mirror of https://github.com/langgenius/dify.git
load balance save api also can switch custom model credential_id
This commit is contained in:
parent
416b2634ed
commit
b9a6bf89ef
|
|
@ -117,8 +117,21 @@ class ModelProviderModelApi(Resource):
|
|||
)
|
||||
parser.add_argument("load_balancing", type=dict, required=False, nullable=True, location="json")
|
||||
parser.add_argument("config_from", type=str, required=False, nullable=True, location="json")
|
||||
parser.add_argument("credential_id", type=uuid_value, required=False, nullable=True, location="json")
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.get("config_from", "") == "custom-model":
|
||||
if not args.get("credential_id"):
|
||||
raise ValueError("credential_id is required when configuring a custom-model")
|
||||
service = ModelProviderService()
|
||||
service.switch_active_custom_model_credential(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
provider=provider,
|
||||
model_type=args["model_type"],
|
||||
model=args["model"],
|
||||
credential_id=args["credential_id"],
|
||||
)
|
||||
|
||||
model_load_balancing_service = ModelLoadBalancingService()
|
||||
|
||||
if (
|
||||
|
|
@ -375,7 +388,7 @@ class ModelProviderModelCredentialSwitchApi(Resource):
|
|||
args = parser.parse_args()
|
||||
|
||||
service = ModelProviderService()
|
||||
service.switch_active_provider_model_credential(
|
||||
service.add_model_credential_to_model_list(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
provider=provider,
|
||||
model_type=args["model_type"],
|
||||
|
|
|
|||
|
|
@ -1014,10 +1014,10 @@ class ProviderConfiguration(BaseModel):
|
|||
session.rollback()
|
||||
raise
|
||||
|
||||
def switch_custom_model_credential(self, model_type: ModelType, model: str, credential_id: str) -> None:
|
||||
def add_model_credential_to_model(self, model_type: ModelType, model: str, credential_id: str) -> None:
|
||||
"""
|
||||
Not only switch the custom model credential.
|
||||
It can also add credential to a new custom model record.
|
||||
if model list exist this custom model, switch the custom model credential.
|
||||
if model list not exist this custom model, use the credential to add a new custom model record.
|
||||
|
||||
:param model_type: model type
|
||||
:param model: model name
|
||||
|
|
@ -1056,6 +1056,36 @@ class ProviderConfiguration(BaseModel):
|
|||
session.add(provider_model_record)
|
||||
session.commit()
|
||||
|
||||
def switch_custom_model_credential(self, model_type: ModelType, model: str, credential_id: str) -> None:
|
||||
"""
|
||||
switch the custom model credential.
|
||||
|
||||
:param model_type: model type
|
||||
:param model: model name
|
||||
:param credential_id: credential id
|
||||
:return:
|
||||
"""
|
||||
with Session(db.engine) as session:
|
||||
stmt = select(ProviderModelCredential).where(
|
||||
ProviderModelCredential.id == credential_id,
|
||||
ProviderModelCredential.tenant_id == self.tenant_id,
|
||||
ProviderModelCredential.provider_name == self.provider.provider,
|
||||
ProviderModelCredential.model_name == model,
|
||||
ProviderModelCredential.model_type == model_type.to_origin_model_type(),
|
||||
)
|
||||
credential_record = session.execute(stmt).scalar_one_or_none()
|
||||
if not credential_record:
|
||||
raise ValueError("Credential record not found.")
|
||||
|
||||
provider_model_record = self._get_custom_model_record(model_type=model_type, model=model, session=session)
|
||||
if not provider_model_record:
|
||||
raise ValueError("The custom model record not found.")
|
||||
|
||||
provider_model_record.credential_id = credential_record.id
|
||||
provider_model_record.updated_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None)
|
||||
session.add(provider_model_record)
|
||||
session.commit()
|
||||
|
||||
def delete_custom_model(self, model_type: ModelType, model: str) -> None:
|
||||
"""
|
||||
Delete custom model.
|
||||
|
|
|
|||
|
|
@ -319,7 +319,7 @@ class ModelProviderService:
|
|||
model_type=ModelType.value_of(model_type), model=model, credential_id=credential_id
|
||||
)
|
||||
|
||||
def switch_active_provider_model_credential(
|
||||
def switch_active_custom_model_credential(
|
||||
self, tenant_id: str, provider: str, model_type: str, model: str, credential_id: str
|
||||
) -> None:
|
||||
"""
|
||||
|
|
@ -337,6 +337,24 @@ class ModelProviderService:
|
|||
model_type=ModelType.value_of(model_type), model=model, credential_id=credential_id
|
||||
)
|
||||
|
||||
def add_model_credential_to_model_list(
|
||||
self, tenant_id: str, provider: str, model_type: str, model: str, credential_id: str
|
||||
) -> None:
|
||||
"""
|
||||
add model credentials to model list.
|
||||
|
||||
:param tenant_id: workspace id
|
||||
:param provider: provider name
|
||||
:param model_type: model type
|
||||
:param model: model name
|
||||
:param credential_id: credential id
|
||||
:return:
|
||||
"""
|
||||
provider_configuration = self._get_provider_configuration(tenant_id, provider)
|
||||
provider_configuration.add_model_credential_to_model(
|
||||
model_type=ModelType.value_of(model_type), model=model, credential_id=credential_id
|
||||
)
|
||||
|
||||
def remove_model(self, tenant_id: str, provider: str, model_type: str, model: str) -> None:
|
||||
"""
|
||||
remove model credentials.
|
||||
|
|
|
|||
Loading…
Reference in New Issue