From c40cb7fd597549382b4e5ac25692437ea98fb68f Mon Sep 17 00:00:00 2001 From: -LAN- Date: Wed, 3 Sep 2025 13:34:07 +0800 Subject: [PATCH 1/4] [Chore/Refactor] Update .gitignore to exclude pyrightconfig.json while preserving api/pyrightconfig.json (#25055) --- .gitignore | 2 ++ 1 file changed, 2 insertions(+) diff --git a/.gitignore b/.gitignore index 8b75bd0db4..8a5a34cf88 100644 --- a/.gitignore +++ b/.gitignore @@ -127,6 +127,8 @@ venv.bak/ .mypy_cache/ .dmypy.json dmypy.json +pyrightconfig.json +!api/pyrightconfig.json # Pyre type checker .pyre/ From b88146c4439df9824c48efa343d1a7302931fe0f Mon Sep 17 00:00:00 2001 From: -LAN- Date: Wed, 3 Sep 2025 13:34:43 +0800 Subject: [PATCH 2/4] chore: consolidate type checking in style workflow (#25053) --- .github/workflows/api-tests.yml | 3 --- .github/workflows/style.yml | 4 ++-- dev/basedpyright-check | 2 +- 3 files changed, 3 insertions(+), 6 deletions(-) diff --git a/.github/workflows/api-tests.yml b/.github/workflows/api-tests.yml index 5efe422904..116fc59ee8 100644 --- a/.github/workflows/api-tests.yml +++ b/.github/workflows/api-tests.yml @@ -62,9 +62,6 @@ jobs: - name: Run dify config tests run: uv run --project api dev/pytest/pytest_config_tests.py - - name: Run Basedpyright Checks - run: dev/basedpyright-check - - name: Set up dotenvs run: | cp docker/.env.example docker/.env diff --git a/.github/workflows/style.yml b/.github/workflows/style.yml index a3643c9931..aaabec0cb5 100644 --- a/.github/workflows/style.yml +++ b/.github/workflows/style.yml @@ -44,9 +44,9 @@ jobs: if: steps.changed-files.outputs.any_changed == 'true' run: uv sync --project api --dev - - name: Run ty check + - name: Run Basedpyright Checks if: steps.changed-files.outputs.any_changed == 'true' - run: dev/ty-check + run: dev/basedpyright-check - name: Dotenv check if: steps.changed-files.outputs.any_changed == 'true' diff --git a/dev/basedpyright-check b/dev/basedpyright-check index db0f02335e..267ef2a522 100755 --- a/dev/basedpyright-check +++ b/dev/basedpyright-check @@ -6,4 +6,4 @@ SCRIPT_DIR="$(dirname "$(realpath "$0")")" cd "$SCRIPT_DIR/.." # run basedpyright checks -uv --directory api run basedpyright +uv run --directory api --dev basedpyright From 9e125e2029184c4642f13a914961d35860aa998e Mon Sep 17 00:00:00 2001 From: zxhlyh Date: Wed, 3 Sep 2025 13:36:59 +0800 Subject: [PATCH 3/4] Refactor/model credential (#24994) --- .../base/form/components/base/base-field.tsx | 37 +- .../base/form/components/base/base-form.tsx | 34 +- .../model-provider-page/declarations.ts | 12 + .../model-provider-page/hooks.ts | 25 +- .../add-credential-in-load-balancing.tsx | 71 ++-- .../model-auth/add-custom-model.tsx | 174 ++++++--- .../model-auth/authorized/authorized-item.tsx | 68 ++-- .../model-auth/authorized/credential-item.tsx | 18 +- .../model-auth/authorized/index.tsx | 137 ++++--- .../model-auth/config-model.tsx | 2 +- .../model-auth/config-provider.tsx | 41 +- .../model-auth/credential-selector.tsx | 115 ++++++ .../model-auth/hooks/use-auth-service.ts | 2 +- .../model-auth/hooks/use-auth.ts | 105 +++-- .../model-auth/hooks/use-custom-models.ts | 6 + .../hooks/use-model-form-schemas.ts | 42 +- .../model-provider-page/model-auth/index.tsx | 2 + .../manage-custom-model-credentials.tsx | 82 ++++ .../switch-credential-in-load-balancing.tsx | 32 +- .../model-provider-page/model-modal/index.tsx | 364 ++++++++++++------ .../provider-added-card/credential-panel.tsx | 1 - .../provider-added-card/index.tsx | 20 +- .../provider-added-card/model-list.tsx | 9 +- .../model-load-balancing-configs.tsx | 61 ++- .../model-load-balancing-modal.tsx | 317 ++++++++++----- .../model-provider-page/utils.ts | 2 +- web/context/modal-context.tsx | 23 +- web/i18n/en-US/common.ts | 9 + web/i18n/zh-Hans/common.ts | 9 + web/service/use-models.ts | 2 +- 30 files changed, 1226 insertions(+), 596 deletions(-) create mode 100644 web/app/components/header/account-setting/model-provider-page/model-auth/credential-selector.tsx create mode 100644 web/app/components/header/account-setting/model-provider-page/model-auth/manage-custom-model-credentials.tsx diff --git a/web/app/components/base/form/components/base/base-field.tsx b/web/app/components/base/form/components/base/base-field.tsx index f25dfb069d..bf415e08a8 100644 --- a/web/app/components/base/form/components/base/base-field.tsx +++ b/web/app/components/base/form/components/base/base-field.tsx @@ -1,6 +1,7 @@ import { isValidElement, memo, + useCallback, useMemo, } from 'react' import { RiExternalLinkLine } from '@remixicon/react' @@ -23,6 +24,7 @@ export type BaseFieldProps = { formSchema: FormSchema field: AnyFieldApi disabled?: boolean + onChange?: (field: string, value: any) => void } const BaseField = ({ fieldClassName, @@ -32,6 +34,7 @@ const BaseField = ({ formSchema, field, disabled: propsDisabled, + onChange, }: BaseFieldProps) => { const renderI18nObject = useRenderI18nObject() const { @@ -40,7 +43,6 @@ const BaseField = ({ placeholder, options, labelClassName: formLabelClassName, - show_on = [], disabled: formSchemaDisabled, } = formSchema const disabled = propsDisabled || formSchemaDisabled @@ -90,21 +92,11 @@ const BaseField = ({ }) || [] }, [options, renderI18nObject, optionValues]) const value = useStore(field.form.store, s => s.values[field.name]) - const values = useStore(field.form.store, (s) => { - return show_on.reduce((acc, condition) => { - acc[condition.variable] = s.values[condition.variable] - return acc - }, {} as Record) - }) - const show = useMemo(() => { - return show_on.every((condition) => { - const conditionValue = values[condition.variable] - return conditionValue === condition.value - }) - }, [values, show_on]) - if (!show) - return null + const handleChange = useCallback((value: any) => { + field.handleChange(value) + onChange?.(field.name, value) + }, [field, onChange]) return (
@@ -124,7 +116,9 @@ const BaseField = ({ name={field.name} className={cn(inputClassName)} value={value || ''} - onChange={e => field.handleChange(e.target.value)} + onChange={(e) => { + handleChange(e.target.value) + }} onBlur={field.handleBlur} disabled={disabled} placeholder={memorizedPlaceholder} @@ -139,7 +133,7 @@ const BaseField = ({ type='password' className={cn(inputClassName)} value={value || ''} - onChange={e => field.handleChange(e.target.value)} + onChange={e => handleChange(e.target.value)} onBlur={field.handleBlur} disabled={disabled} placeholder={memorizedPlaceholder} @@ -155,7 +149,7 @@ const BaseField = ({ type='number' className={cn(inputClassName)} value={value || ''} - onChange={e => field.handleChange(e.target.value)} + onChange={e => handleChange(e.target.value)} onBlur={field.handleBlur} disabled={disabled} placeholder={memorizedPlaceholder} @@ -166,11 +160,14 @@ const BaseField = ({ formSchema.type === FormTypeEnum.select && ( field.handleChange(v)} + onChange={v => handleChange(v)} disabled={disabled} placeholder={memorizedPlaceholder} options={memorizedOptions} triggerPopupSameWidth + popupProps={{ + className: 'max-h-[320px] overflow-y-auto', + }} /> ) } @@ -189,7 +186,7 @@ const BaseField = ({ disabled && 'cursor-not-allowed opacity-50', inputClassName, )} - onClick={() => !disabled && field.handleChange(option.value)} + onClick={() => !disabled && handleChange(option.value)} > { formSchema.showRadioUI && ( diff --git a/web/app/components/base/form/components/base/base-form.tsx b/web/app/components/base/form/components/base/base-form.tsx index c056829db4..6b7e992510 100644 --- a/web/app/components/base/form/components/base/base-form.tsx +++ b/web/app/components/base/form/components/base/base-form.tsx @@ -8,7 +8,10 @@ import type { AnyFieldApi, AnyFormApi, } from '@tanstack/react-form' -import { useForm } from '@tanstack/react-form' +import { + useForm, + useStore, +} from '@tanstack/react-form' import type { FormRef, FormSchema, @@ -32,6 +35,7 @@ export type BaseFormProps = { ref?: FormRef disabled?: boolean formFromProps?: AnyFormApi + onChange?: (field: string, value: any) => void } & Pick const BaseForm = ({ @@ -45,6 +49,7 @@ const BaseForm = ({ ref, disabled, formFromProps, + onChange, }: BaseFormProps) => { const initialDefaultValues = useMemo(() => { if (defaultValues) @@ -63,6 +68,19 @@ const BaseForm = ({ const { getFormValues } = useGetFormValues(form, formSchemas) const { getValidators } = useGetValidators() + const showOnValues = useStore(form.store, (s: any) => { + const result: Record = {} + formSchemas.forEach((schema) => { + const { show_on } = schema + if (show_on?.length) { + show_on.forEach((condition) => { + result[condition.variable] = s.values[condition.variable] + }) + } + }) + return result + }) + useImperativeHandle(ref, () => { return { getForm() { @@ -87,19 +105,29 @@ const BaseForm = ({ inputContainerClassName={inputContainerClassName} inputClassName={inputClassName} disabled={disabled} + onChange={onChange} /> ) } return null - }, [formSchemas, fieldClassName, labelClassName, inputContainerClassName, inputClassName, disabled]) + }, [formSchemas, fieldClassName, labelClassName, inputContainerClassName, inputClassName, disabled, onChange]) const renderFieldWrapper = useCallback((formSchema: FormSchema) => { const validators = getValidators(formSchema) const { name, + show_on = [], } = formSchema + const show = show_on?.every((condition) => { + const conditionValue = showOnValues[condition.variable] + return conditionValue === condition.value + }) + + if (!show) + return null + return ( ) - }, [renderField, form, getValidators]) + }, [renderField, form, getValidators, showOnValues]) if (!formSchemas?.length) return null diff --git a/web/app/components/header/account-setting/model-provider-page/declarations.ts b/web/app/components/header/account-setting/model-provider-page/declarations.ts index 74f47c9d1d..9fac34b21b 100644 --- a/web/app/components/header/account-setting/model-provider-page/declarations.ts +++ b/web/app/components/header/account-setting/model-provider-page/declarations.ts @@ -199,6 +199,7 @@ export type CustomModelCredential = CustomModel & { credentials?: Record available_model_credentials?: Credential[] current_credential_id?: string + current_credential_name?: string } export type CredentialWithModel = Credential & { @@ -236,6 +237,10 @@ export type ModelProvider = { current_credential_name?: string available_credentials?: Credential[] custom_models?: CustomModelCredential[] + can_added_models?: { + model: string + model_type: ModelTypeEnum + }[] } system_configuration: { enabled: boolean @@ -323,3 +328,10 @@ export type ModelCredential = { current_credential_id?: string current_credential_name?: string } + +export enum ModelModalModeEnum { + configProviderCredential = 'config-provider-credential', + configCustomModel = 'config-custom-model', + addCustomModelToModelList = 'add-custom-model-to-model-list', + configModelCredential = 'config-model-credential', +} diff --git a/web/app/components/header/account-setting/model-provider-page/hooks.ts b/web/app/components/header/account-setting/model-provider-page/hooks.ts index fa5130137a..c9e4f9961e 100644 --- a/web/app/components/header/account-setting/model-provider-page/hooks.ts +++ b/web/app/components/header/account-setting/model-provider-page/hooks.ts @@ -13,6 +13,7 @@ import type { DefaultModel, DefaultModelResponse, Model, + ModelModalModeEnum, ModelProvider, ModelTypeEnum, } from './declarations' @@ -348,29 +349,31 @@ export const useRefreshModel = () => { export const useModelModalHandler = () => { const setShowModelModal = useModalContextSelector(state => state.setShowModelModal) - const { handleRefreshModel } = useRefreshModel() return ( provider: ModelProvider, configurationMethod: ConfigurationMethodEnum, CustomConfigurationModelFixedFields?: CustomConfigurationModelFixedFields, - isModelCredential?: boolean, - credential?: Credential, - model?: CustomModel, - onUpdate?: () => void, + extra: { + isModelCredential?: boolean, + credential?: Credential, + model?: CustomModel, + onUpdate?: (newPayload: any, formValues?: Record) => void, + mode?: ModelModalModeEnum, + } = {}, ) => { setShowModelModal({ payload: { currentProvider: provider, currentConfigurationMethod: configurationMethod, currentCustomConfigurationModelFixedFields: CustomConfigurationModelFixedFields, - isModelCredential, - credential, - model, + isModelCredential: extra.isModelCredential, + credential: extra.credential, + model: extra.model, + mode: extra.mode, }, - onSaveCallback: () => { - handleRefreshModel(provider, configurationMethod, CustomConfigurationModelFixedFields) - onUpdate?.() + onSaveCallback: (newPayload, formValues) => { + extra.onUpdate?.(newPayload, formValues) }, }) } diff --git a/web/app/components/header/account-setting/model-provider-page/model-auth/add-credential-in-load-balancing.tsx b/web/app/components/header/account-setting/model-provider-page/model-auth/add-credential-in-load-balancing.tsx index a0c78e3292..30d56bced7 100644 --- a/web/app/components/header/account-setting/model-provider-page/model-auth/add-credential-in-load-balancing.tsx +++ b/web/app/components/header/account-setting/model-provider-page/model-auth/add-credential-in-load-balancing.tsx @@ -1,7 +1,6 @@ import { memo, useCallback, - useMemo, } from 'react' import { RiAddLine } from '@remixicon/react' import { useTranslation } from 'react-i18next' @@ -9,20 +8,22 @@ import { Authorized } from '@/app/components/header/account-setting/model-provid import cn from '@/utils/classnames' import type { Credential, + CustomConfigurationModelFixedFields, CustomModelCredential, ModelCredential, ModelProvider, } from '@/app/components/header/account-setting/model-provider-page/declarations' -import { ConfigurationMethodEnum } from '@/app/components/header/account-setting/model-provider-page/declarations' -import Tooltip from '@/app/components/base/tooltip' +import { ConfigurationMethodEnum, ModelModalModeEnum } from '@/app/components/header/account-setting/model-provider-page/declarations' type AddCredentialInLoadBalancingProps = { provider: ModelProvider model: CustomModelCredential configurationMethod: ConfigurationMethodEnum + currentCustomConfigurationModelFixedFields?: CustomConfigurationModelFixedFields modelCredential: ModelCredential onSelectCredential: (credential: Credential) => void - onUpdate?: () => void + onUpdate?: (payload?: any, formValues?: Record) => void + onRemove?: (credentialId: string) => void } const AddCredentialInLoadBalancing = ({ provider, @@ -31,41 +32,17 @@ const AddCredentialInLoadBalancing = ({ modelCredential, onSelectCredential, onUpdate, + onRemove, }: AddCredentialInLoadBalancingProps) => { const { t } = useTranslation() const { available_credentials, } = modelCredential - const customModel = configurationMethod === ConfigurationMethodEnum.customizableModel + const isCustomModel = configurationMethod === ConfigurationMethodEnum.customizableModel const notAllowCustomCredential = provider.allow_custom_token === false - - const ButtonComponent = useMemo(() => { - const Item = ( -
- - { - customModel - ? t('common.modelProvider.auth.addCredential') - : t('common.modelProvider.auth.addApiKey') - } -
- ) - - if (notAllowCustomCredential) { - return ( - - {Item} - - ) - } - return Item - }, [notAllowCustomCredential, t, customModel]) + const handleUpdate = useCallback((payload?: any, formValues?: Record) => { + onUpdate?.(payload, formValues) + }, [onUpdate]) const renderTrigger = useCallback((open?: boolean) => { const Item = ( @@ -74,40 +51,40 @@ const AddCredentialInLoadBalancing = ({ open && 'bg-state-base-hover', )}> - { - customModel - ? t('common.modelProvider.auth.addCredential') - : t('common.modelProvider.auth.addApiKey') - } + {t('common.modelProvider.auth.addCredential')}
) return Item - }, [t, customModel]) - - if (!available_credentials?.length) - return ButtonComponent + }, [t, isCustomModel]) return ( ) } diff --git a/web/app/components/header/account-setting/model-provider-page/model-auth/add-custom-model.tsx b/web/app/components/header/account-setting/model-provider-page/model-auth/add-custom-model.tsx index 0ec6fa45a0..dd9284398c 100644 --- a/web/app/components/header/account-setting/model-provider-page/model-auth/add-custom-model.tsx +++ b/web/app/components/header/account-setting/model-provider-page/model-auth/add-custom-model.tsx @@ -1,32 +1,39 @@ import { memo, useCallback, - useMemo, + useState, } from 'react' import { useTranslation } from 'react-i18next' import { RiAddCircleFill, + RiAddLine, } from '@remixicon/react' import { Button, } from '@/app/components/base/button' import type { + ConfigurationMethodEnum, CustomConfigurationModelFixedFields, ModelProvider, } from '@/app/components/header/account-setting/model-provider-page/declarations' -import { ConfigurationMethodEnum } from '@/app/components/header/account-setting/model-provider-page/declarations' -import Authorized from './authorized' -import { - useAuth, - useCustomModels, -} from './hooks' +import { ModelModalModeEnum } from '@/app/components/header/account-setting/model-provider-page/declarations' import cn from '@/utils/classnames' +import { + PortalToFollowElem, + PortalToFollowElemContent, + PortalToFollowElemTrigger, +} from '@/app/components/base/portal-to-follow-elem' +import ModelIcon from '../model-icon' +import { useCanAddedModels } from './hooks/use-custom-models' +import { useAuth } from './hooks/use-auth' import Tooltip from '@/app/components/base/tooltip' type AddCustomModelProps = { provider: ModelProvider, configurationMethod: ConfigurationMethodEnum, currentCustomConfigurationModelFixedFields?: CustomConfigurationModelFixedFields, + open?: boolean + onOpenChange?: (open: boolean) => void } const AddCustomModel = ({ provider, @@ -34,44 +41,32 @@ const AddCustomModel = ({ currentCustomConfigurationModelFixedFields, }: AddCustomModelProps) => { const { t } = useTranslation() - const customModels = useCustomModels(provider) - const noModels = !customModels.length + const [open, setOpen] = useState(false) + const canAddedModels = useCanAddedModels(provider) + const noModels = !canAddedModels.length const { - handleOpenModal, - } = useAuth(provider, configurationMethod, currentCustomConfigurationModelFixedFields, true) + handleOpenModal: handleOpenModalForAddNewCustomModel, + } = useAuth( + provider, + configurationMethod, + currentCustomConfigurationModelFixedFields, + { + isModelCredential: true, + mode: ModelModalModeEnum.configCustomModel, + }, + ) + const { + handleOpenModal: handleOpenModalForAddCustomModelToModelList, + } = useAuth( + provider, + configurationMethod, + currentCustomConfigurationModelFixedFields, + { + isModelCredential: true, + mode: ModelModalModeEnum.addCustomModelToModelList, + }, + ) const notAllowCustomCredential = provider.allow_custom_token === false - const handleClick = useCallback(() => { - if (notAllowCustomCredential) - return - - handleOpenModal() - }, [handleOpenModal, notAllowCustomCredential]) - const ButtonComponent = useMemo(() => { - const Item = ( - - ) - if (notAllowCustomCredential) { - return ( - - {Item} - - ) - } - return Item - }, [handleClick, notAllowCustomCredential, t]) const renderTrigger = useCallback((open?: boolean) => { const Item = ( @@ -79,32 +74,93 @@ const AddCustomModel = ({ variant='ghost' size='small' className={cn( + 'text-text-tertiary', open && 'bg-components-button-ghost-bg-hover', + notAllowCustomCredential && !!noModels && 'cursor-not-allowed opacity-50', )} > {t('common.modelProvider.addModel')} ) + if (notAllowCustomCredential && !!noModels) { + return ( + + {Item} + + ) + } return Item - }, [t]) - - if (noModels) - return ButtonComponent + }, [t, notAllowCustomCredential, noModels]) return ( - ({ - model, - credentials: model.available_model_credentials ?? [], - }))} - renderTrigger={renderTrigger} - isModelCredential - enableAddModelCredential - bottomAddModelCredentialText={t('common.modelProvider.auth.addNewModel')} - /> + + { + if (noModels) { + if (notAllowCustomCredential) + return + handleOpenModalForAddNewCustomModel() + return + } + + setOpen(prev => !prev) + }}> + {renderTrigger(open)} + + +
+
+ { + canAddedModels.map(model => ( +
{ + handleOpenModalForAddCustomModelToModelList(undefined, model) + setOpen(false) + }} + > + +
+ {model.model} +
+
+ )) + } +
+ { + !notAllowCustomCredential && ( +
{ + handleOpenModalForAddNewCustomModel() + setOpen(false) + }} + > + + {t('common.modelProvider.auth.addNewModel')} +
+ ) + } +
+
+
) } diff --git a/web/app/components/header/account-setting/model-provider-page/model-auth/authorized/authorized-item.tsx b/web/app/components/header/account-setting/model-provider-page/model-auth/authorized/authorized-item.tsx index 4f4c30bc9b..10dc48585c 100644 --- a/web/app/components/header/account-setting/model-provider-page/model-auth/authorized/authorized-item.tsx +++ b/web/app/components/header/account-setting/model-provider-page/model-auth/authorized/authorized-item.tsx @@ -2,18 +2,17 @@ import { memo, useCallback, } from 'react' -import { RiAddLine } from '@remixicon/react' -import { useTranslation } from 'react-i18next' import CredentialItem from './credential-item' import type { Credential, CustomModel, CustomModelCredential, + ModelProvider, } from '../../declarations' -import Button from '@/app/components/base/button' -import Tooltip from '@/app/components/base/tooltip' +import ModelIcon from '../../model-icon' type AuthorizedItemProps = { + provider: ModelProvider model?: CustomModelCredential title?: string disabled?: boolean @@ -25,8 +24,12 @@ type AuthorizedItemProps = { onItemClick?: (credential: Credential, model?: CustomModel) => void enableAddModelCredential?: boolean notAllowCustomCredential?: boolean + showModelTitle?: boolean + disableDeleteButShowAction?: boolean + disableDeleteTip?: string } export const AuthorizedItem = ({ + provider, model, title, credentials, @@ -36,10 +39,10 @@ export const AuthorizedItem = ({ showItemSelectedIcon, selectedCredentialId, onItemClick, - enableAddModelCredential, - notAllowCustomCredential, + showModelTitle, + disableDeleteButShowAction, + disableDeleteTip, }: AuthorizedItemProps) => { - const { t } = useTranslation() const handleEdit = useCallback((credential?: Credential) => { onEdit?.(credential, model) }, [onEdit, model]) @@ -52,34 +55,29 @@ export const AuthorizedItem = ({ return (
-
-
-
- {title ?? model?.model} -
- { - enableAddModelCredential && !notAllowCustomCredential && ( - + { + model?.model && ( + + ) + } +
- - - ) - } -
+ {title ?? model?.model} +
+
+ ) + } { credentials.map(credential => ( )) } diff --git a/web/app/components/header/account-setting/model-provider-page/model-auth/authorized/credential-item.tsx b/web/app/components/header/account-setting/model-provider-page/model-auth/authorized/credential-item.tsx index 6596e64e0d..2d792d1705 100644 --- a/web/app/components/header/account-setting/model-provider-page/model-auth/authorized/credential-item.tsx +++ b/web/app/components/header/account-setting/model-provider-page/model-auth/authorized/credential-item.tsx @@ -24,6 +24,8 @@ type CredentialItemProps = { disableRename?: boolean disableEdit?: boolean disableDelete?: boolean + disableDeleteButShowAction?: boolean + disableDeleteTip?: string showSelectedIcon?: boolean selectedCredentialId?: string } @@ -36,6 +38,8 @@ const CredentialItem = ({ disableRename, disableEdit, disableDelete, + disableDeleteButShowAction, + disableDeleteTip, showSelectedIcon, selectedCredentialId, }: CredentialItemProps) => { @@ -43,6 +47,9 @@ const CredentialItem = ({ const showAction = useMemo(() => { return !(disableRename && disableEdit && disableDelete) }, [disableRename, disableEdit, disableDelete]) + const disableDeleteWhenSelected = useMemo(() => { + return disableDeleteButShowAction && selectedCredentialId === credential.credential_id + }, [disableDeleteButShowAction, selectedCredentialId, credential.credential_id]) const Item = (
+ { + if (disabled || disableDeleteWhenSelected) + return e.stopPropagation() onDelete?.(credential) }} > - + ) diff --git a/web/app/components/header/account-setting/model-provider-page/model-auth/authorized/index.tsx b/web/app/components/header/account-setting/model-provider-page/model-auth/authorized/index.tsx index 2aa64ffb89..6504fbc37e 100644 --- a/web/app/components/header/account-setting/model-provider-page/model-auth/authorized/index.tsx +++ b/web/app/components/header/account-setting/model-provider-page/model-auth/authorized/index.tsx @@ -1,12 +1,11 @@ import { + Fragment, memo, useCallback, - useMemo, useState, } from 'react' import { RiAddLine, - RiEqualizer2Line, } from '@remixicon/react' import { useTranslation } from 'react-i18next' import { @@ -25,6 +24,7 @@ import type { Credential, CustomConfigurationModelFixedFields, CustomModel, + ModelModalModeEnum, ModelProvider, } from '../../declarations' import { useAuth } from '../hooks' @@ -34,15 +34,20 @@ type AuthorizedProps = { provider: ModelProvider, configurationMethod: ConfigurationMethodEnum, currentCustomConfigurationModelFixedFields?: CustomConfigurationModelFixedFields, - isModelCredential?: boolean + authParams?: { + isModelCredential?: boolean + onUpdate?: (newPayload?: any, formValues?: Record) => void + onRemove?: (credentialId: string) => void + mode?: ModelModalModeEnum + } items: { title?: string model?: CustomModel + selectedCredential?: Credential credentials: Credential[] }[] - selectedCredential?: Credential disabled?: boolean - renderTrigger?: (open?: boolean) => React.ReactNode + renderTrigger: (open?: boolean) => React.ReactNode isOpen?: boolean onOpenChange?: (open: boolean) => void offset?: PortalToFollowElemOptions['offset'] @@ -50,18 +55,22 @@ type AuthorizedProps = { triggerPopupSameWidth?: boolean popupClassName?: string showItemSelectedIcon?: boolean - onUpdate?: () => void onItemClick?: (credential: Credential, model?: CustomModel) => void enableAddModelCredential?: boolean - bottomAddModelCredentialText?: string + triggerOnlyOpenModal?: boolean + hideAddAction?: boolean + disableItemClick?: boolean + popupTitle?: string + showModelTitle?: boolean + disableDeleteButShowAction?: boolean + disableDeleteTip?: string } const Authorized = ({ provider, configurationMethod, currentCustomConfigurationModelFixedFields, items, - isModelCredential, - selectedCredential, + authParams, disabled, renderTrigger, isOpen, @@ -71,10 +80,14 @@ const Authorized = ({ triggerPopupSameWidth = false, popupClassName, showItemSelectedIcon, - onUpdate, onItemClick, - enableAddModelCredential, - bottomAddModelCredentialText, + triggerOnlyOpenModal, + hideAddAction, + disableItemClick, + popupTitle, + showModelTitle, + disableDeleteButShowAction, + disableDeleteTip, }: AuthorizedProps) => { const { t } = useTranslation() const [isLocalOpen, setIsLocalOpen] = useState(false) @@ -85,6 +98,12 @@ const Authorized = ({ setIsLocalOpen(open) }, [onOpenChange]) + const { + isModelCredential, + onUpdate, + onRemove, + mode, + } = authParams || {} const { openConfirmDelete, closeConfirmDelete, @@ -93,7 +112,17 @@ const Authorized = ({ handleConfirmDelete, deleteCredentialId, handleOpenModal, - } = useAuth(provider, configurationMethod, currentCustomConfigurationModelFixedFields, isModelCredential, onUpdate) + } = useAuth( + provider, + configurationMethod, + currentCustomConfigurationModelFixedFields, + { + isModelCredential, + onUpdate, + onRemove, + mode, + }, + ) const handleEdit = useCallback((credential?: Credential, model?: CustomModel) => { handleOpenModal(credential, model) @@ -101,28 +130,18 @@ const Authorized = ({ }, [handleOpenModal, setMergedIsOpen]) const handleItemClick = useCallback((credential: Credential, model?: CustomModel) => { + if (disableItemClick) + return + if (onItemClick) onItemClick(credential, model) else handleActiveCredential(credential, model) setMergedIsOpen(false) - }, [handleActiveCredential, onItemClick, setMergedIsOpen]) + }, [handleActiveCredential, onItemClick, setMergedIsOpen, disableItemClick]) const notAllowCustomCredential = provider.allow_custom_token === false - const Trigger = useMemo(() => { - const Item = ( - - ) - return Item - }, [t]) - return ( <> { + if (triggerOnlyOpenModal) { + handleOpenModal() + return + } + setMergedIsOpen(!mergedIsOpen) }} asChild > - { - renderTrigger - ? renderTrigger(mergedIsOpen) - : Trigger - } + {renderTrigger(mergedIsOpen)}
+ { + popupTitle && ( +
+ {popupTitle} +
+ ) + }
{ items.map((item, index) => ( - + + + { + index !== items.length - 1 && ( +
+ ) + } +
)) }
{ - isModelCredential && !notAllowCustomCredential && ( + isModelCredential && !notAllowCustomCredential && !hideAddAction && (
handleEdit( undefined, @@ -182,15 +217,15 @@ const Authorized = ({ } : undefined, )} - className='system-xs-medium flex h-[30px] cursor-pointer items-center px-3 text-text-accent-light-mode-only' + className='system-xs-medium flex h-[40px] cursor-pointer items-center px-3 text-text-accent-light-mode-only' > - {bottomAddModelCredentialText ?? t('common.modelProvider.auth.addModelCredential')} + {t('common.modelProvider.auth.addModelCredential')}
) } { - !isModelCredential && !notAllowCustomCredential && ( + !isModelCredential && !notAllowCustomCredential && !hideAddAction && (
) - if (notAllowCustomCredential) { + if (notAllowCustomCredential && !hasCredential) { return ( ) } diff --git a/web/app/components/header/account-setting/model-provider-page/model-auth/credential-selector.tsx b/web/app/components/header/account-setting/model-provider-page/model-auth/credential-selector.tsx new file mode 100644 index 0000000000..ef0a9a9be5 --- /dev/null +++ b/web/app/components/header/account-setting/model-provider-page/model-auth/credential-selector.tsx @@ -0,0 +1,115 @@ +import { + memo, + useCallback, + useState, +} from 'react' +import { useTranslation } from 'react-i18next' +import { + RiAddLine, + RiArrowDownSLine, +} from '@remixicon/react' +import { + PortalToFollowElem, + PortalToFollowElemContent, + PortalToFollowElemTrigger, +} from '@/app/components/base/portal-to-follow-elem' +import type { Credential } from '@/app/components/header/account-setting/model-provider-page/declarations' +import CredentialItem from './authorized/credential-item' +import Badge from '@/app/components/base/badge' +import Indicator from '@/app/components/header/indicator' + +type CredentialSelectorProps = { + selectedCredential?: Credential & { addNewCredential?: boolean } + credentials: Credential[] + onSelect: (credential: Credential & { addNewCredential?: boolean }) => void + disabled?: boolean + notAllowAddNewCredential?: boolean +} +const CredentialSelector = ({ + selectedCredential, + credentials, + onSelect, + disabled, + notAllowAddNewCredential, +}: CredentialSelectorProps) => { + const { t } = useTranslation() + const [open, setOpen] = useState(false) + const handleSelect = useCallback((credential: Credential & { addNewCredential?: boolean }) => { + setOpen(false) + onSelect(credential) + }, [onSelect]) + const handleAddNewCredential = useCallback(() => { + handleSelect({ + credential_id: '__add_new_credential', + addNewCredential: true, + credential_name: t('common.modelProvider.auth.addNewModelCredential'), + }) + }, [handleSelect, t]) + + return ( + + !disabled && setOpen(v => !v)}> +
+ { + selectedCredential && ( +
+ { + !selectedCredential.addNewCredential && + } +
{selectedCredential.credential_name}
+ { + selectedCredential.from_enterprise && ( + Enterprise + ) + } +
+ ) + } + { + !selectedCredential && ( +
{t('common.modelProvider.auth.selectModelCredential')}
+ ) + } + +
+
+ +
+
+ { + credentials.map(credential => ( + + )) + } +
+ { + !notAllowAddNewCredential && ( +
+ + {t('common.modelProvider.auth.addNewModelCredential')} +
+ ) + } +
+
+
+ ) +} + +export default memo(CredentialSelector) diff --git a/web/app/components/header/account-setting/model-provider-page/model-auth/hooks/use-auth-service.ts b/web/app/components/header/account-setting/model-provider-page/model-auth/hooks/use-auth-service.ts index 317a1fe1a9..6de1333ea4 100644 --- a/web/app/components/header/account-setting/model-provider-page/model-auth/hooks/use-auth-service.ts +++ b/web/app/components/header/account-setting/model-provider-page/model-auth/hooks/use-auth-service.ts @@ -17,7 +17,7 @@ import type { export const useGetCredential = (provider: string, isModelCredential?: boolean, credentialId?: string, model?: CustomModel, configFrom?: string) => { const providerData = useGetProviderCredential(!isModelCredential && !!credentialId, provider, credentialId) - const modelData = useGetModelCredential(!!isModelCredential && !!credentialId, provider, credentialId, model?.model, model?.model_type, configFrom) + const modelData = useGetModelCredential(!!isModelCredential && (!!credentialId || !!model), provider, credentialId, model?.model, model?.model_type, configFrom) return isModelCredential ? modelData : providerData } diff --git a/web/app/components/header/account-setting/model-provider-page/model-auth/hooks/use-auth.ts b/web/app/components/header/account-setting/model-provider-page/model-auth/hooks/use-auth.ts index d4a0417a44..14b21be7f7 100644 --- a/web/app/components/header/account-setting/model-provider-page/model-auth/hooks/use-auth.ts +++ b/web/app/components/header/account-setting/model-provider-page/model-auth/hooks/use-auth.ts @@ -11,20 +11,32 @@ import type { Credential, CustomConfigurationModelFixedFields, CustomModel, + ModelModalModeEnum, ModelProvider, } from '../../declarations' import { useModelModalHandler, useRefreshModel, } from '@/app/components/header/account-setting/model-provider-page/hooks' +import { useDeleteModel } from '@/service/use-models' export const useAuth = ( provider: ModelProvider, configurationMethod: ConfigurationMethodEnum, currentCustomConfigurationModelFixedFields?: CustomConfigurationModelFixedFields, - isModelCredential?: boolean, - onUpdate?: () => void, + extra: { + isModelCredential?: boolean, + onUpdate?: (newPayload?: any, formValues?: Record) => void, + onRemove?: (credentialId: string) => void, + mode?: ModelModalModeEnum, + } = {}, ) => { + const { + isModelCredential, + onUpdate, + onRemove, + mode, + } = extra const { t } = useTranslation() const { notify } = useToastContext() const { @@ -33,22 +45,30 @@ export const useAuth = ( getEditCredentialService, getAddCredentialService, } = useAuthService(provider.provider) + const { mutateAsync: deleteModelService } = useDeleteModel(provider.provider) const handleOpenModelModal = useModelModalHandler() const { handleRefreshModel } = useRefreshModel() const pendingOperationCredentialId = useRef(null) - const pendingOperationModel = useRef(null) const [deleteCredentialId, setDeleteCredentialId] = useState(null) + const handleSetDeleteCredentialId = useCallback((credentialId: string | null) => { + setDeleteCredentialId(credentialId) + pendingOperationCredentialId.current = credentialId + }, []) + const pendingOperationModel = useRef(null) + const [deleteModel, setDeleteModel] = useState(null) + const handleSetDeleteModel = useCallback((model: CustomModel | null) => { + setDeleteModel(model) + pendingOperationModel.current = model + }, []) const openConfirmDelete = useCallback((credential?: Credential, model?: CustomModel) => { if (credential) - pendingOperationCredentialId.current = credential.credential_id + handleSetDeleteCredentialId(credential.credential_id) if (model) - pendingOperationModel.current = model - - setDeleteCredentialId(pendingOperationCredentialId.current) + handleSetDeleteModel(model) }, []) const closeConfirmDelete = useCallback(() => { - setDeleteCredentialId(null) - pendingOperationCredentialId.current = null + handleSetDeleteCredentialId(null) + handleSetDeleteModel(null) }, []) const [doingAction, setDoingAction] = useState(false) const doingActionRef = useRef(doingAction) @@ -70,45 +90,49 @@ export const useAuth = ( type: 'success', message: t('common.api.actionSuccess'), }) - onUpdate?.() handleRefreshModel(provider, configurationMethod, undefined) } finally { handleSetDoingAction(false) } - }, [getActiveCredentialService, onUpdate, notify, t, handleSetDoingAction]) + }, [getActiveCredentialService, notify, t, handleSetDoingAction]) const handleConfirmDelete = useCallback(async () => { if (doingActionRef.current) return - if (!pendingOperationCredentialId.current) { - setDeleteCredentialId(null) + if (!pendingOperationCredentialId.current && !pendingOperationModel.current) { + closeConfirmDelete() return } try { handleSetDoingAction(true) - await getDeleteCredentialService(!!isModelCredential)({ - credential_id: pendingOperationCredentialId.current, - model: pendingOperationModel.current?.model, - model_type: pendingOperationModel.current?.model_type, - }) + let payload: any = {} + if (pendingOperationCredentialId.current) { + payload = { + credential_id: pendingOperationCredentialId.current, + model: pendingOperationModel.current?.model, + model_type: pendingOperationModel.current?.model_type, + } + await getDeleteCredentialService(!!isModelCredential)(payload) + } + if (!pendingOperationCredentialId.current && pendingOperationModel.current) { + payload = { + model: pendingOperationModel.current.model, + model_type: pendingOperationModel.current.model_type, + } + await deleteModelService(payload) + } notify({ type: 'success', message: t('common.api.actionSuccess'), }) - onUpdate?.() handleRefreshModel(provider, configurationMethod, undefined) - setDeleteCredentialId(null) - pendingOperationCredentialId.current = null - pendingOperationModel.current = null + onRemove?.(pendingOperationCredentialId.current ?? '') + closeConfirmDelete() } finally { handleSetDoingAction(false) } - }, [onUpdate, notify, t, handleSetDoingAction, getDeleteCredentialService, isModelCredential]) - const handleAddCredential = useCallback((model?: CustomModel) => { - if (model) - pendingOperationModel.current = model - }, []) + }, [notify, t, handleSetDoingAction, getDeleteCredentialService, isModelCredential, closeConfirmDelete, handleRefreshModel, provider, configurationMethod, deleteModelService]) const handleSaveCredential = useCallback(async (payload: Record) => { if (doingActionRef.current) return @@ -123,24 +147,35 @@ export const useAuth = ( if (res.result === 'success') { notify({ type: 'success', message: t('common.actionMsg.modifiedSuccessfully') }) - onUpdate?.() + handleRefreshModel(provider, configurationMethod, undefined) } } finally { handleSetDoingAction(false) } - }, [onUpdate, notify, t, handleSetDoingAction, getEditCredentialService, getAddCredentialService]) + }, [notify, t, handleSetDoingAction, getEditCredentialService, getAddCredentialService]) const handleOpenModal = useCallback((credential?: Credential, model?: CustomModel) => { handleOpenModelModal( provider, configurationMethod, currentCustomConfigurationModelFixedFields, - isModelCredential, - credential, - model, - onUpdate, + { + isModelCredential, + credential, + model, + onUpdate, + mode, + }, ) - }, [handleOpenModelModal, provider, configurationMethod, currentCustomConfigurationModelFixedFields, isModelCredential, onUpdate]) + }, [ + handleOpenModelModal, + provider, + configurationMethod, + currentCustomConfigurationModelFixedFields, + isModelCredential, + onUpdate, + mode, + ]) return { pendingOperationCredentialId, @@ -150,8 +185,8 @@ export const useAuth = ( doingAction, handleActiveCredential, handleConfirmDelete, - handleAddCredential, deleteCredentialId, + deleteModel, handleSaveCredential, handleOpenModal, } diff --git a/web/app/components/header/account-setting/model-provider-page/model-auth/hooks/use-custom-models.ts b/web/app/components/header/account-setting/model-provider-page/model-auth/hooks/use-custom-models.ts index f3b50f3f49..6abf6f51b6 100644 --- a/web/app/components/header/account-setting/model-provider-page/model-auth/hooks/use-custom-models.ts +++ b/web/app/components/header/account-setting/model-provider-page/model-auth/hooks/use-custom-models.ts @@ -7,3 +7,9 @@ export const useCustomModels = (provider: ModelProvider) => { return custom_models || [] } + +export const useCanAddedModels = (provider: ModelProvider) => { + const { can_added_models } = provider.custom_configuration + + return can_added_models || [] +} diff --git a/web/app/components/header/account-setting/model-provider-page/model-auth/hooks/use-model-form-schemas.ts b/web/app/components/header/account-setting/model-provider-page/model-auth/hooks/use-model-form-schemas.ts index 22fab62bee..1cbe8f20b1 100644 --- a/web/app/components/header/account-setting/model-provider-page/model-auth/hooks/use-model-form-schemas.ts +++ b/web/app/components/header/account-setting/model-provider-page/model-auth/hooks/use-model-form-schemas.ts @@ -3,7 +3,6 @@ import { useTranslation } from 'react-i18next' import type { Credential, CustomModelCredential, - ModelLoadBalancingConfig, ModelProvider, } from '../../declarations' import { @@ -18,7 +17,6 @@ export const useModelFormSchemas = ( credentials?: Record, credential?: Credential, model?: CustomModelCredential, - draftConfig?: ModelLoadBalancingConfig, ) => { const { t } = useTranslation() const { @@ -27,26 +25,15 @@ export const useModelFormSchemas = ( model_credential_schema, } = provider const formSchemas = useMemo(() => { - const modelTypeSchema = genModelTypeFormSchema(supported_model_types) - const modelNameSchema = genModelNameFormSchema(model_credential_schema?.model) - if (!!model) { - modelTypeSchema.disabled = true - modelNameSchema.disabled = true - } return providerFormSchemaPredefined ? provider_credential_schema.credential_form_schemas - : [ - modelTypeSchema, - modelNameSchema, - ...(draftConfig?.enabled ? [] : model_credential_schema.credential_form_schemas), - ] + : model_credential_schema.credential_form_schemas }, [ providerFormSchemaPredefined, provider_credential_schema?.credential_form_schemas, supported_model_types, model_credential_schema?.credential_form_schemas, model_credential_schema?.model, - draftConfig?.enabled, model, ]) @@ -55,7 +42,7 @@ export const useModelFormSchemas = ( type: FormTypeEnum.textInput, variable: '__authorization_name__', label: t('plugin.auth.authorizationName'), - required: true, + required: false, } return [ @@ -79,8 +66,33 @@ export const useModelFormSchemas = ( return result }, [credentials, credential, model, formSchemas]) + const modelNameAndTypeFormSchemas = useMemo(() => { + if (providerFormSchemaPredefined) + return [] + + const modelNameSchema = genModelNameFormSchema(model_credential_schema?.model) + const modelTypeSchema = genModelTypeFormSchema(supported_model_types) + return [ + modelNameSchema, + modelTypeSchema, + ] + }, [supported_model_types, model_credential_schema?.model, providerFormSchemaPredefined]) + + const modelNameAndTypeFormValues = useMemo(() => { + let result = {} + if (providerFormSchemaPredefined) + return result + + if (model) + result = { ...result, __model_name: model?.model, __model_type: model?.model_type } + + return result + }, [model, providerFormSchemaPredefined]) + return { formSchemas: formSchemasWithAuthorizationName, formValues, + modelNameAndTypeFormSchemas, + modelNameAndTypeFormValues, } } diff --git a/web/app/components/header/account-setting/model-provider-page/model-auth/index.tsx b/web/app/components/header/account-setting/model-provider-page/model-auth/index.tsx index 05effcea7c..f9708607a7 100644 --- a/web/app/components/header/account-setting/model-provider-page/model-auth/index.tsx +++ b/web/app/components/header/account-setting/model-provider-page/model-auth/index.tsx @@ -4,3 +4,5 @@ export { default as AddCredentialInLoadBalancing } from './add-credential-in-loa export { default as AddCustomModel } from './add-custom-model' export { default as ConfigProvider } from './config-provider' export { default as ConfigModel } from './config-model' +export { default as ManageCustomModelCredentials } from './manage-custom-model-credentials' +export { default as CredentialSelector } from './credential-selector' diff --git a/web/app/components/header/account-setting/model-provider-page/model-auth/manage-custom-model-credentials.tsx b/web/app/components/header/account-setting/model-provider-page/model-auth/manage-custom-model-credentials.tsx new file mode 100644 index 0000000000..3a9d10ea46 --- /dev/null +++ b/web/app/components/header/account-setting/model-provider-page/model-auth/manage-custom-model-credentials.tsx @@ -0,0 +1,82 @@ +import { + memo, + useCallback, +} from 'react' +import { useTranslation } from 'react-i18next' +import { + Button, +} from '@/app/components/base/button' +import type { + CustomConfigurationModelFixedFields, + ModelProvider, +} from '@/app/components/header/account-setting/model-provider-page/declarations' +import { + ConfigurationMethodEnum, + ModelModalModeEnum, +} from '@/app/components/header/account-setting/model-provider-page/declarations' +import Authorized from './authorized' +import { + useCustomModels, +} from './hooks' +import cn from '@/utils/classnames' + +type ManageCustomModelCredentialsProps = { + provider: ModelProvider, + currentCustomConfigurationModelFixedFields?: CustomConfigurationModelFixedFields, +} +const ManageCustomModelCredentials = ({ + provider, + currentCustomConfigurationModelFixedFields, +}: ManageCustomModelCredentialsProps) => { + const { t } = useTranslation() + const customModels = useCustomModels(provider) + const noModels = !customModels.length + + const renderTrigger = useCallback((open?: boolean) => { + const Item = ( + + ) + return Item + }, [t]) + + if (noModels) + return null + + return ( + ({ + model, + credentials: model.available_model_credentials ?? [], + selectedCredential: model.current_credential_id ? { + credential_id: model.current_credential_id, + credential_name: model.current_credential_name, + } : undefined, + }))} + renderTrigger={renderTrigger} + authParams={{ + isModelCredential: true, + mode: ModelModalModeEnum.configModelCredential, + }} + hideAddAction + disableItemClick + popupTitle={t('common.modelProvider.auth.customModelCredentials')} + showModelTitle + disableDeleteButShowAction + disableDeleteTip={t('common.modelProvider.auth.customModelCredentialsDeleteTip')} + /> + ) +} + +export default memo(ManageCustomModelCredentials) diff --git a/web/app/components/header/account-setting/model-provider-page/model-auth/switch-credential-in-load-balancing.tsx b/web/app/components/header/account-setting/model-provider-page/model-auth/switch-credential-in-load-balancing.tsx index 8f81107bb2..6ca120aea6 100644 --- a/web/app/components/header/account-setting/model-provider-page/model-auth/switch-credential-in-load-balancing.tsx +++ b/web/app/components/header/account-setting/model-provider-page/model-auth/switch-credential-in-load-balancing.tsx @@ -13,7 +13,7 @@ import type { CustomModel, ModelProvider, } from '../declarations' -import { ConfigurationMethodEnum } from '@/app/components/header/account-setting/model-provider-page/declarations' +import { ConfigurationMethodEnum, ModelModalModeEnum } from '@/app/components/header/account-setting/model-provider-page/declarations' import cn from '@/utils/classnames' import Tooltip from '@/app/components/base/tooltip' import Badge from '@/app/components/base/badge' @@ -24,6 +24,8 @@ type SwitchCredentialInLoadBalancingProps = { credentials?: Credential[] customModelCredential?: Credential setCustomModelCredential: Dispatch> + onUpdate?: (payload?: any, formValues?: Record) => void + onRemove?: (credentialId: string) => void } const SwitchCredentialInLoadBalancing = ({ provider, @@ -31,6 +33,8 @@ const SwitchCredentialInLoadBalancing = ({ customModelCredential, setCustomModelCredential, credentials, + onUpdate, + onRemove, }: SwitchCredentialInLoadBalancingProps) => { const { t } = useTranslation() @@ -94,27 +98,31 @@ const SwitchCredentialInLoadBalancing = ({ ) } diff --git a/web/app/components/header/account-setting/model-provider-page/model-modal/index.tsx b/web/app/components/header/account-setting/model-provider-page/model-modal/index.tsx index d754d24d90..adf633831b 100644 --- a/web/app/components/header/account-setting/model-provider-page/model-modal/index.tsx +++ b/web/app/components/header/account-setting/model-provider-page/model-modal/index.tsx @@ -5,6 +5,7 @@ import { useEffect, useMemo, useRef, + useState, } from 'react' import { RiCloseLine } from '@remixicon/react' import { useTranslation } from 'react-i18next' @@ -15,6 +16,7 @@ import type { import { ConfigurationMethodEnum, FormTypeEnum, + ModelModalModeEnum, } from '../declarations' import { useLanguage, @@ -46,16 +48,19 @@ import { import ModelIcon from '@/app/components/header/account-setting/model-provider-page/model-icon' import Badge from '@/app/components/base/badge' import { useRenderI18nObject } from '@/hooks/use-i18n' +import { CredentialSelector } from '../model-auth' type ModelModalProps = { provider: ModelProvider configurateMethod: ConfigurationMethodEnum currentCustomConfigurationModelFixedFields?: CustomConfigurationModelFixedFields onCancel: () => void - onSave: () => void + onSave: (formValues?: Record) => void + onRemove: (formValues?: Record) => void model?: CustomModel credential?: Credential isModelCredential?: boolean + mode?: ModelModalModeEnum } const ModelModal: FC = ({ @@ -67,6 +72,7 @@ const ModelModal: FC = ({ model, credential, isModelCredential, + mode = ModelModalModeEnum.configProviderCredential, }) => { const renderI18nObject = useRenderI18nObject() const providerFormSchemaPredefined = configurateMethod === ConfigurationMethodEnum.predefinedModel @@ -81,40 +87,88 @@ const ModelModal: FC = ({ closeConfirmDelete, openConfirmDelete, doingAction, - } = useAuth(provider, configurateMethod, currentCustomConfigurationModelFixedFields, isModelCredential, onSave) + handleActiveCredential, + } = useAuth( + provider, + configurateMethod, + currentCustomConfigurationModelFixedFields, + { + isModelCredential, + mode, + }, + ) const { credentials: formSchemasValue, + available_credentials, } = credentialData as any const { isCurrentWorkspaceManager } = useAppContext() - const isEditMode = !!formSchemasValue && isCurrentWorkspaceManager const { t } = useTranslation() const language = useLanguage() const { formSchemas, formValues, + modelNameAndTypeFormSchemas, + modelNameAndTypeFormValues, } = useModelFormSchemas(provider, providerFormSchemaPredefined, formSchemasValue, credential, model) - const formRef = useRef(null) + const formRef1 = useRef(null) + const [selectedCredential, setSelectedCredential] = useState() + const formRef2 = useRef(null) + const isEditMode = !!Object.keys(formValues).filter((key) => { + return key !== '__model_name' && key !== '__model_type' + }).length && isCurrentWorkspaceManager const handleSave = useCallback(async () => { + if (mode === ModelModalModeEnum.addCustomModelToModelList && selectedCredential && !selectedCredential?.addNewCredential) { + handleActiveCredential(selectedCredential, model) + onCancel() + return + } + + let modelNameAndTypeIsCheckValidated = true + let modelNameAndTypeValues: Record = {} + + if (mode === ModelModalModeEnum.configCustomModel) { + const formResult = formRef1.current?.getFormValues({ + needCheckValidatedValues: true, + }) || { isCheckValidated: false, values: {} } + modelNameAndTypeIsCheckValidated = formResult.isCheckValidated + modelNameAndTypeValues = formResult.values + } + + if (mode === ModelModalModeEnum.configModelCredential && model) { + modelNameAndTypeValues = { + __model_name: model.model, + __model_type: model.model_type, + } + } + + if (mode === ModelModalModeEnum.addCustomModelToModelList && selectedCredential?.addNewCredential && model) { + modelNameAndTypeValues = { + __model_name: model.model, + __model_type: model.model_type, + } + } const { isCheckValidated, values, - } = formRef.current?.getFormValues({ + } = formRef2.current?.getFormValues({ needCheckValidatedValues: true, needTransformWhenSecretFieldIsPristine: true, }) || { isCheckValidated: false, values: {} } - if (!isCheckValidated) + if (!isCheckValidated || !modelNameAndTypeIsCheckValidated) return const { - __authorization_name__, __model_name, __model_type, + } = modelNameAndTypeValues + const { + __authorization_name__, ...rest } = values - if (__model_name && __model_type) { - handleSaveCredential({ + if (__model_name && __model_type && __authorization_name__) { + await handleSaveCredential({ credential_id: credential?.credential_id, credentials: rest, name: __authorization_name__, @@ -123,41 +177,33 @@ const ModelModal: FC = ({ }) } else { - handleSaveCredential({ + await handleSaveCredential({ credential_id: credential?.credential_id, credentials: rest, name: __authorization_name__, }) } - }, [handleSaveCredential, credential?.credential_id, model]) + onSave(values) + }, [handleSaveCredential, credential?.credential_id, model, onSave, mode, selectedCredential, handleActiveCredential]) const modalTitle = useMemo(() => { - if (!providerFormSchemaPredefined && !model) { - return ( -
- -
-
{t('common.modelProvider.auth.apiKeyModal.addModel')}
-
{renderI18nObject(provider.label)}
-
-
- ) - } let label = t('common.modelProvider.auth.apiKeyModal.title') - if (model) - label = t('common.modelProvider.auth.addModelCredential') + if (mode === ModelModalModeEnum.configCustomModel || mode === ModelModalModeEnum.addCustomModelToModelList) + label = t('common.modelProvider.auth.addModel') + if (mode === ModelModalModeEnum.configModelCredential) { + if (credential) + label = t('common.modelProvider.auth.editModelCredential') + else + label = t('common.modelProvider.auth.addModelCredential') + } return (
{label}
) - }, [providerFormSchemaPredefined, t, model, renderI18nObject]) + }, [t, mode, credential]) const modalDesc = useMemo(() => { if (providerFormSchemaPredefined) { @@ -172,7 +218,18 @@ const ModelModal: FC = ({ }, [providerFormSchemaPredefined, t]) const modalModel = useMemo(() => { - if (model) { + if (mode === ModelModalModeEnum.configCustomModel) { + return ( +
+ +
{renderI18nObject(provider.label)}
+
+ ) + } + if (model && (mode === ModelModalModeEnum.configModelCredential || mode === ModelModalModeEnum.addCustomModelToModelList)) { return (
= ({ } return null - }, [model, provider]) + }, [model, provider, mode, renderI18nObject]) + + const showCredentialLabel = useMemo(() => { + if (mode === ModelModalModeEnum.configCustomModel) + return true + if (mode === ModelModalModeEnum.addCustomModelToModelList) + return selectedCredential?.addNewCredential + }, [mode, selectedCredential]) + const showCredentialForm = useMemo(() => { + if (mode !== ModelModalModeEnum.addCustomModelToModelList) + return true + return selectedCredential?.addNewCredential + }, [mode, selectedCredential]) + const saveButtonText = useMemo(() => { + if (mode === ModelModalModeEnum.addCustomModelToModelList || mode === ModelModalModeEnum.configCustomModel) + return t('common.operation.add') + return t('common.operation.save') + }, [mode, t]) + + const handleDeleteCredential = useCallback(() => { + handleConfirmDelete() + onCancel() + }, [handleConfirmDelete]) + + const handleModelNameAndTypeChange = useCallback((field: string, value: any) => { + const { + getForm, + } = formRef2.current as FormRefObject || {} + if (getForm()) + getForm()?.setFieldValue(field, value) + }, []) + const notAllowCustomCredential = provider.allow_custom_token === false useEffect(() => { const handleKeyDown = (event: KeyboardEvent) => { @@ -214,100 +302,132 @@ const ModelModal: FC = ({ >
-
-
- {modalTitle} - {modalDesc} - {modalModel} -
- -
- { - isLoading && ( -
- -
- ) - } - { - !isLoading && ( - { - return { - ...formSchema, - name: formSchema.variable, - showRadioUI: formSchema.type === FormTypeEnum.radio, - } - }) as FormSchema[]} - defaultValues={formValues} - inputClassName='justify-start' - ref={formRef} - /> - ) - } -
- -
- { - (provider.help && (provider.help.title || provider.help.url)) - ? ( - !provider.help.url && e.preventDefault()} - > - {provider.help.title?.[language] || provider.help.url[language] || provider.help.title?.en_US || provider.help.url.en_US} - - - ) - :
- } -
- { - isEditMode && ( - - ) - } - - -
-
+
+ {modalTitle} + {modalDesc} + {modalModel}
-
-
- - {t('common.modelProvider.encrypted.front')} - + { + mode === ModelModalModeEnum.configCustomModel && ( + { + return { + ...formSchema, + name: formSchema.variable, + } + }) as FormSchema[]} + defaultValues={modelNameAndTypeFormValues} + inputClassName='justify-start' + ref={formRef1} + onChange={handleModelNameAndTypeChange} + /> + ) + } + { + mode === ModelModalModeEnum.addCustomModelToModelList && ( + + ) + } + { + showCredentialLabel && ( +
+ {t('common.modelProvider.auth.modelCredential')} +
+
+ ) + } + { + isLoading && ( +
+ +
+ ) + } + { + !isLoading + && showCredentialForm + && ( + { + return { + ...formSchema, + name: formSchema.variable, + showRadioUI: formSchema.type === FormTypeEnum.radio, + } + }) as FormSchema[]} + defaultValues={formValues} + inputClassName='justify-start' + ref={formRef2} + /> + ) + } +
+
+ { + (provider.help && (provider.help.title || provider.help.url)) + ? ( + !provider.help.url && e.preventDefault()} + > + {provider.help.title?.[language] || provider.help.url[language] || provider.help.title?.en_US || provider.help.url.en_US} + + + ) + :
+ } +
+ { + isEditMode && ( + + ) + } + +
+ { + (mode === ModelModalModeEnum.configCustomModel || mode === ModelModalModeEnum.configProviderCredential) && ( +
+
+ + {t('common.modelProvider.encrypted.front')} + + PKCS1_OAEP + + {t('common.modelProvider.encrypted.back')} +
+
+ ) + }
{ deleteCredentialId && ( @@ -316,7 +436,7 @@ const ModelModal: FC = ({ title={t('common.modelProvider.confirmDelete')} isDisabled={doingAction} onCancel={closeConfirmDelete} - onConfirm={handleConfirmDelete} + onConfirm={handleDeleteCredential} /> ) } diff --git a/web/app/components/header/account-setting/model-provider-page/provider-added-card/credential-panel.tsx b/web/app/components/header/account-setting/model-provider-page/provider-added-card/credential-panel.tsx index e67da77837..fda6abb2fc 100644 --- a/web/app/components/header/account-setting/model-provider-page/provider-added-card/credential-panel.tsx +++ b/web/app/components/header/account-setting/model-provider-page/provider-added-card/credential-panel.tsx @@ -111,7 +111,6 @@ const CredentialPanel = ({
{ systemConfig.enabled && isCustomConfigured && ( diff --git a/web/app/components/header/account-setting/model-provider-page/provider-added-card/index.tsx b/web/app/components/header/account-setting/model-provider-page/provider-added-card/index.tsx index 559f630b49..d3601d04f9 100644 --- a/web/app/components/header/account-setting/model-provider-page/provider-added-card/index.tsx +++ b/web/app/components/header/account-setting/model-provider-page/provider-added-card/index.tsx @@ -25,7 +25,10 @@ import { useEventEmitterContextContext } from '@/context/event-emitter' import { IS_CE_EDITION } from '@/config' import { useAppContext } from '@/context/app-context' import cn from '@/utils/classnames' -import { AddCustomModel } from '@/app/components/header/account-setting/model-provider-page/model-auth' +import { + AddCustomModel, + ManageCustomModelCredentials, +} from '@/app/components/header/account-setting/model-provider-page/model-auth' export const UPDATE_MODEL_PROVIDER_CUSTOM_MODEL_LIST = 'UPDATE_MODEL_PROVIDER_CUSTOM_MODEL_LIST' type ProviderAddedCardProps = { @@ -155,10 +158,17 @@ const ProviderAddedCard: FC = ({ )} { configurationMethods.includes(ConfigurationMethodEnum.customizableModel) && isCurrentWorkspaceManager && ( - +
+ + +
) }
diff --git a/web/app/components/header/account-setting/model-provider-page/provider-added-card/model-list.tsx b/web/app/components/header/account-setting/model-provider-page/provider-added-card/model-list.tsx index 8d902043ff..9e26d233c9 100644 --- a/web/app/components/header/account-setting/model-provider-page/provider-added-card/model-list.tsx +++ b/web/app/components/header/account-setting/model-provider-page/provider-added-card/model-list.tsx @@ -16,7 +16,10 @@ import { import ModelListItem from './model-list-item' import { useModalContextSelector } from '@/context/modal-context' import { useAppContext } from '@/context/app-context' -import { AddCustomModel } from '@/app/components/header/account-setting/model-provider-page/model-auth' +import { + AddCustomModel, + ManageCustomModelCredentials, +} from '@/app/components/header/account-setting/model-provider-page/model-auth' type ModelListProps = { provider: ModelProvider @@ -67,6 +70,10 @@ const ModelList: FC = ({ { isConfigurable && isCurrentWorkspaceManager && (
+ void + onUpdate?: (payload?: any, formValues?: Record) => void + onRemove?: (credentialId: string) => void model: CustomModelCredential } @@ -55,11 +54,11 @@ const ModelLoadBalancingConfigs = ({ className, modelCredential, onUpdate, + onRemove, }: ModelLoadBalancingConfigsProps) => { const { t } = useTranslation() const providerFormSchemaPredefined = configurationMethod === ConfigurationMethodEnum.predefinedModel const modelLoadBalancingEnabled = useProviderContextSelector(state => state.modelLoadBalancingEnabled) - const handleOpenModal = useModelModalHandler() const updateConfigEntry = useCallback( ( @@ -130,6 +129,17 @@ const ModelLoadBalancingConfigs = ({ return draftConfig.configs }, [draftConfig]) + const handleUpdate = useCallback((payload?: any, formValues?: Record) => { + onUpdate?.(payload, formValues) + }, [onUpdate]) + + const handleRemove = useCallback((credentialId: string) => { + const index = draftConfig?.configs.findIndex(item => item.credential_id === credentialId && item.name !== '__inherit__') + if (index && index > -1) + updateConfigEntry(index, () => undefined) + onRemove?.(credentialId) + }, [draftConfig?.configs, updateConfigEntry, onRemove]) + if (!draftConfig) return null @@ -190,7 +200,7 @@ const ModelLoadBalancingConfigs = ({ )}
-
+
{isProviderManaged ? t('common.modelProvider.defaultConfig') : config.name}
{isProviderManaged && providerFormSchemaPredefined && ( @@ -206,34 +216,14 @@ const ModelLoadBalancingConfigs = ({ {!isProviderManaged && ( <>
- { - config.credential_id && !credential?.not_allowed_to_use && !credential?.from_enterprise && ( - { - handleOpenModal( - provider, - configurationMethod, - currentCustomConfigurationModelFixedFields, - configurationMethod === ConfigurationMethodEnum.customizableModel, - (config.credential_id && config.name) ? { - credential_id: config.credential_id, - credential_name: config.name, - } : undefined, - model, - ) - }} - > - - - ) - } - updateConfigEntry(index, () => undefined)} - > - - + + updateConfigEntry(index, () => undefined)} + > + + +
)} @@ -261,7 +251,8 @@ const ModelLoadBalancingConfigs = ({ configurationMethod={configurationMethod} modelCredential={modelCredential} onSelectCredential={addConfigEntry} - onUpdate={onUpdate} + onUpdate={handleUpdate} + onRemove={handleRemove} />
)} diff --git a/web/app/components/header/account-setting/model-provider-page/provider-added-card/model-load-balancing-modal.tsx b/web/app/components/header/account-setting/model-provider-page/provider-added-card/model-load-balancing-modal.tsx index cbd19c7cae..070c2ee90f 100644 --- a/web/app/components/header/account-setting/model-provider-page/provider-added-card/model-load-balancing-modal.tsx +++ b/web/app/components/header/account-setting/model-provider-page/provider-added-card/model-load-balancing-modal.tsx @@ -2,6 +2,7 @@ import { memo, useCallback, useEffect, useMemo, useState } from 'react' import { useTranslation } from 'react-i18next' import type { Credential, + CustomConfigurationModelFixedFields, ModelItem, ModelLoadBalancingConfig, ModelLoadBalancingConfigEntry, @@ -24,10 +25,14 @@ import { useGetModelCredential, useUpdateModelLoadBalancingConfig, } from '@/service/use-models' +import { useAuth } from '../model-auth/hooks/use-auth' +import Confirm from '@/app/components/base/confirm' +import { useRefreshModel } from '../hooks' export type ModelLoadBalancingModalProps = { provider: ModelProvider configurateMethod: ConfigurationMethodEnum + currentCustomConfigurationModelFixedFields?: CustomConfigurationModelFixedFields model: ModelItem credential?: Credential open?: boolean @@ -39,6 +44,7 @@ export type ModelLoadBalancingModalProps = { const ModelLoadBalancingModal = ({ provider, configurateMethod, + currentCustomConfigurationModelFixedFields, model, credential, open = false, @@ -47,7 +53,20 @@ const ModelLoadBalancingModal = ({ }: ModelLoadBalancingModalProps) => { const { t } = useTranslation() const { notify } = useToastContext() - + const { + doingAction, + deleteModel, + openConfirmDelete, + closeConfirmDelete, + handleConfirmDelete, + } = useAuth( + provider, + configurateMethod, + currentCustomConfigurationModelFixedFields, + { + isModelCredential: true, + }, + ) const [loading, setLoading] = useState(false) const providerFormSchemaPredefined = configurateMethod === ConfigurationMethodEnum.predefinedModel const configFrom = providerFormSchemaPredefined ? 'predefined-model' : 'custom-model' @@ -121,6 +140,7 @@ const ModelLoadBalancingModal = ({ } }, [current_credential_id, current_credential_name]) const [customModelCredential, setCustomModelCredential] = useState(initialCustomModelCredential) + const { handleRefreshModel } = useRefreshModel() const handleSave = async () => { try { setLoading(true) @@ -139,6 +159,7 @@ const ModelLoadBalancingModal = ({ ) if (res.result === 'success') { notify({ type: 'success', message: t('common.actionMsg.modifiedSuccessfully') }) + handleRefreshModel(provider, configurateMethod, currentCustomConfigurationModelFixedFields) onSave?.(provider.provider) onClose?.() } @@ -147,120 +168,208 @@ const ModelLoadBalancingModal = ({ setLoading(false) } } + const handleDeleteModel = useCallback(async () => { + await handleConfirmDelete() + onClose?.() + }, [handleConfirmDelete, onClose]) + + const handleUpdate = useCallback(async (payload?: any, formValues?: Record) => { + const result = await refetch() + const available_credentials = result.data?.available_credentials || [] + const credentialName = formValues?.__authorization_name__ + const modelCredential = payload?.credential + + if (!available_credentials.length) { + onClose?.() + return + } + + if (!modelCredential) { + const currentCredential = available_credentials.find(c => c.credential_name === credentialName) + if (currentCredential) { + setDraftConfig((prev: any) => { + if (!prev) + return prev + return { + ...prev, + configs: [...prev.configs, { + credential_id: currentCredential.credential_id, + enabled: true, + name: currentCredential.credential_name, + }], + } + }) + } + } + else { + setDraftConfig((prev) => { + if (!prev) + return prev + const newConfigs = [...prev.configs] + const prevIndex = newConfigs.findIndex(item => item.credential_id === modelCredential.credential_id && item.name !== '__inherit__') + const newIndex = available_credentials.findIndex(c => c.credential_id === modelCredential.credential_id) + + if (newIndex > -1 && prevIndex > -1) + newConfigs[prevIndex].name = available_credentials[newIndex].credential_name || '' + + return { + ...prev, + configs: newConfigs, + } + }) + } + }, [refetch, credential]) + + const handleUpdateWhenSwitchCredential = useCallback(async () => { + const result = await refetch() + const available_credentials = result.data?.available_credentials || [] + if (!available_credentials.length) + onClose?.() + }, [refetch, onClose]) return ( - -
{ - draftConfig?.enabled - ? t('common.modelProvider.auth.configLoadBalancing') - : t('common.modelProvider.auth.configModel') - }
- {Boolean(model) && ( -
- - -
- )} -
- } - > - {!draftConfig - ? - : ( - <> -
-
toggleModalBalancing(false) : undefined} - > -
-
- {Boolean(model) && ( - - )} -
-
-
{ - providerFormSchemaPredefined - ? t('common.modelProvider.auth.providerManaged') - : t('common.modelProvider.auth.specifyModelCredential') - }
-
{ - providerFormSchemaPredefined - ? t('common.modelProvider.auth.providerManagedTip') - : t('common.modelProvider.auth.specifyModelCredentialTip') - }
+ <> + +
{ + draftConfig?.enabled + ? t('common.modelProvider.auth.configLoadBalancing') + : t('common.modelProvider.auth.configModel') + }
+ {Boolean(model) && ( +
+ + +
+ )} +
+ } + > + {!draftConfig + ? + : ( + <> +
+
toggleModalBalancing(false) : undefined} + > +
+
+ {Boolean(model) && ( + + )} +
+
+
{ + providerFormSchemaPredefined + ? t('common.modelProvider.auth.providerManaged') + : t('common.modelProvider.auth.specifyModelCredential') + }
+
{ + providerFormSchemaPredefined + ? t('common.modelProvider.auth.providerManagedTip') + : t('common.modelProvider.auth.specifyModelCredentialTip') + }
+
+ { + !providerFormSchemaPredefined && ( + + ) + }
+
+ { + modelCredential && ( + + ) + } +
+ +
+
{ !providerFormSchemaPredefined && ( - + ) }
+
+ + +
- { - modelCredential && ( - - ) - } -
- -
- - -
- + + ) + } + + { + deleteModel && ( + ) } - + ) } diff --git a/web/app/components/header/account-setting/model-provider-page/utils.ts b/web/app/components/header/account-setting/model-provider-page/utils.ts index f577a536dc..f19999cc1f 100644 --- a/web/app/components/header/account-setting/model-provider-page/utils.ts +++ b/web/app/components/header/account-setting/model-provider-page/utils.ts @@ -161,7 +161,7 @@ export const modelTypeFormat = (modelType: ModelTypeEnum) => { export const genModelTypeFormSchema = (modelTypes: ModelTypeEnum[]) => { return { - type: FormTypeEnum.radio, + type: FormTypeEnum.select, label: { zh_Hans: '模型类型', en_US: 'Model Type', diff --git a/web/context/modal-context.tsx b/web/context/modal-context.tsx index dac9ef30d5..5baadc934b 100644 --- a/web/context/modal-context.tsx +++ b/web/context/modal-context.tsx @@ -9,7 +9,6 @@ import type { Credential, CustomConfigurationModelFixedFields, CustomModel, - ModelLoadBalancingConfigEntry, ModelProvider, } from '@/app/components/header/account-setting/model-provider-page/declarations' import { @@ -29,6 +28,7 @@ import { removeSpecificQueryParam } from '@/utils' import { noop } from 'lodash-es' import dynamic from 'next/dynamic' import type { ExpireNoticeModalPayloadProps } from '@/app/education-apply/expire-notice-modal' +import type { ModelModalModeEnum } from '@/app/components/header/account-setting/model-provider-page/declarations' const AccountSetting = dynamic(() => import('@/app/components/header/account-setting'), { ssr: false, @@ -71,8 +71,8 @@ const ExpireNoticeModal = dynamic(() => import('@/app/education-apply/expire-not export type ModalState = { payload: T onCancelCallback?: () => void - onSaveCallback?: (newPayload: T) => void - onRemoveCallback?: (newPayload: T) => void + onSaveCallback?: (newPayload?: T, formValues?: Record) => void + onRemoveCallback?: (newPayload?: T, formValues?: Record) => void onEditCallback?: (newPayload: T) => void onValidateBeforeSaveCallback?: (newPayload: T) => boolean isEditMode?: boolean @@ -86,10 +86,7 @@ export type ModelModalType = { isModelCredential?: boolean credential?: Credential model?: CustomModel -} -export type LoadBalancingEntryModalType = ModelModalType & { - entry?: ModelLoadBalancingConfigEntry - index?: number + mode?: ModelModalModeEnum } export type ModalContextState = { @@ -187,9 +184,15 @@ export const ModalContextProvider = ({ showModelModal.onCancelCallback() }, [showModelModal]) - const handleSaveModelModal = useCallback(() => { + const handleSaveModelModal = useCallback((formValues?: Record) => { if (showModelModal?.onSaveCallback) - showModelModal.onSaveCallback(showModelModal.payload) + showModelModal.onSaveCallback(showModelModal.payload, formValues) + setShowModelModal(null) + }, [showModelModal]) + + const handleRemoveModelModal = useCallback((formValues?: Record) => { + if (showModelModal?.onRemoveCallback) + showModelModal.onRemoveCallback(showModelModal.payload, formValues) setShowModelModal(null) }, [showModelModal]) @@ -329,8 +332,10 @@ export const ModalContextProvider = ({ isModelCredential={showModelModal.payload.isModelCredential} credential={showModelModal.payload.credential} model={showModelModal.payload.model} + mode={showModelModal.payload.mode} onCancel={handleCancelModelModal} onSave={handleSaveModelModal} + onRemove={handleRemoveModelModal} /> ) } diff --git a/web/i18n/en-US/common.ts b/web/i18n/en-US/common.ts index 2f0082edd0..a54f6a4e47 100644 --- a/web/i18n/en-US/common.ts +++ b/web/i18n/en-US/common.ts @@ -498,10 +498,13 @@ const translation = { authRemoved: 'Auth removed', apiKeys: 'API Keys', addApiKey: 'Add API Key', + addModel: 'Add model', addNewModel: 'Add new model', addCredential: 'Add credential', addModelCredential: 'Add model credential', + editModelCredential: 'Edit model credential', modelCredentials: 'Model credentials', + modelCredential: 'Model credential', configModel: 'Config model', configLoadBalancing: 'Config Load Balancing', authorizationError: 'Authorization error', @@ -514,6 +517,12 @@ const translation = { desc: 'After configuring credentials, all members within the workspace can use this model when orchestrating applications.', addModel: 'Add model', }, + manageCredentials: 'Manage Credentials', + customModelCredentials: 'Custom Model Credentials', + addNewModelCredential: 'Add new model credential', + removeModel: 'Remove Model', + selectModelCredential: 'Select a model credential', + customModelCredentialsDeleteTip: 'Credential is in use and cannot be deleted', }, }, dataSource: { diff --git a/web/i18n/zh-Hans/common.ts b/web/i18n/zh-Hans/common.ts index 5d9f01834b..a83487e432 100644 --- a/web/i18n/zh-Hans/common.ts +++ b/web/i18n/zh-Hans/common.ts @@ -492,10 +492,13 @@ const translation = { authRemoved: '授权已移除', apiKeys: 'API 密钥', addApiKey: '添加 API 密钥', + addModel: '添加模型', addNewModel: '添加新模型', addCredential: '添加凭据', addModelCredential: '添加模型凭据', + editModelCredential: '编辑模型凭据', modelCredentials: '模型凭据', + modelCredential: '模型凭据', configModel: '配置模型', configLoadBalancing: '配置负载均衡', authorizationError: '授权错误', @@ -508,6 +511,12 @@ const translation = { desc: '配置凭据后,工作空间中的所有成员都可以在编排应用时使用此模型。', addModel: '添加模型', }, + manageCredentials: '管理凭据', + customModelCredentials: '自定义模型凭据', + addNewModelCredential: '添加模型新凭据', + removeModel: '移除模型', + selectModelCredential: '选择模型凭据', + customModelCredentialsDeleteTip: '模型凭据正在使用中,无法删除', }, }, dataSource: { diff --git a/web/service/use-models.ts b/web/service/use-models.ts index f3336dd03b..d6eb929646 100644 --- a/web/service/use-models.ts +++ b/web/service/use-models.ts @@ -122,7 +122,7 @@ export const useDeleteModel = (provider: string) => { mutationFn: (data: { model: string model_type: ModelTypeEnum - }) => del<{ result: string }>(`/workspaces/current/model-providers/${provider}/models/credentials`, { + }) => del<{ result: string }>(`/workspaces/current/model-providers/${provider}/models`, { body: data, }), }) From b673560b92c107195dd36080463f5f59c4085254 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=9D=9E=E6=B3=95=E6=93=8D=E4=BD=9C?= Date: Wed, 3 Sep 2025 13:52:31 +0800 Subject: [PATCH 4/4] feat: improve multi model credentials (#25009) Co-authored-by: Claude --- .../console/workspace/model_providers.py | 4 +- api/controllers/console/workspace/models.py | 10 +- api/core/entities/provider_configuration.py | 124 +++++-- api/core/entities/provider_entities.py | 12 + api/core/provider_manager.py | 305 ++++++++++++------ .../entities/model_provider_entities.py | 2 + api/services/model_load_balancing_service.py | 15 +- api/services/model_provider_service.py | 10 +- 8 files changed, 332 insertions(+), 150 deletions(-) diff --git a/api/controllers/console/workspace/model_providers.py b/api/controllers/console/workspace/model_providers.py index 3861fb8e99..bfcc9a7f0a 100644 --- a/api/controllers/console/workspace/model_providers.py +++ b/api/controllers/console/workspace/model_providers.py @@ -67,7 +67,7 @@ class ModelProviderCredentialApi(Resource): parser = reqparse.RequestParser() parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json") - parser.add_argument("name", type=StrLen(30), required=True, nullable=False, location="json") + parser.add_argument("name", type=StrLen(30), required=False, nullable=True, location="json") args = parser.parse_args() model_provider_service = ModelProviderService() @@ -94,7 +94,7 @@ class ModelProviderCredentialApi(Resource): parser = reqparse.RequestParser() parser.add_argument("credential_id", type=uuid_value, required=True, nullable=False, location="json") parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json") - parser.add_argument("name", type=StrLen(30), required=True, nullable=False, location="json") + parser.add_argument("name", type=StrLen(30), required=False, nullable=True, location="json") args = parser.parse_args() model_provider_service = ModelProviderService() diff --git a/api/controllers/console/workspace/models.py b/api/controllers/console/workspace/models.py index 35fc61e48a..f174fcc5d3 100644 --- a/api/controllers/console/workspace/models.py +++ b/api/controllers/console/workspace/models.py @@ -219,7 +219,11 @@ class ModelProviderModelCredentialApi(Resource): model_load_balancing_service = ModelLoadBalancingService() is_load_balancing_enabled, load_balancing_configs = model_load_balancing_service.get_load_balancing_configs( - tenant_id=tenant_id, provider=provider, model=args["model"], model_type=args["model_type"] + tenant_id=tenant_id, + provider=provider, + model=args["model"], + model_type=args["model_type"], + config_from=args.get("config_from", ""), ) if args.get("config_from", "") == "predefined-model": @@ -263,7 +267,7 @@ class ModelProviderModelCredentialApi(Resource): choices=[mt.value for mt in ModelType], location="json", ) - parser.add_argument("name", type=StrLen(30), required=True, nullable=False, location="json") + parser.add_argument("name", type=StrLen(30), required=False, nullable=True, location="json") parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json") args = parser.parse_args() @@ -309,7 +313,7 @@ class ModelProviderModelCredentialApi(Resource): ) parser.add_argument("credential_id", type=uuid_value, required=True, nullable=False, location="json") parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json") - parser.add_argument("name", type=StrLen(30), required=True, nullable=False, location="json") + parser.add_argument("name", type=StrLen(30), required=False, nullable=True, location="json") args = parser.parse_args() model_provider_service = ModelProviderService() diff --git a/api/core/entities/provider_configuration.py b/api/core/entities/provider_configuration.py index ca3c36b878..b74e081dd4 100644 --- a/api/core/entities/provider_configuration.py +++ b/api/core/entities/provider_configuration.py @@ -1,5 +1,6 @@ import json import logging +import re from collections import defaultdict from collections.abc import Iterator, Sequence from json import JSONDecodeError @@ -343,7 +344,65 @@ class ProviderConfiguration(BaseModel): with Session(db.engine) as new_session: return _validate(new_session) - def create_provider_credential(self, credentials: dict, credential_name: str) -> None: + def _generate_provider_credential_name(self, session) -> str: + """ + Generate a unique credential name for provider. + :return: credential name + """ + return self._generate_next_api_key_name( + session=session, + query_factory=lambda: select(ProviderCredential).where( + ProviderCredential.tenant_id == self.tenant_id, + ProviderCredential.provider_name == self.provider.provider, + ), + ) + + def _generate_custom_model_credential_name(self, model: str, model_type: ModelType, session) -> str: + """ + Generate a unique credential name for custom model. + :return: credential name + """ + return self._generate_next_api_key_name( + session=session, + query_factory=lambda: select(ProviderModelCredential).where( + ProviderModelCredential.tenant_id == self.tenant_id, + ProviderModelCredential.provider_name == self.provider.provider, + ProviderModelCredential.model_name == model, + ProviderModelCredential.model_type == model_type.to_origin_model_type(), + ), + ) + + def _generate_next_api_key_name(self, session, query_factory) -> str: + """ + Generate next available API KEY name by finding the highest numbered suffix. + :param session: database session + :param query_factory: function that returns the SQLAlchemy query + :return: next available API KEY name + """ + try: + stmt = query_factory() + credential_records = session.execute(stmt).scalars().all() + + if not credential_records: + return "API KEY 1" + + # Extract numbers from API KEY pattern using list comprehension + pattern = re.compile(r"^API KEY\s+(\d+)$") + numbers = [ + int(match.group(1)) + for cr in credential_records + if cr.credential_name and (match := pattern.match(cr.credential_name.strip())) + ] + + # Return next sequential number + next_number = max(numbers, default=0) + 1 + return f"API KEY {next_number}" + + except Exception as e: + logger.warning("Error generating next credential name: %s", str(e)) + return "API KEY 1" + + def create_provider_credential(self, credentials: dict, credential_name: str | None) -> None: """ Add custom provider credentials. :param credentials: provider credentials @@ -351,8 +410,12 @@ class ProviderConfiguration(BaseModel): :return: """ with Session(db.engine) as session: - if self._check_provider_credential_name_exists(credential_name=credential_name, session=session): + if credential_name and self._check_provider_credential_name_exists( + credential_name=credential_name, session=session + ): raise ValueError(f"Credential with name '{credential_name}' already exists.") + else: + credential_name = self._generate_provider_credential_name(session) credentials = self.validate_provider_credentials(credentials=credentials, session=session) provider_record = self._get_provider_record(session) @@ -395,7 +458,7 @@ class ProviderConfiguration(BaseModel): self, credentials: dict, credential_id: str, - credential_name: str, + credential_name: str | None, ) -> None: """ update a saved provider credential (by credential_id). @@ -406,7 +469,7 @@ class ProviderConfiguration(BaseModel): :return: """ with Session(db.engine) as session: - if self._check_provider_credential_name_exists( + if credential_name and self._check_provider_credential_name_exists( credential_name=credential_name, session=session, exclude_id=credential_id ): raise ValueError(f"Credential with name '{credential_name}' already exists.") @@ -428,9 +491,9 @@ class ProviderConfiguration(BaseModel): try: # Update credential credential_record.encrypted_config = json.dumps(credentials) - credential_record.credential_name = credential_name credential_record.updated_at = naive_utc_now() - + if credential_name: + credential_record.credential_name = credential_name session.commit() if provider_record and provider_record.credential_id == credential_id: @@ -532,13 +595,7 @@ class ProviderConfiguration(BaseModel): cache_type=ProviderCredentialsCacheType.LOAD_BALANCING_MODEL, ) lb_credentials_cache.delete() - - lb_config.credential_id = None - lb_config.encrypted_config = None - lb_config.enabled = False - lb_config.name = "__delete__" - lb_config.updated_at = naive_utc_now() - session.add(lb_config) + session.delete(lb_config) # Check if this is the currently active credential provider_record = self._get_provider_record(session) @@ -822,7 +879,7 @@ class ProviderConfiguration(BaseModel): return _validate(new_session) def create_custom_model_credential( - self, model_type: ModelType, model: str, credentials: dict, credential_name: str + self, model_type: ModelType, model: str, credentials: dict, credential_name: str | None ) -> None: """ Create a custom model credential. @@ -833,10 +890,14 @@ class ProviderConfiguration(BaseModel): :return: """ with Session(db.engine) as session: - if self._check_custom_model_credential_name_exists( + if credential_name and self._check_custom_model_credential_name_exists( model=model, model_type=model_type, credential_name=credential_name, session=session ): raise ValueError(f"Model credential with name '{credential_name}' already exists for {model}.") + else: + credential_name = self._generate_custom_model_credential_name( + model=model, model_type=model_type, session=session + ) # validate custom model config credentials = self.validate_custom_model_credentials( model_type=model_type, model=model, credentials=credentials, session=session @@ -880,7 +941,7 @@ class ProviderConfiguration(BaseModel): raise def update_custom_model_credential( - self, model_type: ModelType, model: str, credentials: dict, credential_name: str, credential_id: str + self, model_type: ModelType, model: str, credentials: dict, credential_name: str | None, credential_id: str ) -> None: """ Update a custom model credential. @@ -893,7 +954,7 @@ class ProviderConfiguration(BaseModel): :return: """ with Session(db.engine) as session: - if self._check_custom_model_credential_name_exists( + if credential_name and self._check_custom_model_credential_name_exists( model=model, model_type=model_type, credential_name=credential_name, @@ -925,8 +986,9 @@ class ProviderConfiguration(BaseModel): try: # Update credential credential_record.encrypted_config = json.dumps(credentials) - credential_record.credential_name = credential_name credential_record.updated_at = naive_utc_now() + if credential_name: + credential_record.credential_name = credential_name session.commit() if provider_model_record and provider_model_record.credential_id == credential_id: @@ -982,12 +1044,7 @@ class ProviderConfiguration(BaseModel): cache_type=ProviderCredentialsCacheType.LOAD_BALANCING_MODEL, ) lb_credentials_cache.delete() - lb_config.credential_id = None - lb_config.encrypted_config = None - lb_config.enabled = False - lb_config.name = "__delete__" - lb_config.updated_at = naive_utc_now() - session.add(lb_config) + session.delete(lb_config) # Check if this is the currently active credential provider_model_record = self._get_custom_model_record(model_type, model, session=session) @@ -1054,6 +1111,7 @@ class ProviderConfiguration(BaseModel): provider_name=self.provider.provider, model_name=model, model_type=model_type.to_origin_model_type(), + is_valid=True, credential_id=credential_id, ) else: @@ -1605,11 +1663,9 @@ class ProviderConfiguration(BaseModel): if config.credential_source_type != "custom_model" ] - if len(provider_model_lb_configs) > 1: - load_balancing_enabled = True - - if any(config.name == "__delete__" for config in provider_model_lb_configs): - has_invalid_load_balancing_configs = True + load_balancing_enabled = model_setting.load_balancing_enabled + # when the user enable load_balancing but available configs are less than 2 display warning + has_invalid_load_balancing_configs = load_balancing_enabled and len(provider_model_lb_configs) < 2 provider_models.append( ModelWithProviderEntity( @@ -1631,6 +1687,8 @@ class ProviderConfiguration(BaseModel): for model_configuration in self.custom_configuration.models: if model_configuration.model_type not in model_types: continue + if model_configuration.unadded_to_model_list: + continue if model and model != model_configuration.model: continue try: @@ -1663,11 +1721,9 @@ class ProviderConfiguration(BaseModel): if config.credential_source_type != "provider" ] - if len(custom_model_lb_configs) > 1: - load_balancing_enabled = True - - if any(config.name == "__delete__" for config in custom_model_lb_configs): - has_invalid_load_balancing_configs = True + load_balancing_enabled = model_setting.load_balancing_enabled + # when the user enable load_balancing but available configs are less than 2 display warning + has_invalid_load_balancing_configs = load_balancing_enabled and len(custom_model_lb_configs) < 2 if len(model_configuration.available_model_credentials) > 0 and not model_configuration.credentials: status = ModelStatus.CREDENTIAL_REMOVED diff --git a/api/core/entities/provider_entities.py b/api/core/entities/provider_entities.py index 1b87bffe57..79a7514bbc 100644 --- a/api/core/entities/provider_entities.py +++ b/api/core/entities/provider_entities.py @@ -111,11 +111,21 @@ class CustomModelConfiguration(BaseModel): current_credential_id: Optional[str] = None current_credential_name: Optional[str] = None available_model_credentials: list[CredentialConfiguration] = [] + unadded_to_model_list: Optional[bool] = False # pydantic configs model_config = ConfigDict(protected_namespaces=()) +class UnaddedModelConfiguration(BaseModel): + """ + Model class for provider unadded model configuration. + """ + + model: str + model_type: ModelType + + class CustomConfiguration(BaseModel): """ Model class for provider custom configuration. @@ -123,6 +133,7 @@ class CustomConfiguration(BaseModel): provider: Optional[CustomProviderConfiguration] = None models: list[CustomModelConfiguration] = [] + can_added_models: list[UnaddedModelConfiguration] = [] class ModelLoadBalancingConfiguration(BaseModel): @@ -144,6 +155,7 @@ class ModelSettings(BaseModel): model: str model_type: ModelType enabled: bool = True + load_balancing_enabled: bool = False load_balancing_configs: list[ModelLoadBalancingConfiguration] = [] # pydantic configs diff --git a/api/core/provider_manager.py b/api/core/provider_manager.py index 04996442ca..f8ef0c1846 100644 --- a/api/core/provider_manager.py +++ b/api/core/provider_manager.py @@ -1,8 +1,9 @@ import contextlib import json from collections import defaultdict +from collections.abc import Sequence from json import JSONDecodeError -from typing import Any, Optional +from typing import Any, Optional, cast from sqlalchemy import select from sqlalchemy.exc import IntegrityError @@ -22,6 +23,7 @@ from core.entities.provider_entities import ( QuotaConfiguration, QuotaUnit, SystemConfiguration, + UnaddedModelConfiguration, ) from core.helper import encrypter from core.helper.model_provider_cache import ProviderCredentialsCache, ProviderCredentialsCacheType @@ -537,6 +539,23 @@ class ProviderManager: for credential in available_credentials ] + @staticmethod + def get_credentials_from_provider_model(tenant_id: str, provider_name: str) -> Sequence[ProviderModelCredential]: + """ + Get all the credentials records from ProviderModelCredential by provider_name + + :param tenant_id: workspace id + :param provider_name: provider name + + """ + with Session(db.engine, expire_on_commit=False) as session: + stmt = select(ProviderModelCredential).where( + ProviderModelCredential.tenant_id == tenant_id, ProviderModelCredential.provider_name == provider_name + ) + + all_credentials = session.scalars(stmt).all() + return all_credentials + @staticmethod def _init_trial_provider_records( tenant_id: str, provider_name_to_provider_records_dict: dict[str, list[Provider]] @@ -623,6 +642,44 @@ class ProviderManager: :param provider_model_records: provider model records :return: """ + # Get custom provider configuration + custom_provider_configuration = self._get_custom_provider_configuration( + tenant_id, provider_entity, provider_records + ) + + # Get all model credentials once + all_model_credentials = self.get_credentials_from_provider_model(tenant_id, provider_entity.provider) + + # Get custom models which have not been added to the model list yet + unadded_models = self._get_can_added_models(provider_model_records, all_model_credentials) + + # Get custom model configurations + custom_model_configurations = self._get_custom_model_configurations( + tenant_id, provider_entity, provider_model_records, unadded_models, all_model_credentials + ) + + can_added_models = [ + UnaddedModelConfiguration(model=model["model"], model_type=model["model_type"]) for model in unadded_models + ] + + return CustomConfiguration( + provider=custom_provider_configuration, + models=custom_model_configurations, + can_added_models=can_added_models, + ) + + def _get_custom_provider_configuration( + self, tenant_id: str, provider_entity: ProviderEntity, provider_records: list[Provider] + ) -> CustomProviderConfiguration | None: + """Get custom provider configuration.""" + # Find custom provider record (non-system) + custom_provider_record = next( + (record for record in provider_records if record.provider_type != ProviderType.SYSTEM.value), None + ) + + if not custom_provider_record: + return None + # Get provider credential secret variables provider_credential_secret_variables = self._extract_secret_variables( provider_entity.provider_credential_schema.credential_form_schemas @@ -630,113 +687,98 @@ class ProviderManager: else [] ) - # Get custom provider record - custom_provider_record = None - for provider_record in provider_records: - if provider_record.provider_type == ProviderType.SYSTEM.value: - continue + # Get and decrypt provider credentials + provider_credentials = self._get_and_decrypt_credentials( + tenant_id=tenant_id, + record_id=custom_provider_record.id, + encrypted_config=custom_provider_record.encrypted_config, + secret_variables=provider_credential_secret_variables, + cache_type=ProviderCredentialsCacheType.PROVIDER, + is_provider=True, + ) - custom_provider_record = provider_record + return CustomProviderConfiguration( + credentials=provider_credentials, + current_credential_name=custom_provider_record.credential_name, + current_credential_id=custom_provider_record.credential_id, + available_credentials=self.get_provider_available_credentials( + tenant_id, custom_provider_record.provider_name + ), + ) - # Get custom provider credentials - custom_provider_configuration = None - if custom_provider_record: - provider_credentials_cache = ProviderCredentialsCache( - tenant_id=tenant_id, - identity_id=custom_provider_record.id, - cache_type=ProviderCredentialsCacheType.PROVIDER, - ) + def _get_can_added_models( + self, provider_model_records: list[ProviderModel], all_model_credentials: Sequence[ProviderModelCredential] + ) -> list[dict]: + """Get the custom models and credentials from enterprise version which haven't add to the model list""" + existing_model_set = {(record.model_name, record.model_type) for record in provider_model_records} - # Get cached provider credentials - cached_provider_credentials = provider_credentials_cache.get() + # Get not added custom models credentials + not_added_custom_models_credentials = [ + credential + for credential in all_model_credentials + if (credential.model_name, credential.model_type) not in existing_model_set + ] - if not cached_provider_credentials: - try: - # fix origin data - if custom_provider_record.encrypted_config is None: - provider_credentials = {} - elif not custom_provider_record.encrypted_config.startswith("{"): - provider_credentials = {"openai_api_key": custom_provider_record.encrypted_config} - else: - provider_credentials = json.loads(custom_provider_record.encrypted_config) - except JSONDecodeError: - provider_credentials = {} + # Group credentials by model + model_to_credentials = defaultdict(list) + for credential in not_added_custom_models_credentials: + model_to_credentials[(credential.model_name, credential.model_type)].append(credential) - # Get decoding rsa key and cipher for decrypting credentials - if self.decoding_rsa_key is None or self.decoding_cipher_rsa is None: - self.decoding_rsa_key, self.decoding_cipher_rsa = encrypter.get_decrypt_decoding(tenant_id) + return [ + { + "model": model_key[0], + "model_type": ModelType.value_of(model_key[1]), + "available_model_credentials": [ + CredentialConfiguration(credential_id=cred.id, credential_name=cred.credential_name) + for cred in creds + ], + } + for model_key, creds in model_to_credentials.items() + ] - for variable in provider_credential_secret_variables: - if variable in provider_credentials: - with contextlib.suppress(ValueError): - provider_credentials[variable] = encrypter.decrypt_token_with_decoding( - provider_credentials.get(variable) or "", # type: ignore - self.decoding_rsa_key, - self.decoding_cipher_rsa, - ) - - # cache provider credentials - provider_credentials_cache.set(credentials=provider_credentials) - else: - provider_credentials = cached_provider_credentials - - custom_provider_configuration = CustomProviderConfiguration( - credentials=provider_credentials, - current_credential_name=custom_provider_record.credential_name, - current_credential_id=custom_provider_record.credential_id, - available_credentials=self.get_provider_available_credentials( - tenant_id, custom_provider_record.provider_name - ), - ) - - # Get provider model credential secret variables + def _get_custom_model_configurations( + self, + tenant_id: str, + provider_entity: ProviderEntity, + provider_model_records: list[ProviderModel], + can_added_models: list[dict], + all_model_credentials: Sequence[ProviderModelCredential], + ) -> list[CustomModelConfiguration]: + """Get custom model configurations.""" + # Get model credential secret variables model_credential_secret_variables = self._extract_secret_variables( provider_entity.model_credential_schema.credential_form_schemas if provider_entity.model_credential_schema else [] ) - # Get custom provider model credentials + # Create credentials lookup for efficient access + credentials_map = defaultdict(list) + for credential in all_model_credentials: + credentials_map[(credential.model_name, credential.model_type)].append(credential) + custom_model_configurations = [] + + # Process existing model records for provider_model_record in provider_model_records: - available_model_credentials = self.get_provider_model_available_credentials( - tenant_id, - provider_model_record.provider_name, - provider_model_record.model_name, - provider_model_record.model_type, + # Use pre-fetched credentials instead of individual database calls + available_model_credentials = [ + CredentialConfiguration(credential_id=cred.id, credential_name=cred.credential_name) + for cred in credentials_map.get( + (provider_model_record.model_name, provider_model_record.model_type), [] + ) + ] + + # Get and decrypt model credentials + provider_model_credentials = self._get_and_decrypt_credentials( + tenant_id=tenant_id, + record_id=provider_model_record.id, + encrypted_config=provider_model_record.encrypted_config, + secret_variables=model_credential_secret_variables, + cache_type=ProviderCredentialsCacheType.MODEL, + is_provider=False, ) - provider_model_credentials_cache = ProviderCredentialsCache( - tenant_id=tenant_id, identity_id=provider_model_record.id, cache_type=ProviderCredentialsCacheType.MODEL - ) - - # Get cached provider model credentials - cached_provider_model_credentials = provider_model_credentials_cache.get() - - if not cached_provider_model_credentials and provider_model_record.encrypted_config: - try: - provider_model_credentials = json.loads(provider_model_record.encrypted_config) - except JSONDecodeError: - continue - - # Get decoding rsa key and cipher for decrypting credentials - if self.decoding_rsa_key is None or self.decoding_cipher_rsa is None: - self.decoding_rsa_key, self.decoding_cipher_rsa = encrypter.get_decrypt_decoding(tenant_id) - - for variable in model_credential_secret_variables: - if variable in provider_model_credentials: - with contextlib.suppress(ValueError): - provider_model_credentials[variable] = encrypter.decrypt_token_with_decoding( - provider_model_credentials.get(variable), - self.decoding_rsa_key, - self.decoding_cipher_rsa, - ) - - # cache provider model credentials - provider_model_credentials_cache.set(credentials=provider_model_credentials) - else: - provider_model_credentials = cached_provider_model_credentials - custom_model_configurations.append( CustomModelConfiguration( model=provider_model_record.model_name, @@ -748,7 +790,71 @@ class ProviderManager: ) ) - return CustomConfiguration(provider=custom_provider_configuration, models=custom_model_configurations) + # Add models that can be added + for model in can_added_models: + custom_model_configurations.append( + CustomModelConfiguration( + model=model["model"], + model_type=model["model_type"], + credentials=None, + current_credential_id=None, + current_credential_name=None, + available_model_credentials=model["available_model_credentials"], + unadded_to_model_list=True, + ) + ) + + return custom_model_configurations + + def _get_and_decrypt_credentials( + self, + tenant_id: str, + record_id: str, + encrypted_config: str | None, + secret_variables: list[str], + cache_type: ProviderCredentialsCacheType, + is_provider: bool = False, + ) -> dict: + """Get and decrypt credentials with caching.""" + credentials_cache = ProviderCredentialsCache( + tenant_id=tenant_id, + identity_id=record_id, + cache_type=cache_type, + ) + + # Try to get from cache first + cached_credentials = credentials_cache.get() + if cached_credentials: + return cached_credentials + + # Parse encrypted config + if not encrypted_config: + return {} + + if is_provider and not encrypted_config.startswith("{"): + return {"openai_api_key": encrypted_config} + + try: + credentials = cast(dict, json.loads(encrypted_config)) + except JSONDecodeError: + return {} + + # Decrypt secret variables + if self.decoding_rsa_key is None or self.decoding_cipher_rsa is None: + self.decoding_rsa_key, self.decoding_cipher_rsa = encrypter.get_decrypt_decoding(tenant_id) + + for variable in secret_variables: + if variable in credentials: + with contextlib.suppress(ValueError): + credentials[variable] = encrypter.decrypt_token_with_decoding( + credentials.get(variable) or "", + self.decoding_rsa_key, + self.decoding_cipher_rsa, + ) + + # Cache the decrypted credentials + credentials_cache.set(credentials=credentials) + return credentials def _to_system_configuration( self, tenant_id: str, provider_entity: ProviderEntity, provider_records: list[Provider] @@ -956,18 +1062,6 @@ class ProviderManager: load_balancing_model_config.model_name == provider_model_setting.model_name and load_balancing_model_config.model_type == provider_model_setting.model_type ): - if load_balancing_model_config.name == "__delete__": - # to calculate current model whether has invalidate lb configs - load_balancing_configs.append( - ModelLoadBalancingConfiguration( - id=load_balancing_model_config.id, - name=load_balancing_model_config.name, - credentials={}, - credential_source_type=load_balancing_model_config.credential_source_type, - ) - ) - continue - if not load_balancing_model_config.enabled: continue @@ -1033,6 +1127,7 @@ class ProviderManager: model=provider_model_setting.model_name, model_type=ModelType.value_of(provider_model_setting.model_type), enabled=provider_model_setting.enabled, + load_balancing_enabled=provider_model_setting.load_balancing_enabled, load_balancing_configs=load_balancing_configs if len(load_balancing_configs) > 1 else [], ) ) diff --git a/api/services/entities/model_provider_entities.py b/api/services/entities/model_provider_entities.py index 056decda26..1fe259dd46 100644 --- a/api/services/entities/model_provider_entities.py +++ b/api/services/entities/model_provider_entities.py @@ -13,6 +13,7 @@ from core.entities.provider_entities import ( CustomModelConfiguration, ProviderQuotaType, QuotaConfiguration, + UnaddedModelConfiguration, ) from core.model_runtime.entities.common_entities import I18nObject from core.model_runtime.entities.model_entities import ModelType @@ -45,6 +46,7 @@ class CustomConfigurationResponse(BaseModel): current_credential_name: Optional[str] = None available_credentials: Optional[list[CredentialConfiguration]] = None custom_models: Optional[list[CustomModelConfiguration]] = None + can_added_models: Optional[list[UnaddedModelConfiguration]] = None class SystemConfigurationResponse(BaseModel): diff --git a/api/services/model_load_balancing_service.py b/api/services/model_load_balancing_service.py index d830034f11..17696f5cd8 100644 --- a/api/services/model_load_balancing_service.py +++ b/api/services/model_load_balancing_service.py @@ -3,6 +3,8 @@ import logging from json import JSONDecodeError from typing import Optional, Union +from sqlalchemy import or_ + from constants import HIDDEN_VALUE from core.entities.provider_configuration import ProviderConfiguration from core.helper import encrypter @@ -69,7 +71,7 @@ class ModelLoadBalancingService: provider_configuration.disable_model_load_balancing(model=model, model_type=ModelType.value_of(model_type)) def get_load_balancing_configs( - self, tenant_id: str, provider: str, model: str, model_type: str + self, tenant_id: str, provider: str, model: str, model_type: str, config_from: str = "" ) -> tuple[bool, list[dict]]: """ Get load balancing configurations. @@ -100,6 +102,11 @@ class ModelLoadBalancingService: if provider_model_setting and provider_model_setting.load_balancing_enabled: is_load_balancing_enabled = True + if config_from == "predefined-model": + credential_source_type = "provider" + else: + credential_source_type = "custom_model" + # Get load balancing configurations load_balancing_configs = ( db.session.query(LoadBalancingModelConfig) @@ -108,6 +115,10 @@ class ModelLoadBalancingService: LoadBalancingModelConfig.provider_name == provider_configuration.provider.provider, LoadBalancingModelConfig.model_type == model_type_enum.to_origin_model_type(), LoadBalancingModelConfig.model_name == model, + or_( + LoadBalancingModelConfig.credential_source_type == credential_source_type, + LoadBalancingModelConfig.credential_source_type.is_(None), + ), ) .order_by(LoadBalancingModelConfig.created_at) .all() @@ -405,7 +416,7 @@ class ModelLoadBalancingService: self._clear_credentials_cache(tenant_id, config_id) else: # create load balancing config - if name in {"__inherit__", "__delete__"}: + if name == "__inherit__": raise ValueError("Invalid load balancing config name") if credential_id: diff --git a/api/services/model_provider_service.py b/api/services/model_provider_service.py index 9e9422f9f7..69c7e4cf58 100644 --- a/api/services/model_provider_service.py +++ b/api/services/model_provider_service.py @@ -72,6 +72,7 @@ class ModelProviderService: provider_config = provider_configuration.custom_configuration.provider model_config = provider_configuration.custom_configuration.models + can_added_models = provider_configuration.custom_configuration.can_added_models provider_response = ProviderResponse( tenant_id=tenant_id, @@ -95,6 +96,7 @@ class ModelProviderService: current_credential_name=getattr(provider_config, "current_credential_name", None), available_credentials=getattr(provider_config, "available_credentials", []), custom_models=model_config, + can_added_models=can_added_models, ), system_configuration=SystemConfigurationResponse( enabled=provider_configuration.system_configuration.enabled, @@ -152,7 +154,7 @@ class ModelProviderService: provider_configuration.validate_provider_credentials(credentials) def create_provider_credential( - self, tenant_id: str, provider: str, credentials: dict, credential_name: str + self, tenant_id: str, provider: str, credentials: dict, credential_name: str | None ) -> None: """ Create and save new provider credentials. @@ -172,7 +174,7 @@ class ModelProviderService: provider: str, credentials: dict, credential_id: str, - credential_name: str, + credential_name: str | None, ) -> None: """ update a saved provider credential (by credential_id). @@ -249,7 +251,7 @@ class ModelProviderService: ) def create_model_credential( - self, tenant_id: str, provider: str, model_type: str, model: str, credentials: dict, credential_name: str + self, tenant_id: str, provider: str, model_type: str, model: str, credentials: dict, credential_name: str | None ) -> None: """ create and save model credentials. @@ -278,7 +280,7 @@ class ModelProviderService: model: str, credentials: dict, credential_id: str, - credential_name: str, + credential_name: str | None, ) -> None: """ update model credentials.