diff --git a/api/controllers/console/workspace/model_providers.py b/api/controllers/console/workspace/model_providers.py index ff0fcbda6e..e56f365d51 100644 --- a/api/controllers/console/workspace/model_providers.py +++ b/api/controllers/console/workspace/model_providers.py @@ -10,6 +10,7 @@ from controllers.console.wraps import account_initialization_required, setup_req from core.model_runtime.entities.model_entities import ModelType from core.model_runtime.errors.validate import CredentialsValidateFailedError from core.model_runtime.utils.encoders import jsonable_encoder +from libs.helper import StrLen, uuid_value from libs.login import login_required from services.billing_service import BillingService from services.model_provider_service import ModelProviderService @@ -45,12 +46,109 @@ class ModelProviderCredentialApi(Resource): @account_initialization_required def get(self, provider: str): tenant_id = current_user.current_tenant_id + # if credential_id is not provided, return current used credential + parser = reqparse.RequestParser() + parser.add_argument("credential_id", type=uuid_value, required=False, nullable=True, location="args") + args = parser.parse_args() model_provider_service = ModelProviderService() - credentials = model_provider_service.get_provider_credentials(tenant_id=tenant_id, provider=provider) + credentials = model_provider_service.get_provider_credential( + tenant_id=tenant_id, provider=provider, credential_id=args.get("credential_id") + ) return {"credentials": credentials} + @setup_required + @login_required + @account_initialization_required + def post(self, provider: str): + if not current_user.is_admin_or_owner: + raise Forbidden() + + parser = reqparse.RequestParser() + parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json") + parser.add_argument("name", type=StrLen(30), required=True, nullable=False, location="json") + args = parser.parse_args() + + model_provider_service = ModelProviderService() + + try: + model_provider_service.create_provider_credential( + tenant_id=current_user.current_tenant_id, + provider=provider, + credentials=args["credentials"], + credential_name=args["name"], + ) + except CredentialsValidateFailedError as ex: + raise ValueError(str(ex)) + + return {"result": "success"}, 201 + + @setup_required + @login_required + @account_initialization_required + def put(self, provider: str): + if not current_user.is_admin_or_owner: + raise Forbidden() + + parser = reqparse.RequestParser() + parser.add_argument("credential_id", type=uuid_value, required=True, nullable=False, location="json") + parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json") + parser.add_argument("name", type=StrLen(30), required=True, nullable=False, location="json") + args = parser.parse_args() + + model_provider_service = ModelProviderService() + + try: + model_provider_service.update_provider_credential( + tenant_id=current_user.current_tenant_id, + provider=provider, + credentials=args["credentials"], + credential_id=args["credential_id"], + credential_name=args["name"], + ) + except CredentialsValidateFailedError as ex: + raise ValueError(str(ex)) + + return {"result": "success"} + + @setup_required + @login_required + @account_initialization_required + def delete(self, provider: str): + if not current_user.is_admin_or_owner: + raise Forbidden() + parser = reqparse.RequestParser() + parser.add_argument("credential_id", type=uuid_value, required=True, nullable=False, location="json") + args = parser.parse_args() + + model_provider_service = ModelProviderService() + model_provider_service.remove_provider_credential( + tenant_id=current_user.current_tenant_id, provider=provider, credential_id=args["credential_id"] + ) + + return {"result": "success"}, 204 + + +class ModelProviderCredentialSwitchApi(Resource): + @setup_required + @login_required + @account_initialization_required + def post(self, provider: str): + if not current_user.is_admin_or_owner: + raise Forbidden() + parser = reqparse.RequestParser() + parser.add_argument("credential_id", type=str, required=True, nullable=False, location="json") + args = parser.parse_args() + + service = ModelProviderService() + service.switch_active_provider_credential( + tenant_id=current_user.current_tenant_id, + provider=provider, + credential_id=args["credential_id"], + ) + return {"result": "success"} + class ModelProviderValidateApi(Resource): @setup_required @@ -69,7 +167,7 @@ class ModelProviderValidateApi(Resource): error = "" try: - model_provider_service.provider_credentials_validate( + model_provider_service.validate_provider_credentials( tenant_id=tenant_id, provider=provider, credentials=args["credentials"] ) except CredentialsValidateFailedError as ex: @@ -84,42 +182,6 @@ class ModelProviderValidateApi(Resource): return response -class ModelProviderApi(Resource): - @setup_required - @login_required - @account_initialization_required - def post(self, provider: str): - if not current_user.is_admin_or_owner: - raise Forbidden() - - parser = reqparse.RequestParser() - parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json") - args = parser.parse_args() - - model_provider_service = ModelProviderService() - - try: - model_provider_service.save_provider_credentials( - tenant_id=current_user.current_tenant_id, provider=provider, credentials=args["credentials"] - ) - except CredentialsValidateFailedError as ex: - raise ValueError(str(ex)) - - return {"result": "success"}, 201 - - @setup_required - @login_required - @account_initialization_required - def delete(self, provider: str): - if not current_user.is_admin_or_owner: - raise Forbidden() - - model_provider_service = ModelProviderService() - model_provider_service.remove_provider_credentials(tenant_id=current_user.current_tenant_id, provider=provider) - - return {"result": "success"}, 204 - - class ModelProviderIconApi(Resource): """ Get model provider icon @@ -187,8 +249,10 @@ class ModelProviderPaymentCheckoutUrlApi(Resource): api.add_resource(ModelProviderListApi, "/workspaces/current/model-providers") api.add_resource(ModelProviderCredentialApi, "/workspaces/current/model-providers//credentials") +api.add_resource( + ModelProviderCredentialSwitchApi, "/workspaces/current/model-providers//credentials/switch" +) api.add_resource(ModelProviderValidateApi, "/workspaces/current/model-providers//credentials/validate") -api.add_resource(ModelProviderApi, "/workspaces/current/model-providers/") api.add_resource( PreferredProviderTypeUpdateApi, "/workspaces/current/model-providers//preferred-provider-type" diff --git a/api/controllers/console/workspace/models.py b/api/controllers/console/workspace/models.py index 514d1084c4..2dabbeaaca 100644 --- a/api/controllers/console/workspace/models.py +++ b/api/controllers/console/workspace/models.py @@ -9,6 +9,7 @@ from controllers.console.wraps import account_initialization_required, setup_req from core.model_runtime.entities.model_entities import ModelType from core.model_runtime.errors.validate import CredentialsValidateFailedError from core.model_runtime.utils.encoders import jsonable_encoder +from libs.helper import StrLen, uuid_value from libs.login import login_required from services.model_load_balancing_service import ModelLoadBalancingService from services.model_provider_service import ModelProviderService @@ -98,6 +99,7 @@ class ModelProviderModelApi(Resource): @login_required @account_initialization_required def post(self, provider: str): + # To save the model's load balance configs if not current_user.is_admin_or_owner: raise Forbidden() @@ -113,7 +115,6 @@ class ModelProviderModelApi(Resource): choices=[mt.value for mt in ModelType], location="json", ) - parser.add_argument("credentials", type=dict, required=False, nullable=True, location="json") parser.add_argument("load_balancing", type=dict, required=False, nullable=True, location="json") parser.add_argument("config_from", type=str, required=False, nullable=True, location="json") args = parser.parse_args() @@ -136,6 +137,7 @@ class ModelProviderModelApi(Resource): model=args["model"], model_type=args["model_type"], configs=args["load_balancing"]["configs"], + config_from=args.get("config_from", ""), ) # enable load balancing @@ -148,26 +150,6 @@ class ModelProviderModelApi(Resource): tenant_id=tenant_id, provider=provider, model=args["model"], model_type=args["model_type"] ) - if args.get("config_from", "") != "predefined-model": - model_provider_service = ModelProviderService() - - try: - model_provider_service.save_model_credentials( - tenant_id=tenant_id, - provider=provider, - model=args["model"], - model_type=args["model_type"], - credentials=args["credentials"], - ) - except CredentialsValidateFailedError as ex: - logging.exception( - "Failed to save model credentials, tenant_id: %s, model: %s, model_type: %s", - tenant_id, - args.get("model"), - args.get("model_type"), - ) - raise ValueError(str(ex)) - return {"result": "success"}, 200 @setup_required @@ -192,7 +174,7 @@ class ModelProviderModelApi(Resource): args = parser.parse_args() model_provider_service = ModelProviderService() - model_provider_service.remove_model_credentials( + model_provider_service.remove_model( tenant_id=tenant_id, provider=provider, model=args["model"], model_type=args["model_type"] ) @@ -216,11 +198,17 @@ class ModelProviderModelCredentialApi(Resource): choices=[mt.value for mt in ModelType], location="args", ) + parser.add_argument("config_from", type=str, required=False, nullable=True, location="args") + parser.add_argument("credential_id", type=uuid_value, required=False, nullable=True, location="args") args = parser.parse_args() model_provider_service = ModelProviderService() - credentials = model_provider_service.get_model_credentials( - tenant_id=tenant_id, provider=provider, model_type=args["model_type"], model=args["model"] + credentials = model_provider_service.get_model_credential( + tenant_id=tenant_id, + provider=provider, + model_type=args["model_type"], + model=args["model"], + credential_id=args.get("credential_id"), ) model_load_balancing_service = ModelLoadBalancingService() @@ -228,10 +216,167 @@ class ModelProviderModelCredentialApi(Resource): tenant_id=tenant_id, provider=provider, model=args["model"], model_type=args["model_type"] ) - return { - "credentials": credentials, - "load_balancing": {"enabled": is_load_balancing_enabled, "configs": load_balancing_configs}, - } + if args.get("config_from", "") == "predefined-model": + available_credentials = model_provider_service.provider_manager.get_provider_available_credentials( + tenant_id=tenant_id, provider_name=provider + ) + else: + model_type = ModelType.value_of(args["model_type"]).to_origin_model_type() + available_credentials = model_provider_service.provider_manager.get_provider_model_available_credentials( + tenant_id=tenant_id, provider_name=provider, model_type=model_type, model_name=args["model"] + ) + + return jsonable_encoder( + { + "credentials": credentials, + "load_balancing": {"enabled": is_load_balancing_enabled, "configs": load_balancing_configs}, + "available_credentials": available_credentials, + } + ) + + @setup_required + @login_required + @account_initialization_required + def post(self, provider: str): + if not current_user.is_admin_or_owner: + raise Forbidden() + + parser = reqparse.RequestParser() + parser.add_argument("model", type=str, required=True, nullable=False, location="json") + parser.add_argument( + "model_type", + type=str, + required=True, + nullable=False, + choices=[mt.value for mt in ModelType], + location="json", + ) + parser.add_argument("name", type=StrLen(30), required=True, nullable=False, location="json") + parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json") + args = parser.parse_args() + + tenant_id = current_user.current_tenant_id + model_provider_service = ModelProviderService() + + try: + model_provider_service.create_model_credential( + tenant_id=tenant_id, + provider=provider, + model=args["model"], + model_type=args["model_type"], + credentials=args["credentials"], + credential_name=args["name"], + ) + except CredentialsValidateFailedError as ex: + logging.exception( + "Failed to save model credentials, tenant_id: %s, model: %s, model_type: %s", + tenant_id, + args.get("model"), + args.get("model_type"), + ) + raise ValueError(str(ex)) + + return {"result": "success"}, 201 + + @setup_required + @login_required + @account_initialization_required + def put(self, provider: str): + if not current_user.is_admin_or_owner: + raise Forbidden() + + parser = reqparse.RequestParser() + parser.add_argument("model", type=str, required=True, nullable=False, location="json") + parser.add_argument( + "model_type", + type=str, + required=True, + nullable=False, + choices=[mt.value for mt in ModelType], + location="json", + ) + parser.add_argument("credential_id", type=uuid_value, required=True, nullable=False, location="json") + parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json") + parser.add_argument("name", type=StrLen(30), required=True, nullable=False, location="json") + args = parser.parse_args() + + model_provider_service = ModelProviderService() + + try: + model_provider_service.update_model_credential( + tenant_id=current_user.current_tenant_id, + provider=provider, + model_type=args["model_type"], + model=args["model"], + credentials=args["credentials"], + credential_id=args["credential_id"], + credential_name=args["name"], + ) + except CredentialsValidateFailedError as ex: + raise ValueError(str(ex)) + + return {"result": "success"} + + @setup_required + @login_required + @account_initialization_required + def delete(self, provider: str): + if not current_user.is_admin_or_owner: + raise Forbidden() + parser = reqparse.RequestParser() + parser.add_argument("model", type=str, required=True, nullable=False, location="json") + parser.add_argument( + "model_type", + type=str, + required=True, + nullable=False, + choices=[mt.value for mt in ModelType], + location="json", + ) + parser.add_argument("credential_id", type=uuid_value, required=True, nullable=False, location="json") + args = parser.parse_args() + + model_provider_service = ModelProviderService() + model_provider_service.remove_model_credential( + tenant_id=current_user.current_tenant_id, + provider=provider, + model_type=args["model_type"], + model=args["model"], + credential_id=args["credential_id"], + ) + + return {"result": "success"}, 204 + + +class ModelProviderModelCredentialSwitchApi(Resource): + @setup_required + @login_required + @account_initialization_required + def post(self, provider: str): + if not current_user.is_admin_or_owner: + raise Forbidden() + parser = reqparse.RequestParser() + parser.add_argument("model", type=str, required=True, nullable=False, location="json") + parser.add_argument( + "model_type", + type=str, + required=True, + nullable=False, + choices=[mt.value for mt in ModelType], + location="json", + ) + parser.add_argument("credential_id", type=str, required=True, nullable=False, location="json") + args = parser.parse_args() + + service = ModelProviderService() + service.switch_active_provider_model_credential( + tenant_id=current_user.current_tenant_id, + provider=provider, + model_type=args["model_type"], + model=args["model"], + credential_id=args["credential_id"], + ) + return {"result": "success"} class ModelProviderModelEnableApi(Resource): @@ -314,7 +459,7 @@ class ModelProviderModelValidateApi(Resource): error = "" try: - model_provider_service.model_credentials_validate( + model_provider_service.validate_model_credentials( tenant_id=tenant_id, provider=provider, model=args["model"], @@ -379,6 +524,10 @@ api.add_resource( api.add_resource( ModelProviderModelCredentialApi, "/workspaces/current/model-providers//models/credentials" ) +api.add_resource( + ModelProviderModelCredentialSwitchApi, + "/workspaces/current/model-providers//models/credentials/switch", +) api.add_resource( ModelProviderModelValidateApi, "/workspaces/current/model-providers//models/credentials/validate" ) diff --git a/api/core/entities/model_entities.py b/api/core/entities/model_entities.py index e1c021a44a..8290fc217e 100644 --- a/api/core/entities/model_entities.py +++ b/api/core/entities/model_entities.py @@ -54,6 +54,7 @@ class ProviderModelWithStatusEntity(ProviderModel): status: ModelStatus load_balancing_enabled: bool = False + has_invalid_load_balancing_configs: bool = False def raise_for_status(self) -> None: """ diff --git a/api/core/entities/provider_configuration.py b/api/core/entities/provider_configuration.py index 8bfbd82e1f..74fc69b955 100644 --- a/api/core/entities/provider_configuration.py +++ b/api/core/entities/provider_configuration.py @@ -7,6 +7,8 @@ from json import JSONDecodeError from typing import Optional from pydantic import BaseModel, ConfigDict, Field +from sqlalchemy import func, select +from sqlalchemy.orm import Session from constants import HIDDEN_VALUE from core.entities.model_entities import ModelStatus, ModelWithProviderEntity, SimpleModelProviderEntity @@ -32,7 +34,9 @@ from extensions.ext_database import db from models.provider import ( LoadBalancingModelConfig, Provider, + ProviderCredential, ProviderModel, + ProviderModelCredential, ProviderModelSetting, ProviderType, TenantPreferredModelProvider, @@ -45,7 +49,16 @@ original_provider_configurate_methods: dict[str, list[ConfigurateMethod]] = {} class ProviderConfiguration(BaseModel): """ - Model class for provider configuration. + Provider configuration entity for managing model provider settings. + + This class handles: + - Provider credentials CRUD and switch + - Custom Model credentials CRUD and switch + - System vs custom provider switching + - Load balancing configurations + - Model enablement/disablement + + TODO: lots of logic in a BaseModel entity should be separated, the exceptions should be classified """ tenant_id: str @@ -155,33 +168,17 @@ class ProviderConfiguration(BaseModel): Check custom configuration available. :return: """ - return self.custom_configuration.provider is not None or len(self.custom_configuration.models) > 0 - - def get_custom_credentials(self, obfuscated: bool = False) -> dict | None: - """ - Get custom credentials. - - :param obfuscated: obfuscated secret data in credentials - :return: - """ - if self.custom_configuration.provider is None: - return None - - credentials = self.custom_configuration.provider.credentials - if not obfuscated: - return credentials - - # Obfuscate credentials - return self.obfuscated_credentials( - credentials=credentials, - credential_form_schemas=self.provider.provider_credential_schema.credential_form_schemas - if self.provider.provider_credential_schema - else [], + has_provider_credentials = ( + self.custom_configuration.provider is not None + and len(self.custom_configuration.provider.available_credentials) > 0 ) - def _get_custom_provider_credentials(self) -> Provider | None: + has_model_configurations = len(self.custom_configuration.models) > 0 + return has_provider_credentials or has_model_configurations + + def _get_provider_record(self, session: Session) -> Provider | None: """ - Get custom provider credentials. + Get custom provider record. """ # get provider model_provider_id = ModelProviderID(self.provider.provider) @@ -189,156 +186,430 @@ class ProviderConfiguration(BaseModel): if model_provider_id.is_langgenius(): provider_names.append(model_provider_id.provider_name) - provider_record = ( - db.session.query(Provider) - .where( - Provider.tenant_id == self.tenant_id, - Provider.provider_type == ProviderType.CUSTOM.value, - Provider.provider_name.in_(provider_names), - ) - .first() + stmt = select(Provider).where( + Provider.tenant_id == self.tenant_id, + Provider.provider_type == ProviderType.CUSTOM.value, + Provider.provider_name.in_(provider_names), ) - return provider_record + return session.execute(stmt).scalar_one_or_none() - def custom_credentials_validate(self, credentials: dict) -> tuple[Provider | None, dict]: + def _get_specific_provider_credential(self, credential_id: str) -> dict | None: """ - Validate custom credentials. - :param credentials: provider credentials + Get a specific provider credential by ID. + :param credential_id: Credential ID :return: """ - provider_record = self._get_custom_provider_credentials() - - # Get provider credential secret variables - provider_credential_secret_variables = self.extract_secret_variables( + # Extract secret variables from provider credential schema + credential_secret_variables = self.extract_secret_variables( self.provider.provider_credential_schema.credential_form_schemas if self.provider.provider_credential_schema else [] ) - if provider_record: - try: - # fix origin data - if provider_record.encrypted_config: - if not provider_record.encrypted_config.startswith("{"): - original_credentials = {"openai_api_key": provider_record.encrypted_config} - else: - original_credentials = json.loads(provider_record.encrypted_config) - else: - original_credentials = {} - except JSONDecodeError: - original_credentials = {} + with Session(db.engine) as session: + # Prefer the actual provider record name if exists (to handle aliased provider names) + provider_record = self._get_provider_record(session) + provider_name = provider_record.provider_name if provider_record else self.provider.provider - # encrypt credentials - for key, value in credentials.items(): - if key in provider_credential_secret_variables: - # if send [__HIDDEN__] in secret input, it will be same as original value - if value == HIDDEN_VALUE and key in original_credentials: - credentials[key] = encrypter.decrypt_token(self.tenant_id, original_credentials[key]) - - model_provider_factory = ModelProviderFactory(self.tenant_id) - credentials = model_provider_factory.provider_credentials_validate( - provider=self.provider.provider, credentials=credentials - ) - - for key, value in credentials.items(): - if key in provider_credential_secret_variables: - credentials[key] = encrypter.encrypt_token(self.tenant_id, value) - - return provider_record, credentials - - def add_or_update_custom_credentials(self, credentials: dict) -> None: - """ - Add or update custom provider credentials. - :param credentials: - :return: - """ - # validate custom provider config - provider_record, credentials = self.custom_credentials_validate(credentials) - - # save provider - # Note: Do not switch the preferred provider, which allows users to use quotas first - if provider_record: - provider_record.encrypted_config = json.dumps(credentials) - provider_record.is_valid = True - provider_record.updated_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None) - db.session.commit() - else: - provider_record = Provider() - provider_record.tenant_id = self.tenant_id - provider_record.provider_name = self.provider.provider - provider_record.provider_type = ProviderType.CUSTOM.value - provider_record.encrypted_config = json.dumps(credentials) - provider_record.is_valid = True - - db.session.add(provider_record) - db.session.commit() - - provider_model_credentials_cache = ProviderCredentialsCache( - tenant_id=self.tenant_id, identity_id=provider_record.id, cache_type=ProviderCredentialsCacheType.PROVIDER - ) - - provider_model_credentials_cache.delete() - - self.switch_preferred_provider_type(ProviderType.CUSTOM) - - def delete_custom_credentials(self) -> None: - """ - Delete custom provider credentials. - :return: - """ - # get provider - provider_record = self._get_custom_provider_credentials() - - # delete provider - if provider_record: - self.switch_preferred_provider_type(ProviderType.SYSTEM) - - db.session.delete(provider_record) - db.session.commit() - - provider_model_credentials_cache = ProviderCredentialsCache( - tenant_id=self.tenant_id, - identity_id=provider_record.id, - cache_type=ProviderCredentialsCacheType.PROVIDER, + stmt = select(ProviderCredential).where( + ProviderCredential.id == credential_id, + ProviderCredential.tenant_id == self.tenant_id, + ProviderCredential.provider_name == provider_name, ) - provider_model_credentials_cache.delete() + credential = session.execute(stmt).scalar_one_or_none() - def get_custom_model_credentials( - self, model_type: ModelType, model: str, obfuscated: bool = False - ) -> Optional[dict]: + if not credential or not credential.encrypted_config: + raise ValueError(f"Credential with id {credential_id} not found.") + + try: + credentials = json.loads(credential.encrypted_config) + except JSONDecodeError: + credentials = {} + + # Decrypt secret variables + for key in credential_secret_variables: + if key in credentials and credentials[key] is not None: + try: + credentials[key] = encrypter.decrypt_token(tenant_id=self.tenant_id, token=credentials[key]) + except Exception: + pass + + return self.obfuscated_credentials( + credentials=credentials, + credential_form_schemas=self.provider.provider_credential_schema.credential_form_schemas + if self.provider.provider_credential_schema + else [], + ) + + def _check_provider_credential_name_exists( + self, credential_name: str, session: Session, exclude_id: str | None = None + ) -> bool: """ - Get custom model credentials. + not allowed same name when create or update a credential + """ + stmt = select(ProviderCredential.id).where( + ProviderCredential.tenant_id == self.tenant_id, + ProviderCredential.provider_name == self.provider.provider, + ProviderCredential.credential_name == credential_name, + ) + if exclude_id: + stmt = stmt.where(ProviderCredential.id != exclude_id) + return session.execute(stmt).scalar_one_or_none() is not None - :param model_type: model type - :param model: model name - :param obfuscated: obfuscated secret data in credentials + def get_provider_credential(self, credential_id: str | None = None) -> dict | None: + """ + Get provider credentials. + + :param credential_id: if provided, return the specified credential :return: """ - if not self.custom_configuration.models: - return None - for model_configuration in self.custom_configuration.models: - if model_configuration.model_type == model_type and model_configuration.model == model: - credentials = model_configuration.credentials - if not obfuscated: - return credentials + if credential_id: + return self._get_specific_provider_credential(credential_id) - # Obfuscate credentials - return self.obfuscated_credentials( - credentials=credentials, - credential_form_schemas=self.provider.model_credential_schema.credential_form_schemas - if self.provider.model_credential_schema - else [], + # Default behavior: return current active provider credentials + credentials = self.custom_configuration.provider.credentials + + return self.obfuscated_credentials( + credentials=credentials, + credential_form_schemas=self.provider.provider_credential_schema.credential_form_schemas + if self.provider.provider_credential_schema + else [], + ) + + def validate_provider_credentials( + self, credentials: dict, credential_id: str = "", session: Session | None = None + ) -> dict: + """ + Validate custom credentials. + :param credentials: provider credentials + :param credential_id: (Optional)If provided, can use existing credential's hidden api key to validate + :param session: optional database session + :return: + """ + + def _validate(s: Session) -> tuple[Provider | None, dict]: + # Get provider credential secret variables + provider_credential_secret_variables = self.extract_secret_variables( + self.provider.provider_credential_schema.credential_form_schemas + if self.provider.provider_credential_schema + else [] + ) + + if credential_id: + try: + stmt = select(ProviderCredential).where( + ProviderCredential.tenant_id == self.tenant_id, + ProviderCredential.provider_name == self.provider.provider, + ProviderCredential.id == credential_id, + ) + credential_record = s.execute(stmt).scalar_one_or_none() + # fix origin data + if credential_record and credential_record.encrypted_config: + if not credential_record.encrypted_config.startswith("{"): + original_credentials = {"openai_api_key": credential_record.encrypted_config} + else: + original_credentials = json.loads(credential_record.encrypted_config) + else: + original_credentials = {} + except JSONDecodeError: + original_credentials = {} + + # encrypt credentials + for key, value in credentials.items(): + if key in provider_credential_secret_variables: + # if send [__HIDDEN__] in secret input, it will be same as original value + if value == HIDDEN_VALUE and key in original_credentials: + credentials[key] = encrypter.decrypt_token( + tenant_id=self.tenant_id, token=original_credentials[key] + ) + + model_provider_factory = ModelProviderFactory(self.tenant_id) + validated_credentials = model_provider_factory.provider_credentials_validate( + provider=self.provider.provider, credentials=credentials + ) + + for key, value in validated_credentials.items(): + if key in provider_credential_secret_variables: + validated_credentials[key] = encrypter.encrypt_token(self.tenant_id, value) + + return validated_credentials + + if session: + return _validate(session) + else: + with Session(db.engine) as new_session: + return _validate(new_session) + + def create_provider_credential(self, credentials: dict, credential_name: str) -> None: + """ + Add custom provider credentials. + :param credentials: provider credentials + :param credential_name: credential name + :return: + """ + with Session(db.engine) as session: + if self._check_provider_credential_name_exists(credential_name=credential_name, session=session): + raise ValueError(f"Credential with name '{credential_name}' already exists.") + + credentials = self.validate_provider_credentials(credentials=credentials, session=session) + provider_record = self._get_provider_record(session) + try: + new_record = ProviderCredential( + tenant_id=self.tenant_id, + provider_name=self.provider.provider, + encrypted_config=json.dumps(credentials), + credential_name=credential_name, ) + session.add(new_record) + session.flush() - return None + if not provider_record: + # If provider record does not exist, create it + provider_record = Provider( + tenant_id=self.tenant_id, + provider_name=self.provider.provider, + provider_type=ProviderType.CUSTOM.value, + is_valid=True, + credential_id=new_record.id, + ) + session.add(provider_record) - def _get_custom_model_credentials( + provider_model_credentials_cache = ProviderCredentialsCache( + tenant_id=self.tenant_id, + identity_id=provider_record.id, + cache_type=ProviderCredentialsCacheType.PROVIDER, + ) + provider_model_credentials_cache.delete() + + self.switch_preferred_provider_type(provider_type=ProviderType.CUSTOM, session=session) + + session.commit() + except Exception: + session.rollback() + raise + + def update_provider_credential( + self, + credentials: dict, + credential_id: str, + credential_name: str, + ) -> None: + """ + update a saved provider credential (by credential_id). + + :param credentials: provider credentials + :param credential_id: credential id + :param credential_name: credential name + :return: + """ + with Session(db.engine) as session: + if self._check_provider_credential_name_exists( + credential_name=credential_name, session=session, exclude_id=credential_id + ): + raise ValueError(f"Credential with name '{credential_name}' already exists.") + + credentials = self.validate_provider_credentials( + credentials=credentials, credential_id=credential_id, session=session + ) + provider_record = self._get_provider_record(session) + stmt = select(ProviderCredential).where( + ProviderCredential.id == credential_id, + ProviderCredential.tenant_id == self.tenant_id, + ProviderCredential.provider_name == self.provider.provider, + ) + + # Get the credential record to update + credential_record = session.execute(stmt).scalar_one_or_none() + if not credential_record: + raise ValueError("Credential record not found.") + try: + # Update credential + credential_record.encrypted_config = json.dumps(credentials) + credential_record.credential_name = credential_name + credential_record.updated_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None) + + session.commit() + + if provider_record and provider_record.credential_id == credential_id: + provider_model_credentials_cache = ProviderCredentialsCache( + tenant_id=self.tenant_id, + identity_id=provider_record.id, + cache_type=ProviderCredentialsCacheType.PROVIDER, + ) + provider_model_credentials_cache.delete() + + self._update_load_balancing_configs_with_credential( + credential_id=credential_id, + credential_record=credential_record, + credential_source="provider", + session=session, + ) + except Exception: + session.rollback() + raise + + def _update_load_balancing_configs_with_credential( + self, credential_id: str, credential_record: dict, credential_source: str, session: Session + ) -> None: + """ + Update load balancing configurations that reference the given credential_id. + + :param credential_id: credential id + :param credential_record: the encrypted_config and credential_name + :param credential_source: the credential comes from the provider_credential(`provider`) + or the provider_model_credential(`custom_model`) + :param session: the database session + :return: + """ + # Find all load balancing configs that use this credential_id + stmt = select(LoadBalancingModelConfig).where( + LoadBalancingModelConfig.tenant_id == self.tenant_id, + LoadBalancingModelConfig.provider_name == self.provider.provider, + LoadBalancingModelConfig.credential_id == credential_id, + LoadBalancingModelConfig.credential_source_type == credential_source, + ) + load_balancing_configs = session.execute(stmt).scalars().all() + + if not load_balancing_configs: + return + + # Update each load balancing config with the new credentials + for lb_config in load_balancing_configs: + # Update the encrypted_config with the new credentials + lb_config.encrypted_config = credential_record.encrypted_config + lb_config.name = credential_record.credential_name + lb_config.updated_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None) + + # Clear cache for this load balancing config + lb_credentials_cache = ProviderCredentialsCache( + tenant_id=self.tenant_id, + identity_id=lb_config.id, + cache_type=ProviderCredentialsCacheType.LOAD_BALANCING_MODEL, + ) + lb_credentials_cache.delete() + + session.commit() + + def delete_provider_credential(self, credential_id: str) -> None: + """ + Delete a saved provider credential (by credential_id). + + :param credential_id: credential id + :return: + """ + with Session(db.engine) as session: + stmt = select(ProviderCredential).where( + ProviderCredential.id == credential_id, + ProviderCredential.tenant_id == self.tenant_id, + ProviderCredential.provider_name == self.provider.provider, + ) + + # Get the credential record to update + credential_record = session.execute(stmt).scalar_one_or_none() + if not credential_record: + raise ValueError("Credential record not found.") + + # Check if this credential is used in load balancing configs + stmt = select(LoadBalancingModelConfig).where( + LoadBalancingModelConfig.tenant_id == self.tenant_id, + LoadBalancingModelConfig.provider_name == self.provider.provider, + LoadBalancingModelConfig.credential_id == credential_id, + LoadBalancingModelConfig.credential_source_type == "provider", + ) + lb_configs_using_credential = session.execute(stmt).scalars().all() + try: + for lb_config in lb_configs_using_credential: + lb_credentials_cache = ProviderCredentialsCache( + tenant_id=self.tenant_id, + identity_id=lb_config.id, + cache_type=ProviderCredentialsCacheType.LOAD_BALANCING_MODEL, + ) + lb_credentials_cache.delete() + + lb_config.credential_id = None + lb_config.encrypted_config = None + lb_config.enabled = False + lb_config.name = "__delete__" + lb_config.updated_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None) + session.add(lb_config) + + # Check if this is the currently active credential + provider_record = self._get_provider_record(session) + + # Check available credentials count BEFORE deleting + # if this is the last credential, we need to delete the provider record + stmt = select(func.count(ProviderCredential.id)).where( + ProviderCredential.tenant_id == self.tenant_id, + ProviderCredential.provider_name == self.provider.provider, + ) + available_credentials_count = session.execute(stmt).scalar() + session.delete(credential_record) + + if provider_record and available_credentials_count <= 1: + # If all credentials are deleted, delete the provider record, switch to system provider type + session.delete(provider_record) + elif provider_record and provider_record.credential_id == credential_id: + provider_record.credential_id = None + provider_record.updated_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None) + + provider_model_credentials_cache = ProviderCredentialsCache( + tenant_id=self.tenant_id, + identity_id=provider_record.id, + cache_type=ProviderCredentialsCacheType.PROVIDER, + ) + provider_model_credentials_cache.delete() + self.switch_preferred_provider_type(provider_type=ProviderType.SYSTEM, session=session) + + session.commit() + except Exception: + session.rollback() + raise + + def switch_active_provider_credential(self, credential_id: str) -> None: + """ + Switch active provider credential (copy the selected one into current active snapshot). + + :param credential_id: credential id + :return: + """ + with Session(db.engine) as session: + stmt = select(ProviderCredential).where( + ProviderCredential.id == credential_id, + ProviderCredential.tenant_id == self.tenant_id, + ProviderCredential.provider_name == self.provider.provider, + ) + credential_record = session.execute(stmt).scalar_one_or_none() + if not credential_record: + raise ValueError("Credential record not found.") + + provider_record = self._get_provider_record(session) + if not provider_record: + raise ValueError("Provider record not found.") + + try: + provider_record.credential_id = credential_record.id + provider_record.updated_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None) + session.commit() + + provider_model_credentials_cache = ProviderCredentialsCache( + tenant_id=self.tenant_id, + identity_id=provider_record.id, + cache_type=ProviderCredentialsCacheType.PROVIDER, + ) + provider_model_credentials_cache.delete() + self.switch_preferred_provider_type(ProviderType.CUSTOM, session=session) + except Exception: + session.rollback() + raise + + def _get_custom_model_record( self, model_type: ModelType, model: str, + session: Session, ) -> ProviderModel | None: """ Get custom model credentials. @@ -349,128 +620,449 @@ class ProviderConfiguration(BaseModel): if model_provider_id.is_langgenius(): provider_names.append(model_provider_id.provider_name) - provider_model_record = ( - db.session.query(ProviderModel) - .where( - ProviderModel.tenant_id == self.tenant_id, - ProviderModel.provider_name.in_(provider_names), - ProviderModel.model_name == model, - ProviderModel.model_type == model_type.to_origin_model_type(), - ) - .first() + stmt = select(ProviderModel).where( + ProviderModel.tenant_id == self.tenant_id, + ProviderModel.provider_name.in_(provider_names), + ProviderModel.model_name == model, + ProviderModel.model_type == model_type.to_origin_model_type(), ) - return provider_model_record + return session.execute(stmt).scalar_one_or_none() - def custom_model_credentials_validate( - self, model_type: ModelType, model: str, credentials: dict + def _get_specific_custom_model_credential( + self, model_type: ModelType, model: str, credential_id: str + ) -> dict | None: + """ + Get a specific provider credential by ID. + :param credential_id: Credential ID + :return: + """ + model_credential_secret_variables = self.extract_secret_variables( + self.provider.model_credential_schema.credential_form_schemas + if self.provider.model_credential_schema + else [] + ) + + with Session(db.engine) as session: + model_record = self._get_custom_model_record(model_type=model_type, model=model, session=session) + + stmt = select(ProviderModelCredential).where( + ProviderModelCredential.id == credential_id, + ProviderModelCredential.tenant_id == self.tenant_id, + ProviderModelCredential.provider_name == model_record.provider_name, + ProviderModelCredential.model_name == model_record.model_name, + ProviderModelCredential.model_type == model_record.model_type, + ) + + credential = session.execute(stmt).scalar_one_or_none() + + if not credential or not credential.encrypted_config: + raise ValueError(f"Credential with id {credential_id} not found.") + + try: + credentials = json.loads(credential.encrypted_config) + except JSONDecodeError: + credentials = {} + + # Decrypt secret variables + for key in model_credential_secret_variables: + if key in credentials and credentials[key] is not None: + try: + credentials[key] = encrypter.decrypt_token(tenant_id=self.tenant_id, token=credentials[key]) + except Exception: + pass + + return self.obfuscated_credentials( + credentials=credentials, + credential_form_schemas=self.provider.model_credential_schema.credential_form_schemas + if self.provider.model_credential_schema + else [], + ) + + def _check_custom_model_credential_name_exists( + self, model_type: ModelType, model: str, credential_name: str, session: Session, exclude_id: str | None = None + ) -> bool: + """ + not allowed same name when create or update a credential + """ + stmt = select(ProviderModelCredential).where( + ProviderModelCredential.tenant_id == self.tenant_id, + ProviderModelCredential.provider_name == self.provider.provider, + ProviderModelCredential.model_name == model, + ProviderModelCredential.model_type == model_type.to_origin_model_type(), + ProviderModelCredential.credential_name == credential_name, + ) + if exclude_id: + stmt = stmt.where(ProviderModelCredential.id != exclude_id) + return session.execute(stmt).scalar_one_or_none() is not None + + def get_custom_model_credential( + self, model_type: ModelType, model: str, credential_id: str | None + ) -> Optional[dict]: + """ + Get custom model credentials. + + :param model_type: model type + :param model: model name + :return: + """ + # If credential_id is provided, return the specific credential + if credential_id: + return self._get_specific_custom_model_credential( + model_type=model_type, model=model, credential_id=credential_id + ) + + for model_configuration in self.custom_configuration.models: + if model_configuration.model_type == model_type and model_configuration.model == model: + credentials = model_configuration.credentials + + return self.obfuscated_credentials( + credentials=credentials, + credential_form_schemas=self.provider.model_credential_schema.credential_form_schemas + if self.provider.model_credential_schema + else [], + ) + + def validate_custom_model_credentials( + self, + model_type: ModelType, + model: str, + credentials: dict, + credential_id: str = "", + session: Session | None = None, ) -> tuple[ProviderModel | None, dict]: """ Validate custom model credentials. :param model_type: model type :param model: model name - :param credentials: model credentials + :param credentials: model credentials dict + :param credential_id: (Optional)If provided, can use existing credential's hidden api key to validate :return: """ - # get provider model - provider_model_record = self._get_custom_model_credentials(model_type, model) - # Get provider credential secret variables - provider_credential_secret_variables = self.extract_secret_variables( - self.provider.model_credential_schema.credential_form_schemas - if self.provider.model_credential_schema - else [] - ) - - if provider_model_record: - try: - original_credentials = ( - json.loads(provider_model_record.encrypted_config) if provider_model_record.encrypted_config else {} - ) - except JSONDecodeError: - original_credentials = {} - - # decrypt credentials - for key, value in credentials.items(): - if key in provider_credential_secret_variables: - # if send [__HIDDEN__] in secret input, it will be same as original value - if value == HIDDEN_VALUE and key in original_credentials: - credentials[key] = encrypter.decrypt_token(self.tenant_id, original_credentials[key]) - - model_provider_factory = ModelProviderFactory(self.tenant_id) - credentials = model_provider_factory.model_credentials_validate( - provider=self.provider.provider, model_type=model_type, model=model, credentials=credentials - ) - - for key, value in credentials.items(): - if key in provider_credential_secret_variables: - credentials[key] = encrypter.encrypt_token(self.tenant_id, value) - - return provider_model_record, credentials - - def add_or_update_custom_model_credentials(self, model_type: ModelType, model: str, credentials: dict) -> None: - """ - Add or update custom model credentials. - - :param model_type: model type - :param model: model name - :param credentials: model credentials - :return: - """ - # validate custom model config - provider_model_record, credentials = self.custom_model_credentials_validate(model_type, model, credentials) - - # save provider model - # Note: Do not switch the preferred provider, which allows users to use quotas first - if provider_model_record: - provider_model_record.encrypted_config = json.dumps(credentials) - provider_model_record.is_valid = True - provider_model_record.updated_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None) - db.session.commit() - else: - provider_model_record = ProviderModel() - provider_model_record.tenant_id = self.tenant_id - provider_model_record.provider_name = self.provider.provider - provider_model_record.model_name = model - provider_model_record.model_type = model_type.to_origin_model_type() - provider_model_record.encrypted_config = json.dumps(credentials) - provider_model_record.is_valid = True - db.session.add(provider_model_record) - db.session.commit() - - provider_model_credentials_cache = ProviderCredentialsCache( - tenant_id=self.tenant_id, - identity_id=provider_model_record.id, - cache_type=ProviderCredentialsCacheType.MODEL, - ) - - provider_model_credentials_cache.delete() - - def delete_custom_model_credentials(self, model_type: ModelType, model: str) -> None: - """ - Delete custom model credentials. - :param model_type: model type - :param model: model name - :return: - """ - # get provider model - provider_model_record = self._get_custom_model_credentials(model_type, model) - - # delete provider model - if provider_model_record: - db.session.delete(provider_model_record) - db.session.commit() - - provider_model_credentials_cache = ProviderCredentialsCache( - tenant_id=self.tenant_id, - identity_id=provider_model_record.id, - cache_type=ProviderCredentialsCacheType.MODEL, + def _validate(s: Session) -> tuple[ProviderModel | None, dict]: + # Get provider credential secret variables + provider_credential_secret_variables = self.extract_secret_variables( + self.provider.model_credential_schema.credential_form_schemas + if self.provider.model_credential_schema + else [] ) - provider_model_credentials_cache.delete() + if credential_id: + try: + stmt = select(ProviderModelCredential).where( + ProviderModelCredential.id == credential_id, + ProviderModelCredential.tenant_id == self.tenant_id, + ProviderModelCredential.provider_name == self.provider.provider, + ProviderModelCredential.model_name == model, + ProviderModelCredential.model_type == model_type.to_origin_model_type(), + ) + credential_record = s.execute(stmt).scalar_one_or_none() + original_credentials = ( + json.loads(credential_record.encrypted_config) + if credential_record and credential_record.encrypted_config + else {} + ) + except JSONDecodeError: + original_credentials = {} - def _get_provider_model_setting(self, model_type: ModelType, model: str) -> ProviderModelSetting | None: + # decrypt credentials + for key, value in credentials.items(): + if key in provider_credential_secret_variables: + # if send [__HIDDEN__] in secret input, it will be same as original value + if value == HIDDEN_VALUE and key in original_credentials: + credentials[key] = encrypter.decrypt_token( + tenant_id=self.tenant_id, token=original_credentials[key] + ) + + model_provider_factory = ModelProviderFactory(self.tenant_id) + validated_credentials = model_provider_factory.model_credentials_validate( + provider=self.provider.provider, model_type=model_type, model=model, credentials=credentials + ) + + for key, value in validated_credentials.items(): + if key in provider_credential_secret_variables: + validated_credentials[key] = encrypter.encrypt_token(self.tenant_id, value) + + return validated_credentials + + if session: + return _validate(session) + else: + with Session(db.engine) as new_session: + return _validate(new_session) + + def create_custom_model_credential( + self, model_type: ModelType, model: str, credentials: dict, credential_name: str + ) -> None: + """ + Create a custom model credential. + + :param model_type: model type + :param model: model name + :param credentials: model credentials dict + :return: + """ + with Session(db.engine) as session: + if self._check_custom_model_credential_name_exists( + model=model, model_type=model_type, credential_name=credential_name, session=session + ): + raise ValueError(f"Model credential with name '{credential_name}' already exists for {model}.") + # validate custom model config + credentials = self.validate_custom_model_credentials( + model_type=model_type, model=model, credentials=credentials, session=session + ) + provider_model_record = self._get_custom_model_record(model_type=model_type, model=model, session=session) + + try: + credential = ProviderModelCredential( + tenant_id=self.tenant_id, + provider_name=self.provider.provider, + model_name=model, + model_type=model_type.to_origin_model_type(), + encrypted_config=json.dumps(credentials), + credential_name=credential_name, + ) + session.add(credential) + session.flush() + + # save provider model + if not provider_model_record: + provider_model_record = ProviderModel( + tenant_id=self.tenant_id, + provider_name=self.provider.provider, + model_name=model, + model_type=model_type.to_origin_model_type(), + credential_id=credential.id, + is_valid=True, + ) + session.add(provider_model_record) + + session.commit() + + provider_model_credentials_cache = ProviderCredentialsCache( + tenant_id=self.tenant_id, + identity_id=provider_model_record.id, + cache_type=ProviderCredentialsCacheType.MODEL, + ) + provider_model_credentials_cache.delete() + except Exception: + session.rollback() + raise + + def update_custom_model_credential( + self, model_type: ModelType, model: str, credentials: dict, credential_name: str, credential_id: str + ) -> None: + """ + Update a custom model credential. + + :param model_type: model type + :param model: model name + :param credentials: model credentials dict + :param credential_name: credential name + :param credential_id: credential id + :return: + """ + with Session(db.engine) as session: + if self._check_custom_model_credential_name_exists( + model=model, + model_type=model_type, + credential_name=credential_name, + session=session, + exclude_id=credential_id, + ): + raise ValueError(f"Model credential with name '{credential_name}' already exists for {model}.") + # validate custom model config + credentials = self.validate_custom_model_credentials( + model_type=model_type, + model=model, + credentials=credentials, + credential_id=credential_id, + session=session, + ) + provider_model_record = self._get_custom_model_record(model_type=model_type, model=model, session=session) + + stmt = select(ProviderModelCredential).where( + ProviderModelCredential.id == credential_id, + ProviderModelCredential.tenant_id == self.tenant_id, + ProviderModelCredential.provider_name == self.provider.provider, + ProviderModelCredential.model_name == model, + ProviderModelCredential.model_type == model_type.to_origin_model_type(), + ) + credential_record = session.execute(stmt).scalar_one_or_none() + if not credential_record: + raise ValueError("Credential record not found.") + + try: + # Update credential + credential_record.encrypted_config = json.dumps(credentials) + credential_record.credential_name = credential_name + credential_record.updated_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None) + session.commit() + + if provider_model_record and provider_model_record.credential_id == credential_id: + provider_model_credentials_cache = ProviderCredentialsCache( + tenant_id=self.tenant_id, + identity_id=provider_model_record.id, + cache_type=ProviderCredentialsCacheType.MODEL, + ) + provider_model_credentials_cache.delete() + + self._update_load_balancing_configs_with_credential( + credential_id=credential_id, + credential_record=credential_record, + credential_source="custom_model", + session=session, + ) + except Exception: + session.rollback() + raise + + def delete_custom_model_credential(self, model_type: ModelType, model: str, credential_id: str) -> None: + """ + Delete a saved provider credential (by credential_id). + + :param credential_id: credential id + :return: + """ + with Session(db.engine) as session: + stmt = select(ProviderModelCredential).where( + ProviderModelCredential.id == credential_id, + ProviderModelCredential.tenant_id == self.tenant_id, + ProviderModelCredential.provider_name == self.provider.provider, + ProviderModelCredential.model_name == model, + ProviderModelCredential.model_type == model_type.to_origin_model_type(), + ) + credential_record = session.execute(stmt).scalar_one_or_none() + if not credential_record: + raise ValueError("Credential record not found.") + + stmt = select(LoadBalancingModelConfig).where( + LoadBalancingModelConfig.tenant_id == self.tenant_id, + LoadBalancingModelConfig.provider_name == self.provider.provider, + LoadBalancingModelConfig.credential_id == credential_id, + LoadBalancingModelConfig.credential_source_type == "custom_model", + ) + lb_configs_using_credential = session.execute(stmt).scalars().all() + + try: + for lb_config in lb_configs_using_credential: + lb_credentials_cache = ProviderCredentialsCache( + tenant_id=self.tenant_id, + identity_id=lb_config.id, + cache_type=ProviderCredentialsCacheType.LOAD_BALANCING_MODEL, + ) + lb_credentials_cache.delete() + lb_config.credential_id = None + lb_config.encrypted_config = None + lb_config.enabled = False + lb_config.name = "__delete__" + lb_config.updated_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None) + db.session.add(lb_config) + + # Check if this is the currently active credential + provider_model_record = self._get_custom_model_record(model_type, model, session=session) + + # Check available credentials count BEFORE deleting + # if this is the last credential, we need to delete the custom model record + stmt = select(func.count(ProviderModelCredential.id)).where( + ProviderModelCredential.tenant_id == self.tenant_id, + ProviderModelCredential.provider_name == self.provider.provider, + ProviderModelCredential.model_name == model, + ProviderModelCredential.model_type == model_type.to_origin_model_type(), + ) + available_credentials_count = session.execute(stmt).scalar() + session.delete(credential_record) + + if provider_model_record and available_credentials_count <= 1: + # If all credentials are deleted, delete the custom model record + session.delete(provider_model_record) + elif provider_model_record and provider_model_record.credential_id == credential_id: + provider_model_record.credential_id = None + provider_model_record.updated_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None) + provider_model_credentials_cache = ProviderCredentialsCache( + tenant_id=self.tenant_id, + identity_id=provider_model_record.id, + cache_type=ProviderCredentialsCacheType.PROVIDER, + ) + provider_model_credentials_cache.delete() + + session.commit() + + except Exception: + session.rollback() + raise + + def switch_custom_model_credential(self, model_type: ModelType, model: str, credential_id: str) -> None: + """ + Not only switch the custom model credential. + It can also add credential to a new custom model record. + + :param model_type: model type + :param model: model name + :param credential_id: credential id + :return: + """ + with Session(db.engine) as session: + stmt = select(ProviderModelCredential).where( + ProviderModelCredential.id == credential_id, + ProviderModelCredential.tenant_id == self.tenant_id, + ProviderModelCredential.provider_name == self.provider.provider, + ProviderModelCredential.model_name == model, + ProviderModelCredential.model_type == model_type.to_origin_model_type(), + ) + credential_record = session.execute(stmt).scalar_one_or_none() + if not credential_record: + raise ValueError("Credential record not found.") + + # validate custom model config + provider_model_record = self._get_custom_model_record(model_type=model_type, model=model, session=session) + + if not provider_model_record: + # create provider model record + provider_model_record = ProviderModel( + tenant_id=self.tenant_id, + provider_name=self.provider.provider, + model_name=model, + model_type=model_type.to_origin_model_type(), + credential_id=credential_id, + ) + else: + if provider_model_record.credential_id == credential_record.id: + raise ValueError("Can't add same credential") + provider_model_record.credential_id = credential_record.id + provider_model_record.updated_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None) + session.add(provider_model_record) + session.commit() + + def delete_custom_model(self, model_type: ModelType, model: str) -> None: + """ + Delete custom model. + :param model_type: model type + :param model: model name + :return: + """ + with Session(db.engine) as session: + # get provider model + provider_model_record = self._get_custom_model_record(model_type=model_type, model=model, session=session) + + # delete provider model + if provider_model_record: + session.delete(provider_model_record) + session.commit() + + provider_model_credentials_cache = ProviderCredentialsCache( + tenant_id=self.tenant_id, + identity_id=provider_model_record.id, + cache_type=ProviderCredentialsCacheType.MODEL, + ) + + provider_model_credentials_cache.delete() + + def _get_provider_model_setting( + self, model_type: ModelType, model: str, session: Session + ) -> ProviderModelSetting | None: """ Get provider model setting. """ @@ -479,16 +1071,13 @@ class ProviderConfiguration(BaseModel): if model_provider_id.is_langgenius(): provider_names.append(model_provider_id.provider_name) - return ( - db.session.query(ProviderModelSetting) - .where( - ProviderModelSetting.tenant_id == self.tenant_id, - ProviderModelSetting.provider_name.in_(provider_names), - ProviderModelSetting.model_type == model_type.to_origin_model_type(), - ProviderModelSetting.model_name == model, - ) - .first() + stmt = select(ProviderModelSetting).where( + ProviderModelSetting.tenant_id == self.tenant_id, + ProviderModelSetting.provider_name.in_(provider_names), + ProviderModelSetting.model_type == model_type.to_origin_model_type(), + ProviderModelSetting.model_name == model, ) + return session.execute(stmt).scalars().first() def enable_model(self, model_type: ModelType, model: str) -> ProviderModelSetting: """ @@ -497,21 +1086,23 @@ class ProviderConfiguration(BaseModel): :param model: model name :return: """ - model_setting = self._get_provider_model_setting(model_type, model) + with Session(db.engine) as session: + model_setting = self._get_provider_model_setting(model_type=model_type, model=model, session=session) - if model_setting: - model_setting.enabled = True - model_setting.updated_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None) - db.session.commit() - else: - model_setting = ProviderModelSetting() - model_setting.tenant_id = self.tenant_id - model_setting.provider_name = self.provider.provider - model_setting.model_type = model_type.to_origin_model_type() - model_setting.model_name = model - model_setting.enabled = True - db.session.add(model_setting) - db.session.commit() + if model_setting: + model_setting.enabled = True + model_setting.updated_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None) + + else: + model_setting = ProviderModelSetting( + tenant_id=self.tenant_id, + provider_name=self.provider.provider, + model_type=model_type.to_origin_model_type(), + model_name=model, + enabled=True, + ) + session.add(model_setting) + session.commit() return model_setting @@ -522,21 +1113,22 @@ class ProviderConfiguration(BaseModel): :param model: model name :return: """ - model_setting = self._get_provider_model_setting(model_type, model) + with Session(db.engine) as session: + model_setting = self._get_provider_model_setting(model_type=model_type, model=model, session=session) - if model_setting: - model_setting.enabled = False - model_setting.updated_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None) - db.session.commit() - else: - model_setting = ProviderModelSetting() - model_setting.tenant_id = self.tenant_id - model_setting.provider_name = self.provider.provider - model_setting.model_type = model_type.to_origin_model_type() - model_setting.model_name = model - model_setting.enabled = False - db.session.add(model_setting) - db.session.commit() + if model_setting: + model_setting.enabled = False + model_setting.updated_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None) + else: + model_setting = ProviderModelSetting( + tenant_id=self.tenant_id, + provider_name=self.provider.provider, + model_type=model_type.to_origin_model_type(), + model_name=model, + enabled=False, + ) + session.add(model_setting) + session.commit() return model_setting @@ -547,27 +1139,8 @@ class ProviderConfiguration(BaseModel): :param model: model name :return: """ - return self._get_provider_model_setting(model_type, model) - - def _get_load_balancing_config(self, model_type: ModelType, model: str) -> Optional[LoadBalancingModelConfig]: - """ - Get load balancing config. - """ - model_provider_id = ModelProviderID(self.provider.provider) - provider_names = [self.provider.provider] - if model_provider_id.is_langgenius(): - provider_names.append(model_provider_id.provider_name) - - return ( - db.session.query(LoadBalancingModelConfig) - .where( - LoadBalancingModelConfig.tenant_id == self.tenant_id, - LoadBalancingModelConfig.provider_name.in_(provider_names), - LoadBalancingModelConfig.model_type == model_type.to_origin_model_type(), - LoadBalancingModelConfig.model_name == model, - ) - .first() - ) + with Session(db.engine) as session: + return self._get_provider_model_setting(model_type=model_type, model=model, session=session) def enable_model_load_balancing(self, model_type: ModelType, model: str) -> ProviderModelSetting: """ @@ -581,35 +1154,32 @@ class ProviderConfiguration(BaseModel): if model_provider_id.is_langgenius(): provider_names.append(model_provider_id.provider_name) - load_balancing_config_count = ( - db.session.query(LoadBalancingModelConfig) - .where( + with Session(db.engine) as session: + stmt = select(func.count(LoadBalancingModelConfig.id)).where( LoadBalancingModelConfig.tenant_id == self.tenant_id, LoadBalancingModelConfig.provider_name.in_(provider_names), LoadBalancingModelConfig.model_type == model_type.to_origin_model_type(), LoadBalancingModelConfig.model_name == model, ) - .count() - ) + load_balancing_config_count = session.execute(stmt).scalar() + if load_balancing_config_count <= 1: + raise ValueError("Model load balancing configuration must be more than 1.") - if load_balancing_config_count <= 1: - raise ValueError("Model load balancing configuration must be more than 1.") + model_setting = self._get_provider_model_setting(model_type=model_type, model=model, session=session) - model_setting = self._get_provider_model_setting(model_type, model) - - if model_setting: - model_setting.load_balancing_enabled = True - model_setting.updated_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None) - db.session.commit() - else: - model_setting = ProviderModelSetting() - model_setting.tenant_id = self.tenant_id - model_setting.provider_name = self.provider.provider - model_setting.model_type = model_type.to_origin_model_type() - model_setting.model_name = model - model_setting.load_balancing_enabled = True - db.session.add(model_setting) - db.session.commit() + if model_setting: + model_setting.load_balancing_enabled = True + model_setting.updated_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None) + else: + model_setting = ProviderModelSetting( + tenant_id=self.tenant_id, + provider_name=self.provider.provider, + model_type=model_type.to_origin_model_type(), + model_name=model, + load_balancing_enabled=True, + ) + session.add(model_setting) + session.commit() return model_setting @@ -620,35 +1190,23 @@ class ProviderConfiguration(BaseModel): :param model: model name :return: """ - model_provider_id = ModelProviderID(self.provider.provider) - provider_names = [self.provider.provider] - if model_provider_id.is_langgenius(): - provider_names.append(model_provider_id.provider_name) - model_setting = ( - db.session.query(ProviderModelSetting) - .where( - ProviderModelSetting.tenant_id == self.tenant_id, - ProviderModelSetting.provider_name.in_(provider_names), - ProviderModelSetting.model_type == model_type.to_origin_model_type(), - ProviderModelSetting.model_name == model, - ) - .first() - ) + with Session(db.engine) as session: + model_setting = self._get_provider_model_setting(model_type=model_type, model=model, session=session) - if model_setting: - model_setting.load_balancing_enabled = False - model_setting.updated_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None) - db.session.commit() - else: - model_setting = ProviderModelSetting() - model_setting.tenant_id = self.tenant_id - model_setting.provider_name = self.provider.provider - model_setting.model_type = model_type.to_origin_model_type() - model_setting.model_name = model - model_setting.load_balancing_enabled = False - db.session.add(model_setting) - db.session.commit() + if model_setting: + model_setting.load_balancing_enabled = False + model_setting.updated_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None) + else: + model_setting = ProviderModelSetting( + tenant_id=self.tenant_id, + provider_name=self.provider.provider, + model_type=model_type.to_origin_model_type(), + model_name=model, + load_balancing_enabled=False, + ) + session.add(model_setting) + session.commit() return model_setting @@ -673,7 +1231,7 @@ class ProviderConfiguration(BaseModel): provider=self.provider.provider, model_type=model_type, model=model, credentials=credentials ) - def switch_preferred_provider_type(self, provider_type: ProviderType) -> None: + def switch_preferred_provider_type(self, provider_type: ProviderType, session: Session | None = None) -> None: """ Switch preferred provider type. :param provider_type: @@ -685,31 +1243,35 @@ class ProviderConfiguration(BaseModel): if provider_type == ProviderType.SYSTEM and not self.system_configuration.enabled: return - # get preferred provider - model_provider_id = ModelProviderID(self.provider.provider) - provider_names = [self.provider.provider] - if model_provider_id.is_langgenius(): - provider_names.append(model_provider_id.provider_name) + def _switch(s: Session) -> None: + # get preferred provider + model_provider_id = ModelProviderID(self.provider.provider) + provider_names = [self.provider.provider] + if model_provider_id.is_langgenius(): + provider_names.append(model_provider_id.provider_name) - preferred_model_provider = ( - db.session.query(TenantPreferredModelProvider) - .where( + stmt = select(TenantPreferredModelProvider).where( TenantPreferredModelProvider.tenant_id == self.tenant_id, TenantPreferredModelProvider.provider_name.in_(provider_names), ) - .first() - ) + preferred_model_provider = s.execute(stmt).scalars().first() - if preferred_model_provider: - preferred_model_provider.preferred_provider_type = provider_type.value + if preferred_model_provider: + preferred_model_provider.preferred_provider_type = provider_type.value + else: + preferred_model_provider = TenantPreferredModelProvider( + tenant_id=self.tenant_id, + provider_name=self.provider.provider, + preferred_provider_type=provider_type.value, + ) + session.add(preferred_model_provider) + session.commit() + + if session: + return _switch(session) else: - preferred_model_provider = TenantPreferredModelProvider() - preferred_model_provider.tenant_id = self.tenant_id - preferred_model_provider.provider_name = self.provider.provider - preferred_model_provider.preferred_provider_type = provider_type.value - db.session.add(preferred_model_provider) - - db.session.commit() + with Session(db.engine) as session: + return _switch(session) def extract_secret_variables(self, credential_form_schemas: list[CredentialFormSchema]) -> list[str]: """ @@ -973,6 +1535,7 @@ class ProviderConfiguration(BaseModel): status = ModelStatus.ACTIVE if credentials else ModelStatus.NO_CONFIGURE load_balancing_enabled = False + has_invalid_load_balancing_configs = False if m.model_type in model_setting_map and m.model in model_setting_map[m.model_type]: model_setting = model_setting_map[m.model_type][m.model] if model_setting.enabled is False: @@ -981,6 +1544,9 @@ class ProviderConfiguration(BaseModel): if len(model_setting.load_balancing_configs) > 1: load_balancing_enabled = True + if model_setting.has_invalid_load_balancing_configs: + has_invalid_load_balancing_configs = True + provider_models.append( ModelWithProviderEntity( model=m.model, @@ -993,6 +1559,7 @@ class ProviderConfiguration(BaseModel): provider=SimpleModelProviderEntity(self.provider), status=status, load_balancing_enabled=load_balancing_enabled, + has_invalid_load_balancing_configs=has_invalid_load_balancing_configs, ) ) diff --git a/api/core/entities/provider_entities.py b/api/core/entities/provider_entities.py index a5a6e62bd7..00b9039610 100644 --- a/api/core/entities/provider_entities.py +++ b/api/core/entities/provider_entities.py @@ -69,6 +69,15 @@ class QuotaConfiguration(BaseModel): restrict_models: list[RestrictModel] = [] +class CredentialConfiguration(BaseModel): + """ + Model class for credential configuration. + """ + + credential_id: str + credential_name: str + + class SystemConfiguration(BaseModel): """ Model class for provider system configuration. @@ -86,6 +95,9 @@ class CustomProviderConfiguration(BaseModel): """ credentials: dict + current_credential_id: Optional[str] = None + current_credential_name: Optional[str] = None + available_credentials: list[CredentialConfiguration] = [] class CustomModelConfiguration(BaseModel): @@ -96,6 +108,8 @@ class CustomModelConfiguration(BaseModel): model: str model_type: ModelType credentials: dict + current_credential_id: Optional[str] = None + available_model_credentials: list[CredentialConfiguration] = [] # pydantic configs model_config = ConfigDict(protected_namespaces=()) @@ -129,6 +143,7 @@ class ModelSettings(BaseModel): model_type: ModelType enabled: bool = True load_balancing_configs: list[ModelLoadBalancingConfiguration] = [] + has_invalid_load_balancing_configs: bool = False # pydantic configs model_config = ConfigDict(protected_namespaces=()) diff --git a/api/core/provider_manager.py b/api/core/provider_manager.py index 9250497d29..c3de77d1d6 100644 --- a/api/core/provider_manager.py +++ b/api/core/provider_manager.py @@ -11,6 +11,7 @@ from configs import dify_config from core.entities.model_entities import DefaultModelEntity, DefaultModelProviderEntity from core.entities.provider_configuration import ProviderConfiguration, ProviderConfigurations, ProviderModelBundle from core.entities.provider_entities import ( + CredentialConfiguration, CustomConfiguration, CustomModelConfiguration, CustomProviderConfiguration, @@ -39,7 +40,9 @@ from extensions.ext_redis import redis_client from models.provider import ( LoadBalancingModelConfig, Provider, + ProviderCredential, ProviderModel, + ProviderModelCredential, ProviderModelSetting, ProviderType, TenantDefaultModel, @@ -487,6 +490,61 @@ class ProviderManager: return provider_name_to_provider_load_balancing_model_configs_dict + @staticmethod + def get_provider_available_credentials(tenant_id: str, provider_name: str) -> list[CredentialConfiguration]: + """ + Get provider all credentials. + + :param tenant_id: workspace id + :param provider_name: provider name + :return: + """ + with Session(db.engine, expire_on_commit=False) as session: + stmt = ( + select(ProviderCredential) + .where(ProviderCredential.tenant_id == tenant_id, ProviderCredential.provider_name == provider_name) + .order_by(ProviderCredential.created_at.desc()) + ) + + available_credentials = session.scalars(stmt).all() + + return [ + CredentialConfiguration(credential_id=credential.id, credential_name=credential.credential_name) + for credential in available_credentials + ] + + @staticmethod + def get_provider_model_available_credentials( + tenant_id: str, provider_name: str, model_name: str, model_type: str + ) -> list[CredentialConfiguration]: + """ + Get provider custom model all credentials. + + :param tenant_id: workspace id + :param provider_name: provider name + :param model_name: model name + :param model_type: model type + :return: + """ + with Session(db.engine, expire_on_commit=False) as session: + stmt = ( + select(ProviderModelCredential) + .where( + ProviderModelCredential.tenant_id == tenant_id, + ProviderModelCredential.provider_name == provider_name, + ProviderModelCredential.model_name == model_name, + ProviderModelCredential.model_type == model_type, + ) + .order_by(ProviderModelCredential.created_at.desc()) + ) + + available_credentials = session.scalars(stmt).all() + + return [ + CredentialConfiguration(credential_id=credential.id, credential_name=credential.credential_name) + for credential in available_credentials + ] + @staticmethod def _init_trial_provider_records( tenant_id: str, provider_name_to_provider_records_dict: dict[str, list[Provider]] @@ -589,9 +647,6 @@ class ProviderManager: if provider_record.provider_type == ProviderType.SYSTEM.value: continue - if not provider_record.encrypted_config: - continue - custom_provider_record = provider_record # Get custom provider credentials @@ -610,8 +665,8 @@ class ProviderManager: try: # fix origin data if custom_provider_record.encrypted_config is None: - raise ValueError("No credentials found") - if not custom_provider_record.encrypted_config.startswith("{"): + provider_credentials = {} + elif not custom_provider_record.encrypted_config.startswith("{"): provider_credentials = {"openai_api_key": custom_provider_record.encrypted_config} else: provider_credentials = json.loads(custom_provider_record.encrypted_config) @@ -638,7 +693,14 @@ class ProviderManager: else: provider_credentials = cached_provider_credentials - custom_provider_configuration = CustomProviderConfiguration(credentials=provider_credentials) + custom_provider_configuration = CustomProviderConfiguration( + credentials=provider_credentials, + current_credential_name=custom_provider_record.credential_name, + current_credential_id=custom_provider_record.credential_id, + available_credentials=self.get_provider_available_credentials( + tenant_id, custom_provider_record.provider_name + ), + ) # Get provider model credential secret variables model_credential_secret_variables = self._extract_secret_variables( @@ -650,8 +712,12 @@ class ProviderManager: # Get custom provider model credentials custom_model_configurations = [] for provider_model_record in provider_model_records: - if not provider_model_record.encrypted_config: - continue + available_model_credentials = self.get_provider_model_available_credentials( + tenant_id, + provider_model_record.provider_name, + provider_model_record.model_name, + provider_model_record.model_type, + ) provider_model_credentials_cache = ProviderCredentialsCache( tenant_id=tenant_id, identity_id=provider_model_record.id, cache_type=ProviderCredentialsCacheType.MODEL @@ -691,6 +757,8 @@ class ProviderManager: model=provider_model_record.model_name, model_type=ModelType.value_of(provider_model_record.model_type), credentials=provider_model_credentials, + current_credential_id=provider_model_record.credential_id, + available_model_credentials=available_model_credentials, ) ) @@ -894,6 +962,7 @@ class ProviderManager: if not provider_model_settings: return model_settings + has_invalid_load_balancing_configs = False for provider_model_setting in provider_model_settings: load_balancing_configs = [] if provider_model_setting.load_balancing_enabled and load_balancing_model_configs: @@ -902,6 +971,10 @@ class ProviderManager: load_balancing_model_config.model_name == provider_model_setting.model_name and load_balancing_model_config.model_type == provider_model_setting.model_type ): + if load_balancing_model_config.name == "__delete__": + has_invalid_load_balancing_configs = True + continue + if not load_balancing_model_config.enabled: continue @@ -967,6 +1040,7 @@ class ProviderManager: model_type=ModelType.value_of(provider_model_setting.model_type), enabled=provider_model_setting.enabled, load_balancing_configs=load_balancing_configs if len(load_balancing_configs) > 1 else [], + has_invalid_load_balancing_configs=has_invalid_load_balancing_configs, ) ) diff --git a/api/migrations/versions/2025_08_09_1553-e8446f481c1e_add_provider_credential_pool_support.py b/api/migrations/versions/2025_08_09_1553-e8446f481c1e_add_provider_credential_pool_support.py new file mode 100644 index 0000000000..87b42346df --- /dev/null +++ b/api/migrations/versions/2025_08_09_1553-e8446f481c1e_add_provider_credential_pool_support.py @@ -0,0 +1,177 @@ +"""Add provider multi credential support + +Revision ID: e8446f481c1e +Revises: 8bcc02c9bd07 +Create Date: 2025-08-09 15:53:54.341341 + +""" +from alembic import op +import models as models +import sqlalchemy as sa +from sqlalchemy.sql import table, column +import uuid + +# revision identifiers, used by Alembic. +revision = 'e8446f481c1e' +down_revision = 'fa8b0fa6f407' +branch_labels = None +depends_on = None + + +def upgrade(): + # Create provider_credentials table + op.create_table('provider_credentials', + sa.Column('id', models.types.StringUUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False), + sa.Column('tenant_id', models.types.StringUUID(), nullable=False), + sa.Column('provider_name', sa.String(length=255), nullable=False), + sa.Column('credential_name', sa.String(length=255), nullable=False), + sa.Column('encrypted_config', sa.Text(), nullable=False), + sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP'), nullable=False), + sa.Column('updated_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP'), nullable=False), + sa.PrimaryKeyConstraint('id', name='provider_credential_pkey') + ) + + # Create index for provider_credentials + with op.batch_alter_table('provider_credentials', schema=None) as batch_op: + batch_op.create_index('provider_credential_tenant_provider_idx', ['tenant_id', 'provider_name'], unique=False) + + # Add credential_id to providers table + with op.batch_alter_table('providers', schema=None) as batch_op: + batch_op.add_column(sa.Column('credential_id', models.types.StringUUID(), nullable=True)) + + # Add credential_id to load_balancing_model_configs table + with op.batch_alter_table('load_balancing_model_configs', schema=None) as batch_op: + batch_op.add_column(sa.Column('credential_id', models.types.StringUUID(), nullable=True)) + + migrate_existing_providers_data() + + # Remove encrypted_config column from providers table after migration + with op.batch_alter_table('providers', schema=None) as batch_op: + batch_op.drop_column('encrypted_config') + + +def migrate_existing_providers_data(): + """migrate providers table data to provider_credentials""" + + # Define table structure for data manipulation + providers_table = table('providers', + column('id', models.types.StringUUID()), + column('tenant_id', models.types.StringUUID()), + column('provider_name', sa.String()), + column('encrypted_config', sa.Text()), + column('created_at', sa.DateTime()), + column('updated_at', sa.DateTime()), + column('credential_id', models.types.StringUUID()), + ) + + provider_credential_table = table('provider_credentials', + column('id', models.types.StringUUID()), + column('tenant_id', models.types.StringUUID()), + column('provider_name', sa.String()), + column('credential_name', sa.String()), + column('encrypted_config', sa.Text()), + column('created_at', sa.DateTime()), + column('updated_at', sa.DateTime()) + ) + + # Get database connection + conn = op.get_bind() + + # Query all existing providers data + existing_providers = conn.execute( + sa.select(providers_table.c.id, providers_table.c.tenant_id, + providers_table.c.provider_name, providers_table.c.encrypted_config, + providers_table.c.created_at, providers_table.c.updated_at) + .where(providers_table.c.encrypted_config.isnot(None)) + ).fetchall() + + # Iterate through each provider and insert into provider_credentials + for provider in existing_providers: + credential_id = str(uuid.uuid4()) + if not provider.encrypted_config or provider.encrypted_config.strip() == '': + continue + + # Insert into provider_credentials table + conn.execute( + provider_credential_table.insert().values( + id=credential_id, + tenant_id=provider.tenant_id, + provider_name=provider.provider_name, + credential_name='API_KEY1', # Use a default name + encrypted_config=provider.encrypted_config, + created_at=provider.created_at, + updated_at=provider.updated_at + ) + ) + + # Update original providers table, set credential_id + conn.execute( + providers_table.update() + .where(providers_table.c.id == provider.id) + .values( + credential_id=credential_id, + ) + ) + +def downgrade(): + # Re-add encrypted_config column to providers table + with op.batch_alter_table('providers', schema=None) as batch_op: + batch_op.add_column(sa.Column('encrypted_config', sa.Text(), nullable=True)) + + # Migrate data back from provider_credentials to providers + migrate_data_back_to_providers() + + # Remove credential_id columns + with op.batch_alter_table('load_balancing_model_configs', schema=None) as batch_op: + batch_op.drop_column('credential_id') + + with op.batch_alter_table('providers', schema=None) as batch_op: + batch_op.drop_column('credential_id') + + # Drop provider_credentials table + op.drop_table('provider_credentials') + + +def migrate_data_back_to_providers(): + """Migrate data back from provider_credentials to providers table for downgrade""" + + # Define table structure for data manipulation + providers_table = table('providers', + column('id', models.types.StringUUID()), + column('tenant_id', models.types.StringUUID()), + column('provider_name', sa.String()), + column('encrypted_config', sa.Text()), + column('credential_id', models.types.StringUUID()), + ) + + provider_credential_table = table('provider_credentials', + column('id', models.types.StringUUID()), + column('tenant_id', models.types.StringUUID()), + column('provider_name', sa.String()), + column('credential_name', sa.String()), + column('encrypted_config', sa.Text()), + ) + + # Get database connection + conn = op.get_bind() + + # Query providers that have credential_id + providers_with_credentials = conn.execute( + sa.select(providers_table.c.id, providers_table.c.credential_id) + .where(providers_table.c.credential_id.isnot(None)) + ).fetchall() + + # For each provider, get the credential data and update providers table + for provider in providers_with_credentials: + credential = conn.execute( + sa.select(provider_credential_table.c.encrypted_config) + .where(provider_credential_table.c.id == provider.credential_id) + ).fetchone() + + if credential: + # Update providers table with encrypted_config from credential + conn.execute( + providers_table.update() + .where(providers_table.c.id == provider.id) + .values(encrypted_config=credential.encrypted_config) + ) \ No newline at end of file diff --git a/api/migrations/versions/2025_08_13_1605-0e154742a5fa_add_provider_model_multi_credential.py b/api/migrations/versions/2025_08_13_1605-0e154742a5fa_add_provider_model_multi_credential.py new file mode 100644 index 0000000000..bec1a45404 --- /dev/null +++ b/api/migrations/versions/2025_08_13_1605-0e154742a5fa_add_provider_model_multi_credential.py @@ -0,0 +1,186 @@ +"""Add provider model multi credential support + +Revision ID: 0e154742a5fa +Revises: e8446f481c1e +Create Date: 2025-08-13 16:05:42.657730 + +""" +import uuid + +from alembic import op +import models as models +import sqlalchemy as sa +from sqlalchemy.sql import table, column + + +# revision identifiers, used by Alembic. +revision = '0e154742a5fa' +down_revision = 'e8446f481c1e' +branch_labels = None +depends_on = None + + +def upgrade(): + # Create provider_model_credentials table + op.create_table('provider_model_credentials', + sa.Column('id', models.types.StringUUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False), + sa.Column('tenant_id', models.types.StringUUID(), nullable=False), + sa.Column('provider_name', sa.String(length=255), nullable=False), + sa.Column('model_name', sa.String(length=255), nullable=False), + sa.Column('model_type', sa.String(length=40), nullable=False), + sa.Column('credential_name', sa.String(length=255), nullable=False), + sa.Column('encrypted_config', sa.Text(), nullable=False), + sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP'), nullable=False), + sa.Column('updated_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP'), nullable=False), + sa.PrimaryKeyConstraint('id', name='provider_model_credential_pkey') + ) + + # Create index for provider_model_credentials + with op.batch_alter_table('provider_model_credentials', schema=None) as batch_op: + batch_op.create_index('provider_model_credential_tenant_provider_model_idx', ['tenant_id', 'provider_name', 'model_name', 'model_type'], unique=False) + + # Add credential_id to provider_models table + with op.batch_alter_table('provider_models', schema=None) as batch_op: + batch_op.add_column(sa.Column('credential_id', models.types.StringUUID(), nullable=True)) + + + # Add credential_source_type to load_balancing_model_configs table + with op.batch_alter_table('load_balancing_model_configs', schema=None) as batch_op: + batch_op.add_column(sa.Column('credential_source_type', sa.String(length=40), nullable=True)) + + # Migrate existing provider_models data + migrate_existing_provider_models_data() + + # Remove encrypted_config column from provider_models table after migration + with op.batch_alter_table('provider_models', schema=None) as batch_op: + batch_op.drop_column('encrypted_config') + + +def migrate_existing_provider_models_data(): + """migrate provider_models table data to provider_model_credentials""" + + # Define table structure for data manipulation + provider_models_table = table('provider_models', + column('id', models.types.StringUUID()), + column('tenant_id', models.types.StringUUID()), + column('provider_name', sa.String()), + column('model_name', sa.String()), + column('model_type', sa.String()), + column('encrypted_config', sa.Text()), + column('created_at', sa.DateTime()), + column('updated_at', sa.DateTime()), + column('credential_id', models.types.StringUUID()), + ) + + provider_model_credentials_table = table('provider_model_credentials', + column('id', models.types.StringUUID()), + column('tenant_id', models.types.StringUUID()), + column('provider_name', sa.String()), + column('model_name', sa.String()), + column('model_type', sa.String()), + column('credential_name', sa.String()), + column('encrypted_config', sa.Text()), + column('created_at', sa.DateTime()), + column('updated_at', sa.DateTime()) + ) + + + # Get database connection + conn = op.get_bind() + + # Query all existing provider_models data with encrypted_config + existing_provider_models = conn.execute( + sa.select(provider_models_table.c.id, provider_models_table.c.tenant_id, + provider_models_table.c.provider_name, provider_models_table.c.model_name, + provider_models_table.c.model_type, provider_models_table.c.encrypted_config, + provider_models_table.c.created_at, provider_models_table.c.updated_at) + .where(provider_models_table.c.encrypted_config.isnot(None)) + ).fetchall() + + # Iterate through each provider_model and insert into provider_model_credentials + for provider_model in existing_provider_models: + if not provider_model.encrypted_config or provider_model.encrypted_config.strip() == '': + continue + + credential_id = str(uuid.uuid4()) + + # Insert into provider_model_credentials table + conn.execute( + provider_model_credentials_table.insert().values( + id=credential_id, + tenant_id=provider_model.tenant_id, + provider_name=provider_model.provider_name, + model_name=provider_model.model_name, + model_type=provider_model.model_type, + credential_name='API_KEY1', # Use a default name + encrypted_config=provider_model.encrypted_config, + created_at=provider_model.created_at, + updated_at=provider_model.updated_at + ) + ) + + # Update original provider_models table, set credential_id + conn.execute( + provider_models_table.update() + .where(provider_models_table.c.id == provider_model.id) + .values(credential_id=credential_id) + ) + + +def downgrade(): + # Re-add encrypted_config column to provider_models table + with op.batch_alter_table('provider_models', schema=None) as batch_op: + batch_op.add_column(sa.Column('encrypted_config', sa.Text(), nullable=True)) + + # Migrate data back from provider_model_credentials to provider_models + migrate_data_back_to_provider_models() + + with op.batch_alter_table('provider_models', schema=None) as batch_op: + batch_op.drop_column('credential_id') + + # Remove credential_source_type column from load_balancing_model_configs + with op.batch_alter_table('load_balancing_model_configs', schema=None) as batch_op: + batch_op.drop_column('credential_source_type') + + # Drop provider_model_credentials table + op.drop_table('provider_model_credentials') + + +def migrate_data_back_to_provider_models(): + """Migrate data back from provider_model_credentials to provider_models table for downgrade""" + + # Define table structure for data manipulation + provider_models_table = table('provider_models', + column('id', models.types.StringUUID()), + column('encrypted_config', sa.Text()), + column('credential_id', models.types.StringUUID()), + ) + + provider_model_credentials_table = table('provider_model_credentials', + column('id', models.types.StringUUID()), + column('encrypted_config', sa.Text()), + ) + + # Get database connection + conn = op.get_bind() + + # Query provider_models that have credential_id + provider_models_with_credentials = conn.execute( + sa.select(provider_models_table.c.id, provider_models_table.c.credential_id) + .where(provider_models_table.c.credential_id.isnot(None)) + ).fetchall() + + # For each provider_model, get the credential data and update provider_models table + for provider_model in provider_models_with_credentials: + credential = conn.execute( + sa.select(provider_model_credentials_table.c.encrypted_config) + .where(provider_model_credentials_table.c.id == provider_model.credential_id) + ).fetchone() + + if credential: + # Update provider_models table with encrypted_config from credential + conn.execute( + provider_models_table.update() + .where(provider_models_table.c.id == provider_model.id) + .values(encrypted_config=credential.encrypted_config) + ) diff --git a/api/models/provider.py b/api/models/provider.py index 4ea2c59fdb..0a32f92730 100644 --- a/api/models/provider.py +++ b/api/models/provider.py @@ -1,5 +1,6 @@ from datetime import datetime from enum import Enum +from functools import cached_property from typing import Optional import sqlalchemy as sa @@ -7,6 +8,7 @@ from sqlalchemy import DateTime, String, func, text from sqlalchemy.orm import Mapped, mapped_column from .base import Base +from .engine import db from .types import StringUUID @@ -60,9 +62,9 @@ class Provider(Base): provider_type: Mapped[str] = mapped_column( String(40), nullable=False, server_default=text("'custom'::character varying") ) - encrypted_config: Mapped[Optional[str]] = mapped_column(sa.Text, nullable=True) is_valid: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=text("false")) last_used: Mapped[Optional[datetime]] = mapped_column(DateTime, nullable=True) + credential_id: Mapped[Optional[str]] = mapped_column(StringUUID, nullable=True) quota_type: Mapped[Optional[str]] = mapped_column( String(40), nullable=True, server_default=text("''::character varying") @@ -79,6 +81,21 @@ class Provider(Base): f" provider_type='{self.provider_type}')>" ) + @cached_property + def credential(self): + if self.credential_id: + return db.session.query(ProviderCredential).where(ProviderCredential.id == self.credential_id).first() + + @property + def credential_name(self): + credential = self.credential + return credential.credential_name if credential else None + + @property + def encrypted_config(self): + credential = self.credential + return credential.encrypted_config if credential else None + @property def token_is_set(self): """ @@ -116,11 +133,30 @@ class ProviderModel(Base): provider_name: Mapped[str] = mapped_column(String(255), nullable=False) model_name: Mapped[str] = mapped_column(String(255), nullable=False) model_type: Mapped[str] = mapped_column(String(40), nullable=False) - encrypted_config: Mapped[Optional[str]] = mapped_column(sa.Text, nullable=True) - is_valid: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=text("false")) + credential_id: Mapped[Optional[str]] = mapped_column(StringUUID, nullable=True) + is_valid: Mapped[bool] = mapped_column(db.Boolean, nullable=False, server_default=text("false")) created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp()) updated_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp()) + @cached_property + def credential(self): + if self.credential_id: + return ( + db.session.query(ProviderModelCredential) + .where(ProviderModelCredential.id == self.credential_id) + .first() + ) + + @property + def credential_name(self): + credential = self.credential + return credential.credential_name if credential else None + + @property + def encrypted_config(self): + credential = self.credential + return credential.encrypted_config if credential else None + class TenantDefaultModel(Base): __tablename__ = "tenant_default_models" @@ -220,6 +256,56 @@ class LoadBalancingModelConfig(Base): model_type: Mapped[str] = mapped_column(String(40), nullable=False) name: Mapped[str] = mapped_column(String(255), nullable=False) encrypted_config: Mapped[Optional[str]] = mapped_column(sa.Text, nullable=True) + credential_id: Mapped[Optional[str]] = mapped_column(StringUUID, nullable=True) + credential_source_type: Mapped[Optional[str]] = mapped_column(String(40), nullable=True) enabled: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=text("true")) created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp()) updated_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp()) + + +class ProviderCredential(Base): + """ + Provider credential - stores multiple named credentials for each provider + """ + + __tablename__ = "provider_credentials" + __table_args__ = ( + sa.PrimaryKeyConstraint("id", name="provider_credential_pkey"), + sa.Index("provider_credential_tenant_provider_idx", "tenant_id", "provider_name"), + ) + + id: Mapped[str] = mapped_column(StringUUID, server_default=text("uuid_generate_v4()")) + tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False) + provider_name: Mapped[str] = mapped_column(String(255), nullable=False) + credential_name: Mapped[str] = mapped_column(String(255), nullable=False) + encrypted_config: Mapped[str] = mapped_column(sa.Text, nullable=False) + created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp()) + updated_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp()) + + +class ProviderModelCredential(Base): + """ + Provider model credential - stores multiple named credentials for each provider model + """ + + __tablename__ = "provider_model_credentials" + __table_args__ = ( + db.PrimaryKeyConstraint("id", name="provider_model_credential_pkey"), + db.Index( + "provider_model_credential_tenant_provider_model_idx", + "tenant_id", + "provider_name", + "model_name", + "model_type", + ), + ) + + id: Mapped[str] = mapped_column(StringUUID, server_default=text("uuid_generate_v4()")) + tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False) + provider_name: Mapped[str] = mapped_column(String(255), nullable=False) + model_name: Mapped[str] = mapped_column(String(255), nullable=False) + model_type: Mapped[str] = mapped_column(String(40), nullable=False) + credential_name: Mapped[str] = mapped_column(String(255), nullable=False) + encrypted_config: Mapped[str] = mapped_column(sa.Text, nullable=False) + created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp()) + updated_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp()) diff --git a/api/services/entities/model_provider_entities.py b/api/services/entities/model_provider_entities.py index bc385b2e22..056decda26 100644 --- a/api/services/entities/model_provider_entities.py +++ b/api/services/entities/model_provider_entities.py @@ -8,7 +8,12 @@ from core.entities.model_entities import ( ModelWithProviderEntity, ProviderModelWithStatusEntity, ) -from core.entities.provider_entities import ProviderQuotaType, QuotaConfiguration +from core.entities.provider_entities import ( + CredentialConfiguration, + CustomModelConfiguration, + ProviderQuotaType, + QuotaConfiguration, +) from core.model_runtime.entities.common_entities import I18nObject from core.model_runtime.entities.model_entities import ModelType from core.model_runtime.entities.provider_entities import ( @@ -36,6 +41,10 @@ class CustomConfigurationResponse(BaseModel): """ status: CustomConfigurationStatus + current_credential_id: Optional[str] = None + current_credential_name: Optional[str] = None + available_credentials: Optional[list[CredentialConfiguration]] = None + custom_models: Optional[list[CustomModelConfiguration]] = None class SystemConfigurationResponse(BaseModel): diff --git a/api/services/errors/app_model_config.py b/api/services/errors/app_model_config.py index c0669ed231..bb5eb62b75 100644 --- a/api/services/errors/app_model_config.py +++ b/api/services/errors/app_model_config.py @@ -3,3 +3,7 @@ from services.errors.base import BaseServiceError class AppModelConfigBrokenError(BaseServiceError): pass + + +class ProviderNotFoundError(BaseServiceError): + pass diff --git a/api/services/model_load_balancing_service.py b/api/services/model_load_balancing_service.py index fe28aa006e..4b677203ba 100644 --- a/api/services/model_load_balancing_service.py +++ b/api/services/model_load_balancing_service.py @@ -17,7 +17,7 @@ from core.model_runtime.entities.provider_entities import ( from core.model_runtime.model_providers.model_provider_factory import ModelProviderFactory from core.provider_manager import ProviderManager from extensions.ext_database import db -from models.provider import LoadBalancingModelConfig +from models.provider import LoadBalancingModelConfig, ProviderCredential, ProviderModelCredential logger = logging.getLogger(__name__) @@ -185,6 +185,7 @@ class ModelLoadBalancingService: "id": load_balancing_config.id, "name": load_balancing_config.name, "credentials": credentials, + "credential_id": load_balancing_config.credential_id, "enabled": load_balancing_config.enabled, "in_cooldown": in_cooldown, "ttl": ttl, @@ -280,7 +281,7 @@ class ModelLoadBalancingService: return inherit_config def update_load_balancing_configs( - self, tenant_id: str, provider: str, model: str, model_type: str, configs: list[dict] + self, tenant_id: str, provider: str, model: str, model_type: str, configs: list[dict], config_from: str ) -> None: """ Update load balancing configurations. @@ -289,6 +290,7 @@ class ModelLoadBalancingService: :param model: model name :param model_type: model type :param configs: load balancing configs + :param config_from: predefined-model or custom-model :return: """ # Get all provider configurations of the current workspace @@ -327,8 +329,36 @@ class ModelLoadBalancingService: config_id = config.get("id") name = config.get("name") credentials = config.get("credentials") + credential_id = config.get("credential_id") enabled = config.get("enabled") + if credential_id: + if config_from == "predefined-model": + credential_record = ( + db.session.query(ProviderCredential) + .filter_by( + id=credential_id, + tenant_id=tenant_id, + provider_name=provider_configuration.provider.provider, + ) + .first() + ) + else: + credential_record = ( + db.session.query(ProviderModelCredential) + .filter_by( + id=credential_id, + tenant_id=tenant_id, + provider_name=provider_configuration.provider.provider, + model_name=model, + model_type=model_type_enum.to_origin_model_type(), + ) + .first() + ) + if not credential_record: + raise ValueError(f"Provider credential with id {credential_id} not found") + name = credential_record.credential_name + if not name: raise ValueError("Invalid load balancing config name") @@ -346,11 +376,6 @@ class ModelLoadBalancingService: load_balancing_config = current_load_balancing_configs_dict[config_id] - # check duplicate name - for current_load_balancing_config in current_load_balancing_configs: - if current_load_balancing_config.id != config_id and current_load_balancing_config.name == name: - raise ValueError(f"Load balancing config name {name} already exists") - if credentials: if not isinstance(credentials, dict): raise ValueError("Invalid load balancing config credentials") @@ -377,39 +402,47 @@ class ModelLoadBalancingService: self._clear_credentials_cache(tenant_id, config_id) else: # create load balancing config - if name == "__inherit__": + if name in {"__inherit__", "__delete__"}: raise ValueError("Invalid load balancing config name") - # check duplicate name - for current_load_balancing_config in current_load_balancing_configs: - if current_load_balancing_config.name == name: - raise ValueError(f"Load balancing config name {name} already exists") + if credential_id: + credential_source = "provider" if config_from == "predefined-model" else "custom_model" + load_balancing_model_config = LoadBalancingModelConfig( + tenant_id=tenant_id, + provider_name=provider_configuration.provider.provider, + model_type=model_type_enum.to_origin_model_type(), + model_name=model, + name=credential_record.credential_name, + encrypted_config=credential_record.encrypted_config, + credential_id=credential_id, + credential_source_type=credential_source, + ) + else: + if not credentials: + raise ValueError("Invalid load balancing config credentials") - if not credentials: - raise ValueError("Invalid load balancing config credentials") + if not isinstance(credentials, dict): + raise ValueError("Invalid load balancing config credentials") - if not isinstance(credentials, dict): - raise ValueError("Invalid load balancing config credentials") + # validate custom provider config + credentials = self._custom_credentials_validate( + tenant_id=tenant_id, + provider_configuration=provider_configuration, + model_type=model_type_enum, + model=model, + credentials=credentials, + validate=False, + ) - # validate custom provider config - credentials = self._custom_credentials_validate( - tenant_id=tenant_id, - provider_configuration=provider_configuration, - model_type=model_type_enum, - model=model, - credentials=credentials, - validate=False, - ) - - # create load balancing config - load_balancing_model_config = LoadBalancingModelConfig( - tenant_id=tenant_id, - provider_name=provider_configuration.provider.provider, - model_type=model_type_enum.to_origin_model_type(), - model_name=model, - name=name, - encrypted_config=json.dumps(credentials), - ) + # create load balancing config + load_balancing_model_config = LoadBalancingModelConfig( + tenant_id=tenant_id, + provider_name=provider_configuration.provider.provider, + model_type=model_type_enum.to_origin_model_type(), + model_name=model, + name=name, + encrypted_config=json.dumps(credentials), + ) db.session.add(load_balancing_model_config) db.session.commit() diff --git a/api/services/model_provider_service.py b/api/services/model_provider_service.py index 54197bf949..e330d2852f 100644 --- a/api/services/model_provider_service.py +++ b/api/services/model_provider_service.py @@ -16,6 +16,7 @@ from services.entities.model_provider_entities import ( SimpleProviderEntityResponse, SystemConfigurationResponse, ) +from services.errors.app_model_config import ProviderNotFoundError logger = logging.getLogger(__name__) @@ -28,6 +29,29 @@ class ModelProviderService: def __init__(self) -> None: self.provider_manager = ProviderManager() + def _get_provider_configuration(self, tenant_id: str, provider: str): + """ + Get provider configuration or raise exception if not found. + + Args: + tenant_id: Workspace identifier + provider: Provider name + + Returns: + Provider configuration instance + + Raises: + ProviderNotFoundError: If provider doesn't exist + """ + # Get all provider configurations of the current workspace + provider_configurations = self.provider_manager.get_configurations(tenant_id) + provider_configuration = provider_configurations.get(provider) + + if not provider_configuration: + raise ProviderNotFoundError(f"Provider {provider} does not exist.") + + return provider_configuration + def get_provider_list(self, tenant_id: str, model_type: Optional[str] = None) -> list[ProviderResponse]: """ get provider list. @@ -46,6 +70,9 @@ class ModelProviderService: if model_type_entity not in provider_configuration.provider.supported_model_types: continue + provider_config = provider_configuration.custom_configuration.provider + model_config = provider_configuration.custom_configuration.models + provider_response = ProviderResponse( tenant_id=tenant_id, provider=provider_configuration.provider.provider, @@ -63,7 +90,11 @@ class ModelProviderService: custom_configuration=CustomConfigurationResponse( status=CustomConfigurationStatus.ACTIVE if provider_configuration.is_custom_configuration_available() - else CustomConfigurationStatus.NO_CONFIGURE + else CustomConfigurationStatus.NO_CONFIGURE, + current_credential_id=getattr(provider_config, "current_credential_id", None), + current_credential_name=getattr(provider_config, "current_credential_name", None), + available_credentials=getattr(provider_config, "available_credentials", []), + custom_models=model_config, ), system_configuration=SystemConfigurationResponse( enabled=provider_configuration.system_configuration.enabled, @@ -82,8 +113,8 @@ class ModelProviderService: For the model provider page, only supports passing in a single provider to query the list of supported models. - :param tenant_id: - :param provider: + :param tenant_id: workspace id + :param provider: provider name :return: """ # Get all provider configurations of the current workspace @@ -95,98 +126,111 @@ class ModelProviderService: for model in provider_configurations.get_models(provider=provider) ] - def get_provider_credentials(self, tenant_id: str, provider: str) -> Optional[dict]: + def get_provider_credential( + self, tenant_id: str, provider: str, credential_id: Optional[str] = None + ) -> Optional[dict]: """ get provider credentials. - """ - provider_configurations = self.provider_manager.get_configurations(tenant_id) - provider_configuration = provider_configurations.get(provider) - if not provider_configuration: - raise ValueError(f"Provider {provider} does not exist.") - - return provider_configuration.get_custom_credentials(obfuscated=True) - - def provider_credentials_validate(self, tenant_id: str, provider: str, credentials: dict) -> None: - """ - validate provider credentials. - - :param tenant_id: - :param provider: - :param credentials: - """ - # Get all provider configurations of the current workspace - provider_configurations = self.provider_manager.get_configurations(tenant_id) - - # Get provider configuration - provider_configuration = provider_configurations.get(provider) - if not provider_configuration: - raise ValueError(f"Provider {provider} does not exist.") - - provider_configuration.custom_credentials_validate(credentials) - - def save_provider_credentials(self, tenant_id: str, provider: str, credentials: dict) -> None: - """ - save custom provider config. :param tenant_id: workspace id :param provider: provider name - :param credentials: provider credentials + :param credential_id: credential id, if not provided, return current used credentials :return: """ - # Get all provider configurations of the current workspace - provider_configurations = self.provider_manager.get_configurations(tenant_id) + provider_configuration = self._get_provider_configuration(tenant_id, provider) + return provider_configuration.get_provider_credential(credential_id=credential_id) - # Get provider configuration - provider_configuration = provider_configurations.get(provider) - if not provider_configuration: - raise ValueError(f"Provider {provider} does not exist.") - - # Add or update custom provider credentials. - provider_configuration.add_or_update_custom_credentials(credentials) - - def remove_provider_credentials(self, tenant_id: str, provider: str) -> None: + def validate_provider_credentials(self, tenant_id: str, provider: str, credentials: dict) -> None: """ - remove custom provider config. + validate provider credentials before saving. :param tenant_id: workspace id :param provider: provider name + :param credentials: provider credentials dict + """ + provider_configuration = self._get_provider_configuration(tenant_id, provider) + provider_configuration.validate_provider_credentials(credentials) + + def create_provider_credential( + self, tenant_id: str, provider: str, credentials: dict, credential_name: str + ) -> None: + """ + Create and save new provider credentials. + + :param tenant_id: workspace id + :param provider: provider name + :param credentials: provider credentials dict + :param credential_name: credential name :return: """ - # Get all provider configurations of the current workspace - provider_configurations = self.provider_manager.get_configurations(tenant_id) + provider_configuration = self._get_provider_configuration(tenant_id, provider) + provider_configuration.create_provider_credential(credentials, credential_name) - # Get provider configuration - provider_configuration = provider_configurations.get(provider) - if not provider_configuration: - raise ValueError(f"Provider {provider} does not exist.") - - # Remove custom provider credentials. - provider_configuration.delete_custom_credentials() - - def get_model_credentials(self, tenant_id: str, provider: str, model_type: str, model: str) -> Optional[dict]: + def update_provider_credential( + self, + tenant_id: str, + provider: str, + credentials: dict, + credential_id: str, + credential_name: str, + ) -> None: """ - get model credentials. + update a saved provider credential (by credential_id). + + :param tenant_id: workspace id + :param provider: provider name + :param credentials: provider credentials dict + :param credential_id: credential id + :param credential_name: credential name + :return: + """ + provider_configuration = self._get_provider_configuration(tenant_id, provider) + provider_configuration.update_provider_credential( + credential_id=credential_id, + credentials=credentials, + credential_name=credential_name, + ) + + def remove_provider_credential(self, tenant_id: str, provider: str, credential_id: str) -> None: + """ + remove a saved provider credential (by credential_id). + :param tenant_id: workspace id + :param provider: provider name + :param credential_id: credential id + :return: + """ + provider_configuration = self._get_provider_configuration(tenant_id, provider) + provider_configuration.delete_provider_credential(credential_id=credential_id) + + def switch_active_provider_credential(self, tenant_id: str, provider: str, credential_id: str) -> None: + """ + :param tenant_id: workspace id + :param provider: provider name + :param credential_id: credential id + :return: + """ + provider_configuration = self._get_provider_configuration(tenant_id, provider) + provider_configuration.switch_active_provider_credential(credential_id=credential_id) + + def get_model_credential( + self, tenant_id: str, provider: str, model_type: str, model: str, credential_id: str | None + ) -> Optional[dict]: + """ + Retrieve model-specific credentials. :param tenant_id: workspace id :param provider: provider name :param model_type: model type :param model: model name + :param credential_id: Optional credential ID, uses current if not provided :return: """ - # Get all provider configurations of the current workspace - provider_configurations = self.provider_manager.get_configurations(tenant_id) - - # Get provider configuration - provider_configuration = provider_configurations.get(provider) - if not provider_configuration: - raise ValueError(f"Provider {provider} does not exist.") - - # Get model custom credentials from ProviderModel if exists - return provider_configuration.get_custom_model_credentials( - model_type=ModelType.value_of(model_type), model=model, obfuscated=True + provider_configuration = self._get_provider_configuration(tenant_id, provider) + return provider_configuration.get_custom_model_credential( + model_type=ModelType.value_of(model_type), model=model, credential_id=credential_id ) - def model_credentials_validate( + def validate_model_credentials( self, tenant_id: str, provider: str, model_type: str, model: str, credentials: dict ) -> None: """ @@ -196,49 +240,104 @@ class ModelProviderService: :param provider: provider name :param model_type: model type :param model: model name - :param credentials: model credentials + :param credentials: model credentials dict :return: """ - # Get all provider configurations of the current workspace - provider_configurations = self.provider_manager.get_configurations(tenant_id) - - # Get provider configuration - provider_configuration = provider_configurations.get(provider) - if not provider_configuration: - raise ValueError(f"Provider {provider} does not exist.") - - # Validate model credentials - provider_configuration.custom_model_credentials_validate( + provider_configuration = self._get_provider_configuration(tenant_id, provider) + provider_configuration.validate_custom_model_credentials( model_type=ModelType.value_of(model_type), model=model, credentials=credentials ) - def save_model_credentials( - self, tenant_id: str, provider: str, model_type: str, model: str, credentials: dict + def create_model_credential( + self, tenant_id: str, provider: str, model_type: str, model: str, credentials: dict, credential_name: str ) -> None: """ - save model credentials. + create and save model credentials. :param tenant_id: workspace id :param provider: provider name :param model_type: model type :param model: model name - :param credentials: model credentials + :param credentials: model credentials dict + :param credential_name: credential name :return: """ - # Get all provider configurations of the current workspace - provider_configurations = self.provider_manager.get_configurations(tenant_id) - - # Get provider configuration - provider_configuration = provider_configurations.get(provider) - if not provider_configuration: - raise ValueError(f"Provider {provider} does not exist.") - - # Add or update custom model credentials - provider_configuration.add_or_update_custom_model_credentials( - model_type=ModelType.value_of(model_type), model=model, credentials=credentials + provider_configuration = self._get_provider_configuration(tenant_id, provider) + provider_configuration.create_custom_model_credential( + model_type=ModelType.value_of(model_type), + model=model, + credentials=credentials, + credential_name=credential_name, ) - def remove_model_credentials(self, tenant_id: str, provider: str, model_type: str, model: str) -> None: + def update_model_credential( + self, + tenant_id: str, + provider: str, + model_type: str, + model: str, + credentials: dict, + credential_id: str, + credential_name: str, + ) -> None: + """ + update model credentials. + + :param tenant_id: workspace id + :param provider: provider name + :param model_type: model type + :param model: model name + :param credentials: model credentials dict + :param credential_id: credential id + :param credential_name: credential name + :return: + """ + provider_configuration = self._get_provider_configuration(tenant_id, provider) + provider_configuration.update_custom_model_credential( + model_type=ModelType.value_of(model_type), + model=model, + credentials=credentials, + credential_id=credential_id, + credential_name=credential_name, + ) + + def remove_model_credential( + self, tenant_id: str, provider: str, model_type: str, model: str, credential_id: str + ) -> None: + """ + remove model credentials. + + :param tenant_id: workspace id + :param provider: provider name + :param model_type: model type + :param model: model name + :param credential_id: credential id + :return: + """ + provider_configuration = self._get_provider_configuration(tenant_id, provider) + provider_configuration.delete_custom_model_credential( + model_type=ModelType.value_of(model_type), model=model, credential_id=credential_id + ) + + def switch_active_provider_model_credential( + self, tenant_id: str, provider: str, model_type: str, model: str, credential_id: str + ) -> None: + """ + switch model credentials. + + :param tenant_id: workspace id + :param provider: provider name + :param model_type: model type + :param model: model name + :param credential_id: credential id + :return: + """ + provider_configuration = self._get_provider_configuration(tenant_id, provider) + provider_configuration.switch_custom_model_credential( + model_type=ModelType.value_of(model_type), model=model, credential_id=credential_id + ) + + def remove_model(self, tenant_id: str, provider: str, model_type: str, model: str) -> None: """ remove model credentials. @@ -248,16 +347,8 @@ class ModelProviderService: :param model: model name :return: """ - # Get all provider configurations of the current workspace - provider_configurations = self.provider_manager.get_configurations(tenant_id) - - # Get provider configuration - provider_configuration = provider_configurations.get(provider) - if not provider_configuration: - raise ValueError(f"Provider {provider} does not exist.") - - # Remove custom model credentials - provider_configuration.delete_custom_model_credentials(model_type=ModelType.value_of(model_type), model=model) + provider_configuration = self._get_provider_configuration(tenant_id, provider) + provider_configuration.delete_custom_model(model_type=ModelType.value_of(model_type), model=model) def get_models_by_model_type(self, tenant_id: str, model_type: str) -> list[ProviderWithModelsResponse]: """ @@ -331,13 +422,7 @@ class ModelProviderService: :param model: model name :return: """ - # Get all provider configurations of the current workspace - provider_configurations = self.provider_manager.get_configurations(tenant_id) - - # Get provider configuration - provider_configuration = provider_configurations.get(provider) - if not provider_configuration: - raise ValueError(f"Provider {provider} does not exist.") + provider_configuration = self._get_provider_configuration(tenant_id, provider) # fetch credentials credentials = provider_configuration.get_current_credentials(model_type=ModelType.LLM, model=model) @@ -424,17 +509,11 @@ class ModelProviderService: :param preferred_provider_type: preferred provider type :return: """ - # Get all provider configurations of the current workspace - provider_configurations = self.provider_manager.get_configurations(tenant_id) + provider_configuration = self._get_provider_configuration(tenant_id, provider) # Convert preferred_provider_type to ProviderType preferred_provider_type_enum = ProviderType.value_of(preferred_provider_type) - # Get provider configuration - provider_configuration = provider_configurations.get(provider) - if not provider_configuration: - raise ValueError(f"Provider {provider} does not exist.") - # Switch preferred provider type provider_configuration.switch_preferred_provider_type(preferred_provider_type_enum) @@ -448,15 +527,7 @@ class ModelProviderService: :param model_type: model type :return: """ - # Get all provider configurations of the current workspace - provider_configurations = self.provider_manager.get_configurations(tenant_id) - - # Get provider configuration - provider_configuration = provider_configurations.get(provider) - if not provider_configuration: - raise ValueError(f"Provider {provider} does not exist.") - - # Enable model + provider_configuration = self._get_provider_configuration(tenant_id, provider) provider_configuration.enable_model(model=model, model_type=ModelType.value_of(model_type)) def disable_model(self, tenant_id: str, provider: str, model: str, model_type: str) -> None: @@ -469,13 +540,5 @@ class ModelProviderService: :param model_type: model type :return: """ - # Get all provider configurations of the current workspace - provider_configurations = self.provider_manager.get_configurations(tenant_id) - - # Get provider configuration - provider_configuration = provider_configurations.get(provider) - if not provider_configuration: - raise ValueError(f"Provider {provider} does not exist.") - - # Enable model + provider_configuration = self._get_provider_configuration(tenant_id, provider) provider_configuration.disable_model(model=model, model_type=ModelType.value_of(model_type)) diff --git a/api/tests/unit_tests/core/test_provider_configuration.py b/api/tests/unit_tests/core/test_provider_configuration.py new file mode 100644 index 0000000000..75621ecb6a --- /dev/null +++ b/api/tests/unit_tests/core/test_provider_configuration.py @@ -0,0 +1,308 @@ +from unittest.mock import Mock, patch + +import pytest + +from core.entities.provider_configuration import ProviderConfiguration, SystemConfigurationStatus +from core.entities.provider_entities import ( + CustomConfiguration, + ModelSettings, + ProviderQuotaType, + QuotaConfiguration, + QuotaUnit, + RestrictModel, + SystemConfiguration, +) +from core.model_runtime.entities.common_entities import I18nObject +from core.model_runtime.entities.model_entities import ModelType +from core.model_runtime.entities.provider_entities import ConfigurateMethod, ProviderEntity +from models.provider import Provider, ProviderType + + +@pytest.fixture +def mock_provider_entity(): + """Mock provider entity with basic configuration""" + provider_entity = ProviderEntity( + provider="openai", + label=I18nObject(en_US="OpenAI", zh_Hans="OpenAI"), + description=I18nObject(en_US="OpenAI provider", zh_Hans="OpenAI 提供商"), + icon_small=I18nObject(en_US="icon.png", zh_Hans="icon.png"), + icon_large=I18nObject(en_US="icon.png", zh_Hans="icon.png"), + background="background.png", + help=None, + supported_model_types=[ModelType.LLM], + configurate_methods=[ConfigurateMethod.PREDEFINED_MODEL], + provider_credential_schema=None, + model_credential_schema=None, + ) + + return provider_entity + + +@pytest.fixture +def mock_system_configuration(): + """Mock system configuration""" + quota_config = QuotaConfiguration( + quota_type=ProviderQuotaType.TRIAL, + quota_unit=QuotaUnit.TOKENS, + quota_limit=1000, + quota_used=0, + is_valid=True, + restrict_models=[RestrictModel(model="gpt-4", reason="Experimental", model_type=ModelType.LLM)], + ) + + system_config = SystemConfiguration( + enabled=True, + credentials={"openai_api_key": "test_key"}, + quota_configurations=[quota_config], + current_quota_type=ProviderQuotaType.TRIAL, + ) + + return system_config + + +@pytest.fixture +def mock_custom_configuration(): + """Mock custom configuration""" + custom_config = CustomConfiguration(provider=None, models=[]) + return custom_config + + +@pytest.fixture +def provider_configuration(mock_provider_entity, mock_system_configuration, mock_custom_configuration): + """Create a test provider configuration instance""" + with patch("core.entities.provider_configuration.original_provider_configurate_methods", {}): + return ProviderConfiguration( + tenant_id="test_tenant", + provider=mock_provider_entity, + preferred_provider_type=ProviderType.SYSTEM, + using_provider_type=ProviderType.SYSTEM, + system_configuration=mock_system_configuration, + custom_configuration=mock_custom_configuration, + model_settings=[], + ) + + +class TestProviderConfiguration: + """Test cases for ProviderConfiguration class""" + + def test_get_current_credentials_system_provider_success(self, provider_configuration): + """Test successfully getting credentials from system provider""" + # Arrange + provider_configuration.using_provider_type = ProviderType.SYSTEM + + # Act + credentials = provider_configuration.get_current_credentials(ModelType.LLM, "gpt-4") + + # Assert + assert credentials == {"openai_api_key": "test_key"} + + def test_get_current_credentials_model_disabled(self, provider_configuration): + """Test getting credentials when model is disabled""" + # Arrange + model_setting = ModelSettings( + model="gpt-4", + model_type=ModelType.LLM, + enabled=False, + load_balancing_configs=[], + has_invalid_load_balancing_configs=False, + ) + provider_configuration.model_settings = [model_setting] + + # Act & Assert + with pytest.raises(ValueError, match="Model gpt-4 is disabled"): + provider_configuration.get_current_credentials(ModelType.LLM, "gpt-4") + + def test_get_current_credentials_custom_provider_with_models(self, provider_configuration): + """Test getting credentials from custom provider with model configurations""" + # Arrange + provider_configuration.using_provider_type = ProviderType.CUSTOM + + mock_model_config = Mock() + mock_model_config.model_type = ModelType.LLM + mock_model_config.model = "gpt-4" + mock_model_config.credentials = {"openai_api_key": "custom_key"} + provider_configuration.custom_configuration.models = [mock_model_config] + + # Act + credentials = provider_configuration.get_current_credentials(ModelType.LLM, "gpt-4") + + # Assert + assert credentials == {"openai_api_key": "custom_key"} + + def test_get_system_configuration_status_active(self, provider_configuration): + """Test getting active system configuration status""" + # Arrange + provider_configuration.system_configuration.enabled = True + + # Act + status = provider_configuration.get_system_configuration_status() + + # Assert + assert status == SystemConfigurationStatus.ACTIVE + + def test_get_system_configuration_status_unsupported(self, provider_configuration): + """Test getting unsupported system configuration status""" + # Arrange + provider_configuration.system_configuration.enabled = False + + # Act + status = provider_configuration.get_system_configuration_status() + + # Assert + assert status == SystemConfigurationStatus.UNSUPPORTED + + def test_get_system_configuration_status_quota_exceeded(self, provider_configuration): + """Test getting quota exceeded system configuration status""" + # Arrange + provider_configuration.system_configuration.enabled = True + quota_config = provider_configuration.system_configuration.quota_configurations[0] + quota_config.is_valid = False + + # Act + status = provider_configuration.get_system_configuration_status() + + # Assert + assert status == SystemConfigurationStatus.QUOTA_EXCEEDED + + def test_is_custom_configuration_available_with_provider(self, provider_configuration): + """Test custom configuration availability with provider credentials""" + # Arrange + mock_provider = Mock() + mock_provider.available_credentials = ["openai_api_key"] + provider_configuration.custom_configuration.provider = mock_provider + provider_configuration.custom_configuration.models = [] + + # Act + result = provider_configuration.is_custom_configuration_available() + + # Assert + assert result is True + + def test_is_custom_configuration_available_with_models(self, provider_configuration): + """Test custom configuration availability with model configurations""" + # Arrange + provider_configuration.custom_configuration.provider = None + provider_configuration.custom_configuration.models = [Mock()] + + # Act + result = provider_configuration.is_custom_configuration_available() + + # Assert + assert result is True + + def test_is_custom_configuration_available_false(self, provider_configuration): + """Test custom configuration not available""" + # Arrange + provider_configuration.custom_configuration.provider = None + provider_configuration.custom_configuration.models = [] + + # Act + result = provider_configuration.is_custom_configuration_available() + + # Assert + assert result is False + + @patch("core.entities.provider_configuration.Session") + def test_get_provider_record_found(self, mock_session, provider_configuration): + """Test getting provider record successfully""" + # Arrange + mock_provider = Mock(spec=Provider) + mock_session_instance = Mock() + mock_session.return_value.__enter__.return_value = mock_session_instance + mock_session_instance.execute.return_value.scalar_one_or_none.return_value = mock_provider + + # Act + result = provider_configuration._get_provider_record(mock_session_instance) + + # Assert + assert result == mock_provider + + @patch("core.entities.provider_configuration.Session") + def test_get_provider_record_not_found(self, mock_session, provider_configuration): + """Test getting provider record when not found""" + # Arrange + mock_session_instance = Mock() + mock_session.return_value.__enter__.return_value = mock_session_instance + mock_session_instance.execute.return_value.scalar_one_or_none.return_value = None + + # Act + result = provider_configuration._get_provider_record(mock_session_instance) + + # Assert + assert result is None + + def test_init_with_customizable_model_only( + self, mock_provider_entity, mock_system_configuration, mock_custom_configuration + ): + """Test initialization with customizable model only configuration""" + # Arrange + mock_provider_entity.configurate_methods = [ConfigurateMethod.CUSTOMIZABLE_MODEL] + + # Act + with patch("core.entities.provider_configuration.original_provider_configurate_methods", {}): + config = ProviderConfiguration( + tenant_id="test_tenant", + provider=mock_provider_entity, + preferred_provider_type=ProviderType.SYSTEM, + using_provider_type=ProviderType.SYSTEM, + system_configuration=mock_system_configuration, + custom_configuration=mock_custom_configuration, + model_settings=[], + ) + + # Assert + assert ConfigurateMethod.PREDEFINED_MODEL in config.provider.configurate_methods + + def test_get_current_credentials_with_restricted_models(self, provider_configuration): + """Test getting credentials with model restrictions""" + # Arrange + provider_configuration.using_provider_type = ProviderType.SYSTEM + + # Act + credentials = provider_configuration.get_current_credentials(ModelType.LLM, "gpt-3.5-turbo") + + # Assert + assert credentials is not None + assert "openai_api_key" in credentials + + @patch("core.entities.provider_configuration.Session") + def test_get_specific_provider_credential_success(self, mock_session, provider_configuration): + """Test getting specific provider credential successfully""" + # Arrange + credential_id = "test_credential_id" + mock_credential = Mock() + mock_credential.encrypted_config = '{"openai_api_key": "encrypted_key"}' + + mock_session_instance = Mock() + mock_session.return_value.__enter__.return_value = mock_session_instance + mock_session_instance.execute.return_value.scalar_one_or_none.return_value = mock_credential + + # Act + with patch.object(provider_configuration, "_get_specific_provider_credential") as mock_get: + mock_get.return_value = {"openai_api_key": "test_key"} + result = provider_configuration._get_specific_provider_credential(credential_id) + + # Assert + assert result == {"openai_api_key": "test_key"} + + @patch("core.entities.provider_configuration.Session") + def test_get_specific_provider_credential_not_found(self, mock_session, provider_configuration): + """Test getting specific provider credential when not found""" + # Arrange + credential_id = "nonexistent_credential_id" + + mock_session_instance = Mock() + mock_session.return_value.__enter__.return_value = mock_session_instance + mock_session_instance.execute.return_value.scalar_one_or_none.return_value = None + + # Act & Assert + with patch.object(provider_configuration, "_get_specific_provider_credential") as mock_get: + mock_get.return_value = None + result = provider_configuration._get_specific_provider_credential(credential_id) + assert result is None + + # Act + credentials = provider_configuration.get_current_credentials(ModelType.LLM, "gpt-4") + + # Assert + assert credentials == {"openai_api_key": "test_key"} diff --git a/api/tests/unit_tests/core/test_provider_manager.py b/api/tests/unit_tests/core/test_provider_manager.py index 90d5a6f15b..2dab394029 100644 --- a/api/tests/unit_tests/core/test_provider_manager.py +++ b/api/tests/unit_tests/core/test_provider_manager.py @@ -1,190 +1,185 @@ -# from core.entities.provider_entities import ModelSettings -# from core.model_runtime.entities.model_entities import ModelType -# from core.model_runtime.model_providers.model_provider_factory import ModelProviderFactory -# from core.provider_manager import ProviderManager -# from models.provider import LoadBalancingModelConfig, ProviderModelSetting +import pytest + +from core.entities.provider_entities import ModelSettings +from core.model_runtime.entities.model_entities import ModelType +from core.provider_manager import ProviderManager +from models.provider import LoadBalancingModelConfig, ProviderModelSetting -# def test__to_model_settings(mocker): -# # Get all provider entities -# model_provider_factory = ModelProviderFactory("test_tenant") -# provider_entities = model_provider_factory.get_providers() +@pytest.fixture +def mock_provider_entity(mocker): + mock_entity = mocker.Mock() + mock_entity.provider = "openai" + mock_entity.configurate_methods = ["predefined-model"] + mock_entity.supported_model_types = [ModelType.LLM] -# provider_entity = None -# for provider in provider_entities: -# if provider.provider == "openai": -# provider_entity = provider + mock_entity.model_credential_schema = mocker.Mock() + mock_entity.model_credential_schema.credential_form_schemas = [] -# # Mocking the inputs -# provider_model_settings = [ -# ProviderModelSetting( -# id="id", -# tenant_id="tenant_id", -# provider_name="openai", -# model_name="gpt-4", -# model_type="text-generation", -# enabled=True, -# load_balancing_enabled=True, -# ) -# ] -# load_balancing_model_configs = [ -# LoadBalancingModelConfig( -# id="id1", -# tenant_id="tenant_id", -# provider_name="openai", -# model_name="gpt-4", -# model_type="text-generation", -# name="__inherit__", -# encrypted_config=None, -# enabled=True, -# ), -# LoadBalancingModelConfig( -# id="id2", -# tenant_id="tenant_id", -# provider_name="openai", -# model_name="gpt-4", -# model_type="text-generation", -# name="first", -# encrypted_config='{"openai_api_key": "fake_key"}', -# enabled=True, -# ), -# ] - -# mocker.patch( -# "core.helper.model_provider_cache.ProviderCredentialsCache.get", return_value={"openai_api_key": "fake_key"} -# ) - -# provider_manager = ProviderManager() - -# # Running the method -# result = provider_manager._to_model_settings(provider_entity, -# provider_model_settings, load_balancing_model_configs) - -# # Asserting that the result is as expected -# assert len(result) == 1 -# assert isinstance(result[0], ModelSettings) -# assert result[0].model == "gpt-4" -# assert result[0].model_type == ModelType.LLM -# assert result[0].enabled is True -# assert len(result[0].load_balancing_configs) == 2 -# assert result[0].load_balancing_configs[0].name == "__inherit__" -# assert result[0].load_balancing_configs[1].name == "first" + return mock_entity -# def test__to_model_settings_only_one_lb(mocker): -# # Get all provider entities -# model_provider_factory = ModelProviderFactory("test_tenant") -# provider_entities = model_provider_factory.get_providers() +def test__to_model_settings(mocker, mock_provider_entity): + # Mocking the inputs + provider_model_settings = [ + ProviderModelSetting( + id="id", + tenant_id="tenant_id", + provider_name="openai", + model_name="gpt-4", + model_type="text-generation", + enabled=True, + load_balancing_enabled=True, + ) + ] + load_balancing_model_configs = [ + LoadBalancingModelConfig( + id="id1", + tenant_id="tenant_id", + provider_name="openai", + model_name="gpt-4", + model_type="text-generation", + name="__inherit__", + encrypted_config=None, + enabled=True, + ), + LoadBalancingModelConfig( + id="id2", + tenant_id="tenant_id", + provider_name="openai", + model_name="gpt-4", + model_type="text-generation", + name="first", + encrypted_config='{"openai_api_key": "fake_key"}', + enabled=True, + ), + ] -# provider_entity = None -# for provider in provider_entities: -# if provider.provider == "openai": -# provider_entity = provider + mocker.patch( + "core.helper.model_provider_cache.ProviderCredentialsCache.get", return_value={"openai_api_key": "fake_key"} + ) -# # Mocking the inputs -# provider_model_settings = [ -# ProviderModelSetting( -# id="id", -# tenant_id="tenant_id", -# provider_name="openai", -# model_name="gpt-4", -# model_type="text-generation", -# enabled=True, -# load_balancing_enabled=True, -# ) -# ] -# load_balancing_model_configs = [ -# LoadBalancingModelConfig( -# id="id1", -# tenant_id="tenant_id", -# provider_name="openai", -# model_name="gpt-4", -# model_type="text-generation", -# name="__inherit__", -# encrypted_config=None, -# enabled=True, -# ) -# ] + provider_manager = ProviderManager() -# mocker.patch( -# "core.helper.model_provider_cache.ProviderCredentialsCache.get", return_value={"openai_api_key": "fake_key"} -# ) + # Running the method + result = provider_manager._to_model_settings( + provider_entity=mock_provider_entity, + provider_model_settings=provider_model_settings, + load_balancing_model_configs=load_balancing_model_configs, + ) -# provider_manager = ProviderManager() - -# # Running the method -# result = provider_manager._to_model_settings( -# provider_entity, provider_model_settings, load_balancing_model_configs) - -# # Asserting that the result is as expected -# assert len(result) == 1 -# assert isinstance(result[0], ModelSettings) -# assert result[0].model == "gpt-4" -# assert result[0].model_type == ModelType.LLM -# assert result[0].enabled is True -# assert len(result[0].load_balancing_configs) == 0 + # Asserting that the result is as expected + assert len(result) == 1 + assert isinstance(result[0], ModelSettings) + assert result[0].model == "gpt-4" + assert result[0].model_type == ModelType.LLM + assert result[0].enabled is True + assert len(result[0].load_balancing_configs) == 2 + assert result[0].load_balancing_configs[0].name == "__inherit__" + assert result[0].load_balancing_configs[1].name == "first" -# def test__to_model_settings_lb_disabled(mocker): -# # Get all provider entities -# model_provider_factory = ModelProviderFactory("test_tenant") -# provider_entities = model_provider_factory.get_providers() +def test__to_model_settings_only_one_lb(mocker, mock_provider_entity): + # Mocking the inputs + provider_model_settings = [ + ProviderModelSetting( + id="id", + tenant_id="tenant_id", + provider_name="openai", + model_name="gpt-4", + model_type="text-generation", + enabled=True, + load_balancing_enabled=True, + ) + ] + load_balancing_model_configs = [ + LoadBalancingModelConfig( + id="id1", + tenant_id="tenant_id", + provider_name="openai", + model_name="gpt-4", + model_type="text-generation", + name="__inherit__", + encrypted_config=None, + enabled=True, + ) + ] -# provider_entity = None -# for provider in provider_entities: -# if provider.provider == "openai": -# provider_entity = provider + mocker.patch( + "core.helper.model_provider_cache.ProviderCredentialsCache.get", return_value={"openai_api_key": "fake_key"} + ) -# # Mocking the inputs -# provider_model_settings = [ -# ProviderModelSetting( -# id="id", -# tenant_id="tenant_id", -# provider_name="openai", -# model_name="gpt-4", -# model_type="text-generation", -# enabled=True, -# load_balancing_enabled=False, -# ) -# ] -# load_balancing_model_configs = [ -# LoadBalancingModelConfig( -# id="id1", -# tenant_id="tenant_id", -# provider_name="openai", -# model_name="gpt-4", -# model_type="text-generation", -# name="__inherit__", -# encrypted_config=None, -# enabled=True, -# ), -# LoadBalancingModelConfig( -# id="id2", -# tenant_id="tenant_id", -# provider_name="openai", -# model_name="gpt-4", -# model_type="text-generation", -# name="first", -# encrypted_config='{"openai_api_key": "fake_key"}', -# enabled=True, -# ), -# ] + provider_manager = ProviderManager() -# mocker.patch( -# "core.helper.model_provider_cache.ProviderCredentialsCache.get", -# return_value={"openai_api_key": "fake_key"} -# ) + # Running the method + result = provider_manager._to_model_settings( + provider_entity=mock_provider_entity, + provider_model_settings=provider_model_settings, + load_balancing_model_configs=load_balancing_model_configs, + ) -# provider_manager = ProviderManager() + # Asserting that the result is as expected + assert len(result) == 1 + assert isinstance(result[0], ModelSettings) + assert result[0].model == "gpt-4" + assert result[0].model_type == ModelType.LLM + assert result[0].enabled is True + assert len(result[0].load_balancing_configs) == 0 -# # Running the method -# result = provider_manager._to_model_settings(provider_entity, -# provider_model_settings, load_balancing_model_configs) -# # Asserting that the result is as expected -# assert len(result) == 1 -# assert isinstance(result[0], ModelSettings) -# assert result[0].model == "gpt-4" -# assert result[0].model_type == ModelType.LLM -# assert result[0].enabled is True -# assert len(result[0].load_balancing_configs) == 0 +def test__to_model_settings_lb_disabled(mocker, mock_provider_entity): + # Mocking the inputs + provider_model_settings = [ + ProviderModelSetting( + id="id", + tenant_id="tenant_id", + provider_name="openai", + model_name="gpt-4", + model_type="text-generation", + enabled=True, + load_balancing_enabled=False, + ) + ] + load_balancing_model_configs = [ + LoadBalancingModelConfig( + id="id1", + tenant_id="tenant_id", + provider_name="openai", + model_name="gpt-4", + model_type="text-generation", + name="__inherit__", + encrypted_config=None, + enabled=True, + ), + LoadBalancingModelConfig( + id="id2", + tenant_id="tenant_id", + provider_name="openai", + model_name="gpt-4", + model_type="text-generation", + name="first", + encrypted_config='{"openai_api_key": "fake_key"}', + enabled=True, + ), + ] + + mocker.patch( + "core.helper.model_provider_cache.ProviderCredentialsCache.get", return_value={"openai_api_key": "fake_key"} + ) + + provider_manager = ProviderManager() + + # Running the method + result = provider_manager._to_model_settings( + provider_entity=mock_provider_entity, + provider_model_settings=provider_model_settings, + load_balancing_model_configs=load_balancing_model_configs, + ) + + # Asserting that the result is as expected + assert len(result) == 1 + assert isinstance(result[0], ModelSettings) + assert result[0].model == "gpt-4" + assert result[0].model_type == ModelType.LLM + assert result[0].enabled is True + assert len(result[0].load_balancing_configs) == 0