From 778aabb4859199db7c9f0eafc690d6d3cb2745b6 Mon Sep 17 00:00:00 2001 From: Sean Kenneth Doherty Date: Wed, 4 Feb 2026 00:36:52 -0600 Subject: [PATCH 01/18] refactor(api): replace reqparse with Pydantic models in trial.py (#31789) Co-authored-by: Asuka Minato --- api/controllers/console/explore/trial.py | 101 ++++++++++++++++------- 1 file changed, 71 insertions(+), 30 deletions(-) diff --git a/api/controllers/console/explore/trial.py b/api/controllers/console/explore/trial.py index cd523b481c..ba214e71c0 100644 --- a/api/controllers/console/explore/trial.py +++ b/api/controllers/console/explore/trial.py @@ -1,8 +1,9 @@ import logging -from typing import Any, cast +from typing import Any, Literal, cast from flask import request -from flask_restx import Resource, fields, marshal, marshal_with, reqparse +from flask_restx import Resource, fields, marshal, marshal_with +from pydantic import BaseModel from werkzeug.exceptions import Forbidden, InternalServerError, NotFound import services @@ -117,7 +118,56 @@ workflow_fields_copy["rag_pipeline_variables"] = fields.List(fields.Nested(pipel workflow_model = get_or_create_model("TrialWorkflow", workflow_fields_copy) +# Pydantic models for request validation +DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}" + + +class WorkflowRunRequest(BaseModel): + inputs: dict + files: list | None = None + + +class ChatRequest(BaseModel): + inputs: dict + query: str + files: list | None = None + conversation_id: str | None = None + parent_message_id: str | None = None + retriever_from: str = "explore_app" + + +class TextToSpeechRequest(BaseModel): + message_id: str | None = None + voice: str | None = None + text: str | None = None + streaming: bool | None = None + + +class CompletionRequest(BaseModel): + inputs: dict + query: str = "" + files: list | None = None + response_mode: Literal["blocking", "streaming"] | None = None + retriever_from: str = "explore_app" + + +# Register schemas for Swagger documentation +console_ns.schema_model( + WorkflowRunRequest.__name__, WorkflowRunRequest.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0) +) +console_ns.schema_model( + ChatRequest.__name__, ChatRequest.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0) +) +console_ns.schema_model( + TextToSpeechRequest.__name__, TextToSpeechRequest.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0) +) +console_ns.schema_model( + CompletionRequest.__name__, CompletionRequest.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0) +) + + class TrialAppWorkflowRunApi(TrialAppResource): + @console_ns.expect(console_ns.models[WorkflowRunRequest.__name__]) def post(self, trial_app): """ Run workflow @@ -129,10 +179,8 @@ class TrialAppWorkflowRunApi(TrialAppResource): if app_mode != AppMode.WORKFLOW: raise NotWorkflowAppError() - parser = reqparse.RequestParser() - parser.add_argument("inputs", type=dict, required=True, nullable=False, location="json") - parser.add_argument("files", type=list, required=False, location="json") - args = parser.parse_args() + request_data = WorkflowRunRequest.model_validate(console_ns.payload) + args = request_data.model_dump() assert current_user is not None try: app_id = app_model.id @@ -183,6 +231,7 @@ class TrialAppWorkflowTaskStopApi(TrialAppResource): class TrialChatApi(TrialAppResource): + @console_ns.expect(console_ns.models[ChatRequest.__name__]) @trial_feature_enable def post(self, trial_app): app_model = trial_app @@ -190,14 +239,14 @@ class TrialChatApi(TrialAppResource): if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}: raise NotChatAppError() - parser = reqparse.RequestParser() - parser.add_argument("inputs", type=dict, required=True, location="json") - parser.add_argument("query", type=str, required=True, location="json") - parser.add_argument("files", type=list, required=False, location="json") - parser.add_argument("conversation_id", type=uuid_value, location="json") - parser.add_argument("parent_message_id", type=uuid_value, required=False, location="json") - parser.add_argument("retriever_from", type=str, required=False, default="explore_app", location="json") - args = parser.parse_args() + request_data = ChatRequest.model_validate(console_ns.payload) + args = request_data.model_dump() + + # Validate UUID values if provided + if args.get("conversation_id"): + args["conversation_id"] = uuid_value(args["conversation_id"]) + if args.get("parent_message_id"): + args["parent_message_id"] = uuid_value(args["parent_message_id"]) args["auto_generate_name"] = False @@ -320,20 +369,16 @@ class TrialChatAudioApi(TrialAppResource): class TrialChatTextApi(TrialAppResource): + @console_ns.expect(console_ns.models[TextToSpeechRequest.__name__]) @trial_feature_enable def post(self, trial_app): app_model = trial_app try: - parser = reqparse.RequestParser() - parser.add_argument("message_id", type=str, required=False, location="json") - parser.add_argument("voice", type=str, location="json") - parser.add_argument("text", type=str, location="json") - parser.add_argument("streaming", type=bool, location="json") - args = parser.parse_args() + request_data = TextToSpeechRequest.model_validate(console_ns.payload) - message_id = args.get("message_id", None) - text = args.get("text", None) - voice = args.get("voice", None) + message_id = request_data.message_id + text = request_data.text + voice = request_data.voice if not isinstance(current_user, Account): raise ValueError("current_user must be an Account instance") @@ -371,19 +416,15 @@ class TrialChatTextApi(TrialAppResource): class TrialCompletionApi(TrialAppResource): + @console_ns.expect(console_ns.models[CompletionRequest.__name__]) @trial_feature_enable def post(self, trial_app): app_model = trial_app if app_model.mode != "completion": raise NotCompletionAppError() - parser = reqparse.RequestParser() - parser.add_argument("inputs", type=dict, required=True, location="json") - parser.add_argument("query", type=str, location="json", default="") - parser.add_argument("files", type=list, required=False, location="json") - parser.add_argument("response_mode", type=str, choices=["blocking", "streaming"], location="json") - parser.add_argument("retriever_from", type=str, required=False, default="explore_app", location="json") - args = parser.parse_args() + request_data = CompletionRequest.model_validate(console_ns.payload) + args = request_data.model_dump() streaming = args["response_mode"] == "streaming" args["auto_generate_name"] = False From 64e769f96ea01161ca8e28af7db2dbe590a5d048 Mon Sep 17 00:00:00 2001 From: Coding On Star <447357187@qq.com> Date: Wed, 4 Feb 2026 14:51:47 +0800 Subject: [PATCH 02/18] refactor: plugin detail panel components for better maintainability and code organization. (#31870) Co-authored-by: CodingOnStar --- .../app-selector/app-inputs-panel.tsx | 172 +---- .../hooks/use-app-inputs-form-schema.ts | 211 +++++ .../detail-header.spec.tsx | 1 - .../plugin-detail-panel/detail-header.tsx | 418 +--------- .../components/header-modals.spec.tsx | 539 +++++++++++++ .../components/header-modals.tsx | 107 +++ .../detail-header/components/index.ts | 2 + .../components/plugin-source-badge.spec.tsx | 200 +++++ .../components/plugin-source-badge.tsx | 59 ++ .../detail-header/hooks/index.ts | 3 + .../hooks/use-detail-header-state.spec.ts | 409 ++++++++++ .../hooks/use-detail-header-state.ts | 132 ++++ .../hooks/use-plugin-operations.spec.ts | 549 +++++++++++++ .../hooks/use-plugin-operations.ts | 143 ++++ .../detail-header/index.tsx | 286 +++++++ .../create/common-modal.spec.tsx | 47 +- .../subscription-list/create/common-modal.tsx | 495 ++---------- .../create/components/modal-steps.tsx | 304 ++++++++ .../create/hooks/use-common-modal-state.ts | 401 ++++++++++ .../hooks/use-oauth-client-state.spec.ts | 719 ++++++++++++++++++ .../create/hooks/use-oauth-client-state.ts | 241 ++++++ .../subscription-list/create/index.spec.tsx | 36 - .../subscription-list/create/index.tsx | 22 +- .../create/oauth-client.spec.tsx | 102 +-- .../subscription-list/create/oauth-client.tsx | 241 ++---- .../subscription-list/create/types.ts | 6 + web/eslint-suppressions.json | 13 - 27 files changed, 4481 insertions(+), 1377 deletions(-) create mode 100644 web/app/components/plugins/plugin-detail-panel/app-selector/hooks/use-app-inputs-form-schema.ts create mode 100644 web/app/components/plugins/plugin-detail-panel/detail-header/components/header-modals.spec.tsx create mode 100644 web/app/components/plugins/plugin-detail-panel/detail-header/components/header-modals.tsx create mode 100644 web/app/components/plugins/plugin-detail-panel/detail-header/components/index.ts create mode 100644 web/app/components/plugins/plugin-detail-panel/detail-header/components/plugin-source-badge.spec.tsx create mode 100644 web/app/components/plugins/plugin-detail-panel/detail-header/components/plugin-source-badge.tsx create mode 100644 web/app/components/plugins/plugin-detail-panel/detail-header/hooks/index.ts create mode 100644 web/app/components/plugins/plugin-detail-panel/detail-header/hooks/use-detail-header-state.spec.ts create mode 100644 web/app/components/plugins/plugin-detail-panel/detail-header/hooks/use-detail-header-state.ts create mode 100644 web/app/components/plugins/plugin-detail-panel/detail-header/hooks/use-plugin-operations.spec.ts create mode 100644 web/app/components/plugins/plugin-detail-panel/detail-header/hooks/use-plugin-operations.ts create mode 100644 web/app/components/plugins/plugin-detail-panel/detail-header/index.tsx create mode 100644 web/app/components/plugins/plugin-detail-panel/subscription-list/create/components/modal-steps.tsx create mode 100644 web/app/components/plugins/plugin-detail-panel/subscription-list/create/hooks/use-common-modal-state.ts create mode 100644 web/app/components/plugins/plugin-detail-panel/subscription-list/create/hooks/use-oauth-client-state.spec.ts create mode 100644 web/app/components/plugins/plugin-detail-panel/subscription-list/create/hooks/use-oauth-client-state.ts create mode 100644 web/app/components/plugins/plugin-detail-panel/subscription-list/create/types.ts diff --git a/web/app/components/plugins/plugin-detail-panel/app-selector/app-inputs-panel.tsx b/web/app/components/plugins/plugin-detail-panel/app-selector/app-inputs-panel.tsx index c7280c7508..8e7affad8e 100644 --- a/web/app/components/plugins/plugin-detail-panel/app-selector/app-inputs-panel.tsx +++ b/web/app/components/plugins/plugin-detail-panel/app-selector/app-inputs-panel.tsx @@ -1,27 +1,19 @@ 'use client' -import type { FileUpload } from '@/app/components/base/features/types' import type { App } from '@/types/app' -import * as React from 'react' -import { useMemo, useRef } from 'react' +import { useRef } from 'react' import { useTranslation } from 'react-i18next' import Loading from '@/app/components/base/loading' -import { FILE_EXTS } from '@/app/components/base/prompt-editor/constants' import AppInputsForm from '@/app/components/plugins/plugin-detail-panel/app-selector/app-inputs-form' -import { BlockEnum, InputVarType, SupportUploadFileTypes } from '@/app/components/workflow/types' -import { useAppDetail } from '@/service/use-apps' -import { useFileUploadConfig } from '@/service/use-common' -import { useAppWorkflow } from '@/service/use-workflow' -import { AppModeEnum, Resolution } from '@/types/app' - +import { useAppInputsFormSchema } from '@/app/components/plugins/plugin-detail-panel/app-selector/hooks/use-app-inputs-form-schema' import { cn } from '@/utils/classnames' type Props = { value?: { app_id: string - inputs: Record + inputs: Record } appDetail: App - onFormChange: (value: Record) => void + onFormChange: (value: Record) => void } const AppInputsPanel = ({ @@ -30,155 +22,33 @@ const AppInputsPanel = ({ onFormChange, }: Props) => { const { t } = useTranslation() - const inputsRef = useRef(value?.inputs || {}) - const isBasicApp = appDetail.mode !== AppModeEnum.ADVANCED_CHAT && appDetail.mode !== AppModeEnum.WORKFLOW - const { data: fileUploadConfig } = useFileUploadConfig() - const { data: currentApp, isFetching: isAppLoading } = useAppDetail(appDetail.id) - const { data: currentWorkflow, isFetching: isWorkflowLoading } = useAppWorkflow(isBasicApp ? '' : appDetail.id) - const isLoading = isAppLoading || isWorkflowLoading + const inputsRef = useRef>(value?.inputs || {}) - const basicAppFileConfig = useMemo(() => { - let fileConfig: FileUpload - if (isBasicApp) - fileConfig = currentApp?.model_config?.file_upload as FileUpload - else - fileConfig = currentWorkflow?.features?.file_upload as FileUpload - return { - image: { - detail: fileConfig?.image?.detail || Resolution.high, - enabled: !!fileConfig?.image?.enabled, - number_limits: fileConfig?.image?.number_limits || 3, - transfer_methods: fileConfig?.image?.transfer_methods || ['local_file', 'remote_url'], - }, - enabled: !!(fileConfig?.enabled || fileConfig?.image?.enabled), - allowed_file_types: fileConfig?.allowed_file_types || [SupportUploadFileTypes.image], - allowed_file_extensions: fileConfig?.allowed_file_extensions || [...FILE_EXTS[SupportUploadFileTypes.image]].map(ext => `.${ext}`), - allowed_file_upload_methods: fileConfig?.allowed_file_upload_methods || fileConfig?.image?.transfer_methods || ['local_file', 'remote_url'], - number_limits: fileConfig?.number_limits || fileConfig?.image?.number_limits || 3, - } - }, [currentApp?.model_config?.file_upload, currentWorkflow?.features?.file_upload, isBasicApp]) + const { inputFormSchema, isLoading } = useAppInputsFormSchema({ appDetail }) - const inputFormSchema = useMemo(() => { - if (!currentApp) - return [] - let inputFormSchema = [] - if (isBasicApp) { - inputFormSchema = currentApp.model_config?.user_input_form?.filter((item: any) => !item.external_data_tool).map((item: any) => { - if (item.paragraph) { - return { - ...item.paragraph, - type: 'paragraph', - required: false, - } - } - if (item.number) { - return { - ...item.number, - type: 'number', - required: false, - } - } - if (item.checkbox) { - return { - ...item.checkbox, - type: 'checkbox', - required: false, - } - } - if (item.select) { - return { - ...item.select, - type: 'select', - required: false, - } - } - - if (item['file-list']) { - return { - ...item['file-list'], - type: 'file-list', - required: false, - fileUploadConfig, - } - } - - if (item.file) { - return { - ...item.file, - type: 'file', - required: false, - fileUploadConfig, - } - } - - if (item.json_object) { - return { - ...item.json_object, - type: 'json_object', - } - } - - return { - ...item['text-input'], - type: 'text-input', - required: false, - } - }) || [] - } - else { - const startNode = currentWorkflow?.graph?.nodes.find(node => node.data.type === BlockEnum.Start) as any - inputFormSchema = startNode?.data.variables.map((variable: any) => { - if (variable.type === InputVarType.multiFiles) { - return { - ...variable, - required: false, - fileUploadConfig, - } - } - - if (variable.type === InputVarType.singleFile) { - return { - ...variable, - required: false, - fileUploadConfig, - } - } - return { - ...variable, - required: false, - } - }) || [] - } - if ((currentApp.mode === AppModeEnum.COMPLETION || currentApp.mode === AppModeEnum.WORKFLOW) && basicAppFileConfig.enabled) { - inputFormSchema.push({ - label: 'Image Upload', - variable: '#image#', - type: InputVarType.singleFile, - required: false, - ...basicAppFileConfig, - fileUploadConfig, - }) - } - return inputFormSchema || [] - }, [basicAppFileConfig, currentApp, currentWorkflow, fileUploadConfig, isBasicApp]) - - const handleFormChange = (value: Record) => { - inputsRef.current = value - onFormChange(value) + const handleFormChange = (newValue: Record) => { + inputsRef.current = newValue + onFormChange(newValue) } + const hasInputs = inputFormSchema.length > 0 + return (
{isLoading &&
} {!isLoading && ( -
{t('appSelector.params', { ns: 'app' })}
- )} - {!isLoading && !inputFormSchema.length && ( -
-
{t('appSelector.noParams', { ns: 'app' })}
+
+ {t('appSelector.params', { ns: 'app' })}
)} - {!isLoading && !!inputFormSchema.length && ( + {!isLoading && !hasInputs && ( +
+
+ {t('appSelector.noParams', { ns: 'app' })} +
+
+ )} + {!isLoading && hasInputs && (
= { + 'paragraph': 'paragraph', + 'number': 'number', + 'checkbox': 'checkbox', + 'select': 'select', + 'file-list': 'file-list', + 'file': 'file', + 'json_object': 'json_object', +} + +const FILE_INPUT_TYPES = new Set(['file-list', 'file']) + +const WORKFLOW_FILE_VAR_TYPES = new Set([InputVarType.multiFiles, InputVarType.singleFile]) + +type InputSchemaItem = { + label?: string + variable?: string + type: string + required: boolean + fileUploadConfig?: FileUploadConfigResponse + [key: string]: unknown +} + +function isBasicAppMode(mode: string): boolean { + return mode !== AppModeEnum.ADVANCED_CHAT && mode !== AppModeEnum.WORKFLOW +} + +function supportsImageUpload(mode: string): boolean { + return mode === AppModeEnum.COMPLETION || mode === AppModeEnum.WORKFLOW +} + +function buildFileConfig(fileConfig: FileUpload | undefined) { + return { + image: { + detail: fileConfig?.image?.detail || Resolution.high, + enabled: !!fileConfig?.image?.enabled, + number_limits: fileConfig?.image?.number_limits || 3, + transfer_methods: fileConfig?.image?.transfer_methods || ['local_file', 'remote_url'], + }, + enabled: !!(fileConfig?.enabled || fileConfig?.image?.enabled), + allowed_file_types: fileConfig?.allowed_file_types || [SupportUploadFileTypes.image], + allowed_file_extensions: fileConfig?.allowed_file_extensions + || [...FILE_EXTS[SupportUploadFileTypes.image]].map(ext => `.${ext}`), + allowed_file_upload_methods: fileConfig?.allowed_file_upload_methods + || fileConfig?.image?.transfer_methods + || ['local_file', 'remote_url'], + number_limits: fileConfig?.number_limits || fileConfig?.image?.number_limits || 3, + } +} + +function mapBasicAppInputItem( + item: Record, + fileUploadConfig?: FileUploadConfigResponse, +): InputSchemaItem | null { + for (const [key, type] of Object.entries(BASIC_INPUT_TYPE_MAP)) { + if (!item[key]) + continue + + const inputData = item[key] as Record + const needsFileConfig = FILE_INPUT_TYPES.has(key) + + return { + ...inputData, + type, + required: false, + ...(needsFileConfig && { fileUploadConfig }), + } + } + + const textInput = item['text-input'] as Record | undefined + if (!textInput) + return null + + return { + ...textInput, + type: 'text-input', + required: false, + } +} + +function mapWorkflowVariable( + variable: Record, + fileUploadConfig?: FileUploadConfigResponse, +): InputSchemaItem { + const needsFileConfig = WORKFLOW_FILE_VAR_TYPES.has(variable.type as InputVarType) + + return { + ...variable, + type: variable.type as string, + required: false, + ...(needsFileConfig && { fileUploadConfig }), + } +} + +function createImageUploadSchema( + basicFileConfig: ReturnType, + fileUploadConfig?: FileUploadConfigResponse, +): InputSchemaItem { + return { + label: 'Image Upload', + variable: '#image#', + type: InputVarType.singleFile, + required: false, + ...basicFileConfig, + fileUploadConfig, + } +} + +function buildBasicAppSchema( + currentApp: App, + fileUploadConfig?: FileUploadConfigResponse, +): InputSchemaItem[] { + const userInputForm = currentApp.model_config?.user_input_form as Array> | undefined + if (!userInputForm) + return [] + + return userInputForm + .filter((item: Record) => !item.external_data_tool) + .map((item: Record) => mapBasicAppInputItem(item, fileUploadConfig)) + .filter((item): item is InputSchemaItem => item !== null) +} + +function buildWorkflowSchema( + workflow: FetchWorkflowDraftResponse, + fileUploadConfig?: FileUploadConfigResponse, +): InputSchemaItem[] { + const startNode = workflow.graph?.nodes.find( + node => node.data.type === BlockEnum.Start, + ) as { data: { variables: Array> } } | undefined + + if (!startNode?.data.variables) + return [] + + return startNode.data.variables.map( + variable => mapWorkflowVariable(variable, fileUploadConfig), + ) +} + +type UseAppInputsFormSchemaParams = { + appDetail: App +} + +type UseAppInputsFormSchemaResult = { + inputFormSchema: InputSchemaItem[] + isLoading: boolean + fileUploadConfig?: FileUploadConfigResponse +} + +export function useAppInputsFormSchema({ + appDetail, +}: UseAppInputsFormSchemaParams): UseAppInputsFormSchemaResult { + const isBasicApp = isBasicAppMode(appDetail.mode) + + const { data: fileUploadConfig } = useFileUploadConfig() + const { data: currentApp, isFetching: isAppLoading } = useAppDetail(appDetail.id) + const { data: currentWorkflow, isFetching: isWorkflowLoading } = useAppWorkflow( + isBasicApp ? '' : appDetail.id, + ) + + const isLoading = isAppLoading || isWorkflowLoading + + const inputFormSchema = useMemo(() => { + if (!currentApp) + return [] + + if (!isBasicApp && !currentWorkflow) + return [] + + // Build base schema based on app type + // Note: currentWorkflow is guaranteed to be defined here due to the early return above + const baseSchema = isBasicApp + ? buildBasicAppSchema(currentApp, fileUploadConfig) + : buildWorkflowSchema(currentWorkflow!, fileUploadConfig) + + if (!supportsImageUpload(currentApp.mode)) + return baseSchema + + const rawFileConfig = isBasicApp + ? currentApp.model_config?.file_upload as FileUpload + : currentWorkflow?.features?.file_upload as FileUpload + + const basicFileConfig = buildFileConfig(rawFileConfig) + + if (!basicFileConfig.enabled) + return baseSchema + + return [ + ...baseSchema, + createImageUploadSchema(basicFileConfig, fileUploadConfig), + ] + }, [currentApp, currentWorkflow, fileUploadConfig, isBasicApp]) + + return { + inputFormSchema, + isLoading, + fileUploadConfig, + } +} diff --git a/web/app/components/plugins/plugin-detail-panel/detail-header.spec.tsx b/web/app/components/plugins/plugin-detail-panel/detail-header.spec.tsx index 49c3ef1058..cc0ac404b2 100644 --- a/web/app/components/plugins/plugin-detail-panel/detail-header.spec.tsx +++ b/web/app/components/plugins/plugin-detail-panel/detail-header.spec.tsx @@ -6,7 +6,6 @@ import Toast from '@/app/components/base/toast' import { PluginSource } from '../types' import DetailHeader from './detail-header' -// Use vi.hoisted for mock functions used in vi.mock factories const { mockSetShowUpdatePluginModal, mockRefreshModelProviders, diff --git a/web/app/components/plugins/plugin-detail-panel/detail-header.tsx b/web/app/components/plugins/plugin-detail-panel/detail-header.tsx index 7f7e11ad51..3f39ed289e 100644 --- a/web/app/components/plugins/plugin-detail-panel/detail-header.tsx +++ b/web/app/components/plugins/plugin-detail-panel/detail-header.tsx @@ -1,416 +1,2 @@ -import type { PluginDetail } from '../types' -import { - RiArrowLeftRightLine, - RiBugLine, - RiCloseLine, - RiHardDrive3Line, -} from '@remixicon/react' -import { useBoolean } from 'ahooks' -import * as React from 'react' -import { useCallback, useMemo, useState } from 'react' -import { useTranslation } from 'react-i18next' -import ActionButton from '@/app/components/base/action-button' -import { trackEvent } from '@/app/components/base/amplitude' -import Badge from '@/app/components/base/badge' -import Button from '@/app/components/base/button' -import Confirm from '@/app/components/base/confirm' -import { Github } from '@/app/components/base/icons/src/public/common' -import { BoxSparkleFill } from '@/app/components/base/icons/src/vender/plugin' -import Toast from '@/app/components/base/toast' -import Tooltip from '@/app/components/base/tooltip' -import { AuthCategory, PluginAuth } from '@/app/components/plugins/plugin-auth' -import OperationDropdown from '@/app/components/plugins/plugin-detail-panel/operation-dropdown' -import PluginInfo from '@/app/components/plugins/plugin-page/plugin-info' -import UpdateFromMarketplace from '@/app/components/plugins/update-plugin/from-market-place' -import PluginVersionPicker from '@/app/components/plugins/update-plugin/plugin-version-picker' -import { API_PREFIX } from '@/config' -import { useAppContext } from '@/context/app-context' -import { useGlobalPublicStore } from '@/context/global-public-context' -import { useGetLanguage, useLocale } from '@/context/i18n' -import { useModalContext } from '@/context/modal-context' -import { useProviderContext } from '@/context/provider-context' -import useTheme from '@/hooks/use-theme' -import { uninstallPlugin } from '@/service/plugins' -import { useAllToolProviders, useInvalidateAllToolProviders } from '@/service/use-tools' -import { cn } from '@/utils/classnames' -import { getMarketplaceUrl } from '@/utils/var' -import { AutoUpdateLine } from '../../base/icons/src/vender/system' -import Verified from '../base/badges/verified' -import DeprecationNotice from '../base/deprecation-notice' -import Icon from '../card/base/card-icon' -import Description from '../card/base/description' -import OrgInfo from '../card/base/org-info' -import Title from '../card/base/title' -import { useGitHubReleases } from '../install-plugin/hooks' -import useReferenceSetting from '../plugin-page/use-reference-setting' -import { AUTO_UPDATE_MODE } from '../reference-setting-modal/auto-update-setting/types' -import { convertUTCDaySecondsToLocalSeconds, timeOfDayToDayjs } from '../reference-setting-modal/auto-update-setting/utils' -import { PluginCategoryEnum, PluginSource } from '../types' - -const i18nPrefix = 'action' - -type Props = { - detail: PluginDetail - isReadmeView?: boolean - onHide?: () => void - onUpdate?: (isDelete?: boolean) => void -} - -const DetailHeader = ({ - detail, - isReadmeView = false, - onHide, - onUpdate, -}: Props) => { - const { t } = useTranslation() - const { userProfile: { timezone } } = useAppContext() - - const { theme } = useTheme() - const locale = useGetLanguage() - const currentLocale = useLocale() - const { checkForUpdates, fetchReleases } = useGitHubReleases() - const { setShowUpdatePluginModal } = useModalContext() - const { refreshModelProviders } = useProviderContext() - const invalidateAllToolProviders = useInvalidateAllToolProviders() - const { enable_marketplace } = useGlobalPublicStore(s => s.systemFeatures) - - const { - id, - source, - tenant_id, - version, - latest_unique_identifier, - latest_version, - meta, - plugin_id, - status, - deprecated_reason, - alternative_plugin_id, - } = detail - - const { author, category, name, label, description, icon, icon_dark, verified, tool } = detail.declaration || detail - const isTool = category === PluginCategoryEnum.tool - const providerBriefInfo = tool?.identity - const providerKey = `${plugin_id}/${providerBriefInfo?.name}` - const { data: collectionList = [] } = useAllToolProviders(isTool) - const provider = useMemo(() => { - return collectionList.find(collection => collection.name === providerKey) - }, [collectionList, providerKey]) - const isFromGitHub = source === PluginSource.github - const isFromMarketplace = source === PluginSource.marketplace - - const [isShow, setIsShow] = useState(false) - const [targetVersion, setTargetVersion] = useState({ - version: latest_version, - unique_identifier: latest_unique_identifier, - }) - const hasNewVersion = useMemo(() => { - if (isFromMarketplace) - return !!latest_version && latest_version !== version - - return false - }, [isFromMarketplace, latest_version, version]) - - const iconFileName = theme === 'dark' && icon_dark ? icon_dark : icon - const iconSrc = iconFileName - ? (iconFileName.startsWith('http') ? iconFileName : `${API_PREFIX}/workspaces/current/plugin/icon?tenant_id=${tenant_id}&filename=${iconFileName}`) - : '' - - const detailUrl = useMemo(() => { - if (isFromGitHub) - return `https://github.com/${meta!.repo}` - if (isFromMarketplace) - return getMarketplaceUrl(`/plugins/${author}/${name}`, { language: currentLocale, theme }) - return '' - }, [author, isFromGitHub, isFromMarketplace, meta, name, theme]) - - const [isShowUpdateModal, { - setTrue: showUpdateModal, - setFalse: hideUpdateModal, - }] = useBoolean(false) - - const { referenceSetting } = useReferenceSetting() - const { auto_upgrade: autoUpgradeInfo } = referenceSetting || {} - const isAutoUpgradeEnabled = useMemo(() => { - if (!enable_marketplace) - return false - if (!autoUpgradeInfo || !isFromMarketplace) - return false - if (autoUpgradeInfo.strategy_setting === 'disabled') - return false - if (autoUpgradeInfo.upgrade_mode === AUTO_UPDATE_MODE.update_all) - return true - if (autoUpgradeInfo.upgrade_mode === AUTO_UPDATE_MODE.partial && autoUpgradeInfo.include_plugins.includes(plugin_id)) - return true - if (autoUpgradeInfo.upgrade_mode === AUTO_UPDATE_MODE.exclude && !autoUpgradeInfo.exclude_plugins.includes(plugin_id)) - return true - return false - }, [autoUpgradeInfo, plugin_id, isFromMarketplace]) - - const [isDowngrade, setIsDowngrade] = useState(false) - const handleUpdate = async (isDowngrade?: boolean) => { - if (isFromMarketplace) { - setIsDowngrade(!!isDowngrade) - showUpdateModal() - return - } - - const owner = meta!.repo.split('/')[0] || author - const repo = meta!.repo.split('/')[1] || name - const fetchedReleases = await fetchReleases(owner, repo) - if (fetchedReleases.length === 0) - return - const { needUpdate, toastProps } = checkForUpdates(fetchedReleases, meta!.version) - Toast.notify(toastProps) - if (needUpdate) { - setShowUpdatePluginModal({ - onSaveCallback: () => { - onUpdate?.() - }, - payload: { - type: PluginSource.github, - category: detail.declaration.category, - github: { - originalPackageInfo: { - id: detail.plugin_unique_identifier, - repo: meta!.repo, - version: meta!.version, - package: meta!.package, - releases: fetchedReleases, - }, - }, - }, - }) - } - } - - const handleUpdatedFromMarketplace = () => { - onUpdate?.() - hideUpdateModal() - } - - const [isShowPluginInfo, { - setTrue: showPluginInfo, - setFalse: hidePluginInfo, - }] = useBoolean(false) - - const [isShowDeleteConfirm, { - setTrue: showDeleteConfirm, - setFalse: hideDeleteConfirm, - }] = useBoolean(false) - - const [deleting, { - setTrue: showDeleting, - setFalse: hideDeleting, - }] = useBoolean(false) - - const handleDelete = useCallback(async () => { - showDeleting() - const res = await uninstallPlugin(id) - hideDeleting() - if (res.success) { - hideDeleteConfirm() - onUpdate?.(true) - if (PluginCategoryEnum.model.includes(category)) - refreshModelProviders() - if (PluginCategoryEnum.tool.includes(category)) - invalidateAllToolProviders() - trackEvent('plugin_uninstalled', { plugin_id, plugin_name: name }) - } - }, [showDeleting, id, hideDeleting, hideDeleteConfirm, onUpdate, category, refreshModelProviders, invalidateAllToolProviders, plugin_id, name]) - - return ( -
-
-
- -
-
-
- - {verified && !isReadmeView && <Verified className="ml-0.5 h-4 w-4" text={t('marketplace.verifiedTip', { ns: 'plugin' })} />} - {!!version && ( - <PluginVersionPicker - disabled={!isFromMarketplace || isReadmeView} - isShow={isShow} - onShowChange={setIsShow} - pluginID={plugin_id} - currentVersion={version} - onSelect={(state) => { - setTargetVersion(state) - handleUpdate(state.isDowngrade) - }} - trigger={( - <Badge - className={cn( - 'mx-1', - isShow && 'bg-state-base-hover', - (isShow || isFromMarketplace) && 'hover:bg-state-base-hover', - )} - uppercase={false} - text={( - <> - <div>{isFromGitHub ? meta!.version : version}</div> - {isFromMarketplace && !isReadmeView && <RiArrowLeftRightLine className="ml-1 h-3 w-3 text-text-tertiary" />} - </> - )} - hasRedCornerMark={hasNewVersion} - /> - )} - /> - )} - {/* Auto update info */} - {isAutoUpgradeEnabled && !isReadmeView && ( - <Tooltip popupContent={t('autoUpdate.nextUpdateTime', { ns: 'plugin', time: timeOfDayToDayjs(convertUTCDaySecondsToLocalSeconds(autoUpgradeInfo?.upgrade_time_of_day || 0, timezone!)).format('hh:mm A') })}> - {/* add a a div to fix tooltip hover not show problem */} - <div> - <Badge className="mr-1 cursor-pointer px-1"> - <AutoUpdateLine className="size-3" /> - </Badge> - </div> - </Tooltip> - )} - - {(hasNewVersion || isFromGitHub) && ( - <Button - variant="secondary-accent" - size="small" - className="!h-5" - onClick={() => { - if (isFromMarketplace) { - setTargetVersion({ - version: latest_version, - unique_identifier: latest_unique_identifier, - }) - } - handleUpdate() - }} - > - {t('detailPanel.operation.update', { ns: 'plugin' })} - </Button> - )} - </div> - <div className="mb-1 flex h-4 items-center justify-between"> - <div className="mt-0.5 flex items-center"> - <OrgInfo - packageNameClassName="w-auto" - orgName={author} - packageName={name?.includes('/') ? (name.split('/').pop() || '') : name} - /> - {!!source && ( - <> - <div className="system-xs-regular ml-1 mr-0.5 text-text-quaternary">ยท</div> - {source === PluginSource.marketplace && ( - <Tooltip popupContent={t('detailPanel.categoryTip.marketplace', { ns: 'plugin' })}> - <div><BoxSparkleFill className="h-3.5 w-3.5 text-text-tertiary hover:text-text-accent" /></div> - </Tooltip> - )} - {source === PluginSource.github && ( - <Tooltip popupContent={t('detailPanel.categoryTip.github', { ns: 'plugin' })}> - <div><Github className="h-3.5 w-3.5 text-text-secondary hover:text-text-primary" /></div> - </Tooltip> - )} - {source === PluginSource.local && ( - <Tooltip popupContent={t('detailPanel.categoryTip.local', { ns: 'plugin' })}> - <div><RiHardDrive3Line className="h-3.5 w-3.5 text-text-tertiary" /></div> - </Tooltip> - )} - {source === PluginSource.debugging && ( - <Tooltip popupContent={t('detailPanel.categoryTip.debugging', { ns: 'plugin' })}> - <div><RiBugLine className="h-3.5 w-3.5 text-text-tertiary hover:text-text-warning" /></div> - </Tooltip> - )} - </> - )} - </div> - </div> - </div> - {!isReadmeView && ( - <div className="flex gap-1"> - <OperationDropdown - source={source} - onInfo={showPluginInfo} - onCheckVersion={handleUpdate} - onRemove={showDeleteConfirm} - detailUrl={detailUrl} - /> - <ActionButton onClick={onHide}> - <RiCloseLine className="h-4 w-4" /> - </ActionButton> - </div> - )} - </div> - {isFromMarketplace && ( - <DeprecationNotice - status={status} - deprecatedReason={deprecated_reason} - alternativePluginId={alternative_plugin_id} - alternativePluginURL={getMarketplaceUrl(`/plugins/${alternative_plugin_id}`, { language: currentLocale, theme })} - className="mt-3" - /> - )} - {!isReadmeView && <Description className="mb-2 mt-3 h-auto" text={description[locale]} descriptionLineRows={2}></Description>} - { - category === PluginCategoryEnum.tool && !isReadmeView && ( - <PluginAuth - pluginPayload={{ - provider: provider?.name || '', - category: AuthCategory.tool, - providerType: provider?.type || '', - detail, - }} - /> - ) - } - {isShowPluginInfo && ( - <PluginInfo - repository={isFromGitHub ? meta?.repo : ''} - release={version} - packageName={meta?.package || ''} - onHide={hidePluginInfo} - /> - )} - {isShowDeleteConfirm && ( - <Confirm - isShow - title={t(`${i18nPrefix}.delete`, { ns: 'plugin' })} - content={( - <div> - {t(`${i18nPrefix}.deleteContentLeft`, { ns: 'plugin' })} - <span className="system-md-semibold">{label[locale]}</span> - {t(`${i18nPrefix}.deleteContentRight`, { ns: 'plugin' })} - <br /> - </div> - )} - onCancel={hideDeleteConfirm} - onConfirm={handleDelete} - isLoading={deleting} - isDisabled={deleting} - /> - )} - { - isShowUpdateModal && ( - <UpdateFromMarketplace - pluginId={plugin_id} - payload={{ - category: detail.declaration.category, - originalPackageInfo: { - id: detail.plugin_unique_identifier, - payload: detail.declaration, - }, - targetPackageInfo: { - id: targetVersion.unique_identifier, - version: targetVersion.version, - }, - }} - onCancel={hideUpdateModal} - onSave={handleUpdatedFromMarketplace} - isShowDowngradeWarningModal={isDowngrade && isAutoUpgradeEnabled} - /> - ) - } - </div> - ) -} - -export default DetailHeader +// Re-export from refactored module for backward compatibility +export { default } from './detail-header/index' diff --git a/web/app/components/plugins/plugin-detail-panel/detail-header/components/header-modals.spec.tsx b/web/app/components/plugins/plugin-detail-panel/detail-header/components/header-modals.spec.tsx new file mode 100644 index 0000000000..4011ee13f5 --- /dev/null +++ b/web/app/components/plugins/plugin-detail-panel/detail-header/components/header-modals.spec.tsx @@ -0,0 +1,539 @@ +import type { PluginDetail } from '../../../types' +import type { ModalStates, VersionTarget } from '../hooks' +import { fireEvent, render, screen } from '@testing-library/react' +import { beforeEach, describe, expect, it, vi } from 'vitest' +import { PluginSource } from '../../../types' +import HeaderModals from './header-modals' + +vi.mock('react-i18next', () => ({ + useTranslation: () => ({ + t: (key: string) => key, + }), +})) + +vi.mock('@/context/i18n', () => ({ + useGetLanguage: () => 'en_US', +})) + +vi.mock('@/app/components/base/confirm', () => ({ + default: ({ isShow, title, onCancel, onConfirm, isLoading }: { + isShow: boolean + title: string + onCancel: () => void + onConfirm: () => void + isLoading: boolean + }) => isShow + ? ( + <div data-testid="delete-confirm"> + <div data-testid="delete-title">{title}</div> + <button data-testid="confirm-cancel" onClick={onCancel}>Cancel</button> + <button data-testid="confirm-ok" onClick={onConfirm} disabled={isLoading}>Confirm</button> + </div> + ) + : null, +})) + +vi.mock('@/app/components/plugins/plugin-page/plugin-info', () => ({ + default: ({ repository, release, packageName, onHide }: { + repository: string + release: string + packageName: string + onHide: () => void + }) => ( + <div data-testid="plugin-info"> + <div data-testid="plugin-info-repo">{repository}</div> + <div data-testid="plugin-info-release">{release}</div> + <div data-testid="plugin-info-package">{packageName}</div> + <button data-testid="plugin-info-close" onClick={onHide}>Close</button> + </div> + ), +})) + +vi.mock('@/app/components/plugins/update-plugin/from-market-place', () => ({ + default: ({ pluginId, onSave, onCancel, isShowDowngradeWarningModal }: { + pluginId: string + onSave: () => void + onCancel: () => void + isShowDowngradeWarningModal: boolean + }) => ( + <div data-testid="update-modal"> + <div data-testid="update-plugin-id">{pluginId}</div> + <div data-testid="update-downgrade-warning">{String(isShowDowngradeWarningModal)}</div> + <button data-testid="update-modal-save" onClick={onSave}>Save</button> + <button data-testid="update-modal-cancel" onClick={onCancel}>Cancel</button> + </div> + ), +})) + +const createPluginDetail = (overrides: Partial<PluginDetail> = {}): PluginDetail => ({ + id: 'test-id', + created_at: '2024-01-01', + updated_at: '2024-01-02', + name: 'Test Plugin', + plugin_id: 'test-plugin', + plugin_unique_identifier: 'test-uid', + declaration: { + author: 'test-author', + name: 'test-plugin-name', + category: 'tool', + label: { en_US: 'Test Plugin Label' }, + description: { en_US: 'Test description' }, + icon: 'icon.png', + verified: true, + } as unknown as PluginDetail['declaration'], + installation_id: 'install-1', + tenant_id: 'tenant-1', + endpoints_setups: 0, + endpoints_active: 0, + version: '1.0.0', + latest_version: '2.0.0', + latest_unique_identifier: 'new-uid', + source: PluginSource.marketplace, + meta: undefined, + status: 'active', + deprecated_reason: '', + alternative_plugin_id: '', + ...overrides, +}) + +const createModalStatesMock = (overrides: Partial<ModalStates> = {}): ModalStates => ({ + isShowUpdateModal: false, + showUpdateModal: vi.fn<() => void>(), + hideUpdateModal: vi.fn<() => void>(), + isShowPluginInfo: false, + showPluginInfo: vi.fn<() => void>(), + hidePluginInfo: vi.fn<() => void>(), + isShowDeleteConfirm: false, + showDeleteConfirm: vi.fn<() => void>(), + hideDeleteConfirm: vi.fn<() => void>(), + deleting: false, + showDeleting: vi.fn<() => void>(), + hideDeleting: vi.fn<() => void>(), + ...overrides, +}) + +const createTargetVersion = (overrides: Partial<VersionTarget> = {}): VersionTarget => ({ + version: '2.0.0', + unique_identifier: 'new-uid', + ...overrides, +}) + +describe('HeaderModals', () => { + let mockOnUpdatedFromMarketplace: () => void + let mockOnDelete: () => void + + beforeEach(() => { + vi.clearAllMocks() + mockOnUpdatedFromMarketplace = vi.fn<() => void>() + mockOnDelete = vi.fn<() => void>() + }) + + describe('Plugin Info Modal', () => { + it('should not render plugin info modal when isShowPluginInfo is false', () => { + const modalStates = createModalStatesMock({ isShowPluginInfo: false }) + render( + <HeaderModals + detail={createPluginDetail()} + modalStates={modalStates} + targetVersion={createTargetVersion()} + isDowngrade={false} + isAutoUpgradeEnabled={false} + onUpdatedFromMarketplace={mockOnUpdatedFromMarketplace} + onDelete={mockOnDelete} + />, + ) + + expect(screen.queryByTestId('plugin-info')).not.toBeInTheDocument() + }) + + it('should render plugin info modal when isShowPluginInfo is true', () => { + const modalStates = createModalStatesMock({ isShowPluginInfo: true }) + render( + <HeaderModals + detail={createPluginDetail()} + modalStates={modalStates} + targetVersion={createTargetVersion()} + isDowngrade={false} + isAutoUpgradeEnabled={false} + onUpdatedFromMarketplace={mockOnUpdatedFromMarketplace} + onDelete={mockOnDelete} + />, + ) + + expect(screen.getByTestId('plugin-info')).toBeInTheDocument() + }) + + it('should pass GitHub repo to plugin info for GitHub source', () => { + const modalStates = createModalStatesMock({ isShowPluginInfo: true }) + const detail = createPluginDetail({ + source: PluginSource.github, + meta: { repo: 'owner/repo', version: 'v1.0.0', package: 'test-pkg' }, + }) + render( + <HeaderModals + detail={detail} + modalStates={modalStates} + targetVersion={createTargetVersion()} + isDowngrade={false} + isAutoUpgradeEnabled={false} + onUpdatedFromMarketplace={mockOnUpdatedFromMarketplace} + onDelete={mockOnDelete} + />, + ) + + expect(screen.getByTestId('plugin-info-repo')).toHaveTextContent('owner/repo') + }) + + it('should pass empty string for repo for non-GitHub source', () => { + const modalStates = createModalStatesMock({ isShowPluginInfo: true }) + render( + <HeaderModals + detail={createPluginDetail({ source: PluginSource.marketplace })} + modalStates={modalStates} + targetVersion={createTargetVersion()} + isDowngrade={false} + isAutoUpgradeEnabled={false} + onUpdatedFromMarketplace={mockOnUpdatedFromMarketplace} + onDelete={mockOnDelete} + />, + ) + + expect(screen.getByTestId('plugin-info-repo')).toHaveTextContent('') + }) + + it('should call hidePluginInfo when close button is clicked', () => { + const modalStates = createModalStatesMock({ isShowPluginInfo: true }) + render( + <HeaderModals + detail={createPluginDetail()} + modalStates={modalStates} + targetVersion={createTargetVersion()} + isDowngrade={false} + isAutoUpgradeEnabled={false} + onUpdatedFromMarketplace={mockOnUpdatedFromMarketplace} + onDelete={mockOnDelete} + />, + ) + + fireEvent.click(screen.getByTestId('plugin-info-close')) + + expect(modalStates.hidePluginInfo).toHaveBeenCalled() + }) + }) + + describe('Delete Confirm Modal', () => { + it('should not render delete confirm when isShowDeleteConfirm is false', () => { + const modalStates = createModalStatesMock({ isShowDeleteConfirm: false }) + render( + <HeaderModals + detail={createPluginDetail()} + modalStates={modalStates} + targetVersion={createTargetVersion()} + isDowngrade={false} + isAutoUpgradeEnabled={false} + onUpdatedFromMarketplace={mockOnUpdatedFromMarketplace} + onDelete={mockOnDelete} + />, + ) + + expect(screen.queryByTestId('delete-confirm')).not.toBeInTheDocument() + }) + + it('should render delete confirm when isShowDeleteConfirm is true', () => { + const modalStates = createModalStatesMock({ isShowDeleteConfirm: true }) + render( + <HeaderModals + detail={createPluginDetail()} + modalStates={modalStates} + targetVersion={createTargetVersion()} + isDowngrade={false} + isAutoUpgradeEnabled={false} + onUpdatedFromMarketplace={mockOnUpdatedFromMarketplace} + onDelete={mockOnDelete} + />, + ) + + expect(screen.getByTestId('delete-confirm')).toBeInTheDocument() + }) + + it('should show correct delete title', () => { + const modalStates = createModalStatesMock({ isShowDeleteConfirm: true }) + render( + <HeaderModals + detail={createPluginDetail()} + modalStates={modalStates} + targetVersion={createTargetVersion()} + isDowngrade={false} + isAutoUpgradeEnabled={false} + onUpdatedFromMarketplace={mockOnUpdatedFromMarketplace} + onDelete={mockOnDelete} + />, + ) + + expect(screen.getByTestId('delete-title')).toHaveTextContent('action.delete') + }) + + it('should call hideDeleteConfirm when cancel is clicked', () => { + const modalStates = createModalStatesMock({ isShowDeleteConfirm: true }) + render( + <HeaderModals + detail={createPluginDetail()} + modalStates={modalStates} + targetVersion={createTargetVersion()} + isDowngrade={false} + isAutoUpgradeEnabled={false} + onUpdatedFromMarketplace={mockOnUpdatedFromMarketplace} + onDelete={mockOnDelete} + />, + ) + + fireEvent.click(screen.getByTestId('confirm-cancel')) + + expect(modalStates.hideDeleteConfirm).toHaveBeenCalled() + }) + + it('should call onDelete when confirm is clicked', () => { + const modalStates = createModalStatesMock({ isShowDeleteConfirm: true }) + render( + <HeaderModals + detail={createPluginDetail()} + modalStates={modalStates} + targetVersion={createTargetVersion()} + isDowngrade={false} + isAutoUpgradeEnabled={false} + onUpdatedFromMarketplace={mockOnUpdatedFromMarketplace} + onDelete={mockOnDelete} + />, + ) + + fireEvent.click(screen.getByTestId('confirm-ok')) + + expect(mockOnDelete).toHaveBeenCalled() + }) + + it('should disable confirm button when deleting', () => { + const modalStates = createModalStatesMock({ isShowDeleteConfirm: true, deleting: true }) + render( + <HeaderModals + detail={createPluginDetail()} + modalStates={modalStates} + targetVersion={createTargetVersion()} + isDowngrade={false} + isAutoUpgradeEnabled={false} + onUpdatedFromMarketplace={mockOnUpdatedFromMarketplace} + onDelete={mockOnDelete} + />, + ) + + expect(screen.getByTestId('confirm-ok')).toBeDisabled() + }) + }) + + describe('Update Modal', () => { + it('should not render update modal when isShowUpdateModal is false', () => { + const modalStates = createModalStatesMock({ isShowUpdateModal: false }) + render( + <HeaderModals + detail={createPluginDetail()} + modalStates={modalStates} + targetVersion={createTargetVersion()} + isDowngrade={false} + isAutoUpgradeEnabled={false} + onUpdatedFromMarketplace={mockOnUpdatedFromMarketplace} + onDelete={mockOnDelete} + />, + ) + + expect(screen.queryByTestId('update-modal')).not.toBeInTheDocument() + }) + + it('should render update modal when isShowUpdateModal is true', () => { + const modalStates = createModalStatesMock({ isShowUpdateModal: true }) + render( + <HeaderModals + detail={createPluginDetail()} + modalStates={modalStates} + targetVersion={createTargetVersion()} + isDowngrade={false} + isAutoUpgradeEnabled={false} + onUpdatedFromMarketplace={mockOnUpdatedFromMarketplace} + onDelete={mockOnDelete} + />, + ) + + expect(screen.getByTestId('update-modal')).toBeInTheDocument() + }) + + it('should pass plugin id to update modal', () => { + const modalStates = createModalStatesMock({ isShowUpdateModal: true }) + render( + <HeaderModals + detail={createPluginDetail({ plugin_id: 'my-plugin-id' })} + modalStates={modalStates} + targetVersion={createTargetVersion()} + isDowngrade={false} + isAutoUpgradeEnabled={false} + onUpdatedFromMarketplace={mockOnUpdatedFromMarketplace} + onDelete={mockOnDelete} + />, + ) + + expect(screen.getByTestId('update-plugin-id')).toHaveTextContent('my-plugin-id') + }) + + it('should call onUpdatedFromMarketplace when save is clicked', () => { + const modalStates = createModalStatesMock({ isShowUpdateModal: true }) + render( + <HeaderModals + detail={createPluginDetail()} + modalStates={modalStates} + targetVersion={createTargetVersion()} + isDowngrade={false} + isAutoUpgradeEnabled={false} + onUpdatedFromMarketplace={mockOnUpdatedFromMarketplace} + onDelete={mockOnDelete} + />, + ) + + fireEvent.click(screen.getByTestId('update-modal-save')) + + expect(mockOnUpdatedFromMarketplace).toHaveBeenCalled() + }) + + it('should call hideUpdateModal when cancel is clicked', () => { + const modalStates = createModalStatesMock({ isShowUpdateModal: true }) + render( + <HeaderModals + detail={createPluginDetail()} + modalStates={modalStates} + targetVersion={createTargetVersion()} + isDowngrade={false} + isAutoUpgradeEnabled={false} + onUpdatedFromMarketplace={mockOnUpdatedFromMarketplace} + onDelete={mockOnDelete} + />, + ) + + fireEvent.click(screen.getByTestId('update-modal-cancel')) + + expect(modalStates.hideUpdateModal).toHaveBeenCalled() + }) + + it('should show downgrade warning when isDowngrade and isAutoUpgradeEnabled are true', () => { + const modalStates = createModalStatesMock({ isShowUpdateModal: true }) + render( + <HeaderModals + detail={createPluginDetail()} + modalStates={modalStates} + targetVersion={createTargetVersion()} + isDowngrade={true} + isAutoUpgradeEnabled={true} + onUpdatedFromMarketplace={mockOnUpdatedFromMarketplace} + onDelete={mockOnDelete} + />, + ) + + expect(screen.getByTestId('update-downgrade-warning')).toHaveTextContent('true') + }) + + it('should not show downgrade warning when only isDowngrade is true', () => { + const modalStates = createModalStatesMock({ isShowUpdateModal: true }) + render( + <HeaderModals + detail={createPluginDetail()} + modalStates={modalStates} + targetVersion={createTargetVersion()} + isDowngrade={true} + isAutoUpgradeEnabled={false} + onUpdatedFromMarketplace={mockOnUpdatedFromMarketplace} + onDelete={mockOnDelete} + />, + ) + + expect(screen.getByTestId('update-downgrade-warning')).toHaveTextContent('false') + }) + + it('should not show downgrade warning when only isAutoUpgradeEnabled is true', () => { + const modalStates = createModalStatesMock({ isShowUpdateModal: true }) + render( + <HeaderModals + detail={createPluginDetail()} + modalStates={modalStates} + targetVersion={createTargetVersion()} + isDowngrade={false} + isAutoUpgradeEnabled={true} + onUpdatedFromMarketplace={mockOnUpdatedFromMarketplace} + onDelete={mockOnDelete} + />, + ) + + expect(screen.getByTestId('update-downgrade-warning')).toHaveTextContent('false') + }) + }) + + describe('Multiple Modals', () => { + it('should render multiple modals when multiple are open', () => { + const modalStates = createModalStatesMock({ + isShowPluginInfo: true, + isShowDeleteConfirm: true, + isShowUpdateModal: true, + }) + render( + <HeaderModals + detail={createPluginDetail()} + modalStates={modalStates} + targetVersion={createTargetVersion()} + isDowngrade={false} + isAutoUpgradeEnabled={false} + onUpdatedFromMarketplace={mockOnUpdatedFromMarketplace} + onDelete={mockOnDelete} + />, + ) + + expect(screen.getByTestId('plugin-info')).toBeInTheDocument() + expect(screen.getByTestId('delete-confirm')).toBeInTheDocument() + expect(screen.getByTestId('update-modal')).toBeInTheDocument() + }) + }) + + describe('Edge Cases', () => { + it('should handle undefined target version values', () => { + const modalStates = createModalStatesMock({ isShowUpdateModal: true }) + render( + <HeaderModals + detail={createPluginDetail()} + modalStates={modalStates} + targetVersion={{ version: undefined, unique_identifier: undefined }} + isDowngrade={false} + isAutoUpgradeEnabled={false} + onUpdatedFromMarketplace={mockOnUpdatedFromMarketplace} + onDelete={mockOnDelete} + />, + ) + + expect(screen.getByTestId('update-modal')).toBeInTheDocument() + }) + + it('should handle empty meta for GitHub source', () => { + const modalStates = createModalStatesMock({ isShowPluginInfo: true }) + const detail = createPluginDetail({ + source: PluginSource.github, + meta: undefined, + }) + render( + <HeaderModals + detail={detail} + modalStates={modalStates} + targetVersion={createTargetVersion()} + isDowngrade={false} + isAutoUpgradeEnabled={false} + onUpdatedFromMarketplace={mockOnUpdatedFromMarketplace} + onDelete={mockOnDelete} + />, + ) + + expect(screen.getByTestId('plugin-info-repo')).toHaveTextContent('') + expect(screen.getByTestId('plugin-info-package')).toHaveTextContent('') + }) + }) +}) diff --git a/web/app/components/plugins/plugin-detail-panel/detail-header/components/header-modals.tsx b/web/app/components/plugins/plugin-detail-panel/detail-header/components/header-modals.tsx new file mode 100644 index 0000000000..62840b64e3 --- /dev/null +++ b/web/app/components/plugins/plugin-detail-panel/detail-header/components/header-modals.tsx @@ -0,0 +1,107 @@ +'use client' + +import type { FC } from 'react' +import type { PluginDetail } from '../../../types' +import type { ModalStates, VersionTarget } from '../hooks' +import { useTranslation } from 'react-i18next' +import Confirm from '@/app/components/base/confirm' +import PluginInfo from '@/app/components/plugins/plugin-page/plugin-info' +import UpdateFromMarketplace from '@/app/components/plugins/update-plugin/from-market-place' +import { useGetLanguage } from '@/context/i18n' +import { PluginSource } from '../../../types' + +const i18nPrefix = 'action' + +type HeaderModalsProps = { + detail: PluginDetail + modalStates: ModalStates + targetVersion: VersionTarget + isDowngrade: boolean + isAutoUpgradeEnabled: boolean + onUpdatedFromMarketplace: () => void + onDelete: () => void +} + +const HeaderModals: FC<HeaderModalsProps> = ({ + detail, + modalStates, + targetVersion, + isDowngrade, + isAutoUpgradeEnabled, + onUpdatedFromMarketplace, + onDelete, +}) => { + const { t } = useTranslation() + const locale = useGetLanguage() + + const { source, version, meta } = detail + const { label } = detail.declaration || detail + const isFromGitHub = source === PluginSource.github + + const { + isShowUpdateModal, + hideUpdateModal, + isShowPluginInfo, + hidePluginInfo, + isShowDeleteConfirm, + hideDeleteConfirm, + deleting, + } = modalStates + + return ( + <> + {/* Plugin Info Modal */} + {isShowPluginInfo && ( + <PluginInfo + repository={isFromGitHub ? meta?.repo : ''} + release={version} + packageName={meta?.package || ''} + onHide={hidePluginInfo} + /> + )} + + {/* Delete Confirm Modal */} + {isShowDeleteConfirm && ( + <Confirm + isShow + title={t(`${i18nPrefix}.delete`, { ns: 'plugin' })} + content={( + <div> + {t(`${i18nPrefix}.deleteContentLeft`, { ns: 'plugin' })} + <span className="system-md-semibold">{label[locale]}</span> + {t(`${i18nPrefix}.deleteContentRight`, { ns: 'plugin' })} + <br /> + </div> + )} + onCancel={hideDeleteConfirm} + onConfirm={onDelete} + isLoading={deleting} + isDisabled={deleting} + /> + )} + + {/* Update from Marketplace Modal */} + {isShowUpdateModal && ( + <UpdateFromMarketplace + pluginId={detail.plugin_id} + payload={{ + category: detail.declaration?.category ?? '', + originalPackageInfo: { + id: detail.plugin_unique_identifier, + payload: detail.declaration ?? undefined, + }, + targetPackageInfo: { + id: targetVersion.unique_identifier || '', + version: targetVersion.version || '', + }, + }} + onCancel={hideUpdateModal} + onSave={onUpdatedFromMarketplace} + isShowDowngradeWarningModal={isDowngrade && isAutoUpgradeEnabled} + /> + )} + </> + ) +} + +export default HeaderModals diff --git a/web/app/components/plugins/plugin-detail-panel/detail-header/components/index.ts b/web/app/components/plugins/plugin-detail-panel/detail-header/components/index.ts new file mode 100644 index 0000000000..6e0d9d5042 --- /dev/null +++ b/web/app/components/plugins/plugin-detail-panel/detail-header/components/index.ts @@ -0,0 +1,2 @@ +export { default as HeaderModals } from './header-modals' +export { default as PluginSourceBadge } from './plugin-source-badge' diff --git a/web/app/components/plugins/plugin-detail-panel/detail-header/components/plugin-source-badge.spec.tsx b/web/app/components/plugins/plugin-detail-panel/detail-header/components/plugin-source-badge.spec.tsx new file mode 100644 index 0000000000..e2fa1f6140 --- /dev/null +++ b/web/app/components/plugins/plugin-detail-panel/detail-header/components/plugin-source-badge.spec.tsx @@ -0,0 +1,200 @@ +import { render, screen } from '@testing-library/react' +import { beforeEach, describe, expect, it, vi } from 'vitest' +import { PluginSource } from '../../../types' +import PluginSourceBadge from './plugin-source-badge' + +vi.mock('react-i18next', () => ({ + useTranslation: () => ({ + t: (key: string) => key, + }), +})) + +vi.mock('@/app/components/base/tooltip', () => ({ + default: ({ children, popupContent }: { children: React.ReactNode, popupContent: string }) => ( + <div data-testid="tooltip" data-content={popupContent}> + {children} + </div> + ), +})) + +describe('PluginSourceBadge', () => { + beforeEach(() => { + vi.clearAllMocks() + }) + + describe('Source Icon Rendering', () => { + it('should render marketplace source badge', () => { + render(<PluginSourceBadge source={PluginSource.marketplace} />) + + const tooltip = screen.getByTestId('tooltip') + expect(tooltip).toBeInTheDocument() + expect(tooltip).toHaveAttribute('data-content', 'detailPanel.categoryTip.marketplace') + }) + + it('should render github source badge', () => { + render(<PluginSourceBadge source={PluginSource.github} />) + + const tooltip = screen.getByTestId('tooltip') + expect(tooltip).toBeInTheDocument() + expect(tooltip).toHaveAttribute('data-content', 'detailPanel.categoryTip.github') + }) + + it('should render local source badge', () => { + render(<PluginSourceBadge source={PluginSource.local} />) + + const tooltip = screen.getByTestId('tooltip') + expect(tooltip).toBeInTheDocument() + expect(tooltip).toHaveAttribute('data-content', 'detailPanel.categoryTip.local') + }) + + it('should render debugging source badge', () => { + render(<PluginSourceBadge source={PluginSource.debugging} />) + + const tooltip = screen.getByTestId('tooltip') + expect(tooltip).toBeInTheDocument() + expect(tooltip).toHaveAttribute('data-content', 'detailPanel.categoryTip.debugging') + }) + }) + + describe('Separator Rendering', () => { + it('should render separator dot before marketplace badge', () => { + const { container } = render(<PluginSourceBadge source={PluginSource.marketplace} />) + + const separator = container.querySelector('.text-text-quaternary') + expect(separator).toBeInTheDocument() + expect(separator?.textContent).toBe('ยท') + }) + + it('should render separator dot before github badge', () => { + const { container } = render(<PluginSourceBadge source={PluginSource.github} />) + + const separator = container.querySelector('.text-text-quaternary') + expect(separator).toBeInTheDocument() + expect(separator?.textContent).toBe('ยท') + }) + + it('should render separator dot before local badge', () => { + const { container } = render(<PluginSourceBadge source={PluginSource.local} />) + + const separator = container.querySelector('.text-text-quaternary') + expect(separator).toBeInTheDocument() + }) + + it('should render separator dot before debugging badge', () => { + const { container } = render(<PluginSourceBadge source={PluginSource.debugging} />) + + const separator = container.querySelector('.text-text-quaternary') + expect(separator).toBeInTheDocument() + }) + }) + + describe('Tooltip Content', () => { + it('should show marketplace tooltip', () => { + render(<PluginSourceBadge source={PluginSource.marketplace} />) + + expect(screen.getByTestId('tooltip')).toHaveAttribute( + 'data-content', + 'detailPanel.categoryTip.marketplace', + ) + }) + + it('should show github tooltip', () => { + render(<PluginSourceBadge source={PluginSource.github} />) + + expect(screen.getByTestId('tooltip')).toHaveAttribute( + 'data-content', + 'detailPanel.categoryTip.github', + ) + }) + + it('should show local tooltip', () => { + render(<PluginSourceBadge source={PluginSource.local} />) + + expect(screen.getByTestId('tooltip')).toHaveAttribute( + 'data-content', + 'detailPanel.categoryTip.local', + ) + }) + + it('should show debugging tooltip', () => { + render(<PluginSourceBadge source={PluginSource.debugging} />) + + expect(screen.getByTestId('tooltip')).toHaveAttribute( + 'data-content', + 'detailPanel.categoryTip.debugging', + ) + }) + }) + + describe('Icon Element Structure', () => { + it('should render icon inside tooltip for marketplace', () => { + render(<PluginSourceBadge source={PluginSource.marketplace} />) + + const tooltip = screen.getByTestId('tooltip') + const iconWrapper = tooltip.querySelector('div') + expect(iconWrapper).toBeInTheDocument() + }) + + it('should render icon inside tooltip for github', () => { + render(<PluginSourceBadge source={PluginSource.github} />) + + const tooltip = screen.getByTestId('tooltip') + const iconWrapper = tooltip.querySelector('div') + expect(iconWrapper).toBeInTheDocument() + }) + + it('should render icon inside tooltip for local', () => { + render(<PluginSourceBadge source={PluginSource.local} />) + + const tooltip = screen.getByTestId('tooltip') + const iconWrapper = tooltip.querySelector('div') + expect(iconWrapper).toBeInTheDocument() + }) + + it('should render icon inside tooltip for debugging', () => { + render(<PluginSourceBadge source={PluginSource.debugging} />) + + const tooltip = screen.getByTestId('tooltip') + const iconWrapper = tooltip.querySelector('div') + expect(iconWrapper).toBeInTheDocument() + }) + }) + + describe('Lookup Table Coverage', () => { + it('should handle all PluginSource enum values', () => { + const allSources = Object.values(PluginSource) + + allSources.forEach((source) => { + const { container } = render(<PluginSourceBadge source={source} />) + // Should render either tooltip or nothing + expect(container).toBeTruthy() + }) + }) + }) + + describe('Invalid Source Handling', () => { + it('should return null for unknown source type', () => { + // Use type assertion to test invalid source value + const invalidSource = 'unknown_source' as PluginSource + const { container } = render(<PluginSourceBadge source={invalidSource} />) + + // Should render nothing (empty container) + expect(container.firstChild).toBeNull() + }) + + it('should not render separator for invalid source', () => { + const invalidSource = 'invalid' as PluginSource + const { container } = render(<PluginSourceBadge source={invalidSource} />) + + const separator = container.querySelector('.text-text-quaternary') + expect(separator).not.toBeInTheDocument() + }) + + it('should not render tooltip for invalid source', () => { + const invalidSource = '' as PluginSource + render(<PluginSourceBadge source={invalidSource} />) + + expect(screen.queryByTestId('tooltip')).not.toBeInTheDocument() + }) + }) +}) diff --git a/web/app/components/plugins/plugin-detail-panel/detail-header/components/plugin-source-badge.tsx b/web/app/components/plugins/plugin-detail-panel/detail-header/components/plugin-source-badge.tsx new file mode 100644 index 0000000000..e886cec4da --- /dev/null +++ b/web/app/components/plugins/plugin-detail-panel/detail-header/components/plugin-source-badge.tsx @@ -0,0 +1,59 @@ +'use client' + +import type { FC, ReactNode } from 'react' +import { + RiBugLine, + RiHardDrive3Line, +} from '@remixicon/react' +import { useTranslation } from 'react-i18next' +import { Github } from '@/app/components/base/icons/src/public/common' +import { BoxSparkleFill } from '@/app/components/base/icons/src/vender/plugin' +import Tooltip from '@/app/components/base/tooltip' +import { PluginSource } from '../../../types' + +type SourceConfig = { + icon: ReactNode + tipKey: string +} + +type PluginSourceBadgeProps = { + source: PluginSource +} + +const SOURCE_CONFIG_MAP: Record<PluginSource, SourceConfig | null> = { + [PluginSource.marketplace]: { + icon: <BoxSparkleFill className="h-3.5 w-3.5 text-text-tertiary hover:text-text-accent" />, + tipKey: 'detailPanel.categoryTip.marketplace', + }, + [PluginSource.github]: { + icon: <Github className="h-3.5 w-3.5 text-text-secondary hover:text-text-primary" />, + tipKey: 'detailPanel.categoryTip.github', + }, + [PluginSource.local]: { + icon: <RiHardDrive3Line className="h-3.5 w-3.5 text-text-tertiary" />, + tipKey: 'detailPanel.categoryTip.local', + }, + [PluginSource.debugging]: { + icon: <RiBugLine className="h-3.5 w-3.5 text-text-tertiary hover:text-text-warning" />, + tipKey: 'detailPanel.categoryTip.debugging', + }, +} + +const PluginSourceBadge: FC<PluginSourceBadgeProps> = ({ source }) => { + const { t } = useTranslation() + + const config = SOURCE_CONFIG_MAP[source] + if (!config) + return null + + return ( + <> + <div className="system-xs-regular ml-1 mr-0.5 text-text-quaternary">ยท</div> + <Tooltip popupContent={t(config.tipKey as never, { ns: 'plugin' })}> + <div>{config.icon}</div> + </Tooltip> + </> + ) +} + +export default PluginSourceBadge diff --git a/web/app/components/plugins/plugin-detail-panel/detail-header/hooks/index.ts b/web/app/components/plugins/plugin-detail-panel/detail-header/hooks/index.ts new file mode 100644 index 0000000000..47b4d9b9a5 --- /dev/null +++ b/web/app/components/plugins/plugin-detail-panel/detail-header/hooks/index.ts @@ -0,0 +1,3 @@ +export { useDetailHeaderState } from './use-detail-header-state' +export type { ModalStates, UseDetailHeaderStateReturn, VersionPickerState, VersionTarget } from './use-detail-header-state' +export { usePluginOperations } from './use-plugin-operations' diff --git a/web/app/components/plugins/plugin-detail-panel/detail-header/hooks/use-detail-header-state.spec.ts b/web/app/components/plugins/plugin-detail-panel/detail-header/hooks/use-detail-header-state.spec.ts new file mode 100644 index 0000000000..2e14fed60a --- /dev/null +++ b/web/app/components/plugins/plugin-detail-panel/detail-header/hooks/use-detail-header-state.spec.ts @@ -0,0 +1,409 @@ +import type { PluginDetail } from '../../../types' +import { act, renderHook } from '@testing-library/react' +import { beforeEach, describe, expect, it, vi } from 'vitest' +import { PluginSource } from '../../../types' +import { useDetailHeaderState } from './use-detail-header-state' + +let mockEnableMarketplace = true +vi.mock('@/context/global-public-context', () => ({ + useGlobalPublicStore: (selector: (state: { systemFeatures: { enable_marketplace: boolean } }) => unknown) => + selector({ systemFeatures: { enable_marketplace: mockEnableMarketplace } }), +})) + +let mockAutoUpgradeInfo: { + strategy_setting: string + upgrade_mode: string + include_plugins: string[] + exclude_plugins: string[] + upgrade_time_of_day: number +} | null = null + +vi.mock('../../../plugin-page/use-reference-setting', () => ({ + default: () => ({ + referenceSetting: mockAutoUpgradeInfo ? { auto_upgrade: mockAutoUpgradeInfo } : null, + }), +})) + +vi.mock('../../../reference-setting-modal/auto-update-setting/types', () => ({ + AUTO_UPDATE_MODE: { + update_all: 'update_all', + partial: 'partial', + exclude: 'exclude', + }, +})) + +const createPluginDetail = (overrides: Partial<PluginDetail> = {}): PluginDetail => ({ + id: 'test-id', + created_at: '2024-01-01', + updated_at: '2024-01-02', + name: 'Test Plugin', + plugin_id: 'test-plugin', + plugin_unique_identifier: 'test-uid', + declaration: { + author: 'test-author', + name: 'test-plugin-name', + category: 'tool', + label: { en_US: 'Test Plugin Label' }, + description: { en_US: 'Test description' }, + icon: 'icon.png', + verified: true, + } as unknown as PluginDetail['declaration'], + installation_id: 'install-1', + tenant_id: 'tenant-1', + endpoints_setups: 0, + endpoints_active: 0, + version: '1.0.0', + latest_version: '1.0.0', + latest_unique_identifier: 'test-uid', + source: PluginSource.marketplace, + meta: undefined, + status: 'active', + deprecated_reason: '', + alternative_plugin_id: '', + ...overrides, +}) + +describe('useDetailHeaderState', () => { + beforeEach(() => { + vi.clearAllMocks() + mockAutoUpgradeInfo = null + mockEnableMarketplace = true + }) + + describe('Source Type Detection', () => { + it('should detect marketplace source', () => { + const detail = createPluginDetail({ source: PluginSource.marketplace }) + const { result } = renderHook(() => useDetailHeaderState(detail)) + + expect(result.current.isFromMarketplace).toBe(true) + expect(result.current.isFromGitHub).toBe(false) + }) + + it('should detect GitHub source', () => { + const detail = createPluginDetail({ + source: PluginSource.github, + meta: { repo: 'owner/repo', version: 'v1.0.0', package: 'pkg' }, + }) + const { result } = renderHook(() => useDetailHeaderState(detail)) + + expect(result.current.isFromGitHub).toBe(true) + expect(result.current.isFromMarketplace).toBe(false) + }) + + it('should detect local source', () => { + const detail = createPluginDetail({ source: PluginSource.local }) + const { result } = renderHook(() => useDetailHeaderState(detail)) + + expect(result.current.isFromGitHub).toBe(false) + expect(result.current.isFromMarketplace).toBe(false) + }) + }) + + describe('Version State', () => { + it('should detect new version available for marketplace plugin', () => { + const detail = createPluginDetail({ + version: '1.0.0', + latest_version: '2.0.0', + source: PluginSource.marketplace, + }) + const { result } = renderHook(() => useDetailHeaderState(detail)) + + expect(result.current.hasNewVersion).toBe(true) + }) + + it('should not detect new version when versions match', () => { + const detail = createPluginDetail({ + version: '1.0.0', + latest_version: '1.0.0', + source: PluginSource.marketplace, + }) + const { result } = renderHook(() => useDetailHeaderState(detail)) + + expect(result.current.hasNewVersion).toBe(false) + }) + + it('should not detect new version for non-marketplace source', () => { + const detail = createPluginDetail({ + version: '1.0.0', + latest_version: '2.0.0', + source: PluginSource.github, + meta: { repo: 'owner/repo', version: 'v1.0.0', package: 'pkg' }, + }) + const { result } = renderHook(() => useDetailHeaderState(detail)) + + expect(result.current.hasNewVersion).toBe(false) + }) + + it('should not detect new version when latest_version is empty', () => { + const detail = createPluginDetail({ + version: '1.0.0', + latest_version: '', + source: PluginSource.marketplace, + }) + const { result } = renderHook(() => useDetailHeaderState(detail)) + + expect(result.current.hasNewVersion).toBe(false) + }) + }) + + describe('Version Picker State', () => { + it('should initialize version picker as hidden', () => { + const detail = createPluginDetail() + const { result } = renderHook(() => useDetailHeaderState(detail)) + + expect(result.current.versionPicker.isShow).toBe(false) + }) + + it('should toggle version picker visibility', () => { + const detail = createPluginDetail() + const { result } = renderHook(() => useDetailHeaderState(detail)) + + act(() => { + result.current.versionPicker.setIsShow(true) + }) + expect(result.current.versionPicker.isShow).toBe(true) + + act(() => { + result.current.versionPicker.setIsShow(false) + }) + expect(result.current.versionPicker.isShow).toBe(false) + }) + + it('should update target version', () => { + const detail = createPluginDetail() + const { result } = renderHook(() => useDetailHeaderState(detail)) + + act(() => { + result.current.versionPicker.setTargetVersion({ + version: '2.0.0', + unique_identifier: 'new-uid', + }) + }) + + expect(result.current.versionPicker.targetVersion.version).toBe('2.0.0') + expect(result.current.versionPicker.targetVersion.unique_identifier).toBe('new-uid') + }) + + it('should set isDowngrade when provided in target version', () => { + const detail = createPluginDetail() + const { result } = renderHook(() => useDetailHeaderState(detail)) + + act(() => { + result.current.versionPicker.setTargetVersion({ + version: '0.5.0', + unique_identifier: 'old-uid', + isDowngrade: true, + }) + }) + + expect(result.current.versionPicker.isDowngrade).toBe(true) + }) + }) + + describe('Modal States', () => { + it('should initialize all modals as hidden', () => { + const detail = createPluginDetail() + const { result } = renderHook(() => useDetailHeaderState(detail)) + + expect(result.current.modalStates.isShowUpdateModal).toBe(false) + expect(result.current.modalStates.isShowPluginInfo).toBe(false) + expect(result.current.modalStates.isShowDeleteConfirm).toBe(false) + expect(result.current.modalStates.deleting).toBe(false) + }) + + it('should toggle update modal', () => { + const detail = createPluginDetail() + const { result } = renderHook(() => useDetailHeaderState(detail)) + + act(() => { + result.current.modalStates.showUpdateModal() + }) + expect(result.current.modalStates.isShowUpdateModal).toBe(true) + + act(() => { + result.current.modalStates.hideUpdateModal() + }) + expect(result.current.modalStates.isShowUpdateModal).toBe(false) + }) + + it('should toggle plugin info modal', () => { + const detail = createPluginDetail() + const { result } = renderHook(() => useDetailHeaderState(detail)) + + act(() => { + result.current.modalStates.showPluginInfo() + }) + expect(result.current.modalStates.isShowPluginInfo).toBe(true) + + act(() => { + result.current.modalStates.hidePluginInfo() + }) + expect(result.current.modalStates.isShowPluginInfo).toBe(false) + }) + + it('should toggle delete confirm modal', () => { + const detail = createPluginDetail() + const { result } = renderHook(() => useDetailHeaderState(detail)) + + act(() => { + result.current.modalStates.showDeleteConfirm() + }) + expect(result.current.modalStates.isShowDeleteConfirm).toBe(true) + + act(() => { + result.current.modalStates.hideDeleteConfirm() + }) + expect(result.current.modalStates.isShowDeleteConfirm).toBe(false) + }) + + it('should toggle deleting state', () => { + const detail = createPluginDetail() + const { result } = renderHook(() => useDetailHeaderState(detail)) + + act(() => { + result.current.modalStates.showDeleting() + }) + expect(result.current.modalStates.deleting).toBe(true) + + act(() => { + result.current.modalStates.hideDeleting() + }) + expect(result.current.modalStates.deleting).toBe(false) + }) + }) + + describe('Auto Upgrade Detection', () => { + it('should disable auto upgrade when marketplace is disabled', () => { + mockEnableMarketplace = false + mockAutoUpgradeInfo = { + strategy_setting: 'enabled', + upgrade_mode: 'update_all', + include_plugins: [], + exclude_plugins: [], + upgrade_time_of_day: 36000, + } + + const detail = createPluginDetail({ source: PluginSource.marketplace }) + const { result } = renderHook(() => useDetailHeaderState(detail)) + + expect(result.current.isAutoUpgradeEnabled).toBe(false) + }) + + it('should disable auto upgrade when strategy is disabled', () => { + mockAutoUpgradeInfo = { + strategy_setting: 'disabled', + upgrade_mode: 'update_all', + include_plugins: [], + exclude_plugins: [], + upgrade_time_of_day: 36000, + } + + const detail = createPluginDetail({ source: PluginSource.marketplace }) + const { result } = renderHook(() => useDetailHeaderState(detail)) + + expect(result.current.isAutoUpgradeEnabled).toBe(false) + }) + + it('should enable auto upgrade for update_all mode', () => { + mockAutoUpgradeInfo = { + strategy_setting: 'enabled', + upgrade_mode: 'update_all', + include_plugins: [], + exclude_plugins: [], + upgrade_time_of_day: 36000, + } + + const detail = createPluginDetail({ source: PluginSource.marketplace }) + const { result } = renderHook(() => useDetailHeaderState(detail)) + + expect(result.current.isAutoUpgradeEnabled).toBe(true) + }) + + it('should enable auto upgrade for partial mode when plugin is included', () => { + mockAutoUpgradeInfo = { + strategy_setting: 'enabled', + upgrade_mode: 'partial', + include_plugins: ['test-plugin'], + exclude_plugins: [], + upgrade_time_of_day: 36000, + } + + const detail = createPluginDetail({ source: PluginSource.marketplace }) + const { result } = renderHook(() => useDetailHeaderState(detail)) + + expect(result.current.isAutoUpgradeEnabled).toBe(true) + }) + + it('should disable auto upgrade for partial mode when plugin is not included', () => { + mockAutoUpgradeInfo = { + strategy_setting: 'enabled', + upgrade_mode: 'partial', + include_plugins: ['other-plugin'], + exclude_plugins: [], + upgrade_time_of_day: 36000, + } + + const detail = createPluginDetail({ source: PluginSource.marketplace }) + const { result } = renderHook(() => useDetailHeaderState(detail)) + + expect(result.current.isAutoUpgradeEnabled).toBe(false) + }) + + it('should enable auto upgrade for exclude mode when plugin is not excluded', () => { + mockAutoUpgradeInfo = { + strategy_setting: 'enabled', + upgrade_mode: 'exclude', + include_plugins: [], + exclude_plugins: ['other-plugin'], + upgrade_time_of_day: 36000, + } + + const detail = createPluginDetail({ source: PluginSource.marketplace }) + const { result } = renderHook(() => useDetailHeaderState(detail)) + + expect(result.current.isAutoUpgradeEnabled).toBe(true) + }) + + it('should disable auto upgrade for exclude mode when plugin is excluded', () => { + mockAutoUpgradeInfo = { + strategy_setting: 'enabled', + upgrade_mode: 'exclude', + include_plugins: [], + exclude_plugins: ['test-plugin'], + upgrade_time_of_day: 36000, + } + + const detail = createPluginDetail({ source: PluginSource.marketplace }) + const { result } = renderHook(() => useDetailHeaderState(detail)) + + expect(result.current.isAutoUpgradeEnabled).toBe(false) + }) + + it('should disable auto upgrade for non-marketplace source', () => { + mockAutoUpgradeInfo = { + strategy_setting: 'enabled', + upgrade_mode: 'update_all', + include_plugins: [], + exclude_plugins: [], + upgrade_time_of_day: 36000, + } + + const detail = createPluginDetail({ + source: PluginSource.github, + meta: { repo: 'owner/repo', version: 'v1.0.0', package: 'pkg' }, + }) + const { result } = renderHook(() => useDetailHeaderState(detail)) + + expect(result.current.isAutoUpgradeEnabled).toBe(false) + }) + + it('should disable auto upgrade when no auto upgrade info', () => { + mockAutoUpgradeInfo = null + + const detail = createPluginDetail({ source: PluginSource.marketplace }) + const { result } = renderHook(() => useDetailHeaderState(detail)) + + expect(result.current.isAutoUpgradeEnabled).toBe(false) + }) + }) +}) diff --git a/web/app/components/plugins/plugin-detail-panel/detail-header/hooks/use-detail-header-state.ts b/web/app/components/plugins/plugin-detail-panel/detail-header/hooks/use-detail-header-state.ts new file mode 100644 index 0000000000..763ed9c992 --- /dev/null +++ b/web/app/components/plugins/plugin-detail-panel/detail-header/hooks/use-detail-header-state.ts @@ -0,0 +1,132 @@ +'use client' + +import type { PluginDetail } from '../../../types' +import { useBoolean } from 'ahooks' +import { useCallback, useMemo, useState } from 'react' +import { useGlobalPublicStore } from '@/context/global-public-context' +import useReferenceSetting from '../../../plugin-page/use-reference-setting' +import { AUTO_UPDATE_MODE } from '../../../reference-setting-modal/auto-update-setting/types' +import { PluginSource } from '../../../types' + +export type VersionTarget = { + version: string | undefined + unique_identifier: string | undefined + isDowngrade?: boolean +} + +export type ModalStates = { + isShowUpdateModal: boolean + showUpdateModal: () => void + hideUpdateModal: () => void + isShowPluginInfo: boolean + showPluginInfo: () => void + hidePluginInfo: () => void + isShowDeleteConfirm: boolean + showDeleteConfirm: () => void + hideDeleteConfirm: () => void + deleting: boolean + showDeleting: () => void + hideDeleting: () => void +} + +export type VersionPickerState = { + isShow: boolean + setIsShow: (show: boolean) => void + targetVersion: VersionTarget + setTargetVersion: (version: VersionTarget) => void + isDowngrade: boolean + setIsDowngrade: (downgrade: boolean) => void +} + +export type UseDetailHeaderStateReturn = { + modalStates: ModalStates + versionPicker: VersionPickerState + hasNewVersion: boolean + isAutoUpgradeEnabled: boolean + isFromGitHub: boolean + isFromMarketplace: boolean +} + +export const useDetailHeaderState = (detail: PluginDetail): UseDetailHeaderStateReturn => { + const { enable_marketplace } = useGlobalPublicStore(s => s.systemFeatures) + const { referenceSetting } = useReferenceSetting() + + const { + source, + version, + latest_version, + latest_unique_identifier, + plugin_id, + } = detail + + const isFromGitHub = source === PluginSource.github + const isFromMarketplace = source === PluginSource.marketplace + const [isShow, setIsShow] = useState(false) + const [targetVersion, setTargetVersion] = useState<VersionTarget>({ + version: latest_version, + unique_identifier: latest_unique_identifier, + }) + const [isDowngrade, setIsDowngrade] = useState(false) + + const [isShowUpdateModal, { setTrue: showUpdateModal, setFalse: hideUpdateModal }] = useBoolean(false) + const [isShowPluginInfo, { setTrue: showPluginInfo, setFalse: hidePluginInfo }] = useBoolean(false) + const [isShowDeleteConfirm, { setTrue: showDeleteConfirm, setFalse: hideDeleteConfirm }] = useBoolean(false) + const [deleting, { setTrue: showDeleting, setFalse: hideDeleting }] = useBoolean(false) + + const hasNewVersion = useMemo(() => { + if (isFromMarketplace) + return !!latest_version && latest_version !== version + return false + }, [isFromMarketplace, latest_version, version]) + + const { auto_upgrade: autoUpgradeInfo } = referenceSetting || {} + + const isAutoUpgradeEnabled = useMemo(() => { + if (!enable_marketplace || !autoUpgradeInfo || !isFromMarketplace) + return false + if (autoUpgradeInfo.strategy_setting === 'disabled') + return false + if (autoUpgradeInfo.upgrade_mode === AUTO_UPDATE_MODE.update_all) + return true + if (autoUpgradeInfo.upgrade_mode === AUTO_UPDATE_MODE.partial && autoUpgradeInfo.include_plugins.includes(plugin_id)) + return true + if (autoUpgradeInfo.upgrade_mode === AUTO_UPDATE_MODE.exclude && !autoUpgradeInfo.exclude_plugins.includes(plugin_id)) + return true + return false + }, [autoUpgradeInfo, plugin_id, isFromMarketplace, enable_marketplace]) + + const handleSetTargetVersion = useCallback((version: VersionTarget) => { + setTargetVersion(version) + if (version.isDowngrade !== undefined) + setIsDowngrade(version.isDowngrade) + }, []) + + return { + modalStates: { + isShowUpdateModal, + showUpdateModal, + hideUpdateModal, + isShowPluginInfo, + showPluginInfo, + hidePluginInfo, + isShowDeleteConfirm, + showDeleteConfirm, + hideDeleteConfirm, + deleting, + showDeleting, + hideDeleting, + }, + versionPicker: { + isShow, + setIsShow, + targetVersion, + setTargetVersion: handleSetTargetVersion, + isDowngrade, + setIsDowngrade, + }, + hasNewVersion, + isAutoUpgradeEnabled, + isFromGitHub, + isFromMarketplace, + } +} diff --git a/web/app/components/plugins/plugin-detail-panel/detail-header/hooks/use-plugin-operations.spec.ts b/web/app/components/plugins/plugin-detail-panel/detail-header/hooks/use-plugin-operations.spec.ts new file mode 100644 index 0000000000..683c4080ea --- /dev/null +++ b/web/app/components/plugins/plugin-detail-panel/detail-header/hooks/use-plugin-operations.spec.ts @@ -0,0 +1,549 @@ +import type { PluginDetail } from '../../../types' +import type { ModalStates, VersionTarget } from './use-detail-header-state' +import { act, renderHook } from '@testing-library/react' +import { beforeEach, describe, expect, it, vi } from 'vitest' +import * as amplitude from '@/app/components/base/amplitude' +import Toast from '@/app/components/base/toast' +import { PluginSource } from '../../../types' +import { usePluginOperations } from './use-plugin-operations' + +type VersionPickerMock = { + setTargetVersion: (version: VersionTarget) => void + setIsDowngrade: (downgrade: boolean) => void +} + +const { + mockSetShowUpdatePluginModal, + mockRefreshModelProviders, + mockInvalidateAllToolProviders, + mockUninstallPlugin, + mockFetchReleases, + mockCheckForUpdates, +} = vi.hoisted(() => { + return { + mockSetShowUpdatePluginModal: vi.fn(), + mockRefreshModelProviders: vi.fn(), + mockInvalidateAllToolProviders: vi.fn(), + mockUninstallPlugin: vi.fn(() => Promise.resolve({ success: true })), + mockFetchReleases: vi.fn(() => Promise.resolve([{ tag_name: 'v2.0.0' }])), + mockCheckForUpdates: vi.fn(() => ({ needUpdate: true, toastProps: { type: 'success', message: 'Update available' } })), + } +}) + +vi.mock('@/context/modal-context', () => ({ + useModalContext: () => ({ + setShowUpdatePluginModal: mockSetShowUpdatePluginModal, + }), +})) + +vi.mock('@/context/provider-context', () => ({ + useProviderContext: () => ({ + refreshModelProviders: mockRefreshModelProviders, + }), +})) + +vi.mock('@/service/plugins', () => ({ + uninstallPlugin: mockUninstallPlugin, +})) + +vi.mock('@/service/use-tools', () => ({ + useInvalidateAllToolProviders: () => mockInvalidateAllToolProviders, +})) + +vi.mock('../../../install-plugin/hooks', () => ({ + useGitHubReleases: () => ({ + checkForUpdates: mockCheckForUpdates, + fetchReleases: mockFetchReleases, + }), +})) + +const createPluginDetail = (overrides: Partial<PluginDetail> = {}): PluginDetail => ({ + id: 'test-id', + created_at: '2024-01-01', + updated_at: '2024-01-02', + name: 'Test Plugin', + plugin_id: 'test-plugin', + plugin_unique_identifier: 'test-uid', + declaration: { + author: 'test-author', + name: 'test-plugin-name', + category: 'tool', + label: { en_US: 'Test Plugin Label' }, + description: { en_US: 'Test description' }, + icon: 'icon.png', + verified: true, + } as unknown as PluginDetail['declaration'], + installation_id: 'install-1', + tenant_id: 'tenant-1', + endpoints_setups: 0, + endpoints_active: 0, + version: '1.0.0', + latest_version: '2.0.0', + latest_unique_identifier: 'new-uid', + source: PluginSource.marketplace, + meta: undefined, + status: 'active', + deprecated_reason: '', + alternative_plugin_id: '', + ...overrides, +}) + +const createModalStatesMock = (): ModalStates => ({ + isShowUpdateModal: false, + showUpdateModal: vi.fn(), + hideUpdateModal: vi.fn(), + isShowPluginInfo: false, + showPluginInfo: vi.fn(), + hidePluginInfo: vi.fn(), + isShowDeleteConfirm: false, + showDeleteConfirm: vi.fn(), + hideDeleteConfirm: vi.fn(), + deleting: false, + showDeleting: vi.fn(), + hideDeleting: vi.fn(), +}) + +const createVersionPickerMock = (): VersionPickerMock => ({ + setTargetVersion: vi.fn<(version: VersionTarget) => void>(), + setIsDowngrade: vi.fn<(downgrade: boolean) => void>(), +}) + +describe('usePluginOperations', () => { + let modalStates: ModalStates + let versionPicker: VersionPickerMock + let mockOnUpdate: (isDelete?: boolean) => void + + beforeEach(() => { + vi.clearAllMocks() + modalStates = createModalStatesMock() + versionPicker = createVersionPickerMock() + mockOnUpdate = vi.fn() + vi.spyOn(Toast, 'notify').mockImplementation(() => ({ clear: vi.fn() })) + vi.spyOn(amplitude, 'trackEvent').mockImplementation(() => {}) + }) + + describe('Marketplace Update Flow', () => { + it('should show update modal for marketplace plugin', async () => { + const detail = createPluginDetail({ source: PluginSource.marketplace }) + const { result } = renderHook(() => + usePluginOperations({ + detail, + modalStates, + versionPicker, + isFromMarketplace: true, + onUpdate: mockOnUpdate, + }), + ) + + await act(async () => { + await result.current.handleUpdate() + }) + + expect(modalStates.showUpdateModal).toHaveBeenCalled() + }) + + it('should set isDowngrade when downgrading', async () => { + const detail = createPluginDetail({ source: PluginSource.marketplace }) + const { result } = renderHook(() => + usePluginOperations({ + detail, + modalStates, + versionPicker, + isFromMarketplace: true, + onUpdate: mockOnUpdate, + }), + ) + + await act(async () => { + await result.current.handleUpdate(true) + }) + + expect(versionPicker.setIsDowngrade).toHaveBeenCalledWith(true) + expect(modalStates.showUpdateModal).toHaveBeenCalled() + }) + + it('should call onUpdate and hide modal on successful marketplace update', () => { + const detail = createPluginDetail({ source: PluginSource.marketplace }) + const { result } = renderHook(() => + usePluginOperations({ + detail, + modalStates, + versionPicker, + isFromMarketplace: true, + onUpdate: mockOnUpdate, + }), + ) + + act(() => { + result.current.handleUpdatedFromMarketplace() + }) + + expect(mockOnUpdate).toHaveBeenCalled() + expect(modalStates.hideUpdateModal).toHaveBeenCalled() + }) + }) + + describe('GitHub Update Flow', () => { + it('should fetch releases from GitHub', async () => { + const detail = createPluginDetail({ + source: PluginSource.github, + meta: { repo: 'owner/repo', version: 'v1.0.0', package: 'pkg' }, + }) + const { result } = renderHook(() => + usePluginOperations({ + detail, + modalStates, + versionPicker, + isFromMarketplace: false, + onUpdate: mockOnUpdate, + }), + ) + + await act(async () => { + await result.current.handleUpdate() + }) + + expect(mockFetchReleases).toHaveBeenCalledWith('owner', 'repo') + }) + + it('should check for updates after fetching releases', async () => { + const detail = createPluginDetail({ + source: PluginSource.github, + meta: { repo: 'owner/repo', version: 'v1.0.0', package: 'pkg' }, + }) + const { result } = renderHook(() => + usePluginOperations({ + detail, + modalStates, + versionPicker, + isFromMarketplace: false, + onUpdate: mockOnUpdate, + }), + ) + + await act(async () => { + await result.current.handleUpdate() + }) + + expect(mockCheckForUpdates).toHaveBeenCalled() + expect(Toast.notify).toHaveBeenCalled() + }) + + it('should show update plugin modal when update is needed', async () => { + const detail = createPluginDetail({ + source: PluginSource.github, + meta: { repo: 'owner/repo', version: 'v1.0.0', package: 'pkg' }, + }) + const { result } = renderHook(() => + usePluginOperations({ + detail, + modalStates, + versionPicker, + isFromMarketplace: false, + onUpdate: mockOnUpdate, + }), + ) + + await act(async () => { + await result.current.handleUpdate() + }) + + expect(mockSetShowUpdatePluginModal).toHaveBeenCalled() + }) + + it('should not show modal when no releases found', async () => { + mockFetchReleases.mockResolvedValueOnce([]) + const detail = createPluginDetail({ + source: PluginSource.github, + meta: { repo: 'owner/repo', version: 'v1.0.0', package: 'pkg' }, + }) + const { result } = renderHook(() => + usePluginOperations({ + detail, + modalStates, + versionPicker, + isFromMarketplace: false, + onUpdate: mockOnUpdate, + }), + ) + + await act(async () => { + await result.current.handleUpdate() + }) + + expect(mockSetShowUpdatePluginModal).not.toHaveBeenCalled() + }) + + it('should not show modal when no update needed', async () => { + mockCheckForUpdates.mockReturnValueOnce({ + needUpdate: false, + toastProps: { type: 'info', message: 'Already up to date' }, + }) + const detail = createPluginDetail({ + source: PluginSource.github, + meta: { repo: 'owner/repo', version: 'v1.0.0', package: 'pkg' }, + }) + const { result } = renderHook(() => + usePluginOperations({ + detail, + modalStates, + versionPicker, + isFromMarketplace: false, + onUpdate: mockOnUpdate, + }), + ) + + await act(async () => { + await result.current.handleUpdate() + }) + + expect(mockSetShowUpdatePluginModal).not.toHaveBeenCalled() + }) + + it('should use author and name as fallback for repo parsing', async () => { + const detail = createPluginDetail({ + source: PluginSource.github, + meta: { repo: '/', version: 'v1.0.0', package: 'pkg' }, + declaration: { + author: 'fallback-author', + name: 'fallback-name', + category: 'tool', + label: { en_US: 'Test' }, + description: { en_US: 'Test' }, + icon: 'icon.png', + verified: true, + } as unknown as PluginDetail['declaration'], + }) + const { result } = renderHook(() => + usePluginOperations({ + detail, + modalStates, + versionPicker, + isFromMarketplace: false, + onUpdate: mockOnUpdate, + }), + ) + + await act(async () => { + await result.current.handleUpdate() + }) + + expect(mockFetchReleases).toHaveBeenCalledWith('fallback-author', 'fallback-name') + }) + }) + + describe('Delete Flow', () => { + it('should call uninstallPlugin with correct id', async () => { + const detail = createPluginDetail({ id: 'plugin-to-delete' }) + const { result } = renderHook(() => + usePluginOperations({ + detail, + modalStates, + versionPicker, + isFromMarketplace: true, + onUpdate: mockOnUpdate, + }), + ) + + await act(async () => { + await result.current.handleDelete() + }) + + expect(mockUninstallPlugin).toHaveBeenCalledWith('plugin-to-delete') + }) + + it('should show and hide deleting state during delete', async () => { + const detail = createPluginDetail() + const { result } = renderHook(() => + usePluginOperations({ + detail, + modalStates, + versionPicker, + isFromMarketplace: true, + onUpdate: mockOnUpdate, + }), + ) + + await act(async () => { + await result.current.handleDelete() + }) + + expect(modalStates.showDeleting).toHaveBeenCalled() + expect(modalStates.hideDeleting).toHaveBeenCalled() + }) + + it('should call onUpdate with true after successful delete', async () => { + const detail = createPluginDetail() + const { result } = renderHook(() => + usePluginOperations({ + detail, + modalStates, + versionPicker, + isFromMarketplace: true, + onUpdate: mockOnUpdate, + }), + ) + + await act(async () => { + await result.current.handleDelete() + }) + + expect(mockOnUpdate).toHaveBeenCalledWith(true) + }) + + it('should hide delete confirm after successful delete', async () => { + const detail = createPluginDetail() + const { result } = renderHook(() => + usePluginOperations({ + detail, + modalStates, + versionPicker, + isFromMarketplace: true, + onUpdate: mockOnUpdate, + }), + ) + + await act(async () => { + await result.current.handleDelete() + }) + + expect(modalStates.hideDeleteConfirm).toHaveBeenCalled() + }) + + it('should refresh model providers when deleting model plugin', async () => { + const detail = createPluginDetail({ + declaration: { + author: 'test-author', + name: 'test-plugin-name', + category: 'model', + label: { en_US: 'Test' }, + description: { en_US: 'Test' }, + icon: 'icon.png', + verified: true, + } as unknown as PluginDetail['declaration'], + }) + const { result } = renderHook(() => + usePluginOperations({ + detail, + modalStates, + versionPicker, + isFromMarketplace: true, + onUpdate: mockOnUpdate, + }), + ) + + await act(async () => { + await result.current.handleDelete() + }) + + expect(mockRefreshModelProviders).toHaveBeenCalled() + }) + + it('should invalidate tool providers when deleting tool plugin', async () => { + const detail = createPluginDetail({ + declaration: { + author: 'test-author', + name: 'test-plugin-name', + category: 'tool', + label: { en_US: 'Test' }, + description: { en_US: 'Test' }, + icon: 'icon.png', + verified: true, + } as unknown as PluginDetail['declaration'], + }) + const { result } = renderHook(() => + usePluginOperations({ + detail, + modalStates, + versionPicker, + isFromMarketplace: true, + onUpdate: mockOnUpdate, + }), + ) + + await act(async () => { + await result.current.handleDelete() + }) + + expect(mockInvalidateAllToolProviders).toHaveBeenCalled() + }) + + it('should track plugin uninstalled event', async () => { + const detail = createPluginDetail() + const { result } = renderHook(() => + usePluginOperations({ + detail, + modalStates, + versionPicker, + isFromMarketplace: true, + onUpdate: mockOnUpdate, + }), + ) + + await act(async () => { + await result.current.handleDelete() + }) + + expect(amplitude.trackEvent).toHaveBeenCalledWith('plugin_uninstalled', expect.objectContaining({ + plugin_id: 'test-plugin', + plugin_name: 'test-plugin-name', + })) + }) + + it('should not call onUpdate when delete fails', async () => { + mockUninstallPlugin.mockResolvedValueOnce({ success: false }) + const detail = createPluginDetail() + const { result } = renderHook(() => + usePluginOperations({ + detail, + modalStates, + versionPicker, + isFromMarketplace: true, + onUpdate: mockOnUpdate, + }), + ) + + await act(async () => { + await result.current.handleDelete() + }) + + expect(mockOnUpdate).not.toHaveBeenCalled() + }) + }) + + describe('Optional onUpdate Callback', () => { + it('should not throw when onUpdate is not provided for marketplace update', () => { + const detail = createPluginDetail() + const { result } = renderHook(() => + usePluginOperations({ + detail, + modalStates, + versionPicker, + isFromMarketplace: true, + }), + ) + + expect(() => { + result.current.handleUpdatedFromMarketplace() + }).not.toThrow() + }) + + it('should not throw when onUpdate is not provided for delete', async () => { + const detail = createPluginDetail() + const { result } = renderHook(() => + usePluginOperations({ + detail, + modalStates, + versionPicker, + isFromMarketplace: true, + }), + ) + + await expect( + act(async () => { + await result.current.handleDelete() + }), + ).resolves.not.toThrow() + }) + }) +}) diff --git a/web/app/components/plugins/plugin-detail-panel/detail-header/hooks/use-plugin-operations.ts b/web/app/components/plugins/plugin-detail-panel/detail-header/hooks/use-plugin-operations.ts new file mode 100644 index 0000000000..f3f0296a88 --- /dev/null +++ b/web/app/components/plugins/plugin-detail-panel/detail-header/hooks/use-plugin-operations.ts @@ -0,0 +1,143 @@ +'use client' + +import type { PluginDetail } from '../../../types' +import type { ModalStates, VersionTarget } from './use-detail-header-state' +import { useCallback } from 'react' +import { trackEvent } from '@/app/components/base/amplitude' +import Toast from '@/app/components/base/toast' +import { useModalContext } from '@/context/modal-context' +import { useProviderContext } from '@/context/provider-context' +import { uninstallPlugin } from '@/service/plugins' +import { useInvalidateAllToolProviders } from '@/service/use-tools' +import { useGitHubReleases } from '../../../install-plugin/hooks' +import { PluginCategoryEnum, PluginSource } from '../../../types' + +type UsePluginOperationsParams = { + detail: PluginDetail + modalStates: ModalStates + versionPicker: { + setTargetVersion: (version: VersionTarget) => void + setIsDowngrade: (downgrade: boolean) => void + } + isFromMarketplace: boolean + onUpdate?: (isDelete?: boolean) => void +} + +type UsePluginOperationsReturn = { + handleUpdate: (isDowngrade?: boolean) => Promise<void> + handleUpdatedFromMarketplace: () => void + handleDelete: () => Promise<void> +} + +export const usePluginOperations = ({ + detail, + modalStates, + versionPicker, + isFromMarketplace, + onUpdate, +}: UsePluginOperationsParams): UsePluginOperationsReturn => { + const { checkForUpdates, fetchReleases } = useGitHubReleases() + const { setShowUpdatePluginModal } = useModalContext() + const { refreshModelProviders } = useProviderContext() + const invalidateAllToolProviders = useInvalidateAllToolProviders() + + const { id, meta, plugin_id } = detail + const { author, category, name } = detail.declaration || detail + + const handleUpdate = useCallback(async (isDowngrade?: boolean) => { + if (isFromMarketplace) { + versionPicker.setIsDowngrade(!!isDowngrade) + modalStates.showUpdateModal() + return + } + + if (!meta?.repo || !meta?.version || !meta?.package) { + Toast.notify({ + type: 'error', + message: 'Missing plugin metadata for GitHub update', + }) + return + } + + const owner = meta.repo.split('/')[0] || author + const repo = meta.repo.split('/')[1] || name + const fetchedReleases = await fetchReleases(owner, repo) + if (fetchedReleases.length === 0) + return + + const { needUpdate, toastProps } = checkForUpdates(fetchedReleases, meta.version) + Toast.notify(toastProps) + + if (needUpdate) { + setShowUpdatePluginModal({ + onSaveCallback: () => { + onUpdate?.() + }, + payload: { + type: PluginSource.github, + category, + github: { + originalPackageInfo: { + id: detail.plugin_unique_identifier, + repo: meta.repo, + version: meta.version, + package: meta.package, + releases: fetchedReleases, + }, + }, + }, + }) + } + }, [ + isFromMarketplace, + meta, + author, + name, + fetchReleases, + checkForUpdates, + setShowUpdatePluginModal, + detail, + onUpdate, + modalStates, + versionPicker, + ]) + + const handleUpdatedFromMarketplace = useCallback(() => { + onUpdate?.() + modalStates.hideUpdateModal() + }, [onUpdate, modalStates]) + + const handleDelete = useCallback(async () => { + modalStates.showDeleting() + const res = await uninstallPlugin(id) + modalStates.hideDeleting() + + if (res.success) { + modalStates.hideDeleteConfirm() + onUpdate?.(true) + + if (PluginCategoryEnum.model.includes(category)) + refreshModelProviders() + + if (PluginCategoryEnum.tool.includes(category)) + invalidateAllToolProviders() + + trackEvent('plugin_uninstalled', { plugin_id, plugin_name: name }) + } + }, [ + id, + category, + plugin_id, + name, + modalStates, + onUpdate, + refreshModelProviders, + invalidateAllToolProviders, + ]) + + return { + handleUpdate, + handleUpdatedFromMarketplace, + handleDelete, + } +} diff --git a/web/app/components/plugins/plugin-detail-panel/detail-header/index.tsx b/web/app/components/plugins/plugin-detail-panel/detail-header/index.tsx new file mode 100644 index 0000000000..8f265c5717 --- /dev/null +++ b/web/app/components/plugins/plugin-detail-panel/detail-header/index.tsx @@ -0,0 +1,286 @@ +'use client' + +import type { PluginDetail } from '../../types' +import { + RiArrowLeftRightLine, + RiCloseLine, +} from '@remixicon/react' +import { useMemo } from 'react' +import { useTranslation } from 'react-i18next' +import ActionButton from '@/app/components/base/action-button' +import Badge from '@/app/components/base/badge' +import Button from '@/app/components/base/button' +import Tooltip from '@/app/components/base/tooltip' +import { AuthCategory, PluginAuth } from '@/app/components/plugins/plugin-auth' +import OperationDropdown from '@/app/components/plugins/plugin-detail-panel/operation-dropdown' +import PluginVersionPicker from '@/app/components/plugins/update-plugin/plugin-version-picker' +import { API_PREFIX } from '@/config' +import { useAppContext } from '@/context/app-context' +import { useGetLanguage, useLocale } from '@/context/i18n' +import useTheme from '@/hooks/use-theme' +import { useAllToolProviders } from '@/service/use-tools' +import { cn } from '@/utils/classnames' +import { getMarketplaceUrl } from '@/utils/var' +import { AutoUpdateLine } from '../../../base/icons/src/vender/system' +import Verified from '../../base/badges/verified' +import DeprecationNotice from '../../base/deprecation-notice' +import Icon from '../../card/base/card-icon' +import Description from '../../card/base/description' +import OrgInfo from '../../card/base/org-info' +import Title from '../../card/base/title' +import useReferenceSetting from '../../plugin-page/use-reference-setting' +import { convertUTCDaySecondsToLocalSeconds, timeOfDayToDayjs } from '../../reference-setting-modal/auto-update-setting/utils' +import { PluginCategoryEnum, PluginSource } from '../../types' +import { HeaderModals, PluginSourceBadge } from './components' +import { useDetailHeaderState, usePluginOperations } from './hooks' + +type Props = { + detail: PluginDetail + isReadmeView?: boolean + onHide?: () => void + onUpdate?: (isDelete?: boolean) => void +} + +const getIconSrc = (icon: string | undefined, iconDark: string | undefined, theme: string, tenantId: string): string => { + const iconFileName = theme === 'dark' && iconDark ? iconDark : icon + if (!iconFileName) + return '' + return iconFileName.startsWith('http') + ? iconFileName + : `${API_PREFIX}/workspaces/current/plugin/icon?tenant_id=${tenantId}&filename=${iconFileName}` +} + +const getDetailUrl = ( + source: PluginSource, + meta: PluginDetail['meta'], + author: string, + name: string, + locale: string, + theme: string, +): string => { + if (source === PluginSource.github) { + const repo = meta?.repo + if (!repo) + return '' + return `https://github.com/${repo}` + } + if (source === PluginSource.marketplace) + return getMarketplaceUrl(`/plugins/${author}/${name}`, { language: locale, theme }) + return '' +} + +const DetailHeader = ({ + detail, + isReadmeView = false, + onHide, + onUpdate, +}: Props) => { + const { t } = useTranslation() + const { userProfile: { timezone } } = useAppContext() + const { theme } = useTheme() + const locale = useGetLanguage() + const currentLocale = useLocale() + const { referenceSetting } = useReferenceSetting() + + const { + source, + tenant_id, + version, + latest_version, + latest_unique_identifier, + meta, + plugin_id, + status, + deprecated_reason, + alternative_plugin_id, + } = detail + + const { author, category, name, label, description, icon, icon_dark, verified, tool } = detail.declaration || detail + + const { + modalStates, + versionPicker, + hasNewVersion, + isAutoUpgradeEnabled, + isFromGitHub, + isFromMarketplace, + } = useDetailHeaderState(detail) + + const { + handleUpdate, + handleUpdatedFromMarketplace, + handleDelete, + } = usePluginOperations({ + detail, + modalStates, + versionPicker, + isFromMarketplace, + onUpdate, + }) + + const isTool = category === PluginCategoryEnum.tool + const providerBriefInfo = tool?.identity + const providerKey = `${plugin_id}/${providerBriefInfo?.name}` + const { data: collectionList = [] } = useAllToolProviders(isTool) + const provider = useMemo(() => { + return collectionList.find(collection => collection.name === providerKey) + }, [collectionList, providerKey]) + + const iconSrc = getIconSrc(icon, icon_dark, theme, tenant_id) + const detailUrl = getDetailUrl(source, meta, author, name, currentLocale, theme) + const { auto_upgrade: autoUpgradeInfo } = referenceSetting || {} + + const handleVersionSelect = (state: { version: string, unique_identifier: string, isDowngrade?: boolean }) => { + versionPicker.setTargetVersion(state) + handleUpdate(state.isDowngrade) + } + + const handleTriggerLatestUpdate = () => { + if (isFromMarketplace) { + versionPicker.setTargetVersion({ + version: latest_version, + unique_identifier: latest_unique_identifier, + }) + } + handleUpdate() + } + + return ( + <div className={cn('shrink-0 border-b border-divider-subtle bg-components-panel-bg p-4 pb-3', isReadmeView && 'border-b-0 bg-transparent p-0')}> + <div className="flex"> + {/* Plugin Icon */} + <div className={cn('overflow-hidden rounded-xl border border-components-panel-border-subtle', isReadmeView && 'bg-components-panel-bg')}> + <Icon src={iconSrc} /> + </div> + + {/* Plugin Info */} + <div className="ml-3 w-0 grow"> + {/* Title Row */} + <div className="flex h-5 items-center"> + <Title title={label[locale]} /> + {verified && !isReadmeView && <Verified className="ml-0.5 h-4 w-4" text={t('marketplace.verifiedTip', { ns: 'plugin' })} />} + + {/* Version Picker */} + {!!version && ( + <PluginVersionPicker + disabled={!isFromMarketplace || isReadmeView} + isShow={versionPicker.isShow} + onShowChange={versionPicker.setIsShow} + pluginID={plugin_id} + currentVersion={version} + onSelect={handleVersionSelect} + trigger={( + <Badge + className={cn( + 'mx-1', + versionPicker.isShow && 'bg-state-base-hover', + (versionPicker.isShow || isFromMarketplace) && 'hover:bg-state-base-hover', + )} + uppercase={false} + text={( + <> + <div>{isFromGitHub ? (meta?.version ?? version ?? '') : version}</div> + {isFromMarketplace && !isReadmeView && <RiArrowLeftRightLine className="ml-1 h-3 w-3 text-text-tertiary" />} + </> + )} + hasRedCornerMark={hasNewVersion} + /> + )} + /> + )} + + {/* Auto Update Badge */} + {isAutoUpgradeEnabled && !isReadmeView && ( + <Tooltip popupContent={t('autoUpdate.nextUpdateTime', { ns: 'plugin', time: timeOfDayToDayjs(convertUTCDaySecondsToLocalSeconds(autoUpgradeInfo?.upgrade_time_of_day || 0, timezone!)).format('hh:mm A') })}> + <div> + <Badge className="mr-1 cursor-pointer px-1"> + <AutoUpdateLine className="size-3" /> + </Badge> + </div> + </Tooltip> + )} + + {/* Update Button */} + {(hasNewVersion || isFromGitHub) && ( + <Button + variant="secondary-accent" + size="small" + className="!h-5" + onClick={handleTriggerLatestUpdate} + > + {t('detailPanel.operation.update', { ns: 'plugin' })} + </Button> + )} + </div> + + {/* Org Info Row */} + <div className="mb-1 flex h-4 items-center justify-between"> + <div className="mt-0.5 flex items-center"> + <OrgInfo + packageNameClassName="w-auto" + orgName={author} + packageName={name?.includes('/') ? (name.split('/').pop() || '') : name} + /> + {!!source && <PluginSourceBadge source={source} />} + </div> + </div> + </div> + + {/* Action Buttons */} + {!isReadmeView && ( + <div className="flex gap-1"> + <OperationDropdown + source={source} + onInfo={modalStates.showPluginInfo} + onCheckVersion={handleUpdate} + onRemove={modalStates.showDeleteConfirm} + detailUrl={detailUrl} + /> + <ActionButton onClick={onHide}> + <RiCloseLine className="h-4 w-4" /> + </ActionButton> + </div> + )} + </div> + + {/* Deprecation Notice */} + {isFromMarketplace && ( + <DeprecationNotice + status={status} + deprecatedReason={deprecated_reason} + alternativePluginId={alternative_plugin_id} + alternativePluginURL={getMarketplaceUrl(`/plugins/${alternative_plugin_id}`, { language: currentLocale, theme })} + className="mt-3" + /> + )} + + {/* Description */} + {!isReadmeView && <Description className="mb-2 mt-3 h-auto" text={description[locale]} descriptionLineRows={2} />} + + {/* Plugin Auth for Tools */} + {category === PluginCategoryEnum.tool && !isReadmeView && ( + <PluginAuth + pluginPayload={{ + provider: provider?.name || '', + category: AuthCategory.tool, + providerType: provider?.type || '', + detail, + }} + /> + )} + + {/* Modals */} + <HeaderModals + detail={detail} + modalStates={modalStates} + targetVersion={versionPicker.targetVersion} + isDowngrade={versionPicker.isDowngrade} + isAutoUpgradeEnabled={isAutoUpgradeEnabled} + onUpdatedFromMarketplace={handleUpdatedFromMarketplace} + onDelete={handleDelete} + /> + </div> + ) +} + +export default DetailHeader diff --git a/web/app/components/plugins/plugin-detail-panel/subscription-list/create/common-modal.spec.tsx b/web/app/components/plugins/plugin-detail-panel/subscription-list/create/common-modal.spec.tsx index 9155fa15be..b7e4f01f58 100644 --- a/web/app/components/plugins/plugin-detail-panel/subscription-list/create/common-modal.spec.tsx +++ b/web/app/components/plugins/plugin-detail-panel/subscription-list/create/common-modal.spec.tsx @@ -2,15 +2,10 @@ import type { TriggerSubscriptionBuilder } from '@/app/components/workflow/block import { act, fireEvent, render, screen, waitFor } from '@testing-library/react' import * as React from 'react' import { afterEach, beforeEach, describe, expect, it, vi } from 'vitest' -// Import after mocks import { SupportedCreationMethods } from '@/app/components/plugins/types' import { TriggerCredentialTypeEnum } from '@/app/components/workflow/block-selector/types' import { CommonCreateModal } from './common-modal' -// ============================================================================ -// Type Definitions -// ============================================================================ - type PluginDetail = { plugin_id: string provider: string @@ -33,10 +28,6 @@ type TriggerLogEntity = { level: 'info' | 'warn' | 'error' } -// ============================================================================ -// Mock Factory Functions -// ============================================================================ - function createMockPluginDetail(overrides: Partial<PluginDetail> = {}): PluginDetail { return { plugin_id: 'test-plugin-id', @@ -74,18 +65,12 @@ function createMockLogData(logs: TriggerLogEntity[] = []): { logs: TriggerLogEnt return { logs } } -// ============================================================================ -// Mock Setup -// ============================================================================ - -// Mock plugin store const mockPluginDetail = createMockPluginDetail() const mockUsePluginStore = vi.fn(() => mockPluginDetail) vi.mock('../../store', () => ({ usePluginStore: () => mockUsePluginStore(), })) -// Mock subscription list hook const mockRefetch = vi.fn() vi.mock('../use-subscription-list', () => ({ useSubscriptionList: () => ({ @@ -93,13 +78,11 @@ vi.mock('../use-subscription-list', () => ({ }), })) -// Mock service hooks const mockVerifyCredentials = vi.fn() const mockCreateBuilder = vi.fn() const mockBuildSubscription = vi.fn() const mockUpdateBuilder = vi.fn() -// Configurable pending states let mockIsVerifyingCredentials = false let mockIsBuilding = false const setMockPendingStates = (verifying: boolean, building: boolean) => { @@ -129,18 +112,15 @@ vi.mock('@/service/use-triggers', () => ({ }), })) -// Mock error parser const mockParsePluginErrorMessage = vi.fn().mockResolvedValue(null) vi.mock('@/utils/error-parser', () => ({ parsePluginErrorMessage: (...args: unknown[]) => mockParsePluginErrorMessage(...args), })) -// Mock URL validation vi.mock('@/utils/urlValidation', () => ({ isPrivateOrLocalAddress: vi.fn().mockReturnValue(false), })) -// Mock toast const mockToastNotify = vi.fn() vi.mock('@/app/components/base/toast', () => ({ default: { @@ -148,7 +128,6 @@ vi.mock('@/app/components/base/toast', () => ({ }, })) -// Mock Modal component vi.mock('@/app/components/base/modal/modal', () => ({ default: ({ children, @@ -179,7 +158,6 @@ vi.mock('@/app/components/base/modal/modal', () => ({ ), })) -// Configurable form mock values type MockFormValuesConfig = { values: Record<string, unknown> isCheckValidated: boolean @@ -190,7 +168,6 @@ let mockFormValuesConfig: MockFormValuesConfig = { } let mockGetFormReturnsNull = false -// Separate validation configs for different forms let mockSubscriptionFormValidated = true let mockAutoParamsFormValidated = true let mockManualPropsFormValidated = true @@ -207,7 +184,6 @@ const setMockFormValidation = (subscription: boolean, autoParams: boolean, manua mockManualPropsFormValidated = manualProps } -// Mock BaseForm component with ref support vi.mock('@/app/components/base/form/components/base', async () => { const React = await import('react') @@ -219,7 +195,6 @@ vi.mock('@/app/components/base/form/components/base', async () => { type MockBaseFormProps = { formSchemas: Array<{ name: string }>, onChange?: () => void } function MockBaseFormInner({ formSchemas, onChange }: MockBaseFormProps, ref: React.ForwardedRef<MockFormRef>) { - // Determine which form this is based on schema const isSubscriptionForm = formSchemas.some((s: { name: string }) => s.name === 'subscription_name') const isAutoParamsForm = formSchemas.some((s: { name: string }) => ['repo_name', 'branch', 'repo', 'text_field', 'dynamic_field', 'bool_field', 'text_input_field', 'unknown_field', 'count'].includes(s.name), @@ -265,12 +240,10 @@ vi.mock('@/app/components/base/form/components/base', async () => { } }) -// Mock EncryptedBottom component vi.mock('@/app/components/base/encrypted-bottom', () => ({ EncryptedBottom: () => <div data-testid="encrypted-bottom">Encrypted</div>, })) -// Mock LogViewer component vi.mock('../log-viewer', () => ({ default: ({ logs }: { logs: TriggerLogEntity[] }) => ( <div data-testid="log-viewer"> @@ -281,7 +254,6 @@ vi.mock('../log-viewer', () => ({ ), })) -// Mock debounce vi.mock('es-toolkit/compat', () => ({ debounce: (fn: (...args: unknown[]) => unknown) => { const debouncedFn = (...args: unknown[]) => fn(...args) @@ -290,10 +262,6 @@ vi.mock('es-toolkit/compat', () => ({ }, })) -// ============================================================================ -// Test Suites -// ============================================================================ - describe('CommonCreateModal', () => { const defaultProps = { onClose: vi.fn(), @@ -441,7 +409,8 @@ describe('CommonCreateModal', () => { }) it('should call onConfirm handler when confirm button is clicked', () => { - render(<CommonCreateModal {...defaultProps} />) + // Provide builder so the guard passes and credentials check is reached + render(<CommonCreateModal {...defaultProps} builder={createMockSubscriptionBuilder()} />) fireEvent.click(screen.getByTestId('modal-confirm')) @@ -1243,13 +1212,22 @@ describe('CommonCreateModal', () => { render(<CommonCreateModal {...defaultProps} createType={SupportedCreationMethods.MANUAL} />) + // Wait for createBuilder to complete and state to update await waitFor(() => { expect(mockCreateBuilder).toHaveBeenCalled() }) + // Allow React to process the state update from createBuilder + await act(async () => {}) + const input = screen.getByTestId('form-field-webhook_url') fireEvent.change(input, { target: { value: 'https://example.com/webhook' } }) + // Wait for updateBuilder to be called, then check the toast + await waitFor(() => { + expect(mockUpdateBuilder).toHaveBeenCalled() + }) + await waitFor(() => { expect(mockToastNotify).toHaveBeenCalledWith({ type: 'error', @@ -1450,7 +1428,8 @@ describe('CommonCreateModal', () => { }) mockUsePluginStore.mockReturnValue(detailWithCredentials) - render(<CommonCreateModal {...defaultProps} />) + // Provide builder so the guard passes and credentials check is reached + render(<CommonCreateModal {...defaultProps} builder={createMockSubscriptionBuilder()} />) fireEvent.click(screen.getByTestId('modal-confirm')) diff --git a/web/app/components/plugins/plugin-detail-panel/subscription-list/create/common-modal.tsx b/web/app/components/plugins/plugin-detail-panel/subscription-list/create/common-modal.tsx index 91a844fb86..15d3417c9b 100644 --- a/web/app/components/plugins/plugin-detail-panel/subscription-list/create/common-modal.tsx +++ b/web/app/components/plugins/plugin-detail-panel/subscription-list/create/common-modal.tsx @@ -1,32 +1,19 @@ 'use client' -import type { FormRefObject } from '@/app/components/base/form/types' import type { TriggerSubscriptionBuilder } from '@/app/components/workflow/block-selector/types' -import type { BuildTriggerSubscriptionPayload } from '@/service/use-triggers' -import { RiLoader2Line } from '@remixicon/react' -import { debounce } from 'es-toolkit/compat' -import * as React from 'react' -import { useCallback, useEffect, useMemo, useRef, useState } from 'react' import { useTranslation } from 'react-i18next' -// import { CopyFeedbackNew } from '@/app/components/base/copy-feedback' import { EncryptedBottom } from '@/app/components/base/encrypted-bottom' -import { BaseForm } from '@/app/components/base/form/components/base' -import { FormTypeEnum } from '@/app/components/base/form/types' import Modal from '@/app/components/base/modal/modal' -import Toast from '@/app/components/base/toast' import { SupportedCreationMethods } from '@/app/components/plugins/types' -import { TriggerCredentialTypeEnum } from '@/app/components/workflow/block-selector/types' import { - useBuildTriggerSubscription, - useCreateTriggerSubscriptionBuilder, - useTriggerSubscriptionBuilderLogs, - useUpdateTriggerSubscriptionBuilder, - useVerifyAndUpdateTriggerSubscriptionBuilder, -} from '@/service/use-triggers' -import { parsePluginErrorMessage } from '@/utils/error-parser' -import { isPrivateOrLocalAddress } from '@/utils/urlValidation' -import { usePluginStore } from '../../store' -import LogViewer from '../log-viewer' -import { useSubscriptionList } from '../use-subscription-list' + ConfigurationStepContent, + MultiSteps, + VerifyStepContent, +} from './components/modal-steps' +import { + ApiKeyStep, + MODAL_TITLE_KEY_MAP, + useCommonModalState, +} from './hooks/use-common-modal-state' type Props = { onClose: () => void @@ -34,316 +21,33 @@ type Props = { builder?: TriggerSubscriptionBuilder } -const CREDENTIAL_TYPE_MAP: Record<SupportedCreationMethods, TriggerCredentialTypeEnum> = { - [SupportedCreationMethods.APIKEY]: TriggerCredentialTypeEnum.ApiKey, - [SupportedCreationMethods.OAUTH]: TriggerCredentialTypeEnum.Oauth2, - [SupportedCreationMethods.MANUAL]: TriggerCredentialTypeEnum.Unauthorized, -} - -const MODAL_TITLE_KEY_MAP: Record< - SupportedCreationMethods, - 'modal.apiKey.title' | 'modal.oauth.title' | 'modal.manual.title' -> = { - [SupportedCreationMethods.APIKEY]: 'modal.apiKey.title', - [SupportedCreationMethods.OAUTH]: 'modal.oauth.title', - [SupportedCreationMethods.MANUAL]: 'modal.manual.title', -} - -enum ApiKeyStep { - Verify = 'verify', - Configuration = 'configuration', -} - -const defaultFormValues = { values: {}, isCheckValidated: false } - -const normalizeFormType = (type: FormTypeEnum | string): FormTypeEnum => { - if (Object.values(FormTypeEnum).includes(type as FormTypeEnum)) - return type as FormTypeEnum - - switch (type) { - case 'string': - case 'text': - return FormTypeEnum.textInput - case 'password': - case 'secret': - return FormTypeEnum.secretInput - case 'number': - case 'integer': - return FormTypeEnum.textNumber - case 'boolean': - return FormTypeEnum.boolean - default: - return FormTypeEnum.textInput - } -} - -const StatusStep = ({ isActive, text }: { isActive: boolean, text: string }) => { - return ( - <div className={`system-2xs-semibold-uppercase flex items-center gap-1 ${isActive - ? 'text-state-accent-solid' - : 'text-text-tertiary'}`} - > - {/* Active indicator dot */} - {isActive && ( - <div className="h-1 w-1 rounded-full bg-state-accent-solid"></div> - )} - {text} - </div> - ) -} - -const MultiSteps = ({ currentStep }: { currentStep: ApiKeyStep }) => { - const { t } = useTranslation() - return ( - <div className="mb-6 flex w-1/3 items-center gap-2"> - <StatusStep isActive={currentStep === ApiKeyStep.Verify} text={t('modal.steps.verify', { ns: 'pluginTrigger' })} /> - <div className="h-px w-3 shrink-0 bg-divider-deep"></div> - <StatusStep isActive={currentStep === ApiKeyStep.Configuration} text={t('modal.steps.configuration', { ns: 'pluginTrigger' })} /> - </div> - ) -} - export const CommonCreateModal = ({ onClose, createType, builder }: Props) => { const { t } = useTranslation() - const detail = usePluginStore(state => state.detail) - const { refetch } = useSubscriptionList() - const [currentStep, setCurrentStep] = useState<ApiKeyStep>(createType === SupportedCreationMethods.APIKEY ? ApiKeyStep.Verify : ApiKeyStep.Configuration) + const { + currentStep, + subscriptionBuilder, + isVerifyingCredentials, + isBuilding, + formRefs, + detail, + manualPropertiesSchema, + autoCommonParametersSchema, + apiKeyCredentialsSchema, + logData, + confirmButtonText, + handleConfirm, + handleManualPropertiesChange, + handleApiKeyCredentialsChange, + } = useCommonModalState({ + createType, + builder, + onClose, + }) - const [subscriptionBuilder, setSubscriptionBuilder] = useState<TriggerSubscriptionBuilder | undefined>(builder) - const isInitializedRef = useRef(false) - - const { mutate: verifyCredentials, isPending: isVerifyingCredentials } = useVerifyAndUpdateTriggerSubscriptionBuilder() - const { mutateAsync: createBuilder /* isPending: isCreatingBuilder */ } = useCreateTriggerSubscriptionBuilder() - const { mutate: buildSubscription, isPending: isBuilding } = useBuildTriggerSubscription() - const { mutate: updateBuilder } = useUpdateTriggerSubscriptionBuilder() - - const manualPropertiesSchema = detail?.declaration?.trigger?.subscription_schema || [] // manual - const manualPropertiesFormRef = React.useRef<FormRefObject>(null) - - const subscriptionFormRef = React.useRef<FormRefObject>(null) - - const autoCommonParametersSchema = detail?.declaration.trigger?.subscription_constructor?.parameters || [] // apikey and oauth - const autoCommonParametersFormRef = React.useRef<FormRefObject>(null) - - const apiKeyCredentialsSchema = useMemo(() => { - const rawSchema = detail?.declaration?.trigger?.subscription_constructor?.credentials_schema || [] - return rawSchema.map(schema => ({ - ...schema, - tooltip: schema.help, - })) - }, [detail?.declaration?.trigger?.subscription_constructor?.credentials_schema]) - const apiKeyCredentialsFormRef = React.useRef<FormRefObject>(null) - - const { data: logData } = useTriggerSubscriptionBuilderLogs( - detail?.provider || '', - subscriptionBuilder?.id || '', - { - enabled: createType === SupportedCreationMethods.MANUAL, - refetchInterval: 3000, - }, - ) - - useEffect(() => { - const initializeBuilder = async () => { - isInitializedRef.current = true - try { - const response = await createBuilder({ - provider: detail?.provider || '', - credential_type: CREDENTIAL_TYPE_MAP[createType], - }) - setSubscriptionBuilder(response.subscription_builder) - } - catch (error) { - console.error('createBuilder error:', error) - Toast.notify({ - type: 'error', - message: t('modal.errors.createFailed', { ns: 'pluginTrigger' }), - }) - } - } - if (!isInitializedRef.current && !subscriptionBuilder && detail?.provider) - initializeBuilder() - }, [subscriptionBuilder, detail?.provider, createType, createBuilder, t]) - - useEffect(() => { - if (subscriptionBuilder?.endpoint && subscriptionFormRef.current && currentStep === ApiKeyStep.Configuration) { - const form = subscriptionFormRef.current.getForm() - if (form) - form.setFieldValue('callback_url', subscriptionBuilder.endpoint) - if (isPrivateOrLocalAddress(subscriptionBuilder.endpoint)) { - console.warn('callback_url is private or local address', subscriptionBuilder.endpoint) - subscriptionFormRef.current?.setFields([{ - name: 'callback_url', - warnings: [t('modal.form.callbackUrl.privateAddressWarning', { ns: 'pluginTrigger' })], - }]) - } - else { - subscriptionFormRef.current?.setFields([{ - name: 'callback_url', - warnings: [], - }]) - } - } - }, [subscriptionBuilder?.endpoint, currentStep, t]) - - const debouncedUpdate = useMemo( - () => debounce((provider: string, builderId: string, properties: Record<string, unknown>) => { - updateBuilder( - { - provider, - subscriptionBuilderId: builderId, - properties, - }, - { - onError: async (error: unknown) => { - const errorMessage = await parsePluginErrorMessage(error) || t('modal.errors.updateFailed', { ns: 'pluginTrigger' }) - console.error('Failed to update subscription builder:', error) - Toast.notify({ - type: 'error', - message: errorMessage, - }) - }, - }, - ) - }, 500), - [updateBuilder, t], - ) - - const handleManualPropertiesChange = useCallback(() => { - if (!subscriptionBuilder || !detail?.provider) - return - - const formValues = manualPropertiesFormRef.current?.getFormValues({ needCheckValidatedValues: false }) || { values: {}, isCheckValidated: true } - - debouncedUpdate(detail.provider, subscriptionBuilder.id, formValues.values) - }, [subscriptionBuilder, detail?.provider, debouncedUpdate]) - - useEffect(() => { - return () => { - debouncedUpdate.cancel() - } - }, [debouncedUpdate]) - - const handleVerify = () => { - const apiKeyCredentialsFormValues = apiKeyCredentialsFormRef.current?.getFormValues({}) || defaultFormValues - const credentials = apiKeyCredentialsFormValues.values - - if (!Object.keys(credentials).length) { - Toast.notify({ - type: 'error', - message: 'Please fill in all required credentials', - }) - return - } - - apiKeyCredentialsFormRef.current?.setFields([{ - name: Object.keys(credentials)[0], - errors: [], - }]) - - verifyCredentials( - { - provider: detail?.provider || '', - subscriptionBuilderId: subscriptionBuilder?.id || '', - credentials, - }, - { - onSuccess: () => { - Toast.notify({ - type: 'success', - message: t('modal.apiKey.verify.success', { ns: 'pluginTrigger' }), - }) - setCurrentStep(ApiKeyStep.Configuration) - }, - onError: async (error: unknown) => { - const errorMessage = await parsePluginErrorMessage(error) || t('modal.apiKey.verify.error', { ns: 'pluginTrigger' }) - apiKeyCredentialsFormRef.current?.setFields([{ - name: Object.keys(credentials)[0], - errors: [errorMessage], - }]) - }, - }, - ) - } - - const handleCreate = () => { - if (!subscriptionBuilder) { - Toast.notify({ - type: 'error', - message: 'Subscription builder not found', - }) - return - } - - const subscriptionFormValues = subscriptionFormRef.current?.getFormValues({}) - if (!subscriptionFormValues?.isCheckValidated) - return - - const subscriptionNameValue = subscriptionFormValues?.values?.subscription_name as string - - const params: BuildTriggerSubscriptionPayload = { - provider: detail?.provider || '', - subscriptionBuilderId: subscriptionBuilder.id, - name: subscriptionNameValue, - } - - if (createType !== SupportedCreationMethods.MANUAL) { - if (autoCommonParametersSchema.length > 0) { - const autoCommonParametersFormValues = autoCommonParametersFormRef.current?.getFormValues({}) || defaultFormValues - if (!autoCommonParametersFormValues?.isCheckValidated) - return - params.parameters = autoCommonParametersFormValues.values - } - } - else if (manualPropertiesSchema.length > 0) { - const manualFormValues = manualPropertiesFormRef.current?.getFormValues({}) || defaultFormValues - if (!manualFormValues?.isCheckValidated) - return - } - - buildSubscription( - params, - { - onSuccess: () => { - Toast.notify({ - type: 'success', - message: t('subscription.createSuccess', { ns: 'pluginTrigger' }), - }) - onClose() - refetch?.() - }, - onError: async (error: unknown) => { - const errorMessage = await parsePluginErrorMessage(error) || t('subscription.createFailed', { ns: 'pluginTrigger' }) - Toast.notify({ - type: 'error', - message: errorMessage, - }) - }, - }, - ) - } - - const handleConfirm = () => { - if (currentStep === ApiKeyStep.Verify) - handleVerify() - else - handleCreate() - } - - const handleApiKeyCredentialsChange = () => { - apiKeyCredentialsFormRef.current?.setFields([{ - name: apiKeyCredentialsSchema[0].name, - errors: [], - }]) - } - - const confirmButtonText = useMemo(() => { - if (currentStep === ApiKeyStep.Verify) - return isVerifyingCredentials ? t('modal.common.verifying', { ns: 'pluginTrigger' }) : t('modal.common.verify', { ns: 'pluginTrigger' }) - - return isBuilding ? t('modal.common.creating', { ns: 'pluginTrigger' }) : t('modal.common.create', { ns: 'pluginTrigger' }) - }, [currentStep, isVerifyingCredentials, isBuilding, t]) + const isApiKeyType = createType === SupportedCreationMethods.APIKEY + const isVerifyStep = currentStep === ApiKeyStep.Verify + const isConfigurationStep = currentStep === ApiKeyStep.Configuration return ( <Modal @@ -353,121 +57,36 @@ export const CommonCreateModal = ({ onClose, createType, builder }: Props) => { onCancel={onClose} onConfirm={handleConfirm} disabled={isVerifyingCredentials || isBuilding} - bottomSlot={currentStep === ApiKeyStep.Verify ? <EncryptedBottom /> : null} + bottomSlot={isVerifyStep ? <EncryptedBottom /> : null} size={createType === SupportedCreationMethods.MANUAL ? 'md' : 'sm'} containerClassName="min-h-[360px]" clickOutsideNotClose > - {createType === SupportedCreationMethods.APIKEY && <MultiSteps currentStep={currentStep} />} - {currentStep === ApiKeyStep.Verify && ( - <> - {apiKeyCredentialsSchema.length > 0 && ( - <div className="mb-4"> - <BaseForm - formSchemas={apiKeyCredentialsSchema} - ref={apiKeyCredentialsFormRef} - labelClassName="system-sm-medium mb-2 flex items-center gap-1 text-text-primary" - preventDefaultSubmit={true} - formClassName="space-y-4" - onChange={handleApiKeyCredentialsChange} - /> - </div> - )} - </> - )} - {currentStep === ApiKeyStep.Configuration && ( - <div className="max-h-[70vh]"> - <BaseForm - formSchemas={[ - { - name: 'subscription_name', - label: t('modal.form.subscriptionName.label', { ns: 'pluginTrigger' }), - placeholder: t('modal.form.subscriptionName.placeholder', { ns: 'pluginTrigger' }), - type: FormTypeEnum.textInput, - required: true, - }, - { - name: 'callback_url', - label: t('modal.form.callbackUrl.label', { ns: 'pluginTrigger' }), - placeholder: t('modal.form.callbackUrl.placeholder', { ns: 'pluginTrigger' }), - type: FormTypeEnum.textInput, - required: false, - default: subscriptionBuilder?.endpoint || '', - disabled: true, - tooltip: t('modal.form.callbackUrl.tooltip', { ns: 'pluginTrigger' }), - showCopy: true, - }, - ]} - ref={subscriptionFormRef} - labelClassName="system-sm-medium mb-2 flex items-center gap-1 text-text-primary" - formClassName="space-y-4 mb-4" - /> - {/* <div className='system-xs-regular mb-6 mt-[-1rem] text-text-tertiary'> - {t('pluginTrigger.modal.form.callbackUrl.description')} - </div> */} - {createType !== SupportedCreationMethods.MANUAL && autoCommonParametersSchema.length > 0 && ( - <BaseForm - formSchemas={autoCommonParametersSchema.map((schema) => { - const normalizedType = normalizeFormType(schema.type as FormTypeEnum | string) - return { - ...schema, - tooltip: schema.description, - type: normalizedType, - dynamicSelectParams: normalizedType === FormTypeEnum.dynamicSelect - ? { - plugin_id: detail?.plugin_id || '', - provider: detail?.provider || '', - action: 'provider', - parameter: schema.name, - credential_id: subscriptionBuilder?.id || '', - } - : undefined, - fieldClassName: schema.type === FormTypeEnum.boolean ? 'flex items-center justify-between' : undefined, - labelClassName: schema.type === FormTypeEnum.boolean ? 'mb-0' : undefined, - } - })} - ref={autoCommonParametersFormRef} - labelClassName="system-sm-medium mb-2 flex items-center gap-1 text-text-primary" - formClassName="space-y-4" - /> - )} - {createType === SupportedCreationMethods.MANUAL && ( - <> - {manualPropertiesSchema.length > 0 && ( - <div className="mb-6"> - <BaseForm - formSchemas={manualPropertiesSchema.map(schema => ({ - ...schema, - tooltip: schema.description, - }))} - ref={manualPropertiesFormRef} - labelClassName="system-sm-medium mb-2 flex items-center gap-1 text-text-primary" - formClassName="space-y-4" - onChange={handleManualPropertiesChange} - /> - </div> - )} - <div className="mb-6"> - <div className="mb-3 flex items-center gap-2"> - <div className="system-xs-medium-uppercase text-text-tertiary"> - {t('modal.manual.logs.title', { ns: 'pluginTrigger' })} - </div> - <div className="h-px flex-1 bg-gradient-to-r from-divider-regular to-transparent" /> - </div> + {isApiKeyType && <MultiSteps currentStep={currentStep} />} - <div className="mb-1 flex items-center justify-center gap-1 rounded-lg bg-background-section p-3"> - <div className="h-3.5 w-3.5"> - <RiLoader2Line className="h-full w-full animate-spin" /> - </div> - <div className="system-xs-regular text-text-tertiary"> - {t('modal.manual.logs.loading', { ns: 'pluginTrigger', pluginName: detail?.name || '' })} - </div> - </div> - <LogViewer logs={logData?.logs || []} /> - </div> - </> - )} - </div> + {isVerifyStep && ( + <VerifyStepContent + apiKeyCredentialsSchema={apiKeyCredentialsSchema} + apiKeyCredentialsFormRef={formRefs.apiKeyCredentialsFormRef} + onChange={handleApiKeyCredentialsChange} + /> + )} + + {isConfigurationStep && ( + <ConfigurationStepContent + createType={createType} + subscriptionBuilder={subscriptionBuilder} + subscriptionFormRef={formRefs.subscriptionFormRef} + autoCommonParametersSchema={autoCommonParametersSchema} + autoCommonParametersFormRef={formRefs.autoCommonParametersFormRef} + manualPropertiesSchema={manualPropertiesSchema} + manualPropertiesFormRef={formRefs.manualPropertiesFormRef} + onManualPropertiesChange={handleManualPropertiesChange} + logs={logData?.logs || []} + pluginId={detail?.plugin_id || ''} + pluginName={detail?.name || ''} + provider={detail?.provider || ''} + /> )} </Modal> ) diff --git a/web/app/components/plugins/plugin-detail-panel/subscription-list/create/components/modal-steps.tsx b/web/app/components/plugins/plugin-detail-panel/subscription-list/create/components/modal-steps.tsx new file mode 100644 index 0000000000..795176d4f6 --- /dev/null +++ b/web/app/components/plugins/plugin-detail-panel/subscription-list/create/components/modal-steps.tsx @@ -0,0 +1,304 @@ +'use client' +import type { FormRefObject, FormSchema } from '@/app/components/base/form/types' +import type { TriggerLogEntity, TriggerSubscriptionBuilder } from '@/app/components/workflow/block-selector/types' +import { RiLoader2Line } from '@remixicon/react' +import * as React from 'react' +import { useTranslation } from 'react-i18next' +import { BaseForm } from '@/app/components/base/form/components/base' +import { FormTypeEnum } from '@/app/components/base/form/types' +import { SupportedCreationMethods } from '@/app/components/plugins/types' +import LogViewer from '../../log-viewer' +import { ApiKeyStep } from '../hooks/use-common-modal-state' + +export type SchemaItem = Partial<FormSchema> & Record<string, unknown> & { + name: string +} + +type StatusStepProps = { + isActive: boolean + text: string +} + +export const StatusStep = ({ isActive, text }: StatusStepProps) => { + return ( + <div className={`system-2xs-semibold-uppercase flex items-center gap-1 ${isActive + ? 'text-state-accent-solid' + : 'text-text-tertiary'}`} + > + {isActive && ( + <div className="h-1 w-1 rounded-full bg-state-accent-solid"></div> + )} + {text} + </div> + ) +} + +type MultiStepsProps = { + currentStep: ApiKeyStep +} + +export const MultiSteps = ({ currentStep }: MultiStepsProps) => { + const { t } = useTranslation() + return ( + <div className="mb-6 flex w-1/3 items-center gap-2"> + <StatusStep isActive={currentStep === ApiKeyStep.Verify} text={t('modal.steps.verify', { ns: 'pluginTrigger' })} /> + <div className="h-px w-3 shrink-0 bg-divider-deep"></div> + <StatusStep isActive={currentStep === ApiKeyStep.Configuration} text={t('modal.steps.configuration', { ns: 'pluginTrigger' })} /> + </div> + ) +} + +type VerifyStepContentProps = { + apiKeyCredentialsSchema: SchemaItem[] + apiKeyCredentialsFormRef: React.RefObject<FormRefObject | null> + onChange: () => void +} + +export const VerifyStepContent = ({ + apiKeyCredentialsSchema, + apiKeyCredentialsFormRef, + onChange, +}: VerifyStepContentProps) => { + if (!apiKeyCredentialsSchema.length) + return null + + return ( + <div className="mb-4"> + <BaseForm + formSchemas={apiKeyCredentialsSchema as FormSchema[]} + ref={apiKeyCredentialsFormRef} + labelClassName="system-sm-medium mb-2 flex items-center gap-1 text-text-primary" + preventDefaultSubmit={true} + formClassName="space-y-4" + onChange={onChange} + /> + </div> + ) +} + +type SubscriptionFormProps = { + subscriptionFormRef: React.RefObject<FormRefObject | null> + endpoint?: string +} + +export const SubscriptionForm = ({ + subscriptionFormRef, + endpoint, +}: SubscriptionFormProps) => { + const { t } = useTranslation() + + const formSchemas = React.useMemo(() => [ + { + name: 'subscription_name', + label: t('modal.form.subscriptionName.label', { ns: 'pluginTrigger' }), + placeholder: t('modal.form.subscriptionName.placeholder', { ns: 'pluginTrigger' }), + type: FormTypeEnum.textInput, + required: true, + }, + { + name: 'callback_url', + label: t('modal.form.callbackUrl.label', { ns: 'pluginTrigger' }), + placeholder: t('modal.form.callbackUrl.placeholder', { ns: 'pluginTrigger' }), + type: FormTypeEnum.textInput, + required: false, + default: endpoint || '', + disabled: true, + tooltip: t('modal.form.callbackUrl.tooltip', { ns: 'pluginTrigger' }), + showCopy: true, + }, + ], [endpoint, t]) + + return ( + <BaseForm + formSchemas={formSchemas} + ref={subscriptionFormRef} + labelClassName="system-sm-medium mb-2 flex items-center gap-1 text-text-primary" + formClassName="space-y-4 mb-4" + /> + ) +} + +const normalizeFormType = (type: FormTypeEnum | string): FormTypeEnum => { + if (Object.values(FormTypeEnum).includes(type as FormTypeEnum)) + return type as FormTypeEnum + + const TYPE_MAP: Record<string, FormTypeEnum> = { + string: FormTypeEnum.textInput, + text: FormTypeEnum.textInput, + password: FormTypeEnum.secretInput, + secret: FormTypeEnum.secretInput, + number: FormTypeEnum.textNumber, + integer: FormTypeEnum.textNumber, + boolean: FormTypeEnum.boolean, + } + + return TYPE_MAP[type] || FormTypeEnum.textInput +} + +type AutoParametersFormProps = { + schemas: SchemaItem[] + formRef: React.RefObject<FormRefObject | null> + pluginId: string + provider: string + credentialId: string +} + +export const AutoParametersForm = ({ + schemas, + formRef, + pluginId, + provider, + credentialId, +}: AutoParametersFormProps) => { + const formSchemas = React.useMemo(() => + schemas.map((schema) => { + const normalizedType = normalizeFormType((schema.type || FormTypeEnum.textInput) as FormTypeEnum | string) + return { + ...schema, + tooltip: schema.description, + type: normalizedType, + dynamicSelectParams: normalizedType === FormTypeEnum.dynamicSelect + ? { + plugin_id: pluginId, + provider, + action: 'provider', + parameter: schema.name, + credential_id: credentialId, + } + : undefined, + fieldClassName: normalizedType === FormTypeEnum.boolean ? 'flex items-center justify-between' : undefined, + labelClassName: normalizedType === FormTypeEnum.boolean ? 'mb-0' : undefined, + } + }) as FormSchema[], [schemas, pluginId, provider, credentialId]) + + if (!schemas.length) + return null + + return ( + <BaseForm + formSchemas={formSchemas} + ref={formRef} + labelClassName="system-sm-medium mb-2 flex items-center gap-1 text-text-primary" + formClassName="space-y-4" + /> + ) +} + +type ManualPropertiesSectionProps = { + schemas: SchemaItem[] + formRef: React.RefObject<FormRefObject | null> + onChange: () => void + logs: TriggerLogEntity[] + pluginName: string +} + +export const ManualPropertiesSection = ({ + schemas, + formRef, + onChange, + logs, + pluginName, +}: ManualPropertiesSectionProps) => { + const { t } = useTranslation() + + const formSchemas = React.useMemo(() => + schemas.map(schema => ({ + ...schema, + tooltip: schema.description, + })) as FormSchema[], [schemas]) + + return ( + <> + {schemas.length > 0 && ( + <div className="mb-6"> + <BaseForm + formSchemas={formSchemas} + ref={formRef} + labelClassName="system-sm-medium mb-2 flex items-center gap-1 text-text-primary" + formClassName="space-y-4" + onChange={onChange} + /> + </div> + )} + <div className="mb-6"> + <div className="mb-3 flex items-center gap-2"> + <div className="system-xs-medium-uppercase text-text-tertiary"> + {t('modal.manual.logs.title', { ns: 'pluginTrigger' })} + </div> + <div className="h-px flex-1 bg-gradient-to-r from-divider-regular to-transparent" /> + </div> + + <div className="mb-1 flex items-center justify-center gap-1 rounded-lg bg-background-section p-3"> + <div className="h-3.5 w-3.5"> + <RiLoader2Line className="h-full w-full animate-spin" /> + </div> + <div className="system-xs-regular text-text-tertiary"> + {t('modal.manual.logs.loading', { ns: 'pluginTrigger', pluginName })} + </div> + </div> + <LogViewer logs={logs} /> + </div> + </> + ) +} + +type ConfigurationStepContentProps = { + createType: SupportedCreationMethods + subscriptionBuilder?: TriggerSubscriptionBuilder + subscriptionFormRef: React.RefObject<FormRefObject | null> + autoCommonParametersSchema: SchemaItem[] + autoCommonParametersFormRef: React.RefObject<FormRefObject | null> + manualPropertiesSchema: SchemaItem[] + manualPropertiesFormRef: React.RefObject<FormRefObject | null> + onManualPropertiesChange: () => void + logs: TriggerLogEntity[] + pluginId: string + pluginName: string + provider: string +} + +export const ConfigurationStepContent = ({ + createType, + subscriptionBuilder, + subscriptionFormRef, + autoCommonParametersSchema, + autoCommonParametersFormRef, + manualPropertiesSchema, + manualPropertiesFormRef, + onManualPropertiesChange, + logs, + pluginId, + pluginName, + provider, +}: ConfigurationStepContentProps) => { + const isManualType = createType === SupportedCreationMethods.MANUAL + + return ( + <div className="max-h-[70vh]"> + <SubscriptionForm + subscriptionFormRef={subscriptionFormRef} + endpoint={subscriptionBuilder?.endpoint} + /> + + {!isManualType && autoCommonParametersSchema.length > 0 && ( + <AutoParametersForm + schemas={autoCommonParametersSchema} + formRef={autoCommonParametersFormRef} + pluginId={pluginId} + provider={provider} + credentialId={subscriptionBuilder?.id || ''} + /> + )} + + {isManualType && ( + <ManualPropertiesSection + schemas={manualPropertiesSchema} + formRef={manualPropertiesFormRef} + onChange={onManualPropertiesChange} + logs={logs} + pluginName={pluginName} + /> + )} + </div> + ) +} diff --git a/web/app/components/plugins/plugin-detail-panel/subscription-list/create/hooks/use-common-modal-state.ts b/web/app/components/plugins/plugin-detail-panel/subscription-list/create/hooks/use-common-modal-state.ts new file mode 100644 index 0000000000..b01312d3d1 --- /dev/null +++ b/web/app/components/plugins/plugin-detail-panel/subscription-list/create/hooks/use-common-modal-state.ts @@ -0,0 +1,401 @@ +'use client' +import type { SimpleDetail } from '../../../store' +import type { SchemaItem } from '../components/modal-steps' +import type { FormRefObject } from '@/app/components/base/form/types' +import type { TriggerLogEntity, TriggerSubscriptionBuilder } from '@/app/components/workflow/block-selector/types' +import type { BuildTriggerSubscriptionPayload } from '@/service/use-triggers' +import { debounce } from 'es-toolkit/compat' +import { useCallback, useEffect, useMemo, useRef, useState } from 'react' +import { useTranslation } from 'react-i18next' +import Toast from '@/app/components/base/toast' +import { SupportedCreationMethods } from '@/app/components/plugins/types' +import { TriggerCredentialTypeEnum } from '@/app/components/workflow/block-selector/types' +import { + useBuildTriggerSubscription, + useCreateTriggerSubscriptionBuilder, + useTriggerSubscriptionBuilderLogs, + useUpdateTriggerSubscriptionBuilder, + useVerifyAndUpdateTriggerSubscriptionBuilder, +} from '@/service/use-triggers' +import { parsePluginErrorMessage } from '@/utils/error-parser' +import { isPrivateOrLocalAddress } from '@/utils/urlValidation' +import { usePluginStore } from '../../../store' +import { useSubscriptionList } from '../../use-subscription-list' + +// ============================================================================ +// Types +// ============================================================================ + +export enum ApiKeyStep { + Verify = 'verify', + Configuration = 'configuration', +} + +export const CREDENTIAL_TYPE_MAP: Record<SupportedCreationMethods, TriggerCredentialTypeEnum> = { + [SupportedCreationMethods.APIKEY]: TriggerCredentialTypeEnum.ApiKey, + [SupportedCreationMethods.OAUTH]: TriggerCredentialTypeEnum.Oauth2, + [SupportedCreationMethods.MANUAL]: TriggerCredentialTypeEnum.Unauthorized, +} + +export const MODAL_TITLE_KEY_MAP: Record< + SupportedCreationMethods, + 'modal.apiKey.title' | 'modal.oauth.title' | 'modal.manual.title' +> = { + [SupportedCreationMethods.APIKEY]: 'modal.apiKey.title', + [SupportedCreationMethods.OAUTH]: 'modal.oauth.title', + [SupportedCreationMethods.MANUAL]: 'modal.manual.title', +} + +type UseCommonModalStateParams = { + createType: SupportedCreationMethods + builder?: TriggerSubscriptionBuilder + onClose: () => void +} + +type FormRefs = { + manualPropertiesFormRef: React.RefObject<FormRefObject | null> + subscriptionFormRef: React.RefObject<FormRefObject | null> + autoCommonParametersFormRef: React.RefObject<FormRefObject | null> + apiKeyCredentialsFormRef: React.RefObject<FormRefObject | null> +} + +type UseCommonModalStateReturn = { + // State + currentStep: ApiKeyStep + subscriptionBuilder: TriggerSubscriptionBuilder | undefined + isVerifyingCredentials: boolean + isBuilding: boolean + + // Form refs + formRefs: FormRefs + + // Computed values + detail: SimpleDetail | undefined + manualPropertiesSchema: SchemaItem[] + autoCommonParametersSchema: SchemaItem[] + apiKeyCredentialsSchema: SchemaItem[] + logData: { logs: TriggerLogEntity[] } | undefined + confirmButtonText: string + + // Handlers + handleVerify: () => void + handleCreate: () => void + handleConfirm: () => void + handleManualPropertiesChange: () => void + handleApiKeyCredentialsChange: () => void +} + +const DEFAULT_FORM_VALUES = { values: {}, isCheckValidated: false } + +// ============================================================================ +// Hook Implementation +// ============================================================================ + +export const useCommonModalState = ({ + createType, + builder, + onClose, +}: UseCommonModalStateParams): UseCommonModalStateReturn => { + const { t } = useTranslation() + const detail = usePluginStore(state => state.detail) + const { refetch } = useSubscriptionList() + + // State + const [currentStep, setCurrentStep] = useState<ApiKeyStep>( + createType === SupportedCreationMethods.APIKEY ? ApiKeyStep.Verify : ApiKeyStep.Configuration, + ) + const [subscriptionBuilder, setSubscriptionBuilder] = useState<TriggerSubscriptionBuilder | undefined>(builder) + const isInitializedRef = useRef(false) + + // Form refs + const manualPropertiesFormRef = useRef<FormRefObject>(null) + const subscriptionFormRef = useRef<FormRefObject>(null) + const autoCommonParametersFormRef = useRef<FormRefObject>(null) + const apiKeyCredentialsFormRef = useRef<FormRefObject>(null) + + // Mutations + const { mutate: verifyCredentials, isPending: isVerifyingCredentials } = useVerifyAndUpdateTriggerSubscriptionBuilder() + const { mutateAsync: createBuilder } = useCreateTriggerSubscriptionBuilder() + const { mutate: buildSubscription, isPending: isBuilding } = useBuildTriggerSubscription() + const { mutate: updateBuilder } = useUpdateTriggerSubscriptionBuilder() + + // Schemas + const manualPropertiesSchema = detail?.declaration?.trigger?.subscription_schema || [] + const autoCommonParametersSchema = detail?.declaration.trigger?.subscription_constructor?.parameters || [] + + const apiKeyCredentialsSchema = useMemo(() => { + const rawSchema = detail?.declaration?.trigger?.subscription_constructor?.credentials_schema || [] + return rawSchema.map(schema => ({ + ...schema, + tooltip: schema.help, + })) + }, [detail?.declaration?.trigger?.subscription_constructor?.credentials_schema]) + + // Log data for manual mode + const { data: logData } = useTriggerSubscriptionBuilderLogs( + detail?.provider || '', + subscriptionBuilder?.id || '', + { + enabled: createType === SupportedCreationMethods.MANUAL, + refetchInterval: 3000, + }, + ) + + // Debounced update for manual properties + const debouncedUpdate = useMemo( + () => debounce((provider: string, builderId: string, properties: Record<string, unknown>) => { + updateBuilder( + { + provider, + subscriptionBuilderId: builderId, + properties, + }, + { + onError: async (error: unknown) => { + const errorMessage = await parsePluginErrorMessage(error) || t('modal.errors.updateFailed', { ns: 'pluginTrigger' }) + console.error('Failed to update subscription builder:', error) + Toast.notify({ + type: 'error', + message: errorMessage, + }) + }, + }, + ) + }, 500), + [updateBuilder, t], + ) + + // Initialize builder + useEffect(() => { + const initializeBuilder = async () => { + isInitializedRef.current = true + try { + const response = await createBuilder({ + provider: detail?.provider || '', + credential_type: CREDENTIAL_TYPE_MAP[createType], + }) + setSubscriptionBuilder(response.subscription_builder) + } + catch (error) { + console.error('createBuilder error:', error) + Toast.notify({ + type: 'error', + message: t('modal.errors.createFailed', { ns: 'pluginTrigger' }), + }) + } + } + if (!isInitializedRef.current && !subscriptionBuilder && detail?.provider) + initializeBuilder() + }, [subscriptionBuilder, detail?.provider, createType, createBuilder, t]) + + // Cleanup debounced function + useEffect(() => { + return () => { + debouncedUpdate.cancel() + } + }, [debouncedUpdate]) + + // Update endpoint in form when endpoint changes + useEffect(() => { + if (!subscriptionBuilder?.endpoint || !subscriptionFormRef.current || currentStep !== ApiKeyStep.Configuration) + return + + const form = subscriptionFormRef.current.getForm() + if (form) + form.setFieldValue('callback_url', subscriptionBuilder.endpoint) + + const warnings = isPrivateOrLocalAddress(subscriptionBuilder.endpoint) + ? [t('modal.form.callbackUrl.privateAddressWarning', { ns: 'pluginTrigger' })] + : [] + + subscriptionFormRef.current?.setFields([{ + name: 'callback_url', + warnings, + }]) + }, [subscriptionBuilder?.endpoint, currentStep, t]) + + // Handle manual properties change + const handleManualPropertiesChange = useCallback(() => { + if (!subscriptionBuilder || !detail?.provider) + return + + const formValues = manualPropertiesFormRef.current?.getFormValues({ needCheckValidatedValues: false }) + || { values: {}, isCheckValidated: true } + + debouncedUpdate(detail.provider, subscriptionBuilder.id, formValues.values) + }, [subscriptionBuilder, detail?.provider, debouncedUpdate]) + + // Handle API key credentials change + const handleApiKeyCredentialsChange = useCallback(() => { + if (!apiKeyCredentialsSchema.length) + return + apiKeyCredentialsFormRef.current?.setFields([{ + name: apiKeyCredentialsSchema[0].name, + errors: [], + }]) + }, [apiKeyCredentialsSchema]) + + // Handle verify + const handleVerify = useCallback(() => { + // Guard against uninitialized state + if (!detail?.provider || !subscriptionBuilder?.id) { + Toast.notify({ + type: 'error', + message: 'Subscription builder not initialized', + }) + return + } + + const apiKeyCredentialsFormValues = apiKeyCredentialsFormRef.current?.getFormValues({}) || DEFAULT_FORM_VALUES + const credentials = apiKeyCredentialsFormValues.values + + if (!Object.keys(credentials).length) { + Toast.notify({ + type: 'error', + message: 'Please fill in all required credentials', + }) + return + } + + apiKeyCredentialsFormRef.current?.setFields([{ + name: Object.keys(credentials)[0], + errors: [], + }]) + + verifyCredentials( + { + provider: detail.provider, + subscriptionBuilderId: subscriptionBuilder.id, + credentials, + }, + { + onSuccess: () => { + Toast.notify({ + type: 'success', + message: t('modal.apiKey.verify.success', { ns: 'pluginTrigger' }), + }) + setCurrentStep(ApiKeyStep.Configuration) + }, + onError: async (error: unknown) => { + const errorMessage = await parsePluginErrorMessage(error) || t('modal.apiKey.verify.error', { ns: 'pluginTrigger' }) + apiKeyCredentialsFormRef.current?.setFields([{ + name: Object.keys(credentials)[0], + errors: [errorMessage], + }]) + }, + }, + ) + }, [detail?.provider, subscriptionBuilder?.id, verifyCredentials, t]) + + // Handle create + const handleCreate = useCallback(() => { + if (!subscriptionBuilder) { + Toast.notify({ + type: 'error', + message: 'Subscription builder not found', + }) + return + } + + const subscriptionFormValues = subscriptionFormRef.current?.getFormValues({}) + if (!subscriptionFormValues?.isCheckValidated) + return + + const subscriptionNameValue = subscriptionFormValues?.values?.subscription_name as string + + const params: BuildTriggerSubscriptionPayload = { + provider: detail?.provider || '', + subscriptionBuilderId: subscriptionBuilder.id, + name: subscriptionNameValue, + } + + if (createType !== SupportedCreationMethods.MANUAL) { + if (autoCommonParametersSchema.length > 0) { + const autoCommonParametersFormValues = autoCommonParametersFormRef.current?.getFormValues({}) || DEFAULT_FORM_VALUES + if (!autoCommonParametersFormValues?.isCheckValidated) + return + params.parameters = autoCommonParametersFormValues.values + } + } + else if (manualPropertiesSchema.length > 0) { + const manualFormValues = manualPropertiesFormRef.current?.getFormValues({}) || DEFAULT_FORM_VALUES + if (!manualFormValues?.isCheckValidated) + return + } + + buildSubscription( + params, + { + onSuccess: () => { + Toast.notify({ + type: 'success', + message: t('subscription.createSuccess', { ns: 'pluginTrigger' }), + }) + onClose() + refetch?.() + }, + onError: async (error: unknown) => { + const errorMessage = await parsePluginErrorMessage(error) || t('subscription.createFailed', { ns: 'pluginTrigger' }) + Toast.notify({ + type: 'error', + message: errorMessage, + }) + }, + }, + ) + }, [ + subscriptionBuilder, + detail?.provider, + createType, + autoCommonParametersSchema.length, + manualPropertiesSchema.length, + buildSubscription, + onClose, + refetch, + t, + ]) + + // Handle confirm (dispatch based on step) + const handleConfirm = useCallback(() => { + if (currentStep === ApiKeyStep.Verify) + handleVerify() + else + handleCreate() + }, [currentStep, handleVerify, handleCreate]) + + // Confirm button text + const confirmButtonText = useMemo(() => { + if (currentStep === ApiKeyStep.Verify) { + return isVerifyingCredentials + ? t('modal.common.verifying', { ns: 'pluginTrigger' }) + : t('modal.common.verify', { ns: 'pluginTrigger' }) + } + return isBuilding + ? t('modal.common.creating', { ns: 'pluginTrigger' }) + : t('modal.common.create', { ns: 'pluginTrigger' }) + }, [currentStep, isVerifyingCredentials, isBuilding, t]) + + return { + currentStep, + subscriptionBuilder, + isVerifyingCredentials, + isBuilding, + formRefs: { + manualPropertiesFormRef, + subscriptionFormRef, + autoCommonParametersFormRef, + apiKeyCredentialsFormRef, + }, + detail, + manualPropertiesSchema, + autoCommonParametersSchema, + apiKeyCredentialsSchema, + logData, + confirmButtonText, + handleVerify, + handleCreate, + handleConfirm, + handleManualPropertiesChange, + handleApiKeyCredentialsChange, + } +} diff --git a/web/app/components/plugins/plugin-detail-panel/subscription-list/create/hooks/use-oauth-client-state.spec.ts b/web/app/components/plugins/plugin-detail-panel/subscription-list/create/hooks/use-oauth-client-state.spec.ts new file mode 100644 index 0000000000..de54a2b87c --- /dev/null +++ b/web/app/components/plugins/plugin-detail-panel/subscription-list/create/hooks/use-oauth-client-state.spec.ts @@ -0,0 +1,719 @@ +import type { TriggerOAuthConfig, TriggerSubscriptionBuilder } from '@/app/components/workflow/block-selector/types' +import { act, renderHook, waitFor } from '@testing-library/react' +import { afterEach, beforeEach, describe, expect, it, vi } from 'vitest' +import { TriggerCredentialTypeEnum } from '@/app/components/workflow/block-selector/types' +import { + AuthorizationStatusEnum, + ClientTypeEnum, + getErrorMessage, + useOAuthClientState, +} from './use-oauth-client-state' + +// ============================================================================ +// Mock Factory Functions +// ============================================================================ + +function createMockOAuthConfig(overrides: Partial<TriggerOAuthConfig> = {}): TriggerOAuthConfig { + return { + configured: true, + custom_configured: false, + custom_enabled: false, + system_configured: true, + redirect_uri: 'https://example.com/oauth/callback', + params: { + client_id: 'default-client-id', + client_secret: 'default-client-secret', + }, + oauth_client_schema: [ + { name: 'client_id', type: 'text-input' as unknown, required: true, label: { 'en-US': 'Client ID' } as unknown }, + { name: 'client_secret', type: 'secret-input' as unknown, required: true, label: { 'en-US': 'Client Secret' } as unknown }, + ] as TriggerOAuthConfig['oauth_client_schema'], + ...overrides, + } +} + +function createMockSubscriptionBuilder(overrides: Partial<TriggerSubscriptionBuilder> = {}): TriggerSubscriptionBuilder { + return { + id: 'builder-123', + name: 'Test Builder', + provider: 'test-provider', + credential_type: TriggerCredentialTypeEnum.Oauth2, + credentials: {}, + endpoint: 'https://example.com/callback', + parameters: {}, + properties: {}, + workflows_in_use: 0, + ...overrides, + } +} + +// ============================================================================ +// Mock Setup +// ============================================================================ + +const mockInitiateOAuth = vi.fn() +const mockVerifyBuilder = vi.fn() +const mockConfigureOAuth = vi.fn() +const mockDeleteOAuth = vi.fn() + +vi.mock('@/service/use-triggers', () => ({ + useInitiateTriggerOAuth: () => ({ + mutate: mockInitiateOAuth, + }), + useVerifyAndUpdateTriggerSubscriptionBuilder: () => ({ + mutate: mockVerifyBuilder, + }), + useConfigureTriggerOAuth: () => ({ + mutate: mockConfigureOAuth, + }), + useDeleteTriggerOAuth: () => ({ + mutate: mockDeleteOAuth, + }), +})) + +const mockOpenOAuthPopup = vi.fn() +vi.mock('@/hooks/use-oauth', () => ({ + openOAuthPopup: (url: string, callback: (data: unknown) => void) => mockOpenOAuthPopup(url, callback), +})) + +const mockToastNotify = vi.fn() +vi.mock('@/app/components/base/toast', () => ({ + default: { + notify: (params: unknown) => mockToastNotify(params), + }, +})) + +// ============================================================================ +// Test Suites +// ============================================================================ + +describe('getErrorMessage', () => { + it('should extract message from Error instance', () => { + const error = new Error('Test error message') + expect(getErrorMessage(error, 'fallback')).toBe('Test error message') + }) + + it('should extract message from object with message property', () => { + const error = { message: 'Object error message' } + expect(getErrorMessage(error, 'fallback')).toBe('Object error message') + }) + + it('should return fallback when error is empty object', () => { + expect(getErrorMessage({}, 'fallback')).toBe('fallback') + }) + + it('should return fallback when error.message is not a string', () => { + expect(getErrorMessage({ message: 123 }, 'fallback')).toBe('fallback') + }) + + it('should return fallback when error.message is empty string', () => { + expect(getErrorMessage({ message: '' }, 'fallback')).toBe('fallback') + }) + + it('should return fallback when error is null', () => { + expect(getErrorMessage(null, 'fallback')).toBe('fallback') + }) + + it('should return fallback when error is undefined', () => { + expect(getErrorMessage(undefined, 'fallback')).toBe('fallback') + }) + + it('should return fallback when error is a primitive', () => { + expect(getErrorMessage('string error', 'fallback')).toBe('fallback') + expect(getErrorMessage(123, 'fallback')).toBe('fallback') + }) +}) + +describe('useOAuthClientState', () => { + const defaultParams = { + oauthConfig: createMockOAuthConfig(), + providerName: 'test-provider', + onClose: vi.fn(), + showOAuthCreateModal: vi.fn(), + } + + beforeEach(() => { + vi.clearAllMocks() + }) + + afterEach(() => { + vi.clearAllMocks() + }) + + describe('Initial State', () => { + it('should default to Default client type when system_configured is true', () => { + const { result } = renderHook(() => useOAuthClientState(defaultParams)) + + expect(result.current.clientType).toBe(ClientTypeEnum.Default) + }) + + it('should default to Custom client type when system_configured is false', () => { + const config = createMockOAuthConfig({ system_configured: false }) + const { result } = renderHook(() => useOAuthClientState({ + ...defaultParams, + oauthConfig: config, + })) + + expect(result.current.clientType).toBe(ClientTypeEnum.Custom) + }) + + it('should have undefined authorizationStatus initially', () => { + const { result } = renderHook(() => useOAuthClientState(defaultParams)) + + expect(result.current.authorizationStatus).toBeUndefined() + }) + + it('should provide clientFormRef', () => { + const { result } = renderHook(() => useOAuthClientState(defaultParams)) + + expect(result.current.clientFormRef).toBeDefined() + expect(result.current.clientFormRef.current).toBeNull() + }) + }) + + describe('OAuth Client Schema', () => { + it('should compute schema with default values from params', () => { + const config = createMockOAuthConfig({ + params: { + client_id: 'my-client-id', + client_secret: 'my-secret', + }, + }) + const { result } = renderHook(() => useOAuthClientState({ + ...defaultParams, + oauthConfig: config, + })) + + expect(result.current.oauthClientSchema).toHaveLength(2) + expect(result.current.oauthClientSchema[0].default).toBe('my-client-id') + expect(result.current.oauthClientSchema[1].default).toBe('my-secret') + }) + + it('should return empty array when oauth_client_schema is empty', () => { + const config = createMockOAuthConfig({ + oauth_client_schema: [], + }) + const { result } = renderHook(() => useOAuthClientState({ + ...defaultParams, + oauthConfig: config, + })) + + expect(result.current.oauthClientSchema).toEqual([]) + }) + + it('should return empty array when params is undefined', () => { + const config = createMockOAuthConfig({ + params: undefined as unknown as TriggerOAuthConfig['params'], + }) + const { result } = renderHook(() => useOAuthClientState({ + ...defaultParams, + oauthConfig: config, + })) + + expect(result.current.oauthClientSchema).toEqual([]) + }) + + it('should preserve original schema default when param key not found', () => { + const config = createMockOAuthConfig({ + params: { + client_id: 'only-client-id', + client_secret: '', // empty + }, + oauth_client_schema: [ + { name: 'client_id', type: 'text-input' as unknown, required: true, label: {} as unknown, default: 'original-default' }, + { name: 'extra_field', type: 'text-input' as unknown, required: false, label: {} as unknown, default: 'extra-default' }, + ] as TriggerOAuthConfig['oauth_client_schema'], + }) + const { result } = renderHook(() => useOAuthClientState({ + ...defaultParams, + oauthConfig: config, + })) + + // client_id should be overridden + expect(result.current.oauthClientSchema[0].default).toBe('only-client-id') + // extra_field should keep original default since key not in params + expect(result.current.oauthClientSchema[1].default).toBe('extra-default') + }) + }) + + describe('Confirm Button Text', () => { + it('should show saveAndAuth text by default', () => { + const { result } = renderHook(() => useOAuthClientState(defaultParams)) + + expect(result.current.confirmButtonText).toBe('plugin.auth.saveAndAuth') + }) + + it('should show authorizing text when status is Pending', async () => { + mockConfigureOAuth.mockImplementation((params, { onSuccess }) => onSuccess()) + mockInitiateOAuth.mockImplementation(() => { + // Don't resolve - stays pending + }) + + const { result } = renderHook(() => useOAuthClientState(defaultParams)) + + act(() => { + result.current.handleSave(true) + }) + + await waitFor(() => { + expect(result.current.confirmButtonText).toBe('pluginTrigger.modal.common.authorizing') + }) + }) + }) + + describe('setClientType', () => { + it('should update client type when called', () => { + const { result } = renderHook(() => useOAuthClientState(defaultParams)) + + act(() => { + result.current.setClientType(ClientTypeEnum.Custom) + }) + + expect(result.current.clientType).toBe(ClientTypeEnum.Custom) + }) + + it('should toggle between client types', () => { + const { result } = renderHook(() => useOAuthClientState(defaultParams)) + + act(() => { + result.current.setClientType(ClientTypeEnum.Custom) + }) + expect(result.current.clientType).toBe(ClientTypeEnum.Custom) + + act(() => { + result.current.setClientType(ClientTypeEnum.Default) + }) + expect(result.current.clientType).toBe(ClientTypeEnum.Default) + }) + }) + + describe('handleRemove', () => { + it('should call deleteOAuth with provider name', () => { + const { result } = renderHook(() => useOAuthClientState(defaultParams)) + + act(() => { + result.current.handleRemove() + }) + + expect(mockDeleteOAuth).toHaveBeenCalledWith( + 'test-provider', + expect.any(Object), + ) + }) + + it('should call onClose and show success toast on success', () => { + mockDeleteOAuth.mockImplementation((provider, { onSuccess }) => onSuccess()) + + const onClose = vi.fn() + const { result } = renderHook(() => useOAuthClientState({ + ...defaultParams, + onClose, + })) + + act(() => { + result.current.handleRemove() + }) + + expect(onClose).toHaveBeenCalled() + expect(mockToastNotify).toHaveBeenCalledWith({ + type: 'success', + message: 'pluginTrigger.modal.oauth.remove.success', + }) + }) + + it('should show error toast with error message on failure', () => { + mockDeleteOAuth.mockImplementation((provider, { onError }) => { + onError(new Error('Delete failed')) + }) + + const { result } = renderHook(() => useOAuthClientState(defaultParams)) + + act(() => { + result.current.handleRemove() + }) + + expect(mockToastNotify).toHaveBeenCalledWith({ + type: 'error', + message: 'Delete failed', + }) + }) + }) + + describe('handleSave', () => { + it('should call configureOAuth with enabled: false for Default type', () => { + mockConfigureOAuth.mockImplementation((params, { onSuccess }) => onSuccess()) + + const { result } = renderHook(() => useOAuthClientState(defaultParams)) + + act(() => { + result.current.handleSave(false) + }) + + expect(mockConfigureOAuth).toHaveBeenCalledWith( + expect.objectContaining({ + provider: 'test-provider', + enabled: false, + }), + expect.any(Object), + ) + }) + + it('should call configureOAuth with enabled: true for Custom type', () => { + mockConfigureOAuth.mockImplementation((params, { onSuccess }) => onSuccess()) + + const config = createMockOAuthConfig({ system_configured: false }) + const { result } = renderHook(() => useOAuthClientState({ + ...defaultParams, + oauthConfig: config, + })) + + // Mock the form ref + const mockFormRef = { + getFormValues: () => ({ + values: { client_id: 'new-id', client_secret: 'new-secret' }, + isCheckValidated: true, + }), + } + // @ts-expect-error - mocking ref + result.current.clientFormRef.current = mockFormRef + + act(() => { + result.current.handleSave(false) + }) + + expect(mockConfigureOAuth).toHaveBeenCalledWith( + expect.objectContaining({ + enabled: true, + }), + expect.any(Object), + ) + }) + + it('should show success toast and call onClose when needAuth is false', () => { + mockConfigureOAuth.mockImplementation((params, { onSuccess }) => onSuccess()) + const onClose = vi.fn() + + const { result } = renderHook(() => useOAuthClientState({ + ...defaultParams, + onClose, + })) + + act(() => { + result.current.handleSave(false) + }) + + expect(onClose).toHaveBeenCalled() + expect(mockToastNotify).toHaveBeenCalledWith({ + type: 'success', + message: 'pluginTrigger.modal.oauth.save.success', + }) + }) + + it('should trigger authorization when needAuth is true', () => { + mockConfigureOAuth.mockImplementation((params, { onSuccess }) => onSuccess()) + mockInitiateOAuth.mockImplementation((provider, { onSuccess }) => { + onSuccess({ + authorization_url: 'https://oauth.example.com/authorize', + subscription_builder: createMockSubscriptionBuilder(), + }) + }) + + const { result } = renderHook(() => useOAuthClientState(defaultParams)) + + act(() => { + result.current.handleSave(true) + }) + + expect(mockInitiateOAuth).toHaveBeenCalledWith( + 'test-provider', + expect.any(Object), + ) + }) + }) + + describe('handleAuthorization', () => { + it('should set status to Pending and call initiateOAuth', () => { + mockConfigureOAuth.mockImplementation((params, { onSuccess }) => onSuccess()) + mockInitiateOAuth.mockImplementation(() => {}) + + const { result } = renderHook(() => useOAuthClientState(defaultParams)) + + act(() => { + result.current.handleSave(true) + }) + + expect(result.current.authorizationStatus).toBe(AuthorizationStatusEnum.Pending) + expect(mockInitiateOAuth).toHaveBeenCalled() + }) + + it('should open OAuth popup on success', () => { + mockConfigureOAuth.mockImplementation((params, { onSuccess }) => onSuccess()) + mockInitiateOAuth.mockImplementation((provider, { onSuccess }) => { + onSuccess({ + authorization_url: 'https://oauth.example.com/authorize', + subscription_builder: createMockSubscriptionBuilder(), + }) + }) + + const { result } = renderHook(() => useOAuthClientState(defaultParams)) + + act(() => { + result.current.handleSave(true) + }) + + expect(mockOpenOAuthPopup).toHaveBeenCalledWith( + 'https://oauth.example.com/authorize', + expect.any(Function), + ) + }) + + it('should set status to Failed and show error toast on error', () => { + mockConfigureOAuth.mockImplementation((params, { onSuccess }) => onSuccess()) + mockInitiateOAuth.mockImplementation((provider, { onError }) => { + onError(new Error('OAuth failed')) + }) + + const { result } = renderHook(() => useOAuthClientState(defaultParams)) + + act(() => { + result.current.handleSave(true) + }) + + expect(result.current.authorizationStatus).toBe(AuthorizationStatusEnum.Failed) + expect(mockToastNotify).toHaveBeenCalledWith({ + type: 'error', + message: 'pluginTrigger.modal.oauth.authorization.authFailed', + }) + }) + + it('should call onClose and showOAuthCreateModal on callback success', () => { + const onClose = vi.fn() + const showOAuthCreateModal = vi.fn() + const builder = createMockSubscriptionBuilder() + + mockConfigureOAuth.mockImplementation((params, { onSuccess }) => onSuccess()) + mockInitiateOAuth.mockImplementation((provider, { onSuccess }) => { + onSuccess({ + authorization_url: 'https://oauth.example.com/authorize', + subscription_builder: builder, + }) + }) + mockOpenOAuthPopup.mockImplementation((url, callback) => { + callback({ success: true }) + }) + + const { result } = renderHook(() => useOAuthClientState({ + ...defaultParams, + onClose, + showOAuthCreateModal, + })) + + act(() => { + result.current.handleSave(true) + }) + + expect(onClose).toHaveBeenCalled() + expect(showOAuthCreateModal).toHaveBeenCalledWith(builder) + expect(mockToastNotify).toHaveBeenCalledWith({ + type: 'success', + message: 'pluginTrigger.modal.oauth.authorization.authSuccess', + }) + }) + + it('should not call callbacks when OAuth callback returns falsy', () => { + const onClose = vi.fn() + const showOAuthCreateModal = vi.fn() + + mockConfigureOAuth.mockImplementation((params, { onSuccess }) => onSuccess()) + mockInitiateOAuth.mockImplementation((provider, { onSuccess }) => { + onSuccess({ + authorization_url: 'https://oauth.example.com/authorize', + subscription_builder: createMockSubscriptionBuilder(), + }) + }) + mockOpenOAuthPopup.mockImplementation((url, callback) => { + callback(null) + }) + + const { result } = renderHook(() => useOAuthClientState({ + ...defaultParams, + onClose, + showOAuthCreateModal, + })) + + act(() => { + result.current.handleSave(true) + }) + + expect(onClose).not.toHaveBeenCalled() + expect(showOAuthCreateModal).not.toHaveBeenCalled() + }) + }) + + describe('Polling Effect', () => { + it('should start polling after authorization starts', async () => { + vi.useFakeTimers({ shouldAdvanceTime: true }) + + mockConfigureOAuth.mockImplementation((params, { onSuccess }) => onSuccess()) + mockInitiateOAuth.mockImplementation((provider, { onSuccess }) => { + onSuccess({ + authorization_url: 'https://oauth.example.com/authorize', + subscription_builder: createMockSubscriptionBuilder(), + }) + }) + mockVerifyBuilder.mockImplementation((params, { onSuccess }) => { + onSuccess({ verified: false }) + }) + + const { result } = renderHook(() => useOAuthClientState(defaultParams)) + + act(() => { + result.current.handleSave(true) + }) + + // Advance timer to trigger first poll + await act(async () => { + vi.advanceTimersByTime(3000) + }) + + expect(mockVerifyBuilder).toHaveBeenCalled() + + vi.useRealTimers() + }) + + it('should set status to Success when verified', async () => { + vi.useFakeTimers({ shouldAdvanceTime: true }) + + mockConfigureOAuth.mockImplementation((params, { onSuccess }) => onSuccess()) + mockInitiateOAuth.mockImplementation((provider, { onSuccess }) => { + onSuccess({ + authorization_url: 'https://oauth.example.com/authorize', + subscription_builder: createMockSubscriptionBuilder(), + }) + }) + mockVerifyBuilder.mockImplementation((params, { onSuccess }) => { + onSuccess({ verified: true }) + }) + + const { result } = renderHook(() => useOAuthClientState(defaultParams)) + + act(() => { + result.current.handleSave(true) + }) + + await act(async () => { + vi.advanceTimersByTime(3000) + }) + + await waitFor(() => { + expect(result.current.authorizationStatus).toBe(AuthorizationStatusEnum.Success) + }) + + vi.useRealTimers() + }) + + it('should continue polling on error', async () => { + vi.useFakeTimers({ shouldAdvanceTime: true }) + + mockConfigureOAuth.mockImplementation((params, { onSuccess }) => onSuccess()) + mockInitiateOAuth.mockImplementation((provider, { onSuccess }) => { + onSuccess({ + authorization_url: 'https://oauth.example.com/authorize', + subscription_builder: createMockSubscriptionBuilder(), + }) + }) + mockVerifyBuilder.mockImplementation((params, { onError }) => { + onError(new Error('Verify failed')) + }) + + const { result } = renderHook(() => useOAuthClientState(defaultParams)) + + act(() => { + result.current.handleSave(true) + }) + + await act(async () => { + vi.advanceTimersByTime(3000) + }) + + expect(mockVerifyBuilder).toHaveBeenCalled() + // Status should still be Pending + expect(result.current.authorizationStatus).toBe(AuthorizationStatusEnum.Pending) + + vi.useRealTimers() + }) + + it('should stop polling when verified', async () => { + vi.useFakeTimers({ shouldAdvanceTime: true }) + + mockConfigureOAuth.mockImplementation((params, { onSuccess }) => onSuccess()) + mockInitiateOAuth.mockImplementation((provider, { onSuccess }) => { + onSuccess({ + authorization_url: 'https://oauth.example.com/authorize', + subscription_builder: createMockSubscriptionBuilder(), + }) + }) + mockVerifyBuilder.mockImplementation((params, { onSuccess }) => { + onSuccess({ verified: true }) + }) + + const { result } = renderHook(() => useOAuthClientState(defaultParams)) + + act(() => { + result.current.handleSave(true) + }) + + // First poll - should verify + await act(async () => { + vi.advanceTimersByTime(3000) + }) + + expect(mockVerifyBuilder).toHaveBeenCalledTimes(1) + + // Second poll - should not happen as interval is cleared + await act(async () => { + vi.advanceTimersByTime(3000) + }) + + // Still only 1 call because polling stopped + expect(mockVerifyBuilder).toHaveBeenCalledTimes(1) + + vi.useRealTimers() + }) + }) + + describe('Edge Cases', () => { + it('should handle undefined oauthConfig', () => { + const { result } = renderHook(() => useOAuthClientState({ + ...defaultParams, + oauthConfig: undefined, + })) + + expect(result.current.clientType).toBe(ClientTypeEnum.Custom) + expect(result.current.oauthClientSchema).toEqual([]) + }) + + it('should handle empty providerName', () => { + const { result } = renderHook(() => useOAuthClientState({ + ...defaultParams, + providerName: '', + })) + + // Should not throw + expect(result.current.clientType).toBe(ClientTypeEnum.Default) + }) + }) +}) + +describe('Enum Exports', () => { + it('should export AuthorizationStatusEnum', () => { + expect(AuthorizationStatusEnum.Pending).toBe('pending') + expect(AuthorizationStatusEnum.Success).toBe('success') + expect(AuthorizationStatusEnum.Failed).toBe('failed') + }) + + it('should export ClientTypeEnum', () => { + expect(ClientTypeEnum.Default).toBe('default') + expect(ClientTypeEnum.Custom).toBe('custom') + }) +}) diff --git a/web/app/components/plugins/plugin-detail-panel/subscription-list/create/hooks/use-oauth-client-state.ts b/web/app/components/plugins/plugin-detail-panel/subscription-list/create/hooks/use-oauth-client-state.ts new file mode 100644 index 0000000000..6a551051e2 --- /dev/null +++ b/web/app/components/plugins/plugin-detail-panel/subscription-list/create/hooks/use-oauth-client-state.ts @@ -0,0 +1,241 @@ +'use client' +import type { FormRefObject } from '@/app/components/base/form/types' +import type { TriggerOAuthClientParams, TriggerOAuthConfig, TriggerSubscriptionBuilder } from '@/app/components/workflow/block-selector/types' +import type { ConfigureTriggerOAuthPayload } from '@/service/use-triggers' +import { useCallback, useEffect, useMemo, useRef, useState } from 'react' +import { useTranslation } from 'react-i18next' +import Toast from '@/app/components/base/toast' +import { openOAuthPopup } from '@/hooks/use-oauth' +import { + useConfigureTriggerOAuth, + useDeleteTriggerOAuth, + useInitiateTriggerOAuth, + useVerifyAndUpdateTriggerSubscriptionBuilder, +} from '@/service/use-triggers' + +export enum AuthorizationStatusEnum { + Pending = 'pending', + Success = 'success', + Failed = 'failed', +} + +export enum ClientTypeEnum { + Default = 'default', + Custom = 'custom', +} + +const POLL_INTERVAL_MS = 3000 + +// Extract error message from various error formats +export const getErrorMessage = (error: unknown, fallback: string): string => { + if (error instanceof Error && error.message) + return error.message + if (typeof error === 'object' && error && 'message' in error) { + const message = (error as { message?: string }).message + if (typeof message === 'string' && message) + return message + } + return fallback +} + +type UseOAuthClientStateParams = { + oauthConfig?: TriggerOAuthConfig + providerName: string + onClose: () => void + showOAuthCreateModal: (builder: TriggerSubscriptionBuilder) => void +} + +type UseOAuthClientStateReturn = { + // State + clientType: ClientTypeEnum + setClientType: (type: ClientTypeEnum) => void + authorizationStatus: AuthorizationStatusEnum | undefined + + // Refs + clientFormRef: React.RefObject<FormRefObject | null> + + // Computed values + oauthClientSchema: TriggerOAuthConfig['oauth_client_schema'] + confirmButtonText: string + + // Handlers + handleAuthorization: () => void + handleRemove: () => void + handleSave: (needAuth: boolean) => void +} + +export const useOAuthClientState = ({ + oauthConfig, + providerName, + onClose, + showOAuthCreateModal, +}: UseOAuthClientStateParams): UseOAuthClientStateReturn => { + const { t } = useTranslation() + + // State management + const [subscriptionBuilder, setSubscriptionBuilder] = useState<TriggerSubscriptionBuilder | undefined>() + const [authorizationStatus, setAuthorizationStatus] = useState<AuthorizationStatusEnum>() + const [clientType, setClientType] = useState<ClientTypeEnum>( + oauthConfig?.system_configured ? ClientTypeEnum.Default : ClientTypeEnum.Custom, + ) + + const clientFormRef = useRef<FormRefObject>(null) + + // Mutations + const { mutate: initiateOAuth } = useInitiateTriggerOAuth() + const { mutate: verifyBuilder } = useVerifyAndUpdateTriggerSubscriptionBuilder() + const { mutate: configureOAuth } = useConfigureTriggerOAuth() + const { mutate: deleteOAuth } = useDeleteTriggerOAuth() + + // Compute OAuth client schema with default values + const oauthClientSchema = useMemo(() => { + const { oauth_client_schema, params } = oauthConfig || {} + if (!oauth_client_schema?.length || !params) + return [] + + const paramKeys = Object.keys(params) + return oauth_client_schema.map(schema => ({ + ...schema, + default: paramKeys.includes(schema.name) ? params[schema.name] : schema.default, + })) + }, [oauthConfig]) + + // Compute confirm button text based on authorization status + const confirmButtonText = useMemo(() => { + if (authorizationStatus === AuthorizationStatusEnum.Pending) + return t('modal.common.authorizing', { ns: 'pluginTrigger' }) + if (authorizationStatus === AuthorizationStatusEnum.Success) + return t('modal.oauth.authorization.waitingJump', { ns: 'pluginTrigger' }) + return t('auth.saveAndAuth', { ns: 'plugin' }) + }, [authorizationStatus, t]) + + // Authorization handler + const handleAuthorization = useCallback(() => { + setAuthorizationStatus(AuthorizationStatusEnum.Pending) + initiateOAuth(providerName, { + onSuccess: (response) => { + setSubscriptionBuilder(response.subscription_builder) + openOAuthPopup(response.authorization_url, (callbackData) => { + if (!callbackData) + return + Toast.notify({ + type: 'success', + message: t('modal.oauth.authorization.authSuccess', { ns: 'pluginTrigger' }), + }) + onClose() + showOAuthCreateModal(response.subscription_builder) + }) + }, + onError: () => { + setAuthorizationStatus(AuthorizationStatusEnum.Failed) + Toast.notify({ + type: 'error', + message: t('modal.oauth.authorization.authFailed', { ns: 'pluginTrigger' }), + }) + }, + }) + }, [providerName, initiateOAuth, onClose, showOAuthCreateModal, t]) + + // Remove handler + const handleRemove = useCallback(() => { + deleteOAuth(providerName, { + onSuccess: () => { + onClose() + Toast.notify({ + type: 'success', + message: t('modal.oauth.remove.success', { ns: 'pluginTrigger' }), + }) + }, + onError: (error: unknown) => { + Toast.notify({ + type: 'error', + message: getErrorMessage(error, t('modal.oauth.remove.failed', { ns: 'pluginTrigger' })), + }) + }, + }) + }, [providerName, deleteOAuth, onClose, t]) + + // Save handler + const handleSave = useCallback((needAuth: boolean) => { + const isCustom = clientType === ClientTypeEnum.Custom + const params: ConfigureTriggerOAuthPayload = { + provider: providerName, + enabled: isCustom, + } + + if (isCustom && oauthClientSchema?.length) { + const clientFormValues = clientFormRef.current?.getFormValues({}) as { + values: TriggerOAuthClientParams + isCheckValidated: boolean + } | undefined + // Handle missing ref or form values + if (!clientFormValues || !clientFormValues.isCheckValidated) + return + const clientParams = { ...clientFormValues.values } + // Preserve hidden values if unchanged + if (clientParams.client_id === oauthConfig?.params.client_id) + clientParams.client_id = '[__HIDDEN__]' + if (clientParams.client_secret === oauthConfig?.params.client_secret) + clientParams.client_secret = '[__HIDDEN__]' + params.client_params = clientParams + } + + configureOAuth(params, { + onSuccess: () => { + if (needAuth) { + handleAuthorization() + return + } + onClose() + Toast.notify({ + type: 'success', + message: t('modal.oauth.save.success', { ns: 'pluginTrigger' }), + }) + }, + }) + }, [clientType, providerName, oauthClientSchema, oauthConfig?.params, configureOAuth, handleAuthorization, onClose, t]) + + // Polling effect for authorization verification + useEffect(() => { + const shouldPoll = providerName + && subscriptionBuilder + && authorizationStatus === AuthorizationStatusEnum.Pending + + if (!shouldPoll) + return + + const pollInterval = setInterval(() => { + verifyBuilder( + { + provider: providerName, + subscriptionBuilderId: subscriptionBuilder.id, + }, + { + onSuccess: (response) => { + if (response.verified) { + setAuthorizationStatus(AuthorizationStatusEnum.Success) + clearInterval(pollInterval) + } + }, + onError: () => { + // Continue polling on error - auth might still be in progress + }, + }, + ) + }, POLL_INTERVAL_MS) + + return () => clearInterval(pollInterval) + }, [subscriptionBuilder, authorizationStatus, verifyBuilder, providerName]) + + return { + clientType, + setClientType, + authorizationStatus, + clientFormRef, + oauthClientSchema, + confirmButtonText, + handleAuthorization, + handleRemove, + handleSave, + } +} diff --git a/web/app/components/plugins/plugin-detail-panel/subscription-list/create/index.spec.tsx b/web/app/components/plugins/plugin-detail-panel/subscription-list/create/index.spec.tsx index 0ad6bc364e..8520d7e2e9 100644 --- a/web/app/components/plugins/plugin-detail-panel/subscription-list/create/index.spec.tsx +++ b/web/app/components/plugins/plugin-detail-panel/subscription-list/create/index.spec.tsx @@ -6,9 +6,6 @@ import { SupportedCreationMethods } from '@/app/components/plugins/types' import { TriggerCredentialTypeEnum } from '@/app/components/workflow/block-selector/types' import { CreateButtonType, CreateSubscriptionButton, DEFAULT_METHOD } from './index' -// ==================== Mock Setup ==================== - -// Mock shared state for portal let mockPortalOpenState = false vi.mock('@/app/components/base/portal-to-follow-elem', () => ({ @@ -36,21 +33,18 @@ vi.mock('@/app/components/base/portal-to-follow-elem', () => ({ }, })) -// Mock Toast vi.mock('@/app/components/base/toast', () => ({ default: { notify: vi.fn(), }, })) -// Mock zustand store let mockStoreDetail: SimpleDetail | undefined vi.mock('../../store', () => ({ usePluginStore: (selector: (state: { detail: SimpleDetail | undefined }) => SimpleDetail | undefined) => selector({ detail: mockStoreDetail }), })) -// Mock subscription list hook const mockSubscriptions: TriggerSubscription[] = [] const mockRefetch = vi.fn() vi.mock('../use-subscription-list', () => ({ @@ -60,7 +54,6 @@ vi.mock('../use-subscription-list', () => ({ }), })) -// Mock trigger service hooks let mockProviderInfo: { data: TriggerProviderApiEntity | undefined } = { data: undefined } let mockOAuthConfig: { data: TriggerOAuthConfig | undefined, refetch: () => void } = { data: undefined, refetch: vi.fn() } const mockInitiateOAuth = vi.fn() @@ -73,14 +66,12 @@ vi.mock('@/service/use-triggers', () => ({ }), })) -// Mock OAuth popup vi.mock('@/hooks/use-oauth', () => ({ openOAuthPopup: vi.fn((url: string, callback: (data?: unknown) => void) => { callback({ success: true, subscriptionId: 'test-subscription' }) }), })) -// Mock child modals vi.mock('./common-modal', () => ({ CommonCreateModal: ({ createType, onClose, builder }: { createType: SupportedCreationMethods @@ -128,7 +119,6 @@ vi.mock('./oauth-client', () => ({ ), })) -// Mock CustomSelect vi.mock('@/app/components/base/select/custom', () => ({ default: ({ options, value, onChange, CustomTrigger, CustomOption, containerProps }: { options: Array<{ value: string, label: string, show: boolean, extra?: React.ReactNode, tag?: React.ReactNode }> @@ -160,11 +150,6 @@ vi.mock('@/app/components/base/select/custom', () => ({ ), })) -// ==================== Test Utilities ==================== - -/** - * Factory function to create a TriggerProviderApiEntity with defaults - */ const createProviderInfo = (overrides: Partial<TriggerProviderApiEntity> = {}): TriggerProviderApiEntity => ({ author: 'test-author', name: 'test-provider', @@ -179,9 +164,6 @@ const createProviderInfo = (overrides: Partial<TriggerProviderApiEntity> = {}): ...overrides, }) -/** - * Factory function to create a TriggerOAuthConfig with defaults - */ const createOAuthConfig = (overrides: Partial<TriggerOAuthConfig> = {}): TriggerOAuthConfig => ({ configured: false, custom_configured: false, @@ -196,9 +178,6 @@ const createOAuthConfig = (overrides: Partial<TriggerOAuthConfig> = {}): Trigger ...overrides, }) -/** - * Factory function to create a SimpleDetail with defaults - */ const createStoreDetail = (overrides: Partial<SimpleDetail> = {}): SimpleDetail => ({ plugin_id: 'test-plugin', name: 'Test Plugin', @@ -209,9 +188,6 @@ const createStoreDetail = (overrides: Partial<SimpleDetail> = {}): SimpleDetail ...overrides, }) -/** - * Factory function to create a TriggerSubscription with defaults - */ const createSubscription = (overrides: Partial<TriggerSubscription> = {}): TriggerSubscription => ({ id: 'test-subscription', name: 'Test Subscription', @@ -225,16 +201,10 @@ const createSubscription = (overrides: Partial<TriggerSubscription> = {}): Trigg ...overrides, }) -/** - * Factory function to create default props - */ const createDefaultProps = (overrides: Partial<Parameters<typeof CreateSubscriptionButton>[0]> = {}) => ({ ...overrides, }) -/** - * Helper to set up mock data for testing - */ const setupMocks = (config: { providerInfo?: TriggerProviderApiEntity oauthConfig?: TriggerOAuthConfig @@ -249,8 +219,6 @@ const setupMocks = (config: { mockSubscriptions.push(...config.subscriptions) } -// ==================== Tests ==================== - describe('CreateSubscriptionButton', () => { beforeEach(() => { vi.clearAllMocks() @@ -258,7 +226,6 @@ describe('CreateSubscriptionButton', () => { setupMocks() }) - // ==================== Rendering Tests ==================== describe('Rendering', () => { it('should render null when supportedMethods is empty', () => { // Arrange @@ -322,7 +289,6 @@ describe('CreateSubscriptionButton', () => { }) }) - // ==================== Props Testing ==================== describe('Props', () => { it('should apply default buttonType as FULL_BUTTON', () => { // Arrange @@ -355,7 +321,6 @@ describe('CreateSubscriptionButton', () => { }) }) - // ==================== State Management ==================== describe('State Management', () => { it('should show CommonCreateModal when selectedCreateInfo is set', async () => { // Arrange @@ -474,7 +439,6 @@ describe('CreateSubscriptionButton', () => { }) }) - // ==================== Memoization Logic ==================== describe('Memoization - buttonTextMap', () => { it('should display correct button text for OAUTH method', () => { // Arrange diff --git a/web/app/components/plugins/plugin-detail-panel/subscription-list/create/index.tsx b/web/app/components/plugins/plugin-detail-panel/subscription-list/create/index.tsx index d119f42a13..eecaf165fb 100644 --- a/web/app/components/plugins/plugin-detail-panel/subscription-list/create/index.tsx +++ b/web/app/components/plugins/plugin-detail-panel/subscription-list/create/index.tsx @@ -2,7 +2,7 @@ import type { Option } from '@/app/components/base/select/custom' import type { TriggerSubscriptionBuilder } from '@/app/components/workflow/block-selector/types' import { RiAddLine, RiEqualizer2Line } from '@remixicon/react' import { useBoolean } from 'ahooks' -import { useMemo, useState } from 'react' +import { useCallback, useMemo, useState } from 'react' import { useTranslation } from 'react-i18next' import { ActionButton, ActionButtonState } from '@/app/components/base/action-button' import Badge from '@/app/components/base/badge' @@ -18,11 +18,7 @@ import { usePluginStore } from '../../store' import { useSubscriptionList } from '../use-subscription-list' import { CommonCreateModal } from './common-modal' import { OAuthClientSettingsModal } from './oauth-client' - -export enum CreateButtonType { - FULL_BUTTON = 'full-button', - ICON_BUTTON = 'icon-button', -} +import { CreateButtonType, DEFAULT_METHOD } from './types' type Props = { className?: string @@ -32,8 +28,6 @@ type Props = { const MAX_COUNT = 10 -export const DEFAULT_METHOD = 'default' - export const CreateSubscriptionButton = ({ buttonType = CreateButtonType.FULL_BUTTON, shape = 'square' }: Props) => { const { t } = useTranslation() const { subscriptions } = useSubscriptionList() @@ -43,7 +37,7 @@ export const CreateSubscriptionButton = ({ buttonType = CreateButtonType.FULL_BU const detail = usePluginStore(state => state.detail) const { data: providerInfo } = useTriggerProviderInfo(detail?.provider || '') - const supportedMethods = providerInfo?.supported_creation_methods || [] + const supportedMethods = useMemo(() => providerInfo?.supported_creation_methods || [], [providerInfo?.supported_creation_methods]) const { data: oauthConfig, refetch: refetchOAuthConfig } = useTriggerOAuthConfig(detail?.provider || '', supportedMethods.includes(SupportedCreationMethods.OAUTH)) const { mutate: initiateOAuth } = useInitiateTriggerOAuth() @@ -63,11 +57,11 @@ export const CreateSubscriptionButton = ({ buttonType = CreateButtonType.FULL_BU } }, [t]) - const onClickClientSettings = (e: React.MouseEvent<HTMLDivElement | HTMLButtonElement>) => { + const onClickClientSettings = useCallback((e: React.MouseEvent<HTMLDivElement | HTMLButtonElement>) => { e.stopPropagation() e.preventDefault() showClientSettingsModal() - } + }, [showClientSettingsModal]) const allOptions = useMemo(() => { const showCustomBadge = oauthConfig?.custom_enabled && oauthConfig?.custom_configured @@ -104,7 +98,7 @@ export const CreateSubscriptionButton = ({ buttonType = CreateButtonType.FULL_BU show: supportedMethods.includes(SupportedCreationMethods.MANUAL), }, ] - }, [t, oauthConfig, supportedMethods, methodType]) + }, [t, oauthConfig, supportedMethods, methodType, onClickClientSettings]) const onChooseCreateType = async (type: SupportedCreationMethods) => { if (type === SupportedCreationMethods.OAUTH) { @@ -160,7 +154,7 @@ export const CreateSubscriptionButton = ({ buttonType = CreateButtonType.FULL_BU <CustomSelect<Option & { show: boolean, extra?: React.ReactNode, tag?: React.ReactNode }> options={allOptions.filter(option => option.show)} value={methodType} - onChange={value => onChooseCreateType(value as any)} + onChange={value => onChooseCreateType(value as SupportedCreationMethods)} containerProps={{ open: (methodType === DEFAULT_METHOD || (methodType === SupportedCreationMethods.OAUTH && supportedMethods.length === 1)) ? undefined : false, placement: 'bottom-start', @@ -254,3 +248,5 @@ export const CreateSubscriptionButton = ({ buttonType = CreateButtonType.FULL_BU </> ) } + +export { CreateButtonType, DEFAULT_METHOD } from './types' diff --git a/web/app/components/plugins/plugin-detail-panel/subscription-list/create/oauth-client.spec.tsx b/web/app/components/plugins/plugin-detail-panel/subscription-list/create/oauth-client.spec.tsx index a842c63cfd..93cbbd518b 100644 --- a/web/app/components/plugins/plugin-detail-panel/subscription-list/create/oauth-client.spec.tsx +++ b/web/app/components/plugins/plugin-detail-panel/subscription-list/create/oauth-client.spec.tsx @@ -3,24 +3,14 @@ import { fireEvent, render, screen, waitFor } from '@testing-library/react' import * as React from 'react' import { afterEach, beforeEach, describe, expect, it, vi } from 'vitest' import { TriggerCredentialTypeEnum } from '@/app/components/workflow/block-selector/types' - -// Import after mocks import { OAuthClientSettingsModal } from './oauth-client' -// ============================================================================ -// Type Definitions -// ============================================================================ - type PluginDetail = { plugin_id: string provider: string name: string } -// ============================================================================ -// Mock Factory Functions -// ============================================================================ - function createMockOAuthConfig(overrides: Partial<TriggerOAuthConfig> = {}): TriggerOAuthConfig { return { configured: true, @@ -64,18 +54,12 @@ function createMockSubscriptionBuilder(overrides: Partial<TriggerSubscriptionBui } } -// ============================================================================ -// Mock Setup -// ============================================================================ - -// Mock plugin store const mockPluginDetail = createMockPluginDetail() const mockUsePluginStore = vi.fn(() => mockPluginDetail) vi.mock('../../store', () => ({ usePluginStore: () => mockUsePluginStore(), })) -// Mock service hooks const mockInitiateOAuth = vi.fn() const mockVerifyBuilder = vi.fn() const mockConfigureOAuth = vi.fn() @@ -96,13 +80,11 @@ vi.mock('@/service/use-triggers', () => ({ }), })) -// Mock OAuth popup const mockOpenOAuthPopup = vi.fn() vi.mock('@/hooks/use-oauth', () => ({ openOAuthPopup: (url: string, callback: (data: unknown) => void) => mockOpenOAuthPopup(url, callback), })) -// Mock toast const mockToastNotify = vi.fn() vi.mock('@/app/components/base/toast', () => ({ default: { @@ -110,7 +92,6 @@ vi.mock('@/app/components/base/toast', () => ({ }, })) -// Mock clipboard API const mockClipboardWriteText = vi.fn() Object.assign(navigator, { clipboard: { @@ -118,7 +99,6 @@ Object.assign(navigator, { }, }) -// Mock Modal component vi.mock('@/app/components/base/modal/modal', () => ({ default: ({ children, @@ -161,24 +141,6 @@ vi.mock('@/app/components/base/modal/modal', () => ({ ), })) -// Mock Button component -vi.mock('@/app/components/base/button', () => ({ - default: ({ children, onClick, variant, className }: { - children: React.ReactNode - onClick?: () => void - variant?: string - className?: string - }) => ( - <button - data-testid={`button-${variant || 'default'}`} - onClick={onClick} - className={className} - > - {children} - </button> - ), -})) -// Configurable form mock values let mockFormValues: { values: Record<string, string>, isCheckValidated: boolean } = { values: { client_id: 'test-client-id', client_secret: 'test-client-secret' }, isCheckValidated: true, @@ -210,29 +172,6 @@ vi.mock('@/app/components/base/form/components/base', () => ({ }), })) -// Mock OptionCard component -vi.mock('@/app/components/workflow/nodes/_base/components/option-card', () => ({ - default: ({ title, onSelect, selected, className }: { - title: string - onSelect: () => void - selected: boolean - className?: string - }) => ( - <div - data-testid={`option-card-${title}`} - onClick={onSelect} - className={`${className} ${selected ? 'selected' : ''}`} - data-selected={selected} - > - {title} - </div> - ), -})) - -// ============================================================================ -// Test Suites -// ============================================================================ - describe('OAuthClientSettingsModal', () => { const defaultProps = { oauthConfig: createMockOAuthConfig(), @@ -244,7 +183,6 @@ describe('OAuthClientSettingsModal', () => { vi.clearAllMocks() mockUsePluginStore.mockReturnValue(mockPluginDetail) mockClipboardWriteText.mockResolvedValue(undefined) - // Reset form values to default setMockFormValues({ values: { client_id: 'test-client-id', client_secret: 'test-client-secret' }, isCheckValidated: true, @@ -265,8 +203,8 @@ describe('OAuthClientSettingsModal', () => { it('should render client type selector when system_configured is true', () => { render(<OAuthClientSettingsModal {...defaultProps} />) - expect(screen.getByTestId('option-card-pluginTrigger.subscription.addType.options.oauth.default')).toBeInTheDocument() - expect(screen.getByTestId('option-card-pluginTrigger.subscription.addType.options.oauth.custom')).toBeInTheDocument() + expect(screen.getByText('pluginTrigger.subscription.addType.options.oauth.default')).toBeInTheDocument() + expect(screen.getByText('pluginTrigger.subscription.addType.options.oauth.custom')).toBeInTheDocument() }) it('should not render client type selector when system_configured is false', () => { @@ -276,7 +214,7 @@ describe('OAuthClientSettingsModal', () => { render(<OAuthClientSettingsModal {...defaultProps} oauthConfig={configWithoutSystemConfigured} />) - expect(screen.queryByTestId('option-card-pluginTrigger.subscription.addType.options.oauth.default')).not.toBeInTheDocument() + expect(screen.queryByText('pluginTrigger.subscription.addType.options.oauth.default')).not.toBeInTheDocument() }) it('should render redirect URI info when custom client type is selected', () => { @@ -319,29 +257,29 @@ describe('OAuthClientSettingsModal', () => { it('should default to Default client type when system_configured is true', () => { render(<OAuthClientSettingsModal {...defaultProps} />) - const defaultCard = screen.getByTestId('option-card-pluginTrigger.subscription.addType.options.oauth.default') - expect(defaultCard).toHaveAttribute('data-selected', 'true') + const defaultCard = screen.getByText('pluginTrigger.subscription.addType.options.oauth.default').closest('div') + expect(defaultCard).toHaveClass('border-[1.5px]') }) it('should switch to Custom client type when Custom card is clicked', () => { render(<OAuthClientSettingsModal {...defaultProps} />) - const customCard = screen.getByTestId('option-card-pluginTrigger.subscription.addType.options.oauth.custom') - fireEvent.click(customCard) + const customCard = screen.getByText('pluginTrigger.subscription.addType.options.oauth.custom').closest('div') + fireEvent.click(customCard!) - expect(customCard).toHaveAttribute('data-selected', 'true') + expect(customCard).toHaveClass('border-[1.5px]') }) it('should switch back to Default client type when Default card is clicked', () => { render(<OAuthClientSettingsModal {...defaultProps} />) - const customCard = screen.getByTestId('option-card-pluginTrigger.subscription.addType.options.oauth.custom') - fireEvent.click(customCard) + const customCard = screen.getByText('pluginTrigger.subscription.addType.options.oauth.custom').closest('div') + fireEvent.click(customCard!) - const defaultCard = screen.getByTestId('option-card-pluginTrigger.subscription.addType.options.oauth.default') - fireEvent.click(defaultCard) + const defaultCard = screen.getByText('pluginTrigger.subscription.addType.options.oauth.default').closest('div') + fireEvent.click(defaultCard!) - expect(defaultCard).toHaveAttribute('data-selected', 'true') + expect(defaultCard).toHaveClass('border-[1.5px]') }) }) @@ -852,8 +790,8 @@ describe('OAuthClientSettingsModal', () => { render(<OAuthClientSettingsModal {...defaultProps} />) // Switch to custom - const customCard = screen.getByTestId('option-card-pluginTrigger.subscription.addType.options.oauth.custom') - fireEvent.click(customCard) + const customCard = screen.getByText('pluginTrigger.subscription.addType.options.oauth.custom').closest('div') + fireEvent.click(customCard!) fireEvent.click(screen.getByTestId('modal-cancel')) @@ -1054,7 +992,7 @@ describe('OAuthClientSettingsModal', () => { render(<OAuthClientSettingsModal {...defaultProps} />) // Switch to custom type - const customCard = screen.getByTestId('option-card-pluginTrigger.subscription.addType.options.oauth.custom') + const customCard = screen.getByText('pluginTrigger.subscription.addType.options.oauth.custom').closest('div')! fireEvent.click(customCard) fireEvent.click(screen.getByTestId('modal-cancel')) @@ -1077,7 +1015,7 @@ describe('OAuthClientSettingsModal', () => { render(<OAuthClientSettingsModal {...defaultProps} />) // Switch to custom type - fireEvent.click(screen.getByTestId('option-card-pluginTrigger.subscription.addType.options.oauth.custom')) + fireEvent.click(screen.getByText('pluginTrigger.subscription.addType.options.oauth.custom').closest('div')!) fireEvent.click(screen.getByTestId('modal-cancel')) @@ -1104,7 +1042,7 @@ describe('OAuthClientSettingsModal', () => { render(<OAuthClientSettingsModal {...defaultProps} />) // Switch to custom type - fireEvent.click(screen.getByTestId('option-card-pluginTrigger.subscription.addType.options.oauth.custom')) + fireEvent.click(screen.getByText('pluginTrigger.subscription.addType.options.oauth.custom').closest('div')!) fireEvent.click(screen.getByTestId('modal-cancel')) @@ -1131,7 +1069,7 @@ describe('OAuthClientSettingsModal', () => { render(<OAuthClientSettingsModal {...defaultProps} />) // Switch to custom type - fireEvent.click(screen.getByTestId('option-card-pluginTrigger.subscription.addType.options.oauth.custom')) + fireEvent.click(screen.getByText('pluginTrigger.subscription.addType.options.oauth.custom').closest('div')!) fireEvent.click(screen.getByTestId('modal-cancel')) @@ -1158,7 +1096,7 @@ describe('OAuthClientSettingsModal', () => { render(<OAuthClientSettingsModal {...defaultProps} />) // Switch to custom type - fireEvent.click(screen.getByTestId('option-card-pluginTrigger.subscription.addType.options.oauth.custom')) + fireEvent.click(screen.getByText('pluginTrigger.subscription.addType.options.oauth.custom').closest('div')!) fireEvent.click(screen.getByTestId('modal-cancel')) diff --git a/web/app/components/plugins/plugin-detail-panel/subscription-list/create/oauth-client.tsx b/web/app/components/plugins/plugin-detail-panel/subscription-list/create/oauth-client.tsx index 25caf3b789..b7f9b8ebec 100644 --- a/web/app/components/plugins/plugin-detail-panel/subscription-list/create/oauth-client.tsx +++ b/web/app/components/plugins/plugin-detail-panel/subscription-list/create/oauth-client.tsx @@ -1,27 +1,17 @@ 'use client' -import type { FormRefObject } from '@/app/components/base/form/types' -import type { TriggerOAuthClientParams, TriggerOAuthConfig, TriggerSubscriptionBuilder } from '@/app/components/workflow/block-selector/types' -import type { ConfigureTriggerOAuthPayload } from '@/service/use-triggers' +import type { TriggerOAuthConfig, TriggerSubscriptionBuilder } from '@/app/components/workflow/block-selector/types' import { RiClipboardLine, RiInformation2Fill, } from '@remixicon/react' -import * as React from 'react' -import { useEffect, useMemo, useState } from 'react' import { useTranslation } from 'react-i18next' import Button from '@/app/components/base/button' import { BaseForm } from '@/app/components/base/form/components/base' import Modal from '@/app/components/base/modal/modal' import Toast from '@/app/components/base/toast' import OptionCard from '@/app/components/workflow/nodes/_base/components/option-card' -import { openOAuthPopup } from '@/hooks/use-oauth' -import { - useConfigureTriggerOAuth, - useDeleteTriggerOAuth, - useInitiateTriggerOAuth, - useVerifyAndUpdateTriggerSubscriptionBuilder, -} from '@/service/use-triggers' import { usePluginStore } from '../../store' +import { ClientTypeEnum, useOAuthClientState } from './hooks/use-oauth-client-state' type Props = { oauthConfig?: TriggerOAuthConfig @@ -29,169 +19,38 @@ type Props = { showOAuthCreateModal: (builder: TriggerSubscriptionBuilder) => void } -enum AuthorizationStatusEnum { - Pending = 'pending', - Success = 'success', - Failed = 'failed', -} - -enum ClientTypeEnum { - Default = 'default', - Custom = 'custom', -} +const CLIENT_TYPE_OPTIONS = [ClientTypeEnum.Default, ClientTypeEnum.Custom] as const export const OAuthClientSettingsModal = ({ oauthConfig, onClose, showOAuthCreateModal }: Props) => { const { t } = useTranslation() const detail = usePluginStore(state => state.detail) - const { system_configured, params, oauth_client_schema } = oauthConfig || {} - const [subscriptionBuilder, setSubscriptionBuilder] = useState<TriggerSubscriptionBuilder | undefined>() - const [authorizationStatus, setAuthorizationStatus] = useState<AuthorizationStatusEnum>() - - const [clientType, setClientType] = useState<ClientTypeEnum>(system_configured ? ClientTypeEnum.Default : ClientTypeEnum.Custom) - - const clientFormRef = React.useRef<FormRefObject>(null) - - const oauthClientSchema = useMemo(() => { - if (oauth_client_schema && oauth_client_schema.length > 0 && params) { - const oauthConfigPramaKeys = Object.keys(params || {}) - for (const schema of oauth_client_schema) { - if (oauthConfigPramaKeys.includes(schema.name)) - schema.default = params?.[schema.name] - } - return oauth_client_schema - } - return [] - }, [oauth_client_schema, params]) - const providerName = detail?.provider || '' - const { mutate: initiateOAuth } = useInitiateTriggerOAuth() - const { mutate: verifyBuilder } = useVerifyAndUpdateTriggerSubscriptionBuilder() - const { mutate: configureOAuth } = useConfigureTriggerOAuth() - const { mutate: deleteOAuth } = useDeleteTriggerOAuth() - const confirmButtonText = useMemo(() => { - if (authorizationStatus === AuthorizationStatusEnum.Pending) - return t('modal.common.authorizing', { ns: 'pluginTrigger' }) - if (authorizationStatus === AuthorizationStatusEnum.Success) - return t('modal.oauth.authorization.waitingJump', { ns: 'pluginTrigger' }) - return t('auth.saveAndAuth', { ns: 'plugin' }) - }, [authorizationStatus, t]) + const { + clientType, + setClientType, + clientFormRef, + oauthClientSchema, + confirmButtonText, + handleRemove, + handleSave, + } = useOAuthClientState({ + oauthConfig, + providerName, + onClose, + showOAuthCreateModal, + }) - const getErrorMessage = (error: unknown, fallback: string) => { - if (error instanceof Error && error.message) - return error.message - if (typeof error === 'object' && error && 'message' in error) { - const message = (error as { message?: string }).message - if (typeof message === 'string' && message) - return message - } - return fallback - } + const isCustomClient = clientType === ClientTypeEnum.Custom + const showRemoveButton = oauthConfig?.custom_enabled && oauthConfig?.params && isCustomClient + const showRedirectInfo = isCustomClient && oauthConfig?.redirect_uri + const showClientForm = isCustomClient && oauthClientSchema.length > 0 - const handleAuthorization = () => { - setAuthorizationStatus(AuthorizationStatusEnum.Pending) - initiateOAuth(providerName, { - onSuccess: (response) => { - setSubscriptionBuilder(response.subscription_builder) - openOAuthPopup(response.authorization_url, (callbackData) => { - if (callbackData) { - Toast.notify({ - type: 'success', - message: t('modal.oauth.authorization.authSuccess', { ns: 'pluginTrigger' }), - }) - onClose() - showOAuthCreateModal(response.subscription_builder) - } - }) - }, - onError: () => { - setAuthorizationStatus(AuthorizationStatusEnum.Failed) - Toast.notify({ - type: 'error', - message: t('modal.oauth.authorization.authFailed', { ns: 'pluginTrigger' }), - }) - }, - }) - } - - useEffect(() => { - if (providerName && subscriptionBuilder && authorizationStatus === AuthorizationStatusEnum.Pending) { - const pollInterval = setInterval(() => { - verifyBuilder( - { - provider: providerName, - subscriptionBuilderId: subscriptionBuilder.id, - }, - { - onSuccess: (response) => { - if (response.verified) { - setAuthorizationStatus(AuthorizationStatusEnum.Success) - clearInterval(pollInterval) - } - }, - onError: () => { - // Continue polling - auth might still be in progress - }, - }, - ) - }, 3000) - - return () => clearInterval(pollInterval) - } - }, [subscriptionBuilder, authorizationStatus, verifyBuilder, providerName, t]) - - const handleRemove = () => { - deleteOAuth(providerName, { - onSuccess: () => { - onClose() - Toast.notify({ - type: 'success', - message: t('modal.oauth.remove.success', { ns: 'pluginTrigger' }), - }) - }, - onError: (error: unknown) => { - Toast.notify({ - type: 'error', - message: getErrorMessage(error, t('modal.oauth.remove.failed', { ns: 'pluginTrigger' })), - }) - }, - }) - } - - const handleSave = (needAuth: boolean) => { - const isCustom = clientType === ClientTypeEnum.Custom - const params: ConfigureTriggerOAuthPayload = { - provider: providerName, - enabled: isCustom, - } - - if (isCustom) { - const clientFormValues = clientFormRef.current?.getFormValues({}) as { values: TriggerOAuthClientParams, isCheckValidated: boolean } - if (!clientFormValues.isCheckValidated) - return - const clientParams = clientFormValues.values - if (clientParams.client_id === oauthConfig?.params.client_id) - clientParams.client_id = '[__HIDDEN__]' - - if (clientParams.client_secret === oauthConfig?.params.client_secret) - clientParams.client_secret = '[__HIDDEN__]' - - params.client_params = clientParams - } - - configureOAuth(params, { - onSuccess: () => { - if (needAuth) { - handleAuthorization() - } - else { - onClose() - Toast.notify({ - type: 'success', - message: t('modal.oauth.save.success', { ns: 'pluginTrigger' }), - }) - } - }, + const handleCopyRedirectUri = () => { + navigator.clipboard.writeText(oauthConfig?.redirect_uri || '') + Toast.notify({ + type: 'success', + message: t('actionMsg.copySuccessfully', { ns: 'common' }), }) } @@ -208,25 +67,25 @@ export const OAuthClientSettingsModal = ({ oauthConfig, onClose, showOAuthCreate onClose={onClose} onCancel={() => handleSave(false)} onConfirm={() => handleSave(true)} - footerSlot={ - oauthConfig?.custom_enabled && oauthConfig?.params && clientType === ClientTypeEnum.Custom && ( - <div className="grow"> - <Button - variant="secondary" - className="text-components-button-destructive-secondary-text" - // disabled={disabled || doingAction || !editValues} - onClick={handleRemove} - > - {t('operation.remove', { ns: 'common' })} - </Button> - </div> - ) - } + footerSlot={showRemoveButton && ( + <div className="grow"> + <Button + variant="secondary" + className="text-components-button-destructive-secondary-text" + onClick={handleRemove} + > + {t('operation.remove', { ns: 'common' })} + </Button> + </div> + )} > - <div className="system-sm-medium mb-2 text-text-secondary">{t('subscription.addType.options.oauth.clientTitle', { ns: 'pluginTrigger' })}</div> + <div className="system-sm-medium mb-2 text-text-secondary"> + {t('subscription.addType.options.oauth.clientTitle', { ns: 'pluginTrigger' })} + </div> + {oauthConfig?.system_configured && ( <div className="mb-4 flex w-full items-start justify-between gap-2"> - {[ClientTypeEnum.Default, ClientTypeEnum.Custom].map(option => ( + {CLIENT_TYPE_OPTIONS.map(option => ( <OptionCard key={option} title={t(`subscription.addType.options.oauth.${option}`, { ns: 'pluginTrigger' })} @@ -237,7 +96,8 @@ export const OAuthClientSettingsModal = ({ oauthConfig, onClose, showOAuthCreate ))} </div> )} - {clientType === ClientTypeEnum.Custom && oauthConfig?.redirect_uri && ( + + {showRedirectInfo && ( <div className="mb-4 flex items-start gap-3 rounded-xl bg-background-section-burn p-4"> <div className="rounded-lg border-[0.5px] border-components-card-border bg-components-card-bg p-2 shadow-xs shadow-shadow-shadow-3"> <RiInformation2Fill className="h-5 w-5 shrink-0 text-text-accent" /> @@ -247,18 +107,12 @@ export const OAuthClientSettingsModal = ({ oauthConfig, onClose, showOAuthCreate {t('modal.oauthRedirectInfo', { ns: 'pluginTrigger' })} </div> <div className="system-sm-medium my-1.5 break-all leading-4"> - {oauthConfig.redirect_uri} + {oauthConfig?.redirect_uri} </div> <Button variant="secondary" size="small" - onClick={() => { - navigator.clipboard.writeText(oauthConfig.redirect_uri) - Toast.notify({ - type: 'success', - message: t('actionMsg.copySuccessfully', { ns: 'common' }), - }) - }} + onClick={handleCopyRedirectUri} > <RiClipboardLine className="mr-1 h-[14px] w-[14px]" /> {t('operation.copy', { ns: 'common' })} @@ -266,7 +120,8 @@ export const OAuthClientSettingsModal = ({ oauthConfig, onClose, showOAuthCreate </div> </div> )} - {clientType === ClientTypeEnum.Custom && oauthClientSchema.length > 0 && ( + + {showClientForm && ( <BaseForm formSchemas={oauthClientSchema} ref={clientFormRef} diff --git a/web/app/components/plugins/plugin-detail-panel/subscription-list/create/types.ts b/web/app/components/plugins/plugin-detail-panel/subscription-list/create/types.ts new file mode 100644 index 0000000000..637846b606 --- /dev/null +++ b/web/app/components/plugins/plugin-detail-panel/subscription-list/create/types.ts @@ -0,0 +1,6 @@ +export enum CreateButtonType { + FULL_BUTTON = 'full-button', + ICON_BUTTON = 'icon-button', +} + +export const DEFAULT_METHOD = 'default' diff --git a/web/eslint-suppressions.json b/web/eslint-suppressions.json index e5ced085ff..9cfe1fd462 100644 --- a/web/eslint-suppressions.json +++ b/web/eslint-suppressions.json @@ -2445,11 +2445,6 @@ "count": 8 } }, - "app/components/plugins/plugin-detail-panel/app-selector/app-inputs-panel.tsx": { - "ts/no-explicit-any": { - "count": 8 - } - }, "app/components/plugins/plugin-detail-panel/datasource-action-list.tsx": { "ts/no-explicit-any": { "count": 1 @@ -2503,14 +2498,6 @@ "count": 2 } }, - "app/components/plugins/plugin-detail-panel/subscription-list/create/index.tsx": { - "react-refresh/only-export-components": { - "count": 1 - }, - "ts/no-explicit-any": { - "count": 1 - } - }, "app/components/plugins/plugin-detail-panel/subscription-list/delete-confirm.tsx": { "ts/no-explicit-any": { "count": 1 From 468990cc3953743f3a72d4af76b1f7760805854d Mon Sep 17 00:00:00 2001 From: Stephen Zhou <38493346+hyoban@users.noreply.github.com> Date: Wed, 4 Feb 2026 14:58:26 +0800 Subject: [PATCH 03/18] fix: remove api reference doc link en prefix (#31910) --- web/context/i18n.spec.ts | 8 ++++---- web/context/i18n.ts | 13 +++++++------ 2 files changed, 11 insertions(+), 10 deletions(-) diff --git a/web/context/i18n.spec.ts b/web/context/i18n.spec.ts index 98f3552c99..616f3bfced 100644 --- a/web/context/i18n.spec.ts +++ b/web/context/i18n.spec.ts @@ -196,19 +196,19 @@ describe('useDocLink', () => { const { result } = renderHook(() => useDocLink()) const url = result.current('/api-reference/annotations/create-annotation') - expect(url).toBe(`${defaultDocBaseUrl}/en/api-reference/annotations/create-annotation`) + expect(url).toBe(`${defaultDocBaseUrl}/api-reference/annotations/create-annotation`) }) it('should keep original path when no translation exists for non-English locale', () => { vi.mocked(useTranslation).mockReturnValue({ - i18n: { language: 'ja-JP' }, + i18n: { language: 'zh-Hans' }, } as ReturnType<typeof useTranslation>) - vi.mocked(getDocLanguage).mockReturnValue('ja') + vi.mocked(getDocLanguage).mockReturnValue('zh') const { result } = renderHook(() => useDocLink()) // This path has no Japanese translation const url = result.current('/api-reference/annotations/create-annotation') - expect(url).toBe(`${defaultDocBaseUrl}/ja/api-reference/annotations/create-annotation`) + expect(url).toBe(`${defaultDocBaseUrl}/api-reference/ๆ ‡ๆณจ็ฎก็†/ๅˆ›ๅปบๆ ‡ๆณจ`) }) it('should remove language prefix when translation is applied', () => { diff --git a/web/context/i18n.ts b/web/context/i18n.ts index 5f39d1afb3..f371c1129b 100644 --- a/web/context/i18n.ts +++ b/web/context/i18n.ts @@ -35,12 +35,13 @@ export const useDocLink = (baseUrl?: string): ((path?: DocPathWithoutLang, pathM let targetPath = (pathMap) ? pathMap[locale] || pathUrl : pathUrl let languagePrefix = `/${docLanguage}` - // Translate API reference paths for non-English locales - if (targetPath.startsWith('/api-reference/') && docLanguage !== 'en') { - const translatedPath = apiReferencePathTranslations[targetPath]?.[docLanguage as 'zh' | 'ja'] - if (translatedPath) { - targetPath = translatedPath - languagePrefix = '' + if (targetPath.startsWith('/api-reference/')) { + languagePrefix = '' + if (docLanguage !== 'en') { + const translatedPath = apiReferencePathTranslations[targetPath]?.[docLanguage] + if (translatedPath) { + targetPath = translatedPath + } } } From 0d74ac634b541e4b0adcefdd03a4d0f2c7b23f6a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=9D=9E=E6=B3=95=E6=93=8D=E4=BD=9C?= <hjlarry@163.com> Date: Wed, 4 Feb 2026 16:08:00 +0800 Subject: [PATCH 04/18] fix: missing import console_ns (#31916) --- api/controllers/console/explore/trial.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/api/controllers/console/explore/trial.py b/api/controllers/console/explore/trial.py index ba214e71c0..c417967c88 100644 --- a/api/controllers/console/explore/trial.py +++ b/api/controllers/console/explore/trial.py @@ -10,7 +10,7 @@ import services from controllers.common.fields import Parameters as ParametersResponse from controllers.common.fields import Site as SiteResponse from controllers.common.schema import get_or_create_model -from controllers.console import api +from controllers.console import api, console_ns from controllers.console.app.error import ( AppUnavailableError, AudioTooLargeError, From ec7ccd800c6091bc7430d6dc9e23757e2a2df641 Mon Sep 17 00:00:00 2001 From: wangxiaolei <fatelei@gmail.com> Date: Wed, 4 Feb 2026 16:55:12 +0800 Subject: [PATCH 05/18] fix: fix mcp server status is not right (#31826) Co-authored-by: Stephen Zhou <38493346+hyoban@users.noreply.github.com> --- .../components/tools/mcp/detail/content.tsx | 23 ++++++++++++------- 1 file changed, 15 insertions(+), 8 deletions(-) diff --git a/web/app/components/tools/mcp/detail/content.tsx b/web/app/components/tools/mcp/detail/content.tsx index 2476973fa5..23d408706d 100644 --- a/web/app/components/tools/mcp/detail/content.tsx +++ b/web/app/components/tools/mcp/detail/content.tsx @@ -103,15 +103,22 @@ const MCPDetailContent: FC<Props> = ({ return if (!detail) return - const res = await authorizeMcp({ - provider_id: detail.id, - }) - if (res.result === 'success') - handleUpdateTools() + try { + const res = await authorizeMcp({ + provider_id: detail.id, + }) + if (res.result === 'success') + handleUpdateTools() - else if (res.authorization_url) - openOAuthPopup(res.authorization_url, handleOAuthCallback) - }, [onFirstCreate, isCurrentWorkspaceManager, detail, authorizeMcp, handleUpdateTools, handleOAuthCallback]) + else if (res.authorization_url) + openOAuthPopup(res.authorization_url, handleOAuthCallback) + } + catch { + // On authorization error, refresh the parent component state + // to update the connection status indicator + onUpdate() + } + }, [onFirstCreate, isCurrentWorkspaceManager, detail, authorizeMcp, handleUpdateTools, handleOAuthCallback, onUpdate]) const handleUpdate = useCallback(async (data: any) => { if (!detail) From 5f69470ebf966a781b43520c66980d092cf5c66b Mon Sep 17 00:00:00 2001 From: Stephen Zhou <38493346+hyoban@users.noreply.github.com> Date: Wed, 4 Feb 2026 17:05:15 +0800 Subject: [PATCH 06/18] test: try fix test, clear test log in CI (#31912) --- .../create/common-modal.spec.tsx | 3 ++ .../components/update-dsl-modal.spec.tsx | 54 ++++++++++++++----- web/package.json | 5 +- web/pnpm-lock.yaml | 26 +++++++++ web/vitest.config.ts | 2 +- web/vitest.setup.ts | 1 + 6 files changed, 75 insertions(+), 16 deletions(-) diff --git a/web/app/components/plugins/plugin-detail-panel/subscription-list/create/common-modal.spec.tsx b/web/app/components/plugins/plugin-detail-panel/subscription-list/create/common-modal.spec.tsx index b7e4f01f58..0c1b5efc29 100644 --- a/web/app/components/plugins/plugin-detail-panel/subscription-list/create/common-modal.spec.tsx +++ b/web/app/components/plugins/plugin-detail-panel/subscription-list/create/common-modal.spec.tsx @@ -2031,6 +2031,9 @@ describe('CommonCreateModal', () => { expect(mockCreateBuilder).toHaveBeenCalled() }) + // Flush pending state updates from createBuilder promise resolution + await act(async () => {}) + const input = screen.getByTestId('form-field-webhook_url') fireEvent.change(input, { target: { value: 'test' } }) diff --git a/web/app/components/rag-pipeline/components/update-dsl-modal.spec.tsx b/web/app/components/rag-pipeline/components/update-dsl-modal.spec.tsx index f57bd80d7b..45eb1cafe1 100644 --- a/web/app/components/rag-pipeline/components/update-dsl-modal.spec.tsx +++ b/web/app/components/rag-pipeline/components/update-dsl-modal.spec.tsx @@ -613,6 +613,11 @@ describe('UpdateDSLModal', () => { expect(importButton).not.toBeDisabled() }) + // Flush the FileReader microtask to ensure fileContent is set + await act(async () => { + await new Promise<void>(resolve => queueMicrotask(resolve)) + }) + const importButton = screen.getByText('common.overwriteAndImport') fireEvent.click(importButton) @@ -761,6 +766,8 @@ describe('UpdateDSLModal', () => { }) it('should call importDSLConfirm when confirm button is clicked in error modal', async () => { + vi.useFakeTimers({ shouldAdvanceTime: true }) + mockImportDSL.mockResolvedValue({ id: 'import-id', status: DSLImportStatus.PENDING, @@ -778,20 +785,27 @@ describe('UpdateDSLModal', () => { const fileInput = screen.getByTestId('file-input') const file = new File(['test content'], 'test.pipeline', { type: 'text/yaml' }) - fireEvent.change(fileInput, { target: { files: [file] } }) - await waitFor(() => { - const importButton = screen.getByText('common.overwriteAndImport') - expect(importButton).not.toBeDisabled() + await act(async () => { + fireEvent.change(fileInput, { target: { files: [file] } }) + // Flush microtasks scheduled by the FileReader mock (which uses queueMicrotask) + await new Promise<void>(resolve => queueMicrotask(resolve)) }) const importButton = screen.getByText('common.overwriteAndImport') - fireEvent.click(importButton) + expect(importButton).not.toBeDisabled() + + await act(async () => { + fireEvent.click(importButton) + // Flush the promise resolution from mockImportDSL + await Promise.resolve() + // Advance past the 300ms setTimeout in the component + await vi.advanceTimersByTimeAsync(350) + }) - // Wait for error modal await waitFor(() => { expect(screen.getByText('newApp.appCreateDSLErrorTitle')).toBeInTheDocument() - }, { timeout: 500 }) + }) // Click confirm button const confirmButton = screen.getByText('newApp.Confirm') @@ -800,6 +814,8 @@ describe('UpdateDSLModal', () => { await waitFor(() => { expect(mockImportDSLConfirm).toHaveBeenCalledWith('import-id') }) + + vi.useRealTimers() }) it('should show success notification after confirm completes', async () => { @@ -1008,6 +1024,8 @@ describe('UpdateDSLModal', () => { }) it('should call handleCheckPluginDependencies after confirm', async () => { + vi.useFakeTimers({ shouldAdvanceTime: true }) + mockImportDSL.mockResolvedValue({ id: 'import-id', status: DSLImportStatus.PENDING, @@ -1025,19 +1043,27 @@ describe('UpdateDSLModal', () => { const fileInput = screen.getByTestId('file-input') const file = new File(['test content'], 'test.pipeline', { type: 'text/yaml' }) - fireEvent.change(fileInput, { target: { files: [file] } }) - await waitFor(() => { - const importButton = screen.getByText('common.overwriteAndImport') - expect(importButton).not.toBeDisabled() + await act(async () => { + fireEvent.change(fileInput, { target: { files: [file] } }) + // Flush microtasks scheduled by the FileReader mock (which uses queueMicrotask) + await new Promise<void>(resolve => queueMicrotask(resolve)) }) const importButton = screen.getByText('common.overwriteAndImport') - fireEvent.click(importButton) + expect(importButton).not.toBeDisabled() + + await act(async () => { + fireEvent.click(importButton) + // Flush the promise resolution from mockImportDSL + await Promise.resolve() + // Advance past the 300ms setTimeout in the component + await vi.advanceTimersByTimeAsync(350) + }) await waitFor(() => { expect(screen.getByText('newApp.appCreateDSLErrorTitle')).toBeInTheDocument() - }, { timeout: 500 }) + }) const confirmButton = screen.getByText('newApp.Confirm') fireEvent.click(confirmButton) @@ -1045,6 +1071,8 @@ describe('UpdateDSLModal', () => { await waitFor(() => { expect(mockHandleCheckPluginDependencies).toHaveBeenCalledWith('test-pipeline-id', true) }) + + vi.useRealTimers() }) it('should handle undefined imported_dsl_version and current_dsl_version', async () => { diff --git a/web/package.json b/web/package.json index 7b2e570554..494a9f0848 100644 --- a/web/package.json +++ b/web/package.json @@ -46,7 +46,7 @@ "uglify-embed": "node ./bin/uglify-embed", "i18n:check": "tsx ./scripts/check-i18n.js", "test": "vitest run", - "test:coverage": "vitest run --coverage", + "test:coverage": "vitest run --coverage --reporter=dot --silent=passed-only", "test:watch": "vitest --watch", "analyze-component": "node ./scripts/analyze-component.js", "refactor-component": "node ./scripts/refactor-component.js", @@ -233,7 +233,8 @@ "uglify-js": "3.19.3", "vite": "7.3.1", "vite-tsconfig-paths": "6.0.4", - "vitest": "4.0.17" + "vitest": "4.0.17", + "vitest-canvas-mock": "1.1.3" }, "pnpm": { "overrides": { diff --git a/web/pnpm-lock.yaml b/web/pnpm-lock.yaml index abf3a444a2..9119d2554a 100644 --- a/web/pnpm-lock.yaml +++ b/web/pnpm-lock.yaml @@ -579,6 +579,9 @@ importers: vitest: specifier: 4.0.17 version: 4.0.17(@types/node@18.15.0)(@vitest/browser-playwright@4.0.17)(jiti@1.21.7)(jsdom@27.3.0(canvas@3.2.1))(sass@1.93.2)(terser@5.46.0)(tsx@4.21.0)(yaml@2.8.2) + vitest-canvas-mock: + specifier: 1.1.3 + version: 1.1.3(vitest@4.0.17) packages: @@ -4002,6 +4005,9 @@ packages: engines: {node: '>=4'} hasBin: true + cssfontparser@1.2.1: + resolution: {integrity: sha512-6tun4LoZnj7VN6YeegOVb67KBX/7JJsqvj+pv3ZA7F878/eN33AbGa5b/S/wXxS/tcp8nc40xRUrsPlxIyNUPg==} + cssstyle@5.3.7: resolution: {integrity: sha512-7D2EPVltRrsTkhpQmksIu+LxeWAIEk6wRDMJ1qljlv+CKHJM+cJLlfhWIzNA44eAsHXSNe3+vO6DW1yCYx8SuQ==} engines: {node: '>=20'} @@ -5751,6 +5757,9 @@ packages: monaco-editor@0.55.1: resolution: {integrity: sha512-jz4x+TJNFHwHtwuV9vA9rMujcZRb0CEilTEwG2rRSpe/A7Jdkuj8xPKttCgOh+v/lkHy7HsZ64oj+q3xoAFl9A==} + moo-color@1.0.3: + resolution: {integrity: sha512-i/+ZKXMDf6aqYtBhuOcej71YSlbjT3wCO/4H1j8rPvxDJEifdwgg5MaFyu6iYAT8GBZJg2z0dkgK4YMzvURALQ==} + mri@1.2.0: resolution: {integrity: sha512-tzzskb3bG8LvYGFF/mDTpq3jpI6Q9wc3LEmBaghu+DdCssd1FakN7Bc0hVNmEyGq1bq3RgfkCb3cmQLpNPOroA==} engines: {node: '>=4'} @@ -7216,6 +7225,11 @@ packages: yaml: optional: true + vitest-canvas-mock@1.1.3: + resolution: {integrity: sha512-zlKJR776Qgd+bcACPh0Pq5MG3xWq+CdkACKY/wX4Jyija0BSz8LH3aCCgwFKYFwtm565+050YFEGG9Ki0gE/Hw==} + peerDependencies: + vitest: ^3.0.0 || ^4.0.0 + vitest@4.0.17: resolution: {integrity: sha512-FQMeF0DJdWY0iOnbv466n/0BudNdKj1l5jYgl5JVTwjSsZSlqyXFt/9+1sEyhR6CLowbZpV7O1sCHrzBhucKKg==} engines: {node: ^20.0.0 || ^22.0.0 || >=24.0.0} @@ -11274,6 +11288,8 @@ snapshots: cssesc@3.0.0: {} + cssfontparser@1.2.1: {} + cssstyle@5.3.7: dependencies: '@asamuzakjp/css-color': 4.1.1 @@ -13573,6 +13589,10 @@ snapshots: dompurify: 3.2.7 marked: 14.0.0 + moo-color@1.0.3: + dependencies: + color-name: 1.1.4 + mri@1.2.0: {} mrmime@2.0.1: {} @@ -15202,6 +15222,12 @@ snapshots: tsx: 4.21.0 yaml: 2.8.2 + vitest-canvas-mock@1.1.3(vitest@4.0.17): + dependencies: + cssfontparser: 1.2.1 + moo-color: 1.0.3 + vitest: 4.0.17(@types/node@18.15.0)(@vitest/browser-playwright@4.0.17)(jiti@1.21.7)(jsdom@27.3.0(canvas@3.2.1))(sass@1.93.2)(terser@5.46.0)(tsx@4.21.0)(yaml@2.8.2) + vitest@4.0.17(@types/node@18.15.0)(@vitest/browser-playwright@4.0.17)(jiti@1.21.7)(jsdom@27.3.0(canvas@3.2.1))(sass@1.93.2)(terser@5.46.0)(tsx@4.21.0)(yaml@2.8.2): dependencies: '@vitest/expect': 4.0.17 diff --git a/web/vitest.config.ts b/web/vitest.config.ts index c58a92f217..370bc74904 100644 --- a/web/vitest.config.ts +++ b/web/vitest.config.ts @@ -8,7 +8,7 @@ export default mergeConfig(viteConfig, defineConfig({ setupFiles: ['./vitest.setup.ts'], coverage: { provider: 'v8', - reporter: ['text', 'json', 'json-summary'], + reporter: ['json', 'json-summary'], }, }, })) diff --git a/web/vitest.setup.ts b/web/vitest.setup.ts index 0e1d9e6d10..9e54b80492 100644 --- a/web/vitest.setup.ts +++ b/web/vitest.setup.ts @@ -1,6 +1,7 @@ import { act, cleanup } from '@testing-library/react' import { mockAnimationsApi, mockResizeObserver } from 'jsdom-testing-mocks' import '@testing-library/jest-dom/vitest' +import 'vitest-canvas-mock' mockResizeObserver() From 74b027c41af3c26e1bbe6883f549b774b1706e05 Mon Sep 17 00:00:00 2001 From: wangxiaolei <fatelei@gmail.com> Date: Wed, 4 Feb 2026 17:33:41 +0800 Subject: [PATCH 07/18] fix: fix mcp output schema is union type frontend crash (#31779) Co-authored-by: Stephen Zhou <38493346+hyoban@users.noreply.github.com> --- .../workflow/nodes/tool/use-config.ts | 22 +++++++++---------- 1 file changed, 11 insertions(+), 11 deletions(-) diff --git a/web/app/components/workflow/nodes/tool/use-config.ts b/web/app/components/workflow/nodes/tool/use-config.ts index 7e4594f4f2..87e9186008 100644 --- a/web/app/components/workflow/nodes/tool/use-config.ts +++ b/web/app/components/workflow/nodes/tool/use-config.ts @@ -1,6 +1,7 @@ import type { ToolNodeType, ToolVarInputs } from './types' import type { InputVar } from '@/app/components/workflow/types' import { useBoolean } from 'ahooks' +import { capitalize } from 'es-toolkit/string' import { produce } from 'immer' import { useCallback, useEffect, useMemo, useState } from 'react' import { useTranslation } from 'react-i18next' @@ -25,6 +26,12 @@ import { } from '@/service/use-tools' import { canFindTool } from '@/utils' import { useWorkflowStore } from '../../store' +import { normalizeJsonSchemaType } from './output-schema-utils' + +const formatDisplayType = (output: Record<string, unknown>): string => { + const normalizedType = normalizeJsonSchemaType(output) || 'Unknown' + return capitalize(normalizedType) +} const useConfig = (id: string, payload: ToolNodeType) => { const workflowStore = useWorkflowStore() @@ -247,20 +254,13 @@ const useConfig = (id: string, payload: ToolNodeType) => { }) } else { + const normalizedType = normalizeJsonSchemaType(output) res.push({ name: outputKey, type: - output.type === 'array' - ? `Array[${output.items?.type - ? output.items.type.slice(0, 1).toLocaleUpperCase() - + output.items.type.slice(1) - : 'Unknown' - }]` - : `${output.type - ? output.type.slice(0, 1).toLocaleUpperCase() - + output.type.slice(1) - : 'Unknown' - }`, + normalizedType === 'array' + ? `Array[${output.items ? formatDisplayType(output.items) : 'Unknown'}]` + : formatDisplayType(output), description: output.description, }) } From cc5705cb7168fc6b8c7bc5f360fed16d14217357 Mon Sep 17 00:00:00 2001 From: zxhlyh <jasonapring2015@outlook.com> Date: Wed, 4 Feb 2026 17:47:38 +0800 Subject: [PATCH 08/18] fix: auto summary env (#31930) --- .../components/general-chunking-options.tsx | 2 +- .../step-two/components/parent-child-options.tsx | 3 ++- .../datasets/documents/components/operations.tsx | 13 +++++++++---- .../detail/completed/common/batch-action.tsx | 3 ++- web/app/components/datasets/settings/form/index.tsx | 3 ++- .../workflow/nodes/knowledge-base/panel.tsx | 3 ++- 6 files changed, 18 insertions(+), 9 deletions(-) diff --git a/web/app/components/datasets/create/step-two/components/general-chunking-options.tsx b/web/app/components/datasets/create/step-two/components/general-chunking-options.tsx index 84d742d734..0beda8f5c8 100644 --- a/web/app/components/datasets/create/step-two/components/general-chunking-options.tsx +++ b/web/app/components/datasets/create/step-two/components/general-chunking-options.tsx @@ -154,7 +154,7 @@ export const GeneralChunkingOptions: FC<GeneralChunkingOptionsProps> = ({ </div> ))} { - showSummaryIndexSetting && ( + showSummaryIndexSetting && IS_CE_EDITION && ( <div className="mt-3"> <SummaryIndexSetting entry="create-document" diff --git a/web/app/components/datasets/create/step-two/components/parent-child-options.tsx b/web/app/components/datasets/create/step-two/components/parent-child-options.tsx index 22b88037e1..b7b965a4fd 100644 --- a/web/app/components/datasets/create/step-two/components/parent-child-options.tsx +++ b/web/app/components/datasets/create/step-two/components/parent-child-options.tsx @@ -12,6 +12,7 @@ import Divider from '@/app/components/base/divider' import { ParentChildChunk } from '@/app/components/base/icons/src/vender/knowledge' import RadioCard from '@/app/components/base/radio-card' import SummaryIndexSetting from '@/app/components/datasets/settings/summary-index-setting' +import { IS_CE_EDITION } from '@/config' import { ChunkingMode } from '@/models/datasets' import FileList from '../../assets/file-list-3-fill.svg' import Note from '../../assets/note-mod.svg' @@ -191,7 +192,7 @@ export const ParentChildOptions: FC<ParentChildOptionsProps> = ({ </div> ))} { - showSummaryIndexSetting && ( + showSummaryIndexSetting && IS_CE_EDITION && ( <div className="mt-3"> <SummaryIndexSetting entry="create-document" diff --git a/web/app/components/datasets/documents/components/operations.tsx b/web/app/components/datasets/documents/components/operations.tsx index d3dcc23121..cdd694fad9 100644 --- a/web/app/components/datasets/documents/components/operations.tsx +++ b/web/app/components/datasets/documents/components/operations.tsx @@ -26,6 +26,7 @@ import CustomPopover from '@/app/components/base/popover' import Switch from '@/app/components/base/switch' import { ToastContext } from '@/app/components/base/toast' import Tooltip from '@/app/components/base/tooltip' +import { IS_CE_EDITION } from '@/config' import { DataSourceType, DocumentActionType } from '@/models/datasets' import { useDocumentArchive, @@ -263,10 +264,14 @@ const Operations = ({ <span className={s.actionName}>{t('list.action.sync', { ns: 'datasetDocuments' })}</span> </div> )} - <div className={s.actionItem} onClick={() => onOperate('summary')}> - <SearchLinesSparkle className="h-4 w-4 text-text-tertiary" /> - <span className={s.actionName}>{t('list.action.summary', { ns: 'datasetDocuments' })}</span> - </div> + { + IS_CE_EDITION && ( + <div className={s.actionItem} onClick={() => onOperate('summary')}> + <SearchLinesSparkle className="h-4 w-4 text-text-tertiary" /> + <span className={s.actionName}>{t('list.action.summary', { ns: 'datasetDocuments' })}</span> + </div> + ) + } <Divider className="my-1" /> </> )} diff --git a/web/app/components/datasets/documents/detail/completed/common/batch-action.tsx b/web/app/components/datasets/documents/detail/completed/common/batch-action.tsx index 486ba2ffdf..ca5a56ec2a 100644 --- a/web/app/components/datasets/documents/detail/completed/common/batch-action.tsx +++ b/web/app/components/datasets/documents/detail/completed/common/batch-action.tsx @@ -7,6 +7,7 @@ import Button from '@/app/components/base/button' import Confirm from '@/app/components/base/confirm' import Divider from '@/app/components/base/divider' import { SearchLinesSparkle } from '@/app/components/base/icons/src/vender/knowledge' +import { IS_CE_EDITION } from '@/config' import { cn } from '@/utils/classnames' const i18nPrefix = 'batchAction' @@ -87,7 +88,7 @@ const BatchAction: FC<IBatchActionProps> = ({ <span className="px-0.5">{t('metadata.metadata', { ns: 'dataset' })}</span> </Button> )} - {onBatchSummary && ( + {onBatchSummary && IS_CE_EDITION && ( <Button variant="ghost" className="gap-x-0.5 px-3" diff --git a/web/app/components/datasets/settings/form/index.tsx b/web/app/components/datasets/settings/form/index.tsx index 1993c9fd8d..ca072cfcae 100644 --- a/web/app/components/datasets/settings/form/index.tsx +++ b/web/app/components/datasets/settings/form/index.tsx @@ -21,6 +21,7 @@ import RetrievalMethodConfig from '@/app/components/datasets/common/retrieval-me import { ModelTypeEnum } from '@/app/components/header/account-setting/model-provider-page/declarations' import { useModelList } from '@/app/components/header/account-setting/model-provider-page/hooks' import ModelSelector from '@/app/components/header/account-setting/model-provider-page/model-selector' +import { IS_CE_EDITION } from '@/config' import { useSelector as useAppContextWithSelector } from '@/context/app-context' import { useDatasetDetailContextWithSelector } from '@/context/dataset-detail' import { useDocLink } from '@/context/i18n' @@ -359,7 +360,7 @@ const Form = () => { { indexMethod === IndexingType.QUALIFIED && [ChunkingMode.text, ChunkingMode.parentChild].includes(currentDataset?.doc_form as ChunkingMode) - && ( + && IS_CE_EDITION && ( <> <Divider type="horizontal" diff --git a/web/app/components/workflow/nodes/knowledge-base/panel.tsx b/web/app/components/workflow/nodes/knowledge-base/panel.tsx index 0a275645a8..2845d605bf 100644 --- a/web/app/components/workflow/nodes/knowledge-base/panel.tsx +++ b/web/app/components/workflow/nodes/knowledge-base/panel.tsx @@ -18,6 +18,7 @@ import { Group, } from '@/app/components/workflow/nodes/_base/components/layout' import VarReferencePicker from '@/app/components/workflow/nodes/_base/components/variable/var-reference-picker' +import { IS_CE_EDITION } from '@/config' import Split from '../_base/components/split' import ChunkStructure from './components/chunk-structure' import EmbeddingModel from './components/embedding-model' @@ -172,7 +173,7 @@ const Panel: FC<NodePanelProps<KnowledgeBaseNodeType>> = ({ { data.indexing_technique === IndexMethodEnum.QUALIFIED && [ChunkStructureEnum.general, ChunkStructureEnum.parent_child].includes(data.chunk_structure) - && ( + && IS_CE_EDITION && ( <> <SummaryIndexSetting summaryIndexSetting={data.summary_index_setting} From 297dd832aac253cf694bb6697d91f3fafb88260d Mon Sep 17 00:00:00 2001 From: Coding On Star <447357187@qq.com> Date: Wed, 4 Feb 2026 18:12:17 +0800 Subject: [PATCH 09/18] refactor(datasets): extract hooks and components with comprehensive tests (#31707) Co-authored-by: CodingOnStar <hanxujiang@dify.com> Co-authored-by: Claude Opus 4.5 <noreply@anthropic.com> --- .../common/image-uploader/utils.spec.ts | 39 +- .../datasets/common/image-uploader/utils.ts | 8 +- .../hooks/use-dsl-import.spec.tsx | 1045 +++++++++++++++++ .../hooks/use-dsl-import.ts | 218 ++++ .../create-from-dsl-modal/index.tsx | 257 +--- .../components/file-list-item.spec.tsx | 334 ++++++ .../components/file-list-item.tsx | 89 ++ .../components/upload-dropzone.spec.tsx | 210 ++++ .../components/upload-dropzone.tsx | 84 ++ .../create/file-uploader/constants.ts | 3 + .../hooks/use-file-upload.spec.tsx | 921 +++++++++++++++ .../file-uploader/hooks/use-file-upload.ts | 351 ++++++ .../create/file-uploader/index.spec.tsx | 278 +++++ .../datasets/create/file-uploader/index.tsx | 407 +------ .../components/file-list-item.spec.tsx | 351 ++++++ .../local-file/components/file-list-item.tsx | 85 ++ .../components/upload-dropzone.spec.tsx | 231 ++++ .../local-file/components/upload-dropzone.tsx | 83 ++ .../data-source/local-file/constants.ts | 3 + .../hooks/use-local-file-upload.spec.tsx | 911 ++++++++++++++ .../local-file/hooks/use-local-file-upload.ts | 105 ++ .../data-source/local-file/index.spec.tsx | 398 +++++++ .../data-source/local-file/index.tsx | 391 +----- .../components/basic-info-section.spec.tsx | 441 +++++++ .../form/components/basic-info-section.tsx | 124 ++ .../external-knowledge-section.spec.tsx | 362 ++++++ .../components/external-knowledge-section.tsx | 84 ++ .../form/components/indexing-section.spec.tsx | 501 ++++++++ .../form/components/indexing-section.tsx | 208 ++++ .../form/hooks/use-form-state.spec.ts | 763 ++++++++++++ .../settings/form/hooks/use-form-state.ts | 264 +++++ .../datasets/settings/form/index.spec.tsx | 488 ++++++++ .../datasets/settings/form/index.tsx | 575 ++------- .../rag-pipeline/hooks/use-DSL.spec.ts | 302 ++--- .../index.spec.tsx | 7 +- .../workflow-onboarding-modal/index.spec.tsx | 12 +- web/eslint-suppressions.json | 28 - web/utils/format.ts | 20 + 38 files changed, 9328 insertions(+), 1653 deletions(-) create mode 100644 web/app/components/datasets/create-from-pipeline/create-options/create-from-dsl-modal/hooks/use-dsl-import.spec.tsx create mode 100644 web/app/components/datasets/create-from-pipeline/create-options/create-from-dsl-modal/hooks/use-dsl-import.ts create mode 100644 web/app/components/datasets/create/file-uploader/components/file-list-item.spec.tsx create mode 100644 web/app/components/datasets/create/file-uploader/components/file-list-item.tsx create mode 100644 web/app/components/datasets/create/file-uploader/components/upload-dropzone.spec.tsx create mode 100644 web/app/components/datasets/create/file-uploader/components/upload-dropzone.tsx create mode 100644 web/app/components/datasets/create/file-uploader/constants.ts create mode 100644 web/app/components/datasets/create/file-uploader/hooks/use-file-upload.spec.tsx create mode 100644 web/app/components/datasets/create/file-uploader/hooks/use-file-upload.ts create mode 100644 web/app/components/datasets/create/file-uploader/index.spec.tsx create mode 100644 web/app/components/datasets/documents/create-from-pipeline/data-source/local-file/components/file-list-item.spec.tsx create mode 100644 web/app/components/datasets/documents/create-from-pipeline/data-source/local-file/components/file-list-item.tsx create mode 100644 web/app/components/datasets/documents/create-from-pipeline/data-source/local-file/components/upload-dropzone.spec.tsx create mode 100644 web/app/components/datasets/documents/create-from-pipeline/data-source/local-file/components/upload-dropzone.tsx create mode 100644 web/app/components/datasets/documents/create-from-pipeline/data-source/local-file/constants.ts create mode 100644 web/app/components/datasets/documents/create-from-pipeline/data-source/local-file/hooks/use-local-file-upload.spec.tsx create mode 100644 web/app/components/datasets/documents/create-from-pipeline/data-source/local-file/hooks/use-local-file-upload.ts create mode 100644 web/app/components/datasets/documents/create-from-pipeline/data-source/local-file/index.spec.tsx create mode 100644 web/app/components/datasets/settings/form/components/basic-info-section.spec.tsx create mode 100644 web/app/components/datasets/settings/form/components/basic-info-section.tsx create mode 100644 web/app/components/datasets/settings/form/components/external-knowledge-section.spec.tsx create mode 100644 web/app/components/datasets/settings/form/components/external-knowledge-section.tsx create mode 100644 web/app/components/datasets/settings/form/components/indexing-section.spec.tsx create mode 100644 web/app/components/datasets/settings/form/components/indexing-section.tsx create mode 100644 web/app/components/datasets/settings/form/hooks/use-form-state.spec.ts create mode 100644 web/app/components/datasets/settings/form/hooks/use-form-state.ts create mode 100644 web/app/components/datasets/settings/form/index.spec.tsx diff --git a/web/app/components/datasets/common/image-uploader/utils.spec.ts b/web/app/components/datasets/common/image-uploader/utils.spec.ts index 0150b1fb23..5741f5704f 100644 --- a/web/app/components/datasets/common/image-uploader/utils.spec.ts +++ b/web/app/components/datasets/common/image-uploader/utils.spec.ts @@ -216,13 +216,22 @@ describe('image-uploader utils', () => { type FileCallback = (file: MockFile) => void type EntriesCallback = (entries: FileSystemEntry[]) => void + // Helper to create mock FileSystemEntry with required properties + const createMockEntry = (props: { + isFile: boolean + isDirectory: boolean + name?: string + file?: (callback: FileCallback) => void + createReader?: () => { readEntries: (callback: EntriesCallback) => void } + }): FileSystemEntry => props as unknown as FileSystemEntry + it('should resolve with file array for file entry', async () => { const mockFile: MockFile = { name: 'test.png' } - const mockEntry = { + const mockEntry = createMockEntry({ isFile: true, isDirectory: false, file: (callback: FileCallback) => callback(mockFile), - } + }) const result = await traverseFileEntry(mockEntry) expect(result).toHaveLength(1) @@ -232,11 +241,11 @@ describe('image-uploader utils', () => { it('should resolve with file array with prefix for nested file', async () => { const mockFile: MockFile = { name: 'test.png' } - const mockEntry = { + const mockEntry = createMockEntry({ isFile: true, isDirectory: false, file: (callback: FileCallback) => callback(mockFile), - } + }) const result = await traverseFileEntry(mockEntry, 'folder/') expect(result).toHaveLength(1) @@ -244,24 +253,24 @@ describe('image-uploader utils', () => { }) it('should resolve empty array for unknown entry type', async () => { - const mockEntry = { + const mockEntry = createMockEntry({ isFile: false, isDirectory: false, - } + }) const result = await traverseFileEntry(mockEntry) expect(result).toEqual([]) }) it('should handle directory with no files', async () => { - const mockEntry = { + const mockEntry = createMockEntry({ isFile: false, isDirectory: true, name: 'empty-folder', createReader: () => ({ readEntries: (callback: EntriesCallback) => callback([]), }), - } + }) const result = await traverseFileEntry(mockEntry) expect(result).toEqual([]) @@ -271,20 +280,20 @@ describe('image-uploader utils', () => { const mockFile1: MockFile = { name: 'file1.png' } const mockFile2: MockFile = { name: 'file2.png' } - const mockFileEntry1 = { + const mockFileEntry1 = createMockEntry({ isFile: true, isDirectory: false, file: (callback: FileCallback) => callback(mockFile1), - } + }) - const mockFileEntry2 = { + const mockFileEntry2 = createMockEntry({ isFile: true, isDirectory: false, file: (callback: FileCallback) => callback(mockFile2), - } + }) let readCount = 0 - const mockEntry = { + const mockEntry = createMockEntry({ isFile: false, isDirectory: true, name: 'folder', @@ -292,14 +301,14 @@ describe('image-uploader utils', () => { readEntries: (callback: EntriesCallback) => { if (readCount === 0) { readCount++ - callback([mockFileEntry1, mockFileEntry2] as unknown as FileSystemEntry[]) + callback([mockFileEntry1, mockFileEntry2]) } else { callback([]) } }, }), - } + }) const result = await traverseFileEntry(mockEntry) expect(result).toHaveLength(2) diff --git a/web/app/components/datasets/common/image-uploader/utils.ts b/web/app/components/datasets/common/image-uploader/utils.ts index c2fad83840..d8c8582e2a 100644 --- a/web/app/components/datasets/common/image-uploader/utils.ts +++ b/web/app/components/datasets/common/image-uploader/utils.ts @@ -18,17 +18,17 @@ type FileWithPath = { relativePath?: string } & File -export const traverseFileEntry = (entry: any, prefix = ''): Promise<FileWithPath[]> => { +export const traverseFileEntry = (entry: FileSystemEntry, prefix = ''): Promise<FileWithPath[]> => { return new Promise((resolve) => { if (entry.isFile) { - entry.file((file: FileWithPath) => { + (entry as FileSystemFileEntry).file((file: FileWithPath) => { file.relativePath = `${prefix}${file.name}` resolve([file]) }) } else if (entry.isDirectory) { - const reader = entry.createReader() - const entries: any[] = [] + const reader = (entry as FileSystemDirectoryEntry).createReader() + const entries: FileSystemEntry[] = [] const read = () => { reader.readEntries(async (results: FileSystemEntry[]) => { if (!results.length) { diff --git a/web/app/components/datasets/create-from-pipeline/create-options/create-from-dsl-modal/hooks/use-dsl-import.spec.tsx b/web/app/components/datasets/create-from-pipeline/create-options/create-from-dsl-modal/hooks/use-dsl-import.spec.tsx new file mode 100644 index 0000000000..e4955f58f6 --- /dev/null +++ b/web/app/components/datasets/create-from-pipeline/create-options/create-from-dsl-modal/hooks/use-dsl-import.spec.tsx @@ -0,0 +1,1045 @@ +import type { ReactNode } from 'react' +import { QueryClient, QueryClientProvider } from '@tanstack/react-query' +import { act, renderHook, waitFor } from '@testing-library/react' +import { beforeEach, describe, expect, it, vi } from 'vitest' +import { CreateFromDSLModalTab, useDSLImport } from './use-dsl-import' + +// Mock next/navigation +const mockPush = vi.fn() +vi.mock('next/navigation', () => ({ + useRouter: () => ({ + push: mockPush, + }), +})) + +// Mock service hooks +const mockImportDSL = vi.fn() +const mockImportDSLConfirm = vi.fn() + +vi.mock('@/service/use-pipeline', () => ({ + useImportPipelineDSL: () => ({ + mutateAsync: mockImportDSL, + }), + useImportPipelineDSLConfirm: () => ({ + mutateAsync: mockImportDSLConfirm, + }), +})) + +// Mock plugin dependencies hook +const mockHandleCheckPluginDependencies = vi.fn() + +vi.mock('@/app/components/workflow/plugin-dependency/hooks', () => ({ + usePluginDependencies: () => ({ + handleCheckPluginDependencies: mockHandleCheckPluginDependencies, + }), +})) + +// Mock toast context +const mockNotify = vi.fn() + +vi.mock('use-context-selector', async () => { + const actual = await vi.importActual<typeof import('use-context-selector')>('use-context-selector') + return { + ...actual, + useContext: vi.fn(() => ({ notify: mockNotify })), + } +}) + +// Test data builders +const createImportDSLResponse = (overrides = {}) => ({ + id: 'import-123', + status: 'completed' as const, + pipeline_id: 'pipeline-456', + dataset_id: 'dataset-789', + current_dsl_version: '1.0.0', + imported_dsl_version: '1.0.0', + ...overrides, +}) + +// Helper function to create QueryClient wrapper +const createWrapper = () => { + const queryClient = new QueryClient({ + defaultOptions: { + queries: { retry: false }, + mutations: { retry: false }, + }, + }) + return ({ children }: { children: ReactNode }) => ( + <QueryClientProvider client={queryClient}> + {children} + </QueryClientProvider> + ) +} + +describe('useDSLImport', () => { + beforeEach(() => { + vi.clearAllMocks() + mockImportDSL.mockReset() + mockImportDSLConfirm.mockReset() + mockPush.mockReset() + mockNotify.mockReset() + mockHandleCheckPluginDependencies.mockReset() + }) + + describe('initialization', () => { + it('should initialize with default values', () => { + const { result } = renderHook( + () => useDSLImport({}), + { wrapper: createWrapper() }, + ) + + expect(result.current.currentFile).toBeUndefined() + expect(result.current.currentTab).toBe(CreateFromDSLModalTab.FROM_FILE) + expect(result.current.dslUrlValue).toBe('') + expect(result.current.showConfirmModal).toBe(false) + expect(result.current.versions).toBeUndefined() + expect(result.current.buttonDisabled).toBe(true) + expect(result.current.isConfirming).toBe(false) + }) + + it('should use provided activeTab', () => { + const { result } = renderHook( + () => useDSLImport({ activeTab: CreateFromDSLModalTab.FROM_URL }), + { wrapper: createWrapper() }, + ) + + expect(result.current.currentTab).toBe(CreateFromDSLModalTab.FROM_URL) + }) + + it('should use provided dslUrl', () => { + const { result } = renderHook( + () => useDSLImport({ dslUrl: 'https://example.com/test.pipeline' }), + { wrapper: createWrapper() }, + ) + + expect(result.current.dslUrlValue).toBe('https://example.com/test.pipeline') + }) + }) + + describe('setCurrentTab', () => { + it('should update current tab', () => { + const { result } = renderHook( + () => useDSLImport({}), + { wrapper: createWrapper() }, + ) + + act(() => { + result.current.setCurrentTab(CreateFromDSLModalTab.FROM_URL) + }) + + expect(result.current.currentTab).toBe(CreateFromDSLModalTab.FROM_URL) + }) + }) + + describe('setDslUrlValue', () => { + it('should update DSL URL value', () => { + const { result } = renderHook( + () => useDSLImport({}), + { wrapper: createWrapper() }, + ) + + act(() => { + result.current.setDslUrlValue('https://new-url.com/pipeline') + }) + + expect(result.current.dslUrlValue).toBe('https://new-url.com/pipeline') + }) + }) + + describe('handleFile', () => { + it('should set file and trigger file reading', async () => { + const { result } = renderHook( + () => useDSLImport({}), + { wrapper: createWrapper() }, + ) + + const mockFile = new File(['test content'], 'test.pipeline', { type: 'application/octet-stream' }) + + await act(async () => { + result.current.handleFile(mockFile) + }) + + expect(result.current.currentFile).toBe(mockFile) + expect(result.current.buttonDisabled).toBe(false) + }) + + it('should clear file when undefined is passed', async () => { + const { result } = renderHook( + () => useDSLImport({}), + { wrapper: createWrapper() }, + ) + + const mockFile = new File(['test content'], 'test.pipeline', { type: 'application/octet-stream' }) + + // First set a file + await act(async () => { + result.current.handleFile(mockFile) + }) + + expect(result.current.currentFile).toBe(mockFile) + + // Then clear it + await act(async () => { + result.current.handleFile(undefined) + }) + + expect(result.current.currentFile).toBeUndefined() + expect(result.current.buttonDisabled).toBe(true) + }) + }) + + describe('buttonDisabled', () => { + it('should be true when file tab is active and no file is selected', () => { + const { result } = renderHook( + () => useDSLImport({ activeTab: CreateFromDSLModalTab.FROM_FILE }), + { wrapper: createWrapper() }, + ) + + expect(result.current.buttonDisabled).toBe(true) + }) + + it('should be false when file tab is active and file is selected', async () => { + const { result } = renderHook( + () => useDSLImport({ activeTab: CreateFromDSLModalTab.FROM_FILE }), + { wrapper: createWrapper() }, + ) + + const mockFile = new File(['content'], 'test.pipeline', { type: 'application/octet-stream' }) + + await act(async () => { + result.current.handleFile(mockFile) + }) + + expect(result.current.buttonDisabled).toBe(false) + }) + + it('should be true when URL tab is active and no URL is entered', () => { + const { result } = renderHook( + () => useDSLImport({ activeTab: CreateFromDSLModalTab.FROM_URL }), + { wrapper: createWrapper() }, + ) + + expect(result.current.buttonDisabled).toBe(true) + }) + + it('should be false when URL tab is active and URL is entered', () => { + const { result } = renderHook( + () => useDSLImport({ activeTab: CreateFromDSLModalTab.FROM_URL, dslUrl: 'https://example.com' }), + { wrapper: createWrapper() }, + ) + + expect(result.current.buttonDisabled).toBe(false) + }) + }) + + describe('handleCreateApp with URL mode', () => { + it('should call importDSL with URL mode', async () => { + vi.useFakeTimers({ shouldAdvanceTime: true }) + mockImportDSL.mockResolvedValue(createImportDSLResponse()) + mockHandleCheckPluginDependencies.mockResolvedValue(undefined) + + const onSuccess = vi.fn() + const onClose = vi.fn() + + const { result } = renderHook( + () => useDSLImport({ + activeTab: CreateFromDSLModalTab.FROM_URL, + dslUrl: 'https://example.com/test.pipeline', + onSuccess, + onClose, + }), + { wrapper: createWrapper() }, + ) + + await act(async () => { + result.current.handleCreateApp() + vi.advanceTimersByTime(400) // Wait for debounce + }) + + await waitFor(() => { + expect(mockImportDSL).toHaveBeenCalledWith({ + mode: 'yaml-url', + yaml_url: 'https://example.com/test.pipeline', + }) + }) + + vi.useRealTimers() + }) + + it('should handle successful import with COMPLETED status', async () => { + vi.useFakeTimers({ shouldAdvanceTime: true }) + mockImportDSL.mockResolvedValue(createImportDSLResponse({ status: 'completed' })) + mockHandleCheckPluginDependencies.mockResolvedValue(undefined) + + const onSuccess = vi.fn() + const onClose = vi.fn() + + const { result } = renderHook( + () => useDSLImport({ + activeTab: CreateFromDSLModalTab.FROM_URL, + dslUrl: 'https://example.com/test.pipeline', + onSuccess, + onClose, + }), + { wrapper: createWrapper() }, + ) + + await act(async () => { + result.current.handleCreateApp() + vi.advanceTimersByTime(400) + }) + + await waitFor(() => { + expect(onSuccess).toHaveBeenCalled() + expect(onClose).toHaveBeenCalled() + expect(mockNotify).toHaveBeenCalledWith(expect.objectContaining({ + type: 'success', + })) + expect(mockPush).toHaveBeenCalledWith('/datasets/dataset-789/pipeline') + }) + + vi.useRealTimers() + }) + + it('should handle import with COMPLETED_WITH_WARNINGS status', async () => { + vi.useFakeTimers({ shouldAdvanceTime: true }) + mockImportDSL.mockResolvedValue(createImportDSLResponse({ status: 'completed-with-warnings' })) + mockHandleCheckPluginDependencies.mockResolvedValue(undefined) + + const onSuccess = vi.fn() + const onClose = vi.fn() + + const { result } = renderHook( + () => useDSLImport({ + activeTab: CreateFromDSLModalTab.FROM_URL, + dslUrl: 'https://example.com/test.pipeline', + onSuccess, + onClose, + }), + { wrapper: createWrapper() }, + ) + + await act(async () => { + result.current.handleCreateApp() + vi.advanceTimersByTime(400) + }) + + await waitFor(() => { + expect(mockNotify).toHaveBeenCalledWith(expect.objectContaining({ + type: 'warning', + })) + }) + + vi.useRealTimers() + }) + + it('should handle import with PENDING status and show confirm modal', async () => { + vi.useFakeTimers({ shouldAdvanceTime: true }) + mockImportDSL.mockResolvedValue(createImportDSLResponse({ + status: 'pending', + imported_dsl_version: '0.9.0', + current_dsl_version: '1.0.0', + })) + + const onClose = vi.fn() + + const { result } = renderHook( + () => useDSLImport({ + activeTab: CreateFromDSLModalTab.FROM_URL, + dslUrl: 'https://example.com/test.pipeline', + onClose, + }), + { wrapper: createWrapper() }, + ) + + await act(async () => { + result.current.handleCreateApp() + vi.advanceTimersByTime(400) + }) + + await waitFor(() => { + expect(onClose).toHaveBeenCalled() + }) + + // Wait for setTimeout to show confirm modal + await act(async () => { + vi.advanceTimersByTime(400) + }) + + expect(result.current.showConfirmModal).toBe(true) + expect(result.current.versions).toEqual({ + importedVersion: '0.9.0', + systemVersion: '1.0.0', + }) + + vi.useRealTimers() + }) + + it('should handle API error (null response)', async () => { + vi.useFakeTimers({ shouldAdvanceTime: true }) + mockImportDSL.mockResolvedValue(null) + + const { result } = renderHook( + () => useDSLImport({ + activeTab: CreateFromDSLModalTab.FROM_URL, + dslUrl: 'https://example.com/test.pipeline', + }), + { wrapper: createWrapper() }, + ) + + await act(async () => { + result.current.handleCreateApp() + vi.advanceTimersByTime(400) + }) + + await waitFor(() => { + expect(mockNotify).toHaveBeenCalledWith(expect.objectContaining({ + type: 'error', + })) + }) + + vi.useRealTimers() + }) + + it('should handle FAILED status', async () => { + vi.useFakeTimers({ shouldAdvanceTime: true }) + mockImportDSL.mockResolvedValue(createImportDSLResponse({ status: 'failed' })) + + const { result } = renderHook( + () => useDSLImport({ + activeTab: CreateFromDSLModalTab.FROM_URL, + dslUrl: 'https://example.com/test.pipeline', + }), + { wrapper: createWrapper() }, + ) + + await act(async () => { + result.current.handleCreateApp() + vi.advanceTimersByTime(400) + }) + + await waitFor(() => { + expect(mockNotify).toHaveBeenCalledWith(expect.objectContaining({ + type: 'error', + })) + }) + + vi.useRealTimers() + }) + + it('should check plugin dependencies when pipeline_id is present', async () => { + vi.useFakeTimers({ shouldAdvanceTime: true }) + mockImportDSL.mockResolvedValue(createImportDSLResponse({ + status: 'completed', + pipeline_id: 'pipeline-123', + })) + mockHandleCheckPluginDependencies.mockResolvedValue(undefined) + + const { result } = renderHook( + () => useDSLImport({ + activeTab: CreateFromDSLModalTab.FROM_URL, + dslUrl: 'https://example.com/test.pipeline', + }), + { wrapper: createWrapper() }, + ) + + await act(async () => { + result.current.handleCreateApp() + vi.advanceTimersByTime(400) + }) + + await waitFor(() => { + expect(mockHandleCheckPluginDependencies).toHaveBeenCalledWith('pipeline-123', true) + }) + + vi.useRealTimers() + }) + + it('should not check plugin dependencies when pipeline_id is null', async () => { + vi.useFakeTimers({ shouldAdvanceTime: true }) + mockImportDSL.mockResolvedValue(createImportDSLResponse({ + status: 'completed', + pipeline_id: null, + })) + + const { result } = renderHook( + () => useDSLImport({ + activeTab: CreateFromDSLModalTab.FROM_URL, + dslUrl: 'https://example.com/test.pipeline', + }), + { wrapper: createWrapper() }, + ) + + await act(async () => { + result.current.handleCreateApp() + vi.advanceTimersByTime(400) + }) + + await waitFor(() => { + expect(mockHandleCheckPluginDependencies).not.toHaveBeenCalled() + }) + + vi.useRealTimers() + }) + + it('should return early when URL tab is active but no URL is provided', async () => { + vi.useFakeTimers({ shouldAdvanceTime: true }) + + const { result } = renderHook( + () => useDSLImport({ + activeTab: CreateFromDSLModalTab.FROM_URL, + dslUrl: '', + }), + { wrapper: createWrapper() }, + ) + + await act(async () => { + result.current.handleCreateApp() + vi.advanceTimersByTime(400) + }) + + expect(mockImportDSL).not.toHaveBeenCalled() + + vi.useRealTimers() + }) + }) + + describe('handleCreateApp with FILE mode', () => { + it('should call importDSL with file content mode', async () => { + vi.useFakeTimers({ shouldAdvanceTime: true }) + mockImportDSL.mockResolvedValue(createImportDSLResponse()) + mockHandleCheckPluginDependencies.mockResolvedValue(undefined) + + const { result } = renderHook( + () => useDSLImport({ + activeTab: CreateFromDSLModalTab.FROM_FILE, + }), + { wrapper: createWrapper() }, + ) + + const fileContent = 'test yaml content' + const mockFile = new File([fileContent], 'test.pipeline', { type: 'application/octet-stream' }) + + // Set up file and wait for FileReader to complete + await act(async () => { + result.current.handleFile(mockFile) + // Give FileReader time to process + await new Promise(resolve => setTimeout(resolve, 100)) + }) + + // Trigger create + await act(async () => { + result.current.handleCreateApp() + vi.advanceTimersByTime(400) + }) + + await waitFor(() => { + expect(mockImportDSL).toHaveBeenCalledWith({ + mode: 'yaml-content', + yaml_content: fileContent, + }) + }) + + vi.useRealTimers() + }) + + it('should return early when file tab is active but no file is selected', async () => { + vi.useFakeTimers({ shouldAdvanceTime: true }) + + const { result } = renderHook( + () => useDSLImport({ + activeTab: CreateFromDSLModalTab.FROM_FILE, + }), + { wrapper: createWrapper() }, + ) + + await act(async () => { + result.current.handleCreateApp() + vi.advanceTimersByTime(400) + }) + + expect(mockImportDSL).not.toHaveBeenCalled() + + vi.useRealTimers() + }) + }) + + describe('onDSLConfirm', () => { + it('should call importDSLConfirm and handle success', async () => { + vi.useFakeTimers({ shouldAdvanceTime: true }) + + // First, trigger pending status to get importId + mockImportDSL.mockResolvedValue(createImportDSLResponse({ + id: 'import-123', + status: 'pending', + })) + + mockImportDSLConfirm.mockResolvedValue({ + status: 'completed', + pipeline_id: 'pipeline-456', + dataset_id: 'dataset-789', + }) + + mockHandleCheckPluginDependencies.mockResolvedValue(undefined) + + const onSuccess = vi.fn() + + const { result } = renderHook( + () => useDSLImport({ + activeTab: CreateFromDSLModalTab.FROM_URL, + dslUrl: 'https://example.com/test.pipeline', + onSuccess, + }), + { wrapper: createWrapper() }, + ) + + // Trigger pending status + await act(async () => { + result.current.handleCreateApp() + vi.advanceTimersByTime(400) + }) + + // Wait for confirm modal to show + await act(async () => { + vi.advanceTimersByTime(400) + }) + + expect(result.current.showConfirmModal).toBe(true) + + // Call onDSLConfirm + await act(async () => { + result.current.onDSLConfirm() + }) + + await waitFor(() => { + expect(mockImportDSLConfirm).toHaveBeenCalledWith('import-123') + expect(onSuccess).toHaveBeenCalled() + expect(result.current.showConfirmModal).toBe(false) + expect(mockNotify).toHaveBeenCalledWith(expect.objectContaining({ + type: 'success', + })) + }) + + vi.useRealTimers() + }) + + it('should handle confirm API error', async () => { + vi.useFakeTimers({ shouldAdvanceTime: true }) + + mockImportDSL.mockResolvedValue(createImportDSLResponse({ + id: 'import-123', + status: 'pending', + })) + + mockImportDSLConfirm.mockResolvedValue(null) + + const { result } = renderHook( + () => useDSLImport({ + activeTab: CreateFromDSLModalTab.FROM_URL, + dslUrl: 'https://example.com/test.pipeline', + }), + { wrapper: createWrapper() }, + ) + + // Trigger pending status + await act(async () => { + result.current.handleCreateApp() + vi.advanceTimersByTime(400) + }) + + await act(async () => { + vi.advanceTimersByTime(400) + }) + + // Call onDSLConfirm + await act(async () => { + result.current.onDSLConfirm() + }) + + await waitFor(() => { + expect(mockNotify).toHaveBeenCalledWith(expect.objectContaining({ + type: 'error', + })) + }) + + vi.useRealTimers() + }) + + it('should handle confirm with FAILED status', async () => { + vi.useFakeTimers({ shouldAdvanceTime: true }) + + mockImportDSL.mockResolvedValue(createImportDSLResponse({ + id: 'import-123', + status: 'pending', + })) + + mockImportDSLConfirm.mockResolvedValue({ + status: 'failed', + pipeline_id: 'pipeline-456', + dataset_id: 'dataset-789', + }) + + const { result } = renderHook( + () => useDSLImport({ + activeTab: CreateFromDSLModalTab.FROM_URL, + dslUrl: 'https://example.com/test.pipeline', + }), + { wrapper: createWrapper() }, + ) + + // Trigger pending status + await act(async () => { + result.current.handleCreateApp() + vi.advanceTimersByTime(400) + }) + + await act(async () => { + vi.advanceTimersByTime(400) + }) + + // Call onDSLConfirm + await act(async () => { + result.current.onDSLConfirm() + }) + + await waitFor(() => { + expect(mockNotify).toHaveBeenCalledWith(expect.objectContaining({ + type: 'error', + })) + }) + + vi.useRealTimers() + }) + + it('should return early when importId is not set', async () => { + const { result } = renderHook( + () => useDSLImport({ + activeTab: CreateFromDSLModalTab.FROM_URL, + dslUrl: 'https://example.com/test.pipeline', + }), + { wrapper: createWrapper() }, + ) + + // Call onDSLConfirm without triggering pending status + await act(async () => { + result.current.onDSLConfirm() + }) + + expect(mockImportDSLConfirm).not.toHaveBeenCalled() + }) + + it('should check plugin dependencies on confirm success', async () => { + vi.useFakeTimers({ shouldAdvanceTime: true }) + + mockImportDSL.mockResolvedValue(createImportDSLResponse({ + id: 'import-123', + status: 'pending', + })) + + mockImportDSLConfirm.mockResolvedValue({ + status: 'completed', + pipeline_id: 'pipeline-789', + dataset_id: 'dataset-789', + }) + + mockHandleCheckPluginDependencies.mockResolvedValue(undefined) + + const { result } = renderHook( + () => useDSLImport({ + activeTab: CreateFromDSLModalTab.FROM_URL, + dslUrl: 'https://example.com/test.pipeline', + }), + { wrapper: createWrapper() }, + ) + + // Trigger pending status + await act(async () => { + result.current.handleCreateApp() + vi.advanceTimersByTime(400) + }) + + await act(async () => { + vi.advanceTimersByTime(400) + }) + + // Call onDSLConfirm + await act(async () => { + result.current.onDSLConfirm() + }) + + await waitFor(() => { + expect(mockHandleCheckPluginDependencies).toHaveBeenCalledWith('pipeline-789', true) + }) + + vi.useRealTimers() + }) + + it('should set isConfirming during confirm process', async () => { + vi.useFakeTimers({ shouldAdvanceTime: true }) + + let resolveConfirm: (value: unknown) => void + mockImportDSL.mockResolvedValue(createImportDSLResponse({ + id: 'import-123', + status: 'pending', + })) + + mockImportDSLConfirm.mockImplementation(() => new Promise((resolve) => { + resolveConfirm = resolve + })) + + const { result } = renderHook( + () => useDSLImport({ + activeTab: CreateFromDSLModalTab.FROM_URL, + dslUrl: 'https://example.com/test.pipeline', + }), + { wrapper: createWrapper() }, + ) + + // Trigger pending status + await act(async () => { + result.current.handleCreateApp() + vi.advanceTimersByTime(400) + }) + + await act(async () => { + vi.advanceTimersByTime(400) + }) + + expect(result.current.isConfirming).toBe(false) + + // Start confirm + let confirmPromise: Promise<void> + act(() => { + confirmPromise = result.current.onDSLConfirm() + }) + + await waitFor(() => { + expect(result.current.isConfirming).toBe(true) + }) + + // Resolve confirm + await act(async () => { + resolveConfirm!({ + status: 'completed', + pipeline_id: 'pipeline-789', + dataset_id: 'dataset-789', + }) + }) + + await confirmPromise! + + expect(result.current.isConfirming).toBe(false) + + vi.useRealTimers() + }) + }) + + describe('handleCancelConfirm', () => { + it('should close confirm modal', async () => { + vi.useFakeTimers({ shouldAdvanceTime: true }) + + mockImportDSL.mockResolvedValue(createImportDSLResponse({ + id: 'import-123', + status: 'pending', + })) + + const { result } = renderHook( + () => useDSLImport({ + activeTab: CreateFromDSLModalTab.FROM_URL, + dslUrl: 'https://example.com/test.pipeline', + }), + { wrapper: createWrapper() }, + ) + + // Trigger pending status to show confirm modal + await act(async () => { + result.current.handleCreateApp() + vi.advanceTimersByTime(400) + }) + + await act(async () => { + vi.advanceTimersByTime(400) + }) + + expect(result.current.showConfirmModal).toBe(true) + + // Cancel confirm + act(() => { + result.current.handleCancelConfirm() + }) + + expect(result.current.showConfirmModal).toBe(false) + + vi.useRealTimers() + }) + }) + + describe('duplicate submission prevention', () => { + it('should prevent duplicate submissions while creating', async () => { + vi.useFakeTimers({ shouldAdvanceTime: true }) + + let resolveImport: (value: unknown) => void + mockImportDSL.mockImplementation(() => new Promise((resolve) => { + resolveImport = resolve + })) + + const { result } = renderHook( + () => useDSLImport({ + activeTab: CreateFromDSLModalTab.FROM_URL, + dslUrl: 'https://example.com/test.pipeline', + }), + { wrapper: createWrapper() }, + ) + + // First call + await act(async () => { + result.current.handleCreateApp() + vi.advanceTimersByTime(400) + }) + + // Second call should be ignored + await act(async () => { + result.current.handleCreateApp() + vi.advanceTimersByTime(400) + }) + + // Third call should be ignored + await act(async () => { + result.current.handleCreateApp() + vi.advanceTimersByTime(400) + }) + + // Only one call should be made + expect(mockImportDSL).toHaveBeenCalledTimes(1) + + // Resolve the first call + await act(async () => { + resolveImport!(createImportDSLResponse()) + }) + + vi.useRealTimers() + }) + }) + + describe('file reading', () => { + it('should read file content using FileReader', async () => { + const { result } = renderHook( + () => useDSLImport({ activeTab: CreateFromDSLModalTab.FROM_FILE }), + { wrapper: createWrapper() }, + ) + + const fileContent = 'yaml content here' + const mockFile = new File([fileContent], 'test.pipeline', { type: 'application/octet-stream' }) + + await act(async () => { + result.current.handleFile(mockFile) + }) + + expect(result.current.currentFile).toBe(mockFile) + }) + + it('should clear file content when file is removed', async () => { + const { result } = renderHook( + () => useDSLImport({ activeTab: CreateFromDSLModalTab.FROM_FILE }), + { wrapper: createWrapper() }, + ) + + const mockFile = new File(['content'], 'test.pipeline', { type: 'application/octet-stream' }) + + // Set file + await act(async () => { + result.current.handleFile(mockFile) + }) + + // Clear file + await act(async () => { + result.current.handleFile(undefined) + }) + + expect(result.current.currentFile).toBeUndefined() + }) + }) + + describe('navigation after import', () => { + it('should navigate to pipeline page after successful import', async () => { + vi.useFakeTimers({ shouldAdvanceTime: true }) + mockImportDSL.mockResolvedValue(createImportDSLResponse({ + status: 'completed', + dataset_id: 'test-dataset-id', + })) + mockHandleCheckPluginDependencies.mockResolvedValue(undefined) + + const { result } = renderHook( + () => useDSLImport({ + activeTab: CreateFromDSLModalTab.FROM_URL, + dslUrl: 'https://example.com/test.pipeline', + }), + { wrapper: createWrapper() }, + ) + + await act(async () => { + result.current.handleCreateApp() + vi.advanceTimersByTime(400) + }) + + await waitFor(() => { + expect(mockPush).toHaveBeenCalledWith('/datasets/test-dataset-id/pipeline') + }) + + vi.useRealTimers() + }) + + it('should navigate to pipeline page after confirm success', async () => { + vi.useFakeTimers({ shouldAdvanceTime: true }) + + mockImportDSL.mockResolvedValue(createImportDSLResponse({ + id: 'import-123', + status: 'pending', + })) + + mockImportDSLConfirm.mockResolvedValue({ + status: 'completed', + pipeline_id: 'pipeline-456', + dataset_id: 'confirm-dataset-id', + }) + + mockHandleCheckPluginDependencies.mockResolvedValue(undefined) + + const { result } = renderHook( + () => useDSLImport({ + activeTab: CreateFromDSLModalTab.FROM_URL, + dslUrl: 'https://example.com/test.pipeline', + }), + { wrapper: createWrapper() }, + ) + + // Trigger pending status + await act(async () => { + result.current.handleCreateApp() + vi.advanceTimersByTime(400) + }) + + await act(async () => { + vi.advanceTimersByTime(400) + }) + + // Call onDSLConfirm + await act(async () => { + result.current.onDSLConfirm() + }) + + await waitFor(() => { + expect(mockPush).toHaveBeenCalledWith('/datasets/confirm-dataset-id/pipeline') + }) + + vi.useRealTimers() + }) + }) + + describe('enum export', () => { + it('should export CreateFromDSLModalTab enum with correct values', () => { + expect(CreateFromDSLModalTab.FROM_FILE).toBe('from-file') + expect(CreateFromDSLModalTab.FROM_URL).toBe('from-url') + }) + }) +}) diff --git a/web/app/components/datasets/create-from-pipeline/create-options/create-from-dsl-modal/hooks/use-dsl-import.ts b/web/app/components/datasets/create-from-pipeline/create-options/create-from-dsl-modal/hooks/use-dsl-import.ts new file mode 100644 index 0000000000..87e55ea740 --- /dev/null +++ b/web/app/components/datasets/create-from-pipeline/create-options/create-from-dsl-modal/hooks/use-dsl-import.ts @@ -0,0 +1,218 @@ +'use client' +import { useDebounceFn } from 'ahooks' +import { useRouter } from 'next/navigation' +import { useCallback, useMemo, useRef, useState } from 'react' +import { useTranslation } from 'react-i18next' +import { useContext } from 'use-context-selector' +import { ToastContext } from '@/app/components/base/toast' +import { usePluginDependencies } from '@/app/components/workflow/plugin-dependency/hooks' +import { + DSLImportMode, + DSLImportStatus, +} from '@/models/app' +import { useImportPipelineDSL, useImportPipelineDSLConfirm } from '@/service/use-pipeline' + +export enum CreateFromDSLModalTab { + FROM_FILE = 'from-file', + FROM_URL = 'from-url', +} + +export type UseDSLImportOptions = { + activeTab?: CreateFromDSLModalTab + dslUrl?: string + onSuccess?: () => void + onClose?: () => void +} + +export type DSLVersions = { + importedVersion: string + systemVersion: string +} + +export const useDSLImport = ({ + activeTab = CreateFromDSLModalTab.FROM_FILE, + dslUrl = '', + onSuccess, + onClose, +}: UseDSLImportOptions) => { + const { push } = useRouter() + const { t } = useTranslation() + const { notify } = useContext(ToastContext) + + const [currentFile, setDSLFile] = useState<File>() + const [fileContent, setFileContent] = useState<string>() + const [currentTab, setCurrentTab] = useState(activeTab) + const [dslUrlValue, setDslUrlValue] = useState(dslUrl) + const [showConfirmModal, setShowConfirmModal] = useState(false) + const [versions, setVersions] = useState<DSLVersions>() + const [importId, setImportId] = useState<string>() + const [isConfirming, setIsConfirming] = useState(false) + + const { handleCheckPluginDependencies } = usePluginDependencies() + const isCreatingRef = useRef(false) + + const { mutateAsync: importDSL } = useImportPipelineDSL() + const { mutateAsync: importDSLConfirm } = useImportPipelineDSLConfirm() + + const readFile = useCallback((file: File) => { + const reader = new FileReader() + reader.onload = (event) => { + const content = event.target?.result + setFileContent(content as string) + } + reader.readAsText(file) + }, []) + + const handleFile = useCallback((file?: File) => { + setDSLFile(file) + if (file) + readFile(file) + if (!file) + setFileContent('') + }, [readFile]) + + const onCreate = useCallback(async () => { + if (currentTab === CreateFromDSLModalTab.FROM_FILE && !currentFile) + return + if (currentTab === CreateFromDSLModalTab.FROM_URL && !dslUrlValue) + return + if (isCreatingRef.current) + return + + isCreatingRef.current = true + + let response + if (currentTab === CreateFromDSLModalTab.FROM_FILE) { + response = await importDSL({ + mode: DSLImportMode.YAML_CONTENT, + yaml_content: fileContent || '', + }) + } + if (currentTab === CreateFromDSLModalTab.FROM_URL) { + response = await importDSL({ + mode: DSLImportMode.YAML_URL, + yaml_url: dslUrlValue || '', + }) + } + + if (!response) { + notify({ type: 'error', message: t('creation.errorTip', { ns: 'datasetPipeline' }) }) + isCreatingRef.current = false + return + } + + const { id, status, pipeline_id, dataset_id, imported_dsl_version, current_dsl_version } = response + + if (status === DSLImportStatus.COMPLETED || status === DSLImportStatus.COMPLETED_WITH_WARNINGS) { + onSuccess?.() + onClose?.() + + notify({ + type: status === DSLImportStatus.COMPLETED ? 'success' : 'warning', + message: t(status === DSLImportStatus.COMPLETED ? 'creation.successTip' : 'creation.caution', { ns: 'datasetPipeline' }), + children: status === DSLImportStatus.COMPLETED_WITH_WARNINGS && t('newApp.appCreateDSLWarning', { ns: 'app' }), + }) + + if (pipeline_id) + await handleCheckPluginDependencies(pipeline_id, true) + + push(`/datasets/${dataset_id}/pipeline`) + isCreatingRef.current = false + } + else if (status === DSLImportStatus.PENDING) { + setVersions({ + importedVersion: imported_dsl_version ?? '', + systemVersion: current_dsl_version ?? '', + }) + onClose?.() + setTimeout(() => { + setShowConfirmModal(true) + }, 300) + setImportId(id) + isCreatingRef.current = false + } + else { + notify({ type: 'error', message: t('creation.errorTip', { ns: 'datasetPipeline' }) }) + isCreatingRef.current = false + } + }, [ + currentTab, + currentFile, + dslUrlValue, + fileContent, + importDSL, + notify, + t, + onSuccess, + onClose, + handleCheckPluginDependencies, + push, + ]) + + const { run: handleCreateApp } = useDebounceFn(onCreate, { wait: 300 }) + + const onDSLConfirm = useCallback(async () => { + if (!importId) + return + + setIsConfirming(true) + const response = await importDSLConfirm(importId) + setIsConfirming(false) + + if (!response) { + notify({ type: 'error', message: t('creation.errorTip', { ns: 'datasetPipeline' }) }) + return + } + + const { status, pipeline_id, dataset_id } = response + + if (status === DSLImportStatus.COMPLETED) { + onSuccess?.() + setShowConfirmModal(false) + + notify({ + type: 'success', + message: t('creation.successTip', { ns: 'datasetPipeline' }), + }) + + if (pipeline_id) + await handleCheckPluginDependencies(pipeline_id, true) + + push(`/datasets/${dataset_id}/pipeline`) + } + else if (status === DSLImportStatus.FAILED) { + notify({ type: 'error', message: t('creation.errorTip', { ns: 'datasetPipeline' }) }) + } + }, [importId, importDSLConfirm, notify, t, onSuccess, handleCheckPluginDependencies, push]) + + const handleCancelConfirm = useCallback(() => { + setShowConfirmModal(false) + }, []) + + const buttonDisabled = useMemo(() => { + if (currentTab === CreateFromDSLModalTab.FROM_FILE) + return !currentFile + if (currentTab === CreateFromDSLModalTab.FROM_URL) + return !dslUrlValue + return false + }, [currentTab, currentFile, dslUrlValue]) + + return { + // State + currentFile, + currentTab, + dslUrlValue, + showConfirmModal, + versions, + buttonDisabled, + isConfirming, + + // Actions + setCurrentTab, + setDslUrlValue, + handleFile, + handleCreateApp, + onDSLConfirm, + handleCancelConfirm, + } +} diff --git a/web/app/components/datasets/create-from-pipeline/create-options/create-from-dsl-modal/index.tsx b/web/app/components/datasets/create-from-pipeline/create-options/create-from-dsl-modal/index.tsx index 2d187010b8..079ea90687 100644 --- a/web/app/components/datasets/create-from-pipeline/create-options/create-from-dsl-modal/index.tsx +++ b/web/app/components/datasets/create-from-pipeline/create-options/create-from-dsl-modal/index.tsx @@ -1,24 +1,18 @@ 'use client' -import { useDebounceFn, useKeyPress } from 'ahooks' +import { useKeyPress } from 'ahooks' import { noop } from 'es-toolkit/function' -import { useRouter } from 'next/navigation' -import { useMemo, useRef, useState } from 'react' import { useTranslation } from 'react-i18next' -import { useContext } from 'use-context-selector' import Button from '@/app/components/base/button' import Input from '@/app/components/base/input' import Modal from '@/app/components/base/modal' -import { ToastContext } from '@/app/components/base/toast' -import { usePluginDependencies } from '@/app/components/workflow/plugin-dependency/hooks' -import { - DSLImportMode, - DSLImportStatus, -} from '@/models/app' -import { useImportPipelineDSL, useImportPipelineDSLConfirm } from '@/service/use-pipeline' +import DSLConfirmModal from './dsl-confirm-modal' import Header from './header' +import { CreateFromDSLModalTab, useDSLImport } from './hooks/use-dsl-import' import Tab from './tab' import Uploader from './uploader' +export { CreateFromDSLModalTab } + type CreateFromDSLModalProps = { show: boolean onSuccess?: () => void @@ -27,11 +21,6 @@ type CreateFromDSLModalProps = { dslUrl?: string } -export enum CreateFromDSLModalTab { - FROM_FILE = 'from-file', - FROM_URL = 'from-url', -} - const CreateFromDSLModal = ({ show, onSuccess, @@ -39,149 +28,33 @@ const CreateFromDSLModal = ({ activeTab = CreateFromDSLModalTab.FROM_FILE, dslUrl = '', }: CreateFromDSLModalProps) => { - const { push } = useRouter() const { t } = useTranslation() - const { notify } = useContext(ToastContext) - const [currentFile, setDSLFile] = useState<File>() - const [fileContent, setFileContent] = useState<string>() - const [currentTab, setCurrentTab] = useState(activeTab) - const [dslUrlValue, setDslUrlValue] = useState(dslUrl) - const [showErrorModal, setShowErrorModal] = useState(false) - const [versions, setVersions] = useState<{ importedVersion: string, systemVersion: string }>() - const [importId, setImportId] = useState<string>() - const { handleCheckPluginDependencies } = usePluginDependencies() - const readFile = (file: File) => { - const reader = new FileReader() - reader.onload = function (event) { - const content = event.target?.result - setFileContent(content as string) - } - reader.readAsText(file) - } - - const handleFile = (file?: File) => { - setDSLFile(file) - if (file) - readFile(file) - if (!file) - setFileContent('') - } - - const isCreatingRef = useRef(false) - - const { mutateAsync: importDSL } = useImportPipelineDSL() - - const onCreate = async () => { - if (currentTab === CreateFromDSLModalTab.FROM_FILE && !currentFile) - return - if (currentTab === CreateFromDSLModalTab.FROM_URL && !dslUrlValue) - return - if (isCreatingRef.current) - return - isCreatingRef.current = true - let response - if (currentTab === CreateFromDSLModalTab.FROM_FILE) { - response = await importDSL({ - mode: DSLImportMode.YAML_CONTENT, - yaml_content: fileContent || '', - }) - } - if (currentTab === CreateFromDSLModalTab.FROM_URL) { - response = await importDSL({ - mode: DSLImportMode.YAML_URL, - yaml_url: dslUrlValue || '', - }) - } - - if (!response) { - notify({ type: 'error', message: t('creation.errorTip', { ns: 'datasetPipeline' }) }) - isCreatingRef.current = false - return - } - const { id, status, pipeline_id, dataset_id, imported_dsl_version, current_dsl_version } = response - if (status === DSLImportStatus.COMPLETED || status === DSLImportStatus.COMPLETED_WITH_WARNINGS) { - if (onSuccess) - onSuccess() - if (onClose) - onClose() - - notify({ - type: status === DSLImportStatus.COMPLETED ? 'success' : 'warning', - message: t(status === DSLImportStatus.COMPLETED ? 'creation.successTip' : 'creation.caution', { ns: 'datasetPipeline' }), - children: status === DSLImportStatus.COMPLETED_WITH_WARNINGS && t('newApp.appCreateDSLWarning', { ns: 'app' }), - }) - if (pipeline_id) - await handleCheckPluginDependencies(pipeline_id, true) - push(`/datasets/${dataset_id}/pipeline`) - isCreatingRef.current = false - } - else if (status === DSLImportStatus.PENDING) { - setVersions({ - importedVersion: imported_dsl_version ?? '', - systemVersion: current_dsl_version ?? '', - }) - if (onClose) - onClose() - setTimeout(() => { - setShowErrorModal(true) - }, 300) - setImportId(id) - isCreatingRef.current = false - } - else { - notify({ type: 'error', message: t('creation.errorTip', { ns: 'datasetPipeline' }) }) - isCreatingRef.current = false - } - } - - const { run: handleCreateApp } = useDebounceFn(onCreate, { wait: 300 }) - - useKeyPress('esc', () => { - if (show && !showErrorModal) - onClose() + const { + currentFile, + currentTab, + dslUrlValue, + showConfirmModal, + versions, + buttonDisabled, + isConfirming, + setCurrentTab, + setDslUrlValue, + handleFile, + handleCreateApp, + onDSLConfirm, + handleCancelConfirm, + } = useDSLImport({ + activeTab, + dslUrl, + onSuccess, + onClose, }) - const { mutateAsync: importDSLConfirm } = useImportPipelineDSLConfirm() - - const onDSLConfirm = async () => { - if (!importId) - return - const response = await importDSLConfirm(importId) - - if (!response) { - notify({ type: 'error', message: t('creation.errorTip', { ns: 'datasetPipeline' }) }) - return - } - - const { status, pipeline_id, dataset_id } = response - - if (status === DSLImportStatus.COMPLETED) { - if (onSuccess) - onSuccess() - if (onClose) - onClose() - - notify({ - type: 'success', - message: t('creation.successTip', { ns: 'datasetPipeline' }), - }) - if (pipeline_id) - await handleCheckPluginDependencies(pipeline_id, true) - push(`datasets/${dataset_id}/pipeline`) - } - else if (status === DSLImportStatus.FAILED) { - notify({ type: 'error', message: t('creation.errorTip', { ns: 'datasetPipeline' }) }) - } - } - - const buttonDisabled = useMemo(() => { - if (currentTab === CreateFromDSLModalTab.FROM_FILE) - return !currentFile - if (currentTab === CreateFromDSLModalTab.FROM_URL) - return !dslUrlValue - return false - }, [currentTab, currentFile, dslUrlValue]) + useKeyPress('esc', () => { + if (show && !showConfirmModal) + onClose() + }) return ( <> @@ -196,29 +69,25 @@ const CreateFromDSLModal = ({ setCurrentTab={setCurrentTab} /> <div className="px-6 py-4"> - { - currentTab === CreateFromDSLModalTab.FROM_FILE && ( - <Uploader - className="mt-0" - file={currentFile} - updateFile={handleFile} - /> - ) - } - { - currentTab === CreateFromDSLModalTab.FROM_URL && ( - <div> - <div className="system-md-semibold leading6 mb-1 text-text-secondary"> - DSL URL - </div> - <Input - placeholder={t('importFromDSLUrlPlaceholder', { ns: 'app' }) || ''} - value={dslUrlValue} - onChange={e => setDslUrlValue(e.target.value)} - /> + {currentTab === CreateFromDSLModalTab.FROM_FILE && ( + <Uploader + className="mt-0" + file={currentFile} + updateFile={handleFile} + /> + )} + {currentTab === CreateFromDSLModalTab.FROM_URL && ( + <div> + <div className="system-md-semibold leading6 mb-1 text-text-secondary"> + DSL URL </div> - ) - } + <Input + placeholder={t('importFromDSLUrlPlaceholder', { ns: 'app' }) || ''} + value={dslUrlValue} + onChange={e => setDslUrlValue(e.target.value)} + /> + </div> + )} </div> <div className="flex justify-end gap-x-2 p-6 pt-5"> <Button onClick={onClose}> @@ -234,32 +103,14 @@ const CreateFromDSLModal = ({ </Button> </div> </Modal> - <Modal - isShow={showErrorModal} - onClose={() => setShowErrorModal(false)} - className="w-[480px]" - > - <div className="flex flex-col items-start gap-2 self-stretch pb-4"> - <div className="title-2xl-semi-bold text-text-primary">{t('newApp.appCreateDSLErrorTitle', { ns: 'app' })}</div> - <div className="system-md-regular flex grow flex-col text-text-secondary"> - <div>{t('newApp.appCreateDSLErrorPart1', { ns: 'app' })}</div> - <div>{t('newApp.appCreateDSLErrorPart2', { ns: 'app' })}</div> - <br /> - <div> - {t('newApp.appCreateDSLErrorPart3', { ns: 'app' })} - <span className="system-md-medium">{versions?.importedVersion}</span> - </div> - <div> - {t('newApp.appCreateDSLErrorPart4', { ns: 'app' })} - <span className="system-md-medium">{versions?.systemVersion}</span> - </div> - </div> - </div> - <div className="flex items-start justify-end gap-2 self-stretch pt-6"> - <Button variant="secondary" onClick={() => setShowErrorModal(false)}>{t('newApp.Cancel', { ns: 'app' })}</Button> - <Button variant="primary" destructive onClick={onDSLConfirm}>{t('newApp.Confirm', { ns: 'app' })}</Button> - </div> - </Modal> + {showConfirmModal && ( + <DSLConfirmModal + versions={versions} + onCancel={handleCancelConfirm} + onConfirm={onDSLConfirm} + confirmDisabled={isConfirming} + /> + )} </> ) } diff --git a/web/app/components/datasets/create/file-uploader/components/file-list-item.spec.tsx b/web/app/components/datasets/create/file-uploader/components/file-list-item.spec.tsx new file mode 100644 index 0000000000..4da20a7bf7 --- /dev/null +++ b/web/app/components/datasets/create/file-uploader/components/file-list-item.spec.tsx @@ -0,0 +1,334 @@ +import type { FileListItemProps } from './file-list-item' +import type { CustomFile as File, FileItem } from '@/models/datasets' +import { fireEvent, render, screen } from '@testing-library/react' +import { beforeEach, describe, expect, it, vi } from 'vitest' +import { PROGRESS_COMPLETE, PROGRESS_ERROR, PROGRESS_NOT_STARTED } from '../constants' +import FileListItem from './file-list-item' + +// Mock theme hook - can be changed per test +let mockTheme = 'light' +vi.mock('@/hooks/use-theme', () => ({ + default: () => ({ theme: mockTheme }), +})) + +// Mock theme types +vi.mock('@/types/app', () => ({ + Theme: { dark: 'dark', light: 'light' }, +})) + +// Mock SimplePieChart with dynamic import handling +vi.mock('next/dynamic', () => ({ + default: () => { + const DynamicComponent = ({ percentage, stroke, fill }: { percentage: number, stroke: string, fill: string }) => ( + <div data-testid="pie-chart" data-percentage={percentage} data-stroke={stroke} data-fill={fill}> + Pie Chart: + {' '} + {percentage} + % + </div> + ) + DynamicComponent.displayName = 'SimplePieChart' + return DynamicComponent + }, +})) + +// Mock DocumentFileIcon +vi.mock('@/app/components/datasets/common/document-file-icon', () => ({ + default: ({ name, extension, size }: { name: string, extension: string, size: string }) => ( + <div data-testid="document-icon" data-name={name} data-extension={extension} data-size={size}> + Document Icon + </div> + ), +})) + +describe('FileListItem', () => { + const createMockFile = (overrides: Partial<File> = {}): File => ({ + name: 'test-document.pdf', + size: 1024 * 100, // 100KB + type: 'application/pdf', + lastModified: Date.now(), + ...overrides, + } as File) + + const createMockFileItem = (overrides: Partial<FileItem> = {}): FileItem => ({ + fileID: 'file-123', + file: createMockFile(overrides.file as Partial<File>), + progress: PROGRESS_NOT_STARTED, + ...overrides, + }) + + const defaultProps: FileListItemProps = { + fileItem: createMockFileItem(), + onPreview: vi.fn(), + onRemove: vi.fn(), + } + + beforeEach(() => { + vi.clearAllMocks() + mockTheme = 'light' + }) + + describe('rendering', () => { + it('should render the file item container', () => { + const { container } = render(<FileListItem {...defaultProps} />) + const item = container.firstChild as HTMLElement + expect(item).toHaveClass('flex', 'h-12', 'items-center', 'rounded-lg') + }) + + it('should render document icon with correct props', () => { + render(<FileListItem {...defaultProps} />) + const icon = screen.getByTestId('document-icon') + expect(icon).toBeInTheDocument() + expect(icon).toHaveAttribute('data-name', 'test-document.pdf') + expect(icon).toHaveAttribute('data-extension', 'pdf') + expect(icon).toHaveAttribute('data-size', 'xl') + }) + + it('should render file name', () => { + render(<FileListItem {...defaultProps} />) + expect(screen.getByText('test-document.pdf')).toBeInTheDocument() + }) + + it('should render file extension in uppercase via CSS class', () => { + render(<FileListItem {...defaultProps} />) + const extensionSpan = screen.getByText('pdf') + expect(extensionSpan).toBeInTheDocument() + expect(extensionSpan).toHaveClass('uppercase') + }) + + it('should render file size', () => { + render(<FileListItem {...defaultProps} />) + // Default mock file is 100KB (1024 * 100 bytes) + expect(screen.getByText('100.00 KB')).toBeInTheDocument() + }) + + it('should render delete button', () => { + const { container } = render(<FileListItem {...defaultProps} />) + const deleteButton = container.querySelector('.cursor-pointer') + expect(deleteButton).toBeInTheDocument() + }) + }) + + describe('progress states', () => { + it('should show progress chart when uploading (0-99)', () => { + const fileItem = createMockFileItem({ progress: 50 }) + render(<FileListItem {...defaultProps} fileItem={fileItem} />) + + const pieChart = screen.getByTestId('pie-chart') + expect(pieChart).toBeInTheDocument() + expect(pieChart).toHaveAttribute('data-percentage', '50') + }) + + it('should show progress chart at 0%', () => { + const fileItem = createMockFileItem({ progress: 0 }) + render(<FileListItem {...defaultProps} fileItem={fileItem} />) + + const pieChart = screen.getByTestId('pie-chart') + expect(pieChart).toHaveAttribute('data-percentage', '0') + }) + + it('should not show progress chart when complete (100)', () => { + const fileItem = createMockFileItem({ progress: PROGRESS_COMPLETE }) + render(<FileListItem {...defaultProps} fileItem={fileItem} />) + + expect(screen.queryByTestId('pie-chart')).not.toBeInTheDocument() + }) + + it('should not show progress chart when not started (-1)', () => { + const fileItem = createMockFileItem({ progress: PROGRESS_NOT_STARTED }) + render(<FileListItem {...defaultProps} fileItem={fileItem} />) + + expect(screen.queryByTestId('pie-chart')).not.toBeInTheDocument() + }) + }) + + describe('error state', () => { + it('should show error indicator when progress is PROGRESS_ERROR', () => { + const fileItem = createMockFileItem({ progress: PROGRESS_ERROR }) + const { container } = render(<FileListItem {...defaultProps} fileItem={fileItem} />) + + const errorIndicator = container.querySelector('.text-text-destructive') + expect(errorIndicator).toBeInTheDocument() + }) + + it('should not show error indicator when not in error state', () => { + const { container } = render(<FileListItem {...defaultProps} />) + const errorIndicator = container.querySelector('.text-text-destructive') + expect(errorIndicator).not.toBeInTheDocument() + }) + }) + + describe('theme handling', () => { + it('should use correct chart color for light theme', () => { + mockTheme = 'light' + const fileItem = createMockFileItem({ progress: 50 }) + render(<FileListItem {...defaultProps} fileItem={fileItem} />) + + const pieChart = screen.getByTestId('pie-chart') + expect(pieChart).toHaveAttribute('data-stroke', '#296dff') + expect(pieChart).toHaveAttribute('data-fill', '#296dff') + }) + + it('should use correct chart color for dark theme', () => { + mockTheme = 'dark' + const fileItem = createMockFileItem({ progress: 50 }) + render(<FileListItem {...defaultProps} fileItem={fileItem} />) + + const pieChart = screen.getByTestId('pie-chart') + expect(pieChart).toHaveAttribute('data-stroke', '#5289ff') + expect(pieChart).toHaveAttribute('data-fill', '#5289ff') + }) + }) + + describe('event handlers', () => { + it('should call onPreview when item is clicked with file id', () => { + const onPreview = vi.fn() + const fileItem = createMockFileItem({ + file: createMockFile({ id: 'uploaded-id' } as Partial<File>), + }) + render(<FileListItem {...defaultProps} fileItem={fileItem} onPreview={onPreview} />) + + const item = screen.getByText('test-document.pdf').closest('[class*="flex h-12"]')! + fireEvent.click(item) + + expect(onPreview).toHaveBeenCalledTimes(1) + expect(onPreview).toHaveBeenCalledWith(fileItem.file) + }) + + it('should not call onPreview when file has no id', () => { + const onPreview = vi.fn() + const fileItem = createMockFileItem() + render(<FileListItem {...defaultProps} fileItem={fileItem} onPreview={onPreview} />) + + const item = screen.getByText('test-document.pdf').closest('[class*="flex h-12"]')! + fireEvent.click(item) + + expect(onPreview).not.toHaveBeenCalled() + }) + + it('should call onRemove when delete button is clicked', () => { + const onRemove = vi.fn() + const fileItem = createMockFileItem() + const { container } = render(<FileListItem {...defaultProps} fileItem={fileItem} onRemove={onRemove} />) + + const deleteButton = container.querySelector('.cursor-pointer')! + fireEvent.click(deleteButton) + + expect(onRemove).toHaveBeenCalledTimes(1) + expect(onRemove).toHaveBeenCalledWith('file-123') + }) + + it('should stop propagation when delete button is clicked', () => { + const onPreview = vi.fn() + const onRemove = vi.fn() + const fileItem = createMockFileItem({ + file: createMockFile({ id: 'uploaded-id' } as Partial<File>), + }) + const { container } = render(<FileListItem {...defaultProps} fileItem={fileItem} onPreview={onPreview} onRemove={onRemove} />) + + const deleteButton = container.querySelector('.cursor-pointer')! + fireEvent.click(deleteButton) + + expect(onRemove).toHaveBeenCalledTimes(1) + expect(onPreview).not.toHaveBeenCalled() + }) + }) + + describe('file type handling', () => { + it('should handle files with multiple dots in name', () => { + const fileItem = createMockFileItem({ + file: createMockFile({ name: 'my.document.file.docx' }), + }) + render(<FileListItem {...defaultProps} fileItem={fileItem} />) + + expect(screen.getByText('my.document.file.docx')).toBeInTheDocument() + expect(screen.getByText('docx')).toBeInTheDocument() + }) + + it('should handle files without extension', () => { + const fileItem = createMockFileItem({ + file: createMockFile({ name: 'README' }), + }) + render(<FileListItem {...defaultProps} fileItem={fileItem} />) + + // File name appears once, and extension area shows empty string + expect(screen.getByText('README')).toBeInTheDocument() + }) + + it('should handle various file extensions', () => { + const extensions = ['txt', 'md', 'json', 'csv', 'xlsx'] + + extensions.forEach((ext) => { + const fileItem = createMockFileItem({ + file: createMockFile({ name: `file.${ext}` }), + }) + const { unmount } = render(<FileListItem {...defaultProps} fileItem={fileItem} />) + expect(screen.getByText(ext)).toBeInTheDocument() + unmount() + }) + }) + }) + + describe('file size display', () => { + it('should display size in KB for small files', () => { + const fileItem = createMockFileItem({ + file: createMockFile({ size: 5 * 1024 }), + }) + render(<FileListItem {...defaultProps} fileItem={fileItem} />) + expect(screen.getByText('5.00 KB')).toBeInTheDocument() + }) + + it('should display size in MB for larger files', () => { + const fileItem = createMockFileItem({ + file: createMockFile({ size: 5 * 1024 * 1024 }), + }) + render(<FileListItem {...defaultProps} fileItem={fileItem} />) + expect(screen.getByText('5.00 MB')).toBeInTheDocument() + }) + }) + + describe('upload progress values', () => { + it('should show chart at progress 1', () => { + const fileItem = createMockFileItem({ progress: 1 }) + render(<FileListItem {...defaultProps} fileItem={fileItem} />) + expect(screen.getByTestId('pie-chart')).toBeInTheDocument() + }) + + it('should show chart at progress 99', () => { + const fileItem = createMockFileItem({ progress: 99 }) + render(<FileListItem {...defaultProps} fileItem={fileItem} />) + expect(screen.getByTestId('pie-chart')).toHaveAttribute('data-percentage', '99') + }) + + it('should not show chart at progress 100', () => { + const fileItem = createMockFileItem({ progress: 100 }) + render(<FileListItem {...defaultProps} fileItem={fileItem} />) + expect(screen.queryByTestId('pie-chart')).not.toBeInTheDocument() + }) + }) + + describe('styling', () => { + it('should have proper shadow styling', () => { + const { container } = render(<FileListItem {...defaultProps} />) + const item = container.firstChild as HTMLElement + expect(item).toHaveClass('shadow-xs') + }) + + it('should have proper border styling', () => { + const { container } = render(<FileListItem {...defaultProps} />) + const item = container.firstChild as HTMLElement + expect(item).toHaveClass('border', 'border-components-panel-border') + }) + + it('should truncate long file names', () => { + const longFileName = 'this-is-a-very-long-file-name-that-should-be-truncated.pdf' + const fileItem = createMockFileItem({ + file: createMockFile({ name: longFileName }), + }) + render(<FileListItem {...defaultProps} fileItem={fileItem} />) + + const nameElement = screen.getByText(longFileName) + expect(nameElement).toHaveClass('truncate') + }) + }) +}) diff --git a/web/app/components/datasets/create/file-uploader/components/file-list-item.tsx b/web/app/components/datasets/create/file-uploader/components/file-list-item.tsx new file mode 100644 index 0000000000..d36773fa5c --- /dev/null +++ b/web/app/components/datasets/create/file-uploader/components/file-list-item.tsx @@ -0,0 +1,89 @@ +'use client' +import type { CustomFile as File, FileItem } from '@/models/datasets' +import { RiDeleteBinLine, RiErrorWarningFill } from '@remixicon/react' +import dynamic from 'next/dynamic' +import { useMemo } from 'react' +import DocumentFileIcon from '@/app/components/datasets/common/document-file-icon' +import useTheme from '@/hooks/use-theme' +import { Theme } from '@/types/app' +import { formatFileSize, getFileExtension } from '@/utils/format' +import { PROGRESS_COMPLETE, PROGRESS_ERROR } from '../constants' + +const SimplePieChart = dynamic(() => import('@/app/components/base/simple-pie-chart'), { ssr: false }) + +export type FileListItemProps = { + fileItem: FileItem + onPreview: (file: File) => void + onRemove: (fileID: string) => void +} + +const FileListItem = ({ + fileItem, + onPreview, + onRemove, +}: FileListItemProps) => { + const { theme } = useTheme() + const chartColor = useMemo(() => theme === Theme.dark ? '#5289ff' : '#296dff', [theme]) + + const isUploading = fileItem.progress >= 0 && fileItem.progress < PROGRESS_COMPLETE + const isError = fileItem.progress === PROGRESS_ERROR + + const handleClick = () => { + if (fileItem.file?.id) + onPreview(fileItem.file) + } + + const handleRemove = (e: React.MouseEvent) => { + e.stopPropagation() + onRemove(fileItem.fileID) + } + + return ( + <div + onClick={handleClick} + className="flex h-12 max-w-[640px] items-center rounded-lg border border-components-panel-border bg-components-panel-on-panel-item-bg text-xs leading-3 text-text-tertiary shadow-xs" + > + <div className="flex w-12 shrink-0 items-center justify-center"> + <DocumentFileIcon + size="xl" + className="shrink-0" + name={fileItem.file.name} + extension={getFileExtension(fileItem.file.name)} + /> + </div> + <div className="flex shrink grow flex-col gap-0.5"> + <div className="flex w-full"> + <div className="w-0 grow truncate text-sm leading-4 text-text-secondary"> + {fileItem.file.name} + </div> + </div> + <div className="w-full truncate leading-3 text-text-tertiary"> + <span className="uppercase">{getFileExtension(fileItem.file.name)}</span> + <span className="px-1 text-text-quaternary">ยท</span> + <span>{formatFileSize(fileItem.file.size)}</span> + </div> + </div> + <div className="flex w-16 shrink-0 items-center justify-end gap-1 pr-3"> + {isUploading && ( + <SimplePieChart + percentage={fileItem.progress} + stroke={chartColor} + fill={chartColor} + animationDuration={0} + /> + )} + {isError && ( + <RiErrorWarningFill className="size-4 text-text-destructive" /> + )} + <span + className="flex h-6 w-6 cursor-pointer items-center justify-center" + onClick={handleRemove} + > + <RiDeleteBinLine className="size-4 text-text-tertiary" /> + </span> + </div> + </div> + ) +} + +export default FileListItem diff --git a/web/app/components/datasets/create/file-uploader/components/upload-dropzone.spec.tsx b/web/app/components/datasets/create/file-uploader/components/upload-dropzone.spec.tsx new file mode 100644 index 0000000000..112d61250b --- /dev/null +++ b/web/app/components/datasets/create/file-uploader/components/upload-dropzone.spec.tsx @@ -0,0 +1,210 @@ +import type { RefObject } from 'react' +import type { UploadDropzoneProps } from './upload-dropzone' +import { fireEvent, render, screen } from '@testing-library/react' +import { beforeEach, describe, expect, it, vi } from 'vitest' +import UploadDropzone from './upload-dropzone' + +// Helper to create mock ref objects for testing +const createMockRef = <T,>(value: T | null = null): RefObject<T | null> => ({ current: value }) + +// Mock react-i18next +vi.mock('react-i18next', () => ({ + useTranslation: () => ({ + t: (key: string, options?: Record<string, unknown>) => { + const translations: Record<string, string> = { + 'stepOne.uploader.button': 'Drag and drop files, or', + 'stepOne.uploader.buttonSingleFile': 'Drag and drop file, or', + 'stepOne.uploader.browse': 'Browse', + 'stepOne.uploader.tip': 'Supports {{supportTypes}}, Max {{size}}MB each, up to {{batchCount}} files at a time, {{totalCount}} files total', + } + let result = translations[key] || key + if (options && typeof options === 'object') { + Object.entries(options).forEach(([k, v]) => { + result = result.replace(`{{${k}}}`, String(v)) + }) + } + return result + }, + }), +})) + +describe('UploadDropzone', () => { + const defaultProps: UploadDropzoneProps = { + dropRef: createMockRef<HTMLDivElement>() as RefObject<HTMLDivElement | null>, + dragRef: createMockRef<HTMLDivElement>() as RefObject<HTMLDivElement | null>, + fileUploaderRef: createMockRef<HTMLInputElement>() as RefObject<HTMLInputElement | null>, + dragging: false, + supportBatchUpload: true, + supportTypesShowNames: 'PDF, DOCX, TXT', + fileUploadConfig: { + file_size_limit: 15, + batch_count_limit: 5, + file_upload_limit: 10, + }, + acceptTypes: ['.pdf', '.docx', '.txt'], + onSelectFile: vi.fn(), + onFileChange: vi.fn(), + } + + beforeEach(() => { + vi.clearAllMocks() + }) + + describe('rendering', () => { + it('should render the dropzone container', () => { + const { container } = render(<UploadDropzone {...defaultProps} />) + const dropzone = container.querySelector('[class*="border-dashed"]') + expect(dropzone).toBeInTheDocument() + }) + + it('should render hidden file input', () => { + render(<UploadDropzone {...defaultProps} />) + const input = document.getElementById('fileUploader') as HTMLInputElement + expect(input).toBeInTheDocument() + expect(input).toHaveClass('hidden') + expect(input).toHaveAttribute('type', 'file') + }) + + it('should render upload icon', () => { + render(<UploadDropzone {...defaultProps} />) + const icon = document.querySelector('svg') + expect(icon).toBeInTheDocument() + }) + + it('should render browse label when extensions are allowed', () => { + render(<UploadDropzone {...defaultProps} />) + expect(screen.getByText('Browse')).toBeInTheDocument() + }) + + it('should not render browse label when no extensions allowed', () => { + render(<UploadDropzone {...defaultProps} acceptTypes={[]} />) + expect(screen.queryByText('Browse')).not.toBeInTheDocument() + }) + + it('should render file size and count limits', () => { + render(<UploadDropzone {...defaultProps} />) + const tipText = screen.getByText(/Supports.*Max.*15MB/i) + expect(tipText).toBeInTheDocument() + }) + }) + + describe('file input configuration', () => { + it('should allow multiple files when supportBatchUpload is true', () => { + render(<UploadDropzone {...defaultProps} supportBatchUpload={true} />) + const input = document.getElementById('fileUploader') as HTMLInputElement + expect(input).toHaveAttribute('multiple') + }) + + it('should not allow multiple files when supportBatchUpload is false', () => { + render(<UploadDropzone {...defaultProps} supportBatchUpload={false} />) + const input = document.getElementById('fileUploader') as HTMLInputElement + expect(input).not.toHaveAttribute('multiple') + }) + + it('should set accept attribute with correct types', () => { + render(<UploadDropzone {...defaultProps} acceptTypes={['.pdf', '.docx']} />) + const input = document.getElementById('fileUploader') as HTMLInputElement + expect(input).toHaveAttribute('accept', '.pdf,.docx') + }) + }) + + describe('text content', () => { + it('should show batch upload text when supportBatchUpload is true', () => { + render(<UploadDropzone {...defaultProps} supportBatchUpload={true} />) + expect(screen.getByText(/Drag and drop files/i)).toBeInTheDocument() + }) + + it('should show single file text when supportBatchUpload is false', () => { + render(<UploadDropzone {...defaultProps} supportBatchUpload={false} />) + expect(screen.getByText(/Drag and drop file/i)).toBeInTheDocument() + }) + }) + + describe('dragging state', () => { + it('should apply dragging styles when dragging is true', () => { + const { container } = render(<UploadDropzone {...defaultProps} dragging={true} />) + const dropzone = container.querySelector('[class*="border-components-dropzone-border-accent"]') + expect(dropzone).toBeInTheDocument() + }) + + it('should render drag overlay when dragging', () => { + const dragRef = createMockRef<HTMLDivElement>() + render(<UploadDropzone {...defaultProps} dragging={true} dragRef={dragRef as RefObject<HTMLDivElement | null>} />) + const overlay = document.querySelector('.absolute.left-0.top-0') + expect(overlay).toBeInTheDocument() + }) + + it('should not render drag overlay when not dragging', () => { + render(<UploadDropzone {...defaultProps} dragging={false} />) + const overlay = document.querySelector('.absolute.left-0.top-0') + expect(overlay).not.toBeInTheDocument() + }) + }) + + describe('event handlers', () => { + it('should call onSelectFile when browse label is clicked', () => { + const onSelectFile = vi.fn() + render(<UploadDropzone {...defaultProps} onSelectFile={onSelectFile} />) + + const browseLabel = screen.getByText('Browse') + fireEvent.click(browseLabel) + + expect(onSelectFile).toHaveBeenCalledTimes(1) + }) + + it('should call onFileChange when files are selected', () => { + const onFileChange = vi.fn() + render(<UploadDropzone {...defaultProps} onFileChange={onFileChange} />) + + const input = document.getElementById('fileUploader') as HTMLInputElement + const file = new File(['content'], 'test.pdf', { type: 'application/pdf' }) + + fireEvent.change(input, { target: { files: [file] } }) + + expect(onFileChange).toHaveBeenCalledTimes(1) + }) + }) + + describe('refs', () => { + it('should attach dropRef to drop container', () => { + const dropRef = createMockRef<HTMLDivElement>() + render(<UploadDropzone {...defaultProps} dropRef={dropRef as RefObject<HTMLDivElement | null>} />) + expect(dropRef.current).toBeInstanceOf(HTMLDivElement) + }) + + it('should attach fileUploaderRef to input element', () => { + const fileUploaderRef = createMockRef<HTMLInputElement>() + render(<UploadDropzone {...defaultProps} fileUploaderRef={fileUploaderRef as RefObject<HTMLInputElement | null>} />) + expect(fileUploaderRef.current).toBeInstanceOf(HTMLInputElement) + }) + + it('should attach dragRef to overlay when dragging', () => { + const dragRef = createMockRef<HTMLDivElement>() + render(<UploadDropzone {...defaultProps} dragging={true} dragRef={dragRef as RefObject<HTMLDivElement | null>} />) + expect(dragRef.current).toBeInstanceOf(HTMLDivElement) + }) + }) + + describe('styling', () => { + it('should have base dropzone styling', () => { + const { container } = render(<UploadDropzone {...defaultProps} />) + const dropzone = container.querySelector('[class*="border-dashed"]') + expect(dropzone).toBeInTheDocument() + expect(dropzone).toHaveClass('rounded-xl') + }) + + it('should have cursor-pointer on browse label', () => { + render(<UploadDropzone {...defaultProps} />) + const browseLabel = screen.getByText('Browse') + expect(browseLabel).toHaveClass('cursor-pointer') + }) + }) + + describe('accessibility', () => { + it('should have an accessible file input', () => { + render(<UploadDropzone {...defaultProps} />) + const input = document.getElementById('fileUploader') as HTMLInputElement + expect(input).toHaveAttribute('id', 'fileUploader') + }) + }) +}) diff --git a/web/app/components/datasets/create/file-uploader/components/upload-dropzone.tsx b/web/app/components/datasets/create/file-uploader/components/upload-dropzone.tsx new file mode 100644 index 0000000000..9fa577dace --- /dev/null +++ b/web/app/components/datasets/create/file-uploader/components/upload-dropzone.tsx @@ -0,0 +1,84 @@ +'use client' +import type { RefObject } from 'react' +import type { FileUploadConfig } from '../hooks/use-file-upload' +import { RiUploadCloud2Line } from '@remixicon/react' +import { useTranslation } from 'react-i18next' +import { cn } from '@/utils/classnames' + +export type UploadDropzoneProps = { + dropRef: RefObject<HTMLDivElement | null> + dragRef: RefObject<HTMLDivElement | null> + fileUploaderRef: RefObject<HTMLInputElement | null> + dragging: boolean + supportBatchUpload: boolean + supportTypesShowNames: string + fileUploadConfig: FileUploadConfig + acceptTypes: string[] + onSelectFile: () => void + onFileChange: (e: React.ChangeEvent<HTMLInputElement>) => void +} + +const UploadDropzone = ({ + dropRef, + dragRef, + fileUploaderRef, + dragging, + supportBatchUpload, + supportTypesShowNames, + fileUploadConfig, + acceptTypes, + onSelectFile, + onFileChange, +}: UploadDropzoneProps) => { + const { t } = useTranslation() + + return ( + <> + <input + ref={fileUploaderRef} + id="fileUploader" + className="hidden" + type="file" + multiple={supportBatchUpload} + accept={acceptTypes.join(',')} + onChange={onFileChange} + /> + <div + ref={dropRef} + className={cn( + 'relative mb-2 box-border flex min-h-20 max-w-[640px] flex-col items-center justify-center gap-1 rounded-xl border border-dashed border-components-dropzone-border bg-components-dropzone-bg px-4 py-3 text-xs leading-4 text-text-tertiary', + dragging && 'border-components-dropzone-border-accent bg-components-dropzone-bg-accent', + )} + > + <div className="flex min-h-5 items-center justify-center text-sm leading-4 text-text-secondary"> + <RiUploadCloud2Line className="mr-2 size-5" /> + <span> + {supportBatchUpload + ? t('stepOne.uploader.button', { ns: 'datasetCreation' }) + : t('stepOne.uploader.buttonSingleFile', { ns: 'datasetCreation' })} + {acceptTypes.length > 0 && ( + <label + className="ml-1 cursor-pointer text-text-accent" + onClick={onSelectFile} + > + {t('stepOne.uploader.browse', { ns: 'datasetCreation' })} + </label> + )} + </span> + </div> + <div> + {t('stepOne.uploader.tip', { + ns: 'datasetCreation', + size: fileUploadConfig.file_size_limit, + supportTypes: supportTypesShowNames, + batchCount: fileUploadConfig.batch_count_limit, + totalCount: fileUploadConfig.file_upload_limit, + })} + </div> + {dragging && <div ref={dragRef} className="absolute left-0 top-0 h-full w-full" />} + </div> + </> + ) +} + +export default UploadDropzone diff --git a/web/app/components/datasets/create/file-uploader/constants.ts b/web/app/components/datasets/create/file-uploader/constants.ts new file mode 100644 index 0000000000..cda2dae868 --- /dev/null +++ b/web/app/components/datasets/create/file-uploader/constants.ts @@ -0,0 +1,3 @@ +export const PROGRESS_NOT_STARTED = -1 +export const PROGRESS_ERROR = -2 +export const PROGRESS_COMPLETE = 100 diff --git a/web/app/components/datasets/create/file-uploader/hooks/use-file-upload.spec.tsx b/web/app/components/datasets/create/file-uploader/hooks/use-file-upload.spec.tsx new file mode 100644 index 0000000000..222f038c84 --- /dev/null +++ b/web/app/components/datasets/create/file-uploader/hooks/use-file-upload.spec.tsx @@ -0,0 +1,921 @@ +import type { ReactNode } from 'react' +import type { CustomFile, FileItem } from '@/models/datasets' +import { act, render, renderHook, waitFor } from '@testing-library/react' +import { beforeEach, describe, expect, it, vi } from 'vitest' +import { ToastContext } from '@/app/components/base/toast' + +import { PROGRESS_COMPLETE, PROGRESS_ERROR, PROGRESS_NOT_STARTED } from '../constants' +// Import after mocks +import { useFileUpload } from './use-file-upload' + +// Mock notify function +const mockNotify = vi.fn() +const mockClose = vi.fn() + +// Mock ToastContext +vi.mock('use-context-selector', async () => { + const actual = await vi.importActual<typeof import('use-context-selector')>('use-context-selector') + return { + ...actual, + useContext: vi.fn(() => ({ notify: mockNotify, close: mockClose })), + } +}) + +// Mock upload service +const mockUpload = vi.fn() +vi.mock('@/service/base', () => ({ + upload: (...args: unknown[]) => mockUpload(...args), +})) + +// Mock file upload config +const mockFileUploadConfig = { + file_size_limit: 15, + batch_count_limit: 5, + file_upload_limit: 10, +} + +const mockSupportTypes = { + allowed_extensions: ['pdf', 'docx', 'txt', 'md'], +} + +vi.mock('@/service/use-common', () => ({ + useFileUploadConfig: () => ({ data: mockFileUploadConfig }), + useFileSupportTypes: () => ({ data: mockSupportTypes }), +})) + +// Mock i18n +vi.mock('react-i18next', () => ({ + useTranslation: () => ({ + t: (key: string) => key, + }), +})) + +// Mock locale +vi.mock('@/context/i18n', () => ({ + useLocale: () => 'en-US', +})) + +vi.mock('@/i18n-config/language', () => ({ + LanguagesSupported: ['en-US', 'zh-Hans'], +})) + +// Mock config +vi.mock('@/config', () => ({ + IS_CE_EDITION: false, +})) + +// Mock file upload error message +vi.mock('@/app/components/base/file-uploader/utils', () => ({ + getFileUploadErrorMessage: (_e: unknown, defaultMsg: string) => defaultMsg, +})) + +const createWrapper = () => { + return ({ children }: { children: ReactNode }) => ( + <ToastContext.Provider value={{ notify: mockNotify, close: mockClose }}> + {children} + </ToastContext.Provider> + ) +} + +describe('useFileUpload', () => { + const defaultOptions = { + fileList: [] as FileItem[], + prepareFileList: vi.fn(), + onFileUpdate: vi.fn(), + onFileListUpdate: vi.fn(), + onPreview: vi.fn(), + supportBatchUpload: true, + } + + beforeEach(() => { + vi.clearAllMocks() + mockUpload.mockReset() + // Default mock to return a resolved promise to avoid unhandled rejections + mockUpload.mockResolvedValue({ id: 'default-id' }) + mockNotify.mockReset() + }) + + describe('initialization', () => { + it('should initialize with default values', () => { + const { result } = renderHook( + () => useFileUpload(defaultOptions), + { wrapper: createWrapper() }, + ) + + expect(result.current.dragging).toBe(false) + expect(result.current.hideUpload).toBe(false) + expect(result.current.dropRef.current).toBeNull() + expect(result.current.dragRef.current).toBeNull() + expect(result.current.fileUploaderRef.current).toBeNull() + }) + + it('should set hideUpload true when not batch upload and has files', () => { + const { result } = renderHook( + () => useFileUpload({ + ...defaultOptions, + supportBatchUpload: false, + fileList: [{ fileID: 'file-1', file: {} as CustomFile, progress: 100 }], + }), + { wrapper: createWrapper() }, + ) + + expect(result.current.hideUpload).toBe(true) + }) + + it('should compute acceptTypes correctly', () => { + const { result } = renderHook( + () => useFileUpload(defaultOptions), + { wrapper: createWrapper() }, + ) + + expect(result.current.acceptTypes).toEqual(['.pdf', '.docx', '.txt', '.md']) + }) + + it('should compute supportTypesShowNames correctly', () => { + const { result } = renderHook( + () => useFileUpload(defaultOptions), + { wrapper: createWrapper() }, + ) + + expect(result.current.supportTypesShowNames).toContain('PDF') + expect(result.current.supportTypesShowNames).toContain('DOCX') + expect(result.current.supportTypesShowNames).toContain('TXT') + // 'md' is mapped to 'markdown' in the extensionMap + expect(result.current.supportTypesShowNames).toContain('MARKDOWN') + }) + + it('should set batch limit to 1 when not batch upload', () => { + const { result } = renderHook( + () => useFileUpload({ + ...defaultOptions, + supportBatchUpload: false, + }), + { wrapper: createWrapper() }, + ) + + expect(result.current.fileUploadConfig.batch_count_limit).toBe(1) + expect(result.current.fileUploadConfig.file_upload_limit).toBe(1) + }) + }) + + describe('selectHandle', () => { + it('should trigger click on file input', () => { + const { result } = renderHook( + () => useFileUpload(defaultOptions), + { wrapper: createWrapper() }, + ) + + const mockClick = vi.fn() + const mockInput = { click: mockClick } as unknown as HTMLInputElement + Object.defineProperty(result.current.fileUploaderRef, 'current', { + value: mockInput, + writable: true, + }) + + act(() => { + result.current.selectHandle() + }) + + expect(mockClick).toHaveBeenCalled() + }) + + it('should do nothing when file input ref is null', () => { + const { result } = renderHook( + () => useFileUpload(defaultOptions), + { wrapper: createWrapper() }, + ) + + expect(() => { + act(() => { + result.current.selectHandle() + }) + }).not.toThrow() + }) + }) + + describe('handlePreview', () => { + it('should call onPreview when file has id', () => { + const onPreview = vi.fn() + const { result } = renderHook( + () => useFileUpload({ ...defaultOptions, onPreview }), + { wrapper: createWrapper() }, + ) + + const mockFile = { id: 'file-123', name: 'test.pdf', size: 1024 } as CustomFile + + act(() => { + result.current.handlePreview(mockFile) + }) + + expect(onPreview).toHaveBeenCalledWith(mockFile) + }) + + it('should not call onPreview when file has no id', () => { + const onPreview = vi.fn() + const { result } = renderHook( + () => useFileUpload({ ...defaultOptions, onPreview }), + { wrapper: createWrapper() }, + ) + + const mockFile = { name: 'test.pdf', size: 1024 } as CustomFile + + act(() => { + result.current.handlePreview(mockFile) + }) + + expect(onPreview).not.toHaveBeenCalled() + }) + }) + + describe('removeFile', () => { + it('should call onFileListUpdate with filtered list', () => { + const onFileListUpdate = vi.fn() + const { result } = renderHook( + () => useFileUpload({ ...defaultOptions, onFileListUpdate }), + { wrapper: createWrapper() }, + ) + + act(() => { + result.current.removeFile('file-to-remove') + }) + + expect(onFileListUpdate).toHaveBeenCalled() + }) + + it('should clear file input value', () => { + const { result } = renderHook( + () => useFileUpload(defaultOptions), + { wrapper: createWrapper() }, + ) + + const mockInput = { value: 'some-file' } as HTMLInputElement + Object.defineProperty(result.current.fileUploaderRef, 'current', { + value: mockInput, + writable: true, + }) + + act(() => { + result.current.removeFile('file-123') + }) + + expect(mockInput.value).toBe('') + }) + }) + + describe('fileChangeHandle', () => { + it('should handle valid files', async () => { + mockUpload.mockResolvedValue({ id: 'uploaded-id' }) + + const prepareFileList = vi.fn() + const { result } = renderHook( + () => useFileUpload({ ...defaultOptions, prepareFileList }), + { wrapper: createWrapper() }, + ) + + const mockFile = new File(['content'], 'test.pdf', { type: 'application/pdf' }) + const event = { + target: { files: [mockFile] }, + } as unknown as React.ChangeEvent<HTMLInputElement> + + act(() => { + result.current.fileChangeHandle(event) + }) + + await waitFor(() => { + expect(prepareFileList).toHaveBeenCalled() + }) + }) + + it('should limit files to batch count', () => { + const prepareFileList = vi.fn() + const { result } = renderHook( + () => useFileUpload({ ...defaultOptions, prepareFileList }), + { wrapper: createWrapper() }, + ) + + const files = Array.from({ length: 10 }, (_, i) => + new File(['content'], `file${i}.pdf`, { type: 'application/pdf' })) + + const event = { + target: { files }, + } as unknown as React.ChangeEvent<HTMLInputElement> + + act(() => { + result.current.fileChangeHandle(event) + }) + + // Should be called with at most batch_count_limit files + if (prepareFileList.mock.calls.length > 0) { + const calledFiles = prepareFileList.mock.calls[0][0] + expect(calledFiles.length).toBeLessThanOrEqual(mockFileUploadConfig.batch_count_limit) + } + }) + + it('should reject invalid file types', () => { + const { result } = renderHook( + () => useFileUpload(defaultOptions), + { wrapper: createWrapper() }, + ) + + const mockFile = new File(['content'], 'test.exe', { type: 'application/x-msdownload' }) + const event = { + target: { files: [mockFile] }, + } as unknown as React.ChangeEvent<HTMLInputElement> + + act(() => { + result.current.fileChangeHandle(event) + }) + + expect(mockNotify).toHaveBeenCalledWith( + expect.objectContaining({ type: 'error' }), + ) + }) + + it('should reject files exceeding size limit', () => { + const { result } = renderHook( + () => useFileUpload(defaultOptions), + { wrapper: createWrapper() }, + ) + + // Create a file larger than the limit (15MB) + const largeFile = new File([new ArrayBuffer(20 * 1024 * 1024)], 'large.pdf', { type: 'application/pdf' }) + + const event = { + target: { files: [largeFile] }, + } as unknown as React.ChangeEvent<HTMLInputElement> + + act(() => { + result.current.fileChangeHandle(event) + }) + + expect(mockNotify).toHaveBeenCalledWith( + expect.objectContaining({ type: 'error' }), + ) + }) + + it('should handle null files', () => { + const prepareFileList = vi.fn() + const { result } = renderHook( + () => useFileUpload({ ...defaultOptions, prepareFileList }), + { wrapper: createWrapper() }, + ) + + const event = { + target: { files: null }, + } as unknown as React.ChangeEvent<HTMLInputElement> + + act(() => { + result.current.fileChangeHandle(event) + }) + + expect(prepareFileList).not.toHaveBeenCalled() + }) + }) + + describe('drag and drop handlers', () => { + const TestDropzone = ({ options }: { options: typeof defaultOptions }) => { + const { + dropRef, + dragRef, + dragging, + } = useFileUpload(options) + + return ( + <div> + <div ref={dropRef} data-testid="dropzone"> + {dragging && <div ref={dragRef} data-testid="drag-overlay" />} + </div> + <span data-testid="dragging">{String(dragging)}</span> + </div> + ) + } + + it('should set dragging true on dragenter', async () => { + const { getByTestId } = await act(async () => + render( + <ToastContext.Provider value={{ notify: mockNotify, close: mockClose }}> + <TestDropzone options={defaultOptions} /> + </ToastContext.Provider>, + ), + ) + + const dropzone = getByTestId('dropzone') + + await act(async () => { + const dragEnterEvent = new Event('dragenter', { bubbles: true, cancelable: true }) + dropzone.dispatchEvent(dragEnterEvent) + }) + + expect(getByTestId('dragging').textContent).toBe('true') + }) + + it('should handle dragover event', async () => { + const { getByTestId } = await act(async () => + render( + <ToastContext.Provider value={{ notify: mockNotify, close: mockClose }}> + <TestDropzone options={defaultOptions} /> + </ToastContext.Provider>, + ), + ) + + const dropzone = getByTestId('dropzone') + + await act(async () => { + const dragOverEvent = new Event('dragover', { bubbles: true, cancelable: true }) + dropzone.dispatchEvent(dragOverEvent) + }) + + expect(dropzone).toBeInTheDocument() + }) + + it('should set dragging false on dragleave from drag overlay', async () => { + const { getByTestId, queryByTestId } = await act(async () => + render( + <ToastContext.Provider value={{ notify: mockNotify, close: mockClose }}> + <TestDropzone options={defaultOptions} /> + </ToastContext.Provider>, + ), + ) + + const dropzone = getByTestId('dropzone') + + await act(async () => { + const dragEnterEvent = new Event('dragenter', { bubbles: true, cancelable: true }) + dropzone.dispatchEvent(dragEnterEvent) + }) + + expect(getByTestId('dragging').textContent).toBe('true') + + const dragOverlay = queryByTestId('drag-overlay') + if (dragOverlay) { + await act(async () => { + const dragLeaveEvent = new Event('dragleave', { bubbles: true, cancelable: true }) + Object.defineProperty(dragLeaveEvent, 'target', { value: dragOverlay }) + dropzone.dispatchEvent(dragLeaveEvent) + }) + } + }) + + it('should handle drop with files', async () => { + mockUpload.mockResolvedValue({ id: 'uploaded-id' }) + const prepareFileList = vi.fn() + + const { getByTestId } = await act(async () => + render( + <ToastContext.Provider value={{ notify: mockNotify, close: mockClose }}> + <TestDropzone options={{ ...defaultOptions, prepareFileList }} /> + </ToastContext.Provider>, + ), + ) + + const dropzone = getByTestId('dropzone') + const mockFile = new File(['content'], 'test.pdf', { type: 'application/pdf' }) + + await act(async () => { + const dropEvent = new Event('drop', { bubbles: true, cancelable: true }) as Event & { dataTransfer: DataTransfer | null } + Object.defineProperty(dropEvent, 'dataTransfer', { + value: { + items: [{ + getAsFile: () => mockFile, + webkitGetAsEntry: () => null, + }], + }, + }) + dropzone.dispatchEvent(dropEvent) + }) + + await waitFor(() => { + expect(prepareFileList).toHaveBeenCalled() + }) + }) + + it('should handle drop without dataTransfer', async () => { + const prepareFileList = vi.fn() + + const { getByTestId } = await act(async () => + render( + <ToastContext.Provider value={{ notify: mockNotify, close: mockClose }}> + <TestDropzone options={{ ...defaultOptions, prepareFileList }} /> + </ToastContext.Provider>, + ), + ) + + const dropzone = getByTestId('dropzone') + + await act(async () => { + const dropEvent = new Event('drop', { bubbles: true, cancelable: true }) as Event & { dataTransfer: DataTransfer | null } + Object.defineProperty(dropEvent, 'dataTransfer', { value: null }) + dropzone.dispatchEvent(dropEvent) + }) + + expect(prepareFileList).not.toHaveBeenCalled() + }) + + it('should limit to single file on drop when supportBatchUpload is false', async () => { + mockUpload.mockResolvedValue({ id: 'uploaded-id' }) + const prepareFileList = vi.fn() + + const { getByTestId } = await act(async () => + render( + <ToastContext.Provider value={{ notify: mockNotify, close: mockClose }}> + <TestDropzone options={{ ...defaultOptions, supportBatchUpload: false, prepareFileList }} /> + </ToastContext.Provider>, + ), + ) + + const dropzone = getByTestId('dropzone') + const files = [ + new File(['content1'], 'test1.pdf', { type: 'application/pdf' }), + new File(['content2'], 'test2.pdf', { type: 'application/pdf' }), + ] + + await act(async () => { + const dropEvent = new Event('drop', { bubbles: true, cancelable: true }) as Event & { dataTransfer: DataTransfer | null } + Object.defineProperty(dropEvent, 'dataTransfer', { + value: { + items: files.map(f => ({ + getAsFile: () => f, + webkitGetAsEntry: () => null, + })), + }, + }) + dropzone.dispatchEvent(dropEvent) + }) + + await waitFor(() => { + if (prepareFileList.mock.calls.length > 0) { + const calledFiles = prepareFileList.mock.calls[0][0] + expect(calledFiles.length).toBe(1) + } + }) + }) + + it('should handle drop with FileSystemFileEntry', async () => { + mockUpload.mockResolvedValue({ id: 'uploaded-id' }) + const prepareFileList = vi.fn() + const mockFile = new File(['content'], 'test.pdf', { type: 'application/pdf' }) + + const { getByTestId } = await act(async () => + render( + <ToastContext.Provider value={{ notify: mockNotify, close: mockClose }}> + <TestDropzone options={{ ...defaultOptions, prepareFileList }} /> + </ToastContext.Provider>, + ), + ) + + const dropzone = getByTestId('dropzone') + + await act(async () => { + const dropEvent = new Event('drop', { bubbles: true, cancelable: true }) as Event & { dataTransfer: DataTransfer | null } + Object.defineProperty(dropEvent, 'dataTransfer', { + value: { + items: [{ + getAsFile: () => mockFile, + webkitGetAsEntry: () => ({ + isFile: true, + isDirectory: false, + file: (callback: (file: File) => void) => callback(mockFile), + }), + }], + }, + }) + dropzone.dispatchEvent(dropEvent) + }) + + await waitFor(() => { + expect(prepareFileList).toHaveBeenCalled() + }) + }) + + it('should handle drop with FileSystemDirectoryEntry', async () => { + mockUpload.mockResolvedValue({ id: 'uploaded-id' }) + const prepareFileList = vi.fn() + const mockFile = new File(['content'], 'nested.pdf', { type: 'application/pdf' }) + + const { getByTestId } = await act(async () => + render( + <ToastContext.Provider value={{ notify: mockNotify, close: mockClose }}> + <TestDropzone options={{ ...defaultOptions, prepareFileList }} /> + </ToastContext.Provider>, + ), + ) + + const dropzone = getByTestId('dropzone') + + await act(async () => { + let callCount = 0 + const dropEvent = new Event('drop', { bubbles: true, cancelable: true }) as Event & { dataTransfer: DataTransfer | null } + Object.defineProperty(dropEvent, 'dataTransfer', { + value: { + items: [{ + getAsFile: () => null, + webkitGetAsEntry: () => ({ + isFile: false, + isDirectory: true, + name: 'folder', + createReader: () => ({ + readEntries: (callback: (entries: Array<{ isFile: boolean, isDirectory: boolean, name?: string, file?: (cb: (f: File) => void) => void }>) => void) => { + // First call returns file entry, second call returns empty (signals end) + if (callCount === 0) { + callCount++ + callback([{ + isFile: true, + isDirectory: false, + name: 'nested.pdf', + file: (cb: (f: File) => void) => cb(mockFile), + }]) + } + else { + callback([]) + } + }, + }), + }), + }], + }, + }) + dropzone.dispatchEvent(dropEvent) + }) + + await waitFor(() => { + expect(prepareFileList).toHaveBeenCalled() + }) + }) + + it('should handle drop with empty directory', async () => { + const prepareFileList = vi.fn() + + const { getByTestId } = await act(async () => + render( + <ToastContext.Provider value={{ notify: mockNotify, close: mockClose }}> + <TestDropzone options={{ ...defaultOptions, prepareFileList }} /> + </ToastContext.Provider>, + ), + ) + + const dropzone = getByTestId('dropzone') + + await act(async () => { + const dropEvent = new Event('drop', { bubbles: true, cancelable: true }) as Event & { dataTransfer: DataTransfer | null } + Object.defineProperty(dropEvent, 'dataTransfer', { + value: { + items: [{ + getAsFile: () => null, + webkitGetAsEntry: () => ({ + isFile: false, + isDirectory: true, + name: 'empty-folder', + createReader: () => ({ + readEntries: (callback: (entries: never[]) => void) => { + callback([]) + }, + }), + }), + }], + }, + }) + dropzone.dispatchEvent(dropEvent) + }) + + // Should not prepare file list if no valid files + await new Promise(resolve => setTimeout(resolve, 100)) + }) + + it('should handle entry that is neither file nor directory', async () => { + const prepareFileList = vi.fn() + + const { getByTestId } = await act(async () => + render( + <ToastContext.Provider value={{ notify: mockNotify, close: mockClose }}> + <TestDropzone options={{ ...defaultOptions, prepareFileList }} /> + </ToastContext.Provider>, + ), + ) + + const dropzone = getByTestId('dropzone') + + await act(async () => { + const dropEvent = new Event('drop', { bubbles: true, cancelable: true }) as Event & { dataTransfer: DataTransfer | null } + Object.defineProperty(dropEvent, 'dataTransfer', { + value: { + items: [{ + getAsFile: () => null, + webkitGetAsEntry: () => ({ + isFile: false, + isDirectory: false, + }), + }], + }, + }) + dropzone.dispatchEvent(dropEvent) + }) + + // Should not throw and should handle gracefully + await new Promise(resolve => setTimeout(resolve, 100)) + }) + }) + + describe('file upload', () => { + it('should call upload with correct parameters', async () => { + mockUpload.mockResolvedValue({ id: 'uploaded-id', name: 'test.pdf' }) + const onFileUpdate = vi.fn() + + const { result } = renderHook( + () => useFileUpload({ ...defaultOptions, onFileUpdate }), + { wrapper: createWrapper() }, + ) + + const mockFile = new File(['content'], 'test.pdf', { type: 'application/pdf' }) + const event = { + target: { files: [mockFile] }, + } as unknown as React.ChangeEvent<HTMLInputElement> + + act(() => { + result.current.fileChangeHandle(event) + }) + + await waitFor(() => { + expect(mockUpload).toHaveBeenCalled() + }) + }) + + it('should update progress during upload', async () => { + let progressCallback: ((e: ProgressEvent) => void) | undefined + + mockUpload.mockImplementation(async (options: { onprogress: (e: ProgressEvent) => void }) => { + progressCallback = options.onprogress + return { id: 'uploaded-id' } + }) + + const onFileUpdate = vi.fn() + + const { result } = renderHook( + () => useFileUpload({ ...defaultOptions, onFileUpdate }), + { wrapper: createWrapper() }, + ) + + const mockFile = new File(['content'], 'test.pdf', { type: 'application/pdf' }) + const event = { + target: { files: [mockFile] }, + } as unknown as React.ChangeEvent<HTMLInputElement> + + act(() => { + result.current.fileChangeHandle(event) + }) + + await waitFor(() => { + expect(mockUpload).toHaveBeenCalled() + }) + + if (progressCallback) { + act(() => { + progressCallback!({ + lengthComputable: true, + loaded: 50, + total: 100, + } as ProgressEvent) + }) + + expect(onFileUpdate).toHaveBeenCalled() + } + }) + + it('should handle upload error', async () => { + mockUpload.mockRejectedValue(new Error('Upload failed')) + const onFileUpdate = vi.fn() + + const { result } = renderHook( + () => useFileUpload({ ...defaultOptions, onFileUpdate }), + { wrapper: createWrapper() }, + ) + + const mockFile = new File(['content'], 'test.pdf', { type: 'application/pdf' }) + const event = { + target: { files: [mockFile] }, + } as unknown as React.ChangeEvent<HTMLInputElement> + + act(() => { + result.current.fileChangeHandle(event) + }) + + await waitFor(() => { + expect(mockNotify).toHaveBeenCalledWith( + expect.objectContaining({ type: 'error' }), + ) + }) + }) + + it('should update file with PROGRESS_COMPLETE on success', async () => { + mockUpload.mockResolvedValue({ id: 'uploaded-id', name: 'test.pdf' }) + const onFileUpdate = vi.fn() + + const { result } = renderHook( + () => useFileUpload({ ...defaultOptions, onFileUpdate }), + { wrapper: createWrapper() }, + ) + + const mockFile = new File(['content'], 'test.pdf', { type: 'application/pdf' }) + const event = { + target: { files: [mockFile] }, + } as unknown as React.ChangeEvent<HTMLInputElement> + + act(() => { + result.current.fileChangeHandle(event) + }) + + await waitFor(() => { + const completeCalls = onFileUpdate.mock.calls.filter( + ([, progress]) => progress === PROGRESS_COMPLETE, + ) + expect(completeCalls.length).toBeGreaterThan(0) + }) + }) + + it('should update file with PROGRESS_ERROR on failure', async () => { + mockUpload.mockRejectedValue(new Error('Upload failed')) + const onFileUpdate = vi.fn() + + const { result } = renderHook( + () => useFileUpload({ ...defaultOptions, onFileUpdate }), + { wrapper: createWrapper() }, + ) + + const mockFile = new File(['content'], 'test.pdf', { type: 'application/pdf' }) + const event = { + target: { files: [mockFile] }, + } as unknown as React.ChangeEvent<HTMLInputElement> + + act(() => { + result.current.fileChangeHandle(event) + }) + + await waitFor(() => { + const errorCalls = onFileUpdate.mock.calls.filter( + ([, progress]) => progress === PROGRESS_ERROR, + ) + expect(errorCalls.length).toBeGreaterThan(0) + }) + }) + }) + + describe('file count validation', () => { + it('should reject when total files exceed limit', () => { + const existingFiles: FileItem[] = Array.from({ length: 8 }, (_, i) => ({ + fileID: `existing-${i}`, + file: { name: `existing-${i}.pdf`, size: 1024 } as CustomFile, + progress: 100, + })) + + const { result } = renderHook( + () => useFileUpload({ + ...defaultOptions, + fileList: existingFiles, + }), + { wrapper: createWrapper() }, + ) + + const files = Array.from({ length: 5 }, (_, i) => + new File(['content'], `new-${i}.pdf`, { type: 'application/pdf' })) + + const event = { + target: { files }, + } as unknown as React.ChangeEvent<HTMLInputElement> + + act(() => { + result.current.fileChangeHandle(event) + }) + + expect(mockNotify).toHaveBeenCalledWith( + expect.objectContaining({ type: 'error' }), + ) + }) + }) + + describe('progress constants', () => { + it('should use PROGRESS_NOT_STARTED for new files', async () => { + mockUpload.mockResolvedValue({ id: 'file-id' }) + + const prepareFileList = vi.fn() + const { result } = renderHook( + () => useFileUpload({ ...defaultOptions, prepareFileList }), + { wrapper: createWrapper() }, + ) + + const mockFile = new File(['content'], 'test.pdf', { type: 'application/pdf' }) + const event = { + target: { files: [mockFile] }, + } as unknown as React.ChangeEvent<HTMLInputElement> + + act(() => { + result.current.fileChangeHandle(event) + }) + + await waitFor(() => { + if (prepareFileList.mock.calls.length > 0) { + const files = prepareFileList.mock.calls[0][0] + expect(files[0].progress).toBe(PROGRESS_NOT_STARTED) + } + }) + }) + }) +}) diff --git a/web/app/components/datasets/create/file-uploader/hooks/use-file-upload.ts b/web/app/components/datasets/create/file-uploader/hooks/use-file-upload.ts new file mode 100644 index 0000000000..e097bab755 --- /dev/null +++ b/web/app/components/datasets/create/file-uploader/hooks/use-file-upload.ts @@ -0,0 +1,351 @@ +'use client' +import type { RefObject } from 'react' +import type { CustomFile as File, FileItem } from '@/models/datasets' +import { useCallback, useEffect, useMemo, useRef, useState } from 'react' +import { useTranslation } from 'react-i18next' +import { useContext } from 'use-context-selector' +import { getFileUploadErrorMessage } from '@/app/components/base/file-uploader/utils' +import { ToastContext } from '@/app/components/base/toast' +import { IS_CE_EDITION } from '@/config' +import { useLocale } from '@/context/i18n' +import { LanguagesSupported } from '@/i18n-config/language' +import { upload } from '@/service/base' +import { useFileSupportTypes, useFileUploadConfig } from '@/service/use-common' +import { getFileExtension } from '@/utils/format' +import { PROGRESS_COMPLETE, PROGRESS_ERROR, PROGRESS_NOT_STARTED } from '../constants' + +export type FileUploadConfig = { + file_size_limit: number + batch_count_limit: number + file_upload_limit: number +} + +export type UseFileUploadOptions = { + fileList: FileItem[] + prepareFileList: (files: FileItem[]) => void + onFileUpdate: (fileItem: FileItem, progress: number, list: FileItem[]) => void + onFileListUpdate?: (files: FileItem[]) => void + onPreview: (file: File) => void + supportBatchUpload?: boolean + /** + * Optional list of allowed file extensions. If not provided, fetches from API. + * Pass this when you need custom extension filtering instead of using the global config. + */ + allowedExtensions?: string[] +} + +export type UseFileUploadReturn = { + // Refs + dropRef: RefObject<HTMLDivElement | null> + dragRef: RefObject<HTMLDivElement | null> + fileUploaderRef: RefObject<HTMLInputElement | null> + + // State + dragging: boolean + + // Config + fileUploadConfig: FileUploadConfig + acceptTypes: string[] + supportTypesShowNames: string + hideUpload: boolean + + // Handlers + selectHandle: () => void + fileChangeHandle: (e: React.ChangeEvent<HTMLInputElement>) => void + removeFile: (fileID: string) => void + handlePreview: (file: File) => void +} + +type FileWithPath = { + relativePath?: string +} & File + +export const useFileUpload = ({ + fileList, + prepareFileList, + onFileUpdate, + onFileListUpdate, + onPreview, + supportBatchUpload = false, + allowedExtensions, +}: UseFileUploadOptions): UseFileUploadReturn => { + const { t } = useTranslation() + const { notify } = useContext(ToastContext) + const locale = useLocale() + + const [dragging, setDragging] = useState(false) + const dropRef = useRef<HTMLDivElement>(null) + const dragRef = useRef<HTMLDivElement>(null) + const fileUploaderRef = useRef<HTMLInputElement>(null) + const fileListRef = useRef<FileItem[]>([]) + + const hideUpload = !supportBatchUpload && fileList.length > 0 + + const { data: fileUploadConfigResponse } = useFileUploadConfig() + const { data: supportFileTypesResponse } = useFileSupportTypes() + // Use provided allowedExtensions or fetch from API + const supportTypes = useMemo( + () => allowedExtensions ?? supportFileTypesResponse?.allowed_extensions ?? [], + [allowedExtensions, supportFileTypesResponse?.allowed_extensions], + ) + + const supportTypesShowNames = useMemo(() => { + const extensionMap: { [key: string]: string } = { + md: 'markdown', + pptx: 'pptx', + htm: 'html', + xlsx: 'xlsx', + docx: 'docx', + } + + return [...supportTypes] + .map(item => extensionMap[item] || item) + .map(item => item.toLowerCase()) + .filter((item, index, self) => self.indexOf(item) === index) + .map(item => item.toUpperCase()) + .join(locale !== LanguagesSupported[1] ? ', ' : 'ใ€ ') + }, [supportTypes, locale]) + + const acceptTypes = useMemo(() => supportTypes.map((ext: string) => `.${ext}`), [supportTypes]) + + const fileUploadConfig = useMemo(() => ({ + file_size_limit: fileUploadConfigResponse?.file_size_limit ?? 15, + batch_count_limit: supportBatchUpload ? (fileUploadConfigResponse?.batch_count_limit ?? 5) : 1, + file_upload_limit: supportBatchUpload ? (fileUploadConfigResponse?.file_upload_limit ?? 5) : 1, + }), [fileUploadConfigResponse, supportBatchUpload]) + + const isValid = useCallback((file: File) => { + const { size } = file + const ext = `.${getFileExtension(file.name)}` + const isValidType = acceptTypes.includes(ext.toLowerCase()) + if (!isValidType) + notify({ type: 'error', message: t('stepOne.uploader.validation.typeError', { ns: 'datasetCreation' }) }) + + const isValidSize = size <= fileUploadConfig.file_size_limit * 1024 * 1024 + if (!isValidSize) + notify({ type: 'error', message: t('stepOne.uploader.validation.size', { ns: 'datasetCreation', size: fileUploadConfig.file_size_limit }) }) + + return isValidType && isValidSize + }, [fileUploadConfig, notify, t, acceptTypes]) + + const fileUpload = useCallback(async (fileItem: FileItem): Promise<FileItem> => { + const formData = new FormData() + formData.append('file', fileItem.file) + const onProgress = (e: ProgressEvent) => { + if (e.lengthComputable) { + const percent = Math.floor(e.loaded / e.total * 100) + onFileUpdate(fileItem, percent, fileListRef.current) + } + } + + return upload({ + xhr: new XMLHttpRequest(), + data: formData, + onprogress: onProgress, + }, false, undefined, '?source=datasets') + .then((res) => { + const completeFile = { + fileID: fileItem.fileID, + file: res as unknown as File, + progress: PROGRESS_NOT_STARTED, + } + const index = fileListRef.current.findIndex(item => item.fileID === fileItem.fileID) + fileListRef.current[index] = completeFile + onFileUpdate(completeFile, PROGRESS_COMPLETE, fileListRef.current) + return Promise.resolve({ ...completeFile }) + }) + .catch((e) => { + const errorMessage = getFileUploadErrorMessage(e, t('stepOne.uploader.failed', { ns: 'datasetCreation' }), t) + notify({ type: 'error', message: errorMessage }) + onFileUpdate(fileItem, PROGRESS_ERROR, fileListRef.current) + return Promise.resolve({ ...fileItem }) + }) + .finally() + }, [notify, onFileUpdate, t]) + + const uploadBatchFiles = useCallback((bFiles: FileItem[]) => { + bFiles.forEach(bf => (bf.progress = 0)) + return Promise.all(bFiles.map(fileUpload)) + }, [fileUpload]) + + const uploadMultipleFiles = useCallback(async (files: FileItem[]) => { + const batchCountLimit = fileUploadConfig.batch_count_limit + const length = files.length + let start = 0 + let end = 0 + + while (start < length) { + if (start + batchCountLimit > length) + end = length + else + end = start + batchCountLimit + const bFiles = files.slice(start, end) + await uploadBatchFiles(bFiles) + start = end + } + }, [fileUploadConfig, uploadBatchFiles]) + + const initialUpload = useCallback((files: File[]) => { + const filesCountLimit = fileUploadConfig.file_upload_limit + if (!files.length) + return false + + if (files.length + fileList.length > filesCountLimit && !IS_CE_EDITION) { + notify({ type: 'error', message: t('stepOne.uploader.validation.filesNumber', { ns: 'datasetCreation', filesNumber: filesCountLimit }) }) + return false + } + + const preparedFiles = files.map((file, index) => ({ + fileID: `file${index}-${Date.now()}`, + file, + progress: PROGRESS_NOT_STARTED, + })) + const newFiles = [...fileListRef.current, ...preparedFiles] + prepareFileList(newFiles) + fileListRef.current = newFiles + uploadMultipleFiles(preparedFiles) + }, [prepareFileList, uploadMultipleFiles, notify, t, fileList, fileUploadConfig]) + + const traverseFileEntry = useCallback( + (entry: FileSystemEntry, prefix = ''): Promise<FileWithPath[]> => { + return new Promise((resolve) => { + if (entry.isFile) { + (entry as FileSystemFileEntry).file((file: FileWithPath) => { + file.relativePath = `${prefix}${file.name}` + resolve([file]) + }) + } + else if (entry.isDirectory) { + const reader = (entry as FileSystemDirectoryEntry).createReader() + const entries: FileSystemEntry[] = [] + const read = () => { + reader.readEntries(async (results: FileSystemEntry[]) => { + if (!results.length) { + const files = await Promise.all( + entries.map(ent => + traverseFileEntry(ent, `${prefix}${entry.name}/`), + ), + ) + resolve(files.flat()) + } + else { + entries.push(...results) + read() + } + }) + } + read() + } + else { + resolve([]) + } + }) + }, + [], + ) + + const handleDragEnter = useCallback((e: DragEvent) => { + e.preventDefault() + e.stopPropagation() + if (e.target !== dragRef.current) + setDragging(true) + }, []) + + const handleDragOver = useCallback((e: DragEvent) => { + e.preventDefault() + e.stopPropagation() + }, []) + + const handleDragLeave = useCallback((e: DragEvent) => { + e.preventDefault() + e.stopPropagation() + if (e.target === dragRef.current) + setDragging(false) + }, []) + + const handleDrop = useCallback( + async (e: DragEvent) => { + e.preventDefault() + e.stopPropagation() + setDragging(false) + if (!e.dataTransfer) + return + const nested = await Promise.all( + Array.from(e.dataTransfer.items).map((it) => { + const entry = (it as DataTransferItem & { webkitGetAsEntry?: () => FileSystemEntry | null }).webkitGetAsEntry?.() + if (entry) + return traverseFileEntry(entry) + const f = it.getAsFile?.() + return f ? Promise.resolve([f as FileWithPath]) : Promise.resolve([]) + }), + ) + let files = nested.flat() + if (!supportBatchUpload) + files = files.slice(0, 1) + files = files.slice(0, fileUploadConfig.batch_count_limit) + const valid = files.filter(isValid) + initialUpload(valid) + }, + [initialUpload, isValid, supportBatchUpload, traverseFileEntry, fileUploadConfig], + ) + + const selectHandle = useCallback(() => { + if (fileUploaderRef.current) + fileUploaderRef.current.click() + }, []) + + const removeFile = useCallback((fileID: string) => { + if (fileUploaderRef.current) + fileUploaderRef.current.value = '' + + fileListRef.current = fileListRef.current.filter(item => item.fileID !== fileID) + onFileListUpdate?.([...fileListRef.current]) + }, [onFileListUpdate]) + + const fileChangeHandle = useCallback((e: React.ChangeEvent<HTMLInputElement>) => { + let files = Array.from(e.target.files ?? []) as File[] + files = files.slice(0, fileUploadConfig.batch_count_limit) + initialUpload(files.filter(isValid)) + }, [isValid, initialUpload, fileUploadConfig]) + + const handlePreview = useCallback((file: File) => { + if (file?.id) + onPreview(file) + }, [onPreview]) + + useEffect(() => { + const dropArea = dropRef.current + dropArea?.addEventListener('dragenter', handleDragEnter) + dropArea?.addEventListener('dragover', handleDragOver) + dropArea?.addEventListener('dragleave', handleDragLeave) + dropArea?.addEventListener('drop', handleDrop) + return () => { + dropArea?.removeEventListener('dragenter', handleDragEnter) + dropArea?.removeEventListener('dragover', handleDragOver) + dropArea?.removeEventListener('dragleave', handleDragLeave) + dropArea?.removeEventListener('drop', handleDrop) + } + }, [handleDragEnter, handleDragOver, handleDragLeave, handleDrop]) + + return { + // Refs + dropRef, + dragRef, + fileUploaderRef, + + // State + dragging, + + // Config + fileUploadConfig, + acceptTypes, + supportTypesShowNames, + hideUpload, + + // Handlers + selectHandle, + fileChangeHandle, + removeFile, + handlePreview, + } +} diff --git a/web/app/components/datasets/create/file-uploader/index.spec.tsx b/web/app/components/datasets/create/file-uploader/index.spec.tsx new file mode 100644 index 0000000000..91f65652f3 --- /dev/null +++ b/web/app/components/datasets/create/file-uploader/index.spec.tsx @@ -0,0 +1,278 @@ +import type { CustomFile as File, FileItem } from '@/models/datasets' +import { fireEvent, render, screen } from '@testing-library/react' +import { beforeEach, describe, expect, it, vi } from 'vitest' +import { PROGRESS_NOT_STARTED } from './constants' +import FileUploader from './index' + +// Mock react-i18next +vi.mock('react-i18next', () => ({ + useTranslation: () => ({ + t: (key: string) => { + const translations: Record<string, string> = { + 'stepOne.uploader.title': 'Upload Files', + 'stepOne.uploader.button': 'Drag and drop files, or', + 'stepOne.uploader.buttonSingleFile': 'Drag and drop file, or', + 'stepOne.uploader.browse': 'Browse', + 'stepOne.uploader.tip': 'Supports various file types', + } + return translations[key] || key + }, + }), +})) + +// Mock ToastContext +const mockNotify = vi.fn() +vi.mock('use-context-selector', async () => { + const actual = await vi.importActual<typeof import('use-context-selector')>('use-context-selector') + return { + ...actual, + useContext: vi.fn(() => ({ notify: mockNotify })), + } +}) + +// Mock services +vi.mock('@/service/base', () => ({ + upload: vi.fn().mockResolvedValue({ id: 'uploaded-id' }), +})) + +vi.mock('@/service/use-common', () => ({ + useFileUploadConfig: () => ({ + data: { file_size_limit: 15, batch_count_limit: 5, file_upload_limit: 10 }, + }), + useFileSupportTypes: () => ({ + data: { allowed_extensions: ['pdf', 'docx', 'txt'] }, + }), +})) + +vi.mock('@/context/i18n', () => ({ + useLocale: () => 'en-US', +})) + +vi.mock('@/i18n-config/language', () => ({ + LanguagesSupported: ['en-US', 'zh-Hans'], +})) + +vi.mock('@/config', () => ({ + IS_CE_EDITION: false, +})) + +vi.mock('@/app/components/base/file-uploader/utils', () => ({ + getFileUploadErrorMessage: () => 'Upload error', +})) + +// Mock theme +vi.mock('@/hooks/use-theme', () => ({ + default: () => ({ theme: 'light' }), +})) + +vi.mock('@/types/app', () => ({ + Theme: { dark: 'dark', light: 'light' }, +})) + +// Mock DocumentFileIcon - uses relative path from file-list-item.tsx +vi.mock('@/app/components/datasets/common/document-file-icon', () => ({ + default: ({ extension }: { extension: string }) => <div data-testid="document-icon">{extension}</div>, +})) + +// Mock SimplePieChart +vi.mock('next/dynamic', () => ({ + default: () => { + const Component = ({ percentage }: { percentage: number }) => ( + <div data-testid="pie-chart"> + {percentage} + % + </div> + ) + return Component + }, +})) + +describe('FileUploader', () => { + const createMockFile = (overrides: Partial<File> = {}): File => ({ + name: 'test.pdf', + size: 1024, + type: 'application/pdf', + ...overrides, + } as File) + + const createMockFileItem = (overrides: Partial<FileItem> = {}): FileItem => ({ + fileID: `file-${Date.now()}`, + file: createMockFile(overrides.file as Partial<File>), + progress: PROGRESS_NOT_STARTED, + ...overrides, + }) + + const defaultProps = { + fileList: [] as FileItem[], + prepareFileList: vi.fn(), + onFileUpdate: vi.fn(), + onFileListUpdate: vi.fn(), + onPreview: vi.fn(), + supportBatchUpload: true, + } + + beforeEach(() => { + vi.clearAllMocks() + }) + + describe('rendering', () => { + it('should render the component', () => { + render(<FileUploader {...defaultProps} />) + expect(screen.getByText('Upload Files')).toBeInTheDocument() + }) + + it('should render dropzone when no files', () => { + render(<FileUploader {...defaultProps} />) + expect(screen.getByText(/Drag and drop files/i)).toBeInTheDocument() + }) + + it('should render browse button', () => { + render(<FileUploader {...defaultProps} />) + expect(screen.getByText('Browse')).toBeInTheDocument() + }) + + it('should apply custom title className', () => { + render(<FileUploader {...defaultProps} titleClassName="custom-class" />) + const title = screen.getByText('Upload Files') + expect(title).toHaveClass('custom-class') + }) + }) + + describe('file list rendering', () => { + it('should render file items when fileList has items', () => { + const fileList = [ + createMockFileItem({ file: createMockFile({ name: 'file1.pdf' }) }), + createMockFileItem({ file: createMockFile({ name: 'file2.pdf' }) }), + ] + + render(<FileUploader {...defaultProps} fileList={fileList} />) + + expect(screen.getByText('file1.pdf')).toBeInTheDocument() + expect(screen.getByText('file2.pdf')).toBeInTheDocument() + }) + + it('should render document icons for files', () => { + const fileList = [createMockFileItem()] + render(<FileUploader {...defaultProps} fileList={fileList} />) + + expect(screen.getByTestId('document-icon')).toBeInTheDocument() + }) + }) + + describe('batch upload mode', () => { + it('should show dropzone with batch upload enabled', () => { + render(<FileUploader {...defaultProps} supportBatchUpload={true} />) + expect(screen.getByText(/Drag and drop files/i)).toBeInTheDocument() + }) + + it('should show single file text when batch upload disabled', () => { + render(<FileUploader {...defaultProps} supportBatchUpload={false} />) + expect(screen.getByText(/Drag and drop file/i)).toBeInTheDocument() + }) + + it('should hide dropzone when not batch upload and has files', () => { + const fileList = [createMockFileItem()] + render(<FileUploader {...defaultProps} supportBatchUpload={false} fileList={fileList} />) + + expect(screen.queryByText(/Drag and drop/i)).not.toBeInTheDocument() + }) + }) + + describe('event handlers', () => { + it('should handle file preview click', () => { + const onPreview = vi.fn() + const fileItem = createMockFileItem({ + file: createMockFile({ id: 'file-id' } as Partial<File>), + }) + + const { container } = render(<FileUploader {...defaultProps} fileList={[fileItem]} onPreview={onPreview} />) + + // Find the file list item container by its class pattern + const fileElement = container.querySelector('[class*="flex h-12"]') + if (fileElement) + fireEvent.click(fileElement) + + expect(onPreview).toHaveBeenCalledWith(fileItem.file) + }) + + it('should handle file remove click', () => { + const onFileListUpdate = vi.fn() + const fileItem = createMockFileItem() + + const { container } = render( + <FileUploader {...defaultProps} fileList={[fileItem]} onFileListUpdate={onFileListUpdate} />, + ) + + // Find the delete button (the span with cursor-pointer containing the icon) + const deleteButtons = container.querySelectorAll('[class*="cursor-pointer"]') + // Get the last one which should be the delete button (not the browse label) + const deleteButton = deleteButtons[deleteButtons.length - 1] + if (deleteButton) + fireEvent.click(deleteButton) + + expect(onFileListUpdate).toHaveBeenCalled() + }) + + it('should handle browse button click', () => { + render(<FileUploader {...defaultProps} />) + + // The browse label should trigger file input click + const browseLabel = screen.getByText('Browse') + expect(browseLabel).toHaveClass('cursor-pointer') + }) + }) + + describe('upload progress', () => { + it('should show progress chart for uploading files', () => { + const fileItem = createMockFileItem({ progress: 50 }) + render(<FileUploader {...defaultProps} fileList={[fileItem]} />) + + expect(screen.getByTestId('pie-chart')).toBeInTheDocument() + expect(screen.getByText('50%')).toBeInTheDocument() + }) + + it('should not show progress chart for completed files', () => { + const fileItem = createMockFileItem({ progress: 100 }) + render(<FileUploader {...defaultProps} fileList={[fileItem]} />) + + expect(screen.queryByTestId('pie-chart')).not.toBeInTheDocument() + }) + + it('should not show progress chart for not started files', () => { + const fileItem = createMockFileItem({ progress: PROGRESS_NOT_STARTED }) + render(<FileUploader {...defaultProps} fileList={[fileItem]} />) + + expect(screen.queryByTestId('pie-chart')).not.toBeInTheDocument() + }) + }) + + describe('multiple files', () => { + it('should render all files in the list', () => { + const fileList = [ + createMockFileItem({ fileID: 'f1', file: createMockFile({ name: 'doc1.pdf' }) }), + createMockFileItem({ fileID: 'f2', file: createMockFile({ name: 'doc2.docx' }) }), + createMockFileItem({ fileID: 'f3', file: createMockFile({ name: 'doc3.txt' }) }), + ] + + render(<FileUploader {...defaultProps} fileList={fileList} />) + + expect(screen.getByText('doc1.pdf')).toBeInTheDocument() + expect(screen.getByText('doc2.docx')).toBeInTheDocument() + expect(screen.getByText('doc3.txt')).toBeInTheDocument() + }) + }) + + describe('styling', () => { + it('should have correct container width', () => { + const { container } = render(<FileUploader {...defaultProps} />) + const wrapper = container.firstChild as HTMLElement + expect(wrapper).toHaveClass('w-[640px]') + }) + + it('should have proper spacing', () => { + const { container } = render(<FileUploader {...defaultProps} />) + const wrapper = container.firstChild as HTMLElement + expect(wrapper).toHaveClass('mb-5') + }) + }) +}) diff --git a/web/app/components/datasets/create/file-uploader/index.tsx b/web/app/components/datasets/create/file-uploader/index.tsx index 781b97200a..b649554a12 100644 --- a/web/app/components/datasets/create/file-uploader/index.tsx +++ b/web/app/components/datasets/create/file-uploader/index.tsx @@ -1,23 +1,10 @@ 'use client' import type { CustomFile as File, FileItem } from '@/models/datasets' -import { RiDeleteBinLine, RiUploadCloud2Line } from '@remixicon/react' -import * as React from 'react' -import { useCallback, useEffect, useMemo, useRef, useState } from 'react' import { useTranslation } from 'react-i18next' -import { useContext } from 'use-context-selector' -import { getFileUploadErrorMessage } from '@/app/components/base/file-uploader/utils' -import SimplePieChart from '@/app/components/base/simple-pie-chart' -import { ToastContext } from '@/app/components/base/toast' -import { IS_CE_EDITION } from '@/config' - -import { useLocale } from '@/context/i18n' -import useTheme from '@/hooks/use-theme' -import { LanguagesSupported } from '@/i18n-config/language' -import { upload } from '@/service/base' -import { useFileSupportTypes, useFileUploadConfig } from '@/service/use-common' -import { Theme } from '@/types/app' import { cn } from '@/utils/classnames' -import DocumentFileIcon from '../../common/document-file-icon' +import FileListItem from './components/file-list-item' +import UploadDropzone from './components/upload-dropzone' +import { useFileUpload } from './hooks/use-file-upload' type IFileUploaderProps = { fileList: FileItem[] @@ -39,358 +26,62 @@ const FileUploader = ({ supportBatchUpload = false, }: IFileUploaderProps) => { const { t } = useTranslation() - const { notify } = useContext(ToastContext) - const locale = useLocale() - const [dragging, setDragging] = useState(false) - const dropRef = useRef<HTMLDivElement>(null) - const dragRef = useRef<HTMLDivElement>(null) - const fileUploader = useRef<HTMLInputElement>(null) - const hideUpload = !supportBatchUpload && fileList.length > 0 - const { data: fileUploadConfigResponse } = useFileUploadConfig() - const { data: supportFileTypesResponse } = useFileSupportTypes() - const supportTypes = supportFileTypesResponse?.allowed_extensions || [] - const supportTypesShowNames = (() => { - const extensionMap: { [key: string]: string } = { - md: 'markdown', - pptx: 'pptx', - htm: 'html', - xlsx: 'xlsx', - docx: 'docx', - } - - return [...supportTypes] - .map(item => extensionMap[item] || item) // map to standardized extension - .map(item => item.toLowerCase()) // convert to lower case - .filter((item, index, self) => self.indexOf(item) === index) // remove duplicates - .map(item => item.toUpperCase()) // convert to upper case - .join(locale !== LanguagesSupported[1] ? ', ' : 'ใ€ ') - })() - const ACCEPTS = supportTypes.map((ext: string) => `.${ext}`) - const fileUploadConfig = useMemo(() => ({ - file_size_limit: fileUploadConfigResponse?.file_size_limit ?? 15, - batch_count_limit: supportBatchUpload ? (fileUploadConfigResponse?.batch_count_limit ?? 5) : 1, - file_upload_limit: supportBatchUpload ? (fileUploadConfigResponse?.file_upload_limit ?? 5) : 1, - }), [fileUploadConfigResponse, supportBatchUpload]) - - const fileListRef = useRef<FileItem[]>([]) - - // utils - const getFileType = (currentFile: File) => { - if (!currentFile) - return '' - - const arr = currentFile.name.split('.') - return arr[arr.length - 1] - } - - const getFileSize = (size: number) => { - if (size / 1024 < 10) - return `${(size / 1024).toFixed(2)}KB` - - return `${(size / 1024 / 1024).toFixed(2)}MB` - } - - const isValid = useCallback((file: File) => { - const { size } = file - const ext = `.${getFileType(file)}` - const isValidType = ACCEPTS.includes(ext.toLowerCase()) - if (!isValidType) - notify({ type: 'error', message: t('stepOne.uploader.validation.typeError', { ns: 'datasetCreation' }) }) - - const isValidSize = size <= fileUploadConfig.file_size_limit * 1024 * 1024 - if (!isValidSize) - notify({ type: 'error', message: t('stepOne.uploader.validation.size', { ns: 'datasetCreation', size: fileUploadConfig.file_size_limit }) }) - - return isValidType && isValidSize - }, [fileUploadConfig, notify, t, ACCEPTS]) - - const fileUpload = useCallback(async (fileItem: FileItem): Promise<FileItem> => { - const formData = new FormData() - formData.append('file', fileItem.file) - const onProgress = (e: ProgressEvent) => { - if (e.lengthComputable) { - const percent = Math.floor(e.loaded / e.total * 100) - onFileUpdate(fileItem, percent, fileListRef.current) - } - } - - return upload({ - xhr: new XMLHttpRequest(), - data: formData, - onprogress: onProgress, - }, false, undefined, '?source=datasets') - .then((res) => { - const completeFile = { - fileID: fileItem.fileID, - file: res as unknown as File, - progress: -1, - } - const index = fileListRef.current.findIndex(item => item.fileID === fileItem.fileID) - fileListRef.current[index] = completeFile - onFileUpdate(completeFile, 100, fileListRef.current) - return Promise.resolve({ ...completeFile }) - }) - .catch((e) => { - const errorMessage = getFileUploadErrorMessage(e, t('stepOne.uploader.failed', { ns: 'datasetCreation' }), t) - notify({ type: 'error', message: errorMessage }) - onFileUpdate(fileItem, -2, fileListRef.current) - return Promise.resolve({ ...fileItem }) - }) - .finally() - }, [fileListRef, notify, onFileUpdate, t]) - - const uploadBatchFiles = useCallback((bFiles: FileItem[]) => { - bFiles.forEach(bf => (bf.progress = 0)) - return Promise.all(bFiles.map(fileUpload)) - }, [fileUpload]) - - const uploadMultipleFiles = useCallback(async (files: FileItem[]) => { - const batchCountLimit = fileUploadConfig.batch_count_limit - const length = files.length - let start = 0 - let end = 0 - - while (start < length) { - if (start + batchCountLimit > length) - end = length - else - end = start + batchCountLimit - const bFiles = files.slice(start, end) - await uploadBatchFiles(bFiles) - start = end - } - }, [fileUploadConfig, uploadBatchFiles]) - - const initialUpload = useCallback((files: File[]) => { - const filesCountLimit = fileUploadConfig.file_upload_limit - if (!files.length) - return false - - if (files.length + fileList.length > filesCountLimit && !IS_CE_EDITION) { - notify({ type: 'error', message: t('stepOne.uploader.validation.filesNumber', { ns: 'datasetCreation', filesNumber: filesCountLimit }) }) - return false - } - - const preparedFiles = files.map((file, index) => ({ - fileID: `file${index}-${Date.now()}`, - file, - progress: -1, - })) - const newFiles = [...fileListRef.current, ...preparedFiles] - prepareFileList(newFiles) - fileListRef.current = newFiles - uploadMultipleFiles(preparedFiles) - }, [prepareFileList, uploadMultipleFiles, notify, t, fileList, fileUploadConfig]) - - const handleDragEnter = (e: DragEvent) => { - e.preventDefault() - e.stopPropagation() - if (e.target !== dragRef.current) - setDragging(true) - } - const handleDragOver = (e: DragEvent) => { - e.preventDefault() - e.stopPropagation() - } - const handleDragLeave = (e: DragEvent) => { - e.preventDefault() - e.stopPropagation() - if (e.target === dragRef.current) - setDragging(false) - } - type FileWithPath = { - relativePath?: string - } & File - const traverseFileEntry = useCallback( - (entry: any, prefix = ''): Promise<FileWithPath[]> => { - return new Promise((resolve) => { - if (entry.isFile) { - entry.file((file: FileWithPath) => { - file.relativePath = `${prefix}${file.name}` - resolve([file]) - }) - } - else if (entry.isDirectory) { - const reader = entry.createReader() - const entries: any[] = [] - const read = () => { - reader.readEntries(async (results: FileSystemEntry[]) => { - if (!results.length) { - const files = await Promise.all( - entries.map(ent => - traverseFileEntry(ent, `${prefix}${entry.name}/`), - ), - ) - resolve(files.flat()) - } - else { - entries.push(...results) - read() - } - }) - } - read() - } - else { - resolve([]) - } - }) - }, - [], - ) - - const handleDrop = useCallback( - async (e: DragEvent) => { - e.preventDefault() - e.stopPropagation() - setDragging(false) - if (!e.dataTransfer) - return - const nested = await Promise.all( - Array.from(e.dataTransfer.items).map((it) => { - const entry = (it as any).webkitGetAsEntry?.() - if (entry) - return traverseFileEntry(entry) - const f = it.getAsFile?.() - return f ? Promise.resolve([f]) : Promise.resolve([]) - }), - ) - let files = nested.flat() - if (!supportBatchUpload) - files = files.slice(0, 1) - files = files.slice(0, fileUploadConfig.batch_count_limit) - const valid = files.filter(isValid) - initialUpload(valid) - }, - [initialUpload, isValid, supportBatchUpload, traverseFileEntry, fileUploadConfig], - ) - const selectHandle = () => { - if (fileUploader.current) - fileUploader.current.click() - } - - const removeFile = (fileID: string) => { - if (fileUploader.current) - fileUploader.current.value = '' - - fileListRef.current = fileListRef.current.filter(item => item.fileID !== fileID) - onFileListUpdate?.([...fileListRef.current]) - } - const fileChangeHandle = useCallback((e: React.ChangeEvent<HTMLInputElement>) => { - let files = Array.from(e.target.files ?? []) as File[] - files = files.slice(0, fileUploadConfig.batch_count_limit) - initialUpload(files.filter(isValid)) - }, [isValid, initialUpload, fileUploadConfig]) - - const { theme } = useTheme() - const chartColor = useMemo(() => theme === Theme.dark ? '#5289ff' : '#296dff', [theme]) - - useEffect(() => { - dropRef.current?.addEventListener('dragenter', handleDragEnter) - dropRef.current?.addEventListener('dragover', handleDragOver) - dropRef.current?.addEventListener('dragleave', handleDragLeave) - dropRef.current?.addEventListener('drop', handleDrop) - return () => { - dropRef.current?.removeEventListener('dragenter', handleDragEnter) - dropRef.current?.removeEventListener('dragover', handleDragOver) - dropRef.current?.removeEventListener('dragleave', handleDragLeave) - dropRef.current?.removeEventListener('drop', handleDrop) - } - }, [handleDrop]) + const { + dropRef, + dragRef, + fileUploaderRef, + dragging, + fileUploadConfig, + acceptTypes, + supportTypesShowNames, + hideUpload, + selectHandle, + fileChangeHandle, + removeFile, + handlePreview, + } = useFileUpload({ + fileList, + prepareFileList, + onFileUpdate, + onFileListUpdate, + onPreview, + supportBatchUpload, + }) return ( <div className="mb-5 w-[640px]"> + <div className={cn('mb-1 text-sm font-semibold leading-6 text-text-secondary', titleClassName)}> + {t('stepOne.uploader.title', { ns: 'datasetCreation' })} + </div> + {!hideUpload && ( - <input - ref={fileUploader} - id="fileUploader" - className="hidden" - type="file" - multiple={supportBatchUpload} - accept={ACCEPTS.join(',')} - onChange={fileChangeHandle} + <UploadDropzone + dropRef={dropRef} + dragRef={dragRef} + fileUploaderRef={fileUploaderRef} + dragging={dragging} + supportBatchUpload={supportBatchUpload} + supportTypesShowNames={supportTypesShowNames} + fileUploadConfig={fileUploadConfig} + acceptTypes={acceptTypes} + onSelectFile={selectHandle} + onFileChange={fileChangeHandle} /> )} - <div className={cn('mb-1 text-sm font-semibold leading-6 text-text-secondary', titleClassName)}>{t('stepOne.uploader.title', { ns: 'datasetCreation' })}</div> - - {!hideUpload && ( - <div ref={dropRef} className={cn('relative mb-2 box-border flex min-h-20 max-w-[640px] flex-col items-center justify-center gap-1 rounded-xl border border-dashed border-components-dropzone-border bg-components-dropzone-bg px-4 py-3 text-xs leading-4 text-text-tertiary', dragging && 'border-components-dropzone-border-accent bg-components-dropzone-bg-accent')}> - <div className="flex min-h-5 items-center justify-center text-sm leading-4 text-text-secondary"> - <RiUploadCloud2Line className="mr-2 size-5" /> - - <span> - {supportBatchUpload ? t('stepOne.uploader.button', { ns: 'datasetCreation' }) : t('stepOne.uploader.buttonSingleFile', { ns: 'datasetCreation' })} - {supportTypes.length > 0 && ( - <label className="ml-1 cursor-pointer text-text-accent" onClick={selectHandle}>{t('stepOne.uploader.browse', { ns: 'datasetCreation' })}</label> - )} - </span> - </div> - <div> - {t('stepOne.uploader.tip', { - ns: 'datasetCreation', - size: fileUploadConfig.file_size_limit, - supportTypes: supportTypesShowNames, - batchCount: fileUploadConfig.batch_count_limit, - totalCount: fileUploadConfig.file_upload_limit, - })} - </div> - {dragging && <div ref={dragRef} className="absolute left-0 top-0 h-full w-full" />} + {fileList.length > 0 && ( + <div className="max-w-[640px] cursor-default space-y-1"> + {fileList.map(fileItem => ( + <FileListItem + key={fileItem.fileID} + fileItem={fileItem} + onPreview={handlePreview} + onRemove={removeFile} + /> + ))} </div> )} - <div className="max-w-[640px] cursor-default space-y-1"> - - {fileList.map((fileItem, index) => ( - <div - key={`${fileItem.fileID}-${index}`} - onClick={() => fileItem.file?.id && onPreview(fileItem.file)} - className={cn( - 'flex h-12 max-w-[640px] items-center rounded-lg border border-components-panel-border bg-components-panel-on-panel-item-bg text-xs leading-3 text-text-tertiary shadow-xs', - // 'border-state-destructive-border bg-state-destructive-hover', - )} - > - <div className="flex w-12 shrink-0 items-center justify-center"> - <DocumentFileIcon - size="xl" - className="shrink-0" - name={fileItem.file.name} - extension={getFileType(fileItem.file)} - /> - </div> - <div className="flex shrink grow flex-col gap-0.5"> - <div className="flex w-full"> - <div className="w-0 grow truncate text-sm leading-4 text-text-secondary">{fileItem.file.name}</div> - </div> - <div className="w-full truncate leading-3 text-text-tertiary"> - <span className="uppercase">{getFileType(fileItem.file)}</span> - <span className="px-1 text-text-quaternary">ยท</span> - <span>{getFileSize(fileItem.file.size)}</span> - {/* <span className='px-1 text-text-quaternary'>ยท</span> - <span>10k characters</span> */} - </div> - </div> - <div className="flex w-16 shrink-0 items-center justify-end gap-1 pr-3"> - {/* <span className="flex justify-center items-center w-6 h-6 cursor-pointer"> - <RiErrorWarningFill className='size-4 text-text-warning' /> - </span> */} - {(fileItem.progress < 100 && fileItem.progress >= 0) && ( - // <div className={s.percent}>{`${fileItem.progress}%`}</div> - <SimplePieChart percentage={fileItem.progress} stroke={chartColor} fill={chartColor} animationDuration={0} /> - )} - <span - className="flex h-6 w-6 cursor-pointer items-center justify-center" - onClick={(e) => { - e.stopPropagation() - removeFile(fileItem.fileID) - }} - > - <RiDeleteBinLine className="size-4 text-text-tertiary" /> - </span> - </div> - </div> - ))} - </div> </div> ) } diff --git a/web/app/components/datasets/documents/create-from-pipeline/data-source/local-file/components/file-list-item.spec.tsx b/web/app/components/datasets/documents/create-from-pipeline/data-source/local-file/components/file-list-item.spec.tsx new file mode 100644 index 0000000000..7754ba6970 --- /dev/null +++ b/web/app/components/datasets/documents/create-from-pipeline/data-source/local-file/components/file-list-item.spec.tsx @@ -0,0 +1,351 @@ +import type { FileListItemProps } from './file-list-item' +import type { CustomFile as File, FileItem } from '@/models/datasets' +import { fireEvent, render, screen } from '@testing-library/react' +import { beforeEach, describe, expect, it, vi } from 'vitest' +import { PROGRESS_ERROR, PROGRESS_NOT_STARTED } from '../constants' +import FileListItem from './file-list-item' + +// Mock theme hook - can be changed per test +let mockTheme = 'light' +vi.mock('@/hooks/use-theme', () => ({ + default: () => ({ theme: mockTheme }), +})) + +// Mock theme types +vi.mock('@/types/app', () => ({ + Theme: { dark: 'dark', light: 'light' }, +})) + +// Mock SimplePieChart with dynamic import handling +vi.mock('next/dynamic', () => ({ + default: () => { + const DynamicComponent = ({ percentage, stroke, fill }: { percentage: number, stroke: string, fill: string }) => ( + <div data-testid="pie-chart" data-percentage={percentage} data-stroke={stroke} data-fill={fill}> + Pie Chart: + {' '} + {percentage} + % + </div> + ) + DynamicComponent.displayName = 'SimplePieChart' + return DynamicComponent + }, +})) + +// Mock DocumentFileIcon +vi.mock('@/app/components/datasets/common/document-file-icon', () => ({ + default: ({ name, extension, size }: { name: string, extension: string, size: string }) => ( + <div data-testid="document-icon" data-name={name} data-extension={extension} data-size={size}> + Document Icon + </div> + ), +})) + +describe('FileListItem', () => { + const createMockFile = (overrides: Partial<File> = {}): File => ({ + name: 'test-document.pdf', + size: 1024 * 100, // 100KB + type: 'application/pdf', + lastModified: Date.now(), + ...overrides, + } as File) + + const createMockFileItem = (overrides: Partial<FileItem> = {}): FileItem => ({ + fileID: 'file-123', + file: createMockFile(overrides.file as Partial<File>), + progress: PROGRESS_NOT_STARTED, + ...overrides, + }) + + const defaultProps: FileListItemProps = { + fileItem: createMockFileItem(), + onPreview: vi.fn(), + onRemove: vi.fn(), + } + + beforeEach(() => { + vi.clearAllMocks() + }) + + describe('rendering', () => { + it('should render the file item container', () => { + const { container } = render(<FileListItem {...defaultProps} />) + + const item = container.firstChild as HTMLElement + expect(item).toHaveClass('flex', 'h-12', 'items-center', 'rounded-lg') + }) + + it('should render document icon with correct props', () => { + render(<FileListItem {...defaultProps} />) + + const icon = screen.getByTestId('document-icon') + expect(icon).toBeInTheDocument() + expect(icon).toHaveAttribute('data-name', 'test-document.pdf') + expect(icon).toHaveAttribute('data-extension', 'pdf') + expect(icon).toHaveAttribute('data-size', 'lg') + }) + + it('should render file name', () => { + render(<FileListItem {...defaultProps} />) + + expect(screen.getByText('test-document.pdf')).toBeInTheDocument() + }) + + it('should render file extension in uppercase via CSS class', () => { + render(<FileListItem {...defaultProps} />) + + // Extension is rendered in lowercase but styled with uppercase CSS + const extensionSpan = screen.getByText('pdf') + expect(extensionSpan).toBeInTheDocument() + expect(extensionSpan).toHaveClass('uppercase') + }) + + it('should render file size', () => { + render(<FileListItem {...defaultProps} />) + + // 100KB (102400 bytes) formatted with formatFileSize + expect(screen.getByText('100.00 KB')).toBeInTheDocument() + }) + + it('should render delete button', () => { + const { container } = render(<FileListItem {...defaultProps} />) + + const deleteButton = container.querySelector('.cursor-pointer') + expect(deleteButton).toBeInTheDocument() + }) + }) + + describe('progress states', () => { + it('should show progress chart when uploading (0-99)', () => { + const fileItem = createMockFileItem({ progress: 50 }) + render(<FileListItem {...defaultProps} fileItem={fileItem} />) + + const pieChart = screen.getByTestId('pie-chart') + expect(pieChart).toBeInTheDocument() + expect(pieChart).toHaveAttribute('data-percentage', '50') + }) + + it('should show progress chart at 0%', () => { + const fileItem = createMockFileItem({ progress: 0 }) + render(<FileListItem {...defaultProps} fileItem={fileItem} />) + + const pieChart = screen.getByTestId('pie-chart') + expect(pieChart).toHaveAttribute('data-percentage', '0') + }) + + it('should not show progress chart when complete (100)', () => { + const fileItem = createMockFileItem({ progress: 100 }) + render(<FileListItem {...defaultProps} fileItem={fileItem} />) + + expect(screen.queryByTestId('pie-chart')).not.toBeInTheDocument() + }) + + it('should not show progress chart when not started (-1)', () => { + const fileItem = createMockFileItem({ progress: PROGRESS_NOT_STARTED }) + render(<FileListItem {...defaultProps} fileItem={fileItem} />) + + expect(screen.queryByTestId('pie-chart')).not.toBeInTheDocument() + }) + }) + + describe('error state', () => { + it('should show error icon when progress is PROGRESS_ERROR', () => { + const fileItem = createMockFileItem({ progress: PROGRESS_ERROR }) + const { container } = render(<FileListItem {...defaultProps} fileItem={fileItem} />) + + const errorIcon = container.querySelector('.text-text-destructive') + expect(errorIcon).toBeInTheDocument() + }) + + it('should apply error styling to container', () => { + const fileItem = createMockFileItem({ progress: PROGRESS_ERROR }) + const { container } = render(<FileListItem {...defaultProps} fileItem={fileItem} />) + + const item = container.firstChild as HTMLElement + expect(item).toHaveClass('border-state-destructive-border', 'bg-state-destructive-hover') + }) + + it('should not show error styling when not in error state', () => { + const { container } = render(<FileListItem {...defaultProps} />) + + const item = container.firstChild as HTMLElement + expect(item).not.toHaveClass('border-state-destructive-border') + }) + }) + + describe('theme handling', () => { + it('should use correct chart color for light theme', () => { + mockTheme = 'light' + const fileItem = createMockFileItem({ progress: 50 }) + render(<FileListItem {...defaultProps} fileItem={fileItem} />) + + const pieChart = screen.getByTestId('pie-chart') + expect(pieChart).toHaveAttribute('data-stroke', '#296dff') + expect(pieChart).toHaveAttribute('data-fill', '#296dff') + }) + + it('should use correct chart color for dark theme', () => { + mockTheme = 'dark' + const fileItem = createMockFileItem({ progress: 50 }) + render(<FileListItem {...defaultProps} fileItem={fileItem} />) + + const pieChart = screen.getByTestId('pie-chart') + expect(pieChart).toHaveAttribute('data-stroke', '#5289ff') + expect(pieChart).toHaveAttribute('data-fill', '#5289ff') + }) + }) + + describe('event handlers', () => { + it('should call onPreview when item is clicked', () => { + const onPreview = vi.fn() + const fileItem = createMockFileItem() + render(<FileListItem {...defaultProps} fileItem={fileItem} onPreview={onPreview} />) + + const item = screen.getByText('test-document.pdf').closest('[class*="flex h-12"]')! + fireEvent.click(item) + + expect(onPreview).toHaveBeenCalledTimes(1) + expect(onPreview).toHaveBeenCalledWith(fileItem.file) + }) + + it('should call onRemove when delete button is clicked', () => { + const onRemove = vi.fn() + const fileItem = createMockFileItem() + const { container } = render(<FileListItem {...defaultProps} fileItem={fileItem} onRemove={onRemove} />) + + const deleteButton = container.querySelector('.cursor-pointer')! + fireEvent.click(deleteButton) + + expect(onRemove).toHaveBeenCalledTimes(1) + expect(onRemove).toHaveBeenCalledWith('file-123') + }) + + it('should stop propagation when delete button is clicked', () => { + const onPreview = vi.fn() + const onRemove = vi.fn() + const { container } = render(<FileListItem {...defaultProps} onPreview={onPreview} onRemove={onRemove} />) + + const deleteButton = container.querySelector('.cursor-pointer')! + fireEvent.click(deleteButton) + + expect(onRemove).toHaveBeenCalledTimes(1) + expect(onPreview).not.toHaveBeenCalled() + }) + }) + + describe('file type handling', () => { + it('should handle files with multiple dots in name', () => { + const fileItem = createMockFileItem({ + file: createMockFile({ name: 'my.document.file.docx' }), + }) + render(<FileListItem {...defaultProps} fileItem={fileItem} />) + + expect(screen.getByText('my.document.file.docx')).toBeInTheDocument() + // Extension is lowercase with uppercase CSS class + expect(screen.getByText('docx')).toBeInTheDocument() + }) + + it('should handle files without extension', () => { + const fileItem = createMockFileItem({ + file: createMockFile({ name: 'README' }), + }) + render(<FileListItem {...defaultProps} fileItem={fileItem} />) + + // getFileType returns 'README' when there's no extension (last part after split) + expect(screen.getAllByText('README')).toHaveLength(2) // filename and extension + }) + + it('should handle various file extensions', () => { + const extensions = ['txt', 'md', 'json', 'csv', 'xlsx'] + + extensions.forEach((ext) => { + const fileItem = createMockFileItem({ + file: createMockFile({ name: `file.${ext}` }), + }) + const { unmount } = render(<FileListItem {...defaultProps} fileItem={fileItem} />) + // Extension is rendered in lowercase with uppercase CSS class + expect(screen.getByText(ext)).toBeInTheDocument() + unmount() + }) + }) + }) + + describe('file size display', () => { + it('should display size in KB for small files', () => { + const fileItem = createMockFileItem({ + file: createMockFile({ size: 5 * 1024 }), // 5KB + }) + render(<FileListItem {...defaultProps} fileItem={fileItem} />) + + expect(screen.getByText('5.00 KB')).toBeInTheDocument() + }) + + it('should display size in MB for larger files', () => { + const fileItem = createMockFileItem({ + file: createMockFile({ size: 5 * 1024 * 1024 }), // 5MB + }) + render(<FileListItem {...defaultProps} fileItem={fileItem} />) + + expect(screen.getByText('5.00 MB')).toBeInTheDocument() + }) + + it('should display size at threshold (10KB)', () => { + const fileItem = createMockFileItem({ + file: createMockFile({ size: 10 * 1024 }), // 10KB + }) + render(<FileListItem {...defaultProps} fileItem={fileItem} />) + + expect(screen.getByText('10.00 KB')).toBeInTheDocument() + }) + }) + + describe('upload progress values', () => { + it('should show chart at progress 1', () => { + const fileItem = createMockFileItem({ progress: 1 }) + render(<FileListItem {...defaultProps} fileItem={fileItem} />) + + expect(screen.getByTestId('pie-chart')).toBeInTheDocument() + }) + + it('should show chart at progress 99', () => { + const fileItem = createMockFileItem({ progress: 99 }) + render(<FileListItem {...defaultProps} fileItem={fileItem} />) + + expect(screen.getByTestId('pie-chart')).toHaveAttribute('data-percentage', '99') + }) + + it('should not show chart at progress 100', () => { + const fileItem = createMockFileItem({ progress: 100 }) + render(<FileListItem {...defaultProps} fileItem={fileItem} />) + + expect(screen.queryByTestId('pie-chart')).not.toBeInTheDocument() + }) + }) + + describe('styling', () => { + it('should have proper shadow styling', () => { + const { container } = render(<FileListItem {...defaultProps} />) + + const item = container.firstChild as HTMLElement + expect(item).toHaveClass('shadow-xs') + }) + + it('should have proper border styling', () => { + const { container } = render(<FileListItem {...defaultProps} />) + + const item = container.firstChild as HTMLElement + expect(item).toHaveClass('border', 'border-components-panel-border') + }) + + it('should truncate long file names', () => { + const longFileName = 'this-is-a-very-long-file-name-that-should-be-truncated.pdf' + const fileItem = createMockFileItem({ + file: createMockFile({ name: longFileName }), + }) + render(<FileListItem {...defaultProps} fileItem={fileItem} />) + + const nameElement = screen.getByText(longFileName) + expect(nameElement).toHaveClass('truncate') + }) + }) +}) diff --git a/web/app/components/datasets/documents/create-from-pipeline/data-source/local-file/components/file-list-item.tsx b/web/app/components/datasets/documents/create-from-pipeline/data-source/local-file/components/file-list-item.tsx new file mode 100644 index 0000000000..1a61fa04f0 --- /dev/null +++ b/web/app/components/datasets/documents/create-from-pipeline/data-source/local-file/components/file-list-item.tsx @@ -0,0 +1,85 @@ +import type { CustomFile as File, FileItem } from '@/models/datasets' +import { RiDeleteBinLine, RiErrorWarningFill } from '@remixicon/react' +import dynamic from 'next/dynamic' +import { useMemo } from 'react' +import DocumentFileIcon from '@/app/components/datasets/common/document-file-icon' +import { getFileType } from '@/app/components/datasets/common/image-uploader/utils' +import useTheme from '@/hooks/use-theme' +import { Theme } from '@/types/app' +import { cn } from '@/utils/classnames' +import { formatFileSize } from '@/utils/format' +import { PROGRESS_ERROR } from '../constants' + +const SimplePieChart = dynamic(() => import('@/app/components/base/simple-pie-chart'), { ssr: false }) + +export type FileListItemProps = { + fileItem: FileItem + onPreview: (file: File) => void + onRemove: (fileID: string) => void +} + +const FileListItem = ({ + fileItem, + onPreview, + onRemove, +}: FileListItemProps) => { + const { theme } = useTheme() + const chartColor = useMemo(() => theme === Theme.dark ? '#5289ff' : '#296dff', [theme]) + + const isUploading = fileItem.progress >= 0 && fileItem.progress < 100 + const isError = fileItem.progress === PROGRESS_ERROR + + const handleClick = () => { + onPreview(fileItem.file) + } + + const handleRemove = (e: React.MouseEvent) => { + e.stopPropagation() + onRemove(fileItem.fileID) + } + + return ( + <div + onClick={handleClick} + className={cn( + 'flex h-12 items-center rounded-lg border border-components-panel-border bg-components-panel-on-panel-item-bg shadow-xs shadow-shadow-shadow-4', + isError && 'border-state-destructive-border bg-state-destructive-hover', + )} + > + <div className="flex w-12 shrink-0 items-center justify-center"> + <DocumentFileIcon + size="lg" + className="shrink-0" + name={fileItem.file.name} + extension={getFileType(fileItem.file)} + /> + </div> + <div className="flex shrink grow flex-col gap-0.5"> + <div className="flex w-full"> + <div className="w-0 grow truncate text-xs text-text-secondary">{fileItem.file.name}</div> + </div> + <div className="w-full truncate text-2xs leading-3 text-text-tertiary"> + <span className="uppercase">{getFileType(fileItem.file)}</span> + <span className="px-1 text-text-quaternary">ยท</span> + <span>{formatFileSize(fileItem.file.size)}</span> + </div> + </div> + <div className="flex w-16 shrink-0 items-center justify-end gap-1 pr-3"> + {isUploading && ( + <SimplePieChart percentage={fileItem.progress} stroke={chartColor} fill={chartColor} animationDuration={0} /> + )} + {isError && ( + <RiErrorWarningFill className="size-4 text-text-destructive" /> + )} + <span + className="flex h-6 w-6 cursor-pointer items-center justify-center" + onClick={handleRemove} + > + <RiDeleteBinLine className="size-4 text-text-tertiary" /> + </span> + </div> + </div> + ) +} + +export default FileListItem diff --git a/web/app/components/datasets/documents/create-from-pipeline/data-source/local-file/components/upload-dropzone.spec.tsx b/web/app/components/datasets/documents/create-from-pipeline/data-source/local-file/components/upload-dropzone.spec.tsx new file mode 100644 index 0000000000..21742b731c --- /dev/null +++ b/web/app/components/datasets/documents/create-from-pipeline/data-source/local-file/components/upload-dropzone.spec.tsx @@ -0,0 +1,231 @@ +import type { RefObject } from 'react' +import type { UploadDropzoneProps } from './upload-dropzone' +import { fireEvent, render, screen } from '@testing-library/react' +import { beforeEach, describe, expect, it, vi } from 'vitest' +import UploadDropzone from './upload-dropzone' + +// Helper to create mock ref objects for testing +const createMockRef = <T,>(value: T | null = null): RefObject<T | null> => ({ current: value }) + +// Mock react-i18next +vi.mock('react-i18next', () => ({ + useTranslation: () => ({ + t: (key: string, options?: { ns?: string }) => { + const translations: Record<string, string> = { + 'stepOne.uploader.button': 'Drag and drop files, or', + 'stepOne.uploader.buttonSingleFile': 'Drag and drop file, or', + 'stepOne.uploader.browse': 'Browse', + 'stepOne.uploader.tip': 'Supports {{supportTypes}}, Max {{size}}MB each, up to {{batchCount}} files at a time, {{totalCount}} files total', + } + let result = translations[key] || key + if (options && typeof options === 'object') { + Object.entries(options).forEach(([k, v]) => { + result = result.replace(`{{${k}}}`, String(v)) + }) + } + return result + }, + }), +})) + +describe('UploadDropzone', () => { + const defaultProps: UploadDropzoneProps = { + dropRef: createMockRef<HTMLDivElement>() as RefObject<HTMLDivElement | null>, + dragRef: createMockRef<HTMLDivElement>() as RefObject<HTMLDivElement | null>, + fileUploaderRef: createMockRef<HTMLInputElement>() as RefObject<HTMLInputElement | null>, + dragging: false, + supportBatchUpload: true, + supportTypesShowNames: 'PDF, DOCX, TXT', + fileUploadConfig: { + file_size_limit: 15, + batch_count_limit: 5, + file_upload_limit: 10, + }, + acceptTypes: ['.pdf', '.docx', '.txt'], + onSelectFile: vi.fn(), + onFileChange: vi.fn(), + allowedExtensions: ['pdf', 'docx', 'txt'], + } + + beforeEach(() => { + vi.clearAllMocks() + }) + + describe('rendering', () => { + it('should render the dropzone container', () => { + const { container } = render(<UploadDropzone {...defaultProps} />) + + const dropzone = container.querySelector('[class*="border-dashed"]') + expect(dropzone).toBeInTheDocument() + }) + + it('should render hidden file input', () => { + render(<UploadDropzone {...defaultProps} />) + + const input = document.getElementById('fileUploader') as HTMLInputElement + expect(input).toBeInTheDocument() + expect(input).toHaveClass('hidden') + expect(input).toHaveAttribute('type', 'file') + }) + + it('should render upload icon', () => { + render(<UploadDropzone {...defaultProps} />) + + const icon = document.querySelector('svg') + expect(icon).toBeInTheDocument() + }) + + it('should render browse label when extensions are allowed', () => { + render(<UploadDropzone {...defaultProps} />) + + expect(screen.getByText('Browse')).toBeInTheDocument() + }) + + it('should not render browse label when no extensions allowed', () => { + render(<UploadDropzone {...defaultProps} allowedExtensions={[]} />) + + expect(screen.queryByText('Browse')).not.toBeInTheDocument() + }) + + it('should render file size and count limits', () => { + render(<UploadDropzone {...defaultProps} />) + + const tipText = screen.getByText(/Supports.*Max.*15MB/i) + expect(tipText).toBeInTheDocument() + }) + }) + + describe('file input configuration', () => { + it('should allow multiple files when supportBatchUpload is true', () => { + render(<UploadDropzone {...defaultProps} supportBatchUpload={true} />) + + const input = document.getElementById('fileUploader') as HTMLInputElement + expect(input).toHaveAttribute('multiple') + }) + + it('should not allow multiple files when supportBatchUpload is false', () => { + render(<UploadDropzone {...defaultProps} supportBatchUpload={false} />) + + const input = document.getElementById('fileUploader') as HTMLInputElement + expect(input).not.toHaveAttribute('multiple') + }) + + it('should set accept attribute with correct types', () => { + render(<UploadDropzone {...defaultProps} acceptTypes={['.pdf', '.docx']} />) + + const input = document.getElementById('fileUploader') as HTMLInputElement + expect(input).toHaveAttribute('accept', '.pdf,.docx') + }) + }) + + describe('text content', () => { + it('should show batch upload text when supportBatchUpload is true', () => { + render(<UploadDropzone {...defaultProps} supportBatchUpload={true} />) + + expect(screen.getByText(/Drag and drop files/i)).toBeInTheDocument() + }) + + it('should show single file text when supportBatchUpload is false', () => { + render(<UploadDropzone {...defaultProps} supportBatchUpload={false} />) + + expect(screen.getByText(/Drag and drop file/i)).toBeInTheDocument() + }) + }) + + describe('dragging state', () => { + it('should apply dragging styles when dragging is true', () => { + const { container } = render(<UploadDropzone {...defaultProps} dragging={true} />) + + const dropzone = container.querySelector('[class*="border-components-dropzone-border-accent"]') + expect(dropzone).toBeInTheDocument() + }) + + it('should render drag overlay when dragging', () => { + const dragRef = createMockRef<HTMLDivElement>() + render(<UploadDropzone {...defaultProps} dragging={true} dragRef={dragRef as RefObject<HTMLDivElement | null>} />) + + const overlay = document.querySelector('.absolute.left-0.top-0') + expect(overlay).toBeInTheDocument() + }) + + it('should not render drag overlay when not dragging', () => { + render(<UploadDropzone {...defaultProps} dragging={false} />) + + const overlay = document.querySelector('.absolute.left-0.top-0') + expect(overlay).not.toBeInTheDocument() + }) + }) + + describe('event handlers', () => { + it('should call onSelectFile when browse label is clicked', () => { + const onSelectFile = vi.fn() + render(<UploadDropzone {...defaultProps} onSelectFile={onSelectFile} />) + + const browseLabel = screen.getByText('Browse') + fireEvent.click(browseLabel) + + expect(onSelectFile).toHaveBeenCalledTimes(1) + }) + + it('should call onFileChange when files are selected', () => { + const onFileChange = vi.fn() + render(<UploadDropzone {...defaultProps} onFileChange={onFileChange} />) + + const input = document.getElementById('fileUploader') as HTMLInputElement + const file = new File(['content'], 'test.pdf', { type: 'application/pdf' }) + + fireEvent.change(input, { target: { files: [file] } }) + + expect(onFileChange).toHaveBeenCalledTimes(1) + }) + }) + + describe('refs', () => { + it('should attach dropRef to drop container', () => { + const dropRef = createMockRef<HTMLDivElement>() + render(<UploadDropzone {...defaultProps} dropRef={dropRef as RefObject<HTMLDivElement | null>} />) + + expect(dropRef.current).toBeInstanceOf(HTMLDivElement) + }) + + it('should attach fileUploaderRef to input element', () => { + const fileUploaderRef = createMockRef<HTMLInputElement>() + render(<UploadDropzone {...defaultProps} fileUploaderRef={fileUploaderRef as RefObject<HTMLInputElement | null>} />) + + expect(fileUploaderRef.current).toBeInstanceOf(HTMLInputElement) + }) + + it('should attach dragRef to overlay when dragging', () => { + const dragRef = createMockRef<HTMLDivElement>() + render(<UploadDropzone {...defaultProps} dragging={true} dragRef={dragRef as RefObject<HTMLDivElement | null>} />) + + expect(dragRef.current).toBeInstanceOf(HTMLDivElement) + }) + }) + + describe('styling', () => { + it('should have base dropzone styling', () => { + const { container } = render(<UploadDropzone {...defaultProps} />) + + const dropzone = container.querySelector('[class*="border-dashed"]') + expect(dropzone).toBeInTheDocument() + expect(dropzone).toHaveClass('rounded-xl') + }) + + it('should have cursor-pointer on browse label', () => { + render(<UploadDropzone {...defaultProps} />) + + const browseLabel = screen.getByText('Browse') + expect(browseLabel).toHaveClass('cursor-pointer') + }) + }) + + describe('accessibility', () => { + it('should have an accessible file input', () => { + render(<UploadDropzone {...defaultProps} />) + + const input = document.getElementById('fileUploader') as HTMLInputElement + expect(input).toHaveAttribute('id', 'fileUploader') + }) + }) +}) diff --git a/web/app/components/datasets/documents/create-from-pipeline/data-source/local-file/components/upload-dropzone.tsx b/web/app/components/datasets/documents/create-from-pipeline/data-source/local-file/components/upload-dropzone.tsx new file mode 100644 index 0000000000..66bf42d365 --- /dev/null +++ b/web/app/components/datasets/documents/create-from-pipeline/data-source/local-file/components/upload-dropzone.tsx @@ -0,0 +1,83 @@ +import type { ChangeEvent, RefObject } from 'react' +import { RiUploadCloud2Line } from '@remixicon/react' +import { useTranslation } from 'react-i18next' +import { cn } from '@/utils/classnames' + +type FileUploadConfig = { + file_size_limit: number + batch_count_limit: number + file_upload_limit: number +} + +export type UploadDropzoneProps = { + dropRef: RefObject<HTMLDivElement | null> + dragRef: RefObject<HTMLDivElement | null> + fileUploaderRef: RefObject<HTMLInputElement | null> + dragging: boolean + supportBatchUpload: boolean + supportTypesShowNames: string + fileUploadConfig: FileUploadConfig + acceptTypes: string[] + onSelectFile: () => void + onFileChange: (e: ChangeEvent<HTMLInputElement>) => void + allowedExtensions: string[] +} + +const UploadDropzone = ({ + dropRef, + dragRef, + fileUploaderRef, + dragging, + supportBatchUpload, + supportTypesShowNames, + fileUploadConfig, + acceptTypes, + onSelectFile, + onFileChange, + allowedExtensions, +}: UploadDropzoneProps) => { + const { t } = useTranslation() + + return ( + <> + <input + ref={fileUploaderRef} + id="fileUploader" + className="hidden" + type="file" + multiple={supportBatchUpload} + accept={acceptTypes.join(',')} + onChange={onFileChange} + /> + <div + ref={dropRef} + className={cn( + 'relative box-border flex min-h-20 flex-col items-center justify-center gap-1 rounded-xl border border-dashed border-components-dropzone-border bg-components-dropzone-bg px-4 py-3 text-xs leading-4 text-text-tertiary', + dragging && 'border-components-dropzone-border-accent bg-components-dropzone-bg-accent', + )} + > + <div className="flex min-h-5 items-center justify-center text-sm leading-4 text-text-secondary"> + <RiUploadCloud2Line className="mr-2 size-5" /> + <span> + {supportBatchUpload ? t('stepOne.uploader.button', { ns: 'datasetCreation' }) : t('stepOne.uploader.buttonSingleFile', { ns: 'datasetCreation' })} + {allowedExtensions.length > 0 && ( + <label className="ml-1 cursor-pointer text-text-accent" onClick={onSelectFile}>{t('stepOne.uploader.browse', { ns: 'datasetCreation' })}</label> + )} + </span> + </div> + <div> + {t('stepOne.uploader.tip', { + ns: 'datasetCreation', + size: fileUploadConfig.file_size_limit, + supportTypes: supportTypesShowNames, + batchCount: fileUploadConfig.batch_count_limit, + totalCount: fileUploadConfig.file_upload_limit, + })} + </div> + {dragging && <div ref={dragRef} className="absolute left-0 top-0 h-full w-full" />} + </div> + </> + ) +} + +export default UploadDropzone diff --git a/web/app/components/datasets/documents/create-from-pipeline/data-source/local-file/constants.ts b/web/app/components/datasets/documents/create-from-pipeline/data-source/local-file/constants.ts new file mode 100644 index 0000000000..cda2dae868 --- /dev/null +++ b/web/app/components/datasets/documents/create-from-pipeline/data-source/local-file/constants.ts @@ -0,0 +1,3 @@ +export const PROGRESS_NOT_STARTED = -1 +export const PROGRESS_ERROR = -2 +export const PROGRESS_COMPLETE = 100 diff --git a/web/app/components/datasets/documents/create-from-pipeline/data-source/local-file/hooks/use-local-file-upload.spec.tsx b/web/app/components/datasets/documents/create-from-pipeline/data-source/local-file/hooks/use-local-file-upload.spec.tsx new file mode 100644 index 0000000000..6248b70506 --- /dev/null +++ b/web/app/components/datasets/documents/create-from-pipeline/data-source/local-file/hooks/use-local-file-upload.spec.tsx @@ -0,0 +1,911 @@ +import type { ReactNode } from 'react' +import type { CustomFile, FileItem } from '@/models/datasets' +import { act, render, renderHook, waitFor } from '@testing-library/react' +import { beforeEach, describe, expect, it, vi } from 'vitest' +import { PROGRESS_ERROR, PROGRESS_NOT_STARTED } from '../constants' + +// Mock notify function - defined before mocks +const mockNotify = vi.fn() +const mockClose = vi.fn() + +// Mock ToastContext with factory function +vi.mock('@/app/components/base/toast', async () => { + const { createContext, useContext } = await import('use-context-selector') + const context = createContext({ notify: mockNotify, close: mockClose }) + return { + ToastContext: context, + useToastContext: () => useContext(context), + } +}) + +// Mock file uploader utils +vi.mock('@/app/components/base/file-uploader/utils', () => ({ + getFileUploadErrorMessage: (e: Error, defaultMsg: string) => e.message || defaultMsg, +})) + +// Mock format utils used by the shared hook +vi.mock('@/utils/format', () => ({ + getFileExtension: (filename: string) => { + const parts = filename.split('.') + return parts[parts.length - 1] || '' + }, +})) + +// Mock react-i18next +vi.mock('react-i18next', () => ({ + useTranslation: () => ({ + t: (key: string) => key, + }), +})) + +// Mock locale context +vi.mock('@/context/i18n', () => ({ + useLocale: () => 'en-US', +})) + +// Mock i18n config +vi.mock('@/i18n-config/language', () => ({ + LanguagesSupported: ['en-US', 'zh-Hans'], +})) + +// Mock config +vi.mock('@/config', () => ({ + IS_CE_EDITION: false, +})) + +// Mock store functions +const mockSetLocalFileList = vi.fn() +const mockSetCurrentLocalFile = vi.fn() +const mockGetState = vi.fn(() => ({ + setLocalFileList: mockSetLocalFileList, + setCurrentLocalFile: mockSetCurrentLocalFile, +})) +const mockStore = { getState: mockGetState } + +vi.mock('../../store', () => ({ + useDataSourceStoreWithSelector: vi.fn((selector: (state: { localFileList: FileItem[] }) => FileItem[]) => + selector({ localFileList: [] }), + ), + useDataSourceStore: vi.fn(() => mockStore), +})) + +// Mock file upload config +vi.mock('@/service/use-common', () => ({ + useFileUploadConfig: vi.fn(() => ({ + data: { + file_size_limit: 15, + batch_count_limit: 5, + file_upload_limit: 10, + }, + })), + // Required by the shared useFileUpload hook + useFileSupportTypes: vi.fn(() => ({ + data: { + allowed_extensions: ['pdf', 'docx', 'txt'], + }, + })), +})) + +// Mock upload service +const mockUpload = vi.fn() +vi.mock('@/service/base', () => ({ + upload: (...args: unknown[]) => mockUpload(...args), +})) + +// Import after all mocks are set up +const { useLocalFileUpload } = await import('./use-local-file-upload') +const { ToastContext } = await import('@/app/components/base/toast') + +const createWrapper = () => { + return ({ children }: { children: ReactNode }) => ( + <ToastContext.Provider value={{ notify: mockNotify, close: mockClose }}> + {children} + </ToastContext.Provider> + ) +} + +describe('useLocalFileUpload', () => { + beforeEach(() => { + vi.clearAllMocks() + mockUpload.mockReset() + }) + + describe('initialization', () => { + it('should initialize with default values', () => { + const { result } = renderHook( + () => useLocalFileUpload({ allowedExtensions: ['pdf', 'docx'] }), + { wrapper: createWrapper() }, + ) + + expect(result.current.dragging).toBe(false) + expect(result.current.localFileList).toEqual([]) + expect(result.current.hideUpload).toBe(false) + }) + + it('should create refs for dropzone, drag area, and file uploader', () => { + const { result } = renderHook( + () => useLocalFileUpload({ allowedExtensions: ['pdf'] }), + { wrapper: createWrapper() }, + ) + + expect(result.current.dropRef).toBeDefined() + expect(result.current.dragRef).toBeDefined() + expect(result.current.fileUploaderRef).toBeDefined() + }) + + it('should compute acceptTypes from allowedExtensions', () => { + const { result } = renderHook( + () => useLocalFileUpload({ allowedExtensions: ['pdf', 'docx', 'txt'] }), + { wrapper: createWrapper() }, + ) + + expect(result.current.acceptTypes).toEqual(['.pdf', '.docx', '.txt']) + }) + + it('should compute supportTypesShowNames correctly', () => { + const { result } = renderHook( + () => useLocalFileUpload({ allowedExtensions: ['pdf', 'docx', 'md'] }), + { wrapper: createWrapper() }, + ) + + expect(result.current.supportTypesShowNames).toContain('PDF') + expect(result.current.supportTypesShowNames).toContain('DOCX') + expect(result.current.supportTypesShowNames).toContain('MARKDOWN') + }) + + it('should provide file upload config with defaults', () => { + const { result } = renderHook( + () => useLocalFileUpload({ allowedExtensions: ['pdf'] }), + { wrapper: createWrapper() }, + ) + + expect(result.current.fileUploadConfig.file_size_limit).toBe(15) + expect(result.current.fileUploadConfig.batch_count_limit).toBe(5) + expect(result.current.fileUploadConfig.file_upload_limit).toBe(10) + }) + }) + + describe('supportBatchUpload option', () => { + it('should use batch limits when supportBatchUpload is true', () => { + const { result } = renderHook( + () => useLocalFileUpload({ allowedExtensions: ['pdf'], supportBatchUpload: true }), + { wrapper: createWrapper() }, + ) + + expect(result.current.fileUploadConfig.batch_count_limit).toBe(5) + expect(result.current.fileUploadConfig.file_upload_limit).toBe(10) + }) + + it('should use single file limits when supportBatchUpload is false', () => { + const { result } = renderHook( + () => useLocalFileUpload({ allowedExtensions: ['pdf'], supportBatchUpload: false }), + { wrapper: createWrapper() }, + ) + + expect(result.current.fileUploadConfig.batch_count_limit).toBe(1) + expect(result.current.fileUploadConfig.file_upload_limit).toBe(1) + }) + }) + + describe('selectHandle', () => { + it('should trigger file input click', () => { + const { result } = renderHook( + () => useLocalFileUpload({ allowedExtensions: ['pdf'] }), + { wrapper: createWrapper() }, + ) + + const mockClick = vi.fn() + const mockInput = { click: mockClick } as unknown as HTMLInputElement + Object.defineProperty(result.current.fileUploaderRef, 'current', { + value: mockInput, + writable: true, + }) + + act(() => { + result.current.selectHandle() + }) + + expect(mockClick).toHaveBeenCalled() + }) + + it('should handle null fileUploaderRef gracefully', () => { + const { result } = renderHook( + () => useLocalFileUpload({ allowedExtensions: ['pdf'] }), + { wrapper: createWrapper() }, + ) + + expect(() => { + act(() => { + result.current.selectHandle() + }) + }).not.toThrow() + }) + }) + + describe('removeFile', () => { + it('should remove file from list', () => { + const { result } = renderHook( + () => useLocalFileUpload({ allowedExtensions: ['pdf'] }), + { wrapper: createWrapper() }, + ) + + act(() => { + result.current.removeFile('file-id-123') + }) + + expect(mockSetLocalFileList).toHaveBeenCalled() + }) + + it('should clear file input value when removing', () => { + const { result } = renderHook( + () => useLocalFileUpload({ allowedExtensions: ['pdf'] }), + { wrapper: createWrapper() }, + ) + + const mockInput = { value: 'some-file.pdf' } as HTMLInputElement + Object.defineProperty(result.current.fileUploaderRef, 'current', { + value: mockInput, + writable: true, + }) + + act(() => { + result.current.removeFile('file-id') + }) + + expect(mockInput.value).toBe('') + }) + }) + + describe('handlePreview', () => { + it('should set current local file when file has id', () => { + const { result } = renderHook( + () => useLocalFileUpload({ allowedExtensions: ['pdf'] }), + { wrapper: createWrapper() }, + ) + + const mockFile = { id: 'file-123', name: 'test.pdf', size: 1024 } + + act(() => { + result.current.handlePreview(mockFile as unknown as CustomFile) + }) + + expect(mockSetCurrentLocalFile).toHaveBeenCalledWith(mockFile) + }) + + it('should not set current file when file has no id', () => { + const { result } = renderHook( + () => useLocalFileUpload({ allowedExtensions: ['pdf'] }), + { wrapper: createWrapper() }, + ) + + const mockFile = { name: 'test.pdf', size: 1024 } + + act(() => { + result.current.handlePreview(mockFile as unknown as CustomFile) + }) + + expect(mockSetCurrentLocalFile).not.toHaveBeenCalled() + }) + }) + + describe('fileChangeHandle', () => { + it('should handle valid files', async () => { + mockUpload.mockResolvedValue({ id: 'uploaded-id' }) + + const { result } = renderHook( + () => useLocalFileUpload({ allowedExtensions: ['pdf'] }), + { wrapper: createWrapper() }, + ) + + const mockFile = new File(['content'], 'test.pdf', { type: 'application/pdf' }) + const event = { + target: { + files: [mockFile], + }, + } as unknown as React.ChangeEvent<HTMLInputElement> + + act(() => { + result.current.fileChangeHandle(event) + }) + + await waitFor(() => { + expect(mockSetLocalFileList).toHaveBeenCalled() + }) + }) + + it('should handle empty file list', () => { + const { result } = renderHook( + () => useLocalFileUpload({ allowedExtensions: ['pdf'] }), + { wrapper: createWrapper() }, + ) + + const event = { + target: { + files: null, + }, + } as unknown as React.ChangeEvent<HTMLInputElement> + + act(() => { + result.current.fileChangeHandle(event) + }) + + expect(mockSetLocalFileList).not.toHaveBeenCalled() + }) + + it('should reject files with invalid type', () => { + const { result } = renderHook( + () => useLocalFileUpload({ allowedExtensions: ['pdf'] }), + { wrapper: createWrapper() }, + ) + + const mockFile = new File(['content'], 'test.exe', { type: 'application/exe' }) + const event = { + target: { + files: [mockFile], + }, + } as unknown as React.ChangeEvent<HTMLInputElement> + + act(() => { + result.current.fileChangeHandle(event) + }) + + expect(mockNotify).toHaveBeenCalledWith( + expect.objectContaining({ type: 'error' }), + ) + }) + + it('should reject files exceeding size limit', () => { + const { result } = renderHook( + () => useLocalFileUpload({ allowedExtensions: ['pdf'] }), + { wrapper: createWrapper() }, + ) + + // Create a mock file larger than 15MB + const largeSize = 20 * 1024 * 1024 + const mockFile = new File([''], 'large.pdf', { type: 'application/pdf' }) + Object.defineProperty(mockFile, 'size', { value: largeSize }) + + const event = { + target: { + files: [mockFile], + }, + } as unknown as React.ChangeEvent<HTMLInputElement> + + act(() => { + result.current.fileChangeHandle(event) + }) + + expect(mockNotify).toHaveBeenCalledWith( + expect.objectContaining({ type: 'error' }), + ) + }) + + it('should limit files to batch count limit', async () => { + mockUpload.mockResolvedValue({ id: 'uploaded-id' }) + + const { result } = renderHook( + () => useLocalFileUpload({ allowedExtensions: ['pdf'] }), + { wrapper: createWrapper() }, + ) + + // Create 10 files but batch limit is 5 + const files = Array.from({ length: 10 }, (_, i) => + new File(['content'], `file${i}.pdf`, { type: 'application/pdf' })) + + const event = { + target: { + files, + }, + } as unknown as React.ChangeEvent<HTMLInputElement> + + act(() => { + result.current.fileChangeHandle(event) + }) + + await waitFor(() => { + expect(mockSetLocalFileList).toHaveBeenCalled() + }) + + // Should only process first 5 files (batch_count_limit) + const firstCall = mockSetLocalFileList.mock.calls[0] + expect(firstCall[0].length).toBeLessThanOrEqual(5) + }) + }) + + describe('upload handling', () => { + it('should handle successful upload', async () => { + const uploadedResponse = { id: 'server-file-id' } + mockUpload.mockResolvedValue(uploadedResponse) + + const { result } = renderHook( + () => useLocalFileUpload({ allowedExtensions: ['pdf'] }), + { wrapper: createWrapper() }, + ) + + const mockFile = new File(['content'], 'test.pdf', { type: 'application/pdf' }) + const event = { + target: { + files: [mockFile], + }, + } as unknown as React.ChangeEvent<HTMLInputElement> + + act(() => { + result.current.fileChangeHandle(event) + }) + + await waitFor(() => { + expect(mockUpload).toHaveBeenCalled() + }) + }) + + it('should handle upload error', async () => { + mockUpload.mockRejectedValue(new Error('Upload failed')) + + const { result } = renderHook( + () => useLocalFileUpload({ allowedExtensions: ['pdf'] }), + { wrapper: createWrapper() }, + ) + + const mockFile = new File(['content'], 'test.pdf', { type: 'application/pdf' }) + const event = { + target: { + files: [mockFile], + }, + } as unknown as React.ChangeEvent<HTMLInputElement> + + act(() => { + result.current.fileChangeHandle(event) + }) + + await waitFor(() => { + expect(mockNotify).toHaveBeenCalledWith( + expect.objectContaining({ type: 'error' }), + ) + }) + }) + + it('should call upload with correct parameters', async () => { + mockUpload.mockResolvedValue({ id: 'file-id' }) + + const { result } = renderHook( + () => useLocalFileUpload({ allowedExtensions: ['pdf'] }), + { wrapper: createWrapper() }, + ) + + const mockFile = new File(['content'], 'test.pdf', { type: 'application/pdf' }) + const event = { + target: { + files: [mockFile], + }, + } as unknown as React.ChangeEvent<HTMLInputElement> + + act(() => { + result.current.fileChangeHandle(event) + }) + + await waitFor(() => { + expect(mockUpload).toHaveBeenCalledWith( + expect.objectContaining({ + xhr: expect.any(XMLHttpRequest), + data: expect.any(FormData), + }), + false, + undefined, + '?source=datasets', + ) + }) + }) + }) + + describe('extension mapping', () => { + it('should map md to markdown', () => { + const { result } = renderHook( + () => useLocalFileUpload({ allowedExtensions: ['md'] }), + { wrapper: createWrapper() }, + ) + + expect(result.current.supportTypesShowNames).toContain('MARKDOWN') + }) + + it('should map htm to html', () => { + const { result } = renderHook( + () => useLocalFileUpload({ allowedExtensions: ['htm'] }), + { wrapper: createWrapper() }, + ) + + expect(result.current.supportTypesShowNames).toContain('HTML') + }) + + it('should preserve unmapped extensions', () => { + const { result } = renderHook( + () => useLocalFileUpload({ allowedExtensions: ['pdf', 'txt'] }), + { wrapper: createWrapper() }, + ) + + expect(result.current.supportTypesShowNames).toContain('PDF') + expect(result.current.supportTypesShowNames).toContain('TXT') + }) + + it('should remove duplicate extensions', () => { + const { result } = renderHook( + () => useLocalFileUpload({ allowedExtensions: ['pdf', 'pdf', 'PDF'] }), + { wrapper: createWrapper() }, + ) + + const count = (result.current.supportTypesShowNames.match(/PDF/g) || []).length + expect(count).toBe(1) + }) + }) + + describe('drag and drop handlers', () => { + // Helper component that renders with the hook and connects refs + const TestDropzone = ({ allowedExtensions, supportBatchUpload = true }: { + allowedExtensions: string[] + supportBatchUpload?: boolean + }) => { + const { + dropRef, + dragRef, + dragging, + } = useLocalFileUpload({ allowedExtensions, supportBatchUpload }) + + return ( + <div> + <div ref={dropRef} data-testid="dropzone"> + {dragging && <div ref={dragRef} data-testid="drag-overlay" />} + </div> + <span data-testid="dragging">{String(dragging)}</span> + </div> + ) + } + + it('should set dragging true on dragenter', async () => { + const { getByTestId } = await act(async () => + render( + <ToastContext.Provider value={{ notify: mockNotify, close: mockClose }}> + <TestDropzone allowedExtensions={['pdf']} /> + </ToastContext.Provider>, + ), + ) + + const dropzone = getByTestId('dropzone') + + await act(async () => { + const dragEnterEvent = new Event('dragenter', { bubbles: true, cancelable: true }) + dropzone.dispatchEvent(dragEnterEvent) + }) + + expect(getByTestId('dragging').textContent).toBe('true') + }) + + it('should handle dragover event', async () => { + const { getByTestId } = await act(async () => + render( + <ToastContext.Provider value={{ notify: mockNotify, close: mockClose }}> + <TestDropzone allowedExtensions={['pdf']} /> + </ToastContext.Provider>, + ), + ) + + const dropzone = getByTestId('dropzone') + + await act(async () => { + const dragOverEvent = new Event('dragover', { bubbles: true, cancelable: true }) + dropzone.dispatchEvent(dragOverEvent) + }) + + // dragover should not throw + expect(dropzone).toBeInTheDocument() + }) + + it('should set dragging false on dragleave from drag overlay', async () => { + const { getByTestId, queryByTestId } = await act(async () => + render( + <ToastContext.Provider value={{ notify: mockNotify, close: mockClose }}> + <TestDropzone allowedExtensions={['pdf']} /> + </ToastContext.Provider>, + ), + ) + + const dropzone = getByTestId('dropzone') + + // First trigger dragenter to set dragging true + await act(async () => { + const dragEnterEvent = new Event('dragenter', { bubbles: true, cancelable: true }) + dropzone.dispatchEvent(dragEnterEvent) + }) + + expect(getByTestId('dragging').textContent).toBe('true') + + // Now the drag overlay should be rendered + const dragOverlay = queryByTestId('drag-overlay') + if (dragOverlay) { + await act(async () => { + const dragLeaveEvent = new Event('dragleave', { bubbles: true, cancelable: true }) + Object.defineProperty(dragLeaveEvent, 'target', { value: dragOverlay }) + dropzone.dispatchEvent(dragLeaveEvent) + }) + } + }) + + it('should handle drop with files', async () => { + mockUpload.mockResolvedValue({ id: 'uploaded-id' }) + + const { getByTestId } = await act(async () => + render( + <ToastContext.Provider value={{ notify: mockNotify, close: mockClose }}> + <TestDropzone allowedExtensions={['pdf']} /> + </ToastContext.Provider>, + ), + ) + + const dropzone = getByTestId('dropzone') + const mockFile = new File(['content'], 'test.pdf', { type: 'application/pdf' }) + + await act(async () => { + const dropEvent = new Event('drop', { bubbles: true, cancelable: true }) as Event & { + dataTransfer: { items: DataTransferItem[], files: File[] } | null + } + // Mock dataTransfer with items array (used by the shared hook for directory traversal) + dropEvent.dataTransfer = { + items: [{ + kind: 'file', + getAsFile: () => mockFile, + }] as unknown as DataTransferItem[], + files: [mockFile], + } + dropzone.dispatchEvent(dropEvent) + }) + + await waitFor(() => { + expect(mockSetLocalFileList).toHaveBeenCalled() + }) + }) + + it('should handle drop without dataTransfer', async () => { + const { getByTestId } = await act(async () => + render( + <ToastContext.Provider value={{ notify: mockNotify, close: mockClose }}> + <TestDropzone allowedExtensions={['pdf']} /> + </ToastContext.Provider>, + ), + ) + + const dropzone = getByTestId('dropzone') + mockSetLocalFileList.mockClear() + + await act(async () => { + const dropEvent = new Event('drop', { bubbles: true, cancelable: true }) as Event & { dataTransfer: { files: File[] } | null } + dropEvent.dataTransfer = null + dropzone.dispatchEvent(dropEvent) + }) + + // Should not upload when no dataTransfer + expect(mockSetLocalFileList).not.toHaveBeenCalled() + }) + + it('should limit to single file on drop when supportBatchUpload is false', async () => { + mockUpload.mockResolvedValue({ id: 'uploaded-id' }) + + const { getByTestId } = await act(async () => + render( + <ToastContext.Provider value={{ notify: mockNotify, close: mockClose }}> + <TestDropzone allowedExtensions={['pdf']} supportBatchUpload={false} /> + </ToastContext.Provider>, + ), + ) + + const dropzone = getByTestId('dropzone') + const files = [ + new File(['content1'], 'test1.pdf', { type: 'application/pdf' }), + new File(['content2'], 'test2.pdf', { type: 'application/pdf' }), + ] + + await act(async () => { + const dropEvent = new Event('drop', { bubbles: true, cancelable: true }) as Event & { + dataTransfer: { items: DataTransferItem[], files: File[] } | null + } + // Mock dataTransfer with items array (used by the shared hook for directory traversal) + dropEvent.dataTransfer = { + items: files.map(f => ({ + kind: 'file', + getAsFile: () => f, + })) as unknown as DataTransferItem[], + files, + } + dropzone.dispatchEvent(dropEvent) + }) + + await waitFor(() => { + expect(mockSetLocalFileList).toHaveBeenCalled() + // Should only have 1 file (limited by supportBatchUpload: false) + const callArgs = mockSetLocalFileList.mock.calls[0][0] + expect(callArgs.length).toBe(1) + }) + }) + }) + + describe('file upload limit', () => { + it('should reject files exceeding total file upload limit', async () => { + // Mock store to return existing files + const { useDataSourceStoreWithSelector } = vi.mocked(await import('../../store')) + const existingFiles: FileItem[] = Array.from({ length: 8 }, (_, i) => ({ + fileID: `existing-${i}`, + file: { name: `existing-${i}.pdf`, size: 1024 } as CustomFile, + progress: 100, + })) + vi.mocked(useDataSourceStoreWithSelector).mockImplementation(selector => + selector({ localFileList: existingFiles } as Parameters<typeof selector>[0]), + ) + + const { result } = renderHook( + () => useLocalFileUpload({ allowedExtensions: ['pdf'] }), + { wrapper: createWrapper() }, + ) + + // Try to add 5 more files when limit is 10 and we already have 8 + const files = Array.from({ length: 5 }, (_, i) => + new File(['content'], `new-${i}.pdf`, { type: 'application/pdf' })) + + const event = { + target: { files }, + } as unknown as React.ChangeEvent<HTMLInputElement> + + act(() => { + result.current.fileChangeHandle(event) + }) + + // Should show error about files number limit + expect(mockNotify).toHaveBeenCalledWith( + expect.objectContaining({ type: 'error' }), + ) + + // Reset mock for other tests + vi.mocked(useDataSourceStoreWithSelector).mockImplementation(selector => + selector({ localFileList: [] as FileItem[] } as Parameters<typeof selector>[0]), + ) + }) + }) + + describe('upload progress tracking', () => { + it('should track upload progress', async () => { + let progressCallback: ((e: ProgressEvent) => void) | undefined + + mockUpload.mockImplementation(async (options: { onprogress: (e: ProgressEvent) => void }) => { + progressCallback = options.onprogress + return { id: 'uploaded-id' } + }) + + const { result } = renderHook( + () => useLocalFileUpload({ allowedExtensions: ['pdf'] }), + { wrapper: createWrapper() }, + ) + + const mockFile = new File(['content'], 'test.pdf', { type: 'application/pdf' }) + const event = { + target: { files: [mockFile] }, + } as unknown as React.ChangeEvent<HTMLInputElement> + + act(() => { + result.current.fileChangeHandle(event) + }) + + await waitFor(() => { + expect(mockUpload).toHaveBeenCalled() + }) + + // Simulate progress event + if (progressCallback) { + act(() => { + progressCallback!({ + lengthComputable: true, + loaded: 50, + total: 100, + } as ProgressEvent) + }) + + expect(mockSetLocalFileList).toHaveBeenCalled() + } + }) + + it('should not update progress when not lengthComputable', async () => { + let progressCallback: ((e: ProgressEvent) => void) | undefined + const uploadCallCount = { value: 0 } + + mockUpload.mockImplementation(async (options: { onprogress: (e: ProgressEvent) => void }) => { + progressCallback = options.onprogress + uploadCallCount.value++ + return { id: 'uploaded-id' } + }) + + const { result } = renderHook( + () => useLocalFileUpload({ allowedExtensions: ['pdf'] }), + { wrapper: createWrapper() }, + ) + + const mockFile = new File(['content'], 'test.pdf', { type: 'application/pdf' }) + const event = { + target: { files: [mockFile] }, + } as unknown as React.ChangeEvent<HTMLInputElement> + + mockSetLocalFileList.mockClear() + + act(() => { + result.current.fileChangeHandle(event) + }) + + await waitFor(() => { + expect(mockUpload).toHaveBeenCalled() + }) + + const callsBeforeProgress = mockSetLocalFileList.mock.calls.length + + // Simulate progress event without lengthComputable + if (progressCallback) { + act(() => { + progressCallback!({ + lengthComputable: false, + loaded: 50, + total: 100, + } as ProgressEvent) + }) + + // Should not have additional calls + expect(mockSetLocalFileList.mock.calls.length).toBe(callsBeforeProgress) + } + }) + }) + + describe('file progress constants', () => { + it('should use PROGRESS_NOT_STARTED for new files', async () => { + mockUpload.mockResolvedValue({ id: 'file-id' }) + + const { result } = renderHook( + () => useLocalFileUpload({ allowedExtensions: ['pdf'] }), + { wrapper: createWrapper() }, + ) + + const mockFile = new File(['content'], 'test.pdf', { type: 'application/pdf' }) + const event = { + target: { + files: [mockFile], + }, + } as unknown as React.ChangeEvent<HTMLInputElement> + + act(() => { + result.current.fileChangeHandle(event) + }) + + await waitFor(() => { + const callArgs = mockSetLocalFileList.mock.calls[0][0] + expect(callArgs[0].progress).toBe(PROGRESS_NOT_STARTED) + }) + }) + + it('should set PROGRESS_ERROR on upload failure', async () => { + mockUpload.mockRejectedValue(new Error('Upload failed')) + + const { result } = renderHook( + () => useLocalFileUpload({ allowedExtensions: ['pdf'] }), + { wrapper: createWrapper() }, + ) + + const mockFile = new File(['content'], 'test.pdf', { type: 'application/pdf' }) + const event = { + target: { + files: [mockFile], + }, + } as unknown as React.ChangeEvent<HTMLInputElement> + + act(() => { + result.current.fileChangeHandle(event) + }) + + await waitFor(() => { + const calls = mockSetLocalFileList.mock.calls + const lastCall = calls[calls.length - 1][0] + expect(lastCall.some((f: FileItem) => f.progress === PROGRESS_ERROR)).toBe(true) + }) + }) + }) +}) diff --git a/web/app/components/datasets/documents/create-from-pipeline/data-source/local-file/hooks/use-local-file-upload.ts b/web/app/components/datasets/documents/create-from-pipeline/data-source/local-file/hooks/use-local-file-upload.ts new file mode 100644 index 0000000000..1f7c9ecfed --- /dev/null +++ b/web/app/components/datasets/documents/create-from-pipeline/data-source/local-file/hooks/use-local-file-upload.ts @@ -0,0 +1,105 @@ +import type { CustomFile as File, FileItem } from '@/models/datasets' +import { produce } from 'immer' +import { useCallback, useRef } from 'react' +import { useFileUpload } from '@/app/components/datasets/create/file-uploader/hooks/use-file-upload' +import { useDataSourceStore, useDataSourceStoreWithSelector } from '../../store' + +export type UseLocalFileUploadOptions = { + allowedExtensions: string[] + supportBatchUpload?: boolean +} + +/** + * Hook for handling local file uploads in the create-from-pipeline flow. + * This is a thin wrapper around the generic useFileUpload hook that provides + * Zustand store integration for state management. + */ +export const useLocalFileUpload = ({ + allowedExtensions, + supportBatchUpload = true, +}: UseLocalFileUploadOptions) => { + const localFileList = useDataSourceStoreWithSelector(state => state.localFileList) + const dataSourceStore = useDataSourceStore() + const fileListRef = useRef<FileItem[]>([]) + + // Sync fileListRef with localFileList for internal tracking + fileListRef.current = localFileList + + const prepareFileList = useCallback((files: FileItem[]) => { + const { setLocalFileList } = dataSourceStore.getState() + setLocalFileList(files) + fileListRef.current = files + }, [dataSourceStore]) + + const onFileUpdate = useCallback((fileItem: FileItem, progress: number, list: FileItem[]) => { + const { setLocalFileList } = dataSourceStore.getState() + const newList = produce(list, (draft) => { + const targetIndex = draft.findIndex(file => file.fileID === fileItem.fileID) + if (targetIndex !== -1) { + draft[targetIndex] = { + ...draft[targetIndex], + ...fileItem, + progress, + } + } + }) + setLocalFileList(newList) + }, [dataSourceStore]) + + const onFileListUpdate = useCallback((files: FileItem[]) => { + const { setLocalFileList } = dataSourceStore.getState() + setLocalFileList(files) + fileListRef.current = files + }, [dataSourceStore]) + + const onPreview = useCallback((file: File) => { + const { setCurrentLocalFile } = dataSourceStore.getState() + setCurrentLocalFile(file) + }, [dataSourceStore]) + + const { + dropRef, + dragRef, + fileUploaderRef, + dragging, + fileUploadConfig, + acceptTypes, + supportTypesShowNames, + hideUpload, + selectHandle, + fileChangeHandle, + removeFile, + handlePreview, + } = useFileUpload({ + fileList: localFileList, + prepareFileList, + onFileUpdate, + onFileListUpdate, + onPreview, + supportBatchUpload, + allowedExtensions, + }) + + return { + // Refs + dropRef, + dragRef, + fileUploaderRef, + + // State + dragging, + localFileList, + + // Config + fileUploadConfig, + acceptTypes, + supportTypesShowNames, + hideUpload, + + // Handlers + selectHandle, + fileChangeHandle, + removeFile, + handlePreview, + } +} diff --git a/web/app/components/datasets/documents/create-from-pipeline/data-source/local-file/index.spec.tsx b/web/app/components/datasets/documents/create-from-pipeline/data-source/local-file/index.spec.tsx new file mode 100644 index 0000000000..66f13be84f --- /dev/null +++ b/web/app/components/datasets/documents/create-from-pipeline/data-source/local-file/index.spec.tsx @@ -0,0 +1,398 @@ +import type { FileItem } from '@/models/datasets' +import { render, screen } from '@testing-library/react' +import { beforeEach, describe, expect, it, vi } from 'vitest' +import LocalFile from './index' + +// Mock the hook +const mockUseLocalFileUpload = vi.fn() +vi.mock('./hooks/use-local-file-upload', () => ({ + useLocalFileUpload: (...args: unknown[]) => mockUseLocalFileUpload(...args), +})) + +// Mock react-i18next for sub-components +vi.mock('react-i18next', () => ({ + useTranslation: () => ({ + t: (key: string) => key, + }), +})) + +// Mock theme hook for sub-components +vi.mock('@/hooks/use-theme', () => ({ + default: () => ({ theme: 'light' }), +})) + +// Mock theme types +vi.mock('@/types/app', () => ({ + Theme: { dark: 'dark', light: 'light' }, +})) + +// Mock DocumentFileIcon +vi.mock('@/app/components/datasets/common/document-file-icon', () => ({ + default: ({ name }: { name: string }) => <div data-testid="document-icon">{name}</div>, +})) + +// Mock SimplePieChart +vi.mock('next/dynamic', () => ({ + default: () => { + const Component = ({ percentage }: { percentage: number }) => ( + <div data-testid="pie-chart"> + {percentage} + % + </div> + ) + return Component + }, +})) + +describe('LocalFile', () => { + const mockDropRef = { current: null } + const mockDragRef = { current: null } + const mockFileUploaderRef = { current: null } + + const defaultHookReturn = { + dropRef: mockDropRef, + dragRef: mockDragRef, + fileUploaderRef: mockFileUploaderRef, + dragging: false, + localFileList: [] as FileItem[], + fileUploadConfig: { + file_size_limit: 15, + batch_count_limit: 5, + file_upload_limit: 10, + }, + acceptTypes: ['.pdf', '.docx'], + supportTypesShowNames: 'PDF, DOCX', + hideUpload: false, + selectHandle: vi.fn(), + fileChangeHandle: vi.fn(), + removeFile: vi.fn(), + handlePreview: vi.fn(), + } + + beforeEach(() => { + vi.clearAllMocks() + mockUseLocalFileUpload.mockReturnValue(defaultHookReturn) + }) + + describe('rendering', () => { + it('should render the component container', () => { + const { container } = render( + <LocalFile allowedExtensions={['pdf', 'docx']} />, + ) + + expect(container.firstChild).toHaveClass('flex', 'flex-col') + }) + + it('should render UploadDropzone when hideUpload is false', () => { + render(<LocalFile allowedExtensions={['pdf']} />) + + const fileInput = document.getElementById('fileUploader') + expect(fileInput).toBeInTheDocument() + }) + + it('should not render UploadDropzone when hideUpload is true', () => { + mockUseLocalFileUpload.mockReturnValue({ + ...defaultHookReturn, + hideUpload: true, + }) + + render(<LocalFile allowedExtensions={['pdf']} />) + + const fileInput = document.getElementById('fileUploader') + expect(fileInput).not.toBeInTheDocument() + }) + }) + + describe('file list rendering', () => { + it('should not render file list when empty', () => { + render(<LocalFile allowedExtensions={['pdf']} />) + + expect(screen.queryByTestId('document-icon')).not.toBeInTheDocument() + }) + + it('should render file list when files exist', () => { + const mockFile = { + name: 'test.pdf', + size: 1024, + type: 'application/pdf', + lastModified: Date.now(), + } as File + + mockUseLocalFileUpload.mockReturnValue({ + ...defaultHookReturn, + localFileList: [ + { + fileID: 'file-1', + file: mockFile, + progress: -1, + }, + ], + }) + + render(<LocalFile allowedExtensions={['pdf']} />) + + expect(screen.getByTestId('document-icon')).toBeInTheDocument() + }) + + it('should render multiple file items', () => { + const createMockFile = (name: string) => ({ + name, + size: 1024, + type: 'application/pdf', + lastModified: Date.now(), + }) as File + + mockUseLocalFileUpload.mockReturnValue({ + ...defaultHookReturn, + localFileList: [ + { fileID: 'file-1', file: createMockFile('doc1.pdf'), progress: -1 }, + { fileID: 'file-2', file: createMockFile('doc2.pdf'), progress: -1 }, + { fileID: 'file-3', file: createMockFile('doc3.pdf'), progress: -1 }, + ], + }) + + render(<LocalFile allowedExtensions={['pdf']} />) + + const icons = screen.getAllByTestId('document-icon') + expect(icons).toHaveLength(3) + }) + + it('should use correct key for file items', () => { + const mockFile = { + name: 'test.pdf', + size: 1024, + type: 'application/pdf', + lastModified: Date.now(), + } as File + + mockUseLocalFileUpload.mockReturnValue({ + ...defaultHookReturn, + localFileList: [ + { fileID: 'unique-id-123', file: mockFile, progress: -1 }, + ], + }) + + render(<LocalFile allowedExtensions={['pdf']} />) + + // The component should render without errors (key is used internally) + expect(screen.getByTestId('document-icon')).toBeInTheDocument() + }) + }) + + describe('hook integration', () => { + it('should pass allowedExtensions to hook', () => { + render(<LocalFile allowedExtensions={['pdf', 'docx', 'txt']} />) + + expect(mockUseLocalFileUpload).toHaveBeenCalledWith({ + allowedExtensions: ['pdf', 'docx', 'txt'], + supportBatchUpload: true, + }) + }) + + it('should pass supportBatchUpload true by default', () => { + render(<LocalFile allowedExtensions={['pdf']} />) + + expect(mockUseLocalFileUpload).toHaveBeenCalledWith( + expect.objectContaining({ supportBatchUpload: true }), + ) + }) + + it('should pass supportBatchUpload false when specified', () => { + render(<LocalFile allowedExtensions={['pdf']} supportBatchUpload={false} />) + + expect(mockUseLocalFileUpload).toHaveBeenCalledWith( + expect.objectContaining({ supportBatchUpload: false }), + ) + }) + }) + + describe('props passed to UploadDropzone', () => { + it('should pass all required props to UploadDropzone', () => { + const selectHandle = vi.fn() + const fileChangeHandle = vi.fn() + + mockUseLocalFileUpload.mockReturnValue({ + ...defaultHookReturn, + selectHandle, + fileChangeHandle, + supportTypesShowNames: 'PDF, DOCX', + acceptTypes: ['.pdf', '.docx'], + fileUploadConfig: { + file_size_limit: 20, + batch_count_limit: 10, + file_upload_limit: 50, + }, + }) + + render(<LocalFile allowedExtensions={['pdf', 'docx']} supportBatchUpload={true} />) + + // Verify the dropzone is rendered with correct configuration + const fileInput = document.getElementById('fileUploader') + expect(fileInput).toBeInTheDocument() + expect(fileInput).toHaveAttribute('accept', '.pdf,.docx') + expect(fileInput).toHaveAttribute('multiple') + }) + }) + + describe('props passed to FileListItem', () => { + it('should pass correct props to file items', () => { + const handlePreview = vi.fn() + const removeFile = vi.fn() + const mockFile = { + name: 'document.pdf', + size: 2048, + type: 'application/pdf', + lastModified: Date.now(), + } as File + + mockUseLocalFileUpload.mockReturnValue({ + ...defaultHookReturn, + handlePreview, + removeFile, + localFileList: [ + { fileID: 'test-id', file: mockFile, progress: 50 }, + ], + }) + + render(<LocalFile allowedExtensions={['pdf']} />) + + expect(screen.getByTestId('document-icon')).toHaveTextContent('document.pdf') + }) + }) + + describe('conditional rendering', () => { + it('should show both dropzone and file list when files exist and hideUpload is false', () => { + const mockFile = { + name: 'test.pdf', + size: 1024, + type: 'application/pdf', + lastModified: Date.now(), + } as File + + mockUseLocalFileUpload.mockReturnValue({ + ...defaultHookReturn, + hideUpload: false, + localFileList: [ + { fileID: 'file-1', file: mockFile, progress: -1 }, + ], + }) + + render(<LocalFile allowedExtensions={['pdf']} />) + + expect(document.getElementById('fileUploader')).toBeInTheDocument() + expect(screen.getByTestId('document-icon')).toBeInTheDocument() + }) + + it('should show only file list when hideUpload is true', () => { + const mockFile = { + name: 'test.pdf', + size: 1024, + type: 'application/pdf', + lastModified: Date.now(), + } as File + + mockUseLocalFileUpload.mockReturnValue({ + ...defaultHookReturn, + hideUpload: true, + localFileList: [ + { fileID: 'file-1', file: mockFile, progress: -1 }, + ], + }) + + render(<LocalFile allowedExtensions={['pdf']} />) + + expect(document.getElementById('fileUploader')).not.toBeInTheDocument() + expect(screen.getByTestId('document-icon')).toBeInTheDocument() + }) + }) + + describe('file list container styling', () => { + it('should apply correct container classes for file list', () => { + const mockFile = { + name: 'test.pdf', + size: 1024, + type: 'application/pdf', + lastModified: Date.now(), + } as File + + mockUseLocalFileUpload.mockReturnValue({ + ...defaultHookReturn, + localFileList: [ + { fileID: 'file-1', file: mockFile, progress: -1 }, + ], + }) + + const { container } = render(<LocalFile allowedExtensions={['pdf']} />) + + const fileListContainer = container.querySelector('.mt-1.flex.flex-col.gap-y-1') + expect(fileListContainer).toBeInTheDocument() + }) + }) + + describe('edge cases', () => { + it('should handle empty allowedExtensions', () => { + render(<LocalFile allowedExtensions={[]} />) + + expect(mockUseLocalFileUpload).toHaveBeenCalledWith({ + allowedExtensions: [], + supportBatchUpload: true, + }) + }) + + it('should handle files with same fileID but different index', () => { + const mockFile = { + name: 'test.pdf', + size: 1024, + type: 'application/pdf', + lastModified: Date.now(), + } as File + + mockUseLocalFileUpload.mockReturnValue({ + ...defaultHookReturn, + localFileList: [ + { fileID: 'same-id', file: { ...mockFile, name: 'doc1.pdf' } as File, progress: -1 }, + { fileID: 'same-id', file: { ...mockFile, name: 'doc2.pdf' } as File, progress: -1 }, + ], + }) + + // Should render without key collision errors due to index in key + render(<LocalFile allowedExtensions={['pdf']} />) + + const icons = screen.getAllByTestId('document-icon') + expect(icons).toHaveLength(2) + }) + }) + + describe('component integration', () => { + it('should render complete component tree', () => { + const mockFile = { + name: 'complete-test.pdf', + size: 5 * 1024, + type: 'application/pdf', + lastModified: Date.now(), + } as File + + mockUseLocalFileUpload.mockReturnValue({ + ...defaultHookReturn, + hideUpload: false, + localFileList: [ + { fileID: 'file-1', file: mockFile, progress: 50 }, + ], + dragging: false, + }) + + const { container } = render( + <LocalFile allowedExtensions={['pdf', 'docx']} supportBatchUpload={true} />, + ) + + // Main container + expect(container.firstChild).toHaveClass('flex', 'flex-col') + + // Dropzone exists + expect(document.getElementById('fileUploader')).toBeInTheDocument() + + // File list exists + expect(screen.getByTestId('document-icon')).toBeInTheDocument() + }) + }) +}) diff --git a/web/app/components/datasets/documents/create-from-pipeline/data-source/local-file/index.tsx b/web/app/components/datasets/documents/create-from-pipeline/data-source/local-file/index.tsx index d02d5927f2..cb3632ba9d 100644 --- a/web/app/components/datasets/documents/create-from-pipeline/data-source/local-file/index.tsx +++ b/web/app/components/datasets/documents/create-from-pipeline/data-source/local-file/index.tsx @@ -1,26 +1,7 @@ 'use client' -import type { CustomFile as File, FileItem } from '@/models/datasets' -import { RiDeleteBinLine, RiErrorWarningFill, RiUploadCloud2Line } from '@remixicon/react' -import { produce } from 'immer' -import dynamic from 'next/dynamic' -import * as React from 'react' -import { useCallback, useEffect, useMemo, useRef, useState } from 'react' -import { useTranslation } from 'react-i18next' -import { useContext } from 'use-context-selector' -import { getFileUploadErrorMessage } from '@/app/components/base/file-uploader/utils' -import { ToastContext } from '@/app/components/base/toast' -import DocumentFileIcon from '@/app/components/datasets/common/document-file-icon' -import { IS_CE_EDITION } from '@/config' -import { useLocale } from '@/context/i18n' -import useTheme from '@/hooks/use-theme' -import { LanguagesSupported } from '@/i18n-config/language' -import { upload } from '@/service/base' -import { useFileUploadConfig } from '@/service/use-common' -import { Theme } from '@/types/app' -import { cn } from '@/utils/classnames' -import { useDataSourceStore, useDataSourceStoreWithSelector } from '../store' - -const SimplePieChart = dynamic(() => import('@/app/components/base/simple-pie-chart'), { ssr: false }) +import FileListItem from './components/file-list-item' +import UploadDropzone from './components/upload-dropzone' +import { useLocalFileUpload } from './hooks/use-local-file-upload' export type LocalFileProps = { allowedExtensions: string[] @@ -31,345 +12,49 @@ const LocalFile = ({ allowedExtensions, supportBatchUpload = true, }: LocalFileProps) => { - const { t } = useTranslation() - const { notify } = useContext(ToastContext) - const locale = useLocale() - const localFileList = useDataSourceStoreWithSelector(state => state.localFileList) - const dataSourceStore = useDataSourceStore() - const [dragging, setDragging] = useState(false) - - const dropRef = useRef<HTMLDivElement>(null) - const dragRef = useRef<HTMLDivElement>(null) - const fileUploader = useRef<HTMLInputElement>(null) - const fileListRef = useRef<FileItem[]>([]) - - const hideUpload = !supportBatchUpload && localFileList.length > 0 - - const { data: fileUploadConfigResponse } = useFileUploadConfig() - const supportTypesShowNames = useMemo(() => { - const extensionMap: { [key: string]: string } = { - md: 'markdown', - pptx: 'pptx', - htm: 'html', - xlsx: 'xlsx', - docx: 'docx', - } - - return allowedExtensions - .map(item => extensionMap[item] || item) // map to standardized extension - .map(item => item.toLowerCase()) // convert to lower case - .filter((item, index, self) => self.indexOf(item) === index) // remove duplicates - .map(item => item.toUpperCase()) // convert to upper case - .join(locale !== LanguagesSupported[1] ? ', ' : 'ใ€ ') - }, [locale, allowedExtensions]) - const ACCEPTS = allowedExtensions.map((ext: string) => `.${ext}`) - const fileUploadConfig = useMemo(() => ({ - file_size_limit: fileUploadConfigResponse?.file_size_limit ?? 15, - batch_count_limit: supportBatchUpload ? (fileUploadConfigResponse?.batch_count_limit ?? 5) : 1, - file_upload_limit: supportBatchUpload ? (fileUploadConfigResponse?.file_upload_limit ?? 5) : 1, - }), [fileUploadConfigResponse, supportBatchUpload]) - - const updateFile = useCallback((fileItem: FileItem, progress: number, list: FileItem[]) => { - const { setLocalFileList } = dataSourceStore.getState() - const newList = produce(list, (draft) => { - const targetIndex = draft.findIndex(file => file.fileID === fileItem.fileID) - draft[targetIndex] = { - ...draft[targetIndex], - progress, - } - }) - setLocalFileList(newList) - }, [dataSourceStore]) - - const updateFileList = useCallback((preparedFiles: FileItem[]) => { - const { setLocalFileList } = dataSourceStore.getState() - setLocalFileList(preparedFiles) - }, [dataSourceStore]) - - const handlePreview = useCallback((file: File) => { - const { setCurrentLocalFile } = dataSourceStore.getState() - if (file.id) - setCurrentLocalFile(file) - }, [dataSourceStore]) - - // utils - const getFileType = (currentFile: File) => { - if (!currentFile) - return '' - - const arr = currentFile.name.split('.') - return arr[arr.length - 1] - } - - const getFileSize = (size: number) => { - if (size / 1024 < 10) - return `${(size / 1024).toFixed(2)}KB` - - return `${(size / 1024 / 1024).toFixed(2)}MB` - } - - const isValid = useCallback((file: File) => { - const { size } = file - const ext = `.${getFileType(file)}` - const isValidType = ACCEPTS.includes(ext.toLowerCase()) - if (!isValidType) - notify({ type: 'error', message: t('stepOne.uploader.validation.typeError', { ns: 'datasetCreation' }) }) - - const isValidSize = size <= fileUploadConfig.file_size_limit * 1024 * 1024 - if (!isValidSize) - notify({ type: 'error', message: t('stepOne.uploader.validation.size', { ns: 'datasetCreation', size: fileUploadConfig.file_size_limit }) }) - - return isValidType && isValidSize - }, [notify, t, ACCEPTS, fileUploadConfig.file_size_limit]) - - type UploadResult = Awaited<ReturnType<typeof upload>> - - const fileUpload = useCallback(async (fileItem: FileItem): Promise<FileItem> => { - const formData = new FormData() - formData.append('file', fileItem.file) - const onProgress = (e: ProgressEvent) => { - if (e.lengthComputable) { - const percent = Math.floor(e.loaded / e.total * 100) - updateFile(fileItem, percent, fileListRef.current) - } - } - - return upload({ - xhr: new XMLHttpRequest(), - data: formData, - onprogress: onProgress, - }, false, undefined, '?source=datasets') - .then((res: UploadResult) => { - const updatedFile = Object.assign({}, fileItem.file, { - id: res.id, - ...(res as Partial<File>), - }) as File - const completeFile: FileItem = { - fileID: fileItem.fileID, - file: updatedFile, - progress: -1, - } - const index = fileListRef.current.findIndex(item => item.fileID === fileItem.fileID) - fileListRef.current[index] = completeFile - updateFile(completeFile, 100, fileListRef.current) - return Promise.resolve({ ...completeFile }) - }) - .catch((e) => { - const errorMessage = getFileUploadErrorMessage(e, t('stepOne.uploader.failed', { ns: 'datasetCreation' }), t) - notify({ type: 'error', message: errorMessage }) - updateFile(fileItem, -2, fileListRef.current) - return Promise.resolve({ ...fileItem }) - }) - .finally() - }, [fileListRef, notify, updateFile, t]) - - const uploadBatchFiles = useCallback((bFiles: FileItem[]) => { - bFiles.forEach(bf => (bf.progress = 0)) - return Promise.all(bFiles.map(fileUpload)) - }, [fileUpload]) - - const uploadMultipleFiles = useCallback(async (files: FileItem[]) => { - const batchCountLimit = fileUploadConfig.batch_count_limit - const length = files.length - let start = 0 - let end = 0 - - while (start < length) { - if (start + batchCountLimit > length) - end = length - else - end = start + batchCountLimit - const bFiles = files.slice(start, end) - await uploadBatchFiles(bFiles) - start = end - } - }, [fileUploadConfig, uploadBatchFiles]) - - const initialUpload = useCallback((files: File[]) => { - const filesCountLimit = fileUploadConfig.file_upload_limit - if (!files.length) - return false - - if (files.length + localFileList.length > filesCountLimit && !IS_CE_EDITION) { - notify({ type: 'error', message: t('stepOne.uploader.validation.filesNumber', { ns: 'datasetCreation', filesNumber: filesCountLimit }) }) - return false - } - - const preparedFiles = files.map((file, index) => ({ - fileID: `file${index}-${Date.now()}`, - file, - progress: -1, - })) - const newFiles = [...fileListRef.current, ...preparedFiles] - updateFileList(newFiles) - fileListRef.current = newFiles - uploadMultipleFiles(preparedFiles) - }, [fileUploadConfig.file_upload_limit, localFileList.length, updateFileList, uploadMultipleFiles, notify, t]) - - const handleDragEnter = (e: DragEvent) => { - e.preventDefault() - e.stopPropagation() - if (e.target !== dragRef.current) - setDragging(true) - } - const handleDragOver = (e: DragEvent) => { - e.preventDefault() - e.stopPropagation() - } - const handleDragLeave = (e: DragEvent) => { - e.preventDefault() - e.stopPropagation() - if (e.target === dragRef.current) - setDragging(false) - } - - const handleDrop = useCallback((e: DragEvent) => { - e.preventDefault() - e.stopPropagation() - setDragging(false) - if (!e.dataTransfer) - return - - let files = Array.from(e.dataTransfer.files) as File[] - if (!supportBatchUpload) - files = files.slice(0, 1) - - const validFiles = files.filter(isValid) - initialUpload(validFiles) - }, [initialUpload, isValid, supportBatchUpload]) - - const selectHandle = useCallback(() => { - if (fileUploader.current) - fileUploader.current.click() - }, []) - - const removeFile = (fileID: string) => { - if (fileUploader.current) - fileUploader.current.value = '' - - fileListRef.current = fileListRef.current.filter(item => item.fileID !== fileID) - updateFileList([...fileListRef.current]) - } - const fileChangeHandle = useCallback((e: React.ChangeEvent<HTMLInputElement>) => { - let files = Array.from(e.target.files ?? []) as File[] - files = files.slice(0, fileUploadConfig.batch_count_limit) - initialUpload(files.filter(isValid)) - }, [isValid, initialUpload, fileUploadConfig.batch_count_limit]) - - const { theme } = useTheme() - const chartColor = useMemo(() => theme === Theme.dark ? '#5289ff' : '#296dff', [theme]) - - useEffect(() => { - const dropElement = dropRef.current - dropElement?.addEventListener('dragenter', handleDragEnter) - dropElement?.addEventListener('dragover', handleDragOver) - dropElement?.addEventListener('dragleave', handleDragLeave) - dropElement?.addEventListener('drop', handleDrop) - return () => { - dropElement?.removeEventListener('dragenter', handleDragEnter) - dropElement?.removeEventListener('dragover', handleDragOver) - dropElement?.removeEventListener('dragleave', handleDragLeave) - dropElement?.removeEventListener('drop', handleDrop) - } - }, [handleDrop]) + const { + dropRef, + dragRef, + fileUploaderRef, + dragging, + localFileList, + fileUploadConfig, + acceptTypes, + supportTypesShowNames, + hideUpload, + selectHandle, + fileChangeHandle, + removeFile, + handlePreview, + } = useLocalFileUpload({ allowedExtensions, supportBatchUpload }) return ( <div className="flex flex-col"> {!hideUpload && ( - <input - ref={fileUploader} - id="fileUploader" - className="hidden" - type="file" - multiple={supportBatchUpload} - accept={ACCEPTS.join(',')} - onChange={fileChangeHandle} + <UploadDropzone + dropRef={dropRef} + dragRef={dragRef} + fileUploaderRef={fileUploaderRef} + dragging={dragging} + supportBatchUpload={supportBatchUpload} + supportTypesShowNames={supportTypesShowNames} + fileUploadConfig={fileUploadConfig} + acceptTypes={acceptTypes} + onSelectFile={selectHandle} + onFileChange={fileChangeHandle} + allowedExtensions={allowedExtensions} /> )} - {!hideUpload && ( - <div - ref={dropRef} - className={cn( - 'relative box-border flex min-h-20 flex-col items-center justify-center gap-1 rounded-xl border border-dashed border-components-dropzone-border bg-components-dropzone-bg px-4 py-3 text-xs leading-4 text-text-tertiary', - dragging && 'border-components-dropzone-border-accent bg-components-dropzone-bg-accent', - )} - > - <div className="flex min-h-5 items-center justify-center text-sm leading-4 text-text-secondary"> - <RiUploadCloud2Line className="mr-2 size-5" /> - - <span> - {supportBatchUpload ? t('stepOne.uploader.button', { ns: 'datasetCreation' }) : t('stepOne.uploader.buttonSingleFile', { ns: 'datasetCreation' })} - {allowedExtensions.length > 0 && ( - <label className="ml-1 cursor-pointer text-text-accent" onClick={selectHandle}>{t('stepOne.uploader.browse', { ns: 'datasetCreation' })}</label> - )} - </span> - </div> - <div> - {t('stepOne.uploader.tip', { - ns: 'datasetCreation', - size: fileUploadConfig.file_size_limit, - supportTypes: supportTypesShowNames, - batchCount: fileUploadConfig.batch_count_limit, - totalCount: fileUploadConfig.file_upload_limit, - })} - </div> - {dragging && <div ref={dragRef} className="absolute left-0 top-0 h-full w-full" />} - </div> - )} {localFileList.length > 0 && ( <div className="mt-1 flex flex-col gap-y-1"> - {localFileList.map((fileItem, index) => { - const isUploading = fileItem.progress >= 0 && fileItem.progress < 100 - const isError = fileItem.progress === -2 - return ( - <div - key={`${fileItem.fileID}-${index}`} - onClick={handlePreview.bind(null, fileItem.file)} - className={cn( - 'flex h-12 items-center rounded-lg border border-components-panel-border bg-components-panel-on-panel-item-bg shadow-xs shadow-shadow-shadow-4', - isError && 'border-state-destructive-border bg-state-destructive-hover', - )} - > - <div className="flex w-12 shrink-0 items-center justify-center"> - <DocumentFileIcon - size="lg" - className="shrink-0" - name={fileItem.file.name} - extension={getFileType(fileItem.file)} - /> - </div> - <div className="flex shrink grow flex-col gap-0.5"> - <div className="flex w-full"> - <div className="w-0 grow truncate text-xs text-text-secondary">{fileItem.file.name}</div> - </div> - <div className="w-full truncate text-2xs leading-3 text-text-tertiary"> - <span className="uppercase">{getFileType(fileItem.file)}</span> - <span className="px-1 text-text-quaternary">ยท</span> - <span>{getFileSize(fileItem.file.size)}</span> - </div> - </div> - <div className="flex w-16 shrink-0 items-center justify-end gap-1 pr-3"> - {isUploading && ( - <SimplePieChart percentage={fileItem.progress} stroke={chartColor} fill={chartColor} animationDuration={0} /> - )} - { - isError && ( - <RiErrorWarningFill className="size-4 text-text-destructive" /> - ) - } - <span - className="flex h-6 w-6 cursor-pointer items-center justify-center" - onClick={(e) => { - e.stopPropagation() - removeFile(fileItem.fileID) - }} - > - <RiDeleteBinLine className="size-4 text-text-tertiary" /> - </span> - </div> - </div> - ) - })} + {localFileList.map((fileItem, index) => ( + <FileListItem + key={`${fileItem.fileID}-${index}`} + fileItem={fileItem} + onPreview={handlePreview} + onRemove={removeFile} + /> + ))} </div> )} </div> diff --git a/web/app/components/datasets/settings/form/components/basic-info-section.spec.tsx b/web/app/components/datasets/settings/form/components/basic-info-section.spec.tsx new file mode 100644 index 0000000000..28085e52fa --- /dev/null +++ b/web/app/components/datasets/settings/form/components/basic-info-section.spec.tsx @@ -0,0 +1,441 @@ +import type { Member } from '@/models/common' +import type { DataSet, IconInfo } from '@/models/datasets' +import type { RetrievalConfig } from '@/types/app' +import { fireEvent, render, screen, waitFor } from '@testing-library/react' +import { ChunkingMode, DatasetPermission, DataSourceType } from '@/models/datasets' +import { RETRIEVE_METHOD } from '@/types/app' +import { IndexingType } from '../../../create/step-two' +import BasicInfoSection from './basic-info-section' + +// Mock app-context +vi.mock('@/context/app-context', () => ({ + useSelector: () => ({ + id: 'user-1', + name: 'Current User', + email: 'current@example.com', + avatar_url: '', + role: 'owner', + }), +})) + +// Mock image uploader hooks for AppIconPicker +vi.mock('@/app/components/base/image-uploader/hooks', () => ({ + useLocalFileUploader: () => ({ + disabled: false, + handleLocalFileUpload: vi.fn(), + }), + useImageFiles: () => ({ + files: [], + onUpload: vi.fn(), + onRemove: vi.fn(), + onReUpload: vi.fn(), + onImageLinkLoadError: vi.fn(), + onImageLinkLoadSuccess: vi.fn(), + onClear: vi.fn(), + }), +})) + +describe('BasicInfoSection', () => { + const mockDataset: DataSet = { + id: 'dataset-1', + name: 'Test Dataset', + description: 'Test description', + permission: DatasetPermission.onlyMe, + icon_info: { + icon_type: 'emoji', + icon: '๐Ÿ“š', + icon_background: '#FFFFFF', + icon_url: '', + }, + indexing_technique: IndexingType.QUALIFIED, + indexing_status: 'completed', + data_source_type: DataSourceType.FILE, + doc_form: ChunkingMode.text, + embedding_model: 'text-embedding-ada-002', + embedding_model_provider: 'openai', + embedding_available: true, + app_count: 0, + document_count: 5, + total_document_count: 5, + word_count: 1000, + provider: 'vendor', + tags: [], + partial_member_list: [], + external_knowledge_info: { + external_knowledge_id: 'ext-1', + external_knowledge_api_id: 'api-1', + external_knowledge_api_name: 'External API', + external_knowledge_api_endpoint: 'https://api.example.com', + }, + external_retrieval_model: { + top_k: 3, + score_threshold: 0.7, + score_threshold_enabled: true, + }, + retrieval_model_dict: { + search_method: RETRIEVE_METHOD.semantic, + reranking_enable: false, + reranking_model: { + reranking_provider_name: '', + reranking_model_name: '', + }, + top_k: 3, + score_threshold_enabled: false, + score_threshold: 0.5, + } as RetrievalConfig, + retrieval_model: { + search_method: RETRIEVE_METHOD.semantic, + reranking_enable: false, + reranking_model: { + reranking_provider_name: '', + reranking_model_name: '', + }, + top_k: 3, + score_threshold_enabled: false, + score_threshold: 0.5, + } as RetrievalConfig, + built_in_field_enabled: false, + keyword_number: 10, + created_by: 'user-1', + updated_by: 'user-1', + updated_at: Date.now(), + runtime_mode: 'general', + enable_api: true, + is_multimodal: false, + } + + const mockMemberList: Member[] = [ + { id: 'user-1', name: 'User 1', email: 'user1@example.com', role: 'owner', avatar: '', avatar_url: '', last_login_at: '', created_at: '', status: 'active' }, + { id: 'user-2', name: 'User 2', email: 'user2@example.com', role: 'admin', avatar: '', avatar_url: '', last_login_at: '', created_at: '', status: 'active' }, + ] + + const mockIconInfo: IconInfo = { + icon_type: 'emoji', + icon: '๐Ÿ“š', + icon_background: '#FFFFFF', + icon_url: '', + } + + const defaultProps = { + currentDataset: mockDataset, + isCurrentWorkspaceDatasetOperator: false, + name: 'Test Dataset', + setName: vi.fn(), + description: 'Test description', + setDescription: vi.fn(), + iconInfo: mockIconInfo, + showAppIconPicker: false, + handleOpenAppIconPicker: vi.fn(), + handleSelectAppIcon: vi.fn(), + handleCloseAppIconPicker: vi.fn(), + permission: DatasetPermission.onlyMe, + setPermission: vi.fn(), + selectedMemberIDs: ['user-1'], + setSelectedMemberIDs: vi.fn(), + memberList: mockMemberList, + } + + beforeEach(() => { + vi.clearAllMocks() + }) + + describe('Rendering', () => { + it('should render without crashing', () => { + render(<BasicInfoSection {...defaultProps} />) + expect(screen.getByText(/form\.nameAndIcon/i)).toBeInTheDocument() + }) + + it('should render name and icon section', () => { + render(<BasicInfoSection {...defaultProps} />) + expect(screen.getByText(/form\.nameAndIcon/i)).toBeInTheDocument() + }) + + it('should render description section', () => { + render(<BasicInfoSection {...defaultProps} />) + expect(screen.getByText(/form\.desc/i)).toBeInTheDocument() + }) + + it('should render permissions section', () => { + render(<BasicInfoSection {...defaultProps} />) + // Use exact match to avoid matching "permissionsOnlyMe" + expect(screen.getByText('datasetSettings.form.permissions')).toBeInTheDocument() + }) + + it('should render name input with correct value', () => { + render(<BasicInfoSection {...defaultProps} />) + const nameInput = screen.getByDisplayValue('Test Dataset') + expect(nameInput).toBeInTheDocument() + }) + + it('should render description textarea with correct value', () => { + render(<BasicInfoSection {...defaultProps} />) + const descriptionTextarea = screen.getByDisplayValue('Test description') + expect(descriptionTextarea).toBeInTheDocument() + }) + + it('should render app icon with emoji', () => { + const { container } = render(<BasicInfoSection {...defaultProps} />) + // The icon section should be rendered (emoji may be in a span or SVG) + const iconSection = container.querySelector('[class*="cursor-pointer"]') + expect(iconSection).toBeInTheDocument() + }) + }) + + describe('Name Input', () => { + it('should call setName when name input changes', () => { + const setName = vi.fn() + render(<BasicInfoSection {...defaultProps} setName={setName} />) + + const nameInput = screen.getByDisplayValue('Test Dataset') + fireEvent.change(nameInput, { target: { value: 'New Name' } }) + + expect(setName).toHaveBeenCalledWith('New Name') + }) + + it('should disable name input when embedding is not available', () => { + const datasetWithoutEmbedding = { ...mockDataset, embedding_available: false } + render(<BasicInfoSection {...defaultProps} currentDataset={datasetWithoutEmbedding} />) + + const nameInput = screen.getByDisplayValue('Test Dataset') + expect(nameInput).toBeDisabled() + }) + + it('should enable name input when embedding is available', () => { + render(<BasicInfoSection {...defaultProps} />) + + const nameInput = screen.getByDisplayValue('Test Dataset') + expect(nameInput).not.toBeDisabled() + }) + + it('should display empty name', () => { + const { container } = render(<BasicInfoSection {...defaultProps} name="" />) + + // Find the name input by its structure - may be type=text or just input + const nameInput = container.querySelector('input') + expect(nameInput).toHaveValue('') + }) + }) + + describe('Description Textarea', () => { + it('should call setDescription when description changes', () => { + const setDescription = vi.fn() + render(<BasicInfoSection {...defaultProps} setDescription={setDescription} />) + + const descriptionTextarea = screen.getByDisplayValue('Test description') + fireEvent.change(descriptionTextarea, { target: { value: 'New Description' } }) + + expect(setDescription).toHaveBeenCalledWith('New Description') + }) + + it('should disable description textarea when embedding is not available', () => { + const datasetWithoutEmbedding = { ...mockDataset, embedding_available: false } + render(<BasicInfoSection {...defaultProps} currentDataset={datasetWithoutEmbedding} />) + + const descriptionTextarea = screen.getByDisplayValue('Test description') + expect(descriptionTextarea).toBeDisabled() + }) + + it('should render placeholder', () => { + render(<BasicInfoSection {...defaultProps} description="" />) + + const descriptionTextarea = screen.getByPlaceholderText(/form\.descPlaceholder/i) + expect(descriptionTextarea).toBeInTheDocument() + }) + }) + + describe('App Icon', () => { + it('should call handleOpenAppIconPicker when icon is clicked', () => { + const handleOpenAppIconPicker = vi.fn() + const { container } = render(<BasicInfoSection {...defaultProps} handleOpenAppIconPicker={handleOpenAppIconPicker} />) + + // Find the clickable icon element - it's inside a wrapper that handles the click + const iconWrapper = container.querySelector('[class*="cursor-pointer"]') + if (iconWrapper) { + fireEvent.click(iconWrapper) + expect(handleOpenAppIconPicker).toHaveBeenCalled() + } + }) + + it('should render AppIconPicker when showAppIconPicker is true', () => { + const { baseElement } = render(<BasicInfoSection {...defaultProps} showAppIconPicker={true} />) + + // AppIconPicker renders a modal with emoji tabs and options via portal + // We just verify the component renders without crashing when picker is shown + expect(baseElement).toBeInTheDocument() + }) + + it('should not render AppIconPicker when showAppIconPicker is false', () => { + const { container } = render(<BasicInfoSection {...defaultProps} showAppIconPicker={false} />) + + // Check that AppIconPicker is not rendered + expect(container.querySelector('[data-testid="app-icon-picker"]')).not.toBeInTheDocument() + }) + + it('should render image icon when icon_type is image', () => { + const imageIconInfo: IconInfo = { + icon_type: 'image', + icon: 'file-123', + icon_background: undefined, + icon_url: 'https://example.com/icon.png', + } + render(<BasicInfoSection {...defaultProps} iconInfo={imageIconInfo} />) + + // For image type, it renders an img element + const img = screen.queryByRole('img') + if (img) { + expect(img).toHaveAttribute('src', expect.stringContaining('icon.png')) + } + }) + }) + + describe('Permission Selector', () => { + it('should render with correct permission value', () => { + render(<BasicInfoSection {...defaultProps} permission={DatasetPermission.onlyMe} />) + + expect(screen.getByText(/form\.permissionsOnlyMe/i)).toBeInTheDocument() + }) + + it('should render all team members permission', () => { + render(<BasicInfoSection {...defaultProps} permission={DatasetPermission.allTeamMembers} />) + + expect(screen.getByText(/form\.permissionsAllMember/i)).toBeInTheDocument() + }) + + it('should be disabled when embedding is not available', () => { + const datasetWithoutEmbedding = { ...mockDataset, embedding_available: false } + const { container } = render( + <BasicInfoSection {...defaultProps} currentDataset={datasetWithoutEmbedding} />, + ) + + // Check for disabled state via cursor-not-allowed class + const disabledElement = container.querySelector('[class*="cursor-not-allowed"]') + expect(disabledElement).toBeInTheDocument() + }) + + it('should be disabled when user is dataset operator', () => { + const { container } = render( + <BasicInfoSection {...defaultProps} isCurrentWorkspaceDatasetOperator={true} />, + ) + + const disabledElement = container.querySelector('[class*="cursor-not-allowed"]') + expect(disabledElement).toBeInTheDocument() + }) + + it('should call setPermission when permission changes', async () => { + const setPermission = vi.fn() + render(<BasicInfoSection {...defaultProps} setPermission={setPermission} />) + + // Open dropdown + const trigger = screen.getByText(/form\.permissionsOnlyMe/i) + fireEvent.click(trigger) + + await waitFor(() => { + // Click All Team Members option + const allMemberOptions = screen.getAllByText(/form\.permissionsAllMember/i) + fireEvent.click(allMemberOptions[0]) + }) + + expect(setPermission).toHaveBeenCalledWith(DatasetPermission.allTeamMembers) + }) + + it('should call setSelectedMemberIDs when members are selected', async () => { + const setSelectedMemberIDs = vi.fn() + const { container } = render( + <BasicInfoSection + {...defaultProps} + permission={DatasetPermission.partialMembers} + setSelectedMemberIDs={setSelectedMemberIDs} + />, + ) + + // For partial members permission, the member selector should be visible + // The exact interaction depends on the MemberSelector component + // We verify the component renders without crashing + expect(container).toBeInTheDocument() + }) + }) + + describe('Undefined Dataset', () => { + it('should handle undefined currentDataset gracefully', () => { + render(<BasicInfoSection {...defaultProps} currentDataset={undefined} />) + + // Should still render but inputs might behave differently + expect(screen.getByText(/form\.nameAndIcon/i)).toBeInTheDocument() + }) + }) + + describe('Props Validation', () => { + it('should update when name prop changes', () => { + const { rerender } = render(<BasicInfoSection {...defaultProps} name="Initial Name" />) + + expect(screen.getByDisplayValue('Initial Name')).toBeInTheDocument() + + rerender(<BasicInfoSection {...defaultProps} name="Updated Name" />) + + expect(screen.getByDisplayValue('Updated Name')).toBeInTheDocument() + }) + + it('should update when description prop changes', () => { + const { rerender } = render(<BasicInfoSection {...defaultProps} description="Initial Description" />) + + expect(screen.getByDisplayValue('Initial Description')).toBeInTheDocument() + + rerender(<BasicInfoSection {...defaultProps} description="Updated Description" />) + + expect(screen.getByDisplayValue('Updated Description')).toBeInTheDocument() + }) + + it('should update when permission prop changes', () => { + const { rerender } = render(<BasicInfoSection {...defaultProps} permission={DatasetPermission.onlyMe} />) + + expect(screen.getByText(/form\.permissionsOnlyMe/i)).toBeInTheDocument() + + rerender(<BasicInfoSection {...defaultProps} permission={DatasetPermission.allTeamMembers} />) + + expect(screen.getByText(/form\.permissionsAllMember/i)).toBeInTheDocument() + }) + }) + + describe('Member List', () => { + it('should pass member list to PermissionSelector', () => { + const { container } = render( + <BasicInfoSection + {...defaultProps} + permission={DatasetPermission.partialMembers} + memberList={mockMemberList} + />, + ) + + // For partial members, a member selector component should be rendered + // We verify it renders without crashing + expect(container).toBeInTheDocument() + }) + + it('should handle empty member list', () => { + render( + <BasicInfoSection + {...defaultProps} + memberList={[]} + />, + ) + + expect(screen.getByText(/form\.permissionsOnlyMe/i)).toBeInTheDocument() + }) + }) + + describe('Accessibility', () => { + it('should have accessible name input', () => { + render(<BasicInfoSection {...defaultProps} />) + + const nameInput = screen.getByDisplayValue('Test Dataset') + expect(nameInput.tagName.toLowerCase()).toBe('input') + }) + + it('should have accessible description textarea', () => { + render(<BasicInfoSection {...defaultProps} />) + + const descriptionTextarea = screen.getByDisplayValue('Test description') + expect(descriptionTextarea.tagName.toLowerCase()).toBe('textarea') + }) + }) +}) diff --git a/web/app/components/datasets/settings/form/components/basic-info-section.tsx b/web/app/components/datasets/settings/form/components/basic-info-section.tsx new file mode 100644 index 0000000000..3d3cf75851 --- /dev/null +++ b/web/app/components/datasets/settings/form/components/basic-info-section.tsx @@ -0,0 +1,124 @@ +'use client' +import type { AppIconSelection } from '@/app/components/base/app-icon-picker' +import type { Member } from '@/models/common' +import type { DataSet, DatasetPermission, IconInfo } from '@/models/datasets' +import type { AppIconType } from '@/types/app' +import { useTranslation } from 'react-i18next' +import AppIcon from '@/app/components/base/app-icon' +import AppIconPicker from '@/app/components/base/app-icon-picker' +import Input from '@/app/components/base/input' +import Textarea from '@/app/components/base/textarea' +import PermissionSelector from '../../permission-selector' + +const rowClass = 'flex gap-x-1' +const labelClass = 'flex items-center shrink-0 w-[180px] h-7 pt-1' + +type BasicInfoSectionProps = { + currentDataset: DataSet | undefined + isCurrentWorkspaceDatasetOperator: boolean + name: string + setName: (value: string) => void + description: string + setDescription: (value: string) => void + iconInfo: IconInfo + showAppIconPicker: boolean + handleOpenAppIconPicker: () => void + handleSelectAppIcon: (icon: AppIconSelection) => void + handleCloseAppIconPicker: () => void + permission: DatasetPermission | undefined + setPermission: (value: DatasetPermission | undefined) => void + selectedMemberIDs: string[] + setSelectedMemberIDs: (value: string[]) => void + memberList: Member[] +} + +const BasicInfoSection = ({ + currentDataset, + isCurrentWorkspaceDatasetOperator, + name, + setName, + description, + setDescription, + iconInfo, + showAppIconPicker, + handleOpenAppIconPicker, + handleSelectAppIcon, + handleCloseAppIconPicker, + permission, + setPermission, + selectedMemberIDs, + setSelectedMemberIDs, + memberList, +}: BasicInfoSectionProps) => { + const { t } = useTranslation() + + return ( + <> + {/* Dataset name and icon */} + <div className={rowClass}> + <div className={labelClass}> + <div className="system-sm-semibold text-text-secondary">{t('form.nameAndIcon', { ns: 'datasetSettings' })}</div> + </div> + <div className="flex grow items-center gap-x-2"> + <AppIcon + size="small" + onClick={handleOpenAppIconPicker} + className="cursor-pointer" + iconType={iconInfo.icon_type as AppIconType} + icon={iconInfo.icon} + background={iconInfo.icon_background} + imageUrl={iconInfo.icon_url} + showEditIcon + /> + <Input + disabled={!currentDataset?.embedding_available} + value={name} + onChange={e => setName(e.target.value)} + /> + </div> + </div> + + {/* Dataset description */} + <div className={rowClass}> + <div className={labelClass}> + <div className="system-sm-semibold text-text-secondary">{t('form.desc', { ns: 'datasetSettings' })}</div> + </div> + <div className="grow"> + <Textarea + disabled={!currentDataset?.embedding_available} + className="resize-none" + placeholder={t('form.descPlaceholder', { ns: 'datasetSettings' }) || ''} + value={description} + onChange={e => setDescription(e.target.value)} + /> + </div> + </div> + + {/* Permissions */} + <div className={rowClass}> + <div className={labelClass}> + <div className="system-sm-semibold text-text-secondary">{t('form.permissions', { ns: 'datasetSettings' })}</div> + </div> + <div className="grow"> + <PermissionSelector + disabled={!currentDataset?.embedding_available || isCurrentWorkspaceDatasetOperator} + permission={permission} + value={selectedMemberIDs} + onChange={v => setPermission(v)} + onMemberSelect={setSelectedMemberIDs} + memberList={memberList} + /> + </div> + </div> + + {showAppIconPicker && ( + <AppIconPicker + onSelect={handleSelectAppIcon} + onClose={handleCloseAppIconPicker} + /> + )} + </> + ) +} + +export default BasicInfoSection diff --git a/web/app/components/datasets/settings/form/components/external-knowledge-section.spec.tsx b/web/app/components/datasets/settings/form/components/external-knowledge-section.spec.tsx new file mode 100644 index 0000000000..96512b5aca --- /dev/null +++ b/web/app/components/datasets/settings/form/components/external-knowledge-section.spec.tsx @@ -0,0 +1,362 @@ +import type { DataSet } from '@/models/datasets' +import type { RetrievalConfig } from '@/types/app' +import { render, screen } from '@testing-library/react' +import { ChunkingMode, DatasetPermission, DataSourceType } from '@/models/datasets' +import { RETRIEVE_METHOD } from '@/types/app' +import { IndexingType } from '../../../create/step-two' +import ExternalKnowledgeSection from './external-knowledge-section' + +describe('ExternalKnowledgeSection', () => { + const mockRetrievalConfig: RetrievalConfig = { + search_method: RETRIEVE_METHOD.semantic, + reranking_enable: false, + reranking_model: { + reranking_provider_name: '', + reranking_model_name: '', + }, + top_k: 3, + score_threshold_enabled: false, + score_threshold: 0.5, + } + + const mockDataset: DataSet = { + id: 'dataset-1', + name: 'External Dataset', + description: 'External dataset description', + permission: DatasetPermission.onlyMe, + icon_info: { + icon_type: 'emoji', + icon: '๐Ÿ“š', + icon_background: '#FFFFFF', + icon_url: '', + }, + indexing_technique: IndexingType.QUALIFIED, + indexing_status: 'completed', + data_source_type: DataSourceType.FILE, + doc_form: ChunkingMode.text, + embedding_model: 'text-embedding-ada-002', + embedding_model_provider: 'openai', + embedding_available: true, + app_count: 0, + document_count: 5, + total_document_count: 5, + word_count: 1000, + provider: 'external', + tags: [], + partial_member_list: [], + external_knowledge_info: { + external_knowledge_id: 'ext-knowledge-123', + external_knowledge_api_id: 'api-456', + external_knowledge_api_name: 'My External API', + external_knowledge_api_endpoint: 'https://api.external.example.com/v1', + }, + external_retrieval_model: { + top_k: 5, + score_threshold: 0.8, + score_threshold_enabled: true, + }, + retrieval_model_dict: mockRetrievalConfig, + retrieval_model: mockRetrievalConfig, + built_in_field_enabled: false, + keyword_number: 10, + created_by: 'user-1', + updated_by: 'user-1', + updated_at: Date.now(), + runtime_mode: 'general', + enable_api: true, + is_multimodal: false, + } + + const defaultProps = { + currentDataset: mockDataset, + topK: 5, + scoreThreshold: 0.8, + scoreThresholdEnabled: true, + handleSettingsChange: vi.fn(), + } + + beforeEach(() => { + vi.clearAllMocks() + }) + + describe('Rendering', () => { + it('should render without crashing', () => { + render(<ExternalKnowledgeSection {...defaultProps} />) + expect(screen.getByText(/form\.retrievalSetting\.title/i)).toBeInTheDocument() + }) + + it('should render retrieval settings section', () => { + render(<ExternalKnowledgeSection {...defaultProps} />) + expect(screen.getByText(/form\.retrievalSetting\.title/i)).toBeInTheDocument() + }) + + it('should render external knowledge API section', () => { + render(<ExternalKnowledgeSection {...defaultProps} />) + expect(screen.getByText(/form\.externalKnowledgeAPI/i)).toBeInTheDocument() + }) + + it('should render external knowledge ID section', () => { + render(<ExternalKnowledgeSection {...defaultProps} />) + expect(screen.getByText(/form\.externalKnowledgeID/i)).toBeInTheDocument() + }) + }) + + describe('External Knowledge API Info', () => { + it('should display external API name', () => { + render(<ExternalKnowledgeSection {...defaultProps} />) + expect(screen.getByText('My External API')).toBeInTheDocument() + }) + + it('should display external API endpoint', () => { + render(<ExternalKnowledgeSection {...defaultProps} />) + expect(screen.getByText('https://api.external.example.com/v1')).toBeInTheDocument() + }) + + it('should render API connection icon', () => { + const { container } = render(<ExternalKnowledgeSection {...defaultProps} />) + // The ApiConnectionMod icon should be rendered + const icon = container.querySelector('svg') + expect(icon).toBeInTheDocument() + }) + + it('should display API name and endpoint in the same row', () => { + render(<ExternalKnowledgeSection {...defaultProps} />) + + const apiName = screen.getByText('My External API') + const apiEndpoint = screen.getByText('https://api.external.example.com/v1') + + // Both should be in the same container + expect(apiName.parentElement?.parentElement).toBe(apiEndpoint.parentElement?.parentElement) + }) + }) + + describe('External Knowledge ID', () => { + it('should display external knowledge ID value', () => { + render(<ExternalKnowledgeSection {...defaultProps} />) + expect(screen.getByText('ext-knowledge-123')).toBeInTheDocument() + }) + + it('should render ID in a read-only display', () => { + render(<ExternalKnowledgeSection {...defaultProps} />) + + const idElement = screen.getByText('ext-knowledge-123') + // The ID should be in a div with input-like styling, not an actual input + expect(idElement.tagName.toLowerCase()).toBe('div') + }) + }) + + describe('Retrieval Settings', () => { + it('should pass topK to RetrievalSettings', () => { + render(<ExternalKnowledgeSection {...defaultProps} topK={10} />) + + // RetrievalSettings should receive topK prop + // The exact rendering depends on RetrievalSettings component + }) + + it('should pass scoreThreshold to RetrievalSettings', () => { + render(<ExternalKnowledgeSection {...defaultProps} scoreThreshold={0.9} />) + + // RetrievalSettings should receive scoreThreshold prop + }) + + it('should pass scoreThresholdEnabled to RetrievalSettings', () => { + render(<ExternalKnowledgeSection {...defaultProps} scoreThresholdEnabled={false} />) + + // RetrievalSettings should receive scoreThresholdEnabled prop + }) + + it('should call handleSettingsChange when settings change', () => { + const handleSettingsChange = vi.fn() + render(<ExternalKnowledgeSection {...defaultProps} handleSettingsChange={handleSettingsChange} />) + + // The handler should be properly passed to RetrievalSettings + // Actual interaction depends on RetrievalSettings implementation + }) + }) + + describe('Dividers', () => { + it('should render dividers between sections', () => { + const { container } = render(<ExternalKnowledgeSection {...defaultProps} />) + + const dividers = container.querySelectorAll('.bg-divider-subtle') + expect(dividers.length).toBeGreaterThanOrEqual(2) + }) + }) + + describe('Props Updates', () => { + it('should update when currentDataset changes', () => { + const { rerender } = render(<ExternalKnowledgeSection {...defaultProps} />) + + expect(screen.getByText('My External API')).toBeInTheDocument() + + const updatedDataset = { + ...mockDataset, + external_knowledge_info: { + ...mockDataset.external_knowledge_info, + external_knowledge_api_name: 'Updated API Name', + }, + } + + rerender(<ExternalKnowledgeSection {...defaultProps} currentDataset={updatedDataset} />) + + expect(screen.getByText('Updated API Name')).toBeInTheDocument() + }) + + it('should update when external knowledge ID changes', () => { + const { rerender } = render(<ExternalKnowledgeSection {...defaultProps} />) + + expect(screen.getByText('ext-knowledge-123')).toBeInTheDocument() + + const updatedDataset = { + ...mockDataset, + external_knowledge_info: { + ...mockDataset.external_knowledge_info, + external_knowledge_id: 'new-ext-id-789', + }, + } + + rerender(<ExternalKnowledgeSection {...defaultProps} currentDataset={updatedDataset} />) + + expect(screen.getByText('new-ext-id-789')).toBeInTheDocument() + }) + + it('should update when API endpoint changes', () => { + const { rerender } = render(<ExternalKnowledgeSection {...defaultProps} />) + + expect(screen.getByText('https://api.external.example.com/v1')).toBeInTheDocument() + + const updatedDataset = { + ...mockDataset, + external_knowledge_info: { + ...mockDataset.external_knowledge_info, + external_knowledge_api_endpoint: 'https://new-api.example.com/v2', + }, + } + + rerender(<ExternalKnowledgeSection {...defaultProps} currentDataset={updatedDataset} />) + + expect(screen.getByText('https://new-api.example.com/v2')).toBeInTheDocument() + }) + }) + + describe('Layout', () => { + it('should have consistent row layout', () => { + const { container } = render(<ExternalKnowledgeSection {...defaultProps} />) + + // Check for flex gap-x-1 class on rows + const rows = container.querySelectorAll('.flex.gap-x-1') + expect(rows.length).toBeGreaterThan(0) + }) + + it('should have consistent label width', () => { + const { container } = render(<ExternalKnowledgeSection {...defaultProps} />) + + // Check for w-[180px] label containers + const labels = container.querySelectorAll('.w-\\[180px\\]') + expect(labels.length).toBeGreaterThan(0) + }) + }) + + describe('Styling', () => { + it('should apply correct background to info displays', () => { + const { container } = render(<ExternalKnowledgeSection {...defaultProps} />) + + // Info displays should have bg-components-input-bg-normal + const infoDisplays = container.querySelectorAll('.bg-components-input-bg-normal') + expect(infoDisplays.length).toBeGreaterThan(0) + }) + + it('should apply rounded corners to info displays', () => { + const { container } = render(<ExternalKnowledgeSection {...defaultProps} />) + + const roundedElements = container.querySelectorAll('.rounded-lg') + expect(roundedElements.length).toBeGreaterThan(0) + }) + }) + + describe('Different External Knowledge Info', () => { + it('should handle long API names', () => { + const longNameDataset = { + ...mockDataset, + external_knowledge_info: { + ...mockDataset.external_knowledge_info, + external_knowledge_api_name: 'This is a very long external knowledge API name that should be truncated', + }, + } + + render(<ExternalKnowledgeSection {...defaultProps} currentDataset={longNameDataset} />) + + expect(screen.getByText(/This is a very long external knowledge API name/)).toBeInTheDocument() + }) + + it('should handle long API endpoints', () => { + const longEndpointDataset = { + ...mockDataset, + external_knowledge_info: { + ...mockDataset.external_knowledge_info, + external_knowledge_api_endpoint: 'https://api.very-long-domain-name.example.com/api/v1/external/knowledge', + }, + } + + render(<ExternalKnowledgeSection {...defaultProps} currentDataset={longEndpointDataset} />) + + expect(screen.getByText(/https:\/\/api.very-long-domain-name.example.com/)).toBeInTheDocument() + }) + + it('should handle special characters in API name', () => { + const specialCharDataset = { + ...mockDataset, + external_knowledge_info: { + ...mockDataset.external_knowledge_info, + external_knowledge_api_name: 'API & Service <Test>', + }, + } + + render(<ExternalKnowledgeSection {...defaultProps} currentDataset={specialCharDataset} />) + + expect(screen.getByText('API & Service <Test>')).toBeInTheDocument() + }) + }) + + describe('RetrievalSettings Integration', () => { + it('should pass isInRetrievalSetting=true to RetrievalSettings', () => { + render(<ExternalKnowledgeSection {...defaultProps} />) + + // The RetrievalSettings component should be rendered with isInRetrievalSetting=true + // This affects the component's layout/styling + }) + + it('should handle settings change for top_k', () => { + const handleSettingsChange = vi.fn() + render(<ExternalKnowledgeSection {...defaultProps} handleSettingsChange={handleSettingsChange} />) + + // Find and interact with the top_k control in RetrievalSettings + // The exact interaction depends on RetrievalSettings implementation + }) + + it('should handle settings change for score_threshold', () => { + const handleSettingsChange = vi.fn() + render(<ExternalKnowledgeSection {...defaultProps} handleSettingsChange={handleSettingsChange} />) + + // Find and interact with the score_threshold control in RetrievalSettings + }) + + it('should handle settings change for score_threshold_enabled', () => { + const handleSettingsChange = vi.fn() + render(<ExternalKnowledgeSection {...defaultProps} handleSettingsChange={handleSettingsChange} />) + + // Find and interact with the score_threshold_enabled toggle in RetrievalSettings + }) + }) + + describe('Accessibility', () => { + it('should have semantic structure', () => { + render(<ExternalKnowledgeSection {...defaultProps} />) + + // Section labels should be present + expect(screen.getByText(/form\.retrievalSetting\.title/i)).toBeInTheDocument() + expect(screen.getByText(/form\.externalKnowledgeAPI/i)).toBeInTheDocument() + expect(screen.getByText(/form\.externalKnowledgeID/i)).toBeInTheDocument() + }) + }) +}) diff --git a/web/app/components/datasets/settings/form/components/external-knowledge-section.tsx b/web/app/components/datasets/settings/form/components/external-knowledge-section.tsx new file mode 100644 index 0000000000..4b08bb1e7a --- /dev/null +++ b/web/app/components/datasets/settings/form/components/external-knowledge-section.tsx @@ -0,0 +1,84 @@ +'use client' +import type { DataSet } from '@/models/datasets' +import { useTranslation } from 'react-i18next' +import Divider from '@/app/components/base/divider' +import { ApiConnectionMod } from '@/app/components/base/icons/src/vender/solid/development' +import RetrievalSettings from '../../../external-knowledge-base/create/RetrievalSettings' + +const rowClass = 'flex gap-x-1' +const labelClass = 'flex items-center shrink-0 w-[180px] h-7 pt-1' + +type ExternalKnowledgeSectionProps = { + currentDataset: DataSet + topK: number + scoreThreshold: number + scoreThresholdEnabled: boolean + handleSettingsChange: (data: { top_k?: number, score_threshold?: number, score_threshold_enabled?: boolean }) => void +} + +const ExternalKnowledgeSection = ({ + currentDataset, + topK, + scoreThreshold, + scoreThresholdEnabled, + handleSettingsChange, +}: ExternalKnowledgeSectionProps) => { + const { t } = useTranslation() + + return ( + <> + <Divider type="horizontal" className="my-1 h-px bg-divider-subtle" /> + + {/* Retrieval Settings */} + <div className={rowClass}> + <div className={labelClass}> + <div className="system-sm-semibold text-text-secondary">{t('form.retrievalSetting.title', { ns: 'datasetSettings' })}</div> + </div> + <RetrievalSettings + topK={topK} + scoreThreshold={scoreThreshold} + scoreThresholdEnabled={scoreThresholdEnabled} + onChange={handleSettingsChange} + isInRetrievalSetting={true} + /> + </div> + + <Divider type="horizontal" className="my-1 h-px bg-divider-subtle" /> + + {/* External Knowledge API */} + <div className={rowClass}> + <div className={labelClass}> + <div className="system-sm-semibold text-text-secondary">{t('form.externalKnowledgeAPI', { ns: 'datasetSettings' })}</div> + </div> + <div className="w-full"> + <div className="flex h-full items-center gap-1 rounded-lg bg-components-input-bg-normal px-3 py-2"> + <ApiConnectionMod className="h-4 w-4 text-text-secondary" /> + <div className="system-sm-medium overflow-hidden text-ellipsis text-text-secondary"> + {currentDataset.external_knowledge_info.external_knowledge_api_name} + </div> + <div className="system-xs-regular text-text-tertiary">ยท</div> + <div className="system-xs-regular text-text-tertiary"> + {currentDataset.external_knowledge_info.external_knowledge_api_endpoint} + </div> + </div> + </div> + </div> + + {/* External Knowledge ID */} + <div className={rowClass}> + <div className={labelClass}> + <div className="system-sm-semibold text-text-secondary">{t('form.externalKnowledgeID', { ns: 'datasetSettings' })}</div> + </div> + <div className="w-full"> + <div className="flex h-full items-center gap-1 rounded-lg bg-components-input-bg-normal px-3 py-2"> + <div className="system-xs-regular text-text-tertiary"> + {currentDataset.external_knowledge_info.external_knowledge_id} + </div> + </div> + </div> + </div> + </> + ) +} + +export default ExternalKnowledgeSection diff --git a/web/app/components/datasets/settings/form/components/indexing-section.spec.tsx b/web/app/components/datasets/settings/form/components/indexing-section.spec.tsx new file mode 100644 index 0000000000..bf1448b933 --- /dev/null +++ b/web/app/components/datasets/settings/form/components/indexing-section.spec.tsx @@ -0,0 +1,501 @@ +import type { DefaultModel, Model } from '@/app/components/header/account-setting/model-provider-page/declarations' +import type { DataSet, SummaryIndexSetting } from '@/models/datasets' +import type { RetrievalConfig } from '@/types/app' +import { fireEvent, render, screen } from '@testing-library/react' +import { ConfigurationMethodEnum, ModelStatusEnum, ModelTypeEnum } from '@/app/components/header/account-setting/model-provider-page/declarations' +import { ChunkingMode, DatasetPermission, DataSourceType } from '@/models/datasets' +import { RETRIEVE_METHOD } from '@/types/app' +import { IndexingType } from '../../../create/step-two' +import IndexingSection from './indexing-section' + +// Mock i18n doc link +vi.mock('@/context/i18n', () => ({ + useDocLink: () => (path: string) => `https://docs.dify.ai${path}`, +})) + +// Mock app-context for child components +vi.mock('@/context/app-context', () => ({ + useSelector: (selector: (state: unknown) => unknown) => { + const state = { + isCurrentWorkspaceDatasetOperator: false, + userProfile: { + id: 'user-1', + name: 'Current User', + email: 'current@example.com', + avatar_url: '', + role: 'owner', + }, + } + return selector(state) + }, +})) + +// Mock model-provider-page hooks +vi.mock('@/app/components/header/account-setting/model-provider-page/hooks', () => ({ + useModelList: () => ({ data: [], mutate: vi.fn(), isLoading: false }), + useCurrentProviderAndModel: () => ({ currentProvider: undefined, currentModel: undefined }), + useDefaultModel: () => ({ data: undefined, mutate: vi.fn(), isLoading: false }), + useModelListAndDefaultModel: () => ({ modelList: [], defaultModel: undefined }), + useModelListAndDefaultModelAndCurrentProviderAndModel: () => ({ + modelList: [], + defaultModel: undefined, + currentProvider: undefined, + currentModel: undefined, + }), + useUpdateModelList: () => vi.fn(), + useUpdateModelProviders: () => vi.fn(), + useLanguage: () => 'en_US', + useSystemDefaultModelAndModelList: () => [undefined, vi.fn()], + useProviderCredentialsAndLoadBalancing: () => ({ + credentials: undefined, + loadBalancing: undefined, + mutate: vi.fn(), + isLoading: false, + }), + useAnthropicBuyQuota: () => vi.fn(), + useMarketplaceAllPlugins: () => ({ plugins: [], isLoading: false }), + useRefreshModel: () => ({ handleRefreshModel: vi.fn() }), + useModelModalHandler: () => vi.fn(), +})) + +// Mock provider-context +vi.mock('@/context/provider-context', () => ({ + useProviderContext: () => ({ + textGenerationModelList: [], + embeddingsModelList: [], + rerankModelList: [], + agentThoughtModelList: [], + modelProviders: [], + textEmbeddingModelList: [], + speech2textModelList: [], + ttsModelList: [], + moderationModelList: [], + hasSettedApiKey: true, + plan: { type: 'free' }, + enableBilling: false, + onPlanInfoChanged: vi.fn(), + isCurrentWorkspaceDatasetOperator: false, + supportRetrievalMethods: ['semantic_search', 'full_text_search', 'hybrid_search'], + }), +})) + +describe('IndexingSection', () => { + const mockRetrievalConfig: RetrievalConfig = { + search_method: RETRIEVE_METHOD.semantic, + reranking_enable: false, + reranking_model: { + reranking_provider_name: '', + reranking_model_name: '', + }, + top_k: 3, + score_threshold_enabled: false, + score_threshold: 0.5, + } + + const mockDataset: DataSet = { + id: 'dataset-1', + name: 'Test Dataset', + description: 'Test description', + permission: DatasetPermission.onlyMe, + icon_info: { + icon_type: 'emoji', + icon: '๐Ÿ“š', + icon_background: '#FFFFFF', + icon_url: '', + }, + indexing_technique: IndexingType.QUALIFIED, + indexing_status: 'completed', + data_source_type: DataSourceType.FILE, + doc_form: ChunkingMode.text, + embedding_model: 'text-embedding-ada-002', + embedding_model_provider: 'openai', + embedding_available: true, + app_count: 0, + document_count: 5, + total_document_count: 5, + word_count: 1000, + provider: 'vendor', + tags: [], + partial_member_list: [], + external_knowledge_info: { + external_knowledge_id: 'ext-1', + external_knowledge_api_id: 'api-1', + external_knowledge_api_name: 'External API', + external_knowledge_api_endpoint: 'https://api.example.com', + }, + external_retrieval_model: { + top_k: 3, + score_threshold: 0.7, + score_threshold_enabled: true, + }, + retrieval_model_dict: mockRetrievalConfig, + retrieval_model: mockRetrievalConfig, + built_in_field_enabled: false, + keyword_number: 10, + created_by: 'user-1', + updated_by: 'user-1', + updated_at: Date.now(), + runtime_mode: 'general', + enable_api: true, + is_multimodal: false, + } + + const mockEmbeddingModel: DefaultModel = { + provider: 'openai', + model: 'text-embedding-ada-002', + } + + const mockEmbeddingModelList: Model[] = [ + { + provider: 'openai', + label: { en_US: 'OpenAI', zh_Hans: 'OpenAI' }, + icon_small: { en_US: '', zh_Hans: '' }, + status: ModelStatusEnum.active, + models: [ + { + model: 'text-embedding-ada-002', + label: { en_US: 'text-embedding-ada-002', zh_Hans: 'text-embedding-ada-002' }, + model_type: ModelTypeEnum.textEmbedding, + features: [], + fetch_from: ConfigurationMethodEnum.predefinedModel, + model_properties: {}, + deprecated: false, + status: ModelStatusEnum.active, + load_balancing_enabled: false, + }, + ], + }, + ] + + const mockSummaryIndexSetting: SummaryIndexSetting = { + enable: false, + } + + const defaultProps = { + currentDataset: mockDataset, + indexMethod: IndexingType.QUALIFIED, + setIndexMethod: vi.fn(), + keywordNumber: 10, + setKeywordNumber: vi.fn(), + embeddingModel: mockEmbeddingModel, + setEmbeddingModel: vi.fn(), + embeddingModelList: mockEmbeddingModelList, + retrievalConfig: mockRetrievalConfig, + setRetrievalConfig: vi.fn(), + summaryIndexSetting: mockSummaryIndexSetting, + handleSummaryIndexSettingChange: vi.fn(), + showMultiModalTip: false, + } + + beforeEach(() => { + vi.clearAllMocks() + }) + + describe('Rendering', () => { + it('should render without crashing', () => { + render(<IndexingSection {...defaultProps} />) + expect(screen.getByText(/form\.chunkStructure\.title/i)).toBeInTheDocument() + }) + + it('should render chunk structure section when doc_form is set', () => { + render(<IndexingSection {...defaultProps} />) + expect(screen.getByText(/form\.chunkStructure\.title/i)).toBeInTheDocument() + }) + + it('should render index method section when conditions are met', () => { + render(<IndexingSection {...defaultProps} />) + // May match multiple elements (label and descriptions) + expect(screen.getAllByText(/form\.indexMethod/i).length).toBeGreaterThan(0) + }) + + it('should render embedding model section when indexMethod is high_quality', () => { + render(<IndexingSection {...defaultProps} indexMethod={IndexingType.QUALIFIED} />) + expect(screen.getByText(/form\.embeddingModel/i)).toBeInTheDocument() + }) + + it('should render retrieval settings section', () => { + render(<IndexingSection {...defaultProps} />) + expect(screen.getByText(/form\.retrievalSetting\.title/i)).toBeInTheDocument() + }) + }) + + describe('Chunk Structure Section', () => { + it('should not render chunk structure when doc_form is not set', () => { + const datasetWithoutDocForm = { ...mockDataset, doc_form: undefined as unknown as ChunkingMode } + render(<IndexingSection {...defaultProps} currentDataset={datasetWithoutDocForm} />) + + expect(screen.queryByText(/form\.chunkStructure\.title/i)).not.toBeInTheDocument() + }) + + it('should render learn more link for chunk structure', () => { + render(<IndexingSection {...defaultProps} />) + + const learnMoreLink = screen.getByText(/form\.chunkStructure\.learnMore/i) + expect(learnMoreLink).toBeInTheDocument() + expect(learnMoreLink).toHaveAttribute('href', expect.stringContaining('chunking-and-cleaning-text')) + }) + + it('should render chunk structure description', () => { + render(<IndexingSection {...defaultProps} />) + + expect(screen.getByText(/form\.chunkStructure\.description/i)).toBeInTheDocument() + }) + }) + + describe('Index Method Section', () => { + it('should not render index method for parentChild chunking mode', () => { + const parentChildDataset = { ...mockDataset, doc_form: ChunkingMode.parentChild } + render(<IndexingSection {...defaultProps} currentDataset={parentChildDataset} />) + + expect(screen.queryByText(/form\.indexMethod/i)).not.toBeInTheDocument() + }) + + it('should render high quality option', () => { + render(<IndexingSection {...defaultProps} />) + + expect(screen.getByText(/stepTwo\.qualified/i)).toBeInTheDocument() + }) + + it('should render economy option', () => { + render(<IndexingSection {...defaultProps} />) + + // May match multiple elements (title and tip) + expect(screen.getAllByText(/form\.indexMethodEconomy/i).length).toBeGreaterThan(0) + }) + + it('should call setIndexMethod when index method changes', () => { + const setIndexMethod = vi.fn() + const { container } = render(<IndexingSection {...defaultProps} setIndexMethod={setIndexMethod} />) + + // Find the economy option card by looking for clickable elements containing the economy text + const economyOptions = screen.getAllByText(/form\.indexMethodEconomy/i) + if (economyOptions.length > 0) { + const economyCard = economyOptions[0].closest('[class*="cursor-pointer"]') + if (economyCard) { + fireEvent.click(economyCard) + } + } + + // The handler should be properly passed - verify component renders without crashing + expect(container).toBeInTheDocument() + }) + + it('should show upgrade warning when switching from economy to high quality', () => { + const economyDataset = { ...mockDataset, indexing_technique: IndexingType.ECONOMICAL } + render( + <IndexingSection + {...defaultProps} + currentDataset={economyDataset} + indexMethod={IndexingType.QUALIFIED} + />, + ) + + expect(screen.getByText(/form\.upgradeHighQualityTip/i)).toBeInTheDocument() + }) + + it('should not show upgrade warning when already on high quality', () => { + render( + <IndexingSection + {...defaultProps} + indexMethod={IndexingType.QUALIFIED} + />, + ) + + expect(screen.queryByText(/form\.upgradeHighQualityTip/i)).not.toBeInTheDocument() + }) + + it('should disable index method when embedding is not available', () => { + const datasetWithoutEmbedding = { ...mockDataset, embedding_available: false } + render(<IndexingSection {...defaultProps} currentDataset={datasetWithoutEmbedding} />) + + // Index method options should be disabled + // The exact implementation depends on the IndexMethod component + }) + }) + + describe('Embedding Model Section', () => { + it('should render embedding model when indexMethod is high_quality', () => { + render(<IndexingSection {...defaultProps} indexMethod={IndexingType.QUALIFIED} />) + + expect(screen.getByText(/form\.embeddingModel/i)).toBeInTheDocument() + }) + + it('should not render embedding model when indexMethod is economy', () => { + render(<IndexingSection {...defaultProps} indexMethod={IndexingType.ECONOMICAL} />) + + expect(screen.queryByText(/form\.embeddingModel/i)).not.toBeInTheDocument() + }) + + it('should call setEmbeddingModel when model changes', () => { + const setEmbeddingModel = vi.fn() + render( + <IndexingSection + {...defaultProps} + setEmbeddingModel={setEmbeddingModel} + indexMethod={IndexingType.QUALIFIED} + />, + ) + + // The embedding model selector should be rendered + expect(screen.getByText(/form\.embeddingModel/i)).toBeInTheDocument() + }) + }) + + describe('Summary Index Setting Section', () => { + it('should render summary index setting for high quality with text chunking', () => { + render( + <IndexingSection + {...defaultProps} + indexMethod={IndexingType.QUALIFIED} + />, + ) + + // Summary index setting should be rendered based on conditions + // The exact rendering depends on the SummaryIndexSetting component + }) + + it('should not render summary index setting for economy indexing', () => { + render( + <IndexingSection + {...defaultProps} + indexMethod={IndexingType.ECONOMICAL} + />, + ) + + // Summary index setting should not be rendered for economy + }) + + it('should call handleSummaryIndexSettingChange when setting changes', () => { + const handleSummaryIndexSettingChange = vi.fn() + render( + <IndexingSection + {...defaultProps} + handleSummaryIndexSettingChange={handleSummaryIndexSettingChange} + indexMethod={IndexingType.QUALIFIED} + />, + ) + + // The handler should be properly passed + }) + }) + + describe('Retrieval Settings Section', () => { + it('should render retrieval settings', () => { + render(<IndexingSection {...defaultProps} />) + + expect(screen.getByText(/form\.retrievalSetting\.title/i)).toBeInTheDocument() + }) + + it('should render learn more link for retrieval settings', () => { + render(<IndexingSection {...defaultProps} />) + + const learnMoreLinks = screen.getAllByText(/learnMore/i) + const retrievalLearnMore = learnMoreLinks.find(link => + link.closest('a')?.href?.includes('setting-indexing-methods'), + ) + expect(retrievalLearnMore).toBeInTheDocument() + }) + + it('should render RetrievalMethodConfig for high quality indexing', () => { + render(<IndexingSection {...defaultProps} indexMethod={IndexingType.QUALIFIED} />) + + // RetrievalMethodConfig should be rendered + expect(screen.getByText(/form\.retrievalSetting\.title/i)).toBeInTheDocument() + }) + + it('should render EconomicalRetrievalMethodConfig for economy indexing', () => { + render(<IndexingSection {...defaultProps} indexMethod={IndexingType.ECONOMICAL} />) + + // EconomicalRetrievalMethodConfig should be rendered + expect(screen.getByText(/form\.retrievalSetting\.title/i)).toBeInTheDocument() + }) + + it('should call setRetrievalConfig when config changes', () => { + const setRetrievalConfig = vi.fn() + render(<IndexingSection {...defaultProps} setRetrievalConfig={setRetrievalConfig} />) + + // The handler should be properly passed + }) + + it('should pass showMultiModalTip to RetrievalMethodConfig', () => { + render(<IndexingSection {...defaultProps} showMultiModalTip={true} />) + + // The tip should be passed to the config component + }) + }) + + describe('External Provider', () => { + it('should not render retrieval config for external provider', () => { + const externalDataset = { ...mockDataset, provider: 'external' } + render(<IndexingSection {...defaultProps} currentDataset={externalDataset} />) + + // Retrieval config should not be rendered for external provider + // This is handled by the parent component, but we verify the condition + }) + }) + + describe('Conditional Rendering', () => { + it('should show divider between sections', () => { + const { container } = render(<IndexingSection {...defaultProps} />) + + // Dividers should be present + const dividers = container.querySelectorAll('.bg-divider-subtle') + expect(dividers.length).toBeGreaterThan(0) + }) + + it('should not render index method when indexing_technique is not set', () => { + const datasetWithoutTechnique = { ...mockDataset, indexing_technique: undefined as unknown as IndexingType } + render(<IndexingSection {...defaultProps} currentDataset={datasetWithoutTechnique} indexMethod={undefined} />) + + expect(screen.queryByText(/form\.indexMethod/i)).not.toBeInTheDocument() + }) + }) + + describe('Keyword Number', () => { + it('should pass keywordNumber to IndexMethod', () => { + render(<IndexingSection {...defaultProps} keywordNumber={15} />) + + // The keyword number should be displayed in the economy option description + // The exact rendering depends on the IndexMethod component + }) + + it('should call setKeywordNumber when keyword number changes', () => { + const setKeywordNumber = vi.fn() + render(<IndexingSection {...defaultProps} setKeywordNumber={setKeywordNumber} />) + + // The handler should be properly passed + }) + }) + + describe('Props Updates', () => { + it('should update when indexMethod changes', () => { + const { rerender } = render(<IndexingSection {...defaultProps} indexMethod={IndexingType.QUALIFIED} />) + + expect(screen.getByText(/form\.embeddingModel/i)).toBeInTheDocument() + + rerender(<IndexingSection {...defaultProps} indexMethod={IndexingType.ECONOMICAL} />) + + expect(screen.queryByText(/form\.embeddingModel/i)).not.toBeInTheDocument() + }) + + it('should update when currentDataset changes', () => { + const { rerender } = render(<IndexingSection {...defaultProps} />) + + expect(screen.getByText(/form\.chunkStructure\.title/i)).toBeInTheDocument() + + const datasetWithoutDocForm = { ...mockDataset, doc_form: undefined as unknown as ChunkingMode } + rerender(<IndexingSection {...defaultProps} currentDataset={datasetWithoutDocForm} />) + + expect(screen.queryByText(/form\.chunkStructure\.title/i)).not.toBeInTheDocument() + }) + }) + + describe('Undefined Dataset', () => { + it('should handle undefined currentDataset gracefully', () => { + render(<IndexingSection {...defaultProps} currentDataset={undefined} />) + + // Should not crash and should handle undefined gracefully + // Most sections should not render without a dataset + }) + }) +}) diff --git a/web/app/components/datasets/settings/form/components/indexing-section.tsx b/web/app/components/datasets/settings/form/components/indexing-section.tsx new file mode 100644 index 0000000000..f534dd56a8 --- /dev/null +++ b/web/app/components/datasets/settings/form/components/indexing-section.tsx @@ -0,0 +1,208 @@ +'use client' +import type { DefaultModel, Model } from '@/app/components/header/account-setting/model-provider-page/declarations' +import type { DataSet, SummaryIndexSetting as SummaryIndexSettingType } from '@/models/datasets' +import type { RetrievalConfig } from '@/types/app' +import { RiAlertFill } from '@remixicon/react' +import { useTranslation } from 'react-i18next' +import Divider from '@/app/components/base/divider' +import EconomicalRetrievalMethodConfig from '@/app/components/datasets/common/economical-retrieval-method-config' +import RetrievalMethodConfig from '@/app/components/datasets/common/retrieval-method-config' +import ModelSelector from '@/app/components/header/account-setting/model-provider-page/model-selector' +import { IS_CE_EDITION } from '@/config' +import { useDocLink } from '@/context/i18n' +import { ChunkingMode } from '@/models/datasets' +import { IndexingType } from '../../../create/step-two' +import ChunkStructure from '../../chunk-structure' +import IndexMethod from '../../index-method' +import SummaryIndexSetting from '../../summary-index-setting' + +const rowClass = 'flex gap-x-1' +const labelClass = 'flex items-center shrink-0 w-[180px] h-7 pt-1' + +type IndexingSectionProps = { + currentDataset: DataSet | undefined + indexMethod: IndexingType | undefined + setIndexMethod: (value: IndexingType | undefined) => void + keywordNumber: number + setKeywordNumber: (value: number) => void + embeddingModel: DefaultModel + setEmbeddingModel: (value: DefaultModel) => void + embeddingModelList: Model[] + retrievalConfig: RetrievalConfig + setRetrievalConfig: (value: RetrievalConfig) => void + summaryIndexSetting: SummaryIndexSettingType | undefined + handleSummaryIndexSettingChange: (payload: SummaryIndexSettingType) => void + showMultiModalTip: boolean +} + +const IndexingSection = ({ + currentDataset, + indexMethod, + setIndexMethod, + keywordNumber, + setKeywordNumber, + embeddingModel, + setEmbeddingModel, + embeddingModelList, + retrievalConfig, + setRetrievalConfig, + summaryIndexSetting, + handleSummaryIndexSettingChange, + showMultiModalTip, +}: IndexingSectionProps) => { + const { t } = useTranslation() + const docLink = useDocLink() + + const isShowIndexMethod = currentDataset + && currentDataset.doc_form !== ChunkingMode.parentChild + && currentDataset.indexing_technique + && indexMethod + + const showUpgradeWarning = currentDataset?.indexing_technique === IndexingType.ECONOMICAL + && indexMethod === IndexingType.QUALIFIED + + const showSummaryIndexSetting = indexMethod === IndexingType.QUALIFIED + && [ChunkingMode.text, ChunkingMode.parentChild].includes(currentDataset?.doc_form as ChunkingMode) + && IS_CE_EDITION + + return ( + <> + {/* Chunk Structure */} + {!!currentDataset?.doc_form && ( + <> + <Divider type="horizontal" className="my-1 h-px bg-divider-subtle" /> + <div className={rowClass}> + <div className="flex w-[180px] shrink-0 flex-col"> + <div className="system-sm-semibold flex h-8 items-center text-text-secondary"> + {t('form.chunkStructure.title', { ns: 'datasetSettings' })} + </div> + <div className="body-xs-regular text-text-tertiary"> + <a + target="_blank" + rel="noopener noreferrer" + href={docLink('/use-dify/knowledge/create-knowledge/chunking-and-cleaning-text')} + className="text-text-accent" + > + {t('form.chunkStructure.learnMore', { ns: 'datasetSettings' })} + </a> + {t('form.chunkStructure.description', { ns: 'datasetSettings' })} + </div> + </div> + <div className="grow"> + <ChunkStructure chunkStructure={currentDataset?.doc_form} /> + </div> + </div> + </> + )} + + {!!(isShowIndexMethod || indexMethod === 'high_quality') && ( + <Divider type="horizontal" className="my-1 h-px bg-divider-subtle" /> + )} + + {/* Index Method */} + {!!isShowIndexMethod && ( + <div className={rowClass}> + <div className={labelClass}> + <div className="system-sm-semibold text-text-secondary">{t('form.indexMethod', { ns: 'datasetSettings' })}</div> + </div> + <div className="grow"> + <IndexMethod + value={indexMethod!} + disabled={!currentDataset?.embedding_available} + onChange={setIndexMethod} + currentValue={currentDataset.indexing_technique} + keywordNumber={keywordNumber} + onKeywordNumberChange={setKeywordNumber} + /> + {showUpgradeWarning && ( + <div className="relative mt-2 flex h-10 items-center gap-x-0.5 overflow-hidden rounded-xl border-[0.5px] border-components-panel-border bg-components-panel-bg-blur px-2 shadow-xs shadow-shadow-shadow-3"> + <div className="absolute left-0 top-0 flex h-full w-full items-center bg-toast-warning-bg opacity-40" /> + <div className="p-1"> + <RiAlertFill className="size-4 text-text-warning-secondary" /> + </div> + <span className="system-xs-medium text-text-primary"> + {t('form.upgradeHighQualityTip', { ns: 'datasetSettings' })} + </span> + </div> + )} + </div> + </div> + )} + + {/* Embedding Model */} + {indexMethod === IndexingType.QUALIFIED && ( + <div className={rowClass}> + <div className={labelClass}> + <div className="system-sm-semibold text-text-secondary"> + {t('form.embeddingModel', { ns: 'datasetSettings' })} + </div> + </div> + <div className="grow"> + <ModelSelector + defaultModel={embeddingModel} + modelList={embeddingModelList} + onSelect={setEmbeddingModel} + /> + </div> + </div> + )} + + {/* Summary Index Setting */} + {showSummaryIndexSetting && ( + <> + <Divider type="horizontal" className="my-1 h-px bg-divider-subtle" /> + <SummaryIndexSetting + entry="dataset-settings" + summaryIndexSetting={summaryIndexSetting} + onSummaryIndexSettingChange={handleSummaryIndexSettingChange} + /> + </> + )} + + {/* Retrieval Method Config */} + {indexMethod && currentDataset?.provider !== 'external' && ( + <> + <Divider type="horizontal" className="my-1 h-px bg-divider-subtle" /> + <div className={rowClass}> + <div className={labelClass}> + <div className="flex w-[180px] shrink-0 flex-col"> + <div className="system-sm-semibold flex h-7 items-center pt-1 text-text-secondary"> + {t('form.retrievalSetting.title', { ns: 'datasetSettings' })} + </div> + <div className="body-xs-regular text-text-tertiary"> + <a + target="_blank" + rel="noopener noreferrer" + href={docLink('/use-dify/knowledge/create-knowledge/setting-indexing-methods')} + className="text-text-accent" + > + {t('form.retrievalSetting.learnMore', { ns: 'datasetSettings' })} + </a> + {t('form.retrievalSetting.description', { ns: 'datasetSettings' })} + </div> + </div> + </div> + <div className="grow"> + {indexMethod === IndexingType.QUALIFIED + ? ( + <RetrievalMethodConfig + value={retrievalConfig} + onChange={setRetrievalConfig} + showMultiModalTip={showMultiModalTip} + /> + ) + : ( + <EconomicalRetrievalMethodConfig + value={retrievalConfig} + onChange={setRetrievalConfig} + /> + )} + </div> + </div> + </> + )} + </> + ) +} + +export default IndexingSection diff --git a/web/app/components/datasets/settings/form/hooks/use-form-state.spec.ts b/web/app/components/datasets/settings/form/hooks/use-form-state.spec.ts new file mode 100644 index 0000000000..f79500544b --- /dev/null +++ b/web/app/components/datasets/settings/form/hooks/use-form-state.spec.ts @@ -0,0 +1,763 @@ +import type { DataSet } from '@/models/datasets' +import type { RetrievalConfig } from '@/types/app' +import { act, renderHook, waitFor } from '@testing-library/react' +import { ChunkingMode, DatasetPermission, DataSourceType, WeightedScoreEnum } from '@/models/datasets' +import { RETRIEVE_METHOD } from '@/types/app' +import { IndexingType } from '../../../create/step-two' +import { useFormState } from './use-form-state' + +// Mock contexts +const mockMutateDatasets = vi.fn() +const mockInvalidDatasetList = vi.fn() + +vi.mock('@/context/app-context', () => ({ + useSelector: () => false, // isCurrentWorkspaceDatasetOperator +})) + +const createDefaultMockDataset = (): DataSet => ({ + id: 'dataset-1', + name: 'Test Dataset', + description: 'Test description', + permission: DatasetPermission.onlyMe, + icon_info: { + icon_type: 'emoji', + icon: '๐Ÿ“š', + icon_background: '#FFFFFF', + icon_url: '', + }, + indexing_technique: IndexingType.QUALIFIED, + indexing_status: 'completed', + data_source_type: DataSourceType.FILE, + doc_form: ChunkingMode.text, + embedding_model: 'text-embedding-ada-002', + embedding_model_provider: 'openai', + embedding_available: true, + app_count: 0, + document_count: 5, + total_document_count: 5, + word_count: 1000, + provider: 'vendor', + tags: [], + partial_member_list: [], + external_knowledge_info: { + external_knowledge_id: 'ext-1', + external_knowledge_api_id: 'api-1', + external_knowledge_api_name: 'External API', + external_knowledge_api_endpoint: 'https://api.example.com', + }, + external_retrieval_model: { + top_k: 3, + score_threshold: 0.7, + score_threshold_enabled: true, + }, + retrieval_model_dict: { + search_method: RETRIEVE_METHOD.semantic, + reranking_enable: false, + reranking_model: { + reranking_provider_name: '', + reranking_model_name: '', + }, + top_k: 3, + score_threshold_enabled: false, + score_threshold: 0.5, + } as RetrievalConfig, + retrieval_model: { + search_method: RETRIEVE_METHOD.semantic, + reranking_enable: false, + reranking_model: { + reranking_provider_name: '', + reranking_model_name: '', + }, + top_k: 3, + score_threshold_enabled: false, + score_threshold: 0.5, + } as RetrievalConfig, + built_in_field_enabled: false, + keyword_number: 10, + created_by: 'user-1', + updated_by: 'user-1', + updated_at: Date.now(), + runtime_mode: 'general', + enable_api: true, + is_multimodal: false, +}) + +let mockDataset: DataSet = createDefaultMockDataset() + +vi.mock('@/context/dataset-detail', () => ({ + useDatasetDetailContextWithSelector: (selector: (state: { dataset: DataSet | null, mutateDatasetRes: () => void }) => unknown) => { + const state = { + dataset: mockDataset, + mutateDatasetRes: mockMutateDatasets, + } + return selector(state) + }, +})) + +// Mock services +vi.mock('@/service/datasets', () => ({ + updateDatasetSetting: vi.fn().mockResolvedValue({}), +})) + +vi.mock('@/service/knowledge/use-dataset', () => ({ + useInvalidDatasetList: () => mockInvalidDatasetList, +})) + +vi.mock('@/service/use-common', () => ({ + useMembers: () => ({ + data: { + accounts: [ + { id: 'user-1', name: 'User 1', email: 'user1@example.com', role: 'owner', avatar: '', avatar_url: '', last_login_at: '', created_at: '', status: 'active' }, + { id: 'user-2', name: 'User 2', email: 'user2@example.com', role: 'admin', avatar: '', avatar_url: '', last_login_at: '', created_at: '', status: 'active' }, + ], + }, + }), +})) + +vi.mock('@/app/components/header/account-setting/model-provider-page/hooks', () => ({ + useModelList: () => ({ data: [] }), +})) + +vi.mock('@/app/components/datasets/common/check-rerank-model', () => ({ + isReRankModelSelected: () => true, +})) + +vi.mock('@/app/components/base/toast', () => ({ + default: { + notify: vi.fn(), + }, +})) + +describe('useFormState', () => { + beforeEach(() => { + vi.clearAllMocks() + mockDataset = createDefaultMockDataset() + }) + + describe('Initial State', () => { + it('should initialize with dataset values', () => { + const { result } = renderHook(() => useFormState()) + + expect(result.current.name).toBe('Test Dataset') + expect(result.current.description).toBe('Test description') + expect(result.current.permission).toBe(DatasetPermission.onlyMe) + expect(result.current.indexMethod).toBe(IndexingType.QUALIFIED) + expect(result.current.keywordNumber).toBe(10) + }) + + it('should initialize icon info from dataset', () => { + const { result } = renderHook(() => useFormState()) + + expect(result.current.iconInfo).toEqual({ + icon_type: 'emoji', + icon: '๐Ÿ“š', + icon_background: '#FFFFFF', + icon_url: '', + }) + }) + + it('should initialize external retrieval settings', () => { + const { result } = renderHook(() => useFormState()) + + expect(result.current.topK).toBe(3) + expect(result.current.scoreThreshold).toBe(0.7) + expect(result.current.scoreThresholdEnabled).toBe(true) + }) + + it('should derive member list from API data', () => { + const { result } = renderHook(() => useFormState()) + + expect(result.current.memberList).toHaveLength(2) + expect(result.current.memberList[0].name).toBe('User 1') + }) + + it('should return currentDataset from context', () => { + const { result } = renderHook(() => useFormState()) + + expect(result.current.currentDataset).toBeDefined() + expect(result.current.currentDataset?.id).toBe('dataset-1') + }) + }) + + describe('State Setters', () => { + it('should update name when setName is called', () => { + const { result } = renderHook(() => useFormState()) + + act(() => { + result.current.setName('New Name') + }) + + expect(result.current.name).toBe('New Name') + }) + + it('should update description when setDescription is called', () => { + const { result } = renderHook(() => useFormState()) + + act(() => { + result.current.setDescription('New Description') + }) + + expect(result.current.description).toBe('New Description') + }) + + it('should update permission when setPermission is called', () => { + const { result } = renderHook(() => useFormState()) + + act(() => { + result.current.setPermission(DatasetPermission.allTeamMembers) + }) + + expect(result.current.permission).toBe(DatasetPermission.allTeamMembers) + }) + + it('should update indexMethod when setIndexMethod is called', () => { + const { result } = renderHook(() => useFormState()) + + act(() => { + result.current.setIndexMethod(IndexingType.ECONOMICAL) + }) + + expect(result.current.indexMethod).toBe(IndexingType.ECONOMICAL) + }) + + it('should update keywordNumber when setKeywordNumber is called', () => { + const { result } = renderHook(() => useFormState()) + + act(() => { + result.current.setKeywordNumber(20) + }) + + expect(result.current.keywordNumber).toBe(20) + }) + + it('should update selectedMemberIDs when setSelectedMemberIDs is called', () => { + const { result } = renderHook(() => useFormState()) + + act(() => { + result.current.setSelectedMemberIDs(['user-1', 'user-2']) + }) + + expect(result.current.selectedMemberIDs).toEqual(['user-1', 'user-2']) + }) + }) + + describe('Icon Handlers', () => { + it('should open app icon picker and save previous icon', () => { + const { result } = renderHook(() => useFormState()) + + act(() => { + result.current.handleOpenAppIconPicker() + }) + + expect(result.current.showAppIconPicker).toBe(true) + }) + + it('should select emoji icon and close picker', () => { + const { result } = renderHook(() => useFormState()) + + act(() => { + result.current.handleOpenAppIconPicker() + }) + + act(() => { + result.current.handleSelectAppIcon({ + type: 'emoji', + icon: '๐ŸŽ‰', + background: '#FF0000', + }) + }) + + expect(result.current.showAppIconPicker).toBe(false) + expect(result.current.iconInfo).toEqual({ + icon_type: 'emoji', + icon: '๐ŸŽ‰', + icon_background: '#FF0000', + icon_url: undefined, + }) + }) + + it('should select image icon and close picker', () => { + const { result } = renderHook(() => useFormState()) + + act(() => { + result.current.handleOpenAppIconPicker() + }) + + act(() => { + result.current.handleSelectAppIcon({ + type: 'image', + fileId: 'file-123', + url: 'https://example.com/icon.png', + }) + }) + + expect(result.current.showAppIconPicker).toBe(false) + expect(result.current.iconInfo).toEqual({ + icon_type: 'image', + icon: 'file-123', + icon_background: undefined, + icon_url: 'https://example.com/icon.png', + }) + }) + + it('should restore previous icon when picker is closed', () => { + const { result } = renderHook(() => useFormState()) + + act(() => { + result.current.handleOpenAppIconPicker() + }) + + act(() => { + result.current.handleSelectAppIcon({ + type: 'emoji', + icon: '๐ŸŽ‰', + background: '#FF0000', + }) + }) + + act(() => { + result.current.handleOpenAppIconPicker() + }) + + act(() => { + result.current.handleCloseAppIconPicker() + }) + + expect(result.current.showAppIconPicker).toBe(false) + // After close, icon should be restored to the icon before opening + expect(result.current.iconInfo).toEqual({ + icon_type: 'emoji', + icon: '๐ŸŽ‰', + icon_background: '#FF0000', + icon_url: undefined, + }) + }) + }) + + describe('External Retrieval Settings Handler', () => { + it('should update topK when provided', () => { + const { result } = renderHook(() => useFormState()) + + act(() => { + result.current.handleSettingsChange({ top_k: 5 }) + }) + + expect(result.current.topK).toBe(5) + }) + + it('should update scoreThreshold when provided', () => { + const { result } = renderHook(() => useFormState()) + + act(() => { + result.current.handleSettingsChange({ score_threshold: 0.8 }) + }) + + expect(result.current.scoreThreshold).toBe(0.8) + }) + + it('should update scoreThresholdEnabled when provided', () => { + const { result } = renderHook(() => useFormState()) + + act(() => { + result.current.handleSettingsChange({ score_threshold_enabled: false }) + }) + + expect(result.current.scoreThresholdEnabled).toBe(false) + }) + + it('should update multiple settings at once', () => { + const { result } = renderHook(() => useFormState()) + + act(() => { + result.current.handleSettingsChange({ + top_k: 10, + score_threshold: 0.9, + score_threshold_enabled: true, + }) + }) + + expect(result.current.topK).toBe(10) + expect(result.current.scoreThreshold).toBe(0.9) + expect(result.current.scoreThresholdEnabled).toBe(true) + }) + }) + + describe('Summary Index Setting Handler', () => { + it('should update summary index setting', () => { + const { result } = renderHook(() => useFormState()) + + act(() => { + result.current.handleSummaryIndexSettingChange({ + enable: true, + }) + }) + + expect(result.current.summaryIndexSetting).toMatchObject({ + enable: true, + }) + }) + + it('should merge with existing settings', () => { + const { result } = renderHook(() => useFormState()) + + act(() => { + result.current.handleSummaryIndexSettingChange({ + enable: true, + }) + }) + + act(() => { + result.current.handleSummaryIndexSettingChange({ + model_provider_name: 'openai', + model_name: 'gpt-4', + }) + }) + + expect(result.current.summaryIndexSetting).toMatchObject({ + enable: true, + model_provider_name: 'openai', + model_name: 'gpt-4', + }) + }) + }) + + describe('handleSave', () => { + it('should show error toast when name is empty', async () => { + const Toast = await import('@/app/components/base/toast') + const { result } = renderHook(() => useFormState()) + + act(() => { + result.current.setName('') + }) + + await act(async () => { + await result.current.handleSave() + }) + + expect(Toast.default.notify).toHaveBeenCalledWith({ + type: 'error', + message: expect.any(String), + }) + }) + + it('should show error toast when name is whitespace only', async () => { + const Toast = await import('@/app/components/base/toast') + const { result } = renderHook(() => useFormState()) + + act(() => { + result.current.setName(' ') + }) + + await act(async () => { + await result.current.handleSave() + }) + + expect(Toast.default.notify).toHaveBeenCalledWith({ + type: 'error', + message: expect.any(String), + }) + }) + + it('should call updateDatasetSetting with correct params', async () => { + const { updateDatasetSetting } = await import('@/service/datasets') + const { result } = renderHook(() => useFormState()) + + await act(async () => { + await result.current.handleSave() + }) + + expect(updateDatasetSetting).toHaveBeenCalledWith({ + datasetId: 'dataset-1', + body: expect.objectContaining({ + name: 'Test Dataset', + description: 'Test description', + permission: DatasetPermission.onlyMe, + }), + }) + }) + + it('should show success toast on successful save', async () => { + const Toast = await import('@/app/components/base/toast') + const { result } = renderHook(() => useFormState()) + + await act(async () => { + await result.current.handleSave() + }) + + await waitFor(() => { + expect(Toast.default.notify).toHaveBeenCalledWith({ + type: 'success', + message: expect.any(String), + }) + }) + }) + + it('should call mutateDatasets after successful save', async () => { + const { result } = renderHook(() => useFormState()) + + await act(async () => { + await result.current.handleSave() + }) + + await waitFor(() => { + expect(mockMutateDatasets).toHaveBeenCalled() + }) + }) + + it('should call invalidDatasetList after successful save', async () => { + const { result } = renderHook(() => useFormState()) + + await act(async () => { + await result.current.handleSave() + }) + + await waitFor(() => { + expect(mockInvalidDatasetList).toHaveBeenCalled() + }) + }) + + it('should set loading to true during save', async () => { + const { result } = renderHook(() => useFormState()) + + expect(result.current.loading).toBe(false) + + const savePromise = act(async () => { + await result.current.handleSave() + }) + + // Loading should be true during the save operation + await savePromise + + expect(result.current.loading).toBe(false) // After completion + }) + + it('should not save when already loading', async () => { + const { updateDatasetSetting } = await import('@/service/datasets') + vi.mocked(updateDatasetSetting).mockImplementation(() => new Promise(resolve => setTimeout(resolve, 100))) + + const { result } = renderHook(() => useFormState()) + + // Start first save + act(() => { + result.current.handleSave() + }) + + // Try to start second save immediately + await act(async () => { + await result.current.handleSave() + }) + + // Should only have been called once + expect(updateDatasetSetting).toHaveBeenCalledTimes(1) + }) + + it('should show error toast on save failure', async () => { + const { updateDatasetSetting } = await import('@/service/datasets') + const Toast = await import('@/app/components/base/toast') + vi.mocked(updateDatasetSetting).mockRejectedValueOnce(new Error('Network error')) + + const { result } = renderHook(() => useFormState()) + + await act(async () => { + await result.current.handleSave() + }) + + expect(Toast.default.notify).toHaveBeenCalledWith({ + type: 'error', + message: expect.any(String), + }) + }) + + it('should include partial_member_list when permission is partialMembers', async () => { + const { updateDatasetSetting } = await import('@/service/datasets') + const { result } = renderHook(() => useFormState()) + + act(() => { + result.current.setPermission(DatasetPermission.partialMembers) + result.current.setSelectedMemberIDs(['user-1', 'user-2']) + }) + + await act(async () => { + await result.current.handleSave() + }) + + expect(updateDatasetSetting).toHaveBeenCalledWith({ + datasetId: 'dataset-1', + body: expect.objectContaining({ + partial_member_list: expect.arrayContaining([ + expect.objectContaining({ user_id: 'user-1' }), + expect.objectContaining({ user_id: 'user-2' }), + ]), + }), + }) + }) + }) + + describe('Embedding Model', () => { + it('should initialize embedding model from dataset', () => { + const { result } = renderHook(() => useFormState()) + + expect(result.current.embeddingModel).toEqual({ + provider: 'openai', + model: 'text-embedding-ada-002', + }) + }) + + it('should update embedding model when setEmbeddingModel is called', () => { + const { result } = renderHook(() => useFormState()) + + act(() => { + result.current.setEmbeddingModel({ + provider: 'cohere', + model: 'embed-english-v3.0', + }) + }) + + expect(result.current.embeddingModel).toEqual({ + provider: 'cohere', + model: 'embed-english-v3.0', + }) + }) + }) + + describe('Retrieval Config', () => { + it('should initialize retrieval config from dataset', () => { + const { result } = renderHook(() => useFormState()) + + expect(result.current.retrievalConfig).toBeDefined() + expect(result.current.retrievalConfig.search_method).toBe(RETRIEVE_METHOD.semantic) + }) + + it('should update retrieval config when setRetrievalConfig is called', () => { + const { result } = renderHook(() => useFormState()) + + const newConfig: RetrievalConfig = { + ...result.current.retrievalConfig, + reranking_enable: true, + } + + act(() => { + result.current.setRetrievalConfig(newConfig) + }) + + expect(result.current.retrievalConfig.reranking_enable).toBe(true) + }) + + it('should include weights in save request when weights are set', async () => { + const { updateDatasetSetting } = await import('@/service/datasets') + const { result } = renderHook(() => useFormState()) + + // Set retrieval config with weights + const configWithWeights: RetrievalConfig = { + ...result.current.retrievalConfig, + search_method: RETRIEVE_METHOD.hybrid, + weights: { + weight_type: WeightedScoreEnum.Customized, + vector_setting: { + vector_weight: 0.7, + embedding_provider_name: '', + embedding_model_name: '', + }, + keyword_setting: { + keyword_weight: 0.3, + }, + }, + } + + act(() => { + result.current.setRetrievalConfig(configWithWeights) + }) + + await act(async () => { + await result.current.handleSave() + }) + + // Verify that weights were included and embedding model info was added + expect(updateDatasetSetting).toHaveBeenCalledWith({ + datasetId: 'dataset-1', + body: expect.objectContaining({ + retrieval_model: expect.objectContaining({ + weights: expect.objectContaining({ + vector_setting: expect.objectContaining({ + embedding_provider_name: 'openai', + embedding_model_name: 'text-embedding-ada-002', + }), + }), + }), + }), + }) + }) + }) + + describe('External Provider', () => { + beforeEach(() => { + // Update mock dataset to be external provider + mockDataset = { + ...mockDataset, + provider: 'external', + external_knowledge_info: { + external_knowledge_id: 'ext-123', + external_knowledge_api_id: 'api-456', + external_knowledge_api_name: 'External API', + external_knowledge_api_endpoint: 'https://api.example.com', + }, + external_retrieval_model: { + top_k: 5, + score_threshold: 0.8, + score_threshold_enabled: true, + }, + } + }) + + it('should include external knowledge info in save request for external provider', async () => { + const { updateDatasetSetting } = await import('@/service/datasets') + const { result } = renderHook(() => useFormState()) + + await act(async () => { + await result.current.handleSave() + }) + + expect(updateDatasetSetting).toHaveBeenCalledWith({ + datasetId: 'dataset-1', + body: expect.objectContaining({ + external_knowledge_id: 'ext-123', + external_knowledge_api_id: 'api-456', + external_retrieval_model: expect.objectContaining({ + top_k: expect.any(Number), + score_threshold: expect.any(Number), + score_threshold_enabled: expect.any(Boolean), + }), + }), + }) + }) + + it('should use correct external retrieval settings', async () => { + const { updateDatasetSetting } = await import('@/service/datasets') + const { result } = renderHook(() => useFormState()) + + // Update external retrieval settings + act(() => { + result.current.handleSettingsChange({ + top_k: 10, + score_threshold: 0.9, + score_threshold_enabled: false, + }) + }) + + await act(async () => { + await result.current.handleSave() + }) + + expect(updateDatasetSetting).toHaveBeenCalledWith({ + datasetId: 'dataset-1', + body: expect.objectContaining({ + external_retrieval_model: { + top_k: 10, + score_threshold: 0.9, + score_threshold_enabled: false, + }, + }), + }) + }) + }) +}) diff --git a/web/app/components/datasets/settings/form/hooks/use-form-state.ts b/web/app/components/datasets/settings/form/hooks/use-form-state.ts new file mode 100644 index 0000000000..614995d43a --- /dev/null +++ b/web/app/components/datasets/settings/form/hooks/use-form-state.ts @@ -0,0 +1,264 @@ +'use client' +import type { AppIconSelection } from '@/app/components/base/app-icon-picker' +import type { DefaultModel } from '@/app/components/header/account-setting/model-provider-page/declarations' +import type { Member } from '@/models/common' +import type { IconInfo, SummaryIndexSetting as SummaryIndexSettingType } from '@/models/datasets' +import type { RetrievalConfig } from '@/types/app' +import { useCallback, useMemo, useRef, useState } from 'react' +import { useTranslation } from 'react-i18next' +import Toast from '@/app/components/base/toast' +import { isReRankModelSelected } from '@/app/components/datasets/common/check-rerank-model' +import { ModelTypeEnum } from '@/app/components/header/account-setting/model-provider-page/declarations' +import { useModelList } from '@/app/components/header/account-setting/model-provider-page/hooks' +import { useSelector as useAppContextWithSelector } from '@/context/app-context' +import { useDatasetDetailContextWithSelector } from '@/context/dataset-detail' +import { DatasetPermission } from '@/models/datasets' +import { updateDatasetSetting } from '@/service/datasets' +import { useInvalidDatasetList } from '@/service/knowledge/use-dataset' +import { useMembers } from '@/service/use-common' +import { checkShowMultiModalTip } from '../../utils' + +const DEFAULT_APP_ICON: IconInfo = { + icon_type: 'emoji', + icon: '๐Ÿ“™', + icon_background: '#FFF4ED', + icon_url: '', +} + +export const useFormState = () => { + const { t } = useTranslation() + const isCurrentWorkspaceDatasetOperator = useAppContextWithSelector(state => state.isCurrentWorkspaceDatasetOperator) + const currentDataset = useDatasetDetailContextWithSelector(state => state.dataset) + const mutateDatasets = useDatasetDetailContextWithSelector(state => state.mutateDatasetRes) + + // Basic form state + const [loading, setLoading] = useState(false) + const [name, setName] = useState(currentDataset?.name ?? '') + const [description, setDescription] = useState(currentDataset?.description ?? '') + + // Icon state + const [iconInfo, setIconInfo] = useState(currentDataset?.icon_info || DEFAULT_APP_ICON) + const [showAppIconPicker, setShowAppIconPicker] = useState(false) + const previousAppIcon = useRef(DEFAULT_APP_ICON) + + // Permission state + const [permission, setPermission] = useState(currentDataset?.permission) + const [selectedMemberIDs, setSelectedMemberIDs] = useState<string[]>(currentDataset?.partial_member_list || []) + + // External retrieval state + const [topK, setTopK] = useState(currentDataset?.external_retrieval_model.top_k ?? 2) + const [scoreThreshold, setScoreThreshold] = useState(currentDataset?.external_retrieval_model.score_threshold ?? 0.5) + const [scoreThresholdEnabled, setScoreThresholdEnabled] = useState(currentDataset?.external_retrieval_model.score_threshold_enabled ?? false) + + // Indexing and retrieval state + const [indexMethod, setIndexMethod] = useState(currentDataset?.indexing_technique) + const [keywordNumber, setKeywordNumber] = useState(currentDataset?.keyword_number ?? 10) + const [retrievalConfig, setRetrievalConfig] = useState(currentDataset?.retrieval_model_dict as RetrievalConfig) + const [embeddingModel, setEmbeddingModel] = useState<DefaultModel>( + currentDataset?.embedding_model + ? { + provider: currentDataset.embedding_model_provider, + model: currentDataset.embedding_model, + } + : { + provider: '', + model: '', + }, + ) + + // Summary index state + const [summaryIndexSetting, setSummaryIndexSetting] = useState(currentDataset?.summary_index_setting) + + // Model lists + const { data: rerankModelList } = useModelList(ModelTypeEnum.rerank) + const { data: embeddingModelList } = useModelList(ModelTypeEnum.textEmbedding) + const { data: membersData } = useMembers() + const invalidDatasetList = useInvalidDatasetList() + + // Derive member list from API data + const memberList = useMemo<Member[]>(() => { + return membersData?.accounts ?? [] + }, [membersData]) + + // Icon handlers + const handleOpenAppIconPicker = useCallback(() => { + setShowAppIconPicker(true) + previousAppIcon.current = iconInfo + }, [iconInfo]) + + const handleSelectAppIcon = useCallback((icon: AppIconSelection) => { + const newIconInfo: IconInfo = { + icon_type: icon.type, + icon: icon.type === 'emoji' ? icon.icon : icon.fileId, + icon_background: icon.type === 'emoji' ? icon.background : undefined, + icon_url: icon.type === 'emoji' ? undefined : icon.url, + } + setIconInfo(newIconInfo) + setShowAppIconPicker(false) + }, []) + + const handleCloseAppIconPicker = useCallback(() => { + setIconInfo(previousAppIcon.current) + setShowAppIconPicker(false) + }, []) + + // External retrieval settings handler + const handleSettingsChange = useCallback((data: { top_k?: number, score_threshold?: number, score_threshold_enabled?: boolean }) => { + if (data.top_k !== undefined) + setTopK(data.top_k) + if (data.score_threshold !== undefined) + setScoreThreshold(data.score_threshold) + if (data.score_threshold_enabled !== undefined) + setScoreThresholdEnabled(data.score_threshold_enabled) + }, []) + + // Summary index setting handler + const handleSummaryIndexSettingChange = useCallback((payload: SummaryIndexSettingType) => { + setSummaryIndexSetting(prev => ({ ...prev, ...payload })) + }, []) + + // Save handler + const handleSave = async () => { + if (loading) + return + + if (!name?.trim()) { + Toast.notify({ type: 'error', message: t('form.nameError', { ns: 'datasetSettings' }) }) + return + } + + if (!isReRankModelSelected({ rerankModelList, retrievalConfig, indexMethod })) { + Toast.notify({ type: 'error', message: t('datasetConfig.rerankModelRequired', { ns: 'appDebug' }) }) + return + } + + if (retrievalConfig.weights) { + retrievalConfig.weights.vector_setting.embedding_provider_name = embeddingModel.provider || '' + retrievalConfig.weights.vector_setting.embedding_model_name = embeddingModel.model || '' + } + + try { + setLoading(true) + const body: Record<string, unknown> = { + name, + icon_info: iconInfo, + doc_form: currentDataset?.doc_form, + description, + permission, + indexing_technique: indexMethod, + retrieval_model: { + ...retrievalConfig, + score_threshold: retrievalConfig.score_threshold_enabled ? retrievalConfig.score_threshold : 0, + }, + embedding_model: embeddingModel.model, + embedding_model_provider: embeddingModel.provider, + keyword_number: keywordNumber, + summary_index_setting: summaryIndexSetting, + } + + if (currentDataset!.provider === 'external') { + body.external_knowledge_id = currentDataset!.external_knowledge_info.external_knowledge_id + body.external_knowledge_api_id = currentDataset!.external_knowledge_info.external_knowledge_api_id + body.external_retrieval_model = { + top_k: topK, + score_threshold: scoreThreshold, + score_threshold_enabled: scoreThresholdEnabled, + } + } + + if (permission === DatasetPermission.partialMembers) { + body.partial_member_list = selectedMemberIDs.map((id) => { + return { + user_id: id, + role: memberList.find(member => member.id === id)?.role, + } + }) + } + + await updateDatasetSetting({ datasetId: currentDataset!.id, body }) + Toast.notify({ type: 'success', message: t('actionMsg.modifiedSuccessfully', { ns: 'common' }) }) + + if (mutateDatasets) { + await mutateDatasets() + invalidDatasetList() + } + } + catch { + Toast.notify({ type: 'error', message: t('actionMsg.modifiedUnsuccessfully', { ns: 'common' }) }) + } + finally { + setLoading(false) + } + } + + // Computed values + const showMultiModalTip = useMemo(() => { + return checkShowMultiModalTip({ + embeddingModel, + rerankingEnable: retrievalConfig.reranking_enable, + rerankModel: { + rerankingProviderName: retrievalConfig.reranking_model.reranking_provider_name, + rerankingModelName: retrievalConfig.reranking_model.reranking_model_name, + }, + indexMethod, + embeddingModelList, + rerankModelList, + }) + }, [embeddingModel, rerankModelList, retrievalConfig.reranking_enable, retrievalConfig.reranking_model, embeddingModelList, indexMethod]) + + return { + // Context values + currentDataset, + isCurrentWorkspaceDatasetOperator, + + // Loading state + loading, + + // Basic form + name, + setName, + description, + setDescription, + + // Icon + iconInfo, + showAppIconPicker, + handleOpenAppIconPicker, + handleSelectAppIcon, + handleCloseAppIconPicker, + + // Permission + permission, + setPermission, + selectedMemberIDs, + setSelectedMemberIDs, + memberList, + + // External retrieval + topK, + scoreThreshold, + scoreThresholdEnabled, + handleSettingsChange, + + // Indexing and retrieval + indexMethod, + setIndexMethod, + keywordNumber, + setKeywordNumber, + retrievalConfig, + setRetrievalConfig, + embeddingModel, + setEmbeddingModel, + embeddingModelList, + + // Summary index + summaryIndexSetting, + handleSummaryIndexSettingChange, + + // Computed + showMultiModalTip, + + // Actions + handleSave, + } +} diff --git a/web/app/components/datasets/settings/form/index.spec.tsx b/web/app/components/datasets/settings/form/index.spec.tsx new file mode 100644 index 0000000000..03e98861e2 --- /dev/null +++ b/web/app/components/datasets/settings/form/index.spec.tsx @@ -0,0 +1,488 @@ +import type { DataSet } from '@/models/datasets' +import type { RetrievalConfig } from '@/types/app' +import { fireEvent, render, screen, waitFor } from '@testing-library/react' +import { ChunkingMode, DatasetPermission, DataSourceType } from '@/models/datasets' +import { RETRIEVE_METHOD } from '@/types/app' +import { IndexingType } from '../../create/step-two' +import Form from './index' + +// Mock contexts +const mockMutateDatasets = vi.fn() +const mockInvalidDatasetList = vi.fn() + +const mockUserProfile = { + id: 'user-1', + name: 'Current User', + email: 'current@example.com', + avatar_url: '', + role: 'owner', +} + +vi.mock('@/context/app-context', () => ({ + useSelector: (selector: (state: unknown) => unknown) => { + const state = { + isCurrentWorkspaceDatasetOperator: false, + userProfile: mockUserProfile, + } + return selector(state) + }, +})) + +const createMockDataset = (overrides: Partial<DataSet> = {}): DataSet => ({ + id: 'dataset-1', + name: 'Test Dataset', + description: 'Test description', + permission: DatasetPermission.onlyMe, + icon_info: { + icon_type: 'emoji', + icon: '๐Ÿ“š', + icon_background: '#FFFFFF', + icon_url: '', + }, + indexing_technique: IndexingType.QUALIFIED, + indexing_status: 'completed', + data_source_type: DataSourceType.FILE, + doc_form: ChunkingMode.text, + embedding_model: 'text-embedding-ada-002', + embedding_model_provider: 'openai', + embedding_available: true, + app_count: 0, + document_count: 5, + total_document_count: 5, + word_count: 1000, + provider: 'vendor', + tags: [], + partial_member_list: [], + external_knowledge_info: { + external_knowledge_id: 'ext-1', + external_knowledge_api_id: 'api-1', + external_knowledge_api_name: 'External API', + external_knowledge_api_endpoint: 'https://api.example.com', + }, + external_retrieval_model: { + top_k: 3, + score_threshold: 0.7, + score_threshold_enabled: true, + }, + retrieval_model_dict: { + search_method: RETRIEVE_METHOD.semantic, + reranking_enable: false, + reranking_model: { + reranking_provider_name: '', + reranking_model_name: '', + }, + top_k: 3, + score_threshold_enabled: false, + score_threshold: 0.5, + } as RetrievalConfig, + retrieval_model: { + search_method: RETRIEVE_METHOD.semantic, + reranking_enable: false, + reranking_model: { + reranking_provider_name: '', + reranking_model_name: '', + }, + top_k: 3, + score_threshold_enabled: false, + score_threshold: 0.5, + } as RetrievalConfig, + built_in_field_enabled: false, + keyword_number: 10, + created_by: 'user-1', + updated_by: 'user-1', + updated_at: Date.now(), + runtime_mode: 'general', + enable_api: true, + is_multimodal: false, + ...overrides, +}) + +let mockDataset: DataSet = createMockDataset() + +vi.mock('@/context/dataset-detail', () => ({ + useDatasetDetailContextWithSelector: (selector: (state: { dataset: DataSet | null, mutateDatasetRes: () => void }) => unknown) => { + const state = { + dataset: mockDataset, + mutateDatasetRes: mockMutateDatasets, + } + return selector(state) + }, +})) + +// Mock services +vi.mock('@/service/datasets', () => ({ + updateDatasetSetting: vi.fn().mockResolvedValue({}), +})) + +vi.mock('@/service/knowledge/use-dataset', () => ({ + useInvalidDatasetList: () => mockInvalidDatasetList, +})) + +vi.mock('@/service/use-common', () => ({ + useMembers: () => ({ + data: { + accounts: [ + { id: 'user-1', name: 'User 1', email: 'user1@example.com', role: 'owner', avatar: '', avatar_url: '', last_login_at: '', created_at: '', status: 'active' }, + { id: 'user-2', name: 'User 2', email: 'user2@example.com', role: 'admin', avatar: '', avatar_url: '', last_login_at: '', created_at: '', status: 'active' }, + ], + }, + }), +})) + +vi.mock('@/app/components/header/account-setting/model-provider-page/hooks', () => ({ + useModelList: () => ({ data: [], mutate: vi.fn(), isLoading: false }), + useCurrentProviderAndModel: () => ({ currentProvider: undefined, currentModel: undefined }), + useDefaultModel: () => ({ data: undefined, mutate: vi.fn(), isLoading: false }), + useModelListAndDefaultModel: () => ({ modelList: [], defaultModel: undefined }), + useModelListAndDefaultModelAndCurrentProviderAndModel: () => ({ + modelList: [], + defaultModel: undefined, + currentProvider: undefined, + currentModel: undefined, + }), + useUpdateModelList: () => vi.fn(), + useUpdateModelProviders: () => vi.fn(), + useLanguage: () => 'en_US', + useSystemDefaultModelAndModelList: () => [undefined, vi.fn()], + useProviderCredentialsAndLoadBalancing: () => ({ + credentials: undefined, + loadBalancing: undefined, + mutate: vi.fn(), + isLoading: false, + }), + useAnthropicBuyQuota: () => vi.fn(), + useMarketplaceAllPlugins: () => ({ plugins: [], isLoading: false }), + useRefreshModel: () => ({ handleRefreshModel: vi.fn() }), + useModelModalHandler: () => vi.fn(), +})) + +// Mock provider-context +vi.mock('@/context/provider-context', () => ({ + useProviderContext: () => ({ + textGenerationModelList: [], + embeddingsModelList: [], + rerankModelList: [], + agentThoughtModelList: [], + modelProviders: [], + textEmbeddingModelList: [], + speech2textModelList: [], + ttsModelList: [], + moderationModelList: [], + hasSettedApiKey: true, + plan: { type: 'free' }, + enableBilling: false, + onPlanInfoChanged: vi.fn(), + isCurrentWorkspaceDatasetOperator: false, + supportRetrievalMethods: ['semantic_search', 'full_text_search', 'hybrid_search'], + }), +})) + +vi.mock('@/app/components/datasets/common/check-rerank-model', () => ({ + isReRankModelSelected: () => true, +})) + +vi.mock('@/app/components/base/toast', () => ({ + default: { + notify: vi.fn(), + }, +})) + +vi.mock('@/context/i18n', () => ({ + useDocLink: () => (path: string) => `https://docs.dify.ai${path}`, +})) + +describe('Form', () => { + beforeEach(() => { + vi.clearAllMocks() + mockDataset = createMockDataset() + }) + + describe('Rendering', () => { + it('should render without crashing', () => { + render(<Form />) + expect(screen.getByRole('button', { name: /form\.save/i })).toBeInTheDocument() + }) + + it('should render dataset name input with initial value', () => { + render(<Form />) + const nameInput = screen.getByDisplayValue('Test Dataset') + expect(nameInput).toBeInTheDocument() + }) + + it('should render dataset description textarea', () => { + render(<Form />) + const descriptionTextarea = screen.getByDisplayValue('Test description') + expect(descriptionTextarea).toBeInTheDocument() + }) + + it('should render save button', () => { + render(<Form />) + const saveButton = screen.getByRole('button', { name: /form\.save/i }) + expect(saveButton).toBeInTheDocument() + }) + + it('should render permission selector', () => { + render(<Form />) + // Permission selector renders the current permission text + expect(screen.getByText(/form\.permissionsOnlyMe/i)).toBeInTheDocument() + }) + }) + + describe('BasicInfoSection', () => { + it('should allow editing dataset name', () => { + render(<Form />) + const nameInput = screen.getByDisplayValue('Test Dataset') + + fireEvent.change(nameInput, { target: { value: 'Updated Dataset Name' } }) + + expect(nameInput).toHaveValue('Updated Dataset Name') + }) + + it('should allow editing dataset description', () => { + render(<Form />) + const descriptionTextarea = screen.getByDisplayValue('Test description') + + fireEvent.change(descriptionTextarea, { target: { value: 'Updated description' } }) + + expect(descriptionTextarea).toHaveValue('Updated description') + }) + + it('should render app icon', () => { + const { container } = render(<Form />) + // The app icon wrapper should be rendered (icon may be in a span or SVG) + // The icon is rendered within a clickable container in the name and icon section + const iconSection = container.querySelector('[class*="cursor-pointer"]') + expect(iconSection).toBeInTheDocument() + }) + }) + + describe('IndexingSection - Internal Provider', () => { + it('should render chunk structure section when doc_form is set', () => { + render(<Form />) + expect(screen.getByText(/form\.chunkStructure\.title/i)).toBeInTheDocument() + }) + + it('should render index method section', () => { + render(<Form />) + // May match multiple elements (label and descriptions) + expect(screen.getAllByText(/form\.indexMethod/i).length).toBeGreaterThan(0) + }) + + it('should render embedding model section when indexMethod is high_quality', () => { + render(<Form />) + expect(screen.getByText(/form\.embeddingModel/i)).toBeInTheDocument() + }) + + it('should render retrieval settings section', () => { + render(<Form />) + expect(screen.getByText(/form\.retrievalSetting\.title/i)).toBeInTheDocument() + }) + + it('should render learn more links', () => { + render(<Form />) + const learnMoreLinks = screen.getAllByText(/learnMore/i) + expect(learnMoreLinks.length).toBeGreaterThan(0) + }) + }) + + describe('ExternalKnowledgeSection - External Provider', () => { + beforeEach(() => { + mockDataset = createMockDataset({ provider: 'external' }) + }) + + it('should render external knowledge API info when provider is external', () => { + render(<Form />) + expect(screen.getByText(/form\.externalKnowledgeAPI/i)).toBeInTheDocument() + }) + + it('should render external knowledge ID when provider is external', () => { + render(<Form />) + expect(screen.getByText(/form\.externalKnowledgeID/i)).toBeInTheDocument() + }) + + it('should display external API name', () => { + render(<Form />) + expect(screen.getByText('External API')).toBeInTheDocument() + }) + + it('should display external API endpoint', () => { + render(<Form />) + expect(screen.getByText('https://api.example.com')).toBeInTheDocument() + }) + + it('should display external knowledge ID value', () => { + render(<Form />) + expect(screen.getByText('ext-1')).toBeInTheDocument() + }) + }) + + describe('Save Functionality', () => { + it('should call save when save button is clicked', async () => { + const { updateDatasetSetting } = await import('@/service/datasets') + render(<Form />) + + const saveButton = screen.getByRole('button', { name: /form\.save/i }) + fireEvent.click(saveButton) + + await waitFor(() => { + expect(updateDatasetSetting).toHaveBeenCalled() + }) + }) + + it('should show loading state on save button while saving', async () => { + const { updateDatasetSetting } = await import('@/service/datasets') + vi.mocked(updateDatasetSetting).mockImplementation( + () => new Promise(resolve => setTimeout(resolve, 100)), + ) + + render(<Form />) + + const saveButton = screen.getByRole('button', { name: /form\.save/i }) + fireEvent.click(saveButton) + + // Button should be disabled during loading + await waitFor(() => { + expect(saveButton).toBeDisabled() + }) + }) + + it('should show error when trying to save with empty name', async () => { + const Toast = await import('@/app/components/base/toast') + render(<Form />) + + // Clear the name + const nameInput = screen.getByDisplayValue('Test Dataset') + fireEvent.change(nameInput, { target: { value: '' } }) + + // Try to save + const saveButton = screen.getByRole('button', { name: /form\.save/i }) + fireEvent.click(saveButton) + + await waitFor(() => { + expect(Toast.default.notify).toHaveBeenCalledWith({ + type: 'error', + message: expect.any(String), + }) + }) + }) + + it('should save with updated name', async () => { + const { updateDatasetSetting } = await import('@/service/datasets') + render(<Form />) + + // Update name + const nameInput = screen.getByDisplayValue('Test Dataset') + fireEvent.change(nameInput, { target: { value: 'New Dataset Name' } }) + + // Save + const saveButton = screen.getByRole('button', { name: /form\.save/i }) + fireEvent.click(saveButton) + + await waitFor(() => { + expect(updateDatasetSetting).toHaveBeenCalledWith( + expect.objectContaining({ + body: expect.objectContaining({ + name: 'New Dataset Name', + }), + }), + ) + }) + }) + + it('should save with updated description', async () => { + const { updateDatasetSetting } = await import('@/service/datasets') + render(<Form />) + + // Update description + const descriptionTextarea = screen.getByDisplayValue('Test description') + fireEvent.change(descriptionTextarea, { target: { value: 'New description' } }) + + // Save + const saveButton = screen.getByRole('button', { name: /form\.save/i }) + fireEvent.click(saveButton) + + await waitFor(() => { + expect(updateDatasetSetting).toHaveBeenCalledWith( + expect.objectContaining({ + body: expect.objectContaining({ + description: 'New description', + }), + }), + ) + }) + }) + }) + + describe('Disabled States', () => { + it('should disable inputs when embedding is not available', () => { + mockDataset = createMockDataset({ embedding_available: false }) + render(<Form />) + + const nameInput = screen.getByDisplayValue('Test Dataset') + expect(nameInput).toBeDisabled() + + const descriptionTextarea = screen.getByDisplayValue('Test description') + expect(descriptionTextarea).toBeDisabled() + }) + }) + + describe('Conditional Rendering', () => { + it('should not render chunk structure when doc_form is not set', () => { + mockDataset = createMockDataset({ doc_form: undefined as unknown as ChunkingMode }) + render(<Form />) + + // Chunk structure should not be present + expect(screen.queryByText(/form\.chunkStructure\.title/i)).not.toBeInTheDocument() + }) + + it('should render IndexingSection for internal provider', () => { + mockDataset = createMockDataset({ provider: 'vendor' }) + render(<Form />) + + // May match multiple elements (label and descriptions) + expect(screen.getAllByText(/form\.indexMethod/i).length).toBeGreaterThan(0) + expect(screen.queryByText(/form\.externalKnowledgeAPI/i)).not.toBeInTheDocument() + }) + + it('should render ExternalKnowledgeSection for external provider', () => { + mockDataset = createMockDataset({ provider: 'external' }) + render(<Form />) + + expect(screen.getByText(/form\.externalKnowledgeAPI/i)).toBeInTheDocument() + }) + }) + + describe('Permission Selection', () => { + it('should open permission dropdown when clicked', async () => { + render(<Form />) + + const permissionTrigger = screen.getByText(/form\.permissionsOnlyMe/i) + fireEvent.click(permissionTrigger) + + await waitFor(() => { + // Should show all permission options + expect(screen.getAllByText(/form\.permissionsOnlyMe/i).length).toBeGreaterThanOrEqual(1) + }) + }) + }) + + describe('Integration', () => { + it('should render all main sections', () => { + render(<Form />) + + // Basic info + expect(screen.getByText(/form\.nameAndIcon/i)).toBeInTheDocument() + expect(screen.getByText(/form\.desc/i)).toBeInTheDocument() + // form.permissions matches multiple elements (label and permission options) + expect(screen.getAllByText(/form\.permissions/i).length).toBeGreaterThan(0) + + // Indexing (for internal provider) + expect(screen.getByText(/form\.chunkStructure\.title/i)).toBeInTheDocument() + // form.indexMethod matches multiple elements + expect(screen.getAllByText(/form\.indexMethod/i).length).toBeGreaterThan(0) + + // Save button + expect(screen.getByRole('button', { name: /form\.save/i })).toBeInTheDocument() + }) + }) +}) diff --git a/web/app/components/datasets/settings/form/index.tsx b/web/app/components/datasets/settings/form/index.tsx index ca072cfcae..f060701f1c 100644 --- a/web/app/components/datasets/settings/form/index.tsx +++ b/web/app/components/datasets/settings/form/index.tsx @@ -1,487 +1,126 @@ 'use client' -import type { AppIconSelection } from '@/app/components/base/app-icon-picker' -import type { DefaultModel } from '@/app/components/header/account-setting/model-provider-page/declarations' -import type { Member } from '@/models/common' -import type { IconInfo, SummaryIndexSetting as SummaryIndexSettingType } from '@/models/datasets' -import type { AppIconType, RetrievalConfig } from '@/types/app' -import { RiAlertFill } from '@remixicon/react' -import { useCallback, useEffect, useMemo, useRef, useState } from 'react' import { useTranslation } from 'react-i18next' -import AppIcon from '@/app/components/base/app-icon' -import AppIconPicker from '@/app/components/base/app-icon-picker' import Button from '@/app/components/base/button' import Divider from '@/app/components/base/divider' -import { ApiConnectionMod } from '@/app/components/base/icons/src/vender/solid/development' -import Input from '@/app/components/base/input' -import Textarea from '@/app/components/base/textarea' -import Toast from '@/app/components/base/toast' -import { isReRankModelSelected } from '@/app/components/datasets/common/check-rerank-model' -import EconomicalRetrievalMethodConfig from '@/app/components/datasets/common/economical-retrieval-method-config' -import RetrievalMethodConfig from '@/app/components/datasets/common/retrieval-method-config' -import { ModelTypeEnum } from '@/app/components/header/account-setting/model-provider-page/declarations' -import { useModelList } from '@/app/components/header/account-setting/model-provider-page/hooks' -import ModelSelector from '@/app/components/header/account-setting/model-provider-page/model-selector' -import { IS_CE_EDITION } from '@/config' -import { useSelector as useAppContextWithSelector } from '@/context/app-context' -import { useDatasetDetailContextWithSelector } from '@/context/dataset-detail' -import { useDocLink } from '@/context/i18n' -import { ChunkingMode, DatasetPermission } from '@/models/datasets' -import { updateDatasetSetting } from '@/service/datasets' -import { useInvalidDatasetList } from '@/service/knowledge/use-dataset' -import { useMembers } from '@/service/use-common' -import { IndexingType } from '../../create/step-two' -import RetrievalSettings from '../../external-knowledge-base/create/RetrievalSettings' -import ChunkStructure from '../chunk-structure' -import IndexMethod from '../index-method' -import PermissionSelector from '../permission-selector' -import SummaryIndexSetting from '../summary-index-setting' -import { checkShowMultiModalTip } from '../utils' - -const rowClass = 'flex gap-x-1' -const labelClass = 'flex items-center shrink-0 w-[180px] h-7 pt-1' - -const DEFAULT_APP_ICON: IconInfo = { - icon_type: 'emoji', - icon: '๐Ÿ“™', - icon_background: '#FFF4ED', - icon_url: '', -} +import BasicInfoSection from './components/basic-info-section' +import ExternalKnowledgeSection from './components/external-knowledge-section' +import IndexingSection from './components/indexing-section' +import { useFormState } from './hooks/use-form-state' const Form = () => { const { t } = useTranslation() - const docLink = useDocLink() - const isCurrentWorkspaceDatasetOperator = useAppContextWithSelector(state => state.isCurrentWorkspaceDatasetOperator) - const currentDataset = useDatasetDetailContextWithSelector(state => state.dataset) - const mutateDatasets = useDatasetDetailContextWithSelector(state => state.mutateDatasetRes) - const [loading, setLoading] = useState(false) - const [name, setName] = useState(currentDataset?.name ?? '') - const [iconInfo, setIconInfo] = useState(currentDataset?.icon_info || DEFAULT_APP_ICON) - const [showAppIconPicker, setShowAppIconPicker] = useState(false) - const [description, setDescription] = useState(currentDataset?.description ?? '') - const [permission, setPermission] = useState(currentDataset?.permission) - const [topK, setTopK] = useState(currentDataset?.external_retrieval_model.top_k ?? 2) - const [scoreThreshold, setScoreThreshold] = useState(currentDataset?.external_retrieval_model.score_threshold ?? 0.5) - const [scoreThresholdEnabled, setScoreThresholdEnabled] = useState(currentDataset?.external_retrieval_model.score_threshold_enabled ?? false) - const [selectedMemberIDs, setSelectedMemberIDs] = useState<string[]>(currentDataset?.partial_member_list || []) - const [memberList, setMemberList] = useState<Member[]>([]) - const [indexMethod, setIndexMethod] = useState(currentDataset?.indexing_technique) - const [keywordNumber, setKeywordNumber] = useState(currentDataset?.keyword_number ?? 10) - const [retrievalConfig, setRetrievalConfig] = useState(currentDataset?.retrieval_model_dict as RetrievalConfig) - const [embeddingModel, setEmbeddingModel] = useState<DefaultModel>( - currentDataset?.embedding_model - ? { - provider: currentDataset.embedding_model_provider, - model: currentDataset.embedding_model, - } - : { - provider: '', - model: '', - }, - ) - const [summaryIndexSetting, setSummaryIndexSetting] = useState(currentDataset?.summary_index_setting) - const handleSummaryIndexSettingChange = useCallback((payload: SummaryIndexSettingType) => { - setSummaryIndexSetting((prev) => { - return { ...prev, ...payload } - }) - }, []) - const { data: rerankModelList } = useModelList(ModelTypeEnum.rerank) - const { data: embeddingModelList } = useModelList(ModelTypeEnum.textEmbedding) - const { data: membersData } = useMembers() - const previousAppIcon = useRef(DEFAULT_APP_ICON) + const { + // Context values + currentDataset, + isCurrentWorkspaceDatasetOperator, - const handleOpenAppIconPicker = useCallback(() => { - setShowAppIconPicker(true) - previousAppIcon.current = iconInfo - }, [iconInfo]) + // Loading state + loading, - const handleSelectAppIcon = useCallback((icon: AppIconSelection) => { - const iconInfo: IconInfo = { - icon_type: icon.type, - icon: icon.type === 'emoji' ? icon.icon : icon.fileId, - icon_background: icon.type === 'emoji' ? icon.background : undefined, - icon_url: icon.type === 'emoji' ? undefined : icon.url, - } - setIconInfo(iconInfo) - setShowAppIconPicker(false) - }, []) + // Basic form + name, + setName, + description, + setDescription, - const handleCloseAppIconPicker = useCallback(() => { - setIconInfo(previousAppIcon.current) - setShowAppIconPicker(false) - }, []) + // Icon + iconInfo, + showAppIconPicker, + handleOpenAppIconPicker, + handleSelectAppIcon, + handleCloseAppIconPicker, - const handleSettingsChange = useCallback((data: { top_k?: number, score_threshold?: number, score_threshold_enabled?: boolean }) => { - if (data.top_k !== undefined) - setTopK(data.top_k) - if (data.score_threshold !== undefined) - setScoreThreshold(data.score_threshold) - if (data.score_threshold_enabled !== undefined) - setScoreThresholdEnabled(data.score_threshold_enabled) - }, []) + // Permission + permission, + setPermission, + selectedMemberIDs, + setSelectedMemberIDs, + memberList, - useEffect(() => { - if (!membersData?.accounts) - setMemberList([]) - else - setMemberList(membersData.accounts) - }, [membersData]) + // External retrieval + topK, + scoreThreshold, + scoreThresholdEnabled, + handleSettingsChange, - const invalidDatasetList = useInvalidDatasetList() - const handleSave = async () => { - if (loading) - return - if (!name?.trim()) { - Toast.notify({ type: 'error', message: t('form.nameError', { ns: 'datasetSettings' }) }) - return - } - if ( - !isReRankModelSelected({ - rerankModelList, - retrievalConfig, - indexMethod, - }) - ) { - Toast.notify({ type: 'error', message: t('datasetConfig.rerankModelRequired', { ns: 'appDebug' }) }) - return - } - if (retrievalConfig.weights) { - retrievalConfig.weights.vector_setting.embedding_provider_name = embeddingModel.provider || '' - retrievalConfig.weights.vector_setting.embedding_model_name = embeddingModel.model || '' - } - try { - setLoading(true) - const requestParams = { - datasetId: currentDataset!.id, - body: { - name, - icon_info: iconInfo, - doc_form: currentDataset?.doc_form, - description, - permission, - indexing_technique: indexMethod, - retrieval_model: { - ...retrievalConfig, - score_threshold: retrievalConfig.score_threshold_enabled ? retrievalConfig.score_threshold : 0, - }, - embedding_model: embeddingModel.model, - embedding_model_provider: embeddingModel.provider, - ...(currentDataset!.provider === 'external' && { - external_knowledge_id: currentDataset!.external_knowledge_info.external_knowledge_id, - external_knowledge_api_id: currentDataset!.external_knowledge_info.external_knowledge_api_id, - external_retrieval_model: { - top_k: topK, - score_threshold: scoreThreshold, - score_threshold_enabled: scoreThresholdEnabled, - }, - }), - keyword_number: keywordNumber, - summary_index_setting: summaryIndexSetting, - }, - } as any - if (permission === DatasetPermission.partialMembers) { - requestParams.body.partial_member_list = selectedMemberIDs.map((id) => { - return { - user_id: id, - role: memberList.find(member => member.id === id)?.role, - } - }) - } - await updateDatasetSetting(requestParams) - Toast.notify({ type: 'success', message: t('actionMsg.modifiedSuccessfully', { ns: 'common' }) }) - if (mutateDatasets) { - await mutateDatasets() - invalidDatasetList() - } - } - catch { - Toast.notify({ type: 'error', message: t('actionMsg.modifiedUnsuccessfully', { ns: 'common' }) }) - } - finally { - setLoading(false) - } - } + // Indexing and retrieval + indexMethod, + setIndexMethod, + keywordNumber, + setKeywordNumber, + retrievalConfig, + setRetrievalConfig, + embeddingModel, + setEmbeddingModel, + embeddingModelList, - const isShowIndexMethod = currentDataset && currentDataset.doc_form !== ChunkingMode.parentChild && currentDataset.indexing_technique && indexMethod + // Summary index + summaryIndexSetting, + handleSummaryIndexSettingChange, - const showMultiModalTip = useMemo(() => { - return checkShowMultiModalTip({ - embeddingModel, - rerankingEnable: retrievalConfig.reranking_enable, - rerankModel: { - rerankingProviderName: retrievalConfig.reranking_model.reranking_provider_name, - rerankingModelName: retrievalConfig.reranking_model.reranking_model_name, - }, - indexMethod, - embeddingModelList, - rerankModelList, - }) - }, [embeddingModel, rerankModelList, retrievalConfig.reranking_enable, retrievalConfig.reranking_model, embeddingModelList, indexMethod]) + // Computed + showMultiModalTip, + + // Actions + handleSave, + } = useFormState() + + const isExternalProvider = currentDataset?.provider === 'external' return ( <div className="flex w-full flex-col gap-y-4 px-20 py-8 sm:w-[960px]"> - {/* Dataset name and icon */} - <div className={rowClass}> - <div className={labelClass}> - <div className="system-sm-semibold text-text-secondary">{t('form.nameAndIcon', { ns: 'datasetSettings' })}</div> - </div> - <div className="flex grow items-center gap-x-2"> - <AppIcon - size="small" - onClick={handleOpenAppIconPicker} - className="cursor-pointer" - iconType={iconInfo.icon_type as AppIconType} - icon={iconInfo.icon} - background={iconInfo.icon_background} - imageUrl={iconInfo.icon_url} - showEditIcon - /> - <Input - disabled={!currentDataset?.embedding_available} - value={name} - onChange={e => setName(e.target.value)} - /> - </div> - </div> - {/* Dataset description */} - <div className={rowClass}> - <div className={labelClass}> - <div className="system-sm-semibold text-text-secondary">{t('form.desc', { ns: 'datasetSettings' })}</div> - </div> - <div className="grow"> - <Textarea - disabled={!currentDataset?.embedding_available} - className="resize-none" - placeholder={t('form.descPlaceholder', { ns: 'datasetSettings' }) || ''} - value={description} - onChange={e => setDescription(e.target.value)} - /> - </div> - </div> - {/* Permissions */} - <div className={rowClass}> - <div className={labelClass}> - <div className="system-sm-semibold text-text-secondary">{t('form.permissions', { ns: 'datasetSettings' })}</div> - </div> - <div className="grow"> - <PermissionSelector - disabled={!currentDataset?.embedding_available || isCurrentWorkspaceDatasetOperator} - permission={permission} - value={selectedMemberIDs} - onChange={v => setPermission(v)} - onMemberSelect={setSelectedMemberIDs} - memberList={memberList} - /> - </div> - </div> - { - !!currentDataset?.doc_form && ( - <> - <Divider - type="horizontal" - className="my-1 h-px bg-divider-subtle" - /> - {/* Chunk Structure */} - <div className={rowClass}> - <div className="flex w-[180px] shrink-0 flex-col"> - <div className="system-sm-semibold flex h-8 items-center text-text-secondary"> - {t('form.chunkStructure.title', { ns: 'datasetSettings' })} - </div> - <div className="body-xs-regular text-text-tertiary"> - <a - target="_blank" - rel="noopener noreferrer" - href={docLink('/use-dify/knowledge/create-knowledge/chunking-and-cleaning-text')} - className="text-text-accent" - > - {t('form.chunkStructure.learnMore', { ns: 'datasetSettings' })} - </a> - {t('form.chunkStructure.description', { ns: 'datasetSettings' })} - </div> - </div> - <div className="grow"> - <ChunkStructure - chunkStructure={currentDataset?.doc_form} - /> - </div> - </div> - </> - ) - } - {!!(isShowIndexMethod || indexMethod === 'high_quality') && ( - <Divider - type="horizontal" - className="my-1 h-px bg-divider-subtle" - /> - )} - {!!isShowIndexMethod && ( - <div className={rowClass}> - <div className={labelClass}> - <div className="system-sm-semibold text-text-secondary">{t('form.indexMethod', { ns: 'datasetSettings' })}</div> - </div> - <div className="grow"> - <IndexMethod - value={indexMethod} - disabled={!currentDataset?.embedding_available} - onChange={v => setIndexMethod(v!)} - currentValue={currentDataset.indexing_technique} - keywordNumber={keywordNumber} - onKeywordNumberChange={setKeywordNumber} - /> - {currentDataset.indexing_technique === IndexingType.ECONOMICAL && indexMethod === IndexingType.QUALIFIED && ( - <div className="relative mt-2 flex h-10 items-center gap-x-0.5 overflow-hidden rounded-xl border-[0.5px] border-components-panel-border bg-components-panel-bg-blur px-2 shadow-xs shadow-shadow-shadow-3"> - <div className="absolute left-0 top-0 flex h-full w-full items-center bg-toast-warning-bg opacity-40" /> - <div className="p-1"> - <RiAlertFill className="size-4 text-text-warning-secondary" /> - </div> - <span className="system-xs-medium text-text-primary"> - {t('form.upgradeHighQualityTip', { ns: 'datasetSettings' })} - </span> - </div> - )} - </div> - </div> - )} - {indexMethod === IndexingType.QUALIFIED && ( - <div className={rowClass}> - <div className={labelClass}> - <div className="system-sm-semibold text-text-secondary"> - {t('form.embeddingModel', { ns: 'datasetSettings' })} - </div> - </div> - <div className="grow"> - <ModelSelector - defaultModel={embeddingModel} - modelList={embeddingModelList} - onSelect={setEmbeddingModel} - /> - </div> - </div> - )} - { - indexMethod === IndexingType.QUALIFIED - && [ChunkingMode.text, ChunkingMode.parentChild].includes(currentDataset?.doc_form as ChunkingMode) - && IS_CE_EDITION && ( - <> - <Divider - type="horizontal" - className="my-1 h-px bg-divider-subtle" - /> - <SummaryIndexSetting - entry="dataset-settings" - summaryIndexSetting={summaryIndexSetting} - onSummaryIndexSettingChange={handleSummaryIndexSettingChange} - /> - </> - ) - } - {/* Retrieval Method Config */} - {currentDataset?.provider === 'external' - ? ( - <> - <Divider - type="horizontal" - className="my-1 h-px bg-divider-subtle" - /> - <div className={rowClass}> - <div className={labelClass}> - <div className="system-sm-semibold text-text-secondary">{t('form.retrievalSetting.title', { ns: 'datasetSettings' })}</div> - </div> - <RetrievalSettings - topK={topK} - scoreThreshold={scoreThreshold} - scoreThresholdEnabled={scoreThresholdEnabled} - onChange={handleSettingsChange} - isInRetrievalSetting={true} - /> - </div> - <Divider - type="horizontal" - className="my-1 h-px bg-divider-subtle" - /> - <div className={rowClass}> - <div className={labelClass}> - <div className="system-sm-semibold text-text-secondary">{t('form.externalKnowledgeAPI', { ns: 'datasetSettings' })}</div> - </div> - <div className="w-full"> - <div className="flex h-full items-center gap-1 rounded-lg bg-components-input-bg-normal px-3 py-2"> - <ApiConnectionMod className="h-4 w-4 text-text-secondary" /> - <div className="system-sm-medium overflow-hidden text-ellipsis text-text-secondary"> - {currentDataset?.external_knowledge_info.external_knowledge_api_name} - </div> - <div className="system-xs-regular text-text-tertiary">ยท</div> - <div className="system-xs-regular text-text-tertiary"> - {currentDataset?.external_knowledge_info.external_knowledge_api_endpoint} - </div> - </div> - </div> - </div> - <div className={rowClass}> - <div className={labelClass}> - <div className="system-sm-semibold text-text-secondary">{t('form.externalKnowledgeID', { ns: 'datasetSettings' })}</div> - </div> - <div className="w-full"> - <div className="flex h-full items-center gap-1 rounded-lg bg-components-input-bg-normal px-3 py-2"> - <div className="system-xs-regular text-text-tertiary"> - {currentDataset?.external_knowledge_info.external_knowledge_id} - </div> - </div> - </div> - </div> - </> - ) - - : indexMethod - ? ( - <> - <Divider - type="horizontal" - className="my-1 h-px bg-divider-subtle" - /> - <div className={rowClass}> - <div className={labelClass}> - <div className="flex w-[180px] shrink-0 flex-col"> - <div className="system-sm-semibold flex h-7 items-center pt-1 text-text-secondary"> - {t('form.retrievalSetting.title', { ns: 'datasetSettings' })} - </div> - <div className="body-xs-regular text-text-tertiary"> - <a - target="_blank" - rel="noopener noreferrer" - href={docLink('/use-dify/knowledge/create-knowledge/setting-indexing-methods')} - className="text-text-accent" - > - {t('form.retrievalSetting.learnMore', { ns: 'datasetSettings' })} - </a> - {t('form.retrievalSetting.description', { ns: 'datasetSettings' })} - </div> - </div> - </div> - <div className="grow"> - {indexMethod === IndexingType.QUALIFIED - ? ( - <RetrievalMethodConfig - value={retrievalConfig} - onChange={setRetrievalConfig} - showMultiModalTip={showMultiModalTip} - /> - ) - : ( - <EconomicalRetrievalMethodConfig - value={retrievalConfig} - onChange={setRetrievalConfig} - /> - )} - </div> - </div> - </> - ) - : null} - <Divider - type="horizontal" - className="my-1 h-px bg-divider-subtle" + <BasicInfoSection + currentDataset={currentDataset} + isCurrentWorkspaceDatasetOperator={isCurrentWorkspaceDatasetOperator} + name={name} + setName={setName} + description={description} + setDescription={setDescription} + iconInfo={iconInfo} + showAppIconPicker={showAppIconPicker} + handleOpenAppIconPicker={handleOpenAppIconPicker} + handleSelectAppIcon={handleSelectAppIcon} + handleCloseAppIconPicker={handleCloseAppIconPicker} + permission={permission} + setPermission={setPermission} + selectedMemberIDs={selectedMemberIDs} + setSelectedMemberIDs={setSelectedMemberIDs} + memberList={memberList} /> - <div className={rowClass}> - <div className={labelClass} /> + + {isExternalProvider + ? ( + <ExternalKnowledgeSection + currentDataset={currentDataset} + topK={topK} + scoreThreshold={scoreThreshold} + scoreThresholdEnabled={scoreThresholdEnabled} + handleSettingsChange={handleSettingsChange} + /> + ) + : ( + <IndexingSection + currentDataset={currentDataset} + indexMethod={indexMethod} + setIndexMethod={setIndexMethod} + keywordNumber={keywordNumber} + setKeywordNumber={setKeywordNumber} + embeddingModel={embeddingModel} + setEmbeddingModel={setEmbeddingModel} + embeddingModelList={embeddingModelList} + retrievalConfig={retrievalConfig} + setRetrievalConfig={setRetrievalConfig} + summaryIndexSetting={summaryIndexSetting} + handleSummaryIndexSettingChange={handleSummaryIndexSettingChange} + showMultiModalTip={showMultiModalTip} + /> + )} + + <Divider type="horizontal" className="my-1 h-px bg-divider-subtle" /> + + {/* Save Button */} + <div className="flex gap-x-1"> + <div className="flex h-7 w-[180px] shrink-0 items-center pt-1" /> <div className="grow"> <Button className="min-w-24" @@ -494,12 +133,6 @@ const Form = () => { </Button> </div> </div> - {showAppIconPicker && ( - <AppIconPicker - onSelect={handleSelectAppIcon} - onClose={handleCloseAppIconPicker} - /> - )} </div> ) } diff --git a/web/app/components/rag-pipeline/hooks/use-DSL.spec.ts b/web/app/components/rag-pipeline/hooks/use-DSL.spec.ts index 0d217f3605..c6e3d261c0 100644 --- a/web/app/components/rag-pipeline/hooks/use-DSL.spec.ts +++ b/web/app/components/rag-pipeline/hooks/use-DSL.spec.ts @@ -1,79 +1,49 @@ -import { renderHook } from '@testing-library/react' -import { act } from 'react' +import { act, renderHook, waitFor } from '@testing-library/react' import { afterEach, beforeEach, describe, expect, it, vi } from 'vitest' - -// ============================================================================ -// Import after mocks -// ============================================================================ - import { useDSL } from './use-DSL' -// ============================================================================ -// Mocks -// ============================================================================ +// Mock dependencies +const mockNotify = vi.fn() +vi.mock('@/app/components/base/toast', () => ({ + useToastContext: () => ({ notify: mockNotify }), +})) + +const mockEventEmitter = { emit: vi.fn() } +vi.mock('@/context/event-emitter', () => ({ + useEventEmitterContextContext: () => ({ eventEmitter: mockEventEmitter }), +})) + +const mockDoSyncWorkflowDraft = vi.fn() +vi.mock('./use-nodes-sync-draft', () => ({ + useNodesSyncDraft: () => ({ doSyncWorkflowDraft: mockDoSyncWorkflowDraft }), +})) + +const mockGetState = vi.fn() +vi.mock('@/app/components/workflow/store', () => ({ + useWorkflowStore: () => ({ getState: mockGetState }), +})) + +const mockExportPipelineConfig = vi.fn() +vi.mock('@/service/use-pipeline', () => ({ + useExportPipelineDSL: () => ({ mutateAsync: mockExportPipelineConfig }), +})) + +const mockFetchWorkflowDraft = vi.fn() +vi.mock('@/service/workflow', () => ({ + fetchWorkflowDraft: (...args: unknown[]) => mockFetchWorkflowDraft(...args), +})) + +const mockDownloadBlob = vi.fn() +vi.mock('@/utils/download', () => ({ + downloadBlob: (...args: unknown[]) => mockDownloadBlob(...args), +})) -// Mock react-i18next vi.mock('react-i18next', () => ({ useTranslation: () => ({ t: (key: string) => key, }), })) -// Mock toast context -const mockNotify = vi.fn() -vi.mock('@/app/components/base/toast', () => ({ - useToastContext: () => ({ - notify: mockNotify, - }), -})) - -// Mock event emitter context -const mockEmit = vi.fn() -vi.mock('@/context/event-emitter', () => ({ - useEventEmitterContextContext: () => ({ - eventEmitter: { - emit: mockEmit, - }, - }), -})) - -// Mock workflow store -const mockWorkflowStoreGetState = vi.fn() -vi.mock('@/app/components/workflow/store', () => ({ - useWorkflowStore: () => ({ - getState: mockWorkflowStoreGetState, - }), -})) - -// Mock useNodesSyncDraft -const mockDoSyncWorkflowDraft = vi.fn() -vi.mock('./use-nodes-sync-draft', () => ({ - useNodesSyncDraft: () => ({ - doSyncWorkflowDraft: mockDoSyncWorkflowDraft, - }), -})) - -// Mock pipeline service -const mockExportPipelineConfig = vi.fn() -vi.mock('@/service/use-pipeline', () => ({ - useExportPipelineDSL: () => ({ - mutateAsync: mockExportPipelineConfig, - }), -})) - -// Mock download utility -const mockDownloadBlob = vi.fn() -vi.mock('@/utils/download', () => ({ - downloadBlob: (...args: unknown[]) => mockDownloadBlob(...args), -})) - -// Mock workflow service -const mockFetchWorkflowDraft = vi.fn() -vi.mock('@/service/workflow', () => ({ - fetchWorkflowDraft: (url: string) => mockFetchWorkflowDraft(url), -})) - -// Mock workflow constants vi.mock('@/app/components/workflow/constants', () => ({ DSL_EXPORT_CHECK: 'DSL_EXPORT_CHECK', })) @@ -87,44 +57,23 @@ describe('useDSL', () => { vi.clearAllMocks() // Default store state - mockWorkflowStoreGetState.mockReturnValue({ + mockGetState.mockReturnValue({ pipelineId: 'test-pipeline-id', knowledgeName: 'Test Knowledge Base', }) mockDoSyncWorkflowDraft.mockResolvedValue(undefined) mockExportPipelineConfig.mockResolvedValue({ data: 'yaml-content' }) - mockFetchWorkflowDraft.mockResolvedValue({ - environment_variables: [], - }) + mockFetchWorkflowDraft.mockResolvedValue({ environment_variables: [] }) }) afterEach(() => { vi.clearAllMocks() }) - describe('hook initialization', () => { - it('should return exportCheck function', () => { - const { result } = renderHook(() => useDSL()) - - expect(result.current.exportCheck).toBeDefined() - expect(typeof result.current.exportCheck).toBe('function') - }) - - it('should return handleExportDSL function', () => { - const { result } = renderHook(() => useDSL()) - - expect(result.current.handleExportDSL).toBeDefined() - expect(typeof result.current.handleExportDSL).toBe('function') - }) - }) - describe('handleExportDSL', () => { - it('should not export when pipelineId is missing', async () => { - mockWorkflowStoreGetState.mockReturnValue({ - pipelineId: undefined, - knowledgeName: 'Test', - }) + it('should return early when pipelineId is not set', async () => { + mockGetState.mockReturnValue({ pipelineId: null, knowledgeName: 'test' }) const { result } = renderHook(() => useDSL()) @@ -133,30 +82,6 @@ describe('useDSL', () => { }) expect(mockDoSyncWorkflowDraft).not.toHaveBeenCalled() - expect(mockExportPipelineConfig).not.toHaveBeenCalled() - }) - - it('should sync workflow draft before export', async () => { - const { result } = renderHook(() => useDSL()) - - await act(async () => { - await result.current.handleExportDSL() - }) - - expect(mockDoSyncWorkflowDraft).toHaveBeenCalled() - }) - - it('should call exportPipelineConfig with correct params', async () => { - const { result } = renderHook(() => useDSL()) - - await act(async () => { - await result.current.handleExportDSL(true) - }) - - expect(mockExportPipelineConfig).toHaveBeenCalledWith({ - pipelineId: 'test-pipeline-id', - include: true, - }) }) it('should create and download file', async () => { @@ -169,7 +94,7 @@ describe('useDSL', () => { expect(mockDownloadBlob).toHaveBeenCalled() }) - it('should use correct file extension for download', async () => { + it('should set correct download filename', async () => { const { result } = renderHook(() => useDSL()) await act(async () => { @@ -197,7 +122,7 @@ describe('useDSL', () => { ) }) - it('should show error notification on export failure', async () => { + it('should handle export error', async () => { mockExportPipelineConfig.mockRejectedValue(new Error('Export failed')) const { result } = renderHook(() => useDSL()) @@ -206,19 +131,33 @@ describe('useDSL', () => { await result.current.handleExportDSL() }) - expect(mockNotify).toHaveBeenCalledWith({ - type: 'error', - message: 'exportFailed', + await waitFor(() => { + expect(mockNotify).toHaveBeenCalledWith({ + type: 'error', + message: 'exportFailed', + }) + }) + }) + + it('should pass include parameter', async () => { + const { result } = renderHook(() => useDSL()) + + await act(async () => { + await result.current.handleExportDSL(true) + }) + + await waitFor(() => { + expect(mockExportPipelineConfig).toHaveBeenCalledWith({ + pipelineId: 'test-pipeline-id', + include: true, + }) }) }) }) describe('exportCheck', () => { - it('should not check when pipelineId is missing', async () => { - mockWorkflowStoreGetState.mockReturnValue({ - pipelineId: undefined, - knowledgeName: 'Test', - }) + it('should return early when pipelineId is not set', async () => { + mockGetState.mockReturnValue({ pipelineId: null }) const { result } = renderHook(() => useDSL()) @@ -229,22 +168,8 @@ describe('useDSL', () => { expect(mockFetchWorkflowDraft).not.toHaveBeenCalled() }) - it('should fetch workflow draft', async () => { - const { result } = renderHook(() => useDSL()) - - await act(async () => { - await result.current.exportCheck() - }) - - expect(mockFetchWorkflowDraft).toHaveBeenCalledWith('/rag/pipelines/test-pipeline-id/workflows/draft') - }) - - it('should directly export when no secret environment variables', async () => { - mockFetchWorkflowDraft.mockResolvedValue({ - environment_variables: [ - { id: '1', value_type: 'string', value: 'test' }, - ], - }) + it('should call handleExportDSL directly when no secret variables', async () => { + mockFetchWorkflowDraft.mockResolvedValue({ environment_variables: [] }) const { result } = renderHook(() => useDSL()) @@ -252,16 +177,15 @@ describe('useDSL', () => { await result.current.exportCheck() }) - // Should call doSyncWorkflowDraft (which means handleExportDSL was called) - expect(mockDoSyncWorkflowDraft).toHaveBeenCalled() + await waitFor(() => { + expect(mockFetchWorkflowDraft).toHaveBeenCalledWith('/rag/pipelines/test-pipeline-id/workflows/draft') + expect(mockDoSyncWorkflowDraft).toHaveBeenCalled() + }) }) - it('should emit DSL_EXPORT_CHECK event when secret variables exist', async () => { - mockFetchWorkflowDraft.mockResolvedValue({ - environment_variables: [ - { id: '1', value_type: 'secret', value: 'secret-value' }, - ], - }) + it('should emit event when secret variables exist', async () => { + const secretVars = [{ value_type: 'secret', name: 'API_KEY' }] + mockFetchWorkflowDraft.mockResolvedValue({ environment_variables: secretVars }) const { result } = renderHook(() => useDSL()) @@ -269,15 +193,17 @@ describe('useDSL', () => { await result.current.exportCheck() }) - expect(mockEmit).toHaveBeenCalledWith({ - type: 'DSL_EXPORT_CHECK', - payload: { - data: [{ id: '1', value_type: 'secret', value: 'secret-value' }], - }, + await waitFor(() => { + expect(mockEventEmitter.emit).toHaveBeenCalledWith({ + type: expect.any(String), + payload: { + data: secretVars, + }, + }) }) }) - it('should show error notification on check failure', async () => { + it('should handle export check error', async () => { mockFetchWorkflowDraft.mockRejectedValue(new Error('Fetch failed')) const { result } = renderHook(() => useDSL()) @@ -286,68 +212,12 @@ describe('useDSL', () => { await result.current.exportCheck() }) - expect(mockNotify).toHaveBeenCalledWith({ - type: 'error', - message: 'exportFailed', + await waitFor(() => { + expect(mockNotify).toHaveBeenCalledWith({ + type: 'error', + message: 'exportFailed', + }) }) }) - - it('should filter only secret environment variables', async () => { - mockFetchWorkflowDraft.mockResolvedValue({ - environment_variables: [ - { id: '1', value_type: 'string', value: 'plain' }, - { id: '2', value_type: 'secret', value: 'secret1' }, - { id: '3', value_type: 'number', value: '123' }, - { id: '4', value_type: 'secret', value: 'secret2' }, - ], - }) - - const { result } = renderHook(() => useDSL()) - - await act(async () => { - await result.current.exportCheck() - }) - - expect(mockEmit).toHaveBeenCalledWith({ - type: 'DSL_EXPORT_CHECK', - payload: { - data: [ - { id: '2', value_type: 'secret', value: 'secret1' }, - { id: '4', value_type: 'secret', value: 'secret2' }, - ], - }, - }) - }) - - it('should handle empty environment variables', async () => { - mockFetchWorkflowDraft.mockResolvedValue({ - environment_variables: [], - }) - - const { result } = renderHook(() => useDSL()) - - await act(async () => { - await result.current.exportCheck() - }) - - // Should directly call handleExportDSL since no secrets - expect(mockEmit).not.toHaveBeenCalled() - expect(mockDoSyncWorkflowDraft).toHaveBeenCalled() - }) - - it('should handle undefined environment variables', async () => { - mockFetchWorkflowDraft.mockResolvedValue({ - environment_variables: undefined, - }) - - const { result } = renderHook(() => useDSL()) - - await act(async () => { - await result.current.exportCheck() - }) - - // Should directly call handleExportDSL since no secrets - expect(mockEmit).not.toHaveBeenCalled() - }) }) }) diff --git a/web/app/components/tools/edit-custom-collection-modal/index.spec.tsx b/web/app/components/tools/edit-custom-collection-modal/index.spec.tsx index 204772a3e2..97fc03175d 100644 --- a/web/app/components/tools/edit-custom-collection-modal/index.spec.tsx +++ b/web/app/components/tools/edit-custom-collection-modal/index.spec.tsx @@ -168,6 +168,7 @@ describe('EditCustomCollectionModal', () => { const schemaInput = screen.getByPlaceholderText('tools.createTool.schemaPlaceHolder') fireEvent.change(schemaInput, { target: { value: '{}' } }) + // Wait for parseParamsSchema to be called and state to be updated await waitFor(() => { expect(parseParamsSchemaMock).toHaveBeenCalledWith('{}') }) @@ -184,13 +185,13 @@ describe('EditCustomCollectionModal', () => { provider: 'provider', schema: '{}', schema_type: 'openapi', - credentials: { - auth_type: 'none', - }, icon: { content: '๐Ÿ•ต๏ธ', background: '#FEF7C3', }, + credentials: { + auth_type: 'none', + }, labels: [], })) expect(toastNotifySpy).not.toHaveBeenCalled() diff --git a/web/app/components/workflow-app/components/workflow-onboarding-modal/index.spec.tsx b/web/app/components/workflow-app/components/workflow-onboarding-modal/index.spec.tsx index 525946bb1c..19f5e8b346 100644 --- a/web/app/components/workflow-app/components/workflow-onboarding-modal/index.spec.tsx +++ b/web/app/components/workflow-app/components/workflow-onboarding-modal/index.spec.tsx @@ -11,7 +11,12 @@ vi.mock('@/app/components/base/modal', () => ({ onClose, children, closable, - }: any) { + }: { + isShow: boolean + onClose: () => void + children: React.ReactNode + closable?: boolean + }) { if (!isShow) return null @@ -39,7 +44,10 @@ vi.mock('./start-node-selection-panel', () => ({ default: function MockStartNodeSelectionPanel({ onSelectUserInput, onSelectTrigger, - }: any) { + }: { + onSelectUserInput: () => void + onSelectTrigger: (type: BlockEnum, config?: Record<string, unknown>) => void + }) { return ( <div data-testid="start-node-selection-panel"> <button data-testid="select-user-input" onClick={onSelectUserInput}> diff --git a/web/eslint-suppressions.json b/web/eslint-suppressions.json index 9cfe1fd462..79742805df 100644 --- a/web/eslint-suppressions.json +++ b/web/eslint-suppressions.json @@ -1707,11 +1707,6 @@ "count": 4 } }, - "app/components/datasets/common/image-uploader/utils.ts": { - "ts/no-explicit-any": { - "count": 2 - } - }, "app/components/datasets/common/retrieval-method-config/index.spec.tsx": { "ts/no-explicit-any": { "count": 1 @@ -1722,21 +1717,11 @@ "count": 1 } }, - "app/components/datasets/create-from-pipeline/create-options/create-from-dsl-modal/index.tsx": { - "react-refresh/only-export-components": { - "count": 1 - } - }, "app/components/datasets/create/file-preview/index.tsx": { "react-hooks-extra/no-direct-set-state-in-use-effect": { "count": 1 } }, - "app/components/datasets/create/file-uploader/index.tsx": { - "ts/no-explicit-any": { - "count": 3 - } - }, "app/components/datasets/create/index.spec.tsx": { "ts/no-explicit-any": { "count": 16 @@ -2044,14 +2029,6 @@ "count": 1 } }, - "app/components/datasets/settings/form/index.tsx": { - "react-hooks-extra/no-direct-set-state-in-use-effect": { - "count": 2 - }, - "ts/no-explicit-any": { - "count": 1 - } - }, "app/components/datasets/settings/permission-selector/index.tsx": { "react/no-missing-key": { "count": 1 @@ -2841,11 +2818,6 @@ "count": 2 } }, - "app/components/workflow-app/components/workflow-onboarding-modal/index.spec.tsx": { - "ts/no-explicit-any": { - "count": 2 - } - }, "app/components/workflow-app/components/workflow-onboarding-modal/start-node-selection-panel.spec.tsx": { "ts/no-explicit-any": { "count": 1 diff --git a/web/utils/format.ts b/web/utils/format.ts index d6968e0ef1..1146d1bfcd 100644 --- a/web/utils/format.ts +++ b/web/utils/format.ts @@ -148,3 +148,23 @@ export const formatNumberAbbreviated = (num: number) => { export const formatToLocalTime = (time: Dayjs, local: Locale, format: string) => { return time.locale(localeMap[local] ?? 'en').format(format) } + +/** + * Get file extension from file name. + * @param fileName file name + * @example getFileExtension('document.pdf') will return 'pdf' + * @example getFileExtension('archive.tar.gz') will return 'gz' + * @example getFileExtension('.gitignore') will return '' (hidden file with no extension) + * @example getFileExtension('.hidden.txt') will return 'txt' + */ +export const getFileExtension = (fileName: string): string => { + if (!fileName) + return '' + + // Handle hidden files (starting with dot) by finding dot after the first character + const dotIndex = fileName.indexOf('.', fileName.startsWith('.') ? 1 : 0) + if (dotIndex === -1 || dotIndex === fileName.length - 1) + return '' + + return fileName.slice(dotIndex + 1).split('.').pop()?.toLowerCase() ?? '' +} From ad3be1e4d02de23d9722b0cb1e6d0c3954a6caff Mon Sep 17 00:00:00 2001 From: Coding On Star <447357187@qq.com> Date: Wed, 4 Feb 2026 18:12:30 +0800 Subject: [PATCH 10/18] fix: include locale in appList query key for localization support inuseExploreAppList (#31921) Co-authored-by: CodingOnStar <hanxujiang@dify.com> --- web/service/use-explore.ts | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/web/service/use-explore.ts b/web/service/use-explore.ts index 3e3b9ff255..a2c278f2b2 100644 --- a/web/service/use-explore.ts +++ b/web/service/use-explore.ts @@ -1,6 +1,7 @@ import type { App, AppCategory } from '@/models/explore' import { useMutation, useQuery, useQueryClient } from '@tanstack/react-query' import { useGlobalPublicStore } from '@/context/global-public-context' +import { useLocale } from '@/context/i18n' import { AccessMode } from '@/models/access-control' import { fetchAppList, fetchBanners, fetchInstalledAppList, getAppAccessModeByAppId, uninstallApp, updatePinStatus } from './explore' import { AppSourceType, fetchAppMeta, fetchAppParams } from './share' @@ -13,8 +14,9 @@ type ExploreAppListData = { } export const useExploreAppList = () => { + const locale = useLocale() return useQuery<ExploreAppListData>({ - queryKey: [NAME_SPACE, 'appList'], + queryKey: [NAME_SPACE, 'appList', locale], queryFn: async () => { const { categories, recommended_apps } = await fetchAppList() return { From bba2040a05af18ec293be7999f977e8f414b80e4 Mon Sep 17 00:00:00 2001 From: -LAN- <laipz8200@outlook.com> Date: Wed, 4 Feb 2026 18:22:14 +0800 Subject: [PATCH 11/18] chore: assign code owners for test directories (#31940) --- .github/CODEOWNERS | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/.github/CODEOWNERS b/.github/CODEOWNERS index 36fa39b5d7..6cd99d551a 100644 --- a/.github/CODEOWNERS +++ b/.github/CODEOWNERS @@ -24,6 +24,10 @@ /api/services/tools/mcp_tools_manage_service.py @Nov1c444 /api/controllers/mcp/ @Nov1c444 /api/controllers/console/app/mcp_server.py @Nov1c444 + +# Backend - Tests +/api/tests/ @laipz8200 @QuantumGhost + /api/tests/**/*mcp* @Nov1c444 # Backend - Workflow - Engine (Core graph execution engine) @@ -234,6 +238,9 @@ # Frontend - Base Components /web/app/components/base/ @iamjoel @zxhlyh +# Frontend - Base Components Tests +/web/app/components/base/**/__tests__/ @hyoban @CodingOnStar + # Frontend - Utils and Hooks /web/utils/classnames.ts @iamjoel @zxhlyh /web/utils/time.ts @iamjoel @zxhlyh From 1e344f773be23ed89fa34171bbe4a317df1b9f75 Mon Sep 17 00:00:00 2001 From: Coding On Star <447357187@qq.com> Date: Wed, 4 Feb 2026 18:35:31 +0800 Subject: [PATCH 12/18] refactor(web): extract complex components into modular structure with comprehensive tests (#31729) Co-authored-by: CodingOnStar <hanxujiang@dify.com> Co-authored-by: Claude Opus 4.5 <noreply@anthropic.com> --- .../app/create-app-modal/index.spec.tsx | 15 +- .../components/document-source-icon.spec.tsx | 262 +++++++++ .../components/document-source-icon.tsx | 100 ++++ .../components/document-table-row.spec.tsx | 342 ++++++++++++ .../components/document-table-row.tsx | 152 ++++++ .../document-list/components/index.ts | 4 + .../components/sort-header.spec.tsx | 124 +++++ .../document-list/components/sort-header.tsx | 44 ++ .../document-list/components/utils.spec.tsx | 90 ++++ .../document-list/components/utils.tsx | 16 + .../components/document-list/hooks/index.ts | 4 + .../hooks/use-document-actions.spec.tsx | 438 ++++++++++++++++ .../hooks/use-document-actions.ts | 126 +++++ .../hooks/use-document-selection.spec.ts | 317 +++++++++++ .../hooks/use-document-selection.ts | 66 +++ .../hooks/use-document-sort.spec.ts | 340 ++++++++++++ .../document-list/hooks/use-document-sort.ts | 102 ++++ .../components/document-list/index.spec.tsx | 487 +++++++++++++++++ .../components/document-list/index.tsx | 3 + .../datasets/documents/components/list.tsx | 496 ++++-------------- .../detail/embedding/components/index.ts | 4 + .../components/progress-bar.spec.tsx | 159 ++++++ .../embedding/components/progress-bar.tsx | 44 ++ .../embedding/components/rule-detail.spec.tsx | 203 +++++++ .../embedding/components/rule-detail.tsx | 128 +++++ .../components/segment-progress.spec.tsx | 81 +++ .../embedding/components/segment-progress.tsx | 32 ++ .../components/status-header.spec.tsx | 155 ++++++ .../embedding/components/status-header.tsx | 84 +++ .../documents/detail/embedding/hooks/index.ts | 10 + .../hooks/use-embedding-status.spec.tsx | 462 ++++++++++++++++ .../embedding/hooks/use-embedding-status.ts | 149 ++++++ .../documents/detail/embedding/index.spec.tsx | 337 ++++++++++++ .../documents/detail/embedding/index.tsx | 351 +++---------- .../components/dataset-card-header.spec.tsx | 7 + .../components/dataset-card-modals.spec.tsx | 30 +- .../components/goto-anything/index.spec.tsx | 4 + .../components/panel/index.spec.tsx | 163 +++--- .../components/update-dsl-modal.spec.tsx | 51 +- .../rag-pipeline/hooks/use-DSL.spec.ts | 36 ++ .../workflow-onboarding-modal/index.spec.tsx | 14 +- web/eslint-suppressions.json | 11 - 42 files changed, 5234 insertions(+), 809 deletions(-) create mode 100644 web/app/components/datasets/documents/components/document-list/components/document-source-icon.spec.tsx create mode 100644 web/app/components/datasets/documents/components/document-list/components/document-source-icon.tsx create mode 100644 web/app/components/datasets/documents/components/document-list/components/document-table-row.spec.tsx create mode 100644 web/app/components/datasets/documents/components/document-list/components/document-table-row.tsx create mode 100644 web/app/components/datasets/documents/components/document-list/components/index.ts create mode 100644 web/app/components/datasets/documents/components/document-list/components/sort-header.spec.tsx create mode 100644 web/app/components/datasets/documents/components/document-list/components/sort-header.tsx create mode 100644 web/app/components/datasets/documents/components/document-list/components/utils.spec.tsx create mode 100644 web/app/components/datasets/documents/components/document-list/components/utils.tsx create mode 100644 web/app/components/datasets/documents/components/document-list/hooks/index.ts create mode 100644 web/app/components/datasets/documents/components/document-list/hooks/use-document-actions.spec.tsx create mode 100644 web/app/components/datasets/documents/components/document-list/hooks/use-document-actions.ts create mode 100644 web/app/components/datasets/documents/components/document-list/hooks/use-document-selection.spec.ts create mode 100644 web/app/components/datasets/documents/components/document-list/hooks/use-document-selection.ts create mode 100644 web/app/components/datasets/documents/components/document-list/hooks/use-document-sort.spec.ts create mode 100644 web/app/components/datasets/documents/components/document-list/hooks/use-document-sort.ts create mode 100644 web/app/components/datasets/documents/components/document-list/index.spec.tsx create mode 100644 web/app/components/datasets/documents/components/document-list/index.tsx create mode 100644 web/app/components/datasets/documents/detail/embedding/components/index.ts create mode 100644 web/app/components/datasets/documents/detail/embedding/components/progress-bar.spec.tsx create mode 100644 web/app/components/datasets/documents/detail/embedding/components/progress-bar.tsx create mode 100644 web/app/components/datasets/documents/detail/embedding/components/rule-detail.spec.tsx create mode 100644 web/app/components/datasets/documents/detail/embedding/components/rule-detail.tsx create mode 100644 web/app/components/datasets/documents/detail/embedding/components/segment-progress.spec.tsx create mode 100644 web/app/components/datasets/documents/detail/embedding/components/segment-progress.tsx create mode 100644 web/app/components/datasets/documents/detail/embedding/components/status-header.spec.tsx create mode 100644 web/app/components/datasets/documents/detail/embedding/components/status-header.tsx create mode 100644 web/app/components/datasets/documents/detail/embedding/hooks/index.ts create mode 100644 web/app/components/datasets/documents/detail/embedding/hooks/use-embedding-status.spec.tsx create mode 100644 web/app/components/datasets/documents/detail/embedding/hooks/use-embedding-status.ts create mode 100644 web/app/components/datasets/documents/detail/embedding/index.spec.tsx diff --git a/web/app/components/app/create-app-modal/index.spec.tsx b/web/app/components/app/create-app-modal/index.spec.tsx index d26a581fda..8c368df62c 100644 --- a/web/app/components/app/create-app-modal/index.spec.tsx +++ b/web/app/components/app/create-app-modal/index.spec.tsx @@ -1,3 +1,4 @@ +import type { App } from '@/types/app' import { fireEvent, render, screen, waitFor } from '@testing-library/react' import { useRouter } from 'next/navigation' import { afterAll, beforeEach, describe, expect, it, vi } from 'vitest' @@ -13,8 +14,8 @@ import { getRedirection } from '@/utils/app-redirection' import CreateAppModal from './index' vi.mock('ahooks', () => ({ - useDebounceFn: (fn: (...args: any[]) => any) => { - const run = (...args: any[]) => fn(...args) + useDebounceFn: <T extends (...args: unknown[]) => unknown>(fn: T) => { + const run = (...args: Parameters<T>) => fn(...args) const cancel = vi.fn() const flush = vi.fn() return { run, cancel, flush } @@ -83,7 +84,7 @@ describe('CreateAppModal', () => { beforeEach(() => { vi.clearAllMocks() - mockUseRouter.mockReturnValue({ push: mockPush } as any) + mockUseRouter.mockReturnValue({ push: mockPush } as unknown as ReturnType<typeof useRouter>) mockUseProviderContext.mockReturnValue({ plan: { type: AppModeEnum.ADVANCED_CHAT, @@ -92,10 +93,10 @@ describe('CreateAppModal', () => { reset: {}, }, enableBilling: true, - } as any) + } as unknown as ReturnType<typeof useProviderContext>) mockUseAppContext.mockReturnValue({ isCurrentWorkspaceEditor: true, - } as any) + } as unknown as ReturnType<typeof useAppContext>) mockSetItem.mockClear() Object.defineProperty(window, 'localStorage', { value: { @@ -118,8 +119,8 @@ describe('CreateAppModal', () => { }) it('creates an app, notifies success, and fires callbacks', async () => { - const mockApp = { id: 'app-1', mode: AppModeEnum.ADVANCED_CHAT } - mockCreateApp.mockResolvedValue(mockApp as any) + const mockApp: Partial<App> = { id: 'app-1', mode: AppModeEnum.ADVANCED_CHAT } + mockCreateApp.mockResolvedValue(mockApp as App) const { onClose, onSuccess } = renderModal() const nameInput = screen.getByPlaceholderText('app.newApp.appNamePlaceholder') diff --git a/web/app/components/datasets/documents/components/document-list/components/document-source-icon.spec.tsx b/web/app/components/datasets/documents/components/document-list/components/document-source-icon.spec.tsx new file mode 100644 index 0000000000..33108fbbac --- /dev/null +++ b/web/app/components/datasets/documents/components/document-list/components/document-source-icon.spec.tsx @@ -0,0 +1,262 @@ +import type { SimpleDocumentDetail } from '@/models/datasets' +import { render } from '@testing-library/react' +import { describe, expect, it } from 'vitest' +import { DataSourceType } from '@/models/datasets' +import { DatasourceType } from '@/models/pipeline' +import DocumentSourceIcon from './document-source-icon' + +const createMockDoc = (overrides: Record<string, unknown> = {}): SimpleDocumentDetail => ({ + id: 'doc-1', + position: 1, + data_source_type: DataSourceType.FILE, + data_source_info: {}, + data_source_detail_dict: {}, + dataset_process_rule_id: 'rule-1', + dataset_id: 'dataset-1', + batch: 'batch-1', + name: 'test-document.txt', + created_from: 'web', + created_by: 'user-1', + created_at: Date.now(), + tokens: 100, + indexing_status: 'completed', + error: null, + enabled: true, + disabled_at: null, + disabled_by: null, + archived: false, + archived_reason: null, + archived_by: null, + archived_at: null, + updated_at: Date.now(), + doc_type: null, + doc_metadata: undefined, + doc_language: 'en', + display_status: 'available', + word_count: 100, + hit_count: 10, + doc_form: 'text_model', + ...overrides, +}) as unknown as SimpleDocumentDetail + +describe('DocumentSourceIcon', () => { + describe('Rendering', () => { + it('should render without crashing', () => { + const doc = createMockDoc() + const { container } = render(<DocumentSourceIcon doc={doc} />) + expect(container.firstChild).toBeInTheDocument() + }) + }) + + describe('Local File Icon', () => { + it('should render FileTypeIcon for FILE data source type', () => { + const doc = createMockDoc({ + data_source_type: DataSourceType.FILE, + data_source_info: { + upload_file: { extension: 'pdf' }, + }, + }) + + const { container } = render(<DocumentSourceIcon doc={doc} fileType="pdf" />) + const icon = container.querySelector('svg, img') + expect(icon).toBeInTheDocument() + }) + + it('should render FileTypeIcon for localFile data source type', () => { + const doc = createMockDoc({ + data_source_type: DatasourceType.localFile, + created_from: 'rag-pipeline', + data_source_info: { + extension: 'docx', + }, + }) + + const { container } = render(<DocumentSourceIcon doc={doc} />) + const icon = container.querySelector('svg, img') + expect(icon).toBeInTheDocument() + }) + + it('should use extension from upload_file for legacy data source', () => { + const doc = createMockDoc({ + data_source_type: DataSourceType.FILE, + created_from: 'web', + data_source_info: { + upload_file: { extension: 'txt' }, + }, + }) + + const { container } = render(<DocumentSourceIcon doc={doc} />) + expect(container.firstChild).toBeInTheDocument() + }) + + it('should use fileType prop as fallback for extension', () => { + const doc = createMockDoc({ + data_source_type: DataSourceType.FILE, + created_from: 'web', + data_source_info: {}, + }) + + const { container } = render(<DocumentSourceIcon doc={doc} fileType="csv" />) + expect(container.firstChild).toBeInTheDocument() + }) + }) + + describe('Notion Icon', () => { + it('should render NotionIcon for NOTION data source type', () => { + const doc = createMockDoc({ + data_source_type: DataSourceType.NOTION, + created_from: 'web', + data_source_info: { + notion_page_icon: 'https://notion.so/icon.png', + }, + }) + + const { container } = render(<DocumentSourceIcon doc={doc} />) + expect(container.firstChild).toBeInTheDocument() + }) + + it('should render NotionIcon for onlineDocument data source type', () => { + const doc = createMockDoc({ + data_source_type: DatasourceType.onlineDocument, + created_from: 'rag-pipeline', + data_source_info: { + page: { page_icon: 'https://notion.so/icon.png' }, + }, + }) + + const { container } = render(<DocumentSourceIcon doc={doc} />) + expect(container.firstChild).toBeInTheDocument() + }) + + it('should use page_icon for rag-pipeline created documents', () => { + const doc = createMockDoc({ + data_source_type: DataSourceType.NOTION, + created_from: 'rag-pipeline', + data_source_info: { + page: { page_icon: 'https://notion.so/custom-icon.png' }, + }, + }) + + const { container } = render(<DocumentSourceIcon doc={doc} />) + expect(container.firstChild).toBeInTheDocument() + }) + }) + + describe('Web Crawl Icon', () => { + it('should render globe icon for WEB data source type', () => { + const doc = createMockDoc({ + data_source_type: DataSourceType.WEB, + }) + + const { container } = render(<DocumentSourceIcon doc={doc} />) + const icon = container.querySelector('svg') + expect(icon).toBeInTheDocument() + expect(icon).toHaveClass('mr-1.5') + expect(icon).toHaveClass('size-4') + }) + + it('should render globe icon for websiteCrawl data source type', () => { + const doc = createMockDoc({ + data_source_type: DatasourceType.websiteCrawl, + }) + + const { container } = render(<DocumentSourceIcon doc={doc} />) + const icon = container.querySelector('svg') + expect(icon).toBeInTheDocument() + }) + }) + + describe('Online Drive Icon', () => { + it('should render FileTypeIcon for onlineDrive data source type', () => { + const doc = createMockDoc({ + data_source_type: DatasourceType.onlineDrive, + data_source_info: { + name: 'document.xlsx', + }, + }) + + const { container } = render(<DocumentSourceIcon doc={doc} />) + expect(container.firstChild).toBeInTheDocument() + }) + + it('should extract extension from file name', () => { + const doc = createMockDoc({ + data_source_type: DatasourceType.onlineDrive, + data_source_info: { + name: 'spreadsheet.xlsx', + }, + }) + + const { container } = render(<DocumentSourceIcon doc={doc} />) + expect(container.firstChild).toBeInTheDocument() + }) + + it('should handle file name without extension', () => { + const doc = createMockDoc({ + data_source_type: DatasourceType.onlineDrive, + data_source_info: { + name: 'noextension', + }, + }) + + const { container } = render(<DocumentSourceIcon doc={doc} />) + expect(container.firstChild).toBeInTheDocument() + }) + + it('should handle empty file name', () => { + const doc = createMockDoc({ + data_source_type: DatasourceType.onlineDrive, + data_source_info: { + name: '', + }, + }) + + const { container } = render(<DocumentSourceIcon doc={doc} />) + expect(container.firstChild).toBeInTheDocument() + }) + + it('should handle hidden files (starting with dot)', () => { + const doc = createMockDoc({ + data_source_type: DatasourceType.onlineDrive, + data_source_info: { + name: '.gitignore', + }, + }) + + const { container } = render(<DocumentSourceIcon doc={doc} />) + expect(container.firstChild).toBeInTheDocument() + }) + }) + + describe('Unknown Data Source Type', () => { + it('should return null for unknown data source type', () => { + const doc = createMockDoc({ + data_source_type: 'unknown', + }) + + const { container } = render(<DocumentSourceIcon doc={doc} />) + expect(container.firstChild).toBeNull() + }) + }) + + describe('Edge Cases', () => { + it('should handle undefined data_source_info', () => { + const doc = createMockDoc({ + data_source_type: DataSourceType.FILE, + data_source_info: undefined, + }) + + const { container } = render(<DocumentSourceIcon doc={doc} />) + expect(container.firstChild).toBeInTheDocument() + }) + + it('should memoize the component', () => { + const doc = createMockDoc() + const { rerender, container } = render(<DocumentSourceIcon doc={doc} />) + + const firstRender = container.innerHTML + rerender(<DocumentSourceIcon doc={doc} />) + expect(container.innerHTML).toBe(firstRender) + }) + }) +}) diff --git a/web/app/components/datasets/documents/components/document-list/components/document-source-icon.tsx b/web/app/components/datasets/documents/components/document-list/components/document-source-icon.tsx new file mode 100644 index 0000000000..5461f34921 --- /dev/null +++ b/web/app/components/datasets/documents/components/document-list/components/document-source-icon.tsx @@ -0,0 +1,100 @@ +import type { FC } from 'react' +import type { LegacyDataSourceInfo, LocalFileInfo, OnlineDocumentInfo, OnlineDriveInfo, SimpleDocumentDetail } from '@/models/datasets' +import { RiGlobalLine } from '@remixicon/react' +import * as React from 'react' +import FileTypeIcon from '@/app/components/base/file-uploader/file-type-icon' +import NotionIcon from '@/app/components/base/notion-icon' +import { extensionToFileType } from '@/app/components/datasets/hit-testing/utils/extension-to-file-type' +import { DataSourceType } from '@/models/datasets' +import { DatasourceType } from '@/models/pipeline' + +type DocumentSourceIconProps = { + doc: SimpleDocumentDetail + fileType?: string +} + +const isLocalFile = (dataSourceType: DataSourceType | DatasourceType) => { + return dataSourceType === DatasourceType.localFile || dataSourceType === DataSourceType.FILE +} + +const isOnlineDocument = (dataSourceType: DataSourceType | DatasourceType) => { + return dataSourceType === DatasourceType.onlineDocument || dataSourceType === DataSourceType.NOTION +} + +const isWebsiteCrawl = (dataSourceType: DataSourceType | DatasourceType) => { + return dataSourceType === DatasourceType.websiteCrawl || dataSourceType === DataSourceType.WEB +} + +const isOnlineDrive = (dataSourceType: DataSourceType | DatasourceType) => { + return dataSourceType === DatasourceType.onlineDrive +} + +const isCreateFromRAGPipeline = (createdFrom: string) => { + return createdFrom === 'rag-pipeline' +} + +const getFileExtension = (fileName: string): string => { + if (!fileName) + return '' + const parts = fileName.split('.') + if (parts.length <= 1 || (parts[0] === '' && parts.length === 2)) + return '' + return parts[parts.length - 1].toLowerCase() +} + +const DocumentSourceIcon: FC<DocumentSourceIconProps> = React.memo(({ + doc, + fileType, +}) => { + if (isOnlineDocument(doc.data_source_type)) { + return ( + <NotionIcon + className="mr-1.5" + type="page" + src={ + isCreateFromRAGPipeline(doc.created_from) + ? (doc.data_source_info as OnlineDocumentInfo).page.page_icon + : (doc.data_source_info as LegacyDataSourceInfo).notion_page_icon + } + /> + ) + } + + if (isLocalFile(doc.data_source_type)) { + return ( + <FileTypeIcon + type={ + extensionToFileType( + isCreateFromRAGPipeline(doc.created_from) + ? (doc?.data_source_info as LocalFileInfo)?.extension + : ((doc?.data_source_info as LegacyDataSourceInfo)?.upload_file?.extension ?? fileType), + ) + } + className="mr-1.5" + /> + ) + } + + if (isOnlineDrive(doc.data_source_type)) { + return ( + <FileTypeIcon + type={ + extensionToFileType( + getFileExtension((doc?.data_source_info as unknown as OnlineDriveInfo)?.name), + ) + } + className="mr-1.5" + /> + ) + } + + if (isWebsiteCrawl(doc.data_source_type)) { + return <RiGlobalLine className="mr-1.5 size-4" /> + } + + return null +}) + +DocumentSourceIcon.displayName = 'DocumentSourceIcon' + +export default DocumentSourceIcon diff --git a/web/app/components/datasets/documents/components/document-list/components/document-table-row.spec.tsx b/web/app/components/datasets/documents/components/document-list/components/document-table-row.spec.tsx new file mode 100644 index 0000000000..7157a9bf4b --- /dev/null +++ b/web/app/components/datasets/documents/components/document-list/components/document-table-row.spec.tsx @@ -0,0 +1,342 @@ +import type { ReactNode } from 'react' +import type { SimpleDocumentDetail } from '@/models/datasets' +import { QueryClient, QueryClientProvider } from '@tanstack/react-query' +import { fireEvent, render, screen } from '@testing-library/react' +import { beforeEach, describe, expect, it, vi } from 'vitest' +import { DataSourceType } from '@/models/datasets' +import DocumentTableRow from './document-table-row' + +const mockPush = vi.fn() + +vi.mock('next/navigation', () => ({ + useRouter: () => ({ + push: mockPush, + }), +})) + +const createTestQueryClient = () => new QueryClient({ + defaultOptions: { + queries: { retry: false, gcTime: 0 }, + mutations: { retry: false }, + }, +}) + +const createWrapper = () => { + const queryClient = createTestQueryClient() + return ({ children }: { children: ReactNode }) => ( + <QueryClientProvider client={queryClient}> + <table> + <tbody> + {children} + </tbody> + </table> + </QueryClientProvider> + ) +} + +type LocalDoc = SimpleDocumentDetail & { percent?: number } + +const createMockDoc = (overrides: Record<string, unknown> = {}): LocalDoc => ({ + id: 'doc-1', + position: 1, + data_source_type: DataSourceType.FILE, + data_source_info: {}, + data_source_detail_dict: { + upload_file: { name: 'test.txt', extension: 'txt' }, + }, + dataset_process_rule_id: 'rule-1', + dataset_id: 'dataset-1', + batch: 'batch-1', + name: 'test-document.txt', + created_from: 'web', + created_by: 'user-1', + created_at: Date.now(), + tokens: 100, + indexing_status: 'completed', + error: null, + enabled: true, + disabled_at: null, + disabled_by: null, + archived: false, + archived_reason: null, + archived_by: null, + archived_at: null, + updated_at: Date.now(), + doc_type: null, + doc_metadata: undefined, + doc_language: 'en', + display_status: 'available', + word_count: 500, + hit_count: 10, + doc_form: 'text_model', + ...overrides, +}) as unknown as LocalDoc + +// Helper to find the custom checkbox div (Checkbox component renders as a div, not a native checkbox) +const findCheckbox = (container: HTMLElement): HTMLElement | null => { + return container.querySelector('[class*="shadow-xs"]') +} + +describe('DocumentTableRow', () => { + const defaultProps = { + doc: createMockDoc(), + index: 0, + datasetId: 'dataset-1', + isSelected: false, + isGeneralMode: true, + isQAMode: false, + embeddingAvailable: true, + selectedIds: [], + onSelectOne: vi.fn(), + onSelectedIdChange: vi.fn(), + onShowRenameModal: vi.fn(), + onUpdate: vi.fn(), + } + + beforeEach(() => { + vi.clearAllMocks() + }) + + describe('Rendering', () => { + it('should render without crashing', () => { + render(<DocumentTableRow {...defaultProps} />, { wrapper: createWrapper() }) + expect(screen.getByText('test-document.txt')).toBeInTheDocument() + }) + + it('should render index number correctly', () => { + render(<DocumentTableRow {...defaultProps} index={5} />, { wrapper: createWrapper() }) + expect(screen.getByText('6')).toBeInTheDocument() + }) + + it('should render document name with tooltip', () => { + render(<DocumentTableRow {...defaultProps} />, { wrapper: createWrapper() }) + expect(screen.getByText('test-document.txt')).toBeInTheDocument() + }) + + it('should render checkbox element', () => { + const { container } = render(<DocumentTableRow {...defaultProps} />, { wrapper: createWrapper() }) + const checkbox = findCheckbox(container) + expect(checkbox).toBeInTheDocument() + }) + }) + + describe('Selection', () => { + it('should show check icon when isSelected is true', () => { + const { container } = render(<DocumentTableRow {...defaultProps} isSelected />, { wrapper: createWrapper() }) + // When selected, the checkbox should have a check icon (RiCheckLine svg) + const checkbox = findCheckbox(container) + expect(checkbox).toBeInTheDocument() + const checkIcon = checkbox?.querySelector('svg') + expect(checkIcon).toBeInTheDocument() + }) + + it('should not show check icon when isSelected is false', () => { + const { container } = render(<DocumentTableRow {...defaultProps} isSelected={false} />, { wrapper: createWrapper() }) + const checkbox = findCheckbox(container) + expect(checkbox).toBeInTheDocument() + // When not selected, there should be no check icon inside the checkbox + const checkIcon = checkbox?.querySelector('svg') + expect(checkIcon).not.toBeInTheDocument() + }) + + it('should call onSelectOne when checkbox is clicked', () => { + const onSelectOne = vi.fn() + const { container } = render(<DocumentTableRow {...defaultProps} onSelectOne={onSelectOne} />, { wrapper: createWrapper() }) + + const checkbox = findCheckbox(container) + if (checkbox) { + fireEvent.click(checkbox) + expect(onSelectOne).toHaveBeenCalledWith('doc-1') + } + }) + + it('should stop propagation when checkbox container is clicked', () => { + const { container } = render(<DocumentTableRow {...defaultProps} />, { wrapper: createWrapper() }) + + // Click the div containing the checkbox (which has stopPropagation) + const checkboxContainer = container.querySelector('td')?.querySelector('div') + if (checkboxContainer) { + fireEvent.click(checkboxContainer) + expect(mockPush).not.toHaveBeenCalled() + } + }) + }) + + describe('Row Navigation', () => { + it('should navigate to document detail on row click', () => { + render(<DocumentTableRow {...defaultProps} />, { wrapper: createWrapper() }) + + const row = screen.getByRole('row') + fireEvent.click(row) + + expect(mockPush).toHaveBeenCalledWith('/datasets/dataset-1/documents/doc-1') + }) + + it('should navigate with correct datasetId and documentId', () => { + render( + <DocumentTableRow + {...defaultProps} + datasetId="custom-dataset" + doc={createMockDoc({ id: 'custom-doc' })} + />, + { wrapper: createWrapper() }, + ) + + const row = screen.getByRole('row') + fireEvent.click(row) + + expect(mockPush).toHaveBeenCalledWith('/datasets/custom-dataset/documents/custom-doc') + }) + }) + + describe('Word Count Display', () => { + it('should display word count less than 1000 as is', () => { + const doc = createMockDoc({ word_count: 500 }) + render(<DocumentTableRow {...defaultProps} doc={doc} />, { wrapper: createWrapper() }) + expect(screen.getByText('500')).toBeInTheDocument() + }) + + it('should display word count 1000 or more in k format', () => { + const doc = createMockDoc({ word_count: 1500 }) + render(<DocumentTableRow {...defaultProps} doc={doc} />, { wrapper: createWrapper() }) + expect(screen.getByText('1.5k')).toBeInTheDocument() + }) + + it('should display 0 with empty style when word_count is 0', () => { + const doc = createMockDoc({ word_count: 0 }) + const { container } = render(<DocumentTableRow {...defaultProps} doc={doc} />, { wrapper: createWrapper() }) + const zeroCells = container.querySelectorAll('.text-text-tertiary') + expect(zeroCells.length).toBeGreaterThan(0) + }) + + it('should handle undefined word_count', () => { + const doc = createMockDoc({ word_count: undefined as unknown as number }) + const { container } = render(<DocumentTableRow {...defaultProps} doc={doc} />, { wrapper: createWrapper() }) + expect(container).toBeInTheDocument() + }) + }) + + describe('Hit Count Display', () => { + it('should display hit count less than 1000 as is', () => { + const doc = createMockDoc({ hit_count: 100 }) + render(<DocumentTableRow {...defaultProps} doc={doc} />, { wrapper: createWrapper() }) + expect(screen.getByText('100')).toBeInTheDocument() + }) + + it('should display hit count 1000 or more in k format', () => { + const doc = createMockDoc({ hit_count: 2500 }) + render(<DocumentTableRow {...defaultProps} doc={doc} />, { wrapper: createWrapper() }) + expect(screen.getByText('2.5k')).toBeInTheDocument() + }) + + it('should display 0 with empty style when hit_count is 0', () => { + const doc = createMockDoc({ hit_count: 0 }) + const { container } = render(<DocumentTableRow {...defaultProps} doc={doc} />, { wrapper: createWrapper() }) + const zeroCells = container.querySelectorAll('.text-text-tertiary') + expect(zeroCells.length).toBeGreaterThan(0) + }) + }) + + describe('Chunking Mode', () => { + it('should render ChunkingModeLabel with general mode', () => { + render(<DocumentTableRow {...defaultProps} isGeneralMode isQAMode={false} />, { wrapper: createWrapper() }) + // ChunkingModeLabel should be rendered + expect(screen.getByRole('row')).toBeInTheDocument() + }) + + it('should render ChunkingModeLabel with QA mode', () => { + render(<DocumentTableRow {...defaultProps} isGeneralMode={false} isQAMode />, { wrapper: createWrapper() }) + expect(screen.getByRole('row')).toBeInTheDocument() + }) + }) + + describe('Summary Status', () => { + it('should render SummaryStatus when summary_index_status is present', () => { + const doc = createMockDoc({ summary_index_status: 'completed' }) + render(<DocumentTableRow {...defaultProps} doc={doc} />, { wrapper: createWrapper() }) + expect(screen.getByRole('row')).toBeInTheDocument() + }) + + it('should not render SummaryStatus when summary_index_status is absent', () => { + const doc = createMockDoc({ summary_index_status: undefined }) + render(<DocumentTableRow {...defaultProps} doc={doc} />, { wrapper: createWrapper() }) + expect(screen.getByRole('row')).toBeInTheDocument() + }) + }) + + describe('Rename Action', () => { + it('should call onShowRenameModal when rename button is clicked', () => { + const onShowRenameModal = vi.fn() + const { container } = render( + <DocumentTableRow {...defaultProps} onShowRenameModal={onShowRenameModal} />, + { wrapper: createWrapper() }, + ) + + // Find the rename button by finding the RiEditLine icon's parent + const renameButtons = container.querySelectorAll('.cursor-pointer.rounded-md') + if (renameButtons.length > 0) { + fireEvent.click(renameButtons[0]) + expect(onShowRenameModal).toHaveBeenCalledWith(defaultProps.doc) + expect(mockPush).not.toHaveBeenCalled() + } + }) + }) + + describe('Operations', () => { + it('should pass selectedIds to Operations component', () => { + render(<DocumentTableRow {...defaultProps} selectedIds={['doc-1', 'doc-2']} />, { wrapper: createWrapper() }) + expect(screen.getByRole('row')).toBeInTheDocument() + }) + + it('should pass onSelectedIdChange to Operations component', () => { + const onSelectedIdChange = vi.fn() + render(<DocumentTableRow {...defaultProps} onSelectedIdChange={onSelectedIdChange} />, { wrapper: createWrapper() }) + expect(screen.getByRole('row')).toBeInTheDocument() + }) + }) + + describe('Document Source Icon', () => { + it('should render with FILE data source type', () => { + const doc = createMockDoc({ data_source_type: DataSourceType.FILE }) + render(<DocumentTableRow {...defaultProps} doc={doc} />, { wrapper: createWrapper() }) + expect(screen.getByRole('row')).toBeInTheDocument() + }) + + it('should render with NOTION data source type', () => { + const doc = createMockDoc({ + data_source_type: DataSourceType.NOTION, + data_source_info: { notion_page_icon: 'icon.png' }, + }) + render(<DocumentTableRow {...defaultProps} doc={doc} />, { wrapper: createWrapper() }) + expect(screen.getByRole('row')).toBeInTheDocument() + }) + + it('should render with WEB data source type', () => { + const doc = createMockDoc({ data_source_type: DataSourceType.WEB }) + render(<DocumentTableRow {...defaultProps} doc={doc} />, { wrapper: createWrapper() }) + expect(screen.getByRole('row')).toBeInTheDocument() + }) + }) + + describe('Edge Cases', () => { + it('should handle document with very long name', () => { + const doc = createMockDoc({ name: `${'a'.repeat(500)}.txt` }) + render(<DocumentTableRow {...defaultProps} doc={doc} />, { wrapper: createWrapper() }) + expect(screen.getByRole('row')).toBeInTheDocument() + }) + + it('should handle document with special characters in name', () => { + const doc = createMockDoc({ name: '<script>test</script>.txt' }) + render(<DocumentTableRow {...defaultProps} doc={doc} />, { wrapper: createWrapper() }) + expect(screen.getByText('<script>test</script>.txt')).toBeInTheDocument() + }) + + it('should memoize the component', () => { + const wrapper = createWrapper() + const { rerender } = render(<DocumentTableRow {...defaultProps} />, { wrapper }) + + rerender(<DocumentTableRow {...defaultProps} />) + expect(screen.getByRole('row')).toBeInTheDocument() + }) + }) +}) diff --git a/web/app/components/datasets/documents/components/document-list/components/document-table-row.tsx b/web/app/components/datasets/documents/components/document-list/components/document-table-row.tsx new file mode 100644 index 0000000000..731c14e731 --- /dev/null +++ b/web/app/components/datasets/documents/components/document-list/components/document-table-row.tsx @@ -0,0 +1,152 @@ +import type { FC } from 'react' +import type { SimpleDocumentDetail } from '@/models/datasets' +import { RiEditLine } from '@remixicon/react' +import { pick } from 'es-toolkit/object' +import { useRouter } from 'next/navigation' +import * as React from 'react' +import { useCallback } from 'react' +import { useTranslation } from 'react-i18next' +import Checkbox from '@/app/components/base/checkbox' +import Tooltip from '@/app/components/base/tooltip' +import ChunkingModeLabel from '@/app/components/datasets/common/chunking-mode-label' +import Operations from '@/app/components/datasets/documents/components/operations' +import SummaryStatus from '@/app/components/datasets/documents/detail/completed/common/summary-status' +import StatusItem from '@/app/components/datasets/documents/status-item' +import useTimestamp from '@/hooks/use-timestamp' +import { DataSourceType } from '@/models/datasets' +import { formatNumber } from '@/utils/format' +import DocumentSourceIcon from './document-source-icon' +import { renderTdValue } from './utils' + +type LocalDoc = SimpleDocumentDetail & { percent?: number } + +type DocumentTableRowProps = { + doc: LocalDoc + index: number + datasetId: string + isSelected: boolean + isGeneralMode: boolean + isQAMode: boolean + embeddingAvailable: boolean + selectedIds: string[] + onSelectOne: (docId: string) => void + onSelectedIdChange: (ids: string[]) => void + onShowRenameModal: (doc: LocalDoc) => void + onUpdate: () => void +} + +const renderCount = (count: number | undefined) => { + if (!count) + return renderTdValue(0, true) + + if (count < 1000) + return count + + return `${formatNumber((count / 1000).toFixed(1))}k` +} + +const DocumentTableRow: FC<DocumentTableRowProps> = React.memo(({ + doc, + index, + datasetId, + isSelected, + isGeneralMode, + isQAMode, + embeddingAvailable, + selectedIds, + onSelectOne, + onSelectedIdChange, + onShowRenameModal, + onUpdate, +}) => { + const { t } = useTranslation() + const { formatTime } = useTimestamp() + const router = useRouter() + + const isFile = doc.data_source_type === DataSourceType.FILE + const fileType = isFile ? doc.data_source_detail_dict?.upload_file?.extension : '' + + const handleRowClick = useCallback(() => { + router.push(`/datasets/${datasetId}/documents/${doc.id}`) + }, [router, datasetId, doc.id]) + + const handleCheckboxClick = useCallback((e: React.MouseEvent) => { + e.stopPropagation() + }, []) + + const handleRenameClick = useCallback((e: React.MouseEvent) => { + e.stopPropagation() + onShowRenameModal(doc) + }, [doc, onShowRenameModal]) + + return ( + <tr + className="h-8 cursor-pointer border-b border-divider-subtle hover:bg-background-default-hover" + onClick={handleRowClick} + > + <td className="text-left align-middle text-xs text-text-tertiary"> + <div className="flex items-center" onClick={handleCheckboxClick}> + <Checkbox + className="mr-2 shrink-0" + checked={isSelected} + onCheck={() => onSelectOne(doc.id)} + /> + {index + 1} + </div> + </td> + <td> + <div className="group mr-6 flex max-w-[460px] items-center hover:mr-0"> + <div className="flex shrink-0 items-center"> + <DocumentSourceIcon doc={doc} fileType={fileType} /> + </div> + <Tooltip popupContent={doc.name}> + <span className="grow-1 truncate text-sm">{doc.name}</span> + </Tooltip> + {doc.summary_index_status && ( + <div className="ml-1 hidden shrink-0 group-hover:flex"> + <SummaryStatus status={doc.summary_index_status} /> + </div> + )} + <div className="hidden shrink-0 group-hover:ml-auto group-hover:flex"> + <Tooltip popupContent={t('list.table.rename', { ns: 'datasetDocuments' })}> + <div + className="cursor-pointer rounded-md p-1 hover:bg-state-base-hover" + onClick={handleRenameClick} + > + <RiEditLine className="h-4 w-4 text-text-tertiary" /> + </div> + </Tooltip> + </div> + </div> + </td> + <td> + <ChunkingModeLabel + isGeneralMode={isGeneralMode} + isQAMode={isQAMode} + /> + </td> + <td>{renderCount(doc.word_count)}</td> + <td>{renderCount(doc.hit_count)}</td> + <td className="text-[13px] text-text-secondary"> + {formatTime(doc.created_at, t('dateTimeFormat', { ns: 'datasetHitTesting' }) as string)} + </td> + <td> + <StatusItem status={doc.display_status} /> + </td> + <td> + <Operations + selectedIds={selectedIds} + onSelectedIdChange={onSelectedIdChange} + embeddingAvailable={embeddingAvailable} + datasetId={datasetId} + detail={pick(doc, ['name', 'enabled', 'archived', 'id', 'data_source_type', 'doc_form', 'display_status'])} + onUpdate={onUpdate} + /> + </td> + </tr> + ) +}) + +DocumentTableRow.displayName = 'DocumentTableRow' + +export default DocumentTableRow diff --git a/web/app/components/datasets/documents/components/document-list/components/index.ts b/web/app/components/datasets/documents/components/document-list/components/index.ts new file mode 100644 index 0000000000..377f64a27f --- /dev/null +++ b/web/app/components/datasets/documents/components/document-list/components/index.ts @@ -0,0 +1,4 @@ +export { default as DocumentSourceIcon } from './document-source-icon' +export { default as DocumentTableRow } from './document-table-row' +export { default as SortHeader } from './sort-header' +export { renderTdValue } from './utils' diff --git a/web/app/components/datasets/documents/components/document-list/components/sort-header.spec.tsx b/web/app/components/datasets/documents/components/document-list/components/sort-header.spec.tsx new file mode 100644 index 0000000000..15cc55247b --- /dev/null +++ b/web/app/components/datasets/documents/components/document-list/components/sort-header.spec.tsx @@ -0,0 +1,124 @@ +import { fireEvent, render, screen } from '@testing-library/react' +import { describe, expect, it, vi } from 'vitest' +import SortHeader from './sort-header' + +describe('SortHeader', () => { + const defaultProps = { + field: 'name' as const, + label: 'File Name', + currentSortField: null, + sortOrder: 'desc' as const, + onSort: vi.fn(), + } + + describe('rendering', () => { + it('should render the label', () => { + render(<SortHeader {...defaultProps} />) + expect(screen.getByText('File Name')).toBeInTheDocument() + }) + + it('should render the sort icon', () => { + const { container } = render(<SortHeader {...defaultProps} />) + const icon = container.querySelector('svg') + expect(icon).toBeInTheDocument() + }) + }) + + describe('inactive state', () => { + it('should have disabled text color when not active', () => { + const { container } = render(<SortHeader {...defaultProps} />) + const icon = container.querySelector('svg') + expect(icon).toHaveClass('text-text-disabled') + }) + + it('should not be rotated when not active', () => { + const { container } = render(<SortHeader {...defaultProps} />) + const icon = container.querySelector('svg') + expect(icon).not.toHaveClass('rotate-180') + }) + }) + + describe('active state', () => { + it('should have tertiary text color when active', () => { + const { container } = render( + <SortHeader {...defaultProps} currentSortField="name" />, + ) + const icon = container.querySelector('svg') + expect(icon).toHaveClass('text-text-tertiary') + }) + + it('should not be rotated when active and desc', () => { + const { container } = render( + <SortHeader {...defaultProps} currentSortField="name" sortOrder="desc" />, + ) + const icon = container.querySelector('svg') + expect(icon).not.toHaveClass('rotate-180') + }) + + it('should be rotated when active and asc', () => { + const { container } = render( + <SortHeader {...defaultProps} currentSortField="name" sortOrder="asc" />, + ) + const icon = container.querySelector('svg') + expect(icon).toHaveClass('rotate-180') + }) + }) + + describe('interaction', () => { + it('should call onSort when clicked', () => { + const onSort = vi.fn() + render(<SortHeader {...defaultProps} onSort={onSort} />) + + fireEvent.click(screen.getByText('File Name')) + + expect(onSort).toHaveBeenCalledWith('name') + }) + + it('should call onSort with correct field', () => { + const onSort = vi.fn() + render(<SortHeader {...defaultProps} field="word_count" onSort={onSort} />) + + fireEvent.click(screen.getByText('File Name')) + + expect(onSort).toHaveBeenCalledWith('word_count') + }) + }) + + describe('different fields', () => { + it('should work with word_count field', () => { + render( + <SortHeader + {...defaultProps} + field="word_count" + label="Words" + currentSortField="word_count" + />, + ) + expect(screen.getByText('Words')).toBeInTheDocument() + }) + + it('should work with hit_count field', () => { + render( + <SortHeader + {...defaultProps} + field="hit_count" + label="Hit Count" + currentSortField="hit_count" + />, + ) + expect(screen.getByText('Hit Count')).toBeInTheDocument() + }) + + it('should work with created_at field', () => { + render( + <SortHeader + {...defaultProps} + field="created_at" + label="Upload Time" + currentSortField="created_at" + />, + ) + expect(screen.getByText('Upload Time')).toBeInTheDocument() + }) + }) +}) diff --git a/web/app/components/datasets/documents/components/document-list/components/sort-header.tsx b/web/app/components/datasets/documents/components/document-list/components/sort-header.tsx new file mode 100644 index 0000000000..1dc13df2b0 --- /dev/null +++ b/web/app/components/datasets/documents/components/document-list/components/sort-header.tsx @@ -0,0 +1,44 @@ +import type { FC } from 'react' +import type { SortField, SortOrder } from '../hooks' +import { RiArrowDownLine } from '@remixicon/react' +import * as React from 'react' +import { cn } from '@/utils/classnames' + +type SortHeaderProps = { + field: Exclude<SortField, null> + label: string + currentSortField: SortField + sortOrder: SortOrder + onSort: (field: SortField) => void +} + +const SortHeader: FC<SortHeaderProps> = React.memo(({ + field, + label, + currentSortField, + sortOrder, + onSort, +}) => { + const isActive = currentSortField === field + const isDesc = isActive && sortOrder === 'desc' + + return ( + <div + className="flex cursor-pointer items-center hover:text-text-secondary" + onClick={() => onSort(field)} + > + {label} + <RiArrowDownLine + className={cn( + 'ml-0.5 h-3 w-3 transition-all', + isActive ? 'text-text-tertiary' : 'text-text-disabled', + isActive && !isDesc ? 'rotate-180' : '', + )} + /> + </div> + ) +}) + +SortHeader.displayName = 'SortHeader' + +export default SortHeader diff --git a/web/app/components/datasets/documents/components/document-list/components/utils.spec.tsx b/web/app/components/datasets/documents/components/document-list/components/utils.spec.tsx new file mode 100644 index 0000000000..7dc66d4d39 --- /dev/null +++ b/web/app/components/datasets/documents/components/document-list/components/utils.spec.tsx @@ -0,0 +1,90 @@ +import { render, screen } from '@testing-library/react' +import { describe, expect, it } from 'vitest' +import { renderTdValue } from './utils' + +describe('renderTdValue', () => { + describe('Rendering', () => { + it('should render string value correctly', () => { + const { container } = render(<>{renderTdValue('test value')}</>) + expect(screen.getByText('test value')).toBeInTheDocument() + expect(container.querySelector('div')).toHaveClass('text-text-secondary') + }) + + it('should render number value correctly', () => { + const { container } = render(<>{renderTdValue(42)}</>) + expect(screen.getByText('42')).toBeInTheDocument() + expect(container.querySelector('div')).toHaveClass('text-text-secondary') + }) + + it('should render zero correctly', () => { + const { container } = render(<>{renderTdValue(0)}</>) + expect(screen.getByText('0')).toBeInTheDocument() + expect(container.querySelector('div')).toHaveClass('text-text-secondary') + }) + }) + + describe('Null and undefined handling', () => { + it('should render dash for null value', () => { + render(<>{renderTdValue(null)}</>) + expect(screen.getByText('-')).toBeInTheDocument() + }) + + it('should render dash for null value with empty style', () => { + const { container } = render(<>{renderTdValue(null, true)}</>) + expect(screen.getByText('-')).toBeInTheDocument() + expect(container.querySelector('div')).toHaveClass('text-text-tertiary') + }) + }) + + describe('Empty style', () => { + it('should apply text-text-tertiary class when isEmptyStyle is true', () => { + const { container } = render(<>{renderTdValue('value', true)}</>) + expect(container.querySelector('div')).toHaveClass('text-text-tertiary') + }) + + it('should apply text-text-secondary class when isEmptyStyle is false', () => { + const { container } = render(<>{renderTdValue('value', false)}</>) + expect(container.querySelector('div')).toHaveClass('text-text-secondary') + }) + + it('should apply text-text-secondary class when isEmptyStyle is not provided', () => { + const { container } = render(<>{renderTdValue('value')}</>) + expect(container.querySelector('div')).toHaveClass('text-text-secondary') + }) + }) + + describe('Edge Cases', () => { + it('should handle empty string', () => { + render(<>{renderTdValue('')}</>) + // Empty string should still render but with no visible text + const div = document.querySelector('div') + expect(div).toBeInTheDocument() + }) + + it('should handle large numbers', () => { + render(<>{renderTdValue(1234567890)}</>) + expect(screen.getByText('1234567890')).toBeInTheDocument() + }) + + it('should handle negative numbers', () => { + render(<>{renderTdValue(-42)}</>) + expect(screen.getByText('-42')).toBeInTheDocument() + }) + + it('should handle special characters in string', () => { + render(<>{renderTdValue('<script>alert("xss")</script>')}</>) + expect(screen.getByText('<script>alert("xss")</script>')).toBeInTheDocument() + }) + + it('should handle unicode characters', () => { + render(<>{renderTdValue('Test Unicode: \u4E2D\u6587')}</>) + expect(screen.getByText('Test Unicode: \u4E2D\u6587')).toBeInTheDocument() + }) + + it('should handle very long strings', () => { + const longString = 'a'.repeat(1000) + render(<>{renderTdValue(longString)}</>) + expect(screen.getByText(longString)).toBeInTheDocument() + }) + }) +}) diff --git a/web/app/components/datasets/documents/components/document-list/components/utils.tsx b/web/app/components/datasets/documents/components/document-list/components/utils.tsx new file mode 100644 index 0000000000..4cb652108d --- /dev/null +++ b/web/app/components/datasets/documents/components/document-list/components/utils.tsx @@ -0,0 +1,16 @@ +import type { ReactNode } from 'react' +import { cn } from '@/utils/classnames' +import s from '../../../style.module.css' + +export const renderTdValue = (value: string | number | null, isEmptyStyle = false): ReactNode => { + const className = cn( + isEmptyStyle ? 'text-text-tertiary' : 'text-text-secondary', + s.tdValue, + ) + + return ( + <div className={className}> + {value ?? '-'} + </div> + ) +} diff --git a/web/app/components/datasets/documents/components/document-list/hooks/index.ts b/web/app/components/datasets/documents/components/document-list/hooks/index.ts new file mode 100644 index 0000000000..3ca7a920f2 --- /dev/null +++ b/web/app/components/datasets/documents/components/document-list/hooks/index.ts @@ -0,0 +1,4 @@ +export { useDocumentActions } from './use-document-actions' +export { useDocumentSelection } from './use-document-selection' +export { useDocumentSort } from './use-document-sort' +export type { SortField, SortOrder } from './use-document-sort' diff --git a/web/app/components/datasets/documents/components/document-list/hooks/use-document-actions.spec.tsx b/web/app/components/datasets/documents/components/document-list/hooks/use-document-actions.spec.tsx new file mode 100644 index 0000000000..bc84477744 --- /dev/null +++ b/web/app/components/datasets/documents/components/document-list/hooks/use-document-actions.spec.tsx @@ -0,0 +1,438 @@ +import type { ReactNode } from 'react' +import { QueryClient, QueryClientProvider } from '@tanstack/react-query' +import { act, renderHook, waitFor } from '@testing-library/react' +import { beforeEach, describe, expect, it, vi } from 'vitest' +import { DocumentActionType } from '@/models/datasets' +import * as useDocument from '@/service/knowledge/use-document' +import { useDocumentActions } from './use-document-actions' + +vi.mock('@/service/knowledge/use-document') + +const mockUseDocumentArchive = vi.mocked(useDocument.useDocumentArchive) +const mockUseDocumentSummary = vi.mocked(useDocument.useDocumentSummary) +const mockUseDocumentEnable = vi.mocked(useDocument.useDocumentEnable) +const mockUseDocumentDisable = vi.mocked(useDocument.useDocumentDisable) +const mockUseDocumentDelete = vi.mocked(useDocument.useDocumentDelete) +const mockUseDocumentBatchRetryIndex = vi.mocked(useDocument.useDocumentBatchRetryIndex) +const mockUseDocumentDownloadZip = vi.mocked(useDocument.useDocumentDownloadZip) + +const createTestQueryClient = () => new QueryClient({ + defaultOptions: { + queries: { retry: false }, + mutations: { retry: false }, + }, +}) + +const createWrapper = () => { + const queryClient = createTestQueryClient() + return ({ children }: { children: ReactNode }) => ( + <QueryClientProvider client={queryClient}> + {children} + </QueryClientProvider> + ) +} + +describe('useDocumentActions', () => { + const mockMutateAsync = vi.fn() + + beforeEach(() => { + vi.clearAllMocks() + + // Setup all mocks with default values + const createMockMutation = () => ({ + mutateAsync: mockMutateAsync, + isPending: false, + isError: false, + isSuccess: false, + isIdle: true, + data: undefined, + error: null, + mutate: vi.fn(), + reset: vi.fn(), + status: 'idle' as const, + variables: undefined, + context: undefined, + failureCount: 0, + failureReason: null, + submittedAt: 0, + }) + + mockUseDocumentArchive.mockReturnValue(createMockMutation() as unknown as ReturnType<typeof useDocument.useDocumentArchive>) + mockUseDocumentSummary.mockReturnValue(createMockMutation() as unknown as ReturnType<typeof useDocument.useDocumentSummary>) + mockUseDocumentEnable.mockReturnValue(createMockMutation() as unknown as ReturnType<typeof useDocument.useDocumentEnable>) + mockUseDocumentDisable.mockReturnValue(createMockMutation() as unknown as ReturnType<typeof useDocument.useDocumentDisable>) + mockUseDocumentDelete.mockReturnValue(createMockMutation() as unknown as ReturnType<typeof useDocument.useDocumentDelete>) + mockUseDocumentBatchRetryIndex.mockReturnValue(createMockMutation() as unknown as ReturnType<typeof useDocument.useDocumentBatchRetryIndex>) + mockUseDocumentDownloadZip.mockReturnValue({ + ...createMockMutation(), + isPending: false, + } as unknown as ReturnType<typeof useDocument.useDocumentDownloadZip>) + }) + + describe('handleAction', () => { + it('should call archive mutation when archive action is triggered', async () => { + mockMutateAsync.mockResolvedValue({ result: 'success' }) + const onUpdate = vi.fn() + const onClearSelection = vi.fn() + + const { result } = renderHook( + () => useDocumentActions({ + datasetId: 'ds1', + selectedIds: ['doc1'], + downloadableSelectedIds: [], + onUpdate, + onClearSelection, + }), + { wrapper: createWrapper() }, + ) + + await act(async () => { + await result.current.handleAction(DocumentActionType.archive)() + }) + + expect(mockMutateAsync).toHaveBeenCalledWith({ + datasetId: 'ds1', + documentIds: ['doc1'], + }) + }) + + it('should call onUpdate on successful action', async () => { + mockMutateAsync.mockResolvedValue({ result: 'success' }) + const onUpdate = vi.fn() + const onClearSelection = vi.fn() + + const { result } = renderHook( + () => useDocumentActions({ + datasetId: 'ds1', + selectedIds: ['doc1'], + downloadableSelectedIds: [], + onUpdate, + onClearSelection, + }), + { wrapper: createWrapper() }, + ) + + await act(async () => { + await result.current.handleAction(DocumentActionType.enable)() + }) + + await waitFor(() => { + expect(onUpdate).toHaveBeenCalled() + }) + }) + + it('should call onClearSelection on delete action', async () => { + mockMutateAsync.mockResolvedValue({ result: 'success' }) + const onUpdate = vi.fn() + const onClearSelection = vi.fn() + + const { result } = renderHook( + () => useDocumentActions({ + datasetId: 'ds1', + selectedIds: ['doc1'], + downloadableSelectedIds: [], + onUpdate, + onClearSelection, + }), + { wrapper: createWrapper() }, + ) + + await act(async () => { + await result.current.handleAction(DocumentActionType.delete)() + }) + + await waitFor(() => { + expect(onClearSelection).toHaveBeenCalled() + }) + }) + }) + + describe('handleBatchReIndex', () => { + it('should call retry index mutation', async () => { + mockMutateAsync.mockResolvedValue({ result: 'success' }) + const onUpdate = vi.fn() + const onClearSelection = vi.fn() + + const { result } = renderHook( + () => useDocumentActions({ + datasetId: 'ds1', + selectedIds: ['doc1', 'doc2'], + downloadableSelectedIds: [], + onUpdate, + onClearSelection, + }), + { wrapper: createWrapper() }, + ) + + await act(async () => { + await result.current.handleBatchReIndex() + }) + + expect(mockMutateAsync).toHaveBeenCalledWith({ + datasetId: 'ds1', + documentIds: ['doc1', 'doc2'], + }) + }) + + it('should call onClearSelection on success', async () => { + mockMutateAsync.mockResolvedValue({ result: 'success' }) + const onUpdate = vi.fn() + const onClearSelection = vi.fn() + + const { result } = renderHook( + () => useDocumentActions({ + datasetId: 'ds1', + selectedIds: ['doc1'], + downloadableSelectedIds: [], + onUpdate, + onClearSelection, + }), + { wrapper: createWrapper() }, + ) + + await act(async () => { + await result.current.handleBatchReIndex() + }) + + await waitFor(() => { + expect(onClearSelection).toHaveBeenCalled() + expect(onUpdate).toHaveBeenCalled() + }) + }) + }) + + describe('handleBatchDownload', () => { + it('should not proceed when already downloading', async () => { + mockUseDocumentDownloadZip.mockReturnValue({ + mutateAsync: mockMutateAsync, + isPending: true, + } as unknown as ReturnType<typeof useDocument.useDocumentDownloadZip>) + + const { result } = renderHook( + () => useDocumentActions({ + datasetId: 'ds1', + selectedIds: ['doc1'], + downloadableSelectedIds: ['doc1'], + onUpdate: vi.fn(), + onClearSelection: vi.fn(), + }), + { wrapper: createWrapper() }, + ) + + await act(async () => { + await result.current.handleBatchDownload() + }) + + expect(mockMutateAsync).not.toHaveBeenCalled() + }) + + it('should call download mutation with downloadable ids', async () => { + const mockBlob = new Blob(['test']) + mockMutateAsync.mockResolvedValue(mockBlob) + + mockUseDocumentDownloadZip.mockReturnValue({ + mutateAsync: mockMutateAsync, + isPending: false, + } as unknown as ReturnType<typeof useDocument.useDocumentDownloadZip>) + + const { result } = renderHook( + () => useDocumentActions({ + datasetId: 'ds1', + selectedIds: ['doc1', 'doc2'], + downloadableSelectedIds: ['doc1'], + onUpdate: vi.fn(), + onClearSelection: vi.fn(), + }), + { wrapper: createWrapper() }, + ) + + await act(async () => { + await result.current.handleBatchDownload() + }) + + expect(mockMutateAsync).toHaveBeenCalledWith({ + datasetId: 'ds1', + documentIds: ['doc1'], + }) + }) + }) + + describe('isDownloadingZip', () => { + it('should reflect isPending state from mutation', () => { + mockUseDocumentDownloadZip.mockReturnValue({ + mutateAsync: mockMutateAsync, + isPending: true, + } as unknown as ReturnType<typeof useDocument.useDocumentDownloadZip>) + + const { result } = renderHook( + () => useDocumentActions({ + datasetId: 'ds1', + selectedIds: [], + downloadableSelectedIds: [], + onUpdate: vi.fn(), + onClearSelection: vi.fn(), + }), + { wrapper: createWrapper() }, + ) + + expect(result.current.isDownloadingZip).toBe(true) + }) + }) + + describe('error handling', () => { + it('should show error toast when handleAction fails', async () => { + mockMutateAsync.mockRejectedValue(new Error('Action failed')) + const onUpdate = vi.fn() + const onClearSelection = vi.fn() + + const { result } = renderHook( + () => useDocumentActions({ + datasetId: 'ds1', + selectedIds: ['doc1'], + downloadableSelectedIds: [], + onUpdate, + onClearSelection, + }), + { wrapper: createWrapper() }, + ) + + await act(async () => { + await result.current.handleAction(DocumentActionType.archive)() + }) + + // onUpdate should not be called on error + expect(onUpdate).not.toHaveBeenCalled() + }) + + it('should show error toast when handleBatchReIndex fails', async () => { + mockMutateAsync.mockRejectedValue(new Error('Re-index failed')) + const onUpdate = vi.fn() + const onClearSelection = vi.fn() + + const { result } = renderHook( + () => useDocumentActions({ + datasetId: 'ds1', + selectedIds: ['doc1'], + downloadableSelectedIds: [], + onUpdate, + onClearSelection, + }), + { wrapper: createWrapper() }, + ) + + await act(async () => { + await result.current.handleBatchReIndex() + }) + + // onUpdate and onClearSelection should not be called on error + expect(onUpdate).not.toHaveBeenCalled() + expect(onClearSelection).not.toHaveBeenCalled() + }) + + it('should show error toast when handleBatchDownload fails', async () => { + mockMutateAsync.mockRejectedValue(new Error('Download failed')) + + mockUseDocumentDownloadZip.mockReturnValue({ + mutateAsync: mockMutateAsync, + isPending: false, + } as unknown as ReturnType<typeof useDocument.useDocumentDownloadZip>) + + const { result } = renderHook( + () => useDocumentActions({ + datasetId: 'ds1', + selectedIds: ['doc1'], + downloadableSelectedIds: ['doc1'], + onUpdate: vi.fn(), + onClearSelection: vi.fn(), + }), + { wrapper: createWrapper() }, + ) + + await act(async () => { + await result.current.handleBatchDownload() + }) + + // Mutation was called but failed + expect(mockMutateAsync).toHaveBeenCalled() + }) + + it('should show error toast when handleBatchDownload returns null blob', async () => { + mockMutateAsync.mockResolvedValue(null) + + mockUseDocumentDownloadZip.mockReturnValue({ + mutateAsync: mockMutateAsync, + isPending: false, + } as unknown as ReturnType<typeof useDocument.useDocumentDownloadZip>) + + const { result } = renderHook( + () => useDocumentActions({ + datasetId: 'ds1', + selectedIds: ['doc1'], + downloadableSelectedIds: ['doc1'], + onUpdate: vi.fn(), + onClearSelection: vi.fn(), + }), + { wrapper: createWrapper() }, + ) + + await act(async () => { + await result.current.handleBatchDownload() + }) + + // Mutation was called but returned null + expect(mockMutateAsync).toHaveBeenCalled() + }) + }) + + describe('all action types', () => { + it('should handle summary action', async () => { + mockMutateAsync.mockResolvedValue({ result: 'success' }) + const onUpdate = vi.fn() + + const { result } = renderHook( + () => useDocumentActions({ + datasetId: 'ds1', + selectedIds: ['doc1'], + downloadableSelectedIds: [], + onUpdate, + onClearSelection: vi.fn(), + }), + { wrapper: createWrapper() }, + ) + + await act(async () => { + await result.current.handleAction(DocumentActionType.summary)() + }) + + expect(mockMutateAsync).toHaveBeenCalled() + await waitFor(() => { + expect(onUpdate).toHaveBeenCalled() + }) + }) + + it('should handle disable action', async () => { + mockMutateAsync.mockResolvedValue({ result: 'success' }) + const onUpdate = vi.fn() + + const { result } = renderHook( + () => useDocumentActions({ + datasetId: 'ds1', + selectedIds: ['doc1'], + downloadableSelectedIds: [], + onUpdate, + onClearSelection: vi.fn(), + }), + { wrapper: createWrapper() }, + ) + + await act(async () => { + await result.current.handleAction(DocumentActionType.disable)() + }) + + expect(mockMutateAsync).toHaveBeenCalled() + await waitFor(() => { + expect(onUpdate).toHaveBeenCalled() + }) + }) + }) +}) diff --git a/web/app/components/datasets/documents/components/document-list/hooks/use-document-actions.ts b/web/app/components/datasets/documents/components/document-list/hooks/use-document-actions.ts new file mode 100644 index 0000000000..56553faa9e --- /dev/null +++ b/web/app/components/datasets/documents/components/document-list/hooks/use-document-actions.ts @@ -0,0 +1,126 @@ +import type { CommonResponse } from '@/models/common' +import { useCallback, useMemo } from 'react' +import { useTranslation } from 'react-i18next' +import Toast from '@/app/components/base/toast' +import { DocumentActionType } from '@/models/datasets' +import { + useDocumentArchive, + useDocumentBatchRetryIndex, + useDocumentDelete, + useDocumentDisable, + useDocumentDownloadZip, + useDocumentEnable, + useDocumentSummary, +} from '@/service/knowledge/use-document' +import { asyncRunSafe } from '@/utils' +import { downloadBlob } from '@/utils/download' + +type UseDocumentActionsOptions = { + datasetId: string + selectedIds: string[] + downloadableSelectedIds: string[] + onUpdate: () => void + onClearSelection: () => void +} + +/** + * Generate a random ZIP filename for bulk document downloads. + * We intentionally avoid leaking dataset info in the exported archive name. + */ +const generateDocsZipFileName = (): string => { + const randomPart = (typeof crypto !== 'undefined' && typeof crypto.randomUUID === 'function') + ? crypto.randomUUID() + : `${Date.now().toString(36)}${Math.random().toString(36).slice(2, 10)}` + return `${randomPart}-docs.zip` +} + +export const useDocumentActions = ({ + datasetId, + selectedIds, + downloadableSelectedIds, + onUpdate, + onClearSelection, +}: UseDocumentActionsOptions) => { + const { t } = useTranslation() + + const { mutateAsync: archiveDocument } = useDocumentArchive() + const { mutateAsync: generateSummary } = useDocumentSummary() + const { mutateAsync: enableDocument } = useDocumentEnable() + const { mutateAsync: disableDocument } = useDocumentDisable() + const { mutateAsync: deleteDocument } = useDocumentDelete() + const { mutateAsync: retryIndexDocument } = useDocumentBatchRetryIndex() + const { mutateAsync: requestDocumentsZip, isPending: isDownloadingZip } = useDocumentDownloadZip() + + type SupportedActionType + = | typeof DocumentActionType.archive + | typeof DocumentActionType.summary + | typeof DocumentActionType.enable + | typeof DocumentActionType.disable + | typeof DocumentActionType.delete + + const actionMutationMap = useMemo(() => ({ + [DocumentActionType.archive]: archiveDocument, + [DocumentActionType.summary]: generateSummary, + [DocumentActionType.enable]: enableDocument, + [DocumentActionType.disable]: disableDocument, + [DocumentActionType.delete]: deleteDocument, + } as const), [archiveDocument, generateSummary, enableDocument, disableDocument, deleteDocument]) + + const handleAction = useCallback((actionName: SupportedActionType) => { + return async () => { + const opApi = actionMutationMap[actionName] + if (!opApi) + return + + const [e] = await asyncRunSafe<CommonResponse>( + opApi({ datasetId, documentIds: selectedIds }), + ) + + if (!e) { + if (actionName === DocumentActionType.delete) + onClearSelection() + Toast.notify({ type: 'success', message: t('actionMsg.modifiedSuccessfully', { ns: 'common' }) }) + onUpdate() + } + else { + Toast.notify({ type: 'error', message: t('actionMsg.modifiedUnsuccessfully', { ns: 'common' }) }) + } + } + }, [actionMutationMap, datasetId, selectedIds, onClearSelection, onUpdate, t]) + + const handleBatchReIndex = useCallback(async () => { + const [e] = await asyncRunSafe<CommonResponse>( + retryIndexDocument({ datasetId, documentIds: selectedIds }), + ) + if (!e) { + onClearSelection() + Toast.notify({ type: 'success', message: t('actionMsg.modifiedSuccessfully', { ns: 'common' }) }) + onUpdate() + } + else { + Toast.notify({ type: 'error', message: t('actionMsg.modifiedUnsuccessfully', { ns: 'common' }) }) + } + }, [retryIndexDocument, datasetId, selectedIds, onClearSelection, onUpdate, t]) + + const handleBatchDownload = useCallback(async () => { + if (isDownloadingZip) + return + + const [e, blob] = await asyncRunSafe( + requestDocumentsZip({ datasetId, documentIds: downloadableSelectedIds }), + ) + if (e || !blob) { + Toast.notify({ type: 'error', message: t('actionMsg.downloadUnsuccessfully', { ns: 'common' }) }) + return + } + + downloadBlob({ data: blob, fileName: generateDocsZipFileName() }) + }, [datasetId, downloadableSelectedIds, isDownloadingZip, requestDocumentsZip, t]) + + return { + handleAction, + handleBatchReIndex, + handleBatchDownload, + isDownloadingZip, + } +} diff --git a/web/app/components/datasets/documents/components/document-list/hooks/use-document-selection.spec.ts b/web/app/components/datasets/documents/components/document-list/hooks/use-document-selection.spec.ts new file mode 100644 index 0000000000..7775c83f1c --- /dev/null +++ b/web/app/components/datasets/documents/components/document-list/hooks/use-document-selection.spec.ts @@ -0,0 +1,317 @@ +import type { SimpleDocumentDetail } from '@/models/datasets' +import { act, renderHook } from '@testing-library/react' +import { describe, expect, it, vi } from 'vitest' +import { DataSourceType } from '@/models/datasets' +import { useDocumentSelection } from './use-document-selection' + +type LocalDoc = SimpleDocumentDetail & { percent?: number } + +const createMockDocument = (overrides: Partial<LocalDoc> = {}): LocalDoc => ({ + id: 'doc1', + name: 'Test Document', + data_source_type: DataSourceType.FILE, + data_source_info: {}, + data_source_detail_dict: {}, + word_count: 100, + hit_count: 10, + created_at: 1000000, + position: 1, + doc_form: 'text_model', + enabled: true, + archived: false, + display_status: 'available', + created_from: 'api', + ...overrides, +} as LocalDoc) + +describe('useDocumentSelection', () => { + describe('isAllSelected', () => { + it('should return false when documents is empty', () => { + const onSelectedIdChange = vi.fn() + const { result } = renderHook(() => + useDocumentSelection({ + documents: [], + selectedIds: [], + onSelectedIdChange, + }), + ) + + expect(result.current.isAllSelected).toBe(false) + }) + + it('should return true when all documents are selected', () => { + const docs = [ + createMockDocument({ id: 'doc1' }), + createMockDocument({ id: 'doc2' }), + ] + const onSelectedIdChange = vi.fn() + + const { result } = renderHook(() => + useDocumentSelection({ + documents: docs, + selectedIds: ['doc1', 'doc2'], + onSelectedIdChange, + }), + ) + + expect(result.current.isAllSelected).toBe(true) + }) + + it('should return false when not all documents are selected', () => { + const docs = [ + createMockDocument({ id: 'doc1' }), + createMockDocument({ id: 'doc2' }), + ] + const onSelectedIdChange = vi.fn() + + const { result } = renderHook(() => + useDocumentSelection({ + documents: docs, + selectedIds: ['doc1'], + onSelectedIdChange, + }), + ) + + expect(result.current.isAllSelected).toBe(false) + }) + }) + + describe('isSomeSelected', () => { + it('should return false when no documents are selected', () => { + const docs = [createMockDocument({ id: 'doc1' })] + const onSelectedIdChange = vi.fn() + + const { result } = renderHook(() => + useDocumentSelection({ + documents: docs, + selectedIds: [], + onSelectedIdChange, + }), + ) + + expect(result.current.isSomeSelected).toBe(false) + }) + + it('should return true when some documents are selected', () => { + const docs = [ + createMockDocument({ id: 'doc1' }), + createMockDocument({ id: 'doc2' }), + ] + const onSelectedIdChange = vi.fn() + + const { result } = renderHook(() => + useDocumentSelection({ + documents: docs, + selectedIds: ['doc1'], + onSelectedIdChange, + }), + ) + + expect(result.current.isSomeSelected).toBe(true) + }) + }) + + describe('onSelectAll', () => { + it('should select all documents when none are selected', () => { + const docs = [ + createMockDocument({ id: 'doc1' }), + createMockDocument({ id: 'doc2' }), + ] + const onSelectedIdChange = vi.fn() + + const { result } = renderHook(() => + useDocumentSelection({ + documents: docs, + selectedIds: [], + onSelectedIdChange, + }), + ) + + act(() => { + result.current.onSelectAll() + }) + + expect(onSelectedIdChange).toHaveBeenCalledWith(['doc1', 'doc2']) + }) + + it('should deselect all when all are selected', () => { + const docs = [ + createMockDocument({ id: 'doc1' }), + createMockDocument({ id: 'doc2' }), + ] + const onSelectedIdChange = vi.fn() + + const { result } = renderHook(() => + useDocumentSelection({ + documents: docs, + selectedIds: ['doc1', 'doc2'], + onSelectedIdChange, + }), + ) + + act(() => { + result.current.onSelectAll() + }) + + expect(onSelectedIdChange).toHaveBeenCalledWith([]) + }) + + it('should add to existing selection when some are selected', () => { + const docs = [ + createMockDocument({ id: 'doc1' }), + createMockDocument({ id: 'doc2' }), + createMockDocument({ id: 'doc3' }), + ] + const onSelectedIdChange = vi.fn() + + const { result } = renderHook(() => + useDocumentSelection({ + documents: docs, + selectedIds: ['doc1'], + onSelectedIdChange, + }), + ) + + act(() => { + result.current.onSelectAll() + }) + + expect(onSelectedIdChange).toHaveBeenCalledWith(['doc1', 'doc2', 'doc3']) + }) + }) + + describe('onSelectOne', () => { + it('should add document to selection when not selected', () => { + const onSelectedIdChange = vi.fn() + + const { result } = renderHook(() => + useDocumentSelection({ + documents: [], + selectedIds: [], + onSelectedIdChange, + }), + ) + + act(() => { + result.current.onSelectOne('doc1') + }) + + expect(onSelectedIdChange).toHaveBeenCalledWith(['doc1']) + }) + + it('should remove document from selection when already selected', () => { + const onSelectedIdChange = vi.fn() + + const { result } = renderHook(() => + useDocumentSelection({ + documents: [], + selectedIds: ['doc1', 'doc2'], + onSelectedIdChange, + }), + ) + + act(() => { + result.current.onSelectOne('doc1') + }) + + expect(onSelectedIdChange).toHaveBeenCalledWith(['doc2']) + }) + }) + + describe('hasErrorDocumentsSelected', () => { + it('should return false when no error documents are selected', () => { + const docs = [ + createMockDocument({ id: 'doc1', display_status: 'available' }), + createMockDocument({ id: 'doc2', display_status: 'error' }), + ] + const onSelectedIdChange = vi.fn() + + const { result } = renderHook(() => + useDocumentSelection({ + documents: docs, + selectedIds: ['doc1'], + onSelectedIdChange, + }), + ) + + expect(result.current.hasErrorDocumentsSelected).toBe(false) + }) + + it('should return true when an error document is selected', () => { + const docs = [ + createMockDocument({ id: 'doc1', display_status: 'available' }), + createMockDocument({ id: 'doc2', display_status: 'error' }), + ] + const onSelectedIdChange = vi.fn() + + const { result } = renderHook(() => + useDocumentSelection({ + documents: docs, + selectedIds: ['doc2'], + onSelectedIdChange, + }), + ) + + expect(result.current.hasErrorDocumentsSelected).toBe(true) + }) + }) + + describe('downloadableSelectedIds', () => { + it('should return only FILE type documents from selection', () => { + const docs = [ + createMockDocument({ id: 'doc1', data_source_type: DataSourceType.FILE }), + createMockDocument({ id: 'doc2', data_source_type: DataSourceType.NOTION }), + createMockDocument({ id: 'doc3', data_source_type: DataSourceType.FILE }), + ] + const onSelectedIdChange = vi.fn() + + const { result } = renderHook(() => + useDocumentSelection({ + documents: docs, + selectedIds: ['doc1', 'doc2', 'doc3'], + onSelectedIdChange, + }), + ) + + expect(result.current.downloadableSelectedIds).toEqual(['doc1', 'doc3']) + }) + + it('should return empty array when no FILE documents selected', () => { + const docs = [ + createMockDocument({ id: 'doc1', data_source_type: DataSourceType.NOTION }), + createMockDocument({ id: 'doc2', data_source_type: DataSourceType.WEB }), + ] + const onSelectedIdChange = vi.fn() + + const { result } = renderHook(() => + useDocumentSelection({ + documents: docs, + selectedIds: ['doc1', 'doc2'], + onSelectedIdChange, + }), + ) + + expect(result.current.downloadableSelectedIds).toEqual([]) + }) + }) + + describe('clearSelection', () => { + it('should call onSelectedIdChange with empty array', () => { + const onSelectedIdChange = vi.fn() + + const { result } = renderHook(() => + useDocumentSelection({ + documents: [], + selectedIds: ['doc1', 'doc2'], + onSelectedIdChange, + }), + ) + + act(() => { + result.current.clearSelection() + }) + + expect(onSelectedIdChange).toHaveBeenCalledWith([]) + }) + }) +}) diff --git a/web/app/components/datasets/documents/components/document-list/hooks/use-document-selection.ts b/web/app/components/datasets/documents/components/document-list/hooks/use-document-selection.ts new file mode 100644 index 0000000000..ad12b2b00f --- /dev/null +++ b/web/app/components/datasets/documents/components/document-list/hooks/use-document-selection.ts @@ -0,0 +1,66 @@ +import type { SimpleDocumentDetail } from '@/models/datasets' +import { uniq } from 'es-toolkit/array' +import { useCallback, useMemo } from 'react' +import { DataSourceType } from '@/models/datasets' + +type LocalDoc = SimpleDocumentDetail & { percent?: number } + +type UseDocumentSelectionOptions = { + documents: LocalDoc[] + selectedIds: string[] + onSelectedIdChange: (selectedIds: string[]) => void +} + +export const useDocumentSelection = ({ + documents, + selectedIds, + onSelectedIdChange, +}: UseDocumentSelectionOptions) => { + const isAllSelected = useMemo(() => { + return documents.length > 0 && documents.every(doc => selectedIds.includes(doc.id)) + }, [documents, selectedIds]) + + const isSomeSelected = useMemo(() => { + return documents.some(doc => selectedIds.includes(doc.id)) + }, [documents, selectedIds]) + + const onSelectAll = useCallback(() => { + if (isAllSelected) + onSelectedIdChange([]) + else + onSelectedIdChange(uniq([...selectedIds, ...documents.map(doc => doc.id)])) + }, [isAllSelected, documents, onSelectedIdChange, selectedIds]) + + const onSelectOne = useCallback((docId: string) => { + onSelectedIdChange( + selectedIds.includes(docId) + ? selectedIds.filter(id => id !== docId) + : [...selectedIds, docId], + ) + }, [selectedIds, onSelectedIdChange]) + + const hasErrorDocumentsSelected = useMemo(() => { + return documents.some(doc => selectedIds.includes(doc.id) && doc.display_status === 'error') + }, [documents, selectedIds]) + + const downloadableSelectedIds = useMemo(() => { + const selectedSet = new Set(selectedIds) + return documents + .filter(doc => selectedSet.has(doc.id) && doc.data_source_type === DataSourceType.FILE) + .map(doc => doc.id) + }, [documents, selectedIds]) + + const clearSelection = useCallback(() => { + onSelectedIdChange([]) + }, [onSelectedIdChange]) + + return { + isAllSelected, + isSomeSelected, + onSelectAll, + onSelectOne, + hasErrorDocumentsSelected, + downloadableSelectedIds, + clearSelection, + } +} diff --git a/web/app/components/datasets/documents/components/document-list/hooks/use-document-sort.spec.ts b/web/app/components/datasets/documents/components/document-list/hooks/use-document-sort.spec.ts new file mode 100644 index 0000000000..a41b42d6fa --- /dev/null +++ b/web/app/components/datasets/documents/components/document-list/hooks/use-document-sort.spec.ts @@ -0,0 +1,340 @@ +import type { SimpleDocumentDetail } from '@/models/datasets' +import { act, renderHook } from '@testing-library/react' +import { describe, expect, it } from 'vitest' +import { useDocumentSort } from './use-document-sort' + +type LocalDoc = SimpleDocumentDetail & { percent?: number } + +const createMockDocument = (overrides: Partial<LocalDoc> = {}): LocalDoc => ({ + id: 'doc1', + name: 'Test Document', + data_source_type: 'upload_file', + data_source_info: {}, + data_source_detail_dict: {}, + word_count: 100, + hit_count: 10, + created_at: 1000000, + position: 1, + doc_form: 'text_model', + enabled: true, + archived: false, + display_status: 'available', + created_from: 'api', + ...overrides, +} as LocalDoc) + +describe('useDocumentSort', () => { + describe('initial state', () => { + it('should return null sortField initially', () => { + const { result } = renderHook(() => + useDocumentSort({ + documents: [], + statusFilterValue: '', + remoteSortValue: '', + }), + ) + + expect(result.current.sortField).toBeNull() + expect(result.current.sortOrder).toBe('desc') + }) + + it('should return documents unchanged when no sort is applied', () => { + const docs = [ + createMockDocument({ id: 'doc1', name: 'B' }), + createMockDocument({ id: 'doc2', name: 'A' }), + ] + + const { result } = renderHook(() => + useDocumentSort({ + documents: docs, + statusFilterValue: '', + remoteSortValue: '', + }), + ) + + expect(result.current.sortedDocuments).toEqual(docs) + }) + }) + + describe('handleSort', () => { + it('should set sort field when called', () => { + const { result } = renderHook(() => + useDocumentSort({ + documents: [], + statusFilterValue: '', + remoteSortValue: '', + }), + ) + + act(() => { + result.current.handleSort('name') + }) + + expect(result.current.sortField).toBe('name') + expect(result.current.sortOrder).toBe('desc') + }) + + it('should toggle sort order when same field is clicked twice', () => { + const { result } = renderHook(() => + useDocumentSort({ + documents: [], + statusFilterValue: '', + remoteSortValue: '', + }), + ) + + act(() => { + result.current.handleSort('name') + }) + expect(result.current.sortOrder).toBe('desc') + + act(() => { + result.current.handleSort('name') + }) + expect(result.current.sortOrder).toBe('asc') + + act(() => { + result.current.handleSort('name') + }) + expect(result.current.sortOrder).toBe('desc') + }) + + it('should reset to desc when different field is selected', () => { + const { result } = renderHook(() => + useDocumentSort({ + documents: [], + statusFilterValue: '', + remoteSortValue: '', + }), + ) + + act(() => { + result.current.handleSort('name') + }) + act(() => { + result.current.handleSort('name') + }) + expect(result.current.sortOrder).toBe('asc') + + act(() => { + result.current.handleSort('word_count') + }) + expect(result.current.sortField).toBe('word_count') + expect(result.current.sortOrder).toBe('desc') + }) + + it('should not change state when null is passed', () => { + const { result } = renderHook(() => + useDocumentSort({ + documents: [], + statusFilterValue: '', + remoteSortValue: '', + }), + ) + + act(() => { + result.current.handleSort(null) + }) + + expect(result.current.sortField).toBeNull() + }) + }) + + describe('sorting documents', () => { + const docs = [ + createMockDocument({ id: 'doc1', name: 'Banana', word_count: 200, hit_count: 5, created_at: 3000 }), + createMockDocument({ id: 'doc2', name: 'Apple', word_count: 100, hit_count: 10, created_at: 1000 }), + createMockDocument({ id: 'doc3', name: 'Cherry', word_count: 300, hit_count: 1, created_at: 2000 }), + ] + + it('should sort by name descending', () => { + const { result } = renderHook(() => + useDocumentSort({ + documents: docs, + statusFilterValue: '', + remoteSortValue: '', + }), + ) + + act(() => { + result.current.handleSort('name') + }) + + const names = result.current.sortedDocuments.map(d => d.name) + expect(names).toEqual(['Cherry', 'Banana', 'Apple']) + }) + + it('should sort by name ascending', () => { + const { result } = renderHook(() => + useDocumentSort({ + documents: docs, + statusFilterValue: '', + remoteSortValue: '', + }), + ) + + act(() => { + result.current.handleSort('name') + }) + act(() => { + result.current.handleSort('name') + }) + + const names = result.current.sortedDocuments.map(d => d.name) + expect(names).toEqual(['Apple', 'Banana', 'Cherry']) + }) + + it('should sort by word_count descending', () => { + const { result } = renderHook(() => + useDocumentSort({ + documents: docs, + statusFilterValue: '', + remoteSortValue: '', + }), + ) + + act(() => { + result.current.handleSort('word_count') + }) + + const counts = result.current.sortedDocuments.map(d => d.word_count) + expect(counts).toEqual([300, 200, 100]) + }) + + it('should sort by hit_count ascending', () => { + const { result } = renderHook(() => + useDocumentSort({ + documents: docs, + statusFilterValue: '', + remoteSortValue: '', + }), + ) + + act(() => { + result.current.handleSort('hit_count') + }) + act(() => { + result.current.handleSort('hit_count') + }) + + const counts = result.current.sortedDocuments.map(d => d.hit_count) + expect(counts).toEqual([1, 5, 10]) + }) + + it('should sort by created_at descending', () => { + const { result } = renderHook(() => + useDocumentSort({ + documents: docs, + statusFilterValue: '', + remoteSortValue: '', + }), + ) + + act(() => { + result.current.handleSort('created_at') + }) + + const times = result.current.sortedDocuments.map(d => d.created_at) + expect(times).toEqual([3000, 2000, 1000]) + }) + }) + + describe('status filtering', () => { + const docs = [ + createMockDocument({ id: 'doc1', display_status: 'available' }), + createMockDocument({ id: 'doc2', display_status: 'error' }), + createMockDocument({ id: 'doc3', display_status: 'available' }), + ] + + it('should not filter when statusFilterValue is empty', () => { + const { result } = renderHook(() => + useDocumentSort({ + documents: docs, + statusFilterValue: '', + remoteSortValue: '', + }), + ) + + expect(result.current.sortedDocuments.length).toBe(3) + }) + + it('should not filter when statusFilterValue is all', () => { + const { result } = renderHook(() => + useDocumentSort({ + documents: docs, + statusFilterValue: 'all', + remoteSortValue: '', + }), + ) + + expect(result.current.sortedDocuments.length).toBe(3) + }) + }) + + describe('remoteSortValue reset', () => { + it('should reset sort state when remoteSortValue changes', () => { + const { result, rerender } = renderHook( + ({ remoteSortValue }) => + useDocumentSort({ + documents: [], + statusFilterValue: '', + remoteSortValue, + }), + { initialProps: { remoteSortValue: 'initial' } }, + ) + + act(() => { + result.current.handleSort('name') + }) + act(() => { + result.current.handleSort('name') + }) + expect(result.current.sortField).toBe('name') + expect(result.current.sortOrder).toBe('asc') + + rerender({ remoteSortValue: 'changed' }) + + expect(result.current.sortField).toBeNull() + expect(result.current.sortOrder).toBe('desc') + }) + }) + + describe('edge cases', () => { + it('should handle documents with missing values', () => { + const docs = [ + createMockDocument({ id: 'doc1', name: undefined as unknown as string, word_count: undefined }), + createMockDocument({ id: 'doc2', name: 'Test', word_count: 100 }), + ] + + const { result } = renderHook(() => + useDocumentSort({ + documents: docs, + statusFilterValue: '', + remoteSortValue: '', + }), + ) + + act(() => { + result.current.handleSort('name') + }) + + expect(result.current.sortedDocuments.length).toBe(2) + }) + + it('should handle empty documents array', () => { + const { result } = renderHook(() => + useDocumentSort({ + documents: [], + statusFilterValue: '', + remoteSortValue: '', + }), + ) + + act(() => { + result.current.handleSort('name') + }) + + expect(result.current.sortedDocuments).toEqual([]) + }) + }) +}) diff --git a/web/app/components/datasets/documents/components/document-list/hooks/use-document-sort.ts b/web/app/components/datasets/documents/components/document-list/hooks/use-document-sort.ts new file mode 100644 index 0000000000..98cf244f36 --- /dev/null +++ b/web/app/components/datasets/documents/components/document-list/hooks/use-document-sort.ts @@ -0,0 +1,102 @@ +import type { SimpleDocumentDetail } from '@/models/datasets' +import { useCallback, useMemo, useRef, useState } from 'react' +import { normalizeStatusForQuery } from '@/app/components/datasets/documents/status-filter' + +export type SortField = 'name' | 'word_count' | 'hit_count' | 'created_at' | null +export type SortOrder = 'asc' | 'desc' + +type LocalDoc = SimpleDocumentDetail & { percent?: number } + +type UseDocumentSortOptions = { + documents: LocalDoc[] + statusFilterValue: string + remoteSortValue: string +} + +export const useDocumentSort = ({ + documents, + statusFilterValue, + remoteSortValue, +}: UseDocumentSortOptions) => { + const [sortField, setSortField] = useState<SortField>(null) + const [sortOrder, setSortOrder] = useState<SortOrder>('desc') + const prevRemoteSortValueRef = useRef(remoteSortValue) + + // Reset sort when remote sort changes + if (prevRemoteSortValueRef.current !== remoteSortValue) { + prevRemoteSortValueRef.current = remoteSortValue + setSortField(null) + setSortOrder('desc') + } + + const handleSort = useCallback((field: SortField) => { + if (field === null) + return + + if (sortField === field) { + setSortOrder(prev => prev === 'asc' ? 'desc' : 'asc') + } + else { + setSortField(field) + setSortOrder('desc') + } + }, [sortField]) + + const sortedDocuments = useMemo(() => { + let filteredDocs = documents + + if (statusFilterValue && statusFilterValue !== 'all') { + filteredDocs = filteredDocs.filter(doc => + typeof doc.display_status === 'string' + && normalizeStatusForQuery(doc.display_status) === statusFilterValue, + ) + } + + if (!sortField) + return filteredDocs + + const sortedDocs = [...filteredDocs].sort((a, b) => { + let aValue: string | number + let bValue: string | number + + switch (sortField) { + case 'name': + aValue = a.name?.toLowerCase() || '' + bValue = b.name?.toLowerCase() || '' + break + case 'word_count': + aValue = a.word_count || 0 + bValue = b.word_count || 0 + break + case 'hit_count': + aValue = a.hit_count || 0 + bValue = b.hit_count || 0 + break + case 'created_at': + aValue = a.created_at + bValue = b.created_at + break + default: + return 0 + } + + if (sortField === 'name') { + const result = (aValue as string).localeCompare(bValue as string) + return sortOrder === 'asc' ? result : -result + } + else { + const result = (aValue as number) - (bValue as number) + return sortOrder === 'asc' ? result : -result + } + }) + + return sortedDocs + }, [documents, sortField, sortOrder, statusFilterValue]) + + return { + sortField, + sortOrder, + handleSort, + sortedDocuments, + } +} diff --git a/web/app/components/datasets/documents/components/document-list/index.spec.tsx b/web/app/components/datasets/documents/components/document-list/index.spec.tsx new file mode 100644 index 0000000000..32429cc0ac --- /dev/null +++ b/web/app/components/datasets/documents/components/document-list/index.spec.tsx @@ -0,0 +1,487 @@ +import type { ReactNode } from 'react' +import type { Props as PaginationProps } from '@/app/components/base/pagination' +import type { SimpleDocumentDetail } from '@/models/datasets' +import { QueryClient, QueryClientProvider } from '@tanstack/react-query' +import { fireEvent, render, screen } from '@testing-library/react' +import { beforeEach, describe, expect, it, vi } from 'vitest' +import { ChunkingMode, DataSourceType } from '@/models/datasets' +import DocumentList from '../list' + +const mockPush = vi.fn() + +vi.mock('next/navigation', () => ({ + useRouter: () => ({ + push: mockPush, + }), +})) + +vi.mock('@/context/dataset-detail', () => ({ + useDatasetDetailContextWithSelector: (selector: (state: { dataset: { doc_form: string } }) => unknown) => + selector({ dataset: { doc_form: ChunkingMode.text } }), +})) + +const createTestQueryClient = () => new QueryClient({ + defaultOptions: { + queries: { retry: false, gcTime: 0 }, + mutations: { retry: false }, + }, +}) + +const createWrapper = () => { + const queryClient = createTestQueryClient() + return ({ children }: { children: ReactNode }) => ( + <QueryClientProvider client={queryClient}> + {children} + </QueryClientProvider> + ) +} + +const createMockDoc = (overrides: Partial<SimpleDocumentDetail> = {}): SimpleDocumentDetail => ({ + id: `doc-${Math.random().toString(36).substr(2, 9)}`, + position: 1, + data_source_type: DataSourceType.FILE, + data_source_info: {}, + data_source_detail_dict: { + upload_file: { name: 'test.txt', extension: 'txt' }, + }, + dataset_process_rule_id: 'rule-1', + batch: 'batch-1', + name: 'test-document.txt', + created_from: 'web', + created_by: 'user-1', + created_at: Date.now(), + tokens: 100, + indexing_status: 'completed', + error: null, + enabled: true, + disabled_at: null, + disabled_by: null, + archived: false, + archived_reason: null, + archived_by: null, + archived_at: null, + updated_at: Date.now(), + doc_type: null, + doc_metadata: undefined, + display_status: 'available', + word_count: 500, + hit_count: 10, + doc_form: 'text_model', + ...overrides, +} as SimpleDocumentDetail) + +const defaultPagination: PaginationProps = { + current: 1, + onChange: vi.fn(), + total: 100, +} + +describe('DocumentList', () => { + const defaultProps = { + embeddingAvailable: true, + documents: [ + createMockDoc({ id: 'doc-1', name: 'Document 1.txt', word_count: 100, hit_count: 5 }), + createMockDoc({ id: 'doc-2', name: 'Document 2.txt', word_count: 200, hit_count: 10 }), + createMockDoc({ id: 'doc-3', name: 'Document 3.txt', word_count: 300, hit_count: 15 }), + ], + selectedIds: [] as string[], + onSelectedIdChange: vi.fn(), + datasetId: 'dataset-1', + pagination: defaultPagination, + onUpdate: vi.fn(), + onManageMetadata: vi.fn(), + statusFilterValue: '', + remoteSortValue: '', + } + + beforeEach(() => { + vi.clearAllMocks() + }) + + describe('Rendering', () => { + it('should render without crashing', () => { + render(<DocumentList {...defaultProps} />, { wrapper: createWrapper() }) + expect(screen.getByRole('table')).toBeInTheDocument() + }) + + it('should render all documents', () => { + render(<DocumentList {...defaultProps} />, { wrapper: createWrapper() }) + expect(screen.getByText('Document 1.txt')).toBeInTheDocument() + expect(screen.getByText('Document 2.txt')).toBeInTheDocument() + expect(screen.getByText('Document 3.txt')).toBeInTheDocument() + }) + + it('should render table headers', () => { + render(<DocumentList {...defaultProps} />, { wrapper: createWrapper() }) + expect(screen.getByText('#')).toBeInTheDocument() + }) + + it('should render pagination when total is provided', () => { + render(<DocumentList {...defaultProps} />, { wrapper: createWrapper() }) + // Pagination component should be present + expect(screen.getByRole('table')).toBeInTheDocument() + }) + + it('should not render pagination when total is 0', () => { + const props = { + ...defaultProps, + pagination: { ...defaultPagination, total: 0 }, + } + render(<DocumentList {...props} />, { wrapper: createWrapper() }) + expect(screen.getByRole('table')).toBeInTheDocument() + }) + + it('should render empty table when no documents', () => { + const props = { ...defaultProps, documents: [] } + render(<DocumentList {...props} />, { wrapper: createWrapper() }) + expect(screen.getByRole('table')).toBeInTheDocument() + }) + }) + + describe('Selection', () => { + // Helper to find checkboxes (custom div components, not native checkboxes) + const findCheckboxes = (container: HTMLElement): NodeListOf<Element> => { + return container.querySelectorAll('[class*="shadow-xs"]') + } + + it('should render header checkbox when embeddingAvailable', () => { + const { container } = render(<DocumentList {...defaultProps} />, { wrapper: createWrapper() }) + const checkboxes = findCheckboxes(container) + expect(checkboxes.length).toBeGreaterThan(0) + }) + + it('should not render header checkbox when embedding not available', () => { + const props = { ...defaultProps, embeddingAvailable: false } + render(<DocumentList {...props} />, { wrapper: createWrapper() }) + // Row checkboxes should still be there, but header checkbox should be hidden + expect(screen.getByRole('table')).toBeInTheDocument() + }) + + it('should call onSelectedIdChange when select all is clicked', () => { + const onSelectedIdChange = vi.fn() + const props = { ...defaultProps, onSelectedIdChange } + const { container } = render(<DocumentList {...props} />, { wrapper: createWrapper() }) + + const checkboxes = findCheckboxes(container) + if (checkboxes.length > 0) { + fireEvent.click(checkboxes[0]) + expect(onSelectedIdChange).toHaveBeenCalled() + } + }) + + it('should show all checkboxes as checked when all are selected', () => { + const props = { + ...defaultProps, + selectedIds: ['doc-1', 'doc-2', 'doc-3'], + } + const { container } = render(<DocumentList {...props} />, { wrapper: createWrapper() }) + + const checkboxes = findCheckboxes(container) + // When checked, checkbox should have a check icon (svg) inside + checkboxes.forEach((checkbox) => { + const checkIcon = checkbox.querySelector('svg') + expect(checkIcon).toBeInTheDocument() + }) + }) + + it('should show indeterminate state when some are selected', () => { + const props = { + ...defaultProps, + selectedIds: ['doc-1'], + } + const { container } = render(<DocumentList {...props} />, { wrapper: createWrapper() }) + + // First checkbox is the header checkbox which should be indeterminate + const checkboxes = findCheckboxes(container) + expect(checkboxes.length).toBeGreaterThan(0) + // Header checkbox should show indeterminate icon, not check icon + // Just verify it's rendered + expect(checkboxes[0]).toBeInTheDocument() + }) + + it('should call onSelectedIdChange with single document when row checkbox is clicked', () => { + const onSelectedIdChange = vi.fn() + const props = { ...defaultProps, onSelectedIdChange } + const { container } = render(<DocumentList {...props} />, { wrapper: createWrapper() }) + + // Click the second checkbox (first row checkbox) + const checkboxes = findCheckboxes(container) + if (checkboxes.length > 1) { + fireEvent.click(checkboxes[1]) + expect(onSelectedIdChange).toHaveBeenCalled() + } + }) + }) + + describe('Sorting', () => { + it('should render sort headers for sortable columns', () => { + render(<DocumentList {...defaultProps} />, { wrapper: createWrapper() }) + // Find svg icons which indicate sortable columns + const sortIcons = document.querySelectorAll('svg') + expect(sortIcons.length).toBeGreaterThan(0) + }) + + it('should update sort order when sort header is clicked', () => { + render(<DocumentList {...defaultProps} />, { wrapper: createWrapper() }) + + // Find and click a sort header by its parent div containing the label text + const sortableHeaders = document.querySelectorAll('[class*="cursor-pointer"]') + if (sortableHeaders.length > 0) { + fireEvent.click(sortableHeaders[0]) + } + + expect(screen.getByRole('table')).toBeInTheDocument() + }) + }) + + describe('Batch Actions', () => { + it('should show batch action bar when documents are selected', () => { + const props = { + ...defaultProps, + selectedIds: ['doc-1', 'doc-2'], + } + render(<DocumentList {...props} />, { wrapper: createWrapper() }) + + // BatchAction component should be visible + expect(screen.getByRole('table')).toBeInTheDocument() + }) + + it('should not show batch action bar when no documents selected', () => { + render(<DocumentList {...defaultProps} />, { wrapper: createWrapper() }) + + // BatchAction should not be present + expect(screen.getByRole('table')).toBeInTheDocument() + }) + + it('should render batch action bar with archive option', () => { + const props = { + ...defaultProps, + selectedIds: ['doc-1'], + } + render(<DocumentList {...props} />, { wrapper: createWrapper() }) + + // BatchAction component should be visible when documents are selected + expect(screen.getByRole('table')).toBeInTheDocument() + }) + + it('should render batch action bar with enable option', () => { + const props = { + ...defaultProps, + selectedIds: ['doc-1'], + } + render(<DocumentList {...props} />, { wrapper: createWrapper() }) + + expect(screen.getByRole('table')).toBeInTheDocument() + }) + + it('should render batch action bar with disable option', () => { + const props = { + ...defaultProps, + selectedIds: ['doc-1'], + } + render(<DocumentList {...props} />, { wrapper: createWrapper() }) + + expect(screen.getByRole('table')).toBeInTheDocument() + }) + + it('should render batch action bar with delete option', () => { + const props = { + ...defaultProps, + selectedIds: ['doc-1'], + } + render(<DocumentList {...props} />, { wrapper: createWrapper() }) + + expect(screen.getByRole('table')).toBeInTheDocument() + }) + + it('should clear selection when cancel is clicked', () => { + const onSelectedIdChange = vi.fn() + const props = { + ...defaultProps, + selectedIds: ['doc-1'], + onSelectedIdChange, + } + render(<DocumentList {...props} />, { wrapper: createWrapper() }) + + const cancelButton = screen.queryByRole('button', { name: /cancel/i }) + if (cancelButton) { + fireEvent.click(cancelButton) + expect(onSelectedIdChange).toHaveBeenCalledWith([]) + } + }) + + it('should show download option for downloadable documents', () => { + const props = { + ...defaultProps, + selectedIds: ['doc-1'], + documents: [ + createMockDoc({ id: 'doc-1', data_source_type: DataSourceType.FILE }), + ], + } + render(<DocumentList {...props} />, { wrapper: createWrapper() }) + + // BatchAction should be visible + expect(screen.getByRole('table')).toBeInTheDocument() + }) + + it('should show re-index option for error documents', () => { + const props = { + ...defaultProps, + selectedIds: ['doc-1'], + documents: [ + createMockDoc({ id: 'doc-1', display_status: 'error' }), + ], + } + render(<DocumentList {...props} />, { wrapper: createWrapper() }) + + // BatchAction with re-index should be present for error documents + expect(screen.getByRole('table')).toBeInTheDocument() + }) + }) + + describe('Row Click Navigation', () => { + it('should navigate to document detail when row is clicked', () => { + render(<DocumentList {...defaultProps} />, { wrapper: createWrapper() }) + + const rows = screen.getAllByRole('row') + // First row is header, second row is first document + if (rows.length > 1) { + fireEvent.click(rows[1]) + expect(mockPush).toHaveBeenCalledWith('/datasets/dataset-1/documents/doc-1') + } + }) + }) + + describe('Rename Modal', () => { + it('should not show rename modal initially', () => { + render(<DocumentList {...defaultProps} />, { wrapper: createWrapper() }) + + // RenameModal should not be visible initially + const modal = screen.queryByRole('dialog') + expect(modal).not.toBeInTheDocument() + }) + + it('should show rename modal when rename button is clicked', () => { + const { container } = render(<DocumentList {...defaultProps} />, { wrapper: createWrapper() }) + + // Find and click the rename button in the first row + const renameButtons = container.querySelectorAll('.cursor-pointer.rounded-md') + if (renameButtons.length > 0) { + fireEvent.click(renameButtons[0]) + } + + // After clicking rename, the modal should potentially be visible + expect(screen.getByRole('table')).toBeInTheDocument() + }) + + it('should call onUpdate when document is renamed', () => { + const onUpdate = vi.fn() + const props = { ...defaultProps, onUpdate } + render(<DocumentList {...props} />, { wrapper: createWrapper() }) + + // The handleRenamed callback wraps onUpdate + expect(screen.getByRole('table')).toBeInTheDocument() + }) + }) + + describe('Edit Metadata Modal', () => { + it('should handle edit metadata action', () => { + const props = { + ...defaultProps, + selectedIds: ['doc-1'], + } + render(<DocumentList {...props} />, { wrapper: createWrapper() }) + + const editButton = screen.queryByRole('button', { name: /metadata/i }) + if (editButton) { + fireEvent.click(editButton) + } + + expect(screen.getByRole('table')).toBeInTheDocument() + }) + + it('should call onManageMetadata when manage metadata is triggered', () => { + const onManageMetadata = vi.fn() + const props = { + ...defaultProps, + selectedIds: ['doc-1'], + onManageMetadata, + } + render(<DocumentList {...props} />, { wrapper: createWrapper() }) + + // The onShowManage callback in EditMetadataBatchModal should call hideEditModal then onManageMetadata + expect(screen.getByRole('table')).toBeInTheDocument() + }) + }) + + describe('Chunking Mode', () => { + it('should render with general mode', () => { + render(<DocumentList {...defaultProps} />, { wrapper: createWrapper() }) + expect(screen.getByRole('table')).toBeInTheDocument() + }) + + it('should render with QA mode', () => { + // This test uses the default mock which returns ChunkingMode.text + // The component will compute isQAMode based on doc_form + render(<DocumentList {...defaultProps} />, { wrapper: createWrapper() }) + expect(screen.getByRole('table')).toBeInTheDocument() + }) + + it('should render with parent-child mode', () => { + render(<DocumentList {...defaultProps} />, { wrapper: createWrapper() }) + expect(screen.getByRole('table')).toBeInTheDocument() + }) + }) + + describe('Edge Cases', () => { + it('should handle empty documents array', () => { + const props = { ...defaultProps, documents: [] } + render(<DocumentList {...props} />, { wrapper: createWrapper() }) + + expect(screen.getByRole('table')).toBeInTheDocument() + }) + + it('should handle documents with missing optional fields', () => { + const docWithMissingFields = createMockDoc({ + word_count: undefined as unknown as number, + hit_count: undefined as unknown as number, + }) + const props = { + ...defaultProps, + documents: [docWithMissingFields], + } + render(<DocumentList {...props} />, { wrapper: createWrapper() }) + + expect(screen.getByRole('table')).toBeInTheDocument() + }) + + it('should handle status filter value', () => { + const props = { + ...defaultProps, + statusFilterValue: 'completed', + } + render(<DocumentList {...props} />, { wrapper: createWrapper() }) + + expect(screen.getByRole('table')).toBeInTheDocument() + }) + + it('should handle remote sort value', () => { + const props = { + ...defaultProps, + remoteSortValue: 'created_at', + } + render(<DocumentList {...props} />, { wrapper: createWrapper() }) + + expect(screen.getByRole('table')).toBeInTheDocument() + }) + + it('should handle large number of documents', () => { + const manyDocs = Array.from({ length: 20 }, (_, i) => + createMockDoc({ id: `doc-${i}`, name: `Document ${i}.txt` })) + const props = { ...defaultProps, documents: manyDocs } + render(<DocumentList {...props} />, { wrapper: createWrapper() }) + + expect(screen.getByRole('table')).toBeInTheDocument() + }, 10000) + }) +}) diff --git a/web/app/components/datasets/documents/components/document-list/index.tsx b/web/app/components/datasets/documents/components/document-list/index.tsx new file mode 100644 index 0000000000..46fd7a02d5 --- /dev/null +++ b/web/app/components/datasets/documents/components/document-list/index.tsx @@ -0,0 +1,3 @@ +// Re-export from parent for backwards compatibility +export { default } from '../list' +export { renderTdValue } from './components' diff --git a/web/app/components/datasets/documents/components/list.tsx b/web/app/components/datasets/documents/components/list.tsx index f63d6d987e..3106f6c30b 100644 --- a/web/app/components/datasets/documents/components/list.tsx +++ b/web/app/components/datasets/documents/components/list.tsx @@ -1,67 +1,26 @@ 'use client' import type { FC } from 'react' import type { Props as PaginationProps } from '@/app/components/base/pagination' -import type { CommonResponse } from '@/models/common' -import type { LegacyDataSourceInfo, LocalFileInfo, OnlineDocumentInfo, OnlineDriveInfo, SimpleDocumentDetail } from '@/models/datasets' -import { - RiArrowDownLine, - RiEditLine, - RiGlobalLine, -} from '@remixicon/react' +import type { SimpleDocumentDetail } from '@/models/datasets' import { useBoolean } from 'ahooks' -import { uniq } from 'es-toolkit/array' -import { pick } from 'es-toolkit/object' -import { useRouter } from 'next/navigation' import * as React from 'react' -import { useCallback, useEffect, useMemo, useState } from 'react' +import { useCallback, useState } from 'react' import { useTranslation } from 'react-i18next' import Checkbox from '@/app/components/base/checkbox' -import FileTypeIcon from '@/app/components/base/file-uploader/file-type-icon' -import NotionIcon from '@/app/components/base/notion-icon' import Pagination from '@/app/components/base/pagination' -import Toast from '@/app/components/base/toast' -import Tooltip from '@/app/components/base/tooltip' -import ChunkingModeLabel from '@/app/components/datasets/common/chunking-mode-label' -import { normalizeStatusForQuery } from '@/app/components/datasets/documents/status-filter' -import { extensionToFileType } from '@/app/components/datasets/hit-testing/utils/extension-to-file-type' import EditMetadataBatchModal from '@/app/components/datasets/metadata/edit-metadata-batch/modal' import useBatchEditDocumentMetadata from '@/app/components/datasets/metadata/hooks/use-batch-edit-document-metadata' import { useDatasetDetailContextWithSelector as useDatasetDetailContext } from '@/context/dataset-detail' -import useTimestamp from '@/hooks/use-timestamp' -import { ChunkingMode, DataSourceType, DocumentActionType } from '@/models/datasets' -import { DatasourceType } from '@/models/pipeline' -import { useDocumentArchive, useDocumentBatchRetryIndex, useDocumentDelete, useDocumentDisable, useDocumentDownloadZip, useDocumentEnable, useDocumentSummary } from '@/service/knowledge/use-document' -import { asyncRunSafe } from '@/utils' -import { cn } from '@/utils/classnames' -import { downloadBlob } from '@/utils/download' -import { formatNumber } from '@/utils/format' +import { ChunkingMode, DocumentActionType } from '@/models/datasets' import BatchAction from '../detail/completed/common/batch-action' -import SummaryStatus from '../detail/completed/common/summary-status' -import StatusItem from '../status-item' import s from '../style.module.css' -import Operations from './operations' +import { DocumentTableRow, renderTdValue, SortHeader } from './document-list/components' +import { useDocumentActions, useDocumentSelection, useDocumentSort } from './document-list/hooks' import RenameModal from './rename-modal' -export const renderTdValue = (value: string | number | null, isEmptyStyle = false) => { - return ( - <div className={cn(isEmptyStyle ? 'text-text-tertiary' : 'text-text-secondary', s.tdValue)}> - {value ?? '-'} - </div> - ) -} - -const renderCount = (count: number | undefined) => { - if (!count) - return renderTdValue(0, true) - - if (count < 1000) - return count - - return `${formatNumber((count / 1000).toFixed(1))}k` -} - type LocalDoc = SimpleDocumentDetail & { percent?: number } -type IDocumentListProps = { + +type DocumentListProps = { embeddingAvailable: boolean documents: LocalDoc[] selectedIds: string[] @@ -77,7 +36,7 @@ type IDocumentListProps = { /** * Document list component including basic information */ -const DocumentList: FC<IDocumentListProps> = ({ +const DocumentList: FC<DocumentListProps> = ({ embeddingAvailable, documents = [], selectedIds, @@ -90,20 +49,43 @@ const DocumentList: FC<IDocumentListProps> = ({ remoteSortValue, }) => { const { t } = useTranslation() - const { formatTime } = useTimestamp() - const router = useRouter() const datasetConfig = useDatasetDetailContext(s => s.dataset) const chunkingMode = datasetConfig?.doc_form const isGeneralMode = chunkingMode !== ChunkingMode.parentChild const isQAMode = chunkingMode === ChunkingMode.qa - const [sortField, setSortField] = useState<'name' | 'word_count' | 'hit_count' | 'created_at' | null>(null) - const [sortOrder, setSortOrder] = useState<'asc' | 'desc'>('desc') - useEffect(() => { - setSortField(null) - setSortOrder('desc') - }, [remoteSortValue]) + // Sorting + const { sortField, sortOrder, handleSort, sortedDocuments } = useDocumentSort({ + documents, + statusFilterValue, + remoteSortValue, + }) + // Selection + const { + isAllSelected, + isSomeSelected, + onSelectAll, + onSelectOne, + hasErrorDocumentsSelected, + downloadableSelectedIds, + clearSelection, + } = useDocumentSelection({ + documents: sortedDocuments, + selectedIds, + onSelectedIdChange, + }) + + // Actions + const { handleAction, handleBatchReIndex, handleBatchDownload } = useDocumentActions({ + datasetId, + selectedIds, + downloadableSelectedIds, + onUpdate, + onClearSelection: clearSelection, + }) + + // Batch edit metadata const { isShowEditModal, showEditModal, @@ -113,233 +95,26 @@ const DocumentList: FC<IDocumentListProps> = ({ } = useBatchEditDocumentMetadata({ datasetId, docList: documents.filter(doc => selectedIds.includes(doc.id)), - selectedDocumentIds: selectedIds, // Pass all selected IDs separately + selectedDocumentIds: selectedIds, onUpdate, }) - const localDocs = useMemo(() => { - let filteredDocs = documents - - if (statusFilterValue && statusFilterValue !== 'all') { - filteredDocs = filteredDocs.filter(doc => - typeof doc.display_status === 'string' - && normalizeStatusForQuery(doc.display_status) === statusFilterValue, - ) - } - - if (!sortField) - return filteredDocs - - const sortedDocs = [...filteredDocs].sort((a, b) => { - let aValue: any - let bValue: any - - switch (sortField) { - case 'name': - aValue = a.name?.toLowerCase() || '' - bValue = b.name?.toLowerCase() || '' - break - case 'word_count': - aValue = a.word_count || 0 - bValue = b.word_count || 0 - break - case 'hit_count': - aValue = a.hit_count || 0 - bValue = b.hit_count || 0 - break - case 'created_at': - aValue = a.created_at - bValue = b.created_at - break - default: - return 0 - } - - if (sortField === 'name') { - const result = aValue.localeCompare(bValue) - return sortOrder === 'asc' ? result : -result - } - else { - const result = aValue - bValue - return sortOrder === 'asc' ? result : -result - } - }) - - return sortedDocs - }, [documents, sortField, sortOrder, statusFilterValue]) - - const handleSort = (field: 'name' | 'word_count' | 'hit_count' | 'created_at') => { - if (sortField === field) { - setSortOrder(sortOrder === 'asc' ? 'desc' : 'asc') - } - else { - setSortField(field) - setSortOrder('desc') - } - } - - const renderSortHeader = (field: 'name' | 'word_count' | 'hit_count' | 'created_at', label: string) => { - const isActive = sortField === field - const isDesc = isActive && sortOrder === 'desc' - - return ( - <div className="flex cursor-pointer items-center hover:text-text-secondary" onClick={() => handleSort(field)}> - {label} - <RiArrowDownLine - className={cn('ml-0.5 h-3 w-3 transition-all', isActive ? 'text-text-tertiary' : 'text-text-disabled', isActive && !isDesc ? 'rotate-180' : '')} - /> - </div> - ) - } - + // Rename modal const [currDocument, setCurrDocument] = useState<LocalDoc | null>(null) const [isShowRenameModal, { setTrue: setShowRenameModalTrue, setFalse: setShowRenameModalFalse, }] = useBoolean(false) + const handleShowRenameModal = useCallback((doc: LocalDoc) => { setCurrDocument(doc) setShowRenameModalTrue() }, [setShowRenameModalTrue]) + const handleRenamed = useCallback(() => { onUpdate() }, [onUpdate]) - const isAllSelected = useMemo(() => { - return localDocs.length > 0 && localDocs.every(doc => selectedIds.includes(doc.id)) - }, [localDocs, selectedIds]) - - const isSomeSelected = useMemo(() => { - return localDocs.some(doc => selectedIds.includes(doc.id)) - }, [localDocs, selectedIds]) - - const onSelectedAll = useCallback(() => { - if (isAllSelected) - onSelectedIdChange([]) - else - onSelectedIdChange(uniq([...selectedIds, ...localDocs.map(doc => doc.id)])) - }, [isAllSelected, localDocs, onSelectedIdChange, selectedIds]) - const { mutateAsync: archiveDocument } = useDocumentArchive() - const { mutateAsync: generateSummary } = useDocumentSummary() - const { mutateAsync: enableDocument } = useDocumentEnable() - const { mutateAsync: disableDocument } = useDocumentDisable() - const { mutateAsync: deleteDocument } = useDocumentDelete() - const { mutateAsync: retryIndexDocument } = useDocumentBatchRetryIndex() - const { mutateAsync: requestDocumentsZip, isPending: isDownloadingZip } = useDocumentDownloadZip() - - const handleAction = (actionName: DocumentActionType) => { - return async () => { - let opApi - switch (actionName) { - case DocumentActionType.archive: - opApi = archiveDocument - break - case DocumentActionType.summary: - opApi = generateSummary - break - case DocumentActionType.enable: - opApi = enableDocument - break - case DocumentActionType.disable: - opApi = disableDocument - break - default: - opApi = deleteDocument - break - } - const [e] = await asyncRunSafe<CommonResponse>(opApi({ datasetId, documentIds: selectedIds }) as Promise<CommonResponse>) - - if (!e) { - if (actionName === DocumentActionType.delete) - onSelectedIdChange([]) - Toast.notify({ type: 'success', message: t('actionMsg.modifiedSuccessfully', { ns: 'common' }) }) - onUpdate() - } - else { Toast.notify({ type: 'error', message: t('actionMsg.modifiedUnsuccessfully', { ns: 'common' }) }) } - } - } - - const handleBatchReIndex = async () => { - const [e] = await asyncRunSafe<CommonResponse>(retryIndexDocument({ datasetId, documentIds: selectedIds })) - if (!e) { - onSelectedIdChange([]) - Toast.notify({ type: 'success', message: t('actionMsg.modifiedSuccessfully', { ns: 'common' }) }) - onUpdate() - } - else { - Toast.notify({ type: 'error', message: t('actionMsg.modifiedUnsuccessfully', { ns: 'common' }) }) - } - } - - const hasErrorDocumentsSelected = useMemo(() => { - return localDocs.some(doc => selectedIds.includes(doc.id) && doc.display_status === 'error') - }, [localDocs, selectedIds]) - - const getFileExtension = useCallback((fileName: string): string => { - if (!fileName) - return '' - const parts = fileName.split('.') - if (parts.length <= 1 || (parts[0] === '' && parts.length === 2)) - return '' - - return parts[parts.length - 1].toLowerCase() - }, []) - - const isCreateFromRAGPipeline = useCallback((createdFrom: string) => { - return createdFrom === 'rag-pipeline' - }, []) - - /** - * Calculate the data source type - * DataSourceType: FILE, NOTION, WEB (legacy) - * DatasourceType: localFile, onlineDocument, websiteCrawl, onlineDrive (new) - */ - const isLocalFile = useCallback((dataSourceType: DataSourceType | DatasourceType) => { - return dataSourceType === DatasourceType.localFile || dataSourceType === DataSourceType.FILE - }, []) - const isOnlineDocument = useCallback((dataSourceType: DataSourceType | DatasourceType) => { - return dataSourceType === DatasourceType.onlineDocument || dataSourceType === DataSourceType.NOTION - }, []) - const isWebsiteCrawl = useCallback((dataSourceType: DataSourceType | DatasourceType) => { - return dataSourceType === DatasourceType.websiteCrawl || dataSourceType === DataSourceType.WEB - }, []) - const isOnlineDrive = useCallback((dataSourceType: DataSourceType | DatasourceType) => { - return dataSourceType === DatasourceType.onlineDrive - }, []) - - const downloadableSelectedIds = useMemo(() => { - const selectedSet = new Set(selectedIds) - return localDocs - .filter(doc => selectedSet.has(doc.id) && doc.data_source_type === DataSourceType.FILE) - .map(doc => doc.id) - }, [localDocs, selectedIds]) - - /** - * Generate a random ZIP filename for bulk document downloads. - * We intentionally avoid leaking dataset info in the exported archive name. - */ - const generateDocsZipFileName = useCallback((): string => { - // Prefer UUID for uniqueness; fall back to time+random when unavailable. - const randomPart = (typeof crypto !== 'undefined' && typeof crypto.randomUUID === 'function') - ? crypto.randomUUID() - : `${Date.now().toString(36)}${Math.random().toString(36).slice(2, 10)}` - return `${randomPart}-docs.zip` - }, []) - - const handleBatchDownload = useCallback(async () => { - if (isDownloadingZip) - return - - // Download as a single ZIP to avoid browser caps on multiple automatic downloads. - const [e, blob] = await asyncRunSafe(requestDocumentsZip({ datasetId, documentIds: downloadableSelectedIds })) - if (e || !blob) { - Toast.notify({ type: 'error', message: t('actionMsg.downloadUnsuccessfully', { ns: 'common' }) }) - return - } - - downloadBlob({ data: blob, fileName: generateDocsZipFileName() }) - }, [datasetId, downloadableSelectedIds, generateDocsZipFileName, isDownloadingZip, requestDocumentsZip, t]) - return ( <div className="relative mt-3 flex h-full w-full flex-col"> <div className="relative h-0 grow overflow-x-auto"> @@ -353,157 +128,76 @@ const DocumentList: FC<IDocumentListProps> = ({ className="mr-2 shrink-0" checked={isAllSelected} indeterminate={!isAllSelected && isSomeSelected} - onCheck={onSelectedAll} + onCheck={onSelectAll} /> )} # </div> </td> <td> - {renderSortHeader('name', t('list.table.header.fileName', { ns: 'datasetDocuments' }))} + <SortHeader + field="name" + label={t('list.table.header.fileName', { ns: 'datasetDocuments' })} + currentSortField={sortField} + sortOrder={sortOrder} + onSort={handleSort} + /> </td> <td className="w-[130px]">{t('list.table.header.chunkingMode', { ns: 'datasetDocuments' })}</td> <td className="w-24"> - {renderSortHeader('word_count', t('list.table.header.words', { ns: 'datasetDocuments' }))} + <SortHeader + field="word_count" + label={t('list.table.header.words', { ns: 'datasetDocuments' })} + currentSortField={sortField} + sortOrder={sortOrder} + onSort={handleSort} + /> </td> <td className="w-44"> - {renderSortHeader('hit_count', t('list.table.header.hitCount', { ns: 'datasetDocuments' }))} + <SortHeader + field="hit_count" + label={t('list.table.header.hitCount', { ns: 'datasetDocuments' })} + currentSortField={sortField} + sortOrder={sortOrder} + onSort={handleSort} + /> </td> <td className="w-44"> - {renderSortHeader('created_at', t('list.table.header.uploadTime', { ns: 'datasetDocuments' }))} + <SortHeader + field="created_at" + label={t('list.table.header.uploadTime', { ns: 'datasetDocuments' })} + currentSortField={sortField} + sortOrder={sortOrder} + onSort={handleSort} + /> </td> <td className="w-40">{t('list.table.header.status', { ns: 'datasetDocuments' })}</td> <td className="w-20">{t('list.table.header.action', { ns: 'datasetDocuments' })}</td> </tr> </thead> <tbody className="text-text-secondary"> - {localDocs.map((doc, index) => { - const isFile = isLocalFile(doc.data_source_type) - const fileType = isFile ? doc.data_source_detail_dict?.upload_file?.extension : '' - return ( - <tr - key={doc.id} - className="h-8 cursor-pointer border-b border-divider-subtle hover:bg-background-default-hover" - onClick={() => { - router.push(`/datasets/${datasetId}/documents/${doc.id}`) - }} - > - <td className="text-left align-middle text-xs text-text-tertiary"> - <div className="flex items-center" onClick={e => e.stopPropagation()}> - <Checkbox - className="mr-2 shrink-0" - checked={selectedIds.includes(doc.id)} - onCheck={() => { - onSelectedIdChange( - selectedIds.includes(doc.id) - ? selectedIds.filter(id => id !== doc.id) - : [...selectedIds, doc.id], - ) - }} - /> - {index + 1} - </div> - </td> - <td> - <div className="group mr-6 flex max-w-[460px] items-center hover:mr-0"> - <div className="flex shrink-0 items-center"> - {isOnlineDocument(doc.data_source_type) && ( - <NotionIcon - className="mr-1.5" - type="page" - src={ - isCreateFromRAGPipeline(doc.created_from) - ? (doc.data_source_info as OnlineDocumentInfo).page.page_icon - : (doc.data_source_info as LegacyDataSourceInfo).notion_page_icon - } - /> - )} - {isLocalFile(doc.data_source_type) && ( - <FileTypeIcon - type={ - extensionToFileType( - isCreateFromRAGPipeline(doc.created_from) - ? (doc?.data_source_info as LocalFileInfo)?.extension - : ((doc?.data_source_info as LegacyDataSourceInfo)?.upload_file?.extension ?? fileType), - ) - } - className="mr-1.5" - /> - )} - {isOnlineDrive(doc.data_source_type) && ( - <FileTypeIcon - type={ - extensionToFileType( - getFileExtension((doc?.data_source_info as unknown as OnlineDriveInfo)?.name), - ) - } - className="mr-1.5" - /> - )} - {isWebsiteCrawl(doc.data_source_type) && ( - <RiGlobalLine className="mr-1.5 size-4" /> - )} - </div> - <Tooltip - popupContent={doc.name} - > - <span className="grow-1 truncate text-sm">{doc.name}</span> - </Tooltip> - { - doc.summary_index_status && ( - <div className="ml-1 hidden shrink-0 group-hover:flex"> - <SummaryStatus status={doc.summary_index_status} /> - </div> - ) - } - <div className="hidden shrink-0 group-hover:ml-auto group-hover:flex"> - <Tooltip - popupContent={t('list.table.rename', { ns: 'datasetDocuments' })} - > - <div - className="cursor-pointer rounded-md p-1 hover:bg-state-base-hover" - onClick={(e) => { - e.stopPropagation() - handleShowRenameModal(doc) - }} - > - <RiEditLine className="h-4 w-4 text-text-tertiary" /> - </div> - </Tooltip> - </div> - </div> - </td> - <td> - <ChunkingModeLabel - isGeneralMode={isGeneralMode} - isQAMode={isQAMode} - /> - </td> - <td>{renderCount(doc.word_count)}</td> - <td>{renderCount(doc.hit_count)}</td> - <td className="text-[13px] text-text-secondary"> - {formatTime(doc.created_at, t('dateTimeFormat', { ns: 'datasetHitTesting' }) as string)} - </td> - <td> - <StatusItem status={doc.display_status} /> - </td> - <td> - <Operations - selectedIds={selectedIds} - onSelectedIdChange={onSelectedIdChange} - embeddingAvailable={embeddingAvailable} - datasetId={datasetId} - detail={pick(doc, ['name', 'enabled', 'archived', 'id', 'data_source_type', 'doc_form', 'display_status'])} - onUpdate={onUpdate} - /> - </td> - </tr> - ) - })} + {sortedDocuments.map((doc, index) => ( + <DocumentTableRow + key={doc.id} + doc={doc} + index={index} + datasetId={datasetId} + isSelected={selectedIds.includes(doc.id)} + isGeneralMode={isGeneralMode} + isQAMode={isQAMode} + embeddingAvailable={embeddingAvailable} + selectedIds={selectedIds} + onSelectOne={onSelectOne} + onSelectedIdChange={onSelectedIdChange} + onShowRenameModal={handleShowRenameModal} + onUpdate={onUpdate} + /> + ))} </tbody> </table> </div> - {(selectedIds.length > 0) && ( + + {selectedIds.length > 0 && ( <BatchAction className="absolute bottom-16 left-0 z-20" selectedIds={selectedIds} @@ -515,12 +209,10 @@ const DocumentList: FC<IDocumentListProps> = ({ onBatchDelete={handleAction(DocumentActionType.delete)} onEditMetadata={showEditModal} onBatchReIndex={hasErrorDocumentsSelected ? handleBatchReIndex : undefined} - onCancel={() => { - onSelectedIdChange([]) - }} + onCancel={clearSelection} /> )} - {/* Show Pagination only if the total is more than the limit */} + {!!pagination.total && ( <Pagination {...pagination} @@ -556,3 +248,5 @@ const DocumentList: FC<IDocumentListProps> = ({ } export default DocumentList + +export { renderTdValue } diff --git a/web/app/components/datasets/documents/detail/embedding/components/index.ts b/web/app/components/datasets/documents/detail/embedding/components/index.ts new file mode 100644 index 0000000000..5faac4e027 --- /dev/null +++ b/web/app/components/datasets/documents/detail/embedding/components/index.ts @@ -0,0 +1,4 @@ +export { default as ProgressBar } from './progress-bar' +export { default as RuleDetail } from './rule-detail' +export { default as SegmentProgress } from './segment-progress' +export { default as StatusHeader } from './status-header' diff --git a/web/app/components/datasets/documents/detail/embedding/components/progress-bar.spec.tsx b/web/app/components/datasets/documents/detail/embedding/components/progress-bar.spec.tsx new file mode 100644 index 0000000000..b54c8000fe --- /dev/null +++ b/web/app/components/datasets/documents/detail/embedding/components/progress-bar.spec.tsx @@ -0,0 +1,159 @@ +import { render } from '@testing-library/react' +import { describe, expect, it } from 'vitest' +import ProgressBar from './progress-bar' + +describe('ProgressBar', () => { + const defaultProps = { + percent: 50, + isEmbedding: false, + isCompleted: false, + isPaused: false, + isError: false, + } + + const getProgressElements = (container: HTMLElement) => { + const wrapper = container.firstChild as HTMLElement + const progressBar = wrapper.firstChild as HTMLElement + return { wrapper, progressBar } + } + + describe('Rendering', () => { + it('should render without crashing', () => { + const { container } = render(<ProgressBar {...defaultProps} />) + const { wrapper, progressBar } = getProgressElements(container) + expect(wrapper).toBeInTheDocument() + expect(progressBar).toBeInTheDocument() + }) + + it('should render progress bar container with correct classes', () => { + const { container } = render(<ProgressBar {...defaultProps} />) + const { wrapper } = getProgressElements(container) + expect(wrapper).toHaveClass('flex', 'h-2', 'w-full', 'items-center', 'overflow-hidden', 'rounded-md') + }) + + it('should render inner progress bar with transition classes', () => { + const { container } = render(<ProgressBar {...defaultProps} />) + const { progressBar } = getProgressElements(container) + expect(progressBar).toHaveClass('h-full', 'transition-all', 'duration-300') + }) + }) + + describe('Progress Width', () => { + it('should set progress width to 0%', () => { + const { container } = render(<ProgressBar {...defaultProps} percent={0} />) + const { progressBar } = getProgressElements(container) + expect(progressBar).toHaveStyle({ width: '0%' }) + }) + + it('should set progress width to 50%', () => { + const { container } = render(<ProgressBar {...defaultProps} percent={50} />) + const { progressBar } = getProgressElements(container) + expect(progressBar).toHaveStyle({ width: '50%' }) + }) + + it('should set progress width to 100%', () => { + const { container } = render(<ProgressBar {...defaultProps} percent={100} />) + const { progressBar } = getProgressElements(container) + expect(progressBar).toHaveStyle({ width: '100%' }) + }) + + it('should set progress width to 75%', () => { + const { container } = render(<ProgressBar {...defaultProps} percent={75} />) + const { progressBar } = getProgressElements(container) + expect(progressBar).toHaveStyle({ width: '75%' }) + }) + }) + + describe('Container Background States', () => { + it('should apply semi-transparent background when isEmbedding is true', () => { + const { container } = render(<ProgressBar {...defaultProps} isEmbedding />) + const { wrapper } = getProgressElements(container) + expect(wrapper).toHaveClass('bg-components-progress-bar-bg/50') + }) + + it('should apply default background when isEmbedding is false', () => { + const { container } = render(<ProgressBar {...defaultProps} isEmbedding={false} />) + const { wrapper } = getProgressElements(container) + expect(wrapper).toHaveClass('bg-components-progress-bar-bg') + expect(wrapper).not.toHaveClass('bg-components-progress-bar-bg/50') + }) + }) + + describe('Progress Bar Fill States', () => { + it('should apply solid progress style when isEmbedding is true', () => { + const { container } = render(<ProgressBar {...defaultProps} isEmbedding />) + const { progressBar } = getProgressElements(container) + expect(progressBar).toHaveClass('bg-components-progress-bar-progress-solid') + }) + + it('should apply solid progress style when isCompleted is true', () => { + const { container } = render(<ProgressBar {...defaultProps} isCompleted />) + const { progressBar } = getProgressElements(container) + expect(progressBar).toHaveClass('bg-components-progress-bar-progress-solid') + }) + + it('should apply highlight style when isPaused is true', () => { + const { container } = render(<ProgressBar {...defaultProps} isPaused />) + const { progressBar } = getProgressElements(container) + expect(progressBar).toHaveClass('bg-components-progress-bar-progress-highlight') + }) + + it('should apply highlight style when isError is true', () => { + const { container } = render(<ProgressBar {...defaultProps} isError />) + const { progressBar } = getProgressElements(container) + expect(progressBar).toHaveClass('bg-components-progress-bar-progress-highlight') + }) + + it('should not apply fill styles when no status flags are set', () => { + const { container } = render(<ProgressBar {...defaultProps} />) + const { progressBar } = getProgressElements(container) + expect(progressBar).not.toHaveClass('bg-components-progress-bar-progress-solid') + expect(progressBar).not.toHaveClass('bg-components-progress-bar-progress-highlight') + }) + }) + + describe('Combined States', () => { + it('should apply highlight when isEmbedding and isPaused', () => { + const { container } = render(<ProgressBar {...defaultProps} isEmbedding isPaused />) + const { progressBar } = getProgressElements(container) + // highlight takes precedence since isPaused condition is separate + expect(progressBar).toHaveClass('bg-components-progress-bar-progress-highlight') + }) + + it('should apply highlight when isCompleted and isError', () => { + const { container } = render(<ProgressBar {...defaultProps} isCompleted isError />) + const { progressBar } = getProgressElements(container) + // highlight takes precedence since isError condition is separate + expect(progressBar).toHaveClass('bg-components-progress-bar-progress-highlight') + }) + + it('should apply semi-transparent bg for embedding and highlight for paused', () => { + const { container } = render(<ProgressBar {...defaultProps} isEmbedding isPaused />) + const { wrapper } = getProgressElements(container) + expect(wrapper).toHaveClass('bg-components-progress-bar-bg/50') + }) + }) + + describe('Edge Cases', () => { + it('should handle all props set to false', () => { + const { container } = render( + <ProgressBar + percent={0} + isEmbedding={false} + isCompleted={false} + isPaused={false} + isError={false} + />, + ) + const { wrapper, progressBar } = getProgressElements(container) + expect(wrapper).toBeInTheDocument() + expect(progressBar).toHaveStyle({ width: '0%' }) + }) + + it('should handle decimal percent values', () => { + const { container } = render(<ProgressBar {...defaultProps} percent={33.33} />) + const { progressBar } = getProgressElements(container) + expect(progressBar).toHaveStyle({ width: '33.33%' }) + }) + }) +}) diff --git a/web/app/components/datasets/documents/detail/embedding/components/progress-bar.tsx b/web/app/components/datasets/documents/detail/embedding/components/progress-bar.tsx new file mode 100644 index 0000000000..19c6493922 --- /dev/null +++ b/web/app/components/datasets/documents/detail/embedding/components/progress-bar.tsx @@ -0,0 +1,44 @@ +import type { FC } from 'react' +import * as React from 'react' +import { cn } from '@/utils/classnames' + +type ProgressBarProps = { + percent: number + isEmbedding: boolean + isCompleted: boolean + isPaused: boolean + isError: boolean +} + +const ProgressBar: FC<ProgressBarProps> = React.memo(({ + percent, + isEmbedding, + isCompleted, + isPaused, + isError, +}) => { + const isActive = isEmbedding || isCompleted + const isHighlighted = isPaused || isError + + return ( + <div + className={cn( + 'flex h-2 w-full items-center overflow-hidden rounded-md border border-components-progress-bar-border', + isEmbedding ? 'bg-components-progress-bar-bg/50' : 'bg-components-progress-bar-bg', + )} + > + <div + className={cn( + 'h-full transition-all duration-300', + isActive && 'bg-components-progress-bar-progress-solid', + isHighlighted && 'bg-components-progress-bar-progress-highlight', + )} + style={{ width: `${percent}%` }} + /> + </div> + ) +}) + +ProgressBar.displayName = 'ProgressBar' + +export default ProgressBar diff --git a/web/app/components/datasets/documents/detail/embedding/components/rule-detail.spec.tsx b/web/app/components/datasets/documents/detail/embedding/components/rule-detail.spec.tsx new file mode 100644 index 0000000000..138a4eacd8 --- /dev/null +++ b/web/app/components/datasets/documents/detail/embedding/components/rule-detail.spec.tsx @@ -0,0 +1,203 @@ +import type { ProcessRuleResponse } from '@/models/datasets' +import { render, screen } from '@testing-library/react' +import { describe, expect, it } from 'vitest' +import { ProcessMode } from '@/models/datasets' +import { RETRIEVE_METHOD } from '@/types/app' +import { IndexingType } from '../../../../create/step-two' +import RuleDetail from './rule-detail' + +describe('RuleDetail', () => { + const defaultProps = { + indexingType: IndexingType.QUALIFIED, + retrievalMethod: RETRIEVE_METHOD.semantic, + } + + const createSourceData = (overrides: Partial<ProcessRuleResponse> = {}): ProcessRuleResponse => ({ + mode: ProcessMode.general, + rules: { + segmentation: { + separator: '\n', + max_tokens: 500, + chunk_overlap: 50, + }, + pre_processing_rules: [ + { id: 'remove_extra_spaces', enabled: true }, + { id: 'remove_urls_emails', enabled: false }, + ], + parent_mode: 'full-doc', + subchunk_segmentation: { + separator: '\n', + max_tokens: 200, + chunk_overlap: 20, + }, + }, + limits: { indexing_max_segmentation_tokens_length: 4000 }, + ...overrides, + }) + + describe('Rendering', () => { + it('should render without crashing', () => { + render(<RuleDetail {...defaultProps} />) + expect(screen.getByText(/stepTwo\.indexMode/i)).toBeInTheDocument() + }) + + it('should render with sourceData', () => { + const sourceData = createSourceData() + render(<RuleDetail {...defaultProps} sourceData={sourceData} />) + expect(screen.getByText(/embedding\.mode/i)).toBeInTheDocument() + }) + + it('should render all segmentation rule fields', () => { + const sourceData = createSourceData() + render(<RuleDetail {...defaultProps} sourceData={sourceData} />) + expect(screen.getByText(/embedding\.mode/i)).toBeInTheDocument() + expect(screen.getByText(/embedding\.segmentLength/i)).toBeInTheDocument() + expect(screen.getByText(/embedding\.textCleaning/i)).toBeInTheDocument() + }) + }) + + describe('Mode Display', () => { + it('should display custom mode for general process mode', () => { + const sourceData = createSourceData({ mode: ProcessMode.general }) + render(<RuleDetail {...defaultProps} sourceData={sourceData} />) + expect(screen.getByText(/embedding\.custom/i)).toBeInTheDocument() + }) + + it('should display mode label field', () => { + const sourceData = createSourceData() + render(<RuleDetail {...defaultProps} sourceData={sourceData} />) + expect(screen.getByText(/embedding\.mode/i)).toBeInTheDocument() + }) + }) + + describe('Segment Length Display', () => { + it('should display max tokens for general mode', () => { + const sourceData = createSourceData({ + mode: ProcessMode.general, + rules: { + segmentation: { separator: '\n', max_tokens: 500, chunk_overlap: 50 }, + pre_processing_rules: [], + parent_mode: 'full-doc', + subchunk_segmentation: { separator: '\n', max_tokens: 200, chunk_overlap: 20 }, + }, + }) + render(<RuleDetail {...defaultProps} sourceData={sourceData} />) + expect(screen.getByText('500')).toBeInTheDocument() + }) + + it('should display segment length label', () => { + const sourceData = createSourceData() + render(<RuleDetail {...defaultProps} sourceData={sourceData} />) + expect(screen.getByText(/embedding\.segmentLength/i)).toBeInTheDocument() + }) + }) + + describe('Text Cleaning Display', () => { + it('should display enabled pre-processing rules', () => { + const sourceData = createSourceData({ + rules: { + segmentation: { separator: '\n', max_tokens: 500, chunk_overlap: 50 }, + pre_processing_rules: [ + { id: 'remove_extra_spaces', enabled: true }, + { id: 'remove_urls_emails', enabled: true }, + ], + parent_mode: 'full-doc', + subchunk_segmentation: { separator: '\n', max_tokens: 200, chunk_overlap: 20 }, + }, + }) + render(<RuleDetail {...defaultProps} sourceData={sourceData} />) + expect(screen.getByText(/removeExtraSpaces/i)).toBeInTheDocument() + expect(screen.getByText(/removeUrlEmails/i)).toBeInTheDocument() + }) + + it('should display text cleaning label', () => { + const sourceData = createSourceData() + render(<RuleDetail {...defaultProps} sourceData={sourceData} />) + expect(screen.getByText(/embedding\.textCleaning/i)).toBeInTheDocument() + }) + }) + + describe('Index Mode Display', () => { + it('should display economical mode when indexingType is ECONOMICAL', () => { + render(<RuleDetail {...defaultProps} indexingType={IndexingType.ECONOMICAL} />) + expect(screen.getByText(/stepTwo\.economical/i)).toBeInTheDocument() + }) + + it('should display qualified mode when indexingType is QUALIFIED', () => { + render(<RuleDetail {...defaultProps} indexingType={IndexingType.QUALIFIED} />) + expect(screen.getByText(/stepTwo\.qualified/i)).toBeInTheDocument() + }) + }) + + describe('Retrieval Method Display', () => { + it('should display keyword search for economical mode', () => { + render(<RuleDetail {...defaultProps} indexingType={IndexingType.ECONOMICAL} />) + expect(screen.getByText(/retrieval\.keyword_search\.title/i)).toBeInTheDocument() + }) + + it('should display semantic search as default for qualified mode', () => { + render(<RuleDetail {...defaultProps} indexingType={IndexingType.QUALIFIED} />) + expect(screen.getByText(/retrieval\.semantic_search\.title/i)).toBeInTheDocument() + }) + + it('should display full text search when retrievalMethod is fullText', () => { + render(<RuleDetail {...defaultProps} retrievalMethod={RETRIEVE_METHOD.fullText} />) + expect(screen.getByText(/retrieval\.full_text_search\.title/i)).toBeInTheDocument() + }) + + it('should display hybrid search when retrievalMethod is hybrid', () => { + render(<RuleDetail {...defaultProps} retrievalMethod={RETRIEVE_METHOD.hybrid} />) + expect(screen.getByText(/retrieval\.hybrid_search\.title/i)).toBeInTheDocument() + }) + }) + + describe('Edge Cases', () => { + it('should display dash for missing sourceData', () => { + render(<RuleDetail {...defaultProps} />) + const dashes = screen.getAllByText('-') + expect(dashes.length).toBeGreaterThan(0) + }) + + it('should display dash when mode is undefined', () => { + const sourceData = { rules: {} } as ProcessRuleResponse + render(<RuleDetail {...defaultProps} sourceData={sourceData} />) + const dashes = screen.getAllByText('-') + expect(dashes.length).toBeGreaterThan(0) + }) + + it('should handle undefined retrievalMethod', () => { + render(<RuleDetail indexingType={IndexingType.QUALIFIED} />) + expect(screen.getByText(/retrieval\.semantic_search\.title/i)).toBeInTheDocument() + }) + + it('should handle empty pre_processing_rules array', () => { + const sourceData = createSourceData({ + rules: { + segmentation: { separator: '\n', max_tokens: 500, chunk_overlap: 50 }, + pre_processing_rules: [], + parent_mode: 'full-doc', + subchunk_segmentation: { separator: '\n', max_tokens: 200, chunk_overlap: 20 }, + }, + }) + render(<RuleDetail {...defaultProps} sourceData={sourceData} />) + expect(screen.getByText(/embedding\.textCleaning/i)).toBeInTheDocument() + }) + + it('should render container with correct structure', () => { + const { container } = render(<RuleDetail {...defaultProps} />) + const wrapper = container.firstChild as HTMLElement + expect(wrapper).toHaveClass('py-3') + }) + + it('should handle undefined indexingType', () => { + render(<RuleDetail retrievalMethod={RETRIEVE_METHOD.semantic} />) + expect(screen.getByText(/stepTwo\.indexMode/i)).toBeInTheDocument() + }) + + it('should render divider between sections', () => { + const { container } = render(<RuleDetail {...defaultProps} />) + const dividers = container.querySelectorAll('.bg-divider-subtle') + expect(dividers.length).toBeGreaterThan(0) + }) + }) +}) diff --git a/web/app/components/datasets/documents/detail/embedding/components/rule-detail.tsx b/web/app/components/datasets/documents/detail/embedding/components/rule-detail.tsx new file mode 100644 index 0000000000..486b94175b --- /dev/null +++ b/web/app/components/datasets/documents/detail/embedding/components/rule-detail.tsx @@ -0,0 +1,128 @@ +import type { FC } from 'react' +import type { ProcessRuleResponse } from '@/models/datasets' +import type { RETRIEVE_METHOD } from '@/types/app' +import Image from 'next/image' +import * as React from 'react' +import { useCallback } from 'react' +import { useTranslation } from 'react-i18next' +import Divider from '@/app/components/base/divider' +import { ProcessMode } from '@/models/datasets' +import { indexMethodIcon, retrievalIcon } from '../../../../create/icons' +import { IndexingType } from '../../../../create/step-two' +import { FieldInfo } from '../../metadata' + +type RuleDetailProps = { + sourceData?: ProcessRuleResponse + indexingType?: IndexingType + retrievalMethod?: RETRIEVE_METHOD +} + +const getRetrievalIcon = (method?: RETRIEVE_METHOD) => { + if (method === 'full_text_search') + return retrievalIcon.fullText + if (method === 'hybrid_search') + return retrievalIcon.hybrid + return retrievalIcon.vector +} + +const RuleDetail: FC<RuleDetailProps> = React.memo(({ + sourceData, + indexingType, + retrievalMethod, +}) => { + const { t } = useTranslation() + + const segmentationRuleMap = { + mode: t('embedding.mode', { ns: 'datasetDocuments' }), + segmentLength: t('embedding.segmentLength', { ns: 'datasetDocuments' }), + textCleaning: t('embedding.textCleaning', { ns: 'datasetDocuments' }), + } + + const getRuleName = useCallback((key: string) => { + const ruleNameMap: Record<string, string> = { + remove_extra_spaces: t('stepTwo.removeExtraSpaces', { ns: 'datasetCreation' }), + remove_urls_emails: t('stepTwo.removeUrlEmails', { ns: 'datasetCreation' }), + remove_stopwords: t('stepTwo.removeStopwords', { ns: 'datasetCreation' }), + } + return ruleNameMap[key] + }, [t]) + + const getValue = useCallback((field: string) => { + const defaultValue = '-' + + if (!sourceData?.mode) + return defaultValue + + const maxTokens = typeof sourceData?.rules?.segmentation?.max_tokens === 'number' + ? sourceData.rules.segmentation.max_tokens + : defaultValue + + const childMaxTokens = typeof sourceData?.rules?.subchunk_segmentation?.max_tokens === 'number' + ? sourceData.rules.subchunk_segmentation.max_tokens + : defaultValue + + const isGeneralMode = sourceData.mode === ProcessMode.general + + const fieldValueMap: Record<string, string | number> = { + mode: isGeneralMode + ? t('embedding.custom', { ns: 'datasetDocuments' }) + : `${t('embedding.hierarchical', { ns: 'datasetDocuments' })} ยท ${ + sourceData?.rules?.parent_mode === 'paragraph' + ? t('parentMode.paragraph', { ns: 'dataset' }) + : t('parentMode.fullDoc', { ns: 'dataset' }) + }`, + segmentLength: isGeneralMode + ? maxTokens + : `${t('embedding.parentMaxTokens', { ns: 'datasetDocuments' })} ${maxTokens}; ${t('embedding.childMaxTokens', { ns: 'datasetDocuments' })} ${childMaxTokens}`, + textCleaning: sourceData?.rules?.pre_processing_rules + ?.filter(rule => rule.enabled) + .map(rule => getRuleName(rule.id)) + .join(',') || defaultValue, + } + + return fieldValueMap[field] ?? defaultValue + }, [sourceData, t, getRuleName]) + + const isEconomical = indexingType === IndexingType.ECONOMICAL + + return ( + <div className="py-3"> + <div className="flex flex-col gap-y-1"> + {Object.keys(segmentationRuleMap).map(field => ( + <FieldInfo + key={field} + label={segmentationRuleMap[field as keyof typeof segmentationRuleMap]} + displayedValue={String(getValue(field))} + /> + ))} + </div> + <Divider type="horizontal" className="bg-divider-subtle" /> + <FieldInfo + label={t('stepTwo.indexMode', { ns: 'datasetCreation' })} + displayedValue={t(`stepTwo.${isEconomical ? 'economical' : 'qualified'}`, { ns: 'datasetCreation' }) as string} + valueIcon={( + <Image + className="size-4" + src={isEconomical ? indexMethodIcon.economical : indexMethodIcon.high_quality} + alt="" + /> + )} + /> + <FieldInfo + label={t('form.retrievalSetting.title', { ns: 'datasetSettings' })} + displayedValue={t(`retrieval.${isEconomical ? 'keyword_search' : retrievalMethod ?? 'semantic_search'}.title`, { ns: 'dataset' })} + valueIcon={( + <Image + className="size-4" + src={getRetrievalIcon(retrievalMethod)} + alt="" + /> + )} + /> + </div> + ) +}) + +RuleDetail.displayName = 'RuleDetail' + +export default RuleDetail diff --git a/web/app/components/datasets/documents/detail/embedding/components/segment-progress.spec.tsx b/web/app/components/datasets/documents/detail/embedding/components/segment-progress.spec.tsx new file mode 100644 index 0000000000..1afc2f42f1 --- /dev/null +++ b/web/app/components/datasets/documents/detail/embedding/components/segment-progress.spec.tsx @@ -0,0 +1,81 @@ +import { render, screen } from '@testing-library/react' +import { describe, expect, it } from 'vitest' +import SegmentProgress from './segment-progress' + +describe('SegmentProgress', () => { + const defaultProps = { + completedSegments: 50, + totalSegments: 100, + percent: 50, + } + + describe('Rendering', () => { + it('should render without crashing', () => { + render(<SegmentProgress {...defaultProps} />) + expect(screen.getByText(/segments/i)).toBeInTheDocument() + }) + + it('should render with correct CSS classes', () => { + const { container } = render(<SegmentProgress {...defaultProps} />) + const wrapper = container.firstChild as HTMLElement + expect(wrapper).toHaveClass('flex', 'w-full', 'items-center') + }) + + it('should render text with correct styling class', () => { + render(<SegmentProgress {...defaultProps} />) + const text = screen.getByText(/segments/i) + expect(text).toHaveClass('system-xs-medium', 'text-text-secondary') + }) + }) + + describe('Progress Display', () => { + it('should display completed and total segments', () => { + render(<SegmentProgress completedSegments={50} totalSegments={100} percent={50} />) + expect(screen.getByText(/50\/100/)).toBeInTheDocument() + }) + + it('should display percent value', () => { + render(<SegmentProgress completedSegments={50} totalSegments={100} percent={50} />) + expect(screen.getByText(/50%/)).toBeInTheDocument() + }) + + it('should display 0/0 when segments are 0', () => { + render(<SegmentProgress completedSegments={0} totalSegments={0} percent={0} />) + expect(screen.getByText(/0\/0/)).toBeInTheDocument() + expect(screen.getByText(/0%/)).toBeInTheDocument() + }) + + it('should display 100% when completed', () => { + render(<SegmentProgress completedSegments={100} totalSegments={100} percent={100} />) + expect(screen.getByText(/100\/100/)).toBeInTheDocument() + expect(screen.getByText(/100%/)).toBeInTheDocument() + }) + }) + + describe('Edge Cases', () => { + it('should display -- when completedSegments is undefined', () => { + render(<SegmentProgress totalSegments={100} percent={0} />) + expect(screen.getByText(/--\/100/)).toBeInTheDocument() + }) + + it('should display -- when totalSegments is undefined', () => { + render(<SegmentProgress completedSegments={50} percent={50} />) + expect(screen.getByText(/50\/--/)).toBeInTheDocument() + }) + + it('should display --/-- when both segments are undefined', () => { + render(<SegmentProgress percent={0} />) + expect(screen.getByText(/--\/--/)).toBeInTheDocument() + }) + + it('should handle large numbers', () => { + render(<SegmentProgress completedSegments={999999} totalSegments={1000000} percent={99} />) + expect(screen.getByText(/999999\/1000000/)).toBeInTheDocument() + }) + + it('should handle decimal percent', () => { + render(<SegmentProgress completedSegments={33} totalSegments={100} percent={33.33} />) + expect(screen.getByText(/33.33%/)).toBeInTheDocument() + }) + }) +}) diff --git a/web/app/components/datasets/documents/detail/embedding/components/segment-progress.tsx b/web/app/components/datasets/documents/detail/embedding/components/segment-progress.tsx new file mode 100644 index 0000000000..a76704391d --- /dev/null +++ b/web/app/components/datasets/documents/detail/embedding/components/segment-progress.tsx @@ -0,0 +1,32 @@ +import type { FC } from 'react' +import * as React from 'react' +import { useTranslation } from 'react-i18next' + +type SegmentProgressProps = { + completedSegments?: number + totalSegments?: number + percent: number +} + +const SegmentProgress: FC<SegmentProgressProps> = React.memo(({ + completedSegments, + totalSegments, + percent, +}) => { + const { t } = useTranslation() + + const completed = completedSegments ?? '--' + const total = totalSegments ?? '--' + + return ( + <div className="flex w-full items-center"> + <span className="system-xs-medium text-text-secondary"> + {`${t('embedding.segments', { ns: 'datasetDocuments' })} ${completed}/${total} ยท ${percent}%`} + </span> + </div> + ) +}) + +SegmentProgress.displayName = 'SegmentProgress' + +export default SegmentProgress diff --git a/web/app/components/datasets/documents/detail/embedding/components/status-header.spec.tsx b/web/app/components/datasets/documents/detail/embedding/components/status-header.spec.tsx new file mode 100644 index 0000000000..519d2f3aa8 --- /dev/null +++ b/web/app/components/datasets/documents/detail/embedding/components/status-header.spec.tsx @@ -0,0 +1,155 @@ +import { fireEvent, render, screen } from '@testing-library/react' +import { beforeEach, describe, expect, it, vi } from 'vitest' +import StatusHeader from './status-header' + +describe('StatusHeader', () => { + const defaultProps = { + isEmbedding: false, + isCompleted: false, + isPaused: false, + isError: false, + onPause: vi.fn(), + onResume: vi.fn(), + } + + beforeEach(() => { + vi.clearAllMocks() + }) + + describe('Rendering', () => { + it('should render without crashing', () => { + const { container } = render(<StatusHeader {...defaultProps} />) + expect(container.firstChild).toBeInTheDocument() + }) + + it('should render with correct container classes', () => { + const { container } = render(<StatusHeader {...defaultProps} />) + const wrapper = container.firstChild as HTMLElement + expect(wrapper).toHaveClass('flex', 'h-6', 'items-center', 'gap-x-1') + }) + }) + + describe('Status Text', () => { + it('should display processing text when isEmbedding is true', () => { + render(<StatusHeader {...defaultProps} isEmbedding />) + expect(screen.getByText(/embedding\.processing/i)).toBeInTheDocument() + }) + + it('should display completed text when isCompleted is true', () => { + render(<StatusHeader {...defaultProps} isCompleted />) + expect(screen.getByText(/embedding\.completed/i)).toBeInTheDocument() + }) + + it('should display paused text when isPaused is true', () => { + render(<StatusHeader {...defaultProps} isPaused />) + expect(screen.getByText(/embedding\.paused/i)).toBeInTheDocument() + }) + + it('should display error text when isError is true', () => { + render(<StatusHeader {...defaultProps} isError />) + expect(screen.getByText(/embedding\.error/i)).toBeInTheDocument() + }) + + it('should display empty text when no status flags are set', () => { + render(<StatusHeader {...defaultProps} />) + const statusText = screen.getByText('', { selector: 'span.system-md-semibold-uppercase' }) + expect(statusText).toBeInTheDocument() + }) + }) + + describe('Loading Spinner', () => { + it('should show loading spinner when isEmbedding is true', () => { + const { container } = render(<StatusHeader {...defaultProps} isEmbedding />) + const spinner = container.querySelector('svg.animate-spin') + expect(spinner).toBeInTheDocument() + }) + + it('should not show loading spinner when isEmbedding is false', () => { + const { container } = render(<StatusHeader {...defaultProps} isEmbedding={false} />) + const spinner = container.querySelector('svg.animate-spin') + expect(spinner).not.toBeInTheDocument() + }) + }) + + describe('Pause Button', () => { + it('should show pause button when isEmbedding is true', () => { + render(<StatusHeader {...defaultProps} isEmbedding />) + expect(screen.getByRole('button')).toBeInTheDocument() + expect(screen.getByText(/embedding\.pause/i)).toBeInTheDocument() + }) + + it('should not show pause button when isEmbedding is false', () => { + render(<StatusHeader {...defaultProps} isEmbedding={false} />) + expect(screen.queryByText(/embedding\.pause/i)).not.toBeInTheDocument() + }) + + it('should call onPause when pause button is clicked', () => { + const onPause = vi.fn() + render(<StatusHeader {...defaultProps} isEmbedding onPause={onPause} />) + fireEvent.click(screen.getByRole('button')) + expect(onPause).toHaveBeenCalledTimes(1) + }) + + it('should disable pause button when isPauseLoading is true', () => { + render(<StatusHeader {...defaultProps} isEmbedding isPauseLoading />) + expect(screen.getByRole('button')).toBeDisabled() + }) + }) + + describe('Resume Button', () => { + it('should show resume button when isPaused is true', () => { + render(<StatusHeader {...defaultProps} isPaused />) + expect(screen.getByRole('button')).toBeInTheDocument() + expect(screen.getByText(/embedding\.resume/i)).toBeInTheDocument() + }) + + it('should not show resume button when isPaused is false', () => { + render(<StatusHeader {...defaultProps} isPaused={false} />) + expect(screen.queryByText(/embedding\.resume/i)).not.toBeInTheDocument() + }) + + it('should call onResume when resume button is clicked', () => { + const onResume = vi.fn() + render(<StatusHeader {...defaultProps} isPaused onResume={onResume} />) + fireEvent.click(screen.getByRole('button')) + expect(onResume).toHaveBeenCalledTimes(1) + }) + + it('should disable resume button when isResumeLoading is true', () => { + render(<StatusHeader {...defaultProps} isPaused isResumeLoading />) + expect(screen.getByRole('button')).toBeDisabled() + }) + }) + + describe('Button Styles', () => { + it('should have correct button styles for pause button', () => { + render(<StatusHeader {...defaultProps} isEmbedding />) + const button = screen.getByRole('button') + expect(button).toHaveClass('flex', 'items-center', 'gap-x-1', 'rounded-md') + }) + + it('should have correct button styles for resume button', () => { + render(<StatusHeader {...defaultProps} isPaused />) + const button = screen.getByRole('button') + expect(button).toHaveClass('flex', 'items-center', 'gap-x-1', 'rounded-md') + }) + }) + + describe('Edge Cases', () => { + it('should not show any buttons when isCompleted', () => { + render(<StatusHeader {...defaultProps} isCompleted />) + expect(screen.queryByRole('button')).not.toBeInTheDocument() + }) + + it('should not show any buttons when isError', () => { + render(<StatusHeader {...defaultProps} isError />) + expect(screen.queryByRole('button')).not.toBeInTheDocument() + }) + + it('should show both buttons when isEmbedding and isPaused are both true', () => { + render(<StatusHeader {...defaultProps} isEmbedding isPaused />) + const buttons = screen.getAllByRole('button') + expect(buttons.length).toBe(2) + }) + }) +}) diff --git a/web/app/components/datasets/documents/detail/embedding/components/status-header.tsx b/web/app/components/datasets/documents/detail/embedding/components/status-header.tsx new file mode 100644 index 0000000000..e72f0553b5 --- /dev/null +++ b/web/app/components/datasets/documents/detail/embedding/components/status-header.tsx @@ -0,0 +1,84 @@ +import type { FC } from 'react' +import { RiLoader2Line, RiPauseCircleLine, RiPlayCircleLine } from '@remixicon/react' +import * as React from 'react' +import { useTranslation } from 'react-i18next' + +type StatusHeaderProps = { + isEmbedding: boolean + isCompleted: boolean + isPaused: boolean + isError: boolean + onPause: () => void + onResume: () => void + isPauseLoading?: boolean + isResumeLoading?: boolean +} + +const StatusHeader: FC<StatusHeaderProps> = React.memo(({ + isEmbedding, + isCompleted, + isPaused, + isError, + onPause, + onResume, + isPauseLoading, + isResumeLoading, +}) => { + const { t } = useTranslation() + + const getStatusText = () => { + if (isEmbedding) + return t('embedding.processing', { ns: 'datasetDocuments' }) + if (isCompleted) + return t('embedding.completed', { ns: 'datasetDocuments' }) + if (isPaused) + return t('embedding.paused', { ns: 'datasetDocuments' }) + if (isError) + return t('embedding.error', { ns: 'datasetDocuments' }) + return '' + } + + const buttonBaseClass = `flex items-center gap-x-1 rounded-md border-[0.5px] + border-components-button-secondary-border bg-components-button-secondary-bg + px-1.5 py-1 shadow-xs shadow-shadow-shadow-3 backdrop-blur-[5px] + disabled:cursor-not-allowed disabled:opacity-50` + + return ( + <div className="flex h-6 items-center gap-x-1"> + {isEmbedding && <RiLoader2Line className="h-4 w-4 animate-spin text-text-secondary" />} + <span className="system-md-semibold-uppercase grow text-text-secondary"> + {getStatusText()} + </span> + {isEmbedding && ( + <button + type="button" + className={buttonBaseClass} + onClick={onPause} + disabled={isPauseLoading} + > + <RiPauseCircleLine className="h-3.5 w-3.5 text-components-button-secondary-text" /> + <span className="system-xs-medium pr-[3px] text-components-button-secondary-text"> + {t('embedding.pause', { ns: 'datasetDocuments' })} + </span> + </button> + )} + {isPaused && ( + <button + type="button" + className={buttonBaseClass} + onClick={onResume} + disabled={isResumeLoading} + > + <RiPlayCircleLine className="h-3.5 w-3.5 text-components-button-secondary-text" /> + <span className="system-xs-medium pr-[3px] text-components-button-secondary-text"> + {t('embedding.resume', { ns: 'datasetDocuments' })} + </span> + </button> + )} + </div> + ) +}) + +StatusHeader.displayName = 'StatusHeader' + +export default StatusHeader diff --git a/web/app/components/datasets/documents/detail/embedding/hooks/index.ts b/web/app/components/datasets/documents/detail/embedding/hooks/index.ts new file mode 100644 index 0000000000..603c16dda5 --- /dev/null +++ b/web/app/components/datasets/documents/detail/embedding/hooks/index.ts @@ -0,0 +1,10 @@ +export { + calculatePercent, + isEmbeddingStatus, + isTerminalStatus, + useEmbeddingStatus, + useInvalidateEmbeddingStatus, + usePauseIndexing, + useResumeIndexing, +} from './use-embedding-status' +export type { EmbeddingStatusType } from './use-embedding-status' diff --git a/web/app/components/datasets/documents/detail/embedding/hooks/use-embedding-status.spec.tsx b/web/app/components/datasets/documents/detail/embedding/hooks/use-embedding-status.spec.tsx new file mode 100644 index 0000000000..7cadc12dfc --- /dev/null +++ b/web/app/components/datasets/documents/detail/embedding/hooks/use-embedding-status.spec.tsx @@ -0,0 +1,462 @@ +import type { ReactNode } from 'react' +import type { IndexingStatusResponse } from '@/models/datasets' +import { QueryClient, QueryClientProvider } from '@tanstack/react-query' +import { act, renderHook, waitFor } from '@testing-library/react' +import { beforeEach, describe, expect, it, vi } from 'vitest' +import * as datasetsService from '@/service/datasets' +import { + calculatePercent, + isEmbeddingStatus, + isTerminalStatus, + useEmbeddingStatus, + useInvalidateEmbeddingStatus, + usePauseIndexing, + useResumeIndexing, +} from './use-embedding-status' + +vi.mock('@/service/datasets') + +const mockFetchIndexingStatus = vi.mocked(datasetsService.fetchIndexingStatus) +const mockPauseDocIndexing = vi.mocked(datasetsService.pauseDocIndexing) +const mockResumeDocIndexing = vi.mocked(datasetsService.resumeDocIndexing) + +const createTestQueryClient = () => new QueryClient({ + defaultOptions: { + queries: { retry: false }, + mutations: { retry: false }, + }, +}) + +const createWrapper = () => { + const queryClient = createTestQueryClient() + return ({ children }: { children: ReactNode }) => ( + <QueryClientProvider client={queryClient}> + {children} + </QueryClientProvider> + ) +} + +const mockIndexingStatus = (overrides: Partial<IndexingStatusResponse> = {}): IndexingStatusResponse => ({ + id: 'doc1', + indexing_status: 'indexing', + completed_segments: 50, + total_segments: 100, + processing_started_at: 0, + parsing_completed_at: 0, + cleaning_completed_at: 0, + splitting_completed_at: 0, + completed_at: null, + paused_at: null, + error: null, + stopped_at: null, + ...overrides, +}) + +describe('use-embedding-status', () => { + beforeEach(() => { + vi.clearAllMocks() + }) + + describe('isEmbeddingStatus', () => { + it('should return true for indexing status', () => { + expect(isEmbeddingStatus('indexing')).toBe(true) + }) + + it('should return true for splitting status', () => { + expect(isEmbeddingStatus('splitting')).toBe(true) + }) + + it('should return true for parsing status', () => { + expect(isEmbeddingStatus('parsing')).toBe(true) + }) + + it('should return true for cleaning status', () => { + expect(isEmbeddingStatus('cleaning')).toBe(true) + }) + + it('should return false for completed status', () => { + expect(isEmbeddingStatus('completed')).toBe(false) + }) + + it('should return false for paused status', () => { + expect(isEmbeddingStatus('paused')).toBe(false) + }) + + it('should return false for error status', () => { + expect(isEmbeddingStatus('error')).toBe(false) + }) + + it('should return false for undefined', () => { + expect(isEmbeddingStatus(undefined)).toBe(false) + }) + + it('should return false for empty string', () => { + expect(isEmbeddingStatus('')).toBe(false) + }) + }) + + describe('isTerminalStatus', () => { + it('should return true for completed status', () => { + expect(isTerminalStatus('completed')).toBe(true) + }) + + it('should return true for error status', () => { + expect(isTerminalStatus('error')).toBe(true) + }) + + it('should return true for paused status', () => { + expect(isTerminalStatus('paused')).toBe(true) + }) + + it('should return false for indexing status', () => { + expect(isTerminalStatus('indexing')).toBe(false) + }) + + it('should return false for undefined', () => { + expect(isTerminalStatus(undefined)).toBe(false) + }) + }) + + describe('calculatePercent', () => { + it('should calculate percent correctly', () => { + expect(calculatePercent(50, 100)).toBe(50) + }) + + it('should return 0 when total is 0', () => { + expect(calculatePercent(50, 0)).toBe(0) + }) + + it('should return 0 when total is undefined', () => { + expect(calculatePercent(50, undefined)).toBe(0) + }) + + it('should return 0 when completed is undefined', () => { + expect(calculatePercent(undefined, 100)).toBe(0) + }) + + it('should cap at 100 when percent exceeds 100', () => { + expect(calculatePercent(150, 100)).toBe(100) + }) + + it('should round to nearest integer', () => { + expect(calculatePercent(33, 100)).toBe(33) + expect(calculatePercent(1, 3)).toBe(33) + }) + }) + + describe('useEmbeddingStatus', () => { + it('should return initial state when disabled', () => { + const { result } = renderHook( + () => useEmbeddingStatus({ datasetId: 'ds1', documentId: 'doc1', enabled: false }), + { wrapper: createWrapper() }, + ) + + expect(result.current.isEmbedding).toBe(false) + expect(result.current.isCompleted).toBe(false) + expect(result.current.isPaused).toBe(false) + expect(result.current.isError).toBe(false) + expect(result.current.percent).toBe(0) + }) + + it('should not fetch when datasetId is missing', () => { + renderHook( + () => useEmbeddingStatus({ documentId: 'doc1' }), + { wrapper: createWrapper() }, + ) + + expect(mockFetchIndexingStatus).not.toHaveBeenCalled() + }) + + it('should not fetch when documentId is missing', () => { + renderHook( + () => useEmbeddingStatus({ datasetId: 'ds1' }), + { wrapper: createWrapper() }, + ) + + expect(mockFetchIndexingStatus).not.toHaveBeenCalled() + }) + + it('should fetch indexing status when enabled with valid ids', async () => { + mockFetchIndexingStatus.mockResolvedValue(mockIndexingStatus()) + + const { result } = renderHook( + () => useEmbeddingStatus({ datasetId: 'ds1', documentId: 'doc1' }), + { wrapper: createWrapper() }, + ) + + await waitFor(() => { + expect(result.current.isEmbedding).toBe(true) + }) + + expect(mockFetchIndexingStatus).toHaveBeenCalledWith({ + datasetId: 'ds1', + documentId: 'doc1', + }) + expect(result.current.percent).toBe(50) + }) + + it('should set isCompleted when status is completed', async () => { + mockFetchIndexingStatus.mockResolvedValue(mockIndexingStatus({ + indexing_status: 'completed', + completed_segments: 100, + })) + + const { result } = renderHook( + () => useEmbeddingStatus({ datasetId: 'ds1', documentId: 'doc1' }), + { wrapper: createWrapper() }, + ) + + await waitFor(() => { + expect(result.current.isCompleted).toBe(true) + }) + + expect(result.current.percent).toBe(100) + }) + + it('should set isPaused when status is paused', async () => { + mockFetchIndexingStatus.mockResolvedValue(mockIndexingStatus({ + indexing_status: 'paused', + })) + + const { result } = renderHook( + () => useEmbeddingStatus({ datasetId: 'ds1', documentId: 'doc1' }), + { wrapper: createWrapper() }, + ) + + await waitFor(() => { + expect(result.current.isPaused).toBe(true) + }) + }) + + it('should set isError when status is error', async () => { + mockFetchIndexingStatus.mockResolvedValue(mockIndexingStatus({ + indexing_status: 'error', + completed_segments: 25, + })) + + const { result } = renderHook( + () => useEmbeddingStatus({ datasetId: 'ds1', documentId: 'doc1' }), + { wrapper: createWrapper() }, + ) + + await waitFor(() => { + expect(result.current.isError).toBe(true) + }) + }) + + it('should provide invalidate function', async () => { + mockFetchIndexingStatus.mockResolvedValue(mockIndexingStatus()) + + const { result } = renderHook( + () => useEmbeddingStatus({ datasetId: 'ds1', documentId: 'doc1' }), + { wrapper: createWrapper() }, + ) + + await waitFor(() => { + expect(result.current.isEmbedding).toBe(true) + }) + + expect(typeof result.current.invalidate).toBe('function') + + // Call invalidate should not throw + await act(async () => { + result.current.invalidate() + }) + }) + + it('should provide resetStatus function that clears data', async () => { + mockFetchIndexingStatus.mockResolvedValue(mockIndexingStatus()) + + const { result } = renderHook( + () => useEmbeddingStatus({ datasetId: 'ds1', documentId: 'doc1' }), + { wrapper: createWrapper() }, + ) + + await waitFor(() => { + expect(result.current.data).toBeDefined() + }) + + // Reset status should clear the data + await act(async () => { + result.current.resetStatus() + }) + + await waitFor(() => { + expect(result.current.data).toBeNull() + }) + }) + }) + + describe('usePauseIndexing', () => { + it('should call pauseDocIndexing when mutate is called', async () => { + mockPauseDocIndexing.mockResolvedValue({ result: 'success' }) + + const { result } = renderHook( + () => usePauseIndexing({ datasetId: 'ds1', documentId: 'doc1' }), + { wrapper: createWrapper() }, + ) + + await act(async () => { + result.current.mutate() + }) + + await waitFor(() => { + expect(mockPauseDocIndexing).toHaveBeenCalledWith({ + datasetId: 'ds1', + documentId: 'doc1', + }) + }) + }) + + it('should call onSuccess callback on successful pause', async () => { + mockPauseDocIndexing.mockResolvedValue({ result: 'success' }) + const onSuccess = vi.fn() + + const { result } = renderHook( + () => usePauseIndexing({ datasetId: 'ds1', documentId: 'doc1', onSuccess }), + { wrapper: createWrapper() }, + ) + + await act(async () => { + result.current.mutate() + }) + + await waitFor(() => { + expect(onSuccess).toHaveBeenCalled() + }) + }) + + it('should call onError callback on failed pause', async () => { + const error = new Error('Network error') + mockPauseDocIndexing.mockRejectedValue(error) + const onError = vi.fn() + + const { result } = renderHook( + () => usePauseIndexing({ datasetId: 'ds1', documentId: 'doc1', onError }), + { wrapper: createWrapper() }, + ) + + await act(async () => { + result.current.mutate() + }) + + await waitFor(() => { + expect(onError).toHaveBeenCalled() + expect(onError.mock.calls[0][0]).toEqual(error) + }) + }) + }) + + describe('useResumeIndexing', () => { + it('should call resumeDocIndexing when mutate is called', async () => { + mockResumeDocIndexing.mockResolvedValue({ result: 'success' }) + + const { result } = renderHook( + () => useResumeIndexing({ datasetId: 'ds1', documentId: 'doc1' }), + { wrapper: createWrapper() }, + ) + + await act(async () => { + result.current.mutate() + }) + + await waitFor(() => { + expect(mockResumeDocIndexing).toHaveBeenCalledWith({ + datasetId: 'ds1', + documentId: 'doc1', + }) + }) + }) + + it('should call onSuccess callback on successful resume', async () => { + mockResumeDocIndexing.mockResolvedValue({ result: 'success' }) + const onSuccess = vi.fn() + + const { result } = renderHook( + () => useResumeIndexing({ datasetId: 'ds1', documentId: 'doc1', onSuccess }), + { wrapper: createWrapper() }, + ) + + await act(async () => { + result.current.mutate() + }) + + await waitFor(() => { + expect(onSuccess).toHaveBeenCalled() + }) + }) + }) + + describe('useInvalidateEmbeddingStatus', () => { + it('should return a function', () => { + const { result } = renderHook( + () => useInvalidateEmbeddingStatus(), + { wrapper: createWrapper() }, + ) + + expect(typeof result.current).toBe('function') + }) + + it('should invalidate specific query when datasetId and documentId are provided', async () => { + const queryClient = createTestQueryClient() + const wrapper = ({ children }: { children: ReactNode }) => ( + <QueryClientProvider client={queryClient}> + {children} + </QueryClientProvider> + ) + + // Set some initial data in the cache + queryClient.setQueryData(['embedding', 'indexing-status', 'ds1', 'doc1'], { + id: 'doc1', + indexing_status: 'indexing', + }) + + const { result } = renderHook( + () => useInvalidateEmbeddingStatus(), + { wrapper }, + ) + + await act(async () => { + result.current('ds1', 'doc1') + }) + + // The query should be invalidated (marked as stale) + const queryState = queryClient.getQueryState(['embedding', 'indexing-status', 'ds1', 'doc1']) + expect(queryState?.isInvalidated).toBe(true) + }) + + it('should invalidate all embedding status queries when ids are not provided', async () => { + const queryClient = createTestQueryClient() + const wrapper = ({ children }: { children: ReactNode }) => ( + <QueryClientProvider client={queryClient}> + {children} + </QueryClientProvider> + ) + + // Set some initial data in the cache for multiple documents + queryClient.setQueryData(['embedding', 'indexing-status', 'ds1', 'doc1'], { + id: 'doc1', + indexing_status: 'indexing', + }) + queryClient.setQueryData(['embedding', 'indexing-status', 'ds2', 'doc2'], { + id: 'doc2', + indexing_status: 'completed', + }) + + const { result } = renderHook( + () => useInvalidateEmbeddingStatus(), + { wrapper }, + ) + + await act(async () => { + result.current() + }) + + // Both queries should be invalidated + const queryState1 = queryClient.getQueryState(['embedding', 'indexing-status', 'ds1', 'doc1']) + const queryState2 = queryClient.getQueryState(['embedding', 'indexing-status', 'ds2', 'doc2']) + expect(queryState1?.isInvalidated).toBe(true) + expect(queryState2?.isInvalidated).toBe(true) + }) + }) +}) diff --git a/web/app/components/datasets/documents/detail/embedding/hooks/use-embedding-status.ts b/web/app/components/datasets/documents/detail/embedding/hooks/use-embedding-status.ts new file mode 100644 index 0000000000..e55cd8f9aa --- /dev/null +++ b/web/app/components/datasets/documents/detail/embedding/hooks/use-embedding-status.ts @@ -0,0 +1,149 @@ +import type { CommonResponse } from '@/models/common' +import type { IndexingStatusResponse } from '@/models/datasets' +import { useMutation, useQuery, useQueryClient } from '@tanstack/react-query' +import { useCallback, useEffect, useMemo, useRef } from 'react' +import { + fetchIndexingStatus, + pauseDocIndexing, + resumeDocIndexing, +} from '@/service/datasets' + +const NAME_SPACE = 'embedding' + +export type EmbeddingStatusType = 'indexing' | 'splitting' | 'parsing' | 'cleaning' | 'completed' | 'paused' | 'error' | 'waiting' | '' + +const EMBEDDING_STATUSES = ['indexing', 'splitting', 'parsing', 'cleaning'] as const +const TERMINAL_STATUSES = ['completed', 'error', 'paused'] as const + +export const isEmbeddingStatus = (status?: string): boolean => { + return EMBEDDING_STATUSES.includes(status as typeof EMBEDDING_STATUSES[number]) +} + +export const isTerminalStatus = (status?: string): boolean => { + return TERMINAL_STATUSES.includes(status as typeof TERMINAL_STATUSES[number]) +} + +export const calculatePercent = (completed?: number, total?: number): number => { + if (!total || total === 0) + return 0 + const percent = Math.round((completed || 0) * 100 / total) + return Math.min(percent, 100) +} + +type UseEmbeddingStatusOptions = { + datasetId?: string + documentId?: string + enabled?: boolean + onComplete?: () => void +} + +export const useEmbeddingStatus = ({ + datasetId, + documentId, + enabled = true, + onComplete, +}: UseEmbeddingStatusOptions) => { + const queryClient = useQueryClient() + const isPolling = useRef(false) + const onCompleteRef = useRef(onComplete) + onCompleteRef.current = onComplete + + const queryKey = useMemo( + () => [NAME_SPACE, 'indexing-status', datasetId, documentId] as const, + [datasetId, documentId], + ) + + const query = useQuery<IndexingStatusResponse>({ + queryKey, + queryFn: () => fetchIndexingStatus({ datasetId: datasetId!, documentId: documentId! }), + enabled: enabled && !!datasetId && !!documentId, + refetchInterval: (query) => { + const status = query.state.data?.indexing_status + if (isTerminalStatus(status)) { + return false + } + return 2500 + }, + refetchOnWindowFocus: false, + }) + + const status = query.data?.indexing_status || '' + const isEmbedding = isEmbeddingStatus(status) + const isCompleted = status === 'completed' + const isPaused = status === 'paused' + const isError = status === 'error' + const percent = calculatePercent(query.data?.completed_segments, query.data?.total_segments) + + // Handle completion callback + useEffect(() => { + if (isTerminalStatus(status) && isPolling.current) { + isPolling.current = false + onCompleteRef.current?.() + } + if (isEmbedding) { + isPolling.current = true + } + }, [status, isEmbedding]) + + const invalidate = useCallback(() => { + queryClient.invalidateQueries({ queryKey }) + }, [queryClient, queryKey]) + + const resetStatus = useCallback(() => { + queryClient.setQueryData(queryKey, null) + }, [queryClient, queryKey]) + + return { + data: query.data, + isLoading: query.isLoading, + isEmbedding, + isCompleted, + isPaused, + isError, + percent, + invalidate, + resetStatus, + refetch: query.refetch, + } +} + +type UsePauseResumeOptions = { + datasetId?: string + documentId?: string + onSuccess?: () => void + onError?: (error: Error) => void +} + +export const usePauseIndexing = ({ datasetId, documentId, onSuccess, onError }: UsePauseResumeOptions) => { + return useMutation<CommonResponse, Error>({ + mutationKey: [NAME_SPACE, 'pause', datasetId, documentId], + mutationFn: () => pauseDocIndexing({ datasetId: datasetId!, documentId: documentId! }), + onSuccess, + onError, + }) +} + +export const useResumeIndexing = ({ datasetId, documentId, onSuccess, onError }: UsePauseResumeOptions) => { + return useMutation<CommonResponse, Error>({ + mutationKey: [NAME_SPACE, 'resume', datasetId, documentId], + mutationFn: () => resumeDocIndexing({ datasetId: datasetId!, documentId: documentId! }), + onSuccess, + onError, + }) +} + +export const useInvalidateEmbeddingStatus = () => { + const queryClient = useQueryClient() + return useCallback((datasetId?: string, documentId?: string) => { + if (datasetId && documentId) { + queryClient.invalidateQueries({ + queryKey: [NAME_SPACE, 'indexing-status', datasetId, documentId], + }) + } + else { + queryClient.invalidateQueries({ + queryKey: [NAME_SPACE, 'indexing-status'], + }) + } + }, [queryClient]) +} diff --git a/web/app/components/datasets/documents/detail/embedding/index.spec.tsx b/web/app/components/datasets/documents/detail/embedding/index.spec.tsx new file mode 100644 index 0000000000..699de4f12a --- /dev/null +++ b/web/app/components/datasets/documents/detail/embedding/index.spec.tsx @@ -0,0 +1,337 @@ +import type { ReactNode } from 'react' +import type { DocumentContextValue } from '../context' +import type { IndexingStatusResponse, ProcessRuleResponse } from '@/models/datasets' +import { QueryClient, QueryClientProvider } from '@tanstack/react-query' +import { render, screen, waitFor } from '@testing-library/react' +import userEvent from '@testing-library/user-event' +import { beforeEach, describe, expect, it, vi } from 'vitest' +import { ProcessMode } from '@/models/datasets' +import * as datasetsService from '@/service/datasets' +import * as useDataset from '@/service/knowledge/use-dataset' +import { RETRIEVE_METHOD } from '@/types/app' +import { IndexingType } from '../../../create/step-two' +import { DocumentContext } from '../context' +import EmbeddingDetail from './index' + +vi.mock('@/service/datasets') +vi.mock('@/service/knowledge/use-dataset') + +const mockFetchIndexingStatus = vi.mocked(datasetsService.fetchIndexingStatus) +const mockPauseDocIndexing = vi.mocked(datasetsService.pauseDocIndexing) +const mockResumeDocIndexing = vi.mocked(datasetsService.resumeDocIndexing) +const mockUseProcessRule = vi.mocked(useDataset.useProcessRule) + +const createTestQueryClient = () => new QueryClient({ + defaultOptions: { + queries: { retry: false, gcTime: 0 }, + mutations: { retry: false }, + }, +}) + +const createWrapper = (contextValue: DocumentContextValue = { datasetId: 'ds1', documentId: 'doc1' }) => { + const queryClient = createTestQueryClient() + return ({ children }: { children: ReactNode }) => ( + <QueryClientProvider client={queryClient}> + <DocumentContext.Provider value={contextValue}> + {children} + </DocumentContext.Provider> + </QueryClientProvider> + ) +} + +const mockIndexingStatus = (overrides: Partial<IndexingStatusResponse> = {}): IndexingStatusResponse => ({ + id: 'doc1', + indexing_status: 'indexing', + completed_segments: 50, + total_segments: 100, + processing_started_at: Date.now(), + parsing_completed_at: 0, + cleaning_completed_at: 0, + splitting_completed_at: 0, + completed_at: null, + paused_at: null, + error: null, + stopped_at: null, + ...overrides, +}) + +const mockProcessRule = (overrides: Partial<ProcessRuleResponse> = {}): ProcessRuleResponse => ({ + mode: ProcessMode.general, + rules: { + segmentation: { separator: '\n', max_tokens: 500, chunk_overlap: 50 }, + pre_processing_rules: [{ id: 'remove_extra_spaces', enabled: true }], + parent_mode: 'full-doc', + subchunk_segmentation: { separator: '\n', max_tokens: 200, chunk_overlap: 20 }, + }, + limits: { indexing_max_segmentation_tokens_length: 4000 }, + ...overrides, +}) + +describe('EmbeddingDetail', () => { + const defaultProps = { + detailUpdate: vi.fn(), + indexingType: IndexingType.QUALIFIED, + retrievalMethod: RETRIEVE_METHOD.semantic, + } + + beforeEach(() => { + vi.clearAllMocks() + + mockUseProcessRule.mockReturnValue({ + data: mockProcessRule(), + isLoading: false, + error: null, + } as ReturnType<typeof useDataset.useProcessRule>) + }) + + describe('Rendering', () => { + it('should render without crashing', async () => { + mockFetchIndexingStatus.mockResolvedValue(mockIndexingStatus()) + + render(<EmbeddingDetail {...defaultProps} />, { wrapper: createWrapper() }) + + await waitFor(() => { + expect(screen.getByText(/embedding\.processing/i)).toBeInTheDocument() + }) + }) + + it('should render with provided datasetId and documentId props', async () => { + mockFetchIndexingStatus.mockResolvedValue(mockIndexingStatus()) + + render( + <EmbeddingDetail {...defaultProps} datasetId="custom-ds" documentId="custom-doc" />, + { wrapper: createWrapper({ datasetId: '', documentId: '' }) }, + ) + + await waitFor(() => { + expect(mockFetchIndexingStatus).toHaveBeenCalledWith({ + datasetId: 'custom-ds', + documentId: 'custom-doc', + }) + }) + }) + + it('should fall back to context values when props are not provided', async () => { + mockFetchIndexingStatus.mockResolvedValue(mockIndexingStatus()) + + render(<EmbeddingDetail {...defaultProps} />, { wrapper: createWrapper() }) + + await waitFor(() => { + expect(mockFetchIndexingStatus).toHaveBeenCalledWith({ + datasetId: 'ds1', + documentId: 'doc1', + }) + }) + }) + }) + + describe('Status Display', () => { + it('should show processing status when indexing', async () => { + mockFetchIndexingStatus.mockResolvedValue(mockIndexingStatus({ indexing_status: 'indexing' })) + + render(<EmbeddingDetail {...defaultProps} />, { wrapper: createWrapper() }) + + await waitFor(() => { + expect(screen.getByText(/embedding\.processing/i)).toBeInTheDocument() + }) + }) + + it('should show completed status', async () => { + mockFetchIndexingStatus.mockResolvedValue(mockIndexingStatus({ indexing_status: 'completed' })) + + render(<EmbeddingDetail {...defaultProps} />, { wrapper: createWrapper() }) + + await waitFor(() => { + expect(screen.getByText(/embedding\.completed/i)).toBeInTheDocument() + }) + }) + + it('should show paused status', async () => { + mockFetchIndexingStatus.mockResolvedValue(mockIndexingStatus({ indexing_status: 'paused' })) + + render(<EmbeddingDetail {...defaultProps} />, { wrapper: createWrapper() }) + + await waitFor(() => { + expect(screen.getByText(/embedding\.paused/i)).toBeInTheDocument() + }) + }) + + it('should show error status', async () => { + mockFetchIndexingStatus.mockResolvedValue(mockIndexingStatus({ indexing_status: 'error' })) + + render(<EmbeddingDetail {...defaultProps} />, { wrapper: createWrapper() }) + + await waitFor(() => { + expect(screen.getByText(/embedding\.error/i)).toBeInTheDocument() + }) + }) + }) + + describe('Progress Display', () => { + it('should display segment progress', async () => { + mockFetchIndexingStatus.mockResolvedValue(mockIndexingStatus({ + completed_segments: 50, + total_segments: 100, + })) + + render(<EmbeddingDetail {...defaultProps} />, { wrapper: createWrapper() }) + + await waitFor(() => { + expect(screen.getByText(/50\/100/)).toBeInTheDocument() + expect(screen.getByText(/50%/)).toBeInTheDocument() + }) + }) + }) + + describe('Pause/Resume Actions', () => { + it('should show pause button when embedding is in progress', async () => { + mockFetchIndexingStatus.mockResolvedValue(mockIndexingStatus({ indexing_status: 'indexing' })) + + render(<EmbeddingDetail {...defaultProps} />, { wrapper: createWrapper() }) + + await waitFor(() => { + expect(screen.getByText(/embedding\.pause/i)).toBeInTheDocument() + }) + }) + + it('should show resume button when paused', async () => { + mockFetchIndexingStatus.mockResolvedValue(mockIndexingStatus({ indexing_status: 'paused' })) + + render(<EmbeddingDetail {...defaultProps} />, { wrapper: createWrapper() }) + + await waitFor(() => { + expect(screen.getByText(/embedding\.resume/i)).toBeInTheDocument() + }) + }) + + it('should call pause API when pause button is clicked', async () => { + const user = userEvent.setup() + mockFetchIndexingStatus.mockResolvedValue(mockIndexingStatus({ indexing_status: 'indexing' })) + mockPauseDocIndexing.mockResolvedValue({ result: 'success' }) + + render(<EmbeddingDetail {...defaultProps} />, { wrapper: createWrapper() }) + + await waitFor(() => { + expect(screen.getByText(/embedding\.pause/i)).toBeInTheDocument() + }) + + await user.click(screen.getByRole('button', { name: /pause/i })) + + await waitFor(() => { + expect(mockPauseDocIndexing).toHaveBeenCalledWith({ + datasetId: 'ds1', + documentId: 'doc1', + }) + }) + }) + + it('should call resume API when resume button is clicked', async () => { + const user = userEvent.setup() + mockFetchIndexingStatus.mockResolvedValue(mockIndexingStatus({ indexing_status: 'paused' })) + mockResumeDocIndexing.mockResolvedValue({ result: 'success' }) + + render(<EmbeddingDetail {...defaultProps} />, { wrapper: createWrapper() }) + + await waitFor(() => { + expect(screen.getByText(/embedding\.resume/i)).toBeInTheDocument() + }) + + await user.click(screen.getByRole('button', { name: /resume/i })) + + await waitFor(() => { + expect(mockResumeDocIndexing).toHaveBeenCalledWith({ + datasetId: 'ds1', + documentId: 'doc1', + }) + }) + }) + }) + + describe('Rule Detail', () => { + it('should display rule detail section', async () => { + mockFetchIndexingStatus.mockResolvedValue(mockIndexingStatus()) + + render(<EmbeddingDetail {...defaultProps} />, { wrapper: createWrapper() }) + + await waitFor(() => { + expect(screen.getByText(/stepTwo\.indexMode/i)).toBeInTheDocument() + }) + }) + + it('should display qualified index mode', async () => { + mockFetchIndexingStatus.mockResolvedValue(mockIndexingStatus()) + + render( + <EmbeddingDetail {...defaultProps} indexingType={IndexingType.QUALIFIED} />, + { wrapper: createWrapper() }, + ) + + await waitFor(() => { + expect(screen.getByText(/stepTwo\.qualified/i)).toBeInTheDocument() + }) + }) + + it('should display economical index mode', async () => { + mockFetchIndexingStatus.mockResolvedValue(mockIndexingStatus()) + + render( + <EmbeddingDetail {...defaultProps} indexingType={IndexingType.ECONOMICAL} />, + { wrapper: createWrapper() }, + ) + + await waitFor(() => { + expect(screen.getByText(/stepTwo\.economical/i)).toBeInTheDocument() + }) + }) + }) + + describe('detailUpdate Callback', () => { + it('should call detailUpdate when status becomes terminal', async () => { + const detailUpdate = vi.fn() + // First call returns indexing, subsequent call returns completed + mockFetchIndexingStatus + .mockResolvedValueOnce(mockIndexingStatus({ indexing_status: 'indexing' })) + .mockResolvedValueOnce(mockIndexingStatus({ indexing_status: 'completed' })) + + render( + <EmbeddingDetail {...defaultProps} detailUpdate={detailUpdate} />, + { wrapper: createWrapper() }, + ) + + // Wait for the terminal status to trigger detailUpdate + await waitFor(() => { + expect(mockFetchIndexingStatus).toHaveBeenCalled() + }, { timeout: 5000 }) + }) + }) + + describe('Edge Cases', () => { + it('should handle missing context values', async () => { + mockFetchIndexingStatus.mockResolvedValue(mockIndexingStatus()) + + render( + <EmbeddingDetail {...defaultProps} datasetId="explicit-ds" documentId="explicit-doc" />, + { wrapper: createWrapper({ datasetId: undefined, documentId: undefined }) }, + ) + + await waitFor(() => { + expect(mockFetchIndexingStatus).toHaveBeenCalledWith({ + datasetId: 'explicit-ds', + documentId: 'explicit-doc', + }) + }) + }) + + it('should render skeleton component', async () => { + mockFetchIndexingStatus.mockResolvedValue(mockIndexingStatus()) + + const { container } = render(<EmbeddingDetail {...defaultProps} />, { wrapper: createWrapper() }) + + // EmbeddingSkeleton should be rendered - check for the skeleton wrapper element + await waitFor(() => { + const skeletonWrapper = container.querySelector('.bg-dataset-chunk-list-mask-bg') + expect(skeletonWrapper).toBeInTheDocument() + }) + }) + }) +}) diff --git a/web/app/components/datasets/documents/detail/embedding/index.tsx b/web/app/components/datasets/documents/detail/embedding/index.tsx index 37b5bb85e7..e89a85c6de 100644 --- a/web/app/components/datasets/documents/detail/embedding/index.tsx +++ b/web/app/components/datasets/documents/detail/embedding/index.tsx @@ -1,31 +1,18 @@ import type { FC } from 'react' -import type { CommonResponse } from '@/models/common' -import type { IndexingStatusResponse, ProcessRuleResponse } from '@/models/datasets' -import { RiLoader2Line, RiPauseCircleLine, RiPlayCircleLine } from '@remixicon/react' -import Image from 'next/image' +import type { IndexingType } from '../../../create/step-two' +import type { RETRIEVE_METHOD } from '@/types/app' import * as React from 'react' -import { useCallback, useEffect, useMemo, useRef, useState } from 'react' +import { useCallback } from 'react' import { useTranslation } from 'react-i18next' import { useContext } from 'use-context-selector' -import Divider from '@/app/components/base/divider' import { ToastContext } from '@/app/components/base/toast' -import { ProcessMode } from '@/models/datasets' -import { - fetchIndexingStatus as doFetchIndexingStatus, - pauseDocIndexing, - resumeDocIndexing, -} from '@/service/datasets' import { useProcessRule } from '@/service/knowledge/use-dataset' -import { RETRIEVE_METHOD } from '@/types/app' -import { asyncRunSafe, sleep } from '@/utils' -import { cn } from '@/utils/classnames' -import { indexMethodIcon, retrievalIcon } from '../../../create/icons' -import { IndexingType } from '../../../create/step-two' import { useDocumentContext } from '../context' -import { FieldInfo } from '../metadata' +import { ProgressBar, RuleDetail, SegmentProgress, StatusHeader } from './components' +import { useEmbeddingStatus, usePauseIndexing, useResumeIndexing } from './hooks' import EmbeddingSkeleton from './skeleton' -type IEmbeddingDetailProps = { +type EmbeddingDetailProps = { datasetId?: string documentId?: string indexingType?: IndexingType @@ -33,128 +20,7 @@ type IEmbeddingDetailProps = { detailUpdate: VoidFunction } -type IRuleDetailProps = { - sourceData?: ProcessRuleResponse - indexingType?: IndexingType - retrievalMethod?: RETRIEVE_METHOD -} - -const RuleDetail: FC<IRuleDetailProps> = React.memo(({ - sourceData, - indexingType, - retrievalMethod, -}) => { - const { t } = useTranslation() - - const segmentationRuleMap = { - mode: t('embedding.mode', { ns: 'datasetDocuments' }), - segmentLength: t('embedding.segmentLength', { ns: 'datasetDocuments' }), - textCleaning: t('embedding.textCleaning', { ns: 'datasetDocuments' }), - } - - const getRuleName = (key: string) => { - if (key === 'remove_extra_spaces') - return t('stepTwo.removeExtraSpaces', { ns: 'datasetCreation' }) - - if (key === 'remove_urls_emails') - return t('stepTwo.removeUrlEmails', { ns: 'datasetCreation' }) - - if (key === 'remove_stopwords') - return t('stepTwo.removeStopwords', { ns: 'datasetCreation' }) - } - - const isNumber = (value: unknown) => { - return typeof value === 'number' - } - - const getValue = useCallback((field: string) => { - let value: string | number | undefined = '-' - const maxTokens = isNumber(sourceData?.rules?.segmentation?.max_tokens) - ? sourceData.rules.segmentation.max_tokens - : value - const childMaxTokens = isNumber(sourceData?.rules?.subchunk_segmentation?.max_tokens) - ? sourceData.rules.subchunk_segmentation.max_tokens - : value - switch (field) { - case 'mode': - value = !sourceData?.mode - ? value - : sourceData.mode === ProcessMode.general - ? (t('embedding.custom', { ns: 'datasetDocuments' }) as string) - : `${t('embedding.hierarchical', { ns: 'datasetDocuments' })} ยท ${sourceData?.rules?.parent_mode === 'paragraph' - ? t('parentMode.paragraph', { ns: 'dataset' }) - : t('parentMode.fullDoc', { ns: 'dataset' })}` - break - case 'segmentLength': - value = !sourceData?.mode - ? value - : sourceData.mode === ProcessMode.general - ? maxTokens - : `${t('embedding.parentMaxTokens', { ns: 'datasetDocuments' })} ${maxTokens}; ${t('embedding.childMaxTokens', { ns: 'datasetDocuments' })} ${childMaxTokens}` - break - default: - value = !sourceData?.mode - ? value - : sourceData?.rules?.pre_processing_rules?.filter(rule => - rule.enabled).map(rule => getRuleName(rule.id)).join(',') - break - } - return value - }, [sourceData]) - - return ( - <div className="py-3"> - <div className="flex flex-col gap-y-1"> - {Object.keys(segmentationRuleMap).map((field) => { - return ( - <FieldInfo - key={field} - label={segmentationRuleMap[field as keyof typeof segmentationRuleMap]} - displayedValue={String(getValue(field))} - /> - ) - })} - </div> - <Divider type="horizontal" className="bg-divider-subtle" /> - <FieldInfo - label={t('stepTwo.indexMode', { ns: 'datasetCreation' })} - displayedValue={t(`stepTwo.${indexingType === IndexingType.ECONOMICAL ? 'economical' : 'qualified'}`, { ns: 'datasetCreation' }) as string} - valueIcon={( - <Image - className="size-4" - src={ - indexingType === IndexingType.ECONOMICAL - ? indexMethodIcon.economical - : indexMethodIcon.high_quality - } - alt="" - /> - )} - /> - <FieldInfo - label={t('form.retrievalSetting.title', { ns: 'datasetSettings' })} - displayedValue={t(`retrieval.${indexingType === IndexingType.ECONOMICAL ? 'keyword_search' : retrievalMethod ?? 'semantic_search'}.title`, { ns: 'dataset' })} - valueIcon={( - <Image - className="size-4" - src={ - retrievalMethod === RETRIEVE_METHOD.fullText - ? retrievalIcon.fullText - : retrievalMethod === RETRIEVE_METHOD.hybrid - ? retrievalIcon.hybrid - : retrievalIcon.vector - } - alt="" - /> - )} - /> - </div> - ) -}) - -RuleDetail.displayName = 'RuleDetail' - -const EmbeddingDetail: FC<IEmbeddingDetailProps> = ({ +const EmbeddingDetail: FC<EmbeddingDetailProps> = ({ datasetId: dstId, documentId: docId, detailUpdate, @@ -164,144 +30,95 @@ const EmbeddingDetail: FC<IEmbeddingDetailProps> = ({ const { t } = useTranslation() const { notify } = useContext(ToastContext) - const datasetId = useDocumentContext(s => s.datasetId) - const documentId = useDocumentContext(s => s.documentId) - const localDatasetId = dstId ?? datasetId - const localDocumentId = docId ?? documentId + const contextDatasetId = useDocumentContext(s => s.datasetId) + const contextDocumentId = useDocumentContext(s => s.documentId) + const datasetId = dstId ?? contextDatasetId + const documentId = docId ?? contextDocumentId - const [indexingStatusDetail, setIndexingStatusDetail] = useState<IndexingStatusResponse | null>(null) - const fetchIndexingStatus = async () => { - const status = await doFetchIndexingStatus({ datasetId: localDatasetId, documentId: localDocumentId }) - setIndexingStatusDetail(status) - return status - } + const { + data: indexingStatus, + isEmbedding, + isCompleted, + isPaused, + isError, + percent, + resetStatus, + refetch, + } = useEmbeddingStatus({ + datasetId, + documentId, + onComplete: detailUpdate, + }) - const isStopQuery = useRef(false) - const stopQueryStatus = useCallback(() => { - isStopQuery.current = true - }, []) + const { data: ruleDetail } = useProcessRule(documentId) - const startQueryStatus = useCallback(async () => { - if (isStopQuery.current) - return + const handleSuccess = useCallback(() => { + notify({ type: 'success', message: t('actionMsg.modifiedSuccessfully', { ns: 'common' }) }) + }, [notify, t]) - try { - const indexingStatusDetail = await fetchIndexingStatus() - if (['completed', 'error', 'paused'].includes(indexingStatusDetail?.indexing_status)) { - stopQueryStatus() - detailUpdate() - return - } + const handleError = useCallback(() => { + notify({ type: 'error', message: t('actionMsg.modifiedUnsuccessfully', { ns: 'common' }) }) + }, [notify, t]) - await sleep(2500) - await startQueryStatus() - } - catch { - await sleep(2500) - await startQueryStatus() - } - }, [stopQueryStatus]) + const pauseMutation = usePauseIndexing({ + datasetId, + documentId, + onSuccess: () => { + handleSuccess() + resetStatus() + }, + onError: handleError, + }) - useEffect(() => { - isStopQuery.current = false - startQueryStatus() - return () => { - stopQueryStatus() - } - }, [startQueryStatus, stopQueryStatus]) + const resumeMutation = useResumeIndexing({ + datasetId, + documentId, + onSuccess: () => { + handleSuccess() + refetch() + detailUpdate() + }, + onError: handleError, + }) - const { data: ruleDetail } = useProcessRule(localDocumentId) + const handlePause = useCallback(() => { + pauseMutation.mutate() + }, [pauseMutation]) - const isEmbedding = useMemo(() => ['indexing', 'splitting', 'parsing', 'cleaning'].includes(indexingStatusDetail?.indexing_status || ''), [indexingStatusDetail]) - const isEmbeddingCompleted = useMemo(() => ['completed'].includes(indexingStatusDetail?.indexing_status || ''), [indexingStatusDetail]) - const isEmbeddingPaused = useMemo(() => ['paused'].includes(indexingStatusDetail?.indexing_status || ''), [indexingStatusDetail]) - const isEmbeddingError = useMemo(() => ['error'].includes(indexingStatusDetail?.indexing_status || ''), [indexingStatusDetail]) - const percent = useMemo(() => { - const completedCount = indexingStatusDetail?.completed_segments || 0 - const totalCount = indexingStatusDetail?.total_segments || 0 - if (totalCount === 0) - return 0 - const percent = Math.round(completedCount * 100 / totalCount) - return percent > 100 ? 100 : percent - }, [indexingStatusDetail]) - - const handleSwitch = async () => { - const opApi = isEmbedding ? pauseDocIndexing : resumeDocIndexing - const [e] = await asyncRunSafe<CommonResponse>(opApi({ datasetId: localDatasetId, documentId: localDocumentId }) as Promise<CommonResponse>) - if (!e) { - notify({ type: 'success', message: t('actionMsg.modifiedSuccessfully', { ns: 'common' }) }) - // if the embedding is resumed from paused, we need to start the query status - if (isEmbeddingPaused) { - isStopQuery.current = false - startQueryStatus() - detailUpdate() - } - setIndexingStatusDetail(null) - } - else { - notify({ type: 'error', message: t('actionMsg.modifiedUnsuccessfully', { ns: 'common' }) }) - } - } + const handleResume = useCallback(() => { + resumeMutation.mutate() + }, [resumeMutation]) return ( <> <div className="flex flex-col gap-y-2 px-16 py-12"> - <div className="flex h-6 items-center gap-x-1"> - {isEmbedding && <RiLoader2Line className="h-4 w-4 animate-spin text-text-secondary" />} - <span className="system-md-semibold-uppercase grow text-text-secondary"> - {isEmbedding && t('embedding.processing', { ns: 'datasetDocuments' })} - {isEmbeddingCompleted && t('embedding.completed', { ns: 'datasetDocuments' })} - {isEmbeddingPaused && t('embedding.paused', { ns: 'datasetDocuments' })} - {isEmbeddingError && t('embedding.error', { ns: 'datasetDocuments' })} - </span> - {isEmbedding && ( - <button - type="button" - className={`flex items-center gap-x-1 rounded-md border-[0.5px] - border-components-button-secondary-border bg-components-button-secondary-bg px-1.5 py-1 shadow-xs shadow-shadow-shadow-3 backdrop-blur-[5px]`} - onClick={handleSwitch} - > - <RiPauseCircleLine className="h-3.5 w-3.5 text-components-button-secondary-text" /> - <span className="system-xs-medium pr-[3px] text-components-button-secondary-text"> - {t('embedding.pause', { ns: 'datasetDocuments' })} - </span> - </button> - )} - {isEmbeddingPaused && ( - <button - type="button" - className={`flex items-center gap-x-1 rounded-md border-[0.5px] - border-components-button-secondary-border bg-components-button-secondary-bg px-1.5 py-1 shadow-xs shadow-shadow-shadow-3 backdrop-blur-[5px]`} - onClick={handleSwitch} - > - <RiPlayCircleLine className="h-3.5 w-3.5 text-components-button-secondary-text" /> - <span className="system-xs-medium pr-[3px] text-components-button-secondary-text"> - {t('embedding.resume', { ns: 'datasetDocuments' })} - </span> - </button> - )} - </div> - {/* progress bar */} - <div className={cn( - 'flex h-2 w-full items-center overflow-hidden rounded-md border border-components-progress-bar-border', - isEmbedding ? 'bg-components-progress-bar-bg/50' : 'bg-components-progress-bar-bg', - )} - > - <div - className={cn( - 'h-full', - (isEmbedding || isEmbeddingCompleted) && 'bg-components-progress-bar-progress-solid', - (isEmbeddingPaused || isEmbeddingError) && 'bg-components-progress-bar-progress-highlight', - )} - style={{ width: `${percent}%` }} - /> - </div> - <div className="flex w-full items-center"> - <span className="system-xs-medium text-text-secondary"> - {`${t('embedding.segments', { ns: 'datasetDocuments' })} ${indexingStatusDetail?.completed_segments || '--'}/${indexingStatusDetail?.total_segments || '--'} ยท ${percent}%`} - </span> - </div> - <RuleDetail sourceData={ruleDetail} indexingType={indexingType} retrievalMethod={retrievalMethod} /> + <StatusHeader + isEmbedding={isEmbedding} + isCompleted={isCompleted} + isPaused={isPaused} + isError={isError} + onPause={handlePause} + onResume={handleResume} + isPauseLoading={pauseMutation.isPending} + isResumeLoading={resumeMutation.isPending} + /> + <ProgressBar + percent={percent} + isEmbedding={isEmbedding} + isCompleted={isCompleted} + isPaused={isPaused} + isError={isError} + /> + <SegmentProgress + completedSegments={indexingStatus?.completed_segments} + totalSegments={indexingStatus?.total_segments} + percent={percent} + /> + <RuleDetail + sourceData={ruleDetail} + indexingType={indexingType} + retrievalMethod={retrievalMethod} + /> </div> <EmbeddingSkeleton /> </> diff --git a/web/app/components/datasets/list/dataset-card/components/dataset-card-header.spec.tsx b/web/app/components/datasets/list/dataset-card/components/dataset-card-header.spec.tsx index 6a0e3693a3..c7121287b3 100644 --- a/web/app/components/datasets/list/dataset-card/components/dataset-card-header.spec.tsx +++ b/web/app/components/datasets/list/dataset-card/components/dataset-card-header.spec.tsx @@ -6,6 +6,13 @@ import { ChunkingMode, DatasetPermission, DataSourceType } from '@/models/datase import { RETRIEVE_METHOD } from '@/types/app' import DatasetCardHeader from './dataset-card-header' +// Mock AppIcon component to avoid emoji-mart initialization issues +vi.mock('@/app/components/base/app-icon', () => ({ + default: ({ icon, className }: { icon?: string, className?: string }) => ( + <div data-testid="app-icon" className={className}>{icon}</div> + ), +})) + // Mock useFormatTimeFromNow hook vi.mock('@/hooks/use-format-time-from-now', () => ({ useFormatTimeFromNow: () => ({ diff --git a/web/app/components/datasets/list/dataset-card/components/dataset-card-modals.spec.tsx b/web/app/components/datasets/list/dataset-card/components/dataset-card-modals.spec.tsx index ebee72159e..607830661d 100644 --- a/web/app/components/datasets/list/dataset-card/components/dataset-card-modals.spec.tsx +++ b/web/app/components/datasets/list/dataset-card/components/dataset-card-modals.spec.tsx @@ -19,6 +19,28 @@ vi.mock('../../../rename-modal', () => ({ ), })) +// Mock Confirm component since it uses createPortal which can cause issues in tests +vi.mock('@/app/components/base/confirm', () => ({ + default: ({ isShow, title, content, onConfirm, onCancel }: { + isShow: boolean + title: string + content?: React.ReactNode + onConfirm: () => void + onCancel: () => void + }) => ( + isShow + ? ( + <div data-testid="confirm-modal"> + <div data-testid="confirm-title">{title}</div> + <div data-testid="confirm-content">{content}</div> + <button onClick={onCancel} role="button" aria-label="cancel">Cancel</button> + <button onClick={onConfirm} role="button" aria-label="confirm">Confirm</button> + </div> + ) + : null + ), +})) + describe('DatasetCardModals', () => { const mockDataset: DataSet = { id: 'dataset-1', @@ -172,11 +194,9 @@ describe('DatasetCardModals', () => { />, ) - // Find and click the confirm button - const confirmButton = screen.getByRole('button', { name: /confirm|ok|delete/i }) - || screen.getAllByRole('button').find(btn => btn.textContent?.toLowerCase().includes('confirm')) - if (confirmButton) - fireEvent.click(confirmButton) + // Find and click the confirm button using our mocked Confirm component + const confirmButton = screen.getByRole('button', { name: /confirm/i }) + fireEvent.click(confirmButton) expect(onConfirmDelete).toHaveBeenCalledTimes(1) }) diff --git a/web/app/components/goto-anything/index.spec.tsx b/web/app/components/goto-anything/index.spec.tsx index 7fb45726e8..6a6143a6e2 100644 --- a/web/app/components/goto-anything/index.spec.tsx +++ b/web/app/components/goto-anything/index.spec.tsx @@ -70,6 +70,10 @@ vi.mock('./context', () => ({ GotoAnythingProvider: ({ children }: { children: React.ReactNode }) => <>{children}</>, })) +vi.mock('@/app/components/workflow/utils', () => ({ + getKeyboardKeyNameBySystem: (key: string) => key, +})) + const createActionItem = (key: ActionItem['key'], shortcut: string): ActionItem => ({ key, shortcut, diff --git a/web/app/components/rag-pipeline/components/panel/index.spec.tsx b/web/app/components/rag-pipeline/components/panel/index.spec.tsx index 97229aa443..11f9f8b2c4 100644 --- a/web/app/components/rag-pipeline/components/panel/index.spec.tsx +++ b/web/app/components/rag-pipeline/components/panel/index.spec.tsx @@ -7,47 +7,72 @@ import RagPipelinePanel from './index' // Mock External Dependencies // ============================================================================ -// Type definitions for dynamic module -type DynamicModule = { - default?: React.ComponentType<Record<string, unknown>> -} +// Mock reactflow to avoid zustand provider error +vi.mock('reactflow', () => ({ + useNodes: () => [], + useStoreApi: () => ({ + getState: () => ({ + getNodes: () => [], + }), + }), + useReactFlow: () => ({ + getNodes: () => [], + }), + useStore: (selector: (state: Record<string, unknown>) => unknown) => { + const state = { + getNodes: () => [], + } + return selector(state) + }, +})) -type PromiseOrModule = Promise<DynamicModule> | DynamicModule +// Use vi.hoisted to create variables that can be used in vi.mock +const { dynamicMocks, mockInputFieldEditorProps } = vi.hoisted(() => { + let counter = 0 + const mockInputFieldEditorProps = vi.fn() -// Mock next/dynamic to return synchronous components immediately + const createMockComponent = () => { + const index = counter++ + // Order matches the imports in index.tsx: + // 0: Record + // 1: TestRunPanel + // 2: InputFieldPanel + // 3: InputFieldEditorPanel + // 4: PreviewPanel + // 5: GlobalVariablePanel + switch (index) { + case 0: + return () => <div data-testid="record-panel">Record Panel</div> + case 1: + return () => <div data-testid="test-run-panel">Test Run Panel</div> + case 2: + return () => <div data-testid="input-field-panel">Input Field Panel</div> + case 3: + return (props: Record<string, unknown>) => { + mockInputFieldEditorProps(props) + return <div data-testid="input-field-editor-panel">Input Field Editor Panel</div> + } + case 4: + return () => <div data-testid="preview-panel">Preview Panel</div> + case 5: + return () => <div data-testid="global-variable-panel">Global Variable Panel</div> + default: + return () => ( + <div data-testid="dynamic-fallback"> + Dynamic Component + {index} + </div> + ) + } + } + + return { dynamicMocks: { createMockComponent }, mockInputFieldEditorProps } +}) + +// Mock next/dynamic vi.mock('next/dynamic', () => ({ - default: (loader: () => PromiseOrModule, _options?: Record<string, unknown>) => { - let Component: React.ComponentType<Record<string, unknown>> | null = null - - // Try to resolve the loader synchronously for mocked modules - try { - const result = loader() as PromiseOrModule - if (result && typeof (result as Promise<DynamicModule>).then === 'function') { - // For async modules, we need to handle them specially - // This will work with vi.mock since mocks resolve synchronously - (result as Promise<DynamicModule>).then((mod: DynamicModule) => { - Component = (mod.default || mod) as React.ComponentType<Record<string, unknown>> - }) - } - else if (result) { - Component = ((result as DynamicModule).default || result) as React.ComponentType<Record<string, unknown>> - } - } - catch { - // If the module can't be resolved, Component stays null - } - - // Return a simple wrapper that renders the component or null - const DynamicComponent = React.forwardRef((props: Record<string, unknown>, ref: React.Ref<unknown>) => { - // For mocked modules, Component should already be set - if (Component) - return <Component {...props} ref={ref} /> - - return null - }) - - DynamicComponent.displayName = 'DynamicComponent' - return DynamicComponent + default: (_loader: () => Promise<{ default: React.ComponentType }>, _options?: Record<string, unknown>) => { + return dynamicMocks.createMockComponent() }, })) @@ -68,6 +93,28 @@ type MockStoreState = { showInputFieldPreviewPanel: boolean inputFieldEditPanelProps: Record<string, unknown> | null pipelineId: string + nodePanelWidth: number + workflowCanvasWidth: number + otherPanelWidth: number + setShowInputFieldPanel?: (show: boolean) => void + setShowInputFieldPreviewPanel?: (show: boolean) => void + setInputFieldEditPanelProps?: (props: Record<string, unknown> | null) => void +} + +const mockWorkflowStoreState: MockStoreState = { + historyWorkflowData: null, + showDebugAndPreviewPanel: false, + showGlobalVariablePanel: false, + showInputFieldPanel: false, + showInputFieldPreviewPanel: false, + inputFieldEditPanelProps: null, + pipelineId: 'test-pipeline-123', + nodePanelWidth: 400, + workflowCanvasWidth: 1200, + otherPanelWidth: 0, + setShowInputFieldPanel: vi.fn(), + setShowInputFieldPreviewPanel: vi.fn(), + setInputFieldEditPanelProps: vi.fn(), } vi.mock('@/app/components/workflow/store', () => ({ @@ -80,9 +127,15 @@ vi.mock('@/app/components/workflow/store', () => ({ showInputFieldPreviewPanel: mockShowInputFieldPreviewPanel, inputFieldEditPanelProps: mockInputFieldEditPanelProps, pipelineId: mockPipelineId, + nodePanelWidth: 400, + workflowCanvasWidth: 1200, + otherPanelWidth: 0, } return selector(state) }, + useWorkflowStore: () => ({ + getState: () => mockWorkflowStoreState, + }), })) // Mock Panel component to capture props and render children @@ -99,40 +152,6 @@ vi.mock('@/app/components/workflow/panel', () => ({ }, })) -// Mock Record component -vi.mock('@/app/components/workflow/panel/record', () => ({ - default: () => <div data-testid="record-panel">Record Panel</div>, -})) - -// Mock TestRunPanel component -vi.mock('@/app/components/rag-pipeline/components/panel/test-run', () => ({ - default: () => <div data-testid="test-run-panel">Test Run Panel</div>, -})) - -// Mock InputFieldPanel component -vi.mock('./input-field', () => ({ - default: () => <div data-testid="input-field-panel">Input Field Panel</div>, -})) - -// Mock InputFieldEditorPanel component -const mockInputFieldEditorProps = vi.fn() -vi.mock('./input-field/editor', () => ({ - default: (props: Record<string, unknown>) => { - mockInputFieldEditorProps(props) - return <div data-testid="input-field-editor-panel">Input Field Editor Panel</div> - }, -})) - -// Mock PreviewPanel component -vi.mock('./input-field/preview', () => ({ - default: () => <div data-testid="preview-panel">Preview Panel</div>, -})) - -// Mock GlobalVariablePanel component -vi.mock('@/app/components/workflow/panel/global-variable-panel', () => ({ - default: () => <div data-testid="global-variable-panel">Global Variable Panel</div>, -})) - // ============================================================================ // Helper Functions // ============================================================================ diff --git a/web/app/components/rag-pipeline/components/update-dsl-modal.spec.tsx b/web/app/components/rag-pipeline/components/update-dsl-modal.spec.tsx index 45eb1cafe1..317f2b19d4 100644 --- a/web/app/components/rag-pipeline/components/update-dsl-modal.spec.tsx +++ b/web/app/components/rag-pipeline/components/update-dsl-modal.spec.tsx @@ -134,22 +134,6 @@ vi.mock('@/app/components/workflow/constants', () => ({ WORKFLOW_DATA_UPDATE: 'WORKFLOW_DATA_UPDATE', })) -// Mock FileReader -class MockFileReader { - result: string | null = null - onload: ((e: { target: { result: string | null } }) => void) | null = null - - readAsText(_file: File) { - // Simulate async file reading using queueMicrotask for more reliable async behavior - queueMicrotask(() => { - this.result = 'test file content' - if (this.onload) { - this.onload({ target: { result: this.result } }) - } - }) - } -} - afterEach(() => { cleanup() vi.clearAllMocks() @@ -159,7 +143,6 @@ describe('UpdateDSLModal', () => { const mockOnCancel = vi.fn() const mockOnBackup = vi.fn() const mockOnImport = vi.fn() - let originalFileReader: typeof FileReader const defaultProps = { onCancel: mockOnCancel, @@ -175,14 +158,6 @@ describe('UpdateDSLModal', () => { pipeline_id: 'test-pipeline-id', }) mockHandleCheckPluginDependencies.mockResolvedValue(undefined) - - // Mock FileReader - originalFileReader = globalThis.FileReader - globalThis.FileReader = MockFileReader as unknown as typeof FileReader - }) - - afterEach(() => { - globalThis.FileReader = originalFileReader }) describe('rendering', () => { @@ -552,6 +527,7 @@ describe('UpdateDSLModal', () => { const file = new File(['test content'], 'test.pipeline', { type: 'text/yaml' }) fireEvent.change(fileInput, { target: { files: [file] } }) + // Wait for FileReader to process and button to be enabled await waitFor(() => { const importButton = screen.getByText('common.overwriteAndImport') expect(importButton).not.toBeDisabled() @@ -576,15 +552,12 @@ describe('UpdateDSLModal', () => { const file = new File(['test content'], 'test.pipeline', { type: 'text/yaml' }) fireEvent.change(fileInput, { target: { files: [file] } }) - // Wait for FileReader to complete (setTimeout 0) and button to be enabled + // Wait for FileReader to complete and button to be enabled await waitFor(() => { const importButton = screen.getByText('common.overwriteAndImport') expect(importButton).not.toBeDisabled() }) - // Give extra time for the FileReader's setTimeout to complete - await new Promise(resolve => setTimeout(resolve, 10)) - const importButton = screen.getByText('common.overwriteAndImport') fireEvent.click(importButton) @@ -719,7 +692,7 @@ describe('UpdateDSLModal', () => { await waitFor(() => { expect(screen.getByText('1.0.0')).toBeInTheDocument() expect(screen.getByText('2.0.0')).toBeInTheDocument() - }, { timeout: 500 }) + }, { timeout: 1000 }) }) it('should close error modal when cancel button is clicked', async () => { @@ -748,7 +721,7 @@ describe('UpdateDSLModal', () => { // Wait for error modal await waitFor(() => { expect(screen.getByText('newApp.appCreateDSLErrorTitle')).toBeInTheDocument() - }, { timeout: 500 }) + }, { timeout: 1000 }) // Find and click cancel button in error modal - it should be the one with secondary variant const cancelButtons = screen.getAllByText('newApp.Cancel') @@ -805,7 +778,7 @@ describe('UpdateDSLModal', () => { await waitFor(() => { expect(screen.getByText('newApp.appCreateDSLErrorTitle')).toBeInTheDocument() - }) + }, { timeout: 1000 }) // Click confirm button const confirmButton = screen.getByText('newApp.Confirm') @@ -848,7 +821,7 @@ describe('UpdateDSLModal', () => { await waitFor(() => { expect(screen.getByText('newApp.appCreateDSLErrorTitle')).toBeInTheDocument() - }, { timeout: 500 }) + }, { timeout: 1000 }) const confirmButton = screen.getByText('newApp.Confirm') fireEvent.click(confirmButton) @@ -890,7 +863,7 @@ describe('UpdateDSLModal', () => { await waitFor(() => { expect(screen.getByText('newApp.appCreateDSLErrorTitle')).toBeInTheDocument() - }, { timeout: 500 }) + }, { timeout: 1000 }) const confirmButton = screen.getByText('newApp.Confirm') fireEvent.click(confirmButton) @@ -929,7 +902,7 @@ describe('UpdateDSLModal', () => { await waitFor(() => { expect(screen.getByText('newApp.appCreateDSLErrorTitle')).toBeInTheDocument() - }, { timeout: 500 }) + }, { timeout: 1000 }) const confirmButton = screen.getByText('newApp.Confirm') fireEvent.click(confirmButton) @@ -971,7 +944,7 @@ describe('UpdateDSLModal', () => { await waitFor(() => { expect(screen.getByText('newApp.appCreateDSLErrorTitle')).toBeInTheDocument() - }, { timeout: 500 }) + }, { timeout: 1000 }) const confirmButton = screen.getByText('newApp.Confirm') fireEvent.click(confirmButton) @@ -1013,7 +986,7 @@ describe('UpdateDSLModal', () => { await waitFor(() => { expect(screen.getByText('newApp.appCreateDSLErrorTitle')).toBeInTheDocument() - }, { timeout: 500 }) + }, { timeout: 1000 }) const confirmButton = screen.getByText('newApp.Confirm') fireEvent.click(confirmButton) @@ -1063,7 +1036,7 @@ describe('UpdateDSLModal', () => { await waitFor(() => { expect(screen.getByText('newApp.appCreateDSLErrorTitle')).toBeInTheDocument() - }) + }, { timeout: 1000 }) const confirmButton = screen.getByText('newApp.Confirm') fireEvent.click(confirmButton) @@ -1101,7 +1074,7 @@ describe('UpdateDSLModal', () => { // Should show error modal even with undefined versions await waitFor(() => { expect(screen.getByText('newApp.appCreateDSLErrorTitle')).toBeInTheDocument() - }, { timeout: 500 }) + }, { timeout: 1000 }) }) it('should not call importDSLConfirm when importId is not set', async () => { diff --git a/web/app/components/rag-pipeline/hooks/use-DSL.spec.ts b/web/app/components/rag-pipeline/hooks/use-DSL.spec.ts index c6e3d261c0..295ed20bd8 100644 --- a/web/app/components/rag-pipeline/hooks/use-DSL.spec.ts +++ b/web/app/components/rag-pipeline/hooks/use-DSL.spec.ts @@ -53,9 +53,41 @@ vi.mock('@/app/components/workflow/constants', () => ({ // ============================================================================ describe('useDSL', () => { + let mockLink: { href: string, download: string, click: ReturnType<typeof vi.fn>, style: { display: string }, remove: ReturnType<typeof vi.fn> } + let originalCreateElement: typeof document.createElement + let originalAppendChild: typeof document.body.appendChild + let mockCreateObjectURL: ReturnType<typeof vi.spyOn> + let mockRevokeObjectURL: ReturnType<typeof vi.spyOn> + beforeEach(() => { vi.clearAllMocks() + // Create a proper mock link element with all required properties for downloadBlob + mockLink = { + href: '', + download: '', + click: vi.fn(), + style: { display: '' }, + remove: vi.fn(), + } + + // Save original and mock selectively - only intercept 'a' elements + originalCreateElement = document.createElement.bind(document) + document.createElement = vi.fn((tagName: string) => { + if (tagName === 'a') { + return mockLink as unknown as HTMLElement + } + return originalCreateElement(tagName) + }) as typeof document.createElement + + // Mock document.body.appendChild for downloadBlob + originalAppendChild = document.body.appendChild.bind(document.body) + document.body.appendChild = vi.fn(<T extends Node>(node: T): T => node) as typeof document.body.appendChild + + // downloadBlob uses window.URL, not URL + mockCreateObjectURL = vi.spyOn(window.URL, 'createObjectURL').mockReturnValue('blob:test-url') + mockRevokeObjectURL = vi.spyOn(window.URL, 'revokeObjectURL').mockImplementation(() => {}) + // Default store state mockGetState.mockReturnValue({ pipelineId: 'test-pipeline-id', @@ -68,6 +100,10 @@ describe('useDSL', () => { }) afterEach(() => { + document.createElement = originalCreateElement + document.body.appendChild = originalAppendChild + mockCreateObjectURL.mockRestore() + mockRevokeObjectURL.mockRestore() vi.clearAllMocks() }) diff --git a/web/app/components/workflow-app/components/workflow-onboarding-modal/index.spec.tsx b/web/app/components/workflow-app/components/workflow-onboarding-modal/index.spec.tsx index 19f5e8b346..63d0344275 100644 --- a/web/app/components/workflow-app/components/workflow-onboarding-modal/index.spec.tsx +++ b/web/app/components/workflow-app/components/workflow-onboarding-modal/index.spec.tsx @@ -13,8 +13,8 @@ vi.mock('@/app/components/base/modal', () => ({ closable, }: { isShow: boolean - onClose: () => void - children: React.ReactNode + onClose?: () => void + children?: React.ReactNode closable?: boolean }) { if (!isShow) @@ -45,8 +45,8 @@ vi.mock('./start-node-selection-panel', () => ({ onSelectUserInput, onSelectTrigger, }: { - onSelectUserInput: () => void - onSelectTrigger: (type: BlockEnum, config?: Record<string, unknown>) => void + onSelectUserInput?: () => void + onSelectTrigger?: (type: BlockEnum, config?: Record<string, unknown>) => void }) { return ( <div data-testid="start-node-selection-panel"> @@ -55,13 +55,13 @@ vi.mock('./start-node-selection-panel', () => ({ </button> <button data-testid="select-trigger-schedule" - onClick={() => onSelectTrigger(BlockEnum.TriggerSchedule)} + onClick={() => onSelectTrigger?.(BlockEnum.TriggerSchedule)} > Select Trigger Schedule </button> <button data-testid="select-trigger-webhook" - onClick={() => onSelectTrigger(BlockEnum.TriggerWebhook, { config: 'test' })} + onClick={() => onSelectTrigger?.(BlockEnum.TriggerWebhook, { config: 'test' })} > Select Trigger Webhook </button> @@ -557,7 +557,7 @@ describe('WorkflowOnboardingModal', () => { // Arrange & Act renderComponent({ isShow: true }) - // Assert + // Assert - ShortcutsName component renders keys in div elements with system-kbd class const escKey = screen.getByText('workflow.onboarding.escTip.key') // ShortcutsName renders a <div> with class system-kbd, not a <kbd> element expect(escKey.closest('.system-kbd')).toBeInTheDocument() diff --git a/web/eslint-suppressions.json b/web/eslint-suppressions.json index 79742805df..2a35bf49b6 100644 --- a/web/eslint-suppressions.json +++ b/web/eslint-suppressions.json @@ -530,11 +530,6 @@ "count": 1 } }, - "app/components/app/create-app-modal/index.spec.tsx": { - "ts/no-explicit-any": { - "count": 7 - } - }, "app/components/app/create-app-modal/index.tsx": { "react-hooks-extra/no-direct-set-state-in-use-effect": { "count": 1 @@ -1801,14 +1796,8 @@ } }, "app/components/datasets/documents/components/list.tsx": { - "react-hooks-extra/no-direct-set-state-in-use-effect": { - "count": 2 - }, "react-refresh/only-export-components": { "count": 1 - }, - "ts/no-explicit-any": { - "count": 2 } }, "app/components/datasets/documents/create-from-pipeline/data-source/base/credential-selector/index.spec.tsx": { From 0dfa59b1db17dffd5a346de15a791649e7702c8a Mon Sep 17 00:00:00 2001 From: wangxiaolei <fatelei@gmail.com> Date: Wed, 4 Feb 2026 19:10:27 +0800 Subject: [PATCH 13/18] fix: fix delete_draft_variables_batch cycle forever (#31934) Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com> --- api/tasks/remove_app_and_related_data_task.py | 6 +- .../test_remove_app_and_related_data_task.py | 305 +++++++++++++++++- .../test_remove_app_and_related_data_task.py | 2 +- 3 files changed, 302 insertions(+), 11 deletions(-) diff --git a/api/tasks/remove_app_and_related_data_task.py b/api/tasks/remove_app_and_related_data_task.py index 817249845a..6240f2200f 100644 --- a/api/tasks/remove_app_and_related_data_task.py +++ b/api/tasks/remove_app_and_related_data_task.py @@ -259,8 +259,8 @@ def _delete_app_workflow_app_logs(tenant_id: str, app_id: str): def _delete_app_workflow_archive_logs(tenant_id: str, app_id: str): - def del_workflow_archive_log(workflow_archive_log_id: str): - db.session.query(WorkflowArchiveLog).where(WorkflowArchiveLog.id == workflow_archive_log_id).delete( + def del_workflow_archive_log(session, workflow_archive_log_id: str): + session.query(WorkflowArchiveLog).where(WorkflowArchiveLog.id == workflow_archive_log_id).delete( synchronize_session=False ) @@ -420,7 +420,7 @@ def delete_draft_variables_batch(app_id: str, batch_size: int = 1000) -> int: total_files_deleted = 0 while True: - with session_factory.create_session() as session: + with session_factory.create_session() as session, session.begin(): # Get a batch of draft variable IDs along with their file_ids query_sql = """ SELECT id, file_id FROM workflow_draft_variables diff --git a/api/tests/integration_tests/tasks/test_remove_app_and_related_data_task.py b/api/tests/integration_tests/tasks/test_remove_app_and_related_data_task.py index f46d1bf5db..d020233620 100644 --- a/api/tests/integration_tests/tasks/test_remove_app_and_related_data_task.py +++ b/api/tests/integration_tests/tasks/test_remove_app_and_related_data_task.py @@ -10,7 +10,10 @@ from models import Tenant from models.enums import CreatorUserRole from models.model import App, UploadFile from models.workflow import WorkflowDraftVariable, WorkflowDraftVariableFile -from tasks.remove_app_and_related_data_task import _delete_draft_variables, delete_draft_variables_batch +from tasks.remove_app_and_related_data_task import ( + _delete_draft_variables, + delete_draft_variables_batch, +) @pytest.fixture @@ -297,12 +300,18 @@ class TestDeleteDraftVariablesWithOffloadIntegration: def test_delete_draft_variables_with_offload_data(self, mock_storage, setup_offload_test_data): data = setup_offload_test_data app_id = data["app"].id + upload_file_ids = [uf.id for uf in data["upload_files"]] + variable_file_ids = [vf.id for vf in data["variable_files"]] mock_storage.delete.return_value = None with session_factory.create_session() as session: draft_vars_before = session.query(WorkflowDraftVariable).filter_by(app_id=app_id).count() - var_files_before = session.query(WorkflowDraftVariableFile).count() - upload_files_before = session.query(UploadFile).count() + var_files_before = ( + session.query(WorkflowDraftVariableFile) + .where(WorkflowDraftVariableFile.id.in_(variable_file_ids)) + .count() + ) + upload_files_before = session.query(UploadFile).where(UploadFile.id.in_(upload_file_ids)).count() assert draft_vars_before == 3 assert var_files_before == 2 assert upload_files_before == 2 @@ -315,8 +324,12 @@ class TestDeleteDraftVariablesWithOffloadIntegration: assert draft_vars_after == 0 with session_factory.create_session() as session: - var_files_after = session.query(WorkflowDraftVariableFile).count() - upload_files_after = session.query(UploadFile).count() + var_files_after = ( + session.query(WorkflowDraftVariableFile) + .where(WorkflowDraftVariableFile.id.in_(variable_file_ids)) + .count() + ) + upload_files_after = session.query(UploadFile).where(UploadFile.id.in_(upload_file_ids)).count() assert var_files_after == 0 assert upload_files_after == 0 @@ -329,6 +342,8 @@ class TestDeleteDraftVariablesWithOffloadIntegration: def test_delete_draft_variables_storage_failure_continues_cleanup(self, mock_storage, setup_offload_test_data): data = setup_offload_test_data app_id = data["app"].id + upload_file_ids = [uf.id for uf in data["upload_files"]] + variable_file_ids = [vf.id for vf in data["variable_files"]] mock_storage.delete.side_effect = [Exception("Storage error"), None] deleted_count = delete_draft_variables_batch(app_id, batch_size=10) @@ -339,8 +354,12 @@ class TestDeleteDraftVariablesWithOffloadIntegration: assert draft_vars_after == 0 with session_factory.create_session() as session: - var_files_after = session.query(WorkflowDraftVariableFile).count() - upload_files_after = session.query(UploadFile).count() + var_files_after = ( + session.query(WorkflowDraftVariableFile) + .where(WorkflowDraftVariableFile.id.in_(variable_file_ids)) + .count() + ) + upload_files_after = session.query(UploadFile).where(UploadFile.id.in_(upload_file_ids)).count() assert var_files_after == 0 assert upload_files_after == 0 @@ -395,3 +414,275 @@ class TestDeleteDraftVariablesWithOffloadIntegration: if app2_obj: session.delete(app2_obj) session.commit() + + +class TestDeleteDraftVariablesSessionCommit: + """Test suite to verify session commit behavior in delete_draft_variables_batch.""" + + @pytest.fixture + def setup_offload_test_data(self, app_and_tenant): + """Create test data with offload files for session commit tests.""" + from core.variables.types import SegmentType + from libs.datetime_utils import naive_utc_now + + tenant, app = app_and_tenant + + with session_factory.create_session() as session: + upload_file1 = UploadFile( + tenant_id=tenant.id, + storage_type="local", + key="test/file1.json", + name="file1.json", + size=1024, + extension="json", + mime_type="application/json", + created_by_role=CreatorUserRole.ACCOUNT, + created_by=str(uuid.uuid4()), + created_at=naive_utc_now(), + used=False, + ) + upload_file2 = UploadFile( + tenant_id=tenant.id, + storage_type="local", + key="test/file2.json", + name="file2.json", + size=2048, + extension="json", + mime_type="application/json", + created_by_role=CreatorUserRole.ACCOUNT, + created_by=str(uuid.uuid4()), + created_at=naive_utc_now(), + used=False, + ) + session.add(upload_file1) + session.add(upload_file2) + session.flush() + + var_file1 = WorkflowDraftVariableFile( + tenant_id=tenant.id, + app_id=app.id, + user_id=str(uuid.uuid4()), + upload_file_id=upload_file1.id, + size=1024, + length=10, + value_type=SegmentType.STRING, + ) + var_file2 = WorkflowDraftVariableFile( + tenant_id=tenant.id, + app_id=app.id, + user_id=str(uuid.uuid4()), + upload_file_id=upload_file2.id, + size=2048, + length=20, + value_type=SegmentType.OBJECT, + ) + session.add(var_file1) + session.add(var_file2) + session.flush() + + draft_var1 = WorkflowDraftVariable.new_node_variable( + app_id=app.id, + node_id="node_1", + name="large_var_1", + value=StringSegment(value="truncated..."), + node_execution_id=str(uuid.uuid4()), + file_id=var_file1.id, + ) + draft_var2 = WorkflowDraftVariable.new_node_variable( + app_id=app.id, + node_id="node_2", + name="large_var_2", + value=StringSegment(value="truncated..."), + node_execution_id=str(uuid.uuid4()), + file_id=var_file2.id, + ) + draft_var3 = WorkflowDraftVariable.new_node_variable( + app_id=app.id, + node_id="node_3", + name="regular_var", + value=StringSegment(value="regular_value"), + node_execution_id=str(uuid.uuid4()), + ) + session.add(draft_var1) + session.add(draft_var2) + session.add(draft_var3) + session.commit() + + data = { + "app": app, + "tenant": tenant, + "upload_files": [upload_file1, upload_file2], + "variable_files": [var_file1, var_file2], + "draft_variables": [draft_var1, draft_var2, draft_var3], + } + + yield data + + with session_factory.create_session() as session: + for table, ids in [ + (WorkflowDraftVariable, [v.id for v in data["draft_variables"]]), + (WorkflowDraftVariableFile, [vf.id for vf in data["variable_files"]]), + (UploadFile, [uf.id for uf in data["upload_files"]]), + ]: + cleanup_query = delete(table).where(table.id.in_(ids)).execution_options(synchronize_session=False) + session.execute(cleanup_query) + session.commit() + + @pytest.fixture + def setup_commit_test_data(self, app_and_tenant): + """Create test data for session commit tests.""" + tenant, app = app_and_tenant + variable_ids: list[str] = [] + + with session_factory.create_session() as session: + variables = [] + for i in range(10): + var = WorkflowDraftVariable.new_node_variable( + app_id=app.id, + node_id=f"node_{i}", + name=f"var_{i}", + value=StringSegment(value="test_value"), + node_execution_id=str(uuid.uuid4()), + ) + session.add(var) + variables.append(var) + session.commit() + variable_ids = [v.id for v in variables] + + yield { + "app": app, + "tenant": tenant, + "variable_ids": variable_ids, + } + + with session_factory.create_session() as session: + cleanup_query = ( + delete(WorkflowDraftVariable) + .where(WorkflowDraftVariable.id.in_(variable_ids)) + .execution_options(synchronize_session=False) + ) + session.execute(cleanup_query) + session.commit() + + def test_session_commit_is_called_after_each_batch(self, setup_commit_test_data): + """Test that session.begin() is used for automatic transaction management.""" + data = setup_commit_test_data + app_id = data["app"].id + + # Since session.begin() is used, the transaction is automatically committed + # when the with block exits successfully. We verify this by checking that + # data is actually persisted. + deleted_count = delete_draft_variables_batch(app_id, batch_size=3) + + # Verify all data was deleted (proves transaction was committed) + with session_factory.create_session() as session: + remaining_count = session.query(WorkflowDraftVariable).filter_by(app_id=app_id).count() + + assert deleted_count == 10 + assert remaining_count == 0 + + def test_data_persisted_after_batch_deletion(self, setup_commit_test_data): + """Test that data is actually persisted to database after batch deletion with commits.""" + data = setup_commit_test_data + app_id = data["app"].id + variable_ids = data["variable_ids"] + + # Verify initial state + with session_factory.create_session() as session: + initial_count = session.query(WorkflowDraftVariable).filter_by(app_id=app_id).count() + assert initial_count == 10 + + # Perform deletion with small batch size to force multiple commits + deleted_count = delete_draft_variables_batch(app_id, batch_size=3) + + assert deleted_count == 10 + + # Verify all data is deleted in a new session (proves commits worked) + with session_factory.create_session() as session: + final_count = session.query(WorkflowDraftVariable).filter_by(app_id=app_id).count() + assert final_count == 0 + + # Verify specific IDs are deleted + with session_factory.create_session() as session: + remaining_vars = ( + session.query(WorkflowDraftVariable).where(WorkflowDraftVariable.id.in_(variable_ids)).count() + ) + assert remaining_vars == 0 + + def test_session_commit_with_empty_dataset(self, setup_commit_test_data): + """Test session behavior when deleting from an empty dataset.""" + nonexistent_app_id = str(uuid.uuid4()) + + # Should not raise any errors and should return 0 + deleted_count = delete_draft_variables_batch(nonexistent_app_id, batch_size=10) + assert deleted_count == 0 + + def test_session_commit_with_single_batch(self, setup_commit_test_data): + """Test that commit happens correctly when all data fits in a single batch.""" + data = setup_commit_test_data + app_id = data["app"].id + + with session_factory.create_session() as session: + initial_count = session.query(WorkflowDraftVariable).filter_by(app_id=app_id).count() + assert initial_count == 10 + + # Delete all in a single batch + deleted_count = delete_draft_variables_batch(app_id, batch_size=100) + assert deleted_count == 10 + + # Verify data is persisted + with session_factory.create_session() as session: + final_count = session.query(WorkflowDraftVariable).filter_by(app_id=app_id).count() + assert final_count == 0 + + def test_invalid_batch_size_raises_error(self, setup_commit_test_data): + """Test that invalid batch size raises ValueError.""" + data = setup_commit_test_data + app_id = data["app"].id + + with pytest.raises(ValueError, match="batch_size must be positive"): + delete_draft_variables_batch(app_id, batch_size=0) + + with pytest.raises(ValueError, match="batch_size must be positive"): + delete_draft_variables_batch(app_id, batch_size=-1) + + @patch("extensions.ext_storage.storage") + def test_session_commit_with_offload_data_cleanup(self, mock_storage, setup_offload_test_data): + """Test that session commits correctly when cleaning up offload data.""" + data = setup_offload_test_data + app_id = data["app"].id + upload_file_ids = [uf.id for uf in data["upload_files"]] + mock_storage.delete.return_value = None + + # Verify initial state + with session_factory.create_session() as session: + draft_vars_before = session.query(WorkflowDraftVariable).filter_by(app_id=app_id).count() + var_files_before = ( + session.query(WorkflowDraftVariableFile) + .where(WorkflowDraftVariableFile.id.in_([vf.id for vf in data["variable_files"]])) + .count() + ) + upload_files_before = session.query(UploadFile).where(UploadFile.id.in_(upload_file_ids)).count() + assert draft_vars_before == 3 + assert var_files_before == 2 + assert upload_files_before == 2 + + # Delete variables with offload data + deleted_count = delete_draft_variables_batch(app_id, batch_size=10) + assert deleted_count == 3 + + # Verify all data is persisted (deleted) in new session + with session_factory.create_session() as session: + draft_vars_after = session.query(WorkflowDraftVariable).filter_by(app_id=app_id).count() + var_files_after = ( + session.query(WorkflowDraftVariableFile) + .where(WorkflowDraftVariableFile.id.in_([vf.id for vf in data["variable_files"]])) + .count() + ) + upload_files_after = session.query(UploadFile).where(UploadFile.id.in_(upload_file_ids)).count() + assert draft_vars_after == 0 + assert var_files_after == 0 + assert upload_files_after == 0 + + # Verify storage cleanup was called + assert mock_storage.delete.call_count == 2 diff --git a/api/tests/unit_tests/tasks/test_remove_app_and_related_data_task.py b/api/tests/unit_tests/tasks/test_remove_app_and_related_data_task.py index a14bbb01d0..2b11e42cd5 100644 --- a/api/tests/unit_tests/tasks/test_remove_app_and_related_data_task.py +++ b/api/tests/unit_tests/tasks/test_remove_app_and_related_data_task.py @@ -350,7 +350,7 @@ class TestDeleteWorkflowArchiveLogs: mock_query.where.return_value = mock_delete_query mock_db.session.query.return_value = mock_query - delete_func("log-1") + delete_func(mock_db.session, "log-1") mock_db.session.query.assert_called_once_with(WorkflowArchiveLog) mock_query.where.assert_called_once() From 3bd228ddb7575d4107b3eeb73aed29335639085e Mon Sep 17 00:00:00 2001 From: QuantumGhost <obelisk.reg+git@gmail.com> Date: Wed, 4 Feb 2026 19:29:28 +0800 Subject: [PATCH 14/18] chore: bump version in docker-compose and package manager to 1.12.1 (#31947) --- api/pyproject.toml | 2 +- api/uv.lock | 2 +- docker/docker-compose-template.yaml | 8 ++++---- docker/docker-compose.yaml | 8 ++++---- web/package.json | 2 +- 5 files changed, 11 insertions(+), 11 deletions(-) diff --git a/api/pyproject.toml b/api/pyproject.toml index c05e884271..4be7afff26 100644 --- a/api/pyproject.toml +++ b/api/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "dify-api" -version = "1.12.0" +version = "1.12.1" requires-python = ">=3.11,<3.13" dependencies = [ diff --git a/api/uv.lock b/api/uv.lock index aefb8e91f0..0a17741f9a 100644 --- a/api/uv.lock +++ b/api/uv.lock @@ -1368,7 +1368,7 @@ wheels = [ [[package]] name = "dify-api" -version = "1.12.0" +version = "1.12.1" source = { virtual = "." } dependencies = [ { name = "aliyun-log-python-sdk" }, diff --git a/docker/docker-compose-template.yaml b/docker/docker-compose-template.yaml index aacb551933..cb5e2c47f7 100644 --- a/docker/docker-compose-template.yaml +++ b/docker/docker-compose-template.yaml @@ -21,7 +21,7 @@ services: # API service api: - image: langgenius/dify-api:1.12.0 + image: langgenius/dify-api:1.12.1 restart: always environment: # Use the shared environment variables. @@ -63,7 +63,7 @@ services: # worker service # The Celery worker for processing all queues (dataset, workflow, mail, etc.) worker: - image: langgenius/dify-api:1.12.0 + image: langgenius/dify-api:1.12.1 restart: always environment: # Use the shared environment variables. @@ -102,7 +102,7 @@ services: # worker_beat service # Celery beat for scheduling periodic tasks. worker_beat: - image: langgenius/dify-api:1.12.0 + image: langgenius/dify-api:1.12.1 restart: always environment: # Use the shared environment variables. @@ -132,7 +132,7 @@ services: # Frontend web application. web: - image: langgenius/dify-web:1.12.0 + image: langgenius/dify-web:1.12.1 restart: always environment: CONSOLE_API_URL: ${CONSOLE_API_URL:-} diff --git a/docker/docker-compose.yaml b/docker/docker-compose.yaml index 1cb327cfe4..1886f848e0 100644 --- a/docker/docker-compose.yaml +++ b/docker/docker-compose.yaml @@ -707,7 +707,7 @@ services: # API service api: - image: langgenius/dify-api:1.12.0 + image: langgenius/dify-api:1.12.1 restart: always environment: # Use the shared environment variables. @@ -749,7 +749,7 @@ services: # worker service # The Celery worker for processing all queues (dataset, workflow, mail, etc.) worker: - image: langgenius/dify-api:1.12.0 + image: langgenius/dify-api:1.12.1 restart: always environment: # Use the shared environment variables. @@ -788,7 +788,7 @@ services: # worker_beat service # Celery beat for scheduling periodic tasks. worker_beat: - image: langgenius/dify-api:1.12.0 + image: langgenius/dify-api:1.12.1 restart: always environment: # Use the shared environment variables. @@ -818,7 +818,7 @@ services: # Frontend web application. web: - image: langgenius/dify-web:1.12.0 + image: langgenius/dify-web:1.12.1 restart: always environment: CONSOLE_API_URL: ${CONSOLE_API_URL:-} diff --git a/web/package.json b/web/package.json index 494a9f0848..219a613363 100644 --- a/web/package.json +++ b/web/package.json @@ -1,7 +1,7 @@ { "name": "dify-web", "type": "module", - "version": "1.12.0", + "version": "1.12.1", "private": true, "packageManager": "pnpm@10.27.0+sha512.72d699da16b1179c14ba9e64dc71c9a40988cbdc65c264cb0e489db7de917f20dcf4d64d8723625f2969ba52d4b7e2a1170682d9ac2a5dcaeaab732b7e16f04a", "imports": { From f584be9cf06831fb92eb31481e857c195ca1873c Mon Sep 17 00:00:00 2001 From: Coding On Star <447357187@qq.com> Date: Wed, 4 Feb 2026 19:29:57 +0800 Subject: [PATCH 15/18] chore: update CODEOWNERS to specify test file patterns for base components (#31941) Co-authored-by: CodingOnStar <hanxujiang@dify.com> --- .github/CODEOWNERS | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/CODEOWNERS b/.github/CODEOWNERS index 6cd99d551a..bfb1c85436 100644 --- a/.github/CODEOWNERS +++ b/.github/CODEOWNERS @@ -239,7 +239,7 @@ /web/app/components/base/ @iamjoel @zxhlyh # Frontend - Base Components Tests -/web/app/components/base/**/__tests__/ @hyoban @CodingOnStar +/web/app/components/base/**/*.spec.tsx @hyoban @CodingOnStar # Frontend - Utils and Hooks /web/utils/classnames.ts @iamjoel @zxhlyh From f686197589fcd2897dd04795c4711b4db6c3e7ea Mon Sep 17 00:00:00 2001 From: wangxiaolei <fatelei@gmail.com> Date: Wed, 4 Feb 2026 19:32:36 +0800 Subject: [PATCH 16/18] feat: use latest hash to sync draft (#31924) --- .../hooks/use-nodes-sync-draft.ts | 59 ++++++++++++------- 1 file changed, 37 insertions(+), 22 deletions(-) diff --git a/web/app/components/workflow-app/hooks/use-nodes-sync-draft.ts b/web/app/components/workflow-app/hooks/use-nodes-sync-draft.ts index 5d394bab1e..f3538a5abb 100644 --- a/web/app/components/workflow-app/hooks/use-nodes-sync-draft.ts +++ b/web/app/components/workflow-app/hooks/use-nodes-sync-draft.ts @@ -98,31 +98,46 @@ export const useNodesSyncDraft = () => { ) => { if (getNodesReadOnly()) return - const postParams = getPostParams() - if (postParams) { - const { - setSyncWorkflowDraftHash, - setDraftUpdatedAt, - } = workflowStore.getState() - try { - const res = await syncWorkflowDraft(postParams) - setSyncWorkflowDraftHash(res.hash) - setDraftUpdatedAt(res.updated_at) - callback?.onSuccess?.() + // Get base params without hash + const baseParams = getPostParams() + if (!baseParams) + return + + const { + setSyncWorkflowDraftHash, + setDraftUpdatedAt, + } = workflowStore.getState() + + try { + // IMPORTANT: Get the LATEST hash right before sending the request + // This ensures that even if queued, each request uses the most recent hash + const latestHash = workflowStore.getState().syncWorkflowDraftHash + + const postParams = { + ...baseParams, + params: { + ...baseParams.params, + hash: latestHash || null, // null for first-time, otherwise use latest hash + }, } - catch (error: any) { - if (error && error.json && !error.bodyUsed) { - error.json().then((err: any) => { - if (err.code === 'draft_workflow_not_sync' && !notRefreshWhenSyncError) - handleRefreshWorkflowDraft() - }) - } - callback?.onError?.() - } - finally { - callback?.onSettled?.() + + const res = await syncWorkflowDraft(postParams) + setSyncWorkflowDraftHash(res.hash) + setDraftUpdatedAt(res.updated_at) + callback?.onSuccess?.() + } + catch (error: any) { + if (error && error.json && !error.bodyUsed) { + error.json().then((err: any) => { + if (err.code === 'draft_workflow_not_sync' && !notRefreshWhenSyncError) + handleRefreshWorkflowDraft() + }) } + callback?.onError?.() + } + finally { + callback?.onSettled?.() } }, [workflowStore, getPostParams, getNodesReadOnly, handleRefreshWorkflowDraft]) From 365f749ed5794da584154a63eafaab9587f31343 Mon Sep 17 00:00:00 2001 From: yyh <92089059+lyzno1@users.noreply.github.com> Date: Wed, 4 Feb 2026 19:33:32 +0800 Subject: [PATCH 17/18] fix: remove staleTime/gcTime overrides from trigger query hooks and use orpc contract (#31863) --- web/contract/console/trigger.ts | 119 +++++++++++++++++++++++++++++++ web/contract/router.ts | 34 +++++++++ web/service/use-triggers.ts | 121 ++++++++++++++------------------ web/types/i18n.d.ts | 2 - 4 files changed, 206 insertions(+), 70 deletions(-) create mode 100644 web/contract/console/trigger.ts diff --git a/web/contract/console/trigger.ts b/web/contract/console/trigger.ts new file mode 100644 index 0000000000..41a326ccf5 --- /dev/null +++ b/web/contract/console/trigger.ts @@ -0,0 +1,119 @@ +import type { + TriggerLogEntity, + TriggerOAuthClientParams, + TriggerOAuthConfig, + TriggerProviderApiEntity, + TriggerSubscription, + TriggerSubscriptionBuilder, +} from '@/app/components/workflow/block-selector/types' +import { type } from '@orpc/contract' +import { base } from '../base' + +export const triggersContract = base + .route({ path: '/workspaces/current/triggers', method: 'GET' }) + .input(type<{ query?: { type?: string } }>()) + .output(type<TriggerProviderApiEntity[]>()) + +export const triggerProviderInfoContract = base + .route({ path: '/workspaces/current/trigger-provider/{provider}/info', method: 'GET' }) + .input(type<{ params: { provider: string } }>()) + .output(type<TriggerProviderApiEntity>()) + +export const triggerSubscriptionsContract = base + .route({ path: '/workspaces/current/trigger-provider/{provider}/subscriptions/list', method: 'GET' }) + .input(type<{ params: { provider: string } }>()) + .output(type<TriggerSubscription[]>()) + +export const triggerSubscriptionBuilderCreateContract = base + .route({ path: '/workspaces/current/trigger-provider/{provider}/subscriptions/builder/create', method: 'POST' }) + .input(type<{ + params: { provider: string } + body?: { credential_type?: string } + }>()) + .output(type<{ subscription_builder: TriggerSubscriptionBuilder }>()) + +export const triggerSubscriptionBuilderUpdateContract = base + .route({ path: '/workspaces/current/trigger-provider/{provider}/subscriptions/builder/update/{subscriptionBuilderId}', method: 'POST' }) + .input(type<{ + params: { provider: string, subscriptionBuilderId: string } + body?: { + name?: string + properties?: Record<string, unknown> + parameters?: Record<string, unknown> + credentials?: Record<string, unknown> + } + }>()) + .output(type<TriggerSubscriptionBuilder>()) + +export const triggerSubscriptionBuilderVerifyUpdateContract = base + .route({ path: '/workspaces/current/trigger-provider/{provider}/subscriptions/builder/verify-and-update/{subscriptionBuilderId}', method: 'POST' }) + .input(type<{ + params: { provider: string, subscriptionBuilderId: string } + body?: { credentials?: Record<string, unknown> } + }>()) + .output(type<{ verified: boolean }>()) + +export const triggerSubscriptionVerifyContract = base + .route({ path: '/workspaces/current/trigger-provider/{provider}/subscriptions/verify/{subscriptionId}', method: 'POST' }) + .input(type<{ + params: { provider: string, subscriptionId: string } + body?: { credentials?: Record<string, unknown> } + }>()) + .output(type<{ verified: boolean }>()) + +export const triggerSubscriptionBuildContract = base + .route({ path: '/workspaces/current/trigger-provider/{provider}/subscriptions/builder/build/{subscriptionBuilderId}', method: 'POST' }) + .input(type<{ + params: { provider: string, subscriptionBuilderId: string } + body?: { + name?: string + parameters?: Record<string, unknown> + } + }>()) + .output(type<unknown>()) + +export const triggerSubscriptionDeleteContract = base + .route({ path: '/workspaces/current/trigger-provider/{subscriptionId}/subscriptions/delete', method: 'POST' }) + .input(type<{ params: { subscriptionId: string } }>()) + .output(type<{ result: string }>()) + +export const triggerSubscriptionUpdateContract = base + .route({ path: '/workspaces/current/trigger-provider/{subscriptionId}/subscriptions/update', method: 'POST' }) + .input(type<{ + params: { subscriptionId: string } + body?: { + name?: string + properties?: Record<string, unknown> + parameters?: Record<string, unknown> + credentials?: Record<string, unknown> + } + }>()) + .output(type<{ result: string, id: string }>()) + +export const triggerSubscriptionBuilderLogsContract = base + .route({ path: '/workspaces/current/trigger-provider/{provider}/subscriptions/builder/logs/{subscriptionBuilderId}', method: 'GET' }) + .input(type<{ params: { provider: string, subscriptionBuilderId: string } }>()) + .output(type<{ logs: TriggerLogEntity[] }>()) + +export const triggerOAuthConfigContract = base + .route({ path: '/workspaces/current/trigger-provider/{provider}/oauth/client', method: 'GET' }) + .input(type<{ params: { provider: string } }>()) + .output(type<TriggerOAuthConfig>()) + +export const triggerOAuthConfigureContract = base + .route({ path: '/workspaces/current/trigger-provider/{provider}/oauth/client', method: 'POST' }) + .input(type<{ + params: { provider: string } + body: { client_params?: TriggerOAuthClientParams, enabled: boolean } + }>()) + .output(type<{ result: string }>()) + +export const triggerOAuthDeleteContract = base + .route({ path: '/workspaces/current/trigger-provider/{provider}/oauth/client', method: 'DELETE' }) + .input(type<{ params: { provider: string } }>()) + .output(type<{ result: string }>()) + +export const triggerOAuthInitiateContract = base + .route({ path: '/workspaces/current/trigger-provider/{provider}/subscriptions/oauth/authorize', method: 'GET' }) + .input(type<{ params: { provider: string } }>()) + .output(type<{ authorization_url: string, subscription_builder: TriggerSubscriptionBuilder }>()) diff --git a/web/contract/router.ts b/web/contract/router.ts index 965c381bd7..33499b106f 100644 --- a/web/contract/router.ts +++ b/web/contract/router.ts @@ -1,6 +1,23 @@ import type { InferContractRouterInputs } from '@orpc/contract' import { bindPartnerStackContract, invoicesContract } from './console/billing' import { systemFeaturesContract } from './console/system' +import { + triggerOAuthConfigContract, + triggerOAuthConfigureContract, + triggerOAuthDeleteContract, + triggerOAuthInitiateContract, + triggerProviderInfoContract, + triggersContract, + triggerSubscriptionBuildContract, + triggerSubscriptionBuilderCreateContract, + triggerSubscriptionBuilderLogsContract, + triggerSubscriptionBuilderUpdateContract, + triggerSubscriptionBuilderVerifyUpdateContract, + triggerSubscriptionDeleteContract, + triggerSubscriptionsContract, + triggerSubscriptionUpdateContract, + triggerSubscriptionVerifyContract, +} from './console/trigger' import { trialAppDatasetsContract, trialAppInfoContract, trialAppParametersContract, trialAppWorkflowsContract } from './console/try-app' import { collectionPluginsContract, collectionsContract, searchAdvancedContract } from './marketplace' @@ -24,6 +41,23 @@ export const consoleRouterContract = { invoices: invoicesContract, bindPartnerStack: bindPartnerStackContract, }, + triggers: { + list: triggersContract, + providerInfo: triggerProviderInfoContract, + subscriptions: triggerSubscriptionsContract, + subscriptionBuilderCreate: triggerSubscriptionBuilderCreateContract, + subscriptionBuilderUpdate: triggerSubscriptionBuilderUpdateContract, + subscriptionBuilderVerifyUpdate: triggerSubscriptionBuilderVerifyUpdateContract, + subscriptionVerify: triggerSubscriptionVerifyContract, + subscriptionBuild: triggerSubscriptionBuildContract, + subscriptionDelete: triggerSubscriptionDeleteContract, + subscriptionUpdate: triggerSubscriptionUpdateContract, + subscriptionBuilderLogs: triggerSubscriptionBuilderLogsContract, + oauthConfig: triggerOAuthConfigContract, + oauthConfigure: triggerOAuthConfigureContract, + oauthDelete: triggerOAuthDeleteContract, + oauthInitiate: triggerOAuthInitiateContract, + }, } export type ConsoleInputs = InferContractRouterInputs<typeof consoleRouterContract> diff --git a/web/service/use-triggers.ts b/web/service/use-triggers.ts index 24717cc50b..df5700d0e7 100644 --- a/web/service/use-triggers.ts +++ b/web/service/use-triggers.ts @@ -10,17 +10,14 @@ import type { } from '@/app/components/workflow/block-selector/types' import { useMutation, useQuery, useQueryClient } from '@tanstack/react-query' import { CollectionType } from '@/app/components/tools/types' -import { del, get, post } from './base' +import { consoleClient, consoleQuery } from '@/service/client' +import { get, post } from './base' import { useInvalid } from './use-base' const NAME_SPACE = 'triggers' -// Trigger Provider Service - Provider ID Format: plugin_id/provider_name - -// Convert backend API response to frontend ToolWithProvider format const convertToTriggerWithProvider = (provider: TriggerProviderApiEntity): TriggerWithProvider => { return { - // Collection fields id: provider.plugin_id || provider.name, name: provider.name, author: provider.author, @@ -58,12 +55,9 @@ const convertToTriggerWithProvider = (provider: TriggerProviderApiEntity): Trigg labels: provider.tags || [], output_schema: event.output_schema || {}, })), - - // Trigger-specific schema fields subscription_constructor: provider.subscription_constructor, subscription_schema: provider.subscription_schema, supported_creation_methods: provider.supported_creation_methods, - meta: { version: '1.0', }, @@ -72,22 +66,20 @@ const convertToTriggerWithProvider = (provider: TriggerProviderApiEntity): Trigg export const useAllTriggerPlugins = (enabled = true) => { return useQuery<TriggerWithProvider[]>({ - queryKey: [NAME_SPACE, 'all'], + queryKey: consoleQuery.triggers.list.queryKey({ input: {} }), queryFn: async () => { - const response = await get<TriggerProviderApiEntity[]>('/workspaces/current/triggers') + const response = await consoleClient.triggers.list({}) return response.map(convertToTriggerWithProvider) }, enabled, - staleTime: 0, - gcTime: 0, }) } export const useTriggerPluginsByType = (triggerType: string, enabled = true) => { return useQuery<TriggerWithProvider[]>({ - queryKey: [NAME_SPACE, 'byType', triggerType], + queryKey: consoleQuery.triggers.list.queryKey({ input: { query: { type: triggerType } } }), queryFn: async () => { - const response = await get<TriggerProviderApiEntity[]>(`/workspaces/current/triggers?type=${triggerType}`) + const response = await consoleClient.triggers.list({ query: { type: triggerType } }) return response.map(convertToTriggerWithProvider) }, enabled: enabled && !!triggerType, @@ -95,25 +87,23 @@ export const useTriggerPluginsByType = (triggerType: string, enabled = true) => } export const useInvalidateAllTriggerPlugins = () => { - return useInvalid([NAME_SPACE, 'all']) + return useInvalid(consoleQuery.triggers.list.queryKey({ input: {} })) } // ===== Trigger Subscriptions Management ===== export const useTriggerProviderInfo = (provider: string, enabled = true) => { return useQuery<TriggerProviderApiEntity>({ - queryKey: [NAME_SPACE, 'provider-info', provider], - queryFn: () => get<TriggerProviderApiEntity>(`/workspaces/current/trigger-provider/${provider}/info`), + queryKey: consoleQuery.triggers.providerInfo.queryKey({ input: { params: { provider } } }), + queryFn: () => consoleClient.triggers.providerInfo({ params: { provider } }), enabled: enabled && !!provider, - staleTime: 0, - gcTime: 0, }) } export const useTriggerSubscriptions = (provider: string, enabled = true) => { return useQuery<TriggerSubscription[]>({ - queryKey: [NAME_SPACE, 'list-subscriptions', provider], - queryFn: () => get<TriggerSubscription[]>(`/workspaces/current/trigger-provider/${provider}/subscriptions/list`), + queryKey: consoleQuery.triggers.subscriptions.queryKey({ input: { params: { provider } } }), + queryFn: () => consoleClient.triggers.subscriptions({ params: { provider } }), enabled: enabled && !!provider, }) } @@ -122,30 +112,30 @@ export const useInvalidateTriggerSubscriptions = () => { const queryClient = useQueryClient() return (provider: string) => { queryClient.invalidateQueries({ - queryKey: [NAME_SPACE, 'subscriptions', provider], + queryKey: consoleQuery.triggers.subscriptions.queryKey({ input: { params: { provider } } }), }) } } export const useCreateTriggerSubscriptionBuilder = () => { return useMutation({ - mutationKey: [NAME_SPACE, 'create-subscription-builder'], + mutationKey: consoleQuery.triggers.subscriptionBuilderCreate.mutationKey(), mutationFn: (payload: { provider: string credential_type?: string }) => { const { provider, ...body } = payload - return post<{ subscription_builder: TriggerSubscriptionBuilder }>( - `/workspaces/current/trigger-provider/${provider}/subscriptions/builder/create`, - { body }, - ) + return consoleClient.triggers.subscriptionBuilderCreate({ + params: { provider }, + body, + }) }, }) } export const useUpdateTriggerSubscriptionBuilder = () => { return useMutation({ - mutationKey: [NAME_SPACE, 'update-subscription-builder'], + mutationKey: consoleQuery.triggers.subscriptionBuilderUpdate.mutationKey(), mutationFn: (payload: { provider: string subscriptionBuilderId: string @@ -155,17 +145,17 @@ export const useUpdateTriggerSubscriptionBuilder = () => { credentials?: Record<string, unknown> }) => { const { provider, subscriptionBuilderId, ...body } = payload - return post<TriggerSubscriptionBuilder>( - `/workspaces/current/trigger-provider/${provider}/subscriptions/builder/update/${subscriptionBuilderId}`, - { body }, - ) + return consoleClient.triggers.subscriptionBuilderUpdate({ + params: { provider, subscriptionBuilderId }, + body, + }) }, }) } export const useVerifyAndUpdateTriggerSubscriptionBuilder = () => { return useMutation({ - mutationKey: [NAME_SPACE, 'verify-and-update-subscription-builder'], + mutationKey: consoleQuery.triggers.subscriptionBuilderVerifyUpdate.mutationKey(), mutationFn: (payload: { provider: string subscriptionBuilderId: string @@ -183,7 +173,7 @@ export const useVerifyAndUpdateTriggerSubscriptionBuilder = () => { export const useVerifyTriggerSubscription = () => { return useMutation({ - mutationKey: [NAME_SPACE, 'verify-subscription'], + mutationKey: consoleQuery.triggers.subscriptionVerify.mutationKey(), mutationFn: (payload: { provider: string subscriptionId: string @@ -208,24 +198,24 @@ export type BuildTriggerSubscriptionPayload = { export const useBuildTriggerSubscription = () => { return useMutation({ - mutationKey: [NAME_SPACE, 'build-subscription'], + mutationKey: consoleQuery.triggers.subscriptionBuild.mutationKey(), mutationFn: (payload: BuildTriggerSubscriptionPayload) => { const { provider, subscriptionBuilderId, ...body } = payload - return post( - `/workspaces/current/trigger-provider/${provider}/subscriptions/builder/build/${subscriptionBuilderId}`, - { body }, - ) + return consoleClient.triggers.subscriptionBuild({ + params: { provider, subscriptionBuilderId }, + body, + }) }, }) } export const useDeleteTriggerSubscription = () => { return useMutation({ - mutationKey: [NAME_SPACE, 'delete-subscription'], + mutationKey: consoleQuery.triggers.subscriptionDelete.mutationKey(), mutationFn: (subscriptionId: string) => { - return post<{ result: string }>( - `/workspaces/current/trigger-provider/${subscriptionId}/subscriptions/delete`, - ) + return consoleClient.triggers.subscriptionDelete({ + params: { subscriptionId }, + }) }, }) } @@ -240,13 +230,13 @@ export type UpdateTriggerSubscriptionPayload = { export const useUpdateTriggerSubscription = () => { return useMutation({ - mutationKey: [NAME_SPACE, 'update-subscription'], + mutationKey: consoleQuery.triggers.subscriptionUpdate.mutationKey(), mutationFn: (payload: UpdateTriggerSubscriptionPayload) => { const { subscriptionId, ...body } = payload - return post<{ result: string, id: string }>( - `/workspaces/current/trigger-provider/${subscriptionId}/subscriptions/update`, - { body }, - ) + return consoleClient.triggers.subscriptionUpdate({ + params: { subscriptionId }, + body, + }) }, }) } @@ -262,10 +252,8 @@ export const useTriggerSubscriptionBuilderLogs = ( const { enabled = true, refetchInterval = false } = options return useQuery<{ logs: TriggerLogEntity[] }>({ - queryKey: [NAME_SPACE, 'subscription-builder-logs', provider, subscriptionBuilderId], - queryFn: () => get( - `/workspaces/current/trigger-provider/${provider}/subscriptions/builder/logs/${subscriptionBuilderId}`, - ), + queryKey: consoleQuery.triggers.subscriptionBuilderLogs.queryKey({ input: { params: { provider, subscriptionBuilderId } } }), + queryFn: () => consoleClient.triggers.subscriptionBuilderLogs({ params: { provider, subscriptionBuilderId } }), enabled: enabled && !!provider && !!subscriptionBuilderId, refetchInterval, }) @@ -274,8 +262,8 @@ export const useTriggerSubscriptionBuilderLogs = ( // ===== OAuth Management ===== export const useTriggerOAuthConfig = (provider: string, enabled = true) => { return useQuery<TriggerOAuthConfig>({ - queryKey: [NAME_SPACE, 'oauth-config', provider], - queryFn: () => get<TriggerOAuthConfig>(`/workspaces/current/trigger-provider/${provider}/oauth/client`), + queryKey: consoleQuery.triggers.oauthConfig.queryKey({ input: { params: { provider } } }), + queryFn: () => consoleClient.triggers.oauthConfig({ params: { provider } }), enabled: enabled && !!provider, }) } @@ -288,31 +276,31 @@ export type ConfigureTriggerOAuthPayload = { export const useConfigureTriggerOAuth = () => { return useMutation({ - mutationKey: [NAME_SPACE, 'configure-oauth'], + mutationKey: consoleQuery.triggers.oauthConfigure.mutationKey(), mutationFn: (payload: ConfigureTriggerOAuthPayload) => { const { provider, ...body } = payload - return post<{ result: string }>( - `/workspaces/current/trigger-provider/${provider}/oauth/client`, - { body }, - ) + return consoleClient.triggers.oauthConfigure({ + params: { provider }, + body, + }) }, }) } export const useDeleteTriggerOAuth = () => { return useMutation({ - mutationKey: [NAME_SPACE, 'delete-oauth'], + mutationKey: consoleQuery.triggers.oauthDelete.mutationKey(), mutationFn: (provider: string) => { - return del<{ result: string }>( - `/workspaces/current/trigger-provider/${provider}/oauth/client`, - ) + return consoleClient.triggers.oauthDelete({ + params: { provider }, + }) }, }) } export const useInitiateTriggerOAuth = () => { return useMutation({ - mutationKey: [NAME_SPACE, 'initiate-oauth'], + mutationKey: consoleQuery.triggers.oauthInitiate.mutationKey(), mutationFn: (provider: string) => { return get<{ authorization_url: string, subscription_builder: TriggerSubscriptionBuilder }>( `/workspaces/current/trigger-provider/${provider}/subscriptions/oauth/authorize`, @@ -336,7 +324,6 @@ export const useTriggerPluginDynamicOptions = (payload: { return useQuery<{ options: FormOption[] }>({ queryKey: [NAME_SPACE, 'dynamic-options', payload.plugin_id, payload.provider, payload.action, payload.parameter, payload.credential_id, payload.credentials, payload.extra], queryFn: () => { - // Use new endpoint with POST when credentials provided (for edit mode) if (payload.credentials) { return post<{ options: FormOption[] }>( '/workspaces/current/plugin/parameters/dynamic-options-with-credentials', @@ -353,7 +340,6 @@ export const useTriggerPluginDynamicOptions = (payload: { { silent: true }, ) } - // Use original GET endpoint for normal cases return get<{ options: FormOption[] }>( '/workspaces/current/plugin/parameters/dynamic-options', { @@ -372,7 +358,6 @@ export const useTriggerPluginDynamicOptions = (payload: { enabled: enabled && !!payload.plugin_id && !!payload.provider && !!payload.action && !!payload.parameter && !!payload.credential_id, retry: 0, staleTime: 0, - gcTime: 0, }) } @@ -382,7 +367,7 @@ export const useInvalidateTriggerOAuthConfig = () => { const queryClient = useQueryClient() return (provider: string) => { queryClient.invalidateQueries({ - queryKey: [NAME_SPACE, 'oauth-config', provider], + queryKey: consoleQuery.triggers.oauthConfig.queryKey({ input: { params: { provider } } }), }) } } diff --git a/web/types/i18n.d.ts b/web/types/i18n.d.ts index 9e20d5a55a..bdf2ef56d3 100644 --- a/web/types/i18n.d.ts +++ b/web/types/i18n.d.ts @@ -27,5 +27,3 @@ export type I18nKeysWithPrefix< > = Prefix extends '' ? keyof Resources[NS] : Extract<keyof Resources[NS], `${Prefix}${string}`> - -type A = I18nKeysWithPrefix<'billing'> From c56ad8e3230b40143c17f83ac56f5d246d41bd32 Mon Sep 17 00:00:00 2001 From: Xiyuan Chen <52963600+GareArc@users.noreply.github.com> Date: Wed, 4 Feb 2026 17:59:41 -0800 Subject: [PATCH 18/18] feat: account delete cleanup (#31519) Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com> --- api/services/account_service.py | 24 ++ .../enterprise/account_deletion_sync.py | 115 ++++++++ .../services/test_account_service.py | 24 +- .../enterprise/test_account_deletion_sync.py | 276 ++++++++++++++++++ 4 files changed, 435 insertions(+), 4 deletions(-) create mode 100644 api/services/enterprise/account_deletion_sync.py create mode 100644 api/tests/unit_tests/services/enterprise/test_account_deletion_sync.py diff --git a/api/services/account_service.py b/api/services/account_service.py index 35e4a505af..d3893c1207 100644 --- a/api/services/account_service.py +++ b/api/services/account_service.py @@ -327,6 +327,17 @@ class AccountService: @staticmethod def delete_account(account: Account): """Delete account. This method only adds a task to the queue for deletion.""" + # Queue account deletion sync tasks for all workspaces BEFORE account deletion (enterprise only) + from services.enterprise.account_deletion_sync import sync_account_deletion + + sync_success = sync_account_deletion(account_id=account.id, source="account_deleted") + if not sync_success: + logger.warning( + "Enterprise account deletion sync failed for account %s; proceeding with local deletion.", + account.id, + ) + + # Now proceed with async account deletion delete_account_task.delay(account.id) @staticmethod @@ -1230,6 +1241,19 @@ class TenantService: if dify_config.BILLING_ENABLED: BillingService.clean_billing_info_cache(tenant.id) + # Queue account deletion sync task for enterprise backend to reassign resources (enterprise only) + from services.enterprise.account_deletion_sync import sync_workspace_member_removal + + sync_success = sync_workspace_member_removal( + workspace_id=tenant.id, member_id=account.id, source="workspace_member_removed" + ) + if not sync_success: + logger.warning( + "Enterprise workspace member removal sync failed: workspace_id=%s, member_id=%s", + tenant.id, + account.id, + ) + @staticmethod def update_member_role(tenant: Tenant, member: Account, new_role: str, operator: Account): """Update member role""" diff --git a/api/services/enterprise/account_deletion_sync.py b/api/services/enterprise/account_deletion_sync.py new file mode 100644 index 0000000000..c7ff42894d --- /dev/null +++ b/api/services/enterprise/account_deletion_sync.py @@ -0,0 +1,115 @@ +import json +import logging +import uuid +from datetime import UTC, datetime + +from redis import RedisError + +from configs import dify_config +from extensions.ext_database import db +from extensions.ext_redis import redis_client +from models.account import TenantAccountJoin + +logger = logging.getLogger(__name__) + +ACCOUNT_DELETION_SYNC_QUEUE = "enterprise:member:sync:queue" +ACCOUNT_DELETION_SYNC_TASK_TYPE = "sync_member_deletion_from_workspace" + + +def _queue_task(workspace_id: str, member_id: str, *, source: str) -> bool: + """ + Queue an account deletion sync task to Redis. + + Internal helper function. Do not call directly - use the public functions instead. + + Args: + workspace_id: The workspace/tenant ID to sync + member_id: The member/account ID that was removed + source: Source of the sync request (for debugging/tracking) + + Returns: + bool: True if task was queued successfully, False otherwise + """ + try: + task = { + "task_id": str(uuid.uuid4()), + "workspace_id": workspace_id, + "member_id": member_id, + "retry_count": 0, + "created_at": datetime.now(UTC).isoformat(), + "source": source, + "type": ACCOUNT_DELETION_SYNC_TASK_TYPE, + } + + # Push to Redis list (queue) - LPUSH adds to the head, worker consumes from tail with RPOP + redis_client.lpush(ACCOUNT_DELETION_SYNC_QUEUE, json.dumps(task)) + + logger.info( + "Queued account deletion sync task for workspace %s, member %s, task_id: %s, source: %s", + workspace_id, + member_id, + task["task_id"], + source, + ) + return True + + except (RedisError, TypeError) as e: + logger.error( + "Failed to queue account deletion sync for workspace %s, member %s: %s", + workspace_id, + member_id, + str(e), + exc_info=True, + ) + # Don't raise - we don't want to fail member deletion if queueing fails + return False + + +def sync_workspace_member_removal(workspace_id: str, member_id: str, *, source: str) -> bool: + """ + Sync a single workspace member removal (enterprise only). + + Queues a task for the enterprise backend to reassign resources from the removed member. + Handles enterprise edition check internally. Safe to call in community edition (no-op). + + Args: + workspace_id: The workspace/tenant ID + member_id: The member/account ID that was removed + source: Source of the sync request (e.g., "workspace_member_removed") + + Returns: + bool: True if task was queued (or skipped in community), False if queueing failed + """ + if not dify_config.ENTERPRISE_ENABLED: + return True + + return _queue_task(workspace_id=workspace_id, member_id=member_id, source=source) + + +def sync_account_deletion(account_id: str, *, source: str) -> bool: + """ + Sync full account deletion across all workspaces (enterprise only). + + Fetches all workspace memberships for the account and queues a sync task for each. + Handles enterprise edition check internally. Safe to call in community edition (no-op). + + Args: + account_id: The account ID being deleted + source: Source of the sync request (e.g., "account_deleted") + + Returns: + bool: True if all tasks were queued (or skipped in community), False if any queueing failed + """ + if not dify_config.ENTERPRISE_ENABLED: + return True + + # Fetch all workspaces the account belongs to + workspace_joins = db.session.query(TenantAccountJoin).filter_by(account_id=account_id).all() + + # Queue sync task for each workspace + success = True + for join in workspace_joins: + if not _queue_task(workspace_id=join.tenant_id, member_id=account_id, source=source): + success = False + + return success diff --git a/api/tests/test_containers_integration_tests/services/test_account_service.py b/api/tests/test_containers_integration_tests/services/test_account_service.py index 4d4e77a802..a09a6e5c65 100644 --- a/api/tests/test_containers_integration_tests/services/test_account_service.py +++ b/api/tests/test_containers_integration_tests/services/test_account_service.py @@ -1016,7 +1016,7 @@ class TestAccountService: def test_delete_account(self, db_session_with_containers, mock_external_service_dependencies): """ - Test account deletion (should add task to queue). + Test account deletion (should add task to queue and sync to enterprise). """ fake = Faker() email = fake.email() @@ -1034,10 +1034,18 @@ class TestAccountService: password=password, ) - with patch("services.account_service.delete_account_task") as mock_delete_task: + with ( + patch("services.account_service.delete_account_task") as mock_delete_task, + patch("services.enterprise.account_deletion_sync.sync_account_deletion") as mock_sync, + ): + mock_sync.return_value = True + # Delete account AccountService.delete_account(account) + # Verify sync was called + mock_sync.assert_called_once_with(account_id=account.id, source="account_deleted") + # Verify task was added to queue mock_delete_task.delay.assert_called_once_with(account.id) @@ -1716,7 +1724,7 @@ class TestTenantService: def test_remove_member_from_tenant_success(self, db_session_with_containers, mock_external_service_dependencies): """ - Test successful member removal from tenant. + Test successful member removal from tenant (should sync to enterprise). """ fake = Faker() tenant_name = fake.company() @@ -1751,7 +1759,15 @@ class TestTenantService: TenantService.create_tenant_member(tenant, member_account, role="normal") # Remove member - TenantService.remove_member_from_tenant(tenant, member_account, owner_account) + with patch("services.enterprise.account_deletion_sync.sync_workspace_member_removal") as mock_sync: + mock_sync.return_value = True + + TenantService.remove_member_from_tenant(tenant, member_account, owner_account) + + # Verify sync was called + mock_sync.assert_called_once_with( + workspace_id=tenant.id, member_id=member_account.id, source="workspace_member_removed" + ) # Verify member was removed from extensions.ext_database import db diff --git a/api/tests/unit_tests/services/enterprise/test_account_deletion_sync.py b/api/tests/unit_tests/services/enterprise/test_account_deletion_sync.py new file mode 100644 index 0000000000..b66111902c --- /dev/null +++ b/api/tests/unit_tests/services/enterprise/test_account_deletion_sync.py @@ -0,0 +1,276 @@ +"""Unit tests for account deletion synchronization. + +This test module verifies the enterprise account deletion sync functionality, +including Redis queuing, error handling, and community vs enterprise behavior. +""" + +from unittest.mock import MagicMock, patch + +import pytest +from redis import RedisError + +from services.enterprise.account_deletion_sync import ( + _queue_task, + sync_account_deletion, + sync_workspace_member_removal, +) + + +class TestQueueTask: + """Unit tests for the _queue_task helper function.""" + + @pytest.fixture + def mock_redis_client(self): + """Mock redis_client for testing.""" + with patch("services.enterprise.account_deletion_sync.redis_client") as mock_redis: + yield mock_redis + + @pytest.fixture + def mock_uuid(self): + """Mock UUID generation for predictable task IDs.""" + with patch("services.enterprise.account_deletion_sync.uuid.uuid4") as mock_uuid_gen: + mock_uuid_gen.return_value = MagicMock(hex="test-task-id-1234") + yield mock_uuid_gen + + def test_queue_task_success(self, mock_redis_client, mock_uuid): + """Test successful task queueing to Redis.""" + # Arrange + workspace_id = "ws-123" + member_id = "member-456" + source = "test_source" + + # Act + result = _queue_task(workspace_id=workspace_id, member_id=member_id, source=source) + + # Assert + assert result is True + mock_redis_client.lpush.assert_called_once() + + # Verify the task payload structure + call_args = mock_redis_client.lpush.call_args[0] + assert call_args[0] == "enterprise:member:sync:queue" + + import json + + task_data = json.loads(call_args[1]) + assert task_data["workspace_id"] == workspace_id + assert task_data["member_id"] == member_id + assert task_data["source"] == source + assert task_data["type"] == "sync_member_deletion_from_workspace" + assert task_data["retry_count"] == 0 + assert "task_id" in task_data + assert "created_at" in task_data + + def test_queue_task_redis_error(self, mock_redis_client, caplog): + """Test handling of Redis connection errors.""" + # Arrange + mock_redis_client.lpush.side_effect = RedisError("Connection failed") + + # Act + result = _queue_task(workspace_id="ws-123", member_id="member-456", source="test_source") + + # Assert + assert result is False + assert "Failed to queue account deletion sync" in caplog.text + + def test_queue_task_type_error(self, mock_redis_client, caplog): + """Test handling of JSON serialization errors.""" + # Arrange + mock_redis_client.lpush.side_effect = TypeError("Cannot serialize") + + # Act + result = _queue_task(workspace_id="ws-123", member_id="member-456", source="test_source") + + # Assert + assert result is False + assert "Failed to queue account deletion sync" in caplog.text + + +class TestSyncWorkspaceMemberRemoval: + """Unit tests for sync_workspace_member_removal function.""" + + @pytest.fixture + def mock_queue_task(self): + """Mock _queue_task for testing.""" + with patch("services.enterprise.account_deletion_sync._queue_task") as mock_queue: + mock_queue.return_value = True + yield mock_queue + + def test_sync_workspace_member_removal_enterprise_enabled(self, mock_queue_task): + """Test sync when ENTERPRISE_ENABLED is True.""" + # Arrange + workspace_id = "ws-123" + member_id = "member-456" + source = "workspace_member_removed" + + with patch("services.enterprise.account_deletion_sync.dify_config") as mock_config: + mock_config.ENTERPRISE_ENABLED = True + + # Act + result = sync_workspace_member_removal(workspace_id=workspace_id, member_id=member_id, source=source) + + # Assert + assert result is True + mock_queue_task.assert_called_once_with(workspace_id=workspace_id, member_id=member_id, source=source) + + def test_sync_workspace_member_removal_enterprise_disabled(self, mock_queue_task): + """Test sync when ENTERPRISE_ENABLED is False (community edition).""" + # Arrange + with patch("services.enterprise.account_deletion_sync.dify_config") as mock_config: + mock_config.ENTERPRISE_ENABLED = False + + # Act + result = sync_workspace_member_removal(workspace_id="ws-123", member_id="member-456", source="test_source") + + # Assert + assert result is True + mock_queue_task.assert_not_called() + + def test_sync_workspace_member_removal_queue_failure(self, mock_queue_task): + """Test handling of queue task failures.""" + # Arrange + mock_queue_task.return_value = False + + with patch("services.enterprise.account_deletion_sync.dify_config") as mock_config: + mock_config.ENTERPRISE_ENABLED = True + + # Act + result = sync_workspace_member_removal(workspace_id="ws-123", member_id="member-456", source="test_source") + + # Assert + assert result is False + + +class TestSyncAccountDeletion: + """Unit tests for sync_account_deletion function.""" + + @pytest.fixture + def mock_db_session(self): + """Mock database session for testing.""" + with patch("services.enterprise.account_deletion_sync.db.session") as mock_session: + yield mock_session + + @pytest.fixture + def mock_queue_task(self): + """Mock _queue_task for testing.""" + with patch("services.enterprise.account_deletion_sync._queue_task") as mock_queue: + mock_queue.return_value = True + yield mock_queue + + def test_sync_account_deletion_enterprise_disabled(self, mock_db_session, mock_queue_task): + """Test sync when ENTERPRISE_ENABLED is False (community edition).""" + # Arrange + with patch("services.enterprise.account_deletion_sync.dify_config") as mock_config: + mock_config.ENTERPRISE_ENABLED = False + + # Act + result = sync_account_deletion(account_id="acc-123", source="account_deleted") + + # Assert + assert result is True + mock_db_session.query.assert_not_called() + mock_queue_task.assert_not_called() + + def test_sync_account_deletion_multiple_workspaces(self, mock_db_session, mock_queue_task): + """Test sync for account with multiple workspace memberships.""" + # Arrange + account_id = "acc-123" + + # Mock workspace joins + mock_join1 = MagicMock() + mock_join1.tenant_id = "tenant-1" + mock_join2 = MagicMock() + mock_join2.tenant_id = "tenant-2" + mock_join3 = MagicMock() + mock_join3.tenant_id = "tenant-3" + + mock_query = MagicMock() + mock_query.filter_by.return_value.all.return_value = [mock_join1, mock_join2, mock_join3] + mock_db_session.query.return_value = mock_query + + with patch("services.enterprise.account_deletion_sync.dify_config") as mock_config: + mock_config.ENTERPRISE_ENABLED = True + + # Act + result = sync_account_deletion(account_id=account_id, source="account_deleted") + + # Assert + assert result is True + assert mock_queue_task.call_count == 3 + + # Verify each workspace was queued + mock_queue_task.assert_any_call(workspace_id="tenant-1", member_id=account_id, source="account_deleted") + mock_queue_task.assert_any_call(workspace_id="tenant-2", member_id=account_id, source="account_deleted") + mock_queue_task.assert_any_call(workspace_id="tenant-3", member_id=account_id, source="account_deleted") + + def test_sync_account_deletion_no_workspaces(self, mock_db_session, mock_queue_task): + """Test sync for account with no workspace memberships.""" + # Arrange + mock_query = MagicMock() + mock_query.filter_by.return_value.all.return_value = [] + mock_db_session.query.return_value = mock_query + + with patch("services.enterprise.account_deletion_sync.dify_config") as mock_config: + mock_config.ENTERPRISE_ENABLED = True + + # Act + result = sync_account_deletion(account_id="acc-123", source="account_deleted") + + # Assert + assert result is True + mock_queue_task.assert_not_called() + + def test_sync_account_deletion_partial_failure(self, mock_db_session, mock_queue_task): + """Test sync when some tasks fail to queue.""" + # Arrange + account_id = "acc-123" + + # Mock workspace joins + mock_join1 = MagicMock() + mock_join1.tenant_id = "tenant-1" + mock_join2 = MagicMock() + mock_join2.tenant_id = "tenant-2" + mock_join3 = MagicMock() + mock_join3.tenant_id = "tenant-3" + + mock_query = MagicMock() + mock_query.filter_by.return_value.all.return_value = [mock_join1, mock_join2, mock_join3] + mock_db_session.query.return_value = mock_query + + # Mock queue_task to fail for second workspace + def queue_side_effect(workspace_id, member_id, source): + return workspace_id != "tenant-2" + + mock_queue_task.side_effect = queue_side_effect + + with patch("services.enterprise.account_deletion_sync.dify_config") as mock_config: + mock_config.ENTERPRISE_ENABLED = True + + # Act + result = sync_account_deletion(account_id=account_id, source="account_deleted") + + # Assert + assert result is False # Should return False if any task fails + assert mock_queue_task.call_count == 3 + + def test_sync_account_deletion_all_failures(self, mock_db_session, mock_queue_task): + """Test sync when all tasks fail to queue.""" + # Arrange + mock_join = MagicMock() + mock_join.tenant_id = "tenant-1" + + mock_query = MagicMock() + mock_query.filter_by.return_value.all.return_value = [mock_join] + mock_db_session.query.return_value = mock_query + + mock_queue_task.return_value = False + + with patch("services.enterprise.account_deletion_sync.dify_config") as mock_config: + mock_config.ENTERPRISE_ENABLED = True + + # Act + result = sync_account_deletion(account_id="acc-123", source="account_deleted") + + # Assert + assert result is False + mock_queue_task.assert_called_once()