fix: tool node check

This commit is contained in:
StyleZhang 2024-04-02 20:19:27 +08:00
parent 5df66579a8
commit f7184c0e36
2 changed files with 64 additions and 7 deletions

View File

@ -10,14 +10,21 @@ import type {
} from '../types'
import { BlockEnum } from '../types'
import { useStore } from '../store'
import { getValidTreeNodes } from '../utils'
import {
getToolCheckParams,
getValidTreeNodes,
} from '../utils'
import { MAX_TREE_DEEPTH } from '../constants'
import type { ToolNodeType } from '../nodes/tool/types'
import { useIsChatMode } from './use-workflow'
import { useNodesExtraData } from './use-nodes-data'
import { useToastContext } from '@/app/components/base/toast'
import { CollectionType } from '@/app/components/tools/types'
import { useGetLanguage } from '@/context/i18n'
export const useChecklist = (nodes: Node[], edges: Edge[]) => {
const { t } = useTranslation()
const language = useGetLanguage()
const nodesExtraData = useNodesExtraData()
const buildInTools = useStore(s => s.buildInTools)
const customTools = useStore(s => s.customTools)
@ -28,16 +35,21 @@ export const useChecklist = (nodes: Node[], edges: Edge[]) => {
for (let i = 0; i < nodes.length; i++) {
const node = nodes[i]
const { errorMessage } = nodesExtraData[node.data.type].checkValid(node.data, t)
let toolIcon
let moreDataForCheckValid
if (node.data.type === BlockEnum.Tool) {
if (node.data.provider_type === 'builtin')
const { provider_type } = node.data
const isBuiltIn = provider_type === CollectionType.builtIn
moreDataForCheckValid = getToolCheckParams(node.data as ToolNodeType, buildInTools, customTools, language)
if (isBuiltIn)
toolIcon = buildInTools.find(tool => tool.id === node.data.provider_id)?.icon
if (node.data.provider_type === 'custom')
if (!isBuiltIn)
toolIcon = customTools.find(tool => tool.id === node.data.provider_id)?.icon
}
const { errorMessage } = nodesExtraData[node.data.type].checkValid(node.data, t, moreDataForCheckValid)
if (errorMessage || !validNodes.find(n => n.id === node.id)) {
list.push({
@ -52,13 +64,16 @@ export const useChecklist = (nodes: Node[], edges: Edge[]) => {
}
return list
}, [t, nodes, edges, nodesExtraData, buildInTools, customTools])
}, [t, nodes, edges, nodesExtraData, buildInTools, customTools, language])
return needWarningNodes
}
export const useChecklistBeforePublish = () => {
const { t } = useTranslation()
const language = useGetLanguage()
const buildInTools = useStore(s => s.buildInTools)
const customTools = useStore(s => s.customTools)
const { notify } = useToastContext()
const isChatMode = useIsChatMode()
const store = useStoreApi()
@ -82,7 +97,11 @@ export const useChecklistBeforePublish = () => {
for (let i = 0; i < nodes.length; i++) {
const node = nodes[i]
const { errorMessage } = nodesExtraData[node.data.type as BlockEnum].checkValid(node.data, t)
let moreDataForCheckValid
if (node.data.type === BlockEnum.Tool)
moreDataForCheckValid = getToolCheckParams(node.data as ToolNodeType, buildInTools, customTools, language)
const { errorMessage } = nodesExtraData[node.data.type as BlockEnum].checkValid(node.data, t, moreDataForCheckValid)
if (errorMessage) {
notify({ type: 'error', message: `[${node.data.title}] ${errorMessage}` })
@ -106,7 +125,7 @@ export const useChecklistBeforePublish = () => {
}
return true
}, [nodesExtraData, notify, t, store, isChatMode])
}, [nodesExtraData, notify, t, store, isChatMode, buildInTools, customTools, language])
return {
handleCheckBeforePublish,

View File

@ -10,7 +10,9 @@ import {
} from 'lodash-es'
import type {
Edge,
InputVar,
Node,
ToolWithProvider,
} from './types'
import { BlockEnum } from './types'
import {
@ -18,6 +20,9 @@ import {
START_INITIAL_POSITION,
} from './constants'
import type { QuestionClassifierNodeType } from './nodes/question-classifier/types'
import type { ToolNodeType } from './nodes/tool/types'
import { CollectionType } from '@/app/components/tools/types'
import { toolParametersToFormSchemas } from '@/app/components/tools/utils/to-form-schema'
export const initialNodes = (nodes: Node[], edges: Edge[]) => {
const firstNode = nodes[0]
@ -228,3 +233,36 @@ export const getValidTreeNodes = (nodes: Node[], edges: Edge[]) => {
maxDepth,
}
}
export const getToolCheckParams = (
toolData: ToolNodeType,
buildInTools: ToolWithProvider[],
customTools: ToolWithProvider[],
language: string,
) => {
const { provider_id, provider_type, tool_name } = toolData
const isBuiltIn = provider_type === CollectionType.builtIn
const currentTools = isBuiltIn ? buildInTools : customTools
const currCollection = currentTools.find(item => item.id === provider_id)
const currTool = currCollection?.tools.find(tool => tool.name === tool_name)
const formSchemas = currTool ? toolParametersToFormSchemas(currTool.parameters) : []
const toolInputVarSchema = formSchemas.filter((item: any) => item.form === 'llm')
const toolSettingSchema = formSchemas.filter((item: any) => item.form !== 'llm')
return {
toolInputsSchema: (() => {
const formInputs: InputVar[] = []
toolInputVarSchema.forEach((item: any) => {
formInputs.push({
label: item.label[language] || item.label.en_US,
variable: item.variable,
type: item.type,
required: item.required,
})
})
return formInputs
})(),
toolSettingSchema,
language,
}
}