diff --git a/web/app/components/app/configuration/dataset-config/params-config/config-content.tsx b/web/app/components/app/configuration/dataset-config/params-config/config-content.tsx new file mode 100644 index 0000000000..0081ca03df --- /dev/null +++ b/web/app/components/app/configuration/dataset-config/params-config/config-content.tsx @@ -0,0 +1,139 @@ +'use client' +import React from 'react' +import type { FC } from 'react' +import { useTranslation } from 'react-i18next' +import TopKItem from '@/app/components/base/param-item/top-k-item' +import ScoreThresholdItem from '@/app/components/base/param-item/score-threshold-item' +import RadioCard from '@/app/components/base/radio-card/simple' +import { RETRIEVE_TYPE } from '@/types/app' +import { + MultiPathRetrieval, + NTo1Retrieval, +} from '@/app/components/base/icons/src/public/common' +import type { + DatasetConfigs, +} from '@/models/debug' + +import ModelSelector from '@/app/components/header/account-setting/model-provider-page/model-selector' +import { useModelListAndDefaultModelAndCurrentProviderAndModel } from '@/app/components/header/account-setting/model-provider-page/hooks' + +type Props = { + datasetConfigs: DatasetConfigs + onChange: (configs: DatasetConfigs, isRetrievalModeChange?: boolean) => void +} + +const ConfigContent: FC = ({ + datasetConfigs, + onChange, +}) => { + const { t } = useTranslation() + const type = datasetConfigs.retrieval_model + const setType = (value: RETRIEVE_TYPE) => { + onChange({ + ...datasetConfigs, + retrieval_model: value, + }, true) + } + const { + modelList: rerankModelList, + defaultModel: rerankDefaultModel, + currentModel: isRerankDefaultModelVaild, + } = useModelListAndDefaultModelAndCurrentProviderAndModel(3) + + const rerankModel = (() => { + if (datasetConfigs.reranking_model) { + return { + provider_name: datasetConfigs.reranking_model.reranking_provider_name, + model_name: datasetConfigs.reranking_model.reranking_model_name, + } + } + else if (rerankDefaultModel) { + return { + provider_name: rerankDefaultModel.provider.provider, + model_name: rerankDefaultModel.model, + } + } + })() + + const handleParamChange = (key: string, value: number) => { + if (key === 'top_k') { + onChange({ + ...datasetConfigs, + top_k: value, + }) + } + else if (key === 'score_threshold') { + onChange({ + ...datasetConfigs, + score_threshold: value, + }) + } + } + + const handleSwitch = (key: string, enable: boolean) => { + if (key === 'top_k') + return + + onChange({ + ...datasetConfigs, + score_threshold_enabled: enable, + }) + } + return ( +
+
+ } + title={t('appDebug.datasetConfig.retrieveOneWay.title')} + description={t('appDebug.datasetConfig.retrieveOneWay.description')} + isChosen={type === RETRIEVE_TYPE.oneWay} + onChosen={() => { setType(RETRIEVE_TYPE.oneWay) }} + /> + } + title={t('appDebug.datasetConfig.retrieveMultiWay.title')} + description={t('appDebug.datasetConfig.retrieveMultiWay.description')} + isChosen={type === RETRIEVE_TYPE.multiWay} + onChosen={() => { setType(RETRIEVE_TYPE.multiWay) }} + /> +
+ {type === RETRIEVE_TYPE.multiWay && ( + <> +
+
{t('common.modelProvider.rerankModel.key')}
+
+ { + onChange({ + ...datasetConfigs, + reranking_model: { + reranking_provider_name: v.provider, + reranking_model_name: v.model, + }, + }) + }} + modelList={rerankModelList} + /> +
+
+
+ + +
+ + )} +
+ ) +} +export default React.memo(ConfigContent) diff --git a/web/app/components/app/configuration/dataset-config/params-config/index.tsx b/web/app/components/app/configuration/dataset-config/params-config/index.tsx index 8aecaf734f..95ec81720f 100644 --- a/web/app/components/app/configuration/dataset-config/params-config/index.tsx +++ b/web/app/components/app/configuration/dataset-config/params-config/index.tsx @@ -4,21 +4,14 @@ import { memo, useState } from 'react' import { useTranslation } from 'react-i18next' import { useContext } from 'use-context-selector' import cn from 'classnames' +import ConfigContent from './config-content' import { Settings04 } from '@/app/components/base/icons/src/vender/line/general' import ConfigContext from '@/context/debug-configuration' -import TopKItem from '@/app/components/base/param-item/top-k-item' -import ScoreThresholdItem from '@/app/components/base/param-item/score-threshold-item' import Modal from '@/app/components/base/modal' import Button from '@/app/components/base/button' -import RadioCard from '@/app/components/base/radio-card/simple' import { RETRIEVE_TYPE } from '@/types/app' import Toast from '@/app/components/base/toast' import { DATASET_DEFAULT } from '@/config' -import { - MultiPathRetrieval, - NTo1Retrieval, -} from '@/app/components/base/icons/src/public/common' -import ModelSelector from '@/app/components/header/account-setting/model-provider-page/model-selector' import { useModelListAndDefaultModelAndCurrentProviderAndModel } from '@/app/components/header/account-setting/model-provider-page/hooks' const ParamsConfig: FC = () => { @@ -30,58 +23,11 @@ const ParamsConfig: FC = () => { } = useContext(ConfigContext) const [tempDataSetConfigs, setTempDataSetConfigs] = useState(datasetConfigs) - const type = tempDataSetConfigs.retrieval_model - const setType = (value: RETRIEVE_TYPE) => { - setTempDataSetConfigs({ - ...tempDataSetConfigs, - retrieval_model: value, - }) - } const { - modelList: rerankModelList, defaultModel: rerankDefaultModel, currentModel: isRerankDefaultModelVaild, } = useModelListAndDefaultModelAndCurrentProviderAndModel(3) - const rerankModel = (() => { - if (tempDataSetConfigs.reranking_model) { - return { - provider_name: tempDataSetConfigs.reranking_model.reranking_provider_name, - model_name: tempDataSetConfigs.reranking_model.reranking_model_name, - } - } - else if (rerankDefaultModel) { - return { - provider_name: rerankDefaultModel.provider.provider, - model_name: rerankDefaultModel.model, - } - } - })() - - const handleParamChange = (key: string, value: number) => { - if (key === 'top_k') { - setTempDataSetConfigs({ - ...tempDataSetConfigs, - top_k: value, - }) - } - else if (key === 'score_threshold') { - setTempDataSetConfigs({ - ...tempDataSetConfigs, - score_threshold: value, - }) - } - } - - const handleSwitch = (key: string, enable: boolean) => { - if (key === 'top_k') - return - - setTempDataSetConfigs({ - ...tempDataSetConfigs, - score_threshold_enabled: enable, - }) - } const isValid = () => { let errMsg = '' if (tempDataSetConfigs.retrieval_model === RETRIEVE_TYPE.multiWay) { @@ -140,58 +86,11 @@ const ParamsConfig: FC = () => { wrapperClassName='z-50' title={t('appDebug.datasetConfig.settingTitle')} > -
- } - title={t('appDebug.datasetConfig.retrieveOneWay.title')} - description={t('appDebug.datasetConfig.retrieveOneWay.description')} - isChosen={type === RETRIEVE_TYPE.oneWay} - onChosen={() => { setType(RETRIEVE_TYPE.oneWay) }} - /> - } - title={t('appDebug.datasetConfig.retrieveMultiWay.title')} - description={t('appDebug.datasetConfig.retrieveMultiWay.description')} - isChosen={type === RETRIEVE_TYPE.multiWay} - onChosen={() => { setType(RETRIEVE_TYPE.multiWay) }} - /> -
- {type === RETRIEVE_TYPE.multiWay && ( - <> -
-
{t('common.modelProvider.rerankModel.key')}
-
- { - setTempDataSetConfigs({ - ...tempDataSetConfigs, - reranking_model: { - reranking_provider_name: v.provider, - reranking_model_name: v.model, - }, - }) - }} - modelList={rerankModelList} - /> -
-
-
- - -
- - )} + +
+ } + > + + diff --git a/web/app/components/workflow/nodes/knowledge-retrieval/types.ts b/web/app/components/workflow/nodes/knowledge-retrieval/types.ts index 15d941c242..5fc765eeb7 100644 --- a/web/app/components/workflow/nodes/knowledge-retrieval/types.ts +++ b/web/app/components/workflow/nodes/knowledge-retrieval/types.ts @@ -1,16 +1,17 @@ import type { CommonNodeType, ValueSelector } from '@/app/components/workflow/types' import type { RETRIEVE_TYPE } from '@/types/app' +export type MultipleRetrievalConfig = { + top_k: number + score_threshold: number | null | undefined + reranking_model: { + provider: string + model: string + } +} export type KnowledgeRetrievalNodeType = CommonNodeType & { query_variable_selector: ValueSelector dataset_ids: string[] retrieval_mode: RETRIEVE_TYPE - multiple_retrieval_config?: { - top_k: number - score_threshold: number - reranking_model: { - provider: string - model: string - } - } + multiple_retrieval_config?: MultipleRetrievalConfig } diff --git a/web/app/components/workflow/nodes/knowledge-retrieval/use-config.ts b/web/app/components/workflow/nodes/knowledge-retrieval/use-config.ts index db17b743bc..373955158f 100644 --- a/web/app/components/workflow/nodes/knowledge-retrieval/use-config.ts +++ b/web/app/components/workflow/nodes/knowledge-retrieval/use-config.ts @@ -1,7 +1,8 @@ import { useCallback, useState } from 'react' import produce from 'immer' import type { ValueSelector } from '../../types' -import type { KnowledgeRetrievalNodeType } from './types' +import type { KnowledgeRetrievalNodeType, MultipleRetrievalConfig } from './types' +import type { RETRIEVE_TYPE } from '@/types/app' const useConfig = (initInputs: KnowledgeRetrievalNodeType) => { const [inputs, setInputs] = useState(initInputs) @@ -13,9 +14,25 @@ const useConfig = (initInputs: KnowledgeRetrievalNodeType) => { setInputs(newInputs) }, [inputs, setInputs]) + const handleRetrievalModeChange = useCallback((newMode: RETRIEVE_TYPE) => { + const newInputs = produce(inputs, (draft) => { + draft.retrieval_mode = newMode + }) + setInputs(newInputs) + }, [inputs, setInputs]) + + const handleMultipleRetrievalConfigChange = useCallback((newConfig: MultipleRetrievalConfig) => { + const newInputs = produce(inputs, (draft) => { + draft.multiple_retrieval_config = newConfig + }) + setInputs(newInputs) + }, [inputs, setInputs]) + return { inputs, handleQueryVarChange, + handleRetrievalModeChange, + handleMultipleRetrievalConfigChange, } } diff --git a/web/i18n/en-US/workflow.ts b/web/i18n/en-US/workflow.ts index 6f5d5a7485..aa074c3558 100644 --- a/web/i18n/en-US/workflow.ts +++ b/web/i18n/en-US/workflow.ts @@ -56,6 +56,7 @@ const translation = { }, knowledgeRetrieval: { queryVariable: 'Query Variable', + knowledge: 'Knowledge', outputVars: { output: 'Retrieval segmented data', content: 'Segmented content', diff --git a/web/i18n/zh-Hans/workflow.ts b/web/i18n/zh-Hans/workflow.ts index a851777724..adf89b25b8 100644 --- a/web/i18n/zh-Hans/workflow.ts +++ b/web/i18n/zh-Hans/workflow.ts @@ -55,6 +55,7 @@ const translation = { }, knowledgeRetrieval: { queryVariable: '查询变量', + knowledge: '知识库', outputVars: { output: '召回的分段', content: '分段内容',