support model params change

This commit is contained in:
JzoNg 2024-12-24 14:15:18 +08:00
parent c8fc1deca6
commit e2e2090e0c
5 changed files with 131 additions and 71 deletions

View File

@ -2,7 +2,7 @@
import type { FC } from 'react'
import React, { Fragment, useEffect, useState } from 'react'
import { Combobox, Listbox, Transition } from '@headlessui/react'
import { CheckIcon, ChevronDownIcon, ChevronUpIcon, XMarkIcon } from '@heroicons/react/20/solid'
import { ChevronDownIcon, ChevronUpIcon, XMarkIcon } from '@heroicons/react/20/solid'
import Badge from '../badge/index'
import { RiCheckLine } from '@remixicon/react'
import { useTranslation } from 'react-i18next'
@ -352,7 +352,7 @@ const PortalSelect: FC<PortalSelectProps> = ({
</PortalToFollowElemTrigger>
<PortalToFollowElemContent className={`z-20 ${popupClassName}`}>
<div
className={classNames('px-1 py-1 max-h-60 overflow-auto rounded-md bg-white text-base shadow-lg border-gray-200 border-[0.5px] focus:outline-none sm:text-sm', popupInnerClassName)}
className={classNames('px-1 py-1 max-h-60 overflow-auto rounded-md text-base shadow-lg border-components-panel-border bg-components-panel-bg border-[0.5px] focus:outline-none sm:text-sm', popupInnerClassName)}
>
{items.map((item: Item) => (
<div

View File

@ -72,25 +72,15 @@ const Form: FC<FormProps> = ({
onChange({ ...value, [key]: val, ...shouldClearVariable })
}
const handleModelChanged = useCallback((key: string, model: { provider: string; modelId: string; mode?: string }) => {
const handleModelChanged = useCallback((key: string, model: any) => {
const newValue = {
...value[key],
provider: model.provider,
model: model.modelId,
mode: model.mode,
...model,
type: FormTypeEnum.modelSelector,
}
onChange({ ...value, [key]: newValue })
}, [onChange, value])
const handleCompletionParamsChange = useCallback((key: string, newParams: Record<string, any>) => {
const newValue = {
...value[key],
completion_params: newParams,
}
onChange({ ...value, [key]: newValue })
}, [onChange, value])
const renderField = (formSchema: CredentialFormSchema) => {
const tooltip = formSchema.tooltip
const tooltipContent = (tooltip && (
@ -302,12 +292,8 @@ const Form: FC<FormProps> = ({
popupClassName='!w-[387px]'
isAdvancedMode
isInWorkflow
provider={value[variable]?.provider}
modelId={value[variable]?.model}
mode={value[variable]?.mode}
completionParams={value[variable]?.completion_params}
value={value[variable]}
setModel={model => handleModelChanged(variable, model)}
onCompletionParamsChange={params => handleCompletionParamsChange(variable, params)}
readonly={readonly}
scope={scope}
/>

View File

@ -21,6 +21,7 @@ import {
PortalToFollowElemTrigger,
} from '@/app/components/base/portal-to-follow-elem'
import LLMParamsPanel from './llm-params-panel'
import TTSParamsPanel from './tts-params-panel'
import { useProviderContext } from '@/context/provider-context'
import cn from '@/utils/classnames'
@ -28,12 +29,8 @@ export type ModelParameterModalProps = {
popupClassName?: string
portalToFollowElemContentClassName?: string
isAdvancedMode: boolean
mode: string
modelId: string
provider: string
setModel: (model: { modelId: string; provider: string; mode?: string; features?: string[] }) => void
completionParams: FormValue
onCompletionParamsChange: (newParams: FormValue) => void
value: any
setModel: (model: any) => void
renderTrigger?: (v: TriggerProps) => ReactNode
readonly?: boolean
isInWorkflow?: boolean
@ -44,15 +41,12 @@ const ModelParameterModal: FC<ModelParameterModalProps> = ({
popupClassName,
portalToFollowElemContentClassName,
isAdvancedMode,
modelId,
provider,
value,
setModel,
completionParams,
onCompletionParamsChange,
renderTrigger,
readonly,
isInWorkflow,
scope = 'text-generation',
scope = ModelTypeEnum.textGeneration,
}) => {
const { t } = useTranslation()
const { isAPIKeySet } = useProviderContext()
@ -79,29 +73,29 @@ const ModelParameterModal: FC<ModelParameterModalProps> = ({
...moderationList,
]
}
if (scopeArray.includes('text-generation'))
if (scopeArray.includes(ModelTypeEnum.textGeneration))
return textGenerationList
if (scopeArray.includes('embedding'))
if (scopeArray.includes(ModelTypeEnum.textEmbedding))
return textEmbeddingList
if (scopeArray.includes('rerank'))
if (scopeArray.includes(ModelTypeEnum.rerank))
return rerankList
if (scopeArray.includes('moderation'))
if (scopeArray.includes(ModelTypeEnum.moderation))
return moderationList
if (scopeArray.includes('stt'))
if (scopeArray.includes(ModelTypeEnum.speech2text))
return sttList
if (scopeArray.includes('tts'))
if (scopeArray.includes(ModelTypeEnum.tts))
return ttsList
return resultList
}, [scopeArray, textGenerationList, textEmbeddingList, rerankList, sttList, ttsList, moderationList])
const { currentProvider, currentModel } = useMemo(() => {
const currentProvider = scopedModelList.find(item => item.provider === provider)
const currentModel = currentProvider?.models.find((model: { model: string }) => model.model === modelId)
const currentProvider = scopedModelList.find(item => item.provider === value?.provider)
const currentModel = currentProvider?.models.find((model: { model: string }) => model.model === value?.model)
return {
currentProvider,
currentModel,
}
}, [provider, modelId, scopedModelList])
}, [scopedModelList, value?.provider, value?.model])
const hasDeprecated = useMemo(() => {
return !currentProvider || !currentModel
@ -116,11 +110,33 @@ const ModelParameterModal: FC<ModelParameterModalProps> = ({
const handleChangeModel = ({ provider, model }: DefaultModel) => {
const targetProvider = scopedModelList.find(modelItem => modelItem.provider === provider)
const targetModelItem = targetProvider?.models.find((modelItem: { model: string }) => modelItem.model === model)
const model_type = targetModelItem?.model_type as string
setModel({
modelId: model,
provider,
mode: targetModelItem?.model_properties.mode as string,
features: targetModelItem?.features || [],
model,
model_type,
...(model_type === ModelTypeEnum.textGeneration ? {
mode: targetModelItem?.model_properties.mode as string,
} : {}),
})
}
const handleLLMParamsChange = (newParams: FormValue) => {
const newValue = {
...(value?.completionParams || {}),
completion_params: newParams,
}
setModel({
...value,
...newValue,
})
}
const handleTTSParamsChange = (language: string, voice: string) => {
setModel({
...value,
language,
voice,
})
}
@ -149,8 +165,8 @@ const ModelParameterModal: FC<ModelParameterModalProps> = ({
hasDeprecated,
currentProvider,
currentModel,
providerName: provider,
modelId,
providerName: value?.provider,
modelId: value?.model,
})
: (
<Trigger
@ -160,8 +176,8 @@ const ModelParameterModal: FC<ModelParameterModalProps> = ({
hasDeprecated={hasDeprecated}
currentProvider={currentProvider}
currentModel={currentModel}
providerName={provider}
modelId={modelId}
providerName={value?.provider}
modelId={value?.model}
/>
)
}
@ -174,7 +190,7 @@ const ModelParameterModal: FC<ModelParameterModalProps> = ({
{t('common.modelProvider.model').toLocaleUpperCase()}
</div>
<ModelSelector
defaultModel={(provider || modelId) ? { provider, model: modelId } : undefined}
defaultModel={(value?.provider || value?.model) ? { provider: value?.provider, model: value?.model } : undefined}
modelList={scopedModelList}
scopeFeatures={scopeFeatures}
onSelect={handleChangeModel}
@ -185,13 +201,21 @@ const ModelParameterModal: FC<ModelParameterModalProps> = ({
)}
{currentModel?.model_type === ModelTypeEnum.textGeneration && (
<LLMParamsPanel
provider={provider}
modelId={modelId}
completionParams={completionParams}
onCompletionParamsChange={onCompletionParamsChange}
provider={value?.provider}
modelId={value?.model}
completionParams={value?.completion_params || {}}
onCompletionParamsChange={handleLLMParamsChange}
isAdvancedMode={isAdvancedMode}
/>
)}
{currentModel?.model_type === ModelTypeEnum.tts && (
<TTSParamsPanel
currentModel={currentModel}
language={value?.language}
voice={value?.voice}
onChange={handleTTSParamsChange}
/>
)}
</div>
</div>
</PortalToFollowElemContent>

View File

@ -0,0 +1,67 @@
import React, { useMemo } from 'react'
import { useTranslation } from 'react-i18next'
import { languages } from '@/i18n/language'
import { PortalSelect } from '@/app/components/base/select'
import cn from '@/utils/classnames'
type Props = {
currentModel: any
language: string
voice: string
onChange: (language: string, voice: string) => void
}
const TTSParamsPanel = ({
currentModel,
language,
voice,
onChange,
}: Props) => {
const { t } = useTranslation()
const voiceList = useMemo(() => {
if (!currentModel)
return []
return currentModel.model_properties.voices.map((item: { mode: any }) => ({
...item,
value: item.mode,
}))
}, [currentModel])
const setLanguage = (language: string) => {
onChange(language, voice)
}
const setVoice = (voice: string) => {
onChange(language, voice)
}
return (
<>
<div className='mb-3'>
<div className='mb-1 py-1 flex items-center text-text-secondary system-sm-semibold'>
{t('appDebug.voice.voiceSettings.language')}
</div>
<PortalSelect
triggerClassName='h-8'
popupClassName={cn('z-[1000]')}
popupInnerClassName={cn('w-[354px]')}
value={language}
items={languages.filter(item => item.supported)}
onSelect={item => setLanguage(item.value as string)}
/>
</div>
<div className='mb-3'>
<div className='mb-1 py-1 flex items-center text-text-secondary system-sm-semibold'>
{t('appDebug.voice.voiceSettings.voice')}
</div>
<PortalSelect
triggerClassName='h-8'
popupClassName={cn('z-[1000]')}
popupInnerClassName={cn('w-[354px]')}
value={voice}
items={voiceList}
onSelect={item => setVoice(item.value as string)}
/>
</div>
</>
)
}
export default TTSParamsPanel

View File

@ -123,24 +123,11 @@ const InputVarList: FC<Props> = ({
}
}, [onChange, value])
const handleModelChange = useCallback((variable: string) => {
return (model: { provider: string; modelId: string; mode?: string }) => {
return (model: any) => {
const newValue = produce(value, (draft: ToolVarInputs) => {
draft[variable] = {
...draft[variable],
provider: model.provider,
model: model.modelId,
mode: model.mode,
} as any
})
onChange(newValue)
}
}, [onChange, value])
const handleModelParamsChange = useCallback((variable: string) => {
return (newParams: Record<string, any>) => {
const newValue = produce(value, (draft: ToolVarInputs) => {
draft[variable] = {
...draft[variable],
completion_params: newParams,
...model,
} as any
})
onChange(newValue)
@ -242,12 +229,8 @@ const InputVarList: FC<Props> = ({
popupClassName='!w-[387px]'
isAdvancedMode
isInWorkflow
provider={(varInput as any)?.provider}
modelId={(varInput as any)?.model}
mode={(varInput as any)?.mode}
completionParams={(varInput as any)?.completion_params}
value={varInput as any}
setModel={handleModelChange(variable)}
onCompletionParamsChange={handleModelParamsChange(variable)}
readonly={readOnly}
scope={scope}
/>