feat: knowledge support one sigle

This commit is contained in:
Joel 2024-03-18 18:48:51 +08:00
parent 4eb7546177
commit c409ab4c3c
6 changed files with 195 additions and 10 deletions

View File

@ -16,15 +16,27 @@ import type {
import ModelSelector from '@/app/components/header/account-setting/model-provider-page/model-selector' import ModelSelector from '@/app/components/header/account-setting/model-provider-page/model-selector'
import { useModelListAndDefaultModelAndCurrentProviderAndModel } from '@/app/components/header/account-setting/model-provider-page/hooks' import { useModelListAndDefaultModelAndCurrentProviderAndModel } from '@/app/components/header/account-setting/model-provider-page/hooks'
import type { ModelConfig } from '@/app/components/workflow/types'
import ModelParameterModal from '@/app/components/header/account-setting/model-provider-page/model-parameter-modal'
import TooltipPlus from '@/app/components/base/tooltip-plus'
import { HelpCircle } from '@/app/components/base/icons/src/vender/line/general'
type Props = { type Props = {
datasetConfigs: DatasetConfigs datasetConfigs: DatasetConfigs
onChange: (configs: DatasetConfigs, isRetrievalModeChange?: boolean) => void onChange: (configs: DatasetConfigs, isRetrievalModeChange?: boolean) => void
isInWorkflow?: boolean
singleRetrievalModelConfig?: ModelConfig
onSingleRetrievalModelChange?: (config: ModelConfig) => void
onSingleRetrievalModelParamsChange?: (config: ModelConfig) => void
} }
const ConfigContent: FC<Props> = ({ const ConfigContent: FC<Props> = ({
datasetConfigs, datasetConfigs,
onChange, onChange,
isInWorkflow,
singleRetrievalModelConfig: singleRetrievalConfig = {} as ModelConfig,
onSingleRetrievalModelChange = () => { },
onSingleRetrievalModelParamsChange = () => { },
}) => { }) => {
const { t } = useTranslation() const { t } = useTranslation()
const type = datasetConfigs.retrieval_model const type = datasetConfigs.retrieval_model
@ -77,6 +89,9 @@ const ConfigContent: FC<Props> = ({
score_threshold_enabled: enable, score_threshold_enabled: enable,
}) })
} }
const model = singleRetrievalConfig
return ( return (
<div> <div>
<div className='mt-2 space-y-3'> <div className='mt-2 space-y-3'>
@ -122,7 +137,7 @@ const ConfigContent: FC<Props> = ({
enable={true} enable={true}
/> />
<ScoreThresholdItem <ScoreThresholdItem
value={datasetConfigs.score_threshold} value={datasetConfigs.score_threshold as number}
onChange={handleParamChange} onChange={handleParamChange}
enable={datasetConfigs.score_threshold_enabled} enable={datasetConfigs.score_threshold_enabled}
hasSwitch={true} hasSwitch={true}
@ -131,7 +146,34 @@ const ConfigContent: FC<Props> = ({
</div> </div>
</> </>
)} )}
</div>
{isInWorkflow && type === RETRIEVE_TYPE.oneWay && (
<div className='mt-6'>
<div className='flex items-center space-x-0.5'>
<div className='leading-[32px] text-[13px] font-medium text-gray-900'>{t('common.modelProvider.systemReasoningModel.key')}</div>
<TooltipPlus
popupContent={t('common.modelProvider.systemReasoningModel.tip')}
>
<HelpCircle className='w-3.5 h-4.5 text-gray-400' />
</TooltipPlus>
</div>
<ModelParameterModal
popupClassName='!w-[387px]'
portalToFollowElemContentClassName='!z-[1002]'
isAdvancedMode={true}
mode={model?.mode}
provider={model?.provider}
completionParams={model?.completion_params}
modelId={model?.name}
setModel={onSingleRetrievalModelChange as any}
onCompletionParamsChange={onSingleRetrievalModelParamsChange as any}
hideDebugWithMultipleModel
debugWithMultipleModel={false}
/>
</div>
)
}
</div >
) )
} }
export default React.memo(ConfigContent) export default React.memo(ConfigContent)

View File

@ -35,6 +35,7 @@ import { ArrowNarrowLeft } from '@/app/components/base/icons/src/vender/line/arr
export type ModelParameterModalProps = { export type ModelParameterModalProps = {
popupClassName?: string popupClassName?: string
portalToFollowElemContentClassName?: string
isAdvancedMode: boolean isAdvancedMode: boolean
mode: string mode: string
modelId: string modelId: string
@ -69,6 +70,7 @@ const stopParameerRule: ModelParameterRule = {
const PROVIDER_WITH_PRESET_TONE = ['openai', 'azure_openai'] const PROVIDER_WITH_PRESET_TONE = ['openai', 'azure_openai']
const ModelParameterModal: FC<ModelParameterModalProps> = ({ const ModelParameterModal: FC<ModelParameterModalProps> = ({
popupClassName, popupClassName,
portalToFollowElemContentClassName,
isAdvancedMode, isAdvancedMode,
modelId, modelId,
provider, provider,
@ -200,7 +202,7 @@ const ModelParameterModal: FC<ModelParameterModalProps> = ({
) )
} }
</PortalToFollowElemTrigger> </PortalToFollowElemTrigger>
<PortalToFollowElemContent className='z-[60]'> <PortalToFollowElemContent className={cn(portalToFollowElemContentClassName, 'z-[60]')}>
<div className={cn(popupClassName, 'w-[496px] rounded-xl border border-gray-100 bg-white shadow-xl')}> <div className={cn(popupClassName, 'w-[496px] rounded-xl border border-gray-100 bg-white shadow-xl')}>
<div className='max-h-[480px] px-10 pt-6 pb-8 overflow-y-auto'> <div className='max-h-[480px] px-10 pt-6 pb-8 overflow-y-auto'>
<div className='flex items-center justify-between h-8'> <div className='flex items-center justify-between h-8'>

View File

@ -3,7 +3,8 @@ import type { FC } from 'react'
import React, { useCallback, useState } from 'react' import React, { useCallback, useState } from 'react'
import { useTranslation } from 'react-i18next' import { useTranslation } from 'react-i18next'
import cn from 'classnames' import cn from 'classnames'
import type { MultipleRetrievalConfig } from '../types' import type { MultipleRetrievalConfig, SingleRetrievalConfig } from '../types'
import type { ModelConfig } from '../../../types'
import { import {
PortalToFollowElem, PortalToFollowElem,
PortalToFollowElemContent, PortalToFollowElemContent,
@ -23,15 +24,22 @@ type Props = {
payload: { payload: {
retrieval_mode: RETRIEVE_TYPE retrieval_mode: RETRIEVE_TYPE
multiple_retrieval_config?: MultipleRetrievalConfig multiple_retrieval_config?: MultipleRetrievalConfig
single_retrieval_config?: SingleRetrievalConfig
} }
onRetrievalModeChange: (mode: RETRIEVE_TYPE) => void onRetrievalModeChange: (mode: RETRIEVE_TYPE) => void
onMultipleRetrievalConfigChange: (config: MultipleRetrievalConfig) => void onMultipleRetrievalConfigChange: (config: MultipleRetrievalConfig) => void
singleRetrievalModelConfig?: ModelConfig
onSingleRetrievalModelChange?: (config: ModelConfig) => void
onSingleRetrievalModelParamsChange?: (config: ModelConfig) => void
} }
const RetrievalConfig: FC<Props> = ({ const RetrievalConfig: FC<Props> = ({
payload, payload,
onRetrievalModeChange, onRetrievalModeChange,
onMultipleRetrievalConfigChange, onMultipleRetrievalConfigChange,
singleRetrievalModelConfig,
onSingleRetrievalModelChange,
onSingleRetrievalModelParamsChange,
}) => { }) => {
const { t } = useTranslation() const { t } = useTranslation()
@ -43,6 +51,7 @@ const RetrievalConfig: FC<Props> = ({
const { multiple_retrieval_config } = payload const { multiple_retrieval_config } = payload
const handleChange = useCallback((configs: DatasetConfigs, isRetrievalModeChange?: boolean) => { const handleChange = useCallback((configs: DatasetConfigs, isRetrievalModeChange?: boolean) => {
console.log(configs, isRetrievalModeChange)
if (isRetrievalModeChange) { if (isRetrievalModeChange) {
onRetrievalModeChange(configs.retrieval_model) onRetrievalModeChange(configs.retrieval_model)
return return
@ -62,7 +71,7 @@ const RetrievalConfig: FC<Props> = ({
model: configs.reranking_model?.reranking_model_name, model: configs.reranking_model?.reranking_model_name,
}), }),
}) })
}, [onRetrievalModeChange, onMultipleRetrievalConfigChange]) }, [onMultipleRetrievalConfigChange, payload.retrieval_mode, rerankDefaultModel?.provider?.provider, rerankDefaultModel?.model, onRetrievalModeChange])
return ( return (
<PortalToFollowElem <PortalToFollowElem
@ -106,6 +115,10 @@ const RetrievalConfig: FC<Props> = ({
} }
} }
onChange={handleChange} onChange={handleChange}
isInWorkflow
singleRetrievalModelConfig={singleRetrievalModelConfig}
onSingleRetrievalModelChange={onSingleRetrievalModelChange}
onSingleRetrievalModelParamsChange={onSingleRetrievalModelParamsChange}
/> />
</div> </div>
</PortalToFollowElemContent> </PortalToFollowElemContent>

View File

@ -27,6 +27,8 @@ const Panel: FC<NodePanelProps<KnowledgeRetrievalNodeType>> = ({
inputs, inputs,
handleQueryVarChange, handleQueryVarChange,
filterVar, filterVar,
handleModelChanged,
handleCompletionParamsChange,
handleRetrievalModeChange, handleRetrievalModeChange,
handleMultipleRetrievalConfigChange, handleMultipleRetrievalConfigChange,
selectedDatasets, selectedDatasets,
@ -66,9 +68,13 @@ const Panel: FC<NodePanelProps<KnowledgeRetrievalNodeType>> = ({
payload={{ payload={{
retrieval_mode: inputs.retrieval_mode, retrieval_mode: inputs.retrieval_mode,
multiple_retrieval_config: inputs.multiple_retrieval_config, multiple_retrieval_config: inputs.multiple_retrieval_config,
single_retrieval_config: inputs.single_retrieval_config,
}} }}
onRetrievalModeChange={handleRetrievalModeChange} onRetrievalModeChange={handleRetrievalModeChange}
onMultipleRetrievalConfigChange={handleMultipleRetrievalConfigChange} onMultipleRetrievalConfigChange={handleMultipleRetrievalConfigChange}
singleRetrievalModelConfig={inputs.single_retrieval_config?.model}
onSingleRetrievalModelChange={handleModelChanged as any}
onSingleRetrievalModelParamsChange={handleCompletionParamsChange}
/> />
<div className='w-px h-3 bg-gray-200'></div> <div className='w-px h-3 bg-gray-200'></div>
<AddKnowledge <AddKnowledge

View File

@ -1,4 +1,4 @@
import type { CommonNodeType, ValueSelector } from '@/app/components/workflow/types' import type { CommonNodeType, ModelConfig, ValueSelector } from '@/app/components/workflow/types'
import type { RETRIEVE_TYPE } from '@/types/app' import type { RETRIEVE_TYPE } from '@/types/app'
export type MultipleRetrievalConfig = { export type MultipleRetrievalConfig = {
@ -9,9 +9,15 @@ export type MultipleRetrievalConfig = {
model: string model: string
} }
} }
export type SingleRetrievalConfig = {
model: ModelConfig
}
export type KnowledgeRetrievalNodeType = CommonNodeType & { export type KnowledgeRetrievalNodeType = CommonNodeType & {
query_variable_selector: ValueSelector query_variable_selector: ValueSelector
dataset_ids: string[] dataset_ids: string[]
retrieval_mode: RETRIEVE_TYPE retrieval_mode: RETRIEVE_TYPE
multiple_retrieval_config?: MultipleRetrievalConfig multiple_retrieval_config?: MultipleRetrievalConfig
single_retrieval_config?: SingleRetrievalConfig
} }

View File

@ -1,22 +1,29 @@
import { useCallback, useEffect, useState } from 'react' import { useCallback, useEffect, useRef, useState } from 'react'
import produce from 'immer' import produce from 'immer'
import type { ValueSelector, Var } from '../../types' import type { ValueSelector, Var } from '../../types'
import { BlockEnum, VarType } from '../../types' import { BlockEnum, VarType } from '../../types'
import { useIsChatMode, useWorkflow } from '../../hooks' import { useIsChatMode, useWorkflow } from '../../hooks'
import type { KnowledgeRetrievalNodeType, MultipleRetrievalConfig } from './types' import type { KnowledgeRetrievalNodeType, MultipleRetrievalConfig } from './types'
import type { RETRIEVE_TYPE } from '@/types/app' import { RETRIEVE_TYPE } from '@/types/app'
import { DATASET_DEFAULT } from '@/config'
import type { DataSet } from '@/models/datasets' import type { DataSet } from '@/models/datasets'
import { fetchDatasets } from '@/service/datasets' import { fetchDatasets } from '@/service/datasets'
import useNodeCrud from '@/app/components/workflow/nodes/_base/hooks/use-node-crud' import useNodeCrud from '@/app/components/workflow/nodes/_base/hooks/use-node-crud'
import useOneStepRun from '@/app/components/workflow/nodes/_base/hooks/use-one-step-run' import useOneStepRun from '@/app/components/workflow/nodes/_base/hooks/use-one-step-run'
import { useModelListAndDefaultModelAndCurrentProviderAndModel } from '@/app/components/header/account-setting/model-provider-page/hooks'
const useConfig = (id: string, payload: KnowledgeRetrievalNodeType) => { const useConfig = (id: string, payload: KnowledgeRetrievalNodeType) => {
const isChatMode = useIsChatMode() const isChatMode = useIsChatMode()
console.log()
const { getBeforeNodesInSameBranch } = useWorkflow() const { getBeforeNodesInSameBranch } = useWorkflow()
const startNode = getBeforeNodesInSameBranch(id).find(node => node.data.type === BlockEnum.Start) const startNode = getBeforeNodesInSameBranch(id).find(node => node.data.type === BlockEnum.Start)
const startNodeId = startNode?.id const startNodeId = startNode?.id
const { inputs, setInputs } = useNodeCrud<KnowledgeRetrievalNodeType>(id, payload) const { inputs, setInputs } = useNodeCrud<KnowledgeRetrievalNodeType>(id, payload)
const inputRef = useRef(inputs)
useEffect(() => {
inputRef.current = inputs
}, [inputs])
const handleQueryVarChange = useCallback((newVar: ValueSelector | string) => { const handleQueryVarChange = useCallback((newVar: ValueSelector | string) => {
const newInputs = produce(inputs, (draft) => { const newInputs = produce(inputs, (draft) => {
draft.query_variable_selector = newVar as ValueSelector draft.query_variable_selector = newVar as ValueSelector
@ -24,12 +31,119 @@ const useConfig = (id: string, payload: KnowledgeRetrievalNodeType) => {
setInputs(newInputs) setInputs(newInputs)
}, [inputs, setInputs]) }, [inputs, setInputs])
const {
currentProvider,
currentModel,
} = useModelListAndDefaultModelAndCurrentProviderAndModel(1)
const {
defaultModel: rerankDefaultModel,
} = useModelListAndDefaultModelAndCurrentProviderAndModel(3)
const handleModelChanged = useCallback((model: { provider: string; modelId: string; mode?: string }) => {
const newInputs = produce(inputRef.current, (draft) => {
if (!draft.single_retrieval_config) {
draft.single_retrieval_config = {
model: {
provider: '',
name: '',
mode: '',
completion_params: {},
},
}
}
const draftModel = draft.single_retrieval_config?.model
draftModel.provider = model.provider
draftModel.name = model.modelId
draftModel.mode = model.mode!
})
setInputs(newInputs)
}, [setInputs])
const handleCompletionParamsChange = useCallback((newParams: Record<string, any>) => {
const newInputs = produce(inputRef.current, (draft) => {
if (!draft.single_retrieval_config) {
draft.single_retrieval_config = {
model: {
provider: '',
name: '',
mode: '',
completion_params: {},
},
}
}
draft.single_retrieval_config.model.completion_params = newParams
})
setInputs(newInputs)
}, [setInputs])
// set defaults models
useEffect(() => {
const inputs = inputRef.current
const newInput = produce(inputs, (draft) => {
if (currentProvider?.provider && currentModel?.model) {
const hasSetModel = draft.single_retrieval_config?.model?.provider
if (!hasSetModel) {
draft.single_retrieval_config = {
model: {
provider: currentProvider?.provider,
name: currentModel?.model,
mode: currentModel?.model_properties?.mode as string,
completion_params: {},
},
}
}
}
const multipleRetrievalConfig = draft.multiple_retrieval_config
draft.multiple_retrieval_config = {
top_k: multipleRetrievalConfig?.top_k || DATASET_DEFAULT.top_k,
score_threshold: multipleRetrievalConfig?.score_threshold,
reranking_model: payload.retrieval_mode === RETRIEVE_TYPE.oneWay
? undefined
: (!multipleRetrievalConfig?.reranking_model?.provider
? {
provider: rerankDefaultModel?.provider?.provider || '',
model: rerankDefaultModel?.model || '',
}
: multipleRetrievalConfig?.reranking_model),
}
})
setInputs(newInput)
// eslint-disable-next-line react-hooks/exhaustive-deps
}, [currentProvider?.provider, currentModel, rerankDefaultModel])
const handleRetrievalModeChange = useCallback((newMode: RETRIEVE_TYPE) => { const handleRetrievalModeChange = useCallback((newMode: RETRIEVE_TYPE) => {
const newInputs = produce(inputs, (draft) => { const newInputs = produce(inputs, (draft) => {
draft.retrieval_mode = newMode draft.retrieval_mode = newMode
if (newMode === RETRIEVE_TYPE.multiWay) {
draft.multiple_retrieval_config = {
top_k: draft.multiple_retrieval_config?.top_k || DATASET_DEFAULT.top_k,
score_threshold: draft.multiple_retrieval_config?.score_threshold,
reranking_model: !draft.multiple_retrieval_config?.reranking_model?.provider
? {
provider: rerankDefaultModel?.provider?.provider || '',
model: rerankDefaultModel?.model || '',
}
: draft.multiple_retrieval_config?.reranking_model,
}
}
else {
const hasSetModel = draft.single_retrieval_config?.model?.provider
if (!hasSetModel) {
draft.single_retrieval_config = {
model: {
provider: currentProvider?.provider || '',
name: currentModel?.model || '',
mode: currentModel?.model_properties?.mode as string,
completion_params: {},
},
}
}
}
}) })
setInputs(newInputs) setInputs(newInputs)
}, [inputs, setInputs]) }, [currentModel?.model, currentModel?.model_properties?.mode, currentProvider?.provider, inputs, rerankDefaultModel?.model, rerankDefaultModel?.provider?.provider, setInputs])
const handleMultipleRetrievalConfigChange = useCallback((newConfig: MultipleRetrievalConfig) => { const handleMultipleRetrievalConfigChange = useCallback((newConfig: MultipleRetrievalConfig) => {
const newInputs = produce(inputs, (draft) => { const newInputs = produce(inputs, (draft) => {
@ -111,6 +225,8 @@ const useConfig = (id: string, payload: KnowledgeRetrievalNodeType) => {
filterVar, filterVar,
handleRetrievalModeChange, handleRetrievalModeChange,
handleMultipleRetrievalConfigChange, handleMultipleRetrievalConfigChange,
handleModelChanged,
handleCompletionParamsChange,
selectedDatasets, selectedDatasets,
handleOnDatasetsChange, handleOnDatasetsChange,
isShowSingleRun, isShowSingleRun,