diff --git a/web/app/components/app/configuration/dataset-config/params-config/config-content.tsx b/web/app/components/app/configuration/dataset-config/params-config/config-content.tsx index 1fa9345102..116dd51ee4 100644 --- a/web/app/components/app/configuration/dataset-config/params-config/config-content.tsx +++ b/web/app/components/app/configuration/dataset-config/params-config/config-content.tsx @@ -16,15 +16,27 @@ import type { import ModelSelector from '@/app/components/header/account-setting/model-provider-page/model-selector' import { useModelListAndDefaultModelAndCurrentProviderAndModel } from '@/app/components/header/account-setting/model-provider-page/hooks' +import type { ModelConfig } from '@/app/components/workflow/types' +import ModelParameterModal from '@/app/components/header/account-setting/model-provider-page/model-parameter-modal' +import TooltipPlus from '@/app/components/base/tooltip-plus' +import { HelpCircle } from '@/app/components/base/icons/src/vender/line/general' type Props = { datasetConfigs: DatasetConfigs onChange: (configs: DatasetConfigs, isRetrievalModeChange?: boolean) => void + isInWorkflow?: boolean + singleRetrievalModelConfig?: ModelConfig + onSingleRetrievalModelChange?: (config: ModelConfig) => void + onSingleRetrievalModelParamsChange?: (config: ModelConfig) => void } const ConfigContent: FC = ({ datasetConfigs, onChange, + isInWorkflow, + singleRetrievalModelConfig: singleRetrievalConfig = {} as ModelConfig, + onSingleRetrievalModelChange = () => { }, + onSingleRetrievalModelParamsChange = () => { }, }) => { const { t } = useTranslation() const type = datasetConfigs.retrieval_model @@ -77,6 +89,9 @@ const ConfigContent: FC = ({ score_threshold_enabled: enable, }) } + + const model = singleRetrievalConfig + return (
@@ -122,7 +137,7 @@ const ConfigContent: FC = ({ enable={true} /> = ({
)} -
+ + {isInWorkflow && type === RETRIEVE_TYPE.oneWay && ( +
+
+
{t('common.modelProvider.systemReasoningModel.key')}
+ + + +
+ +
+ ) + } + ) } export default React.memo(ConfigContent) diff --git a/web/app/components/header/account-setting/model-provider-page/model-parameter-modal/index.tsx b/web/app/components/header/account-setting/model-provider-page/model-parameter-modal/index.tsx index 4a719fe7ab..7a44f9a486 100644 --- a/web/app/components/header/account-setting/model-provider-page/model-parameter-modal/index.tsx +++ b/web/app/components/header/account-setting/model-provider-page/model-parameter-modal/index.tsx @@ -35,6 +35,7 @@ import { ArrowNarrowLeft } from '@/app/components/base/icons/src/vender/line/arr export type ModelParameterModalProps = { popupClassName?: string + portalToFollowElemContentClassName?: string isAdvancedMode: boolean mode: string modelId: string @@ -69,6 +70,7 @@ const stopParameerRule: ModelParameterRule = { const PROVIDER_WITH_PRESET_TONE = ['openai', 'azure_openai'] const ModelParameterModal: FC = ({ popupClassName, + portalToFollowElemContentClassName, isAdvancedMode, modelId, provider, @@ -200,7 +202,7 @@ const ModelParameterModal: FC = ({ ) } - +
diff --git a/web/app/components/workflow/nodes/knowledge-retrieval/components/retrieval-config.tsx b/web/app/components/workflow/nodes/knowledge-retrieval/components/retrieval-config.tsx index ca9cd616ef..23ff4dac23 100644 --- a/web/app/components/workflow/nodes/knowledge-retrieval/components/retrieval-config.tsx +++ b/web/app/components/workflow/nodes/knowledge-retrieval/components/retrieval-config.tsx @@ -3,7 +3,8 @@ import type { FC } from 'react' import React, { useCallback, useState } from 'react' import { useTranslation } from 'react-i18next' import cn from 'classnames' -import type { MultipleRetrievalConfig } from '../types' +import type { MultipleRetrievalConfig, SingleRetrievalConfig } from '../types' +import type { ModelConfig } from '../../../types' import { PortalToFollowElem, PortalToFollowElemContent, @@ -23,15 +24,22 @@ type Props = { payload: { retrieval_mode: RETRIEVE_TYPE multiple_retrieval_config?: MultipleRetrievalConfig + single_retrieval_config?: SingleRetrievalConfig } onRetrievalModeChange: (mode: RETRIEVE_TYPE) => void onMultipleRetrievalConfigChange: (config: MultipleRetrievalConfig) => void + singleRetrievalModelConfig?: ModelConfig + onSingleRetrievalModelChange?: (config: ModelConfig) => void + onSingleRetrievalModelParamsChange?: (config: ModelConfig) => void } const RetrievalConfig: FC = ({ payload, onRetrievalModeChange, onMultipleRetrievalConfigChange, + singleRetrievalModelConfig, + onSingleRetrievalModelChange, + onSingleRetrievalModelParamsChange, }) => { const { t } = useTranslation() @@ -43,6 +51,7 @@ const RetrievalConfig: FC = ({ const { multiple_retrieval_config } = payload const handleChange = useCallback((configs: DatasetConfigs, isRetrievalModeChange?: boolean) => { + console.log(configs, isRetrievalModeChange) if (isRetrievalModeChange) { onRetrievalModeChange(configs.retrieval_model) return @@ -62,7 +71,7 @@ const RetrievalConfig: FC = ({ model: configs.reranking_model?.reranking_model_name, }), }) - }, [onRetrievalModeChange, onMultipleRetrievalConfigChange]) + }, [onMultipleRetrievalConfigChange, payload.retrieval_mode, rerankDefaultModel?.provider?.provider, rerankDefaultModel?.model, onRetrievalModeChange]) return ( = ({ } } onChange={handleChange} + isInWorkflow + singleRetrievalModelConfig={singleRetrievalModelConfig} + onSingleRetrievalModelChange={onSingleRetrievalModelChange} + onSingleRetrievalModelParamsChange={onSingleRetrievalModelParamsChange} />
diff --git a/web/app/components/workflow/nodes/knowledge-retrieval/panel.tsx b/web/app/components/workflow/nodes/knowledge-retrieval/panel.tsx index d8acf4fa87..2e7ae7c46c 100644 --- a/web/app/components/workflow/nodes/knowledge-retrieval/panel.tsx +++ b/web/app/components/workflow/nodes/knowledge-retrieval/panel.tsx @@ -27,6 +27,8 @@ const Panel: FC> = ({ inputs, handleQueryVarChange, filterVar, + handleModelChanged, + handleCompletionParamsChange, handleRetrievalModeChange, handleMultipleRetrievalConfigChange, selectedDatasets, @@ -66,9 +68,13 @@ const Panel: FC> = ({ payload={{ retrieval_mode: inputs.retrieval_mode, multiple_retrieval_config: inputs.multiple_retrieval_config, + single_retrieval_config: inputs.single_retrieval_config, }} onRetrievalModeChange={handleRetrievalModeChange} onMultipleRetrievalConfigChange={handleMultipleRetrievalConfigChange} + singleRetrievalModelConfig={inputs.single_retrieval_config?.model} + onSingleRetrievalModelChange={handleModelChanged as any} + onSingleRetrievalModelParamsChange={handleCompletionParamsChange} />
{ const isChatMode = useIsChatMode() - console.log() const { getBeforeNodesInSameBranch } = useWorkflow() const startNode = getBeforeNodesInSameBranch(id).find(node => node.data.type === BlockEnum.Start) const startNodeId = startNode?.id const { inputs, setInputs } = useNodeCrud(id, payload) + + const inputRef = useRef(inputs) + useEffect(() => { + inputRef.current = inputs + }, [inputs]) + const handleQueryVarChange = useCallback((newVar: ValueSelector | string) => { const newInputs = produce(inputs, (draft) => { draft.query_variable_selector = newVar as ValueSelector @@ -24,12 +31,119 @@ const useConfig = (id: string, payload: KnowledgeRetrievalNodeType) => { setInputs(newInputs) }, [inputs, setInputs]) + const { + currentProvider, + currentModel, + } = useModelListAndDefaultModelAndCurrentProviderAndModel(1) + + const { + defaultModel: rerankDefaultModel, + } = useModelListAndDefaultModelAndCurrentProviderAndModel(3) + + const handleModelChanged = useCallback((model: { provider: string; modelId: string; mode?: string }) => { + const newInputs = produce(inputRef.current, (draft) => { + if (!draft.single_retrieval_config) { + draft.single_retrieval_config = { + model: { + provider: '', + name: '', + mode: '', + completion_params: {}, + }, + } + } + const draftModel = draft.single_retrieval_config?.model + draftModel.provider = model.provider + draftModel.name = model.modelId + draftModel.mode = model.mode! + }) + setInputs(newInputs) + }, [setInputs]) + + const handleCompletionParamsChange = useCallback((newParams: Record) => { + const newInputs = produce(inputRef.current, (draft) => { + if (!draft.single_retrieval_config) { + draft.single_retrieval_config = { + model: { + provider: '', + name: '', + mode: '', + completion_params: {}, + }, + } + } + draft.single_retrieval_config.model.completion_params = newParams + }) + setInputs(newInputs) + }, [setInputs]) + + // set defaults models + useEffect(() => { + const inputs = inputRef.current + const newInput = produce(inputs, (draft) => { + if (currentProvider?.provider && currentModel?.model) { + const hasSetModel = draft.single_retrieval_config?.model?.provider + if (!hasSetModel) { + draft.single_retrieval_config = { + model: { + provider: currentProvider?.provider, + name: currentModel?.model, + mode: currentModel?.model_properties?.mode as string, + completion_params: {}, + }, + } + } + } + + const multipleRetrievalConfig = draft.multiple_retrieval_config + draft.multiple_retrieval_config = { + top_k: multipleRetrievalConfig?.top_k || DATASET_DEFAULT.top_k, + score_threshold: multipleRetrievalConfig?.score_threshold, + reranking_model: payload.retrieval_mode === RETRIEVE_TYPE.oneWay + ? undefined + : (!multipleRetrievalConfig?.reranking_model?.provider + ? { + provider: rerankDefaultModel?.provider?.provider || '', + model: rerankDefaultModel?.model || '', + } + : multipleRetrievalConfig?.reranking_model), + } + }) + setInputs(newInput) + // eslint-disable-next-line react-hooks/exhaustive-deps + }, [currentProvider?.provider, currentModel, rerankDefaultModel]) + const handleRetrievalModeChange = useCallback((newMode: RETRIEVE_TYPE) => { const newInputs = produce(inputs, (draft) => { draft.retrieval_mode = newMode + if (newMode === RETRIEVE_TYPE.multiWay) { + draft.multiple_retrieval_config = { + top_k: draft.multiple_retrieval_config?.top_k || DATASET_DEFAULT.top_k, + score_threshold: draft.multiple_retrieval_config?.score_threshold, + reranking_model: !draft.multiple_retrieval_config?.reranking_model?.provider + ? { + provider: rerankDefaultModel?.provider?.provider || '', + model: rerankDefaultModel?.model || '', + } + : draft.multiple_retrieval_config?.reranking_model, + } + } + else { + const hasSetModel = draft.single_retrieval_config?.model?.provider + if (!hasSetModel) { + draft.single_retrieval_config = { + model: { + provider: currentProvider?.provider || '', + name: currentModel?.model || '', + mode: currentModel?.model_properties?.mode as string, + completion_params: {}, + }, + } + } + } }) setInputs(newInputs) - }, [inputs, setInputs]) + }, [currentModel?.model, currentModel?.model_properties?.mode, currentProvider?.provider, inputs, rerankDefaultModel?.model, rerankDefaultModel?.provider?.provider, setInputs]) const handleMultipleRetrievalConfigChange = useCallback((newConfig: MultipleRetrievalConfig) => { const newInputs = produce(inputs, (draft) => { @@ -111,6 +225,8 @@ const useConfig = (id: string, payload: KnowledgeRetrievalNodeType) => { filterVar, handleRetrievalModeChange, handleMultipleRetrievalConfigChange, + handleModelChanged, + handleCompletionParamsChange, selectedDatasets, handleOnDatasetsChange, isShowSingleRun,