mirror of https://github.com/langgenius/dify.git
support model params change
This commit is contained in:
parent
c8fc1deca6
commit
e2e2090e0c
|
|
@ -2,7 +2,7 @@
|
|||
import type { FC } from 'react'
|
||||
import React, { Fragment, useEffect, useState } from 'react'
|
||||
import { Combobox, Listbox, Transition } from '@headlessui/react'
|
||||
import { CheckIcon, ChevronDownIcon, ChevronUpIcon, XMarkIcon } from '@heroicons/react/20/solid'
|
||||
import { ChevronDownIcon, ChevronUpIcon, XMarkIcon } from '@heroicons/react/20/solid'
|
||||
import Badge from '../badge/index'
|
||||
import { RiCheckLine } from '@remixicon/react'
|
||||
import { useTranslation } from 'react-i18next'
|
||||
|
|
@ -352,7 +352,7 @@ const PortalSelect: FC<PortalSelectProps> = ({
|
|||
</PortalToFollowElemTrigger>
|
||||
<PortalToFollowElemContent className={`z-20 ${popupClassName}`}>
|
||||
<div
|
||||
className={classNames('px-1 py-1 max-h-60 overflow-auto rounded-md bg-white text-base shadow-lg border-gray-200 border-[0.5px] focus:outline-none sm:text-sm', popupInnerClassName)}
|
||||
className={classNames('px-1 py-1 max-h-60 overflow-auto rounded-md text-base shadow-lg border-components-panel-border bg-components-panel-bg border-[0.5px] focus:outline-none sm:text-sm', popupInnerClassName)}
|
||||
>
|
||||
{items.map((item: Item) => (
|
||||
<div
|
||||
|
|
|
|||
|
|
@ -72,25 +72,15 @@ const Form: FC<FormProps> = ({
|
|||
onChange({ ...value, [key]: val, ...shouldClearVariable })
|
||||
}
|
||||
|
||||
const handleModelChanged = useCallback((key: string, model: { provider: string; modelId: string; mode?: string }) => {
|
||||
const handleModelChanged = useCallback((key: string, model: any) => {
|
||||
const newValue = {
|
||||
...value[key],
|
||||
provider: model.provider,
|
||||
model: model.modelId,
|
||||
mode: model.mode,
|
||||
...model,
|
||||
type: FormTypeEnum.modelSelector,
|
||||
}
|
||||
onChange({ ...value, [key]: newValue })
|
||||
}, [onChange, value])
|
||||
|
||||
const handleCompletionParamsChange = useCallback((key: string, newParams: Record<string, any>) => {
|
||||
const newValue = {
|
||||
...value[key],
|
||||
completion_params: newParams,
|
||||
}
|
||||
onChange({ ...value, [key]: newValue })
|
||||
}, [onChange, value])
|
||||
|
||||
const renderField = (formSchema: CredentialFormSchema) => {
|
||||
const tooltip = formSchema.tooltip
|
||||
const tooltipContent = (tooltip && (
|
||||
|
|
@ -302,12 +292,8 @@ const Form: FC<FormProps> = ({
|
|||
popupClassName='!w-[387px]'
|
||||
isAdvancedMode
|
||||
isInWorkflow
|
||||
provider={value[variable]?.provider}
|
||||
modelId={value[variable]?.model}
|
||||
mode={value[variable]?.mode}
|
||||
completionParams={value[variable]?.completion_params}
|
||||
value={value[variable]}
|
||||
setModel={model => handleModelChanged(variable, model)}
|
||||
onCompletionParamsChange={params => handleCompletionParamsChange(variable, params)}
|
||||
readonly={readonly}
|
||||
scope={scope}
|
||||
/>
|
||||
|
|
|
|||
|
|
@ -21,6 +21,7 @@ import {
|
|||
PortalToFollowElemTrigger,
|
||||
} from '@/app/components/base/portal-to-follow-elem'
|
||||
import LLMParamsPanel from './llm-params-panel'
|
||||
import TTSParamsPanel from './tts-params-panel'
|
||||
import { useProviderContext } from '@/context/provider-context'
|
||||
import cn from '@/utils/classnames'
|
||||
|
||||
|
|
@ -28,12 +29,8 @@ export type ModelParameterModalProps = {
|
|||
popupClassName?: string
|
||||
portalToFollowElemContentClassName?: string
|
||||
isAdvancedMode: boolean
|
||||
mode: string
|
||||
modelId: string
|
||||
provider: string
|
||||
setModel: (model: { modelId: string; provider: string; mode?: string; features?: string[] }) => void
|
||||
completionParams: FormValue
|
||||
onCompletionParamsChange: (newParams: FormValue) => void
|
||||
value: any
|
||||
setModel: (model: any) => void
|
||||
renderTrigger?: (v: TriggerProps) => ReactNode
|
||||
readonly?: boolean
|
||||
isInWorkflow?: boolean
|
||||
|
|
@ -44,15 +41,12 @@ const ModelParameterModal: FC<ModelParameterModalProps> = ({
|
|||
popupClassName,
|
||||
portalToFollowElemContentClassName,
|
||||
isAdvancedMode,
|
||||
modelId,
|
||||
provider,
|
||||
value,
|
||||
setModel,
|
||||
completionParams,
|
||||
onCompletionParamsChange,
|
||||
renderTrigger,
|
||||
readonly,
|
||||
isInWorkflow,
|
||||
scope = 'text-generation',
|
||||
scope = ModelTypeEnum.textGeneration,
|
||||
}) => {
|
||||
const { t } = useTranslation()
|
||||
const { isAPIKeySet } = useProviderContext()
|
||||
|
|
@ -79,29 +73,29 @@ const ModelParameterModal: FC<ModelParameterModalProps> = ({
|
|||
...moderationList,
|
||||
]
|
||||
}
|
||||
if (scopeArray.includes('text-generation'))
|
||||
if (scopeArray.includes(ModelTypeEnum.textGeneration))
|
||||
return textGenerationList
|
||||
if (scopeArray.includes('embedding'))
|
||||
if (scopeArray.includes(ModelTypeEnum.textEmbedding))
|
||||
return textEmbeddingList
|
||||
if (scopeArray.includes('rerank'))
|
||||
if (scopeArray.includes(ModelTypeEnum.rerank))
|
||||
return rerankList
|
||||
if (scopeArray.includes('moderation'))
|
||||
if (scopeArray.includes(ModelTypeEnum.moderation))
|
||||
return moderationList
|
||||
if (scopeArray.includes('stt'))
|
||||
if (scopeArray.includes(ModelTypeEnum.speech2text))
|
||||
return sttList
|
||||
if (scopeArray.includes('tts'))
|
||||
if (scopeArray.includes(ModelTypeEnum.tts))
|
||||
return ttsList
|
||||
return resultList
|
||||
}, [scopeArray, textGenerationList, textEmbeddingList, rerankList, sttList, ttsList, moderationList])
|
||||
|
||||
const { currentProvider, currentModel } = useMemo(() => {
|
||||
const currentProvider = scopedModelList.find(item => item.provider === provider)
|
||||
const currentModel = currentProvider?.models.find((model: { model: string }) => model.model === modelId)
|
||||
const currentProvider = scopedModelList.find(item => item.provider === value?.provider)
|
||||
const currentModel = currentProvider?.models.find((model: { model: string }) => model.model === value?.model)
|
||||
return {
|
||||
currentProvider,
|
||||
currentModel,
|
||||
}
|
||||
}, [provider, modelId, scopedModelList])
|
||||
}, [scopedModelList, value?.provider, value?.model])
|
||||
|
||||
const hasDeprecated = useMemo(() => {
|
||||
return !currentProvider || !currentModel
|
||||
|
|
@ -116,11 +110,33 @@ const ModelParameterModal: FC<ModelParameterModalProps> = ({
|
|||
const handleChangeModel = ({ provider, model }: DefaultModel) => {
|
||||
const targetProvider = scopedModelList.find(modelItem => modelItem.provider === provider)
|
||||
const targetModelItem = targetProvider?.models.find((modelItem: { model: string }) => modelItem.model === model)
|
||||
const model_type = targetModelItem?.model_type as string
|
||||
setModel({
|
||||
modelId: model,
|
||||
provider,
|
||||
mode: targetModelItem?.model_properties.mode as string,
|
||||
features: targetModelItem?.features || [],
|
||||
model,
|
||||
model_type,
|
||||
...(model_type === ModelTypeEnum.textGeneration ? {
|
||||
mode: targetModelItem?.model_properties.mode as string,
|
||||
} : {}),
|
||||
})
|
||||
}
|
||||
|
||||
const handleLLMParamsChange = (newParams: FormValue) => {
|
||||
const newValue = {
|
||||
...(value?.completionParams || {}),
|
||||
completion_params: newParams,
|
||||
}
|
||||
setModel({
|
||||
...value,
|
||||
...newValue,
|
||||
})
|
||||
}
|
||||
|
||||
const handleTTSParamsChange = (language: string, voice: string) => {
|
||||
setModel({
|
||||
...value,
|
||||
language,
|
||||
voice,
|
||||
})
|
||||
}
|
||||
|
||||
|
|
@ -149,8 +165,8 @@ const ModelParameterModal: FC<ModelParameterModalProps> = ({
|
|||
hasDeprecated,
|
||||
currentProvider,
|
||||
currentModel,
|
||||
providerName: provider,
|
||||
modelId,
|
||||
providerName: value?.provider,
|
||||
modelId: value?.model,
|
||||
})
|
||||
: (
|
||||
<Trigger
|
||||
|
|
@ -160,8 +176,8 @@ const ModelParameterModal: FC<ModelParameterModalProps> = ({
|
|||
hasDeprecated={hasDeprecated}
|
||||
currentProvider={currentProvider}
|
||||
currentModel={currentModel}
|
||||
providerName={provider}
|
||||
modelId={modelId}
|
||||
providerName={value?.provider}
|
||||
modelId={value?.model}
|
||||
/>
|
||||
)
|
||||
}
|
||||
|
|
@ -174,7 +190,7 @@ const ModelParameterModal: FC<ModelParameterModalProps> = ({
|
|||
{t('common.modelProvider.model').toLocaleUpperCase()}
|
||||
</div>
|
||||
<ModelSelector
|
||||
defaultModel={(provider || modelId) ? { provider, model: modelId } : undefined}
|
||||
defaultModel={(value?.provider || value?.model) ? { provider: value?.provider, model: value?.model } : undefined}
|
||||
modelList={scopedModelList}
|
||||
scopeFeatures={scopeFeatures}
|
||||
onSelect={handleChangeModel}
|
||||
|
|
@ -185,13 +201,21 @@ const ModelParameterModal: FC<ModelParameterModalProps> = ({
|
|||
)}
|
||||
{currentModel?.model_type === ModelTypeEnum.textGeneration && (
|
||||
<LLMParamsPanel
|
||||
provider={provider}
|
||||
modelId={modelId}
|
||||
completionParams={completionParams}
|
||||
onCompletionParamsChange={onCompletionParamsChange}
|
||||
provider={value?.provider}
|
||||
modelId={value?.model}
|
||||
completionParams={value?.completion_params || {}}
|
||||
onCompletionParamsChange={handleLLMParamsChange}
|
||||
isAdvancedMode={isAdvancedMode}
|
||||
/>
|
||||
)}
|
||||
{currentModel?.model_type === ModelTypeEnum.tts && (
|
||||
<TTSParamsPanel
|
||||
currentModel={currentModel}
|
||||
language={value?.language}
|
||||
voice={value?.voice}
|
||||
onChange={handleTTSParamsChange}
|
||||
/>
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
</PortalToFollowElemContent>
|
||||
|
|
|
|||
|
|
@ -0,0 +1,67 @@
|
|||
import React, { useMemo } from 'react'
|
||||
import { useTranslation } from 'react-i18next'
|
||||
import { languages } from '@/i18n/language'
|
||||
import { PortalSelect } from '@/app/components/base/select'
|
||||
import cn from '@/utils/classnames'
|
||||
|
||||
type Props = {
|
||||
currentModel: any
|
||||
language: string
|
||||
voice: string
|
||||
onChange: (language: string, voice: string) => void
|
||||
}
|
||||
|
||||
const TTSParamsPanel = ({
|
||||
currentModel,
|
||||
language,
|
||||
voice,
|
||||
onChange,
|
||||
}: Props) => {
|
||||
const { t } = useTranslation()
|
||||
const voiceList = useMemo(() => {
|
||||
if (!currentModel)
|
||||
return []
|
||||
return currentModel.model_properties.voices.map((item: { mode: any }) => ({
|
||||
...item,
|
||||
value: item.mode,
|
||||
}))
|
||||
}, [currentModel])
|
||||
const setLanguage = (language: string) => {
|
||||
onChange(language, voice)
|
||||
}
|
||||
const setVoice = (voice: string) => {
|
||||
onChange(language, voice)
|
||||
}
|
||||
return (
|
||||
<>
|
||||
<div className='mb-3'>
|
||||
<div className='mb-1 py-1 flex items-center text-text-secondary system-sm-semibold'>
|
||||
{t('appDebug.voice.voiceSettings.language')}
|
||||
</div>
|
||||
<PortalSelect
|
||||
triggerClassName='h-8'
|
||||
popupClassName={cn('z-[1000]')}
|
||||
popupInnerClassName={cn('w-[354px]')}
|
||||
value={language}
|
||||
items={languages.filter(item => item.supported)}
|
||||
onSelect={item => setLanguage(item.value as string)}
|
||||
/>
|
||||
</div>
|
||||
<div className='mb-3'>
|
||||
<div className='mb-1 py-1 flex items-center text-text-secondary system-sm-semibold'>
|
||||
{t('appDebug.voice.voiceSettings.voice')}
|
||||
</div>
|
||||
<PortalSelect
|
||||
triggerClassName='h-8'
|
||||
popupClassName={cn('z-[1000]')}
|
||||
popupInnerClassName={cn('w-[354px]')}
|
||||
value={voice}
|
||||
items={voiceList}
|
||||
onSelect={item => setVoice(item.value as string)}
|
||||
/>
|
||||
</div>
|
||||
</>
|
||||
)
|
||||
}
|
||||
|
||||
export default TTSParamsPanel
|
||||
|
|
@ -123,24 +123,11 @@ const InputVarList: FC<Props> = ({
|
|||
}
|
||||
}, [onChange, value])
|
||||
const handleModelChange = useCallback((variable: string) => {
|
||||
return (model: { provider: string; modelId: string; mode?: string }) => {
|
||||
return (model: any) => {
|
||||
const newValue = produce(value, (draft: ToolVarInputs) => {
|
||||
draft[variable] = {
|
||||
...draft[variable],
|
||||
provider: model.provider,
|
||||
model: model.modelId,
|
||||
mode: model.mode,
|
||||
} as any
|
||||
})
|
||||
onChange(newValue)
|
||||
}
|
||||
}, [onChange, value])
|
||||
const handleModelParamsChange = useCallback((variable: string) => {
|
||||
return (newParams: Record<string, any>) => {
|
||||
const newValue = produce(value, (draft: ToolVarInputs) => {
|
||||
draft[variable] = {
|
||||
...draft[variable],
|
||||
completion_params: newParams,
|
||||
...model,
|
||||
} as any
|
||||
})
|
||||
onChange(newValue)
|
||||
|
|
@ -242,12 +229,8 @@ const InputVarList: FC<Props> = ({
|
|||
popupClassName='!w-[387px]'
|
||||
isAdvancedMode
|
||||
isInWorkflow
|
||||
provider={(varInput as any)?.provider}
|
||||
modelId={(varInput as any)?.model}
|
||||
mode={(varInput as any)?.mode}
|
||||
completionParams={(varInput as any)?.completion_params}
|
||||
value={varInput as any}
|
||||
setModel={handleModelChange(variable)}
|
||||
onCompletionParamsChange={handleModelParamsChange(variable)}
|
||||
readonly={readOnly}
|
||||
scope={scope}
|
||||
/>
|
||||
|
|
|
|||
Loading…
Reference in New Issue