From b9a6bf89ef60e721d583e6bcddcaf7c312d1d64a Mon Sep 17 00:00:00 2001 From: hjlarry Date: Tue, 19 Aug 2025 17:36:01 +0800 Subject: [PATCH] load balance save api also can switch custom model credential_id --- api/controllers/console/workspace/models.py | 15 ++++++++- api/core/entities/provider_configuration.py | 36 +++++++++++++++++++-- api/services/model_provider_service.py | 20 +++++++++++- 3 files changed, 66 insertions(+), 5 deletions(-) diff --git a/api/controllers/console/workspace/models.py b/api/controllers/console/workspace/models.py index f223476d02..e5179bf1d9 100644 --- a/api/controllers/console/workspace/models.py +++ b/api/controllers/console/workspace/models.py @@ -117,8 +117,21 @@ class ModelProviderModelApi(Resource): ) parser.add_argument("load_balancing", type=dict, required=False, nullable=True, location="json") parser.add_argument("config_from", type=str, required=False, nullable=True, location="json") + parser.add_argument("credential_id", type=uuid_value, required=False, nullable=True, location="json") args = parser.parse_args() + if args.get("config_from", "") == "custom-model": + if not args.get("credential_id"): + raise ValueError("credential_id is required when configuring a custom-model") + service = ModelProviderService() + service.switch_active_custom_model_credential( + tenant_id=current_user.current_tenant_id, + provider=provider, + model_type=args["model_type"], + model=args["model"], + credential_id=args["credential_id"], + ) + model_load_balancing_service = ModelLoadBalancingService() if ( @@ -375,7 +388,7 @@ class ModelProviderModelCredentialSwitchApi(Resource): args = parser.parse_args() service = ModelProviderService() - service.switch_active_provider_model_credential( + service.add_model_credential_to_model_list( tenant_id=current_user.current_tenant_id, provider=provider, model_type=args["model_type"], diff --git a/api/core/entities/provider_configuration.py b/api/core/entities/provider_configuration.py index 6cf9a403e3..628b282d67 100644 --- a/api/core/entities/provider_configuration.py +++ b/api/core/entities/provider_configuration.py @@ -1014,10 +1014,10 @@ class ProviderConfiguration(BaseModel): session.rollback() raise - def switch_custom_model_credential(self, model_type: ModelType, model: str, credential_id: str) -> None: + def add_model_credential_to_model(self, model_type: ModelType, model: str, credential_id: str) -> None: """ - Not only switch the custom model credential. - It can also add credential to a new custom model record. + if model list exist this custom model, switch the custom model credential. + if model list not exist this custom model, use the credential to add a new custom model record. :param model_type: model type :param model: model name @@ -1056,6 +1056,36 @@ class ProviderConfiguration(BaseModel): session.add(provider_model_record) session.commit() + def switch_custom_model_credential(self, model_type: ModelType, model: str, credential_id: str) -> None: + """ + switch the custom model credential. + + :param model_type: model type + :param model: model name + :param credential_id: credential id + :return: + """ + with Session(db.engine) as session: + stmt = select(ProviderModelCredential).where( + ProviderModelCredential.id == credential_id, + ProviderModelCredential.tenant_id == self.tenant_id, + ProviderModelCredential.provider_name == self.provider.provider, + ProviderModelCredential.model_name == model, + ProviderModelCredential.model_type == model_type.to_origin_model_type(), + ) + credential_record = session.execute(stmt).scalar_one_or_none() + if not credential_record: + raise ValueError("Credential record not found.") + + provider_model_record = self._get_custom_model_record(model_type=model_type, model=model, session=session) + if not provider_model_record: + raise ValueError("The custom model record not found.") + + provider_model_record.credential_id = credential_record.id + provider_model_record.updated_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None) + session.add(provider_model_record) + session.commit() + def delete_custom_model(self, model_type: ModelType, model: str) -> None: """ Delete custom model. diff --git a/api/services/model_provider_service.py b/api/services/model_provider_service.py index 62cf9884e4..67c3f0d6b2 100644 --- a/api/services/model_provider_service.py +++ b/api/services/model_provider_service.py @@ -319,7 +319,7 @@ class ModelProviderService: model_type=ModelType.value_of(model_type), model=model, credential_id=credential_id ) - def switch_active_provider_model_credential( + def switch_active_custom_model_credential( self, tenant_id: str, provider: str, model_type: str, model: str, credential_id: str ) -> None: """ @@ -337,6 +337,24 @@ class ModelProviderService: model_type=ModelType.value_of(model_type), model=model, credential_id=credential_id ) + def add_model_credential_to_model_list( + self, tenant_id: str, provider: str, model_type: str, model: str, credential_id: str + ) -> None: + """ + add model credentials to model list. + + :param tenant_id: workspace id + :param provider: provider name + :param model_type: model type + :param model: model name + :param credential_id: credential id + :return: + """ + provider_configuration = self._get_provider_configuration(tenant_id, provider) + provider_configuration.add_model_credential_to_model( + model_type=ModelType.value_of(model_type), model=model, credential_id=credential_id + ) + def remove_model(self, tenant_id: str, provider: str, model_type: str, model: str) -> None: """ remove model credentials.