From f9090405674d6a4465ba61269c14f3c133676e5d Mon Sep 17 00:00:00 2001 From: Wu Tianwei <30284043+WTW0313@users.noreply.github.com> Date: Wed, 22 Oct 2025 10:49:49 +0800 Subject: [PATCH] feat: Enhance knowledge base node validation by adding checks for embedding and reranking models (#27241) --- .../workflow/hooks/use-checklist.ts | 14 ++++++- .../components/embedding-model.tsx | 1 + .../reranking-model-selector.tsx | 1 + .../workflow/nodes/knowledge-base/default.ts | 40 +++++++++++++++---- .../workflow/nodes/knowledge-base/types.ts | 3 ++ web/i18n/en-US/workflow.ts | 2 + web/i18n/zh-Hans/workflow.ts | 2 + 7 files changed, 54 insertions(+), 9 deletions(-) diff --git a/web/app/components/workflow/hooks/use-checklist.ts b/web/app/components/workflow/hooks/use-checklist.ts index 8a29551b89..1f474a699a 100644 --- a/web/app/components/workflow/hooks/use-checklist.ts +++ b/web/app/components/workflow/hooks/use-checklist.ts @@ -42,6 +42,9 @@ import { fetchDatasets } from '@/service/datasets' import { MAX_TREE_DEPTH } from '@/config' import useNodesAvailableVarList, { useGetNodesAvailableVarList } from './use-nodes-available-var-list' import { getNodeUsedVars, isSpecialVar } from '../nodes/_base/components/variable/utils' +import { useModelList } from '@/app/components/header/account-setting/model-provider-page/hooks' +import { ModelTypeEnum } from '@/app/components/header/account-setting/model-provider-page/declarations' +import type { KnowledgeBaseNodeType } from '../nodes/knowledge-base/types' export const useChecklist = (nodes: Node[], edges: Edge[]) => { const { t } = useTranslation() @@ -57,6 +60,8 @@ export const useChecklist = (nodes: Node[], edges: Edge[]) => { const getToolIcon = useGetToolIcon() const map = useNodesAvailableVarList(nodes) + const { data: embeddingModelList } = useModelList(ModelTypeEnum.textEmbedding) + const { data: rerankModelList } = useModelList(ModelTypeEnum.rerank) const getCheckData = useCallback((data: CommonNodeType<{}>) => { let checkData = data @@ -72,8 +77,15 @@ export const useChecklist = (nodes: Node[], edges: Edge[]) => { _datasets, } as CommonNodeType } + else if (data.type === BlockEnum.KnowledgeBase) { + checkData = { + ...data, + _embeddingModelList: embeddingModelList, + _rerankModelList: rerankModelList, + } as CommonNodeType + } return checkData - }, [datasetsDetail]) + }, [datasetsDetail, embeddingModelList, rerankModelList]) const needWarningNodes = useMemo(() => { const list = [] diff --git a/web/app/components/workflow/nodes/knowledge-base/components/embedding-model.tsx b/web/app/components/workflow/nodes/knowledge-base/components/embedding-model.tsx index 23481cb529..7709fb49d7 100644 --- a/web/app/components/workflow/nodes/knowledge-base/components/embedding-model.tsx +++ b/web/app/components/workflow/nodes/knowledge-base/components/embedding-model.tsx @@ -57,6 +57,7 @@ const EmbeddingModel = ({ modelList={embeddingModelList} onSelect={handleEmbeddingModelChange} readonly={readonly} + showDeprecatedWarnIcon /> ) diff --git a/web/app/components/workflow/nodes/knowledge-base/components/retrieval-setting/reranking-model-selector.tsx b/web/app/components/workflow/nodes/knowledge-base/components/retrieval-setting/reranking-model-selector.tsx index e1eccaf309..19566362a1 100644 --- a/web/app/components/workflow/nodes/knowledge-base/components/retrieval-setting/reranking-model-selector.tsx +++ b/web/app/components/workflow/nodes/knowledge-base/components/retrieval-setting/reranking-model-selector.tsx @@ -44,6 +44,7 @@ const RerankingModelSelector = ({ modelList={rerankModelList} onSelect={handleRerankingModelChange} readonly={readonly} + showDeprecatedWarnIcon /> ) } diff --git a/web/app/components/workflow/nodes/knowledge-base/default.ts b/web/app/components/workflow/nodes/knowledge-base/default.ts index 190addde4d..952eb10fa0 100644 --- a/web/app/components/workflow/nodes/knowledge-base/default.ts +++ b/web/app/components/workflow/nodes/knowledge-base/default.ts @@ -31,6 +31,8 @@ const nodeDefault: NodeDefault = { embedding_model, embedding_model_provider, index_chunk_variable_selector, + _embeddingModelList, + _rerankModelList, } = payload const { @@ -39,6 +41,12 @@ const nodeDefault: NodeDefault = { reranking_model, } = retrieval_model || {} + const currentEmbeddingModelProvider = _embeddingModelList?.find(provider => provider.provider === embedding_model_provider) + const currentEmbeddingModel = currentEmbeddingModelProvider?.models.find(model => model.model === embedding_model) + + const currentRerankingModelProvider = _rerankModelList?.find(provider => provider.provider === reranking_model?.reranking_provider_name) + const currentRerankingModel = currentRerankingModelProvider?.models.find(model => model.model === reranking_model?.reranking_model_name) + if (!chunk_structure) { return { isValid: false, @@ -60,10 +68,18 @@ const nodeDefault: NodeDefault = { } } - if (indexing_technique === IndexingType.QUALIFIED && (!embedding_model || !embedding_model_provider)) { - return { - isValid: false, - errorMessage: t('workflow.nodes.knowledgeBase.embeddingModelIsRequired'), + if (indexing_technique === IndexingType.QUALIFIED) { + if (!embedding_model || !embedding_model_provider) { + return { + isValid: false, + errorMessage: t('workflow.nodes.knowledgeBase.embeddingModelIsRequired'), + } + } + else if (!currentEmbeddingModel) { + return { + isValid: false, + errorMessage: t('workflow.nodes.knowledgeBase.embeddingModelIsInvalid'), + } } } @@ -74,10 +90,18 @@ const nodeDefault: NodeDefault = { } } - if (reranking_enable && (!reranking_model || !reranking_model.reranking_provider_name || !reranking_model.reranking_model_name)) { - return { - isValid: false, - errorMessage: t('workflow.nodes.knowledgeBase.rerankingModelIsRequired'), + if (reranking_enable) { + if (!reranking_model || !reranking_model.reranking_provider_name || !reranking_model.reranking_model_name) { + return { + isValid: false, + errorMessage: t('workflow.nodes.knowledgeBase.rerankingModelIsRequired'), + } + } + else if (!currentRerankingModel) { + return { + isValid: false, + errorMessage: t('workflow.nodes.knowledgeBase.rerankingModelIsInvalid'), + } } } diff --git a/web/app/components/workflow/nodes/knowledge-base/types.ts b/web/app/components/workflow/nodes/knowledge-base/types.ts index a8a0811c54..1f484a5c55 100644 --- a/web/app/components/workflow/nodes/knowledge-base/types.ts +++ b/web/app/components/workflow/nodes/knowledge-base/types.ts @@ -3,6 +3,7 @@ import type { IndexingType } from '@/app/components/datasets/create/step-two' import type { RETRIEVE_METHOD } from '@/types/app' import type { WeightedScoreEnum } from '@/models/datasets' import type { RerankingModeEnum } from '@/models/datasets' +import type { Model } from '@/app/components/header/account-setting/model-provider-page/declarations' export { WeightedScoreEnum } from '@/models/datasets' export { IndexingType as IndexMethodEnum } from '@/app/components/datasets/create/step-two' export { RETRIEVE_METHOD as RetrievalSearchMethodEnum } from '@/types/app' @@ -49,4 +50,6 @@ export type KnowledgeBaseNodeType = CommonNodeType & { embedding_model_provider?: string keyword_number: number retrieval_model: RetrievalSetting + _embeddingModelList?: Model[] + _rerankModelList?: Model[] } diff --git a/web/i18n/en-US/workflow.ts b/web/i18n/en-US/workflow.ts index 01d17a4111..e07fc3f109 100644 --- a/web/i18n/en-US/workflow.ts +++ b/web/i18n/en-US/workflow.ts @@ -959,8 +959,10 @@ const translation = { indexMethodIsRequired: 'Index method is required', chunksVariableIsRequired: 'Chunks variable is required', embeddingModelIsRequired: 'Embedding model is required', + embeddingModelIsInvalid: 'Embedding model is invalid', retrievalSettingIsRequired: 'Retrieval setting is required', rerankingModelIsRequired: 'Reranking model is required', + rerankingModelIsInvalid: 'Reranking model is invalid', }, }, tracing: { diff --git a/web/i18n/zh-Hans/workflow.ts b/web/i18n/zh-Hans/workflow.ts index 441c0a707e..c5c72eb712 100644 --- a/web/i18n/zh-Hans/workflow.ts +++ b/web/i18n/zh-Hans/workflow.ts @@ -959,8 +959,10 @@ const translation = { indexMethodIsRequired: '索引方法是必需的', chunksVariableIsRequired: 'Chunks 变量是必需的', embeddingModelIsRequired: 'Embedding 模型是必需的', + embeddingModelIsInvalid: '无效的 Embedding 模型', retrievalSettingIsRequired: '检索设置是必需的', rerankingModelIsRequired: 'Reranking 模型是必需的', + rerankingModelIsInvalid: '无效的 Reranking 模型', }, }, tracing: {