From b95e6f6a7a33cd1623633be388222143bff5210d Mon Sep 17 00:00:00 2001 From: Blackoutta <37723456+Blackoutta@users.noreply.github.com> Date: Sun, 10 May 2026 20:10:16 +0800 Subject: [PATCH] feat: support editable class labels in question classifier (#35430) --- eslint-suppressions.json | 23 +-- .../prompt-editor/__tests__/index.spec.tsx | 8 + .../components/base/prompt-editor/index.tsx | 16 +- web/app/components/workflow/constants.ts | 4 + .../__tests__/integration.spec.tsx | 6 + .../__tests__/node.spec.tsx | 10 +- .../__tests__/panel.spec.tsx | 1 + .../__tests__/use-config.spec.ts | 147 ++++++++++++++++++ .../components/__tests__/class-item.spec.tsx | 36 ++++- .../components/class-item.tsx | 99 +++++++++++- .../components/class-label-utils.ts | 66 ++++++++ .../components/class-list.tsx | 51 ++++-- .../nodes/question-classifier/default.ts | 5 +- .../nodes/question-classifier/node.tsx | 10 +- .../nodes/question-classifier/panel.tsx | 5 + .../nodes/question-classifier/types.ts | 1 + .../nodes/question-classifier/use-config.ts | 118 ++++++++++---- .../__tests__/node.spec.tsx | 5 +- .../nodes/question-classifier/node.tsx | 5 +- web/i18n/en-US/workflow.json | 4 + web/i18n/ja-JP/workflow.json | 4 + web/i18n/zh-Hans/workflow.json | 4 + web/i18n/zh-Hant/workflow.json | 4 + 23 files changed, 535 insertions(+), 97 deletions(-) create mode 100644 web/app/components/workflow/nodes/question-classifier/__tests__/use-config.spec.ts create mode 100644 web/app/components/workflow/nodes/question-classifier/components/class-label-utils.ts diff --git a/eslint-suppressions.json b/eslint-suppressions.json index 23e2da9ee0..2de84456ee 100644 --- a/eslint-suppressions.json +++ b/eslint-suppressions.json @@ -1465,7 +1465,7 @@ }, "web/app/components/base/prompt-editor/index.tsx": { "ts/no-explicit-any": { - "count": 4 + "count": 3 } }, "web/app/components/base/prompt-editor/plugins/component-picker-block/index.tsx": { @@ -3858,30 +3858,9 @@ "count": 9 } }, - "web/app/components/workflow/nodes/question-classifier/components/class-item.tsx": { - "react/set-state-in-effect": { - "count": 1 - } - }, "web/app/components/workflow/nodes/question-classifier/components/class-list.tsx": { "react/set-state-in-effect": { "count": 1 - }, - "react/unsupported-syntax": { - "count": 2 - } - }, - "web/app/components/workflow/nodes/question-classifier/default.ts": { - "ts/no-explicit-any": { - "count": 1 - } - }, - "web/app/components/workflow/nodes/question-classifier/use-config.ts": { - "react/set-state-in-effect": { - "count": 2 - }, - "ts/no-explicit-any": { - "count": 2 } }, "web/app/components/workflow/nodes/question-classifier/use-single-run-form-params.ts": { diff --git a/web/app/components/base/prompt-editor/__tests__/index.spec.tsx b/web/app/components/base/prompt-editor/__tests__/index.spec.tsx index 9d75b9e061..31e25ab19e 100644 --- a/web/app/components/base/prompt-editor/__tests__/index.spec.tsx +++ b/web/app/components/base/prompt-editor/__tests__/index.spec.tsx @@ -365,6 +365,14 @@ describe('PromptEditor', () => { expect(() => unmount()).not.toThrow() }) + it('should rerender without ref-driven update loops', () => { + const { rerender } = render() + + expect(() => { + rerender() + }).not.toThrow() + }) + it('should render hitl block when show=true', () => { render( = ({ } as any) }, [eventEmitter, historyBlock?.history]) - const [floatingAnchorElem, setFloatingAnchorElem] = useState(null) + const [floatingAnchorElem, setFloatingAnchorElem] = useState(null) - const onRef = (_floatingAnchorElem: any) => { - if (_floatingAnchorElem !== null) - setFloatingAnchorElem(_floatingAnchorElem) - } + const onRef = useCallback((nextFloatingAnchorElem: HTMLDivElement | null) => { + setFloatingAnchorElem((currentFloatingAnchorElem) => { + if (currentFloatingAnchorElem === nextFloatingAnchorElem) + return currentFloatingAnchorElem + + return nextFloatingAnchorElem + }) + }, []) return ( diff --git a/web/app/components/workflow/constants.ts b/web/app/components/workflow/constants.ts index ed9a072824..101d15a140 100644 --- a/web/app/components/workflow/constants.ts +++ b/web/app/components/workflow/constants.ts @@ -172,6 +172,10 @@ export const QUESTION_CLASSIFIER_OUTPUT_STRUCT = [ variable: 'class_name', type: VarType.string, }, + { + variable: 'class_label', + type: VarType.string, + }, { variable: 'usage', type: VarType.object, diff --git a/web/app/components/workflow/nodes/question-classifier/__tests__/integration.spec.tsx b/web/app/components/workflow/nodes/question-classifier/__tests__/integration.spec.tsx index ada3fc43cc..c4f8a41d47 100644 --- a/web/app/components/workflow/nodes/question-classifier/__tests__/integration.spec.tsx +++ b/web/app/components/workflow/nodes/question-classifier/__tests__/integration.spec.tsx @@ -55,6 +55,12 @@ vi.mock('@/app/components/header/account-setting/model-provider-page/model-selec default: ({ defaultModel }: any) =>
{defaultModel.provider}:{defaultModel.model}
, })) +vi.mock('@langgenius/dify-ui/tooltip', () => ({ + Tooltip: ({ children }: any) =>
{children}
, + TooltipTrigger: ({ children }: any) => <>{children}, + TooltipContent: ({ children }: any) =>
{children}
, +})) + vi.mock('@/app/components/workflow/nodes/_base/components/readonly-input-with-select-var', () => ({ default: ({ value }: any) =>
{value}
, })) diff --git a/web/app/components/workflow/nodes/question-classifier/__tests__/node.spec.tsx b/web/app/components/workflow/nodes/question-classifier/__tests__/node.spec.tsx index ad411639e9..3356226a85 100644 --- a/web/app/components/workflow/nodes/question-classifier/__tests__/node.spec.tsx +++ b/web/app/components/workflow/nodes/question-classifier/__tests__/node.spec.tsx @@ -76,12 +76,18 @@ describe('question-classifier/node', () => { render( ), + createTopic({ id: 'topic-2', name: 'Refunds', label: 'Refund desk' } as Partial), + ], + })} />, ) expect(screen.getByText('openai:gpt-4o')).toBeInTheDocument() - expect(screen.getByText('Billing questions')).toBeInTheDocument() + expect(screen.getByText('Billing')).toBeInTheDocument() + expect(screen.getByText('Refund desk')).toBeInTheDocument() expect(screen.getByText('handle-topic-1')).toBeInTheDocument() expect(screen.getByText('handle-topic-2')).toBeInTheDocument() }) diff --git a/web/app/components/workflow/nodes/question-classifier/__tests__/panel.spec.tsx b/web/app/components/workflow/nodes/question-classifier/__tests__/panel.spec.tsx index c205f57d08..f92bb5abb3 100644 --- a/web/app/components/workflow/nodes/question-classifier/__tests__/panel.spec.tsx +++ b/web/app/components/workflow/nodes/question-classifier/__tests__/panel.spec.tsx @@ -144,6 +144,7 @@ describe('question-classifier/panel', () => { expect(handleVisionResolutionEnabledChange).toHaveBeenCalledWith(true) expect(handleVisionResolutionChange).toHaveBeenCalledWith({ resolution: 'high' }) expect(screen.getByText('class_name:string')).toBeInTheDocument() + expect(screen.getByText('class_label:string')).toBeInTheDocument() expect(screen.getByText('usage:object')).toBeInTheDocument() }) }) diff --git a/web/app/components/workflow/nodes/question-classifier/__tests__/use-config.spec.ts b/web/app/components/workflow/nodes/question-classifier/__tests__/use-config.spec.ts new file mode 100644 index 0000000000..89af550de1 --- /dev/null +++ b/web/app/components/workflow/nodes/question-classifier/__tests__/use-config.spec.ts @@ -0,0 +1,147 @@ +import type { QuestionClassifierNodeType } from '../types' +import { act, renderHook, waitFor } from '@testing-library/react' +import { useModelListAndDefaultModelAndCurrentProviderAndModel } from '@/app/components/header/account-setting/model-provider-page/hooks' +import { + useIsChatMode, + useNodesReadOnly, + useWorkflow, +} from '@/app/components/workflow/hooks' +import useConfigVision from '@/app/components/workflow/hooks/use-config-vision' +import useAvailableVarList from '@/app/components/workflow/nodes/_base/hooks/use-available-var-list' +import useNodeCrud from '@/app/components/workflow/nodes/_base/hooks/use-node-crud' +import { useStore } from '@/app/components/workflow/store' +import { BlockEnum } from '@/app/components/workflow/types' +import { AppModeEnum } from '@/types/app' +import useConfig from '../use-config' + +vi.mock('@/app/components/workflow/hooks', () => ({ + useNodesReadOnly: vi.fn(), + useIsChatMode: vi.fn(), + useWorkflow: vi.fn(), +})) + +vi.mock('reactflow', () => ({ + useUpdateNodeInternals: vi.fn(() => vi.fn()), +})) + +vi.mock('@/app/components/workflow/nodes/_base/hooks/use-node-crud', () => ({ + __esModule: true, + default: vi.fn(), +})) + +vi.mock('@/app/components/header/account-setting/model-provider-page/hooks', () => ({ + useModelListAndDefaultModelAndCurrentProviderAndModel: vi.fn(), +})) + +vi.mock('@/app/components/workflow/store', () => ({ + useStore: vi.fn(), +})) + +vi.mock('@/app/components/workflow/hooks/use-config-vision', () => ({ + __esModule: true, + default: vi.fn(), +})) + +vi.mock('@/app/components/workflow/nodes/_base/hooks/use-available-var-list', () => ({ + __esModule: true, + default: vi.fn(), +})) + +const mockUseNodesReadOnly = vi.mocked(useNodesReadOnly) +const mockUseIsChatMode = vi.mocked(useIsChatMode) +const mockUseWorkflow = vi.mocked(useWorkflow) +const mockUseNodeCrud = vi.mocked(useNodeCrud) +const mockUseModelListAndDefaultModelAndCurrentProviderAndModel = vi.mocked(useModelListAndDefaultModelAndCurrentProviderAndModel) +const mockUseStore = vi.mocked(useStore) +const mockUseConfigVision = vi.mocked(useConfigVision) +const mockUseAvailableVarList = vi.mocked(useAvailableVarList) + +const createPayload = (overrides: Partial = {}): QuestionClassifierNodeType => ({ + type: BlockEnum.QuestionClassifier, + title: 'Question Classifier', + desc: '', + model: { + provider: '', + name: '', + mode: AppModeEnum.CHAT, + completion_params: {}, + }, + classes: [{ id: 'topic-1', name: 'Billing questions', label: 'CLASS 1' }], + query_variable_selector: ['start-node', 'sys.query'], + instruction: 'Route by topic', + vision: { + enabled: false, + }, + ...overrides, +}) + +describe('question-classifier/use-config', () => { + const setInputs = vi.fn() + let latestVisionOptions: { + onChange: (payload: QuestionClassifierNodeType['vision']) => void + } | null = null + + beforeEach(() => { + vi.clearAllMocks() + latestVisionOptions = null + mockUseNodesReadOnly.mockReturnValue({ nodesReadOnly: false, getNodesReadOnly: () => false }) + mockUseIsChatMode.mockReturnValue(true) + mockUseWorkflow.mockReturnValue({ + getBeforeNodesInSameBranch: vi.fn(() => []), + } as unknown as ReturnType) + mockUseNodeCrud.mockReturnValue({ + inputs: createPayload(), + setInputs, + }) + mockUseModelListAndDefaultModelAndCurrentProviderAndModel.mockReturnValue({ + modelList: [], + defaultModel: undefined, + currentProvider: undefined, + currentModel: undefined, + } as ReturnType) + mockUseStore.mockImplementation((selector) => { + return selector({ nodesDefaultConfigs: {} } as never) + }) + mockUseConfigVision.mockImplementation((_model, options) => { + latestVisionOptions = options as typeof latestVisionOptions + return { + isVisionModel: false, + handleVisionResolutionEnabledChange: vi.fn(), + handleVisionResolutionChange: vi.fn(), + handleModelChanged: vi.fn(() => { + latestVisionOptions?.onChange({ enabled: false }) + }), + } + }) + mockUseAvailableVarList.mockReturnValue({ + availableVars: [], + availableNodes: [], + availableNodesWithParent: [], + } as unknown as ReturnType) + }) + + it('preserves the selected model when the vision follow-up updates after model selection', async () => { + const { result } = renderHook(() => useConfig('question-classifier-node', createPayload())) + + act(() => { + result.current.handleModelChanged({ + provider: 'openai', + modelId: 'gpt-4o', + mode: AppModeEnum.CHAT, + }) + }) + + await waitFor(() => { + expect(setInputs).toHaveBeenLastCalledWith(expect.objectContaining({ + model: expect.objectContaining({ + provider: 'openai', + name: 'gpt-4o', + mode: AppModeEnum.CHAT, + }), + vision: { + enabled: false, + }, + })) + }) + }) +}) diff --git a/web/app/components/workflow/nodes/question-classifier/components/__tests__/class-item.spec.tsx b/web/app/components/workflow/nodes/question-classifier/components/__tests__/class-item.spec.tsx index 6ba88016e0..ab1d0e0224 100644 --- a/web/app/components/workflow/nodes/question-classifier/components/__tests__/class-item.spec.tsx +++ b/web/app/components/workflow/nodes/question-classifier/components/__tests__/class-item.spec.tsx @@ -3,8 +3,6 @@ import { fireEvent, render, screen } from '@testing-library/react' import userEvent from '@testing-library/user-event' import ClassItem from '../class-item' -const mockEditorRender = vi.hoisted(() => vi.fn()) - vi.mock('@/app/components/workflow/nodes/_base/hooks/use-available-var-list', () => ({ __esModule: true, default: () => ({ @@ -16,13 +14,12 @@ vi.mock('@/app/components/workflow/nodes/_base/hooks/use-available-var-list', () vi.mock('@/app/components/workflow/nodes/_base/components/prompt/editor', () => ({ __esModule: true, default: (props: { - title: string + title: React.ReactNode value: string onChange: (value: string) => void onRemove: () => void showRemove?: boolean }) => { - mockEditorRender(props) return (
{props.title}
@@ -70,9 +67,32 @@ describe('question-classifier/class-item', () => { name: 'Billing questions updated', }) expect(onRemove).toHaveBeenCalledTimes(1) - expect(mockEditorRender).toHaveBeenCalledWith(expect.objectContaining({ - title: 'workflow.nodes.questionClassifiers.class 1', - value: 'Billing questions', - })) + expect(screen.getByRole('button', { name: 'CLASS 1' })).toBeInTheDocument() + }) + + it('preserves a custom label when editing the classifier name', () => { + const onChange = vi.fn() + + render( + true} + />, + ) + + fireEvent.change(screen.getByLabelText('class-name'), { + target: { value: 'Billing questions updated' }, + }) + + expect(onChange).toHaveBeenCalledWith({ + id: 'topic-1', + name: 'Billing questions updated', + label: 'Billing', + }) + expect(screen.getByRole('button', { name: 'Billing' })).toBeInTheDocument() }) }) diff --git a/web/app/components/workflow/nodes/question-classifier/components/class-item.tsx b/web/app/components/workflow/nodes/question-classifier/components/class-item.tsx index 1e90d4590d..139c314def 100644 --- a/web/app/components/workflow/nodes/question-classifier/components/class-item.tsx +++ b/web/app/components/workflow/nodes/question-classifier/components/class-item.tsx @@ -2,12 +2,13 @@ import type { FC } from 'react' import type { Topic } from '../types' import type { ValueSelector, Var } from '@/app/components/workflow/types' -import { uniqueId } from 'es-toolkit/compat' +import { cn } from '@langgenius/dify-ui/cn' import * as React from 'react' -import { useCallback, useEffect, useState } from 'react' +import { useCallback, useEffect, useId, useRef, useState } from 'react' import { useTranslation } from 'react-i18next' import Editor from '@/app/components/workflow/nodes/_base/components/prompt/editor' import useAvailableVarList from '@/app/components/workflow/nodes/_base/hooks/use-available-var-list' +import { getCanonicalClassLabel, getDisplayClassLabel } from './class-label-utils' const i18nPrefix = 'nodes.questionClassifiers' @@ -21,6 +22,7 @@ type Props = { index: number readonly?: boolean filterVar: (payload: Var, valueSelector: ValueSelector) => boolean + onLabelEditStart?: () => void } const ClassItem: FC = ({ @@ -33,18 +35,49 @@ const ClassItem: FC = ({ index, readonly, filterVar, + onLabelEditStart, }) => { const { t } = useTranslation() - const [instanceId, setInstanceId] = useState(() => uniqueId()) + const reactId = useId() + const [isEditingLabel, setIsEditingLabel] = useState(false) + const [draftLabel, setDraftLabel] = useState('') + const labelInputRef = useRef(null) + const displayLabel = getDisplayClassLabel(payload.label, index, t) + const instanceId = `${nodeId}-${reactId}` useEffect(() => { - setInstanceId(`${nodeId}-${uniqueId()}`) - }, [nodeId]) + if (isEditingLabel) + labelInputRef.current?.select() + }, [isEditingLabel]) const handleNameChange = useCallback((value: string) => { onChange({ ...payload, name: value }) }, [onChange, payload]) + const handleLabelSave = useCallback((nextValue: string) => { + const normalizedLabel = getCanonicalClassLabel(nextValue, index, t) + setIsEditingLabel(false) + setDraftLabel(normalizedLabel) + const shouldPersistLabel = normalizedLabel !== displayLabel + || (payload.label !== undefined && payload.label !== normalizedLabel) + if (shouldPersistLabel) + onChange({ ...payload, label: normalizedLabel }) + }, [displayLabel, index, onChange, payload, t]) + + const handleLabelCancel = useCallback(() => { + setDraftLabel(displayLabel) + setIsEditingLabel(false) + }, [displayLabel]) + + const handleLabelEditStart = useCallback(() => { + if (readonly) + return + + setDraftLabel(displayLabel) + setIsEditingLabel(true) + onLabelEditStart?.() + }, [displayLabel, onLabelEditStart, readonly]) + const { availableVars, availableNodesWithParent } = useAvailableVarList(nodeId, { onlyLeafNodeVar: false, hideChatVar: false, @@ -52,11 +85,65 @@ const ClassItem: FC = ({ filterVar, }) + const title = isEditingLabel + ? ( + setDraftLabel(event.target.value)} + onBlur={() => handleLabelSave(draftLabel)} + onClick={event => event.stopPropagation()} + onDoubleClick={event => event.stopPropagation()} + onKeyDown={(event) => { + if (event.key === 'Enter') { + event.preventDefault() + handleLabelSave(draftLabel) + } + + if (event.key === 'Escape') { + event.preventDefault() + handleLabelCancel() + } + }} + autoFocus + /> + ) + : readonly + ? ( +
+ {displayLabel} +
+ ) + : ( + + ) + return ( `${LEGACY_DEFAULT_LABEL_PREFIX} ${index}` + +const getTranslatedDefaultClassLabel = (t: TFunction, index: number) => { + const translated = t(`${i18nPrefix}.defaultLabel`, { ns: 'workflow', index }) + if (typeof translated !== 'string') + return undefined + + const resolvedLabel = translated.replace('{{index}}', String(index)) + const rawWorkflowKey = `workflow.${i18nPrefix}.defaultLabel` + const rawKey = `${i18nPrefix}.defaultLabel` + if ( + resolvedLabel === rawWorkflowKey + || resolvedLabel === rawKey + || resolvedLabel.startsWith(`${rawWorkflowKey}:`) + || resolvedLabel.startsWith(`${rawKey}:`) + ) { + return undefined + } + + return resolvedLabel +} + +const normalizeClassLabel = (label?: string | null) => label?.trim() ?? '' + +export const getDefaultClassLabel = (_t: TFunction, index: number) => getCanonicalDefaultClassLabel(index) + +export const getDisplayClassLabel = ( + label: string | undefined, + index: number, + t: TFunction, +) => normalizeClassLabel(label) || getTranslatedDefaultClassLabel(t, index) || getCanonicalDefaultClassLabel(index) + +export const isDefaultClassLabel = ( + label: string | undefined, + index: number, + t: TFunction, +) => { + const normalizedLabel = normalizeClassLabel(label) + if (!normalizedLabel) + return true + + return DEFAULT_EQUIVALENT_PREFIXES.some(prefix => normalizedLabel === `${prefix} ${index}`) + || normalizedLabel === getTranslatedDefaultClassLabel(t, index) +} + +export const getCanonicalClassLabel = ( + label: string | undefined, + index: number, + t: TFunction, +) => { + const normalizedLabel = normalizeClassLabel(label) + if (!normalizedLabel) + return getCanonicalDefaultClassLabel(index) + + if (isDefaultClassLabel(normalizedLabel, index, t)) + return getCanonicalDefaultClassLabel(index) + + return normalizedLabel +} diff --git a/web/app/components/workflow/nodes/question-classifier/components/class-list.tsx b/web/app/components/workflow/nodes/question-classifier/components/class-list.tsx index fead81fb19..1a80266de5 100644 --- a/web/app/components/workflow/nodes/question-classifier/components/class-list.tsx +++ b/web/app/components/workflow/nodes/question-classifier/components/class-list.tsx @@ -14,8 +14,10 @@ import { ArrowDownRoundFill } from '@/app/components/base/icons/src/vender/solid import { useEdgesInteractions } from '../../../hooks' import AddButton from '../../_base/components/add-button' import Item from './class-item' +import { getDefaultClassLabel, isDefaultClassLabel } from './class-label-utils' const i18nPrefix = 'nodes.questionClassifiers' +const INLINE_LABEL_HINT_STORAGE_KEY = 'question-classifier-inline-label-hint-dismissed' type Props = { nodeId: string @@ -40,6 +42,17 @@ const ClassList: FC = ({ const [shouldScrollToEnd, setShouldScrollToEnd] = useState(false) const prevListLength = useRef(list.length) const [collapsed, setCollapsed] = useState(false) + const [isRenameHintDismissed, setIsRenameHintDismissed] = useState(() => { + if (typeof window === 'undefined') + return true + + try { + return window.localStorage.getItem(INLINE_LABEL_HINT_STORAGE_KEY) === 'true' + } + catch { + return false + } + }) const handleClassChange = useCallback((index: number) => { return (value: Topic) => { @@ -52,13 +65,17 @@ const ClassList: FC = ({ const handleAddClass = useCallback(() => { const newList = produce(list, (draft) => { - draft.push({ id: `${Date.now()}`, name: '' }) + draft.push({ + id: `${Date.now()}`, + name: '', + label: getDefaultClassLabel(t, draft.length + 1), + }) }) onChange(newList) setShouldScrollToEnd(true) if (collapsed) setCollapsed(false) - }, [list, onChange, collapsed]) + }, [collapsed, list, onChange, t]) const handleRemoveClass = useCallback((index: number) => { return () => { @@ -72,7 +89,6 @@ const ClassList: FC = ({ const topicCount = list.length - // Scroll to the newly added item after the list updates useEffect(() => { if (shouldScrollToEnd && list.length > prevListLength.current) setShouldScrollToEnd(false) @@ -83,6 +99,22 @@ const ClassList: FC = ({ setCollapsed(!collapsed) }, [collapsed]) + const dismissRenameHint = useCallback(() => { + if (isRenameHintDismissed) + return + + setIsRenameHintDismissed(true) + try { + window.localStorage.setItem(INLINE_LABEL_HINT_STORAGE_KEY, 'true') + } + catch { + } + }, [isRenameHintDismissed]) + + const shouldShowRenameHint = !readonly && !isRenameHintDismissed && list.some((item, index) => { + return isDefaultClassLabel(item.label, index + 1, t) + }) + return ( <>
@@ -100,6 +132,11 @@ const ClassList: FC = ({ )}
+ {shouldShowRenameHint && ( +
+ {t(`${i18nPrefix}.renameHint`, { ns: 'workflow' })} +
+ )} {!collapsed && (
= ({ > { list.map((item, index) => { - const canDrag = (() => { - if (readonly) - return false - - return topicCount >= 2 - })() + const canDrag = !readonly && topicCount >= 2 return (
= ({ index={index + 1} readonly={readonly} filterVar={filterVar} + onLabelEditStart={dismissRenameHint} />
diff --git a/web/app/components/workflow/nodes/question-classifier/default.ts b/web/app/components/workflow/nodes/question-classifier/default.ts index 1ee0d3e8d1..c8f882ae31 100644 --- a/web/app/components/workflow/nodes/question-classifier/default.ts +++ b/web/app/components/workflow/nodes/question-classifier/default.ts @@ -1,3 +1,4 @@ +import type { TFunction } from 'i18next' import type { NodeDefault } from '../../types' import type { QuestionClassifierNodeType } from './types' import { BlockClassificationEnum } from '@/app/components/workflow/block-selector/types' @@ -28,10 +29,12 @@ const nodeDefault: NodeDefault = { { id: '1', name: '', + label: 'CLASS 1', }, { id: '2', name: '', + label: 'CLASS 2', }, ], _targetBranches: [ @@ -48,7 +51,7 @@ const nodeDefault: NodeDefault = { enabled: false, }, }, - checkValid(payload: QuestionClassifierNodeType, t: any) { + checkValid(payload: QuestionClassifierNodeType, t: TFunction<'workflow'>) { let errorMessages = '' if (!errorMessages && (!payload.query_variable_selector || payload.query_variable_selector.length === 0)) errorMessages = t(`${i18nPrefix}errorMsg.fieldRequired`, { ns: 'workflow', field: t(`${i18nPrefix}nodes.questionClassifiers.inputVars`, { ns: 'workflow' }) }) diff --git a/web/app/components/workflow/nodes/question-classifier/node.tsx b/web/app/components/workflow/nodes/question-classifier/node.tsx index 305eacc204..ac932f4767 100644 --- a/web/app/components/workflow/nodes/question-classifier/node.tsx +++ b/web/app/components/workflow/nodes/question-classifier/node.tsx @@ -11,19 +11,19 @@ import { import ModelSelector from '@/app/components/header/account-setting/model-provider-page/model-selector' import { NodeSourceHandle } from '../_base/components/node-handle' import ReadonlyInputWithSelectVar from '../_base/components/readonly-input-with-select-var' - -const i18nPrefix = 'nodes.questionClassifiers' +import { getDisplayClassLabel } from './components/class-label-utils' const MAX_CLASS_TEXT_LENGTH = 50 type TruncatedClassItemProps = { - topic: { id: string, name: string } + topic: { id: string, name: string, label?: string } index: number nodeId: string t: TFunction } const TruncatedClassItem: FC = ({ topic, index, nodeId, t }) => { + const displayLabel = getDisplayClassLabel(topic.label, index + 1, t) const truncatedText = topic.name.length > MAX_CLASS_TEXT_LENGTH ? `${topic.name.slice(0, MAX_CLASS_TEXT_LENGTH)}...` : topic.name @@ -42,8 +42,8 @@ const TruncatedClassItem: FC = ({ topic, index, nodeId, return (
-
- {`${t(`${i18nPrefix}.class`, { ns: 'workflow' })} ${index + 1}`} +
+ {displayLabel}
{shouldShowTooltip ? ( diff --git a/web/app/components/workflow/nodes/question-classifier/panel.tsx b/web/app/components/workflow/nodes/question-classifier/panel.tsx index 8d0bd4665f..624952203b 100644 --- a/web/app/components/workflow/nodes/question-classifier/panel.tsx +++ b/web/app/components/workflow/nodes/question-classifier/panel.tsx @@ -127,6 +127,11 @@ const Panel: FC> = ({ type="string" description={t(`${i18nPrefix}.outputVars.className`, { ns: 'workflow' })} /> + { const { getBeforeNodesInSameBranch } = useWorkflow() const startNode = getBeforeNodesInSameBranch(id).find(node => node.data.type === BlockEnum.Start) const startNodeId = startNode?.id - const { inputs, setInputs } = useNodeCrud(id, payload) + const { inputs, setInputs: doSetInputs } = useNodeCrud(id, payload) const inputRef = useRef(inputs) + const setInputs = useCallback((newInputs: QuestionClassifierNodeType) => { + doSetInputs(newInputs) + inputRef.current = newInputs + }, [doSetInputs]) useEffect(() => { inputRef.current = inputs }, [inputs]) - const [modelChanged, setModelChanged] = useState(false) + const isHandlingModelChangeRef = useRef(false) const { currentProvider, currentModel, @@ -42,6 +46,13 @@ const useConfig = (id: string, payload: QuestionClassifierNodeType) => { const modelMode = inputs.model?.mode const isChatModel = modelMode === AppModeEnum.CHAT + const handleVisionChange = useCallback((newPayload: QuestionClassifierNodeType['vision']) => { + const newInputs = produce(inputRef.current, (draft) => { + draft.vision = newPayload + }) + setInputs(newInputs) + }, [setInputs]) + const { isVisionModel, handleVisionResolutionEnabledChange, @@ -49,12 +60,7 @@ const useConfig = (id: string, payload: QuestionClassifierNodeType) => { handleModelChanged: handleVisionConfigAfterModelChanged, } = useConfigVision(model, { payload: inputs.vision, - onChange: (newPayload) => { - const newInputs = produce(inputs, (draft) => { - draft.vision = newPayload - }) - setInputs(newInputs) - }, + onChange: handleVisionChange, }) const handleModelChanged = useCallback((model: { provider: string, modelId: string, mode?: string }) => { @@ -63,21 +69,23 @@ const useConfig = (id: string, payload: QuestionClassifierNodeType) => { draft.model.name = model.modelId draft.model.mode = model.mode! }) + isHandlingModelChangeRef.current = true setInputs(newInputs) - setModelChanged(true) }, [setInputs]) useEffect(() => { if (currentProvider?.provider && currentModel?.model && !model.provider) { - handleModelChanged({ - provider: currentProvider?.provider, - modelId: currentModel?.model, - mode: currentModel?.model_properties?.mode as string, + startTransition(() => { + handleModelChanged({ + provider: currentProvider?.provider, + modelId: currentModel?.model, + mode: currentModel?.model_properties?.mode as string | undefined, + }) }) } }, [model.provider, currentProvider, currentModel, handleModelChanged]) - const handleCompletionParamsChange = useCallback((newParams: Record) => { + const handleCompletionParamsChange = useCallback((newParams: Record) => { const newInputs = produce(inputs, (draft) => { draft.model.completion_params = newParams }) @@ -86,11 +94,13 @@ const useConfig = (id: string, payload: QuestionClassifierNodeType) => { // change to vision model to set vision enabled, else disabled useEffect(() => { - if (!modelChanged) + if (!isHandlingModelChangeRef.current) return - setModelChanged(false) - handleVisionConfigAfterModelChanged() - }, [isVisionModel, modelChanged]) + isHandlingModelChangeRef.current = false + startTransition(() => { + handleVisionConfigAfterModelChanged() + }) + }, [handleVisionConfigAfterModelChanged, isVisionModel]) const handleQueryVarChange = useCallback((newVar: ValueSelector | string) => { const newInputs = produce(inputs, (draft) => { @@ -101,22 +111,58 @@ const useConfig = (id: string, payload: QuestionClassifierNodeType) => { useEffect(() => { const isReady = defaultConfig && Object.keys(defaultConfig).length > 0 - if (isReady) { - let query_variable_selector: ValueSelector = [] - if (isChatMode && inputs.query_variable_selector.length === 0 && startNodeId) - query_variable_selector = [startNodeId, 'sys.query'] - setInputs({ - ...inputs, - ...defaultConfig, - query_variable_selector: inputs.query_variable_selector.length > 0 ? inputs.query_variable_selector : query_variable_selector, - }) - } - }, [defaultConfig]) + if (!isReady) + return - const handleClassesChange = useCallback((newClasses: any) => { + const currentInputs = inputRef.current + let shouldUpdate = false + + const nextInputs = produce(currentInputs, (draft) => { + if (!draft.model) + draft.model = defaultConfig.model + + if (!draft.classes) + draft.classes = defaultConfig.classes + + if (!draft._targetBranches) + draft._targetBranches = defaultConfig._targetBranches + + if (!draft.vision) + draft.vision = defaultConfig.vision + + if (draft.query_variable_selector.length === 0 && isChatMode && startNodeId) { + draft.query_variable_selector = [startNodeId, 'sys.query'] + shouldUpdate = true + } + + if (!currentInputs.model && defaultConfig.model) + shouldUpdate = true + + if (!currentInputs.classes && defaultConfig.classes) + shouldUpdate = true + + if (!currentInputs._targetBranches && defaultConfig._targetBranches) + shouldUpdate = true + + if (!currentInputs.vision && defaultConfig.vision) + shouldUpdate = true + }) + + if (!shouldUpdate) + return + + startTransition(() => { + setInputs(nextInputs) + }) + }, [defaultConfig, isChatMode, setInputs, startNodeId]) + + const handleClassesChange = useCallback((newClasses: Topic[]) => { const newInputs = produce(inputs, (draft) => { draft.classes = newClasses - draft._targetBranches = newClasses + draft._targetBranches = newClasses.map((item: Topic) => ({ + id: item.id, + name: item.name, + })) }) setInputs(newInputs) }, [inputs, setInputs]) @@ -170,7 +216,13 @@ const useConfig = (id: string, payload: QuestionClassifierNodeType) => { const handleSortTopic = useCallback((newTopics: (Topic & { id: string })[]) => { const newInputs = produce(inputs, (draft) => { - draft.classes = newTopics.filter(Boolean).map(item => ({ + const sortedTopics = newTopics.filter(Boolean) + draft.classes = sortedTopics.map(item => ({ + id: item.id, + name: item.name, + label: item.label, + })) + draft._targetBranches = sortedTopics.map(item => ({ id: item.id, name: item.name, })) diff --git a/web/app/components/workflow/workflow-preview/components/nodes/question-classifier/__tests__/node.spec.tsx b/web/app/components/workflow/workflow-preview/components/nodes/question-classifier/__tests__/node.spec.tsx index 82cede1d85..463c4dd43d 100644 --- a/web/app/components/workflow/workflow-preview/components/nodes/question-classifier/__tests__/node.spec.tsx +++ b/web/app/components/workflow/workflow-preview/components/nodes/question-classifier/__tests__/node.spec.tsx @@ -28,7 +28,7 @@ describe('workflow preview question classifier node', () => { title: 'Classifier', desc: '', classes: [ - { id: 'class-1', name: 'Billing' }, + { id: 'class-1', name: 'Billing', label: 'Billing label' }, { id: 'class-2', name: 'Support' }, ], } as never, @@ -38,7 +38,8 @@ describe('workflow preview question classifier node', () => { , ) - expect(getByText('workflow.nodes.questionClassifiers.class 1')).toBeInTheDocument() + expect(getByText('Billing label')).toBeInTheDocument() + expect(getByText('CLASS 2')).toBeInTheDocument() expect(container.querySelector('[data-handleid="class-1"]')).toBeInTheDocument() expect(container.querySelector('[data-handleid="class-2"]')).toBeInTheDocument() }) diff --git a/web/app/components/workflow/workflow-preview/components/nodes/question-classifier/node.tsx b/web/app/components/workflow/workflow-preview/components/nodes/question-classifier/node.tsx index d0e23f7823..9884990309 100644 --- a/web/app/components/workflow/workflow-preview/components/nodes/question-classifier/node.tsx +++ b/web/app/components/workflow/workflow-preview/components/nodes/question-classifier/node.tsx @@ -4,10 +4,9 @@ import type { QuestionClassifierNodeType } from '@/app/components/workflow/nodes import * as React from 'react' import { useTranslation } from 'react-i18next' import InfoPanel from '@/app/components/workflow/nodes/_base/components/info-panel' +import { getDisplayClassLabel } from '@/app/components/workflow/nodes/question-classifier/components/class-label-utils' import { NodeSourceHandle } from '../../node-handle' -const i18nPrefix = 'nodes.questionClassifiers' - const Node: FC> = (props) => { const { t } = useTranslation() const { data } = props @@ -24,7 +23,7 @@ const Node: FC> = (props) => { className="relative" >