diff --git a/api/controllers/console/workspace/models.py b/api/controllers/console/workspace/models.py index 2dabbeaaca..f223476d02 100644 --- a/api/controllers/console/workspace/models.py +++ b/api/controllers/console/workspace/models.py @@ -203,7 +203,7 @@ class ModelProviderModelCredentialApi(Resource): args = parser.parse_args() model_provider_service = ModelProviderService() - credentials = model_provider_service.get_model_credential( + current_credential = model_provider_service.get_model_credential( tenant_id=tenant_id, provider=provider, model_type=args["model_type"], @@ -228,7 +228,13 @@ class ModelProviderModelCredentialApi(Resource): return jsonable_encoder( { - "credentials": credentials, + "credentials": current_credential.get("credentials") if current_credential else {}, + "current_credential_id": current_credential.get("current_credential_id") + if current_credential + else None, + "current_credential_name": current_credential.get("current_credential_name") + if current_credential + else None, "load_balancing": {"enabled": is_load_balancing_enabled, "configs": load_balancing_configs}, "available_credentials": available_credentials, } diff --git a/api/core/entities/provider_configuration.py b/api/core/entities/provider_configuration.py index d35ce9bf89..26659e1056 100644 --- a/api/core/entities/provider_configuration.py +++ b/api/core/entities/provider_configuration.py @@ -656,13 +656,13 @@ class ProviderConfiguration(BaseModel): ProviderModelCredential.model_type == model_type.to_origin_model_type(), ) - credential = session.execute(stmt).scalar_one_or_none() + credential_record = session.execute(stmt).scalar_one_or_none() - if not credential or not credential.encrypted_config: + if not credential_record or not credential_record.encrypted_config: raise ValueError(f"Credential with id {credential_id} not found.") try: - credentials = json.loads(credential.encrypted_config) + credentials = json.loads(credential_record.encrypted_config) except JSONDecodeError: credentials = {} @@ -674,13 +674,21 @@ class ProviderConfiguration(BaseModel): except Exception: pass - return self.obfuscated_credentials( + current_credential_id = credential_record.id + current_credential_name = credential_record.credential_name + credentials = self.obfuscated_credentials( credentials=credentials, credential_form_schemas=self.provider.model_credential_schema.credential_form_schemas if self.provider.model_credential_schema else [], ) + return { + "current_credential_id": current_credential_id, + "current_credential_name": current_credential_name, + "credentials": credentials, + } + def _check_custom_model_credential_name_exists( self, model_type: ModelType, model: str, credential_name: str, session: Session, exclude_id: str | None = None ) -> bool: @@ -715,15 +723,24 @@ class ProviderConfiguration(BaseModel): ) for model_configuration in self.custom_configuration.models: - if model_configuration.model_type == model_type and model_configuration.model == model: - credentials = model_configuration.credentials - - return self.obfuscated_credentials( - credentials=credentials, + if ( + model_configuration.model_type == model_type + and model_configuration.model == model + and model_configuration.credentials + ): + current_credential_id = model_configuration.current_credential_id + current_credential_name = model_configuration.current_credential_name + credentials = self.obfuscated_credentials( + credentials=model_configuration.credentials, credential_form_schemas=self.provider.model_credential_schema.credential_form_schemas if self.provider.model_credential_schema else [], ) + return { + "current_credential_id": current_credential_id, + "current_credential_name": current_credential_name, + "credentials": credentials, + } return None def validate_custom_model_credentials( diff --git a/api/core/entities/provider_entities.py b/api/core/entities/provider_entities.py index 63115b720a..98ba625c94 100644 --- a/api/core/entities/provider_entities.py +++ b/api/core/entities/provider_entities.py @@ -109,6 +109,7 @@ class CustomModelConfiguration(BaseModel): model_type: ModelType credentials: dict | None current_credential_id: Optional[str] = None + current_credential_name: Optional[str] = None available_model_credentials: list[CredentialConfiguration] = [] # pydantic configs diff --git a/api/core/provider_manager.py b/api/core/provider_manager.py index 8502cbd92b..a99b4777f0 100644 --- a/api/core/provider_manager.py +++ b/api/core/provider_manager.py @@ -758,6 +758,7 @@ class ProviderManager: model_type=ModelType.value_of(provider_model_record.model_type), credentials=provider_model_credentials, current_credential_id=provider_model_record.credential_id, + current_credential_name=provider_model_record.credential_name, available_model_credentials=available_model_credentials, ) )