diff --git a/web/app/components/base/select/index.tsx b/web/app/components/base/select/index.tsx index f7cbfc916a..687d402582 100644 --- a/web/app/components/base/select/index.tsx +++ b/web/app/components/base/select/index.tsx @@ -2,7 +2,7 @@ import type { FC } from 'react' import React, { Fragment, useEffect, useState } from 'react' import { Combobox, Listbox, Transition } from '@headlessui/react' -import { CheckIcon, ChevronDownIcon, ChevronUpIcon, XMarkIcon } from '@heroicons/react/20/solid' +import { ChevronDownIcon, ChevronUpIcon, XMarkIcon } from '@heroicons/react/20/solid' import Badge from '../badge/index' import { RiCheckLine } from '@remixicon/react' import { useTranslation } from 'react-i18next' @@ -352,7 +352,7 @@ const PortalSelect: FC = ({
{items.map((item: Item) => (
= ({ onChange({ ...value, [key]: val, ...shouldClearVariable }) } - const handleModelChanged = useCallback((key: string, model: { provider: string; modelId: string; mode?: string }) => { + const handleModelChanged = useCallback((key: string, model: any) => { const newValue = { ...value[key], - provider: model.provider, - model: model.modelId, - mode: model.mode, + ...model, type: FormTypeEnum.modelSelector, } onChange({ ...value, [key]: newValue }) }, [onChange, value]) - const handleCompletionParamsChange = useCallback((key: string, newParams: Record) => { - const newValue = { - ...value[key], - completion_params: newParams, - } - onChange({ ...value, [key]: newValue }) - }, [onChange, value]) - const renderField = (formSchema: CredentialFormSchema) => { const tooltip = formSchema.tooltip const tooltipContent = (tooltip && ( @@ -302,12 +292,8 @@ const Form: FC = ({ popupClassName='!w-[387px]' isAdvancedMode isInWorkflow - provider={value[variable]?.provider} - modelId={value[variable]?.model} - mode={value[variable]?.mode} - completionParams={value[variable]?.completion_params} + value={value[variable]} setModel={model => handleModelChanged(variable, model)} - onCompletionParamsChange={params => handleCompletionParamsChange(variable, params)} readonly={readonly} scope={scope} /> diff --git a/web/app/components/plugins/plugin-detail-panel/model-selector/index.tsx b/web/app/components/plugins/plugin-detail-panel/model-selector/index.tsx index 8b56b4d3d1..6bd750c8c3 100644 --- a/web/app/components/plugins/plugin-detail-panel/model-selector/index.tsx +++ b/web/app/components/plugins/plugin-detail-panel/model-selector/index.tsx @@ -21,6 +21,7 @@ import { PortalToFollowElemTrigger, } from '@/app/components/base/portal-to-follow-elem' import LLMParamsPanel from './llm-params-panel' +import TTSParamsPanel from './tts-params-panel' import { useProviderContext } from '@/context/provider-context' import cn from '@/utils/classnames' @@ -28,12 +29,8 @@ export type ModelParameterModalProps = { popupClassName?: string portalToFollowElemContentClassName?: string isAdvancedMode: boolean - mode: string - modelId: string - provider: string - setModel: (model: { modelId: string; provider: string; mode?: string; features?: string[] }) => void - completionParams: FormValue - onCompletionParamsChange: (newParams: FormValue) => void + value: any + setModel: (model: any) => void renderTrigger?: (v: TriggerProps) => ReactNode readonly?: boolean isInWorkflow?: boolean @@ -44,15 +41,12 @@ const ModelParameterModal: FC = ({ popupClassName, portalToFollowElemContentClassName, isAdvancedMode, - modelId, - provider, + value, setModel, - completionParams, - onCompletionParamsChange, renderTrigger, readonly, isInWorkflow, - scope = 'text-generation', + scope = ModelTypeEnum.textGeneration, }) => { const { t } = useTranslation() const { isAPIKeySet } = useProviderContext() @@ -79,29 +73,29 @@ const ModelParameterModal: FC = ({ ...moderationList, ] } - if (scopeArray.includes('text-generation')) + if (scopeArray.includes(ModelTypeEnum.textGeneration)) return textGenerationList - if (scopeArray.includes('embedding')) + if (scopeArray.includes(ModelTypeEnum.textEmbedding)) return textEmbeddingList - if (scopeArray.includes('rerank')) + if (scopeArray.includes(ModelTypeEnum.rerank)) return rerankList - if (scopeArray.includes('moderation')) + if (scopeArray.includes(ModelTypeEnum.moderation)) return moderationList - if (scopeArray.includes('stt')) + if (scopeArray.includes(ModelTypeEnum.speech2text)) return sttList - if (scopeArray.includes('tts')) + if (scopeArray.includes(ModelTypeEnum.tts)) return ttsList return resultList }, [scopeArray, textGenerationList, textEmbeddingList, rerankList, sttList, ttsList, moderationList]) const { currentProvider, currentModel } = useMemo(() => { - const currentProvider = scopedModelList.find(item => item.provider === provider) - const currentModel = currentProvider?.models.find((model: { model: string }) => model.model === modelId) + const currentProvider = scopedModelList.find(item => item.provider === value?.provider) + const currentModel = currentProvider?.models.find((model: { model: string }) => model.model === value?.model) return { currentProvider, currentModel, } - }, [provider, modelId, scopedModelList]) + }, [scopedModelList, value?.provider, value?.model]) const hasDeprecated = useMemo(() => { return !currentProvider || !currentModel @@ -116,11 +110,33 @@ const ModelParameterModal: FC = ({ const handleChangeModel = ({ provider, model }: DefaultModel) => { const targetProvider = scopedModelList.find(modelItem => modelItem.provider === provider) const targetModelItem = targetProvider?.models.find((modelItem: { model: string }) => modelItem.model === model) + const model_type = targetModelItem?.model_type as string setModel({ - modelId: model, provider, - mode: targetModelItem?.model_properties.mode as string, - features: targetModelItem?.features || [], + model, + model_type, + ...(model_type === ModelTypeEnum.textGeneration ? { + mode: targetModelItem?.model_properties.mode as string, + } : {}), + }) + } + + const handleLLMParamsChange = (newParams: FormValue) => { + const newValue = { + ...(value?.completionParams || {}), + completion_params: newParams, + } + setModel({ + ...value, + ...newValue, + }) + } + + const handleTTSParamsChange = (language: string, voice: string) => { + setModel({ + ...value, + language, + voice, }) } @@ -149,8 +165,8 @@ const ModelParameterModal: FC = ({ hasDeprecated, currentProvider, currentModel, - providerName: provider, - modelId, + providerName: value?.provider, + modelId: value?.model, }) : ( = ({ hasDeprecated={hasDeprecated} currentProvider={currentProvider} currentModel={currentModel} - providerName={provider} - modelId={modelId} + providerName={value?.provider} + modelId={value?.model} /> ) } @@ -174,7 +190,7 @@ const ModelParameterModal: FC = ({ {t('common.modelProvider.model').toLocaleUpperCase()}
= ({ )} {currentModel?.model_type === ModelTypeEnum.textGeneration && ( )} + {currentModel?.model_type === ModelTypeEnum.tts && ( + + )}
diff --git a/web/app/components/plugins/plugin-detail-panel/model-selector/tts-params-panel.tsx b/web/app/components/plugins/plugin-detail-panel/model-selector/tts-params-panel.tsx new file mode 100644 index 0000000000..a13b9905d3 --- /dev/null +++ b/web/app/components/plugins/plugin-detail-panel/model-selector/tts-params-panel.tsx @@ -0,0 +1,67 @@ +import React, { useMemo } from 'react' +import { useTranslation } from 'react-i18next' +import { languages } from '@/i18n/language' +import { PortalSelect } from '@/app/components/base/select' +import cn from '@/utils/classnames' + +type Props = { + currentModel: any + language: string + voice: string + onChange: (language: string, voice: string) => void +} + +const TTSParamsPanel = ({ + currentModel, + language, + voice, + onChange, +}: Props) => { + const { t } = useTranslation() + const voiceList = useMemo(() => { + if (!currentModel) + return [] + return currentModel.model_properties.voices.map((item: { mode: any }) => ({ + ...item, + value: item.mode, + })) + }, [currentModel]) + const setLanguage = (language: string) => { + onChange(language, voice) + } + const setVoice = (voice: string) => { + onChange(language, voice) + } + return ( + <> +
+
+ {t('appDebug.voice.voiceSettings.language')} +
+ item.supported)} + onSelect={item => setLanguage(item.value as string)} + /> +
+
+
+ {t('appDebug.voice.voiceSettings.voice')} +
+ setVoice(item.value as string)} + /> +
+ + ) +} + +export default TTSParamsPanel diff --git a/web/app/components/workflow/nodes/tool/components/input-var-list.tsx b/web/app/components/workflow/nodes/tool/components/input-var-list.tsx index 9c9d097d3a..d4cef82c43 100644 --- a/web/app/components/workflow/nodes/tool/components/input-var-list.tsx +++ b/web/app/components/workflow/nodes/tool/components/input-var-list.tsx @@ -123,24 +123,11 @@ const InputVarList: FC = ({ } }, [onChange, value]) const handleModelChange = useCallback((variable: string) => { - return (model: { provider: string; modelId: string; mode?: string }) => { + return (model: any) => { const newValue = produce(value, (draft: ToolVarInputs) => { draft[variable] = { ...draft[variable], - provider: model.provider, - model: model.modelId, - mode: model.mode, - } as any - }) - onChange(newValue) - } - }, [onChange, value]) - const handleModelParamsChange = useCallback((variable: string) => { - return (newParams: Record) => { - const newValue = produce(value, (draft: ToolVarInputs) => { - draft[variable] = { - ...draft[variable], - completion_params: newParams, + ...model, } as any }) onChange(newValue) @@ -242,12 +229,8 @@ const InputVarList: FC = ({ popupClassName='!w-[387px]' isAdvancedMode isInWorkflow - provider={(varInput as any)?.provider} - modelId={(varInput as any)?.model} - mode={(varInput as any)?.mode} - completionParams={(varInput as any)?.completion_params} + value={varInput as any} setModel={handleModelChange(variable)} - onCompletionParamsChange={handleModelParamsChange(variable)} readonly={readOnly} scope={scope} />