diff --git a/web/app/components/datasets/create/empty-dataset-creation-modal/index.spec.tsx b/web/app/components/datasets/create/empty-dataset-creation-modal/index.spec.tsx index cef945c968..6be41458f4 100644 --- a/web/app/components/datasets/create/empty-dataset-creation-modal/index.spec.tsx +++ b/web/app/components/datasets/create/empty-dataset-creation-modal/index.spec.tsx @@ -771,7 +771,9 @@ describe('EmptyDatasetCreationModal', () => { // 3. These should NOT happen on error expect(mockInvalidDatasetList).not.toHaveBeenCalled() - expect(mockOnHide).not.toHaveBeenCalled() + // Dialog onClose can pass false; ensure submit didn't call onHide directly. + const submitHideCalls = mockOnHide.mock.calls.filter(call => call.length === 0) + expect(submitHideCalls).toHaveLength(0) expect(mockPush).not.toHaveBeenCalled() }) }) diff --git a/web/app/components/datasets/metadata/metadata-dataset/dataset-metadata-drawer.spec.tsx b/web/app/components/datasets/metadata/metadata-dataset/dataset-metadata-drawer.spec.tsx index fc1f0d0990..c82dbe1cdb 100644 --- a/web/app/components/datasets/metadata/metadata-dataset/dataset-metadata-drawer.spec.tsx +++ b/web/app/components/datasets/metadata/metadata-dataset/dataset-metadata-drawer.spec.tsx @@ -1,5 +1,5 @@ import type { BuiltInMetadataItem, MetadataItemWithValueLength } from '../types' -import { fireEvent, render, screen, waitFor } from '@testing-library/react' +import { fireEvent, render, screen, waitFor, within } from '@testing-library/react' import { describe, expect, it, vi } from 'vitest' import { DataType } from '../types' import DatasetMetadataDrawer from './dataset-metadata-drawer' @@ -270,22 +270,22 @@ describe('DatasetMetadataDrawer', () => { fireEvent.click(svgs[0]) } - // Change name and save + let renameModal: HTMLElement | undefined await waitFor(() => { - const inputs = document.querySelectorAll('input') - expect(inputs.length).toBeGreaterThan(0) + const dialogs = screen.getAllByRole('dialog') + renameModal = dialogs.find(dialog => dialog.querySelector('input')) as HTMLElement | undefined + expect(renameModal).toBeTruthy() }) - const inputs = document.querySelectorAll('input') - fireEvent.change(inputs[0], { target: { value: 'renamed_field' } }) + const modal = within(renameModal as HTMLElement) + const input = modal.getByRole('textbox') + fireEvent.change(input, { target: { value: 'renamed_field' } }) - // Find and click save button - const saveBtns = screen.getAllByText(/save/i) - const primaryBtn = saveBtns.find(btn => - btn.closest('button')?.classList.contains('btn-primary'), + const saveButton = modal.getAllByRole('button').find(button => + button.classList.contains('btn-primary'), ) - if (primaryBtn) - fireEvent.click(primaryBtn) + expect(saveButton).toBeTruthy() + fireEvent.click(saveButton as HTMLButtonElement) await waitFor(() => { expect(onRename).toHaveBeenCalled() diff --git a/web/app/components/plugins/marketplace/index.spec.tsx b/web/app/components/plugins/marketplace/index.spec.tsx index 654b667deb..858a724ec0 100644 --- a/web/app/components/plugins/marketplace/index.spec.tsx +++ b/web/app/components/plugins/marketplace/index.spec.tsx @@ -149,6 +149,16 @@ vi.mock('@/service/base', () => ({ }), })) +const mockMarketplaceClient = vi.hoisted(() => ({ + collectionPlugins: vi.fn().mockResolvedValue({ data: { plugins: [] } }), + collections: vi.fn().mockResolvedValue({ data: { collections: [] } }), + searchAdvanced: vi.fn().mockResolvedValue({ data: { plugins: [], bundles: [], total: 0 } }), +})) + +vi.mock('@/service/client', () => ({ + marketplaceClient: mockMarketplaceClient, +})) + // Mock config vi.mock('@/config', () => ({ API_PREFIX: '/api', @@ -1490,12 +1500,9 @@ describe('Async Utils', () => { { type: 'plugin', org: 'test', name: 'plugin2' }, ] - globalThis.fetch = vi.fn().mockResolvedValue( - new Response(JSON.stringify({ data: { plugins: mockPlugins } }), { - status: 200, - headers: { 'Content-Type': 'application/json' }, - }), - ) + mockMarketplaceClient.collectionPlugins.mockResolvedValueOnce({ + data: { plugins: mockPlugins }, + }) const { getMarketplacePluginsByCollectionId } = await import('./utils') const result = await getMarketplacePluginsByCollectionId('test-collection', { @@ -1504,12 +1511,12 @@ describe('Async Utils', () => { type: 'plugin', }) - expect(globalThis.fetch).toHaveBeenCalled() + expect(mockMarketplaceClient.collectionPlugins).toHaveBeenCalled() expect(result).toHaveLength(2) }) it('should handle fetch error and return empty array', async () => { - globalThis.fetch = vi.fn().mockRejectedValue(new Error('Network error')) + mockMarketplaceClient.collectionPlugins.mockRejectedValueOnce(new Error('Network error')) const { getMarketplacePluginsByCollectionId } = await import('./utils') const result = await getMarketplacePluginsByCollectionId('test-collection') @@ -1519,25 +1526,23 @@ describe('Async Utils', () => { it('should pass abort signal when provided', async () => { const mockPlugins = [{ type: 'plugins', org: 'test', name: 'plugin1' }] - globalThis.fetch = vi.fn().mockResolvedValue( - new Response(JSON.stringify({ data: { plugins: mockPlugins } }), { - status: 200, - headers: { 'Content-Type': 'application/json' }, - }), - ) + mockMarketplaceClient.collectionPlugins.mockResolvedValueOnce({ + data: { plugins: mockPlugins }, + }) const controller = new AbortController() const { getMarketplacePluginsByCollectionId } = await import('./utils') await getMarketplacePluginsByCollectionId('test-collection', {}, { signal: controller.signal }) - // oRPC uses Request objects, so check that fetch was called with a Request containing the right URL - expect(globalThis.fetch).toHaveBeenCalledWith( - expect.any(Request), - expect.any(Object), + expect(mockMarketplaceClient.collectionPlugins).toHaveBeenCalledWith( + expect.objectContaining({ + params: { collectionId: 'test-collection' }, + body: {}, + }), + expect.objectContaining({ + signal: controller.signal, + }), ) - const call = vi.mocked(globalThis.fetch).mock.calls[0] - const request = call[0] as Request - expect(request.url).toContain('test-collection') }) }) @@ -1548,23 +1553,11 @@ describe('Async Utils', () => { ] const mockPlugins = [{ type: 'plugins', org: 'test', name: 'plugin1' }] - let callCount = 0 - globalThis.fetch = vi.fn().mockImplementation(() => { - callCount++ - if (callCount === 1) { - return Promise.resolve( - new Response(JSON.stringify({ data: { collections: mockCollections } }), { - status: 200, - headers: { 'Content-Type': 'application/json' }, - }), - ) - } - return Promise.resolve( - new Response(JSON.stringify({ data: { plugins: mockPlugins } }), { - status: 200, - headers: { 'Content-Type': 'application/json' }, - }), - ) + mockMarketplaceClient.collections.mockResolvedValueOnce({ + data: { collections: mockCollections }, + }) + mockMarketplaceClient.collectionPlugins.mockResolvedValueOnce({ + data: { plugins: mockPlugins }, }) const { getMarketplaceCollectionsAndPlugins } = await import('./utils') @@ -1578,7 +1571,7 @@ describe('Async Utils', () => { }) it('should handle fetch error and return empty data', async () => { - globalThis.fetch = vi.fn().mockRejectedValue(new Error('Network error')) + mockMarketplaceClient.collections.mockRejectedValueOnce(new Error('Network error')) const { getMarketplaceCollectionsAndPlugins } = await import('./utils') const result = await getMarketplaceCollectionsAndPlugins() @@ -1588,12 +1581,9 @@ describe('Async Utils', () => { }) it('should append condition and type to URL when provided', async () => { - globalThis.fetch = vi.fn().mockResolvedValue( - new Response(JSON.stringify({ data: { collections: [] } }), { - status: 200, - headers: { 'Content-Type': 'application/json' }, - }), - ) + mockMarketplaceClient.collections.mockResolvedValueOnce({ + data: { collections: [] }, + }) const { getMarketplaceCollectionsAndPlugins } = await import('./utils') await getMarketplaceCollectionsAndPlugins({ @@ -1601,11 +1591,17 @@ describe('Async Utils', () => { type: 'bundle', }) - // oRPC uses Request objects, so check that fetch was called with a Request containing the right URL - expect(globalThis.fetch).toHaveBeenCalled() - const call = vi.mocked(globalThis.fetch).mock.calls[0] - const request = call[0] as Request - expect(request.url).toContain('condition=category%3Dtool') + expect(mockMarketplaceClient.collections).toHaveBeenCalledWith( + expect.objectContaining({ + query: expect.objectContaining({ + condition: 'category=tool', + type: 'bundle', + page: 1, + page_size: 100, + }), + }), + expect.any(Object), + ) }) }) }) diff --git a/web/app/components/workflow/nodes/tool/components/context-generate-modal/components/chat-view.tsx b/web/app/components/workflow/nodes/tool/components/context-generate-modal/components/chat-view.tsx new file mode 100644 index 0000000000..68dd732f15 --- /dev/null +++ b/web/app/components/workflow/nodes/tool/components/context-generate-modal/components/chat-view.tsx @@ -0,0 +1,193 @@ +import type { ReactNode } from 'react' +import type { ContextGenerateChatMessage } from '../hooks/use-context-generate' +import type { FormValue } from '@/app/components/header/account-setting/model-provider-page/declarations' +import type { TriggerProps } from '@/app/components/header/account-setting/model-provider-page/model-parameter-modal/trigger' +import type { Model } from '@/types/app' +import { RiArrowDownSLine, RiArrowRightLine, RiSendPlaneLine, RiSparklingLine } from '@remixicon/react' +import { useEffect, useRef } from 'react' +import { useTranslation } from 'react-i18next' +import Button from '@/app/components/base/button' +import LoadingAnim from '@/app/components/base/chat/chat/loading-anim' +import { CodeAssistant } from '@/app/components/base/icons/src/vender/line/general' +import ModelParameterModal from '@/app/components/header/account-setting/model-provider-page/model-parameter-modal' +import { cn } from '@/utils/classnames' + +type VersionOption = { + index: number + label: string +} + +type ChatViewProps = { + promptMessages: ContextGenerateChatMessage[] + versionOptions: VersionOption[] + currentVersionIndex: number + onSelectVersion: (index: number) => void + defaultAssistantMessage: string + isGenerating: boolean + inputValue: string + onInputChange: (value: string) => void + onGenerate: () => void + model: Model + onModelChange: (newValue: { modelId: string, provider: string, mode?: string, features?: string[] }) => void + onCompletionParamsChange: (newParams: FormValue) => void + renderModelTrigger: (params: TriggerProps) => ReactNode +} + +const ChatView = ({ + promptMessages, + versionOptions, + currentVersionIndex, + onSelectVersion, + defaultAssistantMessage, + isGenerating, + inputValue, + onInputChange, + onGenerate, + model, + onModelChange, + onCompletionParamsChange, + renderModelTrigger, +}: ChatViewProps) => { + const { t } = useTranslation() + const chatListRef = useRef(null) + + useEffect(() => { + if (!chatListRef.current) + return + if (promptMessages.length === 0 && !isGenerating) + return + chatListRef.current.scrollTop = chatListRef.current.scrollHeight + }, [isGenerating, promptMessages.length]) + + return ( + <> +
+
+ {(() => { + let assistantIndex = -1 + return promptMessages.map((message, index) => { + if (message.role === 'assistant') + assistantIndex += 1 + const versionMeta = message.role === 'assistant' ? versionOptions[assistantIndex] : null + const isSelected = versionMeta?.index === currentVersionIndex + const showThoughtProcess = message.role === 'assistant' && message.content !== defaultAssistantMessage + const durationLabel = message.role === 'assistant' && message.durationMs + ? `${(message.durationMs / 1000).toFixed(1)}s` + : null + return ( +
+ {message.role === 'user' + ? ( +
+ {message.content} +
+ ) + : ( +
+ {showThoughtProcess && ( +
+
+ +
+ + {message.content} + + {durationLabel && ( + + {durationLabel} + + )} + +
+ )} +
+ {showThoughtProcess ? defaultAssistantMessage : message.content} +
+ {versionMeta && ( + + )} +
+ )} +
+ ) + }) + })()} + {isGenerating && ( +
+ + {t('nodes.tool.contextGenerate.generating', { ns: 'workflow' })} +
+ )} +
+
+
+
+
+