feat: knowledge support one sigle

This commit is contained in:
Joel 2024-03-18 18:48:51 +08:00
parent 4eb7546177
commit c409ab4c3c
6 changed files with 195 additions and 10 deletions

View File

@ -16,15 +16,27 @@ import type {
import ModelSelector from '@/app/components/header/account-setting/model-provider-page/model-selector'
import { useModelListAndDefaultModelAndCurrentProviderAndModel } from '@/app/components/header/account-setting/model-provider-page/hooks'
import type { ModelConfig } from '@/app/components/workflow/types'
import ModelParameterModal from '@/app/components/header/account-setting/model-provider-page/model-parameter-modal'
import TooltipPlus from '@/app/components/base/tooltip-plus'
import { HelpCircle } from '@/app/components/base/icons/src/vender/line/general'
type Props = {
datasetConfigs: DatasetConfigs
onChange: (configs: DatasetConfigs, isRetrievalModeChange?: boolean) => void
isInWorkflow?: boolean
singleRetrievalModelConfig?: ModelConfig
onSingleRetrievalModelChange?: (config: ModelConfig) => void
onSingleRetrievalModelParamsChange?: (config: ModelConfig) => void
}
const ConfigContent: FC<Props> = ({
datasetConfigs,
onChange,
isInWorkflow,
singleRetrievalModelConfig: singleRetrievalConfig = {} as ModelConfig,
onSingleRetrievalModelChange = () => { },
onSingleRetrievalModelParamsChange = () => { },
}) => {
const { t } = useTranslation()
const type = datasetConfigs.retrieval_model
@ -77,6 +89,9 @@ const ConfigContent: FC<Props> = ({
score_threshold_enabled: enable,
})
}
const model = singleRetrievalConfig
return (
<div>
<div className='mt-2 space-y-3'>
@ -122,7 +137,7 @@ const ConfigContent: FC<Props> = ({
enable={true}
/>
<ScoreThresholdItem
value={datasetConfigs.score_threshold}
value={datasetConfigs.score_threshold as number}
onChange={handleParamChange}
enable={datasetConfigs.score_threshold_enabled}
hasSwitch={true}
@ -131,7 +146,34 @@ const ConfigContent: FC<Props> = ({
</div>
</>
)}
</div>
{isInWorkflow && type === RETRIEVE_TYPE.oneWay && (
<div className='mt-6'>
<div className='flex items-center space-x-0.5'>
<div className='leading-[32px] text-[13px] font-medium text-gray-900'>{t('common.modelProvider.systemReasoningModel.key')}</div>
<TooltipPlus
popupContent={t('common.modelProvider.systemReasoningModel.tip')}
>
<HelpCircle className='w-3.5 h-4.5 text-gray-400' />
</TooltipPlus>
</div>
<ModelParameterModal
popupClassName='!w-[387px]'
portalToFollowElemContentClassName='!z-[1002]'
isAdvancedMode={true}
mode={model?.mode}
provider={model?.provider}
completionParams={model?.completion_params}
modelId={model?.name}
setModel={onSingleRetrievalModelChange as any}
onCompletionParamsChange={onSingleRetrievalModelParamsChange as any}
hideDebugWithMultipleModel
debugWithMultipleModel={false}
/>
</div>
)
}
</div >
)
}
export default React.memo(ConfigContent)

View File

@ -35,6 +35,7 @@ import { ArrowNarrowLeft } from '@/app/components/base/icons/src/vender/line/arr
export type ModelParameterModalProps = {
popupClassName?: string
portalToFollowElemContentClassName?: string
isAdvancedMode: boolean
mode: string
modelId: string
@ -69,6 +70,7 @@ const stopParameerRule: ModelParameterRule = {
const PROVIDER_WITH_PRESET_TONE = ['openai', 'azure_openai']
const ModelParameterModal: FC<ModelParameterModalProps> = ({
popupClassName,
portalToFollowElemContentClassName,
isAdvancedMode,
modelId,
provider,
@ -200,7 +202,7 @@ const ModelParameterModal: FC<ModelParameterModalProps> = ({
)
}
</PortalToFollowElemTrigger>
<PortalToFollowElemContent className='z-[60]'>
<PortalToFollowElemContent className={cn(portalToFollowElemContentClassName, 'z-[60]')}>
<div className={cn(popupClassName, 'w-[496px] rounded-xl border border-gray-100 bg-white shadow-xl')}>
<div className='max-h-[480px] px-10 pt-6 pb-8 overflow-y-auto'>
<div className='flex items-center justify-between h-8'>

View File

@ -3,7 +3,8 @@ import type { FC } from 'react'
import React, { useCallback, useState } from 'react'
import { useTranslation } from 'react-i18next'
import cn from 'classnames'
import type { MultipleRetrievalConfig } from '../types'
import type { MultipleRetrievalConfig, SingleRetrievalConfig } from '../types'
import type { ModelConfig } from '../../../types'
import {
PortalToFollowElem,
PortalToFollowElemContent,
@ -23,15 +24,22 @@ type Props = {
payload: {
retrieval_mode: RETRIEVE_TYPE
multiple_retrieval_config?: MultipleRetrievalConfig
single_retrieval_config?: SingleRetrievalConfig
}
onRetrievalModeChange: (mode: RETRIEVE_TYPE) => void
onMultipleRetrievalConfigChange: (config: MultipleRetrievalConfig) => void
singleRetrievalModelConfig?: ModelConfig
onSingleRetrievalModelChange?: (config: ModelConfig) => void
onSingleRetrievalModelParamsChange?: (config: ModelConfig) => void
}
const RetrievalConfig: FC<Props> = ({
payload,
onRetrievalModeChange,
onMultipleRetrievalConfigChange,
singleRetrievalModelConfig,
onSingleRetrievalModelChange,
onSingleRetrievalModelParamsChange,
}) => {
const { t } = useTranslation()
@ -43,6 +51,7 @@ const RetrievalConfig: FC<Props> = ({
const { multiple_retrieval_config } = payload
const handleChange = useCallback((configs: DatasetConfigs, isRetrievalModeChange?: boolean) => {
console.log(configs, isRetrievalModeChange)
if (isRetrievalModeChange) {
onRetrievalModeChange(configs.retrieval_model)
return
@ -62,7 +71,7 @@ const RetrievalConfig: FC<Props> = ({
model: configs.reranking_model?.reranking_model_name,
}),
})
}, [onRetrievalModeChange, onMultipleRetrievalConfigChange])
}, [onMultipleRetrievalConfigChange, payload.retrieval_mode, rerankDefaultModel?.provider?.provider, rerankDefaultModel?.model, onRetrievalModeChange])
return (
<PortalToFollowElem
@ -106,6 +115,10 @@ const RetrievalConfig: FC<Props> = ({
}
}
onChange={handleChange}
isInWorkflow
singleRetrievalModelConfig={singleRetrievalModelConfig}
onSingleRetrievalModelChange={onSingleRetrievalModelChange}
onSingleRetrievalModelParamsChange={onSingleRetrievalModelParamsChange}
/>
</div>
</PortalToFollowElemContent>

View File

@ -27,6 +27,8 @@ const Panel: FC<NodePanelProps<KnowledgeRetrievalNodeType>> = ({
inputs,
handleQueryVarChange,
filterVar,
handleModelChanged,
handleCompletionParamsChange,
handleRetrievalModeChange,
handleMultipleRetrievalConfigChange,
selectedDatasets,
@ -66,9 +68,13 @@ const Panel: FC<NodePanelProps<KnowledgeRetrievalNodeType>> = ({
payload={{
retrieval_mode: inputs.retrieval_mode,
multiple_retrieval_config: inputs.multiple_retrieval_config,
single_retrieval_config: inputs.single_retrieval_config,
}}
onRetrievalModeChange={handleRetrievalModeChange}
onMultipleRetrievalConfigChange={handleMultipleRetrievalConfigChange}
singleRetrievalModelConfig={inputs.single_retrieval_config?.model}
onSingleRetrievalModelChange={handleModelChanged as any}
onSingleRetrievalModelParamsChange={handleCompletionParamsChange}
/>
<div className='w-px h-3 bg-gray-200'></div>
<AddKnowledge

View File

@ -1,4 +1,4 @@
import type { CommonNodeType, ValueSelector } from '@/app/components/workflow/types'
import type { CommonNodeType, ModelConfig, ValueSelector } from '@/app/components/workflow/types'
import type { RETRIEVE_TYPE } from '@/types/app'
export type MultipleRetrievalConfig = {
@ -9,9 +9,15 @@ export type MultipleRetrievalConfig = {
model: string
}
}
export type SingleRetrievalConfig = {
model: ModelConfig
}
export type KnowledgeRetrievalNodeType = CommonNodeType & {
query_variable_selector: ValueSelector
dataset_ids: string[]
retrieval_mode: RETRIEVE_TYPE
multiple_retrieval_config?: MultipleRetrievalConfig
single_retrieval_config?: SingleRetrievalConfig
}

View File

@ -1,22 +1,29 @@
import { useCallback, useEffect, useState } from 'react'
import { useCallback, useEffect, useRef, useState } from 'react'
import produce from 'immer'
import type { ValueSelector, Var } from '../../types'
import { BlockEnum, VarType } from '../../types'
import { useIsChatMode, useWorkflow } from '../../hooks'
import type { KnowledgeRetrievalNodeType, MultipleRetrievalConfig } from './types'
import type { RETRIEVE_TYPE } from '@/types/app'
import { RETRIEVE_TYPE } from '@/types/app'
import { DATASET_DEFAULT } from '@/config'
import type { DataSet } from '@/models/datasets'
import { fetchDatasets } from '@/service/datasets'
import useNodeCrud from '@/app/components/workflow/nodes/_base/hooks/use-node-crud'
import useOneStepRun from '@/app/components/workflow/nodes/_base/hooks/use-one-step-run'
import { useModelListAndDefaultModelAndCurrentProviderAndModel } from '@/app/components/header/account-setting/model-provider-page/hooks'
const useConfig = (id: string, payload: KnowledgeRetrievalNodeType) => {
const isChatMode = useIsChatMode()
console.log()
const { getBeforeNodesInSameBranch } = useWorkflow()
const startNode = getBeforeNodesInSameBranch(id).find(node => node.data.type === BlockEnum.Start)
const startNodeId = startNode?.id
const { inputs, setInputs } = useNodeCrud<KnowledgeRetrievalNodeType>(id, payload)
const inputRef = useRef(inputs)
useEffect(() => {
inputRef.current = inputs
}, [inputs])
const handleQueryVarChange = useCallback((newVar: ValueSelector | string) => {
const newInputs = produce(inputs, (draft) => {
draft.query_variable_selector = newVar as ValueSelector
@ -24,12 +31,119 @@ const useConfig = (id: string, payload: KnowledgeRetrievalNodeType) => {
setInputs(newInputs)
}, [inputs, setInputs])
const {
currentProvider,
currentModel,
} = useModelListAndDefaultModelAndCurrentProviderAndModel(1)
const {
defaultModel: rerankDefaultModel,
} = useModelListAndDefaultModelAndCurrentProviderAndModel(3)
const handleModelChanged = useCallback((model: { provider: string; modelId: string; mode?: string }) => {
const newInputs = produce(inputRef.current, (draft) => {
if (!draft.single_retrieval_config) {
draft.single_retrieval_config = {
model: {
provider: '',
name: '',
mode: '',
completion_params: {},
},
}
}
const draftModel = draft.single_retrieval_config?.model
draftModel.provider = model.provider
draftModel.name = model.modelId
draftModel.mode = model.mode!
})
setInputs(newInputs)
}, [setInputs])
const handleCompletionParamsChange = useCallback((newParams: Record<string, any>) => {
const newInputs = produce(inputRef.current, (draft) => {
if (!draft.single_retrieval_config) {
draft.single_retrieval_config = {
model: {
provider: '',
name: '',
mode: '',
completion_params: {},
},
}
}
draft.single_retrieval_config.model.completion_params = newParams
})
setInputs(newInputs)
}, [setInputs])
// set defaults models
useEffect(() => {
const inputs = inputRef.current
const newInput = produce(inputs, (draft) => {
if (currentProvider?.provider && currentModel?.model) {
const hasSetModel = draft.single_retrieval_config?.model?.provider
if (!hasSetModel) {
draft.single_retrieval_config = {
model: {
provider: currentProvider?.provider,
name: currentModel?.model,
mode: currentModel?.model_properties?.mode as string,
completion_params: {},
},
}
}
}
const multipleRetrievalConfig = draft.multiple_retrieval_config
draft.multiple_retrieval_config = {
top_k: multipleRetrievalConfig?.top_k || DATASET_DEFAULT.top_k,
score_threshold: multipleRetrievalConfig?.score_threshold,
reranking_model: payload.retrieval_mode === RETRIEVE_TYPE.oneWay
? undefined
: (!multipleRetrievalConfig?.reranking_model?.provider
? {
provider: rerankDefaultModel?.provider?.provider || '',
model: rerankDefaultModel?.model || '',
}
: multipleRetrievalConfig?.reranking_model),
}
})
setInputs(newInput)
// eslint-disable-next-line react-hooks/exhaustive-deps
}, [currentProvider?.provider, currentModel, rerankDefaultModel])
const handleRetrievalModeChange = useCallback((newMode: RETRIEVE_TYPE) => {
const newInputs = produce(inputs, (draft) => {
draft.retrieval_mode = newMode
if (newMode === RETRIEVE_TYPE.multiWay) {
draft.multiple_retrieval_config = {
top_k: draft.multiple_retrieval_config?.top_k || DATASET_DEFAULT.top_k,
score_threshold: draft.multiple_retrieval_config?.score_threshold,
reranking_model: !draft.multiple_retrieval_config?.reranking_model?.provider
? {
provider: rerankDefaultModel?.provider?.provider || '',
model: rerankDefaultModel?.model || '',
}
: draft.multiple_retrieval_config?.reranking_model,
}
}
else {
const hasSetModel = draft.single_retrieval_config?.model?.provider
if (!hasSetModel) {
draft.single_retrieval_config = {
model: {
provider: currentProvider?.provider || '',
name: currentModel?.model || '',
mode: currentModel?.model_properties?.mode as string,
completion_params: {},
},
}
}
}
})
setInputs(newInputs)
}, [inputs, setInputs])
}, [currentModel?.model, currentModel?.model_properties?.mode, currentProvider?.provider, inputs, rerankDefaultModel?.model, rerankDefaultModel?.provider?.provider, setInputs])
const handleMultipleRetrievalConfigChange = useCallback((newConfig: MultipleRetrievalConfig) => {
const newInputs = produce(inputs, (draft) => {
@ -111,6 +225,8 @@ const useConfig = (id: string, payload: KnowledgeRetrievalNodeType) => {
filterVar,
handleRetrievalModeChange,
handleMultipleRetrievalConfigChange,
handleModelChanged,
handleCompletionParamsChange,
selectedDatasets,
handleOnDatasetsChange,
isShowSingleRun,