From 19c17220326725d2b67e87a73810a88a00cd92da Mon Sep 17 00:00:00 2001 From: StyleZhang Date: Thu, 14 Mar 2024 17:27:08 +0800 Subject: [PATCH] node default value --- .../model-provider-page/hooks.ts | 4 +- .../workflow/hooks/use-nodes-data.ts | 28 +++++- .../components/workflow/hooks/use-workflow.ts | 94 ++++++++++++++++++- web/app/components/workflow/index.tsx | 82 ++-------------- .../nodes/_base/components/next-step/add.tsx | 15 ++- web/app/components/workflow/store.ts | 4 + web/types/workflow.ts | 8 +- 7 files changed, 143 insertions(+), 92 deletions(-) diff --git a/web/app/components/header/account-setting/model-provider-page/hooks.ts b/web/app/components/header/account-setting/model-provider-page/hooks.ts index 3b5bdbb682..6e483b4b9b 100644 --- a/web/app/components/header/account-setting/model-provider-page/hooks.ts +++ b/web/app/components/header/account-setting/model-provider-page/hooks.ts @@ -109,7 +109,7 @@ export const MODEL_TYPE_MAPS = { } export const useModelList = (type: ModelTypeIndex) => { - const { data, mutate, isLoading } = useSWR(`/workspaces/current/models/model-types/${MODEL_TYPE_MAPS[type]}`, fetchModelList) + const { data, mutate, isLoading } = useSWR(`/workspaces/current/models/model-types/${MODEL_TYPE_MAPS[type]}`, fetchModelList, { revalidateOnFocus: true }) return { data: data?.data || [], @@ -119,7 +119,7 @@ export const useModelList = (type: ModelTypeIndex) => { } export const useDefaultModel = (type: ModelTypeIndex) => { - const { data, mutate, isLoading } = useSWR(`/workspaces/current/default-model?model_type=${MODEL_TYPE_MAPS[type]}`, fetchDefaultModal) + const { data, mutate, isLoading } = useSWR(`/workspaces/current/default-model?model_type=${MODEL_TYPE_MAPS[type]}`, fetchDefaultModal, { revalidateOnFocus: true }) return { data: data?.data, diff --git a/web/app/components/workflow/hooks/use-nodes-data.ts b/web/app/components/workflow/hooks/use-nodes-data.ts index 29920c76de..7a904ffc08 100644 --- a/web/app/components/workflow/hooks/use-nodes-data.ts +++ b/web/app/components/workflow/hooks/use-nodes-data.ts @@ -1,19 +1,41 @@ +import { useMemo } from 'react' import { useTranslation } from 'react-i18next' import produce from 'immer' -import type { BlockEnum } from '../types' +import { BlockEnum } from '../types' import { NODES_EXTRA_DATA, NODES_INITIAL_DATA, } from '../constants' +import { useStore } from '../store' +import type { LLMNodeType } from '../nodes/llm/types' +import type { QuestionClassifierNodeType } from '../nodes/question-classifier/types' +import { useModelListAndDefaultModelAndCurrentProviderAndModel } from '@/app/components/header/account-setting/model-provider-page/hooks' export const useNodesInitialData = () => { const { t } = useTranslation() + const nodesDefaultConfigs = useStore(s => s.nodesDefaultConfigs) + const { + currentProvider, + currentModel, + } = useModelListAndDefaultModelAndCurrentProviderAndModel(1) - return produce(NODES_INITIAL_DATA, (draft) => { + return useMemo(() => produce(NODES_INITIAL_DATA, (draft) => { Object.keys(draft).forEach((key) => { draft[key as BlockEnum].title = t(`workflow.blocks.${key}`) + + if (currentProvider && currentModel && (key === BlockEnum.LLM || key === BlockEnum.QuestionClassifier)) { + (draft[key as BlockEnum] as LLMNodeType | QuestionClassifierNodeType).model.provider = currentProvider.provider; + (draft[key as BlockEnum] as LLMNodeType | QuestionClassifierNodeType).model.name = currentModel.model + } + + if (nodesDefaultConfigs[key as BlockEnum]) { + draft[key as BlockEnum] = { + ...draft[key as BlockEnum], + ...nodesDefaultConfigs[key as BlockEnum], + } + } }) - }) + }), [t, nodesDefaultConfigs, currentProvider, currentModel]) } export const useNodesExtraData = () => { diff --git a/web/app/components/workflow/hooks/use-workflow.ts b/web/app/components/workflow/hooks/use-workflow.ts index 76f9758764..1757184750 100644 --- a/web/app/components/workflow/hooks/use-workflow.ts +++ b/web/app/components/workflow/hooks/use-workflow.ts @@ -1,15 +1,34 @@ -import { useCallback } from 'react' +import { + useCallback, + useEffect, +} from 'react' +import useSWR from 'swr' import produce from 'immer' import { getIncomers, getOutgoers, useStoreApi, } from 'reactflow' -import { getLayoutByDagre } from '../utils' +import type { ToolsMap } from '../block-selector/types' +import { + generateNewNode, + getLayoutByDagre, +} from '../utils' import type { Node } from '../types' import { BlockEnum } from '../types' -import { SUPPORT_OUTPUT_VARS_NODE } from '../constants' +import { useStore } from '../store' +import { + START_INITIAL_POSITION, + SUPPORT_OUTPUT_VARS_NODE, +} from '../constants' +import { useNodesInitialData } from './use-nodes-data' import { useStore as useAppStore } from '@/app/components/app/store' +import { + fetchNodesDefaultConfigs, + fetchWorkflowDraft, + syncWorkflowDraft, +} from '@/service/workflow' +import { fetchCollectionList } from '@/service/tools' export const useIsChatMode = () => { const appDetail = useAppStore(s => s.appDetail) @@ -156,3 +175,72 @@ export const useWorkflow = () => { getAfterNodesInSameBranch, } } + +export const useWorkflowInit = () => { + const nodesInitialData = useNodesInitialData() + const appDetail = useAppStore(state => state.appDetail)! + const { data, error, mutate } = useSWR(`/apps/${appDetail.id}/workflows/draft`, fetchWorkflowDraft) + + const handleFetchPreloadData = async () => { + try { + const toolsets = await fetchCollectionList() + const nodesDefaultConfigsData = await fetchNodesDefaultConfigs(`/apps/${appDetail?.id}/workflows/default-workflow-block-configs`) + + useStore.setState({ + toolsets, + toolsMap: toolsets.reduce((acc, toolset) => { + acc[toolset.id] = [] + return acc + }, {} as ToolsMap), + }) + useStore.setState({ + nodesDefaultConfigs: nodesDefaultConfigsData.reduce((acc, block) => { + if (!acc[block.type]) + acc[block.type] = block.config + return acc + }, {} as Record), + }) + } + catch (e) { + + } + } + + useEffect(() => { + handleFetchPreloadData() + }, []) + + useEffect(() => { + if (data) + useStore.setState({ draftUpdatedAt: data.updated_at }) + }, [data]) + + if (error && error.json && !error.bodyUsed && appDetail) { + error.json().then((err: any) => { + if (err.code === 'draft_workflow_not_exist') { + useStore.setState({ notInitialWorkflow: true }) + syncWorkflowDraft({ + url: `/apps/${appDetail.id}/workflows/draft`, + params: { + graph: { + nodes: [generateNewNode({ + data: { + ...nodesInitialData.start, + selected: true, + }, + position: START_INITIAL_POSITION, + })], + edges: [], + }, + features: {}, + }, + }).then((res) => { + useStore.setState({ draftUpdatedAt: res.updated_at }) + mutate() + }) + } + }) + } + + return data +} diff --git a/web/app/components/workflow/index.tsx b/web/app/components/workflow/index.tsx index 5f3cd7d921..35a26338c6 100644 --- a/web/app/components/workflow/index.tsx +++ b/web/app/components/workflow/index.tsx @@ -4,7 +4,6 @@ import { useEffect, useMemo, } from 'react' -import useSWR from 'swr' import { setAutoFreeze } from 'immer' import { useKeyPress } from 'ahooks' import ReactFlow, { @@ -14,16 +13,15 @@ import ReactFlow, { } from 'reactflow' import type { Viewport } from 'reactflow' import 'reactflow/dist/style.css' -import type { ToolsMap } from './block-selector/types' import type { Edge, Node, } from './types' import { useEdgesInteractions, - useNodesInitialData, useNodesInteractions, useNodesSyncDraft, + useWorkflowInit, } from './hooks' import Header from './header' import CustomNode from './nodes' @@ -38,17 +36,9 @@ import { initialEdges, initialNodes, } from './utils' -import { START_INITIAL_POSITION } from './constants' -import { - fetchNodesDefaultConfigs, - fetchWorkflowDraft, - syncWorkflowDraft, -} from '@/service/workflow' -import { useStore as useAppStore } from '@/app/components/app/store' import Loading from '@/app/components/base/loading' import { FeaturesProvider } from '@/app/components/base/features' import type { Features as FeaturesData } from '@/app/components/base/features/types' -import { fetchCollectionList } from '@/service/tools' const nodeTypes = { custom: CustomNode, @@ -157,27 +147,7 @@ const WorkflowWrap: FC = ({ nodes, edges, }) => { - const appDetail = useAppStore(state => state.appDetail) - const { data, isLoading, error, mutate } = useSWR(appDetail?.id ? `/apps/${appDetail.id}/workflows/draft` : null, fetchWorkflowDraft) - const { data: nodesDefaultConfigs } = useSWR(appDetail?.id ? `/apps/${appDetail?.id}/workflows/default-workflow-block-configs` : null, fetchNodesDefaultConfigs) - const nodesInitialData = useNodesInitialData() - - useEffect(() => { - if (data) - useStore.setState({ draftUpdatedAt: data.updated_at }) - }, [data]) - - const startNode = useMemo(() => { - return { - id: `${Date.now()}`, - type: 'custom', - data: { - ...nodesInitialData.start, - selected: true, - }, - position: START_INITIAL_POSITION, - } - }, [nodesInitialData]) + const data = useWorkflowInit() const nodesData = useMemo(() => { if (nodes) @@ -186,8 +156,8 @@ const WorkflowWrap: FC = ({ if (data) return initialNodes(data.graph.nodes, data.graph.edges) - return [startNode] - }, [data, nodes, startNode]) + return [] + }, [data, nodes]) const edgesData = useMemo(() => { if (edges) return edges @@ -198,44 +168,7 @@ const WorkflowWrap: FC = ({ return [] }, [data, edges]) - const handleFetchCollectionList = async () => { - const toolsets = await fetchCollectionList() - - useStore.setState({ - toolsets, - toolsMap: toolsets.reduce((acc, toolset) => { - acc[toolset.id] = [] - return acc - }, {} as ToolsMap), - }) - } - - useEffect(() => { - handleFetchCollectionList() - }, []) - - if (error && error.json && !error.bodyUsed && appDetail) { - error.json().then((err: any) => { - if (err.code === 'draft_workflow_not_exist') { - useStore.setState({ notInitialWorkflow: true }) - syncWorkflowDraft({ - url: `/apps/${appDetail.id}/workflows/draft`, - params: { - graph: { - nodes: [startNode], - edges: [], - }, - features: {}, - }, - }).then((res) => { - useStore.setState({ draftUpdatedAt: res.updated_at }) - mutate() - }) - } - }) - } - - if (isLoading) { + if (!data) { return (
@@ -243,10 +176,7 @@ const WorkflowWrap: FC = ({ ) } - if (!data) - return null - - const features = data?.features || {} + const features = data.features || {} const initialFeatures: FeaturesData = { opening: { enabled: !!features.opening_statement, diff --git a/web/app/components/workflow/nodes/_base/components/next-step/add.tsx b/web/app/components/workflow/nodes/_base/components/next-step/add.tsx index 8c4790f751..9a97502273 100644 --- a/web/app/components/workflow/nodes/_base/components/next-step/add.tsx +++ b/web/app/components/workflow/nodes/_base/components/next-step/add.tsx @@ -19,11 +19,20 @@ const Add = ({ branchName, }: AddProps) => { const { t } = useTranslation() - const { handleNodeAddNext } = useNodesInteractions() + const { handleNodeAdd } = useNodesInteractions() const handleSelect = useCallback((type, toolDefaultValue) => { - handleNodeAddNext(nodeId, type, sourceHandle, toolDefaultValue) - }, [nodeId, sourceHandle, handleNodeAddNext]) + handleNodeAdd( + { + nodeType: type, + toolDefaultValue, + }, + { + prevNodeId: nodeId, + prevNodeSourceHandle: sourceHandle, + }, + ) + }, [nodeId, sourceHandle, handleNodeAdd]) const renderTrigger = useCallback((open: boolean) => { return ( diff --git a/web/app/components/workflow/store.ts b/web/app/components/workflow/store.ts index 9881cccefc..18fd1224b9 100644 --- a/web/app/components/workflow/store.ts +++ b/web/app/components/workflow/store.ts @@ -38,6 +38,7 @@ type State = { viewport: Viewport } notInitialWorkflow: boolean + nodesDefaultConfigs: Record } type Action = { @@ -58,6 +59,7 @@ type Action = { setInputs: (inputs: Record) => void setBackupDraft: (backupDraft?: State['backupDraft']) => void setNotInitialWorkflow: (notInitialWorkflow: boolean) => void + setNodesDefaultConfigs: (nodesDefaultConfigs: Record) => void } export const useStore = create(set => ({ @@ -95,4 +97,6 @@ export const useStore = create(set => ({ setBackupDraft: backupDraft => set(() => ({ backupDraft })), notInitialWorkflow: false, setNotInitialWorkflow: notInitialWorkflow => set(() => ({ notInitialWorkflow })), + nodesDefaultConfigs: {}, + setNodesDefaultConfigs: nodesDefaultConfigs => set(() => ({ nodesDefaultConfigs })), })) diff --git a/web/types/workflow.ts b/web/types/workflow.ts index 7e77bdba5d..bac3a7d2de 100644 --- a/web/types/workflow.ts +++ b/web/types/workflow.ts @@ -162,8 +162,6 @@ export type WorkflowRunHistoryResponse = { } export type NodesDefaultConfigsResponse = { - blocks: { - type: string - config: any - }[] -} + type: string + config: any +}[]