From 2068640a4b0461859dc9db5e9122fad4d32d1845 Mon Sep 17 00:00:00 2001 From: FFXN <31929997+FFXN@users.noreply.github.com> Date: Tue, 3 Mar 2026 15:54:43 +0800 Subject: [PATCH 001/256] fix: Add the missing validation of doc_form in the service API. (#32892) --- api/controllers/console/datasets/datasets.py | 16 ++++++++++++++++ api/controllers/service_api/dataset/document.py | 16 +++++++++++++++- api/models/dataset.py | 2 ++ .../knowledge_entities/knowledge_entities.py | 15 ++++++++++++++- 4 files changed, 47 insertions(+), 2 deletions(-) diff --git a/api/controllers/console/datasets/datasets.py b/api/controllers/console/datasets/datasets.py index 45def1ae62..54303b2482 100644 --- a/api/controllers/console/datasets/datasets.py +++ b/api/controllers/console/datasets/datasets.py @@ -119,6 +119,14 @@ def _validate_indexing_technique(value: str | None) -> str | None: return value +def _validate_doc_form(value: str | None) -> str | None: + if value is None: + return value + if value not in Dataset.DOC_FORM_LIST: + raise ValueError("Invalid doc_form.") + return value + + class DatasetCreatePayload(BaseModel): name: str = Field(..., min_length=1, max_length=40) description: str = Field("", max_length=400) @@ -179,6 +187,14 @@ class IndexingEstimatePayload(BaseModel): raise ValueError("indexing_technique is required.") return result + @field_validator("doc_form") + @classmethod + def validate_doc_form(cls, value: str) -> str: + result = _validate_doc_form(value) + if result is None: + return "text_model" + return result + class ConsoleDatasetListQuery(BaseModel): page: int = Field(default=1, description="Page number") diff --git a/api/controllers/service_api/dataset/document.py b/api/controllers/service_api/dataset/document.py index 0aeb4a2d36..dc8da025d4 100644 --- a/api/controllers/service_api/dataset/document.py +++ b/api/controllers/service_api/dataset/document.py @@ -4,7 +4,7 @@ from uuid import UUID from flask import request from flask_restx import marshal -from pydantic import BaseModel, Field, model_validator +from pydantic import BaseModel, Field, field_validator, model_validator from sqlalchemy import desc, select from werkzeug.exceptions import Forbidden, NotFound @@ -60,6 +60,13 @@ class DocumentTextCreatePayload(BaseModel): embedding_model: str | None = None embedding_model_provider: str | None = None + @field_validator("doc_form") + @classmethod + def validate_doc_form(cls, value: str) -> str: + if value not in Dataset.DOC_FORM_LIST: + raise ValueError("Invalid doc_form.") + return value + DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}" @@ -72,6 +79,13 @@ class DocumentTextUpdate(BaseModel): doc_language: str = "English" retrieval_model: RetrievalModel | None = None + @field_validator("doc_form") + @classmethod + def validate_doc_form(cls, value: str) -> str: + if value not in Dataset.DOC_FORM_LIST: + raise ValueError("Invalid doc_form.") + return value + @model_validator(mode="after") def check_text_and_name(self) -> Self: if self.text is not None and self.name is None: diff --git a/api/models/dataset.py b/api/models/dataset.py index e7da2961bc..4ef39fcde1 100644 --- a/api/models/dataset.py +++ b/api/models/dataset.py @@ -19,6 +19,7 @@ from sqlalchemy.orm import Mapped, Session, mapped_column from configs import dify_config from core.rag.index_processor.constant.built_in_field import BuiltInField, MetadataDataSource +from core.rag.index_processor.constant.index_type import IndexStructureType from core.rag.index_processor.constant.query_type import QueryType from core.rag.retrieval.retrieval_methods import RetrievalMethod from core.tools.signature import sign_upload_file @@ -51,6 +52,7 @@ class Dataset(Base): INDEXING_TECHNIQUE_LIST = ["high_quality", "economy", None] PROVIDER_LIST = ["vendor", "external", None] + DOC_FORM_LIST = [member.value for member in IndexStructureType] id: Mapped[str] = mapped_column(StringUUID, default=lambda: str(uuid4())) tenant_id: Mapped[str] = mapped_column(StringUUID) diff --git a/api/services/entities/knowledge_entities/knowledge_entities.py b/api/services/entities/knowledge_entities/knowledge_entities.py index 8dc5b93501..66309f0e59 100644 --- a/api/services/entities/knowledge_entities/knowledge_entities.py +++ b/api/services/entities/knowledge_entities/knowledge_entities.py @@ -1,8 +1,9 @@ from enum import StrEnum from typing import Literal -from pydantic import BaseModel +from pydantic import BaseModel, field_validator +from core.rag.index_processor.constant.index_type import IndexStructureType from core.rag.retrieval.retrieval_methods import RetrievalMethod @@ -127,6 +128,18 @@ class KnowledgeConfig(BaseModel): name: str | None = None is_multimodal: bool = False + @field_validator("doc_form") + @classmethod + def validate_doc_form(cls, value: str) -> str: + valid_forms = [ + IndexStructureType.PARAGRAPH_INDEX, + IndexStructureType.QA_INDEX, + IndexStructureType.PARENT_CHILD_INDEX, + ] + if value not in valid_forms: + raise ValueError("Invalid doc_form.") + return value + class SegmentCreateArgs(BaseModel): content: str | None = None From 5e79d3588190272ee2f8197465a5b8df4cde23db Mon Sep 17 00:00:00 2001 From: Stephen Zhou Date: Tue, 3 Mar 2026 16:00:56 +0800 Subject: [PATCH 002/256] fix: downgrade node version to 22 (#32897) Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> --- .github/workflows/style.yml | 2 +- .github/workflows/tool-test-sdks.yaml | 2 +- .github/workflows/translate-i18n-claude.yml | 2 +- .github/workflows/web-tests.yml | 6 +++--- web/.nvmrc | 2 +- web/Dockerfile | 2 +- web/package.json | 2 +- 7 files changed, 9 insertions(+), 9 deletions(-) diff --git a/.github/workflows/style.yml b/.github/workflows/style.yml index cbd6edf94b..eb13c3d096 100644 --- a/.github/workflows/style.yml +++ b/.github/workflows/style.yml @@ -89,7 +89,7 @@ jobs: uses: actions/setup-node@v6 if: steps.changed-files.outputs.any_changed == 'true' with: - node-version: 24 + node-version: 22 cache: pnpm cache-dependency-path: ./web/pnpm-lock.yaml diff --git a/.github/workflows/tool-test-sdks.yaml b/.github/workflows/tool-test-sdks.yaml index ec392cb3b2..d9a1168636 100644 --- a/.github/workflows/tool-test-sdks.yaml +++ b/.github/workflows/tool-test-sdks.yaml @@ -28,7 +28,7 @@ jobs: - name: Use Node.js uses: actions/setup-node@v6 with: - node-version: 24 + node-version: 22 cache: '' cache-dependency-path: 'pnpm-lock.yaml' diff --git a/.github/workflows/translate-i18n-claude.yml b/.github/workflows/translate-i18n-claude.yml index 5d9440ff35..b431c36a8b 100644 --- a/.github/workflows/translate-i18n-claude.yml +++ b/.github/workflows/translate-i18n-claude.yml @@ -57,7 +57,7 @@ jobs: - name: Set up Node.js uses: actions/setup-node@v6 with: - node-version: 24 + node-version: 22 cache: pnpm cache-dependency-path: ./web/pnpm-lock.yaml diff --git a/.github/workflows/web-tests.yml b/.github/workflows/web-tests.yml index f50689636b..659620b2a9 100644 --- a/.github/workflows/web-tests.yml +++ b/.github/workflows/web-tests.yml @@ -39,7 +39,7 @@ jobs: - name: Setup Node.js uses: actions/setup-node@v6 with: - node-version: 24 + node-version: 22 cache: pnpm cache-dependency-path: ./web/pnpm-lock.yaml @@ -83,7 +83,7 @@ jobs: - name: Setup Node.js uses: actions/setup-node@v6 with: - node-version: 24 + node-version: 22 cache: pnpm cache-dependency-path: ./web/pnpm-lock.yaml @@ -457,7 +457,7 @@ jobs: uses: actions/setup-node@v6 if: steps.changed-files.outputs.any_changed == 'true' with: - node-version: 24 + node-version: 22 cache: pnpm cache-dependency-path: ./web/pnpm-lock.yaml diff --git a/web/.nvmrc b/web/.nvmrc index a45fd52cc5..2bd5a0a98a 100644 --- a/web/.nvmrc +++ b/web/.nvmrc @@ -1 +1 @@ -24 +22 diff --git a/web/Dockerfile b/web/Dockerfile index fe4ea1a579..9b24f9ea0a 100644 --- a/web/Dockerfile +++ b/web/Dockerfile @@ -1,5 +1,5 @@ # base image -FROM node:24-alpine AS base +FROM node:22-alpine AS base LABEL maintainer="takatost@gmail.com" # if you located in China, you can use aliyun mirror to speed up diff --git a/web/package.json b/web/package.json index a1f1c5966a..2291d78998 100644 --- a/web/package.json +++ b/web/package.json @@ -23,7 +23,7 @@ "and_qq >= 14.9" ], "engines": { - "node": ">=24" + "node": "^22" }, "scripts": { "dev": "next dev", From 7f67e1a2fc1bf0123196f50b835fc2ae50a7a770 Mon Sep 17 00:00:00 2001 From: yyh <92089059+lyzno1@users.noreply.github.com> Date: Tue, 3 Mar 2026 16:56:13 +0800 Subject: [PATCH 003/256] feat(web): overlay migration guardrails + Base UI primitives (#32824) Signed-off-by: yyh Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com> --- web/app/components/base/modal/index.tsx | 5 + web/app/components/base/modal/modal.tsx | 5 + .../base/portal-to-follow-elem/index.tsx | 16 + web/app/components/base/select/index.tsx | 7 +- web/app/components/base/tooltip/index.tsx | 7 +- web/app/components/base/ui/dialog/index.tsx | 58 + .../base/ui/dropdown-menu/index.tsx | 277 +++ web/app/components/base/ui/placement.ts | 29 + web/app/components/base/ui/popover/index.tsx | 67 + web/app/components/base/ui/select/index.tsx | 163 ++ web/app/components/base/ui/tooltip/index.tsx | 59 + .../account-dropdown/compliance.spec.tsx | 19 +- .../header/account-dropdown/compliance.tsx | 258 +-- .../header/account-dropdown/index.spec.tsx | 10 + .../header/account-dropdown/index.tsx | 369 ++-- .../account-dropdown/menu-item-content.tsx | 31 + .../header/account-dropdown/support.spec.tsx | 38 +- .../header/account-dropdown/support.tsx | 158 +- web/app/layout.tsx | 5 +- web/docs/lint.md | 2 + web/docs/overlay-migration.md | 50 + web/eslint-suppressions.json | 1585 ++++++++++++++++- web/eslint.config.mjs | 48 + web/eslint.constants.mjs | 29 + web/package.json | 1 + web/pnpm-lock.yaml | 53 + 26 files changed, 2926 insertions(+), 423 deletions(-) create mode 100644 web/app/components/base/ui/dialog/index.tsx create mode 100644 web/app/components/base/ui/dropdown-menu/index.tsx create mode 100644 web/app/components/base/ui/placement.ts create mode 100644 web/app/components/base/ui/popover/index.tsx create mode 100644 web/app/components/base/ui/select/index.tsx create mode 100644 web/app/components/base/ui/tooltip/index.tsx create mode 100644 web/app/components/header/account-dropdown/menu-item-content.tsx create mode 100644 web/docs/overlay-migration.md create mode 100644 web/eslint.constants.mjs diff --git a/web/app/components/base/modal/index.tsx b/web/app/components/base/modal/index.tsx index 023934b674..af11e5aa69 100644 --- a/web/app/components/base/modal/index.tsx +++ b/web/app/components/base/modal/index.tsx @@ -1,3 +1,8 @@ +/** + * @deprecated Use `@/app/components/base/ui/dialog` instead. + * This component will be removed after migration is complete. + * See: https://github.com/langgenius/dify/issues/32767 + */ import { Dialog, DialogPanel, DialogTitle, Transition, TransitionChild } from '@headlessui/react' import { noop } from 'es-toolkit/function' import { Fragment } from 'react' diff --git a/web/app/components/base/modal/modal.tsx b/web/app/components/base/modal/modal.tsx index 3ad08e2493..5a3e4d74c1 100644 --- a/web/app/components/base/modal/modal.tsx +++ b/web/app/components/base/modal/modal.tsx @@ -1,3 +1,8 @@ +/** + * @deprecated Use `@/app/components/base/ui/dialog` instead. + * This component will be removed after migration is complete. + * See: https://github.com/langgenius/dify/issues/32767 + */ import type { ButtonProps } from '@/app/components/base/button' import { noop } from 'es-toolkit/function' import { memo } from 'react' diff --git a/web/app/components/base/portal-to-follow-elem/index.tsx b/web/app/components/base/portal-to-follow-elem/index.tsx index c57fba9dd0..7d4f6baa9b 100644 --- a/web/app/components/base/portal-to-follow-elem/index.tsx +++ b/web/app/components/base/portal-to-follow-elem/index.tsx @@ -1,4 +1,16 @@ 'use client' +/** + * @deprecated Use semantic overlay primitives from `@/app/components/base/ui/` instead. + * This component will be removed after migration is complete. + * See: https://github.com/langgenius/dify/issues/32767 + * + * Migration guide: + * - Tooltip → `@/app/components/base/ui/tooltip` + * - Menu/Dropdown → `@/app/components/base/ui/dropdown-menu` + * - Popover → `@/app/components/base/ui/popover` + * - Dialog/Modal → `@/app/components/base/ui/dialog` + * - Select → `@/app/components/base/ui/select` + */ import type { OffsetOptions, Placement } from '@floating-ui/react' import { autoUpdate, @@ -33,6 +45,7 @@ export type PortalToFollowElemOptions = { triggerPopupSameWidth?: boolean } +/** @deprecated Use semantic overlay primitives instead. See #32767. */ export function usePortalToFollowElem({ placement = 'bottom', open: controlledOpen, @@ -110,6 +123,7 @@ export function usePortalToFollowElemContext() { return context } +/** @deprecated Use semantic overlay primitives instead. See #32767. */ export function PortalToFollowElem({ children, ...options @@ -124,6 +138,7 @@ export function PortalToFollowElem({ ) } +/** @deprecated Use semantic overlay primitives instead. See #32767. */ export const PortalToFollowElemTrigger = ( { ref: propRef, @@ -164,6 +179,7 @@ export const PortalToFollowElemTrigger = ( } PortalToFollowElemTrigger.displayName = 'PortalToFollowElemTrigger' +/** @deprecated Use semantic overlay primitives instead. See #32767. */ export const PortalToFollowElemContent = ( { ref: propRef, diff --git a/web/app/components/base/select/index.tsx b/web/app/components/base/select/index.tsx index ac59894771..144629c380 100644 --- a/web/app/components/base/select/index.tsx +++ b/web/app/components/base/select/index.tsx @@ -1,4 +1,9 @@ 'use client' +/** + * @deprecated Use `@/app/components/base/ui/select` instead. + * This component will be removed after migration is complete. + * See: https://github.com/langgenius/dify/issues/32767 + */ import type { FC } from 'react' import { Combobox, ComboboxButton, ComboboxInput, ComboboxOption, ComboboxOptions, Listbox, ListboxButton, ListboxOption, ListboxOptions } from '@headlessui/react' import { ChevronDownIcon, ChevronUpIcon, XMarkIcon } from '@heroicons/react/20/solid' @@ -236,7 +241,7 @@ const SimpleSelect: FC = ({ }} className={cn(`flex h-full w-full items-center rounded-lg border-0 bg-components-input-bg-normal pl-3 pr-10 focus-visible:bg-state-base-hover-alt focus-visible:outline-none group-hover/simple-select:bg-state-base-hover-alt sm:text-sm sm:leading-6 ${disabled ? 'cursor-not-allowed' : 'cursor-pointer'}`, className)} > - {selectedItem?.name ?? localPlaceholder} + {selectedItem?.name ?? localPlaceholder} {isLoading ? diff --git a/web/app/components/base/tooltip/index.tsx b/web/app/components/base/tooltip/index.tsx index d1047ff902..7eb15b2c19 100644 --- a/web/app/components/base/tooltip/index.tsx +++ b/web/app/components/base/tooltip/index.tsx @@ -1,4 +1,9 @@ 'use client' +/** + * @deprecated Use `@/app/components/base/ui/tooltip` instead. + * This component will be removed after migration is complete. + * See: https://github.com/langgenius/dify/issues/32767 + */ import type { OffsetOptions, Placement } from '@floating-ui/react' import type { FC } from 'react' import { RiQuestionLine } from '@remixicon/react' @@ -130,7 +135,7 @@ const Tooltip: FC = ({ {!!popupContent && (
{ diff --git a/web/app/components/base/ui/dialog/index.tsx b/web/app/components/base/ui/dialog/index.tsx new file mode 100644 index 0000000000..605fccee09 --- /dev/null +++ b/web/app/components/base/ui/dialog/index.tsx @@ -0,0 +1,58 @@ +'use client' + +// z-index strategy (relies on root `isolation: isolate` in layout.tsx): +// All overlay primitives (Tooltip / Popover / Dropdown / Select / Dialog) — z-50 +// Overlays share the same z-index; DOM order handles stacking when multiple are open. +// This ensures overlays inside a Dialog (e.g. a Tooltip on a dialog button) render +// above the dialog backdrop instead of being clipped by it. +// Toast — z-[99], always on top (defined in toast component) + +import { Dialog as BaseDialog } from '@base-ui/react/dialog' +import * as React from 'react' +import { cn } from '@/utils/classnames' + +export const Dialog = BaseDialog.Root +export const DialogTrigger = BaseDialog.Trigger +export const DialogTitle = BaseDialog.Title +export const DialogDescription = BaseDialog.Description +export const DialogClose = BaseDialog.Close + +type DialogContentProps = { + children: React.ReactNode + className?: string + overlayClassName?: string + closable?: boolean +} + +export function DialogContent({ + children, + className, + overlayClassName, + closable = false, +}: DialogContentProps) { + return ( + + + + {closable && ( + + + + )} + {children} + + + ) +} diff --git a/web/app/components/base/ui/dropdown-menu/index.tsx b/web/app/components/base/ui/dropdown-menu/index.tsx new file mode 100644 index 0000000000..e839fd24ef --- /dev/null +++ b/web/app/components/base/ui/dropdown-menu/index.tsx @@ -0,0 +1,277 @@ +'use client' + +import type { Placement } from '@/app/components/base/ui/placement' +import { Menu } from '@base-ui/react/menu' +import * as React from 'react' +import { parsePlacement } from '@/app/components/base/ui/placement' +import { cn } from '@/utils/classnames' + +export const DropdownMenu = Menu.Root +export const DropdownMenuPortal = Menu.Portal +export const DropdownMenuTrigger = Menu.Trigger +export const DropdownMenuSub = Menu.SubmenuRoot +export const DropdownMenuGroup = Menu.Group +export const DropdownMenuRadioGroup = Menu.RadioGroup + +const menuRowBaseClassName = 'mx-1 flex h-8 cursor-pointer select-none items-center rounded-lg px-2 outline-none' +const menuRowStateClassName = 'data-[highlighted]:bg-state-base-hover data-[disabled]:cursor-not-allowed data-[disabled]:opacity-50' + +export function DropdownMenuRadioItem({ + className, + ...props +}: React.ComponentPropsWithoutRef) { + return ( + + ) +} + +export function DropdownMenuRadioItemIndicator({ + className, + ...props +}: Omit, 'children'>) { + return ( + + + + ) +} + +export function DropdownMenuCheckboxItem({ + className, + ...props +}: React.ComponentPropsWithoutRef) { + return ( + + ) +} + +export function DropdownMenuCheckboxItemIndicator({ + className, + ...props +}: Omit, 'children'>) { + return ( + + + + ) +} + +export function DropdownMenuGroupLabel({ + className, + ...props +}: React.ComponentPropsWithoutRef) { + return ( + + ) +} + +type DropdownMenuContentProps = { + children: React.ReactNode + placement?: Placement + sideOffset?: number + alignOffset?: number + className?: string + popupClassName?: string + positionerProps?: Omit< + React.ComponentPropsWithoutRef, + 'children' | 'className' | 'side' | 'align' | 'sideOffset' | 'alignOffset' + > + popupProps?: Omit< + React.ComponentPropsWithoutRef, + 'children' | 'className' + > +} + +type DropdownMenuPopupRenderProps = Required> & { + placement: Placement + sideOffset: number + alignOffset: number + className?: string + popupClassName?: string + positionerProps?: DropdownMenuContentProps['positionerProps'] + popupProps?: DropdownMenuContentProps['popupProps'] +} + +function renderDropdownMenuPopup({ + children, + placement, + sideOffset, + alignOffset, + className, + popupClassName, + positionerProps, + popupProps, +}: DropdownMenuPopupRenderProps) { + const { side, align } = parsePlacement(placement) + + return ( + + + + {children} + + + + ) +} + +export function DropdownMenuContent({ + children, + placement = 'bottom-end', + sideOffset = 4, + alignOffset = 0, + className, + popupClassName, + positionerProps, + popupProps, +}: DropdownMenuContentProps) { + return renderDropdownMenuPopup({ + children, + placement, + sideOffset, + alignOffset, + className, + popupClassName, + positionerProps, + popupProps, + }) +} + +type DropdownMenuSubTriggerProps = React.ComponentPropsWithoutRef & { + destructive?: boolean +} + +export function DropdownMenuSubTrigger({ + className, + destructive, + children, + ...props +}: DropdownMenuSubTriggerProps) { + return ( + + {children} + + + ) +} + +type DropdownMenuSubContentProps = { + children: React.ReactNode + placement?: Placement + sideOffset?: number + alignOffset?: number + className?: string + popupClassName?: string + positionerProps?: DropdownMenuContentProps['positionerProps'] + popupProps?: DropdownMenuContentProps['popupProps'] +} + +export function DropdownMenuSubContent({ + children, + placement = 'left-start', + sideOffset = 4, + alignOffset = 0, + className, + popupClassName, + positionerProps, + popupProps, +}: DropdownMenuSubContentProps) { + return renderDropdownMenuPopup({ + children, + placement, + sideOffset, + alignOffset, + className, + popupClassName, + positionerProps, + popupProps, + }) +} + +type DropdownMenuItemProps = React.ComponentPropsWithoutRef & { + destructive?: boolean +} + +export function DropdownMenuItem({ + className, + destructive, + ...props +}: DropdownMenuItemProps) { + return ( + + ) +} + +export function DropdownMenuSeparator({ + className, + ...props +}: React.ComponentPropsWithoutRef) { + return ( + + ) +} diff --git a/web/app/components/base/ui/placement.ts b/web/app/components/base/ui/placement.ts new file mode 100644 index 0000000000..95f233fd56 --- /dev/null +++ b/web/app/components/base/ui/placement.ts @@ -0,0 +1,29 @@ +// Placement type for overlay positioning. +// Mirrors the Floating UI Placement spec — a stable set of 12 CSS-based position values. +// Reference: https://floating-ui.com/docs/useFloating#placement + +type Side = 'top' | 'bottom' | 'left' | 'right' +type Align = 'start' | 'center' | 'end' + +export type Placement + = 'top' + | 'top-start' + | 'top-end' + | 'right' + | 'right-start' + | 'right-end' + | 'bottom' + | 'bottom-start' + | 'bottom-end' + | 'left' + | 'left-start' + | 'left-end' + +export function parsePlacement(placement: Placement): { side: Side, align: Align } { + const [side, align] = placement.split('-') as [Side, Align | undefined] + + return { + side, + align: align ?? 'center', + } +} diff --git a/web/app/components/base/ui/popover/index.tsx b/web/app/components/base/ui/popover/index.tsx new file mode 100644 index 0000000000..8c806cd9da --- /dev/null +++ b/web/app/components/base/ui/popover/index.tsx @@ -0,0 +1,67 @@ +'use client' + +import type { Placement } from '@/app/components/base/ui/placement' +import { Popover as BasePopover } from '@base-ui/react/popover' +import * as React from 'react' +import { parsePlacement } from '@/app/components/base/ui/placement' +import { cn } from '@/utils/classnames' + +export const Popover = BasePopover.Root +export const PopoverTrigger = BasePopover.Trigger +export const PopoverClose = BasePopover.Close +export const PopoverTitle = BasePopover.Title +export const PopoverDescription = BasePopover.Description + +type PopoverContentProps = { + children: React.ReactNode + placement?: Placement + sideOffset?: number + alignOffset?: number + className?: string + popupClassName?: string + positionerProps?: Omit< + React.ComponentPropsWithoutRef, + 'children' | 'className' | 'side' | 'align' | 'sideOffset' | 'alignOffset' + > + popupProps?: Omit< + React.ComponentPropsWithoutRef, + 'children' | 'className' + > +} + +export function PopoverContent({ + children, + placement = 'bottom', + sideOffset = 8, + alignOffset = 0, + className, + popupClassName, + positionerProps, + popupProps, +}: PopoverContentProps) { + const { side, align } = parsePlacement(placement) + + return ( + + + + {children} + + + + ) +} diff --git a/web/app/components/base/ui/select/index.tsx b/web/app/components/base/ui/select/index.tsx new file mode 100644 index 0000000000..c7e51e0939 --- /dev/null +++ b/web/app/components/base/ui/select/index.tsx @@ -0,0 +1,163 @@ +'use client' + +import type { Placement } from '@/app/components/base/ui/placement' +import { Select as BaseSelect } from '@base-ui/react/select' +import * as React from 'react' +import { parsePlacement } from '@/app/components/base/ui/placement' +import { cn } from '@/utils/classnames' + +export const Select = BaseSelect.Root +export const SelectValue = BaseSelect.Value +export const SelectGroup = BaseSelect.Group +export const SelectGroupLabel = BaseSelect.GroupLabel +export const SelectSeparator = BaseSelect.Separator + +type SelectTriggerProps = React.ComponentPropsWithoutRef & { + clearable?: boolean + onClear?: () => void + loading?: boolean +} + +export function SelectTrigger({ + className, + children, + clearable = false, + onClear, + loading = false, + ...props +}: SelectTriggerProps) { + const showClear = clearable && !loading + + return ( + + {children} + {loading + ? ( + + + + ) + : showClear + ? ( + { + e.stopPropagation() + onClear?.() + }} + onMouseDown={(e) => { + e.stopPropagation() + }} + > + + + ) + : ( + + + + )} + + ) +} + +type SelectContentProps = { + children: React.ReactNode + placement?: Placement + sideOffset?: number + alignOffset?: number + className?: string + popupClassName?: string + listClassName?: string + positionerProps?: Omit< + React.ComponentPropsWithoutRef, + 'children' | 'className' | 'side' | 'align' | 'sideOffset' | 'alignOffset' + > + popupProps?: Omit< + React.ComponentPropsWithoutRef, + 'children' | 'className' + > + listProps?: Omit< + React.ComponentPropsWithoutRef, + 'children' | 'className' + > +} + +export function SelectContent({ + children, + placement = 'bottom-start', + sideOffset = 4, + alignOffset = 0, + className, + popupClassName, + listClassName, + positionerProps, + popupProps, + listProps, +}: SelectContentProps) { + const { side, align } = parsePlacement(placement) + + return ( + + + + + {children} + + + + + ) +} + +export function SelectItem({ + className, + children, + ...props +}: React.ComponentPropsWithoutRef) { + return ( + + + {children} + + + + + + ) +} diff --git a/web/app/components/base/ui/tooltip/index.tsx b/web/app/components/base/ui/tooltip/index.tsx new file mode 100644 index 0000000000..1a41cc0f56 --- /dev/null +++ b/web/app/components/base/ui/tooltip/index.tsx @@ -0,0 +1,59 @@ +'use client' + +import type { Placement } from '@/app/components/base/ui/placement' +import { Tooltip as BaseTooltip } from '@base-ui/react/tooltip' +import * as React from 'react' +import { parsePlacement } from '@/app/components/base/ui/placement' +import { cn } from '@/utils/classnames' + +type TooltipContentVariant = 'default' | 'plain' + +export type TooltipContentProps = { + children: React.ReactNode + placement?: Placement + sideOffset?: number + alignOffset?: number + className?: string + popupClassName?: string + variant?: TooltipContentVariant +} & Omit, 'children' | 'className'> + +export function TooltipContent({ + children, + placement = 'top', + sideOffset = 8, + alignOffset = 0, + className, + popupClassName, + variant = 'default', + ...props +}: TooltipContentProps) { + const { side, align } = parsePlacement(placement) + + return ( + + + + {children} + + + + ) +} + +export const TooltipProvider = BaseTooltip.Provider +export const Tooltip = BaseTooltip.Root +export const TooltipTrigger = BaseTooltip.Trigger diff --git a/web/app/components/header/account-dropdown/compliance.spec.tsx b/web/app/components/header/account-dropdown/compliance.spec.tsx index 54a0460f82..1eb747e154 100644 --- a/web/app/components/header/account-dropdown/compliance.spec.tsx +++ b/web/app/components/header/account-dropdown/compliance.spec.tsx @@ -1,6 +1,7 @@ import type { ModalContextState } from '@/context/modal-context' import { QueryClient, QueryClientProvider } from '@tanstack/react-query' import { fireEvent, render, screen, waitFor } from '@testing-library/react' +import { DropdownMenu, DropdownMenuContent, DropdownMenuTrigger } from '@/app/components/base/ui/dropdown-menu' import { Plan } from '@/app/components/billing/type' import { ACCOUNT_SETTING_TAB } from '@/app/components/header/account-setting/constants' import { useModalContext } from '@/context/modal-context' @@ -70,16 +71,26 @@ describe('Compliance', () => { ) } - // Wrapper for tests that need the menu open + const renderCompliance = () => { + return renderWithQueryClient( + {}}> + open + + + + , + ) + } + const openMenuAndRender = () => { - renderWithQueryClient() - fireEvent.click(screen.getByRole('button')) + renderCompliance() + fireEvent.click(screen.getByText('common.userProfile.compliance')) } describe('Rendering', () => { it('should render compliance menu trigger', () => { // Act - renderWithQueryClient() + renderCompliance() // Assert expect(screen.getByText('common.userProfile.compliance')).toBeInTheDocument() diff --git a/web/app/components/header/account-dropdown/compliance.tsx b/web/app/components/header/account-dropdown/compliance.tsx index 6bc5b5c3f1..1627a5a052 100644 --- a/web/app/components/header/account-dropdown/compliance.tsx +++ b/web/app/components/header/account-dropdown/compliance.tsx @@ -1,9 +1,9 @@ -import type { FC, MouseEvent } from 'react' -import { Menu, MenuButton, MenuItem, MenuItems, Transition } from '@headlessui/react' -import { RiArrowDownCircleLine, RiArrowRightSLine, RiVerifiedBadgeLine } from '@remixicon/react' +import type { ReactNode } from 'react' import { useMutation } from '@tanstack/react-query' -import { Fragment, useCallback } from 'react' +import { useCallback } from 'react' import { useTranslation } from 'react-i18next' +import { DropdownMenuGroup, DropdownMenuItem, DropdownMenuSub, DropdownMenuSubContent, DropdownMenuSubTrigger } from '@/app/components/base/ui/dropdown-menu' +import { Tooltip, TooltipContent, TooltipTrigger } from '@/app/components/base/ui/tooltip' import { Plan } from '@/app/components/billing/type' import { ACCOUNT_SETTING_TAB } from '@/app/components/header/account-setting/constants' import { useModalContext } from '@/context/modal-context' @@ -11,14 +11,14 @@ import { useProviderContext } from '@/context/provider-context' import { getDocDownloadUrl } from '@/service/common' import { cn } from '@/utils/classnames' import { downloadUrl } from '@/utils/download' -import Button from '../../base/button' import Gdpr from '../../base/icons/src/public/common/Gdpr' import Iso from '../../base/icons/src/public/common/Iso' import Soc2 from '../../base/icons/src/public/common/Soc2' import SparklesSoft from '../../base/icons/src/public/common/SparklesSoft' import PremiumBadge from '../../base/premium-badge' +import Spinner from '../../base/spinner' import Toast from '../../base/toast' -import Tooltip from '../../base/tooltip' +import { MenuItemContent } from './menu-item-content' enum DocName { SOC2_Type_I = 'SOC2_Type_I', @@ -27,27 +27,83 @@ enum DocName { GDPR = 'GDPR', } -type UpgradeOrDownloadProps = { - doc_name: DocName +type ComplianceDocActionVisualProps = { + isCurrentPlanCanDownload: boolean + isPending: boolean + tooltipText: string + downloadText: string + upgradeText: string } -const UpgradeOrDownload: FC = ({ doc_name }) => { + +function ComplianceDocActionVisual({ + isCurrentPlanCanDownload, + isPending, + tooltipText, + downloadText, + upgradeText, +}: ComplianceDocActionVisualProps) { + if (isCurrentPlanCanDownload) { + return ( +
+ + {downloadText} + {isPending && } +
+ ) + } + + const canShowUpgradeTooltip = tooltipText.length > 0 + + return ( + + + +
+ {upgradeText} +
+ + )} + /> + {canShowUpgradeTooltip && ( + + {tooltipText} + + )} +
+ ) +} + +type ComplianceDocRowItemProps = { + icon: ReactNode + label: ReactNode + docName: DocName +} + +function ComplianceDocRowItem({ + icon, + label, + docName, +}: ComplianceDocRowItemProps) { const { t } = useTranslation() const { plan } = useProviderContext() const { setShowPricingModal, setShowAccountSettingModal } = useModalContext() const isFreePlan = plan.type === Plan.sandbox - const handlePlanClick = useCallback(() => { - if (isFreePlan) - setShowPricingModal() - else - setShowAccountSettingModal({ payload: ACCOUNT_SETTING_TAB.BILLING }) - }, [isFreePlan, setShowAccountSettingModal, setShowPricingModal]) - const { isPending, mutate: downloadCompliance } = useMutation({ - mutationKey: ['downloadCompliance', doc_name], + mutationKey: ['downloadCompliance', docName], mutationFn: async () => { try { - const ret = await getDocDownloadUrl(doc_name) + const ret = await getDocDownloadUrl(docName) downloadUrl({ url: ret.url }) Toast.notify({ type: 'success', @@ -63,6 +119,7 @@ const UpgradeOrDownload: FC = ({ doc_name }) => { } }, }) + const whichPlanCanDownloadCompliance = { [DocName.SOC2_Type_I]: [Plan.professional, Plan.team], [DocName.SOC2_Type_II]: [Plan.team], @@ -70,118 +127,85 @@ const UpgradeOrDownload: FC = ({ doc_name }) => { [DocName.GDPR]: [Plan.team, Plan.professional, Plan.sandbox], } - const isCurrentPlanCanDownload = whichPlanCanDownloadCompliance[doc_name].includes(plan.type) - const handleDownloadClick = useCallback((e: MouseEvent) => { - e.preventDefault() - downloadCompliance() - }, [downloadCompliance]) - if (isCurrentPlanCanDownload) { - return ( - - ) - } + const isCurrentPlanCanDownload = whichPlanCanDownloadCompliance[docName].includes(plan.type) + + const handleSelect = useCallback(() => { + if (isCurrentPlanCanDownload) { + if (!isPending) + downloadCompliance() + return + } + + if (isFreePlan) + setShowPricingModal() + else + setShowAccountSettingModal({ payload: ACCOUNT_SETTING_TAB.BILLING }) + }, [downloadCompliance, isCurrentPlanCanDownload, isFreePlan, isPending, setShowAccountSettingModal, setShowPricingModal]) + const upgradeTooltip: Record = { [Plan.sandbox]: t('compliance.sandboxUpgradeTooltip', { ns: 'common' }), [Plan.professional]: t('compliance.professionalUpgradeTooltip', { ns: 'common' }), [Plan.team]: '', [Plan.enterprise]: '', } + return ( - - - -
- - {t('upgradeBtn.encourageShort', { ns: 'billing' })} - -
-
-
+ + {icon} +
{label}
+ +
) } +// Submenu-only: this component must be rendered within an existing DropdownMenu root. export default function Compliance() { - const itemClassName = ` - flex items-center w-full h-10 pl-1 pr-2 py-1 text-text-secondary system-md-regular - rounded-lg hover:bg-state-base-hover gap-1 -` const { t } = useTranslation() return ( - - { - ({ open }) => ( - <> - - -
{t('userProfile.compliance', { ns: 'common' })}
- -
- - -
- -
- -
{t('compliance.soc2Type1', { ns: 'common' })}
- -
-
- -
- -
{t('compliance.soc2Type2', { ns: 'common' })}
- -
-
- -
- -
{t('compliance.iso27001', { ns: 'common' })}
- -
-
- -
- -
{t('compliance.gdpr', { ns: 'common' })}
- -
-
-
-
-
- - ) - } -
+ + + + + + + } + label={t('compliance.soc2Type1', { ns: 'common' })} + docName={DocName.SOC2_Type_I} + /> + } + label={t('compliance.soc2Type2', { ns: 'common' })} + docName={DocName.SOC2_Type_II} + /> + } + label={t('compliance.iso27001', { ns: 'common' })} + docName={DocName.ISO_27001} + /> + } + label={t('compliance.gdpr', { ns: 'common' })} + docName={DocName.GDPR} + /> + + + ) } diff --git a/web/app/components/header/account-dropdown/index.spec.tsx b/web/app/components/header/account-dropdown/index.spec.tsx index a954351267..60e00868c1 100644 --- a/web/app/components/header/account-dropdown/index.spec.tsx +++ b/web/app/components/header/account-dropdown/index.spec.tsx @@ -65,6 +65,7 @@ vi.mock('@/context/i18n', () => ({ const { mockConfig, mockEnv } = vi.hoisted(() => ({ mockConfig: { IS_CLOUD_EDITION: false, + ZENDESK_WIDGET_KEY: '', }, mockEnv: { env: { @@ -74,6 +75,7 @@ const { mockConfig, mockEnv } = vi.hoisted(() => ({ })) vi.mock('@/config', () => ({ get IS_CLOUD_EDITION() { return mockConfig.IS_CLOUD_EDITION }, + get ZENDESK_WIDGET_KEY() { return mockConfig.ZENDESK_WIDGET_KEY }, IS_DEV: false, IS_CE_EDITION: false, })) @@ -187,6 +189,14 @@ describe('AccountDropdown', () => { expect(screen.getByText('test@example.com')).toBeInTheDocument() }) + it('should set an accessible label on avatar trigger when menu trigger is rendered', () => { + // Act + renderWithRouter() + + // Assert + expect(screen.getByRole('button', { name: 'common.account.account' })).toBeInTheDocument() + }) + it('should show EDU badge for education accounts', () => { // Arrange vi.mocked(useProviderContext).mockReturnValue({ diff --git a/web/app/components/header/account-dropdown/index.tsx b/web/app/components/header/account-dropdown/index.tsx index 983f9e434d..4cf1f4efda 100644 --- a/web/app/components/header/account-dropdown/index.tsx +++ b/web/app/components/header/account-dropdown/index.tsx @@ -1,26 +1,15 @@ 'use client' -import { Menu, MenuButton, MenuItem, MenuItems, Transition } from '@headlessui/react' -import { - RiAccountCircleLine, - RiArrowRightUpLine, - RiBookOpenLine, - RiGithubLine, - RiGraduationCapFill, - RiInformation2Line, - RiLogoutBoxRLine, - RiMap2Line, - RiSettings3Line, - RiStarLine, - RiTShirt2Line, -} from '@remixicon/react' + +import type { MouseEventHandler, ReactNode } from 'react' import Link from 'next/link' import { useRouter } from 'next/navigation' -import { Fragment, useState } from 'react' +import { useState } from 'react' import { useTranslation } from 'react-i18next' import { resetUser } from '@/app/components/base/amplitude/utils' import Avatar from '@/app/components/base/avatar' import PremiumBadge from '@/app/components/base/premium-badge' import ThemeSwitcher from '@/app/components/base/theme-switcher' +import { DropdownMenu, DropdownMenuContent, DropdownMenuGroup, DropdownMenuItem, DropdownMenuSeparator, DropdownMenuTrigger } from '@/app/components/base/ui/dropdown-menu' import { ACCOUNT_SETTING_TAB } from '@/app/components/header/account-setting/constants' import { IS_CLOUD_EDITION } from '@/config' import { useAppContext } from '@/context/app-context' @@ -35,15 +24,90 @@ import AccountAbout from '../account-about' import GithubStar from '../github-star' import Indicator from '../indicator' import Compliance from './compliance' +import { ExternalLinkIndicator, MenuItemContent } from './menu-item-content' import Support from './support' +type AccountMenuRouteItemProps = { + href: string + iconClassName: string + label: ReactNode + trailing?: ReactNode +} + +function AccountMenuRouteItem({ + href, + iconClassName, + label, + trailing, +}: AccountMenuRouteItemProps) { + return ( + } + > + + + ) +} + +type AccountMenuExternalItemProps = { + href: string + iconClassName: string + label: ReactNode + trailing?: ReactNode +} + +function AccountMenuExternalItem({ + href, + iconClassName, + label, + trailing, +}: AccountMenuExternalItemProps) { + return ( + } + > + + + ) +} + +type AccountMenuActionItemProps = { + iconClassName: string + label: ReactNode + onClick?: MouseEventHandler + trailing?: ReactNode +} + +function AccountMenuActionItem({ + iconClassName, + label, + onClick, + trailing, +}: AccountMenuActionItemProps) { + return ( + + + + ) +} + +type AccountMenuSectionProps = { + children: ReactNode +} + +function AccountMenuSection({ children }: AccountMenuSectionProps) { + return {children} +} + export default function AppSelector() { - const itemClassName = ` - flex items-center w-full h-8 pl-3 pr-2 text-text-secondary system-md-regular - rounded-lg hover:bg-state-base-hover cursor-pointer gap-1 - ` const router = useRouter() const [aboutVisible, setAboutVisible] = useState(false) + const [isAccountMenuOpen, setIsAccountMenuOpen] = useState(false) const { systemFeatures } = useGlobalPublicStore() const { t } = useTranslation() @@ -68,161 +132,124 @@ export default function AppSelector() { } return ( -
- - { - ({ open, close }) => ( - <> - - - - - -
- -
-
-
- {userProfile.name} - {isEducationAccount && ( - - - EDU - - )} -
-
{userProfile.email}
-
- -
-
- - - -
{t('account.account', { ns: 'common' })}
- - -
- -
setShowAccountSettingModal({ payload: ACCOUNT_SETTING_TAB.MEMBERS })} - > - -
{t('userProfile.settings', { ns: 'common' })}
-
-
-
- {!systemFeatures.branding.enabled && ( - <> -
- - - -
{t('userProfile.helpCenter', { ns: 'common' })}
- - -
- - {IS_CLOUD_EDITION && isCurrentWorkspaceOwner && } -
-
- - - -
{t('userProfile.roadmap', { ns: 'common' })}
- - -
- - - -
{t('userProfile.github', { ns: 'common' })}
-
- - -
- -
- { - env.NEXT_PUBLIC_SITE_ABOUT !== 'hide' && ( - -
setAboutVisible(true)} - > - -
{t('userProfile.about', { ns: 'common' })}
-
-
{langGeniusVersionInfo.current_version}
- -
-
-
- ) - } -
- +
+ + + + + + +
+
+
+ {userProfile.name} + {isEducationAccount && ( + + + EDU + )} - -
-
- -
{t('theme.theme', { ns: 'common' })}
- -
+
+
{userProfile.email}
+
+ +
+ } + /> + setShowAccountSettingModal({ payload: ACCOUNT_SETTING_TAB.MEMBERS })} + /> + + + {!systemFeatures.branding.enabled && ( + <> + + } + /> + setIsAccountMenuOpen(false)} /> + {IS_CLOUD_EDITION && isCurrentWorkspaceOwner && } + + + + } + /> + + +
- - -
handleLogout()}> -
- -
{t('userProfile.logout', { ns: 'common' })}
-
-
-
- - + )} + /> + { + env.NEXT_PUBLIC_SITE_ABOUT !== 'hide' && ( + { + setAboutVisible(true) + setIsAccountMenuOpen(false) + }} + trailing={( +
+
{langGeniusVersionInfo.current_version}
+ +
+ )} + /> + ) + } + + - ) - } -
+ )} + + e.preventDefault()} + > + } + /> + + + + + { + void handleLogout() + }} + /> + + + { aboutVisible && setAboutVisible(false)} langGeniusVersionInfo={langGeniusVersionInfo} /> } diff --git a/web/app/components/header/account-dropdown/menu-item-content.tsx b/web/app/components/header/account-dropdown/menu-item-content.tsx new file mode 100644 index 0000000000..47f0042047 --- /dev/null +++ b/web/app/components/header/account-dropdown/menu-item-content.tsx @@ -0,0 +1,31 @@ +import type { ReactNode } from 'react' +import { cn } from '@/utils/classnames' + +const menuLabelClassName = 'min-w-0 grow truncate px-1 text-text-secondary system-md-regular' +const menuLeadingIconClassName = 'size-4 shrink-0 text-text-tertiary' + +export const menuTrailingIconClassName = 'size-[14px] shrink-0 text-text-tertiary' + +type MenuItemContentProps = { + iconClassName: string + label: ReactNode + trailing?: ReactNode +} + +export function MenuItemContent({ + iconClassName, + label, + trailing, +}: MenuItemContentProps) { + return ( + <> + +
{label}
+ {trailing} + + ) +} + +export function ExternalLinkIndicator() { + return +} diff --git a/web/app/components/header/account-dropdown/support.spec.tsx b/web/app/components/header/account-dropdown/support.spec.tsx index b30a290ea5..90bcb9f3ec 100644 --- a/web/app/components/header/account-dropdown/support.spec.tsx +++ b/web/app/components/header/account-dropdown/support.spec.tsx @@ -1,6 +1,7 @@ import type { AppContextValue } from '@/context/app-context' import { fireEvent, render, screen } from '@testing-library/react' +import { DropdownMenu, DropdownMenuContent, DropdownMenuTrigger } from '@/app/components/base/ui/dropdown-menu' import { Plan } from '@/app/components/billing/type' import { useAppContext } from '@/context/app-context' import { baseProviderContextValue, useProviderContext } from '@/context/provider-context' @@ -93,10 +94,21 @@ describe('Support', () => { }) }) + const renderSupport = () => { + return render( + {}}> + open + + + + , + ) + } + describe('Rendering', () => { it('should render support menu trigger', () => { // Act - render() + renderSupport() // Assert expect(screen.getByText('common.userProfile.support')).toBeInTheDocument() @@ -104,8 +116,8 @@ describe('Support', () => { it('should show forum and community links when opened', () => { // Act - render() - fireEvent.click(screen.getByRole('button')) + renderSupport() + fireEvent.click(screen.getByText('common.userProfile.support')) // Assert expect(screen.getByText('common.userProfile.forum')).toBeInTheDocument() @@ -116,8 +128,8 @@ describe('Support', () => { describe('Plan-based Channels', () => { it('should show "Contact Us" when ZENDESK_WIDGET_KEY is present', () => { // Act - render() - fireEvent.click(screen.getByRole('button')) + renderSupport() + fireEvent.click(screen.getByText('common.userProfile.support')) // Assert expect(screen.getByText('common.userProfile.contactUs')).toBeInTheDocument() @@ -134,8 +146,8 @@ describe('Support', () => { }) // Act - render() - fireEvent.click(screen.getByRole('button')) + renderSupport() + fireEvent.click(screen.getByText('common.userProfile.support')) // Assert expect(screen.queryByText('common.userProfile.contactUs')).not.toBeInTheDocument() @@ -147,8 +159,8 @@ describe('Support', () => { mockZendeskKey.value = '' // Act - render() - fireEvent.click(screen.getByRole('button')) + renderSupport() + fireEvent.click(screen.getByText('common.userProfile.support')) // Assert expect(screen.getByText('common.userProfile.emailSupport')).toBeInTheDocument() @@ -159,8 +171,8 @@ describe('Support', () => { describe('Interactions and Links', () => { it('should call toggleZendeskWindow and closeAccountDropdown when "Contact Us" is clicked', () => { // Act - render() - fireEvent.click(screen.getByRole('button')) + renderSupport() + fireEvent.click(screen.getByText('common.userProfile.support')) fireEvent.click(screen.getByText('common.userProfile.contactUs')) // Assert @@ -170,8 +182,8 @@ describe('Support', () => { it('should have correct forum and community links', () => { // Act - render() - fireEvent.click(screen.getByRole('button')) + renderSupport() + fireEvent.click(screen.getByText('common.userProfile.support')) // Assert const forumLink = screen.getByText('common.userProfile.forum').closest('a') diff --git a/web/app/components/header/account-dropdown/support.tsx b/web/app/components/header/account-dropdown/support.tsx index 7873b676c3..7ec2766977 100644 --- a/web/app/components/header/account-dropdown/support.tsx +++ b/web/app/components/header/account-dropdown/support.tsx @@ -1,119 +1,85 @@ -import { Menu, MenuButton, MenuItem, MenuItems, Transition } from '@headlessui/react' -import { RiArrowRightSLine, RiArrowRightUpLine, RiChatSmile2Line, RiDiscordLine, RiDiscussLine, RiMailSendLine, RiQuestionLine } from '@remixicon/react' -import Link from 'next/link' -import { Fragment } from 'react' import { useTranslation } from 'react-i18next' +import { DropdownMenuGroup, DropdownMenuItem, DropdownMenuSub, DropdownMenuSubContent, DropdownMenuSubTrigger } from '@/app/components/base/ui/dropdown-menu' import { toggleZendeskWindow } from '@/app/components/base/zendesk/utils' import { Plan } from '@/app/components/billing/type' import { ZENDESK_WIDGET_KEY } from '@/config' import { useAppContext } from '@/context/app-context' import { useProviderContext } from '@/context/provider-context' -import { cn } from '@/utils/classnames' import { mailToSupport } from '../utils/util' +import { ExternalLinkIndicator, MenuItemContent } from './menu-item-content' type SupportProps = { closeAccountDropdown: () => void } +// Submenu-only: this component must be rendered within an existing DropdownMenu root. export default function Support({ closeAccountDropdown }: SupportProps) { - const itemClassName = ` - flex items-center w-full h-9 pl-3 pr-2 text-text-secondary system-md-regular - rounded-lg hover:bg-state-base-hover cursor-pointer gap-1 -` const { t } = useTranslation() const { plan } = useProviderContext() const { userProfile, langGeniusVersionInfo } = useAppContext() const hasDedicatedChannel = plan.type !== Plan.sandbox + const hasZendeskWidget = !!ZENDESK_WIDGET_KEY?.trim() return ( - - { - ({ open }) => ( - <> - + + + + + + {hasDedicatedChannel && hasZendeskWidget && ( + { + toggleZendeskWindow(true) + closeAccountDropdown() + }} > - -
{t('userProfile.support', { ns: 'common' })}
- -
- + + )} + {hasDedicatedChannel && !hasZendeskWidget && ( + } > - -
- {hasDedicatedChannel && ( - - {ZENDESK_WIDGET_KEY && ZENDESK_WIDGET_KEY.trim() !== '' - ? ( - - ) - : ( - - -
{t('userProfile.emailSupport', { ns: 'common' })}
- -
- )} -
- )} - - - -
{t('userProfile.forum', { ns: 'common' })}
- - -
- - - -
{t('userProfile.community', { ns: 'common' })}
- - -
-
-
-
- - ) - } -
+ } + /> + + )} + } + > + } + /> + + } + > + } + /> + + + + ) } diff --git a/web/app/layout.tsx b/web/app/layout.tsx index fd81e09cb6..318cad3a6c 100644 --- a/web/app/layout.tsx +++ b/web/app/layout.tsx @@ -9,6 +9,7 @@ import { getDatasetMap } from '@/env' import { getLocaleOnServer } from '@/i18n-config/server' import { cn } from '@/utils/classnames' import { ToastProvider } from './components/base/toast' +import { TooltipProvider } from './components/base/ui/tooltip' import BrowserInitializer from './components/browser-initializer' import { ReactScanLoader } from './components/devtools/react-scan/loader' import { I18nServerProvider } from './components/provider/i18n-server' @@ -79,7 +80,9 @@ const LocaleLayout = async ({ - {children} + + {children} + diff --git a/web/docs/lint.md b/web/docs/lint.md index a0ec9d58ad..1105d4af08 100644 --- a/web/docs/lint.md +++ b/web/docs/lint.md @@ -43,6 +43,8 @@ This command lints the entire project and is intended for final verification bef If a new rule causes many existing code errors or automatic fixes generate too many diffs, do not use the `--fix` option for automatic fixes. You can introduce the rule first, then use the `--suppress-all` option to temporarily suppress these errors, and gradually fix them in subsequent changes. +For overlay migration policy and cleanup phases, see [Overlay Migration Guide](./overlay-migration.md). + ## Type Check You should be able to see suggestions from TypeScript in your editor for all open files. diff --git a/web/docs/overlay-migration.md b/web/docs/overlay-migration.md new file mode 100644 index 0000000000..77c5fe5b04 --- /dev/null +++ b/web/docs/overlay-migration.md @@ -0,0 +1,50 @@ +# Overlay Migration Guide + +This document tracks the migration away from legacy `portal-to-follow-elem` APIs. + +## Scope + +- Deprecated API: `@/app/components/base/portal-to-follow-elem` +- Replacement primitives: + - `@/app/components/base/ui/tooltip` + - `@/app/components/base/ui/dropdown-menu` + - `@/app/components/base/ui/popover` + - `@/app/components/base/ui/dialog` + - `@/app/components/base/ui/select` +- Tracking issue: https://github.com/langgenius/dify/issues/32767 + +## ESLint policy + +- `no-restricted-imports` blocks new usage of `portal-to-follow-elem`. +- The rule is enabled for normal source files and test files are excluded. +- Legacy `app/components/base/*` callers are temporarily allowlisted in ESLint config. +- New files must not be added to the allowlist without migration owner approval. + +## Migration phases + +1. Business/UI features outside `app/components/base/**` + - Migrate old calls to semantic primitives. + - Keep `eslint-suppressions.json` stable or shrinking. +1. Legacy base components in allowlist + - Migrate allowlisted base callers gradually. + - Remove migrated files from allowlist immediately. +1. Cleanup + - Remove remaining suppressions for `no-restricted-imports`. + - Remove legacy `portal-to-follow-elem` implementation. + +## Suppression maintenance + +- After each migration batch, run: + +```sh +pnpm eslint --prune-suppressions --pass-on-unpruned-suppressions +``` + +- Never increase suppressions to bypass new code. +- Prefer direct migration over adding suppression entries. + +## React Refresh policy for base UI primitives + +- We keep primitive aliases (for example `DropdownMenu = Menu.Root`) in the same module. +- For `app/components/base/ui/**/*.tsx`, `react-refresh/only-export-components` is currently set to `off` in ESLint to avoid false positives and IDE noise during migration. +- Do not use file-level `eslint-disable` comments for this policy; keep control in the scoped ESLint override. diff --git a/web/eslint-suppressions.json b/web/eslint-suppressions.json index c5307cbe8e..95f00e92b3 100644 --- a/web/eslint-suppressions.json +++ b/web/eslint-suppressions.json @@ -81,12 +81,20 @@ "count": 1 } }, + "app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/long-time-range-picker.tsx": { + "no-restricted-imports": { + "count": 2 + } + }, "app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/time-range-picker/date-picker.tsx": { "tailwindcss/enforce-consistent-class-order": { "count": 2 } }, "app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/time-range-picker/range-selector.tsx": { + "no-restricted-imports": { + "count": 2 + }, "tailwindcss/enforce-consistent-class-order": { "count": 2 }, @@ -102,7 +110,15 @@ "count": 3 } }, + "app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/tracing/config-button.tsx": { + "no-restricted-imports": { + "count": 1 + } + }, "app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/tracing/config-popup.tsx": { + "no-restricted-imports": { + "count": 1 + }, "tailwindcss/enforce-consistent-class-order": { "count": 6 } @@ -113,6 +129,9 @@ } }, "app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/tracing/provider-config-modal.tsx": { + "no-restricted-imports": { + "count": 1 + }, "tailwindcss/enforce-consistent-class-order": { "count": 1 } @@ -213,11 +232,17 @@ } }, "app/account/(commonLayout)/account-page/AvatarWithEdit.tsx": { + "no-restricted-imports": { + "count": 1 + }, "tailwindcss/enforce-consistent-class-order": { "count": 1 } }, "app/account/(commonLayout)/account-page/email-change-modal.tsx": { + "no-restricted-imports": { + "count": 1 + }, "tailwindcss/enforce-consistent-class-order": { "count": 21 }, @@ -226,6 +251,9 @@ } }, "app/account/(commonLayout)/account-page/index.tsx": { + "no-restricted-imports": { + "count": 1 + }, "tailwindcss/enforce-consistent-class-order": { "count": 14 }, @@ -278,20 +306,46 @@ } }, "app/components/app-sidebar/app-info/app-operations.tsx": { + "no-restricted-imports": { + "count": 1 + }, "react-hooks-extra/no-direct-set-state-in-use-effect": { "count": 4 } }, + "app/components/app-sidebar/app-sidebar-dropdown.tsx": { + "no-restricted-imports": { + "count": 1 + } + }, + "app/components/app-sidebar/basic.tsx": { + "no-restricted-imports": { + "count": 1 + } + }, "app/components/app-sidebar/dataset-info/dropdown.tsx": { + "no-restricted-imports": { + "count": 1 + }, "ts/no-explicit-any": { "count": 1 } }, + "app/components/app-sidebar/dataset-sidebar-dropdown.tsx": { + "no-restricted-imports": { + "count": 1 + } + }, "app/components/app-sidebar/index.tsx": { "ts/no-explicit-any": { "count": 1 } }, + "app/components/app-sidebar/toggle-button.tsx": { + "no-restricted-imports": { + "count": 1 + } + }, "app/components/app/annotation/add-annotation-modal/edit-item/index.tsx": { "react-refresh/only-export-components": { "count": 1 @@ -321,6 +375,9 @@ } }, "app/components/app/annotation/batch-add-annotation-modal/index.tsx": { + "no-restricted-imports": { + "count": 1 + }, "react-hooks-extra/no-direct-set-state-in-use-effect": { "count": 1 }, @@ -416,6 +473,9 @@ } }, "app/components/app/app-access-control/add-member-or-group-pop.tsx": { + "no-restricted-imports": { + "count": 1 + }, "tailwindcss/enforce-consistent-class-order": { "count": 8 } @@ -426,6 +486,9 @@ } }, "app/components/app/app-access-control/specific-groups-or-members.tsx": { + "no-restricted-imports": { + "count": 1 + }, "tailwindcss/enforce-consistent-class-order": { "count": 8 } @@ -436,16 +499,27 @@ } }, "app/components/app/app-publisher/index.tsx": { + "no-restricted-imports": { + "count": 2 + }, "ts/no-explicit-any": { "count": 5 } }, + "app/components/app/app-publisher/publish-with-multiple-model.tsx": { + "no-restricted-imports": { + "count": 1 + } + }, "app/components/app/app-publisher/suggested-action.tsx": { "tailwindcss/enforce-consistent-class-order": { "count": 1 } }, "app/components/app/app-publisher/version-info-modal.tsx": { + "no-restricted-imports": { + "count": 1 + }, "tailwindcss/enforce-consistent-class-order": { "count": 3 } @@ -461,6 +535,9 @@ } }, "app/components/app/configuration/config-prompt/advanced-prompt-input.tsx": { + "no-restricted-imports": { + "count": 1 + }, "ts/no-explicit-any": { "count": 2 } @@ -471,6 +548,9 @@ } }, "app/components/app/configuration/config-prompt/conversation-history/edit-modal.tsx": { + "no-restricted-imports": { + "count": 1 + }, "ts/no-explicit-any": { "count": 1 } @@ -486,6 +566,9 @@ } }, "app/components/app/configuration/config-prompt/simple-prompt-input.tsx": { + "no-restricted-imports": { + "count": 1 + }, "tailwindcss/enforce-consistent-class-order": { "count": 1 }, @@ -494,6 +577,9 @@ } }, "app/components/app/configuration/config-var/config-modal/index.tsx": { + "no-restricted-imports": { + "count": 2 + }, "tailwindcss/enforce-consistent-class-order": { "count": 2 }, @@ -501,6 +587,11 @@ "count": 6 } }, + "app/components/app/configuration/config-var/config-modal/type-select.tsx": { + "no-restricted-imports": { + "count": 1 + } + }, "app/components/app/configuration/config-var/config-select/index.tsx": { "tailwindcss/enforce-consistent-class-order": { "count": 1 @@ -509,6 +600,11 @@ "count": 1 } }, + "app/components/app/configuration/config-var/index.tsx": { + "no-restricted-imports": { + "count": 1 + } + }, "app/components/app/configuration/config-var/input-type-icon.tsx": { "tailwindcss/enforce-consistent-class-order": { "count": 1 @@ -523,6 +619,9 @@ } }, "app/components/app/configuration/config-var/select-var-type.tsx": { + "no-restricted-imports": { + "count": 1 + }, "ts/no-explicit-any": { "count": 1 } @@ -541,6 +640,9 @@ } }, "app/components/app/configuration/config-vision/index.tsx": { + "no-restricted-imports": { + "count": 1 + }, "tailwindcss/enforce-consistent-class-order": { "count": 2 }, @@ -549,10 +651,18 @@ } }, "app/components/app/configuration/config-vision/param-config-content.tsx": { + "no-restricted-imports": { + "count": 1 + }, "tailwindcss/no-unnecessary-whitespace": { "count": 1 } }, + "app/components/app/configuration/config-vision/param-config.tsx": { + "no-restricted-imports": { + "count": 1 + } + }, "app/components/app/configuration/config/agent-setting-button.spec.tsx": { "ts/no-explicit-any": { "count": 2 @@ -566,12 +676,20 @@ "count": 1 } }, + "app/components/app/configuration/config/agent/agent-setting/item-panel.tsx": { + "no-restricted-imports": { + "count": 1 + } + }, "app/components/app/configuration/config/agent/agent-tools/index.spec.tsx": { "ts/no-explicit-any": { "count": 5 } }, "app/components/app/configuration/config/agent/agent-tools/index.tsx": { + "no-restricted-imports": { + "count": 1 + }, "tailwindcss/enforce-consistent-class-order": { "count": 2 }, @@ -607,6 +725,9 @@ } }, "app/components/app/configuration/config/assistant-type-picker/index.tsx": { + "no-restricted-imports": { + "count": 1 + }, "tailwindcss/no-unnecessary-whitespace": { "count": 3 }, @@ -615,6 +736,9 @@ } }, "app/components/app/configuration/config/automatic/get-automatic-res.tsx": { + "no-restricted-imports": { + "count": 1 + }, "react-hooks-extra/no-direct-set-state-in-use-effect": { "count": 4 }, @@ -652,6 +776,9 @@ } }, "app/components/app/configuration/config/automatic/version-selector.tsx": { + "no-restricted-imports": { + "count": 1 + }, "tailwindcss/enforce-consistent-class-order": { "count": 3 }, @@ -660,6 +787,9 @@ } }, "app/components/app/configuration/config/code-generator/get-code-generator-res.tsx": { + "no-restricted-imports": { + "count": 1 + }, "react-hooks-extra/no-direct-set-state-in-use-effect": { "count": 4 }, @@ -679,6 +809,9 @@ } }, "app/components/app/configuration/config/config-audio.tsx": { + "no-restricted-imports": { + "count": 1 + }, "tailwindcss/enforce-consistent-class-order": { "count": 1 } @@ -689,6 +822,9 @@ } }, "app/components/app/configuration/config/config-document.tsx": { + "no-restricted-imports": { + "count": 1 + }, "tailwindcss/enforce-consistent-class-order": { "count": 1 } @@ -709,11 +845,17 @@ } }, "app/components/app/configuration/dataset-config/context-var/index.tsx": { + "no-restricted-imports": { + "count": 1 + }, "tailwindcss/no-unnecessary-whitespace": { "count": 1 } }, "app/components/app/configuration/dataset-config/context-var/var-picker.tsx": { + "no-restricted-imports": { + "count": 1 + }, "tailwindcss/no-unnecessary-whitespace": { "count": 4 } @@ -728,7 +870,15 @@ "count": 1 } }, + "app/components/app/configuration/dataset-config/params-config/config-content.tsx": { + "no-restricted-imports": { + "count": 1 + } + }, "app/components/app/configuration/dataset-config/params-config/index.tsx": { + "no-restricted-imports": { + "count": 1 + }, "react-hooks-extra/no-direct-set-state-in-use-effect": { "count": 1 } @@ -744,6 +894,9 @@ } }, "app/components/app/configuration/dataset-config/select-dataset/index.tsx": { + "no-restricted-imports": { + "count": 1 + }, "react-hooks-extra/no-direct-set-state-in-use-effect": { "count": 2 }, @@ -771,6 +924,9 @@ } }, "app/components/app/configuration/debug/chat-user-input.tsx": { + "no-restricted-imports": { + "count": 1 + }, "tailwindcss/enforce-consistent-class-order": { "count": 2 } @@ -795,6 +951,11 @@ "count": 2 } }, + "app/components/app/configuration/debug/debug-with-multiple-model/model-parameter-trigger.tsx": { + "no-restricted-imports": { + "count": 1 + } + }, "app/components/app/configuration/debug/debug-with-multiple-model/text-generation-item.tsx": { "ts/no-explicit-any": { "count": 8 @@ -819,6 +980,9 @@ } }, "app/components/app/configuration/debug/index.tsx": { + "no-restricted-imports": { + "count": 1 + }, "react-hooks-extra/no-direct-set-state-in-use-effect": { "count": 2 }, @@ -857,6 +1021,9 @@ } }, "app/components/app/configuration/prompt-value-panel/index.tsx": { + "no-restricted-imports": { + "count": 2 + }, "tailwindcss/enforce-consistent-class-order": { "count": 4 } @@ -867,6 +1034,9 @@ } }, "app/components/app/configuration/tools/external-data-tool-modal.tsx": { + "no-restricted-imports": { + "count": 2 + }, "tailwindcss/no-unnecessary-whitespace": { "count": 1 }, @@ -874,6 +1044,11 @@ "count": 2 } }, + "app/components/app/configuration/tools/index.tsx": { + "no-restricted-imports": { + "count": 1 + } + }, "app/components/app/create-app-dialog/app-card/index.spec.tsx": { "ts/no-explicit-any": { "count": 1 @@ -917,11 +1092,17 @@ } }, "app/components/app/create-from-dsl-modal/dsl-confirm-modal.tsx": { + "no-restricted-imports": { + "count": 1 + }, "tailwindcss/enforce-consistent-class-order": { "count": 2 } }, "app/components/app/create-from-dsl-modal/index.tsx": { + "no-restricted-imports": { + "count": 1 + }, "react-hooks-extra/no-direct-set-state-in-use-effect": { "count": 2 }, @@ -938,6 +1119,9 @@ } }, "app/components/app/duplicate-modal/index.tsx": { + "no-restricted-imports": { + "count": 1 + }, "tailwindcss/enforce-consistent-class-order": { "count": 2 } @@ -961,6 +1145,9 @@ } }, "app/components/app/log/list.tsx": { + "no-restricted-imports": { + "count": 1 + }, "react-hooks-extra/no-direct-set-state-in-use-effect": { "count": 6 }, @@ -975,6 +1162,9 @@ } }, "app/components/app/log/model-info.tsx": { + "no-restricted-imports": { + "count": 1 + }, "tailwindcss/enforce-consistent-class-order": { "count": 3 }, @@ -998,6 +1188,9 @@ } }, "app/components/app/overview/app-card.tsx": { + "no-restricted-imports": { + "count": 1 + }, "react-hooks-extra/no-direct-set-state-in-use-effect": { "count": 3 }, @@ -1017,11 +1210,17 @@ } }, "app/components/app/overview/customize/index.tsx": { + "no-restricted-imports": { + "count": 1 + }, "tailwindcss/enforce-consistent-class-order": { "count": 2 } }, "app/components/app/overview/embedded/index.tsx": { + "no-restricted-imports": { + "count": 2 + }, "react-hooks-extra/no-direct-set-state-in-use-effect": { "count": 1 }, @@ -1033,6 +1232,9 @@ } }, "app/components/app/overview/settings/index.tsx": { + "no-restricted-imports": { + "count": 3 + }, "react-hooks-extra/no-direct-set-state-in-use-effect": { "count": 3 }, @@ -1052,6 +1254,9 @@ } }, "app/components/app/switch-app-modal/index.tsx": { + "no-restricted-imports": { + "count": 1 + }, "react-hooks-extra/no-direct-set-state-in-use-effect": { "count": 1 } @@ -1095,11 +1300,17 @@ } }, "app/components/app/type-selector/index.tsx": { + "no-restricted-imports": { + "count": 1 + }, "tailwindcss/enforce-consistent-class-order": { "count": 3 } }, "app/components/app/workflow-log/detail.tsx": { + "no-restricted-imports": { + "count": 1 + }, "tailwindcss/enforce-consistent-class-order": { "count": 1 } @@ -1136,6 +1347,9 @@ } }, "app/components/apps/app-card.tsx": { + "no-restricted-imports": { + "count": 1 + }, "react-hooks-extra/no-direct-set-state-in-use-effect": { "count": 1 }, @@ -1227,6 +1441,11 @@ "count": 3 } }, + "app/components/base/audio-btn/index.tsx": { + "no-restricted-imports": { + "count": 1 + } + }, "app/components/base/audio-gallery/AudioPlayer.tsx": { "ts/no-explicit-any": { "count": 2 @@ -1284,6 +1503,11 @@ "count": 1 } }, + "app/components/base/button/sync-button.tsx": { + "no-restricted-imports": { + "count": 1 + } + }, "app/components/base/carousel/index.tsx": { "react-refresh/only-export-components": { "count": 1 @@ -1308,6 +1532,9 @@ } }, "app/components/base/chat/chat-with-history/header/index.tsx": { + "no-restricted-imports": { + "count": 1 + }, "tailwindcss/enforce-consistent-class-order": { "count": 1 }, @@ -1337,6 +1564,9 @@ } }, "app/components/base/chat/chat-with-history/inputs-form/content.tsx": { + "no-restricted-imports": { + "count": 1 + }, "tailwindcss/enforce-consistent-class-order": { "count": 3 }, @@ -1377,6 +1607,11 @@ "count": 3 } }, + "app/components/base/chat/chat-with-history/sidebar/rename-modal.tsx": { + "no-restricted-imports": { + "count": 1 + } + }, "app/components/base/chat/chat/answer/agent-content.tsx": { "style/multiline-ternary": { "count": 2 @@ -1401,6 +1636,11 @@ "count": 1 } }, + "app/components/base/chat/chat/answer/operation.tsx": { + "no-restricted-imports": { + "count": 2 + } + }, "app/components/base/chat/chat/answer/tool-detail.tsx": { "tailwindcss/enforce-consistent-class-order": { "count": 5 @@ -1478,6 +1718,11 @@ "count": 7 } }, + "app/components/base/chat/embedded-chatbot/header/index.tsx": { + "no-restricted-imports": { + "count": 1 + } + }, "app/components/base/chat/embedded-chatbot/hooks.tsx": { "react-hooks-extra/no-direct-set-state-in-use-effect": { "count": 3 @@ -1489,6 +1734,9 @@ } }, "app/components/base/chat/embedded-chatbot/inputs-form/content.tsx": { + "no-restricted-imports": { + "count": 1 + }, "ts/no-explicit-any": { "count": 3 } @@ -1535,6 +1783,11 @@ "count": 1 } }, + "app/components/base/copy-feedback/index.tsx": { + "no-restricted-imports": { + "count": 1 + } + }, "app/components/base/corner-label/index.tsx": { "tailwindcss/enforce-consistent-class-order": { "count": 1 @@ -1621,6 +1874,11 @@ "count": 1 } }, + "app/components/base/emoji-picker/index.tsx": { + "no-restricted-imports": { + "count": 1 + } + }, "app/components/base/encrypted-bottom/index.tsx": { "tailwindcss/enforce-consistent-class-order": { "count": 1 @@ -1639,12 +1897,23 @@ "count": 1 } }, + "app/components/base/features/new-feature-panel/annotation-reply/annotation-ctrl-button.tsx": { + "no-restricted-imports": { + "count": 1 + } + }, "app/components/base/features/new-feature-panel/annotation-reply/config-param-modal.tsx": { + "no-restricted-imports": { + "count": 1 + }, "tailwindcss/enforce-consistent-class-order": { "count": 1 } }, "app/components/base/features/new-feature-panel/annotation-reply/config-param.tsx": { + "no-restricted-imports": { + "count": 1 + }, "tailwindcss/enforce-consistent-class-order": { "count": 2 } @@ -1681,16 +1950,25 @@ } }, "app/components/base/features/new-feature-panel/conversation-opener/modal.tsx": { + "no-restricted-imports": { + "count": 1 + }, "react-hooks-extra/no-direct-set-state-in-use-effect": { "count": 1 } }, "app/components/base/features/new-feature-panel/feature-bar.tsx": { + "no-restricted-imports": { + "count": 1 + }, "tailwindcss/enforce-consistent-class-order": { "count": 2 } }, "app/components/base/features/new-feature-panel/feature-card.tsx": { + "no-restricted-imports": { + "count": 1 + }, "tailwindcss/enforce-consistent-class-order": { "count": 2 }, @@ -1718,6 +1996,11 @@ "count": 3 } }, + "app/components/base/features/new-feature-panel/moderation/form-generation.tsx": { + "no-restricted-imports": { + "count": 1 + } + }, "app/components/base/features/new-feature-panel/moderation/index.tsx": { "tailwindcss/enforce-consistent-class-order": { "count": 5 @@ -1727,6 +2010,9 @@ } }, "app/components/base/features/new-feature-panel/moderation/moderation-setting-modal.tsx": { + "no-restricted-imports": { + "count": 1 + }, "ts/no-explicit-any": { "count": 2 } @@ -1736,6 +2022,11 @@ "count": 7 } }, + "app/components/base/features/new-feature-panel/text-to-speech/param-config-content.tsx": { + "no-restricted-imports": { + "count": 2 + } + }, "app/components/base/features/new-feature-panel/text-to-speech/voice-settings.tsx": { "ts/no-explicit-any": { "count": 1 @@ -1762,6 +2053,9 @@ } }, "app/components/base/file-uploader/file-list-in-log.tsx": { + "no-restricted-imports": { + "count": 1 + }, "react/no-missing-key": { "count": 1 }, @@ -1787,6 +2081,11 @@ "count": 3 } }, + "app/components/base/file-uploader/pdf-preview.tsx": { + "no-restricted-imports": { + "count": 1 + } + }, "app/components/base/file-uploader/store.tsx": { "react-refresh/only-export-components": { "count": 4 @@ -1798,6 +2097,9 @@ } }, "app/components/base/form/components/base/base-field.tsx": { + "no-restricted-imports": { + "count": 2 + }, "tailwindcss/enforce-consistent-class-order": { "count": 4 }, @@ -1926,7 +2228,15 @@ "count": 1 } }, + "app/components/base/image-uploader/image-list.tsx": { + "no-restricted-imports": { + "count": 1 + } + }, "app/components/base/image-uploader/image-preview.tsx": { + "no-restricted-imports": { + "count": 1 + }, "ts/no-explicit-any": { "count": 1 } @@ -2017,6 +2327,9 @@ } }, "app/components/base/markdown-blocks/form.tsx": { + "no-restricted-imports": { + "count": 1 + }, "react-hooks-extra/no-direct-set-state-in-use-effect": { "count": 1 }, @@ -2147,6 +2460,9 @@ } }, "app/components/base/new-audio-button/index.tsx": { + "no-restricted-imports": { + "count": 1 + }, "ts/no-explicit-any": { "count": 1 } @@ -2190,6 +2506,9 @@ } }, "app/components/base/param-item/index.tsx": { + "no-restricted-imports": { + "count": 1 + }, "tailwindcss/enforce-consistent-class-order": { "count": 1 } @@ -2298,6 +2617,11 @@ "count": 1 } }, + "app/components/base/prompt-editor/plugins/hitl-input-block/variable-block.tsx": { + "no-restricted-imports": { + "count": 1 + } + }, "app/components/base/prompt-editor/plugins/last-run-block/component.tsx": { "tailwindcss/no-unnecessary-whitespace": { "count": 2 @@ -2343,6 +2667,11 @@ "count": 2 } }, + "app/components/base/prompt-editor/plugins/workflow-variable-block/component.tsx": { + "no-restricted-imports": { + "count": 1 + } + }, "app/components/base/prompt-editor/plugins/workflow-variable-block/index.tsx": { "react-refresh/only-export-components": { "count": 4 @@ -2368,6 +2697,11 @@ "count": 1 } }, + "app/components/base/qrcode/index.tsx": { + "no-restricted-imports": { + "count": 1 + } + }, "app/components/base/radio-card/index.stories.tsx": { "ts/no-explicit-any": { "count": 1 @@ -2431,9 +2765,6 @@ "style/multiline-ternary": { "count": 2 }, - "tailwindcss/enforce-consistent-class-order": { - "count": 1 - }, "ts/no-explicit-any": { "count": 1 } @@ -2495,6 +2826,21 @@ "count": 1 } }, + "app/components/base/tag-management/index.tsx": { + "no-restricted-imports": { + "count": 1 + } + }, + "app/components/base/tag-management/tag-item-editor.tsx": { + "no-restricted-imports": { + "count": 1 + } + }, + "app/components/base/tag-management/tag-remove-modal.tsx": { + "no-restricted-imports": { + "count": 1 + } + }, "app/components/base/text-generation/hooks.ts": { "ts/no-explicit-any": { "count": 1 @@ -2521,11 +2867,6 @@ "count": 2 } }, - "app/components/base/tooltip/index.tsx": { - "tailwindcss/enforce-consistent-class-order": { - "count": 1 - } - }, "app/components/base/video-gallery/VideoPlayer.tsx": { "react-hooks-extra/no-direct-set-state-in-use-effect": { "count": 1 @@ -2559,6 +2900,11 @@ "count": 4 } }, + "app/components/billing/annotation-full/modal.tsx": { + "no-restricted-imports": { + "count": 1 + } + }, "app/components/billing/apps-full-in-dialog/index.tsx": { "tailwindcss/enforce-consistent-class-order": { "count": 5 @@ -2570,6 +2916,9 @@ } }, "app/components/billing/plan-upgrade-modal/index.tsx": { + "no-restricted-imports": { + "count": 1 + }, "tailwindcss/enforce-consistent-class-order": { "count": 1 } @@ -2646,6 +2995,9 @@ } }, "app/components/billing/priority-label/index.tsx": { + "no-restricted-imports": { + "count": 1 + }, "tailwindcss/enforce-consistent-class-order": { "count": 1 } @@ -2656,6 +3008,9 @@ } }, "app/components/billing/usage-info/index.tsx": { + "no-restricted-imports": { + "count": 1 + }, "tailwindcss/enforce-consistent-class-order": { "count": 8 } @@ -2676,11 +3031,17 @@ } }, "app/components/datasets/common/document-picker/index.tsx": { + "no-restricted-imports": { + "count": 1 + }, "tailwindcss/enforce-consistent-class-order": { "count": 1 } }, "app/components/datasets/common/document-picker/preview-document-picker.tsx": { + "no-restricted-imports": { + "count": 1 + }, "tailwindcss/enforce-consistent-class-order": { "count": 2 } @@ -2709,6 +3070,9 @@ } }, "app/components/datasets/common/image-uploader/image-uploader-in-retrieval-testing/image-input.tsx": { + "no-restricted-imports": { + "count": 1 + }, "tailwindcss/enforce-consistent-class-order": { "count": 1 } @@ -2724,11 +3088,17 @@ } }, "app/components/datasets/common/retrieval-param-config/index.tsx": { + "no-restricted-imports": { + "count": 1 + }, "tailwindcss/enforce-consistent-class-order": { "count": 3 } }, "app/components/datasets/create-from-pipeline/create-options/create-from-dsl-modal/dsl-confirm-modal.tsx": { + "no-restricted-imports": { + "count": 1 + }, "tailwindcss/enforce-consistent-class-order": { "count": 2 } @@ -2739,6 +3109,9 @@ } }, "app/components/datasets/create-from-pipeline/create-options/create-from-dsl-modal/index.tsx": { + "no-restricted-imports": { + "count": 1 + }, "tailwindcss/enforce-consistent-class-order": { "count": 1 } @@ -2787,6 +3160,9 @@ } }, "app/components/datasets/create-from-pipeline/list/template-card/details/index.tsx": { + "no-restricted-imports": { + "count": 1 + }, "tailwindcss/enforce-consistent-class-order": { "count": 4 } @@ -2796,6 +3172,11 @@ "count": 3 } }, + "app/components/datasets/create-from-pipeline/list/template-card/index.tsx": { + "no-restricted-imports": { + "count": 1 + } + }, "app/components/datasets/create-from-pipeline/list/template-card/operations.tsx": { "tailwindcss/enforce-consistent-class-order": { "count": 3 @@ -2807,10 +3188,18 @@ } }, "app/components/datasets/create/embedding-process/indexing-progress-item.tsx": { + "no-restricted-imports": { + "count": 1 + }, "tailwindcss/enforce-consistent-class-order": { "count": 1 } }, + "app/components/datasets/create/empty-dataset-creation-modal/index.tsx": { + "no-restricted-imports": { + "count": 1 + } + }, "app/components/datasets/create/file-preview/index.tsx": { "react-hooks-extra/no-direct-set-state-in-use-effect": { "count": 1 @@ -2840,16 +3229,25 @@ } }, "app/components/datasets/create/step-two/components/general-chunking-options.tsx": { + "no-restricted-imports": { + "count": 1 + }, "tailwindcss/enforce-consistent-class-order": { "count": 4 } }, "app/components/datasets/create/step-two/components/indexing-mode-section.tsx": { + "no-restricted-imports": { + "count": 1 + }, "tailwindcss/enforce-consistent-class-order": { "count": 8 } }, "app/components/datasets/create/step-two/components/inputs.tsx": { + "no-restricted-imports": { + "count": 1 + }, "tailwindcss/enforce-consistent-class-order": { "count": 2 } @@ -2889,16 +3287,31 @@ "count": 2 } }, + "app/components/datasets/create/stop-embedding-modal/index.tsx": { + "no-restricted-imports": { + "count": 1 + } + }, "app/components/datasets/create/top-bar/index.tsx": { "tailwindcss/enforce-consistent-class-order": { "count": 1 } }, + "app/components/datasets/create/website/base/checkbox-with-label.tsx": { + "no-restricted-imports": { + "count": 1 + } + }, "app/components/datasets/create/website/base/error-message.tsx": { "tailwindcss/enforce-consistent-class-order": { "count": 2 } }, + "app/components/datasets/create/website/base/field.tsx": { + "no-restricted-imports": { + "count": 1 + } + }, "app/components/datasets/create/website/base/header.tsx": { "tailwindcss/enforce-consistent-class-order": { "count": 1 @@ -2981,11 +3394,31 @@ "count": 1 } }, + "app/components/datasets/documents/components/document-list/components/document-table-row.tsx": { + "no-restricted-imports": { + "count": 1 + } + }, + "app/components/datasets/documents/components/documents-header.tsx": { + "no-restricted-imports": { + "count": 1 + } + }, "app/components/datasets/documents/components/list.tsx": { "react-refresh/only-export-components": { "count": 1 } }, + "app/components/datasets/documents/components/operations.tsx": { + "no-restricted-imports": { + "count": 1 + } + }, + "app/components/datasets/documents/components/rename-modal.tsx": { + "no-restricted-imports": { + "count": 1 + } + }, "app/components/datasets/documents/create-from-pipeline/actions/index.tsx": { "tailwindcss/enforce-consistent-class-order": { "count": 2 @@ -2996,6 +3429,11 @@ "count": 1 } }, + "app/components/datasets/documents/create-from-pipeline/data-source/base/credential-selector/index.tsx": { + "no-restricted-imports": { + "count": 1 + } + }, "app/components/datasets/documents/create-from-pipeline/data-source/base/credential-selector/item.tsx": { "tailwindcss/enforce-consistent-class-order": { "count": 1 @@ -3007,6 +3445,9 @@ } }, "app/components/datasets/documents/create-from-pipeline/data-source/base/header.tsx": { + "no-restricted-imports": { + "count": 1 + }, "tailwindcss/enforce-consistent-class-order": { "count": 1 } @@ -3032,6 +3473,9 @@ } }, "app/components/datasets/documents/create-from-pipeline/data-source/online-drive/file-list/header/breadcrumbs/bucket.tsx": { + "no-restricted-imports": { + "count": 1 + }, "tailwindcss/enforce-consistent-class-order": { "count": 4 } @@ -3042,6 +3486,9 @@ } }, "app/components/datasets/documents/create-from-pipeline/data-source/online-drive/file-list/header/breadcrumbs/dropdown/index.tsx": { + "no-restricted-imports": { + "count": 1 + }, "tailwindcss/enforce-consistent-class-order": { "count": 1 } @@ -3072,6 +3519,9 @@ } }, "app/components/datasets/documents/create-from-pipeline/data-source/online-drive/file-list/list/item.tsx": { + "no-restricted-imports": { + "count": 1 + }, "tailwindcss/enforce-consistent-class-order": { "count": 2 } @@ -3102,6 +3552,9 @@ } }, "app/components/datasets/documents/create-from-pipeline/data-source/website-crawl/base/checkbox-with-label.tsx": { + "no-restricted-imports": { + "count": 1 + }, "tailwindcss/enforce-consistent-class-order": { "count": 1 } @@ -3181,6 +3634,9 @@ } }, "app/components/datasets/documents/create-from-pipeline/processing/embedding-process/index.tsx": { + "no-restricted-imports": { + "count": 1 + }, "tailwindcss/enforce-consistent-class-order": { "count": 3 } @@ -3196,6 +3652,9 @@ } }, "app/components/datasets/documents/detail/batch-modal/index.tsx": { + "no-restricted-imports": { + "count": 1 + }, "react-hooks-extra/no-direct-set-state-in-use-effect": { "count": 1 } @@ -3249,6 +3708,9 @@ } }, "app/components/datasets/documents/detail/completed/common/regeneration-modal.tsx": { + "no-restricted-imports": { + "count": 1 + }, "tailwindcss/enforce-consistent-class-order": { "count": 6 } @@ -3263,16 +3725,34 @@ "count": 2 } }, + "app/components/datasets/documents/detail/completed/common/summary-status.tsx": { + "no-restricted-imports": { + "count": 1 + } + }, "app/components/datasets/documents/detail/completed/common/summary-text.tsx": { "tailwindcss/enforce-consistent-class-order": { "count": 2 } }, "app/components/datasets/documents/detail/completed/components/menu-bar.tsx": { + "no-restricted-imports": { + "count": 2 + }, "tailwindcss/enforce-consistent-class-order": { "count": 1 } }, + "app/components/datasets/documents/detail/completed/display-toggle.tsx": { + "no-restricted-imports": { + "count": 1 + } + }, + "app/components/datasets/documents/detail/completed/hooks/use-search-filter.ts": { + "no-restricted-imports": { + "count": 1 + } + }, "app/components/datasets/documents/detail/completed/index.tsx": { "react-refresh/only-export-components": { "count": 1 @@ -3292,6 +3772,9 @@ } }, "app/components/datasets/documents/detail/completed/segment-card/index.tsx": { + "no-restricted-imports": { + "count": 1 + }, "tailwindcss/enforce-consistent-class-order": { "count": 3 } @@ -3306,6 +3789,11 @@ "count": 1 } }, + "app/components/datasets/documents/detail/completed/status-item.tsx": { + "no-restricted-imports": { + "count": 1 + } + }, "app/components/datasets/documents/detail/context.ts": { "ts/no-explicit-any": { "count": 1 @@ -3321,6 +3809,16 @@ "count": 3 } }, + "app/components/datasets/documents/detail/metadata/components/doc-type-selector.tsx": { + "no-restricted-imports": { + "count": 1 + } + }, + "app/components/datasets/documents/detail/metadata/components/field-info.tsx": { + "no-restricted-imports": { + "count": 1 + } + }, "app/components/datasets/documents/detail/new-segment.tsx": { "tailwindcss/enforce-consistent-class-order": { "count": 3 @@ -3365,12 +3863,20 @@ "count": 2 } }, + "app/components/datasets/documents/status-item/index.tsx": { + "no-restricted-imports": { + "count": 1 + } + }, "app/components/datasets/external-api/external-api-modal/Form.tsx": { "tailwindcss/enforce-consistent-class-order": { "count": 2 } }, "app/components/datasets/external-api/external-api-modal/index.tsx": { + "no-restricted-imports": { + "count": 2 + }, "react-hooks-extra/no-direct-set-state-in-use-effect": { "count": 1 }, @@ -3430,6 +3936,9 @@ } }, "app/components/datasets/extra-info/api-access/index.tsx": { + "no-restricted-imports": { + "count": 1 + }, "tailwindcss/enforce-consistent-class-order": { "count": 1 } @@ -3440,11 +3949,17 @@ } }, "app/components/datasets/extra-info/service-api/index.tsx": { + "no-restricted-imports": { + "count": 1 + }, "tailwindcss/enforce-consistent-class-order": { "count": 1 } }, "app/components/datasets/extra-info/statistics.tsx": { + "no-restricted-imports": { + "count": 1 + }, "tailwindcss/enforce-consistent-class-order": { "count": 4 } @@ -3463,6 +3978,9 @@ } }, "app/components/datasets/hit-testing/components/chunk-detail-modal.tsx": { + "no-restricted-imports": { + "count": 1 + }, "tailwindcss/enforce-consistent-class-order": { "count": 3 } @@ -3473,6 +3991,9 @@ } }, "app/components/datasets/hit-testing/components/query-input/textarea.tsx": { + "no-restricted-imports": { + "count": 1 + }, "tailwindcss/enforce-consistent-class-order": { "count": 3 } @@ -3483,6 +4004,9 @@ } }, "app/components/datasets/hit-testing/components/result-item-external.tsx": { + "no-restricted-imports": { + "count": 1 + }, "tailwindcss/enforce-consistent-class-order": { "count": 2 } @@ -3503,6 +4027,9 @@ } }, "app/components/datasets/list/dataset-card/components/dataset-card-footer.tsx": { + "no-restricted-imports": { + "count": 1 + }, "tailwindcss/enforce-consistent-class-order": { "count": 1 } @@ -3537,6 +4064,11 @@ "count": 1 } }, + "app/components/datasets/metadata/edit-metadata-batch/edited-beacon.tsx": { + "no-restricted-imports": { + "count": 1 + } + }, "app/components/datasets/metadata/edit-metadata-batch/input-combined.tsx": { "ts/no-explicit-any": { "count": 2 @@ -3556,6 +4088,9 @@ } }, "app/components/datasets/metadata/edit-metadata-batch/modal.tsx": { + "no-restricted-imports": { + "count": 2 + }, "tailwindcss/enforce-consistent-class-order": { "count": 3 } @@ -3576,11 +4111,17 @@ } }, "app/components/datasets/metadata/metadata-dataset/create-metadata-modal.tsx": { + "no-restricted-imports": { + "count": 1 + }, "ts/no-explicit-any": { "count": 1 } }, "app/components/datasets/metadata/metadata-dataset/dataset-metadata-drawer.tsx": { + "no-restricted-imports": { + "count": 2 + }, "tailwindcss/enforce-consistent-class-order": { "count": 5 }, @@ -3593,6 +4134,11 @@ "count": 1 } }, + "app/components/datasets/metadata/metadata-dataset/select-metadata-modal.tsx": { + "no-restricted-imports": { + "count": 1 + } + }, "app/components/datasets/metadata/metadata-dataset/select-metadata.tsx": { "tailwindcss/enforce-consistent-class-order": { "count": 3 @@ -3607,6 +4153,9 @@ } }, "app/components/datasets/metadata/metadata-document/info-group.tsx": { + "no-restricted-imports": { + "count": 1 + }, "tailwindcss/enforce-consistent-class-order": { "count": 1 }, @@ -3624,6 +4173,11 @@ "count": 1 } }, + "app/components/datasets/rename-modal/index.tsx": { + "no-restricted-imports": { + "count": 1 + } + }, "app/components/datasets/settings/form/components/basic-info-section.tsx": { "tailwindcss/enforce-consistent-class-order": { "count": 3 @@ -3639,7 +4193,15 @@ "count": 7 } }, + "app/components/datasets/settings/index-method/index.tsx": { + "no-restricted-imports": { + "count": 1 + } + }, "app/components/datasets/settings/index-method/keyword-number.tsx": { + "no-restricted-imports": { + "count": 1 + }, "tailwindcss/enforce-consistent-class-order": { "count": 1 } @@ -3650,6 +4212,9 @@ } }, "app/components/datasets/settings/permission-selector/index.tsx": { + "no-restricted-imports": { + "count": 1 + }, "react/no-missing-key": { "count": 1 }, @@ -3668,6 +4233,9 @@ } }, "app/components/datasets/settings/summary-index-setting.tsx": { + "no-restricted-imports": { + "count": 1 + }, "tailwindcss/enforce-consistent-class-order": { "count": 11 } @@ -3688,11 +4256,26 @@ "count": 2 } }, + "app/components/develop/secret-key/input-copy.tsx": { + "no-restricted-imports": { + "count": 1 + } + }, "app/components/develop/secret-key/secret-key-button.tsx": { "tailwindcss/enforce-consistent-class-order": { "count": 1 } }, + "app/components/develop/secret-key/secret-key-generate.tsx": { + "no-restricted-imports": { + "count": 1 + } + }, + "app/components/develop/secret-key/secret-key-modal.tsx": { + "no-restricted-imports": { + "count": 1 + } + }, "app/components/explore/banner/banner-item.tsx": { "tailwindcss/enforce-consistent-class-order": { "count": 4 @@ -3712,6 +4295,9 @@ } }, "app/components/explore/create-app-modal/index.tsx": { + "no-restricted-imports": { + "count": 1 + }, "tailwindcss/enforce-consistent-class-order": { "count": 2 }, @@ -3723,6 +4309,9 @@ } }, "app/components/explore/item-operation/index.tsx": { + "no-restricted-imports": { + "count": 1 + }, "react-hooks-extra/no-direct-set-state-in-use-effect": { "count": 1 } @@ -3738,6 +4327,9 @@ } }, "app/components/explore/try-app/app/chat.tsx": { + "no-restricted-imports": { + "count": 1 + }, "tailwindcss/enforce-consistent-class-order": { "count": 1 } @@ -3750,6 +4342,11 @@ "count": 1 } }, + "app/components/explore/try-app/index.tsx": { + "no-restricted-imports": { + "count": 1 + } + }, "app/components/explore/try-app/preview/basic-app-preview.tsx": { "tailwindcss/no-unnecessary-whitespace": { "count": 2 @@ -3796,19 +4393,14 @@ "count": 1 } }, - "app/components/header/account-dropdown/compliance.tsx": { - "tailwindcss/enforce-consistent-class-order": { - "count": 6 + "app/components/goto-anything/index.tsx": { + "no-restricted-imports": { + "count": 1 } }, - "app/components/header/account-dropdown/index.tsx": { - "tailwindcss/enforce-consistent-class-order": { - "count": 12 - } - }, - "app/components/header/account-dropdown/support.tsx": { - "tailwindcss/enforce-consistent-class-order": { - "count": 5 + "app/components/header/account-about/index.tsx": { + "no-restricted-imports": { + "count": 1 } }, "app/components/header/account-dropdown/workplace-selector/index.tsx": { @@ -3824,6 +4416,16 @@ "count": 2 } }, + "app/components/header/account-setting/api-based-extension-page/modal.tsx": { + "no-restricted-imports": { + "count": 1 + } + }, + "app/components/header/account-setting/api-based-extension-page/selector.tsx": { + "no-restricted-imports": { + "count": 1 + } + }, "app/components/header/account-setting/data-source-page-new/card.tsx": { "tailwindcss/enforce-consistent-class-order": { "count": 4 @@ -3833,6 +4435,9 @@ } }, "app/components/header/account-setting/data-source-page-new/configure.tsx": { + "no-restricted-imports": { + "count": 1 + }, "tailwindcss/enforce-consistent-class-order": { "count": 1 } @@ -3874,16 +4479,25 @@ } }, "app/components/header/account-setting/data-source-page/data-source-website/config-firecrawl-modal.tsx": { + "no-restricted-imports": { + "count": 1 + }, "tailwindcss/enforce-consistent-class-order": { "count": 1 } }, "app/components/header/account-setting/data-source-page/data-source-website/config-jina-reader-modal.tsx": { + "no-restricted-imports": { + "count": 1 + }, "tailwindcss/enforce-consistent-class-order": { "count": 1 } }, "app/components/header/account-setting/data-source-page/data-source-website/config-watercrawl-modal.tsx": { + "no-restricted-imports": { + "count": 1 + }, "tailwindcss/enforce-consistent-class-order": { "count": 1 } @@ -3916,12 +4530,48 @@ "count": 1 } }, + "app/components/header/account-setting/language-page/index.tsx": { + "no-restricted-imports": { + "count": 2 + } + }, + "app/components/header/account-setting/members-page/edit-workspace-modal/index.tsx": { + "no-restricted-imports": { + "count": 1 + } + }, + "app/components/header/account-setting/members-page/index.tsx": { + "no-restricted-imports": { + "count": 1 + } + }, "app/components/header/account-setting/members-page/invite-modal/index.tsx": { + "no-restricted-imports": { + "count": 1 + }, "react-hooks-extra/no-direct-set-state-in-use-effect": { "count": 3 } }, + "app/components/header/account-setting/members-page/invite-modal/role-selector.tsx": { + "no-restricted-imports": { + "count": 1 + } + }, + "app/components/header/account-setting/members-page/invited-modal/index.tsx": { + "no-restricted-imports": { + "count": 2 + } + }, + "app/components/header/account-setting/members-page/invited-modal/invitation-link.tsx": { + "no-restricted-imports": { + "count": 1 + } + }, "app/components/header/account-setting/members-page/operation/index.tsx": { + "no-restricted-imports": { + "count": 1 + }, "tailwindcss/enforce-consistent-class-order": { "count": 5 } @@ -3932,11 +4582,17 @@ } }, "app/components/header/account-setting/members-page/transfer-ownership-modal/index.tsx": { + "no-restricted-imports": { + "count": 1 + }, "ts/no-explicit-any": { "count": 3 } }, "app/components/header/account-setting/members-page/transfer-ownership-modal/member-selector.tsx": { + "no-restricted-imports": { + "count": 1 + }, "ts/no-explicit-any": { "count": 2 } @@ -3973,6 +4629,9 @@ } }, "app/components/header/account-setting/model-provider-page/model-auth/add-custom-model.tsx": { + "no-restricted-imports": { + "count": 2 + }, "tailwindcss/enforce-consistent-class-order": { "count": 2 } @@ -3983,11 +4642,17 @@ } }, "app/components/header/account-setting/model-provider-page/model-auth/authorized/credential-item.tsx": { + "no-restricted-imports": { + "count": 1 + }, "tailwindcss/enforce-consistent-class-order": { "count": 1 } }, "app/components/header/account-setting/model-provider-page/model-auth/authorized/index.tsx": { + "no-restricted-imports": { + "count": 2 + }, "tailwindcss/enforce-consistent-class-order": { "count": 2 }, @@ -4000,7 +4665,15 @@ "count": 1 } }, + "app/components/header/account-setting/model-provider-page/model-auth/config-provider.tsx": { + "no-restricted-imports": { + "count": 1 + } + }, "app/components/header/account-setting/model-provider-page/model-auth/credential-selector.tsx": { + "no-restricted-imports": { + "count": 1 + }, "tailwindcss/enforce-consistent-class-order": { "count": 4 } @@ -4016,6 +4689,9 @@ } }, "app/components/header/account-setting/model-provider-page/model-auth/switch-credential-in-load-balancing.tsx": { + "no-restricted-imports": { + "count": 1 + }, "ts/no-explicit-any": { "count": 3 } @@ -4026,6 +4702,9 @@ } }, "app/components/header/account-setting/model-provider-page/model-modal/Form.tsx": { + "no-restricted-imports": { + "count": 2 + }, "tailwindcss/enforce-consistent-class-order": { "count": 10 }, @@ -4039,6 +4718,9 @@ } }, "app/components/header/account-setting/model-provider-page/model-modal/index.tsx": { + "no-restricted-imports": { + "count": 1 + }, "tailwindcss/enforce-consistent-class-order": { "count": 6 }, @@ -4065,6 +4747,9 @@ } }, "app/components/header/account-setting/model-provider-page/model-parameter-modal/index.tsx": { + "no-restricted-imports": { + "count": 1 + }, "tailwindcss/enforce-consistent-class-order": { "count": 3 } @@ -4077,12 +4762,20 @@ "count": 1 } }, + "app/components/header/account-setting/model-provider-page/model-parameter-modal/parameter-item.tsx": { + "no-restricted-imports": { + "count": 2 + } + }, "app/components/header/account-setting/model-provider-page/model-parameter-modal/presets-parameter.tsx": { "tailwindcss/enforce-consistent-class-order": { "count": 1 } }, "app/components/header/account-setting/model-provider-page/model-parameter-modal/status-indicators.tsx": { + "no-restricted-imports": { + "count": 1 + }, "tailwindcss/enforce-consistent-class-order": { "count": 3 }, @@ -4091,26 +4784,58 @@ } }, "app/components/header/account-setting/model-provider-page/model-parameter-modal/trigger.tsx": { + "no-restricted-imports": { + "count": 1 + }, "tailwindcss/no-unnecessary-whitespace": { "count": 2 } }, "app/components/header/account-setting/model-provider-page/model-selector/deprecated-model-trigger.tsx": { + "no-restricted-imports": { + "count": 1 + }, "tailwindcss/enforce-consistent-class-order": { "count": 1 } }, + "app/components/header/account-setting/model-provider-page/model-selector/feature-icon.tsx": { + "no-restricted-imports": { + "count": 1 + } + }, + "app/components/header/account-setting/model-provider-page/model-selector/index.tsx": { + "no-restricted-imports": { + "count": 1 + } + }, + "app/components/header/account-setting/model-provider-page/model-selector/model-trigger.tsx": { + "no-restricted-imports": { + "count": 1 + } + }, "app/components/header/account-setting/model-provider-page/model-selector/popup-item.tsx": { + "no-restricted-imports": { + "count": 1 + }, "tailwindcss/enforce-consistent-class-order": { "count": 3 } }, + "app/components/header/account-setting/model-provider-page/model-selector/popup.tsx": { + "no-restricted-imports": { + "count": 1 + } + }, "app/components/header/account-setting/model-provider-page/provider-added-card/add-model-button.tsx": { "tailwindcss/enforce-consistent-class-order": { "count": 1 } }, "app/components/header/account-setting/model-provider-page/provider-added-card/cooldown-timer.tsx": { + "no-restricted-imports": { + "count": 1 + }, "react-hooks-extra/no-direct-set-state-in-use-effect": { "count": 2 } @@ -4129,6 +4854,9 @@ } }, "app/components/header/account-setting/model-provider-page/provider-added-card/model-list-item.tsx": { + "no-restricted-imports": { + "count": 1 + }, "tailwindcss/enforce-consistent-class-order": { "count": 1 } @@ -4139,11 +4867,17 @@ } }, "app/components/header/account-setting/model-provider-page/provider-added-card/model-load-balancing-configs.tsx": { + "no-restricted-imports": { + "count": 1 + }, "ts/no-explicit-any": { "count": 5 } }, "app/components/header/account-setting/model-provider-page/provider-added-card/model-load-balancing-modal.tsx": { + "no-restricted-imports": { + "count": 1 + }, "react-hooks-extra/no-direct-set-state-in-use-effect": { "count": 1 }, @@ -4154,7 +4888,15 @@ "count": 3 } }, + "app/components/header/account-setting/model-provider-page/provider-added-card/priority-use-tip.tsx": { + "no-restricted-imports": { + "count": 1 + } + }, "app/components/header/account-setting/model-provider-page/provider-added-card/quota-panel.tsx": { + "no-restricted-imports": { + "count": 1 + }, "tailwindcss/enforce-consistent-class-order": { "count": 2 } @@ -4164,6 +4906,11 @@ "count": 1 } }, + "app/components/header/account-setting/model-provider-page/system-model-selector/index.tsx": { + "no-restricted-imports": { + "count": 2 + } + }, "app/components/header/account-setting/plugin-page/utils.ts": { "ts/no-explicit-any": { "count": 4 @@ -4202,12 +4949,20 @@ "count": 1 } }, + "app/components/plugins/base/badges/icon-with-tooltip.tsx": { + "no-restricted-imports": { + "count": 1 + } + }, "app/components/plugins/base/deprecation-notice.tsx": { "tailwindcss/enforce-consistent-class-order": { "count": 1 } }, "app/components/plugins/base/key-value-item.tsx": { + "no-restricted-imports": { + "count": 1 + }, "tailwindcss/enforce-consistent-class-order": { "count": 2 }, @@ -4264,6 +5019,9 @@ } }, "app/components/plugins/install-plugin/install-bundle/index.tsx": { + "no-restricted-imports": { + "count": 1 + }, "react-refresh/only-export-components": { "count": 1 }, @@ -4282,6 +5040,9 @@ } }, "app/components/plugins/install-plugin/install-from-github/index.tsx": { + "no-restricted-imports": { + "count": 2 + }, "tailwindcss/enforce-consistent-class-order": { "count": 2 }, @@ -4295,6 +5056,9 @@ } }, "app/components/plugins/install-plugin/install-from-github/steps/selectPackage.tsx": { + "no-restricted-imports": { + "count": 2 + }, "ts/no-explicit-any": { "count": 1 } @@ -4305,6 +5069,9 @@ } }, "app/components/plugins/install-plugin/install-from-local-package/index.tsx": { + "no-restricted-imports": { + "count": 1 + }, "tailwindcss/enforce-consistent-class-order": { "count": 1 } @@ -4323,6 +5090,9 @@ } }, "app/components/plugins/install-plugin/install-from-marketplace/index.tsx": { + "no-restricted-imports": { + "count": 1 + }, "tailwindcss/enforce-consistent-class-order": { "count": 1 } @@ -4371,6 +5141,9 @@ } }, "app/components/plugins/marketplace/search-box/tags-filter.tsx": { + "no-restricted-imports": { + "count": 1 + }, "tailwindcss/enforce-consistent-class-order": { "count": 1 } @@ -4386,6 +5159,9 @@ } }, "app/components/plugins/marketplace/sort-dropdown/index.tsx": { + "no-restricted-imports": { + "count": 1 + }, "tailwindcss/enforce-consistent-class-order": { "count": 3 } @@ -4399,16 +5175,25 @@ } }, "app/components/plugins/plugin-auth/authorize/api-key-modal.tsx": { + "no-restricted-imports": { + "count": 1 + }, "ts/no-explicit-any": { "count": 2 } }, "app/components/plugins/plugin-auth/authorize/index.tsx": { + "no-restricted-imports": { + "count": 1 + }, "tailwindcss/enforce-consistent-class-order": { "count": 1 } }, "app/components/plugins/plugin-auth/authorize/oauth-client-settings.tsx": { + "no-restricted-imports": { + "count": 1 + }, "ts/no-explicit-any": { "count": 2 } @@ -4419,6 +5204,9 @@ } }, "app/components/plugins/plugin-auth/authorized/index.tsx": { + "no-restricted-imports": { + "count": 2 + }, "tailwindcss/enforce-consistent-class-order": { "count": 2 }, @@ -4427,6 +5215,9 @@ } }, "app/components/plugins/plugin-auth/authorized/item.tsx": { + "no-restricted-imports": { + "count": 1 + }, "tailwindcss/enforce-consistent-class-order": { "count": 1 }, @@ -4473,6 +5264,9 @@ } }, "app/components/plugins/plugin-detail-panel/app-selector/app-inputs-form.tsx": { + "no-restricted-imports": { + "count": 1 + }, "tailwindcss/enforce-consistent-class-order": { "count": 2 }, @@ -4486,6 +5280,9 @@ } }, "app/components/plugins/plugin-detail-panel/app-selector/app-picker.tsx": { + "no-restricted-imports": { + "count": 1 + }, "tailwindcss/enforce-consistent-class-order": { "count": 2 } @@ -4496,6 +5293,9 @@ } }, "app/components/plugins/plugin-detail-panel/app-selector/index.tsx": { + "no-restricted-imports": { + "count": 1 + }, "tailwindcss/enforce-consistent-class-order": { "count": 1 } @@ -4509,11 +5309,22 @@ } }, "app/components/plugins/plugin-detail-panel/detail-header/components/plugin-source-badge.tsx": { + "no-restricted-imports": { + "count": 1 + }, "tailwindcss/enforce-consistent-class-order": { "count": 1 } }, + "app/components/plugins/plugin-detail-panel/detail-header/index.tsx": { + "no-restricted-imports": { + "count": 1 + } + }, "app/components/plugins/plugin-detail-panel/endpoint-card.tsx": { + "no-restricted-imports": { + "count": 1 + }, "tailwindcss/enforce-consistent-class-order": { "count": 5 }, @@ -4522,6 +5333,9 @@ } }, "app/components/plugins/plugin-detail-panel/endpoint-list.tsx": { + "no-restricted-imports": { + "count": 1 + }, "tailwindcss/enforce-consistent-class-order": { "count": 4 }, @@ -4546,6 +5360,9 @@ } }, "app/components/plugins/plugin-detail-panel/model-selector/index.tsx": { + "no-restricted-imports": { + "count": 1 + }, "tailwindcss/enforce-consistent-class-order": { "count": 1 }, @@ -4559,6 +5376,9 @@ } }, "app/components/plugins/plugin-detail-panel/model-selector/tts-params-panel.tsx": { + "no-restricted-imports": { + "count": 1 + }, "tailwindcss/enforce-consistent-class-order": { "count": 2 }, @@ -4567,6 +5387,9 @@ } }, "app/components/plugins/plugin-detail-panel/multiple-tool-selector/index.tsx": { + "no-restricted-imports": { + "count": 1 + }, "tailwindcss/enforce-consistent-class-order": { "count": 3 }, @@ -4575,6 +5398,9 @@ } }, "app/components/plugins/plugin-detail-panel/operation-dropdown.tsx": { + "no-restricted-imports": { + "count": 1 + }, "tailwindcss/enforce-consistent-class-order": { "count": 4 } @@ -4592,12 +5418,25 @@ "count": 2 } }, + "app/components/plugins/plugin-detail-panel/subscription-list/create/common-modal.tsx": { + "no-restricted-imports": { + "count": 1 + } + }, "app/components/plugins/plugin-detail-panel/subscription-list/create/components/modal-steps.tsx": { "tailwindcss/enforce-consistent-class-order": { "count": 3 } }, + "app/components/plugins/plugin-detail-panel/subscription-list/create/index.tsx": { + "no-restricted-imports": { + "count": 3 + } + }, "app/components/plugins/plugin-detail-panel/subscription-list/create/oauth-client.tsx": { + "no-restricted-imports": { + "count": 1 + }, "tailwindcss/enforce-consistent-class-order": { "count": 3 } @@ -4611,16 +5450,32 @@ } }, "app/components/plugins/plugin-detail-panel/subscription-list/edit/apikey-edit-modal.tsx": { + "no-restricted-imports": { + "count": 1 + }, "tailwindcss/enforce-consistent-class-order": { "count": 1 } }, + "app/components/plugins/plugin-detail-panel/subscription-list/edit/manual-edit-modal.tsx": { + "no-restricted-imports": { + "count": 1 + } + }, + "app/components/plugins/plugin-detail-panel/subscription-list/edit/oauth-edit-modal.tsx": { + "no-restricted-imports": { + "count": 1 + } + }, "app/components/plugins/plugin-detail-panel/subscription-list/index.tsx": { "react-refresh/only-export-components": { "count": 1 } }, "app/components/plugins/plugin-detail-panel/subscription-list/list-view.tsx": { + "no-restricted-imports": { + "count": 1 + }, "tailwindcss/enforce-consistent-class-order": { "count": 1 } @@ -4634,26 +5489,41 @@ } }, "app/components/plugins/plugin-detail-panel/subscription-list/selector-entry.tsx": { + "no-restricted-imports": { + "count": 1 + }, "tailwindcss/enforce-consistent-class-order": { "count": 1 } }, "app/components/plugins/plugin-detail-panel/subscription-list/selector-view.tsx": { + "no-restricted-imports": { + "count": 1 + }, "tailwindcss/enforce-consistent-class-order": { "count": 2 } }, "app/components/plugins/plugin-detail-panel/subscription-list/subscription-card.tsx": { + "no-restricted-imports": { + "count": 1 + }, "tailwindcss/enforce-consistent-class-order": { "count": 3 } }, "app/components/plugins/plugin-detail-panel/tool-selector/components/reasoning-config-form.tsx": { + "no-restricted-imports": { + "count": 2 + }, "tailwindcss/enforce-consistent-class-order": { "count": 6 } }, "app/components/plugins/plugin-detail-panel/tool-selector/components/schema-modal.tsx": { + "no-restricted-imports": { + "count": 1 + }, "tailwindcss/enforce-consistent-class-order": { "count": 1 } @@ -4663,6 +5533,11 @@ "count": 2 } }, + "app/components/plugins/plugin-detail-panel/tool-selector/components/tool-item.tsx": { + "no-restricted-imports": { + "count": 2 + } + }, "app/components/plugins/plugin-detail-panel/tool-selector/components/tool-settings-panel.tsx": { "tailwindcss/enforce-consistent-class-order": { "count": 4 @@ -4674,6 +5549,9 @@ } }, "app/components/plugins/plugin-detail-panel/tool-selector/index.tsx": { + "no-restricted-imports": { + "count": 1 + }, "tailwindcss/enforce-consistent-class-order": { "count": 1 } @@ -4691,7 +5569,15 @@ "count": 3 } }, + "app/components/plugins/plugin-item/action.tsx": { + "no-restricted-imports": { + "count": 1 + } + }, "app/components/plugins/plugin-item/index.tsx": { + "no-restricted-imports": { + "count": 1 + }, "tailwindcss/enforce-consistent-class-order": { "count": 7 }, @@ -4700,6 +5586,9 @@ } }, "app/components/plugins/plugin-mutation-model/index.tsx": { + "no-restricted-imports": { + "count": 1 + }, "tailwindcss/enforce-consistent-class-order": { "count": 1 } @@ -4713,6 +5602,9 @@ } }, "app/components/plugins/plugin-page/debug-info.tsx": { + "no-restricted-imports": { + "count": 1 + }, "tailwindcss/enforce-consistent-class-order": { "count": 1 } @@ -4726,16 +5618,30 @@ } }, "app/components/plugins/plugin-page/filter-management/category-filter.tsx": { + "no-restricted-imports": { + "count": 1 + }, "tailwindcss/enforce-consistent-class-order": { "count": 3 } }, "app/components/plugins/plugin-page/filter-management/tag-filter.tsx": { + "no-restricted-imports": { + "count": 1 + }, "tailwindcss/enforce-consistent-class-order": { "count": 3 } }, + "app/components/plugins/plugin-page/index.tsx": { + "no-restricted-imports": { + "count": 1 + } + }, "app/components/plugins/plugin-page/install-plugin-dropdown.tsx": { + "no-restricted-imports": { + "count": 1 + }, "react-hooks-extra/no-direct-set-state-in-use-effect": { "count": 2 }, @@ -4743,11 +5649,26 @@ "count": 2 } }, + "app/components/plugins/plugin-page/plugin-info.tsx": { + "no-restricted-imports": { + "count": 1 + } + }, "app/components/plugins/plugin-page/plugin-tasks/components/plugin-task-list.tsx": { "tailwindcss/enforce-consistent-class-order": { "count": 2 } }, + "app/components/plugins/plugin-page/plugin-tasks/components/task-status-indicator.tsx": { + "no-restricted-imports": { + "count": 1 + } + }, + "app/components/plugins/plugin-page/plugin-tasks/index.tsx": { + "no-restricted-imports": { + "count": 1 + } + }, "app/components/plugins/provider-card.tsx": { "tailwindcss/enforce-consistent-class-order": { "count": 2 @@ -4779,6 +5700,9 @@ } }, "app/components/plugins/reference-setting-modal/auto-update-setting/strategy-picker.tsx": { + "no-restricted-imports": { + "count": 1 + }, "tailwindcss/enforce-consistent-class-order": { "count": 2 } @@ -4791,7 +5715,15 @@ "count": 1 } }, + "app/components/plugins/reference-setting-modal/auto-update-setting/tool-picker.tsx": { + "no-restricted-imports": { + "count": 1 + } + }, "app/components/plugins/reference-setting-modal/index.tsx": { + "no-restricted-imports": { + "count": 1 + }, "tailwindcss/enforce-consistent-class-order": { "count": 1 } @@ -4812,11 +5744,17 @@ } }, "app/components/plugins/update-plugin/from-market-place.tsx": { + "no-restricted-imports": { + "count": 1 + }, "tailwindcss/enforce-consistent-class-order": { "count": 1 } }, "app/components/plugins/update-plugin/plugin-version-picker.tsx": { + "no-restricted-imports": { + "count": 1 + }, "tailwindcss/enforce-consistent-class-order": { "count": 3 } @@ -4880,6 +5818,9 @@ } }, "app/components/rag-pipeline/components/panel/input-field/index.tsx": { + "no-restricted-imports": { + "count": 1 + }, "tailwindcss/enforce-consistent-class-order": { "count": 3 } @@ -4890,6 +5831,9 @@ } }, "app/components/rag-pipeline/components/panel/input-field/label-right-content/global-inputs.tsx": { + "no-restricted-imports": { + "count": 1 + }, "tailwindcss/enforce-consistent-class-order": { "count": 1 } @@ -4958,6 +5902,9 @@ } }, "app/components/rag-pipeline/components/publish-as-knowledge-pipeline-modal.tsx": { + "no-restricted-imports": { + "count": 1 + }, "tailwindcss/enforce-consistent-class-order": { "count": 3 }, @@ -4975,6 +5922,11 @@ "count": 1 } }, + "app/components/rag-pipeline/components/rag-pipeline-header/publisher/index.tsx": { + "no-restricted-imports": { + "count": 1 + } + }, "app/components/rag-pipeline/components/rag-pipeline-header/publisher/popup.tsx": { "tailwindcss/enforce-consistent-class-order": { "count": 8 @@ -4994,11 +5946,17 @@ } }, "app/components/rag-pipeline/components/update-dsl-modal.tsx": { + "no-restricted-imports": { + "count": 1 + }, "tailwindcss/enforce-consistent-class-order": { "count": 3 } }, "app/components/rag-pipeline/components/version-mismatch-modal.tsx": { + "no-restricted-imports": { + "count": 1 + }, "tailwindcss/enforce-consistent-class-order": { "count": 2 } @@ -5055,11 +6013,17 @@ } }, "app/components/share/text-generation/info-modal.tsx": { + "no-restricted-imports": { + "count": 1 + }, "tailwindcss/enforce-consistent-class-order": { "count": 2 } }, "app/components/share/text-generation/menu-dropdown.tsx": { + "no-restricted-imports": { + "count": 1 + }, "react-hooks-extra/no-direct-set-state-in-use-effect": { "count": 1 }, @@ -5097,6 +6061,9 @@ } }, "app/components/share/text-generation/run-once/index.tsx": { + "no-restricted-imports": { + "count": 1 + }, "react-hooks-extra/no-direct-set-state-in-use-effect": { "count": 1 }, @@ -5118,6 +6085,9 @@ } }, "app/components/tools/edit-custom-collection-modal/config-credentials.tsx": { + "no-restricted-imports": { + "count": 1 + }, "tailwindcss/enforce-consistent-class-order": { "count": 7 } @@ -5159,11 +6129,17 @@ } }, "app/components/tools/labels/filter.tsx": { + "no-restricted-imports": { + "count": 1 + }, "tailwindcss/no-unnecessary-whitespace": { "count": 1 } }, "app/components/tools/labels/selector.tsx": { + "no-restricted-imports": { + "count": 1 + }, "tailwindcss/no-unnecessary-whitespace": { "count": 1 } @@ -5182,6 +6158,9 @@ } }, "app/components/tools/mcp/detail/content.tsx": { + "no-restricted-imports": { + "count": 1 + }, "tailwindcss/enforce-consistent-class-order": { "count": 12 }, @@ -5190,11 +6169,17 @@ } }, "app/components/tools/mcp/detail/operation-dropdown.tsx": { + "no-restricted-imports": { + "count": 1 + }, "tailwindcss/enforce-consistent-class-order": { "count": 2 } }, "app/components/tools/mcp/detail/tool-item.tsx": { + "no-restricted-imports": { + "count": 1 + }, "tailwindcss/enforce-consistent-class-order": { "count": 8 } @@ -5205,6 +6190,9 @@ } }, "app/components/tools/mcp/mcp-server-modal.tsx": { + "no-restricted-imports": { + "count": 1 + }, "tailwindcss/enforce-consistent-class-order": { "count": 5 }, @@ -5221,11 +6209,17 @@ } }, "app/components/tools/mcp/mcp-service-card.tsx": { + "no-restricted-imports": { + "count": 1 + }, "tailwindcss/enforce-consistent-class-order": { "count": 3 } }, "app/components/tools/mcp/modal.tsx": { + "no-restricted-imports": { + "count": 1 + }, "tailwindcss/enforce-consistent-class-order": { "count": 7 } @@ -5294,7 +6288,15 @@ "count": 4 } }, + "app/components/tools/workflow-tool/confirm-modal/index.tsx": { + "no-restricted-imports": { + "count": 1 + } + }, "app/components/tools/workflow-tool/index.tsx": { + "no-restricted-imports": { + "count": 1 + }, "tailwindcss/enforce-consistent-class-order": { "count": 7 }, @@ -5303,6 +6305,9 @@ } }, "app/components/tools/workflow-tool/method-selector.tsx": { + "no-restricted-imports": { + "count": 1 + }, "tailwindcss/no-unnecessary-whitespace": { "count": 1 } @@ -5321,6 +6326,9 @@ } }, "app/components/workflow-app/components/workflow-onboarding-modal/index.tsx": { + "no-restricted-imports": { + "count": 1 + }, "tailwindcss/enforce-consistent-class-order": { "count": 3 } @@ -5389,11 +6397,17 @@ } }, "app/components/workflow/block-selector/blocks.tsx": { + "no-restricted-imports": { + "count": 1 + }, "tailwindcss/enforce-consistent-class-order": { "count": 2 } }, "app/components/workflow/block-selector/featured-tools.tsx": { + "no-restricted-imports": { + "count": 1 + }, "react-hooks-extra/no-direct-set-state-in-use-effect": { "count": 2 }, @@ -5408,6 +6422,9 @@ } }, "app/components/workflow/block-selector/featured-triggers.tsx": { + "no-restricted-imports": { + "count": 1 + }, "react-hooks-extra/no-direct-set-state-in-use-effect": { "count": 2 }, @@ -5431,7 +6448,15 @@ "count": 1 } }, + "app/components/workflow/block-selector/main.tsx": { + "no-restricted-imports": { + "count": 1 + } + }, "app/components/workflow/block-selector/market-place-plugin/action.tsx": { + "no-restricted-imports": { + "count": 1 + }, "react-hooks-extra/no-direct-set-state-in-use-effect": { "count": 1 }, @@ -5466,16 +6491,30 @@ } }, "app/components/workflow/block-selector/start-blocks.tsx": { + "no-restricted-imports": { + "count": 1 + }, "tailwindcss/enforce-consistent-class-order": { "count": 4 } }, "app/components/workflow/block-selector/tabs.tsx": { + "no-restricted-imports": { + "count": 1 + }, "tailwindcss/enforce-consistent-class-order": { "count": 1 } }, + "app/components/workflow/block-selector/tool-picker.tsx": { + "no-restricted-imports": { + "count": 1 + } + }, "app/components/workflow/block-selector/tool/action-item.tsx": { + "no-restricted-imports": { + "count": 1 + }, "tailwindcss/enforce-consistent-class-order": { "count": 2 } @@ -5494,6 +6533,9 @@ } }, "app/components/workflow/block-selector/trigger-plugin/action-item.tsx": { + "no-restricted-imports": { + "count": 1 + }, "tailwindcss/enforce-consistent-class-order": { "count": 2 }, @@ -5541,10 +6583,18 @@ } }, "app/components/workflow/dsl-export-confirm-modal.tsx": { + "no-restricted-imports": { + "count": 1 + }, "tailwindcss/enforce-consistent-class-order": { "count": 6 } }, + "app/components/workflow/header/checklist.tsx": { + "no-restricted-imports": { + "count": 1 + } + }, "app/components/workflow/header/editing-title.tsx": { "tailwindcss/enforce-consistent-class-order": { "count": 1 @@ -5577,6 +6627,9 @@ } }, "app/components/workflow/header/test-run-menu.tsx": { + "no-restricted-imports": { + "count": 1 + }, "react-refresh/only-export-components": { "count": 1 }, @@ -5590,11 +6643,22 @@ } }, "app/components/workflow/header/version-history-button.tsx": { + "no-restricted-imports": { + "count": 1 + }, "tailwindcss/enforce-consistent-class-order": { "count": 1 } }, + "app/components/workflow/header/view-history.tsx": { + "no-restricted-imports": { + "count": 2 + } + }, "app/components/workflow/header/view-workflow-history.tsx": { + "no-restricted-imports": { + "count": 1 + }, "ts/no-explicit-any": { "count": 1 } @@ -5688,6 +6752,9 @@ } }, "app/components/workflow/nodes/_base/components/agent-strategy-selector.tsx": { + "no-restricted-imports": { + "count": 3 + }, "ts/no-explicit-any": { "count": 4 } @@ -5706,6 +6773,9 @@ } }, "app/components/workflow/nodes/_base/components/before-run-form/form-item.tsx": { + "no-restricted-imports": { + "count": 1 + }, "ts/no-explicit-any": { "count": 11 } @@ -5733,6 +6803,11 @@ "count": 1 } }, + "app/components/workflow/nodes/_base/components/config-vision.tsx": { + "no-restricted-imports": { + "count": 1 + } + }, "app/components/workflow/nodes/_base/components/editor/base.tsx": { "tailwindcss/enforce-consistent-class-order": { "count": 1 @@ -5773,6 +6848,9 @@ } }, "app/components/workflow/nodes/_base/components/error-handle/error-handle-on-panel.tsx": { + "no-restricted-imports": { + "count": 1 + }, "tailwindcss/enforce-consistent-class-order": { "count": 1 } @@ -5783,6 +6861,9 @@ } }, "app/components/workflow/nodes/_base/components/error-handle/error-handle-type-selector.tsx": { + "no-restricted-imports": { + "count": 1 + }, "tailwindcss/enforce-consistent-class-order": { "count": 2 } @@ -5798,6 +6879,9 @@ } }, "app/components/workflow/nodes/_base/components/field.tsx": { + "no-restricted-imports": { + "count": 1 + }, "tailwindcss/enforce-consistent-class-order": { "count": 2 } @@ -5818,6 +6902,9 @@ } }, "app/components/workflow/nodes/_base/components/form-input-item.tsx": { + "no-restricted-imports": { + "count": 1 + }, "tailwindcss/enforce-consistent-class-order": { "count": 2 }, @@ -5825,17 +6912,30 @@ "count": 33 } }, + "app/components/workflow/nodes/_base/components/form-input-type-switch.tsx": { + "no-restricted-imports": { + "count": 1 + } + }, "app/components/workflow/nodes/_base/components/group.tsx": { "tailwindcss/enforce-consistent-class-order": { "count": 1 } }, + "app/components/workflow/nodes/_base/components/help-link.tsx": { + "no-restricted-imports": { + "count": 1 + } + }, "app/components/workflow/nodes/_base/components/info-panel.tsx": { "tailwindcss/enforce-consistent-class-order": { "count": 2 } }, "app/components/workflow/nodes/_base/components/input-support-select-var.tsx": { + "no-restricted-imports": { + "count": 1 + }, "tailwindcss/no-unnecessary-whitespace": { "count": 1 }, @@ -5849,6 +6949,9 @@ } }, "app/components/workflow/nodes/_base/components/layout/field-title.tsx": { + "no-restricted-imports": { + "count": 1 + }, "tailwindcss/enforce-consistent-class-order": { "count": 1 } @@ -5868,6 +6971,11 @@ "count": 1 } }, + "app/components/workflow/nodes/_base/components/mcp-tool-not-support-tooltip.tsx": { + "no-restricted-imports": { + "count": 1 + } + }, "app/components/workflow/nodes/_base/components/memory-config.tsx": { "tailwindcss/enforce-consistent-class-order": { "count": 1 @@ -5900,10 +7008,18 @@ } }, "app/components/workflow/nodes/_base/components/next-step/operator.tsx": { + "no-restricted-imports": { + "count": 1 + }, "tailwindcss/enforce-consistent-class-order": { "count": 1 } }, + "app/components/workflow/nodes/_base/components/node-control.tsx": { + "no-restricted-imports": { + "count": 1 + } + }, "app/components/workflow/nodes/_base/components/node-handle.tsx": { "react-hooks-extra/no-direct-set-state-in-use-effect": { "count": 1 @@ -5916,6 +7032,9 @@ } }, "app/components/workflow/nodes/_base/components/option-card.tsx": { + "no-restricted-imports": { + "count": 1 + }, "tailwindcss/enforce-consistent-class-order": { "count": 2 } @@ -5925,7 +7044,15 @@ "count": 3 } }, + "app/components/workflow/nodes/_base/components/panel-operator/index.tsx": { + "no-restricted-imports": { + "count": 1 + } + }, "app/components/workflow/nodes/_base/components/prompt/editor.tsx": { + "no-restricted-imports": { + "count": 1 + }, "tailwindcss/no-unnecessary-whitespace": { "count": 2 }, @@ -5952,6 +7079,9 @@ } }, "app/components/workflow/nodes/_base/components/setting-item.tsx": { + "no-restricted-imports": { + "count": 1 + }, "tailwindcss/enforce-consistent-class-order": { "count": 2 }, @@ -5960,6 +7090,9 @@ } }, "app/components/workflow/nodes/_base/components/switch-plugin-version.tsx": { + "no-restricted-imports": { + "count": 1 + }, "tailwindcss/enforce-consistent-class-order": { "count": 1 } @@ -5969,6 +7102,11 @@ "count": 1 } }, + "app/components/workflow/nodes/_base/components/variable/constant-field.tsx": { + "no-restricted-imports": { + "count": 1 + } + }, "app/components/workflow/nodes/_base/components/variable/manage-input-field.tsx": { "tailwindcss/enforce-consistent-class-order": { "count": 2 @@ -5980,6 +7118,9 @@ } }, "app/components/workflow/nodes/_base/components/variable/object-child-tree-panel/picker/field.tsx": { + "no-restricted-imports": { + "count": 1 + }, "tailwindcss/enforce-consistent-class-order": { "count": 2 } @@ -6008,6 +7149,9 @@ } }, "app/components/workflow/nodes/_base/components/variable/var-reference-picker.tsx": { + "no-restricted-imports": { + "count": 2 + }, "react-hooks-extra/no-direct-set-state-in-use-effect": { "count": 2 }, @@ -6027,6 +7171,9 @@ } }, "app/components/workflow/nodes/_base/components/variable/var-reference-vars.tsx": { + "no-restricted-imports": { + "count": 1 + }, "tailwindcss/enforce-consistent-class-order": { "count": 6 }, @@ -6034,7 +7181,15 @@ "count": 3 } }, + "app/components/workflow/nodes/_base/components/variable/var-type-picker.tsx": { + "no-restricted-imports": { + "count": 1 + } + }, "app/components/workflow/nodes/_base/components/variable/variable-label/base/variable-label.tsx": { + "no-restricted-imports": { + "count": 1 + }, "tailwindcss/enforce-consistent-class-order": { "count": 2 } @@ -6050,6 +7205,9 @@ } }, "app/components/workflow/nodes/_base/components/workflow-panel/index.tsx": { + "no-restricted-imports": { + "count": 1 + }, "react-hooks-extra/no-direct-set-state-in-use-effect": { "count": 3 }, @@ -6116,6 +7274,9 @@ } }, "app/components/workflow/nodes/_base/node.tsx": { + "no-restricted-imports": { + "count": 1 + }, "ts/no-explicit-any": { "count": 3 } @@ -6126,10 +7287,18 @@ } }, "app/components/workflow/nodes/agent/components/model-bar.tsx": { + "no-restricted-imports": { + "count": 1 + }, "ts/no-empty-object-type": { "count": 1 } }, + "app/components/workflow/nodes/agent/components/tool-icon.tsx": { + "no-restricted-imports": { + "count": 1 + } + }, "app/components/workflow/nodes/agent/default.ts": { "ts/no-explicit-any": { "count": 3 @@ -6169,6 +7338,9 @@ } }, "app/components/workflow/nodes/assigner/components/operation-selector.tsx": { + "no-restricted-imports": { + "count": 1 + }, "tailwindcss/enforce-consistent-class-order": { "count": 3 } @@ -6213,6 +7385,11 @@ "count": 1 } }, + "app/components/workflow/nodes/code/dependency-picker.tsx": { + "no-restricted-imports": { + "count": 1 + } + }, "app/components/workflow/nodes/code/use-config.ts": { "react-hooks-extra/no-direct-set-state-in-use-effect": { "count": 2 @@ -6309,11 +7486,21 @@ "count": 1 } }, + "app/components/workflow/nodes/http/components/authorization/index.tsx": { + "no-restricted-imports": { + "count": 1 + } + }, "app/components/workflow/nodes/http/components/authorization/radio-group.tsx": { "tailwindcss/enforce-consistent-class-order": { "count": 2 } }, + "app/components/workflow/nodes/http/components/curl-panel.tsx": { + "no-restricted-imports": { + "count": 1 + } + }, "app/components/workflow/nodes/http/components/key-value/key-value-edit/index.tsx": { "tailwindcss/enforce-consistent-class-order": { "count": 1 @@ -6323,6 +7510,9 @@ } }, "app/components/workflow/nodes/http/components/key-value/key-value-edit/item.tsx": { + "no-restricted-imports": { + "count": 1 + }, "tailwindcss/enforce-consistent-class-order": { "count": 1 }, @@ -6354,31 +7544,49 @@ } }, "app/components/workflow/nodes/human-input/components/button-style-dropdown.tsx": { + "no-restricted-imports": { + "count": 1 + }, "tailwindcss/enforce-consistent-class-order": { "count": 1 } }, "app/components/workflow/nodes/human-input/components/delivery-method/email-configure-modal.tsx": { + "no-restricted-imports": { + "count": 1 + }, "tailwindcss/enforce-consistent-class-order": { "count": 8 } }, "app/components/workflow/nodes/human-input/components/delivery-method/index.tsx": { + "no-restricted-imports": { + "count": 1 + }, "tailwindcss/enforce-consistent-class-order": { "count": 2 } }, "app/components/workflow/nodes/human-input/components/delivery-method/method-item.tsx": { + "no-restricted-imports": { + "count": 1 + }, "tailwindcss/enforce-consistent-class-order": { "count": 1 } }, "app/components/workflow/nodes/human-input/components/delivery-method/method-selector.tsx": { + "no-restricted-imports": { + "count": 1 + }, "tailwindcss/enforce-consistent-class-order": { "count": 14 } }, "app/components/workflow/nodes/human-input/components/delivery-method/recipient/email-input.tsx": { + "no-restricted-imports": { + "count": 1 + }, "tailwindcss/enforce-consistent-class-order": { "count": 1 } @@ -6398,7 +7606,15 @@ "count": 6 } }, + "app/components/workflow/nodes/human-input/components/delivery-method/recipient/member-selector.tsx": { + "no-restricted-imports": { + "count": 1 + } + }, "app/components/workflow/nodes/human-input/components/delivery-method/test-email-sender.tsx": { + "no-restricted-imports": { + "count": 1 + }, "tailwindcss/enforce-consistent-class-order": { "count": 21 }, @@ -6407,6 +7623,9 @@ } }, "app/components/workflow/nodes/human-input/components/delivery-method/upgrade-modal.tsx": { + "no-restricted-imports": { + "count": 1 + }, "tailwindcss/enforce-consistent-class-order": { "count": 2 } @@ -6457,6 +7676,9 @@ } }, "app/components/workflow/nodes/human-input/panel.tsx": { + "no-restricted-imports": { + "count": 1 + }, "tailwindcss/enforce-consistent-class-order": { "count": 4 }, @@ -6464,6 +7686,11 @@ "count": 1 } }, + "app/components/workflow/nodes/if-else/components/condition-add.tsx": { + "no-restricted-imports": { + "count": 1 + } + }, "app/components/workflow/nodes/if-else/components/condition-files-list-value.tsx": { "tailwindcss/enforce-consistent-class-order": { "count": 3 @@ -6478,11 +7705,27 @@ } }, "app/components/workflow/nodes/if-else/components/condition-list/condition-item.tsx": { + "no-restricted-imports": { + "count": 1 + }, "tailwindcss/enforce-consistent-class-order": { "count": 2 } }, + "app/components/workflow/nodes/if-else/components/condition-list/condition-operator.tsx": { + "no-restricted-imports": { + "count": 1 + } + }, + "app/components/workflow/nodes/if-else/components/condition-list/condition-var-selector.tsx": { + "no-restricted-imports": { + "count": 1 + } + }, "app/components/workflow/nodes/if-else/components/condition-number-input.tsx": { + "no-restricted-imports": { + "count": 1 + }, "tailwindcss/enforce-consistent-class-order": { "count": 1 }, @@ -6491,6 +7734,9 @@ } }, "app/components/workflow/nodes/if-else/components/condition-wrap.tsx": { + "no-restricted-imports": { + "count": 1 + }, "tailwindcss/no-unnecessary-whitespace": { "count": 2 } @@ -6505,6 +7751,11 @@ "count": 5 } }, + "app/components/workflow/nodes/iteration-start/index.tsx": { + "no-restricted-imports": { + "count": 1 + } + }, "app/components/workflow/nodes/iteration/add-block.tsx": { "tailwindcss/enforce-consistent-class-order": { "count": 1 @@ -6521,6 +7772,9 @@ } }, "app/components/workflow/nodes/iteration/panel.tsx": { + "no-restricted-imports": { + "count": 1 + }, "tailwindcss/enforce-consistent-class-order": { "count": 2 }, @@ -6528,6 +7782,11 @@ "count": 1 } }, + "app/components/workflow/nodes/iteration/use-config.ts": { + "no-restricted-imports": { + "count": 1 + } + }, "app/components/workflow/nodes/iteration/use-single-run-form-params.ts": { "ts/no-explicit-any": { "count": 6 @@ -6539,11 +7798,17 @@ } }, "app/components/workflow/nodes/knowledge-base/components/chunk-structure/selector.tsx": { + "no-restricted-imports": { + "count": 1 + }, "tailwindcss/enforce-consistent-class-order": { "count": 1 } }, "app/components/workflow/nodes/knowledge-base/components/index-method.tsx": { + "no-restricted-imports": { + "count": 1 + }, "tailwindcss/enforce-consistent-class-order": { "count": 1 } @@ -6563,6 +7828,16 @@ "count": 1 } }, + "app/components/workflow/nodes/knowledge-base/components/retrieval-setting/search-method-option.tsx": { + "no-restricted-imports": { + "count": 1 + } + }, + "app/components/workflow/nodes/knowledge-base/components/retrieval-setting/top-k-and-score-threshold.tsx": { + "no-restricted-imports": { + "count": 1 + } + }, "app/components/workflow/nodes/knowledge-base/components/retrieval-setting/type.ts": { "ts/no-explicit-any": { "count": 2 @@ -6587,11 +7862,17 @@ } }, "app/components/workflow/nodes/knowledge-retrieval/components/metadata/add-condition.tsx": { + "no-restricted-imports": { + "count": 1 + }, "tailwindcss/enforce-consistent-class-order": { "count": 2 } }, "app/components/workflow/nodes/knowledge-retrieval/components/metadata/condition-list/condition-common-variable-selector.tsx": { + "no-restricted-imports": { + "count": 1 + }, "tailwindcss/enforce-consistent-class-order": { "count": 4 } @@ -6609,17 +7890,36 @@ "count": 1 } }, + "app/components/workflow/nodes/knowledge-retrieval/components/metadata/condition-list/condition-operator.tsx": { + "no-restricted-imports": { + "count": 1 + } + }, + "app/components/workflow/nodes/knowledge-retrieval/components/metadata/condition-list/condition-value-method.tsx": { + "no-restricted-imports": { + "count": 1 + } + }, "app/components/workflow/nodes/knowledge-retrieval/components/metadata/condition-list/condition-variable-selector.tsx": { + "no-restricted-imports": { + "count": 1 + }, "tailwindcss/enforce-consistent-class-order": { "count": 2 } }, "app/components/workflow/nodes/knowledge-retrieval/components/metadata/metadata-filter/index.tsx": { + "no-restricted-imports": { + "count": 1 + }, "tailwindcss/enforce-consistent-class-order": { "count": 2 } }, "app/components/workflow/nodes/knowledge-retrieval/components/metadata/metadata-filter/metadata-filter-selector.tsx": { + "no-restricted-imports": { + "count": 1 + }, "tailwindcss/enforce-consistent-class-order": { "count": 2 } @@ -6630,11 +7930,17 @@ } }, "app/components/workflow/nodes/knowledge-retrieval/components/metadata/metadata-trigger.tsx": { + "no-restricted-imports": { + "count": 1 + }, "tailwindcss/enforce-consistent-class-order": { "count": 1 } }, "app/components/workflow/nodes/knowledge-retrieval/components/retrieval-config.tsx": { + "no-restricted-imports": { + "count": 1 + }, "tailwindcss/no-unnecessary-whitespace": { "count": 2 } @@ -6673,11 +7979,17 @@ } }, "app/components/workflow/nodes/list-operator/components/filter-condition.tsx": { + "no-restricted-imports": { + "count": 1 + }, "ts/no-explicit-any": { "count": 1 } }, "app/components/workflow/nodes/list-operator/components/sub-variable-picker.tsx": { + "no-restricted-imports": { + "count": 2 + }, "tailwindcss/enforce-consistent-class-order": { "count": 4 }, @@ -6699,6 +8011,9 @@ } }, "app/components/workflow/nodes/llm/components/config-prompt-item.tsx": { + "no-restricted-imports": { + "count": 1 + }, "ts/no-explicit-any": { "count": 3 } @@ -6709,6 +8024,9 @@ } }, "app/components/workflow/nodes/llm/components/json-schema-config-modal/code-editor.tsx": { + "no-restricted-imports": { + "count": 1 + }, "tailwindcss/enforce-consistent-class-order": { "count": 1 }, @@ -6721,7 +8039,15 @@ "count": 1 } }, + "app/components/workflow/nodes/llm/components/json-schema-config-modal/index.tsx": { + "no-restricted-imports": { + "count": 1 + } + }, "app/components/workflow/nodes/llm/components/json-schema-config-modal/json-importer.tsx": { + "no-restricted-imports": { + "count": 1 + }, "tailwindcss/enforce-consistent-class-order": { "count": 2 }, @@ -6738,11 +8064,17 @@ } }, "app/components/workflow/nodes/llm/components/json-schema-config-modal/json-schema-generator/index.tsx": { + "no-restricted-imports": { + "count": 1 + }, "react-hooks-extra/no-direct-set-state-in-use-effect": { "count": 2 } }, "app/components/workflow/nodes/llm/components/json-schema-config-modal/json-schema-generator/prompt-editor.tsx": { + "no-restricted-imports": { + "count": 1 + }, "tailwindcss/enforce-consistent-class-order": { "count": 4 } @@ -6757,6 +8089,11 @@ "count": 3 } }, + "app/components/workflow/nodes/llm/components/json-schema-config-modal/visual-editor/edit-card/actions.tsx": { + "no-restricted-imports": { + "count": 1 + } + }, "app/components/workflow/nodes/llm/components/json-schema-config-modal/visual-editor/edit-card/advanced-options.tsx": { "tailwindcss/enforce-consistent-class-order": { "count": 2 @@ -6781,6 +8118,9 @@ } }, "app/components/workflow/nodes/llm/components/json-schema-config-modal/visual-editor/edit-card/type-selector.tsx": { + "no-restricted-imports": { + "count": 1 + }, "tailwindcss/enforce-consistent-class-order": { "count": 2 } @@ -6804,6 +8144,9 @@ } }, "app/components/workflow/nodes/llm/panel.tsx": { + "no-restricted-imports": { + "count": 1 + }, "tailwindcss/enforce-consistent-class-order": { "count": 3 } @@ -6826,11 +8169,21 @@ "count": 10 } }, + "app/components/workflow/nodes/loop-start/index.tsx": { + "no-restricted-imports": { + "count": 1 + } + }, "app/components/workflow/nodes/loop/add-block.tsx": { "tailwindcss/enforce-consistent-class-order": { "count": 1 } }, + "app/components/workflow/nodes/loop/components/condition-add.tsx": { + "no-restricted-imports": { + "count": 1 + } + }, "app/components/workflow/nodes/loop/components/condition-files-list-value.tsx": { "tailwindcss/enforce-consistent-class-order": { "count": 3 @@ -6845,11 +8198,27 @@ } }, "app/components/workflow/nodes/loop/components/condition-list/condition-item.tsx": { + "no-restricted-imports": { + "count": 1 + }, "tailwindcss/enforce-consistent-class-order": { "count": 2 } }, + "app/components/workflow/nodes/loop/components/condition-list/condition-operator.tsx": { + "no-restricted-imports": { + "count": 1 + } + }, + "app/components/workflow/nodes/loop/components/condition-list/condition-var-selector.tsx": { + "no-restricted-imports": { + "count": 1 + } + }, "app/components/workflow/nodes/loop/components/condition-number-input.tsx": { + "no-restricted-imports": { + "count": 1 + }, "tailwindcss/enforce-consistent-class-order": { "count": 1 }, @@ -6858,6 +8227,9 @@ } }, "app/components/workflow/nodes/loop/components/condition-wrap.tsx": { + "no-restricted-imports": { + "count": 1 + }, "tailwindcss/no-unnecessary-whitespace": { "count": 1 } @@ -6872,11 +8244,21 @@ "count": 3 } }, + "app/components/workflow/nodes/loop/components/loop-variables/input-mode-selec.tsx": { + "no-restricted-imports": { + "count": 1 + } + }, "app/components/workflow/nodes/loop/components/loop-variables/item.tsx": { "ts/no-explicit-any": { "count": 4 } }, + "app/components/workflow/nodes/loop/components/loop-variables/variable-type-select.tsx": { + "no-restricted-imports": { + "count": 1 + } + }, "app/components/workflow/nodes/loop/default.ts": { "ts/no-explicit-any": { "count": 1 @@ -6898,6 +8280,9 @@ } }, "app/components/workflow/nodes/parameter-extractor/components/extract-parameter/update.tsx": { + "no-restricted-imports": { + "count": 2 + }, "ts/no-explicit-any": { "count": 1 } @@ -6907,6 +8292,11 @@ "count": 1 } }, + "app/components/workflow/nodes/parameter-extractor/panel.tsx": { + "no-restricted-imports": { + "count": 1 + } + }, "app/components/workflow/nodes/parameter-extractor/use-config.ts": { "react-hooks-extra/no-direct-set-state-in-use-effect": { "count": 2 @@ -6920,6 +8310,11 @@ "count": 9 } }, + "app/components/workflow/nodes/question-classifier/components/advanced-setting.tsx": { + "no-restricted-imports": { + "count": 1 + } + }, "app/components/workflow/nodes/question-classifier/components/class-item.tsx": { "react-hooks-extra/no-direct-set-state-in-use-effect": { "count": 1 @@ -6936,6 +8331,9 @@ } }, "app/components/workflow/nodes/question-classifier/node.tsx": { + "no-restricted-imports": { + "count": 1 + }, "tailwindcss/enforce-consistent-class-order": { "count": 2 } @@ -6997,6 +8395,9 @@ } }, "app/components/workflow/nodes/tool/components/copy-id.tsx": { + "no-restricted-imports": { + "count": 1 + }, "tailwindcss/enforce-consistent-class-order": { "count": 1 }, @@ -7031,6 +8432,9 @@ } }, "app/components/workflow/nodes/tool/components/tool-form/item.tsx": { + "no-restricted-imports": { + "count": 1 + }, "tailwindcss/enforce-consistent-class-order": { "count": 5 }, @@ -7079,6 +8483,9 @@ } }, "app/components/workflow/nodes/trigger-plugin/components/trigger-form/item.tsx": { + "no-restricted-imports": { + "count": 1 + }, "tailwindcss/enforce-consistent-class-order": { "count": 5 }, @@ -7121,6 +8528,16 @@ "count": 7 } }, + "app/components/workflow/nodes/trigger-schedule/components/frequency-selector.tsx": { + "no-restricted-imports": { + "count": 1 + } + }, + "app/components/workflow/nodes/trigger-schedule/components/monthly-days-selector.tsx": { + "no-restricted-imports": { + "count": 1 + } + }, "app/components/workflow/nodes/trigger-schedule/default.ts": { "regexp/no-unused-capturing-group": { "count": 2 @@ -7130,6 +8547,9 @@ } }, "app/components/workflow/nodes/trigger-webhook/components/generic-table.tsx": { + "no-restricted-imports": { + "count": 1 + }, "tailwindcss/enforce-consistent-class-order": { "count": 3 } @@ -7140,6 +8560,9 @@ } }, "app/components/workflow/nodes/trigger-webhook/panel.tsx": { + "no-restricted-imports": { + "count": 2 + }, "tailwindcss/enforce-consistent-class-order": { "count": 3 } @@ -7154,6 +8577,11 @@ "count": 1 } }, + "app/components/workflow/nodes/variable-assigner/components/add-variable/index.tsx": { + "no-restricted-imports": { + "count": 1 + } + }, "app/components/workflow/nodes/variable-assigner/components/node-variable-item.tsx": { "tailwindcss/enforce-consistent-class-order": { "count": 1 @@ -7189,10 +8617,28 @@ } }, "app/components/workflow/note-node/note-editor/toolbar/color-picker.tsx": { + "no-restricted-imports": { + "count": 1 + }, "react-refresh/only-export-components": { "count": 1 } }, + "app/components/workflow/note-node/note-editor/toolbar/command.tsx": { + "no-restricted-imports": { + "count": 1 + } + }, + "app/components/workflow/note-node/note-editor/toolbar/font-size-selector.tsx": { + "no-restricted-imports": { + "count": 1 + } + }, + "app/components/workflow/note-node/note-editor/toolbar/operator.tsx": { + "no-restricted-imports": { + "count": 1 + } + }, "app/components/workflow/note-node/note-editor/utils.ts": { "regexp/no-useless-quantifier": { "count": 1 @@ -7209,16 +8655,25 @@ } }, "app/components/workflow/operator/more-actions.tsx": { + "no-restricted-imports": { + "count": 1 + }, "tailwindcss/enforce-consistent-class-order": { "count": 6 } }, "app/components/workflow/operator/tip-popup.tsx": { + "no-restricted-imports": { + "count": 1 + }, "tailwindcss/enforce-consistent-class-order": { "count": 1 } }, "app/components/workflow/operator/zoom-in-out.tsx": { + "no-restricted-imports": { + "count": 1 + }, "tailwindcss/enforce-consistent-class-order": { "count": 2 } @@ -7273,6 +8728,11 @@ "count": 6 } }, + "app/components/workflow/panel/chat-variable-panel/components/variable-modal-trigger.tsx": { + "no-restricted-imports": { + "count": 1 + } + }, "app/components/workflow/panel/chat-variable-panel/components/variable-modal.tsx": { "react-hooks-extra/no-direct-set-state-in-use-effect": { "count": 8 @@ -7285,6 +8745,9 @@ } }, "app/components/workflow/panel/chat-variable-panel/components/variable-type-select.tsx": { + "no-restricted-imports": { + "count": 1 + }, "tailwindcss/enforce-consistent-class-order": { "count": 6 }, @@ -7303,6 +8766,9 @@ } }, "app/components/workflow/panel/debug-and-preview/conversation-variable-modal.tsx": { + "no-restricted-imports": { + "count": 1 + }, "tailwindcss/enforce-consistent-class-order": { "count": 9 }, @@ -7316,6 +8782,9 @@ } }, "app/components/workflow/panel/debug-and-preview/index.tsx": { + "no-restricted-imports": { + "count": 1 + }, "tailwindcss/enforce-consistent-class-order": { "count": 1 } @@ -7331,6 +8800,9 @@ } }, "app/components/workflow/panel/env-panel/variable-modal.tsx": { + "no-restricted-imports": { + "count": 1 + }, "react-hooks-extra/no-direct-set-state-in-use-effect": { "count": 4 }, @@ -7341,6 +8813,11 @@ "count": 1 } }, + "app/components/workflow/panel/env-panel/variable-trigger.tsx": { + "no-restricted-imports": { + "count": 1 + } + }, "app/components/workflow/panel/global-variable-panel/index.tsx": { "tailwindcss/enforce-consistent-class-order": { "count": 2 @@ -7366,6 +8843,11 @@ "count": 1 } }, + "app/components/workflow/panel/version-history-panel/context-menu/index.tsx": { + "no-restricted-imports": { + "count": 1 + } + }, "app/components/workflow/panel/version-history-panel/context-menu/menu-item.tsx": { "tailwindcss/enforce-consistent-class-order": { "count": 1 @@ -7375,6 +8857,9 @@ } }, "app/components/workflow/panel/version-history-panel/delete-confirm-modal.tsx": { + "no-restricted-imports": { + "count": 1 + }, "tailwindcss/enforce-consistent-class-order": { "count": 2 }, @@ -7398,6 +8883,9 @@ } }, "app/components/workflow/panel/version-history-panel/filter/index.tsx": { + "no-restricted-imports": { + "count": 1 + }, "tailwindcss/no-unnecessary-whitespace": { "count": 1 } @@ -7418,6 +8906,9 @@ } }, "app/components/workflow/panel/version-history-panel/restore-confirm-modal.tsx": { + "no-restricted-imports": { + "count": 1 + }, "tailwindcss/enforce-consistent-class-order": { "count": 2 }, @@ -7444,6 +8935,9 @@ } }, "app/components/workflow/run/agent-log/agent-log-nav-more.tsx": { + "no-restricted-imports": { + "count": 1 + }, "tailwindcss/enforce-consistent-class-order": { "count": 1 } @@ -7519,6 +9013,9 @@ } }, "app/components/workflow/run/node.tsx": { + "no-restricted-imports": { + "count": 1 + }, "react-hooks-extra/no-direct-set-state-in-use-effect": { "count": 1 }, @@ -7657,6 +9154,9 @@ } }, "app/components/workflow/update-dsl-modal.tsx": { + "no-restricted-imports": { + "count": 1 + }, "tailwindcss/enforce-consistent-class-order": { "count": 5 }, @@ -7713,6 +9213,9 @@ } }, "app/components/workflow/variable-inspect/group.tsx": { + "no-restricted-imports": { + "count": 1 + }, "tailwindcss/enforce-consistent-class-order": { "count": 4 }, @@ -7734,6 +9237,9 @@ } }, "app/components/workflow/variable-inspect/listening.tsx": { + "no-restricted-imports": { + "count": 1 + }, "tailwindcss/enforce-consistent-class-order": { "count": 3 }, @@ -7750,6 +9256,9 @@ } }, "app/components/workflow/variable-inspect/right.tsx": { + "no-restricted-imports": { + "count": 1 + }, "tailwindcss/enforce-consistent-class-order": { "count": 4 }, @@ -7758,6 +9267,9 @@ } }, "app/components/workflow/variable-inspect/trigger.tsx": { + "no-restricted-imports": { + "count": 1 + }, "tailwindcss/enforce-consistent-class-order": { "count": 4 }, @@ -7795,6 +9307,9 @@ } }, "app/components/workflow/workflow-preview/components/nodes/base.tsx": { + "no-restricted-imports": { + "count": 1 + }, "tailwindcss/enforce-consistent-class-order": { "count": 3 }, @@ -7807,7 +9322,20 @@ "count": 1 } }, + "app/components/workflow/workflow-preview/components/nodes/iteration-start/index.tsx": { + "no-restricted-imports": { + "count": 1 + } + }, + "app/components/workflow/workflow-preview/components/nodes/loop-start/index.tsx": { + "no-restricted-imports": { + "count": 1 + } + }, "app/components/workflow/workflow-preview/components/zoom-in-out.tsx": { + "no-restricted-imports": { + "count": 1 + }, "tailwindcss/enforce-consistent-class-order": { "count": 2 } @@ -7818,6 +9346,9 @@ } }, "app/education-apply/expire-notice-modal.tsx": { + "no-restricted-imports": { + "count": 1 + }, "tailwindcss/enforce-consistent-class-order": { "count": 4 } @@ -7836,6 +9367,9 @@ } }, "app/education-apply/search-input.tsx": { + "no-restricted-imports": { + "count": 1 + }, "tailwindcss/enforce-consistent-class-order": { "count": 1 }, @@ -7894,6 +9428,11 @@ "count": 6 } }, + "app/signin/_header.tsx": { + "no-restricted-imports": { + "count": 1 + } + }, "app/signin/components/mail-and-code-auth.tsx": { "tailwindcss/enforce-consistent-class-order": { "count": 1 @@ -7904,6 +9443,11 @@ "count": 1 } }, + "app/signin/invite-settings/page.tsx": { + "no-restricted-imports": { + "count": 1 + } + }, "app/signin/layout.tsx": { "tailwindcss/enforce-consistent-class-order": { "count": 1 @@ -7913,6 +9457,9 @@ } }, "app/signin/one-more-step.tsx": { + "no-restricted-imports": { + "count": 2 + }, "tailwindcss/enforce-consistent-class-order": { "count": 7 }, diff --git a/web/eslint.config.mjs b/web/eslint.config.mjs index cf7825fc61..acc585ca7e 100644 --- a/web/eslint.config.mjs +++ b/web/eslint.config.mjs @@ -6,6 +6,7 @@ import hyoban from 'eslint-plugin-hyoban' import sonar from 'eslint-plugin-sonarjs' import storybook from 'eslint-plugin-storybook' import dify from './eslint-rules/index.js' +import { OVERLAY_MIGRATION_LEGACY_BASE_FILES } from './eslint.constants.mjs' // Enable Tailwind CSS IntelliSense mode for ESLint runs // See: tailwind-css-plugin.ts @@ -145,4 +146,51 @@ export default antfu( 'hyoban/no-dependency-version-prefix': 'error', }, }, + { + name: 'dify/base-ui-primitives', + files: ['app/components/base/ui/**/*.tsx'], + rules: { + 'react-refresh/only-export-components': 'off', + }, + }, + { + name: 'dify/overlay-migration', + files: [GLOB_TS, GLOB_TSX], + ignores: [ + ...GLOB_TESTS, + ...OVERLAY_MIGRATION_LEGACY_BASE_FILES, + ], + rules: { + 'no-restricted-imports': ['error', { + patterns: [{ + group: [ + '**/portal-to-follow-elem', + '**/portal-to-follow-elem/index', + ], + message: 'Deprecated: use semantic overlay primitives from @/app/components/base/ui/ instead. See issue #32767.', + }, { + group: [ + '**/base/tooltip', + '**/base/tooltip/index', + ], + message: 'Deprecated: use @/app/components/base/ui/tooltip instead. See issue #32767.', + }, { + group: [ + '**/base/modal', + '**/base/modal/index', + '**/base/modal/modal', + ], + message: 'Deprecated: use @/app/components/base/ui/dialog instead. See issue #32767.', + }, { + group: [ + '**/base/select', + '**/base/select/index', + '**/base/select/custom', + '**/base/select/pure', + ], + message: 'Deprecated: use @/app/components/base/ui/select instead. See issue #32767.', + }], + }], + }, + }, ) diff --git a/web/eslint.constants.mjs b/web/eslint.constants.mjs new file mode 100644 index 0000000000..2ec571de84 --- /dev/null +++ b/web/eslint.constants.mjs @@ -0,0 +1,29 @@ +export const OVERLAY_MIGRATION_LEGACY_BASE_FILES = [ + 'app/components/base/chat/chat-with-history/header/mobile-operation-dropdown.tsx', + 'app/components/base/chat/chat-with-history/header/operation.tsx', + 'app/components/base/chat/chat-with-history/inputs-form/view-form-dropdown.tsx', + 'app/components/base/chat/chat-with-history/sidebar/operation.tsx', + 'app/components/base/chat/chat/citation/popup.tsx', + 'app/components/base/chat/chat/citation/progress-tooltip.tsx', + 'app/components/base/chat/chat/citation/tooltip.tsx', + 'app/components/base/chat/embedded-chatbot/inputs-form/view-form-dropdown.tsx', + 'app/components/base/chip/index.tsx', + 'app/components/base/date-and-time-picker/date-picker/index.tsx', + 'app/components/base/date-and-time-picker/time-picker/index.tsx', + 'app/components/base/dropdown/index.tsx', + 'app/components/base/features/new-feature-panel/file-upload/setting-modal.tsx', + 'app/components/base/features/new-feature-panel/text-to-speech/voice-settings.tsx', + 'app/components/base/file-uploader/file-from-link-or-local/index.tsx', + 'app/components/base/image-uploader/chat-image-uploader.tsx', + 'app/components/base/image-uploader/text-generation-image-uploader.tsx', + 'app/components/base/modal/modal.tsx', + 'app/components/base/prompt-editor/plugins/context-block/component.tsx', + 'app/components/base/prompt-editor/plugins/history-block/component.tsx', + 'app/components/base/select/custom.tsx', + 'app/components/base/select/index.tsx', + 'app/components/base/select/pure.tsx', + 'app/components/base/sort/index.tsx', + 'app/components/base/tag-management/filter.tsx', + 'app/components/base/theme-selector.tsx', + 'app/components/base/tooltip/index.tsx', +] diff --git a/web/package.json b/web/package.json index 2291d78998..a120f5718d 100644 --- a/web/package.json +++ b/web/package.json @@ -63,6 +63,7 @@ "dependencies": { "@amplitude/analytics-browser": "2.33.1", "@amplitude/plugin-session-replay-browser": "1.23.6", + "@base-ui/react": "1.2.0", "@emoji-mart/data": "1.2.1", "@floating-ui/react": "0.26.28", "@formatjs/intl-localematcher": "0.5.10", diff --git a/web/pnpm-lock.yaml b/web/pnpm-lock.yaml index 338cae78ca..17c3fb7f06 100644 --- a/web/pnpm-lock.yaml +++ b/web/pnpm-lock.yaml @@ -60,6 +60,9 @@ importers: '@amplitude/plugin-session-replay-browser': specifier: 1.23.6 version: 1.23.6(@amplitude/rrweb@2.0.0-alpha.35)(rollup@4.56.0) + '@base-ui/react': + specifier: 1.2.0 + version: 1.2.0(@types/react@19.2.9)(react-dom@19.2.4(react@19.2.4))(react@19.2.4) '@emoji-mart/data': specifier: 1.2.1 version: 1.2.1 @@ -900,6 +903,27 @@ packages: resolution: {integrity: sha512-LwdZHpScM4Qz8Xw2iKSzS+cfglZzJGvofQICy7W7v4caru4EaAmyUuO6BGrbyQ2mYV11W0U8j5mBhd14dd3B0A==} engines: {node: '>=6.9.0'} + '@base-ui/react@1.2.0': + resolution: {integrity: sha512-O6aEQHcm+QyGTFY28xuwRD3SEJGZOBDpyjN2WvpfWYFVhg+3zfXPysAILqtM0C1kWC82MccOE/v1j+GHXE4qIw==} + engines: {node: '>=14.0.0'} + peerDependencies: + '@types/react': ^17 || ^18 || ^19 + react: ^17 || ^18 || ^19 + react-dom: ^17 || ^18 || ^19 + peerDependenciesMeta: + '@types/react': + optional: true + + '@base-ui/utils@0.2.5': + resolution: {integrity: sha512-oYC7w0gp76RI5MxprlGLV0wze0SErZaRl3AAkeP3OnNB/UBMb6RqNf6ZSIlxOc9Qp68Ab3C2VOcJQyRs7Xc7Vw==} + peerDependencies: + '@types/react': ^17 || ^18 || ^19 + react: ^17 || ^18 || ^19 + react-dom: ^17 || ^18 || ^19 + peerDependenciesMeta: + '@types/react': + optional: true + '@bcoe/v8-coverage@1.0.2': resolution: {integrity: sha512-6zABk/ECA/QYSCQ1NGiVwwbQerUCZ+TQbp64Q3AgmfNvurHH0j8TtXa1qbShXA6qqkpAj4V5W8pP6mLe1mcMqA==} engines: {node: '>=18'} @@ -6812,6 +6836,9 @@ packages: resolution: {integrity: sha512-Xf0nWe6RseziFMu+Ap9biiUbmplq6S9/p+7w7YXP/JBHhrUDDUhwa+vANyubuqfZWTveU//DYVGsDG7RKL/vEw==} engines: {node: '>=0.10.0'} + reselect@5.1.1: + resolution: {integrity: sha512-K/BG6eIky/SBpzfHZv/dd+9JBFiS4SWV7FIujVyJRux6e45+73RaUHXLmIR1f7WOMaQ0U1km6qwklRQxpJJY0w==} + reserved-identifiers@1.2.0: resolution: {integrity: sha512-yE7KUfFvaBFzGPs5H3Ops1RevfUEsDc5Iz65rOwWg4lE8HJSYtle77uul3+573457oHvBKuHYDl/xqUkKpEEdw==} engines: {node: '>=18'} @@ -8316,6 +8343,30 @@ snapshots: '@babel/helper-string-parser': 7.27.1 '@babel/helper-validator-identifier': 7.28.5 + '@base-ui/react@1.2.0(@types/react@19.2.9)(react-dom@19.2.4(react@19.2.4))(react@19.2.4)': + dependencies: + '@babel/runtime': 7.28.6 + '@base-ui/utils': 0.2.5(@types/react@19.2.9)(react-dom@19.2.4(react@19.2.4))(react@19.2.4) + '@floating-ui/react-dom': 2.1.6(react-dom@19.2.4(react@19.2.4))(react@19.2.4) + '@floating-ui/utils': 0.2.10 + react: 19.2.4 + react-dom: 19.2.4(react@19.2.4) + tabbable: 6.4.0 + use-sync-external-store: 1.6.0(react@19.2.4) + optionalDependencies: + '@types/react': 19.2.9 + + '@base-ui/utils@0.2.5(@types/react@19.2.9)(react-dom@19.2.4(react@19.2.4))(react@19.2.4)': + dependencies: + '@babel/runtime': 7.28.6 + '@floating-ui/utils': 0.2.10 + react: 19.2.4 + react-dom: 19.2.4(react@19.2.4) + reselect: 5.1.1 + use-sync-external-store: 1.6.0(react@19.2.4) + optionalDependencies: + '@types/react': 19.2.9 + '@bcoe/v8-coverage@1.0.2': {} '@braintree/sanitize-url@7.1.1': {} @@ -15127,6 +15178,8 @@ snapshots: require-from-string@2.0.2: {} + reselect@5.1.1: {} + reserved-identifiers@1.2.0: {} resize-observer-polyfill@1.5.1: {} From 3a8ff301fca403e30435b971f7082f809b760ca1 Mon Sep 17 00:00:00 2001 From: yyh <92089059+lyzno1@users.noreply.github.com> Date: Tue, 3 Mar 2026 18:21:33 +0800 Subject: [PATCH 004/256] test(web): add high-quality unit tests for Base UI wrapper primitives (#32904) --- .../base/ui/dialog/__tests__/index.spec.tsx | 70 +++++ .../ui/dropdown-menu/__tests__/index.spec.tsx | 294 ++++++++++++++++++ .../base/ui/popover/__tests__/index.spec.tsx | 107 +++++++ .../base/ui/select/__tests__/index.spec.tsx | 219 +++++++++++++ .../base/ui/tooltip/__tests__/index.spec.tsx | 95 ++++++ 5 files changed, 785 insertions(+) create mode 100644 web/app/components/base/ui/dialog/__tests__/index.spec.tsx create mode 100644 web/app/components/base/ui/dropdown-menu/__tests__/index.spec.tsx create mode 100644 web/app/components/base/ui/popover/__tests__/index.spec.tsx create mode 100644 web/app/components/base/ui/select/__tests__/index.spec.tsx create mode 100644 web/app/components/base/ui/tooltip/__tests__/index.spec.tsx diff --git a/web/app/components/base/ui/dialog/__tests__/index.spec.tsx b/web/app/components/base/ui/dialog/__tests__/index.spec.tsx new file mode 100644 index 0000000000..0861fff603 --- /dev/null +++ b/web/app/components/base/ui/dialog/__tests__/index.spec.tsx @@ -0,0 +1,70 @@ +import { Dialog as BaseDialog } from '@base-ui/react/dialog' +import { render, screen } from '@testing-library/react' +import { describe, expect, it } from 'vitest' +import { + Dialog, + DialogClose, + DialogContent, + DialogDescription, + DialogTitle, + DialogTrigger, +} from '../index' + +describe('Dialog wrapper', () => { + describe('Rendering', () => { + it('should render dialog content when dialog is open', () => { + render( + + + Dialog Title + Dialog Description + + , + ) + + const dialog = screen.getByRole('dialog') + expect(dialog).toHaveTextContent('Dialog Title') + expect(dialog).toHaveTextContent('Dialog Description') + }) + }) + + describe('Props', () => { + it('should not render close button when closable is omitted', () => { + render( + + + Dialog body + + , + ) + + expect(screen.queryByRole('button', { name: 'Close' })).not.toBeInTheDocument() + }) + + it('should render close button when closable is true', () => { + render( + + + Dialog body + + , + ) + + const dialog = screen.getByRole('dialog') + const closeButton = screen.getByRole('button', { name: 'Close' }) + + expect(dialog).toContainElement(closeButton) + expect(closeButton).toHaveAttribute('aria-label', 'Close') + }) + }) + + describe('Exports', () => { + it('should map dialog aliases to the matching base dialog primitives', () => { + expect(Dialog).toBe(BaseDialog.Root) + expect(DialogTrigger).toBe(BaseDialog.Trigger) + expect(DialogTitle).toBe(BaseDialog.Title) + expect(DialogDescription).toBe(BaseDialog.Description) + expect(DialogClose).toBe(BaseDialog.Close) + }) + }) +}) diff --git a/web/app/components/base/ui/dropdown-menu/__tests__/index.spec.tsx b/web/app/components/base/ui/dropdown-menu/__tests__/index.spec.tsx new file mode 100644 index 0000000000..b381078180 --- /dev/null +++ b/web/app/components/base/ui/dropdown-menu/__tests__/index.spec.tsx @@ -0,0 +1,294 @@ +import { Menu } from '@base-ui/react/menu' +import { fireEvent, render, screen, within } from '@testing-library/react' +import { describe, expect, it, vi } from 'vitest' +import { + DropdownMenu, + DropdownMenuContent, + DropdownMenuGroup, + DropdownMenuItem, + DropdownMenuPortal, + DropdownMenuRadioGroup, + DropdownMenuSeparator, + DropdownMenuSub, + DropdownMenuSubContent, + DropdownMenuSubTrigger, + DropdownMenuTrigger, +} from '../index' + +describe('dropdown-menu wrapper', () => { + describe('alias exports', () => { + it('should map direct aliases to the corresponding Menu primitive when importing menu roots', () => { + expect(DropdownMenu).toBe(Menu.Root) + expect(DropdownMenuPortal).toBe(Menu.Portal) + expect(DropdownMenuTrigger).toBe(Menu.Trigger) + expect(DropdownMenuSub).toBe(Menu.SubmenuRoot) + expect(DropdownMenuGroup).toBe(Menu.Group) + expect(DropdownMenuRadioGroup).toBe(Menu.RadioGroup) + }) + }) + + describe('DropdownMenuContent', () => { + it('should position content at bottom-end with default placement when props are omitted', () => { + render( + + Open + + Content action + + , + ) + + const positioner = screen.getByRole('group', { name: 'content positioner' }) + const popup = screen.getByRole('menu') + + expect(positioner).toHaveAttribute('data-side', 'bottom') + expect(positioner).toHaveAttribute('data-align', 'end') + expect(within(popup).getByRole('menuitem', { name: 'Content action' })).toBeInTheDocument() + }) + + it('should apply custom placement when custom positioning props are provided', () => { + render( + + Open + + Custom content + + , + ) + + const positioner = screen.getByRole('group', { name: 'custom content positioner' }) + const popup = screen.getByRole('menu') + + expect(positioner).toHaveAttribute('data-side', 'top') + expect(positioner).toHaveAttribute('data-align', 'start') + expect(within(popup).getByRole('menuitem', { name: 'Custom content' })).toBeInTheDocument() + }) + + it('should forward passthrough attributes and handlers when positionerProps and popupProps are provided', () => { + const handlePositionerMouseEnter = vi.fn() + const handlePopupClick = vi.fn() + + render( + + Open + + Passthrough content + + , + ) + + const positioner = screen.getByRole('group', { name: 'dropdown content positioner' }) + const popup = screen.getByRole('menu') + fireEvent.mouseEnter(positioner) + fireEvent.click(popup) + + expect(positioner).toHaveAttribute('id', 'dropdown-content-positioner') + expect(popup).toHaveAttribute('id', 'dropdown-content-popup') + expect(handlePositionerMouseEnter).toHaveBeenCalledTimes(1) + expect(handlePopupClick).toHaveBeenCalledTimes(1) + }) + }) + + describe('DropdownMenuSubContent', () => { + it('should position sub-content at left-start with default placement when props are omitted', () => { + render( + + Open + + + More actions + + Sub action + + + + , + ) + + const positioner = screen.getByRole('group', { name: 'sub positioner' }) + expect(positioner).toHaveAttribute('data-side', 'left') + expect(positioner).toHaveAttribute('data-align', 'start') + expect(screen.getByRole('menuitem', { name: 'Sub action' })).toBeInTheDocument() + }) + + it('should apply custom placement and forward passthrough props for sub-content when custom props are provided', () => { + const handlePositionerFocus = vi.fn() + const handlePopupClick = vi.fn() + + render( + + Open + + + More actions + + Custom sub action + + + + , + ) + + const positioner = screen.getByRole('group', { name: 'dropdown sub positioner' }) + const popup = screen.getByRole('menu', { name: 'More actions' }) + fireEvent.focus(positioner) + fireEvent.click(popup) + + expect(positioner).toHaveAttribute('data-side', 'right') + expect(positioner).toHaveAttribute('data-align', 'end') + expect(positioner).toHaveAttribute('id', 'dropdown-sub-positioner') + expect(popup).toHaveAttribute('id', 'dropdown-sub-popup') + expect(handlePositionerFocus).toHaveBeenCalledTimes(1) + expect(handlePopupClick).toHaveBeenCalledTimes(1) + }) + }) + + describe('DropdownMenuSubTrigger', () => { + it('should render submenu trigger content when trigger children are provided', () => { + render( + + Open + + + Trigger item + + + , + ) + + expect(screen.getByRole('menuitem', { name: 'Trigger item' })).toBeInTheDocument() + }) + + it.each([true, false])('should remain interactive and not leak destructive prop when destructive is %s', (destructive) => { + const handleClick = vi.fn() + + render( + + Open + + + + Trigger item + + + + , + ) + + const subTrigger = screen.getByRole('menuitem', { name: 'submenu action' }) + fireEvent.click(subTrigger) + + expect(subTrigger).toHaveAttribute('id', `submenu-trigger-${String(destructive)}`) + expect(subTrigger).not.toHaveAttribute('destructive') + expect(handleClick).toHaveBeenCalledTimes(1) + }) + }) + + describe('DropdownMenuItem', () => { + it.each([true, false])('should remain interactive and not leak destructive prop when destructive is %s', (destructive) => { + const handleClick = vi.fn() + + render( + + Open + + + Item label + + + , + ) + + const item = screen.getByRole('menuitem', { name: 'menu action' }) + fireEvent.click(item) + + expect(item).toHaveAttribute('id', `menu-item-${String(destructive)}`) + expect(item).not.toHaveAttribute('destructive') + expect(handleClick).toHaveBeenCalledTimes(1) + }) + }) + + describe('DropdownMenuSeparator', () => { + it('should forward passthrough props and handlers when separator props are provided', () => { + const handleMouseEnter = vi.fn() + + render( + + Open + + + + , + ) + + const separator = screen.getByRole('separator', { name: 'actions divider' }) + fireEvent.mouseEnter(separator) + + expect(separator).toHaveAttribute('id', 'menu-separator') + expect(handleMouseEnter).toHaveBeenCalledTimes(1) + }) + + it('should keep surrounding menu rows rendered when separator is placed between items', () => { + render( + + Open + + First action + + Second action + + , + ) + + expect(screen.getByRole('menuitem', { name: 'First action' })).toBeInTheDocument() + expect(screen.getByRole('menuitem', { name: 'Second action' })).toBeInTheDocument() + expect(screen.getAllByRole('separator')).toHaveLength(1) + }) + }) +}) diff --git a/web/app/components/base/ui/popover/__tests__/index.spec.tsx b/web/app/components/base/ui/popover/__tests__/index.spec.tsx new file mode 100644 index 0000000000..9d65f8c934 --- /dev/null +++ b/web/app/components/base/ui/popover/__tests__/index.spec.tsx @@ -0,0 +1,107 @@ +import { Popover as BasePopover } from '@base-ui/react/popover' +import { fireEvent, render, screen } from '@testing-library/react' +import { describe, expect, it, vi } from 'vitest' +import { + Popover, + PopoverClose, + PopoverContent, + PopoverDescription, + PopoverTitle, + PopoverTrigger, +} from '..' + +describe('PopoverContent', () => { + describe('Placement', () => { + it('should use bottom placement and default offsets when placement props are not provided', () => { + render( + + Open + + Default content + + , + ) + + const positioner = screen.getByRole('group', { name: 'default positioner' }) + const popup = screen.getByRole('dialog', { name: 'default popover' }) + + expect(positioner).toHaveAttribute('data-side', 'bottom') + expect(positioner).toHaveAttribute('data-align', 'center') + expect(popup).toHaveTextContent('Default content') + }) + + it('should apply parsed custom placement and custom offsets when placement props are provided', () => { + render( + + Open + + Custom placement content + + , + ) + + const positioner = screen.getByRole('group', { name: 'custom positioner' }) + const popup = screen.getByRole('dialog', { name: 'custom popover' }) + + expect(positioner).toHaveAttribute('data-side', 'top') + expect(positioner).toHaveAttribute('data-align', 'end') + expect(popup).toHaveTextContent('Custom placement content') + }) + }) + + describe('Passthrough props', () => { + it('should forward positionerProps and popupProps when passthrough props are provided', () => { + const onPopupClick = vi.fn() + + render( + + Open + + Popover body + + , + ) + + const positioner = screen.getByRole('group', { name: 'popover positioner' }) + const popup = screen.getByRole('dialog', { name: 'popover content' }) + fireEvent.click(popup) + + expect(positioner).toHaveAttribute('id', 'popover-positioner-id') + expect(popup).toHaveAttribute('id', 'popover-popup-id') + expect(onPopupClick).toHaveBeenCalledTimes(1) + }) + }) +}) + +describe('Popover aliases', () => { + describe('Export mapping', () => { + it('should map aliases to the matching base popover primitives when wrapper exports are imported', () => { + expect(Popover).toBe(BasePopover.Root) + expect(PopoverTrigger).toBe(BasePopover.Trigger) + expect(PopoverClose).toBe(BasePopover.Close) + expect(PopoverTitle).toBe(BasePopover.Title) + expect(PopoverDescription).toBe(BasePopover.Description) + }) + }) +}) diff --git a/web/app/components/base/ui/select/__tests__/index.spec.tsx b/web/app/components/base/ui/select/__tests__/index.spec.tsx new file mode 100644 index 0000000000..7744a0e22c --- /dev/null +++ b/web/app/components/base/ui/select/__tests__/index.spec.tsx @@ -0,0 +1,219 @@ +import { fireEvent, render, screen } from '@testing-library/react' +import { describe, expect, it, vi } from 'vitest' +import { Select, SelectContent, SelectItem, SelectTrigger, SelectValue } from '../index' + +const renderOpenSelect = ({ + triggerProps = {}, + contentProps = {}, + onValueChange, +}: { + triggerProps?: Record + contentProps?: Record + onValueChange?: (value: string | null) => void +} = {}) => { + return render( + , + ) +} + +describe('Select wrappers', () => { + describe('SelectTrigger', () => { + it('should render clear button when clearable is true and loading is false', () => { + renderOpenSelect({ + triggerProps: { clearable: true }, + }) + + expect(screen.getByRole('button', { name: /clear selection/i })).toBeInTheDocument() + }) + + it('should hide clear button when loading is true', () => { + renderOpenSelect({ + triggerProps: { clearable: true, loading: true }, + }) + + expect(screen.queryByRole('button', { name: /clear selection/i })).not.toBeInTheDocument() + }) + + it('should forward native trigger props when trigger props are provided', () => { + renderOpenSelect({ + triggerProps: { + 'aria-label': 'Choose option', + 'disabled': true, + }, + }) + + const trigger = screen.getByRole('combobox', { name: 'Choose option' }) + expect(trigger).toBeDisabled() + }) + + it('should call onClear and stop click propagation when clear button is clicked', () => { + const onClear = vi.fn() + const onTriggerClick = vi.fn() + + renderOpenSelect({ + triggerProps: { + clearable: true, + onClear, + onClick: onTriggerClick, + }, + }) + + fireEvent.click(screen.getByRole('button', { name: /clear selection/i })) + + expect(onClear).toHaveBeenCalledTimes(1) + expect(onTriggerClick).not.toHaveBeenCalled() + }) + + it('should stop mouse down propagation when clear button receives mouse down', () => { + const onTriggerMouseDown = vi.fn() + + renderOpenSelect({ + triggerProps: { + clearable: true, + onMouseDown: onTriggerMouseDown, + }, + }) + + fireEvent.mouseDown(screen.getByRole('button', { name: /clear selection/i })) + + expect(onTriggerMouseDown).not.toHaveBeenCalled() + }) + + it('should not throw when clear button is clicked without onClear handler', () => { + renderOpenSelect({ + triggerProps: { clearable: true }, + }) + + const clearButton = screen.getByRole('button', { name: /clear selection/i }) + expect(() => fireEvent.click(clearButton)).not.toThrow() + }) + }) + + describe('SelectContent', () => { + it('should use default placement when placement is not provided', () => { + renderOpenSelect() + + const positioner = screen.getByRole('group', { name: 'select positioner' }) + expect(positioner).toHaveAttribute('data-side', 'bottom') + expect(positioner).toHaveAttribute('data-align', 'start') + }) + + it('should apply custom placement when placement props are provided', () => { + renderOpenSelect({ + contentProps: { + placement: 'top-end', + sideOffset: 12, + alignOffset: 6, + }, + }) + + const positioner = screen.getByRole('group', { name: 'select positioner' }) + expect(positioner).toHaveAttribute('data-side', 'top') + expect(positioner).toHaveAttribute('data-align', 'end') + }) + + it('should forward passthrough props to positioner popup and list when passthrough props are provided', () => { + const onPositionerMouseEnter = vi.fn() + const onPopupClick = vi.fn() + const onListFocus = vi.fn() + + render( + , + ) + + const positioner = screen.getByRole('group', { name: 'select positioner' }) + const popup = screen.getByRole('dialog', { name: 'select popup' }) + const list = screen.getByRole('listbox', { name: 'select list' }) + + fireEvent.mouseEnter(positioner) + fireEvent.click(popup) + fireEvent.focus(list) + + expect(positioner).toHaveAttribute('id', 'select-positioner') + expect(popup).toHaveAttribute('id', 'select-popup') + expect(list).toHaveAttribute('id', 'select-list') + expect(onPositionerMouseEnter).toHaveBeenCalledTimes(1) + expect(onPopupClick).toHaveBeenCalledTimes(1) + expect(onListFocus).toHaveBeenCalledTimes(1) + }) + }) + + describe('SelectItem', () => { + it('should render options when children are provided', () => { + renderOpenSelect() + + expect(screen.getByRole('option', { name: 'Seattle' })).toBeInTheDocument() + expect(screen.getByRole('option', { name: 'New York' })).toBeInTheDocument() + }) + + it('should not call onValueChange when disabled item is clicked', () => { + const onValueChange = vi.fn() + + render( + , + ) + + fireEvent.click(screen.getByRole('option', { name: 'Disabled New York' })) + + expect(onValueChange).not.toHaveBeenCalled() + }) + }) +}) diff --git a/web/app/components/base/ui/tooltip/__tests__/index.spec.tsx b/web/app/components/base/ui/tooltip/__tests__/index.spec.tsx new file mode 100644 index 0000000000..4582f07cbe --- /dev/null +++ b/web/app/components/base/ui/tooltip/__tests__/index.spec.tsx @@ -0,0 +1,95 @@ +import { Tooltip as BaseTooltip } from '@base-ui/react/tooltip' +import { fireEvent, render, screen } from '@testing-library/react' +import { describe, expect, it, vi } from 'vitest' +import { Tooltip, TooltipContent, TooltipProvider, TooltipTrigger } from '../index' + +describe('TooltipContent', () => { + describe('Placement and offsets', () => { + it('should use default top placement when placement is not provided', () => { + render( + + Trigger + + Tooltip body + + , + ) + + const popup = screen.getByRole('tooltip', { name: 'default tooltip' }) + expect(popup).toHaveAttribute('data-side', 'top') + expect(popup).toHaveAttribute('data-align', 'center') + expect(popup).toHaveTextContent('Tooltip body') + }) + + it('should apply custom placement when placement props are provided', () => { + render( + + Trigger + + Custom tooltip body + + , + ) + + const popup = screen.getByRole('tooltip', { name: 'custom tooltip' }) + expect(popup).toHaveAttribute('data-side', 'bottom') + expect(popup).toHaveAttribute('data-align', 'start') + expect(popup).toHaveTextContent('Custom tooltip body') + }) + }) + + describe('Variant and popup props', () => { + it('should render popup content when variant is plain', () => { + render( + + Trigger + + Plain tooltip body + + , + ) + + expect(screen.getByRole('tooltip', { name: 'plain tooltip' })).toHaveTextContent('Plain tooltip body') + }) + + it('should forward popup props and handlers when popup props are provided', () => { + const onMouseEnter = vi.fn() + + render( + + Trigger + + Tooltip body + + , + ) + + const popup = screen.getByRole('tooltip', { name: 'help text' }) + fireEvent.mouseEnter(popup) + + expect(popup).toHaveAttribute('id', 'tooltip-popup-id') + expect(popup).toHaveAttribute('data-track-id', 'tooltip-track') + expect(onMouseEnter).toHaveBeenCalledTimes(1) + }) + }) +}) + +describe('Tooltip aliases', () => { + it('should map alias exports to BaseTooltip components when wrapper exports are imported', () => { + expect(TooltipProvider).toBe(BaseTooltip.Provider) + expect(Tooltip).toBe(BaseTooltip.Root) + expect(TooltipTrigger).toBe(BaseTooltip.Trigger) + }) +}) From 1a90c4d81bcacab9115fd8ebc6bb9bb69e40f8da Mon Sep 17 00:00:00 2001 From: yyh <92089059+lyzno1@users.noreply.github.com> Date: Tue, 3 Mar 2026 18:29:23 +0800 Subject: [PATCH 005/256] refactor(web): migrate document list query state to nuqs (#32339) --- .agents/skills/frontend-testing/SKILL.md | 10 + .../frontend-testing/references/checklist.md | 3 + .../frontend-testing/references/mocking.md | 25 + .../apps/app-list-browsing-flow.test.tsx | 35 +- web/__tests__/apps/create-app-flow.test.tsx | 10 +- .../datasets/document-management.test.tsx | 121 ++-- .../tool-browsing-and-filtering.test.tsx | 16 +- .../components/apps/__tests__/list.spec.tsx | 35 +- .../__tests__/use-apps-query-state.spec.tsx | 15 +- web/app/components/apps/list.tsx | 25 +- .../documents/__tests__/index.spec.tsx | 147 +++-- .../components/__tests__/list.spec.tsx | 17 +- .../document-list/__tests__/index.spec.tsx | 42 +- .../__tests__/document-table-row.spec.tsx | 12 + .../components/__tests__/sort-header.spec.tsx | 46 +- .../components/document-table-row.tsx | 13 +- .../document-list/components/sort-header.tsx | 12 +- .../hooks/__tests__/use-document-sort.spec.ts | 362 ++--------- .../document-list/hooks/use-document-sort.ts | 96 +-- .../datasets/documents/components/list.tsx | 35 +- .../documents/detail/__tests__/index.spec.tsx | 38 +- .../datasets/documents/detail/index.tsx | 38 +- .../use-document-list-query-state.spec.ts | 439 ------------- .../use-document-list-query-state.spec.tsx | 426 +++++++++++++ .../use-documents-page-state.spec.ts | 600 +++--------------- .../hooks/use-document-list-query-state.ts | 154 ++--- .../hooks/use-documents-page-state.ts | 153 +---- .../components/datasets/documents/index.tsx | 35 +- .../explore/app-list/__tests__/index.spec.tsx | 11 +- .../marketplace/__tests__/atoms.spec.tsx | 11 +- .../__tests__/plugin-type-switch.spec.tsx | 11 +- .../marketplace/__tests__/state.spec.tsx | 7 +- .../sticky-search-and-switch-wrapper.spec.tsx | 24 +- .../plugins/marketplace/hydration-server.tsx | 3 +- .../plugins/marketplace/search-params.ts | 3 + .../plugin-page/__tests__/context.spec.tsx | 3 + .../plugin-page/__tests__/index.spec.tsx | 3 + .../plugins/plugin-page/context.tsx | 25 +- .../components/plugins/plugin-page/index.tsx | 16 +- .../tools/__tests__/provider-list.spec.tsx | 11 +- web/app/components/tools/provider-list.tsx | 19 +- web/context/modal-context.test.tsx | 14 +- web/context/modal-context.tsx | 2 +- web/docs/test.md | 32 + web/eslint-suppressions.json | 15 - web/hooks/use-query-params.spec.tsx | 15 +- web/hooks/use-query-params.ts | 23 +- web/service/knowledge/use-document.ts | 7 +- web/test/nuqs-testing.tsx | 60 ++ 49 files changed, 1272 insertions(+), 2003 deletions(-) delete mode 100644 web/app/components/datasets/documents/hooks/__tests__/use-document-list-query-state.spec.ts create mode 100644 web/app/components/datasets/documents/hooks/__tests__/use-document-list-query-state.spec.tsx create mode 100644 web/test/nuqs-testing.tsx diff --git a/.agents/skills/frontend-testing/SKILL.md b/.agents/skills/frontend-testing/SKILL.md index 280fcb6341..69c099a262 100644 --- a/.agents/skills/frontend-testing/SKILL.md +++ b/.agents/skills/frontend-testing/SKILL.md @@ -204,6 +204,16 @@ When assigned to test a directory/path, test **ALL content** within that path: > See [Test Structure Template](#test-structure-template) for correct import/mock patterns. +### `nuqs` Query State Testing (Required for URL State Hooks) + +When a component or hook uses `useQueryState` / `useQueryStates`: + +- ✅ Use `NuqsTestingAdapter` (prefer shared helpers in `web/test/nuqs-testing.tsx`) +- ✅ Assert URL synchronization via `onUrlUpdate` (`searchParams`, `options.history`) +- ✅ For custom parsers (`createParser`), keep `parse` and `serialize` bijective and add round-trip edge cases (`%2F`, `%25`, spaces, legacy encoded values) +- ✅ Verify default-clearing behavior (default values should be removed from URL when applicable) +- ⚠️ Only mock `nuqs` directly when URL behavior is explicitly out of scope for the test + ## Core Principles ### 1. AAA Pattern (Arrange-Act-Assert) diff --git a/.agents/skills/frontend-testing/references/checklist.md b/.agents/skills/frontend-testing/references/checklist.md index 1ff2b27bbb..10b8fb66f9 100644 --- a/.agents/skills/frontend-testing/references/checklist.md +++ b/.agents/skills/frontend-testing/references/checklist.md @@ -80,6 +80,9 @@ Use this checklist when generating or reviewing tests for Dify frontend componen - [ ] Router mocks match actual Next.js API - [ ] Mocks reflect actual component conditional behavior - [ ] Only mock: API services, complex context providers, third-party libs +- [ ] For `nuqs` URL-state tests, wrap with `NuqsTestingAdapter` (prefer `web/test/nuqs-testing.tsx`) +- [ ] For `nuqs` URL-state tests, assert `onUrlUpdate` payload (`searchParams`, `options.history`) +- [ ] If custom `nuqs` parser exists, add round-trip tests for encoded edge cases (`%2F`, `%25`, spaces, legacy encoded values) ### Queries diff --git a/.agents/skills/frontend-testing/references/mocking.md b/.agents/skills/frontend-testing/references/mocking.md index 86bd375987..f58377c4a5 100644 --- a/.agents/skills/frontend-testing/references/mocking.md +++ b/.agents/skills/frontend-testing/references/mocking.md @@ -125,6 +125,31 @@ describe('Component', () => { }) ``` +### 2.1 `nuqs` Query State (Preferred: Testing Adapter) + +For tests that validate URL query behavior, use `NuqsTestingAdapter` instead of mocking `nuqs` directly. + +```typescript +import { renderHookWithNuqs } from '@/test/nuqs-testing' + +it('should sync query to URL with push history', async () => { + const { result, onUrlUpdate } = renderHookWithNuqs(() => useMyQueryState(), { + searchParams: '?page=1', + }) + + act(() => { + result.current.setQuery({ page: 2 }) + }) + + await waitFor(() => expect(onUrlUpdate).toHaveBeenCalled()) + const update = onUrlUpdate.mock.calls[onUrlUpdate.mock.calls.length - 1][0] + expect(update.options.history).toBe('push') + expect(update.searchParams.get('page')).toBe('2') +}) +``` + +Use direct `vi.mock('nuqs')` only when URL synchronization is intentionally out of scope. + ### 3. Portal Components (with Shared State) ```typescript diff --git a/web/__tests__/apps/app-list-browsing-flow.test.tsx b/web/__tests__/apps/app-list-browsing-flow.test.tsx index 1c046f5dd0..19288ecd95 100644 --- a/web/__tests__/apps/app-list-browsing-flow.test.tsx +++ b/web/__tests__/apps/app-list-browsing-flow.test.tsx @@ -8,11 +8,11 @@ */ import type { AppListResponse } from '@/models/app' import type { App } from '@/types/app' -import { fireEvent, render, screen } from '@testing-library/react' -import { NuqsTestingAdapter } from 'nuqs/adapters/testing' +import { fireEvent, screen } from '@testing-library/react' import { beforeEach, describe, expect, it, vi } from 'vitest' import List from '@/app/components/apps/list' import { AccessMode } from '@/models/access-control' +import { renderWithNuqs } from '@/test/nuqs-testing' import { AppModeEnum } from '@/types/app' let mockIsCurrentWorkspaceEditor = true @@ -161,10 +161,9 @@ const createPage = (apps: App[], hasMore = false, page = 1): AppListResponse => }) const renderList = (searchParams?: Record) => { - return render( - - - , + return renderWithNuqs( + , + { searchParams }, ) } @@ -209,11 +208,7 @@ describe('App List Browsing Flow', () => { it('should transition from loading to content when data loads', () => { mockIsLoading = true - const { rerender } = render( - - - , - ) + const { rerender } = renderWithNuqs() const skeletonCards = document.querySelectorAll('.animate-pulse') expect(skeletonCards.length).toBeGreaterThan(0) @@ -224,11 +219,7 @@ describe('App List Browsing Flow', () => { createMockApp({ id: 'app-1', name: 'Loaded App' }), ])] - rerender( - - - , - ) + rerender() expect(screen.getByText('Loaded App')).toBeInTheDocument() }) @@ -424,17 +415,9 @@ describe('App List Browsing Flow', () => { it('should call refetch when controlRefreshList increments', () => { mockPages = [createPage([createMockApp()])] - const { rerender } = render( - - - , - ) + const { rerender } = renderWithNuqs() - rerender( - - - , - ) + rerender() expect(mockRefetch).toHaveBeenCalled() }) diff --git a/web/__tests__/apps/create-app-flow.test.tsx b/web/__tests__/apps/create-app-flow.test.tsx index 556c973b06..a0976d32cc 100644 --- a/web/__tests__/apps/create-app-flow.test.tsx +++ b/web/__tests__/apps/create-app-flow.test.tsx @@ -9,11 +9,11 @@ */ import type { AppListResponse } from '@/models/app' import type { App } from '@/types/app' -import { fireEvent, render, screen, waitFor } from '@testing-library/react' -import { NuqsTestingAdapter } from 'nuqs/adapters/testing' +import { fireEvent, screen, waitFor } from '@testing-library/react' import { beforeEach, describe, expect, it, vi } from 'vitest' import List from '@/app/components/apps/list' import { AccessMode } from '@/models/access-control' +import { renderWithNuqs } from '@/test/nuqs-testing' import { AppModeEnum } from '@/types/app' let mockIsCurrentWorkspaceEditor = true @@ -214,11 +214,7 @@ const createPage = (apps: App[]): AppListResponse => ({ }) const renderList = () => { - return render( - - - , - ) + return renderWithNuqs() } describe('Create App Flow', () => { diff --git a/web/__tests__/datasets/document-management.test.tsx b/web/__tests__/datasets/document-management.test.tsx index 3b901ccee2..8aedd4fc63 100644 --- a/web/__tests__/datasets/document-management.test.tsx +++ b/web/__tests__/datasets/document-management.test.tsx @@ -7,9 +7,10 @@ */ import type { SimpleDocumentDetail } from '@/models/datasets' -import { act, renderHook } from '@testing-library/react' +import { act, renderHook, waitFor } from '@testing-library/react' import { beforeEach, describe, expect, it, vi } from 'vitest' import { DataSourceType } from '@/models/datasets' +import { renderHookWithNuqs } from '@/test/nuqs-testing' const mockPush = vi.fn() vi.mock('next/navigation', () => ({ @@ -28,12 +29,16 @@ const { useDocumentSort } = await import( const { useDocumentSelection } = await import( '@/app/components/datasets/documents/components/document-list/hooks/use-document-selection', ) -const { default: useDocumentListQueryState } = await import( +const { useDocumentListQueryState } = await import( '@/app/components/datasets/documents/hooks/use-document-list-query-state', ) type LocalDoc = SimpleDocumentDetail & { percent?: number } +const renderQueryStateHook = (searchParams = '') => { + return renderHookWithNuqs(() => useDocumentListQueryState(), { searchParams }) +} + const createDoc = (overrides?: Partial): LocalDoc => ({ id: `doc-${Math.random().toString(36).slice(2, 8)}`, name: 'test-doc.txt', @@ -85,7 +90,7 @@ describe('Document Management Flow', () => { describe('URL-based Query State', () => { it('should parse default query from empty URL params', () => { - const { result } = renderHook(() => useDocumentListQueryState()) + const { result } = renderQueryStateHook() expect(result.current.query).toEqual({ page: 1, @@ -96,107 +101,85 @@ describe('Document Management Flow', () => { }) }) - it('should update query and push to router', () => { - const { result } = renderHook(() => useDocumentListQueryState()) + it('should update keyword query with replace history', async () => { + const { result, onUrlUpdate } = renderQueryStateHook() act(() => { result.current.updateQuery({ keyword: 'test', page: 2 }) }) - expect(mockPush).toHaveBeenCalled() - // The push call should contain the updated query params - const pushUrl = mockPush.mock.calls[0][0] as string - expect(pushUrl).toContain('keyword=test') - expect(pushUrl).toContain('page=2') + await waitFor(() => expect(onUrlUpdate).toHaveBeenCalled()) + const update = onUrlUpdate.mock.calls[onUrlUpdate.mock.calls.length - 1][0] + expect(update.options.history).toBe('replace') + expect(update.searchParams.get('keyword')).toBe('test') + expect(update.searchParams.get('page')).toBe('2') }) - it('should reset query to defaults', () => { - const { result } = renderHook(() => useDocumentListQueryState()) + it('should reset query to defaults', async () => { + const { result, onUrlUpdate } = renderQueryStateHook() act(() => { result.current.resetQuery() }) - expect(mockPush).toHaveBeenCalled() - // Default query omits default values from URL - const pushUrl = mockPush.mock.calls[0][0] as string - expect(pushUrl).toBe('/datasets/ds-1/documents') + await waitFor(() => expect(onUrlUpdate).toHaveBeenCalled()) + const update = onUrlUpdate.mock.calls[onUrlUpdate.mock.calls.length - 1][0] + expect(update.options.history).toBe('replace') + expect(update.searchParams.toString()).toBe('') }) }) describe('Document Sort Integration', () => { - it('should return documents unsorted when no sort field set', () => { - const docs = [ - createDoc({ id: 'doc-1', name: 'Banana.txt', word_count: 300 }), - createDoc({ id: 'doc-2', name: 'Apple.txt', word_count: 100 }), - createDoc({ id: 'doc-3', name: 'Cherry.txt', word_count: 200 }), - ] - + it('should derive sort field and order from remote sort value', () => { const { result } = renderHook(() => useDocumentSort({ - documents: docs, - statusFilterValue: '', remoteSortValue: '-created_at', + onRemoteSortChange: vi.fn(), })) - expect(result.current.sortField).toBeNull() - expect(result.current.sortedDocuments).toHaveLength(3) + expect(result.current.sortField).toBe('created_at') + expect(result.current.sortOrder).toBe('desc') }) - it('should sort by name descending', () => { - const docs = [ - createDoc({ id: 'doc-1', name: 'Banana.txt' }), - createDoc({ id: 'doc-2', name: 'Apple.txt' }), - createDoc({ id: 'doc-3', name: 'Cherry.txt' }), - ] - + it('should call remote sort change with descending sort for a new field', () => { + const onRemoteSortChange = vi.fn() const { result } = renderHook(() => useDocumentSort({ - documents: docs, - statusFilterValue: '', remoteSortValue: '-created_at', + onRemoteSortChange, })) act(() => { - result.current.handleSort('name') + result.current.handleSort('hit_count') }) - expect(result.current.sortField).toBe('name') - expect(result.current.sortOrder).toBe('desc') - const names = result.current.sortedDocuments.map(d => d.name) - expect(names).toEqual(['Cherry.txt', 'Banana.txt', 'Apple.txt']) + expect(onRemoteSortChange).toHaveBeenCalledWith('-hit_count') }) - it('should toggle sort order on same field click', () => { - const docs = [createDoc({ id: 'doc-1', name: 'A.txt' }), createDoc({ id: 'doc-2', name: 'B.txt' })] - + it('should toggle descending to ascending when clicking active field', () => { + const onRemoteSortChange = vi.fn() const { result } = renderHook(() => useDocumentSort({ - documents: docs, - statusFilterValue: '', - remoteSortValue: '-created_at', + remoteSortValue: '-hit_count', + onRemoteSortChange, })) - act(() => result.current.handleSort('name')) - expect(result.current.sortOrder).toBe('desc') + act(() => { + result.current.handleSort('hit_count') + }) - act(() => result.current.handleSort('name')) - expect(result.current.sortOrder).toBe('asc') + expect(onRemoteSortChange).toHaveBeenCalledWith('hit_count') }) - it('should filter by status before sorting', () => { - const docs = [ - createDoc({ id: 'doc-1', name: 'A.txt', display_status: 'available' }), - createDoc({ id: 'doc-2', name: 'B.txt', display_status: 'error' }), - createDoc({ id: 'doc-3', name: 'C.txt', display_status: 'available' }), - ] - + it('should ignore null sort field updates', () => { + const onRemoteSortChange = vi.fn() const { result } = renderHook(() => useDocumentSort({ - documents: docs, - statusFilterValue: 'available', remoteSortValue: '-created_at', + onRemoteSortChange, })) - // Only 'available' documents should remain - expect(result.current.sortedDocuments).toHaveLength(2) - expect(result.current.sortedDocuments.every(d => d.display_status === 'available')).toBe(true) + act(() => { + result.current.handleSort(null) + }) + + expect(onRemoteSortChange).not.toHaveBeenCalled() }) }) @@ -309,14 +292,13 @@ describe('Document Management Flow', () => { describe('Cross-Module: Query State → Sort → Selection Pipeline', () => { it('should maintain consistent default state across all hooks', () => { const docs = [createDoc({ id: 'doc-1' })] - const { result: queryResult } = renderHook(() => useDocumentListQueryState()) + const { result: queryResult } = renderQueryStateHook() const { result: sortResult } = renderHook(() => useDocumentSort({ - documents: docs, - statusFilterValue: queryResult.current.query.status, remoteSortValue: queryResult.current.query.sort, + onRemoteSortChange: vi.fn(), })) const { result: selResult } = renderHook(() => useDocumentSelection({ - documents: sortResult.current.sortedDocuments, + documents: docs, selectedIds: [], onSelectedIdChange: vi.fn(), })) @@ -325,8 +307,9 @@ describe('Document Management Flow', () => { expect(queryResult.current.query.sort).toBe('-created_at') expect(queryResult.current.query.status).toBe('all') - // Sort inherits 'all' status → no filtering applied - expect(sortResult.current.sortedDocuments).toHaveLength(1) + // Sort state is derived from URL default sort. + expect(sortResult.current.sortField).toBe('created_at') + expect(sortResult.current.sortOrder).toBe('desc') // Selection starts empty expect(selResult.current.isAllSelected).toBe(false) diff --git a/web/__tests__/tools/tool-browsing-and-filtering.test.tsx b/web/__tests__/tools/tool-browsing-and-filtering.test.tsx index 4e7fa4952b..dbefb1fdc3 100644 --- a/web/__tests__/tools/tool-browsing-and-filtering.test.tsx +++ b/web/__tests__/tools/tool-browsing-and-filtering.test.tsx @@ -28,9 +28,13 @@ vi.mock('react-i18next', () => ({ }), })) -vi.mock('nuqs', () => ({ - useQueryState: () => ['builtin', vi.fn()], -})) +vi.mock('nuqs', async (importOriginal) => { + const actual = await importOriginal() + return { + ...actual, + useQueryState: () => ['builtin', vi.fn()], + } +}) vi.mock('@/context/global-public-context', () => ({ useGlobalPublicStore: () => ({ enable_marketplace: false }), @@ -212,6 +216,12 @@ vi.mock('@/app/components/tools/marketplace', () => ({ default: () => null, })) +vi.mock('@/app/components/tools/marketplace/hooks', () => ({ + useMarketplace: () => ({ + handleScroll: vi.fn(), + }), +})) + vi.mock('@/app/components/tools/mcp', () => ({ default: () =>
MCP List
, })) diff --git a/web/app/components/apps/__tests__/list.spec.tsx b/web/app/components/apps/__tests__/list.spec.tsx index fa83296267..a9bef08243 100644 --- a/web/app/components/apps/__tests__/list.spec.tsx +++ b/web/app/components/apps/__tests__/list.spec.tsx @@ -1,9 +1,7 @@ -import type { UrlUpdateEvent } from 'nuqs/adapters/testing' -import type { ReactNode } from 'react' -import { act, fireEvent, render, screen } from '@testing-library/react' -import { NuqsTestingAdapter } from 'nuqs/adapters/testing' +import { act, fireEvent, screen } from '@testing-library/react' import * as React from 'react' import { useStore as useTagStore } from '@/app/components/base/tag-management/store' +import { renderWithNuqs } from '@/test/nuqs-testing' import { AppModeEnum } from '@/types/app' import List from '../list' @@ -186,21 +184,14 @@ beforeAll(() => { } as unknown as typeof IntersectionObserver }) -// Render helper wrapping with NuqsTestingAdapter -const onUrlUpdate = vi.fn<(event: UrlUpdateEvent) => void>() +// Render helper wrapping with shared nuqs testing helper. const renderList = (searchParams = '') => { - const wrapper = ({ children }: { children: ReactNode }) => ( - - {children} - - ) - return render(, { wrapper }) + return renderWithNuqs(, { searchParams }) } describe('List', () => { beforeEach(() => { vi.clearAllMocks() - onUrlUpdate.mockClear() useTagStore.setState({ tagList: [{ id: 'tag-1', name: 'Test Tag', type: 'app', binding_count: 0 }], showTagManagementModal: false, @@ -277,7 +268,7 @@ describe('List', () => { describe('Tab Navigation', () => { it('should update URL when workflow tab is clicked', async () => { - renderList() + const { onUrlUpdate } = renderList() fireEvent.click(screen.getByText('app.types.workflow')) @@ -287,7 +278,7 @@ describe('List', () => { }) it('should update URL when all tab is clicked', async () => { - renderList('?category=workflow') + const { onUrlUpdate } = renderList('?category=workflow') fireEvent.click(screen.getByText('app.types.all')) @@ -391,18 +382,10 @@ describe('List', () => { describe('Edge Cases', () => { it('should handle multiple renders without issues', () => { - const { rerender } = render( - - - , - ) + const { rerender } = renderWithNuqs() expect(screen.getByText('app.types.all')).toBeInTheDocument() - rerender( - - - , - ) + rerender() expect(screen.getByText('app.types.all')).toBeInTheDocument() }) @@ -448,7 +431,7 @@ describe('List', () => { }) it('should update URL for each app type tab click', async () => { - renderList() + const { onUrlUpdate } = renderList() const appTypeTexts = [ { mode: AppModeEnum.WORKFLOW, text: 'app.types.workflow' }, diff --git a/web/app/components/apps/hooks/__tests__/use-apps-query-state.spec.tsx b/web/app/components/apps/hooks/__tests__/use-apps-query-state.spec.tsx index 0c956b78a4..8e8e5821a8 100644 --- a/web/app/components/apps/hooks/__tests__/use-apps-query-state.spec.tsx +++ b/web/app/components/apps/hooks/__tests__/use-apps-query-state.spec.tsx @@ -1,18 +1,9 @@ -import type { UrlUpdateEvent } from 'nuqs/adapters/testing' -import type { ReactNode } from 'react' -import { act, renderHook, waitFor } from '@testing-library/react' -import { NuqsTestingAdapter } from 'nuqs/adapters/testing' +import { act, waitFor } from '@testing-library/react' +import { renderHookWithNuqs } from '@/test/nuqs-testing' import useAppsQueryState from '../use-apps-query-state' const renderWithAdapter = (searchParams = '') => { - const onUrlUpdate = vi.fn<(event: UrlUpdateEvent) => void>() - const wrapper = ({ children }: { children: ReactNode }) => ( - - {children} - - ) - const { result } = renderHook(() => useAppsQueryState(), { wrapper }) - return { result, onUrlUpdate } + return renderHookWithNuqs(() => useAppsQueryState(), { searchParams }) } describe('useAppsQueryState', () => { diff --git a/web/app/components/apps/list.tsx b/web/app/components/apps/list.tsx index d97cd176ca..6ae422f716 100644 --- a/web/app/components/apps/list.tsx +++ b/web/app/components/apps/list.tsx @@ -3,7 +3,7 @@ import type { FC } from 'react' import { useDebounceFn } from 'ahooks' import dynamic from 'next/dynamic' -import { parseAsString, useQueryState } from 'nuqs' +import { parseAsStringLiteral, useQueryState } from 'nuqs' import { useCallback, useEffect, useRef, useState } from 'react' import { useTranslation } from 'react-i18next' import Input from '@/app/components/base/input' @@ -16,7 +16,7 @@ import { useAppContext } from '@/context/app-context' import { useGlobalPublicStore } from '@/context/global-public-context' import { CheckModal } from '@/hooks/use-pay' import { useInfiniteAppList } from '@/service/use-apps' -import { AppModeEnum } from '@/types/app' +import { AppModeEnum, AppModes } from '@/types/app' import { cn } from '@/utils/classnames' import AppCard from './app-card' import { AppCardSkeleton } from './app-card-skeleton' @@ -33,6 +33,18 @@ const CreateFromDSLModal = dynamic(() => import('@/app/components/app/create-fro ssr: false, }) +const APP_LIST_CATEGORY_VALUES = ['all', ...AppModes] as const +type AppListCategory = typeof APP_LIST_CATEGORY_VALUES[number] +const appListCategorySet = new Set(APP_LIST_CATEGORY_VALUES) + +const isAppListCategory = (value: string): value is AppListCategory => { + return appListCategorySet.has(value) +} + +const parseAsAppListCategory = parseAsStringLiteral(APP_LIST_CATEGORY_VALUES) + .withDefault('all') + .withOptions({ history: 'push' }) + type Props = { controlRefreshList?: number } @@ -45,7 +57,7 @@ const List: FC = ({ const showTagManagementModal = useTagStore(s => s.showTagManagementModal) const [activeTab, setActiveTab] = useQueryState( 'category', - parseAsString.withDefault('all').withOptions({ history: 'push' }), + parseAsAppListCategory, ) const { query: { tagIDs = [], keywords = '', isCreatedByMe: queryIsCreatedByMe = false }, setQuery } = useAppsQueryState() @@ -80,7 +92,7 @@ const List: FC = ({ name: searchKeywords, tag_ids: tagIDs, is_created_by_me: isCreatedByMe, - ...(activeTab !== 'all' ? { mode: activeTab as AppModeEnum } : {}), + ...(activeTab !== 'all' ? { mode: activeTab } : {}), } const { @@ -186,7 +198,10 @@ const List: FC = ({
{ + if (isAppListCategory(nextValue)) + setActiveTab(nextValue) + }} options={options} />
diff --git a/web/app/components/datasets/documents/__tests__/index.spec.tsx b/web/app/components/datasets/documents/__tests__/index.spec.tsx index 1749508ee1..f464c97395 100644 --- a/web/app/components/datasets/documents/__tests__/index.spec.tsx +++ b/web/app/components/datasets/documents/__tests__/index.spec.tsx @@ -4,7 +4,7 @@ import { useDatasetDetailContextWithSelector } from '@/context/dataset-detail' import { useProviderContext } from '@/context/provider-context' import { DataSourceType } from '@/models/datasets' import { useDocumentList } from '@/service/knowledge/use-document' -import useDocumentsPageState from '../hooks/use-documents-page-state' +import { useDocumentsPageState } from '../hooks/use-documents-page-state' import Documents from '../index' // Type for mock selector function - use `as MockState` to bypass strict type checking in tests @@ -117,13 +117,10 @@ const mockHandleStatusFilterClear = vi.fn() const mockHandleSortChange = vi.fn() const mockHandlePageChange = vi.fn() const mockHandleLimitChange = vi.fn() -const mockUpdatePollingState = vi.fn() -const mockAdjustPageForTotal = vi.fn() vi.mock('../hooks/use-documents-page-state', () => ({ - default: vi.fn(() => ({ + useDocumentsPageState: vi.fn(() => ({ inputValue: '', - searchValue: '', debouncedSearchValue: '', handleInputChange: mockHandleInputChange, statusFilterValue: 'all', @@ -138,9 +135,6 @@ vi.mock('../hooks/use-documents-page-state', () => ({ handleLimitChange: mockHandleLimitChange, selectedIds: [] as string[], setSelectedIds: mockSetSelectedIds, - timerCanRun: false, - updatePollingState: mockUpdatePollingState, - adjustPageForTotal: mockAdjustPageForTotal, })), })) @@ -319,6 +313,33 @@ describe('Documents', () => { expect(screen.queryByTestId('documents-list')).not.toBeInTheDocument() }) + it('should keep rendering list when loading with existing data', () => { + vi.mocked(useDocumentList).mockReturnValueOnce({ + data: { + data: [ + { + id: 'doc-1', + name: 'Document 1', + indexing_status: 'completed', + data_source_type: 'upload_file', + position: 1, + enabled: true, + }, + ], + total: 1, + page: 1, + limit: 10, + has_more: false, + } as DocumentListResponse, + isLoading: true, + refetch: vi.fn(), + } as unknown as ReturnType) + + render() + expect(screen.getByTestId('documents-list')).toBeInTheDocument() + expect(screen.getByTestId('list-documents-count')).toHaveTextContent('1') + }) + it('should render empty element when no documents exist', () => { vi.mocked(useDocumentList).mockReturnValueOnce({ data: { data: [], total: 0, page: 1, limit: 10, has_more: false }, @@ -484,17 +505,75 @@ describe('Documents', () => { }) }) - describe('Side Effects and Cleanup', () => { - it('should call updatePollingState when documents response changes', () => { + describe('Query Options', () => { + it('should pass function refetchInterval to useDocumentList', () => { render() - expect(mockUpdatePollingState).toHaveBeenCalled() + const payload = vi.mocked(useDocumentList).mock.calls.at(-1)?.[0] + expect(payload).toBeDefined() + expect(typeof payload?.refetchInterval).toBe('function') }) - it('should call adjustPageForTotal when documents response changes', () => { + it('should stop polling when all documents are in terminal statuses', () => { render() - expect(mockAdjustPageForTotal).toHaveBeenCalled() + const payload = vi.mocked(useDocumentList).mock.calls.at(-1)?.[0] + const refetchInterval = payload?.refetchInterval + expect(typeof refetchInterval).toBe('function') + if (typeof refetchInterval !== 'function') + throw new Error('Expected function refetchInterval') + + const interval = refetchInterval({ + state: { + data: { + data: [ + { indexing_status: 'completed' }, + { indexing_status: 'paused' }, + { indexing_status: 'error' }, + ], + }, + }, + } as unknown as Parameters[0]) + + expect(interval).toBe(false) + }) + + it('should keep polling for transient status filters', () => { + vi.mocked(useDocumentsPageState).mockReturnValueOnce({ + inputValue: '', + debouncedSearchValue: '', + handleInputChange: mockHandleInputChange, + statusFilterValue: 'indexing', + sortValue: '-created_at' as const, + normalizedStatusFilterValue: 'indexing', + handleStatusFilterChange: mockHandleStatusFilterChange, + handleStatusFilterClear: mockHandleStatusFilterClear, + handleSortChange: mockHandleSortChange, + currPage: 0, + limit: 10, + handlePageChange: mockHandlePageChange, + handleLimitChange: mockHandleLimitChange, + selectedIds: [] as string[], + setSelectedIds: mockSetSelectedIds, + }) + + render() + + const payload = vi.mocked(useDocumentList).mock.calls.at(-1)?.[0] + const refetchInterval = payload?.refetchInterval + expect(typeof refetchInterval).toBe('function') + if (typeof refetchInterval !== 'function') + throw new Error('Expected function refetchInterval') + + const interval = refetchInterval({ + state: { + data: { + data: [{ indexing_status: 'completed' }], + }, + }, + } as unknown as Parameters[0]) + + expect(interval).toBe(2500) }) }) @@ -591,36 +670,6 @@ describe('Documents', () => { }) }) - describe('Polling State', () => { - it('should enable polling when documents are indexing', () => { - vi.mocked(useDocumentsPageState).mockReturnValueOnce({ - inputValue: '', - searchValue: '', - debouncedSearchValue: '', - handleInputChange: mockHandleInputChange, - statusFilterValue: 'all', - sortValue: '-created_at' as const, - normalizedStatusFilterValue: 'all', - handleStatusFilterChange: mockHandleStatusFilterChange, - handleStatusFilterClear: mockHandleStatusFilterClear, - handleSortChange: mockHandleSortChange, - currPage: 0, - limit: 10, - handlePageChange: mockHandlePageChange, - handleLimitChange: mockHandleLimitChange, - selectedIds: [] as string[], - setSelectedIds: mockSetSelectedIds, - timerCanRun: true, - updatePollingState: mockUpdatePollingState, - adjustPageForTotal: mockAdjustPageForTotal, - }) - - render() - - expect(screen.getByTestId('documents-list')).toBeInTheDocument() - }) - }) - describe('Pagination', () => { it('should display correct total in list', () => { render() @@ -635,7 +684,6 @@ describe('Documents', () => { it('should handle page changes', () => { vi.mocked(useDocumentsPageState).mockReturnValueOnce({ inputValue: '', - searchValue: '', debouncedSearchValue: '', handleInputChange: mockHandleInputChange, statusFilterValue: 'all', @@ -650,9 +698,6 @@ describe('Documents', () => { handleLimitChange: mockHandleLimitChange, selectedIds: [] as string[], setSelectedIds: mockSetSelectedIds, - timerCanRun: false, - updatePollingState: mockUpdatePollingState, - adjustPageForTotal: mockAdjustPageForTotal, }) render() @@ -664,7 +709,6 @@ describe('Documents', () => { it('should display selected count', () => { vi.mocked(useDocumentsPageState).mockReturnValueOnce({ inputValue: '', - searchValue: '', debouncedSearchValue: '', handleInputChange: mockHandleInputChange, statusFilterValue: 'all', @@ -679,9 +723,6 @@ describe('Documents', () => { handleLimitChange: mockHandleLimitChange, selectedIds: ['doc-1', 'doc-2'], setSelectedIds: mockSetSelectedIds, - timerCanRun: false, - updatePollingState: mockUpdatePollingState, - adjustPageForTotal: mockAdjustPageForTotal, }) render() @@ -693,7 +734,6 @@ describe('Documents', () => { it('should pass filter value to list', () => { vi.mocked(useDocumentsPageState).mockReturnValueOnce({ inputValue: 'test search', - searchValue: 'test search', debouncedSearchValue: 'test search', handleInputChange: mockHandleInputChange, statusFilterValue: 'completed', @@ -708,9 +748,6 @@ describe('Documents', () => { handleLimitChange: mockHandleLimitChange, selectedIds: [] as string[], setSelectedIds: mockSetSelectedIds, - timerCanRun: false, - updatePollingState: mockUpdatePollingState, - adjustPageForTotal: mockAdjustPageForTotal, }) render() diff --git a/web/app/components/datasets/documents/components/__tests__/list.spec.tsx b/web/app/components/datasets/documents/components/__tests__/list.spec.tsx index a96afe3cb4..bb7e170783 100644 --- a/web/app/components/datasets/documents/components/__tests__/list.spec.tsx +++ b/web/app/components/datasets/documents/components/__tests__/list.spec.tsx @@ -20,9 +20,8 @@ const mockHandleSave = vi.fn() vi.mock('../document-list/hooks', () => ({ useDocumentSort: vi.fn(() => ({ sortField: null, - sortOrder: null, + sortOrder: 'desc', handleSort: mockHandleSort, - sortedDocuments: [], })), useDocumentSelection: vi.fn(() => ({ isAllSelected: false, @@ -125,8 +124,8 @@ const defaultProps = { pagination: { total: 0, current: 1, limit: 10, onChange: vi.fn() }, onUpdate: vi.fn(), onManageMetadata: vi.fn(), - statusFilterValue: 'all', - remoteSortValue: '', + remoteSortValue: '-created_at', + onSortChange: vi.fn(), } describe('DocumentList', () => { @@ -140,8 +139,6 @@ describe('DocumentList', () => { render() expect(screen.getByText('#')).toBeInTheDocument() - expect(screen.getByTestId('sort-name')).toBeInTheDocument() - expect(screen.getByTestId('sort-word_count')).toBeInTheDocument() expect(screen.getByTestId('sort-hit_count')).toBeInTheDocument() expect(screen.getByTestId('sort-created_at')).toBeInTheDocument() }) @@ -164,10 +161,9 @@ describe('DocumentList', () => { it('should render document rows from sortedDocuments', () => { const docs = [createDoc({ id: 'a', name: 'Doc A' }), createDoc({ id: 'b', name: 'Doc B' })] vi.mocked(useDocumentSort).mockReturnValue({ - sortField: null, + sortField: 'created_at', sortOrder: 'desc', handleSort: mockHandleSort, - sortedDocuments: docs, } as unknown as ReturnType) render() @@ -182,9 +178,9 @@ describe('DocumentList', () => { it('should call handleSort when sort header is clicked', () => { render() - fireEvent.click(screen.getByTestId('sort-name')) + fireEvent.click(screen.getByTestId('sort-created_at')) - expect(mockHandleSort).toHaveBeenCalledWith('name') + expect(mockHandleSort).toHaveBeenCalledWith('created_at') }) }) @@ -229,7 +225,6 @@ describe('DocumentList', () => { sortField: null, sortOrder: 'desc', handleSort: mockHandleSort, - sortedDocuments: [], } as unknown as ReturnType) render() diff --git a/web/app/components/datasets/documents/components/document-list/__tests__/index.spec.tsx b/web/app/components/datasets/documents/components/document-list/__tests__/index.spec.tsx index 5ea2a00a7d..5053038d5e 100644 --- a/web/app/components/datasets/documents/components/document-list/__tests__/index.spec.tsx +++ b/web/app/components/datasets/documents/components/document-list/__tests__/index.spec.tsx @@ -2,7 +2,7 @@ import type { ReactNode } from 'react' import type { Props as PaginationProps } from '@/app/components/base/pagination' import type { SimpleDocumentDetail } from '@/models/datasets' import { QueryClient, QueryClientProvider } from '@tanstack/react-query' -import { fireEvent, render, screen } from '@testing-library/react' +import { act, fireEvent, render, screen } from '@testing-library/react' import { beforeEach, describe, expect, it, vi } from 'vitest' import { ChunkingMode, DataSourceType } from '@/models/datasets' import DocumentList from '../../list' @@ -13,6 +13,7 @@ vi.mock('next/navigation', () => ({ useRouter: () => ({ push: mockPush, }), + useSearchParams: () => new URLSearchParams(), })) vi.mock('@/context/dataset-detail', () => ({ @@ -90,8 +91,8 @@ describe('DocumentList', () => { pagination: defaultPagination, onUpdate: vi.fn(), onManageMetadata: vi.fn(), - statusFilterValue: '', - remoteSortValue: '', + remoteSortValue: '-created_at', + onSortChange: vi.fn(), } beforeEach(() => { @@ -220,16 +221,15 @@ describe('DocumentList', () => { expect(sortIcons.length).toBeGreaterThan(0) }) - it('should update sort order when sort header is clicked', () => { - render(, { wrapper: createWrapper() }) + it('should call onSortChange when sortable header is clicked', () => { + const onSortChange = vi.fn() + const { container } = render(, { wrapper: createWrapper() }) - // Find and click a sort header by its parent div containing the label text - const sortableHeaders = document.querySelectorAll('[class*="cursor-pointer"]') - if (sortableHeaders.length > 0) { + const sortableHeaders = container.querySelectorAll('thead button') + if (sortableHeaders.length > 0) fireEvent.click(sortableHeaders[0]) - } - expect(screen.getByRole('table')).toBeInTheDocument() + expect(onSortChange).toHaveBeenCalled() }) }) @@ -360,13 +360,15 @@ describe('DocumentList', () => { expect(modal).not.toBeInTheDocument() }) - it('should show rename modal when rename button is clicked', () => { + it('should show rename modal when rename button is clicked', async () => { const { container } = render(, { wrapper: createWrapper() }) // Find and click the rename button in the first row const renameButtons = container.querySelectorAll('.cursor-pointer.rounded-md') if (renameButtons.length > 0) { - fireEvent.click(renameButtons[0]) + await act(async () => { + fireEvent.click(renameButtons[0]) + }) } // After clicking rename, the modal should potentially be visible @@ -384,7 +386,7 @@ describe('DocumentList', () => { }) describe('Edit Metadata Modal', () => { - it('should handle edit metadata action', () => { + it('should handle edit metadata action', async () => { const props = { ...defaultProps, selectedIds: ['doc-1'], @@ -393,7 +395,9 @@ describe('DocumentList', () => { const editButton = screen.queryByRole('button', { name: /metadata/i }) if (editButton) { - fireEvent.click(editButton) + await act(async () => { + fireEvent.click(editButton) + }) } expect(screen.getByRole('table')).toBeInTheDocument() @@ -454,16 +458,6 @@ describe('DocumentList', () => { expect(screen.getByRole('table')).toBeInTheDocument() }) - it('should handle status filter value', () => { - const props = { - ...defaultProps, - statusFilterValue: 'completed', - } - render(, { wrapper: createWrapper() }) - - expect(screen.getByRole('table')).toBeInTheDocument() - }) - it('should handle remote sort value', () => { const props = { ...defaultProps, diff --git a/web/app/components/datasets/documents/components/document-list/components/__tests__/document-table-row.spec.tsx b/web/app/components/datasets/documents/components/document-list/components/__tests__/document-table-row.spec.tsx index ad920e9a37..20a3f7cee1 100644 --- a/web/app/components/datasets/documents/components/document-list/components/__tests__/document-table-row.spec.tsx +++ b/web/app/components/datasets/documents/components/document-list/components/__tests__/document-table-row.spec.tsx @@ -7,11 +7,13 @@ import { DataSourceType } from '@/models/datasets' import DocumentTableRow from '../document-table-row' const mockPush = vi.fn() +let mockSearchParams = '' vi.mock('next/navigation', () => ({ useRouter: () => ({ push: mockPush, }), + useSearchParams: () => new URLSearchParams(mockSearchParams), })) const createTestQueryClient = () => new QueryClient({ @@ -95,6 +97,7 @@ describe('DocumentTableRow', () => { beforeEach(() => { vi.clearAllMocks() + mockSearchParams = '' }) describe('Rendering', () => { @@ -186,6 +189,15 @@ describe('DocumentTableRow', () => { expect(mockPush).toHaveBeenCalledWith('/datasets/custom-dataset/documents/custom-doc') }) + + it('should preserve search params when navigating to detail', () => { + mockSearchParams = 'page=2&status=error' + render(, { wrapper: createWrapper() }) + + fireEvent.click(screen.getByRole('row')) + + expect(mockPush).toHaveBeenCalledWith('/datasets/dataset-1/documents/doc-1?page=2&status=error') + }) }) describe('Word Count Display', () => { diff --git a/web/app/components/datasets/documents/components/document-list/components/__tests__/sort-header.spec.tsx b/web/app/components/datasets/documents/components/document-list/components/__tests__/sort-header.spec.tsx index 777f240d00..8730f3f278 100644 --- a/web/app/components/datasets/documents/components/document-list/components/__tests__/sort-header.spec.tsx +++ b/web/app/components/datasets/documents/components/document-list/components/__tests__/sort-header.spec.tsx @@ -4,8 +4,8 @@ import SortHeader from '../sort-header' describe('SortHeader', () => { const defaultProps = { - field: 'name' as const, - label: 'File Name', + field: 'created_at' as const, + label: 'Upload Time', currentSortField: null, sortOrder: 'desc' as const, onSort: vi.fn(), @@ -14,12 +14,12 @@ describe('SortHeader', () => { describe('rendering', () => { it('should render the label', () => { render() - expect(screen.getByText('File Name')).toBeInTheDocument() + expect(screen.getByText('Upload Time')).toBeInTheDocument() }) it('should render the sort icon', () => { const { container } = render() - const icon = container.querySelector('svg') + const icon = container.querySelector('button span') expect(icon).toBeInTheDocument() }) }) @@ -27,13 +27,13 @@ describe('SortHeader', () => { describe('inactive state', () => { it('should have disabled text color when not active', () => { const { container } = render() - const icon = container.querySelector('svg') + const icon = container.querySelector('button span') expect(icon).toHaveClass('text-text-disabled') }) it('should not be rotated when not active', () => { const { container } = render() - const icon = container.querySelector('svg') + const icon = container.querySelector('button span') expect(icon).not.toHaveClass('rotate-180') }) }) @@ -41,25 +41,25 @@ describe('SortHeader', () => { describe('active state', () => { it('should have tertiary text color when active', () => { const { container } = render( - , + , ) - const icon = container.querySelector('svg') + const icon = container.querySelector('button span') expect(icon).toHaveClass('text-text-tertiary') }) it('should not be rotated when active and desc', () => { const { container } = render( - , + , ) - const icon = container.querySelector('svg') + const icon = container.querySelector('button span') expect(icon).not.toHaveClass('rotate-180') }) it('should be rotated when active and asc', () => { const { container } = render( - , + , ) - const icon = container.querySelector('svg') + const icon = container.querySelector('button span') expect(icon).toHaveClass('rotate-180') }) }) @@ -69,34 +69,22 @@ describe('SortHeader', () => { const onSort = vi.fn() render() - fireEvent.click(screen.getByText('File Name')) + fireEvent.click(screen.getByText('Upload Time')) - expect(onSort).toHaveBeenCalledWith('name') + expect(onSort).toHaveBeenCalledWith('created_at') }) it('should call onSort with correct field', () => { const onSort = vi.fn() - render() + render() - fireEvent.click(screen.getByText('File Name')) + fireEvent.click(screen.getByText('Upload Time')) - expect(onSort).toHaveBeenCalledWith('word_count') + expect(onSort).toHaveBeenCalledWith('hit_count') }) }) describe('different fields', () => { - it('should work with word_count field', () => { - render( - , - ) - expect(screen.getByText('Words')).toBeInTheDocument() - }) - it('should work with hit_count field', () => { render( = React.memo(({ const { t } = useTranslation() const { formatTime } = useTimestamp() const router = useRouter() + const searchParams = useSearchParams() const isFile = doc.data_source_type === DataSourceType.FILE const fileType = isFile ? doc.data_source_detail_dict?.upload_file?.extension : '' + const queryString = searchParams.toString() const handleRowClick = useCallback(() => { - router.push(`/datasets/${datasetId}/documents/${doc.id}`) - }, [router, datasetId, doc.id]) + router.push(`/datasets/${datasetId}/documents/${doc.id}${queryString ? `?${queryString}` : ''}`) + }, [router, datasetId, doc.id, queryString]) const handleCheckboxClick = useCallback((e: React.MouseEvent) => { e.stopPropagation() @@ -100,7 +101,7 @@ const DocumentTableRow: FC = React.memo(({
- {doc.name} + {doc.name} {doc.summary_index_status && (
@@ -113,7 +114,7 @@ const DocumentTableRow: FC = React.memo(({ className="cursor-pointer rounded-md p-1 hover:bg-state-base-hover" onClick={handleRenameClick} > - +
diff --git a/web/app/components/datasets/documents/components/document-list/components/sort-header.tsx b/web/app/components/datasets/documents/components/document-list/components/sort-header.tsx index 1dc13df2b0..1d693565cb 100644 --- a/web/app/components/datasets/documents/components/document-list/components/sort-header.tsx +++ b/web/app/components/datasets/documents/components/document-list/components/sort-header.tsx @@ -1,6 +1,5 @@ import type { FC } from 'react' import type { SortField, SortOrder } from '../hooks' -import { RiArrowDownLine } from '@remixicon/react' import * as React from 'react' import { cn } from '@/utils/classnames' @@ -23,19 +22,20 @@ const SortHeader: FC = React.memo(({ const isDesc = isActive && sortOrder === 'desc' return ( -
onSort(field)} > {label} - -
+ ) }) diff --git a/web/app/components/datasets/documents/components/document-list/hooks/__tests__/use-document-sort.spec.ts b/web/app/components/datasets/documents/components/document-list/hooks/__tests__/use-document-sort.spec.ts index 43bc0e1dd5..004597afa9 100644 --- a/web/app/components/datasets/documents/components/document-list/hooks/__tests__/use-document-sort.spec.ts +++ b/web/app/components/datasets/documents/components/document-list/hooks/__tests__/use-document-sort.spec.ts @@ -1,340 +1,98 @@ -import type { SimpleDocumentDetail } from '@/models/datasets' import { act, renderHook } from '@testing-library/react' -import { describe, expect, it } from 'vitest' +import { describe, expect, it, vi } from 'vitest' import { useDocumentSort } from '../use-document-sort' -type LocalDoc = SimpleDocumentDetail & { percent?: number } - -const createMockDocument = (overrides: Partial = {}): LocalDoc => ({ - id: 'doc1', - name: 'Test Document', - data_source_type: 'upload_file', - data_source_info: {}, - data_source_detail_dict: {}, - word_count: 100, - hit_count: 10, - created_at: 1000000, - position: 1, - doc_form: 'text_model', - enabled: true, - archived: false, - display_status: 'available', - created_from: 'api', - ...overrides, -} as LocalDoc) - describe('useDocumentSort', () => { - describe('initial state', () => { - it('should return null sortField initially', () => { - const { result } = renderHook(() => - useDocumentSort({ - documents: [], - statusFilterValue: '', - remoteSortValue: '', - }), - ) + describe('remote state parsing', () => { + it('should parse descending created_at sort', () => { + const onRemoteSortChange = vi.fn() + const { result } = renderHook(() => useDocumentSort({ + remoteSortValue: '-created_at', + onRemoteSortChange, + })) - expect(result.current.sortField).toBeNull() + expect(result.current.sortField).toBe('created_at') expect(result.current.sortOrder).toBe('desc') }) - it('should return documents unchanged when no sort is applied', () => { - const docs = [ - createMockDocument({ id: 'doc1', name: 'B' }), - createMockDocument({ id: 'doc2', name: 'A' }), - ] + it('should parse ascending hit_count sort', () => { + const onRemoteSortChange = vi.fn() + const { result } = renderHook(() => useDocumentSort({ + remoteSortValue: 'hit_count', + onRemoteSortChange, + })) - const { result } = renderHook(() => - useDocumentSort({ - documents: docs, - statusFilterValue: '', - remoteSortValue: '', - }), - ) + expect(result.current.sortField).toBe('hit_count') + expect(result.current.sortOrder).toBe('asc') + }) - expect(result.current.sortedDocuments).toEqual(docs) + it('should fallback to inactive field for unsupported sort key', () => { + const onRemoteSortChange = vi.fn() + const { result } = renderHook(() => useDocumentSort({ + remoteSortValue: '-name', + onRemoteSortChange, + })) + + expect(result.current.sortField).toBeNull() + expect(result.current.sortOrder).toBe('desc') }) }) describe('handleSort', () => { - it('should set sort field when called', () => { - const { result } = renderHook(() => - useDocumentSort({ - documents: [], - statusFilterValue: '', - remoteSortValue: '', - }), - ) - - act(() => { - result.current.handleSort('name') - }) - - expect(result.current.sortField).toBe('name') - expect(result.current.sortOrder).toBe('desc') - }) - - it('should toggle sort order when same field is clicked twice', () => { - const { result } = renderHook(() => - useDocumentSort({ - documents: [], - statusFilterValue: '', - remoteSortValue: '', - }), - ) - - act(() => { - result.current.handleSort('name') - }) - expect(result.current.sortOrder).toBe('desc') - - act(() => { - result.current.handleSort('name') - }) - expect(result.current.sortOrder).toBe('asc') - - act(() => { - result.current.handleSort('name') - }) - expect(result.current.sortOrder).toBe('desc') - }) - - it('should reset to desc when different field is selected', () => { - const { result } = renderHook(() => - useDocumentSort({ - documents: [], - statusFilterValue: '', - remoteSortValue: '', - }), - ) - - act(() => { - result.current.handleSort('name') - }) - act(() => { - result.current.handleSort('name') - }) - expect(result.current.sortOrder).toBe('asc') - - act(() => { - result.current.handleSort('word_count') - }) - expect(result.current.sortField).toBe('word_count') - expect(result.current.sortOrder).toBe('desc') - }) - - it('should not change state when null is passed', () => { - const { result } = renderHook(() => - useDocumentSort({ - documents: [], - statusFilterValue: '', - remoteSortValue: '', - }), - ) - - act(() => { - result.current.handleSort(null) - }) - - expect(result.current.sortField).toBeNull() - }) - }) - - describe('sorting documents', () => { - const docs = [ - createMockDocument({ id: 'doc1', name: 'Banana', word_count: 200, hit_count: 5, created_at: 3000 }), - createMockDocument({ id: 'doc2', name: 'Apple', word_count: 100, hit_count: 10, created_at: 1000 }), - createMockDocument({ id: 'doc3', name: 'Cherry', word_count: 300, hit_count: 1, created_at: 2000 }), - ] - - it('should sort by name descending', () => { - const { result } = renderHook(() => - useDocumentSort({ - documents: docs, - statusFilterValue: '', - remoteSortValue: '', - }), - ) - - act(() => { - result.current.handleSort('name') - }) - - const names = result.current.sortedDocuments.map(d => d.name) - expect(names).toEqual(['Cherry', 'Banana', 'Apple']) - }) - - it('should sort by name ascending', () => { - const { result } = renderHook(() => - useDocumentSort({ - documents: docs, - statusFilterValue: '', - remoteSortValue: '', - }), - ) - - act(() => { - result.current.handleSort('name') - }) - act(() => { - result.current.handleSort('name') - }) - - const names = result.current.sortedDocuments.map(d => d.name) - expect(names).toEqual(['Apple', 'Banana', 'Cherry']) - }) - - it('should sort by word_count descending', () => { - const { result } = renderHook(() => - useDocumentSort({ - documents: docs, - statusFilterValue: '', - remoteSortValue: '', - }), - ) - - act(() => { - result.current.handleSort('word_count') - }) - - const counts = result.current.sortedDocuments.map(d => d.word_count) - expect(counts).toEqual([300, 200, 100]) - }) - - it('should sort by hit_count ascending', () => { - const { result } = renderHook(() => - useDocumentSort({ - documents: docs, - statusFilterValue: '', - remoteSortValue: '', - }), - ) + it('should switch to desc when selecting a different field', () => { + const onRemoteSortChange = vi.fn() + const { result } = renderHook(() => useDocumentSort({ + remoteSortValue: '-created_at', + onRemoteSortChange, + })) act(() => { result.current.handleSort('hit_count') }) + + expect(onRemoteSortChange).toHaveBeenCalledWith('-hit_count') + }) + + it('should toggle desc -> asc when clicking active field', () => { + const onRemoteSortChange = vi.fn() + const { result } = renderHook(() => useDocumentSort({ + remoteSortValue: '-hit_count', + onRemoteSortChange, + })) + act(() => { result.current.handleSort('hit_count') }) - const counts = result.current.sortedDocuments.map(d => d.hit_count) - expect(counts).toEqual([1, 5, 10]) + expect(onRemoteSortChange).toHaveBeenCalledWith('hit_count') }) - it('should sort by created_at descending', () => { - const { result } = renderHook(() => - useDocumentSort({ - documents: docs, - statusFilterValue: '', - remoteSortValue: '', - }), - ) + it('should toggle asc -> desc when clicking active field', () => { + const onRemoteSortChange = vi.fn() + const { result } = renderHook(() => useDocumentSort({ + remoteSortValue: 'created_at', + onRemoteSortChange, + })) act(() => { result.current.handleSort('created_at') }) - const times = result.current.sortedDocuments.map(d => d.created_at) - expect(times).toEqual([3000, 2000, 1000]) - }) - }) - - describe('status filtering', () => { - const docs = [ - createMockDocument({ id: 'doc1', display_status: 'available' }), - createMockDocument({ id: 'doc2', display_status: 'error' }), - createMockDocument({ id: 'doc3', display_status: 'available' }), - ] - - it('should not filter when statusFilterValue is empty', () => { - const { result } = renderHook(() => - useDocumentSort({ - documents: docs, - statusFilterValue: '', - remoteSortValue: '', - }), - ) - - expect(result.current.sortedDocuments.length).toBe(3) + expect(onRemoteSortChange).toHaveBeenCalledWith('-created_at') }) - it('should not filter when statusFilterValue is all', () => { - const { result } = renderHook(() => - useDocumentSort({ - documents: docs, - statusFilterValue: 'all', - remoteSortValue: '', - }), - ) - - expect(result.current.sortedDocuments.length).toBe(3) - }) - }) - - describe('remoteSortValue reset', () => { - it('should reset sort state when remoteSortValue changes', () => { - const { result, rerender } = renderHook( - ({ remoteSortValue }) => - useDocumentSort({ - documents: [], - statusFilterValue: '', - remoteSortValue, - }), - { initialProps: { remoteSortValue: 'initial' } }, - ) + it('should ignore null field', () => { + const onRemoteSortChange = vi.fn() + const { result } = renderHook(() => useDocumentSort({ + remoteSortValue: '-created_at', + onRemoteSortChange, + })) act(() => { - result.current.handleSort('name') - }) - act(() => { - result.current.handleSort('name') - }) - expect(result.current.sortField).toBe('name') - expect(result.current.sortOrder).toBe('asc') - - rerender({ remoteSortValue: 'changed' }) - - expect(result.current.sortField).toBeNull() - expect(result.current.sortOrder).toBe('desc') - }) - }) - - describe('edge cases', () => { - it('should handle documents with missing values', () => { - const docs = [ - createMockDocument({ id: 'doc1', name: undefined as unknown as string, word_count: undefined }), - createMockDocument({ id: 'doc2', name: 'Test', word_count: 100 }), - ] - - const { result } = renderHook(() => - useDocumentSort({ - documents: docs, - statusFilterValue: '', - remoteSortValue: '', - }), - ) - - act(() => { - result.current.handleSort('name') + result.current.handleSort(null) }) - expect(result.current.sortedDocuments.length).toBe(2) - }) - - it('should handle empty documents array', () => { - const { result } = renderHook(() => - useDocumentSort({ - documents: [], - statusFilterValue: '', - remoteSortValue: '', - }), - ) - - act(() => { - result.current.handleSort('name') - }) - - expect(result.current.sortedDocuments).toEqual([]) + expect(onRemoteSortChange).not.toHaveBeenCalled() }) }) }) diff --git a/web/app/components/datasets/documents/components/document-list/hooks/use-document-sort.ts b/web/app/components/datasets/documents/components/document-list/hooks/use-document-sort.ts index 98cf244f36..0e0b07db6f 100644 --- a/web/app/components/datasets/documents/components/document-list/hooks/use-document-sort.ts +++ b/web/app/components/datasets/documents/components/document-list/hooks/use-document-sort.ts @@ -1,102 +1,42 @@ -import type { SimpleDocumentDetail } from '@/models/datasets' -import { useCallback, useMemo, useRef, useState } from 'react' -import { normalizeStatusForQuery } from '@/app/components/datasets/documents/status-filter' +import { useCallback, useMemo } from 'react' -export type SortField = 'name' | 'word_count' | 'hit_count' | 'created_at' | null +type RemoteSortField = 'hit_count' | 'created_at' +const REMOTE_SORT_FIELDS = new Set(['hit_count', 'created_at']) + +export type SortField = RemoteSortField | null export type SortOrder = 'asc' | 'desc' -type LocalDoc = SimpleDocumentDetail & { percent?: number } - type UseDocumentSortOptions = { - documents: LocalDoc[] - statusFilterValue: string remoteSortValue: string + onRemoteSortChange: (nextSortValue: string) => void } export const useDocumentSort = ({ - documents, - statusFilterValue, remoteSortValue, + onRemoteSortChange, }: UseDocumentSortOptions) => { - const [sortField, setSortField] = useState(null) - const [sortOrder, setSortOrder] = useState('desc') - const prevRemoteSortValueRef = useRef(remoteSortValue) + const sortOrder: SortOrder = remoteSortValue.startsWith('-') ? 'desc' : 'asc' + const sortKey = remoteSortValue.startsWith('-') ? remoteSortValue.slice(1) : remoteSortValue - // Reset sort when remote sort changes - if (prevRemoteSortValueRef.current !== remoteSortValue) { - prevRemoteSortValueRef.current = remoteSortValue - setSortField(null) - setSortOrder('desc') - } + const sortField = useMemo(() => { + return REMOTE_SORT_FIELDS.has(sortKey as RemoteSortField) ? sortKey as RemoteSortField : null + }, [sortKey]) const handleSort = useCallback((field: SortField) => { - if (field === null) + if (!field) return if (sortField === field) { - setSortOrder(prev => prev === 'asc' ? 'desc' : 'asc') + const nextSortOrder = sortOrder === 'desc' ? 'asc' : 'desc' + onRemoteSortChange(nextSortOrder === 'desc' ? `-${field}` : field) + return } - else { - setSortField(field) - setSortOrder('desc') - } - }, [sortField]) - - const sortedDocuments = useMemo(() => { - let filteredDocs = documents - - if (statusFilterValue && statusFilterValue !== 'all') { - filteredDocs = filteredDocs.filter(doc => - typeof doc.display_status === 'string' - && normalizeStatusForQuery(doc.display_status) === statusFilterValue, - ) - } - - if (!sortField) - return filteredDocs - - const sortedDocs = [...filteredDocs].sort((a, b) => { - let aValue: string | number - let bValue: string | number - - switch (sortField) { - case 'name': - aValue = a.name?.toLowerCase() || '' - bValue = b.name?.toLowerCase() || '' - break - case 'word_count': - aValue = a.word_count || 0 - bValue = b.word_count || 0 - break - case 'hit_count': - aValue = a.hit_count || 0 - bValue = b.hit_count || 0 - break - case 'created_at': - aValue = a.created_at - bValue = b.created_at - break - default: - return 0 - } - - if (sortField === 'name') { - const result = (aValue as string).localeCompare(bValue as string) - return sortOrder === 'asc' ? result : -result - } - else { - const result = (aValue as number) - (bValue as number) - return sortOrder === 'asc' ? result : -result - } - }) - - return sortedDocs - }, [documents, sortField, sortOrder, statusFilterValue]) + onRemoteSortChange(`-${field}`) + }, [onRemoteSortChange, sortField, sortOrder]) return { sortField, sortOrder, handleSort, - sortedDocuments, } } diff --git a/web/app/components/datasets/documents/components/list.tsx b/web/app/components/datasets/documents/components/list.tsx index 3106f6c30b..e40e4c061b 100644 --- a/web/app/components/datasets/documents/components/list.tsx +++ b/web/app/components/datasets/documents/components/list.tsx @@ -14,7 +14,7 @@ import { useDatasetDetailContextWithSelector as useDatasetDetailContext } from ' import { ChunkingMode, DocumentActionType } from '@/models/datasets' import BatchAction from '../detail/completed/common/batch-action' import s from '../style.module.css' -import { DocumentTableRow, renderTdValue, SortHeader } from './document-list/components' +import { DocumentTableRow, SortHeader } from './document-list/components' import { useDocumentActions, useDocumentSelection, useDocumentSort } from './document-list/hooks' import RenameModal from './rename-modal' @@ -29,8 +29,8 @@ type DocumentListProps = { pagination: PaginationProps onUpdate: () => void onManageMetadata: () => void - statusFilterValue: string remoteSortValue: string + onSortChange: (value: string) => void } /** @@ -45,8 +45,8 @@ const DocumentList: FC = ({ pagination, onUpdate, onManageMetadata, - statusFilterValue, remoteSortValue, + onSortChange, }) => { const { t } = useTranslation() const datasetConfig = useDatasetDetailContext(s => s.dataset) @@ -55,10 +55,9 @@ const DocumentList: FC = ({ const isQAMode = chunkingMode === ChunkingMode.qa // Sorting - const { sortField, sortOrder, handleSort, sortedDocuments } = useDocumentSort({ - documents, - statusFilterValue, + const { sortField, sortOrder, handleSort } = useDocumentSort({ remoteSortValue, + onRemoteSortChange: onSortChange, }) // Selection @@ -71,7 +70,7 @@ const DocumentList: FC = ({ downloadableSelectedIds, clearSelection, } = useDocumentSelection({ - documents: sortedDocuments, + documents, selectedIds, onSelectedIdChange, }) @@ -135,24 +134,10 @@ const DocumentList: FC = ({
- + {t('list.table.header.fileName', { ns: 'datasetDocuments' })} {t('list.table.header.chunkingMode', { ns: 'datasetDocuments' })} - - - + {t('list.table.header.words', { ns: 'datasetDocuments' })} = ({ - {sortedDocuments.map((doc, index) => ( + {documents.map((doc, index) => ( = ({ } export default DocumentList - -export { renderTdValue } diff --git a/web/app/components/datasets/documents/detail/__tests__/index.spec.tsx b/web/app/components/datasets/documents/detail/__tests__/index.spec.tsx index ad8741a8e1..f01a64e34e 100644 --- a/web/app/components/datasets/documents/detail/__tests__/index.spec.tsx +++ b/web/app/components/datasets/documents/detail/__tests__/index.spec.tsx @@ -9,6 +9,7 @@ const mocks = vi.hoisted(() => { documentError: null as Error | null, documentMetadata: null as Record | null, media: 'desktop' as string, + searchParams: '' as string, } return { state, @@ -26,6 +27,7 @@ const mocks = vi.hoisted(() => { // --- External mocks --- vi.mock('next/navigation', () => ({ useRouter: () => ({ push: mocks.push }), + useSearchParams: () => new URLSearchParams(mocks.state.searchParams), })) vi.mock('@/hooks/use-breakpoints', () => ({ @@ -193,6 +195,7 @@ describe('DocumentDetail', () => { mocks.state.documentError = null mocks.state.documentMetadata = null mocks.state.media = 'desktop' + mocks.state.searchParams = '' }) afterEach(() => { @@ -286,15 +289,23 @@ describe('DocumentDetail', () => { }) it('should toggle metadata panel when button clicked', () => { - const { container } = render() + render() expect(screen.getByTestId('metadata')).toBeInTheDocument() - const svgs = container.querySelectorAll('svg') - const toggleBtn = svgs[svgs.length - 1].closest('button')! - fireEvent.click(toggleBtn) + fireEvent.click(screen.getByTestId('document-detail-metadata-toggle')) expect(screen.queryByTestId('metadata')).not.toBeInTheDocument() }) + it('should expose aria semantics for metadata toggle button', () => { + render() + const toggle = screen.getByTestId('document-detail-metadata-toggle') + expect(toggle).toHaveAttribute('aria-label') + expect(toggle).toHaveAttribute('aria-pressed', 'true') + + fireEvent.click(toggle) + expect(toggle).toHaveAttribute('aria-pressed', 'false') + }) + it('should pass correct props to Metadata', () => { render() const metadata = screen.getByTestId('metadata') @@ -305,20 +316,21 @@ describe('DocumentDetail', () => { describe('Navigation', () => { it('should navigate back when back button clicked', () => { - const { container } = render() - const backBtn = container.querySelector('svg')!.parentElement! - fireEvent.click(backBtn) + render() + fireEvent.click(screen.getByTestId('document-detail-back-button')) expect(mocks.push).toHaveBeenCalledWith('/datasets/ds-1/documents') }) + it('should expose aria label for back button', () => { + render() + expect(screen.getByTestId('document-detail-back-button')).toHaveAttribute('aria-label') + }) + it('should preserve query params when navigating back', () => { - const origLocation = window.location - window.history.pushState({}, '', '?page=2&status=active') - const { container } = render() - const backBtn = container.querySelector('svg')!.parentElement! - fireEvent.click(backBtn) + mocks.state.searchParams = 'page=2&status=active' + render() + fireEvent.click(screen.getByTestId('document-detail-back-button')) expect(mocks.push).toHaveBeenCalledWith('/datasets/ds-1/documents?page=2&status=active') - window.history.pushState({}, '', origLocation.href) }) }) diff --git a/web/app/components/datasets/documents/detail/index.tsx b/web/app/components/datasets/documents/detail/index.tsx index e147bf9aba..b6842605c6 100644 --- a/web/app/components/datasets/documents/detail/index.tsx +++ b/web/app/components/datasets/documents/detail/index.tsx @@ -1,8 +1,7 @@ 'use client' import type { FC } from 'react' import type { DataSourceInfo, FileItem, FullDocumentDetail, LegacyDataSourceInfo } from '@/models/datasets' -import { RiArrowLeftLine, RiLayoutLeft2Line, RiLayoutRight2Line } from '@remixicon/react' -import { useRouter } from 'next/navigation' +import { useRouter, useSearchParams } from 'next/navigation' import * as React from 'react' import { useMemo, useState } from 'react' import { useTranslation } from 'react-i18next' @@ -35,6 +34,7 @@ type DocumentDetailProps = { const DocumentDetail: FC = ({ datasetId, documentId }) => { const router = useRouter() + const searchParams = useSearchParams() const { t } = useTranslation() const media = useBreakpoints() @@ -98,11 +98,8 @@ const DocumentDetail: FC = ({ datasetId, documentId }) => { }) const backToPrev = () => { - // Preserve pagination and filter states when navigating back - const searchParams = new URLSearchParams(window.location.search) const queryString = searchParams.toString() - const separator = queryString ? '?' : '' - const backPath = `/datasets/${datasetId}/documents${separator}${queryString}` + const backPath = `/datasets/${datasetId}/documents${queryString ? `?${queryString}` : ''}` router.push(backPath) } @@ -152,6 +149,11 @@ const DocumentDetail: FC = ({ datasetId, documentId }) => { return chunkMode === ChunkingMode.parentChild && parentMode === 'full-doc' }, [documentDetail?.doc_form, parentMode]) + const backButtonLabel = t('operation.back', { ns: 'common' }) + const metadataToggleLabel = `${showMetadata + ? t('operation.close', { ns: 'common' }) + : t('operation.view', { ns: 'common' })} ${t('metadata.title', { ns: 'datasetDocuments' })}` + return ( = ({ datasetId, documentId }) => { >
-
- -
+ = ({ datasetId, documentId }) => { />
diff --git a/web/app/components/datasets/documents/hooks/__tests__/use-document-list-query-state.spec.ts b/web/app/components/datasets/documents/hooks/__tests__/use-document-list-query-state.spec.ts deleted file mode 100644 index e31d4ac547..0000000000 --- a/web/app/components/datasets/documents/hooks/__tests__/use-document-list-query-state.spec.ts +++ /dev/null @@ -1,439 +0,0 @@ -import type { DocumentListQuery } from '../use-document-list-query-state' -import { act, renderHook } from '@testing-library/react' - -import { beforeEach, describe, expect, it, vi } from 'vitest' -import useDocumentListQueryState from '../use-document-list-query-state' - -const mockPush = vi.fn() -const mockSearchParams = new URLSearchParams() - -vi.mock('@/models/datasets', () => ({ - DisplayStatusList: [ - 'queuing', - 'indexing', - 'paused', - 'error', - 'available', - 'enabled', - 'disabled', - 'archived', - ], -})) - -vi.mock('next/navigation', () => ({ - useRouter: () => ({ push: mockPush }), - usePathname: () => '/datasets/test-id/documents', - useSearchParams: () => mockSearchParams, -})) - -describe('useDocumentListQueryState', () => { - beforeEach(() => { - vi.clearAllMocks() - // Reset mock search params to empty - for (const key of [...mockSearchParams.keys()]) - mockSearchParams.delete(key) - }) - - // Tests for parseParams (exposed via the query property) - describe('parseParams (via query)', () => { - it('should return default query when no search params present', () => { - const { result } = renderHook(() => useDocumentListQueryState()) - - expect(result.current.query).toEqual({ - page: 1, - limit: 10, - keyword: '', - status: 'all', - sort: '-created_at', - }) - }) - - it('should parse page from search params', () => { - mockSearchParams.set('page', '3') - - const { result } = renderHook(() => useDocumentListQueryState()) - - expect(result.current.query.page).toBe(3) - }) - - it('should default page to 1 when page is zero', () => { - mockSearchParams.set('page', '0') - - const { result } = renderHook(() => useDocumentListQueryState()) - - expect(result.current.query.page).toBe(1) - }) - - it('should default page to 1 when page is negative', () => { - mockSearchParams.set('page', '-5') - - const { result } = renderHook(() => useDocumentListQueryState()) - - expect(result.current.query.page).toBe(1) - }) - - it('should default page to 1 when page is NaN', () => { - mockSearchParams.set('page', 'abc') - - const { result } = renderHook(() => useDocumentListQueryState()) - - expect(result.current.query.page).toBe(1) - }) - - it('should parse limit from search params', () => { - mockSearchParams.set('limit', '50') - - const { result } = renderHook(() => useDocumentListQueryState()) - - expect(result.current.query.limit).toBe(50) - }) - - it('should default limit to 10 when limit is zero', () => { - mockSearchParams.set('limit', '0') - - const { result } = renderHook(() => useDocumentListQueryState()) - - expect(result.current.query.limit).toBe(10) - }) - - it('should default limit to 10 when limit exceeds 100', () => { - mockSearchParams.set('limit', '101') - - const { result } = renderHook(() => useDocumentListQueryState()) - - expect(result.current.query.limit).toBe(10) - }) - - it('should default limit to 10 when limit is negative', () => { - mockSearchParams.set('limit', '-1') - - const { result } = renderHook(() => useDocumentListQueryState()) - - expect(result.current.query.limit).toBe(10) - }) - - it('should accept limit at boundary 100', () => { - mockSearchParams.set('limit', '100') - - const { result } = renderHook(() => useDocumentListQueryState()) - - expect(result.current.query.limit).toBe(100) - }) - - it('should accept limit at boundary 1', () => { - mockSearchParams.set('limit', '1') - - const { result } = renderHook(() => useDocumentListQueryState()) - - expect(result.current.query.limit).toBe(1) - }) - - it('should parse and decode keyword from search params', () => { - mockSearchParams.set('keyword', encodeURIComponent('hello world')) - - const { result } = renderHook(() => useDocumentListQueryState()) - - expect(result.current.query.keyword).toBe('hello world') - }) - - it('should return empty keyword when not present', () => { - const { result } = renderHook(() => useDocumentListQueryState()) - - expect(result.current.query.keyword).toBe('') - }) - - it('should sanitize status from search params', () => { - mockSearchParams.set('status', 'available') - - const { result } = renderHook(() => useDocumentListQueryState()) - - expect(result.current.query.status).toBe('available') - }) - - it('should fallback status to all for unknown status', () => { - mockSearchParams.set('status', 'badvalue') - - const { result } = renderHook(() => useDocumentListQueryState()) - - expect(result.current.query.status).toBe('all') - }) - - it('should resolve active status alias to available', () => { - mockSearchParams.set('status', 'active') - - const { result } = renderHook(() => useDocumentListQueryState()) - - expect(result.current.query.status).toBe('available') - }) - - it('should parse valid sort value from search params', () => { - mockSearchParams.set('sort', 'hit_count') - - const { result } = renderHook(() => useDocumentListQueryState()) - - expect(result.current.query.sort).toBe('hit_count') - }) - - it('should default sort to -created_at for invalid sort value', () => { - mockSearchParams.set('sort', 'invalid_sort') - - const { result } = renderHook(() => useDocumentListQueryState()) - - expect(result.current.query.sort).toBe('-created_at') - }) - - it('should default sort to -created_at when not present', () => { - const { result } = renderHook(() => useDocumentListQueryState()) - - expect(result.current.query.sort).toBe('-created_at') - }) - - it.each([ - '-created_at', - 'created_at', - '-hit_count', - 'hit_count', - ] as const)('should accept valid sort value %s', (sortValue) => { - mockSearchParams.set('sort', sortValue) - - const { result } = renderHook(() => useDocumentListQueryState()) - - expect(result.current.query.sort).toBe(sortValue) - }) - }) - - // Tests for updateQuery - describe('updateQuery', () => { - it('should call router.push with updated params when page is changed', () => { - const { result } = renderHook(() => useDocumentListQueryState()) - - act(() => { - result.current.updateQuery({ page: 3 }) - }) - - expect(mockPush).toHaveBeenCalledTimes(1) - const pushedUrl = mockPush.mock.calls[0][0] as string - expect(pushedUrl).toContain('page=3') - }) - - it('should call router.push with scroll false', () => { - const { result } = renderHook(() => useDocumentListQueryState()) - - act(() => { - result.current.updateQuery({ page: 2 }) - }) - - expect(mockPush).toHaveBeenCalledWith( - expect.any(String), - { scroll: false }, - ) - }) - - it('should set status in URL when status is not all', () => { - const { result } = renderHook(() => useDocumentListQueryState()) - - act(() => { - result.current.updateQuery({ status: 'error' }) - }) - - const pushedUrl = mockPush.mock.calls[0][0] as string - expect(pushedUrl).toContain('status=error') - }) - - it('should not set status in URL when status is all', () => { - const { result } = renderHook(() => useDocumentListQueryState()) - - act(() => { - result.current.updateQuery({ status: 'all' }) - }) - - const pushedUrl = mockPush.mock.calls[0][0] as string - expect(pushedUrl).not.toContain('status=') - }) - - it('should set sort in URL when sort is not the default -created_at', () => { - const { result } = renderHook(() => useDocumentListQueryState()) - - act(() => { - result.current.updateQuery({ sort: 'hit_count' }) - }) - - const pushedUrl = mockPush.mock.calls[0][0] as string - expect(pushedUrl).toContain('sort=hit_count') - }) - - it('should not set sort in URL when sort is default -created_at', () => { - const { result } = renderHook(() => useDocumentListQueryState()) - - act(() => { - result.current.updateQuery({ sort: '-created_at' }) - }) - - const pushedUrl = mockPush.mock.calls[0][0] as string - expect(pushedUrl).not.toContain('sort=') - }) - - it('should encode keyword in URL when keyword is provided', () => { - const { result } = renderHook(() => useDocumentListQueryState()) - - act(() => { - result.current.updateQuery({ keyword: 'test query' }) - }) - - const pushedUrl = mockPush.mock.calls[0][0] as string - // Source code applies encodeURIComponent before setting in URLSearchParams - expect(pushedUrl).toContain('keyword=') - const params = new URLSearchParams(pushedUrl.split('?')[1]) - // params.get decodes one layer, but the value was pre-encoded with encodeURIComponent - expect(decodeURIComponent(params.get('keyword')!)).toBe('test query') - }) - - it('should remove keyword from URL when keyword is empty', () => { - const { result } = renderHook(() => useDocumentListQueryState()) - - act(() => { - result.current.updateQuery({ keyword: '' }) - }) - - const pushedUrl = mockPush.mock.calls[0][0] as string - expect(pushedUrl).not.toContain('keyword=') - }) - - it('should sanitize invalid status to all and not include in URL', () => { - const { result } = renderHook(() => useDocumentListQueryState()) - - act(() => { - result.current.updateQuery({ status: 'invalidstatus' }) - }) - - const pushedUrl = mockPush.mock.calls[0][0] as string - expect(pushedUrl).not.toContain('status=') - }) - - it('should sanitize invalid sort to -created_at and not include in URL', () => { - const { result } = renderHook(() => useDocumentListQueryState()) - - act(() => { - result.current.updateQuery({ sort: 'invalidsort' as DocumentListQuery['sort'] }) - }) - - const pushedUrl = mockPush.mock.calls[0][0] as string - expect(pushedUrl).not.toContain('sort=') - }) - - it('should omit page and limit when they are default and no keyword', () => { - const { result } = renderHook(() => useDocumentListQueryState()) - - act(() => { - result.current.updateQuery({ page: 1, limit: 10 }) - }) - - const pushedUrl = mockPush.mock.calls[0][0] as string - expect(pushedUrl).not.toContain('page=') - expect(pushedUrl).not.toContain('limit=') - }) - - it('should include page and limit when page is greater than 1', () => { - const { result } = renderHook(() => useDocumentListQueryState()) - - act(() => { - result.current.updateQuery({ page: 2 }) - }) - - const pushedUrl = mockPush.mock.calls[0][0] as string - expect(pushedUrl).toContain('page=2') - expect(pushedUrl).toContain('limit=10') - }) - - it('should include page and limit when limit is non-default', () => { - const { result } = renderHook(() => useDocumentListQueryState()) - - act(() => { - result.current.updateQuery({ limit: 25 }) - }) - - const pushedUrl = mockPush.mock.calls[0][0] as string - expect(pushedUrl).toContain('page=1') - expect(pushedUrl).toContain('limit=25') - }) - - it('should include page and limit when keyword is provided', () => { - const { result } = renderHook(() => useDocumentListQueryState()) - - act(() => { - result.current.updateQuery({ keyword: 'search' }) - }) - - const pushedUrl = mockPush.mock.calls[0][0] as string - expect(pushedUrl).toContain('page=1') - expect(pushedUrl).toContain('limit=10') - }) - - it('should use pathname prefix in pushed URL', () => { - const { result } = renderHook(() => useDocumentListQueryState()) - - act(() => { - result.current.updateQuery({ page: 2 }) - }) - - const pushedUrl = mockPush.mock.calls[0][0] as string - expect(pushedUrl).toMatch(/^\/datasets\/test-id\/documents/) - }) - - it('should push path without query string when all values are defaults', () => { - const { result } = renderHook(() => useDocumentListQueryState()) - - act(() => { - result.current.updateQuery({}) - }) - - const pushedUrl = mockPush.mock.calls[0][0] as string - expect(pushedUrl).toBe('/datasets/test-id/documents') - }) - }) - - // Tests for resetQuery - describe('resetQuery', () => { - it('should push URL with default query params when called', () => { - mockSearchParams.set('page', '5') - mockSearchParams.set('status', 'error') - - const { result } = renderHook(() => useDocumentListQueryState()) - - act(() => { - result.current.resetQuery() - }) - - expect(mockPush).toHaveBeenCalledTimes(1) - const pushedUrl = mockPush.mock.calls[0][0] as string - // Default query has all defaults, so no params should be in the URL - expect(pushedUrl).toBe('/datasets/test-id/documents') - }) - - it('should call router.push with scroll false when resetting', () => { - const { result } = renderHook(() => useDocumentListQueryState()) - - act(() => { - result.current.resetQuery() - }) - - expect(mockPush).toHaveBeenCalledWith( - expect.any(String), - { scroll: false }, - ) - }) - }) - - // Tests for return value stability - describe('return value', () => { - it('should return query, updateQuery, and resetQuery', () => { - const { result } = renderHook(() => useDocumentListQueryState()) - - expect(result.current).toHaveProperty('query') - expect(result.current).toHaveProperty('updateQuery') - expect(result.current).toHaveProperty('resetQuery') - expect(typeof result.current.updateQuery).toBe('function') - expect(typeof result.current.resetQuery).toBe('function') - }) - }) -}) diff --git a/web/app/components/datasets/documents/hooks/__tests__/use-document-list-query-state.spec.tsx b/web/app/components/datasets/documents/hooks/__tests__/use-document-list-query-state.spec.tsx new file mode 100644 index 0000000000..5879e72782 --- /dev/null +++ b/web/app/components/datasets/documents/hooks/__tests__/use-document-list-query-state.spec.tsx @@ -0,0 +1,426 @@ +import type { DocumentListQuery } from '../use-document-list-query-state' +import { act, waitFor } from '@testing-library/react' +import { beforeEach, describe, expect, it, vi } from 'vitest' +import { renderHookWithNuqs } from '@/test/nuqs-testing' +import { useDocumentListQueryState } from '../use-document-list-query-state' + +vi.mock('@/models/datasets', () => ({ + DisplayStatusList: [ + 'queuing', + 'indexing', + 'paused', + 'error', + 'available', + 'enabled', + 'disabled', + 'archived', + ], +})) + +const renderWithAdapter = (searchParams = '') => { + return renderHookWithNuqs(() => useDocumentListQueryState(), { searchParams }) +} + +describe('useDocumentListQueryState', () => { + beforeEach(() => { + vi.clearAllMocks() + }) + + describe('query parsing', () => { + it('should return default query when no search params present', () => { + const { result } = renderWithAdapter() + + expect(result.current.query).toEqual({ + page: 1, + limit: 10, + keyword: '', + status: 'all', + sort: '-created_at', + }) + }) + + it('should parse page from search params', () => { + const { result } = renderWithAdapter('?page=3') + expect(result.current.query.page).toBe(3) + }) + + it('should default page to 1 when page is zero', () => { + const { result } = renderWithAdapter('?page=0') + expect(result.current.query.page).toBe(1) + }) + + it('should default page to 1 when page is negative', () => { + const { result } = renderWithAdapter('?page=-5') + expect(result.current.query.page).toBe(1) + }) + + it('should default page to 1 when page is NaN', () => { + const { result } = renderWithAdapter('?page=abc') + expect(result.current.query.page).toBe(1) + }) + + it('should parse limit from search params', () => { + const { result } = renderWithAdapter('?limit=50') + expect(result.current.query.limit).toBe(50) + }) + + it('should default limit to 10 when limit is zero', () => { + const { result } = renderWithAdapter('?limit=0') + expect(result.current.query.limit).toBe(10) + }) + + it('should default limit to 10 when limit exceeds 100', () => { + const { result } = renderWithAdapter('?limit=101') + expect(result.current.query.limit).toBe(10) + }) + + it('should default limit to 10 when limit is negative', () => { + const { result } = renderWithAdapter('?limit=-1') + expect(result.current.query.limit).toBe(10) + }) + + it('should accept limit at boundary 100', () => { + const { result } = renderWithAdapter('?limit=100') + expect(result.current.query.limit).toBe(100) + }) + + it('should accept limit at boundary 1', () => { + const { result } = renderWithAdapter('?limit=1') + expect(result.current.query.limit).toBe(1) + }) + + it('should parse keyword from search params', () => { + const { result } = renderWithAdapter('?keyword=hello+world') + expect(result.current.query.keyword).toBe('hello world') + }) + + it('should preserve legacy double-encoded keyword text after URL decoding', () => { + const { result } = renderWithAdapter('?keyword=test%2520query') + expect(result.current.query.keyword).toBe('test%20query') + }) + + it('should return empty keyword when not present', () => { + const { result } = renderWithAdapter() + expect(result.current.query.keyword).toBe('') + }) + + it('should sanitize status from search params', () => { + const { result } = renderWithAdapter('?status=available') + expect(result.current.query.status).toBe('available') + }) + + it('should fallback status to all for unknown status', () => { + const { result } = renderWithAdapter('?status=badvalue') + expect(result.current.query.status).toBe('all') + }) + + it('should resolve active status alias to available', () => { + const { result } = renderWithAdapter('?status=active') + expect(result.current.query.status).toBe('available') + }) + + it('should parse valid sort value from search params', () => { + const { result } = renderWithAdapter('?sort=hit_count') + expect(result.current.query.sort).toBe('hit_count') + }) + + it('should default sort to -created_at for invalid sort value', () => { + const { result } = renderWithAdapter('?sort=invalid_sort') + expect(result.current.query.sort).toBe('-created_at') + }) + + it('should default sort to -created_at when not present', () => { + const { result } = renderWithAdapter() + expect(result.current.query.sort).toBe('-created_at') + }) + + it.each([ + '-created_at', + 'created_at', + '-hit_count', + 'hit_count', + ] as const)('should accept valid sort value %s', (sortValue) => { + const { result } = renderWithAdapter(`?sort=${sortValue}`) + expect(result.current.query.sort).toBe(sortValue) + }) + }) + + describe('updateQuery', () => { + it('should update page in state when page is changed', () => { + const { result } = renderWithAdapter() + + act(() => { + result.current.updateQuery({ page: 3 }) + }) + + expect(result.current.query.page).toBe(3) + }) + + it('should sync page to URL with push history', async () => { + const { result, onUrlUpdate } = renderWithAdapter() + + act(() => { + result.current.updateQuery({ page: 2 }) + }) + + await waitFor(() => expect(onUrlUpdate).toHaveBeenCalled()) + const update = onUrlUpdate.mock.calls[onUrlUpdate.mock.calls.length - 1][0] + expect(update.searchParams.get('page')).toBe('2') + expect(update.options.history).toBe('push') + }) + + it('should set status in URL when status is not all', async () => { + const { result, onUrlUpdate } = renderWithAdapter() + + act(() => { + result.current.updateQuery({ status: 'error' }) + }) + + await waitFor(() => expect(onUrlUpdate).toHaveBeenCalled()) + const update = onUrlUpdate.mock.calls[onUrlUpdate.mock.calls.length - 1][0] + expect(update.searchParams.get('status')).toBe('error') + }) + + it('should not set status in URL when status is all', async () => { + const { result, onUrlUpdate } = renderWithAdapter() + + act(() => { + result.current.updateQuery({ status: 'all' }) + }) + + await waitFor(() => expect(onUrlUpdate).toHaveBeenCalled()) + const update = onUrlUpdate.mock.calls[onUrlUpdate.mock.calls.length - 1][0] + expect(update.searchParams.has('status')).toBe(false) + }) + + it('should set sort in URL when sort is not the default -created_at', async () => { + const { result, onUrlUpdate } = renderWithAdapter() + + act(() => { + result.current.updateQuery({ sort: 'hit_count' }) + }) + + await waitFor(() => expect(onUrlUpdate).toHaveBeenCalled()) + const update = onUrlUpdate.mock.calls[onUrlUpdate.mock.calls.length - 1][0] + expect(update.searchParams.get('sort')).toBe('hit_count') + }) + + it('should not set sort in URL when sort is default -created_at', async () => { + const { result, onUrlUpdate } = renderWithAdapter() + + act(() => { + result.current.updateQuery({ sort: '-created_at' }) + }) + + await waitFor(() => expect(onUrlUpdate).toHaveBeenCalled()) + const update = onUrlUpdate.mock.calls[onUrlUpdate.mock.calls.length - 1][0] + expect(update.searchParams.has('sort')).toBe(false) + }) + + it('should set keyword in URL when keyword is provided', async () => { + const { result, onUrlUpdate } = renderWithAdapter() + + act(() => { + result.current.updateQuery({ keyword: 'test query' }) + }) + + await waitFor(() => expect(onUrlUpdate).toHaveBeenCalled()) + const update = onUrlUpdate.mock.calls[onUrlUpdate.mock.calls.length - 1][0] + expect(update.searchParams.get('keyword')).toBe('test query') + expect(update.options.history).toBe('replace') + }) + + it('should use replace history when keyword update also resets page', async () => { + const { result, onUrlUpdate } = renderWithAdapter('?page=3') + + act(() => { + result.current.updateQuery({ keyword: 'hello', page: 1 }) + }) + + await waitFor(() => expect(onUrlUpdate).toHaveBeenCalled()) + const update = onUrlUpdate.mock.calls[onUrlUpdate.mock.calls.length - 1][0] + expect(update.searchParams.get('keyword')).toBe('hello') + expect(update.searchParams.has('page')).toBe(false) + expect(update.options.history).toBe('replace') + }) + + it('should remove keyword from URL when keyword is empty', async () => { + const { result, onUrlUpdate } = renderWithAdapter('?keyword=existing') + + act(() => { + result.current.updateQuery({ keyword: '' }) + }) + + await waitFor(() => expect(onUrlUpdate).toHaveBeenCalled()) + const update = onUrlUpdate.mock.calls[onUrlUpdate.mock.calls.length - 1][0] + expect(update.searchParams.has('keyword')).toBe(false) + expect(update.options.history).toBe('replace') + }) + + it('should remove keyword from URL when keyword contains only whitespace', async () => { + const { result, onUrlUpdate } = renderWithAdapter('?keyword=existing') + + act(() => { + result.current.updateQuery({ keyword: ' ' }) + }) + + await waitFor(() => expect(onUrlUpdate).toHaveBeenCalled()) + const update = onUrlUpdate.mock.calls[onUrlUpdate.mock.calls.length - 1][0] + expect(update.searchParams.has('keyword')).toBe(false) + expect(result.current.query.keyword).toBe('') + }) + + it('should preserve literal percent-encoded-like keyword values', async () => { + const { result, onUrlUpdate } = renderWithAdapter() + + act(() => { + result.current.updateQuery({ keyword: '%2F' }) + }) + + await waitFor(() => expect(onUrlUpdate).toHaveBeenCalled()) + const update = onUrlUpdate.mock.calls[onUrlUpdate.mock.calls.length - 1][0] + expect(update.searchParams.get('keyword')).toBe('%2F') + expect(result.current.query.keyword).toBe('%2F') + }) + + it('should keep keyword text unchanged when updating query from legacy URL', async () => { + const { result, onUrlUpdate } = renderWithAdapter('?keyword=test%2520query') + + act(() => { + result.current.updateQuery({ page: 2 }) + }) + + await waitFor(() => expect(onUrlUpdate).toHaveBeenCalled()) + expect(result.current.query.keyword).toBe('test%20query') + }) + + it('should sanitize invalid status to all and not include in URL', async () => { + const { result, onUrlUpdate } = renderWithAdapter() + + act(() => { + result.current.updateQuery({ status: 'invalidstatus' }) + }) + + await waitFor(() => expect(onUrlUpdate).toHaveBeenCalled()) + const update = onUrlUpdate.mock.calls[onUrlUpdate.mock.calls.length - 1][0] + expect(update.searchParams.has('status')).toBe(false) + }) + + it('should sanitize invalid sort to -created_at and not include in URL', async () => { + const { result, onUrlUpdate } = renderWithAdapter() + + act(() => { + result.current.updateQuery({ sort: 'invalidsort' as DocumentListQuery['sort'] }) + }) + + await waitFor(() => expect(onUrlUpdate).toHaveBeenCalled()) + const update = onUrlUpdate.mock.calls[onUrlUpdate.mock.calls.length - 1][0] + expect(update.searchParams.has('sort')).toBe(false) + }) + + it('should not include page in URL when page is default', async () => { + const { result, onUrlUpdate } = renderWithAdapter() + + act(() => { + result.current.updateQuery({ page: 1 }) + }) + + await waitFor(() => expect(onUrlUpdate).toHaveBeenCalled()) + const update = onUrlUpdate.mock.calls[onUrlUpdate.mock.calls.length - 1][0] + expect(update.searchParams.has('page')).toBe(false) + }) + + it('should include page in URL when page is greater than 1', async () => { + const { result, onUrlUpdate } = renderWithAdapter() + + act(() => { + result.current.updateQuery({ page: 2 }) + }) + + await waitFor(() => expect(onUrlUpdate).toHaveBeenCalled()) + const update = onUrlUpdate.mock.calls[onUrlUpdate.mock.calls.length - 1][0] + expect(update.searchParams.get('page')).toBe('2') + }) + + it('should include limit in URL when limit is non-default', async () => { + const { result, onUrlUpdate } = renderWithAdapter() + + act(() => { + result.current.updateQuery({ limit: 25 }) + }) + + await waitFor(() => expect(onUrlUpdate).toHaveBeenCalled()) + const update = onUrlUpdate.mock.calls[onUrlUpdate.mock.calls.length - 1][0] + expect(update.searchParams.get('limit')).toBe('25') + }) + + it('should sanitize invalid page to default and omit page from URL', async () => { + const { result, onUrlUpdate } = renderWithAdapter() + + act(() => { + result.current.updateQuery({ page: -1 }) + }) + + await waitFor(() => expect(onUrlUpdate).toHaveBeenCalled()) + const update = onUrlUpdate.mock.calls[onUrlUpdate.mock.calls.length - 1][0] + expect(update.searchParams.has('page')).toBe(false) + expect(result.current.query.page).toBe(1) + }) + + it('should sanitize invalid limit to default and omit limit from URL', async () => { + const { result, onUrlUpdate } = renderWithAdapter() + + act(() => { + result.current.updateQuery({ limit: 999 }) + }) + + await waitFor(() => expect(onUrlUpdate).toHaveBeenCalled()) + const update = onUrlUpdate.mock.calls[onUrlUpdate.mock.calls.length - 1][0] + expect(update.searchParams.has('limit')).toBe(false) + expect(result.current.query.limit).toBe(10) + }) + }) + + describe('resetQuery', () => { + it('should reset all values to defaults', () => { + const { result } = renderWithAdapter('?page=5&status=error&sort=hit_count') + + act(() => { + result.current.resetQuery() + }) + + expect(result.current.query).toEqual({ + page: 1, + limit: 10, + keyword: '', + status: 'all', + sort: '-created_at', + }) + }) + + it('should clear all params from URL when called', async () => { + const { result, onUrlUpdate } = renderWithAdapter('?page=5&status=error') + + act(() => { + result.current.resetQuery() + }) + + await waitFor(() => expect(onUrlUpdate).toHaveBeenCalled()) + const update = onUrlUpdate.mock.calls[onUrlUpdate.mock.calls.length - 1][0] + expect(update.searchParams.has('page')).toBe(false) + expect(update.searchParams.has('status')).toBe(false) + }) + }) + + describe('return value', () => { + it('should return query, updateQuery, and resetQuery', () => { + const { result } = renderWithAdapter() + + expect(result.current).toHaveProperty('query') + expect(result.current).toHaveProperty('updateQuery') + expect(result.current).toHaveProperty('resetQuery') + expect(typeof result.current.updateQuery).toBe('function') + expect(typeof result.current.resetQuery).toBe('function') + }) + }) +}) diff --git a/web/app/components/datasets/documents/hooks/__tests__/use-documents-page-state.spec.ts b/web/app/components/datasets/documents/hooks/__tests__/use-documents-page-state.spec.ts index 34911e9e9c..e0dbee6660 100644 --- a/web/app/components/datasets/documents/hooks/__tests__/use-documents-page-state.spec.ts +++ b/web/app/components/datasets/documents/hooks/__tests__/use-documents-page-state.spec.ts @@ -1,12 +1,10 @@ import type { DocumentListQuery } from '../use-document-list-query-state' -import type { DocumentListResponse } from '@/models/datasets' import { act, renderHook } from '@testing-library/react' import { beforeEach, describe, expect, it, vi } from 'vitest' import { useDocumentsPageState } from '../use-documents-page-state' const mockUpdateQuery = vi.fn() -const mockResetQuery = vi.fn() let mockQuery: DocumentListQuery = { page: 1, limit: 10, keyword: '', status: 'all', sort: '-created_at' } vi.mock('@/models/datasets', () => ({ @@ -22,151 +20,70 @@ vi.mock('@/models/datasets', () => ({ ], })) -vi.mock('next/navigation', () => ({ - useRouter: () => ({ push: vi.fn() }), - usePathname: () => '/datasets/test-id/documents', - useSearchParams: () => new URLSearchParams(), -})) - -// Mock ahooks debounce utilities: required because tests capture the debounce -// callback reference to invoke it synchronously, bypassing real timer delays. -let capturedDebounceFnCallback: (() => void) | null = null - vi.mock('ahooks', () => ({ useDebounce: (value: unknown, _options?: { wait?: number }) => value, - useDebounceFn: (fn: () => void, _options?: { wait?: number }) => { - capturedDebounceFnCallback = fn - return { run: fn, cancel: vi.fn(), flush: vi.fn() } - }, })) -// Mock the dependent hook -vi.mock('../use-document-list-query-state', () => ({ - default: () => ({ - query: mockQuery, - updateQuery: mockUpdateQuery, - resetQuery: mockResetQuery, - }), -})) - -// Factory for creating DocumentListResponse test data -function createDocumentListResponse(overrides: Partial = {}): DocumentListResponse { +vi.mock('../use-document-list-query-state', async () => { + const React = await import('react') return { - data: [], - has_more: false, - total: 0, - page: 1, - limit: 10, - ...overrides, + useDocumentListQueryState: () => { + const [query, setQuery] = React.useState(mockQuery) + return { + query, + updateQuery: (updates: Partial) => { + mockUpdateQuery(updates) + setQuery(prev => ({ ...prev, ...updates })) + }, + } + }, } -} - -// Factory for creating a minimal document item -function createDocumentItem(overrides: Record = {}) { - return { - id: `doc-${Math.random().toString(36).slice(2, 8)}`, - name: 'test-doc.txt', - indexing_status: 'completed' as string, - display_status: 'available' as string, - enabled: true, - archived: false, - word_count: 100, - created_at: Date.now(), - updated_at: Date.now(), - created_from: 'web' as const, - created_by: 'user-1', - dataset_process_rule_id: 'rule-1', - doc_form: 'text_model' as const, - doc_language: 'en', - position: 1, - data_source_type: 'upload_file', - ...overrides, - } -} +}) describe('useDocumentsPageState', () => { beforeEach(() => { vi.clearAllMocks() - capturedDebounceFnCallback = null mockQuery = { page: 1, limit: 10, keyword: '', status: 'all', sort: '-created_at' } }) // Initial state verification describe('initial state', () => { - it('should return correct initial search state', () => { + it('should return correct initial query-derived state', () => { const { result } = renderHook(() => useDocumentsPageState()) expect(result.current.inputValue).toBe('') - expect(result.current.searchValue).toBe('') expect(result.current.debouncedSearchValue).toBe('') - }) - - it('should return correct initial filter and sort state', () => { - const { result } = renderHook(() => useDocumentsPageState()) - expect(result.current.statusFilterValue).toBe('all') expect(result.current.sortValue).toBe('-created_at') expect(result.current.normalizedStatusFilterValue).toBe('all') - }) - - it('should return correct initial pagination state', () => { - const { result } = renderHook(() => useDocumentsPageState()) - - // page is query.page - 1 = 0 expect(result.current.currPage).toBe(0) expect(result.current.limit).toBe(10) - }) - - it('should return correct initial selection state', () => { - const { result } = renderHook(() => useDocumentsPageState()) - expect(result.current.selectedIds).toEqual([]) }) - it('should return correct initial polling state', () => { - const { result } = renderHook(() => useDocumentsPageState()) - - expect(result.current.timerCanRun).toBe(true) - }) - - it('should initialize from query when query has keyword', () => { - mockQuery = { ...mockQuery, keyword: 'initial search' } + it('should initialize from non-default query values', () => { + mockQuery = { + page: 3, + limit: 25, + keyword: 'initial', + status: 'enabled', + sort: 'hit_count', + } const { result } = renderHook(() => useDocumentsPageState()) - expect(result.current.inputValue).toBe('initial search') - expect(result.current.searchValue).toBe('initial search') - }) - - it('should initialize pagination from query with non-default page', () => { - mockQuery = { ...mockQuery, page: 3, limit: 25 } - - const { result } = renderHook(() => useDocumentsPageState()) - - expect(result.current.currPage).toBe(2) // page - 1 + expect(result.current.inputValue).toBe('initial') + expect(result.current.currPage).toBe(2) expect(result.current.limit).toBe(25) - }) - - it('should initialize status filter from query', () => { - mockQuery = { ...mockQuery, status: 'error' } - - const { result } = renderHook(() => useDocumentsPageState()) - - expect(result.current.statusFilterValue).toBe('error') - }) - - it('should initialize sort from query', () => { - mockQuery = { ...mockQuery, sort: 'hit_count' } - - const { result } = renderHook(() => useDocumentsPageState()) - + expect(result.current.statusFilterValue).toBe('enabled') + expect(result.current.normalizedStatusFilterValue).toBe('available') expect(result.current.sortValue).toBe('hit_count') }) }) // Handler behaviors describe('handleInputChange', () => { - it('should update input value when called', () => { + it('should update keyword and reset page', () => { const { result } = renderHook(() => useDocumentsPageState()) act(() => { @@ -174,30 +91,59 @@ describe('useDocumentsPageState', () => { }) expect(result.current.inputValue).toBe('new value') + expect(mockUpdateQuery).toHaveBeenCalledWith({ keyword: 'new value', page: 1 }) }) - it('should trigger debounced search callback when called', () => { + it('should clear selected ids when keyword changes', () => { const { result } = renderHook(() => useDocumentsPageState()) - // First call sets inputValue and triggers the debounced fn act(() => { - result.current.handleInputChange('search term') + result.current.setSelectedIds(['doc-1']) + }) + expect(result.current.selectedIds).toEqual(['doc-1']) + + act(() => { + result.current.handleInputChange('keyword') }) - // The debounced fn captures inputValue from its render closure. - // After re-render with new inputValue, calling the captured callback again - // should reflect the updated state. + expect(result.current.selectedIds).toEqual([]) + }) + + it('should keep selected ids when keyword is unchanged', () => { + mockQuery = { ...mockQuery, keyword: 'same' } + const { result } = renderHook(() => useDocumentsPageState()) + act(() => { - if (capturedDebounceFnCallback) - capturedDebounceFnCallback() + result.current.setSelectedIds(['doc-1']) }) - expect(result.current.searchValue).toBe('search term') + act(() => { + result.current.handleInputChange('same') + }) + + expect(result.current.selectedIds).toEqual(['doc-1']) + expect(mockUpdateQuery).toHaveBeenCalledWith({ keyword: 'same', page: 1 }) }) }) describe('handleStatusFilterChange', () => { - it('should update status filter value when called with valid status', () => { + it('should sanitize status, reset page, and clear selection', () => { + const { result } = renderHook(() => useDocumentsPageState()) + + act(() => { + result.current.setSelectedIds(['doc-1']) + }) + + act(() => { + result.current.handleStatusFilterChange('invalid') + }) + + expect(result.current.statusFilterValue).toBe('all') + expect(result.current.selectedIds).toEqual([]) + expect(mockUpdateQuery).toHaveBeenCalledWith({ status: 'all', page: 1 }) + }) + + it('should update to valid status value', () => { const { result } = renderHook(() => useDocumentsPageState()) act(() => { @@ -205,61 +151,23 @@ describe('useDocumentsPageState', () => { }) expect(result.current.statusFilterValue).toBe('error') - }) - - it('should reset page to 0 when status filter changes', () => { - mockQuery = { ...mockQuery, page: 3 } - const { result } = renderHook(() => useDocumentsPageState()) - - act(() => { - result.current.handleStatusFilterChange('error') - }) - - expect(result.current.currPage).toBe(0) - }) - - it('should call updateQuery with sanitized status and page 1', () => { - const { result } = renderHook(() => useDocumentsPageState()) - - act(() => { - result.current.handleStatusFilterChange('error') - }) - expect(mockUpdateQuery).toHaveBeenCalledWith({ status: 'error', page: 1 }) }) - - it('should sanitize invalid status to all', () => { - const { result } = renderHook(() => useDocumentsPageState()) - - act(() => { - result.current.handleStatusFilterChange('invalid') - }) - - expect(result.current.statusFilterValue).toBe('all') - expect(mockUpdateQuery).toHaveBeenCalledWith({ status: 'all', page: 1 }) - }) }) describe('handleStatusFilterClear', () => { - it('should set status to all and reset page when status is not all', () => { + it('should reset status to all when status is not all', () => { + mockQuery = { ...mockQuery, status: 'error' } const { result } = renderHook(() => useDocumentsPageState()) - // First set a non-all status - act(() => { - result.current.handleStatusFilterChange('error') - }) - vi.clearAllMocks() - - // Then clear act(() => { result.current.handleStatusFilterClear() }) - expect(result.current.statusFilterValue).toBe('all') expect(mockUpdateQuery).toHaveBeenCalledWith({ status: 'all', page: 1 }) }) - it('should not call updateQuery when status is already all', () => { + it('should do nothing when status is already all', () => { const { result } = renderHook(() => useDocumentsPageState()) act(() => { @@ -271,7 +179,7 @@ describe('useDocumentsPageState', () => { }) describe('handleSortChange', () => { - it('should update sort value and call updateQuery when value changes', () => { + it('should update sort and reset page when sort changes', () => { const { result } = renderHook(() => useDocumentsPageState()) act(() => { @@ -282,18 +190,7 @@ describe('useDocumentsPageState', () => { expect(mockUpdateQuery).toHaveBeenCalledWith({ sort: 'hit_count', page: 1 }) }) - it('should reset page to 0 when sort changes', () => { - mockQuery = { ...mockQuery, page: 5 } - const { result } = renderHook(() => useDocumentsPageState()) - - act(() => { - result.current.handleSortChange('hit_count') - }) - - expect(result.current.currPage).toBe(0) - }) - - it('should not call updateQuery when sort value is same as current', () => { + it('should ignore sort update when value is unchanged', () => { const { result } = renderHook(() => useDocumentsPageState()) act(() => { @@ -304,8 +201,8 @@ describe('useDocumentsPageState', () => { }) }) - describe('handlePageChange', () => { - it('should update current page and call updateQuery', () => { + describe('pagination handlers', () => { + it('should update page with one-based value', () => { const { result } = renderHook(() => useDocumentsPageState()) act(() => { @@ -313,23 +210,10 @@ describe('useDocumentsPageState', () => { }) expect(result.current.currPage).toBe(2) - expect(mockUpdateQuery).toHaveBeenCalledWith({ page: 3 }) // newPage + 1 + expect(mockUpdateQuery).toHaveBeenCalledWith({ page: 3 }) }) - it('should handle page 0 (first page)', () => { - const { result } = renderHook(() => useDocumentsPageState()) - - act(() => { - result.current.handlePageChange(0) - }) - - expect(result.current.currPage).toBe(0) - expect(mockUpdateQuery).toHaveBeenCalledWith({ page: 1 }) - }) - }) - - describe('handleLimitChange', () => { - it('should update limit, reset page to 0, and call updateQuery', () => { + it('should update limit and reset page', () => { const { result } = renderHook(() => useDocumentsPageState()) act(() => { @@ -342,359 +226,29 @@ describe('useDocumentsPageState', () => { }) }) - // Selection state - describe('selection state', () => { - it('should update selectedIds via setSelectedIds', () => { - const { result } = renderHook(() => useDocumentsPageState()) - - act(() => { - result.current.setSelectedIds(['doc-1', 'doc-2']) - }) - - expect(result.current.selectedIds).toEqual(['doc-1', 'doc-2']) - }) - }) - - // Polling state management - describe('updatePollingState', () => { - it('should not update timer when documentsRes is undefined', () => { - const { result } = renderHook(() => useDocumentsPageState()) - - act(() => { - result.current.updatePollingState(undefined) - }) - - // timerCanRun remains true (initial value) - expect(result.current.timerCanRun).toBe(true) - }) - - it('should not update timer when documentsRes.data is undefined', () => { - const { result } = renderHook(() => useDocumentsPageState()) - - act(() => { - result.current.updatePollingState({ data: undefined } as unknown as DocumentListResponse) - }) - - expect(result.current.timerCanRun).toBe(true) - }) - - it('should set timerCanRun to false when all documents are completed and status filter is all', () => { - const response = createDocumentListResponse({ - data: [ - createDocumentItem({ indexing_status: 'completed' }), - createDocumentItem({ indexing_status: 'completed' }), - ] as DocumentListResponse['data'], - total: 2, - }) - - const { result } = renderHook(() => useDocumentsPageState()) - - act(() => { - result.current.updatePollingState(response) - }) - - expect(result.current.timerCanRun).toBe(false) - }) - - it('should set timerCanRun to true when some documents are not completed', () => { - const response = createDocumentListResponse({ - data: [ - createDocumentItem({ indexing_status: 'completed' }), - createDocumentItem({ indexing_status: 'indexing' }), - ] as DocumentListResponse['data'], - total: 2, - }) - - const { result } = renderHook(() => useDocumentsPageState()) - - act(() => { - result.current.updatePollingState(response) - }) - - expect(result.current.timerCanRun).toBe(true) - }) - - it('should count paused documents as completed for polling purposes', () => { - const response = createDocumentListResponse({ - data: [ - createDocumentItem({ indexing_status: 'paused' }), - createDocumentItem({ indexing_status: 'completed' }), - ] as DocumentListResponse['data'], - total: 2, - }) - - const { result } = renderHook(() => useDocumentsPageState()) - - act(() => { - result.current.updatePollingState(response) - }) - - // All docs are "embedded" (completed, paused, error), so hasIncomplete = false - // statusFilter is 'all', so shouldForcePolling = false - expect(result.current.timerCanRun).toBe(false) - }) - - it('should count error documents as completed for polling purposes', () => { - const response = createDocumentListResponse({ - data: [ - createDocumentItem({ indexing_status: 'error' }), - createDocumentItem({ indexing_status: 'completed' }), - ] as DocumentListResponse['data'], - total: 2, - }) - - const { result } = renderHook(() => useDocumentsPageState()) - - act(() => { - result.current.updatePollingState(response) - }) - - expect(result.current.timerCanRun).toBe(false) - }) - - it('should force polling when status filter is a transient status (queuing)', () => { - const { result } = renderHook(() => useDocumentsPageState()) - - // Set status filter to queuing - act(() => { - result.current.handleStatusFilterChange('queuing') - }) - - const response = createDocumentListResponse({ - data: [ - createDocumentItem({ indexing_status: 'completed' }), - ] as DocumentListResponse['data'], - total: 1, - }) - - act(() => { - result.current.updatePollingState(response) - }) - - // shouldForcePolling = true (queuing is transient), hasIncomplete = false - // timerCanRun = true || false = true - expect(result.current.timerCanRun).toBe(true) - }) - - it('should force polling when status filter is indexing', () => { - const { result } = renderHook(() => useDocumentsPageState()) - - act(() => { - result.current.handleStatusFilterChange('indexing') - }) - - const response = createDocumentListResponse({ - data: [ - createDocumentItem({ indexing_status: 'completed' }), - ] as DocumentListResponse['data'], - total: 1, - }) - - act(() => { - result.current.updatePollingState(response) - }) - - expect(result.current.timerCanRun).toBe(true) - }) - - it('should force polling when status filter is paused', () => { - const { result } = renderHook(() => useDocumentsPageState()) - - act(() => { - result.current.handleStatusFilterChange('paused') - }) - - const response = createDocumentListResponse({ - data: [ - createDocumentItem({ indexing_status: 'paused' }), - ] as DocumentListResponse['data'], - total: 1, - }) - - act(() => { - result.current.updatePollingState(response) - }) - - expect(result.current.timerCanRun).toBe(true) - }) - - it('should not force polling when status filter is a non-transient status (error)', () => { - const { result } = renderHook(() => useDocumentsPageState()) - - act(() => { - result.current.handleStatusFilterChange('error') - }) - - const response = createDocumentListResponse({ - data: [ - createDocumentItem({ indexing_status: 'error' }), - ] as DocumentListResponse['data'], - total: 1, - }) - - act(() => { - result.current.updatePollingState(response) - }) - - // shouldForcePolling = false (error is not transient), hasIncomplete = false (error is embedded) - expect(result.current.timerCanRun).toBe(false) - }) - - it('should set timerCanRun to true when data is empty and filter is transient', () => { - const { result } = renderHook(() => useDocumentsPageState()) - - act(() => { - result.current.handleStatusFilterChange('indexing') - }) - - const response = createDocumentListResponse({ data: [] as DocumentListResponse['data'], total: 0 }) - - act(() => { - result.current.updatePollingState(response) - }) - - // shouldForcePolling = true (indexing is transient), hasIncomplete = false (0 !== 0 is false) - expect(result.current.timerCanRun).toBe(true) - }) - }) - - // Page adjustment - describe('adjustPageForTotal', () => { - it('should not adjust page when documentsRes is undefined', () => { - const { result } = renderHook(() => useDocumentsPageState()) - - act(() => { - result.current.adjustPageForTotal(undefined) - }) - - expect(result.current.currPage).toBe(0) - }) - - it('should not adjust page when currPage is within total pages', () => { - const { result } = renderHook(() => useDocumentsPageState()) - - const response = createDocumentListResponse({ total: 20 }) - - act(() => { - result.current.adjustPageForTotal(response) - }) - - // currPage is 0, totalPages is 2, so no adjustment needed - expect(result.current.currPage).toBe(0) - }) - - it('should adjust page to last page when currPage exceeds total pages', () => { - mockQuery = { ...mockQuery, page: 6 } - const { result } = renderHook(() => useDocumentsPageState()) - - // currPage should be 5 (page - 1) - expect(result.current.currPage).toBe(5) - - const response = createDocumentListResponse({ total: 30 }) // 30/10 = 3 pages - - act(() => { - result.current.adjustPageForTotal(response) - }) - - // currPage (5) + 1 > totalPages (3), so adjust to totalPages - 1 = 2 - expect(result.current.currPage).toBe(2) - expect(mockUpdateQuery).toHaveBeenCalledWith({ page: 3 }) // handlePageChange passes newPage + 1 - }) - - it('should adjust page to 0 when total is 0 and currPage > 0', () => { - mockQuery = { ...mockQuery, page: 3 } - const { result } = renderHook(() => useDocumentsPageState()) - - const response = createDocumentListResponse({ total: 0 }) - - act(() => { - result.current.adjustPageForTotal(response) - }) - - // totalPages = 0, so adjust to max(0 - 1, 0) = 0 - expect(result.current.currPage).toBe(0) - expect(mockUpdateQuery).toHaveBeenCalledWith({ page: 1 }) - }) - - it('should not adjust page when currPage is 0 even if total is 0', () => { - const { result } = renderHook(() => useDocumentsPageState()) - - const response = createDocumentListResponse({ total: 0 }) - - act(() => { - result.current.adjustPageForTotal(response) - }) - - // currPage is 0, condition is currPage > 0 so no adjustment - expect(mockUpdateQuery).not.toHaveBeenCalled() - }) - }) - - // Normalized status filter value - describe('normalizedStatusFilterValue', () => { - it('should return all for default status', () => { - const { result } = renderHook(() => useDocumentsPageState()) - - expect(result.current.normalizedStatusFilterValue).toBe('all') - }) - - it('should normalize enabled to available', () => { - const { result } = renderHook(() => useDocumentsPageState()) - - act(() => { - result.current.handleStatusFilterChange('enabled') - }) - - expect(result.current.normalizedStatusFilterValue).toBe('available') - }) - - it('should return non-aliased status as-is', () => { - const { result } = renderHook(() => useDocumentsPageState()) - - act(() => { - result.current.handleStatusFilterChange('error') - }) - - expect(result.current.normalizedStatusFilterValue).toBe('error') - }) - }) - // Return value shape describe('return value', () => { it('should return all expected properties', () => { const { result } = renderHook(() => useDocumentsPageState()) - // Search state expect(result.current).toHaveProperty('inputValue') - expect(result.current).toHaveProperty('searchValue') expect(result.current).toHaveProperty('debouncedSearchValue') expect(result.current).toHaveProperty('handleInputChange') - - // Filter & sort state expect(result.current).toHaveProperty('statusFilterValue') expect(result.current).toHaveProperty('sortValue') expect(result.current).toHaveProperty('normalizedStatusFilterValue') expect(result.current).toHaveProperty('handleStatusFilterChange') expect(result.current).toHaveProperty('handleStatusFilterClear') expect(result.current).toHaveProperty('handleSortChange') - - // Pagination state expect(result.current).toHaveProperty('currPage') expect(result.current).toHaveProperty('limit') expect(result.current).toHaveProperty('handlePageChange') expect(result.current).toHaveProperty('handleLimitChange') - - // Selection state expect(result.current).toHaveProperty('selectedIds') expect(result.current).toHaveProperty('setSelectedIds') - - // Polling state - expect(result.current).toHaveProperty('timerCanRun') - expect(result.current).toHaveProperty('updatePollingState') - expect(result.current).toHaveProperty('adjustPageForTotal') }) - it('should have function types for all handlers', () => { + it('should expose function handlers', () => { const { result } = renderHook(() => useDocumentsPageState()) expect(typeof result.current.handleInputChange).toBe('function') @@ -704,8 +258,6 @@ describe('useDocumentsPageState', () => { expect(typeof result.current.handlePageChange).toBe('function') expect(typeof result.current.handleLimitChange).toBe('function') expect(typeof result.current.setSelectedIds).toBe('function') - expect(typeof result.current.updatePollingState).toBe('function') - expect(typeof result.current.adjustPageForTotal).toBe('function') }) }) }) diff --git a/web/app/components/datasets/documents/hooks/use-document-list-query-state.ts b/web/app/components/datasets/documents/hooks/use-document-list-query-state.ts index 505f15efc0..60717d532c 100644 --- a/web/app/components/datasets/documents/hooks/use-document-list-query-state.ts +++ b/web/app/components/datasets/documents/hooks/use-document-list-query-state.ts @@ -1,6 +1,6 @@ -import type { ReadonlyURLSearchParams } from 'next/navigation' +import type { inferParserType } from 'nuqs' import type { SortType } from '@/service/datasets' -import { usePathname, useRouter, useSearchParams } from 'next/navigation' +import { createParser, parseAsString, throttle, useQueryStates } from 'nuqs' import { useCallback, useMemo } from 'react' import { sanitizeStatusValue } from '../status-filter' @@ -13,99 +13,87 @@ const sanitizeSortValue = (value?: string | null): SortType => { return (ALLOWED_SORT_VALUES.includes(value as SortType) ? value : '-created_at') as SortType } -export type DocumentListQuery = { - page: number - limit: number - keyword: string - status: string - sort: SortType +const sanitizePageValue = (value: number): number => { + return Number.isInteger(value) && value > 0 ? value : 1 } -const DEFAULT_QUERY: DocumentListQuery = { - page: 1, - limit: 10, - keyword: '', - status: 'all', - sort: '-created_at', +const sanitizeLimitValue = (value: number): number => { + return Number.isInteger(value) && value > 0 && value <= 100 ? value : 10 } -// Parse the query parameters from the URL search string. -function parseParams(params: ReadonlyURLSearchParams): DocumentListQuery { - const page = Number.parseInt(params.get('page') || '1', 10) - const limit = Number.parseInt(params.get('limit') || '10', 10) - const keyword = params.get('keyword') || '' - const status = sanitizeStatusValue(params.get('status')) - const sort = sanitizeSortValue(params.get('sort')) +const parseAsPage = createParser({ + parse: (value) => { + const n = Number.parseInt(value, 10) + return Number.isNaN(n) || n <= 0 ? null : n + }, + serialize: value => value.toString(), +}).withDefault(1) - return { - page: page > 0 ? page : 1, - limit: (limit > 0 && limit <= 100) ? limit : 10, - keyword: keyword ? decodeURIComponent(keyword) : '', - status, - sort, - } +const parseAsLimit = createParser({ + parse: (value) => { + const n = Number.parseInt(value, 10) + return Number.isNaN(n) || n <= 0 || n > 100 ? null : n + }, + serialize: value => value.toString(), +}).withDefault(10) + +const parseAsDocStatus = createParser({ + parse: value => sanitizeStatusValue(value), + serialize: value => value, +}).withDefault('all') + +const parseAsDocSort = createParser({ + parse: value => sanitizeSortValue(value), + serialize: value => value, +}).withDefault('-created_at' as SortType) + +const parseAsKeyword = parseAsString.withDefault('') + +export const documentListParsers = { + page: parseAsPage, + limit: parseAsLimit, + keyword: parseAsKeyword, + status: parseAsDocStatus, + sort: parseAsDocSort, } -// Update the URL search string with the given query parameters. -function updateSearchParams(query: DocumentListQuery, searchParams: URLSearchParams) { - const { page, limit, keyword, status, sort } = query || {} +export type DocumentListQuery = inferParserType - const hasNonDefaultParams = (page && page > 1) || (limit && limit !== 10) || (keyword && keyword.trim()) +// Search input updates can be frequent; throttle URL writes to reduce history/api churn. +const KEYWORD_URL_UPDATE_THROTTLE = throttle(300) - if (hasNonDefaultParams) { - searchParams.set('page', (page || 1).toString()) - searchParams.set('limit', (limit || 10).toString()) - } - else { - searchParams.delete('page') - searchParams.delete('limit') - } +export function useDocumentListQueryState() { + const [query, setQuery] = useQueryStates(documentListParsers) - if (keyword && keyword.trim()) - searchParams.set('keyword', encodeURIComponent(keyword)) - else - searchParams.delete('keyword') - - const sanitizedStatus = sanitizeStatusValue(status) - if (sanitizedStatus && sanitizedStatus !== 'all') - searchParams.set('status', sanitizedStatus) - else - searchParams.delete('status') - - const sanitizedSort = sanitizeSortValue(sort) - if (sanitizedSort !== '-created_at') - searchParams.set('sort', sanitizedSort) - else - searchParams.delete('sort') -} - -function useDocumentListQueryState() { - const searchParams = useSearchParams() - const query = useMemo(() => parseParams(searchParams), [searchParams]) - - const router = useRouter() - const pathname = usePathname() - - // Helper function to update specific query parameters const updateQuery = useCallback((updates: Partial) => { - const newQuery = { ...query, ...updates } - newQuery.status = sanitizeStatusValue(newQuery.status) - newQuery.sort = sanitizeSortValue(newQuery.sort) - const params = new URLSearchParams() - updateSearchParams(newQuery, params) - const search = params.toString() - const queryString = search ? `?${search}` : '' - router.push(`${pathname}${queryString}`, { scroll: false }) - }, [query, router, pathname]) + const patch = { ...updates } + if ('page' in patch && patch.page !== undefined) + patch.page = sanitizePageValue(patch.page) + if ('limit' in patch && patch.limit !== undefined) + patch.limit = sanitizeLimitValue(patch.limit) + if ('status' in patch) + patch.status = sanitizeStatusValue(patch.status) + if ('sort' in patch) + patch.sort = sanitizeSortValue(patch.sort) + if ('keyword' in patch && typeof patch.keyword === 'string' && patch.keyword.trim() === '') + patch.keyword = '' + + // If keyword is part of this patch (even with page reset), treat it as a search update: + // use replace to avoid creating a history entry per input-driven change. + if ('keyword' in patch) { + setQuery(patch, { + history: 'replace', + limitUrlUpdates: patch.keyword === '' ? undefined : KEYWORD_URL_UPDATE_THROTTLE, + }) + return + } + + setQuery(patch, { history: 'push' }) + }, [setQuery]) - // Helper function to reset query to defaults const resetQuery = useCallback(() => { - const params = new URLSearchParams() - updateSearchParams(DEFAULT_QUERY, params) - const search = params.toString() - const queryString = search ? `?${search}` : '' - router.push(`${pathname}${queryString}`, { scroll: false }) - }, [router, pathname]) + setQuery(null, { history: 'replace' }) + }, [setQuery]) return useMemo(() => ({ query, @@ -113,5 +101,3 @@ function useDocumentListQueryState() { resetQuery, }), [query, updateQuery, resetQuery]) } - -export default useDocumentListQueryState diff --git a/web/app/components/datasets/documents/hooks/use-documents-page-state.ts b/web/app/components/datasets/documents/hooks/use-documents-page-state.ts index 4fb227f717..36b1e8c760 100644 --- a/web/app/components/datasets/documents/hooks/use-documents-page-state.ts +++ b/web/app/components/datasets/documents/hooks/use-documents-page-state.ts @@ -1,175 +1,63 @@ -import type { DocumentListResponse } from '@/models/datasets' import type { SortType } from '@/service/datasets' -import { useDebounce, useDebounceFn } from 'ahooks' -import { useCallback, useEffect, useMemo, useState } from 'react' +import { useDebounce } from 'ahooks' +import { useCallback, useState } from 'react' import { normalizeStatusForQuery, sanitizeStatusValue } from '../status-filter' -import useDocumentListQueryState from './use-document-list-query-state' +import { useDocumentListQueryState } from './use-document-list-query-state' -/** - * Custom hook to manage documents page state including: - * - Search state (input value, debounced search value) - * - Filter state (status filter, sort value) - * - Pagination state (current page, limit) - * - Selection state (selected document ids) - * - Polling state (timer control for auto-refresh) - */ export function useDocumentsPageState() { const { query, updateQuery } = useDocumentListQueryState() - // Search state - const [inputValue, setInputValue] = useState('') - const [searchValue, setSearchValue] = useState('') - const debouncedSearchValue = useDebounce(searchValue, { wait: 500 }) + const inputValue = query.keyword + const debouncedSearchValue = useDebounce(query.keyword, { wait: 500 }) - // Filter & sort state - const [statusFilterValue, setStatusFilterValue] = useState(() => sanitizeStatusValue(query.status)) - const [sortValue, setSortValue] = useState(query.sort) - const normalizedStatusFilterValue = useMemo( - () => normalizeStatusForQuery(statusFilterValue), - [statusFilterValue], - ) + const statusFilterValue = sanitizeStatusValue(query.status) + const sortValue = query.sort + const normalizedStatusFilterValue = normalizeStatusForQuery(statusFilterValue) - // Pagination state - const [currPage, setCurrPage] = useState(query.page - 1) - const [limit, setLimit] = useState(query.limit) + const currPage = query.page - 1 + const limit = query.limit - // Selection state const [selectedIds, setSelectedIds] = useState([]) - // Polling state - const [timerCanRun, setTimerCanRun] = useState(true) - - // Initialize search value from URL on mount - useEffect(() => { - if (query.keyword) { - setInputValue(query.keyword) - setSearchValue(query.keyword) - } - }, []) // Only run on mount - - // Sync local state with URL query changes - useEffect(() => { - setCurrPage(query.page - 1) - setLimit(query.limit) - if (query.keyword !== searchValue) { - setInputValue(query.keyword) - setSearchValue(query.keyword) - } - setStatusFilterValue((prev) => { - const nextValue = sanitizeStatusValue(query.status) - return prev === nextValue ? prev : nextValue - }) - setSortValue(query.sort) - }, [query]) - - // Update URL when search changes - useEffect(() => { - if (debouncedSearchValue !== query.keyword) { - setCurrPage(0) - updateQuery({ keyword: debouncedSearchValue, page: 1 }) - } - }, [debouncedSearchValue, query.keyword, updateQuery]) - - // Clear selection when search changes - useEffect(() => { - if (searchValue !== query.keyword) - setSelectedIds([]) - }, [searchValue, query.keyword]) - - // Clear selection when status filter changes - useEffect(() => { - setSelectedIds([]) - }, [normalizedStatusFilterValue]) - - // Page change handler const handlePageChange = useCallback((newPage: number) => { - setCurrPage(newPage) updateQuery({ page: newPage + 1 }) }, [updateQuery]) - // Limit change handler const handleLimitChange = useCallback((newLimit: number) => { - setLimit(newLimit) - setCurrPage(0) updateQuery({ limit: newLimit, page: 1 }) }, [updateQuery]) - // Debounced search handler - const { run: handleSearch } = useDebounceFn(() => { - setSearchValue(inputValue) - }, { wait: 500 }) - - // Input change handler const handleInputChange = useCallback((value: string) => { - setInputValue(value) - handleSearch() - }, [handleSearch]) + if (value !== query.keyword) + setSelectedIds([]) + updateQuery({ keyword: value, page: 1 }) + }, [query.keyword, updateQuery]) - // Status filter change handler const handleStatusFilterChange = useCallback((value: string) => { const selectedValue = sanitizeStatusValue(value) - setStatusFilterValue(selectedValue) - setCurrPage(0) + setSelectedIds([]) updateQuery({ status: selectedValue, page: 1 }) }, [updateQuery]) - // Status filter clear handler const handleStatusFilterClear = useCallback(() => { if (statusFilterValue === 'all') return - setStatusFilterValue('all') - setCurrPage(0) + setSelectedIds([]) updateQuery({ status: 'all', page: 1 }) }, [statusFilterValue, updateQuery]) - // Sort change handler const handleSortChange = useCallback((value: string) => { const next = value as SortType if (next === sortValue) return - setSortValue(next) - setCurrPage(0) updateQuery({ sort: next, page: 1 }) }, [sortValue, updateQuery]) - // Update polling state based on documents response - const updatePollingState = useCallback((documentsRes: DocumentListResponse | undefined) => { - if (!documentsRes?.data) - return - - let completedNum = 0 - documentsRes.data.forEach((documentItem) => { - const { indexing_status } = documentItem - const isEmbedded = indexing_status === 'completed' || indexing_status === 'paused' || indexing_status === 'error' - if (isEmbedded) - completedNum++ - }) - - const hasIncompleteDocuments = completedNum !== documentsRes.data.length - const transientStatuses = ['queuing', 'indexing', 'paused'] - const shouldForcePolling = normalizedStatusFilterValue === 'all' - ? false - : transientStatuses.includes(normalizedStatusFilterValue) - setTimerCanRun(shouldForcePolling || hasIncompleteDocuments) - }, [normalizedStatusFilterValue]) - - // Adjust page when total pages change - const adjustPageForTotal = useCallback((documentsRes: DocumentListResponse | undefined) => { - if (!documentsRes) - return - const totalPages = Math.ceil(documentsRes.total / limit) - if (currPage > 0 && currPage + 1 > totalPages) - handlePageChange(totalPages > 0 ? totalPages - 1 : 0) - }, [limit, currPage, handlePageChange]) - return { - // Search state inputValue, - searchValue, debouncedSearchValue, handleInputChange, - // Filter & sort state statusFilterValue, sortValue, normalizedStatusFilterValue, @@ -177,21 +65,12 @@ export function useDocumentsPageState() { handleStatusFilterClear, handleSortChange, - // Pagination state currPage, limit, handlePageChange, handleLimitChange, - // Selection state selectedIds, setSelectedIds, - - // Polling state - timerCanRun, - updatePollingState, - adjustPageForTotal, } } - -export default useDocumentsPageState diff --git a/web/app/components/datasets/documents/index.tsx b/web/app/components/datasets/documents/index.tsx index 676e715f56..764b04227c 100644 --- a/web/app/components/datasets/documents/index.tsx +++ b/web/app/components/datasets/documents/index.tsx @@ -1,7 +1,7 @@ 'use client' import type { FC } from 'react' import { useRouter } from 'next/navigation' -import { useCallback, useEffect } from 'react' +import { useCallback } from 'react' import Loading from '@/app/components/base/loading' import { useDatasetDetailContextWithSelector } from '@/context/dataset-detail' import { useProviderContext } from '@/context/provider-context' @@ -13,12 +13,16 @@ import useEditDocumentMetadata from '../metadata/hooks/use-edit-dataset-metadata import DocumentsHeader from './components/documents-header' import EmptyElement from './components/empty-element' import List from './components/list' -import useDocumentsPageState from './hooks/use-documents-page-state' +import { useDocumentsPageState } from './hooks/use-documents-page-state' type IDocumentsProps = { datasetId: string } +const POLLING_INTERVAL = 2500 +const TERMINAL_INDEXING_STATUSES = new Set(['completed', 'paused', 'error']) +const FORCED_POLLING_STATUSES = new Set(['queuing', 'indexing', 'paused']) + const Documents: FC = ({ datasetId }) => { const router = useRouter() const { plan } = useProviderContext() @@ -44,9 +48,6 @@ const Documents: FC = ({ datasetId }) => { handleLimitChange, selectedIds, setSelectedIds, - timerCanRun, - updatePollingState, - adjustPageForTotal, } = useDocumentsPageState() // Fetch document list @@ -59,19 +60,17 @@ const Documents: FC = ({ datasetId }) => { status: normalizedStatusFilterValue, sort: sortValue, }, - refetchInterval: timerCanRun ? 2500 : 0, + refetchInterval: (query) => { + const shouldForcePolling = normalizedStatusFilterValue !== 'all' + && FORCED_POLLING_STATUSES.has(normalizedStatusFilterValue) + const documents = query.state.data?.data + if (!documents) + return POLLING_INTERVAL + const hasIncompleteDocuments = documents.some(({ indexing_status }) => !TERMINAL_INDEXING_STATUSES.has(indexing_status)) + return shouldForcePolling || hasIncompleteDocuments ? POLLING_INTERVAL : false + }, }) - // Update polling state when documents change - useEffect(() => { - updatePollingState(documentsRes) - }, [documentsRes, updatePollingState]) - - // Adjust page when total changes - useEffect(() => { - adjustPageForTotal(documentsRes) - }, [documentsRes, adjustPageForTotal]) - // Invalidation hooks const invalidDocumentList = useInvalidDocumentList(datasetId) const invalidDocumentDetail = useInvalidDocumentDetail() @@ -119,7 +118,7 @@ const Documents: FC = ({ datasetId }) => { // Render content based on loading and data state const renderContent = () => { - if (isListLoading) + if (isListLoading && !documentsRes) return if (total > 0) { @@ -131,8 +130,8 @@ const Documents: FC = ({ datasetId }) => { onUpdate={handleUpdate} selectedIds={selectedIds} onSelectedIdChange={setSelectedIds} - statusFilterValue={normalizedStatusFilterValue} remoteSortValue={sortValue} + onSortChange={handleSortChange} pagination={{ total, limit, diff --git a/web/app/components/explore/app-list/__tests__/index.spec.tsx b/web/app/components/explore/app-list/__tests__/index.spec.tsx index 1a389e21ba..5d7dffd40a 100644 --- a/web/app/components/explore/app-list/__tests__/index.spec.tsx +++ b/web/app/components/explore/app-list/__tests__/index.spec.tsx @@ -1,12 +1,12 @@ import type { Mock } from 'vitest' import type { CreateAppModalProps } from '@/app/components/explore/create-app-modal' import type { App } from '@/models/explore' -import { act, fireEvent, render, screen, waitFor } from '@testing-library/react' -import { NuqsTestingAdapter } from 'nuqs/adapters/testing' +import { act, fireEvent, screen, waitFor } from '@testing-library/react' import { useAppContext } from '@/context/app-context' import { useGlobalPublicStore } from '@/context/global-public-context' import { fetchAppDetail } from '@/service/explore' import { useMembers } from '@/service/use-common' +import { renderWithNuqs } from '@/test/nuqs-testing' import { AppModeEnum } from '@/types/app' import AppList from '../index' @@ -132,10 +132,9 @@ const mockMemberRole = (hasEditPermission: boolean) => { const renderAppList = (hasEditPermission = false, onSuccess?: () => void, searchParams?: Record) => { mockMemberRole(hasEditPermission) - return render( - - - , + return renderWithNuqs( + , + { searchParams }, ) } diff --git a/web/app/components/plugins/marketplace/__tests__/atoms.spec.tsx b/web/app/components/plugins/marketplace/__tests__/atoms.spec.tsx index 40fc47a9d2..2fae0f90d6 100644 --- a/web/app/components/plugins/marketplace/__tests__/atoms.spec.tsx +++ b/web/app/components/plugins/marketplace/__tests__/atoms.spec.tsx @@ -1,21 +1,20 @@ -import type { UrlUpdateEvent } from 'nuqs/adapters/testing' import type { ReactNode } from 'react' import { act, renderHook } from '@testing-library/react' import { Provider as JotaiProvider } from 'jotai' -import { NuqsTestingAdapter } from 'nuqs/adapters/testing' import { beforeEach, describe, expect, it, vi } from 'vitest' +import { createNuqsTestWrapper } from '@/test/nuqs-testing' import { DEFAULT_SORT } from '../constants' const createWrapper = (searchParams = '') => { - const onUrlUpdate = vi.fn<(event: UrlUpdateEvent) => void>() + const { wrapper: NuqsWrapper } = createNuqsTestWrapper({ searchParams }) const wrapper = ({ children }: { children: ReactNode }) => ( - + {children} - + ) - return { wrapper, onUrlUpdate } + return { wrapper } } describe('Marketplace sort atoms', () => { diff --git a/web/app/components/plugins/marketplace/__tests__/plugin-type-switch.spec.tsx b/web/app/components/plugins/marketplace/__tests__/plugin-type-switch.spec.tsx index 6bb075410e..3e5e6a5e0a 100644 --- a/web/app/components/plugins/marketplace/__tests__/plugin-type-switch.spec.tsx +++ b/web/app/components/plugins/marketplace/__tests__/plugin-type-switch.spec.tsx @@ -1,9 +1,8 @@ -import type { UrlUpdateEvent } from 'nuqs/adapters/testing' import type { ReactNode } from 'react' import { fireEvent, render, screen } from '@testing-library/react' import { Provider as JotaiProvider } from 'jotai' -import { NuqsTestingAdapter } from 'nuqs/adapters/testing' import { beforeEach, describe, expect, it, vi } from 'vitest' +import { createNuqsTestWrapper } from '@/test/nuqs-testing' import PluginTypeSwitch from '../plugin-type-switch' vi.mock('#i18n', () => ({ @@ -25,15 +24,15 @@ vi.mock('#i18n', () => ({ })) const createWrapper = (searchParams = '') => { - const onUrlUpdate = vi.fn<(event: UrlUpdateEvent) => void>() + const { wrapper: NuqsWrapper } = createNuqsTestWrapper({ searchParams }) const Wrapper = ({ children }: { children: ReactNode }) => ( - + {children} - + ) - return { Wrapper, onUrlUpdate } + return { Wrapper } } describe('PluginTypeSwitch', () => { diff --git a/web/app/components/plugins/marketplace/__tests__/state.spec.tsx b/web/app/components/plugins/marketplace/__tests__/state.spec.tsx index 4177c9b2b7..f4330f31b4 100644 --- a/web/app/components/plugins/marketplace/__tests__/state.spec.tsx +++ b/web/app/components/plugins/marketplace/__tests__/state.spec.tsx @@ -2,8 +2,8 @@ import type { ReactNode } from 'react' import { QueryClient, QueryClientProvider } from '@tanstack/react-query' import { renderHook, waitFor } from '@testing-library/react' import { Provider as JotaiProvider } from 'jotai' -import { NuqsTestingAdapter } from 'nuqs/adapters/testing' import { beforeEach, describe, expect, it, vi } from 'vitest' +import { createNuqsTestWrapper } from '@/test/nuqs-testing' vi.mock('@/config', () => ({ API_PREFIX: '/api', @@ -37,6 +37,7 @@ vi.mock('@/service/client', () => ({ })) const createWrapper = (searchParams = '') => { + const { wrapper: NuqsWrapper } = createNuqsTestWrapper({ searchParams }) const queryClient = new QueryClient({ defaultOptions: { queries: { retry: false, gcTime: 0 }, @@ -45,9 +46,9 @@ const createWrapper = (searchParams = '') => { const Wrapper = ({ children }: { children: ReactNode }) => ( - + {children} - + ) diff --git a/web/app/components/plugins/marketplace/__tests__/sticky-search-and-switch-wrapper.spec.tsx b/web/app/components/plugins/marketplace/__tests__/sticky-search-and-switch-wrapper.spec.tsx index 1311adb508..1876692dad 100644 --- a/web/app/components/plugins/marketplace/__tests__/sticky-search-and-switch-wrapper.spec.tsx +++ b/web/app/components/plugins/marketplace/__tests__/sticky-search-and-switch-wrapper.spec.tsx @@ -1,8 +1,8 @@ import type { ReactNode } from 'react' import { render } from '@testing-library/react' import { Provider as JotaiProvider } from 'jotai' -import { NuqsTestingAdapter } from 'nuqs/adapters/testing' import { beforeEach, describe, expect, it, vi } from 'vitest' +import { createNuqsTestWrapper } from '@/test/nuqs-testing' import StickySearchAndSwitchWrapper from '../sticky-search-and-switch-wrapper' vi.mock('#i18n', () => ({ @@ -20,13 +20,17 @@ vi.mock('../search-box/search-box-wrapper', () => ({ default: () =>
SearchBoxWrapper
, })) -const Wrapper = ({ children }: { children: ReactNode }) => ( - - - {children} - - -) +const createWrapper = () => { + const { wrapper: NuqsWrapper } = createNuqsTestWrapper() + const Wrapper = ({ children }: { children: ReactNode }) => ( + + + {children} + + + ) + return { Wrapper } +} describe('StickySearchAndSwitchWrapper', () => { beforeEach(() => { @@ -34,6 +38,7 @@ describe('StickySearchAndSwitchWrapper', () => { }) it('should render SearchBoxWrapper and PluginTypeSwitch', () => { + const { Wrapper } = createWrapper() const { getByTestId } = render( , { wrapper: Wrapper }, @@ -44,6 +49,7 @@ describe('StickySearchAndSwitchWrapper', () => { }) it('should not apply sticky class when no pluginTypeSwitchClassName', () => { + const { Wrapper } = createWrapper() const { container } = render( , { wrapper: Wrapper }, @@ -55,6 +61,7 @@ describe('StickySearchAndSwitchWrapper', () => { }) it('should apply sticky class when pluginTypeSwitchClassName contains top-', () => { + const { Wrapper } = createWrapper() const { container } = render( , { wrapper: Wrapper }, @@ -67,6 +74,7 @@ describe('StickySearchAndSwitchWrapper', () => { }) it('should not apply sticky class when pluginTypeSwitchClassName does not contain top-', () => { + const { Wrapper } = createWrapper() const { container } = render( , { wrapper: Wrapper }, diff --git a/web/app/components/plugins/marketplace/hydration-server.tsx b/web/app/components/plugins/marketplace/hydration-server.tsx index 287bcbe8c0..2b1dea25cb 100644 --- a/web/app/components/plugins/marketplace/hydration-server.tsx +++ b/web/app/components/plugins/marketplace/hydration-server.tsx @@ -1,4 +1,5 @@ import type { SearchParams } from 'nuqs/server' +import type { MarketplaceSearchParams } from './search-params' import { dehydrate, HydrationBoundary } from '@tanstack/react-query' import { createLoader } from 'nuqs/server' import { getQueryClientServer } from '@/context/query-client-server' @@ -14,7 +15,7 @@ async function getDehydratedState(searchParams?: Promise) { return } const loadSearchParams = createLoader(marketplaceSearchParamsParsers) - const params = await loadSearchParams(searchParams) + const params: MarketplaceSearchParams = await loadSearchParams(searchParams) if (!PLUGIN_CATEGORY_WITH_COLLECTIONS.has(params.category)) { return diff --git a/web/app/components/plugins/marketplace/search-params.ts b/web/app/components/plugins/marketplace/search-params.ts index ad0b16977f..a7da306045 100644 --- a/web/app/components/plugins/marketplace/search-params.ts +++ b/web/app/components/plugins/marketplace/search-params.ts @@ -1,3 +1,4 @@ +import type { inferParserType } from 'nuqs/server' import type { ActivePluginType } from './constants' import { parseAsArrayOf, parseAsString, parseAsStringEnum } from 'nuqs/server' import { PLUGIN_TYPE_SEARCH_MAP } from './constants' @@ -7,3 +8,5 @@ export const marketplaceSearchParamsParsers = { q: parseAsString.withDefault('').withOptions({ history: 'replace' }), tags: parseAsArrayOf(parseAsString).withDefault([]).withOptions({ history: 'replace' }), } + +export type MarketplaceSearchParams = inferParserType diff --git a/web/app/components/plugins/plugin-page/__tests__/context.spec.tsx b/web/app/components/plugins/plugin-page/__tests__/context.spec.tsx index 4dd23f53f1..ab0e35e042 100644 --- a/web/app/components/plugins/plugin-page/__tests__/context.spec.tsx +++ b/web/app/components/plugins/plugin-page/__tests__/context.spec.tsx @@ -7,6 +7,9 @@ import { PluginPageContext, PluginPageContextProvider, usePluginPageContext } fr // Mock dependencies vi.mock('nuqs', () => ({ + parseAsStringEnum: vi.fn(() => ({ + withDefault: vi.fn(() => ({})), + })), useQueryState: vi.fn(() => ['plugins', vi.fn()]), })) diff --git a/web/app/components/plugins/plugin-page/__tests__/index.spec.tsx b/web/app/components/plugins/plugin-page/__tests__/index.spec.tsx index be9f0b1858..dafcbe57c2 100644 --- a/web/app/components/plugins/plugin-page/__tests__/index.spec.tsx +++ b/web/app/components/plugins/plugin-page/__tests__/index.spec.tsx @@ -80,6 +80,9 @@ vi.mock('@/service/use-plugins', () => ({ })) vi.mock('nuqs', () => ({ + parseAsStringEnum: vi.fn(() => ({ + withDefault: vi.fn(() => ({})), + })), useQueryState: vi.fn(() => ['plugins', vi.fn()]), })) diff --git a/web/app/components/plugins/plugin-page/context.tsx b/web/app/components/plugins/plugin-page/context.tsx index abc4408d62..01ec518347 100644 --- a/web/app/components/plugins/plugin-page/context.tsx +++ b/web/app/components/plugins/plugin-page/context.tsx @@ -3,7 +3,7 @@ import type { ReactNode, RefObject } from 'react' import type { FilterState } from './filter-management' import { noop } from 'es-toolkit/function' -import { useQueryState } from 'nuqs' +import { parseAsStringEnum, useQueryState } from 'nuqs' import { useMemo, useRef, @@ -15,6 +15,19 @@ import { } from 'use-context-selector' import { useGlobalPublicStore } from '@/context/global-public-context' import { PLUGIN_PAGE_TABS_MAP, usePluginPageTabs } from '../hooks' +import { PLUGIN_TYPE_SEARCH_MAP } from '../marketplace/constants' + +export type PluginPageTab = typeof PLUGIN_PAGE_TABS_MAP[keyof typeof PLUGIN_PAGE_TABS_MAP] + | (typeof PLUGIN_TYPE_SEARCH_MAP)[keyof typeof PLUGIN_TYPE_SEARCH_MAP] + +const PLUGIN_PAGE_TAB_VALUES: PluginPageTab[] = [ + PLUGIN_PAGE_TABS_MAP.plugins, + PLUGIN_PAGE_TABS_MAP.marketplace, + ...Object.values(PLUGIN_TYPE_SEARCH_MAP), +] + +const parseAsPluginPageTab = parseAsStringEnum(PLUGIN_PAGE_TAB_VALUES) + .withDefault(PLUGIN_PAGE_TABS_MAP.plugins) export type PluginPageContextValue = { containerRef: RefObject @@ -22,8 +35,8 @@ export type PluginPageContextValue = { setCurrentPluginID: (pluginID?: string) => void filters: FilterState setFilters: (filter: FilterState) => void - activeTab: string - setActiveTab: (tab: string) => void + activeTab: PluginPageTab + setActiveTab: (tab: PluginPageTab) => void options: Array<{ value: string, text: string }> } @@ -39,7 +52,7 @@ export const PluginPageContext = createContext({ searchQuery: '', }, setFilters: noop, - activeTab: '', + activeTab: PLUGIN_PAGE_TABS_MAP.plugins, setActiveTab: noop, options: [], }) @@ -68,9 +81,7 @@ export const PluginPageContextProvider = ({ const options = useMemo(() => { return enable_marketplace ? tabs : tabs.filter(tab => tab.value !== PLUGIN_PAGE_TABS_MAP.marketplace) }, [tabs, enable_marketplace]) - const [activeTab, setActiveTab] = useQueryState('tab', { - defaultValue: options[0].value, - }) + const [activeTab, setActiveTab] = useQueryState('tab', parseAsPluginPageTab) return ( ([ + PLUGIN_PAGE_TABS_MAP.plugins, + PLUGIN_PAGE_TABS_MAP.marketplace, + ...Object.values(PLUGIN_TYPE_SEARCH_MAP), +]) + +const isPluginPageTab = (value: string): value is PluginPageTab => { + return pluginPageTabSet.has(value) +} + export type PluginPageProps = { plugins: React.ReactNode marketplace: React.ReactNode @@ -154,7 +165,10 @@ const PluginPage = ({
{ + if (isPluginPageTab(nextTab)) + setActiveTab(nextTab) + }} options={options} />
diff --git a/web/app/components/tools/__tests__/provider-list.spec.tsx b/web/app/components/tools/__tests__/provider-list.spec.tsx index 2c75c20979..a5cb4821bb 100644 --- a/web/app/components/tools/__tests__/provider-list.spec.tsx +++ b/web/app/components/tools/__tests__/provider-list.spec.tsx @@ -1,6 +1,6 @@ -import { cleanup, fireEvent, render, screen } from '@testing-library/react' -import { NuqsTestingAdapter } from 'nuqs/adapters/testing' +import { cleanup, fireEvent, screen } from '@testing-library/react' import { afterEach, beforeEach, describe, expect, it, vi } from 'vitest' +import { renderWithNuqs } from '@/test/nuqs-testing' import { ToolTypeEnum } from '../../workflow/block-selector/types' import ProviderList from '../provider-list' import { getToolType } from '../utils' @@ -206,10 +206,9 @@ describe('getToolType', () => { }) const renderProviderList = (searchParams?: Record) => { - return render( - - - , + return renderWithNuqs( + , + { searchParams }, ) } diff --git a/web/app/components/tools/provider-list.tsx b/web/app/components/tools/provider-list.tsx index ed6136f3c5..3501726cd0 100644 --- a/web/app/components/tools/provider-list.tsx +++ b/web/app/components/tools/provider-list.tsx @@ -1,6 +1,6 @@ 'use client' import type { Collection } from './types' -import { useQueryState } from 'nuqs' +import { parseAsStringLiteral, useQueryState } from 'nuqs' import { useCallback, useEffect, useMemo, useRef, useState } from 'react' import { useTranslation } from 'react-i18next' import Input from '@/app/components/base/input' @@ -23,6 +23,17 @@ import { useMarketplace } from './marketplace/hooks' import MCPList from './mcp' import { getToolType } from './utils' +const TOOL_PROVIDER_CATEGORY_VALUES = ['builtin', 'api', 'workflow', 'mcp'] as const +type ToolProviderCategory = typeof TOOL_PROVIDER_CATEGORY_VALUES[number] +const toolProviderCategorySet = new Set(TOOL_PROVIDER_CATEGORY_VALUES) + +const isToolProviderCategory = (value: string): value is ToolProviderCategory => { + return toolProviderCategorySet.has(value) +} + +const parseAsToolProviderCategory = parseAsStringLiteral(TOOL_PROVIDER_CATEGORY_VALUES) + .withDefault('builtin') + const ProviderList = () => { // const searchParams = useSearchParams() // searchParams.get('category') === 'workflow' @@ -31,9 +42,7 @@ const ProviderList = () => { const { enable_marketplace } = useGlobalPublicStore(s => s.systemFeatures) const containerRef = useRef(null) - const [activeTab, setActiveTab] = useQueryState('category', { - defaultValue: 'builtin', - }) + const [activeTab, setActiveTab] = useQueryState('category', parseAsToolProviderCategory) const options = [ { value: 'builtin', text: t('type.builtIn', { ns: 'tools' }) }, { value: 'api', text: t('type.custom', { ns: 'tools' }) }, @@ -124,6 +133,8 @@ const ProviderList = () => { { + if (!isToolProviderCategory(state)) + return setActiveTab(state) if (state !== activeTab) setCurrentProviderId(undefined) diff --git a/web/context/modal-context.test.tsx b/web/context/modal-context.test.tsx index 2f2d09c6f0..a0f6ff35ec 100644 --- a/web/context/modal-context.test.tsx +++ b/web/context/modal-context.test.tsx @@ -1,9 +1,9 @@ -import { act, render, screen, waitFor } from '@testing-library/react' -import { NuqsTestingAdapter } from 'nuqs/adapters/testing' +import { act, screen, waitFor } from '@testing-library/react' import * as React from 'react' import { defaultPlan } from '@/app/components/billing/config' import { Plan } from '@/app/components/billing/type' import { ModalContextProvider } from '@/context/modal-context' +import { renderWithNuqs } from '@/test/nuqs-testing' vi.mock('@/config', async (importOriginal) => { const actual = await importOriginal() @@ -71,12 +71,10 @@ const createPlan = (overrides: PlanOverrides = {}): PlanShape => ({ }, }) -const renderProvider = () => render( - - -
- - , +const renderProvider = () => renderWithNuqs( + +
+ , ) describe('ModalContextProvider trigger events limit modal', () => { diff --git a/web/context/modal-context.tsx b/web/context/modal-context.tsx index 293970259a..ae1184e5bf 100644 --- a/web/context/modal-context.tsx +++ b/web/context/modal-context.tsx @@ -158,7 +158,7 @@ export const ModalContextProvider = ({ }: ModalContextProviderProps) => { // Use nuqs hooks for URL-based modal state management const [showPricingModal, setPricingModalOpen] = usePricingModal() - const [urlAccountModalState, setUrlAccountModalState] = useAccountSettingModal() + const [urlAccountModalState, setUrlAccountModalState] = useAccountSettingModal() const accountSettingCallbacksRef = useRef, 'payload'> | null>(null) const accountSettingTab = urlAccountModalState.isOpen diff --git a/web/docs/test.md b/web/docs/test.md index cac0e0e351..0204ce0c77 100644 --- a/web/docs/test.md +++ b/web/docs/test.md @@ -225,6 +225,38 @@ Simulate the interactions that matter to users—primary clicks, change events, Mock the specific Next.js navigation hooks your component consumes (`useRouter`, `usePathname`, `useSearchParams`) and drive realistic routing flows—query parameters, redirects, guarded routes, URL updates—while asserting the rendered outcome or navigation side effects. +#### 7.1 `nuqs` Query State Testing + +When testing code that uses `useQueryState` or `useQueryStates`, treat `nuqs` as the source of truth for URL synchronization. + +- ✅ In runtime, keep `NuqsAdapter` in app layout (already wired in `app/layout.tsx`). +- ✅ In tests, wrap with `NuqsTestingAdapter` (prefer helper utilities from `@/test/nuqs-testing`). +- ✅ Assert URL behavior via `onUrlUpdate` events (`searchParams`, `options.history`) instead of only asserting router mocks. +- ✅ For custom parsers created with `createParser`, keep `parse` and `serialize` bijective (round-trip safe). Add edge-case coverage for values like `%2F`, `%25`, spaces, and legacy encoded URLs. +- ✅ Assert default-clearing behavior explicitly (`clearOnDefault` semantics remove params when value equals default). +- ⚠️ Only mock `nuqs` directly when URL behavior is intentionally out of scope for the test. For ESM-safe partial mocks, use async `vi.mock` with `importOriginal`. + +Example: + +```tsx +import { renderHookWithNuqs } from '@/test/nuqs-testing' + +it('should update query with push history', async () => { + const { result, onUrlUpdate } = renderHookWithNuqs(() => useMyQueryState(), { + searchParams: '?page=1', + }) + + act(() => { + result.current.setQuery({ page: 2 }) + }) + + await waitFor(() => expect(onUrlUpdate).toHaveBeenCalled()) + const update = onUrlUpdate.mock.calls.at(-1)![0] + expect(update.options.history).toBe('push') + expect(update.searchParams.get('page')).toBe('2') +}) +``` + ### 8. Edge Cases (REQUIRED - All Components) **Must Test**: diff --git a/web/eslint-suppressions.json b/web/eslint-suppressions.json index 95f00e92b3..4b64291612 100644 --- a/web/eslint-suppressions.json +++ b/web/eslint-suppressions.json @@ -3404,11 +3404,6 @@ "count": 1 } }, - "app/components/datasets/documents/components/list.tsx": { - "react-refresh/only-export-components": { - "count": 1 - } - }, "app/components/datasets/documents/components/operations.tsx": { "no-restricted-imports": { "count": 1 @@ -3853,16 +3848,6 @@ "count": 3 } }, - "app/components/datasets/documents/hooks/use-documents-page-state.ts": { - "react-hooks-extra/no-direct-set-state-in-use-effect": { - "count": 12 - } - }, - "app/components/datasets/documents/index.tsx": { - "react-hooks-extra/no-direct-set-state-in-use-effect": { - "count": 2 - } - }, "app/components/datasets/documents/status-item/index.tsx": { "no-restricted-imports": { "count": 1 diff --git a/web/hooks/use-query-params.spec.tsx b/web/hooks/use-query-params.spec.tsx index 35e234881d..ef744742cf 100644 --- a/web/hooks/use-query-params.spec.tsx +++ b/web/hooks/use-query-params.spec.tsx @@ -1,8 +1,6 @@ -import type { UrlUpdateEvent } from 'nuqs/adapters/testing' -import type { ReactNode } from 'react' -import { act, renderHook, waitFor } from '@testing-library/react' -import { NuqsTestingAdapter } from 'nuqs/adapters/testing' +import { act, waitFor } from '@testing-library/react' import { ACCOUNT_SETTING_MODAL_ACTION } from '@/app/components/header/account-setting/constants' +import { renderHookWithNuqs } from '@/test/nuqs-testing' import { clearQueryParams, PRICING_MODAL_QUERY_PARAM, @@ -20,14 +18,7 @@ vi.mock('@/utils/client', () => ({ })) const renderWithAdapter = (hook: () => T, searchParams = '') => { - const onUrlUpdate = vi.fn<(event: UrlUpdateEvent) => void>() - const wrapper = ({ children }: { children: ReactNode }) => ( - - {children} - - ) - const { result } = renderHook(hook, { wrapper }) - return { result, onUrlUpdate } + return renderHookWithNuqs(hook, { searchParams }) } // Query param hooks: defaults, parsing, and URL sync behavior. diff --git a/web/hooks/use-query-params.ts b/web/hooks/use-query-params.ts index 0749a1ffa5..2de6cfcd2e 100644 --- a/web/hooks/use-query-params.ts +++ b/web/hooks/use-query-params.ts @@ -13,14 +13,19 @@ * - Use shallow routing to avoid unnecessary re-renders */ +import type { AccountSettingTab } from '@/app/components/header/account-setting/constants' import { createParser, - parseAsString, + parseAsStringEnum, + parseAsStringLiteral, useQueryState, useQueryStates, } from 'nuqs' import { useCallback } from 'react' -import { ACCOUNT_SETTING_MODAL_ACTION } from '@/app/components/header/account-setting/constants' +import { + ACCOUNT_SETTING_MODAL_ACTION, + ACCOUNT_SETTING_TAB, +} from '@/app/components/header/account-setting/constants' import { isServer } from '@/utils/client' /** @@ -52,6 +57,10 @@ export function usePricingModal() { ) } +const accountSettingTabValues = Object.values(ACCOUNT_SETTING_TAB) as AccountSettingTab[] +const parseAsAccountSettingAction = parseAsStringLiteral([ACCOUNT_SETTING_MODAL_ACTION] as const) +const parseAsAccountSettingTab = parseAsStringEnum(accountSettingTabValues) + /** * Hook to manage account setting modal state via URL * @returns [state, setState] - Object with isOpen + payload (tab) and setter @@ -61,11 +70,11 @@ export function usePricingModal() { * setAccountModalState({ payload: 'billing' }) // Sets ?action=showSettings&tab=billing * setAccountModalState(null) // Removes both params */ -export function useAccountSettingModal() { +export function useAccountSettingModal() { const [accountState, setAccountState] = useQueryStates( { - action: parseAsString, - tab: parseAsString, + action: parseAsAccountSettingAction, + tab: parseAsAccountSettingTab, }, { history: 'replace', @@ -73,7 +82,7 @@ export function useAccountSettingModal() { ) const setState = useCallback( - (state: { payload: T } | null) => { + (state: { payload: AccountSettingTab } | null) => { if (!state) { setAccountState({ action: null, tab: null }, { history: 'replace' }) return @@ -88,7 +97,7 @@ export function useAccountSettingModal() { ) const isOpen = accountState.action === ACCOUNT_SETTING_MODAL_ACTION - const currentTab = (isOpen ? accountState.tab : null) as T | null + const currentTab = isOpen ? accountState.tab : null return [{ isOpen, payload: currentTab }, setState] as const } diff --git a/web/service/knowledge/use-document.ts b/web/service/knowledge/use-document.ts index 74c9a77bcf..4eb2b7d282 100644 --- a/web/service/knowledge/use-document.ts +++ b/web/service/knowledge/use-document.ts @@ -1,7 +1,9 @@ +import type { UseQueryOptions } from '@tanstack/react-query' import type { DocumentDownloadResponse, DocumentDownloadZipRequest, MetadataType, SortType } from '../datasets' import type { CommonResponse } from '@/models/common' import type { DocumentDetailResponse, DocumentListResponse, UpdateDocumentBatchParams } from '@/models/datasets' import { + keepPreviousData, useMutation, useQuery, } from '@tanstack/react-query' @@ -14,6 +16,8 @@ import { useInvalid } from '../use-base' const NAME_SPACE = 'knowledge/document' export const useDocumentListKey = [NAME_SPACE, 'documentList'] +type DocumentListRefetchInterval = UseQueryOptions['refetchInterval'] + export const useDocumentList = (payload: { datasetId: string query: { @@ -23,7 +27,7 @@ export const useDocumentList = (payload: { sort?: SortType status?: string } - refetchInterval?: number | false + refetchInterval?: DocumentListRefetchInterval }) => { const { query, datasetId, refetchInterval } = payload const { keyword, page, limit, sort, status } = query @@ -42,6 +46,7 @@ export const useDocumentList = (payload: { queryFn: () => get(`/datasets/${datasetId}/documents`, { params, }), + placeholderData: keepPreviousData, refetchInterval, }) } diff --git a/web/test/nuqs-testing.tsx b/web/test/nuqs-testing.tsx new file mode 100644 index 0000000000..b5bf9fa83c --- /dev/null +++ b/web/test/nuqs-testing.tsx @@ -0,0 +1,60 @@ +import type { UrlUpdateEvent } from 'nuqs/adapters/testing' +import type { ComponentProps, ReactElement, ReactNode } from 'react' +import type { Mock } from 'vitest' +import { render, renderHook } from '@testing-library/react' +import { NuqsTestingAdapter } from 'nuqs/adapters/testing' +import { vi } from 'vitest' + +type NuqsSearchParams = ComponentProps['searchParams'] +type NuqsOnUrlUpdate = (event: UrlUpdateEvent) => void +type NuqsOnUrlUpdateSpy = Mock + +type NuqsTestOptions = { + searchParams?: NuqsSearchParams + onUrlUpdate?: NuqsOnUrlUpdateSpy +} + +type NuqsHookTestOptions = NuqsTestOptions & { + initialProps?: Props +} + +type NuqsWrapperProps = { + children: ReactNode +} + +export const createNuqsTestWrapper = (options: NuqsTestOptions = {}) => { + const { searchParams = '', onUrlUpdate } = options + const urlUpdateSpy = onUrlUpdate ?? vi.fn() + const wrapper = ({ children }: NuqsWrapperProps) => ( + + {children} + + ) + + return { + wrapper, + onUrlUpdate: urlUpdateSpy, + } +} + +export const renderWithNuqs = (ui: ReactElement, options: NuqsTestOptions = {}) => { + const { wrapper, onUrlUpdate } = createNuqsTestWrapper(options) + const rendered = render(ui, { wrapper }) + return { + ...rendered, + onUrlUpdate, + } +} + +export const renderHookWithNuqs = ( + callback: (props: Props) => Result, + options: NuqsHookTestOptions = {}, +) => { + const { initialProps, ...nuqsOptions } = options + const { wrapper, onUrlUpdate } = createNuqsTestWrapper(nuqsOptions) + const rendered = renderHook(callback, { wrapper, initialProps }) + return { + ...rendered, + onUrlUpdate, + } +} From 65bf632ec0bddc14ce9242786075693eb126a4db Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=9C=A8=E4=B9=8B=E6=9C=AC=E6=BE=AA?= Date: Tue, 3 Mar 2026 19:29:58 +0800 Subject: [PATCH 006/256] test: migrate test_dataset_service_batch_update_document_status SQL tests to testcontainers (#32537) Co-authored-by: KinomotoMio <200703522+KinomotoMio@users.noreply.github.com> Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com> --- ...et_service_batch_update_document_status.py | 660 ++++++++++++++++ ...et_service_batch_update_document_status.py | 704 +----------------- 2 files changed, 662 insertions(+), 702 deletions(-) create mode 100644 api/tests/test_containers_integration_tests/services/test_dataset_service_batch_update_document_status.py diff --git a/api/tests/test_containers_integration_tests/services/test_dataset_service_batch_update_document_status.py b/api/tests/test_containers_integration_tests/services/test_dataset_service_batch_update_document_status.py new file mode 100644 index 0000000000..ffdb501474 --- /dev/null +++ b/api/tests/test_containers_integration_tests/services/test_dataset_service_batch_update_document_status.py @@ -0,0 +1,660 @@ +"""Integration tests for DocumentService.batch_update_document_status. + +This suite validates SQL-backed batch status updates with testcontainers. +It keeps database access real and only patches non-DB side effects. +""" + +import datetime +import json +from dataclasses import dataclass +from unittest.mock import call, patch +from uuid import uuid4 + +import pytest + +from extensions.ext_database import db +from models.dataset import Dataset, Document +from services.dataset_service import DocumentService +from services.errors.document import DocumentIndexingError + +FIXED_TIME = datetime.datetime(2023, 1, 1, 12, 0, 0) + + +@dataclass +class UserDouble: + """Minimal user object for batch update operations.""" + + id: str + + +class DocumentBatchUpdateIntegrationDataFactory: + """Factory for creating persisted entities used in integration tests.""" + + @staticmethod + def create_dataset( + dataset_id: str | None = None, + tenant_id: str | None = None, + name: str = "Test Dataset", + created_by: str | None = None, + ) -> Dataset: + """Create and persist a dataset.""" + dataset = Dataset( + tenant_id=tenant_id or str(uuid4()), + name=name, + data_source_type="upload_file", + created_by=created_by or str(uuid4()), + ) + if dataset_id: + dataset.id = dataset_id + + db.session.add(dataset) + db.session.commit() + return dataset + + @staticmethod + def create_document( + dataset: Dataset, + document_id: str | None = None, + name: str = "test_document.pdf", + enabled: bool = True, + archived: bool = False, + indexing_status: str = "completed", + completed_at: datetime.datetime | None = None, + position: int = 1, + created_by: str | None = None, + commit: bool = True, + **kwargs, + ) -> Document: + """Create a document bound to the given dataset and persist it.""" + document = Document( + tenant_id=dataset.tenant_id, + dataset_id=dataset.id, + position=position, + data_source_type="upload_file", + data_source_info=json.dumps({"upload_file_id": str(uuid4())}), + batch=f"batch-{uuid4()}", + name=name, + created_from="web", + created_by=created_by or str(uuid4()), + doc_form="text_model", + ) + document.id = document_id or str(uuid4()) + document.enabled = enabled + document.archived = archived + document.indexing_status = indexing_status + document.completed_at = ( + completed_at if completed_at is not None else (FIXED_TIME if indexing_status == "completed" else None) + ) + + for key, value in kwargs.items(): + setattr(document, key, value) + + db.session.add(document) + if commit: + db.session.commit() + return document + + @staticmethod + def create_multiple_documents( + dataset: Dataset, + document_ids: list[str], + enabled: bool = True, + archived: bool = False, + indexing_status: str = "completed", + ) -> list[Document]: + """Create and persist multiple documents for one dataset in a single transaction.""" + documents: list[Document] = [] + for index, doc_id in enumerate(document_ids, start=1): + document = DocumentBatchUpdateIntegrationDataFactory.create_document( + dataset=dataset, + document_id=doc_id, + name=f"document_{doc_id}.pdf", + enabled=enabled, + archived=archived, + indexing_status=indexing_status, + position=index, + commit=False, + ) + documents.append(document) + db.session.commit() + return documents + + @staticmethod + def create_user(user_id: str | None = None) -> UserDouble: + """Create a lightweight user for update metadata fields.""" + return UserDouble(id=user_id or str(uuid4())) + + +class TestDatasetServiceBatchUpdateDocumentStatus: + """Integration coverage for batch document status updates.""" + + @pytest.fixture + def patched_dependencies(self): + """Patch non-DB collaborators only.""" + with ( + patch("services.dataset_service.redis_client") as redis_client, + patch("services.dataset_service.add_document_to_index_task") as add_task, + patch("services.dataset_service.remove_document_from_index_task") as remove_task, + patch("services.dataset_service.naive_utc_now") as naive_utc_now, + ): + naive_utc_now.return_value = FIXED_TIME + redis_client.get.return_value = None + yield { + "redis_client": redis_client, + "add_task": add_task, + "remove_task": remove_task, + "naive_utc_now": naive_utc_now, + } + + def _assert_document_enabled(self, document: Document, current_time: datetime.datetime): + """Verify enabled-state fields after action=enable.""" + assert document.enabled is True + assert document.disabled_at is None + assert document.disabled_by is None + assert document.updated_at == current_time + + def _assert_document_disabled(self, document: Document, user_id: str, current_time: datetime.datetime): + """Verify disabled-state fields after action=disable.""" + assert document.enabled is False + assert document.disabled_at == current_time + assert document.disabled_by == user_id + assert document.updated_at == current_time + + def _assert_document_archived(self, document: Document, user_id: str, current_time: datetime.datetime): + """Verify archived-state fields after action=archive.""" + assert document.archived is True + assert document.archived_at == current_time + assert document.archived_by == user_id + assert document.updated_at == current_time + + def _assert_document_unarchived(self, document: Document): + """Verify unarchived-state fields after action=un_archive.""" + assert document.archived is False + assert document.archived_at is None + assert document.archived_by is None + + def test_batch_update_enable_documents_success(self, db_session_with_containers, patched_dependencies): + """Enable disabled documents and trigger indexing side effects.""" + # Arrange + dataset = DocumentBatchUpdateIntegrationDataFactory.create_dataset() + user = DocumentBatchUpdateIntegrationDataFactory.create_user() + document_ids = [str(uuid4()), str(uuid4())] + disabled_docs = DocumentBatchUpdateIntegrationDataFactory.create_multiple_documents( + dataset=dataset, + document_ids=document_ids, + enabled=False, + ) + + # Act + DocumentService.batch_update_document_status( + dataset=dataset, document_ids=document_ids, action="enable", user=user + ) + + # Assert + for document in disabled_docs: + db.session.refresh(document) + self._assert_document_enabled(document, FIXED_TIME) + + expected_get_calls = [call(f"document_{doc_id}_indexing") for doc_id in document_ids] + expected_setex_calls = [call(f"document_{doc_id}_indexing", 600, 1) for doc_id in document_ids] + expected_add_calls = [call(doc_id) for doc_id in document_ids] + patched_dependencies["redis_client"].get.assert_has_calls(expected_get_calls) + patched_dependencies["redis_client"].setex.assert_has_calls(expected_setex_calls) + patched_dependencies["add_task"].delay.assert_has_calls(expected_add_calls) + + def test_batch_update_enable_already_enabled_document_skipped( + self, db_session_with_containers, patched_dependencies + ): + """Skip enable operation for already-enabled documents.""" + # Arrange + dataset = DocumentBatchUpdateIntegrationDataFactory.create_dataset() + user = DocumentBatchUpdateIntegrationDataFactory.create_user() + document = DocumentBatchUpdateIntegrationDataFactory.create_document(dataset=dataset, enabled=True) + + # Act + DocumentService.batch_update_document_status( + dataset=dataset, + document_ids=[document.id], + action="enable", + user=user, + ) + + # Assert + db.session.refresh(document) + assert document.enabled is True + patched_dependencies["redis_client"].setex.assert_not_called() + patched_dependencies["add_task"].delay.assert_not_called() + + def test_batch_update_disable_documents_success(self, db_session_with_containers, patched_dependencies): + """Disable completed documents and trigger remove-index tasks.""" + # Arrange + dataset = DocumentBatchUpdateIntegrationDataFactory.create_dataset() + user = DocumentBatchUpdateIntegrationDataFactory.create_user() + document_ids = [str(uuid4()), str(uuid4())] + enabled_docs = DocumentBatchUpdateIntegrationDataFactory.create_multiple_documents( + dataset=dataset, + document_ids=document_ids, + enabled=True, + indexing_status="completed", + ) + + # Act + DocumentService.batch_update_document_status( + dataset=dataset, + document_ids=document_ids, + action="disable", + user=user, + ) + + # Assert + for document in enabled_docs: + db.session.refresh(document) + self._assert_document_disabled(document, user.id, FIXED_TIME) + + expected_get_calls = [call(f"document_{doc_id}_indexing") for doc_id in document_ids] + expected_setex_calls = [call(f"document_{doc_id}_indexing", 600, 1) for doc_id in document_ids] + expected_remove_calls = [call(doc_id) for doc_id in document_ids] + patched_dependencies["redis_client"].get.assert_has_calls(expected_get_calls) + patched_dependencies["redis_client"].setex.assert_has_calls(expected_setex_calls) + patched_dependencies["remove_task"].delay.assert_has_calls(expected_remove_calls) + + def test_batch_update_disable_already_disabled_document_skipped( + self, db_session_with_containers, patched_dependencies + ): + """Skip disable operation for already-disabled documents.""" + # Arrange + dataset = DocumentBatchUpdateIntegrationDataFactory.create_dataset() + user = DocumentBatchUpdateIntegrationDataFactory.create_user() + disabled_doc = DocumentBatchUpdateIntegrationDataFactory.create_document( + dataset=dataset, + enabled=False, + indexing_status="completed", + completed_at=FIXED_TIME, + ) + + # Act + DocumentService.batch_update_document_status( + dataset=dataset, + document_ids=[disabled_doc.id], + action="disable", + user=user, + ) + + # Assert + db.session.refresh(disabled_doc) + assert disabled_doc.enabled is False + patched_dependencies["redis_client"].setex.assert_not_called() + patched_dependencies["remove_task"].delay.assert_not_called() + + def test_batch_update_disable_non_completed_document_error(self, db_session_with_containers, patched_dependencies): + """Raise error when disabling a non-completed document.""" + # Arrange + dataset = DocumentBatchUpdateIntegrationDataFactory.create_dataset() + user = DocumentBatchUpdateIntegrationDataFactory.create_user() + non_completed_doc = DocumentBatchUpdateIntegrationDataFactory.create_document( + dataset=dataset, + enabled=True, + indexing_status="indexing", + completed_at=None, + ) + + # Act / Assert + with pytest.raises(DocumentIndexingError, match="is not completed"): + DocumentService.batch_update_document_status( + dataset=dataset, + document_ids=[non_completed_doc.id], + action="disable", + user=user, + ) + + def test_batch_update_archive_documents_success(self, db_session_with_containers, patched_dependencies): + """Archive enabled documents and trigger remove-index task.""" + # Arrange + dataset = DocumentBatchUpdateIntegrationDataFactory.create_dataset() + user = DocumentBatchUpdateIntegrationDataFactory.create_user() + document = DocumentBatchUpdateIntegrationDataFactory.create_document( + dataset=dataset, enabled=True, archived=False + ) + + # Act + DocumentService.batch_update_document_status( + dataset=dataset, + document_ids=[document.id], + action="archive", + user=user, + ) + + # Assert + db.session.refresh(document) + self._assert_document_archived(document, user.id, FIXED_TIME) + patched_dependencies["redis_client"].get.assert_called_once_with(f"document_{document.id}_indexing") + patched_dependencies["redis_client"].setex.assert_called_once_with(f"document_{document.id}_indexing", 600, 1) + patched_dependencies["remove_task"].delay.assert_called_once_with(document.id) + + def test_batch_update_archive_already_archived_document_skipped( + self, db_session_with_containers, patched_dependencies + ): + """Skip archive operation for already-archived documents.""" + # Arrange + dataset = DocumentBatchUpdateIntegrationDataFactory.create_dataset() + user = DocumentBatchUpdateIntegrationDataFactory.create_user() + document = DocumentBatchUpdateIntegrationDataFactory.create_document( + dataset=dataset, enabled=True, archived=True + ) + + # Act + DocumentService.batch_update_document_status( + dataset=dataset, + document_ids=[document.id], + action="archive", + user=user, + ) + + # Assert + db.session.refresh(document) + assert document.archived is True + patched_dependencies["redis_client"].setex.assert_not_called() + patched_dependencies["remove_task"].delay.assert_not_called() + + def test_batch_update_archive_disabled_document_no_index_removal( + self, db_session_with_containers, patched_dependencies + ): + """Archive disabled document without index-removal side effects.""" + # Arrange + dataset = DocumentBatchUpdateIntegrationDataFactory.create_dataset() + user = DocumentBatchUpdateIntegrationDataFactory.create_user() + document = DocumentBatchUpdateIntegrationDataFactory.create_document( + dataset=dataset, enabled=False, archived=False + ) + + # Act + DocumentService.batch_update_document_status( + dataset=dataset, + document_ids=[document.id], + action="archive", + user=user, + ) + + # Assert + db.session.refresh(document) + self._assert_document_archived(document, user.id, FIXED_TIME) + patched_dependencies["redis_client"].setex.assert_not_called() + patched_dependencies["remove_task"].delay.assert_not_called() + + def test_batch_update_unarchive_documents_success(self, db_session_with_containers, patched_dependencies): + """Unarchive enabled documents and trigger add-index task.""" + # Arrange + dataset = DocumentBatchUpdateIntegrationDataFactory.create_dataset() + user = DocumentBatchUpdateIntegrationDataFactory.create_user() + document = DocumentBatchUpdateIntegrationDataFactory.create_document( + dataset=dataset, enabled=True, archived=True + ) + + # Act + DocumentService.batch_update_document_status( + dataset=dataset, + document_ids=[document.id], + action="un_archive", + user=user, + ) + + # Assert + db.session.refresh(document) + self._assert_document_unarchived(document) + assert document.updated_at == FIXED_TIME + patched_dependencies["redis_client"].get.assert_called_once_with(f"document_{document.id}_indexing") + patched_dependencies["redis_client"].setex.assert_called_once_with(f"document_{document.id}_indexing", 600, 1) + patched_dependencies["add_task"].delay.assert_called_once_with(document.id) + + def test_batch_update_unarchive_already_unarchived_document_skipped( + self, db_session_with_containers, patched_dependencies + ): + """Skip unarchive operation for already-unarchived documents.""" + # Arrange + dataset = DocumentBatchUpdateIntegrationDataFactory.create_dataset() + user = DocumentBatchUpdateIntegrationDataFactory.create_user() + document = DocumentBatchUpdateIntegrationDataFactory.create_document( + dataset=dataset, enabled=True, archived=False + ) + + # Act + DocumentService.batch_update_document_status( + dataset=dataset, + document_ids=[document.id], + action="un_archive", + user=user, + ) + + # Assert + db.session.refresh(document) + assert document.archived is False + patched_dependencies["redis_client"].setex.assert_not_called() + patched_dependencies["add_task"].delay.assert_not_called() + + def test_batch_update_unarchive_disabled_document_no_index_addition( + self, db_session_with_containers, patched_dependencies + ): + """Unarchive disabled document without index-add side effects.""" + # Arrange + dataset = DocumentBatchUpdateIntegrationDataFactory.create_dataset() + user = DocumentBatchUpdateIntegrationDataFactory.create_user() + document = DocumentBatchUpdateIntegrationDataFactory.create_document( + dataset=dataset, enabled=False, archived=True + ) + + # Act + DocumentService.batch_update_document_status( + dataset=dataset, + document_ids=[document.id], + action="un_archive", + user=user, + ) + + # Assert + db.session.refresh(document) + self._assert_document_unarchived(document) + assert document.updated_at == FIXED_TIME + patched_dependencies["redis_client"].setex.assert_not_called() + patched_dependencies["add_task"].delay.assert_not_called() + + def test_batch_update_document_indexing_error_redis_cache_hit( + self, db_session_with_containers, patched_dependencies + ): + """Raise DocumentIndexingError when redis indicates active indexing.""" + # Arrange + dataset = DocumentBatchUpdateIntegrationDataFactory.create_dataset() + user = DocumentBatchUpdateIntegrationDataFactory.create_user() + document = DocumentBatchUpdateIntegrationDataFactory.create_document( + dataset=dataset, + name="test_document.pdf", + enabled=True, + ) + patched_dependencies["redis_client"].get.return_value = "indexing" + + # Act / Assert + with pytest.raises(DocumentIndexingError, match="is being indexed") as exc_info: + DocumentService.batch_update_document_status( + dataset=dataset, + document_ids=[document.id], + action="enable", + user=user, + ) + + assert "test_document.pdf" in str(exc_info.value) + patched_dependencies["redis_client"].get.assert_called_once_with(f"document_{document.id}_indexing") + + def test_batch_update_async_task_error_handling(self, db_session_with_containers, patched_dependencies): + """Persist DB update, then propagate async task error.""" + # Arrange + dataset = DocumentBatchUpdateIntegrationDataFactory.create_dataset() + user = DocumentBatchUpdateIntegrationDataFactory.create_user() + document = DocumentBatchUpdateIntegrationDataFactory.create_document(dataset=dataset, enabled=False) + patched_dependencies["add_task"].delay.side_effect = Exception("Celery task error") + + # Act / Assert + with pytest.raises(Exception, match="Celery task error"): + DocumentService.batch_update_document_status( + dataset=dataset, + document_ids=[document.id], + action="enable", + user=user, + ) + + db.session.refresh(document) + self._assert_document_enabled(document, FIXED_TIME) + patched_dependencies["redis_client"].setex.assert_called_once_with(f"document_{document.id}_indexing", 600, 1) + + def test_batch_update_empty_document_list(self, db_session_with_containers, patched_dependencies): + """Return early when document_ids is empty.""" + # Arrange + dataset = DocumentBatchUpdateIntegrationDataFactory.create_dataset() + user = DocumentBatchUpdateIntegrationDataFactory.create_user() + + # Act + result = DocumentService.batch_update_document_status( + dataset=dataset, document_ids=[], action="enable", user=user + ) + + # Assert + assert result is None + patched_dependencies["redis_client"].get.assert_not_called() + patched_dependencies["redis_client"].setex.assert_not_called() + + def test_batch_update_document_not_found_skipped(self, db_session_with_containers, patched_dependencies): + """Skip IDs that do not map to existing dataset documents.""" + # Arrange + dataset = DocumentBatchUpdateIntegrationDataFactory.create_dataset() + user = DocumentBatchUpdateIntegrationDataFactory.create_user() + missing_document_id = str(uuid4()) + + # Act + DocumentService.batch_update_document_status( + dataset=dataset, + document_ids=[missing_document_id], + action="enable", + user=user, + ) + + # Assert + patched_dependencies["redis_client"].get.assert_not_called() + patched_dependencies["redis_client"].setex.assert_not_called() + patched_dependencies["add_task"].delay.assert_not_called() + + def test_batch_update_mixed_document_states_and_actions(self, db_session_with_containers, patched_dependencies): + """Process only the applicable document in a mixed-state enable batch.""" + # Arrange + dataset = DocumentBatchUpdateIntegrationDataFactory.create_dataset() + user = DocumentBatchUpdateIntegrationDataFactory.create_user() + disabled_doc = DocumentBatchUpdateIntegrationDataFactory.create_document(dataset=dataset, enabled=False) + enabled_doc = DocumentBatchUpdateIntegrationDataFactory.create_document( + dataset=dataset, + enabled=True, + position=2, + ) + archived_doc = DocumentBatchUpdateIntegrationDataFactory.create_document( + dataset=dataset, + enabled=True, + archived=True, + position=3, + ) + document_ids = [disabled_doc.id, enabled_doc.id, archived_doc.id] + + # Act + DocumentService.batch_update_document_status( + dataset=dataset, + document_ids=document_ids, + action="enable", + user=user, + ) + + # Assert + db.session.refresh(disabled_doc) + db.session.refresh(enabled_doc) + db.session.refresh(archived_doc) + self._assert_document_enabled(disabled_doc, FIXED_TIME) + assert enabled_doc.enabled is True + assert archived_doc.enabled is True + + patched_dependencies["redis_client"].setex.assert_called_once_with( + f"document_{disabled_doc.id}_indexing", + 600, + 1, + ) + patched_dependencies["add_task"].delay.assert_called_once_with(disabled_doc.id) + + def test_batch_update_large_document_list_performance(self, db_session_with_containers, patched_dependencies): + """Handle large document lists with consistent updates and side effects.""" + # Arrange + dataset = DocumentBatchUpdateIntegrationDataFactory.create_dataset() + user = DocumentBatchUpdateIntegrationDataFactory.create_user() + document_ids = [str(uuid4()) for _ in range(100)] + documents = DocumentBatchUpdateIntegrationDataFactory.create_multiple_documents( + dataset=dataset, + document_ids=document_ids, + enabled=False, + ) + + # Act + DocumentService.batch_update_document_status( + dataset=dataset, + document_ids=document_ids, + action="enable", + user=user, + ) + + # Assert + for document in documents: + db.session.refresh(document) + self._assert_document_enabled(document, FIXED_TIME) + + assert patched_dependencies["redis_client"].setex.call_count == len(document_ids) + assert patched_dependencies["add_task"].delay.call_count == len(document_ids) + + expected_setex_calls = [call(f"document_{doc_id}_indexing", 600, 1) for doc_id in document_ids] + expected_task_calls = [call(doc_id) for doc_id in document_ids] + patched_dependencies["redis_client"].setex.assert_has_calls(expected_setex_calls) + patched_dependencies["add_task"].delay.assert_has_calls(expected_task_calls) + + def test_batch_update_mixed_document_states_complex_scenario( + self, db_session_with_containers, patched_dependencies + ): + """Process a complex mixed-state batch and update only eligible records.""" + # Arrange + dataset = DocumentBatchUpdateIntegrationDataFactory.create_dataset() + user = DocumentBatchUpdateIntegrationDataFactory.create_user() + doc1 = DocumentBatchUpdateIntegrationDataFactory.create_document(dataset=dataset, enabled=False) + doc2 = DocumentBatchUpdateIntegrationDataFactory.create_document(dataset=dataset, enabled=True, position=2) + doc3 = DocumentBatchUpdateIntegrationDataFactory.create_document(dataset=dataset, enabled=True, position=3) + doc4 = DocumentBatchUpdateIntegrationDataFactory.create_document(dataset=dataset, enabled=True, position=4) + doc5 = DocumentBatchUpdateIntegrationDataFactory.create_document( + dataset=dataset, + enabled=True, + archived=True, + position=5, + ) + missing_id = str(uuid4()) + + document_ids = [doc1.id, doc2.id, doc3.id, doc4.id, doc5.id, missing_id] + + # Act + DocumentService.batch_update_document_status( + dataset=dataset, + document_ids=document_ids, + action="enable", + user=user, + ) + + # Assert + db.session.refresh(doc1) + db.session.refresh(doc2) + db.session.refresh(doc3) + db.session.refresh(doc4) + db.session.refresh(doc5) + self._assert_document_enabled(doc1, FIXED_TIME) + assert doc2.enabled is True + assert doc3.enabled is True + assert doc4.enabled is True + assert doc5.enabled is True + + patched_dependencies["redis_client"].setex.assert_called_once_with(f"document_{doc1.id}_indexing", 600, 1) + patched_dependencies["add_task"].delay.assert_called_once_with(doc1.id) diff --git a/api/tests/unit_tests/services/test_dataset_service_batch_update_document_status.py b/api/tests/unit_tests/services/test_dataset_service_batch_update_document_status.py index 69766188f3..abff48347e 100644 --- a/api/tests/unit_tests/services/test_dataset_service_batch_update_document_status.py +++ b/api/tests/unit_tests/services/test_dataset_service_batch_update_document_status.py @@ -1,13 +1,10 @@ import datetime - -# Mock redis_client before importing dataset_service -from unittest.mock import Mock, call, patch +from unittest.mock import Mock, patch import pytest from models.dataset import Dataset, Document from services.dataset_service import DocumentService -from services.errors.document import DocumentIndexingError from tests.unit_tests.conftest import redis_mock @@ -48,7 +45,6 @@ class DocumentBatchUpdateTestDataFactory: document.indexing_status = indexing_status document.completed_at = completed_at or datetime.datetime.now() - # Set default values for optional fields document.disabled_at = None document.disabled_by = None document.archived_at = None @@ -59,32 +55,9 @@ class DocumentBatchUpdateTestDataFactory: setattr(document, key, value) return document - @staticmethod - def create_multiple_documents( - document_ids: list[str], enabled: bool = True, archived: bool = False, indexing_status: str = "completed" - ) -> list[Mock]: - """Create multiple mock documents with specified attributes.""" - documents = [] - for doc_id in document_ids: - doc = DocumentBatchUpdateTestDataFactory.create_document_mock( - document_id=doc_id, - name=f"document_{doc_id}.pdf", - enabled=enabled, - archived=archived, - indexing_status=indexing_status, - ) - documents.append(doc) - return documents - class TestDatasetServiceBatchUpdateDocumentStatus: - """ - Comprehensive unit tests for DocumentService.batch_update_document_status method. - - This test suite covers all supported actions (enable, disable, archive, un_archive), - error conditions, edge cases, and validates proper interaction with Redis cache, - database operations, and async task triggers. - """ + """Unit tests for non-SQL path in DocumentService.batch_update_document_status.""" @pytest.fixture def mock_document_service_dependencies(self): @@ -104,697 +77,24 @@ class TestDatasetServiceBatchUpdateDocumentStatus: "current_time": current_time, } - @pytest.fixture - def mock_async_task_dependencies(self): - """Mock setup for async task dependencies.""" - with ( - patch("services.dataset_service.add_document_to_index_task") as mock_add_task, - patch("services.dataset_service.remove_document_from_index_task") as mock_remove_task, - ): - yield {"add_task": mock_add_task, "remove_task": mock_remove_task} - - def _assert_document_enabled(self, document: Mock, user_id: str, current_time: datetime.datetime): - """Helper method to verify document was enabled correctly.""" - assert document.enabled == True - assert document.disabled_at is None - assert document.disabled_by is None - assert document.updated_at == current_time - - def _assert_document_disabled(self, document: Mock, user_id: str, current_time: datetime.datetime): - """Helper method to verify document was disabled correctly.""" - assert document.enabled == False - assert document.disabled_at == current_time - assert document.disabled_by == user_id - assert document.updated_at == current_time - - def _assert_document_archived(self, document: Mock, user_id: str, current_time: datetime.datetime): - """Helper method to verify document was archived correctly.""" - assert document.archived == True - assert document.archived_at == current_time - assert document.archived_by == user_id - assert document.updated_at == current_time - - def _assert_document_unarchived(self, document: Mock): - """Helper method to verify document was unarchived correctly.""" - assert document.archived == False - assert document.archived_at is None - assert document.archived_by is None - - def _assert_redis_cache_operations(self, document_ids: list[str], action: str = "setex"): - """Helper method to verify Redis cache operations.""" - if action == "setex": - expected_calls = [call(f"document_{doc_id}_indexing", 600, 1) for doc_id in document_ids] - redis_mock.setex.assert_has_calls(expected_calls) - elif action == "get": - expected_calls = [call(f"document_{doc_id}_indexing") for doc_id in document_ids] - redis_mock.get.assert_has_calls(expected_calls) - - def _assert_async_task_calls(self, mock_task, document_ids: list[str], task_type: str): - """Helper method to verify async task calls.""" - expected_calls = [call(doc_id) for doc_id in document_ids] - if task_type in {"add", "remove"}: - mock_task.delay.assert_has_calls(expected_calls) - - # ==================== Enable Document Tests ==================== - - def test_batch_update_enable_documents_success( - self, mock_document_service_dependencies, mock_async_task_dependencies - ): - """Test successful enabling of disabled documents.""" - dataset = DocumentBatchUpdateTestDataFactory.create_dataset_mock() - user = DocumentBatchUpdateTestDataFactory.create_user_mock() - - # Create disabled documents - disabled_docs = DocumentBatchUpdateTestDataFactory.create_multiple_documents(["doc-1", "doc-2"], enabled=False) - mock_document_service_dependencies["get_document"].side_effect = disabled_docs - - # Reset module-level Redis mock - redis_mock.reset_mock() - redis_mock.get.return_value = None - - # Call the method to enable documents - DocumentService.batch_update_document_status( - dataset=dataset, document_ids=["doc-1", "doc-2"], action="enable", user=user - ) - - # Verify document attributes were updated correctly - for doc in disabled_docs: - self._assert_document_enabled(doc, user.id, mock_document_service_dependencies["current_time"]) - - # Verify Redis cache operations - self._assert_redis_cache_operations(["doc-1", "doc-2"], "get") - self._assert_redis_cache_operations(["doc-1", "doc-2"], "setex") - - # Verify async tasks were triggered for indexing - self._assert_async_task_calls(mock_async_task_dependencies["add_task"], ["doc-1", "doc-2"], "add") - - # Verify database operations - mock_db = mock_document_service_dependencies["db_session"] - assert mock_db.add.call_count == 2 - assert mock_db.commit.call_count == 1 - - def test_batch_update_enable_already_enabled_document_skipped(self, mock_document_service_dependencies): - """Test enabling documents that are already enabled.""" - dataset = DocumentBatchUpdateTestDataFactory.create_dataset_mock() - user = DocumentBatchUpdateTestDataFactory.create_user_mock() - - # Create already enabled document - enabled_doc = DocumentBatchUpdateTestDataFactory.create_document_mock(enabled=True) - mock_document_service_dependencies["get_document"].return_value = enabled_doc - - # Reset module-level Redis mock - redis_mock.reset_mock() - redis_mock.get.return_value = None - - # Attempt to enable already enabled document - DocumentService.batch_update_document_status( - dataset=dataset, document_ids=["doc-1"], action="enable", user=user - ) - - # Verify no database operations occurred (document was skipped) - mock_db = mock_document_service_dependencies["db_session"] - mock_db.commit.assert_not_called() - - # Verify no Redis setex operations occurred (document was skipped) - redis_mock.setex.assert_not_called() - - # ==================== Disable Document Tests ==================== - - def test_batch_update_disable_documents_success( - self, mock_document_service_dependencies, mock_async_task_dependencies - ): - """Test successful disabling of enabled and completed documents.""" - dataset = DocumentBatchUpdateTestDataFactory.create_dataset_mock() - user = DocumentBatchUpdateTestDataFactory.create_user_mock() - - # Create enabled documents - enabled_docs = DocumentBatchUpdateTestDataFactory.create_multiple_documents(["doc-1", "doc-2"], enabled=True) - mock_document_service_dependencies["get_document"].side_effect = enabled_docs - - # Reset module-level Redis mock - redis_mock.reset_mock() - redis_mock.get.return_value = None - - # Call the method to disable documents - DocumentService.batch_update_document_status( - dataset=dataset, document_ids=["doc-1", "doc-2"], action="disable", user=user - ) - - # Verify document attributes were updated correctly - for doc in enabled_docs: - self._assert_document_disabled(doc, user.id, mock_document_service_dependencies["current_time"]) - - # Verify Redis cache operations for indexing prevention - self._assert_redis_cache_operations(["doc-1", "doc-2"], "setex") - - # Verify async tasks were triggered to remove from index - self._assert_async_task_calls(mock_async_task_dependencies["remove_task"], ["doc-1", "doc-2"], "remove") - - # Verify database operations - mock_db = mock_document_service_dependencies["db_session"] - assert mock_db.add.call_count == 2 - assert mock_db.commit.call_count == 1 - - def test_batch_update_disable_already_disabled_document_skipped( - self, mock_document_service_dependencies, mock_async_task_dependencies - ): - """Test disabling documents that are already disabled.""" - dataset = DocumentBatchUpdateTestDataFactory.create_dataset_mock() - user = DocumentBatchUpdateTestDataFactory.create_user_mock() - - # Create already disabled document - disabled_doc = DocumentBatchUpdateTestDataFactory.create_document_mock(enabled=False) - mock_document_service_dependencies["get_document"].return_value = disabled_doc - - # Reset module-level Redis mock - redis_mock.reset_mock() - redis_mock.get.return_value = None - - # Attempt to disable already disabled document - DocumentService.batch_update_document_status( - dataset=dataset, document_ids=["doc-1"], action="disable", user=user - ) - - # Verify no database operations occurred (document was skipped) - mock_db = mock_document_service_dependencies["db_session"] - mock_db.commit.assert_not_called() - - # Verify no Redis setex operations occurred (document was skipped) - redis_mock.setex.assert_not_called() - - # Verify no async tasks were triggered (document was skipped) - mock_async_task_dependencies["add_task"].delay.assert_not_called() - - def test_batch_update_disable_non_completed_document_error(self, mock_document_service_dependencies): - """Test that DocumentIndexingError is raised when trying to disable non-completed documents.""" - dataset = DocumentBatchUpdateTestDataFactory.create_dataset_mock() - user = DocumentBatchUpdateTestDataFactory.create_user_mock() - - # Create a document that's not completed - non_completed_doc = DocumentBatchUpdateTestDataFactory.create_document_mock( - enabled=True, - indexing_status="indexing", # Not completed - completed_at=None, # Not completed - ) - mock_document_service_dependencies["get_document"].return_value = non_completed_doc - - # Verify that DocumentIndexingError is raised - with pytest.raises(DocumentIndexingError) as exc_info: - DocumentService.batch_update_document_status( - dataset=dataset, document_ids=["doc-1"], action="disable", user=user - ) - - # Verify error message indicates document is not completed - assert "is not completed" in str(exc_info.value) - - # ==================== Archive Document Tests ==================== - - def test_batch_update_archive_documents_success( - self, mock_document_service_dependencies, mock_async_task_dependencies - ): - """Test successful archiving of unarchived documents.""" - dataset = DocumentBatchUpdateTestDataFactory.create_dataset_mock() - user = DocumentBatchUpdateTestDataFactory.create_user_mock() - - # Create unarchived enabled document - unarchived_doc = DocumentBatchUpdateTestDataFactory.create_document_mock(enabled=True, archived=False) - mock_document_service_dependencies["get_document"].return_value = unarchived_doc - - # Reset module-level Redis mock - redis_mock.reset_mock() - redis_mock.get.return_value = None - - # Call the method to archive documents - DocumentService.batch_update_document_status( - dataset=dataset, document_ids=["doc-1"], action="archive", user=user - ) - - # Verify document attributes were updated correctly - self._assert_document_archived(unarchived_doc, user.id, mock_document_service_dependencies["current_time"]) - - # Verify Redis cache was set (because document was enabled) - redis_mock.setex.assert_called_once_with("document_doc-1_indexing", 600, 1) - - # Verify async task was triggered to remove from index (because enabled) - mock_async_task_dependencies["remove_task"].delay.assert_called_once_with("doc-1") - - # Verify database operations - mock_db = mock_document_service_dependencies["db_session"] - mock_db.add.assert_called_once() - mock_db.commit.assert_called_once() - - def test_batch_update_archive_already_archived_document_skipped(self, mock_document_service_dependencies): - """Test archiving documents that are already archived.""" - dataset = DocumentBatchUpdateTestDataFactory.create_dataset_mock() - user = DocumentBatchUpdateTestDataFactory.create_user_mock() - - # Create already archived document - archived_doc = DocumentBatchUpdateTestDataFactory.create_document_mock(enabled=True, archived=True) - mock_document_service_dependencies["get_document"].return_value = archived_doc - - # Reset module-level Redis mock - redis_mock.reset_mock() - redis_mock.get.return_value = None - - # Attempt to archive already archived document - DocumentService.batch_update_document_status( - dataset=dataset, document_ids=["doc-3"], action="archive", user=user - ) - - # Verify no database operations occurred (document was skipped) - mock_db = mock_document_service_dependencies["db_session"] - mock_db.commit.assert_not_called() - - # Verify no Redis setex operations occurred (document was skipped) - redis_mock.setex.assert_not_called() - - def test_batch_update_archive_disabled_document_no_index_removal( - self, mock_document_service_dependencies, mock_async_task_dependencies - ): - """Test archiving disabled documents (should not trigger index removal).""" - dataset = DocumentBatchUpdateTestDataFactory.create_dataset_mock() - user = DocumentBatchUpdateTestDataFactory.create_user_mock() - - # Set up disabled, unarchived document - disabled_unarchived_doc = DocumentBatchUpdateTestDataFactory.create_document_mock(enabled=False, archived=False) - mock_document_service_dependencies["get_document"].return_value = disabled_unarchived_doc - - # Reset module-level Redis mock - redis_mock.reset_mock() - redis_mock.get.return_value = None - - # Archive the disabled document - DocumentService.batch_update_document_status( - dataset=dataset, document_ids=["doc-1"], action="archive", user=user - ) - - # Verify document was archived - self._assert_document_archived( - disabled_unarchived_doc, user.id, mock_document_service_dependencies["current_time"] - ) - - # Verify no Redis cache was set (document is disabled) - redis_mock.setex.assert_not_called() - - # Verify no index removal task was triggered (document is disabled) - mock_async_task_dependencies["remove_task"].delay.assert_not_called() - - # Verify database operations still occurred - mock_db = mock_document_service_dependencies["db_session"] - mock_db.add.assert_called_once() - mock_db.commit.assert_called_once() - - # ==================== Unarchive Document Tests ==================== - - def test_batch_update_unarchive_documents_success( - self, mock_document_service_dependencies, mock_async_task_dependencies - ): - """Test successful unarchiving of archived documents.""" - dataset = DocumentBatchUpdateTestDataFactory.create_dataset_mock() - user = DocumentBatchUpdateTestDataFactory.create_user_mock() - - # Create mock archived document - archived_doc = DocumentBatchUpdateTestDataFactory.create_document_mock(enabled=True, archived=True) - mock_document_service_dependencies["get_document"].return_value = archived_doc - - # Reset module-level Redis mock - redis_mock.reset_mock() - redis_mock.get.return_value = None - - # Call the method to unarchive documents - DocumentService.batch_update_document_status( - dataset=dataset, document_ids=["doc-1"], action="un_archive", user=user - ) - - # Verify document attributes were updated correctly - self._assert_document_unarchived(archived_doc) - assert archived_doc.updated_at == mock_document_service_dependencies["current_time"] - - # Verify Redis cache was set (because document is enabled) - redis_mock.setex.assert_called_once_with("document_doc-1_indexing", 600, 1) - - # Verify async task was triggered to add back to index (because enabled) - mock_async_task_dependencies["add_task"].delay.assert_called_once_with("doc-1") - - # Verify database operations - mock_db = mock_document_service_dependencies["db_session"] - mock_db.add.assert_called_once() - mock_db.commit.assert_called_once() - - def test_batch_update_unarchive_already_unarchived_document_skipped( - self, mock_document_service_dependencies, mock_async_task_dependencies - ): - """Test unarchiving documents that are already unarchived.""" - dataset = DocumentBatchUpdateTestDataFactory.create_dataset_mock() - user = DocumentBatchUpdateTestDataFactory.create_user_mock() - - # Create already unarchived document - unarchived_doc = DocumentBatchUpdateTestDataFactory.create_document_mock(enabled=True, archived=False) - mock_document_service_dependencies["get_document"].return_value = unarchived_doc - - # Reset module-level Redis mock - redis_mock.reset_mock() - redis_mock.get.return_value = None - - # Attempt to unarchive already unarchived document - DocumentService.batch_update_document_status( - dataset=dataset, document_ids=["doc-1"], action="un_archive", user=user - ) - - # Verify no database operations occurred (document was skipped) - mock_db = mock_document_service_dependencies["db_session"] - mock_db.commit.assert_not_called() - - # Verify no Redis setex operations occurred (document was skipped) - redis_mock.setex.assert_not_called() - - # Verify no async tasks were triggered (document was skipped) - mock_async_task_dependencies["add_task"].delay.assert_not_called() - - def test_batch_update_unarchive_disabled_document_no_index_addition( - self, mock_document_service_dependencies, mock_async_task_dependencies - ): - """Test unarchiving disabled documents (should not trigger index addition).""" - dataset = DocumentBatchUpdateTestDataFactory.create_dataset_mock() - user = DocumentBatchUpdateTestDataFactory.create_user_mock() - - # Create mock archived but disabled document - archived_disabled_doc = DocumentBatchUpdateTestDataFactory.create_document_mock(enabled=False, archived=True) - mock_document_service_dependencies["get_document"].return_value = archived_disabled_doc - - # Reset module-level Redis mock - redis_mock.reset_mock() - redis_mock.get.return_value = None - - # Unarchive the disabled document - DocumentService.batch_update_document_status( - dataset=dataset, document_ids=["doc-1"], action="un_archive", user=user - ) - - # Verify document was unarchived - self._assert_document_unarchived(archived_disabled_doc) - assert archived_disabled_doc.updated_at == mock_document_service_dependencies["current_time"] - - # Verify no Redis cache was set (document is disabled) - redis_mock.setex.assert_not_called() - - # Verify no index addition task was triggered (document is disabled) - mock_async_task_dependencies["add_task"].delay.assert_not_called() - - # Verify database operations still occurred - mock_db = mock_document_service_dependencies["db_session"] - mock_db.add.assert_called_once() - mock_db.commit.assert_called_once() - - # ==================== Error Handling Tests ==================== - - def test_batch_update_document_indexing_error_redis_cache_hit(self, mock_document_service_dependencies): - """Test that DocumentIndexingError is raised when documents are currently being indexed.""" - dataset = DocumentBatchUpdateTestDataFactory.create_dataset_mock() - user = DocumentBatchUpdateTestDataFactory.create_user_mock() - - # Create mock enabled document - enabled_doc = DocumentBatchUpdateTestDataFactory.create_document_mock(enabled=True) - mock_document_service_dependencies["get_document"].return_value = enabled_doc - - # Set up mock to indicate document is being indexed - redis_mock.reset_mock() - redis_mock.get.return_value = "indexing" - - # Verify that DocumentIndexingError is raised - with pytest.raises(DocumentIndexingError) as exc_info: - DocumentService.batch_update_document_status( - dataset=dataset, document_ids=["doc-1"], action="enable", user=user - ) - - # Verify error message contains document name - assert "test_document.pdf" in str(exc_info.value) - assert "is being indexed" in str(exc_info.value) - - # Verify Redis cache was checked - redis_mock.get.assert_called_once_with("document_doc-1_indexing") - def test_batch_update_invalid_action_error(self, mock_document_service_dependencies): """Test that ValueError is raised when an invalid action is provided.""" dataset = DocumentBatchUpdateTestDataFactory.create_dataset_mock() user = DocumentBatchUpdateTestDataFactory.create_user_mock() - # Create mock document doc = DocumentBatchUpdateTestDataFactory.create_document_mock(enabled=True) mock_document_service_dependencies["get_document"].return_value = doc - # Reset module-level Redis mock redis_mock.reset_mock() redis_mock.get.return_value = None - # Test with invalid action invalid_action = "invalid_action" with pytest.raises(ValueError) as exc_info: DocumentService.batch_update_document_status( dataset=dataset, document_ids=["doc-1"], action=invalid_action, user=user ) - # Verify error message contains the invalid action assert invalid_action in str(exc_info.value) assert "Invalid action" in str(exc_info.value) - # Verify no Redis operations occurred redis_mock.setex.assert_not_called() - - def test_batch_update_async_task_error_handling( - self, mock_document_service_dependencies, mock_async_task_dependencies - ): - """Test handling of async task errors during batch operations.""" - dataset = DocumentBatchUpdateTestDataFactory.create_dataset_mock() - user = DocumentBatchUpdateTestDataFactory.create_user_mock() - - # Create mock disabled document - disabled_doc = DocumentBatchUpdateTestDataFactory.create_document_mock(enabled=False) - mock_document_service_dependencies["get_document"].return_value = disabled_doc - - # Mock async task to raise an exception - mock_async_task_dependencies["add_task"].delay.side_effect = Exception("Celery task error") - - # Reset module-level Redis mock - redis_mock.reset_mock() - redis_mock.get.return_value = None - - # Verify that async task error is propagated - with pytest.raises(Exception) as exc_info: - DocumentService.batch_update_document_status( - dataset=dataset, document_ids=["doc-1"], action="enable", user=user - ) - - # Verify error message - assert "Celery task error" in str(exc_info.value) - - # Verify database operations completed successfully - mock_db = mock_document_service_dependencies["db_session"] - mock_db.add.assert_called_once() - mock_db.commit.assert_called_once() - - # Verify Redis cache was set successfully - redis_mock.setex.assert_called_once_with("document_doc-1_indexing", 600, 1) - - # Verify document was updated - self._assert_document_enabled(disabled_doc, user.id, mock_document_service_dependencies["current_time"]) - - # ==================== Edge Case Tests ==================== - - def test_batch_update_empty_document_list(self, mock_document_service_dependencies): - """Test batch operations with an empty document ID list.""" - dataset = DocumentBatchUpdateTestDataFactory.create_dataset_mock() - user = DocumentBatchUpdateTestDataFactory.create_user_mock() - - # Call method with empty document list - result = DocumentService.batch_update_document_status( - dataset=dataset, document_ids=[], action="enable", user=user - ) - - # Verify no document lookups were performed - mock_document_service_dependencies["get_document"].assert_not_called() - - # Verify method returns None (early return) - assert result is None - - def test_batch_update_document_not_found_skipped(self, mock_document_service_dependencies): - """Test behavior when some documents don't exist in the database.""" - dataset = DocumentBatchUpdateTestDataFactory.create_dataset_mock() - user = DocumentBatchUpdateTestDataFactory.create_user_mock() - - # Mock document service to return None (document not found) - mock_document_service_dependencies["get_document"].return_value = None - - # Call method with non-existent document ID - # This should not raise an error, just skip the missing document - try: - DocumentService.batch_update_document_status( - dataset=dataset, document_ids=["non-existent-doc"], action="enable", user=user - ) - except Exception as e: - pytest.fail(f"Method should not raise exception for missing documents: {e}") - - # Verify document lookup was attempted - mock_document_service_dependencies["get_document"].assert_called_once_with(dataset.id, "non-existent-doc") - - def test_batch_update_mixed_document_states_and_actions( - self, mock_document_service_dependencies, mock_async_task_dependencies - ): - """Test batch operations on documents with mixed states and various scenarios.""" - dataset = DocumentBatchUpdateTestDataFactory.create_dataset_mock() - user = DocumentBatchUpdateTestDataFactory.create_user_mock() - - # Create documents in various states - disabled_doc = DocumentBatchUpdateTestDataFactory.create_document_mock("doc-1", enabled=False) - enabled_doc = DocumentBatchUpdateTestDataFactory.create_document_mock("doc-2", enabled=True) - archived_doc = DocumentBatchUpdateTestDataFactory.create_document_mock("doc-3", enabled=True, archived=True) - - # Mix of different document states - documents = [disabled_doc, enabled_doc, archived_doc] - mock_document_service_dependencies["get_document"].side_effect = documents - - # Reset module-level Redis mock - redis_mock.reset_mock() - redis_mock.get.return_value = None - - # Perform enable operation on mixed state documents - DocumentService.batch_update_document_status( - dataset=dataset, document_ids=["doc-1", "doc-2", "doc-3"], action="enable", user=user - ) - - # Verify only the disabled document was processed - # (enabled and archived documents should be skipped for enable action) - - # Only one add should occur (for the disabled document that was enabled) - mock_db = mock_document_service_dependencies["db_session"] - mock_db.add.assert_called_once() - # Only one commit should occur - mock_db.commit.assert_called_once() - - # Only one Redis setex should occur (for the document that was enabled) - redis_mock.setex.assert_called_once_with("document_doc-1_indexing", 600, 1) - - # Only one async task should be triggered (for the document that was enabled) - mock_async_task_dependencies["add_task"].delay.assert_called_once_with("doc-1") - - # ==================== Performance Tests ==================== - - def test_batch_update_large_document_list_performance( - self, mock_document_service_dependencies, mock_async_task_dependencies - ): - """Test batch operations with a large number of documents.""" - dataset = DocumentBatchUpdateTestDataFactory.create_dataset_mock() - user = DocumentBatchUpdateTestDataFactory.create_user_mock() - - # Create large list of document IDs - document_ids = [f"doc-{i}" for i in range(1, 101)] # 100 documents - - # Create mock documents - mock_documents = DocumentBatchUpdateTestDataFactory.create_multiple_documents( - document_ids, - enabled=False, # All disabled, will be enabled - ) - mock_document_service_dependencies["get_document"].side_effect = mock_documents - - # Reset module-level Redis mock - redis_mock.reset_mock() - redis_mock.get.return_value = None - - # Perform batch enable operation - DocumentService.batch_update_document_status( - dataset=dataset, document_ids=document_ids, action="enable", user=user - ) - - # Verify all documents were processed - assert mock_document_service_dependencies["get_document"].call_count == 100 - - # Verify all documents were updated - for mock_doc in mock_documents: - self._assert_document_enabled(mock_doc, user.id, mock_document_service_dependencies["current_time"]) - - # Verify database operations - mock_db = mock_document_service_dependencies["db_session"] - assert mock_db.add.call_count == 100 - assert mock_db.commit.call_count == 1 - - # Verify Redis cache operations occurred for each document - assert redis_mock.setex.call_count == 100 - - # Verify async tasks were triggered for each document - assert mock_async_task_dependencies["add_task"].delay.call_count == 100 - - # Verify correct Redis cache keys were set - expected_redis_calls = [call(f"document_doc-{i}_indexing", 600, 1) for i in range(1, 101)] - redis_mock.setex.assert_has_calls(expected_redis_calls) - - # Verify correct async task calls - expected_task_calls = [call(f"doc-{i}") for i in range(1, 101)] - mock_async_task_dependencies["add_task"].delay.assert_has_calls(expected_task_calls) - - def test_batch_update_mixed_document_states_complex_scenario( - self, mock_document_service_dependencies, mock_async_task_dependencies - ): - """Test complex batch operations with documents in various states.""" - dataset = DocumentBatchUpdateTestDataFactory.create_dataset_mock() - user = DocumentBatchUpdateTestDataFactory.create_user_mock() - - # Create documents in various states - doc1 = DocumentBatchUpdateTestDataFactory.create_document_mock("doc-1", enabled=False) # Will be enabled - doc2 = DocumentBatchUpdateTestDataFactory.create_document_mock( - "doc-2", enabled=True - ) # Already enabled, will be skipped - doc3 = DocumentBatchUpdateTestDataFactory.create_document_mock( - "doc-3", enabled=True - ) # Already enabled, will be skipped - doc4 = DocumentBatchUpdateTestDataFactory.create_document_mock( - "doc-4", enabled=True - ) # Not affected by enable action - doc5 = DocumentBatchUpdateTestDataFactory.create_document_mock( - "doc-5", enabled=True, archived=True - ) # Not affected by enable action - doc6 = None # Non-existent, will be skipped - - mock_document_service_dependencies["get_document"].side_effect = [doc1, doc2, doc3, doc4, doc5, doc6] - - # Reset module-level Redis mock - redis_mock.reset_mock() - redis_mock.get.return_value = None - - # Perform mixed batch operations - DocumentService.batch_update_document_status( - dataset=dataset, - document_ids=["doc-1", "doc-2", "doc-3", "doc-4", "doc-5", "doc-6"], - action="enable", # This will only affect doc1 - user=user, - ) - - # Verify document 1 was enabled - self._assert_document_enabled(doc1, user.id, mock_document_service_dependencies["current_time"]) - - # Verify other documents were skipped appropriately - assert doc2.enabled == True # No change - assert doc3.enabled == True # No change - assert doc4.enabled == True # No change - assert doc5.enabled == True # No change - - # Verify database commits occurred for processed documents - # Only doc1 should be added (others were skipped, doc6 doesn't exist) - mock_db = mock_document_service_dependencies["db_session"] - assert mock_db.add.call_count == 1 - assert mock_db.commit.call_count == 1 - - # Verify Redis cache operations occurred for processed documents - # Only doc1 should have Redis operations - assert redis_mock.setex.call_count == 1 - - # Verify async tasks were triggered for processed documents - # Only doc1 should trigger tasks - assert mock_async_task_dependencies["add_task"].delay.call_count == 1 - - # Verify correct Redis cache keys were set - expected_redis_calls = [call("document_doc-1_indexing", 600, 1)] - redis_mock.setex.assert_has_calls(expected_redis_calls) - - # Verify correct async task calls - expected_task_calls = [call("doc-1")] - mock_async_task_dependencies["add_task"].delay.assert_has_calls(expected_task_calls) From d7e399872d001e2ff43e754120356eab238c5644 Mon Sep 17 00:00:00 2001 From: Stephen Zhou Date: Tue, 3 Mar 2026 21:42:47 +0800 Subject: [PATCH 007/256] fix: get i18n lazy, make vinext build works (#32917) --- .../goto-anything/actions/commands/slash.tsx | 13 +- web/package.json | 6 +- web/pnpm-lock.yaml | 463 +++++++++++++++--- web/vite.config.ts | 20 + 4 files changed, 433 insertions(+), 69 deletions(-) diff --git a/web/app/components/goto-anything/actions/commands/slash.tsx b/web/app/components/goto-anything/actions/commands/slash.tsx index 6aad67731f..25d1aa59c5 100644 --- a/web/app/components/goto-anything/actions/commands/slash.tsx +++ b/web/app/components/goto-anything/actions/commands/slash.tsx @@ -14,13 +14,17 @@ import { slashCommandRegistry } from './registry' import { themeCommand } from './theme' import { zenCommand } from './zen' -const i18n = getI18n() - export const slashAction: ActionItem = { key: '/', shortcut: '/', - title: i18n.t('gotoAnything.actions.slashTitle', { ns: 'app' }), - description: i18n.t('gotoAnything.actions.slashDesc', { ns: 'app' }), + get title() { + const i18n = getI18n() + return i18n.t('gotoAnything.actions.slashTitle', { ns: 'app' }) + }, + get description() { + const i18n = getI18n() + return i18n.t('gotoAnything.actions.slashDesc', { ns: 'app' }) + }, action: (result) => { if (result.type !== 'command') return @@ -28,6 +32,7 @@ export const slashAction: ActionItem = { executeCommand(command, args) }, search: async (query, _searchTerm = '') => { + const i18n = getI18n() // Delegate all search logic to the command registry system return slashCommandRegistry.search(query, i18n.language) }, diff --git a/web/package.json b/web/package.json index a120f5718d..e0df37836e 100644 --- a/web/package.json +++ b/web/package.json @@ -31,9 +31,9 @@ "dev:vinext": "vinext dev", "build": "next build", "build:docker": "next build && node scripts/optimize-standalone.js", - "build:vinext": "vinext build", + "build:vinext": "cross-env NODE_ENV=production vinext build", "start": "node ./scripts/copy-and-start.mjs", - "start:vinext": "vinext start", + "start:vinext": "cross-env NODE_ENV=production vinext start", "lint": "eslint --cache --concurrency=auto", "lint:ci": "eslint --cache --concurrency 2", "lint:fix": "pnpm lint --fix", @@ -248,7 +248,7 @@ "typescript": "5.9.3", "uglify-js": "3.19.3", "vinext": "https://pkg.pr.new/hyoban/vinext@cfae669", - "vite": "7.3.1", + "vite": "8.0.0-beta.16", "vite-tsconfig-paths": "6.1.1", "vitest": "4.0.18", "vitest-canvas-mock": "1.1.3" diff --git a/web/pnpm-lock.yaml b/web/pnpm-lock.yaml index 17c3fb7f06..81ecb7f857 100644 --- a/web/pnpm-lock.yaml +++ b/web/pnpm-lock.yaml @@ -375,7 +375,7 @@ importers: devDependencies: '@antfu/eslint-config': specifier: 7.6.1 - version: 7.6.1(@eslint-react/eslint-plugin@2.13.0(eslint@10.0.2(jiti@1.21.7))(typescript@5.9.3))(@next/eslint-plugin-next@16.1.6)(@typescript-eslint/rule-tester@8.56.1(eslint@10.0.2(jiti@1.21.7))(typescript@5.9.3))(@typescript-eslint/typescript-estree@8.56.1(typescript@5.9.3))(@typescript-eslint/utils@8.56.1(eslint@10.0.2(jiti@1.21.7))(typescript@5.9.3))(@vue/compiler-sfc@3.5.27)(eslint-plugin-react-hooks@7.0.1(eslint@10.0.2(jiti@1.21.7)))(eslint-plugin-react-refresh@0.5.2(eslint@10.0.2(jiti@1.21.7)))(eslint@10.0.2(jiti@1.21.7))(typescript@5.9.3)(vitest@4.0.18(@types/node@24.10.12)(jiti@1.21.7)(jsdom@27.3.0(canvas@3.2.1))(sass@1.93.2)(terser@5.46.0)(tsx@4.21.0)(yaml@2.8.2)) + version: 7.6.1(@eslint-react/eslint-plugin@2.13.0(eslint@10.0.2(jiti@1.21.7))(typescript@5.9.3))(@next/eslint-plugin-next@16.1.6)(@typescript-eslint/rule-tester@8.56.1(eslint@10.0.2(jiti@1.21.7))(typescript@5.9.3))(@typescript-eslint/typescript-estree@8.56.1(typescript@5.9.3))(@typescript-eslint/utils@8.56.1(eslint@10.0.2(jiti@1.21.7))(typescript@5.9.3))(@vue/compiler-sfc@3.5.27)(eslint-plugin-react-hooks@7.0.1(eslint@10.0.2(jiti@1.21.7)))(eslint-plugin-react-refresh@0.5.2(eslint@10.0.2(jiti@1.21.7)))(eslint@10.0.2(jiti@1.21.7))(typescript@5.9.3)(vitest@4.0.18(@types/node@24.10.12)(jiti@1.21.7)(jsdom@27.3.0(canvas@3.2.1))(lightningcss@1.31.1)(sass@1.93.2)(terser@5.46.0)(tsx@4.21.0)(yaml@2.8.2)) '@chromatic-com/storybook': specifier: 5.0.1 version: 5.0.1(storybook@10.2.13(@testing-library/dom@10.4.1)(react-dom@19.2.4(react@19.2.4))(react@19.2.4)) @@ -414,7 +414,7 @@ importers: version: 9.5.4(@swc/helpers@0.5.18)(esbuild-wasm@0.27.2)(esbuild@0.27.2)(next@16.1.5(@babel/core@7.29.0)(react-dom@19.2.4(react@19.2.4))(react@19.2.4)(sass@1.93.2))(react@19.2.4)(typescript@5.9.3) '@storybook/addon-docs': specifier: 10.2.13 - version: 10.2.13(@types/react@19.2.9)(esbuild@0.27.2)(rollup@4.56.0)(storybook@10.2.13(@testing-library/dom@10.4.1)(react-dom@19.2.4(react@19.2.4))(react@19.2.4))(vite@7.3.1(@types/node@24.10.12)(jiti@1.21.7)(sass@1.93.2)(terser@5.46.0)(tsx@4.21.0)(yaml@2.8.2))(webpack@5.104.1(esbuild@0.27.2)(uglify-js@3.19.3)) + version: 10.2.13(@types/react@19.2.9)(esbuild@0.27.2)(rollup@4.56.0)(storybook@10.2.13(@testing-library/dom@10.4.1)(react-dom@19.2.4(react@19.2.4))(react@19.2.4))(vite@8.0.0-beta.16(@types/node@24.10.12)(esbuild@0.27.2)(jiti@1.21.7)(sass@1.93.2)(terser@5.46.0)(tsx@4.21.0)(yaml@2.8.2))(webpack@5.104.1(esbuild@0.27.2)(uglify-js@3.19.3)) '@storybook/addon-links': specifier: 10.2.13 version: 10.2.13(react@19.2.4)(storybook@10.2.13(@testing-library/dom@10.4.1)(react-dom@19.2.4(react@19.2.4))(react@19.2.4)) @@ -426,7 +426,7 @@ importers: version: 10.2.13(storybook@10.2.13(@testing-library/dom@10.4.1)(react-dom@19.2.4(react@19.2.4))(react@19.2.4)) '@storybook/nextjs-vite': specifier: 10.2.13 - version: 10.2.13(@babel/core@7.29.0)(esbuild@0.27.2)(next@16.1.5(@babel/core@7.29.0)(react-dom@19.2.4(react@19.2.4))(react@19.2.4)(sass@1.93.2))(react-dom@19.2.4(react@19.2.4))(react@19.2.4)(rollup@4.56.0)(storybook@10.2.13(@testing-library/dom@10.4.1)(react-dom@19.2.4(react@19.2.4))(react@19.2.4))(typescript@5.9.3)(vite@7.3.1(@types/node@24.10.12)(jiti@1.21.7)(sass@1.93.2)(terser@5.46.0)(tsx@4.21.0)(yaml@2.8.2))(webpack@5.104.1(esbuild@0.27.2)(uglify-js@3.19.3)) + version: 10.2.13(@babel/core@7.29.0)(esbuild@0.27.2)(next@16.1.5(@babel/core@7.29.0)(react-dom@19.2.4(react@19.2.4))(react@19.2.4)(sass@1.93.2))(react-dom@19.2.4(react@19.2.4))(react@19.2.4)(rollup@4.56.0)(storybook@10.2.13(@testing-library/dom@10.4.1)(react-dom@19.2.4(react@19.2.4))(react@19.2.4))(typescript@5.9.3)(vite@8.0.0-beta.16(@types/node@24.10.12)(esbuild@0.27.2)(jiti@1.21.7)(sass@1.93.2)(terser@5.46.0)(tsx@4.21.0)(yaml@2.8.2))(webpack@5.104.1(esbuild@0.27.2)(uglify-js@3.19.3)) '@storybook/react': specifier: 10.2.13 version: 10.2.13(react-dom@19.2.4(react@19.2.4))(react@19.2.4)(storybook@10.2.13(@testing-library/dom@10.4.1)(react-dom@19.2.4(react@19.2.4))(react@19.2.4))(typescript@5.9.3) @@ -513,13 +513,13 @@ importers: version: 7.0.0-dev.20251209.1 '@vitejs/plugin-react': specifier: 5.1.4 - version: 5.1.4(vite@7.3.1(@types/node@24.10.12)(jiti@1.21.7)(sass@1.93.2)(terser@5.46.0)(tsx@4.21.0)(yaml@2.8.2)) + version: 5.1.4(vite@8.0.0-beta.16(@types/node@24.10.12)(esbuild@0.27.2)(jiti@1.21.7)(sass@1.93.2)(terser@5.46.0)(tsx@4.21.0)(yaml@2.8.2)) '@vitejs/plugin-rsc': specifier: 0.5.21 - version: 0.5.21(react-dom@19.2.4(react@19.2.4))(react-server-dom-webpack@19.2.4(react-dom@19.2.4(react@19.2.4))(react@19.2.4)(webpack@5.104.1(esbuild@0.27.2)(uglify-js@3.19.3)))(react@19.2.4)(vite@7.3.1(@types/node@24.10.12)(jiti@1.21.7)(sass@1.93.2)(terser@5.46.0)(tsx@4.21.0)(yaml@2.8.2)) + version: 0.5.21(react-dom@19.2.4(react@19.2.4))(react-server-dom-webpack@19.2.4(react-dom@19.2.4(react@19.2.4))(react@19.2.4)(webpack@5.104.1(esbuild@0.27.2)(uglify-js@3.19.3)))(react@19.2.4)(vite@8.0.0-beta.16(@types/node@24.10.12)(esbuild@0.27.2)(jiti@1.21.7)(sass@1.93.2)(terser@5.46.0)(tsx@4.21.0)(yaml@2.8.2)) '@vitest/coverage-v8': specifier: 4.0.18 - version: 4.0.18(vitest@4.0.18(@types/node@24.10.12)(jiti@1.21.7)(jsdom@27.3.0(canvas@3.2.1))(sass@1.93.2)(terser@5.46.0)(tsx@4.21.0)(yaml@2.8.2)) + version: 4.0.18(vitest@4.0.18(@types/node@24.10.12)(jiti@1.21.7)(jsdom@27.3.0(canvas@3.2.1))(lightningcss@1.31.1)(sass@1.93.2)(terser@5.46.0)(tsx@4.21.0)(yaml@2.8.2)) autoprefixer: specifier: 10.4.21 version: 10.4.21(postcss@8.5.6) @@ -609,19 +609,19 @@ importers: version: 3.19.3 vinext: specifier: https://pkg.pr.new/hyoban/vinext@cfae669 - version: https://pkg.pr.new/hyoban/vinext@cfae669(next@16.1.5(@babel/core@7.29.0)(react-dom@19.2.4(react@19.2.4))(react@19.2.4)(sass@1.93.2))(react-dom@19.2.4(react@19.2.4))(react@19.2.4)(typescript@5.9.3)(vite@7.3.1(@types/node@24.10.12)(jiti@1.21.7)(sass@1.93.2)(terser@5.46.0)(tsx@4.21.0)(yaml@2.8.2))(webpack@5.104.1(esbuild@0.27.2)(uglify-js@3.19.3)) + version: https://pkg.pr.new/hyoban/vinext@cfae669(next@16.1.5(@babel/core@7.29.0)(react-dom@19.2.4(react@19.2.4))(react@19.2.4)(sass@1.93.2))(react-dom@19.2.4(react@19.2.4))(react@19.2.4)(typescript@5.9.3)(vite@8.0.0-beta.16(@types/node@24.10.12)(esbuild@0.27.2)(jiti@1.21.7)(sass@1.93.2)(terser@5.46.0)(tsx@4.21.0)(yaml@2.8.2))(webpack@5.104.1(esbuild@0.27.2)(uglify-js@3.19.3)) vite: - specifier: 7.3.1 - version: 7.3.1(@types/node@24.10.12)(jiti@1.21.7)(sass@1.93.2)(terser@5.46.0)(tsx@4.21.0)(yaml@2.8.2) + specifier: 8.0.0-beta.16 + version: 8.0.0-beta.16(@types/node@24.10.12)(esbuild@0.27.2)(jiti@1.21.7)(sass@1.93.2)(terser@5.46.0)(tsx@4.21.0)(yaml@2.8.2) vite-tsconfig-paths: specifier: 6.1.1 - version: 6.1.1(typescript@5.9.3)(vite@7.3.1(@types/node@24.10.12)(jiti@1.21.7)(sass@1.93.2)(terser@5.46.0)(tsx@4.21.0)(yaml@2.8.2)) + version: 6.1.1(typescript@5.9.3)(vite@8.0.0-beta.16(@types/node@24.10.12)(esbuild@0.27.2)(jiti@1.21.7)(sass@1.93.2)(terser@5.46.0)(tsx@4.21.0)(yaml@2.8.2)) vitest: specifier: 4.0.18 - version: 4.0.18(@types/node@24.10.12)(jiti@1.21.7)(jsdom@27.3.0(canvas@3.2.1))(sass@1.93.2)(terser@5.46.0)(tsx@4.21.0)(yaml@2.8.2) + version: 4.0.18(@types/node@24.10.12)(jiti@1.21.7)(jsdom@27.3.0(canvas@3.2.1))(lightningcss@1.31.1)(sass@1.93.2)(terser@5.46.0)(tsx@4.21.0)(yaml@2.8.2) vitest-canvas-mock: specifier: 1.1.3 - version: 1.1.3(vitest@4.0.18(@types/node@24.10.12)(jiti@1.21.7)(jsdom@27.3.0(canvas@3.2.1))(sass@1.93.2)(terser@5.46.0)(tsx@4.21.0)(yaml@2.8.2)) + version: 1.1.3(vitest@4.0.18(@types/node@24.10.12)(jiti@1.21.7)(jsdom@27.3.0(canvas@3.2.1))(lightningcss@1.31.1)(sass@1.93.2)(terser@5.46.0)(tsx@4.21.0)(yaml@2.8.2)) packages: @@ -2002,6 +2002,13 @@ packages: resolution: {integrity: sha512-XRO0zi2NIUKq2lUk3T1ecFSld1fMWRKE6naRFGkgkdeosx7IslyUKNv5Dcb5PJTja9tHJoFu0v/7yEpAkrkrTg==} engines: {node: ^20.19.0 || ^22.13.0 || >=24} + '@oxc-project/runtime@0.115.0': + resolution: {integrity: sha512-Rg8Wlt5dCbXhQnsXPrkOjL1DTSvXLgb2R/KYfnf1/K+R0k6UMLEmbQXPM+kwrWqSmWA2t0B1EtHy2/3zikQpvQ==} + engines: {node: ^20.19.0 || >=22.12.0} + + '@oxc-project/types@0.115.0': + resolution: {integrity: sha512-4n91DKnebUS4yjUHl2g3/b2T+IUdCfmoZGhmwsovZCDaJSs+QkVAM+0AqqTxHSsHfeiMuueT75cZaZcT/m0pSw==} + '@oxc-resolver/binding-android-arm-eabi@11.16.4': resolution: {integrity: sha512-6XUHilmj8D6Ggus+sTBp64x/DUQ7LgC/dvTDdUOt4iMQnDdSep6N1mnvVLIiG+qM5tRnNHravNzBJnUlYwRQoA==} cpu: [arm] @@ -2479,12 +2486,92 @@ packages: resolution: {integrity: sha512-UuBOt7BOsKVOkFXRe4Ypd/lADuNIfqJXv8GvHqtXaTYXPPKkj2nS2zPllVsrtRjcomDhIJVBnZwfmlI222WH8g==} engines: {node: '>=14.0.0'} + '@rolldown/binding-android-arm64@1.0.0-rc.6': + resolution: {integrity: sha512-kvjTSWGcrv+BaR2vge57rsKiYdVR8V8CoS0vgKrc570qRBfty4bT+1X0z3j2TaVV+kAYzA0PjeB9+mdZyqUZlg==} + engines: {node: ^20.19.0 || >=22.12.0} + cpu: [arm64] + os: [android] + + '@rolldown/binding-darwin-arm64@1.0.0-rc.6': + resolution: {integrity: sha512-+tJhD21KvGNtUrpLXrZQlT+j5HZKiEwR2qtcZb3vNOUpvoT9QjEykr75ZW/Kr0W89gose/HVXU6351uVZD8Qvw==} + engines: {node: ^20.19.0 || >=22.12.0} + cpu: [arm64] + os: [darwin] + + '@rolldown/binding-darwin-x64@1.0.0-rc.6': + resolution: {integrity: sha512-DKNhjMk38FAWaHwUt1dFR3rA/qRAvn2NUvSG2UGvxvlMxSmN/qqww/j4ABAbXhNRXtGQNmrAINMXRuwHl16ZHg==} + engines: {node: ^20.19.0 || >=22.12.0} + cpu: [x64] + os: [darwin] + + '@rolldown/binding-freebsd-x64@1.0.0-rc.6': + resolution: {integrity: sha512-8TThsRkCPAnfyMBShxrGdtoOE6h36QepqRQI97iFaQSCRbHFWHcDHppcojZnzXoruuhPnjMEygzaykvPVJsMRg==} + engines: {node: ^20.19.0 || >=22.12.0} + cpu: [x64] + os: [freebsd] + + '@rolldown/binding-linux-arm-gnueabihf@1.0.0-rc.6': + resolution: {integrity: sha512-ZfmFoOwPUZCWtGOVC9/qbQzfc0249FrRUOzV2XabSMUV60Crp211OWLQN1zmQAsRIVWRcEwhJ46Z1mXGo/L/nQ==} + engines: {node: ^20.19.0 || >=22.12.0} + cpu: [arm] + os: [linux] + + '@rolldown/binding-linux-arm64-gnu@1.0.0-rc.6': + resolution: {integrity: sha512-ZsGzbNETxPodGlLTYHaCSGVhNN/rvkMDCJYHdT7PZr5jFJRmBfmDi2awhF64Dt2vxrJqY6VeeYSgOzEbHRsb7Q==} + engines: {node: ^20.19.0 || >=22.12.0} + cpu: [arm64] + os: [linux] + + '@rolldown/binding-linux-arm64-musl@1.0.0-rc.6': + resolution: {integrity: sha512-elPpdevtCdUOqziemR86C4CSCr/5sUxalzDrf/CJdMT+kZt2C556as++qHikNOz0vuFf52h+GJNXZM08eWgGPQ==} + engines: {node: ^20.19.0 || >=22.12.0} + cpu: [arm64] + os: [linux] + + '@rolldown/binding-linux-x64-gnu@1.0.0-rc.6': + resolution: {integrity: sha512-IBwXsf56o3xhzAyaZxdM1CX8UFiBEUFCjiVUgny67Q8vPIqkjzJj0YKhd3TbBHanuxThgBa59f6Pgutg2OGk5A==} + engines: {node: ^20.19.0 || >=22.12.0} + cpu: [x64] + os: [linux] + + '@rolldown/binding-linux-x64-musl@1.0.0-rc.6': + resolution: {integrity: sha512-vOk7G8V9Zm+8a6PL6JTpCea61q491oYlGtO6CvnsbhNLlKdf0bbCPytFzGQhYmCKZDKkEbmnkcIprTEGCURnwg==} + engines: {node: ^20.19.0 || >=22.12.0} + cpu: [x64] + os: [linux] + + '@rolldown/binding-openharmony-arm64@1.0.0-rc.6': + resolution: {integrity: sha512-ASjEDI4MRv7XCQb2JVaBzfEYO98JKCGrAgoW6M03fJzH/ilCnC43Mb3ptB9q/lzsaahoJyIBoAGKAYEjUvpyvQ==} + engines: {node: ^20.19.0 || >=22.12.0} + cpu: [arm64] + os: [openharmony] + + '@rolldown/binding-wasm32-wasi@1.0.0-rc.6': + resolution: {integrity: sha512-mYa1+h2l6Zc0LvmwUh0oXKKYihnw/1WC73vTqw+IgtfEtv47A+rWzzcWwVDkW73+UDr0d/Ie/HRXoaOY22pQDw==} + engines: {node: '>=14.0.0'} + cpu: [wasm32] + + '@rolldown/binding-win32-arm64-msvc@1.0.0-rc.6': + resolution: {integrity: sha512-e2ABskbNH3MRUBMjgxaMjYIw11DSwjLJxBII3UgpF6WClGLIh8A20kamc+FKH5vIaFVnYQInmcLYSUVpqMPLow==} + engines: {node: ^20.19.0 || >=22.12.0} + cpu: [arm64] + os: [win32] + + '@rolldown/binding-win32-x64-msvc@1.0.0-rc.6': + resolution: {integrity: sha512-dJVc3ifhaRXxIEh1xowLohzFrlQXkJ66LepHm+CmSprTWgVrPa8Fx3OL57xwIqDEH9hufcKkDX2v65rS3NZyRA==} + engines: {node: ^20.19.0 || >=22.12.0} + cpu: [x64] + os: [win32] + '@rolldown/pluginutils@1.0.0-rc.3': resolution: {integrity: sha512-eybk3TjzzzV97Dlj5c+XrBFW57eTNhzod66y9HrBlzJ6NsCrWCp/2kaPS3K9wJmurBC0Tdw4yPjXKZqlznim3Q==} '@rolldown/pluginutils@1.0.0-rc.5': resolution: {integrity: sha512-RxlLX/DPoarZ9PtxVrQgZhPoor987YtKQqCo5zkjX+0S0yLJ7Vv515Wk6+xtTL67VONKJKxETWZwuZjss2idYw==} + '@rolldown/pluginutils@1.0.0-rc.6': + resolution: {integrity: sha512-Y0+JT8Mi1mmW08K6HieG315XNRu4L0rkfCpA364HtytjgiqYnMYRdFPcxRl+BQQqNXzecL2S9nii+RUpO93XIA==} + '@rollup/plugin-replace@6.0.3': resolution: {integrity: sha512-J4RZarRvQAm5IF0/LwUUg+obsm+xZhYnbMXmXROyoSE1ATJe3oXSb9L5MMppdxP2ylNSjv6zFBwKYjcKMucVfA==} engines: {node: '>=14.0.0'} @@ -5720,6 +5807,76 @@ packages: engines: {node: '>=16'} hasBin: true + lightningcss-android-arm64@1.31.1: + resolution: {integrity: sha512-HXJF3x8w9nQ4jbXRiNppBCqeZPIAfUo8zE/kOEGbW5NZvGc/K7nMxbhIr+YlFlHW5mpbg/YFPdbnCh1wAXCKFg==} + engines: {node: '>= 12.0.0'} + cpu: [arm64] + os: [android] + + lightningcss-darwin-arm64@1.31.1: + resolution: {integrity: sha512-02uTEqf3vIfNMq3h/z2cJfcOXnQ0GRwQrkmPafhueLb2h7mqEidiCzkE4gBMEH65abHRiQvhdcQ+aP0D0g67sg==} + engines: {node: '>= 12.0.0'} + cpu: [arm64] + os: [darwin] + + lightningcss-darwin-x64@1.31.1: + resolution: {integrity: sha512-1ObhyoCY+tGxtsz1lSx5NXCj3nirk0Y0kB/g8B8DT+sSx4G9djitg9ejFnjb3gJNWo7qXH4DIy2SUHvpoFwfTA==} + engines: {node: '>= 12.0.0'} + cpu: [x64] + os: [darwin] + + lightningcss-freebsd-x64@1.31.1: + resolution: {integrity: sha512-1RINmQKAItO6ISxYgPwszQE1BrsVU5aB45ho6O42mu96UiZBxEXsuQ7cJW4zs4CEodPUioj/QrXW1r9pLUM74A==} + engines: {node: '>= 12.0.0'} + cpu: [x64] + os: [freebsd] + + lightningcss-linux-arm-gnueabihf@1.31.1: + resolution: {integrity: sha512-OOCm2//MZJ87CdDK62rZIu+aw9gBv4azMJuA8/KB74wmfS3lnC4yoPHm0uXZ/dvNNHmnZnB8XLAZzObeG0nS1g==} + engines: {node: '>= 12.0.0'} + cpu: [arm] + os: [linux] + + lightningcss-linux-arm64-gnu@1.31.1: + resolution: {integrity: sha512-WKyLWztD71rTnou4xAD5kQT+982wvca7E6QoLpoawZ1gP9JM0GJj4Tp5jMUh9B3AitHbRZ2/H3W5xQmdEOUlLg==} + engines: {node: '>= 12.0.0'} + cpu: [arm64] + os: [linux] + + lightningcss-linux-arm64-musl@1.31.1: + resolution: {integrity: sha512-mVZ7Pg2zIbe3XlNbZJdjs86YViQFoJSpc41CbVmKBPiGmC4YrfeOyz65ms2qpAobVd7WQsbW4PdsSJEMymyIMg==} + engines: {node: '>= 12.0.0'} + cpu: [arm64] + os: [linux] + + lightningcss-linux-x64-gnu@1.31.1: + resolution: {integrity: sha512-xGlFWRMl+0KvUhgySdIaReQdB4FNudfUTARn7q0hh/V67PVGCs3ADFjw+6++kG1RNd0zdGRlEKa+T13/tQjPMA==} + engines: {node: '>= 12.0.0'} + cpu: [x64] + os: [linux] + + lightningcss-linux-x64-musl@1.31.1: + resolution: {integrity: sha512-eowF8PrKHw9LpoZii5tdZwnBcYDxRw2rRCyvAXLi34iyeYfqCQNA9rmUM0ce62NlPhCvof1+9ivRaTY6pSKDaA==} + engines: {node: '>= 12.0.0'} + cpu: [x64] + os: [linux] + + lightningcss-win32-arm64-msvc@1.31.1: + resolution: {integrity: sha512-aJReEbSEQzx1uBlQizAOBSjcmr9dCdL3XuC/6HLXAxmtErsj2ICo5yYggg1qOODQMtnjNQv2UHb9NpOuFtYe4w==} + engines: {node: '>= 12.0.0'} + cpu: [arm64] + os: [win32] + + lightningcss-win32-x64-msvc@1.31.1: + resolution: {integrity: sha512-I9aiFrbd7oYHwlnQDqr1Roz+fTz61oDDJX7n9tYF9FJymH1cIN1DtKw3iYt6b8WZgEjoNwVSncwF4wx/ZedMhw==} + engines: {node: '>= 12.0.0'} + cpu: [x64] + os: [win32] + + lightningcss@1.31.1: + resolution: {integrity: sha512-l51N2r93WmGUye3WuFoN5k10zyvrVs0qfKBhyC5ogUQ6Ew6JUSswh78mbSO+IU3nTWsyOArqPCcShdQSadghBQ==} + engines: {node: '>= 12.0.0'} + lilconfig@3.1.3: resolution: {integrity: sha512-/vlFKAoH5Cgt3Ie+JLhRbwOsCQePABiU3tJ1egGvyQ+33R/vcwM2Zl2QR/LzjsBeItPt3oSVXapn+m4nQDvpzw==} engines: {node: '>=14'} @@ -6872,6 +7029,11 @@ packages: robust-predicates@3.0.2: resolution: {integrity: sha512-IXgzBWvWQwE6PrDI05OvmXUIruQTcoMDzRsOd5CDvHCVLcLHMTSYvOK5Cm46kWqlV3yAbuSpBZdJ5oP5OUoStg==} + rolldown@1.0.0-rc.6: + resolution: {integrity: sha512-B8vFPV1ADyegoYfhg+E7RAucYKv0xdVlwYYsIJgfPNeiSxZGWNxts9RqhyGzC11ULK/VaeXyKezGCwpMiH8Ktw==} + engines: {node: ^20.19.0 || >=22.12.0} + hasBin: true + rollup@4.56.0: resolution: {integrity: sha512-9FwVqlgUHzbXtDg9RCMgodF3Ua4Na6Gau+Sdt9vyCN4RhHfVKX2DCHy3BjMLTDd47ITDhYAnTwGulWTblJSDLg==} engines: {node: '>=18.0.0', npm: '>=8.0.0'} @@ -7641,6 +7803,49 @@ packages: yaml: optional: true + vite@8.0.0-beta.16: + resolution: {integrity: sha512-c0t7hYkxsjws89HH+BUFh/sL3BpPNhNsL9CJrTpMxBmwKQBRSa5OJ5w4o9O0bQVI/H/vx7UpUUIevvXa37NS/Q==} + engines: {node: ^20.19.0 || >=22.12.0} + hasBin: true + peerDependencies: + '@types/node': ^20.19.0 || >=22.12.0 + '@vitejs/devtools': ^0.0.0-alpha.31 + esbuild: 0.27.2 + jiti: '>=1.21.0' + less: ^4.0.0 + sass: ^1.70.0 + sass-embedded: ^1.70.0 + stylus: '>=0.54.8' + sugarss: ^5.0.0 + terser: ^5.16.0 + tsx: ^4.8.1 + yaml: ^2.4.2 + peerDependenciesMeta: + '@types/node': + optional: true + '@vitejs/devtools': + optional: true + esbuild: + optional: true + jiti: + optional: true + less: + optional: true + sass: + optional: true + sass-embedded: + optional: true + stylus: + optional: true + sugarss: + optional: true + terser: + optional: true + tsx: + optional: true + yaml: + optional: true + vitefu@1.1.2: resolution: {integrity: sha512-zpKATdUbzbsycPFBN71nS2uzBUQiVnFoOrr2rvqv34S1lcAgMKKkjWleLGeiJlZ8lwCXvtWaRn7R3ZC16SYRuw==} peerDependencies: @@ -8087,7 +8292,7 @@ snapshots: idb: 8.0.3 tslib: 2.8.1 - '@antfu/eslint-config@7.6.1(@eslint-react/eslint-plugin@2.13.0(eslint@10.0.2(jiti@1.21.7))(typescript@5.9.3))(@next/eslint-plugin-next@16.1.6)(@typescript-eslint/rule-tester@8.56.1(eslint@10.0.2(jiti@1.21.7))(typescript@5.9.3))(@typescript-eslint/typescript-estree@8.56.1(typescript@5.9.3))(@typescript-eslint/utils@8.56.1(eslint@10.0.2(jiti@1.21.7))(typescript@5.9.3))(@vue/compiler-sfc@3.5.27)(eslint-plugin-react-hooks@7.0.1(eslint@10.0.2(jiti@1.21.7)))(eslint-plugin-react-refresh@0.5.2(eslint@10.0.2(jiti@1.21.7)))(eslint@10.0.2(jiti@1.21.7))(typescript@5.9.3)(vitest@4.0.18(@types/node@24.10.12)(jiti@1.21.7)(jsdom@27.3.0(canvas@3.2.1))(sass@1.93.2)(terser@5.46.0)(tsx@4.21.0)(yaml@2.8.2))': + '@antfu/eslint-config@7.6.1(@eslint-react/eslint-plugin@2.13.0(eslint@10.0.2(jiti@1.21.7))(typescript@5.9.3))(@next/eslint-plugin-next@16.1.6)(@typescript-eslint/rule-tester@8.56.1(eslint@10.0.2(jiti@1.21.7))(typescript@5.9.3))(@typescript-eslint/typescript-estree@8.56.1(typescript@5.9.3))(@typescript-eslint/utils@8.56.1(eslint@10.0.2(jiti@1.21.7))(typescript@5.9.3))(@vue/compiler-sfc@3.5.27)(eslint-plugin-react-hooks@7.0.1(eslint@10.0.2(jiti@1.21.7)))(eslint-plugin-react-refresh@0.5.2(eslint@10.0.2(jiti@1.21.7)))(eslint@10.0.2(jiti@1.21.7))(typescript@5.9.3)(vitest@4.0.18(@types/node@24.10.12)(jiti@1.21.7)(jsdom@27.3.0(canvas@3.2.1))(lightningcss@1.31.1)(sass@1.93.2)(terser@5.46.0)(tsx@4.21.0)(yaml@2.8.2))': dependencies: '@antfu/install-pkg': 1.1.0 '@clack/prompts': 1.0.1 @@ -8096,7 +8301,7 @@ snapshots: '@stylistic/eslint-plugin': https://pkg.pr.new/@stylistic/eslint-plugin@258f9d8(eslint@10.0.2(jiti@1.21.7)) '@typescript-eslint/eslint-plugin': 8.56.1(@typescript-eslint/parser@8.56.1(eslint@10.0.2(jiti@1.21.7))(typescript@5.9.3))(eslint@10.0.2(jiti@1.21.7))(typescript@5.9.3) '@typescript-eslint/parser': 8.56.1(eslint@10.0.2(jiti@1.21.7))(typescript@5.9.3) - '@vitest/eslint-plugin': 1.6.9(eslint@10.0.2(jiti@1.21.7))(typescript@5.9.3)(vitest@4.0.18(@types/node@24.10.12)(jiti@1.21.7)(jsdom@27.3.0(canvas@3.2.1))(sass@1.93.2)(terser@5.46.0)(tsx@4.21.0)(yaml@2.8.2)) + '@vitest/eslint-plugin': 1.6.9(eslint@10.0.2(jiti@1.21.7))(typescript@5.9.3)(vitest@4.0.18(@types/node@24.10.12)(jiti@1.21.7)(jsdom@27.3.0(canvas@3.2.1))(lightningcss@1.31.1)(sass@1.93.2)(terser@5.46.0)(tsx@4.21.0)(yaml@2.8.2)) ansis: 4.2.0 cac: 6.7.14 eslint: 10.0.2(jiti@1.21.7) @@ -9088,11 +9293,11 @@ snapshots: dependencies: minipass: 7.1.3 - '@joshwooding/vite-plugin-react-docgen-typescript@0.6.4(typescript@5.9.3)(vite@7.3.1(@types/node@24.10.12)(jiti@1.21.7)(sass@1.93.2)(terser@5.46.0)(tsx@4.21.0)(yaml@2.8.2))': + '@joshwooding/vite-plugin-react-docgen-typescript@0.6.4(typescript@5.9.3)(vite@8.0.0-beta.16(@types/node@24.10.12)(esbuild@0.27.2)(jiti@1.21.7)(sass@1.93.2)(terser@5.46.0)(tsx@4.21.0)(yaml@2.8.2))': dependencies: glob: 13.0.6 react-docgen-typescript: 2.4.0(typescript@5.9.3) - vite: 7.3.1(@types/node@24.10.12)(jiti@1.21.7)(sass@1.93.2)(terser@5.46.0)(tsx@4.21.0)(yaml@2.8.2) + vite: 8.0.0-beta.16(@types/node@24.10.12)(esbuild@0.27.2)(jiti@1.21.7)(sass@1.93.2)(terser@5.46.0)(tsx@4.21.0)(yaml@2.8.2) optionalDependencies: typescript: 5.9.3 @@ -9583,6 +9788,10 @@ snapshots: '@ota-meshi/ast-token-store@0.3.0': {} + '@oxc-project/runtime@0.115.0': {} + + '@oxc-project/types@0.115.0': {} + '@oxc-resolver/binding-android-arm-eabi@11.16.4': optional: true @@ -10012,10 +10221,53 @@ snapshots: '@rgrove/parse-xml@4.2.0': {} + '@rolldown/binding-android-arm64@1.0.0-rc.6': + optional: true + + '@rolldown/binding-darwin-arm64@1.0.0-rc.6': + optional: true + + '@rolldown/binding-darwin-x64@1.0.0-rc.6': + optional: true + + '@rolldown/binding-freebsd-x64@1.0.0-rc.6': + optional: true + + '@rolldown/binding-linux-arm-gnueabihf@1.0.0-rc.6': + optional: true + + '@rolldown/binding-linux-arm64-gnu@1.0.0-rc.6': + optional: true + + '@rolldown/binding-linux-arm64-musl@1.0.0-rc.6': + optional: true + + '@rolldown/binding-linux-x64-gnu@1.0.0-rc.6': + optional: true + + '@rolldown/binding-linux-x64-musl@1.0.0-rc.6': + optional: true + + '@rolldown/binding-openharmony-arm64@1.0.0-rc.6': + optional: true + + '@rolldown/binding-wasm32-wasi@1.0.0-rc.6': + dependencies: + '@napi-rs/wasm-runtime': 1.1.1 + optional: true + + '@rolldown/binding-win32-arm64-msvc@1.0.0-rc.6': + optional: true + + '@rolldown/binding-win32-x64-msvc@1.0.0-rc.6': + optional: true + '@rolldown/pluginutils@1.0.0-rc.3': {} '@rolldown/pluginutils@1.0.0-rc.5': {} + '@rolldown/pluginutils@1.0.0-rc.6': {} + '@rollup/plugin-replace@6.0.3(rollup@4.56.0)': dependencies: '@rollup/pluginutils': 5.3.0(rollup@4.56.0) @@ -10232,10 +10484,10 @@ snapshots: '@standard-schema/spec@1.1.0': {} - '@storybook/addon-docs@10.2.13(@types/react@19.2.9)(esbuild@0.27.2)(rollup@4.56.0)(storybook@10.2.13(@testing-library/dom@10.4.1)(react-dom@19.2.4(react@19.2.4))(react@19.2.4))(vite@7.3.1(@types/node@24.10.12)(jiti@1.21.7)(sass@1.93.2)(terser@5.46.0)(tsx@4.21.0)(yaml@2.8.2))(webpack@5.104.1(esbuild@0.27.2)(uglify-js@3.19.3))': + '@storybook/addon-docs@10.2.13(@types/react@19.2.9)(esbuild@0.27.2)(rollup@4.56.0)(storybook@10.2.13(@testing-library/dom@10.4.1)(react-dom@19.2.4(react@19.2.4))(react@19.2.4))(vite@8.0.0-beta.16(@types/node@24.10.12)(esbuild@0.27.2)(jiti@1.21.7)(sass@1.93.2)(terser@5.46.0)(tsx@4.21.0)(yaml@2.8.2))(webpack@5.104.1(esbuild@0.27.2)(uglify-js@3.19.3))': dependencies: '@mdx-js/react': 3.1.1(@types/react@19.2.9)(react@19.2.4) - '@storybook/csf-plugin': 10.2.13(esbuild@0.27.2)(rollup@4.56.0)(storybook@10.2.13(@testing-library/dom@10.4.1)(react-dom@19.2.4(react@19.2.4))(react@19.2.4))(vite@7.3.1(@types/node@24.10.12)(jiti@1.21.7)(sass@1.93.2)(terser@5.46.0)(tsx@4.21.0)(yaml@2.8.2))(webpack@5.104.1(esbuild@0.27.2)(uglify-js@3.19.3)) + '@storybook/csf-plugin': 10.2.13(esbuild@0.27.2)(rollup@4.56.0)(storybook@10.2.13(@testing-library/dom@10.4.1)(react-dom@19.2.4(react@19.2.4))(react@19.2.4))(vite@8.0.0-beta.16(@types/node@24.10.12)(esbuild@0.27.2)(jiti@1.21.7)(sass@1.93.2)(terser@5.46.0)(tsx@4.21.0)(yaml@2.8.2))(webpack@5.104.1(esbuild@0.27.2)(uglify-js@3.19.3)) '@storybook/icons': 2.0.1(react-dom@19.2.4(react@19.2.4))(react@19.2.4) '@storybook/react-dom-shim': 10.2.13(react-dom@19.2.4(react@19.2.4))(react@19.2.4)(storybook@10.2.13(@testing-library/dom@10.4.1)(react-dom@19.2.4(react@19.2.4))(react@19.2.4)) react: 19.2.4 @@ -10265,25 +10517,25 @@ snapshots: storybook: 10.2.13(@testing-library/dom@10.4.1)(react-dom@19.2.4(react@19.2.4))(react@19.2.4) ts-dedent: 2.2.0 - '@storybook/builder-vite@10.2.13(esbuild@0.27.2)(rollup@4.56.0)(storybook@10.2.13(@testing-library/dom@10.4.1)(react-dom@19.2.4(react@19.2.4))(react@19.2.4))(vite@7.3.1(@types/node@24.10.12)(jiti@1.21.7)(sass@1.93.2)(terser@5.46.0)(tsx@4.21.0)(yaml@2.8.2))(webpack@5.104.1(esbuild@0.27.2)(uglify-js@3.19.3))': + '@storybook/builder-vite@10.2.13(esbuild@0.27.2)(rollup@4.56.0)(storybook@10.2.13(@testing-library/dom@10.4.1)(react-dom@19.2.4(react@19.2.4))(react@19.2.4))(vite@8.0.0-beta.16(@types/node@24.10.12)(esbuild@0.27.2)(jiti@1.21.7)(sass@1.93.2)(terser@5.46.0)(tsx@4.21.0)(yaml@2.8.2))(webpack@5.104.1(esbuild@0.27.2)(uglify-js@3.19.3))': dependencies: - '@storybook/csf-plugin': 10.2.13(esbuild@0.27.2)(rollup@4.56.0)(storybook@10.2.13(@testing-library/dom@10.4.1)(react-dom@19.2.4(react@19.2.4))(react@19.2.4))(vite@7.3.1(@types/node@24.10.12)(jiti@1.21.7)(sass@1.93.2)(terser@5.46.0)(tsx@4.21.0)(yaml@2.8.2))(webpack@5.104.1(esbuild@0.27.2)(uglify-js@3.19.3)) + '@storybook/csf-plugin': 10.2.13(esbuild@0.27.2)(rollup@4.56.0)(storybook@10.2.13(@testing-library/dom@10.4.1)(react-dom@19.2.4(react@19.2.4))(react@19.2.4))(vite@8.0.0-beta.16(@types/node@24.10.12)(esbuild@0.27.2)(jiti@1.21.7)(sass@1.93.2)(terser@5.46.0)(tsx@4.21.0)(yaml@2.8.2))(webpack@5.104.1(esbuild@0.27.2)(uglify-js@3.19.3)) storybook: 10.2.13(@testing-library/dom@10.4.1)(react-dom@19.2.4(react@19.2.4))(react@19.2.4) ts-dedent: 2.2.0 - vite: 7.3.1(@types/node@24.10.12)(jiti@1.21.7)(sass@1.93.2)(terser@5.46.0)(tsx@4.21.0)(yaml@2.8.2) + vite: 8.0.0-beta.16(@types/node@24.10.12)(esbuild@0.27.2)(jiti@1.21.7)(sass@1.93.2)(terser@5.46.0)(tsx@4.21.0)(yaml@2.8.2) transitivePeerDependencies: - esbuild - rollup - webpack - '@storybook/csf-plugin@10.2.13(esbuild@0.27.2)(rollup@4.56.0)(storybook@10.2.13(@testing-library/dom@10.4.1)(react-dom@19.2.4(react@19.2.4))(react@19.2.4))(vite@7.3.1(@types/node@24.10.12)(jiti@1.21.7)(sass@1.93.2)(terser@5.46.0)(tsx@4.21.0)(yaml@2.8.2))(webpack@5.104.1(esbuild@0.27.2)(uglify-js@3.19.3))': + '@storybook/csf-plugin@10.2.13(esbuild@0.27.2)(rollup@4.56.0)(storybook@10.2.13(@testing-library/dom@10.4.1)(react-dom@19.2.4(react@19.2.4))(react@19.2.4))(vite@8.0.0-beta.16(@types/node@24.10.12)(esbuild@0.27.2)(jiti@1.21.7)(sass@1.93.2)(terser@5.46.0)(tsx@4.21.0)(yaml@2.8.2))(webpack@5.104.1(esbuild@0.27.2)(uglify-js@3.19.3))': dependencies: storybook: 10.2.13(@testing-library/dom@10.4.1)(react-dom@19.2.4(react@19.2.4))(react@19.2.4) unplugin: 2.3.11 optionalDependencies: esbuild: 0.27.2 rollup: 4.56.0 - vite: 7.3.1(@types/node@24.10.12)(jiti@1.21.7)(sass@1.93.2)(terser@5.46.0)(tsx@4.21.0)(yaml@2.8.2) + vite: 8.0.0-beta.16(@types/node@24.10.12)(esbuild@0.27.2)(jiti@1.21.7)(sass@1.93.2)(terser@5.46.0)(tsx@4.21.0)(yaml@2.8.2) webpack: 5.104.1(esbuild@0.27.2)(uglify-js@3.19.3) '@storybook/global@5.0.0': {} @@ -10293,18 +10545,18 @@ snapshots: react: 19.2.4 react-dom: 19.2.4(react@19.2.4) - '@storybook/nextjs-vite@10.2.13(@babel/core@7.29.0)(esbuild@0.27.2)(next@16.1.5(@babel/core@7.29.0)(react-dom@19.2.4(react@19.2.4))(react@19.2.4)(sass@1.93.2))(react-dom@19.2.4(react@19.2.4))(react@19.2.4)(rollup@4.56.0)(storybook@10.2.13(@testing-library/dom@10.4.1)(react-dom@19.2.4(react@19.2.4))(react@19.2.4))(typescript@5.9.3)(vite@7.3.1(@types/node@24.10.12)(jiti@1.21.7)(sass@1.93.2)(terser@5.46.0)(tsx@4.21.0)(yaml@2.8.2))(webpack@5.104.1(esbuild@0.27.2)(uglify-js@3.19.3))': + '@storybook/nextjs-vite@10.2.13(@babel/core@7.29.0)(esbuild@0.27.2)(next@16.1.5(@babel/core@7.29.0)(react-dom@19.2.4(react@19.2.4))(react@19.2.4)(sass@1.93.2))(react-dom@19.2.4(react@19.2.4))(react@19.2.4)(rollup@4.56.0)(storybook@10.2.13(@testing-library/dom@10.4.1)(react-dom@19.2.4(react@19.2.4))(react@19.2.4))(typescript@5.9.3)(vite@8.0.0-beta.16(@types/node@24.10.12)(esbuild@0.27.2)(jiti@1.21.7)(sass@1.93.2)(terser@5.46.0)(tsx@4.21.0)(yaml@2.8.2))(webpack@5.104.1(esbuild@0.27.2)(uglify-js@3.19.3))': dependencies: - '@storybook/builder-vite': 10.2.13(esbuild@0.27.2)(rollup@4.56.0)(storybook@10.2.13(@testing-library/dom@10.4.1)(react-dom@19.2.4(react@19.2.4))(react@19.2.4))(vite@7.3.1(@types/node@24.10.12)(jiti@1.21.7)(sass@1.93.2)(terser@5.46.0)(tsx@4.21.0)(yaml@2.8.2))(webpack@5.104.1(esbuild@0.27.2)(uglify-js@3.19.3)) + '@storybook/builder-vite': 10.2.13(esbuild@0.27.2)(rollup@4.56.0)(storybook@10.2.13(@testing-library/dom@10.4.1)(react-dom@19.2.4(react@19.2.4))(react@19.2.4))(vite@8.0.0-beta.16(@types/node@24.10.12)(esbuild@0.27.2)(jiti@1.21.7)(sass@1.93.2)(terser@5.46.0)(tsx@4.21.0)(yaml@2.8.2))(webpack@5.104.1(esbuild@0.27.2)(uglify-js@3.19.3)) '@storybook/react': 10.2.13(react-dom@19.2.4(react@19.2.4))(react@19.2.4)(storybook@10.2.13(@testing-library/dom@10.4.1)(react-dom@19.2.4(react@19.2.4))(react@19.2.4))(typescript@5.9.3) - '@storybook/react-vite': 10.2.13(esbuild@0.27.2)(react-dom@19.2.4(react@19.2.4))(react@19.2.4)(rollup@4.56.0)(storybook@10.2.13(@testing-library/dom@10.4.1)(react-dom@19.2.4(react@19.2.4))(react@19.2.4))(typescript@5.9.3)(vite@7.3.1(@types/node@24.10.12)(jiti@1.21.7)(sass@1.93.2)(terser@5.46.0)(tsx@4.21.0)(yaml@2.8.2))(webpack@5.104.1(esbuild@0.27.2)(uglify-js@3.19.3)) + '@storybook/react-vite': 10.2.13(esbuild@0.27.2)(react-dom@19.2.4(react@19.2.4))(react@19.2.4)(rollup@4.56.0)(storybook@10.2.13(@testing-library/dom@10.4.1)(react-dom@19.2.4(react@19.2.4))(react@19.2.4))(typescript@5.9.3)(vite@8.0.0-beta.16(@types/node@24.10.12)(esbuild@0.27.2)(jiti@1.21.7)(sass@1.93.2)(terser@5.46.0)(tsx@4.21.0)(yaml@2.8.2))(webpack@5.104.1(esbuild@0.27.2)(uglify-js@3.19.3)) next: 16.1.5(@babel/core@7.29.0)(react-dom@19.2.4(react@19.2.4))(react@19.2.4)(sass@1.93.2) react: 19.2.4 react-dom: 19.2.4(react@19.2.4) storybook: 10.2.13(@testing-library/dom@10.4.1)(react-dom@19.2.4(react@19.2.4))(react@19.2.4) styled-jsx: 5.1.6(@babel/core@7.29.0)(react@19.2.4) - vite: 7.3.1(@types/node@24.10.12)(jiti@1.21.7)(sass@1.93.2)(terser@5.46.0)(tsx@4.21.0)(yaml@2.8.2) - vite-plugin-storybook-nextjs: 3.2.2(next@16.1.5(@babel/core@7.29.0)(react-dom@19.2.4(react@19.2.4))(react@19.2.4)(sass@1.93.2))(storybook@10.2.13(@testing-library/dom@10.4.1)(react-dom@19.2.4(react@19.2.4))(react@19.2.4))(typescript@5.9.3)(vite@7.3.1(@types/node@24.10.12)(jiti@1.21.7)(sass@1.93.2)(terser@5.46.0)(tsx@4.21.0)(yaml@2.8.2)) + vite: 8.0.0-beta.16(@types/node@24.10.12)(esbuild@0.27.2)(jiti@1.21.7)(sass@1.93.2)(terser@5.46.0)(tsx@4.21.0)(yaml@2.8.2) + vite-plugin-storybook-nextjs: 3.2.2(next@16.1.5(@babel/core@7.29.0)(react-dom@19.2.4(react@19.2.4))(react@19.2.4)(sass@1.93.2))(storybook@10.2.13(@testing-library/dom@10.4.1)(react-dom@19.2.4(react@19.2.4))(react@19.2.4))(typescript@5.9.3)(vite@8.0.0-beta.16(@types/node@24.10.12)(esbuild@0.27.2)(jiti@1.21.7)(sass@1.93.2)(terser@5.46.0)(tsx@4.21.0)(yaml@2.8.2)) optionalDependencies: typescript: 5.9.3 transitivePeerDependencies: @@ -10321,11 +10573,11 @@ snapshots: react-dom: 19.2.4(react@19.2.4) storybook: 10.2.13(@testing-library/dom@10.4.1)(react-dom@19.2.4(react@19.2.4))(react@19.2.4) - '@storybook/react-vite@10.2.13(esbuild@0.27.2)(react-dom@19.2.4(react@19.2.4))(react@19.2.4)(rollup@4.56.0)(storybook@10.2.13(@testing-library/dom@10.4.1)(react-dom@19.2.4(react@19.2.4))(react@19.2.4))(typescript@5.9.3)(vite@7.3.1(@types/node@24.10.12)(jiti@1.21.7)(sass@1.93.2)(terser@5.46.0)(tsx@4.21.0)(yaml@2.8.2))(webpack@5.104.1(esbuild@0.27.2)(uglify-js@3.19.3))': + '@storybook/react-vite@10.2.13(esbuild@0.27.2)(react-dom@19.2.4(react@19.2.4))(react@19.2.4)(rollup@4.56.0)(storybook@10.2.13(@testing-library/dom@10.4.1)(react-dom@19.2.4(react@19.2.4))(react@19.2.4))(typescript@5.9.3)(vite@8.0.0-beta.16(@types/node@24.10.12)(esbuild@0.27.2)(jiti@1.21.7)(sass@1.93.2)(terser@5.46.0)(tsx@4.21.0)(yaml@2.8.2))(webpack@5.104.1(esbuild@0.27.2)(uglify-js@3.19.3))': dependencies: - '@joshwooding/vite-plugin-react-docgen-typescript': 0.6.4(typescript@5.9.3)(vite@7.3.1(@types/node@24.10.12)(jiti@1.21.7)(sass@1.93.2)(terser@5.46.0)(tsx@4.21.0)(yaml@2.8.2)) + '@joshwooding/vite-plugin-react-docgen-typescript': 0.6.4(typescript@5.9.3)(vite@8.0.0-beta.16(@types/node@24.10.12)(esbuild@0.27.2)(jiti@1.21.7)(sass@1.93.2)(terser@5.46.0)(tsx@4.21.0)(yaml@2.8.2)) '@rollup/pluginutils': 5.3.0(rollup@4.56.0) - '@storybook/builder-vite': 10.2.13(esbuild@0.27.2)(rollup@4.56.0)(storybook@10.2.13(@testing-library/dom@10.4.1)(react-dom@19.2.4(react@19.2.4))(react@19.2.4))(vite@7.3.1(@types/node@24.10.12)(jiti@1.21.7)(sass@1.93.2)(terser@5.46.0)(tsx@4.21.0)(yaml@2.8.2))(webpack@5.104.1(esbuild@0.27.2)(uglify-js@3.19.3)) + '@storybook/builder-vite': 10.2.13(esbuild@0.27.2)(rollup@4.56.0)(storybook@10.2.13(@testing-library/dom@10.4.1)(react-dom@19.2.4(react@19.2.4))(react@19.2.4))(vite@8.0.0-beta.16(@types/node@24.10.12)(esbuild@0.27.2)(jiti@1.21.7)(sass@1.93.2)(terser@5.46.0)(tsx@4.21.0)(yaml@2.8.2))(webpack@5.104.1(esbuild@0.27.2)(uglify-js@3.19.3)) '@storybook/react': 10.2.13(react-dom@19.2.4(react@19.2.4))(react@19.2.4)(storybook@10.2.13(@testing-library/dom@10.4.1)(react-dom@19.2.4(react@19.2.4))(react@19.2.4))(typescript@5.9.3) empathic: 2.0.0 magic-string: 0.30.21 @@ -10335,7 +10587,7 @@ snapshots: resolve: 1.22.11 storybook: 10.2.13(@testing-library/dom@10.4.1)(react-dom@19.2.4(react@19.2.4))(react@19.2.4) tsconfig-paths: 4.2.0 - vite: 7.3.1(@types/node@24.10.12)(jiti@1.21.7)(sass@1.93.2)(terser@5.46.0)(tsx@4.21.0)(yaml@2.8.2) + vite: 8.0.0-beta.16(@types/node@24.10.12)(esbuild@0.27.2)(jiti@1.21.7)(sass@1.93.2)(terser@5.46.0)(tsx@4.21.0)(yaml@2.8.2) transitivePeerDependencies: - esbuild - rollup @@ -11171,7 +11423,7 @@ snapshots: '@resvg/resvg-wasm': 2.4.0 satori: 0.16.0 - '@vitejs/plugin-react@5.1.4(vite@7.3.1(@types/node@24.10.12)(jiti@1.21.7)(sass@1.93.2)(terser@5.46.0)(tsx@4.21.0)(yaml@2.8.2))': + '@vitejs/plugin-react@5.1.4(vite@8.0.0-beta.16(@types/node@24.10.12)(esbuild@0.27.2)(jiti@1.21.7)(sass@1.93.2)(terser@5.46.0)(tsx@4.21.0)(yaml@2.8.2))': dependencies: '@babel/core': 7.29.0 '@babel/plugin-transform-react-jsx-self': 7.27.1(@babel/core@7.29.0) @@ -11179,11 +11431,11 @@ snapshots: '@rolldown/pluginutils': 1.0.0-rc.3 '@types/babel__core': 7.20.5 react-refresh: 0.18.0 - vite: 7.3.1(@types/node@24.10.12)(jiti@1.21.7)(sass@1.93.2)(terser@5.46.0)(tsx@4.21.0)(yaml@2.8.2) + vite: 8.0.0-beta.16(@types/node@24.10.12)(esbuild@0.27.2)(jiti@1.21.7)(sass@1.93.2)(terser@5.46.0)(tsx@4.21.0)(yaml@2.8.2) transitivePeerDependencies: - supports-color - '@vitejs/plugin-rsc@0.5.21(react-dom@19.2.4(react@19.2.4))(react-server-dom-webpack@19.2.4(react-dom@19.2.4(react@19.2.4))(react@19.2.4)(webpack@5.104.1(esbuild@0.27.2)(uglify-js@3.19.3)))(react@19.2.4)(vite@7.3.1(@types/node@24.10.12)(jiti@1.21.7)(sass@1.93.2)(terser@5.46.0)(tsx@4.21.0)(yaml@2.8.2))': + '@vitejs/plugin-rsc@0.5.21(react-dom@19.2.4(react@19.2.4))(react-server-dom-webpack@19.2.4(react-dom@19.2.4(react@19.2.4))(react@19.2.4)(webpack@5.104.1(esbuild@0.27.2)(uglify-js@3.19.3)))(react@19.2.4)(vite@8.0.0-beta.16(@types/node@24.10.12)(esbuild@0.27.2)(jiti@1.21.7)(sass@1.93.2)(terser@5.46.0)(tsx@4.21.0)(yaml@2.8.2))': dependencies: '@rolldown/pluginutils': 1.0.0-rc.5 es-module-lexer: 2.0.0 @@ -11195,12 +11447,12 @@ snapshots: srvx: 0.11.7 strip-literal: 3.1.0 turbo-stream: 3.1.0 - vite: 7.3.1(@types/node@24.10.12)(jiti@1.21.7)(sass@1.93.2)(terser@5.46.0)(tsx@4.21.0)(yaml@2.8.2) - vitefu: 1.1.2(vite@7.3.1(@types/node@24.10.12)(jiti@1.21.7)(sass@1.93.2)(terser@5.46.0)(tsx@4.21.0)(yaml@2.8.2)) + vite: 8.0.0-beta.16(@types/node@24.10.12)(esbuild@0.27.2)(jiti@1.21.7)(sass@1.93.2)(terser@5.46.0)(tsx@4.21.0)(yaml@2.8.2) + vitefu: 1.1.2(vite@8.0.0-beta.16(@types/node@24.10.12)(esbuild@0.27.2)(jiti@1.21.7)(sass@1.93.2)(terser@5.46.0)(tsx@4.21.0)(yaml@2.8.2)) optionalDependencies: react-server-dom-webpack: 19.2.4(react-dom@19.2.4(react@19.2.4))(react@19.2.4)(webpack@5.104.1(esbuild@0.27.2)(uglify-js@3.19.3)) - '@vitest/coverage-v8@4.0.18(vitest@4.0.18(@types/node@24.10.12)(jiti@1.21.7)(jsdom@27.3.0(canvas@3.2.1))(sass@1.93.2)(terser@5.46.0)(tsx@4.21.0)(yaml@2.8.2))': + '@vitest/coverage-v8@4.0.18(vitest@4.0.18(@types/node@24.10.12)(jiti@1.21.7)(jsdom@27.3.0(canvas@3.2.1))(lightningcss@1.31.1)(sass@1.93.2)(terser@5.46.0)(tsx@4.21.0)(yaml@2.8.2))': dependencies: '@bcoe/v8-coverage': 1.0.2 '@vitest/utils': 4.0.18 @@ -11212,16 +11464,16 @@ snapshots: obug: 2.1.1 std-env: 3.10.0 tinyrainbow: 3.0.3 - vitest: 4.0.18(@types/node@24.10.12)(jiti@1.21.7)(jsdom@27.3.0(canvas@3.2.1))(sass@1.93.2)(terser@5.46.0)(tsx@4.21.0)(yaml@2.8.2) + vitest: 4.0.18(@types/node@24.10.12)(jiti@1.21.7)(jsdom@27.3.0(canvas@3.2.1))(lightningcss@1.31.1)(sass@1.93.2)(terser@5.46.0)(tsx@4.21.0)(yaml@2.8.2) - '@vitest/eslint-plugin@1.6.9(eslint@10.0.2(jiti@1.21.7))(typescript@5.9.3)(vitest@4.0.18(@types/node@24.10.12)(jiti@1.21.7)(jsdom@27.3.0(canvas@3.2.1))(sass@1.93.2)(terser@5.46.0)(tsx@4.21.0)(yaml@2.8.2))': + '@vitest/eslint-plugin@1.6.9(eslint@10.0.2(jiti@1.21.7))(typescript@5.9.3)(vitest@4.0.18(@types/node@24.10.12)(jiti@1.21.7)(jsdom@27.3.0(canvas@3.2.1))(lightningcss@1.31.1)(sass@1.93.2)(terser@5.46.0)(tsx@4.21.0)(yaml@2.8.2))': dependencies: '@typescript-eslint/scope-manager': 8.56.1 '@typescript-eslint/utils': 8.56.1(eslint@10.0.2(jiti@1.21.7))(typescript@5.9.3) eslint: 10.0.2(jiti@1.21.7) optionalDependencies: typescript: 5.9.3 - vitest: 4.0.18(@types/node@24.10.12)(jiti@1.21.7)(jsdom@27.3.0(canvas@3.2.1))(sass@1.93.2)(terser@5.46.0)(tsx@4.21.0)(yaml@2.8.2) + vitest: 4.0.18(@types/node@24.10.12)(jiti@1.21.7)(jsdom@27.3.0(canvas@3.2.1))(lightningcss@1.31.1)(sass@1.93.2)(terser@5.46.0)(tsx@4.21.0)(yaml@2.8.2) transitivePeerDependencies: - supports-color @@ -11242,13 +11494,13 @@ snapshots: chai: 6.2.2 tinyrainbow: 3.0.3 - '@vitest/mocker@4.0.18(vite@7.3.1(@types/node@24.10.12)(jiti@1.21.7)(sass@1.93.2)(terser@5.46.0)(tsx@4.21.0)(yaml@2.8.2))': + '@vitest/mocker@4.0.18(vite@7.3.1(@types/node@24.10.12)(jiti@1.21.7)(lightningcss@1.31.1)(sass@1.93.2)(terser@5.46.0)(tsx@4.21.0)(yaml@2.8.2))': dependencies: '@vitest/spy': 4.0.18 estree-walker: 3.0.3 magic-string: 0.30.21 optionalDependencies: - vite: 7.3.1(@types/node@24.10.12)(jiti@1.21.7)(sass@1.93.2)(terser@5.46.0)(tsx@4.21.0)(yaml@2.8.2) + vite: 7.3.1(@types/node@24.10.12)(jiti@1.21.7)(lightningcss@1.31.1)(sass@1.93.2)(terser@5.46.0)(tsx@4.21.0)(yaml@2.8.2) '@vitest/pretty-format@3.2.4': dependencies: @@ -13635,6 +13887,55 @@ snapshots: dependencies: isomorphic.js: 0.2.5 + lightningcss-android-arm64@1.31.1: + optional: true + + lightningcss-darwin-arm64@1.31.1: + optional: true + + lightningcss-darwin-x64@1.31.1: + optional: true + + lightningcss-freebsd-x64@1.31.1: + optional: true + + lightningcss-linux-arm-gnueabihf@1.31.1: + optional: true + + lightningcss-linux-arm64-gnu@1.31.1: + optional: true + + lightningcss-linux-arm64-musl@1.31.1: + optional: true + + lightningcss-linux-x64-gnu@1.31.1: + optional: true + + lightningcss-linux-x64-musl@1.31.1: + optional: true + + lightningcss-win32-arm64-msvc@1.31.1: + optional: true + + lightningcss-win32-x64-msvc@1.31.1: + optional: true + + lightningcss@1.31.1: + dependencies: + detect-libc: 2.1.2 + optionalDependencies: + lightningcss-android-arm64: 1.31.1 + lightningcss-darwin-arm64: 1.31.1 + lightningcss-darwin-x64: 1.31.1 + lightningcss-freebsd-x64: 1.31.1 + lightningcss-linux-arm-gnueabihf: 1.31.1 + lightningcss-linux-arm64-gnu: 1.31.1 + lightningcss-linux-arm64-musl: 1.31.1 + lightningcss-linux-x64-gnu: 1.31.1 + lightningcss-linux-x64-musl: 1.31.1 + lightningcss-win32-arm64-msvc: 1.31.1 + lightningcss-win32-x64-msvc: 1.31.1 + lilconfig@3.1.3: {} linebreak@1.1.0: @@ -15205,6 +15506,25 @@ snapshots: robust-predicates@3.0.2: {} + rolldown@1.0.0-rc.6: + dependencies: + '@oxc-project/types': 0.115.0 + '@rolldown/pluginutils': 1.0.0-rc.6 + optionalDependencies: + '@rolldown/binding-android-arm64': 1.0.0-rc.6 + '@rolldown/binding-darwin-arm64': 1.0.0-rc.6 + '@rolldown/binding-darwin-x64': 1.0.0-rc.6 + '@rolldown/binding-freebsd-x64': 1.0.0-rc.6 + '@rolldown/binding-linux-arm-gnueabihf': 1.0.0-rc.6 + '@rolldown/binding-linux-arm64-gnu': 1.0.0-rc.6 + '@rolldown/binding-linux-arm64-musl': 1.0.0-rc.6 + '@rolldown/binding-linux-x64-gnu': 1.0.0-rc.6 + '@rolldown/binding-linux-x64-musl': 1.0.0-rc.6 + '@rolldown/binding-openharmony-arm64': 1.0.0-rc.6 + '@rolldown/binding-wasm32-wasi': 1.0.0-rc.6 + '@rolldown/binding-win32-arm64-msvc': 1.0.0-rc.6 + '@rolldown/binding-win32-x64-msvc': 1.0.0-rc.6 + rollup@4.56.0: dependencies: '@types/estree': 1.0.8 @@ -15971,19 +16291,19 @@ snapshots: '@types/unist': 3.0.3 vfile-message: 4.0.3 - vinext@https://pkg.pr.new/hyoban/vinext@cfae669(next@16.1.5(@babel/core@7.29.0)(react-dom@19.2.4(react@19.2.4))(react@19.2.4)(sass@1.93.2))(react-dom@19.2.4(react@19.2.4))(react@19.2.4)(typescript@5.9.3)(vite@7.3.1(@types/node@24.10.12)(jiti@1.21.7)(sass@1.93.2)(terser@5.46.0)(tsx@4.21.0)(yaml@2.8.2))(webpack@5.104.1(esbuild@0.27.2)(uglify-js@3.19.3)): + vinext@https://pkg.pr.new/hyoban/vinext@cfae669(next@16.1.5(@babel/core@7.29.0)(react-dom@19.2.4(react@19.2.4))(react@19.2.4)(sass@1.93.2))(react-dom@19.2.4(react@19.2.4))(react@19.2.4)(typescript@5.9.3)(vite@8.0.0-beta.16(@types/node@24.10.12)(esbuild@0.27.2)(jiti@1.21.7)(sass@1.93.2)(terser@5.46.0)(tsx@4.21.0)(yaml@2.8.2))(webpack@5.104.1(esbuild@0.27.2)(uglify-js@3.19.3)): dependencies: '@unpic/react': 1.0.2(next@16.1.5(@babel/core@7.29.0)(react-dom@19.2.4(react@19.2.4))(react@19.2.4)(sass@1.93.2))(react-dom@19.2.4(react@19.2.4))(react@19.2.4) '@vercel/og': 0.8.6 - '@vitejs/plugin-rsc': 0.5.21(react-dom@19.2.4(react@19.2.4))(react-server-dom-webpack@19.2.4(react-dom@19.2.4(react@19.2.4))(react@19.2.4)(webpack@5.104.1(esbuild@0.27.2)(uglify-js@3.19.3)))(react@19.2.4)(vite@7.3.1(@types/node@24.10.12)(jiti@1.21.7)(sass@1.93.2)(terser@5.46.0)(tsx@4.21.0)(yaml@2.8.2)) + '@vitejs/plugin-rsc': 0.5.21(react-dom@19.2.4(react@19.2.4))(react-server-dom-webpack@19.2.4(react-dom@19.2.4(react@19.2.4))(react@19.2.4)(webpack@5.104.1(esbuild@0.27.2)(uglify-js@3.19.3)))(react@19.2.4)(vite@8.0.0-beta.16(@types/node@24.10.12)(esbuild@0.27.2)(jiti@1.21.7)(sass@1.93.2)(terser@5.46.0)(tsx@4.21.0)(yaml@2.8.2)) magic-string: 0.30.21 react: 19.2.4 react-dom: 19.2.4(react@19.2.4) react-server-dom-webpack: 19.2.4(react-dom@19.2.4(react@19.2.4))(react@19.2.4)(webpack@5.104.1(esbuild@0.27.2)(uglify-js@3.19.3)) rsc-html-stream: 0.0.7 - vite: 7.3.1(@types/node@24.10.12)(jiti@1.21.7)(sass@1.93.2)(terser@5.46.0)(tsx@4.21.0)(yaml@2.8.2) + vite: 8.0.0-beta.16(@types/node@24.10.12)(esbuild@0.27.2)(jiti@1.21.7)(sass@1.93.2)(terser@5.46.0)(tsx@4.21.0)(yaml@2.8.2) vite-plugin-commonjs: 0.10.4 - vite-tsconfig-paths: 6.1.1(typescript@5.9.3)(vite@7.3.1(@types/node@24.10.12)(jiti@1.21.7)(sass@1.93.2)(terser@5.46.0)(tsx@4.21.0)(yaml@2.8.2)) + vite-tsconfig-paths: 6.1.1(typescript@5.9.3)(vite@8.0.0-beta.16(@types/node@24.10.12)(esbuild@0.27.2)(jiti@1.21.7)(sass@1.93.2)(terser@5.46.0)(tsx@4.21.0)(yaml@2.8.2)) transitivePeerDependencies: - next - supports-color @@ -16003,7 +16323,7 @@ snapshots: fast-glob: 3.3.3 magic-string: 0.30.21 - vite-plugin-storybook-nextjs@3.2.2(next@16.1.5(@babel/core@7.29.0)(react-dom@19.2.4(react@19.2.4))(react@19.2.4)(sass@1.93.2))(storybook@10.2.13(@testing-library/dom@10.4.1)(react-dom@19.2.4(react@19.2.4))(react@19.2.4))(typescript@5.9.3)(vite@7.3.1(@types/node@24.10.12)(jiti@1.21.7)(sass@1.93.2)(terser@5.46.0)(tsx@4.21.0)(yaml@2.8.2)): + vite-plugin-storybook-nextjs@3.2.2(next@16.1.5(@babel/core@7.29.0)(react-dom@19.2.4(react@19.2.4))(react@19.2.4)(sass@1.93.2))(storybook@10.2.13(@testing-library/dom@10.4.1)(react-dom@19.2.4(react@19.2.4))(react@19.2.4))(typescript@5.9.3)(vite@8.0.0-beta.16(@types/node@24.10.12)(esbuild@0.27.2)(jiti@1.21.7)(sass@1.93.2)(terser@5.46.0)(tsx@4.21.0)(yaml@2.8.2)): dependencies: '@next/env': 16.0.0 image-size: 2.0.2 @@ -16012,34 +16332,34 @@ snapshots: next: 16.1.5(@babel/core@7.29.0)(react-dom@19.2.4(react@19.2.4))(react@19.2.4)(sass@1.93.2) storybook: 10.2.13(@testing-library/dom@10.4.1)(react-dom@19.2.4(react@19.2.4))(react@19.2.4) ts-dedent: 2.2.0 - vite: 7.3.1(@types/node@24.10.12)(jiti@1.21.7)(sass@1.93.2)(terser@5.46.0)(tsx@4.21.0)(yaml@2.8.2) - vite-tsconfig-paths: 5.1.4(typescript@5.9.3)(vite@7.3.1(@types/node@24.10.12)(jiti@1.21.7)(sass@1.93.2)(terser@5.46.0)(tsx@4.21.0)(yaml@2.8.2)) + vite: 8.0.0-beta.16(@types/node@24.10.12)(esbuild@0.27.2)(jiti@1.21.7)(sass@1.93.2)(terser@5.46.0)(tsx@4.21.0)(yaml@2.8.2) + vite-tsconfig-paths: 5.1.4(typescript@5.9.3)(vite@8.0.0-beta.16(@types/node@24.10.12)(esbuild@0.27.2)(jiti@1.21.7)(sass@1.93.2)(terser@5.46.0)(tsx@4.21.0)(yaml@2.8.2)) transitivePeerDependencies: - supports-color - typescript - vite-tsconfig-paths@5.1.4(typescript@5.9.3)(vite@7.3.1(@types/node@24.10.12)(jiti@1.21.7)(sass@1.93.2)(terser@5.46.0)(tsx@4.21.0)(yaml@2.8.2)): + vite-tsconfig-paths@5.1.4(typescript@5.9.3)(vite@8.0.0-beta.16(@types/node@24.10.12)(esbuild@0.27.2)(jiti@1.21.7)(sass@1.93.2)(terser@5.46.0)(tsx@4.21.0)(yaml@2.8.2)): dependencies: debug: 4.4.3 globrex: 0.1.2 tsconfck: 3.1.6(typescript@5.9.3) optionalDependencies: - vite: 7.3.1(@types/node@24.10.12)(jiti@1.21.7)(sass@1.93.2)(terser@5.46.0)(tsx@4.21.0)(yaml@2.8.2) + vite: 8.0.0-beta.16(@types/node@24.10.12)(esbuild@0.27.2)(jiti@1.21.7)(sass@1.93.2)(terser@5.46.0)(tsx@4.21.0)(yaml@2.8.2) transitivePeerDependencies: - supports-color - typescript - vite-tsconfig-paths@6.1.1(typescript@5.9.3)(vite@7.3.1(@types/node@24.10.12)(jiti@1.21.7)(sass@1.93.2)(terser@5.46.0)(tsx@4.21.0)(yaml@2.8.2)): + vite-tsconfig-paths@6.1.1(typescript@5.9.3)(vite@8.0.0-beta.16(@types/node@24.10.12)(esbuild@0.27.2)(jiti@1.21.7)(sass@1.93.2)(terser@5.46.0)(tsx@4.21.0)(yaml@2.8.2)): dependencies: debug: 4.4.3 globrex: 0.1.2 tsconfck: 3.1.6(typescript@5.9.3) - vite: 7.3.1(@types/node@24.10.12)(jiti@1.21.7)(sass@1.93.2)(terser@5.46.0)(tsx@4.21.0)(yaml@2.8.2) + vite: 8.0.0-beta.16(@types/node@24.10.12)(esbuild@0.27.2)(jiti@1.21.7)(sass@1.93.2)(terser@5.46.0)(tsx@4.21.0)(yaml@2.8.2) transitivePeerDependencies: - supports-color - typescript - vite@7.3.1(@types/node@24.10.12)(jiti@1.21.7)(sass@1.93.2)(terser@5.46.0)(tsx@4.21.0)(yaml@2.8.2): + vite@7.3.1(@types/node@24.10.12)(jiti@1.21.7)(lightningcss@1.31.1)(sass@1.93.2)(terser@5.46.0)(tsx@4.21.0)(yaml@2.8.2): dependencies: esbuild: 0.27.2 fdir: 6.5.0(picomatch@4.0.3) @@ -16051,25 +16371,44 @@ snapshots: '@types/node': 24.10.12 fsevents: 2.3.3 jiti: 1.21.7 + lightningcss: 1.31.1 sass: 1.93.2 terser: 5.46.0 tsx: 4.21.0 yaml: 2.8.2 - vitefu@1.1.2(vite@7.3.1(@types/node@24.10.12)(jiti@1.21.7)(sass@1.93.2)(terser@5.46.0)(tsx@4.21.0)(yaml@2.8.2)): + vite@8.0.0-beta.16(@types/node@24.10.12)(esbuild@0.27.2)(jiti@1.21.7)(sass@1.93.2)(terser@5.46.0)(tsx@4.21.0)(yaml@2.8.2): + dependencies: + '@oxc-project/runtime': 0.115.0 + lightningcss: 1.31.1 + picomatch: 4.0.3 + postcss: 8.5.6 + rolldown: 1.0.0-rc.6 + tinyglobby: 0.2.15 optionalDependencies: - vite: 7.3.1(@types/node@24.10.12)(jiti@1.21.7)(sass@1.93.2)(terser@5.46.0)(tsx@4.21.0)(yaml@2.8.2) + '@types/node': 24.10.12 + esbuild: 0.27.2 + fsevents: 2.3.3 + jiti: 1.21.7 + sass: 1.93.2 + terser: 5.46.0 + tsx: 4.21.0 + yaml: 2.8.2 - vitest-canvas-mock@1.1.3(vitest@4.0.18(@types/node@24.10.12)(jiti@1.21.7)(jsdom@27.3.0(canvas@3.2.1))(sass@1.93.2)(terser@5.46.0)(tsx@4.21.0)(yaml@2.8.2)): + vitefu@1.1.2(vite@8.0.0-beta.16(@types/node@24.10.12)(esbuild@0.27.2)(jiti@1.21.7)(sass@1.93.2)(terser@5.46.0)(tsx@4.21.0)(yaml@2.8.2)): + optionalDependencies: + vite: 8.0.0-beta.16(@types/node@24.10.12)(esbuild@0.27.2)(jiti@1.21.7)(sass@1.93.2)(terser@5.46.0)(tsx@4.21.0)(yaml@2.8.2) + + vitest-canvas-mock@1.1.3(vitest@4.0.18(@types/node@24.10.12)(jiti@1.21.7)(jsdom@27.3.0(canvas@3.2.1))(lightningcss@1.31.1)(sass@1.93.2)(terser@5.46.0)(tsx@4.21.0)(yaml@2.8.2)): dependencies: cssfontparser: 1.2.1 moo-color: 1.0.3 - vitest: 4.0.18(@types/node@24.10.12)(jiti@1.21.7)(jsdom@27.3.0(canvas@3.2.1))(sass@1.93.2)(terser@5.46.0)(tsx@4.21.0)(yaml@2.8.2) + vitest: 4.0.18(@types/node@24.10.12)(jiti@1.21.7)(jsdom@27.3.0(canvas@3.2.1))(lightningcss@1.31.1)(sass@1.93.2)(terser@5.46.0)(tsx@4.21.0)(yaml@2.8.2) - vitest@4.0.18(@types/node@24.10.12)(jiti@1.21.7)(jsdom@27.3.0(canvas@3.2.1))(sass@1.93.2)(terser@5.46.0)(tsx@4.21.0)(yaml@2.8.2): + vitest@4.0.18(@types/node@24.10.12)(jiti@1.21.7)(jsdom@27.3.0(canvas@3.2.1))(lightningcss@1.31.1)(sass@1.93.2)(terser@5.46.0)(tsx@4.21.0)(yaml@2.8.2): dependencies: '@vitest/expect': 4.0.18 - '@vitest/mocker': 4.0.18(vite@7.3.1(@types/node@24.10.12)(jiti@1.21.7)(sass@1.93.2)(terser@5.46.0)(tsx@4.21.0)(yaml@2.8.2)) + '@vitest/mocker': 4.0.18(vite@7.3.1(@types/node@24.10.12)(jiti@1.21.7)(lightningcss@1.31.1)(sass@1.93.2)(terser@5.46.0)(tsx@4.21.0)(yaml@2.8.2)) '@vitest/pretty-format': 4.0.18 '@vitest/runner': 4.0.18 '@vitest/snapshot': 4.0.18 @@ -16086,7 +16425,7 @@ snapshots: tinyexec: 1.0.2 tinyglobby: 0.2.15 tinyrainbow: 3.0.3 - vite: 7.3.1(@types/node@24.10.12)(jiti@1.21.7)(sass@1.93.2)(terser@5.46.0)(tsx@4.21.0)(yaml@2.8.2) + vite: 7.3.1(@types/node@24.10.12)(jiti@1.21.7)(lightningcss@1.31.1)(sass@1.93.2)(terser@5.46.0)(tsx@4.21.0)(yaml@2.8.2) why-is-node-running: 2.3.0 optionalDependencies: '@types/node': 24.10.12 diff --git a/web/vite.config.ts b/web/vite.config.ts index 1fdef49d0c..a3e1cb77ef 100644 --- a/web/vite.config.ts +++ b/web/vite.config.ts @@ -108,10 +108,30 @@ export default defineConfig(({ mode }) => { ? { optimizeDeps: { exclude: ['nuqs'], + // Make Prism in lexical works + // https://github.com/vitejs/rolldown-vite/issues/396 + rolldownOptions: { + output: { + strictExecutionOrder: true, + }, + }, }, server: { port: 3000, }, + ssr: { + // SyntaxError: Named export not found. The requested module is a CommonJS module, which may not support all module.exports as named exports + noExternal: ['emoji-mart'], + }, + // Make Prism in lexical works + // https://github.com/vitejs/rolldown-vite/issues/396 + build: { + rolldownOptions: { + output: { + strictExecutionOrder: true, + }, + }, + }, } : {}), From 664ab123c33624416ff0ea7521b223cf68670236 Mon Sep 17 00:00:00 2001 From: nightcityblade Date: Tue, 3 Mar 2026 22:30:36 +0800 Subject: [PATCH 008/256] chore: add dependency groups to Dependabot config (#32721) --- .github/dependabot.yml | 26 +++++++++++++++++++------- 1 file changed, 19 insertions(+), 7 deletions(-) diff --git a/.github/dependabot.yml b/.github/dependabot.yml index 1a57bb0050..78f6eefd0d 100644 --- a/.github/dependabot.yml +++ b/.github/dependabot.yml @@ -1,25 +1,37 @@ version: 2 -multi-ecosystem-groups: - python: - schedule: - interval: "weekly" # or whatever schedule you want - updates: - package-ecosystem: "pip" directory: "/api" open-pull-requests-limit: 2 - patterns: ["*"] schedule: interval: "weekly" + groups: + python-dependencies: + patterns: + - "*" - package-ecosystem: "uv" directory: "/api" open-pull-requests-limit: 2 - patterns: ["*"] schedule: interval: "weekly" + groups: + uv-dependencies: + patterns: + - "*" - package-ecosystem: "npm" directory: "/web" schedule: interval: "weekly" open-pull-requests-limit: 2 + groups: + storybook: + patterns: + - "storybook" + - "@storybook/*" + npm-dependencies: + patterns: + - "*" + exclude-patterns: + - "storybook" + - "@storybook/*" From 2b47db0462ae6bbb8e392cac255858725f37696b Mon Sep 17 00:00:00 2001 From: Br1an <932039080@qq.com> Date: Tue, 3 Mar 2026 23:01:16 +0800 Subject: [PATCH 009/256] fix(api): resolve OpenTelemetry histogram type mismatch (#32771) Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com> --- api/core/ops/tencent_trace/client.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/api/core/ops/tencent_trace/client.py b/api/core/ops/tencent_trace/client.py index 99ccf00400..c39093bf4c 100644 --- a/api/core/ops/tencent_trace/client.py +++ b/api/core/ops/tencent_trace/client.py @@ -120,7 +120,8 @@ class TencentTraceClient: # Metrics exporter and instruments try: - from opentelemetry.sdk.metrics import Histogram, MeterProvider + from opentelemetry.sdk.metrics import Histogram as SdkHistogram + from opentelemetry.sdk.metrics import MeterProvider from opentelemetry.sdk.metrics.export import AggregationTemporality, PeriodicExportingMetricReader protocol = os.getenv("OTEL_EXPORTER_OTLP_PROTOCOL", "").strip().lower() @@ -128,7 +129,7 @@ class TencentTraceClient: use_http_json = protocol in {"http/json", "http-json"} # Tencent APM works best with delta aggregation temporality - preferred_temporality: dict[type, AggregationTemporality] = {Histogram: AggregationTemporality.DELTA} + preferred_temporality: dict[type, AggregationTemporality] = {SdkHistogram: AggregationTemporality.DELTA} def _create_metric_exporter(exporter_cls, **kwargs): """Create metric exporter with preferred_temporality support""" From 6002fd09b4de87ac9b863e0b445e742100e7e4d5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=9C=A8=E4=B9=8B=E6=9C=AC=E6=BE=AA?= Date: Wed, 4 Mar 2026 02:40:18 +0800 Subject: [PATCH 010/256] test: migrate test_dataset_service_create_dataset SQL tests to testcontainers (#32538) Co-authored-by: KinomotoMio <200703522+KinomotoMio@users.noreply.github.com> --- .../services/test_dataset_service.py | 273 +++++- .../test_dataset_service_create_dataset.py | 789 +----------------- 2 files changed, 282 insertions(+), 780 deletions(-) diff --git a/api/tests/test_containers_integration_tests/services/test_dataset_service.py b/api/tests/test_containers_integration_tests/services/test_dataset_service.py index 0ca649b36d..c3decbf39d 100644 --- a/api/tests/test_containers_integration_tests/services/test_dataset_service.py +++ b/api/tests/test_containers_integration_tests/services/test_dataset_service.py @@ -14,9 +14,10 @@ from core.rag.retrieval.retrieval_methods import RetrievalMethod from dify_graph.model_runtime.entities.model_entities import ModelType from extensions.ext_database import db from models.account import Account, Tenant, TenantAccountJoin, TenantAccountRole -from models.dataset import Dataset, DatasetPermissionEnum, Document, ExternalKnowledgeBindings +from models.dataset import Dataset, DatasetPermissionEnum, Document, ExternalKnowledgeBindings, Pipeline from services.dataset_service import DatasetService from services.entities.knowledge_entities.knowledge_entities import RerankingModel, RetrievalModel +from services.entities.knowledge_entities.rag_pipeline_entities import IconInfo, RagPipelineDatasetCreateEntity from services.errors.dataset import DatasetNameDuplicateError @@ -274,6 +275,276 @@ class TestDatasetServiceCreateDataset: assert result.retrieval_model == retrieval_model.model_dump() mock_check_reranking.assert_called_once_with(tenant.id, "cohere", "rerank-english-v2.0") + def test_create_internal_dataset_with_high_quality_indexing_custom_embedding(self, db_session_with_containers): + """Create high-quality dataset with explicitly configured embedding model.""" + # Arrange + account, tenant = DatasetServiceIntegrationDataFactory.create_account_with_tenant() + embedding_provider = "openai" + embedding_model_name = "text-embedding-3-small" + embedding_model = DatasetServiceIntegrationDataFactory.create_embedding_model( + provider=embedding_provider, model_name=embedding_model_name + ) + + # Act + with ( + patch("services.dataset_service.ModelManager") as mock_model_manager, + patch("services.dataset_service.DatasetService.check_embedding_model_setting") as mock_check_embedding, + ): + mock_model_manager.return_value.get_model_instance.return_value = embedding_model + + result = DatasetService.create_empty_dataset( + tenant_id=tenant.id, + name="Custom Embedding Dataset", + description=None, + indexing_technique="high_quality", + account=account, + embedding_model_provider=embedding_provider, + embedding_model_name=embedding_model_name, + ) + + # Assert + db.session.refresh(result) + assert result.indexing_technique == "high_quality" + assert result.embedding_model_provider == embedding_provider + assert result.embedding_model == embedding_model_name + mock_check_embedding.assert_called_once_with(tenant.id, embedding_provider, embedding_model_name) + mock_model_manager.return_value.get_model_instance.assert_called_once_with( + tenant_id=tenant.id, + provider=embedding_provider, + model_type=ModelType.TEXT_EMBEDDING, + model=embedding_model_name, + ) + + def test_create_internal_dataset_with_retrieval_model(self, db_session_with_containers): + """Persist retrieval model settings when creating an internal dataset.""" + # Arrange + account, tenant = DatasetServiceIntegrationDataFactory.create_account_with_tenant() + retrieval_model = RetrievalModel( + search_method=RetrievalMethod.SEMANTIC_SEARCH, + reranking_enable=False, + top_k=2, + score_threshold_enabled=True, + score_threshold=0.0, + ) + + # Act + result = DatasetService.create_empty_dataset( + tenant_id=tenant.id, + name="Retrieval Model Dataset", + description=None, + indexing_technique=None, + account=account, + retrieval_model=retrieval_model, + ) + + # Assert + db.session.refresh(result) + assert result.retrieval_model == retrieval_model.model_dump() + + def test_create_internal_dataset_with_custom_permission(self, db_session_with_containers): + """Persist canonical custom permission when creating an internal dataset.""" + # Arrange + account, tenant = DatasetServiceIntegrationDataFactory.create_account_with_tenant() + + # Act + result = DatasetService.create_empty_dataset( + tenant_id=tenant.id, + name="Custom Permission Dataset", + description=None, + indexing_technique=None, + account=account, + permission=DatasetPermissionEnum.ALL_TEAM, + ) + + # Assert + db.session.refresh(result) + assert result.permission == DatasetPermissionEnum.ALL_TEAM + + def test_create_external_dataset_missing_api_id_error(self, db_session_with_containers): + """Raise error when external API template does not exist.""" + # Arrange + account, tenant = DatasetServiceIntegrationDataFactory.create_account_with_tenant() + external_knowledge_api_id = str(uuid4()) + + # Act / Assert + with patch("services.dataset_service.ExternalDatasetService.get_external_knowledge_api") as mock_get_api: + mock_get_api.return_value = None + with pytest.raises(ValueError, match=r"External API template not found\.?"): + DatasetService.create_empty_dataset( + tenant_id=tenant.id, + name="External Missing API Dataset", + description=None, + indexing_technique=None, + account=account, + provider="external", + external_knowledge_api_id=external_knowledge_api_id, + external_knowledge_id="knowledge-123", + ) + + def test_create_external_dataset_missing_knowledge_id_error(self, db_session_with_containers): + """Raise error when external knowledge id is missing for external dataset creation.""" + # Arrange + account, tenant = DatasetServiceIntegrationDataFactory.create_account_with_tenant() + external_knowledge_api_id = str(uuid4()) + + # Act / Assert + with patch("services.dataset_service.ExternalDatasetService.get_external_knowledge_api") as mock_get_api: + mock_get_api.return_value = Mock(id=external_knowledge_api_id) + with pytest.raises(ValueError, match="external_knowledge_id is required"): + DatasetService.create_empty_dataset( + tenant_id=tenant.id, + name="External Missing Knowledge Dataset", + description=None, + indexing_technique=None, + account=account, + provider="external", + external_knowledge_api_id=external_knowledge_api_id, + external_knowledge_id=None, + ) + + +class TestDatasetServiceCreateRagPipelineDataset: + """Integration coverage for DatasetService.create_empty_rag_pipeline_dataset.""" + + def test_create_rag_pipeline_dataset_with_name_success(self, db_session_with_containers): + """Create rag-pipeline dataset and pipeline rows when a name is provided.""" + # Arrange + account, tenant = DatasetServiceIntegrationDataFactory.create_account_with_tenant() + icon_info = IconInfo(icon="📙", icon_background="#FFF4ED", icon_type="emoji") + entity = RagPipelineDatasetCreateEntity( + name="RAG Pipeline Dataset", + description="RAG Pipeline Description", + icon_info=icon_info, + permission=DatasetPermissionEnum.ONLY_ME, + ) + + # Act + with patch("services.dataset_service.current_user", account): + result = DatasetService.create_empty_rag_pipeline_dataset( + tenant_id=tenant.id, rag_pipeline_dataset_create_entity=entity + ) + + # Assert + created_dataset = db.session.get(Dataset, result.id) + created_pipeline = db.session.get(Pipeline, result.pipeline_id) + assert created_dataset is not None + assert created_dataset.name == entity.name + assert created_dataset.runtime_mode == "rag_pipeline" + assert created_dataset.created_by == account.id + assert created_dataset.permission == DatasetPermissionEnum.ONLY_ME + assert created_pipeline is not None + assert created_pipeline.name == entity.name + assert created_pipeline.created_by == account.id + + def test_create_rag_pipeline_dataset_with_auto_generated_name(self, db_session_with_containers): + """Create rag-pipeline dataset with generated incremental name when input name is empty.""" + # Arrange + account, tenant = DatasetServiceIntegrationDataFactory.create_account_with_tenant() + generated_name = "Untitled 1" + icon_info = IconInfo(icon="📙", icon_background="#FFF4ED", icon_type="emoji") + entity = RagPipelineDatasetCreateEntity( + name="", + description="", + icon_info=icon_info, + permission=DatasetPermissionEnum.ONLY_ME, + ) + + # Act + with ( + patch("services.dataset_service.current_user", account), + patch("services.dataset_service.generate_incremental_name") as mock_generate_name, + ): + mock_generate_name.return_value = generated_name + result = DatasetService.create_empty_rag_pipeline_dataset( + tenant_id=tenant.id, rag_pipeline_dataset_create_entity=entity + ) + + # Assert + db.session.refresh(result) + created_pipeline = db.session.get(Pipeline, result.pipeline_id) + assert result.name == generated_name + assert created_pipeline is not None + assert created_pipeline.name == generated_name + mock_generate_name.assert_called_once() + + def test_create_rag_pipeline_dataset_duplicate_name_error(self, db_session_with_containers): + """Raise duplicate-name error when rag-pipeline dataset name already exists.""" + # Arrange + account, tenant = DatasetServiceIntegrationDataFactory.create_account_with_tenant() + duplicate_name = "Duplicate RAG Dataset" + DatasetServiceIntegrationDataFactory.create_dataset( + tenant_id=tenant.id, + created_by=account.id, + name=duplicate_name, + indexing_technique=None, + ) + db.session.commit() + icon_info = IconInfo(icon="📙", icon_background="#FFF4ED", icon_type="emoji") + entity = RagPipelineDatasetCreateEntity( + name=duplicate_name, + description="", + icon_info=icon_info, + permission=DatasetPermissionEnum.ONLY_ME, + ) + + # Act / Assert + with ( + patch("services.dataset_service.current_user", account), + pytest.raises(DatasetNameDuplicateError, match=f"Dataset with name {duplicate_name} already exists"), + ): + DatasetService.create_empty_rag_pipeline_dataset( + tenant_id=tenant.id, rag_pipeline_dataset_create_entity=entity + ) + + def test_create_rag_pipeline_dataset_with_custom_permission(self, db_session_with_containers): + """Persist canonical custom permission for rag-pipeline dataset creation.""" + # Arrange + account, tenant = DatasetServiceIntegrationDataFactory.create_account_with_tenant() + icon_info = IconInfo(icon="📙", icon_background="#FFF4ED", icon_type="emoji") + entity = RagPipelineDatasetCreateEntity( + name="Custom Permission RAG Dataset", + description="", + icon_info=icon_info, + permission=DatasetPermissionEnum.ALL_TEAM, + ) + + # Act + with patch("services.dataset_service.current_user", account): + result = DatasetService.create_empty_rag_pipeline_dataset( + tenant_id=tenant.id, rag_pipeline_dataset_create_entity=entity + ) + + # Assert + db.session.refresh(result) + assert result.permission == DatasetPermissionEnum.ALL_TEAM + + def test_create_rag_pipeline_dataset_with_icon_info(self, db_session_with_containers): + """Persist icon metadata when creating rag-pipeline dataset.""" + # Arrange + account, tenant = DatasetServiceIntegrationDataFactory.create_account_with_tenant() + icon_info = IconInfo( + icon="📚", + icon_background="#E8F5E9", + icon_type="emoji", + icon_url="https://example.com/icon.png", + ) + entity = RagPipelineDatasetCreateEntity( + name="Icon Info RAG Dataset", + description="", + icon_info=icon_info, + permission=DatasetPermissionEnum.ONLY_ME, + ) + + # Act + with patch("services.dataset_service.current_user", account): + result = DatasetService.create_empty_rag_pipeline_dataset( + tenant_id=tenant.id, rag_pipeline_dataset_create_entity=entity + ) + + # Assert + db.session.refresh(result) + assert result.icon_info == icon_info.model_dump() + class TestDatasetServiceUpdateAndDeleteDataset: """Integration coverage for SQL-backed update and delete behavior.""" diff --git a/api/tests/unit_tests/services/test_dataset_service_create_dataset.py b/api/tests/unit_tests/services/test_dataset_service_create_dataset.py index 87a0d6b678..f8c5270656 100644 --- a/api/tests/unit_tests/services/test_dataset_service_create_dataset.py +++ b/api/tests/unit_tests/services/test_dataset_service_create_dataset.py @@ -1,726 +1,39 @@ -""" -Comprehensive unit tests for DatasetService creation methods. +"""Unit tests for non-SQL validation paths in DatasetService dataset creation.""" -This test suite covers: -- create_empty_dataset for internal datasets -- create_empty_dataset for external datasets -- create_empty_rag_pipeline_dataset -- Error conditions and edge cases -""" - -from unittest.mock import Mock, create_autospec, patch +from unittest.mock import Mock, patch from uuid import uuid4 import pytest -from dify_graph.model_runtime.entities.model_entities import ModelType -from models.account import Account -from models.dataset import Dataset, Pipeline from services.dataset_service import DatasetService -from services.entities.knowledge_entities.knowledge_entities import RetrievalModel -from services.entities.knowledge_entities.rag_pipeline_entities import ( - IconInfo, - RagPipelineDatasetCreateEntity, -) -from services.errors.dataset import DatasetNameDuplicateError +from services.entities.knowledge_entities.rag_pipeline_entities import IconInfo, RagPipelineDatasetCreateEntity -class DatasetCreateTestDataFactory: - """Factory class for creating test data and mock objects for dataset creation tests.""" - - @staticmethod - def create_account_mock( - account_id: str = "account-123", - tenant_id: str = "tenant-123", - **kwargs, - ) -> Mock: - """Create a mock account.""" - account = create_autospec(Account, instance=True) - account.id = account_id - account.current_tenant_id = tenant_id - for key, value in kwargs.items(): - setattr(account, key, value) - return account - - @staticmethod - def create_embedding_model_mock(model: str = "text-embedding-ada-002", provider: str = "openai") -> Mock: - """Create a mock embedding model.""" - embedding_model = Mock() - embedding_model.model_name = model - embedding_model.provider = provider - return embedding_model - - @staticmethod - def create_retrieval_model_mock() -> Mock: - """Create a mock retrieval model.""" - retrieval_model = Mock(spec=RetrievalModel) - retrieval_model.model_dump.return_value = { - "search_method": "semantic_search", - "top_k": 2, - "score_threshold": 0.0, - } - retrieval_model.reranking_model = None - return retrieval_model - - @staticmethod - def create_external_knowledge_api_mock(api_id: str = "api-123", **kwargs) -> Mock: - """Create a mock external knowledge API.""" - api = Mock() - api.id = api_id - for key, value in kwargs.items(): - setattr(api, key, value) - return api - - @staticmethod - def create_dataset_mock( - dataset_id: str = "dataset-123", - name: str = "Test Dataset", - tenant_id: str = "tenant-123", - **kwargs, - ) -> Mock: - """Create a mock dataset.""" - dataset = create_autospec(Dataset, instance=True) - dataset.id = dataset_id - dataset.name = name - dataset.tenant_id = tenant_id - for key, value in kwargs.items(): - setattr(dataset, key, value) - return dataset - - @staticmethod - def create_pipeline_mock( - pipeline_id: str = "pipeline-123", - name: str = "Test Pipeline", - **kwargs, - ) -> Mock: - """Create a mock pipeline.""" - pipeline = Mock(spec=Pipeline) - pipeline.id = pipeline_id - pipeline.name = name - for key, value in kwargs.items(): - setattr(pipeline, key, value) - return pipeline - - -class TestDatasetServiceCreateEmptyDataset: - """ - Comprehensive unit tests for DatasetService.create_empty_dataset method. - - This test suite covers: - - Internal dataset creation (vendor provider) - - External dataset creation - - High quality indexing technique with embedding models - - Economy indexing technique - - Retrieval model configuration - - Error conditions (duplicate names, missing external knowledge IDs) - """ - - @pytest.fixture - def mock_dataset_service_dependencies(self): - """Common mock setup for dataset service dependencies.""" - with ( - patch("services.dataset_service.db.session") as mock_db, - patch("services.dataset_service.ModelManager") as mock_model_manager, - patch("services.dataset_service.DatasetService.check_embedding_model_setting") as mock_check_embedding, - patch("services.dataset_service.DatasetService.check_reranking_model_setting") as mock_check_reranking, - patch("services.dataset_service.ExternalDatasetService") as mock_external_service, - ): - yield { - "db_session": mock_db, - "model_manager": mock_model_manager, - "check_embedding": mock_check_embedding, - "check_reranking": mock_check_reranking, - "external_service": mock_external_service, - } - - # ==================== Internal Dataset Creation Tests ==================== - - def test_create_internal_dataset_basic_success(self, mock_dataset_service_dependencies): - """Test successful creation of basic internal dataset.""" - # Arrange - tenant_id = str(uuid4()) - account = DatasetCreateTestDataFactory.create_account_mock(tenant_id=tenant_id) - name = "Test Dataset" - description = "Test description" - - # Mock database query to return None (no duplicate name) - mock_query = Mock() - mock_query.filter_by.return_value.first.return_value = None - mock_dataset_service_dependencies["db_session"].query.return_value = mock_query - - # Mock database session operations - mock_db = mock_dataset_service_dependencies["db_session"] - mock_db.add = Mock() - mock_db.flush = Mock() - mock_db.commit = Mock() - - # Act - result = DatasetService.create_empty_dataset( - tenant_id=tenant_id, - name=name, - description=description, - indexing_technique=None, - account=account, - ) - - # Assert - assert result is not None - assert result.name == name - assert result.description == description - assert result.tenant_id == tenant_id - assert result.created_by == account.id - assert result.updated_by == account.id - assert result.provider == "vendor" - assert result.permission == "only_me" - mock_db.add.assert_called_once() - mock_db.commit.assert_called_once() - - def test_create_internal_dataset_with_economy_indexing(self, mock_dataset_service_dependencies): - """Test successful creation of internal dataset with economy indexing.""" - # Arrange - tenant_id = str(uuid4()) - account = DatasetCreateTestDataFactory.create_account_mock(tenant_id=tenant_id) - name = "Economy Dataset" - - # Mock database query - mock_query = Mock() - mock_query.filter_by.return_value.first.return_value = None - mock_dataset_service_dependencies["db_session"].query.return_value = mock_query - - mock_db = mock_dataset_service_dependencies["db_session"] - mock_db.add = Mock() - mock_db.flush = Mock() - mock_db.commit = Mock() - - # Act - result = DatasetService.create_empty_dataset( - tenant_id=tenant_id, - name=name, - description=None, - indexing_technique="economy", - account=account, - ) - - # Assert - assert result.indexing_technique == "economy" - assert result.embedding_model_provider is None - assert result.embedding_model is None - mock_db.commit.assert_called_once() - - def test_create_internal_dataset_with_high_quality_indexing_default_embedding( - self, mock_dataset_service_dependencies - ): - """Test creation with high_quality indexing using default embedding model.""" - # Arrange - tenant_id = str(uuid4()) - account = DatasetCreateTestDataFactory.create_account_mock(tenant_id=tenant_id) - name = "High Quality Dataset" - - # Mock database query - mock_query = Mock() - mock_query.filter_by.return_value.first.return_value = None - mock_dataset_service_dependencies["db_session"].query.return_value = mock_query - - # Mock model manager - embedding_model = DatasetCreateTestDataFactory.create_embedding_model_mock() - mock_model_manager_instance = Mock() - mock_model_manager_instance.get_default_model_instance.return_value = embedding_model - mock_dataset_service_dependencies["model_manager"].return_value = mock_model_manager_instance - - mock_db = mock_dataset_service_dependencies["db_session"] - mock_db.add = Mock() - mock_db.flush = Mock() - mock_db.commit = Mock() - - # Act - result = DatasetService.create_empty_dataset( - tenant_id=tenant_id, - name=name, - description=None, - indexing_technique="high_quality", - account=account, - ) - - # Assert - assert result.indexing_technique == "high_quality" - assert result.embedding_model_provider == embedding_model.provider - assert result.embedding_model == embedding_model.model_name - mock_model_manager_instance.get_default_model_instance.assert_called_once_with( - tenant_id=tenant_id, model_type=ModelType.TEXT_EMBEDDING - ) - mock_db.commit.assert_called_once() - - def test_create_internal_dataset_with_high_quality_indexing_custom_embedding( - self, mock_dataset_service_dependencies - ): - """Test creation with high_quality indexing using custom embedding model.""" - # Arrange - tenant_id = str(uuid4()) - account = DatasetCreateTestDataFactory.create_account_mock(tenant_id=tenant_id) - name = "Custom Embedding Dataset" - embedding_provider = "openai" - embedding_model_name = "text-embedding-3-small" - - # Mock database query - mock_query = Mock() - mock_query.filter_by.return_value.first.return_value = None - mock_dataset_service_dependencies["db_session"].query.return_value = mock_query - - # Mock model manager - embedding_model = DatasetCreateTestDataFactory.create_embedding_model_mock( - model=embedding_model_name, provider=embedding_provider - ) - mock_model_manager_instance = Mock() - mock_model_manager_instance.get_model_instance.return_value = embedding_model - mock_dataset_service_dependencies["model_manager"].return_value = mock_model_manager_instance - - mock_db = mock_dataset_service_dependencies["db_session"] - mock_db.add = Mock() - mock_db.flush = Mock() - mock_db.commit = Mock() - - # Act - result = DatasetService.create_empty_dataset( - tenant_id=tenant_id, - name=name, - description=None, - indexing_technique="high_quality", - account=account, - embedding_model_provider=embedding_provider, - embedding_model_name=embedding_model_name, - ) - - # Assert - assert result.indexing_technique == "high_quality" - assert result.embedding_model_provider == embedding_provider - assert result.embedding_model == embedding_model_name - mock_dataset_service_dependencies["check_embedding"].assert_called_once_with( - tenant_id, embedding_provider, embedding_model_name - ) - mock_model_manager_instance.get_model_instance.assert_called_once_with( - tenant_id=tenant_id, - provider=embedding_provider, - model_type=ModelType.TEXT_EMBEDDING, - model=embedding_model_name, - ) - mock_db.commit.assert_called_once() - - def test_create_internal_dataset_with_retrieval_model(self, mock_dataset_service_dependencies): - """Test creation with retrieval model configuration.""" - # Arrange - tenant_id = str(uuid4()) - account = DatasetCreateTestDataFactory.create_account_mock(tenant_id=tenant_id) - name = "Retrieval Model Dataset" - - # Mock database query - mock_query = Mock() - mock_query.filter_by.return_value.first.return_value = None - mock_dataset_service_dependencies["db_session"].query.return_value = mock_query - - # Mock retrieval model - retrieval_model = DatasetCreateTestDataFactory.create_retrieval_model_mock() - retrieval_model_dict = {"search_method": "semantic_search", "top_k": 2, "score_threshold": 0.0} - - mock_db = mock_dataset_service_dependencies["db_session"] - mock_db.add = Mock() - mock_db.flush = Mock() - mock_db.commit = Mock() - - # Act - result = DatasetService.create_empty_dataset( - tenant_id=tenant_id, - name=name, - description=None, - indexing_technique=None, - account=account, - retrieval_model=retrieval_model, - ) - - # Assert - assert result.retrieval_model == retrieval_model_dict - retrieval_model.model_dump.assert_called_once() - mock_db.commit.assert_called_once() - - def test_create_internal_dataset_with_retrieval_model_reranking(self, mock_dataset_service_dependencies): - """Test creation with retrieval model that includes reranking.""" - # Arrange - tenant_id = str(uuid4()) - account = DatasetCreateTestDataFactory.create_account_mock(tenant_id=tenant_id) - name = "Reranking Dataset" - - # Mock database query - mock_query = Mock() - mock_query.filter_by.return_value.first.return_value = None - mock_dataset_service_dependencies["db_session"].query.return_value = mock_query - - # Mock model manager - embedding_model = DatasetCreateTestDataFactory.create_embedding_model_mock() - mock_model_manager_instance = Mock() - mock_model_manager_instance.get_default_model_instance.return_value = embedding_model - mock_dataset_service_dependencies["model_manager"].return_value = mock_model_manager_instance - - # Mock retrieval model with reranking - reranking_model = Mock() - reranking_model.reranking_provider_name = "cohere" - reranking_model.reranking_model_name = "rerank-english-v3.0" - - retrieval_model = DatasetCreateTestDataFactory.create_retrieval_model_mock() - retrieval_model.reranking_model = reranking_model - - mock_db = mock_dataset_service_dependencies["db_session"] - mock_db.add = Mock() - mock_db.flush = Mock() - mock_db.commit = Mock() - - # Act - result = DatasetService.create_empty_dataset( - tenant_id=tenant_id, - name=name, - description=None, - indexing_technique="high_quality", - account=account, - retrieval_model=retrieval_model, - ) - - # Assert - mock_dataset_service_dependencies["check_reranking"].assert_called_once_with( - tenant_id, "cohere", "rerank-english-v3.0" - ) - mock_db.commit.assert_called_once() - - def test_create_internal_dataset_with_custom_permission(self, mock_dataset_service_dependencies): - """Test creation with custom permission setting.""" - # Arrange - tenant_id = str(uuid4()) - account = DatasetCreateTestDataFactory.create_account_mock(tenant_id=tenant_id) - name = "Custom Permission Dataset" - - # Mock database query - mock_query = Mock() - mock_query.filter_by.return_value.first.return_value = None - mock_dataset_service_dependencies["db_session"].query.return_value = mock_query - - mock_db = mock_dataset_service_dependencies["db_session"] - mock_db.add = Mock() - mock_db.flush = Mock() - mock_db.commit = Mock() - - # Act - result = DatasetService.create_empty_dataset( - tenant_id=tenant_id, - name=name, - description=None, - indexing_technique=None, - account=account, - permission="all_team_members", - ) - - # Assert - assert result.permission == "all_team_members" - mock_db.commit.assert_called_once() - - # ==================== External Dataset Creation Tests ==================== - - def test_create_external_dataset_success(self, mock_dataset_service_dependencies): - """Test successful creation of external dataset.""" - # Arrange - tenant_id = str(uuid4()) - account = DatasetCreateTestDataFactory.create_account_mock(tenant_id=tenant_id) - name = "External Dataset" - external_api_id = "external-api-123" - external_knowledge_id = "external-knowledge-456" - - # Mock database query - mock_query = Mock() - mock_query.filter_by.return_value.first.return_value = None - mock_dataset_service_dependencies["db_session"].query.return_value = mock_query - - # Mock external knowledge API - external_api = DatasetCreateTestDataFactory.create_external_knowledge_api_mock(api_id=external_api_id) - mock_dataset_service_dependencies["external_service"].get_external_knowledge_api.return_value = external_api - - mock_db = mock_dataset_service_dependencies["db_session"] - mock_db.add = Mock() - mock_db.flush = Mock() - mock_db.commit = Mock() - - # Act - result = DatasetService.create_empty_dataset( - tenant_id=tenant_id, - name=name, - description=None, - indexing_technique=None, - account=account, - provider="external", - external_knowledge_api_id=external_api_id, - external_knowledge_id=external_knowledge_id, - ) - - # Assert - assert result.provider == "external" - assert mock_db.add.call_count == 2 # Dataset + ExternalKnowledgeBindings - mock_dataset_service_dependencies["external_service"].get_external_knowledge_api.assert_called_once_with( - external_api_id - ) - mock_db.commit.assert_called_once() - - def test_create_external_dataset_missing_api_id_error(self, mock_dataset_service_dependencies): - """Test error when external knowledge API is not found.""" - # Arrange - tenant_id = str(uuid4()) - account = DatasetCreateTestDataFactory.create_account_mock(tenant_id=tenant_id) - name = "External Dataset" - external_api_id = "non-existent-api" - - # Mock database query - mock_query = Mock() - mock_query.filter_by.return_value.first.return_value = None - mock_dataset_service_dependencies["db_session"].query.return_value = mock_query - - # Mock external knowledge API not found - mock_dataset_service_dependencies["external_service"].get_external_knowledge_api.return_value = None - - mock_db = mock_dataset_service_dependencies["db_session"] - mock_db.add = Mock() - mock_db.flush = Mock() - - # Act & Assert - with pytest.raises(ValueError, match="External API template not found"): - DatasetService.create_empty_dataset( - tenant_id=tenant_id, - name=name, - description=None, - indexing_technique=None, - account=account, - provider="external", - external_knowledge_api_id=external_api_id, - external_knowledge_id="knowledge-123", - ) - - def test_create_external_dataset_missing_knowledge_id_error(self, mock_dataset_service_dependencies): - """Test error when external knowledge ID is missing.""" - # Arrange - tenant_id = str(uuid4()) - account = DatasetCreateTestDataFactory.create_account_mock(tenant_id=tenant_id) - name = "External Dataset" - external_api_id = "external-api-123" - - # Mock database query - mock_query = Mock() - mock_query.filter_by.return_value.first.return_value = None - mock_dataset_service_dependencies["db_session"].query.return_value = mock_query - - # Mock external knowledge API - external_api = DatasetCreateTestDataFactory.create_external_knowledge_api_mock(api_id=external_api_id) - mock_dataset_service_dependencies["external_service"].get_external_knowledge_api.return_value = external_api - - mock_db = mock_dataset_service_dependencies["db_session"] - mock_db.add = Mock() - mock_db.flush = Mock() - - # Act & Assert - with pytest.raises(ValueError, match="external_knowledge_id is required"): - DatasetService.create_empty_dataset( - tenant_id=tenant_id, - name=name, - description=None, - indexing_technique=None, - account=account, - provider="external", - external_knowledge_api_id=external_api_id, - external_knowledge_id=None, - ) - - # ==================== Error Handling Tests ==================== - - def test_create_dataset_duplicate_name_error(self, mock_dataset_service_dependencies): - """Test error when dataset name already exists.""" - # Arrange - tenant_id = str(uuid4()) - account = DatasetCreateTestDataFactory.create_account_mock(tenant_id=tenant_id) - name = "Duplicate Dataset" - - # Mock database query to return existing dataset - existing_dataset = DatasetCreateTestDataFactory.create_dataset_mock(name=name) - mock_query = Mock() - mock_query.filter_by.return_value.first.return_value = existing_dataset - mock_dataset_service_dependencies["db_session"].query.return_value = mock_query - - # Act & Assert - with pytest.raises(DatasetNameDuplicateError, match=f"Dataset with name {name} already exists"): - DatasetService.create_empty_dataset( - tenant_id=tenant_id, - name=name, - description=None, - indexing_technique=None, - account=account, - ) - - -class TestDatasetServiceCreateEmptyRagPipelineDataset: - """ - Comprehensive unit tests for DatasetService.create_empty_rag_pipeline_dataset method. - - This test suite covers: - - RAG pipeline dataset creation with provided name - - RAG pipeline dataset creation with auto-generated name - - Pipeline creation - - Error conditions (duplicate names, missing current user) - """ +class TestDatasetServiceCreateRagPipelineDatasetNonSQL: + """Unit coverage for non-SQL validation in create_empty_rag_pipeline_dataset.""" @pytest.fixture def mock_rag_pipeline_dependencies(self): - """Common mock setup for RAG pipeline dataset creation.""" + """Patch database session and current_user for validation-only unit coverage.""" with ( patch("services.dataset_service.db.session") as mock_db, patch("services.dataset_service.current_user") as mock_current_user, - patch("services.dataset_service.generate_incremental_name") as mock_generate_name, ): - # Configure mock_current_user to behave like a Flask-Login proxy - # Default: no user (falsy) - mock_current_user.id = None yield { "db_session": mock_db, "current_user_mock": mock_current_user, - "generate_name": mock_generate_name, } - def test_create_rag_pipeline_dataset_with_name_success(self, mock_rag_pipeline_dependencies): - """Test successful creation of RAG pipeline dataset with provided name.""" - # Arrange - tenant_id = str(uuid4()) - user_id = str(uuid4()) - name = "RAG Pipeline Dataset" - description = "RAG Pipeline Description" - - # Mock current user - set up the mock to have id attribute accessible directly - mock_rag_pipeline_dependencies["current_user_mock"].id = user_id - - # Mock database query (no duplicate name) - mock_query = Mock() - mock_query.filter_by.return_value.first.return_value = None - mock_rag_pipeline_dependencies["db_session"].query.return_value = mock_query - - # Mock database operations - mock_db = mock_rag_pipeline_dependencies["db_session"] - mock_db.add = Mock() - mock_db.flush = Mock() - mock_db.commit = Mock() - - # Create entity - icon_info = IconInfo(icon="📙", icon_background="#FFF4ED", icon_type="emoji") - entity = RagPipelineDatasetCreateEntity( - name=name, - description=description, - icon_info=icon_info, - permission="only_me", - ) - - # Act - result = DatasetService.create_empty_rag_pipeline_dataset( - tenant_id=tenant_id, rag_pipeline_dataset_create_entity=entity - ) - - # Assert - assert result is not None - assert result.name == name - assert result.description == description - assert result.tenant_id == tenant_id - assert result.created_by == user_id - assert result.provider == "vendor" - assert result.runtime_mode == "rag_pipeline" - assert result.permission == "only_me" - assert mock_db.add.call_count == 2 # Pipeline + Dataset - mock_db.commit.assert_called_once() - - def test_create_rag_pipeline_dataset_with_auto_generated_name(self, mock_rag_pipeline_dependencies): - """Test creation of RAG pipeline dataset with auto-generated name.""" - # Arrange - tenant_id = str(uuid4()) - user_id = str(uuid4()) - auto_name = "Untitled 1" - - # Mock current user - set up the mock to have id attribute accessible directly - mock_rag_pipeline_dependencies["current_user_mock"].id = user_id - - # Mock database query (empty name, need to generate) - mock_query = Mock() - mock_query.filter_by.return_value.all.return_value = [] - mock_rag_pipeline_dependencies["db_session"].query.return_value = mock_query - - # Mock name generation - mock_rag_pipeline_dependencies["generate_name"].return_value = auto_name - - # Mock database operations - mock_db = mock_rag_pipeline_dependencies["db_session"] - mock_db.add = Mock() - mock_db.flush = Mock() - mock_db.commit = Mock() - - # Create entity with empty name - icon_info = IconInfo(icon="📙", icon_background="#FFF4ED", icon_type="emoji") - entity = RagPipelineDatasetCreateEntity( - name="", - description="", - icon_info=icon_info, - permission="only_me", - ) - - # Act - result = DatasetService.create_empty_rag_pipeline_dataset( - tenant_id=tenant_id, rag_pipeline_dataset_create_entity=entity - ) - - # Assert - assert result.name == auto_name - mock_rag_pipeline_dependencies["generate_name"].assert_called_once() - mock_db.commit.assert_called_once() - - def test_create_rag_pipeline_dataset_duplicate_name_error(self, mock_rag_pipeline_dependencies): - """Test error when RAG pipeline dataset name already exists.""" - # Arrange - tenant_id = str(uuid4()) - user_id = str(uuid4()) - name = "Duplicate RAG Dataset" - - # Mock current user - set up the mock to have id attribute accessible directly - mock_rag_pipeline_dependencies["current_user_mock"].id = user_id - - # Mock database query to return existing dataset - existing_dataset = DatasetCreateTestDataFactory.create_dataset_mock(name=name) - mock_query = Mock() - mock_query.filter_by.return_value.first.return_value = existing_dataset - mock_rag_pipeline_dependencies["db_session"].query.return_value = mock_query - - # Create entity - icon_info = IconInfo(icon="📙", icon_background="#FFF4ED", icon_type="emoji") - entity = RagPipelineDatasetCreateEntity( - name=name, - description="", - icon_info=icon_info, - permission="only_me", - ) - - # Act & Assert - with pytest.raises(DatasetNameDuplicateError, match=f"Dataset with name {name} already exists"): - DatasetService.create_empty_rag_pipeline_dataset( - tenant_id=tenant_id, rag_pipeline_dataset_create_entity=entity - ) - def test_create_rag_pipeline_dataset_missing_current_user_error(self, mock_rag_pipeline_dependencies): - """Test error when current user is not available.""" + """Raise ValueError when current_user.id is unavailable before SQL persistence.""" # Arrange tenant_id = str(uuid4()) - - # Mock current user as None - set id to None so the check fails mock_rag_pipeline_dependencies["current_user_mock"].id = None - # Mock database query mock_query = Mock() mock_query.filter_by.return_value.first.return_value = None mock_rag_pipeline_dependencies["db_session"].query.return_value = mock_query - # Create entity icon_info = IconInfo(icon="📙", icon_background="#FFF4ED", icon_type="emoji") entity = RagPipelineDatasetCreateEntity( name="Test Dataset", @@ -729,91 +42,9 @@ class TestDatasetServiceCreateEmptyRagPipelineDataset: permission="only_me", ) - # Act & Assert + # Act / Assert with pytest.raises(ValueError, match="Current user or current user id not found"): DatasetService.create_empty_rag_pipeline_dataset( - tenant_id=tenant_id, rag_pipeline_dataset_create_entity=entity + tenant_id=tenant_id, + rag_pipeline_dataset_create_entity=entity, ) - - def test_create_rag_pipeline_dataset_with_custom_permission(self, mock_rag_pipeline_dependencies): - """Test creation with custom permission setting.""" - # Arrange - tenant_id = str(uuid4()) - user_id = str(uuid4()) - name = "Custom Permission RAG Dataset" - - # Mock current user - set up the mock to have id attribute accessible directly - mock_rag_pipeline_dependencies["current_user_mock"].id = user_id - - # Mock database query - mock_query = Mock() - mock_query.filter_by.return_value.first.return_value = None - mock_rag_pipeline_dependencies["db_session"].query.return_value = mock_query - - # Mock database operations - mock_db = mock_rag_pipeline_dependencies["db_session"] - mock_db.add = Mock() - mock_db.flush = Mock() - mock_db.commit = Mock() - - # Create entity - icon_info = IconInfo(icon="📙", icon_background="#FFF4ED", icon_type="emoji") - entity = RagPipelineDatasetCreateEntity( - name=name, - description="", - icon_info=icon_info, - permission="all_team", - ) - - # Act - result = DatasetService.create_empty_rag_pipeline_dataset( - tenant_id=tenant_id, rag_pipeline_dataset_create_entity=entity - ) - - # Assert - assert result.permission == "all_team" - mock_db.commit.assert_called_once() - - def test_create_rag_pipeline_dataset_with_icon_info(self, mock_rag_pipeline_dependencies): - """Test creation with icon info configuration.""" - # Arrange - tenant_id = str(uuid4()) - user_id = str(uuid4()) - name = "Icon Info RAG Dataset" - - # Mock current user - set up the mock to have id attribute accessible directly - mock_rag_pipeline_dependencies["current_user_mock"].id = user_id - - # Mock database query - mock_query = Mock() - mock_query.filter_by.return_value.first.return_value = None - mock_rag_pipeline_dependencies["db_session"].query.return_value = mock_query - - # Mock database operations - mock_db = mock_rag_pipeline_dependencies["db_session"] - mock_db.add = Mock() - mock_db.flush = Mock() - mock_db.commit = Mock() - - # Create entity with icon info - icon_info = IconInfo( - icon="📚", - icon_background="#E8F5E9", - icon_type="emoji", - icon_url="https://example.com/icon.png", - ) - entity = RagPipelineDatasetCreateEntity( - name=name, - description="", - icon_info=icon_info, - permission="only_me", - ) - - # Act - result = DatasetService.create_empty_rag_pipeline_dataset( - tenant_id=tenant_id, rag_pipeline_dataset_create_entity=entity - ) - - # Assert - assert result.icon_info == icon_info.model_dump() - mock_db.commit.assert_called_once() From 477bf6e0758eb7717bfbe1dac05e686c173a3be3 Mon Sep 17 00:00:00 2001 From: yyh <92089059+lyzno1@users.noreply.github.com> Date: Wed, 4 Mar 2026 09:53:15 +0800 Subject: [PATCH 011/256] fix(web): align dropdown-menu styles with Figma design (#32922) --- .../base/ui/dropdown-menu/index.stories.tsx | 317 ++++++++++++++++++ .../base/ui/dropdown-menu/index.tsx | 12 +- 2 files changed, 323 insertions(+), 6 deletions(-) create mode 100644 web/app/components/base/ui/dropdown-menu/index.stories.tsx diff --git a/web/app/components/base/ui/dropdown-menu/index.stories.tsx b/web/app/components/base/ui/dropdown-menu/index.stories.tsx new file mode 100644 index 0000000000..70afc07819 --- /dev/null +++ b/web/app/components/base/ui/dropdown-menu/index.stories.tsx @@ -0,0 +1,317 @@ +import type { Meta, StoryObj } from '@storybook/nextjs-vite' +import { useState } from 'react' +import { + DropdownMenu, + DropdownMenuCheckboxItem, + DropdownMenuCheckboxItemIndicator, + DropdownMenuContent, + DropdownMenuGroup, + DropdownMenuGroupLabel, + DropdownMenuItem, + DropdownMenuRadioGroup, + DropdownMenuRadioItem, + DropdownMenuRadioItemIndicator, + DropdownMenuSeparator, + DropdownMenuSub, + DropdownMenuSubContent, + DropdownMenuSubTrigger, + DropdownMenuTrigger, +} from '.' + +const TriggerButton = ({ label = 'Open Menu' }: { label?: string }) => ( + } + > + {label} + +) + +const meta = { + title: 'Base/Navigation/DropdownMenu', + component: DropdownMenu, + parameters: { + layout: 'centered', + docs: { + description: { + component: 'Compound dropdown menu built on Base UI Menu. Supports items, separators, group labels, submenus, radio groups, checkbox items, destructive items, and disabled states.', + }, + }, + }, + tags: ['autodocs'], +} satisfies Meta + +export default meta +type Story = StoryObj + +export const Default: Story = { + render: () => ( + + + + Edit + Duplicate + Archive + + + ), +} + +export const WithSeparator: Story = { + render: () => ( + + + + Cut + Copy + Paste + + Select All + + Find and Replace + + + ), +} + +export const WithGroupLabel: Story = { + render: () => ( + + + + + Actions + Edit + Duplicate + + + + Export + Export as PDF + Export as CSV + + + + ), +} + +export const WithDestructiveItem: Story = { + render: () => ( + + + + Edit + Duplicate + + Delete + + + ), +} + +export const WithSubmenu: Story = { + render: () => ( + + + + New File + Open + + + Share + + Email + Slack + Copy Link + + + + Download + + + ), +} + +const WithRadioItemsDemo = () => { + const [value, setValue] = useState('comfortable') + + return ( + + + + + + Compact + + + + Comfortable + + + + Spacious + + + + + + ) +} + +export const WithRadioItems: Story = { + render: () => , +} + +const WithCheckboxItemsDemo = () => { + const [showToolbar, setShowToolbar] = useState(true) + const [showSidebar, setShowSidebar] = useState(false) + const [showStatusBar, setShowStatusBar] = useState(true) + + return ( + + + + + Toolbar + + + + Sidebar + + + + Status Bar + + + + + ) +} + +export const WithCheckboxItems: Story = { + render: () => , +} + +export const WithDisabledItems: Story = { + render: () => ( + + + + Edit + Duplicate + Archive + + Restore + Delete + + + ), +} + +export const WithIcons: Story = { + render: () => ( + + + + + + Edit + + + + Duplicate + + + + Archive + + + + + Delete + + + + ), +} + +const ComplexDemo = () => { + const [sortOrder, setSortOrder] = useState('newest') + const [showArchived, setShowArchived] = useState(false) + + return ( + + + + + Edit + + + Rename + + + + Duplicate + + + + Move to Workspace + + + + + + + Share + + + + + Email + + + + Slack + + + + Copy Link + + + + + + Sort by + + + Newest first + + + + Oldest first + + + + Name + + + + + + + + Show Archived + + + + + + Delete + + + + ) +} + +export const Complex: Story = { + render: () => , +} diff --git a/web/app/components/base/ui/dropdown-menu/index.tsx b/web/app/components/base/ui/dropdown-menu/index.tsx index e839fd24ef..8d4f630adc 100644 --- a/web/app/components/base/ui/dropdown-menu/index.tsx +++ b/web/app/components/base/ui/dropdown-menu/index.tsx @@ -13,8 +13,8 @@ export const DropdownMenuSub = Menu.SubmenuRoot export const DropdownMenuGroup = Menu.Group export const DropdownMenuRadioGroup = Menu.RadioGroup -const menuRowBaseClassName = 'mx-1 flex h-8 cursor-pointer select-none items-center rounded-lg px-2 outline-none' -const menuRowStateClassName = 'data-[highlighted]:bg-state-base-hover data-[disabled]:cursor-not-allowed data-[disabled]:opacity-50' +const menuRowBaseClassName = 'mx-1 flex h-8 cursor-pointer select-none items-center gap-1 rounded-lg px-2 outline-none' +const menuRowStateClassName = 'data-[highlighted]:bg-state-base-hover data-[disabled]:cursor-not-allowed data-[disabled]:opacity-30' export function DropdownMenuRadioItem({ className, @@ -89,7 +89,7 @@ export function DropdownMenuGroupLabel({ return ( {children} - + ) } @@ -270,7 +270,7 @@ export function DropdownMenuSeparator({ }: React.ComponentPropsWithoutRef) { return ( ) From 1c1edb4a226acf5e169bd72376036a21575d2dfa Mon Sep 17 00:00:00 2001 From: yyh <92089059+lyzno1@users.noreply.github.com> Date: Wed, 4 Mar 2026 09:53:36 +0800 Subject: [PATCH 012/256] fix: keep account dropdown open when switching theme (#32918) --- web/app/components/base/theme-switcher.tsx | 18 ++++++++++++------ .../header/account-dropdown/index.spec.tsx | 14 ++++++++++++++ .../header/account-dropdown/index.tsx | 2 +- 3 files changed, 27 insertions(+), 7 deletions(-) diff --git a/web/app/components/base/theme-switcher.tsx b/web/app/components/base/theme-switcher.tsx index 86e24a443c..58da8f4664 100644 --- a/web/app/components/base/theme-switcher.tsx +++ b/web/app/components/base/theme-switcher.tsx @@ -13,44 +13,50 @@ export default function ThemeSwitcher() { return (
-
handleThemeChange('system')} + aria-label="System theme" data-testid="system-theme-container" >
-
+
-
handleThemeChange('light')} + aria-label="Light theme" data-testid="light-theme-container" >
-
+
-
handleThemeChange('dark')} + aria-label="Dark theme" data-testid="dark-theme-container" >
-
+
) } diff --git a/web/app/components/header/account-dropdown/index.spec.tsx b/web/app/components/header/account-dropdown/index.spec.tsx index 60e00868c1..a92f8503ee 100644 --- a/web/app/components/header/account-dropdown/index.spec.tsx +++ b/web/app/components/header/account-dropdown/index.spec.tsx @@ -29,6 +29,10 @@ vi.mock('@/app/components/header/github-star', () => ({ default: () =>
GithubStar
, })) +vi.mock('@/app/components/base/theme-switcher', () => ({ + default: () => , +})) + vi.mock('@/context/app-context', () => ({ useAppContext: vi.fn(), })) @@ -276,6 +280,16 @@ describe('AccountDropdown', () => { // Assert expect(screen.queryByTestId('account-about')).not.toBeInTheDocument() }) + + it('should keep account dropdown open when clicking the theme switcher', () => { + // Act + renderWithRouter() + fireEvent.click(screen.getByRole('button', { name: 'common.account.account' })) + fireEvent.click(screen.getByTestId('theme-switcher-button')) + + // Assert + expect(screen.getByText('common.userProfile.logout')).toBeInTheDocument() + }) }) describe('Branding and Environment', () => { diff --git a/web/app/components/header/account-dropdown/index.tsx b/web/app/components/header/account-dropdown/index.tsx index 4cf1f4efda..c4f1c5699f 100644 --- a/web/app/components/header/account-dropdown/index.tsx +++ b/web/app/components/header/account-dropdown/index.tsx @@ -228,8 +228,8 @@ export default function AppSelector() { )} e.preventDefault()} > Date: Wed, 4 Mar 2026 10:07:29 +0800 Subject: [PATCH 013/256] chore: fix load env and treeshaking for vinext (#32928) --- web/config/index.ts | 4 ++-- web/env.ts | 4 ---- web/next.config.ts | 2 +- web/package.json | 6 +++--- web/pnpm-lock.yaml | 10 +++++----- web/proxy.ts | 2 +- web/vite.config.ts | 14 +++++--------- 7 files changed, 17 insertions(+), 25 deletions(-) diff --git a/web/config/index.ts b/web/config/index.ts index 167c87ae34..35ea3780a8 100644 --- a/web/config/index.ts +++ b/web/config/index.ts @@ -42,8 +42,8 @@ export const AMPLITUDE_API_KEY = getStringConfig( '', ) -export const IS_DEV = env.NODE_ENV === 'development' -export const IS_PROD = env.NODE_ENV === 'production' +export const IS_DEV = process.env.NODE_ENV === 'development' +export const IS_PROD = process.env.NODE_ENV === 'production' export const SUPPORT_MAIL_LOGIN = env.NEXT_PUBLIC_SUPPORT_MAIL_LOGIN diff --git a/web/env.ts b/web/env.ts index f240fcd980..f931c48677 100644 --- a/web/env.ts +++ b/web/env.ts @@ -153,12 +153,8 @@ export const env = createEnv({ */ TEXT_GENERATION_TIMEOUT_MS: coercedNumber.default(60000), }, - shared: { - NODE_ENV: z.enum(['development', 'test', 'production']).default('development'), - }, client: clientSchema, experimental__runtimeEnv: { - NODE_ENV: process.env.NODE_ENV, NEXT_PUBLIC_ALLOW_EMBED: isServer ? process.env.NEXT_PUBLIC_ALLOW_EMBED : getRuntimeEnvFromBody('allowEmbed'), NEXT_PUBLIC_ALLOW_UNSAFE_DATA_SCHEME: isServer ? process.env.NEXT_PUBLIC_ALLOW_UNSAFE_DATA_SCHEME : getRuntimeEnvFromBody('allowUnsafeDataScheme'), NEXT_PUBLIC_AMPLITUDE_API_KEY: isServer ? process.env.NEXT_PUBLIC_AMPLITUDE_API_KEY : getRuntimeEnvFromBody('amplitudeApiKey'), diff --git a/web/next.config.ts b/web/next.config.ts index 2236278a74..591c210fe9 100644 --- a/web/next.config.ts +++ b/web/next.config.ts @@ -3,7 +3,7 @@ import createMDX from '@next/mdx' import { codeInspectorPlugin } from 'code-inspector-plugin' import { env } from './env' -const isDev = env.NODE_ENV === 'development' +const isDev = process.env.NODE_ENV === 'development' const withMDX = createMDX({ extension: /\.mdx?$/, options: { diff --git a/web/package.json b/web/package.json index e0df37836e..d352943328 100644 --- a/web/package.json +++ b/web/package.json @@ -31,9 +31,9 @@ "dev:vinext": "vinext dev", "build": "next build", "build:docker": "next build && node scripts/optimize-standalone.js", - "build:vinext": "cross-env NODE_ENV=production vinext build", + "build:vinext": "vinext build", "start": "node ./scripts/copy-and-start.mjs", - "start:vinext": "cross-env NODE_ENV=production vinext start", + "start:vinext": "vinext start", "lint": "eslint --cache --concurrency=auto", "lint:ci": "eslint --cache --concurrency 2", "lint:fix": "pnpm lint --fix", @@ -247,7 +247,7 @@ "tsx": "4.21.0", "typescript": "5.9.3", "uglify-js": "3.19.3", - "vinext": "https://pkg.pr.new/hyoban/vinext@cfae669", + "vinext": "https://pkg.pr.new/hyoban/vinext@a30ba79", "vite": "8.0.0-beta.16", "vite-tsconfig-paths": "6.1.1", "vitest": "4.0.18", diff --git a/web/pnpm-lock.yaml b/web/pnpm-lock.yaml index 81ecb7f857..3251493a13 100644 --- a/web/pnpm-lock.yaml +++ b/web/pnpm-lock.yaml @@ -608,8 +608,8 @@ importers: specifier: 3.19.3 version: 3.19.3 vinext: - specifier: https://pkg.pr.new/hyoban/vinext@cfae669 - version: https://pkg.pr.new/hyoban/vinext@cfae669(next@16.1.5(@babel/core@7.29.0)(react-dom@19.2.4(react@19.2.4))(react@19.2.4)(sass@1.93.2))(react-dom@19.2.4(react@19.2.4))(react@19.2.4)(typescript@5.9.3)(vite@8.0.0-beta.16(@types/node@24.10.12)(esbuild@0.27.2)(jiti@1.21.7)(sass@1.93.2)(terser@5.46.0)(tsx@4.21.0)(yaml@2.8.2))(webpack@5.104.1(esbuild@0.27.2)(uglify-js@3.19.3)) + specifier: https://pkg.pr.new/hyoban/vinext@a30ba79 + version: https://pkg.pr.new/hyoban/vinext@a30ba79(next@16.1.5(@babel/core@7.29.0)(react-dom@19.2.4(react@19.2.4))(react@19.2.4)(sass@1.93.2))(react-dom@19.2.4(react@19.2.4))(react@19.2.4)(typescript@5.9.3)(vite@8.0.0-beta.16(@types/node@24.10.12)(esbuild@0.27.2)(jiti@1.21.7)(sass@1.93.2)(terser@5.46.0)(tsx@4.21.0)(yaml@2.8.2))(webpack@5.104.1(esbuild@0.27.2)(uglify-js@3.19.3)) vite: specifier: 8.0.0-beta.16 version: 8.0.0-beta.16(@types/node@24.10.12)(esbuild@0.27.2)(jiti@1.21.7)(sass@1.93.2)(terser@5.46.0)(tsx@4.21.0)(yaml@2.8.2) @@ -7727,8 +7727,8 @@ packages: vfile@6.0.3: resolution: {integrity: sha512-KzIbH/9tXat2u30jf+smMwFCsno4wHVdNmzFyL+T/L3UGqqk6JKfVqOFOZEpZSHADH1k40ab6NUIXZq422ov3Q==} - vinext@https://pkg.pr.new/hyoban/vinext@cfae669: - resolution: {integrity: sha512-4SRm/Dkou0Ib0UYexP8xg0G83jIM17XPUC32uXwLHt5lO47AisblMpDZXTh84fhN058FEHtPaAGtoFThaoZLIw==, tarball: https://pkg.pr.new/hyoban/vinext@cfae669} + vinext@https://pkg.pr.new/hyoban/vinext@a30ba79: + resolution: {integrity: sha512-yx/gCneOli5eGTrLUq6/M7A6DGQs14qOJW/Xp9RN6sTI0mErKyWWjO5E7FZT98BJbqH5xzI5nk8EOCLF3bojkA==, tarball: https://pkg.pr.new/hyoban/vinext@a30ba79} version: 0.0.5 engines: {node: '>=22'} hasBin: true @@ -16291,7 +16291,7 @@ snapshots: '@types/unist': 3.0.3 vfile-message: 4.0.3 - vinext@https://pkg.pr.new/hyoban/vinext@cfae669(next@16.1.5(@babel/core@7.29.0)(react-dom@19.2.4(react@19.2.4))(react@19.2.4)(sass@1.93.2))(react-dom@19.2.4(react@19.2.4))(react@19.2.4)(typescript@5.9.3)(vite@8.0.0-beta.16(@types/node@24.10.12)(esbuild@0.27.2)(jiti@1.21.7)(sass@1.93.2)(terser@5.46.0)(tsx@4.21.0)(yaml@2.8.2))(webpack@5.104.1(esbuild@0.27.2)(uglify-js@3.19.3)): + vinext@https://pkg.pr.new/hyoban/vinext@a30ba79(next@16.1.5(@babel/core@7.29.0)(react-dom@19.2.4(react@19.2.4))(react@19.2.4)(sass@1.93.2))(react-dom@19.2.4(react@19.2.4))(react@19.2.4)(typescript@5.9.3)(vite@8.0.0-beta.16(@types/node@24.10.12)(esbuild@0.27.2)(jiti@1.21.7)(sass@1.93.2)(terser@5.46.0)(tsx@4.21.0)(yaml@2.8.2))(webpack@5.104.1(esbuild@0.27.2)(uglify-js@3.19.3)): dependencies: '@unpic/react': 1.0.2(next@16.1.5(@babel/core@7.29.0)(react-dom@19.2.4(react@19.2.4))(react@19.2.4)(sass@1.93.2))(react-dom@19.2.4(react@19.2.4))(react@19.2.4) '@vercel/og': 0.8.6 diff --git a/web/proxy.ts b/web/proxy.ts index bc4a4a3d89..8d7c28153e 100644 --- a/web/proxy.ts +++ b/web/proxy.ts @@ -22,7 +22,7 @@ export function proxy(request: NextRequest) { }, }) - const isWhiteListEnabled = !!env.NEXT_PUBLIC_CSP_WHITELIST && env.NODE_ENV === 'production' + const isWhiteListEnabled = !!env.NEXT_PUBLIC_CSP_WHITELIST && process.env.NODE_ENV === 'production' if (!isWhiteListEnabled) return wrapResponseWithXFrameOptions(response, pathname) diff --git a/web/vite.config.ts b/web/vite.config.ts index a3e1cb77ef..9cfa6a8e7b 100644 --- a/web/vite.config.ts +++ b/web/vite.config.ts @@ -71,10 +71,10 @@ const createForceInspectorClientInjectionPlugin = (): Plugin => { } export default defineConfig(({ mode }) => { - const isDev = mode === 'development' + const isTest = mode === 'test' return { - plugins: mode === 'test' + plugins: isTest ? [ tsconfigPaths(), react(), @@ -89,12 +89,8 @@ export default defineConfig(({ mode }) => { } as Plugin, ] : [ - ...(isDev - ? [ - createCodeInspectorPlugin(), - createForceInspectorClientInjectionPlugin(), - ] - : []), + createCodeInspectorPlugin(), + createForceInspectorClientInjectionPlugin(), vinext(), ], resolve: { @@ -104,7 +100,7 @@ export default defineConfig(({ mode }) => { }, // vinext related config - ...(mode !== 'test' + ...(!isTest ? { optimizeDeps: { exclude: ['nuqs'], From 3398962bfa9b20bc2e5e7a095c9326bc65ad0f4a Mon Sep 17 00:00:00 2001 From: Coding On Star <447357187@qq.com> Date: Wed, 4 Mar 2026 10:59:31 +0800 Subject: [PATCH 014/256] test(workflow): add unit tests for workflow store slices (#32932) Co-authored-by: CodingOnStar --- .../__tests__/chat-variable-slice.spec.ts | 67 +++ .../__tests__/datasets-detail-store.spec.ts | 62 +++ .../__tests__/env-variable-slice.spec.ts | 67 +++ .../__tests__/inspect-vars-slice.spec.ts | 240 +++++++++ .../__tests__/plugin-dependency-store.spec.ts | 43 ++ .../store/__tests__/version-slice.spec.ts | 61 +++ .../__tests__/workflow-draft-slice.spec.ts | 105 ++++ .../store/__tests__/workflow-store.spec.ts | 486 ++++++++++++++++++ 8 files changed, 1131 insertions(+) create mode 100644 web/app/components/workflow/store/__tests__/chat-variable-slice.spec.ts create mode 100644 web/app/components/workflow/store/__tests__/datasets-detail-store.spec.ts create mode 100644 web/app/components/workflow/store/__tests__/env-variable-slice.spec.ts create mode 100644 web/app/components/workflow/store/__tests__/inspect-vars-slice.spec.ts create mode 100644 web/app/components/workflow/store/__tests__/plugin-dependency-store.spec.ts create mode 100644 web/app/components/workflow/store/__tests__/version-slice.spec.ts create mode 100644 web/app/components/workflow/store/__tests__/workflow-draft-slice.spec.ts create mode 100644 web/app/components/workflow/store/__tests__/workflow-store.spec.ts diff --git a/web/app/components/workflow/store/__tests__/chat-variable-slice.spec.ts b/web/app/components/workflow/store/__tests__/chat-variable-slice.spec.ts new file mode 100644 index 0000000000..512eb5b404 --- /dev/null +++ b/web/app/components/workflow/store/__tests__/chat-variable-slice.spec.ts @@ -0,0 +1,67 @@ +import type { ConversationVariable } from '@/app/components/workflow/types' +import { ChatVarType } from '@/app/components/workflow/panel/chat-variable-panel/type' +import { createWorkflowStore } from '../workflow' + +function createStore() { + return createWorkflowStore({}) +} + +describe('Chat Variable Slice', () => { + describe('setShowChatVariablePanel', () => { + it('should hide other panels when opening', () => { + const store = createStore() + store.getState().setShowDebugAndPreviewPanel(true) + store.getState().setShowEnvPanel(true) + + store.getState().setShowChatVariablePanel(true) + + const state = store.getState() + expect(state.showChatVariablePanel).toBe(true) + expect(state.showDebugAndPreviewPanel).toBe(false) + expect(state.showEnvPanel).toBe(false) + expect(state.showGlobalVariablePanel).toBe(false) + }) + + it('should only close itself when setting false', () => { + const store = createStore() + store.getState().setShowChatVariablePanel(true) + + store.getState().setShowChatVariablePanel(false) + + expect(store.getState().showChatVariablePanel).toBe(false) + }) + }) + + describe('setShowGlobalVariablePanel', () => { + it('should hide other panels when opening', () => { + const store = createStore() + store.getState().setShowDebugAndPreviewPanel(true) + store.getState().setShowChatVariablePanel(true) + + store.getState().setShowGlobalVariablePanel(true) + + const state = store.getState() + expect(state.showGlobalVariablePanel).toBe(true) + expect(state.showDebugAndPreviewPanel).toBe(false) + expect(state.showChatVariablePanel).toBe(false) + expect(state.showEnvPanel).toBe(false) + }) + + it('should only close itself when setting false', () => { + const store = createStore() + store.getState().setShowGlobalVariablePanel(true) + store.getState().setShowGlobalVariablePanel(false) + + expect(store.getState().showGlobalVariablePanel).toBe(false) + }) + }) + + describe('setConversationVariables', () => { + it('should update conversationVariables', () => { + const store = createStore() + const vars: ConversationVariable[] = [{ id: 'cv1', name: 'history', value: [], value_type: ChatVarType.String, description: '' }] + store.getState().setConversationVariables(vars) + expect(store.getState().conversationVariables).toEqual(vars) + }) + }) +}) diff --git a/web/app/components/workflow/store/__tests__/datasets-detail-store.spec.ts b/web/app/components/workflow/store/__tests__/datasets-detail-store.spec.ts new file mode 100644 index 0000000000..3728aeda8e --- /dev/null +++ b/web/app/components/workflow/store/__tests__/datasets-detail-store.spec.ts @@ -0,0 +1,62 @@ +import type { DataSet } from '@/models/datasets' +import { createDatasetsDetailStore } from '../../datasets-detail-store/store' + +function makeDataset(id: string, name: string): DataSet { + return { id, name } as DataSet +} + +describe('DatasetsDetailStore', () => { + describe('Initial State', () => { + it('should start with empty datasetsDetail', () => { + const store = createDatasetsDetailStore() + expect(store.getState().datasetsDetail).toEqual({}) + }) + }) + + describe('updateDatasetsDetail', () => { + it('should add datasets by id', () => { + const store = createDatasetsDetailStore() + const ds1 = makeDataset('ds-1', 'Dataset 1') + const ds2 = makeDataset('ds-2', 'Dataset 2') + + store.getState().updateDatasetsDetail([ds1, ds2]) + + expect(store.getState().datasetsDetail['ds-1']).toEqual(ds1) + expect(store.getState().datasetsDetail['ds-2']).toEqual(ds2) + }) + + it('should merge new datasets into existing ones', () => { + const store = createDatasetsDetailStore() + const ds1 = makeDataset('ds-1', 'First') + const ds2 = makeDataset('ds-2', 'Second') + const ds3 = makeDataset('ds-3', 'Third') + + store.getState().updateDatasetsDetail([ds1, ds2]) + store.getState().updateDatasetsDetail([ds3]) + + const detail = store.getState().datasetsDetail + expect(detail['ds-1']).toEqual(ds1) + expect(detail['ds-2']).toEqual(ds2) + expect(detail['ds-3']).toEqual(ds3) + }) + + it('should overwrite existing datasets with same id', () => { + const store = createDatasetsDetailStore() + const ds1v1 = makeDataset('ds-1', 'Version 1') + const ds1v2 = makeDataset('ds-1', 'Version 2') + + store.getState().updateDatasetsDetail([ds1v1]) + store.getState().updateDatasetsDetail([ds1v2]) + + expect(store.getState().datasetsDetail['ds-1'].name).toBe('Version 2') + }) + + it('should handle empty array without errors', () => { + const store = createDatasetsDetailStore() + store.getState().updateDatasetsDetail([makeDataset('ds-1', 'Test')]) + store.getState().updateDatasetsDetail([]) + + expect(store.getState().datasetsDetail['ds-1'].name).toBe('Test') + }) + }) +}) diff --git a/web/app/components/workflow/store/__tests__/env-variable-slice.spec.ts b/web/app/components/workflow/store/__tests__/env-variable-slice.spec.ts new file mode 100644 index 0000000000..95ed7d3955 --- /dev/null +++ b/web/app/components/workflow/store/__tests__/env-variable-slice.spec.ts @@ -0,0 +1,67 @@ +import type { EnvironmentVariable } from '@/app/components/workflow/types' +import { createWorkflowStore } from '../workflow' + +function createStore() { + return createWorkflowStore({}) +} + +describe('Env Variable Slice', () => { + describe('setShowEnvPanel', () => { + it('should hide other panels when opening', () => { + const store = createStore() + store.getState().setShowDebugAndPreviewPanel(true) + store.getState().setShowChatVariablePanel(true) + + store.getState().setShowEnvPanel(true) + + const state = store.getState() + expect(state.showEnvPanel).toBe(true) + expect(state.showDebugAndPreviewPanel).toBe(false) + expect(state.showChatVariablePanel).toBe(false) + expect(state.showGlobalVariablePanel).toBe(false) + }) + + it('should only close itself when setting false', () => { + const store = createStore() + store.getState().setShowEnvPanel(true) + + store.getState().setShowEnvPanel(false) + + expect(store.getState().showEnvPanel).toBe(false) + }) + }) + + describe('setEnvironmentVariables', () => { + it('should update environmentVariables', () => { + const store = createStore() + const vars: EnvironmentVariable[] = [{ id: 'v1', name: 'API_KEY', value: 'secret', value_type: 'string', description: '' }] + store.getState().setEnvironmentVariables(vars) + expect(store.getState().environmentVariables).toEqual(vars) + }) + }) + + describe('setEnvSecrets', () => { + it('should update envSecrets', () => { + const store = createStore() + store.getState().setEnvSecrets({ API_KEY: '***' }) + expect(store.getState().envSecrets).toEqual({ API_KEY: '***' }) + }) + }) + + describe('Sequential Panel Switching', () => { + it('should correctly switch between exclusive panels', () => { + const store = createStore() + + store.getState().setShowChatVariablePanel(true) + expect(store.getState().showChatVariablePanel).toBe(true) + + store.getState().setShowEnvPanel(true) + expect(store.getState().showEnvPanel).toBe(true) + expect(store.getState().showChatVariablePanel).toBe(false) + + store.getState().setShowGlobalVariablePanel(true) + expect(store.getState().showGlobalVariablePanel).toBe(true) + expect(store.getState().showEnvPanel).toBe(false) + }) + }) +}) diff --git a/web/app/components/workflow/store/__tests__/inspect-vars-slice.spec.ts b/web/app/components/workflow/store/__tests__/inspect-vars-slice.spec.ts new file mode 100644 index 0000000000..225cb6a6c8 --- /dev/null +++ b/web/app/components/workflow/store/__tests__/inspect-vars-slice.spec.ts @@ -0,0 +1,240 @@ +import type { NodeWithVar, VarInInspect } from '@/types/workflow' +import { BlockEnum, VarType } from '@/app/components/workflow/types' +import { VarInInspectType } from '@/types/workflow' +import { createWorkflowStore } from '../workflow' + +function createStore() { + return createWorkflowStore({}) +} + +function makeVar(overrides: Partial = {}): VarInInspect { + return { + id: 'var-1', + name: 'output', + type: VarInInspectType.node, + description: '', + selector: ['node-1', 'output'], + value_type: VarType.string, + value: 'hello', + edited: false, + visible: true, + is_truncated: false, + full_content: { size_bytes: 0, download_url: '' }, + ...overrides, + } +} + +function makeNodeWithVar(nodeId: string, vars: VarInInspect[]): NodeWithVar { + return { + nodeId, + nodePayload: { title: `Node ${nodeId}`, desc: '', type: BlockEnum.Code } as NodeWithVar['nodePayload'], + nodeType: BlockEnum.Code, + title: `Node ${nodeId}`, + vars, + isValueFetched: false, + } +} + +describe('Inspect Vars Slice', () => { + describe('setNodesWithInspectVars', () => { + it('should replace the entire list', () => { + const store = createStore() + const nodes = [makeNodeWithVar('n1', [makeVar()])] + store.getState().setNodesWithInspectVars(nodes) + expect(store.getState().nodesWithInspectVars).toEqual(nodes) + }) + }) + + describe('deleteAllInspectVars', () => { + it('should clear all nodes', () => { + const store = createStore() + store.getState().setNodesWithInspectVars([makeNodeWithVar('n1', [makeVar()])]) + store.getState().deleteAllInspectVars() + expect(store.getState().nodesWithInspectVars).toEqual([]) + }) + }) + + describe('setNodeInspectVars', () => { + it('should update vars for a specific node and mark as fetched', () => { + const store = createStore() + const v1 = makeVar({ id: 'v1', name: 'a' }) + const v2 = makeVar({ id: 'v2', name: 'b' }) + store.getState().setNodesWithInspectVars([makeNodeWithVar('n1', [v1])]) + + store.getState().setNodeInspectVars('n1', [v2]) + + const node = store.getState().nodesWithInspectVars[0] + expect(node.vars).toEqual([v2]) + expect(node.isValueFetched).toBe(true) + }) + + it('should not modify state when node is not found', () => { + const store = createStore() + store.getState().setNodesWithInspectVars([makeNodeWithVar('n1', [makeVar()])]) + + store.getState().setNodeInspectVars('non-existent', []) + + expect(store.getState().nodesWithInspectVars[0].vars).toHaveLength(1) + }) + }) + + describe('deleteNodeInspectVars', () => { + it('should remove the matching node', () => { + const store = createStore() + store.getState().setNodesWithInspectVars([ + makeNodeWithVar('n1', [makeVar()]), + makeNodeWithVar('n2', [makeVar()]), + ]) + + store.getState().deleteNodeInspectVars('n1') + + expect(store.getState().nodesWithInspectVars).toHaveLength(1) + expect(store.getState().nodesWithInspectVars[0].nodeId).toBe('n2') + }) + }) + + describe('setInspectVarValue', () => { + it('should update the value and set edited=true', () => { + const store = createStore() + const v = makeVar({ id: 'v1', value: 'old', edited: false }) + store.getState().setNodesWithInspectVars([makeNodeWithVar('n1', [v])]) + + store.getState().setInspectVarValue('n1', 'v1', 'new') + + const updated = store.getState().nodesWithInspectVars[0].vars[0] + expect(updated.value).toBe('new') + expect(updated.edited).toBe(true) + }) + + it('should not change state when var is not found', () => { + const store = createStore() + const v = makeVar({ id: 'v1', value: 'old' }) + store.getState().setNodesWithInspectVars([makeNodeWithVar('n1', [v])]) + + store.getState().setInspectVarValue('n1', 'wrong-id', 'new') + + expect(store.getState().nodesWithInspectVars[0].vars[0].value).toBe('old') + }) + + it('should not change state when node is not found', () => { + const store = createStore() + const v = makeVar({ id: 'v1', value: 'old' }) + store.getState().setNodesWithInspectVars([makeNodeWithVar('n1', [v])]) + + store.getState().setInspectVarValue('wrong-node', 'v1', 'new') + + expect(store.getState().nodesWithInspectVars[0].vars[0].value).toBe('old') + }) + }) + + describe('resetToLastRunVar', () => { + it('should restore value and set edited=false', () => { + const store = createStore() + const v = makeVar({ id: 'v1', value: 'modified', edited: true }) + store.getState().setNodesWithInspectVars([makeNodeWithVar('n1', [v])]) + + store.getState().resetToLastRunVar('n1', 'v1', 'original') + + const updated = store.getState().nodesWithInspectVars[0].vars[0] + expect(updated.value).toBe('original') + expect(updated.edited).toBe(false) + }) + + it('should not change state when node is not found', () => { + const store = createStore() + store.getState().setNodesWithInspectVars([makeNodeWithVar('n1', [makeVar()])]) + + store.getState().resetToLastRunVar('wrong-node', 'v1', 'val') + + expect(store.getState().nodesWithInspectVars[0].vars[0].edited).toBe(false) + }) + + it('should not change state when var is not found', () => { + const store = createStore() + store.getState().setNodesWithInspectVars([makeNodeWithVar('n1', [makeVar({ id: 'v1', edited: true })])]) + + store.getState().resetToLastRunVar('n1', 'wrong-var', 'val') + + expect(store.getState().nodesWithInspectVars[0].vars[0].edited).toBe(true) + }) + }) + + describe('renameInspectVarName', () => { + it('should update name and selector', () => { + const store = createStore() + const v = makeVar({ id: 'v1', name: 'old_name', selector: ['n1', 'old_name'] }) + store.getState().setNodesWithInspectVars([makeNodeWithVar('n1', [v])]) + + store.getState().renameInspectVarName('n1', 'v1', ['n1', 'new_name']) + + const updated = store.getState().nodesWithInspectVars[0].vars[0] + expect(updated.name).toBe('new_name') + expect(updated.selector).toEqual(['n1', 'new_name']) + }) + + it('should not change state when node is not found', () => { + const store = createStore() + const v = makeVar({ id: 'v1', name: 'old' }) + store.getState().setNodesWithInspectVars([makeNodeWithVar('n1', [v])]) + + store.getState().renameInspectVarName('wrong-node', 'v1', ['x', 'y']) + + expect(store.getState().nodesWithInspectVars[0].vars[0].name).toBe('old') + }) + + it('should not change state when var is not found', () => { + const store = createStore() + const v = makeVar({ id: 'v1', name: 'old' }) + store.getState().setNodesWithInspectVars([makeNodeWithVar('n1', [v])]) + + store.getState().renameInspectVarName('n1', 'wrong-var', ['x', 'y']) + + expect(store.getState().nodesWithInspectVars[0].vars[0].name).toBe('old') + }) + }) + + describe('deleteInspectVar', () => { + it('should remove the matching var from the node', () => { + const store = createStore() + const v1 = makeVar({ id: 'v1' }) + const v2 = makeVar({ id: 'v2' }) + store.getState().setNodesWithInspectVars([makeNodeWithVar('n1', [v1, v2])]) + + store.getState().deleteInspectVar('n1', 'v1') + + const vars = store.getState().nodesWithInspectVars[0].vars + expect(vars).toHaveLength(1) + expect(vars[0].id).toBe('v2') + }) + + it('should not change state when var is not found', () => { + const store = createStore() + const v = makeVar({ id: 'v1' }) + store.getState().setNodesWithInspectVars([makeNodeWithVar('n1', [v])]) + + store.getState().deleteInspectVar('n1', 'wrong-id') + + expect(store.getState().nodesWithInspectVars[0].vars).toHaveLength(1) + }) + + it('should not change state when node is not found', () => { + const store = createStore() + store.getState().setNodesWithInspectVars([makeNodeWithVar('n1', [makeVar()])]) + + store.getState().deleteInspectVar('wrong-node', 'v1') + + expect(store.getState().nodesWithInspectVars[0].vars).toHaveLength(1) + }) + }) + + describe('currentFocusNodeId', () => { + it('should update and clear focus node', () => { + const store = createStore() + store.getState().setCurrentFocusNodeId('n1') + expect(store.getState().currentFocusNodeId).toBe('n1') + + store.getState().setCurrentFocusNodeId(null) + expect(store.getState().currentFocusNodeId).toBeNull() + }) + }) +}) diff --git a/web/app/components/workflow/store/__tests__/plugin-dependency-store.spec.ts b/web/app/components/workflow/store/__tests__/plugin-dependency-store.spec.ts new file mode 100644 index 0000000000..8c0cdd8337 --- /dev/null +++ b/web/app/components/workflow/store/__tests__/plugin-dependency-store.spec.ts @@ -0,0 +1,43 @@ +import type { Dependency } from '@/app/components/plugins/types' +import { useStore } from '../../plugin-dependency/store' + +describe('Plugin Dependency Store', () => { + beforeEach(() => { + useStore.setState({ dependencies: [] }) + }) + + describe('Initial State', () => { + it('should start with empty dependencies', () => { + expect(useStore.getState().dependencies).toEqual([]) + }) + }) + + describe('setDependencies', () => { + it('should update dependencies list', () => { + const deps: Dependency[] = [ + { type: 'marketplace', value: { plugin_unique_identifier: 'p1' } }, + { type: 'marketplace', value: { plugin_unique_identifier: 'p2' } }, + ] as Dependency[] + + useStore.getState().setDependencies(deps) + expect(useStore.getState().dependencies).toEqual(deps) + }) + + it('should replace existing dependencies', () => { + const dep1: Dependency = { type: 'marketplace', value: { plugin_unique_identifier: 'p1' } } as Dependency + const dep2: Dependency = { type: 'marketplace', value: { plugin_unique_identifier: 'p2' } } as Dependency + useStore.getState().setDependencies([dep1]) + useStore.getState().setDependencies([dep2]) + + expect(useStore.getState().dependencies).toHaveLength(1) + }) + + it('should handle empty array', () => { + const dep: Dependency = { type: 'marketplace', value: { plugin_unique_identifier: 'p1' } } as Dependency + useStore.getState().setDependencies([dep]) + useStore.getState().setDependencies([]) + + expect(useStore.getState().dependencies).toEqual([]) + }) + }) +}) diff --git a/web/app/components/workflow/store/__tests__/version-slice.spec.ts b/web/app/components/workflow/store/__tests__/version-slice.spec.ts new file mode 100644 index 0000000000..8d76a62256 --- /dev/null +++ b/web/app/components/workflow/store/__tests__/version-slice.spec.ts @@ -0,0 +1,61 @@ +import type { VersionHistory } from '@/types/workflow' +import { createWorkflowStore } from '../workflow' + +function createStore() { + return createWorkflowStore({}) +} + +describe('Version Slice', () => { + describe('setDraftUpdatedAt', () => { + it('should multiply timestamp by 1000 (seconds to milliseconds)', () => { + const store = createStore() + store.getState().setDraftUpdatedAt(1704067200) + expect(store.getState().draftUpdatedAt).toBe(1704067200000) + }) + + it('should set 0 when given 0', () => { + const store = createStore() + store.getState().setDraftUpdatedAt(0) + expect(store.getState().draftUpdatedAt).toBe(0) + }) + }) + + describe('setPublishedAt', () => { + it('should multiply timestamp by 1000', () => { + const store = createStore() + store.getState().setPublishedAt(1704067200) + expect(store.getState().publishedAt).toBe(1704067200000) + }) + + it('should set 0 when given 0', () => { + const store = createStore() + store.getState().setPublishedAt(0) + expect(store.getState().publishedAt).toBe(0) + }) + }) + + describe('currentVersion', () => { + it('should default to null', () => { + const store = createStore() + expect(store.getState().currentVersion).toBeNull() + }) + + it('should update current version', () => { + const store = createStore() + const version = { hash: 'abc', updated_at: 1000, version: '1.0' } as VersionHistory + store.getState().setCurrentVersion(version) + expect(store.getState().currentVersion).toEqual(version) + }) + }) + + describe('isRestoring', () => { + it('should toggle restoring state', () => { + const store = createStore() + store.getState().setIsRestoring(true) + expect(store.getState().isRestoring).toBe(true) + + store.getState().setIsRestoring(false) + expect(store.getState().isRestoring).toBe(false) + }) + }) +}) diff --git a/web/app/components/workflow/store/__tests__/workflow-draft-slice.spec.ts b/web/app/components/workflow/store/__tests__/workflow-draft-slice.spec.ts new file mode 100644 index 0000000000..dfbc58e050 --- /dev/null +++ b/web/app/components/workflow/store/__tests__/workflow-draft-slice.spec.ts @@ -0,0 +1,105 @@ +import type { Node } from '@/app/components/workflow/types' +import { createWorkflowStore } from '../workflow' + +function createStore() { + return createWorkflowStore({}) +} + +describe('Workflow Draft Slice', () => { + describe('Initial State', () => { + it('should have empty default values', () => { + const store = createStore() + const state = store.getState() + expect(state.backupDraft).toBeUndefined() + expect(state.syncWorkflowDraftHash).toBe('') + expect(state.isSyncingWorkflowDraft).toBe(false) + expect(state.isWorkflowDataLoaded).toBe(false) + expect(state.nodes).toEqual([]) + }) + }) + + describe('setBackupDraft', () => { + it('should set and clear backup draft', () => { + const store = createStore() + const draft = { + nodes: [] as Node[], + edges: [], + viewport: { x: 0, y: 0, zoom: 1 }, + environmentVariables: [], + } + store.getState().setBackupDraft(draft) + expect(store.getState().backupDraft).toEqual(draft) + + store.getState().setBackupDraft(undefined) + expect(store.getState().backupDraft).toBeUndefined() + }) + }) + + describe('setSyncWorkflowDraftHash', () => { + it('should update the hash', () => { + const store = createStore() + store.getState().setSyncWorkflowDraftHash('abc123') + expect(store.getState().syncWorkflowDraftHash).toBe('abc123') + }) + }) + + describe('setIsSyncingWorkflowDraft', () => { + it('should toggle syncing state', () => { + const store = createStore() + store.getState().setIsSyncingWorkflowDraft(true) + expect(store.getState().isSyncingWorkflowDraft).toBe(true) + }) + }) + + describe('setIsWorkflowDataLoaded', () => { + it('should toggle loaded state', () => { + const store = createStore() + store.getState().setIsWorkflowDataLoaded(true) + expect(store.getState().isWorkflowDataLoaded).toBe(true) + }) + }) + + describe('setNodes', () => { + it('should update nodes array', () => { + const store = createStore() + const nodes: Node[] = [] + store.getState().setNodes(nodes) + expect(store.getState().nodes).toEqual(nodes) + }) + }) + + describe('debouncedSyncWorkflowDraft', () => { + it('should be a callable function', () => { + const store = createStore() + expect(typeof store.getState().debouncedSyncWorkflowDraft).toBe('function') + }) + + it('should debounce the sync call', () => { + vi.useFakeTimers() + const store = createStore() + const syncFn = vi.fn() + + store.getState().debouncedSyncWorkflowDraft(syncFn) + expect(syncFn).not.toHaveBeenCalled() + + vi.advanceTimersByTime(5000) + expect(syncFn).toHaveBeenCalledTimes(1) + + vi.useRealTimers() + }) + + it('should flush pending sync via flushPendingSync', () => { + vi.useFakeTimers() + const store = createStore() + const syncFn = vi.fn() + + store.getState().debouncedSyncWorkflowDraft(syncFn) + expect(syncFn).not.toHaveBeenCalled() + + store.getState().flushPendingSync() + expect(syncFn).toHaveBeenCalledTimes(1) + + vi.useRealTimers() + }) + }) +}) diff --git a/web/app/components/workflow/store/__tests__/workflow-store.spec.ts b/web/app/components/workflow/store/__tests__/workflow-store.spec.ts new file mode 100644 index 0000000000..df94be90b8 --- /dev/null +++ b/web/app/components/workflow/store/__tests__/workflow-store.spec.ts @@ -0,0 +1,486 @@ +import type { Shape, SliceFromInjection } from '../workflow' +import type { HelpLineHorizontalPosition, HelpLineVerticalPosition } from '@/app/components/workflow/help-line/types' +import type { WorkflowRunningData } from '@/app/components/workflow/types' +import type { FileUploadConfigResponse } from '@/models/common' +import type { VersionHistory } from '@/types/workflow' +import { renderHook } from '@testing-library/react' +import * as React from 'react' +import { BlockEnum } from '@/app/components/workflow/types' +import { WorkflowContext } from '../../context' +import { createWorkflowStore, useStore, useWorkflowStore } from '../workflow' + +function createStore() { + return createWorkflowStore({}) +} + +describe('createWorkflowStore', () => { + describe('Initial State', () => { + it('should create a store with all slices merged', () => { + const store = createStore() + const state = store.getState() + + expect(state.showSingleRunPanel).toBe(false) + expect(state.controlMode).toBeDefined() + expect(state.nodes).toEqual([]) + expect(state.environmentVariables).toEqual([]) + expect(state.conversationVariables).toEqual([]) + expect(state.nodesWithInspectVars).toEqual([]) + expect(state.workflowCanvasWidth).toBeUndefined() + expect(state.draftUpdatedAt).toBe(0) + expect(state.versionHistory).toEqual([]) + }) + }) + + describe('Workflow Slice Setters', () => { + it('should update workflowRunningData', () => { + const store = createStore() + const data: Partial = { result: { status: 'running', inputs_truncated: false, process_data_truncated: false, outputs_truncated: false } } + store.getState().setWorkflowRunningData(data as Parameters[0]) + expect(store.getState().workflowRunningData).toEqual(data) + }) + + it('should update isListening', () => { + const store = createStore() + store.getState().setIsListening(true) + expect(store.getState().isListening).toBe(true) + }) + + it('should update listeningTriggerType', () => { + const store = createStore() + store.getState().setListeningTriggerType(BlockEnum.TriggerWebhook) + expect(store.getState().listeningTriggerType).toBe(BlockEnum.TriggerWebhook) + }) + + it('should update listeningTriggerNodeId', () => { + const store = createStore() + store.getState().setListeningTriggerNodeId('node-abc') + expect(store.getState().listeningTriggerNodeId).toBe('node-abc') + }) + + it('should update listeningTriggerNodeIds', () => { + const store = createStore() + store.getState().setListeningTriggerNodeIds(['n1', 'n2']) + expect(store.getState().listeningTriggerNodeIds).toEqual(['n1', 'n2']) + }) + + it('should update listeningTriggerIsAll', () => { + const store = createStore() + store.getState().setListeningTriggerIsAll(true) + expect(store.getState().listeningTriggerIsAll).toBe(true) + }) + + it('should update clipboardElements', () => { + const store = createStore() + store.getState().setClipboardElements([]) + expect(store.getState().clipboardElements).toEqual([]) + }) + + it('should update selection', () => { + const store = createStore() + const sel = { x1: 0, y1: 0, x2: 100, y2: 100 } + store.getState().setSelection(sel) + expect(store.getState().selection).toEqual(sel) + }) + + it('should update bundleNodeSize', () => { + const store = createStore() + store.getState().setBundleNodeSize({ width: 200, height: 100 }) + expect(store.getState().bundleNodeSize).toEqual({ width: 200, height: 100 }) + }) + + it('should persist controlMode to localStorage', () => { + const store = createStore() + store.getState().setControlMode('pointer') + expect(store.getState().controlMode).toBe('pointer') + expect(localStorage.setItem).toHaveBeenCalledWith('workflow-operation-mode', 'pointer') + }) + + it('should update mousePosition', () => { + const store = createStore() + const pos = { pageX: 10, pageY: 20, elementX: 5, elementY: 15 } + store.getState().setMousePosition(pos) + expect(store.getState().mousePosition).toEqual(pos) + }) + + it('should update showConfirm', () => { + const store = createStore() + const confirm = { title: 'Delete?', onConfirm: vi.fn() } + store.getState().setShowConfirm(confirm) + expect(store.getState().showConfirm).toEqual(confirm) + }) + + it('should update controlPromptEditorRerenderKey', () => { + const store = createStore() + store.getState().setControlPromptEditorRerenderKey(42) + expect(store.getState().controlPromptEditorRerenderKey).toBe(42) + }) + + it('should update showImportDSLModal', () => { + const store = createStore() + store.getState().setShowImportDSLModal(true) + expect(store.getState().showImportDSLModal).toBe(true) + }) + + it('should update fileUploadConfig', () => { + const store = createStore() + const config: FileUploadConfigResponse = { + batch_count_limit: 5, + image_file_batch_limit: 10, + single_chunk_attachment_limit: 10, + attachment_image_file_size_limit: 2, + file_size_limit: 15, + file_upload_limit: 5, + } + store.getState().setFileUploadConfig(config) + expect(store.getState().fileUploadConfig).toEqual(config) + }) + }) + + describe('Node Slice Setters', () => { + it('should update showSingleRunPanel', () => { + const store = createStore() + store.getState().setShowSingleRunPanel(true) + expect(store.getState().showSingleRunPanel).toBe(true) + }) + + it('should update nodeAnimation', () => { + const store = createStore() + store.getState().setNodeAnimation(true) + expect(store.getState().nodeAnimation).toBe(true) + }) + + it('should update candidateNode', () => { + const store = createStore() + store.getState().setCandidateNode(undefined) + expect(store.getState().candidateNode).toBeUndefined() + }) + + it('should update nodeMenu', () => { + const store = createStore() + store.getState().setNodeMenu({ top: 100, left: 200, nodeId: 'n1' }) + expect(store.getState().nodeMenu).toEqual({ top: 100, left: 200, nodeId: 'n1' }) + }) + + it('should update showAssignVariablePopup', () => { + const store = createStore() + store.getState().setShowAssignVariablePopup(undefined) + expect(store.getState().showAssignVariablePopup).toBeUndefined() + }) + + it('should update hoveringAssignVariableGroupId', () => { + const store = createStore() + store.getState().setHoveringAssignVariableGroupId('group-1') + expect(store.getState().hoveringAssignVariableGroupId).toBe('group-1') + }) + + it('should update connectingNodePayload', () => { + const store = createStore() + const payload = { nodeId: 'n1', nodeType: 'llm', handleType: 'source', handleId: 'h1' } + store.getState().setConnectingNodePayload(payload) + expect(store.getState().connectingNodePayload).toEqual(payload) + }) + + it('should update enteringNodePayload', () => { + const store = createStore() + store.getState().setEnteringNodePayload(undefined) + expect(store.getState().enteringNodePayload).toBeUndefined() + }) + + it('should update iterTimes', () => { + const store = createStore() + store.getState().setIterTimes(5) + expect(store.getState().iterTimes).toBe(5) + }) + + it('should update loopTimes', () => { + const store = createStore() + store.getState().setLoopTimes(10) + expect(store.getState().loopTimes).toBe(10) + }) + + it('should update iterParallelLogMap', () => { + const store = createStore() + const map = new Map>() + store.getState().setIterParallelLogMap(map) + expect(store.getState().iterParallelLogMap).toBe(map) + }) + + it('should update pendingSingleRun', () => { + const store = createStore() + store.getState().setPendingSingleRun({ nodeId: 'n1', action: 'run' }) + expect(store.getState().pendingSingleRun).toEqual({ nodeId: 'n1', action: 'run' }) + }) + }) + + describe('Panel Slice Setters', () => { + it('should update showFeaturesPanel', () => { + const store = createStore() + store.getState().setShowFeaturesPanel(true) + expect(store.getState().showFeaturesPanel).toBe(true) + }) + + it('should update showWorkflowVersionHistoryPanel', () => { + const store = createStore() + store.getState().setShowWorkflowVersionHistoryPanel(true) + expect(store.getState().showWorkflowVersionHistoryPanel).toBe(true) + }) + + it('should update showInputsPanel', () => { + const store = createStore() + store.getState().setShowInputsPanel(true) + expect(store.getState().showInputsPanel).toBe(true) + }) + + it('should update showDebugAndPreviewPanel', () => { + const store = createStore() + store.getState().setShowDebugAndPreviewPanel(true) + expect(store.getState().showDebugAndPreviewPanel).toBe(true) + }) + + it('should update panelMenu', () => { + const store = createStore() + store.getState().setPanelMenu({ top: 10, left: 20 }) + expect(store.getState().panelMenu).toEqual({ top: 10, left: 20 }) + }) + + it('should update selectionMenu', () => { + const store = createStore() + store.getState().setSelectionMenu({ top: 50, left: 60 }) + expect(store.getState().selectionMenu).toEqual({ top: 50, left: 60 }) + }) + + it('should update showVariableInspectPanel', () => { + const store = createStore() + store.getState().setShowVariableInspectPanel(true) + expect(store.getState().showVariableInspectPanel).toBe(true) + }) + + it('should update initShowLastRunTab', () => { + const store = createStore() + store.getState().setInitShowLastRunTab(true) + expect(store.getState().initShowLastRunTab).toBe(true) + }) + }) + + describe('Help Line Slice Setters', () => { + it('should update helpLineHorizontal', () => { + const store = createStore() + const pos: HelpLineHorizontalPosition = { top: 100, left: 0, width: 500 } + store.getState().setHelpLineHorizontal(pos) + expect(store.getState().helpLineHorizontal).toEqual(pos) + }) + + it('should clear helpLineHorizontal', () => { + const store = createStore() + store.getState().setHelpLineHorizontal({ top: 100, left: 0, width: 500 }) + store.getState().setHelpLineHorizontal(undefined) + expect(store.getState().helpLineHorizontal).toBeUndefined() + }) + + it('should update helpLineVertical', () => { + const store = createStore() + const pos: HelpLineVerticalPosition = { top: 0, left: 200, height: 300 } + store.getState().setHelpLineVertical(pos) + expect(store.getState().helpLineVertical).toEqual(pos) + }) + }) + + describe('History Slice Setters', () => { + it('should update historyWorkflowData', () => { + const store = createStore() + store.getState().setHistoryWorkflowData({ id: 'run-1', status: 'succeeded' }) + expect(store.getState().historyWorkflowData).toEqual({ id: 'run-1', status: 'succeeded' }) + }) + + it('should update showRunHistory', () => { + const store = createStore() + store.getState().setShowRunHistory(true) + expect(store.getState().showRunHistory).toBe(true) + }) + + it('should update versionHistory', () => { + const store = createStore() + const history: VersionHistory[] = [] + store.getState().setVersionHistory(history) + expect(store.getState().versionHistory).toEqual(history) + }) + }) + + describe('Form Slice Setters', () => { + it('should update inputs', () => { + const store = createStore() + store.getState().setInputs({ name: 'test', count: 42 }) + expect(store.getState().inputs).toEqual({ name: 'test', count: 42 }) + }) + + it('should update files', () => { + const store = createStore() + store.getState().setFiles([]) + expect(store.getState().files).toEqual([]) + }) + }) + + describe('Tool Slice Setters', () => { + it('should update toolPublished', () => { + const store = createStore() + store.getState().setToolPublished(true) + expect(store.getState().toolPublished).toBe(true) + }) + + it('should update lastPublishedHasUserInput', () => { + const store = createStore() + store.getState().setLastPublishedHasUserInput(true) + expect(store.getState().lastPublishedHasUserInput).toBe(true) + }) + }) + + describe('Layout Slice Setters', () => { + it('should update workflowCanvasWidth', () => { + const store = createStore() + store.getState().setWorkflowCanvasWidth(1200) + expect(store.getState().workflowCanvasWidth).toBe(1200) + }) + + it('should update workflowCanvasHeight', () => { + const store = createStore() + store.getState().setWorkflowCanvasHeight(800) + expect(store.getState().workflowCanvasHeight).toBe(800) + }) + + it('should update rightPanelWidth', () => { + const store = createStore() + store.getState().setRightPanelWidth(500) + expect(store.getState().rightPanelWidth).toBe(500) + }) + + it('should update nodePanelWidth', () => { + const store = createStore() + store.getState().setNodePanelWidth(350) + expect(store.getState().nodePanelWidth).toBe(350) + }) + + it('should update previewPanelWidth', () => { + const store = createStore() + store.getState().setPreviewPanelWidth(450) + expect(store.getState().previewPanelWidth).toBe(450) + }) + + it('should update otherPanelWidth', () => { + const store = createStore() + store.getState().setOtherPanelWidth(380) + expect(store.getState().otherPanelWidth).toBe(380) + }) + + it('should update bottomPanelWidth', () => { + const store = createStore() + store.getState().setBottomPanelWidth(600) + expect(store.getState().bottomPanelWidth).toBe(600) + }) + + it('should update bottomPanelHeight', () => { + const store = createStore() + store.getState().setBottomPanelHeight(500) + expect(store.getState().bottomPanelHeight).toBe(500) + }) + + it('should update variableInspectPanelHeight', () => { + const store = createStore() + store.getState().setVariableInspectPanelHeight(250) + expect(store.getState().variableInspectPanelHeight).toBe(250) + }) + + it('should update maximizeCanvas', () => { + const store = createStore() + store.getState().setMaximizeCanvas(true) + expect(store.getState().maximizeCanvas).toBe(true) + }) + }) + + describe('localStorage Initialization', () => { + it('should read controlMode from localStorage', () => { + localStorage.setItem('workflow-operation-mode', 'pointer') + const store = createStore() + expect(store.getState().controlMode).toBe('pointer') + }) + + it('should default controlMode to hand when localStorage has no value', () => { + const store = createStore() + expect(store.getState().controlMode).toBe('hand') + }) + + it('should read panelWidth from localStorage', () => { + localStorage.setItem('workflow-node-panel-width', '500') + const store = createStore() + expect(store.getState().panelWidth).toBe(500) + }) + + it('should default panelWidth to 420 when localStorage is empty', () => { + const store = createStore() + expect(store.getState().panelWidth).toBe(420) + }) + + it('should read nodePanelWidth from localStorage', () => { + localStorage.setItem('workflow-node-panel-width', '350') + const store = createStore() + expect(store.getState().nodePanelWidth).toBe(350) + }) + + it('should read previewPanelWidth from localStorage', () => { + localStorage.setItem('debug-and-preview-panel-width', '450') + const store = createStore() + expect(store.getState().previewPanelWidth).toBe(450) + }) + + it('should read variableInspectPanelHeight from localStorage', () => { + localStorage.setItem('workflow-variable-inpsect-panel-height', '200') + const store = createStore() + expect(store.getState().variableInspectPanelHeight).toBe(200) + }) + + it('should read maximizeCanvas from localStorage', () => { + localStorage.setItem('workflow-canvas-maximize', 'true') + const store = createStore() + expect(store.getState().maximizeCanvas).toBe(true) + }) + }) + + describe('useStore hook', () => { + it('should read state via selector when wrapped in WorkflowContext', () => { + const store = createStore() + store.getState().setShowSingleRunPanel(true) + + const wrapper = ({ children }: { children: React.ReactNode }) => + React.createElement(WorkflowContext.Provider, { value: store }, children) + + const { result } = renderHook(() => useStore(s => s.showSingleRunPanel), { wrapper }) + expect(result.current).toBe(true) + }) + + it('should throw when used without WorkflowContext.Provider', () => { + expect(() => { + renderHook(() => useStore(s => s.showSingleRunPanel)) + }).toThrow('Missing WorkflowContext.Provider in the tree') + }) + }) + + describe('useWorkflowStore hook', () => { + it('should return the store instance when wrapped in WorkflowContext', () => { + const store = createStore() + const wrapper = ({ children }: { children: React.ReactNode }) => + React.createElement(WorkflowContext.Provider, { value: store }, children) + + const { result } = renderHook(() => useWorkflowStore(), { wrapper }) + expect(result.current).toBe(store) + }) + }) + + describe('Injection', () => { + it('should support injecting additional slice', () => { + const injected: SliceFromInjection = {} + const store = createWorkflowStore({ + injectWorkflowStoreSliceFn: () => injected, + }) + expect(store.getState()).toBeDefined() + }) + }) +}) From b68ee600c1e2206a3c87f3ac01e28af2a7331c04 Mon Sep 17 00:00:00 2001 From: yyh <92089059+lyzno1@users.noreply.github.com> Date: Wed, 4 Mar 2026 11:13:43 +0800 Subject: [PATCH 015/256] refactor: migrate workflow onboarding modal to base dialog (#32915) --- .../base/ui/dialog/__tests__/index.spec.tsx | 48 +- web/app/components/base/ui/dialog/index.tsx | 35 +- .../components/workflow-children.tsx | 1 - .../__tests__/index.spec.tsx | 461 +++++------------- .../__tests__/start-node-option.spec.tsx | 1 - .../workflow-onboarding-modal/index.tsx | 62 +-- .../start-node-option.tsx | 13 +- .../start-node-selection-panel.tsx | 7 +- web/docs/overlay-migration.md | 32 +- web/eslint-suppressions.json | 16 - 10 files changed, 239 insertions(+), 437 deletions(-) diff --git a/web/app/components/base/ui/dialog/__tests__/index.spec.tsx b/web/app/components/base/ui/dialog/__tests__/index.spec.tsx index 0861fff603..2e52bd547a 100644 --- a/web/app/components/base/ui/dialog/__tests__/index.spec.tsx +++ b/web/app/components/base/ui/dialog/__tests__/index.spec.tsx @@ -1,11 +1,13 @@ import { Dialog as BaseDialog } from '@base-ui/react/dialog' -import { render, screen } from '@testing-library/react' -import { describe, expect, it } from 'vitest' +import { fireEvent, render, screen } from '@testing-library/react' +import { describe, expect, it, vi } from 'vitest' import { Dialog, DialogClose, + DialogCloseButton, DialogContent, DialogDescription, + DialogPortal, DialogTitle, DialogTrigger, } from '../index' @@ -29,7 +31,7 @@ describe('Dialog wrapper', () => { }) describe('Props', () => { - it('should not render close button when closable is omitted', () => { + it('should not render close button when DialogCloseButton is not provided', () => { render( @@ -41,20 +43,47 @@ describe('Dialog wrapper', () => { expect(screen.queryByRole('button', { name: 'Close' })).not.toBeInTheDocument() }) - it('should render close button when closable is true', () => { + it('should render explicit close button with custom aria-label', () => { render( - + + Dialog body , ) - const dialog = screen.getByRole('dialog') - const closeButton = screen.getByRole('button', { name: 'Close' }) + expect(screen.getByRole('button', { name: 'Dismiss dialog' })).toBeInTheDocument() + }) - expect(dialog).toContainElement(closeButton) - expect(closeButton).toHaveAttribute('aria-label', 'Close') + it('should render default close button label when aria-label is omitted', () => { + render( + + + + Dialog body + + , + ) + + expect(screen.getByRole('button', { name: 'Close' })).toBeInTheDocument() + }) + + it('should forward close button props to base primitive', () => { + const onClick = vi.fn() + render( + + + + Dialog body + + , + ) + + const closeButton = screen.getByTestId('close-button') + expect(closeButton).toBeDisabled() + fireEvent.click(closeButton) + expect(onClick).not.toHaveBeenCalled() }) }) @@ -65,6 +94,7 @@ describe('Dialog wrapper', () => { expect(DialogTitle).toBe(BaseDialog.Title) expect(DialogDescription).toBe(BaseDialog.Description) expect(DialogClose).toBe(BaseDialog.Close) + expect(DialogPortal).toBe(BaseDialog.Portal) }) }) }) diff --git a/web/app/components/base/ui/dialog/index.tsx b/web/app/components/base/ui/dialog/index.tsx index 605fccee09..b86f94c46f 100644 --- a/web/app/components/base/ui/dialog/index.tsx +++ b/web/app/components/base/ui/dialog/index.tsx @@ -16,22 +16,42 @@ export const DialogTrigger = BaseDialog.Trigger export const DialogTitle = BaseDialog.Title export const DialogDescription = BaseDialog.Description export const DialogClose = BaseDialog.Close +export const DialogPortal = BaseDialog.Portal + +type DialogCloseButtonProps = Omit, 'children'> + +export function DialogCloseButton({ + className, + 'aria-label': ariaLabel = 'Close', + ...props +}: DialogCloseButtonProps) { + return ( + + + ) +} type DialogContentProps = { children: React.ReactNode className?: string overlayClassName?: string - closable?: boolean } export function DialogContent({ children, className, overlayClassName, - closable = false, }: DialogContentProps) { return ( - + - {closable && ( - - - - )} {children} - + ) } diff --git a/web/app/components/workflow-app/components/workflow-children.tsx b/web/app/components/workflow-app/components/workflow-children.tsx index 2634e8da2a..0fbb399dd7 100644 --- a/web/app/components/workflow-app/components/workflow-children.tsx +++ b/web/app/components/workflow-app/components/workflow-children.tsx @@ -147,7 +147,6 @@ const WorkflowChildren = () => { handleSyncWorkflowDraft(true, false, { onSuccess: () => { autoGenerateWebhookUrl(newNode.id) - console.log('Node successfully saved to draft') }, onError: () => { console.error('Failed to save node to draft') diff --git a/web/app/components/workflow-app/components/workflow-onboarding-modal/__tests__/index.spec.tsx b/web/app/components/workflow-app/components/workflow-onboarding-modal/__tests__/index.spec.tsx index ca627f9679..af38ca113f 100644 --- a/web/app/components/workflow-app/components/workflow-onboarding-modal/__tests__/index.spec.tsx +++ b/web/app/components/workflow-app/components/workflow-onboarding-modal/__tests__/index.spec.tsx @@ -1,65 +1,32 @@ +import type { ReactNode } from 'react' import { fireEvent, render, screen, waitFor } from '@testing-library/react' import userEvent from '@testing-library/user-event' -import * as React from 'react' import { BlockEnum } from '@/app/components/workflow/types' import WorkflowOnboardingModal from '../index' -// Mock Modal component -vi.mock('@/app/components/base/modal', () => ({ - default: function MockModal({ - isShow, - onClose, - children, - closable, +vi.mock('@/app/components/workflow/block-selector', () => ({ + default: function MockNodeSelector({ + open, + onSelect, + trigger, }: { - isShow: boolean - onClose?: () => void - children?: React.ReactNode - closable?: boolean + open?: boolean + onSelect: (type: BlockEnum, config?: Record) => void + trigger?: ((open: boolean) => ReactNode) | ReactNode }) { - if (!isShow) - return null - return ( -
- {closable && ( - +
+ {typeof trigger === 'function' ? trigger(Boolean(open)) : trigger} + {open && ( +
+ + +
)} - {children} -
- ) - }, -})) - -// Mock StartNodeSelectionPanel (using real component would be better for integration, -// but for this test we'll mock to control behavior) -vi.mock('../start-node-selection-panel', () => ({ - default: function MockStartNodeSelectionPanel({ - onSelectUserInput, - onSelectTrigger, - }: { - onSelectUserInput?: () => void - onSelectTrigger?: (type: BlockEnum, config?: Record) => void - }) { - return ( -
- - -
) }, @@ -79,401 +46,292 @@ describe('WorkflowOnboardingModal', () => { vi.clearAllMocks() }) - // Helper function to render component const renderComponent = (props = {}) => { return render() } + const getBackdrop = () => document.body.querySelector('.bg-workflow-canvas-canvas-overlay') + const getUserInputHeading = () => screen.getByRole('heading', { name: 'workflow.onboarding.userInputFull' }) + const getTriggerHeading = () => screen.getByRole('heading', { name: 'workflow.onboarding.trigger' }) - // Rendering tests (REQUIRED) describe('Rendering', () => { it('should render without crashing', () => { - // Arrange & Act renderComponent() - // Assert expect(screen.getByRole('dialog')).toBeInTheDocument() }) - it('should render modal when isShow is true', () => { - // Arrange & Act + it('should render dialog when isShow is true', () => { renderComponent({ isShow: true }) - // Assert - expect(screen.getByTestId('modal')).toBeInTheDocument() + expect(screen.getByRole('dialog')).toBeInTheDocument() }) - it('should not render modal when isShow is false', () => { - // Arrange & Act + it('should not render dialog when isShow is false', () => { renderComponent({ isShow: false }) - // Assert - expect(screen.queryByTestId('modal')).not.toBeInTheDocument() + expect(screen.queryByRole('dialog')).not.toBeInTheDocument() }) - it('should render modal title', () => { - // Arrange & Act + it('should render title', () => { renderComponent() - // Assert expect(screen.getByText('workflow.onboarding.title')).toBeInTheDocument() }) - it('should render modal description', () => { - // Arrange & Act - const { container } = renderComponent() + it('should render description', () => { + renderComponent() - // Assert - Check both parts of description (separated by link) - const descriptionDiv = container.querySelector('.body-xs-regular.leading-4') - expect(descriptionDiv).toBeInTheDocument() - expect(descriptionDiv).toHaveTextContent('workflow.onboarding.description') + expect(screen.getByText('workflow.onboarding.description')).toBeInTheDocument() }) it('should render StartNodeSelectionPanel', () => { - // Arrange & Act renderComponent() - // Assert - expect(screen.getByTestId('start-node-selection-panel')).toBeInTheDocument() + expect(getUserInputHeading()).toBeInTheDocument() + expect(getTriggerHeading()).toBeInTheDocument() }) - it('should render ESC tip when modal is shown', () => { - // Arrange & Act + it('should render ESC tip when shown', () => { renderComponent({ isShow: true }) - // Assert expect(screen.getByText('workflow.onboarding.escTip.press')).toBeInTheDocument() expect(screen.getByText('workflow.onboarding.escTip.key')).toBeInTheDocument() expect(screen.getByText('workflow.onboarding.escTip.toDismiss')).toBeInTheDocument() }) - it('should not render ESC tip when modal is hidden', () => { - // Arrange & Act + it('should not render ESC tip when hidden', () => { renderComponent({ isShow: false }) - // Assert expect(screen.queryByText('workflow.onboarding.escTip.press')).not.toBeInTheDocument() }) it('should have correct styling for title', () => { - // Arrange & Act renderComponent() - // Assert const title = screen.getByText('workflow.onboarding.title') expect(title).toHaveClass('title-2xl-semi-bold') expect(title).toHaveClass('text-text-primary') }) - it('should have modal close button', () => { - // Arrange & Act + it('should have close button', () => { renderComponent() - // Assert - expect(screen.getByTestId('modal-close-button')).toBeInTheDocument() + expect(screen.getByRole('button', { name: 'Close' })).toBeInTheDocument() + }) + + it('should render workflow canvas backdrop when shown', () => { + renderComponent({ isShow: true }) + + const backdrop = getBackdrop() + expect(backdrop).toBeInTheDocument() + expect(backdrop).not.toHaveClass('opacity-20') }) }) - // Props tests (REQUIRED) describe('Props', () => { it('should accept isShow prop', () => { - // Arrange & Act const { rerender } = renderComponent({ isShow: false }) - // Assert - expect(screen.queryByTestId('modal')).not.toBeInTheDocument() + expect(screen.queryByRole('dialog')).not.toBeInTheDocument() - // Act rerender() - // Assert - expect(screen.getByTestId('modal')).toBeInTheDocument() + expect(screen.getByRole('dialog')).toBeInTheDocument() }) it('should accept onClose prop', () => { - // Arrange const customOnClose = vi.fn() - // Act renderComponent({ onClose: customOnClose }) - // Assert - expect(screen.getByTestId('modal')).toBeInTheDocument() + expect(screen.getByRole('dialog')).toBeInTheDocument() }) it('should accept onSelectStartNode prop', () => { - // Arrange const customHandler = vi.fn() - // Act renderComponent({ onSelectStartNode: customHandler }) - // Assert - expect(screen.getByTestId('start-node-selection-panel')).toBeInTheDocument() - }) - - it('should handle undefined onClose gracefully', () => { - // Arrange & Act - expect(() => { - renderComponent({ onClose: undefined }) - }).not.toThrow() - - // Assert - expect(screen.getByTestId('modal')).toBeInTheDocument() - }) - - it('should handle undefined onSelectStartNode gracefully', () => { - // Arrange & Act - expect(() => { - renderComponent({ onSelectStartNode: undefined }) - }).not.toThrow() - - // Assert - expect(screen.getByTestId('modal')).toBeInTheDocument() + expect(getUserInputHeading()).toBeInTheDocument() }) }) - // User Interactions - Start Node Selection describe('User Interactions - Start Node Selection', () => { it('should call onSelectStartNode with Start block when user input is selected', async () => { - // Arrange const user = userEvent.setup() renderComponent() - // Act - const userInputButton = screen.getByTestId('select-user-input') - await user.click(userInputButton) + await user.click(getUserInputHeading()) - // Assert expect(mockOnSelectStartNode).toHaveBeenCalledTimes(1) expect(mockOnSelectStartNode).toHaveBeenCalledWith(BlockEnum.Start) }) - it('should call onClose after selecting user input', async () => { - // Arrange + it('should not call onClose when selecting user input (parent handles closing)', async () => { const user = userEvent.setup() renderComponent() - // Act - const userInputButton = screen.getByTestId('select-user-input') - await user.click(userInputButton) + await user.click(getUserInputHeading()) - // Assert - expect(mockOnClose).toHaveBeenCalledTimes(1) + expect(mockOnClose).not.toHaveBeenCalled() }) it('should call onSelectStartNode with trigger type when trigger is selected', async () => { - // Arrange const user = userEvent.setup() renderComponent() - // Act - const triggerButton = screen.getByTestId('select-trigger-schedule') - await user.click(triggerButton) + await user.click(getTriggerHeading()) + await user.click(screen.getByTestId('select-trigger-schedule')) - // Assert expect(mockOnSelectStartNode).toHaveBeenCalledTimes(1) expect(mockOnSelectStartNode).toHaveBeenCalledWith(BlockEnum.TriggerSchedule, undefined) }) - it('should call onClose after selecting trigger', async () => { - // Arrange + it('should not call onClose when selecting trigger (parent handles closing)', async () => { const user = userEvent.setup() renderComponent() - // Act - const triggerButton = screen.getByTestId('select-trigger-schedule') - await user.click(triggerButton) + await user.click(getTriggerHeading()) + await user.click(screen.getByTestId('select-trigger-schedule')) - // Assert - expect(mockOnClose).toHaveBeenCalledTimes(1) + expect(mockOnClose).not.toHaveBeenCalled() }) it('should pass tool config when selecting trigger with config', async () => { - // Arrange const user = userEvent.setup() renderComponent() - // Act - const webhookButton = screen.getByTestId('select-trigger-webhook') - await user.click(webhookButton) + await user.click(getTriggerHeading()) + await user.click(screen.getByTestId('select-trigger-webhook')) - // Assert expect(mockOnSelectStartNode).toHaveBeenCalledTimes(1) expect(mockOnSelectStartNode).toHaveBeenCalledWith(BlockEnum.TriggerWebhook, { config: 'test' }) - expect(mockOnClose).toHaveBeenCalledTimes(1) + expect(mockOnClose).not.toHaveBeenCalled() }) }) - // User Interactions - Modal Close - describe('User Interactions - Modal Close', () => { + describe('User Interactions - Dialog Close', () => { it('should call onClose when close button is clicked', async () => { - // Arrange const user = userEvent.setup() renderComponent() - // Act - const closeButton = screen.getByTestId('modal-close-button') - await user.click(closeButton) + await user.click(screen.getByRole('button', { name: 'Close' })) - // Assert expect(mockOnClose).toHaveBeenCalledTimes(1) }) it('should not call onSelectStartNode when closing without selection', async () => { - // Arrange const user = userEvent.setup() renderComponent() - // Act - const closeButton = screen.getByTestId('modal-close-button') - await user.click(closeButton) + await user.click(screen.getByRole('button', { name: 'Close' })) - // Assert expect(mockOnSelectStartNode).not.toHaveBeenCalled() expect(mockOnClose).toHaveBeenCalledTimes(1) }) + + it('should call onClose exactly once when close button is clicked (no double-close)', async () => { + const user = userEvent.setup() + const onClose = vi.fn() + renderComponent({ onClose }) + + await user.click(screen.getByRole('button', { name: 'Close' })) + + expect(onClose).toHaveBeenCalledTimes(1) + }) + + it('should not call onClose when clicking backdrop', async () => { + const user = userEvent.setup() + renderComponent() + + const backdrop = getBackdrop() + expect(backdrop).toBeInTheDocument() + if (!backdrop) + throw new Error('backdrop should exist when dialog is open') + + await user.click(backdrop) + + expect(mockOnClose).not.toHaveBeenCalled() + }) }) - // Keyboard Event Handling describe('Keyboard Event Handling', () => { it('should call onClose when ESC key is pressed', () => { - // Arrange renderComponent({ isShow: true }) - // Act - fireEvent.keyDown(document, { key: 'Escape', code: 'Escape' }) + fireEvent.keyDown(screen.getByRole('dialog'), { key: 'Escape' }) - // Assert expect(mockOnClose).toHaveBeenCalledTimes(1) }) - it('should not call onClose when other keys are pressed', () => { - // Arrange - renderComponent({ isShow: true }) - - // Act - fireEvent.keyDown(document, { key: 'Enter', code: 'Enter' }) - fireEvent.keyDown(document, { key: 'Tab', code: 'Tab' }) - fireEvent.keyDown(document, { key: 'a', code: 'KeyA' }) - - // Assert - expect(mockOnClose).not.toHaveBeenCalled() - }) - - it('should not call onClose when ESC is pressed but modal is hidden', () => { - // Arrange + it('should not call onClose when ESC is pressed but dialog is hidden', () => { renderComponent({ isShow: false }) - // Act - fireEvent.keyDown(document, { key: 'Escape', code: 'Escape' }) + fireEvent.keyDown(document, { key: 'Escape' }) - // Assert expect(mockOnClose).not.toHaveBeenCalled() }) - it('should clean up event listener on unmount', () => { - // Arrange + it('should clean up on unmount', () => { const { unmount } = renderComponent({ isShow: true }) - // Act unmount() - fireEvent.keyDown(document, { key: 'Escape', code: 'Escape' }) + fireEvent.keyDown(document, { key: 'Escape' }) - // Assert expect(mockOnClose).not.toHaveBeenCalled() }) - it('should update event listener when isShow changes', () => { - // Arrange + it('should respond to ESC based on open state', () => { const { rerender } = renderComponent({ isShow: true }) - // Act - Press ESC when shown - fireEvent.keyDown(document, { key: 'Escape', code: 'Escape' }) - - // Assert + fireEvent.keyDown(screen.getByRole('dialog'), { key: 'Escape' }) expect(mockOnClose).toHaveBeenCalledTimes(1) - // Act - Hide modal and clear mock mockOnClose.mockClear() rerender() - // Act - Press ESC when hidden - fireEvent.keyDown(document, { key: 'Escape', code: 'Escape' }) - - // Assert + fireEvent.keyDown(document, { key: 'Escape' }) expect(mockOnClose).not.toHaveBeenCalled() }) - - it('should handle multiple ESC key presses', () => { - // Arrange - renderComponent({ isShow: true }) - - // Act - fireEvent.keyDown(document, { key: 'Escape', code: 'Escape' }) - fireEvent.keyDown(document, { key: 'Escape', code: 'Escape' }) - fireEvent.keyDown(document, { key: 'Escape', code: 'Escape' }) - - // Assert - expect(mockOnClose).toHaveBeenCalledTimes(3) - }) }) - // Edge Cases (REQUIRED) describe('Edge Cases', () => { - it('should handle rapid modal show/hide toggling', async () => { - // Arrange + it('should handle rapid show/hide toggling', async () => { const { rerender } = renderComponent({ isShow: false }) - // Assert - expect(screen.queryByTestId('modal')).not.toBeInTheDocument() + expect(screen.queryByRole('dialog')).not.toBeInTheDocument() - // Act rerender() + expect(screen.getByRole('dialog')).toBeInTheDocument() - // Assert - expect(screen.getByTestId('modal')).toBeInTheDocument() - - // Act rerender() - - // Assert await waitFor(() => { - expect(screen.queryByTestId('modal')).not.toBeInTheDocument() + expect(screen.queryByRole('dialog')).not.toBeInTheDocument() }) }) it('should handle selecting multiple nodes in sequence', async () => { - // Arrange const user = userEvent.setup() const { rerender } = renderComponent() - // Act - Select user input - await user.click(screen.getByTestId('select-user-input')) - - // Assert + await user.click(getUserInputHeading()) expect(mockOnSelectStartNode).toHaveBeenCalledWith(BlockEnum.Start) - expect(mockOnClose).toHaveBeenCalledTimes(1) + expect(mockOnClose).not.toHaveBeenCalled() - // Act - Re-show modal and select trigger - mockOnClose.mockClear() mockOnSelectStartNode.mockClear() rerender() + await user.click(getTriggerHeading()) await user.click(screen.getByTestId('select-trigger-schedule')) - - // Assert expect(mockOnSelectStartNode).toHaveBeenCalledWith(BlockEnum.TriggerSchedule, undefined) - expect(mockOnClose).toHaveBeenCalledTimes(1) + expect(mockOnClose).not.toHaveBeenCalled() }) it('should handle prop updates correctly', () => { - // Arrange const { rerender } = renderComponent({ isShow: true }) - // Assert - expect(screen.getByTestId('modal')).toBeInTheDocument() + expect(screen.getByRole('dialog')).toBeInTheDocument() - // Act - Update props const newOnClose = vi.fn() const newOnSelectStartNode = vi.fn() rerender( @@ -484,169 +342,120 @@ describe('WorkflowOnboardingModal', () => { />, ) - // Assert - Modal still renders with new props - expect(screen.getByTestId('modal')).toBeInTheDocument() + expect(screen.getByRole('dialog')).toBeInTheDocument() }) - it('should handle onClose being called multiple times', async () => { - // Arrange - const user = userEvent.setup() - renderComponent() - - // Act - await user.click(screen.getByTestId('modal-close-button')) - await user.click(screen.getByTestId('modal-close-button')) - - // Assert - expect(mockOnClose).toHaveBeenCalledTimes(2) - }) - - it('should maintain modal state when props change', () => { - // Arrange + it('should maintain dialog when props change', () => { const { rerender } = renderComponent({ isShow: true }) - // Assert - expect(screen.getByTestId('modal')).toBeInTheDocument() + expect(screen.getByRole('dialog')).toBeInTheDocument() - // Act - Change onClose handler const newOnClose = vi.fn() rerender() - // Assert - Modal should still be visible - expect(screen.getByTestId('modal')).toBeInTheDocument() + expect(screen.getByRole('dialog')).toBeInTheDocument() }) }) - // Accessibility Tests describe('Accessibility', () => { it('should have dialog role', () => { - // Arrange & Act renderComponent() - // Assert expect(screen.getByRole('dialog')).toBeInTheDocument() }) it('should have proper heading hierarchy', () => { - // Arrange & Act - const { container } = renderComponent() + renderComponent() - // Assert - const heading = container.querySelector('h3') + const heading = screen.getByRole('heading', { name: 'workflow.onboarding.title' }) expect(heading).toBeInTheDocument() expect(heading).toHaveTextContent('workflow.onboarding.title') }) - it('should have keyboard navigation support via ESC key', () => { - // Arrange + it('should expose dialog accessible name from title', () => { + renderComponent() + + expect(screen.getByRole('dialog', { name: 'workflow.onboarding.title' })).toBeInTheDocument() + }) + + it('should support ESC key dismissal', () => { renderComponent({ isShow: true }) - // Act - fireEvent.keyDown(document, { key: 'Escape', code: 'Escape' }) + fireEvent.keyDown(screen.getByRole('dialog'), { key: 'Escape' }) - // Assert expect(mockOnClose).toHaveBeenCalledTimes(1) }) it('should have visible ESC key hint', () => { - // Arrange & Act renderComponent({ isShow: true }) - // Assert - ShortcutsName component renders keys in div elements with system-kbd class const escKey = screen.getByText('workflow.onboarding.escTip.key') - // ShortcutsName renders a
with class system-kbd, not a element expect(escKey.closest('.system-kbd')).toBeInTheDocument() }) it('should have descriptive text for ESC functionality', () => { - // Arrange & Act renderComponent({ isShow: true }) - // Assert expect(screen.getByText('workflow.onboarding.escTip.press')).toBeInTheDocument() expect(screen.getByText('workflow.onboarding.escTip.toDismiss')).toBeInTheDocument() }) it('should have proper text color classes', () => { - // Arrange & Act renderComponent() - // Assert const title = screen.getByText('workflow.onboarding.title') expect(title).toHaveClass('text-text-primary') }) }) - // Integration Tests describe('Integration', () => { it('should complete full flow of selecting user input node', async () => { - // Arrange const user = userEvent.setup() renderComponent() - // Assert - Initial state - expect(screen.getByTestId('modal')).toBeInTheDocument() + expect(screen.getByRole('dialog')).toBeInTheDocument() expect(screen.getByText('workflow.onboarding.title')).toBeInTheDocument() - expect(screen.getByTestId('start-node-selection-panel')).toBeInTheDocument() + expect(getUserInputHeading()).toBeInTheDocument() - // Act - Select user input - await user.click(screen.getByTestId('select-user-input')) + await user.click(getUserInputHeading()) - // Assert - Callbacks called expect(mockOnSelectStartNode).toHaveBeenCalledWith(BlockEnum.Start) - expect(mockOnClose).toHaveBeenCalledTimes(1) + expect(mockOnClose).not.toHaveBeenCalled() }) it('should complete full flow of selecting trigger node', async () => { - // Arrange const user = userEvent.setup() renderComponent() - // Assert - Initial state - expect(screen.getByTestId('modal')).toBeInTheDocument() + expect(screen.getByRole('dialog')).toBeInTheDocument() - // Act - Select trigger + await user.click(getTriggerHeading()) await user.click(screen.getByTestId('select-trigger-webhook')) - // Assert - Callbacks called with config expect(mockOnSelectStartNode).toHaveBeenCalledWith(BlockEnum.TriggerWebhook, { config: 'test' }) - expect(mockOnClose).toHaveBeenCalledTimes(1) + expect(mockOnClose).not.toHaveBeenCalled() }) it('should render all components in correct hierarchy', () => { - // Arrange & Act - const { container } = renderComponent() + renderComponent() - // Assert - Modal is the root - expect(screen.getByTestId('modal')).toBeInTheDocument() - - // Assert - Header elements - const heading = container.querySelector('h3') - expect(heading).toBeInTheDocument() - - // Assert - Selection panel - expect(screen.getByTestId('start-node-selection-panel')).toBeInTheDocument() - - // Assert - ESC tip + const dialog = screen.getByRole('dialog') + expect(dialog).toBeInTheDocument() + expect(screen.getByText('workflow.onboarding.title')).toBeInTheDocument() + expect(getUserInputHeading()).toBeInTheDocument() expect(screen.getByText('workflow.onboarding.escTip.key')).toBeInTheDocument() + expect(dialog).not.toContainElement(screen.getByText('workflow.onboarding.escTip.key')) }) it('should coordinate between keyboard and click interactions', async () => { - // Arrange const user = userEvent.setup() renderComponent() - // Act - Click close button - await user.click(screen.getByTestId('modal-close-button')) - - // Assert + await user.click(screen.getByRole('button', { name: 'Close' })) expect(mockOnClose).toHaveBeenCalledTimes(1) - // Act - Clear and try ESC key mockOnClose.mockClear() - fireEvent.keyDown(document, { key: 'Escape', code: 'Escape' }) - - // Assert + fireEvent.keyDown(screen.getByRole('dialog'), { key: 'Escape' }) expect(mockOnClose).toHaveBeenCalledTimes(1) }) }) diff --git a/web/app/components/workflow-app/components/workflow-onboarding-modal/__tests__/start-node-option.spec.tsx b/web/app/components/workflow-app/components/workflow-onboarding-modal/__tests__/start-node-option.spec.tsx index 04c223499a..2739c51b62 100644 --- a/web/app/components/workflow-app/components/workflow-onboarding-modal/__tests__/start-node-option.spec.tsx +++ b/web/app/components/workflow-app/components/workflow-onboarding-modal/__tests__/start-node-option.spec.tsx @@ -47,7 +47,6 @@ describe('StartNodeOption', () => { // Assert const title = screen.getByText('Test Title') expect(title).toBeInTheDocument() - expect(title).toHaveClass('system-md-semi-bold') expect(title).toHaveClass('text-text-primary') }) diff --git a/web/app/components/workflow-app/components/workflow-onboarding-modal/index.tsx b/web/app/components/workflow-app/components/workflow-onboarding-modal/index.tsx index 16bae51246..8db281bda0 100644 --- a/web/app/components/workflow-app/components/workflow-onboarding-modal/index.tsx +++ b/web/app/components/workflow-app/components/workflow-onboarding-modal/index.tsx @@ -1,12 +1,8 @@ 'use client' import type { FC } from 'react' import type { PluginDefaultValue } from '@/app/components/workflow/block-selector/types' -import { - useCallback, - useEffect, -} from 'react' import { useTranslation } from 'react-i18next' -import Modal from '@/app/components/base/modal' +import { Dialog, DialogCloseButton, DialogContent, DialogDescription, DialogPortal, DialogTitle } from '@/app/components/base/ui/dialog' import ShortcutsName from '@/app/components/workflow/shortcuts-name' import { BlockEnum } from '@/app/components/workflow/types' import StartNodeSelectionPanel from './start-node-selection-panel' @@ -24,63 +20,39 @@ const WorkflowOnboardingModal: FC = ({ }) => { const { t } = useTranslation() - const handleSelectUserInput = useCallback(() => { - onSelectStartNode(BlockEnum.Start) - onClose() // Close modal after selection - }, [onSelectStartNode, onClose]) - - const handleTriggerSelect = useCallback((nodeType: BlockEnum, toolConfig?: PluginDefaultValue) => { - onSelectStartNode(nodeType, toolConfig) - onClose() // Close modal after selection - }, [onSelectStartNode, onClose]) - - useEffect(() => { - const handleEsc = (e: KeyboardEvent) => { - if (e.key === 'Escape' && isShow) - onClose() - } - document.addEventListener('keydown', handleEsc) - return () => document.removeEventListener('keydown', handleEsc) - }, [isShow, onClose]) - return ( - <> - + + +
- {/* Header */}
-

+ {t('onboarding.title', { ns: 'workflow' })} -

-
+ + {t('onboarding.description', { ns: 'workflow' })} -
+
- {/* Content */} onSelectStartNode(BlockEnum.Start)} + onSelectTrigger={onSelectStartNode} />
-
+ - {/* ESC tip below modal */} - {isShow && ( -
+ +
{t('onboarding.escTip.press', { ns: 'workflow' })} {t('onboarding.escTip.toDismiss', { ns: 'workflow' })}
- )} - +
+
) } diff --git a/web/app/components/workflow-app/components/workflow-onboarding-modal/start-node-option.tsx b/web/app/components/workflow-app/components/workflow-onboarding-modal/start-node-option.tsx index 77f2a842c9..8b1ce699e7 100644 --- a/web/app/components/workflow-app/components/workflow-onboarding-modal/start-node-option.tsx +++ b/web/app/components/workflow-app/components/workflow-onboarding-modal/start-node-option.tsx @@ -1,6 +1,5 @@ 'use client' import type { FC, ReactNode } from 'react' -import { cn } from '@/utils/classnames' type StartNodeOptionProps = { icon: ReactNode @@ -20,22 +19,18 @@ const StartNodeOption: FC = ({ return (
- {/* Icon */}
{icon}
- {/* Text content */}
-

+

{title} {subtitle && ( - + {' '} {subtitle} @@ -44,7 +39,7 @@ const StartNodeOption: FC = ({

-

+

{description}

diff --git a/web/app/components/workflow-app/components/workflow-onboarding-modal/start-node-selection-panel.tsx b/web/app/components/workflow-app/components/workflow-onboarding-modal/start-node-selection-panel.tsx index 6d13cbf6a8..b4a1cd135b 100644 --- a/web/app/components/workflow-app/components/workflow-onboarding-modal/start-node-selection-panel.tsx +++ b/web/app/components/workflow-app/components/workflow-onboarding-modal/start-node-selection-panel.tsx @@ -21,10 +21,6 @@ const StartNodeSelectionPanel: FC = ({ const { t } = useTranslation() const [showTriggerSelector, setShowTriggerSelector] = useState(false) - const handleTriggerClick = useCallback(() => { - setShowTriggerSelector(true) - }, []) - const handleTriggerSelect = useCallback((nodeType: BlockEnum, toolConfig?: PluginDefaultValue) => { setShowTriggerSelector(false) onSelectTrigger(nodeType, toolConfig) @@ -67,10 +63,9 @@ const StartNodeSelectionPanel: FC = ({ )} title={t('onboarding.trigger', { ns: 'workflow' })} description={t('onboarding.triggerDescription', { ns: 'workflow' })} - onClick={handleTriggerClick} + onClick={() => setShowTriggerSelector(true)} /> )} - popupClassName="z-[1200]" />
) diff --git a/web/docs/overlay-migration.md b/web/docs/overlay-migration.md index 77c5fe5b04..ffe54afa1a 100644 --- a/web/docs/overlay-migration.md +++ b/web/docs/overlay-migration.md @@ -1,10 +1,14 @@ # Overlay Migration Guide -This document tracks the migration away from legacy `portal-to-follow-elem` APIs. +This document tracks the migration away from legacy overlay APIs. ## Scope -- Deprecated API: `@/app/components/base/portal-to-follow-elem` +- Deprecated imports: + - `@/app/components/base/portal-to-follow-elem` + - `@/app/components/base/tooltip` + - `@/app/components/base/modal` + - `@/app/components/base/select` (including `custom` / `pure`) - Replacement primitives: - `@/app/components/base/ui/tooltip` - `@/app/components/base/ui/dropdown-menu` @@ -15,33 +19,33 @@ This document tracks the migration away from legacy `portal-to-follow-elem` APIs ## ESLint policy -- `no-restricted-imports` blocks new usage of `portal-to-follow-elem`. -- The rule is enabled for normal source files and test files are excluded. -- Legacy `app/components/base/*` callers are temporarily allowlisted in ESLint config. +- `no-restricted-imports` blocks all deprecated imports listed above. +- The rule is enabled for normal source files (`.ts` / `.tsx`) and test files are excluded. +- Legacy `app/components/base/*` callers are temporarily allowlisted in `OVERLAY_MIGRATION_LEGACY_BASE_FILES` (`web/eslint.constants.mjs`). - New files must not be added to the allowlist without migration owner approval. ## Migration phases 1. Business/UI features outside `app/components/base/**` - - Migrate old calls to semantic primitives. - - Keep `eslint-suppressions.json` stable or shrinking. + - Migrate old calls to semantic primitives from `@/app/components/base/ui/**`. + - Keep deprecated imports out of newly touched files. 1. Legacy base components in allowlist - Migrate allowlisted base callers gradually. - - Remove migrated files from allowlist immediately. + - Remove migrated files from `OVERLAY_MIGRATION_LEGACY_BASE_FILES` immediately. 1. Cleanup - - Remove remaining suppressions for `no-restricted-imports`. - - Remove legacy `portal-to-follow-elem` implementation. + - Remove remaining allowlist entries. + - Remove legacy overlay implementations when import count reaches zero. -## Suppression maintenance +## Allowlist maintenance - After each migration batch, run: ```sh -pnpm eslint --prune-suppressions --pass-on-unpruned-suppressions +pnpm -C web lint:fix --prune-suppressions ``` -- Never increase suppressions to bypass new code. -- Prefer direct migration over adding suppression entries. +- If a migrated file was in the allowlist, remove it from `web/eslint.constants.mjs` in the same PR. +- Never increase allowlist scope to bypass new code. ## React Refresh policy for base UI primitives diff --git a/web/eslint-suppressions.json b/web/eslint-suppressions.json index 4b64291612..f0eb575d8d 100644 --- a/web/eslint-suppressions.json +++ b/web/eslint-suppressions.json @@ -6298,9 +6298,6 @@ } }, "app/components/workflow-app/components/workflow-children.tsx": { - "no-console": { - "count": 1 - }, "ts/no-explicit-any": { "count": 3 } @@ -6310,19 +6307,6 @@ "count": 2 } }, - "app/components/workflow-app/components/workflow-onboarding-modal/index.tsx": { - "no-restricted-imports": { - "count": 1 - }, - "tailwindcss/enforce-consistent-class-order": { - "count": 3 - } - }, - "app/components/workflow-app/components/workflow-onboarding-modal/start-node-option.tsx": { - "tailwindcss/enforce-consistent-class-order": { - "count": 2 - } - }, "app/components/workflow-app/hooks/use-DSL.ts": { "ts/no-explicit-any": { "count": 1 From b584434e2862465694949b35a60bf38da8a3d76a Mon Sep 17 00:00:00 2001 From: wangxiaolei Date: Wed, 4 Mar 2026 11:52:43 +0800 Subject: [PATCH 016/256] feat: redis connection support max connections (#32935) Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com> --- api/.env.example | 2 + api/configs/middleware/cache/redis_config.py | 5 +++ api/extensions/ext_redis.py | 45 ++++++++++++++------ docker/.env.example | 3 ++ docker/docker-compose.yaml | 1 + docker/middleware.env.example | 3 ++ 6 files changed, 46 insertions(+), 13 deletions(-) diff --git a/api/.env.example b/api/.env.example index 38a096da0a..ab8b6c5287 100644 --- a/api/.env.example +++ b/api/.env.example @@ -42,6 +42,8 @@ REFRESH_TOKEN_EXPIRE_DAYS=30 # redis configuration REDIS_HOST=localhost REDIS_PORT=6379 +# Optional: limit total connections in connection pool (unset for default) +# REDIS_MAX_CONNECTIONS=200 REDIS_USERNAME= REDIS_PASSWORD=difyai123456 REDIS_USE_SSL=false diff --git a/api/configs/middleware/cache/redis_config.py b/api/configs/middleware/cache/redis_config.py index 4705b28c69..367cb52731 100644 --- a/api/configs/middleware/cache/redis_config.py +++ b/api/configs/middleware/cache/redis_config.py @@ -111,3 +111,8 @@ class RedisConfig(BaseSettings): description="Enable client side cache in redis", default=False, ) + + REDIS_MAX_CONNECTIONS: PositiveInt | None = Field( + description="Maximum connections in the Redis connection pool (unset for library default)", + default=None, + ) diff --git a/api/extensions/ext_redis.py b/api/extensions/ext_redis.py index 658e6a0738..cadd9cb263 100644 --- a/api/extensions/ext_redis.py +++ b/api/extensions/ext_redis.py @@ -181,13 +181,18 @@ def _create_sentinel_client(redis_params: dict[str, Any]) -> Union[redis.Redis, sentinel_hosts = [(node.split(":")[0], int(node.split(":")[1])) for node in dify_config.REDIS_SENTINELS.split(",")] + sentinel_kwargs = { + "socket_timeout": dify_config.REDIS_SENTINEL_SOCKET_TIMEOUT, + "username": dify_config.REDIS_SENTINEL_USERNAME, + "password": dify_config.REDIS_SENTINEL_PASSWORD, + } + + if dify_config.REDIS_MAX_CONNECTIONS: + sentinel_kwargs["max_connections"] = dify_config.REDIS_MAX_CONNECTIONS + sentinel = Sentinel( sentinel_hosts, - sentinel_kwargs={ - "socket_timeout": dify_config.REDIS_SENTINEL_SOCKET_TIMEOUT, - "username": dify_config.REDIS_SENTINEL_USERNAME, - "password": dify_config.REDIS_SENTINEL_PASSWORD, - }, + sentinel_kwargs=sentinel_kwargs, ) master: redis.Redis = sentinel.master_for(dify_config.REDIS_SENTINEL_SERVICE_NAME, **redis_params) @@ -204,12 +209,15 @@ def _create_cluster_client() -> Union[redis.Redis, RedisCluster]: for node in dify_config.REDIS_CLUSTERS.split(",") ] - cluster: RedisCluster = RedisCluster( - startup_nodes=nodes, - password=dify_config.REDIS_CLUSTERS_PASSWORD, - protocol=dify_config.REDIS_SERIALIZATION_PROTOCOL, - cache_config=_get_cache_configuration(), - ) + cluster_kwargs: dict[str, Any] = { + "startup_nodes": nodes, + "password": dify_config.REDIS_CLUSTERS_PASSWORD, + "protocol": dify_config.REDIS_SERIALIZATION_PROTOCOL, + "cache_config": _get_cache_configuration(), + } + if dify_config.REDIS_MAX_CONNECTIONS: + cluster_kwargs["max_connections"] = dify_config.REDIS_MAX_CONNECTIONS + cluster: RedisCluster = RedisCluster(**cluster_kwargs) return cluster @@ -225,6 +233,9 @@ def _create_standalone_client(redis_params: dict[str, Any]) -> Union[redis.Redis } ) + if dify_config.REDIS_MAX_CONNECTIONS: + redis_params["max_connections"] = dify_config.REDIS_MAX_CONNECTIONS + if ssl_kwargs: redis_params.update(ssl_kwargs) @@ -234,9 +245,17 @@ def _create_standalone_client(redis_params: dict[str, Any]) -> Union[redis.Redis def _create_pubsub_client(pubsub_url: str, use_clusters: bool) -> redis.Redis | RedisCluster: + max_conns = dify_config.REDIS_MAX_CONNECTIONS if use_clusters: - return RedisCluster.from_url(pubsub_url) - return redis.Redis.from_url(pubsub_url) + if max_conns: + return RedisCluster.from_url(pubsub_url, max_connections=max_conns) + else: + return RedisCluster.from_url(pubsub_url) + + if max_conns: + return redis.Redis.from_url(pubsub_url, max_connections=max_conns) + else: + return redis.Redis.from_url(pubsub_url) def init_app(app: DifyApp): diff --git a/docker/.env.example b/docker/.env.example index 3d0009711d..399242cea3 100644 --- a/docker/.env.example +++ b/docker/.env.example @@ -349,6 +349,9 @@ REDIS_SSL_CERTFILE= REDIS_SSL_KEYFILE= # Path to client private key file for SSL authentication REDIS_DB=0 +# Optional: limit total Redis connections used by API/Worker (unset for default) +# Align with API's REDIS_MAX_CONNECTIONS in configs +REDIS_MAX_CONNECTIONS= # Whether to use Redis Sentinel mode. # If set to true, the application will automatically discover and connect to the master node through Sentinel. diff --git a/docker/docker-compose.yaml b/docker/docker-compose.yaml index 62421d7ec4..8ab3af9788 100644 --- a/docker/docker-compose.yaml +++ b/docker/docker-compose.yaml @@ -90,6 +90,7 @@ x-shared-env: &shared-api-worker-env REDIS_SSL_CERTFILE: ${REDIS_SSL_CERTFILE:-} REDIS_SSL_KEYFILE: ${REDIS_SSL_KEYFILE:-} REDIS_DB: ${REDIS_DB:-0} + REDIS_MAX_CONNECTIONS: ${REDIS_MAX_CONNECTIONS:-} REDIS_USE_SENTINEL: ${REDIS_USE_SENTINEL:-false} REDIS_SENTINELS: ${REDIS_SENTINELS:-} REDIS_SENTINEL_SERVICE_NAME: ${REDIS_SENTINEL_SERVICE_NAME:-} diff --git a/docker/middleware.env.example b/docker/middleware.env.example index c88dbe5511..7b28a77fe3 100644 --- a/docker/middleware.env.example +++ b/docker/middleware.env.example @@ -91,6 +91,9 @@ MYSQL_INNODB_FLUSH_LOG_AT_TRX_COMMIT=2 # ----------------------------- REDIS_HOST_VOLUME=./volumes/redis/data REDIS_PASSWORD=difyai123456 +# Optional: limit total Redis connections used by API/Worker (unset for default) +# Align with API's REDIS_MAX_CONNECTIONS in configs +REDIS_MAX_CONNECTIONS= # ------------------------------ # Environment Variables for sandbox Service From e14b09d4dbb0a01872005886dc788cf4744037de Mon Sep 17 00:00:00 2001 From: wangxiaolei Date: Wed, 4 Mar 2026 13:18:32 +0800 Subject: [PATCH 017/256] refactor: human input node decouple db (#32900) --- api/.importlinter | 4 -- .../advanced_chat/generate_task_pipeline.py | 1 - .../repositories/human_input_repository.py | 29 +++++------ api/core/workflow/node_factory.py | 11 +++++ .../nodes/human_input/human_input_node.py | 9 +--- api/services/human_input_service.py | 2 +- api/services/workflow_service.py | 3 +- api/tasks/human_input_timeout_tasks.py | 2 +- .../test_human_input_form_repository_impl.py | 8 +-- .../test_mail_human_input_delivery_task.py | 3 +- .../test_human_input_form_repository_impl.py | 49 +++++++++++++------ .../tasks/test_human_input_timeout_tasks.py | 6 +-- 12 files changed, 69 insertions(+), 58 deletions(-) diff --git a/api/.importlinter b/api/.importlinter index 14c2b30101..0d9af6e065 100644 --- a/api/.importlinter +++ b/api/.importlinter @@ -58,8 +58,6 @@ ignore_imports = dify_graph.nodes.tool.tool_node -> extensions.ext_database dify_graph.model_runtime.model_providers.__base.ai_model -> extensions.ext_redis dify_graph.model_runtime.model_providers.model_provider_factory -> extensions.ext_redis - # TODO(QuantumGhost): use DI to avoid depending on global DB. - dify_graph.nodes.human_input.human_input_node -> extensions.ext_database [importlinter:contract:workflow-external-imports] name = Workflow External Imports @@ -153,8 +151,6 @@ ignore_imports = dify_graph.nodes.llm.file_saver -> extensions.ext_database dify_graph.nodes.llm.node -> extensions.ext_database dify_graph.nodes.tool.tool_node -> extensions.ext_database - dify_graph.nodes.human_input.human_input_node -> extensions.ext_database - dify_graph.nodes.human_input.human_input_node -> core.repositories.human_input_repository dify_graph.nodes.agent.agent_node -> models dify_graph.nodes.loop.loop_node -> core.app.workflow.layers.llm_quota dify_graph.nodes.llm.node -> models.model diff --git a/api/core/app/apps/advanced_chat/generate_task_pipeline.py b/api/core/app/apps/advanced_chat/generate_task_pipeline.py index f57a0d9b3b..fbd5060b8c 100644 --- a/api/core/app/apps/advanced_chat/generate_task_pipeline.py +++ b/api/core/app/apps/advanced_chat/generate_task_pipeline.py @@ -735,7 +735,6 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport): def _load_human_input_form_id(self, *, node_id: str) -> str | None: form_repository = HumanInputFormRepositoryImpl( - session_factory=db.engine, tenant_id=self._workflow_tenant_id, ) form = form_repository.get_form(self._workflow_run_id, node_id) diff --git a/api/core/repositories/human_input_repository.py b/api/core/repositories/human_input_repository.py index bd9afe36f0..6607a87032 100644 --- a/api/core/repositories/human_input_repository.py +++ b/api/core/repositories/human_input_repository.py @@ -4,9 +4,10 @@ from collections.abc import Mapping, Sequence from datetime import datetime from typing import Any -from sqlalchemy import Engine, select -from sqlalchemy.orm import Session, selectinload, sessionmaker +from sqlalchemy import select +from sqlalchemy.orm import Session, selectinload +from core.db.session_factory import session_factory from dify_graph.nodes.human_input.entities import ( DeliveryChannelConfig, EmailDeliveryMethod, @@ -198,12 +199,9 @@ class _InvalidTimeoutStatusError(ValueError): class HumanInputFormRepositoryImpl: def __init__( self, - session_factory: sessionmaker | Engine, + *, tenant_id: str, ): - if isinstance(session_factory, Engine): - session_factory = sessionmaker(bind=session_factory) - self._session_factory = session_factory self._tenant_id = tenant_id def _delivery_method_to_model( @@ -217,7 +215,7 @@ class HumanInputFormRepositoryImpl: id=delivery_id, form_id=form_id, delivery_method_type=delivery_method.type, - delivery_config_id=delivery_method.id, + delivery_config_id=str(delivery_method.id), channel_payload=delivery_method.model_dump_json(), ) recipients: list[HumanInputFormRecipient] = [] @@ -343,7 +341,7 @@ class HumanInputFormRepositoryImpl: def create_form(self, params: FormCreateParams) -> HumanInputFormEntity: form_config: HumanInputNodeData = params.form_config - with self._session_factory(expire_on_commit=False) as session, session.begin(): + with session_factory.create_session() as session, session.begin(): # Generate unique form ID form_id = str(uuidv7()) start_time = naive_utc_now() @@ -435,7 +433,7 @@ class HumanInputFormRepositoryImpl: HumanInputForm.node_id == node_id, HumanInputForm.tenant_id == self._tenant_id, ) - with self._session_factory(expire_on_commit=False) as session: + with session_factory.create_session() as session: form_model: HumanInputForm | None = session.scalars(form_query).first() if form_model is None: return None @@ -448,18 +446,13 @@ class HumanInputFormRepositoryImpl: class HumanInputFormSubmissionRepository: """Repository for fetching and submitting human input forms.""" - def __init__(self, session_factory: sessionmaker | Engine): - if isinstance(session_factory, Engine): - session_factory = sessionmaker(bind=session_factory) - self._session_factory = session_factory - def get_by_token(self, form_token: str) -> HumanInputFormRecord | None: query = ( select(HumanInputFormRecipient) .options(selectinload(HumanInputFormRecipient.form)) .where(HumanInputFormRecipient.access_token == form_token) ) - with self._session_factory(expire_on_commit=False) as session: + with session_factory.create_session() as session: recipient_model = session.scalars(query).first() if recipient_model is None or recipient_model.form is None: return None @@ -478,7 +471,7 @@ class HumanInputFormSubmissionRepository: HumanInputFormRecipient.recipient_type == recipient_type, ) ) - with self._session_factory(expire_on_commit=False) as session: + with session_factory.create_session() as session: recipient_model = session.scalars(query).first() if recipient_model is None or recipient_model.form is None: return None @@ -494,7 +487,7 @@ class HumanInputFormSubmissionRepository: submission_user_id: str | None, submission_end_user_id: str | None, ) -> HumanInputFormRecord: - with self._session_factory(expire_on_commit=False) as session, session.begin(): + with session_factory.create_session() as session, session.begin(): form_model = session.get(HumanInputForm, form_id) if form_model is None: raise FormNotFoundError(f"form not found, id={form_id}") @@ -524,7 +517,7 @@ class HumanInputFormSubmissionRepository: timeout_status: HumanInputFormStatus, reason: str | None = None, ) -> HumanInputFormRecord: - with self._session_factory(expire_on_commit=False) as session, session.begin(): + with session_factory.create_session() as session, session.begin(): form_model = session.get(HumanInputForm, form_id) if form_model is None: raise FormNotFoundError(f"form not found, id={form_id}") diff --git a/api/core/workflow/node_factory.py b/api/core/workflow/node_factory.py index 22d86748f1..1b4937769e 100644 --- a/api/core/workflow/node_factory.py +++ b/api/core/workflow/node_factory.py @@ -19,6 +19,7 @@ from core.prompt.entities.advanced_prompt_entities import MemoryConfig from core.rag.index_processor.index_processor import IndexProcessor from core.rag.retrieval.dataset_retrieval import DatasetRetrieval from core.rag.summary_index.summary_index import SummaryIndex +from core.repositories.human_input_repository import HumanInputFormRepositoryImpl from core.tools.tool_file_manager import ToolFileManager from dify_graph.entities.graph_config import NodeConfigDict from dify_graph.enums import NodeType, SystemVariableKey @@ -34,6 +35,7 @@ from dify_graph.nodes.code.limits import CodeNodeLimits from dify_graph.nodes.datasource import DatasourceNode from dify_graph.nodes.document_extractor import DocumentExtractorNode, UnstructuredApiConfig from dify_graph.nodes.http_request import HttpRequestNode, build_http_request_config +from dify_graph.nodes.human_input.human_input_node import HumanInputNode from dify_graph.nodes.knowledge_index.knowledge_index_node import KnowledgeIndexNode from dify_graph.nodes.knowledge_retrieval.knowledge_retrieval_node import KnowledgeRetrievalNode from dify_graph.nodes.llm.entities import ModelConfig @@ -205,6 +207,15 @@ class DifyNodeFactory(NodeFactory): file_manager=self._http_request_file_manager, ) + if node_type == NodeType.HUMAN_INPUT: + return HumanInputNode( + id=node_id, + config=node_config, + graph_init_params=self.graph_init_params, + graph_runtime_state=self.graph_runtime_state, + form_repository=HumanInputFormRepositoryImpl(tenant_id=self.graph_init_params.tenant_id), + ) + if node_type == NodeType.KNOWLEDGE_INDEX: return KnowledgeIndexNode( id=node_id, diff --git a/api/dify_graph/nodes/human_input/human_input_node.py b/api/dify_graph/nodes/human_input/human_input_node.py index f41423f550..e54650898d 100644 --- a/api/dify_graph/nodes/human_input/human_input_node.py +++ b/api/dify_graph/nodes/human_input/human_input_node.py @@ -3,7 +3,6 @@ import logging from collections.abc import Generator, Mapping, Sequence from typing import TYPE_CHECKING, Any -from core.repositories.human_input_repository import HumanInputFormRepositoryImpl from dify_graph.entities.pause_reason import HumanInputRequired from dify_graph.enums import InvokeFrom, NodeExecutionType, NodeType, WorkflowNodeExecutionStatus from dify_graph.node_events import ( @@ -21,7 +20,6 @@ from dify_graph.repositories.human_input_form_repository import ( HumanInputFormRepository, ) from dify_graph.workflow_type_encoder import WorkflowRuntimeTypeConverter -from extensions.ext_database import db from libs.datetime_utils import naive_utc_now from .entities import DeliveryChannelConfig, HumanInputNodeData, apply_debug_email_recipient @@ -66,7 +64,7 @@ class HumanInputNode(Node[HumanInputNodeData]): config: Mapping[str, Any], graph_init_params: "GraphInitParams", graph_runtime_state: "GraphRuntimeState", - form_repository: HumanInputFormRepository | None = None, + form_repository: HumanInputFormRepository, ) -> None: super().__init__( id=id, @@ -74,11 +72,6 @@ class HumanInputNode(Node[HumanInputNodeData]): graph_init_params=graph_init_params, graph_runtime_state=graph_runtime_state, ) - if form_repository is None: - form_repository = HumanInputFormRepositoryImpl( - session_factory=db.engine, - tenant_id=self.tenant_id, - ) self._form_repository = form_repository @classmethod diff --git a/api/services/human_input_service.py b/api/services/human_input_service.py index cfab723fef..2e74c50963 100644 --- a/api/services/human_input_service.py +++ b/api/services/human_input_service.py @@ -130,7 +130,7 @@ class HumanInputService: if isinstance(session_factory, Engine): session_factory = sessionmaker(bind=session_factory) self._session_factory = session_factory - self._form_repository = form_repository or HumanInputFormSubmissionRepository(session_factory) + self._form_repository = form_repository or HumanInputFormSubmissionRepository() def get_form_by_token(self, form_token: str) -> Form | None: record = self._form_repository.get_by_token(form_token) diff --git a/api/services/workflow_service.py b/api/services/workflow_service.py index 3ea38c3535..21bc95136e 100644 --- a/api/services/workflow_service.py +++ b/api/services/workflow_service.py @@ -1015,7 +1015,7 @@ class WorkflowService: rendered_content: str, resolved_default_values: Mapping[str, Any], ) -> tuple[str, list[DeliveryTestEmailRecipient]]: - repo = HumanInputFormRepositoryImpl(session_factory=db.engine, tenant_id=app_model.tenant_id) + repo = HumanInputFormRepositoryImpl(tenant_id=app_model.tenant_id) params = FormCreateParams( app_id=app_model.id, workflow_execution_id=None, @@ -1081,6 +1081,7 @@ class WorkflowService: config=node_config, graph_init_params=graph_init_params, graph_runtime_state=graph_runtime_state, + form_repository=HumanInputFormRepositoryImpl(tenant_id=workflow.tenant_id), ) return node diff --git a/api/tasks/human_input_timeout_tasks.py b/api/tasks/human_input_timeout_tasks.py index 03441683b0..dd3b6a4530 100644 --- a/api/tasks/human_input_timeout_tasks.py +++ b/api/tasks/human_input_timeout_tasks.py @@ -58,7 +58,7 @@ def check_and_handle_human_input_timeouts(limit: int = 100) -> None: """Scan for expired human input forms and resume or end workflows.""" session_factory = sessionmaker(bind=db.engine, expire_on_commit=False) - form_repo = HumanInputFormSubmissionRepository(session_factory) + form_repo = HumanInputFormSubmissionRepository() service = HumanInputService(session_factory, form_repository=form_repo) now = naive_utc_now() global_timeout_seconds = dify_config.HUMAN_INPUT_GLOBAL_TIMEOUT_SECONDS diff --git a/api/tests/test_containers_integration_tests/core/repositories/test_human_input_form_repository_impl.py b/api/tests/test_containers_integration_tests/core/repositories/test_human_input_form_repository_impl.py index 4b362d1abe..9d0fad4b12 100644 --- a/api/tests/test_containers_integration_tests/core/repositories/test_human_input_form_repository_impl.py +++ b/api/tests/test_containers_integration_tests/core/repositories/test_human_input_form_repository_impl.py @@ -100,7 +100,7 @@ class TestHumanInputFormRepositoryImplWithContainers: member_emails=["member1@example.com", "member2@example.com"], ) - repository = HumanInputFormRepositoryImpl(session_factory=engine, tenant_id=tenant.id) + repository = HumanInputFormRepositoryImpl(tenant_id=tenant.id) params = _build_form_params( delivery_methods=[_build_email_delivery(whole_workspace=True, recipients=[])], ) @@ -129,7 +129,7 @@ class TestHumanInputFormRepositoryImplWithContainers: member_emails=["primary@example.com", "secondary@example.com"], ) - repository = HumanInputFormRepositoryImpl(session_factory=engine, tenant_id=tenant.id) + repository = HumanInputFormRepositoryImpl(tenant_id=tenant.id) params = _build_form_params( delivery_methods=[ _build_email_delivery( @@ -173,7 +173,7 @@ class TestHumanInputFormRepositoryImplWithContainers: member_emails=["prefill@example.com"], ) - repository = HumanInputFormRepositoryImpl(session_factory=engine, tenant_id=tenant.id) + repository = HumanInputFormRepositoryImpl(tenant_id=tenant.id) resolved_values = {"greeting": "Hello!"} params = FormCreateParams( app_id=str(uuid4()), @@ -210,7 +210,7 @@ class TestHumanInputFormRepositoryImplWithContainers: member_emails=["ui@example.com"], ) - repository = HumanInputFormRepositoryImpl(session_factory=engine, tenant_id=tenant.id) + repository = HumanInputFormRepositoryImpl(tenant_id=tenant.id) params = FormCreateParams( app_id=str(uuid4()), workflow_execution_id=str(uuid4()), diff --git a/api/tests/test_containers_integration_tests/tasks/test_mail_human_input_delivery_task.py b/api/tests/test_containers_integration_tests/tasks/test_mail_human_input_delivery_task.py index c9bcba6639..0876a39f82 100644 --- a/api/tests/test_containers_integration_tests/tasks/test_mail_human_input_delivery_task.py +++ b/api/tests/test_containers_integration_tests/tasks/test_mail_human_input_delivery_task.py @@ -96,8 +96,7 @@ def _build_form(db_session_with_containers, tenant, account, *, app_id: str, wor delivery_methods=[delivery_method], ) - engine = db_session_with_containers.get_bind() - repo = HumanInputFormRepositoryImpl(session_factory=engine, tenant_id=tenant.id) + repo = HumanInputFormRepositoryImpl(tenant_id=tenant.id) params = FormCreateParams( app_id=app_id, workflow_execution_id=workflow_execution_id, diff --git a/api/tests/unit_tests/core/repositories/test_human_input_form_repository_impl.py b/api/tests/unit_tests/core/repositories/test_human_input_form_repository_impl.py index 7bf7c2e5f6..9af4d12664 100644 --- a/api/tests/unit_tests/core/repositories/test_human_input_form_repository_impl.py +++ b/api/tests/unit_tests/core/repositories/test_human_input_form_repository_impl.py @@ -5,7 +5,6 @@ from __future__ import annotations import dataclasses from datetime import datetime from types import SimpleNamespace -from unittest.mock import MagicMock import pytest @@ -35,7 +34,7 @@ from models.human_input import ( def _build_repository() -> HumanInputFormRepositoryImpl: - return HumanInputFormRepositoryImpl(session_factory=MagicMock(), tenant_id="tenant-id") + return HumanInputFormRepositoryImpl(tenant_id="tenant-id") def _patch_recipient_factory(monkeypatch: pytest.MonkeyPatch) -> list[SimpleNamespace]: @@ -389,8 +388,21 @@ def _session_factory(session: _FakeSession): return _factory +def _patch_repo_session_factory(monkeypatch: pytest.MonkeyPatch, session: _FakeSession) -> None: + """Patch repository's global session factory to return our fake session. + + The repositories under test now use a global session factory; patch its + create_session method so unit tests don't hit a real database. + """ + monkeypatch.setattr( + "core.repositories.human_input_repository.session_factory.create_session", + _session_factory(session), + raising=True, + ) + + class TestHumanInputFormRepositoryImplPublicMethods: - def test_get_form_returns_entity_and_recipients(self): + def test_get_form_returns_entity_and_recipients(self, monkeypatch: pytest.MonkeyPatch): form = _DummyForm( id="form-1", workflow_run_id="run-1", @@ -408,7 +420,8 @@ class TestHumanInputFormRepositoryImplPublicMethods: access_token="token-123", ) session = _FakeSession(scalars_results=[form, [recipient]]) - repo = HumanInputFormRepositoryImpl(_session_factory(session), tenant_id="tenant-id") + _patch_repo_session_factory(monkeypatch, session) + repo = HumanInputFormRepositoryImpl(tenant_id="tenant-id") entity = repo.get_form(form.workflow_run_id, form.node_id) @@ -418,13 +431,14 @@ class TestHumanInputFormRepositoryImplPublicMethods: assert len(entity.recipients) == 1 assert entity.recipients[0].token == "token-123" - def test_get_form_returns_none_when_missing(self): + def test_get_form_returns_none_when_missing(self, monkeypatch: pytest.MonkeyPatch): session = _FakeSession(scalars_results=[None]) - repo = HumanInputFormRepositoryImpl(_session_factory(session), tenant_id="tenant-id") + _patch_repo_session_factory(monkeypatch, session) + repo = HumanInputFormRepositoryImpl(tenant_id="tenant-id") assert repo.get_form("run-1", "node-1") is None - def test_get_form_returns_unsubmitted_state(self): + def test_get_form_returns_unsubmitted_state(self, monkeypatch: pytest.MonkeyPatch): form = _DummyForm( id="form-1", workflow_run_id="run-1", @@ -436,7 +450,8 @@ class TestHumanInputFormRepositoryImplPublicMethods: expiration_time=naive_utc_now(), ) session = _FakeSession(scalars_results=[form, []]) - repo = HumanInputFormRepositoryImpl(_session_factory(session), tenant_id="tenant-id") + _patch_repo_session_factory(monkeypatch, session) + repo = HumanInputFormRepositoryImpl(tenant_id="tenant-id") entity = repo.get_form(form.workflow_run_id, form.node_id) @@ -445,7 +460,7 @@ class TestHumanInputFormRepositoryImplPublicMethods: assert entity.selected_action_id is None assert entity.submitted_data is None - def test_get_form_returns_submission_when_completed(self): + def test_get_form_returns_submission_when_completed(self, monkeypatch: pytest.MonkeyPatch): form = _DummyForm( id="form-1", workflow_run_id="run-1", @@ -460,7 +475,8 @@ class TestHumanInputFormRepositoryImplPublicMethods: submitted_at=naive_utc_now(), ) session = _FakeSession(scalars_results=[form, []]) - repo = HumanInputFormRepositoryImpl(_session_factory(session), tenant_id="tenant-id") + _patch_repo_session_factory(monkeypatch, session) + repo = HumanInputFormRepositoryImpl(tenant_id="tenant-id") entity = repo.get_form(form.workflow_run_id, form.node_id) @@ -471,7 +487,7 @@ class TestHumanInputFormRepositoryImplPublicMethods: class TestHumanInputFormSubmissionRepository: - def test_get_by_token_returns_record(self): + def test_get_by_token_returns_record(self, monkeypatch: pytest.MonkeyPatch): form = _DummyForm( id="form-1", workflow_run_id="run-1", @@ -490,7 +506,8 @@ class TestHumanInputFormSubmissionRepository: form=form, ) session = _FakeSession(scalars_result=recipient) - repo = HumanInputFormSubmissionRepository(_session_factory(session)) + _patch_repo_session_factory(monkeypatch, session) + repo = HumanInputFormSubmissionRepository() record = repo.get_by_token("token-123") @@ -499,7 +516,7 @@ class TestHumanInputFormSubmissionRepository: assert record.recipient_type == RecipientType.STANDALONE_WEB_APP assert record.submitted is False - def test_get_by_form_id_and_recipient_type_uses_recipient(self): + def test_get_by_form_id_and_recipient_type_uses_recipient(self, monkeypatch: pytest.MonkeyPatch): form = _DummyForm( id="form-1", workflow_run_id="run-1", @@ -518,7 +535,8 @@ class TestHumanInputFormSubmissionRepository: form=form, ) session = _FakeSession(scalars_result=recipient) - repo = HumanInputFormSubmissionRepository(_session_factory(session)) + _patch_repo_session_factory(monkeypatch, session) + repo = HumanInputFormSubmissionRepository() record = repo.get_by_form_id_and_recipient_type( form_id=form.id, @@ -553,7 +571,8 @@ class TestHumanInputFormSubmissionRepository: forms={form.id: form}, recipients={recipient.id: recipient}, ) - repo = HumanInputFormSubmissionRepository(_session_factory(session)) + _patch_repo_session_factory(monkeypatch, session) + repo = HumanInputFormSubmissionRepository() record: HumanInputFormRecord = repo.mark_submitted( form_id=form.id, diff --git a/api/tests/unit_tests/tasks/test_human_input_timeout_tasks.py b/api/tests/unit_tests/tasks/test_human_input_timeout_tasks.py index 6b07f88c41..bd0182a402 100644 --- a/api/tests/unit_tests/tasks/test_human_input_timeout_tasks.py +++ b/api/tests/unit_tests/tasks/test_human_input_timeout_tasks.py @@ -47,7 +47,7 @@ class _FakeSessionFactory: class _FakeFormRepo: - def __init__(self, _session_factory, form_map: dict[str, Any] | None = None): + def __init__(self, form_map: dict[str, Any] | None = None): self.calls: list[dict[str, Any]] = [] self._form_map = form_map or {} @@ -149,9 +149,9 @@ def test_check_and_handle_human_input_timeouts_marks_and_routes(monkeypatch: pyt monkeypatch.setattr(task_module, "sessionmaker", lambda *args, **kwargs: _FakeSessionFactory(forms, capture)) form_map = {form.id: form for form in forms} - repo = _FakeFormRepo(None, form_map=form_map) + repo = _FakeFormRepo(form_map=form_map) - def _repo_factory(_session_factory): + def _repo_factory(): return repo service = _FakeService(None) From 2f4c740d4643f0eaf63e44ab9f423e9928c294a8 Mon Sep 17 00:00:00 2001 From: wangxiaolei Date: Wed, 4 Mar 2026 13:18:55 +0800 Subject: [PATCH 018/256] feat: support redis xstream (#32586) --- .../middleware/cache/redis_pubsub_config.py | 49 +++-- api/extensions/ext_redis.py | 6 + .../redis/streams_channel.py | 159 ++++++++++++++ api/services/app_generate_service.py | 23 +- .../redis/test_streams_channel_unit_tests.py | 145 +++++++++++++ ..._generate_service_streaming_integration.py | 197 ++++++++++++++++++ 6 files changed, 558 insertions(+), 21 deletions(-) create mode 100644 api/libs/broadcast_channel/redis/streams_channel.py create mode 100644 api/tests/unit_tests/libs/broadcast_channel/redis/test_streams_channel_unit_tests.py create mode 100644 api/tests/unit_tests/services/test_app_generate_service_streaming_integration.py diff --git a/api/configs/middleware/cache/redis_pubsub_config.py b/api/configs/middleware/cache/redis_pubsub_config.py index a72e1dd28f..8cddc5677a 100644 --- a/api/configs/middleware/cache/redis_pubsub_config.py +++ b/api/configs/middleware/cache/redis_pubsub_config.py @@ -1,7 +1,7 @@ from typing import Literal, Protocol from urllib.parse import quote_plus, urlunparse -from pydantic import Field +from pydantic import AliasChoices, Field from pydantic_settings import BaseSettings @@ -23,41 +23,56 @@ class RedisConfigDefaultsMixin: class RedisPubSubConfig(BaseSettings, RedisConfigDefaultsMixin): """ - Configuration settings for Redis pub/sub streaming. + Configuration settings for event transport between API and workers. + + Supported transports: + - pubsub: Redis PUBLISH/SUBSCRIBE (at-most-once) + - sharded: Redis 7+ Sharded Pub/Sub (at-most-once, better scaling) + - streams: Redis Streams (at-least-once, supports late subscribers) """ PUBSUB_REDIS_URL: str | None = Field( - alias="PUBSUB_REDIS_URL", + validation_alias=AliasChoices("EVENT_BUS_REDIS_URL", "PUBSUB_REDIS_URL"), description=( - "Redis connection URL for pub/sub streaming events between API " - "and celery worker, defaults to url constructed from " - "`REDIS_*` configurations" + "Redis connection URL for streaming events between API and celery worker; " + "defaults to URL constructed from `REDIS_*` configurations. Also accepts ENV: EVENT_BUS_REDIS_URL." ), default=None, ) PUBSUB_REDIS_USE_CLUSTERS: bool = Field( + validation_alias=AliasChoices("EVENT_BUS_REDIS_CLUSTERS", "PUBSUB_REDIS_USE_CLUSTERS"), description=( - "Enable Redis Cluster mode for pub/sub streaming. It's highly " - "recommended to enable this for large deployments." + "Enable Redis Cluster mode for pub/sub or streams transport. Recommended for large deployments. " + "Also accepts ENV: EVENT_BUS_REDIS_CLUSTERS." ), default=False, ) - PUBSUB_REDIS_CHANNEL_TYPE: Literal["pubsub", "sharded"] = Field( + PUBSUB_REDIS_CHANNEL_TYPE: Literal["pubsub", "sharded", "streams"] = Field( + validation_alias=AliasChoices("EVENT_BUS_REDIS_CHANNEL_TYPE", "PUBSUB_REDIS_CHANNEL_TYPE"), description=( - "Pub/sub channel type for streaming events. " - "Valid options are:\n" - "\n" - " - pubsub: for normal Pub/Sub\n" - " - sharded: for sharded Pub/Sub\n" - "\n" - "It's highly recommended to use sharded Pub/Sub AND redis cluster " - "for large deployments." + "Event transport type. Options are:\n\n" + " - pubsub: normal Pub/Sub (at-most-once)\n" + " - sharded: sharded Pub/Sub (at-most-once)\n" + " - streams: Redis Streams (at-least-once, recommended to avoid subscriber races)\n\n" + "Note: Before enabling 'streams' in production, estimate your expected event volume and retention needs.\n" + "Configure Redis memory limits and stream trimming appropriately (e.g., MAXLEN and key expiry) to reduce\n" + "the risk of data loss from Redis auto-eviction under memory pressure.\n" + "Also accepts ENV: EVENT_BUS_REDIS_CHANNEL_TYPE." ), default="pubsub", ) + PUBSUB_STREAMS_RETENTION_SECONDS: int = Field( + validation_alias=AliasChoices("EVENT_BUS_STREAMS_RETENTION_SECONDS", "PUBSUB_STREAMS_RETENTION_SECONDS"), + description=( + "When using 'streams', expire each stream key this many seconds after the last event is published. " + "Also accepts ENV: EVENT_BUS_STREAMS_RETENTION_SECONDS." + ), + default=600, + ) + def _build_default_pubsub_url(self) -> str: defaults = self._redis_defaults() if not defaults.REDIS_HOST or not defaults.REDIS_PORT: diff --git a/api/extensions/ext_redis.py b/api/extensions/ext_redis.py index cadd9cb263..26262484f9 100644 --- a/api/extensions/ext_redis.py +++ b/api/extensions/ext_redis.py @@ -18,6 +18,7 @@ from dify_app import DifyApp from libs.broadcast_channel.channel import BroadcastChannel as BroadcastChannelProtocol from libs.broadcast_channel.redis.channel import BroadcastChannel as RedisBroadcastChannel from libs.broadcast_channel.redis.sharded_channel import ShardedRedisBroadcastChannel +from libs.broadcast_channel.redis.streams_channel import StreamsBroadcastChannel if TYPE_CHECKING: from redis.lock import Lock @@ -288,6 +289,11 @@ def get_pubsub_broadcast_channel() -> BroadcastChannelProtocol: assert _pubsub_redis_client is not None, "PubSub redis Client should be initialized here." if dify_config.PUBSUB_REDIS_CHANNEL_TYPE == "sharded": return ShardedRedisBroadcastChannel(_pubsub_redis_client) + if dify_config.PUBSUB_REDIS_CHANNEL_TYPE == "streams": + return StreamsBroadcastChannel( + _pubsub_redis_client, + retention_seconds=dify_config.PUBSUB_STREAMS_RETENTION_SECONDS, + ) return RedisBroadcastChannel(_pubsub_redis_client) diff --git a/api/libs/broadcast_channel/redis/streams_channel.py b/api/libs/broadcast_channel/redis/streams_channel.py new file mode 100644 index 0000000000..d6ec5504ca --- /dev/null +++ b/api/libs/broadcast_channel/redis/streams_channel.py @@ -0,0 +1,159 @@ +from __future__ import annotations + +import logging +import queue +import threading +from collections.abc import Iterator +from typing import Self + +from libs.broadcast_channel.channel import Producer, Subscriber, Subscription +from libs.broadcast_channel.exc import SubscriptionClosedError +from redis import Redis, RedisCluster + +logger = logging.getLogger(__name__) + + +class StreamsBroadcastChannel: + """ + Redis Streams based broadcast channel implementation. + + Characteristics: + - At-least-once delivery for late subscribers within the stream retention window. + - Each topic is stored as a dedicated Redis Stream key. + - The stream key expires `retention_seconds` after the last event is published (to bound storage). + """ + + def __init__(self, redis_client: Redis | RedisCluster, *, retention_seconds: int = 600): + self._client = redis_client + self._retention_seconds = max(int(retention_seconds or 0), 0) + + def topic(self, topic: str) -> StreamsTopic: + return StreamsTopic(self._client, topic, retention_seconds=self._retention_seconds) + + +class StreamsTopic: + def __init__(self, redis_client: Redis | RedisCluster, topic: str, *, retention_seconds: int = 600): + self._client = redis_client + self._topic = topic + self._key = f"stream:{topic}" + self._retention_seconds = retention_seconds + self.max_length = 5000 + + def as_producer(self) -> Producer: + return self + + def publish(self, payload: bytes) -> None: + self._client.xadd(self._key, {b"data": payload}, maxlen=self.max_length) + if self._retention_seconds > 0: + try: + self._client.expire(self._key, self._retention_seconds) + except Exception as e: + logger.warning("Failed to set expire for stream key %s: %s", self._key, e, exc_info=True) + + def as_subscriber(self) -> Subscriber: + return self + + def subscribe(self) -> Subscription: + return _StreamsSubscription(self._client, self._key) + + +class _StreamsSubscription(Subscription): + _SENTINEL = object() + + def __init__(self, client: Redis | RedisCluster, key: str): + self._client = client + self._key = key + self._closed = threading.Event() + self._last_id = "0-0" + self._queue: queue.Queue[object] = queue.Queue() + self._start_lock = threading.Lock() + self._listener: threading.Thread | None = None + + def _listen(self) -> None: + try: + while not self._closed.is_set(): + streams = self._client.xread({self._key: self._last_id}, block=1000, count=100) + + if not streams: + continue + + for _key, entries in streams: + for entry_id, fields in entries: + data = None + if isinstance(fields, dict): + data = fields.get(b"data") + data_bytes: bytes | None = None + if isinstance(data, str): + data_bytes = data.encode() + elif isinstance(data, (bytes, bytearray)): + data_bytes = bytes(data) + if data_bytes is not None: + self._queue.put_nowait(data_bytes) + self._last_id = entry_id + finally: + self._queue.put_nowait(self._SENTINEL) + self._listener = None + + def _start_if_needed(self) -> None: + if self._listener is not None: + return + # Ensure only one listener thread is created under concurrent calls + with self._start_lock: + if self._listener is not None or self._closed.is_set(): + return + self._listener = threading.Thread( + target=self._listen, + name=f"redis-streams-sub-{self._key}", + daemon=True, + ) + self._listener.start() + + def __iter__(self) -> Iterator[bytes]: + # Iterator delegates to receive with timeout; stops on closure. + self._start_if_needed() + while not self._closed.is_set(): + item = self.receive(timeout=1) + if item is not None: + yield item + + def receive(self, timeout: float | None = 0.1) -> bytes | None: + if self._closed.is_set(): + raise SubscriptionClosedError("The Redis streams subscription is closed") + self._start_if_needed() + + try: + if timeout is None: + item = self._queue.get() + else: + item = self._queue.get(timeout=timeout) + except queue.Empty: + return None + + if item is self._SENTINEL or self._closed.is_set(): + raise SubscriptionClosedError("The Redis streams subscription is closed") + assert isinstance(item, (bytes, bytearray)), "Unexpected item type in stream queue" + return bytes(item) + + def close(self) -> None: + if self._closed.is_set(): + return + self._closed.set() + listener = self._listener + if listener is not None: + listener.join(timeout=2.0) + if listener.is_alive(): + logger.warning( + "Streams subscription listener for key %s did not stop within timeout; keeping reference.", + self._key, + ) + else: + self._listener = None + + # Context manager helpers + def __enter__(self) -> Self: + self._start_if_needed() + return self + + def __exit__(self, exc_type, exc_value, traceback) -> bool | None: + self.close() + return None diff --git a/api/services/app_generate_service.py b/api/services/app_generate_service.py index 31003cb8f7..40013f2b66 100644 --- a/api/services/app_generate_service.py +++ b/api/services/app_generate_service.py @@ -38,6 +38,13 @@ if TYPE_CHECKING: class AppGenerateService: @staticmethod def _build_streaming_task_on_subscribe(start_task: Callable[[], None]) -> Callable[[], None]: + """ + Build a subscription callback that coordinates when the background task starts. + + - streams transport: start immediately (events are durable; late subscribers can replay). + - pubsub/sharded transport: start on first subscribe, with a short fallback timer so the task + still runs if the client never connects. + """ started = False lock = threading.Lock() @@ -54,10 +61,18 @@ class AppGenerateService: started = True return True - # XXX(QuantumGhost): dirty hacks to avoid a race between publisher and SSE subscriber. - # The Celery task may publish the first event before the API side actually subscribes, - # causing an "at most once" drop with Redis Pub/Sub. We start the task on subscribe, - # but also use a short fallback timer so the task still runs if the client never consumes. + channel_type = dify_config.PUBSUB_REDIS_CHANNEL_TYPE + if channel_type == "streams": + # With Redis Streams, we can safely start right away; consumers can read past events. + _try_start() + + # Keep return type Callable[[], None] consistent while allowing an extra (no-op) call. + def _on_subscribe_streams() -> None: + _try_start() + + return _on_subscribe_streams + + # Pub/Sub modes (at-most-once): subscribe-gated start with a tiny fallback. timer = threading.Timer(SSE_TASK_START_FALLBACK_MS / 1000.0, _try_start) timer.daemon = True timer.start() diff --git a/api/tests/unit_tests/libs/broadcast_channel/redis/test_streams_channel_unit_tests.py b/api/tests/unit_tests/libs/broadcast_channel/redis/test_streams_channel_unit_tests.py new file mode 100644 index 0000000000..248aa0b145 --- /dev/null +++ b/api/tests/unit_tests/libs/broadcast_channel/redis/test_streams_channel_unit_tests.py @@ -0,0 +1,145 @@ +import time + +import pytest + +from libs.broadcast_channel.redis.streams_channel import ( + StreamsBroadcastChannel, + StreamsTopic, + _StreamsSubscription, +) + + +class FakeStreamsRedis: + """Minimal in-memory Redis Streams stub for unit tests. + + - Stores entries per key as [(id, {b"data": bytes}), ...] + - xadd appends entries and returns an auto-increment id like "1-0" + - xread returns entries strictly greater than last_id + - expire is recorded but has no effect on behavior + """ + + def __init__(self) -> None: + self._store: dict[str, list[tuple[str, dict]]] = {} + self._next_id: dict[str, int] = {} + self._expire_calls: dict[str, int] = {} + + # Publisher API + def xadd(self, key: str, fields: dict, *, maxlen: int | None = None) -> str: + """Append entry to stream; accept optional maxlen for API compatibility. + + The test double ignores maxlen trimming semantics; only records the entry. + """ + n = self._next_id.get(key, 0) + 1 + self._next_id[key] = n + entry_id = f"{n}-0" + self._store.setdefault(key, []).append((entry_id, fields)) + return entry_id + + def expire(self, key: str, seconds: int) -> None: + self._expire_calls[key] = self._expire_calls.get(key, 0) + 1 + + # Consumer API + def xread(self, streams: dict, block: int | None = None, count: int | None = None): + # Expect a single key + assert len(streams) == 1 + key, last_id = next(iter(streams.items())) + entries = self._store.get(key, []) + + # Find position strictly greater than last_id + start_idx = 0 + if last_id != "0-0": + for i, (eid, _f) in enumerate(entries): + if eid == last_id: + start_idx = i + 1 + break + if start_idx >= len(entries): + # Simulate blocking wait (bounded) if requested + if block and block > 0: + time.sleep(min(0.01, block / 1000.0)) + return [] + + end_idx = len(entries) if count is None else min(len(entries), start_idx + count) + batch = entries[start_idx:end_idx] + return [(key, batch)] + + +@pytest.fixture +def fake_redis() -> FakeStreamsRedis: + return FakeStreamsRedis() + + +@pytest.fixture +def streams_channel(fake_redis: FakeStreamsRedis) -> StreamsBroadcastChannel: + return StreamsBroadcastChannel(fake_redis, retention_seconds=60) + + +class TestStreamsBroadcastChannel: + def test_topic_creation(self, streams_channel: StreamsBroadcastChannel, fake_redis: FakeStreamsRedis): + topic = streams_channel.topic("alpha") + assert isinstance(topic, StreamsTopic) + assert topic._client is fake_redis + assert topic._topic == "alpha" + assert topic._key == "stream:alpha" + + def test_publish_calls_xadd_and_expire( + self, + streams_channel: StreamsBroadcastChannel, + fake_redis: FakeStreamsRedis, + ): + topic = streams_channel.topic("beta") + payload = b"hello" + topic.publish(payload) + # One entry stored under stream key (bytes key for payload field) + assert fake_redis._store["stream:beta"][0][1] == {b"data": payload} + # Expire called after publish + assert fake_redis._expire_calls.get("stream:beta", 0) >= 1 + + +class TestStreamsSubscription: + def test_subscribe_and_receive_from_beginning(self, streams_channel: StreamsBroadcastChannel): + topic = streams_channel.topic("gamma") + # Pre-publish events before subscribing (late subscriber) + topic.publish(b"e1") + topic.publish(b"e2") + + sub = topic.subscribe() + assert isinstance(sub, _StreamsSubscription) + + received: list[bytes] = [] + with sub: + # Give listener thread a moment to xread + time.sleep(0.05) + # Drain using receive() to avoid indefinite iteration in tests + for _ in range(5): + msg = sub.receive(timeout=0.1) + if msg is None: + break + received.append(msg) + + assert received == [b"e1", b"e2"] + + def test_receive_timeout_returns_none(self, streams_channel: StreamsBroadcastChannel): + topic = streams_channel.topic("delta") + sub = topic.subscribe() + with sub: + # No messages yet + assert sub.receive(timeout=0.05) is None + + def test_close_stops_listener(self, streams_channel: StreamsBroadcastChannel): + topic = streams_channel.topic("epsilon") + sub = topic.subscribe() + with sub: + # Listener running; now close and ensure no crash + sub.close() + # After close, receive should raise SubscriptionClosedError + from libs.broadcast_channel.exc import SubscriptionClosedError + + with pytest.raises(SubscriptionClosedError): + sub.receive() + + def test_no_expire_when_zero_retention(self, fake_redis: FakeStreamsRedis): + channel = StreamsBroadcastChannel(fake_redis, retention_seconds=0) + topic = channel.topic("zeta") + topic.publish(b"payload") + # No expire recorded when retention is disabled + assert fake_redis._expire_calls.get("stream:zeta") is None diff --git a/api/tests/unit_tests/services/test_app_generate_service_streaming_integration.py b/api/tests/unit_tests/services/test_app_generate_service_streaming_integration.py new file mode 100644 index 0000000000..e66d52f66b --- /dev/null +++ b/api/tests/unit_tests/services/test_app_generate_service_streaming_integration.py @@ -0,0 +1,197 @@ +import json +import uuid +from collections import defaultdict, deque + +import pytest + +from core.app.apps.message_generator import MessageGenerator +from models.model import AppMode +from services.app_generate_service import AppGenerateService + + +# ----------------------------- +# Fakes for Redis Pub/Sub flow +# ----------------------------- +class _FakePubSub: + def __init__(self, store: dict[str, deque[bytes]]): + self._store = store + self._subs: set[str] = set() + self._closed = False + + def subscribe(self, topic: str) -> None: + self._subs.add(topic) + + def unsubscribe(self, topic: str) -> None: + self._subs.discard(topic) + + def close(self) -> None: + self._closed = True + + def get_message(self, ignore_subscribe_messages: bool = True, timeout: int | float | None = 1): + # simulate a non-blocking poll; return first available + if self._closed: + return None + for t in list(self._subs): + q = self._store.get(t) + if q and len(q) > 0: + payload = q.popleft() + return {"type": "message", "channel": t, "data": payload} + # no message + return None + + +class _FakeRedisClient: + def __init__(self, store: dict[str, deque[bytes]]): + self._store = store + + def pubsub(self): + return _FakePubSub(self._store) + + def publish(self, topic: str, payload: bytes) -> None: + self._store.setdefault(topic, deque()).append(payload) + + +# ------------------------------------ +# Fakes for Redis Streams (XADD/XREAD) +# ------------------------------------ +class _FakeStreams: + def __init__(self) -> None: + # key -> list[(id, {field: value})] + self._data: dict[str, list[tuple[str, dict]]] = defaultdict(list) + self._seq: dict[str, int] = defaultdict(int) + + def xadd(self, key: str, fields: dict, *, maxlen: int | None = None) -> str: + # maxlen is accepted for API compatibility with redis-py; ignored in this test double + self._seq[key] += 1 + eid = f"{self._seq[key]}-0" + self._data[key].append((eid, fields)) + return eid + + def expire(self, key: str, seconds: int) -> None: + # no-op for tests + return None + + def xread(self, streams: dict, block: int | None = None, count: int | None = None): + assert len(streams) == 1 + key, last_id = next(iter(streams.items())) + entries = self._data.get(key, []) + start = 0 + if last_id != "0-0": + for i, (eid, _f) in enumerate(entries): + if eid == last_id: + start = i + 1 + break + if start >= len(entries): + return [] + end = len(entries) if count is None else min(len(entries), start + count) + return [(key, entries[start:end])] + + +@pytest.fixture +def _patch_get_channel_streams(monkeypatch): + from libs.broadcast_channel.redis.streams_channel import StreamsBroadcastChannel + + fake = _FakeStreams() + chan = StreamsBroadcastChannel(fake, retention_seconds=60) + + def _get_channel(): + return chan + + # Patch both the source and the imported alias used by MessageGenerator + monkeypatch.setattr("extensions.ext_redis.get_pubsub_broadcast_channel", lambda: chan) + monkeypatch.setattr("core.app.apps.message_generator.get_pubsub_broadcast_channel", lambda: chan) + # Ensure AppGenerateService sees streams mode + import services.app_generate_service as ags + + monkeypatch.setattr(ags.dify_config, "PUBSUB_REDIS_CHANNEL_TYPE", "streams", raising=False) + + +@pytest.fixture +def _patch_get_channel_pubsub(monkeypatch): + from libs.broadcast_channel.redis.channel import BroadcastChannel as RedisBroadcastChannel + + store: dict[str, deque[bytes]] = defaultdict(deque) + client = _FakeRedisClient(store) + chan = RedisBroadcastChannel(client) + + def _get_channel(): + return chan + + # Patch both the source and the imported alias used by MessageGenerator + monkeypatch.setattr("extensions.ext_redis.get_pubsub_broadcast_channel", lambda: chan) + monkeypatch.setattr("core.app.apps.message_generator.get_pubsub_broadcast_channel", lambda: chan) + # Ensure AppGenerateService sees pubsub mode + import services.app_generate_service as ags + + monkeypatch.setattr(ags.dify_config, "PUBSUB_REDIS_CHANNEL_TYPE", "pubsub", raising=False) + + +def _publish_events(app_mode: AppMode, run_id: str, events: list[dict]): + # Publish events to the same topic used by MessageGenerator + topic = MessageGenerator.get_response_topic(app_mode, run_id) + for ev in events: + topic.publish(json.dumps(ev).encode()) + + +@pytest.mark.usefixtures("_patch_get_channel_streams") +def test_streams_full_flow_prepublish_and_replay(): + app_mode = AppMode.WORKFLOW + run_id = str(uuid.uuid4()) + + # Build start_task that publishes two events immediately + events = [{"event": "workflow_started"}, {"event": "workflow_finished"}] + + def start_task(): + _publish_events(app_mode, run_id, events) + + on_subscribe = AppGenerateService._build_streaming_task_on_subscribe(start_task) + + # Start retrieving BEFORE subscription is established; in streams mode, we also started immediately + gen = MessageGenerator.retrieve_events(app_mode, run_id, idle_timeout=2.0, on_subscribe=on_subscribe) + + received = [] + for msg in gen: + if isinstance(msg, str): + # skip ping events + continue + received.append(msg) + if msg.get("event") == "workflow_finished": + break + + assert [m.get("event") for m in received] == ["workflow_started", "workflow_finished"] + + +@pytest.mark.usefixtures("_patch_get_channel_pubsub") +def test_pubsub_full_flow_start_on_subscribe_gated(monkeypatch): + # Speed up any potential timer if it accidentally triggers + monkeypatch.setattr("services.app_generate_service.SSE_TASK_START_FALLBACK_MS", 50) + + app_mode = AppMode.WORKFLOW + run_id = str(uuid.uuid4()) + + published_order: list[str] = [] + + def start_task(): + # When called (on subscribe), publish both events + events = [{"event": "workflow_started"}, {"event": "workflow_finished"}] + _publish_events(app_mode, run_id, events) + published_order.extend([e["event"] for e in events]) + + on_subscribe = AppGenerateService._build_streaming_task_on_subscribe(start_task) + + # Producer not started yet; only when subscribe happens + assert published_order == [] + + gen = MessageGenerator.retrieve_events(app_mode, run_id, idle_timeout=2.0, on_subscribe=on_subscribe) + + received = [] + for msg in gen: + if isinstance(msg, str): + continue + received.append(msg) + if msg.get("event") == "workflow_finished": + break + + # Verify publish happened and consumer received in order + assert published_order == ["workflow_started", "workflow_finished"] + assert [m.get("event") for m in received] == ["workflow_started", "workflow_finished"] From ad000c42b7beafc0cfd384eec3fd32d16e9f4063 Mon Sep 17 00:00:00 2001 From: Renzo <170978465+RenzoMXD@users.noreply.github.com> Date: Tue, 3 Mar 2026 21:50:41 -0800 Subject: [PATCH 019/256] feat: replace db.session with db_session_with_containers (#32942) --- .../services/dataset_collection_binding.py | 38 +- .../services/dataset_service_update_delete.py | 119 +++-- .../services/test_account_service.py | 413 +++++++++-------- .../services/test_agent_service.py | 167 ++++--- .../services/test_annotation_service.py | 184 ++++---- .../test_api_based_extension_service.py | 61 ++- .../services/test_app_generate_service.py | 109 +++-- .../services/test_app_service.py | 64 +-- .../services/test_dataset_service.py | 168 ++++--- ...et_service_batch_update_document_status.py | 183 +++++--- .../test_dataset_service_get_segments.py | 137 ++++-- .../test_dataset_service_retrieval.py | 245 ++++++---- .../test_dataset_service_update_dataset.py | 123 +++-- .../services/test_file_service.py | 162 ++++--- .../services/test_message_service.py | 158 ++++--- .../services/test_messages_clean_service.py | 427 +++++++++++------- .../services/test_metadata_service.py | 197 ++++---- .../test_model_load_balancing_service.py | 81 ++-- .../services/test_model_provider_service.py | 99 ++-- .../services/test_saved_message_service.py | 113 +++-- .../services/test_tag_service.py | 194 ++++---- .../services/test_trigger_provider_service.py | 30 +- .../services/test_web_conversation_service.py | 90 ++-- .../services/test_webapp_auth_service.py | 160 ++++--- .../services/test_workflow_app_service.py | 177 ++++---- .../test_workflow_draft_variable_service.py | 128 +++--- .../services/test_workflow_run_service.py | 63 ++- .../services/test_workflow_service.py | 225 ++++----- .../services/test_workspace_service.py | 99 ++-- .../tools/test_api_tools_manage_service.py | 45 +- .../tools/test_mcp_tools_manage_service.py | 217 ++++----- .../tools/test_tools_transform_service.py | 79 ++-- .../test_workflow_tools_manage_service.py | 93 ++-- .../workflow/test_workflow_converter.py | 75 ++- .../tasks/test_add_document_to_index_task.py | 134 +++--- .../tasks/test_batch_clean_document_task.py | 146 +++--- ...test_batch_create_segment_to_index_task.py | 119 +++-- .../tasks/test_clean_dataset_task.py | 37 +- .../test_disable_segment_from_index_task.py | 170 +++---- .../test_disable_segments_from_index_task.py | 62 +-- .../test_enable_segments_to_index_task.py | 66 +-- .../tasks/test_mail_account_deletion_task.py | 30 +- .../tasks/test_rag_pipeline_run_tasks.py | 60 +-- 43 files changed, 3078 insertions(+), 2669 deletions(-) diff --git a/api/tests/test_containers_integration_tests/services/dataset_collection_binding.py b/api/tests/test_containers_integration_tests/services/dataset_collection_binding.py index 73df2d9ed9..191c161613 100644 --- a/api/tests/test_containers_integration_tests/services/dataset_collection_binding.py +++ b/api/tests/test_containers_integration_tests/services/dataset_collection_binding.py @@ -9,8 +9,8 @@ from itertools import starmap from uuid import uuid4 import pytest +from sqlalchemy.orm import Session -from extensions.ext_database import db from models.dataset import DatasetCollectionBinding from services.dataset_service import DatasetCollectionBindingService @@ -28,6 +28,7 @@ class DatasetCollectionBindingTestDataFactory: @staticmethod def create_collection_binding( + db_session_with_containers: Session, provider_name: str = "openai", model_name: str = "text-embedding-ada-002", collection_name: str = "collection-abc", @@ -51,8 +52,8 @@ class DatasetCollectionBindingTestDataFactory: collection_name=collection_name, type=collection_type, ) - db.session.add(binding) - db.session.commit() + db_session_with_containers.add(binding) + db_session_with_containers.commit() return binding @@ -64,7 +65,7 @@ class TestDatasetCollectionBindingServiceGetBinding: including various provider/model combinations, collection types, and edge cases. """ - def test_get_dataset_collection_binding_existing_binding_success(self, db_session_with_containers): + def test_get_dataset_collection_binding_existing_binding_success(self, db_session_with_containers: Session): """ Test successful retrieval of an existing collection binding. @@ -77,6 +78,7 @@ class TestDatasetCollectionBindingServiceGetBinding: model_name = "text-embedding-ada-002" collection_type = "dataset" existing_binding = DatasetCollectionBindingTestDataFactory.create_collection_binding( + db_session_with_containers, provider_name=provider_name, model_name=model_name, collection_name="existing-collection", @@ -92,7 +94,7 @@ class TestDatasetCollectionBindingServiceGetBinding: assert result.id == existing_binding.id assert result.collection_name == "existing-collection" - def test_get_dataset_collection_binding_create_new_binding_success(self, db_session_with_containers): + def test_get_dataset_collection_binding_create_new_binding_success(self, db_session_with_containers: Session): """ Test successful creation of a new collection binding when none exists. @@ -116,7 +118,7 @@ class TestDatasetCollectionBindingServiceGetBinding: assert result.type == collection_type assert result.collection_name is not None - def test_get_dataset_collection_binding_different_collection_type(self, db_session_with_containers): + def test_get_dataset_collection_binding_different_collection_type(self, db_session_with_containers: Session): """Test get_dataset_collection_binding with different collection type.""" # Arrange provider_name = "openai" @@ -133,7 +135,7 @@ class TestDatasetCollectionBindingServiceGetBinding: assert result.provider_name == provider_name assert result.model_name == model_name - def test_get_dataset_collection_binding_default_collection_type(self, db_session_with_containers): + def test_get_dataset_collection_binding_default_collection_type(self, db_session_with_containers: Session): """Test get_dataset_collection_binding with default collection type parameter.""" # Arrange provider_name = "openai" @@ -147,7 +149,9 @@ class TestDatasetCollectionBindingServiceGetBinding: assert result.provider_name == provider_name assert result.model_name == model_name - def test_get_dataset_collection_binding_different_provider_model_combination(self, db_session_with_containers): + def test_get_dataset_collection_binding_different_provider_model_combination( + self, db_session_with_containers: Session + ): """Test get_dataset_collection_binding with various provider/model combinations.""" # Arrange combinations = [ @@ -174,10 +178,11 @@ class TestDatasetCollectionBindingServiceGetBindingByIdAndType: including successful retrieval and error handling for missing bindings. """ - def test_get_dataset_collection_binding_by_id_and_type_success(self, db_session_with_containers): + def test_get_dataset_collection_binding_by_id_and_type_success(self, db_session_with_containers: Session): """Test successful retrieval of collection binding by ID and type.""" # Arrange binding = DatasetCollectionBindingTestDataFactory.create_collection_binding( + db_session_with_containers, provider_name="openai", model_name="text-embedding-ada-002", collection_name="test-collection", @@ -194,7 +199,7 @@ class TestDatasetCollectionBindingServiceGetBindingByIdAndType: assert result.collection_name == "test-collection" assert result.type == "dataset" - def test_get_dataset_collection_binding_by_id_and_type_not_found_error(self, db_session_with_containers): + def test_get_dataset_collection_binding_by_id_and_type_not_found_error(self, db_session_with_containers: Session): """Test error handling when collection binding is not found by ID and type.""" # Arrange non_existent_id = str(uuid4()) @@ -203,10 +208,13 @@ class TestDatasetCollectionBindingServiceGetBindingByIdAndType: with pytest.raises(ValueError, match="Dataset collection binding not found"): DatasetCollectionBindingService.get_dataset_collection_binding_by_id_and_type(non_existent_id, "dataset") - def test_get_dataset_collection_binding_by_id_and_type_different_collection_type(self, db_session_with_containers): + def test_get_dataset_collection_binding_by_id_and_type_different_collection_type( + self, db_session_with_containers: Session + ): """Test retrieval by ID and type with different collection type.""" # Arrange binding = DatasetCollectionBindingTestDataFactory.create_collection_binding( + db_session_with_containers, provider_name="openai", model_name="text-embedding-ada-002", collection_name="test-collection", @@ -222,10 +230,13 @@ class TestDatasetCollectionBindingServiceGetBindingByIdAndType: assert result.id == binding.id assert result.type == "custom_type" - def test_get_dataset_collection_binding_by_id_and_type_default_collection_type(self, db_session_with_containers): + def test_get_dataset_collection_binding_by_id_and_type_default_collection_type( + self, db_session_with_containers: Session + ): """Test retrieval by ID with default collection type.""" # Arrange binding = DatasetCollectionBindingTestDataFactory.create_collection_binding( + db_session_with_containers, provider_name="openai", model_name="text-embedding-ada-002", collection_name="test-collection", @@ -239,10 +250,11 @@ class TestDatasetCollectionBindingServiceGetBindingByIdAndType: assert result.id == binding.id assert result.type == "dataset" - def test_get_dataset_collection_binding_by_id_and_type_wrong_type_error(self, db_session_with_containers): + def test_get_dataset_collection_binding_by_id_and_type_wrong_type_error(self, db_session_with_containers: Session): """Test error when binding exists but with wrong collection type.""" # Arrange binding = DatasetCollectionBindingTestDataFactory.create_collection_binding( + db_session_with_containers, provider_name="openai", model_name="text-embedding-ada-002", collection_name="test-collection", diff --git a/api/tests/test_containers_integration_tests/services/dataset_service_update_delete.py b/api/tests/test_containers_integration_tests/services/dataset_service_update_delete.py index 9871ef37e6..4b98bddd26 100644 --- a/api/tests/test_containers_integration_tests/services/dataset_service_update_delete.py +++ b/api/tests/test_containers_integration_tests/services/dataset_service_update_delete.py @@ -10,9 +10,9 @@ from unittest.mock import patch from uuid import uuid4 import pytest +from sqlalchemy.orm import Session from werkzeug.exceptions import NotFound -from extensions.ext_database import db from models import Account, Tenant, TenantAccountJoin, TenantAccountRole from models.dataset import AppDatasetJoin, Dataset, DatasetPermissionEnum from models.model import App @@ -27,6 +27,7 @@ class DatasetUpdateDeleteTestDataFactory: @staticmethod def create_account_with_tenant( + db_session_with_containers: Session, role: TenantAccountRole = TenantAccountRole.NORMAL, tenant: Tenant | None = None, ) -> tuple[Account, Tenant]: @@ -37,13 +38,13 @@ class DatasetUpdateDeleteTestDataFactory: interface_language="en-US", status="active", ) - db.session.add(account) - db.session.commit() + db_session_with_containers.add(account) + db_session_with_containers.commit() if tenant is None: tenant = Tenant(name=f"tenant-{uuid4()}", status="normal") - db.session.add(tenant) - db.session.commit() + db_session_with_containers.add(tenant) + db_session_with_containers.commit() join = TenantAccountJoin( tenant_id=tenant.id, @@ -51,14 +52,15 @@ class DatasetUpdateDeleteTestDataFactory: role=role, current=True, ) - db.session.add(join) - db.session.commit() + db_session_with_containers.add(join) + db_session_with_containers.commit() account.current_tenant = tenant return account, tenant @staticmethod def create_dataset( + db_session_with_containers: Session, tenant_id: str, created_by: str, name: str = "Test Dataset", @@ -78,12 +80,12 @@ class DatasetUpdateDeleteTestDataFactory: retrieval_model={"top_k": 2}, enable_api=enable_api, ) - db.session.add(dataset) - db.session.commit() + db_session_with_containers.add(dataset) + db_session_with_containers.commit() return dataset @staticmethod - def create_app(tenant_id: str, created_by: str, name: str = "Test App") -> App: + def create_app(db_session_with_containers: Session, tenant_id: str, created_by: str, name: str = "Test App") -> App: """Create a real app for AppDatasetJoin.""" app = App( tenant_id=tenant_id, @@ -96,16 +98,16 @@ class DatasetUpdateDeleteTestDataFactory: enable_api=True, created_by=created_by, ) - db.session.add(app) - db.session.commit() + db_session_with_containers.add(app) + db_session_with_containers.commit() return app @staticmethod - def create_app_dataset_join(app_id: str, dataset_id: str) -> AppDatasetJoin: + def create_app_dataset_join(db_session_with_containers: Session, app_id: str, dataset_id: str) -> AppDatasetJoin: """Create a real AppDatasetJoin record.""" join = AppDatasetJoin(app_id=app_id, dataset_id=dataset_id) - db.session.add(join) - db.session.commit() + db_session_with_containers.add(join) + db_session_with_containers.commit() return join @@ -114,7 +116,7 @@ class TestDatasetServiceDeleteDataset: Comprehensive integration tests for DatasetService.delete_dataset method. """ - def test_delete_dataset_success(self, db_session_with_containers): + def test_delete_dataset_success(self, db_session_with_containers: Session): """ Test successful deletion of a dataset. @@ -130,8 +132,10 @@ class TestDatasetServiceDeleteDataset: - Method returns True """ # Arrange - owner, tenant = DatasetUpdateDeleteTestDataFactory.create_account_with_tenant(role=TenantAccountRole.OWNER) - dataset = DatasetUpdateDeleteTestDataFactory.create_dataset(tenant.id, owner.id) + owner, tenant = DatasetUpdateDeleteTestDataFactory.create_account_with_tenant( + db_session_with_containers, role=TenantAccountRole.OWNER + ) + dataset = DatasetUpdateDeleteTestDataFactory.create_dataset(db_session_with_containers, tenant.id, owner.id) # Act with patch("services.dataset_service.dataset_was_deleted") as mock_dataset_was_deleted: @@ -139,10 +143,10 @@ class TestDatasetServiceDeleteDataset: # Assert assert result is True - assert db.session.get(Dataset, dataset.id) is None + assert db_session_with_containers.get(Dataset, dataset.id) is None mock_dataset_was_deleted.send.assert_called_once_with(dataset) - def test_delete_dataset_not_found(self, db_session_with_containers): + def test_delete_dataset_not_found(self, db_session_with_containers: Session): """ Test handling when dataset is not found. @@ -156,7 +160,9 @@ class TestDatasetServiceDeleteDataset: - No database operations are performed """ # Arrange - owner, _ = DatasetUpdateDeleteTestDataFactory.create_account_with_tenant(role=TenantAccountRole.OWNER) + owner, _ = DatasetUpdateDeleteTestDataFactory.create_account_with_tenant( + db_session_with_containers, role=TenantAccountRole.OWNER + ) dataset_id = str(uuid4()) # Act @@ -165,7 +171,7 @@ class TestDatasetServiceDeleteDataset: # Assert assert result is False - def test_delete_dataset_permission_denied_error(self, db_session_with_containers): + def test_delete_dataset_permission_denied_error(self, db_session_with_containers: Session): """ Test error handling when user lacks permission. @@ -178,19 +184,22 @@ class TestDatasetServiceDeleteDataset: - No database operations are performed """ # Arrange - owner, tenant = DatasetUpdateDeleteTestDataFactory.create_account_with_tenant(role=TenantAccountRole.OWNER) + owner, tenant = DatasetUpdateDeleteTestDataFactory.create_account_with_tenant( + db_session_with_containers, role=TenantAccountRole.OWNER + ) normal_user, _ = DatasetUpdateDeleteTestDataFactory.create_account_with_tenant( + db_session_with_containers, role=TenantAccountRole.NORMAL, tenant=tenant, ) - dataset = DatasetUpdateDeleteTestDataFactory.create_dataset(tenant.id, owner.id) + dataset = DatasetUpdateDeleteTestDataFactory.create_dataset(db_session_with_containers, tenant.id, owner.id) # Act & Assert with pytest.raises(NoPermissionError): DatasetService.delete_dataset(dataset.id, normal_user) # Verify no deletion was attempted - assert db.session.get(Dataset, dataset.id) is not None + assert db_session_with_containers.get(Dataset, dataset.id) is not None class TestDatasetServiceDatasetUseCheck: @@ -198,7 +207,7 @@ class TestDatasetServiceDatasetUseCheck: Comprehensive integration tests for DatasetService.dataset_use_check method. """ - def test_dataset_use_check_in_use(self, db_session_with_containers): + def test_dataset_use_check_in_use(self, db_session_with_containers: Session): """ Test detection when dataset is in use. @@ -211,10 +220,12 @@ class TestDatasetServiceDatasetUseCheck: - Database query is executed """ # Arrange - owner, tenant = DatasetUpdateDeleteTestDataFactory.create_account_with_tenant(role=TenantAccountRole.OWNER) - dataset = DatasetUpdateDeleteTestDataFactory.create_dataset(tenant.id, owner.id) - app = DatasetUpdateDeleteTestDataFactory.create_app(tenant.id, owner.id) - DatasetUpdateDeleteTestDataFactory.create_app_dataset_join(app.id, dataset.id) + owner, tenant = DatasetUpdateDeleteTestDataFactory.create_account_with_tenant( + db_session_with_containers, role=TenantAccountRole.OWNER + ) + dataset = DatasetUpdateDeleteTestDataFactory.create_dataset(db_session_with_containers, tenant.id, owner.id) + app = DatasetUpdateDeleteTestDataFactory.create_app(db_session_with_containers, tenant.id, owner.id) + DatasetUpdateDeleteTestDataFactory.create_app_dataset_join(db_session_with_containers, app.id, dataset.id) # Act result = DatasetService.dataset_use_check(dataset.id) @@ -222,7 +233,7 @@ class TestDatasetServiceDatasetUseCheck: # Assert assert result is True - def test_dataset_use_check_not_in_use(self, db_session_with_containers): + def test_dataset_use_check_not_in_use(self, db_session_with_containers: Session): """ Test detection when dataset is not in use. @@ -235,8 +246,10 @@ class TestDatasetServiceDatasetUseCheck: - Database query is executed """ # Arrange - owner, tenant = DatasetUpdateDeleteTestDataFactory.create_account_with_tenant(role=TenantAccountRole.OWNER) - dataset = DatasetUpdateDeleteTestDataFactory.create_dataset(tenant.id, owner.id) + owner, tenant = DatasetUpdateDeleteTestDataFactory.create_account_with_tenant( + db_session_with_containers, role=TenantAccountRole.OWNER + ) + dataset = DatasetUpdateDeleteTestDataFactory.create_dataset(db_session_with_containers, tenant.id, owner.id) # Act result = DatasetService.dataset_use_check(dataset.id) @@ -250,7 +263,7 @@ class TestDatasetServiceUpdateDatasetApiStatus: Comprehensive integration tests for DatasetService.update_dataset_api_status method. """ - def test_update_dataset_api_status_enable_success(self, db_session_with_containers): + def test_update_dataset_api_status_enable_success(self, db_session_with_containers: Session): """ Test successful enabling of dataset API access. @@ -264,8 +277,12 @@ class TestDatasetServiceUpdateDatasetApiStatus: - Transaction is committed """ # Arrange - owner, tenant = DatasetUpdateDeleteTestDataFactory.create_account_with_tenant(role=TenantAccountRole.OWNER) - dataset = DatasetUpdateDeleteTestDataFactory.create_dataset(tenant.id, owner.id, enable_api=False) + owner, tenant = DatasetUpdateDeleteTestDataFactory.create_account_with_tenant( + db_session_with_containers, role=TenantAccountRole.OWNER + ) + dataset = DatasetUpdateDeleteTestDataFactory.create_dataset( + db_session_with_containers, tenant.id, owner.id, enable_api=False + ) current_time = datetime.datetime(2023, 1, 1, 12, 0, 0) # Act @@ -276,12 +293,12 @@ class TestDatasetServiceUpdateDatasetApiStatus: DatasetService.update_dataset_api_status(dataset.id, True) # Assert - db.session.refresh(dataset) + db_session_with_containers.refresh(dataset) assert dataset.enable_api is True assert dataset.updated_by == owner.id assert dataset.updated_at == current_time - def test_update_dataset_api_status_disable_success(self, db_session_with_containers): + def test_update_dataset_api_status_disable_success(self, db_session_with_containers: Session): """ Test successful disabling of dataset API access. @@ -295,8 +312,12 @@ class TestDatasetServiceUpdateDatasetApiStatus: - Transaction is committed """ # Arrange - owner, tenant = DatasetUpdateDeleteTestDataFactory.create_account_with_tenant(role=TenantAccountRole.OWNER) - dataset = DatasetUpdateDeleteTestDataFactory.create_dataset(tenant.id, owner.id, enable_api=True) + owner, tenant = DatasetUpdateDeleteTestDataFactory.create_account_with_tenant( + db_session_with_containers, role=TenantAccountRole.OWNER + ) + dataset = DatasetUpdateDeleteTestDataFactory.create_dataset( + db_session_with_containers, tenant.id, owner.id, enable_api=True + ) current_time = datetime.datetime(2023, 1, 1, 12, 0, 0) # Act @@ -307,11 +328,11 @@ class TestDatasetServiceUpdateDatasetApiStatus: DatasetService.update_dataset_api_status(dataset.id, False) # Assert - db.session.refresh(dataset) + db_session_with_containers.refresh(dataset) assert dataset.enable_api is False assert dataset.updated_by == owner.id - def test_update_dataset_api_status_not_found_error(self, db_session_with_containers): + def test_update_dataset_api_status_not_found_error(self, db_session_with_containers: Session): """ Test error handling when dataset is not found. @@ -330,7 +351,7 @@ class TestDatasetServiceUpdateDatasetApiStatus: with pytest.raises(NotFound, match="Dataset not found"): DatasetService.update_dataset_api_status(dataset_id, True) - def test_update_dataset_api_status_missing_current_user_error(self, db_session_with_containers): + def test_update_dataset_api_status_missing_current_user_error(self, db_session_with_containers: Session): """ Test error handling when current_user is missing. @@ -343,8 +364,12 @@ class TestDatasetServiceUpdateDatasetApiStatus: - No updates are committed """ # Arrange - owner, tenant = DatasetUpdateDeleteTestDataFactory.create_account_with_tenant(role=TenantAccountRole.OWNER) - dataset = DatasetUpdateDeleteTestDataFactory.create_dataset(tenant.id, owner.id, enable_api=False) + owner, tenant = DatasetUpdateDeleteTestDataFactory.create_account_with_tenant( + db_session_with_containers, role=TenantAccountRole.OWNER + ) + dataset = DatasetUpdateDeleteTestDataFactory.create_dataset( + db_session_with_containers, tenant.id, owner.id, enable_api=False + ) # Act & Assert with ( @@ -354,6 +379,6 @@ class TestDatasetServiceUpdateDatasetApiStatus: DatasetService.update_dataset_api_status(dataset.id, True) # Verify no commit was attempted - db.session.rollback() - db.session.refresh(dataset) + db_session_with_containers.rollback() + db_session_with_containers.refresh(dataset) assert dataset.enable_api is False diff --git a/api/tests/test_containers_integration_tests/services/test_account_service.py b/api/tests/test_containers_integration_tests/services/test_account_service.py index 606e7e0b57..8595f5bf14 100644 --- a/api/tests/test_containers_integration_tests/services/test_account_service.py +++ b/api/tests/test_containers_integration_tests/services/test_account_service.py @@ -4,6 +4,7 @@ from unittest.mock import patch import pytest from faker import Faker +from sqlalchemy.orm import Session from werkzeug.exceptions import Unauthorized from configs import dify_config @@ -45,7 +46,7 @@ class TestAccountService: "passport_service": mock_passport_service, } - def test_create_account_and_login(self, db_session_with_containers, mock_external_service_dependencies): + def test_create_account_and_login(self, db_session_with_containers: Session, mock_external_service_dependencies): """ Test account creation and login with correct password. """ @@ -70,7 +71,9 @@ class TestAccountService: logged_in = AccountService.authenticate(email, password) assert logged_in.id == account.id - def test_create_account_without_password(self, db_session_with_containers, mock_external_service_dependencies): + def test_create_account_without_password( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test account creation without password (for OAuth users). """ @@ -92,7 +95,7 @@ class TestAccountService: assert account.password_salt is None def test_create_account_password_invalid_new_password( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test account create with invalid new password format. @@ -113,7 +116,9 @@ class TestAccountService: password="invalid_new_password", ) - def test_create_account_registration_disabled(self, db_session_with_containers, mock_external_service_dependencies): + def test_create_account_registration_disabled( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test account creation when registration is disabled. """ @@ -131,7 +136,9 @@ class TestAccountService: password=fake.password(length=12), ) - def test_create_account_email_in_freeze(self, db_session_with_containers, mock_external_service_dependencies): + def test_create_account_email_in_freeze( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test account creation when email is in freeze period. """ @@ -154,7 +161,9 @@ class TestAccountService: dify_config.BILLING_ENABLED = False # Reset config for other tests - def test_authenticate_account_not_found(self, db_session_with_containers, mock_external_service_dependencies): + def test_authenticate_account_not_found( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test authentication with non-existent account. """ @@ -164,7 +173,7 @@ class TestAccountService: with pytest.raises(AccountPasswordError): AccountService.authenticate(email, password) - def test_authenticate_banned_account(self, db_session_with_containers, mock_external_service_dependencies): + def test_authenticate_banned_account(self, db_session_with_containers: Session, mock_external_service_dependencies): """ Test authentication with banned account. """ @@ -186,14 +195,13 @@ class TestAccountService: # Ban the account account.status = AccountStatus.BANNED - from extensions.ext_database import db - db.session.commit() + db_session_with_containers.commit() with pytest.raises(AccountLoginError): AccountService.authenticate(email, password) - def test_authenticate_wrong_password(self, db_session_with_containers, mock_external_service_dependencies): + def test_authenticate_wrong_password(self, db_session_with_containers: Session, mock_external_service_dependencies): """ Test authentication with wrong password. """ @@ -217,7 +225,9 @@ class TestAccountService: with pytest.raises(AccountPasswordError): AccountService.authenticate(email, wrong_password) - def test_authenticate_with_invite_token(self, db_session_with_containers, mock_external_service_dependencies): + def test_authenticate_with_invite_token( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test authentication with invite token to set password for account without password. """ @@ -249,7 +259,7 @@ class TestAccountService: assert authenticated_account.password_salt is not None def test_authenticate_pending_account_activation( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test authentication activates pending account. @@ -270,16 +280,17 @@ class TestAccountService: password=password, ) account.status = AccountStatus.PENDING - from extensions.ext_database import db - db.session.commit() + db_session_with_containers.commit() # Authenticate should activate the account authenticated_account = AccountService.authenticate(email, password) assert authenticated_account.status == AccountStatus.ACTIVE assert authenticated_account.initialized_at is not None - def test_update_account_password_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_update_account_password_success( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test successful password update. """ @@ -308,7 +319,7 @@ class TestAccountService: assert authenticated_account.id == account.id def test_update_account_password_wrong_current_password( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test password update with wrong current password. @@ -335,7 +346,7 @@ class TestAccountService: AccountService.update_account_password(account, wrong_password, new_password) def test_update_account_password_invalid_new_password( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test password update with invalid new password format. @@ -360,7 +371,7 @@ class TestAccountService: with pytest.raises(ValueError): # Password validation error AccountService.update_account_password(account, old_password, "123") - def test_create_account_and_tenant(self, db_session_with_containers, mock_external_service_dependencies): + def test_create_account_and_tenant(self, db_session_with_containers: Session, mock_external_service_dependencies): """ Test account creation with automatic tenant creation. """ @@ -387,14 +398,13 @@ class TestAccountService: assert account.email == email # Verify tenant was created and linked - from extensions.ext_database import db - tenant_join = db.session.query(TenantAccountJoin).filter_by(account_id=account.id).first() + tenant_join = db_session_with_containers.query(TenantAccountJoin).filter_by(account_id=account.id).first() assert tenant_join is not None assert tenant_join.role == "owner" def test_create_account_and_tenant_workspace_creation_disabled( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test account creation when workspace creation is disabled. @@ -419,7 +429,7 @@ class TestAccountService: ) def test_create_account_and_tenant_workspace_limit_exceeded( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test account creation when workspace limit is exceeded. @@ -446,7 +456,9 @@ class TestAccountService: password=password, ) - def test_link_account_integrate_new_provider(self, db_session_with_containers, mock_external_service_dependencies): + def test_link_account_integrate_new_provider( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test linking account with new OAuth provider. """ @@ -469,15 +481,18 @@ class TestAccountService: AccountService.link_account_integrate("new-google", "google_open_id_123", account) # Verify integration was created - from extensions.ext_database import db from models import AccountIntegrate - integration = db.session.query(AccountIntegrate).filter_by(account_id=account.id, provider="new-google").first() + integration = ( + db_session_with_containers.query(AccountIntegrate) + .filter_by(account_id=account.id, provider="new-google") + .first() + ) assert integration is not None assert integration.open_id == "google_open_id_123" def test_link_account_integrate_existing_provider( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test linking account with existing provider (should update). @@ -504,15 +519,16 @@ class TestAccountService: AccountService.link_account_integrate("exists-google", "google_open_id_456", account) # Verify integration was updated - from extensions.ext_database import db from models import AccountIntegrate integration = ( - db.session.query(AccountIntegrate).filter_by(account_id=account.id, provider="exists-google").first() + db_session_with_containers.query(AccountIntegrate) + .filter_by(account_id=account.id, provider="exists-google") + .first() ) assert integration.open_id == "google_open_id_456" - def test_close_account(self, db_session_with_containers, mock_external_service_dependencies): + def test_close_account(self, db_session_with_containers: Session, mock_external_service_dependencies): """ Test closing an account. """ @@ -536,12 +552,11 @@ class TestAccountService: AccountService.close_account(account) # Verify account status changed - from extensions.ext_database import db - db.session.refresh(account) + db_session_with_containers.refresh(account) assert account.status == AccountStatus.CLOSED - def test_update_account_fields(self, db_session_with_containers, mock_external_service_dependencies): + def test_update_account_fields(self, db_session_with_containers: Session, mock_external_service_dependencies): """ Test updating account fields. """ @@ -568,7 +583,9 @@ class TestAccountService: assert updated_account.name == updated_name assert updated_account.interface_theme == "dark" - def test_update_account_invalid_field(self, db_session_with_containers, mock_external_service_dependencies): + def test_update_account_invalid_field( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test updating account with invalid field. """ @@ -591,7 +608,7 @@ class TestAccountService: with pytest.raises(AttributeError): AccountService.update_account(account, invalid_field="value") - def test_update_login_info(self, db_session_with_containers, mock_external_service_dependencies): + def test_update_login_info(self, db_session_with_containers: Session, mock_external_service_dependencies): """ Test updating login information. """ @@ -616,13 +633,12 @@ class TestAccountService: AccountService.update_login_info(account, ip_address=ip_address) # Verify login info was updated - from extensions.ext_database import db - db.session.refresh(account) + db_session_with_containers.refresh(account) assert account.last_login_ip == ip_address assert account.last_login_at is not None - def test_login_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_login_success(self, db_session_with_containers: Session, mock_external_service_dependencies): """ Test successful login with token generation. """ @@ -659,7 +675,9 @@ class TestAccountService: assert call_args["iss"] is not None assert call_args["sub"] == "Console API Passport" - def test_login_pending_account_activation(self, db_session_with_containers, mock_external_service_dependencies): + def test_login_pending_account_activation( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test login activates pending account. """ @@ -680,17 +698,16 @@ class TestAccountService: password=password, ) account.status = AccountStatus.PENDING - from extensions.ext_database import db - db.session.commit() + db_session_with_containers.commit() # Login should activate the account token_pair = AccountService.login(account) - db.session.refresh(account) + db_session_with_containers.refresh(account) assert account.status == AccountStatus.ACTIVE - def test_logout(self, db_session_with_containers, mock_external_service_dependencies): + def test_logout(self, db_session_with_containers: Session, mock_external_service_dependencies): """ Test logout functionality. """ @@ -723,7 +740,7 @@ class TestAccountService: refresh_token_key = f"account_refresh_token:{account.id}" assert redis_client.get(refresh_token_key) is None - def test_refresh_token_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_refresh_token_success(self, db_session_with_containers: Session, mock_external_service_dependencies): """ Test successful token refresh. """ @@ -757,7 +774,7 @@ class TestAccountService: assert new_token_pair.access_token == "new_mock_access_token" assert new_token_pair.refresh_token != initial_token_pair.refresh_token - def test_refresh_token_invalid_token(self, db_session_with_containers, mock_external_service_dependencies): + def test_refresh_token_invalid_token(self, db_session_with_containers: Session, mock_external_service_dependencies): """ Test refresh token with invalid token. """ @@ -766,7 +783,9 @@ class TestAccountService: with pytest.raises(ValueError, match="Invalid refresh token"): AccountService.refresh_token(invalid_token) - def test_refresh_token_invalid_account(self, db_session_with_containers, mock_external_service_dependencies): + def test_refresh_token_invalid_account( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test refresh token with valid token but invalid account. """ @@ -791,16 +810,15 @@ class TestAccountService: token_pair = AccountService.login(account) # Delete account - from extensions.ext_database import db - db.session.delete(account) - db.session.commit() + db_session_with_containers.delete(account) + db_session_with_containers.commit() # Try to refresh token with deleted account with pytest.raises(ValueError, match="Invalid account"): AccountService.refresh_token(token_pair.refresh_token) - def test_load_user_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_load_user_success(self, db_session_with_containers: Session, mock_external_service_dependencies): """ Test loading user by ID successfully. """ @@ -830,7 +848,7 @@ class TestAccountService: assert loaded_user.id == account.id assert loaded_user.email == account.email - def test_load_user_not_found(self, db_session_with_containers, mock_external_service_dependencies): + def test_load_user_not_found(self, db_session_with_containers: Session, mock_external_service_dependencies): """ Test loading non-existent user. """ @@ -839,7 +857,7 @@ class TestAccountService: loaded_user = AccountService.load_user(non_existent_user_id) assert loaded_user is None - def test_load_user_banned_account(self, db_session_with_containers, mock_external_service_dependencies): + def test_load_user_banned_account(self, db_session_with_containers: Session, mock_external_service_dependencies): """ Test loading banned user raises Unauthorized. """ @@ -861,14 +879,13 @@ class TestAccountService: # Ban the account account.status = AccountStatus.BANNED - from extensions.ext_database import db - db.session.commit() + db_session_with_containers.commit() with pytest.raises(Unauthorized): # Unauthorized exception AccountService.load_user(account.id) - def test_get_account_jwt_token(self, db_session_with_containers, mock_external_service_dependencies): + def test_get_account_jwt_token(self, db_session_with_containers: Session, mock_external_service_dependencies): """ Test JWT token generation for account. """ @@ -902,7 +919,7 @@ class TestAccountService: assert call_args["iss"] is not None assert call_args["sub"] == "Console API Passport" - def test_load_logged_in_account(self, db_session_with_containers, mock_external_service_dependencies): + def test_load_logged_in_account(self, db_session_with_containers: Session, mock_external_service_dependencies): """ Test loading logged in account by ID. """ @@ -931,7 +948,9 @@ class TestAccountService: assert loaded_account is not None assert loaded_account.id == account.id - def test_get_user_through_email_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_get_user_through_email_success( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test getting user through email successfully. """ @@ -957,7 +976,9 @@ class TestAccountService: assert found_user is not None assert found_user.id == account.id - def test_get_user_through_email_not_found(self, db_session_with_containers, mock_external_service_dependencies): + def test_get_user_through_email_not_found( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test getting user through non-existent email. """ @@ -968,7 +989,7 @@ class TestAccountService: assert found_user is None def test_get_user_through_email_banned_account( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test getting banned user through email raises Unauthorized. @@ -991,14 +1012,15 @@ class TestAccountService: # Ban the account account.status = AccountStatus.BANNED - from extensions.ext_database import db - db.session.commit() + db_session_with_containers.commit() with pytest.raises(Unauthorized): # Unauthorized exception AccountService.get_user_through_email(email) - def test_get_user_through_email_in_freeze(self, db_session_with_containers, mock_external_service_dependencies): + def test_get_user_through_email_in_freeze( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test getting user through email that is in freeze period. """ @@ -1014,7 +1036,7 @@ class TestAccountService: # Reset config dify_config.BILLING_ENABLED = False - def test_delete_account(self, db_session_with_containers, mock_external_service_dependencies): + def test_delete_account(self, db_session_with_containers: Session, mock_external_service_dependencies): """ Test account deletion (should add task to queue and sync to enterprise). """ @@ -1050,7 +1072,7 @@ class TestAccountService: mock_delete_task.delay.assert_called_once_with(account.id) def test_generate_account_deletion_verification_code( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test generating account deletion verification code. @@ -1079,7 +1101,9 @@ class TestAccountService: assert len(code) == 6 assert code.isdigit() - def test_verify_account_deletion_code_valid(self, db_session_with_containers, mock_external_service_dependencies): + def test_verify_account_deletion_code_valid( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test verifying valid account deletion code. """ @@ -1106,7 +1130,9 @@ class TestAccountService: is_valid = AccountService.verify_account_deletion_code(token, code) assert is_valid is True - def test_verify_account_deletion_code_invalid(self, db_session_with_containers, mock_external_service_dependencies): + def test_verify_account_deletion_code_invalid( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test verifying invalid account deletion code. """ @@ -1135,7 +1161,7 @@ class TestAccountService: assert is_valid is False def test_verify_account_deletion_code_invalid_token( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test verifying account deletion code with invalid token. @@ -1167,7 +1193,7 @@ class TestTenantService: "billing_service": mock_billing_service, } - def test_create_tenant_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_create_tenant_success(self, db_session_with_containers: Session, mock_external_service_dependencies): """ Test successful tenant creation with default settings. """ @@ -1187,7 +1213,7 @@ class TestTenantService: assert tenant.encrypt_public_key is not None def test_create_tenant_workspace_creation_disabled( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test tenant creation when workspace creation is disabled. @@ -1202,7 +1228,9 @@ class TestTenantService: with pytest.raises(NotAllowedCreateWorkspace): # NotAllowedCreateWorkspace exception TenantService.create_tenant(name=tenant_name) - def test_create_tenant_with_custom_name(self, db_session_with_containers, mock_external_service_dependencies): + def test_create_tenant_with_custom_name( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test tenant creation with custom name and setup flag. """ @@ -1221,7 +1249,9 @@ class TestTenantService: assert tenant.status == "normal" assert tenant.encrypt_public_key is not None - def test_create_tenant_member_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_create_tenant_member_success( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test successful tenant member creation. """ @@ -1251,7 +1281,9 @@ class TestTenantService: assert tenant_member.account_id == account.id assert tenant_member.role == "admin" - def test_create_tenant_member_duplicate_owner(self, db_session_with_containers, mock_external_service_dependencies): + def test_create_tenant_member_duplicate_owner( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test creating duplicate owner for a tenant (should fail). """ @@ -1290,7 +1322,9 @@ class TestTenantService: with pytest.raises(Exception, match="Tenant already has an owner"): TenantService.create_tenant_member(tenant, account2, role="owner") - def test_create_tenant_member_existing_member(self, db_session_with_containers, mock_external_service_dependencies): + def test_create_tenant_member_existing_member( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test updating role for existing tenant member. """ @@ -1323,7 +1357,7 @@ class TestTenantService: assert tenant_member2.account_id == tenant_member1.account_id assert tenant_member2.role == "editor" - def test_get_join_tenants_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_get_join_tenants_success(self, db_session_with_containers: Session, mock_external_service_dependencies): """ Test getting join tenants for an account. """ @@ -1361,7 +1395,7 @@ class TestTenantService: assert tenant2_name in tenant_names def test_get_current_tenant_by_account_success( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test getting current tenant by account successfully. @@ -1388,9 +1422,8 @@ class TestTenantService: # Add account to tenant and set as current TenantService.create_tenant_member(tenant, account, role="owner") account.current_tenant = tenant - from extensions.ext_database import db - db.session.commit() + db_session_with_containers.commit() # Get current tenant current_tenant = TenantService.get_current_tenant_by_account(account) @@ -1400,7 +1433,7 @@ class TestTenantService: assert current_tenant.role == "owner" def test_get_current_tenant_by_account_not_found( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test getting current tenant when account has no current tenant. @@ -1426,7 +1459,7 @@ class TestTenantService: with pytest.raises((AttributeError, TenantNotFoundError)): TenantService.get_current_tenant_by_account(account) - def test_switch_tenant_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_switch_tenant_success(self, db_session_with_containers: Session, mock_external_service_dependencies): """ Test successful tenant switching. """ @@ -1457,18 +1490,17 @@ class TestTenantService: # Set initial current tenant account.current_tenant = tenant1 - from extensions.ext_database import db - db.session.commit() + db_session_with_containers.commit() # Switch to second tenant TenantService.switch_tenant(account, tenant2.id) # Verify tenant was switched - db.session.refresh(account) + db_session_with_containers.refresh(account) assert account.current_tenant_id == tenant2.id - def test_switch_tenant_no_tenant_id(self, db_session_with_containers, mock_external_service_dependencies): + def test_switch_tenant_no_tenant_id(self, db_session_with_containers: Session, mock_external_service_dependencies): """ Test tenant switching without providing tenant ID. """ @@ -1493,7 +1525,9 @@ class TestTenantService: with pytest.raises(ValueError, match="Tenant ID must be provided"): TenantService.switch_tenant(account, None) - def test_switch_tenant_account_not_member(self, db_session_with_containers, mock_external_service_dependencies): + def test_switch_tenant_account_not_member( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test switching to a tenant where account is not a member. """ @@ -1520,7 +1554,7 @@ class TestTenantService: with pytest.raises(Exception, match="Tenant not found or account is not a member of the tenant"): TenantService.switch_tenant(account, tenant.id) - def test_has_roles_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_has_roles_success(self, db_session_with_containers: Session, mock_external_service_dependencies): """ Test checking if tenant has specific roles. """ @@ -1570,7 +1604,7 @@ class TestTenantService: has_normal = TenantService.has_roles(tenant, [TenantAccountRole.NORMAL]) assert has_normal is False - def test_has_roles_invalid_role_type(self, db_session_with_containers, mock_external_service_dependencies): + def test_has_roles_invalid_role_type(self, db_session_with_containers: Session, mock_external_service_dependencies): """ Test checking roles with invalid role type. """ @@ -1589,7 +1623,7 @@ class TestTenantService: with pytest.raises(ValueError, match="all roles must be TenantAccountRole"): TenantService.has_roles(tenant, [invalid_role]) - def test_get_user_role_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_get_user_role_success(self, db_session_with_containers: Session, mock_external_service_dependencies): """ Test getting user role in a tenant. """ @@ -1620,7 +1654,9 @@ class TestTenantService: assert user_role == "editor" - def test_check_member_permission_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_check_member_permission_success( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test checking member permission successfully. """ @@ -1660,7 +1696,7 @@ class TestTenantService: TenantService.check_member_permission(tenant, owner_account, member_account, "add") def test_check_member_permission_invalid_action( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test checking member permission with invalid action. @@ -1692,7 +1728,9 @@ class TestTenantService: with pytest.raises(Exception, match="Invalid action"): TenantService.check_member_permission(tenant, account, None, invalid_action) - def test_check_member_permission_operate_self(self, db_session_with_containers, mock_external_service_dependencies): + def test_check_member_permission_operate_self( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test checking member permission when trying to operate self. """ @@ -1722,7 +1760,9 @@ class TestTenantService: with pytest.raises(Exception, match="Cannot operate self"): TenantService.check_member_permission(tenant, account, account, "remove") - def test_remove_member_from_tenant_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_remove_member_from_tenant_success( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test successful member removal from tenant (should sync to enterprise). """ @@ -1770,16 +1810,17 @@ class TestTenantService: ) # Verify member was removed - from extensions.ext_database import db from models.account import TenantAccountJoin member_join = ( - db.session.query(TenantAccountJoin).filter_by(tenant_id=tenant.id, account_id=member_account.id).first() + db_session_with_containers.query(TenantAccountJoin) + .filter_by(tenant_id=tenant.id, account_id=member_account.id) + .first() ) assert member_join is None def test_remove_member_from_tenant_operate_self( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test removing member when trying to operate self. @@ -1810,7 +1851,9 @@ class TestTenantService: with pytest.raises(Exception, match="Cannot operate self"): TenantService.remove_member_from_tenant(tenant, account, account) - def test_remove_member_from_tenant_not_member(self, db_session_with_containers, mock_external_service_dependencies): + def test_remove_member_from_tenant_not_member( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test removing member who is not in the tenant. """ @@ -1849,7 +1892,7 @@ class TestTenantService: with pytest.raises(Exception, match="Member not in tenant"): TenantService.remove_member_from_tenant(tenant, non_member_account, owner_account) - def test_update_member_role_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_update_member_role_success(self, db_session_with_containers: Session, mock_external_service_dependencies): """ Test successful member role update. """ @@ -1889,15 +1932,16 @@ class TestTenantService: TenantService.update_member_role(tenant, member_account, "admin", owner_account) # Verify role was updated - from extensions.ext_database import db from models.account import TenantAccountJoin member_join = ( - db.session.query(TenantAccountJoin).filter_by(tenant_id=tenant.id, account_id=member_account.id).first() + db_session_with_containers.query(TenantAccountJoin) + .filter_by(tenant_id=tenant.id, account_id=member_account.id) + .first() ) assert member_join.role == "admin" - def test_update_member_role_to_owner(self, db_session_with_containers, mock_external_service_dependencies): + def test_update_member_role_to_owner(self, db_session_with_containers: Session, mock_external_service_dependencies): """ Test updating member role to owner (should change current owner to admin). """ @@ -1937,19 +1981,24 @@ class TestTenantService: TenantService.update_member_role(tenant, member_account, "owner", owner_account) # Verify roles were updated correctly - from extensions.ext_database import db from models.account import TenantAccountJoin owner_join = ( - db.session.query(TenantAccountJoin).filter_by(tenant_id=tenant.id, account_id=owner_account.id).first() + db_session_with_containers.query(TenantAccountJoin) + .filter_by(tenant_id=tenant.id, account_id=owner_account.id) + .first() ) member_join = ( - db.session.query(TenantAccountJoin).filter_by(tenant_id=tenant.id, account_id=member_account.id).first() + db_session_with_containers.query(TenantAccountJoin) + .filter_by(tenant_id=tenant.id, account_id=member_account.id) + .first() ) assert owner_join.role == "admin" assert member_join.role == "owner" - def test_update_member_role_already_assigned(self, db_session_with_containers, mock_external_service_dependencies): + def test_update_member_role_already_assigned( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test updating member role to already assigned role. """ @@ -1989,7 +2038,7 @@ class TestTenantService: with pytest.raises(Exception, match="The provided role is already assigned to the member"): TenantService.update_member_role(tenant, member_account, "admin", owner_account) - def test_get_tenant_count_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_get_tenant_count_success(self, db_session_with_containers: Session, mock_external_service_dependencies): """ Test getting tenant count successfully. """ @@ -2014,7 +2063,7 @@ class TestTenantService: assert tenant_count >= 3 def test_create_owner_tenant_if_not_exist_new_user( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test creating owner tenant for new user without existing tenants. @@ -2044,17 +2093,16 @@ class TestTenantService: TenantService.create_owner_tenant_if_not_exist(account, name=workspace_name) # Verify tenant was created and linked - from extensions.ext_database import db from models.account import TenantAccountJoin - tenant_join = db.session.query(TenantAccountJoin).filter_by(account_id=account.id).first() + tenant_join = db_session_with_containers.query(TenantAccountJoin).filter_by(account_id=account.id).first() assert tenant_join is not None assert tenant_join.role == "owner" assert account.current_tenant is not None assert account.current_tenant.name == workspace_name def test_create_owner_tenant_if_not_exist_existing_tenant( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test creating owner tenant when user already has a tenant. @@ -2083,20 +2131,19 @@ class TestTenantService: existing_tenant = TenantService.create_tenant(name=existing_tenant_name) TenantService.create_tenant_member(existing_tenant, account, role="owner") account.current_tenant = existing_tenant - from extensions.ext_database import db - db.session.commit() + db_session_with_containers.commit() # Try to create owner tenant again (should not create new one) TenantService.create_owner_tenant_if_not_exist(account, name=new_workspace_name) # Verify no new tenant was created - tenant_joins = db.session.query(TenantAccountJoin).filter_by(account_id=account.id).all() + tenant_joins = db_session_with_containers.query(TenantAccountJoin).filter_by(account_id=account.id).all() assert len(tenant_joins) == 1 assert account.current_tenant.id == existing_tenant.id def test_create_owner_tenant_if_not_exist_workspace_disabled( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test creating owner tenant when workspace creation is disabled. @@ -2123,7 +2170,7 @@ class TestTenantService: with pytest.raises(WorkSpaceNotAllowedCreateError): # WorkSpaceNotAllowedCreateError exception TenantService.create_owner_tenant_if_not_exist(account, name=workspace_name) - def test_get_tenant_members_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_get_tenant_members_success(self, db_session_with_containers: Session, mock_external_service_dependencies): """ Test getting tenant members successfully. """ @@ -2187,7 +2234,9 @@ class TestTenantService: elif member.email == normal_email: assert member.role == "normal" - def test_get_dataset_operator_members_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_get_dataset_operator_members_success( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test getting dataset operator members successfully. """ @@ -2240,7 +2289,7 @@ class TestTenantService: assert dataset_operators[0].email == operator_email assert dataset_operators[0].role == "dataset_operator" - def test_get_custom_config_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_get_custom_config_success(self, db_session_with_containers: Session, mock_external_service_dependencies): """ Test getting custom config successfully. """ @@ -2259,9 +2308,8 @@ class TestTenantService: # Set custom config custom_config = {"theme": theme, "language": language, "feature_flags": {"beta": True}} tenant.custom_config_dict = custom_config - from extensions.ext_database import db - db.session.commit() + db_session_with_containers.commit() # Get custom config retrieved_config = TenantService.get_custom_config(tenant.id) @@ -2296,7 +2344,7 @@ class TestRegisterService: "passport_service": mock_passport_service, } - def test_setup_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_setup_success(self, db_session_with_containers: Session, mock_external_service_dependencies): """ Test successful system setup with account creation and tenant setup. """ @@ -2309,11 +2357,10 @@ class TestRegisterService: mock_external_service_dependencies["feature_service"].get_system_features.return_value.is_allow_register = True mock_external_service_dependencies["billing_service"].is_email_in_freeze.return_value = False - from extensions.ext_database import db from models.model import DifySetup - db.session.query(DifySetup).delete() - db.session.commit() + db_session_with_containers.query(DifySetup).delete() + db_session_with_containers.commit() # Execute setup RegisterService.setup( @@ -2327,7 +2374,7 @@ class TestRegisterService: # Verify account was created from models import Account - account = db.session.query(Account).filter_by(email=admin_email).first() + account = db_session_with_containers.query(Account).filter_by(email=admin_email).first() assert account is not None assert account.name == admin_name assert account.last_login_ip == ip_address @@ -2335,17 +2382,17 @@ class TestRegisterService: assert account.status == "active" # Verify DifySetup was created - dify_setup = db.session.query(DifySetup).first() + dify_setup = db_session_with_containers.query(DifySetup).first() assert dify_setup is not None # Verify tenant was created and linked from models.account import TenantAccountJoin - tenant_join = db.session.query(TenantAccountJoin).filter_by(account_id=account.id).first() + tenant_join = db_session_with_containers.query(TenantAccountJoin).filter_by(account_id=account.id).first() assert tenant_join is not None assert tenant_join.role == "owner" - def test_setup_failure_rollback(self, db_session_with_containers, mock_external_service_dependencies): + def test_setup_failure_rollback(self, db_session_with_containers: Session, mock_external_service_dependencies): """ Test setup failure with proper rollback of all created entities. """ @@ -2373,21 +2420,20 @@ class TestRegisterService: ) # Verify no entities were created (rollback worked) - from extensions.ext_database import db from models import Account, Tenant, TenantAccountJoin from models.model import DifySetup - account = db.session.query(Account).filter_by(email=admin_email).first() - tenant_count = db.session.query(Tenant).count() - tenant_join_count = db.session.query(TenantAccountJoin).count() - dify_setup_count = db.session.query(DifySetup).count() + account = db_session_with_containers.query(Account).filter_by(email=admin_email).first() + tenant_count = db_session_with_containers.query(Tenant).count() + tenant_join_count = db_session_with_containers.query(TenantAccountJoin).count() + dify_setup_count = db_session_with_containers.query(DifySetup).count() assert account is None assert tenant_count == 0 assert tenant_join_count == 0 assert dify_setup_count == 0 - def test_register_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_register_success(self, db_session_with_containers: Session, mock_external_service_dependencies): """ Test successful account registration with workspace creation. """ @@ -2421,16 +2467,15 @@ class TestRegisterService: assert account.initialized_at is not None # Verify tenant was created and linked - from extensions.ext_database import db from models.account import TenantAccountJoin - tenant_join = db.session.query(TenantAccountJoin).filter_by(account_id=account.id).first() + tenant_join = db_session_with_containers.query(TenantAccountJoin).filter_by(account_id=account.id).first() assert tenant_join is not None assert tenant_join.role == "owner" assert account.current_tenant is not None assert account.current_tenant.name == f"{name}'s Workspace" - def test_register_with_oauth(self, db_session_with_containers, mock_external_service_dependencies): + def test_register_with_oauth(self, db_session_with_containers: Session, mock_external_service_dependencies): """ Test account registration with OAuth integration. """ @@ -2467,14 +2512,19 @@ class TestRegisterService: assert account.initialized_at is not None # Verify OAuth integration was created - from extensions.ext_database import db from models import AccountIntegrate - integration = db.session.query(AccountIntegrate).filter_by(account_id=account.id, provider=provider).first() + integration = ( + db_session_with_containers.query(AccountIntegrate) + .filter_by(account_id=account.id, provider=provider) + .first() + ) assert integration is not None assert integration.open_id == open_id - def test_register_with_pending_status(self, db_session_with_containers, mock_external_service_dependencies): + def test_register_with_pending_status( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test account registration with pending status. """ @@ -2511,14 +2561,15 @@ class TestRegisterService: assert account.initialized_at is not None # Verify tenant was created and linked - from extensions.ext_database import db from models.account import TenantAccountJoin - tenant_join = db.session.query(TenantAccountJoin).filter_by(account_id=account.id).first() + tenant_join = db_session_with_containers.query(TenantAccountJoin).filter_by(account_id=account.id).first() assert tenant_join is not None assert tenant_join.role == "owner" - def test_register_workspace_creation_disabled(self, db_session_with_containers, mock_external_service_dependencies): + def test_register_workspace_creation_disabled( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test account registration when workspace creation is disabled. """ @@ -2549,13 +2600,14 @@ class TestRegisterService: assert account.initialized_at is not None # Verify tenant was created and linked - from extensions.ext_database import db from models.account import TenantAccountJoin - tenant_join = db.session.query(TenantAccountJoin).filter_by(account_id=account.id).first() + tenant_join = db_session_with_containers.query(TenantAccountJoin).filter_by(account_id=account.id).first() assert tenant_join is None - def test_register_workspace_limit_exceeded(self, db_session_with_containers, mock_external_service_dependencies): + def test_register_workspace_limit_exceeded( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test account registration when workspace limit is exceeded. """ @@ -2589,13 +2641,12 @@ class TestRegisterService: assert account.initialized_at is not None # Verify tenant was created and linked - from extensions.ext_database import db from models.account import TenantAccountJoin - tenant_join = db.session.query(TenantAccountJoin).filter_by(account_id=account.id).first() + tenant_join = db_session_with_containers.query(TenantAccountJoin).filter_by(account_id=account.id).first() assert tenant_join is None - def test_register_without_workspace(self, db_session_with_containers, mock_external_service_dependencies): + def test_register_without_workspace(self, db_session_with_containers: Session, mock_external_service_dependencies): """ Test account registration without workspace creation. """ @@ -2624,13 +2675,14 @@ class TestRegisterService: assert account.initialized_at is not None # Verify no tenant was created - from extensions.ext_database import db from models.account import TenantAccountJoin - tenant_join = db.session.query(TenantAccountJoin).filter_by(account_id=account.id).first() + tenant_join = db_session_with_containers.query(TenantAccountJoin).filter_by(account_id=account.id).first() assert tenant_join is None - def test_invite_new_member_new_account(self, db_session_with_containers, mock_external_service_dependencies): + def test_invite_new_member_new_account( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test inviting a new member who doesn't have an account yet. """ @@ -2682,22 +2734,25 @@ class TestRegisterService: mock_send_mail.delay.assert_called_once() # Verify new account was created with pending status - from extensions.ext_database import db from models import Account, TenantAccountJoin - new_account = db.session.query(Account).filter_by(email=new_member_email).first() + new_account = db_session_with_containers.query(Account).filter_by(email=new_member_email).first() assert new_account is not None assert new_account.name == new_member_email.split("@")[0] # Default name from email assert new_account.status == "pending" # Verify tenant member was created tenant_join = ( - db.session.query(TenantAccountJoin).filter_by(tenant_id=tenant.id, account_id=new_account.id).first() + db_session_with_containers.query(TenantAccountJoin) + .filter_by(tenant_id=tenant.id, account_id=new_account.id) + .first() ) assert tenant_join is not None assert tenant_join.role == "normal" - def test_invite_new_member_existing_account(self, db_session_with_containers, mock_external_service_dependencies): + def test_invite_new_member_existing_account( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test inviting an existing member who is not in the tenant yet. """ @@ -2749,16 +2804,19 @@ class TestRegisterService: mock_send_mail.delay.assert_not_called() # Verify tenant member was created for existing account - from extensions.ext_database import db from models.account import TenantAccountJoin tenant_join = ( - db.session.query(TenantAccountJoin).filter_by(tenant_id=tenant.id, account_id=existing_account.id).first() + db_session_with_containers.query(TenantAccountJoin) + .filter_by(tenant_id=tenant.id, account_id=existing_account.id) + .first() ) assert tenant_join is not None assert tenant_join.role == "admin" - def test_invite_new_member_existing_member(self, db_session_with_containers, mock_external_service_dependencies): + def test_invite_new_member_existing_member( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test inviting a member who is already in the tenant with pending status. """ @@ -2793,9 +2851,8 @@ class TestRegisterService: password=existing_pending_member_password, ) existing_account.status = "pending" - from extensions.ext_database import db - db.session.commit() + db_session_with_containers.commit() # Add existing account to tenant TenantService.create_tenant_member(tenant, existing_account, role="normal") @@ -2820,7 +2877,9 @@ class TestRegisterService: # Verify email task was called mock_send_mail.delay.assert_called_once() - def test_invite_new_member_no_inviter(self, db_session_with_containers, mock_external_service_dependencies): + def test_invite_new_member_no_inviter( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test inviting a member without providing an inviter. """ @@ -2846,7 +2905,7 @@ class TestRegisterService: ) def test_invite_new_member_account_already_in_tenant( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test inviting a member who is already in the tenant with active status. @@ -2882,9 +2941,8 @@ class TestRegisterService: password=already_in_tenant_password, ) existing_account.status = "active" - from extensions.ext_database import db - db.session.commit() + db_session_with_containers.commit() # Add existing account to tenant TenantService.create_tenant_member(tenant, existing_account, role="normal") @@ -2899,7 +2957,9 @@ class TestRegisterService: inviter=inviter, ) - def test_generate_invite_token_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_generate_invite_token_success( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test successful generation of invite token. """ @@ -2943,7 +3003,7 @@ class TestRegisterService: assert invitation_data["email"] == account.email assert invitation_data["workspace_id"] == tenant.id - def test_is_valid_invite_token_valid(self, db_session_with_containers, mock_external_service_dependencies): + def test_is_valid_invite_token_valid(self, db_session_with_containers: Session, mock_external_service_dependencies): """ Test validation of valid invite token. """ @@ -2974,7 +3034,9 @@ class TestRegisterService: # Verify token is valid assert is_valid is True - def test_is_valid_invite_token_invalid(self, db_session_with_containers, mock_external_service_dependencies): + def test_is_valid_invite_token_invalid( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test validation of invalid invite token. """ @@ -2987,7 +3049,7 @@ class TestRegisterService: assert is_valid is False def test_revoke_token_with_workspace_and_email( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test revoking token with workspace ID and email. @@ -3030,7 +3092,7 @@ class TestRegisterService: assert redis_client.get(token_key) is not None def test_revoke_token_without_workspace_and_email( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test revoking token without workspace ID and email. @@ -3073,7 +3135,7 @@ class TestRegisterService: assert redis_client.get(token_key) is None def test_get_invitation_if_token_valid_success( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test getting invitation data with valid token. @@ -3122,7 +3184,7 @@ class TestRegisterService: assert result["data"]["workspace_id"] == tenant.id def test_get_invitation_if_token_valid_invalid_token( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test getting invitation data with invalid token. @@ -3142,7 +3204,7 @@ class TestRegisterService: assert result is None def test_get_invitation_if_token_valid_invalid_tenant( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test getting invitation data with invalid tenant. @@ -3192,7 +3254,7 @@ class TestRegisterService: redis_client.delete(token_key) def test_get_invitation_if_token_valid_account_mismatch( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test getting invitation data with account ID mismatch. @@ -3242,7 +3304,7 @@ class TestRegisterService: redis_client.delete(token_key) def test_get_invitation_if_token_valid_tenant_not_normal( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test getting invitation data with tenant not in normal status. @@ -3269,9 +3331,8 @@ class TestRegisterService: # Change tenant status to non-normal tenant.status = "suspended" - from extensions.ext_database import db - db.session.commit() + db_session_with_containers.commit() # Create a real token from extensions.ext_redis import redis_client @@ -3300,7 +3361,7 @@ class TestRegisterService: redis_client.delete(token_key) def test_get_invitation_by_token_with_workspace_and_email( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test getting invitation by token with workspace ID and email. @@ -3339,7 +3400,7 @@ class TestRegisterService: redis_client.delete(cache_key) def test_get_invitation_by_token_without_workspace_and_email( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test getting invitation by token without workspace ID and email. @@ -3372,7 +3433,7 @@ class TestRegisterService: # Clean up redis_client.delete(token_key) - def test_get_invitation_token_key(self, db_session_with_containers, mock_external_service_dependencies): + def test_get_invitation_token_key(self, db_session_with_containers: Session, mock_external_service_dependencies): """ Test getting invitation token key. """ diff --git a/api/tests/test_containers_integration_tests/services/test_agent_service.py b/api/tests/test_containers_integration_tests/services/test_agent_service.py index 00bce32f48..45839fd463 100644 --- a/api/tests/test_containers_integration_tests/services/test_agent_service.py +++ b/api/tests/test_containers_integration_tests/services/test_agent_service.py @@ -3,6 +3,7 @@ from unittest.mock import MagicMock, create_autospec, patch import pytest from faker import Faker +from sqlalchemy.orm import Session from core.plugin.impl.exc import PluginDaemonClientSideError from models import Account @@ -87,7 +88,7 @@ class TestAgentService: "account_feature_service": mock_account_feature_service, } - def _create_test_app_and_account(self, db_session_with_containers, mock_external_service_dependencies): + def _create_test_app_and_account(self, db_session_with_containers: Session, mock_external_service_dependencies): """ Helper method to create a test app and account for testing. @@ -133,13 +134,12 @@ class TestAgentService: # Update the app model config to set agent_mode for agent-chat mode if app.mode == "agent-chat" and app.app_model_config: app.app_model_config.agent_mode = json.dumps({"enabled": True, "strategy": "react", "tools": []}) - from extensions.ext_database import db - db.session.commit() + db_session_with_containers.commit() return app, account - def _create_test_conversation_and_message(self, db_session_with_containers, app, account): + def _create_test_conversation_and_message(self, db_session_with_containers: Session, app, account): """ Helper method to create a test conversation and message with agent thoughts. @@ -153,8 +153,6 @@ class TestAgentService: """ fake = Faker() - from extensions.ext_database import db - # Create conversation conversation = Conversation( id=fake.uuid4(), @@ -167,8 +165,8 @@ class TestAgentService: mode="chat", from_source="api", ) - db.session.add(conversation) - db.session.commit() + db_session_with_containers.add(conversation) + db_session_with_containers.commit() # Create app model config app_model_config = AppModelConfig( @@ -180,12 +178,12 @@ class TestAgentService: agent_mode=json.dumps({"enabled": True, "strategy": "react", "tools": []}), ) app_model_config.id = fake.uuid4() - db.session.add(app_model_config) - db.session.commit() + db_session_with_containers.add(app_model_config) + db_session_with_containers.commit() # Update conversation with app model config conversation.app_model_config_id = app_model_config.id - db.session.commit() + db_session_with_containers.commit() # Create message message = Message( @@ -206,12 +204,12 @@ class TestAgentService: currency="USD", from_source="api", ) - db.session.add(message) - db.session.commit() + db_session_with_containers.add(message) + db_session_with_containers.commit() return conversation, message - def _create_test_agent_thoughts(self, db_session_with_containers, message): + def _create_test_agent_thoughts(self, db_session_with_containers: Session, message): """ Helper method to create test agent thoughts for a message. @@ -224,8 +222,6 @@ class TestAgentService: """ fake = Faker() - from extensions.ext_database import db - agent_thoughts = [] # Create first agent thought @@ -251,7 +247,7 @@ class TestAgentService: created_by_role="account", created_by=message.from_account_id, ) - db.session.add(thought1) + db_session_with_containers.add(thought1) agent_thoughts.append(thought1) # Create second agent thought @@ -277,14 +273,14 @@ class TestAgentService: created_by_role="account", created_by=message.from_account_id, ) - db.session.add(thought2) + db_session_with_containers.add(thought2) agent_thoughts.append(thought2) - db.session.commit() + db_session_with_containers.commit() return agent_thoughts - def test_get_agent_logs_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_get_agent_logs_success(self, db_session_with_containers: Session, mock_external_service_dependencies): """ Test successful retrieval of agent logs with complete data. """ @@ -344,7 +340,7 @@ class TestAgentService: assert dataset_tool_call["tool_icon"] == "" # dataset-retrieval tools have empty icon def test_get_agent_logs_conversation_not_found( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test error handling when conversation is not found. @@ -358,7 +354,9 @@ class TestAgentService: with pytest.raises(ValueError, match="Conversation not found"): AgentService.get_agent_logs(app, fake.uuid4(), fake.uuid4()) - def test_get_agent_logs_message_not_found(self, db_session_with_containers, mock_external_service_dependencies): + def test_get_agent_logs_message_not_found( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test error handling when message is not found. """ @@ -372,7 +370,9 @@ class TestAgentService: with pytest.raises(ValueError, match="Message not found"): AgentService.get_agent_logs(app, str(conversation.id), fake.uuid4()) - def test_get_agent_logs_with_end_user(self, db_session_with_containers, mock_external_service_dependencies): + def test_get_agent_logs_with_end_user( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test agent logs retrieval when conversation is from end user. """ @@ -381,8 +381,6 @@ class TestAgentService: # Create test data app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) - from extensions.ext_database import db - # Create end user end_user = EndUser( id=fake.uuid4(), @@ -393,8 +391,8 @@ class TestAgentService: session_id=fake.uuid4(), name=fake.name(), ) - db.session.add(end_user) - db.session.commit() + db_session_with_containers.add(end_user) + db_session_with_containers.commit() # Create conversation with end user conversation = Conversation( @@ -408,8 +406,8 @@ class TestAgentService: mode="chat", from_source="api", ) - db.session.add(conversation) - db.session.commit() + db_session_with_containers.add(conversation) + db_session_with_containers.commit() # Create app model config app_model_config = AppModelConfig( @@ -421,12 +419,12 @@ class TestAgentService: agent_mode=json.dumps({"enabled": True, "strategy": "react", "tools": []}), ) app_model_config.id = fake.uuid4() - db.session.add(app_model_config) - db.session.commit() + db_session_with_containers.add(app_model_config) + db_session_with_containers.commit() # Update conversation with app model config conversation.app_model_config_id = app_model_config.id - db.session.commit() + db_session_with_containers.commit() # Create message message = Message( @@ -447,8 +445,8 @@ class TestAgentService: currency="USD", from_source="api", ) - db.session.add(message) - db.session.commit() + db_session_with_containers.add(message) + db_session_with_containers.commit() # Execute the method under test result = AgentService.get_agent_logs(app, str(conversation.id), str(message.id)) @@ -457,7 +455,9 @@ class TestAgentService: assert result is not None assert result["meta"]["executor"] == end_user.name - def test_get_agent_logs_with_unknown_executor(self, db_session_with_containers, mock_external_service_dependencies): + def test_get_agent_logs_with_unknown_executor( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test agent logs retrieval when executor is unknown. """ @@ -466,8 +466,6 @@ class TestAgentService: # Create test data app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) - from extensions.ext_database import db - # Create conversation with non-existent account conversation = Conversation( id=fake.uuid4(), @@ -480,8 +478,8 @@ class TestAgentService: mode="chat", from_source="api", ) - db.session.add(conversation) - db.session.commit() + db_session_with_containers.add(conversation) + db_session_with_containers.commit() # Create app model config app_model_config = AppModelConfig( @@ -493,12 +491,12 @@ class TestAgentService: agent_mode=json.dumps({"enabled": True, "strategy": "react", "tools": []}), ) app_model_config.id = fake.uuid4() - db.session.add(app_model_config) - db.session.commit() + db_session_with_containers.add(app_model_config) + db_session_with_containers.commit() # Update conversation with app model config conversation.app_model_config_id = app_model_config.id - db.session.commit() + db_session_with_containers.commit() # Create message message = Message( @@ -519,8 +517,8 @@ class TestAgentService: currency="USD", from_source="api", ) - db.session.add(message) - db.session.commit() + db_session_with_containers.add(message) + db_session_with_containers.commit() # Execute the method under test result = AgentService.get_agent_logs(app, str(conversation.id), str(message.id)) @@ -529,7 +527,9 @@ class TestAgentService: assert result is not None assert result["meta"]["executor"] == "Unknown" - def test_get_agent_logs_with_tool_error(self, db_session_with_containers, mock_external_service_dependencies): + def test_get_agent_logs_with_tool_error( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test agent logs retrieval with tool errors. """ @@ -539,8 +539,6 @@ class TestAgentService: app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) conversation, message = self._create_test_conversation_and_message(db_session_with_containers, app, account) - from extensions.ext_database import db - # Create agent thought with tool error thought_with_error = MessageAgentThought( message_id=message.id, @@ -564,8 +562,8 @@ class TestAgentService: created_by_role="account", created_by=message.from_account_id, ) - db.session.add(thought_with_error) - db.session.commit() + db_session_with_containers.add(thought_with_error) + db_session_with_containers.commit() # Execute the method under test result = AgentService.get_agent_logs(app, str(conversation.id), str(message.id)) @@ -580,7 +578,7 @@ class TestAgentService: assert tool_call["error"] == "Tool execution failed" def test_get_agent_logs_without_agent_thoughts( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test agent logs retrieval when message has no agent thoughts. @@ -600,7 +598,7 @@ class TestAgentService: assert len(result["iterations"]) == 0 def test_get_agent_logs_app_model_config_not_found( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test error handling when app model config is not found. @@ -610,11 +608,9 @@ class TestAgentService: # Create test data app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) - from extensions.ext_database import db - # Remove app model config to test error handling app.app_model_config_id = None - db.session.commit() + db_session_with_containers.commit() # Create conversation without app model config conversation = Conversation( @@ -629,8 +625,8 @@ class TestAgentService: from_source="api", app_model_config_id=None, # Explicitly set to None ) - db.session.add(conversation) - db.session.commit() + db_session_with_containers.add(conversation) + db_session_with_containers.commit() # Create message message = Message( @@ -651,15 +647,15 @@ class TestAgentService: currency="USD", from_source="api", ) - db.session.add(message) - db.session.commit() + db_session_with_containers.add(message) + db_session_with_containers.commit() # Execute the method under test with pytest.raises(ValueError, match="App model config not found"): AgentService.get_agent_logs(app, str(conversation.id), str(message.id)) def test_get_agent_logs_agent_config_not_found( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test error handling when agent config is not found. @@ -677,7 +673,9 @@ class TestAgentService: with pytest.raises(ValueError, match="Agent config not found"): AgentService.get_agent_logs(app, str(conversation.id), str(message.id)) - def test_list_agent_providers_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_list_agent_providers_success( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test successful listing of agent providers. """ @@ -698,7 +696,7 @@ class TestAgentService: mock_plugin_client = mock_external_service_dependencies["plugin_agent_client"].return_value mock_plugin_client.fetch_agent_strategy_providers.assert_called_once_with(str(app.tenant_id)) - def test_get_agent_provider_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_get_agent_provider_success(self, db_session_with_containers: Session, mock_external_service_dependencies): """ Test successful retrieval of specific agent provider. """ @@ -720,7 +718,9 @@ class TestAgentService: mock_plugin_client = mock_external_service_dependencies["plugin_agent_client"].return_value mock_plugin_client.fetch_agent_strategy_provider.assert_called_once_with(str(app.tenant_id), provider_name) - def test_get_agent_provider_plugin_error(self, db_session_with_containers, mock_external_service_dependencies): + def test_get_agent_provider_plugin_error( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test error handling when plugin daemon client raises an error. """ @@ -741,7 +741,7 @@ class TestAgentService: AgentService.get_agent_provider(str(account.id), str(app.tenant_id), provider_name) def test_get_agent_logs_with_complex_tool_data( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test agent logs retrieval with complex tool data and multiple tools. @@ -752,8 +752,6 @@ class TestAgentService: app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) conversation, message = self._create_test_conversation_and_message(db_session_with_containers, app, account) - from extensions.ext_database import db - # Create agent thought with multiple tools complex_thought = MessageAgentThought( message_id=message.id, @@ -799,8 +797,8 @@ class TestAgentService: created_by_role="account", created_by=message.from_account_id, ) - db.session.add(complex_thought) - db.session.commit() + db_session_with_containers.add(complex_thought) + db_session_with_containers.commit() # Execute the method under test result = AgentService.get_agent_logs(app, str(conversation.id), str(message.id)) @@ -831,7 +829,7 @@ class TestAgentService: assert tool_calls[2]["status"] == "success" assert tool_calls[2]["tool_icon"] == "" # dataset-retrieval tools have empty icon - def test_get_agent_logs_with_files(self, db_session_with_containers, mock_external_service_dependencies): + def test_get_agent_logs_with_files(self, db_session_with_containers: Session, mock_external_service_dependencies): """ Test agent logs retrieval with message files and agent thought files. """ @@ -842,7 +840,6 @@ class TestAgentService: conversation, message = self._create_test_conversation_and_message(db_session_with_containers, app, account) from dify_graph.file import FileTransferMethod, FileType - from extensions.ext_database import db from models.enums import CreatorUserRole # Add files to message @@ -867,9 +864,9 @@ class TestAgentService: created_by_role=CreatorUserRole.ACCOUNT, created_by=message.from_account_id, ) - db.session.add(message_file1) - db.session.add(message_file2) - db.session.commit() + db_session_with_containers.add(message_file1) + db_session_with_containers.add(message_file2) + db_session_with_containers.commit() # Create agent thought with files thought_with_files = MessageAgentThought( @@ -895,8 +892,8 @@ class TestAgentService: created_by_role="account", created_by=message.from_account_id, ) - db.session.add(thought_with_files) - db.session.commit() + db_session_with_containers.add(thought_with_files) + db_session_with_containers.commit() # Execute the method under test result = AgentService.get_agent_logs(app, str(conversation.id), str(message.id)) @@ -912,7 +909,7 @@ class TestAgentService: assert "file2" in iterations[0]["files"] def test_get_agent_logs_with_different_timezone( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test agent logs retrieval with different timezone settings. @@ -938,7 +935,9 @@ class TestAgentService: assert "T" in start_time # ISO format assert "+08:00" in start_time or "Z" in start_time # Timezone offset - def test_get_agent_logs_with_empty_tool_data(self, db_session_with_containers, mock_external_service_dependencies): + def test_get_agent_logs_with_empty_tool_data( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test agent logs retrieval with empty tool data. """ @@ -948,8 +947,6 @@ class TestAgentService: app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) conversation, message = self._create_test_conversation_and_message(db_session_with_containers, app, account) - from extensions.ext_database import db - # Create agent thought with empty tool data empty_thought = MessageAgentThought( message_id=message.id, @@ -964,8 +961,8 @@ class TestAgentService: created_by_role="account", created_by=message.from_account_id, ) - db.session.add(empty_thought) - db.session.commit() + db_session_with_containers.add(empty_thought) + db_session_with_containers.commit() # Execute the method under test result = AgentService.get_agent_logs(app, str(conversation.id), str(message.id)) @@ -979,7 +976,9 @@ class TestAgentService: tool_calls = iterations[0]["tool_calls"] assert len(tool_calls) == 0 # No tools to process - def test_get_agent_logs_with_malformed_json(self, db_session_with_containers, mock_external_service_dependencies): + def test_get_agent_logs_with_malformed_json( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test agent logs retrieval with malformed JSON data in tool fields. """ @@ -989,8 +988,6 @@ class TestAgentService: app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) conversation, message = self._create_test_conversation_and_message(db_session_with_containers, app, account) - from extensions.ext_database import db - # Create agent thought with malformed JSON malformed_thought = MessageAgentThought( message_id=message.id, @@ -1005,8 +1002,8 @@ class TestAgentService: created_by_role="account", created_by=message.from_account_id, ) - db.session.add(malformed_thought) - db.session.commit() + db_session_with_containers.add(malformed_thought) + db_session_with_containers.commit() # Execute the method under test result = AgentService.get_agent_logs(app, str(conversation.id), str(message.id)) diff --git a/api/tests/test_containers_integration_tests/services/test_annotation_service.py b/api/tests/test_containers_integration_tests/services/test_annotation_service.py index 4f5190e533..004d643955 100644 --- a/api/tests/test_containers_integration_tests/services/test_annotation_service.py +++ b/api/tests/test_containers_integration_tests/services/test_annotation_service.py @@ -2,6 +2,7 @@ from unittest.mock import create_autospec, patch import pytest from faker import Faker +from sqlalchemy.orm import Session from werkzeug.exceptions import NotFound from models import Account @@ -52,7 +53,7 @@ class TestAnnotationService: "current_user": mock_user, } - def _create_test_app_and_account(self, db_session_with_containers, mock_external_service_dependencies): + def _create_test_app_and_account(self, db_session_with_containers: Session, mock_external_service_dependencies): """ Helper method to create a test app and account for testing. @@ -115,11 +116,10 @@ class TestAnnotationService: tenant_id, ) - def _create_test_conversation(self, app, account, fake): + def _create_test_conversation(self, db_session_with_containers: Session, app, account, fake): """ Helper method to create a test conversation with all required fields. """ - from extensions.ext_database import db from models.model import Conversation conversation = Conversation( @@ -141,17 +141,16 @@ class TestAnnotationService: from_account_id=account.id, ) - db.session.add(conversation) - db.session.flush() + db_session_with_containers.add(conversation) + db_session_with_containers.flush() return conversation - def _create_test_message(self, app, conversation, account, fake): + def _create_test_message(self, db_session_with_containers: Session, app, conversation, account, fake): """ Helper method to create a test message with all required fields. """ import json - from extensions.ext_database import db from models.model import Message message = Message( @@ -180,12 +179,12 @@ class TestAnnotationService: from_account_id=account.id, ) - db.session.add(message) - db.session.commit() + db_session_with_containers.add(message) + db_session_with_containers.commit() return message def test_insert_app_annotation_directly_success( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test successful direct insertion of app annotation. @@ -211,9 +210,8 @@ class TestAnnotationService: assert annotation.id is not None # Verify annotation was saved to database - from extensions.ext_database import db - db.session.refresh(annotation) + db_session_with_containers.refresh(annotation) assert annotation.id is not None # Verify add_annotation_to_index_task was called (when annotation setting exists) @@ -221,7 +219,7 @@ class TestAnnotationService: mock_external_service_dependencies["add_task"].delay.assert_not_called() def test_insert_app_annotation_directly_requires_question( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Question must be provided when inserting annotations directly. @@ -238,7 +236,7 @@ class TestAnnotationService: AppAnnotationService.insert_app_annotation_directly(annotation_args, app.id) def test_insert_app_annotation_directly_app_not_found( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test direct insertion of app annotation when app is not found. @@ -260,7 +258,7 @@ class TestAnnotationService: AppAnnotationService.insert_app_annotation_directly(annotation_args, non_existent_app_id) def test_update_app_annotation_directly_success( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test successful direct update of app annotation. @@ -298,7 +296,7 @@ class TestAnnotationService: mock_external_service_dependencies["update_task"].delay.assert_not_called() def test_up_insert_app_annotation_from_message_new( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test creating new annotation from message. @@ -307,8 +305,8 @@ class TestAnnotationService: app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) # Create a conversation and message first - conversation = self._create_test_conversation(app, account, fake) - message = self._create_test_message(app, conversation, account, fake) + conversation = self._create_test_conversation(db_session_with_containers, app, account, fake) + message = self._create_test_message(db_session_with_containers, app, conversation, account, fake) # Setup annotation data with message_id annotation_args = { @@ -333,7 +331,7 @@ class TestAnnotationService: mock_external_service_dependencies["add_task"].delay.assert_not_called() def test_up_insert_app_annotation_from_message_update( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test updating existing annotation from message. @@ -342,8 +340,8 @@ class TestAnnotationService: app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) # Create a conversation and message first - conversation = self._create_test_conversation(app, account, fake) - message = self._create_test_message(app, conversation, account, fake) + conversation = self._create_test_conversation(db_session_with_containers, app, account, fake) + message = self._create_test_message(db_session_with_containers, app, conversation, account, fake) # Create initial annotation initial_args = { @@ -373,7 +371,7 @@ class TestAnnotationService: mock_external_service_dependencies["add_task"].delay.assert_not_called() def test_up_insert_app_annotation_from_message_app_not_found( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test creating annotation from message when app is not found. @@ -395,7 +393,7 @@ class TestAnnotationService: AppAnnotationService.up_insert_app_annotation_from_message(annotation_args, non_existent_app_id) def test_get_annotation_list_by_app_id_success( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test successful retrieval of annotation list by app ID. @@ -428,7 +426,7 @@ class TestAnnotationService: assert annotation.account_id == account.id def test_get_annotation_list_by_app_id_with_keyword( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test retrieval of annotation list with keyword search. @@ -462,7 +460,7 @@ class TestAnnotationService: assert unique_keyword in annotation_list[0].question or unique_keyword in annotation_list[0].content def test_get_annotation_list_by_app_id_with_special_characters_in_keyword( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): r""" Test retrieval of annotation list with special characters in keyword to verify SQL injection prevention. @@ -534,7 +532,7 @@ class TestAnnotationService: assert all("50%" in (item.question or "") or "50%" in (item.content or "") for item in annotation_list) def test_get_annotation_list_by_app_id_app_not_found( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test retrieval of annotation list when app is not found. @@ -549,7 +547,9 @@ class TestAnnotationService: with pytest.raises(NotFound, match="App not found"): AppAnnotationService.get_annotation_list_by_app_id(non_existent_app_id, page=1, limit=10, keyword="") - def test_delete_app_annotation_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_delete_app_annotation_success( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test successful deletion of app annotation. """ @@ -568,16 +568,19 @@ class TestAnnotationService: AppAnnotationService.delete_app_annotation(app.id, annotation_id) # Verify annotation was deleted - from extensions.ext_database import db - deleted_annotation = db.session.query(MessageAnnotation).where(MessageAnnotation.id == annotation_id).first() + deleted_annotation = ( + db_session_with_containers.query(MessageAnnotation).where(MessageAnnotation.id == annotation_id).first() + ) assert deleted_annotation is None # Verify delete_annotation_index_task was called (when annotation setting exists) # Note: In this test, no annotation setting exists, so task should not be called mock_external_service_dependencies["delete_task"].delay.assert_not_called() - def test_delete_app_annotation_app_not_found(self, db_session_with_containers, mock_external_service_dependencies): + def test_delete_app_annotation_app_not_found( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test deletion of app annotation when app is not found. """ @@ -593,7 +596,7 @@ class TestAnnotationService: AppAnnotationService.delete_app_annotation(non_existent_app_id, annotation_id) def test_delete_app_annotation_annotation_not_found( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test deletion of app annotation when annotation is not found. @@ -606,7 +609,9 @@ class TestAnnotationService: with pytest.raises(NotFound, match="Annotation not found"): AppAnnotationService.delete_app_annotation(app.id, non_existent_annotation_id) - def test_enable_app_annotation_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_enable_app_annotation_success( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test successful enabling of app annotation. """ @@ -632,7 +637,9 @@ class TestAnnotationService: # Verify task was called mock_external_service_dependencies["enable_task"].delay.assert_called_once() - def test_disable_app_annotation_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_disable_app_annotation_success( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test successful disabling of app annotation. """ @@ -651,7 +658,9 @@ class TestAnnotationService: # Verify task was called mock_external_service_dependencies["disable_task"].delay.assert_called_once() - def test_enable_app_annotation_cached_job(self, db_session_with_containers, mock_external_service_dependencies): + def test_enable_app_annotation_cached_job( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test enabling app annotation when job is already cached. """ @@ -685,7 +694,9 @@ class TestAnnotationService: # Clean up redis_client.delete(enable_app_annotation_key) - def test_get_annotation_hit_histories_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_get_annotation_hit_histories_success( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test successful retrieval of annotation hit histories. """ @@ -728,7 +739,9 @@ class TestAnnotationService: assert history.app_id == app.id assert history.account_id == account.id - def test_add_annotation_history_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_add_annotation_history_success( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test successful addition of annotation history. """ @@ -763,16 +776,15 @@ class TestAnnotationService: ) # Verify hit count was incremented - from extensions.ext_database import db - db.session.refresh(annotation) + db_session_with_containers.refresh(annotation) assert annotation.hit_count == initial_hit_count + 1 # Verify history was created from models.model import AppAnnotationHitHistory history = ( - db.session.query(AppAnnotationHitHistory) + db_session_with_containers.query(AppAnnotationHitHistory) .where( AppAnnotationHitHistory.annotation_id == annotation.id, AppAnnotationHitHistory.message_id == message_id ) @@ -786,7 +798,9 @@ class TestAnnotationService: assert history.score == score assert history.source == "console" - def test_get_annotation_by_id_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_get_annotation_by_id_success( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test successful retrieval of annotation by ID. """ @@ -811,7 +825,9 @@ class TestAnnotationService: assert retrieved_annotation.content == annotation_args["answer"] assert retrieved_annotation.account_id == account.id - def test_batch_import_app_annotations_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_batch_import_app_annotations_success( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test successful batch import of app annotations. """ @@ -854,7 +870,7 @@ class TestAnnotationService: mock_external_service_dependencies["batch_import_task"].delay.assert_called_once() def test_batch_import_app_annotations_empty_file( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test batch import with empty CSV file. @@ -889,7 +905,7 @@ class TestAnnotationService: assert "empty" in result["error_msg"].lower() def test_batch_import_app_annotations_quota_exceeded( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test batch import when quota is exceeded. @@ -935,7 +951,7 @@ class TestAnnotationService: assert "limit" in result["error_msg"].lower() def test_get_app_annotation_setting_by_app_id_enabled( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test getting enabled app annotation setting by app ID. @@ -944,7 +960,6 @@ class TestAnnotationService: app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) # Create annotation setting - from extensions.ext_database import db from models.dataset import DatasetCollectionBinding from models.model import AppAnnotationSetting @@ -956,8 +971,8 @@ class TestAnnotationService: collection_name=f"annotation_collection_{fake.uuid4()}", ) collection_binding.id = str(fake.uuid4()) - db.session.add(collection_binding) - db.session.flush() + db_session_with_containers.add(collection_binding) + db_session_with_containers.flush() # Create annotation setting annotation_setting = AppAnnotationSetting( @@ -967,8 +982,8 @@ class TestAnnotationService: created_user_id=account.id, updated_user_id=account.id, ) - db.session.add(annotation_setting) - db.session.commit() + db_session_with_containers.add(annotation_setting) + db_session_with_containers.commit() # Get annotation setting result = AppAnnotationService.get_app_annotation_setting_by_app_id(app.id) @@ -981,7 +996,7 @@ class TestAnnotationService: assert result["embedding_model"]["embedding_model_name"] == "text-embedding-ada-002" def test_get_app_annotation_setting_by_app_id_disabled( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test getting disabled app annotation setting by app ID. @@ -996,7 +1011,7 @@ class TestAnnotationService: assert result["enabled"] is False def test_update_app_annotation_setting_success( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test successful update of app annotation setting. @@ -1005,7 +1020,6 @@ class TestAnnotationService: app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) # Create annotation setting first - from extensions.ext_database import db from models.dataset import DatasetCollectionBinding from models.model import AppAnnotationSetting @@ -1017,8 +1031,8 @@ class TestAnnotationService: collection_name=f"annotation_collection_{fake.uuid4()}", ) collection_binding.id = str(fake.uuid4()) - db.session.add(collection_binding) - db.session.flush() + db_session_with_containers.add(collection_binding) + db_session_with_containers.flush() # Create annotation setting annotation_setting = AppAnnotationSetting( @@ -1028,8 +1042,8 @@ class TestAnnotationService: created_user_id=account.id, updated_user_id=account.id, ) - db.session.add(annotation_setting) - db.session.commit() + db_session_with_containers.add(annotation_setting) + db_session_with_containers.commit() # Update annotation setting update_args = { @@ -1046,11 +1060,11 @@ class TestAnnotationService: assert result["embedding_model"]["embedding_model_name"] == "text-embedding-ada-002" # Verify database was updated - db.session.refresh(annotation_setting) + db_session_with_containers.refresh(annotation_setting) assert annotation_setting.score_threshold == 0.9 def test_export_annotation_list_by_app_id_success( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test successful export of annotation list by app ID. @@ -1083,7 +1097,7 @@ class TestAnnotationService: assert annotation.created_at <= exported_annotations[i - 1].created_at def test_export_annotation_list_by_app_id_app_not_found( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test export of annotation list when app is not found. @@ -1099,7 +1113,7 @@ class TestAnnotationService: AppAnnotationService.export_annotation_list_by_app_id(non_existent_app_id) def test_insert_app_annotation_directly_with_setting_success( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test successful direct insertion of app annotation with annotation setting enabled. @@ -1108,7 +1122,6 @@ class TestAnnotationService: app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) # Create annotation setting first - from extensions.ext_database import db from models.dataset import DatasetCollectionBinding from models.model import AppAnnotationSetting @@ -1120,8 +1133,8 @@ class TestAnnotationService: collection_name=f"annotation_collection_{fake.uuid4()}", ) collection_binding.id = str(fake.uuid4()) - db.session.add(collection_binding) - db.session.flush() + db_session_with_containers.add(collection_binding) + db_session_with_containers.flush() # Create annotation setting annotation_setting = AppAnnotationSetting( @@ -1131,8 +1144,8 @@ class TestAnnotationService: created_user_id=account.id, updated_user_id=account.id, ) - db.session.add(annotation_setting) - db.session.commit() + db_session_with_containers.add(annotation_setting) + db_session_with_containers.commit() # Setup annotation data annotation_args = { @@ -1161,7 +1174,7 @@ class TestAnnotationService: assert call_args[4] == collection_binding.id # collection_binding_id def test_update_app_annotation_directly_with_setting_success( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test successful direct update of app annotation with annotation setting enabled. @@ -1170,7 +1183,6 @@ class TestAnnotationService: app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) # Create annotation setting first - from extensions.ext_database import db from models.dataset import DatasetCollectionBinding from models.model import AppAnnotationSetting @@ -1182,8 +1194,8 @@ class TestAnnotationService: collection_name=f"annotation_collection_{fake.uuid4()}", ) collection_binding.id = str(fake.uuid4()) - db.session.add(collection_binding) - db.session.flush() + db_session_with_containers.add(collection_binding) + db_session_with_containers.flush() # Create annotation setting annotation_setting = AppAnnotationSetting( @@ -1193,8 +1205,8 @@ class TestAnnotationService: created_user_id=account.id, updated_user_id=account.id, ) - db.session.add(annotation_setting) - db.session.commit() + db_session_with_containers.add(annotation_setting) + db_session_with_containers.commit() # First, create an annotation original_args = { @@ -1234,7 +1246,7 @@ class TestAnnotationService: assert call_args[4] == collection_binding.id # collection_binding_id def test_delete_app_annotation_with_setting_success( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test successful deletion of app annotation with annotation setting enabled. @@ -1243,7 +1255,6 @@ class TestAnnotationService: app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) # Create annotation setting first - from extensions.ext_database import db from models.dataset import DatasetCollectionBinding from models.model import AppAnnotationSetting @@ -1255,8 +1266,8 @@ class TestAnnotationService: collection_name=f"annotation_collection_{fake.uuid4()}", ) collection_binding.id = str(fake.uuid4()) - db.session.add(collection_binding) - db.session.flush() + db_session_with_containers.add(collection_binding) + db_session_with_containers.flush() # Create annotation setting annotation_setting = AppAnnotationSetting( @@ -1267,8 +1278,8 @@ class TestAnnotationService: updated_user_id=account.id, ) - db.session.add(annotation_setting) - db.session.commit() + db_session_with_containers.add(annotation_setting) + db_session_with_containers.commit() # Create an annotation first annotation_args = { @@ -1285,7 +1296,9 @@ class TestAnnotationService: AppAnnotationService.delete_app_annotation(app.id, annotation_id) # Verify annotation was deleted - deleted_annotation = db.session.query(MessageAnnotation).where(MessageAnnotation.id == annotation_id).first() + deleted_annotation = ( + db_session_with_containers.query(MessageAnnotation).where(MessageAnnotation.id == annotation_id).first() + ) assert deleted_annotation is None # Verify delete_annotation_index_task was called @@ -1297,7 +1310,7 @@ class TestAnnotationService: assert call_args[3] == collection_binding.id # collection_binding_id def test_up_insert_app_annotation_from_message_with_setting_success( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test creating annotation from message with annotation setting enabled. @@ -1306,7 +1319,6 @@ class TestAnnotationService: app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) # Create annotation setting first - from extensions.ext_database import db from models.dataset import DatasetCollectionBinding from models.model import AppAnnotationSetting @@ -1318,8 +1330,8 @@ class TestAnnotationService: collection_name=f"annotation_collection_{fake.uuid4()}", ) collection_binding.id = str(fake.uuid4()) - db.session.add(collection_binding) - db.session.flush() + db_session_with_containers.add(collection_binding) + db_session_with_containers.flush() # Create annotation setting annotation_setting = AppAnnotationSetting( @@ -1329,12 +1341,12 @@ class TestAnnotationService: created_user_id=account.id, updated_user_id=account.id, ) - db.session.add(annotation_setting) - db.session.commit() + db_session_with_containers.add(annotation_setting) + db_session_with_containers.commit() # Create a conversation and message first - conversation = self._create_test_conversation(app, account, fake) - message = self._create_test_message(app, conversation, account, fake) + conversation = self._create_test_conversation(db_session_with_containers, app, account, fake) + message = self._create_test_message(db_session_with_containers, app, conversation, account, fake) # Setup annotation data with message_id annotation_args = { diff --git a/api/tests/test_containers_integration_tests/services/test_api_based_extension_service.py b/api/tests/test_containers_integration_tests/services/test_api_based_extension_service.py index 8c8be2e670..b8bf8543bc 100644 --- a/api/tests/test_containers_integration_tests/services/test_api_based_extension_service.py +++ b/api/tests/test_containers_integration_tests/services/test_api_based_extension_service.py @@ -2,6 +2,7 @@ from unittest.mock import patch import pytest from faker import Faker +from sqlalchemy.orm import Session from models.api_based_extension import APIBasedExtension from services.account_service import AccountService, TenantService @@ -31,7 +32,7 @@ class TestAPIBasedExtensionService: "requestor_instance": mock_requestor_instance, } - def _create_test_account_and_tenant(self, db_session_with_containers, mock_external_service_dependencies): + def _create_test_account_and_tenant(self, db_session_with_containers: Session, mock_external_service_dependencies): """ Helper method to create a test account and tenant for testing. @@ -61,7 +62,7 @@ class TestAPIBasedExtensionService: return account, tenant - def test_save_extension_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_save_extension_success(self, db_session_with_containers: Session, mock_external_service_dependencies): """ Test successful saving of API-based extension. """ @@ -90,15 +91,16 @@ class TestAPIBasedExtensionService: assert saved_extension.created_at is not None # Verify extension was saved to database - from extensions.ext_database import db - db.session.refresh(saved_extension) + db_session_with_containers.refresh(saved_extension) assert saved_extension.id is not None # Verify ping connection was called mock_external_service_dependencies["requestor_instance"].request.assert_called_once() - def test_save_extension_validation_errors(self, db_session_with_containers, mock_external_service_dependencies): + def test_save_extension_validation_errors( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test validation errors when saving extension with invalid data. """ @@ -132,7 +134,9 @@ class TestAPIBasedExtensionService: with pytest.raises(ValueError, match="api_key must not be empty"): APIBasedExtensionService.save(extension_data) - def test_get_all_by_tenant_id_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_get_all_by_tenant_id_success( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test successful retrieval of all extensions by tenant ID. """ @@ -169,7 +173,7 @@ class TestAPIBasedExtensionService: # Verify descending order (newer first) assert extension.created_at <= extension_list[i - 1].created_at - def test_get_with_tenant_id_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_get_with_tenant_id_success(self, db_session_with_containers: Session, mock_external_service_dependencies): """ Test successful retrieval of extension by tenant ID and extension ID. """ @@ -200,7 +204,9 @@ class TestAPIBasedExtensionService: assert retrieved_extension.api_key == extension_data.api_key # Should be decrypted assert retrieved_extension.created_at is not None - def test_get_with_tenant_id_not_found(self, db_session_with_containers, mock_external_service_dependencies): + def test_get_with_tenant_id_not_found( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test retrieval of extension when extension is not found. """ @@ -214,7 +220,7 @@ class TestAPIBasedExtensionService: with pytest.raises(ValueError, match="API based extension is not found"): APIBasedExtensionService.get_with_tenant_id(tenant.id, non_existent_extension_id) - def test_delete_extension_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_delete_extension_success(self, db_session_with_containers: Session, mock_external_service_dependencies): """ Test successful deletion of extension. """ @@ -238,12 +244,15 @@ class TestAPIBasedExtensionService: APIBasedExtensionService.delete(created_extension) # Verify extension was deleted - from extensions.ext_database import db - deleted_extension = db.session.query(APIBasedExtension).where(APIBasedExtension.id == extension_id).first() + deleted_extension = ( + db_session_with_containers.query(APIBasedExtension).where(APIBasedExtension.id == extension_id).first() + ) assert deleted_extension is None - def test_save_extension_duplicate_name(self, db_session_with_containers, mock_external_service_dependencies): + def test_save_extension_duplicate_name( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test validation error when saving extension with duplicate name. """ @@ -272,7 +281,9 @@ class TestAPIBasedExtensionService: with pytest.raises(ValueError, match="name must be unique, it is already existed"): APIBasedExtensionService.save(extension_data2) - def test_save_extension_update_existing(self, db_session_with_containers, mock_external_service_dependencies): + def test_save_extension_update_existing( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test successful update of existing extension. """ @@ -329,7 +340,9 @@ class TestAPIBasedExtensionService: assert retrieved_extension.api_endpoint == new_endpoint assert retrieved_extension.api_key == new_api_key # Should be decrypted when retrieved - def test_save_extension_connection_error(self, db_session_with_containers, mock_external_service_dependencies): + def test_save_extension_connection_error( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test connection error when saving extension with invalid endpoint. """ @@ -356,7 +369,7 @@ class TestAPIBasedExtensionService: APIBasedExtensionService.save(extension_data) def test_save_extension_invalid_api_key_length( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test validation error when saving extension with API key that is too short. @@ -378,7 +391,7 @@ class TestAPIBasedExtensionService: with pytest.raises(ValueError, match="api_key must be at least 5 characters"): APIBasedExtensionService.save(extension_data) - def test_save_extension_empty_fields(self, db_session_with_containers, mock_external_service_dependencies): + def test_save_extension_empty_fields(self, db_session_with_containers: Session, mock_external_service_dependencies): """ Test validation errors when saving extension with empty required fields. """ @@ -412,7 +425,9 @@ class TestAPIBasedExtensionService: with pytest.raises(ValueError, match="api_key must not be empty"): APIBasedExtensionService.save(extension_data) - def test_get_all_by_tenant_id_empty_list(self, db_session_with_containers, mock_external_service_dependencies): + def test_get_all_by_tenant_id_empty_list( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test retrieval of extensions when no extensions exist for tenant. """ @@ -428,7 +443,9 @@ class TestAPIBasedExtensionService: assert len(extension_list) == 0 assert extension_list == [] - def test_save_extension_invalid_ping_response(self, db_session_with_containers, mock_external_service_dependencies): + def test_save_extension_invalid_ping_response( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test validation error when ping response is invalid. """ @@ -452,7 +469,9 @@ class TestAPIBasedExtensionService: with pytest.raises(ValueError, match="{'result': 'invalid'}"): APIBasedExtensionService.save(extension_data) - def test_save_extension_missing_ping_result(self, db_session_with_containers, mock_external_service_dependencies): + def test_save_extension_missing_ping_result( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test validation error when ping response is missing result field. """ @@ -476,7 +495,9 @@ class TestAPIBasedExtensionService: with pytest.raises(ValueError, match="{'status': 'ok'}"): APIBasedExtensionService.save(extension_data) - def test_get_with_tenant_id_wrong_tenant(self, db_session_with_containers, mock_external_service_dependencies): + def test_get_with_tenant_id_wrong_tenant( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test retrieval of extension when tenant ID doesn't match. """ diff --git a/api/tests/test_containers_integration_tests/services/test_app_generate_service.py b/api/tests/test_containers_integration_tests/services/test_app_generate_service.py index 8544d23cdf..787a99f3e8 100644 --- a/api/tests/test_containers_integration_tests/services/test_app_generate_service.py +++ b/api/tests/test_containers_integration_tests/services/test_app_generate_service.py @@ -3,6 +3,7 @@ from unittest.mock import ANY, MagicMock, patch import pytest from faker import Faker +from sqlalchemy.orm import Session from core.app.entities.app_invoke_entities import InvokeFrom from models.model import EndUser @@ -118,7 +119,9 @@ class TestAppGenerateService: "global_dify_config": mock_global_dify_config, } - def _create_test_app_and_account(self, db_session_with_containers, mock_external_service_dependencies, mode="chat"): + def _create_test_app_and_account( + self, db_session_with_containers: Session, mock_external_service_dependencies, mode="chat" + ): """ Helper method to create a test app and account for testing. @@ -169,7 +172,7 @@ class TestAppGenerateService: return app, account - def _create_test_workflow(self, db_session_with_containers, app): + def _create_test_workflow(self, db_session_with_containers: Session, app): """ Helper method to create a test workflow for testing. @@ -191,14 +194,14 @@ class TestAppGenerateService: status="published", ) - from extensions.ext_database import db - - db.session.add(workflow) - db.session.commit() + db_session_with_containers.add(workflow) + db_session_with_containers.commit() return workflow - def test_generate_completion_mode_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_generate_completion_mode_success( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test successful generation for completion mode app. """ @@ -226,7 +229,7 @@ class TestAppGenerateService: mock_external_service_dependencies["completion_generator"].return_value.generate.assert_called_once() mock_external_service_dependencies["completion_generator"].convert_to_event_stream.assert_called_once() - def test_generate_chat_mode_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_generate_chat_mode_success(self, db_session_with_containers: Session, mock_external_service_dependencies): """ Test successful generation for chat mode app. """ @@ -250,7 +253,9 @@ class TestAppGenerateService: mock_external_service_dependencies["chat_generator"].return_value.generate.assert_called_once() mock_external_service_dependencies["chat_generator"].convert_to_event_stream.assert_called_once() - def test_generate_agent_chat_mode_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_generate_agent_chat_mode_success( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test successful generation for agent chat mode app. """ @@ -274,7 +279,9 @@ class TestAppGenerateService: mock_external_service_dependencies["agent_chat_generator"].return_value.generate.assert_called_once() mock_external_service_dependencies["agent_chat_generator"].convert_to_event_stream.assert_called_once() - def test_generate_advanced_chat_mode_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_generate_advanced_chat_mode_success( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test successful generation for advanced chat mode app. """ @@ -300,7 +307,9 @@ class TestAppGenerateService: "advanced_chat_generator" ].return_value.convert_to_event_stream.assert_called_once() - def test_generate_workflow_mode_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_generate_workflow_mode_success( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test successful generation for workflow mode app. """ @@ -324,7 +333,9 @@ class TestAppGenerateService: mock_external_service_dependencies["message_based_generator"].retrieve_events.assert_called_once() mock_external_service_dependencies["workflow_generator"].convert_to_event_stream.assert_called_once() - def test_generate_with_specific_workflow_id(self, db_session_with_containers, mock_external_service_dependencies): + def test_generate_with_specific_workflow_id( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test generation with a specific workflow ID. """ @@ -355,7 +366,9 @@ class TestAppGenerateService: "workflow_service" ].return_value.get_published_workflow_by_id.assert_called_once() - def test_generate_with_debugger_invoke_from(self, db_session_with_containers, mock_external_service_dependencies): + def test_generate_with_debugger_invoke_from( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test generation with debugger invoke from. """ @@ -378,7 +391,9 @@ class TestAppGenerateService: # Verify draft workflow was fetched for debugger mock_external_service_dependencies["workflow_service"].return_value.get_draft_workflow.assert_called_once() - def test_generate_with_non_streaming_mode(self, db_session_with_containers, mock_external_service_dependencies): + def test_generate_with_non_streaming_mode( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test generation with non-streaming mode. """ @@ -401,7 +416,7 @@ class TestAppGenerateService: # Verify rate limit exit was called for non-streaming mode mock_external_service_dependencies["rate_limit"].return_value.exit.assert_called_once() - def test_generate_with_end_user(self, db_session_with_containers, mock_external_service_dependencies): + def test_generate_with_end_user(self, db_session_with_containers: Session, mock_external_service_dependencies): """ Test generation with EndUser instead of Account. """ @@ -421,10 +436,8 @@ class TestAppGenerateService: session_id=fake.uuid4(), ) - from extensions.ext_database import db - - db.session.add(end_user) - db.session.commit() + db_session_with_containers.add(end_user) + db_session_with_containers.commit() # Setup test arguments args = {"inputs": {"query": fake.text(max_nb_chars=50)}, "response_mode": "streaming"} @@ -438,7 +451,7 @@ class TestAppGenerateService: assert result == ["test_response"] def test_generate_with_billing_enabled_sandbox_plan( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test generation with billing enabled and sandbox plan. @@ -466,7 +479,9 @@ class TestAppGenerateService: # Verify billing service was called to consume quota mock_external_service_dependencies["billing_service"].update_tenant_feature_plan_usage.assert_called_once() - def test_generate_with_invalid_app_mode(self, db_session_with_containers, mock_external_service_dependencies): + def test_generate_with_invalid_app_mode( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test generation with invalid app mode. """ @@ -491,7 +506,7 @@ class TestAppGenerateService: assert "Invalid app mode" in str(exc_info.value) def test_generate_with_workflow_id_format_error( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test generation with invalid workflow ID format. @@ -518,7 +533,7 @@ class TestAppGenerateService: assert "Invalid workflow_id format" in str(exc_info.value) def test_generate_with_workflow_not_found_error( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test generation when workflow is not found. @@ -552,7 +567,7 @@ class TestAppGenerateService: assert f"Workflow not found with id: {workflow_id}" in str(exc_info.value) def test_generate_with_workflow_not_initialized_error( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test generation when workflow is not initialized for debugger. @@ -578,7 +593,7 @@ class TestAppGenerateService: assert "Workflow not initialized" in str(exc_info.value) def test_generate_with_workflow_not_published_error( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test generation when workflow is not published for non-debugger. @@ -604,7 +619,7 @@ class TestAppGenerateService: assert "Workflow not published" in str(exc_info.value) def test_generate_single_iteration_advanced_chat_success( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test successful single iteration generation for advanced chat mode. @@ -631,7 +646,7 @@ class TestAppGenerateService: ].return_value.single_iteration_generate.assert_called_once() def test_generate_single_iteration_workflow_success( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test successful single iteration generation for workflow mode. @@ -658,7 +673,7 @@ class TestAppGenerateService: ].return_value.single_iteration_generate.assert_called_once() def test_generate_single_iteration_invalid_mode( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test single iteration generation with invalid app mode. @@ -681,7 +696,7 @@ class TestAppGenerateService: assert "Invalid app mode" in str(exc_info.value) def test_generate_single_loop_advanced_chat_success( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test successful single loop generation for advanced chat mode. @@ -708,7 +723,7 @@ class TestAppGenerateService: ].return_value.single_loop_generate.assert_called_once() def test_generate_single_loop_workflow_success( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test successful single loop generation for workflow mode. @@ -732,7 +747,9 @@ class TestAppGenerateService: # Verify workflow generator was called mock_external_service_dependencies["workflow_generator"].return_value.single_loop_generate.assert_called_once() - def test_generate_single_loop_invalid_mode(self, db_session_with_containers, mock_external_service_dependencies): + def test_generate_single_loop_invalid_mode( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test single loop generation with invalid app mode. """ @@ -753,7 +770,9 @@ class TestAppGenerateService: # Verify error message assert "Invalid app mode" in str(exc_info.value) - def test_generate_more_like_this_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_generate_more_like_this_success( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test successful more like this generation. """ @@ -778,7 +797,7 @@ class TestAppGenerateService: ].return_value.generate_more_like_this.assert_called_once() def test_generate_more_like_this_with_end_user( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test more like this generation with EndUser. @@ -799,10 +818,8 @@ class TestAppGenerateService: session_id=fake.uuid4(), ) - from extensions.ext_database import db - - db.session.add(end_user) - db.session.commit() + db_session_with_containers.add(end_user) + db_session_with_containers.commit() message_id = fake.uuid4() @@ -815,7 +832,7 @@ class TestAppGenerateService: assert result == ["more_like_this_response"] def test_get_max_active_requests_with_app_limit( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test getting max active requests with app-specific limit. @@ -835,7 +852,7 @@ class TestAppGenerateService: assert result == 10 def test_get_max_active_requests_with_config_limit( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test getting max active requests with config limit being smaller. @@ -856,7 +873,7 @@ class TestAppGenerateService: assert result <= 100 def test_get_max_active_requests_with_zero_limits( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test getting max active requests with zero limits (infinite). @@ -875,7 +892,9 @@ class TestAppGenerateService: # Verify the result (should return config limit when app limit is 0) assert result == 100 # dify_config.APP_MAX_ACTIVE_REQUESTS - def test_generate_with_exception_cleanup(self, db_session_with_containers, mock_external_service_dependencies): + def test_generate_with_exception_cleanup( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test that rate limit exit is called when an exception occurs. """ @@ -904,7 +923,9 @@ class TestAppGenerateService: # Verify rate limit exit was called for cleanup mock_external_service_dependencies["rate_limit"].return_value.exit.assert_called_once() - def test_generate_with_agent_mode_detection(self, db_session_with_containers, mock_external_service_dependencies): + def test_generate_with_agent_mode_detection( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test generation with agent mode detection based on app configuration. """ @@ -932,7 +953,7 @@ class TestAppGenerateService: mock_external_service_dependencies["agent_chat_generator"].convert_to_event_stream.assert_called_once() def test_generate_with_different_invoke_from_values( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test generation with different invoke from values. @@ -962,7 +983,7 @@ class TestAppGenerateService: # Verify the result assert result == ["test_response"] - def test_generate_with_complex_args(self, db_session_with_containers, mock_external_service_dependencies): + def test_generate_with_complex_args(self, db_session_with_containers: Session, mock_external_service_dependencies): """ Test generation with complex arguments including files and external trace ID. """ diff --git a/api/tests/test_containers_integration_tests/services/test_app_service.py b/api/tests/test_containers_integration_tests/services/test_app_service.py index 745d6c97b0..fc3b20aaae 100644 --- a/api/tests/test_containers_integration_tests/services/test_app_service.py +++ b/api/tests/test_containers_integration_tests/services/test_app_service.py @@ -2,6 +2,7 @@ from unittest.mock import create_autospec, patch import pytest from faker import Faker +from sqlalchemy.orm import Session from constants.model_template import default_app_templates from models import Account @@ -44,7 +45,7 @@ class TestAppService: "account_feature_service": mock_account_feature_service, } - def test_create_app_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_create_app_success(self, db_session_with_containers: Session, mock_external_service_dependencies): """ Test successful app creation with basic parameters. """ @@ -98,7 +99,9 @@ class TestAppService: assert app.is_public is False assert app.is_universal is False - def test_create_app_with_different_modes(self, db_session_with_containers, mock_external_service_dependencies): + def test_create_app_with_different_modes( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test app creation with different app modes. """ @@ -141,7 +144,7 @@ class TestAppService: assert app.tenant_id == tenant.id assert app.created_by == account.id - def test_get_app_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_get_app_success(self, db_session_with_containers: Session, mock_external_service_dependencies): """ Test successful app retrieval. """ @@ -189,7 +192,7 @@ class TestAppService: assert retrieved_app.tenant_id == created_app.tenant_id assert retrieved_app.created_by == created_app.created_by - def test_get_paginate_apps_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_get_paginate_apps_success(self, db_session_with_containers: Session, mock_external_service_dependencies): """ Test successful paginated app list retrieval. """ @@ -243,7 +246,9 @@ class TestAppService: assert app.tenant_id == tenant.id assert app.mode == "chat" - def test_get_paginate_apps_with_filters(self, db_session_with_containers, mock_external_service_dependencies): + def test_get_paginate_apps_with_filters( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test paginated app list with various filters. """ @@ -316,7 +321,9 @@ class TestAppService: my_apps = app_service.get_paginate_apps(account.id, tenant.id, created_by_me_args) assert len(my_apps.items) == 1 - def test_get_paginate_apps_with_tag_filters(self, db_session_with_containers, mock_external_service_dependencies): + def test_get_paginate_apps_with_tag_filters( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test paginated app list with tag filters. """ @@ -386,7 +393,7 @@ class TestAppService: # Should return None when no apps match tag filter assert paginated_apps is None - def test_update_app_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_update_app_success(self, db_session_with_containers: Session, mock_external_service_dependencies): """ Test successful app update with all fields. """ @@ -455,7 +462,7 @@ class TestAppService: assert updated_app.tenant_id == app.tenant_id assert updated_app.created_by == app.created_by - def test_update_app_name_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_update_app_name_success(self, db_session_with_containers: Session, mock_external_service_dependencies): """ Test successful app name update. """ @@ -508,7 +515,7 @@ class TestAppService: assert updated_app.tenant_id == app.tenant_id assert updated_app.created_by == app.created_by - def test_update_app_icon_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_update_app_icon_success(self, db_session_with_containers: Session, mock_external_service_dependencies): """ Test successful app icon update. """ @@ -565,7 +572,9 @@ class TestAppService: assert updated_app.tenant_id == app.tenant_id assert updated_app.created_by == app.created_by - def test_update_app_site_status_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_update_app_site_status_success( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test successful app site status update. """ @@ -623,7 +632,9 @@ class TestAppService: assert updated_app.tenant_id == app.tenant_id assert updated_app.created_by == app.created_by - def test_update_app_api_status_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_update_app_api_status_success( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test successful app API status update. """ @@ -681,7 +692,9 @@ class TestAppService: assert updated_app.tenant_id == app.tenant_id assert updated_app.created_by == app.created_by - def test_update_app_site_status_no_change(self, db_session_with_containers, mock_external_service_dependencies): + def test_update_app_site_status_no_change( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test app site status update when status doesn't change. """ @@ -732,7 +745,7 @@ class TestAppService: assert updated_app.tenant_id == app.tenant_id assert updated_app.created_by == app.created_by - def test_delete_app_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_delete_app_success(self, db_session_with_containers: Session, mock_external_service_dependencies): """ Test successful app deletion. """ @@ -778,12 +791,13 @@ class TestAppService: mock_delete_task.delay.assert_called_once_with(tenant_id=tenant.id, app_id=app_id) # Verify app was deleted from database - from extensions.ext_database import db - deleted_app = db.session.query(App).filter_by(id=app_id).first() + deleted_app = db_session_with_containers.query(App).filter_by(id=app_id).first() assert deleted_app is None - def test_delete_app_with_related_data(self, db_session_with_containers, mock_external_service_dependencies): + def test_delete_app_with_related_data( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test app deletion with related data cleanup. """ @@ -839,12 +853,11 @@ class TestAppService: mock_delete_task.delay.assert_called_once_with(tenant_id=tenant.id, app_id=app_id) # Verify app was deleted from database - from extensions.ext_database import db - deleted_app = db.session.query(App).filter_by(id=app_id).first() + deleted_app = db_session_with_containers.query(App).filter_by(id=app_id).first() assert deleted_app is None - def test_get_app_meta_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_get_app_meta_success(self, db_session_with_containers: Session, mock_external_service_dependencies): """ Test successful app metadata retrieval. """ @@ -883,7 +896,7 @@ class TestAppService: assert "tool_icons" in app_meta # Note: get_app_meta currently only returns tool_icons - def test_get_app_code_by_id_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_get_app_code_by_id_success(self, db_session_with_containers: Session, mock_external_service_dependencies): """ Test successful app code retrieval by app ID. """ @@ -923,7 +936,7 @@ class TestAppService: assert app_code is not None assert len(app_code) > 0 - def test_get_app_id_by_code_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_get_app_id_by_code_success(self, db_session_with_containers: Session, mock_external_service_dependencies): """ Test successful app ID retrieval by app code. """ @@ -963,10 +976,9 @@ class TestAppService: site.status = "normal" site.default_language = "en-US" site.customize_token_strategy = "uuid" - from extensions.ext_database import db - db.session.add(site) - db.session.commit() + db_session_with_containers.add(site) + db_session_with_containers.commit() # Get app ID by code app_id = AppService.get_app_id_by_code(site.code) @@ -974,7 +986,7 @@ class TestAppService: # Verify app ID was retrieved correctly assert app_id == app.id - def test_create_app_invalid_mode(self, db_session_with_containers, mock_external_service_dependencies): + def test_create_app_invalid_mode(self, db_session_with_containers: Session, mock_external_service_dependencies): """ Test app creation with invalid mode. """ @@ -1010,7 +1022,7 @@ class TestAppService: app_service.create_app(tenant.id, app_args, account) def test_get_apps_with_special_characters_in_name( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): r""" Test app retrieval with special characters in name search to verify SQL injection prevention. diff --git a/api/tests/test_containers_integration_tests/services/test_dataset_service.py b/api/tests/test_containers_integration_tests/services/test_dataset_service.py index c3decbf39d..102c1a1eb5 100644 --- a/api/tests/test_containers_integration_tests/services/test_dataset_service.py +++ b/api/tests/test_containers_integration_tests/services/test_dataset_service.py @@ -9,10 +9,10 @@ from unittest.mock import Mock, patch from uuid import uuid4 import pytest +from sqlalchemy.orm import Session from core.rag.retrieval.retrieval_methods import RetrievalMethod from dify_graph.model_runtime.entities.model_entities import ModelType -from extensions.ext_database import db from models.account import Account, Tenant, TenantAccountJoin, TenantAccountRole from models.dataset import Dataset, DatasetPermissionEnum, Document, ExternalKnowledgeBindings, Pipeline from services.dataset_service import DatasetService @@ -25,7 +25,9 @@ class DatasetServiceIntegrationDataFactory: """Factory for creating real database entities used by integration tests.""" @staticmethod - def create_account_with_tenant(role: TenantAccountRole = TenantAccountRole.OWNER) -> tuple[Account, Tenant]: + def create_account_with_tenant( + db_session_with_containers: Session, role: TenantAccountRole = TenantAccountRole.OWNER + ) -> tuple[Account, Tenant]: """Create an account and tenant, then bind the account as current tenant member.""" account = Account( email=f"{uuid4()}@example.com", @@ -34,8 +36,8 @@ class DatasetServiceIntegrationDataFactory: status="active", ) tenant = Tenant(name=f"tenant-{uuid4()}", status="normal") - db.session.add_all([account, tenant]) - db.session.flush() + db_session_with_containers.add_all([account, tenant]) + db_session_with_containers.flush() join = TenantAccountJoin( tenant_id=tenant.id, @@ -43,8 +45,8 @@ class DatasetServiceIntegrationDataFactory: role=role, current=True, ) - db.session.add(join) - db.session.flush() + db_session_with_containers.add(join) + db_session_with_containers.flush() # Keep tenant context on the in-memory user without opening a separate session. account.role = role @@ -53,6 +55,7 @@ class DatasetServiceIntegrationDataFactory: @staticmethod def create_dataset( + db_session_with_containers: Session, tenant_id: str, created_by: str, name: str = "Test Dataset", @@ -82,12 +85,14 @@ class DatasetServiceIntegrationDataFactory: collection_binding_id=collection_binding_id, chunk_structure=chunk_structure, ) - db.session.add(dataset) - db.session.flush() + db_session_with_containers.add(dataset) + db_session_with_containers.flush() return dataset @staticmethod - def create_document(dataset: Dataset, created_by: str, name: str = "doc.txt") -> Document: + def create_document( + db_session_with_containers: Session, dataset: Dataset, created_by: str, name: str = "doc.txt" + ) -> Document: """Create a document row belonging to the given dataset.""" document = Document( tenant_id=dataset.tenant_id, @@ -102,8 +107,8 @@ class DatasetServiceIntegrationDataFactory: indexing_status="completed", doc_form="text_model", ) - db.session.add(document) - db.session.flush() + db_session_with_containers.add(document) + db_session_with_containers.flush() return document @staticmethod @@ -118,10 +123,10 @@ class DatasetServiceIntegrationDataFactory: class TestDatasetServiceCreateDataset: """Integration coverage for DatasetService.create_empty_dataset.""" - def test_create_internal_dataset_basic_success(self, db_session_with_containers): + def test_create_internal_dataset_basic_success(self, db_session_with_containers: Session): """Create a basic internal dataset with minimal configuration.""" # Arrange - account, tenant = DatasetServiceIntegrationDataFactory.create_account_with_tenant() + account, tenant = DatasetServiceIntegrationDataFactory.create_account_with_tenant(db_session_with_containers) # Act result = DatasetService.create_empty_dataset( @@ -133,17 +138,17 @@ class TestDatasetServiceCreateDataset: ) # Assert - created_dataset = db.session.get(Dataset, result.id) + created_dataset = db_session_with_containers.get(Dataset, result.id) assert created_dataset is not None assert created_dataset.provider == "vendor" assert created_dataset.permission == DatasetPermissionEnum.ONLY_ME assert created_dataset.embedding_model_provider is None assert created_dataset.embedding_model is None - def test_create_internal_dataset_with_economy_indexing(self, db_session_with_containers): + def test_create_internal_dataset_with_economy_indexing(self, db_session_with_containers: Session): """Create an internal dataset with economy indexing and no embedding model.""" # Arrange - account, tenant = DatasetServiceIntegrationDataFactory.create_account_with_tenant() + account, tenant = DatasetServiceIntegrationDataFactory.create_account_with_tenant(db_session_with_containers) # Act result = DatasetService.create_empty_dataset( @@ -155,15 +160,15 @@ class TestDatasetServiceCreateDataset: ) # Assert - db.session.refresh(result) + db_session_with_containers.refresh(result) assert result.indexing_technique == "economy" assert result.embedding_model_provider is None assert result.embedding_model is None - def test_create_internal_dataset_with_high_quality_indexing(self, db_session_with_containers): + def test_create_internal_dataset_with_high_quality_indexing(self, db_session_with_containers: Session): """Create a high-quality dataset and persist embedding model settings.""" # Arrange - account, tenant = DatasetServiceIntegrationDataFactory.create_account_with_tenant() + account, tenant = DatasetServiceIntegrationDataFactory.create_account_with_tenant(db_session_with_containers) embedding_model = DatasetServiceIntegrationDataFactory.create_embedding_model() # Act @@ -179,7 +184,7 @@ class TestDatasetServiceCreateDataset: ) # Assert - db.session.refresh(result) + db_session_with_containers.refresh(result) assert result.indexing_technique == "high_quality" assert result.embedding_model_provider == embedding_model.provider assert result.embedding_model == embedding_model.model_name @@ -188,11 +193,12 @@ class TestDatasetServiceCreateDataset: model_type=ModelType.TEXT_EMBEDDING, ) - def test_create_dataset_duplicate_name_error(self, db_session_with_containers): + def test_create_dataset_duplicate_name_error(self, db_session_with_containers: Session): """Raise duplicate-name error when the same tenant already has the name.""" # Arrange - account, tenant = DatasetServiceIntegrationDataFactory.create_account_with_tenant() + account, tenant = DatasetServiceIntegrationDataFactory.create_account_with_tenant(db_session_with_containers) DatasetServiceIntegrationDataFactory.create_dataset( + db_session_with_containers, tenant_id=tenant.id, created_by=account.id, name="Duplicate Dataset", @@ -209,10 +215,10 @@ class TestDatasetServiceCreateDataset: account=account, ) - def test_create_external_dataset_success(self, db_session_with_containers): + def test_create_external_dataset_success(self, db_session_with_containers: Session): """Create an external dataset and persist external knowledge binding.""" # Arrange - account, tenant = DatasetServiceIntegrationDataFactory.create_account_with_tenant() + account, tenant = DatasetServiceIntegrationDataFactory.create_account_with_tenant(db_session_with_containers) external_knowledge_api_id = str(uuid4()) external_knowledge_id = "knowledge-123" @@ -231,16 +237,16 @@ class TestDatasetServiceCreateDataset: ) # Assert - binding = db.session.query(ExternalKnowledgeBindings).filter_by(dataset_id=result.id).first() + binding = db_session_with_containers.query(ExternalKnowledgeBindings).filter_by(dataset_id=result.id).first() assert result.provider == "external" assert binding is not None assert binding.external_knowledge_id == external_knowledge_id assert binding.external_knowledge_api_id == external_knowledge_api_id - def test_create_dataset_with_retrieval_model_and_reranking(self, db_session_with_containers): + def test_create_dataset_with_retrieval_model_and_reranking(self, db_session_with_containers: Session): """Create a high-quality dataset with retrieval/reranking settings.""" # Arrange - account, tenant = DatasetServiceIntegrationDataFactory.create_account_with_tenant() + account, tenant = DatasetServiceIntegrationDataFactory.create_account_with_tenant(db_session_with_containers) embedding_model = DatasetServiceIntegrationDataFactory.create_embedding_model() retrieval_model = RetrievalModel( search_method=RetrievalMethod.SEMANTIC_SEARCH, @@ -271,14 +277,16 @@ class TestDatasetServiceCreateDataset: ) # Assert - db.session.refresh(result) + db_session_with_containers.refresh(result) assert result.retrieval_model == retrieval_model.model_dump() mock_check_reranking.assert_called_once_with(tenant.id, "cohere", "rerank-english-v2.0") - def test_create_internal_dataset_with_high_quality_indexing_custom_embedding(self, db_session_with_containers): + def test_create_internal_dataset_with_high_quality_indexing_custom_embedding( + self, db_session_with_containers: Session + ): """Create high-quality dataset with explicitly configured embedding model.""" # Arrange - account, tenant = DatasetServiceIntegrationDataFactory.create_account_with_tenant() + account, tenant = DatasetServiceIntegrationDataFactory.create_account_with_tenant(db_session_with_containers) embedding_provider = "openai" embedding_model_name = "text-embedding-3-small" embedding_model = DatasetServiceIntegrationDataFactory.create_embedding_model( @@ -303,7 +311,7 @@ class TestDatasetServiceCreateDataset: ) # Assert - db.session.refresh(result) + db_session_with_containers.refresh(result) assert result.indexing_technique == "high_quality" assert result.embedding_model_provider == embedding_provider assert result.embedding_model == embedding_model_name @@ -315,10 +323,10 @@ class TestDatasetServiceCreateDataset: model=embedding_model_name, ) - def test_create_internal_dataset_with_retrieval_model(self, db_session_with_containers): + def test_create_internal_dataset_with_retrieval_model(self, db_session_with_containers: Session): """Persist retrieval model settings when creating an internal dataset.""" # Arrange - account, tenant = DatasetServiceIntegrationDataFactory.create_account_with_tenant() + account, tenant = DatasetServiceIntegrationDataFactory.create_account_with_tenant(db_session_with_containers) retrieval_model = RetrievalModel( search_method=RetrievalMethod.SEMANTIC_SEARCH, reranking_enable=False, @@ -338,13 +346,13 @@ class TestDatasetServiceCreateDataset: ) # Assert - db.session.refresh(result) + db_session_with_containers.refresh(result) assert result.retrieval_model == retrieval_model.model_dump() - def test_create_internal_dataset_with_custom_permission(self, db_session_with_containers): + def test_create_internal_dataset_with_custom_permission(self, db_session_with_containers: Session): """Persist canonical custom permission when creating an internal dataset.""" # Arrange - account, tenant = DatasetServiceIntegrationDataFactory.create_account_with_tenant() + account, tenant = DatasetServiceIntegrationDataFactory.create_account_with_tenant(db_session_with_containers) # Act result = DatasetService.create_empty_dataset( @@ -357,13 +365,13 @@ class TestDatasetServiceCreateDataset: ) # Assert - db.session.refresh(result) + db_session_with_containers.refresh(result) assert result.permission == DatasetPermissionEnum.ALL_TEAM - def test_create_external_dataset_missing_api_id_error(self, db_session_with_containers): + def test_create_external_dataset_missing_api_id_error(self, db_session_with_containers: Session): """Raise error when external API template does not exist.""" # Arrange - account, tenant = DatasetServiceIntegrationDataFactory.create_account_with_tenant() + account, tenant = DatasetServiceIntegrationDataFactory.create_account_with_tenant(db_session_with_containers) external_knowledge_api_id = str(uuid4()) # Act / Assert @@ -381,10 +389,10 @@ class TestDatasetServiceCreateDataset: external_knowledge_id="knowledge-123", ) - def test_create_external_dataset_missing_knowledge_id_error(self, db_session_with_containers): + def test_create_external_dataset_missing_knowledge_id_error(self, db_session_with_containers: Session): """Raise error when external knowledge id is missing for external dataset creation.""" # Arrange - account, tenant = DatasetServiceIntegrationDataFactory.create_account_with_tenant() + account, tenant = DatasetServiceIntegrationDataFactory.create_account_with_tenant(db_session_with_containers) external_knowledge_api_id = str(uuid4()) # Act / Assert @@ -406,10 +414,10 @@ class TestDatasetServiceCreateDataset: class TestDatasetServiceCreateRagPipelineDataset: """Integration coverage for DatasetService.create_empty_rag_pipeline_dataset.""" - def test_create_rag_pipeline_dataset_with_name_success(self, db_session_with_containers): + def test_create_rag_pipeline_dataset_with_name_success(self, db_session_with_containers: Session): """Create rag-pipeline dataset and pipeline rows when a name is provided.""" # Arrange - account, tenant = DatasetServiceIntegrationDataFactory.create_account_with_tenant() + account, tenant = DatasetServiceIntegrationDataFactory.create_account_with_tenant(db_session_with_containers) icon_info = IconInfo(icon="📙", icon_background="#FFF4ED", icon_type="emoji") entity = RagPipelineDatasetCreateEntity( name="RAG Pipeline Dataset", @@ -425,8 +433,8 @@ class TestDatasetServiceCreateRagPipelineDataset: ) # Assert - created_dataset = db.session.get(Dataset, result.id) - created_pipeline = db.session.get(Pipeline, result.pipeline_id) + created_dataset = db_session_with_containers.get(Dataset, result.id) + created_pipeline = db_session_with_containers.get(Pipeline, result.pipeline_id) assert created_dataset is not None assert created_dataset.name == entity.name assert created_dataset.runtime_mode == "rag_pipeline" @@ -436,10 +444,10 @@ class TestDatasetServiceCreateRagPipelineDataset: assert created_pipeline.name == entity.name assert created_pipeline.created_by == account.id - def test_create_rag_pipeline_dataset_with_auto_generated_name(self, db_session_with_containers): + def test_create_rag_pipeline_dataset_with_auto_generated_name(self, db_session_with_containers: Session): """Create rag-pipeline dataset with generated incremental name when input name is empty.""" # Arrange - account, tenant = DatasetServiceIntegrationDataFactory.create_account_with_tenant() + account, tenant = DatasetServiceIntegrationDataFactory.create_account_with_tenant(db_session_with_containers) generated_name = "Untitled 1" icon_info = IconInfo(icon="📙", icon_background="#FFF4ED", icon_type="emoji") entity = RagPipelineDatasetCreateEntity( @@ -460,25 +468,26 @@ class TestDatasetServiceCreateRagPipelineDataset: ) # Assert - db.session.refresh(result) - created_pipeline = db.session.get(Pipeline, result.pipeline_id) + db_session_with_containers.refresh(result) + created_pipeline = db_session_with_containers.get(Pipeline, result.pipeline_id) assert result.name == generated_name assert created_pipeline is not None assert created_pipeline.name == generated_name mock_generate_name.assert_called_once() - def test_create_rag_pipeline_dataset_duplicate_name_error(self, db_session_with_containers): + def test_create_rag_pipeline_dataset_duplicate_name_error(self, db_session_with_containers: Session): """Raise duplicate-name error when rag-pipeline dataset name already exists.""" # Arrange - account, tenant = DatasetServiceIntegrationDataFactory.create_account_with_tenant() + account, tenant = DatasetServiceIntegrationDataFactory.create_account_with_tenant(db_session_with_containers) duplicate_name = "Duplicate RAG Dataset" DatasetServiceIntegrationDataFactory.create_dataset( + db_session_with_containers, tenant_id=tenant.id, created_by=account.id, name=duplicate_name, indexing_technique=None, ) - db.session.commit() + db_session_with_containers.commit() icon_info = IconInfo(icon="📙", icon_background="#FFF4ED", icon_type="emoji") entity = RagPipelineDatasetCreateEntity( name=duplicate_name, @@ -496,10 +505,10 @@ class TestDatasetServiceCreateRagPipelineDataset: tenant_id=tenant.id, rag_pipeline_dataset_create_entity=entity ) - def test_create_rag_pipeline_dataset_with_custom_permission(self, db_session_with_containers): + def test_create_rag_pipeline_dataset_with_custom_permission(self, db_session_with_containers: Session): """Persist canonical custom permission for rag-pipeline dataset creation.""" # Arrange - account, tenant = DatasetServiceIntegrationDataFactory.create_account_with_tenant() + account, tenant = DatasetServiceIntegrationDataFactory.create_account_with_tenant(db_session_with_containers) icon_info = IconInfo(icon="📙", icon_background="#FFF4ED", icon_type="emoji") entity = RagPipelineDatasetCreateEntity( name="Custom Permission RAG Dataset", @@ -515,13 +524,13 @@ class TestDatasetServiceCreateRagPipelineDataset: ) # Assert - db.session.refresh(result) + db_session_with_containers.refresh(result) assert result.permission == DatasetPermissionEnum.ALL_TEAM - def test_create_rag_pipeline_dataset_with_icon_info(self, db_session_with_containers): + def test_create_rag_pipeline_dataset_with_icon_info(self, db_session_with_containers: Session): """Persist icon metadata when creating rag-pipeline dataset.""" # Arrange - account, tenant = DatasetServiceIntegrationDataFactory.create_account_with_tenant() + account, tenant = DatasetServiceIntegrationDataFactory.create_account_with_tenant(db_session_with_containers) icon_info = IconInfo( icon="📚", icon_background="#E8F5E9", @@ -542,23 +551,25 @@ class TestDatasetServiceCreateRagPipelineDataset: ) # Assert - db.session.refresh(result) + db_session_with_containers.refresh(result) assert result.icon_info == icon_info.model_dump() class TestDatasetServiceUpdateAndDeleteDataset: """Integration coverage for SQL-backed update and delete behavior.""" - def test_update_dataset_duplicate_name_error(self, db_session_with_containers): + def test_update_dataset_duplicate_name_error(self, db_session_with_containers: Session): """Reject update when target name already exists within the same tenant.""" # Arrange - account, tenant = DatasetServiceIntegrationDataFactory.create_account_with_tenant() + account, tenant = DatasetServiceIntegrationDataFactory.create_account_with_tenant(db_session_with_containers) source_dataset = DatasetServiceIntegrationDataFactory.create_dataset( + db_session_with_containers, tenant_id=tenant.id, created_by=account.id, name="Source Dataset", ) DatasetServiceIntegrationDataFactory.create_dataset( + db_session_with_containers, tenant_id=tenant.id, created_by=account.id, name="Existing Dataset", @@ -568,17 +579,20 @@ class TestDatasetServiceUpdateAndDeleteDataset: with pytest.raises(ValueError, match="Dataset name already exists"): DatasetService.update_dataset(source_dataset.id, {"name": "Existing Dataset"}, account) - def test_delete_dataset_with_documents_success(self, db_session_with_containers): + def test_delete_dataset_with_documents_success(self, db_session_with_containers: Session): """Delete a dataset that already has documents.""" # Arrange - account, tenant = DatasetServiceIntegrationDataFactory.create_account_with_tenant() + account, tenant = DatasetServiceIntegrationDataFactory.create_account_with_tenant(db_session_with_containers) dataset = DatasetServiceIntegrationDataFactory.create_dataset( + db_session_with_containers, tenant_id=tenant.id, created_by=account.id, indexing_technique="high_quality", chunk_structure="text_model", ) - DatasetServiceIntegrationDataFactory.create_document(dataset=dataset, created_by=account.id) + DatasetServiceIntegrationDataFactory.create_document( + db_session_with_containers, dataset=dataset, created_by=account.id + ) # Act with patch("services.dataset_service.dataset_was_deleted") as dataset_deleted_signal: @@ -586,14 +600,15 @@ class TestDatasetServiceUpdateAndDeleteDataset: # Assert assert result is True - assert db.session.get(Dataset, dataset.id) is None + assert db_session_with_containers.get(Dataset, dataset.id) is None dataset_deleted_signal.send.assert_called_once_with(dataset) - def test_delete_empty_dataset_success(self, db_session_with_containers): + def test_delete_empty_dataset_success(self, db_session_with_containers: Session): """Delete a dataset that has no documents and no indexing technique.""" # Arrange - account, tenant = DatasetServiceIntegrationDataFactory.create_account_with_tenant() + account, tenant = DatasetServiceIntegrationDataFactory.create_account_with_tenant(db_session_with_containers) dataset = DatasetServiceIntegrationDataFactory.create_dataset( + db_session_with_containers, tenant_id=tenant.id, created_by=account.id, indexing_technique=None, @@ -606,14 +621,15 @@ class TestDatasetServiceUpdateAndDeleteDataset: # Assert assert result is True - assert db.session.get(Dataset, dataset.id) is None + assert db_session_with_containers.get(Dataset, dataset.id) is None dataset_deleted_signal.send.assert_called_once_with(dataset) - def test_delete_dataset_with_partial_none_values(self, db_session_with_containers): + def test_delete_dataset_with_partial_none_values(self, db_session_with_containers: Session): """Delete dataset when indexing_technique is None but doc_form path still exists.""" # Arrange - account, tenant = DatasetServiceIntegrationDataFactory.create_account_with_tenant() + account, tenant = DatasetServiceIntegrationDataFactory.create_account_with_tenant(db_session_with_containers) dataset = DatasetServiceIntegrationDataFactory.create_dataset( + db_session_with_containers, tenant_id=tenant.id, created_by=account.id, indexing_technique=None, @@ -626,17 +642,17 @@ class TestDatasetServiceUpdateAndDeleteDataset: # Assert assert result is True - assert db.session.get(Dataset, dataset.id) is None + assert db_session_with_containers.get(Dataset, dataset.id) is None dataset_deleted_signal.send.assert_called_once_with(dataset) class TestDatasetServiceRetrievalConfiguration: """Integration coverage for retrieval configuration persistence.""" - def test_get_dataset_retrieval_configuration(self, db_session_with_containers): + def test_get_dataset_retrieval_configuration(self, db_session_with_containers: Session): """Return retrieval configuration that is persisted in SQL.""" # Arrange - account, tenant = DatasetServiceIntegrationDataFactory.create_account_with_tenant() + account, tenant = DatasetServiceIntegrationDataFactory.create_account_with_tenant(db_session_with_containers) retrieval_model = { "search_method": "semantic_search", "top_k": 5, @@ -644,6 +660,7 @@ class TestDatasetServiceRetrievalConfiguration: "reranking_enable": True, } dataset = DatasetServiceIntegrationDataFactory.create_dataset( + db_session_with_containers, tenant_id=tenant.id, created_by=account.id, retrieval_model=retrieval_model, @@ -658,11 +675,12 @@ class TestDatasetServiceRetrievalConfiguration: assert result.retrieval_model["search_method"] == "semantic_search" assert result.retrieval_model["top_k"] == 5 - def test_update_dataset_retrieval_configuration(self, db_session_with_containers): + def test_update_dataset_retrieval_configuration(self, db_session_with_containers: Session): """Persist retrieval configuration updates through DatasetService.update_dataset.""" # Arrange - account, tenant = DatasetServiceIntegrationDataFactory.create_account_with_tenant() + account, tenant = DatasetServiceIntegrationDataFactory.create_account_with_tenant(db_session_with_containers) dataset = DatasetServiceIntegrationDataFactory.create_dataset( + db_session_with_containers, tenant_id=tenant.id, created_by=account.id, indexing_technique="high_quality", @@ -684,6 +702,6 @@ class TestDatasetServiceRetrievalConfiguration: result = DatasetService.update_dataset(dataset.id, update_data, account) # Assert - db.session.refresh(dataset) + db_session_with_containers.refresh(dataset) assert result.id == dataset.id assert dataset.retrieval_model == update_data["retrieval_model"] diff --git a/api/tests/test_containers_integration_tests/services/test_dataset_service_batch_update_document_status.py b/api/tests/test_containers_integration_tests/services/test_dataset_service_batch_update_document_status.py index ffdb501474..322b67d373 100644 --- a/api/tests/test_containers_integration_tests/services/test_dataset_service_batch_update_document_status.py +++ b/api/tests/test_containers_integration_tests/services/test_dataset_service_batch_update_document_status.py @@ -11,8 +11,8 @@ from unittest.mock import call, patch from uuid import uuid4 import pytest +from sqlalchemy.orm import Session -from extensions.ext_database import db from models.dataset import Dataset, Document from services.dataset_service import DocumentService from services.errors.document import DocumentIndexingError @@ -32,6 +32,7 @@ class DocumentBatchUpdateIntegrationDataFactory: @staticmethod def create_dataset( + db_session_with_containers: Session, dataset_id: str | None = None, tenant_id: str | None = None, name: str = "Test Dataset", @@ -47,12 +48,13 @@ class DocumentBatchUpdateIntegrationDataFactory: if dataset_id: dataset.id = dataset_id - db.session.add(dataset) - db.session.commit() + db_session_with_containers.add(dataset) + db_session_with_containers.commit() return dataset @staticmethod def create_document( + db_session_with_containers: Session, dataset: Dataset, document_id: str | None = None, name: str = "test_document.pdf", @@ -89,13 +91,14 @@ class DocumentBatchUpdateIntegrationDataFactory: for key, value in kwargs.items(): setattr(document, key, value) - db.session.add(document) + db_session_with_containers.add(document) if commit: - db.session.commit() + db_session_with_containers.commit() return document @staticmethod def create_multiple_documents( + db_session_with_containers: Session, dataset: Dataset, document_ids: list[str], enabled: bool = True, @@ -106,6 +109,7 @@ class DocumentBatchUpdateIntegrationDataFactory: documents: list[Document] = [] for index, doc_id in enumerate(document_ids, start=1): document = DocumentBatchUpdateIntegrationDataFactory.create_document( + db_session_with_containers, dataset=dataset, document_id=doc_id, name=f"document_{doc_id}.pdf", @@ -116,7 +120,7 @@ class DocumentBatchUpdateIntegrationDataFactory: commit=False, ) documents.append(document) - db.session.commit() + db_session_with_containers.commit() return documents @staticmethod @@ -173,13 +177,14 @@ class TestDatasetServiceBatchUpdateDocumentStatus: assert document.archived_at is None assert document.archived_by is None - def test_batch_update_enable_documents_success(self, db_session_with_containers, patched_dependencies): + def test_batch_update_enable_documents_success(self, db_session_with_containers: Session, patched_dependencies): """Enable disabled documents and trigger indexing side effects.""" # Arrange - dataset = DocumentBatchUpdateIntegrationDataFactory.create_dataset() + dataset = DocumentBatchUpdateIntegrationDataFactory.create_dataset(db_session_with_containers) user = DocumentBatchUpdateIntegrationDataFactory.create_user() document_ids = [str(uuid4()), str(uuid4())] disabled_docs = DocumentBatchUpdateIntegrationDataFactory.create_multiple_documents( + db_session_with_containers, dataset=dataset, document_ids=document_ids, enabled=False, @@ -192,7 +197,7 @@ class TestDatasetServiceBatchUpdateDocumentStatus: # Assert for document in disabled_docs: - db.session.refresh(document) + db_session_with_containers.refresh(document) self._assert_document_enabled(document, FIXED_TIME) expected_get_calls = [call(f"document_{doc_id}_indexing") for doc_id in document_ids] @@ -203,13 +208,15 @@ class TestDatasetServiceBatchUpdateDocumentStatus: patched_dependencies["add_task"].delay.assert_has_calls(expected_add_calls) def test_batch_update_enable_already_enabled_document_skipped( - self, db_session_with_containers, patched_dependencies + self, db_session_with_containers: Session, patched_dependencies ): """Skip enable operation for already-enabled documents.""" # Arrange - dataset = DocumentBatchUpdateIntegrationDataFactory.create_dataset() + dataset = DocumentBatchUpdateIntegrationDataFactory.create_dataset(db_session_with_containers) user = DocumentBatchUpdateIntegrationDataFactory.create_user() - document = DocumentBatchUpdateIntegrationDataFactory.create_document(dataset=dataset, enabled=True) + document = DocumentBatchUpdateIntegrationDataFactory.create_document( + db_session_with_containers, dataset=dataset, enabled=True + ) # Act DocumentService.batch_update_document_status( @@ -220,18 +227,19 @@ class TestDatasetServiceBatchUpdateDocumentStatus: ) # Assert - db.session.refresh(document) + db_session_with_containers.refresh(document) assert document.enabled is True patched_dependencies["redis_client"].setex.assert_not_called() patched_dependencies["add_task"].delay.assert_not_called() - def test_batch_update_disable_documents_success(self, db_session_with_containers, patched_dependencies): + def test_batch_update_disable_documents_success(self, db_session_with_containers: Session, patched_dependencies): """Disable completed documents and trigger remove-index tasks.""" # Arrange - dataset = DocumentBatchUpdateIntegrationDataFactory.create_dataset() + dataset = DocumentBatchUpdateIntegrationDataFactory.create_dataset(db_session_with_containers) user = DocumentBatchUpdateIntegrationDataFactory.create_user() document_ids = [str(uuid4()), str(uuid4())] enabled_docs = DocumentBatchUpdateIntegrationDataFactory.create_multiple_documents( + db_session_with_containers, dataset=dataset, document_ids=document_ids, enabled=True, @@ -248,7 +256,7 @@ class TestDatasetServiceBatchUpdateDocumentStatus: # Assert for document in enabled_docs: - db.session.refresh(document) + db_session_with_containers.refresh(document) self._assert_document_disabled(document, user.id, FIXED_TIME) expected_get_calls = [call(f"document_{doc_id}_indexing") for doc_id in document_ids] @@ -259,13 +267,14 @@ class TestDatasetServiceBatchUpdateDocumentStatus: patched_dependencies["remove_task"].delay.assert_has_calls(expected_remove_calls) def test_batch_update_disable_already_disabled_document_skipped( - self, db_session_with_containers, patched_dependencies + self, db_session_with_containers: Session, patched_dependencies ): """Skip disable operation for already-disabled documents.""" # Arrange - dataset = DocumentBatchUpdateIntegrationDataFactory.create_dataset() + dataset = DocumentBatchUpdateIntegrationDataFactory.create_dataset(db_session_with_containers) user = DocumentBatchUpdateIntegrationDataFactory.create_user() disabled_doc = DocumentBatchUpdateIntegrationDataFactory.create_document( + db_session_with_containers, dataset=dataset, enabled=False, indexing_status="completed", @@ -281,17 +290,20 @@ class TestDatasetServiceBatchUpdateDocumentStatus: ) # Assert - db.session.refresh(disabled_doc) + db_session_with_containers.refresh(disabled_doc) assert disabled_doc.enabled is False patched_dependencies["redis_client"].setex.assert_not_called() patched_dependencies["remove_task"].delay.assert_not_called() - def test_batch_update_disable_non_completed_document_error(self, db_session_with_containers, patched_dependencies): + def test_batch_update_disable_non_completed_document_error( + self, db_session_with_containers: Session, patched_dependencies + ): """Raise error when disabling a non-completed document.""" # Arrange - dataset = DocumentBatchUpdateIntegrationDataFactory.create_dataset() + dataset = DocumentBatchUpdateIntegrationDataFactory.create_dataset(db_session_with_containers) user = DocumentBatchUpdateIntegrationDataFactory.create_user() non_completed_doc = DocumentBatchUpdateIntegrationDataFactory.create_document( + db_session_with_containers, dataset=dataset, enabled=True, indexing_status="indexing", @@ -307,13 +319,13 @@ class TestDatasetServiceBatchUpdateDocumentStatus: user=user, ) - def test_batch_update_archive_documents_success(self, db_session_with_containers, patched_dependencies): + def test_batch_update_archive_documents_success(self, db_session_with_containers: Session, patched_dependencies): """Archive enabled documents and trigger remove-index task.""" # Arrange - dataset = DocumentBatchUpdateIntegrationDataFactory.create_dataset() + dataset = DocumentBatchUpdateIntegrationDataFactory.create_dataset(db_session_with_containers) user = DocumentBatchUpdateIntegrationDataFactory.create_user() document = DocumentBatchUpdateIntegrationDataFactory.create_document( - dataset=dataset, enabled=True, archived=False + db_session_with_containers, dataset=dataset, enabled=True, archived=False ) # Act @@ -325,21 +337,21 @@ class TestDatasetServiceBatchUpdateDocumentStatus: ) # Assert - db.session.refresh(document) + db_session_with_containers.refresh(document) self._assert_document_archived(document, user.id, FIXED_TIME) patched_dependencies["redis_client"].get.assert_called_once_with(f"document_{document.id}_indexing") patched_dependencies["redis_client"].setex.assert_called_once_with(f"document_{document.id}_indexing", 600, 1) patched_dependencies["remove_task"].delay.assert_called_once_with(document.id) def test_batch_update_archive_already_archived_document_skipped( - self, db_session_with_containers, patched_dependencies + self, db_session_with_containers: Session, patched_dependencies ): """Skip archive operation for already-archived documents.""" # Arrange - dataset = DocumentBatchUpdateIntegrationDataFactory.create_dataset() + dataset = DocumentBatchUpdateIntegrationDataFactory.create_dataset(db_session_with_containers) user = DocumentBatchUpdateIntegrationDataFactory.create_user() document = DocumentBatchUpdateIntegrationDataFactory.create_document( - dataset=dataset, enabled=True, archived=True + db_session_with_containers, dataset=dataset, enabled=True, archived=True ) # Act @@ -351,20 +363,20 @@ class TestDatasetServiceBatchUpdateDocumentStatus: ) # Assert - db.session.refresh(document) + db_session_with_containers.refresh(document) assert document.archived is True patched_dependencies["redis_client"].setex.assert_not_called() patched_dependencies["remove_task"].delay.assert_not_called() def test_batch_update_archive_disabled_document_no_index_removal( - self, db_session_with_containers, patched_dependencies + self, db_session_with_containers: Session, patched_dependencies ): """Archive disabled document without index-removal side effects.""" # Arrange - dataset = DocumentBatchUpdateIntegrationDataFactory.create_dataset() + dataset = DocumentBatchUpdateIntegrationDataFactory.create_dataset(db_session_with_containers) user = DocumentBatchUpdateIntegrationDataFactory.create_user() document = DocumentBatchUpdateIntegrationDataFactory.create_document( - dataset=dataset, enabled=False, archived=False + db_session_with_containers, dataset=dataset, enabled=False, archived=False ) # Act @@ -376,18 +388,18 @@ class TestDatasetServiceBatchUpdateDocumentStatus: ) # Assert - db.session.refresh(document) + db_session_with_containers.refresh(document) self._assert_document_archived(document, user.id, FIXED_TIME) patched_dependencies["redis_client"].setex.assert_not_called() patched_dependencies["remove_task"].delay.assert_not_called() - def test_batch_update_unarchive_documents_success(self, db_session_with_containers, patched_dependencies): + def test_batch_update_unarchive_documents_success(self, db_session_with_containers: Session, patched_dependencies): """Unarchive enabled documents and trigger add-index task.""" # Arrange - dataset = DocumentBatchUpdateIntegrationDataFactory.create_dataset() + dataset = DocumentBatchUpdateIntegrationDataFactory.create_dataset(db_session_with_containers) user = DocumentBatchUpdateIntegrationDataFactory.create_user() document = DocumentBatchUpdateIntegrationDataFactory.create_document( - dataset=dataset, enabled=True, archived=True + db_session_with_containers, dataset=dataset, enabled=True, archived=True ) # Act @@ -399,7 +411,7 @@ class TestDatasetServiceBatchUpdateDocumentStatus: ) # Assert - db.session.refresh(document) + db_session_with_containers.refresh(document) self._assert_document_unarchived(document) assert document.updated_at == FIXED_TIME patched_dependencies["redis_client"].get.assert_called_once_with(f"document_{document.id}_indexing") @@ -407,14 +419,14 @@ class TestDatasetServiceBatchUpdateDocumentStatus: patched_dependencies["add_task"].delay.assert_called_once_with(document.id) def test_batch_update_unarchive_already_unarchived_document_skipped( - self, db_session_with_containers, patched_dependencies + self, db_session_with_containers: Session, patched_dependencies ): """Skip unarchive operation for already-unarchived documents.""" # Arrange - dataset = DocumentBatchUpdateIntegrationDataFactory.create_dataset() + dataset = DocumentBatchUpdateIntegrationDataFactory.create_dataset(db_session_with_containers) user = DocumentBatchUpdateIntegrationDataFactory.create_user() document = DocumentBatchUpdateIntegrationDataFactory.create_document( - dataset=dataset, enabled=True, archived=False + db_session_with_containers, dataset=dataset, enabled=True, archived=False ) # Act @@ -426,20 +438,20 @@ class TestDatasetServiceBatchUpdateDocumentStatus: ) # Assert - db.session.refresh(document) + db_session_with_containers.refresh(document) assert document.archived is False patched_dependencies["redis_client"].setex.assert_not_called() patched_dependencies["add_task"].delay.assert_not_called() def test_batch_update_unarchive_disabled_document_no_index_addition( - self, db_session_with_containers, patched_dependencies + self, db_session_with_containers: Session, patched_dependencies ): """Unarchive disabled document without index-add side effects.""" # Arrange - dataset = DocumentBatchUpdateIntegrationDataFactory.create_dataset() + dataset = DocumentBatchUpdateIntegrationDataFactory.create_dataset(db_session_with_containers) user = DocumentBatchUpdateIntegrationDataFactory.create_user() document = DocumentBatchUpdateIntegrationDataFactory.create_document( - dataset=dataset, enabled=False, archived=True + db_session_with_containers, dataset=dataset, enabled=False, archived=True ) # Act @@ -451,20 +463,21 @@ class TestDatasetServiceBatchUpdateDocumentStatus: ) # Assert - db.session.refresh(document) + db_session_with_containers.refresh(document) self._assert_document_unarchived(document) assert document.updated_at == FIXED_TIME patched_dependencies["redis_client"].setex.assert_not_called() patched_dependencies["add_task"].delay.assert_not_called() def test_batch_update_document_indexing_error_redis_cache_hit( - self, db_session_with_containers, patched_dependencies + self, db_session_with_containers: Session, patched_dependencies ): """Raise DocumentIndexingError when redis indicates active indexing.""" # Arrange - dataset = DocumentBatchUpdateIntegrationDataFactory.create_dataset() + dataset = DocumentBatchUpdateIntegrationDataFactory.create_dataset(db_session_with_containers) user = DocumentBatchUpdateIntegrationDataFactory.create_user() document = DocumentBatchUpdateIntegrationDataFactory.create_document( + db_session_with_containers, dataset=dataset, name="test_document.pdf", enabled=True, @@ -483,12 +496,14 @@ class TestDatasetServiceBatchUpdateDocumentStatus: assert "test_document.pdf" in str(exc_info.value) patched_dependencies["redis_client"].get.assert_called_once_with(f"document_{document.id}_indexing") - def test_batch_update_async_task_error_handling(self, db_session_with_containers, patched_dependencies): + def test_batch_update_async_task_error_handling(self, db_session_with_containers: Session, patched_dependencies): """Persist DB update, then propagate async task error.""" # Arrange - dataset = DocumentBatchUpdateIntegrationDataFactory.create_dataset() + dataset = DocumentBatchUpdateIntegrationDataFactory.create_dataset(db_session_with_containers) user = DocumentBatchUpdateIntegrationDataFactory.create_user() - document = DocumentBatchUpdateIntegrationDataFactory.create_document(dataset=dataset, enabled=False) + document = DocumentBatchUpdateIntegrationDataFactory.create_document( + db_session_with_containers, dataset=dataset, enabled=False + ) patched_dependencies["add_task"].delay.side_effect = Exception("Celery task error") # Act / Assert @@ -500,14 +515,14 @@ class TestDatasetServiceBatchUpdateDocumentStatus: user=user, ) - db.session.refresh(document) + db_session_with_containers.refresh(document) self._assert_document_enabled(document, FIXED_TIME) patched_dependencies["redis_client"].setex.assert_called_once_with(f"document_{document.id}_indexing", 600, 1) - def test_batch_update_empty_document_list(self, db_session_with_containers, patched_dependencies): + def test_batch_update_empty_document_list(self, db_session_with_containers: Session, patched_dependencies): """Return early when document_ids is empty.""" # Arrange - dataset = DocumentBatchUpdateIntegrationDataFactory.create_dataset() + dataset = DocumentBatchUpdateIntegrationDataFactory.create_dataset(db_session_with_containers) user = DocumentBatchUpdateIntegrationDataFactory.create_user() # Act @@ -520,10 +535,10 @@ class TestDatasetServiceBatchUpdateDocumentStatus: patched_dependencies["redis_client"].get.assert_not_called() patched_dependencies["redis_client"].setex.assert_not_called() - def test_batch_update_document_not_found_skipped(self, db_session_with_containers, patched_dependencies): + def test_batch_update_document_not_found_skipped(self, db_session_with_containers: Session, patched_dependencies): """Skip IDs that do not map to existing dataset documents.""" # Arrange - dataset = DocumentBatchUpdateIntegrationDataFactory.create_dataset() + dataset = DocumentBatchUpdateIntegrationDataFactory.create_dataset(db_session_with_containers) user = DocumentBatchUpdateIntegrationDataFactory.create_user() missing_document_id = str(uuid4()) @@ -540,18 +555,24 @@ class TestDatasetServiceBatchUpdateDocumentStatus: patched_dependencies["redis_client"].setex.assert_not_called() patched_dependencies["add_task"].delay.assert_not_called() - def test_batch_update_mixed_document_states_and_actions(self, db_session_with_containers, patched_dependencies): + def test_batch_update_mixed_document_states_and_actions( + self, db_session_with_containers: Session, patched_dependencies + ): """Process only the applicable document in a mixed-state enable batch.""" # Arrange - dataset = DocumentBatchUpdateIntegrationDataFactory.create_dataset() + dataset = DocumentBatchUpdateIntegrationDataFactory.create_dataset(db_session_with_containers) user = DocumentBatchUpdateIntegrationDataFactory.create_user() - disabled_doc = DocumentBatchUpdateIntegrationDataFactory.create_document(dataset=dataset, enabled=False) + disabled_doc = DocumentBatchUpdateIntegrationDataFactory.create_document( + db_session_with_containers, dataset=dataset, enabled=False + ) enabled_doc = DocumentBatchUpdateIntegrationDataFactory.create_document( + db_session_with_containers, dataset=dataset, enabled=True, position=2, ) archived_doc = DocumentBatchUpdateIntegrationDataFactory.create_document( + db_session_with_containers, dataset=dataset, enabled=True, archived=True, @@ -568,9 +589,9 @@ class TestDatasetServiceBatchUpdateDocumentStatus: ) # Assert - db.session.refresh(disabled_doc) - db.session.refresh(enabled_doc) - db.session.refresh(archived_doc) + db_session_with_containers.refresh(disabled_doc) + db_session_with_containers.refresh(enabled_doc) + db_session_with_containers.refresh(archived_doc) self._assert_document_enabled(disabled_doc, FIXED_TIME) assert enabled_doc.enabled is True assert archived_doc.enabled is True @@ -582,13 +603,16 @@ class TestDatasetServiceBatchUpdateDocumentStatus: ) patched_dependencies["add_task"].delay.assert_called_once_with(disabled_doc.id) - def test_batch_update_large_document_list_performance(self, db_session_with_containers, patched_dependencies): + def test_batch_update_large_document_list_performance( + self, db_session_with_containers: Session, patched_dependencies + ): """Handle large document lists with consistent updates and side effects.""" # Arrange - dataset = DocumentBatchUpdateIntegrationDataFactory.create_dataset() + dataset = DocumentBatchUpdateIntegrationDataFactory.create_dataset(db_session_with_containers) user = DocumentBatchUpdateIntegrationDataFactory.create_user() document_ids = [str(uuid4()) for _ in range(100)] documents = DocumentBatchUpdateIntegrationDataFactory.create_multiple_documents( + db_session_with_containers, dataset=dataset, document_ids=document_ids, enabled=False, @@ -604,7 +628,7 @@ class TestDatasetServiceBatchUpdateDocumentStatus: # Assert for document in documents: - db.session.refresh(document) + db_session_with_containers.refresh(document) self._assert_document_enabled(document, FIXED_TIME) assert patched_dependencies["redis_client"].setex.call_count == len(document_ids) @@ -616,17 +640,26 @@ class TestDatasetServiceBatchUpdateDocumentStatus: patched_dependencies["add_task"].delay.assert_has_calls(expected_task_calls) def test_batch_update_mixed_document_states_complex_scenario( - self, db_session_with_containers, patched_dependencies + self, db_session_with_containers: Session, patched_dependencies ): """Process a complex mixed-state batch and update only eligible records.""" # Arrange - dataset = DocumentBatchUpdateIntegrationDataFactory.create_dataset() + dataset = DocumentBatchUpdateIntegrationDataFactory.create_dataset(db_session_with_containers) user = DocumentBatchUpdateIntegrationDataFactory.create_user() - doc1 = DocumentBatchUpdateIntegrationDataFactory.create_document(dataset=dataset, enabled=False) - doc2 = DocumentBatchUpdateIntegrationDataFactory.create_document(dataset=dataset, enabled=True, position=2) - doc3 = DocumentBatchUpdateIntegrationDataFactory.create_document(dataset=dataset, enabled=True, position=3) - doc4 = DocumentBatchUpdateIntegrationDataFactory.create_document(dataset=dataset, enabled=True, position=4) + doc1 = DocumentBatchUpdateIntegrationDataFactory.create_document( + db_session_with_containers, dataset=dataset, enabled=False + ) + doc2 = DocumentBatchUpdateIntegrationDataFactory.create_document( + db_session_with_containers, dataset=dataset, enabled=True, position=2 + ) + doc3 = DocumentBatchUpdateIntegrationDataFactory.create_document( + db_session_with_containers, dataset=dataset, enabled=True, position=3 + ) + doc4 = DocumentBatchUpdateIntegrationDataFactory.create_document( + db_session_with_containers, dataset=dataset, enabled=True, position=4 + ) doc5 = DocumentBatchUpdateIntegrationDataFactory.create_document( + db_session_with_containers, dataset=dataset, enabled=True, archived=True, @@ -645,11 +678,11 @@ class TestDatasetServiceBatchUpdateDocumentStatus: ) # Assert - db.session.refresh(doc1) - db.session.refresh(doc2) - db.session.refresh(doc3) - db.session.refresh(doc4) - db.session.refresh(doc5) + db_session_with_containers.refresh(doc1) + db_session_with_containers.refresh(doc2) + db_session_with_containers.refresh(doc3) + db_session_with_containers.refresh(doc4) + db_session_with_containers.refresh(doc5) self._assert_document_enabled(doc1, FIXED_TIME) assert doc2.enabled is True assert doc3.enabled is True diff --git a/api/tests/test_containers_integration_tests/services/test_dataset_service_get_segments.py b/api/tests/test_containers_integration_tests/services/test_dataset_service_get_segments.py index 6effe795e2..e78894fcae 100644 --- a/api/tests/test_containers_integration_tests/services/test_dataset_service_get_segments.py +++ b/api/tests/test_containers_integration_tests/services/test_dataset_service_get_segments.py @@ -10,7 +10,8 @@ Tests the retrieval of document segments with pagination and filtering: from uuid import uuid4 -from extensions.ext_database import db +from sqlalchemy.orm import Session + from models import Account, Tenant, TenantAccountJoin, TenantAccountRole from models.dataset import Dataset, DatasetPermissionEnum, Document, DocumentSegment from services.dataset_service import SegmentService @@ -23,6 +24,7 @@ class SegmentServiceTestDataFactory: @staticmethod def create_account_with_tenant( + db_session_with_containers: Session, role: TenantAccountRole = TenantAccountRole.OWNER, tenant: Tenant | None = None, ) -> tuple[Account, Tenant]: @@ -33,13 +35,13 @@ class SegmentServiceTestDataFactory: interface_language="en-US", status="active", ) - db.session.add(account) - db.session.commit() + db_session_with_containers.add(account) + db_session_with_containers.commit() if tenant is None: tenant = Tenant(name=f"tenant-{uuid4()}", status="normal") - db.session.add(tenant) - db.session.commit() + db_session_with_containers.add(tenant) + db_session_with_containers.commit() join = TenantAccountJoin( tenant_id=tenant.id, @@ -47,14 +49,14 @@ class SegmentServiceTestDataFactory: role=role, current=True, ) - db.session.add(join) - db.session.commit() + db_session_with_containers.add(join) + db_session_with_containers.commit() account.current_tenant = tenant return account, tenant @staticmethod - def create_dataset(tenant_id: str, created_by: str) -> Dataset: + def create_dataset(db_session_with_containers: Session, tenant_id: str, created_by: str) -> Dataset: """Create a real dataset.""" dataset = Dataset( tenant_id=tenant_id, @@ -67,12 +69,14 @@ class SegmentServiceTestDataFactory: provider="vendor", retrieval_model={"top_k": 2}, ) - db.session.add(dataset) - db.session.commit() + db_session_with_containers.add(dataset) + db_session_with_containers.commit() return dataset @staticmethod - def create_document(tenant_id: str, dataset_id: str, created_by: str) -> Document: + def create_document( + db_session_with_containers: Session, tenant_id: str, dataset_id: str, created_by: str + ) -> Document: """Create a real document.""" document = Document( tenant_id=tenant_id, @@ -84,12 +88,13 @@ class SegmentServiceTestDataFactory: created_from="api", created_by=created_by, ) - db.session.add(document) - db.session.commit() + db_session_with_containers.add(document) + db_session_with_containers.commit() return document @staticmethod def create_segment( + db_session_with_containers: Session, tenant_id: str, dataset_id: str, document_id: str, @@ -112,8 +117,8 @@ class SegmentServiceTestDataFactory: tokens=tokens, created_by=created_by, ) - db.session.add(segment) - db.session.commit() + db_session_with_containers.add(segment) + db_session_with_containers.commit() return segment @@ -130,7 +135,7 @@ class TestSegmentServiceGetSegments: - Combined filters """ - def test_get_segments_basic_pagination(self, db_session_with_containers): + def test_get_segments_basic_pagination(self, db_session_with_containers: Session): """ Test basic pagination functionality. @@ -140,11 +145,14 @@ class TestSegmentServiceGetSegments: - Returns segments and total count """ # Arrange - owner, tenant = SegmentServiceTestDataFactory.create_account_with_tenant() - dataset = SegmentServiceTestDataFactory.create_dataset(tenant.id, owner.id) - document = SegmentServiceTestDataFactory.create_document(tenant.id, dataset.id, owner.id) + owner, tenant = SegmentServiceTestDataFactory.create_account_with_tenant(db_session_with_containers) + dataset = SegmentServiceTestDataFactory.create_dataset(db_session_with_containers, tenant.id, owner.id) + document = SegmentServiceTestDataFactory.create_document( + db_session_with_containers, tenant.id, dataset.id, owner.id + ) segment1 = SegmentServiceTestDataFactory.create_segment( + db_session_with_containers, tenant_id=tenant.id, dataset_id=dataset.id, document_id=document.id, @@ -153,6 +161,7 @@ class TestSegmentServiceGetSegments: content="First segment", ) segment2 = SegmentServiceTestDataFactory.create_segment( + db_session_with_containers, tenant_id=tenant.id, dataset_id=dataset.id, document_id=document.id, @@ -170,7 +179,7 @@ class TestSegmentServiceGetSegments: assert items[0].id == segment1.id assert items[1].id == segment2.id - def test_get_segments_with_status_filter(self, db_session_with_containers): + def test_get_segments_with_status_filter(self, db_session_with_containers: Session): """ Test filtering by status list. @@ -179,11 +188,14 @@ class TestSegmentServiceGetSegments: - Only segments with matching status are returned """ # Arrange - owner, tenant = SegmentServiceTestDataFactory.create_account_with_tenant() - dataset = SegmentServiceTestDataFactory.create_dataset(tenant.id, owner.id) - document = SegmentServiceTestDataFactory.create_document(tenant.id, dataset.id, owner.id) + owner, tenant = SegmentServiceTestDataFactory.create_account_with_tenant(db_session_with_containers) + dataset = SegmentServiceTestDataFactory.create_dataset(db_session_with_containers, tenant.id, owner.id) + document = SegmentServiceTestDataFactory.create_document( + db_session_with_containers, tenant.id, dataset.id, owner.id + ) SegmentServiceTestDataFactory.create_segment( + db_session_with_containers, tenant_id=tenant.id, dataset_id=dataset.id, document_id=document.id, @@ -192,6 +204,7 @@ class TestSegmentServiceGetSegments: status="completed", ) SegmentServiceTestDataFactory.create_segment( + db_session_with_containers, tenant_id=tenant.id, dataset_id=dataset.id, document_id=document.id, @@ -200,6 +213,7 @@ class TestSegmentServiceGetSegments: status="indexing", ) SegmentServiceTestDataFactory.create_segment( + db_session_with_containers, tenant_id=tenant.id, dataset_id=dataset.id, document_id=document.id, @@ -219,7 +233,7 @@ class TestSegmentServiceGetSegments: statuses = {item.status for item in items} assert statuses == {"completed", "indexing"} - def test_get_segments_with_empty_status_list(self, db_session_with_containers): + def test_get_segments_with_empty_status_list(self, db_session_with_containers: Session): """ Test with empty status list. @@ -228,11 +242,14 @@ class TestSegmentServiceGetSegments: - No status filter is applied to avoid WHERE false condition """ # Arrange - owner, tenant = SegmentServiceTestDataFactory.create_account_with_tenant() - dataset = SegmentServiceTestDataFactory.create_dataset(tenant.id, owner.id) - document = SegmentServiceTestDataFactory.create_document(tenant.id, dataset.id, owner.id) + owner, tenant = SegmentServiceTestDataFactory.create_account_with_tenant(db_session_with_containers) + dataset = SegmentServiceTestDataFactory.create_dataset(db_session_with_containers, tenant.id, owner.id) + document = SegmentServiceTestDataFactory.create_document( + db_session_with_containers, tenant.id, dataset.id, owner.id + ) SegmentServiceTestDataFactory.create_segment( + db_session_with_containers, tenant_id=tenant.id, dataset_id=dataset.id, document_id=document.id, @@ -241,6 +258,7 @@ class TestSegmentServiceGetSegments: status="completed", ) SegmentServiceTestDataFactory.create_segment( + db_session_with_containers, tenant_id=tenant.id, dataset_id=dataset.id, document_id=document.id, @@ -256,7 +274,7 @@ class TestSegmentServiceGetSegments: assert len(items) == 2 assert total == 2 - def test_get_segments_with_keyword_search(self, db_session_with_containers): + def test_get_segments_with_keyword_search(self, db_session_with_containers: Session): """ Test keyword search functionality. @@ -265,11 +283,14 @@ class TestSegmentServiceGetSegments: - Search pattern includes wildcards (%keyword%) """ # Arrange - owner, tenant = SegmentServiceTestDataFactory.create_account_with_tenant() - dataset = SegmentServiceTestDataFactory.create_dataset(tenant.id, owner.id) - document = SegmentServiceTestDataFactory.create_document(tenant.id, dataset.id, owner.id) + owner, tenant = SegmentServiceTestDataFactory.create_account_with_tenant(db_session_with_containers) + dataset = SegmentServiceTestDataFactory.create_dataset(db_session_with_containers, tenant.id, owner.id) + document = SegmentServiceTestDataFactory.create_document( + db_session_with_containers, tenant.id, dataset.id, owner.id + ) SegmentServiceTestDataFactory.create_segment( + db_session_with_containers, tenant_id=tenant.id, dataset_id=dataset.id, document_id=document.id, @@ -278,6 +299,7 @@ class TestSegmentServiceGetSegments: content="This contains search term in the middle", ) SegmentServiceTestDataFactory.create_segment( + db_session_with_containers, tenant_id=tenant.id, dataset_id=dataset.id, document_id=document.id, @@ -294,7 +316,7 @@ class TestSegmentServiceGetSegments: assert total == 1 assert "search term" in items[0].content - def test_get_segments_ordering_by_position_and_id(self, db_session_with_containers): + def test_get_segments_ordering_by_position_and_id(self, db_session_with_containers: Session): """ Test ordering by position and id. @@ -304,12 +326,15 @@ class TestSegmentServiceGetSegments: - This prevents duplicate data across pages when positions are not unique """ # Arrange - owner, tenant = SegmentServiceTestDataFactory.create_account_with_tenant() - dataset = SegmentServiceTestDataFactory.create_dataset(tenant.id, owner.id) - document = SegmentServiceTestDataFactory.create_document(tenant.id, dataset.id, owner.id) + owner, tenant = SegmentServiceTestDataFactory.create_account_with_tenant(db_session_with_containers) + dataset = SegmentServiceTestDataFactory.create_dataset(db_session_with_containers, tenant.id, owner.id) + document = SegmentServiceTestDataFactory.create_document( + db_session_with_containers, tenant.id, dataset.id, owner.id + ) # Create segments with different positions seg_pos2 = SegmentServiceTestDataFactory.create_segment( + db_session_with_containers, tenant_id=tenant.id, dataset_id=dataset.id, document_id=document.id, @@ -318,6 +343,7 @@ class TestSegmentServiceGetSegments: content="Position 2", ) seg_pos1 = SegmentServiceTestDataFactory.create_segment( + db_session_with_containers, tenant_id=tenant.id, dataset_id=dataset.id, document_id=document.id, @@ -326,6 +352,7 @@ class TestSegmentServiceGetSegments: content="Position 1", ) seg_pos3 = SegmentServiceTestDataFactory.create_segment( + db_session_with_containers, tenant_id=tenant.id, dataset_id=dataset.id, document_id=document.id, @@ -344,7 +371,7 @@ class TestSegmentServiceGetSegments: assert items[1].id == seg_pos2.id assert items[2].id == seg_pos3.id - def test_get_segments_empty_results(self, db_session_with_containers): + def test_get_segments_empty_results(self, db_session_with_containers: Session): """ Test when no segments match the criteria. @@ -353,7 +380,7 @@ class TestSegmentServiceGetSegments: - Total count is 0 """ # Arrange - owner, tenant = SegmentServiceTestDataFactory.create_account_with_tenant() + owner, tenant = SegmentServiceTestDataFactory.create_account_with_tenant(db_session_with_containers) non_existent_doc_id = str(uuid4()) # Act @@ -363,7 +390,7 @@ class TestSegmentServiceGetSegments: assert items == [] assert total == 0 - def test_get_segments_combined_filters(self, db_session_with_containers): + def test_get_segments_combined_filters(self, db_session_with_containers: Session): """ Test with multiple filters combined. @@ -372,12 +399,15 @@ class TestSegmentServiceGetSegments: - Status list and keyword search both applied """ # Arrange - owner, tenant = SegmentServiceTestDataFactory.create_account_with_tenant() - dataset = SegmentServiceTestDataFactory.create_dataset(tenant.id, owner.id) - document = SegmentServiceTestDataFactory.create_document(tenant.id, dataset.id, owner.id) + owner, tenant = SegmentServiceTestDataFactory.create_account_with_tenant(db_session_with_containers) + dataset = SegmentServiceTestDataFactory.create_dataset(db_session_with_containers, tenant.id, owner.id) + document = SegmentServiceTestDataFactory.create_document( + db_session_with_containers, tenant.id, dataset.id, owner.id + ) # Create segments with various statuses and content SegmentServiceTestDataFactory.create_segment( + db_session_with_containers, tenant_id=tenant.id, dataset_id=dataset.id, document_id=document.id, @@ -387,6 +417,7 @@ class TestSegmentServiceGetSegments: content="This is important information", ) SegmentServiceTestDataFactory.create_segment( + db_session_with_containers, tenant_id=tenant.id, dataset_id=dataset.id, document_id=document.id, @@ -396,6 +427,7 @@ class TestSegmentServiceGetSegments: content="This is also important", ) SegmentServiceTestDataFactory.create_segment( + db_session_with_containers, tenant_id=tenant.id, dataset_id=dataset.id, document_id=document.id, @@ -421,7 +453,7 @@ class TestSegmentServiceGetSegments: assert items[0].status == "completed" assert "important" in items[0].content - def test_get_segments_with_none_status_list(self, db_session_with_containers): + def test_get_segments_with_none_status_list(self, db_session_with_containers: Session): """ Test with None status list. @@ -430,11 +462,14 @@ class TestSegmentServiceGetSegments: - No status filter is applied """ # Arrange - owner, tenant = SegmentServiceTestDataFactory.create_account_with_tenant() - dataset = SegmentServiceTestDataFactory.create_dataset(tenant.id, owner.id) - document = SegmentServiceTestDataFactory.create_document(tenant.id, dataset.id, owner.id) + owner, tenant = SegmentServiceTestDataFactory.create_account_with_tenant(db_session_with_containers) + dataset = SegmentServiceTestDataFactory.create_dataset(db_session_with_containers, tenant.id, owner.id) + document = SegmentServiceTestDataFactory.create_document( + db_session_with_containers, tenant.id, dataset.id, owner.id + ) SegmentServiceTestDataFactory.create_segment( + db_session_with_containers, tenant_id=tenant.id, dataset_id=dataset.id, document_id=document.id, @@ -443,6 +478,7 @@ class TestSegmentServiceGetSegments: status="completed", ) SegmentServiceTestDataFactory.create_segment( + db_session_with_containers, tenant_id=tenant.id, dataset_id=dataset.id, document_id=document.id, @@ -462,7 +498,7 @@ class TestSegmentServiceGetSegments: assert len(items) == 2 assert total == 2 - def test_get_segments_pagination_max_per_page_limit(self, db_session_with_containers): + def test_get_segments_pagination_max_per_page_limit(self, db_session_with_containers: Session): """ Test that max_per_page is correctly set to 100. @@ -471,13 +507,16 @@ class TestSegmentServiceGetSegments: - This prevents excessive page sizes """ # Arrange - owner, tenant = SegmentServiceTestDataFactory.create_account_with_tenant() - dataset = SegmentServiceTestDataFactory.create_dataset(tenant.id, owner.id) - document = SegmentServiceTestDataFactory.create_document(tenant.id, dataset.id, owner.id) + owner, tenant = SegmentServiceTestDataFactory.create_account_with_tenant(db_session_with_containers) + dataset = SegmentServiceTestDataFactory.create_dataset(db_session_with_containers, tenant.id, owner.id) + document = SegmentServiceTestDataFactory.create_document( + db_session_with_containers, tenant.id, dataset.id, owner.id + ) # Create 105 segments to exceed max_per_page of 100 for i in range(105): SegmentServiceTestDataFactory.create_segment( + db_session_with_containers, tenant_id=tenant.id, dataset_id=dataset.id, document_id=document.id, diff --git a/api/tests/test_containers_integration_tests/services/test_dataset_service_retrieval.py b/api/tests/test_containers_integration_tests/services/test_dataset_service_retrieval.py index f605a286ed..8bd994937a 100644 --- a/api/tests/test_containers_integration_tests/services/test_dataset_service_retrieval.py +++ b/api/tests/test_containers_integration_tests/services/test_dataset_service_retrieval.py @@ -13,7 +13,8 @@ This test suite covers: import json from uuid import uuid4 -from extensions.ext_database import db +from sqlalchemy.orm import Session + from models.account import Account, Tenant, TenantAccountJoin, TenantAccountRole from models.dataset import ( AppDatasetJoin, @@ -31,7 +32,9 @@ class DatasetRetrievalTestDataFactory: """Factory class for creating database-backed test data for dataset retrieval integration tests.""" @staticmethod - def create_account_with_tenant(role: TenantAccountRole = TenantAccountRole.NORMAL) -> tuple[Account, Tenant]: + def create_account_with_tenant( + db_session_with_containers: Session, role: TenantAccountRole = TenantAccountRole.NORMAL + ) -> tuple[Account, Tenant]: """Create an account and tenant with the specified role.""" account = Account( email=f"{uuid4()}@example.com", @@ -43,8 +46,8 @@ class DatasetRetrievalTestDataFactory: name=f"tenant-{uuid4()}", status="normal", ) - db.session.add_all([account, tenant]) - db.session.flush() + db_session_with_containers.add_all([account, tenant]) + db_session_with_containers.flush() join = TenantAccountJoin( tenant_id=tenant.id, @@ -52,14 +55,16 @@ class DatasetRetrievalTestDataFactory: role=role, current=True, ) - db.session.add(join) - db.session.commit() + db_session_with_containers.add(join) + db_session_with_containers.commit() account.current_tenant = tenant return account, tenant @staticmethod - def create_account_in_tenant(tenant: Tenant, role: TenantAccountRole = TenantAccountRole.OWNER) -> Account: + def create_account_in_tenant( + db_session_with_containers: Session, tenant: Tenant, role: TenantAccountRole = TenantAccountRole.OWNER + ) -> Account: """Create an account and add it to an existing tenant.""" account = Account( email=f"{uuid4()}@example.com", @@ -67,8 +72,8 @@ class DatasetRetrievalTestDataFactory: interface_language="en-US", status="active", ) - db.session.add(account) - db.session.flush() + db_session_with_containers.add(account) + db_session_with_containers.flush() join = TenantAccountJoin( tenant_id=tenant.id, @@ -76,14 +81,15 @@ class DatasetRetrievalTestDataFactory: role=role, current=True, ) - db.session.add(join) - db.session.commit() + db_session_with_containers.add(join) + db_session_with_containers.commit() account.current_tenant = tenant return account @staticmethod def create_dataset( + db_session_with_containers: Session, tenant_id: str, created_by: str, name: str = "Test Dataset", @@ -101,12 +107,14 @@ class DatasetRetrievalTestDataFactory: provider="vendor", retrieval_model={"top_k": 2}, ) - db.session.add(dataset) - db.session.commit() + db_session_with_containers.add(dataset) + db_session_with_containers.commit() return dataset @staticmethod - def create_dataset_permission(dataset_id: str, tenant_id: str, account_id: str) -> DatasetPermission: + def create_dataset_permission( + db_session_with_containers: Session, dataset_id: str, tenant_id: str, account_id: str + ) -> DatasetPermission: """Create a dataset permission.""" permission = DatasetPermission( dataset_id=dataset_id, @@ -114,12 +122,14 @@ class DatasetRetrievalTestDataFactory: account_id=account_id, has_permission=True, ) - db.session.add(permission) - db.session.commit() + db_session_with_containers.add(permission) + db_session_with_containers.commit() return permission @staticmethod - def create_process_rule(dataset_id: str, created_by: str, mode: str, rules: dict) -> DatasetProcessRule: + def create_process_rule( + db_session_with_containers: Session, dataset_id: str, created_by: str, mode: str, rules: dict + ) -> DatasetProcessRule: """Create a dataset process rule.""" process_rule = DatasetProcessRule( dataset_id=dataset_id, @@ -127,12 +137,14 @@ class DatasetRetrievalTestDataFactory: mode=mode, rules=json.dumps(rules), ) - db.session.add(process_rule) - db.session.commit() + db_session_with_containers.add(process_rule) + db_session_with_containers.commit() return process_rule @staticmethod - def create_dataset_query(dataset_id: str, created_by: str, content: str) -> DatasetQuery: + def create_dataset_query( + db_session_with_containers: Session, dataset_id: str, created_by: str, content: str + ) -> DatasetQuery: """Create a dataset query.""" dataset_query = DatasetQuery( dataset_id=dataset_id, @@ -142,23 +154,23 @@ class DatasetRetrievalTestDataFactory: created_by_role="account", created_by=created_by, ) - db.session.add(dataset_query) - db.session.commit() + db_session_with_containers.add(dataset_query) + db_session_with_containers.commit() return dataset_query @staticmethod - def create_app_dataset_join(dataset_id: str) -> AppDatasetJoin: + def create_app_dataset_join(db_session_with_containers: Session, dataset_id: str) -> AppDatasetJoin: """Create an app-dataset join.""" join = AppDatasetJoin( app_id=str(uuid4()), dataset_id=dataset_id, ) - db.session.add(join) - db.session.commit() + db_session_with_containers.add(join) + db_session_with_containers.commit() return join @staticmethod - def create_tag_binding(tenant_id: str, created_by: str, target_id: str) -> Tag: + def create_tag_binding(db_session_with_containers: Session, tenant_id: str, created_by: str, target_id: str) -> Tag: """Create a knowledge tag and bind it to the target dataset.""" tag = Tag( tenant_id=tenant_id, @@ -166,8 +178,8 @@ class DatasetRetrievalTestDataFactory: name=f"tag-{uuid4()}", created_by=created_by, ) - db.session.add(tag) - db.session.flush() + db_session_with_containers.add(tag) + db_session_with_containers.flush() binding = TagBinding( tenant_id=tenant_id, @@ -175,8 +187,8 @@ class DatasetRetrievalTestDataFactory: target_id=target_id, created_by=created_by, ) - db.session.add(binding) - db.session.commit() + db_session_with_containers.add(binding) + db_session_with_containers.commit() return tag @@ -195,15 +207,16 @@ class TestDatasetServiceGetDatasets: # ==================== Basic Retrieval Tests ==================== - def test_get_datasets_basic_pagination(self, db_session_with_containers): + def test_get_datasets_basic_pagination(self, db_session_with_containers: Session): """Test basic pagination without user or filters.""" # Arrange - account, tenant = DatasetRetrievalTestDataFactory.create_account_with_tenant() + account, tenant = DatasetRetrievalTestDataFactory.create_account_with_tenant(db_session_with_containers) page = 1 per_page = 20 for i in range(5): DatasetRetrievalTestDataFactory.create_dataset( + db_session_with_containers, tenant_id=tenant.id, created_by=account.id, name=f"Dataset {i}", @@ -217,21 +230,23 @@ class TestDatasetServiceGetDatasets: assert len(datasets) == 5 assert total == 5 - def test_get_datasets_with_search(self, db_session_with_containers): + def test_get_datasets_with_search(self, db_session_with_containers: Session): """Test get_datasets with search keyword.""" # Arrange - account, tenant = DatasetRetrievalTestDataFactory.create_account_with_tenant() + account, tenant = DatasetRetrievalTestDataFactory.create_account_with_tenant(db_session_with_containers) page = 1 per_page = 20 search = "test" DatasetRetrievalTestDataFactory.create_dataset( + db_session_with_containers, tenant_id=tenant.id, created_by=account.id, name="Test Dataset", permission=DatasetPermissionEnum.ALL_TEAM, ) DatasetRetrievalTestDataFactory.create_dataset( + db_session_with_containers, tenant_id=tenant.id, created_by=account.id, name="Another Dataset", @@ -245,26 +260,32 @@ class TestDatasetServiceGetDatasets: assert len(datasets) == 1 assert total == 1 - def test_get_datasets_with_tag_filtering(self, db_session_with_containers): + def test_get_datasets_with_tag_filtering(self, db_session_with_containers: Session): """Test get_datasets with tag_ids filtering.""" # Arrange - account, tenant = DatasetRetrievalTestDataFactory.create_account_with_tenant() + account, tenant = DatasetRetrievalTestDataFactory.create_account_with_tenant(db_session_with_containers) page = 1 per_page = 20 dataset_1 = DatasetRetrievalTestDataFactory.create_dataset( + db_session_with_containers, tenant_id=tenant.id, created_by=account.id, permission=DatasetPermissionEnum.ALL_TEAM, ) dataset_2 = DatasetRetrievalTestDataFactory.create_dataset( + db_session_with_containers, tenant_id=tenant.id, created_by=account.id, permission=DatasetPermissionEnum.ALL_TEAM, ) - tag_1 = DatasetRetrievalTestDataFactory.create_tag_binding(tenant.id, account.id, dataset_1.id) - tag_2 = DatasetRetrievalTestDataFactory.create_tag_binding(tenant.id, account.id, dataset_2.id) + tag_1 = DatasetRetrievalTestDataFactory.create_tag_binding( + db_session_with_containers, tenant.id, account.id, dataset_1.id + ) + tag_2 = DatasetRetrievalTestDataFactory.create_tag_binding( + db_session_with_containers, tenant.id, account.id, dataset_2.id + ) tag_ids = [tag_1.id, tag_2.id] # Act @@ -274,16 +295,17 @@ class TestDatasetServiceGetDatasets: assert len(datasets) == 2 assert total == 2 - def test_get_datasets_with_empty_tag_ids(self, db_session_with_containers): + def test_get_datasets_with_empty_tag_ids(self, db_session_with_containers: Session): """Test get_datasets with empty tag_ids skips tag filtering and returns all matching datasets.""" # Arrange - account, tenant = DatasetRetrievalTestDataFactory.create_account_with_tenant() + account, tenant = DatasetRetrievalTestDataFactory.create_account_with_tenant(db_session_with_containers) page = 1 per_page = 20 tag_ids = [] for i in range(3): DatasetRetrievalTestDataFactory.create_dataset( + db_session_with_containers, tenant_id=tenant.id, created_by=account.id, name=f"dataset-{i}", @@ -300,19 +322,21 @@ class TestDatasetServiceGetDatasets: # ==================== Permission-Based Filtering Tests ==================== - def test_get_datasets_without_user_shows_only_all_team(self, db_session_with_containers): + def test_get_datasets_without_user_shows_only_all_team(self, db_session_with_containers: Session): """Test that without user, only ALL_TEAM datasets are shown.""" # Arrange - account, tenant = DatasetRetrievalTestDataFactory.create_account_with_tenant() + account, tenant = DatasetRetrievalTestDataFactory.create_account_with_tenant(db_session_with_containers) page = 1 per_page = 20 DatasetRetrievalTestDataFactory.create_dataset( + db_session_with_containers, tenant_id=tenant.id, created_by=account.id, permission=DatasetPermissionEnum.ALL_TEAM, ) DatasetRetrievalTestDataFactory.create_dataset( + db_session_with_containers, tenant_id=tenant.id, created_by=account.id, permission=DatasetPermissionEnum.ONLY_ME, @@ -325,15 +349,18 @@ class TestDatasetServiceGetDatasets: assert len(datasets) == 1 assert total == 1 - def test_get_datasets_owner_with_include_all(self, db_session_with_containers): + def test_get_datasets_owner_with_include_all(self, db_session_with_containers: Session): """Test that OWNER with include_all=True sees all datasets.""" # Arrange - owner, tenant = DatasetRetrievalTestDataFactory.create_account_with_tenant(role=TenantAccountRole.OWNER) + owner, tenant = DatasetRetrievalTestDataFactory.create_account_with_tenant( + db_session_with_containers, role=TenantAccountRole.OWNER + ) for i, permission in enumerate( [DatasetPermissionEnum.ONLY_ME, DatasetPermissionEnum.ALL_TEAM, DatasetPermissionEnum.PARTIAL_TEAM] ): DatasetRetrievalTestDataFactory.create_dataset( + db_session_with_containers, tenant_id=tenant.id, created_by=owner.id, name=f"dataset-{i}", @@ -353,12 +380,15 @@ class TestDatasetServiceGetDatasets: assert len(datasets) == 3 assert total == 3 - def test_get_datasets_normal_user_only_me_permission(self, db_session_with_containers): + def test_get_datasets_normal_user_only_me_permission(self, db_session_with_containers: Session): """Test that normal user sees ONLY_ME datasets they created.""" # Arrange - user, tenant = DatasetRetrievalTestDataFactory.create_account_with_tenant(role=TenantAccountRole.NORMAL) + user, tenant = DatasetRetrievalTestDataFactory.create_account_with_tenant( + db_session_with_containers, role=TenantAccountRole.NORMAL + ) DatasetRetrievalTestDataFactory.create_dataset( + db_session_with_containers, tenant_id=tenant.id, created_by=user.id, permission=DatasetPermissionEnum.ONLY_ME, @@ -371,13 +401,18 @@ class TestDatasetServiceGetDatasets: assert len(datasets) == 1 assert total == 1 - def test_get_datasets_normal_user_all_team_permission(self, db_session_with_containers): + def test_get_datasets_normal_user_all_team_permission(self, db_session_with_containers: Session): """Test that normal user sees ALL_TEAM datasets.""" # Arrange - user, tenant = DatasetRetrievalTestDataFactory.create_account_with_tenant(role=TenantAccountRole.NORMAL) - owner = DatasetRetrievalTestDataFactory.create_account_in_tenant(tenant, role=TenantAccountRole.OWNER) + user, tenant = DatasetRetrievalTestDataFactory.create_account_with_tenant( + db_session_with_containers, role=TenantAccountRole.NORMAL + ) + owner = DatasetRetrievalTestDataFactory.create_account_in_tenant( + db_session_with_containers, tenant, role=TenantAccountRole.OWNER + ) DatasetRetrievalTestDataFactory.create_dataset( + db_session_with_containers, tenant_id=tenant.id, created_by=owner.id, permission=DatasetPermissionEnum.ALL_TEAM, @@ -390,18 +425,25 @@ class TestDatasetServiceGetDatasets: assert len(datasets) == 1 assert total == 1 - def test_get_datasets_normal_user_partial_team_with_permission(self, db_session_with_containers): + def test_get_datasets_normal_user_partial_team_with_permission(self, db_session_with_containers: Session): """Test that normal user sees PARTIAL_TEAM datasets they have permission for.""" # Arrange - user, tenant = DatasetRetrievalTestDataFactory.create_account_with_tenant(role=TenantAccountRole.NORMAL) - owner = DatasetRetrievalTestDataFactory.create_account_in_tenant(tenant, role=TenantAccountRole.OWNER) + user, tenant = DatasetRetrievalTestDataFactory.create_account_with_tenant( + db_session_with_containers, role=TenantAccountRole.NORMAL + ) + owner = DatasetRetrievalTestDataFactory.create_account_in_tenant( + db_session_with_containers, tenant, role=TenantAccountRole.OWNER + ) dataset = DatasetRetrievalTestDataFactory.create_dataset( + db_session_with_containers, tenant_id=tenant.id, created_by=owner.id, permission=DatasetPermissionEnum.PARTIAL_TEAM, ) - DatasetRetrievalTestDataFactory.create_dataset_permission(dataset.id, tenant.id, user.id) + DatasetRetrievalTestDataFactory.create_dataset_permission( + db_session_with_containers, dataset.id, tenant.id, user.id + ) # Act datasets, total = DatasetService.get_datasets(page=1, per_page=20, tenant_id=tenant.id, user=user) @@ -410,20 +452,25 @@ class TestDatasetServiceGetDatasets: assert len(datasets) == 1 assert total == 1 - def test_get_datasets_dataset_operator_with_permissions(self, db_session_with_containers): + def test_get_datasets_dataset_operator_with_permissions(self, db_session_with_containers: Session): """Test that DATASET_OPERATOR only sees datasets they have explicit permission for.""" # Arrange operator, tenant = DatasetRetrievalTestDataFactory.create_account_with_tenant( - role=TenantAccountRole.DATASET_OPERATOR + db_session_with_containers, role=TenantAccountRole.DATASET_OPERATOR + ) + owner = DatasetRetrievalTestDataFactory.create_account_in_tenant( + db_session_with_containers, tenant, role=TenantAccountRole.OWNER ) - owner = DatasetRetrievalTestDataFactory.create_account_in_tenant(tenant, role=TenantAccountRole.OWNER) dataset = DatasetRetrievalTestDataFactory.create_dataset( + db_session_with_containers, tenant_id=tenant.id, created_by=owner.id, permission=DatasetPermissionEnum.ONLY_ME, ) - DatasetRetrievalTestDataFactory.create_dataset_permission(dataset.id, tenant.id, operator.id) + DatasetRetrievalTestDataFactory.create_dataset_permission( + db_session_with_containers, dataset.id, tenant.id, operator.id + ) # Act datasets, total = DatasetService.get_datasets(page=1, per_page=20, tenant_id=tenant.id, user=operator) @@ -432,14 +479,17 @@ class TestDatasetServiceGetDatasets: assert len(datasets) == 1 assert total == 1 - def test_get_datasets_dataset_operator_without_permissions(self, db_session_with_containers): + def test_get_datasets_dataset_operator_without_permissions(self, db_session_with_containers: Session): """Test that DATASET_OPERATOR without permissions returns empty result.""" # Arrange operator, tenant = DatasetRetrievalTestDataFactory.create_account_with_tenant( - role=TenantAccountRole.DATASET_OPERATOR + db_session_with_containers, role=TenantAccountRole.DATASET_OPERATOR + ) + owner = DatasetRetrievalTestDataFactory.create_account_in_tenant( + db_session_with_containers, tenant, role=TenantAccountRole.OWNER ) - owner = DatasetRetrievalTestDataFactory.create_account_in_tenant(tenant, role=TenantAccountRole.OWNER) DatasetRetrievalTestDataFactory.create_dataset( + db_session_with_containers, tenant_id=tenant.id, created_by=owner.id, permission=DatasetPermissionEnum.ALL_TEAM, @@ -456,11 +506,13 @@ class TestDatasetServiceGetDatasets: class TestDatasetServiceGetDataset: """Comprehensive integration tests for DatasetService.get_dataset method.""" - def test_get_dataset_success(self, db_session_with_containers): + def test_get_dataset_success(self, db_session_with_containers: Session): """Test successful retrieval of a single dataset.""" # Arrange - account, tenant = DatasetRetrievalTestDataFactory.create_account_with_tenant() - dataset = DatasetRetrievalTestDataFactory.create_dataset(tenant_id=tenant.id, created_by=account.id) + account, tenant = DatasetRetrievalTestDataFactory.create_account_with_tenant(db_session_with_containers) + dataset = DatasetRetrievalTestDataFactory.create_dataset( + db_session_with_containers, tenant_id=tenant.id, created_by=account.id + ) # Act result = DatasetService.get_dataset(dataset.id) @@ -469,7 +521,7 @@ class TestDatasetServiceGetDataset: assert result is not None assert result.id == dataset.id - def test_get_dataset_not_found(self, db_session_with_containers): + def test_get_dataset_not_found(self, db_session_with_containers: Session): """Test retrieval when dataset doesn't exist.""" # Arrange dataset_id = str(uuid4()) @@ -484,12 +536,15 @@ class TestDatasetServiceGetDataset: class TestDatasetServiceGetDatasetsByIds: """Comprehensive integration tests for DatasetService.get_datasets_by_ids method.""" - def test_get_datasets_by_ids_success(self, db_session_with_containers): + def test_get_datasets_by_ids_success(self, db_session_with_containers: Session): """Test successful bulk retrieval of datasets by IDs.""" # Arrange - account, tenant = DatasetRetrievalTestDataFactory.create_account_with_tenant() + account, tenant = DatasetRetrievalTestDataFactory.create_account_with_tenant(db_session_with_containers) datasets = [ - DatasetRetrievalTestDataFactory.create_dataset(tenant_id=tenant.id, created_by=account.id) for _ in range(3) + DatasetRetrievalTestDataFactory.create_dataset( + db_session_with_containers, tenant_id=tenant.id, created_by=account.id + ) + for _ in range(3) ] dataset_ids = [dataset.id for dataset in datasets] @@ -501,7 +556,7 @@ class TestDatasetServiceGetDatasetsByIds: assert total == 3 assert all(dataset.id in dataset_ids for dataset in result_datasets) - def test_get_datasets_by_ids_empty_list(self, db_session_with_containers): + def test_get_datasets_by_ids_empty_list(self, db_session_with_containers: Session): """Test get_datasets_by_ids with empty list returns empty result.""" # Arrange tenant_id = str(uuid4()) @@ -514,7 +569,7 @@ class TestDatasetServiceGetDatasetsByIds: assert datasets == [] assert total == 0 - def test_get_datasets_by_ids_none_list(self, db_session_with_containers): + def test_get_datasets_by_ids_none_list(self, db_session_with_containers: Session): """Test get_datasets_by_ids with None returns empty result.""" # Arrange tenant_id = str(uuid4()) @@ -530,17 +585,20 @@ class TestDatasetServiceGetDatasetsByIds: class TestDatasetServiceGetProcessRules: """Comprehensive integration tests for DatasetService.get_process_rules method.""" - def test_get_process_rules_with_existing_rule(self, db_session_with_containers): + def test_get_process_rules_with_existing_rule(self, db_session_with_containers: Session): """Test retrieval of process rules when rule exists.""" # Arrange - account, tenant = DatasetRetrievalTestDataFactory.create_account_with_tenant() - dataset = DatasetRetrievalTestDataFactory.create_dataset(tenant_id=tenant.id, created_by=account.id) + account, tenant = DatasetRetrievalTestDataFactory.create_account_with_tenant(db_session_with_containers) + dataset = DatasetRetrievalTestDataFactory.create_dataset( + db_session_with_containers, tenant_id=tenant.id, created_by=account.id + ) rules_data = { "pre_processing_rules": [{"id": "remove_extra_spaces", "enabled": True}], "segmentation": {"delimiter": "\n", "max_tokens": 500}, } DatasetRetrievalTestDataFactory.create_process_rule( + db_session_with_containers, dataset_id=dataset.id, created_by=account.id, mode="custom", @@ -554,11 +612,13 @@ class TestDatasetServiceGetProcessRules: assert result["mode"] == "custom" assert result["rules"] == rules_data - def test_get_process_rules_without_existing_rule(self, db_session_with_containers): + def test_get_process_rules_without_existing_rule(self, db_session_with_containers: Session): """Test retrieval of process rules when no rule exists (returns defaults).""" # Arrange - account, tenant = DatasetRetrievalTestDataFactory.create_account_with_tenant() - dataset = DatasetRetrievalTestDataFactory.create_dataset(tenant_id=tenant.id, created_by=account.id) + account, tenant = DatasetRetrievalTestDataFactory.create_account_with_tenant(db_session_with_containers) + dataset = DatasetRetrievalTestDataFactory.create_dataset( + db_session_with_containers, tenant_id=tenant.id, created_by=account.id + ) # Act result = DatasetService.get_process_rules(dataset.id) @@ -572,16 +632,19 @@ class TestDatasetServiceGetProcessRules: class TestDatasetServiceGetDatasetQueries: """Comprehensive integration tests for DatasetService.get_dataset_queries method.""" - def test_get_dataset_queries_success(self, db_session_with_containers): + def test_get_dataset_queries_success(self, db_session_with_containers: Session): """Test successful retrieval of dataset queries.""" # Arrange - account, tenant = DatasetRetrievalTestDataFactory.create_account_with_tenant() - dataset = DatasetRetrievalTestDataFactory.create_dataset(tenant_id=tenant.id, created_by=account.id) + account, tenant = DatasetRetrievalTestDataFactory.create_account_with_tenant(db_session_with_containers) + dataset = DatasetRetrievalTestDataFactory.create_dataset( + db_session_with_containers, tenant_id=tenant.id, created_by=account.id + ) page = 1 per_page = 20 for i in range(3): DatasetRetrievalTestDataFactory.create_dataset_query( + db_session_with_containers, dataset_id=dataset.id, created_by=account.id, content=f"query-{i}", @@ -595,11 +658,13 @@ class TestDatasetServiceGetDatasetQueries: assert total == 3 assert all(query.dataset_id == dataset.id for query in queries) - def test_get_dataset_queries_empty_result(self, db_session_with_containers): + def test_get_dataset_queries_empty_result(self, db_session_with_containers: Session): """Test retrieval when no queries exist.""" # Arrange - account, tenant = DatasetRetrievalTestDataFactory.create_account_with_tenant() - dataset = DatasetRetrievalTestDataFactory.create_dataset(tenant_id=tenant.id, created_by=account.id) + account, tenant = DatasetRetrievalTestDataFactory.create_account_with_tenant(db_session_with_containers) + dataset = DatasetRetrievalTestDataFactory.create_dataset( + db_session_with_containers, tenant_id=tenant.id, created_by=account.id + ) page = 1 per_page = 20 @@ -614,14 +679,16 @@ class TestDatasetServiceGetDatasetQueries: class TestDatasetServiceGetRelatedApps: """Comprehensive integration tests for DatasetService.get_related_apps method.""" - def test_get_related_apps_success(self, db_session_with_containers): + def test_get_related_apps_success(self, db_session_with_containers: Session): """Test successful retrieval of related apps.""" # Arrange - account, tenant = DatasetRetrievalTestDataFactory.create_account_with_tenant() - dataset = DatasetRetrievalTestDataFactory.create_dataset(tenant_id=tenant.id, created_by=account.id) + account, tenant = DatasetRetrievalTestDataFactory.create_account_with_tenant(db_session_with_containers) + dataset = DatasetRetrievalTestDataFactory.create_dataset( + db_session_with_containers, tenant_id=tenant.id, created_by=account.id + ) for _ in range(2): - DatasetRetrievalTestDataFactory.create_app_dataset_join(dataset.id) + DatasetRetrievalTestDataFactory.create_app_dataset_join(db_session_with_containers, dataset.id) # Act result = DatasetService.get_related_apps(dataset.id) @@ -630,11 +697,13 @@ class TestDatasetServiceGetRelatedApps: assert len(result) == 2 assert all(join.dataset_id == dataset.id for join in result) - def test_get_related_apps_empty_result(self, db_session_with_containers): + def test_get_related_apps_empty_result(self, db_session_with_containers: Session): """Test retrieval when no related apps exist.""" # Arrange - account, tenant = DatasetRetrievalTestDataFactory.create_account_with_tenant() - dataset = DatasetRetrievalTestDataFactory.create_dataset(tenant_id=tenant.id, created_by=account.id) + account, tenant = DatasetRetrievalTestDataFactory.create_account_with_tenant(db_session_with_containers) + dataset = DatasetRetrievalTestDataFactory.create_dataset( + db_session_with_containers, tenant_id=tenant.id, created_by=account.id + ) # Act result = DatasetService.get_related_apps(dataset.id) diff --git a/api/tests/test_containers_integration_tests/services/test_dataset_service_update_dataset.py b/api/tests/test_containers_integration_tests/services/test_dataset_service_update_dataset.py index 7f9135bb81..ebaa3b4637 100644 --- a/api/tests/test_containers_integration_tests/services/test_dataset_service_update_dataset.py +++ b/api/tests/test_containers_integration_tests/services/test_dataset_service_update_dataset.py @@ -2,9 +2,9 @@ from unittest.mock import Mock, patch from uuid import uuid4 import pytest +from sqlalchemy.orm import Session from dify_graph.model_runtime.entities.model_entities import ModelType -from extensions.ext_database import db from models.account import Account, Tenant, TenantAccountJoin, TenantAccountRole from models.dataset import Dataset, ExternalKnowledgeBindings from services.dataset_service import DatasetService @@ -15,7 +15,9 @@ class DatasetUpdateTestDataFactory: """Factory class for creating real test data for dataset update integration tests.""" @staticmethod - def create_account_with_tenant(role: TenantAccountRole = TenantAccountRole.OWNER) -> tuple[Account, Tenant]: + def create_account_with_tenant( + db_session_with_containers: Session, role: TenantAccountRole = TenantAccountRole.OWNER + ) -> tuple[Account, Tenant]: """Create a real account and tenant with the given role.""" account = Account( email=f"{uuid4()}@example.com", @@ -23,12 +25,12 @@ class DatasetUpdateTestDataFactory: interface_language="en-US", status="active", ) - db.session.add(account) - db.session.commit() + db_session_with_containers.add(account) + db_session_with_containers.commit() tenant = Tenant(name=f"tenant-{account.id}", status="normal") - db.session.add(tenant) - db.session.commit() + db_session_with_containers.add(tenant) + db_session_with_containers.commit() join = TenantAccountJoin( tenant_id=tenant.id, @@ -36,14 +38,15 @@ class DatasetUpdateTestDataFactory: role=role, current=True, ) - db.session.add(join) - db.session.commit() + db_session_with_containers.add(join) + db_session_with_containers.commit() account.current_tenant = tenant return account, tenant @staticmethod def create_dataset( + db_session_with_containers: Session, tenant_id: str, created_by: str, provider: str = "vendor", @@ -71,12 +74,13 @@ class DatasetUpdateTestDataFactory: embedding_model=embedding_model, collection_binding_id=collection_binding_id, ) - db.session.add(dataset) - db.session.commit() + db_session_with_containers.add(dataset) + db_session_with_containers.commit() return dataset @staticmethod def create_external_binding( + db_session_with_containers: Session, tenant_id: str, dataset_id: str, created_by: str, @@ -93,8 +97,8 @@ class DatasetUpdateTestDataFactory: external_knowledge_id=external_knowledge_id, external_knowledge_api_id=external_knowledge_api_id, ) - db.session.add(binding) - db.session.commit() + db_session_with_containers.add(binding) + db_session_with_containers.commit() return binding @@ -112,10 +116,11 @@ class TestDatasetServiceUpdateDataset: # ==================== External Dataset Tests ==================== - def test_update_external_dataset_success(self, db_session_with_containers): + def test_update_external_dataset_success(self, db_session_with_containers: Session): """Test successful update of external dataset.""" - user, tenant = DatasetUpdateTestDataFactory.create_account_with_tenant() + user, tenant = DatasetUpdateTestDataFactory.create_account_with_tenant(db_session_with_containers) dataset = DatasetUpdateTestDataFactory.create_dataset( + db_session_with_containers, tenant_id=tenant.id, created_by=user.id, provider="external", @@ -124,12 +129,13 @@ class TestDatasetServiceUpdateDataset: retrieval_model="old_model", ) binding = DatasetUpdateTestDataFactory.create_external_binding( + db_session_with_containers, tenant_id=tenant.id, dataset_id=dataset.id, created_by=user.id, ) binding_id = binding.id - db.session.expunge(binding) + db_session_with_containers.expunge(binding) update_data = { "name": "new_name", @@ -142,8 +148,8 @@ class TestDatasetServiceUpdateDataset: result = DatasetService.update_dataset(dataset.id, update_data, user) - db.session.refresh(dataset) - updated_binding = db.session.query(ExternalKnowledgeBindings).filter_by(id=binding_id).first() + db_session_with_containers.refresh(dataset) + updated_binding = db_session_with_containers.query(ExternalKnowledgeBindings).filter_by(id=binding_id).first() assert dataset.name == "new_name" assert dataset.description == "new_description" @@ -153,15 +159,17 @@ class TestDatasetServiceUpdateDataset: assert updated_binding.external_knowledge_api_id == update_data["external_knowledge_api_id"] assert result.id == dataset.id - def test_update_external_dataset_missing_knowledge_id_error(self, db_session_with_containers): + def test_update_external_dataset_missing_knowledge_id_error(self, db_session_with_containers: Session): """Test error when external knowledge id is missing.""" - user, tenant = DatasetUpdateTestDataFactory.create_account_with_tenant() + user, tenant = DatasetUpdateTestDataFactory.create_account_with_tenant(db_session_with_containers) dataset = DatasetUpdateTestDataFactory.create_dataset( + db_session_with_containers, tenant_id=tenant.id, created_by=user.id, provider="external", ) DatasetUpdateTestDataFactory.create_external_binding( + db_session_with_containers, tenant_id=tenant.id, dataset_id=dataset.id, created_by=user.id, @@ -173,17 +181,19 @@ class TestDatasetServiceUpdateDataset: DatasetService.update_dataset(dataset.id, update_data, user) assert "External knowledge id is required" in str(context.value) - db.session.rollback() + db_session_with_containers.rollback() - def test_update_external_dataset_missing_api_id_error(self, db_session_with_containers): + def test_update_external_dataset_missing_api_id_error(self, db_session_with_containers: Session): """Test error when external knowledge api id is missing.""" - user, tenant = DatasetUpdateTestDataFactory.create_account_with_tenant() + user, tenant = DatasetUpdateTestDataFactory.create_account_with_tenant(db_session_with_containers) dataset = DatasetUpdateTestDataFactory.create_dataset( + db_session_with_containers, tenant_id=tenant.id, created_by=user.id, provider="external", ) DatasetUpdateTestDataFactory.create_external_binding( + db_session_with_containers, tenant_id=tenant.id, dataset_id=dataset.id, created_by=user.id, @@ -195,12 +205,13 @@ class TestDatasetServiceUpdateDataset: DatasetService.update_dataset(dataset.id, update_data, user) assert "External knowledge api id is required" in str(context.value) - db.session.rollback() + db_session_with_containers.rollback() - def test_update_external_dataset_binding_not_found_error(self, db_session_with_containers): + def test_update_external_dataset_binding_not_found_error(self, db_session_with_containers: Session): """Test error when external knowledge binding is not found.""" - user, tenant = DatasetUpdateTestDataFactory.create_account_with_tenant() + user, tenant = DatasetUpdateTestDataFactory.create_account_with_tenant(db_session_with_containers) dataset = DatasetUpdateTestDataFactory.create_dataset( + db_session_with_containers, tenant_id=tenant.id, created_by=user.id, provider="external", @@ -216,15 +227,16 @@ class TestDatasetServiceUpdateDataset: DatasetService.update_dataset(dataset.id, update_data, user) assert "External knowledge binding not found" in str(context.value) - db.session.rollback() + db_session_with_containers.rollback() # ==================== Internal Dataset Basic Tests ==================== - def test_update_internal_dataset_basic_success(self, db_session_with_containers): + def test_update_internal_dataset_basic_success(self, db_session_with_containers: Session): """Test successful update of internal dataset with basic fields.""" - user, tenant = DatasetUpdateTestDataFactory.create_account_with_tenant() + user, tenant = DatasetUpdateTestDataFactory.create_account_with_tenant(db_session_with_containers) existing_binding_id = str(uuid4()) dataset = DatasetUpdateTestDataFactory.create_dataset( + db_session_with_containers, tenant_id=tenant.id, created_by=user.id, provider="vendor", @@ -244,7 +256,7 @@ class TestDatasetServiceUpdateDataset: } result = DatasetService.update_dataset(dataset.id, update_data, user) - db.session.refresh(dataset) + db_session_with_containers.refresh(dataset) assert dataset.name == "new_name" assert dataset.description == "new_description" @@ -254,11 +266,12 @@ class TestDatasetServiceUpdateDataset: assert dataset.embedding_model == "text-embedding-ada-002" assert result.id == dataset.id - def test_update_internal_dataset_filter_none_values(self, db_session_with_containers): + def test_update_internal_dataset_filter_none_values(self, db_session_with_containers: Session): """Test that None values are filtered out except for description field.""" - user, tenant = DatasetUpdateTestDataFactory.create_account_with_tenant() + user, tenant = DatasetUpdateTestDataFactory.create_account_with_tenant(db_session_with_containers) existing_binding_id = str(uuid4()) dataset = DatasetUpdateTestDataFactory.create_dataset( + db_session_with_containers, tenant_id=tenant.id, created_by=user.id, provider="vendor", @@ -278,7 +291,7 @@ class TestDatasetServiceUpdateDataset: } result = DatasetService.update_dataset(dataset.id, update_data, user) - db.session.refresh(dataset) + db_session_with_containers.refresh(dataset) assert dataset.name == "new_name" assert dataset.description is None @@ -289,11 +302,12 @@ class TestDatasetServiceUpdateDataset: # ==================== Indexing Technique Switch Tests ==================== - def test_update_internal_dataset_indexing_technique_to_economy(self, db_session_with_containers): + def test_update_internal_dataset_indexing_technique_to_economy(self, db_session_with_containers: Session): """Test updating internal dataset indexing technique to economy.""" - user, tenant = DatasetUpdateTestDataFactory.create_account_with_tenant() + user, tenant = DatasetUpdateTestDataFactory.create_account_with_tenant(db_session_with_containers) existing_binding_id = str(uuid4()) dataset = DatasetUpdateTestDataFactory.create_dataset( + db_session_with_containers, tenant_id=tenant.id, created_by=user.id, provider="vendor", @@ -312,7 +326,7 @@ class TestDatasetServiceUpdateDataset: result = DatasetService.update_dataset(dataset.id, update_data, user) mock_task.delay.assert_called_once_with(dataset.id, "remove") - db.session.refresh(dataset) + db_session_with_containers.refresh(dataset) assert dataset.indexing_technique == "economy" assert dataset.embedding_model is None assert dataset.embedding_model_provider is None @@ -320,10 +334,11 @@ class TestDatasetServiceUpdateDataset: assert dataset.retrieval_model == "new_model" assert result.id == dataset.id - def test_update_internal_dataset_indexing_technique_to_high_quality(self, db_session_with_containers): + def test_update_internal_dataset_indexing_technique_to_high_quality(self, db_session_with_containers: Session): """Test updating internal dataset indexing technique to high_quality.""" - user, tenant = DatasetUpdateTestDataFactory.create_account_with_tenant() + user, tenant = DatasetUpdateTestDataFactory.create_account_with_tenant(db_session_with_containers) dataset = DatasetUpdateTestDataFactory.create_dataset( + db_session_with_containers, tenant_id=tenant.id, created_by=user.id, provider="vendor", @@ -366,7 +381,7 @@ class TestDatasetServiceUpdateDataset: mock_get_binding.assert_called_once_with("openai", "text-embedding-ada-002") mock_task.delay.assert_called_once_with(dataset.id, "add") - db.session.refresh(dataset) + db_session_with_containers.refresh(dataset) assert dataset.indexing_technique == "high_quality" assert dataset.embedding_model == "text-embedding-ada-002" assert dataset.embedding_model_provider == "openai" @@ -380,9 +395,10 @@ class TestDatasetServiceUpdateDataset: self, db_session_with_containers ): """Test preserving embedding settings when indexing technique remains unchanged.""" - user, tenant = DatasetUpdateTestDataFactory.create_account_with_tenant() + user, tenant = DatasetUpdateTestDataFactory.create_account_with_tenant(db_session_with_containers) existing_binding_id = str(uuid4()) dataset = DatasetUpdateTestDataFactory.create_dataset( + db_session_with_containers, tenant_id=tenant.id, created_by=user.id, provider="vendor", @@ -399,7 +415,7 @@ class TestDatasetServiceUpdateDataset: } result = DatasetService.update_dataset(dataset.id, update_data, user) - db.session.refresh(dataset) + db_session_with_containers.refresh(dataset) assert dataset.name == "new_name" assert dataset.indexing_technique == "high_quality" @@ -409,11 +425,12 @@ class TestDatasetServiceUpdateDataset: assert dataset.retrieval_model == "new_model" assert result.id == dataset.id - def test_update_internal_dataset_embedding_model_update(self, db_session_with_containers): + def test_update_internal_dataset_embedding_model_update(self, db_session_with_containers: Session): """Test updating internal dataset with new embedding model.""" - user, tenant = DatasetUpdateTestDataFactory.create_account_with_tenant() + user, tenant = DatasetUpdateTestDataFactory.create_account_with_tenant(db_session_with_containers) existing_binding_id = str(uuid4()) dataset = DatasetUpdateTestDataFactory.create_dataset( + db_session_with_containers, tenant_id=tenant.id, created_by=user.id, provider="vendor", @@ -465,7 +482,7 @@ class TestDatasetServiceUpdateDataset: regenerate_vectors_only=True, ) - db.session.refresh(dataset) + db_session_with_containers.refresh(dataset) assert dataset.embedding_model == "text-embedding-3-small" assert dataset.embedding_model_provider == "openai" assert dataset.collection_binding_id == binding.id @@ -474,9 +491,9 @@ class TestDatasetServiceUpdateDataset: # ==================== Error Handling Tests ==================== - def test_update_dataset_not_found_error(self, db_session_with_containers): + def test_update_dataset_not_found_error(self, db_session_with_containers: Session): """Test error when dataset is not found.""" - user, _ = DatasetUpdateTestDataFactory.create_account_with_tenant() + user, _ = DatasetUpdateTestDataFactory.create_account_with_tenant(db_session_with_containers) update_data = {"name": "new_name"} with pytest.raises(ValueError) as context: @@ -484,11 +501,16 @@ class TestDatasetServiceUpdateDataset: assert "Dataset not found" in str(context.value) - def test_update_dataset_permission_error(self, db_session_with_containers): + def test_update_dataset_permission_error(self, db_session_with_containers: Session): """Test error when user doesn't have permission.""" - owner, tenant = DatasetUpdateTestDataFactory.create_account_with_tenant(role=TenantAccountRole.OWNER) - outsider, _ = DatasetUpdateTestDataFactory.create_account_with_tenant(role=TenantAccountRole.NORMAL) + owner, tenant = DatasetUpdateTestDataFactory.create_account_with_tenant( + db_session_with_containers, role=TenantAccountRole.OWNER + ) + outsider, _ = DatasetUpdateTestDataFactory.create_account_with_tenant( + db_session_with_containers, role=TenantAccountRole.NORMAL + ) dataset = DatasetUpdateTestDataFactory.create_dataset( + db_session_with_containers, tenant_id=tenant.id, created_by=owner.id, provider="vendor", @@ -500,10 +522,11 @@ class TestDatasetServiceUpdateDataset: with pytest.raises(NoPermissionError): DatasetService.update_dataset(dataset.id, update_data, outsider) - def test_update_internal_dataset_embedding_model_error(self, db_session_with_containers): + def test_update_internal_dataset_embedding_model_error(self, db_session_with_containers: Session): """Test error when embedding model is not available.""" - user, tenant = DatasetUpdateTestDataFactory.create_account_with_tenant() + user, tenant = DatasetUpdateTestDataFactory.create_account_with_tenant(db_session_with_containers) dataset = DatasetUpdateTestDataFactory.create_dataset( + db_session_with_containers, tenant_id=tenant.id, created_by=user.id, provider="vendor", diff --git a/api/tests/test_containers_integration_tests/services/test_file_service.py b/api/tests/test_containers_integration_tests/services/test_file_service.py index 93516a0030..6712fe8454 100644 --- a/api/tests/test_containers_integration_tests/services/test_file_service.py +++ b/api/tests/test_containers_integration_tests/services/test_file_service.py @@ -5,6 +5,7 @@ from unittest.mock import create_autospec, patch import pytest from faker import Faker from sqlalchemy import Engine +from sqlalchemy.orm import Session from werkzeug.exceptions import NotFound from configs import dify_config @@ -19,7 +20,7 @@ class TestFileService: """Integration tests for FileService using testcontainers.""" @pytest.fixture - def engine(self, db_session_with_containers): + def engine(self, db_session_with_containers: Session): bind = db_session_with_containers.get_bind() assert isinstance(bind, Engine) return bind @@ -46,7 +47,7 @@ class TestFileService: "extract_processor": mock_extract_processor, } - def _create_test_account(self, db_session_with_containers, mock_external_service_dependencies): + def _create_test_account(self, db_session_with_containers: Session, mock_external_service_dependencies): """ Helper method to create a test account for testing. @@ -67,18 +68,16 @@ class TestFileService: status="active", ) - from extensions.ext_database import db - - db.session.add(account) - db.session.commit() + db_session_with_containers.add(account) + db_session_with_containers.commit() # Create tenant for the account tenant = Tenant( name=fake.company(), status="normal", ) - db.session.add(tenant) - db.session.commit() + db_session_with_containers.add(tenant) + db_session_with_containers.commit() # Create tenant-account join from models.account import TenantAccountJoin, TenantAccountRole @@ -89,15 +88,15 @@ class TestFileService: role=TenantAccountRole.OWNER, current=True, ) - db.session.add(join) - db.session.commit() + db_session_with_containers.add(join) + db_session_with_containers.commit() # Set current tenant for account account.current_tenant = tenant return account - def _create_test_end_user(self, db_session_with_containers, mock_external_service_dependencies): + def _create_test_end_user(self, db_session_with_containers: Session, mock_external_service_dependencies): """ Helper method to create a test end user for testing. @@ -118,14 +117,14 @@ class TestFileService: session_id=fake.uuid4(), ) - from extensions.ext_database import db - - db.session.add(end_user) - db.session.commit() + db_session_with_containers.add(end_user) + db_session_with_containers.commit() return end_user - def _create_test_upload_file(self, db_session_with_containers, mock_external_service_dependencies, account): + def _create_test_upload_file( + self, db_session_with_containers: Session, mock_external_service_dependencies, account + ): """ Helper method to create a test upload file for testing. @@ -155,15 +154,13 @@ class TestFileService: source_url="", ) - from extensions.ext_database import db - - db.session.add(upload_file) - db.session.commit() + db_session_with_containers.add(upload_file) + db_session_with_containers.commit() return upload_file # Test upload_file method - def test_upload_file_success(self, db_session_with_containers, engine, mock_external_service_dependencies): + def test_upload_file_success(self, db_session_with_containers: Session, engine, mock_external_service_dependencies): """ Test successful file upload with valid parameters. """ @@ -196,7 +193,9 @@ class TestFileService: assert upload_file.id is not None - def test_upload_file_with_end_user(self, db_session_with_containers, engine, mock_external_service_dependencies): + def test_upload_file_with_end_user( + self, db_session_with_containers: Session, engine, mock_external_service_dependencies + ): """ Test file upload with end user instead of account. """ @@ -219,7 +218,7 @@ class TestFileService: assert upload_file.created_by_role == CreatorUserRole.END_USER def test_upload_file_with_datasets_source( - self, db_session_with_containers, engine, mock_external_service_dependencies + self, db_session_with_containers: Session, engine, mock_external_service_dependencies ): """ Test file upload with datasets source parameter. @@ -244,7 +243,7 @@ class TestFileService: assert upload_file.source_url == "https://example.com/source" def test_upload_file_invalid_filename_characters( - self, db_session_with_containers, engine, mock_external_service_dependencies + self, db_session_with_containers: Session, engine, mock_external_service_dependencies ): """ Test file upload with invalid filename characters. @@ -265,7 +264,7 @@ class TestFileService: ) def test_upload_file_filename_too_long( - self, db_session_with_containers, engine, mock_external_service_dependencies + self, db_session_with_containers: Session, engine, mock_external_service_dependencies ): """ Test file upload with filename that exceeds length limit. @@ -295,7 +294,7 @@ class TestFileService: assert len(base_name) <= 200 def test_upload_file_datasets_unsupported_type( - self, db_session_with_containers, engine, mock_external_service_dependencies + self, db_session_with_containers: Session, engine, mock_external_service_dependencies ): """ Test file upload for datasets with unsupported file type. @@ -316,7 +315,9 @@ class TestFileService: source="datasets", ) - def test_upload_file_too_large(self, db_session_with_containers, engine, mock_external_service_dependencies): + def test_upload_file_too_large( + self, db_session_with_containers: Session, engine, mock_external_service_dependencies + ): """ Test file upload with file size exceeding limit. """ @@ -338,7 +339,7 @@ class TestFileService: # Test is_file_size_within_limit method def test_is_file_size_within_limit_image_success( - self, db_session_with_containers, engine, mock_external_service_dependencies + self, db_session_with_containers: Session, engine, mock_external_service_dependencies ): """ Test file size check for image files within limit. @@ -351,7 +352,7 @@ class TestFileService: assert result is True def test_is_file_size_within_limit_video_success( - self, db_session_with_containers, engine, mock_external_service_dependencies + self, db_session_with_containers: Session, engine, mock_external_service_dependencies ): """ Test file size check for video files within limit. @@ -364,7 +365,7 @@ class TestFileService: assert result is True def test_is_file_size_within_limit_audio_success( - self, db_session_with_containers, engine, mock_external_service_dependencies + self, db_session_with_containers: Session, engine, mock_external_service_dependencies ): """ Test file size check for audio files within limit. @@ -377,7 +378,7 @@ class TestFileService: assert result is True def test_is_file_size_within_limit_document_success( - self, db_session_with_containers, engine, mock_external_service_dependencies + self, db_session_with_containers: Session, engine, mock_external_service_dependencies ): """ Test file size check for document files within limit. @@ -390,7 +391,7 @@ class TestFileService: assert result is True def test_is_file_size_within_limit_image_exceeded( - self, db_session_with_containers, engine, mock_external_service_dependencies + self, db_session_with_containers: Session, engine, mock_external_service_dependencies ): """ Test file size check for image files exceeding limit. @@ -403,7 +404,7 @@ class TestFileService: assert result is False def test_is_file_size_within_limit_unknown_extension( - self, db_session_with_containers, engine, mock_external_service_dependencies + self, db_session_with_containers: Session, engine, mock_external_service_dependencies ): """ Test file size check for unknown file extension. @@ -416,7 +417,7 @@ class TestFileService: assert result is True # Test upload_text method - def test_upload_text_success(self, db_session_with_containers, engine, mock_external_service_dependencies): + def test_upload_text_success(self, db_session_with_containers: Session, engine, mock_external_service_dependencies): """ Test successful text upload. """ @@ -447,7 +448,9 @@ class TestFileService: # Verify storage was called mock_external_service_dependencies["storage"].save.assert_called_once() - def test_upload_text_name_too_long(self, db_session_with_containers, engine, mock_external_service_dependencies): + def test_upload_text_name_too_long( + self, db_session_with_containers: Session, engine, mock_external_service_dependencies + ): """ Test text upload with name that exceeds length limit. """ @@ -472,7 +475,9 @@ class TestFileService: assert upload_file.name == "a" * 200 # Test get_file_preview method - def test_get_file_preview_success(self, db_session_with_containers, engine, mock_external_service_dependencies): + def test_get_file_preview_success( + self, db_session_with_containers: Session, engine, mock_external_service_dependencies + ): """ Test successful file preview generation. """ @@ -484,9 +489,8 @@ class TestFileService: # Update file to have document extension upload_file.extension = "pdf" - from extensions.ext_database import db - db.session.commit() + db_session_with_containers.commit() result = FileService(engine).get_file_preview(file_id=upload_file.id) @@ -494,7 +498,7 @@ class TestFileService: mock_external_service_dependencies["extract_processor"].load_from_upload_file.assert_called_once() def test_get_file_preview_file_not_found( - self, db_session_with_containers, engine, mock_external_service_dependencies + self, db_session_with_containers: Session, engine, mock_external_service_dependencies ): """ Test file preview with non-existent file. @@ -506,7 +510,7 @@ class TestFileService: FileService(engine).get_file_preview(file_id=non_existent_id) def test_get_file_preview_unsupported_file_type( - self, db_session_with_containers, engine, mock_external_service_dependencies + self, db_session_with_containers: Session, engine, mock_external_service_dependencies ): """ Test file preview with unsupported file type. @@ -519,15 +523,14 @@ class TestFileService: # Update file to have non-document extension upload_file.extension = "jpg" - from extensions.ext_database import db - db.session.commit() + db_session_with_containers.commit() with pytest.raises(UnsupportedFileTypeError): FileService(engine).get_file_preview(file_id=upload_file.id) def test_get_file_preview_text_truncation( - self, db_session_with_containers, engine, mock_external_service_dependencies + self, db_session_with_containers: Session, engine, mock_external_service_dependencies ): """ Test file preview with text that exceeds preview limit. @@ -540,9 +543,8 @@ class TestFileService: # Update file to have document extension upload_file.extension = "pdf" - from extensions.ext_database import db - db.session.commit() + db_session_with_containers.commit() # Mock long text content long_text = "x" * 5000 # Longer than PREVIEW_WORDS_LIMIT @@ -554,7 +556,9 @@ class TestFileService: assert result == "x" * 3000 # Test get_image_preview method - def test_get_image_preview_success(self, db_session_with_containers, engine, mock_external_service_dependencies): + def test_get_image_preview_success( + self, db_session_with_containers: Session, engine, mock_external_service_dependencies + ): """ Test successful image preview generation. """ @@ -566,9 +570,8 @@ class TestFileService: # Update file to have image extension upload_file.extension = "jpg" - from extensions.ext_database import db - db.session.commit() + db_session_with_containers.commit() timestamp = "1234567890" nonce = "test_nonce" @@ -586,7 +589,7 @@ class TestFileService: mock_external_service_dependencies["file_helpers"].verify_image_signature.assert_called_once() def test_get_image_preview_invalid_signature( - self, db_session_with_containers, engine, mock_external_service_dependencies + self, db_session_with_containers: Session, engine, mock_external_service_dependencies ): """ Test image preview with invalid signature. @@ -613,7 +616,7 @@ class TestFileService: ) def test_get_image_preview_file_not_found( - self, db_session_with_containers, engine, mock_external_service_dependencies + self, db_session_with_containers: Session, engine, mock_external_service_dependencies ): """ Test image preview with non-existent file. @@ -634,7 +637,7 @@ class TestFileService: ) def test_get_image_preview_unsupported_file_type( - self, db_session_with_containers, engine, mock_external_service_dependencies + self, db_session_with_containers: Session, engine, mock_external_service_dependencies ): """ Test image preview with non-image file type. @@ -647,9 +650,8 @@ class TestFileService: # Update file to have non-image extension upload_file.extension = "pdf" - from extensions.ext_database import db - db.session.commit() + db_session_with_containers.commit() timestamp = "1234567890" nonce = "test_nonce" @@ -665,7 +667,7 @@ class TestFileService: # Test get_file_generator_by_file_id method def test_get_file_generator_by_file_id_success( - self, db_session_with_containers, engine, mock_external_service_dependencies + self, db_session_with_containers: Session, engine, mock_external_service_dependencies ): """ Test successful file generator retrieval. @@ -692,7 +694,7 @@ class TestFileService: mock_external_service_dependencies["file_helpers"].verify_file_signature.assert_called_once() def test_get_file_generator_by_file_id_invalid_signature( - self, db_session_with_containers, engine, mock_external_service_dependencies + self, db_session_with_containers: Session, engine, mock_external_service_dependencies ): """ Test file generator retrieval with invalid signature. @@ -719,7 +721,7 @@ class TestFileService: ) def test_get_file_generator_by_file_id_file_not_found( - self, db_session_with_containers, engine, mock_external_service_dependencies + self, db_session_with_containers: Session, engine, mock_external_service_dependencies ): """ Test file generator retrieval with non-existent file. @@ -741,7 +743,7 @@ class TestFileService: # Test get_public_image_preview method def test_get_public_image_preview_success( - self, db_session_with_containers, engine, mock_external_service_dependencies + self, db_session_with_containers: Session, engine, mock_external_service_dependencies ): """ Test successful public image preview generation. @@ -754,9 +756,8 @@ class TestFileService: # Update file to have image extension upload_file.extension = "jpg" - from extensions.ext_database import db - db.session.commit() + db_session_with_containers.commit() generator, mime_type = FileService(engine).get_public_image_preview(file_id=upload_file.id) @@ -765,7 +766,7 @@ class TestFileService: mock_external_service_dependencies["storage"].load.assert_called_once() def test_get_public_image_preview_file_not_found( - self, db_session_with_containers, engine, mock_external_service_dependencies + self, db_session_with_containers: Session, engine, mock_external_service_dependencies ): """ Test public image preview with non-existent file. @@ -777,7 +778,7 @@ class TestFileService: FileService(engine).get_public_image_preview(file_id=non_existent_id) def test_get_public_image_preview_unsupported_file_type( - self, db_session_with_containers, engine, mock_external_service_dependencies + self, db_session_with_containers: Session, engine, mock_external_service_dependencies ): """ Test public image preview with non-image file type. @@ -790,15 +791,16 @@ class TestFileService: # Update file to have non-image extension upload_file.extension = "pdf" - from extensions.ext_database import db - db.session.commit() + db_session_with_containers.commit() with pytest.raises(UnsupportedFileTypeError): FileService(engine).get_public_image_preview(file_id=upload_file.id) # Test edge cases and boundary conditions - def test_upload_file_empty_content(self, db_session_with_containers, engine, mock_external_service_dependencies): + def test_upload_file_empty_content( + self, db_session_with_containers: Session, engine, mock_external_service_dependencies + ): """ Test file upload with empty content. """ @@ -820,7 +822,7 @@ class TestFileService: assert upload_file.size == 0 def test_upload_file_special_characters_in_name( - self, db_session_with_containers, engine, mock_external_service_dependencies + self, db_session_with_containers: Session, engine, mock_external_service_dependencies ): """ Test file upload with special characters in filename (but valid ones). @@ -843,7 +845,7 @@ class TestFileService: assert upload_file.name == filename def test_upload_file_different_case_extensions( - self, db_session_with_containers, engine, mock_external_service_dependencies + self, db_session_with_containers: Session, engine, mock_external_service_dependencies ): """ Test file upload with different case extensions. @@ -865,7 +867,9 @@ class TestFileService: assert upload_file is not None assert upload_file.extension == "pdf" # Should be converted to lowercase - def test_upload_text_empty_text(self, db_session_with_containers, engine, mock_external_service_dependencies): + def test_upload_text_empty_text( + self, db_session_with_containers: Session, engine, mock_external_service_dependencies + ): """ Test text upload with empty text. """ @@ -888,7 +892,9 @@ class TestFileService: assert upload_file is not None assert upload_file.size == 0 - def test_file_size_limits_edge_cases(self, db_session_with_containers, engine, mock_external_service_dependencies): + def test_file_size_limits_edge_cases( + self, db_session_with_containers: Session, engine, mock_external_service_dependencies + ): """ Test file size limits with edge case values. """ @@ -908,7 +914,9 @@ class TestFileService: result = FileService(engine).is_file_size_within_limit(extension=extension, file_size=file_size) assert result is False - def test_upload_file_with_source_url(self, db_session_with_containers, engine, mock_external_service_dependencies): + def test_upload_file_with_source_url( + self, db_session_with_containers: Session, engine, mock_external_service_dependencies + ): """ Test file upload with source URL that gets overridden by signed URL. """ @@ -946,7 +954,7 @@ class TestFileService: # Test file extension blacklist def test_upload_file_blocked_extension( - self, db_session_with_containers, engine, mock_external_service_dependencies + self, db_session_with_containers: Session, engine, mock_external_service_dependencies ): """ Test file upload with blocked extension. @@ -969,7 +977,7 @@ class TestFileService: ) def test_upload_file_blocked_extension_case_insensitive( - self, db_session_with_containers, engine, mock_external_service_dependencies + self, db_session_with_containers: Session, engine, mock_external_service_dependencies ): """ Test file upload with blocked extension (case insensitive). @@ -992,7 +1000,9 @@ class TestFileService: user=account, ) - def test_upload_file_not_in_blacklist(self, db_session_with_containers, engine, mock_external_service_dependencies): + def test_upload_file_not_in_blacklist( + self, db_session_with_containers: Session, engine, mock_external_service_dependencies + ): """ Test file upload with extension not in blacklist. """ @@ -1016,7 +1026,9 @@ class TestFileService: assert upload_file.name == filename assert upload_file.extension == "pdf" - def test_upload_file_empty_blacklist(self, db_session_with_containers, engine, mock_external_service_dependencies): + def test_upload_file_empty_blacklist( + self, db_session_with_containers: Session, engine, mock_external_service_dependencies + ): """ Test file upload with empty blacklist (default behavior). """ @@ -1041,7 +1053,7 @@ class TestFileService: assert upload_file.extension == "sh" def test_upload_file_multiple_blocked_extensions( - self, db_session_with_containers, engine, mock_external_service_dependencies + self, db_session_with_containers: Session, engine, mock_external_service_dependencies ): """ Test file upload with multiple blocked extensions. @@ -1066,7 +1078,7 @@ class TestFileService: ) def test_upload_file_no_extension_with_blacklist( - self, db_session_with_containers, engine, mock_external_service_dependencies + self, db_session_with_containers: Session, engine, mock_external_service_dependencies ): """ Test file upload with no extension when blacklist is configured. diff --git a/api/tests/test_containers_integration_tests/services/test_message_service.py b/api/tests/test_containers_integration_tests/services/test_message_service.py index ece6de6cdf..19a684a58a 100644 --- a/api/tests/test_containers_integration_tests/services/test_message_service.py +++ b/api/tests/test_containers_integration_tests/services/test_message_service.py @@ -2,6 +2,7 @@ from unittest.mock import patch import pytest from faker import Faker +from sqlalchemy.orm import Session from models.model import MessageFeedback from services.app_service import AppService @@ -69,7 +70,7 @@ class TestMessageService: # "current_user": mock_current_user, } - def _create_test_app_and_account(self, db_session_with_containers, mock_external_service_dependencies): + def _create_test_app_and_account(self, db_session_with_containers: Session, mock_external_service_dependencies): """ Helper method to create a test app and account for testing. @@ -127,11 +128,10 @@ class TestMessageService: # mock_external_service_dependencies["current_user"].id = account_id # mock_external_service_dependencies["current_user"].current_tenant_id = tenant_id - def _create_test_conversation(self, app, account, fake): + def _create_test_conversation(self, db_session_with_containers: Session, app, account, fake): """ Helper method to create a test conversation with all required fields. """ - from extensions.ext_database import db from models.model import Conversation conversation = Conversation( @@ -153,17 +153,16 @@ class TestMessageService: from_account_id=account.id, ) - db.session.add(conversation) - db.session.flush() + db_session_with_containers.add(conversation) + db_session_with_containers.flush() return conversation - def _create_test_message(self, app, conversation, account, fake): + def _create_test_message(self, db_session_with_containers: Session, app, conversation, account, fake): """ Helper method to create a test message with all required fields. """ import json - from extensions.ext_database import db from models.model import Message message = Message( @@ -192,11 +191,13 @@ class TestMessageService: from_account_id=account.id, ) - db.session.add(message) - db.session.commit() + db_session_with_containers.add(message) + db_session_with_containers.commit() return message - def test_pagination_by_first_id_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_pagination_by_first_id_success( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test successful pagination by first ID. """ @@ -204,10 +205,10 @@ class TestMessageService: app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) # Create a conversation and multiple messages - conversation = self._create_test_conversation(app, account, fake) + conversation = self._create_test_conversation(db_session_with_containers, app, account, fake) messages = [] for i in range(5): - message = self._create_test_message(app, conversation, account, fake) + message = self._create_test_message(db_session_with_containers, app, conversation, account, fake) messages.append(message) # Test pagination by first ID @@ -228,7 +229,9 @@ class TestMessageService: # Verify messages are in ascending order assert result.data[0].created_at <= result.data[1].created_at - def test_pagination_by_first_id_no_user(self, db_session_with_containers, mock_external_service_dependencies): + def test_pagination_by_first_id_no_user( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test pagination by first ID when no user is provided. """ @@ -246,7 +249,7 @@ class TestMessageService: assert result.has_more is False def test_pagination_by_first_id_no_conversation_id( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test pagination by first ID when no conversation ID is provided. @@ -265,7 +268,7 @@ class TestMessageService: assert result.has_more is False def test_pagination_by_first_id_invalid_first_id( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test pagination by first ID with invalid first_id. @@ -274,8 +277,8 @@ class TestMessageService: app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) # Create a conversation and message - conversation = self._create_test_conversation(app, account, fake) - self._create_test_message(app, conversation, account, fake) + conversation = self._create_test_conversation(db_session_with_containers, app, account, fake) + self._create_test_message(db_session_with_containers, app, conversation, account, fake) # Test pagination with invalid first_id with pytest.raises(FirstMessageNotExistsError): @@ -287,7 +290,9 @@ class TestMessageService: limit=10, ) - def test_pagination_by_last_id_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_pagination_by_last_id_success( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test successful pagination by last ID. """ @@ -295,10 +300,10 @@ class TestMessageService: app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) # Create a conversation and multiple messages - conversation = self._create_test_conversation(app, account, fake) + conversation = self._create_test_conversation(db_session_with_containers, app, account, fake) messages = [] for i in range(5): - message = self._create_test_message(app, conversation, account, fake) + message = self._create_test_message(db_session_with_containers, app, conversation, account, fake) messages.append(message) # Test pagination by last ID @@ -319,7 +324,7 @@ class TestMessageService: assert result.data[0].created_at >= result.data[1].created_at def test_pagination_by_last_id_with_include_ids( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test pagination by last ID with include_ids filter. @@ -328,10 +333,10 @@ class TestMessageService: app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) # Create a conversation and multiple messages - conversation = self._create_test_conversation(app, account, fake) + conversation = self._create_test_conversation(db_session_with_containers, app, account, fake) messages = [] for i in range(5): - message = self._create_test_message(app, conversation, account, fake) + message = self._create_test_message(db_session_with_containers, app, conversation, account, fake) messages.append(message) # Test pagination with include_ids @@ -347,7 +352,9 @@ class TestMessageService: for message in result.data: assert message.id in include_ids - def test_pagination_by_last_id_no_user(self, db_session_with_containers, mock_external_service_dependencies): + def test_pagination_by_last_id_no_user( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test pagination by last ID when no user is provided. """ @@ -363,7 +370,7 @@ class TestMessageService: assert result.has_more is False def test_pagination_by_last_id_invalid_last_id( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test pagination by last ID with invalid last_id. @@ -372,8 +379,8 @@ class TestMessageService: app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) # Create a conversation and message - conversation = self._create_test_conversation(app, account, fake) - self._create_test_message(app, conversation, account, fake) + conversation = self._create_test_conversation(db_session_with_containers, app, account, fake) + self._create_test_message(db_session_with_containers, app, conversation, account, fake) # Test pagination with invalid last_id with pytest.raises(LastMessageNotExistsError): @@ -385,7 +392,7 @@ class TestMessageService: conversation_id=conversation.id, ) - def test_create_feedback_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_create_feedback_success(self, db_session_with_containers: Session, mock_external_service_dependencies): """ Test successful creation of feedback. """ @@ -393,8 +400,8 @@ class TestMessageService: app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) # Create a conversation and message - conversation = self._create_test_conversation(app, account, fake) - message = self._create_test_message(app, conversation, account, fake) + conversation = self._create_test_conversation(db_session_with_containers, app, account, fake) + message = self._create_test_message(db_session_with_containers, app, conversation, account, fake) # Create feedback rating = "like" @@ -413,7 +420,7 @@ class TestMessageService: assert feedback.from_account_id == account.id assert feedback.from_end_user_id is None - def test_create_feedback_no_user(self, db_session_with_containers, mock_external_service_dependencies): + def test_create_feedback_no_user(self, db_session_with_containers: Session, mock_external_service_dependencies): """ Test creating feedback when no user is provided. """ @@ -421,8 +428,8 @@ class TestMessageService: app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) # Create a conversation and message - conversation = self._create_test_conversation(app, account, fake) - message = self._create_test_message(app, conversation, account, fake) + conversation = self._create_test_conversation(db_session_with_containers, app, account, fake) + message = self._create_test_message(db_session_with_containers, app, conversation, account, fake) # Test creating feedback with no user with pytest.raises(ValueError, match="user cannot be None"): @@ -430,7 +437,9 @@ class TestMessageService: app_model=app, message_id=message.id, user=None, rating="like", content=fake.text(max_nb_chars=100) ) - def test_create_feedback_update_existing(self, db_session_with_containers, mock_external_service_dependencies): + def test_create_feedback_update_existing( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test updating existing feedback. """ @@ -438,8 +447,8 @@ class TestMessageService: app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) # Create a conversation and message - conversation = self._create_test_conversation(app, account, fake) - message = self._create_test_message(app, conversation, account, fake) + conversation = self._create_test_conversation(db_session_with_containers, app, account, fake) + message = self._create_test_message(db_session_with_containers, app, conversation, account, fake) # Create initial feedback initial_rating = "like" @@ -462,7 +471,9 @@ class TestMessageService: assert updated_feedback.rating != initial_rating assert updated_feedback.content != initial_content - def test_create_feedback_delete_existing(self, db_session_with_containers, mock_external_service_dependencies): + def test_create_feedback_delete_existing( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test deleting existing feedback by setting rating to None. """ @@ -470,8 +481,8 @@ class TestMessageService: app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) # Create a conversation and message - conversation = self._create_test_conversation(app, account, fake) - message = self._create_test_message(app, conversation, account, fake) + conversation = self._create_test_conversation(db_session_with_containers, app, account, fake) + message = self._create_test_message(db_session_with_containers, app, conversation, account, fake) # Create initial feedback feedback = MessageService.create_feedback( @@ -482,13 +493,14 @@ class TestMessageService: MessageService.create_feedback(app_model=app, message_id=message.id, user=account, rating=None, content=None) # Verify feedback was deleted - from extensions.ext_database import db - deleted_feedback = db.session.query(MessageFeedback).where(MessageFeedback.id == feedback.id).first() + deleted_feedback = ( + db_session_with_containers.query(MessageFeedback).where(MessageFeedback.id == feedback.id).first() + ) assert deleted_feedback is None def test_create_feedback_no_rating_when_not_exists( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test creating feedback with no rating when feedback doesn't exist. @@ -497,8 +509,8 @@ class TestMessageService: app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) # Create a conversation and message - conversation = self._create_test_conversation(app, account, fake) - message = self._create_test_message(app, conversation, account, fake) + conversation = self._create_test_conversation(db_session_with_containers, app, account, fake) + message = self._create_test_message(db_session_with_containers, app, conversation, account, fake) # Test creating feedback with no rating when no feedback exists with pytest.raises(ValueError, match="rating cannot be None when feedback not exists"): @@ -506,7 +518,9 @@ class TestMessageService: app_model=app, message_id=message.id, user=account, rating=None, content=None ) - def test_get_all_messages_feedbacks_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_get_all_messages_feedbacks_success( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test successful retrieval of all message feedbacks. """ @@ -516,8 +530,8 @@ class TestMessageService: # Create multiple conversations and messages with feedbacks feedbacks = [] for i in range(3): - conversation = self._create_test_conversation(app, account, fake) - message = self._create_test_message(app, conversation, account, fake) + conversation = self._create_test_conversation(db_session_with_containers, app, account, fake) + message = self._create_test_message(db_session_with_containers, app, conversation, account, fake) feedback = MessageService.create_feedback( app_model=app, @@ -539,7 +553,7 @@ class TestMessageService: assert result[i]["created_at"] >= result[i + 1]["created_at"] def test_get_all_messages_feedbacks_pagination( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test pagination of message feedbacks. @@ -549,8 +563,8 @@ class TestMessageService: # Create multiple conversations and messages with feedbacks for i in range(5): - conversation = self._create_test_conversation(app, account, fake) - message = self._create_test_message(app, conversation, account, fake) + conversation = self._create_test_conversation(db_session_with_containers, app, account, fake) + message = self._create_test_message(db_session_with_containers, app, conversation, account, fake) MessageService.create_feedback( app_model=app, message_id=message.id, user=account, rating="like", content=f"Feedback {i}" @@ -569,7 +583,7 @@ class TestMessageService: page_2_ids = {feedback["id"] for feedback in result_page_2} assert len(page_1_ids.intersection(page_2_ids)) == 0 - def test_get_message_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_get_message_success(self, db_session_with_containers: Session, mock_external_service_dependencies): """ Test successful retrieval of message. """ @@ -577,8 +591,8 @@ class TestMessageService: app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) # Create a conversation and message - conversation = self._create_test_conversation(app, account, fake) - message = self._create_test_message(app, conversation, account, fake) + conversation = self._create_test_conversation(db_session_with_containers, app, account, fake) + message = self._create_test_message(db_session_with_containers, app, conversation, account, fake) # Get message retrieved_message = MessageService.get_message(app_model=app, user=account, message_id=message.id) @@ -590,7 +604,7 @@ class TestMessageService: assert retrieved_message.from_source == "console" assert retrieved_message.from_account_id == account.id - def test_get_message_not_exists(self, db_session_with_containers, mock_external_service_dependencies): + def test_get_message_not_exists(self, db_session_with_containers: Session, mock_external_service_dependencies): """ Test getting message that doesn't exist. """ @@ -601,7 +615,7 @@ class TestMessageService: with pytest.raises(MessageNotExistsError): MessageService.get_message(app_model=app, user=account, message_id=fake.uuid4()) - def test_get_message_wrong_user(self, db_session_with_containers, mock_external_service_dependencies): + def test_get_message_wrong_user(self, db_session_with_containers: Session, mock_external_service_dependencies): """ Test getting message with wrong user (different account). """ @@ -609,8 +623,8 @@ class TestMessageService: app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) # Create a conversation and message - conversation = self._create_test_conversation(app, account, fake) - message = self._create_test_message(app, conversation, account, fake) + conversation = self._create_test_conversation(db_session_with_containers, app, account, fake) + message = self._create_test_message(db_session_with_containers, app, conversation, account, fake) # Create another account from services.account_service import AccountService, TenantService @@ -628,7 +642,7 @@ class TestMessageService: MessageService.get_message(app_model=app, user=other_account, message_id=message.id) def test_get_suggested_questions_after_answer_success( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test successful generation of suggested questions after answer. @@ -637,8 +651,8 @@ class TestMessageService: app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) # Create a conversation and message - conversation = self._create_test_conversation(app, account, fake) - message = self._create_test_message(app, conversation, account, fake) + conversation = self._create_test_conversation(db_session_with_containers, app, account, fake) + message = self._create_test_message(db_session_with_containers, app, conversation, account, fake) # Mock the LLMGenerator to return specific questions mock_questions = ["What is AI?", "How does machine learning work?", "Tell me about neural networks"] @@ -665,7 +679,7 @@ class TestMessageService: mock_external_service_dependencies["trace_manager_instance"].add_trace_task.assert_called_once() def test_get_suggested_questions_after_answer_no_user( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test getting suggested questions when no user is provided. @@ -674,8 +688,8 @@ class TestMessageService: app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) # Create a conversation and message - conversation = self._create_test_conversation(app, account, fake) - message = self._create_test_message(app, conversation, account, fake) + conversation = self._create_test_conversation(db_session_with_containers, app, account, fake) + message = self._create_test_message(db_session_with_containers, app, conversation, account, fake) # Test getting suggested questions with no user from core.app.entities.app_invoke_entities import InvokeFrom @@ -686,7 +700,7 @@ class TestMessageService: ) def test_get_suggested_questions_after_answer_disabled( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test getting suggested questions when feature is disabled. @@ -695,8 +709,8 @@ class TestMessageService: app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) # Create a conversation and message - conversation = self._create_test_conversation(app, account, fake) - message = self._create_test_message(app, conversation, account, fake) + conversation = self._create_test_conversation(db_session_with_containers, app, account, fake) + message = self._create_test_message(db_session_with_containers, app, conversation, account, fake) # Mock the feature to be disabled mock_external_service_dependencies[ @@ -712,7 +726,7 @@ class TestMessageService: ) def test_get_suggested_questions_after_answer_no_workflow( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test getting suggested questions when no workflow exists. @@ -721,8 +735,8 @@ class TestMessageService: app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) # Create a conversation and message - conversation = self._create_test_conversation(app, account, fake) - message = self._create_test_message(app, conversation, account, fake) + conversation = self._create_test_conversation(db_session_with_containers, app, account, fake) + message = self._create_test_message(db_session_with_containers, app, conversation, account, fake) # Mock no workflow mock_external_service_dependencies["workflow_service"].return_value.get_published_workflow.return_value = None @@ -738,7 +752,7 @@ class TestMessageService: assert result == [] def test_get_suggested_questions_after_answer_debugger_mode( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test getting suggested questions in debugger mode. @@ -747,8 +761,8 @@ class TestMessageService: app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) # Create a conversation and message - conversation = self._create_test_conversation(app, account, fake) - message = self._create_test_message(app, conversation, account, fake) + conversation = self._create_test_conversation(db_session_with_containers, app, account, fake) + message = self._create_test_message(db_session_with_containers, app, conversation, account, fake) # Mock questions mock_questions = ["Debug question 1", "Debug question 2"] diff --git a/api/tests/test_containers_integration_tests/services/test_messages_clean_service.py b/api/tests/test_containers_integration_tests/services/test_messages_clean_service.py index 5b6db64c09..6fe40c0744 100644 --- a/api/tests/test_containers_integration_tests/services/test_messages_clean_service.py +++ b/api/tests/test_containers_integration_tests/services/test_messages_clean_service.py @@ -6,9 +6,9 @@ from unittest.mock import patch import pytest from faker import Faker +from sqlalchemy.orm import Session from enums.cloud_plan import CloudPlan -from extensions.ext_database import db from extensions.ext_redis import redis_client from models.account import Account, Tenant, TenantAccountJoin, TenantAccountRole from models.model import ( @@ -40,25 +40,25 @@ class TestMessagesCleanServiceIntegration: PLAN_CACHE_KEY_PREFIX = BillingService._PLAN_CACHE_KEY_PREFIX # "tenant_plan:" @pytest.fixture(autouse=True) - def cleanup_database(self, db_session_with_containers): + def cleanup_database(self, db_session_with_containers: Session): """Clean up database before and after each test to ensure isolation.""" yield # Clear all test data in correct order (respecting foreign key constraints) - db.session.query(DatasetRetrieverResource).delete() - db.session.query(AppAnnotationHitHistory).delete() - db.session.query(SavedMessage).delete() - db.session.query(MessageFile).delete() - db.session.query(MessageAgentThought).delete() - db.session.query(MessageChain).delete() - db.session.query(MessageAnnotation).delete() - db.session.query(MessageFeedback).delete() - db.session.query(Message).delete() - db.session.query(Conversation).delete() - db.session.query(App).delete() - db.session.query(TenantAccountJoin).delete() - db.session.query(Tenant).delete() - db.session.query(Account).delete() - db.session.commit() + db_session_with_containers.query(DatasetRetrieverResource).delete() + db_session_with_containers.query(AppAnnotationHitHistory).delete() + db_session_with_containers.query(SavedMessage).delete() + db_session_with_containers.query(MessageFile).delete() + db_session_with_containers.query(MessageAgentThought).delete() + db_session_with_containers.query(MessageChain).delete() + db_session_with_containers.query(MessageAnnotation).delete() + db_session_with_containers.query(MessageFeedback).delete() + db_session_with_containers.query(Message).delete() + db_session_with_containers.query(Conversation).delete() + db_session_with_containers.query(App).delete() + db_session_with_containers.query(TenantAccountJoin).delete() + db_session_with_containers.query(Tenant).delete() + db_session_with_containers.query(Account).delete() + db_session_with_containers.commit() @pytest.fixture(autouse=True) def cleanup_redis(self): @@ -100,7 +100,7 @@ class TestMessagesCleanServiceIntegration: with patch("services.retention.conversation.messages_clean_policy.dify_config.BILLING_ENABLED", False): yield - def _create_account_and_tenant(self, plan: str = CloudPlan.SANDBOX): + def _create_account_and_tenant(self, db_session_with_containers: Session, plan: str = CloudPlan.SANDBOX): """Helper to create account and tenant.""" fake = Faker() @@ -110,28 +110,28 @@ class TestMessagesCleanServiceIntegration: interface_language="en-US", status="active", ) - db.session.add(account) - db.session.flush() + db_session_with_containers.add(account) + db_session_with_containers.flush() tenant = Tenant( name=fake.company(), plan=str(plan), status="normal", ) - db.session.add(tenant) - db.session.flush() + db_session_with_containers.add(tenant) + db_session_with_containers.flush() tenant_account_join = TenantAccountJoin( tenant_id=tenant.id, account_id=account.id, role=TenantAccountRole.OWNER, ) - db.session.add(tenant_account_join) - db.session.commit() + db_session_with_containers.add(tenant_account_join) + db_session_with_containers.commit() return account, tenant - def _create_app(self, tenant, account): + def _create_app(self, db_session_with_containers: Session, tenant, account): """Helper to create an app.""" fake = Faker() @@ -149,12 +149,12 @@ class TestMessagesCleanServiceIntegration: created_by=account.id, updated_by=account.id, ) - db.session.add(app) - db.session.commit() + db_session_with_containers.add(app) + db_session_with_containers.commit() return app - def _create_conversation(self, app): + def _create_conversation(self, db_session_with_containers: Session, app): """Helper to create a conversation.""" conversation = Conversation( app_id=app.id, @@ -168,12 +168,14 @@ class TestMessagesCleanServiceIntegration: from_source="api", from_end_user_id=str(uuid.uuid4()), ) - db.session.add(conversation) - db.session.commit() + db_session_with_containers.add(conversation) + db_session_with_containers.commit() return conversation - def _create_message(self, app, conversation, created_at=None, with_relations=True): + def _create_message( + self, db_session_with_containers: Session, app, conversation, created_at=None, with_relations=True + ): """Helper to create a message with optional related records.""" if created_at is None: created_at = datetime.datetime.now() @@ -197,16 +199,16 @@ class TestMessagesCleanServiceIntegration: from_account_id=conversation.from_end_user_id, created_at=created_at, ) - db.session.add(message) - db.session.flush() + db_session_with_containers.add(message) + db_session_with_containers.flush() if with_relations: - self._create_message_relations(message) + self._create_message_relations(db_session_with_containers, message) - db.session.commit() + db_session_with_containers.commit() return message - def _create_message_relations(self, message): + def _create_message_relations(self, db_session_with_containers: Session, message): """Helper to create all message-related records.""" # MessageFeedback feedback = MessageFeedback( @@ -217,7 +219,7 @@ class TestMessagesCleanServiceIntegration: from_source="api", from_end_user_id=str(uuid.uuid4()), ) - db.session.add(feedback) + db_session_with_containers.add(feedback) # MessageAnnotation annotation = MessageAnnotation( @@ -228,7 +230,7 @@ class TestMessagesCleanServiceIntegration: content="Test annotation", account_id=message.from_account_id, ) - db.session.add(annotation) + db_session_with_containers.add(annotation) # MessageChain chain = MessageChain( @@ -237,8 +239,8 @@ class TestMessagesCleanServiceIntegration: input=json.dumps({"test": "input"}), output=json.dumps({"test": "output"}), ) - db.session.add(chain) - db.session.flush() + db_session_with_containers.add(chain) + db_session_with_containers.flush() # MessageFile file = MessageFile( @@ -250,7 +252,7 @@ class TestMessagesCleanServiceIntegration: created_by_role="end_user", created_by=str(uuid.uuid4()), ) - db.session.add(file) + db_session_with_containers.add(file) # SavedMessage saved = SavedMessage( @@ -259,9 +261,9 @@ class TestMessagesCleanServiceIntegration: created_by_role="end_user", created_by=str(uuid.uuid4()), ) - db.session.add(saved) + db_session_with_containers.add(saved) - db.session.flush() + db_session_with_containers.flush() # AppAnnotationHitHistory hit = AppAnnotationHitHistory( @@ -275,7 +277,7 @@ class TestMessagesCleanServiceIntegration: annotation_question="Test annotation question", annotation_content="Test annotation content", ) - db.session.add(hit) + db_session_with_containers.add(hit) # DatasetRetrieverResource resource = DatasetRetrieverResource( @@ -296,25 +298,29 @@ class TestMessagesCleanServiceIntegration: retriever_from="dataset", created_by=message.from_account_id, ) - db.session.add(resource) + db_session_with_containers.add(resource) def test_billing_disabled_deletes_all_messages_in_time_range( - self, db_session_with_containers, mock_billing_disabled + self, db_session_with_containers: Session, mock_billing_disabled ): """Test that BillingDisabledPolicy deletes all messages within time range regardless of tenant plan.""" # Arrange - Create tenant with messages (plan doesn't matter for billing disabled) - account, tenant = self._create_account_and_tenant(plan=CloudPlan.SANDBOX) - app = self._create_app(tenant, account) - conv = self._create_conversation(app) + account, tenant = self._create_account_and_tenant(db_session_with_containers, plan=CloudPlan.SANDBOX) + app = self._create_app(db_session_with_containers, tenant, account) + conv = self._create_conversation(db_session_with_containers, app) # Create messages: in-range (should be deleted) and out-of-range (should be kept) in_range_date = datetime.datetime(2024, 1, 15, 12, 0, 0) out_of_range_date = datetime.datetime(2024, 1, 25, 12, 0, 0) - in_range_msg = self._create_message(app, conv, created_at=in_range_date, with_relations=True) + in_range_msg = self._create_message( + db_session_with_containers, app, conv, created_at=in_range_date, with_relations=True + ) in_range_msg_id = in_range_msg.id - out_of_range_msg = self._create_message(app, conv, created_at=out_of_range_date, with_relations=True) + out_of_range_msg = self._create_message( + db_session_with_containers, app, conv, created_at=out_of_range_date, with_relations=True + ) out_of_range_msg_id = out_of_range_msg.id # Act - create_message_clean_policy should return BillingDisabledPolicy @@ -336,17 +342,34 @@ class TestMessagesCleanServiceIntegration: assert stats["total_deleted"] == 1 # In-range message deleted - assert db.session.query(Message).where(Message.id == in_range_msg_id).count() == 0 + assert db_session_with_containers.query(Message).where(Message.id == in_range_msg_id).count() == 0 # Out-of-range message kept - assert db.session.query(Message).where(Message.id == out_of_range_msg_id).count() == 1 + assert db_session_with_containers.query(Message).where(Message.id == out_of_range_msg_id).count() == 1 # Related records of in-range message deleted - assert db.session.query(MessageFeedback).where(MessageFeedback.message_id == in_range_msg_id).count() == 0 - assert db.session.query(MessageAnnotation).where(MessageAnnotation.message_id == in_range_msg_id).count() == 0 + assert ( + db_session_with_containers.query(MessageFeedback) + .where(MessageFeedback.message_id == in_range_msg_id) + .count() + == 0 + ) + assert ( + db_session_with_containers.query(MessageAnnotation) + .where(MessageAnnotation.message_id == in_range_msg_id) + .count() + == 0 + ) # Related records of out-of-range message kept - assert db.session.query(MessageFeedback).where(MessageFeedback.message_id == out_of_range_msg_id).count() == 1 + assert ( + db_session_with_containers.query(MessageFeedback) + .where(MessageFeedback.message_id == out_of_range_msg_id) + .count() + == 1 + ) - def test_no_messages_returns_empty_stats(self, db_session_with_containers, mock_billing_enabled, mock_whitelist): + def test_no_messages_returns_empty_stats( + self, db_session_with_containers: Session, mock_billing_enabled, mock_whitelist + ): """Test cleaning when there are no messages to delete (B1).""" # Arrange end_before = datetime.datetime.now() - datetime.timedelta(days=30) @@ -371,36 +394,42 @@ class TestMessagesCleanServiceIntegration: assert stats["filtered_messages"] == 0 assert stats["total_deleted"] == 0 - def test_mixed_sandbox_and_paid_tenants(self, db_session_with_containers, mock_billing_enabled, mock_whitelist): + def test_mixed_sandbox_and_paid_tenants( + self, db_session_with_containers: Session, mock_billing_enabled, mock_whitelist + ): """Test cleaning with mixed sandbox and paid tenants (B2).""" # Arrange - Create sandbox tenants with expired messages sandbox_tenants = [] sandbox_message_ids = [] for i in range(2): - account, tenant = self._create_account_and_tenant(plan=CloudPlan.SANDBOX) + account, tenant = self._create_account_and_tenant(db_session_with_containers, plan=CloudPlan.SANDBOX) sandbox_tenants.append(tenant) - app = self._create_app(tenant, account) - conv = self._create_conversation(app) + app = self._create_app(db_session_with_containers, tenant, account) + conv = self._create_conversation(db_session_with_containers, app) # Create 3 expired messages per sandbox tenant expired_date = datetime.datetime.now() - datetime.timedelta(days=35) for j in range(3): - msg = self._create_message(app, conv, created_at=expired_date - datetime.timedelta(hours=j)) + msg = self._create_message( + db_session_with_containers, app, conv, created_at=expired_date - datetime.timedelta(hours=j) + ) sandbox_message_ids.append(msg.id) # Create paid tenants with expired messages (should NOT be deleted) paid_tenants = [] paid_message_ids = [] for i in range(2): - account, tenant = self._create_account_and_tenant(plan=CloudPlan.PROFESSIONAL) + account, tenant = self._create_account_and_tenant(db_session_with_containers, plan=CloudPlan.PROFESSIONAL) paid_tenants.append(tenant) - app = self._create_app(tenant, account) - conv = self._create_conversation(app) + app = self._create_app(db_session_with_containers, tenant, account) + conv = self._create_conversation(db_session_with_containers, app) # Create 2 expired messages per paid tenant expired_date = datetime.datetime.now() - datetime.timedelta(days=35) for j in range(2): - msg = self._create_message(app, conv, created_at=expired_date - datetime.timedelta(hours=j)) + msg = self._create_message( + db_session_with_containers, app, conv, created_at=expired_date - datetime.timedelta(hours=j) + ) paid_message_ids.append(msg.id) # Mock billing service - return plan and expiration_date @@ -442,29 +471,39 @@ class TestMessagesCleanServiceIntegration: assert stats["total_deleted"] == 6 # Only sandbox messages should be deleted - assert db.session.query(Message).where(Message.id.in_(sandbox_message_ids)).count() == 0 + assert db_session_with_containers.query(Message).where(Message.id.in_(sandbox_message_ids)).count() == 0 # Paid messages should remain - assert db.session.query(Message).where(Message.id.in_(paid_message_ids)).count() == 4 + assert db_session_with_containers.query(Message).where(Message.id.in_(paid_message_ids)).count() == 4 # Related records of sandbox messages should be deleted - assert db.session.query(MessageFeedback).where(MessageFeedback.message_id.in_(sandbox_message_ids)).count() == 0 assert ( - db.session.query(MessageAnnotation).where(MessageAnnotation.message_id.in_(sandbox_message_ids)).count() + db_session_with_containers.query(MessageFeedback) + .where(MessageFeedback.message_id.in_(sandbox_message_ids)) + .count() + == 0 + ) + assert ( + db_session_with_containers.query(MessageAnnotation) + .where(MessageAnnotation.message_id.in_(sandbox_message_ids)) + .count() == 0 ) - def test_cursor_pagination_multiple_batches(self, db_session_with_containers, mock_billing_enabled, mock_whitelist): + def test_cursor_pagination_multiple_batches( + self, db_session_with_containers: Session, mock_billing_enabled, mock_whitelist + ): """Test cursor pagination works correctly across multiple batches (B3).""" # Arrange - Create sandbox tenant with messages that will span multiple batches - account, tenant = self._create_account_and_tenant(plan=CloudPlan.SANDBOX) - app = self._create_app(tenant, account) - conv = self._create_conversation(app) + account, tenant = self._create_account_and_tenant(db_session_with_containers, plan=CloudPlan.SANDBOX) + app = self._create_app(db_session_with_containers, tenant, account) + conv = self._create_conversation(db_session_with_containers, app) # Create 10 expired messages with different timestamps base_date = datetime.datetime.now() - datetime.timedelta(days=35) message_ids = [] for i in range(10): msg = self._create_message( + db_session_with_containers, app, conv, created_at=base_date + datetime.timedelta(hours=i), @@ -498,20 +537,22 @@ class TestMessagesCleanServiceIntegration: assert stats["total_deleted"] == 10 # All messages should be deleted - assert db.session.query(Message).where(Message.id.in_(message_ids)).count() == 0 + assert db_session_with_containers.query(Message).where(Message.id.in_(message_ids)).count() == 0 - def test_dry_run_does_not_delete(self, db_session_with_containers, mock_billing_enabled, mock_whitelist): + def test_dry_run_does_not_delete(self, db_session_with_containers: Session, mock_billing_enabled, mock_whitelist): """Test dry_run mode does not delete messages (B4).""" # Arrange - account, tenant = self._create_account_and_tenant(plan=CloudPlan.SANDBOX) - app = self._create_app(tenant, account) - conv = self._create_conversation(app) + account, tenant = self._create_account_and_tenant(db_session_with_containers, plan=CloudPlan.SANDBOX) + app = self._create_app(db_session_with_containers, tenant, account) + conv = self._create_conversation(db_session_with_containers, app) # Create expired messages expired_date = datetime.datetime.now() - datetime.timedelta(days=35) message_ids = [] for i in range(3): - msg = self._create_message(app, conv, created_at=expired_date - datetime.timedelta(hours=i)) + msg = self._create_message( + db_session_with_containers, app, conv, created_at=expired_date - datetime.timedelta(hours=i) + ) message_ids.append(msg.id) with patch("services.billing_service.BillingService.get_plan_bulk") as mock_billing: @@ -540,21 +581,26 @@ class TestMessagesCleanServiceIntegration: assert stats["total_deleted"] == 0 # But NOT deleted # All messages should still exist - assert db.session.query(Message).where(Message.id.in_(message_ids)).count() == 3 + assert db_session_with_containers.query(Message).where(Message.id.in_(message_ids)).count() == 3 # Related records should also still exist - assert db.session.query(MessageFeedback).where(MessageFeedback.message_id.in_(message_ids)).count() == 3 + assert ( + db_session_with_containers.query(MessageFeedback).where(MessageFeedback.message_id.in_(message_ids)).count() + == 3 + ) - def test_partial_plan_data_safe_default(self, db_session_with_containers, mock_billing_enabled, mock_whitelist): + def test_partial_plan_data_safe_default( + self, db_session_with_containers: Session, mock_billing_enabled, mock_whitelist + ): """Test when billing returns partial data, unknown tenants are preserved (B5).""" # Arrange - Create 3 tenants tenants_data = [] for i in range(3): - account, tenant = self._create_account_and_tenant(plan=CloudPlan.SANDBOX) - app = self._create_app(tenant, account) - conv = self._create_conversation(app) + account, tenant = self._create_account_and_tenant(db_session_with_containers, plan=CloudPlan.SANDBOX) + app = self._create_app(db_session_with_containers, tenant, account) + conv = self._create_conversation(db_session_with_containers, app) expired_date = datetime.datetime.now() - datetime.timedelta(days=35) - msg = self._create_message(app, conv, created_at=expired_date) + msg = self._create_message(db_session_with_containers, app, conv, created_at=expired_date) tenants_data.append( { @@ -600,28 +646,30 @@ class TestMessagesCleanServiceIntegration: # Check which messages were deleted assert ( - db.session.query(Message).where(Message.id == tenants_data[0]["message_id"]).count() == 0 + db_session_with_containers.query(Message).where(Message.id == tenants_data[0]["message_id"]).count() == 0 ) # Sandbox tenant's message deleted assert ( - db.session.query(Message).where(Message.id == tenants_data[1]["message_id"]).count() == 1 + db_session_with_containers.query(Message).where(Message.id == tenants_data[1]["message_id"]).count() == 1 ) # Professional tenant's message preserved assert ( - db.session.query(Message).where(Message.id == tenants_data[2]["message_id"]).count() == 1 + db_session_with_containers.query(Message).where(Message.id == tenants_data[2]["message_id"]).count() == 1 ) # Unknown tenant's message preserved (safe default) - def test_empty_plan_data_skips_deletion(self, db_session_with_containers, mock_billing_enabled, mock_whitelist): + def test_empty_plan_data_skips_deletion( + self, db_session_with_containers: Session, mock_billing_enabled, mock_whitelist + ): """Test when billing returns empty data, skip deletion entirely (B6).""" # Arrange - account, tenant = self._create_account_and_tenant(plan=CloudPlan.SANDBOX) - app = self._create_app(tenant, account) - conv = self._create_conversation(app) + account, tenant = self._create_account_and_tenant(db_session_with_containers, plan=CloudPlan.SANDBOX) + app = self._create_app(db_session_with_containers, tenant, account) + conv = self._create_conversation(db_session_with_containers, app) expired_date = datetime.datetime.now() - datetime.timedelta(days=35) - msg = self._create_message(app, conv, created_at=expired_date) + msg = self._create_message(db_session_with_containers, app, conv, created_at=expired_date) msg_id = msg.id - db.session.commit() + db_session_with_containers.commit() # Mock billing service to return empty data (simulating failure/no data scenario) with patch("services.billing_service.BillingService.get_plan_bulk") as mock_billing: @@ -644,17 +692,20 @@ class TestMessagesCleanServiceIntegration: assert stats["total_deleted"] == 0 # Message should still exist (safe default - don't delete if plan is unknown) - assert db.session.query(Message).where(Message.id == msg_id).count() == 1 + assert db_session_with_containers.query(Message).where(Message.id == msg_id).count() == 1 - def test_time_range_boundary_behavior(self, db_session_with_containers, mock_billing_enabled, mock_whitelist): + def test_time_range_boundary_behavior( + self, db_session_with_containers: Session, mock_billing_enabled, mock_whitelist + ): """Test that messages are correctly filtered by [start_from, end_before) time range (B7).""" # Arrange - account, tenant = self._create_account_and_tenant(plan=CloudPlan.SANDBOX) - app = self._create_app(tenant, account) - conv = self._create_conversation(app) + account, tenant = self._create_account_and_tenant(db_session_with_containers, plan=CloudPlan.SANDBOX) + app = self._create_app(db_session_with_containers, tenant, account) + conv = self._create_conversation(db_session_with_containers, app) # Create messages: before range, in range, after range msg_before = self._create_message( + db_session_with_containers, app, conv, created_at=datetime.datetime(2024, 1, 1, 12, 0, 0), # Before start_from @@ -663,6 +714,7 @@ class TestMessagesCleanServiceIntegration: msg_before_id = msg_before.id msg_at_start = self._create_message( + db_session_with_containers, app, conv, created_at=datetime.datetime(2024, 1, 10, 12, 0, 0), # At start_from (inclusive) @@ -671,6 +723,7 @@ class TestMessagesCleanServiceIntegration: msg_at_start_id = msg_at_start.id msg_in_range = self._create_message( + db_session_with_containers, app, conv, created_at=datetime.datetime(2024, 1, 15, 12, 0, 0), # In range @@ -679,6 +732,7 @@ class TestMessagesCleanServiceIntegration: msg_in_range_id = msg_in_range.id msg_at_end = self._create_message( + db_session_with_containers, app, conv, created_at=datetime.datetime(2024, 1, 20, 12, 0, 0), # At end_before (exclusive) @@ -687,6 +741,7 @@ class TestMessagesCleanServiceIntegration: msg_at_end_id = msg_at_end.id msg_after = self._create_message( + db_session_with_containers, app, conv, created_at=datetime.datetime(2024, 1, 25, 12, 0, 0), # After end_before @@ -694,7 +749,7 @@ class TestMessagesCleanServiceIntegration: ) msg_after_id = msg_after.id - db.session.commit() + db_session_with_containers.commit() # Mock billing service with patch("services.billing_service.BillingService.get_plan_bulk") as mock_billing: @@ -722,17 +777,17 @@ class TestMessagesCleanServiceIntegration: # Verify specific messages using stored IDs # Before range, kept - assert db.session.query(Message).where(Message.id == msg_before_id).count() == 1 + assert db_session_with_containers.query(Message).where(Message.id == msg_before_id).count() == 1 # At start (inclusive), deleted - assert db.session.query(Message).where(Message.id == msg_at_start_id).count() == 0 + assert db_session_with_containers.query(Message).where(Message.id == msg_at_start_id).count() == 0 # In range, deleted - assert db.session.query(Message).where(Message.id == msg_in_range_id).count() == 0 + assert db_session_with_containers.query(Message).where(Message.id == msg_in_range_id).count() == 0 # At end (exclusive), kept - assert db.session.query(Message).where(Message.id == msg_at_end_id).count() == 1 + assert db_session_with_containers.query(Message).where(Message.id == msg_at_end_id).count() == 1 # After range, kept - assert db.session.query(Message).where(Message.id == msg_after_id).count() == 1 + assert db_session_with_containers.query(Message).where(Message.id == msg_after_id).count() == 1 - def test_grace_period_scenarios(self, db_session_with_containers, mock_billing_enabled, mock_whitelist): + def test_grace_period_scenarios(self, db_session_with_containers: Session, mock_billing_enabled, mock_whitelist): """Test cleaning with different graceful period scenarios (B8).""" # Arrange - Create 5 different tenants with different plan and expiration scenarios now_timestamp = int(datetime.datetime.now(datetime.UTC).timestamp()) @@ -740,50 +795,60 @@ class TestMessagesCleanServiceIntegration: # Scenario 1: Sandbox plan with expiration within graceful period (5 days ago) # Should NOT be deleted - account1, tenant1 = self._create_account_and_tenant(plan=CloudPlan.SANDBOX) - app1 = self._create_app(tenant1, account1) - conv1 = self._create_conversation(app1) + account1, tenant1 = self._create_account_and_tenant(db_session_with_containers, plan=CloudPlan.SANDBOX) + app1 = self._create_app(db_session_with_containers, tenant1, account1) + conv1 = self._create_conversation(db_session_with_containers, app1) expired_date = datetime.datetime.now() - datetime.timedelta(days=35) - msg1 = self._create_message(app1, conv1, created_at=expired_date, with_relations=False) + msg1 = self._create_message( + db_session_with_containers, app1, conv1, created_at=expired_date, with_relations=False + ) msg1_id = msg1.id expired_5_days_ago = now_timestamp - (5 * 24 * 60 * 60) # Within grace period # Scenario 2: Sandbox plan with expiration beyond graceful period (10 days ago) # Should be deleted - account2, tenant2 = self._create_account_and_tenant(plan=CloudPlan.SANDBOX) - app2 = self._create_app(tenant2, account2) - conv2 = self._create_conversation(app2) - msg2 = self._create_message(app2, conv2, created_at=expired_date, with_relations=False) + account2, tenant2 = self._create_account_and_tenant(db_session_with_containers, plan=CloudPlan.SANDBOX) + app2 = self._create_app(db_session_with_containers, tenant2, account2) + conv2 = self._create_conversation(db_session_with_containers, app2) + msg2 = self._create_message( + db_session_with_containers, app2, conv2, created_at=expired_date, with_relations=False + ) msg2_id = msg2.id expired_10_days_ago = now_timestamp - (10 * 24 * 60 * 60) # Beyond grace period # Scenario 3: Sandbox plan with expiration_date = -1 (no previous subscription) # Should be deleted - account3, tenant3 = self._create_account_and_tenant(plan=CloudPlan.SANDBOX) - app3 = self._create_app(tenant3, account3) - conv3 = self._create_conversation(app3) - msg3 = self._create_message(app3, conv3, created_at=expired_date, with_relations=False) + account3, tenant3 = self._create_account_and_tenant(db_session_with_containers, plan=CloudPlan.SANDBOX) + app3 = self._create_app(db_session_with_containers, tenant3, account3) + conv3 = self._create_conversation(db_session_with_containers, app3) + msg3 = self._create_message( + db_session_with_containers, app3, conv3, created_at=expired_date, with_relations=False + ) msg3_id = msg3.id # Scenario 4: Non-sandbox plan (professional) with no expiration (future date) # Should NOT be deleted - account4, tenant4 = self._create_account_and_tenant(plan=CloudPlan.PROFESSIONAL) - app4 = self._create_app(tenant4, account4) - conv4 = self._create_conversation(app4) - msg4 = self._create_message(app4, conv4, created_at=expired_date, with_relations=False) + account4, tenant4 = self._create_account_and_tenant(db_session_with_containers, plan=CloudPlan.PROFESSIONAL) + app4 = self._create_app(db_session_with_containers, tenant4, account4) + conv4 = self._create_conversation(db_session_with_containers, app4) + msg4 = self._create_message( + db_session_with_containers, app4, conv4, created_at=expired_date, with_relations=False + ) msg4_id = msg4.id future_expiration = now_timestamp + (365 * 24 * 60 * 60) # Active for 1 year # Scenario 5: Sandbox plan with expiration exactly at grace period boundary (8 days ago) # Should NOT be deleted (boundary is exclusive: > graceful_period) - account5, tenant5 = self._create_account_and_tenant(plan=CloudPlan.SANDBOX) - app5 = self._create_app(tenant5, account5) - conv5 = self._create_conversation(app5) - msg5 = self._create_message(app5, conv5, created_at=expired_date, with_relations=False) + account5, tenant5 = self._create_account_and_tenant(db_session_with_containers, plan=CloudPlan.SANDBOX) + app5 = self._create_app(db_session_with_containers, tenant5, account5) + conv5 = self._create_conversation(db_session_with_containers, app5) + msg5 = self._create_message( + db_session_with_containers, app5, conv5, created_at=expired_date, with_relations=False + ) msg5_id = msg5.id expired_exactly_8_days_ago = now_timestamp - (8 * 24 * 60 * 60) # Exactly at boundary - db.session.commit() + db_session_with_containers.commit() # Mock billing service with all scenarios plan_map = { @@ -832,23 +897,31 @@ class TestMessagesCleanServiceIntegration: assert stats["total_deleted"] == 2 # Verify each scenario using saved IDs - assert db.session.query(Message).where(Message.id == msg1_id).count() == 1 # Within grace, kept - assert db.session.query(Message).where(Message.id == msg2_id).count() == 0 # Beyond grace, deleted - assert db.session.query(Message).where(Message.id == msg3_id).count() == 0 # No subscription, deleted - assert db.session.query(Message).where(Message.id == msg4_id).count() == 1 # Professional plan, kept - assert db.session.query(Message).where(Message.id == msg5_id).count() == 1 # At boundary, kept + assert db_session_with_containers.query(Message).where(Message.id == msg1_id).count() == 1 # Within grace, kept + assert ( + db_session_with_containers.query(Message).where(Message.id == msg2_id).count() == 0 + ) # Beyond grace, deleted + assert ( + db_session_with_containers.query(Message).where(Message.id == msg3_id).count() == 0 + ) # No subscription, deleted + assert ( + db_session_with_containers.query(Message).where(Message.id == msg4_id).count() == 1 + ) # Professional plan, kept + assert db_session_with_containers.query(Message).where(Message.id == msg5_id).count() == 1 # At boundary, kept - def test_tenant_whitelist(self, db_session_with_containers, mock_billing_enabled, mock_whitelist): + def test_tenant_whitelist(self, db_session_with_containers: Session, mock_billing_enabled, mock_whitelist): """Test that whitelisted tenants' messages are not deleted (B9).""" # Arrange - Create 3 sandbox tenants with expired messages tenants_data = [] for i in range(3): - account, tenant = self._create_account_and_tenant(plan=CloudPlan.SANDBOX) - app = self._create_app(tenant, account) - conv = self._create_conversation(app) + account, tenant = self._create_account_and_tenant(db_session_with_containers, plan=CloudPlan.SANDBOX) + app = self._create_app(db_session_with_containers, tenant, account) + conv = self._create_conversation(db_session_with_containers, app) expired_date = datetime.datetime.now() - datetime.timedelta(days=35) - msg = self._create_message(app, conv, created_at=expired_date, with_relations=False) + msg = self._create_message( + db_session_with_containers, app, conv, created_at=expired_date, with_relations=False + ) tenants_data.append( { @@ -897,27 +970,33 @@ class TestMessagesCleanServiceIntegration: assert stats["total_deleted"] == 1 # Verify tenant0's message still exists (whitelisted) - assert db.session.query(Message).where(Message.id == tenants_data[0]["message_id"]).count() == 1 + assert db_session_with_containers.query(Message).where(Message.id == tenants_data[0]["message_id"]).count() == 1 # Verify tenant1's message still exists (whitelisted) - assert db.session.query(Message).where(Message.id == tenants_data[1]["message_id"]).count() == 1 + assert db_session_with_containers.query(Message).where(Message.id == tenants_data[1]["message_id"]).count() == 1 # Verify tenant2's message was deleted (not whitelisted) - assert db.session.query(Message).where(Message.id == tenants_data[2]["message_id"]).count() == 0 + assert db_session_with_containers.query(Message).where(Message.id == tenants_data[2]["message_id"]).count() == 0 - def test_from_days_cleans_old_messages(self, db_session_with_containers, mock_billing_enabled, mock_whitelist): + def test_from_days_cleans_old_messages( + self, db_session_with_containers: Session, mock_billing_enabled, mock_whitelist + ): """Test from_days correctly cleans messages older than N days (B11).""" # Arrange - account, tenant = self._create_account_and_tenant(plan=CloudPlan.SANDBOX) - app = self._create_app(tenant, account) - conv = self._create_conversation(app) + account, tenant = self._create_account_and_tenant(db_session_with_containers, plan=CloudPlan.SANDBOX) + app = self._create_app(db_session_with_containers, tenant, account) + conv = self._create_conversation(db_session_with_containers, app) # Create old messages (should be deleted - older than 30 days) old_date = datetime.datetime.now() - datetime.timedelta(days=45) old_msg_ids = [] for i in range(3): msg = self._create_message( - app, conv, created_at=old_date - datetime.timedelta(hours=i), with_relations=False + db_session_with_containers, + app, + conv, + created_at=old_date - datetime.timedelta(hours=i), + with_relations=False, ) old_msg_ids.append(msg.id) @@ -926,11 +1005,15 @@ class TestMessagesCleanServiceIntegration: recent_msg_ids = [] for i in range(2): msg = self._create_message( - app, conv, created_at=recent_date - datetime.timedelta(hours=i), with_relations=False + db_session_with_containers, + app, + conv, + created_at=recent_date - datetime.timedelta(hours=i), + with_relations=False, ) recent_msg_ids.append(msg.id) - db.session.commit() + db_session_with_containers.commit() with patch("services.billing_service.BillingService.get_plan_bulk") as mock_billing: mock_billing.return_value = { @@ -955,30 +1038,34 @@ class TestMessagesCleanServiceIntegration: assert stats["total_deleted"] == 3 # Old messages deleted - assert db.session.query(Message).where(Message.id.in_(old_msg_ids)).count() == 0 + assert db_session_with_containers.query(Message).where(Message.id.in_(old_msg_ids)).count() == 0 # Recent messages kept - assert db.session.query(Message).where(Message.id.in_(recent_msg_ids)).count() == 2 + assert db_session_with_containers.query(Message).where(Message.id.in_(recent_msg_ids)).count() == 2 def test_whitelist_precedence_over_grace_period( - self, db_session_with_containers, mock_billing_enabled, mock_whitelist + self, db_session_with_containers: Session, mock_billing_enabled, mock_whitelist ): """Test that whitelist takes precedence over grace period logic.""" # Arrange - Create 2 sandbox tenants now_timestamp = int(datetime.datetime.now(datetime.UTC).timestamp()) # Tenant1: whitelisted, expired beyond grace period - account1, tenant1 = self._create_account_and_tenant(plan=CloudPlan.SANDBOX) - app1 = self._create_app(tenant1, account1) - conv1 = self._create_conversation(app1) + account1, tenant1 = self._create_account_and_tenant(db_session_with_containers, plan=CloudPlan.SANDBOX) + app1 = self._create_app(db_session_with_containers, tenant1, account1) + conv1 = self._create_conversation(db_session_with_containers, app1) expired_date = datetime.datetime.now() - datetime.timedelta(days=35) - msg1 = self._create_message(app1, conv1, created_at=expired_date, with_relations=False) + msg1 = self._create_message( + db_session_with_containers, app1, conv1, created_at=expired_date, with_relations=False + ) expired_30_days_ago = now_timestamp - (30 * 24 * 60 * 60) # Well beyond 21-day grace # Tenant2: not whitelisted, within grace period - account2, tenant2 = self._create_account_and_tenant(plan=CloudPlan.SANDBOX) - app2 = self._create_app(tenant2, account2) - conv2 = self._create_conversation(app2) - msg2 = self._create_message(app2, conv2, created_at=expired_date, with_relations=False) + account2, tenant2 = self._create_account_and_tenant(db_session_with_containers, plan=CloudPlan.SANDBOX) + app2 = self._create_app(db_session_with_containers, tenant2, account2) + conv2 = self._create_conversation(db_session_with_containers, app2) + msg2 = self._create_message( + db_session_with_containers, app2, conv2, created_at=expired_date, with_relations=False + ) expired_10_days_ago = now_timestamp - (10 * 24 * 60 * 60) # Within 21-day grace # Mock billing service @@ -1019,22 +1106,26 @@ class TestMessagesCleanServiceIntegration: assert stats["total_deleted"] == 0 # Verify both messages still exist - assert db.session.query(Message).where(Message.id == msg1.id).count() == 1 # Whitelisted - assert db.session.query(Message).where(Message.id == msg2.id).count() == 1 # Within grace period + assert db_session_with_containers.query(Message).where(Message.id == msg1.id).count() == 1 # Whitelisted + assert ( + db_session_with_containers.query(Message).where(Message.id == msg2.id).count() == 1 + ) # Within grace period def test_empty_whitelist_deletes_eligible_messages( - self, db_session_with_containers, mock_billing_enabled, mock_whitelist + self, db_session_with_containers: Session, mock_billing_enabled, mock_whitelist ): """Test that empty whitelist behaves as no whitelist (all eligible messages deleted).""" # Arrange - Create sandbox tenant with expired messages - account, tenant = self._create_account_and_tenant(plan=CloudPlan.SANDBOX) - app = self._create_app(tenant, account) - conv = self._create_conversation(app) + account, tenant = self._create_account_and_tenant(db_session_with_containers, plan=CloudPlan.SANDBOX) + app = self._create_app(db_session_with_containers, tenant, account) + conv = self._create_conversation(db_session_with_containers, app) expired_date = datetime.datetime.now() - datetime.timedelta(days=35) msg_ids = [] for i in range(3): - msg = self._create_message(app, conv, created_at=expired_date - datetime.timedelta(hours=i)) + msg = self._create_message( + db_session_with_containers, app, conv, created_at=expired_date - datetime.timedelta(hours=i) + ) msg_ids.append(msg.id) # Mock billing service @@ -1068,4 +1159,4 @@ class TestMessagesCleanServiceIntegration: assert stats["total_deleted"] == 3 # Verify all messages were deleted - assert db.session.query(Message).where(Message.id.in_(msg_ids)).count() == 0 + assert db_session_with_containers.query(Message).where(Message.id.in_(msg_ids)).count() == 0 diff --git a/api/tests/test_containers_integration_tests/services/test_metadata_service.py b/api/tests/test_containers_integration_tests/services/test_metadata_service.py index e04725627b..694dc1c1b9 100644 --- a/api/tests/test_containers_integration_tests/services/test_metadata_service.py +++ b/api/tests/test_containers_integration_tests/services/test_metadata_service.py @@ -2,6 +2,7 @@ from unittest.mock import create_autospec, patch import pytest from faker import Faker +from sqlalchemy.orm import Session from core.rag.index_processor.constant.built_in_field import BuiltInField from models import Account, Tenant, TenantAccountJoin, TenantAccountRole @@ -32,7 +33,7 @@ class TestMetadataService: "document_service": mock_document_service, } - def _create_test_account_and_tenant(self, db_session_with_containers, mock_external_service_dependencies): + def _create_test_account_and_tenant(self, db_session_with_containers: Session, mock_external_service_dependencies): """ Helper method to create a test account and tenant for testing. @@ -53,18 +54,16 @@ class TestMetadataService: status="active", ) - from extensions.ext_database import db - - db.session.add(account) - db.session.commit() + db_session_with_containers.add(account) + db_session_with_containers.commit() # Create tenant for the account tenant = Tenant( name=fake.company(), status="normal", ) - db.session.add(tenant) - db.session.commit() + db_session_with_containers.add(tenant) + db_session_with_containers.commit() # Create tenant-account join join = TenantAccountJoin( @@ -73,15 +72,17 @@ class TestMetadataService: role=TenantAccountRole.OWNER, current=True, ) - db.session.add(join) - db.session.commit() + db_session_with_containers.add(join) + db_session_with_containers.commit() # Set current tenant for account account.current_tenant = tenant return account, tenant - def _create_test_dataset(self, db_session_with_containers, mock_external_service_dependencies, account, tenant): + def _create_test_dataset( + self, db_session_with_containers: Session, mock_external_service_dependencies, account, tenant + ): """ Helper method to create a test dataset for testing. @@ -105,14 +106,14 @@ class TestMetadataService: built_in_field_enabled=False, ) - from extensions.ext_database import db - - db.session.add(dataset) - db.session.commit() + db_session_with_containers.add(dataset) + db_session_with_containers.commit() return dataset - def _create_test_document(self, db_session_with_containers, mock_external_service_dependencies, dataset, account): + def _create_test_document( + self, db_session_with_containers: Session, mock_external_service_dependencies, dataset, account + ): """ Helper method to create a test document for testing. @@ -141,14 +142,12 @@ class TestMetadataService: doc_language="en", ) - from extensions.ext_database import db - - db.session.add(document) - db.session.commit() + db_session_with_containers.add(document) + db_session_with_containers.commit() return document - def test_create_metadata_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_create_metadata_success(self, db_session_with_containers: Session, mock_external_service_dependencies): """ Test successful metadata creation with valid parameters. """ @@ -178,13 +177,14 @@ class TestMetadataService: assert result.created_by == account.id # Verify database state - from extensions.ext_database import db - db.session.refresh(result) + db_session_with_containers.refresh(result) assert result.id is not None assert result.created_at is not None - def test_create_metadata_name_too_long(self, db_session_with_containers, mock_external_service_dependencies): + def test_create_metadata_name_too_long( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test metadata creation fails when name exceeds 255 characters. """ @@ -207,7 +207,9 @@ class TestMetadataService: with pytest.raises(ValueError, match="Metadata name cannot exceed 255 characters."): MetadataService.create_metadata(dataset.id, metadata_args) - def test_create_metadata_name_already_exists(self, db_session_with_containers, mock_external_service_dependencies): + def test_create_metadata_name_already_exists( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test metadata creation fails when name already exists in the same dataset. """ @@ -235,7 +237,7 @@ class TestMetadataService: MetadataService.create_metadata(dataset.id, second_metadata_args) def test_create_metadata_name_conflicts_with_built_in_field( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test metadata creation fails when name conflicts with built-in field names. @@ -260,7 +262,9 @@ class TestMetadataService: with pytest.raises(ValueError, match="Metadata name already exists in Built-in fields."): MetadataService.create_metadata(dataset.id, metadata_args) - def test_update_metadata_name_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_update_metadata_name_success( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test successful metadata name update with valid parameters. """ @@ -291,12 +295,13 @@ class TestMetadataService: assert result.updated_at is not None # Verify database state - from extensions.ext_database import db - db.session.refresh(result) + db_session_with_containers.refresh(result) assert result.name == new_name - def test_update_metadata_name_too_long(self, db_session_with_containers, mock_external_service_dependencies): + def test_update_metadata_name_too_long( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test metadata name update fails when new name exceeds 255 characters. """ @@ -323,7 +328,9 @@ class TestMetadataService: with pytest.raises(ValueError, match="Metadata name cannot exceed 255 characters."): MetadataService.update_metadata_name(dataset.id, metadata.id, long_name) - def test_update_metadata_name_already_exists(self, db_session_with_containers, mock_external_service_dependencies): + def test_update_metadata_name_already_exists( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test metadata name update fails when new name already exists in the same dataset. """ @@ -351,7 +358,7 @@ class TestMetadataService: MetadataService.update_metadata_name(dataset.id, first_metadata.id, "second_metadata") def test_update_metadata_name_conflicts_with_built_in_field( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test metadata name update fails when new name conflicts with built-in field names. @@ -378,7 +385,9 @@ class TestMetadataService: with pytest.raises(ValueError, match="Metadata name already exists in Built-in fields."): MetadataService.update_metadata_name(dataset.id, metadata.id, built_in_field_name) - def test_update_metadata_name_not_found(self, db_session_with_containers, mock_external_service_dependencies): + def test_update_metadata_name_not_found( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test metadata name update fails when metadata ID does not exist. """ @@ -406,7 +415,7 @@ class TestMetadataService: # Assert: Verify the method returns None when metadata is not found assert result is None - def test_delete_metadata_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_delete_metadata_success(self, db_session_with_containers: Session, mock_external_service_dependencies): """ Test successful metadata deletion with valid parameters. """ @@ -434,12 +443,11 @@ class TestMetadataService: assert result.id == metadata.id # Verify metadata was deleted from database - from extensions.ext_database import db - deleted_metadata = db.session.query(DatasetMetadata).filter_by(id=metadata.id).first() + deleted_metadata = db_session_with_containers.query(DatasetMetadata).filter_by(id=metadata.id).first() assert deleted_metadata is None - def test_delete_metadata_not_found(self, db_session_with_containers, mock_external_service_dependencies): + def test_delete_metadata_not_found(self, db_session_with_containers: Session, mock_external_service_dependencies): """ Test metadata deletion fails when metadata ID does not exist. """ @@ -467,7 +475,7 @@ class TestMetadataService: assert result is None def test_delete_metadata_with_document_bindings( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test metadata deletion successfully removes document metadata bindings. @@ -500,15 +508,13 @@ class TestMetadataService: created_by=account.id, ) - from extensions.ext_database import db - - db.session.add(binding) - db.session.commit() + db_session_with_containers.add(binding) + db_session_with_containers.commit() # Set document metadata document.doc_metadata = {"test_metadata": "test_value"} - db.session.add(document) - db.session.commit() + db_session_with_containers.add(document) + db_session_with_containers.commit() # Act: Execute the method under test result = MetadataService.delete_metadata(dataset.id, metadata.id) @@ -517,13 +523,13 @@ class TestMetadataService: assert result is not None # Verify metadata was deleted from database - deleted_metadata = db.session.query(DatasetMetadata).filter_by(id=metadata.id).first() + deleted_metadata = db_session_with_containers.query(DatasetMetadata).filter_by(id=metadata.id).first() assert deleted_metadata is None # Note: The service attempts to update document metadata but may not succeed # due to mock configuration. The main functionality (metadata deletion) is verified. - def test_get_built_in_fields_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_get_built_in_fields_success(self, db_session_with_containers: Session, mock_external_service_dependencies): """ Test successful retrieval of built-in metadata fields. """ @@ -548,7 +554,9 @@ class TestMetadataService: assert "string" in field_types assert "time" in field_types - def test_enable_built_in_field_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_enable_built_in_field_success( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test successful enabling of built-in fields for a dataset. """ @@ -579,16 +587,15 @@ class TestMetadataService: MetadataService.enable_built_in_field(dataset) # Assert: Verify the expected outcomes - from extensions.ext_database import db - db.session.refresh(dataset) + db_session_with_containers.refresh(dataset) assert dataset.built_in_field_enabled is True # Note: Document metadata update depends on DocumentService mock working correctly # The main functionality (enabling built-in fields) is verified def test_enable_built_in_field_already_enabled( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test enabling built-in fields when they are already enabled. @@ -607,10 +614,9 @@ class TestMetadataService: # Enable built-in fields first dataset.built_in_field_enabled = True - from extensions.ext_database import db - db.session.add(dataset) - db.session.commit() + db_session_with_containers.add(dataset) + db_session_with_containers.commit() # Mock DocumentService.get_working_documents_by_dataset_id mock_external_service_dependencies["document_service"].get_working_documents_by_dataset_id.return_value = [] @@ -619,11 +625,11 @@ class TestMetadataService: MetadataService.enable_built_in_field(dataset) # Assert: Verify the method returns early without changes - db.session.refresh(dataset) + db_session_with_containers.refresh(dataset) assert dataset.built_in_field_enabled is True def test_enable_built_in_field_with_no_documents( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test enabling built-in fields for a dataset with no documents. @@ -647,12 +653,13 @@ class TestMetadataService: MetadataService.enable_built_in_field(dataset) # Assert: Verify the expected outcomes - from extensions.ext_database import db - db.session.refresh(dataset) + db_session_with_containers.refresh(dataset) assert dataset.built_in_field_enabled is True - def test_disable_built_in_field_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_disable_built_in_field_success( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test successful disabling of built-in fields for a dataset. """ @@ -673,10 +680,9 @@ class TestMetadataService: # Enable built-in fields first dataset.built_in_field_enabled = True - from extensions.ext_database import db - db.session.add(dataset) - db.session.commit() + db_session_with_containers.add(dataset) + db_session_with_containers.commit() # Set document metadata with built-in fields document.doc_metadata = { @@ -686,8 +692,8 @@ class TestMetadataService: BuiltInField.last_update_date: 1234567890.0, BuiltInField.source: "test_source", } - db.session.add(document) - db.session.commit() + db_session_with_containers.add(document) + db_session_with_containers.commit() # Mock DocumentService.get_working_documents_by_dataset_id mock_external_service_dependencies["document_service"].get_working_documents_by_dataset_id.return_value = [ @@ -698,14 +704,14 @@ class TestMetadataService: MetadataService.disable_built_in_field(dataset) # Assert: Verify the expected outcomes - db.session.refresh(dataset) + db_session_with_containers.refresh(dataset) assert dataset.built_in_field_enabled is False # Note: Document metadata update depends on DocumentService mock working correctly # The main functionality (disabling built-in fields) is verified def test_disable_built_in_field_already_disabled( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test disabling built-in fields when they are already disabled. @@ -732,13 +738,12 @@ class TestMetadataService: MetadataService.disable_built_in_field(dataset) # Assert: Verify the method returns early without changes - from extensions.ext_database import db - db.session.refresh(dataset) + db_session_with_containers.refresh(dataset) assert dataset.built_in_field_enabled is False def test_disable_built_in_field_with_no_documents( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test disabling built-in fields for a dataset with no documents. @@ -757,10 +762,9 @@ class TestMetadataService: # Enable built-in fields first dataset.built_in_field_enabled = True - from extensions.ext_database import db - db.session.add(dataset) - db.session.commit() + db_session_with_containers.add(dataset) + db_session_with_containers.commit() # Mock DocumentService.get_working_documents_by_dataset_id to return empty list mock_external_service_dependencies["document_service"].get_working_documents_by_dataset_id.return_value = [] @@ -769,10 +773,12 @@ class TestMetadataService: MetadataService.disable_built_in_field(dataset) # Assert: Verify the expected outcomes - db.session.refresh(dataset) + db_session_with_containers.refresh(dataset) assert dataset.built_in_field_enabled is False - def test_update_documents_metadata_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_update_documents_metadata_success( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test successful update of documents metadata. """ @@ -815,24 +821,25 @@ class TestMetadataService: MetadataService.update_documents_metadata(dataset, operation_data) # Assert: Verify the expected outcomes - from extensions.ext_database import db # Verify document metadata was updated - db.session.refresh(document) + db_session_with_containers.refresh(document) assert document.doc_metadata is not None assert "test_metadata" in document.doc_metadata assert document.doc_metadata["test_metadata"] == "test_value" # Verify metadata binding was created binding = ( - db.session.query(DatasetMetadataBinding).filter_by(metadata_id=metadata.id, document_id=document.id).first() + db_session_with_containers.query(DatasetMetadataBinding) + .filter_by(metadata_id=metadata.id, document_id=document.id) + .first() ) assert binding is not None assert binding.tenant_id == tenant.id assert binding.dataset_id == dataset.id def test_update_documents_metadata_with_built_in_fields_enabled( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test update of documents metadata when built-in fields are enabled. @@ -850,10 +857,9 @@ class TestMetadataService: # Enable built-in fields dataset.built_in_field_enabled = True - from extensions.ext_database import db - db.session.add(dataset) - db.session.commit() + db_session_with_containers.add(dataset) + db_session_with_containers.commit() # Setup mocks mock_external_service_dependencies["current_user"].current_tenant_id = tenant.id @@ -884,7 +890,7 @@ class TestMetadataService: # Assert: Verify the expected outcomes # Verify document metadata was updated with both custom and built-in fields - db.session.refresh(document) + db_session_with_containers.refresh(document) assert document.doc_metadata is not None assert "test_metadata" in document.doc_metadata assert document.doc_metadata["test_metadata"] == "test_value" @@ -893,7 +899,7 @@ class TestMetadataService: # The main functionality (custom metadata update) is verified def test_update_documents_metadata_document_not_found( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test update of documents metadata when document is not found. @@ -936,7 +942,7 @@ class TestMetadataService: MetadataService.update_documents_metadata(dataset, operation_data) def test_knowledge_base_metadata_lock_check_dataset_id( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test metadata lock check for dataset operations. @@ -959,7 +965,7 @@ class TestMetadataService: assert call_args[0][0] == f"dataset_metadata_lock_{dataset_id}" def test_knowledge_base_metadata_lock_check_document_id( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test metadata lock check for document operations. @@ -982,7 +988,7 @@ class TestMetadataService: assert call_args[0][0] == f"document_metadata_lock_{document_id}" def test_knowledge_base_metadata_lock_check_lock_exists( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test metadata lock check when lock already exists. @@ -999,7 +1005,7 @@ class TestMetadataService: MetadataService.knowledge_base_metadata_lock_check(dataset_id, None) def test_knowledge_base_metadata_lock_check_document_lock_exists( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test metadata lock check when document lock already exists. @@ -1013,7 +1019,9 @@ class TestMetadataService: with pytest.raises(ValueError, match="Another document metadata operation is running, please wait a moment."): MetadataService.knowledge_base_metadata_lock_check(None, document_id) - def test_get_dataset_metadatas_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_get_dataset_metadatas_success( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test successful retrieval of dataset metadata information. """ @@ -1046,10 +1054,8 @@ class TestMetadataService: created_by=account.id, ) - from extensions.ext_database import db - - db.session.add(binding) - db.session.commit() + db_session_with_containers.add(binding) + db_session_with_containers.commit() # Act: Execute the method under test result = MetadataService.get_dataset_metadatas(dataset) @@ -1071,7 +1077,7 @@ class TestMetadataService: assert result["built_in_field_enabled"] is False def test_get_dataset_metadatas_with_built_in_fields_enabled( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test retrieval of dataset metadata when built-in fields are enabled. @@ -1086,10 +1092,9 @@ class TestMetadataService: # Enable built-in fields dataset.built_in_field_enabled = True - from extensions.ext_database import db - db.session.add(dataset) - db.session.commit() + db_session_with_containers.add(dataset) + db_session_with_containers.commit() # Setup mocks mock_external_service_dependencies["current_user"].current_tenant_id = tenant.id @@ -1114,7 +1119,9 @@ class TestMetadataService: # Verify built-in field status assert result["built_in_field_enabled"] is True - def test_get_dataset_metadatas_no_metadata(self, db_session_with_containers, mock_external_service_dependencies): + def test_get_dataset_metadatas_no_metadata( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test retrieval of dataset metadata when no metadata exists. """ diff --git a/api/tests/test_containers_integration_tests/services/test_model_load_balancing_service.py b/api/tests/test_containers_integration_tests/services/test_model_load_balancing_service.py index 7c8472e819..989df42499 100644 --- a/api/tests/test_containers_integration_tests/services/test_model_load_balancing_service.py +++ b/api/tests/test_containers_integration_tests/services/test_model_load_balancing_service.py @@ -3,6 +3,7 @@ from unittest.mock import MagicMock, patch import pytest from faker import Faker from sqlalchemy import select +from sqlalchemy.orm import Session from models.account import TenantAccountJoin, TenantAccountRole from models.model import Account, Tenant @@ -67,7 +68,7 @@ class TestModelLoadBalancingService: "credential_schema": mock_credential_schema, } - def _create_test_account_and_tenant(self, db_session_with_containers, mock_external_service_dependencies): + def _create_test_account_and_tenant(self, db_session_with_containers: Session, mock_external_service_dependencies): """ Helper method to create a test account and tenant for testing. @@ -88,18 +89,16 @@ class TestModelLoadBalancingService: status="active", ) - from extensions.ext_database import db - - db.session.add(account) - db.session.commit() + db_session_with_containers.add(account) + db_session_with_containers.commit() # Create tenant for the account tenant = Tenant( name=fake.company(), status="normal", ) - db.session.add(tenant) - db.session.commit() + db_session_with_containers.add(tenant) + db_session_with_containers.commit() # Create tenant-account join join = TenantAccountJoin( @@ -108,8 +107,8 @@ class TestModelLoadBalancingService: role=TenantAccountRole.OWNER, current=True, ) - db.session.add(join) - db.session.commit() + db_session_with_containers.add(join) + db_session_with_containers.commit() # Set current tenant for account account.current_tenant = tenant @@ -117,7 +116,7 @@ class TestModelLoadBalancingService: return account, tenant def _create_test_provider_and_setting( - self, db_session_with_containers, tenant_id, mock_external_service_dependencies + self, db_session_with_containers: Session, tenant_id, mock_external_service_dependencies ): """ Helper method to create a test provider and provider model setting. @@ -132,8 +131,6 @@ class TestModelLoadBalancingService: """ fake = Faker() - from extensions.ext_database import db - # Create provider provider = Provider( tenant_id=tenant_id, @@ -141,8 +138,8 @@ class TestModelLoadBalancingService: provider_type="custom", is_valid=True, ) - db.session.add(provider) - db.session.commit() + db_session_with_containers.add(provider) + db_session_with_containers.commit() # Create provider model setting provider_model_setting = ProviderModelSetting( @@ -153,12 +150,14 @@ class TestModelLoadBalancingService: enabled=True, load_balancing_enabled=False, ) - db.session.add(provider_model_setting) - db.session.commit() + db_session_with_containers.add(provider_model_setting) + db_session_with_containers.commit() return provider, provider_model_setting - def test_enable_model_load_balancing_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_enable_model_load_balancing_success( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test successful model load balancing enablement. @@ -193,14 +192,15 @@ class TestModelLoadBalancingService: assert call_args.kwargs["model_type"].value == "llm" # ModelType enum value # Verify database state - from extensions.ext_database import db - db.session.refresh(provider) - db.session.refresh(provider_model_setting) + db_session_with_containers.refresh(provider) + db_session_with_containers.refresh(provider_model_setting) assert provider.id is not None assert provider_model_setting.id is not None - def test_disable_model_load_balancing_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_disable_model_load_balancing_success( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test successful model load balancing disablement. @@ -235,15 +235,14 @@ class TestModelLoadBalancingService: assert call_args.kwargs["model_type"].value == "llm" # ModelType enum value # Verify database state - from extensions.ext_database import db - db.session.refresh(provider) - db.session.refresh(provider_model_setting) + db_session_with_containers.refresh(provider) + db_session_with_containers.refresh(provider_model_setting) assert provider.id is not None assert provider_model_setting.id is not None def test_enable_model_load_balancing_provider_not_found( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test error handling when provider does not exist. @@ -275,11 +274,12 @@ class TestModelLoadBalancingService: assert "Provider nonexistent_provider does not exist." in str(exc_info.value) # Verify no database state changes occurred - from extensions.ext_database import db - db.session.rollback() + db_session_with_containers.rollback() - def test_get_load_balancing_configs_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_get_load_balancing_configs_success( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test successful retrieval of load balancing configurations. @@ -298,7 +298,6 @@ class TestModelLoadBalancingService: ) # Create load balancing config - from extensions.ext_database import db load_balancing_config = LoadBalancingModelConfig( tenant_id=tenant.id, @@ -309,11 +308,11 @@ class TestModelLoadBalancingService: encrypted_config='{"api_key": "test_key"}', enabled=True, ) - db.session.add(load_balancing_config) - db.session.commit() + db_session_with_containers.add(load_balancing_config) + db_session_with_containers.commit() # Verify the config was created - db.session.refresh(load_balancing_config) + db_session_with_containers.refresh(load_balancing_config) assert load_balancing_config.id is not None # Setup mocks for get_load_balancing_configs method @@ -358,11 +357,11 @@ class TestModelLoadBalancingService: assert configs[0]["ttl"] == 0 # Verify database state - db.session.refresh(load_balancing_config) + db_session_with_containers.refresh(load_balancing_config) assert load_balancing_config.id is not None def test_get_load_balancing_configs_provider_not_found( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test error handling when provider does not exist in get_load_balancing_configs. @@ -394,12 +393,11 @@ class TestModelLoadBalancingService: assert "Provider nonexistent_provider does not exist." in str(exc_info.value) # Verify no database state changes occurred - from extensions.ext_database import db - db.session.rollback() + db_session_with_containers.rollback() def test_get_load_balancing_configs_with_inherit_config( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test load balancing configs retrieval with inherit configuration. @@ -419,7 +417,6 @@ class TestModelLoadBalancingService: ) # Create load balancing config - from extensions.ext_database import db load_balancing_config = LoadBalancingModelConfig( tenant_id=tenant.id, @@ -430,8 +427,8 @@ class TestModelLoadBalancingService: encrypted_config='{"api_key": "test_key"}', enabled=True, ) - db.session.add(load_balancing_config) - db.session.commit() + db_session_with_containers.add(load_balancing_config) + db_session_with_containers.commit() # Setup mocks for inherit config scenario mock_provider_config = mock_external_service_dependencies["provider_config"] @@ -467,11 +464,11 @@ class TestModelLoadBalancingService: assert configs[1]["name"] == "config1" # Verify database state - db.session.refresh(load_balancing_config) + db_session_with_containers.refresh(load_balancing_config) assert load_balancing_config.id is not None # Verify inherit config was created in database - inherit_configs = db.session.scalars( + inherit_configs = db_session_with_containers.scalars( select(LoadBalancingModelConfig).where(LoadBalancingModelConfig.name == "__inherit__") ).all() assert len(inherit_configs) == 1 diff --git a/api/tests/test_containers_integration_tests/services/test_model_provider_service.py b/api/tests/test_containers_integration_tests/services/test_model_provider_service.py index 7a4662055c..6afc5aa43c 100644 --- a/api/tests/test_containers_integration_tests/services/test_model_provider_service.py +++ b/api/tests/test_containers_integration_tests/services/test_model_provider_service.py @@ -2,6 +2,7 @@ from unittest.mock import MagicMock, patch import pytest from faker import Faker +from sqlalchemy.orm import Session from core.entities.model_entities import ModelStatus from dify_graph.model_runtime.entities.model_entities import FetchFrom, ModelType @@ -29,7 +30,7 @@ class TestModelProviderService: "model_provider_factory": mock_model_provider_factory, } - def _create_test_account_and_tenant(self, db_session_with_containers, mock_external_service_dependencies): + def _create_test_account_and_tenant(self, db_session_with_containers: Session, mock_external_service_dependencies): """ Helper method to create a test account and tenant for testing. @@ -50,18 +51,16 @@ class TestModelProviderService: status="active", ) - from extensions.ext_database import db - - db.session.add(account) - db.session.commit() + db_session_with_containers.add(account) + db_session_with_containers.commit() # Create tenant for the account tenant = Tenant( name=fake.company(), status="normal", ) - db.session.add(tenant) - db.session.commit() + db_session_with_containers.add(tenant) + db_session_with_containers.commit() # Create tenant-account join join = TenantAccountJoin( @@ -70,8 +69,8 @@ class TestModelProviderService: role=TenantAccountRole.OWNER, current=True, ) - db.session.add(join) - db.session.commit() + db_session_with_containers.add(join) + db_session_with_containers.commit() # Set current tenant for account account.current_tenant = tenant @@ -80,7 +79,7 @@ class TestModelProviderService: def _create_test_provider( self, - db_session_with_containers, + db_session_with_containers: Session, mock_external_service_dependencies, tenant_id: str, provider_name: str = "openai", @@ -109,16 +108,14 @@ class TestModelProviderService: quota_used=0, ) - from extensions.ext_database import db - - db.session.add(provider) - db.session.commit() + db_session_with_containers.add(provider) + db_session_with_containers.commit() return provider def _create_test_provider_model( self, - db_session_with_containers, + db_session_with_containers: Session, mock_external_service_dependencies, tenant_id: str, provider_name: str, @@ -149,16 +146,14 @@ class TestModelProviderService: is_valid=True, ) - from extensions.ext_database import db - - db.session.add(provider_model) - db.session.commit() + db_session_with_containers.add(provider_model) + db_session_with_containers.commit() return provider_model def _create_test_provider_model_setting( self, - db_session_with_containers, + db_session_with_containers: Session, mock_external_service_dependencies, tenant_id: str, provider_name: str, @@ -190,14 +185,12 @@ class TestModelProviderService: load_balancing_enabled=False, ) - from extensions.ext_database import db - - db.session.add(provider_model_setting) - db.session.commit() + db_session_with_containers.add(provider_model_setting) + db_session_with_containers.commit() return provider_model_setting - def test_get_provider_list_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_get_provider_list_success(self, db_session_with_containers: Session, mock_external_service_dependencies): """ Test successful provider list retrieval. @@ -275,7 +268,7 @@ class TestModelProviderService: mock_provider_config.is_custom_configuration_available.assert_called_once() def test_get_provider_list_with_model_type_filter( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test provider list retrieval with model type filtering. @@ -374,7 +367,9 @@ class TestModelProviderService: assert result[0].provider == "cohere" assert ModelType.TEXT_EMBEDDING in result[0].supported_model_types - def test_get_models_by_provider_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_get_models_by_provider_success( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test successful retrieval of models by provider. @@ -485,7 +480,9 @@ class TestModelProviderService: mock_provider_manager.get_configurations.assert_called_once_with(tenant.id) mock_configurations.get_models.assert_called_once_with(provider="openai") - def test_get_provider_credentials_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_get_provider_credentials_success( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test successful retrieval of provider credentials. @@ -543,7 +540,7 @@ class TestModelProviderService: mock_method.assert_called_once_with(tenant.id, "openai") def test_provider_credentials_validate_success( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test successful validation of provider credentials. @@ -585,7 +582,7 @@ class TestModelProviderService: mock_provider_configuration.validate_provider_credentials.assert_called_once_with(test_credentials) def test_provider_credentials_validate_invalid_provider( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test validation failure for non-existent provider. @@ -617,7 +614,7 @@ class TestModelProviderService: mock_provider_manager.get_configurations.assert_called_once_with(tenant.id) def test_get_default_model_of_model_type_success( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test successful retrieval of default model for a specific model type. @@ -673,7 +670,7 @@ class TestModelProviderService: mock_provider_manager.get_default_model.assert_called_once_with(tenant_id=tenant.id, model_type=ModelType.LLM) def test_update_default_model_of_model_type_success( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test successful update of default model for a specific model type. @@ -706,7 +703,9 @@ class TestModelProviderService: tenant_id=tenant.id, model_type=ModelType.LLM, provider="openai", model="gpt-4" ) - def test_get_model_provider_icon_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_get_model_provider_icon_success( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test successful retrieval of model provider icon. @@ -743,7 +742,9 @@ class TestModelProviderService: # Verify mock interactions mock_model_provider_factory.get_provider_icon.assert_called_once_with("openai", "icon_small", "en_US") - def test_switch_preferred_provider_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_switch_preferred_provider_success( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test successful switching of preferred provider type. @@ -779,7 +780,7 @@ class TestModelProviderService: mock_provider_manager.get_configurations.assert_called_once_with(tenant.id) mock_provider_configuration.switch_preferred_provider_type.assert_called_once() - def test_enable_model_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_enable_model_success(self, db_session_with_containers: Session, mock_external_service_dependencies): """ Test successful enabling of a model. @@ -815,7 +816,9 @@ class TestModelProviderService: mock_provider_manager.get_configurations.assert_called_once_with(tenant.id) mock_provider_configuration.enable_model.assert_called_once_with(model_type=ModelType.LLM, model="gpt-4") - def test_get_model_credentials_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_get_model_credentials_success( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test successful retrieval of model credentials. @@ -872,7 +875,9 @@ class TestModelProviderService: # Verify the method was called with correct parameters mock_method.assert_called_once_with(tenant.id, "openai", "llm", "gpt-4", None) - def test_model_credentials_validate_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_model_credentials_validate_success( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test successful validation of model credentials. @@ -914,7 +919,9 @@ class TestModelProviderService: model_type=ModelType.LLM, model="gpt-4", credentials=test_credentials ) - def test_save_model_credentials_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_save_model_credentials_success( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test successful saving of model credentials. @@ -955,7 +962,9 @@ class TestModelProviderService: model_type=ModelType.LLM, model="gpt-4", credentials=test_credentials, credential_name="testname" ) - def test_remove_model_credentials_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_remove_model_credentials_success( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test successful removal of model credentials. @@ -993,7 +1002,9 @@ class TestModelProviderService: model_type=ModelType.LLM, model="gpt-4", credential_id="5540007c-b988-46e0-b1c7-9b5fb9f330d6" ) - def test_get_models_by_model_type_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_get_models_by_model_type_success( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test successful retrieval of models by model type. @@ -1070,7 +1081,9 @@ class TestModelProviderService: mock_provider_manager.get_configurations.assert_called_once_with(tenant.id) mock_provider_configurations.get_models.assert_called_once_with(model_type=ModelType.LLM, only_active=True) - def test_get_model_parameter_rules_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_get_model_parameter_rules_success( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test successful retrieval of model parameter rules. @@ -1137,7 +1150,7 @@ class TestModelProviderService: ) def test_get_model_parameter_rules_no_credentials( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test parameter rules retrieval when no credentials are available. @@ -1181,7 +1194,7 @@ class TestModelProviderService: ) def test_get_model_parameter_rules_provider_not_found( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test parameter rules retrieval when provider does not exist. diff --git a/api/tests/test_containers_integration_tests/services/test_saved_message_service.py b/api/tests/test_containers_integration_tests/services/test_saved_message_service.py index 9e6b9837ae..e3ec1d1df3 100644 --- a/api/tests/test_containers_integration_tests/services/test_saved_message_service.py +++ b/api/tests/test_containers_integration_tests/services/test_saved_message_service.py @@ -2,6 +2,7 @@ from unittest.mock import patch import pytest from faker import Faker +from sqlalchemy.orm import Session from models.model import EndUser, Message from models.web import SavedMessage @@ -38,7 +39,7 @@ class TestSavedMessageService: "message_service": mock_message_service, } - def _create_test_app_and_account(self, db_session_with_containers, mock_external_service_dependencies): + def _create_test_app_and_account(self, db_session_with_containers: Session, mock_external_service_dependencies): """ Helper method to create a test app and account for testing. @@ -85,7 +86,7 @@ class TestSavedMessageService: return app, account - def _create_test_end_user(self, db_session_with_containers, app): + def _create_test_end_user(self, db_session_with_containers: Session, app): """ Helper method to create a test end user for testing. @@ -108,14 +109,12 @@ class TestSavedMessageService: is_anonymous=False, ) - from extensions.ext_database import db - - db.session.add(end_user) - db.session.commit() + db_session_with_containers.add(end_user) + db_session_with_containers.commit() return end_user - def _create_test_message(self, db_session_with_containers, app, user): + def _create_test_message(self, db_session_with_containers: Session, app, user): """ Helper method to create a test message for testing. @@ -143,10 +142,8 @@ class TestSavedMessageService: mode="chat", ) - from extensions.ext_database import db - - db.session.add(conversation) - db.session.commit() + db_session_with_containers.add(conversation) + db_session_with_containers.commit() # Create message message = Message( @@ -168,13 +165,13 @@ class TestSavedMessageService: status="success", ) - db.session.add(message) - db.session.commit() + db_session_with_containers.add(message) + db_session_with_containers.commit() return message def test_pagination_by_last_id_success_with_account_user( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test successful pagination by last ID with account user. @@ -207,10 +204,8 @@ class TestSavedMessageService: created_by=account.id, ) - from extensions.ext_database import db - - db.session.add_all([saved_message1, saved_message2]) - db.session.commit() + db_session_with_containers.add_all([saved_message1, saved_message2]) + db_session_with_containers.commit() # Mock MessageService.pagination_by_last_id return value from libs.infinite_scroll_pagination import InfiniteScrollPagination @@ -240,15 +235,15 @@ class TestSavedMessageService: assert actual_include_ids == expected_include_ids # Verify database state - db.session.refresh(saved_message1) - db.session.refresh(saved_message2) + db_session_with_containers.refresh(saved_message1) + db_session_with_containers.refresh(saved_message2) assert saved_message1.id is not None assert saved_message2.id is not None assert saved_message1.created_by_role == "account" assert saved_message2.created_by_role == "account" def test_pagination_by_last_id_success_with_end_user( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test successful pagination by last ID with end user. @@ -282,10 +277,8 @@ class TestSavedMessageService: created_by=end_user.id, ) - from extensions.ext_database import db - - db.session.add_all([saved_message1, saved_message2]) - db.session.commit() + db_session_with_containers.add_all([saved_message1, saved_message2]) + db_session_with_containers.commit() # Mock MessageService.pagination_by_last_id return value from libs.infinite_scroll_pagination import InfiniteScrollPagination @@ -317,14 +310,16 @@ class TestSavedMessageService: assert actual_include_ids == expected_include_ids # Verify database state - db.session.refresh(saved_message1) - db.session.refresh(saved_message2) + db_session_with_containers.refresh(saved_message1) + db_session_with_containers.refresh(saved_message2) assert saved_message1.id is not None assert saved_message2.id is not None assert saved_message1.created_by_role == "end_user" assert saved_message2.created_by_role == "end_user" - def test_save_success_with_new_message(self, db_session_with_containers, mock_external_service_dependencies): + def test_save_success_with_new_message( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test successful save of a new message. @@ -347,10 +342,9 @@ class TestSavedMessageService: # Assert: Verify the expected outcomes # Check if saved message was created in database - from extensions.ext_database import db saved_message = ( - db.session.query(SavedMessage) + db_session_with_containers.query(SavedMessage) .where( SavedMessage.app_id == app.id, SavedMessage.message_id == message.id, @@ -373,10 +367,12 @@ class TestSavedMessageService: ) # Verify database state - db.session.refresh(saved_message) + db_session_with_containers.refresh(saved_message) assert saved_message.id is not None - def test_pagination_by_last_id_error_no_user(self, db_session_with_containers, mock_external_service_dependencies): + def test_pagination_by_last_id_error_no_user( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test error handling when no user is provided. @@ -396,12 +392,11 @@ class TestSavedMessageService: assert "User is required" in str(exc_info.value) # Verify no database operations were performed - from extensions.ext_database import db - saved_messages = db.session.query(SavedMessage).all() + saved_messages = db_session_with_containers.query(SavedMessage).all() assert len(saved_messages) == 0 - def test_save_error_no_user(self, db_session_with_containers, mock_external_service_dependencies): + def test_save_error_no_user(self, db_session_with_containers: Session, mock_external_service_dependencies): """ Test error handling when saving message with no user. @@ -422,10 +417,9 @@ class TestSavedMessageService: assert result is None # Verify no saved message was created - from extensions.ext_database import db saved_message = ( - db.session.query(SavedMessage) + db_session_with_containers.query(SavedMessage) .where( SavedMessage.app_id == app.id, SavedMessage.message_id == message.id, @@ -435,7 +429,9 @@ class TestSavedMessageService: assert saved_message is None - def test_delete_success_existing_message(self, db_session_with_containers, mock_external_service_dependencies): + def test_delete_success_existing_message( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test successful deletion of an existing saved message. @@ -457,14 +453,12 @@ class TestSavedMessageService: created_by=account.id, ) - from extensions.ext_database import db - - db.session.add(saved_message) - db.session.commit() + db_session_with_containers.add(saved_message) + db_session_with_containers.commit() # Verify saved message exists assert ( - db.session.query(SavedMessage) + db_session_with_containers.query(SavedMessage) .where( SavedMessage.app_id == app.id, SavedMessage.message_id == message.id, @@ -481,7 +475,7 @@ class TestSavedMessageService: # Assert: Verify the expected outcomes # Check if saved message was deleted from database deleted_saved_message = ( - db.session.query(SavedMessage) + db_session_with_containers.query(SavedMessage) .where( SavedMessage.app_id == app.id, SavedMessage.message_id == message.id, @@ -494,11 +488,13 @@ class TestSavedMessageService: assert deleted_saved_message is None # Verify database state - db.session.commit() + db_session_with_containers.commit() # The message should still exist, only the saved_message should be deleted - assert db.session.query(Message).where(Message.id == message.id).first() is not None + assert db_session_with_containers.query(Message).where(Message.id == message.id).first() is not None - def test_pagination_by_last_id_error_no_user(self, db_session_with_containers, mock_external_service_dependencies): + def test_pagination_by_last_id_error_no_user( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test error handling when no user is provided. @@ -522,7 +518,7 @@ class TestSavedMessageService: # Instead, we verify that the error was properly raised pass - def test_save_error_no_user(self, db_session_with_containers, mock_external_service_dependencies): + def test_save_error_no_user(self, db_session_with_containers: Session, mock_external_service_dependencies): """ Test error handling when saving message with no user. @@ -543,10 +539,9 @@ class TestSavedMessageService: assert result is None # Verify no saved message was created - from extensions.ext_database import db saved_message = ( - db.session.query(SavedMessage) + db_session_with_containers.query(SavedMessage) .where( SavedMessage.app_id == app.id, SavedMessage.message_id == message.id, @@ -556,7 +551,9 @@ class TestSavedMessageService: assert saved_message is None - def test_delete_success_existing_message(self, db_session_with_containers, mock_external_service_dependencies): + def test_delete_success_existing_message( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test successful deletion of an existing saved message. @@ -578,14 +575,12 @@ class TestSavedMessageService: created_by=account.id, ) - from extensions.ext_database import db - - db.session.add(saved_message) - db.session.commit() + db_session_with_containers.add(saved_message) + db_session_with_containers.commit() # Verify saved message exists assert ( - db.session.query(SavedMessage) + db_session_with_containers.query(SavedMessage) .where( SavedMessage.app_id == app.id, SavedMessage.message_id == message.id, @@ -602,7 +597,7 @@ class TestSavedMessageService: # Assert: Verify the expected outcomes # Check if saved message was deleted from database deleted_saved_message = ( - db.session.query(SavedMessage) + db_session_with_containers.query(SavedMessage) .where( SavedMessage.app_id == app.id, SavedMessage.message_id == message.id, @@ -615,6 +610,6 @@ class TestSavedMessageService: assert deleted_saved_message is None # Verify database state - db.session.commit() + db_session_with_containers.commit() # The message should still exist, only the saved_message should be deleted - assert db.session.query(Message).where(Message.id == message.id).first() is not None + assert db_session_with_containers.query(Message).where(Message.id == message.id).first() is not None diff --git a/api/tests/test_containers_integration_tests/services/test_tag_service.py b/api/tests/test_containers_integration_tests/services/test_tag_service.py index e8c7f17e0b..597ba6b75b 100644 --- a/api/tests/test_containers_integration_tests/services/test_tag_service.py +++ b/api/tests/test_containers_integration_tests/services/test_tag_service.py @@ -4,6 +4,7 @@ from unittest.mock import create_autospec, patch import pytest from faker import Faker from sqlalchemy import select +from sqlalchemy.orm import Session from werkzeug.exceptions import NotFound from models import Account, Tenant, TenantAccountJoin, TenantAccountRole @@ -29,7 +30,7 @@ class TestTagService: "current_user": mock_current_user, } - def _create_test_account_and_tenant(self, db_session_with_containers, mock_external_service_dependencies): + def _create_test_account_and_tenant(self, db_session_with_containers: Session, mock_external_service_dependencies): """ Helper method to create a test account and tenant for testing. @@ -50,18 +51,16 @@ class TestTagService: status="active", ) - from extensions.ext_database import db - - db.session.add(account) - db.session.commit() + db_session_with_containers.add(account) + db_session_with_containers.commit() # Create tenant for the account tenant = Tenant( name=fake.company(), status="normal", ) - db.session.add(tenant) - db.session.commit() + db_session_with_containers.add(tenant) + db_session_with_containers.commit() # Create tenant-account join join = TenantAccountJoin( @@ -70,8 +69,8 @@ class TestTagService: role=TenantAccountRole.OWNER, current=True, ) - db.session.add(join) - db.session.commit() + db_session_with_containers.add(join) + db_session_with_containers.commit() # Set current tenant for account account.current_tenant = tenant @@ -82,7 +81,7 @@ class TestTagService: return account, tenant - def _create_test_dataset(self, db_session_with_containers, mock_external_service_dependencies, tenant_id): + def _create_test_dataset(self, db_session_with_containers: Session, mock_external_service_dependencies, tenant_id): """ Helper method to create a test dataset for testing. @@ -107,14 +106,12 @@ class TestTagService: created_by=mock_external_service_dependencies["current_user"].id, ) - from extensions.ext_database import db - - db.session.add(dataset) - db.session.commit() + db_session_with_containers.add(dataset) + db_session_with_containers.commit() return dataset - def _create_test_app(self, db_session_with_containers, mock_external_service_dependencies, tenant_id): + def _create_test_app(self, db_session_with_containers: Session, mock_external_service_dependencies, tenant_id): """ Helper method to create a test app for testing. @@ -141,15 +138,13 @@ class TestTagService: created_by=mock_external_service_dependencies["current_user"].id, ) - from extensions.ext_database import db - - db.session.add(app) - db.session.commit() + db_session_with_containers.add(app) + db_session_with_containers.commit() return app def _create_test_tags( - self, db_session_with_containers, mock_external_service_dependencies, tenant_id, tag_type, count=3 + self, db_session_with_containers: Session, mock_external_service_dependencies, tenant_id, tag_type, count=3 ): """ Helper method to create test tags for testing. @@ -176,16 +171,14 @@ class TestTagService: ) tags.append(tag) - from extensions.ext_database import db - for tag in tags: - db.session.add(tag) - db.session.commit() + db_session_with_containers.add(tag) + db_session_with_containers.commit() return tags def _create_test_tag_bindings( - self, db_session_with_containers, mock_external_service_dependencies, tags, target_id, tenant_id + self, db_session_with_containers: Session, mock_external_service_dependencies, tags, target_id, tenant_id ): """ Helper method to create test tag bindings for testing. @@ -211,15 +204,13 @@ class TestTagService: ) tag_bindings.append(tag_binding) - from extensions.ext_database import db - for tag_binding in tag_bindings: - db.session.add(tag_binding) - db.session.commit() + db_session_with_containers.add(tag_binding) + db_session_with_containers.commit() return tag_bindings - def test_get_tags_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_get_tags_success(self, db_session_with_containers: Session, mock_external_service_dependencies): """ Test successful retrieval of tags with binding count. @@ -270,7 +261,9 @@ class TestTagService: # The ordering is handled by the database, we just verify the results are returned assert len(result) == 3 - def test_get_tags_with_keyword_filter(self, db_session_with_containers, mock_external_service_dependencies): + def test_get_tags_with_keyword_filter( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test tag retrieval with keyword filtering. @@ -291,12 +284,11 @@ class TestTagService: ) # Update tag names to make them searchable - from extensions.ext_database import db tags[0].name = "python_development" tags[1].name = "machine_learning" tags[2].name = "web_development" - db.session.commit() + db_session_with_containers.commit() # Act: Execute the method under test with keyword filter result = TagService.get_tags("app", tenant.id, keyword="development") @@ -314,7 +306,7 @@ class TestTagService: assert len(result_no_match) == 0 def test_get_tags_with_special_characters_in_keyword( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): r""" Test tag retrieval with special characters in keyword to verify SQL injection prevention. @@ -330,8 +322,6 @@ class TestTagService: db_session_with_containers, mock_external_service_dependencies ) - from extensions.ext_database import db - # Create tags with special characters in names tag_with_percent = Tag( name="50% discount", @@ -340,7 +330,7 @@ class TestTagService: created_by=account.id, ) tag_with_percent.id = str(uuid.uuid4()) - db.session.add(tag_with_percent) + db_session_with_containers.add(tag_with_percent) tag_with_underscore = Tag( name="test_data_tag", @@ -349,7 +339,7 @@ class TestTagService: created_by=account.id, ) tag_with_underscore.id = str(uuid.uuid4()) - db.session.add(tag_with_underscore) + db_session_with_containers.add(tag_with_underscore) tag_with_backslash = Tag( name="path\\to\\tag", @@ -358,7 +348,7 @@ class TestTagService: created_by=account.id, ) tag_with_backslash.id = str(uuid.uuid4()) - db.session.add(tag_with_backslash) + db_session_with_containers.add(tag_with_backslash) # Create tag that should NOT match tag_no_match = Tag( @@ -368,9 +358,9 @@ class TestTagService: created_by=account.id, ) tag_no_match.id = str(uuid.uuid4()) - db.session.add(tag_no_match) + db_session_with_containers.add(tag_no_match) - db.session.commit() + db_session_with_containers.commit() # Act & Assert: Test 1 - Search with % character result = TagService.get_tags("app", tenant.id, keyword="50%") @@ -392,7 +382,7 @@ class TestTagService: assert len(result) == 1 assert all("50%" in item.name for item in result) - def test_get_tags_empty_result(self, db_session_with_containers, mock_external_service_dependencies): + def test_get_tags_empty_result(self, db_session_with_containers: Session, mock_external_service_dependencies): """ Test tag retrieval when no tags exist. @@ -414,7 +404,9 @@ class TestTagService: assert len(result) == 0 assert isinstance(result, list) - def test_get_target_ids_by_tag_ids_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_get_target_ids_by_tag_ids_success( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test successful retrieval of target IDs by tag IDs. @@ -469,7 +461,7 @@ class TestTagService: assert second_dataset_count == 1 def test_get_target_ids_by_tag_ids_empty_tag_ids( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test target ID retrieval with empty tag IDs list. @@ -493,7 +485,7 @@ class TestTagService: assert isinstance(result, list) def test_get_target_ids_by_tag_ids_no_matching_tags( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test target ID retrieval when no tags match the criteria. @@ -521,7 +513,7 @@ class TestTagService: assert len(result) == 0 assert isinstance(result, list) - def test_get_tag_by_tag_name_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_get_tag_by_tag_name_success(self, db_session_with_containers: Session, mock_external_service_dependencies): """ Test successful retrieval of tags by tag name. @@ -542,11 +534,10 @@ class TestTagService: ) # Update tag names to make them searchable - from extensions.ext_database import db tags[0].name = "python_tag" tags[1].name = "ml_tag" - db.session.commit() + db_session_with_containers.commit() # Act: Execute the method under test result = TagService.get_tag_by_tag_name("app", tenant.id, "python_tag") @@ -558,7 +549,9 @@ class TestTagService: assert result[0].type == "app" assert result[0].tenant_id == tenant.id - def test_get_tag_by_tag_name_no_matches(self, db_session_with_containers, mock_external_service_dependencies): + def test_get_tag_by_tag_name_no_matches( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test tag retrieval by name when no matches exist. @@ -580,7 +573,9 @@ class TestTagService: assert len(result) == 0 assert isinstance(result, list) - def test_get_tag_by_tag_name_empty_parameters(self, db_session_with_containers, mock_external_service_dependencies): + def test_get_tag_by_tag_name_empty_parameters( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test tag retrieval by name with empty parameters. @@ -605,7 +600,9 @@ class TestTagService: assert result_empty_name is not None assert len(result_empty_name) == 0 - def test_get_tags_by_target_id_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_get_tags_by_target_id_success( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test successful retrieval of tags by target ID. @@ -644,7 +641,9 @@ class TestTagService: assert tag.tenant_id == tenant.id assert tag.id in [t.id for t in tags] - def test_get_tags_by_target_id_no_bindings(self, db_session_with_containers, mock_external_service_dependencies): + def test_get_tags_by_target_id_no_bindings( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test tag retrieval by target ID when no tags are bound. @@ -669,7 +668,7 @@ class TestTagService: assert len(result) == 0 assert isinstance(result, list) - def test_save_tags_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_save_tags_success(self, db_session_with_containers: Session, mock_external_service_dependencies): """ Test successful tag creation. @@ -698,17 +697,18 @@ class TestTagService: assert result.id is not None # Verify database state - from extensions.ext_database import db - db.session.refresh(result) + db_session_with_containers.refresh(result) assert result.id is not None # Verify tag was actually saved to database - saved_tag = db.session.query(Tag).where(Tag.id == result.id).first() + saved_tag = db_session_with_containers.query(Tag).where(Tag.id == result.id).first() assert saved_tag is not None assert saved_tag.name == "test_tag_name" - def test_save_tags_duplicate_name_error(self, db_session_with_containers, mock_external_service_dependencies): + def test_save_tags_duplicate_name_error( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test tag creation with duplicate name. @@ -731,7 +731,7 @@ class TestTagService: TagService.save_tags(tag_args) assert "Tag name already exists" in str(exc_info.value) - def test_update_tags_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_update_tags_success(self, db_session_with_containers: Session, mock_external_service_dependencies): """ Test successful tag update. @@ -763,17 +763,16 @@ class TestTagService: assert result.id == tag.id # Verify database state - from extensions.ext_database import db - db.session.refresh(result) + db_session_with_containers.refresh(result) assert result.name == "updated_name" # Verify tag was actually updated in database - updated_tag = db.session.query(Tag).where(Tag.id == tag.id).first() + updated_tag = db_session_with_containers.query(Tag).where(Tag.id == tag.id).first() assert updated_tag is not None assert updated_tag.name == "updated_name" - def test_update_tags_not_found_error(self, db_session_with_containers, mock_external_service_dependencies): + def test_update_tags_not_found_error(self, db_session_with_containers: Session, mock_external_service_dependencies): """ Test tag update for non-existent tag. @@ -799,7 +798,9 @@ class TestTagService: TagService.update_tags(update_args, non_existent_tag_id) assert "Tag not found" in str(exc_info.value) - def test_update_tags_duplicate_name_error(self, db_session_with_containers, mock_external_service_dependencies): + def test_update_tags_duplicate_name_error( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test tag update with duplicate name. @@ -828,7 +829,9 @@ class TestTagService: TagService.update_tags(update_args, tag2.id) assert "Tag name already exists" in str(exc_info.value) - def test_get_tag_binding_count_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_get_tag_binding_count_success( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test successful retrieval of tag binding count. @@ -863,7 +866,7 @@ class TestTagService: assert result_tag_without_bindings == 0 def test_get_tag_binding_count_non_existent_tag( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test binding count retrieval for non-existent tag. @@ -889,7 +892,7 @@ class TestTagService: # Assert: Verify the expected outcomes assert result == 0 - def test_delete_tag_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_delete_tag_success(self, db_session_with_containers: Session, mock_external_service_dependencies): """ Test successful tag deletion. @@ -916,12 +919,11 @@ class TestTagService: ) # Verify tag and binding exist before deletion - from extensions.ext_database import db - tag_before = db.session.query(Tag).where(Tag.id == tag.id).first() + tag_before = db_session_with_containers.query(Tag).where(Tag.id == tag.id).first() assert tag_before is not None - binding_before = db.session.query(TagBinding).where(TagBinding.tag_id == tag.id).first() + binding_before = db_session_with_containers.query(TagBinding).where(TagBinding.tag_id == tag.id).first() assert binding_before is not None # Act: Execute the method under test @@ -929,14 +931,14 @@ class TestTagService: # Assert: Verify the expected outcomes # Verify tag was deleted - tag_after = db.session.query(Tag).where(Tag.id == tag.id).first() + tag_after = db_session_with_containers.query(Tag).where(Tag.id == tag.id).first() assert tag_after is None # Verify tag binding was deleted - binding_after = db.session.query(TagBinding).where(TagBinding.tag_id == tag.id).first() + binding_after = db_session_with_containers.query(TagBinding).where(TagBinding.tag_id == tag.id).first() assert binding_after is None - def test_delete_tag_not_found_error(self, db_session_with_containers, mock_external_service_dependencies): + def test_delete_tag_not_found_error(self, db_session_with_containers: Session, mock_external_service_dependencies): """ Test tag deletion for non-existent tag. @@ -960,7 +962,7 @@ class TestTagService: TagService.delete_tag(non_existent_tag_id) assert "Tag not found" in str(exc_info.value) - def test_save_tag_binding_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_save_tag_binding_success(self, db_session_with_containers: Session, mock_external_service_dependencies): """ Test successful tag binding creation. @@ -988,12 +990,11 @@ class TestTagService: TagService.save_tag_binding(binding_args) # Assert: Verify the expected outcomes - from extensions.ext_database import db # Verify tag bindings were created for tag in tags: binding = ( - db.session.query(TagBinding) + db_session_with_containers.query(TagBinding) .where(TagBinding.tag_id == tag.id, TagBinding.target_id == dataset.id) .first() ) @@ -1001,7 +1002,9 @@ class TestTagService: assert binding.tenant_id == tenant.id assert binding.created_by == account.id - def test_save_tag_binding_duplicate_handling(self, db_session_with_containers, mock_external_service_dependencies): + def test_save_tag_binding_duplicate_handling( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test tag binding creation with duplicate bindings. @@ -1032,15 +1035,16 @@ class TestTagService: TagService.save_tag_binding(binding_args) # Assert: Verify the expected outcomes - from extensions.ext_database import db # Verify only one binding exists - bindings = db.session.scalars( + bindings = db_session_with_containers.scalars( select(TagBinding).where(TagBinding.tag_id == tag.id, TagBinding.target_id == app.id) ).all() assert len(bindings) == 1 - def test_save_tag_binding_invalid_target_type(self, db_session_with_containers, mock_external_service_dependencies): + def test_save_tag_binding_invalid_target_type( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test tag binding creation with invalid target type. @@ -1071,7 +1075,7 @@ class TestTagService: TagService.save_tag_binding(binding_args) assert "Invalid binding type" in str(exc_info.value) - def test_delete_tag_binding_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_delete_tag_binding_success(self, db_session_with_containers: Session, mock_external_service_dependencies): """ Test successful tag binding deletion. @@ -1098,10 +1102,11 @@ class TestTagService: ) # Verify binding exists before deletion - from extensions.ext_database import db binding_before = ( - db.session.query(TagBinding).where(TagBinding.tag_id == tag.id, TagBinding.target_id == dataset.id).first() + db_session_with_containers.query(TagBinding) + .where(TagBinding.tag_id == tag.id, TagBinding.target_id == dataset.id) + .first() ) assert binding_before is not None @@ -1112,12 +1117,14 @@ class TestTagService: # Assert: Verify the expected outcomes # Verify tag binding was deleted binding_after = ( - db.session.query(TagBinding).where(TagBinding.tag_id == tag.id, TagBinding.target_id == dataset.id).first() + db_session_with_containers.query(TagBinding) + .where(TagBinding.tag_id == tag.id, TagBinding.target_id == dataset.id) + .first() ) assert binding_after is None def test_delete_tag_binding_non_existent_binding( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test tag binding deletion for non-existent binding. @@ -1145,15 +1152,14 @@ class TestTagService: # Assert: Verify the expected outcomes # No error should be raised, and database state should remain unchanged - from extensions.ext_database import db - bindings = db.session.scalars( + bindings = db_session_with_containers.scalars( select(TagBinding).where(TagBinding.tag_id == tag.id, TagBinding.target_id == app.id) ).all() assert len(bindings) == 0 def test_check_target_exists_knowledge_success( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test successful target existence check for knowledge type. @@ -1179,7 +1185,7 @@ class TestTagService: # No exception should be raised for existing dataset def test_check_target_exists_knowledge_not_found( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test target existence check for non-existent knowledge dataset. @@ -1204,7 +1210,9 @@ class TestTagService: TagService.check_target_exists("knowledge", non_existent_dataset_id) assert "Dataset not found" in str(exc_info.value) - def test_check_target_exists_app_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_check_target_exists_app_success( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test successful target existence check for app type. @@ -1228,7 +1236,9 @@ class TestTagService: # Assert: Verify the expected outcomes # No exception should be raised for existing app - def test_check_target_exists_app_not_found(self, db_session_with_containers, mock_external_service_dependencies): + def test_check_target_exists_app_not_found( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test target existence check for non-existent app. @@ -1252,7 +1262,9 @@ class TestTagService: TagService.check_target_exists("app", non_existent_app_id) assert "App not found" in str(exc_info.value) - def test_check_target_exists_invalid_type(self, db_session_with_containers, mock_external_service_dependencies): + def test_check_target_exists_invalid_type( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test target existence check for invalid type. diff --git a/api/tests/test_containers_integration_tests/services/test_trigger_provider_service.py b/api/tests/test_containers_integration_tests/services/test_trigger_provider_service.py index 5315960d73..912aa3dd2f 100644 --- a/api/tests/test_containers_integration_tests/services/test_trigger_provider_service.py +++ b/api/tests/test_containers_integration_tests/services/test_trigger_provider_service.py @@ -2,11 +2,11 @@ from unittest.mock import MagicMock, patch import pytest from faker import Faker +from sqlalchemy.orm import Session from constants import HIDDEN_VALUE, UNKNOWN_VALUE from core.plugin.entities.plugin_daemon import CredentialType from core.trigger.entities.entities import Subscription as TriggerSubscriptionEntity -from extensions.ext_database import db from models.provider_ids import TriggerProviderID from models.trigger import TriggerSubscription from services.trigger.trigger_provider_service import TriggerProviderService @@ -47,7 +47,7 @@ class TestTriggerProviderService: "account_feature_service": mock_account_feature_service, } - def _create_test_account_and_tenant(self, db_session_with_containers, mock_external_service_dependencies): + def _create_test_account_and_tenant(self, db_session_with_containers: Session, mock_external_service_dependencies): """ Helper method to create a test account and tenant for testing. @@ -84,7 +84,7 @@ class TestTriggerProviderService: def _create_test_subscription( self, - db_session_with_containers, + db_session_with_containers: Session, tenant_id, user_id, provider_id, @@ -135,14 +135,14 @@ class TestTriggerProviderService: expires_at=-1, ) - db.session.add(subscription) - db.session.commit() - db.session.refresh(subscription) + db_session_with_containers.add(subscription) + db_session_with_containers.commit() + db_session_with_containers.refresh(subscription) return subscription def test_rebuild_trigger_subscription_success_with_merged_credentials( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test successful rebuild with credential merging (HIDDEN_VALUE handling). @@ -217,7 +217,7 @@ class TestTriggerProviderService: assert subscribe_credentials["api_secret"] == "new-secret-value" # New value # Verify database state was updated - db.session.refresh(subscription) + db_session_with_containers.refresh(subscription) assert subscription.name == "updated_name" assert subscription.parameters == {"param1": "updated_value"} @@ -244,7 +244,7 @@ class TestTriggerProviderService: ) def test_rebuild_trigger_subscription_with_all_new_credentials( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test rebuild when all credentials are new (no HIDDEN_VALUE). @@ -304,7 +304,7 @@ class TestTriggerProviderService: assert subscribe_credentials["api_secret"] == "completely-new-secret" def test_rebuild_trigger_subscription_with_all_hidden_values( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test rebuild when all credentials are HIDDEN_VALUE (preserve all existing). @@ -363,7 +363,7 @@ class TestTriggerProviderService: assert subscribe_credentials["api_secret"] == original_credentials["api_secret"] def test_rebuild_trigger_subscription_with_missing_key_uses_unknown_value( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test rebuild when HIDDEN_VALUE is used for a key that doesn't exist in original. @@ -422,7 +422,7 @@ class TestTriggerProviderService: assert subscribe_credentials["non_existent_key"] == UNKNOWN_VALUE def test_rebuild_trigger_subscription_rollback_on_error( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test that transaction is rolled back on error. @@ -470,12 +470,12 @@ class TestTriggerProviderService: ) # Verify subscription state was not changed (rolled back) - db.session.refresh(subscription) + db_session_with_containers.refresh(subscription) assert subscription.name == original_name assert subscription.parameters == original_parameters def test_rebuild_trigger_subscription_subscription_not_found( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test error when subscription is not found. @@ -501,7 +501,7 @@ class TestTriggerProviderService: ) def test_rebuild_trigger_subscription_name_uniqueness_check( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test that name uniqueness is checked when updating name. diff --git a/api/tests/test_containers_integration_tests/services/test_web_conversation_service.py b/api/tests/test_containers_integration_tests/services/test_web_conversation_service.py index bbbf48ede9..f1e8c152f1 100644 --- a/api/tests/test_containers_integration_tests/services/test_web_conversation_service.py +++ b/api/tests/test_containers_integration_tests/services/test_web_conversation_service.py @@ -3,6 +3,7 @@ from unittest.mock import patch import pytest from faker import Faker from sqlalchemy import select +from sqlalchemy.orm import Session from core.app.entities.app_invoke_entities import InvokeFrom from models import Account @@ -45,7 +46,7 @@ class TestWebConversationService: "account_feature_service": mock_account_feature_service, } - def _create_test_app_and_account(self, db_session_with_containers, mock_external_service_dependencies): + def _create_test_app_and_account(self, db_session_with_containers: Session, mock_external_service_dependencies): """ Helper method to create a test app and account for testing. @@ -90,7 +91,7 @@ class TestWebConversationService: return app, account - def _create_test_end_user(self, db_session_with_containers, app): + def _create_test_end_user(self, db_session_with_containers: Session, app): """ Helper method to create a test end user for testing. @@ -111,14 +112,12 @@ class TestWebConversationService: tenant_id=app.tenant_id, ) - from extensions.ext_database import db - - db.session.add(end_user) - db.session.commit() + db_session_with_containers.add(end_user) + db_session_with_containers.commit() return end_user - def _create_test_conversation(self, db_session_with_containers, app, user, fake): + def _create_test_conversation(self, db_session_with_containers: Session, app, user, fake): """ Helper method to create a test conversation for testing. @@ -152,14 +151,14 @@ class TestWebConversationService: is_deleted=False, ) - from extensions.ext_database import db - - db.session.add(conversation) - db.session.commit() + db_session_with_containers.add(conversation) + db_session_with_containers.commit() return conversation - def test_pagination_by_last_id_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_pagination_by_last_id_success( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test successful pagination by last ID with basic parameters. """ @@ -194,7 +193,7 @@ class TestWebConversationService: assert result.data[1].updated_at >= result.data[2].updated_at def test_pagination_by_last_id_with_pinned_filter( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test pagination by last ID with pinned conversation filter. @@ -222,11 +221,9 @@ class TestWebConversationService: created_by=account.id, ) - from extensions.ext_database import db - - db.session.add(pinned_conversation1) - db.session.add(pinned_conversation2) - db.session.commit() + db_session_with_containers.add(pinned_conversation1) + db_session_with_containers.add(pinned_conversation2) + db_session_with_containers.commit() # Test pagination with pinned filter result = WebConversationService.pagination_by_last_id( @@ -251,7 +248,7 @@ class TestWebConversationService: assert set(returned_ids) == set(expected_ids) def test_pagination_by_last_id_with_unpinned_filter( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test pagination by last ID with unpinned conversation filter. @@ -273,10 +270,8 @@ class TestWebConversationService: created_by=account.id, ) - from extensions.ext_database import db - - db.session.add(pinned_conversation) - db.session.commit() + db_session_with_containers.add(pinned_conversation) + db_session_with_containers.commit() # Test pagination with unpinned filter result = WebConversationService.pagination_by_last_id( @@ -303,7 +298,7 @@ class TestWebConversationService: expected_unpinned_ids = [conv.id for conv in conversations[1:]] assert set(returned_ids) == set(expected_unpinned_ids) - def test_pin_conversation_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_pin_conversation_success(self, db_session_with_containers: Session, mock_external_service_dependencies): """ Test successful pinning of a conversation. """ @@ -317,10 +312,9 @@ class TestWebConversationService: WebConversationService.pin(app, conversation.id, account) # Verify the conversation was pinned - from extensions.ext_database import db pinned_conversation = ( - db.session.query(PinnedConversation) + db_session_with_containers.query(PinnedConversation) .where( PinnedConversation.app_id == app.id, PinnedConversation.conversation_id == conversation.id, @@ -336,7 +330,9 @@ class TestWebConversationService: assert pinned_conversation.created_by_role == "account" assert pinned_conversation.created_by == account.id - def test_pin_conversation_already_pinned(self, db_session_with_containers, mock_external_service_dependencies): + def test_pin_conversation_already_pinned( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test pinning a conversation that is already pinned (should not create duplicate). """ @@ -353,9 +349,8 @@ class TestWebConversationService: WebConversationService.pin(app, conversation.id, account) # Verify only one pinned conversation record exists - from extensions.ext_database import db - pinned_conversations = db.session.scalars( + pinned_conversations = db_session_with_containers.scalars( select(PinnedConversation).where( PinnedConversation.app_id == app.id, PinnedConversation.conversation_id == conversation.id, @@ -366,7 +361,9 @@ class TestWebConversationService: assert len(pinned_conversations) == 1 - def test_pin_conversation_with_end_user(self, db_session_with_containers, mock_external_service_dependencies): + def test_pin_conversation_with_end_user( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test pinning a conversation with an end user. """ @@ -383,10 +380,9 @@ class TestWebConversationService: WebConversationService.pin(app, conversation.id, end_user) # Verify the conversation was pinned - from extensions.ext_database import db pinned_conversation = ( - db.session.query(PinnedConversation) + db_session_with_containers.query(PinnedConversation) .where( PinnedConversation.app_id == app.id, PinnedConversation.conversation_id == conversation.id, @@ -402,7 +398,7 @@ class TestWebConversationService: assert pinned_conversation.created_by_role == "end_user" assert pinned_conversation.created_by == end_user.id - def test_unpin_conversation_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_unpin_conversation_success(self, db_session_with_containers: Session, mock_external_service_dependencies): """ Test successful unpinning of a conversation. """ @@ -416,10 +412,9 @@ class TestWebConversationService: WebConversationService.pin(app, conversation.id, account) # Verify it was pinned - from extensions.ext_database import db pinned_conversation = ( - db.session.query(PinnedConversation) + db_session_with_containers.query(PinnedConversation) .where( PinnedConversation.app_id == app.id, PinnedConversation.conversation_id == conversation.id, @@ -436,7 +431,7 @@ class TestWebConversationService: # Verify it was unpinned pinned_conversation = ( - db.session.query(PinnedConversation) + db_session_with_containers.query(PinnedConversation) .where( PinnedConversation.app_id == app.id, PinnedConversation.conversation_id == conversation.id, @@ -448,7 +443,9 @@ class TestWebConversationService: assert pinned_conversation is None - def test_unpin_conversation_not_pinned(self, db_session_with_containers, mock_external_service_dependencies): + def test_unpin_conversation_not_pinned( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test unpinning a conversation that is not pinned (should not cause error). """ @@ -462,10 +459,9 @@ class TestWebConversationService: WebConversationService.unpin(app, conversation.id, account) # Verify no pinned conversation record exists - from extensions.ext_database import db pinned_conversation = ( - db.session.query(PinnedConversation) + db_session_with_containers.query(PinnedConversation) .where( PinnedConversation.app_id == app.id, PinnedConversation.conversation_id == conversation.id, @@ -478,7 +474,7 @@ class TestWebConversationService: assert pinned_conversation is None def test_pagination_by_last_id_user_required_error( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test that pagination_by_last_id raises ValueError when user is None. @@ -499,7 +495,7 @@ class TestWebConversationService: sort_by="-updated_at", ) - def test_pin_conversation_user_none(self, db_session_with_containers, mock_external_service_dependencies): + def test_pin_conversation_user_none(self, db_session_with_containers: Session, mock_external_service_dependencies): """ Test that pin method returns early when user is None. """ @@ -513,10 +509,9 @@ class TestWebConversationService: WebConversationService.pin(app, conversation.id, None) # Verify no pinned conversation was created - from extensions.ext_database import db pinned_conversation = ( - db.session.query(PinnedConversation) + db_session_with_containers.query(PinnedConversation) .where( PinnedConversation.app_id == app.id, PinnedConversation.conversation_id == conversation.id, @@ -526,7 +521,9 @@ class TestWebConversationService: assert pinned_conversation is None - def test_unpin_conversation_user_none(self, db_session_with_containers, mock_external_service_dependencies): + def test_unpin_conversation_user_none( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test that unpin method returns early when user is None. """ @@ -540,10 +537,9 @@ class TestWebConversationService: WebConversationService.pin(app, conversation.id, account) # Verify it was pinned - from extensions.ext_database import db pinned_conversation = ( - db.session.query(PinnedConversation) + db_session_with_containers.query(PinnedConversation) .where( PinnedConversation.app_id == app.id, PinnedConversation.conversation_id == conversation.id, @@ -560,7 +556,7 @@ class TestWebConversationService: # Verify the conversation is still pinned pinned_conversation = ( - db.session.query(PinnedConversation) + db_session_with_containers.query(PinnedConversation) .where( PinnedConversation.app_id == app.id, PinnedConversation.conversation_id == conversation.id, diff --git a/api/tests/test_containers_integration_tests/services/test_webapp_auth_service.py b/api/tests/test_containers_integration_tests/services/test_webapp_auth_service.py index d1c566e477..9a1595d266 100644 --- a/api/tests/test_containers_integration_tests/services/test_webapp_auth_service.py +++ b/api/tests/test_containers_integration_tests/services/test_webapp_auth_service.py @@ -4,6 +4,7 @@ from unittest.mock import patch import pytest from faker import Faker +from sqlalchemy.orm import Session from werkzeug.exceptions import NotFound, Unauthorized from libs.password import hash_password @@ -45,7 +46,7 @@ class TestWebAppAuthService: "enterprise_service": mock_enterprise_service, } - def _create_test_account_and_tenant(self, db_session_with_containers, mock_external_service_dependencies): + def _create_test_account_and_tenant(self, db_session_with_containers: Session, mock_external_service_dependencies): """ Helper method to create a test account and tenant for testing. @@ -68,18 +69,16 @@ class TestWebAppAuthService: status="active", ) - from extensions.ext_database import db - - db.session.add(account) - db.session.commit() + db_session_with_containers.add(account) + db_session_with_containers.commit() # Create tenant for the account tenant = Tenant( name=fake.company(), status="normal", ) - db.session.add(tenant) - db.session.commit() + db_session_with_containers.add(tenant) + db_session_with_containers.commit() # Create tenant-account join join = TenantAccountJoin( @@ -88,15 +87,17 @@ class TestWebAppAuthService: role=TenantAccountRole.OWNER, current=True, ) - db.session.add(join) - db.session.commit() + db_session_with_containers.add(join) + db_session_with_containers.commit() # Set current tenant for account account.current_tenant = tenant return account, tenant - def _create_test_account_with_password(self, db_session_with_containers, mock_external_service_dependencies): + def _create_test_account_with_password( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Helper method to create a test account with password for testing. @@ -131,18 +132,16 @@ class TestWebAppAuthService: account.password = base64.b64encode(password_hash).decode() account.password_salt = base64.b64encode(salt).decode() - from extensions.ext_database import db - - db.session.add(account) - db.session.commit() + db_session_with_containers.add(account) + db_session_with_containers.commit() # Create tenant for the account tenant = Tenant( name=fake.company(), status="normal", ) - db.session.add(tenant) - db.session.commit() + db_session_with_containers.add(tenant) + db_session_with_containers.commit() # Create tenant-account join join = TenantAccountJoin( @@ -151,15 +150,17 @@ class TestWebAppAuthService: role=TenantAccountRole.OWNER, current=True, ) - db.session.add(join) - db.session.commit() + db_session_with_containers.add(join) + db_session_with_containers.commit() # Set current tenant for account account.current_tenant = tenant return account, tenant, password - def _create_test_app_and_site(self, db_session_with_containers, mock_external_service_dependencies, tenant): + def _create_test_app_and_site( + self, db_session_with_containers: Session, mock_external_service_dependencies, tenant + ): """ Helper method to create a test app and site for testing. @@ -188,10 +189,8 @@ class TestWebAppAuthService: enable_api=True, ) - from extensions.ext_database import db - - db.session.add(app) - db.session.commit() + db_session_with_containers.add(app) + db_session_with_containers.commit() # Create site site = Site( @@ -203,12 +202,12 @@ class TestWebAppAuthService: status="normal", customize_token_strategy="not_allow", ) - db.session.add(site) - db.session.commit() + db_session_with_containers.add(site) + db_session_with_containers.commit() return app, site - def test_authenticate_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_authenticate_success(self, db_session_with_containers: Session, mock_external_service_dependencies): """ Test successful authentication with valid email and password. @@ -233,14 +232,15 @@ class TestWebAppAuthService: assert result.status == AccountStatus.ACTIVE # Verify database state - from extensions.ext_database import db - db.session.refresh(result) + db_session_with_containers.refresh(result) assert result.id is not None assert result.password is not None assert result.password_salt is not None - def test_authenticate_account_not_found(self, db_session_with_containers, mock_external_service_dependencies): + def test_authenticate_account_not_found( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test authentication with non-existent email. @@ -262,7 +262,7 @@ class TestWebAppAuthService: with pytest.raises(AccountNotFoundError): WebAppAuthService.authenticate(non_existent_email, "any_password") - def test_authenticate_account_banned(self, db_session_with_containers, mock_external_service_dependencies): + def test_authenticate_account_banned(self, db_session_with_containers: Session, mock_external_service_dependencies): """ Test authentication with banned account. @@ -292,10 +292,8 @@ class TestWebAppAuthService: account.password = base64.b64encode(password_hash).decode() account.password_salt = base64.b64encode(salt).decode() - from extensions.ext_database import db - - db.session.add(account) - db.session.commit() + db_session_with_containers.add(account) + db_session_with_containers.commit() # Act & Assert: Verify proper error handling with pytest.raises(AccountLoginError) as exc_info: @@ -303,7 +301,9 @@ class TestWebAppAuthService: assert "Account is banned." in str(exc_info.value) - def test_authenticate_invalid_password(self, db_session_with_containers, mock_external_service_dependencies): + def test_authenticate_invalid_password( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test authentication with invalid password. @@ -323,7 +323,7 @@ class TestWebAppAuthService: assert "Invalid email or password." in str(exc_info.value) def test_authenticate_account_without_password( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test authentication for account without password. @@ -345,10 +345,8 @@ class TestWebAppAuthService: status="active", ) - from extensions.ext_database import db - - db.session.add(account) - db.session.commit() + db_session_with_containers.add(account) + db_session_with_containers.commit() # Act & Assert: Verify proper error handling with pytest.raises(AccountPasswordError) as exc_info: @@ -356,7 +354,7 @@ class TestWebAppAuthService: assert "Invalid email or password." in str(exc_info.value) - def test_login_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_login_success(self, db_session_with_containers: Session, mock_external_service_dependencies): """ Test successful login and JWT token generation. @@ -388,7 +386,9 @@ class TestWebAppAuthService: assert call_args["auth_type"] == "internal" assert "exp" in call_args - def test_get_user_through_email_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_get_user_through_email_success( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test successful user retrieval through email. @@ -413,12 +413,13 @@ class TestWebAppAuthService: assert result.status == AccountStatus.ACTIVE # Verify database state - from extensions.ext_database import db - db.session.refresh(result) + db_session_with_containers.refresh(result) assert result.id is not None - def test_get_user_through_email_not_found(self, db_session_with_containers, mock_external_service_dependencies): + def test_get_user_through_email_not_found( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test user retrieval with non-existent email. @@ -435,7 +436,9 @@ class TestWebAppAuthService: # Assert: Verify proper handling assert result is None - def test_get_user_through_email_banned(self, db_session_with_containers, mock_external_service_dependencies): + def test_get_user_through_email_banned( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test user retrieval with banned account. @@ -456,10 +459,8 @@ class TestWebAppAuthService: status=AccountStatus.BANNED, ) - from extensions.ext_database import db - - db.session.add(account) - db.session.commit() + db_session_with_containers.add(account) + db_session_with_containers.commit() # Act & Assert: Verify proper error handling with pytest.raises(Unauthorized) as exc_info: @@ -468,7 +469,7 @@ class TestWebAppAuthService: assert "Account is banned." in str(exc_info.value) def test_send_email_code_login_email_with_account( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test sending email code login email with account. @@ -509,7 +510,7 @@ class TestWebAppAuthService: assert "code" in mail_call_args[1] def test_send_email_code_login_email_with_email_only( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test sending email code login email with email only. @@ -549,7 +550,7 @@ class TestWebAppAuthService: assert "code" in mail_call_args[1] def test_send_email_code_login_email_no_email_provided( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test sending email code login email without providing email. @@ -566,7 +567,9 @@ class TestWebAppAuthService: assert "Email must be provided." in str(exc_info.value) - def test_get_email_code_login_data_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_get_email_code_login_data_success( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test successful retrieval of email code login data. @@ -593,7 +596,9 @@ class TestWebAppAuthService: "mock_token", "email_code_login" ) - def test_get_email_code_login_data_no_data(self, db_session_with_containers, mock_external_service_dependencies): + def test_get_email_code_login_data_no_data( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test email code login data retrieval when no data exists. @@ -617,7 +622,7 @@ class TestWebAppAuthService: ) def test_revoke_email_code_login_token_success( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test successful revocation of email code login token. @@ -636,7 +641,7 @@ class TestWebAppAuthService: "mock_token", "email_code_login" ) - def test_create_end_user_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_create_end_user_success(self, db_session_with_containers: Session, mock_external_service_dependencies): """ Test successful end user creation. @@ -668,14 +673,15 @@ class TestWebAppAuthService: assert result.external_user_id == "enterpriseuser" # Verify database state - from extensions.ext_database import db - db.session.refresh(result) + db_session_with_containers.refresh(result) assert result.id is not None assert result.created_at is not None assert result.updated_at is not None - def test_create_end_user_site_not_found(self, db_session_with_containers, mock_external_service_dependencies): + def test_create_end_user_site_not_found( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test end user creation with non-existent site code. @@ -693,7 +699,9 @@ class TestWebAppAuthService: assert "Site not found." in str(exc_info.value) - def test_create_end_user_app_not_found(self, db_session_with_containers, mock_external_service_dependencies): + def test_create_end_user_app_not_found( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test end user creation when app is not found. @@ -708,10 +716,8 @@ class TestWebAppAuthService: status="normal", ) - from extensions.ext_database import db - - db.session.add(tenant) - db.session.commit() + db_session_with_containers.add(tenant) + db_session_with_containers.commit() site = Site( app_id="00000000-0000-0000-0000-000000000000", @@ -722,8 +728,8 @@ class TestWebAppAuthService: status="normal", customize_token_strategy="not_allow", ) - db.session.add(site) - db.session.commit() + db_session_with_containers.add(site) + db_session_with_containers.commit() # Act & Assert: Verify proper error handling with pytest.raises(NotFound) as exc_info: @@ -732,7 +738,7 @@ class TestWebAppAuthService: assert "App not found." in str(exc_info.value) def test_is_app_require_permission_check_with_access_mode_private( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test permission check requirement for private access mode. @@ -751,7 +757,7 @@ class TestWebAppAuthService: assert result is True def test_is_app_require_permission_check_with_access_mode_public( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test permission check requirement for public access mode. @@ -770,7 +776,7 @@ class TestWebAppAuthService: assert result is False def test_is_app_require_permission_check_with_app_code( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test permission check requirement using app code. @@ -796,7 +802,7 @@ class TestWebAppAuthService: ].WebAppAuth.get_app_access_mode_by_id.assert_called_once_with("mock_app_id") def test_is_app_require_permission_check_no_parameters( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test permission check requirement with no parameters. @@ -814,7 +820,7 @@ class TestWebAppAuthService: assert "Either app_code or app_id must be provided." in str(exc_info.value) def test_get_app_auth_type_with_access_mode_public( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test app authentication type for public access mode. @@ -833,7 +839,7 @@ class TestWebAppAuthService: assert result == WebAppAuthType.PUBLIC def test_get_app_auth_type_with_access_mode_private( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test app authentication type for private access mode. @@ -851,7 +857,9 @@ class TestWebAppAuthService: # Assert: Verify correct result assert result == WebAppAuthType.INTERNAL - def test_get_app_auth_type_with_app_code(self, db_session_with_containers, mock_external_service_dependencies): + def test_get_app_auth_type_with_app_code( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test app authentication type using app code. @@ -878,7 +886,9 @@ class TestWebAppAuthService: "enterprise_service" ].WebAppAuth.get_app_access_mode_by_id.assert_called_once_with(app_id="mock_app_id") - def test_get_app_auth_type_no_parameters(self, db_session_with_containers, mock_external_service_dependencies): + def test_get_app_auth_type_no_parameters( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test app authentication type with no parameters. diff --git a/api/tests/test_containers_integration_tests/services/test_workflow_app_service.py b/api/tests/test_containers_integration_tests/services/test_workflow_app_service.py index ce25eec6f0..a3440b6b67 100644 --- a/api/tests/test_containers_integration_tests/services/test_workflow_app_service.py +++ b/api/tests/test_containers_integration_tests/services/test_workflow_app_service.py @@ -5,6 +5,7 @@ from unittest.mock import patch import pytest from faker import Faker +from sqlalchemy.orm import Session from dify_graph.entities.workflow_execution import WorkflowExecutionStatus from models import EndUser, Workflow, WorkflowAppLog, WorkflowRun @@ -48,7 +49,7 @@ class TestWorkflowAppService: "account_feature_service": mock_account_feature_service, } - def _create_test_app_and_account(self, db_session_with_containers, mock_external_service_dependencies): + def _create_test_app_and_account(self, db_session_with_containers: Session, mock_external_service_dependencies): """ Helper method to create a test app and account for testing. @@ -96,7 +97,7 @@ class TestWorkflowAppService: return app, account - def _create_test_tenant_and_account(self, db_session_with_containers, mock_external_service_dependencies): + def _create_test_tenant_and_account(self, db_session_with_containers: Session, mock_external_service_dependencies): """ Helper method to create a test tenant and account for testing. @@ -126,7 +127,7 @@ class TestWorkflowAppService: return tenant, account - def _create_test_app(self, db_session_with_containers, tenant, account): + def _create_test_app(self, db_session_with_containers: Session, tenant, account): """ Helper method to create a test app for testing. @@ -160,7 +161,7 @@ class TestWorkflowAppService: return app - def _create_test_workflow_data(self, db_session_with_containers, app, account): + def _create_test_workflow_data(self, db_session_with_containers: Session, app, account): """ Helper method to create test workflow data for testing. @@ -174,8 +175,6 @@ class TestWorkflowAppService: """ fake = Faker() - from extensions.ext_database import db - # Create workflow workflow = Workflow( id=str(uuid.uuid4()), @@ -188,8 +187,8 @@ class TestWorkflowAppService: created_by=account.id, updated_by=account.id, ) - db.session.add(workflow) - db.session.commit() + db_session_with_containers.add(workflow) + db_session_with_containers.commit() # Create workflow run workflow_run = WorkflowRun( @@ -212,8 +211,8 @@ class TestWorkflowAppService: created_at=datetime.now(UTC), finished_at=datetime.now(UTC), ) - db.session.add(workflow_run) - db.session.commit() + db_session_with_containers.add(workflow_run) + db_session_with_containers.commit() # Create workflow app log workflow_app_log = WorkflowAppLog( @@ -227,13 +226,13 @@ class TestWorkflowAppService: ) workflow_app_log.id = str(uuid.uuid4()) workflow_app_log.created_at = datetime.now(UTC) - db.session.add(workflow_app_log) - db.session.commit() + db_session_with_containers.add(workflow_app_log) + db_session_with_containers.commit() return workflow, workflow_run, workflow_app_log def test_get_paginate_workflow_app_logs_basic_success( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test successful pagination of workflow app logs with basic parameters. @@ -268,13 +267,12 @@ class TestWorkflowAppService: assert log_entry.workflow_run_id == workflow_run.id # Verify database state - from extensions.ext_database import db - db.session.refresh(workflow_app_log) + db_session_with_containers.refresh(workflow_app_log) assert workflow_app_log.id is not None def test_get_paginate_workflow_app_logs_with_keyword_search( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test workflow app logs pagination with keyword search functionality. @@ -287,11 +285,10 @@ class TestWorkflowAppService: ) # Update workflow run with searchable content - from extensions.ext_database import db workflow_run.inputs = json.dumps({"search_term": "test_keyword", "input2": "other_value"}) workflow_run.outputs = json.dumps({"result": "test_keyword_found", "status": "success"}) - db.session.commit() + db_session_with_containers.commit() # Act: Execute the method under test with keyword search service = WorkflowAppService() @@ -317,7 +314,7 @@ class TestWorkflowAppService: assert len(result_no_match["data"]) == 0 def test_get_paginate_workflow_app_logs_with_special_characters_in_keyword( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): r""" Test workflow app logs pagination with special characters in keyword to verify SQL injection prevention. @@ -332,8 +329,6 @@ class TestWorkflowAppService: app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) workflow, _, _ = self._create_test_workflow_data(db_session_with_containers, app, account) - from extensions.ext_database import db - service = WorkflowAppService() # Test 1: Search with % character @@ -353,8 +348,8 @@ class TestWorkflowAppService: created_by=account.id, created_at=datetime.now(UTC), ) - db.session.add(workflow_run_1) - db.session.flush() + db_session_with_containers.add(workflow_run_1) + db_session_with_containers.flush() workflow_app_log_1 = WorkflowAppLog( tenant_id=app.tenant_id, @@ -367,8 +362,8 @@ class TestWorkflowAppService: ) workflow_app_log_1.id = str(uuid.uuid4()) workflow_app_log_1.created_at = datetime.now(UTC) - db.session.add(workflow_app_log_1) - db.session.commit() + db_session_with_containers.add(workflow_app_log_1) + db_session_with_containers.commit() result = service.get_paginate_workflow_app_logs( session=db_session_with_containers, app_model=app, keyword="50%", page=1, limit=20 @@ -395,8 +390,8 @@ class TestWorkflowAppService: created_by=account.id, created_at=datetime.now(UTC), ) - db.session.add(workflow_run_2) - db.session.flush() + db_session_with_containers.add(workflow_run_2) + db_session_with_containers.flush() workflow_app_log_2 = WorkflowAppLog( tenant_id=app.tenant_id, @@ -409,8 +404,8 @@ class TestWorkflowAppService: ) workflow_app_log_2.id = str(uuid.uuid4()) workflow_app_log_2.created_at = datetime.now(UTC) - db.session.add(workflow_app_log_2) - db.session.commit() + db_session_with_containers.add(workflow_app_log_2) + db_session_with_containers.commit() result = service.get_paginate_workflow_app_logs( session=db_session_with_containers, app_model=app, keyword="test_data", page=1, limit=20 @@ -437,8 +432,8 @@ class TestWorkflowAppService: created_by=account.id, created_at=datetime.now(UTC), ) - db.session.add(workflow_run_4) - db.session.flush() + db_session_with_containers.add(workflow_run_4) + db_session_with_containers.flush() workflow_app_log_4 = WorkflowAppLog( tenant_id=app.tenant_id, @@ -451,8 +446,8 @@ class TestWorkflowAppService: ) workflow_app_log_4.id = str(uuid.uuid4()) workflow_app_log_4.created_at = datetime.now(UTC) - db.session.add(workflow_app_log_4) - db.session.commit() + db_session_with_containers.add(workflow_app_log_4) + db_session_with_containers.commit() result = service.get_paginate_workflow_app_logs( session=db_session_with_containers, app_model=app, keyword="50%", page=1, limit=20 @@ -467,7 +462,7 @@ class TestWorkflowAppService: assert workflow_run_4.id not in found_run_ids def test_get_paginate_workflow_app_logs_with_status_filter( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test workflow app logs pagination with status filtering. @@ -476,8 +471,6 @@ class TestWorkflowAppService: fake = Faker() app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) - from extensions.ext_database import db - # Create workflow workflow = Workflow( id=str(uuid.uuid4()), @@ -490,8 +483,8 @@ class TestWorkflowAppService: created_by=account.id, updated_by=account.id, ) - db.session.add(workflow) - db.session.commit() + db_session_with_containers.add(workflow) + db_session_with_containers.commit() # Create workflow runs with different statuses statuses = ["succeeded", "failed", "running", "stopped"] @@ -519,8 +512,8 @@ class TestWorkflowAppService: created_at=datetime.now(UTC) + timedelta(minutes=i), finished_at=datetime.now(UTC) + timedelta(minutes=i + 1) if status != "running" else None, ) - db.session.add(workflow_run) - db.session.commit() + db_session_with_containers.add(workflow_run) + db_session_with_containers.commit() workflow_app_log = WorkflowAppLog( tenant_id=app.tenant_id, @@ -533,8 +526,8 @@ class TestWorkflowAppService: ) workflow_app_log.id = str(uuid.uuid4()) workflow_app_log.created_at = datetime.now(UTC) + timedelta(minutes=i) - db.session.add(workflow_app_log) - db.session.commit() + db_session_with_containers.add(workflow_app_log) + db_session_with_containers.commit() workflow_runs.append(workflow_run) workflow_app_logs.append(workflow_app_log) @@ -568,7 +561,7 @@ class TestWorkflowAppService: assert result_running["data"][0].workflow_run.status == "running" def test_get_paginate_workflow_app_logs_with_time_filtering( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test workflow app logs pagination with time-based filtering. @@ -577,8 +570,6 @@ class TestWorkflowAppService: fake = Faker() app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) - from extensions.ext_database import db - # Create workflow workflow = Workflow( id=str(uuid.uuid4()), @@ -591,8 +582,8 @@ class TestWorkflowAppService: created_by=account.id, updated_by=account.id, ) - db.session.add(workflow) - db.session.commit() + db_session_with_containers.add(workflow) + db_session_with_containers.commit() # Create workflow runs with different timestamps base_time = datetime.now(UTC) @@ -627,8 +618,8 @@ class TestWorkflowAppService: created_at=timestamp, finished_at=timestamp + timedelta(minutes=1), ) - db.session.add(workflow_run) - db.session.commit() + db_session_with_containers.add(workflow_run) + db_session_with_containers.commit() workflow_app_log = WorkflowAppLog( tenant_id=app.tenant_id, @@ -641,8 +632,8 @@ class TestWorkflowAppService: ) workflow_app_log.id = str(uuid.uuid4()) workflow_app_log.created_at = timestamp - db.session.add(workflow_app_log) - db.session.commit() + db_session_with_containers.add(workflow_app_log) + db_session_with_containers.commit() workflow_runs.append(workflow_run) workflow_app_logs.append(workflow_app_log) @@ -682,7 +673,7 @@ class TestWorkflowAppService: assert result_range["total"] == 2 # Should get logs from 2 hours ago and 1 hour ago def test_get_paginate_workflow_app_logs_with_pagination( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test workflow app logs pagination with different page sizes and limits. @@ -691,8 +682,6 @@ class TestWorkflowAppService: fake = Faker() app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) - from extensions.ext_database import db - # Create workflow workflow = Workflow( id=str(uuid.uuid4()), @@ -705,8 +694,8 @@ class TestWorkflowAppService: created_by=account.id, updated_by=account.id, ) - db.session.add(workflow) - db.session.commit() + db_session_with_containers.add(workflow) + db_session_with_containers.commit() # Create 25 workflow runs and logs total_logs = 25 @@ -734,8 +723,8 @@ class TestWorkflowAppService: created_at=datetime.now(UTC) + timedelta(minutes=i), finished_at=datetime.now(UTC) + timedelta(minutes=i + 1), ) - db.session.add(workflow_run) - db.session.commit() + db_session_with_containers.add(workflow_run) + db_session_with_containers.commit() workflow_app_log = WorkflowAppLog( tenant_id=app.tenant_id, @@ -748,8 +737,8 @@ class TestWorkflowAppService: ) workflow_app_log.id = str(uuid.uuid4()) workflow_app_log.created_at = datetime.now(UTC) + timedelta(minutes=i) - db.session.add(workflow_app_log) - db.session.commit() + db_session_with_containers.add(workflow_app_log) + db_session_with_containers.commit() workflow_runs.append(workflow_run) workflow_app_logs.append(workflow_app_log) @@ -798,7 +787,7 @@ class TestWorkflowAppService: assert len(result_large_limit["data"]) == total_logs def test_get_paginate_workflow_app_logs_with_user_role_filtering( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test workflow app logs pagination with user role and session filtering. @@ -807,8 +796,6 @@ class TestWorkflowAppService: fake = Faker() app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) - from extensions.ext_database import db - # Create workflow workflow = Workflow( id=str(uuid.uuid4()), @@ -821,8 +808,8 @@ class TestWorkflowAppService: created_by=account.id, updated_by=account.id, ) - db.session.add(workflow) - db.session.commit() + db_session_with_containers.add(workflow) + db_session_with_containers.commit() # Create end user end_user = EndUser( @@ -835,8 +822,8 @@ class TestWorkflowAppService: created_at=datetime.now(UTC), updated_at=datetime.now(UTC), ) - db.session.add(end_user) - db.session.commit() + db_session_with_containers.add(end_user) + db_session_with_containers.commit() # Create workflow runs and logs for both account and end user workflow_runs = [] @@ -864,8 +851,8 @@ class TestWorkflowAppService: created_at=datetime.now(UTC) + timedelta(minutes=i), finished_at=datetime.now(UTC) + timedelta(minutes=i + 1), ) - db.session.add(workflow_run) - db.session.commit() + db_session_with_containers.add(workflow_run) + db_session_with_containers.commit() workflow_app_log = WorkflowAppLog( tenant_id=app.tenant_id, @@ -878,8 +865,8 @@ class TestWorkflowAppService: ) workflow_app_log.id = str(uuid.uuid4()) workflow_app_log.created_at = datetime.now(UTC) + timedelta(minutes=i) - db.session.add(workflow_app_log) - db.session.commit() + db_session_with_containers.add(workflow_app_log) + db_session_with_containers.commit() workflow_runs.append(workflow_run) workflow_app_logs.append(workflow_app_log) @@ -906,8 +893,8 @@ class TestWorkflowAppService: created_at=datetime.now(UTC) + timedelta(minutes=i + 10), finished_at=datetime.now(UTC) + timedelta(minutes=i + 11), ) - db.session.add(workflow_run) - db.session.commit() + db_session_with_containers.add(workflow_run) + db_session_with_containers.commit() workflow_app_log = WorkflowAppLog( tenant_id=app.tenant_id, @@ -920,8 +907,8 @@ class TestWorkflowAppService: ) workflow_app_log.id = str(uuid.uuid4()) workflow_app_log.created_at = datetime.now(UTC) + timedelta(minutes=i + 10) - db.session.add(workflow_app_log) - db.session.commit() + db_session_with_containers.add(workflow_app_log) + db_session_with_containers.commit() workflow_runs.append(workflow_run) workflow_app_logs.append(workflow_app_log) @@ -994,7 +981,7 @@ class TestWorkflowAppService: assert "Account not found" in str(exc_info.value) def test_get_paginate_workflow_app_logs_with_uuid_keyword_search( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test workflow app logs pagination with UUID keyword search functionality. @@ -1003,8 +990,6 @@ class TestWorkflowAppService: fake = Faker() app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) - from extensions.ext_database import db - # Create workflow workflow = Workflow( id=str(uuid.uuid4()), @@ -1017,8 +1002,8 @@ class TestWorkflowAppService: created_by=account.id, updated_by=account.id, ) - db.session.add(workflow) - db.session.commit() + db_session_with_containers.add(workflow) + db_session_with_containers.commit() # Create workflow run with specific UUID workflow_run_id = str(uuid.uuid4()) @@ -1042,8 +1027,8 @@ class TestWorkflowAppService: created_at=datetime.now(UTC), finished_at=datetime.now(UTC) + timedelta(minutes=1), ) - db.session.add(workflow_run) - db.session.commit() + db_session_with_containers.add(workflow_run) + db_session_with_containers.commit() # Create workflow app log workflow_app_log = WorkflowAppLog( @@ -1057,8 +1042,8 @@ class TestWorkflowAppService: ) workflow_app_log.id = str(uuid.uuid4()) workflow_app_log.created_at = datetime.now(UTC) - db.session.add(workflow_app_log) - db.session.commit() + db_session_with_containers.add(workflow_app_log) + db_session_with_containers.commit() # Act & Assert: Test UUID keyword search service = WorkflowAppService() @@ -1085,7 +1070,7 @@ class TestWorkflowAppService: assert result_invalid_uuid["total"] == 0 def test_get_paginate_workflow_app_logs_with_edge_cases( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test workflow app logs pagination with edge cases and boundary conditions. @@ -1094,8 +1079,6 @@ class TestWorkflowAppService: fake = Faker() app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) - from extensions.ext_database import db - # Create workflow workflow = Workflow( id=str(uuid.uuid4()), @@ -1108,8 +1091,8 @@ class TestWorkflowAppService: created_by=account.id, updated_by=account.id, ) - db.session.add(workflow) - db.session.commit() + db_session_with_containers.add(workflow) + db_session_with_containers.commit() # Create workflow run with edge case data workflow_run = WorkflowRun( @@ -1132,8 +1115,8 @@ class TestWorkflowAppService: created_at=datetime.now(UTC), finished_at=datetime.now(UTC), ) - db.session.add(workflow_run) - db.session.commit() + db_session_with_containers.add(workflow_run) + db_session_with_containers.commit() # Create workflow app log workflow_app_log = WorkflowAppLog( @@ -1147,8 +1130,8 @@ class TestWorkflowAppService: ) workflow_app_log.id = str(uuid.uuid4()) workflow_app_log.created_at = datetime.now(UTC) - db.session.add(workflow_app_log) - db.session.commit() + db_session_with_containers.add(workflow_app_log) + db_session_with_containers.commit() # Act & Assert: Test edge cases service = WorkflowAppService() @@ -1185,7 +1168,7 @@ class TestWorkflowAppService: assert result_high_page["has_more"] is False def test_get_paginate_workflow_app_logs_with_empty_results( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test workflow app logs pagination with empty results and no data scenarios. @@ -1252,7 +1235,7 @@ class TestWorkflowAppService: assert "Account not found" in str(exc_info.value) def test_get_paginate_workflow_app_logs_with_complex_query_combinations( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test workflow app logs pagination with complex query combinations. @@ -1352,7 +1335,7 @@ class TestWorkflowAppService: assert len(result_time_status_limit["data"]) <= 2 def test_get_paginate_workflow_app_logs_with_large_dataset_performance( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test workflow app logs pagination with large dataset for performance validation. @@ -1444,7 +1427,7 @@ class TestWorkflowAppService: assert result_last_page["page"] == 3 def test_get_paginate_workflow_app_logs_with_tenant_isolation( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test workflow app logs pagination with proper tenant isolation. diff --git a/api/tests/test_containers_integration_tests/services/test_workflow_draft_variable_service.py b/api/tests/test_containers_integration_tests/services/test_workflow_draft_variable_service.py index 624251cd6c..ab409deb89 100644 --- a/api/tests/test_containers_integration_tests/services/test_workflow_draft_variable_service.py +++ b/api/tests/test_containers_integration_tests/services/test_workflow_draft_variable_service.py @@ -1,5 +1,6 @@ import pytest from faker import Faker +from sqlalchemy.orm import Session from dify_graph.constants import CONVERSATION_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID from dify_graph.variables.segments import StringSegment @@ -44,7 +45,7 @@ class TestWorkflowDraftVariableService: # WorkflowDraftVariableService doesn't have external dependencies that need mocking return {} - def _create_test_app(self, db_session_with_containers, mock_external_service_dependencies, fake=None): + def _create_test_app(self, db_session_with_containers: Session, mock_external_service_dependencies, fake=None): """ Helper method to create a test app with realistic data for testing. @@ -75,13 +76,11 @@ class TestWorkflowDraftVariableService: app.created_by = fake.uuid4() app.updated_by = app.created_by - from extensions.ext_database import db - - db.session.add(app) - db.session.commit() + db_session_with_containers.add(app) + db_session_with_containers.commit() return app - def _create_test_workflow(self, db_session_with_containers, app, fake=None): + def _create_test_workflow(self, db_session_with_containers: Session, app, fake=None): """ Helper method to create a test workflow associated with an app. @@ -110,15 +109,14 @@ class TestWorkflowDraftVariableService: conversation_variables=[], rag_pipeline_variables=[], ) - from extensions.ext_database import db - db.session.add(workflow) - db.session.commit() + db_session_with_containers.add(workflow) + db_session_with_containers.commit() return workflow def _create_test_variable( self, - db_session_with_containers, + db_session_with_containers: Session, app_id, node_id, name, @@ -174,13 +172,12 @@ class TestWorkflowDraftVariableService: visible=True, editable=True, ) - from extensions.ext_database import db - db.session.add(variable) - db.session.commit() + db_session_with_containers.add(variable) + db_session_with_containers.commit() return variable - def test_get_variable_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_get_variable_success(self, db_session_with_containers: Session, mock_external_service_dependencies): """ Test getting a single variable by ID successfully. @@ -202,7 +199,7 @@ class TestWorkflowDraftVariableService: assert retrieved_variable.app_id == app.id assert retrieved_variable.get_value().value == test_value.value - def test_get_variable_not_found(self, db_session_with_containers, mock_external_service_dependencies): + def test_get_variable_not_found(self, db_session_with_containers: Session, mock_external_service_dependencies): """ Test getting a variable that doesn't exist. @@ -217,7 +214,7 @@ class TestWorkflowDraftVariableService: assert retrieved_variable is None def test_get_draft_variables_by_selectors_success( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test getting variables by selectors successfully. @@ -268,7 +265,7 @@ class TestWorkflowDraftVariableService: assert var.get_value().value == var3_value.value def test_list_variables_without_values_success( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test listing variables without values successfully with pagination. @@ -300,7 +297,7 @@ class TestWorkflowDraftVariableService: assert var.name is not None assert var.app_id == app.id - def test_list_node_variables_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_list_node_variables_success(self, db_session_with_containers: Session, mock_external_service_dependencies): """ Test listing variables for a specific node successfully. @@ -352,7 +349,9 @@ class TestWorkflowDraftVariableService: assert "var2" in var_names assert "var3" not in var_names - def test_list_conversation_variables_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_list_conversation_variables_success( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test listing conversation variables successfully. @@ -393,7 +392,7 @@ class TestWorkflowDraftVariableService: assert "conv_var2" in var_names assert "sys_var" not in var_names - def test_update_variable_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_update_variable_success(self, db_session_with_containers: Session, mock_external_service_dependencies): """ Test updating a variable's name and value successfully. @@ -418,14 +417,15 @@ class TestWorkflowDraftVariableService: assert updated_variable.name == "new_name" assert updated_variable.get_value().value == new_value.value assert updated_variable.last_edited_at is not None - from extensions.ext_database import db - db.session.refresh(variable) + db_session_with_containers.refresh(variable) assert variable.name == "new_name" assert variable.get_value().value == new_value.value assert variable.last_edited_at is not None - def test_update_variable_not_editable(self, db_session_with_containers, mock_external_service_dependencies): + def test_update_variable_not_editable( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test that updating a non-editable variable raises an exception. @@ -445,17 +445,18 @@ class TestWorkflowDraftVariableService: node_execution_id=fake.uuid4(), editable=False, # Set as non-editable ) - from extensions.ext_database import db - db.session.add(variable) - db.session.commit() + db_session_with_containers.add(variable) + db_session_with_containers.commit() service = WorkflowDraftVariableService(db_session_with_containers) with pytest.raises(UpdateNotSupportedError) as exc_info: service.update_variable(variable, name="new_name", value=new_value) assert "variable not support updating" in str(exc_info.value) assert variable.id in str(exc_info.value) - def test_reset_conversation_variable_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_reset_conversation_variable_success( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test resetting conversation variable successfully. @@ -476,9 +477,8 @@ class TestWorkflowDraftVariableService: selector=[CONVERSATION_VARIABLE_NODE_ID, "test_conv_var"], ) workflow.conversation_variables = [conv_var] - from extensions.ext_database import db - db.session.commit() + db_session_with_containers.commit() modified_value = StringSegment(value=fake.word()) variable = self._create_test_variable( db_session_with_containers, @@ -489,17 +489,17 @@ class TestWorkflowDraftVariableService: fake=fake, ) variable.last_edited_at = fake.date_time() - db.session.commit() + db_session_with_containers.commit() service = WorkflowDraftVariableService(db_session_with_containers) reset_variable = service.reset_variable(workflow, variable) assert reset_variable is not None assert reset_variable.get_value().value == "default_value" assert reset_variable.last_edited_at is None - db.session.refresh(variable) + db_session_with_containers.refresh(variable) assert variable.get_value().value == "default_value" assert variable.last_edited_at is None - def test_delete_variable_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_delete_variable_success(self, db_session_with_containers: Session, mock_external_service_dependencies): """ Test deleting a single variable successfully. @@ -513,14 +513,15 @@ class TestWorkflowDraftVariableService: variable = self._create_test_variable( db_session_with_containers, app.id, CONVERSATION_VARIABLE_NODE_ID, "test_var", test_value, fake=fake ) - from extensions.ext_database import db - assert db.session.query(WorkflowDraftVariable).filter_by(id=variable.id).first() is not None + assert db_session_with_containers.query(WorkflowDraftVariable).filter_by(id=variable.id).first() is not None service = WorkflowDraftVariableService(db_session_with_containers) service.delete_variable(variable) - assert db.session.query(WorkflowDraftVariable).filter_by(id=variable.id).first() is None + assert db_session_with_containers.query(WorkflowDraftVariable).filter_by(id=variable.id).first() is None - def test_delete_workflow_variables_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_delete_workflow_variables_success( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test deleting all variables for a workflow successfully. @@ -550,20 +551,25 @@ class TestWorkflowDraftVariableService: other_value, fake=fake, ) - from extensions.ext_database import db - app_variables = db.session.query(WorkflowDraftVariable).filter_by(app_id=app.id).all() - other_app_variables = db.session.query(WorkflowDraftVariable).filter_by(app_id=other_app.id).all() + app_variables = db_session_with_containers.query(WorkflowDraftVariable).filter_by(app_id=app.id).all() + other_app_variables = ( + db_session_with_containers.query(WorkflowDraftVariable).filter_by(app_id=other_app.id).all() + ) assert len(app_variables) == 3 assert len(other_app_variables) == 1 service = WorkflowDraftVariableService(db_session_with_containers) service.delete_workflow_variables(app.id) - app_variables_after = db.session.query(WorkflowDraftVariable).filter_by(app_id=app.id).all() - other_app_variables_after = db.session.query(WorkflowDraftVariable).filter_by(app_id=other_app.id).all() + app_variables_after = db_session_with_containers.query(WorkflowDraftVariable).filter_by(app_id=app.id).all() + other_app_variables_after = ( + db_session_with_containers.query(WorkflowDraftVariable).filter_by(app_id=other_app.id).all() + ) assert len(app_variables_after) == 0 assert len(other_app_variables_after) == 1 - def test_delete_node_variables_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_delete_node_variables_success( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test deleting all variables for a specific node successfully. @@ -605,14 +611,15 @@ class TestWorkflowDraftVariableService: conv_value, fake=fake, ) - from extensions.ext_database import db - target_node_variables = db.session.query(WorkflowDraftVariable).filter_by(app_id=app.id, node_id=node_id).all() + target_node_variables = ( + db_session_with_containers.query(WorkflowDraftVariable).filter_by(app_id=app.id, node_id=node_id).all() + ) other_node_variables = ( - db.session.query(WorkflowDraftVariable).filter_by(app_id=app.id, node_id="other_node").all() + db_session_with_containers.query(WorkflowDraftVariable).filter_by(app_id=app.id, node_id="other_node").all() ) conv_variables = ( - db.session.query(WorkflowDraftVariable) + db_session_with_containers.query(WorkflowDraftVariable) .filter_by(app_id=app.id, node_id=CONVERSATION_VARIABLE_NODE_ID) .all() ) @@ -622,13 +629,13 @@ class TestWorkflowDraftVariableService: service = WorkflowDraftVariableService(db_session_with_containers) service.delete_node_variables(app.id, node_id) target_node_variables_after = ( - db.session.query(WorkflowDraftVariable).filter_by(app_id=app.id, node_id=node_id).all() + db_session_with_containers.query(WorkflowDraftVariable).filter_by(app_id=app.id, node_id=node_id).all() ) other_node_variables_after = ( - db.session.query(WorkflowDraftVariable).filter_by(app_id=app.id, node_id="other_node").all() + db_session_with_containers.query(WorkflowDraftVariable).filter_by(app_id=app.id, node_id="other_node").all() ) conv_variables_after = ( - db.session.query(WorkflowDraftVariable) + db_session_with_containers.query(WorkflowDraftVariable) .filter_by(app_id=app.id, node_id=CONVERSATION_VARIABLE_NODE_ID) .all() ) @@ -637,7 +644,7 @@ class TestWorkflowDraftVariableService: assert len(conv_variables_after) == 1 def test_prefill_conversation_variable_default_values_success( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test prefill conversation variable default values successfully. @@ -665,13 +672,12 @@ class TestWorkflowDraftVariableService: selector=[CONVERSATION_VARIABLE_NODE_ID, "conv_var2"], ) workflow.conversation_variables = [conv_var1, conv_var2] - from extensions.ext_database import db - db.session.commit() + db_session_with_containers.commit() service = WorkflowDraftVariableService(db_session_with_containers) service.prefill_conversation_variable_default_values(workflow) draft_variables = ( - db.session.query(WorkflowDraftVariable) + db_session_with_containers.query(WorkflowDraftVariable) .filter_by(app_id=app.id, node_id=CONVERSATION_VARIABLE_NODE_ID) .all() ) @@ -686,7 +692,7 @@ class TestWorkflowDraftVariableService: assert var.get_variable_type() == DraftVariableType.CONVERSATION def test_get_conversation_id_from_draft_variable_success( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test getting conversation ID from draft variable successfully. @@ -713,7 +719,7 @@ class TestWorkflowDraftVariableService: assert retrieved_conv_id == conversation_id def test_get_conversation_id_from_draft_variable_not_found( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test getting conversation ID when it doesn't exist. @@ -728,7 +734,9 @@ class TestWorkflowDraftVariableService: retrieved_conv_id = service._get_conversation_id_from_draft_variable(app.id) assert retrieved_conv_id is None - def test_list_system_variables_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_list_system_variables_success( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test listing system variables successfully. @@ -775,7 +783,9 @@ class TestWorkflowDraftVariableService: assert "sys_var2" in var_names assert "conv_var" not in var_names - def test_get_variable_by_name_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_get_variable_by_name_success( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test getting variables by name successfully for different types. @@ -822,7 +832,9 @@ class TestWorkflowDraftVariableService: assert retrieved_node_var.name == "test_node_var" assert retrieved_node_var.node_id == "test_node" - def test_get_variable_by_name_not_found(self, db_session_with_containers, mock_external_service_dependencies): + def test_get_variable_by_name_not_found( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test getting variables by name when they don't exist. diff --git a/api/tests/test_containers_integration_tests/services/test_workflow_run_service.py b/api/tests/test_containers_integration_tests/services/test_workflow_run_service.py index 3a88081db3..38ef3975b7 100644 --- a/api/tests/test_containers_integration_tests/services/test_workflow_run_service.py +++ b/api/tests/test_containers_integration_tests/services/test_workflow_run_service.py @@ -5,6 +5,7 @@ from unittest.mock import patch import pytest from faker import Faker +from sqlalchemy.orm import Session from models.enums import CreatorUserRole from models.model import ( @@ -48,7 +49,7 @@ class TestWorkflowRunService: "account_feature_service": mock_account_feature_service, } - def _create_test_app_and_account(self, db_session_with_containers, mock_external_service_dependencies): + def _create_test_app_and_account(self, db_session_with_containers: Session, mock_external_service_dependencies): """ Helper method to create a test app and account for testing. @@ -94,7 +95,7 @@ class TestWorkflowRunService: return app, account def _create_test_workflow_run( - self, db_session_with_containers, app, account, triggered_from="debugging", offset_minutes=0 + self, db_session_with_containers: Session, app, account, triggered_from="debugging", offset_minutes=0 ): """ Helper method to create a test workflow run for testing. @@ -110,8 +111,6 @@ class TestWorkflowRunService: """ fake = Faker() - from extensions.ext_database import db - # Create workflow run with offset timestamp base_time = datetime.now(UTC) created_time = base_time - timedelta(minutes=offset_minutes) @@ -136,12 +135,12 @@ class TestWorkflowRunService: finished_at=created_time, ) - db.session.add(workflow_run) - db.session.commit() + db_session_with_containers.add(workflow_run) + db_session_with_containers.commit() return workflow_run - def _create_test_message(self, db_session_with_containers, app, account, workflow_run): + def _create_test_message(self, db_session_with_containers: Session, app, account, workflow_run): """ Helper method to create a test message for testing. @@ -156,8 +155,6 @@ class TestWorkflowRunService: """ fake = Faker() - from extensions.ext_database import db - # Create conversation first (required for message) from models.model import Conversation @@ -170,8 +167,8 @@ class TestWorkflowRunService: from_source=CreatorUserRole.ACCOUNT, from_account_id=account.id, ) - db.session.add(conversation) - db.session.commit() + db_session_with_containers.add(conversation) + db_session_with_containers.commit() # Create message message = Message() @@ -193,12 +190,14 @@ class TestWorkflowRunService: message.workflow_run_id = workflow_run.id message.inputs = {"input": "test input"} - db.session.add(message) - db.session.commit() + db_session_with_containers.add(message) + db_session_with_containers.commit() return message - def test_get_paginate_workflow_runs_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_get_paginate_workflow_runs_success( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test successful pagination of workflow runs with debugging trigger. @@ -239,7 +238,7 @@ class TestWorkflowRunService: assert workflow_run.tenant_id == app.tenant_id def test_get_paginate_workflow_runs_with_last_id( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test pagination of workflow runs with last_id parameter. @@ -282,7 +281,7 @@ class TestWorkflowRunService: assert workflow_run.tenant_id == app.tenant_id def test_get_paginate_workflow_runs_default_limit( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test pagination of workflow runs with default limit. @@ -320,7 +319,7 @@ class TestWorkflowRunService: assert workflow_run_result.tenant_id == app.tenant_id def test_get_paginate_advanced_chat_workflow_runs_success( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test successful pagination of advanced chat workflow runs with message information. @@ -365,7 +364,7 @@ class TestWorkflowRunService: assert workflow_run.app_id == app.id assert workflow_run.tenant_id == app.tenant_id - def test_get_workflow_run_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_get_workflow_run_success(self, db_session_with_containers: Session, mock_external_service_dependencies): """ Test successful retrieval of workflow run by ID. @@ -395,7 +394,7 @@ class TestWorkflowRunService: assert result.type == "chat" assert result.version == "1.0.0" - def test_get_workflow_run_not_found(self, db_session_with_containers, mock_external_service_dependencies): + def test_get_workflow_run_not_found(self, db_session_with_containers: Session, mock_external_service_dependencies): """ Test workflow run retrieval when run ID does not exist. @@ -419,7 +418,7 @@ class TestWorkflowRunService: assert result is None def test_get_workflow_run_node_executions_success( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test successful retrieval of workflow run node executions. @@ -438,7 +437,6 @@ class TestWorkflowRunService: workflow_run = self._create_test_workflow_run(db_session_with_containers, app, account, "debugging") # Create node executions - from extensions.ext_database import db from models.workflow import WorkflowNodeExecutionModel node_executions = [] @@ -462,7 +460,7 @@ class TestWorkflowRunService: created_by=account.id, created_at=datetime.now(UTC), ) - db.session.add(node_execution) + db_session_with_containers.add(node_execution) node_executions.append(node_execution) paused_node_execution = WorkflowNodeExecutionModel( @@ -484,9 +482,9 @@ class TestWorkflowRunService: created_by=account.id, created_at=datetime.now(UTC), ) - db.session.add(paused_node_execution) + db_session_with_containers.add(paused_node_execution) - db.session.commit() + db_session_with_containers.commit() # Act: Execute the method under test workflow_run_service = WorkflowRunService() @@ -509,7 +507,7 @@ class TestWorkflowRunService: assert node_execution.node_id.startswith("node_") def test_get_workflow_run_node_executions_empty( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test getting node executions for a workflow run with no executions. @@ -560,7 +558,7 @@ class TestWorkflowRunService: assert len(result) == 0 def test_get_workflow_run_node_executions_invalid_workflow_run_id( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test getting node executions with invalid workflow run ID. @@ -611,7 +609,7 @@ class TestWorkflowRunService: assert len(result) == 0 def test_get_workflow_run_node_executions_database_error( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test getting node executions when database encounters an error. @@ -662,7 +660,7 @@ class TestWorkflowRunService: ) def test_get_workflow_run_node_executions_end_user( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test node execution retrieval for end user. @@ -680,7 +678,6 @@ class TestWorkflowRunService: workflow_run = self._create_test_workflow_run(db_session_with_containers, app, account, "debugging") # Create end user - from extensions.ext_database import db from models.model import EndUser end_user = EndUser( @@ -692,8 +689,8 @@ class TestWorkflowRunService: external_user_id=str(uuid.uuid4()), name=fake.name(), ) - db.session.add(end_user) - db.session.commit() + db_session_with_containers.add(end_user) + db_session_with_containers.commit() # Create node execution from models.workflow import WorkflowNodeExecutionModel @@ -717,8 +714,8 @@ class TestWorkflowRunService: created_by=end_user.id, created_at=datetime.now(UTC), ) - db.session.add(node_execution) - db.session.commit() + db_session_with_containers.add(node_execution) + db_session_with_containers.commit() # Act: Execute the method under test workflow_run_service = WorkflowRunService() diff --git a/api/tests/test_containers_integration_tests/services/test_workflow_service.py b/api/tests/test_containers_integration_tests/services/test_workflow_service.py index ef575a9b69..bfb23bac68 100644 --- a/api/tests/test_containers_integration_tests/services/test_workflow_service.py +++ b/api/tests/test_containers_integration_tests/services/test_workflow_service.py @@ -10,6 +10,7 @@ from unittest.mock import MagicMock import pytest from faker import Faker +from sqlalchemy.orm import Session from models import Account, App, Workflow from models.model import AppMode @@ -32,7 +33,7 @@ class TestWorkflowService: and realistic testing environment with actual database interactions. """ - def _create_test_account(self, db_session_with_containers, fake=None): + def _create_test_account(self, db_session_with_containers: Session, fake=None): """ Helper method to create a test account with realistic data. @@ -67,18 +68,16 @@ class TestWorkflowService: tenant.created_at = fake.date_time_this_year() tenant.updated_at = tenant.created_at - from extensions.ext_database import db - - db.session.add(tenant) - db.session.add(account) - db.session.commit() + db_session_with_containers.add(tenant) + db_session_with_containers.add(account) + db_session_with_containers.commit() # Set the current tenant for the account account.current_tenant = tenant return account - def _create_test_app(self, db_session_with_containers, fake=None): + def _create_test_app(self, db_session_with_containers: Session, fake=None): """ Helper method to create a test app with realistic data. @@ -106,13 +105,11 @@ class TestWorkflowService: ) app.updated_by = app.created_by - from extensions.ext_database import db - - db.session.add(app) - db.session.commit() + db_session_with_containers.add(app) + db_session_with_containers.commit() return app - def _create_test_workflow(self, db_session_with_containers, app, account, fake=None): + def _create_test_workflow(self, db_session_with_containers: Session, app, account, fake=None): """ Helper method to create a test workflow associated with an app. @@ -141,13 +138,11 @@ class TestWorkflowService: conversation_variables=[], ) - from extensions.ext_database import db - - db.session.add(workflow) - db.session.commit() + db_session_with_containers.add(workflow) + db_session_with_containers.commit() return workflow - def test_get_node_last_run_success(self, db_session_with_containers): + def test_get_node_last_run_success(self, db_session_with_containers: Session): """ Test successful retrieval of the most recent execution for a specific node. @@ -180,10 +175,8 @@ class TestWorkflowService: node_execution.created_by = account.id # Required field node_execution.created_at = fake.date_time_this_year() - from extensions.ext_database import db - - db.session.add(node_execution) - db.session.commit() + db_session_with_containers.add(node_execution) + db_session_with_containers.commit() workflow_service = WorkflowService() @@ -196,7 +189,7 @@ class TestWorkflowService: assert result.workflow_id == workflow.id assert result.status == "succeeded" - def test_get_node_last_run_not_found(self, db_session_with_containers): + def test_get_node_last_run_not_found(self, db_session_with_containers: Session): """ Test retrieval when no execution record exists for the specified node. @@ -217,7 +210,7 @@ class TestWorkflowService: # Assert assert result is None - def test_is_workflow_exist_true(self, db_session_with_containers): + def test_is_workflow_exist_true(self, db_session_with_containers: Session): """ Test workflow existence check when a draft workflow exists. @@ -238,7 +231,7 @@ class TestWorkflowService: # Assert assert result is True - def test_is_workflow_exist_false(self, db_session_with_containers): + def test_is_workflow_exist_false(self, db_session_with_containers: Session): """ Test workflow existence check when no draft workflow exists. @@ -258,7 +251,7 @@ class TestWorkflowService: # Assert assert result is False - def test_get_draft_workflow_success(self, db_session_with_containers): + def test_get_draft_workflow_success(self, db_session_with_containers: Session): """ Test successful retrieval of a draft workflow. @@ -284,7 +277,7 @@ class TestWorkflowService: assert result.app_id == app.id assert result.tenant_id == app.tenant_id - def test_get_draft_workflow_not_found(self, db_session_with_containers): + def test_get_draft_workflow_not_found(self, db_session_with_containers: Session): """ Test draft workflow retrieval when no draft workflow exists. @@ -304,7 +297,7 @@ class TestWorkflowService: # Assert assert result is None - def test_get_published_workflow_by_id_success(self, db_session_with_containers): + def test_get_published_workflow_by_id_success(self, db_session_with_containers: Session): """ Test successful retrieval of a published workflow by ID. @@ -321,9 +314,7 @@ class TestWorkflowService: workflow = self._create_test_workflow(db_session_with_containers, app, account, fake) workflow.version = "2024.01.01.001" # Published version - from extensions.ext_database import db - - db.session.commit() + db_session_with_containers.commit() workflow_service = WorkflowService() @@ -336,7 +327,7 @@ class TestWorkflowService: assert result.version != Workflow.VERSION_DRAFT assert result.app_id == app.id - def test_get_published_workflow_by_id_draft_error(self, db_session_with_containers): + def test_get_published_workflow_by_id_draft_error(self, db_session_with_containers: Session): """ Test error when trying to retrieve a draft workflow as published. @@ -359,7 +350,7 @@ class TestWorkflowService: with pytest.raises(IsDraftWorkflowError): workflow_service.get_published_workflow_by_id(app, workflow.id) - def test_get_published_workflow_by_id_not_found(self, db_session_with_containers): + def test_get_published_workflow_by_id_not_found(self, db_session_with_containers: Session): """ Test retrieval when no workflow exists with the specified ID. @@ -379,7 +370,7 @@ class TestWorkflowService: # Assert assert result is None - def test_get_published_workflow_success(self, db_session_with_containers): + def test_get_published_workflow_success(self, db_session_with_containers: Session): """ Test successful retrieval of the current published workflow for an app. @@ -395,10 +386,8 @@ class TestWorkflowService: workflow = self._create_test_workflow(db_session_with_containers, app, account, fake) workflow.version = "2024.01.01.001" # Published version - from extensions.ext_database import db - app.workflow_id = workflow.id - db.session.commit() + db_session_with_containers.commit() workflow_service = WorkflowService() @@ -411,7 +400,7 @@ class TestWorkflowService: assert result.version != Workflow.VERSION_DRAFT assert result.app_id == app.id - def test_get_published_workflow_no_workflow_id(self, db_session_with_containers): + def test_get_published_workflow_no_workflow_id(self, db_session_with_containers: Session): """ Test retrieval when app has no associated workflow ID. @@ -431,7 +420,7 @@ class TestWorkflowService: # Assert assert result is None - def test_get_all_published_workflow_pagination(self, db_session_with_containers): + def test_get_all_published_workflow_pagination(self, db_session_with_containers: Session): """ Test pagination of published workflows. @@ -455,15 +444,13 @@ class TestWorkflowService: # Set the app's workflow_id to the first workflow app.workflow_id = workflows[0].id - from extensions.ext_database import db - - db.session.commit() + db_session_with_containers.commit() workflow_service = WorkflowService() # Act - First page result_workflows, has_more = workflow_service.get_all_published_workflow( - session=db.session, + session=db_session_with_containers, app_model=app, page=1, limit=3, @@ -476,7 +463,7 @@ class TestWorkflowService: # Act - Second page result_workflows, has_more = workflow_service.get_all_published_workflow( - session=db.session, + session=db_session_with_containers, app_model=app, page=2, limit=3, @@ -487,7 +474,7 @@ class TestWorkflowService: assert len(result_workflows) == 2 assert has_more is False - def test_get_all_published_workflow_user_filter(self, db_session_with_containers): + def test_get_all_published_workflow_user_filter(self, db_session_with_containers: Session): """ Test filtering published workflows by user. @@ -513,22 +500,20 @@ class TestWorkflowService: # Set the app's workflow_id to the first workflow app.workflow_id = workflow1.id - from extensions.ext_database import db - - db.session.commit() + db_session_with_containers.commit() workflow_service = WorkflowService() # Act - Filter by account1 result_workflows, has_more = workflow_service.get_all_published_workflow( - session=db.session, app_model=app, page=1, limit=10, user_id=account1.id + session=db_session_with_containers, app_model=app, page=1, limit=10, user_id=account1.id ) # Assert assert len(result_workflows) == 1 assert result_workflows[0].created_by == account1.id - def test_get_all_published_workflow_named_only_filter(self, db_session_with_containers): + def test_get_all_published_workflow_named_only_filter(self, db_session_with_containers: Session): """ Test filtering published workflows to show only named workflows. @@ -557,22 +542,20 @@ class TestWorkflowService: # Set the app's workflow_id to the first workflow app.workflow_id = workflow1.id - from extensions.ext_database import db - - db.session.commit() + db_session_with_containers.commit() workflow_service = WorkflowService() # Act - Filter named only result_workflows, has_more = workflow_service.get_all_published_workflow( - session=db.session, app_model=app, page=1, limit=10, user_id=None, named_only=True + session=db_session_with_containers, app_model=app, page=1, limit=10, user_id=None, named_only=True ) # Assert assert len(result_workflows) == 2 assert all(wf.marked_name for wf in result_workflows) - def test_sync_draft_workflow_create_new(self, db_session_with_containers): + def test_sync_draft_workflow_create_new(self, db_session_with_containers: Session): """ Test creating a new draft workflow through sync operation. @@ -624,7 +607,7 @@ class TestWorkflowService: assert result.features == json.dumps(features) assert result.created_by == account.id - def test_sync_draft_workflow_update_existing(self, db_session_with_containers): + def test_sync_draft_workflow_update_existing(self, db_session_with_containers: Session): """ Test updating an existing draft workflow through sync operation. @@ -688,7 +671,7 @@ class TestWorkflowService: assert result.features == json.dumps(new_features) assert result.updated_by == account.id - def test_sync_draft_workflow_hash_mismatch_error(self, db_session_with_containers): + def test_sync_draft_workflow_hash_mismatch_error(self, db_session_with_containers: Session): """ Test error when sync is attempted with mismatched hash. @@ -738,7 +721,7 @@ class TestWorkflowService: conversation_variables=conversation_variables, ) - def test_publish_workflow_success(self, db_session_with_containers): + def test_publish_workflow_success(self, db_session_with_containers: Session): """ Test successful workflow publishing. @@ -755,9 +738,7 @@ class TestWorkflowService: workflow = self._create_test_workflow(db_session_with_containers, app, account, fake) workflow.version = Workflow.VERSION_DRAFT - from extensions.ext_database import db - - db.session.commit() + db_session_with_containers.commit() workflow_service = WorkflowService() @@ -777,7 +758,7 @@ class TestWorkflowService: assert len(result.version) > 10 # Should be a reasonable timestamp length assert result.created_by == account.id - def test_publish_workflow_no_draft_error(self, db_session_with_containers): + def test_publish_workflow_no_draft_error(self, db_session_with_containers: Session): """ Test error when publishing workflow without draft. @@ -797,7 +778,7 @@ class TestWorkflowService: with pytest.raises(ValueError, match="No valid workflow found"): workflow_service.publish_workflow(session=db_session_with_containers, app_model=app, account=account) - def test_publish_workflow_already_published_error(self, db_session_with_containers): + def test_publish_workflow_already_published_error(self, db_session_with_containers: Session): """ Test error when publishing already published workflow. @@ -813,9 +794,7 @@ class TestWorkflowService: workflow = self._create_test_workflow(db_session_with_containers, app, account, fake) workflow.version = "2024.01.01.001" # Already published - from extensions.ext_database import db - - db.session.commit() + db_session_with_containers.commit() workflow_service = WorkflowService() @@ -823,7 +802,7 @@ class TestWorkflowService: with pytest.raises(ValueError, match="No valid workflow found"): workflow_service.publish_workflow(session=db_session_with_containers, app_model=app, account=account) - def test_get_default_block_configs(self, db_session_with_containers): + def test_get_default_block_configs(self, db_session_with_containers: Session): """ Test retrieval of default block configurations for all node types. @@ -847,7 +826,7 @@ class TestWorkflowService: assert isinstance(config, dict) # The structure can vary, so we just check it's a dict - def test_get_default_block_config_specific_type(self, db_session_with_containers): + def test_get_default_block_config_specific_type(self, db_session_with_containers: Session): """ Test retrieval of default block configuration for a specific node type. @@ -867,7 +846,7 @@ class TestWorkflowService: # This is acceptable behavior assert result is None or isinstance(result, dict) - def test_get_default_block_config_invalid_type(self, db_session_with_containers): + def test_get_default_block_config_invalid_type(self, db_session_with_containers: Session): """ Test retrieval of default block configuration for invalid node type. @@ -887,7 +866,7 @@ class TestWorkflowService: # It's also acceptable for the service to raise a ValueError for invalid types pass - def test_get_default_block_config_with_filters(self, db_session_with_containers): + def test_get_default_block_config_with_filters(self, db_session_with_containers: Session): """ Test retrieval of default block configuration with filters. @@ -907,7 +886,7 @@ class TestWorkflowService: # Result might be None if filters don't match, but should not raise error assert result is None or isinstance(result, dict) - def test_convert_to_workflow_chat_mode_success(self, db_session_with_containers): + def test_convert_to_workflow_chat_mode_success(self, db_session_with_containers: Session): """ Test successful conversion from chat mode app to workflow mode. @@ -944,11 +923,9 @@ class TestWorkflowService: ) app_model_config.id = fake.uuid4() - from extensions.ext_database import db - - db.session.add(app_model_config) + db_session_with_containers.add(app_model_config) app.app_model_config_id = app_model_config.id - db.session.commit() + db_session_with_containers.commit() workflow_service = WorkflowService() conversion_args = { @@ -969,7 +946,7 @@ class TestWorkflowService: assert result.icon_type == conversion_args["icon_type"] assert result.icon_background == conversion_args["icon_background"] - def test_convert_to_workflow_completion_mode_success(self, db_session_with_containers): + def test_convert_to_workflow_completion_mode_success(self, db_session_with_containers: Session): """ Test successful conversion from completion mode app to workflow mode. @@ -1006,11 +983,9 @@ class TestWorkflowService: ) app_model_config.id = fake.uuid4() - from extensions.ext_database import db - - db.session.add(app_model_config) + db_session_with_containers.add(app_model_config) app.app_model_config_id = app_model_config.id - db.session.commit() + db_session_with_containers.commit() workflow_service = WorkflowService() conversion_args = { @@ -1031,7 +1006,7 @@ class TestWorkflowService: assert result.icon_type == conversion_args["icon_type"] assert result.icon_background == conversion_args["icon_background"] - def test_convert_to_workflow_unsupported_mode_error(self, db_session_with_containers): + def test_convert_to_workflow_unsupported_mode_error(self, db_session_with_containers: Session): """ Test error when attempting to convert unsupported app mode. @@ -1046,9 +1021,7 @@ class TestWorkflowService: app = self._create_test_app(db_session_with_containers, fake) app.mode = AppMode.WORKFLOW - from extensions.ext_database import db - - db.session.commit() + db_session_with_containers.commit() workflow_service = WorkflowService() conversion_args = {"name": "Test"} @@ -1057,7 +1030,7 @@ class TestWorkflowService: with pytest.raises(ValueError, match="Current App mode: workflow is not supported convert to workflow"): workflow_service.convert_to_workflow(app_model=app, account=account, args=conversion_args) - def test_validate_features_structure_advanced_chat(self, db_session_with_containers): + def test_validate_features_structure_advanced_chat(self, db_session_with_containers: Session): """ Test feature structure validation for advanced chat mode apps. @@ -1069,9 +1042,7 @@ class TestWorkflowService: app = self._create_test_app(db_session_with_containers, fake) app.mode = AppMode.ADVANCED_CHAT - from extensions.ext_database import db - - db.session.commit() + db_session_with_containers.commit() workflow_service = WorkflowService() features = { @@ -1088,7 +1059,7 @@ class TestWorkflowService: # The exact behavior depends on the AdvancedChatAppConfigManager implementation assert result is not None or isinstance(result, dict) - def test_validate_features_structure_workflow(self, db_session_with_containers): + def test_validate_features_structure_workflow(self, db_session_with_containers: Session): """ Test feature structure validation for workflow mode apps. @@ -1100,9 +1071,7 @@ class TestWorkflowService: app = self._create_test_app(db_session_with_containers, fake) app.mode = AppMode.WORKFLOW - from extensions.ext_database import db - - db.session.commit() + db_session_with_containers.commit() workflow_service = WorkflowService() features = {"workflow_config": {"max_steps": 10, "timeout": 300}} @@ -1115,7 +1084,7 @@ class TestWorkflowService: # The exact behavior depends on the WorkflowAppConfigManager implementation assert result is not None or isinstance(result, dict) - def test_validate_features_structure_invalid_mode(self, db_session_with_containers): + def test_validate_features_structure_invalid_mode(self, db_session_with_containers: Session): """ Test error when validating features for invalid app mode. @@ -1127,9 +1096,7 @@ class TestWorkflowService: app = self._create_test_app(db_session_with_containers, fake) app.mode = "invalid_mode" # Invalid mode - from extensions.ext_database import db - - db.session.commit() + db_session_with_containers.commit() workflow_service = WorkflowService() features = {"test": "value"} @@ -1138,7 +1105,7 @@ class TestWorkflowService: with pytest.raises(ValueError, match="Invalid app mode: invalid_mode"): workflow_service.validate_features_structure(app_model=app, features=features) - def test_update_workflow_success(self, db_session_with_containers): + def test_update_workflow_success(self, db_session_with_containers: Session): """ Test successful workflow update with allowed fields. @@ -1152,16 +1119,14 @@ class TestWorkflowService: app = self._create_test_app(db_session_with_containers, fake) workflow = self._create_test_workflow(db_session_with_containers, app, account, fake) - from extensions.ext_database import db - - db.session.commit() + db_session_with_containers.commit() workflow_service = WorkflowService() update_data = {"marked_name": "Updated Workflow Name", "marked_comment": "Updated workflow comment"} # Act result = workflow_service.update_workflow( - session=db.session, + session=db_session_with_containers, workflow_id=workflow.id, tenant_id=workflow.tenant_id, account_id=account.id, @@ -1174,7 +1139,7 @@ class TestWorkflowService: assert result.marked_comment == update_data["marked_comment"] assert result.updated_by == account.id - def test_update_workflow_not_found(self, db_session_with_containers): + def test_update_workflow_not_found(self, db_session_with_containers: Session): """ Test workflow update when workflow doesn't exist. @@ -1186,15 +1151,13 @@ class TestWorkflowService: account = self._create_test_account(db_session_with_containers, fake) app = self._create_test_app(db_session_with_containers, fake) - from extensions.ext_database import db - workflow_service = WorkflowService() non_existent_workflow_id = fake.uuid4() update_data = {"marked_name": "Test"} # Act result = workflow_service.update_workflow( - session=db.session, + session=db_session_with_containers, workflow_id=non_existent_workflow_id, tenant_id=app.tenant_id, account_id=account.id, @@ -1204,7 +1167,7 @@ class TestWorkflowService: # Assert assert result is None - def test_update_workflow_ignores_disallowed_fields(self, db_session_with_containers): + def test_update_workflow_ignores_disallowed_fields(self, db_session_with_containers: Session): """ Test that workflow update ignores disallowed fields. @@ -1218,9 +1181,7 @@ class TestWorkflowService: workflow = self._create_test_workflow(db_session_with_containers, app, account, fake) original_name = workflow.marked_name - from extensions.ext_database import db - - db.session.commit() + db_session_with_containers.commit() workflow_service = WorkflowService() update_data = { @@ -1231,7 +1192,7 @@ class TestWorkflowService: # Act result = workflow_service.update_workflow( - session=db.session, + session=db_session_with_containers, workflow_id=workflow.id, tenant_id=workflow.tenant_id, account_id=account.id, @@ -1245,7 +1206,7 @@ class TestWorkflowService: assert result.graph == workflow.graph assert result.features == workflow.features - def test_delete_workflow_success(self, db_session_with_containers): + def test_delete_workflow_success(self, db_session_with_containers: Session): """ Test successful workflow deletion. @@ -1262,25 +1223,23 @@ class TestWorkflowService: workflow = self._create_test_workflow(db_session_with_containers, app, account, fake) workflow.version = "2024.01.01.001" # Published version - from extensions.ext_database import db - - db.session.commit() + db_session_with_containers.commit() workflow_service = WorkflowService() # Act result = workflow_service.delete_workflow( - session=db.session, workflow_id=workflow.id, tenant_id=workflow.tenant_id + session=db_session_with_containers, workflow_id=workflow.id, tenant_id=workflow.tenant_id ) # Assert assert result is True # Verify workflow is actually deleted - deleted_workflow = db.session.query(Workflow).filter_by(id=workflow.id).first() + deleted_workflow = db_session_with_containers.query(Workflow).filter_by(id=workflow.id).first() assert deleted_workflow is None - def test_delete_workflow_draft_error(self, db_session_with_containers): + def test_delete_workflow_draft_error(self, db_session_with_containers: Session): """ Test error when attempting to delete a draft workflow. @@ -1296,9 +1255,7 @@ class TestWorkflowService: workflow = self._create_test_workflow(db_session_with_containers, app, account, fake) # Keep as draft version - from extensions.ext_database import db - - db.session.commit() + db_session_with_containers.commit() workflow_service = WorkflowService() @@ -1306,9 +1263,11 @@ class TestWorkflowService: from services.errors.workflow_service import DraftWorkflowDeletionError with pytest.raises(DraftWorkflowDeletionError, match="Cannot delete draft workflow versions"): - workflow_service.delete_workflow(session=db.session, workflow_id=workflow.id, tenant_id=workflow.tenant_id) + workflow_service.delete_workflow( + session=db_session_with_containers, workflow_id=workflow.id, tenant_id=workflow.tenant_id + ) - def test_delete_workflow_in_use_error(self, db_session_with_containers): + def test_delete_workflow_in_use_error(self, db_session_with_containers: Session): """ Test error when attempting to delete a workflow that's in use by an app. @@ -1327,9 +1286,7 @@ class TestWorkflowService: # Associate workflow with app app.workflow_id = workflow.id - from extensions.ext_database import db - - db.session.commit() + db_session_with_containers.commit() workflow_service = WorkflowService() @@ -1337,9 +1294,11 @@ class TestWorkflowService: from services.errors.workflow_service import WorkflowInUseError with pytest.raises(WorkflowInUseError, match="Cannot delete workflow that is currently in use by app"): - workflow_service.delete_workflow(session=db.session, workflow_id=workflow.id, tenant_id=workflow.tenant_id) + workflow_service.delete_workflow( + session=db_session_with_containers, workflow_id=workflow.id, tenant_id=workflow.tenant_id + ) - def test_delete_workflow_not_found_error(self, db_session_with_containers): + def test_delete_workflow_not_found_error(self, db_session_with_containers: Session): """ Test error when attempting to delete a non-existent workflow. @@ -1351,17 +1310,15 @@ class TestWorkflowService: app = self._create_test_app(db_session_with_containers, fake) non_existent_workflow_id = fake.uuid4() - from extensions.ext_database import db - workflow_service = WorkflowService() # Act & Assert with pytest.raises(ValueError, match=f"Workflow with ID {non_existent_workflow_id} not found"): workflow_service.delete_workflow( - session=db.session, workflow_id=non_existent_workflow_id, tenant_id=app.tenant_id + session=db_session_with_containers, workflow_id=non_existent_workflow_id, tenant_id=app.tenant_id ) - def test_run_free_workflow_node_success(self, db_session_with_containers): + def test_run_free_workflow_node_success(self, db_session_with_containers: Session): """ Test successful execution of a free workflow node. @@ -1413,7 +1370,7 @@ class TestWorkflowService: assert result.workflow_id == "" # No workflow ID for free nodes assert result.index == 1 - def test_run_free_workflow_node_with_complex_inputs(self, db_session_with_containers): + def test_run_free_workflow_node_with_complex_inputs(self, db_session_with_containers: Session): """ Test execution of a free workflow node with complex input data. @@ -1454,7 +1411,7 @@ class TestWorkflowService: error_msg = str(exc_info.value).lower() assert any(keyword in error_msg for keyword in ["start", "not supported", "external"]) - def test_handle_node_run_result_success(self, db_session_with_containers): + def test_handle_node_run_result_success(self, db_session_with_containers: Session): """ Test successful handling of node run results. @@ -1529,7 +1486,7 @@ class TestWorkflowService: assert result.outputs is not None assert result.process_data is not None - def test_handle_node_run_result_failure(self, db_session_with_containers): + def test_handle_node_run_result_failure(self, db_session_with_containers: Session): """ Test handling of failed node run results. @@ -1598,7 +1555,7 @@ class TestWorkflowService: assert result.error is not None assert "Test error message" in str(result.error) - def test_handle_node_run_result_continue_on_error(self, db_session_with_containers): + def test_handle_node_run_result_continue_on_error(self, db_session_with_containers: Session): """ Test handling of node run results with continue_on_error strategy. diff --git a/api/tests/test_containers_integration_tests/services/test_workspace_service.py b/api/tests/test_containers_integration_tests/services/test_workspace_service.py index 4249642bc9..92dec24c7d 100644 --- a/api/tests/test_containers_integration_tests/services/test_workspace_service.py +++ b/api/tests/test_containers_integration_tests/services/test_workspace_service.py @@ -2,6 +2,7 @@ from unittest.mock import patch import pytest from faker import Faker +from sqlalchemy.orm import Session from models import Account, Tenant, TenantAccountJoin, TenantAccountRole from services.workspace_service import WorkspaceService @@ -29,7 +30,7 @@ class TestWorkspaceService: "dify_config": mock_dify_config, } - def _create_test_account_and_tenant(self, db_session_with_containers, mock_external_service_dependencies): + def _create_test_account_and_tenant(self, db_session_with_containers: Session, mock_external_service_dependencies): """ Helper method to create a test account and tenant for testing. @@ -50,10 +51,8 @@ class TestWorkspaceService: status="active", ) - from extensions.ext_database import db - - db.session.add(account) - db.session.commit() + db_session_with_containers.add(account) + db_session_with_containers.commit() # Create tenant tenant = Tenant( @@ -62,8 +61,8 @@ class TestWorkspaceService: plan="basic", custom_config='{"replace_webapp_logo": true, "remove_webapp_brand": false}', ) - db.session.add(tenant) - db.session.commit() + db_session_with_containers.add(tenant) + db_session_with_containers.commit() # Create tenant-account join with owner role join = TenantAccountJoin( @@ -72,15 +71,15 @@ class TestWorkspaceService: role=TenantAccountRole.OWNER, current=True, ) - db.session.add(join) - db.session.commit() + db_session_with_containers.add(join) + db_session_with_containers.commit() # Set current tenant for account account.current_tenant = tenant return account, tenant - def test_get_tenant_info_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_get_tenant_info_success(self, db_session_with_containers: Session, mock_external_service_dependencies): """ Test successful retrieval of tenant information with all features enabled. @@ -121,13 +120,12 @@ class TestWorkspaceService: assert "replace_webapp_logo" in result["custom_config"] # Verify database state - from extensions.ext_database import db - db.session.refresh(tenant) + db_session_with_containers.refresh(tenant) assert tenant.id is not None def test_get_tenant_info_without_custom_config( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test tenant info retrieval when custom config features are disabled. @@ -167,13 +165,12 @@ class TestWorkspaceService: assert "custom_config" not in result # Verify database state - from extensions.ext_database import db - db.session.refresh(tenant) + db_session_with_containers.refresh(tenant) assert tenant.id is not None def test_get_tenant_info_with_normal_user_role( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test tenant info retrieval for normal user role without privileged features. @@ -191,11 +188,14 @@ class TestWorkspaceService: ) # Update the join to have normal role - from extensions.ext_database import db - join = db.session.query(TenantAccountJoin).filter_by(tenant_id=tenant.id, account_id=account.id).first() + join = ( + db_session_with_containers.query(TenantAccountJoin) + .filter_by(tenant_id=tenant.id, account_id=account.id) + .first() + ) join.role = TenantAccountRole.NORMAL - db.session.commit() + db_session_with_containers.commit() # Setup mocks for feature service mock_external_service_dependencies["feature_service"].get_features.return_value.can_replace_logo = True @@ -220,11 +220,11 @@ class TestWorkspaceService: assert "custom_config" not in result # Verify database state - db.session.refresh(tenant) + db_session_with_containers.refresh(tenant) assert tenant.id is not None def test_get_tenant_info_with_admin_role_and_logo_replacement( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test tenant info retrieval for admin role with logo replacement enabled. @@ -242,11 +242,14 @@ class TestWorkspaceService: ) # Update the join to have admin role - from extensions.ext_database import db - join = db.session.query(TenantAccountJoin).filter_by(tenant_id=tenant.id, account_id=account.id).first() + join = ( + db_session_with_containers.query(TenantAccountJoin) + .filter_by(tenant_id=tenant.id, account_id=account.id) + .first() + ) join.role = TenantAccountRole.ADMIN - db.session.commit() + db_session_with_containers.commit() # Setup mocks for feature service and tenant service mock_external_service_dependencies["feature_service"].get_features.return_value.can_replace_logo = True @@ -268,10 +271,12 @@ class TestWorkspaceService: assert "replace_webapp_logo" in result["custom_config"] # Verify database state - db.session.refresh(tenant) + db_session_with_containers.refresh(tenant) assert tenant.id is not None - def test_get_tenant_info_with_tenant_none(self, db_session_with_containers, mock_external_service_dependencies): + def test_get_tenant_info_with_tenant_none( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test tenant info retrieval when tenant parameter is None. @@ -290,7 +295,7 @@ class TestWorkspaceService: assert result is None def test_get_tenant_info_with_custom_config_variations( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test tenant info retrieval with various custom config configurations. @@ -323,10 +328,8 @@ class TestWorkspaceService: # Update tenant custom config import json - from extensions.ext_database import db - tenant.custom_config = json.dumps(config) - db.session.commit() + db_session_with_containers.commit() # Setup mocks mock_external_service_dependencies["feature_service"].get_features.return_value.can_replace_logo = True @@ -353,11 +356,11 @@ class TestWorkspaceService: assert result["custom_config"]["remove_webapp_brand"] == config["remove_webapp_brand"] # Verify database state - db.session.refresh(tenant) + db_session_with_containers.refresh(tenant) assert tenant.id is not None def test_get_tenant_info_with_editor_role_and_limited_permissions( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test tenant info retrieval for editor role with limited permissions. @@ -375,11 +378,14 @@ class TestWorkspaceService: ) # Update the join to have editor role - from extensions.ext_database import db - join = db.session.query(TenantAccountJoin).filter_by(tenant_id=tenant.id, account_id=account.id).first() + join = ( + db_session_with_containers.query(TenantAccountJoin) + .filter_by(tenant_id=tenant.id, account_id=account.id) + .first() + ) join.role = TenantAccountRole.EDITOR - db.session.commit() + db_session_with_containers.commit() # Setup mocks for feature service and tenant service mock_external_service_dependencies["feature_service"].get_features.return_value.can_replace_logo = True @@ -400,11 +406,11 @@ class TestWorkspaceService: assert "custom_config" not in result # Verify database state - db.session.refresh(tenant) + db_session_with_containers.refresh(tenant) assert tenant.id is not None def test_get_tenant_info_with_dataset_operator_role( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test tenant info retrieval for dataset operator role. @@ -422,11 +428,14 @@ class TestWorkspaceService: ) # Update the join to have dataset operator role - from extensions.ext_database import db - join = db.session.query(TenantAccountJoin).filter_by(tenant_id=tenant.id, account_id=account.id).first() + join = ( + db_session_with_containers.query(TenantAccountJoin) + .filter_by(tenant_id=tenant.id, account_id=account.id) + .first() + ) join.role = TenantAccountRole.DATASET_OPERATOR - db.session.commit() + db_session_with_containers.commit() # Setup mocks for feature service and tenant service mock_external_service_dependencies["feature_service"].get_features.return_value.can_replace_logo = True @@ -447,11 +456,11 @@ class TestWorkspaceService: assert "custom_config" not in result # Verify database state - db.session.refresh(tenant) + db_session_with_containers.refresh(tenant) assert tenant.id is not None def test_get_tenant_info_with_complex_custom_config_scenarios( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test tenant info retrieval with complex custom config scenarios. @@ -491,10 +500,8 @@ class TestWorkspaceService: # Update tenant custom config import json - from extensions.ext_database import db - tenant.custom_config = json.dumps(config) - db.session.commit() + db_session_with_containers.commit() # Setup mocks mock_external_service_dependencies["feature_service"].get_features.return_value.can_replace_logo = True @@ -525,5 +532,5 @@ class TestWorkspaceService: assert result["custom_config"]["remove_webapp_brand"] is False # Verify database state - db.session.refresh(tenant) + db_session_with_containers.refresh(tenant) assert tenant.id is not None diff --git a/api/tests/test_containers_integration_tests/services/tools/test_api_tools_manage_service.py b/api/tests/test_containers_integration_tests/services/tools/test_api_tools_manage_service.py index 2ff71ea6ea..bffdca623a 100644 --- a/api/tests/test_containers_integration_tests/services/tools/test_api_tools_manage_service.py +++ b/api/tests/test_containers_integration_tests/services/tools/test_api_tools_manage_service.py @@ -3,6 +3,7 @@ from unittest.mock import patch import pytest from faker import Faker from pydantic import TypeAdapter, ValidationError +from sqlalchemy.orm import Session from core.tools.entities.tool_entities import ApiProviderSchemaType from models import Account, Tenant @@ -34,7 +35,7 @@ class TestApiToolManageService: "provider_controller": mock_provider_controller, } - def _create_test_account_and_tenant(self, db_session_with_containers, mock_external_service_dependencies): + def _create_test_account_and_tenant(self, db_session_with_containers: Session, mock_external_service_dependencies): """ Helper method to create a test account and tenant for testing. @@ -55,18 +56,16 @@ class TestApiToolManageService: status="active", ) - from extensions.ext_database import db - - db.session.add(account) - db.session.commit() + db_session_with_containers.add(account) + db_session_with_containers.commit() # Create tenant for the account tenant = Tenant( name=fake.company(), status="normal", ) - db.session.add(tenant) - db.session.commit() + db_session_with_containers.add(tenant) + db_session_with_containers.commit() # Create tenant-account join from models.account import TenantAccountJoin, TenantAccountRole @@ -77,8 +76,8 @@ class TestApiToolManageService: role=TenantAccountRole.OWNER, current=True, ) - db.session.add(join) - db.session.commit() + db_session_with_containers.add(join) + db_session_with_containers.commit() # Set current tenant for account account.current_tenant = tenant @@ -118,7 +117,7 @@ class TestApiToolManageService: """ def test_parser_api_schema_success( - self, flask_req_ctx_with_containers, db_session_with_containers, mock_external_service_dependencies + self, flask_req_ctx_with_containers, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test successful parsing of API schema. @@ -163,7 +162,7 @@ class TestApiToolManageService: assert api_key_value_field["default"] == "" def test_parser_api_schema_invalid_schema( - self, flask_req_ctx_with_containers, db_session_with_containers, mock_external_service_dependencies + self, flask_req_ctx_with_containers, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test parsing of invalid API schema. @@ -183,7 +182,7 @@ class TestApiToolManageService: assert "invalid schema" in str(exc_info.value) def test_parser_api_schema_malformed_json( - self, flask_req_ctx_with_containers, db_session_with_containers, mock_external_service_dependencies + self, flask_req_ctx_with_containers, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test parsing of malformed JSON schema. @@ -203,7 +202,7 @@ class TestApiToolManageService: assert "invalid schema" in str(exc_info.value) def test_convert_schema_to_tool_bundles_success( - self, flask_req_ctx_with_containers, db_session_with_containers, mock_external_service_dependencies + self, flask_req_ctx_with_containers, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test successful conversion of schema to tool bundles. @@ -233,7 +232,7 @@ class TestApiToolManageService: assert tool_bundle.operation_id == "testOperation" def test_convert_schema_to_tool_bundles_with_extra_info( - self, flask_req_ctx_with_containers, db_session_with_containers, mock_external_service_dependencies + self, flask_req_ctx_with_containers, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test successful conversion of schema to tool bundles with extra info. @@ -259,7 +258,7 @@ class TestApiToolManageService: assert isinstance(schema_type, str) def test_convert_schema_to_tool_bundles_invalid_schema( - self, flask_req_ctx_with_containers, db_session_with_containers, mock_external_service_dependencies + self, flask_req_ctx_with_containers, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test conversion of invalid schema to tool bundles. @@ -279,7 +278,7 @@ class TestApiToolManageService: assert "invalid schema" in str(exc_info.value) def test_create_api_tool_provider_success( - self, flask_req_ctx_with_containers, db_session_with_containers, mock_external_service_dependencies + self, flask_req_ctx_with_containers, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test successful creation of API tool provider. @@ -324,10 +323,9 @@ class TestApiToolManageService: assert result == {"result": "success"} # Verify database state - from extensions.ext_database import db provider = ( - db.session.query(ApiToolProvider) + db_session_with_containers.query(ApiToolProvider) .filter(ApiToolProvider.tenant_id == tenant.id, ApiToolProvider.name == provider_name) .first() ) @@ -347,7 +345,7 @@ class TestApiToolManageService: mock_external_service_dependencies["provider_controller"].load_bundled_tools.assert_called_once() def test_create_api_tool_provider_duplicate_name( - self, flask_req_ctx_with_containers, db_session_with_containers, mock_external_service_dependencies + self, flask_req_ctx_with_containers, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test creation of API tool provider with duplicate name. @@ -404,7 +402,7 @@ class TestApiToolManageService: assert f"provider {provider_name} already exists" in str(exc_info.value) def test_create_api_tool_provider_invalid_schema_type( - self, flask_req_ctx_with_containers, db_session_with_containers, mock_external_service_dependencies + self, flask_req_ctx_with_containers, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test creation of API tool provider with invalid schema type. @@ -436,7 +434,7 @@ class TestApiToolManageService: assert "validation error" in str(exc_info.value) def test_create_api_tool_provider_missing_auth_type( - self, flask_req_ctx_with_containers, db_session_with_containers, mock_external_service_dependencies + self, flask_req_ctx_with_containers, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test creation of API tool provider with missing auth type. @@ -479,7 +477,7 @@ class TestApiToolManageService: assert "auth_type is required" in str(exc_info.value) def test_create_api_tool_provider_with_api_key_auth( - self, flask_req_ctx_with_containers, db_session_with_containers, mock_external_service_dependencies + self, flask_req_ctx_with_containers, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test successful creation of API tool provider with API key authentication. @@ -522,10 +520,9 @@ class TestApiToolManageService: assert result == {"result": "success"} # Verify database state - from extensions.ext_database import db provider = ( - db.session.query(ApiToolProvider) + db_session_with_containers.query(ApiToolProvider) .filter(ApiToolProvider.tenant_id == tenant.id, ApiToolProvider.name == provider_name) .first() ) diff --git a/api/tests/test_containers_integration_tests/services/tools/test_mcp_tools_manage_service.py b/api/tests/test_containers_integration_tests/services/tools/test_mcp_tools_manage_service.py index 6cae83ac37..0f2e3980af 100644 --- a/api/tests/test_containers_integration_tests/services/tools/test_mcp_tools_manage_service.py +++ b/api/tests/test_containers_integration_tests/services/tools/test_mcp_tools_manage_service.py @@ -2,6 +2,7 @@ from unittest.mock import patch import pytest from faker import Faker +from sqlalchemy.orm import Session from core.tools.entities.tool_entities import ToolProviderType from models import Account, Tenant @@ -41,7 +42,7 @@ class TestMCPToolManageService: "tool_transform_service": mock_tool_transform_service, } - def _create_test_account_and_tenant(self, db_session_with_containers, mock_external_service_dependencies): + def _create_test_account_and_tenant(self, db_session_with_containers: Session, mock_external_service_dependencies): """ Helper method to create a test account and tenant for testing. @@ -62,18 +63,16 @@ class TestMCPToolManageService: status="active", ) - from extensions.ext_database import db - - db.session.add(account) - db.session.commit() + db_session_with_containers.add(account) + db_session_with_containers.commit() # Create tenant for the account tenant = Tenant( name=fake.company(), status="normal", ) - db.session.add(tenant) - db.session.commit() + db_session_with_containers.add(tenant) + db_session_with_containers.commit() # Create tenant-account join from models.account import TenantAccountJoin, TenantAccountRole @@ -84,8 +83,8 @@ class TestMCPToolManageService: role=TenantAccountRole.OWNER, current=True, ) - db.session.add(join) - db.session.commit() + db_session_with_containers.add(join) + db_session_with_containers.commit() # Set current tenant for account account.current_tenant = tenant @@ -93,7 +92,7 @@ class TestMCPToolManageService: return account, tenant def _create_test_mcp_provider( - self, db_session_with_containers, mock_external_service_dependencies, tenant_id, user_id + self, db_session_with_containers: Session, mock_external_service_dependencies, tenant_id, user_id ): """ Helper method to create a test MCP tool provider for testing. @@ -124,15 +123,13 @@ class TestMCPToolManageService: sse_read_timeout=300.0, ) - from extensions.ext_database import db - - db.session.add(mcp_provider) - db.session.commit() + db_session_with_containers.add(mcp_provider) + db_session_with_containers.commit() return mcp_provider def test_get_mcp_provider_by_provider_id_success( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test successful retrieval of MCP provider by provider ID. @@ -153,9 +150,8 @@ class TestMCPToolManageService: ) # Act: Execute the method under test - from extensions.ext_database import db - service = MCPToolManageService(db.session()) + service = MCPToolManageService(db_session_with_containers) result = service.get_provider(provider_id=mcp_provider.id, tenant_id=tenant.id) # Assert: Verify the expected outcomes @@ -166,12 +162,12 @@ class TestMCPToolManageService: assert result.user_id == account.id # Verify database state - db.session.refresh(result) + db_session_with_containers.refresh(result) assert result.id is not None assert result.server_identifier == mcp_provider.server_identifier def test_get_mcp_provider_by_provider_id_not_found( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test error handling when MCP provider is not found by provider ID. @@ -190,14 +186,13 @@ class TestMCPToolManageService: non_existent_id = str(fake.uuid4()) # Act & Assert: Verify proper error handling - from extensions.ext_database import db - service = MCPToolManageService(db.session()) + service = MCPToolManageService(db_session_with_containers) with pytest.raises(ValueError, match="MCP tool not found"): service.get_provider(provider_id=non_existent_id, tenant_id=tenant.id) def test_get_mcp_provider_by_provider_id_tenant_isolation( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test tenant isolation when retrieving MCP provider by provider ID. @@ -223,14 +218,13 @@ class TestMCPToolManageService: ) # Act & Assert: Verify tenant isolation - from extensions.ext_database import db - service = MCPToolManageService(db.session()) + service = MCPToolManageService(db_session_with_containers) with pytest.raises(ValueError, match="MCP tool not found"): service.get_provider(provider_id=mcp_provider1.id, tenant_id=tenant2.id) def test_get_mcp_provider_by_server_identifier_success( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test successful retrieval of MCP provider by server identifier. @@ -251,9 +245,8 @@ class TestMCPToolManageService: ) # Act: Execute the method under test - from extensions.ext_database import db - service = MCPToolManageService(db.session()) + service = MCPToolManageService(db_session_with_containers) result = service.get_provider(server_identifier=mcp_provider.server_identifier, tenant_id=tenant.id) # Assert: Verify the expected outcomes @@ -264,12 +257,12 @@ class TestMCPToolManageService: assert result.user_id == account.id # Verify database state - db.session.refresh(result) + db_session_with_containers.refresh(result) assert result.id is not None assert result.name == mcp_provider.name def test_get_mcp_provider_by_server_identifier_not_found( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test error handling when MCP provider is not found by server identifier. @@ -288,14 +281,13 @@ class TestMCPToolManageService: non_existent_identifier = str(fake.uuid4()) # Act & Assert: Verify proper error handling - from extensions.ext_database import db - service = MCPToolManageService(db.session()) + service = MCPToolManageService(db_session_with_containers) with pytest.raises(ValueError, match="MCP tool not found"): service.get_provider(server_identifier=non_existent_identifier, tenant_id=tenant.id) def test_get_mcp_provider_by_server_identifier_tenant_isolation( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test tenant isolation when retrieving MCP provider by server identifier. @@ -321,13 +313,12 @@ class TestMCPToolManageService: ) # Act & Assert: Verify tenant isolation - from extensions.ext_database import db - service = MCPToolManageService(db.session()) + service = MCPToolManageService(db_session_with_containers) with pytest.raises(ValueError, match="MCP tool not found"): service.get_provider(server_identifier=mcp_provider1.server_identifier, tenant_id=tenant2.id) - def test_create_mcp_provider_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_create_mcp_provider_success(self, db_session_with_containers: Session, mock_external_service_dependencies): """ Test successful creation of MCP provider. @@ -365,9 +356,8 @@ class TestMCPToolManageService: # Act: Execute the method under test from core.entities.mcp_provider import MCPConfiguration - from extensions.ext_database import db - service = MCPToolManageService(db.session()) + service = MCPToolManageService(db_session_with_containers) result = service.create_provider( tenant_id=tenant.id, name="Test MCP Provider", @@ -389,10 +379,9 @@ class TestMCPToolManageService: assert result.type == ToolProviderType.MCP # Verify database state - from extensions.ext_database import db created_provider = ( - db.session.query(MCPToolProvider) + db_session_with_containers.query(MCPToolProvider) .filter(MCPToolProvider.tenant_id == tenant.id, MCPToolProvider.name == "Test MCP Provider") .first() ) @@ -410,7 +399,9 @@ class TestMCPToolManageService: ) mock_external_service_dependencies["tool_transform_service"].mcp_provider_to_user_provider.assert_called_once() - def test_create_mcp_provider_duplicate_name(self, db_session_with_containers, mock_external_service_dependencies): + def test_create_mcp_provider_duplicate_name( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test error handling when creating MCP provider with duplicate name. @@ -427,9 +418,8 @@ class TestMCPToolManageService: # Create first provider from core.entities.mcp_provider import MCPConfiguration - from extensions.ext_database import db - service = MCPToolManageService(db.session()) + service = MCPToolManageService(db_session_with_containers) service.create_provider( tenant_id=tenant.id, name="Test MCP Provider", @@ -463,7 +453,7 @@ class TestMCPToolManageService: ) def test_create_mcp_provider_duplicate_server_url( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test error handling when creating MCP provider with duplicate server URL. @@ -481,9 +471,8 @@ class TestMCPToolManageService: # Create first provider from core.entities.mcp_provider import MCPConfiguration - from extensions.ext_database import db - service = MCPToolManageService(db.session()) + service = MCPToolManageService(db_session_with_containers) service.create_provider( tenant_id=tenant.id, name="Test MCP Provider 1", @@ -517,7 +506,7 @@ class TestMCPToolManageService: ) def test_create_mcp_provider_duplicate_server_identifier( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test error handling when creating MCP provider with duplicate server identifier. @@ -535,9 +524,8 @@ class TestMCPToolManageService: # Create first provider from core.entities.mcp_provider import MCPConfiguration - from extensions.ext_database import db - service = MCPToolManageService(db.session()) + service = MCPToolManageService(db_session_with_containers) service.create_provider( tenant_id=tenant.id, name="Test MCP Provider 1", @@ -570,7 +558,7 @@ class TestMCPToolManageService: ), ) - def test_retrieve_mcp_tools_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_retrieve_mcp_tools_success(self, db_session_with_containers: Session, mock_external_service_dependencies): """ Test successful retrieval of MCP tools for a tenant. @@ -602,9 +590,7 @@ class TestMCPToolManageService: ) provider3.name = "Gamma Provider" - from extensions.ext_database import db - - db.session.commit() + db_session_with_containers.commit() # Setup mock for transformation service from core.tools.entities.api_entities import ToolProviderApiEntity @@ -647,9 +633,8 @@ class TestMCPToolManageService: ] # Act: Execute the method under test - from extensions.ext_database import db - service = MCPToolManageService(db.session()) + service = MCPToolManageService(db_session_with_containers) result = service.list_providers(tenant_id=tenant.id, for_list=True) # Assert: Verify the expected outcomes @@ -666,7 +651,9 @@ class TestMCPToolManageService: mock_external_service_dependencies["tool_transform_service"].mcp_provider_to_user_provider.call_count == 3 ) - def test_retrieve_mcp_tools_empty_list(self, db_session_with_containers, mock_external_service_dependencies): + def test_retrieve_mcp_tools_empty_list( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test retrieval of MCP tools when tenant has no providers. @@ -684,9 +671,8 @@ class TestMCPToolManageService: # No MCP providers created for this tenant # Act: Execute the method under test - from extensions.ext_database import db - service = MCPToolManageService(db.session()) + service = MCPToolManageService(db_session_with_containers) result = service.list_providers(tenant_id=tenant.id, for_list=False) # Assert: Verify the expected outcomes @@ -697,7 +683,9 @@ class TestMCPToolManageService: # Verify no transformation service calls for empty list mock_external_service_dependencies["tool_transform_service"].mcp_provider_to_user_provider.assert_not_called() - def test_retrieve_mcp_tools_tenant_isolation(self, db_session_with_containers, mock_external_service_dependencies): + def test_retrieve_mcp_tools_tenant_isolation( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test tenant isolation when retrieving MCP tools. @@ -756,9 +744,8 @@ class TestMCPToolManageService: ] # Act: Execute the method under test for both tenants - from extensions.ext_database import db - service = MCPToolManageService(db.session()) + service = MCPToolManageService(db_session_with_containers) result1 = service.list_providers(tenant_id=tenant1.id, for_list=True) result2 = service.list_providers(tenant_id=tenant2.id, for_list=True) @@ -769,7 +756,7 @@ class TestMCPToolManageService: assert result2[0].id == provider2.id def test_list_mcp_tool_from_remote_server_success( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test successful listing of MCP tools from remote server. @@ -797,9 +784,7 @@ class TestMCPToolManageService: mcp_provider.authed = True # Provider must be authenticated to list tools mcp_provider.tools = "[]" - from extensions.ext_database import db - - db.session.commit() + db_session_with_containers.commit() # Mock the decryption process at the rsa level to avoid key file issues with patch("libs.rsa.decrypt") as mock_decrypt: @@ -821,9 +806,8 @@ class TestMCPToolManageService: mock_client_instance.list_tools.return_value = mock_tools # Act: Execute the method under test - from extensions.ext_database import db - service = MCPToolManageService(db.session()) + service = MCPToolManageService(db_session_with_containers) result = service.list_provider_tools(tenant_id=tenant.id, provider_id=mcp_provider.id) # Assert: Verify the expected outcomes @@ -834,7 +818,7 @@ class TestMCPToolManageService: # Note: server_url is mocked, so we skip that assertion to avoid encryption issues # Verify database state was updated - db.session.refresh(mcp_provider) + db_session_with_containers.refresh(mcp_provider) assert mcp_provider.authed is True assert mcp_provider.tools != "[]" assert mcp_provider.updated_at is not None @@ -844,7 +828,7 @@ class TestMCPToolManageService: mock_mcp_client.assert_called_once() def test_list_mcp_tool_from_remote_server_auth_error( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test error handling when MCP server requires authentication. @@ -871,9 +855,7 @@ class TestMCPToolManageService: mcp_provider.authed = False mcp_provider.tools = "[]" - from extensions.ext_database import db - - db.session.commit() + db_session_with_containers.commit() # Mock the decryption process at the rsa level to avoid key file issues with patch("libs.rsa.decrypt") as mock_decrypt: @@ -887,19 +869,18 @@ class TestMCPToolManageService: mock_client_instance.list_tools.side_effect = MCPAuthError("Authentication required") # Act & Assert: Verify proper error handling - from extensions.ext_database import db - service = MCPToolManageService(db.session()) + service = MCPToolManageService(db_session_with_containers) with pytest.raises(ValueError, match="Please auth the tool first"): service.list_provider_tools(tenant_id=tenant.id, provider_id=mcp_provider.id) # Verify database state was not changed - db.session.refresh(mcp_provider) + db_session_with_containers.refresh(mcp_provider) assert mcp_provider.authed is False assert mcp_provider.tools == "[]" def test_list_mcp_tool_from_remote_server_connection_error( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test error handling when MCP server connection fails. @@ -926,9 +907,7 @@ class TestMCPToolManageService: mcp_provider.authed = True # Provider must be authenticated to test connection errors mcp_provider.tools = "[]" - from extensions.ext_database import db - - db.session.commit() + db_session_with_containers.commit() # Mock the decryption process at the rsa level to avoid key file issues with patch("libs.rsa.decrypt") as mock_decrypt: @@ -942,18 +921,17 @@ class TestMCPToolManageService: mock_client_instance.list_tools.side_effect = MCPError("Connection failed") # Act & Assert: Verify proper error handling - from extensions.ext_database import db - service = MCPToolManageService(db.session()) + service = MCPToolManageService(db_session_with_containers) with pytest.raises(ValueError, match="Failed to connect to MCP server: Connection failed"): service.list_provider_tools(tenant_id=tenant.id, provider_id=mcp_provider.id) # Verify database state was not changed - db.session.refresh(mcp_provider) + db_session_with_containers.refresh(mcp_provider) assert mcp_provider.authed is True # Provider remains authenticated assert mcp_provider.tools == "[]" - def test_delete_mcp_tool_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_delete_mcp_tool_success(self, db_session_with_containers: Session, mock_external_service_dependencies): """ Test successful deletion of MCP tool. @@ -974,20 +952,19 @@ class TestMCPToolManageService: ) # Verify provider exists - from extensions.ext_database import db - assert db.session.query(MCPToolProvider).filter_by(id=mcp_provider.id).first() is not None + assert db_session_with_containers.query(MCPToolProvider).filter_by(id=mcp_provider.id).first() is not None # Act: Execute the method under test - service = MCPToolManageService(db.session()) + service = MCPToolManageService(db_session_with_containers) service.delete_provider(tenant_id=tenant.id, provider_id=mcp_provider.id) # Assert: Verify the expected outcomes # Provider should be deleted from database - deleted_provider = db.session.query(MCPToolProvider).filter_by(id=mcp_provider.id).first() + deleted_provider = db_session_with_containers.query(MCPToolProvider).filter_by(id=mcp_provider.id).first() assert deleted_provider is None - def test_delete_mcp_tool_not_found(self, db_session_with_containers, mock_external_service_dependencies): + def test_delete_mcp_tool_not_found(self, db_session_with_containers: Session, mock_external_service_dependencies): """ Test error handling when deleting non-existent MCP tool. @@ -1005,13 +982,14 @@ class TestMCPToolManageService: non_existent_id = str(fake.uuid4()) # Act & Assert: Verify proper error handling - from extensions.ext_database import db - service = MCPToolManageService(db.session()) + service = MCPToolManageService(db_session_with_containers) with pytest.raises(ValueError, match="MCP tool not found"): service.delete_provider(tenant_id=tenant.id, provider_id=non_existent_id) - def test_delete_mcp_tool_tenant_isolation(self, db_session_with_containers, mock_external_service_dependencies): + def test_delete_mcp_tool_tenant_isolation( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test tenant isolation when deleting MCP tool. @@ -1036,18 +1014,16 @@ class TestMCPToolManageService: ) # Act & Assert: Verify tenant isolation - from extensions.ext_database import db - service = MCPToolManageService(db.session()) + service = MCPToolManageService(db_session_with_containers) with pytest.raises(ValueError, match="MCP tool not found"): service.delete_provider(tenant_id=tenant2.id, provider_id=mcp_provider1.id) # Verify provider still exists in tenant1 - from extensions.ext_database import db - assert db.session.query(MCPToolProvider).filter_by(id=mcp_provider1.id).first() is not None + assert db_session_with_containers.query(MCPToolProvider).filter_by(id=mcp_provider1.id).first() is not None - def test_update_mcp_provider_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_update_mcp_provider_success(self, db_session_with_containers: Session, mock_external_service_dependencies): """ Test successful update of MCP provider. @@ -1070,14 +1046,12 @@ class TestMCPToolManageService: original_name = mcp_provider.name original_icon = mcp_provider.icon - from extensions.ext_database import db - - db.session.commit() + db_session_with_containers.commit() # Act: Execute the method under test from core.entities.mcp_provider import MCPConfiguration - service = MCPToolManageService(db.session()) + service = MCPToolManageService(db_session_with_containers) service.update_provider( tenant_id=tenant.id, provider_id=mcp_provider.id, @@ -1094,7 +1068,7 @@ class TestMCPToolManageService: ) # Assert: Verify the expected outcomes - db.session.refresh(mcp_provider) + db_session_with_containers.refresh(mcp_provider) assert mcp_provider.name == "Updated MCP Provider" assert mcp_provider.server_identifier == "updated_identifier_123" assert mcp_provider.timeout == 45.0 @@ -1108,7 +1082,9 @@ class TestMCPToolManageService: assert icon_data["content"] == "🚀" assert icon_data["background"] == "#4ECDC4" - def test_update_mcp_provider_duplicate_name(self, db_session_with_containers, mock_external_service_dependencies): + def test_update_mcp_provider_duplicate_name( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test error handling when updating MCP provider with duplicate name. @@ -1134,15 +1110,12 @@ class TestMCPToolManageService: ) provider2.name = "Second Provider" - from extensions.ext_database import db - - db.session.commit() + db_session_with_containers.commit() # Act & Assert: Verify proper error handling for duplicate name from core.entities.mcp_provider import MCPConfiguration - from extensions.ext_database import db - service = MCPToolManageService(db.session()) + service = MCPToolManageService(db_session_with_containers) with pytest.raises(ValueError, match="MCP tool First Provider already exists"): service.update_provider( tenant_id=tenant.id, @@ -1160,7 +1133,7 @@ class TestMCPToolManageService: ) def test_update_mcp_provider_credentials_success( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test successful update of MCP provider credentials. @@ -1185,9 +1158,7 @@ class TestMCPToolManageService: mcp_provider.authed = False mcp_provider.tools = "[]" - from extensions.ext_database import db - - db.session.commit() + db_session_with_containers.commit() # Mock the provider controller and encryption with ( @@ -1202,9 +1173,8 @@ class TestMCPToolManageService: mock_encrypter_instance.encrypt.return_value = {"new_key": "encrypted_value"} # Act: Execute the method under test - from extensions.ext_database import db - service = MCPToolManageService(db.session()) + service = MCPToolManageService(db_session_with_containers) service.update_provider_credentials( provider_id=mcp_provider.id, tenant_id=tenant.id, @@ -1213,7 +1183,7 @@ class TestMCPToolManageService: ) # Assert: Verify the expected outcomes - db.session.refresh(mcp_provider) + db_session_with_containers.refresh(mcp_provider) assert mcp_provider.authed is True assert mcp_provider.updated_at is not None @@ -1225,7 +1195,7 @@ class TestMCPToolManageService: assert "new_key" in credentials def test_update_mcp_provider_credentials_not_authed( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test update of MCP provider credentials when not authenticated. @@ -1249,9 +1219,7 @@ class TestMCPToolManageService: mcp_provider.authed = True mcp_provider.tools = '[{"name": "test_tool"}]' - from extensions.ext_database import db - - db.session.commit() + db_session_with_containers.commit() # Mock the provider controller and encryption with ( @@ -1266,9 +1234,8 @@ class TestMCPToolManageService: mock_encrypter_instance.encrypt.return_value = {"new_key": "encrypted_value"} # Act: Execute the method under test - from extensions.ext_database import db - service = MCPToolManageService(db.session()) + service = MCPToolManageService(db_session_with_containers) service.update_provider_credentials( provider_id=mcp_provider.id, tenant_id=tenant.id, @@ -1277,12 +1244,14 @@ class TestMCPToolManageService: ) # Assert: Verify the expected outcomes - db.session.refresh(mcp_provider) + db_session_with_containers.refresh(mcp_provider) assert mcp_provider.authed is False assert mcp_provider.tools == "[]" assert mcp_provider.updated_at is not None - def test_re_connect_mcp_provider_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_re_connect_mcp_provider_success( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test successful reconnection to MCP provider. @@ -1343,7 +1312,9 @@ class TestMCPToolManageService: sse_read_timeout=mcp_provider.sse_read_timeout, ) - def test_re_connect_mcp_provider_auth_error(self, db_session_with_containers, mock_external_service_dependencies): + def test_re_connect_mcp_provider_auth_error( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test reconnection to MCP provider when authentication fails. @@ -1385,7 +1356,7 @@ class TestMCPToolManageService: assert result.encrypted_credentials == "{}" def test_re_connect_mcp_provider_connection_error( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test reconnection to MCP provider when connection fails. diff --git a/api/tests/test_containers_integration_tests/services/tools/test_tools_transform_service.py b/api/tests/test_containers_integration_tests/services/tools/test_tools_transform_service.py index fa13790942..f3736333ea 100644 --- a/api/tests/test_containers_integration_tests/services/tools/test_tools_transform_service.py +++ b/api/tests/test_containers_integration_tests/services/tools/test_tools_transform_service.py @@ -2,6 +2,7 @@ from unittest.mock import Mock, patch import pytest from faker import Faker +from sqlalchemy.orm import Session from core.tools.entities.api_entities import ToolProviderApiEntity from core.tools.entities.common_entities import I18nObject @@ -27,7 +28,7 @@ class TestToolTransformService: } def _create_test_tool_provider( - self, db_session_with_containers, mock_external_service_dependencies, provider_type="api" + self, db_session_with_containers: Session, mock_external_service_dependencies, provider_type="api" ): """ Helper method to create a test tool provider for testing. @@ -89,14 +90,12 @@ class TestToolTransformService: else: raise ValueError(f"Unknown provider type: {provider_type}") - from extensions.ext_database import db - - db.session.add(provider) - db.session.commit() + db_session_with_containers.add(provider) + db_session_with_containers.commit() return provider - def test_get_plugin_icon_url_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_get_plugin_icon_url_success(self, db_session_with_containers: Session, mock_external_service_dependencies): """ Test successful plugin icon URL generation. @@ -126,7 +125,7 @@ class TestToolTransformService: assert result == expected_url def test_get_plugin_icon_url_with_empty_console_url( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test plugin icon URL generation when CONSOLE_API_URL is empty. @@ -156,7 +155,7 @@ class TestToolTransformService: assert result == expected_url def test_get_tool_provider_icon_url_builtin_success( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test successful tool provider icon URL generation for builtin providers. @@ -194,7 +193,7 @@ class TestToolTransformService: assert result == expected_encoded def test_get_tool_provider_icon_url_api_success( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test successful tool provider icon URL generation for API providers. @@ -220,7 +219,7 @@ class TestToolTransformService: assert result["content"] == "🔧" def test_get_tool_provider_icon_url_api_invalid_json( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test tool provider icon URL generation for API providers with invalid JSON. @@ -246,7 +245,7 @@ class TestToolTransformService: assert result["content"] == "😁" or result["content"] == "\ud83d\ude01" def test_get_tool_provider_icon_url_workflow_success( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test successful tool provider icon URL generation for workflow providers. @@ -271,7 +270,7 @@ class TestToolTransformService: assert result["content"] == "🔧" def test_get_tool_provider_icon_url_mcp_success( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test successful tool provider icon URL generation for MCP providers. @@ -296,7 +295,7 @@ class TestToolTransformService: assert result["content"] == "🔧" def test_get_tool_provider_icon_url_unknown_type( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test tool provider icon URL generation for unknown provider types. @@ -317,7 +316,9 @@ class TestToolTransformService: # Assert: Verify the expected outcomes assert result == "" - def test_repack_provider_dict_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_repack_provider_dict_success( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test successful provider repacking with dictionary input. @@ -341,7 +342,9 @@ class TestToolTransformService: # Note: provider name may contain spaces that get URL encoded assert provider["name"].replace(" ", "%20") in provider["icon"] or provider["name"] in provider["icon"] - def test_repack_provider_entity_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_repack_provider_entity_success( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test successful provider repacking with ToolProviderApiEntity input. @@ -389,7 +392,7 @@ class TestToolTransformService: assert "test_icon_dark.png" in provider.icon_dark def test_repack_provider_entity_no_plugin_success( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test successful provider repacking with ToolProviderApiEntity input without plugin_id. @@ -435,7 +438,9 @@ class TestToolTransformService: assert provider.icon_dark["background"] == "#252525" assert provider.icon_dark["content"] == "🔧" - def test_repack_provider_entity_no_dark_icon(self, db_session_with_containers, mock_external_service_dependencies): + def test_repack_provider_entity_no_dark_icon( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test provider repacking with ToolProviderApiEntity input without dark icon. @@ -477,7 +482,7 @@ class TestToolTransformService: assert provider.icon_dark == "" def test_builtin_provider_to_user_provider_success( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test successful conversion of builtin provider to user provider. @@ -545,7 +550,7 @@ class TestToolTransformService: assert result.original_credentials == {"api_key": "decrypted_key"} def test_builtin_provider_to_user_provider_plugin_success( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test successful conversion of builtin provider to user provider with plugin. @@ -589,7 +594,7 @@ class TestToolTransformService: assert result.allow_delete is False def test_builtin_provider_to_user_provider_no_credentials( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test conversion of builtin provider to user provider without credentials. @@ -630,7 +635,9 @@ class TestToolTransformService: assert result.allow_delete is False assert result.masked_credentials == {"api_key": ""} - def test_api_provider_to_controller_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_api_provider_to_controller_success( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test successful conversion of API provider to controller. @@ -655,10 +662,8 @@ class TestToolTransformService: tools_str="[]", ) - from extensions.ext_database import db - - db.session.add(provider) - db.session.commit() + db_session_with_containers.add(provider) + db_session_with_containers.commit() # Act: Execute the method under test result = ToolTransformService.api_provider_to_controller(provider) @@ -669,7 +674,7 @@ class TestToolTransformService: # Additional assertions would depend on the actual controller implementation def test_api_provider_to_controller_api_key_query( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test conversion of API provider to controller with api_key_query auth type. @@ -693,10 +698,8 @@ class TestToolTransformService: tools_str="[]", ) - from extensions.ext_database import db - - db.session.add(provider) - db.session.commit() + db_session_with_containers.add(provider) + db_session_with_containers.commit() # Act: Execute the method under test result = ToolTransformService.api_provider_to_controller(provider) @@ -706,7 +709,7 @@ class TestToolTransformService: assert hasattr(result, "from_db") def test_api_provider_to_controller_backward_compatibility( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test conversion of API provider to controller with backward compatibility auth types. @@ -731,10 +734,8 @@ class TestToolTransformService: tools_str="[]", ) - from extensions.ext_database import db - - db.session.add(provider) - db.session.commit() + db_session_with_containers.add(provider) + db_session_with_containers.commit() # Act: Execute the method under test result = ToolTransformService.api_provider_to_controller(provider) @@ -744,7 +745,7 @@ class TestToolTransformService: assert hasattr(result, "from_db") def test_workflow_provider_to_controller_success( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test successful conversion of workflow provider to controller. @@ -769,10 +770,8 @@ class TestToolTransformService: parameter_configuration="[]", ) - from extensions.ext_database import db - - db.session.add(provider) - db.session.commit() + db_session_with_containers.add(provider) + db_session_with_containers.commit() # Mock the WorkflowToolProviderController.from_db method to avoid app dependency with patch("services.tools.tools_transform_service.WorkflowToolProviderController.from_db") as mock_from_db: diff --git a/api/tests/test_containers_integration_tests/services/tools/test_workflow_tools_manage_service.py b/api/tests/test_containers_integration_tests/services/tools/test_workflow_tools_manage_service.py index 24fe5c4670..0b3c1112bd 100644 --- a/api/tests/test_containers_integration_tests/services/tools/test_workflow_tools_manage_service.py +++ b/api/tests/test_containers_integration_tests/services/tools/test_workflow_tools_manage_service.py @@ -4,6 +4,7 @@ from unittest.mock import patch import pytest from faker import Faker from pydantic import ValidationError +from sqlalchemy.orm import Session from core.tools.entities.tool_entities import WorkflowToolParameterConfiguration from core.tools.errors import WorkflowToolHumanInputNotSupportedError @@ -63,7 +64,7 @@ class TestWorkflowToolManageService: "tool_transform_service": mock_tool_transform_service, } - def _create_test_app_and_account(self, db_session_with_containers, mock_external_service_dependencies): + def _create_test_app_and_account(self, db_session_with_containers: Session, mock_external_service_dependencies): """ Helper method to create a test app and account for testing. @@ -119,14 +120,12 @@ class TestWorkflowToolManageService: conversation_variables=[], ) - from extensions.ext_database import db - - db.session.add(workflow) - db.session.commit() + db_session_with_containers.add(workflow) + db_session_with_containers.commit() # Update app to reference the workflow app.workflow_id = workflow.id - db.session.commit() + db_session_with_containers.commit() return app, account, workflow @@ -153,7 +152,9 @@ class TestWorkflowToolManageService: ), ] - def test_create_workflow_tool_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_create_workflow_tool_success( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test successful workflow tool creation with valid parameters. @@ -198,11 +199,10 @@ class TestWorkflowToolManageService: assert result == {"result": "success"} # Verify database state - from extensions.ext_database import db # Check if workflow tool provider was created created_tool_provider = ( - db.session.query(WorkflowToolProvider) + db_session_with_containers.query(WorkflowToolProvider) .where( WorkflowToolProvider.tenant_id == account.current_tenant.id, WorkflowToolProvider.app_id == app.id, @@ -230,7 +230,7 @@ class TestWorkflowToolManageService: ].workflow_provider_to_controller.assert_called_once() def test_create_workflow_tool_duplicate_name_error( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test workflow tool creation fails when name already exists. @@ -280,10 +280,9 @@ class TestWorkflowToolManageService: assert f"Tool with name {first_tool_name} or app_id {app.id} already exists" in str(exc_info.value) # Verify only one tool was created - from extensions.ext_database import db tool_count = ( - db.session.query(WorkflowToolProvider) + db_session_with_containers.query(WorkflowToolProvider) .where( WorkflowToolProvider.tenant_id == account.current_tenant.id, ) @@ -293,7 +292,7 @@ class TestWorkflowToolManageService: assert tool_count == 1 def test_create_workflow_tool_invalid_app_error( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test workflow tool creation fails when app does not exist. @@ -331,10 +330,9 @@ class TestWorkflowToolManageService: assert f"App {non_existent_app_id} not found" in str(exc_info.value) # Verify no workflow tool was created - from extensions.ext_database import db tool_count = ( - db.session.query(WorkflowToolProvider) + db_session_with_containers.query(WorkflowToolProvider) .where( WorkflowToolProvider.tenant_id == account.current_tenant.id, ) @@ -344,7 +342,7 @@ class TestWorkflowToolManageService: assert tool_count == 0 def test_create_workflow_tool_invalid_parameters_error( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test workflow tool creation fails when parameters are invalid. @@ -387,10 +385,9 @@ class TestWorkflowToolManageService: assert "validation error" in str(exc_info.value).lower() # Verify no workflow tool was created - from extensions.ext_database import db tool_count = ( - db.session.query(WorkflowToolProvider) + db_session_with_containers.query(WorkflowToolProvider) .where( WorkflowToolProvider.tenant_id == account.current_tenant.id, ) @@ -400,7 +397,7 @@ class TestWorkflowToolManageService: assert tool_count == 0 def test_create_workflow_tool_duplicate_app_id_error( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test workflow tool creation fails when app_id already exists. @@ -450,10 +447,9 @@ class TestWorkflowToolManageService: assert f"Tool with name {second_tool_name} or app_id {app.id} already exists" in str(exc_info.value) # Verify only one tool was created - from extensions.ext_database import db tool_count = ( - db.session.query(WorkflowToolProvider) + db_session_with_containers.query(WorkflowToolProvider) .where( WorkflowToolProvider.tenant_id == account.current_tenant.id, ) @@ -463,7 +459,7 @@ class TestWorkflowToolManageService: assert tool_count == 1 def test_create_workflow_tool_workflow_not_found_error( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test workflow tool creation fails when app has no workflow. @@ -481,10 +477,9 @@ class TestWorkflowToolManageService: ) # Remove workflow reference from app - from extensions.ext_database import db app.workflow_id = None - db.session.commit() + db_session_with_containers.commit() # Attempt to create workflow tool for app without workflow tool_parameters = self._create_test_workflow_tool_parameters() @@ -505,7 +500,7 @@ class TestWorkflowToolManageService: # Verify no workflow tool was created tool_count = ( - db.session.query(WorkflowToolProvider) + db_session_with_containers.query(WorkflowToolProvider) .where( WorkflowToolProvider.tenant_id == account.current_tenant.id, ) @@ -515,7 +510,7 @@ class TestWorkflowToolManageService: assert tool_count == 0 def test_create_workflow_tool_human_input_node_error( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test workflow tool creation fails when workflow contains human input nodes. @@ -558,10 +553,8 @@ class TestWorkflowToolManageService: assert exc_info.value.error_code == "workflow_tool_human_input_not_supported" - from extensions.ext_database import db - tool_count = ( - db.session.query(WorkflowToolProvider) + db_session_with_containers.query(WorkflowToolProvider) .where( WorkflowToolProvider.tenant_id == account.current_tenant.id, ) @@ -570,7 +563,9 @@ class TestWorkflowToolManageService: assert tool_count == 0 - def test_update_workflow_tool_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_update_workflow_tool_success( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test successful workflow tool update with valid parameters. @@ -603,10 +598,9 @@ class TestWorkflowToolManageService: ) # Get the created tool - from extensions.ext_database import db created_tool = ( - db.session.query(WorkflowToolProvider) + db_session_with_containers.query(WorkflowToolProvider) .where( WorkflowToolProvider.tenant_id == account.current_tenant.id, WorkflowToolProvider.app_id == app.id, @@ -641,7 +635,7 @@ class TestWorkflowToolManageService: assert result == {"result": "success"} # Verify database state was updated - db.session.refresh(created_tool) + db_session_with_containers.refresh(created_tool) assert created_tool is not None assert created_tool.name == updated_tool_name assert created_tool.label == updated_tool_label @@ -658,7 +652,7 @@ class TestWorkflowToolManageService: mock_external_service_dependencies["tool_transform_service"].workflow_provider_to_controller.assert_called() def test_update_workflow_tool_human_input_node_error( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test workflow tool update fails when workflow contains human input nodes. @@ -689,10 +683,8 @@ class TestWorkflowToolManageService: parameters=initial_tool_parameters, ) - from extensions.ext_database import db - created_tool = ( - db.session.query(WorkflowToolProvider) + db_session_with_containers.query(WorkflowToolProvider) .where( WorkflowToolProvider.tenant_id == account.current_tenant.id, WorkflowToolProvider.app_id == app.id, @@ -712,7 +704,7 @@ class TestWorkflowToolManageService: ] } ) - db.session.commit() + db_session_with_containers.commit() with pytest.raises(WorkflowToolHumanInputNotSupportedError) as exc_info: WorkflowToolManageService.update_workflow_tool( @@ -728,10 +720,12 @@ class TestWorkflowToolManageService: assert exc_info.value.error_code == "workflow_tool_human_input_not_supported" - db.session.refresh(created_tool) + db_session_with_containers.refresh(created_tool) assert created_tool.name == original_name - def test_update_workflow_tool_not_found_error(self, db_session_with_containers, mock_external_service_dependencies): + def test_update_workflow_tool_not_found_error( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test workflow tool update fails when tool does not exist. @@ -768,10 +762,9 @@ class TestWorkflowToolManageService: assert f"Tool {non_existent_tool_id} not found" in str(exc_info.value) # Verify no workflow tool was created - from extensions.ext_database import db tool_count = ( - db.session.query(WorkflowToolProvider) + db_session_with_containers.query(WorkflowToolProvider) .where( WorkflowToolProvider.tenant_id == account.current_tenant.id, ) @@ -781,7 +774,7 @@ class TestWorkflowToolManageService: assert tool_count == 0 def test_update_workflow_tool_same_name_success( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test workflow tool update succeeds when keeping the same name. @@ -813,10 +806,9 @@ class TestWorkflowToolManageService: ) # Get the created tool - from extensions.ext_database import db created_tool = ( - db.session.query(WorkflowToolProvider) + db_session_with_containers.query(WorkflowToolProvider) .where( WorkflowToolProvider.tenant_id == account.current_tenant.id, WorkflowToolProvider.app_id == app.id, @@ -840,12 +832,12 @@ class TestWorkflowToolManageService: assert result == {"result": "success"} # Verify tool still exists with the same name - db.session.refresh(created_tool) + db_session_with_containers.refresh(created_tool) assert created_tool.name == first_tool_name assert created_tool.updated_at is not None def test_create_workflow_tool_with_file_parameter_default( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test workflow tool creation with FILE parameter having a file object as default. @@ -916,7 +908,7 @@ class TestWorkflowToolManageService: assert result == {"result": "success"} def test_create_workflow_tool_with_files_parameter_default( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test workflow tool creation with FILES (Array[File]) parameter having file objects as default. @@ -991,7 +983,7 @@ class TestWorkflowToolManageService: assert result == {"result": "success"} def test_create_workflow_tool_db_commit_before_validation( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test that database commit happens before validation, causing DB pollution on validation failure. @@ -1035,10 +1027,9 @@ class TestWorkflowToolManageService: # Verify the tool was NOT created in database # This is the expected behavior (no pollution) - from extensions.ext_database import db tool_count = ( - db.session.query(WorkflowToolProvider) + db_session_with_containers.query(WorkflowToolProvider) .where( WorkflowToolProvider.tenant_id == account.current_tenant.id, WorkflowToolProvider.name == tool_name, diff --git a/api/tests/test_containers_integration_tests/services/workflow/test_workflow_converter.py b/api/tests/test_containers_integration_tests/services/workflow/test_workflow_converter.py index 0c2ccaa051..8c007877fd 100644 --- a/api/tests/test_containers_integration_tests/services/workflow/test_workflow_converter.py +++ b/api/tests/test_containers_integration_tests/services/workflow/test_workflow_converter.py @@ -3,6 +3,7 @@ from unittest.mock import patch import pytest from faker import Faker +from sqlalchemy.orm import Session from core.app.app_config.entities import ( DatasetEntity, @@ -79,7 +80,7 @@ class TestWorkflowConverter: mock_config.app_model_config_dict = {} return mock_config - def _create_test_account_and_tenant(self, db_session_with_containers, mock_external_service_dependencies): + def _create_test_account_and_tenant(self, db_session_with_containers: Session, mock_external_service_dependencies): """ Helper method to create a test account and tenant for testing. @@ -100,18 +101,16 @@ class TestWorkflowConverter: status="active", ) - from extensions.ext_database import db - - db.session.add(account) - db.session.commit() + db_session_with_containers.add(account) + db_session_with_containers.commit() # Create tenant for the account tenant = Tenant( name=fake.company(), status="normal", ) - db.session.add(tenant) - db.session.commit() + db_session_with_containers.add(tenant) + db_session_with_containers.commit() # Create tenant-account join from models.account import TenantAccountJoin, TenantAccountRole @@ -122,15 +121,17 @@ class TestWorkflowConverter: role=TenantAccountRole.OWNER, current=True, ) - db.session.add(join) - db.session.commit() + db_session_with_containers.add(join) + db_session_with_containers.commit() # Set current tenant for account account.current_tenant = tenant return account, tenant - def _create_test_app(self, db_session_with_containers, mock_external_service_dependencies, tenant, account): + def _create_test_app( + self, db_session_with_containers: Session, mock_external_service_dependencies, tenant, account + ): """ Helper method to create a test app for testing. @@ -163,10 +164,8 @@ class TestWorkflowConverter: updated_by=account.id, ) - from extensions.ext_database import db - - db.session.add(app) - db.session.commit() + db_session_with_containers.add(app) + db_session_with_containers.commit() # Create app model config app_model_config = AppModelConfig( @@ -177,16 +176,16 @@ class TestWorkflowConverter: created_by=account.id, updated_by=account.id, ) - db.session.add(app_model_config) - db.session.commit() + db_session_with_containers.add(app_model_config) + db_session_with_containers.commit() # Link app model config to app app.app_model_config_id = app_model_config.id - db.session.commit() + db_session_with_containers.commit() return app - def test_convert_to_workflow_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_convert_to_workflow_success(self, db_session_with_containers: Session, mock_external_service_dependencies): """ Test successful conversion of app to workflow. @@ -225,19 +224,18 @@ class TestWorkflowConverter: assert new_app.created_by == account.id # Verify database state - from extensions.ext_database import db - db.session.refresh(new_app) + db_session_with_containers.refresh(new_app) assert new_app.id is not None # Verify workflow was created - workflow = db.session.query(Workflow).where(Workflow.app_id == new_app.id).first() + workflow = db_session_with_containers.query(Workflow).where(Workflow.app_id == new_app.id).first() assert workflow is not None assert workflow.tenant_id == app.tenant_id assert workflow.type == "chat" def test_convert_to_workflow_without_app_model_config_error( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test error handling when app model config is missing. @@ -270,16 +268,14 @@ class TestWorkflowConverter: updated_by=account.id, ) - from extensions.ext_database import db - - db.session.add(app) - db.session.commit() + db_session_with_containers.add(app) + db_session_with_containers.commit() # Act & Assert: Verify proper error handling workflow_converter = WorkflowConverter() # Check initial state - initial_workflow_count = db.session.query(Workflow).count() + initial_workflow_count = db_session_with_containers.query(Workflow).count() with pytest.raises(ValueError, match="App model config is required"): workflow_converter.convert_to_workflow( @@ -294,12 +290,12 @@ class TestWorkflowConverter: # Verify database state remains unchanged # The workflow creation happens in convert_app_model_config_to_workflow # which is called before the app_model_config check, so we need to clean up - db.session.rollback() - final_workflow_count = db.session.query(Workflow).count() + db_session_with_containers.rollback() + final_workflow_count = db_session_with_containers.query(Workflow).count() assert final_workflow_count == initial_workflow_count def test_convert_app_model_config_to_workflow_success( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test successful conversion of app model config to workflow. @@ -356,16 +352,17 @@ class TestWorkflowConverter: assert answer_node["id"] == "answer" # Verify database state - from extensions.ext_database import db - db.session.refresh(workflow) + db_session_with_containers.refresh(workflow) assert workflow.id is not None # Verify features were set features = json.loads(workflow._features) if workflow._features else {} assert isinstance(features, dict) - def test_convert_to_start_node_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_convert_to_start_node_success( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test successful conversion to start node. @@ -410,7 +407,9 @@ class TestWorkflowConverter: assert second_variable["label"] == "Number Input" assert second_variable["type"] == "number" - def test_convert_to_http_request_node_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_convert_to_http_request_node_success( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test successful conversion to HTTP request node. @@ -436,10 +435,8 @@ class TestWorkflowConverter: api_endpoint="https://api.example.com/test", ) - from extensions.ext_database import db - - db.session.add(api_based_extension) - db.session.commit() + db_session_with_containers.add(api_based_extension) + db_session_with_containers.commit() # Mock encrypter mock_external_service_dependencies["encrypter"].decrypt_token.return_value = "decrypted_api_key" @@ -489,7 +486,7 @@ class TestWorkflowConverter: assert external_data_variable_node_mapping["external_data"] == code_node["id"] def test_convert_to_knowledge_retrieval_node_success( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test successful conversion to knowledge retrieval node. diff --git a/api/tests/test_containers_integration_tests/tasks/test_add_document_to_index_task.py b/api/tests/test_containers_integration_tests/tasks/test_add_document_to_index_task.py index 8bb536c34a..efeb29cf20 100644 --- a/api/tests/test_containers_integration_tests/tasks/test_add_document_to_index_task.py +++ b/api/tests/test_containers_integration_tests/tasks/test_add_document_to_index_task.py @@ -2,9 +2,9 @@ from unittest.mock import MagicMock, patch import pytest from faker import Faker +from sqlalchemy.orm import Session from core.rag.index_processor.constant.index_type import IndexStructureType -from extensions.ext_database import db from extensions.ext_redis import redis_client from models import Account, Tenant, TenantAccountJoin, TenantAccountRole from models.dataset import Dataset, DatasetAutoDisableLog, Document, DocumentSegment @@ -31,7 +31,9 @@ class TestAddDocumentToIndexTask: "index_processor": mock_processor, } - def _create_test_dataset_and_document(self, db_session_with_containers, mock_external_service_dependencies): + def _create_test_dataset_and_document( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Helper method to create a test dataset and document for testing. @@ -51,15 +53,15 @@ class TestAddDocumentToIndexTask: interface_language="en-US", status="active", ) - db.session.add(account) - db.session.commit() + db_session_with_containers.add(account) + db_session_with_containers.commit() tenant = Tenant( name=fake.company(), status="normal", ) - db.session.add(tenant) - db.session.commit() + db_session_with_containers.add(tenant) + db_session_with_containers.commit() # Create tenant-account join join = TenantAccountJoin( @@ -68,8 +70,8 @@ class TestAddDocumentToIndexTask: role=TenantAccountRole.OWNER, current=True, ) - db.session.add(join) - db.session.commit() + db_session_with_containers.add(join) + db_session_with_containers.commit() # Create dataset dataset = Dataset( @@ -81,8 +83,8 @@ class TestAddDocumentToIndexTask: indexing_technique="high_quality", created_by=account.id, ) - db.session.add(dataset) - db.session.commit() + db_session_with_containers.add(dataset) + db_session_with_containers.commit() # Create document document = Document( @@ -99,15 +101,15 @@ class TestAddDocumentToIndexTask: enabled=True, doc_form=IndexStructureType.PARAGRAPH_INDEX, ) - db.session.add(document) - db.session.commit() + db_session_with_containers.add(document) + db_session_with_containers.commit() # Refresh dataset to ensure doc_form property works correctly - db.session.refresh(dataset) + db_session_with_containers.refresh(dataset) return dataset, document - def _create_test_segments(self, db_session_with_containers, document, dataset): + def _create_test_segments(self, db_session_with_containers: Session, document, dataset): """ Helper method to create test document segments. @@ -138,13 +140,15 @@ class TestAddDocumentToIndexTask: status="completed", created_by=document.created_by, ) - db.session.add(segment) + db_session_with_containers.add(segment) segments.append(segment) - db.session.commit() + db_session_with_containers.commit() return segments - def test_add_document_to_index_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_add_document_to_index_success( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test successful document indexing with paragraph index type. @@ -180,9 +184,9 @@ class TestAddDocumentToIndexTask: mock_external_service_dependencies["index_processor"].load.assert_called_once() # Verify database state changes - db.session.refresh(document) + db_session_with_containers.refresh(document) for segment in segments: - db.session.refresh(segment) + db_session_with_containers.refresh(segment) assert segment.enabled is True assert segment.disabled_at is None assert segment.disabled_by is None @@ -191,7 +195,7 @@ class TestAddDocumentToIndexTask: assert redis_client.exists(indexing_cache_key) == 0 def test_add_document_to_index_with_different_index_type( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test document indexing with different index types. @@ -209,10 +213,10 @@ class TestAddDocumentToIndexTask: # Update document to use different index type document.doc_form = IndexStructureType.QA_INDEX - db.session.commit() + db_session_with_containers.commit() # Refresh dataset to ensure doc_form property reflects the updated document - db.session.refresh(dataset) + db_session_with_containers.refresh(dataset) # Create segments segments = self._create_test_segments(db_session_with_containers, document, dataset) @@ -237,9 +241,9 @@ class TestAddDocumentToIndexTask: assert len(documents) == 3 # Verify database state changes - db.session.refresh(document) + db_session_with_containers.refresh(document) for segment in segments: - db.session.refresh(segment) + db_session_with_containers.refresh(segment) assert segment.enabled is True assert segment.disabled_at is None assert segment.disabled_by is None @@ -248,7 +252,7 @@ class TestAddDocumentToIndexTask: assert redis_client.exists(indexing_cache_key) == 0 def test_add_document_to_index_document_not_found( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test handling of non-existent document. @@ -275,7 +279,7 @@ class TestAddDocumentToIndexTask: # because indexing_cache_key is not defined in that case def test_add_document_to_index_invalid_indexing_status( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test handling of document with invalid indexing status. @@ -294,7 +298,7 @@ class TestAddDocumentToIndexTask: # Set invalid indexing status document.indexing_status = "processing" - db.session.commit() + db_session_with_containers.commit() # Act: Execute the task add_document_to_index_task(document.id) @@ -304,7 +308,7 @@ class TestAddDocumentToIndexTask: mock_external_service_dependencies["index_processor"].load.assert_not_called() def test_add_document_to_index_dataset_not_found( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test handling when document's dataset doesn't exist. @@ -326,14 +330,14 @@ class TestAddDocumentToIndexTask: redis_client.set(indexing_cache_key, "processing", ex=300) # Delete the dataset to simulate dataset not found scenario - db.session.delete(dataset) - db.session.commit() + db_session_with_containers.delete(dataset) + db_session_with_containers.commit() # Act: Execute the task add_document_to_index_task(document.id) # Assert: Verify error handling - db.session.refresh(document) + db_session_with_containers.refresh(document) assert document.enabled is False assert document.indexing_status == "error" assert document.error is not None @@ -348,7 +352,7 @@ class TestAddDocumentToIndexTask: assert redis_client.exists(indexing_cache_key) == 0 def test_add_document_to_index_with_parent_child_structure( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test document indexing with parent-child structure. @@ -367,10 +371,10 @@ class TestAddDocumentToIndexTask: # Update document to use parent-child index type document.doc_form = IndexStructureType.PARENT_CHILD_INDEX - db.session.commit() + db_session_with_containers.commit() # Refresh dataset to ensure doc_form property reflects the updated document - db.session.refresh(dataset) + db_session_with_containers.refresh(dataset) # Create segments with mock child chunks segments = self._create_test_segments(db_session_with_containers, document, dataset) @@ -413,9 +417,9 @@ class TestAddDocumentToIndexTask: assert len(doc.children) == 2 # Each document has 2 children # Verify database state changes - db.session.refresh(document) + db_session_with_containers.refresh(document) for segment in segments: - db.session.refresh(segment) + db_session_with_containers.refresh(segment) assert segment.enabled is True assert segment.disabled_at is None assert segment.disabled_by is None @@ -424,7 +428,7 @@ class TestAddDocumentToIndexTask: assert redis_client.exists(indexing_cache_key) == 0 def test_add_document_to_index_with_already_enabled_segments( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test document indexing when segments are already enabled. @@ -459,10 +463,10 @@ class TestAddDocumentToIndexTask: status="completed", created_by=document.created_by, ) - db.session.add(segment) + db_session_with_containers.add(segment) segments.append(segment) - db.session.commit() + db_session_with_containers.commit() # Set up Redis cache key indexing_cache_key = f"document_{document.id}_indexing" @@ -488,7 +492,7 @@ class TestAddDocumentToIndexTask: assert redis_client.exists(indexing_cache_key) == 0 def test_add_document_to_index_auto_disable_log_deletion( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test that auto disable logs are properly deleted during indexing. @@ -515,10 +519,10 @@ class TestAddDocumentToIndexTask: document_id=document.id, ) log_entry.id = str(fake.uuid4()) - db.session.add(log_entry) + db_session_with_containers.add(log_entry) auto_disable_logs.append(log_entry) - db.session.commit() + db_session_with_containers.commit() # Set up Redis cache key indexing_cache_key = f"document_{document.id}_indexing" @@ -526,7 +530,9 @@ class TestAddDocumentToIndexTask: # Verify logs exist before processing existing_logs = ( - db.session.query(DatasetAutoDisableLog).where(DatasetAutoDisableLog.document_id == document.id).all() + db_session_with_containers.query(DatasetAutoDisableLog) + .where(DatasetAutoDisableLog.document_id == document.id) + .all() ) assert len(existing_logs) == 2 @@ -535,7 +541,9 @@ class TestAddDocumentToIndexTask: # Assert: Verify auto disable logs were deleted remaining_logs = ( - db.session.query(DatasetAutoDisableLog).where(DatasetAutoDisableLog.document_id == document.id).all() + db_session_with_containers.query(DatasetAutoDisableLog) + .where(DatasetAutoDisableLog.document_id == document.id) + .all() ) assert len(remaining_logs) == 0 @@ -547,14 +555,14 @@ class TestAddDocumentToIndexTask: # Verify segments were enabled for segment in segments: - db.session.refresh(segment) + db_session_with_containers.refresh(segment) assert segment.enabled is True # Verify redis cache was cleared assert redis_client.exists(indexing_cache_key) == 0 def test_add_document_to_index_general_exception_handling( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test general exception handling during indexing process. @@ -584,7 +592,7 @@ class TestAddDocumentToIndexTask: add_document_to_index_task(document.id) # Assert: Verify error handling - db.session.refresh(document) + db_session_with_containers.refresh(document) assert document.enabled is False assert document.indexing_status == "error" assert document.error is not None @@ -593,14 +601,14 @@ class TestAddDocumentToIndexTask: # Verify segments were not enabled due to error for segment in segments: - db.session.refresh(segment) + db_session_with_containers.refresh(segment) assert segment.enabled is False # Should remain disabled due to error # Verify redis cache was still cleared despite error assert redis_client.exists(indexing_cache_key) == 0 def test_add_document_to_index_segment_filtering_edge_cases( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test segment filtering with various edge cases. @@ -638,7 +646,7 @@ class TestAddDocumentToIndexTask: status="completed", created_by=document.created_by, ) - db.session.add(segment1) + db_session_with_containers.add(segment1) segments.append(segment1) # Segment 2: Should be processed (enabled=True, status="completed") @@ -658,7 +666,7 @@ class TestAddDocumentToIndexTask: status="completed", created_by=document.created_by, ) - db.session.add(segment2) + db_session_with_containers.add(segment2) segments.append(segment2) # Segment 3: Should NOT be processed (enabled=False, status="processing") @@ -677,7 +685,7 @@ class TestAddDocumentToIndexTask: status="processing", # Not completed created_by=document.created_by, ) - db.session.add(segment3) + db_session_with_containers.add(segment3) segments.append(segment3) # Segment 4: Should be processed (enabled=False, status="completed") @@ -696,10 +704,10 @@ class TestAddDocumentToIndexTask: status="completed", created_by=document.created_by, ) - db.session.add(segment4) + db_session_with_containers.add(segment4) segments.append(segment4) - db.session.commit() + db_session_with_containers.commit() # Set up Redis cache key indexing_cache_key = f"document_{document.id}_indexing" @@ -728,11 +736,11 @@ class TestAddDocumentToIndexTask: assert documents[2].metadata["doc_id"] == "node_3" # segment4, position 3 # Verify database state changes - db.session.refresh(document) - db.session.refresh(segment1) - db.session.refresh(segment2) - db.session.refresh(segment3) - db.session.refresh(segment4) + db_session_with_containers.refresh(document) + db_session_with_containers.refresh(segment1) + db_session_with_containers.refresh(segment2) + db_session_with_containers.refresh(segment3) + db_session_with_containers.refresh(segment4) # All segments should be enabled because the task updates ALL segments for the document assert segment1.enabled is True @@ -744,7 +752,7 @@ class TestAddDocumentToIndexTask: assert redis_client.exists(indexing_cache_key) == 0 def test_add_document_to_index_comprehensive_error_scenarios( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test comprehensive error scenarios and recovery. @@ -779,7 +787,7 @@ class TestAddDocumentToIndexTask: document.indexing_status = "completed" document.error = None document.disabled_at = None - db.session.commit() + db_session_with_containers.commit() # Set up Redis cache key indexing_cache_key = f"document_{document.id}_indexing" @@ -789,7 +797,7 @@ class TestAddDocumentToIndexTask: add_document_to_index_task(document.id) # Assert: Verify consistent error handling - db.session.refresh(document) + db_session_with_containers.refresh(document) assert document.enabled is False, f"Document should be disabled for {error_name}" assert document.indexing_status == "error", f"Document status should be error for {error_name}" assert document.error is not None, f"Error should be recorded for {error_name}" @@ -798,7 +806,7 @@ class TestAddDocumentToIndexTask: # Verify segments remain disabled due to error for segment in segments: - db.session.refresh(segment) + db_session_with_containers.refresh(segment) assert segment.enabled is False, f"Segments should remain disabled for {error_name}" # Verify redis cache was still cleared despite error diff --git a/api/tests/test_containers_integration_tests/tasks/test_batch_clean_document_task.py b/api/tests/test_containers_integration_tests/tasks/test_batch_clean_document_task.py index f94c5b19e6..ec789418a8 100644 --- a/api/tests/test_containers_integration_tests/tasks/test_batch_clean_document_task.py +++ b/api/tests/test_containers_integration_tests/tasks/test_batch_clean_document_task.py @@ -11,8 +11,8 @@ from unittest.mock import Mock, patch import pytest from faker import Faker +from sqlalchemy.orm import Session -from extensions.ext_database import db from libs.datetime_utils import naive_utc_now from models import Account, Tenant, TenantAccountJoin, TenantAccountRole from models.dataset import Dataset, Document, DocumentSegment @@ -49,7 +49,7 @@ class TestBatchCleanDocumentTask: "get_image_ids": mock_get_image_ids, } - def _create_test_account(self, db_session_with_containers): + def _create_test_account(self, db_session_with_containers: Session): """ Helper method to create a test account for testing. @@ -69,16 +69,16 @@ class TestBatchCleanDocumentTask: status="active", ) - db.session.add(account) - db.session.commit() + db_session_with_containers.add(account) + db_session_with_containers.commit() # Create tenant for the account tenant = Tenant( name=fake.company(), status="normal", ) - db.session.add(tenant) - db.session.commit() + db_session_with_containers.add(tenant) + db_session_with_containers.commit() # Create tenant-account join join = TenantAccountJoin( @@ -87,15 +87,15 @@ class TestBatchCleanDocumentTask: role=TenantAccountRole.OWNER, current=True, ) - db.session.add(join) - db.session.commit() + db_session_with_containers.add(join) + db_session_with_containers.commit() # Set current tenant for account account.current_tenant = tenant return account - def _create_test_dataset(self, db_session_with_containers, account): + def _create_test_dataset(self, db_session_with_containers: Session, account): """ Helper method to create a test dataset for testing. @@ -119,12 +119,12 @@ class TestBatchCleanDocumentTask: embedding_model_provider="openai", ) - db.session.add(dataset) - db.session.commit() + db_session_with_containers.add(dataset) + db_session_with_containers.commit() return dataset - def _create_test_document(self, db_session_with_containers, dataset, account): + def _create_test_document(self, db_session_with_containers: Session, dataset, account): """ Helper method to create a test document for testing. @@ -153,12 +153,12 @@ class TestBatchCleanDocumentTask: doc_form="text_model", ) - db.session.add(document) - db.session.commit() + db_session_with_containers.add(document) + db_session_with_containers.commit() return document - def _create_test_document_segment(self, db_session_with_containers, document, account): + def _create_test_document_segment(self, db_session_with_containers: Session, document, account): """ Helper method to create a test document segment for testing. @@ -186,12 +186,12 @@ class TestBatchCleanDocumentTask: status="completed", ) - db.session.add(segment) - db.session.commit() + db_session_with_containers.add(segment) + db_session_with_containers.commit() return segment - def _create_test_upload_file(self, db_session_with_containers, account): + def _create_test_upload_file(self, db_session_with_containers: Session, account): """ Helper method to create a test upload file for testing. @@ -220,13 +220,13 @@ class TestBatchCleanDocumentTask: used=False, ) - db.session.add(upload_file) - db.session.commit() + db_session_with_containers.add(upload_file) + db_session_with_containers.commit() return upload_file def test_batch_clean_document_task_successful_cleanup( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test successful cleanup of documents with segments and files. @@ -245,7 +245,7 @@ class TestBatchCleanDocumentTask: # Update document to reference the upload file document.data_source_info = json.dumps({"upload_file_id": upload_file.id}) - db.session.commit() + db_session_with_containers.commit() # Store original IDs for verification document_id = document.id @@ -261,18 +261,18 @@ class TestBatchCleanDocumentTask: # The task should have processed the segment and cleaned up the database # Verify database cleanup - db.session.commit() # Ensure all changes are committed + db_session_with_containers.commit() # Ensure all changes are committed # Check that segment is deleted - deleted_segment = db.session.query(DocumentSegment).filter_by(id=segment_id).first() + deleted_segment = db_session_with_containers.query(DocumentSegment).filter_by(id=segment_id).first() assert deleted_segment is None # Check that upload file is deleted - deleted_file = db.session.query(UploadFile).filter_by(id=file_id).first() + deleted_file = db_session_with_containers.query(UploadFile).filter_by(id=file_id).first() assert deleted_file is None def test_batch_clean_document_task_with_image_files( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test cleanup of documents containing image references. @@ -300,8 +300,8 @@ class TestBatchCleanDocumentTask: status="completed", ) - db.session.add(segment) - db.session.commit() + db_session_with_containers.add(segment) + db_session_with_containers.commit() # Store original IDs for verification segment_id = segment.id @@ -313,17 +313,17 @@ class TestBatchCleanDocumentTask: ) # Verify database cleanup - db.session.commit() + db_session_with_containers.commit() # Check that segment is deleted - deleted_segment = db.session.query(DocumentSegment).filter_by(id=segment_id).first() + deleted_segment = db_session_with_containers.query(DocumentSegment).filter_by(id=segment_id).first() assert deleted_segment is None # Verify that the task completed successfully by checking the log output # The task should have processed the segment and cleaned up the database def test_batch_clean_document_task_no_segments( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test cleanup when document has no segments. @@ -339,7 +339,7 @@ class TestBatchCleanDocumentTask: # Update document to reference the upload file document.data_source_info = json.dumps({"upload_file_id": upload_file.id}) - db.session.commit() + db_session_with_containers.commit() # Store original IDs for verification document_id = document.id @@ -354,21 +354,21 @@ class TestBatchCleanDocumentTask: # Since there are no segments, the task should handle this gracefully # Verify database cleanup - db.session.commit() + db_session_with_containers.commit() # Check that upload file is deleted - deleted_file = db.session.query(UploadFile).filter_by(id=file_id).first() + deleted_file = db_session_with_containers.query(UploadFile).filter_by(id=file_id).first() assert deleted_file is None # Verify database cleanup - db.session.commit() + db_session_with_containers.commit() # Check that upload file is deleted - deleted_file = db.session.query(UploadFile).filter_by(id=file_id).first() + deleted_file = db_session_with_containers.query(UploadFile).filter_by(id=file_id).first() assert deleted_file is None def test_batch_clean_document_task_dataset_not_found( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test cleanup when dataset is not found. @@ -386,8 +386,8 @@ class TestBatchCleanDocumentTask: dataset_id = dataset.id # Delete the dataset to simulate not found scenario - db.session.delete(dataset) - db.session.commit() + db_session_with_containers.delete(dataset) + db_session_with_containers.commit() # Execute the task with non-existent dataset batch_clean_document_task(document_ids=[document_id], dataset_id=dataset_id, doc_form="text_model", file_ids=[]) @@ -399,14 +399,14 @@ class TestBatchCleanDocumentTask: mock_external_service_dependencies["storage"].delete.assert_not_called() # Verify that no database cleanup occurred - db.session.commit() + db_session_with_containers.commit() # Document should still exist since cleanup failed - existing_document = db.session.query(Document).filter_by(id=document_id).first() + existing_document = db_session_with_containers.query(Document).filter_by(id=document_id).first() assert existing_document is not None def test_batch_clean_document_task_storage_cleanup_failure( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test cleanup when storage operations fail. @@ -423,7 +423,7 @@ class TestBatchCleanDocumentTask: # Update document to reference the upload file document.data_source_info = json.dumps({"upload_file_id": upload_file.id}) - db.session.commit() + db_session_with_containers.commit() # Store original IDs for verification document_id = document.id @@ -442,18 +442,18 @@ class TestBatchCleanDocumentTask: # The task should continue processing even when storage operations fail # Verify database cleanup still occurred despite storage failure - db.session.commit() + db_session_with_containers.commit() # Check that segment is deleted from database - deleted_segment = db.session.query(DocumentSegment).filter_by(id=segment_id).first() + deleted_segment = db_session_with_containers.query(DocumentSegment).filter_by(id=segment_id).first() assert deleted_segment is None # Check that upload file is deleted from database - deleted_file = db.session.query(UploadFile).filter_by(id=file_id).first() + deleted_file = db_session_with_containers.query(UploadFile).filter_by(id=file_id).first() assert deleted_file is None def test_batch_clean_document_task_multiple_documents( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test cleanup of multiple documents in a single batch operation. @@ -482,7 +482,7 @@ class TestBatchCleanDocumentTask: segments.append(segment) upload_files.append(upload_file) - db.session.commit() + db_session_with_containers.commit() # Store original IDs for verification document_ids = [doc.id for doc in documents] @@ -498,20 +498,20 @@ class TestBatchCleanDocumentTask: # The task should process all documents and clean up all associated resources # Verify database cleanup for all resources - db.session.commit() + db_session_with_containers.commit() # Check that all segments are deleted for segment_id in segment_ids: - deleted_segment = db.session.query(DocumentSegment).filter_by(id=segment_id).first() + deleted_segment = db_session_with_containers.query(DocumentSegment).filter_by(id=segment_id).first() assert deleted_segment is None # Check that all upload files are deleted for file_id in file_ids: - deleted_file = db.session.query(UploadFile).filter_by(id=file_id).first() + deleted_file = db_session_with_containers.query(UploadFile).filter_by(id=file_id).first() assert deleted_file is None def test_batch_clean_document_task_different_doc_forms( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test cleanup with different document form types. @@ -527,12 +527,12 @@ class TestBatchCleanDocumentTask: for doc_form in doc_forms: dataset = self._create_test_dataset(db_session_with_containers, account) - db.session.commit() + db_session_with_containers.commit() document = self._create_test_document(db_session_with_containers, dataset, account) # Update document doc_form document.doc_form = doc_form - db.session.commit() + db_session_with_containers.commit() segment = self._create_test_document_segment(db_session_with_containers, document, account) @@ -549,20 +549,20 @@ class TestBatchCleanDocumentTask: # The task should handle different document forms correctly # Verify database cleanup - db.session.commit() + db_session_with_containers.commit() # Check that segment is deleted - deleted_segment = db.session.query(DocumentSegment).filter_by(id=segment_id).first() + deleted_segment = db_session_with_containers.query(DocumentSegment).filter_by(id=segment_id).first() assert deleted_segment is None except Exception as e: # If the task fails due to external service issues (e.g., plugin daemon), # we should still verify that the database state is consistent # This is a common scenario in test environments where external services may not be available - db.session.commit() + db_session_with_containers.commit() # Check if the segment still exists (task may have failed before deletion) - existing_segment = db.session.query(DocumentSegment).filter_by(id=segment_id).first() + existing_segment = db_session_with_containers.query(DocumentSegment).filter_by(id=segment_id).first() if existing_segment is not None: # If segment still exists, the task failed before deletion # This is acceptable in test environments with external service issues @@ -572,7 +572,7 @@ class TestBatchCleanDocumentTask: pass def test_batch_clean_document_task_large_batch_performance( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test cleanup performance with a large batch of documents. @@ -604,7 +604,7 @@ class TestBatchCleanDocumentTask: segments.append(segment) upload_files.append(upload_file) - db.session.commit() + db_session_with_containers.commit() # Store original IDs for verification document_ids = [doc.id for doc in documents] @@ -629,20 +629,20 @@ class TestBatchCleanDocumentTask: # The task should handle large batches efficiently # Verify database cleanup for all resources - db.session.commit() + db_session_with_containers.commit() # Check that all segments are deleted for segment_id in segment_ids: - deleted_segment = db.session.query(DocumentSegment).filter_by(id=segment_id).first() + deleted_segment = db_session_with_containers.query(DocumentSegment).filter_by(id=segment_id).first() assert deleted_segment is None # Check that all upload files are deleted for file_id in file_ids: - deleted_file = db.session.query(UploadFile).filter_by(id=file_id).first() + deleted_file = db_session_with_containers.query(UploadFile).filter_by(id=file_id).first() assert deleted_file is None def test_batch_clean_document_task_integration_with_real_database( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test full integration with real database operations. @@ -683,12 +683,12 @@ class TestBatchCleanDocumentTask: # Add all to database for segment in segments: - db.session.add(segment) - db.session.commit() + db_session_with_containers.add(segment) + db_session_with_containers.commit() # Verify initial state - assert db.session.query(DocumentSegment).filter_by(document_id=document.id).count() == 3 - assert db.session.query(UploadFile).filter_by(id=upload_file.id).first() is not None + assert db_session_with_containers.query(DocumentSegment).filter_by(document_id=document.id).count() == 3 + assert db_session_with_containers.query(UploadFile).filter_by(id=upload_file.id).first() is not None # Store original IDs for verification document_id = document.id @@ -704,17 +704,17 @@ class TestBatchCleanDocumentTask: # The task should process all segments and clean up all associated resources # Verify database cleanup - db.session.commit() + db_session_with_containers.commit() # Check that all segments are deleted for segment_id in segment_ids: - deleted_segment = db.session.query(DocumentSegment).filter_by(id=segment_id).first() + deleted_segment = db_session_with_containers.query(DocumentSegment).filter_by(id=segment_id).first() assert deleted_segment is None # Check that upload file is deleted - deleted_file = db.session.query(UploadFile).filter_by(id=file_id).first() + deleted_file = db_session_with_containers.query(UploadFile).filter_by(id=file_id).first() assert deleted_file is None # Verify final database state - assert db.session.query(DocumentSegment).filter_by(document_id=document_id).count() == 0 - assert db.session.query(UploadFile).filter_by(id=file_id).first() is None + assert db_session_with_containers.query(DocumentSegment).filter_by(document_id=document_id).count() == 0 + assert db_session_with_containers.query(UploadFile).filter_by(id=file_id).first() is None diff --git a/api/tests/test_containers_integration_tests/tasks/test_batch_create_segment_to_index_task.py b/api/tests/test_containers_integration_tests/tasks/test_batch_create_segment_to_index_task.py index 2156743c17..a2324979db 100644 --- a/api/tests/test_containers_integration_tests/tasks/test_batch_create_segment_to_index_task.py +++ b/api/tests/test_containers_integration_tests/tasks/test_batch_create_segment_to_index_task.py @@ -17,6 +17,7 @@ from unittest.mock import MagicMock, patch import pytest from faker import Faker +from sqlalchemy.orm import Session from models import Account, Tenant, TenantAccountJoin, TenantAccountRole from models.dataset import Dataset, Document, DocumentSegment @@ -29,20 +30,19 @@ class TestBatchCreateSegmentToIndexTask: """Integration tests for batch_create_segment_to_index_task using testcontainers.""" @pytest.fixture(autouse=True) - def cleanup_database(self, db_session_with_containers): + def cleanup_database(self, db_session_with_containers: Session): """Clean up database before each test to ensure isolation.""" - from extensions.ext_database import db from extensions.ext_redis import redis_client # Clear all test data - db.session.query(DocumentSegment).delete() - db.session.query(Document).delete() - db.session.query(Dataset).delete() - db.session.query(UploadFile).delete() - db.session.query(TenantAccountJoin).delete() - db.session.query(Tenant).delete() - db.session.query(Account).delete() - db.session.commit() + db_session_with_containers.query(DocumentSegment).delete() + db_session_with_containers.query(Document).delete() + db_session_with_containers.query(Dataset).delete() + db_session_with_containers.query(UploadFile).delete() + db_session_with_containers.query(TenantAccountJoin).delete() + db_session_with_containers.query(Tenant).delete() + db_session_with_containers.query(Account).delete() + db_session_with_containers.commit() # Clear Redis cache redis_client.flushdb() @@ -75,7 +75,7 @@ class TestBatchCreateSegmentToIndexTask: "embedding_model": mock_embedding_model, } - def _create_test_account_and_tenant(self, db_session_with_containers): + def _create_test_account_and_tenant(self, db_session_with_containers: Session): """ Helper method to create a test account and tenant for testing. @@ -95,18 +95,16 @@ class TestBatchCreateSegmentToIndexTask: status="active", ) - from extensions.ext_database import db - - db.session.add(account) - db.session.commit() + db_session_with_containers.add(account) + db_session_with_containers.commit() # Create tenant for the account tenant = Tenant( name=fake.company(), status="normal", ) - db.session.add(tenant) - db.session.commit() + db_session_with_containers.add(tenant) + db_session_with_containers.commit() # Create tenant-account join join = TenantAccountJoin( @@ -115,15 +113,15 @@ class TestBatchCreateSegmentToIndexTask: role=TenantAccountRole.OWNER, current=True, ) - db.session.add(join) - db.session.commit() + db_session_with_containers.add(join) + db_session_with_containers.commit() # Set current tenant for account account.current_tenant = tenant return account, tenant - def _create_test_dataset(self, db_session_with_containers, account, tenant): + def _create_test_dataset(self, db_session_with_containers: Session, account, tenant): """ Helper method to create a test dataset for testing. @@ -148,14 +146,12 @@ class TestBatchCreateSegmentToIndexTask: created_by=account.id, ) - from extensions.ext_database import db - - db.session.add(dataset) - db.session.commit() + db_session_with_containers.add(dataset) + db_session_with_containers.commit() return dataset - def _create_test_document(self, db_session_with_containers, account, tenant, dataset): + def _create_test_document(self, db_session_with_containers: Session, account, tenant, dataset): """ Helper method to create a test document for testing. @@ -186,14 +182,12 @@ class TestBatchCreateSegmentToIndexTask: word_count=0, ) - from extensions.ext_database import db - - db.session.add(document) - db.session.commit() + db_session_with_containers.add(document) + db_session_with_containers.commit() return document - def _create_test_upload_file(self, db_session_with_containers, account, tenant): + def _create_test_upload_file(self, db_session_with_containers: Session, account, tenant): """ Helper method to create a test upload file for testing. @@ -221,10 +215,8 @@ class TestBatchCreateSegmentToIndexTask: used=False, ) - from extensions.ext_database import db - - db.session.add(upload_file) - db.session.commit() + db_session_with_containers.add(upload_file) + db_session_with_containers.commit() return upload_file @@ -252,7 +244,7 @@ class TestBatchCreateSegmentToIndexTask: return csv_content def test_batch_create_segment_to_index_task_success_text_model( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test successful batch creation of segments for text model documents. @@ -293,11 +285,10 @@ class TestBatchCreateSegmentToIndexTask: ) # Verify results - from extensions.ext_database import db # Check that segments were created segments = ( - db.session.query(DocumentSegment) + db_session_with_containers.query(DocumentSegment) .filter_by(document_id=document.id) .order_by(DocumentSegment.position) .all() @@ -316,7 +307,7 @@ class TestBatchCreateSegmentToIndexTask: assert segment.answer is None # text_model doesn't have answers # Check that document word count was updated - db.session.refresh(document) + db_session_with_containers.refresh(document) assert document.word_count > 0 # Verify vector service was called @@ -331,7 +322,7 @@ class TestBatchCreateSegmentToIndexTask: assert cache_value == b"completed" def test_batch_create_segment_to_index_task_dataset_not_found( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test task failure when dataset does not exist. @@ -370,17 +361,16 @@ class TestBatchCreateSegmentToIndexTask: assert cache_value == b"error" # Verify no segments were created (since dataset doesn't exist) - from extensions.ext_database import db - segments = db.session.query(DocumentSegment).all() + segments = db_session_with_containers.query(DocumentSegment).all() assert len(segments) == 0 # Verify no documents were modified - documents = db.session.query(Document).all() + documents = db_session_with_containers.query(Document).all() assert len(documents) == 0 def test_batch_create_segment_to_index_task_document_not_found( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test task failure when document does not exist. @@ -419,18 +409,17 @@ class TestBatchCreateSegmentToIndexTask: assert cache_value == b"error" # Verify no segments were created - from extensions.ext_database import db - segments = db.session.query(DocumentSegment).all() + segments = db_session_with_containers.query(DocumentSegment).all() assert len(segments) == 0 # Verify dataset remains unchanged (no segments were added to the dataset) - db.session.refresh(dataset) - segments_for_dataset = db.session.query(DocumentSegment).filter_by(dataset_id=dataset.id).all() + db_session_with_containers.refresh(dataset) + segments_for_dataset = db_session_with_containers.query(DocumentSegment).filter_by(dataset_id=dataset.id).all() assert len(segments_for_dataset) == 0 def test_batch_create_segment_to_index_task_document_not_available( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test task failure when document is not available for indexing. @@ -498,11 +487,9 @@ class TestBatchCreateSegmentToIndexTask: ), ] - from extensions.ext_database import db - for document in test_cases: - db.session.add(document) - db.session.commit() + db_session_with_containers.add(document) + db_session_with_containers.commit() # Test each unavailable document for document in test_cases: @@ -524,11 +511,11 @@ class TestBatchCreateSegmentToIndexTask: assert cache_value == b"error" # Verify no segments were created - segments = db.session.query(DocumentSegment).filter_by(document_id=document.id).all() + segments = db_session_with_containers.query(DocumentSegment).filter_by(document_id=document.id).all() assert len(segments) == 0 def test_batch_create_segment_to_index_task_upload_file_not_found( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test task failure when upload file does not exist. @@ -567,17 +554,16 @@ class TestBatchCreateSegmentToIndexTask: assert cache_value == b"error" # Verify no segments were created - from extensions.ext_database import db - segments = db.session.query(DocumentSegment).all() + segments = db_session_with_containers.query(DocumentSegment).all() assert len(segments) == 0 # Verify document remains unchanged - db.session.refresh(document) + db_session_with_containers.refresh(document) assert document.word_count == 0 def test_batch_create_segment_to_index_task_empty_csv_file( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test task failure when CSV file is empty. @@ -619,17 +605,16 @@ class TestBatchCreateSegmentToIndexTask: # Verify error handling # Since exception was raised, no segments should be created - from extensions.ext_database import db - segments = db.session.query(DocumentSegment).all() + segments = db_session_with_containers.query(DocumentSegment).all() assert len(segments) == 0 # Verify document remains unchanged - db.session.refresh(document) + db_session_with_containers.refresh(document) assert document.word_count == 0 def test_batch_create_segment_to_index_task_position_calculation( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test proper position calculation for segments when existing segments exist. @@ -664,11 +649,9 @@ class TestBatchCreateSegmentToIndexTask: ) existing_segments.append(segment) - from extensions.ext_database import db - for segment in existing_segments: - db.session.add(segment) - db.session.commit() + db_session_with_containers.add(segment) + db_session_with_containers.commit() # Create CSV content csv_content = self._create_test_csv_content("text_model") @@ -695,7 +678,7 @@ class TestBatchCreateSegmentToIndexTask: # Verify results # Check that new segments were created with correct positions all_segments = ( - db.session.query(DocumentSegment) + db_session_with_containers.query(DocumentSegment) .filter_by(document_id=document.id) .order_by(DocumentSegment.position) .all() @@ -716,7 +699,7 @@ class TestBatchCreateSegmentToIndexTask: assert segment.completed_at is not None # Check that document word count was updated - db.session.refresh(document) + db_session_with_containers.refresh(document) assert document.word_count > 0 # Verify vector service was called diff --git a/api/tests/test_containers_integration_tests/tasks/test_clean_dataset_task.py b/api/tests/test_containers_integration_tests/tasks/test_clean_dataset_task.py index cd99b2965f..8eb881258a 100644 --- a/api/tests/test_containers_integration_tests/tasks/test_clean_dataset_task.py +++ b/api/tests/test_containers_integration_tests/tasks/test_clean_dataset_task.py @@ -16,6 +16,7 @@ from unittest.mock import MagicMock, patch import pytest from faker import Faker +from sqlalchemy.orm import Session from models import Account, Tenant, TenantAccountJoin, TenantAccountRole from models.dataset import ( @@ -37,7 +38,7 @@ class TestCleanDatasetTask: """Integration tests for clean_dataset_task using testcontainers.""" @pytest.fixture(autouse=True) - def cleanup_database(self, db_session_with_containers): + def cleanup_database(self, db_session_with_containers: Session): """Clean up database before each test to ensure isolation.""" from extensions.ext_redis import redis_client @@ -82,7 +83,7 @@ class TestCleanDatasetTask: "index_processor": mock_index_processor, } - def _create_test_account_and_tenant(self, db_session_with_containers): + def _create_test_account_and_tenant(self, db_session_with_containers: Session): """ Helper method to create a test account and tenant for testing. @@ -127,7 +128,7 @@ class TestCleanDatasetTask: return account, tenant - def _create_test_dataset(self, db_session_with_containers, account, tenant): + def _create_test_dataset(self, db_session_with_containers: Session, account, tenant): """ Helper method to create a test dataset for testing. @@ -157,7 +158,7 @@ class TestCleanDatasetTask: return dataset - def _create_test_document(self, db_session_with_containers, account, tenant, dataset): + def _create_test_document(self, db_session_with_containers: Session, account, tenant, dataset): """ Helper method to create a test document for testing. @@ -194,7 +195,7 @@ class TestCleanDatasetTask: return document - def _create_test_segment(self, db_session_with_containers, account, tenant, dataset, document): + def _create_test_segment(self, db_session_with_containers: Session, account, tenant, dataset, document): """ Helper method to create a test document segment for testing. @@ -230,7 +231,7 @@ class TestCleanDatasetTask: return segment - def _create_test_upload_file(self, db_session_with_containers, account, tenant): + def _create_test_upload_file(self, db_session_with_containers: Session, account, tenant): """ Helper method to create a test upload file for testing. @@ -264,7 +265,7 @@ class TestCleanDatasetTask: return upload_file def test_clean_dataset_task_success_basic_cleanup( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test successful basic dataset cleanup with minimal data. @@ -325,7 +326,7 @@ class TestCleanDatasetTask: mock_storage.delete.assert_not_called() def test_clean_dataset_task_success_with_documents_and_segments( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test successful dataset cleanup with documents and segments. @@ -433,7 +434,7 @@ class TestCleanDatasetTask: assert mock_storage.delete.call_count == 3 def test_clean_dataset_task_success_with_invalid_doc_form( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test successful dataset cleanup with invalid doc_form handling. @@ -493,7 +494,7 @@ class TestCleanDatasetTask: assert mock_factory.call_count == 4 def test_clean_dataset_task_error_handling_and_rollback( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test error handling and rollback mechanism when database operations fail. @@ -542,7 +543,7 @@ class TestCleanDatasetTask: # This demonstrates the resilience of the cleanup process def test_clean_dataset_task_with_image_file_references( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test dataset cleanup with image file references in document segments. @@ -634,7 +635,7 @@ class TestCleanDatasetTask: mock_get_image_ids.assert_called_once() def test_clean_dataset_task_performance_with_large_dataset( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test dataset cleanup performance with large amounts of data. @@ -704,11 +705,9 @@ class TestCleanDatasetTask: binding.created_at = datetime.now() bindings.append(binding) - from extensions.ext_database import db - - db.session.add_all(metadata_items) - db.session.add_all(bindings) - db.session.commit() + db_session_with_containers.add_all(metadata_items) + db_session_with_containers.add_all(bindings) + db_session_with_containers.commit() # Measure cleanup performance import time @@ -772,7 +771,7 @@ class TestCleanDatasetTask: print(f"Average time per document: {cleanup_duration / len(documents):.3f} seconds") def test_clean_dataset_task_storage_exception_handling( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test dataset cleanup when storage operations fail. @@ -838,7 +837,7 @@ class TestCleanDatasetTask: # consistency in the database def test_clean_dataset_task_edge_cases_and_boundary_conditions( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test dataset cleanup with edge cases and boundary conditions. diff --git a/api/tests/test_containers_integration_tests/tasks/test_disable_segment_from_index_task.py b/api/tests/test_containers_integration_tests/tasks/test_disable_segment_from_index_task.py index 8785c948d1..ab9e5b639a 100644 --- a/api/tests/test_containers_integration_tests/tasks/test_disable_segment_from_index_task.py +++ b/api/tests/test_containers_integration_tests/tasks/test_disable_segment_from_index_task.py @@ -13,8 +13,8 @@ from unittest.mock import patch import pytest from faker import Faker +from sqlalchemy.orm import Session -from extensions.ext_database import db from extensions.ext_redis import redis_client from models import Account, Tenant, TenantAccountJoin, TenantAccountRole from models.dataset import Dataset, Document, DocumentSegment @@ -34,7 +34,7 @@ class TestDisableSegmentFromIndexTask: mock_processor.clean.return_value = None yield mock_processor - def _create_test_account_and_tenant(self, db_session_with_containers) -> tuple[Account, Tenant]: + def _create_test_account_and_tenant(self, db_session_with_containers: Session) -> tuple[Account, Tenant]: """ Helper method to create a test account and tenant for testing. @@ -53,8 +53,8 @@ class TestDisableSegmentFromIndexTask: interface_language="en-US", status="active", ) - db.session.add(account) - db.session.commit() + db_session_with_containers.add(account) + db_session_with_containers.commit() # Create tenant tenant = Tenant( @@ -62,8 +62,8 @@ class TestDisableSegmentFromIndexTask: status="normal", plan="basic", ) - db.session.add(tenant) - db.session.commit() + db_session_with_containers.add(tenant) + db_session_with_containers.commit() # Create tenant-account join with owner role join = TenantAccountJoin( @@ -72,15 +72,15 @@ class TestDisableSegmentFromIndexTask: role=TenantAccountRole.OWNER, current=True, ) - db.session.add(join) - db.session.commit() + db_session_with_containers.add(join) + db_session_with_containers.commit() # Set current tenant for account account.current_tenant = tenant return account, tenant - def _create_test_dataset(self, tenant: Tenant, account: Account) -> Dataset: + def _create_test_dataset(self, db_session_with_containers: Session, tenant: Tenant, account: Account) -> Dataset: """ Helper method to create a test dataset. @@ -101,13 +101,18 @@ class TestDisableSegmentFromIndexTask: indexing_technique="high_quality", created_by=account.id, ) - db.session.add(dataset) - db.session.commit() + db_session_with_containers.add(dataset) + db_session_with_containers.commit() return dataset def _create_test_document( - self, dataset: Dataset, tenant: Tenant, account: Account, doc_form: str = "text_model" + self, + db_session_with_containers: Session, + dataset: Dataset, + tenant: Tenant, + account: Account, + doc_form: str = "text_model", ) -> Document: """ Helper method to create a test document. @@ -140,13 +145,14 @@ class TestDisableSegmentFromIndexTask: tokens=500, completed_at=datetime.now(UTC), ) - db.session.add(document) - db.session.commit() + db_session_with_containers.add(document) + db_session_with_containers.commit() return document def _create_test_segment( self, + db_session_with_containers: Session, document: Document, dataset: Dataset, tenant: Tenant, @@ -185,12 +191,12 @@ class TestDisableSegmentFromIndexTask: created_by=account.id, completed_at=datetime.now(UTC) if status == "completed" else None, ) - db.session.add(segment) - db.session.commit() + db_session_with_containers.add(segment) + db_session_with_containers.commit() return segment - def test_disable_segment_success(self, db_session_with_containers, mock_index_processor): + def test_disable_segment_success(self, db_session_with_containers: Session, mock_index_processor): """ Test successful segment disabling from index. @@ -202,9 +208,9 @@ class TestDisableSegmentFromIndexTask: """ # Arrange: Create test data account, tenant = self._create_test_account_and_tenant(db_session_with_containers) - dataset = self._create_test_dataset(tenant, account) - document = self._create_test_document(dataset, tenant, account) - segment = self._create_test_segment(document, dataset, tenant, account) + dataset = self._create_test_dataset(db_session_with_containers, tenant, account) + document = self._create_test_document(db_session_with_containers, dataset, tenant, account) + segment = self._create_test_segment(db_session_with_containers, document, dataset, tenant, account) # Set up Redis cache indexing_cache_key = f"segment_{segment.id}_indexing" @@ -226,10 +232,10 @@ class TestDisableSegmentFromIndexTask: assert redis_client.get(indexing_cache_key) is None # Verify segment is still in database - db.session.refresh(segment) + db_session_with_containers.refresh(segment) assert segment.id is not None - def test_disable_segment_not_found(self, db_session_with_containers, mock_index_processor): + def test_disable_segment_not_found(self, db_session_with_containers: Session, mock_index_processor): """ Test handling when segment is not found. @@ -251,7 +257,7 @@ class TestDisableSegmentFromIndexTask: # Verify index processor was not called mock_index_processor.clean.assert_not_called() - def test_disable_segment_not_completed(self, db_session_with_containers, mock_index_processor): + def test_disable_segment_not_completed(self, db_session_with_containers: Session, mock_index_processor): """ Test handling when segment is not in completed status. @@ -262,9 +268,11 @@ class TestDisableSegmentFromIndexTask: """ # Arrange: Create test data with non-completed segment account, tenant = self._create_test_account_and_tenant(db_session_with_containers) - dataset = self._create_test_dataset(tenant, account) - document = self._create_test_document(dataset, tenant, account) - segment = self._create_test_segment(document, dataset, tenant, account, status="indexing", enabled=True) + dataset = self._create_test_dataset(db_session_with_containers, tenant, account) + document = self._create_test_document(db_session_with_containers, dataset, tenant, account) + segment = self._create_test_segment( + db_session_with_containers, document, dataset, tenant, account, status="indexing", enabled=True + ) # Act: Execute the task result = disable_segment_from_index_task(segment.id) @@ -275,7 +283,7 @@ class TestDisableSegmentFromIndexTask: # Verify index processor was not called mock_index_processor.clean.assert_not_called() - def test_disable_segment_no_dataset(self, db_session_with_containers, mock_index_processor): + def test_disable_segment_no_dataset(self, db_session_with_containers: Session, mock_index_processor): """ Test handling when segment has no associated dataset. @@ -286,13 +294,13 @@ class TestDisableSegmentFromIndexTask: """ # Arrange: Create test data account, tenant = self._create_test_account_and_tenant(db_session_with_containers) - dataset = self._create_test_dataset(tenant, account) - document = self._create_test_document(dataset, tenant, account) - segment = self._create_test_segment(document, dataset, tenant, account) + dataset = self._create_test_dataset(db_session_with_containers, tenant, account) + document = self._create_test_document(db_session_with_containers, dataset, tenant, account) + segment = self._create_test_segment(db_session_with_containers, document, dataset, tenant, account) # Manually remove dataset association segment.dataset_id = "00000000-0000-0000-0000-000000000000" - db.session.commit() + db_session_with_containers.commit() # Act: Execute the task result = disable_segment_from_index_task(segment.id) @@ -303,7 +311,7 @@ class TestDisableSegmentFromIndexTask: # Verify index processor was not called mock_index_processor.clean.assert_not_called() - def test_disable_segment_no_document(self, db_session_with_containers, mock_index_processor): + def test_disable_segment_no_document(self, db_session_with_containers: Session, mock_index_processor): """ Test handling when segment has no associated document. @@ -314,13 +322,13 @@ class TestDisableSegmentFromIndexTask: """ # Arrange: Create test data account, tenant = self._create_test_account_and_tenant(db_session_with_containers) - dataset = self._create_test_dataset(tenant, account) - document = self._create_test_document(dataset, tenant, account) - segment = self._create_test_segment(document, dataset, tenant, account) + dataset = self._create_test_dataset(db_session_with_containers, tenant, account) + document = self._create_test_document(db_session_with_containers, dataset, tenant, account) + segment = self._create_test_segment(db_session_with_containers, document, dataset, tenant, account) # Manually remove document association segment.document_id = "00000000-0000-0000-0000-000000000000" - db.session.commit() + db_session_with_containers.commit() # Act: Execute the task result = disable_segment_from_index_task(segment.id) @@ -331,7 +339,7 @@ class TestDisableSegmentFromIndexTask: # Verify index processor was not called mock_index_processor.clean.assert_not_called() - def test_disable_segment_document_disabled(self, db_session_with_containers, mock_index_processor): + def test_disable_segment_document_disabled(self, db_session_with_containers: Session, mock_index_processor): """ Test handling when document is disabled. @@ -342,12 +350,12 @@ class TestDisableSegmentFromIndexTask: """ # Arrange: Create test data with disabled document account, tenant = self._create_test_account_and_tenant(db_session_with_containers) - dataset = self._create_test_dataset(tenant, account) - document = self._create_test_document(dataset, tenant, account) + dataset = self._create_test_dataset(db_session_with_containers, tenant, account) + document = self._create_test_document(db_session_with_containers, dataset, tenant, account) document.enabled = False - db.session.commit() + db_session_with_containers.commit() - segment = self._create_test_segment(document, dataset, tenant, account) + segment = self._create_test_segment(db_session_with_containers, document, dataset, tenant, account) # Act: Execute the task result = disable_segment_from_index_task(segment.id) @@ -358,7 +366,7 @@ class TestDisableSegmentFromIndexTask: # Verify index processor was not called mock_index_processor.clean.assert_not_called() - def test_disable_segment_document_archived(self, db_session_with_containers, mock_index_processor): + def test_disable_segment_document_archived(self, db_session_with_containers: Session, mock_index_processor): """ Test handling when document is archived. @@ -369,12 +377,12 @@ class TestDisableSegmentFromIndexTask: """ # Arrange: Create test data with archived document account, tenant = self._create_test_account_and_tenant(db_session_with_containers) - dataset = self._create_test_dataset(tenant, account) - document = self._create_test_document(dataset, tenant, account) + dataset = self._create_test_dataset(db_session_with_containers, tenant, account) + document = self._create_test_document(db_session_with_containers, dataset, tenant, account) document.archived = True - db.session.commit() + db_session_with_containers.commit() - segment = self._create_test_segment(document, dataset, tenant, account) + segment = self._create_test_segment(db_session_with_containers, document, dataset, tenant, account) # Act: Execute the task result = disable_segment_from_index_task(segment.id) @@ -385,7 +393,9 @@ class TestDisableSegmentFromIndexTask: # Verify index processor was not called mock_index_processor.clean.assert_not_called() - def test_disable_segment_document_indexing_not_completed(self, db_session_with_containers, mock_index_processor): + def test_disable_segment_document_indexing_not_completed( + self, db_session_with_containers: Session, mock_index_processor + ): """ Test handling when document indexing is not completed. @@ -396,12 +406,12 @@ class TestDisableSegmentFromIndexTask: """ # Arrange: Create test data with incomplete indexing account, tenant = self._create_test_account_and_tenant(db_session_with_containers) - dataset = self._create_test_dataset(tenant, account) - document = self._create_test_document(dataset, tenant, account) + dataset = self._create_test_dataset(db_session_with_containers, tenant, account) + document = self._create_test_document(db_session_with_containers, dataset, tenant, account) document.indexing_status = "indexing" - db.session.commit() + db_session_with_containers.commit() - segment = self._create_test_segment(document, dataset, tenant, account) + segment = self._create_test_segment(db_session_with_containers, document, dataset, tenant, account) # Act: Execute the task result = disable_segment_from_index_task(segment.id) @@ -412,7 +422,7 @@ class TestDisableSegmentFromIndexTask: # Verify index processor was not called mock_index_processor.clean.assert_not_called() - def test_disable_segment_index_processor_exception(self, db_session_with_containers, mock_index_processor): + def test_disable_segment_index_processor_exception(self, db_session_with_containers: Session, mock_index_processor): """ Test handling when index processor raises an exception. @@ -424,9 +434,9 @@ class TestDisableSegmentFromIndexTask: """ # Arrange: Create test data account, tenant = self._create_test_account_and_tenant(db_session_with_containers) - dataset = self._create_test_dataset(tenant, account) - document = self._create_test_document(dataset, tenant, account) - segment = self._create_test_segment(document, dataset, tenant, account) + dataset = self._create_test_dataset(db_session_with_containers, tenant, account) + document = self._create_test_document(db_session_with_containers, dataset, tenant, account) + segment = self._create_test_segment(db_session_with_containers, document, dataset, tenant, account) # Set up Redis cache indexing_cache_key = f"segment_{segment.id}_indexing" @@ -449,13 +459,13 @@ class TestDisableSegmentFromIndexTask: assert call_args[0][1] == [segment.index_node_id] # Check index node IDs # Verify segment was re-enabled - db.session.refresh(segment) + db_session_with_containers.refresh(segment) assert segment.enabled is True # Verify Redis cache was still cleared assert redis_client.get(indexing_cache_key) is None - def test_disable_segment_different_doc_forms(self, db_session_with_containers, mock_index_processor): + def test_disable_segment_different_doc_forms(self, db_session_with_containers: Session, mock_index_processor): """ Test disabling segments with different document forms. @@ -470,9 +480,11 @@ class TestDisableSegmentFromIndexTask: for doc_form in doc_forms: # Arrange: Create test data for each form account, tenant = self._create_test_account_and_tenant(db_session_with_containers) - dataset = self._create_test_dataset(tenant, account) - document = self._create_test_document(dataset, tenant, account, doc_form=doc_form) - segment = self._create_test_segment(document, dataset, tenant, account) + dataset = self._create_test_dataset(db_session_with_containers, tenant, account) + document = self._create_test_document( + db_session_with_containers, dataset, tenant, account, doc_form=doc_form + ) + segment = self._create_test_segment(db_session_with_containers, document, dataset, tenant, account) # Reset mock for each iteration mock_index_processor.reset_mock() @@ -489,7 +501,7 @@ class TestDisableSegmentFromIndexTask: assert call_args[0][0].id == dataset.id # Check dataset ID assert call_args[0][1] == [segment.index_node_id] # Check index node IDs - def test_disable_segment_redis_cache_handling(self, db_session_with_containers, mock_index_processor): + def test_disable_segment_redis_cache_handling(self, db_session_with_containers: Session, mock_index_processor): """ Test Redis cache handling during segment disabling. @@ -500,9 +512,9 @@ class TestDisableSegmentFromIndexTask: """ # Arrange: Create test data account, tenant = self._create_test_account_and_tenant(db_session_with_containers) - dataset = self._create_test_dataset(tenant, account) - document = self._create_test_document(dataset, tenant, account) - segment = self._create_test_segment(document, dataset, tenant, account) + dataset = self._create_test_dataset(db_session_with_containers, tenant, account) + document = self._create_test_document(db_session_with_containers, dataset, tenant, account) + segment = self._create_test_segment(db_session_with_containers, document, dataset, tenant, account) # Test with cache present indexing_cache_key = f"segment_{segment.id}_indexing" @@ -517,13 +529,13 @@ class TestDisableSegmentFromIndexTask: assert redis_client.get(indexing_cache_key) is None # Test with no cache present - segment2 = self._create_test_segment(document, dataset, tenant, account) + segment2 = self._create_test_segment(db_session_with_containers, document, dataset, tenant, account) result2 = disable_segment_from_index_task(segment2.id) # Assert: Verify task still works without cache assert result2 is None - def test_disable_segment_performance_timing(self, db_session_with_containers, mock_index_processor): + def test_disable_segment_performance_timing(self, db_session_with_containers: Session, mock_index_processor): """ Test performance timing of segment disabling task. @@ -534,9 +546,9 @@ class TestDisableSegmentFromIndexTask: """ # Arrange: Create test data account, tenant = self._create_test_account_and_tenant(db_session_with_containers) - dataset = self._create_test_dataset(tenant, account) - document = self._create_test_document(dataset, tenant, account) - segment = self._create_test_segment(document, dataset, tenant, account) + dataset = self._create_test_dataset(db_session_with_containers, tenant, account) + document = self._create_test_document(db_session_with_containers, dataset, tenant, account) + segment = self._create_test_segment(db_session_with_containers, document, dataset, tenant, account) # Act: Execute the task and measure time start_time = time.perf_counter() @@ -548,7 +560,9 @@ class TestDisableSegmentFromIndexTask: execution_time = end_time - start_time assert execution_time < 5.0 # Should complete within 5 seconds - def test_disable_segment_database_session_management(self, db_session_with_containers, mock_index_processor): + def test_disable_segment_database_session_management( + self, db_session_with_containers: Session, mock_index_processor + ): """ Test database session management during task execution. @@ -559,9 +573,9 @@ class TestDisableSegmentFromIndexTask: """ # Arrange: Create test data account, tenant = self._create_test_account_and_tenant(db_session_with_containers) - dataset = self._create_test_dataset(tenant, account) - document = self._create_test_document(dataset, tenant, account) - segment = self._create_test_segment(document, dataset, tenant, account) + dataset = self._create_test_dataset(db_session_with_containers, tenant, account) + document = self._create_test_document(db_session_with_containers, dataset, tenant, account) + segment = self._create_test_segment(db_session_with_containers, document, dataset, tenant, account) # Act: Execute the task result = disable_segment_from_index_task(segment.id) @@ -570,10 +584,10 @@ class TestDisableSegmentFromIndexTask: assert result is None # Verify segment is still accessible (session was properly managed) - db.session.refresh(segment) + db_session_with_containers.refresh(segment) assert segment.id is not None - def test_disable_segment_concurrent_execution(self, db_session_with_containers, mock_index_processor): + def test_disable_segment_concurrent_execution(self, db_session_with_containers: Session, mock_index_processor): """ Test concurrent execution of segment disabling tasks. @@ -584,12 +598,12 @@ class TestDisableSegmentFromIndexTask: """ # Arrange: Create multiple test segments account, tenant = self._create_test_account_and_tenant(db_session_with_containers) - dataset = self._create_test_dataset(tenant, account) - document = self._create_test_document(dataset, tenant, account) + dataset = self._create_test_dataset(db_session_with_containers, tenant, account) + document = self._create_test_document(db_session_with_containers, dataset, tenant, account) segments = [] for i in range(3): - segment = self._create_test_segment(document, dataset, tenant, account) + segment = self._create_test_segment(db_session_with_containers, document, dataset, tenant, account) segments.append(segment) # Act: Execute tasks concurrently (simulated) diff --git a/api/tests/test_containers_integration_tests/tasks/test_disable_segments_from_index_task.py b/api/tests/test_containers_integration_tests/tasks/test_disable_segments_from_index_task.py index a93a80e231..8f47b48ae2 100644 --- a/api/tests/test_containers_integration_tests/tasks/test_disable_segments_from_index_task.py +++ b/api/tests/test_containers_integration_tests/tasks/test_disable_segments_from_index_task.py @@ -9,6 +9,7 @@ The task is responsible for removing document segments from the search index whe from unittest.mock import MagicMock, patch from faker import Faker +from sqlalchemy.orm import Session from models import Account, Dataset, DocumentSegment from models import Document as DatasetDocument @@ -31,7 +32,7 @@ class TestDisableSegmentsFromIndexTask: and realistic testing environment with actual database interactions. """ - def _create_test_account(self, db_session_with_containers, fake=None): + def _create_test_account(self, db_session_with_containers: Session, fake=None): """ Helper method to create a test account with realistic data. @@ -79,7 +80,7 @@ class TestDisableSegmentsFromIndexTask: return account - def _create_test_dataset(self, db_session_with_containers, account, fake=None): + def _create_test_dataset(self, db_session_with_containers: Session, account, fake=None): """ Helper method to create a test dataset with realistic data. @@ -113,7 +114,7 @@ class TestDisableSegmentsFromIndexTask: return dataset - def _create_test_document(self, db_session_with_containers, dataset, account, fake=None): + def _create_test_document(self, db_session_with_containers: Session, dataset, account, fake=None): """ Helper method to create a test document with realistic data. @@ -158,7 +159,9 @@ class TestDisableSegmentsFromIndexTask: return document - def _create_test_segments(self, db_session_with_containers, document, dataset, account, count=3, fake=None): + def _create_test_segments( + self, db_session_with_containers: Session, document, dataset, account, count=3, fake=None + ): """ Helper method to create test document segments with realistic data. @@ -210,7 +213,7 @@ class TestDisableSegmentsFromIndexTask: return segments - def _create_dataset_process_rule(self, db_session_with_containers, dataset, fake=None): + def _create_dataset_process_rule(self, db_session_with_containers: Session, dataset, fake=None): """ Helper method to create a dataset process rule. @@ -239,14 +242,12 @@ class TestDisableSegmentsFromIndexTask: process_rule.created_by = dataset.created_by process_rule.updated_by = dataset.updated_by - from extensions.ext_database import db - - db.session.add(process_rule) - db.session.commit() + db_session_with_containers.add(process_rule) + db_session_with_containers.commit() return process_rule - def test_disable_segments_success(self, db_session_with_containers): + def test_disable_segments_success(self, db_session_with_containers: Session): """ Test successful disabling of segments from index. @@ -297,7 +298,7 @@ class TestDisableSegmentsFromIndexTask: expected_key = f"segment_{segment.id}_indexing" mock_redis.delete.assert_any_call(expected_key) - def test_disable_segments_dataset_not_found(self, db_session_with_containers): + def test_disable_segments_dataset_not_found(self, db_session_with_containers: Session): """ Test handling when dataset is not found. @@ -320,7 +321,7 @@ class TestDisableSegmentsFromIndexTask: # Redis should not be called when dataset is not found mock_redis.delete.assert_not_called() - def test_disable_segments_document_not_found(self, db_session_with_containers): + def test_disable_segments_document_not_found(self, db_session_with_containers: Session): """ Test handling when document is not found. @@ -344,7 +345,7 @@ class TestDisableSegmentsFromIndexTask: # Redis should not be called when document is not found mock_redis.delete.assert_not_called() - def test_disable_segments_document_invalid_status(self, db_session_with_containers): + def test_disable_segments_document_invalid_status(self, db_session_with_containers: Session): """ Test handling when document has invalid status for disabling. @@ -360,9 +361,8 @@ class TestDisableSegmentsFromIndexTask: # Test case 1: Document not enabled document.enabled = False - from extensions.ext_database import db - db.session.commit() + db_session_with_containers.commit() segment_ids = [segment.id for segment in segments] @@ -379,7 +379,7 @@ class TestDisableSegmentsFromIndexTask: # Test case 2: Document archived document.enabled = True document.archived = True - db.session.commit() + db_session_with_containers.commit() with patch("tasks.disable_segments_from_index_task.redis_client") as mock_redis: # Act @@ -393,7 +393,7 @@ class TestDisableSegmentsFromIndexTask: document.enabled = True document.archived = False document.indexing_status = "indexing" - db.session.commit() + db_session_with_containers.commit() with patch("tasks.disable_segments_from_index_task.redis_client") as mock_redis: # Act @@ -403,7 +403,7 @@ class TestDisableSegmentsFromIndexTask: assert result is None # Task should complete without returning a value mock_redis.delete.assert_not_called() - def test_disable_segments_no_segments_found(self, db_session_with_containers): + def test_disable_segments_no_segments_found(self, db_session_with_containers: Session): """ Test handling when no segments are found for the given IDs. @@ -430,7 +430,7 @@ class TestDisableSegmentsFromIndexTask: # Redis should not be called when no segments are found mock_redis.delete.assert_not_called() - def test_disable_segments_index_processor_error(self, db_session_with_containers): + def test_disable_segments_index_processor_error(self, db_session_with_containers: Session): """ Test handling when index processor encounters an error. @@ -464,13 +464,14 @@ class TestDisableSegmentsFromIndexTask: assert result is None # Task should complete without returning a value # Verify segments were rolled back to enabled state - from extensions.ext_database import db - db.session.refresh(segments[0]) - db.session.refresh(segments[1]) + db_session_with_containers.refresh(segments[0]) + db_session_with_containers.refresh(segments[1]) # Check that segments are re-enabled after error - updated_segments = db.session.query(DocumentSegment).where(DocumentSegment.id.in_(segment_ids)).all() + updated_segments = ( + db_session_with_containers.query(DocumentSegment).where(DocumentSegment.id.in_(segment_ids)).all() + ) for segment in updated_segments: assert segment.enabled is True @@ -480,7 +481,7 @@ class TestDisableSegmentsFromIndexTask: # Verify Redis cache cleanup was still called assert mock_redis.delete.call_count == len(segments) - def test_disable_segments_with_different_doc_forms(self, db_session_with_containers): + def test_disable_segments_with_different_doc_forms(self, db_session_with_containers: Session): """ Test disabling segments with different document forms. @@ -503,9 +504,8 @@ class TestDisableSegmentsFromIndexTask: for doc_form in doc_forms: # Update document form document.doc_form = doc_form - from extensions.ext_database import db - db.session.commit() + db_session_with_containers.commit() # Mock the index processor factory with patch("tasks.disable_segments_from_index_task.IndexProcessorFactory") as mock_factory: @@ -523,7 +523,7 @@ class TestDisableSegmentsFromIndexTask: assert result is None # Task should complete without returning a value mock_factory.assert_called_with(doc_form) - def test_disable_segments_performance_timing(self, db_session_with_containers): + def test_disable_segments_performance_timing(self, db_session_with_containers: Session): """ Test that the task properly measures and logs performance timing. @@ -568,7 +568,7 @@ class TestDisableSegmentsFromIndexTask: assert performance_log is not None assert "0.5" in performance_log # Should log the execution time - def test_disable_segments_redis_cache_cleanup(self, db_session_with_containers): + def test_disable_segments_redis_cache_cleanup(self, db_session_with_containers: Session): """ Test that Redis cache is properly cleaned up for all segments. @@ -610,7 +610,7 @@ class TestDisableSegmentsFromIndexTask: for expected_key in expected_keys: assert expected_key in actual_calls - def test_disable_segments_database_session_cleanup(self, db_session_with_containers): + def test_disable_segments_database_session_cleanup(self, db_session_with_containers: Session): """ Test that database session is properly closed after task execution. @@ -643,7 +643,7 @@ class TestDisableSegmentsFromIndexTask: assert result is None # Task should complete without returning a value # Session lifecycle is managed by context manager; no explicit close assertion - def test_disable_segments_empty_segment_ids(self, db_session_with_containers): + def test_disable_segments_empty_segment_ids(self, db_session_with_containers: Session): """ Test handling when empty segment IDs list is provided. @@ -669,7 +669,7 @@ class TestDisableSegmentsFromIndexTask: # Redis should not be called when no segments are provided mock_redis.delete.assert_not_called() - def test_disable_segments_mixed_valid_invalid_ids(self, db_session_with_containers): + def test_disable_segments_mixed_valid_invalid_ids(self, db_session_with_containers: Session): """ Test handling when some segment IDs are valid and others are invalid. diff --git a/api/tests/test_containers_integration_tests/tasks/test_enable_segments_to_index_task.py b/api/tests/test_containers_integration_tests/tasks/test_enable_segments_to_index_task.py index b3d9e49b30..bc29395545 100644 --- a/api/tests/test_containers_integration_tests/tasks/test_enable_segments_to_index_task.py +++ b/api/tests/test_containers_integration_tests/tasks/test_enable_segments_to_index_task.py @@ -2,9 +2,9 @@ from unittest.mock import MagicMock, patch import pytest from faker import Faker +from sqlalchemy.orm import Session from core.rag.index_processor.constant.index_type import IndexStructureType -from extensions.ext_database import db from extensions.ext_redis import redis_client from models.account import Account, Tenant, TenantAccountJoin, TenantAccountRole from models.dataset import Dataset, Document, DocumentSegment @@ -31,7 +31,9 @@ class TestEnableSegmentsToIndexTask: "index_processor": mock_processor, } - def _create_test_dataset_and_document(self, db_session_with_containers, mock_external_service_dependencies): + def _create_test_dataset_and_document( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Helper method to create a test dataset and document for testing. @@ -51,15 +53,15 @@ class TestEnableSegmentsToIndexTask: interface_language="en-US", status="active", ) - db.session.add(account) - db.session.commit() + db_session_with_containers.add(account) + db_session_with_containers.commit() tenant = Tenant( name=fake.company(), status="normal", ) - db.session.add(tenant) - db.session.commit() + db_session_with_containers.add(tenant) + db_session_with_containers.commit() # Create tenant-account join join = TenantAccountJoin( @@ -68,8 +70,8 @@ class TestEnableSegmentsToIndexTask: role=TenantAccountRole.OWNER, current=True, ) - db.session.add(join) - db.session.commit() + db_session_with_containers.add(join) + db_session_with_containers.commit() # Create dataset dataset = Dataset( @@ -81,8 +83,8 @@ class TestEnableSegmentsToIndexTask: indexing_technique="high_quality", created_by=account.id, ) - db.session.add(dataset) - db.session.commit() + db_session_with_containers.add(dataset) + db_session_with_containers.commit() # Create document document = Document( @@ -99,16 +101,16 @@ class TestEnableSegmentsToIndexTask: enabled=True, doc_form=IndexStructureType.PARAGRAPH_INDEX, ) - db.session.add(document) - db.session.commit() + db_session_with_containers.add(document) + db_session_with_containers.commit() # Refresh dataset to ensure doc_form property works correctly - db.session.refresh(dataset) + db_session_with_containers.refresh(dataset) return dataset, document def _create_test_segments( - self, db_session_with_containers, document, dataset, count=3, enabled=False, status="completed" + self, db_session_with_containers: Session, document, dataset, count=3, enabled=False, status="completed" ): """ Helper method to create test document segments. @@ -144,14 +146,14 @@ class TestEnableSegmentsToIndexTask: status=status, created_by=document.created_by, ) - db.session.add(segment) + db_session_with_containers.add(segment) segments.append(segment) - db.session.commit() + db_session_with_containers.commit() return segments def test_enable_segments_to_index_with_different_index_type( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test segments indexing with different index types. @@ -169,10 +171,10 @@ class TestEnableSegmentsToIndexTask: # Update document to use different index type document.doc_form = IndexStructureType.QA_INDEX - db.session.commit() + db_session_with_containers.commit() # Refresh dataset to ensure doc_form property reflects the updated document - db.session.refresh(dataset) + db_session_with_containers.refresh(dataset) # Create segments segments = self._create_test_segments(db_session_with_containers, document, dataset) @@ -204,7 +206,7 @@ class TestEnableSegmentsToIndexTask: assert redis_client.exists(indexing_cache_key) == 0 def test_enable_segments_to_index_dataset_not_found( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test handling of non-existent dataset. @@ -229,7 +231,7 @@ class TestEnableSegmentsToIndexTask: mock_external_service_dependencies["index_processor"].load.assert_not_called() def test_enable_segments_to_index_document_not_found( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test handling of non-existent document. @@ -256,7 +258,7 @@ class TestEnableSegmentsToIndexTask: mock_external_service_dependencies["index_processor"].load.assert_not_called() def test_enable_segments_to_index_invalid_document_status( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test handling of document with invalid status. @@ -284,12 +286,12 @@ class TestEnableSegmentsToIndexTask: document.enabled = True document.archived = False document.indexing_status = "completed" - db.session.commit() + db_session_with_containers.commit() # Set invalid status for attr, value in status_attrs.items(): setattr(document, attr, value) - db.session.commit() + db_session_with_containers.commit() # Create segments segments = self._create_test_segments(db_session_with_containers, document, dataset) @@ -304,11 +306,11 @@ class TestEnableSegmentsToIndexTask: # Clean up segments for next iteration for segment in segments: - db.session.delete(segment) - db.session.commit() + db_session_with_containers.delete(segment) + db_session_with_containers.commit() def test_enable_segments_to_index_segments_not_found( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test handling when no segments are found. @@ -338,7 +340,7 @@ class TestEnableSegmentsToIndexTask: mock_external_service_dependencies["index_processor"].load.assert_not_called() def test_enable_segments_to_index_with_parent_child_structure( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test segments indexing with parent-child structure. @@ -357,10 +359,10 @@ class TestEnableSegmentsToIndexTask: # Update document to use parent-child index type document.doc_form = IndexStructureType.PARENT_CHILD_INDEX - db.session.commit() + db_session_with_containers.commit() # Refresh dataset to ensure doc_form property reflects the updated document - db.session.refresh(dataset) + db_session_with_containers.refresh(dataset) # Create segments with mock child chunks segments = self._create_test_segments(db_session_with_containers, document, dataset) @@ -410,7 +412,7 @@ class TestEnableSegmentsToIndexTask: assert redis_client.exists(indexing_cache_key) == 0 def test_enable_segments_to_index_general_exception_handling( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test general exception handling during indexing process. @@ -443,7 +445,7 @@ class TestEnableSegmentsToIndexTask: # Assert: Verify error handling for segment in segments: - db.session.refresh(segment) + db_session_with_containers.refresh(segment) assert segment.enabled is False assert segment.status == "error" assert segment.error is not None diff --git a/api/tests/test_containers_integration_tests/tasks/test_mail_account_deletion_task.py b/api/tests/test_containers_integration_tests/tasks/test_mail_account_deletion_task.py index 6c3a9ef20a..ff72232d12 100644 --- a/api/tests/test_containers_integration_tests/tasks/test_mail_account_deletion_task.py +++ b/api/tests/test_containers_integration_tests/tasks/test_mail_account_deletion_task.py @@ -2,8 +2,8 @@ from unittest.mock import patch import pytest from faker import Faker +from sqlalchemy.orm import Session -from extensions.ext_database import db from libs.email_i18n import EmailType from models.account import Account, Tenant, TenantAccountJoin, TenantAccountRole from tasks.mail_account_deletion_task import send_account_deletion_verification_code, send_deletion_success_task @@ -30,7 +30,7 @@ class TestMailAccountDeletionTask: "email_service": mock_email_service, } - def _create_test_account(self, db_session_with_containers): + def _create_test_account(self, db_session_with_containers: Session): """ Helper method to create a test account for testing. @@ -49,16 +49,16 @@ class TestMailAccountDeletionTask: interface_language="en-US", status="active", ) - db.session.add(account) - db.session.commit() + db_session_with_containers.add(account) + db_session_with_containers.commit() # Create tenant tenant = Tenant( name=fake.company(), status="normal", ) - db.session.add(tenant) - db.session.commit() + db_session_with_containers.add(tenant) + db_session_with_containers.commit() # Create tenant-account join join = TenantAccountJoin( @@ -67,12 +67,14 @@ class TestMailAccountDeletionTask: role=TenantAccountRole.OWNER, current=True, ) - db.session.add(join) - db.session.commit() + db_session_with_containers.add(join) + db_session_with_containers.commit() return account - def test_send_deletion_success_task_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_send_deletion_success_task_success( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test successful account deletion success email sending. @@ -109,7 +111,7 @@ class TestMailAccountDeletionTask: ) def test_send_deletion_success_task_mail_not_initialized( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test account deletion success email when mail service is not initialized. @@ -132,7 +134,7 @@ class TestMailAccountDeletionTask: mock_external_service_dependencies["email_service"].send_email.assert_not_called() def test_send_deletion_success_task_email_service_exception( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test account deletion success email when email service raises exception. @@ -154,7 +156,7 @@ class TestMailAccountDeletionTask: mock_external_service_dependencies["email_service"].send_email.assert_called_once() def test_send_account_deletion_verification_code_success( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test successful account deletion verification code email sending. @@ -193,7 +195,7 @@ class TestMailAccountDeletionTask: ) def test_send_account_deletion_verification_code_mail_not_initialized( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test account deletion verification code email when mail service is not initialized. @@ -217,7 +219,7 @@ class TestMailAccountDeletionTask: mock_external_service_dependencies["email_service"].send_email.assert_not_called() def test_send_account_deletion_verification_code_email_service_exception( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test account deletion verification code email when email service raises exception. diff --git a/api/tests/test_containers_integration_tests/tasks/test_rag_pipeline_run_tasks.py b/api/tests/test_containers_integration_tests/tasks/test_rag_pipeline_run_tasks.py index b9977b1fb6..ef7191299a 100644 --- a/api/tests/test_containers_integration_tests/tasks/test_rag_pipeline_run_tasks.py +++ b/api/tests/test_containers_integration_tests/tasks/test_rag_pipeline_run_tasks.py @@ -4,11 +4,11 @@ from unittest.mock import patch import pytest from faker import Faker +from sqlalchemy.orm import Session from core.app.entities.app_invoke_entities import InvokeFrom, RagPipelineGenerateEntity from core.app.entities.rag_pipeline_invoke_entities import RagPipelineInvokeEntity from core.rag.pipeline.queue import TenantIsolatedTaskQueue -from extensions.ext_database import db from models import Account, Tenant, TenantAccountJoin, TenantAccountRole from models.dataset import Pipeline from models.workflow import Workflow @@ -52,7 +52,7 @@ class TestRagPipelineRunTasks: "delete_file": mock_delete_file, } - def _create_test_pipeline_and_workflow(self, db_session_with_containers): + def _create_test_pipeline_and_workflow(self, db_session_with_containers: Session): """ Helper method to create test pipeline and workflow for testing. @@ -71,15 +71,15 @@ class TestRagPipelineRunTasks: interface_language="en-US", status="active", ) - db.session.add(account) - db.session.commit() + db_session_with_containers.add(account) + db_session_with_containers.commit() tenant = Tenant( name=fake.company(), status="normal", ) - db.session.add(tenant) - db.session.commit() + db_session_with_containers.add(tenant) + db_session_with_containers.commit() # Create tenant-account join join = TenantAccountJoin( @@ -88,8 +88,8 @@ class TestRagPipelineRunTasks: role=TenantAccountRole.OWNER, current=True, ) - db.session.add(join) - db.session.commit() + db_session_with_containers.add(join) + db_session_with_containers.commit() # Create workflow workflow = Workflow( @@ -107,8 +107,8 @@ class TestRagPipelineRunTasks: conversation_variables=[], rag_pipeline_variables=[], ) - db.session.add(workflow) - db.session.commit() + db_session_with_containers.add(workflow) + db_session_with_containers.commit() # Create pipeline pipeline = Pipeline( @@ -119,14 +119,14 @@ class TestRagPipelineRunTasks: created_by=account.id, ) pipeline.id = str(uuid.uuid4()) - db.session.add(pipeline) - db.session.commit() + db_session_with_containers.add(pipeline) + db_session_with_containers.commit() # Refresh entities to ensure they're properly loaded - db.session.refresh(account) - db.session.refresh(tenant) - db.session.refresh(workflow) - db.session.refresh(pipeline) + db_session_with_containers.refresh(account) + db_session_with_containers.refresh(tenant) + db_session_with_containers.refresh(workflow) + db_session_with_containers.refresh(pipeline) return account, tenant, pipeline, workflow @@ -209,7 +209,7 @@ class TestRagPipelineRunTasks: return json.dumps(entities_data) def test_priority_rag_pipeline_run_task_success( - self, db_session_with_containers, mock_pipeline_generator, mock_file_service + self, db_session_with_containers: Session, mock_pipeline_generator, mock_file_service ): """ Test successful priority RAG pipeline run task execution. @@ -254,7 +254,7 @@ class TestRagPipelineRunTasks: assert isinstance(call_kwargs["application_generate_entity"], RagPipelineGenerateEntity) def test_rag_pipeline_run_task_success( - self, db_session_with_containers, mock_pipeline_generator, mock_file_service + self, db_session_with_containers: Session, mock_pipeline_generator, mock_file_service ): """ Test successful regular RAG pipeline run task execution. @@ -299,7 +299,7 @@ class TestRagPipelineRunTasks: assert isinstance(call_kwargs["application_generate_entity"], RagPipelineGenerateEntity) def test_priority_rag_pipeline_run_task_with_waiting_tasks( - self, db_session_with_containers, mock_pipeline_generator, mock_file_service + self, db_session_with_containers: Session, mock_pipeline_generator, mock_file_service ): """ Test priority RAG pipeline run task with waiting tasks in queue using real Redis. @@ -351,7 +351,7 @@ class TestRagPipelineRunTasks: assert len(remaining_tasks) == 1 # 2 original - 1 pulled = 1 remaining def test_rag_pipeline_run_task_legacy_compatibility( - self, db_session_with_containers, mock_pipeline_generator, mock_file_service + self, db_session_with_containers: Session, mock_pipeline_generator, mock_file_service ): """ Test regular RAG pipeline run task with legacy Redis queue format for backward compatibility. @@ -419,7 +419,7 @@ class TestRagPipelineRunTasks: redis_client.delete(legacy_task_key) def test_rag_pipeline_run_task_with_waiting_tasks( - self, db_session_with_containers, mock_pipeline_generator, mock_file_service + self, db_session_with_containers: Session, mock_pipeline_generator, mock_file_service ): """ Test regular RAG pipeline run task with waiting tasks in queue using real Redis. @@ -469,7 +469,7 @@ class TestRagPipelineRunTasks: assert len(remaining_tasks) == 2 # 3 original - 1 pulled = 2 remaining def test_priority_rag_pipeline_run_task_error_handling( - self, db_session_with_containers, mock_pipeline_generator, mock_file_service + self, db_session_with_containers: Session, mock_pipeline_generator, mock_file_service ): """ Test error handling in priority RAG pipeline run task using real Redis. @@ -526,7 +526,7 @@ class TestRagPipelineRunTasks: assert len(remaining_tasks) == 0 def test_rag_pipeline_run_task_error_handling( - self, db_session_with_containers, mock_pipeline_generator, mock_file_service + self, db_session_with_containers: Session, mock_pipeline_generator, mock_file_service ): """ Test error handling in regular RAG pipeline run task using real Redis. @@ -581,7 +581,7 @@ class TestRagPipelineRunTasks: assert len(remaining_tasks) == 0 def test_priority_rag_pipeline_run_task_tenant_isolation( - self, db_session_with_containers, mock_pipeline_generator, mock_file_service + self, db_session_with_containers: Session, mock_pipeline_generator, mock_file_service ): """ Test tenant isolation in priority RAG pipeline run task using real Redis. @@ -648,7 +648,7 @@ class TestRagPipelineRunTasks: assert queue1._task_key != queue2._task_key def test_rag_pipeline_run_task_tenant_isolation( - self, db_session_with_containers, mock_pipeline_generator, mock_file_service + self, db_session_with_containers: Session, mock_pipeline_generator, mock_file_service ): """ Test tenant isolation in regular RAG pipeline run task using real Redis. @@ -713,7 +713,7 @@ class TestRagPipelineRunTasks: assert queue1._task_key != queue2._task_key def test_run_single_rag_pipeline_task_success( - self, db_session_with_containers, mock_pipeline_generator, flask_app_with_containers + self, db_session_with_containers: Session, mock_pipeline_generator, flask_app_with_containers ): """ Test successful run_single_rag_pipeline_task execution. @@ -748,7 +748,7 @@ class TestRagPipelineRunTasks: assert isinstance(call_kwargs["application_generate_entity"], RagPipelineGenerateEntity) def test_run_single_rag_pipeline_task_entity_validation_error( - self, db_session_with_containers, mock_pipeline_generator, flask_app_with_containers + self, db_session_with_containers: Session, mock_pipeline_generator, flask_app_with_containers ): """ Test run_single_rag_pipeline_task with invalid entity data. @@ -793,7 +793,7 @@ class TestRagPipelineRunTasks: mock_pipeline_generator.assert_not_called() def test_run_single_rag_pipeline_task_database_entity_not_found( - self, db_session_with_containers, mock_pipeline_generator, flask_app_with_containers + self, db_session_with_containers: Session, mock_pipeline_generator, flask_app_with_containers ): """ Test run_single_rag_pipeline_task with non-existent database entities. @@ -838,7 +838,7 @@ class TestRagPipelineRunTasks: mock_pipeline_generator.assert_not_called() def test_priority_rag_pipeline_run_task_file_not_found( - self, db_session_with_containers, mock_pipeline_generator, mock_file_service + self, db_session_with_containers: Session, mock_pipeline_generator, mock_file_service ): """ Test priority RAG pipeline run task with non-existent file. @@ -888,7 +888,7 @@ class TestRagPipelineRunTasks: assert len(remaining_tasks) == 0 def test_rag_pipeline_run_task_file_not_found( - self, db_session_with_containers, mock_pipeline_generator, mock_file_service + self, db_session_with_containers: Session, mock_pipeline_generator, mock_file_service ): """ Test regular RAG pipeline run task with non-existent file. From dfc6de69c359284c9a0a9f3128f236c91e987ab3 Mon Sep 17 00:00:00 2001 From: yyh <92089059+lyzno1@users.noreply.github.com> Date: Wed, 4 Mar 2026 13:55:13 +0800 Subject: [PATCH 020/256] refactor(web): migrate Button to Base UI with focus-visible (#32941) Signed-off-by: yyh --- .../base/button/__tests__/index.spec.tsx | 222 +++++++++++------- web/app/components/base/button/index.css | 27 ++- .../components/base/button/index.stories.tsx | 45 +++- web/app/components/base/button/index.tsx | 42 +++- .../create/step-two/__tests__/index.spec.tsx | 6 +- .../hit-testing/__tests__/index.spec.tsx | 48 +++- .../header/account-dropdown/compliance.tsx | 3 +- web/app/styles/globals.css | 4 - 8 files changed, 267 insertions(+), 130 deletions(-) diff --git a/web/app/components/base/button/__tests__/index.spec.tsx b/web/app/components/base/button/__tests__/index.spec.tsx index b43ae89403..4fe0ab3570 100644 --- a/web/app/components/base/button/__tests__/index.spec.tsx +++ b/web/app/components/base/button/__tests__/index.spec.tsx @@ -1,110 +1,156 @@ -import { cleanup, fireEvent, render } from '@testing-library/react' -import * as React from 'react' +import { cleanup, fireEvent, render, screen } from '@testing-library/react' import Button from '../index' afterEach(cleanup) -// https://testing-library.com/docs/queries/about + describe('Button', () => { - describe('Button text', () => { - it('Button text should be same as children', async () => { - const { getByRole, container } = render() - expect(getByRole('button').textContent).toBe('Click me') - expect(container.querySelector('button')?.textContent).toBe('Click me') + describe('rendering', () => { + it('renders children text', () => { + render() + expect(screen.getByRole('button')).toHaveTextContent('Click me') + }) + + it('renders as a native button element by default', () => { + render() + expect(screen.getByRole('button').tagName).toBe('BUTTON') + }) + + it('defaults to type="button"', () => { + render() + expect(screen.getByRole('button')).toHaveAttribute('type', 'button') + }) + + it('allows type override to submit', () => { + render() + expect(screen.getByRole('button')).toHaveAttribute('type', 'submit') + }) + + it('renders custom element via render prop', () => { + render() + const link = screen.getByRole('link') + expect(link).toHaveTextContent('Link') + expect(link).toHaveAttribute('href', '/test') }) }) - describe('Button loading', () => { - it('Loading button text should include same as children', async () => { - const { getByRole } = render() - expect(getByRole('button').textContent?.includes('Loading')).toBe(true) - }) - it('Not loading button text should include same as children', async () => { - const { getByRole } = render() - expect(getByRole('button').textContent?.includes('Loading')).toBe(false) + describe('variants', () => { + it('applies default secondary variant', () => { + render() + expect(screen.getByRole('button').className).toContain('btn-secondary') }) - it('Loading button should have loading classname', async () => { + it.each([ + 'primary', + 'warning', + 'secondary', + 'secondary-accent', + 'ghost', + 'ghost-accent', + 'tertiary', + ] as const)('applies %s variant', (variant) => { + render() + expect(screen.getByRole('button').className).toContain(`btn-${variant}`) + }) + + it('applies destructive modifier', () => { + render() + expect(screen.getByRole('button').className).toContain('btn-destructive') + }) + }) + + describe('sizes', () => { + it('applies default medium size', () => { + render() + expect(screen.getByRole('button').className).toContain('btn-medium') + }) + + it.each(['small', 'medium', 'large'] as const)('applies %s size', (size) => { + render() + expect(screen.getByRole('button').className).toContain(`btn-${size}`) + }) + }) + + describe('loading', () => { + it('shows spinner when loading', () => { + render() + expect(screen.getByRole('button').querySelector('.animate-spin')).toBeInTheDocument() + }) + + it('hides spinner when not loading', () => { + render() + expect(screen.getByRole('button').querySelector('.animate-spin')).not.toBeInTheDocument() + }) + + it('auto-disables when loading', () => { + render() + expect(screen.getByRole('button')).toBeDisabled() + }) + + it('sets aria-busy when loading', () => { + render() + expect(screen.getByRole('button')).toHaveAttribute('aria-busy', 'true') + }) + + it('does not set aria-busy when not loading', () => { + render() + expect(screen.getByRole('button')).not.toHaveAttribute('aria-busy') + }) + + it('applies custom spinnerClassName', () => { const animClassName = 'anim-breath' - const { getByRole } = render() - expect(getByRole('button').getElementsByClassName('animate-spin')[0]?.className).toContain(animClassName) + render() + expect(screen.getByRole('button').querySelector('.animate-spin')?.className).toContain(animClassName) }) }) - describe('Button style', () => { - it('Button should have default variant', async () => { - const { getByRole } = render() - expect(getByRole('button').className).toContain('btn-secondary') + describe('disabled', () => { + it('disables button when disabled prop is set', () => { + render() + expect(screen.getByRole('button')).toBeDisabled() }) - it('Button should have primary variant', async () => { - const { getByRole } = render() - expect(getByRole('button').className).toContain('btn-primary') - }) - - it('Button should have warning variant', async () => { - const { getByRole } = render() - expect(getByRole('button').className).toContain('btn-warning') - }) - - it('Button should have secondary variant', async () => { - const { getByRole } = render() - expect(getByRole('button').className).toContain('btn-secondary') - }) - - it('Button should have secondary-accent variant', async () => { - const { getByRole } = render() - expect(getByRole('button').className).toContain('btn-secondary-accent') - }) - it('Button should have ghost variant', async () => { - const { getByRole } = render() - expect(getByRole('button').className).toContain('btn-ghost') - }) - it('Button should have ghost-accent variant', async () => { - const { getByRole } = render() - expect(getByRole('button').className).toContain('btn-ghost-accent') - }) - - it('Button disabled should have disabled variant', async () => { - const { getByRole } = render() - expect(getByRole('button').className).toContain('btn-disabled') + it('keeps focusable when loading with focusableWhenDisabled', () => { + render() + const button = screen.getByRole('button') + expect(button).toHaveAttribute('aria-disabled', 'true') }) }) - describe('Button size', () => { - it('Button should have default size', async () => { - const { getByRole } = render() - expect(getByRole('button').className).toContain('btn-medium') - }) - - it('Button should have small size', async () => { - const { getByRole } = render() - expect(getByRole('button').className).toContain('btn-small') - }) - - it('Button should have medium size', async () => { - const { getByRole } = render() - expect(getByRole('button').className).toContain('btn-medium') - }) - - it('Button should have large size', async () => { - const { getByRole } = render() - expect(getByRole('button').className).toContain('btn-large') - }) - }) - - describe('Button destructive', () => { - it('Button should have destructive classname', async () => { - const { getByRole } = render() - expect(getByRole('button').className).toContain('btn-destructive') - }) - }) - - describe('Button events', () => { - it('onClick should been call after clicked', async () => { + describe('events', () => { + it('fires onClick when clicked', () => { const onClick = vi.fn() - const { getByRole } = render() - fireEvent.click(getByRole('button')) - expect(onClick).toHaveBeenCalled() + render() + fireEvent.click(screen.getByRole('button')) + expect(onClick).toHaveBeenCalledTimes(1) + }) + + it('does not fire onClick when disabled', () => { + const onClick = vi.fn() + render() + fireEvent.click(screen.getByRole('button')) + expect(onClick).not.toHaveBeenCalled() + }) + + it('does not fire onClick when loading', () => { + const onClick = vi.fn() + render() + fireEvent.click(screen.getByRole('button')) + expect(onClick).not.toHaveBeenCalled() + }) + }) + + describe('ref forwarding', () => { + it('forwards ref to the button element', () => { + let buttonRef: HTMLButtonElement | null = null + render( + , + ) + expect(buttonRef).toBeInstanceOf(HTMLButtonElement) }) }) }) diff --git a/web/app/components/base/button/index.css b/web/app/components/base/button/index.css index 5899c027d3..6360ed9d0c 100644 --- a/web/app/components/base/button/index.css +++ b/web/app/components/base/button/index.css @@ -2,10 +2,11 @@ @layer components { .btn { - @apply inline-flex justify-center items-center cursor-pointer whitespace-nowrap; + @apply inline-flex justify-center items-center cursor-pointer whitespace-nowrap + outline-none focus-visible:ring-2 focus-visible:ring-state-accent-solid; } - .btn-disabled { + .btn:is(:disabled, [data-disabled]) { @apply cursor-not-allowed; } @@ -40,7 +41,7 @@ text-components-button-destructive-primary-text; } - .btn-primary.btn-disabled { + .btn-primary:is(:disabled, [data-disabled]) { @apply shadow-none bg-components-button-primary-bg-disabled @@ -48,7 +49,7 @@ text-components-button-primary-text-disabled; } - .btn-primary.btn-destructive.btn-disabled { + .btn-primary.btn-destructive:is(:disabled, [data-disabled]) { @apply shadow-none bg-components-button-destructive-primary-bg-disabled @@ -68,7 +69,7 @@ text-components-button-secondary-text; } - .btn-secondary.btn-disabled { + .btn-secondary:is(:disabled, [data-disabled]) { @apply backdrop-blur-sm bg-components-button-secondary-bg-disabled @@ -85,7 +86,7 @@ text-components-button-destructive-secondary-text; } - .btn-secondary.btn-destructive.btn-disabled { + .btn-secondary.btn-destructive:is(:disabled, [data-disabled]) { @apply bg-components-button-destructive-secondary-bg-disabled border-components-button-destructive-secondary-border-disabled @@ -104,7 +105,7 @@ text-components-button-secondary-accent-text; } - .btn-secondary-accent.btn-disabled { + .btn-secondary-accent:is(:disabled, [data-disabled]) { @apply bg-components-button-secondary-bg-disabled border-components-button-secondary-border-disabled @@ -120,7 +121,7 @@ text-components-button-destructive-primary-text; } - .btn-warning.btn-disabled { + .btn-warning:is(:disabled, [data-disabled]) { @apply bg-components-button-destructive-primary-bg-disabled border-components-button-destructive-primary-border-disabled @@ -134,7 +135,7 @@ text-components-button-tertiary-text; } - .btn-tertiary.btn-disabled { + .btn-tertiary:is(:disabled, [data-disabled]) { @apply bg-components-button-tertiary-bg-disabled text-components-button-tertiary-text-disabled; @@ -147,7 +148,7 @@ text-components-button-destructive-tertiary-text; } - .btn-tertiary.btn-destructive.btn-disabled { + .btn-tertiary.btn-destructive:is(:disabled, [data-disabled]) { @apply bg-components-button-destructive-tertiary-bg-disabled text-components-button-destructive-tertiary-text-disabled; @@ -159,7 +160,7 @@ text-components-button-ghost-text; } - .btn-ghost.btn-disabled { + .btn-ghost:is(:disabled, [data-disabled]) { @apply text-components-button-ghost-text-disabled; } @@ -170,7 +171,7 @@ text-components-button-destructive-ghost-text; } - .btn-ghost.btn-destructive.btn-disabled { + .btn-ghost.btn-destructive:is(:disabled, [data-disabled]) { @apply text-components-button-destructive-ghost-text-disabled; } @@ -181,7 +182,7 @@ text-components-button-secondary-accent-text; } - .btn-ghost-accent.btn-disabled { + .btn-ghost-accent:is(:disabled, [data-disabled]) { @apply text-components-button-secondary-accent-text-disabled; } diff --git a/web/app/components/base/button/index.stories.tsx b/web/app/components/base/button/index.stories.tsx index 25bd5957e1..5a7ec55e8f 100644 --- a/web/app/components/base/button/index.stories.tsx +++ b/web/app/components/base/button/index.stories.tsx @@ -1,6 +1,5 @@ import type { Meta, StoryObj } from '@storybook/nextjs-vite' -import { RocketLaunchIcon } from '@heroicons/react/20/solid' import { Button } from '.' const meta = { @@ -12,10 +11,16 @@ const meta = { tags: ['autodocs'], argTypes: { loading: { control: 'boolean' }, + destructive: { control: 'boolean' }, + disabled: { control: 'boolean' }, variant: { control: 'select', options: ['primary', 'warning', 'secondary', 'secondary-accent', 'ghost', 'ghost-accent', 'tertiary'], }, + size: { + control: 'select', + options: ['small', 'medium', 'large'], + }, }, args: { variant: 'ghost', @@ -29,11 +34,7 @@ type Story = StoryObj export const Default: Story = { args: { variant: 'primary', - loading: false, children: 'Primary Button', - styleCss: {}, - spinnerClassName: '', - destructive: false, }, } @@ -95,14 +96,46 @@ export const Loading: Story = { }, } +export const Destructive: Story = { + args: { + variant: 'primary', + destructive: true, + children: 'Delete', + }, +} + export const WithIcon: Story = { args: { variant: 'primary', children: ( <> - + Launch ), }, } + +export const SmallSize: Story = { + args: { + variant: 'secondary', + size: 'small', + children: 'Small', + }, +} + +export const LargeSize: Story = { + args: { + variant: 'primary', + size: 'large', + children: 'Large Button', + }, +} + +export const AsLink: Story = { + args: { + variant: 'ghost-accent', + render: , + children: 'Link Button', + }, +} diff --git a/web/app/components/base/button/index.tsx b/web/app/components/base/button/index.tsx index 0de57617af..047ced4c53 100644 --- a/web/app/components/base/button/index.tsx +++ b/web/app/components/base/button/index.tsx @@ -1,12 +1,12 @@ import type { VariantProps } from 'class-variance-authority' -import type { CSSProperties } from 'react' +import { Button as BaseButton } from '@base-ui/react/button' import { cva } from 'class-variance-authority' import * as React from 'react' import { cn } from '@/utils/classnames' import Spinner from '../spinner' const buttonVariants = cva( - 'btn disabled:btn-disabled', + 'btn', { variants: { variant: { @@ -23,6 +23,9 @@ const buttonVariants = cva( medium: 'btn-medium', large: 'btn-large', }, + destructive: { + true: 'btn-destructive', + }, }, defaultVariants: { variant: 'secondary', @@ -32,25 +35,44 @@ const buttonVariants = cva( ) export type ButtonProps = { - destructive?: boolean loading?: boolean - styleCss?: CSSProperties spinnerClassName?: string ref?: React.Ref + render?: React.ReactElement + focusableWhenDisabled?: boolean } & React.ButtonHTMLAttributes & VariantProps -const Button = ({ className, variant, size, destructive, loading, styleCss, children, spinnerClassName, ref, ...props }: ButtonProps) => { +const Button = ({ + className, + variant, + size, + destructive, + loading, + children, + spinnerClassName, + ref, + render, + focusableWhenDisabled, + disabled, + type = 'button', + ...props +}: ButtonProps) => { + const isDisabled = disabled || loading + return ( - + ) } Button.displayName = 'Button' diff --git a/web/app/components/datasets/create/step-two/__tests__/index.spec.tsx b/web/app/components/datasets/create/step-two/__tests__/index.spec.tsx index 9a0a9630ea..86e8ec2ab5 100644 --- a/web/app/components/datasets/create/step-two/__tests__/index.spec.tsx +++ b/web/app/components/datasets/create/step-two/__tests__/index.spec.tsx @@ -1772,16 +1772,14 @@ describe('StepTwoFooter', () => { render() const nextButton = screen.getByText(/nextStep/i).closest('button') - // Button has disabled:btn-disabled class which handles the loading state - expect(nextButton).toHaveClass('disabled:btn-disabled') + expect(nextButton).toBeDisabled() }) it('should show loading state on Save button when creating in setting mode', () => { render() const saveButton = screen.getByText(/save/i).closest('button') - // Button has disabled:btn-disabled class which handles the loading state - expect(saveButton).toHaveClass('disabled:btn-disabled') + expect(saveButton).toBeDisabled() }) }) }) diff --git a/web/app/components/datasets/hit-testing/__tests__/index.spec.tsx b/web/app/components/datasets/hit-testing/__tests__/index.spec.tsx index 0a5a55b744..fe7510b498 100644 --- a/web/app/components/datasets/hit-testing/__tests__/index.spec.tsx +++ b/web/app/components/datasets/hit-testing/__tests__/index.spec.tsx @@ -579,10 +579,20 @@ describe('HitTestingPage', () => { }) describe('Integration: Hit Testing Flow', () => { - beforeEach(() => { + beforeEach(async () => { vi.clearAllMocks() mockHitTestingMutateAsync.mockReset() mockExternalHitTestingMutateAsync.mockReset() + + const { useHitTesting, useExternalKnowledgeBaseHitTesting } = await import('@/service/knowledge/use-hit-testing') + vi.mocked(useHitTesting).mockReturnValue({ + mutateAsync: mockHitTestingMutateAsync, + isPending: false, + } as unknown as ReturnType) + vi.mocked(useExternalKnowledgeBaseHitTesting).mockReturnValue({ + mutateAsync: mockExternalHitTestingMutateAsync, + isPending: false, + } as unknown as ReturnType) }) it('should complete a full hit testing flow', async () => { @@ -781,8 +791,18 @@ describe('Integration: Hit Testing Flow', () => { // Drawer and Modal Interaction Tests describe('Drawer and Modal Interactions', () => { - beforeEach(() => { + beforeEach(async () => { vi.clearAllMocks() + + const { useHitTesting, useExternalKnowledgeBaseHitTesting } = await import('@/service/knowledge/use-hit-testing') + vi.mocked(useHitTesting).mockReturnValue({ + mutateAsync: mockHitTestingMutateAsync, + isPending: false, + } as unknown as ReturnType) + vi.mocked(useExternalKnowledgeBaseHitTesting).mockReturnValue({ + mutateAsync: mockExternalHitTestingMutateAsync, + isPending: false, + } as unknown as ReturnType) }) it('should save retrieval config when ModifyRetrievalModal onSave is called', async () => { @@ -828,9 +848,19 @@ describe('Drawer and Modal Interactions', () => { // renderHitResults Coverage Tests describe('renderHitResults Coverage', () => { - beforeEach(() => { + beforeEach(async () => { vi.clearAllMocks() mockHitTestingMutateAsync.mockReset() + + const { useHitTesting, useExternalKnowledgeBaseHitTesting } = await import('@/service/knowledge/use-hit-testing') + vi.mocked(useHitTesting).mockReturnValue({ + mutateAsync: mockHitTestingMutateAsync, + isPending: false, + } as unknown as ReturnType) + vi.mocked(useExternalKnowledgeBaseHitTesting).mockReturnValue({ + mutateAsync: mockExternalHitTestingMutateAsync, + isPending: false, + } as unknown as ReturnType) }) it('should render hit results panel with records count', async () => { @@ -952,10 +982,20 @@ describe('ModifyRetrievalModal onSave Coverage', () => { // Direct Component Coverage Tests describe('HitTestingPage Internal Functions Coverage', () => { - beforeEach(() => { + beforeEach(async () => { vi.clearAllMocks() mockHitTestingMutateAsync.mockReset() mockExternalHitTestingMutateAsync.mockReset() + + const { useHitTesting, useExternalKnowledgeBaseHitTesting } = await import('@/service/knowledge/use-hit-testing') + vi.mocked(useHitTesting).mockReturnValue({ + mutateAsync: mockHitTestingMutateAsync, + isPending: false, + } as unknown as ReturnType) + vi.mocked(useExternalKnowledgeBaseHitTesting).mockReturnValue({ + mutateAsync: mockExternalHitTestingMutateAsync, + isPending: false, + } as unknown as ReturnType) }) it('should trigger renderHitResults when mutation succeeds with records', async () => { diff --git a/web/app/components/header/account-dropdown/compliance.tsx b/web/app/components/header/account-dropdown/compliance.tsx index 1627a5a052..c048f25c1e 100644 --- a/web/app/components/header/account-dropdown/compliance.tsx +++ b/web/app/components/header/account-dropdown/compliance.tsx @@ -46,9 +46,10 @@ function ComplianceDocActionVisual({ return (
diff --git a/web/app/styles/globals.css b/web/app/styles/globals.css index cf183cad4e..f99371d180 100644 --- a/web/app/styles/globals.css +++ b/web/app/styles/globals.css @@ -121,10 +121,6 @@ a { outline: none; } -button:focus-within { - outline: none; -} - /* @media (prefers-color-scheme: dark) { html { color-scheme: dark; From 84dca83ecdd8708d64ca5fb82ba70aa9cad17986 Mon Sep 17 00:00:00 2001 From: yyh <92089059+lyzno1@users.noreply.github.com> Date: Wed, 4 Mar 2026 13:56:27 +0800 Subject: [PATCH 021/256] feat(web): add base AlertDialog with app-card migration example (#32933) Signed-off-by: yyh --- .gitignore | 1 + .../apps/app-card-operations-flow.test.tsx | 14 +- .../apps/app-list-browsing-flow.test.tsx | 4 + web/__tests__/apps/create-app-flow.test.tsx | 4 + .../apps/__tests__/app-card.spec.tsx | 93 ++++------- .../components/apps/__tests__/list.spec.tsx | 4 + web/app/components/apps/app-card.tsx | 65 +++++--- web/app/components/base/confirm/index.tsx | 6 + .../ui/alert-dialog/__tests__/index.spec.tsx | 145 ++++++++++++++++++ .../components/base/ui/alert-dialog/index.tsx | 106 +++++++++++++ web/contract/console/apps.ts | 14 ++ web/contract/router.ts | 4 + web/docs/overlay-migration.md | 2 + web/eslint-suppressions.json | 142 ++++++++++++++--- web/eslint.config.mjs | 6 + web/service/use-apps.ts | 25 +++ 16 files changed, 529 insertions(+), 106 deletions(-) create mode 100644 web/app/components/base/ui/alert-dialog/__tests__/index.spec.tsx create mode 100644 web/app/components/base/ui/alert-dialog/index.tsx create mode 100644 web/contract/console/apps.ts diff --git a/.gitignore b/.gitignore index 7bd919f095..8200d70afe 100644 --- a/.gitignore +++ b/.gitignore @@ -222,6 +222,7 @@ mise.toml # AI Assistant .roo/ +/.claude/worktrees/ api/.env.backup /clickzetta diff --git a/web/__tests__/apps/app-card-operations-flow.test.tsx b/web/__tests__/apps/app-card-operations-flow.test.tsx index 1aa6706b82..c3e8410955 100644 --- a/web/__tests__/apps/app-card-operations-flow.test.tsx +++ b/web/__tests__/apps/app-card-operations-flow.test.tsx @@ -14,7 +14,7 @@ import { fireEvent, render, screen, waitFor } from '@testing-library/react' import { beforeEach, describe, expect, it, vi } from 'vitest' import AppCard from '@/app/components/apps/app-card' import { AccessMode } from '@/models/access-control' -import { deleteApp, exportAppConfig, updateAppInfo } from '@/service/apps' +import { exportAppConfig, updateAppInfo } from '@/service/apps' import { AppModeEnum } from '@/types/app' let mockIsCurrentWorkspaceEditor = true @@ -26,6 +26,8 @@ let mockSystemFeatures = { const mockRouterPush = vi.fn() const mockNotify = vi.fn() const mockOnPlanInfoChanged = vi.fn() +const mockDeleteAppMutation = vi.fn().mockResolvedValue(undefined) +let mockDeleteMutationPending = false vi.mock('next/navigation', () => ({ useRouter: () => ({ @@ -117,6 +119,13 @@ vi.mock('@/service/tag', () => ({ fetchTagList: vi.fn().mockResolvedValue([]), })) +vi.mock('@/service/use-apps', () => ({ + useDeleteAppMutation: () => ({ + mutateAsync: mockDeleteAppMutation, + isPending: mockDeleteMutationPending, + }), +})) + vi.mock('@/service/apps', () => ({ deleteApp: vi.fn().mockResolvedValue({}), updateAppInfo: vi.fn().mockResolvedValue({}), @@ -270,6 +279,7 @@ const renderAppCard = (app?: Partial) => { describe('App Card Operations Flow', () => { beforeEach(() => { vi.clearAllMocks() + mockDeleteMutationPending = false mockIsCurrentWorkspaceEditor = true mockSystemFeatures = { branding: { enabled: false }, @@ -341,7 +351,7 @@ describe('App Card Operations Flow', () => { fireEvent.click(confirmBtn) await waitFor(() => { - expect(deleteApp).toHaveBeenCalledWith('app-to-delete') + expect(mockDeleteAppMutation).toHaveBeenCalledWith('app-to-delete') }) } } diff --git a/web/__tests__/apps/app-list-browsing-flow.test.tsx b/web/__tests__/apps/app-list-browsing-flow.test.tsx index 19288ecd95..079f667dbc 100644 --- a/web/__tests__/apps/app-list-browsing-flow.test.tsx +++ b/web/__tests__/apps/app-list-browsing-flow.test.tsx @@ -104,6 +104,10 @@ vi.mock('@/service/use-apps', () => ({ error: mockError, refetch: mockRefetch, }), + useDeleteAppMutation: () => ({ + mutateAsync: vi.fn(), + isPending: false, + }), })) vi.mock('@/hooks/use-pay', () => ({ diff --git a/web/__tests__/apps/create-app-flow.test.tsx b/web/__tests__/apps/create-app-flow.test.tsx index a0976d32cc..4ac9824ddd 100644 --- a/web/__tests__/apps/create-app-flow.test.tsx +++ b/web/__tests__/apps/create-app-flow.test.tsx @@ -91,6 +91,10 @@ vi.mock('@/service/use-apps', () => ({ error: null, refetch: mockRefetch, }), + useDeleteAppMutation: () => ({ + mutateAsync: vi.fn(), + isPending: false, + }), })) vi.mock('@/hooks/use-pay', () => ({ diff --git a/web/app/components/apps/__tests__/app-card.spec.tsx b/web/app/components/apps/__tests__/app-card.spec.tsx index ee36d471fd..9bc23ce199 100644 --- a/web/app/components/apps/__tests__/app-card.spec.tsx +++ b/web/app/components/apps/__tests__/app-card.spec.tsx @@ -63,6 +63,15 @@ vi.mock('@/service/apps', () => ({ exportAppConfig: vi.fn(() => Promise.resolve({ data: 'yaml: content' })), })) +const mockDeleteAppMutation = vi.fn(() => Promise.resolve()) +let mockDeleteMutationPending = false +vi.mock('@/service/use-apps', () => ({ + useDeleteAppMutation: () => ({ + mutateAsync: mockDeleteAppMutation, + isPending: mockDeleteMutationPending, + }), +})) + vi.mock('@/service/workflow', () => ({ fetchWorkflowDraft: vi.fn(() => Promise.resolve({ environment_variables: [] })), })) @@ -146,13 +155,6 @@ vi.mock('next/dynamic', () => ({ return React.createElement('div', { 'data-testid': 'switch-modal' }, React.createElement('button', { 'onClick': onClose, 'data-testid': 'close-switch-modal' }, 'Close'), React.createElement('button', { 'onClick': onSuccess, 'data-testid': 'confirm-switch-modal' }, 'Switch')) } } - if (fnString.includes('base/confirm')) { - return function MockConfirm({ isShow, onCancel, onConfirm }: { isShow: boolean, onCancel: () => void, onConfirm: () => void }) { - if (!isShow) - return null - return React.createElement('div', { 'data-testid': 'confirm-dialog' }, React.createElement('button', { 'onClick': onCancel, 'data-testid': 'cancel-confirm' }, 'Cancel'), React.createElement('button', { 'onClick': onConfirm, 'data-testid': 'confirm-confirm' }, 'Confirm')) - } - } if (fnString.includes('dsl-export-confirm-modal')) { return function MockDSLExportModal({ onClose, onConfirm }: { onClose?: () => void, onConfirm?: (withSecrets: boolean) => void }) { return React.createElement('div', { 'data-testid': 'dsl-export-modal' }, React.createElement('button', { 'onClick': () => onClose?.(), 'data-testid': 'close-dsl-export' }, 'Close'), React.createElement('button', { 'onClick': () => onConfirm?.(true), 'data-testid': 'confirm-dsl-export' }, 'Export with secrets'), React.createElement('button', { 'onClick': () => onConfirm?.(false), 'data-testid': 'confirm-dsl-export-no-secrets' }, 'Export without secrets')) @@ -235,6 +237,7 @@ describe('AppCard', () => { vi.clearAllMocks() mockOpenAsyncWindow.mockReset() mockWebappAuthEnabled = false + mockDeleteMutationPending = false }) describe('Rendering', () => { @@ -461,35 +464,19 @@ describe('AppCard', () => { render() fireEvent.click(screen.getByTestId('popover-trigger')) - - await waitFor(() => { - const deleteButton = screen.getByText('common.operation.delete') - fireEvent.click(deleteButton) - }) - - await waitFor(() => { - expect(screen.getByTestId('confirm-dialog')).toBeInTheDocument() - }) + fireEvent.click(await screen.findByRole('button', { name: 'common.operation.delete' })) + expect(await screen.findByRole('alertdialog')).toBeInTheDocument() }) it('should close confirm dialog when cancel is clicked', async () => { render() fireEvent.click(screen.getByTestId('popover-trigger')) - + fireEvent.click(await screen.findByRole('button', { name: 'common.operation.delete' })) + expect(await screen.findByRole('alertdialog')).toBeInTheDocument() + fireEvent.click(screen.getByRole('button', { name: 'common.operation.cancel' })) await waitFor(() => { - const deleteButton = screen.getByText('common.operation.delete') - fireEvent.click(deleteButton) - }) - - await waitFor(() => { - expect(screen.getByTestId('confirm-dialog')).toBeInTheDocument() - }) - - fireEvent.click(screen.getByTestId('cancel-confirm')) - - await waitFor(() => { - expect(screen.queryByTestId('confirm-dialog')).not.toBeInTheDocument() + expect(screen.queryByRole('alertdialog')).not.toBeInTheDocument() }) }) @@ -554,59 +541,41 @@ describe('AppCard', () => { // Open popover and click delete fireEvent.click(screen.getByTestId('popover-trigger')) - await waitFor(() => { - fireEvent.click(screen.getByText('common.operation.delete')) - }) - - // Confirm delete - await waitFor(() => { - expect(screen.getByTestId('confirm-dialog')).toBeInTheDocument() - }) - - fireEvent.click(screen.getByTestId('confirm-confirm')) + fireEvent.click(await screen.findByRole('button', { name: 'common.operation.delete' })) + expect(await screen.findByRole('alertdialog')).toBeInTheDocument() + fireEvent.click(screen.getByRole('button', { name: 'common.operation.confirm' })) await waitFor(() => { - expect(appsService.deleteApp).toHaveBeenCalled() + expect(mockDeleteAppMutation).toHaveBeenCalled() }) }) - it('should call onRefresh after successful delete', async () => { + it('should not call onRefresh after successful delete', async () => { render() fireEvent.click(screen.getByTestId('popover-trigger')) - await waitFor(() => { - fireEvent.click(screen.getByText('common.operation.delete')) - }) + fireEvent.click(await screen.findByRole('button', { name: 'common.operation.delete' })) + expect(await screen.findByRole('alertdialog')).toBeInTheDocument() + fireEvent.click(screen.getByRole('button', { name: 'common.operation.confirm' })) await waitFor(() => { - expect(screen.getByTestId('confirm-dialog')).toBeInTheDocument() - }) - - fireEvent.click(screen.getByTestId('confirm-confirm')) - - await waitFor(() => { - expect(mockOnRefresh).toHaveBeenCalled() + expect(mockDeleteAppMutation).toHaveBeenCalled() }) + expect(mockOnRefresh).not.toHaveBeenCalled() }) it('should handle delete failure', async () => { - (appsService.deleteApp as Mock).mockRejectedValueOnce(new Error('Delete failed')) + ;(mockDeleteAppMutation as Mock).mockRejectedValueOnce(new Error('Delete failed')) render() fireEvent.click(screen.getByTestId('popover-trigger')) - await waitFor(() => { - fireEvent.click(screen.getByText('common.operation.delete')) - }) + fireEvent.click(await screen.findByRole('button', { name: 'common.operation.delete' })) + expect(await screen.findByRole('alertdialog')).toBeInTheDocument() + fireEvent.click(screen.getByRole('button', { name: 'common.operation.confirm' })) await waitFor(() => { - expect(screen.getByTestId('confirm-dialog')).toBeInTheDocument() - }) - - fireEvent.click(screen.getByTestId('confirm-confirm')) - - await waitFor(() => { - expect(appsService.deleteApp).toHaveBeenCalled() + expect(mockDeleteAppMutation).toHaveBeenCalled() expect(mockNotify).toHaveBeenCalledWith({ type: 'error', message: expect.stringContaining('Delete failed') }) }) }) diff --git a/web/app/components/apps/__tests__/list.spec.tsx b/web/app/components/apps/__tests__/list.spec.tsx index a9bef08243..989bf6a788 100644 --- a/web/app/components/apps/__tests__/list.spec.tsx +++ b/web/app/components/apps/__tests__/list.spec.tsx @@ -106,6 +106,10 @@ vi.mock('@/service/use-apps', () => ({ error: mockServiceState.error, refetch: mockRefetch, }), + useDeleteAppMutation: () => ({ + mutateAsync: vi.fn(), + isPending: false, + }), })) vi.mock('@/service/tag', () => ({ diff --git a/web/app/components/apps/app-card.tsx b/web/app/components/apps/app-card.tsx index 8f268da02c..8dc5c82e13 100644 --- a/web/app/components/apps/app-card.tsx +++ b/web/app/components/apps/app-card.tsx @@ -20,6 +20,15 @@ import CustomPopover from '@/app/components/base/popover' import TagSelector from '@/app/components/base/tag-management/selector' import Toast, { ToastContext } from '@/app/components/base/toast' import Tooltip from '@/app/components/base/tooltip' +import { + AlertDialog, + AlertDialogActions, + AlertDialogCancelButton, + AlertDialogConfirmButton, + AlertDialogContent, + AlertDialogDescription, + AlertDialogTitle, +} from '@/app/components/base/ui/alert-dialog' import { NEED_REFRESH_APP_LIST_KEY } from '@/config' import { useAppContext } from '@/context/app-context' import { useGlobalPublicStore } from '@/context/global-public-context' @@ -27,8 +36,9 @@ import { useProviderContext } from '@/context/provider-context' import { useAsyncWindowOpen } from '@/hooks/use-async-window-open' import { AccessMode } from '@/models/access-control' import { useGetUserCanAccessApp } from '@/service/access-control' -import { copyApp, deleteApp, exportAppConfig, updateAppInfo } from '@/service/apps' +import { copyApp, exportAppConfig, updateAppInfo } from '@/service/apps' import { fetchInstalledAppList } from '@/service/explore' +import { useDeleteAppMutation } from '@/service/use-apps' import { fetchWorkflowDraft } from '@/service/workflow' import { AppModeEnum } from '@/types/app' import { getRedirection } from '@/utils/app-redirection' @@ -46,9 +56,6 @@ const DuplicateAppModal = dynamic(() => import('@/app/components/app/duplicate-m const SwitchAppModal = dynamic(() => import('@/app/components/app/switch-app-modal'), { ssr: false, }) -const Confirm = dynamic(() => import('@/app/components/base/confirm'), { - ssr: false, -}) const DSLExportConfirmModal = dynamic(() => import('@/app/components/workflow/dsl-export-confirm-modal'), { ssr: false, }) @@ -76,13 +83,12 @@ const AppCard = ({ app, onRefresh }: AppCardProps) => { const [showConfirmDelete, setShowConfirmDelete] = useState(false) const [showAccessControl, setShowAccessControl] = useState(false) const [secretEnvList, setSecretEnvList] = useState([]) + const { mutateAsync: mutateDeleteApp, isPending: isDeleting } = useDeleteAppMutation() const onConfirmDelete = useCallback(async () => { try { - await deleteApp(app.id) + await mutateDeleteApp(app.id) notify({ type: 'success', message: t('appDeleted', { ns: 'app' }) }) - if (onRefresh) - onRefresh() onPlanInfoChanged() } catch (e: any) { @@ -91,8 +97,17 @@ const AppCard = ({ app, onRefresh }: AppCardProps) => { message: `${t('appDeleteFailed', { ns: 'app' })}${'message' in e ? `: ${e.message}` : ''}`, }) } - setShowConfirmDelete(false) - }, [app.id, notify, onPlanInfoChanged, onRefresh, t]) + finally { + setShowConfirmDelete(false) + } + }, [app.id, mutateDeleteApp, notify, onPlanInfoChanged, t]) + + const onDeleteDialogOpenChange = useCallback((open: boolean) => { + if (isDeleting) + return + + setShowConfirmDelete(open) + }, [isDeleting]) const onEdit: CreateAppModalProps['onConfirm'] = useCallback(async ({ name, @@ -438,7 +453,8 @@ const AppCard = ({ app, onRefresh }: AppCardProps) => {
- + {t('operation.more', { ns: 'common' })} +
)} btnClassName={open => @@ -495,15 +511,26 @@ const AppCard = ({ app, onRefresh }: AppCardProps) => { onSuccess={onSwitch} /> )} - {showConfirmDelete && ( - setShowConfirmDelete(false)} - /> - )} + + +
+ + {t('deleteAppConfirmTitle', { ns: 'app' })} + + + {t('deleteAppConfirmContent', { ns: 'app' })} + +
+ + + {t('operation.cancel', { ns: 'common' })} + + + {t('operation.confirm', { ns: 'common' })} + + +
+
{secretEnvList.length > 0 && ( { + describe('Rendering', () => { + it('should render alert dialog content when dialog is open', () => { + render( + + + Confirm Delete + This action cannot be undone. + + , + ) + + const dialog = screen.getByRole('alertdialog') + expect(dialog).toHaveTextContent('Confirm Delete') + expect(dialog).toHaveTextContent('This action cannot be undone.') + }) + + it('should not render content when dialog is closed', () => { + render( + + + Hidden Title + + , + ) + + expect(screen.queryByRole('alertdialog')).not.toBeInTheDocument() + }) + }) + + describe('Props', () => { + it('should apply custom className to popup', () => { + render( + + + Title + + , + ) + + const dialog = screen.getByRole('alertdialog') + expect(dialog).toHaveClass('custom-class') + }) + + it('should not render a close button by default', () => { + render( + + + Title + + , + ) + + expect(screen.queryByRole('button', { name: 'Close' })).not.toBeInTheDocument() + }) + }) + + describe('User Interactions', () => { + it('should open and close dialog when trigger and close are clicked', async () => { + render( + + Open Dialog + + Action Required + Please confirm the action. + Cancel + + , + ) + + expect(screen.queryByRole('alertdialog')).not.toBeInTheDocument() + + fireEvent.click(screen.getByRole('button', { name: 'Open Dialog' })) + expect(await screen.findByRole('alertdialog')).toHaveTextContent('Action Required') + + fireEvent.click(screen.getByRole('button', { name: 'Cancel' })) + await waitFor(() => { + expect(screen.queryByRole('alertdialog')).not.toBeInTheDocument() + }) + }) + }) + + describe('Composition Helpers', () => { + it('should render actions wrapper and default confirm button styles', () => { + render( + + + Action Required + + Confirm + + + , + ) + + expect(screen.getByTestId('actions')).toHaveClass('flex', 'items-start', 'justify-end', 'gap-2', 'self-stretch', 'p-6', 'custom-actions') + const confirmButton = screen.getByRole('button', { name: 'Confirm' }) + expect(confirmButton).toHaveClass('btn-primary') + expect(confirmButton).toHaveClass('btn-destructive') + }) + + it('should keep dialog open after confirm click and close via cancel helper', async () => { + const onConfirm = vi.fn() + + render( + + Open Dialog + + Action Required + + Cancel + Confirm + + + , + ) + + fireEvent.click(screen.getByRole('button', { name: 'Open Dialog' })) + expect(await screen.findByRole('alertdialog')).toBeInTheDocument() + + fireEvent.click(screen.getByRole('button', { name: 'Confirm' })) + expect(onConfirm).toHaveBeenCalledTimes(1) + expect(screen.getByRole('alertdialog')).toBeInTheDocument() + + fireEvent.click(screen.getByRole('button', { name: 'Cancel' })) + await waitFor(() => { + expect(screen.queryByRole('alertdialog')).not.toBeInTheDocument() + }) + }) + }) +}) diff --git a/web/app/components/base/ui/alert-dialog/index.tsx b/web/app/components/base/ui/alert-dialog/index.tsx new file mode 100644 index 0000000000..8d48c5b998 --- /dev/null +++ b/web/app/components/base/ui/alert-dialog/index.tsx @@ -0,0 +1,106 @@ +'use client' + +import type { ButtonProps } from '@/app/components/base/button' +import { AlertDialog as BaseAlertDialog } from '@base-ui/react/alert-dialog' +import * as React from 'react' +import Button from '@/app/components/base/button' +import { cn } from '@/utils/classnames' + +// z-index strategy (relies on root `isolation: isolate` in layout.tsx): +// All overlay primitives (Tooltip / Popover / Dropdown / Select / Dialog / AlertDialog) — z-50 +// Overlays share the same z-index; DOM order handles stacking when multiple are open. +// This ensures overlays inside an AlertDialog (e.g. a Tooltip on a dialog button) render +// above the dialog backdrop instead of being clipped by it. +// Toast — z-[99], always on top (defined in toast component) + +export const AlertDialog = BaseAlertDialog.Root +export const AlertDialogTrigger = BaseAlertDialog.Trigger +export const AlertDialogTitle = BaseAlertDialog.Title +export const AlertDialogDescription = BaseAlertDialog.Description +export const AlertDialogClose = BaseAlertDialog.Close + +type AlertDialogContentProps = { + children: React.ReactNode + className?: string + overlayClassName?: string + popupProps?: Omit, 'children' | 'className'> + backdropProps?: Omit, 'className'> +} + +export function AlertDialogContent({ + children, + className, + overlayClassName, + popupProps, + backdropProps, +}: AlertDialogContentProps) { + return ( + + + + {children} + + + ) +} + +type AlertDialogActionsProps = React.ComponentPropsWithoutRef<'div'> + +export function AlertDialogActions({ className, ...props }: AlertDialogActionsProps) { + return ( +
+ ) +} + +type AlertDialogCancelButtonProps = Omit & { + children: React.ReactNode + closeProps?: Omit, 'children' | 'render'> +} + +export function AlertDialogCancelButton({ + children, + closeProps, + ...buttonProps +}: AlertDialogCancelButtonProps) { + return ( + } + > + {children} + + ) +} + +type AlertDialogConfirmButtonProps = ButtonProps + +export function AlertDialogConfirmButton({ + variant = 'primary', + destructive = true, + ...props +}: AlertDialogConfirmButtonProps) { + return ( +
{systemFeatures.branding.enabled === false && ( -
+
© {' '} {new Date().getFullYear()} diff --git a/web/app/components/header/account-setting/model-provider-page/index.tsx b/web/app/components/header/account-setting/model-provider-page/index.tsx index 7606bbc04f..4fb1b770e1 100644 --- a/web/app/components/header/account-setting/model-provider-page/index.tsx +++ b/web/app/components/header/account-setting/model-provider-page/index.tsx @@ -99,7 +99,7 @@ const ModelProviderPage = ({ searchText }: Props) => { return (
-
{t('modelProvider.models', { ns: 'common' })}
+
{t('modelProvider.models', { ns: 'common' })}
{ > {defaultModelNotConfigured &&
} {defaultModelNotConfigured && ( -
+
{t('modelProvider.notConfigured', { ns: 'common' })}
@@ -129,8 +129,8 @@ const ModelProviderPage = ({ searchText }: Props) => {
-
{t('modelProvider.emptyProviderTitle', { ns: 'common' })}
-
{t('modelProvider.emptyProviderTip', { ns: 'common' })}
+
{t('modelProvider.emptyProviderTitle', { ns: 'common' })}
+
{t('modelProvider.emptyProviderTip', { ns: 'common' })}
)} {!!filteredConfiguredProviders?.length && ( @@ -145,7 +145,7 @@ const ModelProviderPage = ({ searchText }: Props) => { )} {!!filteredNotConfiguredProviders?.length && ( <> -
{t('modelProvider.toBeConfigured', { ns: 'common' })}
+
{t('modelProvider.toBeConfigured', { ns: 'common' })}
{filteredNotConfiguredProviders?.map(provider => ( ({ - userProfile: userProfilePlaceholder, - currentWorkspace: initialWorkspaceInfo, - isCurrentWorkspaceManager: false, - isCurrentWorkspaceOwner: false, - isCurrentWorkspaceEditor: false, - isCurrentWorkspaceDatasetOperator: false, - mutateUserProfile: noop, - mutateCurrentWorkspace: noop, - langGeniusVersionInfo: initialLangGeniusVersionInfo, - useSelector, - isLoadingCurrentWorkspace: false, - isValidatingCurrentWorkspace: false, -}) - -export function useSelector(selector: (value: AppContextValue) => T): T { - return useContextSelector(AppContext, selector) -} - export type AppContextProviderProps = { children: ReactNode } @@ -170,7 +109,7 @@ export const AppContextProvider: FC = ({ children }) => // Report user and workspace info to Amplitude when loaded if (userProfile?.id) { setUserId(userProfile.email) - const properties: Record = { + const properties: Record = { email: userProfile.email, name: userProfile.name, has_password: userProfile.is_password_set, @@ -213,7 +152,3 @@ export const AppContextProvider: FC = ({ children }) => ) } - -export const useAppContext = () => useContext(AppContext) - -export default AppContext diff --git a/web/context/app-context.ts b/web/context/app-context.ts new file mode 100644 index 0000000000..298e213e7d --- /dev/null +++ b/web/context/app-context.ts @@ -0,0 +1,73 @@ +'use client' + +import type { ICurrentWorkspace, LangGeniusVersionResponse, UserProfileResponse } from '@/models/common' +import { noop } from 'es-toolkit/function' +import { createContext, useContext, useContextSelector } from 'use-context-selector' + +export type AppContextValue = { + userProfile: UserProfileResponse + mutateUserProfile: VoidFunction + currentWorkspace: ICurrentWorkspace + isCurrentWorkspaceManager: boolean + isCurrentWorkspaceOwner: boolean + isCurrentWorkspaceEditor: boolean + isCurrentWorkspaceDatasetOperator: boolean + mutateCurrentWorkspace: VoidFunction + langGeniusVersionInfo: LangGeniusVersionResponse + useSelector: typeof useSelector + isLoadingCurrentWorkspace: boolean + isValidatingCurrentWorkspace: boolean +} + +export const userProfilePlaceholder = { + id: '', + name: '', + email: '', + avatar: '', + avatar_url: '', + is_password_set: false, +} + +export const initialLangGeniusVersionInfo = { + current_env: '', + current_version: '', + latest_version: '', + release_date: '', + release_notes: '', + version: '', + can_auto_update: false, +} + +export const initialWorkspaceInfo: ICurrentWorkspace = { + id: '', + name: '', + plan: '', + status: '', + created_at: 0, + role: 'normal', + providers: [], + trial_credits: 200, + trial_credits_used: 0, + next_credit_reset_date: 0, +} + +export const AppContext = createContext({ + userProfile: userProfilePlaceholder, + currentWorkspace: initialWorkspaceInfo, + isCurrentWorkspaceManager: false, + isCurrentWorkspaceOwner: false, + isCurrentWorkspaceEditor: false, + isCurrentWorkspaceDatasetOperator: false, + mutateUserProfile: noop, + mutateCurrentWorkspace: noop, + langGeniusVersionInfo: initialLangGeniusVersionInfo, + useSelector, + isLoadingCurrentWorkspace: false, + isValidatingCurrentWorkspace: false, +}) + +export function useSelector(selector: (value: AppContextValue) => T): T { + return useContextSelector(AppContext, selector) +} + +export const useAppContext = () => useContext(AppContext) diff --git a/web/eslint-suppressions.json b/web/eslint-suppressions.json index 00c817eac9..0880084320 100644 --- a/web/eslint-suppressions.json +++ b/web/eslint-suppressions.json @@ -301,9 +301,6 @@ } }, "app/account/oauth/authorize/layout.tsx": { - "tailwindcss/enforce-consistent-class-order": { - "count": 1 - }, "ts/no-explicit-any": { "count": 1 } @@ -4697,11 +4694,6 @@ "count": 3 } }, - "app/components/header/account-setting/model-provider-page/index.tsx": { - "tailwindcss/enforce-consistent-class-order": { - "count": 5 - } - }, "app/components/header/account-setting/model-provider-page/install-from-marketplace.tsx": { "tailwindcss/enforce-consistent-class-order": { "count": 3 @@ -6472,11 +6464,6 @@ "count": 2 } }, - "app/components/workflow/__tests__/trigger-status-sync.test.tsx": { - "ts/no-explicit-any": { - "count": 2 - } - }, "app/components/workflow/block-selector/all-start-blocks.tsx": { "react-hooks-extra/no-direct-set-state-in-use-effect": { "count": 1 @@ -8488,11 +8475,6 @@ "count": 5 } }, - "app/components/workflow/nodes/tool/__tests__/output-schema-utils.test.ts": { - "ts/no-explicit-any": { - "count": 1 - } - }, "app/components/workflow/nodes/tool/components/copy-id.tsx": { "no-restricted-imports": { "count": 1 @@ -8617,11 +8599,6 @@ "count": 1 } }, - "app/components/workflow/nodes/trigger-plugin/utils/__tests__/form-helpers.test.ts": { - "ts/no-explicit-any": { - "count": 2 - } - }, "app/components/workflow/nodes/trigger-plugin/utils/form-helpers.ts": { "ts/no-explicit-any": { "count": 7 @@ -9594,14 +9571,6 @@ "count": 5 } }, - "context/app-context.tsx": { - "react-refresh/only-export-components": { - "count": 2 - }, - "ts/no-explicit-any": { - "count": 1 - } - }, "context/datasets-context.tsx": { "react-refresh/only-export-components": { "count": 1 @@ -9762,17 +9731,6 @@ "count": 1 } }, - "lib/utils.ts": { - "import/consistent-type-specifier-style": { - "count": 1 - }, - "perfectionist/sort-named-imports": { - "count": 1 - }, - "style/quotes": { - "count": 2 - } - }, "models/common.ts": { "ts/no-explicit-any": { "count": 3 diff --git a/web/package.json b/web/package.json index 5527dcaa37..0633a499b4 100644 --- a/web/package.json +++ b/web/package.json @@ -243,8 +243,9 @@ "tsx": "4.21.0", "typescript": "5.9.3", "uglify-js": "3.19.3", - "vinext": "https://pkg.pr.new/hyoban/vinext@a30ba79", + "vinext": "https://pkg.pr.new/hyoban/vinext@556a6d6", "vite": "8.0.0-beta.16", + "vite-plugin-inspect": "11.3.3", "vite-tsconfig-paths": "6.1.1", "vitest": "4.0.18", "vitest-canvas-mock": "1.1.3" diff --git a/web/pnpm-lock.yaml b/web/pnpm-lock.yaml index 1c5347dd2e..3620dbaae2 100644 --- a/web/pnpm-lock.yaml +++ b/web/pnpm-lock.yaml @@ -596,11 +596,14 @@ importers: specifier: 3.19.3 version: 3.19.3 vinext: - specifier: https://pkg.pr.new/hyoban/vinext@a30ba79 - version: https://pkg.pr.new/hyoban/vinext@a30ba79(next@16.1.5(@babel/core@7.29.0)(react-dom@19.2.4(react@19.2.4))(react@19.2.4)(sass@1.93.2))(react-dom@19.2.4(react@19.2.4))(react@19.2.4)(typescript@5.9.3)(vite@8.0.0-beta.16(@types/node@24.10.12)(esbuild@0.27.2)(jiti@1.21.7)(sass@1.93.2)(terser@5.46.0)(tsx@4.21.0)(yaml@2.8.2))(webpack@5.104.1(esbuild@0.27.2)(uglify-js@3.19.3)) + specifier: https://pkg.pr.new/hyoban/vinext@556a6d6 + version: https://pkg.pr.new/hyoban/vinext@556a6d6(next@16.1.5(@babel/core@7.29.0)(react-dom@19.2.4(react@19.2.4))(react@19.2.4)(sass@1.93.2))(react-dom@19.2.4(react@19.2.4))(react@19.2.4)(typescript@5.9.3)(vite@8.0.0-beta.16(@types/node@24.10.12)(esbuild@0.27.2)(jiti@1.21.7)(sass@1.93.2)(terser@5.46.0)(tsx@4.21.0)(yaml@2.8.2))(webpack@5.104.1(esbuild@0.27.2)(uglify-js@3.19.3)) vite: specifier: 8.0.0-beta.16 version: 8.0.0-beta.16(@types/node@24.10.12)(esbuild@0.27.2)(jiti@1.21.7)(sass@1.93.2)(terser@5.46.0)(tsx@4.21.0)(yaml@2.8.2) + vite-plugin-inspect: + specifier: 11.3.3 + version: 11.3.3(vite@8.0.0-beta.16(@types/node@24.10.12)(esbuild@0.27.2)(jiti@1.21.7)(sass@1.93.2)(terser@5.46.0)(tsx@4.21.0)(yaml@2.8.2)) vite-tsconfig-paths: specifier: 6.1.1 version: 6.1.1(typescript@5.9.3)(vite@8.0.0-beta.16(@types/node@24.10.12)(esbuild@0.27.2)(jiti@1.21.7)(sass@1.93.2)(terser@5.46.0)(tsx@4.21.0)(yaml@2.8.2)) @@ -2175,20 +2178,13 @@ packages: resolution: {integrity: sha512-tmmZ3lQxAe/k/+rNnXQRawJ4NjxO2hqiOLTHvWchtGZULp4RyFeh6aU4XdOYBFe2KE1oShQTv4AblOs2iOrNnQ==} engines: {node: '>= 10.0.0'} - '@pivanov/utils@0.0.2': - resolution: {integrity: sha512-q9CN0bFWxWgMY5hVVYyBgez1jGiLBa6I+LkG37ycylPhFvEGOOeaADGtUSu46CaZasPnlY8fCdVJZmrgKb1EPA==} - peerDependencies: - react: '>=18' - react-dom: '>=18' - - '@pkgjs/parseargs@0.11.0': - resolution: {integrity: sha512-+1VkjdD0QBLPodGrJUeqarH8VAIvQODIbwh9XpP5Syisf7YoQgsJKPNFoqqLQlu+VQ/tVSshMR6loPMn8U+dPg==} - engines: {node: '>=14'} - '@pkgr/core@0.2.9': resolution: {integrity: sha512-QNqXyfVS2wm9hweSYD2O7F0G06uurj9kZ96TRQE5Y9hU7+tgdZwIkbAKc5Ocy1HxEY2kuDQa6cQ1WRs/O5LFKA==} engines: {node: ^12.20.0 || ^14.18.0 || >=16.0.0} + '@polka/url@1.0.0-next.29': + resolution: {integrity: sha512-wwQAWhWSuHaag8c4q/KN/vCoeOJYshAIvMQwD4GpSb3OiZklFfvAgmj0VCBBImRpuF/aFgIRzllXlVX93Jevww==} + '@preact/signals-core@1.12.2': resolution: {integrity: sha512-5Yf8h1Ke3SMHr15xl630KtwPTW4sYDFkkxS0vQ8UiQLWwZQnrF9IKaVG1mN5VcJz52EcWs2acsc/Npjha/7ysA==} @@ -3874,6 +3870,9 @@ packages: birecord@0.1.1: resolution: {integrity: sha512-VUpsf/qykW0heRlC8LooCq28Kxn3mAqKohhDG/49rrsQ1dT1CXyj/pgXS+5BSRzFTR/3DyIBOqQOrGyZOh71Aw==} + birpc@2.9.0: + resolution: {integrity: sha512-KrayHS5pBi69Xi9JmvoqrIgYGDkD6mcSe/i6YKi3w5kekCLzrX4+nawcXqrj2tIp50Kw/mT/s3p+GVK0A0sKxw==} + bl@4.1.0: resolution: {integrity: sha512-1W07cM9gS6DcLperZfFSj+bWLtaPGSOHWhPiGzXmvVJbRLdG82sH/Kn8EtW1VqWVA54AKf2h5k5BbnIbwF3h6w==} @@ -4580,6 +4579,9 @@ packages: resolution: {integrity: sha512-xUtoPkMggbz0MPyPiIWr1Kp4aeWJjDZ6SMvURhimjdZgsRuDplF5/s9hcgGhyXMhs+6vpnuoiZ2kFiu3FMnS8Q==} engines: {node: '>=18'} + error-stack-parser-es@1.0.5: + resolution: {integrity: sha512-5qucVt2XcuGMcEGgWI7i+yZpmpByQ8J1lHhcL7PwqCwu9FPP3VUXzT4ltHe5i2z9dePwEHcDVOAfSnHsOlCXRA==} + es-module-lexer@1.7.0: resolution: {integrity: sha512-jEQoCwk8hyb2AZziIOLhDqpm5+2ww5uIE6lkO/6jcOCusfk6LhMHpXXfBLXTZ7Ydyt0j4VoUQv6uGNYbdW+kBA==} @@ -6109,6 +6111,10 @@ packages: moo-color@1.0.3: resolution: {integrity: sha512-i/+ZKXMDf6aqYtBhuOcej71YSlbjT3wCO/4H1j8rPvxDJEifdwgg5MaFyu6iYAT8GBZJg2z0dkgK4YMzvURALQ==} + mrmime@2.0.1: + resolution: {integrity: sha512-Y3wQdFg2Va6etvQ5I82yUhGdsKrcYox6p7FfL1LbK2J4V01F9TGlepTIhnK24t7koZibmg82KGglhA1XK5IsLQ==} + engines: {node: '>=10'} + ms@2.1.3: resolution: {integrity: sha512-6FlzubTLZG3J2a/NVCAleEhjzq5oxgHyaCU9yYXvcLsvoVaHJq/s5xXI6/XXP6tz7R9xAOtHnSO/tXtF3WRTlA==} @@ -6231,6 +6237,9 @@ packages: obug@2.1.1: resolution: {integrity: sha512-uTqF9MuPraAQ+IsnPf366RG4cP9RtUi7MLO1N3KEc+wb0a6yKpeL0lmk2IB1jY5KHPAlTc6T/JRdC/YqxHNwkQ==} + ohash@2.0.11: + resolution: {integrity: sha512-RdR9FQrFwNBNXAr4GixM8YaRZRJ5PUWbKYbE5eOsrwAjJW0q2REGcf79oYPsLyskQCZG1PLN+S/K1V00joZAoQ==} + once@1.4.0: resolution: {integrity: sha512-lNaJgI+2Q5URQBkccEKHTQOPaXdUxnZZElQTZY0MFUAuaEqe1E+Nyvgdz/aIyNi6Z9MzO5dv1H8n58/GELp3+w==} @@ -6354,6 +6363,9 @@ packages: pend@1.2.0: resolution: {integrity: sha512-F3asv42UuXchdzt+xXqfW1OGlVBe+mxa2mqI0pg5yAHZPvFmY3Y6drSf/GQ1A86WgWEN9Kzh/WrgKa6iGcHXLg==} + perfect-debounce@2.1.0: + resolution: {integrity: sha512-LjgdTytVFXeUgtHZr9WYViYSM/g8MkcTPYDlPa3cDqMirHjKiSZPYd6DoL7pK8AJQr+uWkQvCjHNdiMqsrJs+g==} + periscopic@4.0.2: resolution: {integrity: sha512-sqpQDUy8vgB7ycLkendSKS6HnVz1Rneoc3Rc+ZBUCe2pbqlVuCC5vF52l0NJ1aiMg/r1qfYF9/myz8CZeI2rjA==} @@ -6982,6 +6994,10 @@ packages: simple-swizzle@0.2.4: resolution: {integrity: sha512-nAu1WFPQSMNr2Zn9PGSZK9AGn4t/y97lEm+MXTtUDwfP0ksAIX4nO+6ruD9Jwut4C49SB1Ws+fbXsm/yScWOHw==} + sirv@3.0.2: + resolution: {integrity: sha512-2wcC/oGxHis/BoHkkPwldgiPSYcpZK3JU28WoMVv55yHJgcZ8rlXvuG9iZggz+sU1d4bRgIGASwyWqjxu3FM0g==} + engines: {node: '>=18'} + sisteransi@1.0.5: resolution: {integrity: sha512-bLGGlR1QxBcynn2d5YmDX4MGjlZvy2MRBDRNHLJ8VI6l6+9FUiyTFNJ0IveOSP0bcXgVDPRcfGqA0pjaqUpfVg==} @@ -7289,6 +7305,10 @@ packages: resolution: {integrity: sha512-A5F0cM6+mDleacLIEUkmfpkBbnHJFV1d2rprHU2MXNk7mlxHq2zGojA+SRvQD1RoMo9gqjZPWEaKG4v1BQ48lw==} engines: {node: ^20.19.0 || ^22.13.0 || >=24} + totalist@3.0.1: + resolution: {integrity: sha512-sf4i37nQ2LBx4m3wB74y+ubopq6W/dIzXg0FDGjsYnZHVa1Da8FH853wlL2gtUhg+xJXjfk3kUZS3BRoQeoQBQ==} + engines: {node: '>=6'} + tough-cookie@6.0.0: resolution: {integrity: sha512-kXuRi1mtaKMrsLUxz3sQYvVl37B0Ns6MzfrtV5DvJceE9bPyspOqk9xxv7XbZWcfLWbFmm997vl83qUWVJA64w==} engines: {node: '>=16'} @@ -7436,6 +7456,10 @@ packages: unpic@4.2.2: resolution: {integrity: sha512-z6T2ScMgRV2y2H8MwwhY5xHZWXhUx/YxtOCGJwfURSl7ypVy4HpLIMWoIZKnnxQa/RKzM0kg8hUh0paIrpLfvw==} + unplugin-utils@0.3.1: + resolution: {integrity: sha512-5lWVjgi6vuHhJ526bI4nlCOmkCIF3nnfXkCMDeMJrtdvxTs6ZFCM8oNufGTsDbKv/tJ/xj8RpvXjRuPBZJuJog==} + engines: {node: '>=20.19.0'} + unplugin@2.1.0: resolution: {integrity: sha512-us4j03/499KhbGP8BU7Hrzrgseo+KdfJYWcbcajCOqsAyb8Gk0Yn2kiUIcZISYCb1JFaZfIuG3b42HmguVOKCQ==} engines: {node: '>=18.12.0'} @@ -7542,8 +7566,8 @@ packages: vfile@6.0.3: resolution: {integrity: sha512-KzIbH/9tXat2u30jf+smMwFCsno4wHVdNmzFyL+T/L3UGqqk6JKfVqOFOZEpZSHADH1k40ab6NUIXZq422ov3Q==} - vinext@https://pkg.pr.new/hyoban/vinext@a30ba79: - resolution: {integrity: sha512-yx/gCneOli5eGTrLUq6/M7A6DGQs14qOJW/Xp9RN6sTI0mErKyWWjO5E7FZT98BJbqH5xzI5nk8EOCLF3bojkA==, tarball: https://pkg.pr.new/hyoban/vinext@a30ba79} + vinext@https://pkg.pr.new/hyoban/vinext@556a6d6: + resolution: {integrity: sha512-Sz8RkTDsY6cnGrevlQi4nXgahu8okEGsdKY5m31d/L9tXo35bNETMHfVee5gaI2UKZS9LMcffWaTOxxINUgogQ==, tarball: https://pkg.pr.new/hyoban/vinext@556a6d6} version: 0.0.5 engines: {node: '>=22'} hasBin: true @@ -7552,12 +7576,32 @@ packages: react-dom: '>=19.2.0' vite: ^7.0.0 + vite-dev-rpc@1.1.0: + resolution: {integrity: sha512-pKXZlgoXGoE8sEKiKJSng4hI1sQ4wi5YT24FCrwrLt6opmkjlqPPVmiPWWJn8M8byMxRGzp1CrFuqQs4M/Z39A==} + peerDependencies: + vite: ^2.9.0 || ^3.0.0-0 || ^4.0.0-0 || ^5.0.0-0 || ^6.0.1 || ^7.0.0-0 + + vite-hot-client@2.1.0: + resolution: {integrity: sha512-7SpgZmU7R+dDnSmvXE1mfDtnHLHQSisdySVR7lO8ceAXvM0otZeuQQ6C8LrS5d/aYyP/QZ0hI0L+dIPrm4YlFQ==} + peerDependencies: + vite: ^2.6.0 || ^3.0.0 || ^4.0.0 || ^5.0.0-0 || ^6.0.0-0 || ^7.0.0-0 + vite-plugin-commonjs@0.10.4: resolution: {integrity: sha512-eWQuvQKCcx0QYB5e5xfxBNjQKyrjEWZIR9UOkOV6JAgxVhtbZvCOF+FNC2ZijBJ3U3Px04ZMMyyMyFBVWIJ5+g==} vite-plugin-dynamic-import@1.6.0: resolution: {integrity: sha512-TM0sz70wfzTIo9YCxVFwS8OA9lNREsh+0vMHGSkWDTZ7bgd1Yjs5RV8EgB634l/91IsXJReg0xtmuQqP0mf+rg==} + vite-plugin-inspect@11.3.3: + resolution: {integrity: sha512-u2eV5La99oHoYPHE6UvbwgEqKKOQGz86wMg40CCosP6q8BkB6e5xPneZfYagK4ojPJSj5anHCrnvC20DpwVdRA==} + engines: {node: '>=14'} + peerDependencies: + '@nuxt/kit': '*' + vite: ^6.0.0 || ^7.0.0-0 + peerDependenciesMeta: + '@nuxt/kit': + optional: true + vite-plugin-storybook-nextjs@3.2.2: resolution: {integrity: sha512-ZJXCrhi9mW4jEJTKhJ5sUtpBe84mylU40me2aMuLSgIJo4gE/Rc559hZvMYLFTWta1gX7Rm8Co5EEHakPct+wA==} peerDependencies: @@ -9707,16 +9751,10 @@ snapshots: '@parcel/watcher-win32-x64': 2.5.6 optional: true - '@pivanov/utils@0.0.2(react-dom@19.2.4(react@19.2.4))(react@19.2.4)': - dependencies: - react: 19.2.4 - react-dom: 19.2.4(react@19.2.4) - - '@pkgjs/parseargs@0.11.0': - optional: true - '@pkgr/core@0.2.9': {} + '@polka/url@1.0.0-next.29': {} + '@preact/signals-core@1.12.2': {} '@preact/signals@1.3.2(preact@10.28.2)': @@ -11533,6 +11571,8 @@ snapshots: birecord@0.1.1: {} + birpc@2.9.0: {} + bl@4.1.0: dependencies: buffer: 5.7.1 @@ -12237,6 +12277,8 @@ snapshots: environment@1.1.0: {} + error-stack-parser-es@1.0.5: {} + es-module-lexer@1.7.0: {} es-module-lexer@2.0.0: {} @@ -14300,6 +14342,8 @@ snapshots: dependencies: color-name: 1.1.4 + mrmime@2.0.1: {} + ms@2.1.3: {} mz@2.7.0: @@ -14396,6 +14440,8 @@ snapshots: obug@2.1.1: {} + ohash@2.0.11: {} + once@1.4.0: dependencies: wrappy: 1.0.2 @@ -14549,6 +14595,8 @@ snapshots: pend@1.2.0: {} + perfect-debounce@2.1.0: {} + periscopic@4.0.2: dependencies: '@types/estree': 1.0.8 @@ -15381,6 +15429,12 @@ snapshots: dependencies: is-arrayish: 0.3.4 + sirv@3.0.2: + dependencies: + '@polka/url': 1.0.0-next.29 + mrmime: 2.0.1 + totalist: 3.0.1 + sisteransi@1.0.5: {} size-sensor@1.0.3: {} @@ -15696,6 +15750,8 @@ snapshots: dependencies: eslint-visitor-keys: 5.0.0 + totalist@3.0.1: {} + tough-cookie@6.0.0: dependencies: tldts: 7.0.17 @@ -15840,6 +15896,11 @@ snapshots: unpic@4.2.2: {} + unplugin-utils@0.3.1: + dependencies: + pathe: 2.0.3 + picomatch: 4.0.3 + unplugin@2.1.0: dependencies: acorn: 8.16.0 @@ -15933,7 +15994,7 @@ snapshots: '@types/unist': 3.0.3 vfile-message: 4.0.3 - vinext@https://pkg.pr.new/hyoban/vinext@a30ba79(next@16.1.5(@babel/core@7.29.0)(react-dom@19.2.4(react@19.2.4))(react@19.2.4)(sass@1.93.2))(react-dom@19.2.4(react@19.2.4))(react@19.2.4)(typescript@5.9.3)(vite@8.0.0-beta.16(@types/node@24.10.12)(esbuild@0.27.2)(jiti@1.21.7)(sass@1.93.2)(terser@5.46.0)(tsx@4.21.0)(yaml@2.8.2))(webpack@5.104.1(esbuild@0.27.2)(uglify-js@3.19.3)): + vinext@https://pkg.pr.new/hyoban/vinext@556a6d6(next@16.1.5(@babel/core@7.29.0)(react-dom@19.2.4(react@19.2.4))(react@19.2.4)(sass@1.93.2))(react-dom@19.2.4(react@19.2.4))(react@19.2.4)(typescript@5.9.3)(vite@8.0.0-beta.16(@types/node@24.10.12)(esbuild@0.27.2)(jiti@1.21.7)(sass@1.93.2)(terser@5.46.0)(tsx@4.21.0)(yaml@2.8.2))(webpack@5.104.1(esbuild@0.27.2)(uglify-js@3.19.3)): dependencies: '@unpic/react': 1.0.2(next@16.1.5(@babel/core@7.29.0)(react-dom@19.2.4(react@19.2.4))(react@19.2.4)(sass@1.93.2))(react-dom@19.2.4(react@19.2.4))(react@19.2.4) '@vercel/og': 0.8.6 @@ -15952,6 +16013,16 @@ snapshots: - typescript - webpack + vite-dev-rpc@1.1.0(vite@8.0.0-beta.16(@types/node@24.10.12)(esbuild@0.27.2)(jiti@1.21.7)(sass@1.93.2)(terser@5.46.0)(tsx@4.21.0)(yaml@2.8.2)): + dependencies: + birpc: 2.9.0 + vite: 8.0.0-beta.16(@types/node@24.10.12)(esbuild@0.27.2)(jiti@1.21.7)(sass@1.93.2)(terser@5.46.0)(tsx@4.21.0)(yaml@2.8.2) + vite-hot-client: 2.1.0(vite@8.0.0-beta.16(@types/node@24.10.12)(esbuild@0.27.2)(jiti@1.21.7)(sass@1.93.2)(terser@5.46.0)(tsx@4.21.0)(yaml@2.8.2)) + + vite-hot-client@2.1.0(vite@8.0.0-beta.16(@types/node@24.10.12)(esbuild@0.27.2)(jiti@1.21.7)(sass@1.93.2)(terser@5.46.0)(tsx@4.21.0)(yaml@2.8.2)): + dependencies: + vite: 8.0.0-beta.16(@types/node@24.10.12)(esbuild@0.27.2)(jiti@1.21.7)(sass@1.93.2)(terser@5.46.0)(tsx@4.21.0)(yaml@2.8.2) + vite-plugin-commonjs@0.10.4: dependencies: acorn: 8.16.0 @@ -15965,6 +16036,21 @@ snapshots: fast-glob: 3.3.3 magic-string: 0.30.21 + vite-plugin-inspect@11.3.3(vite@8.0.0-beta.16(@types/node@24.10.12)(esbuild@0.27.2)(jiti@1.21.7)(sass@1.93.2)(terser@5.46.0)(tsx@4.21.0)(yaml@2.8.2)): + dependencies: + ansis: 4.2.0 + debug: 4.4.3 + error-stack-parser-es: 1.0.5 + ohash: 2.0.11 + open: 10.2.0 + perfect-debounce: 2.1.0 + sirv: 3.0.2 + unplugin-utils: 0.3.1 + vite: 8.0.0-beta.16(@types/node@24.10.12)(esbuild@0.27.2)(jiti@1.21.7)(sass@1.93.2)(terser@5.46.0)(tsx@4.21.0)(yaml@2.8.2) + vite-dev-rpc: 1.1.0(vite@8.0.0-beta.16(@types/node@24.10.12)(esbuild@0.27.2)(jiti@1.21.7)(sass@1.93.2)(terser@5.46.0)(tsx@4.21.0)(yaml@2.8.2)) + transitivePeerDependencies: + - supports-color + vite-plugin-storybook-nextjs@3.2.2(next@16.1.5(@babel/core@7.29.0)(react-dom@19.2.4(react@19.2.4))(react@19.2.4)(sass@1.93.2))(storybook@10.2.13(@testing-library/dom@10.4.1)(react-dom@19.2.4(react@19.2.4))(react@19.2.4))(typescript@5.9.3)(vite@8.0.0-beta.16(@types/node@24.10.12)(esbuild@0.27.2)(jiti@1.21.7)(sass@1.93.2)(terser@5.46.0)(tsx@4.21.0)(yaml@2.8.2)): dependencies: '@next/env': 16.0.0 diff --git a/web/vite.config.ts b/web/vite.config.ts index 152df913f8..b07e7ea7be 100644 --- a/web/vite.config.ts +++ b/web/vite.config.ts @@ -6,6 +6,7 @@ import react from '@vitejs/plugin-react' import { codeInspectorPlugin } from 'code-inspector-plugin' import vinext from 'vinext' import { defineConfig } from 'vite' +import Inspect from 'vite-plugin-inspect' import tsconfigPaths from 'vite-tsconfig-paths' const __dirname = path.dirname(fileURLToPath(import.meta.url)) @@ -70,6 +71,93 @@ const createForceInspectorClientInjectionPlugin = (): Plugin => { } } +function customI18nHmrPlugin(): Plugin { + const injectTarget = inspectorInjectTarget + const i18nHmrClientMarker = 'custom-i18n-hmr-client' + const i18nHmrClientSnippet = `/* ${i18nHmrClientMarker} */ +if (import.meta.hot) { + const getI18nUpdateTarget = (file) => { + const match = file.match(/[/\\\\]i18n[/\\\\]([^/\\\\]+)[/\\\\]([^/\\\\]+)\\.json$/) + if (!match) + return null + const [, locale, namespaceFile] = match + return { locale, namespaceFile } + } + + import.meta.hot.on('i18n-update', async ({ file, content }) => { + const target = getI18nUpdateTarget(file) + if (!target) + return + + const [{ getI18n }, { camelCase }] = await Promise.all([ + import('react-i18next'), + import('es-toolkit/string'), + ]) + + const i18n = getI18n() + if (!i18n) + return + if (target.locale !== i18n.language) + return + + let resources + try { + resources = JSON.parse(content) + } + catch { + return + } + + const namespace = camelCase(target.namespaceFile) + i18n.addResourceBundle(target.locale, namespace, resources, true, true) + i18n.emit('languageChanged', i18n.language) + }) +} +` + + const injectI18nHmrClient = (code: string) => { + if (code.includes(i18nHmrClientMarker)) + return code + + const useClientMatch = code.match(/(['"])use client\1;?\s*\n/) + if (!useClientMatch) + return `${i18nHmrClientSnippet}\n${code}` + + const insertAt = (useClientMatch.index ?? 0) + useClientMatch[0].length + return `${code.slice(0, insertAt)}\n${i18nHmrClientSnippet}\n${code.slice(insertAt)}` + } + + return { + name: 'custom-i18n-hmr', + apply: 'serve', + handleHotUpdate({ file, server }) { + if (file.endsWith('.json') && file.includes('/i18n/')) { + server.ws.send({ + type: 'custom', + event: 'i18n-update', + data: { + file, + content: fs.readFileSync(file, 'utf-8'), + }, + }) + + // return empty array to prevent the default HMR + return [] + } + }, + transform(code, id) { + const cleanId = normalizeInspectorModuleId(id) + if (cleanId !== injectTarget) + return null + + const nextCode = injectI18nHmrClient(code) + if (nextCode === code) + return null + return { code: nextCode, map: null } + }, + } +} + export default defineConfig(({ mode }) => { const isTest = mode === 'test' @@ -89,10 +177,12 @@ export default defineConfig(({ mode }) => { } as Plugin, ] : [ + Inspect(), createCodeInspectorPlugin(), createForceInspectorClientInjectionPlugin(), react(), vinext(), + customI18nHmrPlugin(), ], resolve: { alias: { From 89a859ae328200c5f7be80ac88a435a8d43071be Mon Sep 17 00:00:00 2001 From: Coding On Star <447357187@qq.com> Date: Thu, 5 Mar 2026 13:53:20 +0800 Subject: [PATCH 039/256] refactor: simplify oauthNewUser state handling in AppInitializer component (#33019) Co-authored-by: CodingOnStar --- web/app/components/app-initializer.tsx | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/web/app/components/app-initializer.tsx b/web/app/components/app-initializer.tsx index e4cd10175a..bf7aa39580 100644 --- a/web/app/components/app-initializer.tsx +++ b/web/app/components/app-initializer.tsx @@ -26,11 +26,10 @@ export const AppInitializer = ({ // Tokens are now stored in cookies, no need to check localStorage const pathname = usePathname() const [init, setInit] = useState(false) - const [oauthNewUser, setOauthNewUser] = useQueryState( + const [oauthNewUser] = useQueryState( 'oauth_new_user', parseAsBoolean.withOptions({ history: 'replace' }), ) - const isSetupFinished = useCallback(async () => { try { const setUpStatus = await fetchSetupStatusWithCache() @@ -69,11 +68,12 @@ export const AppInitializer = ({ ...utmInfo, }) - // Clean up: remove utm_info cookie and URL params Cookies.remove('utm_info') - setOauthNewUser(null) } + if (oauthNewUser !== null) + router.replace(pathname) + if (action === EDUCATION_VERIFY_URL_SEARCHPARAMS_ACTION) localStorage.setItem(EDUCATION_VERIFYING_LOCALSTORAGE_ITEM, 'yes') @@ -96,7 +96,7 @@ export const AppInitializer = ({ router.replace('/signin') } })() - }, [isSetupFinished, router, pathname, searchParams, oauthNewUser, setOauthNewUser]) + }, [isSetupFinished, router, pathname, searchParams, oauthNewUser]) return init ? children : null } From 7432b58f82e84408fe0828b7f8cec5d99f212944 Mon Sep 17 00:00:00 2001 From: 99 Date: Thu, 5 Mar 2026 14:31:28 +0800 Subject: [PATCH 040/256] refactor(dify_graph): introduce run_context and delegate child engine creation (#32964) --- api/.importlinter | 13 -- api/core/app/apps/pipeline/pipeline_runner.py | 16 +- api/core/app/apps/workflow_app_runner.py | 27 ++-- api/core/app/entities/app_invoke_entities.py | 66 +++++++- api/core/app/workflow/layers/llm_quota.py | 3 +- api/core/workflow/node_factory.py | 18 ++- api/core/workflow/workflow_entry.py | 92 ++++++++++-- api/dify_graph/entities/graph_init_params.py | 8 +- api/dify_graph/enums.py | 33 ---- api/dify_graph/graph_engine/graph_engine.py | 27 +++- api/dify_graph/nodes/agent/agent_node.py | 32 ++-- api/dify_graph/nodes/base/node.py | 59 +++++++- .../nodes/datasource/datasource_node.py | 9 +- api/dify_graph/nodes/http_request/node.py | 7 +- .../nodes/human_input/human_input_node.py | 32 ++-- .../nodes/iteration/iteration_node.py | 53 ++----- .../knowledge_index/knowledge_index_node.py | 6 +- .../knowledge_retrieval_node.py | 22 +-- api/dify_graph/nodes/llm/node.py | 9 +- api/dify_graph/nodes/loop/loop_node.py | 34 +---- .../parameter_extractor_node.py | 2 +- .../question_classifier_node.py | 7 +- api/dify_graph/nodes/tool/tool_node.py | 17 ++- api/dify_graph/nodes/trigger_webhook/node.py | 3 +- api/dify_graph/runtime/__init__.py | 10 +- api/dify_graph/runtime/graph_runtime_state.py | 52 +++++++ api/services/workflow_app_service.py | 11 +- api/services/workflow_service.py | 16 +- .../workflow/nodes/test_code.py | 12 +- .../workflow/nodes/test_http.py | 18 +-- .../workflow/nodes/test_llm.py | 12 +- .../nodes/test_parameter_extractor.py | 12 +- .../workflow/nodes/test_template_transform.py | 12 +- .../workflow/nodes/test_tool.py | 12 +- .../test_human_input_resume_node_execution.py | 8 +- .../core/app/apps/test_pause_resume.py | 8 +- .../graph/test_graph_skip_validation.py | 14 +- .../workflow/graph/test_graph_validation.py | 14 +- .../graph_engine/layers/test_llm_quota.py | 5 + .../graph_engine/test_auto_mock_system.py | 23 ++- .../graph_engine/test_command_system.py | 47 +++--- .../graph_engine/test_graph_state_snapshot.py | 7 +- .../test_human_input_pause_multi_branch.py | 8 +- .../test_human_input_pause_single_branch.py | 8 +- .../graph_engine/test_if_else_streaming.py | 9 +- .../test_mock_iteration_simple.py | 53 ++++--- .../workflow/graph_engine/test_mock_nodes.py | 12 +- .../test_mock_nodes_template_code.py | 141 +++++++++++------- .../workflow/graph_engine/test_mock_simple.py | 36 +++-- .../test_parallel_human_input_join_resume.py | 8 +- ...rallel_human_input_pause_missing_finish.py | 8 +- .../test_parallel_streaming_workflow.py | 16 +- .../test_pause_deferred_ready_nodes.py | 8 +- .../graph_engine/test_pause_resume_state.py | 8 +- .../graph_engine/test_table_runner.py | 71 +++++++-- .../core/workflow/nodes/answer/test_answer.py | 12 +- .../nodes/datasource/test_datasource_node.py | 15 +- .../http_request/test_http_request_node.py | 12 +- .../nodes/human_input/test_entities.py | 57 ++++--- .../test_human_input_form_filled_event.py | 34 +++-- .../test_iteration_child_engine_errors.py | 100 +++++++++++++ .../test_knowledge_index_node.py | 12 +- .../test_knowledge_retrieval_node.py | 12 +- .../workflow/nodes/list_operator/node_spec.py | 112 ++++++-------- .../core/workflow/nodes/llm/test_node.py | 10 +- .../template_transform_node_spec.py | 13 +- .../core/workflow/nodes/test_base_node.py | 36 ++--- .../nodes/test_document_extractor_node.py | 11 +- .../core/workflow/nodes/test_if_else.py | 51 ++++--- .../core/workflow/nodes/test_list_operator.py | 19 ++- .../nodes/test_start_node_json_object.py | 8 +- .../workflow/nodes/tool/test_tool_node.py | 8 +- .../v1/test_variable_assigner_v1.py | 46 +++--- .../v2/test_variable_assigner_v2.py | 74 +++++---- .../webhook/test_webhook_file_conversion.py | 21 +-- .../nodes/webhook/test_webhook_node.py | 21 +-- .../test_workflow_entry_redis_channel.py | 3 +- api/tests/workflow_test_utils.py | 53 +++++++ 78 files changed, 1281 insertions(+), 733 deletions(-) create mode 100644 api/tests/unit_tests/core/workflow/nodes/iteration/test_iteration_child_engine_errors.py create mode 100644 api/tests/workflow_test_utils.py diff --git a/api/.importlinter b/api/.importlinter index 10faeb448a..e4536b1f10 100644 --- a/api/.importlinter +++ b/api/.importlinter @@ -28,17 +28,8 @@ ignore_imports = dify_graph.nodes.iteration.iteration_node -> dify_graph.graph_events dify_graph.nodes.loop.loop_node -> dify_graph.graph_events - dify_graph.nodes.iteration.iteration_node -> core.workflow.node_factory - dify_graph.nodes.loop.loop_node -> core.workflow.node_factory - dify_graph.nodes.iteration.iteration_node -> core.app.workflow.layers.llm_quota - dify_graph.nodes.loop.loop_node -> core.app.workflow.layers.llm_quota - dify_graph.nodes.iteration.iteration_node -> dify_graph.graph_engine - dify_graph.nodes.iteration.iteration_node -> dify_graph.graph - dify_graph.nodes.iteration.iteration_node -> dify_graph.graph_engine.command_channels dify_graph.nodes.loop.loop_node -> dify_graph.graph_engine - dify_graph.nodes.loop.loop_node -> dify_graph.graph - dify_graph.nodes.loop.loop_node -> dify_graph.graph_engine.command_channels # TODO(QuantumGhost): fix the import violation later dify_graph.entities.pause_reason -> dify_graph.nodes.human_input.entities @@ -101,12 +92,9 @@ forbidden_modules = core.trigger core.variables ignore_imports = - dify_graph.nodes.loop.loop_node -> core.workflow.node_factory dify_graph.nodes.agent.agent_node -> core.model_manager dify_graph.nodes.agent.agent_node -> core.provider_manager dify_graph.nodes.agent.agent_node -> core.tools.tool_manager - dify_graph.nodes.iteration.iteration_node -> core.workflow.node_factory - dify_graph.nodes.iteration.iteration_node -> core.app.workflow.layers.llm_quota dify_graph.nodes.llm.llm_utils -> core.model_manager dify_graph.nodes.llm.protocols -> core.model_manager dify_graph.nodes.llm.llm_utils -> dify_graph.model_runtime.model_providers.__base.large_language_model @@ -151,7 +139,6 @@ ignore_imports = dify_graph.nodes.llm.node -> extensions.ext_database dify_graph.nodes.tool.tool_node -> extensions.ext_database dify_graph.nodes.agent.agent_node -> models - dify_graph.nodes.loop.loop_node -> core.app.workflow.layers.llm_quota dify_graph.nodes.llm.node -> models.model dify_graph.nodes.agent.agent_node -> services dify_graph.nodes.tool.tool_node -> services diff --git a/api/core/app/apps/pipeline/pipeline_runner.py b/api/core/app/apps/pipeline/pipeline_runner.py index 748edb7956..4222aae809 100644 --- a/api/core/app/apps/pipeline/pipeline_runner.py +++ b/api/core/app/apps/pipeline/pipeline_runner.py @@ -8,12 +8,14 @@ from core.app.apps.workflow_app_runner import WorkflowBasedAppRunner from core.app.entities.app_invoke_entities import ( InvokeFrom, RagPipelineGenerateEntity, + UserFrom, + build_dify_run_context, ) from core.app.workflow.layers.persistence import PersistenceWorkflowInfo, WorkflowPersistenceLayer from core.workflow.node_factory import DifyNodeFactory from core.workflow.workflow_entry import WorkflowEntry from dify_graph.entities.graph_init_params import GraphInitParams -from dify_graph.enums import UserFrom, WorkflowType +from dify_graph.enums import WorkflowType from dify_graph.graph import Graph from dify_graph.graph_events import GraphEngineEvent, GraphRunFailedEvent from dify_graph.repositories.workflow_execution_repository import WorkflowExecutionRepository @@ -256,13 +258,15 @@ class PipelineRunner(WorkflowBasedAppRunner): # init graph # Create required parameters for Graph.init graph_init_params = GraphInitParams( - tenant_id=workflow.tenant_id, - app_id=self._app_id, workflow_id=workflow.id, graph_config=graph_config, - user_id=self.application_generate_entity.user_id, - user_from=user_from, - invoke_from=invoke_from, + run_context=build_dify_run_context( + tenant_id=workflow.tenant_id, + app_id=self._app_id, + user_id=self.application_generate_entity.user_id, + user_from=user_from, + invoke_from=invoke_from, + ), call_depth=0, ) diff --git a/api/core/app/apps/workflow_app_runner.py b/api/core/app/apps/workflow_app_runner.py index c5a00aa4ff..7ef6ff7cc2 100644 --- a/api/core/app/apps/workflow_app_runner.py +++ b/api/core/app/apps/workflow_app_runner.py @@ -4,7 +4,7 @@ from collections.abc import Mapping, Sequence from typing import Any, cast from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom -from core.app.entities.app_invoke_entities import InvokeFrom +from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom, build_dify_run_context from core.app.entities.queue_entities import ( AppQueueEvent, QueueAgentLogEvent, @@ -33,7 +33,6 @@ from core.workflow.node_factory import DifyNodeFactory from core.workflow.workflow_entry import WorkflowEntry from dify_graph.entities import GraphInitParams from dify_graph.entities.pause_reason import HumanInputRequired -from dify_graph.enums import UserFrom from dify_graph.graph import Graph from dify_graph.graph_engine.layers.base import GraphEngineLayer from dify_graph.graph_events import ( @@ -119,13 +118,15 @@ class WorkflowBasedAppRunner: # Create required parameters for Graph.init graph_init_params = GraphInitParams( - tenant_id=tenant_id or "", - app_id=self._app_id, workflow_id=workflow_id, graph_config=graph_config, - user_id=user_id, - user_from=user_from, - invoke_from=invoke_from, + run_context=build_dify_run_context( + tenant_id=tenant_id or "", + app_id=self._app_id, + user_id=user_id, + user_from=user_from, + invoke_from=invoke_from, + ), call_depth=0, ) @@ -267,13 +268,15 @@ class WorkflowBasedAppRunner: # Create required parameters for Graph.init graph_init_params = GraphInitParams( - tenant_id=workflow.tenant_id, - app_id=self._app_id, workflow_id=workflow.id, graph_config=graph_config, - user_id="", - user_from=UserFrom.ACCOUNT, - invoke_from=InvokeFrom.DEBUGGER, + run_context=build_dify_run_context( + tenant_id=workflow.tenant_id, + app_id=self._app_id, + user_id="", + user_from=UserFrom.ACCOUNT, + invoke_from=InvokeFrom.DEBUGGER, + ), call_depth=0, ) diff --git a/api/core/app/entities/app_invoke_entities.py b/api/core/app/entities/app_invoke_entities.py index 6ecca84425..ecbb1cf2f3 100644 --- a/api/core/app/entities/app_invoke_entities.py +++ b/api/core/app/entities/app_invoke_entities.py @@ -1,4 +1,5 @@ from collections.abc import Mapping, Sequence +from enum import StrEnum from typing import TYPE_CHECKING, Any, Optional from pydantic import BaseModel, ConfigDict, Field, ValidationInfo, field_validator @@ -6,7 +7,7 @@ from pydantic import BaseModel, ConfigDict, Field, ValidationInfo, field_validat from constants import UUID_NIL from core.app.app_config.entities import EasyUIBasedAppConfig, WorkflowUIBasedAppConfig from core.entities.provider_configuration import ProviderModelBundle -from dify_graph.enums import InvokeFrom +from dify_graph.entities.graph_init_params import DIFY_RUN_CONTEXT_KEY from dify_graph.file import File, FileUploadConfig from dify_graph.model_runtime.entities.model_entities import AIModelEntity @@ -14,6 +15,69 @@ if TYPE_CHECKING: from core.ops.ops_trace_manager import TraceQueueManager +class UserFrom(StrEnum): + ACCOUNT = "account" + END_USER = "end-user" + + +class InvokeFrom(StrEnum): + SERVICE_API = "service-api" + WEB_APP = "web-app" + TRIGGER = "trigger" + EXPLORE = "explore" + DEBUGGER = "debugger" + PUBLISHED_PIPELINE = "published" + VALIDATION = "validation" + + @classmethod + def value_of(cls, value: str) -> "InvokeFrom": + return cls(value) + + def to_source(self) -> str: + source_mapping = { + InvokeFrom.WEB_APP: "web_app", + InvokeFrom.DEBUGGER: "dev", + InvokeFrom.EXPLORE: "explore_app", + InvokeFrom.TRIGGER: "trigger", + InvokeFrom.SERVICE_API: "api", + } + return source_mapping.get(self, "dev") + + +class DifyRunContext(BaseModel): + tenant_id: str + app_id: str + user_id: str + user_from: UserFrom + invoke_from: InvokeFrom + + +def build_dify_run_context( + *, + tenant_id: str, + app_id: str, + user_id: str, + user_from: UserFrom, + invoke_from: InvokeFrom, + extra_context: Mapping[str, Any] | None = None, +) -> dict[str, Any]: + """ + Build graph run_context with the reserved Dify runtime payload. + + `extra_context` can carry user-defined context keys. The reserved `_dify` + payload is always overwritten by this function to keep one canonical source. + """ + run_context = dict(extra_context) if extra_context else {} + run_context[DIFY_RUN_CONTEXT_KEY] = DifyRunContext( + tenant_id=tenant_id, + app_id=app_id, + user_id=user_id, + user_from=user_from, + invoke_from=invoke_from, + ) + return run_context + + class ModelConfigWithCredentialsEntity(BaseModel): """ Model Config With Credentials Entity. diff --git a/api/core/app/workflow/layers/llm_quota.py b/api/core/app/workflow/layers/llm_quota.py index f6c80d25c5..2e930a1f58 100644 --- a/api/core/app/workflow/layers/llm_quota.py +++ b/api/core/app/workflow/layers/llm_quota.py @@ -75,8 +75,9 @@ class LLMQuotaLayer(GraphEngineLayer): return try: + dify_ctx = node.require_dify_context() deduct_llm_quota( - tenant_id=node.tenant_id, + tenant_id=dify_ctx.tenant_id, model_instance=model_instance, usage=result_event.node_run_result.llm_usage, ) diff --git a/api/core/workflow/node_factory.py b/api/core/workflow/node_factory.py index 714b0ca3d0..4cbee08a65 100644 --- a/api/core/workflow/node_factory.py +++ b/api/core/workflow/node_factory.py @@ -6,6 +6,7 @@ from sqlalchemy.orm import Session from typing_extensions import override from configs import dify_config +from core.app.entities.app_invoke_entities import DifyRunContext from core.app.llm.model_access import build_dify_model_access from core.datasource.datasource_manager import DatasourceManager from core.helper.code_executor.code_executor import ( @@ -22,6 +23,7 @@ from core.rag.summary_index.summary_index import SummaryIndex from core.repositories.human_input_repository import HumanInputFormRepositoryImpl from core.tools.tool_file_manager import ToolFileManager from dify_graph.entities.graph_config import NodeConfigDict +from dify_graph.entities.graph_init_params import DIFY_RUN_CONTEXT_KEY from dify_graph.enums import NodeType, SystemVariableKey from dify_graph.file.file_manager import file_manager from dify_graph.graph.graph import NodeFactory @@ -110,6 +112,7 @@ class DifyNodeFactory(NodeFactory): ) -> None: self.graph_init_params = graph_init_params self.graph_runtime_state = graph_runtime_state + self._dify_context = self._resolve_dify_context(graph_init_params.run_context) self._code_executor: WorkflowCodeExecutor = DefaultWorkflowCodeExecutor() self._code_limits = CodeNodeLimits( max_string_length=dify_config.CODE_MAX_STRING_LENGTH, @@ -141,7 +144,16 @@ class DifyNodeFactory(NodeFactory): ssrf_default_max_retries=dify_config.SSRF_DEFAULT_MAX_RETRIES, ) - self._llm_credentials_provider, self._llm_model_factory = build_dify_model_access(graph_init_params.tenant_id) + self._llm_credentials_provider, self._llm_model_factory = build_dify_model_access(self._dify_context.tenant_id) + + @staticmethod + def _resolve_dify_context(run_context: Mapping[str, Any]) -> DifyRunContext: + raw_ctx = run_context.get(DIFY_RUN_CONTEXT_KEY) + if raw_ctx is None: + raise ValueError(f"run_context missing required key: {DIFY_RUN_CONTEXT_KEY}") + if isinstance(raw_ctx, DifyRunContext): + return raw_ctx + return DifyRunContext.model_validate(raw_ctx) @override def create_node(self, node_config: NodeConfigDict) -> Node: @@ -213,7 +225,7 @@ class DifyNodeFactory(NodeFactory): config=node_config, graph_init_params=self.graph_init_params, graph_runtime_state=self.graph_runtime_state, - form_repository=HumanInputFormRepositoryImpl(tenant_id=self.graph_init_params.tenant_id), + form_repository=HumanInputFormRepositoryImpl(tenant_id=self._dify_context.tenant_id), ) if node_type == NodeType.KNOWLEDGE_INDEX: @@ -356,7 +368,7 @@ class DifyNodeFactory(NodeFactory): ) return fetch_memory( conversation_id=conversation_id, - app_id=self.graph_init_params.app_id, + app_id=self._dify_context.app_id, node_data_memory=node_memory, model_instance=model_instance, ) diff --git a/api/core/workflow/workflow_entry.py b/api/core/workflow/workflow_entry.py index 6210a81c4e..c259e7ac08 100644 --- a/api/core/workflow/workflow_entry.py +++ b/api/core/workflow/workflow_entry.py @@ -5,26 +5,26 @@ from typing import Any, cast from configs import dify_config from core.app.apps.exc import GenerateTaskStoppedError -from core.app.entities.app_invoke_entities import InvokeFrom +from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom, build_dify_run_context from core.app.workflow.layers.llm_quota import LLMQuotaLayer from core.app.workflow.layers.observability import ObservabilityLayer from core.workflow.node_factory import DifyNodeFactory from dify_graph.constants import ENVIRONMENT_VARIABLE_NODE_ID from dify_graph.entities import GraphInitParams from dify_graph.entities.graph_config import NodeConfigData, NodeConfigDict -from dify_graph.enums import UserFrom from dify_graph.errors import WorkflowNodeRunFailedError from dify_graph.file.models import File from dify_graph.graph import Graph from dify_graph.graph_engine import GraphEngine, GraphEngineConfig from dify_graph.graph_engine.command_channels import InMemoryChannel from dify_graph.graph_engine.layers import DebugLoggingLayer, ExecutionLimitsLayer +from dify_graph.graph_engine.layers.base import GraphEngineLayer from dify_graph.graph_engine.protocols.command_channel import CommandChannel from dify_graph.graph_events import GraphEngineEvent, GraphNodeEventBase, GraphRunFailedEvent from dify_graph.nodes import NodeType from dify_graph.nodes.base.node import Node from dify_graph.nodes.node_mapping import NODE_TYPE_CLASSES_MAPPING -from dify_graph.runtime import GraphRuntimeState, VariablePool +from dify_graph.runtime import ChildGraphNotFoundError, GraphRuntimeState, VariablePool from dify_graph.system_variable import SystemVariable from dify_graph.variable_loader import DUMMY_VARIABLE_LOADER, VariableLoader, load_into_variable_pool from extensions.otel.runtime import is_instrument_flag_enabled @@ -34,6 +34,66 @@ from models.workflow import Workflow logger = logging.getLogger(__name__) +class _WorkflowChildEngineBuilder: + @staticmethod + def _has_node_id(graph_config: Mapping[str, Any], node_id: str) -> bool | None: + """ + Return whether `graph_config["nodes"]` contains the given node id. + + Returns `None` when the nodes payload shape is unexpected, so graph-level + validation can surface the original configuration error. + """ + nodes = graph_config.get("nodes") + if not isinstance(nodes, list): + return None + + for node in nodes: + if not isinstance(node, Mapping): + return None + current_id = node.get("id") + if isinstance(current_id, str) and current_id == node_id: + return True + return False + + def build_child_engine( + self, + *, + workflow_id: str, + graph_init_params: GraphInitParams, + graph_runtime_state: GraphRuntimeState, + graph_config: Mapping[str, Any], + root_node_id: str, + layers: Sequence[object] = (), + ) -> GraphEngine: + node_factory = DifyNodeFactory( + graph_init_params=graph_init_params, + graph_runtime_state=graph_runtime_state, + ) + + has_root_node = self._has_node_id(graph_config=graph_config, node_id=root_node_id) + if has_root_node is False: + raise ChildGraphNotFoundError(f"child graph root node '{root_node_id}' not found") + + child_graph = Graph.init( + graph_config=graph_config, + node_factory=node_factory, + root_node_id=root_node_id, + ) + + child_engine = GraphEngine( + workflow_id=workflow_id, + graph=child_graph, + graph_runtime_state=graph_runtime_state, + command_channel=InMemoryChannel(), + config=GraphEngineConfig(), + child_engine_builder=self, + ) + child_engine.layer(LLMQuotaLayer()) + for layer in layers: + child_engine.layer(cast(GraphEngineLayer, layer)) + return child_engine + + class WorkflowEntry: def __init__( self, @@ -77,6 +137,7 @@ class WorkflowEntry: command_channel = InMemoryChannel() self.command_channel = command_channel + self._child_engine_builder = _WorkflowChildEngineBuilder() self.graph_engine = GraphEngine( workflow_id=workflow_id, graph=graph, @@ -88,6 +149,7 @@ class WorkflowEntry: scale_up_threshold=dify_config.GRAPH_ENGINE_SCALE_UP_THRESHOLD, scale_down_idle_time=dify_config.GRAPH_ENGINE_SCALE_DOWN_IDLE_TIME, ), + child_engine_builder=self._child_engine_builder, ) # Add debug logging layer when in debug mode @@ -154,13 +216,15 @@ class WorkflowEntry: # init graph init params and runtime state graph_init_params = GraphInitParams( - tenant_id=workflow.tenant_id, - app_id=workflow.app_id, workflow_id=workflow.id, graph_config=workflow.graph_dict, - user_id=user_id, - user_from=UserFrom.ACCOUNT, - invoke_from=InvokeFrom.DEBUGGER, + run_context=build_dify_run_context( + tenant_id=workflow.tenant_id, + app_id=workflow.app_id, + user_id=user_id, + user_from=UserFrom.ACCOUNT, + invoke_from=InvokeFrom.DEBUGGER, + ), call_depth=0, ) graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()) @@ -293,13 +357,15 @@ class WorkflowEntry: # init graph init params and runtime state graph_init_params = GraphInitParams( - tenant_id=tenant_id, - app_id="", workflow_id="", graph_config=graph_dict, - user_id=user_id, - user_from=UserFrom.ACCOUNT, - invoke_from=InvokeFrom.DEBUGGER, + run_context=build_dify_run_context( + tenant_id=tenant_id, + app_id="", + user_id=user_id, + user_from=UserFrom.ACCOUNT, + invoke_from=InvokeFrom.DEBUGGER, + ), call_depth=0, ) graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()) diff --git a/api/dify_graph/entities/graph_init_params.py b/api/dify_graph/entities/graph_init_params.py index 3712842aaf..f785d58a52 100644 --- a/api/dify_graph/entities/graph_init_params.py +++ b/api/dify_graph/entities/graph_init_params.py @@ -3,7 +3,7 @@ from typing import Any from pydantic import BaseModel, Field -from dify_graph.enums import InvokeFrom, UserFrom +DIFY_RUN_CONTEXT_KEY = "_dify" class GraphInitParams(BaseModel): @@ -18,11 +18,7 @@ class GraphInitParams(BaseModel): """ # init params - tenant_id: str = Field(..., description="tenant / workspace id") - app_id: str = Field(..., description="app id") workflow_id: str = Field(..., description="workflow id") graph_config: Mapping[str, Any] = Field(..., description="graph config") - user_id: str = Field(..., description="user id") - user_from: UserFrom = Field(..., description="user from, account or end-user") - invoke_from: InvokeFrom = Field(..., description="invoke from, service-api, web-app, explore or debugger") + run_context: Mapping[str, Any] = Field(..., description="runtime context") call_depth: int = Field(..., description="call depth") diff --git a/api/dify_graph/enums.py b/api/dify_graph/enums.py index 6c0593945e..bb3b13e8c6 100644 --- a/api/dify_graph/enums.py +++ b/api/dify_graph/enums.py @@ -33,39 +33,6 @@ class SystemVariableKey(StrEnum): INVOKE_FROM = "invoke_from" -class UserFrom(StrEnum): - ACCOUNT = "account" - END_USER = "end-user" - - -class InvokeFrom(StrEnum): - SERVICE_API = "service-api" - WEB_APP = "web-app" - TRIGGER = "trigger" - EXPLORE = "explore" - DEBUGGER = "debugger" - PUBLISHED_PIPELINE = "published" - VALIDATION = "validation" - - @classmethod - def value_of(cls, value: str) -> "InvokeFrom": - return cls(value) - - def to_source(self) -> str: - """Get source of invoke from. - - :return: source - """ - source_mapping = { - InvokeFrom.WEB_APP: "web_app", - InvokeFrom.DEBUGGER: "dev", - InvokeFrom.EXPLORE: "explore_app", - InvokeFrom.TRIGGER: "trigger", - InvokeFrom.SERVICE_API: "api", - } - return source_mapping.get(self, "dev") - - class NodeType(StrEnum): START = "start" END = "end" diff --git a/api/dify_graph/graph_engine/graph_engine.py b/api/dify_graph/graph_engine/graph_engine.py index 772e607328..ea98a46b06 100644 --- a/api/dify_graph/graph_engine/graph_engine.py +++ b/api/dify_graph/graph_engine/graph_engine.py @@ -9,7 +9,7 @@ from __future__ import annotations import logging import queue -from collections.abc import Generator +from collections.abc import Generator, Mapping from typing import TYPE_CHECKING, cast, final from dify_graph.context import capture_current_context @@ -27,6 +27,7 @@ from dify_graph.graph_events import ( GraphRunSucceededEvent, ) from dify_graph.runtime import GraphRuntimeState, ReadOnlyGraphRuntimeStateWrapper +from dify_graph.runtime.graph_runtime_state import ChildGraphEngineBuilderProtocol if TYPE_CHECKING: # pragma: no cover - used only for static analysis from dify_graph.runtime.graph_runtime_state import GraphProtocol @@ -49,6 +50,7 @@ from .protocols.command_channel import CommandChannel from .worker_management import WorkerPool if TYPE_CHECKING: + from dify_graph.entities import GraphInitParams from dify_graph.graph_engine.domain.graph_execution import GraphExecution from dify_graph.graph_engine.response_coordinator import ResponseStreamCoordinator @@ -74,6 +76,7 @@ class GraphEngine: graph_runtime_state: GraphRuntimeState, command_channel: CommandChannel, config: GraphEngineConfig = _DEFAULT_CONFIG, + child_engine_builder: ChildGraphEngineBuilderProtocol | None = None, ) -> None: """Initialize the graph engine with all subsystems and dependencies.""" @@ -83,6 +86,9 @@ class GraphEngine: self._graph_runtime_state.configure(graph=cast("GraphProtocol", graph)) self._command_channel = command_channel self._config = config + self._child_engine_builder = child_engine_builder + if child_engine_builder is not None: + self._graph_runtime_state.bind_child_engine_builder(child_engine_builder) # Graph execution tracks the overall execution state self._graph_execution = cast("GraphExecution", self._graph_runtime_state.graph_execution) @@ -214,6 +220,25 @@ class GraphEngine: self._bind_layer_context(layer) return self + def create_child_engine( + self, + *, + workflow_id: str, + graph_init_params: GraphInitParams, + graph_runtime_state: GraphRuntimeState, + graph_config: dict[str, object] | Mapping[str, object], + root_node_id: str, + layers: list[GraphEngineLayer] | tuple[GraphEngineLayer, ...] = (), + ) -> GraphEngine: + return self._graph_runtime_state.create_child_engine( + workflow_id=workflow_id, + graph_init_params=graph_init_params, + graph_runtime_state=graph_runtime_state, + graph_config=graph_config, + root_node_id=root_node_id, + layers=layers, + ) + def run(self) -> Generator[GraphEngineEvent, None, None]: """ Execute the graph using the modular architecture. diff --git a/api/dify_graph/nodes/agent/agent_node.py b/api/dify_graph/nodes/agent/agent_node.py index f55871718f..d770f7afd1 100644 --- a/api/dify_graph/nodes/agent/agent_node.py +++ b/api/dify_graph/nodes/agent/agent_node.py @@ -80,9 +80,11 @@ class AgentNode(Node[AgentNodeData]): def _run(self) -> Generator[NodeEventBase, None, None]: from core.plugin.impl.exc import PluginDaemonClientSideError + dify_ctx = self.require_dify_context() + try: strategy = get_plugin_agent_strategy( - tenant_id=self.tenant_id, + tenant_id=dify_ctx.tenant_id, agent_strategy_provider_name=self.node_data.agent_strategy_provider_name, agent_strategy_name=self.node_data.agent_strategy_name, ) @@ -120,8 +122,8 @@ class AgentNode(Node[AgentNodeData]): try: message_stream = strategy.invoke( params=parameters, - user_id=self.user_id, - app_id=self.app_id, + user_id=dify_ctx.user_id, + app_id=dify_ctx.app_id, conversation_id=conversation_id.text if conversation_id else None, credentials=credentials, ) @@ -144,8 +146,8 @@ class AgentNode(Node[AgentNodeData]): "agent_strategy": self.node_data.agent_strategy_name, }, parameters_for_log=parameters_for_log, - user_id=self.user_id, - tenant_id=self.tenant_id, + user_id=dify_ctx.user_id, + tenant_id=dify_ctx.tenant_id, node_type=self.node_type, node_id=self._node_id, node_execution_id=self.id, @@ -283,8 +285,13 @@ class AgentNode(Node[AgentNodeData]): runtime_variable_pool: VariablePool | None = None if node_data.version != "1" or node_data.tool_node_version is not None: runtime_variable_pool = variable_pool + dify_ctx = self.require_dify_context() tool_runtime = ToolManager.get_agent_tool_runtime( - self.tenant_id, self.app_id, entity, self.invoke_from, runtime_variable_pool + dify_ctx.tenant_id, + dify_ctx.app_id, + entity, + dify_ctx.invoke_from, + runtime_variable_pool, ) if tool_runtime.entity.description: tool_runtime.entity.description.llm = ( @@ -396,7 +403,8 @@ class AgentNode(Node[AgentNodeData]): from core.plugin.impl.plugin import PluginInstaller manager = PluginInstaller() - plugins = manager.list_plugins(self.tenant_id) + dify_ctx = self.require_dify_context() + plugins = manager.list_plugins(dify_ctx.tenant_id) try: current_plugin = next( plugin @@ -417,8 +425,11 @@ class AgentNode(Node[AgentNodeData]): return None conversation_id = conversation_id_variable.value + dify_ctx = self.require_dify_context() with Session(db.engine, expire_on_commit=False) as session: - stmt = select(Conversation).where(Conversation.app_id == self.app_id, Conversation.id == conversation_id) + stmt = select(Conversation).where( + Conversation.app_id == dify_ctx.app_id, Conversation.id == conversation_id + ) conversation = session.scalar(stmt) if not conversation: @@ -429,9 +440,10 @@ class AgentNode(Node[AgentNodeData]): return memory def _fetch_model(self, value: dict[str, Any]) -> tuple[ModelInstance, AIModelEntity | None]: + dify_ctx = self.require_dify_context() provider_manager = ProviderManager() provider_model_bundle = provider_manager.get_provider_model_bundle( - tenant_id=self.tenant_id, provider=value.get("provider", ""), model_type=ModelType.LLM + tenant_id=dify_ctx.tenant_id, provider=value.get("provider", ""), model_type=ModelType.LLM ) model_name = value.get("model", "") model_credentials = provider_model_bundle.configuration.get_current_credentials( @@ -440,7 +452,7 @@ class AgentNode(Node[AgentNodeData]): provider_name = provider_model_bundle.configuration.provider.provider model_type_instance = provider_model_bundle.model_type_instance model_instance = ModelManager().get_model_instance( - tenant_id=self.tenant_id, + tenant_id=dify_ctx.tenant_id, provider=provider_name, model_type=ModelType(value.get("model_type", "")), model=model_name, diff --git a/api/dify_graph/nodes/base/node.py b/api/dify_graph/nodes/base/node.py index 8eaf0b16b3..1f99a0a6e2 100644 --- a/api/dify_graph/nodes/base/node.py +++ b/api/dify_graph/nodes/base/node.py @@ -8,10 +8,11 @@ from abc import abstractmethod from collections.abc import Generator, Mapping, Sequence from functools import singledispatchmethod from types import MappingProxyType -from typing import Any, ClassVar, Generic, TypeVar, cast, get_args, get_origin +from typing import Any, ClassVar, Generic, Protocol, TypeVar, cast, get_args, get_origin from uuid import uuid4 from dify_graph.entities import AgentNodeStrategyInit, GraphInitParams +from dify_graph.entities.graph_init_params import DIFY_RUN_CONTEXT_KEY from dify_graph.enums import ( ErrorStrategy, NodeExecutionType, @@ -64,10 +65,28 @@ from libs.datetime_utils import naive_utc_now from .entities import BaseNodeData, RetryConfig NodeDataT = TypeVar("NodeDataT", bound=BaseNodeData) +_MISSING_RUN_CONTEXT_VALUE = object() logger = logging.getLogger(__name__) +class DifyRunContextProtocol(Protocol): + tenant_id: str + app_id: str + user_id: str + user_from: Any + invoke_from: Any + + +class _MappingDifyRunContext: + def __init__(self, mapping: Mapping[str, Any]) -> None: + self.tenant_id = str(mapping["tenant_id"]) + self.app_id = str(mapping["app_id"]) + self.user_id = str(mapping["user_id"]) + self.user_from = mapping["user_from"] + self.invoke_from = mapping["invoke_from"] + + class Node(Generic[NodeDataT]): """BaseNode serves as the foundational class for all node implementations. @@ -227,14 +246,10 @@ class Node(Generic[NodeDataT]): graph_runtime_state: GraphRuntimeState, ) -> None: self._graph_init_params = graph_init_params + self._run_context = MappingProxyType(dict(graph_init_params.run_context)) self.id = id - self.tenant_id = graph_init_params.tenant_id - self.app_id = graph_init_params.app_id self.workflow_id = graph_init_params.workflow_id self.graph_config = graph_init_params.graph_config - self.user_id = graph_init_params.user_id - self.user_from = graph_init_params.user_from - self.invoke_from = graph_init_params.invoke_from self.workflow_call_depth = graph_init_params.call_depth self.graph_runtime_state = graph_runtime_state self.state: NodeState = NodeState.UNKNOWN # node execution state @@ -263,6 +278,38 @@ class Node(Generic[NodeDataT]): def graph_init_params(self) -> GraphInitParams: return self._graph_init_params + @property + def run_context(self) -> Mapping[str, Any]: + return self._run_context + + def get_run_context_value(self, key: str, default: Any = None) -> Any: + return self._run_context.get(key, default) + + def require_run_context_value(self, key: str) -> Any: + value = self.get_run_context_value(key, _MISSING_RUN_CONTEXT_VALUE) + if value is _MISSING_RUN_CONTEXT_VALUE: + raise ValueError(f"run_context missing required key: {key}") + return value + + def require_dify_context(self) -> DifyRunContextProtocol: + raw_ctx = self.require_run_context_value(DIFY_RUN_CONTEXT_KEY) + if raw_ctx is None: + raise ValueError(f"run_context missing required key: {DIFY_RUN_CONTEXT_KEY}") + + if isinstance(raw_ctx, Mapping): + missing_keys = [ + key for key in ("tenant_id", "app_id", "user_id", "user_from", "invoke_from") if key not in raw_ctx + ] + if missing_keys: + raise ValueError(f"dify context missing required keys: {', '.join(missing_keys)}") + return _MappingDifyRunContext(raw_ctx) + + for attr in ("tenant_id", "app_id", "user_id", "user_from", "invoke_from"): + if not hasattr(raw_ctx, attr): + raise TypeError(f"invalid dify context object, missing attribute: {attr}") + + return cast(DifyRunContextProtocol, raw_ctx) + @property def execution_id(self) -> str: return self._node_execution_id diff --git a/api/dify_graph/nodes/datasource/datasource_node.py b/api/dify_graph/nodes/datasource/datasource_node.py index 802d34d2d0..b97394744e 100644 --- a/api/dify_graph/nodes/datasource/datasource_node.py +++ b/api/dify_graph/nodes/datasource/datasource_node.py @@ -52,6 +52,7 @@ class DatasourceNode(Node[DatasourceNodeData]): Run the datasource node """ + dify_ctx = self.require_dify_context() node_data = self.node_data variable_pool = self.graph_runtime_state.variable_pool datasource_type_segment = variable_pool.get(["sys", SystemVariableKey.DATASOURCE_TYPE]) @@ -75,7 +76,7 @@ class DatasourceNode(Node[DatasourceNodeData]): datasource_info["icon"] = self.datasource_manager.get_icon_url( provider_id=provider_id, datasource_name=node_data.datasource_name or "", - tenant_id=self.tenant_id, + tenant_id=dify_ctx.tenant_id, datasource_type=datasource_type.value, ) @@ -104,11 +105,11 @@ class DatasourceNode(Node[DatasourceNodeData]): yield from self.datasource_manager.stream_node_events( node_id=self._node_id, - user_id=self.user_id, + user_id=dify_ctx.user_id, datasource_name=node_data.datasource_name or "", datasource_type=datasource_type.value, provider_id=provider_id, - tenant_id=self.tenant_id, + tenant_id=dify_ctx.tenant_id, provider=node_data.provider_name, plugin_id=node_data.plugin_id, credential_id=credential_id, @@ -136,7 +137,7 @@ class DatasourceNode(Node[DatasourceNodeData]): raise DatasourceNodeError("File is not exist") file_info = self.datasource_manager.get_upload_file_by_id( - file_id=related_id, tenant_id=self.tenant_id + file_id=related_id, tenant_id=dify_ctx.tenant_id ) variable_pool.add([self._node_id, "file"], file_info) # variable_pool.add([self.node_id, "file"], file_info.to_dict()) diff --git a/api/dify_graph/nodes/http_request/node.py b/api/dify_graph/nodes/http_request/node.py index ae0faa8a56..2e48d5502a 100644 --- a/api/dify_graph/nodes/http_request/node.py +++ b/api/dify_graph/nodes/http_request/node.py @@ -212,6 +212,7 @@ class HttpRequestNode(Node[HttpRequestNodeData]): """ Extract files from response by checking both Content-Type header and URL """ + dify_ctx = self.require_dify_context() files: list[File] = [] is_file = response.is_file content_type = response.content_type @@ -236,8 +237,8 @@ class HttpRequestNode(Node[HttpRequestNodeData]): tool_file_manager = self._tool_file_manager_factory() tool_file = tool_file_manager.create_file_by_raw( - user_id=self.user_id, - tenant_id=self.tenant_id, + user_id=dify_ctx.user_id, + tenant_id=dify_ctx.tenant_id, conversation_id=None, file_binary=content, mimetype=mime_type, @@ -249,7 +250,7 @@ class HttpRequestNode(Node[HttpRequestNodeData]): } file = file_factory.build_from_mapping( mapping=mapping, - tenant_id=self.tenant_id, + tenant_id=dify_ctx.tenant_id, ) files.append(file) diff --git a/api/dify_graph/nodes/human_input/human_input_node.py b/api/dify_graph/nodes/human_input/human_input_node.py index e54650898d..03c2d17b1d 100644 --- a/api/dify_graph/nodes/human_input/human_input_node.py +++ b/api/dify_graph/nodes/human_input/human_input_node.py @@ -4,7 +4,7 @@ from collections.abc import Generator, Mapping, Sequence from typing import TYPE_CHECKING, Any from dify_graph.entities.pause_reason import HumanInputRequired -from dify_graph.enums import InvokeFrom, NodeExecutionType, NodeType, WorkflowNodeExecutionStatus +from dify_graph.enums import NodeExecutionType, NodeType, WorkflowNodeExecutionStatus from dify_graph.node_events import ( HumanInputFormFilledEvent, HumanInputFormTimeoutEvent, @@ -31,6 +31,8 @@ if TYPE_CHECKING: _SELECTED_BRANCH_KEY = "selected_branch" +_INVOKE_FROM_DEBUGGER = "debugger" +_INVOKE_FROM_EXPLORE = "explore" logger = logging.getLogger(__name__) @@ -155,30 +157,39 @@ class HumanInputNode(Node[HumanInputNodeData]): return resolved_defaults def _should_require_console_recipient(self) -> bool: - if self.invoke_from == InvokeFrom.DEBUGGER: + invoke_from = self._invoke_from_value() + if invoke_from == _INVOKE_FROM_DEBUGGER: return True - if self.invoke_from == InvokeFrom.EXPLORE: + if invoke_from == _INVOKE_FROM_EXPLORE: return self._node_data.is_webapp_enabled() return False def _display_in_ui(self) -> bool: - if self.invoke_from == InvokeFrom.DEBUGGER: + if self._invoke_from_value() == _INVOKE_FROM_DEBUGGER: return True return self._node_data.is_webapp_enabled() def _effective_delivery_methods(self) -> Sequence[DeliveryChannelConfig]: + dify_ctx = self.require_dify_context() + invoke_from = self._invoke_from_value() enabled_methods = [method for method in self._node_data.delivery_methods if method.enabled] - if self.invoke_from in {InvokeFrom.DEBUGGER, InvokeFrom.EXPLORE}: + if invoke_from in {_INVOKE_FROM_DEBUGGER, _INVOKE_FROM_EXPLORE}: enabled_methods = [method for method in enabled_methods if method.type != DeliveryMethodType.WEBAPP] return [ apply_debug_email_recipient( method, - enabled=self.invoke_from == InvokeFrom.DEBUGGER, - user_id=self.user_id or "", + enabled=invoke_from == _INVOKE_FROM_DEBUGGER, + user_id=dify_ctx.user_id, ) for method in enabled_methods ] + def _invoke_from_value(self) -> str: + invoke_from = self.require_dify_context().invoke_from + if isinstance(invoke_from, str): + return invoke_from + return str(getattr(invoke_from, "value", invoke_from)) + def _human_input_required_event(self, form_entity: HumanInputFormEntity) -> HumanInputRequired: node_data = self._node_data resolved_default_values = self.resolve_default_values() @@ -212,10 +223,11 @@ class HumanInputNode(Node[HumanInputNodeData]): """ repo = self._form_repository form = repo.get_form(self._workflow_execution_id, self.id) + dify_ctx = self.require_dify_context() if form is None: display_in_ui = self._display_in_ui() params = FormCreateParams( - app_id=self.app_id, + app_id=dify_ctx.app_id, workflow_execution_id=self._workflow_execution_id, node_id=self.id, form_config=self._node_data, @@ -225,7 +237,9 @@ class HumanInputNode(Node[HumanInputNodeData]): resolved_default_values=self.resolve_default_values(), console_recipient_required=self._should_require_console_recipient(), console_creator_account_id=( - self.user_id if self.invoke_from in {InvokeFrom.DEBUGGER, InvokeFrom.EXPLORE} else None + dify_ctx.user_id + if self._invoke_from_value() in {_INVOKE_FROM_DEBUGGER, _INVOKE_FROM_EXPLORE} + else None ), backstage_recipient_required=True, ) diff --git a/api/dify_graph/nodes/iteration/iteration_node.py b/api/dify_graph/nodes/iteration/iteration_node.py index ed3634fa91..6d26cbfce4 100644 --- a/api/dify_graph/nodes/iteration/iteration_node.py +++ b/api/dify_graph/nodes/iteration/iteration_node.py @@ -587,24 +587,14 @@ class IterationNode(LLMUsageTrackingMixin, Node[IterationNodeData]): return def _create_graph_engine(self, index: int, item: object): - # Import dependencies - from core.app.workflow.layers.llm_quota import LLMQuotaLayer - from core.workflow.node_factory import DifyNodeFactory from dify_graph.entities import GraphInitParams - from dify_graph.graph import Graph - from dify_graph.graph_engine import GraphEngine, GraphEngineConfig - from dify_graph.graph_engine.command_channels import InMemoryChannel - from dify_graph.runtime import GraphRuntimeState + from dify_graph.runtime import ChildGraphNotFoundError, GraphRuntimeState - # Create GraphInitParams from node attributes + # Create GraphInitParams for child graph execution. graph_init_params = GraphInitParams( - tenant_id=self.tenant_id, - app_id=self.app_id, workflow_id=self.workflow_id, graph_config=self.graph_config, - user_id=self.user_id, - user_from=self.user_from, - invoke_from=self.invoke_from, + run_context=self.run_context, call_depth=self.workflow_call_depth, ) # Create a deep copy of the variable pool for each iteration @@ -621,28 +611,17 @@ class IterationNode(LLMUsageTrackingMixin, Node[IterationNodeData]): total_tokens=0, node_run_steps=0, ) + root_node_id = self.node_data.start_node_id + if root_node_id is None: + raise StartNodeIdNotFoundError(f"field start_node_id in iteration {self._node_id} not found") - # Create a new node factory with the new GraphRuntimeState - node_factory = DifyNodeFactory( - graph_init_params=graph_init_params, graph_runtime_state=graph_runtime_state_copy - ) - - # Initialize the iteration graph with the new node factory - iteration_graph = Graph.init( - graph_config=self.graph_config, node_factory=node_factory, root_node_id=self.node_data.start_node_id - ) - - if not iteration_graph: - raise IterationGraphNotFoundError("iteration graph not found") - - # Create a new GraphEngine for this iteration - graph_engine = GraphEngine( - workflow_id=self.workflow_id, - graph=iteration_graph, - graph_runtime_state=graph_runtime_state_copy, - command_channel=InMemoryChannel(), # Use InMemoryChannel for sub-graphs - config=GraphEngineConfig(), - ) - graph_engine.layer(LLMQuotaLayer()) - - return graph_engine + try: + return self.graph_runtime_state.create_child_engine( + workflow_id=self.workflow_id, + graph_init_params=graph_init_params, + graph_runtime_state=graph_runtime_state_copy, + graph_config=self.graph_config, + root_node_id=root_node_id, + ) + except ChildGraphNotFoundError as exc: + raise IterationGraphNotFoundError("iteration graph not found") from exc diff --git a/api/dify_graph/nodes/knowledge_index/knowledge_index_node.py b/api/dify_graph/nodes/knowledge_index/knowledge_index_node.py index daf97d6ca9..eeb4f3c229 100644 --- a/api/dify_graph/nodes/knowledge_index/knowledge_index_node.py +++ b/api/dify_graph/nodes/knowledge_index/knowledge_index_node.py @@ -3,7 +3,7 @@ from collections.abc import Mapping from typing import TYPE_CHECKING, Any from dify_graph.entities.workflow_node_execution import WorkflowNodeExecutionStatus -from dify_graph.enums import InvokeFrom, NodeExecutionType, NodeType, SystemVariableKey +from dify_graph.enums import NodeExecutionType, NodeType, SystemVariableKey from dify_graph.node_events import NodeRunResult from dify_graph.nodes.base.node import Node from dify_graph.nodes.base.template import Template @@ -20,6 +20,7 @@ if TYPE_CHECKING: from dify_graph.runtime import GraphRuntimeState logger = logging.getLogger(__name__) +_INVOKE_FROM_DEBUGGER = "debugger" class KnowledgeIndexNode(Node[KnowledgeIndexNodeData]): @@ -58,7 +59,8 @@ class KnowledgeIndexNode(Node[KnowledgeIndexNodeData]): if not variable: raise KnowledgeIndexNodeError("Index chunk variable is required.") invoke_from = variable_pool.get(["sys", SystemVariableKey.INVOKE_FROM]) - is_preview = invoke_from.value == InvokeFrom.DEBUGGER if invoke_from else False + invoke_from_value = str(invoke_from.value) if invoke_from else None + is_preview = invoke_from_value == _INVOKE_FROM_DEBUGGER chunks = variable.value variables = {"chunks": chunks} diff --git a/api/dify_graph/nodes/knowledge_retrieval/knowledge_retrieval_node.py b/api/dify_graph/nodes/knowledge_retrieval/knowledge_retrieval_node.py index 97c013812e..d84dda42d6 100644 --- a/api/dify_graph/nodes/knowledge_retrieval/knowledge_retrieval_node.py +++ b/api/dify_graph/nodes/knowledge_retrieval/knowledge_retrieval_node.py @@ -66,9 +66,10 @@ class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node[KnowledgeRetrievalNodeD self._rag_retrieval = rag_retrieval if llm_file_saver is None: + dify_ctx = self.require_dify_context() llm_file_saver = FileSaverImpl( - user_id=graph_init_params.user_id, - tenant_id=graph_init_params.tenant_id, + user_id=dify_ctx.user_id, + tenant_id=dify_ctx.tenant_id, ) self._llm_file_saver = llm_file_saver @@ -160,6 +161,7 @@ class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node[KnowledgeRetrievalNodeD def _fetch_dataset_retriever( self, node_data: KnowledgeRetrievalNodeData, variables: dict[str, Any] ) -> tuple[list[Source], LLMUsage]: + dify_ctx = self.require_dify_context() dataset_ids = node_data.dataset_ids query = variables.get("query") attachments = variables.get("attachments") @@ -176,10 +178,10 @@ class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node[KnowledgeRetrievalNodeD model = node_data.single_retrieval_config.model retrieval_resource_list = self._rag_retrieval.knowledge_retrieval( request=KnowledgeRetrievalRequest( - tenant_id=self.tenant_id, - user_id=self.user_id, - app_id=self.app_id, - user_from=self.user_from.value, + tenant_id=dify_ctx.tenant_id, + user_id=dify_ctx.user_id, + app_id=dify_ctx.app_id, + user_from=dify_ctx.user_from.value, dataset_ids=dataset_ids, retrieval_mode=DatasetRetrieveConfigEntity.RetrieveStrategy.SINGLE.value, completion_params=model.completion_params, @@ -229,10 +231,10 @@ class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node[KnowledgeRetrievalNodeD retrieval_resource_list = self._rag_retrieval.knowledge_retrieval( request=KnowledgeRetrievalRequest( - app_id=self.app_id, - tenant_id=self.tenant_id, - user_id=self.user_id, - user_from=self.user_from.value, + app_id=dify_ctx.app_id, + tenant_id=dify_ctx.tenant_id, + user_id=dify_ctx.user_id, + user_from=dify_ctx.user_from.value, dataset_ids=dataset_ids, query=query, retrieval_mode=DatasetRetrieveConfigEntity.RetrieveStrategy.MULTIPLE.value, diff --git a/api/dify_graph/nodes/llm/node.py b/api/dify_graph/nodes/llm/node.py index 65b92b3bcc..c7697a0972 100644 --- a/api/dify_graph/nodes/llm/node.py +++ b/api/dify_graph/nodes/llm/node.py @@ -145,9 +145,10 @@ class LLMNode(Node[LLMNodeData]): self._memory = memory if llm_file_saver is None: + dify_ctx = self.require_dify_context() llm_file_saver = FileSaverImpl( - user_id=graph_init_params.user_id, - tenant_id=graph_init_params.tenant_id, + user_id=dify_ctx.user_id, + tenant_id=dify_ctx.tenant_id, ) self._llm_file_saver = llm_file_saver @@ -242,7 +243,7 @@ class LLMNode(Node[LLMNodeData]): model_instance=model_instance, prompt_messages=prompt_messages, stop=stop, - user_id=self.user_id, + user_id=self.require_dify_context().user_id, structured_output_enabled=self.node_data.structured_output_enabled, structured_output=self.node_data.structured_output, file_saver=self._llm_file_saver, @@ -702,7 +703,7 @@ class LLMNode(Node[LLMNodeData]): filename=upload_file.name, extension="." + upload_file.extension, mime_type=upload_file.mime_type, - tenant_id=self.tenant_id, + tenant_id=self.require_dify_context().tenant_id, type=FileType.IMAGE, transfer_method=FileTransferMethod.LOCAL_FILE, remote_url=upload_file.source_url, diff --git a/api/dify_graph/nodes/loop/loop_node.py b/api/dify_graph/nodes/loop/loop_node.py index 93a9b4d7eb..8279f0fc66 100644 --- a/api/dify_graph/nodes/loop/loop_node.py +++ b/api/dify_graph/nodes/loop/loop_node.py @@ -412,24 +412,14 @@ class LoopNode(LLMUsageTrackingMixin, Node[LoopNodeData]): return build_segment_with_type(var_type, value) def _create_graph_engine(self, start_at: datetime, root_node_id: str): - # Import dependencies - from core.app.workflow.layers.llm_quota import LLMQuotaLayer - from core.workflow.node_factory import DifyNodeFactory from dify_graph.entities import GraphInitParams - from dify_graph.graph import Graph - from dify_graph.graph_engine import GraphEngine, GraphEngineConfig - from dify_graph.graph_engine.command_channels import InMemoryChannel from dify_graph.runtime import GraphRuntimeState - # Create GraphInitParams from node attributes + # Create GraphInitParams for child graph execution. graph_init_params = GraphInitParams( - tenant_id=self.tenant_id, - app_id=self.app_id, workflow_id=self.workflow_id, graph_config=self.graph_config, - user_id=self.user_id, - user_from=self.user_from, - invoke_from=self.invoke_from, + run_context=self.run_context, call_depth=self.workflow_call_depth, ) @@ -439,22 +429,10 @@ class LoopNode(LLMUsageTrackingMixin, Node[LoopNodeData]): start_at=start_at.timestamp(), ) - # Create a new node factory with the new GraphRuntimeState - node_factory = DifyNodeFactory( - graph_init_params=graph_init_params, graph_runtime_state=graph_runtime_state_copy - ) - - # Initialize the loop graph with the new node factory - loop_graph = Graph.init(graph_config=self.graph_config, node_factory=node_factory, root_node_id=root_node_id) - - # Create a new GraphEngine for this iteration - graph_engine = GraphEngine( + return self.graph_runtime_state.create_child_engine( workflow_id=self.workflow_id, - graph=loop_graph, + graph_init_params=graph_init_params, graph_runtime_state=graph_runtime_state_copy, - command_channel=InMemoryChannel(), # Use InMemoryChannel for sub-graphs - config=GraphEngineConfig(), + graph_config=self.graph_config, + root_node_id=root_node_id, ) - graph_engine.layer(LLMQuotaLayer()) - - return graph_engine diff --git a/api/dify_graph/nodes/parameter_extractor/parameter_extractor_node.py b/api/dify_graph/nodes/parameter_extractor/parameter_extractor_node.py index a9b21d83b1..1325a6a09a 100644 --- a/api/dify_graph/nodes/parameter_extractor/parameter_extractor_node.py +++ b/api/dify_graph/nodes/parameter_extractor/parameter_extractor_node.py @@ -297,7 +297,7 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]): tools=tools, stop=list(stop), stream=False, - user=self.user_id, + user=self.require_dify_context().user_id, ) # handle invoke result diff --git a/api/dify_graph/nodes/question_classifier/question_classifier_node.py b/api/dify_graph/nodes/question_classifier/question_classifier_node.py index 03ddf9ab5f..97535d832d 100644 --- a/api/dify_graph/nodes/question_classifier/question_classifier_node.py +++ b/api/dify_graph/nodes/question_classifier/question_classifier_node.py @@ -86,9 +86,10 @@ class QuestionClassifierNode(Node[QuestionClassifierNodeData]): self._memory = memory if llm_file_saver is None: + dify_ctx = self.require_dify_context() llm_file_saver = FileSaverImpl( - user_id=graph_init_params.user_id, - tenant_id=graph_init_params.tenant_id, + user_id=dify_ctx.user_id, + tenant_id=dify_ctx.tenant_id, ) self._llm_file_saver = llm_file_saver @@ -160,7 +161,7 @@ class QuestionClassifierNode(Node[QuestionClassifierNodeData]): model_instance=model_instance, prompt_messages=prompt_messages, stop=stop, - user_id=self.user_id, + user_id=self.require_dify_context().user_id, structured_output_enabled=False, structured_output=None, file_saver=self._llm_file_saver, diff --git a/api/dify_graph/nodes/tool/tool_node.py b/api/dify_graph/nodes/tool/tool_node.py index eee065c311..57fb946559 100644 --- a/api/dify_graph/nodes/tool/tool_node.py +++ b/api/dify_graph/nodes/tool/tool_node.py @@ -56,6 +56,8 @@ class ToolNode(Node[ToolNodeData]): """ from core.plugin.impl.exc import PluginDaemonClientSideError, PluginInvokeError + dify_ctx = self.require_dify_context() + # fetch tool icon tool_info = { "provider_type": self.node_data.provider_type.value, @@ -75,7 +77,12 @@ class ToolNode(Node[ToolNodeData]): if self.node_data.version != "1" or self.node_data.tool_node_version is not None: variable_pool = self.graph_runtime_state.variable_pool tool_runtime = ToolManager.get_workflow_tool_runtime( - self.tenant_id, self.app_id, self._node_id, self.node_data, self.invoke_from, variable_pool + dify_ctx.tenant_id, + dify_ctx.app_id, + self._node_id, + self.node_data, + dify_ctx.invoke_from, + variable_pool, ) except ToolNodeError as e: yield StreamCompletedEvent( @@ -109,10 +116,10 @@ class ToolNode(Node[ToolNodeData]): message_stream = ToolEngine.generic_invoke( tool=tool_runtime, tool_parameters=parameters, - user_id=self.user_id, + user_id=dify_ctx.user_id, workflow_tool_callback=DifyWorkflowCallbackHandler(), workflow_call_depth=self.workflow_call_depth, - app_id=self.app_id, + app_id=dify_ctx.app_id, conversation_id=conversation_id.text if conversation_id else None, ) except ToolNodeError as e: @@ -133,8 +140,8 @@ class ToolNode(Node[ToolNodeData]): messages=message_stream, tool_info=tool_info, parameters_for_log=parameters_for_log, - user_id=self.user_id, - tenant_id=self.tenant_id, + user_id=dify_ctx.user_id, + tenant_id=dify_ctx.tenant_id, node_id=self._node_id, tool_runtime=tool_runtime, ) diff --git a/api/dify_graph/nodes/trigger_webhook/node.py b/api/dify_graph/nodes/trigger_webhook/node.py index 1b8167e799..e466541908 100644 --- a/api/dify_graph/nodes/trigger_webhook/node.py +++ b/api/dify_graph/nodes/trigger_webhook/node.py @@ -69,6 +69,7 @@ class TriggerWebhookNode(Node[WebhookData]): ) def generate_file_var(self, param_name: str, file: dict): + dify_ctx = self.require_dify_context() related_id = file.get("related_id") transfer_method_value = file.get("transfer_method") if transfer_method_value: @@ -84,7 +85,7 @@ class TriggerWebhookNode(Node[WebhookData]): try: file_obj = file_factory.build_from_mapping( mapping=file, - tenant_id=self.tenant_id, + tenant_id=dify_ctx.tenant_id, ) file_segment = build_segment_with_type(SegmentType.FILE, file_obj) return FileVariable(name=param_name, value=file_segment.value, selector=[self.id, param_name]) diff --git a/api/dify_graph/runtime/__init__.py b/api/dify_graph/runtime/__init__.py index 10014c7182..adca07e59a 100644 --- a/api/dify_graph/runtime/__init__.py +++ b/api/dify_graph/runtime/__init__.py @@ -1,9 +1,17 @@ -from .graph_runtime_state import GraphRuntimeState +from .graph_runtime_state import ( + ChildEngineBuilderNotConfiguredError, + ChildEngineError, + ChildGraphNotFoundError, + GraphRuntimeState, +) from .graph_runtime_state_protocol import ReadOnlyGraphRuntimeState, ReadOnlyVariablePool from .read_only_wrappers import ReadOnlyGraphRuntimeStateWrapper, ReadOnlyVariablePoolWrapper from .variable_pool import VariablePool, VariableValue __all__ = [ + "ChildEngineBuilderNotConfiguredError", + "ChildEngineError", + "ChildGraphNotFoundError", "GraphRuntimeState", "ReadOnlyGraphRuntimeState", "ReadOnlyGraphRuntimeStateWrapper", diff --git a/api/dify_graph/runtime/graph_runtime_state.py b/api/dify_graph/runtime/graph_runtime_state.py index 6b88dd683c..41acc6db35 100644 --- a/api/dify_graph/runtime/graph_runtime_state.py +++ b/api/dify_graph/runtime/graph_runtime_state.py @@ -15,6 +15,7 @@ from dify_graph.model_runtime.entities.llm_entities import LLMUsage from dify_graph.runtime.variable_pool import VariablePool if TYPE_CHECKING: + from dify_graph.entities import GraphInitParams from dify_graph.entities.pause_reason import PauseReason @@ -135,6 +136,31 @@ class GraphProtocol(Protocol): def get_outgoing_edges(self, node_id: str) -> Sequence[EdgeProtocol]: ... +class ChildGraphEngineBuilderProtocol(Protocol): + def build_child_engine( + self, + *, + workflow_id: str, + graph_init_params: GraphInitParams, + graph_runtime_state: GraphRuntimeState, + graph_config: Mapping[str, Any], + root_node_id: str, + layers: Sequence[object] = (), + ) -> Any: ... + + +class ChildEngineError(ValueError): + """Base error type for child-engine creation failures.""" + + +class ChildEngineBuilderNotConfiguredError(ChildEngineError): + """Raised when child-engine creation is requested without a bound builder.""" + + +class ChildGraphNotFoundError(ChildEngineError): + """Raised when the requested child graph entry point cannot be resolved.""" + + class _GraphStateSnapshot(BaseModel): """Serializable graph state snapshot for node/edge states.""" @@ -209,6 +235,7 @@ class GraphRuntimeState: self._pending_graph_execution_workflow_id: str | None = None self._paused_nodes: set[str] = set() self._deferred_nodes: set[str] = set() + self._child_engine_builder: ChildGraphEngineBuilderProtocol | None = None # Node and edges states needed to be restored into # graph object. @@ -250,6 +277,31 @@ class GraphRuntimeState: if self._graph is not None: _ = self.response_coordinator + def bind_child_engine_builder(self, builder: ChildGraphEngineBuilderProtocol) -> None: + self._child_engine_builder = builder + + def create_child_engine( + self, + *, + workflow_id: str, + graph_init_params: GraphInitParams, + graph_runtime_state: GraphRuntimeState, + graph_config: Mapping[str, Any], + root_node_id: str, + layers: Sequence[object] = (), + ) -> Any: + if self._child_engine_builder is None: + raise ChildEngineBuilderNotConfiguredError("Child engine builder is not configured.") + + return self._child_engine_builder.build_child_engine( + workflow_id=workflow_id, + graph_init_params=graph_init_params, + graph_runtime_state=graph_runtime_state, + graph_config=graph_config, + root_node_id=root_node_id, + layers=layers, + ) + # ------------------------------------------------------------------ # Primary collaborators # ------------------------------------------------------------------ diff --git a/api/services/workflow_app_service.py b/api/services/workflow_app_service.py index 44bc64ec11..7147fe1eab 100644 --- a/api/services/workflow_app_service.py +++ b/api/services/workflow_app_service.py @@ -7,7 +7,7 @@ from sqlalchemy import and_, func, or_, select from sqlalchemy.orm import Session from dify_graph.enums import WorkflowExecutionStatus -from models import Account, App, EndUser, WorkflowAppLog, WorkflowArchiveLog, WorkflowRun +from models import Account, App, EndUser, TenantAccountJoin, WorkflowAppLog, WorkflowArchiveLog, WorkflowRun from models.enums import AppTriggerType, CreatorUserRole from models.trigger import WorkflowTriggerLog from services.plugin.plugin_service import PluginService @@ -132,7 +132,14 @@ class WorkflowAppService: ), ) if created_by_account: - account = session.scalar(select(Account).where(Account.email == created_by_account)) + account = session.scalar( + select(Account) + .join(TenantAccountJoin, TenantAccountJoin.account_id == Account.id) + .where( + Account.email == created_by_account, + TenantAccountJoin.tenant_id == app_model.tenant_id, + ) + ) if not account: raise ValueError(f"Account not found: {created_by_account}") diff --git a/api/services/workflow_service.py b/api/services/workflow_service.py index 21bc95136e..6d462b60b9 100644 --- a/api/services/workflow_service.py +++ b/api/services/workflow_service.py @@ -11,13 +11,13 @@ from sqlalchemy.orm import Session, sessionmaker from configs import dify_config from core.app.apps.advanced_chat.app_config_manager import AdvancedChatAppConfigManager from core.app.apps.workflow.app_config_manager import WorkflowAppConfigManager -from core.app.entities.app_invoke_entities import InvokeFrom +from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom, build_dify_run_context from core.repositories import DifyCoreRepositoryFactory from core.repositories.human_input_repository import HumanInputFormRepositoryImpl from core.workflow.workflow_entry import WorkflowEntry from dify_graph.entities import GraphInitParams, WorkflowNodeExecution from dify_graph.entities.pause_reason import HumanInputRequired -from dify_graph.enums import ErrorStrategy, UserFrom, WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus +from dify_graph.enums import ErrorStrategy, WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus from dify_graph.errors import WorkflowNodeRunFailedError from dify_graph.file import File from dify_graph.graph_events import GraphNodeEventBase, NodeRunFailedEvent, NodeRunSucceededEvent @@ -1063,13 +1063,15 @@ class WorkflowService: variable_pool: VariablePool, ) -> HumanInputNode: graph_init_params = GraphInitParams( - tenant_id=workflow.tenant_id, - app_id=workflow.app_id, workflow_id=workflow.id, graph_config=workflow.graph_dict, - user_id=account.id, - user_from=UserFrom.ACCOUNT, - invoke_from=InvokeFrom.DEBUGGER, + run_context=build_dify_run_context( + tenant_id=workflow.tenant_id, + app_id=workflow.app_id, + user_id=account.id, + user_from=UserFrom.ACCOUNT, + invoke_from=InvokeFrom.DEBUGGER, + ), call_depth=0, ) graph_runtime_state = GraphRuntimeState( diff --git a/api/tests/integration_tests/workflow/nodes/test_code.py b/api/tests/integration_tests/workflow/nodes/test_code.py index 9971e357d2..f8b7f95493 100644 --- a/api/tests/integration_tests/workflow/nodes/test_code.py +++ b/api/tests/integration_tests/workflow/nodes/test_code.py @@ -4,10 +4,9 @@ import uuid import pytest from configs import dify_config -from core.app.entities.app_invoke_entities import InvokeFrom +from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom from core.workflow.node_factory import DifyNodeFactory -from dify_graph.entities import GraphInitParams -from dify_graph.enums import UserFrom, WorkflowNodeExecutionStatus +from dify_graph.enums import WorkflowNodeExecutionStatus from dify_graph.graph import Graph from dify_graph.node_events import NodeRunResult from dify_graph.nodes.code.code_node import CodeNode @@ -15,6 +14,7 @@ from dify_graph.nodes.code.limits import CodeNodeLimits from dify_graph.runtime import GraphRuntimeState, VariablePool from dify_graph.system_variable import SystemVariable from tests.integration_tests.workflow.nodes.__mock.code_executor import setup_code_executor_mock +from tests.workflow_test_utils import build_test_graph_init_params CODE_MAX_STRING_LENGTH = dify_config.CODE_MAX_STRING_LENGTH @@ -31,11 +31,11 @@ def init_code_node(code_config: dict): "nodes": [{"data": {"type": "start", "title": "Start"}, "id": "start"}, code_config], } - init_params = GraphInitParams( - tenant_id="1", - app_id="1", + init_params = build_test_graph_init_params( workflow_id="1", graph_config=graph_config, + tenant_id="1", + app_id="1", user_id="1", user_from=UserFrom.ACCOUNT, invoke_from=InvokeFrom.DEBUGGER, diff --git a/api/tests/integration_tests/workflow/nodes/test_http.py b/api/tests/integration_tests/workflow/nodes/test_http.py index 6e7b3a573a..f691113511 100644 --- a/api/tests/integration_tests/workflow/nodes/test_http.py +++ b/api/tests/integration_tests/workflow/nodes/test_http.py @@ -5,18 +5,18 @@ from urllib.parse import urlencode import pytest from configs import dify_config -from core.app.entities.app_invoke_entities import InvokeFrom +from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom from core.helper.ssrf_proxy import ssrf_proxy from core.tools.tool_file_manager import ToolFileManager from core.workflow.node_factory import DifyNodeFactory -from dify_graph.entities import GraphInitParams -from dify_graph.enums import UserFrom, WorkflowNodeExecutionStatus +from dify_graph.enums import WorkflowNodeExecutionStatus from dify_graph.file.file_manager import file_manager from dify_graph.graph import Graph from dify_graph.nodes.http_request import HttpRequestNode, HttpRequestNodeConfig from dify_graph.runtime import GraphRuntimeState, VariablePool from dify_graph.system_variable import SystemVariable from tests.integration_tests.workflow.nodes.__mock.http import setup_http_mock +from tests.workflow_test_utils import build_test_graph_init_params HTTP_REQUEST_CONFIG = HttpRequestNodeConfig( max_connect_timeout=dify_config.HTTP_REQUEST_MAX_CONNECT_TIMEOUT, @@ -41,11 +41,11 @@ def init_http_node(config: dict): "nodes": [{"data": {"type": "start", "title": "Start"}, "id": "start"}, config], } - init_params = GraphInitParams( - tenant_id="1", - app_id="1", + init_params = build_test_graph_init_params( workflow_id="1", graph_config=graph_config, + tenant_id="1", + app_id="1", user_id="1", user_from=UserFrom.ACCOUNT, invoke_from=InvokeFrom.DEBUGGER, @@ -685,11 +685,11 @@ def test_nested_object_variable_selector(setup_http_mock): ], } - init_params = GraphInitParams( - tenant_id="1", - app_id="1", + init_params = build_test_graph_init_params( workflow_id="1", graph_config=graph_config, + tenant_id="1", + app_id="1", user_id="1", user_from=UserFrom.ACCOUNT, invoke_from=InvokeFrom.DEBUGGER, diff --git a/api/tests/integration_tests/workflow/nodes/test_llm.py b/api/tests/integration_tests/workflow/nodes/test_llm.py index cc83f0ea16..b4779ebcdd 100644 --- a/api/tests/integration_tests/workflow/nodes/test_llm.py +++ b/api/tests/integration_tests/workflow/nodes/test_llm.py @@ -4,17 +4,17 @@ import uuid from collections.abc import Generator from unittest.mock import MagicMock, patch -from core.app.entities.app_invoke_entities import InvokeFrom +from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom from core.llm_generator.output_parser.structured_output import _parse_structured_output from core.model_manager import ModelInstance -from dify_graph.entities import GraphInitParams -from dify_graph.enums import UserFrom, WorkflowNodeExecutionStatus +from dify_graph.enums import WorkflowNodeExecutionStatus from dify_graph.node_events import StreamCompletedEvent from dify_graph.nodes.llm.node import LLMNode from dify_graph.nodes.llm.protocols import CredentialsProvider, ModelFactory from dify_graph.runtime import GraphRuntimeState, VariablePool from dify_graph.system_variable import SystemVariable from extensions.ext_database import db +from tests.workflow_test_utils import build_test_graph_init_params """FOR MOCK FIXTURES, DO NOT REMOVE""" @@ -37,11 +37,11 @@ def init_llm_node(config: dict) -> LLMNode: workflow_id = "9d2074fc-6f86-45a9-b09d-6ecc63b9056d" user_id = "9d2074fc-6f86-45a9-b09d-6ecc63b9056e" - init_params = GraphInitParams( - tenant_id=tenant_id, - app_id=app_id, + init_params = build_test_graph_init_params( workflow_id=workflow_id, graph_config=graph_config, + tenant_id=tenant_id, + app_id=app_id, user_id=user_id, user_from=UserFrom.ACCOUNT, invoke_from=InvokeFrom.DEBUGGER, diff --git a/api/tests/integration_tests/workflow/nodes/test_parameter_extractor.py b/api/tests/integration_tests/workflow/nodes/test_parameter_extractor.py index 7310c40c50..62d9af0196 100644 --- a/api/tests/integration_tests/workflow/nodes/test_parameter_extractor.py +++ b/api/tests/integration_tests/workflow/nodes/test_parameter_extractor.py @@ -3,10 +3,9 @@ import time import uuid from unittest.mock import MagicMock -from core.app.entities.app_invoke_entities import InvokeFrom +from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom from core.model_manager import ModelInstance -from dify_graph.entities import GraphInitParams -from dify_graph.enums import UserFrom, WorkflowNodeExecutionStatus +from dify_graph.enums import WorkflowNodeExecutionStatus from dify_graph.model_runtime.entities import AssistantPromptMessage, UserPromptMessage from dify_graph.nodes.llm.protocols import CredentialsProvider, ModelFactory from dify_graph.nodes.parameter_extractor.parameter_extractor_node import ParameterExtractorNode @@ -14,6 +13,7 @@ from dify_graph.runtime import GraphRuntimeState, VariablePool from dify_graph.system_variable import SystemVariable from extensions.ext_database import db from tests.integration_tests.workflow.nodes.__mock.model import get_mocked_fetch_model_instance +from tests.workflow_test_utils import build_test_graph_init_params """FOR MOCK FIXTURES, DO NOT REMOVE""" from tests.integration_tests.model_runtime.__mock.plugin_daemon import setup_model_mock @@ -43,11 +43,11 @@ def init_parameter_extractor_node(config: dict, memory=None): "nodes": [{"data": {"type": "start", "title": "Start"}, "id": "start"}, config], } - init_params = GraphInitParams( - tenant_id="1", - app_id="1", + init_params = build_test_graph_init_params( workflow_id="1", graph_config=graph_config, + tenant_id="1", + app_id="1", user_id="1", user_from=UserFrom.ACCOUNT, invoke_from=InvokeFrom.DEBUGGER, diff --git a/api/tests/integration_tests/workflow/nodes/test_template_transform.py b/api/tests/integration_tests/workflow/nodes/test_template_transform.py index 9b4274c667..970e2cae00 100644 --- a/api/tests/integration_tests/workflow/nodes/test_template_transform.py +++ b/api/tests/integration_tests/workflow/nodes/test_template_transform.py @@ -1,15 +1,15 @@ import time import uuid -from core.app.entities.app_invoke_entities import InvokeFrom +from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom from core.workflow.node_factory import DifyNodeFactory -from dify_graph.entities import GraphInitParams -from dify_graph.enums import UserFrom, WorkflowNodeExecutionStatus +from dify_graph.enums import WorkflowNodeExecutionStatus from dify_graph.graph import Graph from dify_graph.nodes.template_transform.template_renderer import TemplateRenderError from dify_graph.nodes.template_transform.template_transform_node import TemplateTransformNode from dify_graph.runtime import GraphRuntimeState, VariablePool from dify_graph.system_variable import SystemVariable +from tests.workflow_test_utils import build_test_graph_init_params class _SimpleJinja2Renderer: @@ -53,11 +53,11 @@ def test_execute_template_transform(): "nodes": [{"data": {"type": "start", "title": "Start"}, "id": "start"}, config], } - init_params = GraphInitParams( - tenant_id="1", - app_id="1", + init_params = build_test_graph_init_params( workflow_id="1", graph_config=graph_config, + tenant_id="1", + app_id="1", user_id="1", user_from=UserFrom.ACCOUNT, invoke_from=InvokeFrom.DEBUGGER, diff --git a/api/tests/integration_tests/workflow/nodes/test_tool.py b/api/tests/integration_tests/workflow/nodes/test_tool.py index fdc690e4cb..f70bf46979 100644 --- a/api/tests/integration_tests/workflow/nodes/test_tool.py +++ b/api/tests/integration_tests/workflow/nodes/test_tool.py @@ -2,16 +2,16 @@ import time import uuid from unittest.mock import MagicMock -from core.app.entities.app_invoke_entities import InvokeFrom +from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom from core.tools.utils.configuration import ToolParameterConfigurationManager from core.workflow.node_factory import DifyNodeFactory -from dify_graph.entities import GraphInitParams -from dify_graph.enums import UserFrom, WorkflowNodeExecutionStatus +from dify_graph.enums import WorkflowNodeExecutionStatus from dify_graph.graph import Graph from dify_graph.node_events import StreamCompletedEvent from dify_graph.nodes.tool.tool_node import ToolNode from dify_graph.runtime import GraphRuntimeState, VariablePool from dify_graph.system_variable import SystemVariable +from tests.workflow_test_utils import build_test_graph_init_params def init_tool_node(config: dict): @@ -26,11 +26,11 @@ def init_tool_node(config: dict): "nodes": [{"data": {"type": "start", "title": "Start"}, "id": "start"}, config], } - init_params = GraphInitParams( - tenant_id="1", - app_id="1", + init_params = build_test_graph_init_params( workflow_id="1", graph_config=graph_config, + tenant_id="1", + app_id="1", user_id="1", user_from=UserFrom.ACCOUNT, invoke_from=InvokeFrom.DEBUGGER, diff --git a/api/tests/test_containers_integration_tests/core/workflow/test_human_input_resume_node_execution.py b/api/tests/test_containers_integration_tests/core/workflow/test_human_input_resume_node_execution.py index c4c9d62c84..9733735df3 100644 --- a/api/tests/test_containers_integration_tests/core/workflow/test_human_input_resume_node_execution.py +++ b/api/tests/test_containers_integration_tests/core/workflow/test_human_input_resume_node_execution.py @@ -12,7 +12,6 @@ from core.app.entities.app_invoke_entities import InvokeFrom, WorkflowAppGenerat from core.app.workflow.layers import PersistenceWorkflowInfo, WorkflowPersistenceLayer from core.repositories.sqlalchemy_workflow_execution_repository import SQLAlchemyWorkflowExecutionRepository from core.repositories.sqlalchemy_workflow_node_execution_repository import SQLAlchemyWorkflowNodeExecutionRepository -from dify_graph.entities import GraphInitParams from dify_graph.enums import WorkflowType from dify_graph.graph import Graph from dify_graph.graph_engine.command_channels.in_memory_channel import InMemoryChannel @@ -33,6 +32,7 @@ from models.account import Tenant, TenantAccountJoin, TenantAccountRole from models.enums import CreatorUserRole, WorkflowRunTriggeredFrom from models.model import App, AppMode, IconType from models.workflow import Workflow, WorkflowNodeExecutionModel, WorkflowNodeExecutionTriggeredFrom, WorkflowRun +from tests.workflow_test_utils import build_test_graph_init_params def _mock_form_repository_without_submission() -> HumanInputFormRepository: @@ -87,11 +87,11 @@ def _build_graph( form_repository: HumanInputFormRepository, ) -> Graph: graph_config: dict[str, object] = {"nodes": [], "edges": []} - params = GraphInitParams( - tenant_id=tenant_id, - app_id=app_id, + params = build_test_graph_init_params( workflow_id=workflow_id, graph_config=graph_config, + tenant_id=tenant_id, + app_id=app_id, user_id=user_id, user_from="account", invoke_from="debugger", diff --git a/api/tests/unit_tests/core/app/apps/test_pause_resume.py b/api/tests/unit_tests/core/app/apps/test_pause_resume.py index dd722c52b2..44af89601c 100644 --- a/api/tests/unit_tests/core/app/apps/test_pause_resume.py +++ b/api/tests/unit_tests/core/app/apps/test_pause_resume.py @@ -13,7 +13,6 @@ from core.app.apps.advanced_chat import app_generator as adv_app_gen_module from core.app.apps.workflow import app_generator as wf_app_gen_module from core.app.entities.app_invoke_entities import InvokeFrom from core.workflow.node_factory import DifyNodeFactory -from dify_graph.entities import GraphInitParams from dify_graph.entities.pause_reason import SchedulingPause from dify_graph.entities.workflow_start_reason import WorkflowStartReason from dify_graph.enums import NodeType, WorkflowNodeExecutionStatus @@ -34,6 +33,7 @@ from dify_graph.nodes.end.entities import EndNodeData from dify_graph.nodes.start.entities import StartNodeData from dify_graph.runtime import GraphRuntimeState, VariablePool from dify_graph.system_variable import SystemVariable +from tests.workflow_test_utils import build_test_graph_init_params if "core.ops.ops_trace_manager" not in sys.modules: ops_stub = ModuleType("core.ops.ops_trace_manager") @@ -142,11 +142,11 @@ def _build_graph_config(*, pause_on: str | None) -> dict[str, object]: def _build_graph(runtime_state: GraphRuntimeState, *, pause_on: str | None) -> Graph: graph_config = _build_graph_config(pause_on=pause_on) - params = GraphInitParams( - tenant_id="tenant", - app_id="app", + params = build_test_graph_init_params( workflow_id="workflow", graph_config=graph_config, + tenant_id="tenant", + app_id="app", user_id="user", user_from="account", invoke_from="service-api", diff --git a/api/tests/unit_tests/core/workflow/graph/test_graph_skip_validation.py b/api/tests/unit_tests/core/workflow/graph/test_graph_skip_validation.py index d7d2258df8..b93f18c5bd 100644 --- a/api/tests/unit_tests/core/workflow/graph/test_graph_skip_validation.py +++ b/api/tests/unit_tests/core/workflow/graph/test_graph_skip_validation.py @@ -4,15 +4,13 @@ from typing import Any import pytest -from core.app.entities.app_invoke_entities import InvokeFrom from core.workflow.node_factory import DifyNodeFactory -from dify_graph.entities import GraphInitParams -from dify_graph.enums import UserFrom from dify_graph.graph import Graph from dify_graph.graph.validation import GraphValidationError from dify_graph.nodes import NodeType from dify_graph.runtime import GraphRuntimeState, VariablePool from dify_graph.system_variable import SystemVariable +from tests.workflow_test_utils import build_test_graph_init_params def _build_iteration_graph(node_id: str) -> dict[str, Any]: @@ -53,14 +51,14 @@ def _build_loop_graph(node_id: str) -> dict[str, Any]: def _make_factory(graph_config: dict[str, Any]) -> DifyNodeFactory: - graph_init_params = GraphInitParams( - tenant_id="tenant", - app_id="app", + graph_init_params = build_test_graph_init_params( workflow_id="workflow", graph_config=graph_config, + tenant_id="tenant", + app_id="app", user_id="user", - user_from=UserFrom.ACCOUNT, - invoke_from=InvokeFrom.DEBUGGER, + user_from="account", + invoke_from="debugger", call_depth=0, ) graph_runtime_state = GraphRuntimeState( diff --git a/api/tests/unit_tests/core/workflow/graph/test_graph_validation.py b/api/tests/unit_tests/core/workflow/graph/test_graph_validation.py index d3ef971e6a..b98d56147e 100644 --- a/api/tests/unit_tests/core/workflow/graph/test_graph_validation.py +++ b/api/tests/unit_tests/core/workflow/graph/test_graph_validation.py @@ -6,15 +6,15 @@ from dataclasses import dataclass import pytest -from core.app.entities.app_invoke_entities import InvokeFrom from dify_graph.entities import GraphInitParams -from dify_graph.enums import ErrorStrategy, NodeExecutionType, NodeType, UserFrom +from dify_graph.enums import ErrorStrategy, NodeExecutionType, NodeType from dify_graph.graph import Graph from dify_graph.graph.validation import GraphValidationError from dify_graph.nodes.base.entities import BaseNodeData from dify_graph.nodes.base.node import Node from dify_graph.runtime import GraphRuntimeState, VariablePool from dify_graph.system_variable import SystemVariable +from tests.workflow_test_utils import build_test_graph_init_params class _TestNodeData(BaseNodeData): @@ -91,14 +91,14 @@ class _SimpleNodeFactory: @pytest.fixture def graph_init_dependencies() -> tuple[_SimpleNodeFactory, dict[str, object]]: graph_config: dict[str, object] = {"edges": [], "nodes": []} - init_params = GraphInitParams( - tenant_id="tenant", - app_id="app", + init_params = build_test_graph_init_params( workflow_id="workflow", graph_config=graph_config, + tenant_id="tenant", + app_id="app", user_id="user", - user_from=UserFrom.ACCOUNT, - invoke_from=InvokeFrom.SERVICE_API, + user_from="account", + invoke_from="service-api", call_depth=0, ) variable_pool = VariablePool(system_variables=SystemVariable(user_id="user", files=[]), user_inputs={}) diff --git a/api/tests/unit_tests/core/workflow/graph_engine/layers/test_llm_quota.py b/api/tests/unit_tests/core/workflow/graph_engine/layers/test_llm_quota.py index 352e270fe4..819fd67f9d 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/layers/test_llm_quota.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/layers/test_llm_quota.py @@ -32,6 +32,7 @@ def test_deduct_quota_called_for_successful_llm_node() -> None: node.execution_id = "execution-id" node.node_type = NodeType.LLM node.tenant_id = "tenant-id" + node.require_dify_context.return_value.tenant_id = "tenant-id" node.model_instance = object() result_event = _build_succeeded_event() @@ -52,6 +53,7 @@ def test_deduct_quota_called_for_question_classifier_node() -> None: node.execution_id = "execution-id" node.node_type = NodeType.QUESTION_CLASSIFIER node.tenant_id = "tenant-id" + node.require_dify_context.return_value.tenant_id = "tenant-id" node.model_instance = object() result_event = _build_succeeded_event() @@ -72,6 +74,7 @@ def test_non_llm_node_is_ignored() -> None: node.execution_id = "execution-id" node.node_type = NodeType.START node.tenant_id = "tenant-id" + node.require_dify_context.return_value.tenant_id = "tenant-id" node._model_instance = object() result_event = _build_succeeded_event() @@ -88,6 +91,7 @@ def test_quota_error_is_handled_in_layer() -> None: node.execution_id = "execution-id" node.node_type = NodeType.LLM node.tenant_id = "tenant-id" + node.require_dify_context.return_value.tenant_id = "tenant-id" node.model_instance = object() result_event = _build_succeeded_event() @@ -109,6 +113,7 @@ def test_quota_deduction_exceeded_aborts_workflow_immediately() -> None: node.execution_id = "execution-id" node.node_type = NodeType.LLM node.tenant_id = "tenant-id" + node.require_dify_context.return_value.tenant_id = "tenant-id" node.model_instance = object() node.graph_runtime_state = MagicMock() node.graph_runtime_state.stop_event = stop_event diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_auto_mock_system.py b/api/tests/unit_tests/core/workflow/graph_engine/test_auto_mock_system.py index 5196af277e..f886ae1c2b 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_auto_mock_system.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_auto_mock_system.py @@ -8,6 +8,7 @@ for workflows containing nodes that require third-party services. import pytest from dify_graph.enums import NodeType +from tests.workflow_test_utils import build_test_graph_init_params from .test_mock_config import MockConfig, MockConfigBuilder, NodeMockConfig from .test_table_runner import TableTestRunner, WorkflowTestCase @@ -199,22 +200,19 @@ def test_mock_config_builder(): def test_mock_factory_node_type_detection(): """Test that MockNodeFactory correctly identifies nodes to mock.""" - from core.app.entities.app_invoke_entities import InvokeFrom - from dify_graph.entities import GraphInitParams - from dify_graph.enums import UserFrom + from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom from dify_graph.runtime import GraphRuntimeState, VariablePool from .test_mock_factory import MockNodeFactory - graph_init_params = GraphInitParams( - tenant_id="test", - app_id="test", + graph_init_params = build_test_graph_init_params( workflow_id="test", graph_config={}, + tenant_id="test", + app_id="test", user_id="test", user_from=UserFrom.ACCOUNT, invoke_from=InvokeFrom.SERVICE_API, - call_depth=0, ) graph_runtime_state = GraphRuntimeState( variable_pool=VariablePool(environment_variables=[], conversation_variables=[], user_inputs={}), @@ -309,9 +307,7 @@ def test_workflow_without_auto_mock(): def test_register_custom_mock_node(): """Test registering a custom mock implementation for a node type.""" - from core.app.entities.app_invoke_entities import InvokeFrom - from dify_graph.entities import GraphInitParams - from dify_graph.enums import UserFrom + from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom from dify_graph.nodes.template_transform import TemplateTransformNode from dify_graph.runtime import GraphRuntimeState, VariablePool @@ -323,15 +319,14 @@ def test_register_custom_mock_node(): # Custom mock implementation pass - graph_init_params = GraphInitParams( - tenant_id="test", - app_id="test", + graph_init_params = build_test_graph_init_params( workflow_id="test", graph_config={}, + tenant_id="test", + app_id="test", user_id="test", user_from=UserFrom.ACCOUNT, invoke_from=InvokeFrom.SERVICE_API, - call_depth=0, ) graph_runtime_state = GraphRuntimeState( variable_pool=VariablePool(environment_variables=[], conversation_variables=[], user_inputs={}), diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_command_system.py b/api/tests/unit_tests/core/workflow/graph_engine/test_command_system.py index 09fc412e7f..765c4deba3 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_command_system.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_command_system.py @@ -3,10 +3,9 @@ import time from unittest.mock import MagicMock -from core.app.entities.app_invoke_entities import InvokeFrom -from dify_graph.entities.graph_init_params import GraphInitParams +from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom +from dify_graph.entities.graph_init_params import DIFY_RUN_CONTEXT_KEY, GraphInitParams from dify_graph.entities.pause_reason import SchedulingPause -from dify_graph.enums import UserFrom from dify_graph.graph import Graph from dify_graph.graph_engine import GraphEngine, GraphEngineConfig from dify_graph.graph_engine.command_channels import InMemoryChannel @@ -41,13 +40,17 @@ def test_abort_command(): id="start", config={"id": "start", "data": {"title": "start", "variables": []}}, graph_init_params=GraphInitParams( - tenant_id="test_tenant", - app_id="test_app", workflow_id="test_workflow", graph_config={}, - user_id="test_user", - user_from=UserFrom.ACCOUNT, - invoke_from=InvokeFrom.DEBUGGER, + run_context={ + DIFY_RUN_CONTEXT_KEY: { + "tenant_id": "test_tenant", + "app_id": "test_app", + "user_id": "test_user", + "user_from": UserFrom.ACCOUNT, + "invoke_from": InvokeFrom.DEBUGGER, + } + }, call_depth=0, ), graph_runtime_state=shared_runtime_state, @@ -151,13 +154,17 @@ def test_pause_command(): id="start", config={"id": "start", "data": {"title": "start", "variables": []}}, graph_init_params=GraphInitParams( - tenant_id="test_tenant", - app_id="test_app", workflow_id="test_workflow", graph_config={}, - user_id="test_user", - user_from=UserFrom.ACCOUNT, - invoke_from=InvokeFrom.DEBUGGER, + run_context={ + DIFY_RUN_CONTEXT_KEY: { + "tenant_id": "test_tenant", + "app_id": "test_app", + "user_id": "test_user", + "user_from": UserFrom.ACCOUNT, + "invoke_from": InvokeFrom.DEBUGGER, + } + }, call_depth=0, ), graph_runtime_state=shared_runtime_state, @@ -207,13 +214,17 @@ def test_update_variables_command_updates_pool(): id="start", config={"id": "start", "data": {"title": "start", "variables": []}}, graph_init_params=GraphInitParams( - tenant_id="test_tenant", - app_id="test_app", workflow_id="test_workflow", graph_config={}, - user_id="test_user", - user_from=UserFrom.ACCOUNT, - invoke_from=InvokeFrom.DEBUGGER, + run_context={ + DIFY_RUN_CONTEXT_KEY: { + "tenant_id": "test_tenant", + "app_id": "test_app", + "user_id": "test_user", + "user_from": UserFrom.ACCOUNT, + "invoke_from": InvokeFrom.DEBUGGER, + } + }, call_depth=0, ), graph_runtime_state=shared_runtime_state, diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_graph_state_snapshot.py b/api/tests/unit_tests/core/workflow/graph_engine/test_graph_state_snapshot.py index 84e033156d..d54f0be190 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_graph_state_snapshot.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_graph_state_snapshot.py @@ -21,6 +21,7 @@ from dify_graph.nodes.start.entities import StartNodeData from dify_graph.nodes.start.start_node import StartNode from dify_graph.runtime import GraphRuntimeState, VariablePool from dify_graph.system_variable import SystemVariable +from tests.workflow_test_utils import build_test_graph_init_params from .test_mock_config import MockConfig from .test_mock_nodes import MockLLMNode @@ -73,11 +74,11 @@ def _build_llm_node( def _build_graph(runtime_state: GraphRuntimeState) -> Graph: graph_config: dict[str, object] = {"nodes": [], "edges": []} - graph_init_params = GraphInitParams( - tenant_id="tenant", - app_id="app", + graph_init_params = build_test_graph_init_params( workflow_id="workflow", graph_config=graph_config, + tenant_id="tenant", + app_id="app", user_id="user", user_from="account", invoke_from="debugger", diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_human_input_pause_multi_branch.py b/api/tests/unit_tests/core/workflow/graph_engine/test_human_input_pause_multi_branch.py index 695e99c1cf..538f53c603 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_human_input_pause_multi_branch.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_human_input_pause_multi_branch.py @@ -4,7 +4,6 @@ from collections.abc import Iterable from unittest import mock from unittest.mock import MagicMock -from dify_graph.entities import GraphInitParams from dify_graph.graph import Graph from dify_graph.graph_events import ( GraphRunPausedEvent, @@ -35,6 +34,7 @@ from dify_graph.repositories.human_input_form_repository import HumanInputFormEn from dify_graph.runtime import GraphRuntimeState, VariablePool from dify_graph.system_variable import SystemVariable from libs.datetime_utils import naive_utc_now +from tests.workflow_test_utils import build_test_graph_init_params from .test_mock_config import MockConfig from .test_mock_nodes import MockLLMNode @@ -47,11 +47,11 @@ def _build_branching_graph( graph_runtime_state: GraphRuntimeState | None = None, ) -> tuple[Graph, GraphRuntimeState]: graph_config: dict[str, object] = {"nodes": [], "edges": []} - graph_init_params = GraphInitParams( - tenant_id="tenant", - app_id="app", + graph_init_params = build_test_graph_init_params( workflow_id="workflow", graph_config=graph_config, + tenant_id="tenant", + app_id="app", user_id="user", user_from="account", invoke_from="debugger", diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_human_input_pause_single_branch.py b/api/tests/unit_tests/core/workflow/graph_engine/test_human_input_pause_single_branch.py index 0275062c41..36bba6deb6 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_human_input_pause_single_branch.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_human_input_pause_single_branch.py @@ -3,7 +3,6 @@ import time from unittest import mock from unittest.mock import MagicMock -from dify_graph.entities import GraphInitParams from dify_graph.graph import Graph from dify_graph.graph_events import ( GraphRunPausedEvent, @@ -34,6 +33,7 @@ from dify_graph.repositories.human_input_form_repository import HumanInputFormEn from dify_graph.runtime import GraphRuntimeState, VariablePool from dify_graph.system_variable import SystemVariable from libs.datetime_utils import naive_utc_now +from tests.workflow_test_utils import build_test_graph_init_params from .test_mock_config import MockConfig from .test_mock_nodes import MockLLMNode @@ -46,11 +46,11 @@ def _build_llm_human_llm_graph( graph_runtime_state: GraphRuntimeState | None = None, ) -> tuple[Graph, GraphRuntimeState]: graph_config: dict[str, object] = {"nodes": [], "edges": []} - graph_init_params = GraphInitParams( - tenant_id="tenant", - app_id="app", + graph_init_params = build_test_graph_init_params( workflow_id="workflow", graph_config=graph_config, + tenant_id="tenant", + app_id="app", user_id="user", user_from="account", invoke_from="debugger", diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_if_else_streaming.py b/api/tests/unit_tests/core/workflow/graph_engine/test_if_else_streaming.py index fbcb8d7155..8da179c15e 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_if_else_streaming.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_if_else_streaming.py @@ -1,7 +1,6 @@ import time from unittest import mock -from dify_graph.entities import GraphInitParams from dify_graph.graph import Graph from dify_graph.graph_events import ( GraphRunStartedEvent, @@ -29,6 +28,7 @@ from dify_graph.nodes.start.start_node import StartNode from dify_graph.runtime import GraphRuntimeState, VariablePool from dify_graph.system_variable import SystemVariable from dify_graph.utils.condition.entities import Condition +from tests.workflow_test_utils import build_test_graph_init_params from .test_mock_config import MockConfig from .test_mock_nodes import MockLLMNode @@ -37,15 +37,10 @@ from .test_table_runner import TableTestRunner, WorkflowTestCase def _build_if_else_graph(branch_value: str, mock_config: MockConfig) -> tuple[Graph, GraphRuntimeState]: graph_config: dict[str, object] = {"nodes": [], "edges": []} - graph_init_params = GraphInitParams( - tenant_id="tenant", - app_id="app", - workflow_id="workflow", + graph_init_params = build_test_graph_init_params( graph_config=graph_config, - user_id="user", user_from="account", invoke_from="debugger", - call_depth=0, ) variable_pool = VariablePool( diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_mock_iteration_simple.py b/api/tests/unit_tests/core/workflow/graph_engine/test_mock_iteration_simple.py index da666ce987..eb449e6d75 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_mock_iteration_simple.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_mock_iteration_simple.py @@ -5,6 +5,8 @@ Simple test to verify MockNodeFactory works with iteration nodes. import sys from pathlib import Path +from dify_graph.entities.graph_init_params import DIFY_RUN_CONTEXT_KEY + # Add api directory to path api_dir = Path(__file__).parent.parent.parent.parent.parent.parent sys.path.insert(0, str(api_dir)) @@ -16,20 +18,23 @@ from tests.unit_tests.core.workflow.graph_engine.test_mock_factory import MockNo def test_mock_factory_registers_iteration_node(): """Test that MockNodeFactory has iteration node registered.""" - from core.app.entities.app_invoke_entities import InvokeFrom + from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom from dify_graph.entities import GraphInitParams - from dify_graph.enums import UserFrom from dify_graph.runtime import GraphRuntimeState, VariablePool # Create a MockNodeFactory instance graph_init_params = GraphInitParams( - tenant_id="test", - app_id="test", workflow_id="test", graph_config={"nodes": [], "edges": []}, - user_id="test", - user_from=UserFrom.ACCOUNT, - invoke_from=InvokeFrom.SERVICE_API, + run_context={ + DIFY_RUN_CONTEXT_KEY: { + "tenant_id": "test", + "app_id": "test", + "user_id": "test", + "user_from": UserFrom.ACCOUNT, + "invoke_from": InvokeFrom.SERVICE_API, + } + }, call_depth=0, ) graph_runtime_state = GraphRuntimeState( @@ -65,9 +70,8 @@ def test_mock_factory_registers_iteration_node(): def test_mock_iteration_node_preserves_config(): """Test that MockIterationNode preserves mock configuration.""" - from core.app.entities.app_invoke_entities import InvokeFrom + from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom from dify_graph.entities import GraphInitParams - from dify_graph.enums import UserFrom from dify_graph.runtime import GraphRuntimeState, VariablePool from tests.unit_tests.core.workflow.graph_engine.test_mock_nodes import MockIterationNode @@ -76,13 +80,17 @@ def test_mock_iteration_node_preserves_config(): # Create minimal graph init params graph_init_params = GraphInitParams( - tenant_id="test", - app_id="test", workflow_id="test", graph_config={"nodes": [], "edges": []}, - user_id="test", - user_from=UserFrom.ACCOUNT, - invoke_from=InvokeFrom.SERVICE_API, + run_context={ + DIFY_RUN_CONTEXT_KEY: { + "tenant_id": "test", + "app_id": "test", + "user_id": "test", + "user_from": UserFrom.ACCOUNT, + "invoke_from": InvokeFrom.SERVICE_API, + } + }, call_depth=0, ) @@ -127,9 +135,8 @@ def test_mock_iteration_node_preserves_config(): def test_mock_loop_node_preserves_config(): """Test that MockLoopNode preserves mock configuration.""" - from core.app.entities.app_invoke_entities import InvokeFrom + from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom from dify_graph.entities import GraphInitParams - from dify_graph.enums import UserFrom from dify_graph.runtime import GraphRuntimeState, VariablePool from tests.unit_tests.core.workflow.graph_engine.test_mock_nodes import MockLoopNode @@ -138,13 +145,17 @@ def test_mock_loop_node_preserves_config(): # Create minimal graph init params graph_init_params = GraphInitParams( - tenant_id="test", - app_id="test", workflow_id="test", graph_config={"nodes": [], "edges": []}, - user_id="test", - user_from=UserFrom.ACCOUNT, - invoke_from=InvokeFrom.SERVICE_API, + run_context={ + DIFY_RUN_CONTEXT_KEY: { + "tenant_id": "test", + "app_id": "test", + "user_id": "test", + "user_from": UserFrom.ACCOUNT, + "invoke_from": InvokeFrom.SERVICE_API, + } + }, call_depth=0, ) diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_mock_nodes.py b/api/tests/unit_tests/core/workflow/graph_engine/test_mock_nodes.py index 22afbb4909..3f458e9de9 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_mock_nodes.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_mock_nodes.py @@ -603,13 +603,9 @@ class MockIterationNode(MockNodeMixin, IterationNode): # Create GraphInitParams from node attributes graph_init_params = GraphInitParams( - tenant_id=self.tenant_id, - app_id=self.app_id, workflow_id=self.workflow_id, graph_config=self.graph_config, - user_id=self.user_id, - user_from=self.user_from.value, - invoke_from=self.invoke_from.value, + run_context=self.run_context, call_depth=self.workflow_call_depth, ) @@ -679,13 +675,9 @@ class MockLoopNode(MockNodeMixin, LoopNode): # Create GraphInitParams from node attributes graph_init_params = GraphInitParams( - tenant_id=self.tenant_id, - app_id=self.app_id, workflow_id=self.workflow_id, graph_config=self.graph_config, - user_id=self.user_id, - user_from=self.user_from.value, - invoke_from=self.invoke_from.value, + run_context=self.run_context, call_depth=self.workflow_call_depth, ) diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_mock_nodes_template_code.py b/api/tests/unit_tests/core/workflow/graph_engine/test_mock_nodes_template_code.py index 0942d15073..1550dca402 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_mock_nodes_template_code.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_mock_nodes_template_code.py @@ -6,6 +6,7 @@ to ensure they work correctly with the TableTestRunner. """ from configs import dify_config +from dify_graph.entities.graph_init_params import DIFY_RUN_CONTEXT_KEY from dify_graph.enums import NodeType, WorkflowNodeExecutionStatus from dify_graph.nodes.code.limits import CodeNodeLimits from tests.unit_tests.core.workflow.graph_engine.test_mock_config import MockConfig, MockConfigBuilder, NodeMockConfig @@ -44,13 +45,17 @@ class TestMockTemplateTransformNode: # Create test parameters graph_init_params = GraphInitParams( - tenant_id="test_tenant", - app_id="test_app", workflow_id="test_workflow", graph_config={}, - user_id="test_user", - user_from="account", - invoke_from="debugger", + run_context={ + DIFY_RUN_CONTEXT_KEY: { + "tenant_id": "test_tenant", + "app_id": "test_app", + "user_id": "test_user", + "user_from": "account", + "invoke_from": "debugger", + } + }, call_depth=0, ) @@ -103,13 +108,17 @@ class TestMockTemplateTransformNode: # Create test parameters graph_init_params = GraphInitParams( - tenant_id="test_tenant", - app_id="test_app", workflow_id="test_workflow", graph_config={}, - user_id="test_user", - user_from="account", - invoke_from="debugger", + run_context={ + DIFY_RUN_CONTEXT_KEY: { + "tenant_id": "test_tenant", + "app_id": "test_app", + "user_id": "test_user", + "user_from": "account", + "invoke_from": "debugger", + } + }, call_depth=0, ) @@ -163,13 +172,17 @@ class TestMockTemplateTransformNode: # Create test parameters graph_init_params = GraphInitParams( - tenant_id="test_tenant", - app_id="test_app", workflow_id="test_workflow", graph_config={}, - user_id="test_user", - user_from="account", - invoke_from="debugger", + run_context={ + DIFY_RUN_CONTEXT_KEY: { + "tenant_id": "test_tenant", + "app_id": "test_app", + "user_id": "test_user", + "user_from": "account", + "invoke_from": "debugger", + } + }, call_depth=0, ) @@ -221,13 +234,17 @@ class TestMockTemplateTransformNode: # Create test parameters graph_init_params = GraphInitParams( - tenant_id="test_tenant", - app_id="test_app", workflow_id="test_workflow", graph_config={}, - user_id="test_user", - user_from="account", - invoke_from="debugger", + run_context={ + DIFY_RUN_CONTEXT_KEY: { + "tenant_id": "test_tenant", + "app_id": "test_app", + "user_id": "test_user", + "user_from": "account", + "invoke_from": "debugger", + } + }, call_depth=0, ) @@ -286,13 +303,17 @@ class TestMockCodeNode: # Create test parameters graph_init_params = GraphInitParams( - tenant_id="test_tenant", - app_id="test_app", workflow_id="test_workflow", graph_config={}, - user_id="test_user", - user_from="account", - invoke_from="debugger", + run_context={ + DIFY_RUN_CONTEXT_KEY: { + "tenant_id": "test_tenant", + "app_id": "test_app", + "user_id": "test_user", + "user_from": "account", + "invoke_from": "debugger", + } + }, call_depth=0, ) @@ -348,13 +369,17 @@ class TestMockCodeNode: # Create test parameters graph_init_params = GraphInitParams( - tenant_id="test_tenant", - app_id="test_app", workflow_id="test_workflow", graph_config={}, - user_id="test_user", - user_from="account", - invoke_from="debugger", + run_context={ + DIFY_RUN_CONTEXT_KEY: { + "tenant_id": "test_tenant", + "app_id": "test_app", + "user_id": "test_user", + "user_from": "account", + "invoke_from": "debugger", + } + }, call_depth=0, ) @@ -418,13 +443,17 @@ class TestMockCodeNode: # Create test parameters graph_init_params = GraphInitParams( - tenant_id="test_tenant", - app_id="test_app", workflow_id="test_workflow", graph_config={}, - user_id="test_user", - user_from="account", - invoke_from="debugger", + run_context={ + DIFY_RUN_CONTEXT_KEY: { + "tenant_id": "test_tenant", + "app_id": "test_app", + "user_id": "test_user", + "user_from": "account", + "invoke_from": "debugger", + } + }, call_depth=0, ) @@ -490,13 +519,17 @@ class TestMockNodeFactory: # Create test parameters graph_init_params = GraphInitParams( - tenant_id="test_tenant", - app_id="test_app", workflow_id="test_workflow", graph_config={}, - user_id="test_user", - user_from="account", - invoke_from="debugger", + run_context={ + DIFY_RUN_CONTEXT_KEY: { + "tenant_id": "test_tenant", + "app_id": "test_app", + "user_id": "test_user", + "user_from": "account", + "invoke_from": "debugger", + } + }, call_depth=0, ) @@ -531,13 +564,17 @@ class TestMockNodeFactory: # Create test parameters graph_init_params = GraphInitParams( - tenant_id="test_tenant", - app_id="test_app", workflow_id="test_workflow", graph_config={}, - user_id="test_user", - user_from="account", - invoke_from="debugger", + run_context={ + DIFY_RUN_CONTEXT_KEY: { + "tenant_id": "test_tenant", + "app_id": "test_app", + "user_id": "test_user", + "user_from": "account", + "invoke_from": "debugger", + } + }, call_depth=0, ) @@ -582,13 +619,17 @@ class TestMockNodeFactory: # Create test parameters graph_init_params = GraphInitParams( - tenant_id="test_tenant", - app_id="test_app", workflow_id="test_workflow", graph_config={}, - user_id="test_user", - user_from="account", - invoke_from="debugger", + run_context={ + DIFY_RUN_CONTEXT_KEY: { + "tenant_id": "test_tenant", + "app_id": "test_app", + "user_id": "test_user", + "user_from": "account", + "invoke_from": "debugger", + } + }, call_depth=0, ) diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_mock_simple.py b/api/tests/unit_tests/core/workflow/graph_engine/test_mock_simple.py index 2376423738..84d1444585 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_mock_simple.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_mock_simple.py @@ -5,6 +5,8 @@ Simple test to validate the auto-mock system without external dependencies. import sys from pathlib import Path +from dify_graph.entities.graph_init_params import DIFY_RUN_CONTEXT_KEY + # Add api directory to path api_dir = Path(__file__).parent.parent.parent.parent.parent.parent sys.path.insert(0, str(api_dir)) @@ -101,21 +103,24 @@ def test_node_mock_config(): def test_mock_factory_detection(): """Test MockNodeFactory node type detection.""" - from core.app.entities.app_invoke_entities import InvokeFrom + from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom from dify_graph.entities import GraphInitParams - from dify_graph.enums import UserFrom from dify_graph.runtime import GraphRuntimeState, VariablePool print("Testing MockNodeFactory detection...") graph_init_params = GraphInitParams( - tenant_id="test", - app_id="test", workflow_id="test", graph_config={}, - user_id="test", - user_from=UserFrom.ACCOUNT, - invoke_from=InvokeFrom.SERVICE_API, + run_context={ + DIFY_RUN_CONTEXT_KEY: { + "tenant_id": "test", + "app_id": "test", + "user_id": "test", + "user_from": UserFrom.ACCOUNT, + "invoke_from": InvokeFrom.SERVICE_API, + } + }, call_depth=0, ) graph_runtime_state = GraphRuntimeState( @@ -154,21 +159,24 @@ def test_mock_factory_detection(): def test_mock_factory_registration(): """Test registering and unregistering mock node types.""" - from core.app.entities.app_invoke_entities import InvokeFrom + from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom from dify_graph.entities import GraphInitParams - from dify_graph.enums import UserFrom from dify_graph.runtime import GraphRuntimeState, VariablePool print("Testing MockNodeFactory registration...") graph_init_params = GraphInitParams( - tenant_id="test", - app_id="test", workflow_id="test", graph_config={}, - user_id="test", - user_from=UserFrom.ACCOUNT, - invoke_from=InvokeFrom.SERVICE_API, + run_context={ + DIFY_RUN_CONTEXT_KEY: { + "tenant_id": "test", + "app_id": "test", + "user_id": "test", + "user_from": UserFrom.ACCOUNT, + "invoke_from": InvokeFrom.SERVICE_API, + } + }, call_depth=0, ) graph_runtime_state = GraphRuntimeState( diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_parallel_human_input_join_resume.py b/api/tests/unit_tests/core/workflow/graph_engine/test_parallel_human_input_join_resume.py index e269263bde..e681b39cc7 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_parallel_human_input_join_resume.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_parallel_human_input_join_resume.py @@ -4,7 +4,6 @@ from dataclasses import dataclass from datetime import datetime, timedelta from typing import Any, Protocol -from dify_graph.entities import GraphInitParams from dify_graph.entities.workflow_start_reason import WorkflowStartReason from dify_graph.graph import Graph from dify_graph.graph_engine.command_channels.in_memory_channel import InMemoryChannel @@ -32,6 +31,7 @@ from dify_graph.repositories.human_input_form_repository import ( from dify_graph.runtime import GraphRuntimeState, VariablePool from dify_graph.system_variable import SystemVariable from libs.datetime_utils import naive_utc_now +from tests.workflow_test_utils import build_test_graph_init_params class PauseStateStore(Protocol): @@ -126,11 +126,11 @@ def _build_runtime_state() -> GraphRuntimeState: def _build_graph(runtime_state: GraphRuntimeState, repo: HumanInputFormRepository) -> Graph: graph_config: dict[str, object] = {"nodes": [], "edges": []} - graph_init_params = GraphInitParams( - tenant_id="tenant", - app_id="app", + graph_init_params = build_test_graph_init_params( workflow_id="workflow", graph_config=graph_config, + tenant_id="tenant", + app_id="app", user_id="user", user_from="account", invoke_from="debugger", diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_parallel_human_input_pause_missing_finish.py b/api/tests/unit_tests/core/workflow/graph_engine/test_parallel_human_input_pause_missing_finish.py index 910292b52c..60167c0441 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_parallel_human_input_pause_missing_finish.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_parallel_human_input_pause_missing_finish.py @@ -4,7 +4,6 @@ from dataclasses import dataclass from datetime import datetime, timedelta from typing import Any -from dify_graph.entities import GraphInitParams from dify_graph.entities.workflow_start_reason import WorkflowStartReason from dify_graph.graph import Graph from dify_graph.graph_engine.command_channels.in_memory_channel import InMemoryChannel @@ -39,6 +38,7 @@ from dify_graph.repositories.human_input_form_repository import ( from dify_graph.runtime import GraphRuntimeState, VariablePool from dify_graph.system_variable import SystemVariable from libs.datetime_utils import naive_utc_now +from tests.workflow_test_utils import build_test_graph_init_params from .test_mock_config import MockConfig, NodeMockConfig from .test_mock_nodes import MockLLMNode @@ -129,11 +129,11 @@ def _build_runtime_state() -> GraphRuntimeState: def _build_graph(runtime_state: GraphRuntimeState, repo: HumanInputFormRepository, mock_config: MockConfig) -> Graph: graph_config: dict[str, object] = {"nodes": [], "edges": []} - graph_init_params = GraphInitParams( - tenant_id="tenant", - app_id="app", + graph_init_params = build_test_graph_init_params( workflow_id="workflow", graph_config=graph_config, + tenant_id="tenant", + app_id="app", user_id="user", user_from="account", invoke_from="debugger", diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_parallel_streaming_workflow.py b/api/tests/unit_tests/core/workflow/graph_engine/test_parallel_streaming_workflow.py index 8e6b30896f..0ac9d6618d 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_parallel_streaming_workflow.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_parallel_streaming_workflow.py @@ -12,11 +12,10 @@ import time from unittest.mock import MagicMock, patch from uuid import uuid4 -from core.app.entities.app_invoke_entities import InvokeFrom +from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom from core.model_manager import ModelInstance from core.workflow.node_factory import DifyNodeFactory -from dify_graph.entities import GraphInitParams -from dify_graph.enums import NodeType, UserFrom, WorkflowNodeExecutionStatus +from dify_graph.enums import NodeType, WorkflowNodeExecutionStatus from dify_graph.graph import Graph from dify_graph.graph_engine import GraphEngine, GraphEngineConfig from dify_graph.graph_engine.command_channels import InMemoryChannel @@ -30,6 +29,7 @@ from dify_graph.node_events import NodeRunResult, StreamCompletedEvent from dify_graph.nodes.llm.node import LLMNode from dify_graph.runtime import GraphRuntimeState, VariablePool from dify_graph.system_variable import SystemVariable +from tests.workflow_test_utils import build_test_graph_init_params from .test_table_runner import TableTestRunner @@ -86,11 +86,11 @@ def test_parallel_streaming_workflow(): graph_config = workflow_config.get("graph", {}) # Create graph initialization parameters - init_params = GraphInitParams( - tenant_id="test_tenant", - app_id="test_app", + init_params = build_test_graph_init_params( workflow_id="test_workflow", graph_config=graph_config, + tenant_id="test_tenant", + app_id="test_app", user_id="test_user", user_from=UserFrom.ACCOUNT, invoke_from=InvokeFrom.WEB_APP, @@ -99,8 +99,8 @@ def test_parallel_streaming_workflow(): # Create variable pool with system variables system_variables = SystemVariable( - user_id=init_params.user_id, - app_id=init_params.app_id, + user_id="test_user", + app_id="test_app", workflow_id=init_params.workflow_id, files=[], query="Tell me about yourself", # User query diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_pause_deferred_ready_nodes.py b/api/tests/unit_tests/core/workflow/graph_engine/test_pause_deferred_ready_nodes.py index e5a9a29a1f..7328ce443f 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_pause_deferred_ready_nodes.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_pause_deferred_ready_nodes.py @@ -4,7 +4,6 @@ from dataclasses import dataclass from datetime import datetime, timedelta from typing import Any -from dify_graph.entities import GraphInitParams from dify_graph.entities.workflow_start_reason import WorkflowStartReason from dify_graph.graph import Graph from dify_graph.graph_engine.command_channels.in_memory_channel import InMemoryChannel @@ -40,6 +39,7 @@ from dify_graph.repositories.human_input_form_repository import ( from dify_graph.runtime import GraphRuntimeState, VariablePool from dify_graph.system_variable import SystemVariable from libs.datetime_utils import naive_utc_now +from tests.workflow_test_utils import build_test_graph_init_params from .test_mock_config import MockConfig, NodeMockConfig from .test_mock_nodes import MockLLMNode @@ -121,11 +121,11 @@ def _build_runtime_state() -> GraphRuntimeState: def _build_graph(runtime_state: GraphRuntimeState, repo: HumanInputFormRepository, mock_config: MockConfig) -> Graph: graph_config: dict[str, object] = {"nodes": [], "edges": []} - graph_init_params = GraphInitParams( - tenant_id="tenant", - app_id="app", + graph_init_params = build_test_graph_init_params( workflow_id="workflow", graph_config=graph_config, + tenant_id="tenant", + app_id="app", user_id="user", user_from="account", invoke_from="debugger", diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_pause_resume_state.py b/api/tests/unit_tests/core/workflow/graph_engine/test_pause_resume_state.py index 183d589e2b..15a7de3c52 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_pause_resume_state.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_pause_resume_state.py @@ -3,7 +3,6 @@ import time from typing import Any from unittest.mock import MagicMock -from dify_graph.entities import GraphInitParams from dify_graph.entities.workflow_start_reason import WorkflowStartReason from dify_graph.graph import Graph from dify_graph.graph_engine.command_channels.in_memory_channel import InMemoryChannel @@ -30,6 +29,7 @@ from dify_graph.repositories.human_input_form_repository import ( from dify_graph.runtime import GraphRuntimeState, VariablePool from dify_graph.system_variable import SystemVariable from libs.datetime_utils import naive_utc_now +from tests.workflow_test_utils import build_test_graph_init_params def _build_runtime_state() -> GraphRuntimeState: @@ -79,11 +79,11 @@ def _build_human_input_graph( form_repository: HumanInputFormRepository, ) -> Graph: graph_config: dict[str, object] = {"nodes": [], "edges": []} - params = GraphInitParams( - tenant_id="tenant", - app_id="app", + params = build_test_graph_init_params( workflow_id="workflow", graph_config=graph_config, + tenant_id="tenant", + app_id="app", user_id="user", user_from="account", invoke_from="service-api", diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_table_runner.py b/api/tests/unit_tests/core/workflow/graph_engine/test_table_runner.py index ee36c976f0..767a8f60ce 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_table_runner.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_table_runner.py @@ -12,19 +12,21 @@ This module provides a robust table-driven testing framework with support for: import logging import time -from collections.abc import Callable, Sequence +from collections.abc import Callable, Mapping, Sequence from concurrent.futures import ThreadPoolExecutor, as_completed from dataclasses import dataclass, field from functools import lru_cache from pathlib import Path -from typing import Any +from typing import Any, cast +from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom from core.tools.utils.yaml_utils import _load_yaml_file from core.workflow.node_factory import DifyNodeFactory -from dify_graph.entities.graph_init_params import GraphInitParams +from dify_graph.entities.graph_init_params import DIFY_RUN_CONTEXT_KEY, GraphInitParams from dify_graph.graph import Graph from dify_graph.graph_engine import GraphEngine, GraphEngineConfig from dify_graph.graph_engine.command_channels import InMemoryChannel +from dify_graph.graph_engine.layers.base import GraphEngineLayer from dify_graph.graph_events import ( GraphEngineEvent, GraphRunStartedEvent, @@ -48,6 +50,47 @@ from .test_mock_factory import MockNodeFactory logger = logging.getLogger(__name__) +class _TableTestChildEngineBuilder: + def __init__(self, *, use_mock_factory: bool, mock_config: MockConfig | None) -> None: + self._use_mock_factory = use_mock_factory + self._mock_config = mock_config + + def build_child_engine( + self, + *, + workflow_id: str, + graph_init_params: GraphInitParams, + graph_runtime_state: GraphRuntimeState, + graph_config: Mapping[str, Any], + root_node_id: str, + layers: Sequence[object] = (), + ) -> GraphEngine: + if self._use_mock_factory: + node_factory = MockNodeFactory( + graph_init_params=graph_init_params, + graph_runtime_state=graph_runtime_state, + mock_config=self._mock_config, + ) + else: + node_factory = DifyNodeFactory(graph_init_params=graph_init_params, graph_runtime_state=graph_runtime_state) + + child_graph = Graph.init(graph_config=graph_config, node_factory=node_factory, root_node_id=root_node_id) + if not child_graph: + raise ValueError("child graph not found") + + child_engine = GraphEngine( + workflow_id=workflow_id, + graph=child_graph, + graph_runtime_state=graph_runtime_state, + command_channel=InMemoryChannel(), + config=GraphEngineConfig(), + child_engine_builder=self, + ) + for layer in layers: + child_engine.layer(cast(GraphEngineLayer, layer)) + return child_engine + + @dataclass class WorkflowTestCase: """Represents a single test case for table-driven testing.""" @@ -149,19 +192,23 @@ class WorkflowRunner: raise ValueError("Fixture missing workflow.graph configuration") graph_init_params = GraphInitParams( - tenant_id="test_tenant", - app_id="test_app", workflow_id="test_workflow", graph_config=graph_config, - user_id="test_user", - user_from="account", - invoke_from="debugger", # Set to debugger to avoid conversation_id requirement + run_context={ + DIFY_RUN_CONTEXT_KEY: { + "tenant_id": "test_tenant", + "app_id": "test_app", + "user_id": "test_user", + "user_from": UserFrom.ACCOUNT, + "invoke_from": InvokeFrom.DEBUGGER, # Set to debugger to avoid conversation_id requirement + } + }, call_depth=0, ) system_variables = SystemVariable( - user_id=graph_init_params.user_id, - app_id=graph_init_params.app_id, + user_id="test_user", + app_id="test_app", workflow_id=graph_init_params.workflow_id, files=[], query=query, @@ -315,6 +362,10 @@ class TableTestRunner: scale_up_threshold=self.graph_engine_scale_up_threshold, scale_down_idle_time=self.graph_engine_scale_down_idle_time, ), + child_engine_builder=_TableTestChildEngineBuilder( + use_mock_factory=test_case.use_auto_mock, + mock_config=test_case.mock_config, + ), ) # Execute and collect events diff --git a/api/tests/unit_tests/core/workflow/nodes/answer/test_answer.py b/api/tests/unit_tests/core/workflow/nodes/answer/test_answer.py index b1351c9fc3..f0d80af1ed 100644 --- a/api/tests/unit_tests/core/workflow/nodes/answer/test_answer.py +++ b/api/tests/unit_tests/core/workflow/nodes/answer/test_answer.py @@ -2,15 +2,15 @@ import time import uuid from unittest.mock import MagicMock -from core.app.entities.app_invoke_entities import InvokeFrom +from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom from core.workflow.node_factory import DifyNodeFactory -from dify_graph.entities import GraphInitParams -from dify_graph.enums import UserFrom, WorkflowNodeExecutionStatus +from dify_graph.enums import WorkflowNodeExecutionStatus from dify_graph.graph import Graph from dify_graph.nodes.answer.answer_node import AnswerNode from dify_graph.runtime import GraphRuntimeState, VariablePool from dify_graph.system_variable import SystemVariable from extensions.ext_database import db +from tests.workflow_test_utils import build_test_graph_init_params def test_execute_answer(): @@ -35,11 +35,11 @@ def test_execute_answer(): ], } - init_params = GraphInitParams( - tenant_id="1", - app_id="1", + init_params = build_test_graph_init_params( workflow_id="1", graph_config=graph_config, + tenant_id="1", + app_id="1", user_id="1", user_from=UserFrom.ACCOUNT, invoke_from=InvokeFrom.DEBUGGER, diff --git a/api/tests/unit_tests/core/workflow/nodes/datasource/test_datasource_node.py b/api/tests/unit_tests/core/workflow/nodes/datasource/test_datasource_node.py index b100c7c02c..db096b1aed 100644 --- a/api/tests/unit_tests/core/workflow/nodes/datasource/test_datasource_node.py +++ b/api/tests/unit_tests/core/workflow/nodes/datasource/test_datasource_node.py @@ -1,3 +1,4 @@ +from dify_graph.entities.graph_init_params import DIFY_RUN_CONTEXT_KEY from dify_graph.entities.workflow_node_execution import WorkflowNodeExecutionStatus from dify_graph.node_events import NodeRunResult, StreamChunkEvent, StreamCompletedEvent from dify_graph.nodes.datasource.datasource_node import DatasourceNode @@ -28,13 +29,17 @@ class _GraphState: class _GraphParams: - tenant_id = "t1" - app_id = "app-1" workflow_id = "wf-1" graph_config = {} - user_id = "u1" - user_from = "account" - invoke_from = "debugger" + run_context = { + DIFY_RUN_CONTEXT_KEY: { + "tenant_id": "t1", + "app_id": "app-1", + "user_id": "u1", + "user_from": "account", + "invoke_from": "debugger", + } + } call_depth = 0 diff --git a/api/tests/unit_tests/core/workflow/nodes/http_request/test_http_request_node.py b/api/tests/unit_tests/core/workflow/nodes/http_request/test_http_request_node.py index 38e68dcdc9..5e34bf1d94 100644 --- a/api/tests/unit_tests/core/workflow/nodes/http_request/test_http_request_node.py +++ b/api/tests/unit_tests/core/workflow/nodes/http_request/test_http_request_node.py @@ -4,16 +4,16 @@ from typing import Any import httpx import pytest -from core.app.entities.app_invoke_entities import InvokeFrom +from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom from core.helper.ssrf_proxy import ssrf_proxy from core.tools.tool_file_manager import ToolFileManager -from dify_graph.entities import GraphInitParams -from dify_graph.enums import UserFrom, WorkflowNodeExecutionStatus +from dify_graph.enums import WorkflowNodeExecutionStatus from dify_graph.file.file_manager import file_manager from dify_graph.nodes.http_request import HTTP_REQUEST_CONFIG_FILTER_KEY, HttpRequestNode, HttpRequestNodeConfig from dify_graph.nodes.http_request.entities import HttpRequestNodeTimeout, Response from dify_graph.runtime import GraphRuntimeState, VariablePool from dify_graph.system_variable import SystemVariable +from tests.workflow_test_utils import build_test_graph_init_params HTTP_REQUEST_CONFIG = HttpRequestNodeConfig( max_connect_timeout=10, @@ -98,11 +98,11 @@ def _build_http_node( ], "edges": [], } - graph_init_params = GraphInitParams( - tenant_id="tenant", - app_id="app", + graph_init_params = build_test_graph_init_params( workflow_id="workflow", graph_config=graph_config, + tenant_id="tenant", + app_id="app", user_id="user", user_from=UserFrom.ACCOUNT, invoke_from=InvokeFrom.DEBUGGER, diff --git a/api/tests/unit_tests/core/workflow/nodes/human_input/test_entities.py b/api/tests/unit_tests/core/workflow/nodes/human_input/test_entities.py index 3b2a81ccef..55aa62a1c0 100644 --- a/api/tests/unit_tests/core/workflow/nodes/human_input/test_entities.py +++ b/api/tests/unit_tests/core/workflow/nodes/human_input/test_entities.py @@ -9,6 +9,7 @@ import pytest from pydantic import ValidationError from dify_graph.entities import GraphInitParams +from dify_graph.entities.graph_init_params import DIFY_RUN_CONTEXT_KEY from dify_graph.node_events import PauseRequestedEvent from dify_graph.node_events.node import StreamCompletedEvent from dify_graph.nodes.human_input.entities import ( @@ -314,13 +315,17 @@ class TestHumanInputNodeVariableResolution: variable_pool.add(("start", "name"), "Jane Doe") runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=0.0) graph_init_params = GraphInitParams( - tenant_id="tenant", - app_id="app", workflow_id="workflow", graph_config={"nodes": [], "edges": []}, - user_id="user", - user_from="account", - invoke_from="debugger", + run_context={ + DIFY_RUN_CONTEXT_KEY: { + "tenant_id": "tenant", + "app_id": "app", + "user_id": "user", + "user_from": "account", + "invoke_from": "debugger", + } + }, call_depth=0, ) @@ -384,13 +389,17 @@ class TestHumanInputNodeVariableResolution: ) runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=0.0) graph_init_params = GraphInitParams( - tenant_id="tenant", - app_id="app", workflow_id="workflow", graph_config={"nodes": [], "edges": []}, - user_id="user", - user_from="account", - invoke_from="debugger", + run_context={ + DIFY_RUN_CONTEXT_KEY: { + "tenant_id": "tenant", + "app_id": "app", + "user_id": "user", + "user_from": "account", + "invoke_from": "debugger", + } + }, call_depth=0, ) @@ -439,13 +448,17 @@ class TestHumanInputNodeVariableResolution: ) runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=0.0) graph_init_params = GraphInitParams( - tenant_id="tenant", - app_id="app", workflow_id="workflow", graph_config={"nodes": [], "edges": []}, - user_id="user-123", - user_from="account", - invoke_from="debugger", + run_context={ + DIFY_RUN_CONTEXT_KEY: { + "tenant_id": "tenant", + "app_id": "app", + "user_id": "user-123", + "user_from": "account", + "invoke_from": "debugger", + } + }, call_depth=0, ) @@ -550,13 +563,17 @@ class TestHumanInputNodeRenderedContent: ) runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=0.0) graph_init_params = GraphInitParams( - tenant_id="tenant", - app_id="app", workflow_id="workflow", graph_config={"nodes": [], "edges": []}, - user_id="user", - user_from="account", - invoke_from="debugger", + run_context={ + DIFY_RUN_CONTEXT_KEY: { + "tenant_id": "tenant", + "app_id": "app", + "user_id": "user", + "user_from": "account", + "invoke_from": "debugger", + } + }, call_depth=0, ) diff --git a/api/tests/unit_tests/core/workflow/nodes/human_input/test_human_input_form_filled_event.py b/api/tests/unit_tests/core/workflow/nodes/human_input/test_human_input_form_filled_event.py index 46b4f1ed37..1fea19e795 100644 --- a/api/tests/unit_tests/core/workflow/nodes/human_input/test_human_input_form_filled_event.py +++ b/api/tests/unit_tests/core/workflow/nodes/human_input/test_human_input_form_filled_event.py @@ -1,9 +1,9 @@ import datetime from types import SimpleNamespace -from core.app.entities.app_invoke_entities import InvokeFrom -from dify_graph.entities.graph_init_params import GraphInitParams -from dify_graph.enums import NodeType, UserFrom +from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom +from dify_graph.entities.graph_init_params import DIFY_RUN_CONTEXT_KEY, GraphInitParams +from dify_graph.enums import NodeType from dify_graph.graph_events import ( NodeRunHumanInputFormFilledEvent, NodeRunHumanInputFormTimeoutEvent, @@ -31,13 +31,17 @@ def _build_node(form_content: str = "Please enter your name:\n\n{{#$output.name# start_at=0.0, ) graph_init_params = GraphInitParams( - tenant_id="tenant", - app_id="app", workflow_id="workflow", graph_config={"nodes": [], "edges": []}, - user_id="user", - user_from=UserFrom.ACCOUNT, - invoke_from=InvokeFrom.SERVICE_API, + run_context={ + DIFY_RUN_CONTEXT_KEY: { + "tenant_id": "tenant", + "app_id": "app", + "user_id": "user", + "user_from": UserFrom.ACCOUNT, + "invoke_from": InvokeFrom.SERVICE_API, + } + }, call_depth=0, ) @@ -91,13 +95,17 @@ def _build_timeout_node() -> HumanInputNode: start_at=0.0, ) graph_init_params = GraphInitParams( - tenant_id="tenant", - app_id="app", workflow_id="workflow", graph_config={"nodes": [], "edges": []}, - user_id="user", - user_from=UserFrom.ACCOUNT, - invoke_from=InvokeFrom.SERVICE_API, + run_context={ + DIFY_RUN_CONTEXT_KEY: { + "tenant_id": "tenant", + "app_id": "app", + "user_id": "user", + "user_from": UserFrom.ACCOUNT, + "invoke_from": InvokeFrom.SERVICE_API, + } + }, call_depth=0, ) diff --git a/api/tests/unit_tests/core/workflow/nodes/iteration/test_iteration_child_engine_errors.py b/api/tests/unit_tests/core/workflow/nodes/iteration/test_iteration_child_engine_errors.py new file mode 100644 index 0000000000..2eb4feef5f --- /dev/null +++ b/api/tests/unit_tests/core/workflow/nodes/iteration/test_iteration_child_engine_errors.py @@ -0,0 +1,100 @@ +from collections.abc import Mapping, Sequence +from typing import Any + +import pytest + +from dify_graph.entities import GraphInitParams +from dify_graph.nodes.iteration.exc import IterationGraphNotFoundError +from dify_graph.nodes.iteration.iteration_node import IterationNode +from dify_graph.runtime import ( + ChildEngineBuilderNotConfiguredError, + ChildGraphNotFoundError, + GraphRuntimeState, + VariablePool, +) +from dify_graph.system_variable import SystemVariable +from tests.workflow_test_utils import build_test_graph_init_params + + +class _MissingGraphBuilder: + def build_child_engine( + self, + *, + workflow_id: str, + graph_init_params: GraphInitParams, + graph_runtime_state: GraphRuntimeState, + graph_config: Mapping[str, Any], + root_node_id: str, + layers: Sequence[object] = (), + ) -> object: + raise ChildGraphNotFoundError(f"child graph root node '{root_node_id}' not found") + + +def _build_runtime_state() -> GraphRuntimeState: + return GraphRuntimeState( + variable_pool=VariablePool(system_variables=SystemVariable.default(), user_inputs={}), + start_at=0.0, + ) + + +def _build_iteration_node( + *, + graph_config: Mapping[str, Any], + runtime_state: GraphRuntimeState, + start_node_id: str, +) -> IterationNode: + init_params = build_test_graph_init_params(graph_config=graph_config) + return IterationNode( + id="iteration-node", + config={ + "id": "iteration-node", + "data": { + "type": "iteration", + "title": "Iteration", + "iterator_selector": ["start", "items"], + "output_selector": ["iteration-node", "output"], + "start_node_id": start_node_id, + }, + }, + graph_init_params=init_params, + graph_runtime_state=runtime_state, + ) + + +def test_graph_runtime_state_raises_specific_error_when_child_builder_is_missing(): + runtime_state = _build_runtime_state() + graph_init_params = build_test_graph_init_params() + + with pytest.raises(ChildEngineBuilderNotConfiguredError): + runtime_state.create_child_engine( + workflow_id="workflow", + graph_init_params=graph_init_params, + graph_runtime_state=_build_runtime_state(), + graph_config={}, + root_node_id="root", + ) + + +def test_iteration_node_only_translates_child_graph_not_found_error(): + runtime_state = _build_runtime_state() + runtime_state.bind_child_engine_builder(_MissingGraphBuilder()) + node = _build_iteration_node( + graph_config={"nodes": [{"id": "present-node"}], "edges": []}, + runtime_state=runtime_state, + start_node_id="missing-node", + ) + + with pytest.raises(IterationGraphNotFoundError): + node._create_graph_engine(index=0, item="item") + + +def test_iteration_node_propagates_non_graph_not_found_errors(): + runtime_state = _build_runtime_state() + node = _build_iteration_node( + graph_config={"nodes": [{"id": "start-node"}], "edges": []}, + runtime_state=runtime_state, + start_node_id="start-node", + ) + + with pytest.raises(ChildEngineBuilderNotConfiguredError): + node._create_graph_engine(index=0, item="item") diff --git a/api/tests/unit_tests/core/workflow/nodes/knowledge_index/test_knowledge_index_node.py b/api/tests/unit_tests/core/workflow/nodes/knowledge_index/test_knowledge_index_node.py index ffb9c0a43f..8116fc8b3c 100644 --- a/api/tests/unit_tests/core/workflow/nodes/knowledge_index/test_knowledge_index_node.py +++ b/api/tests/unit_tests/core/workflow/nodes/knowledge_index/test_knowledge_index_node.py @@ -4,9 +4,8 @@ from unittest.mock import Mock import pytest -from core.app.entities.app_invoke_entities import InvokeFrom -from dify_graph.entities import GraphInitParams -from dify_graph.enums import SystemVariableKey, UserFrom, WorkflowNodeExecutionStatus +from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom +from dify_graph.enums import SystemVariableKey, WorkflowNodeExecutionStatus from dify_graph.nodes.knowledge_index.entities import KnowledgeIndexNodeData from dify_graph.nodes.knowledge_index.exc import KnowledgeIndexNodeError from dify_graph.nodes.knowledge_index.knowledge_index_node import KnowledgeIndexNode @@ -15,16 +14,17 @@ from dify_graph.repositories.summary_index_service_protocol import SummaryIndexS from dify_graph.runtime import GraphRuntimeState, VariablePool from dify_graph.system_variable import SystemVariable from dify_graph.variables.segments import StringSegment +from tests.workflow_test_utils import build_test_graph_init_params @pytest.fixture def mock_graph_init_params(): """Create mock GraphInitParams.""" - return GraphInitParams( - tenant_id=str(uuid.uuid4()), - app_id=str(uuid.uuid4()), + return build_test_graph_init_params( workflow_id=str(uuid.uuid4()), graph_config={}, + tenant_id=str(uuid.uuid4()), + app_id=str(uuid.uuid4()), user_id=str(uuid.uuid4()), user_from=UserFrom.ACCOUNT, invoke_from=InvokeFrom.DEBUGGER, diff --git a/api/tests/unit_tests/core/workflow/nodes/knowledge_retrieval/test_knowledge_retrieval_node.py b/api/tests/unit_tests/core/workflow/nodes/knowledge_retrieval/test_knowledge_retrieval_node.py index 45a4316ead..e194d66ee3 100644 --- a/api/tests/unit_tests/core/workflow/nodes/knowledge_retrieval/test_knowledge_retrieval_node.py +++ b/api/tests/unit_tests/core/workflow/nodes/knowledge_retrieval/test_knowledge_retrieval_node.py @@ -4,9 +4,8 @@ from unittest.mock import Mock import pytest -from core.app.entities.app_invoke_entities import InvokeFrom -from dify_graph.entities import GraphInitParams -from dify_graph.enums import UserFrom, WorkflowNodeExecutionStatus +from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom +from dify_graph.enums import WorkflowNodeExecutionStatus from dify_graph.model_runtime.entities.llm_entities import LLMUsage from dify_graph.nodes.knowledge_retrieval.entities import ( KnowledgeRetrievalNodeData, @@ -20,16 +19,17 @@ from dify_graph.repositories.rag_retrieval_protocol import RAGRetrievalProtocol, from dify_graph.runtime import GraphRuntimeState, VariablePool from dify_graph.system_variable import SystemVariable from dify_graph.variables import StringSegment +from tests.workflow_test_utils import build_test_graph_init_params @pytest.fixture def mock_graph_init_params(): """Create mock GraphInitParams.""" - return GraphInitParams( - tenant_id=str(uuid.uuid4()), - app_id=str(uuid.uuid4()), + return build_test_graph_init_params( workflow_id=str(uuid.uuid4()), graph_config={}, + tenant_id=str(uuid.uuid4()), + app_id=str(uuid.uuid4()), user_id=str(uuid.uuid4()), user_from=UserFrom.ACCOUNT, invoke_from=InvokeFrom.DEBUGGER, diff --git a/api/tests/unit_tests/core/workflow/nodes/list_operator/node_spec.py b/api/tests/unit_tests/core/workflow/nodes/list_operator/node_spec.py index 65228df517..25760ba352 100644 --- a/api/tests/unit_tests/core/workflow/nodes/list_operator/node_spec.py +++ b/api/tests/unit_tests/core/workflow/nodes/list_operator/node_spec.py @@ -1,14 +1,13 @@ from unittest.mock import MagicMock import pytest -from dify_graph.graph_engine.entities.graph import Graph -from dify_graph.graph_engine.entities.graph_init_params import GraphInitParams -from dify_graph.graph_engine.entities.graph_runtime_state import GraphRuntimeState +from dify_graph.entities import GraphInitParams +from dify_graph.entities.graph_init_params import DIFY_RUN_CONTEXT_KEY from dify_graph.enums import NodeType, WorkflowNodeExecutionStatus from dify_graph.nodes.list_operator.node import ListOperatorNode +from dify_graph.runtime import GraphRuntimeState from dify_graph.variables import ArrayNumberSegment, ArrayStringSegment -from models.workflow import WorkflowType class TestListOperatorNode: @@ -22,43 +21,40 @@ class TestListOperatorNode: mock_state.variable_pool = mock_variable_pool return mock_state - @pytest.fixture - def mock_graph(self): - """Create mock Graph.""" - return MagicMock(spec=Graph) - @pytest.fixture def graph_init_params(self): """Create GraphInitParams fixture.""" return GraphInitParams( - tenant_id="test", - app_id="test", - workflow_type=WorkflowType.WORKFLOW, workflow_id="test", graph_config={}, - user_id="test", - user_from="test", - invoke_from="test", + run_context={ + DIFY_RUN_CONTEXT_KEY: { + "tenant_id": "test", + "app_id": "test", + "user_id": "test", + "user_from": "test", + "invoke_from": "test", + } + }, call_depth=0, ) @pytest.fixture - def list_operator_node_factory(self, graph_init_params, mock_graph, mock_graph_runtime_state): + def list_operator_node_factory(self, graph_init_params, mock_graph_runtime_state): """Factory fixture for creating ListOperatorNode instances.""" def _create_node(config, mock_variable): mock_graph_runtime_state.variable_pool.get.return_value = mock_variable return ListOperatorNode( id="test", - config=config, + config={"id": "test", "data": config}, graph_init_params=graph_init_params, - graph=mock_graph, graph_runtime_state=mock_graph_runtime_state, ) return _create_node - def test_node_initialization(self, mock_graph, mock_graph_runtime_state, graph_init_params): + def test_node_initialization(self, mock_graph_runtime_state, graph_init_params): """Test node initializes correctly.""" config = { "title": "List Operator", @@ -70,9 +66,8 @@ class TestListOperatorNode: node = ListOperatorNode( id="test", - config=config, + config={"id": "test", "data": config}, graph_init_params=graph_init_params, - graph=mock_graph, graph_runtime_state=mock_graph_runtime_state, ) @@ -101,7 +96,7 @@ class TestListOperatorNode: assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED assert result.outputs["result"].value == ["apple", "banana", "cherry"] - def test_run_with_empty_array(self, mock_graph, mock_graph_runtime_state, graph_init_params): + def test_run_with_empty_array(self, mock_graph_runtime_state, graph_init_params): """Test with empty array.""" config = { "title": "Test", @@ -116,9 +111,8 @@ class TestListOperatorNode: node = ListOperatorNode( id="test", - config=config, + config={"id": "test", "data": config}, graph_init_params=graph_init_params, - graph=mock_graph, graph_runtime_state=mock_graph_runtime_state, ) @@ -129,7 +123,7 @@ class TestListOperatorNode: assert result.outputs["first_record"] is None assert result.outputs["last_record"] is None - def test_run_with_filter_contains(self, mock_graph, mock_graph_runtime_state, graph_init_params): + def test_run_with_filter_contains(self, mock_graph_runtime_state, graph_init_params): """Test filter with contains condition.""" config = { "title": "Test", @@ -148,9 +142,8 @@ class TestListOperatorNode: node = ListOperatorNode( id="test", - config=config, + config={"id": "test", "data": config}, graph_init_params=graph_init_params, - graph=mock_graph, graph_runtime_state=mock_graph_runtime_state, ) @@ -159,7 +152,7 @@ class TestListOperatorNode: assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED assert result.outputs["result"].value == ["apple", "pineapple"] - def test_run_with_filter_not_contains(self, mock_graph, mock_graph_runtime_state, graph_init_params): + def test_run_with_filter_not_contains(self, mock_graph_runtime_state, graph_init_params): """Test filter with not contains condition.""" config = { "title": "Test", @@ -178,9 +171,8 @@ class TestListOperatorNode: node = ListOperatorNode( id="test", - config=config, + config={"id": "test", "data": config}, graph_init_params=graph_init_params, - graph=mock_graph, graph_runtime_state=mock_graph_runtime_state, ) @@ -189,7 +181,7 @@ class TestListOperatorNode: assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED assert result.outputs["result"].value == ["banana", "cherry"] - def test_run_with_number_filter_greater_than(self, mock_graph, mock_graph_runtime_state, graph_init_params): + def test_run_with_number_filter_greater_than(self, mock_graph_runtime_state, graph_init_params): """Test filter with greater than condition on numbers.""" config = { "title": "Test", @@ -208,9 +200,8 @@ class TestListOperatorNode: node = ListOperatorNode( id="test", - config=config, + config={"id": "test", "data": config}, graph_init_params=graph_init_params, - graph=mock_graph, graph_runtime_state=mock_graph_runtime_state, ) @@ -219,7 +210,7 @@ class TestListOperatorNode: assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED assert result.outputs["result"].value == [7, 9, 11] - def test_run_with_order_ascending(self, mock_graph, mock_graph_runtime_state, graph_init_params): + def test_run_with_order_ascending(self, mock_graph_runtime_state, graph_init_params): """Test ordering in ascending order.""" config = { "title": "Test", @@ -237,9 +228,8 @@ class TestListOperatorNode: node = ListOperatorNode( id="test", - config=config, + config={"id": "test", "data": config}, graph_init_params=graph_init_params, - graph=mock_graph, graph_runtime_state=mock_graph_runtime_state, ) @@ -248,7 +238,7 @@ class TestListOperatorNode: assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED assert result.outputs["result"].value == ["apple", "banana", "cherry"] - def test_run_with_order_descending(self, mock_graph, mock_graph_runtime_state, graph_init_params): + def test_run_with_order_descending(self, mock_graph_runtime_state, graph_init_params): """Test ordering in descending order.""" config = { "title": "Test", @@ -266,9 +256,8 @@ class TestListOperatorNode: node = ListOperatorNode( id="test", - config=config, + config={"id": "test", "data": config}, graph_init_params=graph_init_params, - graph=mock_graph, graph_runtime_state=mock_graph_runtime_state, ) @@ -277,7 +266,7 @@ class TestListOperatorNode: assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED assert result.outputs["result"].value == ["cherry", "banana", "apple"] - def test_run_with_limit(self, mock_graph, mock_graph_runtime_state, graph_init_params): + def test_run_with_limit(self, mock_graph_runtime_state, graph_init_params): """Test with limit enabled.""" config = { "title": "Test", @@ -295,9 +284,8 @@ class TestListOperatorNode: node = ListOperatorNode( id="test", - config=config, + config={"id": "test", "data": config}, graph_init_params=graph_init_params, - graph=mock_graph, graph_runtime_state=mock_graph_runtime_state, ) @@ -306,7 +294,7 @@ class TestListOperatorNode: assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED assert result.outputs["result"].value == ["apple", "banana"] - def test_run_with_filter_order_and_limit(self, mock_graph, mock_graph_runtime_state, graph_init_params): + def test_run_with_filter_order_and_limit(self, mock_graph_runtime_state, graph_init_params): """Test with filter, order, and limit combined.""" config = { "title": "Test", @@ -331,9 +319,8 @@ class TestListOperatorNode: node = ListOperatorNode( id="test", - config=config, + config={"id": "test", "data": config}, graph_init_params=graph_init_params, - graph=mock_graph, graph_runtime_state=mock_graph_runtime_state, ) @@ -342,7 +329,7 @@ class TestListOperatorNode: assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED assert result.outputs["result"].value == [9, 8, 7] - def test_run_with_variable_not_found(self, mock_graph, mock_graph_runtime_state, graph_init_params): + def test_run_with_variable_not_found(self, mock_graph_runtime_state, graph_init_params): """Test when variable is not found.""" config = { "title": "Test", @@ -356,9 +343,8 @@ class TestListOperatorNode: node = ListOperatorNode( id="test", - config=config, + config={"id": "test", "data": config}, graph_init_params=graph_init_params, - graph=mock_graph, graph_runtime_state=mock_graph_runtime_state, ) @@ -367,7 +353,7 @@ class TestListOperatorNode: assert result.status == WorkflowNodeExecutionStatus.FAILED assert "Variable not found" in result.error - def test_run_with_first_and_last_record(self, mock_graph, mock_graph_runtime_state, graph_init_params): + def test_run_with_first_and_last_record(self, mock_graph_runtime_state, graph_init_params): """Test first_record and last_record outputs.""" config = { "title": "Test", @@ -382,9 +368,8 @@ class TestListOperatorNode: node = ListOperatorNode( id="test", - config=config, + config={"id": "test", "data": config}, graph_init_params=graph_init_params, - graph=mock_graph, graph_runtime_state=mock_graph_runtime_state, ) @@ -394,7 +379,7 @@ class TestListOperatorNode: assert result.outputs["first_record"] == "first" assert result.outputs["last_record"] == "last" - def test_run_with_filter_startswith(self, mock_graph, mock_graph_runtime_state, graph_init_params): + def test_run_with_filter_startswith(self, mock_graph_runtime_state, graph_init_params): """Test filter with startswith condition.""" config = { "title": "Test", @@ -413,9 +398,8 @@ class TestListOperatorNode: node = ListOperatorNode( id="test", - config=config, + config={"id": "test", "data": config}, graph_init_params=graph_init_params, - graph=mock_graph, graph_runtime_state=mock_graph_runtime_state, ) @@ -424,7 +408,7 @@ class TestListOperatorNode: assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED assert result.outputs["result"].value == ["apple", "application"] - def test_run_with_filter_endswith(self, mock_graph, mock_graph_runtime_state, graph_init_params): + def test_run_with_filter_endswith(self, mock_graph_runtime_state, graph_init_params): """Test filter with endswith condition.""" config = { "title": "Test", @@ -443,9 +427,8 @@ class TestListOperatorNode: node = ListOperatorNode( id="test", - config=config, + config={"id": "test", "data": config}, graph_init_params=graph_init_params, - graph=mock_graph, graph_runtime_state=mock_graph_runtime_state, ) @@ -454,7 +437,7 @@ class TestListOperatorNode: assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED assert result.outputs["result"].value == ["apple", "pineapple", "table"] - def test_run_with_number_filter_equals(self, mock_graph, mock_graph_runtime_state, graph_init_params): + def test_run_with_number_filter_equals(self, mock_graph_runtime_state, graph_init_params): """Test number filter with equals condition.""" config = { "title": "Test", @@ -473,9 +456,8 @@ class TestListOperatorNode: node = ListOperatorNode( id="test", - config=config, + config={"id": "test", "data": config}, graph_init_params=graph_init_params, - graph=mock_graph, graph_runtime_state=mock_graph_runtime_state, ) @@ -484,7 +466,7 @@ class TestListOperatorNode: assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED assert result.outputs["result"].value == [5, 5] - def test_run_with_number_filter_not_equals(self, mock_graph, mock_graph_runtime_state, graph_init_params): + def test_run_with_number_filter_not_equals(self, mock_graph_runtime_state, graph_init_params): """Test number filter with not equals condition.""" config = { "title": "Test", @@ -503,9 +485,8 @@ class TestListOperatorNode: node = ListOperatorNode( id="test", - config=config, + config={"id": "test", "data": config}, graph_init_params=graph_init_params, - graph=mock_graph, graph_runtime_state=mock_graph_runtime_state, ) @@ -514,7 +495,7 @@ class TestListOperatorNode: assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED assert result.outputs["result"].value == [1, 3, 7, 9] - def test_run_with_number_order_ascending(self, mock_graph, mock_graph_runtime_state, graph_init_params): + def test_run_with_number_order_ascending(self, mock_graph_runtime_state, graph_init_params): """Test number ordering in ascending order.""" config = { "title": "Test", @@ -532,9 +513,8 @@ class TestListOperatorNode: node = ListOperatorNode( id="test", - config=config, + config={"id": "test", "data": config}, graph_init_params=graph_init_params, - graph=mock_graph, graph_runtime_state=mock_graph_runtime_state, ) diff --git a/api/tests/unit_tests/core/workflow/nodes/llm/test_node.py b/api/tests/unit_tests/core/workflow/nodes/llm/test_node.py index b822f1fbe4..90308facc3 100644 --- a/api/tests/unit_tests/core/workflow/nodes/llm/test_node.py +++ b/api/tests/unit_tests/core/workflow/nodes/llm/test_node.py @@ -5,14 +5,13 @@ from unittest import mock import pytest -from core.app.entities.app_invoke_entities import InvokeFrom, ModelConfigWithCredentialsEntity +from core.app.entities.app_invoke_entities import InvokeFrom, ModelConfigWithCredentialsEntity, UserFrom from core.app.llm.model_access import DifyCredentialsProvider, DifyModelFactory, fetch_model_config from core.entities.provider_configuration import ProviderConfiguration, ProviderModelBundle from core.entities.provider_entities import CustomConfiguration, SystemConfiguration from core.model_manager import ModelInstance from core.prompt.entities.advanced_prompt_entities import MemoryConfig from dify_graph.entities import GraphInitParams -from dify_graph.enums import UserFrom from dify_graph.file import File, FileTransferMethod, FileType from dify_graph.model_runtime.entities.common_entities import I18nObject from dify_graph.model_runtime.entities.message_entities import ( @@ -41,6 +40,7 @@ from dify_graph.runtime import GraphRuntimeState, VariablePool from dify_graph.system_variable import SystemVariable from dify_graph.variables import ArrayAnySegment, ArrayFileSegment, NoneSegment from models.provider import ProviderType +from tests.workflow_test_utils import build_test_graph_init_params class MockTokenBufferMemory: @@ -76,11 +76,11 @@ def llm_node_data() -> LLMNodeData: @pytest.fixture def graph_init_params() -> GraphInitParams: - return GraphInitParams( - tenant_id="1", - app_id="1", + return build_test_graph_init_params( workflow_id="1", graph_config={}, + tenant_id="1", + app_id="1", user_id="1", user_from=UserFrom.ACCOUNT, invoke_from=InvokeFrom.SERVICE_API, diff --git a/api/tests/unit_tests/core/workflow/nodes/template_transform/template_transform_node_spec.py b/api/tests/unit_tests/core/workflow/nodes/template_transform/template_transform_node_spec.py index 48d76d9b9b..6831626f58 100644 --- a/api/tests/unit_tests/core/workflow/nodes/template_transform/template_transform_node_spec.py +++ b/api/tests/unit_tests/core/workflow/nodes/template_transform/template_transform_node_spec.py @@ -2,15 +2,13 @@ from unittest.mock import MagicMock import pytest -from core.app.entities.app_invoke_entities import InvokeFrom -from dify_graph.entities import GraphInitParams +from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom from dify_graph.enums import ErrorStrategy, NodeType, WorkflowNodeExecutionStatus from dify_graph.graph import Graph from dify_graph.nodes.template_transform.template_renderer import TemplateRenderError from dify_graph.nodes.template_transform.template_transform_node import TemplateTransformNode from dify_graph.runtime import GraphRuntimeState -from models.enums import UserFrom -from models.workflow import WorkflowType +from tests.workflow_test_utils import build_test_graph_init_params class TestTemplateTransformNode: @@ -32,12 +30,11 @@ class TestTemplateTransformNode: @pytest.fixture def graph_init_params(self): """Create a mock GraphInitParams.""" - return GraphInitParams( - tenant_id="test_tenant", - app_id="test_app", - workflow_type=WorkflowType.WORKFLOW, + return build_test_graph_init_params( workflow_id="test_workflow", graph_config={}, + tenant_id="test_tenant", + app_id="test_app", user_id="test_user", user_from=UserFrom.ACCOUNT, invoke_from=InvokeFrom.DEBUGGER, diff --git a/api/tests/unit_tests/core/workflow/nodes/test_base_node.py b/api/tests/unit_tests/core/workflow/nodes/test_base_node.py index ff9247059b..44abf430c0 100644 --- a/api/tests/unit_tests/core/workflow/nodes/test_base_node.py +++ b/api/tests/unit_tests/core/workflow/nodes/test_base_node.py @@ -2,13 +2,14 @@ from collections.abc import Mapping import pytest -from core.app.entities.app_invoke_entities import InvokeFrom as LegacyInvokeFrom +from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom from dify_graph.entities import GraphInitParams -from dify_graph.enums import InvokeFrom, NodeType, UserFrom +from dify_graph.enums import NodeType from dify_graph.nodes.base.entities import BaseNodeData from dify_graph.nodes.base.node import Node from dify_graph.runtime import GraphRuntimeState, VariablePool from dify_graph.system_variable import SystemVariable +from tests.workflow_test_utils import build_test_graph_init_params class _SampleNodeData(BaseNodeData): @@ -27,15 +28,10 @@ class _SampleNode(Node[_SampleNodeData]): def _build_context(graph_config: Mapping[str, object]) -> tuple[GraphInitParams, GraphRuntimeState]: - init_params = GraphInitParams( - tenant_id="tenant", - app_id="app", - workflow_id="workflow", + init_params = build_test_graph_init_params( graph_config=graph_config, - user_id="user", user_from="account", invoke_from="debugger", - call_depth=0, ) runtime_state = GraphRuntimeState( variable_pool=VariablePool(system_variables=SystemVariable(user_id="user", files=[]), user_inputs={}), @@ -57,21 +53,17 @@ def test_node_hydrates_data_during_initialization(): assert node.node_data.foo == "bar" assert node.title == "Sample" - assert node.user_from == UserFrom.ACCOUNT - assert node.invoke_from == InvokeFrom.DEBUGGER + dify_ctx = node.require_dify_context() + assert dify_ctx.user_from == "account" + assert dify_ctx.invoke_from == "debugger" -def test_node_normalizes_legacy_invoke_from_enum(): +def test_node_accepts_invoke_from_enum(): graph_config: dict[str, object] = {} - init_params = GraphInitParams( - tenant_id="tenant", - app_id="app", - workflow_id="workflow", + init_params = build_test_graph_init_params( graph_config=graph_config, - user_id="user", user_from=UserFrom.ACCOUNT, - invoke_from=LegacyInvokeFrom.DEBUGGER, - call_depth=0, + invoke_from=InvokeFrom.DEBUGGER, ) runtime_state = GraphRuntimeState( variable_pool=VariablePool(system_variables=SystemVariable(user_id="user", files=[]), user_inputs={}), @@ -85,8 +77,12 @@ def test_node_normalizes_legacy_invoke_from_enum(): graph_runtime_state=runtime_state, ) - assert node.user_from == UserFrom.ACCOUNT - assert node.invoke_from == InvokeFrom.DEBUGGER + dify_ctx = node.require_dify_context() + assert dify_ctx.user_from == UserFrom.ACCOUNT + assert dify_ctx.invoke_from == InvokeFrom.DEBUGGER + assert node.get_run_context_value("missing") is None + with pytest.raises(ValueError): + node.require_run_context_value("missing") def test_missing_generic_argument_raises_type_error(): diff --git a/api/tests/unit_tests/core/workflow/nodes/test_document_extractor_node.py b/api/tests/unit_tests/core/workflow/nodes/test_document_extractor_node.py index dff84b580a..5e20b1e12f 100644 --- a/api/tests/unit_tests/core/workflow/nodes/test_document_extractor_node.py +++ b/api/tests/unit_tests/core/workflow/nodes/test_document_extractor_node.py @@ -5,9 +5,9 @@ import pandas as pd import pytest from docx.oxml.text.paragraph import CT_P -from core.app.entities.app_invoke_entities import InvokeFrom +from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom from dify_graph.entities import GraphInitParams -from dify_graph.enums import NodeType, UserFrom, WorkflowNodeExecutionStatus +from dify_graph.enums import NodeType, WorkflowNodeExecutionStatus from dify_graph.file import File, FileTransferMethod from dify_graph.node_events import NodeRunResult from dify_graph.nodes.document_extractor import DocumentExtractorNode, DocumentExtractorNodeData @@ -20,15 +20,16 @@ from dify_graph.nodes.document_extractor.node import ( from dify_graph.variables import ArrayFileSegment from dify_graph.variables.segments import ArrayStringSegment from dify_graph.variables.variables import StringVariable +from tests.workflow_test_utils import build_test_graph_init_params @pytest.fixture def graph_init_params() -> GraphInitParams: - return GraphInitParams( - tenant_id="test_tenant", - app_id="test_app", + return build_test_graph_init_params( workflow_id="test_workflow", graph_config={}, + tenant_id="test_tenant", + app_id="test_app", user_id="test_user", user_from=UserFrom.ACCOUNT, invoke_from=InvokeFrom.DEBUGGER, diff --git a/api/tests/unit_tests/core/workflow/nodes/test_if_else.py b/api/tests/unit_tests/core/workflow/nodes/test_if_else.py index 23b96e7b25..041bd66d03 100644 --- a/api/tests/unit_tests/core/workflow/nodes/test_if_else.py +++ b/api/tests/unit_tests/core/workflow/nodes/test_if_else.py @@ -4,10 +4,10 @@ from unittest.mock import MagicMock, Mock import pytest -from core.app.entities.app_invoke_entities import InvokeFrom +from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom from core.workflow.node_factory import DifyNodeFactory -from dify_graph.entities import GraphInitParams -from dify_graph.enums import UserFrom, WorkflowNodeExecutionStatus +from dify_graph.entities.graph_init_params import DIFY_RUN_CONTEXT_KEY +from dify_graph.enums import WorkflowNodeExecutionStatus from dify_graph.file import File, FileTransferMethod, FileType from dify_graph.graph import Graph from dify_graph.nodes.if_else.entities import IfElseNodeData @@ -17,16 +17,17 @@ from dify_graph.system_variable import SystemVariable from dify_graph.utils.condition.entities import Condition, SubCondition, SubVariableCondition from dify_graph.variables import ArrayFileSegment from extensions.ext_database import db +from tests.workflow_test_utils import build_test_graph_init_params def test_execute_if_else_result_true(): graph_config = {"edges": [], "nodes": [{"data": {"type": "start", "title": "Start"}, "id": "start"}]} - init_params = GraphInitParams( - tenant_id="1", - app_id="1", + init_params = build_test_graph_init_params( workflow_id="1", graph_config=graph_config, + tenant_id="1", + app_id="1", user_id="1", user_from=UserFrom.ACCOUNT, invoke_from=InvokeFrom.DEBUGGER, @@ -128,11 +129,11 @@ def test_execute_if_else_result_false(): # Create a simple graph for IfElse node testing graph_config = {"edges": [], "nodes": [{"data": {"type": "start", "title": "Start"}, "id": "start"}]} - init_params = GraphInitParams( - tenant_id="1", - app_id="1", + init_params = build_test_graph_init_params( workflow_id="1", graph_config=graph_config, + tenant_id="1", + app_id="1", user_id="1", user_from=UserFrom.ACCOUNT, invoke_from=InvokeFrom.DEBUGGER, @@ -229,14 +230,18 @@ def test_array_file_contains_file_name(): # Create properly configured mock for graph_init_params graph_init_params = Mock() - graph_init_params.tenant_id = "test_tenant" - graph_init_params.app_id = "test_app" graph_init_params.workflow_id = "test_workflow" graph_init_params.graph_config = {} - graph_init_params.user_id = "test_user" - graph_init_params.user_from = UserFrom.ACCOUNT - graph_init_params.invoke_from = InvokeFrom.SERVICE_API graph_init_params.call_depth = 0 + graph_init_params.run_context = { + DIFY_RUN_CONTEXT_KEY: { + "tenant_id": "test_tenant", + "app_id": "test_app", + "user_id": "test_user", + "user_from": UserFrom.ACCOUNT, + "invoke_from": InvokeFrom.SERVICE_API, + } + } node = IfElseNode( id=str(uuid.uuid4()), @@ -298,11 +303,11 @@ def test_execute_if_else_boolean_conditions(condition: Condition): """Test IfElseNode with boolean conditions using various operators""" graph_config = {"edges": [], "nodes": [{"data": {"type": "start", "title": "Start"}, "id": "start"}]} - init_params = GraphInitParams( - tenant_id="1", - app_id="1", + init_params = build_test_graph_init_params( workflow_id="1", graph_config=graph_config, + tenant_id="1", + app_id="1", user_id="1", user_from=UserFrom.ACCOUNT, invoke_from=InvokeFrom.DEBUGGER, @@ -353,11 +358,11 @@ def test_execute_if_else_boolean_false_conditions(): """Test IfElseNode with boolean conditions that should evaluate to false""" graph_config = {"edges": [], "nodes": [{"data": {"type": "start", "title": "Start"}, "id": "start"}]} - init_params = GraphInitParams( - tenant_id="1", - app_id="1", + init_params = build_test_graph_init_params( workflow_id="1", graph_config=graph_config, + tenant_id="1", + app_id="1", user_id="1", user_from=UserFrom.ACCOUNT, invoke_from=InvokeFrom.DEBUGGER, @@ -422,11 +427,11 @@ def test_execute_if_else_boolean_cases_structure(): """Test IfElseNode with boolean conditions using the new cases structure""" graph_config = {"edges": [], "nodes": [{"data": {"type": "start", "title": "Start"}, "id": "start"}]} - init_params = GraphInitParams( - tenant_id="1", - app_id="1", + init_params = build_test_graph_init_params( workflow_id="1", graph_config=graph_config, + tenant_id="1", + app_id="1", user_id="1", user_from=UserFrom.ACCOUNT, invoke_from=InvokeFrom.DEBUGGER, diff --git a/api/tests/unit_tests/core/workflow/nodes/test_list_operator.py b/api/tests/unit_tests/core/workflow/nodes/test_list_operator.py index 4c43f63c74..6ca72b64b2 100644 --- a/api/tests/unit_tests/core/workflow/nodes/test_list_operator.py +++ b/api/tests/unit_tests/core/workflow/nodes/test_list_operator.py @@ -2,8 +2,9 @@ from unittest.mock import MagicMock import pytest -from core.app.entities.app_invoke_entities import InvokeFrom -from dify_graph.enums import UserFrom, WorkflowNodeExecutionStatus +from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom +from dify_graph.entities.graph_init_params import DIFY_RUN_CONTEXT_KEY +from dify_graph.enums import WorkflowNodeExecutionStatus from dify_graph.file import File, FileTransferMethod, FileType from dify_graph.nodes.list_operator.entities import ( ExtractConfig, @@ -41,14 +42,18 @@ def list_operator_node(): } # Create properly configured mock for graph_init_params graph_init_params = MagicMock() - graph_init_params.tenant_id = "test_tenant" - graph_init_params.app_id = "test_app" graph_init_params.workflow_id = "test_workflow" graph_init_params.graph_config = {} - graph_init_params.user_id = "test_user" - graph_init_params.user_from = UserFrom.ACCOUNT - graph_init_params.invoke_from = InvokeFrom.SERVICE_API graph_init_params.call_depth = 0 + graph_init_params.run_context = { + DIFY_RUN_CONTEXT_KEY: { + "tenant_id": "test_tenant", + "app_id": "test_app", + "user_id": "test_user", + "user_from": UserFrom.ACCOUNT, + "invoke_from": InvokeFrom.SERVICE_API, + } + } node = ListOperatorNode( id="test_node_id", diff --git a/api/tests/unit_tests/core/workflow/nodes/test_start_node_json_object.py b/api/tests/unit_tests/core/workflow/nodes/test_start_node_json_object.py index 5fd1e33768..b8f0e25e91 100644 --- a/api/tests/unit_tests/core/workflow/nodes/test_start_node_json_object.py +++ b/api/tests/unit_tests/core/workflow/nodes/test_start_node_json_object.py @@ -4,12 +4,12 @@ import time import pytest from pydantic import ValidationError as PydanticValidationError -from dify_graph.entities import GraphInitParams from dify_graph.nodes.start.entities import StartNodeData from dify_graph.nodes.start.start_node import StartNode from dify_graph.runtime import GraphRuntimeState, VariablePool from dify_graph.system_variable import SystemVariable from dify_graph.variables.input_entities import VariableEntity, VariableEntityType +from tests.workflow_test_utils import build_test_graph_init_params def make_start_node(user_inputs, variables): @@ -32,11 +32,11 @@ def make_start_node(user_inputs, variables): return StartNode( id="start", config=config, - graph_init_params=GraphInitParams( - tenant_id="tenant", - app_id="app", + graph_init_params=build_test_graph_init_params( workflow_id="wf", graph_config={}, + tenant_id="tenant", + app_id="app", user_id="u", user_from="account", invoke_from="debugger", diff --git a/api/tests/unit_tests/core/workflow/nodes/tool/test_tool_node.py b/api/tests/unit_tests/core/workflow/nodes/tool/test_tool_node.py index 3d88baa272..11554169e1 100644 --- a/api/tests/unit_tests/core/workflow/nodes/tool/test_tool_node.py +++ b/api/tests/unit_tests/core/workflow/nodes/tool/test_tool_node.py @@ -10,13 +10,13 @@ import pytest from core.tools.entities.tool_entities import ToolInvokeMessage from core.tools.utils.message_transformer import ToolFileMessageTransformer -from dify_graph.entities import GraphInitParams from dify_graph.file import File, FileTransferMethod, FileType from dify_graph.model_runtime.entities.llm_entities import LLMUsage from dify_graph.node_events import StreamChunkEvent, StreamCompletedEvent from dify_graph.runtime import GraphRuntimeState, VariablePool from dify_graph.system_variable import SystemVariable from dify_graph.variables.segments import ArrayFileSegment +from tests.workflow_test_utils import build_test_graph_init_params if TYPE_CHECKING: # pragma: no cover - imported for type checking only from dify_graph.nodes.tool.tool_node import ToolNode @@ -54,11 +54,11 @@ def tool_node(monkeypatch) -> ToolNode: "edges": [], } - init_params = GraphInitParams( - tenant_id="tenant-id", - app_id="app-id", + init_params = build_test_graph_init_params( workflow_id="workflow-id", graph_config=graph_config, + tenant_id="tenant-id", + app_id="app-id", user_id="user-id", user_from="account", invoke_from="debugger", diff --git a/api/tests/unit_tests/core/workflow/nodes/variable_assigner/v1/test_variable_assigner_v1.py b/api/tests/unit_tests/core/workflow/nodes/variable_assigner/v1/test_variable_assigner_v1.py index d5a593ffab..2cd3a38fa6 100644 --- a/api/tests/unit_tests/core/workflow/nodes/variable_assigner/v1/test_variable_assigner_v1.py +++ b/api/tests/unit_tests/core/workflow/nodes/variable_assigner/v1/test_variable_assigner_v1.py @@ -2,10 +2,10 @@ import time import uuid from uuid import uuid4 -from core.app.entities.app_invoke_entities import InvokeFrom +from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom from core.workflow.node_factory import DifyNodeFactory from dify_graph.entities import GraphInitParams -from dify_graph.enums import UserFrom +from dify_graph.entities.graph_init_params import DIFY_RUN_CONTEXT_KEY from dify_graph.graph import Graph from dify_graph.graph_events.node import NodeRunSucceededEvent from dify_graph.nodes.variable_assigner.common import helpers as common_helpers @@ -43,13 +43,17 @@ def test_overwrite_string_variable(): } init_params = GraphInitParams( - tenant_id="1", - app_id="1", workflow_id="1", graph_config=graph_config, - user_id="1", - user_from=UserFrom.ACCOUNT, - invoke_from=InvokeFrom.DEBUGGER, + run_context={ + DIFY_RUN_CONTEXT_KEY: { + "tenant_id": "1", + "app_id": "1", + "user_id": "1", + "user_from": UserFrom.ACCOUNT, + "invoke_from": InvokeFrom.DEBUGGER, + } + }, call_depth=0, ) @@ -141,13 +145,17 @@ def test_append_variable_to_array(): } init_params = GraphInitParams( - tenant_id="1", - app_id="1", workflow_id="1", graph_config=graph_config, - user_id="1", - user_from=UserFrom.ACCOUNT, - invoke_from=InvokeFrom.DEBUGGER, + run_context={ + DIFY_RUN_CONTEXT_KEY: { + "tenant_id": "1", + "app_id": "1", + "user_id": "1", + "user_from": UserFrom.ACCOUNT, + "invoke_from": InvokeFrom.DEBUGGER, + } + }, call_depth=0, ) @@ -236,13 +244,17 @@ def test_clear_array(): } init_params = GraphInitParams( - tenant_id="1", - app_id="1", workflow_id="1", graph_config=graph_config, - user_id="1", - user_from=UserFrom.ACCOUNT, - invoke_from=InvokeFrom.DEBUGGER, + run_context={ + DIFY_RUN_CONTEXT_KEY: { + "tenant_id": "1", + "app_id": "1", + "user_id": "1", + "user_from": UserFrom.ACCOUNT, + "invoke_from": InvokeFrom.DEBUGGER, + } + }, call_depth=0, ) diff --git a/api/tests/unit_tests/core/workflow/nodes/variable_assigner/v2/test_variable_assigner_v2.py b/api/tests/unit_tests/core/workflow/nodes/variable_assigner/v2/test_variable_assigner_v2.py index ef816b9ddc..5b285c2681 100644 --- a/api/tests/unit_tests/core/workflow/nodes/variable_assigner/v2/test_variable_assigner_v2.py +++ b/api/tests/unit_tests/core/workflow/nodes/variable_assigner/v2/test_variable_assigner_v2.py @@ -2,10 +2,10 @@ import time import uuid from uuid import uuid4 -from core.app.entities.app_invoke_entities import InvokeFrom +from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom from core.workflow.node_factory import DifyNodeFactory from dify_graph.entities import GraphInitParams -from dify_graph.enums import UserFrom +from dify_graph.entities.graph_init_params import DIFY_RUN_CONTEXT_KEY from dify_graph.graph import Graph from dify_graph.nodes.variable_assigner.v2 import VariableAssignerNode from dify_graph.nodes.variable_assigner.v2.enums import InputType, Operation @@ -85,13 +85,17 @@ def test_remove_first_from_array(): } init_params = GraphInitParams( - tenant_id="1", - app_id="1", workflow_id="1", graph_config=graph_config, - user_id="1", - user_from=UserFrom.ACCOUNT, - invoke_from=InvokeFrom.DEBUGGER, + run_context={ + DIFY_RUN_CONTEXT_KEY: { + "tenant_id": "1", + "app_id": "1", + "user_id": "1", + "user_from": UserFrom.ACCOUNT, + "invoke_from": InvokeFrom.DEBUGGER, + } + }, call_depth=0, ) @@ -169,13 +173,17 @@ def test_remove_last_from_array(): } init_params = GraphInitParams( - tenant_id="1", - app_id="1", workflow_id="1", graph_config=graph_config, - user_id="1", - user_from=UserFrom.ACCOUNT, - invoke_from=InvokeFrom.DEBUGGER, + run_context={ + DIFY_RUN_CONTEXT_KEY: { + "tenant_id": "1", + "app_id": "1", + "user_id": "1", + "user_from": UserFrom.ACCOUNT, + "invoke_from": InvokeFrom.DEBUGGER, + } + }, call_depth=0, ) @@ -250,13 +258,17 @@ def test_remove_first_from_empty_array(): } init_params = GraphInitParams( - tenant_id="1", - app_id="1", workflow_id="1", graph_config=graph_config, - user_id="1", - user_from=UserFrom.ACCOUNT, - invoke_from=InvokeFrom.DEBUGGER, + run_context={ + DIFY_RUN_CONTEXT_KEY: { + "tenant_id": "1", + "app_id": "1", + "user_id": "1", + "user_from": UserFrom.ACCOUNT, + "invoke_from": InvokeFrom.DEBUGGER, + } + }, call_depth=0, ) @@ -331,13 +343,17 @@ def test_remove_last_from_empty_array(): } init_params = GraphInitParams( - tenant_id="1", - app_id="1", workflow_id="1", graph_config=graph_config, - user_id="1", - user_from=UserFrom.ACCOUNT, - invoke_from=InvokeFrom.DEBUGGER, + run_context={ + DIFY_RUN_CONTEXT_KEY: { + "tenant_id": "1", + "app_id": "1", + "user_id": "1", + "user_from": UserFrom.ACCOUNT, + "invoke_from": InvokeFrom.DEBUGGER, + } + }, call_depth=0, ) @@ -404,13 +420,17 @@ def test_node_factory_creates_variable_assigner_node(): } init_params = GraphInitParams( - tenant_id="1", - app_id="1", workflow_id="1", graph_config=graph_config, - user_id="1", - user_from=UserFrom.ACCOUNT, - invoke_from=InvokeFrom.DEBUGGER, + run_context={ + DIFY_RUN_CONTEXT_KEY: { + "tenant_id": "1", + "app_id": "1", + "user_id": "1", + "user_from": UserFrom.ACCOUNT, + "invoke_from": InvokeFrom.DEBUGGER, + } + }, call_depth=0, ) variable_pool = VariablePool( diff --git a/api/tests/unit_tests/core/workflow/nodes/webhook/test_webhook_file_conversion.py b/api/tests/unit_tests/core/workflow/nodes/webhook/test_webhook_file_conversion.py index 5c89ba7d34..c750e74182 100644 --- a/api/tests/unit_tests/core/workflow/nodes/webhook/test_webhook_file_conversion.py +++ b/api/tests/unit_tests/core/workflow/nodes/webhook/test_webhook_file_conversion.py @@ -8,10 +8,9 @@ when passing files to downstream LLM nodes. from unittest.mock import Mock, patch -from core.app.entities.app_invoke_entities import InvokeFrom -from dify_graph.entities.graph_init_params import GraphInitParams +from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom +from dify_graph.entities.graph_init_params import DIFY_RUN_CONTEXT_KEY, GraphInitParams from dify_graph.entities.workflow_node_execution import WorkflowNodeExecutionStatus -from dify_graph.enums import UserFrom from dify_graph.nodes.trigger_webhook.entities import ( ContentType, Method, @@ -22,7 +21,6 @@ from dify_graph.nodes.trigger_webhook.node import TriggerWebhookNode from dify_graph.runtime.graph_runtime_state import GraphRuntimeState from dify_graph.runtime.variable_pool import VariablePool from dify_graph.system_variable import SystemVariable -from models.workflow import WorkflowType def create_webhook_node( @@ -37,14 +35,17 @@ def create_webhook_node( } graph_init_params = GraphInitParams( - tenant_id=tenant_id, - app_id="test-app", - workflow_type=WorkflowType.WORKFLOW, workflow_id="test-workflow", graph_config={}, - user_id="test-user", - user_from=UserFrom.ACCOUNT, - invoke_from=InvokeFrom.SERVICE_API, + run_context={ + DIFY_RUN_CONTEXT_KEY: { + "tenant_id": tenant_id, + "app_id": "test-app", + "user_id": "test-user", + "user_from": UserFrom.ACCOUNT, + "invoke_from": InvokeFrom.SERVICE_API, + } + }, call_depth=0, ) diff --git a/api/tests/unit_tests/core/workflow/nodes/webhook/test_webhook_node.py b/api/tests/unit_tests/core/workflow/nodes/webhook/test_webhook_node.py index 066ec5542d..df13bbb92f 100644 --- a/api/tests/unit_tests/core/workflow/nodes/webhook/test_webhook_node.py +++ b/api/tests/unit_tests/core/workflow/nodes/webhook/test_webhook_node.py @@ -2,10 +2,9 @@ from unittest.mock import patch import pytest -from core.app.entities.app_invoke_entities import InvokeFrom -from dify_graph.entities.graph_init_params import GraphInitParams +from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom +from dify_graph.entities.graph_init_params import DIFY_RUN_CONTEXT_KEY, GraphInitParams from dify_graph.entities.workflow_node_execution import WorkflowNodeExecutionStatus -from dify_graph.enums import UserFrom from dify_graph.file import File, FileTransferMethod, FileType from dify_graph.nodes.trigger_webhook.entities import ( ContentType, @@ -19,7 +18,6 @@ from dify_graph.runtime.graph_runtime_state import GraphRuntimeState from dify_graph.runtime.variable_pool import VariablePool from dify_graph.system_variable import SystemVariable from dify_graph.variables import FileVariable, StringVariable -from models.workflow import WorkflowType def create_webhook_node(webhook_data: WebhookData, variable_pool: VariablePool) -> TriggerWebhookNode: @@ -30,14 +28,17 @@ def create_webhook_node(webhook_data: WebhookData, variable_pool: VariablePool) } graph_init_params = GraphInitParams( - tenant_id="1", - app_id="1", - workflow_type=WorkflowType.WORKFLOW, workflow_id="1", graph_config={}, - user_id="1", - user_from=UserFrom.ACCOUNT, - invoke_from=InvokeFrom.SERVICE_API, + run_context={ + DIFY_RUN_CONTEXT_KEY: { + "tenant_id": "1", + "app_id": "1", + "user_id": "1", + "user_from": UserFrom.ACCOUNT, + "invoke_from": InvokeFrom.SERVICE_API, + } + }, call_depth=0, ) runtime_state = GraphRuntimeState( diff --git a/api/tests/unit_tests/core/workflow/test_workflow_entry_redis_channel.py b/api/tests/unit_tests/core/workflow/test_workflow_entry_redis_channel.py index 31644edcd8..9969c953e8 100644 --- a/api/tests/unit_tests/core/workflow/test_workflow_entry_redis_channel.py +++ b/api/tests/unit_tests/core/workflow/test_workflow_entry_redis_channel.py @@ -2,9 +2,8 @@ from unittest.mock import MagicMock, patch -from core.app.entities.app_invoke_entities import InvokeFrom +from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom from core.workflow.workflow_entry import WorkflowEntry -from dify_graph.enums import UserFrom from dify_graph.graph_engine.command_channels.redis_channel import RedisChannel from dify_graph.runtime import GraphRuntimeState, VariablePool diff --git a/api/tests/workflow_test_utils.py b/api/tests/workflow_test_utils.py new file mode 100644 index 0000000000..1f0bf8ef37 --- /dev/null +++ b/api/tests/workflow_test_utils.py @@ -0,0 +1,53 @@ +from collections.abc import Mapping +from typing import Any + +from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom, build_dify_run_context +from dify_graph.entities.graph_init_params import GraphInitParams + + +def build_test_run_context( + *, + tenant_id: str = "tenant", + app_id: str = "app", + user_id: str = "user", + user_from: UserFrom | str = UserFrom.ACCOUNT, + invoke_from: InvokeFrom | str = InvokeFrom.DEBUGGER, + extra_context: Mapping[str, Any] | None = None, +) -> dict[str, Any]: + normalized_user_from = user_from if isinstance(user_from, UserFrom) else UserFrom(user_from) + normalized_invoke_from = invoke_from if isinstance(invoke_from, InvokeFrom) else InvokeFrom(invoke_from) + return build_dify_run_context( + tenant_id=tenant_id, + app_id=app_id, + user_id=user_id, + user_from=normalized_user_from, + invoke_from=normalized_invoke_from, + extra_context=extra_context, + ) + + +def build_test_graph_init_params( + *, + workflow_id: str = "workflow", + graph_config: Mapping[str, Any] | None = None, + call_depth: int = 0, + tenant_id: str = "tenant", + app_id: str = "app", + user_id: str = "user", + user_from: UserFrom | str = UserFrom.ACCOUNT, + invoke_from: InvokeFrom | str = InvokeFrom.DEBUGGER, + extra_context: Mapping[str, Any] | None = None, +) -> GraphInitParams: + return GraphInitParams( + workflow_id=workflow_id, + graph_config=graph_config or {}, + run_context=build_test_run_context( + tenant_id=tenant_id, + app_id=app_id, + user_id=user_id, + user_from=user_from, + invoke_from=invoke_from, + extra_context=extra_context, + ), + call_depth=call_depth, + ) From 1819b87a562e98632a6a3263dd4d170ba99ef9ef Mon Sep 17 00:00:00 2001 From: Coding On Star <447357187@qq.com> Date: Thu, 5 Mar 2026 14:34:07 +0800 Subject: [PATCH 041/256] test(workflow): add validation tests for workflow and node component rendering part 3 (#33012) Co-authored-by: CodingOnStar --- .../components/__tests__/index.spec.tsx | 8 +- .../__tests__/workflow-test-env.spec.tsx | 136 ++++++++++++ .../workflow/__tests__/workflow-test-env.tsx | 199 ++++++++++++++---- 3 files changed, 301 insertions(+), 42 deletions(-) create mode 100644 web/app/components/workflow/__tests__/workflow-test-env.spec.tsx diff --git a/web/app/components/rag-pipeline/components/__tests__/index.spec.tsx b/web/app/components/rag-pipeline/components/__tests__/index.spec.tsx index 99318b07b3..0d3b638bab 100644 --- a/web/app/components/rag-pipeline/components/__tests__/index.spec.tsx +++ b/web/app/components/rag-pipeline/components/__tests__/index.spec.tsx @@ -1,5 +1,5 @@ import type { EnvironmentVariable } from '@/app/components/workflow/types' -import { fireEvent, render, screen, waitFor } from '@testing-library/react' +import { act, fireEvent, render, screen, waitFor } from '@testing-library/react' import { createMockProviderContextValue } from '@/__mocks__/provider-context' import Conversion from '../conversion' @@ -9,6 +9,12 @@ import PublishToast from '../publish-toast' import RagPipelineChildren from '../rag-pipeline-children' import PipelineScreenShot from '../screenshot' +afterEach(async () => { + await act(async () => { + await new Promise(resolve => setTimeout(resolve, 0)) + }) +}) + const mockPush = vi.fn() vi.mock('next/navigation', () => ({ useParams: () => ({ datasetId: 'test-dataset-id' }), diff --git a/web/app/components/workflow/__tests__/workflow-test-env.spec.tsx b/web/app/components/workflow/__tests__/workflow-test-env.spec.tsx new file mode 100644 index 0000000000..d9a4efa12e --- /dev/null +++ b/web/app/components/workflow/__tests__/workflow-test-env.spec.tsx @@ -0,0 +1,136 @@ +/** + * Validation tests for renderWorkflowComponent and renderNodeComponent. + */ +import type { Shape } from '../store/workflow' +import { act, screen } from '@testing-library/react' +import * as React from 'react' +import { FlowType } from '@/types/common' +import { useHooksStore } from '../hooks-store/store' +import { useStore, useWorkflowStore } from '../store/workflow' +import { renderNodeComponent, renderWorkflowComponent } from './workflow-test-env' + +// --------------------------------------------------------------------------- +// Test components that read from workflow contexts +// --------------------------------------------------------------------------- + +function StoreReader() { + const showConfirm = useStore(s => s.showConfirm) + return React.createElement('div', { 'data-testid': 'store-reader' }, showConfirm ? 'has-confirm' : 'no-confirm') +} + +function StoreWriter() { + const store = useWorkflowStore() + return React.createElement( + 'button', + { + 'data-testid': 'store-writer', + 'onClick': () => store.setState({ showConfirm: { title: 'Test', onConfirm: () => {} } } as Partial), + }, + 'Write', + ) +} + +function HooksStoreReader() { + const flowId = useHooksStore(s => s.configsMap?.flowId ?? 'none') + return React.createElement('div', { 'data-testid': 'hooks-reader' }, flowId) +} + +function NodeRenderer(props: { id: string, data: { title: string }, selected?: boolean }) { + return React.createElement( + 'div', + { 'data-testid': 'node-render' }, + `${props.id}:${props.data.title}:${props.selected ? 'sel' : 'nosel'}`, + ) +} + +// --------------------------------------------------------------------------- +// Tests +// --------------------------------------------------------------------------- + +describe('renderWorkflowComponent', () => { + it('should provide WorkflowContext with default store', () => { + renderWorkflowComponent(React.createElement(StoreReader)) + expect(screen.getByTestId('store-reader')).toHaveTextContent('no-confirm') + }) + + it('should apply initialStoreState', () => { + renderWorkflowComponent(React.createElement(StoreReader), { + initialStoreState: { showConfirm: { title: 'Hey', onConfirm: () => {} } }, + }) + expect(screen.getByTestId('store-reader')).toHaveTextContent('has-confirm') + }) + + it('should return a live store that components can mutate', () => { + const { store } = renderWorkflowComponent( + React.createElement(React.Fragment, null, React.createElement(StoreReader), React.createElement(StoreWriter)), + ) + + expect(store.getState().showConfirm).toBeUndefined() + + act(() => { + screen.getByTestId('store-writer').click() + }) + + expect(store.getState().showConfirm).toBeDefined() + expect(screen.getByTestId('store-reader')).toHaveTextContent('has-confirm') + }) + + it('should provide HooksStoreContext when hooksStoreProps given', () => { + renderWorkflowComponent(React.createElement(HooksStoreReader), { + hooksStoreProps: { configsMap: { flowId: 'test-123', flowType: FlowType.appFlow, fileSettings: {} } }, + }) + expect(screen.getByTestId('hooks-reader')).toHaveTextContent('test-123') + }) + + it('should throw when HooksStoreContext is not provided', () => { + const consoleSpy = vi.spyOn(console, 'error').mockImplementation(() => {}) + try { + expect(() => { + renderWorkflowComponent(React.createElement(HooksStoreReader)) + }).toThrow('Missing HooksStoreContext.Provider') + } + finally { + consoleSpy.mockRestore() + } + }) + + it('should forward extra render options (container)', () => { + const container = document.createElement('section') + document.body.appendChild(container) + + try { + renderWorkflowComponent(React.createElement(StoreReader), { container }) + expect(container.querySelector('[data-testid="store-reader"]')).toBeTruthy() + } + finally { + document.body.removeChild(container) + } + }) +}) + +describe('renderNodeComponent', () => { + it('should render node with default id and selected=false', () => { + renderNodeComponent(NodeRenderer, { title: 'Hello' }) + expect(screen.getByTestId('node-render')).toHaveTextContent('test-node-1:Hello:nosel') + }) + + it('should accept custom nodeId and selected', () => { + renderNodeComponent(NodeRenderer, { title: 'World' }, { + nodeId: 'custom-42', + selected: true, + }) + expect(screen.getByTestId('node-render')).toHaveTextContent('custom-42:World:sel') + }) + + it('should provide WorkflowContext to node components', () => { + function NodeWithStore(props: { id: string, data: Record }) { + const controlMode = useStore(s => s.controlMode) + return React.createElement('div', { 'data-testid': 'node-store' }, `${props.id}:${controlMode}`) + } + + renderNodeComponent(NodeWithStore, {}, { + initialStoreState: { controlMode: 'hand' as Shape['controlMode'] }, + }) + expect(screen.getByTestId('node-store')).toHaveTextContent('test-node-1:hand') + }) +}) diff --git a/web/app/components/workflow/__tests__/workflow-test-env.tsx b/web/app/components/workflow/__tests__/workflow-test-env.tsx index 6109d8a7f4..00d6829964 100644 --- a/web/app/components/workflow/__tests__/workflow-test-env.tsx +++ b/web/app/components/workflow/__tests__/workflow-test-env.tsx @@ -1,7 +1,7 @@ /** * Workflow test environment — composable providers + render helpers. * - * ## Quick start + * ## Quick start (hook) * * ```ts * import { resetReactFlowMockState, rfState } from '../../__tests__/reactflow-mock-state' @@ -29,13 +29,43 @@ * expect(rfState.setNodes).toHaveBeenCalled() * }) * ``` + * + * ## Quick start (component) + * + * ```ts + * import { renderWorkflowComponent } from '../../__tests__/workflow-test-env' + * + * it('renders correctly', () => { + * const { getByText, store } = renderWorkflowComponent( + * , + * { initialStoreState: { showConfirm: undefined } }, + * ) + * expect(getByText('value')).toBeInTheDocument() + * expect(store.getState().showConfirm).toBeUndefined() + * }) + * ``` + * + * ## Quick start (node component) + * + * ```ts + * import { renderNodeComponent } from '../../__tests__/workflow-test-env' + * + * it('renders node', () => { + * const { getByText, store } = renderNodeComponent( + * MyNodeComponent, + * { type: BlockEnum.Code, title: 'My Node', desc: '' }, + * { nodeId: 'n-1', initialStoreState: { ... } }, + * ) + * expect(getByText('My Node')).toBeInTheDocument() + * }) + * ``` */ -import type { RenderHookOptions, RenderHookResult } from '@testing-library/react' +import type { RenderHookOptions, RenderHookResult, RenderOptions, RenderResult } from '@testing-library/react' import type { Shape as HooksStoreShape } from '../hooks-store/store' import type { Shape } from '../store/workflow' import type { Edge, Node, WorkflowRunningData } from '../types' import type { WorkflowHistoryStoreApi } from '../workflow-history-store' -import { renderHook } from '@testing-library/react' +import { render, renderHook } from '@testing-library/react' import isDeepEqual from 'fast-deep-equal' import * as React from 'react' import { temporal } from 'zundo' @@ -83,11 +113,14 @@ export function createTestWorkflowStore(initialState?: Partial): Workflow } export function createTestHooksStore(props?: Partial): HooksStore { - return createHooksStore(props ?? {}) + const store = createHooksStore(props ?? {}) + if (props) + store.setState(props) + return store } // --------------------------------------------------------------------------- -// renderWorkflowHook — composable hook renderer +// Shared provider options & wrapper factory // --------------------------------------------------------------------------- type HistoryStoreConfig = { @@ -95,17 +128,68 @@ type HistoryStoreConfig = { edges?: Edge[] } -type WorkflowTestOptions

= Omit, 'wrapper'> & { +type WorkflowProviderOptions = { initialStoreState?: Partial hooksStoreProps?: Partial historyStore?: HistoryStoreConfig } -type WorkflowTestResult = RenderHookResult & { +type StoreInstances = { store: WorkflowStore hooksStore?: HooksStore } +function createStoresFromOptions(options: WorkflowProviderOptions): StoreInstances { + const store = createTestWorkflowStore(options.initialStoreState) + const hooksStore = options.hooksStoreProps !== undefined + ? createTestHooksStore(options.hooksStoreProps) + : undefined + return { store, hooksStore } +} + +function createWorkflowWrapper( + stores: StoreInstances, + historyConfig?: HistoryStoreConfig, +) { + const historyCtxValue = historyConfig + ? createTestHistoryStoreContext(historyConfig) + : undefined + + return ({ children }: { children: React.ReactNode }) => { + let inner: React.ReactNode = children + + if (historyCtxValue) { + inner = React.createElement( + WorkflowHistoryStoreContext.Provider, + { value: historyCtxValue }, + inner, + ) + } + + if (stores.hooksStore) { + inner = React.createElement( + HooksStoreContext.Provider, + { value: stores.hooksStore }, + inner, + ) + } + + return React.createElement( + WorkflowContext.Provider, + { value: stores.store }, + inner, + ) + } +} + +// --------------------------------------------------------------------------- +// renderWorkflowHook — composable hook renderer +// --------------------------------------------------------------------------- + +type WorkflowHookTestOptions

= Omit, 'wrapper'> & WorkflowProviderOptions + +type WorkflowHookTestResult = RenderHookResult & StoreInstances + /** * Renders a hook inside composable workflow providers. * @@ -116,44 +200,77 @@ type WorkflowTestResult = RenderHookResult & { */ export function renderWorkflowHook( hook: (props: P) => R, - options?: WorkflowTestOptions

, -): WorkflowTestResult { + options?: WorkflowHookTestOptions

, +): WorkflowHookTestResult { const { initialStoreState, hooksStoreProps, historyStore: historyConfig, ...rest } = options ?? {} - const store = createTestWorkflowStore(initialStoreState) - const hooksStore = hooksStoreProps !== undefined - ? createTestHooksStore(hooksStoreProps) - : undefined - - const wrapper = ({ children }: { children: React.ReactNode }) => { - let inner: React.ReactNode = children - - if (historyConfig) { - const historyCtxValue = createTestHistoryStoreContext(historyConfig) - inner = React.createElement( - WorkflowHistoryStoreContext.Provider, - { value: historyCtxValue }, - inner, - ) - } - - if (hooksStore) { - inner = React.createElement( - HooksStoreContext.Provider, - { value: hooksStore }, - inner, - ) - } - - return React.createElement( - WorkflowContext.Provider, - { value: store }, - inner, - ) - } + const stores = createStoresFromOptions({ initialStoreState, hooksStoreProps }) + const wrapper = createWorkflowWrapper(stores, historyConfig) const renderResult = renderHook(hook, { wrapper, ...rest }) - return { ...renderResult, store, hooksStore } + return { ...renderResult, ...stores } +} + +// --------------------------------------------------------------------------- +// renderWorkflowComponent — composable component renderer +// --------------------------------------------------------------------------- + +type WorkflowComponentTestOptions = Omit & WorkflowProviderOptions + +type WorkflowComponentTestResult = RenderResult & StoreInstances + +/** + * Renders a React element inside composable workflow providers. + * + * Provides the same context layers as `renderWorkflowHook`: + * - **Always**: `WorkflowContext` (real zustand store) + * - **hooksStoreProps**: `HooksStoreContext` (real zustand store) + * - **historyStore**: `WorkflowHistoryStoreContext` (real zundo temporal store) + */ +export function renderWorkflowComponent( + ui: React.ReactElement, + options?: WorkflowComponentTestOptions, +): WorkflowComponentTestResult { + const { initialStoreState, hooksStoreProps, historyStore: historyConfig, ...renderOptions } = options ?? {} + + const stores = createStoresFromOptions({ initialStoreState, hooksStoreProps }) + const wrapper = createWorkflowWrapper(stores, historyConfig) + + const renderResult = render(ui, { wrapper, ...renderOptions }) + return { ...renderResult, ...stores } +} + +// --------------------------------------------------------------------------- +// renderNodeComponent — convenience wrapper for node components +// --------------------------------------------------------------------------- + +type NodeComponentProps> = { + id: string + data: T + selected?: boolean +} + +type NodeTestOptions = WorkflowComponentTestOptions & { + nodeId?: string + selected?: boolean +} + +/** + * Renders a workflow node component inside composable workflow providers. + * + * Automatically provides `id`, `data`, and `selected` props that + * ReactFlow would normally inject into custom node components. + */ +export function renderNodeComponent>( + Component: React.ComponentType>, + data: T, + options?: NodeTestOptions, +): WorkflowComponentTestResult { + const { nodeId = 'test-node-1', selected = false, ...rest } = options ?? {} + return renderWorkflowComponent( + React.createElement(Component, { id: nodeId, data, selected }), + rest, + ) } // --------------------------------------------------------------------------- From f3c840a60ea139203b7e50b189a69f23f904b38b Mon Sep 17 00:00:00 2001 From: zxhlyh Date: Thu, 5 Mar 2026 15:08:37 +0800 Subject: [PATCH 042/256] fix: workflow canvas sync (#33030) --- .../__tests__/use-nodes-sync-draft.spec.ts | 111 ++++++++++++++++++ .../hooks/__tests__/use-workflow-init.spec.ts | 107 +++++++++++++++++ .../use-workflow-refresh-draft.spec.ts | 80 +++++++++++++ .../hooks/use-nodes-sync-draft.ts | 2 +- .../workflow-app/hooks/use-workflow-init.ts | 1 + .../hooks/use-workflow-refresh-draft.ts | 14 ++- 6 files changed, 308 insertions(+), 7 deletions(-) create mode 100644 web/app/components/workflow-app/hooks/__tests__/use-nodes-sync-draft.spec.ts create mode 100644 web/app/components/workflow-app/hooks/__tests__/use-workflow-init.spec.ts create mode 100644 web/app/components/workflow-app/hooks/__tests__/use-workflow-refresh-draft.spec.ts diff --git a/web/app/components/workflow-app/hooks/__tests__/use-nodes-sync-draft.spec.ts b/web/app/components/workflow-app/hooks/__tests__/use-nodes-sync-draft.spec.ts new file mode 100644 index 0000000000..d35e6e3612 --- /dev/null +++ b/web/app/components/workflow-app/hooks/__tests__/use-nodes-sync-draft.spec.ts @@ -0,0 +1,111 @@ +import { act, renderHook } from '@testing-library/react' +import { beforeEach, describe, expect, it, vi } from 'vitest' + +import { useNodesSyncDraft } from '../use-nodes-sync-draft' + +const mockGetNodes = vi.fn() +vi.mock('reactflow', () => ({ + useStoreApi: () => ({ getState: () => ({ getNodes: mockGetNodes, edges: [], transform: [0, 0, 1] }) }), +})) + +vi.mock('@/app/components/workflow/store', () => ({ + useWorkflowStore: () => ({ + getState: () => ({ + appId: 'app-1', + isWorkflowDataLoaded: true, + syncWorkflowDraftHash: 'hash-123', + environmentVariables: [], + conversationVariables: [], + setSyncWorkflowDraftHash: vi.fn(), + setDraftUpdatedAt: vi.fn(), + }), + }), +})) + +vi.mock('@/app/components/base/features/hooks', () => ({ + useFeaturesStore: () => ({ + getState: () => ({ + features: { + opening: { enabled: false, opening_statement: '', suggested_questions: [] }, + suggested: {}, + text2speech: {}, + speech2text: {}, + citation: {}, + moderation: {}, + file: {}, + }, + }), + }), +})) + +vi.mock('@/app/components/workflow/hooks/use-workflow', () => ({ + useNodesReadOnly: () => ({ getNodesReadOnly: () => false }), +})) + +vi.mock('@/app/components/workflow/hooks/use-serial-async-callback', () => ({ + useSerialAsyncCallback: (fn: (...args: unknown[]) => Promise, checkFn: () => boolean) => + (...args: unknown[]) => { + if (!checkFn()) + return fn(...args) + }, +})) + +const mockSyncWorkflowDraft = vi.fn() +vi.mock('@/service/workflow', () => ({ + syncWorkflowDraft: (p: unknown) => mockSyncWorkflowDraft(p), +})) + +vi.mock('@/service/fetch', () => ({ postWithKeepalive: vi.fn() })) +vi.mock('@/config', () => ({ API_PREFIX: '/api' })) + +const mockHandleRefreshWorkflowDraft = vi.fn() +vi.mock('@/app/components/workflow-app/hooks', () => ({ + useWorkflowRefreshDraft: () => ({ handleRefreshWorkflowDraft: mockHandleRefreshWorkflowDraft }), +})) + +describe('useNodesSyncDraft — handleRefreshWorkflowDraft(true) on 409', () => { + beforeEach(() => { + vi.clearAllMocks() + mockGetNodes.mockReturnValue([{ id: 'n1', position: { x: 0, y: 0 }, data: { type: 'start' } }]) + mockSyncWorkflowDraft.mockResolvedValue({ hash: 'new', updated_at: 1 }) + }) + + it('should call handleRefreshWorkflowDraft(true) — not updating canvas — on draft_workflow_not_sync', async () => { + const error = { json: vi.fn().mockResolvedValue({ code: 'draft_workflow_not_sync' }), bodyUsed: false } + mockSyncWorkflowDraft.mockRejectedValue(error) + + const { result } = renderHook(() => useNodesSyncDraft()) + await act(async () => { + await result.current.doSyncWorkflowDraft(false) + }) + await new Promise(r => setTimeout(r, 0)) + + expect(mockHandleRefreshWorkflowDraft).toHaveBeenCalledWith(true) + }) + + it('should NOT refresh when notRefreshWhenSyncError=true', async () => { + const error = { json: vi.fn().mockResolvedValue({ code: 'draft_workflow_not_sync' }), bodyUsed: false } + mockSyncWorkflowDraft.mockRejectedValue(error) + + const { result } = renderHook(() => useNodesSyncDraft()) + await act(async () => { + await result.current.doSyncWorkflowDraft(true) + }) + await new Promise(r => setTimeout(r, 0)) + + expect(mockHandleRefreshWorkflowDraft).not.toHaveBeenCalled() + }) + + it('should NOT refresh for a different error code', async () => { + const error = { json: vi.fn().mockResolvedValue({ code: 'other_error' }), bodyUsed: false } + mockSyncWorkflowDraft.mockRejectedValue(error) + + const { result } = renderHook(() => useNodesSyncDraft()) + await act(async () => { + await result.current.doSyncWorkflowDraft(false) + }) + await new Promise(r => setTimeout(r, 0)) + + expect(mockHandleRefreshWorkflowDraft).not.toHaveBeenCalled() + }) +}) diff --git a/web/app/components/workflow-app/hooks/__tests__/use-workflow-init.spec.ts b/web/app/components/workflow-app/hooks/__tests__/use-workflow-init.spec.ts new file mode 100644 index 0000000000..42e4b593ed --- /dev/null +++ b/web/app/components/workflow-app/hooks/__tests__/use-workflow-init.spec.ts @@ -0,0 +1,107 @@ +import { renderHook, waitFor } from '@testing-library/react' +import { beforeEach, describe, expect, it, vi } from 'vitest' + +import { useWorkflowInit } from '../use-workflow-init' + +const mockSetSyncWorkflowDraftHash = vi.fn() +const mockSetDraftUpdatedAt = vi.fn() +const mockSetToolPublished = vi.fn() +const mockSetPublishedAt = vi.fn() +const mockSetLastPublishedHasUserInput = vi.fn() +const mockSetFileUploadConfig = vi.fn() +const mockWorkflowStoreSetState = vi.fn() +const mockWorkflowStoreGetState = vi.fn() + +vi.mock('@/app/components/workflow/store', () => ({ + useStore: (selector: (state: { setSyncWorkflowDraftHash: ReturnType }) => T): T => + selector({ setSyncWorkflowDraftHash: mockSetSyncWorkflowDraftHash }), + useWorkflowStore: () => ({ + setState: mockWorkflowStoreSetState, + getState: mockWorkflowStoreGetState, + }), +})) + +vi.mock('@/app/components/app/store', () => ({ + useStore: (selector: (state: { appDetail: { id: string, name: string, mode: string } }) => T): T => + selector({ appDetail: { id: 'app-1', name: 'Test', mode: 'workflow' } }), +})) + +vi.mock('../use-workflow-template', () => ({ + useWorkflowTemplate: () => ({ nodes: [], edges: [] }), +})) + +vi.mock('@/service/use-workflow', () => ({ + useWorkflowConfig: () => ({ data: null, isLoading: false }), +})) + +const mockFetchWorkflowDraft = vi.fn() +const mockSyncWorkflowDraft = vi.fn() + +vi.mock('@/service/workflow', () => ({ + fetchWorkflowDraft: (...args: unknown[]) => mockFetchWorkflowDraft(...args), + syncWorkflowDraft: (...args: unknown[]) => mockSyncWorkflowDraft(...args), + fetchNodesDefaultConfigs: () => Promise.resolve([]), + fetchPublishedWorkflow: () => Promise.resolve({ created_at: 0, graph: { nodes: [], edges: [] } }), +})) + +const notExistError = () => ({ + json: vi.fn().mockResolvedValue({ code: 'draft_workflow_not_exist' }), + bodyUsed: false, +}) + +const draftResponse = { + id: 'draft-id', + graph: { nodes: [], edges: [] }, + hash: 'server-hash', + created_at: 0, + created_by: { id: '', name: '', email: '' }, + updated_at: 1, + updated_by: { id: '', name: '', email: '' }, + tool_published: false, + environment_variables: [], + conversation_variables: [], + version: '1', + marked_name: '', + marked_comment: '', +} + +describe('useWorkflowInit — hash fix (draft_workflow_not_exist)', () => { + beforeEach(() => { + vi.clearAllMocks() + mockWorkflowStoreGetState.mockReturnValue({ + setDraftUpdatedAt: mockSetDraftUpdatedAt, + setToolPublished: mockSetToolPublished, + setPublishedAt: mockSetPublishedAt, + setLastPublishedHasUserInput: mockSetLastPublishedHasUserInput, + setFileUploadConfig: mockSetFileUploadConfig, + }) + mockFetchWorkflowDraft + .mockRejectedValueOnce(notExistError()) + .mockResolvedValueOnce(draftResponse) + mockSyncWorkflowDraft.mockResolvedValue({ hash: 'new-hash', updated_at: 1 }) + }) + + it('should call setSyncWorkflowDraftHash with hash returned by syncWorkflowDraft', async () => { + renderHook(() => useWorkflowInit()) + await waitFor(() => expect(mockSetSyncWorkflowDraftHash).toHaveBeenCalledWith('new-hash')) + }) + + it('should store hash BEFORE making the recursive fetchWorkflowDraft call', async () => { + const order: string[] = [] + mockSetSyncWorkflowDraftHash.mockImplementation((h: string) => order.push(`hash:${h}`)) + mockFetchWorkflowDraft + .mockReset() + .mockRejectedValueOnce(notExistError()) + .mockImplementationOnce(async () => { + order.push('fetch:2') + return draftResponse + }) + mockSyncWorkflowDraft.mockResolvedValue({ hash: 'new-hash', updated_at: 1 }) + + renderHook(() => useWorkflowInit()) + + await waitFor(() => expect(order).toContain('fetch:2')) + expect(order).toContain('hash:new-hash') + expect(order.indexOf('hash:new-hash')).toBeLessThan(order.indexOf('fetch:2')) + }) +}) diff --git a/web/app/components/workflow-app/hooks/__tests__/use-workflow-refresh-draft.spec.ts b/web/app/components/workflow-app/hooks/__tests__/use-workflow-refresh-draft.spec.ts new file mode 100644 index 0000000000..2fd06e587b --- /dev/null +++ b/web/app/components/workflow-app/hooks/__tests__/use-workflow-refresh-draft.spec.ts @@ -0,0 +1,80 @@ +import { act, renderHook } from '@testing-library/react' +import { beforeEach, describe, expect, it, vi } from 'vitest' + +import { useWorkflowRefreshDraft } from '../use-workflow-refresh-draft' + +const mockHandleUpdateWorkflowCanvas = vi.fn() +const mockSetSyncWorkflowDraftHash = vi.fn() + +vi.mock('@/app/components/workflow/store', () => ({ + useWorkflowStore: () => ({ + getState: () => ({ + appId: 'app-1', + isWorkflowDataLoaded: true, + debouncedSyncWorkflowDraft: undefined, + setSyncWorkflowDraftHash: mockSetSyncWorkflowDraftHash, + setIsSyncingWorkflowDraft: vi.fn(), + setEnvironmentVariables: vi.fn(), + setEnvSecrets: vi.fn(), + setConversationVariables: vi.fn(), + setIsWorkflowDataLoaded: vi.fn(), + }), + }), +})) + +vi.mock('@/app/components/workflow/hooks', () => ({ + useWorkflowUpdate: () => ({ handleUpdateWorkflowCanvas: mockHandleUpdateWorkflowCanvas }), +})) + +const mockFetchWorkflowDraft = vi.fn() +vi.mock('@/service/workflow', () => ({ + fetchWorkflowDraft: (...args: unknown[]) => mockFetchWorkflowDraft(...args), +})) + +const draftResponse = { + hash: 'server-hash', + graph: { nodes: [{ id: 'n1' }], edges: [], viewport: { x: 1, y: 2, zoom: 1 } }, + environment_variables: [], + conversation_variables: [], +} + +describe('useWorkflowRefreshDraft — notUpdateCanvas parameter', () => { + beforeEach(() => { + vi.clearAllMocks() + mockFetchWorkflowDraft.mockResolvedValue(draftResponse) + }) + + it('should update canvas by default (notUpdateCanvas omitted)', async () => { + const { result } = renderHook(() => useWorkflowRefreshDraft()) + await act(async () => { + result.current.handleRefreshWorkflowDraft() + }) + expect(mockHandleUpdateWorkflowCanvas).toHaveBeenCalledTimes(1) + }) + + it('should update canvas when notUpdateCanvas=false', async () => { + const { result } = renderHook(() => useWorkflowRefreshDraft()) + await act(async () => { + result.current.handleRefreshWorkflowDraft(false) + }) + expect(mockHandleUpdateWorkflowCanvas).toHaveBeenCalledTimes(1) + }) + + it('should NOT update canvas when notUpdateCanvas=true', async () => { + // This is the key change: when called from a 409 error during editing, + // canvas must not be overwritten with server state. + const { result } = renderHook(() => useWorkflowRefreshDraft()) + await act(async () => { + result.current.handleRefreshWorkflowDraft(true) + }) + expect(mockHandleUpdateWorkflowCanvas).not.toHaveBeenCalled() + }) + + it('should still update hash even when notUpdateCanvas=true', async () => { + const { result } = renderHook(() => useWorkflowRefreshDraft()) + await act(async () => { + result.current.handleRefreshWorkflowDraft(true) + }) + expect(mockSetSyncWorkflowDraftHash).toHaveBeenCalledWith('server-hash') + }) +}) diff --git a/web/app/components/workflow-app/hooks/use-nodes-sync-draft.ts b/web/app/components/workflow-app/hooks/use-nodes-sync-draft.ts index 5dc0741324..4f9e529d92 100644 --- a/web/app/components/workflow-app/hooks/use-nodes-sync-draft.ts +++ b/web/app/components/workflow-app/hooks/use-nodes-sync-draft.ts @@ -132,7 +132,7 @@ export const useNodesSyncDraft = () => { if (error && error.json && !error.bodyUsed) { error.json().then((err: any) => { if (err.code === 'draft_workflow_not_sync' && !notRefreshWhenSyncError) - handleRefreshWorkflowDraft() + handleRefreshWorkflowDraft(true) }) } callback?.onError?.() diff --git a/web/app/components/workflow-app/hooks/use-workflow-init.ts b/web/app/components/workflow-app/hooks/use-workflow-init.ts index 8e976937b5..00bff2919f 100644 --- a/web/app/components/workflow-app/hooks/use-workflow-init.ts +++ b/web/app/components/workflow-app/hooks/use-workflow-init.ts @@ -100,6 +100,7 @@ export const useWorkflowInit = () => { }, }).then((res) => { workflowStore.getState().setDraftUpdatedAt(res.updated_at) + setSyncWorkflowDraftHash(res.hash) handleGetInitialWorkflowData() }) } diff --git a/web/app/components/workflow-app/hooks/use-workflow-refresh-draft.ts b/web/app/components/workflow-app/hooks/use-workflow-refresh-draft.ts index fa4a44d894..a7283c0078 100644 --- a/web/app/components/workflow-app/hooks/use-workflow-refresh-draft.ts +++ b/web/app/components/workflow-app/hooks/use-workflow-refresh-draft.ts @@ -8,7 +8,7 @@ export const useWorkflowRefreshDraft = () => { const workflowStore = useWorkflowStore() const { handleUpdateWorkflowCanvas } = useWorkflowUpdate() - const handleRefreshWorkflowDraft = useCallback(() => { + const handleRefreshWorkflowDraft = useCallback((notUpdateCanvas?: boolean) => { const { appId, setSyncWorkflowDraftHash, @@ -31,12 +31,14 @@ export const useWorkflowRefreshDraft = () => { fetchWorkflowDraft(`/apps/${appId}/workflows/draft`) .then((response) => { // Ensure we have a valid workflow structure with viewport - const workflowData: WorkflowDataUpdater = { - nodes: response.graph?.nodes || [], - edges: response.graph?.edges || [], - viewport: response.graph?.viewport || { x: 0, y: 0, zoom: 1 }, + if (!notUpdateCanvas) { + const workflowData: WorkflowDataUpdater = { + nodes: response.graph?.nodes || [], + edges: response.graph?.edges || [], + viewport: response.graph?.viewport || { x: 0, y: 0, zoom: 1 }, + } + handleUpdateWorkflowCanvas(workflowData) } - handleUpdateWorkflowCanvas(workflowData) setSyncWorkflowDraftHash(response.hash) setEnvSecrets((response.environment_variables || []).filter(env => env.value_type === 'secret').reduce((acc, env) => { acc[env.id] = env.value From f487b680f57e59da1a9aa5dafb285568ca62302f Mon Sep 17 00:00:00 2001 From: Stephen Zhou Date: Thu, 5 Mar 2026 15:54:56 +0800 Subject: [PATCH 043/256] refactor: spilt context for better hmr (#33033) --- .../dsl-export-import-flow.test.ts | 2 +- .../[appId]/overview/card-view.tsx | 2 +- web/app/(commonLayout)/layout.tsx | 6 +- .../account-page/AvatarWithEdit.tsx | 2 +- .../account-page/email-change-modal.tsx | 2 +- .../(commonLayout)/account-page/index.tsx | 2 +- web/app/account/(commonLayout)/layout.tsx | 6 +- .../__tests__/use-app-info-actions.spec.ts | 2 +- .../app-info/use-app-info-actions.ts | 2 +- .../csv-uploader.spec.tsx | 2 +- .../csv-uploader.tsx | 2 +- .../config-prompt/advanced-prompt-input.tsx | 2 +- .../config-prompt/simple-prompt-input.tsx | 2 +- .../config/agent/prompt-editor.tsx | 2 +- .../settings-modal/index.spec.tsx | 2 +- .../dataset-config/settings-modal/index.tsx | 2 +- .../context-provider.tsx | 28 ++++++ .../context.spec.tsx | 6 +- .../{context.tsx => context.ts} | 26 +---- .../debug/debug-with-multiple-model/index.tsx | 6 +- .../debug-with-single-model/index.spec.tsx | 2 +- .../app/configuration/debug/index.tsx | 2 +- .../components/app/configuration/index.tsx | 5 +- .../tools/external-data-tool-modal.tsx | 2 +- .../app/configuration/tools/index.tsx | 2 +- .../app/create-app-modal/index.spec.tsx | 2 +- .../components/app/create-app-modal/index.tsx | 2 +- .../app/create-from-dsl-modal/index.tsx | 2 +- .../app/create-from-dsl-modal/uploader.tsx | 2 +- web/app/components/app/log/list.tsx | 2 +- .../app/overview/settings/index.spec.tsx | 16 ++-- .../app/overview/settings/index.tsx | 2 +- .../app/switch-app-modal/index.spec.tsx | 2 +- .../components/app/switch-app-modal/index.tsx | 2 +- web/app/components/apps/app-card.tsx | 3 +- .../agent-log-modal/__tests__/detail.spec.tsx | 2 +- .../agent-log-modal/__tests__/index.spec.tsx | 2 +- .../base/agent-log-modal/detail.tsx | 2 +- .../{context.tsx => context.ts} | 0 .../base/chat/chat-with-history/hooks.tsx | 2 +- .../base/chat/chat/__tests__/context.spec.tsx | 3 +- .../chat/chat/__tests__/question.spec.tsx | 2 +- .../chat-input-area/__tests__/index.spec.tsx | 12 +-- .../base/chat/chat/chat-input-area/index.tsx | 2 +- .../base/chat/chat/check-input-forms-hooks.ts | 2 +- .../{context.tsx => context-provider.tsx} | 30 +----- web/app/components/base/chat/chat/context.ts | 30 ++++++ web/app/components/base/chat/chat/hooks.ts | 2 +- web/app/components/base/chat/chat/index.tsx | 2 +- .../{context.tsx => context.ts} | 0 .../base/chat/embedded-chatbot/hooks.tsx | 2 +- .../inputs-form/__tests__/content.spec.tsx | 2 +- .../moderation-setting-modal.spec.tsx | 2 +- .../moderation/moderation-setting-modal.tsx | 2 +- .../file-uploader/__tests__/hooks.spec.ts | 2 +- .../components/base/file-uploader/hooks.ts | 2 +- .../__tests__/use-check-validated.spec.ts | 2 +- .../base/form/hooks/use-check-validated.ts | 2 +- .../image-uploader/__tests__/hooks.spec.ts | 2 +- .../components/base/image-uploader/hooks.ts | 2 +- .../markdown-blocks/__tests__/button.spec.tsx | 2 +- .../__tests__/think-block.spec.tsx | 2 +- .../markdown-blocks/think-block.stories.tsx | 2 +- .../__tests__/index.spec.tsx | 3 +- .../radio/context/{index.tsx => index.ts} | 0 .../base/tag-input/__tests__/index.spec.tsx | 2 +- web/app/components/base/tag-input/index.tsx | 2 +- .../tag-management/__tests__/panel.spec.tsx | 2 +- .../__tests__/selector.spec.tsx | 2 +- .../components/base/tag-management/index.tsx | 2 +- .../components/base/tag-management/panel.tsx | 2 +- .../base/tag-management/tag-item-editor.tsx | 2 +- .../text-generation/__tests__/hooks.spec.ts | 2 +- .../components/base/text-generation/hooks.ts | 2 +- .../base/toast/__tests__/index.spec.tsx | 3 +- web/app/components/base/toast/context.ts | 23 +++++ .../components/base/toast/index.stories.tsx | 3 +- web/app/components/base/toast/index.tsx | 27 ++---- .../__tests__/index.spec.tsx | 4 +- .../custom/custom-web-app-brand/index.tsx | 2 +- .../__tests__/uploader.spec.tsx | 2 +- .../hooks/use-dsl-import.ts | 2 +- .../create-from-dsl-modal/uploader.tsx | 2 +- .../empty-dataset-creation-modal/index.tsx | 2 +- .../hooks/__tests__/use-file-upload.spec.tsx | 2 +- .../file-uploader/hooks/use-file-upload.ts | 2 +- .../components/__tests__/operations.spec.tsx | 2 +- .../documents/components/operations.tsx | 2 +- .../__tests__/use-local-file-upload.spec.tsx | 4 +- .../__tests__/csv-uploader.spec.tsx | 2 +- .../detail/batch-modal/csv-uploader.tsx | 2 +- .../detail/completed/__tests__/index.spec.tsx | 2 +- .../__tests__/regeneration-modal.spec.tsx | 3 +- .../__tests__/use-child-segment-data.spec.ts | 2 +- .../__tests__/use-segment-list-data.spec.ts | 2 +- .../completed/hooks/use-child-segment-data.ts | 2 +- .../completed/hooks/use-segment-list-data.ts | 2 +- .../detail/completed/new-child-segment.tsx | 2 +- .../documents/detail/embedding/index.tsx | 2 +- .../__tests__/use-metadata-state.spec.ts | 2 +- .../metadata/hooks/use-metadata-state.ts | 2 +- .../datasets/documents/detail/new-segment.tsx | 2 +- .../datasets/documents/status-item/index.tsx | 2 +- .../__tests__/index.spec.tsx | 2 +- .../external-api/external-api-modal/index.tsx | 2 +- .../connector/__tests__/index.spec.tsx | 2 +- .../connector/index.tsx | 2 +- .../workplace-selector/index.spec.tsx | 2 +- .../workplace-selector/index.tsx | 2 +- .../api-based-extension-page/modal.spec.tsx | 4 +- .../api-based-extension-page/modal.tsx | 2 +- .../account-setting/language-page/index.tsx | 2 +- .../edit-workspace-modal/index.spec.tsx | 2 +- .../edit-workspace-modal/index.tsx | 2 +- .../members-page/invite-modal/index.spec.tsx | 2 +- .../members-page/invite-modal/index.tsx | 2 +- .../members-page/operation/index.spec.tsx | 2 +- .../members-page/operation/index.tsx | 2 +- .../transfer-ownership-modal/index.spec.tsx | 2 +- .../transfer-ownership-modal/index.tsx | 2 +- .../model-auth/hooks/use-auth.spec.tsx | 2 +- .../model-auth/hooks/use-auth.ts | 2 +- .../credential-panel.spec.tsx | 2 +- .../provider-added-card/credential-panel.tsx | 2 +- .../model-load-balancing-modal.spec.tsx | 2 +- .../model-load-balancing-modal.tsx | 2 +- .../system-model-selector/index.spec.tsx | 2 +- .../system-model-selector/index.tsx | 2 +- .../plugin-page/SerpapiPlugin.spec.tsx | 4 +- .../plugin-page/SerpapiPlugin.tsx | 2 +- .../plugin-page/index.spec.tsx | 2 +- web/app/components/header/index.spec.tsx | 2 +- web/app/components/header/index.tsx | 2 +- .../__tests__/authorized-in-node.spec.tsx | 2 +- .../__tests__/plugin-auth-in-agent.spec.tsx | 2 +- .../__tests__/api-key-modal.spec.tsx | 2 +- .../__tests__/authorize-components.spec.tsx | 2 +- .../__tests__/oauth-client-settings.spec.tsx | 2 +- .../plugin-auth/authorize/api-key-modal.tsx | 2 +- .../authorize/oauth-client-settings.tsx | 2 +- .../authorized/__tests__/index.spec.tsx | 2 +- .../plugins/plugin-auth/authorized/index.tsx | 2 +- .../__tests__/use-plugin-auth-action.spec.ts | 2 +- .../hooks/use-plugin-auth-action.ts | 2 +- .../edit/__tests__/apikey-edit-modal.spec.tsx | 11 ++- .../edit/__tests__/manual-edit-modal.spec.tsx | 11 ++- .../edit/__tests__/oauth-edit-modal.spec.tsx | 11 ++- .../__tests__/tool-credentials-form.spec.tsx | 3 + .../plugin-page/__tests__/context.spec.tsx | 3 +- .../{context.tsx => context-provider.tsx} | 46 +-------- .../components/plugins/plugin-page/context.ts | 46 +++++++++ .../components/plugins/plugin-page/index.tsx | 6 +- .../__tests__/update-dsl-modal.spec.tsx | 2 +- .../__tests__/index.spec.tsx | 2 +- .../publisher/__tests__/index.spec.tsx | 2 +- .../publisher/__tests__/popup.spec.tsx | 2 +- .../rag-pipeline-header/publisher/popup.tsx | 2 +- .../hooks/__tests__/index.spec.ts | 2 +- .../hooks/__tests__/use-DSL.spec.ts | 2 +- .../__tests__/use-update-dsl-modal.spec.ts | 2 +- .../components/rag-pipeline/hooks/use-DSL.ts | 2 +- .../hooks/use-update-dsl-modal.ts | 2 +- .../__tests__/features-trigger.spec.tsx | 2 +- .../workflow-header/features-trigger.tsx | 2 +- .../components/workflow-app/hooks/use-DSL.ts | 2 +- .../components/workflow/header/run-mode.tsx | 2 +- .../hooks/__tests__/use-checklist.spec.ts | 2 +- .../workflow/hooks/use-checklist.ts | 2 +- .../plugins/link-editor-plugin/hooks.ts | 2 +- .../components/object-value-item.tsx | 2 +- .../components/variable-modal.tsx | 2 +- .../workflow/panel/debug-and-preview/hooks.ts | 2 +- .../panel/env-panel/variable-modal.tsx | 2 +- web/app/components/workflow/run/index.tsx | 2 +- .../components/workflow/update-dsl-modal.tsx | 2 +- .../education-apply/education-apply-page.tsx | 2 +- ...tasets-context.tsx => datasets-context.ts} | 0 web/context/event-emitter-provider.tsx | 22 +++++ .../{event-emitter.tsx => event-emitter.ts} | 18 +--- web/context/mitt-context-provider.tsx | 19 ++++ .../{mitt-context.tsx => mitt-context.ts} | 14 +-- ...context.tsx => modal-context-provider.tsx} | 93 ++---------------- web/context/modal-context.test.tsx | 2 +- web/context/modal-context.ts | 92 ++++++++++++++++++ ...text.tsx => provider-context-provider.tsx} | 96 +------------------ web/context/provider-context.ts | 90 +++++++++++++++++ web/context/workspace-context-provider.tsx | 24 +++++ web/context/workspace-context.ts | 16 ++++ web/context/workspace-context.tsx | 36 ------- web/eslint-suppressions.json | 60 ++---------- web/hooks/use-import-dsl.ts | 2 +- 191 files changed, 645 insertions(+), 615 deletions(-) create mode 100644 web/app/components/app/configuration/debug/debug-with-multiple-model/context-provider.tsx rename web/app/components/app/configuration/debug/debug-with-multiple-model/{context.tsx => context.ts} (50%) rename web/app/components/base/chat/chat-with-history/{context.tsx => context.ts} (100%) rename web/app/components/base/chat/chat/{context.tsx => context-provider.tsx} (58%) create mode 100644 web/app/components/base/chat/chat/context.ts rename web/app/components/base/chat/embedded-chatbot/{context.tsx => context.ts} (100%) rename web/app/components/base/radio/context/{index.tsx => index.ts} (100%) create mode 100644 web/app/components/base/toast/context.ts rename web/app/components/plugins/plugin-page/{context.tsx => context-provider.tsx} (58%) create mode 100644 web/app/components/plugins/plugin-page/context.ts rename web/context/{datasets-context.tsx => datasets-context.ts} (100%) create mode 100644 web/context/event-emitter-provider.tsx rename web/context/{event-emitter.tsx => event-emitter.ts} (53%) create mode 100644 web/context/mitt-context-provider.tsx rename web/context/{mitt-context.tsx => mitt-context.ts} (66%) rename web/context/{modal-context.tsx => modal-context-provider.tsx} (81%) create mode 100644 web/context/modal-context.ts rename web/context/{provider-context.tsx => provider-context-provider.tsx} (70%) create mode 100644 web/context/provider-context.ts create mode 100644 web/context/workspace-context-provider.tsx create mode 100644 web/context/workspace-context.ts delete mode 100644 web/context/workspace-context.tsx diff --git a/web/__tests__/rag-pipeline/dsl-export-import-flow.test.ts b/web/__tests__/rag-pipeline/dsl-export-import-flow.test.ts index 578552840d..dc5ab3fc86 100644 --- a/web/__tests__/rag-pipeline/dsl-export-import-flow.test.ts +++ b/web/__tests__/rag-pipeline/dsl-export-import-flow.test.ts @@ -19,7 +19,7 @@ vi.mock('react-i18next', () => ({ }), })) -vi.mock('@/app/components/base/toast', () => ({ +vi.mock('@/app/components/base/toast/context', () => ({ useToastContext: () => ({ notify: mockNotify }), })) diff --git a/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/card-view.tsx b/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/card-view.tsx index f07b2932c9..8c1df8d63d 100644 --- a/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/card-view.tsx +++ b/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/card-view.tsx @@ -12,7 +12,7 @@ import AppCard from '@/app/components/app/overview/app-card' import TriggerCard from '@/app/components/app/overview/trigger-card' import { useStore as useAppStore } from '@/app/components/app/store' import Loading from '@/app/components/base/loading' -import { ToastContext } from '@/app/components/base/toast' +import { ToastContext } from '@/app/components/base/toast/context' import MCPServiceCard from '@/app/components/tools/mcp/mcp-service-card' import { isTriggerNode } from '@/app/components/workflow/types' import { NEED_REFRESH_APP_LIST_KEY } from '@/config' diff --git a/web/app/(commonLayout)/layout.tsx b/web/app/(commonLayout)/layout.tsx index bd067fde6a..db2786f6cf 100644 --- a/web/app/(commonLayout)/layout.tsx +++ b/web/app/(commonLayout)/layout.tsx @@ -9,9 +9,9 @@ import Header from '@/app/components/header' import HeaderWrapper from '@/app/components/header/header-wrapper' import ReadmePanel from '@/app/components/plugins/readme-panel' import { AppContextProvider } from '@/context/app-context-provider' -import { EventEmitterContextProvider } from '@/context/event-emitter' -import { ModalContextProvider } from '@/context/modal-context' -import { ProviderContextProvider } from '@/context/provider-context' +import { EventEmitterContextProvider } from '@/context/event-emitter-provider' +import { ModalContextProvider } from '@/context/modal-context-provider' +import { ProviderContextProvider } from '@/context/provider-context-provider' import PartnerStack from '../components/billing/partner-stack' import Splash from '../components/splash' import RoleRouteGuard from './role-route-guard' diff --git a/web/app/account/(commonLayout)/account-page/AvatarWithEdit.tsx b/web/app/account/(commonLayout)/account-page/AvatarWithEdit.tsx index 49b59704b1..0a17822187 100644 --- a/web/app/account/(commonLayout)/account-page/AvatarWithEdit.tsx +++ b/web/app/account/(commonLayout)/account-page/AvatarWithEdit.tsx @@ -16,7 +16,7 @@ import Button from '@/app/components/base/button' import Divider from '@/app/components/base/divider' import { useLocalFileUploader } from '@/app/components/base/image-uploader/hooks' import Modal from '@/app/components/base/modal' -import { ToastContext } from '@/app/components/base/toast' +import { ToastContext } from '@/app/components/base/toast/context' import { DISABLE_UPLOAD_IMAGE_AS_ICON } from '@/config' import { updateUserProfile } from '@/service/common' diff --git a/web/app/account/(commonLayout)/account-page/email-change-modal.tsx b/web/app/account/(commonLayout)/account-page/email-change-modal.tsx index 87ca6a689c..463c27294a 100644 --- a/web/app/account/(commonLayout)/account-page/email-change-modal.tsx +++ b/web/app/account/(commonLayout)/account-page/email-change-modal.tsx @@ -9,7 +9,7 @@ import { useContext } from 'use-context-selector' import Button from '@/app/components/base/button' import Input from '@/app/components/base/input' import Modal from '@/app/components/base/modal' -import { ToastContext } from '@/app/components/base/toast' +import { ToastContext } from '@/app/components/base/toast/context' import { checkEmailExisted, resetEmail, diff --git a/web/app/account/(commonLayout)/account-page/index.tsx b/web/app/account/(commonLayout)/account-page/index.tsx index f01efc002c..908ef9c2e8 100644 --- a/web/app/account/(commonLayout)/account-page/index.tsx +++ b/web/app/account/(commonLayout)/account-page/index.tsx @@ -12,7 +12,7 @@ import Button from '@/app/components/base/button' import Input from '@/app/components/base/input' import Modal from '@/app/components/base/modal' import PremiumBadge from '@/app/components/base/premium-badge' -import { ToastContext } from '@/app/components/base/toast' +import { ToastContext } from '@/app/components/base/toast/context' import Collapse from '@/app/components/header/account-setting/collapse' import { IS_CE_EDITION, validPassword } from '@/config' import { useAppContext } from '@/context/app-context' diff --git a/web/app/account/(commonLayout)/layout.tsx b/web/app/account/(commonLayout)/layout.tsx index 47fb47b02b..8fdbd8a238 100644 --- a/web/app/account/(commonLayout)/layout.tsx +++ b/web/app/account/(commonLayout)/layout.tsx @@ -5,9 +5,9 @@ import AmplitudeProvider from '@/app/components/base/amplitude' import GA, { GaType } from '@/app/components/base/ga' import HeaderWrapper from '@/app/components/header/header-wrapper' import { AppContextProvider } from '@/context/app-context-provider' -import { EventEmitterContextProvider } from '@/context/event-emitter' -import { ModalContextProvider } from '@/context/modal-context' -import { ProviderContextProvider } from '@/context/provider-context' +import { EventEmitterContextProvider } from '@/context/event-emitter-provider' +import { ModalContextProvider } from '@/context/modal-context-provider' +import { ProviderContextProvider } from '@/context/provider-context-provider' import Header from './header' const Layout = ({ children }: { children: ReactNode }) => { diff --git a/web/app/components/app-sidebar/app-info/__tests__/use-app-info-actions.spec.ts b/web/app/components/app-sidebar/app-info/__tests__/use-app-info-actions.spec.ts index e5966ed972..6104e2b641 100644 --- a/web/app/components/app-sidebar/app-info/__tests__/use-app-info-actions.spec.ts +++ b/web/app/components/app-sidebar/app-info/__tests__/use-app-info-actions.spec.ts @@ -42,7 +42,7 @@ vi.mock('@/app/components/app/store', () => ({ }), })) -vi.mock('@/app/components/base/toast', () => ({ +vi.mock('@/app/components/base/toast/context', () => ({ ToastContext: {}, })) diff --git a/web/app/components/app-sidebar/app-info/use-app-info-actions.ts b/web/app/components/app-sidebar/app-info/use-app-info-actions.ts index 13880acbed..800f21de44 100644 --- a/web/app/components/app-sidebar/app-info/use-app-info-actions.ts +++ b/web/app/components/app-sidebar/app-info/use-app-info-actions.ts @@ -6,7 +6,7 @@ import { useCallback, useState } from 'react' import { useTranslation } from 'react-i18next' import { useContext } from 'use-context-selector' import { useStore as useAppStore } from '@/app/components/app/store' -import { ToastContext } from '@/app/components/base/toast' +import { ToastContext } from '@/app/components/base/toast/context' import { NEED_REFRESH_APP_LIST_KEY } from '@/config' import { useProviderContext } from '@/context/provider-context' import { copyApp, deleteApp, exportAppConfig, updateAppInfo } from '@/service/apps' diff --git a/web/app/components/app/annotation/batch-add-annotation-modal/csv-uploader.spec.tsx b/web/app/components/app/annotation/batch-add-annotation-modal/csv-uploader.spec.tsx index 6a67ba3207..55f5ee0564 100644 --- a/web/app/components/app/annotation/batch-add-annotation-modal/csv-uploader.spec.tsx +++ b/web/app/components/app/annotation/batch-add-annotation-modal/csv-uploader.spec.tsx @@ -1,7 +1,7 @@ import type { Props } from './csv-uploader' import { fireEvent, render, screen, waitFor } from '@testing-library/react' import * as React from 'react' -import { ToastContext } from '@/app/components/base/toast' +import { ToastContext } from '@/app/components/base/toast/context' import CSVUploader from './csv-uploader' describe('CSVUploader', () => { diff --git a/web/app/components/app/annotation/batch-add-annotation-modal/csv-uploader.tsx b/web/app/components/app/annotation/batch-add-annotation-modal/csv-uploader.tsx index 6d5eb1ef95..118eaea58e 100644 --- a/web/app/components/app/annotation/batch-add-annotation-modal/csv-uploader.tsx +++ b/web/app/components/app/annotation/batch-add-annotation-modal/csv-uploader.tsx @@ -7,7 +7,7 @@ import { useTranslation } from 'react-i18next' import { useContext } from 'use-context-selector' import Button from '@/app/components/base/button' import { Csv as CSVIcon } from '@/app/components/base/icons/src/public/files' -import { ToastContext } from '@/app/components/base/toast' +import { ToastContext } from '@/app/components/base/toast/context' import { cn } from '@/utils/classnames' export type Props = { diff --git a/web/app/components/app/configuration/config-prompt/advanced-prompt-input.tsx b/web/app/components/app/configuration/config-prompt/advanced-prompt-input.tsx index d0e9eb586c..9625204d81 100644 --- a/web/app/components/app/configuration/config-prompt/advanced-prompt-input.tsx +++ b/web/app/components/app/configuration/config-prompt/advanced-prompt-input.tsx @@ -20,7 +20,7 @@ import { } from '@/app/components/base/icons/src/vender/line/files' import PromptEditor from '@/app/components/base/prompt-editor' import { INSERT_VARIABLE_VALUE_BLOCK_COMMAND } from '@/app/components/base/prompt-editor/plugins/variable-block' -import { useToastContext } from '@/app/components/base/toast' +import { useToastContext } from '@/app/components/base/toast/context' import Tooltip from '@/app/components/base/tooltip' import ConfigContext from '@/context/debug-configuration' import { useEventEmitterContextContext } from '@/context/event-emitter' diff --git a/web/app/components/app/configuration/config-prompt/simple-prompt-input.tsx b/web/app/components/app/configuration/config-prompt/simple-prompt-input.tsx index bc94f87838..c33d55873d 100644 --- a/web/app/components/app/configuration/config-prompt/simple-prompt-input.tsx +++ b/web/app/components/app/configuration/config-prompt/simple-prompt-input.tsx @@ -17,7 +17,7 @@ import { useFeaturesStore } from '@/app/components/base/features/hooks' import PromptEditor from '@/app/components/base/prompt-editor' import { PROMPT_EDITOR_UPDATE_VALUE_BY_EVENT_EMITTER } from '@/app/components/base/prompt-editor/plugins/update-block' import { INSERT_VARIABLE_VALUE_BLOCK_COMMAND } from '@/app/components/base/prompt-editor/plugins/variable-block' -import { useToastContext } from '@/app/components/base/toast' +import { useToastContext } from '@/app/components/base/toast/context' import Tooltip from '@/app/components/base/tooltip' import ConfigContext from '@/context/debug-configuration' import { useEventEmitterContextContext } from '@/context/event-emitter' diff --git a/web/app/components/app/configuration/config/agent/prompt-editor.tsx b/web/app/components/app/configuration/config/agent/prompt-editor.tsx index e9e3b60859..9f1f04ba3c 100644 --- a/web/app/components/app/configuration/config/agent/prompt-editor.tsx +++ b/web/app/components/app/configuration/config/agent/prompt-editor.tsx @@ -12,7 +12,7 @@ import { CopyCheck, } from '@/app/components/base/icons/src/vender/line/files' import PromptEditor from '@/app/components/base/prompt-editor' -import { useToastContext } from '@/app/components/base/toast' +import { useToastContext } from '@/app/components/base/toast/context' import ConfigContext from '@/context/debug-configuration' import { useModalContext } from '@/context/modal-context' import { cn } from '@/utils/classnames' diff --git a/web/app/components/app/configuration/dataset-config/settings-modal/index.spec.tsx b/web/app/components/app/configuration/dataset-config/settings-modal/index.spec.tsx index c4fdfb7553..ea70725ea8 100644 --- a/web/app/components/app/configuration/dataset-config/settings-modal/index.spec.tsx +++ b/web/app/components/app/configuration/dataset-config/settings-modal/index.spec.tsx @@ -3,7 +3,7 @@ import type { DataSet } from '@/models/datasets' import type { RetrievalConfig } from '@/types/app' import { render, screen, waitFor } from '@testing-library/react' import userEvent from '@testing-library/user-event' -import { ToastContext } from '@/app/components/base/toast' +import { ToastContext } from '@/app/components/base/toast/context' import { IndexingType } from '@/app/components/datasets/create/step-two' import { ACCOUNT_SETTING_TAB } from '@/app/components/header/account-setting/constants' import { ModelTypeEnum } from '@/app/components/header/account-setting/model-provider-page/declarations' diff --git a/web/app/components/app/configuration/dataset-config/settings-modal/index.tsx b/web/app/components/app/configuration/dataset-config/settings-modal/index.tsx index e910702b66..c9c8d080f2 100644 --- a/web/app/components/app/configuration/dataset-config/settings-modal/index.tsx +++ b/web/app/components/app/configuration/dataset-config/settings-modal/index.tsx @@ -9,7 +9,7 @@ import { useTranslation } from 'react-i18next' import Button from '@/app/components/base/button' import Input from '@/app/components/base/input' import Textarea from '@/app/components/base/textarea' -import { useToastContext } from '@/app/components/base/toast' +import { useToastContext } from '@/app/components/base/toast/context' import { isReRankModelSelected } from '@/app/components/datasets/common/check-rerank-model' import { IndexingType } from '@/app/components/datasets/create/step-two' import IndexMethod from '@/app/components/datasets/settings/index-method' diff --git a/web/app/components/app/configuration/debug/debug-with-multiple-model/context-provider.tsx b/web/app/components/app/configuration/debug/debug-with-multiple-model/context-provider.tsx new file mode 100644 index 0000000000..74aed2d1e2 --- /dev/null +++ b/web/app/components/app/configuration/debug/debug-with-multiple-model/context-provider.tsx @@ -0,0 +1,28 @@ +'use client' + +import type { ReactNode } from 'react' +import type { DebugWithMultipleModelContextType } from './context' +import { DebugWithMultipleModelContext } from './context' + +type DebugWithMultipleModelContextProviderProps = { + children: ReactNode +} & DebugWithMultipleModelContextType +export const DebugWithMultipleModelContextProvider = ({ + children, + onMultipleModelConfigsChange, + multipleModelConfigs, + onDebugWithMultipleModelChange, + checkCanSend, +}: DebugWithMultipleModelContextProviderProps) => { + return ( + + {children} + + ) +} diff --git a/web/app/components/app/configuration/debug/debug-with-multiple-model/context.spec.tsx b/web/app/components/app/configuration/debug/debug-with-multiple-model/context.spec.tsx index e26fcec607..989285f812 100644 --- a/web/app/components/app/configuration/debug/debug-with-multiple-model/context.spec.tsx +++ b/web/app/components/app/configuration/debug/debug-with-multiple-model/context.spec.tsx @@ -1,10 +1,8 @@ import type { ModelAndParameter } from '../types' import type { DebugWithMultipleModelContextType } from './context' import { render, screen } from '@testing-library/react' -import { - DebugWithMultipleModelContextProvider, - useDebugWithMultipleModelContext, -} from './context' +import { useDebugWithMultipleModelContext } from './context' +import { DebugWithMultipleModelContextProvider } from './context-provider' const createModelAndParameter = (overrides: Partial = {}): ModelAndParameter => ({ id: 'model-1', diff --git a/web/app/components/app/configuration/debug/debug-with-multiple-model/context.tsx b/web/app/components/app/configuration/debug/debug-with-multiple-model/context.ts similarity index 50% rename from web/app/components/app/configuration/debug/debug-with-multiple-model/context.tsx rename to web/app/components/app/configuration/debug/debug-with-multiple-model/context.ts index 38f803f8ab..e3ad06f1b9 100644 --- a/web/app/components/app/configuration/debug/debug-with-multiple-model/context.tsx +++ b/web/app/components/app/configuration/debug/debug-with-multiple-model/context.ts @@ -10,7 +10,8 @@ export type DebugWithMultipleModelContextType = { onDebugWithMultipleModelChange: (singleModelConfig: ModelAndParameter) => void checkCanSend?: () => boolean } -const DebugWithMultipleModelContext = createContext({ + +export const DebugWithMultipleModelContext = createContext({ multipleModelConfigs: [], onMultipleModelConfigsChange: noop, onDebugWithMultipleModelChange: noop, @@ -18,27 +19,4 @@ const DebugWithMultipleModelContext = createContext useContext(DebugWithMultipleModelContext) -type DebugWithMultipleModelContextProviderProps = { - children: React.ReactNode -} & DebugWithMultipleModelContextType -export const DebugWithMultipleModelContextProvider = ({ - children, - onMultipleModelConfigsChange, - multipleModelConfigs, - onDebugWithMultipleModelChange, - checkCanSend, -}: DebugWithMultipleModelContextProviderProps) => { - return ( - - {children} - - ) -} - export default DebugWithMultipleModelContext diff --git a/web/app/components/app/configuration/debug/debug-with-multiple-model/index.tsx b/web/app/components/app/configuration/debug/debug-with-multiple-model/index.tsx index c73eb54329..f98e8c1f06 100644 --- a/web/app/components/app/configuration/debug/debug-with-multiple-model/index.tsx +++ b/web/app/components/app/configuration/debug/debug-with-multiple-model/index.tsx @@ -14,10 +14,8 @@ import { useDebugConfigurationContext } from '@/context/debug-configuration' import { useEventEmitterContextContext } from '@/context/event-emitter' import { AppModeEnum } from '@/types/app' import { APP_CHAT_WITH_MULTIPLE_MODEL } from '../types' -import { - DebugWithMultipleModelContextProvider, - useDebugWithMultipleModelContext, -} from './context' +import { useDebugWithMultipleModelContext } from './context' +import { DebugWithMultipleModelContextProvider } from './context-provider' import DebugItem from './debug-item' const DebugWithMultipleModel = () => { diff --git a/web/app/components/app/configuration/debug/debug-with-single-model/index.spec.tsx b/web/app/components/app/configuration/debug/debug-with-single-model/index.spec.tsx index 08bdd2bfcb..48141d0045 100644 --- a/web/app/components/app/configuration/debug/debug-with-single-model/index.spec.tsx +++ b/web/app/components/app/configuration/debug/debug-with-single-model/index.spec.tsx @@ -387,7 +387,7 @@ vi.mock('@/context/event-emitter', () => ({ })) // Mock toast context -vi.mock('@/app/components/base/toast', () => ({ +vi.mock('@/app/components/base/toast/context', () => ({ useToastContext: vi.fn(() => ({ notify: vi.fn(), })), diff --git a/web/app/components/app/configuration/debug/index.tsx b/web/app/components/app/configuration/debug/index.tsx index 14a00e85c7..1a6f9e264f 100644 --- a/web/app/components/app/configuration/debug/index.tsx +++ b/web/app/components/app/configuration/debug/index.tsx @@ -29,7 +29,7 @@ import Button from '@/app/components/base/button' import { useFeatures, useFeaturesStore } from '@/app/components/base/features/hooks' import { RefreshCcw01 } from '@/app/components/base/icons/src/vender/line/arrows' import PromptLogModal from '@/app/components/base/prompt-log-modal' -import { ToastContext } from '@/app/components/base/toast' +import { ToastContext } from '@/app/components/base/toast/context' import TooltipPlus from '@/app/components/base/tooltip' import { ModelFeatureEnum, ModelTypeEnum } from '@/app/components/header/account-setting/model-provider-page/declarations' import { useDefaultModel } from '@/app/components/header/account-setting/model-provider-page/hooks' diff --git a/web/app/components/app/configuration/index.tsx b/web/app/components/app/configuration/index.tsx index 16cf9454ca..a39eb3c63e 100644 --- a/web/app/components/app/configuration/index.tsx +++ b/web/app/components/app/configuration/index.tsx @@ -49,7 +49,8 @@ import { FeaturesProvider } from '@/app/components/base/features' import NewFeaturePanel from '@/app/components/base/features/new-feature-panel' import Loading from '@/app/components/base/loading' import { FILE_EXTS } from '@/app/components/base/prompt-editor/constants' -import Toast, { ToastContext } from '@/app/components/base/toast' +import Toast from '@/app/components/base/toast' +import { ToastContext } from '@/app/components/base/toast/context' import { ACCOUNT_SETTING_TAB } from '@/app/components/header/account-setting/constants' import { ModelFeatureEnum, ModelTypeEnum } from '@/app/components/header/account-setting/model-provider-page/declarations' import { @@ -66,7 +67,7 @@ import { SupportUploadFileTypes } from '@/app/components/workflow/types' import { ANNOTATION_DEFAULT, DATASET_DEFAULT, DEFAULT_AGENT_SETTING, DEFAULT_CHAT_PROMPT_CONFIG, DEFAULT_COMPLETION_PROMPT_CONFIG } from '@/config' import { useAppContext } from '@/context/app-context' import ConfigContext from '@/context/debug-configuration' -import { MittProvider } from '@/context/mitt-context' +import { MittProvider } from '@/context/mitt-context-provider' import { useModalContext } from '@/context/modal-context' import { useProviderContext } from '@/context/provider-context' import useBreakpoints, { MediaType } from '@/hooks/use-breakpoints' diff --git a/web/app/components/app/configuration/tools/external-data-tool-modal.tsx b/web/app/components/app/configuration/tools/external-data-tool-modal.tsx index fece5598e1..35ba59a5fb 100644 --- a/web/app/components/app/configuration/tools/external-data-tool-modal.tsx +++ b/web/app/components/app/configuration/tools/external-data-tool-modal.tsx @@ -13,7 +13,7 @@ import FormGeneration from '@/app/components/base/features/new-feature-panel/mod import { BookOpen01 } from '@/app/components/base/icons/src/vender/line/education' import Modal from '@/app/components/base/modal' import { SimpleSelect } from '@/app/components/base/select' -import { useToastContext } from '@/app/components/base/toast' +import { useToastContext } from '@/app/components/base/toast/context' import ApiBasedExtensionSelector from '@/app/components/header/account-setting/api-based-extension-page/selector' import { useDocLink, useLocale } from '@/context/i18n' import { LanguagesSupported } from '@/i18n-config/language' diff --git a/web/app/components/app/configuration/tools/index.tsx b/web/app/components/app/configuration/tools/index.tsx index d2873b0be3..f348a7718d 100644 --- a/web/app/components/app/configuration/tools/index.tsx +++ b/web/app/components/app/configuration/tools/index.tsx @@ -15,7 +15,7 @@ import { } from '@/app/components/base/icons/src/vender/line/general' import { Tool03 } from '@/app/components/base/icons/src/vender/solid/general' import Switch from '@/app/components/base/switch' -import { useToastContext } from '@/app/components/base/toast' +import { useToastContext } from '@/app/components/base/toast/context' import Tooltip from '@/app/components/base/tooltip' import ConfigContext from '@/context/debug-configuration' import { useModalContext } from '@/context/modal-context' diff --git a/web/app/components/app/create-app-modal/index.spec.tsx b/web/app/components/app/create-app-modal/index.spec.tsx index 8c368df62c..a9adb17582 100644 --- a/web/app/components/app/create-app-modal/index.spec.tsx +++ b/web/app/components/app/create-app-modal/index.spec.tsx @@ -4,7 +4,7 @@ import { useRouter } from 'next/navigation' import { afterAll, beforeEach, describe, expect, it, vi } from 'vitest' import { trackEvent } from '@/app/components/base/amplitude' -import { ToastContext } from '@/app/components/base/toast' +import { ToastContext } from '@/app/components/base/toast/context' import { NEED_REFRESH_APP_LIST_KEY } from '@/config' import { useAppContext } from '@/context/app-context' import { useProviderContext } from '@/context/provider-context' diff --git a/web/app/components/app/create-app-modal/index.tsx b/web/app/components/app/create-app-modal/index.tsx index 66c7bce80c..16ca4bdaff 100644 --- a/web/app/components/app/create-app-modal/index.tsx +++ b/web/app/components/app/create-app-modal/index.tsx @@ -17,7 +17,7 @@ import FullScreenModal from '@/app/components/base/fullscreen-modal' import { BubbleTextMod, ChatBot, ListSparkle, Logic } from '@/app/components/base/icons/src/vender/solid/communication' import Input from '@/app/components/base/input' import Textarea from '@/app/components/base/textarea' -import { ToastContext } from '@/app/components/base/toast' +import { ToastContext } from '@/app/components/base/toast/context' import AppsFull from '@/app/components/billing/apps-full-in-dialog' import { NEED_REFRESH_APP_LIST_KEY } from '@/config' import { useAppContext } from '@/context/app-context' diff --git a/web/app/components/app/create-from-dsl-modal/index.tsx b/web/app/components/app/create-from-dsl-modal/index.tsx index 04d8b1e754..a542ef059f 100644 --- a/web/app/components/app/create-from-dsl-modal/index.tsx +++ b/web/app/components/app/create-from-dsl-modal/index.tsx @@ -12,7 +12,7 @@ import { trackEvent } from '@/app/components/base/amplitude' import Button from '@/app/components/base/button' import Input from '@/app/components/base/input' import Modal from '@/app/components/base/modal' -import { ToastContext } from '@/app/components/base/toast' +import { ToastContext } from '@/app/components/base/toast/context' import AppsFull from '@/app/components/billing/apps-full-in-dialog' import { usePluginDependencies } from '@/app/components/workflow/plugin-dependency/hooks' import { NEED_REFRESH_APP_LIST_KEY } from '@/config' diff --git a/web/app/components/app/create-from-dsl-modal/uploader.tsx b/web/app/components/app/create-from-dsl-modal/uploader.tsx index 778a2c1420..6b7ca4db61 100644 --- a/web/app/components/app/create-from-dsl-modal/uploader.tsx +++ b/web/app/components/app/create-from-dsl-modal/uploader.tsx @@ -10,7 +10,7 @@ import { useTranslation } from 'react-i18next' import { useContext } from 'use-context-selector' import ActionButton from '@/app/components/base/action-button' import { Yaml as YamlIcon } from '@/app/components/base/icons/src/public/files' -import { ToastContext } from '@/app/components/base/toast' +import { ToastContext } from '@/app/components/base/toast/context' import { cn } from '@/utils/classnames' import { formatFileSize } from '@/utils/format' diff --git a/web/app/components/app/log/list.tsx b/web/app/components/app/log/list.tsx index 40519dcb36..4eb733a02c 100644 --- a/web/app/components/app/log/list.tsx +++ b/web/app/components/app/log/list.tsx @@ -31,7 +31,7 @@ import Drawer from '@/app/components/base/drawer' import { getProcessedFilesFromResponse } from '@/app/components/base/file-uploader/utils' import Loading from '@/app/components/base/loading' import MessageLogModal from '@/app/components/base/message-log-modal' -import { ToastContext } from '@/app/components/base/toast' +import { ToastContext } from '@/app/components/base/toast/context' import Tooltip from '@/app/components/base/tooltip' import { addFileInfos, sortAgentSorts } from '@/app/components/tools/utils' import { WorkflowContextProvider } from '@/app/components/workflow/context' diff --git a/web/app/components/app/overview/settings/index.spec.tsx b/web/app/components/app/overview/settings/index.spec.tsx index c9cbe0b724..d98e02ad57 100644 --- a/web/app/components/app/overview/settings/index.spec.tsx +++ b/web/app/components/app/overview/settings/index.spec.tsx @@ -59,16 +59,12 @@ vi.mock('@/context/modal-context', () => ({ useModalContext: () => buildModalContext(), })) -vi.mock('@/app/components/base/toast', async () => { - const actual = await vi.importActual('@/app/components/base/toast') - return { - ...actual, - useToastContext: () => ({ - notify: mockNotify, - close: vi.fn(), - }), - } -}) +vi.mock('@/app/components/base/toast/context', () => ({ + useToastContext: () => ({ + notify: mockNotify, + close: vi.fn(), + }), +})) vi.mock('@/context/i18n', async () => { const actual = await vi.importActual('@/context/i18n') diff --git a/web/app/components/app/overview/settings/index.tsx b/web/app/components/app/overview/settings/index.tsx index 040703f41c..92bfdc5d31 100644 --- a/web/app/components/app/overview/settings/index.tsx +++ b/web/app/components/app/overview/settings/index.tsx @@ -20,7 +20,7 @@ import PremiumBadge from '@/app/components/base/premium-badge' import { SimpleSelect } from '@/app/components/base/select' import Switch from '@/app/components/base/switch' import Textarea from '@/app/components/base/textarea' -import { useToastContext } from '@/app/components/base/toast' +import { useToastContext } from '@/app/components/base/toast/context' import Tooltip from '@/app/components/base/tooltip' import { ACCOUNT_SETTING_TAB } from '@/app/components/header/account-setting/constants' import { useModalContext } from '@/context/modal-context' diff --git a/web/app/components/app/switch-app-modal/index.spec.tsx b/web/app/components/app/switch-app-modal/index.spec.tsx index f0400f8b75..c905d79b31 100644 --- a/web/app/components/app/switch-app-modal/index.spec.tsx +++ b/web/app/components/app/switch-app-modal/index.spec.tsx @@ -3,7 +3,7 @@ import { render, screen, waitFor } from '@testing-library/react' import userEvent from '@testing-library/user-event' import * as React from 'react' import { useStore as useAppStore } from '@/app/components/app/store' -import { ToastContext } from '@/app/components/base/toast' +import { ToastContext } from '@/app/components/base/toast/context' import { Plan } from '@/app/components/billing/type' import { NEED_REFRESH_APP_LIST_KEY } from '@/config' import { AppModeEnum } from '@/types/app' diff --git a/web/app/components/app/switch-app-modal/index.tsx b/web/app/components/app/switch-app-modal/index.tsx index 30d7877ed0..8caa07c187 100644 --- a/web/app/components/app/switch-app-modal/index.tsx +++ b/web/app/components/app/switch-app-modal/index.tsx @@ -15,7 +15,7 @@ import Confirm from '@/app/components/base/confirm' import { AlertTriangle } from '@/app/components/base/icons/src/vender/solid/alertsAndFeedback' import Input from '@/app/components/base/input' import Modal from '@/app/components/base/modal' -import { ToastContext } from '@/app/components/base/toast' +import { ToastContext } from '@/app/components/base/toast/context' import AppsFull from '@/app/components/billing/apps-full-in-dialog' import { NEED_REFRESH_APP_LIST_KEY } from '@/config' import { useAppContext } from '@/context/app-context' diff --git a/web/app/components/apps/app-card.tsx b/web/app/components/apps/app-card.tsx index 8dc5c82e13..471b3420d1 100644 --- a/web/app/components/apps/app-card.tsx +++ b/web/app/components/apps/app-card.tsx @@ -18,7 +18,8 @@ import AppIcon from '@/app/components/base/app-icon' import Divider from '@/app/components/base/divider' import CustomPopover from '@/app/components/base/popover' import TagSelector from '@/app/components/base/tag-management/selector' -import Toast, { ToastContext } from '@/app/components/base/toast' +import Toast from '@/app/components/base/toast' +import { ToastContext } from '@/app/components/base/toast/context' import Tooltip from '@/app/components/base/tooltip' import { AlertDialog, diff --git a/web/app/components/base/agent-log-modal/__tests__/detail.spec.tsx b/web/app/components/base/agent-log-modal/__tests__/detail.spec.tsx index c77f144da2..47d854e028 100644 --- a/web/app/components/base/agent-log-modal/__tests__/detail.spec.tsx +++ b/web/app/components/base/agent-log-modal/__tests__/detail.spec.tsx @@ -2,7 +2,7 @@ import type { ComponentProps } from 'react' import type { IChatItem } from '@/app/components/base/chat/chat/type' import type { AgentLogDetailResponse } from '@/models/log' import { fireEvent, render, screen, waitFor } from '@testing-library/react' -import { ToastContext } from '@/app/components/base/toast' +import { ToastContext } from '@/app/components/base/toast/context' import { fetchAgentLogDetail } from '@/service/log' import AgentLogDetail from '../detail' diff --git a/web/app/components/base/agent-log-modal/__tests__/index.spec.tsx b/web/app/components/base/agent-log-modal/__tests__/index.spec.tsx index 6b59e90c77..6437ae5b43 100644 --- a/web/app/components/base/agent-log-modal/__tests__/index.spec.tsx +++ b/web/app/components/base/agent-log-modal/__tests__/index.spec.tsx @@ -1,7 +1,7 @@ import type { IChatItem } from '@/app/components/base/chat/chat/type' import { fireEvent, render, screen, waitFor } from '@testing-library/react' import { useClickAway } from 'ahooks' -import { ToastContext } from '@/app/components/base/toast' +import { ToastContext } from '@/app/components/base/toast/context' import { fetchAgentLogDetail } from '@/service/log' import AgentLogModal from '../index' diff --git a/web/app/components/base/agent-log-modal/detail.tsx b/web/app/components/base/agent-log-modal/detail.tsx index 36b502e9a5..21ed0be7e8 100644 --- a/web/app/components/base/agent-log-modal/detail.tsx +++ b/web/app/components/base/agent-log-modal/detail.tsx @@ -10,7 +10,7 @@ import { useTranslation } from 'react-i18next' import { useContext } from 'use-context-selector' import { useStore as useAppStore } from '@/app/components/app/store' import Loading from '@/app/components/base/loading' -import { ToastContext } from '@/app/components/base/toast' +import { ToastContext } from '@/app/components/base/toast/context' import { fetchAgentLogDetail } from '@/service/log' import { cn } from '@/utils/classnames' import ResultPanel from './result' diff --git a/web/app/components/base/chat/chat-with-history/context.tsx b/web/app/components/base/chat/chat-with-history/context.ts similarity index 100% rename from web/app/components/base/chat/chat-with-history/context.tsx rename to web/app/components/base/chat/chat-with-history/context.ts diff --git a/web/app/components/base/chat/chat-with-history/hooks.tsx b/web/app/components/base/chat/chat-with-history/hooks.tsx index da344a9789..23936111ce 100644 --- a/web/app/components/base/chat/chat-with-history/hooks.tsx +++ b/web/app/components/base/chat/chat-with-history/hooks.tsx @@ -23,7 +23,7 @@ import { } from 'react' import { useTranslation } from 'react-i18next' import { getProcessedFilesFromResponse } from '@/app/components/base/file-uploader/utils' -import { useToastContext } from '@/app/components/base/toast' +import { useToastContext } from '@/app/components/base/toast/context' import { InputVarType } from '@/app/components/workflow/types' import { useWebAppStore } from '@/context/web-app-context' import { useAppFavicon } from '@/hooks/use-app-favicon' diff --git a/web/app/components/base/chat/chat/__tests__/context.spec.tsx b/web/app/components/base/chat/chat/__tests__/context.spec.tsx index fd00156e59..aeba073a7b 100644 --- a/web/app/components/base/chat/chat/__tests__/context.spec.tsx +++ b/web/app/components/base/chat/chat/__tests__/context.spec.tsx @@ -3,7 +3,8 @@ import type { ChatContextValue } from '../context' import { render, screen } from '@testing-library/react' import userEvent from '@testing-library/user-event' import { vi } from 'vitest' -import { ChatContextProvider, useChatContext } from '../context' +import { useChatContext } from '../context' +import { ChatContextProvider } from '../context-provider' const TestConsumer = () => { const context = useChatContext() diff --git a/web/app/components/base/chat/chat/__tests__/question.spec.tsx b/web/app/components/base/chat/chat/__tests__/question.spec.tsx index 1d0584805b..7c717b6e31 100644 --- a/web/app/components/base/chat/chat/__tests__/question.spec.tsx +++ b/web/app/components/base/chat/chat/__tests__/question.spec.tsx @@ -9,7 +9,7 @@ import { vi } from 'vitest' import Toast from '../../../toast' import { ThemeBuilder } from '../../embedded-chatbot/theme/theme-context' -import { ChatContextProvider } from '../context' +import { ChatContextProvider } from '../context-provider' import Question from '../question' // Global Mocks diff --git a/web/app/components/base/chat/chat/chat-input-area/__tests__/index.spec.tsx b/web/app/components/base/chat/chat/chat-input-area/__tests__/index.spec.tsx index a6d4570fcb..03f3c673ce 100644 --- a/web/app/components/base/chat/chat/chat-input-area/__tests__/index.spec.tsx +++ b/web/app/components/base/chat/chat/chat-input-area/__tests__/index.spec.tsx @@ -91,15 +91,9 @@ vi.mock('@/app/components/base/features/hooks', () => ({ // --------------------------------------------------------------------------- // Toast context // --------------------------------------------------------------------------- -vi.mock('@/app/components/base/toast', async () => { - const actual = await vi.importActual( - '@/app/components/base/toast', - ) - return { - ...actual, - useToastContext: () => ({ notify: mockNotify }), - } -}) +vi.mock('@/app/components/base/toast/context', () => ({ + useToastContext: () => ({ notify: mockNotify, close: vi.fn() }), +})) // --------------------------------------------------------------------------- // Internal layout hook – controls single/multi-line textarea mode diff --git a/web/app/components/base/chat/chat/chat-input-area/index.tsx b/web/app/components/base/chat/chat/chat-input-area/index.tsx index 9de52cb18c..170cccaaca 100644 --- a/web/app/components/base/chat/chat/chat-input-area/index.tsx +++ b/web/app/components/base/chat/chat/chat-input-area/index.tsx @@ -22,7 +22,7 @@ import { FileContextProvider, useFileStore, } from '@/app/components/base/file-uploader/store' -import { useToastContext } from '@/app/components/base/toast' +import { useToastContext } from '@/app/components/base/toast/context' import VoiceInput from '@/app/components/base/voice-input' import { TransferMethod } from '@/types/app' import { cn } from '@/utils/classnames' diff --git a/web/app/components/base/chat/chat/check-input-forms-hooks.ts b/web/app/components/base/chat/chat/check-input-forms-hooks.ts index 2da57b289e..842e89070b 100644 --- a/web/app/components/base/chat/chat/check-input-forms-hooks.ts +++ b/web/app/components/base/chat/chat/check-input-forms-hooks.ts @@ -1,7 +1,7 @@ import type { InputForm } from './type' import { useCallback } from 'react' import { useTranslation } from 'react-i18next' -import { useToastContext } from '@/app/components/base/toast' +import { useToastContext } from '@/app/components/base/toast/context' import { InputVarType } from '@/app/components/workflow/types' import { TransferMethod } from '@/types/app' diff --git a/web/app/components/base/chat/chat/context.tsx b/web/app/components/base/chat/chat/context-provider.tsx similarity index 58% rename from web/app/components/base/chat/chat/context.tsx rename to web/app/components/base/chat/chat/context-provider.tsx index 7843665ad7..02503521e5 100644 --- a/web/app/components/base/chat/chat/context.tsx +++ b/web/app/components/base/chat/chat/context-provider.tsx @@ -1,30 +1,8 @@ 'use client' import type { ReactNode } from 'react' -import type { ChatProps } from './index' -import { createContext, useContext } from 'use-context-selector' - -export type ChatContextValue = Pick & { - readonly?: boolean - } - -const ChatContext = createContext({ - chatList: [], - readonly: false, -}) +import type { ChatContextValue } from './context' +import { ChatContext } from './context' type ChatContextProviderProps = { children: ReactNode @@ -71,7 +49,3 @@ export const ChatContextProvider = ({ ) } - -export const useChatContext = () => useContext(ChatContext) - -export default ChatContext diff --git a/web/app/components/base/chat/chat/context.ts b/web/app/components/base/chat/chat/context.ts new file mode 100644 index 0000000000..ff0bd26336 --- /dev/null +++ b/web/app/components/base/chat/chat/context.ts @@ -0,0 +1,30 @@ +'use client' + +import type { ChatProps } from './index' +import { createContext, useContext } from 'use-context-selector' + +export type ChatContextValue = Pick & { + readonly?: boolean + } + +export const ChatContext = createContext({ + chatList: [], + readonly: false, +}) + +export const useChatContext = () => useContext(ChatContext) + +export default ChatContext diff --git a/web/app/components/base/chat/chat/hooks.ts b/web/app/components/base/chat/chat/hooks.ts index 3e01fa5ade..4828ef2a47 100644 --- a/web/app/components/base/chat/chat/hooks.ts +++ b/web/app/components/base/chat/chat/hooks.ts @@ -30,7 +30,7 @@ import { getProcessedFiles, getProcessedFilesFromResponse, } from '@/app/components/base/file-uploader/utils' -import { useToastContext } from '@/app/components/base/toast' +import { useToastContext } from '@/app/components/base/toast/context' import { NodeRunningStatus, WorkflowRunningStatus } from '@/app/components/workflow/types' import useTimestamp from '@/hooks/use-timestamp' import { diff --git a/web/app/components/base/chat/chat/index.tsx b/web/app/components/base/chat/chat/index.tsx index a77911d895..69c064e3e2 100644 --- a/web/app/components/base/chat/chat/index.tsx +++ b/web/app/components/base/chat/chat/index.tsx @@ -30,7 +30,7 @@ import PromptLogModal from '@/app/components/base/prompt-log-modal' import { cn } from '@/utils/classnames' import Answer from './answer' import ChatInputArea from './chat-input-area' -import { ChatContextProvider } from './context' +import { ChatContextProvider } from './context-provider' import Question from './question' import TryToAsk from './try-to-ask' diff --git a/web/app/components/base/chat/embedded-chatbot/context.tsx b/web/app/components/base/chat/embedded-chatbot/context.ts similarity index 100% rename from web/app/components/base/chat/embedded-chatbot/context.tsx rename to web/app/components/base/chat/embedded-chatbot/context.ts diff --git a/web/app/components/base/chat/embedded-chatbot/hooks.tsx b/web/app/components/base/chat/embedded-chatbot/hooks.tsx index bffee78792..da142a69ec 100644 --- a/web/app/components/base/chat/embedded-chatbot/hooks.tsx +++ b/web/app/components/base/chat/embedded-chatbot/hooks.tsx @@ -21,7 +21,7 @@ import { useState, } from 'react' import { useTranslation } from 'react-i18next' -import { useToastContext } from '@/app/components/base/toast' +import { useToastContext } from '@/app/components/base/toast/context' import { addFileInfos, sortAgentSorts } from '@/app/components/tools/utils' import { InputVarType } from '@/app/components/workflow/types' import { useWebAppStore } from '@/context/web-app-context' diff --git a/web/app/components/base/chat/embedded-chatbot/inputs-form/__tests__/content.spec.tsx b/web/app/components/base/chat/embedded-chatbot/inputs-form/__tests__/content.spec.tsx index 08c9a035e7..aad2d3d09b 100644 --- a/web/app/components/base/chat/embedded-chatbot/inputs-form/__tests__/content.spec.tsx +++ b/web/app/components/base/chat/embedded-chatbot/inputs-form/__tests__/content.spec.tsx @@ -16,7 +16,7 @@ vi.mock('next/navigation', () => ({ useSearchParams: () => new URLSearchParams(), })) -vi.mock('@/app/components/base/toast', () => ({ +vi.mock('@/app/components/base/toast/context', () => ({ useToastContext: () => ({ notify: vi.fn() }), })) diff --git a/web/app/components/base/features/new-feature-panel/moderation/__tests__/moderation-setting-modal.spec.tsx b/web/app/components/base/features/new-feature-panel/moderation/__tests__/moderation-setting-modal.spec.tsx index 3c690635da..88f74d2686 100644 --- a/web/app/components/base/features/new-feature-panel/moderation/__tests__/moderation-setting-modal.spec.tsx +++ b/web/app/components/base/features/new-feature-panel/moderation/__tests__/moderation-setting-modal.spec.tsx @@ -3,7 +3,7 @@ import { fireEvent, render, screen } from '@testing-library/react' import ModerationSettingModal from '../moderation-setting-modal' const mockNotify = vi.fn() -vi.mock('@/app/components/base/toast', () => ({ +vi.mock('@/app/components/base/toast/context', () => ({ useToastContext: () => ({ notify: mockNotify }), })) diff --git a/web/app/components/base/features/new-feature-panel/moderation/moderation-setting-modal.tsx b/web/app/components/base/features/new-feature-panel/moderation/moderation-setting-modal.tsx index c68abfd7b1..4c0682d182 100644 --- a/web/app/components/base/features/new-feature-panel/moderation/moderation-setting-modal.tsx +++ b/web/app/components/base/features/new-feature-panel/moderation/moderation-setting-modal.tsx @@ -7,7 +7,7 @@ import { useTranslation } from 'react-i18next' import Button from '@/app/components/base/button' import Divider from '@/app/components/base/divider' import Modal from '@/app/components/base/modal' -import { useToastContext } from '@/app/components/base/toast' +import { useToastContext } from '@/app/components/base/toast/context' import ApiBasedExtensionSelector from '@/app/components/header/account-setting/api-based-extension-page/selector' import { ACCOUNT_SETTING_TAB } from '@/app/components/header/account-setting/constants' import { CustomConfigurationStatusEnum } from '@/app/components/header/account-setting/model-provider-page/declarations' diff --git a/web/app/components/base/file-uploader/__tests__/hooks.spec.ts b/web/app/components/base/file-uploader/__tests__/hooks.spec.ts index 00c64224aa..8343974967 100644 --- a/web/app/components/base/file-uploader/__tests__/hooks.spec.ts +++ b/web/app/components/base/file-uploader/__tests__/hooks.spec.ts @@ -11,7 +11,7 @@ vi.mock('next/navigation', () => ({ })) // Exception: hook requires toast context that isn't available without a provider wrapper -vi.mock('@/app/components/base/toast', () => ({ +vi.mock('@/app/components/base/toast/context', () => ({ useToastContext: () => ({ notify: mockNotify, }), diff --git a/web/app/components/base/file-uploader/hooks.ts b/web/app/components/base/file-uploader/hooks.ts index 14e46548d8..4aab60175c 100644 --- a/web/app/components/base/file-uploader/hooks.ts +++ b/web/app/components/base/file-uploader/hooks.ts @@ -18,7 +18,7 @@ import { MAX_FILE_UPLOAD_LIMIT, VIDEO_SIZE_LIMIT, } from '@/app/components/base/file-uploader/constants' -import { useToastContext } from '@/app/components/base/toast' +import { useToastContext } from '@/app/components/base/toast/context' import { SupportUploadFileTypes } from '@/app/components/workflow/types' import { uploadRemoteFileInfo } from '@/service/common' import { TransferMethod } from '@/types/app' diff --git a/web/app/components/base/form/hooks/__tests__/use-check-validated.spec.ts b/web/app/components/base/form/hooks/__tests__/use-check-validated.spec.ts index b9d60594f7..28eb5bd5ed 100644 --- a/web/app/components/base/form/hooks/__tests__/use-check-validated.spec.ts +++ b/web/app/components/base/form/hooks/__tests__/use-check-validated.spec.ts @@ -5,7 +5,7 @@ import { useCheckValidated } from '../use-check-validated' const mockNotify = vi.fn() -vi.mock('@/app/components/base/toast', () => ({ +vi.mock('@/app/components/base/toast/context', () => ({ useToastContext: () => ({ notify: mockNotify, }), diff --git a/web/app/components/base/form/hooks/use-check-validated.ts b/web/app/components/base/form/hooks/use-check-validated.ts index 7ed6164bb2..d186996035 100644 --- a/web/app/components/base/form/hooks/use-check-validated.ts +++ b/web/app/components/base/form/hooks/use-check-validated.ts @@ -1,7 +1,7 @@ import type { AnyFormApi } from '@tanstack/react-form' import type { FormSchema } from '@/app/components/base/form/types' import { useCallback } from 'react' -import { useToastContext } from '@/app/components/base/toast' +import { useToastContext } from '@/app/components/base/toast/context' export const useCheckValidated = (form: AnyFormApi, FormSchemas: FormSchema[]) => { const { notify } = useToastContext() diff --git a/web/app/components/base/image-uploader/__tests__/hooks.spec.ts b/web/app/components/base/image-uploader/__tests__/hooks.spec.ts index 4d150830d0..f79ea98081 100644 --- a/web/app/components/base/image-uploader/__tests__/hooks.spec.ts +++ b/web/app/components/base/image-uploader/__tests__/hooks.spec.ts @@ -5,7 +5,7 @@ import { Resolution, TransferMethod } from '@/types/app' import { useClipboardUploader, useDraggableUploader, useImageFiles, useLocalFileUploader } from '../hooks' const mockNotify = vi.fn() -vi.mock('@/app/components/base/toast', () => ({ +vi.mock('@/app/components/base/toast/context', () => ({ useToastContext: () => ({ notify: mockNotify }), })) diff --git a/web/app/components/base/image-uploader/hooks.ts b/web/app/components/base/image-uploader/hooks.ts index cd309a1f7b..03cf0feeca 100644 --- a/web/app/components/base/image-uploader/hooks.ts +++ b/web/app/components/base/image-uploader/hooks.ts @@ -3,7 +3,7 @@ import type { ImageFile, VisionSettings } from '@/types/app' import { useParams } from 'next/navigation' import { useCallback, useMemo, useRef, useState } from 'react' import { useTranslation } from 'react-i18next' -import { useToastContext } from '@/app/components/base/toast' +import { useToastContext } from '@/app/components/base/toast/context' import { ALLOW_FILE_EXTENSIONS, TransferMethod } from '@/types/app' import { getImageUploadErrorMessage, imageUpload } from './utils' diff --git a/web/app/components/base/markdown-blocks/__tests__/button.spec.tsx b/web/app/components/base/markdown-blocks/__tests__/button.spec.tsx index 305896f4f1..95ed788db3 100644 --- a/web/app/components/base/markdown-blocks/__tests__/button.spec.tsx +++ b/web/app/components/base/markdown-blocks/__tests__/button.spec.tsx @@ -5,7 +5,7 @@ import userEvent from '@testing-library/user-event' // markdown-button.spec.tsx import * as React from 'react' import { beforeEach, describe, expect, it, vi } from 'vitest' -import { ChatContextProvider } from '@/app/components/base/chat/chat/context' +import { ChatContextProvider } from '@/app/components/base/chat/chat/context-provider' import MarkdownButton from '../button' diff --git a/web/app/components/base/markdown-blocks/__tests__/think-block.spec.tsx b/web/app/components/base/markdown-blocks/__tests__/think-block.spec.tsx index 2cd31f9a49..e8b956cbbf 100644 --- a/web/app/components/base/markdown-blocks/__tests__/think-block.spec.tsx +++ b/web/app/components/base/markdown-blocks/__tests__/think-block.spec.tsx @@ -1,6 +1,6 @@ import { act, render, screen } from '@testing-library/react' import { afterEach, beforeEach, describe, expect, it, vi } from 'vitest' -import { ChatContextProvider } from '@/app/components/base/chat/chat/context' +import { ChatContextProvider } from '@/app/components/base/chat/chat/context-provider' import ThinkBlock from '../think-block' // Mock react-i18next diff --git a/web/app/components/base/markdown-blocks/think-block.stories.tsx b/web/app/components/base/markdown-blocks/think-block.stories.tsx index 23713fb263..7c3f809ee7 100644 --- a/web/app/components/base/markdown-blocks/think-block.stories.tsx +++ b/web/app/components/base/markdown-blocks/think-block.stories.tsx @@ -1,6 +1,6 @@ import type { Meta, StoryObj } from '@storybook/nextjs-vite' import { useState } from 'react' -import { ChatContextProvider } from '@/app/components/base/chat/chat/context' +import { ChatContextProvider } from '@/app/components/base/chat/chat/context-provider' import ThinkBlock from './think-block' const THOUGHT_TEXT = ` diff --git a/web/app/components/base/prompt-editor/plugins/component-picker-block/__tests__/index.spec.tsx b/web/app/components/base/prompt-editor/plugins/component-picker-block/__tests__/index.spec.tsx index 2deec561e9..6cc6c3a67f 100644 --- a/web/app/components/base/prompt-editor/plugins/component-picker-block/__tests__/index.spec.tsx +++ b/web/app/components/base/prompt-editor/plugins/component-picker-block/__tests__/index.spec.tsx @@ -28,7 +28,8 @@ import { import * as React from 'react' import { GeneratorType } from '@/app/components/app/configuration/config/automatic/types' import { VarType } from '@/app/components/workflow/types' -import { EventEmitterContextProvider, useEventEmitterContextContext } from '@/context/event-emitter' +import { useEventEmitterContextContext } from '@/context/event-emitter' +import { EventEmitterContextProvider } from '@/context/event-emitter-provider' import { INSERT_CONTEXT_BLOCK_COMMAND } from '../../context-block' import { INSERT_CURRENT_BLOCK_COMMAND } from '../../current-block' import { INSERT_ERROR_MESSAGE_BLOCK_COMMAND } from '../../error-message-block' diff --git a/web/app/components/base/radio/context/index.tsx b/web/app/components/base/radio/context/index.ts similarity index 100% rename from web/app/components/base/radio/context/index.tsx rename to web/app/components/base/radio/context/index.ts diff --git a/web/app/components/base/tag-input/__tests__/index.spec.tsx b/web/app/components/base/tag-input/__tests__/index.spec.tsx index b091d9cd03..f07399f8af 100644 --- a/web/app/components/base/tag-input/__tests__/index.spec.tsx +++ b/web/app/components/base/tag-input/__tests__/index.spec.tsx @@ -5,7 +5,7 @@ import TagInput from '../index' const mockNotify = vi.fn() -vi.mock('@/app/components/base/toast', () => ({ +vi.mock('@/app/components/base/toast/context', () => ({ useToastContext: () => ({ notify: mockNotify, }), diff --git a/web/app/components/base/tag-input/index.tsx b/web/app/components/base/tag-input/index.tsx index 1c49b026fb..377e68abe4 100644 --- a/web/app/components/base/tag-input/index.tsx +++ b/web/app/components/base/tag-input/index.tsx @@ -2,7 +2,7 @@ import type { ChangeEvent, FC, KeyboardEvent } from 'react' import { useCallback, useState } from 'react' import AutosizeInput from 'react-18-input-autosize' import { useTranslation } from 'react-i18next' -import { useToastContext } from '@/app/components/base/toast' +import { useToastContext } from '@/app/components/base/toast/context' import { cn } from '@/utils/classnames' type TagInputProps = { diff --git a/web/app/components/base/tag-management/__tests__/panel.spec.tsx b/web/app/components/base/tag-management/__tests__/panel.spec.tsx index c91c72e583..cd9e37e286 100644 --- a/web/app/components/base/tag-management/__tests__/panel.spec.tsx +++ b/web/app/components/base/tag-management/__tests__/panel.spec.tsx @@ -3,7 +3,7 @@ import { render, screen, waitFor, within } from '@testing-library/react' import userEvent from '@testing-library/user-event' import * as React from 'react' import { act } from 'react' -import { ToastContext } from '@/app/components/base/toast' +import { ToastContext } from '@/app/components/base/toast/context' import Panel from '../panel' import { useStore as useTagStore } from '../store' diff --git a/web/app/components/base/tag-management/__tests__/selector.spec.tsx b/web/app/components/base/tag-management/__tests__/selector.spec.tsx index dc58ca37e6..43f17a1e8c 100644 --- a/web/app/components/base/tag-management/__tests__/selector.spec.tsx +++ b/web/app/components/base/tag-management/__tests__/selector.spec.tsx @@ -3,7 +3,7 @@ import { render, screen, waitFor, within } from '@testing-library/react' import userEvent from '@testing-library/user-event' import * as React from 'react' import { act } from 'react' -import { ToastContext } from '@/app/components/base/toast' +import { ToastContext } from '@/app/components/base/toast/context' import TagSelector from '../selector' import { useStore as useTagStore } from '../store' diff --git a/web/app/components/base/tag-management/index.tsx b/web/app/components/base/tag-management/index.tsx index e9ce85ecc0..b7682bcdad 100644 --- a/web/app/components/base/tag-management/index.tsx +++ b/web/app/components/base/tag-management/index.tsx @@ -4,7 +4,7 @@ import { useEffect, useState } from 'react' import { useTranslation } from 'react-i18next' import { useContext } from 'use-context-selector' import Modal from '@/app/components/base/modal' -import { ToastContext } from '@/app/components/base/toast' +import { ToastContext } from '@/app/components/base/toast/context' import { createTag, fetchTagList, diff --git a/web/app/components/base/tag-management/panel.tsx b/web/app/components/base/tag-management/panel.tsx index 4174ba0476..cebed74f3b 100644 --- a/web/app/components/base/tag-management/panel.tsx +++ b/web/app/components/base/tag-management/panel.tsx @@ -10,7 +10,7 @@ import { useContext } from 'use-context-selector' import Checkbox from '@/app/components/base/checkbox' import Divider from '@/app/components/base/divider' import Input from '@/app/components/base/input' -import { ToastContext } from '@/app/components/base/toast' +import { ToastContext } from '@/app/components/base/toast/context' import { bindTag, createTag, unBindTag } from '@/service/tag' import { useStore as useTagStore } from './store' diff --git a/web/app/components/base/tag-management/tag-item-editor.tsx b/web/app/components/base/tag-management/tag-item-editor.tsx index a37e42dcce..3cff335f58 100644 --- a/web/app/components/base/tag-management/tag-item-editor.tsx +++ b/web/app/components/base/tag-management/tag-item-editor.tsx @@ -6,7 +6,7 @@ import { useState } from 'react' import { useTranslation } from 'react-i18next' import { useContext } from 'use-context-selector' import Confirm from '@/app/components/base/confirm' -import { ToastContext } from '@/app/components/base/toast' +import { ToastContext } from '@/app/components/base/toast/context' import Tooltip from '@/app/components/base/tooltip' import { deleteTag, diff --git a/web/app/components/base/text-generation/__tests__/hooks.spec.ts b/web/app/components/base/text-generation/__tests__/hooks.spec.ts index cab06f1c8a..a5b5578158 100644 --- a/web/app/components/base/text-generation/__tests__/hooks.spec.ts +++ b/web/app/components/base/text-generation/__tests__/hooks.spec.ts @@ -5,7 +5,7 @@ import { useTextGeneration } from '../hooks' const mockNotify = vi.fn() const mockSsePost = vi.fn<(url: string, fetchOptions: { body: Record }, otherOptions: IOtherOptions) => void>() -vi.mock('@/app/components/base/toast', () => ({ +vi.mock('@/app/components/base/toast/context', () => ({ useToastContext: () => ({ notify: mockNotify, }), diff --git a/web/app/components/base/text-generation/hooks.ts b/web/app/components/base/text-generation/hooks.ts index c5d008956b..4314a81925 100644 --- a/web/app/components/base/text-generation/hooks.ts +++ b/web/app/components/base/text-generation/hooks.ts @@ -1,6 +1,6 @@ import { useState } from 'react' import { useTranslation } from 'react-i18next' -import { useToastContext } from '@/app/components/base/toast' +import { useToastContext } from '@/app/components/base/toast/context' import { ssePost } from '@/service/base' export const useTextGeneration = () => { diff --git a/web/app/components/base/toast/__tests__/index.spec.tsx b/web/app/components/base/toast/__tests__/index.spec.tsx index f526290fa1..2f5fa49823 100644 --- a/web/app/components/base/toast/__tests__/index.spec.tsx +++ b/web/app/components/base/toast/__tests__/index.spec.tsx @@ -2,7 +2,8 @@ import type { ReactNode } from 'react' import { act, render, screen, waitFor } from '@testing-library/react' import { noop } from 'es-toolkit/function' import * as React from 'react' -import Toast, { ToastProvider, useToastContext } from '..' +import Toast, { ToastProvider } from '..' +import { useToastContext } from '../context' const TestComponent = () => { const { notify, close } = useToastContext() diff --git a/web/app/components/base/toast/context.ts b/web/app/components/base/toast/context.ts new file mode 100644 index 0000000000..ddd8f91336 --- /dev/null +++ b/web/app/components/base/toast/context.ts @@ -0,0 +1,23 @@ +'use client' + +import type { ReactNode } from 'react' +import { createContext, useContext } from 'use-context-selector' + +export type IToastProps = { + type?: 'success' | 'error' | 'warning' | 'info' + size?: 'md' | 'sm' + duration?: number + message: string + children?: ReactNode + onClose?: () => void + className?: string + customComponent?: ReactNode +} + +type IToastContext = { + notify: (props: IToastProps) => void + close: () => void +} + +export const ToastContext = createContext({} as IToastContext) +export const useToastContext = () => useContext(ToastContext) diff --git a/web/app/components/base/toast/index.stories.tsx b/web/app/components/base/toast/index.stories.tsx index 4ab9138070..40d6fecfc2 100644 --- a/web/app/components/base/toast/index.stories.tsx +++ b/web/app/components/base/toast/index.stories.tsx @@ -1,6 +1,7 @@ import type { Meta, StoryObj } from '@storybook/nextjs-vite' import { useCallback } from 'react' -import Toast, { ToastProvider, useToastContext } from '.' +import Toast, { ToastProvider } from '.' +import { useToastContext } from './context' const ToastControls = () => { const { notify } = useToastContext() diff --git a/web/app/components/base/toast/index.tsx b/web/app/components/base/toast/index.tsx index 1b9ae4eedb..a70a0db06c 100644 --- a/web/app/components/base/toast/index.tsx +++ b/web/app/components/base/toast/index.tsx @@ -1,5 +1,6 @@ 'use client' import type { ReactNode } from 'react' +import type { IToastProps } from './context' import { RiAlertFill, RiCheckboxCircleFill, @@ -11,31 +12,13 @@ import { noop } from 'es-toolkit/function' import * as React from 'react' import { useEffect, useState } from 'react' import { createRoot } from 'react-dom/client' -import { createContext, useContext } from 'use-context-selector' import ActionButton from '@/app/components/base/action-button' import { cn } from '@/utils/classnames' - -export type IToastProps = { - type?: 'success' | 'error' | 'warning' | 'info' - size?: 'md' | 'sm' - duration?: number - message: string - children?: ReactNode - onClose?: () => void - className?: string - customComponent?: ReactNode -} -type IToastContext = { - notify: (props: IToastProps) => void - close: () => void -} +import { ToastContext, useToastContext } from './context' export type ToastHandle = { clear?: VoidFunction } - -export const ToastContext = createContext({} as IToastContext) -export const useToastContext = () => useContext(ToastContext) const Toast = ({ type = 'info', size = 'md', @@ -77,11 +60,11 @@ const Toast = ({

-
{message}
+
{message}
{customComponent}
{!!children && ( -
+
{children}
)} @@ -183,3 +166,5 @@ Toast.notify = ({ } export default Toast + +export type { IToastProps } from './context' diff --git a/web/app/components/custom/custom-web-app-brand/__tests__/index.spec.tsx b/web/app/components/custom/custom-web-app-brand/__tests__/index.spec.tsx index 2ceb45235c..1d17a2ae0f 100644 --- a/web/app/components/custom/custom-web-app-brand/__tests__/index.spec.tsx +++ b/web/app/components/custom/custom-web-app-brand/__tests__/index.spec.tsx @@ -1,7 +1,7 @@ import { fireEvent, render, screen, waitFor } from '@testing-library/react' import { beforeEach, describe, expect, it, vi } from 'vitest' import { getImageUploadErrorMessage, imageUpload } from '@/app/components/base/image-uploader/utils' -import { useToastContext } from '@/app/components/base/toast' +import { useToastContext } from '@/app/components/base/toast/context' import { Plan } from '@/app/components/billing/type' import { useAppContext } from '@/context/app-context' import { useGlobalPublicStore } from '@/context/global-public-context' @@ -9,7 +9,7 @@ import { useProviderContext } from '@/context/provider-context' import { updateCurrentWorkspace } from '@/service/common' import CustomWebAppBrand from '../index' -vi.mock('@/app/components/base/toast', () => ({ +vi.mock('@/app/components/base/toast/context', () => ({ useToastContext: vi.fn(), })) vi.mock('@/service/common', () => ({ diff --git a/web/app/components/custom/custom-web-app-brand/index.tsx b/web/app/components/custom/custom-web-app-brand/index.tsx index 438e69894d..e6f9a3837b 100644 --- a/web/app/components/custom/custom-web-app-brand/index.tsx +++ b/web/app/components/custom/custom-web-app-brand/index.tsx @@ -16,7 +16,7 @@ import { BubbleTextMod } from '@/app/components/base/icons/src/vender/solid/comm import { getImageUploadErrorMessage, imageUpload } from '@/app/components/base/image-uploader/utils' import DifyLogo from '@/app/components/base/logo/dify-logo' import Switch from '@/app/components/base/switch' -import { useToastContext } from '@/app/components/base/toast' +import { useToastContext } from '@/app/components/base/toast/context' import { Plan } from '@/app/components/billing/type' import { useAppContext } from '@/context/app-context' import { useGlobalPublicStore } from '@/context/global-public-context' diff --git a/web/app/components/datasets/create-from-pipeline/create-options/create-from-dsl-modal/__tests__/uploader.spec.tsx b/web/app/components/datasets/create-from-pipeline/create-options/create-from-dsl-modal/__tests__/uploader.spec.tsx index db67c91bc0..e47a876fe8 100644 --- a/web/app/components/datasets/create-from-pipeline/create-options/create-from-dsl-modal/__tests__/uploader.spec.tsx +++ b/web/app/components/datasets/create-from-pipeline/create-options/create-from-dsl-modal/__tests__/uploader.spec.tsx @@ -4,7 +4,7 @@ import { beforeEach, describe, expect, it, vi } from 'vitest' import Uploader from '../uploader' const mockNotify = vi.fn() -vi.mock('@/app/components/base/toast', () => ({ +vi.mock('@/app/components/base/toast/context', () => ({ ToastContext: { Provider: ({ children }: { children: React.ReactNode }) => children, Consumer: ({ children }: { children: (value: { notify: typeof mockNotify }) => React.ReactNode }) => children({ notify: mockNotify }), diff --git a/web/app/components/datasets/create-from-pipeline/create-options/create-from-dsl-modal/hooks/use-dsl-import.ts b/web/app/components/datasets/create-from-pipeline/create-options/create-from-dsl-modal/hooks/use-dsl-import.ts index 87e55ea740..c839fad3a2 100644 --- a/web/app/components/datasets/create-from-pipeline/create-options/create-from-dsl-modal/hooks/use-dsl-import.ts +++ b/web/app/components/datasets/create-from-pipeline/create-options/create-from-dsl-modal/hooks/use-dsl-import.ts @@ -4,7 +4,7 @@ import { useRouter } from 'next/navigation' import { useCallback, useMemo, useRef, useState } from 'react' import { useTranslation } from 'react-i18next' import { useContext } from 'use-context-selector' -import { ToastContext } from '@/app/components/base/toast' +import { ToastContext } from '@/app/components/base/toast/context' import { usePluginDependencies } from '@/app/components/workflow/plugin-dependency/hooks' import { DSLImportMode, diff --git a/web/app/components/datasets/create-from-pipeline/create-options/create-from-dsl-modal/uploader.tsx b/web/app/components/datasets/create-from-pipeline/create-options/create-from-dsl-modal/uploader.tsx index 3fa940c60d..faf168c73a 100644 --- a/web/app/components/datasets/create-from-pipeline/create-options/create-from-dsl-modal/uploader.tsx +++ b/web/app/components/datasets/create-from-pipeline/create-options/create-from-dsl-modal/uploader.tsx @@ -10,7 +10,7 @@ import { useEffect, useRef, useState } from 'react' import { useTranslation } from 'react-i18next' import { useContext } from 'use-context-selector' import ActionButton from '@/app/components/base/action-button' -import { ToastContext } from '@/app/components/base/toast' +import { ToastContext } from '@/app/components/base/toast/context' import { cn } from '@/utils/classnames' import { formatFileSize } from '@/utils/format' diff --git a/web/app/components/datasets/create/empty-dataset-creation-modal/index.tsx b/web/app/components/datasets/create/empty-dataset-creation-modal/index.tsx index a9f55c8b84..0a4064de2a 100644 --- a/web/app/components/datasets/create/empty-dataset-creation-modal/index.tsx +++ b/web/app/components/datasets/create/empty-dataset-creation-modal/index.tsx @@ -8,7 +8,7 @@ import { trackEvent } from '@/app/components/base/amplitude' import Button from '@/app/components/base/button' import Input from '@/app/components/base/input' import Modal from '@/app/components/base/modal' -import { ToastContext } from '@/app/components/base/toast' +import { ToastContext } from '@/app/components/base/toast/context' import { createEmptyDataset } from '@/service/datasets' import { useInvalidDatasetList } from '@/service/knowledge/use-dataset' diff --git a/web/app/components/datasets/create/file-uploader/hooks/__tests__/use-file-upload.spec.tsx b/web/app/components/datasets/create/file-uploader/hooks/__tests__/use-file-upload.spec.tsx index b5d1a96554..80331afe2a 100644 --- a/web/app/components/datasets/create/file-uploader/hooks/__tests__/use-file-upload.spec.tsx +++ b/web/app/components/datasets/create/file-uploader/hooks/__tests__/use-file-upload.spec.tsx @@ -2,7 +2,7 @@ import type { ReactNode } from 'react' import type { CustomFile, FileItem } from '@/models/datasets' import { act, render, renderHook, waitFor } from '@testing-library/react' import { beforeEach, describe, expect, it, vi } from 'vitest' -import { ToastContext } from '@/app/components/base/toast' +import { ToastContext } from '@/app/components/base/toast/context' import { PROGRESS_COMPLETE, PROGRESS_ERROR, PROGRESS_NOT_STARTED } from '../../constants' // Import after mocks diff --git a/web/app/components/datasets/create/file-uploader/hooks/use-file-upload.ts b/web/app/components/datasets/create/file-uploader/hooks/use-file-upload.ts index e097bab755..ada60fac1a 100644 --- a/web/app/components/datasets/create/file-uploader/hooks/use-file-upload.ts +++ b/web/app/components/datasets/create/file-uploader/hooks/use-file-upload.ts @@ -5,7 +5,7 @@ import { useCallback, useEffect, useMemo, useRef, useState } from 'react' import { useTranslation } from 'react-i18next' import { useContext } from 'use-context-selector' import { getFileUploadErrorMessage } from '@/app/components/base/file-uploader/utils' -import { ToastContext } from '@/app/components/base/toast' +import { ToastContext } from '@/app/components/base/toast/context' import { IS_CE_EDITION } from '@/config' import { useLocale } from '@/context/i18n' import { LanguagesSupported } from '@/i18n-config/language' diff --git a/web/app/components/datasets/documents/components/__tests__/operations.spec.tsx b/web/app/components/datasets/documents/components/__tests__/operations.spec.tsx index 5aae8dda73..b988b1aeab 100644 --- a/web/app/components/datasets/documents/components/__tests__/operations.spec.tsx +++ b/web/app/components/datasets/documents/components/__tests__/operations.spec.tsx @@ -10,7 +10,7 @@ vi.mock('next/navigation', () => ({ })) const mockNotify = vi.fn() -vi.mock('@/app/components/base/toast', () => ({ +vi.mock('@/app/components/base/toast/context', () => ({ ToastContext: { Provider: ({ children }: { children: React.ReactNode }) => children, }, diff --git a/web/app/components/datasets/documents/components/operations.tsx b/web/app/components/datasets/documents/components/operations.tsx index 15c89a9b26..84e16c7c48 100644 --- a/web/app/components/datasets/documents/components/operations.tsx +++ b/web/app/components/datasets/documents/components/operations.tsx @@ -24,7 +24,7 @@ import Divider from '@/app/components/base/divider' import { SearchLinesSparkle } from '@/app/components/base/icons/src/vender/knowledge' import CustomPopover from '@/app/components/base/popover' import Switch from '@/app/components/base/switch' -import { ToastContext } from '@/app/components/base/toast' +import { ToastContext } from '@/app/components/base/toast/context' import Tooltip from '@/app/components/base/tooltip' import { IS_CE_EDITION } from '@/config' import { DataSourceType, DocumentActionType } from '@/models/datasets' diff --git a/web/app/components/datasets/documents/create-from-pipeline/data-source/local-file/hooks/__tests__/use-local-file-upload.spec.tsx b/web/app/components/datasets/documents/create-from-pipeline/data-source/local-file/hooks/__tests__/use-local-file-upload.spec.tsx index bc9ce04beb..efd1f2a483 100644 --- a/web/app/components/datasets/documents/create-from-pipeline/data-source/local-file/hooks/__tests__/use-local-file-upload.spec.tsx +++ b/web/app/components/datasets/documents/create-from-pipeline/data-source/local-file/hooks/__tests__/use-local-file-upload.spec.tsx @@ -9,7 +9,7 @@ const mockNotify = vi.fn() const mockClose = vi.fn() // Mock ToastContext with factory function -vi.mock('@/app/components/base/toast', async () => { +vi.mock('@/app/components/base/toast/context', async () => { const { createContext, useContext } = await import('use-context-selector') const context = createContext({ notify: mockNotify, close: mockClose }) return { @@ -87,7 +87,7 @@ vi.mock('@/service/base', () => ({ // Import after all mocks are set up const { useLocalFileUpload } = await import('../use-local-file-upload') -const { ToastContext } = await import('@/app/components/base/toast') +const { ToastContext } = await import('@/app/components/base/toast/context') const createWrapper = () => { return ({ children }: { children: ReactNode }) => ( diff --git a/web/app/components/datasets/documents/detail/batch-modal/__tests__/csv-uploader.spec.tsx b/web/app/components/datasets/documents/detail/batch-modal/__tests__/csv-uploader.spec.tsx index 7fb1de7cf9..6876753714 100644 --- a/web/app/components/datasets/documents/detail/batch-modal/__tests__/csv-uploader.spec.tsx +++ b/web/app/components/datasets/documents/detail/batch-modal/__tests__/csv-uploader.spec.tsx @@ -25,7 +25,7 @@ vi.mock('@/hooks/use-theme', () => ({ })) const mockNotify = vi.fn() -vi.mock('@/app/components/base/toast', () => ({ +vi.mock('@/app/components/base/toast/context', () => ({ ToastContext: { Provider: ({ children }: { children: ReactNode }) => children, Consumer: ({ children }: { children: (ctx: { notify: typeof mockNotify }) => ReactNode }) => children({ notify: mockNotify }), diff --git a/web/app/components/datasets/documents/detail/batch-modal/csv-uploader.tsx b/web/app/components/datasets/documents/detail/batch-modal/csv-uploader.tsx index f3a86e910d..63efc766b7 100644 --- a/web/app/components/datasets/documents/detail/batch-modal/csv-uploader.tsx +++ b/web/app/components/datasets/documents/detail/batch-modal/csv-uploader.tsx @@ -12,7 +12,7 @@ import Button from '@/app/components/base/button' import { getFileUploadErrorMessage } from '@/app/components/base/file-uploader/utils' import { Csv as CSVIcon } from '@/app/components/base/icons/src/public/files' import SimplePieChart from '@/app/components/base/simple-pie-chart' -import { ToastContext } from '@/app/components/base/toast' +import { ToastContext } from '@/app/components/base/toast/context' import useTheme from '@/hooks/use-theme' import { upload } from '@/service/base' import { useFileUploadConfig } from '@/service/use-common' diff --git a/web/app/components/datasets/documents/detail/completed/__tests__/index.spec.tsx b/web/app/components/datasets/documents/detail/completed/__tests__/index.spec.tsx index 5802fb8b82..59ecbf5f25 100644 --- a/web/app/components/datasets/documents/detail/completed/__tests__/index.spec.tsx +++ b/web/app/components/datasets/documents/detail/completed/__tests__/index.spec.tsx @@ -65,7 +65,7 @@ vi.mock('../../context', () => ({ }, })) -vi.mock('@/app/components/base/toast', () => ({ +vi.mock('@/app/components/base/toast/context', () => ({ ToastContext: { Provider: ({ children }: { children: React.ReactNode }) => children, Consumer: () => null }, useToastContext: () => ({ notify: mockNotify }), })) diff --git a/web/app/components/datasets/documents/detail/completed/common/__tests__/regeneration-modal.spec.tsx b/web/app/components/datasets/documents/detail/completed/common/__tests__/regeneration-modal.spec.tsx index 719e2867b7..cc7f1aafa4 100644 --- a/web/app/components/datasets/documents/detail/completed/common/__tests__/regeneration-modal.spec.tsx +++ b/web/app/components/datasets/documents/detail/completed/common/__tests__/regeneration-modal.spec.tsx @@ -1,7 +1,8 @@ import type { ReactNode } from 'react' import { act, fireEvent, render, screen, waitFor } from '@testing-library/react' import { beforeEach, describe, expect, it, vi } from 'vitest' -import { EventEmitterContextProvider, useEventEmitterContextContext } from '@/context/event-emitter' +import { useEventEmitterContextContext } from '@/context/event-emitter' +import { EventEmitterContextProvider } from '@/context/event-emitter-provider' import RegenerationModal from '../regeneration-modal' // Store emit function for triggering events in tests diff --git a/web/app/components/datasets/documents/detail/completed/hooks/__tests__/use-child-segment-data.spec.ts b/web/app/components/datasets/documents/detail/completed/hooks/__tests__/use-child-segment-data.spec.ts index 83918a3f30..4cfb4d5927 100644 --- a/web/app/components/datasets/documents/detail/completed/hooks/__tests__/use-child-segment-data.spec.ts +++ b/web/app/components/datasets/documents/detail/completed/hooks/__tests__/use-child-segment-data.spec.ts @@ -59,7 +59,7 @@ vi.mock('../../../context', () => ({ }, })) -vi.mock('@/app/components/base/toast', () => ({ +vi.mock('@/app/components/base/toast/context', () => ({ useToastContext: () => ({ notify: mockNotify }), })) diff --git a/web/app/components/datasets/documents/detail/completed/hooks/__tests__/use-segment-list-data.spec.ts b/web/app/components/datasets/documents/detail/completed/hooks/__tests__/use-segment-list-data.spec.ts index aef2053298..f54c00e3e7 100644 --- a/web/app/components/datasets/documents/detail/completed/hooks/__tests__/use-segment-list-data.spec.ts +++ b/web/app/components/datasets/documents/detail/completed/hooks/__tests__/use-segment-list-data.spec.ts @@ -92,7 +92,7 @@ vi.mock('../../../context', () => ({ }, })) -vi.mock('@/app/components/base/toast', () => ({ +vi.mock('@/app/components/base/toast/context', () => ({ useToastContext: () => ({ notify: mockNotify }), })) diff --git a/web/app/components/datasets/documents/detail/completed/hooks/use-child-segment-data.ts b/web/app/components/datasets/documents/detail/completed/hooks/use-child-segment-data.ts index 4f4c6a532d..cdc8a0b22d 100644 --- a/web/app/components/datasets/documents/detail/completed/hooks/use-child-segment-data.ts +++ b/web/app/components/datasets/documents/detail/completed/hooks/use-child-segment-data.ts @@ -2,7 +2,7 @@ import type { ChildChunkDetail, ChildSegmentsResponse, SegmentDetailModel, Segme import { useQueryClient } from '@tanstack/react-query' import { useCallback, useEffect, useMemo, useRef } from 'react' import { useTranslation } from 'react-i18next' -import { useToastContext } from '@/app/components/base/toast' +import { useToastContext } from '@/app/components/base/toast/context' import { useEventEmitterContextContext } from '@/context/event-emitter' import { useChildSegmentList, diff --git a/web/app/components/datasets/documents/detail/completed/hooks/use-segment-list-data.ts b/web/app/components/datasets/documents/detail/completed/hooks/use-segment-list-data.ts index fd391d2864..aa91e9f464 100644 --- a/web/app/components/datasets/documents/detail/completed/hooks/use-segment-list-data.ts +++ b/web/app/components/datasets/documents/detail/completed/hooks/use-segment-list-data.ts @@ -4,7 +4,7 @@ import { useQueryClient } from '@tanstack/react-query' import { usePathname } from 'next/navigation' import { useCallback, useEffect, useMemo, useRef } from 'react' import { useTranslation } from 'react-i18next' -import { useToastContext } from '@/app/components/base/toast' +import { useToastContext } from '@/app/components/base/toast/context' import { useEventEmitterContextContext } from '@/context/event-emitter' import { ChunkingMode } from '@/models/datasets' import { diff --git a/web/app/components/datasets/documents/detail/completed/new-child-segment.tsx b/web/app/components/datasets/documents/detail/completed/new-child-segment.tsx index 89143662c6..7a6ae4b306 100644 --- a/web/app/components/datasets/documents/detail/completed/new-child-segment.tsx +++ b/web/app/components/datasets/documents/detail/completed/new-child-segment.tsx @@ -8,7 +8,7 @@ import { useContext } from 'use-context-selector' import { useShallow } from 'zustand/react/shallow' import { useStore as useAppStore } from '@/app/components/app/store' import Divider from '@/app/components/base/divider' -import { ToastContext } from '@/app/components/base/toast' +import { ToastContext } from '@/app/components/base/toast/context' import { ChunkingMode } from '@/models/datasets' import { useAddChildSegment } from '@/service/knowledge/use-segment' import { cn } from '@/utils/classnames' diff --git a/web/app/components/datasets/documents/detail/embedding/index.tsx b/web/app/components/datasets/documents/detail/embedding/index.tsx index e89a85c6de..bd344800db 100644 --- a/web/app/components/datasets/documents/detail/embedding/index.tsx +++ b/web/app/components/datasets/documents/detail/embedding/index.tsx @@ -5,7 +5,7 @@ import * as React from 'react' import { useCallback } from 'react' import { useTranslation } from 'react-i18next' import { useContext } from 'use-context-selector' -import { ToastContext } from '@/app/components/base/toast' +import { ToastContext } from '@/app/components/base/toast/context' import { useProcessRule } from '@/service/knowledge/use-dataset' import { useDocumentContext } from '../context' import { ProgressBar, RuleDetail, SegmentProgress, StatusHeader } from './components' diff --git a/web/app/components/datasets/documents/detail/metadata/hooks/__tests__/use-metadata-state.spec.ts b/web/app/components/datasets/documents/detail/metadata/hooks/__tests__/use-metadata-state.spec.ts index ab1d45338f..3d7b28c78c 100644 --- a/web/app/components/datasets/documents/detail/metadata/hooks/__tests__/use-metadata-state.spec.ts +++ b/web/app/components/datasets/documents/detail/metadata/hooks/__tests__/use-metadata-state.spec.ts @@ -3,7 +3,7 @@ import type { FullDocumentDetail } from '@/models/datasets' import { act, renderHook } from '@testing-library/react' import * as React from 'react' import { describe, expect, it, vi } from 'vitest' -import { ToastContext } from '@/app/components/base/toast' +import { ToastContext } from '@/app/components/base/toast/context' import { useMetadataState } from '../use-metadata-state' diff --git a/web/app/components/datasets/documents/detail/metadata/hooks/use-metadata-state.ts b/web/app/components/datasets/documents/detail/metadata/hooks/use-metadata-state.ts index 08651b699e..f786609981 100644 --- a/web/app/components/datasets/documents/detail/metadata/hooks/use-metadata-state.ts +++ b/web/app/components/datasets/documents/detail/metadata/hooks/use-metadata-state.ts @@ -4,7 +4,7 @@ import type { DocType, FullDocumentDetail } from '@/models/datasets' import { useEffect, useState } from 'react' import { useTranslation } from 'react-i18next' import { useContext } from 'use-context-selector' -import { ToastContext } from '@/app/components/base/toast' +import { ToastContext } from '@/app/components/base/toast/context' import { modifyDocMetadata } from '@/service/datasets' import { asyncRunSafe } from '@/utils' import { useDocumentContext } from '../../context' diff --git a/web/app/components/datasets/documents/detail/new-segment.tsx b/web/app/components/datasets/documents/detail/new-segment.tsx index 3a58d6ac06..f32c94bf70 100644 --- a/web/app/components/datasets/documents/detail/new-segment.tsx +++ b/web/app/components/datasets/documents/detail/new-segment.tsx @@ -9,7 +9,7 @@ import { useContext } from 'use-context-selector' import { useShallow } from 'zustand/react/shallow' import { useStore as useAppStore } from '@/app/components/app/store' import Divider from '@/app/components/base/divider' -import { ToastContext } from '@/app/components/base/toast' +import { ToastContext } from '@/app/components/base/toast/context' import ImageUploaderInChunk from '@/app/components/datasets/common/image-uploader/image-uploader-in-chunk' import { useDatasetDetailContextWithSelector } from '@/context/dataset-detail' import { ChunkingMode } from '@/models/datasets' diff --git a/web/app/components/datasets/documents/status-item/index.tsx b/web/app/components/datasets/documents/status-item/index.tsx index 60d837fd81..8d3abed7cf 100644 --- a/web/app/components/datasets/documents/status-item/index.tsx +++ b/web/app/components/datasets/documents/status-item/index.tsx @@ -8,7 +8,7 @@ import { useMemo } from 'react' import { useTranslation } from 'react-i18next' import { useContext } from 'use-context-selector' import Switch from '@/app/components/base/switch' -import { ToastContext } from '@/app/components/base/toast' +import { ToastContext } from '@/app/components/base/toast/context' import Tooltip from '@/app/components/base/tooltip' import Indicator from '@/app/components/header/indicator' import { useDocumentDelete, useDocumentDisable, useDocumentEnable } from '@/service/knowledge/use-document' diff --git a/web/app/components/datasets/external-api/external-api-modal/__tests__/index.spec.tsx b/web/app/components/datasets/external-api/external-api-modal/__tests__/index.spec.tsx index a631de3ea0..66d9a163be 100644 --- a/web/app/components/datasets/external-api/external-api-modal/__tests__/index.spec.tsx +++ b/web/app/components/datasets/external-api/external-api-modal/__tests__/index.spec.tsx @@ -12,7 +12,7 @@ vi.mock('@/service/datasets', () => ({ })) const mockNotify = vi.fn() -vi.mock('@/app/components/base/toast', () => ({ +vi.mock('@/app/components/base/toast/context', () => ({ useToastContext: () => ({ notify: mockNotify, }), diff --git a/web/app/components/datasets/external-api/external-api-modal/index.tsx b/web/app/components/datasets/external-api/external-api-modal/index.tsx index a9a87d11bd..b6e870cdc1 100644 --- a/web/app/components/datasets/external-api/external-api-modal/index.tsx +++ b/web/app/components/datasets/external-api/external-api-modal/index.tsx @@ -19,7 +19,7 @@ import { PortalToFollowElem, PortalToFollowElemContent, } from '@/app/components/base/portal-to-follow-elem' -import { useToastContext } from '@/app/components/base/toast' +import { useToastContext } from '@/app/components/base/toast/context' import Tooltip from '@/app/components/base/tooltip' import { createExternalAPI } from '@/service/datasets' import Form from './Form' diff --git a/web/app/components/datasets/external-knowledge-base/connector/__tests__/index.spec.tsx b/web/app/components/datasets/external-knowledge-base/connector/__tests__/index.spec.tsx index ccd637887b..a6a60aa856 100644 --- a/web/app/components/datasets/external-knowledge-base/connector/__tests__/index.spec.tsx +++ b/web/app/components/datasets/external-knowledge-base/connector/__tests__/index.spec.tsx @@ -22,7 +22,7 @@ vi.mock('@/context/i18n', () => ({ })) const mockNotify = vi.fn() -vi.mock('@/app/components/base/toast', () => ({ +vi.mock('@/app/components/base/toast/context', () => ({ useToastContext: () => ({ notify: mockNotify, }), diff --git a/web/app/components/datasets/external-knowledge-base/connector/index.tsx b/web/app/components/datasets/external-knowledge-base/connector/index.tsx index 1545c0d232..cf36eed382 100644 --- a/web/app/components/datasets/external-knowledge-base/connector/index.tsx +++ b/web/app/components/datasets/external-knowledge-base/connector/index.tsx @@ -5,7 +5,7 @@ import { useRouter } from 'next/navigation' import * as React from 'react' import { useState } from 'react' import { trackEvent } from '@/app/components/base/amplitude' -import { useToastContext } from '@/app/components/base/toast' +import { useToastContext } from '@/app/components/base/toast/context' import ExternalKnowledgeBaseCreate from '@/app/components/datasets/external-knowledge-base/create' import { createExternalKnowledgeBase } from '@/service/datasets' diff --git a/web/app/components/header/account-dropdown/workplace-selector/index.spec.tsx b/web/app/components/header/account-dropdown/workplace-selector/index.spec.tsx index fc32b5f8df..20104b572c 100644 --- a/web/app/components/header/account-dropdown/workplace-selector/index.spec.tsx +++ b/web/app/components/header/account-dropdown/workplace-selector/index.spec.tsx @@ -1,7 +1,7 @@ import type { ProviderContextState } from '@/context/provider-context' import type { IWorkspace } from '@/models/common' import { fireEvent, render, screen, waitFor } from '@testing-library/react' -import { ToastContext } from '@/app/components/base/toast' +import { ToastContext } from '@/app/components/base/toast/context' import { baseProviderContextValue, useProviderContext } from '@/context/provider-context' import { useWorkspacesContext } from '@/context/workspace-context' import { switchWorkspace } from '@/service/common' diff --git a/web/app/components/header/account-dropdown/workplace-selector/index.tsx b/web/app/components/header/account-dropdown/workplace-selector/index.tsx index 058935aa27..528686a26a 100644 --- a/web/app/components/header/account-dropdown/workplace-selector/index.tsx +++ b/web/app/components/header/account-dropdown/workplace-selector/index.tsx @@ -4,7 +4,7 @@ import { RiArrowDownSLine } from '@remixicon/react' import { Fragment } from 'react' import { useTranslation } from 'react-i18next' import { useContext } from 'use-context-selector' -import { ToastContext } from '@/app/components/base/toast' +import { ToastContext } from '@/app/components/base/toast/context' import PlanBadge from '@/app/components/header/plan-badge' import { useWorkspacesContext } from '@/context/workspace-context' import { switchWorkspace } from '@/service/common' diff --git a/web/app/components/header/account-setting/api-based-extension-page/modal.spec.tsx b/web/app/components/header/account-setting/api-based-extension-page/modal.spec.tsx index 3903fbfcf3..884ee8df33 100644 --- a/web/app/components/header/account-setting/api-based-extension-page/modal.spec.tsx +++ b/web/app/components/header/account-setting/api-based-extension-page/modal.spec.tsx @@ -1,8 +1,8 @@ import type { TFunction } from 'i18next' -import type { IToastProps } from '@/app/components/base/toast' +import type { IToastProps } from '@/app/components/base/toast/context' import { fireEvent, render as RTLRender, screen, waitFor } from '@testing-library/react' import * as reactI18next from 'react-i18next' -import { ToastContext } from '@/app/components/base/toast' +import { ToastContext } from '@/app/components/base/toast/context' import { useDocLink } from '@/context/i18n' import { addApiBasedExtension, updateApiBasedExtension } from '@/service/common' import ApiBasedExtensionModal from './modal' diff --git a/web/app/components/header/account-setting/api-based-extension-page/modal.tsx b/web/app/components/header/account-setting/api-based-extension-page/modal.tsx index b04981bf3c..efe6c46dcc 100644 --- a/web/app/components/header/account-setting/api-based-extension-page/modal.tsx +++ b/web/app/components/header/account-setting/api-based-extension-page/modal.tsx @@ -6,7 +6,7 @@ import { useTranslation } from 'react-i18next' import Button from '@/app/components/base/button' import { BookOpen01 } from '@/app/components/base/icons/src/vender/line/education' import Modal from '@/app/components/base/modal' -import { useToastContext } from '@/app/components/base/toast' +import { useToastContext } from '@/app/components/base/toast/context' import { useDocLink } from '@/context/i18n' import { addApiBasedExtension, diff --git a/web/app/components/header/account-setting/language-page/index.tsx b/web/app/components/header/account-setting/language-page/index.tsx index 2a0604421f..5751e88285 100644 --- a/web/app/components/header/account-setting/language-page/index.tsx +++ b/web/app/components/header/account-setting/language-page/index.tsx @@ -7,7 +7,7 @@ import { useState } from 'react' import { useTranslation } from 'react-i18next' import { useContext } from 'use-context-selector' import { SimpleSelect } from '@/app/components/base/select' -import { ToastContext } from '@/app/components/base/toast' +import { ToastContext } from '@/app/components/base/toast/context' import { useAppContext } from '@/context/app-context' import { useLocale } from '@/context/i18n' import { setLocaleOnClient } from '@/i18n-config' diff --git a/web/app/components/header/account-setting/members-page/edit-workspace-modal/index.spec.tsx b/web/app/components/header/account-setting/members-page/edit-workspace-modal/index.spec.tsx index 46ce3f1992..ae0dd8cd4d 100644 --- a/web/app/components/header/account-setting/members-page/edit-workspace-modal/index.spec.tsx +++ b/web/app/components/header/account-setting/members-page/edit-workspace-modal/index.spec.tsx @@ -3,7 +3,7 @@ import type { ICurrentWorkspace } from '@/models/common' import { render, screen, waitFor } from '@testing-library/react' import userEvent from '@testing-library/user-event' import { vi } from 'vitest' -import { ToastContext } from '@/app/components/base/toast' +import { ToastContext } from '@/app/components/base/toast/context' import { useAppContext } from '@/context/app-context' import { updateWorkspaceInfo } from '@/service/common' import EditWorkspaceModal from './index' diff --git a/web/app/components/header/account-setting/members-page/edit-workspace-modal/index.tsx b/web/app/components/header/account-setting/members-page/edit-workspace-modal/index.tsx index a702a83da9..1c3984b0b5 100644 --- a/web/app/components/header/account-setting/members-page/edit-workspace-modal/index.tsx +++ b/web/app/components/header/account-setting/members-page/edit-workspace-modal/index.tsx @@ -6,7 +6,7 @@ import { useContext } from 'use-context-selector' import Button from '@/app/components/base/button' import Input from '@/app/components/base/input' import Modal from '@/app/components/base/modal' -import { ToastContext } from '@/app/components/base/toast' +import { ToastContext } from '@/app/components/base/toast/context' import { useAppContext } from '@/context/app-context' import { updateWorkspaceInfo } from '@/service/common' import { cn } from '@/utils/classnames' diff --git a/web/app/components/header/account-setting/members-page/invite-modal/index.spec.tsx b/web/app/components/header/account-setting/members-page/invite-modal/index.spec.tsx index ef55425ee0..82882c8be5 100644 --- a/web/app/components/header/account-setting/members-page/invite-modal/index.spec.tsx +++ b/web/app/components/header/account-setting/members-page/invite-modal/index.spec.tsx @@ -2,7 +2,7 @@ import type { InvitationResponse } from '@/models/common' import { render, screen, waitFor } from '@testing-library/react' import userEvent from '@testing-library/user-event' import { vi } from 'vitest' -import { ToastContext } from '@/app/components/base/toast' +import { ToastContext } from '@/app/components/base/toast/context' import { useProviderContextSelector } from '@/context/provider-context' import { inviteMember } from '@/service/common' import InviteModal from './index' diff --git a/web/app/components/header/account-setting/members-page/invite-modal/index.tsx b/web/app/components/header/account-setting/members-page/invite-modal/index.tsx index a8c0da40bf..8e4e47e0b8 100644 --- a/web/app/components/header/account-setting/members-page/invite-modal/index.tsx +++ b/web/app/components/header/account-setting/members-page/invite-modal/index.tsx @@ -9,7 +9,7 @@ import { ReactMultiEmail } from 'react-multi-email' import { useContext } from 'use-context-selector' import Button from '@/app/components/base/button' import Modal from '@/app/components/base/modal' -import { ToastContext } from '@/app/components/base/toast' +import { ToastContext } from '@/app/components/base/toast/context' import { emailRegex } from '@/config' import { useLocale } from '@/context/i18n' import { useProviderContextSelector } from '@/context/provider-context' diff --git a/web/app/components/header/account-setting/members-page/operation/index.spec.tsx b/web/app/components/header/account-setting/members-page/operation/index.spec.tsx index 661b2fbc83..e5e7fac10f 100644 --- a/web/app/components/header/account-setting/members-page/operation/index.spec.tsx +++ b/web/app/components/header/account-setting/members-page/operation/index.spec.tsx @@ -2,7 +2,7 @@ import type { Member } from '@/models/common' import { render, screen, waitFor } from '@testing-library/react' import userEvent from '@testing-library/user-event' import { vi } from 'vitest' -import { ToastContext } from '@/app/components/base/toast' +import { ToastContext } from '@/app/components/base/toast/context' import Operation from './index' const mockUpdateMemberRole = vi.fn() diff --git a/web/app/components/header/account-setting/members-page/operation/index.tsx b/web/app/components/header/account-setting/members-page/operation/index.tsx index 88c8e250ea..306a67a093 100644 --- a/web/app/components/header/account-setting/members-page/operation/index.tsx +++ b/web/app/components/header/account-setting/members-page/operation/index.tsx @@ -9,7 +9,7 @@ import { PortalToFollowElemContent, PortalToFollowElemTrigger, } from '@/app/components/base/portal-to-follow-elem' -import { ToastContext } from '@/app/components/base/toast' +import { ToastContext } from '@/app/components/base/toast/context' import { useProviderContext } from '@/context/provider-context' import { deleteMemberOrCancelInvitation, updateMemberRole } from '@/service/common' import { cn } from '@/utils/classnames' diff --git a/web/app/components/header/account-setting/members-page/transfer-ownership-modal/index.spec.tsx b/web/app/components/header/account-setting/members-page/transfer-ownership-modal/index.spec.tsx index 11a0a2db4a..4baa90a7fa 100644 --- a/web/app/components/header/account-setting/members-page/transfer-ownership-modal/index.spec.tsx +++ b/web/app/components/header/account-setting/members-page/transfer-ownership-modal/index.spec.tsx @@ -3,7 +3,7 @@ import type { ICurrentWorkspace } from '@/models/common' import { act, fireEvent, render, screen, waitFor } from '@testing-library/react' import userEvent from '@testing-library/user-event' import { vi } from 'vitest' -import { ToastContext } from '@/app/components/base/toast' +import { ToastContext } from '@/app/components/base/toast/context' import { useAppContext } from '@/context/app-context' import { ownershipTransfer, sendOwnerEmail, verifyOwnerEmail } from '@/service/common' import { useMembers } from '@/service/use-common' diff --git a/web/app/components/header/account-setting/members-page/transfer-ownership-modal/index.tsx b/web/app/components/header/account-setting/members-page/transfer-ownership-modal/index.tsx index 21ea8aa1e9..c4f614737a 100644 --- a/web/app/components/header/account-setting/members-page/transfer-ownership-modal/index.tsx +++ b/web/app/components/header/account-setting/members-page/transfer-ownership-modal/index.tsx @@ -6,7 +6,7 @@ import { useContext } from 'use-context-selector' import Button from '@/app/components/base/button' import Input from '@/app/components/base/input' import Modal from '@/app/components/base/modal' -import { ToastContext } from '@/app/components/base/toast' +import { ToastContext } from '@/app/components/base/toast/context' import { useAppContext } from '@/context/app-context' import { ownershipTransfer, diff --git a/web/app/components/header/account-setting/model-provider-page/model-auth/hooks/use-auth.spec.tsx b/web/app/components/header/account-setting/model-provider-page/model-auth/hooks/use-auth.spec.tsx index c2259f543c..b637fed894 100644 --- a/web/app/components/header/account-setting/model-provider-page/model-auth/hooks/use-auth.spec.tsx +++ b/web/app/components/header/account-setting/model-provider-page/model-auth/hooks/use-auth.spec.tsx @@ -20,7 +20,7 @@ const mockAddModelCredential = vi.fn() const mockEditProviderCredential = vi.fn() const mockEditModelCredential = vi.fn() -vi.mock('@/app/components/base/toast', () => ({ +vi.mock('@/app/components/base/toast/context', () => ({ useToastContext: () => ({ notify: mockNotify }), })) diff --git a/web/app/components/header/account-setting/model-provider-page/model-auth/hooks/use-auth.ts b/web/app/components/header/account-setting/model-provider-page/model-auth/hooks/use-auth.ts index 3576c749b2..4e01677b29 100644 --- a/web/app/components/header/account-setting/model-provider-page/model-auth/hooks/use-auth.ts +++ b/web/app/components/header/account-setting/model-provider-page/model-auth/hooks/use-auth.ts @@ -12,7 +12,7 @@ import { useState, } from 'react' import { useTranslation } from 'react-i18next' -import { useToastContext } from '@/app/components/base/toast' +import { useToastContext } from '@/app/components/base/toast/context' import { useModelModalHandler, useRefreshModel, diff --git a/web/app/components/header/account-setting/model-provider-page/provider-added-card/credential-panel.spec.tsx b/web/app/components/header/account-setting/model-provider-page/provider-added-card/credential-panel.spec.tsx index 554efc93d2..9f493d25e5 100644 --- a/web/app/components/header/account-setting/model-provider-page/provider-added-card/credential-panel.spec.tsx +++ b/web/app/components/header/account-setting/model-provider-page/provider-added-card/credential-panel.spec.tsx @@ -24,7 +24,7 @@ vi.mock('@/config', async (importOriginal) => { } }) -vi.mock('@/app/components/base/toast', () => ({ +vi.mock('@/app/components/base/toast/context', () => ({ useToastContext: () => ({ notify: mockNotify, }), diff --git a/web/app/components/header/account-setting/model-provider-page/provider-added-card/credential-panel.tsx b/web/app/components/header/account-setting/model-provider-page/provider-added-card/credential-panel.tsx index c46f9d56bd..efa768e7f5 100644 --- a/web/app/components/header/account-setting/model-provider-page/provider-added-card/credential-panel.tsx +++ b/web/app/components/header/account-setting/model-provider-page/provider-added-card/credential-panel.tsx @@ -3,7 +3,7 @@ import type { } from '../declarations' import { useMemo } from 'react' import { useTranslation } from 'react-i18next' -import { useToastContext } from '@/app/components/base/toast' +import { useToastContext } from '@/app/components/base/toast/context' import { ConfigProvider } from '@/app/components/header/account-setting/model-provider-page/model-auth' import { useCredentialStatus } from '@/app/components/header/account-setting/model-provider-page/model-auth/hooks' import Indicator from '@/app/components/header/indicator' diff --git a/web/app/components/header/account-setting/model-provider-page/provider-added-card/model-load-balancing-modal.spec.tsx b/web/app/components/header/account-setting/model-provider-page/provider-added-card/model-load-balancing-modal.spec.tsx index ea78234612..b945b50e9b 100644 --- a/web/app/components/header/account-setting/model-provider-page/provider-added-card/model-load-balancing-modal.spec.tsx +++ b/web/app/components/header/account-setting/model-provider-page/provider-added-card/model-load-balancing-modal.spec.tsx @@ -43,7 +43,7 @@ let mockCredentialData: CredentialData | undefined = { current_credential_name: 'Default', } -vi.mock('@/app/components/base/toast', () => ({ +vi.mock('@/app/components/base/toast/context', () => ({ useToastContext: () => ({ notify: mockNotify, }), diff --git a/web/app/components/header/account-setting/model-provider-page/provider-added-card/model-load-balancing-modal.tsx b/web/app/components/header/account-setting/model-provider-page/provider-added-card/model-load-balancing-modal.tsx index 93229f1257..0009237edc 100644 --- a/web/app/components/header/account-setting/model-provider-page/provider-added-card/model-load-balancing-modal.tsx +++ b/web/app/components/header/account-setting/model-provider-page/provider-added-card/model-load-balancing-modal.tsx @@ -12,7 +12,7 @@ import Button from '@/app/components/base/button' import Confirm from '@/app/components/base/confirm' import Loading from '@/app/components/base/loading' import Modal from '@/app/components/base/modal' -import { useToastContext } from '@/app/components/base/toast' +import { useToastContext } from '@/app/components/base/toast/context' import { SwitchCredentialInLoadBalancing } from '@/app/components/header/account-setting/model-provider-page/model-auth' import { useGetModelCredential, diff --git a/web/app/components/header/account-setting/model-provider-page/system-model-selector/index.spec.tsx b/web/app/components/header/account-setting/model-provider-page/system-model-selector/index.spec.tsx index 819bb71164..22186b34e1 100644 --- a/web/app/components/header/account-setting/model-provider-page/system-model-selector/index.spec.tsx +++ b/web/app/components/header/account-setting/model-provider-page/system-model-selector/index.spec.tsx @@ -42,7 +42,7 @@ vi.mock('@/context/provider-context', () => ({ }), })) -vi.mock('@/app/components/base/toast', () => ({ +vi.mock('@/app/components/base/toast/context', () => ({ useToastContext: () => ({ notify: mockNotify, }), diff --git a/web/app/components/header/account-setting/model-provider-page/system-model-selector/index.tsx b/web/app/components/header/account-setting/model-provider-page/system-model-selector/index.tsx index 29c71e04fc..5df062789b 100644 --- a/web/app/components/header/account-setting/model-provider-page/system-model-selector/index.tsx +++ b/web/app/components/header/account-setting/model-provider-page/system-model-selector/index.tsx @@ -12,7 +12,7 @@ import { PortalToFollowElemContent, PortalToFollowElemTrigger, } from '@/app/components/base/portal-to-follow-elem' -import { useToastContext } from '@/app/components/base/toast' +import { useToastContext } from '@/app/components/base/toast/context' import Tooltip from '@/app/components/base/tooltip' import { useAppContext } from '@/context/app-context' import { useProviderContext } from '@/context/provider-context' diff --git a/web/app/components/header/account-setting/plugin-page/SerpapiPlugin.spec.tsx b/web/app/components/header/account-setting/plugin-page/SerpapiPlugin.spec.tsx index de480d47a1..97a79815ff 100644 --- a/web/app/components/header/account-setting/plugin-page/SerpapiPlugin.spec.tsx +++ b/web/app/components/header/account-setting/plugin-page/SerpapiPlugin.spec.tsx @@ -1,6 +1,6 @@ import type { PluginProvider } from '@/models/common' import { act, fireEvent, render, screen, waitFor } from '@testing-library/react' -import { useToastContext } from '@/app/components/base/toast' +import { useToastContext } from '@/app/components/base/toast/context' import { useAppContext } from '@/context/app-context' import SerpapiPlugin from './SerpapiPlugin' import { updatePluginKey, validatePluginKey } from './utils' @@ -20,7 +20,7 @@ const mockEventEmitter = vi.hoisted(() => { } }) -vi.mock('@/app/components/base/toast', () => ({ +vi.mock('@/app/components/base/toast/context', () => ({ useToastContext: vi.fn(), })) diff --git a/web/app/components/header/account-setting/plugin-page/SerpapiPlugin.tsx b/web/app/components/header/account-setting/plugin-page/SerpapiPlugin.tsx index f6909fad28..fe8832e84b 100644 --- a/web/app/components/header/account-setting/plugin-page/SerpapiPlugin.tsx +++ b/web/app/components/header/account-setting/plugin-page/SerpapiPlugin.tsx @@ -2,7 +2,7 @@ import type { Form, ValidateValue } from '../key-validator/declarations' import type { PluginProvider } from '@/models/common' import Image from 'next/image' import { useTranslation } from 'react-i18next' -import { useToastContext } from '@/app/components/base/toast' +import { useToastContext } from '@/app/components/base/toast/context' import { useAppContext } from '@/context/app-context' import SerpapiLogo from '../../assets/serpapi.png' import KeyValidator from '../key-validator' diff --git a/web/app/components/header/account-setting/plugin-page/index.spec.tsx b/web/app/components/header/account-setting/plugin-page/index.spec.tsx index 654292443f..0654bb68aa 100644 --- a/web/app/components/header/account-setting/plugin-page/index.spec.tsx +++ b/web/app/components/header/account-setting/plugin-page/index.spec.tsx @@ -14,7 +14,7 @@ vi.mock('@/context/app-context', () => ({ useAppContext: vi.fn(), })) -vi.mock('@/app/components/base/toast', () => ({ +vi.mock('@/app/components/base/toast/context', () => ({ useToastContext: () => ({ notify: vi.fn(), }), diff --git a/web/app/components/header/index.spec.tsx b/web/app/components/header/index.spec.tsx index 2d5ce50fb8..36c85e6f08 100644 --- a/web/app/components/header/index.spec.tsx +++ b/web/app/components/header/index.spec.tsx @@ -52,7 +52,7 @@ vi.mock('@/app/components/header/plan-badge', () => ({ ), })) -vi.mock('@/context/workspace-context', () => ({ +vi.mock('@/context/workspace-context-provider', () => ({ WorkspaceProvider: ({ children }: { children?: React.ReactNode }) => children, })) diff --git a/web/app/components/header/index.tsx b/web/app/components/header/index.tsx index 210c62b660..64498a82f8 100644 --- a/web/app/components/header/index.tsx +++ b/web/app/components/header/index.tsx @@ -8,7 +8,7 @@ import { useAppContext } from '@/context/app-context' import { useGlobalPublicStore } from '@/context/global-public-context' import { useModalContext } from '@/context/modal-context' import { useProviderContext } from '@/context/provider-context' -import { WorkspaceProvider } from '@/context/workspace-context' +import { WorkspaceProvider } from '@/context/workspace-context-provider' import useBreakpoints, { MediaType } from '@/hooks/use-breakpoints' import { Plan } from '../billing/type' import AccountDropdown from './account-dropdown' diff --git a/web/app/components/plugins/plugin-auth/__tests__/authorized-in-node.spec.tsx b/web/app/components/plugins/plugin-auth/__tests__/authorized-in-node.spec.tsx index 7e8208b995..91ffbaa24a 100644 --- a/web/app/components/plugins/plugin-auth/__tests__/authorized-in-node.spec.tsx +++ b/web/app/components/plugins/plugin-auth/__tests__/authorized-in-node.spec.tsx @@ -42,7 +42,7 @@ vi.mock('@/context/app-context', () => ({ }), })) -vi.mock('@/app/components/base/toast', () => ({ +vi.mock('@/app/components/base/toast/context', () => ({ useToastContext: () => ({ notify: vi.fn() }), })) diff --git a/web/app/components/plugins/plugin-auth/__tests__/plugin-auth-in-agent.spec.tsx b/web/app/components/plugins/plugin-auth/__tests__/plugin-auth-in-agent.spec.tsx index 6b66aca9dd..901c3ab49a 100644 --- a/web/app/components/plugins/plugin-auth/__tests__/plugin-auth-in-agent.spec.tsx +++ b/web/app/components/plugins/plugin-auth/__tests__/plugin-auth-in-agent.spec.tsx @@ -42,7 +42,7 @@ vi.mock('@/context/app-context', () => ({ }), })) -vi.mock('@/app/components/base/toast', () => ({ +vi.mock('@/app/components/base/toast/context', () => ({ useToastContext: () => ({ notify: vi.fn() }), })) diff --git a/web/app/components/plugins/plugin-auth/authorize/__tests__/api-key-modal.spec.tsx b/web/app/components/plugins/plugin-auth/authorize/__tests__/api-key-modal.spec.tsx index a99b3363d6..430149e50b 100644 --- a/web/app/components/plugins/plugin-auth/authorize/__tests__/api-key-modal.spec.tsx +++ b/web/app/components/plugins/plugin-auth/authorize/__tests__/api-key-modal.spec.tsx @@ -9,7 +9,7 @@ const mockAddPluginCredential = vi.fn().mockResolvedValue({}) const mockUpdatePluginCredential = vi.fn().mockResolvedValue({}) const mockFormValues = { isCheckValidated: true, values: { __name__: 'My Key', api_key: 'sk-123' } } -vi.mock('@/app/components/base/toast', () => ({ +vi.mock('@/app/components/base/toast/context', () => ({ useToastContext: () => ({ notify: mockNotify, }), diff --git a/web/app/components/plugins/plugin-auth/authorize/__tests__/authorize-components.spec.tsx b/web/app/components/plugins/plugin-auth/authorize/__tests__/authorize-components.spec.tsx index 51aa287fea..d120902e6d 100644 --- a/web/app/components/plugins/plugin-auth/authorize/__tests__/authorize-components.spec.tsx +++ b/web/app/components/plugins/plugin-auth/authorize/__tests__/authorize-components.spec.tsx @@ -95,7 +95,7 @@ vi.mock('@/app/components/base/form/form-scenarios/auth', () => ({ // Mock useToastContext const mockNotify = vi.fn() -vi.mock('@/app/components/base/toast', () => ({ +vi.mock('@/app/components/base/toast/context', () => ({ useToastContext: () => ({ notify: mockNotify }), })) diff --git a/web/app/components/plugins/plugin-auth/authorize/__tests__/oauth-client-settings.spec.tsx b/web/app/components/plugins/plugin-auth/authorize/__tests__/oauth-client-settings.spec.tsx index 61920e2869..f1b86f80ea 100644 --- a/web/app/components/plugins/plugin-auth/authorize/__tests__/oauth-client-settings.spec.tsx +++ b/web/app/components/plugins/plugin-auth/authorize/__tests__/oauth-client-settings.spec.tsx @@ -9,7 +9,7 @@ const mockDeletePluginOAuthCustomClient = vi.fn().mockResolvedValue({}) const mockInvalidPluginOAuthClientSchema = vi.fn() const mockFormValues = { isCheckValidated: true, values: { __oauth_client__: 'custom', client_id: 'test-id' } } -vi.mock('@/app/components/base/toast', () => ({ +vi.mock('@/app/components/base/toast/context', () => ({ useToastContext: () => ({ notify: mockNotify, }), diff --git a/web/app/components/plugins/plugin-auth/authorize/api-key-modal.tsx b/web/app/components/plugins/plugin-auth/authorize/api-key-modal.tsx index cc98ca3731..d8ab1cafdd 100644 --- a/web/app/components/plugins/plugin-auth/authorize/api-key-modal.tsx +++ b/web/app/components/plugins/plugin-auth/authorize/api-key-modal.tsx @@ -16,7 +16,7 @@ import AuthForm from '@/app/components/base/form/form-scenarios/auth' import { FormTypeEnum } from '@/app/components/base/form/types' import Loading from '@/app/components/base/loading' import Modal from '@/app/components/base/modal/modal' -import { useToastContext } from '@/app/components/base/toast' +import { useToastContext } from '@/app/components/base/toast/context' import { ReadmeEntrance } from '../../readme-panel/entrance' import { ReadmeShowType } from '../../readme-panel/store' import { diff --git a/web/app/components/plugins/plugin-auth/authorize/oauth-client-settings.tsx b/web/app/components/plugins/plugin-auth/authorize/oauth-client-settings.tsx index 7eb22ee4ac..28989da77c 100644 --- a/web/app/components/plugins/plugin-auth/authorize/oauth-client-settings.tsx +++ b/web/app/components/plugins/plugin-auth/authorize/oauth-client-settings.tsx @@ -17,7 +17,7 @@ import { useTranslation } from 'react-i18next' import Button from '@/app/components/base/button' import AuthForm from '@/app/components/base/form/form-scenarios/auth' import Modal from '@/app/components/base/modal/modal' -import { useToastContext } from '@/app/components/base/toast' +import { useToastContext } from '@/app/components/base/toast/context' import { ReadmeEntrance } from '../../readme-panel/entrance' import { ReadmeShowType } from '../../readme-panel/store' import { diff --git a/web/app/components/plugins/plugin-auth/authorized/__tests__/index.spec.tsx b/web/app/components/plugins/plugin-auth/authorized/__tests__/index.spec.tsx index f56c814222..a617c2543e 100644 --- a/web/app/components/plugins/plugin-auth/authorized/__tests__/index.spec.tsx +++ b/web/app/components/plugins/plugin-auth/authorized/__tests__/index.spec.tsx @@ -52,7 +52,7 @@ vi.mock('../../hooks/use-credential', () => ({ // Mock toast context const mockNotify = vi.fn() -vi.mock('@/app/components/base/toast', () => ({ +vi.mock('@/app/components/base/toast/context', () => ({ useToastContext: () => ({ notify: mockNotify, }), diff --git a/web/app/components/plugins/plugin-auth/authorized/index.tsx b/web/app/components/plugins/plugin-auth/authorized/index.tsx index ea58cd16c9..daee3131ad 100644 --- a/web/app/components/plugins/plugin-auth/authorized/index.tsx +++ b/web/app/components/plugins/plugin-auth/authorized/index.tsx @@ -19,7 +19,7 @@ import { PortalToFollowElemContent, PortalToFollowElemTrigger, } from '@/app/components/base/portal-to-follow-elem' -import { useToastContext } from '@/app/components/base/toast' +import { useToastContext } from '@/app/components/base/toast/context' import Indicator from '@/app/components/header/indicator' import { cn } from '@/utils/classnames' import Authorize from '../authorize' diff --git a/web/app/components/plugins/plugin-auth/hooks/__tests__/use-plugin-auth-action.spec.ts b/web/app/components/plugins/plugin-auth/hooks/__tests__/use-plugin-auth-action.spec.ts index d31b29ab85..f779623697 100644 --- a/web/app/components/plugins/plugin-auth/hooks/__tests__/use-plugin-auth-action.spec.ts +++ b/web/app/components/plugins/plugin-auth/hooks/__tests__/use-plugin-auth-action.spec.ts @@ -11,7 +11,7 @@ const mockSetPluginDefaultCredential = vi.fn().mockResolvedValue({}) const mockUpdatePluginCredential = vi.fn().mockResolvedValue({}) const mockNotify = vi.fn() -vi.mock('@/app/components/base/toast', () => ({ +vi.mock('@/app/components/base/toast/context', () => ({ useToastContext: () => ({ notify: mockNotify, }), diff --git a/web/app/components/plugins/plugin-auth/hooks/use-plugin-auth-action.ts b/web/app/components/plugins/plugin-auth/hooks/use-plugin-auth-action.ts index e9218e2d3d..5628c76cc3 100644 --- a/web/app/components/plugins/plugin-auth/hooks/use-plugin-auth-action.ts +++ b/web/app/components/plugins/plugin-auth/hooks/use-plugin-auth-action.ts @@ -5,7 +5,7 @@ import { useState, } from 'react' import { useTranslation } from 'react-i18next' -import { useToastContext } from '@/app/components/base/toast' +import { useToastContext } from '@/app/components/base/toast/context' import { useDeletePluginCredentialHook, useSetPluginDefaultCredentialHook, diff --git a/web/app/components/plugins/plugin-detail-panel/subscription-list/edit/__tests__/apikey-edit-modal.spec.tsx b/web/app/components/plugins/plugin-detail-panel/subscription-list/edit/__tests__/apikey-edit-modal.spec.tsx index af145df2da..e0fb7455ce 100644 --- a/web/app/components/plugins/plugin-detail-panel/subscription-list/edit/__tests__/apikey-edit-modal.spec.tsx +++ b/web/app/components/plugins/plugin-detail-panel/subscription-list/edit/__tests__/apikey-edit-modal.spec.tsx @@ -54,13 +54,16 @@ vi.mock('@/app/components/base/toast', async (importOriginal) => { default: { notify: (args: { type: string, message: string }) => mockToast(args), }, - useToastContext: () => ({ - notify: (args: { type: string, message: string }) => mockToast(args), - close: vi.fn(), - }), } }) +vi.mock('@/app/components/base/toast/context', () => ({ + useToastContext: () => ({ + notify: (args: { type: string, message: string }) => mockToast(args), + close: vi.fn(), + }), +})) + const createSubscription = (overrides: Partial = {}): TriggerSubscription => ({ id: 'sub-1', name: 'Subscription One', diff --git a/web/app/components/plugins/plugin-detail-panel/subscription-list/edit/__tests__/manual-edit-modal.spec.tsx b/web/app/components/plugins/plugin-detail-panel/subscription-list/edit/__tests__/manual-edit-modal.spec.tsx index c6144542ab..60a8428287 100644 --- a/web/app/components/plugins/plugin-detail-panel/subscription-list/edit/__tests__/manual-edit-modal.spec.tsx +++ b/web/app/components/plugins/plugin-detail-panel/subscription-list/edit/__tests__/manual-edit-modal.spec.tsx @@ -37,13 +37,16 @@ vi.mock('@/app/components/base/toast', async (importOriginal) => { default: { notify: (args: { type: string, message: string }) => mockToast(args), }, - useToastContext: () => ({ - notify: (args: { type: string, message: string }) => mockToast(args), - close: vi.fn(), - }), } }) +vi.mock('@/app/components/base/toast/context', () => ({ + useToastContext: () => ({ + notify: (args: { type: string, message: string }) => mockToast(args), + close: vi.fn(), + }), +})) + const createSubscription = (overrides: Partial = {}): TriggerSubscription => ({ id: 'sub-1', name: 'Subscription One', diff --git a/web/app/components/plugins/plugin-detail-panel/subscription-list/edit/__tests__/oauth-edit-modal.spec.tsx b/web/app/components/plugins/plugin-detail-panel/subscription-list/edit/__tests__/oauth-edit-modal.spec.tsx index 7bdcdbc936..8835b46695 100644 --- a/web/app/components/plugins/plugin-detail-panel/subscription-list/edit/__tests__/oauth-edit-modal.spec.tsx +++ b/web/app/components/plugins/plugin-detail-panel/subscription-list/edit/__tests__/oauth-edit-modal.spec.tsx @@ -37,13 +37,16 @@ vi.mock('@/app/components/base/toast', async (importOriginal) => { default: { notify: (args: { type: string, message: string }) => mockToast(args), }, - useToastContext: () => ({ - notify: (args: { type: string, message: string }) => mockToast(args), - close: vi.fn(), - }), } }) +vi.mock('@/app/components/base/toast/context', () => ({ + useToastContext: () => ({ + notify: (args: { type: string, message: string }) => mockToast(args), + close: vi.fn(), + }), +})) + const createSubscription = (overrides: Partial = {}): TriggerSubscription => ({ id: 'sub-1', name: 'Subscription One', diff --git a/web/app/components/plugins/plugin-detail-panel/tool-selector/components/__tests__/tool-credentials-form.spec.tsx b/web/app/components/plugins/plugin-detail-panel/tool-selector/components/__tests__/tool-credentials-form.spec.tsx index 20655d0139..cb5b929d29 100644 --- a/web/app/components/plugins/plugin-detail-panel/tool-selector/components/__tests__/tool-credentials-form.spec.tsx +++ b/web/app/components/plugins/plugin-detail-panel/tool-selector/components/__tests__/tool-credentials-form.spec.tsx @@ -12,6 +12,9 @@ vi.mock('@/utils/classnames', () => ({ vi.mock('@/app/components/base/toast', () => ({ default: { notify: vi.fn() }, +})) + +vi.mock('@/app/components/base/toast/context', () => ({ useToastContext: () => ({ notify: vi.fn() }), })) diff --git a/web/app/components/plugins/plugin-page/__tests__/context.spec.tsx b/web/app/components/plugins/plugin-page/__tests__/context.spec.tsx index ab0e35e042..389c161e8a 100644 --- a/web/app/components/plugins/plugin-page/__tests__/context.spec.tsx +++ b/web/app/components/plugins/plugin-page/__tests__/context.spec.tsx @@ -3,7 +3,8 @@ import { beforeEach, describe, expect, it, vi } from 'vitest' // Import mocks import { useGlobalPublicStore } from '@/context/global-public-context' -import { PluginPageContext, PluginPageContextProvider, usePluginPageContext } from '../context' +import { PluginPageContext, usePluginPageContext } from '../context' +import { PluginPageContextProvider } from '../context-provider' // Mock dependencies vi.mock('nuqs', () => ({ diff --git a/web/app/components/plugins/plugin-page/context.tsx b/web/app/components/plugins/plugin-page/context-provider.tsx similarity index 58% rename from web/app/components/plugins/plugin-page/context.tsx rename to web/app/components/plugins/plugin-page/context-provider.tsx index 01ec518347..83776b48f9 100644 --- a/web/app/components/plugins/plugin-page/context.tsx +++ b/web/app/components/plugins/plugin-page/context-provider.tsx @@ -1,24 +1,20 @@ 'use client' -import type { ReactNode, RefObject } from 'react' +import type { ReactNode } from 'react' +import type { PluginPageTab } from './context' import type { FilterState } from './filter-management' -import { noop } from 'es-toolkit/function' import { parseAsStringEnum, useQueryState } from 'nuqs' import { useMemo, useRef, useState, } from 'react' -import { - createContext, - useContextSelector, -} from 'use-context-selector' import { useGlobalPublicStore } from '@/context/global-public-context' import { PLUGIN_PAGE_TABS_MAP, usePluginPageTabs } from '../hooks' import { PLUGIN_TYPE_SEARCH_MAP } from '../marketplace/constants' - -export type PluginPageTab = typeof PLUGIN_PAGE_TABS_MAP[keyof typeof PLUGIN_PAGE_TABS_MAP] - | (typeof PLUGIN_TYPE_SEARCH_MAP)[keyof typeof PLUGIN_TYPE_SEARCH_MAP] +import { + PluginPageContext, +} from './context' const PLUGIN_PAGE_TAB_VALUES: PluginPageTab[] = [ PLUGIN_PAGE_TABS_MAP.plugins, @@ -29,42 +25,10 @@ const PLUGIN_PAGE_TAB_VALUES: PluginPageTab[] = [ const parseAsPluginPageTab = parseAsStringEnum(PLUGIN_PAGE_TAB_VALUES) .withDefault(PLUGIN_PAGE_TABS_MAP.plugins) -export type PluginPageContextValue = { - containerRef: RefObject - currentPluginID: string | undefined - setCurrentPluginID: (pluginID?: string) => void - filters: FilterState - setFilters: (filter: FilterState) => void - activeTab: PluginPageTab - setActiveTab: (tab: PluginPageTab) => void - options: Array<{ value: string, text: string }> -} - -const emptyContainerRef: RefObject = { current: null } - -export const PluginPageContext = createContext({ - containerRef: emptyContainerRef, - currentPluginID: undefined, - setCurrentPluginID: noop, - filters: { - categories: [], - tags: [], - searchQuery: '', - }, - setFilters: noop, - activeTab: PLUGIN_PAGE_TABS_MAP.plugins, - setActiveTab: noop, - options: [], -}) - type PluginPageContextProviderProps = { children: ReactNode } -export function usePluginPageContext(selector: (value: PluginPageContextValue) => any) { - return useContextSelector(PluginPageContext, selector) -} - export const PluginPageContextProvider = ({ children, }: PluginPageContextProviderProps) => { diff --git a/web/app/components/plugins/plugin-page/context.ts b/web/app/components/plugins/plugin-page/context.ts new file mode 100644 index 0000000000..04ab1eef19 --- /dev/null +++ b/web/app/components/plugins/plugin-page/context.ts @@ -0,0 +1,46 @@ +'use client' + +import type { RefObject } from 'react' +import type { PLUGIN_TYPE_SEARCH_MAP } from '../marketplace/constants' +import type { FilterState } from './filter-management' +import { noop } from 'es-toolkit/function' +import { + createContext, + useContextSelector, +} from 'use-context-selector' +import { PLUGIN_PAGE_TABS_MAP } from '../hooks' + +export type PluginPageTab = typeof PLUGIN_PAGE_TABS_MAP[keyof typeof PLUGIN_PAGE_TABS_MAP] + | (typeof PLUGIN_TYPE_SEARCH_MAP)[keyof typeof PLUGIN_TYPE_SEARCH_MAP] + +export type PluginPageContextValue = { + containerRef: RefObject + currentPluginID: string | undefined + setCurrentPluginID: (pluginID?: string) => void + filters: FilterState + setFilters: (filter: FilterState) => void + activeTab: PluginPageTab + setActiveTab: (tab: PluginPageTab) => void + options: Array<{ value: string, text: string }> +} + +const emptyContainerRef: RefObject = { current: null } + +export const PluginPageContext = createContext({ + containerRef: emptyContainerRef, + currentPluginID: undefined, + setCurrentPluginID: noop, + filters: { + categories: [], + tags: [], + searchQuery: '', + }, + setFilters: noop, + activeTab: PLUGIN_PAGE_TABS_MAP.plugins, + setActiveTab: noop, + options: [], +}) + +export function usePluginPageContext(selector: (value: PluginPageContextValue) => any) { + return useContextSelector(PluginPageContext, selector) +} diff --git a/web/app/components/plugins/plugin-page/index.tsx b/web/app/components/plugins/plugin-page/index.tsx index bec1eb60ab..6768361acf 100644 --- a/web/app/components/plugins/plugin-page/index.tsx +++ b/web/app/components/plugins/plugin-page/index.tsx @@ -28,10 +28,8 @@ import { PLUGIN_PAGE_TABS_MAP } from '../hooks' import InstallFromLocalPackage from '../install-plugin/install-from-local-package' import InstallFromMarketplace from '../install-plugin/install-from-marketplace' import { PLUGIN_TYPE_SEARCH_MAP } from '../marketplace/constants' -import { - PluginPageContextProvider, - usePluginPageContext, -} from './context' +import { usePluginPageContext } from './context' +import { PluginPageContextProvider } from './context-provider' import DebugInfo from './debug-info' import InstallPluginDropdown from './install-plugin-dropdown' import PluginTasks from './plugin-tasks' diff --git a/web/app/components/rag-pipeline/components/__tests__/update-dsl-modal.spec.tsx b/web/app/components/rag-pipeline/components/__tests__/update-dsl-modal.spec.tsx index 2f9b2172bd..98ad5f78f4 100644 --- a/web/app/components/rag-pipeline/components/__tests__/update-dsl-modal.spec.tsx +++ b/web/app/components/rag-pipeline/components/__tests__/update-dsl-modal.spec.tsx @@ -20,7 +20,7 @@ vi.mock('use-context-selector', () => ({ useContext: () => ({ notify: mockNotify }), })) -vi.mock('@/app/components/base/toast', () => ({ +vi.mock('@/app/components/base/toast/context', () => ({ ToastContext: { Provider: ({ children }: PropsWithChildren) => children }, })) diff --git a/web/app/components/rag-pipeline/components/rag-pipeline-header/__tests__/index.spec.tsx b/web/app/components/rag-pipeline/components/rag-pipeline-header/__tests__/index.spec.tsx index 0b858eaaa7..00c989acb0 100644 --- a/web/app/components/rag-pipeline/components/rag-pipeline-header/__tests__/index.spec.tsx +++ b/web/app/components/rag-pipeline/components/rag-pipeline-header/__tests__/index.spec.tsx @@ -155,7 +155,7 @@ vi.mock('@/app/components/base/amplitude', () => ({ })) const mockNotify = vi.fn() -vi.mock('@/app/components/base/toast', () => ({ +vi.mock('@/app/components/base/toast/context', () => ({ useToastContext: () => ({ notify: mockNotify, }), diff --git a/web/app/components/rag-pipeline/components/rag-pipeline-header/publisher/__tests__/index.spec.tsx b/web/app/components/rag-pipeline/components/rag-pipeline-header/publisher/__tests__/index.spec.tsx index 6129d3fe73..9cd1af2736 100644 --- a/web/app/components/rag-pipeline/components/rag-pipeline-header/publisher/__tests__/index.spec.tsx +++ b/web/app/components/rag-pipeline/components/rag-pipeline-header/publisher/__tests__/index.spec.tsx @@ -3,7 +3,7 @@ import { QueryClient, QueryClientProvider } from '@tanstack/react-query' import { fireEvent, render, screen, waitFor } from '@testing-library/react' import * as React from 'react' import { beforeEach, describe, expect, it, vi } from 'vitest' -import { ToastContext } from '@/app/components/base/toast' +import { ToastContext } from '@/app/components/base/toast/context' import Publisher from '../index' import Popup from '../popup' diff --git a/web/app/components/rag-pipeline/components/rag-pipeline-header/publisher/__tests__/popup.spec.tsx b/web/app/components/rag-pipeline/components/rag-pipeline-header/publisher/__tests__/popup.spec.tsx index 71707721a4..48282820d8 100644 --- a/web/app/components/rag-pipeline/components/rag-pipeline-header/publisher/__tests__/popup.spec.tsx +++ b/web/app/components/rag-pipeline/components/rag-pipeline-header/publisher/__tests__/popup.spec.tsx @@ -57,7 +57,7 @@ vi.mock('@/app/components/workflow/store', () => ({ }), })) -vi.mock('@/app/components/base/toast', () => ({ +vi.mock('@/app/components/base/toast/context', () => ({ useToastContext: () => ({ notify: mockNotify }), })) diff --git a/web/app/components/rag-pipeline/components/rag-pipeline-header/publisher/popup.tsx b/web/app/components/rag-pipeline/components/rag-pipeline-header/publisher/popup.tsx index 2dd56b4277..371e4fd721 100644 --- a/web/app/components/rag-pipeline/components/rag-pipeline-header/publisher/popup.tsx +++ b/web/app/components/rag-pipeline/components/rag-pipeline-header/publisher/popup.tsx @@ -24,7 +24,7 @@ import Confirm from '@/app/components/base/confirm' import Divider from '@/app/components/base/divider' import { SparklesSoft } from '@/app/components/base/icons/src/public/common' import PremiumBadge from '@/app/components/base/premium-badge' -import { useToastContext } from '@/app/components/base/toast' +import { useToastContext } from '@/app/components/base/toast/context' import { useChecklistBeforePublish, } from '@/app/components/workflow/hooks' diff --git a/web/app/components/rag-pipeline/hooks/__tests__/index.spec.ts b/web/app/components/rag-pipeline/hooks/__tests__/index.spec.ts index 4c60e5133c..95fe763d61 100644 --- a/web/app/components/rag-pipeline/hooks/__tests__/index.spec.ts +++ b/web/app/components/rag-pipeline/hooks/__tests__/index.spec.ts @@ -31,7 +31,7 @@ vi.mock('@/app/components/workflow/store', () => ({ })) const mockNotify = vi.fn() -vi.mock('@/app/components/base/toast', () => ({ +vi.mock('@/app/components/base/toast/context', () => ({ useToastContext: () => ({ notify: mockNotify, }), diff --git a/web/app/components/rag-pipeline/hooks/__tests__/use-DSL.spec.ts b/web/app/components/rag-pipeline/hooks/__tests__/use-DSL.spec.ts index c0b983052d..c5603f2fc7 100644 --- a/web/app/components/rag-pipeline/hooks/__tests__/use-DSL.spec.ts +++ b/web/app/components/rag-pipeline/hooks/__tests__/use-DSL.spec.ts @@ -3,7 +3,7 @@ import { afterEach, beforeEach, describe, expect, it, vi } from 'vitest' import { useDSL } from '../use-DSL' const mockNotify = vi.fn() -vi.mock('@/app/components/base/toast', () => ({ +vi.mock('@/app/components/base/toast/context', () => ({ useToastContext: () => ({ notify: mockNotify }), })) diff --git a/web/app/components/rag-pipeline/hooks/__tests__/use-update-dsl-modal.spec.ts b/web/app/components/rag-pipeline/hooks/__tests__/use-update-dsl-modal.spec.ts index 942e337ad8..a50965fb3b 100644 --- a/web/app/components/rag-pipeline/hooks/__tests__/use-update-dsl-modal.spec.ts +++ b/web/app/components/rag-pipeline/hooks/__tests__/use-update-dsl-modal.spec.ts @@ -23,7 +23,7 @@ vi.mock('use-context-selector', () => ({ useContext: () => ({ notify: mockNotify }), })) -vi.mock('@/app/components/base/toast', () => ({ +vi.mock('@/app/components/base/toast/context', () => ({ ToastContext: {}, })) diff --git a/web/app/components/rag-pipeline/hooks/use-DSL.ts b/web/app/components/rag-pipeline/hooks/use-DSL.ts index 5c0f9def1c..f45cf35bdf 100644 --- a/web/app/components/rag-pipeline/hooks/use-DSL.ts +++ b/web/app/components/rag-pipeline/hooks/use-DSL.ts @@ -3,7 +3,7 @@ import { useState, } from 'react' import { useTranslation } from 'react-i18next' -import { useToastContext } from '@/app/components/base/toast' +import { useToastContext } from '@/app/components/base/toast/context' import { DSL_EXPORT_CHECK, } from '@/app/components/workflow/constants' diff --git a/web/app/components/rag-pipeline/hooks/use-update-dsl-modal.ts b/web/app/components/rag-pipeline/hooks/use-update-dsl-modal.ts index 3b86937417..48d2bc78b4 100644 --- a/web/app/components/rag-pipeline/hooks/use-update-dsl-modal.ts +++ b/web/app/components/rag-pipeline/hooks/use-update-dsl-modal.ts @@ -6,7 +6,7 @@ import { } from 'react' import { useTranslation } from 'react-i18next' import { useContext } from 'use-context-selector' -import { ToastContext } from '@/app/components/base/toast' +import { ToastContext } from '@/app/components/base/toast/context' import { WORKFLOW_DATA_UPDATE } from '@/app/components/workflow/constants' import { usePluginDependencies } from '@/app/components/workflow/plugin-dependency/hooks' import { useWorkflowStore } from '@/app/components/workflow/store' diff --git a/web/app/components/workflow-app/components/workflow-header/__tests__/features-trigger.spec.tsx b/web/app/components/workflow-app/components/workflow-header/__tests__/features-trigger.spec.tsx index 4a7fd1275f..2318a1c7bc 100644 --- a/web/app/components/workflow-app/components/workflow-header/__tests__/features-trigger.spec.tsx +++ b/web/app/components/workflow-app/components/workflow-header/__tests__/features-trigger.spec.tsx @@ -4,7 +4,7 @@ import type { App } from '@/types/app' import { render, screen, waitFor } from '@testing-library/react' import userEvent from '@testing-library/user-event' import { useStore as useAppStore } from '@/app/components/app/store' -import { ToastContext } from '@/app/components/base/toast' +import { ToastContext } from '@/app/components/base/toast/context' import { Plan } from '@/app/components/billing/type' import { BlockEnum, InputVarType } from '@/app/components/workflow/types' import FeaturesTrigger from '../features-trigger' diff --git a/web/app/components/workflow-app/components/workflow-header/features-trigger.tsx b/web/app/components/workflow-app/components/workflow-header/features-trigger.tsx index d58eb6c669..84603e9a13 100644 --- a/web/app/components/workflow-app/components/workflow-header/features-trigger.tsx +++ b/web/app/components/workflow-app/components/workflow-header/features-trigger.tsx @@ -17,7 +17,7 @@ import AppPublisher from '@/app/components/app/app-publisher' import { useStore as useAppStore } from '@/app/components/app/store' import Button from '@/app/components/base/button' import { useFeatures } from '@/app/components/base/features/hooks' -import { useToastContext } from '@/app/components/base/toast' +import { useToastContext } from '@/app/components/base/toast/context' import { Plan } from '@/app/components/billing/type' import { useChecklist, diff --git a/web/app/components/workflow-app/hooks/use-DSL.ts b/web/app/components/workflow-app/hooks/use-DSL.ts index 939e43b554..918a60f185 100644 --- a/web/app/components/workflow-app/hooks/use-DSL.ts +++ b/web/app/components/workflow-app/hooks/use-DSL.ts @@ -4,7 +4,7 @@ import { } from 'react' import { useTranslation } from 'react-i18next' import { useStore as useAppStore } from '@/app/components/app/store' -import { useToastContext } from '@/app/components/base/toast' +import { useToastContext } from '@/app/components/base/toast/context' import { DSL_EXPORT_CHECK, } from '@/app/components/workflow/constants' diff --git a/web/app/components/workflow/header/run-mode.tsx b/web/app/components/workflow/header/run-mode.tsx index 63c48ff0dc..943af13a92 100644 --- a/web/app/components/workflow/header/run-mode.tsx +++ b/web/app/components/workflow/header/run-mode.tsx @@ -5,7 +5,7 @@ import { useCallback, useEffect, useRef } from 'react' import { useTranslation } from 'react-i18next' import { trackEvent } from '@/app/components/base/amplitude' import { StopCircle } from '@/app/components/base/icons/src/vender/line/mediaAndDevices' -import { useToastContext } from '@/app/components/base/toast' +import { useToastContext } from '@/app/components/base/toast/context' import { useWorkflowRun, useWorkflowRunValidation, useWorkflowStartRun } from '@/app/components/workflow/hooks' import ShortcutsName from '@/app/components/workflow/shortcuts-name' import { useStore } from '@/app/components/workflow/store' diff --git a/web/app/components/workflow/hooks/__tests__/use-checklist.spec.ts b/web/app/components/workflow/hooks/__tests__/use-checklist.spec.ts index d72d001e0b..1b37055134 100644 --- a/web/app/components/workflow/hooks/__tests__/use-checklist.spec.ts +++ b/web/app/components/workflow/hooks/__tests__/use-checklist.spec.ts @@ -82,7 +82,7 @@ vi.mock('../index', () => ({ useNodesMetaData: () => ({ nodes: [], nodesMap: mockNodesMap }), })) -vi.mock('@/app/components/base/toast', () => ({ +vi.mock('@/app/components/base/toast/context', () => ({ useToastContext: () => ({ notify: vi.fn() }), })) diff --git a/web/app/components/workflow/hooks/use-checklist.ts b/web/app/components/workflow/hooks/use-checklist.ts index 642179aed7..dd1d012aed 100644 --- a/web/app/components/workflow/hooks/use-checklist.ts +++ b/web/app/components/workflow/hooks/use-checklist.ts @@ -22,7 +22,7 @@ import { import { useTranslation } from 'react-i18next' import { useEdges, useStoreApi } from 'reactflow' import { useStore as useAppStore } from '@/app/components/app/store' -import { useToastContext } from '@/app/components/base/toast' +import { useToastContext } from '@/app/components/base/toast/context' import { ModelTypeEnum } from '@/app/components/header/account-setting/model-provider-page/declarations' import { useModelList } from '@/app/components/header/account-setting/model-provider-page/hooks' import useNodes from '@/app/components/workflow/store/workflow/use-nodes' diff --git a/web/app/components/workflow/note-node/note-editor/plugins/link-editor-plugin/hooks.ts b/web/app/components/workflow/note-node/note-editor/plugins/link-editor-plugin/hooks.ts index 3a084ef0af..511940d142 100644 --- a/web/app/components/workflow/note-node/note-editor/plugins/link-editor-plugin/hooks.ts +++ b/web/app/components/workflow/note-node/note-editor/plugins/link-editor-plugin/hooks.ts @@ -15,7 +15,7 @@ import { useEffect, } from 'react' import { useTranslation } from 'react-i18next' -import { useToastContext } from '@/app/components/base/toast' +import { useToastContext } from '@/app/components/base/toast/context' import { useNoteEditorStore } from '../../store' import { urlRegExp } from '../../utils' diff --git a/web/app/components/workflow/panel/chat-variable-panel/components/object-value-item.tsx b/web/app/components/workflow/panel/chat-variable-panel/components/object-value-item.tsx index ef4b4d71a9..74765c2b9e 100644 --- a/web/app/components/workflow/panel/chat-variable-panel/components/object-value-item.tsx +++ b/web/app/components/workflow/panel/chat-variable-panel/components/object-value-item.tsx @@ -5,7 +5,7 @@ import * as React from 'react' import { useCallback, useState } from 'react' import { useTranslation } from 'react-i18next' import { useContext } from 'use-context-selector' -import { ToastContext } from '@/app/components/base/toast' +import { ToastContext } from '@/app/components/base/toast/context' import RemoveButton from '@/app/components/workflow/nodes/_base/components/remove-button' import VariableTypeSelector from '@/app/components/workflow/panel/chat-variable-panel/components/variable-type-select' import { ChatVarType } from '@/app/components/workflow/panel/chat-variable-panel/type' diff --git a/web/app/components/workflow/panel/chat-variable-panel/components/variable-modal.tsx b/web/app/components/workflow/panel/chat-variable-panel/components/variable-modal.tsx index 5c07cca3df..dd14bcf75a 100644 --- a/web/app/components/workflow/panel/chat-variable-panel/components/variable-modal.tsx +++ b/web/app/components/workflow/panel/chat-variable-panel/components/variable-modal.tsx @@ -7,7 +7,7 @@ import { useContext } from 'use-context-selector' import { v4 as uuid4 } from 'uuid' import Button from '@/app/components/base/button' import Input from '@/app/components/base/input' -import { ToastContext } from '@/app/components/base/toast' +import { ToastContext } from '@/app/components/base/toast/context' import CodeEditor from '@/app/components/workflow/nodes/_base/components/editor/code-editor' import { CodeLanguage } from '@/app/components/workflow/nodes/code/types' import ArrayValueList from '@/app/components/workflow/panel/chat-variable-panel/components/array-value-list' diff --git a/web/app/components/workflow/panel/debug-and-preview/hooks.ts b/web/app/components/workflow/panel/debug-and-preview/hooks.ts index b31673ee26..3481733cd2 100644 --- a/web/app/components/workflow/panel/debug-and-preview/hooks.ts +++ b/web/app/components/workflow/panel/debug-and-preview/hooks.ts @@ -26,7 +26,7 @@ import { getProcessedFiles, getProcessedFilesFromResponse, } from '@/app/components/base/file-uploader/utils' -import { useToastContext } from '@/app/components/base/toast' +import { useToastContext } from '@/app/components/base/toast/context' import { CUSTOM_NODE, } from '@/app/components/workflow/constants' diff --git a/web/app/components/workflow/panel/env-panel/variable-modal.tsx b/web/app/components/workflow/panel/env-panel/variable-modal.tsx index 0e120ac77c..493a73405a 100644 --- a/web/app/components/workflow/panel/env-panel/variable-modal.tsx +++ b/web/app/components/workflow/panel/env-panel/variable-modal.tsx @@ -7,7 +7,7 @@ import { useContext } from 'use-context-selector' import { v4 as uuid4 } from 'uuid' import Button from '@/app/components/base/button' import Input from '@/app/components/base/input' -import { ToastContext } from '@/app/components/base/toast' +import { ToastContext } from '@/app/components/base/toast/context' import Tooltip from '@/app/components/base/tooltip' import { useWorkflowStore } from '@/app/components/workflow/store' import { cn } from '@/utils/classnames' diff --git a/web/app/components/workflow/run/index.tsx b/web/app/components/workflow/run/index.tsx index 441002c86c..96acb1a026 100644 --- a/web/app/components/workflow/run/index.tsx +++ b/web/app/components/workflow/run/index.tsx @@ -6,7 +6,7 @@ import { useCallback, useEffect, useMemo, useRef, useState } from 'react' import { useTranslation } from 'react-i18next' import { useContext } from 'use-context-selector' import Loading from '@/app/components/base/loading' -import { ToastContext } from '@/app/components/base/toast' +import { ToastContext } from '@/app/components/base/toast/context' import { WorkflowRunningStatus } from '@/app/components/workflow/types' import { fetchRunDetail, fetchTracingList } from '@/service/log' import { cn } from '@/utils/classnames' diff --git a/web/app/components/workflow/update-dsl-modal.tsx b/web/app/components/workflow/update-dsl-modal.tsx index d33679ff1b..08107a0c24 100644 --- a/web/app/components/workflow/update-dsl-modal.tsx +++ b/web/app/components/workflow/update-dsl-modal.tsx @@ -24,7 +24,7 @@ import { useStore as useAppStore } from '@/app/components/app/store' import Button from '@/app/components/base/button' import Modal from '@/app/components/base/modal' import { FILE_EXTS } from '@/app/components/base/prompt-editor/constants' -import { ToastContext } from '@/app/components/base/toast' +import { ToastContext } from '@/app/components/base/toast/context' import { usePluginDependencies } from '@/app/components/workflow/plugin-dependency/hooks' import { useEventEmitterContextContext } from '@/context/event-emitter' import { diff --git a/web/app/education-apply/education-apply-page.tsx b/web/app/education-apply/education-apply-page.tsx index 3f8d80a67e..636a471d3d 100644 --- a/web/app/education-apply/education-apply-page.tsx +++ b/web/app/education-apply/education-apply-page.tsx @@ -12,7 +12,7 @@ import { import { useTranslation } from 'react-i18next' import Button from '@/app/components/base/button' import Checkbox from '@/app/components/base/checkbox' -import { useToastContext } from '@/app/components/base/toast' +import { useToastContext } from '@/app/components/base/toast/context' import { EDUCATION_VERIFYING_LOCALSTORAGE_ITEM } from '@/app/education-apply/constants' import { useDocLink } from '@/context/i18n' import { useProviderContext } from '@/context/provider-context' diff --git a/web/context/datasets-context.tsx b/web/context/datasets-context.ts similarity index 100% rename from web/context/datasets-context.tsx rename to web/context/datasets-context.ts diff --git a/web/context/event-emitter-provider.tsx b/web/context/event-emitter-provider.tsx new file mode 100644 index 0000000000..da8d2d78c2 --- /dev/null +++ b/web/context/event-emitter-provider.tsx @@ -0,0 +1,22 @@ +'use client' + +import type { ReactNode } from 'react' +import type { EventEmitterValue } from './event-emitter' +import { useEventEmitter } from 'ahooks' +import { EventEmitterContext } from './event-emitter' + +type EventEmitterContextProviderProps = { + children: ReactNode +} + +export const EventEmitterContextProvider = ({ + children, +}: EventEmitterContextProviderProps) => { + const eventEmitter = useEventEmitter() + + return ( + + {children} + + ) +} diff --git a/web/context/event-emitter.tsx b/web/context/event-emitter.ts similarity index 53% rename from web/context/event-emitter.tsx rename to web/context/event-emitter.ts index 14b81eacb6..eb7794dfe1 100644 --- a/web/context/event-emitter.tsx +++ b/web/context/event-emitter.ts @@ -1,7 +1,6 @@ 'use client' import type { EventEmitter } from 'ahooks/lib/useEventEmitter' -import { useEventEmitter } from 'ahooks' import { createContext, useContext } from 'use-context-selector' /** @@ -16,25 +15,10 @@ export type EventEmitterMessage = { export type EventEmitterValue = string | EventEmitterMessage -const EventEmitterContext = createContext<{ eventEmitter: EventEmitter | null }>({ +export const EventEmitterContext = createContext<{ eventEmitter: EventEmitter | null }>({ eventEmitter: null, }) export const useEventEmitterContextContext = () => useContext(EventEmitterContext) -type EventEmitterContextProviderProps = { - children: React.ReactNode -} -export const EventEmitterContextProvider = ({ - children, -}: EventEmitterContextProviderProps) => { - const eventEmitter = useEventEmitter() - - return ( - - {children} - - ) -} - export default EventEmitterContext diff --git a/web/context/mitt-context-provider.tsx b/web/context/mitt-context-provider.tsx new file mode 100644 index 0000000000..b177694d8d --- /dev/null +++ b/web/context/mitt-context-provider.tsx @@ -0,0 +1,19 @@ +'use client' + +import type { ReactNode } from 'react' +import { useMitt } from '@/hooks/use-mitt' +import { MittContext } from './mitt-context' + +type MittProviderProps = { + children: ReactNode +} + +export const MittProvider = ({ children }: MittProviderProps) => { + const mitt = useMitt() + + return ( + + {children} + + ) +} diff --git a/web/context/mitt-context.tsx b/web/context/mitt-context.ts similarity index 66% rename from web/context/mitt-context.tsx rename to web/context/mitt-context.ts index 4317fc5660..5c4a0771c5 100644 --- a/web/context/mitt-context.tsx +++ b/web/context/mitt-context.ts @@ -1,6 +1,8 @@ +'use client' + +import type { useMitt } from '@/hooks/use-mitt' import { noop } from 'es-toolkit/function' import { createContext, useContext, useContextSelector } from 'use-context-selector' -import { useMitt } from '@/hooks/use-mitt' type ContextValueType = ReturnType export const MittContext = createContext({ @@ -8,16 +10,6 @@ export const MittContext = createContext({ useSubscribe: noop, }) -export const MittProvider = ({ children }: { children: React.ReactNode }) => { - const mitt = useMitt() - - return ( - - {children} - - ) -} - export const useMittContext = () => { return useContext(MittContext) } diff --git a/web/context/modal-context.tsx b/web/context/modal-context-provider.tsx similarity index 81% rename from web/context/modal-context.tsx rename to web/context/modal-context-provider.tsx index ae1184e5bf..8c64642f43 100644 --- a/web/context/modal-context.tsx +++ b/web/context/modal-context-provider.tsx @@ -1,32 +1,20 @@ 'use client' -import type { Dispatch, SetStateAction } from 'react' -import type { TriggerEventsLimitModalPayload } from './hooks/use-trigger-events-limit-modal' +import type { ReactNode, SetStateAction } from 'react' +import type { ModalState, ModelModalType } from './modal-context' import type { OpeningStatement } from '@/app/components/base/features/types' import type { CreateExternalAPIReq } from '@/app/components/datasets/external-api/declarations' import type { AccountSettingTab } from '@/app/components/header/account-setting/constants' -import type { - ConfigurationMethodEnum, - Credential, - CustomConfigurationModelFixedFields, - CustomModel, - ModelModalModeEnum, - ModelProvider, -} from '@/app/components/header/account-setting/model-provider-page/declarations' import type { ModelLoadBalancingModalProps } from '@/app/components/header/account-setting/model-provider-page/provider-added-card/model-load-balancing-modal' import type { UpdatePluginPayload } from '@/app/components/plugins/types' import type { InputVar } from '@/app/components/workflow/types' import type { ExpireNoticeModalPayloadProps } from '@/app/education-apply/expire-notice-modal' -import type { - ApiBasedExtension, - ExternalDataTool, -} from '@/models/common' +import type { ApiBasedExtension, ExternalDataTool } from '@/models/common' import type { ModerationConfig, PromptVariable } from '@/models/debug' -import { noop } from 'es-toolkit/function' import dynamic from 'next/dynamic' import { useCallback, useEffect, useRef, useState } from 'react' -import { createContext, useContext, useContextSelector } from 'use-context-selector' import { + DEFAULT_ACCOUNT_SETTING_TAB, isValidAccountSettingTab, } from '@/app/components/header/account-setting/constants' @@ -39,11 +27,10 @@ import { useAccountSettingModal, usePricingModal, } from '@/hooks/use-query-params' - +import { useTriggerEventsLimitModal } from './hooks/use-trigger-events-limit-modal' import { - - useTriggerEventsLimitModal, -} from './hooks/use-trigger-events-limit-modal' + ModalContext, +} from './modal-context' const AccountSetting = dynamic(() => import('@/app/components/header/account-setting'), { ssr: false, @@ -86,72 +73,8 @@ const TriggerEventsLimitModal = dynamic(() => import('@/app/components/billing/t ssr: false, }) -export type ModalState = { - payload: T - onCancelCallback?: () => void - onSaveCallback?: (newPayload?: T, formValues?: Record) => void - onRemoveCallback?: (newPayload?: T, formValues?: Record) => void - onEditCallback?: (newPayload: T) => void - onValidateBeforeSaveCallback?: (newPayload: T) => boolean - isEditMode?: boolean - datasetBindings?: { id: string, name: string }[] -} - -export type ModelModalType = { - currentProvider: ModelProvider - currentConfigurationMethod: ConfigurationMethodEnum - currentCustomConfigurationModelFixedFields?: CustomConfigurationModelFixedFields - isModelCredential?: boolean - credential?: Credential - model?: CustomModel - mode?: ModelModalModeEnum -} - -export type ModalContextState = { - setShowAccountSettingModal: Dispatch | null>> - setShowApiBasedExtensionModal: Dispatch | null>> - setShowModerationSettingModal: Dispatch | null>> - setShowExternalDataToolModal: Dispatch | null>> - setShowPricingModal: () => void - setShowAnnotationFullModal: () => void - setShowModelModal: Dispatch | null>> - setShowExternalKnowledgeAPIModal: Dispatch | null>> - setShowModelLoadBalancingModal: Dispatch> - setShowOpeningModal: Dispatch void - }> | null>> - setShowUpdatePluginModal: Dispatch | null>> - setShowEducationExpireNoticeModal: Dispatch | null>> - setShowTriggerEventsLimitModal: Dispatch | null>> -} - -const ModalContext = createContext({ - setShowAccountSettingModal: noop, - setShowApiBasedExtensionModal: noop, - setShowModerationSettingModal: noop, - setShowExternalDataToolModal: noop, - setShowPricingModal: noop, - setShowAnnotationFullModal: noop, - setShowModelModal: noop, - setShowExternalKnowledgeAPIModal: noop, - setShowModelLoadBalancingModal: noop, - setShowOpeningModal: noop, - setShowUpdatePluginModal: noop, - setShowEducationExpireNoticeModal: noop, - setShowTriggerEventsLimitModal: noop, -}) - -export const useModalContext = () => useContext(ModalContext) - -// Adding a dangling comma to avoid the generic parsing issue in tsx, see: -// https://github.com/microsoft/TypeScript/issues/15713 -export const useModalContextSelector = (selector: (state: ModalContextState) => T): T => - useContextSelector(ModalContext, selector) - type ModalContextProviderProps = { - children: React.ReactNode + children: ReactNode } export const ModalContextProvider = ({ children, diff --git a/web/context/modal-context.test.tsx b/web/context/modal-context.test.tsx index a0f6ff35ec..98f67a5473 100644 --- a/web/context/modal-context.test.tsx +++ b/web/context/modal-context.test.tsx @@ -2,7 +2,7 @@ import { act, screen, waitFor } from '@testing-library/react' import * as React from 'react' import { defaultPlan } from '@/app/components/billing/config' import { Plan } from '@/app/components/billing/type' -import { ModalContextProvider } from '@/context/modal-context' +import { ModalContextProvider } from '@/context/modal-context-provider' import { renderWithNuqs } from '@/test/nuqs-testing' vi.mock('@/config', async (importOriginal) => { diff --git a/web/context/modal-context.ts b/web/context/modal-context.ts new file mode 100644 index 0000000000..cc0ca28a42 --- /dev/null +++ b/web/context/modal-context.ts @@ -0,0 +1,92 @@ +'use client' + +import type { Dispatch, SetStateAction } from 'react' +import type { TriggerEventsLimitModalPayload } from './hooks/use-trigger-events-limit-modal' +import type { OpeningStatement } from '@/app/components/base/features/types' +import type { CreateExternalAPIReq } from '@/app/components/datasets/external-api/declarations' +import type { AccountSettingTab } from '@/app/components/header/account-setting/constants' +import type { + ConfigurationMethodEnum, + Credential, + CustomConfigurationModelFixedFields, + CustomModel, + ModelModalModeEnum, + ModelProvider, +} from '@/app/components/header/account-setting/model-provider-page/declarations' +import type { ModelLoadBalancingModalProps } from '@/app/components/header/account-setting/model-provider-page/provider-added-card/model-load-balancing-modal' +import type { UpdatePluginPayload } from '@/app/components/plugins/types' +import type { InputVar } from '@/app/components/workflow/types' +import type { ExpireNoticeModalPayloadProps } from '@/app/education-apply/expire-notice-modal' +import type { + ApiBasedExtension, + ExternalDataTool, +} from '@/models/common' +import type { ModerationConfig, PromptVariable } from '@/models/debug' +import { noop } from 'es-toolkit/function' +import { createContext, useContext, useContextSelector } from 'use-context-selector' + +export type ModalState = { + payload: T + onCancelCallback?: () => void + onSaveCallback?: (newPayload?: T, formValues?: Record) => void + onRemoveCallback?: (newPayload?: T, formValues?: Record) => void + onEditCallback?: (newPayload: T) => void + onValidateBeforeSaveCallback?: (newPayload: T) => boolean + isEditMode?: boolean + datasetBindings?: { id: string, name: string }[] +} + +export type ModelModalType = { + currentProvider: ModelProvider + currentConfigurationMethod: ConfigurationMethodEnum + currentCustomConfigurationModelFixedFields?: CustomConfigurationModelFixedFields + isModelCredential?: boolean + credential?: Credential + model?: CustomModel + mode?: ModelModalModeEnum +} + +export type ModalContextState = { + setShowAccountSettingModal: Dispatch | null>> + setShowApiBasedExtensionModal: Dispatch | null>> + setShowModerationSettingModal: Dispatch | null>> + setShowExternalDataToolModal: Dispatch | null>> + setShowPricingModal: () => void + setShowAnnotationFullModal: () => void + setShowModelModal: Dispatch | null>> + setShowExternalKnowledgeAPIModal: Dispatch | null>> + setShowModelLoadBalancingModal: Dispatch> + setShowOpeningModal: Dispatch void + }> | null>> + setShowUpdatePluginModal: Dispatch | null>> + setShowEducationExpireNoticeModal: Dispatch | null>> + setShowTriggerEventsLimitModal: Dispatch | null>> +} + +export const ModalContext = createContext({ + setShowAccountSettingModal: noop, + setShowApiBasedExtensionModal: noop, + setShowModerationSettingModal: noop, + setShowExternalDataToolModal: noop, + setShowPricingModal: noop, + setShowAnnotationFullModal: noop, + setShowModelModal: noop, + setShowExternalKnowledgeAPIModal: noop, + setShowModelLoadBalancingModal: noop, + setShowOpeningModal: noop, + setShowUpdatePluginModal: noop, + setShowEducationExpireNoticeModal: noop, + setShowTriggerEventsLimitModal: noop, +}) + +export const useModalContext = () => useContext(ModalContext) + +// Adding a dangling comma to avoid the generic parsing issue in tsx, see: +// https://github.com/microsoft/TypeScript/issues/15713 +export const useModalContextSelector = (selector: (state: ModalContextState) => T): T => + useContextSelector(ModalContext, selector) + +export default ModalContext diff --git a/web/context/provider-context.tsx b/web/context/provider-context-provider.tsx similarity index 70% rename from web/context/provider-context.tsx rename to web/context/provider-context-provider.tsx index 2a71d9cf93..ce7f2ba40c 100644 --- a/web/context/provider-context.tsx +++ b/web/context/provider-context-provider.tsx @@ -1,14 +1,10 @@ 'use client' -import type { Plan, UsagePlanInfo, UsageResetInfo } from '@/app/components/billing/type' -import type { Model, ModelProvider } from '@/app/components/header/account-setting/model-provider-page/declarations' -import type { RETRIEVE_METHOD } from '@/types/app' +import type { ReactNode } from 'react' import { useQueryClient } from '@tanstack/react-query' import dayjs from 'dayjs' -import { noop } from 'es-toolkit/function' import { useEffect, useState } from 'react' import { useTranslation } from 'react-i18next' -import { createContext, useContext, useContextSelector } from 'use-context-selector' import Toast from '@/app/components/base/toast' import { setZendeskConversationFields } from '@/app/components/base/zendesk/utils' import { defaultPlan } from '@/app/components/billing/config' @@ -25,93 +21,13 @@ import { useModelProviders, useSupportRetrievalMethods, } from '@/service/use-common' -import { - useEducationStatus, -} from '@/service/use-education' - -export type ProviderContextState = { - modelProviders: ModelProvider[] - refreshModelProviders: () => void - textGenerationModelList: Model[] - supportRetrievalMethods: RETRIEVE_METHOD[] - isAPIKeySet: boolean - plan: { - type: Plan - usage: UsagePlanInfo - total: UsagePlanInfo - reset: UsageResetInfo - } - isFetchedPlan: boolean - enableBilling: boolean - onPlanInfoChanged: () => void - enableReplaceWebAppLogo: boolean - modelLoadBalancingEnabled: boolean - datasetOperatorEnabled: boolean - enableEducationPlan: boolean - isEducationWorkspace: boolean - isEducationAccount: boolean - allowRefreshEducationVerify: boolean - educationAccountExpireAt: number | null - isLoadingEducationAccountInfo: boolean - isFetchingEducationAccountInfo: boolean - webappCopyrightEnabled: boolean - licenseLimit: { - workspace_members: { - size: number - limit: number - } - } - refreshLicenseLimit: () => void - isAllowTransferWorkspace: boolean - isAllowPublishAsCustomKnowledgePipelineTemplate: boolean - humanInputEmailDeliveryEnabled: boolean -} - -export const baseProviderContextValue: ProviderContextState = { - modelProviders: [], - refreshModelProviders: noop, - textGenerationModelList: [], - supportRetrievalMethods: [], - isAPIKeySet: true, - plan: defaultPlan, - isFetchedPlan: false, - enableBilling: false, - onPlanInfoChanged: noop, - enableReplaceWebAppLogo: false, - modelLoadBalancingEnabled: false, - datasetOperatorEnabled: false, - enableEducationPlan: false, - isEducationWorkspace: false, - isEducationAccount: false, - allowRefreshEducationVerify: false, - educationAccountExpireAt: null, - isLoadingEducationAccountInfo: false, - isFetchingEducationAccountInfo: false, - webappCopyrightEnabled: false, - licenseLimit: { - workspace_members: { - size: 0, - limit: 0, - }, - }, - refreshLicenseLimit: noop, - isAllowTransferWorkspace: false, - isAllowPublishAsCustomKnowledgePipelineTemplate: false, - humanInputEmailDeliveryEnabled: false, -} - -const ProviderContext = createContext(baseProviderContextValue) - -export const useProviderContext = () => useContext(ProviderContext) - -// Adding a dangling comma to avoid the generic parsing issue in tsx, see: -// https://github.com/microsoft/TypeScript/issues/15713 -export const useProviderContextSelector = (selector: (state: ProviderContextState) => T): T => - useContextSelector(ProviderContext, selector) +import { useEducationStatus } from '@/service/use-education' +import { ProviderContext } from './provider-context' type ProviderContextProviderProps = { - children: React.ReactNode + children: ReactNode } + export const ProviderContextProvider = ({ children, }: ProviderContextProviderProps) => { @@ -262,5 +178,3 @@ export const ProviderContextProvider = ({ ) } - -export default ProviderContext diff --git a/web/context/provider-context.ts b/web/context/provider-context.ts new file mode 100644 index 0000000000..27c43e7c91 --- /dev/null +++ b/web/context/provider-context.ts @@ -0,0 +1,90 @@ +'use client' + +import type { Plan, UsagePlanInfo, UsageResetInfo } from '@/app/components/billing/type' +import type { Model, ModelProvider } from '@/app/components/header/account-setting/model-provider-page/declarations' +import type { RETRIEVE_METHOD } from '@/types/app' +import { noop } from 'es-toolkit/function' +import { createContext, useContext, useContextSelector } from 'use-context-selector' +import { defaultPlan } from '@/app/components/billing/config' + +export type ProviderContextState = { + modelProviders: ModelProvider[] + refreshModelProviders: () => void + textGenerationModelList: Model[] + supportRetrievalMethods: RETRIEVE_METHOD[] + isAPIKeySet: boolean + plan: { + type: Plan + usage: UsagePlanInfo + total: UsagePlanInfo + reset: UsageResetInfo + } + isFetchedPlan: boolean + enableBilling: boolean + onPlanInfoChanged: () => void + enableReplaceWebAppLogo: boolean + modelLoadBalancingEnabled: boolean + datasetOperatorEnabled: boolean + enableEducationPlan: boolean + isEducationWorkspace: boolean + isEducationAccount: boolean + allowRefreshEducationVerify: boolean + educationAccountExpireAt: number | null + isLoadingEducationAccountInfo: boolean + isFetchingEducationAccountInfo: boolean + webappCopyrightEnabled: boolean + licenseLimit: { + workspace_members: { + size: number + limit: number + } + } + refreshLicenseLimit: () => void + isAllowTransferWorkspace: boolean + isAllowPublishAsCustomKnowledgePipelineTemplate: boolean + humanInputEmailDeliveryEnabled: boolean +} + +export const baseProviderContextValue: ProviderContextState = { + modelProviders: [], + refreshModelProviders: noop, + textGenerationModelList: [], + supportRetrievalMethods: [], + isAPIKeySet: true, + plan: defaultPlan, + isFetchedPlan: false, + enableBilling: false, + onPlanInfoChanged: noop, + enableReplaceWebAppLogo: false, + modelLoadBalancingEnabled: false, + datasetOperatorEnabled: false, + enableEducationPlan: false, + isEducationWorkspace: false, + isEducationAccount: false, + allowRefreshEducationVerify: false, + educationAccountExpireAt: null, + isLoadingEducationAccountInfo: false, + isFetchingEducationAccountInfo: false, + webappCopyrightEnabled: false, + licenseLimit: { + workspace_members: { + size: 0, + limit: 0, + }, + }, + refreshLicenseLimit: noop, + isAllowTransferWorkspace: false, + isAllowPublishAsCustomKnowledgePipelineTemplate: false, + humanInputEmailDeliveryEnabled: false, +} + +export const ProviderContext = createContext(baseProviderContextValue) + +export const useProviderContext = () => useContext(ProviderContext) + +// Adding a dangling comma to avoid the generic parsing issue in tsx, see: +// https://github.com/microsoft/TypeScript/issues/15713 +export const useProviderContextSelector = (selector: (state: ProviderContextState) => T): T => + useContextSelector(ProviderContext, selector) + +export default ProviderContext diff --git a/web/context/workspace-context-provider.tsx b/web/context/workspace-context-provider.tsx new file mode 100644 index 0000000000..afec62f710 --- /dev/null +++ b/web/context/workspace-context-provider.tsx @@ -0,0 +1,24 @@ +'use client' + +import type { ReactNode } from 'react' +import { useWorkspaces } from '@/service/use-common' +import { WorkspacesContext } from './workspace-context' + +type WorkspaceProviderProps = { + children: ReactNode +} + +export const WorkspaceProvider = ({ + children, +}: WorkspaceProviderProps) => { + const { data } = useWorkspaces() + + return ( + + {children} + + ) +} diff --git a/web/context/workspace-context.ts b/web/context/workspace-context.ts new file mode 100644 index 0000000000..e088d12f4e --- /dev/null +++ b/web/context/workspace-context.ts @@ -0,0 +1,16 @@ +'use client' + +import type { IWorkspace } from '@/models/common' +import { createContext, useContext } from 'use-context-selector' + +export type WorkspacesContextValue = { + workspaces: IWorkspace[] +} + +export const WorkspacesContext = createContext({ + workspaces: [], +}) + +export const useWorkspacesContext = () => useContext(WorkspacesContext) + +export default WorkspacesContext diff --git a/web/context/workspace-context.tsx b/web/context/workspace-context.tsx deleted file mode 100644 index 3834641bc1..0000000000 --- a/web/context/workspace-context.tsx +++ /dev/null @@ -1,36 +0,0 @@ -'use client' - -import type { IWorkspace } from '@/models/common' -import { createContext, useContext } from 'use-context-selector' -import { useWorkspaces } from '@/service/use-common' - -export type WorkspacesContextValue = { - workspaces: IWorkspace[] -} - -const WorkspacesContext = createContext({ - workspaces: [], -}) - -type IWorkspaceProviderProps = { - children: React.ReactNode -} - -export const WorkspaceProvider = ({ - children, -}: IWorkspaceProviderProps) => { - const { data } = useWorkspaces() - - return ( - - {children} - - ) -} - -export const useWorkspacesContext = () => useContext(WorkspacesContext) - -export default WorkspacesContext diff --git a/web/eslint-suppressions.json b/web/eslint-suppressions.json index 0880084320..22f225d1d0 100644 --- a/web/eslint-suppressions.json +++ b/web/eslint-suppressions.json @@ -968,11 +968,6 @@ "count": 6 } }, - "app/components/app/configuration/debug/debug-with-multiple-model/context.tsx": { - "react-refresh/only-export-components": { - "count": 1 - } - }, "app/components/app/configuration/debug/debug-with-multiple-model/debug-item.tsx": { "no-restricted-imports": { "count": 2 @@ -1561,7 +1556,7 @@ "count": 7 } }, - "app/components/base/chat/chat-with-history/context.tsx": { + "app/components/base/chat/chat-with-history/context.ts": { "ts/no-explicit-any": { "count": 7 } @@ -1715,11 +1710,6 @@ "count": 1 } }, - "app/components/base/chat/chat/context.tsx": { - "react-refresh/only-export-components": { - "count": 1 - } - }, "app/components/base/chat/chat/hooks.ts": { "react-hooks-extra/no-direct-set-state-in-use-effect": { "count": 2 @@ -1759,7 +1749,7 @@ "count": 7 } }, - "app/components/base/chat/embedded-chatbot/context.tsx": { + "app/components/base/chat/embedded-chatbot/context.ts": { "ts/no-explicit-any": { "count": 7 } @@ -2763,7 +2753,7 @@ "count": 2 } }, - "app/components/base/radio/context/index.tsx": { + "app/components/base/radio/context/index.ts": { "ts/no-explicit-any": { "count": 1 } @@ -2915,14 +2905,6 @@ "count": 1 } }, - "app/components/base/toast/index.tsx": { - "react-refresh/only-export-components": { - "count": 2 - }, - "tailwindcss/enforce-consistent-class-order": { - "count": 2 - } - }, "app/components/base/video-gallery/VideoPlayer.tsx": { "react-hooks-extra/no-direct-set-state-in-use-effect": { "count": 1 @@ -5683,10 +5665,7 @@ "count": 1 } }, - "app/components/plugins/plugin-page/context.tsx": { - "react-refresh/only-export-components": { - "count": 2 - }, + "app/components/plugins/plugin-page/context.ts": { "ts/no-explicit-any": { "count": 1 } @@ -9571,16 +9550,6 @@ "count": 5 } }, - "context/datasets-context.tsx": { - "react-refresh/only-export-components": { - "count": 1 - } - }, - "context/event-emitter.tsx": { - "react-refresh/only-export-components": { - "count": 1 - } - }, "context/external-api-panel-context.tsx": { "react-refresh/only-export-components": { "count": 1 @@ -9601,8 +9570,8 @@ "count": 3 } }, - "context/mitt-context.tsx": { - "react-refresh/only-export-components": { + "context/modal-context-provider.tsx": { + "ts/no-explicit-any": { "count": 3 } }, @@ -9611,18 +9580,12 @@ "count": 3 } }, - "context/modal-context.tsx": { - "react-refresh/only-export-components": { - "count": 2 - }, + "context/modal-context.ts": { "ts/no-explicit-any": { - "count": 5 + "count": 2 } }, - "context/provider-context.tsx": { - "react-refresh/only-export-components": { - "count": 3 - }, + "context/provider-context-provider.tsx": { "ts/no-explicit-any": { "count": 1 } @@ -9632,11 +9595,6 @@ "count": 1 } }, - "context/workspace-context.tsx": { - "react-refresh/only-export-components": { - "count": 1 - } - }, "hooks/use-async-window-open.spec.ts": { "ts/no-explicit-any": { "count": 6 diff --git a/web/hooks/use-import-dsl.ts b/web/hooks/use-import-dsl.ts index ba33db1e84..454f580b42 100644 --- a/web/hooks/use-import-dsl.ts +++ b/web/hooks/use-import-dsl.ts @@ -10,7 +10,7 @@ import { useState, } from 'react' import { useTranslation } from 'react-i18next' -import { useToastContext } from '@/app/components/base/toast' +import { useToastContext } from '@/app/components/base/toast/context' import { usePluginDependencies } from '@/app/components/workflow/plugin-dependency/hooks' import { NEED_REFRESH_APP_LIST_KEY } from '@/config' import { useSelector } from '@/context/app-context' From ebda5efe275154ca6015e665778a9f5c0678a6a8 Mon Sep 17 00:00:00 2001 From: yyh <92089059+lyzno1@users.noreply.github.com> Date: Thu, 5 Mar 2026 16:13:02 +0800 Subject: [PATCH 044/256] chore: prevent Storybook crash caused by vite-plugin-inspect (#33039) --- web/vite.config.ts | 25 ++++++++++++++++--------- 1 file changed, 16 insertions(+), 9 deletions(-) diff --git a/web/vite.config.ts b/web/vite.config.ts index b07e7ea7be..83a35f8558 100644 --- a/web/vite.config.ts +++ b/web/vite.config.ts @@ -160,6 +160,8 @@ if (import.meta.hot) { export default defineConfig(({ mode }) => { const isTest = mode === 'test' + const isStorybook = process.env.STORYBOOK === 'true' + || process.argv.some(arg => arg.toLowerCase().includes('storybook')) return { plugins: isTest @@ -176,14 +178,19 @@ export default defineConfig(({ mode }) => { }, } as Plugin, ] - : [ - Inspect(), - createCodeInspectorPlugin(), - createForceInspectorClientInjectionPlugin(), - react(), - vinext(), - customI18nHmrPlugin(), - ], + : isStorybook + ? [ + tsconfigPaths(), + react(), + ] + : [ + Inspect(), + createCodeInspectorPlugin(), + createForceInspectorClientInjectionPlugin(), + react(), + vinext(), + customI18nHmrPlugin(), + ], resolve: { alias: { '~@': __dirname, @@ -191,7 +198,7 @@ export default defineConfig(({ mode }) => { }, // vinext related config - ...(!isTest + ...(!isTest && !isStorybook ? { optimizeDeps: { exclude: ['nuqs'], From c913a629df9c40f5602b58b536d4d262e26616d0 Mon Sep 17 00:00:00 2001 From: Novice Date: Thu, 5 Mar 2026 16:53:28 +0800 Subject: [PATCH 045/256] feat: add partial indexes on conversations for app_id with created_at and updated_at (#32616) --- ...4_add_partial_indexes_on_conversations_.py | 37 +++++++++++++++++++ api/models/model.py | 12 ++++++ 2 files changed, 49 insertions(+) create mode 100644 api/migrations/versions/2026_02_26_1336-e288952f2994_add_partial_indexes_on_conversations_.py diff --git a/api/migrations/versions/2026_02_26_1336-e288952f2994_add_partial_indexes_on_conversations_.py b/api/migrations/versions/2026_02_26_1336-e288952f2994_add_partial_indexes_on_conversations_.py new file mode 100644 index 0000000000..ed794178b3 --- /dev/null +++ b/api/migrations/versions/2026_02_26_1336-e288952f2994_add_partial_indexes_on_conversations_.py @@ -0,0 +1,37 @@ +"""add partial indexes on conversations for app_id with created_at and updated_at + +Revision ID: e288952f2994 +Revises: fce013ca180e +Create Date: 2026-02-26 13:36:45.928922 + +""" +from alembic import op +import sqlalchemy as sa + +# revision identifiers, used by Alembic. +revision = 'e288952f2994' +down_revision = 'fce013ca180e' +branch_labels = None +depends_on = None + + +def upgrade(): + with op.batch_alter_table('conversations', schema=None) as batch_op: + batch_op.create_index( + 'conversation_app_created_at_idx', + ['app_id', sa.literal_column('created_at DESC')], + unique=False, + postgresql_where=sa.text('is_deleted IS false'), + ) + batch_op.create_index( + 'conversation_app_updated_at_idx', + ['app_id', sa.literal_column('updated_at DESC')], + unique=False, + postgresql_where=sa.text('is_deleted IS false'), + ) + + +def downgrade(): + with op.batch_alter_table('conversations', schema=None) as batch_op: + batch_op.drop_index('conversation_app_updated_at_idx') + batch_op.drop_index('conversation_app_created_at_idx') diff --git a/api/models/model.py b/api/models/model.py index a5bca06666..2bf80edb80 100644 --- a/api/models/model.py +++ b/api/models/model.py @@ -711,6 +711,18 @@ class Conversation(Base): __table_args__ = ( sa.PrimaryKeyConstraint("id", name="conversation_pkey"), sa.Index("conversation_app_from_user_idx", "app_id", "from_source", "from_end_user_id"), + sa.Index( + "conversation_app_created_at_idx", + "app_id", + sa.text("created_at DESC"), + postgresql_where=sa.text("is_deleted IS false"), + ), + sa.Index( + "conversation_app_updated_at_idx", + "app_id", + sa.text("updated_at DESC"), + postgresql_where=sa.text("is_deleted IS false"), + ), ) id: Mapped[str] = mapped_column(StringUUID, default=lambda: str(uuid4())) From 7d25415e4d766104c3d23341f589e2f7fb6de630 Mon Sep 17 00:00:00 2001 From: wangxiaolei Date: Thu, 5 Mar 2026 17:04:19 +0800 Subject: [PATCH 046/256] feat: 204 http status code not return content (#33023) --- api/controllers/service_api/app/annotation.py | 2 +- api/controllers/service_api/app/conversation.py | 3 +-- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/api/controllers/service_api/app/annotation.py b/api/controllers/service_api/app/annotation.py index ef254ca357..c22190cbc9 100644 --- a/api/controllers/service_api/app/annotation.py +++ b/api/controllers/service_api/app/annotation.py @@ -185,4 +185,4 @@ class AnnotationUpdateDeleteApi(Resource): def delete(self, app_model: App, annotation_id: str): """Delete an annotation.""" AppAnnotationService.delete_app_annotation(app_model.id, annotation_id) - return {"result": "success"}, 204 + return "", 204 diff --git a/api/controllers/service_api/app/conversation.py b/api/controllers/service_api/app/conversation.py index 8e29c9ff0f..edbf011656 100644 --- a/api/controllers/service_api/app/conversation.py +++ b/api/controllers/service_api/app/conversation.py @@ -14,7 +14,6 @@ from controllers.service_api.wraps import FetchUserArg, WhereisUserArg, validate from core.app.entities.app_invoke_entities import InvokeFrom from extensions.ext_database import db from fields.conversation_fields import ( - ConversationDelete, ConversationInfiniteScrollPagination, SimpleConversation, ) @@ -163,7 +162,7 @@ class ConversationDetailApi(Resource): ConversationService.delete(app_model, conversation_id, end_user) except services.errors.conversation.ConversationNotExistsError: raise NotFound("Conversation Not Exists.") - return ConversationDelete(result="success").model_dump(mode="json"), 204 + return "", 204 @service_api_ns.route("/conversations//name") From 92bde3503b691b45ee47b1e9f5f7d586f04bd107 Mon Sep 17 00:00:00 2001 From: wangxiaolei Date: Thu, 5 Mar 2026 17:13:35 +0800 Subject: [PATCH 047/256] fix: fix check workflow_run (#33028) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: 非法操作 Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- api/controllers/service_api/app/workflow.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/api/controllers/service_api/app/workflow.py b/api/controllers/service_api/app/workflow.py index b2148f4fa1..35dd22c801 100644 --- a/api/controllers/service_api/app/workflow.py +++ b/api/controllers/service_api/app/workflow.py @@ -132,6 +132,8 @@ class WorkflowRunDetailApi(Resource): app_id=app_model.id, run_id=workflow_run_id, ) + if not workflow_run: + raise NotFound("Workflow run not found.") return workflow_run From 187faed1c09044c8aac2f18ee422c5985daa760e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=9C=A8=E4=B9=8B=E6=9C=AC=E6=BE=AA?= Date: Fri, 6 Mar 2026 05:06:14 +0800 Subject: [PATCH 048/256] test: migrate test_dataset_service_delete_dataset SQL tests to testcontainers (#32543) Co-authored-by: KinomotoMio <200703522+KinomotoMio@users.noreply.github.com> --- .../test_dataset_service_delete_dataset.py | 244 ++++++++++++++++++ .../test_dataset_service_delete_dataset.py | 216 ---------------- 2 files changed, 244 insertions(+), 216 deletions(-) create mode 100644 api/tests/test_containers_integration_tests/services/test_dataset_service_delete_dataset.py delete mode 100644 api/tests/unit_tests/services/test_dataset_service_delete_dataset.py diff --git a/api/tests/test_containers_integration_tests/services/test_dataset_service_delete_dataset.py b/api/tests/test_containers_integration_tests/services/test_dataset_service_delete_dataset.py new file mode 100644 index 0000000000..c47e35791d --- /dev/null +++ b/api/tests/test_containers_integration_tests/services/test_dataset_service_delete_dataset.py @@ -0,0 +1,244 @@ +"""Container-backed integration tests for DatasetService.delete_dataset real SQL paths.""" + +from unittest.mock import patch +from uuid import uuid4 + +from models.account import Account, Tenant, TenantAccountJoin, TenantAccountRole +from models.dataset import Dataset, Document +from services.dataset_service import DatasetService + + +class DatasetDeleteIntegrationDataFactory: + """Create persisted entities used by delete_dataset integration tests.""" + + @staticmethod + def create_account_with_tenant(db_session_with_containers) -> tuple[Account, Tenant]: + """Persist an owner account, tenant, and tenant join for dataset deletion tests.""" + account = Account( + email=f"owner-{uuid4()}@example.com", + name="Owner", + interface_language="en-US", + status="active", + ) + db_session_with_containers.add(account) + db_session_with_containers.commit() + + tenant = Tenant( + name=f"tenant-{uuid4()}", + status="normal", + ) + db_session_with_containers.add(tenant) + db_session_with_containers.commit() + + join = TenantAccountJoin( + tenant_id=tenant.id, + account_id=account.id, + role=TenantAccountRole.OWNER, + current=True, + ) + db_session_with_containers.add(join) + db_session_with_containers.commit() + + account.current_tenant = tenant + return account, tenant + + @staticmethod + def create_dataset( + db_session_with_containers, + tenant_id: str, + created_by: str, + *, + indexing_technique: str | None, + chunk_structure: str | None, + index_struct: str | None = '{"type": "paragraph"}', + collection_binding_id: str | None = None, + pipeline_id: str | None = None, + ) -> Dataset: + """Persist a dataset with delete_dataset-relevant fields configured.""" + dataset = Dataset( + tenant_id=tenant_id, + name=f"dataset-{uuid4()}", + data_source_type="upload_file", + indexing_technique=indexing_technique, + index_struct=index_struct, + created_by=created_by, + collection_binding_id=collection_binding_id, + pipeline_id=pipeline_id, + chunk_structure=chunk_structure, + ) + db_session_with_containers.add(dataset) + db_session_with_containers.commit() + return dataset + + @staticmethod + def create_document( + db_session_with_containers, + *, + tenant_id: str, + dataset_id: str, + created_by: str, + doc_form: str = "text_model", + ) -> Document: + """Persist a document so dataset.doc_form resolves through the real document path.""" + document = Document( + tenant_id=tenant_id, + dataset_id=dataset_id, + position=1, + data_source_type="upload_file", + batch=f"batch-{uuid4()}", + name="Document", + created_from="upload_file", + created_by=created_by, + doc_form=doc_form, + ) + db_session_with_containers.add(document) + db_session_with_containers.commit() + return document + + +class TestDatasetServiceDeleteDataset: + """Integration coverage for DatasetService.delete_dataset using testcontainers.""" + + def test_delete_dataset_with_documents_success(self, db_session_with_containers): + """Delete a dataset with documents and dispatch cleanup through the real signal handler.""" + # Arrange + owner, tenant = DatasetDeleteIntegrationDataFactory.create_account_with_tenant(db_session_with_containers) + dataset = DatasetDeleteIntegrationDataFactory.create_dataset( + db_session_with_containers, + tenant_id=tenant.id, + created_by=owner.id, + indexing_technique="high_quality", + chunk_structure=None, + index_struct='{"type": "paragraph"}', + collection_binding_id=str(uuid4()), + pipeline_id=str(uuid4()), + ) + DatasetDeleteIntegrationDataFactory.create_document( + db_session_with_containers, + tenant_id=tenant.id, + dataset_id=dataset.id, + created_by=owner.id, + doc_form="text_model", + ) + + # Act + with patch( + "events.event_handlers.clean_when_dataset_deleted.clean_dataset_task.delay", + autospec=True, + ) as clean_dataset_delay: + result = DatasetService.delete_dataset(dataset.id, owner) + + # Assert + db_session_with_containers.expire_all() + assert result is True + assert db_session_with_containers.get(Dataset, dataset.id) is None + clean_dataset_delay.assert_called_once_with( + dataset.id, + dataset.tenant_id, + dataset.indexing_technique, + dataset.index_struct, + dataset.collection_binding_id, + dataset.doc_form, + dataset.pipeline_id, + ) + + def test_delete_empty_dataset_success(self, db_session_with_containers): + """Delete an empty dataset without scheduling cleanup when both gating fields are absent.""" + # Arrange + owner, tenant = DatasetDeleteIntegrationDataFactory.create_account_with_tenant(db_session_with_containers) + dataset = DatasetDeleteIntegrationDataFactory.create_dataset( + db_session_with_containers, + tenant_id=tenant.id, + created_by=owner.id, + indexing_technique=None, + chunk_structure=None, + index_struct=None, + collection_binding_id=None, + pipeline_id=None, + ) + + # Act + with patch( + "events.event_handlers.clean_when_dataset_deleted.clean_dataset_task.delay", + autospec=True, + ) as clean_dataset_delay: + result = DatasetService.delete_dataset(dataset.id, owner) + + # Assert + db_session_with_containers.expire_all() + assert result is True + assert db_session_with_containers.get(Dataset, dataset.id) is None + clean_dataset_delay.assert_not_called() + + def test_delete_dataset_with_partial_none_values(self, db_session_with_containers): + """Delete a dataset without cleanup when indexing_technique is missing but doc_form resolves.""" + # Arrange + owner, tenant = DatasetDeleteIntegrationDataFactory.create_account_with_tenant(db_session_with_containers) + dataset = DatasetDeleteIntegrationDataFactory.create_dataset( + db_session_with_containers, + tenant_id=tenant.id, + created_by=owner.id, + indexing_technique=None, + chunk_structure="text_model", + index_struct='{"type": "paragraph"}', + collection_binding_id=str(uuid4()), + pipeline_id=str(uuid4()), + ) + + # Act + with patch( + "events.event_handlers.clean_when_dataset_deleted.clean_dataset_task.delay", + autospec=True, + ) as clean_dataset_delay: + result = DatasetService.delete_dataset(dataset.id, owner) + + # Assert + db_session_with_containers.expire_all() + assert result is True + assert db_session_with_containers.get(Dataset, dataset.id) is None + clean_dataset_delay.assert_not_called() + + def test_delete_dataset_with_doc_form_none_indexing_technique_exists(self, db_session_with_containers): + """Delete a dataset without cleanup when indexing exists but doc_form resolves to None.""" + # Arrange + owner, tenant = DatasetDeleteIntegrationDataFactory.create_account_with_tenant(db_session_with_containers) + dataset = DatasetDeleteIntegrationDataFactory.create_dataset( + db_session_with_containers, + tenant_id=tenant.id, + created_by=owner.id, + indexing_technique="high_quality", + chunk_structure=None, + index_struct='{"type": "paragraph"}', + collection_binding_id=str(uuid4()), + pipeline_id=str(uuid4()), + ) + + # Act + with patch( + "events.event_handlers.clean_when_dataset_deleted.clean_dataset_task.delay", + autospec=True, + ) as clean_dataset_delay: + result = DatasetService.delete_dataset(dataset.id, owner) + + # Assert + db_session_with_containers.expire_all() + assert result is True + assert db_session_with_containers.get(Dataset, dataset.id) is None + clean_dataset_delay.assert_not_called() + + def test_delete_dataset_not_found(self, db_session_with_containers): + """Return False without scheduling cleanup when the target dataset does not exist.""" + # Arrange + owner, _ = DatasetDeleteIntegrationDataFactory.create_account_with_tenant(db_session_with_containers) + missing_dataset_id = str(uuid4()) + + # Act + with patch( + "events.event_handlers.clean_when_dataset_deleted.clean_dataset_task.delay", + autospec=True, + ) as clean_dataset_delay: + result = DatasetService.delete_dataset(missing_dataset_id, owner) + + # Assert + assert result is False + clean_dataset_delay.assert_not_called() diff --git a/api/tests/unit_tests/services/test_dataset_service_delete_dataset.py b/api/tests/unit_tests/services/test_dataset_service_delete_dataset.py deleted file mode 100644 index cc718c9997..0000000000 --- a/api/tests/unit_tests/services/test_dataset_service_delete_dataset.py +++ /dev/null @@ -1,216 +0,0 @@ -from unittest.mock import Mock, patch - -import pytest - -from models.account import Account, TenantAccountRole -from models.dataset import Dataset -from services.dataset_service import DatasetService - - -class DatasetDeleteTestDataFactory: - """Factory class for creating test data and mock objects for dataset delete tests.""" - - @staticmethod - def create_dataset_mock( - dataset_id: str = "dataset-123", - tenant_id: str = "test-tenant-123", - created_by: str = "creator-456", - doc_form: str | None = None, - indexing_technique: str | None = "high_quality", - **kwargs, - ) -> Mock: - """Create a mock dataset with specified attributes.""" - dataset = Mock(spec=Dataset) - dataset.id = dataset_id - dataset.tenant_id = tenant_id - dataset.created_by = created_by - dataset.doc_form = doc_form - dataset.indexing_technique = indexing_technique - for key, value in kwargs.items(): - setattr(dataset, key, value) - return dataset - - @staticmethod - def create_user_mock( - user_id: str = "user-789", - tenant_id: str = "test-tenant-123", - role: TenantAccountRole = TenantAccountRole.ADMIN, - **kwargs, - ) -> Mock: - """Create a mock user with specified attributes.""" - user = Mock(spec=Account) - user.id = user_id - user.current_tenant_id = tenant_id - user.current_role = role - for key, value in kwargs.items(): - setattr(user, key, value) - return user - - -class TestDatasetServiceDeleteDataset: - """ - Comprehensive unit tests for DatasetService.delete_dataset method. - - This test suite covers all deletion scenarios including: - - Normal dataset deletion with documents - - Empty dataset deletion (no documents, doc_form is None) - - Dataset deletion with missing indexing_technique - - Permission checks - - Event handling - - This test suite provides regression protection for issue #27073. - """ - - @pytest.fixture - def mock_dataset_service_dependencies(self): - """Common mock setup for dataset service dependencies.""" - with ( - patch("services.dataset_service.DatasetService.get_dataset") as mock_get_dataset, - patch("services.dataset_service.DatasetService.check_dataset_permission") as mock_check_perm, - patch("extensions.ext_database.db.session") as mock_db, - patch("services.dataset_service.dataset_was_deleted") as mock_dataset_was_deleted, - ): - yield { - "get_dataset": mock_get_dataset, - "check_permission": mock_check_perm, - "db_session": mock_db, - "dataset_was_deleted": mock_dataset_was_deleted, - } - - def test_delete_dataset_with_documents_success(self, mock_dataset_service_dependencies): - """ - Test successful deletion of a dataset with documents. - - This test verifies: - - Dataset is retrieved correctly - - Permission check is performed - - dataset_was_deleted event is sent - - Dataset is deleted from database - - Method returns True - """ - # Arrange - dataset = DatasetDeleteTestDataFactory.create_dataset_mock( - doc_form="text_model", indexing_technique="high_quality" - ) - user = DatasetDeleteTestDataFactory.create_user_mock() - - mock_dataset_service_dependencies["get_dataset"].return_value = dataset - - # Act - result = DatasetService.delete_dataset(dataset.id, user) - - # Assert - assert result is True - mock_dataset_service_dependencies["get_dataset"].assert_called_once_with(dataset.id) - mock_dataset_service_dependencies["check_permission"].assert_called_once_with(dataset, user) - mock_dataset_service_dependencies["dataset_was_deleted"].send.assert_called_once_with(dataset) - mock_dataset_service_dependencies["db_session"].delete.assert_called_once_with(dataset) - mock_dataset_service_dependencies["db_session"].commit.assert_called_once() - - def test_delete_empty_dataset_success(self, mock_dataset_service_dependencies): - """ - Test successful deletion of an empty dataset (no documents, doc_form is None). - - This test verifies that: - - Empty datasets can be deleted without errors - - dataset_was_deleted event is sent (event handler will skip cleanup if doc_form is None) - - Dataset is deleted from database - - Method returns True - - This is the primary test for issue #27073 where deleting an empty dataset - caused internal server error due to assertion failure in event handlers. - """ - # Arrange - dataset = DatasetDeleteTestDataFactory.create_dataset_mock(doc_form=None, indexing_technique=None) - user = DatasetDeleteTestDataFactory.create_user_mock() - - mock_dataset_service_dependencies["get_dataset"].return_value = dataset - - # Act - result = DatasetService.delete_dataset(dataset.id, user) - - # Assert - Verify complete deletion flow - assert result is True - mock_dataset_service_dependencies["get_dataset"].assert_called_once_with(dataset.id) - mock_dataset_service_dependencies["check_permission"].assert_called_once_with(dataset, user) - mock_dataset_service_dependencies["dataset_was_deleted"].send.assert_called_once_with(dataset) - mock_dataset_service_dependencies["db_session"].delete.assert_called_once_with(dataset) - mock_dataset_service_dependencies["db_session"].commit.assert_called_once() - - def test_delete_dataset_with_partial_none_values(self, mock_dataset_service_dependencies): - """ - Test deletion of dataset with partial None values. - - This test verifies that datasets with partial None values (e.g., doc_form exists - but indexing_technique is None) can be deleted successfully. The event handler - will skip cleanup if any required field is None. - - Improvement based on Gemini Code Assist suggestion: Added comprehensive assertions - to verify all core deletion operations are performed, not just event sending. - """ - # Arrange - dataset = DatasetDeleteTestDataFactory.create_dataset_mock(doc_form="text_model", indexing_technique=None) - user = DatasetDeleteTestDataFactory.create_user_mock() - - mock_dataset_service_dependencies["get_dataset"].return_value = dataset - - # Act - result = DatasetService.delete_dataset(dataset.id, user) - - # Assert - Verify complete deletion flow (Gemini suggestion implemented) - assert result is True - mock_dataset_service_dependencies["get_dataset"].assert_called_once_with(dataset.id) - mock_dataset_service_dependencies["check_permission"].assert_called_once_with(dataset, user) - mock_dataset_service_dependencies["dataset_was_deleted"].send.assert_called_once_with(dataset) - mock_dataset_service_dependencies["db_session"].delete.assert_called_once_with(dataset) - mock_dataset_service_dependencies["db_session"].commit.assert_called_once() - - def test_delete_dataset_with_doc_form_none_indexing_technique_exists(self, mock_dataset_service_dependencies): - """ - Test deletion of dataset where doc_form is None but indexing_technique exists. - - This edge case can occur in certain dataset configurations and should be handled - gracefully by the event handler's conditional check. - """ - # Arrange - dataset = DatasetDeleteTestDataFactory.create_dataset_mock(doc_form=None, indexing_technique="high_quality") - user = DatasetDeleteTestDataFactory.create_user_mock() - - mock_dataset_service_dependencies["get_dataset"].return_value = dataset - - # Act - result = DatasetService.delete_dataset(dataset.id, user) - - # Assert - Verify complete deletion flow - assert result is True - mock_dataset_service_dependencies["get_dataset"].assert_called_once_with(dataset.id) - mock_dataset_service_dependencies["check_permission"].assert_called_once_with(dataset, user) - mock_dataset_service_dependencies["dataset_was_deleted"].send.assert_called_once_with(dataset) - mock_dataset_service_dependencies["db_session"].delete.assert_called_once_with(dataset) - mock_dataset_service_dependencies["db_session"].commit.assert_called_once() - - def test_delete_dataset_not_found(self, mock_dataset_service_dependencies): - """ - Test deletion attempt when dataset doesn't exist. - - This test verifies that: - - Method returns False when dataset is not found - - No deletion operations are performed - - No events are sent - """ - # Arrange - dataset_id = "non-existent-dataset" - user = DatasetDeleteTestDataFactory.create_user_mock() - - mock_dataset_service_dependencies["get_dataset"].return_value = None - - # Act - result = DatasetService.delete_dataset(dataset_id, user) - - # Assert - assert result is False - mock_dataset_service_dependencies["get_dataset"].assert_called_once_with(dataset_id) - mock_dataset_service_dependencies["check_permission"].assert_not_called() - mock_dataset_service_dependencies["dataset_was_deleted"].send.assert_not_called() - mock_dataset_service_dependencies["db_session"].delete.assert_not_called() - mock_dataset_service_dependencies["db_session"].commit.assert_not_called() From ed0b27e4d6a37f2ed818573c834617187377c524 Mon Sep 17 00:00:00 2001 From: Lovish Arora <46993225+lavish0000@users.noreply.github.com> Date: Thu, 5 Mar 2026 22:26:28 +0100 Subject: [PATCH 049/256] chore(api): update Python type-checker versions (#33056) --- .../aliyun_trace/data_exporter/traceclient.py | 4 +- api/core/ops/tencent_trace/utils.py | 7 +- api/pyproject.toml | 6 +- api/uv.lock | 97 +++++++++++++------ 4 files changed, 74 insertions(+), 40 deletions(-) diff --git a/api/core/ops/aliyun_trace/data_exporter/traceclient.py b/api/core/ops/aliyun_trace/data_exporter/traceclient.py index 7624586367..0e00e90520 100644 --- a/api/core/ops/aliyun_trace/data_exporter/traceclient.py +++ b/api/core/ops/aliyun_trace/data_exporter/traceclient.py @@ -7,7 +7,7 @@ import uuid from collections import deque from collections.abc import Sequence from datetime import datetime -from typing import Final, cast +from typing import Final from urllib.parse import urljoin import httpx @@ -201,7 +201,7 @@ def convert_to_trace_id(uuid_v4: str | None) -> int: raise ValueError("UUID cannot be None") try: uuid_obj = uuid.UUID(uuid_v4) - return cast(int, uuid_obj.int) + return uuid_obj.int except ValueError as e: raise ValueError(f"Invalid UUID input: {uuid_v4}") from e diff --git a/api/core/ops/tencent_trace/utils.py b/api/core/ops/tencent_trace/utils.py index 96087951ab..678287ae1d 100644 --- a/api/core/ops/tencent_trace/utils.py +++ b/api/core/ops/tencent_trace/utils.py @@ -6,7 +6,6 @@ import hashlib import random import uuid from datetime import datetime -from typing import cast from opentelemetry.trace import Link, SpanContext, TraceFlags @@ -23,7 +22,7 @@ class TencentTraceUtils: uuid_obj = uuid.UUID(uuid_v4) if uuid_v4 else uuid.uuid4() except Exception as e: raise ValueError(f"Invalid UUID input: {e}") - return cast(int, uuid_obj.int) + return uuid_obj.int @staticmethod def convert_to_span_id(uuid_v4: str | None, span_type: str) -> int: @@ -52,9 +51,9 @@ class TencentTraceUtils: @staticmethod def create_link(trace_id_str: str) -> Link: try: - trace_id = int(trace_id_str, 16) if len(trace_id_str) == 32 else cast(int, uuid.UUID(trace_id_str).int) + trace_id = int(trace_id_str, 16) if len(trace_id_str) == 32 else uuid.UUID(trace_id_str).int except (ValueError, TypeError): - trace_id = cast(int, uuid.uuid4().int) + trace_id = uuid.uuid4().int span_context = SpanContext( trace_id=trace_id, diff --git a/api/pyproject.toml b/api/pyproject.toml index 84b95fb226..f39ed910e7 100644 --- a/api/pyproject.toml +++ b/api/pyproject.toml @@ -113,7 +113,7 @@ dev = [ "dotenv-linter~=0.5.0", "faker~=38.2.0", "lxml-stubs~=0.5.1", - "basedpyright~=1.31.0", + "basedpyright~=1.38.2", "ruff~=0.14.0", "pytest~=8.3.2", "pytest-benchmark~=4.0.0", @@ -167,12 +167,12 @@ dev = [ "import-linter>=2.3", "types-redis>=4.6.0.20241004", "celery-types>=0.23.0", - "mypy~=1.17.1", + "mypy~=1.19.1", # "locust>=2.40.4", # Temporarily removed due to compatibility issues. Uncomment when resolved. "sseclient-py>=1.8.0", "pytest-timeout>=2.4.0", "pytest-xdist>=3.8.0", - "pyrefly>=0.54.0", + "pyrefly>=0.55.0", ] ############################################################ diff --git a/api/uv.lock b/api/uv.lock index 5a9ac096dc..7436167d07 100644 --- a/api/uv.lock +++ b/api/uv.lock @@ -505,14 +505,14 @@ wheels = [ [[package]] name = "basedpyright" -version = "1.31.7" +version = "1.38.2" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "nodejs-wheel-binaries" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/c6/ba/ed69e8df732a09c8ca469f592c8e08707fe29149735b834c276d94d4a3da/basedpyright-1.31.7.tar.gz", hash = "sha256:394f334c742a19bcc5905b2455c9f5858182866b7679a6f057a70b44b049bceb", size = 22710948, upload-time = "2025-10-11T05:12:48.3Z" } +sdist = { url = "https://files.pythonhosted.org/packages/e4/a3/20aa7c4e83f2f614e0036300f3c352775dede0655c66814da16c37b661a9/basedpyright-1.38.2.tar.gz", hash = "sha256:b433b2b8ba745ed7520cdc79a29a03682f3fb00346d272ece5944e9e5e5daa92", size = 25277019, upload-time = "2026-02-26T11:18:43.594Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/f8/90/ce01ad2d0afdc1b82b8b5aaba27e60d2e138e39d887e71c35c55d8f1bfcd/basedpyright-1.31.7-py3-none-any.whl", hash = "sha256:7c54beb7828c9ed0028630aaa6904f395c27e5a9f5a313aa9e91fc1d11170831", size = 11817571, upload-time = "2025-10-11T05:12:45.432Z" }, + { url = "https://files.pythonhosted.org/packages/ac/12/736cab83626fea3fe65cdafb3ef3d2ee9480c56723f2fd33921537289a5e/basedpyright-1.38.2-py3-none-any.whl", hash = "sha256:153481d37fd19f9e3adedc8629d1d071b10c5f5e49321fb026b74444b7c70e24", size = 12312475, upload-time = "2026-02-26T11:18:40.373Z" }, ] [[package]] @@ -1660,7 +1660,7 @@ requires-dist = [ [package.metadata.requires-dev] dev = [ - { name = "basedpyright", specifier = "~=1.31.0" }, + { name = "basedpyright", specifier = "~=1.38.2" }, { name = "boto3-stubs", specifier = ">=1.38.20" }, { name = "celery-types", specifier = ">=0.23.0" }, { name = "coverage", specifier = "~=7.2.4" }, @@ -1669,9 +1669,9 @@ dev = [ { name = "hypothesis", specifier = ">=6.131.15" }, { name = "import-linter", specifier = ">=2.3" }, { name = "lxml-stubs", specifier = "~=0.5.1" }, - { name = "mypy", specifier = "~=1.17.1" }, + { name = "mypy", specifier = "~=1.19.1" }, { name = "pandas-stubs", specifier = "~=2.2.3" }, - { name = "pyrefly", specifier = ">=0.54.0" }, + { name = "pyrefly", specifier = ">=0.55.0" }, { name = "pytest", specifier = "~=8.3.2" }, { name = "pytest-benchmark", specifier = "~=4.0.0" }, { name = "pytest-cov", specifier = "~=4.1.0" }, @@ -3267,6 +3267,40 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/de/f0/63b06b99b730b9954f8709f6f7d9b8d076fa0a973e472efe278089bde42b/langsmith-0.1.147-py3-none-any.whl", hash = "sha256:7166fc23b965ccf839d64945a78e9f1157757add228b086141eb03a60d699a15", size = 311812, upload-time = "2024-11-27T17:32:39.569Z" }, ] +[[package]] +name = "librt" +version = "0.8.1" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/56/9c/b4b0c54d84da4a94b37bd44151e46d5e583c9534c7e02250b961b1b6d8a8/librt-0.8.1.tar.gz", hash = "sha256:be46a14693955b3bd96014ccbdb8339ee8c9346fbe11c1b78901b55125f14c73", size = 177471, upload-time = "2026-02-17T16:13:06.101Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/1d/01/0e748af5e4fee180cf7cd12bd12b0513ad23b045dccb2a83191bde82d168/librt-0.8.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:681dc2451d6d846794a828c16c22dc452d924e9f700a485b7ecb887a30aad1fd", size = 65315, upload-time = "2026-02-17T16:11:25.152Z" }, + { url = "https://files.pythonhosted.org/packages/9d/4d/7184806efda571887c798d573ca4134c80ac8642dcdd32f12c31b939c595/librt-0.8.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:a3b4350b13cc0e6f5bec8fa7caf29a8fb8cdc051a3bae45cfbfd7ce64f009965", size = 68021, upload-time = "2026-02-17T16:11:26.129Z" }, + { url = "https://files.pythonhosted.org/packages/ae/88/c3c52d2a5d5101f28d3dc89298444626e7874aa904eed498464c2af17627/librt-0.8.1-cp311-cp311-manylinux1_i686.manylinux_2_28_i686.manylinux_2_5_i686.whl", hash = "sha256:ac1e7817fd0ed3d14fd7c5df91daed84c48e4c2a11ee99c0547f9f62fdae13da", size = 194500, upload-time = "2026-02-17T16:11:27.177Z" }, + { url = "https://files.pythonhosted.org/packages/d6/5d/6fb0a25b6a8906e85b2c3b87bee1d6ed31510be7605b06772f9374ca5cb3/librt-0.8.1-cp311-cp311-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:747328be0c5b7075cde86a0e09d7a9196029800ba75a1689332348e998fb85c0", size = 205622, upload-time = "2026-02-17T16:11:28.242Z" }, + { url = "https://files.pythonhosted.org/packages/b2/a6/8006ae81227105476a45691f5831499e4d936b1c049b0c1feb17c11b02d1/librt-0.8.1-cp311-cp311-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:f0af2bd2bc204fa27f3d6711d0f360e6b8c684a035206257a81673ab924aa11e", size = 218304, upload-time = "2026-02-17T16:11:29.344Z" }, + { url = "https://files.pythonhosted.org/packages/ee/19/60e07886ad16670aae57ef44dada41912c90906a6fe9f2b9abac21374748/librt-0.8.1-cp311-cp311-manylinux_2_31_riscv64.manylinux_2_39_riscv64.whl", hash = "sha256:d480de377f5b687b6b1bc0c0407426da556e2a757633cc7e4d2e1a057aa688f3", size = 211493, upload-time = "2026-02-17T16:11:30.445Z" }, + { url = "https://files.pythonhosted.org/packages/9c/cf/f666c89d0e861d05600438213feeb818c7514d3315bae3648b1fc145d2b6/librt-0.8.1-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:d0ee06b5b5291f609ddb37b9750985b27bc567791bc87c76a569b3feed8481ac", size = 219129, upload-time = "2026-02-17T16:11:32.021Z" }, + { url = "https://files.pythonhosted.org/packages/8f/ef/f1bea01e40b4a879364c031476c82a0dc69ce068daad67ab96302fed2d45/librt-0.8.1-cp311-cp311-musllinux_1_2_i686.whl", hash = "sha256:9e2c6f77b9ad48ce5603b83b7da9ee3e36b3ab425353f695cba13200c5d96596", size = 213113, upload-time = "2026-02-17T16:11:33.192Z" }, + { url = "https://files.pythonhosted.org/packages/9b/80/cdab544370cc6bc1b72ea369525f547a59e6938ef6863a11ab3cd24759af/librt-0.8.1-cp311-cp311-musllinux_1_2_riscv64.whl", hash = "sha256:439352ba9373f11cb8e1933da194dcc6206daf779ff8df0ed69c5e39113e6a99", size = 212269, upload-time = "2026-02-17T16:11:34.373Z" }, + { url = "https://files.pythonhosted.org/packages/9d/9c/48d6ed8dac595654f15eceab2035131c136d1ae9a1e3548e777bb6dbb95d/librt-0.8.1-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:82210adabbc331dbb65d7868b105185464ef13f56f7f76688565ad79f648b0fe", size = 234673, upload-time = "2026-02-17T16:11:36.063Z" }, + { url = "https://files.pythonhosted.org/packages/16/01/35b68b1db517f27a01be4467593292eb5315def8900afad29fabf56304ba/librt-0.8.1-cp311-cp311-win32.whl", hash = "sha256:52c224e14614b750c0a6d97368e16804a98c684657c7518752c356834fff83bb", size = 54597, upload-time = "2026-02-17T16:11:37.544Z" }, + { url = "https://files.pythonhosted.org/packages/71/02/796fe8f02822235966693f257bf2c79f40e11337337a657a8cfebba5febc/librt-0.8.1-cp311-cp311-win_amd64.whl", hash = "sha256:c00e5c884f528c9932d278d5c9cbbea38a6b81eb62c02e06ae53751a83a4d52b", size = 61733, upload-time = "2026-02-17T16:11:38.691Z" }, + { url = "https://files.pythonhosted.org/packages/28/ad/232e13d61f879a42a4e7117d65e4984bb28371a34bb6fb9ca54ec2c8f54e/librt-0.8.1-cp311-cp311-win_arm64.whl", hash = "sha256:f7cdf7f26c2286ffb02e46d7bac56c94655540b26347673bea15fa52a6af17e9", size = 52273, upload-time = "2026-02-17T16:11:40.308Z" }, + { url = "https://files.pythonhosted.org/packages/95/21/d39b0a87ac52fc98f621fb6f8060efb017a767ebbbac2f99fbcbc9ddc0d7/librt-0.8.1-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:a28f2612ab566b17f3698b0da021ff9960610301607c9a5e8eaca62f5e1c350a", size = 66516, upload-time = "2026-02-17T16:11:41.604Z" }, + { url = "https://files.pythonhosted.org/packages/69/f1/46375e71441c43e8ae335905e069f1c54febee63a146278bcee8782c84fd/librt-0.8.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:60a78b694c9aee2a0f1aaeaa7d101cf713e92e8423a941d2897f4fa37908dab9", size = 68634, upload-time = "2026-02-17T16:11:43.268Z" }, + { url = "https://files.pythonhosted.org/packages/0a/33/c510de7f93bf1fa19e13423a606d8189a02624a800710f6e6a0a0f0784b3/librt-0.8.1-cp312-cp312-manylinux1_i686.manylinux_2_28_i686.manylinux_2_5_i686.whl", hash = "sha256:758509ea3f1eba2a57558e7e98f4659d0ea7670bff49673b0dde18a3c7e6c0eb", size = 198941, upload-time = "2026-02-17T16:11:44.28Z" }, + { url = "https://files.pythonhosted.org/packages/dd/36/e725903416409a533d92398e88ce665476f275081d0d7d42f9c4951999e5/librt-0.8.1-cp312-cp312-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:039b9f2c506bd0ab0f8725aa5ba339c6f0cd19d3b514b50d134789809c24285d", size = 209991, upload-time = "2026-02-17T16:11:45.462Z" }, + { url = "https://files.pythonhosted.org/packages/30/7a/8d908a152e1875c9f8eac96c97a480df425e657cdb47854b9efaa4998889/librt-0.8.1-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:5bb54f1205a3a6ab41a6fd71dfcdcbd278670d3a90ca502a30d9da583105b6f7", size = 224476, upload-time = "2026-02-17T16:11:46.542Z" }, + { url = "https://files.pythonhosted.org/packages/a8/b8/a22c34f2c485b8903a06f3fe3315341fe6876ef3599792344669db98fcff/librt-0.8.1-cp312-cp312-manylinux_2_31_riscv64.manylinux_2_39_riscv64.whl", hash = "sha256:05bd41cdee35b0c59c259f870f6da532a2c5ca57db95b5f23689fcb5c9e42440", size = 217518, upload-time = "2026-02-17T16:11:47.746Z" }, + { url = "https://files.pythonhosted.org/packages/79/6f/5c6fea00357e4f82ba44f81dbfb027921f1ab10e320d4a64e1c408d035d9/librt-0.8.1-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:adfab487facf03f0d0857b8710cf82d0704a309d8ffc33b03d9302b4c64e91a9", size = 225116, upload-time = "2026-02-17T16:11:49.298Z" }, + { url = "https://files.pythonhosted.org/packages/f2/a0/95ced4e7b1267fe1e2720a111685bcddf0e781f7e9e0ce59d751c44dcfe5/librt-0.8.1-cp312-cp312-musllinux_1_2_i686.whl", hash = "sha256:153188fe98a72f206042be10a2c6026139852805215ed9539186312d50a8e972", size = 217751, upload-time = "2026-02-17T16:11:50.49Z" }, + { url = "https://files.pythonhosted.org/packages/93/c2/0517281cb4d4101c27ab59472924e67f55e375bc46bedae94ac6dc6e1902/librt-0.8.1-cp312-cp312-musllinux_1_2_riscv64.whl", hash = "sha256:dd3c41254ee98604b08bd5b3af5bf0a89740d4ee0711de95b65166bf44091921", size = 218378, upload-time = "2026-02-17T16:11:51.783Z" }, + { url = "https://files.pythonhosted.org/packages/43/e8/37b3ac108e8976888e559a7b227d0ceac03c384cfd3e7a1c2ee248dbae79/librt-0.8.1-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:e0d138c7ae532908cbb342162b2611dbd4d90c941cd25ab82084aaf71d2c0bd0", size = 241199, upload-time = "2026-02-17T16:11:53.561Z" }, + { url = "https://files.pythonhosted.org/packages/4b/5b/35812d041c53967fedf551a39399271bbe4257e681236a2cf1a69c8e7fa1/librt-0.8.1-cp312-cp312-win32.whl", hash = "sha256:43353b943613c5d9c49a25aaffdba46f888ec354e71e3529a00cca3f04d66a7a", size = 54917, upload-time = "2026-02-17T16:11:54.758Z" }, + { url = "https://files.pythonhosted.org/packages/de/d1/fa5d5331b862b9775aaf2a100f5ef86854e5d4407f71bddf102f4421e034/librt-0.8.1-cp312-cp312-win_amd64.whl", hash = "sha256:ff8baf1f8d3f4b6b7257fcb75a501f2a5499d0dda57645baa09d4d0d34b19444", size = 62017, upload-time = "2026-02-17T16:11:55.748Z" }, + { url = "https://files.pythonhosted.org/packages/c7/7c/c614252f9acda59b01a66e2ddfd243ed1c7e1deab0293332dfbccf862808/librt-0.8.1-cp312-cp312-win_arm64.whl", hash = "sha256:0f2ae3725904f7377e11cc37722d5d401e8b3d5851fb9273d7f4fe04f6b3d37d", size = 52441, upload-time = "2026-02-17T16:11:56.801Z" }, +] + [[package]] name = "litellm" version = "1.77.1" @@ -3653,28 +3687,29 @@ wheels = [ [[package]] name = "mypy" -version = "1.17.1" +version = "1.19.1" source = { registry = "https://pypi.org/simple" } dependencies = [ + { name = "librt", marker = "platform_python_implementation != 'PyPy'" }, { name = "mypy-extensions" }, { name = "pathspec" }, { name = "typing-extensions" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/8e/22/ea637422dedf0bf36f3ef238eab4e455e2a0dcc3082b5cc067615347ab8e/mypy-1.17.1.tar.gz", hash = "sha256:25e01ec741ab5bb3eec8ba9cdb0f769230368a22c959c4937360efb89b7e9f01", size = 3352570, upload-time = "2025-07-31T07:54:19.204Z" } +sdist = { url = "https://files.pythonhosted.org/packages/f5/db/4efed9504bc01309ab9c2da7e352cc223569f05478012b5d9ece38fd44d2/mypy-1.19.1.tar.gz", hash = "sha256:19d88bb05303fe63f71dd2c6270daca27cb9401c4ca8255fe50d1d920e0eb9ba", size = 3582404, upload-time = "2025-12-15T05:03:48.42Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/46/cf/eadc80c4e0a70db1c08921dcc220357ba8ab2faecb4392e3cebeb10edbfa/mypy-1.17.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:ad37544be07c5d7fba814eb370e006df58fed8ad1ef33ed1649cb1889ba6ff58", size = 10921009, upload-time = "2025-07-31T07:53:23.037Z" }, - { url = "https://files.pythonhosted.org/packages/5d/c1/c869d8c067829ad30d9bdae051046561552516cfb3a14f7f0347b7d973ee/mypy-1.17.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:064e2ff508e5464b4bd807a7c1625bc5047c5022b85c70f030680e18f37273a5", size = 10047482, upload-time = "2025-07-31T07:53:26.151Z" }, - { url = "https://files.pythonhosted.org/packages/98/b9/803672bab3fe03cee2e14786ca056efda4bb511ea02dadcedde6176d06d0/mypy-1.17.1-cp311-cp311-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:70401bbabd2fa1aa7c43bb358f54037baf0586f41e83b0ae67dd0534fc64edfd", size = 11832883, upload-time = "2025-07-31T07:53:47.948Z" }, - { url = "https://files.pythonhosted.org/packages/88/fb/fcdac695beca66800918c18697b48833a9a6701de288452b6715a98cfee1/mypy-1.17.1-cp311-cp311-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:e92bdc656b7757c438660f775f872a669b8ff374edc4d18277d86b63edba6b8b", size = 12566215, upload-time = "2025-07-31T07:54:04.031Z" }, - { url = "https://files.pythonhosted.org/packages/7f/37/a932da3d3dace99ee8eb2043b6ab03b6768c36eb29a02f98f46c18c0da0e/mypy-1.17.1-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:c1fdf4abb29ed1cb091cf432979e162c208a5ac676ce35010373ff29247bcad5", size = 12751956, upload-time = "2025-07-31T07:53:36.263Z" }, - { url = "https://files.pythonhosted.org/packages/8c/cf/6438a429e0f2f5cab8bc83e53dbebfa666476f40ee322e13cac5e64b79e7/mypy-1.17.1-cp311-cp311-win_amd64.whl", hash = "sha256:ff2933428516ab63f961644bc49bc4cbe42bbffb2cd3b71cc7277c07d16b1a8b", size = 9507307, upload-time = "2025-07-31T07:53:59.734Z" }, - { url = "https://files.pythonhosted.org/packages/17/a2/7034d0d61af8098ec47902108553122baa0f438df8a713be860f7407c9e6/mypy-1.17.1-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:69e83ea6553a3ba79c08c6e15dbd9bfa912ec1e493bf75489ef93beb65209aeb", size = 11086295, upload-time = "2025-07-31T07:53:28.124Z" }, - { url = "https://files.pythonhosted.org/packages/14/1f/19e7e44b594d4b12f6ba8064dbe136505cec813549ca3e5191e40b1d3cc2/mypy-1.17.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:1b16708a66d38abb1e6b5702f5c2c87e133289da36f6a1d15f6a5221085c6403", size = 10112355, upload-time = "2025-07-31T07:53:21.121Z" }, - { url = "https://files.pythonhosted.org/packages/5b/69/baa33927e29e6b4c55d798a9d44db5d394072eef2bdc18c3e2048c9ed1e9/mypy-1.17.1-cp312-cp312-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:89e972c0035e9e05823907ad5398c5a73b9f47a002b22359b177d40bdaee7056", size = 11875285, upload-time = "2025-07-31T07:53:55.293Z" }, - { url = "https://files.pythonhosted.org/packages/90/13/f3a89c76b0a41e19490b01e7069713a30949d9a6c147289ee1521bcea245/mypy-1.17.1-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:03b6d0ed2b188e35ee6d5c36b5580cffd6da23319991c49ab5556c023ccf1341", size = 12737895, upload-time = "2025-07-31T07:53:43.623Z" }, - { url = "https://files.pythonhosted.org/packages/23/a1/c4ee79ac484241301564072e6476c5a5be2590bc2e7bfd28220033d2ef8f/mypy-1.17.1-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:c837b896b37cd103570d776bda106eabb8737aa6dd4f248451aecf53030cdbeb", size = 12931025, upload-time = "2025-07-31T07:54:17.125Z" }, - { url = "https://files.pythonhosted.org/packages/89/b8/7409477be7919a0608900e6320b155c72caab4fef46427c5cc75f85edadd/mypy-1.17.1-cp312-cp312-win_amd64.whl", hash = "sha256:665afab0963a4b39dff7c1fa563cc8b11ecff7910206db4b2e64dd1ba25aed19", size = 9584664, upload-time = "2025-07-31T07:54:12.842Z" }, - { url = "https://files.pythonhosted.org/packages/1d/f3/8fcd2af0f5b806f6cf463efaffd3c9548a28f84220493ecd38d127b6b66d/mypy-1.17.1-py3-none-any.whl", hash = "sha256:a9f52c0351c21fe24c21d8c0eb1f62967b262d6729393397b6f443c3b773c3b9", size = 2283411, upload-time = "2025-07-31T07:53:24.664Z" }, + { url = "https://files.pythonhosted.org/packages/ef/47/6b3ebabd5474d9cdc170d1342fbf9dddc1b0ec13ec90bf9004ee6f391c31/mypy-1.19.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:d8dfc6ab58ca7dda47d9237349157500468e404b17213d44fc1cb77bce532288", size = 13028539, upload-time = "2025-12-15T05:03:44.129Z" }, + { url = "https://files.pythonhosted.org/packages/5c/a6/ac7c7a88a3c9c54334f53a941b765e6ec6c4ebd65d3fe8cdcfbe0d0fd7db/mypy-1.19.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:e3f276d8493c3c97930e354b2595a44a21348b320d859fb4a2b9f66da9ed27ab", size = 12083163, upload-time = "2025-12-15T05:03:37.679Z" }, + { url = "https://files.pythonhosted.org/packages/67/af/3afa9cf880aa4a2c803798ac24f1d11ef72a0c8079689fac5cfd815e2830/mypy-1.19.1-cp311-cp311-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:2abb24cf3f17864770d18d673c85235ba52456b36a06b6afc1e07c1fdcd3d0e6", size = 12687629, upload-time = "2025-12-15T05:02:31.526Z" }, + { url = "https://files.pythonhosted.org/packages/2d/46/20f8a7114a56484ab268b0ab372461cb3a8f7deed31ea96b83a4e4cfcfca/mypy-1.19.1-cp311-cp311-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:a009ffa5a621762d0c926a078c2d639104becab69e79538a494bcccb62cc0331", size = 13436933, upload-time = "2025-12-15T05:03:15.606Z" }, + { url = "https://files.pythonhosted.org/packages/5b/f8/33b291ea85050a21f15da910002460f1f445f8007adb29230f0adea279cb/mypy-1.19.1-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:f7cee03c9a2e2ee26ec07479f38ea9c884e301d42c6d43a19d20fb014e3ba925", size = 13661754, upload-time = "2025-12-15T05:02:26.731Z" }, + { url = "https://files.pythonhosted.org/packages/fd/a3/47cbd4e85bec4335a9cd80cf67dbc02be21b5d4c9c23ad6b95d6c5196bac/mypy-1.19.1-cp311-cp311-win_amd64.whl", hash = "sha256:4b84a7a18f41e167f7995200a1d07a4a6810e89d29859df936f1c3923d263042", size = 10055772, upload-time = "2025-12-15T05:03:26.179Z" }, + { url = "https://files.pythonhosted.org/packages/06/8a/19bfae96f6615aa8a0604915512e0289b1fad33d5909bf7244f02935d33a/mypy-1.19.1-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:a8174a03289288c1f6c46d55cef02379b478bfbc8e358e02047487cad44c6ca1", size = 13206053, upload-time = "2025-12-15T05:03:46.622Z" }, + { url = "https://files.pythonhosted.org/packages/a5/34/3e63879ab041602154ba2a9f99817bb0c85c4df19a23a1443c8986e4d565/mypy-1.19.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:ffcebe56eb09ff0c0885e750036a095e23793ba6c2e894e7e63f6d89ad51f22e", size = 12219134, upload-time = "2025-12-15T05:03:24.367Z" }, + { url = "https://files.pythonhosted.org/packages/89/cc/2db6f0e95366b630364e09845672dbee0cbf0bbe753a204b29a944967cd9/mypy-1.19.1-cp312-cp312-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:b64d987153888790bcdb03a6473d321820597ab8dd9243b27a92153c4fa50fd2", size = 12731616, upload-time = "2025-12-15T05:02:44.725Z" }, + { url = "https://files.pythonhosted.org/packages/00/be/dd56c1fd4807bc1eba1cf18b2a850d0de7bacb55e158755eb79f77c41f8e/mypy-1.19.1-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:c35d298c2c4bba75feb2195655dfea8124d855dfd7343bf8b8c055421eaf0cf8", size = 13620847, upload-time = "2025-12-15T05:03:39.633Z" }, + { url = "https://files.pythonhosted.org/packages/6d/42/332951aae42b79329f743bf1da088cd75d8d4d9acc18fbcbd84f26c1af4e/mypy-1.19.1-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:34c81968774648ab5ac09c29a375fdede03ba253f8f8287847bd480782f73a6a", size = 13834976, upload-time = "2025-12-15T05:03:08.786Z" }, + { url = "https://files.pythonhosted.org/packages/6f/63/e7493e5f90e1e085c562bb06e2eb32cae27c5057b9653348d38b47daaecc/mypy-1.19.1-cp312-cp312-win_amd64.whl", hash = "sha256:b10e7c2cd7870ba4ad9b2d8a6102eb5ffc1f16ca35e3de6bfa390c1113029d13", size = 10118104, upload-time = "2025-12-15T05:03:10.834Z" }, + { url = "https://files.pythonhosted.org/packages/8d/f4/4ce9a05ce5ded1de3ec1c1d96cf9f9504a04e54ce0ed55cfa38619a32b8d/mypy-1.19.1-py3-none-any.whl", hash = "sha256:f1235f5ea01b7db5468d53ece6aaddf1ad0b88d9e7462b86ef96fe04995d7247", size = 2471239, upload-time = "2025-12-15T05:03:07.248Z" }, ] [[package]] @@ -5140,18 +5175,18 @@ wheels = [ [[package]] name = "pyrefly" -version = "0.54.0" +version = "0.55.0" source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/81/44/c10b16a302fda90d0af1328f880b232761b510eab546616a7be2fdf35a57/pyrefly-0.54.0.tar.gz", hash = "sha256:c6663be64d492f0d2f2a411ada9f28a6792163d34133639378b7f3dd9a8dca94", size = 5098893, upload-time = "2026-02-23T15:44:35.111Z" } +sdist = { url = "https://files.pythonhosted.org/packages/bf/c4/76e0797215e62d007f81f86c9c4fb5d6202685a3f5e70810f3fd94294f92/pyrefly-0.55.0.tar.gz", hash = "sha256:434c3282532dd4525c4840f2040ed0eb79b0ec8224fe18d957956b15471f2441", size = 5135682, upload-time = "2026-03-03T00:46:38.122Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/5f/99/8fdcdb4e55f0227fdd9f6abce36b619bab1ecb0662b83b66adc8cba3c788/pyrefly-0.54.0-py3-none-macosx_10_12_x86_64.whl", hash = "sha256:58a3f092b6dc25ef79b2dc6c69a40f36784ca157c312bfc0baea463926a9db6d", size = 12223973, upload-time = "2026-02-23T15:44:14.278Z" }, - { url = "https://files.pythonhosted.org/packages/90/35/c2aaf87a76003ad27b286594d2e5178f811eaa15bfe3d98dba2b47d56dd1/pyrefly-0.54.0-py3-none-macosx_11_0_arm64.whl", hash = "sha256:615081414106dd95873bc39c3a4bed68754c6cc24a8177ac51d22f88f88d3eb3", size = 11785585, upload-time = "2026-02-23T15:44:17.468Z" }, - { url = "https://files.pythonhosted.org/packages/c4/4a/ced02691ed67e5a897714979196f08ad279ec7ec7f63c45e00a75a7f3c0e/pyrefly-0.54.0-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:0cbcaf20f5fe585079079a95205c1f3cd4542d17228cdf1df560288880623b70", size = 33381977, upload-time = "2026-02-23T15:44:19.736Z" }, - { url = "https://files.pythonhosted.org/packages/0b/ce/72a117ed437c8f6950862181014b41e36f3c3997580e29b772b71e78d587/pyrefly-0.54.0-py3-none-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:66d5da116c0d34acfbd66663addd3ca8aa78a636f6692a66e078126d3620a883", size = 35962821, upload-time = "2026-02-23T15:44:22.357Z" }, - { url = "https://files.pythonhosted.org/packages/85/de/89013f5ae0a35d2b6b01274a92a35ee91431ea001050edf0a16748d39875/pyrefly-0.54.0-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6ef3ac27f1a4baaf67aead64287d3163350844794aca6315ad1a9650b16ec26a", size = 38496689, upload-time = "2026-02-23T15:44:25.236Z" }, - { url = "https://files.pythonhosted.org/packages/9f/9a/33b097c7bf498b924742dca32dd5d9c6a3fa6c2b52b63a58eb9e1980ca89/pyrefly-0.54.0-py3-none-win32.whl", hash = "sha256:7d607d72200a8afbd2db10bfefb40160a7a5d709d207161c21649cedd5cfc09a", size = 11295268, upload-time = "2026-02-23T15:44:27.551Z" }, - { url = "https://files.pythonhosted.org/packages/d4/21/9263fd1144d2a3d7342b474f183f7785b3358a1565c864089b780110b933/pyrefly-0.54.0-py3-none-win_amd64.whl", hash = "sha256:fd416f04f89309385696f685bd5c9141011f18c8072f84d31ca20c748546e791", size = 12081810, upload-time = "2026-02-23T15:44:29.461Z" }, - { url = "https://files.pythonhosted.org/packages/ea/5b/fad062a196c064cbc8564de5b2f4d3cb6315f852e3b31e8a1ce74c69a1ea/pyrefly-0.54.0-py3-none-win_arm64.whl", hash = "sha256:f06ab371356c7b1925e0bffe193b738797e71e5dbbff7fb5a13f90ee7521211d", size = 11564930, upload-time = "2026-02-23T15:44:33.053Z" }, + { url = "https://files.pythonhosted.org/packages/39/b0/16e50cf716784513648e23e726a24f71f9544aa4f86103032dcaa5ff71a2/pyrefly-0.55.0-py3-none-macosx_10_12_x86_64.whl", hash = "sha256:49aafcefe5e2dd4256147db93e5b0ada42bff7d9a60db70e03d1f7055338eec9", size = 12210073, upload-time = "2026-03-03T00:46:15.51Z" }, + { url = "https://files.pythonhosted.org/packages/3a/ad/89500c01bac3083383011600370289fbc67700c5be46e781787392628a3a/pyrefly-0.55.0-py3-none-macosx_11_0_arm64.whl", hash = "sha256:2827426e6b28397c13badb93c0ede0fb0f48046a7a89e3d774cda04e8e2067cd", size = 11767474, upload-time = "2026-03-03T00:46:18.003Z" }, + { url = "https://files.pythonhosted.org/packages/78/68/4c66b260f817f304ead11176ff13985625f7c269e653304b4bdb546551af/pyrefly-0.55.0-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:7346b2d64dc575bd61aa3bca854fbf8b5a19a471cbdb45e0ca1e09861b63488c", size = 33260395, upload-time = "2026-03-03T00:46:20.509Z" }, + { url = "https://files.pythonhosted.org/packages/47/09/10bd48c9f860064f29f412954126a827d60f6451512224912c265e26bbe6/pyrefly-0.55.0-py3-none-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:233b861b4cff008b1aff62f4f941577ed752e4d0060834229eb9b6826e6973c9", size = 35848269, upload-time = "2026-03-03T00:46:23.418Z" }, + { url = "https://files.pythonhosted.org/packages/a9/39/bc65cdd5243eb2dfea25dd1321f9a5a93e8d9c3a308501c4c6c05d011585/pyrefly-0.55.0-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f5aa85657d76da1d25d081a49f0e33c8fc3ec91c1a0f185a8ed393a5a3d9e178", size = 38449820, upload-time = "2026-03-03T00:46:26.309Z" }, + { url = "https://files.pythonhosted.org/packages/e5/64/58b38963b011af91209e87f868cc85cfc762ec49a4568ce610c45e7a5f40/pyrefly-0.55.0-py3-none-win32.whl", hash = "sha256:23f786a78536a56fed331b245b7d10ec8945bebee7b723491c8d66fdbc155fe6", size = 11259415, upload-time = "2026-03-03T00:46:30.875Z" }, + { url = "https://files.pythonhosted.org/packages/7a/0b/a4aa519ff632a1ea69eec942566951670b870b99b5c08407e1387b85b6a4/pyrefly-0.55.0-py3-none-win_amd64.whl", hash = "sha256:d465b49e999b50eeb069ad23f0f5710651cad2576f9452a82991bef557df91ee", size = 12043581, upload-time = "2026-03-03T00:46:33.674Z" }, + { url = "https://files.pythonhosted.org/packages/f1/51/89017636fbe1ffd166ad478990c6052df615b926182fa6d3c0842b407e89/pyrefly-0.55.0-py3-none-win_arm64.whl", hash = "sha256:732ff490e0e863b296e7c0b2471e08f8ba7952f9fa6e9de09d8347fd67dde77f", size = 11548076, upload-time = "2026-03-03T00:46:36.193Z" }, ] [[package]] From 98ba091a50e30898df78fccf7225909895c6838a Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Fri, 6 Mar 2026 06:48:59 +0900 Subject: [PATCH 050/256] chore(deps): bump dompurify from 3.3.0 to 3.3.2 in /web (#33062) Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- web/package.json | 2 +- web/pnpm-lock.yaml | 47 +++++++++++++++++----------------------------- 2 files changed, 18 insertions(+), 31 deletions(-) diff --git a/web/package.json b/web/package.json index 0633a499b4..e987b761e7 100644 --- a/web/package.json +++ b/web/package.json @@ -99,7 +99,7 @@ "cron-parser": "5.4.0", "dayjs": "1.11.19", "decimal.js": "10.6.0", - "dompurify": "3.3.0", + "dompurify": "3.3.2", "echarts": "5.6.0", "echarts-for-react": "3.0.5", "elkjs": "0.9.3", diff --git a/web/pnpm-lock.yaml b/web/pnpm-lock.yaml index 3620dbaae2..77109cdbb3 100644 --- a/web/pnpm-lock.yaml +++ b/web/pnpm-lock.yaml @@ -169,8 +169,8 @@ importers: specifier: 10.6.0 version: 10.6.0 dompurify: - specifier: 3.3.0 - version: 3.3.0 + specifier: 3.3.2 + version: 3.3.2 echarts: specifier: 5.6.0 version: 5.6.0 @@ -4492,8 +4492,9 @@ packages: dompurify@3.2.7: resolution: {integrity: sha512-WhL/YuveyGXJaerVlMYGWhvQswa7myDG17P7Vu65EWC05o8vfeNbvNf4d/BOvH99+ZW+LlQsc1GDKMa1vNK6dw==} - dompurify@3.3.0: - resolution: {integrity: sha512-r+f6MYR1gGN1eJv0TVQbhA7if/U7P87cdPl3HN5rikqaBSBxLiCb/b9O+2eG0cxz0ghyU+mU1QkbsOwERMYlWQ==} + dompurify@3.3.2: + resolution: {integrity: sha512-6obghkliLdmKa56xdbLOpUZ43pAR6xFy1uOrxBaIDjT+yaRuuybLjGS9eVBoSR/UPU5fq3OXClEHLJNGvbxKpQ==} + engines: {node: '>=20'} domutils@3.2.2: resolution: {integrity: sha512-6kZKyUajlDuqlHKVX1w7gyslj9MPIXzIFiz/rGu35uC1wMi+kMhQwGhl4lt9unC9Vb9INnY9Z3/ZA3+FhASLaw==} @@ -4657,7 +4658,7 @@ packages: eslint: '*' eslint-plugin-better-tailwindcss@https://pkg.pr.new/hyoban/eslint-plugin-better-tailwindcss@a520d15: - resolution: {integrity: sha512-hbxpqInIW0Q5UIwXEuQxSBjrMd5bYttXeSPU6dfK2zpECKNIzGR+KXZZEdZaPagEMDJosSyQ9RKievmBcCAxfA==, tarball: https://pkg.pr.new/hyoban/eslint-plugin-better-tailwindcss@a520d15} + resolution: {tarball: https://pkg.pr.new/hyoban/eslint-plugin-better-tailwindcss@a520d15} version: 4.3.1 engines: {node: ^20.19.0 || ^22.12.0 || >=23.0.0} peerDependencies: @@ -6548,9 +6549,6 @@ packages: resolution: {integrity: sha512-h36JMxKRqrAxVD8201FrCpyeNuUY9Y5zZwujr20fFO77tpUtGa6EZzfKw/3WaiBX95fq7+MpsuMLNdSnORAwSA==} engines: {node: '>=14.18.0'} - randombytes@2.1.0: - resolution: {integrity: sha512-vYl3iOX+4CKUWuxGi9Ukhie6fsqXqS9FE2Zaic4tNFD2N2QQaXOMFbuKK4QmDHC0JO6B1Zp41J0LpT0oR68amQ==} - rc@1.2.8: resolution: {integrity: sha512-y3bGgqKj3QBdxLbLkomlohkvsA8gdAiUQlSBJnBhfn+BPxg4bc62d8TcBW15wavDfgexCgccckhcZvywyQYPOw==} hasBin: true @@ -6946,9 +6944,6 @@ packages: engines: {node: '>=10'} hasBin: true - serialize-javascript@6.0.2: - resolution: {integrity: sha512-Saa1xPByTTq2gdeFZYLLo+RFE35NHZkAbqZeWNd3BpzppeVisAqpDjcp8dyf6uIvEqJRd46jemmyA4iFIeVk8g==} - seroval-plugins@1.5.0: resolution: {integrity: sha512-EAHqADIQondwRZIdeW2I636zgsODzoBDwb3PT/+7TLDWyw1Dy/Xv7iGUIEXXav7usHDE9HVhOU61irI3EnyyHA==} engines: {node: '>=10'} @@ -7223,8 +7218,8 @@ packages: engines: {node: '>=18'} deprecated: Old versions of tar are not supported, and contain widely publicized security vulnerabilities, which have been fixed in the current version. Please update. Support for old versions may be purchased (at exorbitant rates) by contacting i@izs.me - terser-webpack-plugin@5.3.16: - resolution: {integrity: sha512-h9oBFCWrq78NyWWVcSwZarJkZ01c2AyGrzs1crmHZO3QUg9D61Wu4NPjBy69n7JqylFF5y+CsUZYmYEIZ3mR+Q==} + terser-webpack-plugin@5.3.17: + resolution: {integrity: sha512-YR7PtUp6GMU91BgSJmlaX/rS2lGDbAF7D+Wtq7hRO+MiljNmodYvqslzCFiYVAgW+Qoaaia/QUIP4lGXufjdZw==} engines: {node: '>= 10.13.0'} peerDependencies: '@swc/core': '*' @@ -7567,7 +7562,7 @@ packages: resolution: {integrity: sha512-KzIbH/9tXat2u30jf+smMwFCsno4wHVdNmzFyL+T/L3UGqqk6JKfVqOFOZEpZSHADH1k40ab6NUIXZq422ov3Q==} vinext@https://pkg.pr.new/hyoban/vinext@556a6d6: - resolution: {integrity: sha512-Sz8RkTDsY6cnGrevlQi4nXgahu8okEGsdKY5m31d/L9tXo35bNETMHfVee5gaI2UKZS9LMcffWaTOxxINUgogQ==, tarball: https://pkg.pr.new/hyoban/vinext@556a6d6} + resolution: {tarball: https://pkg.pr.new/hyoban/vinext@556a6d6} version: 0.0.5 engines: {node: '>=22'} hasBin: true @@ -12198,7 +12193,7 @@ snapshots: optionalDependencies: '@types/trusted-types': 2.0.7 - dompurify@3.3.0: + dompurify@3.3.2: optionalDependencies: '@types/trusted-types': 2.0.7 @@ -13978,7 +13973,7 @@ snapshots: d3-sankey: 0.12.3 dagre-d3-es: 7.0.11 dayjs: 1.11.19 - dompurify: 3.3.0 + dompurify: 3.3.2 katex: 0.16.25 khroma: 2.1.0 lodash-es: 4.17.23 @@ -14124,8 +14119,8 @@ snapshots: micromark-extension-mdxjs@3.0.0: dependencies: - acorn: 8.16.0 - acorn-jsx: 5.3.2(acorn@8.16.0) + acorn: 8.15.0 + acorn-jsx: 5.3.2(acorn@8.15.0) micromark-extension-mdx-expression: 3.0.1 micromark-extension-mdx-jsx: 3.0.2 micromark-extension-mdx-md: 2.0.0 @@ -14776,10 +14771,6 @@ snapshots: radash@12.1.1: {} - randombytes@2.1.0: - dependencies: - safe-buffer: 5.2.1 - rc@1.2.8: dependencies: deep-extend: 0.6.0 @@ -15284,7 +15275,8 @@ snapshots: dependencies: tslib: 2.8.1 - safe-buffer@5.2.1: {} + safe-buffer@5.2.1: + optional: true sass@1.93.2: dependencies: @@ -15335,10 +15327,6 @@ snapshots: semver@7.7.4: {} - serialize-javascript@6.0.2: - dependencies: - randombytes: 2.1.0 - seroval-plugins@1.5.0(seroval@1.5.0): dependencies: seroval: 1.5.0 @@ -15681,12 +15669,11 @@ snapshots: minizlib: 3.1.0 yallist: 5.0.0 - terser-webpack-plugin@5.3.16(esbuild@0.27.2)(uglify-js@3.19.3)(webpack@5.104.1(esbuild@0.27.2)(uglify-js@3.19.3)): + terser-webpack-plugin@5.3.17(esbuild@0.27.2)(uglify-js@3.19.3)(webpack@5.104.1(esbuild@0.27.2)(uglify-js@3.19.3)): dependencies: '@jridgewell/trace-mapping': 0.3.31 jest-worker: 27.5.1 schema-utils: 4.3.3 - serialize-javascript: 6.0.2 terser: 5.46.0 webpack: 5.104.1(esbuild@0.27.2)(uglify-js@3.19.3) optionalDependencies: @@ -16249,7 +16236,7 @@ snapshots: neo-async: 2.6.2 schema-utils: 4.3.3 tapable: 2.3.0 - terser-webpack-plugin: 5.3.16(esbuild@0.27.2)(uglify-js@3.19.3)(webpack@5.104.1(esbuild@0.27.2)(uglify-js@3.19.3)) + terser-webpack-plugin: 5.3.17(esbuild@0.27.2)(uglify-js@3.19.3)(webpack@5.104.1(esbuild@0.27.2)(uglify-js@3.19.3)) watchpack: 2.5.1 webpack-sources: 3.3.4 transitivePeerDependencies: From f76de73be47aa29104f305737a482fb3f35d3fb7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=9C=A8=E4=B9=8B=E6=9C=AC=E6=BE=AA?= Date: Fri, 6 Mar 2026 06:21:25 +0800 Subject: [PATCH 051/256] test: migrate dataset permission service SQL tests to testcontainers (#32546) Co-authored-by: KinomotoMio <200703522+KinomotoMio@users.noreply.github.com> Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- .../test_dataset_permission_service.py | 497 +++++++++++++++ .../services/dataset_permission_service.py | 587 ------------------ 2 files changed, 497 insertions(+), 587 deletions(-) create mode 100644 api/tests/test_containers_integration_tests/services/test_dataset_permission_service.py diff --git a/api/tests/test_containers_integration_tests/services/test_dataset_permission_service.py b/api/tests/test_containers_integration_tests/services/test_dataset_permission_service.py new file mode 100644 index 0000000000..44525e0036 --- /dev/null +++ b/api/tests/test_containers_integration_tests/services/test_dataset_permission_service.py @@ -0,0 +1,497 @@ +""" +Container-backed integration tests for dataset permission services on the real SQL path. + +This module exercises persisted DatasetPermission rows and dataset permission +checks with testcontainers-backed infrastructure instead of database-chain mocks. +""" + +from uuid import uuid4 + +import pytest + +from extensions.ext_database import db +from models import Account, Tenant, TenantAccountJoin, TenantAccountRole +from models.dataset import ( + Dataset, + DatasetPermission, + DatasetPermissionEnum, +) +from services.dataset_service import DatasetPermissionService, DatasetService +from services.errors.account import NoPermissionError + + +class DatasetPermissionTestDataFactory: + """Create persisted entities and request payloads for dataset permission integration tests.""" + + @staticmethod + def create_account_with_tenant( + role: TenantAccountRole = TenantAccountRole.NORMAL, + tenant: Tenant | None = None, + ) -> tuple[Account, Tenant]: + """Create a real account and tenant with specified role.""" + account = Account( + email=f"{uuid4()}@example.com", + name=f"user-{uuid4()}", + interface_language="en-US", + status="active", + ) + if tenant is None: + tenant = Tenant(name=f"tenant-{uuid4()}", status="normal") + db.session.add_all([account, tenant]) + else: + db.session.add(account) + + db.session.flush() + + join = TenantAccountJoin( + tenant_id=tenant.id, + account_id=account.id, + role=role, + current=True, + ) + db.session.add(join) + db.session.commit() + + account.current_tenant = tenant + return account, tenant + + @staticmethod + def create_dataset( + tenant_id: str, + created_by: str, + permission: DatasetPermissionEnum = DatasetPermissionEnum.ONLY_ME, + name: str = "Test Dataset", + ) -> Dataset: + """Create a real dataset with specified attributes.""" + dataset = Dataset( + tenant_id=tenant_id, + name=name, + description="desc", + data_source_type="upload_file", + indexing_technique="high_quality", + created_by=created_by, + permission=permission, + provider="vendor", + retrieval_model={"top_k": 2}, + ) + db.session.add(dataset) + db.session.commit() + return dataset + + @staticmethod + def create_dataset_permission( + dataset_id: str, + account_id: str, + tenant_id: str, + has_permission: bool = True, + ) -> DatasetPermission: + """Create a real DatasetPermission instance.""" + permission = DatasetPermission( + dataset_id=dataset_id, + account_id=account_id, + tenant_id=tenant_id, + has_permission=has_permission, + ) + db.session.add(permission) + db.session.commit() + return permission + + @staticmethod + def build_user_list_payload(user_ids: list[str]) -> list[dict[str, str]]: + """Build the request payload shape used by partial-member list updates.""" + return [{"user_id": user_id} for user_id in user_ids] + + +class TestDatasetPermissionServiceGetPartialMemberList: + """Verify partial-member list reads against persisted DatasetPermission rows.""" + + def test_get_dataset_partial_member_list_with_members(self, db_session_with_containers): + """ + Test retrieving partial member list with multiple members. + """ + # Arrange + owner, tenant = DatasetPermissionTestDataFactory.create_account_with_tenant(role=TenantAccountRole.OWNER) + user_1, _ = DatasetPermissionTestDataFactory.create_account_with_tenant( + role=TenantAccountRole.NORMAL, + tenant=tenant, + ) + user_2, _ = DatasetPermissionTestDataFactory.create_account_with_tenant( + role=TenantAccountRole.NORMAL, + tenant=tenant, + ) + user_3, _ = DatasetPermissionTestDataFactory.create_account_with_tenant( + role=TenantAccountRole.NORMAL, + tenant=tenant, + ) + dataset = DatasetPermissionTestDataFactory.create_dataset(tenant.id, owner.id) + + expected_account_ids = [user_1.id, user_2.id, user_3.id] + for account_id in expected_account_ids: + DatasetPermissionTestDataFactory.create_dataset_permission(dataset.id, account_id, tenant.id) + + # Act + result = DatasetPermissionService.get_dataset_partial_member_list(dataset.id) + + # Assert + assert set(result) == set(expected_account_ids) + assert len(result) == 3 + + def test_get_dataset_partial_member_list_with_single_member(self, db_session_with_containers): + """ + Test retrieving partial member list with single member. + """ + # Arrange + owner, tenant = DatasetPermissionTestDataFactory.create_account_with_tenant(role=TenantAccountRole.OWNER) + user, _ = DatasetPermissionTestDataFactory.create_account_with_tenant( + role=TenantAccountRole.NORMAL, + tenant=tenant, + ) + dataset = DatasetPermissionTestDataFactory.create_dataset(tenant.id, owner.id) + + expected_account_ids = [user.id] + DatasetPermissionTestDataFactory.create_dataset_permission(dataset.id, user.id, tenant.id) + + # Act + result = DatasetPermissionService.get_dataset_partial_member_list(dataset.id) + + # Assert + assert set(result) == set(expected_account_ids) + assert len(result) == 1 + + def test_get_dataset_partial_member_list_empty(self, db_session_with_containers): + """ + Test retrieving partial member list when no members exist. + """ + # Arrange + owner, tenant = DatasetPermissionTestDataFactory.create_account_with_tenant(role=TenantAccountRole.OWNER) + dataset = DatasetPermissionTestDataFactory.create_dataset(tenant.id, owner.id) + + # Act + result = DatasetPermissionService.get_dataset_partial_member_list(dataset.id) + + # Assert + assert result == [] + assert len(result) == 0 + + +class TestDatasetPermissionServiceUpdatePartialMemberList: + """Verify partial-member list updates against persisted DatasetPermission rows.""" + + def test_update_partial_member_list_add_new_members(self, db_session_with_containers): + """ + Test adding new partial members to a dataset. + """ + # Arrange + owner, tenant = DatasetPermissionTestDataFactory.create_account_with_tenant(role=TenantAccountRole.OWNER) + member_1, _ = DatasetPermissionTestDataFactory.create_account_with_tenant( + role=TenantAccountRole.NORMAL, + tenant=tenant, + ) + member_2, _ = DatasetPermissionTestDataFactory.create_account_with_tenant( + role=TenantAccountRole.NORMAL, + tenant=tenant, + ) + dataset = DatasetPermissionTestDataFactory.create_dataset(tenant.id, owner.id) + user_list = DatasetPermissionTestDataFactory.build_user_list_payload([member_1.id, member_2.id]) + + # Act + DatasetPermissionService.update_partial_member_list(tenant.id, dataset.id, user_list) + + # Assert + result = DatasetPermissionService.get_dataset_partial_member_list(dataset.id) + assert set(result) == {member_1.id, member_2.id} + + def test_update_partial_member_list_replace_existing(self, db_session_with_containers): + """ + Test replacing existing partial members with new ones. + """ + # Arrange + owner, tenant = DatasetPermissionTestDataFactory.create_account_with_tenant(role=TenantAccountRole.OWNER) + old_member_1, _ = DatasetPermissionTestDataFactory.create_account_with_tenant( + role=TenantAccountRole.NORMAL, + tenant=tenant, + ) + old_member_2, _ = DatasetPermissionTestDataFactory.create_account_with_tenant( + role=TenantAccountRole.NORMAL, + tenant=tenant, + ) + new_member_1, _ = DatasetPermissionTestDataFactory.create_account_with_tenant( + role=TenantAccountRole.NORMAL, + tenant=tenant, + ) + new_member_2, _ = DatasetPermissionTestDataFactory.create_account_with_tenant( + role=TenantAccountRole.NORMAL, + tenant=tenant, + ) + dataset = DatasetPermissionTestDataFactory.create_dataset(tenant.id, owner.id) + + old_users = DatasetPermissionTestDataFactory.build_user_list_payload([old_member_1.id, old_member_2.id]) + DatasetPermissionService.update_partial_member_list(tenant.id, dataset.id, old_users) + + new_users = DatasetPermissionTestDataFactory.build_user_list_payload([new_member_1.id, new_member_2.id]) + + # Act + DatasetPermissionService.update_partial_member_list(tenant.id, dataset.id, new_users) + + # Assert + result = DatasetPermissionService.get_dataset_partial_member_list(dataset.id) + assert set(result) == {new_member_1.id, new_member_2.id} + + def test_update_partial_member_list_empty_list(self, db_session_with_containers): + """ + Test updating with empty member list (clearing all members). + """ + # Arrange + owner, tenant = DatasetPermissionTestDataFactory.create_account_with_tenant(role=TenantAccountRole.OWNER) + member_1, _ = DatasetPermissionTestDataFactory.create_account_with_tenant( + role=TenantAccountRole.NORMAL, + tenant=tenant, + ) + member_2, _ = DatasetPermissionTestDataFactory.create_account_with_tenant( + role=TenantAccountRole.NORMAL, + tenant=tenant, + ) + dataset = DatasetPermissionTestDataFactory.create_dataset(tenant.id, owner.id) + users = DatasetPermissionTestDataFactory.build_user_list_payload([member_1.id, member_2.id]) + DatasetPermissionService.update_partial_member_list(tenant.id, dataset.id, users) + + # Act + DatasetPermissionService.update_partial_member_list(tenant.id, dataset.id, []) + + # Assert + result = DatasetPermissionService.get_dataset_partial_member_list(dataset.id) + assert result == [] + + def test_update_partial_member_list_database_error_rollback(self, db_session_with_containers): + """ + Test error handling and rollback on database error. + """ + # Arrange + owner, tenant = DatasetPermissionTestDataFactory.create_account_with_tenant(role=TenantAccountRole.OWNER) + existing_member, _ = DatasetPermissionTestDataFactory.create_account_with_tenant( + role=TenantAccountRole.NORMAL, + tenant=tenant, + ) + replacement_member, _ = DatasetPermissionTestDataFactory.create_account_with_tenant( + role=TenantAccountRole.NORMAL, + tenant=tenant, + ) + dataset = DatasetPermissionTestDataFactory.create_dataset(tenant.id, owner.id) + DatasetPermissionService.update_partial_member_list( + tenant.id, + dataset.id, + DatasetPermissionTestDataFactory.build_user_list_payload([existing_member.id]), + ) + user_list = DatasetPermissionTestDataFactory.build_user_list_payload([replacement_member.id]) + rollback_called = {"count": 0} + original_rollback = db.session.rollback + + # Act / Assert + with pytest.MonkeyPatch.context() as mp: + + def _raise_commit(): + raise Exception("Database connection error") + + def _rollback_and_mark(): + rollback_called["count"] += 1 + original_rollback() + + mp.setattr("services.dataset_service.db.session.commit", _raise_commit) + mp.setattr("services.dataset_service.db.session.rollback", _rollback_and_mark) + with pytest.raises(Exception, match="Database connection error"): + DatasetPermissionService.update_partial_member_list(tenant.id, dataset.id, user_list) + + # Assert + result = DatasetPermissionService.get_dataset_partial_member_list(dataset.id) + assert rollback_called["count"] == 1 + assert result == [existing_member.id] + assert db_session_with_containers.query(DatasetPermission).filter_by(dataset_id=dataset.id).count() == 1 + + +class TestDatasetPermissionServiceClearPartialMemberList: + """Verify partial-member clearing against persisted DatasetPermission rows.""" + + def test_clear_partial_member_list_success(self, db_session_with_containers): + """ + Test successful clearing of partial member list. + """ + # Arrange + owner, tenant = DatasetPermissionTestDataFactory.create_account_with_tenant(role=TenantAccountRole.OWNER) + member_1, _ = DatasetPermissionTestDataFactory.create_account_with_tenant( + role=TenantAccountRole.NORMAL, + tenant=tenant, + ) + member_2, _ = DatasetPermissionTestDataFactory.create_account_with_tenant( + role=TenantAccountRole.NORMAL, + tenant=tenant, + ) + dataset = DatasetPermissionTestDataFactory.create_dataset(tenant.id, owner.id) + users = DatasetPermissionTestDataFactory.build_user_list_payload([member_1.id, member_2.id]) + DatasetPermissionService.update_partial_member_list(tenant.id, dataset.id, users) + + # Act + DatasetPermissionService.clear_partial_member_list(dataset.id) + + # Assert + result = DatasetPermissionService.get_dataset_partial_member_list(dataset.id) + assert result == [] + + def test_clear_partial_member_list_empty_list(self, db_session_with_containers): + """ + Test clearing partial member list when no members exist. + """ + # Arrange + owner, tenant = DatasetPermissionTestDataFactory.create_account_with_tenant(role=TenantAccountRole.OWNER) + dataset = DatasetPermissionTestDataFactory.create_dataset(tenant.id, owner.id) + + # Act + DatasetPermissionService.clear_partial_member_list(dataset.id) + + # Assert + result = DatasetPermissionService.get_dataset_partial_member_list(dataset.id) + assert result == [] + + def test_clear_partial_member_list_database_error_rollback(self, db_session_with_containers): + """ + Test error handling and rollback on database error. + """ + # Arrange + owner, tenant = DatasetPermissionTestDataFactory.create_account_with_tenant(role=TenantAccountRole.OWNER) + member_1, _ = DatasetPermissionTestDataFactory.create_account_with_tenant( + role=TenantAccountRole.NORMAL, + tenant=tenant, + ) + member_2, _ = DatasetPermissionTestDataFactory.create_account_with_tenant( + role=TenantAccountRole.NORMAL, + tenant=tenant, + ) + dataset = DatasetPermissionTestDataFactory.create_dataset(tenant.id, owner.id) + users = DatasetPermissionTestDataFactory.build_user_list_payload([member_1.id, member_2.id]) + DatasetPermissionService.update_partial_member_list(tenant.id, dataset.id, users) + rollback_called = {"count": 0} + original_rollback = db.session.rollback + + # Act / Assert + with pytest.MonkeyPatch.context() as mp: + + def _raise_commit(): + raise Exception("Database connection error") + + def _rollback_and_mark(): + rollback_called["count"] += 1 + original_rollback() + + mp.setattr("services.dataset_service.db.session.commit", _raise_commit) + mp.setattr("services.dataset_service.db.session.rollback", _rollback_and_mark) + with pytest.raises(Exception, match="Database connection error"): + DatasetPermissionService.clear_partial_member_list(dataset.id) + + # Assert + result = DatasetPermissionService.get_dataset_partial_member_list(dataset.id) + assert rollback_called["count"] == 1 + assert set(result) == {member_1.id, member_2.id} + assert db_session_with_containers.query(DatasetPermission).filter_by(dataset_id=dataset.id).count() == 2 + + +class TestDatasetServiceCheckDatasetPermission: + """Verify dataset access checks against persisted partial-member permissions.""" + + def test_check_dataset_permission_partial_members_with_permission_success(self, db_session_with_containers): + """ + Test that user with explicit permission can access partial_members dataset. + """ + # Arrange + owner, tenant = DatasetPermissionTestDataFactory.create_account_with_tenant(role=TenantAccountRole.OWNER) + user, _ = DatasetPermissionTestDataFactory.create_account_with_tenant( + role=TenantAccountRole.NORMAL, + tenant=tenant, + ) + + dataset = DatasetPermissionTestDataFactory.create_dataset( + tenant.id, + owner.id, + permission=DatasetPermissionEnum.PARTIAL_TEAM, + ) + DatasetPermissionTestDataFactory.create_dataset_permission(dataset.id, user.id, tenant.id) + + # Act (should not raise) + DatasetService.check_dataset_permission(dataset, user) + + # Assert + permissions = DatasetPermissionService.get_dataset_partial_member_list(dataset.id) + assert user.id in permissions + + def test_check_dataset_permission_partial_members_without_permission_error(self, db_session_with_containers): + """ + Test error when user without permission tries to access partial_members dataset. + """ + # Arrange + owner, tenant = DatasetPermissionTestDataFactory.create_account_with_tenant(role=TenantAccountRole.OWNER) + user, _ = DatasetPermissionTestDataFactory.create_account_with_tenant( + role=TenantAccountRole.NORMAL, + tenant=tenant, + ) + + dataset = DatasetPermissionTestDataFactory.create_dataset( + tenant.id, + owner.id, + permission=DatasetPermissionEnum.PARTIAL_TEAM, + ) + + # Act & Assert + with pytest.raises(NoPermissionError, match="You do not have permission to access this dataset"): + DatasetService.check_dataset_permission(dataset, user) + + +class TestDatasetServiceCheckDatasetOperatorPermission: + """Verify operator permission checks against persisted partial-member permissions.""" + + def test_check_dataset_operator_permission_partial_members_with_permission_success( + self, db_session_with_containers + ): + """ + Test that user with explicit permission can access partial_members dataset. + """ + # Arrange + owner, tenant = DatasetPermissionTestDataFactory.create_account_with_tenant(role=TenantAccountRole.OWNER) + user, _ = DatasetPermissionTestDataFactory.create_account_with_tenant( + role=TenantAccountRole.NORMAL, + tenant=tenant, + ) + + dataset = DatasetPermissionTestDataFactory.create_dataset( + tenant.id, + owner.id, + permission=DatasetPermissionEnum.PARTIAL_TEAM, + ) + DatasetPermissionTestDataFactory.create_dataset_permission(dataset.id, user.id, tenant.id) + + # Act (should not raise) + DatasetService.check_dataset_operator_permission(user=user, dataset=dataset) + + # Assert + permissions = DatasetPermissionService.get_dataset_partial_member_list(dataset.id) + assert user.id in permissions + + def test_check_dataset_operator_permission_partial_members_without_permission_error( + self, db_session_with_containers + ): + """ + Test error when user without permission tries to access partial_members dataset. + """ + # Arrange + owner, tenant = DatasetPermissionTestDataFactory.create_account_with_tenant(role=TenantAccountRole.OWNER) + user, _ = DatasetPermissionTestDataFactory.create_account_with_tenant( + role=TenantAccountRole.NORMAL, + tenant=tenant, + ) + + dataset = DatasetPermissionTestDataFactory.create_dataset( + tenant.id, + owner.id, + permission=DatasetPermissionEnum.PARTIAL_TEAM, + ) + + # Act & Assert + with pytest.raises(NoPermissionError, match="You do not have permission to access this dataset"): + DatasetService.check_dataset_operator_permission(user=user, dataset=dataset) diff --git a/api/tests/unit_tests/services/dataset_permission_service.py b/api/tests/unit_tests/services/dataset_permission_service.py index b687f472a5..e098e90455 100644 --- a/api/tests/unit_tests/services/dataset_permission_service.py +++ b/api/tests/unit_tests/services/dataset_permission_service.py @@ -258,323 +258,6 @@ class DatasetPermissionTestDataFactory: return [{"user_id": user_id} for user_id in user_ids] -# ============================================================================ -# Tests for get_dataset_partial_member_list -# ============================================================================ - - -class TestDatasetPermissionServiceGetPartialMemberList: - """ - Comprehensive unit tests for DatasetPermissionService.get_dataset_partial_member_list method. - - This test class covers the retrieval of partial member lists for datasets, - which returns a list of account IDs that have explicit permissions for - a given dataset. - - The get_dataset_partial_member_list method: - 1. Queries DatasetPermission table for the dataset ID - 2. Selects account_id values - 3. Returns list of account IDs - - Test scenarios include: - - Retrieving list with multiple members - - Retrieving list with single member - - Retrieving empty list (no partial members) - - Database query validation - """ - - @pytest.fixture - def mock_db_session(self): - """ - Mock database session for testing. - - Provides a mocked database session that can be used to verify - query construction and execution. - """ - with patch("services.dataset_service.db.session") as mock_db: - yield mock_db - - def test_get_dataset_partial_member_list_with_members(self, mock_db_session): - """ - Test retrieving partial member list with multiple members. - - Verifies that when a dataset has multiple partial members, all - account IDs are returned correctly. - - This test ensures: - - Query is constructed correctly - - All account IDs are returned - - Database query is executed - """ - # Arrange - dataset_id = "dataset-123" - expected_account_ids = ["user-456", "user-789", "user-012"] - - # Mock the scalars query to return account IDs - mock_scalars_result = Mock() - mock_scalars_result.all.return_value = expected_account_ids - mock_db_session.scalars.return_value = mock_scalars_result - - # Act - result = DatasetPermissionService.get_dataset_partial_member_list(dataset_id) - - # Assert - assert result == expected_account_ids - assert len(result) == 3 - - # Verify query was executed - mock_db_session.scalars.assert_called_once() - - def test_get_dataset_partial_member_list_with_single_member(self, mock_db_session): - """ - Test retrieving partial member list with single member. - - Verifies that when a dataset has only one partial member, the - single account ID is returned correctly. - - This test ensures: - - Query works correctly for single member - - Result is a list with one element - - Database query is executed - """ - # Arrange - dataset_id = "dataset-123" - expected_account_ids = ["user-456"] - - # Mock the scalars query to return single account ID - mock_scalars_result = Mock() - mock_scalars_result.all.return_value = expected_account_ids - mock_db_session.scalars.return_value = mock_scalars_result - - # Act - result = DatasetPermissionService.get_dataset_partial_member_list(dataset_id) - - # Assert - assert result == expected_account_ids - assert len(result) == 1 - - # Verify query was executed - mock_db_session.scalars.assert_called_once() - - def test_get_dataset_partial_member_list_empty(self, mock_db_session): - """ - Test retrieving partial member list when no members exist. - - Verifies that when a dataset has no partial members, an empty - list is returned. - - This test ensures: - - Empty list is returned correctly - - Query is executed even when no results - - No errors are raised - """ - # Arrange - dataset_id = "dataset-123" - - # Mock the scalars query to return empty list - mock_scalars_result = Mock() - mock_scalars_result.all.return_value = [] - mock_db_session.scalars.return_value = mock_scalars_result - - # Act - result = DatasetPermissionService.get_dataset_partial_member_list(dataset_id) - - # Assert - assert result == [] - assert len(result) == 0 - - # Verify query was executed - mock_db_session.scalars.assert_called_once() - - -# ============================================================================ -# Tests for update_partial_member_list -# ============================================================================ - - -class TestDatasetPermissionServiceUpdatePartialMemberList: - """ - Comprehensive unit tests for DatasetPermissionService.update_partial_member_list method. - - This test class covers the update of partial member lists for datasets, - which replaces the existing partial member list with a new one. - - The update_partial_member_list method: - 1. Deletes all existing DatasetPermission records for the dataset - 2. Creates new DatasetPermission records for each user in the list - 3. Adds all new permissions to the session - 4. Commits the transaction - 5. Rolls back on error - - Test scenarios include: - - Adding new partial members - - Updating existing partial members - - Replacing entire member list - - Handling empty member list - - Database transaction handling - - Error handling and rollback - """ - - @pytest.fixture - def mock_db_session(self): - """ - Mock database session for testing. - - Provides a mocked database session that can be used to verify - database operations including queries, adds, commits, and rollbacks. - """ - with patch("services.dataset_service.db.session") as mock_db: - yield mock_db - - def test_update_partial_member_list_add_new_members(self, mock_db_session): - """ - Test adding new partial members to a dataset. - - Verifies that when updating with new members, the old members - are deleted and new members are added correctly. - - This test ensures: - - Old permissions are deleted - - New permissions are created - - All permissions are added to session - - Transaction is committed - """ - # Arrange - tenant_id = "tenant-123" - dataset_id = "dataset-123" - user_list = DatasetPermissionTestDataFactory.create_user_list_mock(["user-456", "user-789"]) - - # Mock the query delete operation - mock_query = Mock() - mock_query.where.return_value = mock_query - mock_query.delete.return_value = None - mock_db_session.query.return_value = mock_query - - # Act - DatasetPermissionService.update_partial_member_list(tenant_id, dataset_id, user_list) - - # Assert - # Verify old permissions were deleted - mock_db_session.query.assert_called() - mock_query.where.assert_called() - - # Verify new permissions were added - mock_db_session.add_all.assert_called_once() - - # Verify transaction was committed - mock_db_session.commit.assert_called_once() - - # Verify no rollback occurred - mock_db_session.rollback.assert_not_called() - - def test_update_partial_member_list_replace_existing(self, mock_db_session): - """ - Test replacing existing partial members with new ones. - - Verifies that when updating with a different member list, the - old members are removed and new members are added. - - This test ensures: - - Old permissions are deleted - - New permissions replace old ones - - Transaction is committed successfully - """ - # Arrange - tenant_id = "tenant-123" - dataset_id = "dataset-123" - user_list = DatasetPermissionTestDataFactory.create_user_list_mock(["user-999", "user-888"]) - - # Mock the query delete operation - mock_query = Mock() - mock_query.where.return_value = mock_query - mock_query.delete.return_value = None - mock_db_session.query.return_value = mock_query - - # Act - DatasetPermissionService.update_partial_member_list(tenant_id, dataset_id, user_list) - - # Assert - # Verify old permissions were deleted - mock_db_session.query.assert_called() - - # Verify new permissions were added - mock_db_session.add_all.assert_called_once() - - # Verify transaction was committed - mock_db_session.commit.assert_called_once() - - def test_update_partial_member_list_empty_list(self, mock_db_session): - """ - Test updating with empty member list (clearing all members). - - Verifies that when updating with an empty list, all existing - permissions are deleted and no new permissions are added. - - This test ensures: - - Old permissions are deleted - - No new permissions are added - - Transaction is committed - """ - # Arrange - tenant_id = "tenant-123" - dataset_id = "dataset-123" - user_list = [] - - # Mock the query delete operation - mock_query = Mock() - mock_query.where.return_value = mock_query - mock_query.delete.return_value = None - mock_db_session.query.return_value = mock_query - - # Act - DatasetPermissionService.update_partial_member_list(tenant_id, dataset_id, user_list) - - # Assert - # Verify old permissions were deleted - mock_db_session.query.assert_called() - - # Verify add_all was called with empty list - mock_db_session.add_all.assert_called_once_with([]) - - # Verify transaction was committed - mock_db_session.commit.assert_called_once() - - def test_update_partial_member_list_database_error_rollback(self, mock_db_session): - """ - Test error handling and rollback on database error. - - Verifies that when a database error occurs during the update, - the transaction is rolled back and the error is re-raised. - - This test ensures: - - Error is caught and handled - - Transaction is rolled back - - Error is re-raised - - No commit occurs after error - """ - # Arrange - tenant_id = "tenant-123" - dataset_id = "dataset-123" - user_list = DatasetPermissionTestDataFactory.create_user_list_mock(["user-456"]) - - # Mock the query delete operation - mock_query = Mock() - mock_query.where.return_value = mock_query - mock_query.delete.return_value = None - mock_db_session.query.return_value = mock_query - - # Mock commit to raise an error - database_error = Exception("Database connection error") - mock_db_session.commit.side_effect = database_error - - # Act & Assert - with pytest.raises(Exception, match="Database connection error"): - DatasetPermissionService.update_partial_member_list(tenant_id, dataset_id, user_list) - - # Verify rollback was called - mock_db_session.rollback.assert_called_once() - - # ============================================================================ # Tests for check_permission # ============================================================================ @@ -776,144 +459,6 @@ class TestDatasetPermissionServiceCheckPermission: mock_get_partial_member_list.assert_called_once_with(dataset.id) -# ============================================================================ -# Tests for clear_partial_member_list -# ============================================================================ - - -class TestDatasetPermissionServiceClearPartialMemberList: - """ - Comprehensive unit tests for DatasetPermissionService.clear_partial_member_list method. - - This test class covers the clearing of partial member lists, which removes - all DatasetPermission records for a given dataset. - - The clear_partial_member_list method: - 1. Deletes all DatasetPermission records for the dataset - 2. Commits the transaction - 3. Rolls back on error - - Test scenarios include: - - Clearing list with existing members - - Clearing empty list (no members) - - Database transaction handling - - Error handling and rollback - """ - - @pytest.fixture - def mock_db_session(self): - """ - Mock database session for testing. - - Provides a mocked database session that can be used to verify - database operations including queries, deletes, commits, and rollbacks. - """ - with patch("services.dataset_service.db.session") as mock_db: - yield mock_db - - def test_clear_partial_member_list_success(self, mock_db_session): - """ - Test successful clearing of partial member list. - - Verifies that when clearing a partial member list, all permissions - are deleted and the transaction is committed. - - This test ensures: - - All permissions are deleted - - Transaction is committed - - No errors are raised - """ - # Arrange - dataset_id = "dataset-123" - - # Mock the query delete operation - mock_query = Mock() - mock_query.where.return_value = mock_query - mock_query.delete.return_value = None - mock_db_session.query.return_value = mock_query - - # Act - DatasetPermissionService.clear_partial_member_list(dataset_id) - - # Assert - # Verify query was executed - mock_db_session.query.assert_called() - - # Verify delete was called - mock_query.where.assert_called() - mock_query.delete.assert_called_once() - - # Verify transaction was committed - mock_db_session.commit.assert_called_once() - - # Verify no rollback occurred - mock_db_session.rollback.assert_not_called() - - def test_clear_partial_member_list_empty_list(self, mock_db_session): - """ - Test clearing partial member list when no members exist. - - Verifies that when clearing an already empty list, the operation - completes successfully without errors. - - This test ensures: - - Operation works correctly for empty lists - - Transaction is committed - - No errors are raised - """ - # Arrange - dataset_id = "dataset-123" - - # Mock the query delete operation - mock_query = Mock() - mock_query.where.return_value = mock_query - mock_query.delete.return_value = None - mock_db_session.query.return_value = mock_query - - # Act - DatasetPermissionService.clear_partial_member_list(dataset_id) - - # Assert - # Verify query was executed - mock_db_session.query.assert_called() - - # Verify transaction was committed - mock_db_session.commit.assert_called_once() - - def test_clear_partial_member_list_database_error_rollback(self, mock_db_session): - """ - Test error handling and rollback on database error. - - Verifies that when a database error occurs during clearing, - the transaction is rolled back and the error is re-raised. - - This test ensures: - - Error is caught and handled - - Transaction is rolled back - - Error is re-raised - - No commit occurs after error - """ - # Arrange - dataset_id = "dataset-123" - - # Mock the query delete operation - mock_query = Mock() - mock_query.where.return_value = mock_query - mock_query.delete.return_value = None - mock_db_session.query.return_value = mock_query - - # Mock commit to raise an error - database_error = Exception("Database connection error") - mock_db_session.commit.side_effect = database_error - - # Act & Assert - with pytest.raises(Exception, match="Database connection error"): - DatasetPermissionService.clear_partial_member_list(dataset_id) - - # Verify rollback was called - mock_db_session.rollback.assert_called_once() - - # ============================================================================ # Tests for DatasetService.check_dataset_permission # ============================================================================ @@ -1047,72 +592,6 @@ class TestDatasetServiceCheckDatasetPermission: with pytest.raises(NoPermissionError, match="You do not have permission to access this dataset"): DatasetService.check_dataset_permission(dataset, user) - def test_check_dataset_permission_partial_members_with_permission_success(self, mock_db_session): - """ - Test that user with explicit permission can access partial_members dataset. - - Verifies that when a user has an explicit DatasetPermission record - for a partial_members dataset, they can access it successfully. - - This test ensures: - - Explicit permissions are checked correctly - - Users with permissions can access - - Database query is executed - """ - # Arrange - user = DatasetPermissionTestDataFactory.create_user_mock(user_id="user-123", role=TenantAccountRole.NORMAL) - dataset = DatasetPermissionTestDataFactory.create_dataset_mock( - tenant_id="tenant-123", - permission=DatasetPermissionEnum.PARTIAL_TEAM, - created_by="other-user-456", # Not the creator - ) - - # Mock permission query to return permission record - mock_permission = DatasetPermissionTestDataFactory.create_dataset_permission_mock( - dataset_id=dataset.id, account_id=user.id - ) - mock_query = Mock() - mock_query.filter_by.return_value = mock_query - mock_query.first.return_value = mock_permission - mock_db_session.query.return_value = mock_query - - # Act (should not raise) - DatasetService.check_dataset_permission(dataset, user) - - # Assert - # Verify permission query was executed - mock_db_session.query.assert_called() - - def test_check_dataset_permission_partial_members_without_permission_error(self, mock_db_session): - """ - Test error when user without permission tries to access partial_members dataset. - - Verifies that when a user does not have an explicit DatasetPermission - record for a partial_members dataset, a NoPermissionError is raised. - - This test ensures: - - Missing permissions are detected - - Error message is clear - - Error type is correct - """ - # Arrange - user = DatasetPermissionTestDataFactory.create_user_mock(user_id="user-123", role=TenantAccountRole.NORMAL) - dataset = DatasetPermissionTestDataFactory.create_dataset_mock( - tenant_id="tenant-123", - permission=DatasetPermissionEnum.PARTIAL_TEAM, - created_by="other-user-456", # Not the creator - ) - - # Mock permission query to return None (no permission) - mock_query = Mock() - mock_query.filter_by.return_value = mock_query - mock_query.first.return_value = None # No permission found - mock_db_session.query.return_value = mock_query - - # Act & Assert - with pytest.raises(NoPermissionError, match="You do not have permission to access this dataset"): - DatasetService.check_dataset_permission(dataset, user) - def test_check_dataset_permission_partial_members_creator_success(self, mock_db_session): """ Test that creator can access partial_members dataset without explicit permission. @@ -1311,72 +790,6 @@ class TestDatasetServiceCheckDatasetOperatorPermission: with pytest.raises(NoPermissionError, match="You do not have permission to access this dataset"): DatasetService.check_dataset_operator_permission(user=user, dataset=dataset) - def test_check_dataset_operator_permission_partial_members_with_permission_success(self, mock_db_session): - """ - Test that user with explicit permission can access partial_members dataset. - - Verifies that when a user has an explicit DatasetPermission record - for a partial_members dataset, they can access it successfully. - - This test ensures: - - Explicit permissions are checked correctly - - Users with permissions can access - - Database query is executed - """ - # Arrange - user = DatasetPermissionTestDataFactory.create_user_mock(user_id="user-123", role=TenantAccountRole.NORMAL) - dataset = DatasetPermissionTestDataFactory.create_dataset_mock( - tenant_id="tenant-123", - permission=DatasetPermissionEnum.PARTIAL_TEAM, - created_by="other-user-456", # Not the creator - ) - - # Mock permission query to return permission records - mock_permission = DatasetPermissionTestDataFactory.create_dataset_permission_mock( - dataset_id=dataset.id, account_id=user.id - ) - mock_query = Mock() - mock_query.filter_by.return_value = mock_query - mock_query.all.return_value = [mock_permission] # User has permission - mock_db_session.query.return_value = mock_query - - # Act (should not raise) - DatasetService.check_dataset_operator_permission(user=user, dataset=dataset) - - # Assert - # Verify permission query was executed - mock_db_session.query.assert_called() - - def test_check_dataset_operator_permission_partial_members_without_permission_error(self, mock_db_session): - """ - Test error when user without permission tries to access partial_members dataset. - - Verifies that when a user does not have an explicit DatasetPermission - record for a partial_members dataset, a NoPermissionError is raised. - - This test ensures: - - Missing permissions are detected - - Error message is clear - - Error type is correct - """ - # Arrange - user = DatasetPermissionTestDataFactory.create_user_mock(user_id="user-123", role=TenantAccountRole.NORMAL) - dataset = DatasetPermissionTestDataFactory.create_dataset_mock( - tenant_id="tenant-123", - permission=DatasetPermissionEnum.PARTIAL_TEAM, - created_by="other-user-456", # Not the creator - ) - - # Mock permission query to return empty list (no permission) - mock_query = Mock() - mock_query.filter_by.return_value = mock_query - mock_query.all.return_value = [] # No permissions found - mock_db_session.query.return_value = mock_query - - # Act & Assert - with pytest.raises(NoPermissionError, match="You do not have permission to access this dataset"): - DatasetService.check_dataset_operator_permission(user=user, dataset=dataset) - # ============================================================================ # Additional Documentation and Notes From 6bd1be9e1606e83917592809a3d7684468ce0ed8 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Fri, 6 Mar 2026 07:41:55 +0900 Subject: [PATCH 052/256] chore(deps): bump markdown from 3.5.2 to 3.8.1 in /api (#33064) Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- api/pyproject.toml | 2 +- api/uv.lock | 8 ++++---- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/api/pyproject.toml b/api/pyproject.toml index f39ed910e7..3c89601dc5 100644 --- a/api/pyproject.toml +++ b/api/pyproject.toml @@ -35,7 +35,7 @@ dependencies = [ "jsonschema>=4.25.1", "langfuse~=2.51.3", "langsmith~=0.1.77", - "markdown~=3.5.1", + "markdown~=3.8.1", "mlflow-skinny>=3.0.0", "numpy~=1.26.4", "openpyxl~=3.1.5", diff --git a/api/uv.lock b/api/uv.lock index 7436167d07..9828067e8b 100644 --- a/api/uv.lock +++ b/api/uv.lock @@ -1606,7 +1606,7 @@ requires-dist = [ { name = "langfuse", specifier = "~=2.51.3" }, { name = "langsmith", specifier = "~=0.1.77" }, { name = "litellm", specifier = "==1.77.1" }, - { name = "markdown", specifier = "~=3.5.1" }, + { name = "markdown", specifier = "~=3.8.1" }, { name = "mlflow-skinny", specifier = ">=3.0.0" }, { name = "numpy", specifier = "~=1.26.4" }, { name = "openpyxl", specifier = "~=3.1.5" }, @@ -3437,11 +3437,11 @@ wheels = [ [[package]] name = "markdown" -version = "3.5.2" +version = "3.8.1" source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/11/28/c5441a6642681d92de56063fa7984df56f783d3f1eba518dc3e7a253b606/Markdown-3.5.2.tar.gz", hash = "sha256:e1ac7b3dc550ee80e602e71c1d168002f062e49f1b11e26a36264dafd4df2ef8", size = 349398, upload-time = "2024-01-10T15:19:38.261Z" } +sdist = { url = "https://files.pythonhosted.org/packages/db/7c/0738e5ff0adccd0b4e02c66d0446c03a3c557e02bb49b7c263d7ab56c57d/markdown-3.8.1.tar.gz", hash = "sha256:a2e2f01cead4828ee74ecca9623045f62216aef2212a7685d6eb9163f590b8c1", size = 361280, upload-time = "2025-06-18T14:50:49.618Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/42/f4/f0031854de10a0bc7821ef9fca0b92ca0d7aa6fbfbf504c5473ba825e49c/Markdown-3.5.2-py3-none-any.whl", hash = "sha256:d43323865d89fc0cb9b20c75fc8ad313af307cc087e84b657d9eec768eddeadd", size = 103870, upload-time = "2024-01-10T15:19:36.071Z" }, + { url = "https://files.pythonhosted.org/packages/50/34/3d1ff0cb4843a33817d06800e9383a2b2a2df4d508e37f53a40e829905d9/markdown-3.8.1-py3-none-any.whl", hash = "sha256:46cc0c0f1e5211ab2e9d453582f0b28a1bfaf058a9f7d5c50386b99b588d8811", size = 106642, upload-time = "2025-06-18T14:50:48.52Z" }, ] [[package]] From 741d48560da5d22a868ab5a249650c4aaa5e55b8 Mon Sep 17 00:00:00 2001 From: statxc <181730535+statxc@users.noreply.github.com> Date: Fri, 6 Mar 2026 01:42:54 +0200 Subject: [PATCH 053/256] refactor(api): add TypedDict definitions to models/model.py (#32925) Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com> --- api/controllers/console/explore/parameter.py | 6 +- api/controllers/service_api/app/app.py | 6 +- api/controllers/web/app.py | 5 +- .../sensitive_word_avoidance/manager.py | 7 +- .../easy_ui_based_app/agent/manager.py | 26 +- .../easy_ui_based_app/dataset/manager.py | 15 +- .../easy_ui_based_app/model_config/manager.py | 5 +- .../prompt_template/manager.py | 15 +- .../easy_ui_based_app/variables/manager.py | 14 +- api/core/app/app_config/entities.py | 2 +- .../app/apps/agent_chat/app_config_manager.py | 14 +- api/core/app/apps/chat/app_config_manager.py | 12 +- api/core/app/apps/chat/app_runner.py | 6 +- .../app/apps/completion/app_config_manager.py | 16 +- api/core/app/apps/completion/app_generator.py | 2 +- api/core/app/apps/completion/app_runner.py | 6 +- .../easy_ui_based_generate_task_pipeline.py | 6 +- api/core/plugin/backwards_invocation/app.py | 6 +- api/models/model.py | 408 +++++++++++++++--- api/services/app_dsl_service.py | 5 +- api/services/app_model_config_service.py | 4 +- api/services/app_service.py | 6 +- api/services/audio_service.py | 3 +- 23 files changed, 453 insertions(+), 142 deletions(-) diff --git a/api/controllers/console/explore/parameter.py b/api/controllers/console/explore/parameter.py index 660a4d5aea..0f29627746 100644 --- a/api/controllers/console/explore/parameter.py +++ b/api/controllers/console/explore/parameter.py @@ -1,3 +1,5 @@ +from typing import Any, cast + from controllers.common import fields from controllers.console import console_ns from controllers.console.app.error import AppUnavailableError @@ -23,14 +25,14 @@ class AppParameterApi(InstalledAppResource): if workflow is None: raise AppUnavailableError() - features_dict = workflow.features_dict + features_dict: dict[str, Any] = workflow.features_dict user_input_form = workflow.user_input_form(to_old_structure=True) else: app_model_config = app_model.app_model_config if app_model_config is None: raise AppUnavailableError() - features_dict = app_model_config.to_dict() + features_dict = cast(dict[str, Any], app_model_config.to_dict()) user_input_form = features_dict.get("user_input_form", []) diff --git a/api/controllers/service_api/app/app.py b/api/controllers/service_api/app/app.py index 562f5e33cc..abcaa0e240 100644 --- a/api/controllers/service_api/app/app.py +++ b/api/controllers/service_api/app/app.py @@ -1,3 +1,5 @@ +from typing import Any, cast + from flask_restx import Resource from controllers.common.fields import Parameters @@ -33,14 +35,14 @@ class AppParameterApi(Resource): if workflow is None: raise AppUnavailableError() - features_dict = workflow.features_dict + features_dict: dict[str, Any] = workflow.features_dict user_input_form = workflow.user_input_form(to_old_structure=True) else: app_model_config = app_model.app_model_config if app_model_config is None: raise AppUnavailableError() - features_dict = app_model_config.to_dict() + features_dict = cast(dict[str, Any], app_model_config.to_dict()) user_input_form = features_dict.get("user_input_form", []) diff --git a/api/controllers/web/app.py b/api/controllers/web/app.py index 62ea532eac..25bbedce54 100644 --- a/api/controllers/web/app.py +++ b/api/controllers/web/app.py @@ -1,4 +1,5 @@ import logging +from typing import Any, cast from flask import request from flask_restx import Resource @@ -57,14 +58,14 @@ class AppParameterApi(WebApiResource): if workflow is None: raise AppUnavailableError() - features_dict = workflow.features_dict + features_dict: dict[str, Any] = workflow.features_dict user_input_form = workflow.user_input_form(to_old_structure=True) else: app_model_config = app_model.app_model_config if app_model_config is None: raise AppUnavailableError() - features_dict = app_model_config.to_dict() + features_dict = cast(dict[str, Any], app_model_config.to_dict()) user_input_form = features_dict.get("user_input_form", []) diff --git a/api/core/app/app_config/common/sensitive_word_avoidance/manager.py b/api/core/app/app_config/common/sensitive_word_avoidance/manager.py index e925d6dd52..7d1b11c008 100644 --- a/api/core/app/app_config/common/sensitive_word_avoidance/manager.py +++ b/api/core/app/app_config/common/sensitive_word_avoidance/manager.py @@ -1,10 +1,13 @@ +from collections.abc import Mapping +from typing import Any + from core.app.app_config.entities import SensitiveWordAvoidanceEntity from core.moderation.factory import ModerationFactory class SensitiveWordAvoidanceConfigManager: @classmethod - def convert(cls, config: dict) -> SensitiveWordAvoidanceEntity | None: + def convert(cls, config: Mapping[str, Any]) -> SensitiveWordAvoidanceEntity | None: sensitive_word_avoidance_dict = config.get("sensitive_word_avoidance") if not sensitive_word_avoidance_dict: return None @@ -12,7 +15,7 @@ class SensitiveWordAvoidanceConfigManager: if sensitive_word_avoidance_dict.get("enabled"): return SensitiveWordAvoidanceEntity( type=sensitive_word_avoidance_dict.get("type"), - config=sensitive_word_avoidance_dict.get("config"), + config=sensitive_word_avoidance_dict.get("config", {}), ) else: return None diff --git a/api/core/app/app_config/easy_ui_based_app/agent/manager.py b/api/core/app/app_config/easy_ui_based_app/agent/manager.py index 9b981dfc09..10db380d1f 100644 --- a/api/core/app/app_config/easy_ui_based_app/agent/manager.py +++ b/api/core/app/app_config/easy_ui_based_app/agent/manager.py @@ -1,10 +1,13 @@ +from typing import Any, cast + from core.agent.entities import AgentEntity, AgentPromptEntity, AgentToolEntity from core.agent.prompt.template import REACT_PROMPT_TEMPLATES +from models.model import AppModelConfigDict class AgentConfigManager: @classmethod - def convert(cls, config: dict) -> AgentEntity | None: + def convert(cls, config: AppModelConfigDict) -> AgentEntity | None: """ Convert model config to model config @@ -28,17 +31,17 @@ class AgentConfigManager: agent_tools = [] for tool in agent_dict.get("tools", []): - keys = tool.keys() - if len(keys) >= 4: - if "enabled" not in tool or not tool["enabled"]: + tool_dict = cast(dict[str, Any], tool) + if len(tool_dict) >= 4: + if "enabled" not in tool_dict or not tool_dict["enabled"]: continue agent_tool_properties = { - "provider_type": tool["provider_type"], - "provider_id": tool["provider_id"], - "tool_name": tool["tool_name"], - "tool_parameters": tool.get("tool_parameters", {}), - "credential_id": tool.get("credential_id", None), + "provider_type": tool_dict["provider_type"], + "provider_id": tool_dict["provider_id"], + "tool_name": tool_dict["tool_name"], + "tool_parameters": tool_dict.get("tool_parameters", {}), + "credential_id": tool_dict.get("credential_id", None), } agent_tools.append(AgentToolEntity.model_validate(agent_tool_properties)) @@ -47,7 +50,8 @@ class AgentConfigManager: "react_router", "router", }: - agent_prompt = agent_dict.get("prompt", None) or {} + agent_prompt_raw = agent_dict.get("prompt", None) + agent_prompt: dict[str, Any] = agent_prompt_raw if isinstance(agent_prompt_raw, dict) else {} # check model mode model_mode = config.get("model", {}).get("mode", "completion") if model_mode == "completion": @@ -75,7 +79,7 @@ class AgentConfigManager: strategy=strategy, prompt=agent_prompt_entity, tools=agent_tools, - max_iteration=agent_dict.get("max_iteration", 10), + max_iteration=cast(int, agent_dict.get("max_iteration", 10)), ) return None diff --git a/api/core/app/app_config/easy_ui_based_app/dataset/manager.py b/api/core/app/app_config/easy_ui_based_app/dataset/manager.py index aacafb2dad..70f43b2c83 100644 --- a/api/core/app/app_config/easy_ui_based_app/dataset/manager.py +++ b/api/core/app/app_config/easy_ui_based_app/dataset/manager.py @@ -1,5 +1,5 @@ import uuid -from typing import Literal, cast +from typing import Any, Literal, cast from core.app.app_config.entities import ( DatasetEntity, @@ -8,13 +8,13 @@ from core.app.app_config.entities import ( ModelConfig, ) from core.entities.agent_entities import PlanningStrategy -from models.model import AppMode +from models.model import AppMode, AppModelConfigDict from services.dataset_service import DatasetService class DatasetConfigManager: @classmethod - def convert(cls, config: dict) -> DatasetEntity | None: + def convert(cls, config: AppModelConfigDict) -> DatasetEntity | None: """ Convert model config to model config @@ -25,11 +25,15 @@ class DatasetConfigManager: datasets = config.get("dataset_configs", {}).get("datasets", {"strategy": "router", "datasets": []}) for dataset in datasets.get("datasets", []): + if not isinstance(dataset, dict): + continue keys = list(dataset.keys()) if len(keys) == 0 or keys[0] != "dataset": continue dataset = dataset["dataset"] + if not isinstance(dataset, dict): + continue if "enabled" not in dataset or not dataset["enabled"]: continue @@ -47,15 +51,14 @@ class DatasetConfigManager: agent_dict = config.get("agent_mode", {}) for tool in agent_dict.get("tools", []): - keys = tool.keys() - if len(keys) == 1: + if len(tool) == 1: # old standard key = list(tool.keys())[0] if key != "dataset": continue - tool_item = tool[key] + tool_item = cast(dict[str, Any], tool)[key] if "enabled" not in tool_item or not tool_item["enabled"]: continue diff --git a/api/core/app/app_config/easy_ui_based_app/model_config/manager.py b/api/core/app/app_config/easy_ui_based_app/model_config/manager.py index e4e750c735..0929f52e33 100644 --- a/api/core/app/app_config/easy_ui_based_app/model_config/manager.py +++ b/api/core/app/app_config/easy_ui_based_app/model_config/manager.py @@ -5,12 +5,13 @@ from core.app.app_config.entities import ModelConfigEntity from core.provider_manager import ProviderManager from dify_graph.model_runtime.entities.model_entities import ModelPropertyKey, ModelType from dify_graph.model_runtime.model_providers.model_provider_factory import ModelProviderFactory +from models.model import AppModelConfigDict from models.provider_ids import ModelProviderID class ModelConfigManager: @classmethod - def convert(cls, config: dict) -> ModelConfigEntity: + def convert(cls, config: AppModelConfigDict) -> ModelConfigEntity: """ Convert model config to model config @@ -22,7 +23,7 @@ class ModelConfigManager: if not model_config: raise ValueError("model is required") - completion_params = model_config.get("completion_params") + completion_params = model_config.get("completion_params") or {} stop = [] if "stop" in completion_params: stop = completion_params["stop"] diff --git a/api/core/app/app_config/easy_ui_based_app/prompt_template/manager.py b/api/core/app/app_config/easy_ui_based_app/prompt_template/manager.py index 01b9601965..b7073898d6 100644 --- a/api/core/app/app_config/easy_ui_based_app/prompt_template/manager.py +++ b/api/core/app/app_config/easy_ui_based_app/prompt_template/manager.py @@ -1,3 +1,5 @@ +from typing import Any + from core.app.app_config.entities import ( AdvancedChatMessageEntity, AdvancedChatPromptTemplateEntity, @@ -6,12 +8,12 @@ from core.app.app_config.entities import ( ) from core.prompt.simple_prompt_transform import ModelMode from dify_graph.model_runtime.entities.message_entities import PromptMessageRole -from models.model import AppMode +from models.model import AppMode, AppModelConfigDict class PromptTemplateConfigManager: @classmethod - def convert(cls, config: dict) -> PromptTemplateEntity: + def convert(cls, config: AppModelConfigDict) -> PromptTemplateEntity: if not config.get("prompt_type"): raise ValueError("prompt_type is required") @@ -40,14 +42,15 @@ class PromptTemplateConfigManager: advanced_completion_prompt_template = None completion_prompt_config = config.get("completion_prompt_config", {}) if completion_prompt_config: - completion_prompt_template_params = { + completion_prompt_template_params: dict[str, Any] = { "prompt": completion_prompt_config["prompt"]["text"], } - if "conversation_histories_role" in completion_prompt_config: + conv_role = completion_prompt_config.get("conversation_histories_role") + if conv_role: completion_prompt_template_params["role_prefix"] = { - "user": completion_prompt_config["conversation_histories_role"]["user_prefix"], - "assistant": completion_prompt_config["conversation_histories_role"]["assistant_prefix"], + "user": conv_role["user_prefix"], + "assistant": conv_role["assistant_prefix"], } advanced_completion_prompt_template = AdvancedCompletionPromptTemplateEntity( diff --git a/api/core/app/app_config/easy_ui_based_app/variables/manager.py b/api/core/app/app_config/easy_ui_based_app/variables/manager.py index 157e5d8bc0..8de1224a89 100644 --- a/api/core/app/app_config/easy_ui_based_app/variables/manager.py +++ b/api/core/app/app_config/easy_ui_based_app/variables/manager.py @@ -1,8 +1,10 @@ import re +from typing import cast from core.app.app_config.entities import ExternalDataVariableEntity from core.external_data_tool.factory import ExternalDataToolFactory from dify_graph.variables.input_entities import VariableEntity, VariableEntityType +from models.model import AppModelConfigDict _ALLOWED_VARIABLE_ENTITY_TYPE = frozenset( [ @@ -18,7 +20,7 @@ _ALLOWED_VARIABLE_ENTITY_TYPE = frozenset( class BasicVariablesConfigManager: @classmethod - def convert(cls, config: dict) -> tuple[list[VariableEntity], list[ExternalDataVariableEntity]]: + def convert(cls, config: AppModelConfigDict) -> tuple[list[VariableEntity], list[ExternalDataVariableEntity]]: """ Convert model config to model config @@ -51,7 +53,9 @@ class BasicVariablesConfigManager: external_data_variables.append( ExternalDataVariableEntity( - variable=variable["variable"], type=variable["type"], config=variable["config"] + variable=variable["variable"], + type=variable.get("type", ""), + config=variable.get("config", {}), ) ) elif variable_type in { @@ -64,10 +68,10 @@ class BasicVariablesConfigManager: variable = variables[variable_type] variable_entities.append( VariableEntity( - type=variable_type, - variable=variable.get("variable"), + type=cast(VariableEntityType, variable_type), + variable=variable["variable"], description=variable.get("description") or "", - label=variable.get("label"), + label=variable["label"], required=variable.get("required", False), max_length=variable.get("max_length"), options=variable.get("options") or [], diff --git a/api/core/app/app_config/entities.py b/api/core/app/app_config/entities.py index f26351d93e..ac21577d57 100644 --- a/api/core/app/app_config/entities.py +++ b/api/core/app/app_config/entities.py @@ -281,7 +281,7 @@ class EasyUIBasedAppConfig(AppConfig): app_model_config_from: EasyUIBasedAppModelConfigFrom app_model_config_id: str - app_model_config_dict: dict + app_model_config_dict: dict[str, Any] model: ModelConfigEntity prompt_template: PromptTemplateEntity dataset: DatasetEntity | None = None diff --git a/api/core/app/apps/agent_chat/app_config_manager.py b/api/core/app/apps/agent_chat/app_config_manager.py index 801619ddbc..f0d81e0c59 100644 --- a/api/core/app/apps/agent_chat/app_config_manager.py +++ b/api/core/app/apps/agent_chat/app_config_manager.py @@ -20,7 +20,7 @@ from core.app.app_config.features.suggested_questions_after_answer.manager impor ) from core.app.app_config.features.text_to_speech.manager import TextToSpeechConfigManager from core.entities.agent_entities import PlanningStrategy -from models.model import App, AppMode, AppModelConfig, Conversation +from models.model import App, AppMode, AppModelConfig, AppModelConfigDict, Conversation OLD_TOOLS = ["dataset", "google_search", "web_reader", "wikipedia", "current_datetime"] @@ -40,7 +40,7 @@ class AgentChatAppConfigManager(BaseAppConfigManager): app_model: App, app_model_config: AppModelConfig, conversation: Conversation | None = None, - override_config_dict: dict | None = None, + override_config_dict: AppModelConfigDict | None = None, ) -> AgentChatAppConfig: """ Convert app model config to agent chat app config @@ -61,7 +61,9 @@ class AgentChatAppConfigManager(BaseAppConfigManager): app_model_config_dict = app_model_config.to_dict() config_dict = app_model_config_dict.copy() else: - config_dict = override_config_dict or {} + if not override_config_dict: + raise Exception("override_config_dict is required when config_from is ARGS") + config_dict = override_config_dict app_mode = AppMode.value_of(app_model.mode) app_config = AgentChatAppConfig( @@ -70,7 +72,7 @@ class AgentChatAppConfigManager(BaseAppConfigManager): app_mode=app_mode, app_model_config_from=config_from, app_model_config_id=app_model_config.id, - app_model_config_dict=config_dict, + app_model_config_dict=cast(dict[str, Any], config_dict), model=ModelConfigManager.convert(config=config_dict), prompt_template=PromptTemplateConfigManager.convert(config=config_dict), sensitive_word_avoidance=SensitiveWordAvoidanceConfigManager.convert(config=config_dict), @@ -86,7 +88,7 @@ class AgentChatAppConfigManager(BaseAppConfigManager): return app_config @classmethod - def config_validate(cls, tenant_id: str, config: Mapping[str, Any]): + def config_validate(cls, tenant_id: str, config: Mapping[str, Any]) -> AppModelConfigDict: """ Validate for agent chat app model config @@ -157,7 +159,7 @@ class AgentChatAppConfigManager(BaseAppConfigManager): # Filter out extra parameters filtered_config = {key: config.get(key) for key in related_config_keys} - return filtered_config + return cast(AppModelConfigDict, filtered_config) @classmethod def validate_agent_mode_and_set_defaults( diff --git a/api/core/app/apps/chat/app_config_manager.py b/api/core/app/apps/chat/app_config_manager.py index 4b6720a3c3..5f087f6066 100644 --- a/api/core/app/apps/chat/app_config_manager.py +++ b/api/core/app/apps/chat/app_config_manager.py @@ -1,3 +1,5 @@ +from typing import Any, cast + from core.app.app_config.base_app_config_manager import BaseAppConfigManager from core.app.app_config.common.sensitive_word_avoidance.manager import SensitiveWordAvoidanceConfigManager from core.app.app_config.easy_ui_based_app.dataset.manager import DatasetConfigManager @@ -13,7 +15,7 @@ from core.app.app_config.features.suggested_questions_after_answer.manager impor SuggestedQuestionsAfterAnswerConfigManager, ) from core.app.app_config.features.text_to_speech.manager import TextToSpeechConfigManager -from models.model import App, AppMode, AppModelConfig, Conversation +from models.model import App, AppMode, AppModelConfig, AppModelConfigDict, Conversation class ChatAppConfig(EasyUIBasedAppConfig): @@ -31,7 +33,7 @@ class ChatAppConfigManager(BaseAppConfigManager): app_model: App, app_model_config: AppModelConfig, conversation: Conversation | None = None, - override_config_dict: dict | None = None, + override_config_dict: AppModelConfigDict | None = None, ) -> ChatAppConfig: """ Convert app model config to chat app config @@ -64,7 +66,7 @@ class ChatAppConfigManager(BaseAppConfigManager): app_mode=app_mode, app_model_config_from=config_from, app_model_config_id=app_model_config.id, - app_model_config_dict=config_dict, + app_model_config_dict=cast(dict[str, Any], config_dict), model=ModelConfigManager.convert(config=config_dict), prompt_template=PromptTemplateConfigManager.convert(config=config_dict), sensitive_word_avoidance=SensitiveWordAvoidanceConfigManager.convert(config=config_dict), @@ -79,7 +81,7 @@ class ChatAppConfigManager(BaseAppConfigManager): return app_config @classmethod - def config_validate(cls, tenant_id: str, config: dict): + def config_validate(cls, tenant_id: str, config: dict) -> AppModelConfigDict: """ Validate for chat app model config @@ -145,4 +147,4 @@ class ChatAppConfigManager(BaseAppConfigManager): # Filter out extra parameters filtered_config = {key: config.get(key) for key in related_config_keys} - return filtered_config + return cast(AppModelConfigDict, filtered_config) diff --git a/api/core/app/apps/chat/app_runner.py b/api/core/app/apps/chat/app_runner.py index 23546a47bb..f63b38fc86 100644 --- a/api/core/app/apps/chat/app_runner.py +++ b/api/core/app/apps/chat/app_runner.py @@ -173,8 +173,10 @@ class ChatAppRunner(AppRunner): memory=memory, message_id=message.id, inputs=inputs, - vision_enabled=application_generate_entity.app_config.app_model_config_dict.get("file_upload", {}).get( - "enabled", False + vision_enabled=bool( + application_generate_entity.app_config.app_model_config_dict.get("file_upload", {}) + .get("image", {}) + .get("enabled", False) ), ) context_files = retrieved_files or [] diff --git a/api/core/app/apps/completion/app_config_manager.py b/api/core/app/apps/completion/app_config_manager.py index eb1902f12e..f49e7b8b5e 100644 --- a/api/core/app/apps/completion/app_config_manager.py +++ b/api/core/app/apps/completion/app_config_manager.py @@ -1,3 +1,5 @@ +from typing import Any, cast + from core.app.app_config.base_app_config_manager import BaseAppConfigManager from core.app.app_config.common.sensitive_word_avoidance.manager import SensitiveWordAvoidanceConfigManager from core.app.app_config.easy_ui_based_app.dataset.manager import DatasetConfigManager @@ -8,7 +10,7 @@ from core.app.app_config.entities import EasyUIBasedAppConfig, EasyUIBasedAppMod from core.app.app_config.features.file_upload.manager import FileUploadConfigManager from core.app.app_config.features.more_like_this.manager import MoreLikeThisConfigManager from core.app.app_config.features.text_to_speech.manager import TextToSpeechConfigManager -from models.model import App, AppMode, AppModelConfig +from models.model import App, AppMode, AppModelConfig, AppModelConfigDict class CompletionAppConfig(EasyUIBasedAppConfig): @@ -22,7 +24,7 @@ class CompletionAppConfig(EasyUIBasedAppConfig): class CompletionAppConfigManager(BaseAppConfigManager): @classmethod def get_app_config( - cls, app_model: App, app_model_config: AppModelConfig, override_config_dict: dict | None = None + cls, app_model: App, app_model_config: AppModelConfig, override_config_dict: AppModelConfigDict | None = None ) -> CompletionAppConfig: """ Convert app model config to completion app config @@ -40,7 +42,9 @@ class CompletionAppConfigManager(BaseAppConfigManager): app_model_config_dict = app_model_config.to_dict() config_dict = app_model_config_dict.copy() else: - config_dict = override_config_dict or {} + if not override_config_dict: + raise Exception("override_config_dict is required when config_from is ARGS") + config_dict = override_config_dict app_mode = AppMode.value_of(app_model.mode) app_config = CompletionAppConfig( @@ -49,7 +53,7 @@ class CompletionAppConfigManager(BaseAppConfigManager): app_mode=app_mode, app_model_config_from=config_from, app_model_config_id=app_model_config.id, - app_model_config_dict=config_dict, + app_model_config_dict=cast(dict[str, Any], config_dict), model=ModelConfigManager.convert(config=config_dict), prompt_template=PromptTemplateConfigManager.convert(config=config_dict), sensitive_word_avoidance=SensitiveWordAvoidanceConfigManager.convert(config=config_dict), @@ -64,7 +68,7 @@ class CompletionAppConfigManager(BaseAppConfigManager): return app_config @classmethod - def config_validate(cls, tenant_id: str, config: dict): + def config_validate(cls, tenant_id: str, config: dict) -> AppModelConfigDict: """ Validate for completion app model config @@ -116,4 +120,4 @@ class CompletionAppConfigManager(BaseAppConfigManager): # Filter out extra parameters filtered_config = {key: config.get(key) for key in related_config_keys} - return filtered_config + return cast(AppModelConfigDict, filtered_config) diff --git a/api/core/app/apps/completion/app_generator.py b/api/core/app/apps/completion/app_generator.py index e8b0e4f179..002b914ef1 100644 --- a/api/core/app/apps/completion/app_generator.py +++ b/api/core/app/apps/completion/app_generator.py @@ -275,7 +275,7 @@ class CompletionAppGenerator(MessageBasedAppGenerator): raise ValueError("Message app_model_config is None") override_model_config_dict = app_model_config.to_dict() model_dict = override_model_config_dict["model"] - completion_params = model_dict.get("completion_params") + completion_params = model_dict.get("completion_params", {}) completion_params["temperature"] = 0.9 model_dict["completion_params"] = completion_params override_model_config_dict["model"] = model_dict diff --git a/api/core/app/apps/completion/app_runner.py b/api/core/app/apps/completion/app_runner.py index ac05172945..56a4519879 100644 --- a/api/core/app/apps/completion/app_runner.py +++ b/api/core/app/apps/completion/app_runner.py @@ -132,8 +132,10 @@ class CompletionAppRunner(AppRunner): hit_callback=hit_callback, message_id=message.id, inputs=inputs, - vision_enabled=application_generate_entity.app_config.app_model_config_dict.get("file_upload", {}).get( - "enabled", False + vision_enabled=bool( + application_generate_entity.app_config.app_model_config_dict.get("file_upload", {}) + .get("image", {}) + .get("enabled", False) ), ) context_files = retrieved_files or [] diff --git a/api/core/app/task_pipeline/easy_ui_based_generate_task_pipeline.py b/api/core/app/task_pipeline/easy_ui_based_generate_task_pipeline.py index 1fa782eb6c..57ef0c078f 100644 --- a/api/core/app/task_pipeline/easy_ui_based_generate_task_pipeline.py +++ b/api/core/app/task_pipeline/easy_ui_based_generate_task_pipeline.py @@ -2,7 +2,7 @@ import logging import time from collections.abc import Generator from threading import Thread -from typing import Union, cast +from typing import Any, Union, cast from sqlalchemy import select from sqlalchemy.orm import Session @@ -219,14 +219,14 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline): tenant_id = self._application_generate_entity.app_config.tenant_id task_id = self._application_generate_entity.task_id publisher = None - text_to_speech_dict = self._app_config.app_model_config_dict.get("text_to_speech") + text_to_speech_dict = cast(dict[str, Any], self._app_config.app_model_config_dict.get("text_to_speech")) if ( text_to_speech_dict and text_to_speech_dict.get("autoPlay") == "enabled" and text_to_speech_dict.get("enabled") ): publisher = AppGeneratorTTSPublisher( - tenant_id, text_to_speech_dict.get("voice", None), text_to_speech_dict.get("language", None) + tenant_id, text_to_speech_dict.get("voice", ""), text_to_speech_dict.get("language", None) ) for response in self._process_stream_response(publisher=publisher, trace_manager=trace_manager): while True: diff --git a/api/core/plugin/backwards_invocation/app.py b/api/core/plugin/backwards_invocation/app.py index 3c5df2b905..60d08b26c9 100644 --- a/api/core/plugin/backwards_invocation/app.py +++ b/api/core/plugin/backwards_invocation/app.py @@ -1,6 +1,6 @@ import uuid from collections.abc import Generator, Mapping -from typing import Union +from typing import Any, Union, cast from sqlalchemy import select from sqlalchemy.orm import Session @@ -34,14 +34,14 @@ class PluginAppBackwardsInvocation(BaseBackwardsInvocation): if workflow is None: raise ValueError("unexpected app type") - features_dict = workflow.features_dict + features_dict: dict[str, Any] = workflow.features_dict user_input_form = workflow.user_input_form(to_old_structure=True) else: app_model_config = app.app_model_config if app_model_config is None: raise ValueError("unexpected app type") - features_dict = app_model_config.to_dict() + features_dict = cast(dict[str, Any], app_model_config.to_dict()) user_input_form = features_dict.get("user_input_form", []) diff --git a/api/models/model.py b/api/models/model.py index 2bf80edb80..ed0614c195 100644 --- a/api/models/model.py +++ b/api/models/model.py @@ -7,7 +7,7 @@ from collections.abc import Mapping, Sequence from datetime import datetime from decimal import Decimal from enum import StrEnum, auto -from typing import TYPE_CHECKING, Any, Literal, cast +from typing import TYPE_CHECKING, Any, Literal, NotRequired, cast from uuid import uuid4 import sqlalchemy as sa @@ -15,6 +15,7 @@ from flask import request from flask_login import UserMixin # type: ignore[import-untyped] from sqlalchemy import BigInteger, Float, Index, PrimaryKeyConstraint, String, exists, func, select, text from sqlalchemy.orm import Mapped, Session, mapped_column +from typing_extensions import TypedDict from configs import dify_config from constants import DEFAULT_FILE_NUMBER_LIMITS @@ -36,6 +37,259 @@ if TYPE_CHECKING: from .workflow import Workflow +# --- TypedDict definitions for structured dict return types --- + + +class EnabledConfig(TypedDict): + enabled: bool + + +class EmbeddingModelInfo(TypedDict): + embedding_provider_name: str + embedding_model_name: str + + +class AnnotationReplyDisabledConfig(TypedDict): + enabled: Literal[False] + + +class AnnotationReplyEnabledConfig(TypedDict): + id: str + enabled: Literal[True] + score_threshold: float + embedding_model: EmbeddingModelInfo + + +AnnotationReplyConfig = AnnotationReplyEnabledConfig | AnnotationReplyDisabledConfig + + +class SensitiveWordAvoidanceConfig(TypedDict): + enabled: bool + type: str + config: dict[str, Any] + + +class AgentToolConfig(TypedDict): + provider_type: str + provider_id: str + tool_name: str + tool_parameters: dict[str, Any] + plugin_unique_identifier: NotRequired[str | None] + credential_id: NotRequired[str | None] + + +class AgentModeConfig(TypedDict): + enabled: bool + strategy: str | None + tools: list[AgentToolConfig | dict[str, Any]] + prompt: str | None + + +class ImageUploadConfig(TypedDict): + enabled: bool + number_limits: int + detail: str + transfer_methods: list[str] + + +class FileUploadConfig(TypedDict): + image: ImageUploadConfig + + +class DeletedToolInfo(TypedDict): + type: str + tool_name: str + provider_id: str + + +class ExternalDataToolConfig(TypedDict): + enabled: bool + variable: str + type: str + config: dict[str, Any] + + +class UserInputFormItemConfig(TypedDict): + variable: str + label: str + description: NotRequired[str] + required: NotRequired[bool] + max_length: NotRequired[int] + options: NotRequired[list[str]] + default: NotRequired[str] + type: NotRequired[str] + config: NotRequired[dict[str, Any]] + + +# Each item is a single-key dict, e.g. {"text-input": UserInputFormItemConfig} +UserInputFormItem = dict[str, UserInputFormItemConfig] + + +class DatasetConfigs(TypedDict): + retrieval_model: str + datasets: NotRequired[dict[str, Any]] + top_k: NotRequired[int] + score_threshold: NotRequired[float] + score_threshold_enabled: NotRequired[bool] + reranking_model: NotRequired[dict[str, Any] | None] + weights: NotRequired[dict[str, Any] | None] + reranking_enabled: NotRequired[bool] + reranking_mode: NotRequired[str] + metadata_filtering_mode: NotRequired[str] + metadata_model_config: NotRequired[dict[str, Any] | None] + metadata_filtering_conditions: NotRequired[dict[str, Any] | None] + + +class ChatPromptMessage(TypedDict): + text: str + role: str + + +class ChatPromptConfig(TypedDict, total=False): + prompt: list[ChatPromptMessage] + + +class CompletionPromptText(TypedDict): + text: str + + +class ConversationHistoriesRole(TypedDict): + user_prefix: str + assistant_prefix: str + + +class CompletionPromptConfig(TypedDict): + prompt: CompletionPromptText + conversation_histories_role: NotRequired[ConversationHistoriesRole] + + +class ModelConfig(TypedDict): + provider: str + name: str + mode: str + completion_params: NotRequired[dict[str, Any]] + + +class AppModelConfigDict(TypedDict): + opening_statement: str | None + suggested_questions: list[str] + suggested_questions_after_answer: EnabledConfig + speech_to_text: EnabledConfig + text_to_speech: EnabledConfig + retriever_resource: EnabledConfig + annotation_reply: AnnotationReplyConfig + more_like_this: EnabledConfig + sensitive_word_avoidance: SensitiveWordAvoidanceConfig + external_data_tools: list[ExternalDataToolConfig] + model: ModelConfig + user_input_form: list[UserInputFormItem] + dataset_query_variable: str | None + pre_prompt: str | None + agent_mode: AgentModeConfig + prompt_type: str + chat_prompt_config: ChatPromptConfig + completion_prompt_config: CompletionPromptConfig + dataset_configs: DatasetConfigs + file_upload: FileUploadConfig + # Added dynamically in Conversation.model_config + model_id: NotRequired[str | None] + provider: NotRequired[str | None] + + +class ConversationDict(TypedDict): + id: str + app_id: str + app_model_config_id: str | None + model_provider: str | None + override_model_configs: str | None + model_id: str | None + mode: str + name: str + summary: str | None + inputs: dict[str, Any] + introduction: str | None + system_instruction: str | None + system_instruction_tokens: int + status: str + invoke_from: str | None + from_source: str + from_end_user_id: str | None + from_account_id: str | None + read_at: datetime | None + read_account_id: str | None + dialogue_count: int + created_at: datetime + updated_at: datetime + + +class MessageDict(TypedDict): + id: str + app_id: str + conversation_id: str + model_id: str | None + inputs: dict[str, Any] + query: str + total_price: Decimal | None + message: dict[str, Any] + answer: str + status: str + error: str | None + message_metadata: dict[str, Any] + from_source: str + from_end_user_id: str | None + from_account_id: str | None + created_at: str + updated_at: str + agent_based: bool + workflow_run_id: str | None + + +class MessageFeedbackDict(TypedDict): + id: str + app_id: str + conversation_id: str + message_id: str + rating: str + content: str | None + from_source: str + from_end_user_id: str | None + from_account_id: str | None + created_at: str + updated_at: str + + +class MessageFileInfo(TypedDict, total=False): + belongs_to: str | None + upload_file_id: str | None + id: str + tenant_id: str + type: str + transfer_method: str + remote_url: str | None + related_id: str | None + filename: str | None + extension: str | None + mime_type: str | None + size: int + dify_model_identity: str + url: str | None + + +class ExtraContentDict(TypedDict, total=False): + type: str + workflow_run_id: str + + +class TraceAppConfigDict(TypedDict): + id: str + app_id: str + tracing_provider: str | None + tracing_config: dict[str, Any] + is_active: bool + created_at: str | None + updated_at: str | None + + class DifySetup(TypeBase): __tablename__ = "dify_setups" __table_args__ = (sa.PrimaryKeyConstraint("version", name="dify_setup_pkey"),) @@ -176,7 +430,7 @@ class App(Base): return str(self.mode) @property - def deleted_tools(self) -> list[dict[str, str]]: + def deleted_tools(self) -> list[DeletedToolInfo]: from core.tools.tool_manager import ToolManager, ToolProviderType from services.plugin.plugin_service import PluginService @@ -257,7 +511,7 @@ class App(Base): provider_id.provider_name: existence[i] for i, provider_id in enumerate(builtin_provider_ids) } - deleted_tools: list[dict[str, str]] = [] + deleted_tools: list[DeletedToolInfo] = [] for tool in tools: keys = list(tool.keys()) @@ -364,35 +618,38 @@ class AppModelConfig(TypeBase): return app @property - def model_dict(self) -> dict[str, Any]: - return json.loads(self.model) if self.model else {} + def model_dict(self) -> ModelConfig: + return cast(ModelConfig, json.loads(self.model) if self.model else {}) @property def suggested_questions_list(self) -> list[str]: return json.loads(self.suggested_questions) if self.suggested_questions else [] @property - def suggested_questions_after_answer_dict(self) -> dict[str, Any]: - return ( + def suggested_questions_after_answer_dict(self) -> EnabledConfig: + return cast( + EnabledConfig, json.loads(self.suggested_questions_after_answer) if self.suggested_questions_after_answer - else {"enabled": False} + else {"enabled": False}, ) @property - def speech_to_text_dict(self) -> dict[str, Any]: - return json.loads(self.speech_to_text) if self.speech_to_text else {"enabled": False} + def speech_to_text_dict(self) -> EnabledConfig: + return cast(EnabledConfig, json.loads(self.speech_to_text) if self.speech_to_text else {"enabled": False}) @property - def text_to_speech_dict(self) -> dict[str, Any]: - return json.loads(self.text_to_speech) if self.text_to_speech else {"enabled": False} + def text_to_speech_dict(self) -> EnabledConfig: + return cast(EnabledConfig, json.loads(self.text_to_speech) if self.text_to_speech else {"enabled": False}) @property - def retriever_resource_dict(self) -> dict[str, Any]: - return json.loads(self.retriever_resource) if self.retriever_resource else {"enabled": True} + def retriever_resource_dict(self) -> EnabledConfig: + return cast( + EnabledConfig, json.loads(self.retriever_resource) if self.retriever_resource else {"enabled": True} + ) @property - def annotation_reply_dict(self) -> dict[str, Any]: + def annotation_reply_dict(self) -> AnnotationReplyConfig: annotation_setting = ( db.session.query(AppAnnotationSetting).where(AppAnnotationSetting.app_id == self.app_id).first() ) @@ -415,56 +672,62 @@ class AppModelConfig(TypeBase): return {"enabled": False} @property - def more_like_this_dict(self) -> dict[str, Any]: - return json.loads(self.more_like_this) if self.more_like_this else {"enabled": False} + def more_like_this_dict(self) -> EnabledConfig: + return cast(EnabledConfig, json.loads(self.more_like_this) if self.more_like_this else {"enabled": False}) @property - def sensitive_word_avoidance_dict(self) -> dict[str, Any]: - return ( + def sensitive_word_avoidance_dict(self) -> SensitiveWordAvoidanceConfig: + return cast( + SensitiveWordAvoidanceConfig, json.loads(self.sensitive_word_avoidance) if self.sensitive_word_avoidance - else {"enabled": False, "type": "", "configs": []} + else {"enabled": False, "type": "", "config": {}}, ) @property - def external_data_tools_list(self) -> list[dict[str, Any]]: + def external_data_tools_list(self) -> list[ExternalDataToolConfig]: return json.loads(self.external_data_tools) if self.external_data_tools else [] @property - def user_input_form_list(self) -> list[dict[str, Any]]: + def user_input_form_list(self) -> list[UserInputFormItem]: return json.loads(self.user_input_form) if self.user_input_form else [] @property - def agent_mode_dict(self) -> dict[str, Any]: - return ( + def agent_mode_dict(self) -> AgentModeConfig: + return cast( + AgentModeConfig, json.loads(self.agent_mode) if self.agent_mode - else {"enabled": False, "strategy": None, "tools": [], "prompt": None} + else {"enabled": False, "strategy": None, "tools": [], "prompt": None}, ) @property - def chat_prompt_config_dict(self) -> dict[str, Any]: - return json.loads(self.chat_prompt_config) if self.chat_prompt_config else {} + def chat_prompt_config_dict(self) -> ChatPromptConfig: + return cast(ChatPromptConfig, json.loads(self.chat_prompt_config) if self.chat_prompt_config else {}) @property - def completion_prompt_config_dict(self) -> dict[str, Any]: - return json.loads(self.completion_prompt_config) if self.completion_prompt_config else {} + def completion_prompt_config_dict(self) -> CompletionPromptConfig: + return cast( + CompletionPromptConfig, + json.loads(self.completion_prompt_config) if self.completion_prompt_config else {}, + ) @property - def dataset_configs_dict(self) -> dict[str, Any]: + def dataset_configs_dict(self) -> DatasetConfigs: if self.dataset_configs: - dataset_configs: dict[str, Any] = json.loads(self.dataset_configs) + dataset_configs = json.loads(self.dataset_configs) if "retrieval_model" not in dataset_configs: return {"retrieval_model": "single"} else: - return dataset_configs + return cast(DatasetConfigs, dataset_configs) return { "retrieval_model": "multiple", } @property - def file_upload_dict(self) -> dict[str, Any]: - return ( + def file_upload_dict(self) -> FileUploadConfig: + return cast( + FileUploadConfig, json.loads(self.file_upload) if self.file_upload else { @@ -474,10 +737,10 @@ class AppModelConfig(TypeBase): "detail": "high", "transfer_methods": ["remote_url", "local_file"], } - } + }, ) - def to_dict(self) -> dict[str, Any]: + def to_dict(self) -> AppModelConfigDict: return { "opening_statement": self.opening_statement, "suggested_questions": self.suggested_questions_list, @@ -501,36 +764,42 @@ class AppModelConfig(TypeBase): "file_upload": self.file_upload_dict, } - def from_model_config_dict(self, model_config: Mapping[str, Any]): + def from_model_config_dict(self, model_config: AppModelConfigDict): self.opening_statement = model_config.get("opening_statement") self.suggested_questions = ( - json.dumps(model_config["suggested_questions"]) if model_config.get("suggested_questions") else None + json.dumps(model_config.get("suggested_questions")) if model_config.get("suggested_questions") else None ) self.suggested_questions_after_answer = ( - json.dumps(model_config["suggested_questions_after_answer"]) + json.dumps(model_config.get("suggested_questions_after_answer")) if model_config.get("suggested_questions_after_answer") else None ) - self.speech_to_text = json.dumps(model_config["speech_to_text"]) if model_config.get("speech_to_text") else None - self.text_to_speech = json.dumps(model_config["text_to_speech"]) if model_config.get("text_to_speech") else None - self.more_like_this = json.dumps(model_config["more_like_this"]) if model_config.get("more_like_this") else None + self.speech_to_text = ( + json.dumps(model_config.get("speech_to_text")) if model_config.get("speech_to_text") else None + ) + self.text_to_speech = ( + json.dumps(model_config.get("text_to_speech")) if model_config.get("text_to_speech") else None + ) + self.more_like_this = ( + json.dumps(model_config.get("more_like_this")) if model_config.get("more_like_this") else None + ) self.sensitive_word_avoidance = ( - json.dumps(model_config["sensitive_word_avoidance"]) + json.dumps(model_config.get("sensitive_word_avoidance")) if model_config.get("sensitive_word_avoidance") else None ) self.external_data_tools = ( - json.dumps(model_config["external_data_tools"]) if model_config.get("external_data_tools") else None + json.dumps(model_config.get("external_data_tools")) if model_config.get("external_data_tools") else None ) - self.model = json.dumps(model_config["model"]) if model_config.get("model") else None + self.model = json.dumps(model_config.get("model")) if model_config.get("model") else None self.user_input_form = ( - json.dumps(model_config["user_input_form"]) if model_config.get("user_input_form") else None + json.dumps(model_config.get("user_input_form")) if model_config.get("user_input_form") else None ) self.dataset_query_variable = model_config.get("dataset_query_variable") - self.pre_prompt = model_config["pre_prompt"] - self.agent_mode = json.dumps(model_config["agent_mode"]) if model_config.get("agent_mode") else None + self.pre_prompt = model_config.get("pre_prompt") + self.agent_mode = json.dumps(model_config.get("agent_mode")) if model_config.get("agent_mode") else None self.retriever_resource = ( - json.dumps(model_config["retriever_resource"]) if model_config.get("retriever_resource") else None + json.dumps(model_config.get("retriever_resource")) if model_config.get("retriever_resource") else None ) self.prompt_type = model_config.get("prompt_type", "simple") self.chat_prompt_config = ( @@ -823,24 +1092,26 @@ class Conversation(Base): self._inputs = inputs @property - def model_config(self): - model_config = {} + def model_config(self) -> AppModelConfigDict: + model_config = cast(AppModelConfigDict, {}) app_model_config: AppModelConfig | None = None if self.mode == AppMode.ADVANCED_CHAT: if self.override_model_configs: override_model_configs = json.loads(self.override_model_configs) - model_config = override_model_configs + model_config = cast(AppModelConfigDict, override_model_configs) else: if self.override_model_configs: override_model_configs = json.loads(self.override_model_configs) if "model" in override_model_configs: # where is app_id? - app_model_config = AppModelConfig(app_id=self.app_id).from_model_config_dict(override_model_configs) + app_model_config = AppModelConfig(app_id=self.app_id).from_model_config_dict( + cast(AppModelConfigDict, override_model_configs) + ) model_config = app_model_config.to_dict() else: - model_config["configs"] = override_model_configs + model_config["configs"] = override_model_configs # type: ignore[typeddict-unknown-key] else: app_model_config = ( db.session.query(AppModelConfig).where(AppModelConfig.id == self.app_model_config_id).first() @@ -1015,7 +1286,7 @@ class Conversation(Base): def in_debug_mode(self) -> bool: return self.override_model_configs is not None - def to_dict(self) -> dict[str, Any]: + def to_dict(self) -> ConversationDict: return { "id": self.id, "app_id": self.app_id, @@ -1295,7 +1566,7 @@ class Message(Base): return self.message_metadata_dict.get("retriever_resources") if self.message_metadata else [] @property - def message_files(self) -> list[dict[str, Any]]: + def message_files(self) -> list[MessageFileInfo]: from factories import file_factory message_files = db.session.scalars(select(MessageFile).where(MessageFile.message_id == self.id)).all() @@ -1350,10 +1621,13 @@ class Message(Base): ) files.append(file) - result: list[dict[str, Any]] = [ - {"belongs_to": message_file.belongs_to, "upload_file_id": message_file.upload_file_id, **file.to_dict()} - for (file, message_file) in zip(files, message_files) - ] + result = cast( + list[MessageFileInfo], + [ + {"belongs_to": message_file.belongs_to, "upload_file_id": message_file.upload_file_id, **file.to_dict()} + for (file, message_file) in zip(files, message_files) + ], + ) db.session.commit() return result @@ -1363,7 +1637,7 @@ class Message(Base): self._extra_contents = list(contents) @property - def extra_contents(self) -> list[dict[str, Any]]: + def extra_contents(self) -> list[ExtraContentDict]: return getattr(self, "_extra_contents", []) @property @@ -1379,7 +1653,7 @@ class Message(Base): return None - def to_dict(self) -> dict[str, Any]: + def to_dict(self) -> MessageDict: return { "id": self.id, "app_id": self.app_id, @@ -1403,7 +1677,7 @@ class Message(Base): } @classmethod - def from_dict(cls, data: dict[str, Any]) -> Message: + def from_dict(cls, data: MessageDict) -> Message: return cls( id=data["id"], app_id=data["app_id"], @@ -1463,7 +1737,7 @@ class MessageFeedback(TypeBase): account = db.session.query(Account).where(Account.id == self.from_account_id).first() return account - def to_dict(self) -> dict[str, Any]: + def to_dict(self) -> MessageFeedbackDict: return { "id": str(self.id), "app_id": str(self.app_id), @@ -1726,8 +2000,8 @@ class AppMCPServer(TypeBase): return result @property - def parameters_dict(self) -> dict[str, Any]: - return cast(dict[str, Any], json.loads(self.parameters)) + def parameters_dict(self) -> dict[str, str]: + return cast(dict[str, str], json.loads(self.parameters)) class Site(Base): @@ -2167,7 +2441,7 @@ class TraceAppConfig(TypeBase): def tracing_config_str(self) -> str: return json.dumps(self.tracing_config_dict) - def to_dict(self) -> dict[str, Any]: + def to_dict(self) -> TraceAppConfigDict: return { "id": self.id, "app_id": self.app_id, diff --git a/api/services/app_dsl_service.py b/api/services/app_dsl_service.py index 5790c8b9ec..06f4ccb90e 100644 --- a/api/services/app_dsl_service.py +++ b/api/services/app_dsl_service.py @@ -4,6 +4,7 @@ import logging import uuid from collections.abc import Mapping from enum import StrEnum +from typing import cast from urllib.parse import urlparse from uuid import uuid4 @@ -32,7 +33,7 @@ from extensions.ext_redis import redis_client from factories import variable_factory from libs.datetime_utils import naive_utc_now from models import Account, App, AppMode -from models.model import AppModelConfig, IconType +from models.model import AppModelConfig, AppModelConfigDict, IconType from models.workflow import Workflow from services.plugin.dependencies_analysis import DependenciesAnalysisService from services.workflow_draft_variable_service import WorkflowDraftVariableService @@ -523,7 +524,7 @@ class AppDslService: if not app.app_model_config: app_model_config = AppModelConfig( app_id=app.id, created_by=account.id, updated_by=account.id - ).from_model_config_dict(model_config) + ).from_model_config_dict(cast(AppModelConfigDict, model_config)) app_model_config.id = str(uuid4()) app.app_model_config_id = app_model_config.id diff --git a/api/services/app_model_config_service.py b/api/services/app_model_config_service.py index 6f54f90734..3bc30cb323 100644 --- a/api/services/app_model_config_service.py +++ b/api/services/app_model_config_service.py @@ -1,12 +1,12 @@ from core.app.apps.agent_chat.app_config_manager import AgentChatAppConfigManager from core.app.apps.chat.app_config_manager import ChatAppConfigManager from core.app.apps.completion.app_config_manager import CompletionAppConfigManager -from models.model import AppMode +from models.model import AppMode, AppModelConfigDict class AppModelConfigService: @classmethod - def validate_configuration(cls, tenant_id: str, config: dict, app_mode: AppMode): + def validate_configuration(cls, tenant_id: str, config: dict, app_mode: AppMode) -> AppModelConfigDict: if app_mode == AppMode.CHAT: return ChatAppConfigManager.config_validate(tenant_id, config) elif app_mode == AppMode.AGENT_CHAT: diff --git a/api/services/app_service.py b/api/services/app_service.py index ce6826ef5c..aba8954f1a 100644 --- a/api/services/app_service.py +++ b/api/services/app_service.py @@ -1,6 +1,6 @@ import json import logging -from typing import TypedDict, cast +from typing import Any, TypedDict, cast import sqlalchemy as sa from flask_sqlalchemy.pagination import Pagination @@ -187,7 +187,7 @@ class AppService: for tool in agent_mode.get("tools") or []: if not isinstance(tool, dict) or len(tool.keys()) <= 3: continue - agent_tool_entity = AgentToolEntity(**tool) + agent_tool_entity = AgentToolEntity(**cast(dict[str, Any], tool)) # get tool try: tool_runtime = ToolManager.get_agent_tool_runtime( @@ -388,7 +388,7 @@ class AppService: agent_config = app_model_config.agent_mode_dict # get all tools - tools = agent_config.get("tools", []) + tools = cast(list[dict[str, Any]], agent_config.get("tools", [])) url_prefix = dify_config.CONSOLE_API_URL + "/console/api/workspaces/current/tool-provider/builtin/" diff --git a/api/services/audio_service.py b/api/services/audio_service.py index 1b698fad17..1794ea9947 100644 --- a/api/services/audio_service.py +++ b/api/services/audio_service.py @@ -2,6 +2,7 @@ import io import logging import uuid from collections.abc import Generator +from typing import cast from flask import Response, stream_with_context from werkzeug.datastructures import FileStorage @@ -106,7 +107,7 @@ class AudioService: if not text_to_speech_dict.get("enabled"): raise ValueError("TTS is not enabled") - voice = text_to_speech_dict.get("voice") + voice = cast(str | None, text_to_speech_dict.get("voice")) model_manager = ModelManager() model_instance = model_manager.get_default_model_instance( From 49dcf5e0d9d42cf46990c852422d07a53469089f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=9B=90=E7=B2=92=20Yanli?= Date: Fri, 6 Mar 2026 09:49:23 +0800 Subject: [PATCH 054/256] chore: add local pyrefly exclude workflow (#33059) --- Makefile | 5 +- api/pyproject.toml | 10 ++ api/pyrefly-local-excludes.txt | 200 +++++++++++++++++++++++++++++++++ api/pyrefly.toml | 8 -- dev/pyrefly-check-local | 34 ++++++ 5 files changed, 247 insertions(+), 10 deletions(-) create mode 100644 api/pyrefly-local-excludes.txt delete mode 100644 api/pyrefly.toml create mode 100755 dev/pyrefly-check-local diff --git a/Makefile b/Makefile index 0aff26b3e5..55871c86a7 100644 --- a/Makefile +++ b/Makefile @@ -68,8 +68,9 @@ lint: @echo "✅ Linting complete" type-check: - @echo "📝 Running type checks (basedpyright + mypy)..." + @echo "📝 Running type checks (basedpyright + pyrefly + mypy)..." @./dev/basedpyright-check $(PATH_TO_CHECK) + @./dev/pyrefly-check-local @uv --directory api run mypy --exclude-gitignore --exclude 'tests/' --exclude 'migrations/' --check-untyped-defs --disable-error-code=import-untyped . @echo "✅ Type checks complete" @@ -131,7 +132,7 @@ help: @echo " make format - Format code with ruff" @echo " make check - Check code with ruff" @echo " make lint - Format, fix, and lint code (ruff, imports, dotenv)" - @echo " make type-check - Run type checks (basedpyright, mypy)" + @echo " make type-check - Run type checks (basedpyright, pyrefly, mypy)" @echo " make test - Run backend unit tests (or TARGET_TESTS=./api/tests/)" @echo "" @echo "Docker Build Targets:" diff --git a/api/pyproject.toml b/api/pyproject.toml index 3c89601dc5..bf786f4584 100644 --- a/api/pyproject.toml +++ b/api/pyproject.toml @@ -247,3 +247,13 @@ module = [ "extensions.logstore.repositories.logstore_api_workflow_run_repository", ] ignore_errors = true + +[tool.pyrefly] +project-includes = ["."] +project-excludes = [ + ".venv", + "migrations/", +] +python-platform = "linux" +python-version = "3.11.0" +infer-with-first-use = false diff --git a/api/pyrefly-local-excludes.txt b/api/pyrefly-local-excludes.txt new file mode 100644 index 0000000000..d3b2ede745 --- /dev/null +++ b/api/pyrefly-local-excludes.txt @@ -0,0 +1,200 @@ +configs/middleware/cache/redis_pubsub_config.py +controllers/console/app/annotation.py +controllers/console/app/app.py +controllers/console/app/app_import.py +controllers/console/app/mcp_server.py +controllers/console/app/site.py +controllers/console/auth/email_register.py +controllers/console/human_input_form.py +controllers/console/init_validate.py +controllers/console/ping.py +controllers/console/setup.py +controllers/console/version.py +controllers/console/workspace/trigger_providers.py +controllers/service_api/app/annotation.py +controllers/web/workflow_events.py +core/agent/fc_agent_runner.py +core/app/apps/advanced_chat/app_generator.py +core/app/apps/advanced_chat/app_runner.py +core/app/apps/advanced_chat/generate_task_pipeline.py +core/app/apps/agent_chat/app_generator.py +core/app/apps/base_app_generate_response_converter.py +core/app/apps/base_app_generator.py +core/app/apps/chat/app_generator.py +core/app/apps/common/workflow_response_converter.py +core/app/apps/completion/app_generator.py +core/app/apps/pipeline/pipeline_generator.py +core/app/apps/pipeline/pipeline_runner.py +core/app/apps/workflow/app_generator.py +core/app/apps/workflow/app_runner.py +core/app/apps/workflow/generate_task_pipeline.py +core/app/apps/workflow_app_runner.py +core/app/task_pipeline/easy_ui_based_generate_task_pipeline.py +core/datasource/datasource_manager.py +core/external_data_tool/api/api.py +core/llm_generator/llm_generator.py +core/llm_generator/output_parser/structured_output.py +core/mcp/mcp_client.py +core/ops/aliyun_trace/data_exporter/traceclient.py +core/ops/arize_phoenix_trace/arize_phoenix_trace.py +core/ops/mlflow_trace/mlflow_trace.py +core/ops/ops_trace_manager.py +core/ops/tencent_trace/client.py +core/ops/tencent_trace/utils.py +core/plugin/backwards_invocation/base.py +core/plugin/backwards_invocation/model.py +core/prompt/utils/extract_thread_messages.py +core/rag/datasource/keyword/jieba/jieba.py +core/rag/datasource/keyword/jieba/jieba_keyword_table_handler.py +core/rag/datasource/vdb/analyticdb/analyticdb_vector.py +core/rag/datasource/vdb/analyticdb/analyticdb_vector_openapi.py +core/rag/datasource/vdb/baidu/baidu_vector.py +core/rag/datasource/vdb/chroma/chroma_vector.py +core/rag/datasource/vdb/clickzetta/clickzetta_vector.py +core/rag/datasource/vdb/couchbase/couchbase_vector.py +core/rag/datasource/vdb/elasticsearch/elasticsearch_vector.py +core/rag/datasource/vdb/huawei/huawei_cloud_vector.py +core/rag/datasource/vdb/lindorm/lindorm_vector.py +core/rag/datasource/vdb/matrixone/matrixone_vector.py +core/rag/datasource/vdb/milvus/milvus_vector.py +core/rag/datasource/vdb/myscale/myscale_vector.py +core/rag/datasource/vdb/oceanbase/oceanbase_vector.py +core/rag/datasource/vdb/opensearch/opensearch_vector.py +core/rag/datasource/vdb/oracle/oraclevector.py +core/rag/datasource/vdb/pgvecto_rs/pgvecto_rs.py +core/rag/datasource/vdb/relyt/relyt_vector.py +core/rag/datasource/vdb/tablestore/tablestore_vector.py +core/rag/datasource/vdb/tencent/tencent_vector.py +core/rag/datasource/vdb/tidb_on_qdrant/tidb_on_qdrant_vector.py +core/rag/datasource/vdb/tidb_on_qdrant/tidb_service.py +core/rag/datasource/vdb/tidb_vector/tidb_vector.py +core/rag/datasource/vdb/upstash/upstash_vector.py +core/rag/datasource/vdb/vikingdb/vikingdb_vector.py +core/rag/datasource/vdb/weaviate/weaviate_vector.py +core/rag/extractor/csv_extractor.py +core/rag/extractor/excel_extractor.py +core/rag/extractor/firecrawl/firecrawl_app.py +core/rag/extractor/firecrawl/firecrawl_web_extractor.py +core/rag/extractor/html_extractor.py +core/rag/extractor/jina_reader_extractor.py +core/rag/extractor/markdown_extractor.py +core/rag/extractor/notion_extractor.py +core/rag/extractor/pdf_extractor.py +core/rag/extractor/text_extractor.py +core/rag/extractor/unstructured/unstructured_doc_extractor.py +core/rag/extractor/unstructured/unstructured_eml_extractor.py +core/rag/extractor/unstructured/unstructured_epub_extractor.py +core/rag/extractor/unstructured/unstructured_markdown_extractor.py +core/rag/extractor/unstructured/unstructured_msg_extractor.py +core/rag/extractor/unstructured/unstructured_ppt_extractor.py +core/rag/extractor/unstructured/unstructured_pptx_extractor.py +core/rag/extractor/unstructured/unstructured_xml_extractor.py +core/rag/extractor/watercrawl/client.py +core/rag/extractor/watercrawl/extractor.py +core/rag/extractor/watercrawl/provider.py +core/rag/extractor/word_extractor.py +core/rag/index_processor/processor/paragraph_index_processor.py +core/rag/index_processor/processor/parent_child_index_processor.py +core/rag/index_processor/processor/qa_index_processor.py +core/rag/retrieval/router/multi_dataset_function_call_router.py +core/rag/summary_index/summary_index.py +core/repositories/sqlalchemy_workflow_execution_repository.py +core/repositories/sqlalchemy_workflow_node_execution_repository.py +core/tools/__base/tool.py +core/tools/mcp_tool/provider.py +core/tools/plugin_tool/provider.py +core/tools/utils/message_transformer.py +core/tools/utils/web_reader_tool.py +core/tools/workflow_as_tool/provider.py +core/trigger/debug/event_selectors.py +core/trigger/entities/entities.py +core/trigger/provider.py +core/workflow/workflow_entry.py +dify_graph/entities/workflow_execution.py +dify_graph/file/file_manager.py +dify_graph/graph_engine/error_handler.py +dify_graph/graph_engine/layers/execution_limits.py +dify_graph/nodes/agent/agent_node.py +dify_graph/nodes/base/node.py +dify_graph/nodes/code/code_node.py +dify_graph/nodes/datasource/datasource_node.py +dify_graph/nodes/document_extractor/node.py +dify_graph/nodes/human_input/human_input_node.py +dify_graph/nodes/if_else/if_else_node.py +dify_graph/nodes/iteration/iteration_node.py +dify_graph/nodes/knowledge_index/knowledge_index_node.py +dify_graph/nodes/knowledge_retrieval/knowledge_retrieval_node.py +dify_graph/nodes/list_operator/node.py +dify_graph/nodes/llm/node.py +dify_graph/nodes/loop/loop_node.py +dify_graph/nodes/parameter_extractor/parameter_extractor_node.py +dify_graph/nodes/question_classifier/question_classifier_node.py +dify_graph/nodes/start/start_node.py +dify_graph/nodes/template_transform/template_transform_node.py +dify_graph/nodes/tool/tool_node.py +dify_graph/nodes/trigger_plugin/trigger_event_node.py +dify_graph/nodes/trigger_schedule/trigger_schedule_node.py +dify_graph/nodes/trigger_webhook/node.py +dify_graph/nodes/variable_aggregator/variable_aggregator_node.py +dify_graph/nodes/variable_assigner/v1/node.py +dify_graph/nodes/variable_assigner/v2/node.py +dify_graph/variables/types.py +extensions/ext_fastopenapi.py +extensions/logstore/repositories/logstore_api_workflow_run_repository.py +extensions/otel/instrumentation.py +extensions/otel/runtime.py +extensions/storage/aliyun_oss_storage.py +extensions/storage/aws_s3_storage.py +extensions/storage/azure_blob_storage.py +extensions/storage/baidu_obs_storage.py +extensions/storage/clickzetta_volume/clickzetta_volume_storage.py +extensions/storage/clickzetta_volume/file_lifecycle.py +extensions/storage/google_cloud_storage.py +extensions/storage/huawei_obs_storage.py +extensions/storage/opendal_storage.py +extensions/storage/oracle_oci_storage.py +extensions/storage/supabase_storage.py +extensions/storage/tencent_cos_storage.py +extensions/storage/volcengine_tos_storage.py +factories/variable_factory.py +libs/external_api.py +libs/gmpy2_pkcs10aep_cipher.py +libs/helper.py +libs/login.py +libs/module_loading.py +libs/oauth.py +libs/oauth_data_source.py +models/trigger.py +models/workflow.py +repositories/sqlalchemy_api_workflow_node_execution_repository.py +repositories/sqlalchemy_api_workflow_run_repository.py +repositories/sqlalchemy_execution_extra_content_repository.py +schedule/queue_monitor_task.py +services/account_service.py +services/audio_service.py +services/auth/firecrawl/firecrawl.py +services/auth/jina.py +services/auth/jina/jina.py +services/auth/watercrawl/watercrawl.py +services/conversation_service.py +services/dataset_service.py +services/document_indexing_proxy/document_indexing_task_proxy.py +services/document_indexing_proxy/duplicate_document_indexing_task_proxy.py +services/external_knowledge_service.py +services/plugin/plugin_migration.py +services/recommend_app/buildin/buildin_retrieval.py +services/recommend_app/database/database_retrieval.py +services/recommend_app/remote/remote_retrieval.py +services/summary_index_service.py +services/tools/tools_transform_service.py +services/trigger/trigger_provider_service.py +services/trigger/trigger_subscription_builder_service.py +services/trigger/webhook_service.py +services/workflow_draft_variable_service.py +services/workflow_event_snapshot_service.py +services/workflow_service.py +tasks/app_generate/workflow_execute_task.py +tasks/regenerate_summary_index_task.py +tasks/trigger_processing_tasks.py +tasks/workflow_cfs_scheduler/cfs_scheduler.py +tasks/workflow_execution_tasks.py diff --git a/api/pyrefly.toml b/api/pyrefly.toml deleted file mode 100644 index 01f4c5a529..0000000000 --- a/api/pyrefly.toml +++ /dev/null @@ -1,8 +0,0 @@ -project-includes = ["."] -project-excludes = [ - ".venv", - "migrations/", -] -python-platform = "linux" -python-version = "3.11.0" -infer-with-first-use = false diff --git a/dev/pyrefly-check-local b/dev/pyrefly-check-local new file mode 100755 index 0000000000..80f90927bb --- /dev/null +++ b/dev/pyrefly-check-local @@ -0,0 +1,34 @@ +#!/bin/bash + +set -euo pipefail + +SCRIPT_DIR="$(dirname "$(realpath "$0")")" +REPO_ROOT="$SCRIPT_DIR/.." +cd "$REPO_ROOT" + +EXCLUDES_FILE="api/pyrefly-local-excludes.txt" + +pyrefly_args=( + "--summary=none" + "--project-excludes=.venv" + "--project-excludes=migrations/" + "--project-excludes=tests/" +) + +if [[ -f "$EXCLUDES_FILE" ]]; then + while IFS= read -r exclude; do + [[ -z "$exclude" || "${exclude:0:1}" == "#" ]] && continue + pyrefly_args+=("--project-excludes=$exclude") + done < "$EXCLUDES_FILE" +fi + +tmp_output="$(mktemp)" +set +e +uv run --directory api --dev pyrefly check "${pyrefly_args[@]}" >"$tmp_output" 2>&1 +pyrefly_status=$? +set -e + +uv run --directory api python libs/pyrefly_diagnostics.py < "$tmp_output" +rm -f "$tmp_output" + +exit "$pyrefly_status" From f751864ab37b21c3f84005bb4c1f2714febd8677 Mon Sep 17 00:00:00 2001 From: Lovish Arora <46993225+lavish0000@users.noreply.github.com> Date: Fri, 6 Mar 2026 02:49:53 +0100 Subject: [PATCH 055/256] fix(api): return inserted ids from Chroma and Clickzetta add_texts (#33065) Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com> --- .../datasource/vdb/chroma/chroma_vector.py | 3 ++- .../vdb/clickzetta/clickzetta_vector.py | 26 ++++++++++++------- 2 files changed, 18 insertions(+), 11 deletions(-) diff --git a/api/core/rag/datasource/vdb/chroma/chroma_vector.py b/api/core/rag/datasource/vdb/chroma/chroma_vector.py index de1572410c..cbc846f716 100644 --- a/api/core/rag/datasource/vdb/chroma/chroma_vector.py +++ b/api/core/rag/datasource/vdb/chroma/chroma_vector.py @@ -65,7 +65,7 @@ class ChromaVector(BaseVector): self._client.get_or_create_collection(collection_name) redis_client.set(collection_exist_cache_key, 1, ex=3600) - def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs): + def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs) -> list[str]: uuids = self._get_uuids(documents) texts = [d.page_content for d in documents] metadatas = [d.metadata for d in documents] @@ -73,6 +73,7 @@ class ChromaVector(BaseVector): collection = self._client.get_or_create_collection(self._collection_name) # FIXME: chromadb using numpy array, fix the type error later collection.upsert(ids=uuids, documents=texts, embeddings=embeddings, metadatas=metadatas) # type: ignore + return uuids def delete_by_metadata_field(self, key: str, value: str): collection = self._client.get_or_create_collection(self._collection_name) diff --git a/api/core/rag/datasource/vdb/clickzetta/clickzetta_vector.py b/api/core/rag/datasource/vdb/clickzetta/clickzetta_vector.py index 91bb71bfa6..8e8120fc10 100644 --- a/api/core/rag/datasource/vdb/clickzetta/clickzetta_vector.py +++ b/api/core/rag/datasource/vdb/clickzetta/clickzetta_vector.py @@ -605,25 +605,36 @@ class ClickzettaVector(BaseVector): logger.warning("Failed to create inverted index: %s", e) # Continue without inverted index - full-text search will fall back to LIKE - def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs): + def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs) -> list[str]: """Add documents with embeddings to the collection.""" if not documents: - return + return [] batch_size = self._config.batch_size total_batches = (len(documents) + batch_size - 1) // batch_size + added_ids = [] for i in range(0, len(documents), batch_size): batch_docs = documents[i : i + batch_size] batch_embeddings = embeddings[i : i + batch_size] + batch_doc_ids = [] + for doc in batch_docs: + metadata = doc.metadata if isinstance(doc.metadata, dict) else {} + batch_doc_ids.append(self._safe_doc_id(metadata.get("doc_id", str(uuid.uuid4())))) + added_ids.extend(batch_doc_ids) # Execute batch insert through write queue - self._execute_write(self._insert_batch, batch_docs, batch_embeddings, i, batch_size, total_batches) + self._execute_write( + self._insert_batch, batch_docs, batch_embeddings, batch_doc_ids, i, batch_size, total_batches + ) + + return added_ids def _insert_batch( self, batch_docs: list[Document], batch_embeddings: list[list[float]], + batch_doc_ids: list[str], batch_index: int, batch_size: int, total_batches: int, @@ -641,14 +652,9 @@ class ClickzettaVector(BaseVector): data_rows = [] vector_dimension = len(batch_embeddings[0]) if batch_embeddings and batch_embeddings[0] else 768 - for doc, embedding in zip(batch_docs, batch_embeddings): + for doc, embedding, doc_id in zip(batch_docs, batch_embeddings, batch_doc_ids): # Optimized: minimal checks for common case, fallback for edge cases - metadata = doc.metadata or {} - - if not isinstance(metadata, dict): - metadata = {} - - doc_id = self._safe_doc_id(metadata.get("doc_id", str(uuid.uuid4()))) + metadata = doc.metadata if isinstance(doc.metadata, dict) else {} # Fast path for JSON serialization try: From ad81513b6aaf47f4132f1f02f89f8fb0b7005d0e Mon Sep 17 00:00:00 2001 From: kurokobo Date: Fri, 6 Mar 2026 10:56:14 +0900 Subject: [PATCH 056/256] fix: show citations in advanced chat apps (#32985) Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com> --- .../advanced_chat/generate_task_pipeline.py | 15 ++- .../knowledge_retrieval_node.py | 2 +- ...t_generate_task_pipeline_extra_contents.py | 104 +++++++++++++++++- .../test_knowledge_retrieval_node.py | 1 + 4 files changed, 119 insertions(+), 3 deletions(-) diff --git a/api/core/app/apps/advanced_chat/generate_task_pipeline.py b/api/core/app/apps/advanced_chat/generate_task_pipeline.py index fbd5060b8c..a1cb375e24 100644 --- a/api/core/app/apps/advanced_chat/generate_task_pipeline.py +++ b/api/core/app/apps/advanced_chat/generate_task_pipeline.py @@ -516,8 +516,10 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport): graph_runtime_state=validated_state, ) + yield from self._handle_advanced_chat_message_end_event( + QueueAdvancedChatMessageEndEvent(), graph_runtime_state=validated_state + ) yield workflow_finish_resp - self._base_task_pipeline.queue_manager.publish(QueueAdvancedChatMessageEndEvent(), PublishFrom.TASK_PIPELINE) def _handle_workflow_partial_success_event( self, @@ -538,6 +540,9 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport): exceptions_count=event.exceptions_count, ) + yield from self._handle_advanced_chat_message_end_event( + QueueAdvancedChatMessageEndEvent(), graph_runtime_state=validated_state + ) yield workflow_finish_resp def _handle_workflow_paused_event( @@ -854,6 +859,14 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport): yield from self._handle_workflow_paused_event(event) break + case QueueWorkflowSucceededEvent(): + yield from self._handle_workflow_succeeded_event(event, trace_manager=trace_manager) + break + + case QueueWorkflowPartialSuccessEvent(): + yield from self._handle_workflow_partial_success_event(event, trace_manager=trace_manager) + break + case QueueStopEvent(): yield from self._handle_stop_event(event, graph_runtime_state=None, trace_manager=trace_manager) break diff --git a/api/dify_graph/nodes/knowledge_retrieval/knowledge_retrieval_node.py b/api/dify_graph/nodes/knowledge_retrieval/knowledge_retrieval_node.py index d84dda42d6..14744d0a74 100644 --- a/api/dify_graph/nodes/knowledge_retrieval/knowledge_retrieval_node.py +++ b/api/dify_graph/nodes/knowledge_retrieval/knowledge_retrieval_node.py @@ -116,7 +116,7 @@ class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node[KnowledgeRetrievalNodeD try: results, usage = self._fetch_dataset_retriever(node_data=self._node_data, variables=variables) - outputs = {"result": ArrayObjectSegment(value=[item.model_dump() for item in results])} + outputs = {"result": ArrayObjectSegment(value=[item.model_dump(by_alias=True) for item in results])} return NodeRunResult( status=WorkflowNodeExecutionStatus.SUCCEEDED, inputs=variables, diff --git a/api/tests/unit_tests/core/app/apps/advanced_chat/test_generate_task_pipeline_extra_contents.py b/api/tests/unit_tests/core/app/apps/advanced_chat/test_generate_task_pipeline_extra_contents.py index be773557f6..83a6e0f231 100644 --- a/api/tests/unit_tests/core/app/apps/advanced_chat/test_generate_task_pipeline_extra_contents.py +++ b/api/tests/unit_tests/core/app/apps/advanced_chat/test_generate_task_pipeline_extra_contents.py @@ -9,8 +9,16 @@ import pytest from core.app.apps.advanced_chat import generate_task_pipeline as pipeline_module from core.app.entities.app_invoke_entities import InvokeFrom -from core.app.entities.queue_entities import QueueTextChunkEvent, QueueWorkflowPausedEvent +from core.app.entities.queue_entities import ( + QueuePingEvent, + QueueTextChunkEvent, + QueueWorkflowPartialSuccessEvent, + QueueWorkflowPausedEvent, + QueueWorkflowSucceededEvent, +) +from core.app.entities.task_entities import StreamEvent from dify_graph.entities.pause_reason import HumanInputRequired +from dify_graph.enums import WorkflowExecutionStatus from models.enums import MessageStatus from models.execution_extra_content import HumanInputContent from models.model import EndUser @@ -185,3 +193,97 @@ def test_resume_appends_chunks_to_paused_answer() -> None: assert message.answer == "beforeafter" assert message.status == MessageStatus.NORMAL + + +def test_workflow_succeeded_emits_message_end_before_workflow_finished() -> None: + pipeline = _build_pipeline() + pipeline._application_generate_entity = SimpleNamespace(task_id="task-1") + pipeline._workflow_id = "workflow-1" + pipeline._ensure_workflow_initialized = mock.Mock() + runtime_state = SimpleNamespace() + pipeline._ensure_graph_runtime_initialized = mock.Mock(return_value=runtime_state) + pipeline._handle_advanced_chat_message_end_event = mock.Mock( + return_value=iter([SimpleNamespace(event=StreamEvent.MESSAGE_END)]) + ) + pipeline._workflow_response_converter = mock.Mock() + pipeline._workflow_response_converter.workflow_finish_to_stream_response.return_value = SimpleNamespace( + event=StreamEvent.WORKFLOW_FINISHED, + data=SimpleNamespace(status=WorkflowExecutionStatus.SUCCEEDED), + ) + + event = QueueWorkflowSucceededEvent(outputs={}) + responses = list(pipeline._handle_workflow_succeeded_event(event)) + + assert [resp.event for resp in responses] == [StreamEvent.MESSAGE_END, StreamEvent.WORKFLOW_FINISHED] + + +def test_workflow_partial_success_emits_message_end_before_workflow_finished() -> None: + pipeline = _build_pipeline() + pipeline._application_generate_entity = SimpleNamespace(task_id="task-1") + pipeline._workflow_id = "workflow-1" + pipeline._ensure_workflow_initialized = mock.Mock() + runtime_state = SimpleNamespace() + pipeline._ensure_graph_runtime_initialized = mock.Mock(return_value=runtime_state) + pipeline._handle_advanced_chat_message_end_event = mock.Mock( + return_value=iter([SimpleNamespace(event=StreamEvent.MESSAGE_END)]) + ) + pipeline._workflow_response_converter = mock.Mock() + pipeline._workflow_response_converter.workflow_finish_to_stream_response.return_value = SimpleNamespace( + event=StreamEvent.WORKFLOW_FINISHED, + data=SimpleNamespace(status=WorkflowExecutionStatus.PARTIAL_SUCCEEDED), + ) + + event = QueueWorkflowPartialSuccessEvent(exceptions_count=1, outputs={}) + responses = list(pipeline._handle_workflow_partial_success_event(event)) + + assert [resp.event for resp in responses] == [StreamEvent.MESSAGE_END, StreamEvent.WORKFLOW_FINISHED] + + +def test_process_stream_response_breaks_after_workflow_succeeded() -> None: + pipeline = _build_pipeline() + succeeded_event = QueueWorkflowSucceededEvent(outputs={}) + ping_event = QueuePingEvent() + queue_messages = [ + SimpleNamespace(event=succeeded_event), + SimpleNamespace(event=ping_event), + ] + + pipeline._conversation_name_generate_thread = None + pipeline._base_task_pipeline = mock.Mock() + pipeline._base_task_pipeline.queue_manager = mock.Mock() + pipeline._base_task_pipeline.queue_manager.listen.return_value = iter(queue_messages) + pipeline._base_task_pipeline.ping_stream_response = mock.Mock(return_value=SimpleNamespace(event=StreamEvent.PING)) + pipeline._handle_workflow_succeeded_event = mock.Mock( + return_value=iter([SimpleNamespace(event=StreamEvent.WORKFLOW_FINISHED)]) + ) + + responses = list(pipeline._process_stream_response()) + + assert [resp.event for resp in responses] == [StreamEvent.WORKFLOW_FINISHED] + pipeline._handle_workflow_succeeded_event.assert_called_once_with(succeeded_event, trace_manager=None) + pipeline._base_task_pipeline.ping_stream_response.assert_not_called() + + +def test_process_stream_response_breaks_after_workflow_partial_success() -> None: + pipeline = _build_pipeline() + partial_event = QueueWorkflowPartialSuccessEvent(exceptions_count=1, outputs={}) + ping_event = QueuePingEvent() + queue_messages = [ + SimpleNamespace(event=partial_event), + SimpleNamespace(event=ping_event), + ] + + pipeline._conversation_name_generate_thread = None + pipeline._base_task_pipeline = mock.Mock() + pipeline._base_task_pipeline.queue_manager = mock.Mock() + pipeline._base_task_pipeline.queue_manager.listen.return_value = iter(queue_messages) + pipeline._base_task_pipeline.ping_stream_response = mock.Mock(return_value=SimpleNamespace(event=StreamEvent.PING)) + pipeline._handle_workflow_partial_success_event = mock.Mock( + return_value=iter([SimpleNamespace(event=StreamEvent.WORKFLOW_FINISHED)]) + ) + + responses = list(pipeline._process_stream_response()) + + assert [resp.event for resp in responses] == [StreamEvent.WORKFLOW_FINISHED] + pipeline._handle_workflow_partial_success_event.assert_called_once_with(partial_event, trace_manager=None) + pipeline._base_task_pipeline.ping_stream_response.assert_not_called() diff --git a/api/tests/unit_tests/core/workflow/nodes/knowledge_retrieval/test_knowledge_retrieval_node.py b/api/tests/unit_tests/core/workflow/nodes/knowledge_retrieval/test_knowledge_retrieval_node.py index e194d66ee3..6a538d81de 100644 --- a/api/tests/unit_tests/core/workflow/nodes/knowledge_retrieval/test_knowledge_retrieval_node.py +++ b/api/tests/unit_tests/core/workflow/nodes/knowledge_retrieval/test_knowledge_retrieval_node.py @@ -205,6 +205,7 @@ class TestKnowledgeRetrievalNode: assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED assert "result" in result.outputs assert mock_rag_retrieval.knowledge_retrieval.called + mock_source.model_dump.assert_called_once_with(by_alias=True) def test_run_with_query_variable_multiple_mode( self, From 7ffa6c184940027053e8a8b8a0cca3d52ba90c39 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=9D=9E=E6=B3=95=E6=93=8D=E4=BD=9C?= Date: Fri, 6 Mar 2026 09:57:09 +0800 Subject: [PATCH 057/256] fix: conversation var unexpected reset after HITL node (#32936) --- api/dify_graph/runtime/variable_pool.py | 10 ++++++++-- .../entities/test_graph_runtime_state.py | 16 ++++++++++++++++ 2 files changed, 24 insertions(+), 2 deletions(-) diff --git a/api/dify_graph/runtime/variable_pool.py b/api/dify_graph/runtime/variable_pool.py index a2b1af99bb..e3ef6a2897 100644 --- a/api/dify_graph/runtime/variable_pool.py +++ b/api/dify_graph/runtime/variable_pool.py @@ -65,9 +65,15 @@ class VariablePool(BaseModel): # Add environment variables to the variable pool for var in self.environment_variables: self.add((ENVIRONMENT_VARIABLE_NODE_ID, var.name), var) - # Add conversation variables to the variable pool + # Add conversation variables to the variable pool. When restoring from a serialized + # snapshot, `variable_dictionary` already carries the latest runtime values. + # In that case, keep existing entries instead of overwriting them with the + # bootstrap list. for var in self.conversation_variables: - self.add((CONVERSATION_VARIABLE_NODE_ID, var.name), var) + selector = (CONVERSATION_VARIABLE_NODE_ID, var.name) + if self._has(selector): + continue + self.add(selector, var) # Add rag pipeline variables to the variable pool if self.rag_pipeline_variables: rag_pipeline_variables_map: defaultdict[Any, dict[Any, Any]] = defaultdict(dict) diff --git a/api/tests/unit_tests/core/workflow/entities/test_graph_runtime_state.py b/api/tests/unit_tests/core/workflow/entities/test_graph_runtime_state.py index 0df4927697..22792eb5b3 100644 --- a/api/tests/unit_tests/core/workflow/entities/test_graph_runtime_state.py +++ b/api/tests/unit_tests/core/workflow/entities/test_graph_runtime_state.py @@ -4,8 +4,10 @@ from unittest.mock import MagicMock, patch import pytest +from dify_graph.constants import CONVERSATION_VARIABLE_NODE_ID from dify_graph.model_runtime.entities.llm_entities import LLMUsage from dify_graph.runtime import GraphRuntimeState, ReadOnlyGraphRuntimeStateWrapper, VariablePool +from dify_graph.variables.variables import StringVariable class StubCoordinator: @@ -278,3 +280,17 @@ class TestGraphRuntimeState: assert restored_execution.started is True assert new_stub.state == "configured" + + def test_snapshot_restore_preserves_updated_conversation_variable(self): + variable_pool = VariablePool( + conversation_variables=[StringVariable(name="session_name", value="before")], + ) + variable_pool.add((CONVERSATION_VARIABLE_NODE_ID, "session_name"), "after") + + state = GraphRuntimeState(variable_pool=variable_pool, start_at=time()) + snapshot = state.dumps() + restored = GraphRuntimeState.from_snapshot(snapshot) + + restored_value = restored.variable_pool.get((CONVERSATION_VARIABLE_NODE_ID, "session_name")) + assert restored_value is not None + assert restored_value.value == "after" From d1eaa41dd1768a26fc61c79afb754343d311aa9b Mon Sep 17 00:00:00 2001 From: Copilot <198982749+Copilot@users.noreply.github.com> Date: Fri, 6 Mar 2026 09:57:43 +0800 Subject: [PATCH 058/256] fix(i18n): correct French translation of "disabled" from medical term to UI-appropriate term (#33067) Co-authored-by: copilot-swe-agent[bot] <198982749+Copilot@users.noreply.github.com> Co-authored-by: crazywoola <100913391+crazywoola@users.noreply.github.com> --- web/i18n/fr-FR/dataset.json | 2 +- web/i18n/fr-FR/plugin.json | 2 +- web/i18n/fr-FR/workflow.json | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/web/i18n/fr-FR/dataset.json b/web/i18n/fr-FR/dataset.json index 43d5d9183c..9b20769fbe 100644 --- a/web/i18n/fr-FR/dataset.json +++ b/web/i18n/fr-FR/dataset.json @@ -124,7 +124,7 @@ "metadata.datasetMetadata.deleteContent": "Êtes-vous sûr de vouloir supprimer les métadonnées \"{{name}}\" ?", "metadata.datasetMetadata.deleteTitle": "Confirmer la suppression", "metadata.datasetMetadata.description": "Vous pouvez gérer toutes les métadonnées dans cette connaissance ici. Les modifications seront synchronisées avec chaque document.", - "metadata.datasetMetadata.disabled": "handicapés", + "metadata.datasetMetadata.disabled": "Désactivé", "metadata.datasetMetadata.name": "Nom", "metadata.datasetMetadata.namePlaceholder": "Nom de métadonnées", "metadata.datasetMetadata.rename": "Renommer", diff --git a/web/i18n/fr-FR/plugin.json b/web/i18n/fr-FR/plugin.json index d96d207de0..79f43acb8e 100644 --- a/web/i18n/fr-FR/plugin.json +++ b/web/i18n/fr-FR/plugin.json @@ -95,7 +95,7 @@ "detailPanel.deprecation.reason.businessAdjustments": "ajustements commerciaux", "detailPanel.deprecation.reason.noMaintainer": "aucun mainteneur", "detailPanel.deprecation.reason.ownershipTransferred": "propriété transférée", - "detailPanel.disabled": "Handicapé", + "detailPanel.disabled": "Désactivé", "detailPanel.endpointDeleteContent": "Souhaitez-vous supprimer {{name}} ?", "detailPanel.endpointDeleteTip": "Supprimer le point de terminaison", "detailPanel.endpointDisableContent": "Souhaitez-vous désactiver {{name}} ?", diff --git a/web/i18n/fr-FR/workflow.json b/web/i18n/fr-FR/workflow.json index b5f13ca3b1..631dc5d05b 100644 --- a/web/i18n/fr-FR/workflow.json +++ b/web/i18n/fr-FR/workflow.json @@ -687,7 +687,7 @@ "nodes.knowledgeRetrieval.metadata.options.automatic.subTitle": "Générer automatiquement des conditions de filtrage des métadonnées en fonction de la requête de l'utilisateur", "nodes.knowledgeRetrieval.metadata.options.automatic.title": "Automatique", "nodes.knowledgeRetrieval.metadata.options.disabled.subTitle": "Ne pas activer le filtrage des métadonnées", - "nodes.knowledgeRetrieval.metadata.options.disabled.title": "Handicapé", + "nodes.knowledgeRetrieval.metadata.options.disabled.title": "Désactivé", "nodes.knowledgeRetrieval.metadata.options.manual.subTitle": "Ajouter manuellement des conditions de filtrage des métadonnées", "nodes.knowledgeRetrieval.metadata.options.manual.title": "Manuel", "nodes.knowledgeRetrieval.metadata.panel.add": "Ajouter une condition", From dc31b075338d09c9a887f0f3295dbfd2fb5680c6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=9D=9E=E6=B3=95=E6=93=8D=E4=BD=9C?= Date: Fri, 6 Mar 2026 11:45:51 +0800 Subject: [PATCH 059/256] fix(type-check): resolve missing-attribute in app dataset join update handler (#33071) --- ...date_app_dataset_join_when_app_model_config_updated.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/api/events/event_handlers/update_app_dataset_join_when_app_model_config_updated.py b/api/events/event_handlers/update_app_dataset_join_when_app_model_config_updated.py index 69959acd19..b70c2183d2 100644 --- a/api/events/event_handlers/update_app_dataset_join_when_app_model_config_updated.py +++ b/api/events/event_handlers/update_app_dataset_join_when_app_model_config_updated.py @@ -1,3 +1,5 @@ +from typing import Any, cast + from sqlalchemy import select from events.app_event import app_model_config_was_updated @@ -54,9 +56,11 @@ def get_dataset_ids_from_model_config(app_model_config: AppModelConfig) -> set[s continue tool_type = list(tool.keys())[0] - tool_config = list(tool.values())[0] + tool_config = cast(dict[str, Any], list(tool.values())[0]) if tool_type == "dataset": - dataset_ids.add(tool_config.get("id")) + dataset_id = tool_config.get("id") + if isinstance(dataset_id, str): + dataset_ids.add(dataset_id) # get dataset from dataset_configs dataset_configs = app_model_config.dataset_configs_dict From 0490756ab23808418166e8df59d2a64c437c634e Mon Sep 17 00:00:00 2001 From: Nite Knite Date: Fri, 6 Mar 2026 14:29:29 +0800 Subject: [PATCH 060/256] chore: add support email env (#33075) --- .../header/account-dropdown/index.spec.tsx | 2 ++ .../header/account-dropdown/support.spec.tsx | 31 +++++++++++++++++-- .../header/account-dropdown/support.tsx | 8 ++--- web/app/components/header/utils/util.ts | 4 +-- web/config/index.ts | 6 ++++ web/env.ts | 2 ++ 6 files changed, 45 insertions(+), 8 deletions(-) diff --git a/web/app/components/header/account-dropdown/index.spec.tsx b/web/app/components/header/account-dropdown/index.spec.tsx index a92f8503ee..c234d350d8 100644 --- a/web/app/components/header/account-dropdown/index.spec.tsx +++ b/web/app/components/header/account-dropdown/index.spec.tsx @@ -70,6 +70,7 @@ const { mockConfig, mockEnv } = vi.hoisted(() => ({ mockConfig: { IS_CLOUD_EDITION: false, ZENDESK_WIDGET_KEY: '', + SUPPORT_EMAIL_ADDRESS: '', }, mockEnv: { env: { @@ -80,6 +81,7 @@ const { mockConfig, mockEnv } = vi.hoisted(() => ({ vi.mock('@/config', () => ({ get IS_CLOUD_EDITION() { return mockConfig.IS_CLOUD_EDITION }, get ZENDESK_WIDGET_KEY() { return mockConfig.ZENDESK_WIDGET_KEY }, + get SUPPORT_EMAIL_ADDRESS() { return mockConfig.SUPPORT_EMAIL_ADDRESS }, IS_DEV: false, IS_CE_EDITION: false, })) diff --git a/web/app/components/header/account-dropdown/support.spec.tsx b/web/app/components/header/account-dropdown/support.spec.tsx index 90bcb9f3ec..de48b744c2 100644 --- a/web/app/components/header/account-dropdown/support.spec.tsx +++ b/web/app/components/header/account-dropdown/support.spec.tsx @@ -11,6 +11,10 @@ const { mockZendeskKey } = vi.hoisted(() => ({ mockZendeskKey: { value: 'test-key' }, })) +const { mockSupportEmailKey } = vi.hoisted(() => ({ + mockSupportEmailKey: { value: '' }, +})) + vi.mock('@/context/app-context', async (importOriginal) => { const actual = await importOriginal() return { @@ -33,6 +37,7 @@ vi.mock('@/config', async (importOriginal) => { ...actual, IS_CE_EDITION: false, get ZENDESK_WIDGET_KEY() { return mockZendeskKey.value }, + get SUPPORT_EMAIL_ADDRESS() { return mockSupportEmailKey.value }, } }) @@ -84,6 +89,7 @@ describe('Support', () => { vi.clearAllMocks() window.zE = vi.fn() mockZendeskKey.value = 'test-key' + mockSupportEmailKey.value = '' vi.mocked(useAppContext).mockReturnValue(baseAppContextValue) vi.mocked(useProviderContext).mockReturnValue({ ...baseProviderContextValue, @@ -96,7 +102,7 @@ describe('Support', () => { const renderSupport = () => { return render( - {}}> + { }}> open @@ -125,7 +131,7 @@ describe('Support', () => { }) }) - describe('Plan-based Channels', () => { + describe('Dedicated Channels', () => { it('should show "Contact Us" when ZENDESK_WIDGET_KEY is present', () => { // Act renderSupport() @@ -166,6 +172,27 @@ describe('Support', () => { expect(screen.getByText('common.userProfile.emailSupport')).toBeInTheDocument() expect(screen.queryByText('common.userProfile.contactUs')).not.toBeInTheDocument() }) + + it('should show email support if specified in the config', () => { + // Arrange + mockZendeskKey.value = '' + mockSupportEmailKey.value = 'support@example.com' + vi.mocked(useProviderContext).mockReturnValue({ + ...baseProviderContextValue, + plan: { + ...baseProviderContextValue.plan, + type: Plan.sandbox, + }, + }) + + // Act + renderSupport() + fireEvent.click(screen.getByText('common.userProfile.support')) + + // Assert + expect(screen.queryByText('common.userProfile.emailSupport')).toBeInTheDocument() + expect(screen.getByText('common.userProfile.emailSupport')?.closest('a')?.getAttribute('href')).toMatch(new RegExp(`^mailto:${mockSupportEmailKey.value}`)) + }) }) describe('Interactions and Links', () => { diff --git a/web/app/components/header/account-dropdown/support.tsx b/web/app/components/header/account-dropdown/support.tsx index 7ec2766977..ead4509cce 100644 --- a/web/app/components/header/account-dropdown/support.tsx +++ b/web/app/components/header/account-dropdown/support.tsx @@ -2,7 +2,7 @@ import { useTranslation } from 'react-i18next' import { DropdownMenuGroup, DropdownMenuItem, DropdownMenuSub, DropdownMenuSubContent, DropdownMenuSubTrigger } from '@/app/components/base/ui/dropdown-menu' import { toggleZendeskWindow } from '@/app/components/base/zendesk/utils' import { Plan } from '@/app/components/billing/type' -import { ZENDESK_WIDGET_KEY } from '@/config' +import { SUPPORT_EMAIL_ADDRESS, ZENDESK_WIDGET_KEY } from '@/config' import { useAppContext } from '@/context/app-context' import { useProviderContext } from '@/context/provider-context' import { mailToSupport } from '../utils/util' @@ -17,8 +17,8 @@ export default function Support({ closeAccountDropdown }: SupportProps) { const { t } = useTranslation() const { plan } = useProviderContext() const { userProfile, langGeniusVersionInfo } = useAppContext() - const hasDedicatedChannel = plan.type !== Plan.sandbox - const hasZendeskWidget = !!ZENDESK_WIDGET_KEY?.trim() + const hasDedicatedChannel = plan.type !== Plan.sandbox || Boolean(SUPPORT_EMAIL_ADDRESS.trim()) + const hasZendeskWidget = Boolean(ZENDESK_WIDGET_KEY.trim()) return ( @@ -49,7 +49,7 @@ export default function Support({ closeAccountDropdown }: SupportProps) { {hasDedicatedChannel && !hasZendeskWidget && ( } + render={} > { +export const mailToSupport = (account: string, plan: string, version: string, supportEmailAddress?: string) => { const subject = `Technical Support Request ${plan} ${account}` const body = ` Please do not remove the following information: @@ -21,5 +21,5 @@ export const mailToSupport = (account: string, plan: string, version: string) => Platform: Problem Description: ` - return generateMailToLink('support@dify.ai', subject, body) + return generateMailToLink(supportEmailAddress || 'support@dify.ai', subject, body) } diff --git a/web/config/index.ts b/web/config/index.ts index 35ea3780a8..e8526479a1 100644 --- a/web/config/index.ts +++ b/web/config/index.ts @@ -342,6 +342,12 @@ export const ZENDESK_FIELD_IDS = { '', ), } + +export const SUPPORT_EMAIL_ADDRESS = getStringConfig( + env.NEXT_PUBLIC_SUPPORT_EMAIL_ADDRESS, + '', +) + export const APP_VERSION = pkg.version export const IS_MARKETPLACE = env.NEXT_PUBLIC_IS_MARKETPLACE diff --git a/web/env.ts b/web/env.ts index f931c48677..8ecde76143 100644 --- a/web/env.ts +++ b/web/env.ts @@ -115,6 +115,7 @@ const clientSchema = { */ NEXT_PUBLIC_SENTRY_DSN: z.string().optional(), NEXT_PUBLIC_SITE_ABOUT: z.string().optional(), + NEXT_PUBLIC_SUPPORT_EMAIL_ADDRESS: z.email().optional(), NEXT_PUBLIC_SUPPORT_MAIL_LOGIN: coercedBoolean.default(false), /** * The timeout for the text generation in millisecond @@ -184,6 +185,7 @@ export const env = createEnv({ NEXT_PUBLIC_PUBLIC_API_PREFIX: isServer ? process.env.NEXT_PUBLIC_PUBLIC_API_PREFIX : getRuntimeEnvFromBody('publicApiPrefix'), NEXT_PUBLIC_SENTRY_DSN: isServer ? process.env.NEXT_PUBLIC_SENTRY_DSN : getRuntimeEnvFromBody('sentryDsn'), NEXT_PUBLIC_SITE_ABOUT: isServer ? process.env.NEXT_PUBLIC_SITE_ABOUT : getRuntimeEnvFromBody('siteAbout'), + NEXT_PUBLIC_SUPPORT_EMAIL_ADDRESS: isServer ? process.env.NEXT_PUBLIC_SUPPORT_EMAIL_ADDRESS : getRuntimeEnvFromBody('supportEmailAddress'), NEXT_PUBLIC_SUPPORT_MAIL_LOGIN: isServer ? process.env.NEXT_PUBLIC_SUPPORT_MAIL_LOGIN : getRuntimeEnvFromBody('supportMailLogin'), NEXT_PUBLIC_TEXT_GENERATION_TIMEOUT_MS: isServer ? process.env.NEXT_PUBLIC_TEXT_GENERATION_TIMEOUT_MS : getRuntimeEnvFromBody('textGenerationTimeoutMs'), NEXT_PUBLIC_TOP_K_MAX_VALUE: isServer ? process.env.NEXT_PUBLIC_TOP_K_MAX_VALUE : getRuntimeEnvFromBody('topKMaxValue'), From e74cda653534baa858334d6b35d6403faddc814c Mon Sep 17 00:00:00 2001 From: eux Date: Fri, 6 Mar 2026 14:35:28 +0800 Subject: [PATCH 061/256] feat(tasks): isolate summary generation to dedicated dataset_summary queue (#32972) --- .devcontainer/post_create_command.sh | 2 +- .vscode/launch.json.template | 2 +- api/docker/entrypoint.sh | 4 +- api/tasks/generate_summary_index_task.py | 2 +- api/tasks/regenerate_summary_index_task.py | 2 +- .../tasks/test_summary_queue_isolation.py | 40 +++++++++++++++++++ dev/start-worker | 5 ++- 7 files changed, 49 insertions(+), 8 deletions(-) create mode 100644 api/tests/unit_tests/tasks/test_summary_queue_isolation.py diff --git a/.devcontainer/post_create_command.sh b/.devcontainer/post_create_command.sh index 637593b9de..b92d4c35a8 100755 --- a/.devcontainer/post_create_command.sh +++ b/.devcontainer/post_create_command.sh @@ -7,7 +7,7 @@ cd web && pnpm install pipx install uv echo "alias start-api=\"cd $WORKSPACE_ROOT/api && uv run python -m flask run --host 0.0.0.0 --port=5001 --debug\"" >> ~/.bashrc -echo "alias start-worker=\"cd $WORKSPACE_ROOT/api && uv run python -m celery -A app.celery worker -P threads -c 1 --loglevel INFO -Q dataset,priority_dataset,priority_pipeline,pipeline,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation,workflow,schedule_poller,schedule_executor,triggered_workflow_dispatcher,trigger_refresh_executor,retention\"" >> ~/.bashrc +echo "alias start-worker=\"cd $WORKSPACE_ROOT/api && uv run python -m celery -A app.celery worker -P threads -c 1 --loglevel INFO -Q dataset,dataset_summary,priority_dataset,priority_pipeline,pipeline,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation,workflow,schedule_poller,schedule_executor,triggered_workflow_dispatcher,trigger_refresh_executor,retention\"" >> ~/.bashrc echo "alias start-web=\"cd $WORKSPACE_ROOT/web && pnpm dev:inspect\"" >> ~/.bashrc echo "alias start-web-prod=\"cd $WORKSPACE_ROOT/web && pnpm build && pnpm start\"" >> ~/.bashrc echo "alias start-containers=\"cd $WORKSPACE_ROOT/docker && docker-compose -f docker-compose.middleware.yaml -p dify --env-file middleware.env up -d\"" >> ~/.bashrc diff --git a/.vscode/launch.json.template b/.vscode/launch.json.template index 700b815c3b..c3e2c50c52 100644 --- a/.vscode/launch.json.template +++ b/.vscode/launch.json.template @@ -37,7 +37,7 @@ "-c", "1", "-Q", - "dataset,priority_dataset,priority_pipeline,pipeline,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation,workflow,schedule_poller,schedule_executor,triggered_workflow_dispatcher,trigger_refresh_executor,retention,workflow_based_app_execution", + "dataset,dataset_summary,priority_dataset,priority_pipeline,pipeline,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation,workflow,schedule_poller,schedule_executor,triggered_workflow_dispatcher,trigger_refresh_executor,retention,workflow_based_app_execution", "--loglevel", "INFO" ], diff --git a/api/docker/entrypoint.sh b/api/docker/entrypoint.sh index 1a675b3338..6b904b7d0d 100755 --- a/api/docker/entrypoint.sh +++ b/api/docker/entrypoint.sh @@ -35,10 +35,10 @@ if [[ "${MODE}" == "worker" ]]; then if [[ -z "${CELERY_QUEUES}" ]]; then if [[ "${EDITION}" == "CLOUD" ]]; then # Cloud edition: separate queues for dataset and trigger tasks - DEFAULT_QUEUES="api_token,dataset,priority_dataset,priority_pipeline,pipeline,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation,workflow_professional,workflow_team,workflow_sandbox,schedule_poller,schedule_executor,triggered_workflow_dispatcher,trigger_refresh_executor,retention,workflow_based_app_execution" + DEFAULT_QUEUES="api_token,dataset,dataset_summary,priority_dataset,priority_pipeline,pipeline,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation,workflow_professional,workflow_team,workflow_sandbox,schedule_poller,schedule_executor,triggered_workflow_dispatcher,trigger_refresh_executor,retention,workflow_based_app_execution" else # Community edition (SELF_HOSTED): dataset, pipeline and workflow have separate queues - DEFAULT_QUEUES="api_token,dataset,priority_dataset,priority_pipeline,pipeline,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation,workflow,schedule_poller,schedule_executor,triggered_workflow_dispatcher,trigger_refresh_executor,retention,workflow_based_app_execution" + DEFAULT_QUEUES="api_token,dataset,dataset_summary,priority_dataset,priority_pipeline,pipeline,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation,workflow,schedule_poller,schedule_executor,triggered_workflow_dispatcher,trigger_refresh_executor,retention,workflow_based_app_execution" fi else DEFAULT_QUEUES="${CELERY_QUEUES}" diff --git a/api/tasks/generate_summary_index_task.py b/api/tasks/generate_summary_index_task.py index e4273e16b5..6493833edc 100644 --- a/api/tasks/generate_summary_index_task.py +++ b/api/tasks/generate_summary_index_task.py @@ -14,7 +14,7 @@ from services.summary_index_service import SummaryIndexService logger = logging.getLogger(__name__) -@shared_task(queue="dataset") +@shared_task(queue="dataset_summary") def generate_summary_index_task(dataset_id: str, document_id: str, segment_ids: list[str] | None = None): """ Async generate summary index for document segments. diff --git a/api/tasks/regenerate_summary_index_task.py b/api/tasks/regenerate_summary_index_task.py index cf8988d13e..39c2f4103e 100644 --- a/api/tasks/regenerate_summary_index_task.py +++ b/api/tasks/regenerate_summary_index_task.py @@ -16,7 +16,7 @@ from services.summary_index_service import SummaryIndexService logger = logging.getLogger(__name__) -@shared_task(queue="dataset") +@shared_task(queue="dataset_summary") def regenerate_summary_index_task( dataset_id: str, regenerate_reason: str = "summary_model_changed", diff --git a/api/tests/unit_tests/tasks/test_summary_queue_isolation.py b/api/tests/unit_tests/tasks/test_summary_queue_isolation.py new file mode 100644 index 0000000000..f6632e0a8a --- /dev/null +++ b/api/tests/unit_tests/tasks/test_summary_queue_isolation.py @@ -0,0 +1,40 @@ +""" +Unit tests for summary index task queue isolation. + +These tasks must NOT run on the shared 'dataset' queue because they invoke LLMs +for each document segment and can occupy all worker slots for hours, blocking +document indexing tasks. +""" + +import pytest + +from tasks.generate_summary_index_task import generate_summary_index_task +from tasks.regenerate_summary_index_task import regenerate_summary_index_task + +SUMMARY_QUEUE = "dataset_summary" +INDEXING_QUEUE = "dataset" + + +def _task_queue(task) -> str | None: + # Celery's @shared_task(queue=...) stores the routing key on the task instance + # at runtime, but type stubs don't declare it; use getattr to stay type-clean. + return getattr(task, "queue", None) + + +@pytest.mark.parametrize( + ("task", "task_name"), + [ + (generate_summary_index_task, "generate_summary_index_task"), + (regenerate_summary_index_task, "regenerate_summary_index_task"), + ], +) +def test_summary_task_uses_dedicated_queue(task, task_name): + """Summary tasks must use the dataset_summary queue, not the shared dataset queue. + + Summary generation is LLM-heavy and will block document indexing if placed + on the shared queue. + """ + assert _task_queue(task) == SUMMARY_QUEUE, ( + f"{task_name} must run on '{SUMMARY_QUEUE}' queue (not '{INDEXING_QUEUE}'). " + "Summary generation is LLM-heavy and will block document indexing if placed on the shared queue." + ) diff --git a/dev/start-worker b/dev/start-worker index 0450851b56..8baa36f1ed 100755 --- a/dev/start-worker +++ b/dev/start-worker @@ -21,6 +21,7 @@ show_help() { echo "" echo "Available queues:" echo " dataset - RAG indexing and document processing" + echo " dataset_summary - LLM-heavy summary index generation (isolated from indexing)" echo " workflow - Workflow triggers (community edition)" echo " workflow_professional - Professional tier workflows (cloud edition)" echo " workflow_team - Team tier workflows (cloud edition)" @@ -106,10 +107,10 @@ if [[ -z "${QUEUES}" ]]; then # Configure queues based on edition if [[ "${EDITION}" == "CLOUD" ]]; then # Cloud edition: separate queues for dataset and trigger tasks - QUEUES="dataset,priority_dataset,priority_pipeline,pipeline,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation,workflow_professional,workflow_team,workflow_sandbox,schedule_poller,schedule_executor,triggered_workflow_dispatcher,trigger_refresh_executor,retention,workflow_based_app_execution" + QUEUES="dataset,dataset_summary,priority_dataset,priority_pipeline,pipeline,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation,workflow_professional,workflow_team,workflow_sandbox,schedule_poller,schedule_executor,triggered_workflow_dispatcher,trigger_refresh_executor,retention,workflow_based_app_execution" else # Community edition (SELF_HOSTED): dataset and workflow have separate queues - QUEUES="dataset,priority_dataset,priority_pipeline,pipeline,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation,workflow,schedule_poller,schedule_executor,triggered_workflow_dispatcher,trigger_refresh_executor,retention,workflow_based_app_execution" + QUEUES="dataset,dataset_summary,priority_dataset,priority_pipeline,pipeline,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation,workflow,schedule_poller,schedule_executor,triggered_workflow_dispatcher,trigger_refresh_executor,retention,workflow_based_app_execution" fi echo "No queues specified, using edition-based defaults: ${QUEUES}" From f05f0be55ff51ddc0cd31e97127b957f03dcb652 Mon Sep 17 00:00:00 2001 From: Stephen Zhou Date: Fri, 6 Mar 2026 14:54:24 +0800 Subject: [PATCH 062/256] chore: use react-grab to replace code-inspector-plugin (#33078) --- .../components/devtools/react-grab/loader.tsx | 17 ++ web/app/layout.tsx | 2 + web/eslint.config.mjs | 2 +- web/next.config.ts | 20 +- web/package.json | 2 - web/{eslint-rules => plugins/eslint}/index.js | 0 .../eslint}/namespaces.js | 0 .../eslint}/rules/consistent-placeholders.js | 0 .../eslint}/rules/no-as-any-in-t.js | 0 .../eslint}/rules/no-extra-keys.js | 0 .../rules/no-legacy-namespace-prefix.js | 0 .../eslint}/rules/require-ns-option.js | 0 web/{eslint-rules => plugins/eslint}/utils.js | 0 web/plugins/vite/custom-i18n-hmr.ts | 80 +++++ web/plugins/vite/react-grab-open-file.ts | 92 ++++++ web/plugins/vite/utils.ts | 20 ++ web/pnpm-lock.yaml | 276 ++---------------- web/vite.config.ts | 166 +---------- 18 files changed, 250 insertions(+), 427 deletions(-) create mode 100644 web/app/components/devtools/react-grab/loader.tsx rename web/{eslint-rules => plugins/eslint}/index.js (100%) rename web/{eslint-rules => plugins/eslint}/namespaces.js (100%) rename web/{eslint-rules => plugins/eslint}/rules/consistent-placeholders.js (100%) rename web/{eslint-rules => plugins/eslint}/rules/no-as-any-in-t.js (100%) rename web/{eslint-rules => plugins/eslint}/rules/no-extra-keys.js (100%) rename web/{eslint-rules => plugins/eslint}/rules/no-legacy-namespace-prefix.js (100%) rename web/{eslint-rules => plugins/eslint}/rules/require-ns-option.js (100%) rename web/{eslint-rules => plugins/eslint}/utils.js (100%) create mode 100644 web/plugins/vite/custom-i18n-hmr.ts create mode 100644 web/plugins/vite/react-grab-open-file.ts create mode 100644 web/plugins/vite/utils.ts diff --git a/web/app/components/devtools/react-grab/loader.tsx b/web/app/components/devtools/react-grab/loader.tsx new file mode 100644 index 0000000000..3a1ecc6be8 --- /dev/null +++ b/web/app/components/devtools/react-grab/loader.tsx @@ -0,0 +1,17 @@ +import Script from 'next/script' +import { IS_DEV } from '@/config' + +export function ReactGrabLoader() { + if (!IS_DEV) + return null + + return ( + <> + ') - }) - - it('renders empty script tag when child value is undefined', () => { - const node: ScriptNode = { - children: [{}], - } - - const { container } = render( - , - ) - - expect(container.textContent).toBe('') - }) - - it('renders empty script tag when children array is empty', () => { - const node: ScriptNode = { - children: [], - } - - const { container } = render( - , - ) - - expect(container.textContent).toBe('') - }) - - it('preserves multiline script content', () => { - const multi = `console.log("line1"); -console.log("line2");` - - const node: ScriptNode = { - children: [{ value: multi }], - } - - const { container } = render( - , - ) - - expect(container.textContent).toBe(``) - }) - - it('has displayName set correctly', () => { - expect(ScriptBlock.displayName).toBe('ScriptBlock') - }) -}) diff --git a/web/app/components/base/markdown-blocks/code-block.tsx b/web/app/components/base/markdown-blocks/code-block.tsx index 744a578ff6..837929cfff 100644 --- a/web/app/components/base/markdown-blocks/code-block.tsx +++ b/web/app/components/base/markdown-blocks/code-block.tsx @@ -399,7 +399,6 @@ const CodeBlock: any = memo(({ inline, className, children = '', ...props }: any }} language={match?.[1]} showLineNumbers - PreTag="div" > {content} @@ -413,7 +412,7 @@ const CodeBlock: any = memo(({ inline, className, children = '', ...props }: any return (
-
{languageShowName}
+
{languageShowName}
{language === 'svg' && } diff --git a/web/app/components/base/markdown-blocks/form.tsx b/web/app/components/base/markdown-blocks/form.tsx index bce05bc585..36597cd13c 100644 --- a/web/app/components/base/markdown-blocks/form.tsx +++ b/web/app/components/base/markdown-blocks/form.tsx @@ -1,14 +1,16 @@ +import type { Dayjs } from 'dayjs' +import type { ButtonProps } from '@/app/components/base/button' import * as React from 'react' -import { useEffect, useState } from 'react' +import { useCallback, useMemo, useState } from 'react' import Button from '@/app/components/base/button' import { useChatContext } from '@/app/components/base/chat/chat/context' import Checkbox from '@/app/components/base/checkbox' import DatePicker from '@/app/components/base/date-and-time-picker/date-picker' import TimePicker from '@/app/components/base/date-and-time-picker/time-picker' -import { formatDateForOutput } from '@/app/components/base/date-and-time-picker/utils/dayjs' +import { formatDateForOutput, toDayjs } from '@/app/components/base/date-and-time-picker/utils/dayjs' import Input from '@/app/components/base/input' -import Select from '@/app/components/base/select' import Textarea from '@/app/components/base/textarea' +import { Select, SelectContent, SelectItem, SelectTrigger, SelectValue } from '@/app/components/base/ui/select' enum DATA_FORMAT { TEXT = 'text', @@ -32,238 +34,359 @@ enum SUPPORTED_TYPES { SELECT = 'select', HIDDEN = 'hidden', } -const MarkdownForm = ({ node }: any) => { - const { onSend } = useChatContext() - const [formValues, setFormValues] = useState<{ [key: string]: any }>({}) +const SUPPORTED_TYPES_SET = new Set(Object.values(SUPPORTED_TYPES)) - useEffect(() => { - const initialValues: { [key: string]: any } = {} - node.children.forEach((child: any) => { - if ([SUPPORTED_TAGS.INPUT, SUPPORTED_TAGS.TEXTAREA].includes(child.tagName)) { - initialValues[child.properties.name] - = (child.tagName === SUPPORTED_TAGS.INPUT && child.properties.type === SUPPORTED_TYPES.HIDDEN) - ? (child.properties.value || '') - : child.properties.value - } - }) - setFormValues(initialValues) - }, [node.children]) +const SAFE_NAME_RE = /^[a-z][\w-]*$/i +const PROTOTYPE_POISON_KEYS = new Set(['__proto__', 'constructor', 'prototype']) - const getFormValues = (children: any) => { - const values: { [key: string]: any } = {} - children.forEach((child: any) => { - if ([SUPPORTED_TAGS.INPUT, SUPPORTED_TAGS.TEXTAREA].includes(child.tagName)) { - let value = formValues[child.properties.name] +function isSafeName(name: unknown): name is string { + return typeof name === 'string' + && name.length > 0 + && name.length <= 128 + && SAFE_NAME_RE.test(name) + && !PROTOTYPE_POISON_KEYS.has(name) +} - if (child.tagName === SUPPORTED_TAGS.INPUT - && (child.properties.type === SUPPORTED_TYPES.DATE || child.properties.type === SUPPORTED_TYPES.DATETIME)) { - if (value && typeof value.format === 'function') { - // Format date output consistently - const includeTime = child.properties.type === SUPPORTED_TYPES.DATETIME - value = formatDateForOutput(value, includeTime) - } - } +const VALID_BUTTON_VARIANTS = new Set([ + 'primary', + 'warning', + 'secondary', + 'secondary-accent', + 'ghost', + 'ghost-accent', + 'tertiary', +]) +const VALID_BUTTON_SIZES = new Set(['small', 'medium', 'large']) - values[child.properties.name] = value - } - }) - return values - } +type HastText = { + type: 'text' + value: string +} - const onSubmit = (e: any) => { - e.preventDefault() - const format = node.properties.dataFormat || DATA_FORMAT.TEXT - const result = getFormValues(node.children) +type HastElement = { + type: 'element' + tagName: string + properties: Record + children: Array +} - if (format === DATA_FORMAT.JSON) { - onSend?.(JSON.stringify(result)) +type FormValue = string | boolean | Dayjs | undefined +type FormValues = Record +type EditState = { + source: HastElement[] + edits: FormValues +} + +function getTextContent(node: HastElement): string { + const textChild = node.children.find((c): c is HastText => c.type === 'text') + return textChild?.value ?? '' +} + +function str(val: unknown): string { + if (val == null) + return '' + return String(val) +} + +function computeInitialFormValues(children: HastElement[]): FormValues { + const init: FormValues = Object.create(null) as FormValues + for (const child of children) { + if (child.tagName !== SUPPORTED_TAGS.INPUT && child.tagName !== SUPPORTED_TAGS.TEXTAREA) + continue + const name = child.properties.name + if (!isSafeName(name)) + continue + + const type = child.tagName === SUPPORTED_TAGS.INPUT ? str(child.properties.type) : '' + + if (type === SUPPORTED_TYPES.HIDDEN) { + init[name] = str(child.properties.value) + } + else if (type === SUPPORTED_TYPES.DATE || type === SUPPORTED_TYPES.DATETIME || type === SUPPORTED_TYPES.TIME) { + const raw = child.properties.value + init[name] = raw != null ? toDayjs(String(raw)) : undefined + } + else if (type === SUPPORTED_TYPES.CHECKBOX) { + const { checked, value } = child.properties + init[name] = !!checked || value === true || value === 'true' } else { - const textResult = Object.entries(result) - .map(([key, value]) => `${key}: ${value}`) - .join('\n') - onSend?.(textResult) + init[name] = child.properties.value != null ? str(child.properties.value) : undefined } } + return init +} + +function getElementKey(child: HastElement, index: number): string { + const tag = child.tagName + const name = str(child.properties.name) + const htmlFor = str(child.properties.htmlFor) + const type = str(child.properties.type) + + if (tag === SUPPORTED_TAGS.LABEL) + return `label-${index}-${htmlFor || name}` + if (tag === SUPPORTED_TAGS.INPUT) + return `input-${index}-${type}-${name}` + if (tag === SUPPORTED_TAGS.TEXTAREA) + return `textarea-${index}-${name}` + if (tag === SUPPORTED_TAGS.BUTTON) + return `button-${index}-${getTextContent(child)}` + return `${tag}-${index}` +} + +const MarkdownForm = ({ node }: { node: HastElement }) => { + const typedNode = node + const { onSend } = useChatContext() + const [isSubmitting, setIsSubmitting] = useState(false) + + const elementChildren = useMemo( + () => typedNode.children.filter((c): c is HastElement => c.type === 'element'), + [typedNode.children], + ) + + const baseFormValues = useMemo( + () => computeInitialFormValues(elementChildren), + [elementChildren], + ) + + const [editState, setEditState] = useState(() => ({ + source: elementChildren, + edits: {}, + })) + + const formValues = useMemo(() => { + if (editState.source === elementChildren) + return { ...baseFormValues, ...editState.edits } + return baseFormValues + }, [editState, baseFormValues, elementChildren]) + + const updateValue = useCallback((name: string, value: FormValue) => { + if (!isSafeName(name)) + return + setEditState(prev => ({ + source: elementChildren, + edits: { + ...(prev.source === elementChildren ? prev.edits : {}), + [name]: value, + }, + })) + }, [elementChildren]) + + const getFormOutput = useCallback((): Record => { + const out = Object.create(null) as Record + for (const child of elementChildren) { + if (child.tagName !== SUPPORTED_TAGS.INPUT && child.tagName !== SUPPORTED_TAGS.TEXTAREA) + continue + const name = child.properties.name + if (!isSafeName(name)) + continue + let value: FormValue = formValues[name] + if ( + child.tagName === SUPPORTED_TAGS.INPUT + && (child.properties.type === SUPPORTED_TYPES.DATE || child.properties.type === SUPPORTED_TYPES.DATETIME) + && value != null + && typeof value === 'object' + && 'format' in value + ) { + const includeTime = child.properties.type === SUPPORTED_TYPES.DATETIME + value = formatDateForOutput(value as Dayjs, includeTime) + } + if (typeof value === 'boolean') + out[name] = value + else + out[name] = value != null ? String(value) : undefined + } + return out + }, [elementChildren, formValues]) + + const onSubmit = useCallback((e: React.MouseEvent) => { + e.preventDefault() + if (isSubmitting) + return + setIsSubmitting(true) + try { + const format = str(typedNode.properties.dataFormat) || DATA_FORMAT.TEXT + const result = getFormOutput() + if (format === DATA_FORMAT.JSON) { + onSend?.(JSON.stringify(result)) + } + else { + const textResult = Object.entries(result) + .map(([key, value]) => `${key}: ${value}`) + .join('\n') + onSend?.(textResult) + } + } + catch { + setIsSubmitting(false) + } + }, [isSubmitting, typedNode.properties.dataFormat, getFormOutput, onSend]) + return (
{ + onSubmit={(e) => { e.preventDefault() e.stopPropagation() }} > - {node.children.filter((i: any) => i.type === 'element').map((child: any, index: number) => { + {elementChildren.map((child, index) => { + const key = getElementKey(child, index) if (child.tagName === SUPPORTED_TAGS.LABEL) { return ( ) } - if (child.tagName === SUPPORTED_TAGS.INPUT && Object.values(SUPPORTED_TYPES).includes(child.properties.type)) { - if (child.properties.type === SUPPORTED_TYPES.DATE || child.properties.type === SUPPORTED_TYPES.DATETIME) { + + if (child.tagName === SUPPORTED_TAGS.INPUT && SUPPORTED_TYPES_SET.has(str(child.properties.type))) { + const name = str(child.properties.name) + if (!isSafeName(name)) + return null + + const type = str(child.properties.type) as SUPPORTED_TYPES + + if (type === SUPPORTED_TYPES.DATE || type === SUPPORTED_TYPES.DATETIME) { return ( { - setFormValues(prevValues => ({ - ...prevValues, - [child.properties.name]: date, - })) - }} - onClear={() => { - setFormValues(prevValues => ({ - ...prevValues, - [child.properties.name]: undefined, - })) - }} + key={key} + value={formValues[name] as Dayjs | undefined} + needTimePicker={type === SUPPORTED_TYPES.DATETIME} + onChange={date => updateValue(name, date)} + onClear={() => updateValue(name, undefined)} /> ) } - if (child.properties.type === SUPPORTED_TYPES.TIME) { + if (type === SUPPORTED_TYPES.TIME) { return ( { - setFormValues(prevValues => ({ - ...prevValues, - [child.properties.name]: time, - })) - }} - onClear={() => { - setFormValues(prevValues => ({ - ...prevValues, - [child.properties.name]: undefined, - })) - }} + key={key} + value={formValues[name] as Dayjs | string | undefined} + onChange={time => updateValue(name, time)} + onClear={() => updateValue(name, undefined)} /> ) } - if (child.properties.type === SUPPORTED_TYPES.CHECKBOX) { + if (type === SUPPORTED_TYPES.CHECKBOX) { return ( -
+
{ - setFormValues(prevValues => ({ - ...prevValues, - [child.properties.name]: !prevValues[child.properties.name], - })) - }} - id={child.properties.name} + checked={!!formValues[name]} + onCheck={() => updateValue(name, !formValues[name])} + id={name} /> - {child.properties.dataTip || child.properties['data-tip'] || ''} + {str(child.properties.dataTip || child.properties['data-tip'])}
) } - if (child.properties.type === SUPPORTED_TYPES.SELECT) { + if (type === SUPPORTED_TYPES.SELECT) { + const rawOptions = child.properties.dataOptions || child.properties['data-options'] || [] + let options: string[] = [] + if (typeof rawOptions === 'string') { + try { + const parsed: unknown = JSON.parse(rawOptions) + if (Array.isArray(parsed)) + options = parsed.filter((o): o is string => typeof o === 'string') + } + catch (error) { + console.error('Failed to parse data-options JSON:', rawOptions, error) + options = [] + } + } + else if (Array.isArray(rawOptions)) { + options = rawOptions.filter((o): o is string => typeof o === 'string') + } return ( ) } - if (child.properties.type === SUPPORTED_TYPES.HIDDEN) { + if (type === SUPPORTED_TYPES.HIDDEN) { return ( ) } return ( { - setFormValues(prevValues => ({ - ...prevValues, - [child.properties.name]: e.target.value, - })) - }} + key={key} + type={type} + name={name} + placeholder={str(child.properties.placeholder)} + value={str(formValues[name])} + onChange={e => updateValue(name, e.target.value)} /> ) } + if (child.tagName === SUPPORTED_TAGS.TEXTAREA) { + const name = str(child.properties.name) + if (!isSafeName(name)) + return null return (