diff --git a/web/app/components/header/account-setting/model-provider-page/model-selector/index.tsx b/web/app/components/header/account-setting/model-provider-page/model-selector/index.tsx index c6dd76a04f..7613bebf37 100644 --- a/web/app/components/header/account-setting/model-provider-page/model-selector/index.tsx +++ b/web/app/components/header/account-setting/model-provider-page/model-selector/index.tsx @@ -23,6 +23,7 @@ type ModelSelectorProps = { popupClassName?: string onSelect?: (model: DefaultModel) => void readonly?: boolean + scopeFeatures?: string[] } const ModelSelector: FC = ({ defaultModel, @@ -31,6 +32,7 @@ const ModelSelector: FC = ({ popupClassName, onSelect, readonly, + scopeFeatures = [], }) => { const [open, setOpen] = useState(false) const { @@ -101,6 +103,7 @@ const ModelSelector: FC = ({ defaultModel={defaultModel} modelList={modelList} onSelect={handleSelect} + scopeFeatures={scopeFeatures} /> diff --git a/web/app/components/header/account-setting/model-provider-page/model-selector/popup.tsx b/web/app/components/header/account-setting/model-provider-page/model-selector/popup.tsx index cb9492c4d2..1089697c98 100644 --- a/web/app/components/header/account-setting/model-provider-page/model-selector/popup.tsx +++ b/web/app/components/header/account-setting/model-provider-page/model-selector/popup.tsx @@ -1,5 +1,5 @@ import type { FC } from 'react' -import { useState } from 'react' +import { useMemo, useState } from 'react' import { RiSearchLine, } from '@remixicon/react' @@ -8,6 +8,7 @@ import type { Model, ModelItem, } from '../declarations' +import { ModelFeatureEnum } from '../declarations' import { useLanguage } from '../hooks' import PopupItem from './popup-item' import { XCircle } from '@/app/components/base/icons/src/vender/solid/general' @@ -16,27 +17,39 @@ type PopupProps = { defaultModel?: DefaultModel modelList: Model[] onSelect: (provider: string, model: ModelItem) => void + scopeFeatures?: string[] } const Popup: FC = ({ defaultModel, modelList, onSelect, + scopeFeatures = [], }) => { const language = useLanguage() const [searchText, setSearchText] = useState('') - const filteredModelList = modelList.map((model) => { - const filteredModels = model.models.filter((modelItem) => { - if (modelItem.label[language] !== undefined) - return modelItem.label[language].toLowerCase().includes(searchText.toLowerCase()) - - return Object.values(modelItem.label).some(label => - label.toLowerCase().includes(searchText.toLowerCase()), - ) - }) - - return { ...model, models: filteredModels } - }).filter(model => model.models.length > 0) + const filteredModelList = useMemo(() => { + return modelList.map((model) => { + const filteredModels = model.models + .filter((modelItem) => { + if (modelItem.label[language] !== undefined) + return modelItem.label[language].toLowerCase().includes(searchText.toLowerCase()) + return Object.values(modelItem.label).some(label => + label.toLowerCase().includes(searchText.toLowerCase()), + ) + }) + .filter((modelItem) => { + if (scopeFeatures.length === 0) + return true + return scopeFeatures.every((feature) => { + if (feature === ModelFeatureEnum.toolCall) + return modelItem.features?.some(featureItem => featureItem === ModelFeatureEnum.toolCall || featureItem === ModelFeatureEnum.multiToolCall) + return modelItem.features?.some(featureItem => featureItem === feature) + }) + }) + return { ...model, models: filteredModels } + }).filter(model => model.models.length > 0) + }, [language, modelList, scopeFeatures, searchText]) return (
diff --git a/web/app/components/plugins/plugin-detail-panel/model-selector/index.tsx b/web/app/components/plugins/plugin-detail-panel/model-selector/index.tsx index 881077d68c..8b56b4d3d1 100644 --- a/web/app/components/plugins/plugin-detail-panel/model-selector/index.tsx +++ b/web/app/components/plugins/plugin-detail-panel/model-selector/index.tsx @@ -58,6 +58,7 @@ const ModelParameterModal: FC = ({ const { isAPIKeySet } = useProviderContext() const [open, setOpen] = useState(false) const scopeArray = scope.split('&') + const scopeFeatures = scopeArray.slice(1) || [] const { data: textGenerationList } = useModelList(ModelTypeEnum.textGeneration) const { data: textEmbeddingList } = useModelList(ModelTypeEnum.textEmbedding) @@ -175,10 +176,11 @@ const ModelParameterModal: FC = ({
- {(currentModel?.model_type === ModelTypeEnum.textGeneration || currentModel.model_type === ModelTypeEnum.tts) && ( + {(currentModel?.model_type === ModelTypeEnum.textGeneration || currentModel?.model_type === ModelTypeEnum.tts) && (
)} {currentModel?.model_type === ModelTypeEnum.textGeneration && (