support model params change

This commit is contained in:
JzoNg 2024-12-24 14:15:18 +08:00
parent c8fc1deca6
commit e2e2090e0c
5 changed files with 131 additions and 71 deletions

View File

@ -2,7 +2,7 @@
import type { FC } from 'react' import type { FC } from 'react'
import React, { Fragment, useEffect, useState } from 'react' import React, { Fragment, useEffect, useState } from 'react'
import { Combobox, Listbox, Transition } from '@headlessui/react' import { Combobox, Listbox, Transition } from '@headlessui/react'
import { CheckIcon, ChevronDownIcon, ChevronUpIcon, XMarkIcon } from '@heroicons/react/20/solid' import { ChevronDownIcon, ChevronUpIcon, XMarkIcon } from '@heroicons/react/20/solid'
import Badge from '../badge/index' import Badge from '../badge/index'
import { RiCheckLine } from '@remixicon/react' import { RiCheckLine } from '@remixicon/react'
import { useTranslation } from 'react-i18next' import { useTranslation } from 'react-i18next'
@ -352,7 +352,7 @@ const PortalSelect: FC<PortalSelectProps> = ({
</PortalToFollowElemTrigger> </PortalToFollowElemTrigger>
<PortalToFollowElemContent className={`z-20 ${popupClassName}`}> <PortalToFollowElemContent className={`z-20 ${popupClassName}`}>
<div <div
className={classNames('px-1 py-1 max-h-60 overflow-auto rounded-md bg-white text-base shadow-lg border-gray-200 border-[0.5px] focus:outline-none sm:text-sm', popupInnerClassName)} className={classNames('px-1 py-1 max-h-60 overflow-auto rounded-md text-base shadow-lg border-components-panel-border bg-components-panel-bg border-[0.5px] focus:outline-none sm:text-sm', popupInnerClassName)}
> >
{items.map((item: Item) => ( {items.map((item: Item) => (
<div <div

View File

@ -72,25 +72,15 @@ const Form: FC<FormProps> = ({
onChange({ ...value, [key]: val, ...shouldClearVariable }) onChange({ ...value, [key]: val, ...shouldClearVariable })
} }
const handleModelChanged = useCallback((key: string, model: { provider: string; modelId: string; mode?: string }) => { const handleModelChanged = useCallback((key: string, model: any) => {
const newValue = { const newValue = {
...value[key], ...value[key],
provider: model.provider, ...model,
model: model.modelId,
mode: model.mode,
type: FormTypeEnum.modelSelector, type: FormTypeEnum.modelSelector,
} }
onChange({ ...value, [key]: newValue }) onChange({ ...value, [key]: newValue })
}, [onChange, value]) }, [onChange, value])
const handleCompletionParamsChange = useCallback((key: string, newParams: Record<string, any>) => {
const newValue = {
...value[key],
completion_params: newParams,
}
onChange({ ...value, [key]: newValue })
}, [onChange, value])
const renderField = (formSchema: CredentialFormSchema) => { const renderField = (formSchema: CredentialFormSchema) => {
const tooltip = formSchema.tooltip const tooltip = formSchema.tooltip
const tooltipContent = (tooltip && ( const tooltipContent = (tooltip && (
@ -302,12 +292,8 @@ const Form: FC<FormProps> = ({
popupClassName='!w-[387px]' popupClassName='!w-[387px]'
isAdvancedMode isAdvancedMode
isInWorkflow isInWorkflow
provider={value[variable]?.provider} value={value[variable]}
modelId={value[variable]?.model}
mode={value[variable]?.mode}
completionParams={value[variable]?.completion_params}
setModel={model => handleModelChanged(variable, model)} setModel={model => handleModelChanged(variable, model)}
onCompletionParamsChange={params => handleCompletionParamsChange(variable, params)}
readonly={readonly} readonly={readonly}
scope={scope} scope={scope}
/> />

View File

@ -21,6 +21,7 @@ import {
PortalToFollowElemTrigger, PortalToFollowElemTrigger,
} from '@/app/components/base/portal-to-follow-elem' } from '@/app/components/base/portal-to-follow-elem'
import LLMParamsPanel from './llm-params-panel' import LLMParamsPanel from './llm-params-panel'
import TTSParamsPanel from './tts-params-panel'
import { useProviderContext } from '@/context/provider-context' import { useProviderContext } from '@/context/provider-context'
import cn from '@/utils/classnames' import cn from '@/utils/classnames'
@ -28,12 +29,8 @@ export type ModelParameterModalProps = {
popupClassName?: string popupClassName?: string
portalToFollowElemContentClassName?: string portalToFollowElemContentClassName?: string
isAdvancedMode: boolean isAdvancedMode: boolean
mode: string value: any
modelId: string setModel: (model: any) => void
provider: string
setModel: (model: { modelId: string; provider: string; mode?: string; features?: string[] }) => void
completionParams: FormValue
onCompletionParamsChange: (newParams: FormValue) => void
renderTrigger?: (v: TriggerProps) => ReactNode renderTrigger?: (v: TriggerProps) => ReactNode
readonly?: boolean readonly?: boolean
isInWorkflow?: boolean isInWorkflow?: boolean
@ -44,15 +41,12 @@ const ModelParameterModal: FC<ModelParameterModalProps> = ({
popupClassName, popupClassName,
portalToFollowElemContentClassName, portalToFollowElemContentClassName,
isAdvancedMode, isAdvancedMode,
modelId, value,
provider,
setModel, setModel,
completionParams,
onCompletionParamsChange,
renderTrigger, renderTrigger,
readonly, readonly,
isInWorkflow, isInWorkflow,
scope = 'text-generation', scope = ModelTypeEnum.textGeneration,
}) => { }) => {
const { t } = useTranslation() const { t } = useTranslation()
const { isAPIKeySet } = useProviderContext() const { isAPIKeySet } = useProviderContext()
@ -79,29 +73,29 @@ const ModelParameterModal: FC<ModelParameterModalProps> = ({
...moderationList, ...moderationList,
] ]
} }
if (scopeArray.includes('text-generation')) if (scopeArray.includes(ModelTypeEnum.textGeneration))
return textGenerationList return textGenerationList
if (scopeArray.includes('embedding')) if (scopeArray.includes(ModelTypeEnum.textEmbedding))
return textEmbeddingList return textEmbeddingList
if (scopeArray.includes('rerank')) if (scopeArray.includes(ModelTypeEnum.rerank))
return rerankList return rerankList
if (scopeArray.includes('moderation')) if (scopeArray.includes(ModelTypeEnum.moderation))
return moderationList return moderationList
if (scopeArray.includes('stt')) if (scopeArray.includes(ModelTypeEnum.speech2text))
return sttList return sttList
if (scopeArray.includes('tts')) if (scopeArray.includes(ModelTypeEnum.tts))
return ttsList return ttsList
return resultList return resultList
}, [scopeArray, textGenerationList, textEmbeddingList, rerankList, sttList, ttsList, moderationList]) }, [scopeArray, textGenerationList, textEmbeddingList, rerankList, sttList, ttsList, moderationList])
const { currentProvider, currentModel } = useMemo(() => { const { currentProvider, currentModel } = useMemo(() => {
const currentProvider = scopedModelList.find(item => item.provider === provider) const currentProvider = scopedModelList.find(item => item.provider === value?.provider)
const currentModel = currentProvider?.models.find((model: { model: string }) => model.model === modelId) const currentModel = currentProvider?.models.find((model: { model: string }) => model.model === value?.model)
return { return {
currentProvider, currentProvider,
currentModel, currentModel,
} }
}, [provider, modelId, scopedModelList]) }, [scopedModelList, value?.provider, value?.model])
const hasDeprecated = useMemo(() => { const hasDeprecated = useMemo(() => {
return !currentProvider || !currentModel return !currentProvider || !currentModel
@ -116,11 +110,33 @@ const ModelParameterModal: FC<ModelParameterModalProps> = ({
const handleChangeModel = ({ provider, model }: DefaultModel) => { const handleChangeModel = ({ provider, model }: DefaultModel) => {
const targetProvider = scopedModelList.find(modelItem => modelItem.provider === provider) const targetProvider = scopedModelList.find(modelItem => modelItem.provider === provider)
const targetModelItem = targetProvider?.models.find((modelItem: { model: string }) => modelItem.model === model) const targetModelItem = targetProvider?.models.find((modelItem: { model: string }) => modelItem.model === model)
const model_type = targetModelItem?.model_type as string
setModel({ setModel({
modelId: model,
provider, provider,
mode: targetModelItem?.model_properties.mode as string, model,
features: targetModelItem?.features || [], model_type,
...(model_type === ModelTypeEnum.textGeneration ? {
mode: targetModelItem?.model_properties.mode as string,
} : {}),
})
}
const handleLLMParamsChange = (newParams: FormValue) => {
const newValue = {
...(value?.completionParams || {}),
completion_params: newParams,
}
setModel({
...value,
...newValue,
})
}
const handleTTSParamsChange = (language: string, voice: string) => {
setModel({
...value,
language,
voice,
}) })
} }
@ -149,8 +165,8 @@ const ModelParameterModal: FC<ModelParameterModalProps> = ({
hasDeprecated, hasDeprecated,
currentProvider, currentProvider,
currentModel, currentModel,
providerName: provider, providerName: value?.provider,
modelId, modelId: value?.model,
}) })
: ( : (
<Trigger <Trigger
@ -160,8 +176,8 @@ const ModelParameterModal: FC<ModelParameterModalProps> = ({
hasDeprecated={hasDeprecated} hasDeprecated={hasDeprecated}
currentProvider={currentProvider} currentProvider={currentProvider}
currentModel={currentModel} currentModel={currentModel}
providerName={provider} providerName={value?.provider}
modelId={modelId} modelId={value?.model}
/> />
) )
} }
@ -174,7 +190,7 @@ const ModelParameterModal: FC<ModelParameterModalProps> = ({
{t('common.modelProvider.model').toLocaleUpperCase()} {t('common.modelProvider.model').toLocaleUpperCase()}
</div> </div>
<ModelSelector <ModelSelector
defaultModel={(provider || modelId) ? { provider, model: modelId } : undefined} defaultModel={(value?.provider || value?.model) ? { provider: value?.provider, model: value?.model } : undefined}
modelList={scopedModelList} modelList={scopedModelList}
scopeFeatures={scopeFeatures} scopeFeatures={scopeFeatures}
onSelect={handleChangeModel} onSelect={handleChangeModel}
@ -185,13 +201,21 @@ const ModelParameterModal: FC<ModelParameterModalProps> = ({
)} )}
{currentModel?.model_type === ModelTypeEnum.textGeneration && ( {currentModel?.model_type === ModelTypeEnum.textGeneration && (
<LLMParamsPanel <LLMParamsPanel
provider={provider} provider={value?.provider}
modelId={modelId} modelId={value?.model}
completionParams={completionParams} completionParams={value?.completion_params || {}}
onCompletionParamsChange={onCompletionParamsChange} onCompletionParamsChange={handleLLMParamsChange}
isAdvancedMode={isAdvancedMode} isAdvancedMode={isAdvancedMode}
/> />
)} )}
{currentModel?.model_type === ModelTypeEnum.tts && (
<TTSParamsPanel
currentModel={currentModel}
language={value?.language}
voice={value?.voice}
onChange={handleTTSParamsChange}
/>
)}
</div> </div>
</div> </div>
</PortalToFollowElemContent> </PortalToFollowElemContent>

View File

@ -0,0 +1,67 @@
import React, { useMemo } from 'react'
import { useTranslation } from 'react-i18next'
import { languages } from '@/i18n/language'
import { PortalSelect } from '@/app/components/base/select'
import cn from '@/utils/classnames'
type Props = {
currentModel: any
language: string
voice: string
onChange: (language: string, voice: string) => void
}
const TTSParamsPanel = ({
currentModel,
language,
voice,
onChange,
}: Props) => {
const { t } = useTranslation()
const voiceList = useMemo(() => {
if (!currentModel)
return []
return currentModel.model_properties.voices.map((item: { mode: any }) => ({
...item,
value: item.mode,
}))
}, [currentModel])
const setLanguage = (language: string) => {
onChange(language, voice)
}
const setVoice = (voice: string) => {
onChange(language, voice)
}
return (
<>
<div className='mb-3'>
<div className='mb-1 py-1 flex items-center text-text-secondary system-sm-semibold'>
{t('appDebug.voice.voiceSettings.language')}
</div>
<PortalSelect
triggerClassName='h-8'
popupClassName={cn('z-[1000]')}
popupInnerClassName={cn('w-[354px]')}
value={language}
items={languages.filter(item => item.supported)}
onSelect={item => setLanguage(item.value as string)}
/>
</div>
<div className='mb-3'>
<div className='mb-1 py-1 flex items-center text-text-secondary system-sm-semibold'>
{t('appDebug.voice.voiceSettings.voice')}
</div>
<PortalSelect
triggerClassName='h-8'
popupClassName={cn('z-[1000]')}
popupInnerClassName={cn('w-[354px]')}
value={voice}
items={voiceList}
onSelect={item => setVoice(item.value as string)}
/>
</div>
</>
)
}
export default TTSParamsPanel

View File

@ -123,24 +123,11 @@ const InputVarList: FC<Props> = ({
} }
}, [onChange, value]) }, [onChange, value])
const handleModelChange = useCallback((variable: string) => { const handleModelChange = useCallback((variable: string) => {
return (model: { provider: string; modelId: string; mode?: string }) => { return (model: any) => {
const newValue = produce(value, (draft: ToolVarInputs) => { const newValue = produce(value, (draft: ToolVarInputs) => {
draft[variable] = { draft[variable] = {
...draft[variable], ...draft[variable],
provider: model.provider, ...model,
model: model.modelId,
mode: model.mode,
} as any
})
onChange(newValue)
}
}, [onChange, value])
const handleModelParamsChange = useCallback((variable: string) => {
return (newParams: Record<string, any>) => {
const newValue = produce(value, (draft: ToolVarInputs) => {
draft[variable] = {
...draft[variable],
completion_params: newParams,
} as any } as any
}) })
onChange(newValue) onChange(newValue)
@ -242,12 +229,8 @@ const InputVarList: FC<Props> = ({
popupClassName='!w-[387px]' popupClassName='!w-[387px]'
isAdvancedMode isAdvancedMode
isInWorkflow isInWorkflow
provider={(varInput as any)?.provider} value={varInput as any}
modelId={(varInput as any)?.model}
mode={(varInput as any)?.mode}
completionParams={(varInput as any)?.completion_params}
setModel={handleModelChange(variable)} setModel={handleModelChange(variable)}
onCompletionParamsChange={handleModelParamsChange(variable)}
readonly={readOnly} readonly={readOnly}
scope={scope} scope={scope}
/> />