mirror of https://github.com/langgenius/dify.git
feat: knowledge support one sigle
This commit is contained in:
parent
4eb7546177
commit
c409ab4c3c
|
|
@ -16,15 +16,27 @@ import type {
|
|||
|
||||
import ModelSelector from '@/app/components/header/account-setting/model-provider-page/model-selector'
|
||||
import { useModelListAndDefaultModelAndCurrentProviderAndModel } from '@/app/components/header/account-setting/model-provider-page/hooks'
|
||||
import type { ModelConfig } from '@/app/components/workflow/types'
|
||||
import ModelParameterModal from '@/app/components/header/account-setting/model-provider-page/model-parameter-modal'
|
||||
import TooltipPlus from '@/app/components/base/tooltip-plus'
|
||||
import { HelpCircle } from '@/app/components/base/icons/src/vender/line/general'
|
||||
|
||||
type Props = {
|
||||
datasetConfigs: DatasetConfigs
|
||||
onChange: (configs: DatasetConfigs, isRetrievalModeChange?: boolean) => void
|
||||
isInWorkflow?: boolean
|
||||
singleRetrievalModelConfig?: ModelConfig
|
||||
onSingleRetrievalModelChange?: (config: ModelConfig) => void
|
||||
onSingleRetrievalModelParamsChange?: (config: ModelConfig) => void
|
||||
}
|
||||
|
||||
const ConfigContent: FC<Props> = ({
|
||||
datasetConfigs,
|
||||
onChange,
|
||||
isInWorkflow,
|
||||
singleRetrievalModelConfig: singleRetrievalConfig = {} as ModelConfig,
|
||||
onSingleRetrievalModelChange = () => { },
|
||||
onSingleRetrievalModelParamsChange = () => { },
|
||||
}) => {
|
||||
const { t } = useTranslation()
|
||||
const type = datasetConfigs.retrieval_model
|
||||
|
|
@ -77,6 +89,9 @@ const ConfigContent: FC<Props> = ({
|
|||
score_threshold_enabled: enable,
|
||||
})
|
||||
}
|
||||
|
||||
const model = singleRetrievalConfig
|
||||
|
||||
return (
|
||||
<div>
|
||||
<div className='mt-2 space-y-3'>
|
||||
|
|
@ -122,7 +137,7 @@ const ConfigContent: FC<Props> = ({
|
|||
enable={true}
|
||||
/>
|
||||
<ScoreThresholdItem
|
||||
value={datasetConfigs.score_threshold}
|
||||
value={datasetConfigs.score_threshold as number}
|
||||
onChange={handleParamChange}
|
||||
enable={datasetConfigs.score_threshold_enabled}
|
||||
hasSwitch={true}
|
||||
|
|
@ -131,7 +146,34 @@ const ConfigContent: FC<Props> = ({
|
|||
</div>
|
||||
</>
|
||||
)}
|
||||
</div>
|
||||
|
||||
{isInWorkflow && type === RETRIEVE_TYPE.oneWay && (
|
||||
<div className='mt-6'>
|
||||
<div className='flex items-center space-x-0.5'>
|
||||
<div className='leading-[32px] text-[13px] font-medium text-gray-900'>{t('common.modelProvider.systemReasoningModel.key')}</div>
|
||||
<TooltipPlus
|
||||
popupContent={t('common.modelProvider.systemReasoningModel.tip')}
|
||||
>
|
||||
<HelpCircle className='w-3.5 h-4.5 text-gray-400' />
|
||||
</TooltipPlus>
|
||||
</div>
|
||||
<ModelParameterModal
|
||||
popupClassName='!w-[387px]'
|
||||
portalToFollowElemContentClassName='!z-[1002]'
|
||||
isAdvancedMode={true}
|
||||
mode={model?.mode}
|
||||
provider={model?.provider}
|
||||
completionParams={model?.completion_params}
|
||||
modelId={model?.name}
|
||||
setModel={onSingleRetrievalModelChange as any}
|
||||
onCompletionParamsChange={onSingleRetrievalModelParamsChange as any}
|
||||
hideDebugWithMultipleModel
|
||||
debugWithMultipleModel={false}
|
||||
/>
|
||||
</div>
|
||||
)
|
||||
}
|
||||
</div >
|
||||
)
|
||||
}
|
||||
export default React.memo(ConfigContent)
|
||||
|
|
|
|||
|
|
@ -35,6 +35,7 @@ import { ArrowNarrowLeft } from '@/app/components/base/icons/src/vender/line/arr
|
|||
|
||||
export type ModelParameterModalProps = {
|
||||
popupClassName?: string
|
||||
portalToFollowElemContentClassName?: string
|
||||
isAdvancedMode: boolean
|
||||
mode: string
|
||||
modelId: string
|
||||
|
|
@ -69,6 +70,7 @@ const stopParameerRule: ModelParameterRule = {
|
|||
const PROVIDER_WITH_PRESET_TONE = ['openai', 'azure_openai']
|
||||
const ModelParameterModal: FC<ModelParameterModalProps> = ({
|
||||
popupClassName,
|
||||
portalToFollowElemContentClassName,
|
||||
isAdvancedMode,
|
||||
modelId,
|
||||
provider,
|
||||
|
|
@ -200,7 +202,7 @@ const ModelParameterModal: FC<ModelParameterModalProps> = ({
|
|||
)
|
||||
}
|
||||
</PortalToFollowElemTrigger>
|
||||
<PortalToFollowElemContent className='z-[60]'>
|
||||
<PortalToFollowElemContent className={cn(portalToFollowElemContentClassName, 'z-[60]')}>
|
||||
<div className={cn(popupClassName, 'w-[496px] rounded-xl border border-gray-100 bg-white shadow-xl')}>
|
||||
<div className='max-h-[480px] px-10 pt-6 pb-8 overflow-y-auto'>
|
||||
<div className='flex items-center justify-between h-8'>
|
||||
|
|
|
|||
|
|
@ -3,7 +3,8 @@ import type { FC } from 'react'
|
|||
import React, { useCallback, useState } from 'react'
|
||||
import { useTranslation } from 'react-i18next'
|
||||
import cn from 'classnames'
|
||||
import type { MultipleRetrievalConfig } from '../types'
|
||||
import type { MultipleRetrievalConfig, SingleRetrievalConfig } from '../types'
|
||||
import type { ModelConfig } from '../../../types'
|
||||
import {
|
||||
PortalToFollowElem,
|
||||
PortalToFollowElemContent,
|
||||
|
|
@ -23,15 +24,22 @@ type Props = {
|
|||
payload: {
|
||||
retrieval_mode: RETRIEVE_TYPE
|
||||
multiple_retrieval_config?: MultipleRetrievalConfig
|
||||
single_retrieval_config?: SingleRetrievalConfig
|
||||
}
|
||||
onRetrievalModeChange: (mode: RETRIEVE_TYPE) => void
|
||||
onMultipleRetrievalConfigChange: (config: MultipleRetrievalConfig) => void
|
||||
singleRetrievalModelConfig?: ModelConfig
|
||||
onSingleRetrievalModelChange?: (config: ModelConfig) => void
|
||||
onSingleRetrievalModelParamsChange?: (config: ModelConfig) => void
|
||||
}
|
||||
|
||||
const RetrievalConfig: FC<Props> = ({
|
||||
payload,
|
||||
onRetrievalModeChange,
|
||||
onMultipleRetrievalConfigChange,
|
||||
singleRetrievalModelConfig,
|
||||
onSingleRetrievalModelChange,
|
||||
onSingleRetrievalModelParamsChange,
|
||||
}) => {
|
||||
const { t } = useTranslation()
|
||||
|
||||
|
|
@ -43,6 +51,7 @@ const RetrievalConfig: FC<Props> = ({
|
|||
|
||||
const { multiple_retrieval_config } = payload
|
||||
const handleChange = useCallback((configs: DatasetConfigs, isRetrievalModeChange?: boolean) => {
|
||||
console.log(configs, isRetrievalModeChange)
|
||||
if (isRetrievalModeChange) {
|
||||
onRetrievalModeChange(configs.retrieval_model)
|
||||
return
|
||||
|
|
@ -62,7 +71,7 @@ const RetrievalConfig: FC<Props> = ({
|
|||
model: configs.reranking_model?.reranking_model_name,
|
||||
}),
|
||||
})
|
||||
}, [onRetrievalModeChange, onMultipleRetrievalConfigChange])
|
||||
}, [onMultipleRetrievalConfigChange, payload.retrieval_mode, rerankDefaultModel?.provider?.provider, rerankDefaultModel?.model, onRetrievalModeChange])
|
||||
|
||||
return (
|
||||
<PortalToFollowElem
|
||||
|
|
@ -106,6 +115,10 @@ const RetrievalConfig: FC<Props> = ({
|
|||
}
|
||||
}
|
||||
onChange={handleChange}
|
||||
isInWorkflow
|
||||
singleRetrievalModelConfig={singleRetrievalModelConfig}
|
||||
onSingleRetrievalModelChange={onSingleRetrievalModelChange}
|
||||
onSingleRetrievalModelParamsChange={onSingleRetrievalModelParamsChange}
|
||||
/>
|
||||
</div>
|
||||
</PortalToFollowElemContent>
|
||||
|
|
|
|||
|
|
@ -27,6 +27,8 @@ const Panel: FC<NodePanelProps<KnowledgeRetrievalNodeType>> = ({
|
|||
inputs,
|
||||
handleQueryVarChange,
|
||||
filterVar,
|
||||
handleModelChanged,
|
||||
handleCompletionParamsChange,
|
||||
handleRetrievalModeChange,
|
||||
handleMultipleRetrievalConfigChange,
|
||||
selectedDatasets,
|
||||
|
|
@ -66,9 +68,13 @@ const Panel: FC<NodePanelProps<KnowledgeRetrievalNodeType>> = ({
|
|||
payload={{
|
||||
retrieval_mode: inputs.retrieval_mode,
|
||||
multiple_retrieval_config: inputs.multiple_retrieval_config,
|
||||
single_retrieval_config: inputs.single_retrieval_config,
|
||||
}}
|
||||
onRetrievalModeChange={handleRetrievalModeChange}
|
||||
onMultipleRetrievalConfigChange={handleMultipleRetrievalConfigChange}
|
||||
singleRetrievalModelConfig={inputs.single_retrieval_config?.model}
|
||||
onSingleRetrievalModelChange={handleModelChanged as any}
|
||||
onSingleRetrievalModelParamsChange={handleCompletionParamsChange}
|
||||
/>
|
||||
<div className='w-px h-3 bg-gray-200'></div>
|
||||
<AddKnowledge
|
||||
|
|
|
|||
|
|
@ -1,4 +1,4 @@
|
|||
import type { CommonNodeType, ValueSelector } from '@/app/components/workflow/types'
|
||||
import type { CommonNodeType, ModelConfig, ValueSelector } from '@/app/components/workflow/types'
|
||||
import type { RETRIEVE_TYPE } from '@/types/app'
|
||||
|
||||
export type MultipleRetrievalConfig = {
|
||||
|
|
@ -9,9 +9,15 @@ export type MultipleRetrievalConfig = {
|
|||
model: string
|
||||
}
|
||||
}
|
||||
|
||||
export type SingleRetrievalConfig = {
|
||||
model: ModelConfig
|
||||
}
|
||||
|
||||
export type KnowledgeRetrievalNodeType = CommonNodeType & {
|
||||
query_variable_selector: ValueSelector
|
||||
dataset_ids: string[]
|
||||
retrieval_mode: RETRIEVE_TYPE
|
||||
multiple_retrieval_config?: MultipleRetrievalConfig
|
||||
single_retrieval_config?: SingleRetrievalConfig
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,22 +1,29 @@
|
|||
import { useCallback, useEffect, useState } from 'react'
|
||||
import { useCallback, useEffect, useRef, useState } from 'react'
|
||||
import produce from 'immer'
|
||||
import type { ValueSelector, Var } from '../../types'
|
||||
import { BlockEnum, VarType } from '../../types'
|
||||
import { useIsChatMode, useWorkflow } from '../../hooks'
|
||||
import type { KnowledgeRetrievalNodeType, MultipleRetrievalConfig } from './types'
|
||||
import type { RETRIEVE_TYPE } from '@/types/app'
|
||||
import { RETRIEVE_TYPE } from '@/types/app'
|
||||
import { DATASET_DEFAULT } from '@/config'
|
||||
import type { DataSet } from '@/models/datasets'
|
||||
import { fetchDatasets } from '@/service/datasets'
|
||||
import useNodeCrud from '@/app/components/workflow/nodes/_base/hooks/use-node-crud'
|
||||
import useOneStepRun from '@/app/components/workflow/nodes/_base/hooks/use-one-step-run'
|
||||
import { useModelListAndDefaultModelAndCurrentProviderAndModel } from '@/app/components/header/account-setting/model-provider-page/hooks'
|
||||
|
||||
const useConfig = (id: string, payload: KnowledgeRetrievalNodeType) => {
|
||||
const isChatMode = useIsChatMode()
|
||||
console.log()
|
||||
const { getBeforeNodesInSameBranch } = useWorkflow()
|
||||
const startNode = getBeforeNodesInSameBranch(id).find(node => node.data.type === BlockEnum.Start)
|
||||
const startNodeId = startNode?.id
|
||||
const { inputs, setInputs } = useNodeCrud<KnowledgeRetrievalNodeType>(id, payload)
|
||||
|
||||
const inputRef = useRef(inputs)
|
||||
useEffect(() => {
|
||||
inputRef.current = inputs
|
||||
}, [inputs])
|
||||
|
||||
const handleQueryVarChange = useCallback((newVar: ValueSelector | string) => {
|
||||
const newInputs = produce(inputs, (draft) => {
|
||||
draft.query_variable_selector = newVar as ValueSelector
|
||||
|
|
@ -24,12 +31,119 @@ const useConfig = (id: string, payload: KnowledgeRetrievalNodeType) => {
|
|||
setInputs(newInputs)
|
||||
}, [inputs, setInputs])
|
||||
|
||||
const {
|
||||
currentProvider,
|
||||
currentModel,
|
||||
} = useModelListAndDefaultModelAndCurrentProviderAndModel(1)
|
||||
|
||||
const {
|
||||
defaultModel: rerankDefaultModel,
|
||||
} = useModelListAndDefaultModelAndCurrentProviderAndModel(3)
|
||||
|
||||
const handleModelChanged = useCallback((model: { provider: string; modelId: string; mode?: string }) => {
|
||||
const newInputs = produce(inputRef.current, (draft) => {
|
||||
if (!draft.single_retrieval_config) {
|
||||
draft.single_retrieval_config = {
|
||||
model: {
|
||||
provider: '',
|
||||
name: '',
|
||||
mode: '',
|
||||
completion_params: {},
|
||||
},
|
||||
}
|
||||
}
|
||||
const draftModel = draft.single_retrieval_config?.model
|
||||
draftModel.provider = model.provider
|
||||
draftModel.name = model.modelId
|
||||
draftModel.mode = model.mode!
|
||||
})
|
||||
setInputs(newInputs)
|
||||
}, [setInputs])
|
||||
|
||||
const handleCompletionParamsChange = useCallback((newParams: Record<string, any>) => {
|
||||
const newInputs = produce(inputRef.current, (draft) => {
|
||||
if (!draft.single_retrieval_config) {
|
||||
draft.single_retrieval_config = {
|
||||
model: {
|
||||
provider: '',
|
||||
name: '',
|
||||
mode: '',
|
||||
completion_params: {},
|
||||
},
|
||||
}
|
||||
}
|
||||
draft.single_retrieval_config.model.completion_params = newParams
|
||||
})
|
||||
setInputs(newInputs)
|
||||
}, [setInputs])
|
||||
|
||||
// set defaults models
|
||||
useEffect(() => {
|
||||
const inputs = inputRef.current
|
||||
const newInput = produce(inputs, (draft) => {
|
||||
if (currentProvider?.provider && currentModel?.model) {
|
||||
const hasSetModel = draft.single_retrieval_config?.model?.provider
|
||||
if (!hasSetModel) {
|
||||
draft.single_retrieval_config = {
|
||||
model: {
|
||||
provider: currentProvider?.provider,
|
||||
name: currentModel?.model,
|
||||
mode: currentModel?.model_properties?.mode as string,
|
||||
completion_params: {},
|
||||
},
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
const multipleRetrievalConfig = draft.multiple_retrieval_config
|
||||
draft.multiple_retrieval_config = {
|
||||
top_k: multipleRetrievalConfig?.top_k || DATASET_DEFAULT.top_k,
|
||||
score_threshold: multipleRetrievalConfig?.score_threshold,
|
||||
reranking_model: payload.retrieval_mode === RETRIEVE_TYPE.oneWay
|
||||
? undefined
|
||||
: (!multipleRetrievalConfig?.reranking_model?.provider
|
||||
? {
|
||||
provider: rerankDefaultModel?.provider?.provider || '',
|
||||
model: rerankDefaultModel?.model || '',
|
||||
}
|
||||
: multipleRetrievalConfig?.reranking_model),
|
||||
}
|
||||
})
|
||||
setInputs(newInput)
|
||||
// eslint-disable-next-line react-hooks/exhaustive-deps
|
||||
}, [currentProvider?.provider, currentModel, rerankDefaultModel])
|
||||
|
||||
const handleRetrievalModeChange = useCallback((newMode: RETRIEVE_TYPE) => {
|
||||
const newInputs = produce(inputs, (draft) => {
|
||||
draft.retrieval_mode = newMode
|
||||
if (newMode === RETRIEVE_TYPE.multiWay) {
|
||||
draft.multiple_retrieval_config = {
|
||||
top_k: draft.multiple_retrieval_config?.top_k || DATASET_DEFAULT.top_k,
|
||||
score_threshold: draft.multiple_retrieval_config?.score_threshold,
|
||||
reranking_model: !draft.multiple_retrieval_config?.reranking_model?.provider
|
||||
? {
|
||||
provider: rerankDefaultModel?.provider?.provider || '',
|
||||
model: rerankDefaultModel?.model || '',
|
||||
}
|
||||
: draft.multiple_retrieval_config?.reranking_model,
|
||||
}
|
||||
}
|
||||
else {
|
||||
const hasSetModel = draft.single_retrieval_config?.model?.provider
|
||||
if (!hasSetModel) {
|
||||
draft.single_retrieval_config = {
|
||||
model: {
|
||||
provider: currentProvider?.provider || '',
|
||||
name: currentModel?.model || '',
|
||||
mode: currentModel?.model_properties?.mode as string,
|
||||
completion_params: {},
|
||||
},
|
||||
}
|
||||
}
|
||||
}
|
||||
})
|
||||
setInputs(newInputs)
|
||||
}, [inputs, setInputs])
|
||||
}, [currentModel?.model, currentModel?.model_properties?.mode, currentProvider?.provider, inputs, rerankDefaultModel?.model, rerankDefaultModel?.provider?.provider, setInputs])
|
||||
|
||||
const handleMultipleRetrievalConfigChange = useCallback((newConfig: MultipleRetrievalConfig) => {
|
||||
const newInputs = produce(inputs, (draft) => {
|
||||
|
|
@ -111,6 +225,8 @@ const useConfig = (id: string, payload: KnowledgeRetrievalNodeType) => {
|
|||
filterVar,
|
||||
handleRetrievalModeChange,
|
||||
handleMultipleRetrievalConfigChange,
|
||||
handleModelChanged,
|
||||
handleCompletionParamsChange,
|
||||
selectedDatasets,
|
||||
handleOnDatasetsChange,
|
||||
isShowSingleRun,
|
||||
|
|
|
|||
Loading…
Reference in New Issue