From b466d8da92ae748bda74a84908aaeabb137c71b9 Mon Sep 17 00:00:00 2001 From: yyh <92089059+lyzno1@users.noreply.github.com> Date: Mon, 8 Dec 2025 16:55:53 +0800 Subject: [PATCH 01/23] fix(web): resolve no-unused-vars lint warning in index.spec.ts (#29273) --- web/utils/index.spec.ts | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/web/utils/index.spec.ts b/web/utils/index.spec.ts index beda974e5c..645fc246c1 100644 --- a/web/utils/index.spec.ts +++ b/web/utils/index.spec.ts @@ -452,9 +452,9 @@ describe('fetchWithRetry extended', () => { }) it('should retry specified number of times', async () => { - let attempts = 0 + let _attempts = 0 const failingPromise = () => { - attempts++ + _attempts++ return Promise.reject(new Error('fail')) } From 0cb696b208efb97b12fcf7ed5c4371338ae621e4 Mon Sep 17 00:00:00 2001 From: Joel Date: Mon, 8 Dec 2025 17:23:45 +0800 Subject: [PATCH 02/23] chore: add provider context mock (#29201) Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- web/__mocks__/provider-context.ts | 47 +++++++++++++ web/context/provider-context-mock.spec.tsx | 82 ++++++++++++++++++++++ web/context/provider-context-mock.tsx | 18 +++++ web/context/provider-context.tsx | 9 ++- web/testing/testing.md | 2 + 5 files changed, 155 insertions(+), 3 deletions(-) create mode 100644 web/__mocks__/provider-context.ts create mode 100644 web/context/provider-context-mock.spec.tsx create mode 100644 web/context/provider-context-mock.tsx diff --git a/web/__mocks__/provider-context.ts b/web/__mocks__/provider-context.ts new file mode 100644 index 0000000000..594fe38f14 --- /dev/null +++ b/web/__mocks__/provider-context.ts @@ -0,0 +1,47 @@ +import { merge, noop } from 'lodash-es' +import { defaultPlan } from '@/app/components/billing/config' +import { baseProviderContextValue } from '@/context/provider-context' +import type { ProviderContextState } from '@/context/provider-context' +import type { Plan, UsagePlanInfo } from '@/app/components/billing/type' + +export const createMockProviderContextValue = (overrides: Partial = {}): ProviderContextState => { + const merged = merge({}, baseProviderContextValue, overrides) + + return { + ...merged, + refreshModelProviders: merged.refreshModelProviders ?? noop, + onPlanInfoChanged: merged.onPlanInfoChanged ?? noop, + refreshLicenseLimit: merged.refreshLicenseLimit ?? noop, + } +} + +export const createMockPlan = (plan: Plan): ProviderContextState => + createMockProviderContextValue({ + plan: merge({}, defaultPlan, { + type: plan, + }), + }) + +export const createMockPlanUsage = (usage: UsagePlanInfo, ctx: Partial): ProviderContextState => + createMockProviderContextValue({ + ...ctx, + plan: merge(ctx.plan, { + usage, + }), + }) + +export const createMockPlanTotal = (total: UsagePlanInfo, ctx: Partial): ProviderContextState => + createMockProviderContextValue({ + ...ctx, + plan: merge(ctx.plan, { + total, + }), + }) + +export const createMockPlanReset = (reset: Partial, ctx: Partial): ProviderContextState => + createMockProviderContextValue({ + ...ctx, + plan: merge(ctx?.plan, { + reset, + }), + }) diff --git a/web/context/provider-context-mock.spec.tsx b/web/context/provider-context-mock.spec.tsx new file mode 100644 index 0000000000..ca7c6b884e --- /dev/null +++ b/web/context/provider-context-mock.spec.tsx @@ -0,0 +1,82 @@ +import { render } from '@testing-library/react' +import type { UsagePlanInfo } from '@/app/components/billing/type' +import { Plan } from '@/app/components/billing/type' +import ProviderContextMock from './provider-context-mock' +import { createMockPlan, createMockPlanReset, createMockPlanTotal, createMockPlanUsage } from '@/__mocks__/provider-context' + +let mockPlan: Plan = Plan.sandbox +const usage: UsagePlanInfo = { + vectorSpace: 1, + buildApps: 10, + teamMembers: 1, + annotatedResponse: 1, + documentsUploadQuota: 0, + apiRateLimit: 0, + triggerEvents: 0, +} + +const total: UsagePlanInfo = { + vectorSpace: 100, + buildApps: 100, + teamMembers: 10, + annotatedResponse: 100, + documentsUploadQuota: 0, + apiRateLimit: 0, + triggerEvents: 0, +} + +const reset = { + apiRateLimit: 100, + triggerEvents: 100, +} + +jest.mock('@/context/provider-context', () => ({ + useProviderContext: () => { + const withPlan = createMockPlan(mockPlan) + const withUsage = createMockPlanUsage(usage, withPlan) + const withTotal = createMockPlanTotal(total, withUsage) + const withReset = createMockPlanReset(reset, withTotal) + console.log(JSON.stringify(withReset.plan, null, 2)) + return withReset + }, +})) + +const renderWithPlan = (plan: Plan) => { + mockPlan = plan + return render() +} + +describe('ProviderContextMock', () => { + beforeEach(() => { + mockPlan = Plan.sandbox + jest.clearAllMocks() + }) + it('should display sandbox plan type when mocked with sandbox plan', async () => { + const { getByTestId } = renderWithPlan(Plan.sandbox) + expect(getByTestId('plan-type').textContent).toBe(Plan.sandbox) + }) + it('should display team plan type when mocked with team plan', () => { + const { getByTestId } = renderWithPlan(Plan.team) + expect(getByTestId('plan-type').textContent).toBe(Plan.team) + }) + it('should provide usage info from mocked plan', () => { + const { getByTestId } = renderWithPlan(Plan.team) + const buildApps = getByTestId('plan-usage-build-apps').textContent + + expect(Number(buildApps as string)).toEqual(usage.buildApps) + }) + + it('should provide total info from mocked plan', () => { + const { getByTestId } = renderWithPlan(Plan.team) + const buildApps = getByTestId('plan-total-build-apps').textContent + + expect(Number(buildApps as string)).toEqual(total.buildApps) + }) + + it('should provide reset info from mocked plan', () => { + const { getByTestId } = renderWithPlan(Plan.team) + const apiRateLimit = getByTestId('plan-reset-api-rate-limit').textContent + + expect(Number(apiRateLimit as string)).toEqual(reset.apiRateLimit) + }) +}) diff --git a/web/context/provider-context-mock.tsx b/web/context/provider-context-mock.tsx new file mode 100644 index 0000000000..b42847a9ec --- /dev/null +++ b/web/context/provider-context-mock.tsx @@ -0,0 +1,18 @@ +'use client' +import type { FC } from 'react' +import React from 'react' +import { useProviderContext } from '@/context/provider-context' + +const ProviderContextMock: FC = () => { + const { plan } = useProviderContext() + + return ( +
+
{plan.type}
+
{plan.usage.buildApps}
+
{plan.total.buildApps}
+
{plan.reset.apiRateLimit}
+
+ ) +} +export default React.memo(ProviderContextMock) diff --git a/web/context/provider-context.tsx b/web/context/provider-context.tsx index 26617921f1..70944d85f1 100644 --- a/web/context/provider-context.tsx +++ b/web/context/provider-context.tsx @@ -30,7 +30,7 @@ import { noop } from 'lodash-es' import { setZendeskConversationFields } from '@/app/components/base/zendesk/utils' import { ZENDESK_FIELD_IDS } from '@/config' -type ProviderContextState = { +export type ProviderContextState = { modelProviders: ModelProvider[] refreshModelProviders: () => void textGenerationModelList: Model[] @@ -66,7 +66,8 @@ type ProviderContextState = { isAllowTransferWorkspace: boolean isAllowPublishAsCustomKnowledgePipelineTemplate: boolean } -const ProviderContext = createContext({ + +export const baseProviderContextValue: ProviderContextState = { modelProviders: [], refreshModelProviders: noop, textGenerationModelList: [], @@ -96,7 +97,9 @@ const ProviderContext = createContext({ refreshLicenseLimit: noop, isAllowTransferWorkspace: false, isAllowPublishAsCustomKnowledgePipelineTemplate: false, -}) +} + +const ProviderContext = createContext(baseProviderContextValue) export const useProviderContext = () => useContext(ProviderContext) diff --git a/web/testing/testing.md b/web/testing/testing.md index e2df86c653..f03451230d 100644 --- a/web/testing/testing.md +++ b/web/testing/testing.md @@ -146,6 +146,8 @@ Treat component state as part of the public behavior: confirm the initial render - ✅ Reset shared stores (React context, Zustand, TanStack Query cache) between tests to avoid leaking state. Prefer helper factory functions over module-level singletons in specs. - ✅ For hooks that read from context, use `renderHook` with a custom wrapper that supplies required providers. +If it's need to mock some common context provider used across many components (for example, `ProviderContext`), put it in __mocks__/context(for example, `__mocks__/context/provider-context`). To dynamically control the mock behavior (for example, toggling plan type), use module-level variables to track state and change them(for example, `context/provier-context-mock.spec.tsx`). + ### 4. Performance Optimization Cover memoized callbacks or values only when they influence observable behavior—memoized children, subscription updates, expensive computations. Trigger realistic re-renders and assert the outcomes (avoided rerenders, reused results) instead of inspecting hook internals. From e6d504558a87a58b05e56ec88148a2fa61c6339e Mon Sep 17 00:00:00 2001 From: Joel Date: Mon, 8 Dec 2025 17:47:16 +0800 Subject: [PATCH 03/23] chore: remove log in test case (#29284) --- web/context/provider-context-mock.spec.tsx | 1 - 1 file changed, 1 deletion(-) diff --git a/web/context/provider-context-mock.spec.tsx b/web/context/provider-context-mock.spec.tsx index ca7c6b884e..5d83f7580d 100644 --- a/web/context/provider-context-mock.spec.tsx +++ b/web/context/provider-context-mock.spec.tsx @@ -36,7 +36,6 @@ jest.mock('@/context/provider-context', () => ({ const withUsage = createMockPlanUsage(usage, withPlan) const withTotal = createMockPlanTotal(total, withUsage) const withReset = createMockPlanReset(reset, withTotal) - console.log(JSON.stringify(withReset.plan, null, 2)) return withReset }, })) From 3cb944f31859b890ce1f78d5926a4dfeea9f2325 Mon Sep 17 00:00:00 2001 From: hj24 Date: Mon, 8 Dec 2025 17:54:57 +0800 Subject: [PATCH 04/23] feat: enable tenant isolation on duplicate document indexing tasks (#29080) Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com> --- api/services/dataset_service.py | 8 +- .../document_indexing_proxy/__init__.py | 11 + api/services/document_indexing_proxy/base.py | 111 +++ .../batch_indexing_base.py | 76 ++ .../document_indexing_task_proxy.py | 12 + .../duplicate_document_indexing_task_proxy.py | 15 + api/services/document_indexing_task_proxy.py | 83 -- .../rag_pipeline/rag_pipeline_task_proxy.py | 11 +- api/tasks/document_indexing_task.py | 10 +- api/tasks/duplicate_document_indexing_task.py | 82 ++ .../priority_rag_pipeline_run_task.py | 6 +- .../rag_pipeline/rag_pipeline_run_task.py | 6 +- .../test_duplicate_document_indexing_task.py | 763 ++++++++++++++++++ .../services/document_indexing_task_proxy.py | 70 +- .../test_document_indexing_task_proxy.py | 37 +- ..._duplicate_document_indexing_task_proxy.py | 363 +++++++++ .../tasks/test_dataset_indexing_task.py | 24 +- .../test_duplicate_document_indexing_task.py | 567 +++++++++++++ 18 files changed, 2097 insertions(+), 158 deletions(-) create mode 100644 api/services/document_indexing_proxy/__init__.py create mode 100644 api/services/document_indexing_proxy/base.py create mode 100644 api/services/document_indexing_proxy/batch_indexing_base.py create mode 100644 api/services/document_indexing_proxy/document_indexing_task_proxy.py create mode 100644 api/services/document_indexing_proxy/duplicate_document_indexing_task_proxy.py delete mode 100644 api/services/document_indexing_task_proxy.py create mode 100644 api/tests/test_containers_integration_tests/tasks/test_duplicate_document_indexing_task.py create mode 100644 api/tests/unit_tests/services/test_duplicate_document_indexing_task_proxy.py create mode 100644 api/tests/unit_tests/tasks/test_duplicate_document_indexing_task.py diff --git a/api/services/dataset_service.py b/api/services/dataset_service.py index 208ebcb018..bb09311349 100644 --- a/api/services/dataset_service.py +++ b/api/services/dataset_service.py @@ -51,7 +51,8 @@ from models.model import UploadFile from models.provider_ids import ModelProviderID from models.source import DataSourceOauthBinding from models.workflow import Workflow -from services.document_indexing_task_proxy import DocumentIndexingTaskProxy +from services.document_indexing_proxy.document_indexing_task_proxy import DocumentIndexingTaskProxy +from services.document_indexing_proxy.duplicate_document_indexing_task_proxy import DuplicateDocumentIndexingTaskProxy from services.entities.knowledge_entities.knowledge_entities import ( ChildChunkUpdateArgs, KnowledgeConfig, @@ -82,7 +83,6 @@ from tasks.delete_segment_from_index_task import delete_segment_from_index_task from tasks.disable_segment_from_index_task import disable_segment_from_index_task from tasks.disable_segments_from_index_task import disable_segments_from_index_task from tasks.document_indexing_update_task import document_indexing_update_task -from tasks.duplicate_document_indexing_task import duplicate_document_indexing_task from tasks.enable_segments_to_index_task import enable_segments_to_index_task from tasks.recover_document_indexing_task import recover_document_indexing_task from tasks.remove_document_from_index_task import remove_document_from_index_task @@ -1761,7 +1761,9 @@ class DocumentService: if document_ids: DocumentIndexingTaskProxy(dataset.tenant_id, dataset.id, document_ids).delay() if duplicate_document_ids: - duplicate_document_indexing_task.delay(dataset.id, duplicate_document_ids) + DuplicateDocumentIndexingTaskProxy( + dataset.tenant_id, dataset.id, duplicate_document_ids + ).delay() except LockNotOwnedError: pass diff --git a/api/services/document_indexing_proxy/__init__.py b/api/services/document_indexing_proxy/__init__.py new file mode 100644 index 0000000000..74195adbe1 --- /dev/null +++ b/api/services/document_indexing_proxy/__init__.py @@ -0,0 +1,11 @@ +from .base import DocumentTaskProxyBase +from .batch_indexing_base import BatchDocumentIndexingProxy +from .document_indexing_task_proxy import DocumentIndexingTaskProxy +from .duplicate_document_indexing_task_proxy import DuplicateDocumentIndexingTaskProxy + +__all__ = [ + "BatchDocumentIndexingProxy", + "DocumentIndexingTaskProxy", + "DocumentTaskProxyBase", + "DuplicateDocumentIndexingTaskProxy", +] diff --git a/api/services/document_indexing_proxy/base.py b/api/services/document_indexing_proxy/base.py new file mode 100644 index 0000000000..56e47857c9 --- /dev/null +++ b/api/services/document_indexing_proxy/base.py @@ -0,0 +1,111 @@ +import logging +from abc import ABC, abstractmethod +from collections.abc import Callable +from functools import cached_property +from typing import Any, ClassVar + +from enums.cloud_plan import CloudPlan +from services.feature_service import FeatureService + +logger = logging.getLogger(__name__) + + +class DocumentTaskProxyBase(ABC): + """ + Base proxy for all document processing tasks. + + Handles common logic: + - Feature/billing checks + - Dispatch routing based on plan + + Subclasses must define: + - QUEUE_NAME: Redis queue identifier + - NORMAL_TASK_FUNC: Task function for normal priority + - PRIORITY_TASK_FUNC: Task function for high priority + """ + + QUEUE_NAME: ClassVar[str] + NORMAL_TASK_FUNC: ClassVar[Callable[..., Any]] + PRIORITY_TASK_FUNC: ClassVar[Callable[..., Any]] + + def __init__(self, tenant_id: str, dataset_id: str): + """ + Initialize with minimal required parameters. + + Args: + tenant_id: Tenant identifier for billing/features + dataset_id: Dataset identifier for logging + """ + self._tenant_id = tenant_id + self._dataset_id = dataset_id + + @cached_property + def features(self): + return FeatureService.get_features(self._tenant_id) + + @abstractmethod + def _send_to_direct_queue(self, task_func: Callable[..., Any]): + """ + Send task directly to Celery queue without tenant isolation. + + Subclasses implement this to pass task-specific parameters. + + Args: + task_func: The Celery task function to call + """ + pass + + @abstractmethod + def _send_to_tenant_queue(self, task_func: Callable[..., Any]): + """ + Send task to tenant-isolated queue. + + Subclasses implement this to handle queue management. + + Args: + task_func: The Celery task function to call + """ + pass + + def _send_to_default_tenant_queue(self): + """Route to normal priority with tenant isolation.""" + self._send_to_tenant_queue(self.NORMAL_TASK_FUNC) + + def _send_to_priority_tenant_queue(self): + """Route to priority queue with tenant isolation.""" + self._send_to_tenant_queue(self.PRIORITY_TASK_FUNC) + + def _send_to_priority_direct_queue(self): + """Route to priority queue without tenant isolation.""" + self._send_to_direct_queue(self.PRIORITY_TASK_FUNC) + + def _dispatch(self): + """ + Dispatch task based on billing plan. + + Routing logic: + - Sandbox plan → normal queue + tenant isolation + - Paid plans → priority queue + tenant isolation + - Self-hosted → priority queue, no isolation + """ + logger.info( + "dispatch args: %s - %s - %s", + self._tenant_id, + self.features.billing.enabled, + self.features.billing.subscription.plan, + ) + # dispatch to different indexing queue with tenant isolation when billing enabled + if self.features.billing.enabled: + if self.features.billing.subscription.plan == CloudPlan.SANDBOX: + # dispatch to normal pipeline queue with tenant self sub queue for sandbox plan + self._send_to_default_tenant_queue() + else: + # dispatch to priority pipeline queue with tenant self sub queue for other plans + self._send_to_priority_tenant_queue() + else: + # dispatch to priority queue without tenant isolation for others, e.g.: self-hosted or enterprise + self._send_to_priority_direct_queue() + + def delay(self): + """Public API: Queue the task asynchronously.""" + self._dispatch() diff --git a/api/services/document_indexing_proxy/batch_indexing_base.py b/api/services/document_indexing_proxy/batch_indexing_base.py new file mode 100644 index 0000000000..dd122f34a8 --- /dev/null +++ b/api/services/document_indexing_proxy/batch_indexing_base.py @@ -0,0 +1,76 @@ +import logging +from collections.abc import Callable, Sequence +from dataclasses import asdict +from typing import Any + +from core.entities.document_task import DocumentTask +from core.rag.pipeline.queue import TenantIsolatedTaskQueue + +from .base import DocumentTaskProxyBase + +logger = logging.getLogger(__name__) + + +class BatchDocumentIndexingProxy(DocumentTaskProxyBase): + """ + Base proxy for batch document indexing tasks (document_ids in plural). + + Adds: + - Tenant isolated queue management + - Batch document handling + """ + + def __init__(self, tenant_id: str, dataset_id: str, document_ids: Sequence[str]): + """ + Initialize with batch documents. + + Args: + tenant_id: Tenant identifier + dataset_id: Dataset identifier + document_ids: List of document IDs to process + """ + super().__init__(tenant_id, dataset_id) + self._document_ids = document_ids + self._tenant_isolated_task_queue = TenantIsolatedTaskQueue(tenant_id, self.QUEUE_NAME) + + def _send_to_direct_queue(self, task_func: Callable[[str, str, Sequence[str]], Any]): + """ + Send batch task to direct queue. + + Args: + task_func: The Celery task function to call with (tenant_id, dataset_id, document_ids) + """ + logger.info("tenant %s send documents %s to direct queue", self._tenant_id, self._document_ids) + task_func.delay( # type: ignore + tenant_id=self._tenant_id, dataset_id=self._dataset_id, document_ids=self._document_ids + ) + + def _send_to_tenant_queue(self, task_func: Callable[[str, str, Sequence[str]], Any]): + """ + Send batch task to tenant-isolated queue. + + Args: + task_func: The Celery task function to call with (tenant_id, dataset_id, document_ids) + """ + logger.info( + "tenant %s send documents %s to tenant queue %s", self._tenant_id, self._document_ids, self.QUEUE_NAME + ) + if self._tenant_isolated_task_queue.get_task_key(): + # Add to waiting queue using List operations (lpush) + self._tenant_isolated_task_queue.push_tasks( + [ + asdict( + DocumentTask( + tenant_id=self._tenant_id, dataset_id=self._dataset_id, document_ids=self._document_ids + ) + ) + ] + ) + logger.info("tenant %s push tasks: %s - %s", self._tenant_id, self._dataset_id, self._document_ids) + else: + # Set flag and execute task + self._tenant_isolated_task_queue.set_task_waiting_time() + task_func.delay( # type: ignore + tenant_id=self._tenant_id, dataset_id=self._dataset_id, document_ids=self._document_ids + ) + logger.info("tenant %s init tasks: %s - %s", self._tenant_id, self._dataset_id, self._document_ids) diff --git a/api/services/document_indexing_proxy/document_indexing_task_proxy.py b/api/services/document_indexing_proxy/document_indexing_task_proxy.py new file mode 100644 index 0000000000..fce79a8387 --- /dev/null +++ b/api/services/document_indexing_proxy/document_indexing_task_proxy.py @@ -0,0 +1,12 @@ +from typing import ClassVar + +from services.document_indexing_proxy.batch_indexing_base import BatchDocumentIndexingProxy +from tasks.document_indexing_task import normal_document_indexing_task, priority_document_indexing_task + + +class DocumentIndexingTaskProxy(BatchDocumentIndexingProxy): + """Proxy for document indexing tasks.""" + + QUEUE_NAME: ClassVar[str] = "document_indexing" + NORMAL_TASK_FUNC = normal_document_indexing_task + PRIORITY_TASK_FUNC = priority_document_indexing_task diff --git a/api/services/document_indexing_proxy/duplicate_document_indexing_task_proxy.py b/api/services/document_indexing_proxy/duplicate_document_indexing_task_proxy.py new file mode 100644 index 0000000000..277cfbdcf1 --- /dev/null +++ b/api/services/document_indexing_proxy/duplicate_document_indexing_task_proxy.py @@ -0,0 +1,15 @@ +from typing import ClassVar + +from services.document_indexing_proxy.batch_indexing_base import BatchDocumentIndexingProxy +from tasks.duplicate_document_indexing_task import ( + normal_duplicate_document_indexing_task, + priority_duplicate_document_indexing_task, +) + + +class DuplicateDocumentIndexingTaskProxy(BatchDocumentIndexingProxy): + """Proxy for duplicate document indexing tasks.""" + + QUEUE_NAME: ClassVar[str] = "duplicate_document_indexing" + NORMAL_TASK_FUNC = normal_duplicate_document_indexing_task + PRIORITY_TASK_FUNC = priority_duplicate_document_indexing_task diff --git a/api/services/document_indexing_task_proxy.py b/api/services/document_indexing_task_proxy.py deleted file mode 100644 index 861c84b586..0000000000 --- a/api/services/document_indexing_task_proxy.py +++ /dev/null @@ -1,83 +0,0 @@ -import logging -from collections.abc import Callable, Sequence -from dataclasses import asdict -from functools import cached_property - -from core.entities.document_task import DocumentTask -from core.rag.pipeline.queue import TenantIsolatedTaskQueue -from enums.cloud_plan import CloudPlan -from services.feature_service import FeatureService -from tasks.document_indexing_task import normal_document_indexing_task, priority_document_indexing_task - -logger = logging.getLogger(__name__) - - -class DocumentIndexingTaskProxy: - def __init__(self, tenant_id: str, dataset_id: str, document_ids: Sequence[str]): - self._tenant_id = tenant_id - self._dataset_id = dataset_id - self._document_ids = document_ids - self._tenant_isolated_task_queue = TenantIsolatedTaskQueue(tenant_id, "document_indexing") - - @cached_property - def features(self): - return FeatureService.get_features(self._tenant_id) - - def _send_to_direct_queue(self, task_func: Callable[[str, str, Sequence[str]], None]): - logger.info("send dataset %s to direct queue", self._dataset_id) - task_func.delay( # type: ignore - tenant_id=self._tenant_id, dataset_id=self._dataset_id, document_ids=self._document_ids - ) - - def _send_to_tenant_queue(self, task_func: Callable[[str, str, Sequence[str]], None]): - logger.info("send dataset %s to tenant queue", self._dataset_id) - if self._tenant_isolated_task_queue.get_task_key(): - # Add to waiting queue using List operations (lpush) - self._tenant_isolated_task_queue.push_tasks( - [ - asdict( - DocumentTask( - tenant_id=self._tenant_id, dataset_id=self._dataset_id, document_ids=self._document_ids - ) - ) - ] - ) - logger.info("push tasks: %s - %s", self._dataset_id, self._document_ids) - else: - # Set flag and execute task - self._tenant_isolated_task_queue.set_task_waiting_time() - task_func.delay( # type: ignore - tenant_id=self._tenant_id, dataset_id=self._dataset_id, document_ids=self._document_ids - ) - logger.info("init tasks: %s - %s", self._dataset_id, self._document_ids) - - def _send_to_default_tenant_queue(self): - self._send_to_tenant_queue(normal_document_indexing_task) - - def _send_to_priority_tenant_queue(self): - self._send_to_tenant_queue(priority_document_indexing_task) - - def _send_to_priority_direct_queue(self): - self._send_to_direct_queue(priority_document_indexing_task) - - def _dispatch(self): - logger.info( - "dispatch args: %s - %s - %s", - self._tenant_id, - self.features.billing.enabled, - self.features.billing.subscription.plan, - ) - # dispatch to different indexing queue with tenant isolation when billing enabled - if self.features.billing.enabled: - if self.features.billing.subscription.plan == CloudPlan.SANDBOX: - # dispatch to normal pipeline queue with tenant self sub queue for sandbox plan - self._send_to_default_tenant_queue() - else: - # dispatch to priority pipeline queue with tenant self sub queue for other plans - self._send_to_priority_tenant_queue() - else: - # dispatch to priority queue without tenant isolation for others, e.g.: self-hosted or enterprise - self._send_to_priority_direct_queue() - - def delay(self): - self._dispatch() diff --git a/api/services/rag_pipeline/rag_pipeline_task_proxy.py b/api/services/rag_pipeline/rag_pipeline_task_proxy.py index 94dd7941da..1a7b104a70 100644 --- a/api/services/rag_pipeline/rag_pipeline_task_proxy.py +++ b/api/services/rag_pipeline/rag_pipeline_task_proxy.py @@ -38,21 +38,24 @@ class RagPipelineTaskProxy: upload_file = FileService(db.engine).upload_text( json_text, self._RAG_PIPELINE_INVOKE_ENTITIES_FILE_NAME, self._user_id, self._dataset_tenant_id ) + logger.info( + "tenant %s upload %d invoke entities", self._dataset_tenant_id, len(self._rag_pipeline_invoke_entities) + ) return upload_file.id def _send_to_direct_queue(self, upload_file_id: str, task_func: Callable[[str, str], None]): - logger.info("send file %s to direct queue", upload_file_id) + logger.info("tenant %s send file %s to direct queue", self._dataset_tenant_id, upload_file_id) task_func.delay( # type: ignore rag_pipeline_invoke_entities_file_id=upload_file_id, tenant_id=self._dataset_tenant_id, ) def _send_to_tenant_queue(self, upload_file_id: str, task_func: Callable[[str, str], None]): - logger.info("send file %s to tenant queue", upload_file_id) + logger.info("tenant %s send file %s to tenant queue", self._dataset_tenant_id, upload_file_id) if self._tenant_isolated_task_queue.get_task_key(): # Add to waiting queue using List operations (lpush) self._tenant_isolated_task_queue.push_tasks([upload_file_id]) - logger.info("push tasks: %s", upload_file_id) + logger.info("tenant %s push tasks: %s", self._dataset_tenant_id, upload_file_id) else: # Set flag and execute task self._tenant_isolated_task_queue.set_task_waiting_time() @@ -60,7 +63,7 @@ class RagPipelineTaskProxy: rag_pipeline_invoke_entities_file_id=upload_file_id, tenant_id=self._dataset_tenant_id, ) - logger.info("init tasks: %s", upload_file_id) + logger.info("tenant %s init tasks: %s", self._dataset_tenant_id, upload_file_id) def _send_to_default_tenant_queue(self, upload_file_id: str): self._send_to_tenant_queue(upload_file_id, rag_pipeline_run_task) diff --git a/api/tasks/document_indexing_task.py b/api/tasks/document_indexing_task.py index fee4430612..acbdab631b 100644 --- a/api/tasks/document_indexing_task.py +++ b/api/tasks/document_indexing_task.py @@ -114,7 +114,13 @@ def _document_indexing_with_tenant_queue( try: _document_indexing(dataset_id, document_ids) except Exception: - logger.exception("Error processing document indexing %s for tenant %s: %s", dataset_id, tenant_id) + logger.exception( + "Error processing document indexing %s for tenant %s: %s", + dataset_id, + tenant_id, + document_ids, + exc_info=True, + ) finally: tenant_isolated_task_queue = TenantIsolatedTaskQueue(tenant_id, "document_indexing") @@ -122,7 +128,7 @@ def _document_indexing_with_tenant_queue( # Use rpop to get the next task from the queue (FIFO order) next_tasks = tenant_isolated_task_queue.pull_tasks(count=dify_config.TENANT_ISOLATED_TASK_CONCURRENCY) - logger.info("document indexing tenant isolation queue next tasks: %s", next_tasks) + logger.info("document indexing tenant isolation queue %s next tasks: %s", tenant_id, next_tasks) if next_tasks: for next_task in next_tasks: diff --git a/api/tasks/duplicate_document_indexing_task.py b/api/tasks/duplicate_document_indexing_task.py index 6492e356a3..4078c8910e 100644 --- a/api/tasks/duplicate_document_indexing_task.py +++ b/api/tasks/duplicate_document_indexing_task.py @@ -1,13 +1,16 @@ import logging import time +from collections.abc import Callable, Sequence import click from celery import shared_task from sqlalchemy import select from configs import dify_config +from core.entities.document_task import DocumentTask from core.indexing_runner import DocumentIsPausedError, IndexingRunner from core.rag.index_processor.index_processor_factory import IndexProcessorFactory +from core.rag.pipeline.queue import TenantIsolatedTaskQueue from enums.cloud_plan import CloudPlan from extensions.ext_database import db from libs.datetime_utils import naive_utc_now @@ -24,8 +27,55 @@ def duplicate_document_indexing_task(dataset_id: str, document_ids: list): :param dataset_id: :param document_ids: + .. warning:: TO BE DEPRECATED + This function will be deprecated and removed in a future version. + Use normal_duplicate_document_indexing_task or priority_duplicate_document_indexing_task instead. + Usage: duplicate_document_indexing_task.delay(dataset_id, document_ids) """ + logger.warning("duplicate document indexing task received: %s - %s", dataset_id, document_ids) + _duplicate_document_indexing_task(dataset_id, document_ids) + + +def _duplicate_document_indexing_task_with_tenant_queue( + tenant_id: str, dataset_id: str, document_ids: Sequence[str], task_func: Callable[[str, str, Sequence[str]], None] +): + try: + _duplicate_document_indexing_task(dataset_id, document_ids) + except Exception: + logger.exception( + "Error processing duplicate document indexing %s for tenant %s: %s", + dataset_id, + tenant_id, + document_ids, + exc_info=True, + ) + finally: + tenant_isolated_task_queue = TenantIsolatedTaskQueue(tenant_id, "duplicate_document_indexing") + + # Check if there are waiting tasks in the queue + # Use rpop to get the next task from the queue (FIFO order) + next_tasks = tenant_isolated_task_queue.pull_tasks(count=dify_config.TENANT_ISOLATED_TASK_CONCURRENCY) + + logger.info("duplicate document indexing tenant isolation queue %s next tasks: %s", tenant_id, next_tasks) + + if next_tasks: + for next_task in next_tasks: + document_task = DocumentTask(**next_task) + # Process the next waiting task + # Keep the flag set to indicate a task is running + tenant_isolated_task_queue.set_task_waiting_time() + task_func.delay( # type: ignore + tenant_id=document_task.tenant_id, + dataset_id=document_task.dataset_id, + document_ids=document_task.document_ids, + ) + else: + # No more waiting tasks, clear the flag + tenant_isolated_task_queue.delete_task_key() + + +def _duplicate_document_indexing_task(dataset_id: str, document_ids: Sequence[str]): documents = [] start_at = time.perf_counter() @@ -110,3 +160,35 @@ def duplicate_document_indexing_task(dataset_id: str, document_ids: list): logger.exception("duplicate_document_indexing_task failed, dataset_id: %s", dataset_id) finally: db.session.close() + + +@shared_task(queue="dataset") +def normal_duplicate_document_indexing_task(tenant_id: str, dataset_id: str, document_ids: Sequence[str]): + """ + Async process duplicate documents + :param tenant_id: + :param dataset_id: + :param document_ids: + + Usage: normal_duplicate_document_indexing_task.delay(tenant_id, dataset_id, document_ids) + """ + logger.info("normal duplicate document indexing task received: %s - %s - %s", tenant_id, dataset_id, document_ids) + _duplicate_document_indexing_task_with_tenant_queue( + tenant_id, dataset_id, document_ids, normal_duplicate_document_indexing_task + ) + + +@shared_task(queue="priority_dataset") +def priority_duplicate_document_indexing_task(tenant_id: str, dataset_id: str, document_ids: Sequence[str]): + """ + Async process duplicate documents + :param tenant_id: + :param dataset_id: + :param document_ids: + + Usage: priority_duplicate_document_indexing_task.delay(tenant_id, dataset_id, document_ids) + """ + logger.info("priority duplicate document indexing task received: %s - %s - %s", tenant_id, dataset_id, document_ids) + _duplicate_document_indexing_task_with_tenant_queue( + tenant_id, dataset_id, document_ids, priority_duplicate_document_indexing_task + ) diff --git a/api/tasks/rag_pipeline/priority_rag_pipeline_run_task.py b/api/tasks/rag_pipeline/priority_rag_pipeline_run_task.py index a7f61d9811..1eef361a92 100644 --- a/api/tasks/rag_pipeline/priority_rag_pipeline_run_task.py +++ b/api/tasks/rag_pipeline/priority_rag_pipeline_run_task.py @@ -47,6 +47,8 @@ def priority_rag_pipeline_run_task( ) rag_pipeline_invoke_entities = json.loads(rag_pipeline_invoke_entities_content) + logger.info("tenant %s received %d rag pipeline invoke entities", tenant_id, len(rag_pipeline_invoke_entities)) + # Get Flask app object for thread context flask_app = current_app._get_current_object() # type: ignore @@ -66,7 +68,7 @@ def priority_rag_pipeline_run_task( end_at = time.perf_counter() logging.info( click.style( - f"tenant_id: {tenant_id} , Rag pipeline run completed. Latency: {end_at - start_at}s", fg="green" + f"tenant_id: {tenant_id}, Rag pipeline run completed. Latency: {end_at - start_at}s", fg="green" ) ) except Exception: @@ -78,7 +80,7 @@ def priority_rag_pipeline_run_task( # Check if there are waiting tasks in the queue # Use rpop to get the next task from the queue (FIFO order) next_file_ids = tenant_isolated_task_queue.pull_tasks(count=dify_config.TENANT_ISOLATED_TASK_CONCURRENCY) - logger.info("priority rag pipeline tenant isolation queue next files: %s", next_file_ids) + logger.info("priority rag pipeline tenant isolation queue %s next files: %s", tenant_id, next_file_ids) if next_file_ids: for next_file_id in next_file_ids: diff --git a/api/tasks/rag_pipeline/rag_pipeline_run_task.py b/api/tasks/rag_pipeline/rag_pipeline_run_task.py index 92f1dfb73d..275f5abe6e 100644 --- a/api/tasks/rag_pipeline/rag_pipeline_run_task.py +++ b/api/tasks/rag_pipeline/rag_pipeline_run_task.py @@ -47,6 +47,8 @@ def rag_pipeline_run_task( ) rag_pipeline_invoke_entities = json.loads(rag_pipeline_invoke_entities_content) + logger.info("tenant %s received %d rag pipeline invoke entities", tenant_id, len(rag_pipeline_invoke_entities)) + # Get Flask app object for thread context flask_app = current_app._get_current_object() # type: ignore @@ -66,7 +68,7 @@ def rag_pipeline_run_task( end_at = time.perf_counter() logging.info( click.style( - f"tenant_id: {tenant_id} , Rag pipeline run completed. Latency: {end_at - start_at}s", fg="green" + f"tenant_id: {tenant_id}, Rag pipeline run completed. Latency: {end_at - start_at}s", fg="green" ) ) except Exception: @@ -78,7 +80,7 @@ def rag_pipeline_run_task( # Check if there are waiting tasks in the queue # Use rpop to get the next task from the queue (FIFO order) next_file_ids = tenant_isolated_task_queue.pull_tasks(count=dify_config.TENANT_ISOLATED_TASK_CONCURRENCY) - logger.info("rag pipeline tenant isolation queue next files: %s", next_file_ids) + logger.info("rag pipeline tenant isolation queue %s next files: %s", tenant_id, next_file_ids) if next_file_ids: for next_file_id in next_file_ids: diff --git a/api/tests/test_containers_integration_tests/tasks/test_duplicate_document_indexing_task.py b/api/tests/test_containers_integration_tests/tasks/test_duplicate_document_indexing_task.py new file mode 100644 index 0000000000..aca4be1ffd --- /dev/null +++ b/api/tests/test_containers_integration_tests/tasks/test_duplicate_document_indexing_task.py @@ -0,0 +1,763 @@ +from unittest.mock import MagicMock, patch + +import pytest +from faker import Faker + +from enums.cloud_plan import CloudPlan +from extensions.ext_database import db +from models import Account, Tenant, TenantAccountJoin, TenantAccountRole +from models.dataset import Dataset, Document, DocumentSegment +from tasks.duplicate_document_indexing_task import ( + _duplicate_document_indexing_task, # Core function + _duplicate_document_indexing_task_with_tenant_queue, # Tenant queue wrapper function + duplicate_document_indexing_task, # Deprecated old interface + normal_duplicate_document_indexing_task, # New normal task + priority_duplicate_document_indexing_task, # New priority task +) + + +class TestDuplicateDocumentIndexingTasks: + """Integration tests for duplicate document indexing tasks using testcontainers. + + This test class covers: + - Core _duplicate_document_indexing_task function + - Deprecated duplicate_document_indexing_task function + - New normal_duplicate_document_indexing_task function + - New priority_duplicate_document_indexing_task function + - Tenant queue wrapper _duplicate_document_indexing_task_with_tenant_queue function + - Document segment cleanup logic + """ + + @pytest.fixture + def mock_external_service_dependencies(self): + """Mock setup for external service dependencies.""" + with ( + patch("tasks.duplicate_document_indexing_task.IndexingRunner") as mock_indexing_runner, + patch("tasks.duplicate_document_indexing_task.FeatureService") as mock_feature_service, + patch("tasks.duplicate_document_indexing_task.IndexProcessorFactory") as mock_index_processor_factory, + ): + # Setup mock indexing runner + mock_runner_instance = MagicMock() + mock_indexing_runner.return_value = mock_runner_instance + + # Setup mock feature service + mock_features = MagicMock() + mock_features.billing.enabled = False + mock_feature_service.get_features.return_value = mock_features + + # Setup mock index processor factory + mock_processor = MagicMock() + mock_processor.clean = MagicMock() + mock_index_processor_factory.return_value.init_index_processor.return_value = mock_processor + + yield { + "indexing_runner": mock_indexing_runner, + "indexing_runner_instance": mock_runner_instance, + "feature_service": mock_feature_service, + "features": mock_features, + "index_processor_factory": mock_index_processor_factory, + "index_processor": mock_processor, + } + + def _create_test_dataset_and_documents( + self, db_session_with_containers, mock_external_service_dependencies, document_count=3 + ): + """ + Helper method to create a test dataset and documents for testing. + + Args: + db_session_with_containers: Database session from testcontainers infrastructure + mock_external_service_dependencies: Mock dependencies + document_count: Number of documents to create + + Returns: + tuple: (dataset, documents) - Created dataset and document instances + """ + fake = Faker() + + # Create account and tenant + account = Account( + email=fake.email(), + name=fake.name(), + interface_language="en-US", + status="active", + ) + db.session.add(account) + db.session.commit() + + tenant = Tenant( + name=fake.company(), + status="normal", + ) + db.session.add(tenant) + db.session.commit() + + # Create tenant-account join + join = TenantAccountJoin( + tenant_id=tenant.id, + account_id=account.id, + role=TenantAccountRole.OWNER, + current=True, + ) + db.session.add(join) + db.session.commit() + + # Create dataset + dataset = Dataset( + id=fake.uuid4(), + tenant_id=tenant.id, + name=fake.company(), + description=fake.text(max_nb_chars=100), + data_source_type="upload_file", + indexing_technique="high_quality", + created_by=account.id, + ) + db.session.add(dataset) + db.session.commit() + + # Create documents + documents = [] + for i in range(document_count): + document = Document( + id=fake.uuid4(), + tenant_id=tenant.id, + dataset_id=dataset.id, + position=i, + data_source_type="upload_file", + batch="test_batch", + name=fake.file_name(), + created_from="upload_file", + created_by=account.id, + indexing_status="waiting", + enabled=True, + doc_form="text_model", + ) + db.session.add(document) + documents.append(document) + + db.session.commit() + + # Refresh dataset to ensure it's properly loaded + db.session.refresh(dataset) + + return dataset, documents + + def _create_test_dataset_with_segments( + self, db_session_with_containers, mock_external_service_dependencies, document_count=3, segments_per_doc=2 + ): + """ + Helper method to create a test dataset with documents and segments. + + Args: + db_session_with_containers: Database session from testcontainers infrastructure + mock_external_service_dependencies: Mock dependencies + document_count: Number of documents to create + segments_per_doc: Number of segments per document + + Returns: + tuple: (dataset, documents, segments) - Created dataset, documents and segments + """ + dataset, documents = self._create_test_dataset_and_documents( + db_session_with_containers, mock_external_service_dependencies, document_count + ) + + fake = Faker() + segments = [] + + # Create segments for each document + for document in documents: + for i in range(segments_per_doc): + segment = DocumentSegment( + id=fake.uuid4(), + tenant_id=dataset.tenant_id, + dataset_id=dataset.id, + document_id=document.id, + position=i, + index_node_id=f"{document.id}-node-{i}", + index_node_hash=fake.sha256(), + content=fake.text(max_nb_chars=200), + word_count=50, + tokens=100, + status="completed", + enabled=True, + indexing_at=fake.date_time_this_year(), + created_by=dataset.created_by, # Add required field + ) + db.session.add(segment) + segments.append(segment) + + db.session.commit() + + # Refresh to ensure all relationships are loaded + for document in documents: + db.session.refresh(document) + + return dataset, documents, segments + + def _create_test_dataset_with_billing_features( + self, db_session_with_containers, mock_external_service_dependencies, billing_enabled=True + ): + """ + Helper method to create a test dataset with billing features configured. + + Args: + db_session_with_containers: Database session from testcontainers infrastructure + mock_external_service_dependencies: Mock dependencies + billing_enabled: Whether billing is enabled + + Returns: + tuple: (dataset, documents) - Created dataset and document instances + """ + fake = Faker() + + # Create account and tenant + account = Account( + email=fake.email(), + name=fake.name(), + interface_language="en-US", + status="active", + ) + db.session.add(account) + db.session.commit() + + tenant = Tenant( + name=fake.company(), + status="normal", + ) + db.session.add(tenant) + db.session.commit() + + # Create tenant-account join + join = TenantAccountJoin( + tenant_id=tenant.id, + account_id=account.id, + role=TenantAccountRole.OWNER, + current=True, + ) + db.session.add(join) + db.session.commit() + + # Create dataset + dataset = Dataset( + id=fake.uuid4(), + tenant_id=tenant.id, + name=fake.company(), + description=fake.text(max_nb_chars=100), + data_source_type="upload_file", + indexing_technique="high_quality", + created_by=account.id, + ) + db.session.add(dataset) + db.session.commit() + + # Create documents + documents = [] + for i in range(3): + document = Document( + id=fake.uuid4(), + tenant_id=tenant.id, + dataset_id=dataset.id, + position=i, + data_source_type="upload_file", + batch="test_batch", + name=fake.file_name(), + created_from="upload_file", + created_by=account.id, + indexing_status="waiting", + enabled=True, + doc_form="text_model", + ) + db.session.add(document) + documents.append(document) + + db.session.commit() + + # Configure billing features + mock_external_service_dependencies["features"].billing.enabled = billing_enabled + if billing_enabled: + mock_external_service_dependencies["features"].billing.subscription.plan = CloudPlan.SANDBOX + mock_external_service_dependencies["features"].vector_space.limit = 100 + mock_external_service_dependencies["features"].vector_space.size = 50 + + # Refresh dataset to ensure it's properly loaded + db.session.refresh(dataset) + + return dataset, documents + + def test_duplicate_document_indexing_task_success( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test successful duplicate document indexing with multiple documents. + + This test verifies: + - Proper dataset retrieval from database + - Correct document processing and status updates + - IndexingRunner integration + - Database state updates + """ + # Arrange: Create test data + dataset, documents = self._create_test_dataset_and_documents( + db_session_with_containers, mock_external_service_dependencies, document_count=3 + ) + document_ids = [doc.id for doc in documents] + + # Act: Execute the task + _duplicate_document_indexing_task(dataset.id, document_ids) + + # Assert: Verify the expected outcomes + # Verify indexing runner was called correctly + mock_external_service_dependencies["indexing_runner"].assert_called_once() + mock_external_service_dependencies["indexing_runner_instance"].run.assert_called_once() + + # Verify documents were updated to parsing status + # Re-query documents from database since _duplicate_document_indexing_task uses a different session + for doc_id in document_ids: + updated_document = db.session.query(Document).where(Document.id == doc_id).first() + assert updated_document.indexing_status == "parsing" + assert updated_document.processing_started_at is not None + + # Verify the run method was called with correct documents + call_args = mock_external_service_dependencies["indexing_runner_instance"].run.call_args + assert call_args is not None + processed_documents = call_args[0][0] # First argument should be documents list + assert len(processed_documents) == 3 + + def test_duplicate_document_indexing_task_with_segment_cleanup( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test duplicate document indexing with existing segments that need cleanup. + + This test verifies: + - Old segments are identified and cleaned + - Index processor clean method is called + - Segments are deleted from database + - New indexing proceeds after cleanup + """ + # Arrange: Create test data with existing segments + dataset, documents, segments = self._create_test_dataset_with_segments( + db_session_with_containers, mock_external_service_dependencies, document_count=2, segments_per_doc=3 + ) + document_ids = [doc.id for doc in documents] + + # Act: Execute the task + _duplicate_document_indexing_task(dataset.id, document_ids) + + # Assert: Verify segment cleanup + # Verify index processor clean was called for each document with segments + assert mock_external_service_dependencies["index_processor"].clean.call_count == len(documents) + + # Verify segments were deleted from database + # Re-query segments from database since _duplicate_document_indexing_task uses a different session + for segment in segments: + deleted_segment = db.session.query(DocumentSegment).where(DocumentSegment.id == segment.id).first() + assert deleted_segment is None + + # Verify documents were updated to parsing status + for doc_id in document_ids: + updated_document = db.session.query(Document).where(Document.id == doc_id).first() + assert updated_document.indexing_status == "parsing" + assert updated_document.processing_started_at is not None + + # Verify indexing runner was called + mock_external_service_dependencies["indexing_runner"].assert_called_once() + mock_external_service_dependencies["indexing_runner_instance"].run.assert_called_once() + + def test_duplicate_document_indexing_task_dataset_not_found( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test handling of non-existent dataset. + + This test verifies: + - Proper error handling for missing datasets + - Early return without processing + - Database session cleanup + - No unnecessary indexing runner calls + """ + # Arrange: Use non-existent dataset ID + fake = Faker() + non_existent_dataset_id = fake.uuid4() + document_ids = [fake.uuid4() for _ in range(3)] + + # Act: Execute the task with non-existent dataset + _duplicate_document_indexing_task(non_existent_dataset_id, document_ids) + + # Assert: Verify no processing occurred + mock_external_service_dependencies["indexing_runner"].assert_not_called() + mock_external_service_dependencies["indexing_runner_instance"].run.assert_not_called() + mock_external_service_dependencies["index_processor"].clean.assert_not_called() + + def test_duplicate_document_indexing_task_document_not_found_in_dataset( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test handling when some documents don't exist in the dataset. + + This test verifies: + - Only existing documents are processed + - Non-existent documents are ignored + - Indexing runner receives only valid documents + - Database state updates correctly + """ + # Arrange: Create test data + dataset, documents = self._create_test_dataset_and_documents( + db_session_with_containers, mock_external_service_dependencies, document_count=2 + ) + + # Mix existing and non-existent document IDs + fake = Faker() + existing_document_ids = [doc.id for doc in documents] + non_existent_document_ids = [fake.uuid4() for _ in range(2)] + all_document_ids = existing_document_ids + non_existent_document_ids + + # Act: Execute the task with mixed document IDs + _duplicate_document_indexing_task(dataset.id, all_document_ids) + + # Assert: Verify only existing documents were processed + mock_external_service_dependencies["indexing_runner"].assert_called_once() + mock_external_service_dependencies["indexing_runner_instance"].run.assert_called_once() + + # Verify only existing documents were updated + # Re-query documents from database since _duplicate_document_indexing_task uses a different session + for doc_id in existing_document_ids: + updated_document = db.session.query(Document).where(Document.id == doc_id).first() + assert updated_document.indexing_status == "parsing" + assert updated_document.processing_started_at is not None + + # Verify the run method was called with only existing documents + call_args = mock_external_service_dependencies["indexing_runner_instance"].run.call_args + assert call_args is not None + processed_documents = call_args[0][0] # First argument should be documents list + assert len(processed_documents) == 2 # Only existing documents + + def test_duplicate_document_indexing_task_indexing_runner_exception( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test handling of IndexingRunner exceptions. + + This test verifies: + - Exceptions from IndexingRunner are properly caught + - Task completes without raising exceptions + - Database session is properly closed + - Error logging occurs + """ + # Arrange: Create test data + dataset, documents = self._create_test_dataset_and_documents( + db_session_with_containers, mock_external_service_dependencies, document_count=2 + ) + document_ids = [doc.id for doc in documents] + + # Mock IndexingRunner to raise an exception + mock_external_service_dependencies["indexing_runner_instance"].run.side_effect = Exception( + "Indexing runner failed" + ) + + # Act: Execute the task + _duplicate_document_indexing_task(dataset.id, document_ids) + + # Assert: Verify exception was handled gracefully + # The task should complete without raising exceptions + mock_external_service_dependencies["indexing_runner"].assert_called_once() + mock_external_service_dependencies["indexing_runner_instance"].run.assert_called_once() + + # Verify documents were still updated to parsing status before the exception + # Re-query documents from database since _duplicate_document_indexing_task close the session + for doc_id in document_ids: + updated_document = db.session.query(Document).where(Document.id == doc_id).first() + assert updated_document.indexing_status == "parsing" + assert updated_document.processing_started_at is not None + + def test_duplicate_document_indexing_task_billing_sandbox_plan_batch_limit( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test billing validation for sandbox plan batch upload limit. + + This test verifies: + - Sandbox plan batch upload limit enforcement + - Error handling for batch upload limit exceeded + - Document status updates to error state + - Proper error message recording + """ + # Arrange: Create test data with billing enabled + dataset, documents = self._create_test_dataset_with_billing_features( + db_session_with_containers, mock_external_service_dependencies, billing_enabled=True + ) + + # Configure sandbox plan with batch limit + mock_external_service_dependencies["features"].billing.subscription.plan = CloudPlan.SANDBOX + + # Create more documents than sandbox plan allows (limit is 1) + fake = Faker() + extra_documents = [] + for i in range(2): # Total will be 5 documents (3 existing + 2 new) + document = Document( + id=fake.uuid4(), + tenant_id=dataset.tenant_id, + dataset_id=dataset.id, + position=i + 3, + data_source_type="upload_file", + batch="test_batch", + name=fake.file_name(), + created_from="upload_file", + created_by=dataset.created_by, + indexing_status="waiting", + enabled=True, + doc_form="text_model", + ) + db.session.add(document) + extra_documents.append(document) + + db.session.commit() + all_documents = documents + extra_documents + document_ids = [doc.id for doc in all_documents] + + # Act: Execute the task with too many documents for sandbox plan + _duplicate_document_indexing_task(dataset.id, document_ids) + + # Assert: Verify error handling + # Re-query documents from database since _duplicate_document_indexing_task uses a different session + for doc_id in document_ids: + updated_document = db.session.query(Document).where(Document.id == doc_id).first() + assert updated_document.indexing_status == "error" + assert updated_document.error is not None + assert "batch upload" in updated_document.error.lower() + assert updated_document.stopped_at is not None + + # Verify indexing runner was not called due to early validation error + mock_external_service_dependencies["indexing_runner_instance"].run.assert_not_called() + + def test_duplicate_document_indexing_task_billing_vector_space_limit_exceeded( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test billing validation for vector space limit. + + This test verifies: + - Vector space limit enforcement + - Error handling for vector space limit exceeded + - Document status updates to error state + - Proper error message recording + """ + # Arrange: Create test data with billing enabled + dataset, documents = self._create_test_dataset_with_billing_features( + db_session_with_containers, mock_external_service_dependencies, billing_enabled=True + ) + + # Configure TEAM plan with vector space limit exceeded + mock_external_service_dependencies["features"].billing.subscription.plan = CloudPlan.TEAM + mock_external_service_dependencies["features"].vector_space.limit = 100 + mock_external_service_dependencies["features"].vector_space.size = 98 # Almost at limit + + document_ids = [doc.id for doc in documents] # 3 documents will exceed limit + + # Act: Execute the task with documents that will exceed vector space limit + _duplicate_document_indexing_task(dataset.id, document_ids) + + # Assert: Verify error handling + # Re-query documents from database since _duplicate_document_indexing_task uses a different session + for doc_id in document_ids: + updated_document = db.session.query(Document).where(Document.id == doc_id).first() + assert updated_document.indexing_status == "error" + assert updated_document.error is not None + assert "limit" in updated_document.error.lower() + assert updated_document.stopped_at is not None + + # Verify indexing runner was not called due to early validation error + mock_external_service_dependencies["indexing_runner_instance"].run.assert_not_called() + + def test_duplicate_document_indexing_task_with_empty_document_list( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test handling of empty document list. + + This test verifies: + - Empty document list is handled gracefully + - No processing occurs + - No errors are raised + - Database session is properly closed + """ + # Arrange: Create test dataset + dataset, _ = self._create_test_dataset_and_documents( + db_session_with_containers, mock_external_service_dependencies, document_count=0 + ) + document_ids = [] + + # Act: Execute the task with empty document list + _duplicate_document_indexing_task(dataset.id, document_ids) + + # Assert: Verify IndexingRunner was called with empty list + # Note: The actual implementation does call run([]) with empty list + mock_external_service_dependencies["indexing_runner"].assert_called_once() + mock_external_service_dependencies["indexing_runner_instance"].run.assert_called_once_with([]) + + def test_deprecated_duplicate_document_indexing_task_delegates_to_core( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test that deprecated duplicate_document_indexing_task delegates to core function. + + This test verifies: + - Deprecated function calls core _duplicate_document_indexing_task + - Proper parameter passing + - Backward compatibility + """ + # Arrange: Create test data + dataset, documents = self._create_test_dataset_and_documents( + db_session_with_containers, mock_external_service_dependencies, document_count=2 + ) + document_ids = [doc.id for doc in documents] + + # Act: Execute the deprecated task + duplicate_document_indexing_task(dataset.id, document_ids) + + # Assert: Verify core function was executed + mock_external_service_dependencies["indexing_runner"].assert_called_once() + mock_external_service_dependencies["indexing_runner_instance"].run.assert_called_once() + + # Clear session cache to see database updates from task's session + db.session.expire_all() + + # Verify documents were processed + for doc_id in document_ids: + updated_document = db.session.query(Document).where(Document.id == doc_id).first() + assert updated_document.indexing_status == "parsing" + + @patch("tasks.duplicate_document_indexing_task.TenantIsolatedTaskQueue") + def test_normal_duplicate_document_indexing_task_with_tenant_queue( + self, mock_queue_class, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test normal_duplicate_document_indexing_task with tenant isolation queue. + + This test verifies: + - Task uses tenant isolation queue correctly + - Core processing function is called + - Queue management (pull tasks, delete key) works properly + """ + # Arrange: Create test data + dataset, documents = self._create_test_dataset_and_documents( + db_session_with_containers, mock_external_service_dependencies, document_count=2 + ) + document_ids = [doc.id for doc in documents] + + # Mock tenant isolated queue to return no next tasks + mock_queue = MagicMock() + mock_queue.pull_tasks.return_value = [] + mock_queue_class.return_value = mock_queue + + # Act: Execute the normal task + normal_duplicate_document_indexing_task(dataset.tenant_id, dataset.id, document_ids) + + # Assert: Verify processing occurred + mock_external_service_dependencies["indexing_runner"].assert_called_once() + mock_external_service_dependencies["indexing_runner_instance"].run.assert_called_once() + + # Verify tenant queue was used + mock_queue_class.assert_called_with(dataset.tenant_id, "duplicate_document_indexing") + mock_queue.pull_tasks.assert_called_once() + mock_queue.delete_task_key.assert_called_once() + + # Clear session cache to see database updates from task's session + db.session.expire_all() + + # Verify documents were processed + for doc_id in document_ids: + updated_document = db.session.query(Document).where(Document.id == doc_id).first() + assert updated_document.indexing_status == "parsing" + + @patch("tasks.duplicate_document_indexing_task.TenantIsolatedTaskQueue") + def test_priority_duplicate_document_indexing_task_with_tenant_queue( + self, mock_queue_class, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test priority_duplicate_document_indexing_task with tenant isolation queue. + + This test verifies: + - Task uses tenant isolation queue correctly + - Core processing function is called + - Queue management works properly + - Same behavior as normal task with different queue assignment + """ + # Arrange: Create test data + dataset, documents = self._create_test_dataset_and_documents( + db_session_with_containers, mock_external_service_dependencies, document_count=2 + ) + document_ids = [doc.id for doc in documents] + + # Mock tenant isolated queue to return no next tasks + mock_queue = MagicMock() + mock_queue.pull_tasks.return_value = [] + mock_queue_class.return_value = mock_queue + + # Act: Execute the priority task + priority_duplicate_document_indexing_task(dataset.tenant_id, dataset.id, document_ids) + + # Assert: Verify processing occurred + mock_external_service_dependencies["indexing_runner"].assert_called_once() + mock_external_service_dependencies["indexing_runner_instance"].run.assert_called_once() + + # Verify tenant queue was used + mock_queue_class.assert_called_with(dataset.tenant_id, "duplicate_document_indexing") + mock_queue.pull_tasks.assert_called_once() + mock_queue.delete_task_key.assert_called_once() + + # Clear session cache to see database updates from task's session + db.session.expire_all() + + # Verify documents were processed + for doc_id in document_ids: + updated_document = db.session.query(Document).where(Document.id == doc_id).first() + assert updated_document.indexing_status == "parsing" + + @patch("tasks.duplicate_document_indexing_task.TenantIsolatedTaskQueue") + def test_tenant_queue_wrapper_processes_next_tasks( + self, mock_queue_class, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test tenant queue wrapper processes next queued tasks. + + This test verifies: + - After completing current task, next tasks are pulled from queue + - Next tasks are executed correctly + - Task waiting time is set for next tasks + """ + # Arrange: Create test data + dataset, documents = self._create_test_dataset_and_documents( + db_session_with_containers, mock_external_service_dependencies, document_count=2 + ) + document_ids = [doc.id for doc in documents] + + # Extract values before session detachment + tenant_id = dataset.tenant_id + dataset_id = dataset.id + + # Mock tenant isolated queue to return next task + mock_queue = MagicMock() + next_task = { + "tenant_id": tenant_id, + "dataset_id": dataset_id, + "document_ids": document_ids, + } + mock_queue.pull_tasks.return_value = [next_task] + mock_queue_class.return_value = mock_queue + + # Mock the task function to track calls + mock_task_func = MagicMock() + + # Act: Execute the wrapper function + _duplicate_document_indexing_task_with_tenant_queue(tenant_id, dataset_id, document_ids, mock_task_func) + + # Assert: Verify next task was scheduled + mock_queue.pull_tasks.assert_called_once() + mock_queue.set_task_waiting_time.assert_called_once() + mock_task_func.delay.assert_called_once_with( + tenant_id=tenant_id, + dataset_id=dataset_id, + document_ids=document_ids, + ) + mock_queue.delete_task_key.assert_not_called() diff --git a/api/tests/unit_tests/services/document_indexing_task_proxy.py b/api/tests/unit_tests/services/document_indexing_task_proxy.py index 765c4b5e32..ff243b8dc3 100644 --- a/api/tests/unit_tests/services/document_indexing_task_proxy.py +++ b/api/tests/unit_tests/services/document_indexing_task_proxy.py @@ -117,7 +117,7 @@ import pytest from core.entities.document_task import DocumentTask from core.rag.pipeline.queue import TenantIsolatedTaskQueue from enums.cloud_plan import CloudPlan -from services.document_indexing_task_proxy import DocumentIndexingTaskProxy +from services.document_indexing_proxy.document_indexing_task_proxy import DocumentIndexingTaskProxy # ============================================================================ # Test Data Factory @@ -370,7 +370,7 @@ class TestDocumentIndexingTaskProxy: # Features Property Tests # ======================================================================== - @patch("services.document_indexing_task_proxy.FeatureService") + @patch("services.document_indexing_proxy.document_indexing_task_proxy.FeatureService") def test_features_property(self, mock_feature_service): """ Test cached_property features. @@ -400,7 +400,7 @@ class TestDocumentIndexingTaskProxy: mock_feature_service.get_features.assert_called_once_with("tenant-123") - @patch("services.document_indexing_task_proxy.FeatureService") + @patch("services.document_indexing_proxy.document_indexing_task_proxy.FeatureService") def test_features_property_with_different_tenants(self, mock_feature_service): """ Test features property with different tenant IDs. @@ -438,7 +438,7 @@ class TestDocumentIndexingTaskProxy: # Direct Queue Routing Tests # ======================================================================== - @patch("services.document_indexing_task_proxy.normal_document_indexing_task") + @patch("services.document_indexing_proxy.document_indexing_task_proxy.normal_document_indexing_task") def test_send_to_direct_queue(self, mock_task): """ Test _send_to_direct_queue method. @@ -460,7 +460,7 @@ class TestDocumentIndexingTaskProxy: # Assert mock_task.delay.assert_called_once_with(tenant_id=tenant_id, dataset_id=dataset_id, document_ids=document_ids) - @patch("services.document_indexing_task_proxy.priority_document_indexing_task") + @patch("services.document_indexing_proxy.document_indexing_task_proxy.priority_document_indexing_task") def test_send_to_direct_queue_with_priority_task(self, mock_task): """ Test _send_to_direct_queue with priority task function. @@ -481,7 +481,7 @@ class TestDocumentIndexingTaskProxy: tenant_id="tenant-123", dataset_id="dataset-456", document_ids=["doc-1", "doc-2", "doc-3"] ) - @patch("services.document_indexing_task_proxy.normal_document_indexing_task") + @patch("services.document_indexing_proxy.document_indexing_task_proxy.normal_document_indexing_task") def test_send_to_direct_queue_with_single_document(self, mock_task): """ Test _send_to_direct_queue with single document ID. @@ -502,7 +502,7 @@ class TestDocumentIndexingTaskProxy: tenant_id="tenant-123", dataset_id="dataset-456", document_ids=["doc-1"] ) - @patch("services.document_indexing_task_proxy.normal_document_indexing_task") + @patch("services.document_indexing_proxy.document_indexing_task_proxy.normal_document_indexing_task") def test_send_to_direct_queue_with_empty_documents(self, mock_task): """ Test _send_to_direct_queue with empty document_ids list. @@ -525,7 +525,7 @@ class TestDocumentIndexingTaskProxy: # Tenant Queue Routing Tests # ======================================================================== - @patch("services.document_indexing_task_proxy.normal_document_indexing_task") + @patch("services.document_indexing_proxy.document_indexing_task_proxy.normal_document_indexing_task") def test_send_to_tenant_queue_with_existing_task_key(self, mock_task): """ Test _send_to_tenant_queue when task key exists. @@ -564,7 +564,7 @@ class TestDocumentIndexingTaskProxy: mock_task.delay.assert_not_called() - @patch("services.document_indexing_task_proxy.normal_document_indexing_task") + @patch("services.document_indexing_proxy.document_indexing_task_proxy.normal_document_indexing_task") def test_send_to_tenant_queue_without_task_key(self, mock_task): """ Test _send_to_tenant_queue when no task key exists. @@ -594,7 +594,7 @@ class TestDocumentIndexingTaskProxy: proxy._tenant_isolated_task_queue.push_tasks.assert_not_called() - @patch("services.document_indexing_task_proxy.priority_document_indexing_task") + @patch("services.document_indexing_proxy.document_indexing_task_proxy.priority_document_indexing_task") def test_send_to_tenant_queue_with_priority_task(self, mock_task): """ Test _send_to_tenant_queue with priority task function. @@ -621,7 +621,7 @@ class TestDocumentIndexingTaskProxy: tenant_id="tenant-123", dataset_id="dataset-456", document_ids=["doc-1", "doc-2", "doc-3"] ) - @patch("services.document_indexing_task_proxy.normal_document_indexing_task") + @patch("services.document_indexing_proxy.document_indexing_task_proxy.normal_document_indexing_task") def test_send_to_tenant_queue_document_task_serialization(self, mock_task): """ Test DocumentTask serialization in _send_to_tenant_queue. @@ -659,7 +659,7 @@ class TestDocumentIndexingTaskProxy: # Queue Type Selection Tests # ======================================================================== - @patch("services.document_indexing_task_proxy.normal_document_indexing_task") + @patch("services.document_indexing_proxy.document_indexing_task_proxy.normal_document_indexing_task") def test_send_to_default_tenant_queue(self, mock_task): """ Test _send_to_default_tenant_queue method. @@ -678,7 +678,7 @@ class TestDocumentIndexingTaskProxy: # Assert proxy._send_to_tenant_queue.assert_called_once_with(mock_task) - @patch("services.document_indexing_task_proxy.priority_document_indexing_task") + @patch("services.document_indexing_proxy.document_indexing_task_proxy.priority_document_indexing_task") def test_send_to_priority_tenant_queue(self, mock_task): """ Test _send_to_priority_tenant_queue method. @@ -697,7 +697,7 @@ class TestDocumentIndexingTaskProxy: # Assert proxy._send_to_tenant_queue.assert_called_once_with(mock_task) - @patch("services.document_indexing_task_proxy.priority_document_indexing_task") + @patch("services.document_indexing_proxy.document_indexing_task_proxy.priority_document_indexing_task") def test_send_to_priority_direct_queue(self, mock_task): """ Test _send_to_priority_direct_queue method. @@ -720,7 +720,7 @@ class TestDocumentIndexingTaskProxy: # Dispatch Logic Tests # ======================================================================== - @patch("services.document_indexing_task_proxy.FeatureService") + @patch("services.document_indexing_proxy.document_indexing_task_proxy.FeatureService") def test_dispatch_with_billing_enabled_sandbox_plan(self, mock_feature_service): """ Test _dispatch method when billing is enabled with SANDBOX plan. @@ -745,7 +745,7 @@ class TestDocumentIndexingTaskProxy: # Assert proxy._send_to_default_tenant_queue.assert_called_once() - @patch("services.document_indexing_task_proxy.FeatureService") + @patch("services.document_indexing_proxy.document_indexing_task_proxy.FeatureService") def test_dispatch_with_billing_enabled_team_plan(self, mock_feature_service): """ Test _dispatch method when billing is enabled with TEAM plan. @@ -770,7 +770,7 @@ class TestDocumentIndexingTaskProxy: # Assert proxy._send_to_priority_tenant_queue.assert_called_once() - @patch("services.document_indexing_task_proxy.FeatureService") + @patch("services.document_indexing_proxy.document_indexing_task_proxy.FeatureService") def test_dispatch_with_billing_enabled_professional_plan(self, mock_feature_service): """ Test _dispatch method when billing is enabled with PROFESSIONAL plan. @@ -795,7 +795,7 @@ class TestDocumentIndexingTaskProxy: # Assert proxy._send_to_priority_tenant_queue.assert_called_once() - @patch("services.document_indexing_task_proxy.FeatureService") + @patch("services.document_indexing_proxy.document_indexing_task_proxy.FeatureService") def test_dispatch_with_billing_disabled(self, mock_feature_service): """ Test _dispatch method when billing is disabled. @@ -818,7 +818,7 @@ class TestDocumentIndexingTaskProxy: # Assert proxy._send_to_priority_direct_queue.assert_called_once() - @patch("services.document_indexing_task_proxy.FeatureService") + @patch("services.document_indexing_proxy.document_indexing_task_proxy.FeatureService") def test_dispatch_edge_case_empty_plan(self, mock_feature_service): """ Test _dispatch method with empty plan string. @@ -842,7 +842,7 @@ class TestDocumentIndexingTaskProxy: # Assert proxy._send_to_priority_tenant_queue.assert_called_once() - @patch("services.document_indexing_task_proxy.FeatureService") + @patch("services.document_indexing_proxy.document_indexing_task_proxy.FeatureService") def test_dispatch_edge_case_none_plan(self, mock_feature_service): """ Test _dispatch method with None plan. @@ -870,7 +870,7 @@ class TestDocumentIndexingTaskProxy: # Delay Method Tests # ======================================================================== - @patch("services.document_indexing_task_proxy.FeatureService") + @patch("services.document_indexing_proxy.document_indexing_task_proxy.FeatureService") def test_delay_method(self, mock_feature_service): """ Test delay method integration. @@ -895,7 +895,7 @@ class TestDocumentIndexingTaskProxy: # Assert proxy._send_to_default_tenant_queue.assert_called_once() - @patch("services.document_indexing_task_proxy.FeatureService") + @patch("services.document_indexing_proxy.document_indexing_task_proxy.FeatureService") def test_delay_method_with_team_plan(self, mock_feature_service): """ Test delay method with TEAM plan. @@ -920,7 +920,7 @@ class TestDocumentIndexingTaskProxy: # Assert proxy._send_to_priority_tenant_queue.assert_called_once() - @patch("services.document_indexing_task_proxy.FeatureService") + @patch("services.document_indexing_proxy.document_indexing_task_proxy.FeatureService") def test_delay_method_with_billing_disabled(self, mock_feature_service): """ Test delay method with billing disabled. @@ -1021,7 +1021,7 @@ class TestDocumentIndexingTaskProxy: # Batch Operations Tests # ======================================================================== - @patch("services.document_indexing_task_proxy.normal_document_indexing_task") + @patch("services.document_indexing_proxy.document_indexing_task_proxy.normal_document_indexing_task") def test_batch_operation_with_multiple_documents(self, mock_task): """ Test batch operation with multiple documents. @@ -1044,7 +1044,7 @@ class TestDocumentIndexingTaskProxy: tenant_id="tenant-123", dataset_id="dataset-456", document_ids=document_ids ) - @patch("services.document_indexing_task_proxy.normal_document_indexing_task") + @patch("services.document_indexing_proxy.document_indexing_task_proxy.normal_document_indexing_task") def test_batch_operation_with_large_batch(self, mock_task): """ Test batch operation with large batch of documents. @@ -1073,7 +1073,7 @@ class TestDocumentIndexingTaskProxy: # Error Handling Tests # ======================================================================== - @patch("services.document_indexing_task_proxy.normal_document_indexing_task") + @patch("services.document_indexing_proxy.document_indexing_task_proxy.normal_document_indexing_task") def test_send_to_direct_queue_task_delay_failure(self, mock_task): """ Test _send_to_direct_queue when task.delay() raises an exception. @@ -1090,7 +1090,7 @@ class TestDocumentIndexingTaskProxy: with pytest.raises(Exception, match="Task delay failed"): proxy._send_to_direct_queue(mock_task) - @patch("services.document_indexing_task_proxy.normal_document_indexing_task") + @patch("services.document_indexing_proxy.document_indexing_task_proxy.normal_document_indexing_task") def test_send_to_tenant_queue_push_tasks_failure(self, mock_task): """ Test _send_to_tenant_queue when push_tasks raises an exception. @@ -1111,7 +1111,7 @@ class TestDocumentIndexingTaskProxy: with pytest.raises(Exception, match="Push tasks failed"): proxy._send_to_tenant_queue(mock_task) - @patch("services.document_indexing_task_proxy.normal_document_indexing_task") + @patch("services.document_indexing_proxy.document_indexing_task_proxy.normal_document_indexing_task") def test_send_to_tenant_queue_set_waiting_time_failure(self, mock_task): """ Test _send_to_tenant_queue when set_task_waiting_time raises an exception. @@ -1132,7 +1132,7 @@ class TestDocumentIndexingTaskProxy: with pytest.raises(Exception, match="Set waiting time failed"): proxy._send_to_tenant_queue(mock_task) - @patch("services.document_indexing_task_proxy.FeatureService") + @patch("services.document_indexing_proxy.document_indexing_task_proxy.FeatureService") def test_dispatch_feature_service_failure(self, mock_feature_service): """ Test _dispatch when FeatureService.get_features raises an exception. @@ -1153,8 +1153,8 @@ class TestDocumentIndexingTaskProxy: # Integration Tests # ======================================================================== - @patch("services.document_indexing_task_proxy.FeatureService") - @patch("services.document_indexing_task_proxy.normal_document_indexing_task") + @patch("services.document_indexing_proxy.document_indexing_task_proxy.FeatureService") + @patch("services.document_indexing_proxy.document_indexing_task_proxy.normal_document_indexing_task") def test_full_flow_sandbox_plan(self, mock_task, mock_feature_service): """ Test full flow for SANDBOX plan with tenant queue. @@ -1187,8 +1187,8 @@ class TestDocumentIndexingTaskProxy: tenant_id="tenant-123", dataset_id="dataset-456", document_ids=["doc-1", "doc-2", "doc-3"] ) - @patch("services.document_indexing_task_proxy.FeatureService") - @patch("services.document_indexing_task_proxy.priority_document_indexing_task") + @patch("services.document_indexing_proxy.document_indexing_task_proxy.FeatureService") + @patch("services.document_indexing_proxy.document_indexing_task_proxy.priority_document_indexing_task") def test_full_flow_team_plan(self, mock_task, mock_feature_service): """ Test full flow for TEAM plan with priority tenant queue. @@ -1221,8 +1221,8 @@ class TestDocumentIndexingTaskProxy: tenant_id="tenant-123", dataset_id="dataset-456", document_ids=["doc-1", "doc-2", "doc-3"] ) - @patch("services.document_indexing_task_proxy.FeatureService") - @patch("services.document_indexing_task_proxy.priority_document_indexing_task") + @patch("services.document_indexing_proxy.document_indexing_task_proxy.FeatureService") + @patch("services.document_indexing_proxy.document_indexing_task_proxy.priority_document_indexing_task") def test_full_flow_billing_disabled(self, mock_task, mock_feature_service): """ Test full flow for billing disabled (self-hosted/enterprise). diff --git a/api/tests/unit_tests/services/test_document_indexing_task_proxy.py b/api/tests/unit_tests/services/test_document_indexing_task_proxy.py index d9183be9fb..98c30c3722 100644 --- a/api/tests/unit_tests/services/test_document_indexing_task_proxy.py +++ b/api/tests/unit_tests/services/test_document_indexing_task_proxy.py @@ -3,7 +3,7 @@ from unittest.mock import Mock, patch from core.entities.document_task import DocumentTask from core.rag.pipeline.queue import TenantIsolatedTaskQueue from enums.cloud_plan import CloudPlan -from services.document_indexing_task_proxy import DocumentIndexingTaskProxy +from services.document_indexing_proxy.document_indexing_task_proxy import DocumentIndexingTaskProxy class DocumentIndexingTaskProxyTestDataFactory: @@ -59,7 +59,7 @@ class TestDocumentIndexingTaskProxy: assert proxy._tenant_isolated_task_queue._tenant_id == tenant_id assert proxy._tenant_isolated_task_queue._unique_key == "document_indexing" - @patch("services.document_indexing_task_proxy.FeatureService") + @patch("services.document_indexing_proxy.base.FeatureService") def test_features_property(self, mock_feature_service): """Test cached_property features.""" # Arrange @@ -77,7 +77,7 @@ class TestDocumentIndexingTaskProxy: assert features1 is features2 # Should be the same instance due to caching mock_feature_service.get_features.assert_called_once_with("tenant-123") - @patch("services.document_indexing_task_proxy.normal_document_indexing_task") + @patch("services.document_indexing_proxy.document_indexing_task_proxy.normal_document_indexing_task") def test_send_to_direct_queue(self, mock_task): """Test _send_to_direct_queue method.""" # Arrange @@ -92,7 +92,7 @@ class TestDocumentIndexingTaskProxy: tenant_id="tenant-123", dataset_id="dataset-456", document_ids=["doc-1", "doc-2", "doc-3"] ) - @patch("services.document_indexing_task_proxy.normal_document_indexing_task") + @patch("services.document_indexing_proxy.document_indexing_task_proxy.normal_document_indexing_task") def test_send_to_tenant_queue_with_existing_task_key(self, mock_task): """Test _send_to_tenant_queue when task key exists.""" # Arrange @@ -115,7 +115,7 @@ class TestDocumentIndexingTaskProxy: assert pushed_tasks[0]["document_ids"] == ["doc-1", "doc-2", "doc-3"] mock_task.delay.assert_not_called() - @patch("services.document_indexing_task_proxy.normal_document_indexing_task") + @patch("services.document_indexing_proxy.document_indexing_task_proxy.normal_document_indexing_task") def test_send_to_tenant_queue_without_task_key(self, mock_task): """Test _send_to_tenant_queue when no task key exists.""" # Arrange @@ -135,8 +135,7 @@ class TestDocumentIndexingTaskProxy: ) proxy._tenant_isolated_task_queue.push_tasks.assert_not_called() - @patch("services.document_indexing_task_proxy.normal_document_indexing_task") - def test_send_to_default_tenant_queue(self, mock_task): + def test_send_to_default_tenant_queue(self): """Test _send_to_default_tenant_queue method.""" # Arrange proxy = DocumentIndexingTaskProxyTestDataFactory.create_document_task_proxy() @@ -146,10 +145,9 @@ class TestDocumentIndexingTaskProxy: proxy._send_to_default_tenant_queue() # Assert - proxy._send_to_tenant_queue.assert_called_once_with(mock_task) + proxy._send_to_tenant_queue.assert_called_once_with(proxy.NORMAL_TASK_FUNC) - @patch("services.document_indexing_task_proxy.priority_document_indexing_task") - def test_send_to_priority_tenant_queue(self, mock_task): + def test_send_to_priority_tenant_queue(self): """Test _send_to_priority_tenant_queue method.""" # Arrange proxy = DocumentIndexingTaskProxyTestDataFactory.create_document_task_proxy() @@ -159,10 +157,9 @@ class TestDocumentIndexingTaskProxy: proxy._send_to_priority_tenant_queue() # Assert - proxy._send_to_tenant_queue.assert_called_once_with(mock_task) + proxy._send_to_tenant_queue.assert_called_once_with(proxy.PRIORITY_TASK_FUNC) - @patch("services.document_indexing_task_proxy.priority_document_indexing_task") - def test_send_to_priority_direct_queue(self, mock_task): + def test_send_to_priority_direct_queue(self): """Test _send_to_priority_direct_queue method.""" # Arrange proxy = DocumentIndexingTaskProxyTestDataFactory.create_document_task_proxy() @@ -172,9 +169,9 @@ class TestDocumentIndexingTaskProxy: proxy._send_to_priority_direct_queue() # Assert - proxy._send_to_direct_queue.assert_called_once_with(mock_task) + proxy._send_to_direct_queue.assert_called_once_with(proxy.PRIORITY_TASK_FUNC) - @patch("services.document_indexing_task_proxy.FeatureService") + @patch("services.document_indexing_proxy.base.FeatureService") def test_dispatch_with_billing_enabled_sandbox_plan(self, mock_feature_service): """Test _dispatch method when billing is enabled with sandbox plan.""" # Arrange @@ -191,7 +188,7 @@ class TestDocumentIndexingTaskProxy: # Assert proxy._send_to_default_tenant_queue.assert_called_once() - @patch("services.document_indexing_task_proxy.FeatureService") + @patch("services.document_indexing_proxy.base.FeatureService") def test_dispatch_with_billing_enabled_non_sandbox_plan(self, mock_feature_service): """Test _dispatch method when billing is enabled with non-sandbox plan.""" # Arrange @@ -208,7 +205,7 @@ class TestDocumentIndexingTaskProxy: # If billing enabled with non sandbox plan, should send to priority tenant queue proxy._send_to_priority_tenant_queue.assert_called_once() - @patch("services.document_indexing_task_proxy.FeatureService") + @patch("services.document_indexing_proxy.base.FeatureService") def test_dispatch_with_billing_disabled(self, mock_feature_service): """Test _dispatch method when billing is disabled.""" # Arrange @@ -223,7 +220,7 @@ class TestDocumentIndexingTaskProxy: # If billing disabled, for example: self-hosted or enterprise, should send to priority direct queue proxy._send_to_priority_direct_queue.assert_called_once() - @patch("services.document_indexing_task_proxy.FeatureService") + @patch("services.document_indexing_proxy.base.FeatureService") def test_delay_method(self, mock_feature_service): """Test delay method integration.""" # Arrange @@ -256,7 +253,7 @@ class TestDocumentIndexingTaskProxy: assert task.dataset_id == dataset_id assert task.document_ids == document_ids - @patch("services.document_indexing_task_proxy.FeatureService") + @patch("services.document_indexing_proxy.base.FeatureService") def test_dispatch_edge_case_empty_plan(self, mock_feature_service): """Test _dispatch method with empty plan string.""" # Arrange @@ -271,7 +268,7 @@ class TestDocumentIndexingTaskProxy: # Assert proxy._send_to_priority_tenant_queue.assert_called_once() - @patch("services.document_indexing_task_proxy.FeatureService") + @patch("services.document_indexing_proxy.base.FeatureService") def test_dispatch_edge_case_none_plan(self, mock_feature_service): """Test _dispatch method with None plan.""" # Arrange diff --git a/api/tests/unit_tests/services/test_duplicate_document_indexing_task_proxy.py b/api/tests/unit_tests/services/test_duplicate_document_indexing_task_proxy.py new file mode 100644 index 0000000000..68bafe3d5e --- /dev/null +++ b/api/tests/unit_tests/services/test_duplicate_document_indexing_task_proxy.py @@ -0,0 +1,363 @@ +from unittest.mock import Mock, patch + +from core.entities.document_task import DocumentTask +from core.rag.pipeline.queue import TenantIsolatedTaskQueue +from enums.cloud_plan import CloudPlan +from services.document_indexing_proxy.duplicate_document_indexing_task_proxy import ( + DuplicateDocumentIndexingTaskProxy, +) + + +class DuplicateDocumentIndexingTaskProxyTestDataFactory: + """Factory class for creating test data and mock objects for DuplicateDocumentIndexingTaskProxy tests.""" + + @staticmethod + def create_mock_features(billing_enabled: bool = False, plan: CloudPlan = CloudPlan.SANDBOX) -> Mock: + """Create mock features with billing configuration.""" + features = Mock() + features.billing = Mock() + features.billing.enabled = billing_enabled + features.billing.subscription = Mock() + features.billing.subscription.plan = plan + return features + + @staticmethod + def create_mock_tenant_queue(has_task_key: bool = False) -> Mock: + """Create mock TenantIsolatedTaskQueue.""" + queue = Mock(spec=TenantIsolatedTaskQueue) + queue.get_task_key.return_value = "task_key" if has_task_key else None + queue.push_tasks = Mock() + queue.set_task_waiting_time = Mock() + return queue + + @staticmethod + def create_duplicate_document_task_proxy( + tenant_id: str = "tenant-123", dataset_id: str = "dataset-456", document_ids: list[str] | None = None + ) -> DuplicateDocumentIndexingTaskProxy: + """Create DuplicateDocumentIndexingTaskProxy instance for testing.""" + if document_ids is None: + document_ids = ["doc-1", "doc-2", "doc-3"] + return DuplicateDocumentIndexingTaskProxy(tenant_id, dataset_id, document_ids) + + +class TestDuplicateDocumentIndexingTaskProxy: + """Test cases for DuplicateDocumentIndexingTaskProxy class.""" + + def test_initialization(self): + """Test DuplicateDocumentIndexingTaskProxy initialization.""" + # Arrange + tenant_id = "tenant-123" + dataset_id = "dataset-456" + document_ids = ["doc-1", "doc-2", "doc-3"] + + # Act + proxy = DuplicateDocumentIndexingTaskProxy(tenant_id, dataset_id, document_ids) + + # Assert + assert proxy._tenant_id == tenant_id + assert proxy._dataset_id == dataset_id + assert proxy._document_ids == document_ids + assert isinstance(proxy._tenant_isolated_task_queue, TenantIsolatedTaskQueue) + assert proxy._tenant_isolated_task_queue._tenant_id == tenant_id + assert proxy._tenant_isolated_task_queue._unique_key == "duplicate_document_indexing" + + def test_queue_name(self): + """Test QUEUE_NAME class variable.""" + # Arrange & Act + proxy = DuplicateDocumentIndexingTaskProxyTestDataFactory.create_duplicate_document_task_proxy() + + # Assert + assert proxy.QUEUE_NAME == "duplicate_document_indexing" + + def test_task_functions(self): + """Test NORMAL_TASK_FUNC and PRIORITY_TASK_FUNC class variables.""" + # Arrange & Act + proxy = DuplicateDocumentIndexingTaskProxyTestDataFactory.create_duplicate_document_task_proxy() + + # Assert + assert proxy.NORMAL_TASK_FUNC.__name__ == "normal_duplicate_document_indexing_task" + assert proxy.PRIORITY_TASK_FUNC.__name__ == "priority_duplicate_document_indexing_task" + + @patch("services.document_indexing_proxy.base.FeatureService") + def test_features_property(self, mock_feature_service): + """Test cached_property features.""" + # Arrange + mock_features = DuplicateDocumentIndexingTaskProxyTestDataFactory.create_mock_features() + mock_feature_service.get_features.return_value = mock_features + proxy = DuplicateDocumentIndexingTaskProxyTestDataFactory.create_duplicate_document_task_proxy() + + # Act + features1 = proxy.features + features2 = proxy.features # Second call should use cached property + + # Assert + assert features1 == mock_features + assert features2 == mock_features + assert features1 is features2 # Should be the same instance due to caching + mock_feature_service.get_features.assert_called_once_with("tenant-123") + + @patch( + "services.document_indexing_proxy.duplicate_document_indexing_task_proxy.normal_duplicate_document_indexing_task" + ) + def test_send_to_direct_queue(self, mock_task): + """Test _send_to_direct_queue method.""" + # Arrange + proxy = DuplicateDocumentIndexingTaskProxyTestDataFactory.create_duplicate_document_task_proxy() + mock_task.delay = Mock() + + # Act + proxy._send_to_direct_queue(mock_task) + + # Assert + mock_task.delay.assert_called_once_with( + tenant_id="tenant-123", dataset_id="dataset-456", document_ids=["doc-1", "doc-2", "doc-3"] + ) + + @patch( + "services.document_indexing_proxy.duplicate_document_indexing_task_proxy.normal_duplicate_document_indexing_task" + ) + def test_send_to_tenant_queue_with_existing_task_key(self, mock_task): + """Test _send_to_tenant_queue when task key exists.""" + # Arrange + proxy = DuplicateDocumentIndexingTaskProxyTestDataFactory.create_duplicate_document_task_proxy() + proxy._tenant_isolated_task_queue = DuplicateDocumentIndexingTaskProxyTestDataFactory.create_mock_tenant_queue( + has_task_key=True + ) + mock_task.delay = Mock() + + # Act + proxy._send_to_tenant_queue(mock_task) + + # Assert + proxy._tenant_isolated_task_queue.push_tasks.assert_called_once() + pushed_tasks = proxy._tenant_isolated_task_queue.push_tasks.call_args[0][0] + assert len(pushed_tasks) == 1 + assert isinstance(DocumentTask(**pushed_tasks[0]), DocumentTask) + assert pushed_tasks[0]["tenant_id"] == "tenant-123" + assert pushed_tasks[0]["dataset_id"] == "dataset-456" + assert pushed_tasks[0]["document_ids"] == ["doc-1", "doc-2", "doc-3"] + mock_task.delay.assert_not_called() + + @patch( + "services.document_indexing_proxy.duplicate_document_indexing_task_proxy.normal_duplicate_document_indexing_task" + ) + def test_send_to_tenant_queue_without_task_key(self, mock_task): + """Test _send_to_tenant_queue when no task key exists.""" + # Arrange + proxy = DuplicateDocumentIndexingTaskProxyTestDataFactory.create_duplicate_document_task_proxy() + proxy._tenant_isolated_task_queue = DuplicateDocumentIndexingTaskProxyTestDataFactory.create_mock_tenant_queue( + has_task_key=False + ) + mock_task.delay = Mock() + + # Act + proxy._send_to_tenant_queue(mock_task) + + # Assert + proxy._tenant_isolated_task_queue.set_task_waiting_time.assert_called_once() + mock_task.delay.assert_called_once_with( + tenant_id="tenant-123", dataset_id="dataset-456", document_ids=["doc-1", "doc-2", "doc-3"] + ) + proxy._tenant_isolated_task_queue.push_tasks.assert_not_called() + + def test_send_to_default_tenant_queue(self): + """Test _send_to_default_tenant_queue method.""" + # Arrange + proxy = DuplicateDocumentIndexingTaskProxyTestDataFactory.create_duplicate_document_task_proxy() + proxy._send_to_tenant_queue = Mock() + + # Act + proxy._send_to_default_tenant_queue() + + # Assert + proxy._send_to_tenant_queue.assert_called_once_with(proxy.NORMAL_TASK_FUNC) + + def test_send_to_priority_tenant_queue(self): + """Test _send_to_priority_tenant_queue method.""" + # Arrange + proxy = DuplicateDocumentIndexingTaskProxyTestDataFactory.create_duplicate_document_task_proxy() + proxy._send_to_tenant_queue = Mock() + + # Act + proxy._send_to_priority_tenant_queue() + + # Assert + proxy._send_to_tenant_queue.assert_called_once_with(proxy.PRIORITY_TASK_FUNC) + + def test_send_to_priority_direct_queue(self): + """Test _send_to_priority_direct_queue method.""" + # Arrange + proxy = DuplicateDocumentIndexingTaskProxyTestDataFactory.create_duplicate_document_task_proxy() + proxy._send_to_direct_queue = Mock() + + # Act + proxy._send_to_priority_direct_queue() + + # Assert + proxy._send_to_direct_queue.assert_called_once_with(proxy.PRIORITY_TASK_FUNC) + + @patch("services.document_indexing_proxy.base.FeatureService") + def test_dispatch_with_billing_enabled_sandbox_plan(self, mock_feature_service): + """Test _dispatch method when billing is enabled with sandbox plan.""" + # Arrange + mock_features = DuplicateDocumentIndexingTaskProxyTestDataFactory.create_mock_features( + billing_enabled=True, plan=CloudPlan.SANDBOX + ) + mock_feature_service.get_features.return_value = mock_features + proxy = DuplicateDocumentIndexingTaskProxyTestDataFactory.create_duplicate_document_task_proxy() + proxy._send_to_default_tenant_queue = Mock() + + # Act + proxy._dispatch() + + # Assert + proxy._send_to_default_tenant_queue.assert_called_once() + + @patch("services.document_indexing_proxy.base.FeatureService") + def test_dispatch_with_billing_enabled_non_sandbox_plan(self, mock_feature_service): + """Test _dispatch method when billing is enabled with non-sandbox plan.""" + # Arrange + mock_features = DuplicateDocumentIndexingTaskProxyTestDataFactory.create_mock_features( + billing_enabled=True, plan=CloudPlan.TEAM + ) + mock_feature_service.get_features.return_value = mock_features + proxy = DuplicateDocumentIndexingTaskProxyTestDataFactory.create_duplicate_document_task_proxy() + proxy._send_to_priority_tenant_queue = Mock() + + # Act + proxy._dispatch() + + # Assert + # If billing enabled with non sandbox plan, should send to priority tenant queue + proxy._send_to_priority_tenant_queue.assert_called_once() + + @patch("services.document_indexing_proxy.base.FeatureService") + def test_dispatch_with_billing_disabled(self, mock_feature_service): + """Test _dispatch method when billing is disabled.""" + # Arrange + mock_features = DuplicateDocumentIndexingTaskProxyTestDataFactory.create_mock_features(billing_enabled=False) + mock_feature_service.get_features.return_value = mock_features + proxy = DuplicateDocumentIndexingTaskProxyTestDataFactory.create_duplicate_document_task_proxy() + proxy._send_to_priority_direct_queue = Mock() + + # Act + proxy._dispatch() + + # Assert + # If billing disabled, for example: self-hosted or enterprise, should send to priority direct queue + proxy._send_to_priority_direct_queue.assert_called_once() + + @patch("services.document_indexing_proxy.base.FeatureService") + def test_delay_method(self, mock_feature_service): + """Test delay method integration.""" + # Arrange + mock_features = DuplicateDocumentIndexingTaskProxyTestDataFactory.create_mock_features( + billing_enabled=True, plan=CloudPlan.SANDBOX + ) + mock_feature_service.get_features.return_value = mock_features + proxy = DuplicateDocumentIndexingTaskProxyTestDataFactory.create_duplicate_document_task_proxy() + proxy._send_to_default_tenant_queue = Mock() + + # Act + proxy.delay() + + # Assert + # If billing enabled with sandbox plan, should send to default tenant queue + proxy._send_to_default_tenant_queue.assert_called_once() + + @patch("services.document_indexing_proxy.base.FeatureService") + def test_dispatch_edge_case_empty_plan(self, mock_feature_service): + """Test _dispatch method with empty plan string.""" + # Arrange + mock_features = DuplicateDocumentIndexingTaskProxyTestDataFactory.create_mock_features( + billing_enabled=True, plan="" + ) + mock_feature_service.get_features.return_value = mock_features + proxy = DuplicateDocumentIndexingTaskProxyTestDataFactory.create_duplicate_document_task_proxy() + proxy._send_to_priority_tenant_queue = Mock() + + # Act + proxy._dispatch() + + # Assert + proxy._send_to_priority_tenant_queue.assert_called_once() + + @patch("services.document_indexing_proxy.base.FeatureService") + def test_dispatch_edge_case_none_plan(self, mock_feature_service): + """Test _dispatch method with None plan.""" + # Arrange + mock_features = DuplicateDocumentIndexingTaskProxyTestDataFactory.create_mock_features( + billing_enabled=True, plan=None + ) + mock_feature_service.get_features.return_value = mock_features + proxy = DuplicateDocumentIndexingTaskProxyTestDataFactory.create_duplicate_document_task_proxy() + proxy._send_to_priority_tenant_queue = Mock() + + # Act + proxy._dispatch() + + # Assert + proxy._send_to_priority_tenant_queue.assert_called_once() + + def test_initialization_with_empty_document_ids(self): + """Test initialization with empty document_ids list.""" + # Arrange + tenant_id = "tenant-123" + dataset_id = "dataset-456" + document_ids = [] + + # Act + proxy = DuplicateDocumentIndexingTaskProxy(tenant_id, dataset_id, document_ids) + + # Assert + assert proxy._tenant_id == tenant_id + assert proxy._dataset_id == dataset_id + assert proxy._document_ids == document_ids + + def test_initialization_with_single_document_id(self): + """Test initialization with single document_id.""" + # Arrange + tenant_id = "tenant-123" + dataset_id = "dataset-456" + document_ids = ["doc-1"] + + # Act + proxy = DuplicateDocumentIndexingTaskProxy(tenant_id, dataset_id, document_ids) + + # Assert + assert proxy._tenant_id == tenant_id + assert proxy._dataset_id == dataset_id + assert proxy._document_ids == document_ids + + def test_initialization_with_large_batch(self): + """Test initialization with large batch of document IDs.""" + # Arrange + tenant_id = "tenant-123" + dataset_id = "dataset-456" + document_ids = [f"doc-{i}" for i in range(100)] + + # Act + proxy = DuplicateDocumentIndexingTaskProxy(tenant_id, dataset_id, document_ids) + + # Assert + assert proxy._tenant_id == tenant_id + assert proxy._dataset_id == dataset_id + assert proxy._document_ids == document_ids + assert len(proxy._document_ids) == 100 + + @patch("services.document_indexing_proxy.base.FeatureService") + def test_dispatch_with_professional_plan(self, mock_feature_service): + """Test _dispatch method when billing is enabled with professional plan.""" + # Arrange + mock_features = DuplicateDocumentIndexingTaskProxyTestDataFactory.create_mock_features( + billing_enabled=True, plan=CloudPlan.PROFESSIONAL + ) + mock_feature_service.get_features.return_value = mock_features + proxy = DuplicateDocumentIndexingTaskProxyTestDataFactory.create_duplicate_document_task_proxy() + proxy._send_to_priority_tenant_queue = Mock() + + # Act + proxy._dispatch() + + # Assert + proxy._send_to_priority_tenant_queue.assert_called_once() diff --git a/api/tests/unit_tests/tasks/test_dataset_indexing_task.py b/api/tests/unit_tests/tasks/test_dataset_indexing_task.py index b3b29fbe45..9d7599b8fe 100644 --- a/api/tests/unit_tests/tasks/test_dataset_indexing_task.py +++ b/api/tests/unit_tests/tasks/test_dataset_indexing_task.py @@ -19,7 +19,7 @@ from core.rag.pipeline.queue import TenantIsolatedTaskQueue from enums.cloud_plan import CloudPlan from extensions.ext_redis import redis_client from models.dataset import Dataset, Document -from services.document_indexing_task_proxy import DocumentIndexingTaskProxy +from services.document_indexing_proxy.document_indexing_task_proxy import DocumentIndexingTaskProxy from tasks.document_indexing_task import ( _document_indexing, _document_indexing_with_tenant_queue, @@ -138,7 +138,9 @@ class TestTaskEnqueuing: with patch.object(DocumentIndexingTaskProxy, "features") as mock_features: mock_features.billing.enabled = False - with patch("services.document_indexing_task_proxy.priority_document_indexing_task") as mock_task: + # Mock the class variable directly + mock_task = Mock() + with patch.object(DocumentIndexingTaskProxy, "PRIORITY_TASK_FUNC", mock_task): proxy = DocumentIndexingTaskProxy(tenant_id, dataset_id, document_ids) # Act @@ -163,7 +165,9 @@ class TestTaskEnqueuing: mock_features.billing.enabled = True mock_features.billing.subscription.plan = CloudPlan.SANDBOX - with patch("services.document_indexing_task_proxy.normal_document_indexing_task") as mock_task: + # Mock the class variable directly + mock_task = Mock() + with patch.object(DocumentIndexingTaskProxy, "NORMAL_TASK_FUNC", mock_task): proxy = DocumentIndexingTaskProxy(tenant_id, dataset_id, document_ids) # Act @@ -187,7 +191,9 @@ class TestTaskEnqueuing: mock_features.billing.enabled = True mock_features.billing.subscription.plan = CloudPlan.PROFESSIONAL - with patch("services.document_indexing_task_proxy.priority_document_indexing_task") as mock_task: + # Mock the class variable directly + mock_task = Mock() + with patch.object(DocumentIndexingTaskProxy, "PRIORITY_TASK_FUNC", mock_task): proxy = DocumentIndexingTaskProxy(tenant_id, dataset_id, document_ids) # Act @@ -211,7 +217,9 @@ class TestTaskEnqueuing: mock_features.billing.enabled = True mock_features.billing.subscription.plan = CloudPlan.PROFESSIONAL - with patch("services.document_indexing_task_proxy.priority_document_indexing_task") as mock_task: + # Mock the class variable directly + mock_task = Mock() + with patch.object(DocumentIndexingTaskProxy, "PRIORITY_TASK_FUNC", mock_task): proxy = DocumentIndexingTaskProxy(tenant_id, dataset_id, document_ids) # Act @@ -1493,7 +1501,9 @@ class TestEdgeCases: mock_features.billing.enabled = True mock_features.billing.subscription.plan = CloudPlan.PROFESSIONAL - with patch("services.document_indexing_task_proxy.priority_document_indexing_task") as mock_task: + # Mock the class variable directly + mock_task = Mock() + with patch.object(DocumentIndexingTaskProxy, "PRIORITY_TASK_FUNC", mock_task): # Act - Enqueue multiple tasks rapidly for doc_ids in document_ids_list: proxy = DocumentIndexingTaskProxy(tenant_id, dataset_id, doc_ids) @@ -1898,7 +1908,7 @@ class TestRobustness: - Error is propagated appropriately """ # Arrange - with patch("services.document_indexing_task_proxy.FeatureService.get_features") as mock_get_features: + with patch("services.document_indexing_proxy.base.FeatureService.get_features") as mock_get_features: # Simulate FeatureService failure mock_get_features.side_effect = Exception("Feature service unavailable") diff --git a/api/tests/unit_tests/tasks/test_duplicate_document_indexing_task.py b/api/tests/unit_tests/tasks/test_duplicate_document_indexing_task.py new file mode 100644 index 0000000000..0be6ea045e --- /dev/null +++ b/api/tests/unit_tests/tasks/test_duplicate_document_indexing_task.py @@ -0,0 +1,567 @@ +""" +Unit tests for duplicate document indexing tasks. + +This module tests the duplicate document indexing task functionality including: +- Task enqueuing to different queues (normal, priority, tenant-isolated) +- Batch processing of multiple duplicate documents +- Progress tracking through task lifecycle +- Error handling and retry mechanisms +- Cleanup of old document data before re-indexing +""" + +import uuid +from unittest.mock import MagicMock, Mock, patch + +import pytest + +from core.indexing_runner import DocumentIsPausedError, IndexingRunner +from core.rag.pipeline.queue import TenantIsolatedTaskQueue +from enums.cloud_plan import CloudPlan +from models.dataset import Dataset, Document, DocumentSegment +from tasks.duplicate_document_indexing_task import ( + _duplicate_document_indexing_task, + _duplicate_document_indexing_task_with_tenant_queue, + duplicate_document_indexing_task, + normal_duplicate_document_indexing_task, + priority_duplicate_document_indexing_task, +) + +# ============================================================================ +# Fixtures +# ============================================================================ + + +@pytest.fixture +def tenant_id(): + """Generate a unique tenant ID for testing.""" + return str(uuid.uuid4()) + + +@pytest.fixture +def dataset_id(): + """Generate a unique dataset ID for testing.""" + return str(uuid.uuid4()) + + +@pytest.fixture +def document_ids(): + """Generate a list of document IDs for testing.""" + return [str(uuid.uuid4()) for _ in range(3)] + + +@pytest.fixture +def mock_dataset(dataset_id, tenant_id): + """Create a mock Dataset object.""" + dataset = Mock(spec=Dataset) + dataset.id = dataset_id + dataset.tenant_id = tenant_id + dataset.indexing_technique = "high_quality" + dataset.embedding_model_provider = "openai" + dataset.embedding_model = "text-embedding-ada-002" + return dataset + + +@pytest.fixture +def mock_documents(document_ids, dataset_id): + """Create mock Document objects.""" + documents = [] + for doc_id in document_ids: + doc = Mock(spec=Document) + doc.id = doc_id + doc.dataset_id = dataset_id + doc.indexing_status = "waiting" + doc.error = None + doc.stopped_at = None + doc.processing_started_at = None + doc.doc_form = "text_model" + documents.append(doc) + return documents + + +@pytest.fixture +def mock_document_segments(document_ids): + """Create mock DocumentSegment objects.""" + segments = [] + for doc_id in document_ids: + for i in range(3): + segment = Mock(spec=DocumentSegment) + segment.id = str(uuid.uuid4()) + segment.document_id = doc_id + segment.index_node_id = f"node-{doc_id}-{i}" + segments.append(segment) + return segments + + +@pytest.fixture +def mock_db_session(): + """Mock database session.""" + with patch("tasks.duplicate_document_indexing_task.db.session") as mock_session: + mock_query = MagicMock() + mock_session.query.return_value = mock_query + mock_query.where.return_value = mock_query + mock_session.scalars.return_value = MagicMock() + yield mock_session + + +@pytest.fixture +def mock_indexing_runner(): + """Mock IndexingRunner.""" + with patch("tasks.duplicate_document_indexing_task.IndexingRunner") as mock_runner_class: + mock_runner = MagicMock(spec=IndexingRunner) + mock_runner_class.return_value = mock_runner + yield mock_runner + + +@pytest.fixture +def mock_feature_service(): + """Mock FeatureService.""" + with patch("tasks.duplicate_document_indexing_task.FeatureService") as mock_service: + mock_features = Mock() + mock_features.billing = Mock() + mock_features.billing.enabled = False + mock_features.vector_space = Mock() + mock_features.vector_space.size = 0 + mock_features.vector_space.limit = 1000 + mock_service.get_features.return_value = mock_features + yield mock_service + + +@pytest.fixture +def mock_index_processor_factory(): + """Mock IndexProcessorFactory.""" + with patch("tasks.duplicate_document_indexing_task.IndexProcessorFactory") as mock_factory: + mock_processor = MagicMock() + mock_processor.clean = Mock() + mock_factory.return_value.init_index_processor.return_value = mock_processor + yield mock_factory + + +@pytest.fixture +def mock_tenant_isolated_queue(): + """Mock TenantIsolatedTaskQueue.""" + with patch("tasks.duplicate_document_indexing_task.TenantIsolatedTaskQueue") as mock_queue_class: + mock_queue = MagicMock(spec=TenantIsolatedTaskQueue) + mock_queue.pull_tasks.return_value = [] + mock_queue.delete_task_key = Mock() + mock_queue.set_task_waiting_time = Mock() + mock_queue_class.return_value = mock_queue + yield mock_queue + + +# ============================================================================ +# Tests for deprecated duplicate_document_indexing_task +# ============================================================================ + + +class TestDuplicateDocumentIndexingTask: + """Tests for the deprecated duplicate_document_indexing_task function.""" + + @patch("tasks.duplicate_document_indexing_task._duplicate_document_indexing_task") + def test_duplicate_document_indexing_task_calls_core_function(self, mock_core_func, dataset_id, document_ids): + """Test that duplicate_document_indexing_task calls the core _duplicate_document_indexing_task function.""" + # Act + duplicate_document_indexing_task(dataset_id, document_ids) + + # Assert + mock_core_func.assert_called_once_with(dataset_id, document_ids) + + @patch("tasks.duplicate_document_indexing_task._duplicate_document_indexing_task") + def test_duplicate_document_indexing_task_with_empty_document_ids(self, mock_core_func, dataset_id): + """Test duplicate_document_indexing_task with empty document_ids list.""" + # Arrange + document_ids = [] + + # Act + duplicate_document_indexing_task(dataset_id, document_ids) + + # Assert + mock_core_func.assert_called_once_with(dataset_id, document_ids) + + +# ============================================================================ +# Tests for _duplicate_document_indexing_task core function +# ============================================================================ + + +class TestDuplicateDocumentIndexingTaskCore: + """Tests for the _duplicate_document_indexing_task core function.""" + + def test_successful_duplicate_document_indexing( + self, + mock_db_session, + mock_indexing_runner, + mock_feature_service, + mock_index_processor_factory, + mock_dataset, + mock_documents, + mock_document_segments, + dataset_id, + document_ids, + ): + """Test successful duplicate document indexing flow.""" + # Arrange + mock_db_session.query.return_value.where.return_value.first.side_effect = [mock_dataset] + mock_documents + mock_db_session.scalars.return_value.all.return_value = mock_document_segments + + # Act + _duplicate_document_indexing_task(dataset_id, document_ids) + + # Assert + # Verify IndexingRunner was called + mock_indexing_runner.run.assert_called_once() + + # Verify all documents were set to parsing status + for doc in mock_documents: + assert doc.indexing_status == "parsing" + assert doc.processing_started_at is not None + + # Verify session operations + assert mock_db_session.commit.called + assert mock_db_session.close.called + + def test_duplicate_document_indexing_dataset_not_found(self, mock_db_session, dataset_id, document_ids): + """Test duplicate document indexing when dataset is not found.""" + # Arrange + mock_db_session.query.return_value.where.return_value.first.return_value = None + + # Act + _duplicate_document_indexing_task(dataset_id, document_ids) + + # Assert + # Should close the session at least once + assert mock_db_session.close.called + + def test_duplicate_document_indexing_with_billing_enabled_sandbox_plan( + self, + mock_db_session, + mock_feature_service, + mock_dataset, + dataset_id, + document_ids, + ): + """Test duplicate document indexing with billing enabled and sandbox plan.""" + # Arrange + mock_db_session.query.return_value.where.return_value.first.return_value = mock_dataset + mock_features = mock_feature_service.get_features.return_value + mock_features.billing.enabled = True + mock_features.billing.subscription.plan = CloudPlan.SANDBOX + + # Act + _duplicate_document_indexing_task(dataset_id, document_ids) + + # Assert + # For sandbox plan with multiple documents, should fail + mock_db_session.commit.assert_called() + + def test_duplicate_document_indexing_with_billing_limit_exceeded( + self, + mock_db_session, + mock_feature_service, + mock_dataset, + mock_documents, + dataset_id, + document_ids, + ): + """Test duplicate document indexing when billing limit is exceeded.""" + # Arrange + mock_db_session.query.return_value.where.return_value.first.side_effect = [mock_dataset] + mock_documents + mock_db_session.scalars.return_value.all.return_value = [] # No segments to clean + mock_features = mock_feature_service.get_features.return_value + mock_features.billing.enabled = True + mock_features.billing.subscription.plan = CloudPlan.TEAM + mock_features.vector_space.size = 990 + mock_features.vector_space.limit = 1000 + + # Act + _duplicate_document_indexing_task(dataset_id, document_ids) + + # Assert + # Should commit the session + assert mock_db_session.commit.called + # Should close the session + assert mock_db_session.close.called + + def test_duplicate_document_indexing_runner_error( + self, + mock_db_session, + mock_indexing_runner, + mock_feature_service, + mock_index_processor_factory, + mock_dataset, + mock_documents, + dataset_id, + document_ids, + ): + """Test duplicate document indexing when IndexingRunner raises an error.""" + # Arrange + mock_db_session.query.return_value.where.return_value.first.side_effect = [mock_dataset] + mock_documents + mock_db_session.scalars.return_value.all.return_value = [] + mock_indexing_runner.run.side_effect = Exception("Indexing error") + + # Act + _duplicate_document_indexing_task(dataset_id, document_ids) + + # Assert + # Should close the session even after error + mock_db_session.close.assert_called_once() + + def test_duplicate_document_indexing_document_is_paused( + self, + mock_db_session, + mock_indexing_runner, + mock_feature_service, + mock_index_processor_factory, + mock_dataset, + mock_documents, + dataset_id, + document_ids, + ): + """Test duplicate document indexing when document is paused.""" + # Arrange + mock_db_session.query.return_value.where.return_value.first.side_effect = [mock_dataset] + mock_documents + mock_db_session.scalars.return_value.all.return_value = [] + mock_indexing_runner.run.side_effect = DocumentIsPausedError("Document paused") + + # Act + _duplicate_document_indexing_task(dataset_id, document_ids) + + # Assert + # Should handle DocumentIsPausedError gracefully + mock_db_session.close.assert_called_once() + + def test_duplicate_document_indexing_cleans_old_segments( + self, + mock_db_session, + mock_indexing_runner, + mock_feature_service, + mock_index_processor_factory, + mock_dataset, + mock_documents, + mock_document_segments, + dataset_id, + document_ids, + ): + """Test that duplicate document indexing cleans old segments.""" + # Arrange + mock_db_session.query.return_value.where.return_value.first.side_effect = [mock_dataset] + mock_documents + mock_db_session.scalars.return_value.all.return_value = mock_document_segments + mock_processor = mock_index_processor_factory.return_value.init_index_processor.return_value + + # Act + _duplicate_document_indexing_task(dataset_id, document_ids) + + # Assert + # Verify clean was called for each document + assert mock_processor.clean.call_count == len(mock_documents) + + # Verify segments were deleted + for segment in mock_document_segments: + mock_db_session.delete.assert_any_call(segment) + + +# ============================================================================ +# Tests for tenant queue wrapper function +# ============================================================================ + + +class TestDuplicateDocumentIndexingTaskWithTenantQueue: + """Tests for _duplicate_document_indexing_task_with_tenant_queue function.""" + + @patch("tasks.duplicate_document_indexing_task._duplicate_document_indexing_task") + def test_tenant_queue_wrapper_calls_core_function( + self, + mock_core_func, + mock_tenant_isolated_queue, + tenant_id, + dataset_id, + document_ids, + ): + """Test that tenant queue wrapper calls the core function.""" + # Arrange + mock_task_func = Mock() + + # Act + _duplicate_document_indexing_task_with_tenant_queue(tenant_id, dataset_id, document_ids, mock_task_func) + + # Assert + mock_core_func.assert_called_once_with(dataset_id, document_ids) + + @patch("tasks.duplicate_document_indexing_task._duplicate_document_indexing_task") + def test_tenant_queue_wrapper_deletes_key_when_no_tasks( + self, + mock_core_func, + mock_tenant_isolated_queue, + tenant_id, + dataset_id, + document_ids, + ): + """Test that tenant queue wrapper deletes task key when no more tasks.""" + # Arrange + mock_task_func = Mock() + mock_tenant_isolated_queue.pull_tasks.return_value = [] + + # Act + _duplicate_document_indexing_task_with_tenant_queue(tenant_id, dataset_id, document_ids, mock_task_func) + + # Assert + mock_tenant_isolated_queue.delete_task_key.assert_called_once() + + @patch("tasks.duplicate_document_indexing_task._duplicate_document_indexing_task") + def test_tenant_queue_wrapper_processes_next_tasks( + self, + mock_core_func, + mock_tenant_isolated_queue, + tenant_id, + dataset_id, + document_ids, + ): + """Test that tenant queue wrapper processes next tasks from queue.""" + # Arrange + mock_task_func = Mock() + next_task = { + "tenant_id": tenant_id, + "dataset_id": dataset_id, + "document_ids": document_ids, + } + mock_tenant_isolated_queue.pull_tasks.return_value = [next_task] + + # Act + _duplicate_document_indexing_task_with_tenant_queue(tenant_id, dataset_id, document_ids, mock_task_func) + + # Assert + mock_tenant_isolated_queue.set_task_waiting_time.assert_called_once() + mock_task_func.delay.assert_called_once_with( + tenant_id=tenant_id, + dataset_id=dataset_id, + document_ids=document_ids, + ) + + @patch("tasks.duplicate_document_indexing_task._duplicate_document_indexing_task") + def test_tenant_queue_wrapper_handles_core_function_error( + self, + mock_core_func, + mock_tenant_isolated_queue, + tenant_id, + dataset_id, + document_ids, + ): + """Test that tenant queue wrapper handles errors from core function.""" + # Arrange + mock_task_func = Mock() + mock_core_func.side_effect = Exception("Core function error") + + # Act + _duplicate_document_indexing_task_with_tenant_queue(tenant_id, dataset_id, document_ids, mock_task_func) + + # Assert + # Should still check for next tasks even after error + mock_tenant_isolated_queue.pull_tasks.assert_called_once() + + +# ============================================================================ +# Tests for normal_duplicate_document_indexing_task +# ============================================================================ + + +class TestNormalDuplicateDocumentIndexingTask: + """Tests for normal_duplicate_document_indexing_task function.""" + + @patch("tasks.duplicate_document_indexing_task._duplicate_document_indexing_task_with_tenant_queue") + def test_normal_task_calls_tenant_queue_wrapper( + self, + mock_wrapper_func, + tenant_id, + dataset_id, + document_ids, + ): + """Test that normal task calls tenant queue wrapper.""" + # Act + normal_duplicate_document_indexing_task(tenant_id, dataset_id, document_ids) + + # Assert + mock_wrapper_func.assert_called_once_with( + tenant_id, dataset_id, document_ids, normal_duplicate_document_indexing_task + ) + + @patch("tasks.duplicate_document_indexing_task._duplicate_document_indexing_task_with_tenant_queue") + def test_normal_task_with_empty_document_ids( + self, + mock_wrapper_func, + tenant_id, + dataset_id, + ): + """Test normal task with empty document_ids list.""" + # Arrange + document_ids = [] + + # Act + normal_duplicate_document_indexing_task(tenant_id, dataset_id, document_ids) + + # Assert + mock_wrapper_func.assert_called_once_with( + tenant_id, dataset_id, document_ids, normal_duplicate_document_indexing_task + ) + + +# ============================================================================ +# Tests for priority_duplicate_document_indexing_task +# ============================================================================ + + +class TestPriorityDuplicateDocumentIndexingTask: + """Tests for priority_duplicate_document_indexing_task function.""" + + @patch("tasks.duplicate_document_indexing_task._duplicate_document_indexing_task_with_tenant_queue") + def test_priority_task_calls_tenant_queue_wrapper( + self, + mock_wrapper_func, + tenant_id, + dataset_id, + document_ids, + ): + """Test that priority task calls tenant queue wrapper.""" + # Act + priority_duplicate_document_indexing_task(tenant_id, dataset_id, document_ids) + + # Assert + mock_wrapper_func.assert_called_once_with( + tenant_id, dataset_id, document_ids, priority_duplicate_document_indexing_task + ) + + @patch("tasks.duplicate_document_indexing_task._duplicate_document_indexing_task_with_tenant_queue") + def test_priority_task_with_single_document( + self, + mock_wrapper_func, + tenant_id, + dataset_id, + ): + """Test priority task with single document.""" + # Arrange + document_ids = ["doc-1"] + + # Act + priority_duplicate_document_indexing_task(tenant_id, dataset_id, document_ids) + + # Assert + mock_wrapper_func.assert_called_once_with( + tenant_id, dataset_id, document_ids, priority_duplicate_document_indexing_task + ) + + @patch("tasks.duplicate_document_indexing_task._duplicate_document_indexing_task_with_tenant_queue") + def test_priority_task_with_large_batch( + self, + mock_wrapper_func, + tenant_id, + dataset_id, + ): + """Test priority task with large batch of documents.""" + # Arrange + document_ids = [f"doc-{i}" for i in range(100)] + + # Act + priority_duplicate_document_indexing_task(tenant_id, dataset_id, document_ids) + + # Assert + mock_wrapper_func.assert_called_once_with( + tenant_id, dataset_id, document_ids, priority_duplicate_document_indexing_task + ) From c6eb18daaec4ed543b6afa5ea355d754977612cd Mon Sep 17 00:00:00 2001 From: wangxiaolei Date: Tue, 9 Dec 2025 10:22:02 +0800 Subject: [PATCH 05/23] =?UTF-8?q?feat:=20charset=5Fnormalizer=20for=20bett?= =?UTF-8?q?er=20encoding=20detection=20than=20httpx's=20d=E2=80=A6=20(#292?= =?UTF-8?q?64)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../workflow/nodes/http_request/entities.py | 29 +++++- .../nodes/http_request/test_entities.py | 93 +++++++++++++++++++ 2 files changed, 121 insertions(+), 1 deletion(-) diff --git a/api/core/workflow/nodes/http_request/entities.py b/api/core/workflow/nodes/http_request/entities.py index 5a7db6e0e6..e323533835 100644 --- a/api/core/workflow/nodes/http_request/entities.py +++ b/api/core/workflow/nodes/http_request/entities.py @@ -3,6 +3,7 @@ from collections.abc import Sequence from email.message import Message from typing import Any, Literal +import charset_normalizer import httpx from pydantic import BaseModel, Field, ValidationInfo, field_validator @@ -96,10 +97,12 @@ class HttpRequestNodeData(BaseNodeData): class Response: headers: dict[str, str] response: httpx.Response + _cached_text: str | None def __init__(self, response: httpx.Response): self.response = response self.headers = dict(response.headers) + self._cached_text = None @property def is_file(self): @@ -159,7 +162,31 @@ class Response: @property def text(self) -> str: - return self.response.text + """ + Get response text with robust encoding detection. + + Uses charset_normalizer for better encoding detection than httpx's default, + which helps handle Chinese and other non-ASCII characters properly. + """ + # Check cache first + if hasattr(self, "_cached_text") and self._cached_text is not None: + return self._cached_text + + # Try charset_normalizer for robust encoding detection first + detected_encoding = charset_normalizer.from_bytes(self.response.content).best() + if detected_encoding and detected_encoding.encoding: + try: + text = self.response.content.decode(detected_encoding.encoding) + self._cached_text = text + return text + except (UnicodeDecodeError, TypeError, LookupError): + # Fallback to httpx's encoding detection if charset_normalizer fails + pass + + # Fallback to httpx's built-in encoding detection + text = self.response.text + self._cached_text = text + return text @property def content(self) -> bytes: diff --git a/api/tests/unit_tests/core/workflow/nodes/http_request/test_entities.py b/api/tests/unit_tests/core/workflow/nodes/http_request/test_entities.py index 0f6b7e4ab6..47a5df92a4 100644 --- a/api/tests/unit_tests/core/workflow/nodes/http_request/test_entities.py +++ b/api/tests/unit_tests/core/workflow/nodes/http_request/test_entities.py @@ -1,3 +1,4 @@ +import json from unittest.mock import Mock, PropertyMock, patch import httpx @@ -138,3 +139,95 @@ def test_is_file_with_no_content_disposition(mock_response): type(mock_response).content = PropertyMock(return_value=bytes([0x00, 0xFF] * 512)) response = Response(mock_response) assert response.is_file + + +# UTF-8 Encoding Tests +@pytest.mark.parametrize( + ("content_bytes", "expected_text", "description"), + [ + # Chinese UTF-8 bytes + ( + b'{"message": "\xe4\xbd\xa0\xe5\xa5\xbd\xe4\xb8\x96\xe7\x95\x8c"}', + '{"message": "你好世界"}', + "Chinese characters UTF-8", + ), + # Japanese UTF-8 bytes + ( + b'{"message": "\xe3\x81\x93\xe3\x82\x93\xe3\x81\xab\xe3\x81\xa1\xe3\x81\xaf"}', + '{"message": "こんにちは"}', + "Japanese characters UTF-8", + ), + # Korean UTF-8 bytes + ( + b'{"message": "\xec\x95\x88\xeb\x85\x95\xed\x95\x98\xec\x84\xb8\xec\x9a\x94"}', + '{"message": "안녕하세요"}', + "Korean characters UTF-8", + ), + # Arabic UTF-8 + (b'{"text": "\xd9\x85\xd8\xb1\xd8\xad\xd8\xa8\xd8\xa7"}', '{"text": "مرحبا"}', "Arabic characters UTF-8"), + # European characters UTF-8 + (b'{"text": "Caf\xc3\xa9 M\xc3\xbcnchen"}', '{"text": "Café München"}', "European accented characters"), + # Simple ASCII + (b'{"text": "Hello World"}', '{"text": "Hello World"}', "Simple ASCII text"), + ], +) +def test_text_property_utf8_decoding(mock_response, content_bytes, expected_text, description): + """Test that Response.text properly decodes UTF-8 content with charset_normalizer""" + mock_response.headers = {"content-type": "application/json; charset=utf-8"} + type(mock_response).content = PropertyMock(return_value=content_bytes) + # Mock httpx response.text to return something different (simulating potential encoding issues) + mock_response.text = "incorrect-fallback-text" # To ensure we are not falling back to httpx's text property + + response = Response(mock_response) + + # Our enhanced text property should decode properly using charset_normalizer + assert response.text == expected_text, ( + f"Failed for {description}: got {repr(response.text)}, expected {repr(expected_text)}" + ) + + +def test_text_property_fallback_to_httpx(mock_response): + """Test that Response.text falls back to httpx.text when charset_normalizer fails""" + mock_response.headers = {"content-type": "application/json"} + + # Create malformed UTF-8 bytes + malformed_bytes = b'{"text": "\xff\xfe\x00\x00 invalid"}' + type(mock_response).content = PropertyMock(return_value=malformed_bytes) + + # Mock httpx.text to return some fallback value + fallback_text = '{"text": "fallback"}' + mock_response.text = fallback_text + + response = Response(mock_response) + + # Should fall back to httpx's text when charset_normalizer fails + assert response.text == fallback_text + + +@pytest.mark.parametrize( + ("json_content", "description"), + [ + # JSON with escaped Unicode (like Flask jsonify()) + ('{"message": "\\u4f60\\u597d\\u4e16\\u754c"}', "JSON with escaped Unicode"), + # JSON with mixed escape sequences and UTF-8 + ('{"mixed": "Hello \\u4f60\\u597d"}', "Mixed escaped and regular text"), + # JSON with complex escape sequences + ('{"complex": "\\ud83d\\ude00\\u4f60\\u597d"}', "Emoji and Chinese escapes"), + ], +) +def test_text_property_with_escaped_unicode(mock_response, json_content, description): + """Test Response.text with JSON containing Unicode escape sequences""" + mock_response.headers = {"content-type": "application/json"} + + content_bytes = json_content.encode("utf-8") + type(mock_response).content = PropertyMock(return_value=content_bytes) + mock_response.text = json_content # httpx would return the same for valid UTF-8 + + response = Response(mock_response) + + # Should preserve the escape sequences (valid JSON) + assert response.text == json_content, f"Failed for {description}" + + # The text should be valid JSON that can be parsed back to proper Unicode + parsed = json.loads(response.text) + assert isinstance(parsed, dict), f"Invalid JSON for {description}" From ca61bb5de0c10bd96ec6e8b604986629c2298b35 Mon Sep 17 00:00:00 2001 From: wangxiaolei Date: Tue, 9 Dec 2025 10:23:29 +0800 Subject: [PATCH 06/23] fix: Weaviate was not closed properly (#29301) --- .../rag/datasource/vdb/weaviate/weaviate_vector.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/api/core/rag/datasource/vdb/weaviate/weaviate_vector.py b/api/core/rag/datasource/vdb/weaviate/weaviate_vector.py index 2c7bc592c0..84d1e26b34 100644 --- a/api/core/rag/datasource/vdb/weaviate/weaviate_vector.py +++ b/api/core/rag/datasource/vdb/weaviate/weaviate_vector.py @@ -79,6 +79,18 @@ class WeaviateVector(BaseVector): self._client = self._init_client(config) self._attributes = attributes + def __del__(self): + """ + Destructor to properly close the Weaviate client connection. + Prevents connection leaks and resource warnings. + """ + if hasattr(self, "_client") and self._client is not None: + try: + self._client.close() + except Exception as e: + # Ignore errors during cleanup as object is being destroyed + logger.warning("Error closing Weaviate client %s", e, exc_info=True) + def _init_client(self, config: WeaviateConfig) -> weaviate.WeaviateClient: """ Initializes and returns a connected Weaviate client. From 97d671d9aadbd856a0adc00f9fe71aaf7fedb115 Mon Sep 17 00:00:00 2001 From: wangxiaolei Date: Tue, 9 Dec 2025 10:24:56 +0800 Subject: [PATCH 07/23] feat: Allow Editor role to use Trigger Plugin subscriptions (#29292) --- .../console/workspace/trigger_providers.py | 18 +- .../test_trigger_provider_permissions.py | 244 ++++++++++++++++++ 2 files changed, 257 insertions(+), 5 deletions(-) create mode 100644 api/tests/integration_tests/controllers/console/workspace/test_trigger_provider_permissions.py diff --git a/api/controllers/console/workspace/trigger_providers.py b/api/controllers/console/workspace/trigger_providers.py index 69281c6214..268473d6d1 100644 --- a/api/controllers/console/workspace/trigger_providers.py +++ b/api/controllers/console/workspace/trigger_providers.py @@ -22,7 +22,12 @@ from services.trigger.trigger_subscription_builder_service import TriggerSubscri from services.trigger.trigger_subscription_operator_service import TriggerSubscriptionOperatorService from .. import console_ns -from ..wraps import account_initialization_required, is_admin_or_owner_required, setup_required +from ..wraps import ( + account_initialization_required, + edit_permission_required, + is_admin_or_owner_required, + setup_required, +) logger = logging.getLogger(__name__) @@ -72,7 +77,7 @@ class TriggerProviderInfoApi(Resource): class TriggerSubscriptionListApi(Resource): @setup_required @login_required - @is_admin_or_owner_required + @edit_permission_required @account_initialization_required def get(self, provider): """List all trigger subscriptions for the current tenant's provider""" @@ -104,7 +109,7 @@ class TriggerSubscriptionBuilderCreateApi(Resource): @console_ns.expect(parser) @setup_required @login_required - @is_admin_or_owner_required + @edit_permission_required @account_initialization_required def post(self, provider): """Add a new subscription instance for a trigger provider""" @@ -133,6 +138,7 @@ class TriggerSubscriptionBuilderCreateApi(Resource): class TriggerSubscriptionBuilderGetApi(Resource): @setup_required @login_required + @edit_permission_required @account_initialization_required def get(self, provider, subscription_builder_id): """Get a subscription instance for a trigger provider""" @@ -155,7 +161,7 @@ class TriggerSubscriptionBuilderVerifyApi(Resource): @console_ns.expect(parser_api) @setup_required @login_required - @is_admin_or_owner_required + @edit_permission_required @account_initialization_required def post(self, provider, subscription_builder_id): """Verify a subscription instance for a trigger provider""" @@ -200,6 +206,7 @@ class TriggerSubscriptionBuilderUpdateApi(Resource): @console_ns.expect(parser_update_api) @setup_required @login_required + @edit_permission_required @account_initialization_required def post(self, provider, subscription_builder_id): """Update a subscription instance for a trigger provider""" @@ -233,6 +240,7 @@ class TriggerSubscriptionBuilderUpdateApi(Resource): class TriggerSubscriptionBuilderLogsApi(Resource): @setup_required @login_required + @edit_permission_required @account_initialization_required def get(self, provider, subscription_builder_id): """Get the request logs for a subscription instance for a trigger provider""" @@ -255,7 +263,7 @@ class TriggerSubscriptionBuilderBuildApi(Resource): @console_ns.expect(parser_update_api) @setup_required @login_required - @is_admin_or_owner_required + @edit_permission_required @account_initialization_required def post(self, provider, subscription_builder_id): """Build a subscription instance for a trigger provider""" diff --git a/api/tests/integration_tests/controllers/console/workspace/test_trigger_provider_permissions.py b/api/tests/integration_tests/controllers/console/workspace/test_trigger_provider_permissions.py new file mode 100644 index 0000000000..e55c12e678 --- /dev/null +++ b/api/tests/integration_tests/controllers/console/workspace/test_trigger_provider_permissions.py @@ -0,0 +1,244 @@ +"""Integration tests for Trigger Provider subscription permission verification.""" + +import uuid +from unittest import mock + +import pytest +from flask.testing import FlaskClient + +from controllers.console.workspace import trigger_providers as trigger_providers_api +from libs.datetime_utils import naive_utc_now +from models import Tenant +from models.account import Account, TenantAccountJoin, TenantAccountRole + + +class TestTriggerProviderSubscriptionPermissions: + """Test permission verification for Trigger Provider subscription endpoints.""" + + @pytest.fixture + def mock_account(self, monkeypatch: pytest.MonkeyPatch): + """Create a mock Account for testing.""" + + account = Account(name="Test User", email="test@example.com") + account.id = str(uuid.uuid4()) + account.last_active_at = naive_utc_now() + account.created_at = naive_utc_now() + account.updated_at = naive_utc_now() + + # Create mock tenant + tenant = Tenant(name="Test Tenant") + tenant.id = str(uuid.uuid4()) + + mock_session_instance = mock.Mock() + + mock_tenant_join = TenantAccountJoin(role=TenantAccountRole.OWNER) + monkeypatch.setattr(mock_session_instance, "scalar", mock.Mock(return_value=mock_tenant_join)) + + mock_scalars_result = mock.Mock() + mock_scalars_result.one.return_value = tenant + monkeypatch.setattr(mock_session_instance, "scalars", mock.Mock(return_value=mock_scalars_result)) + + mock_session_context = mock.Mock() + mock_session_context.__enter__.return_value = mock_session_instance + monkeypatch.setattr("models.account.Session", lambda _, expire_on_commit: mock_session_context) + + account.current_tenant = tenant + account.current_tenant_id = tenant.id + return account + + @pytest.mark.parametrize( + ("role", "list_status", "get_status", "update_status", "create_status", "build_status", "delete_status"), + [ + # Admin/Owner can do everything + (TenantAccountRole.OWNER, 200, 200, 200, 200, 200, 200), + (TenantAccountRole.ADMIN, 200, 200, 200, 200, 200, 200), + # Editor can list, get, update (parameters), but not create, build, or delete + (TenantAccountRole.EDITOR, 200, 200, 200, 403, 403, 403), + # Normal user cannot do anything + (TenantAccountRole.NORMAL, 403, 403, 403, 403, 403, 403), + # Dataset operator cannot do anything + (TenantAccountRole.DATASET_OPERATOR, 403, 403, 403, 403, 403, 403), + ], + ) + def test_trigger_subscription_permissions( + self, + test_client: FlaskClient, + auth_header, + monkeypatch, + mock_account, + role: TenantAccountRole, + list_status: int, + get_status: int, + update_status: int, + create_status: int, + build_status: int, + delete_status: int, + ): + """Test that different roles have appropriate permissions for trigger subscription operations.""" + # Set user role + mock_account.role = role + + # Mock current user + monkeypatch.setattr(trigger_providers_api, "current_user", mock_account) + + # Mock AccountService.load_user to prevent authentication issues + from services.account_service import AccountService + + mock_load_user = mock.Mock(return_value=mock_account) + monkeypatch.setattr(AccountService, "load_user", mock_load_user) + + # Test data + provider = "some_provider/some_trigger" + subscription_builder_id = str(uuid.uuid4()) + subscription_id = str(uuid.uuid4()) + + # Mock service methods + mock_list_subscriptions = mock.Mock(return_value=[]) + monkeypatch.setattr( + "services.trigger.trigger_provider_service.TriggerProviderService.list_trigger_provider_subscriptions", + mock_list_subscriptions, + ) + + mock_get_subscription_builder = mock.Mock(return_value={"id": subscription_builder_id}) + monkeypatch.setattr( + "services.trigger.trigger_subscription_builder_service.TriggerSubscriptionBuilderService.get_subscription_builder_by_id", + mock_get_subscription_builder, + ) + + mock_update_subscription_builder = mock.Mock(return_value={"id": subscription_builder_id}) + monkeypatch.setattr( + "services.trigger.trigger_subscription_builder_service.TriggerSubscriptionBuilderService.update_trigger_subscription_builder", + mock_update_subscription_builder, + ) + + mock_create_subscription_builder = mock.Mock(return_value={"id": subscription_builder_id}) + monkeypatch.setattr( + "services.trigger.trigger_subscription_builder_service.TriggerSubscriptionBuilderService.create_trigger_subscription_builder", + mock_create_subscription_builder, + ) + + mock_update_and_build_builder = mock.Mock() + monkeypatch.setattr( + "services.trigger.trigger_subscription_builder_service.TriggerSubscriptionBuilderService.update_and_build_builder", + mock_update_and_build_builder, + ) + + mock_delete_provider = mock.Mock() + mock_delete_plugin_trigger = mock.Mock() + mock_db_session = mock.Mock() + mock_db_session.commit = mock.Mock() + + def mock_session_func(engine=None): + return mock_session_context + + mock_session_context = mock.Mock() + mock_session_context.__enter__.return_value = mock_db_session + mock_session_context.__exit__.return_value = None + + monkeypatch.setattr("services.trigger.trigger_provider_service.Session", mock_session_func) + monkeypatch.setattr("services.trigger.trigger_subscription_operator_service.Session", mock_session_func) + + monkeypatch.setattr( + "services.trigger.trigger_provider_service.TriggerProviderService.delete_trigger_provider", + mock_delete_provider, + ) + monkeypatch.setattr( + "services.trigger.trigger_subscription_operator_service.TriggerSubscriptionOperatorService.delete_plugin_trigger_by_subscription", + mock_delete_plugin_trigger, + ) + + # Test 1: List subscriptions (should work for Editor, Admin, Owner) + response = test_client.get( + f"/console/api/workspaces/current/trigger-provider/{provider}/subscriptions/list", + headers=auth_header, + ) + assert response.status_code == list_status + + # Test 2: Get subscription builder (should work for Editor, Admin, Owner) + response = test_client.get( + f"/console/api/workspaces/current/trigger-provider/{provider}/subscriptions/builder/{subscription_builder_id}", + headers=auth_header, + ) + assert response.status_code == get_status + + # Test 3: Update subscription builder parameters (should work for Editor, Admin, Owner) + response = test_client.post( + f"/console/api/workspaces/current/trigger-provider/{provider}/subscriptions/builder/update/{subscription_builder_id}", + headers=auth_header, + json={"parameters": {"webhook_url": "https://example.com/webhook"}}, + ) + assert response.status_code == update_status + + # Test 4: Create subscription builder (should only work for Admin, Owner) + response = test_client.post( + f"/console/api/workspaces/current/trigger-provider/{provider}/subscriptions/builder/create", + headers=auth_header, + json={"credential_type": "api_key"}, + ) + assert response.status_code == create_status + + # Test 5: Build/activate subscription (should only work for Admin, Owner) + response = test_client.post( + f"/console/api/workspaces/current/trigger-provider/{provider}/subscriptions/builder/build/{subscription_builder_id}", + headers=auth_header, + json={"name": "Test Subscription"}, + ) + assert response.status_code == build_status + + # Test 6: Delete subscription (should only work for Admin, Owner) + response = test_client.post( + f"/console/api/workspaces/current/trigger-provider/{subscription_id}/subscriptions/delete", + headers=auth_header, + ) + assert response.status_code == delete_status + + @pytest.mark.parametrize( + ("role", "status"), + [ + (TenantAccountRole.OWNER, 200), + (TenantAccountRole.ADMIN, 200), + # Editor should be able to access logs for debugging + (TenantAccountRole.EDITOR, 200), + (TenantAccountRole.NORMAL, 403), + (TenantAccountRole.DATASET_OPERATOR, 403), + ], + ) + def test_trigger_subscription_logs_permissions( + self, + test_client: FlaskClient, + auth_header, + monkeypatch, + mock_account, + role: TenantAccountRole, + status: int, + ): + """Test that different roles have appropriate permissions for accessing subscription logs.""" + # Set user role + mock_account.role = role + + # Mock current user + monkeypatch.setattr(trigger_providers_api, "current_user", mock_account) + + # Mock AccountService.load_user to prevent authentication issues + from services.account_service import AccountService + + mock_load_user = mock.Mock(return_value=mock_account) + monkeypatch.setattr(AccountService, "load_user", mock_load_user) + + # Test data + provider = "some_provider/some_trigger" + subscription_builder_id = str(uuid.uuid4()) + + # Mock service method + mock_list_logs = mock.Mock(return_value=[]) + monkeypatch.setattr( + "services.trigger.trigger_subscription_builder_service.TriggerSubscriptionBuilderService.list_logs", + mock_list_logs, + ) + + # Test access to logs + response = test_client.get( + f"/console/api/workspaces/current/trigger-provider/{provider}/subscriptions/builder/logs/{subscription_builder_id}", + headers=auth_header, + ) + assert response.status_code == status From a0c8ebf48741645c1eb02646c3f8abaaab9f9a06 Mon Sep 17 00:00:00 2001 From: wangxiaolei Date: Tue, 9 Dec 2025 10:25:33 +0800 Subject: [PATCH 08/23] chore: not slient call external service error (#29290) --- api/services/external_knowledge_service.py | 3 +- .../services/test_external_dataset_service.py | 108 ++++++++++++++++-- 2 files changed, 102 insertions(+), 9 deletions(-) diff --git a/api/services/external_knowledge_service.py b/api/services/external_knowledge_service.py index 27936f6278..40faa85b9a 100644 --- a/api/services/external_knowledge_service.py +++ b/api/services/external_knowledge_service.py @@ -324,4 +324,5 @@ class ExternalDatasetService: ) if response.status_code == 200: return cast(list[Any], response.json().get("records", [])) - return [] + else: + raise ValueError(response.text) diff --git a/api/tests/unit_tests/services/test_external_dataset_service.py b/api/tests/unit_tests/services/test_external_dataset_service.py index c12ea2f7cb..e2d62583f8 100644 --- a/api/tests/unit_tests/services/test_external_dataset_service.py +++ b/api/tests/unit_tests/services/test_external_dataset_service.py @@ -6,6 +6,7 @@ Target: 1500+ lines of comprehensive test coverage. """ import json +import re from datetime import datetime from unittest.mock import MagicMock, Mock, patch @@ -1791,8 +1792,8 @@ class TestExternalDatasetServiceFetchRetrieval: @patch("services.external_knowledge_service.ExternalDatasetService.process_external_api") @patch("services.external_knowledge_service.db") - def test_fetch_external_knowledge_retrieval_non_200_status(self, mock_db, mock_process, factory): - """Test retrieval returns empty list on non-200 status.""" + def test_fetch_external_knowledge_retrieval_non_200_status_raises_exception(self, mock_db, mock_process, factory): + """Test that non-200 status code raises Exception with response text.""" # Arrange binding = factory.create_external_knowledge_binding_mock() api = factory.create_external_knowledge_api_mock() @@ -1817,12 +1818,103 @@ class TestExternalDatasetServiceFetchRetrieval: mock_response = MagicMock() mock_response.status_code = 500 + mock_response.text = "Internal Server Error: Database connection failed" mock_process.return_value = mock_response - # Act - result = ExternalDatasetService.fetch_external_knowledge_retrieval( - "tenant-123", "dataset-123", "query", {"top_k": 5} - ) + # Act & Assert + with pytest.raises(Exception, match="Internal Server Error: Database connection failed"): + ExternalDatasetService.fetch_external_knowledge_retrieval( + "tenant-123", "dataset-123", "query", {"top_k": 5} + ) - # Assert - assert result == [] + @pytest.mark.parametrize( + ("status_code", "error_message"), + [ + (400, "Bad Request: Invalid query parameters"), + (401, "Unauthorized: Invalid API key"), + (403, "Forbidden: Access denied to resource"), + (404, "Not Found: Knowledge base not found"), + (429, "Too Many Requests: Rate limit exceeded"), + (500, "Internal Server Error: Database connection failed"), + (502, "Bad Gateway: External service unavailable"), + (503, "Service Unavailable: Maintenance mode"), + ], + ) + @patch("services.external_knowledge_service.ExternalDatasetService.process_external_api") + @patch("services.external_knowledge_service.db") + def test_fetch_external_knowledge_retrieval_various_error_status_codes( + self, mock_db, mock_process, factory, status_code, error_message + ): + """Test that various error status codes raise exceptions with response text.""" + # Arrange + tenant_id = "tenant-123" + dataset_id = "dataset-123" + + binding = factory.create_external_knowledge_binding_mock( + dataset_id=dataset_id, external_knowledge_api_id="api-123" + ) + api = factory.create_external_knowledge_api_mock(api_id="api-123") + + mock_binding_query = MagicMock() + mock_api_query = MagicMock() + + def query_side_effect(model): + if model == ExternalKnowledgeBindings: + return mock_binding_query + elif model == ExternalKnowledgeApis: + return mock_api_query + return MagicMock() + + mock_db.session.query.side_effect = query_side_effect + + mock_binding_query.filter_by.return_value = mock_binding_query + mock_binding_query.first.return_value = binding + + mock_api_query.filter_by.return_value = mock_api_query + mock_api_query.first.return_value = api + + mock_response = MagicMock() + mock_response.status_code = status_code + mock_response.text = error_message + mock_process.return_value = mock_response + + # Act & Assert + with pytest.raises(ValueError, match=re.escape(error_message)): + ExternalDatasetService.fetch_external_knowledge_retrieval(tenant_id, dataset_id, "query", {"top_k": 5}) + + @patch("services.external_knowledge_service.ExternalDatasetService.process_external_api") + @patch("services.external_knowledge_service.db") + def test_fetch_external_knowledge_retrieval_empty_response_text(self, mock_db, mock_process, factory): + """Test exception with empty response text.""" + # Arrange + binding = factory.create_external_knowledge_binding_mock() + api = factory.create_external_knowledge_api_mock() + + mock_binding_query = MagicMock() + mock_api_query = MagicMock() + + def query_side_effect(model): + if model == ExternalKnowledgeBindings: + return mock_binding_query + elif model == ExternalKnowledgeApis: + return mock_api_query + return MagicMock() + + mock_db.session.query.side_effect = query_side_effect + + mock_binding_query.filter_by.return_value = mock_binding_query + mock_binding_query.first.return_value = binding + + mock_api_query.filter_by.return_value = mock_api_query + mock_api_query.first.return_value = api + + mock_response = MagicMock() + mock_response.status_code = 503 + mock_response.text = "" + mock_process.return_value = mock_response + + # Act & Assert + with pytest.raises(Exception, match=""): + ExternalDatasetService.fetch_external_knowledge_retrieval( + "tenant-123", "dataset-123", "query", {"top_k": 5} + ) From 48efd2d174cf457608b9d0913caa095671f2e449 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=9D=9E=E6=B3=95=E6=93=8D=E4=BD=9C?= Date: Tue, 9 Dec 2025 11:00:37 +0800 Subject: [PATCH 09/23] fix: try-to-ask misalign (#29309) Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> --- web/app/components/base/chat/chat/try-to-ask.tsx | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/web/app/components/base/chat/chat/try-to-ask.tsx b/web/app/components/base/chat/chat/try-to-ask.tsx index 7e3dcc95f9..3fc690361e 100644 --- a/web/app/components/base/chat/chat/try-to-ask.tsx +++ b/web/app/components/base/chat/chat/try-to-ask.tsx @@ -4,7 +4,6 @@ import { useTranslation } from 'react-i18next' import type { OnSend } from '../types' import Button from '@/app/components/base/button' import Divider from '@/app/components/base/divider' -import cn from '@/utils/classnames' type TryToAskProps = { suggestedQuestions: string[] @@ -20,12 +19,12 @@ const TryToAsk: FC = ({ return (
-
- +
+
{t('appDebug.feature.suggestedQuestionsAfterAnswer.tryToAsk')}
- {!isMobile && } +
-
+
{ suggestedQuestions.map((suggestQuestion, index) => ( + + Esc + +
+ {cachedImages[currentImage.url].status === 'loading' && ( + + )} + {cachedImages[currentImage.url].status === 'error' && ( +
+ {`Failed to load image: ${currentImage.url}. Please try again.`} + +
+ )} + {cachedImages[currentImage.url].status === 'loaded' && ( +
+ {currentImage.name} +
+ {currentImage.name} + · + {`${cachedImages[currentImage.url].width} ×  ${cachedImages[currentImage.url].height}`} + · + {formatFileSize(currentImage.size)} +
+
+ )} + + +
, + document.body, + ) +} + +export default ImagePreviewer diff --git a/web/app/components/datasets/common/image-uploader/constants.ts b/web/app/components/datasets/common/image-uploader/constants.ts new file mode 100644 index 0000000000..671ed94fcf --- /dev/null +++ b/web/app/components/datasets/common/image-uploader/constants.ts @@ -0,0 +1,7 @@ +export const ACCEPT_TYPES = ['jpg', 'jpeg', 'png', 'gif'] + +export const DEFAULT_IMAGE_FILE_SIZE_LIMIT = 2 + +export const DEFAULT_IMAGE_FILE_BATCH_LIMIT = 5 + +export const DEFAULT_SINGLE_CHUNK_ATTACHMENT_LIMIT = 10 diff --git a/web/app/components/datasets/common/image-uploader/hooks/use-upload.ts b/web/app/components/datasets/common/image-uploader/hooks/use-upload.ts new file mode 100644 index 0000000000..aefe48f0cd --- /dev/null +++ b/web/app/components/datasets/common/image-uploader/hooks/use-upload.ts @@ -0,0 +1,273 @@ +import { useCallback, useEffect, useMemo, useRef, useState } from 'react' +import { useFileUploadConfig } from '@/service/use-common' +import type { FileEntity, FileUploadConfig } from '../types' +import { getFileType, getFileUploadConfig, traverseFileEntry } from '../utils' +import Toast from '@/app/components/base/toast' +import { useTranslation } from 'react-i18next' +import { ACCEPT_TYPES } from '../constants' +import { useFileStore } from '../store' +import { produce } from 'immer' +import { fileUpload, getFileUploadErrorMessage } from '@/app/components/base/file-uploader/utils' +import { v4 as uuid4 } from 'uuid' + +export const useUpload = () => { + const { t } = useTranslation() + const fileStore = useFileStore() + + const [dragging, setDragging] = useState(false) + const uploaderRef = useRef(null) + const dragRef = useRef(null) + const dropRef = useRef(null) + + const { data: fileUploadConfigResponse } = useFileUploadConfig() + + const fileUploadConfig: FileUploadConfig = useMemo(() => { + return getFileUploadConfig(fileUploadConfigResponse) + }, [fileUploadConfigResponse]) + + const handleDragEnter = (e: DragEvent) => { + e.preventDefault() + e.stopPropagation() + if (e.target !== dragRef.current) + setDragging(true) + } + const handleDragOver = (e: DragEvent) => { + e.preventDefault() + e.stopPropagation() + } + const handleDragLeave = (e: DragEvent) => { + e.preventDefault() + e.stopPropagation() + if (e.target === dragRef.current) + setDragging(false) + } + + const checkFileType = useCallback((file: File) => { + const ext = getFileType(file) + return ACCEPT_TYPES.includes(ext.toLowerCase()) + }, []) + + const checkFileSize = useCallback((file: File) => { + const { size } = file + return size <= fileUploadConfig.imageFileSizeLimit * 1024 * 1024 + }, [fileUploadConfig]) + + const showErrorMessage = useCallback((type: 'type' | 'size') => { + if (type === 'type') + Toast.notify({ type: 'error', message: t('common.fileUploader.fileExtensionNotSupport') }) + else + Toast.notify({ type: 'error', message: t('dataset.imageUploader.fileSizeLimitExceeded', { size: fileUploadConfig.imageFileSizeLimit }) }) + }, [fileUploadConfig, t]) + + const getValidFiles = useCallback((files: File[]) => { + let validType = true + let validSize = true + const validFiles = files.filter((file) => { + if (!checkFileType(file)) { + validType = false + return false + } + if (!checkFileSize(file)) { + validSize = false + return false + } + return true + }) + if (!validType) + showErrorMessage('type') + else if (!validSize) + showErrorMessage('size') + + return validFiles + }, [checkFileType, checkFileSize, showErrorMessage]) + + const selectHandle = () => { + if (uploaderRef.current) + uploaderRef.current.click() + } + + const handleAddFile = useCallback((newFile: FileEntity) => { + const { + files, + setFiles, + } = fileStore.getState() + + const newFiles = produce(files, (draft) => { + draft.push(newFile) + }) + setFiles(newFiles) + }, [fileStore]) + + const handleUpdateFile = useCallback((newFile: FileEntity) => { + const { + files, + setFiles, + } = fileStore.getState() + + const newFiles = produce(files, (draft) => { + const index = draft.findIndex(file => file.id === newFile.id) + + if (index > -1) + draft[index] = newFile + }) + setFiles(newFiles) + }, [fileStore]) + + const handleRemoveFile = useCallback((fileId: string) => { + const { + files, + setFiles, + } = fileStore.getState() + + const newFiles = files.filter(file => file.id !== fileId) + setFiles(newFiles) + }, [fileStore]) + + const handleReUploadFile = useCallback((fileId: string) => { + const { + files, + setFiles, + } = fileStore.getState() + const index = files.findIndex(file => file.id === fileId) + + if (index > -1) { + const uploadingFile = files[index] + const newFiles = produce(files, (draft) => { + draft[index].progress = 0 + }) + setFiles(newFiles) + fileUpload({ + file: uploadingFile.originalFile!, + onProgressCallback: (progress) => { + handleUpdateFile({ ...uploadingFile, progress }) + }, + onSuccessCallback: (res) => { + handleUpdateFile({ ...uploadingFile, uploadedId: res.id, progress: 100 }) + }, + onErrorCallback: (error?: any) => { + const errorMessage = getFileUploadErrorMessage(error, t('common.fileUploader.uploadFromComputerUploadError'), t) + Toast.notify({ type: 'error', message: errorMessage }) + handleUpdateFile({ ...uploadingFile, progress: -1 }) + }, + }) + } + }, [fileStore, t, handleUpdateFile]) + + const handleLocalFileUpload = useCallback((file: File) => { + const reader = new FileReader() + const isImage = file.type.startsWith('image') + + reader.addEventListener( + 'load', + () => { + const uploadingFile = { + id: uuid4(), + name: file.name, + extension: getFileType(file), + mimeType: file.type, + size: file.size, + progress: 0, + originalFile: file, + base64Url: isImage ? reader.result as string : '', + } + handleAddFile(uploadingFile) + fileUpload({ + file: uploadingFile.originalFile, + onProgressCallback: (progress) => { + handleUpdateFile({ ...uploadingFile, progress }) + }, + onSuccessCallback: (res) => { + handleUpdateFile({ + ...uploadingFile, + extension: res.extension, + mimeType: res.mime_type, + size: res.size, + uploadedId: res.id, + progress: 100, + }) + }, + onErrorCallback: (error?: any) => { + const errorMessage = getFileUploadErrorMessage(error, t('common.fileUploader.uploadFromComputerUploadError'), t) + Toast.notify({ type: 'error', message: errorMessage }) + handleUpdateFile({ ...uploadingFile, progress: -1 }) + }, + }) + }, + false, + ) + reader.addEventListener( + 'error', + () => { + Toast.notify({ type: 'error', message: t('common.fileUploader.uploadFromComputerReadError') }) + }, + false, + ) + reader.readAsDataURL(file) + }, [t, handleAddFile, handleUpdateFile]) + + const handleFileUpload = useCallback((newFiles: File[]) => { + const { files } = fileStore.getState() + const { singleChunkAttachmentLimit } = fileUploadConfig + if (newFiles.length === 0) return + if (files.length + newFiles.length > singleChunkAttachmentLimit) { + Toast.notify({ + type: 'error', + message: t('datasetHitTesting.imageUploader.singleChunkAttachmentLimitTooltip', { limit: singleChunkAttachmentLimit }), + }) + return + } + for (const file of newFiles) + handleLocalFileUpload(file) + }, [fileUploadConfig, fileStore, t, handleLocalFileUpload]) + + const fileChangeHandle = useCallback((e: React.ChangeEvent) => { + const { imageFileBatchLimit } = fileUploadConfig + const files = Array.from(e.target.files ?? []).slice(0, imageFileBatchLimit) + const validFiles = getValidFiles(files) + handleFileUpload(validFiles) + }, [getValidFiles, handleFileUpload, fileUploadConfig]) + + const handleDrop = useCallback(async (e: DragEvent) => { + e.preventDefault() + e.stopPropagation() + setDragging(false) + if (!e.dataTransfer) return + const nested = await Promise.all( + Array.from(e.dataTransfer.items).map((it) => { + const entry = (it as any).webkitGetAsEntry?.() + if (entry) return traverseFileEntry(entry) + const f = it.getAsFile?.() + return f ? Promise.resolve([f]) : Promise.resolve([]) + }), + ) + const files = nested.flat().slice(0, fileUploadConfig.imageFileBatchLimit) + const validFiles = getValidFiles(files) + handleFileUpload(validFiles) + }, [fileUploadConfig, handleFileUpload, getValidFiles]) + + useEffect(() => { + dropRef.current?.addEventListener('dragenter', handleDragEnter) + dropRef.current?.addEventListener('dragover', handleDragOver) + dropRef.current?.addEventListener('dragleave', handleDragLeave) + dropRef.current?.addEventListener('drop', handleDrop) + return () => { + dropRef.current?.removeEventListener('dragenter', handleDragEnter) + dropRef.current?.removeEventListener('dragover', handleDragOver) + dropRef.current?.removeEventListener('dragleave', handleDragLeave) + dropRef.current?.removeEventListener('drop', handleDrop) + } + }, [handleDrop]) + + return { + dragging, + fileUploadConfig, + dragRef, + dropRef, + uploaderRef, + fileChangeHandle, + selectHandle, + handleRemoveFile, + handleReUploadFile, + handleLocalFileUpload, + } +} diff --git a/web/app/components/datasets/common/image-uploader/image-uploader-in-chunk/image-input.tsx b/web/app/components/datasets/common/image-uploader/image-uploader-in-chunk/image-input.tsx new file mode 100644 index 0000000000..3e15b92705 --- /dev/null +++ b/web/app/components/datasets/common/image-uploader/image-uploader-in-chunk/image-input.tsx @@ -0,0 +1,64 @@ +import React from 'react' +import cn from '@/utils/classnames' +import { RiUploadCloud2Line } from '@remixicon/react' +import { useTranslation } from 'react-i18next' +import { useUpload } from '../hooks/use-upload' +import { ACCEPT_TYPES } from '../constants' + +const ImageUploader = () => { + const { t } = useTranslation() + + const { + dragging, + fileUploadConfig, + dragRef, + dropRef, + uploaderRef, + fileChangeHandle, + selectHandle, + } = useUpload() + + return ( +
+ `.${ext}`).join(',')} + onChange={fileChangeHandle} + /> +
+
+ +
+ {t('dataset.imageUploader.button')} + + {t('dataset.imageUploader.browse')} + +
+
+
+ {t('dataset.imageUploader.tip', { + size: fileUploadConfig.imageFileSizeLimit, + supportTypes: ACCEPT_TYPES.join(', '), + batchCount: fileUploadConfig.imageFileBatchLimit, + })} +
+ {dragging &&
} +
+
+ ) +} + +export default React.memo(ImageUploader) diff --git a/web/app/components/datasets/common/image-uploader/image-uploader-in-chunk/image-item.tsx b/web/app/components/datasets/common/image-uploader/image-uploader-in-chunk/image-item.tsx new file mode 100644 index 0000000000..a5bfb65fa2 --- /dev/null +++ b/web/app/components/datasets/common/image-uploader/image-uploader-in-chunk/image-item.tsx @@ -0,0 +1,95 @@ +import { + memo, + useCallback, +} from 'react' +import { + RiCloseLine, +} from '@remixicon/react' +import FileImageRender from '@/app/components/base/file-uploader/file-image-render' +import type { FileEntity } from '../types' +import ProgressCircle from '@/app/components/base/progress-bar/progress-circle' +import { ReplayLine } from '@/app/components/base/icons/src/vender/other' +import { fileIsUploaded } from '../utils' +import Button from '@/app/components/base/button' + +type ImageItemProps = { + file: FileEntity + showDeleteAction?: boolean + onRemove?: (fileId: string) => void + onReUpload?: (fileId: string) => void + onPreview?: (fileId: string) => void +} +const ImageItem = ({ + file, + showDeleteAction, + onRemove, + onReUpload, + onPreview, +}: ImageItemProps) => { + const { id, progress, base64Url, sourceUrl } = file + + const handlePreview = useCallback((e: React.MouseEvent) => { + e.stopPropagation() + e.preventDefault() + onPreview?.(id) + }, [onPreview, id]) + + const handleRemove = useCallback((e: React.MouseEvent) => { + e.stopPropagation() + e.preventDefault() + onRemove?.(id) + }, [onRemove, id]) + + const handleReUpload = useCallback((e: React.MouseEvent) => { + e.stopPropagation() + e.preventDefault() + onReUpload?.(id) + }, [onReUpload, id]) + + return ( +
+ { + showDeleteAction && ( + + ) + } + + { + progress >= 0 && !fileIsUploaded(file) && ( +
+ +
+ ) + } + { + progress === -1 && ( +
+ +
+ ) + } +
+ ) +} + +export default memo(ImageItem) diff --git a/web/app/components/datasets/common/image-uploader/image-uploader-in-chunk/index.tsx b/web/app/components/datasets/common/image-uploader/image-uploader-in-chunk/index.tsx new file mode 100644 index 0000000000..3efa3a19d7 --- /dev/null +++ b/web/app/components/datasets/common/image-uploader/image-uploader-in-chunk/index.tsx @@ -0,0 +1,94 @@ +import { + FileContextProvider, + useFileStoreWithSelector, +} from '../store' +import type { FileEntity } from '../types' +import FileItem from './image-item' +import { useUpload } from '../hooks/use-upload' +import ImageInput from './image-input' +import cn from '@/utils/classnames' +import { useCallback, useState } from 'react' +import type { ImageInfo } from '@/app/components/datasets/common/image-previewer' +import ImagePreviewer from '@/app/components/datasets/common/image-previewer' + +type ImageUploaderInChunkProps = { + disabled?: boolean + className?: string +} +const ImageUploaderInChunk = ({ + disabled, + className, +}: ImageUploaderInChunkProps) => { + const files = useFileStoreWithSelector(s => s.files) + const [previewIndex, setPreviewIndex] = useState(0) + const [previewImages, setPreviewImages] = useState([]) + + const handleImagePreview = useCallback((fileId: string) => { + const index = files.findIndex(item => item.id === fileId) + if (index === -1) return + setPreviewIndex(index) + setPreviewImages(files.map(item => ({ + url: item.base64Url || item.sourceUrl || '', + name: item.name, + size: item.size, + }))) + }, [files]) + + const handleClosePreview = useCallback(() => { + setPreviewImages([]) + }, []) + + const { + handleRemoveFile, + handleReUploadFile, + } = useUpload() + + return ( +
+ {!disabled && } +
+ { + files.map(file => ( + + )) + } +
+ {previewImages.length > 0 && ( + + )} +
+ ) +} + +export type ImageUploaderInChunkWrapperProps = { + value?: FileEntity[] + onChange: (files: FileEntity[]) => void +} & ImageUploaderInChunkProps + +const ImageUploaderInChunkWrapper = ({ + value, + onChange, + ...props +}: ImageUploaderInChunkWrapperProps) => { + return ( + + + + ) +} + +export default ImageUploaderInChunkWrapper diff --git a/web/app/components/datasets/common/image-uploader/image-uploader-in-retrieval-testing/image-input.tsx b/web/app/components/datasets/common/image-uploader/image-uploader-in-retrieval-testing/image-input.tsx new file mode 100644 index 0000000000..4f230e3957 --- /dev/null +++ b/web/app/components/datasets/common/image-uploader/image-uploader-in-retrieval-testing/image-input.tsx @@ -0,0 +1,64 @@ +import React from 'react' +import { useTranslation } from 'react-i18next' +import { useUpload } from '../hooks/use-upload' +import { ACCEPT_TYPES } from '../constants' +import { useFileStoreWithSelector } from '../store' +import { RiImageAddLine } from '@remixicon/react' +import Tooltip from '@/app/components/base/tooltip' + +const ImageUploader = () => { + const { t } = useTranslation() + const files = useFileStoreWithSelector(s => s.files) + + const { + fileUploadConfig, + uploaderRef, + fileChangeHandle, + selectHandle, + } = useUpload() + + return ( +
+ `.${ext}`).join(',')} + onChange={fileChangeHandle} + /> +
+ +
+
+ +
+ {files.length === 0 && ( + + {t('datasetHitTesting.imageUploader.tip', { + size: fileUploadConfig.imageFileSizeLimit, + batchCount: fileUploadConfig.imageFileBatchLimit, + })} + + )} +
+
+
+
+ ) +} + +export default React.memo(ImageUploader) diff --git a/web/app/components/datasets/common/image-uploader/image-uploader-in-retrieval-testing/image-item.tsx b/web/app/components/datasets/common/image-uploader/image-uploader-in-retrieval-testing/image-item.tsx new file mode 100644 index 0000000000..a47356e560 --- /dev/null +++ b/web/app/components/datasets/common/image-uploader/image-uploader-in-retrieval-testing/image-item.tsx @@ -0,0 +1,95 @@ +import { + memo, + useCallback, +} from 'react' +import { + RiCloseLine, +} from '@remixicon/react' +import FileImageRender from '@/app/components/base/file-uploader/file-image-render' +import type { FileEntity } from '../types' +import ProgressCircle from '@/app/components/base/progress-bar/progress-circle' +import { ReplayLine } from '@/app/components/base/icons/src/vender/other' +import { fileIsUploaded } from '../utils' +import Button from '@/app/components/base/button' + +type ImageItemProps = { + file: FileEntity + showDeleteAction?: boolean + onRemove?: (fileId: string) => void + onReUpload?: (fileId: string) => void + onPreview?: (fileId: string) => void +} +const ImageItem = ({ + file, + showDeleteAction, + onRemove, + onReUpload, + onPreview, +}: ImageItemProps) => { + const { id, progress, base64Url, sourceUrl } = file + + const handlePreview = useCallback((e: React.MouseEvent) => { + e.stopPropagation() + e.preventDefault() + onPreview?.(id) + }, [onPreview, id]) + + const handleRemove = useCallback((e: React.MouseEvent) => { + e.stopPropagation() + e.preventDefault() + onRemove?.(id) + }, [onRemove, id]) + + const handleReUpload = useCallback((e: React.MouseEvent) => { + e.stopPropagation() + e.preventDefault() + onReUpload?.(id) + }, [onReUpload, id]) + + return ( +
+ { + showDeleteAction && ( + + ) + } + + { + progress >= 0 && !fileIsUploaded(file) && ( +
+ +
+ ) + } + { + progress === -1 && ( +
+ +
+ ) + } +
+ ) +} + +export default memo(ImageItem) diff --git a/web/app/components/datasets/common/image-uploader/image-uploader-in-retrieval-testing/index.tsx b/web/app/components/datasets/common/image-uploader/image-uploader-in-retrieval-testing/index.tsx new file mode 100644 index 0000000000..2d04132842 --- /dev/null +++ b/web/app/components/datasets/common/image-uploader/image-uploader-in-retrieval-testing/index.tsx @@ -0,0 +1,131 @@ +import { + useCallback, + useState, +} from 'react' +import { + FileContextProvider, +} from '../store' +import type { FileEntity } from '../types' +import { useUpload } from '../hooks/use-upload' +import ImageInput from './image-input' +import cn from '@/utils/classnames' +import { useTranslation } from 'react-i18next' +import { useFileStoreWithSelector } from '../store' +import ImageItem from './image-item' +import type { ImageInfo } from '@/app/components/datasets/common/image-previewer' +import ImagePreviewer from '@/app/components/datasets/common/image-previewer' + +type ImageUploaderInRetrievalTestingProps = { + textArea: React.ReactNode + actionButton: React.ReactNode + showUploader?: boolean + className?: string + actionAreaClassName?: string +} +const ImageUploaderInRetrievalTesting = ({ + textArea, + actionButton, + showUploader = true, + className, + actionAreaClassName, +}: ImageUploaderInRetrievalTestingProps) => { + const { t } = useTranslation() + const files = useFileStoreWithSelector(s => s.files) + const [previewIndex, setPreviewIndex] = useState(0) + const [previewImages, setPreviewImages] = useState([]) + const { + dragging, + dragRef, + dropRef, + handleRemoveFile, + handleReUploadFile, + } = useUpload() + + const handleImagePreview = useCallback((fileId: string) => { + const index = files.findIndex(item => item.id === fileId) + if (index === -1) return + setPreviewIndex(index) + setPreviewImages(files.map(item => ({ + url: item.base64Url || item.sourceUrl || '', + name: item.name, + size: item.size, + }))) + }, [files]) + + const handleClosePreview = useCallback(() => { + setPreviewImages([]) + }, []) + + return ( +
+ {dragging && ( +
+
{t('datasetHitTesting.imageUploader.dropZoneTip')}
+
+
+ )} + {textArea} + { + showUploader && !!files.length && ( +
+ { + files.map(file => ( + + )) + } +
+ ) + } +
+ {showUploader && } + {actionButton} +
+ {previewImages.length > 0 && ( + + )} +
+ ) +} + +export type ImageUploaderInRetrievalTestingWrapperProps = { + value?: FileEntity[] + onChange: (files: FileEntity[]) => void +} & ImageUploaderInRetrievalTestingProps + +const ImageUploaderInRetrievalTestingWrapper = ({ + value, + onChange, + ...props +}: ImageUploaderInRetrievalTestingWrapperProps) => { + return ( + + + + ) +} + +export default ImageUploaderInRetrievalTestingWrapper diff --git a/web/app/components/datasets/common/image-uploader/store.tsx b/web/app/components/datasets/common/image-uploader/store.tsx new file mode 100644 index 0000000000..e3c9e28a84 --- /dev/null +++ b/web/app/components/datasets/common/image-uploader/store.tsx @@ -0,0 +1,67 @@ +import { + createContext, + useContext, + useRef, +} from 'react' +import { + create, + useStore, +} from 'zustand' +import type { + FileEntity, +} from './types' + +type Shape = { + files: FileEntity[] + setFiles: (files: FileEntity[]) => void +} + +export const createFileStore = ( + value: FileEntity[] = [], + onChange?: (files: FileEntity[]) => void, +) => { + return create(set => ({ + files: value ? [...value] : [], + setFiles: (files) => { + set({ files }) + onChange?.(files) + }, + })) +} + +type FileStore = ReturnType +export const FileContext = createContext(null) + +export function useFileStoreWithSelector(selector: (state: Shape) => T): T { + const store = useContext(FileContext) + if (!store) + throw new Error('Missing FileContext.Provider in the tree') + + return useStore(store, selector) +} + +export const useFileStore = () => { + return useContext(FileContext)! +} + +type FileProviderProps = { + children: React.ReactNode + value?: FileEntity[] + onChange?: (files: FileEntity[]) => void +} +export const FileContextProvider = ({ + children, + value, + onChange, +}: FileProviderProps) => { + const storeRef = useRef(undefined) + + if (!storeRef.current) + storeRef.current = createFileStore(value, onChange) + + return ( + + {children} + + ) +} diff --git a/web/app/components/datasets/common/image-uploader/types.ts b/web/app/components/datasets/common/image-uploader/types.ts new file mode 100644 index 0000000000..e918f2b41e --- /dev/null +++ b/web/app/components/datasets/common/image-uploader/types.ts @@ -0,0 +1,18 @@ +export type FileEntity = { + id: string + name: string + size: number + extension: string + mimeType: string + progress: number // -1: error, 0 ~ 99: uploading, 100: uploaded + originalFile?: File // used for re-uploading + uploadedId?: string // for uploaded image id + sourceUrl?: string // for uploaded image + base64Url?: string // for image preview during uploading +} + +export type FileUploadConfig = { + imageFileSizeLimit: number + imageFileBatchLimit: number + singleChunkAttachmentLimit: number +} diff --git a/web/app/components/datasets/common/image-uploader/utils.ts b/web/app/components/datasets/common/image-uploader/utils.ts new file mode 100644 index 0000000000..842b279a98 --- /dev/null +++ b/web/app/components/datasets/common/image-uploader/utils.ts @@ -0,0 +1,92 @@ +import type { FileUploadConfigResponse } from '@/models/common' +import type { FileEntity } from './types' +import { + DEFAULT_IMAGE_FILE_BATCH_LIMIT, + DEFAULT_IMAGE_FILE_SIZE_LIMIT, + DEFAULT_SINGLE_CHUNK_ATTACHMENT_LIMIT, +} from './constants' + +export const getFileType = (currentFile: File) => { + if (!currentFile) + return '' + + const arr = currentFile.name.split('.') + return arr[arr.length - 1] +} + +type FileWithPath = { + relativePath?: string +} & File + +export const traverseFileEntry = (entry: any, prefix = ''): Promise => { + return new Promise((resolve) => { + if (entry.isFile) { + entry.file((file: FileWithPath) => { + file.relativePath = `${prefix}${file.name}` + resolve([file]) + }) + } + else if (entry.isDirectory) { + const reader = entry.createReader() + const entries: any[] = [] + const read = () => { + reader.readEntries(async (results: FileSystemEntry[]) => { + if (!results.length) { + const files = await Promise.all( + entries.map(ent => + traverseFileEntry(ent, `${prefix}${entry.name}/`), + ), + ) + resolve(files.flat()) + } + else { + entries.push(...results) + read() + } + }) + } + read() + } + else { + resolve([]) + } + }) +} + +export const fileIsUploaded = (file: FileEntity) => { + if (file.uploadedId || file.progress === 100) + return true +} + +const getNumberValue = (value: number | string | undefined | null): number => { + if (value === undefined || value === null) + return 0 + if (typeof value === 'number') + return value + if (typeof value === 'string') + return Number(value) + return 0 +} + +export const getFileUploadConfig = (fileUploadConfigResponse: FileUploadConfigResponse | undefined) => { + if (!fileUploadConfigResponse) { + return { + imageFileSizeLimit: DEFAULT_IMAGE_FILE_SIZE_LIMIT, + imageFileBatchLimit: DEFAULT_IMAGE_FILE_BATCH_LIMIT, + singleChunkAttachmentLimit: DEFAULT_SINGLE_CHUNK_ATTACHMENT_LIMIT, + } + } + const { + image_file_batch_limit, + single_chunk_attachment_limit, + attachment_image_file_size_limit, + } = fileUploadConfigResponse + const imageFileSizeLimit = getNumberValue(attachment_image_file_size_limit) + const imageFileBatchLimit = getNumberValue(image_file_batch_limit) + const singleChunkAttachmentLimit = getNumberValue(single_chunk_attachment_limit) + return { + imageFileSizeLimit: imageFileSizeLimit > 0 ? imageFileSizeLimit : DEFAULT_IMAGE_FILE_SIZE_LIMIT, + imageFileBatchLimit: imageFileBatchLimit > 0 ? imageFileBatchLimit : DEFAULT_IMAGE_FILE_BATCH_LIMIT, + singleChunkAttachmentLimit: singleChunkAttachmentLimit > 0 ? singleChunkAttachmentLimit : DEFAULT_SINGLE_CHUNK_ATTACHMENT_LIMIT, + } +} diff --git a/web/app/components/datasets/common/retrieval-method-config/index.tsx b/web/app/components/datasets/common/retrieval-method-config/index.tsx index ed230c52ce..c0952ed4a4 100644 --- a/web/app/components/datasets/common/retrieval-method-config/index.tsx +++ b/web/app/components/datasets/common/retrieval-method-config/index.tsx @@ -20,12 +20,14 @@ import { EffectColor } from '../../settings/chunk-structure/types' type Props = { disabled?: boolean value: RetrievalConfig + showMultiModalTip?: boolean onChange: (value: RetrievalConfig) => void } const RetrievalMethodConfig: FC = ({ disabled = false, value, + showMultiModalTip = false, onChange, }) => { const { t } = useTranslation() @@ -110,6 +112,7 @@ const RetrievalMethodConfig: FC = ({ type={RETRIEVE_METHOD.semantic} value={value} onChange={onChange} + showMultiModalTip={showMultiModalTip} /> )} @@ -132,6 +135,7 @@ const RetrievalMethodConfig: FC = ({ type={RETRIEVE_METHOD.fullText} value={value} onChange={onChange} + showMultiModalTip={showMultiModalTip} /> )} @@ -155,6 +159,7 @@ const RetrievalMethodConfig: FC = ({ type={RETRIEVE_METHOD.hybrid} value={value} onChange={onChange} + showMultiModalTip={showMultiModalTip} /> )} diff --git a/web/app/components/datasets/common/retrieval-param-config/index.tsx b/web/app/components/datasets/common/retrieval-param-config/index.tsx index 0c28149d56..2b703cc44d 100644 --- a/web/app/components/datasets/common/retrieval-param-config/index.tsx +++ b/web/app/components/datasets/common/retrieval-param-config/index.tsx @@ -24,16 +24,19 @@ import { import WeightedScore from '@/app/components/app/configuration/dataset-config/params-config/weighted-score' import Toast from '@/app/components/base/toast' import RadioCard from '@/app/components/base/radio-card' +import { AlertTriangle } from '@/app/components/base/icons/src/vender/solid/alertsAndFeedback' type Props = { type: RETRIEVE_METHOD value: RetrievalConfig + showMultiModalTip?: boolean onChange: (value: RetrievalConfig) => void } const RetrievalParamConfig: FC = ({ type, value, + showMultiModalTip = false, onChange, }) => { const { t } = useTranslation() @@ -133,19 +136,32 @@ const RetrievalParamConfig: FC = ({
{ value.reranking_enable && ( - { - onChange({ - ...value, - reranking_model: { - reranking_provider_name: v.provider, - reranking_model_name: v.model, - }, - }) - }} - /> + <> + { + onChange({ + ...value, + reranking_model: { + reranking_provider_name: v.provider, + reranking_model_name: v.model, + }, + }) + }} + /> + {showMultiModalTip && ( +
+
+
+ +
+ + {t('datasetSettings.form.retrievalSetting.multiModalTip')} + +
+ )} + ) }
@@ -239,19 +255,32 @@ const RetrievalParamConfig: FC = ({ } { value.reranking_mode !== RerankingModeEnum.WeightedScore && ( - { - onChange({ - ...value, - reranking_model: { - reranking_provider_name: v.provider, - reranking_model_name: v.model, - }, - }) - }} - /> + <> + { + onChange({ + ...value, + reranking_model: { + reranking_provider_name: v.provider, + reranking_model_name: v.model, + }, + }) + }} + /> + {showMultiModalTip && ( +
+
+
+ +
+ + {t('datasetSettings.form.retrievalSetting.multiModalTip')} + +
+ )} + ) }
diff --git a/web/app/components/datasets/create/file-uploader/index.tsx b/web/app/components/datasets/create/file-uploader/index.tsx index 4aec0d4082..d258ed694e 100644 --- a/web/app/components/datasets/create/file-uploader/index.tsx +++ b/web/app/components/datasets/create/file-uploader/index.tsx @@ -68,11 +68,11 @@ const FileUploader = ({ .join(locale !== LanguagesSupported[1] ? ', ' : '、 ') })() const ACCEPTS = supportTypes.map((ext: string) => `.${ext}`) - const fileUploadConfig = useMemo(() => fileUploadConfigResponse ?? { - file_size_limit: 15, - batch_count_limit: 5, - file_upload_limit: 5, - }, [fileUploadConfigResponse]) + const fileUploadConfig = useMemo(() => ({ + file_size_limit: fileUploadConfigResponse?.file_size_limit ?? 15, + batch_count_limit: fileUploadConfigResponse?.batch_count_limit ?? 5, + file_upload_limit: fileUploadConfigResponse?.file_upload_limit ?? 5, + }), [fileUploadConfigResponse]) const fileListRef = useRef([]) diff --git a/web/app/components/datasets/create/step-two/index.tsx b/web/app/components/datasets/create/step-two/index.tsx index 22d6837754..43be89c326 100644 --- a/web/app/components/datasets/create/step-two/index.tsx +++ b/web/app/components/datasets/create/step-two/index.tsx @@ -1,6 +1,6 @@ 'use client' import type { FC, PropsWithChildren } from 'react' -import React, { useCallback, useEffect, useState } from 'react' +import React, { useCallback, useEffect, useMemo, useState } from 'react' import { useTranslation } from 'react-i18next' import { useContext } from 'use-context-selector' import { @@ -63,6 +63,7 @@ import { AlertTriangle } from '@/app/components/base/icons/src/vender/solid/aler import { noop } from 'lodash-es' import { useDocLink } from '@/context/i18n' import { useInvalidDatasetList } from '@/service/knowledge/use-dataset' +import { checkShowMultiModalTip } from '../../settings/utils' const TextLabel: FC = (props) => { return @@ -495,12 +496,6 @@ const StepTwo = ({ setDefaultConfig(data.rules) setLimitMaxChunkLength(data.limits.indexing_max_segmentation_tokens_length) }, - onError(error) { - Toast.notify({ - type: 'error', - message: `${error}`, - }) - }, }) const getRulesFromDetail = () => { @@ -538,22 +533,8 @@ const StepTwo = ({ setSegmentationType(documentDetail.dataset_process_rule.mode) } - const createFirstDocumentMutation = useCreateFirstDocument({ - onError(error) { - Toast.notify({ - type: 'error', - message: `${error}`, - }) - }, - }) - const createDocumentMutation = useCreateDocument(datasetId!, { - onError(error) { - Toast.notify({ - type: 'error', - message: `${error}`, - }) - }, - }) + const createFirstDocumentMutation = useCreateFirstDocument() + const createDocumentMutation = useCreateDocument(datasetId!) const isCreating = createFirstDocumentMutation.isPending || createDocumentMutation.isPending const invalidDatasetList = useInvalidDatasetList() @@ -613,6 +594,20 @@ const StepTwo = ({ const isModelAndRetrievalConfigDisabled = !!datasetId && !!currentDataset?.data_source_type + const showMultiModalTip = useMemo(() => { + return checkShowMultiModalTip({ + embeddingModel, + rerankingEnable: retrievalConfig.reranking_enable, + rerankModel: { + rerankingProviderName: retrievalConfig.reranking_model.reranking_provider_name, + rerankingModelName: retrievalConfig.reranking_model.reranking_model_name, + }, + indexMethod: indexType, + embeddingModelList, + rerankModelList, + }) + }, [embeddingModel, retrievalConfig.reranking_enable, retrievalConfig.reranking_model, indexType, embeddingModelList, rerankModelList]) + return (
@@ -1012,6 +1007,7 @@ const StepTwo = ({ disabled={isModelAndRetrievalConfigDisabled} value={retrievalConfig} onChange={setRetrievalConfig} + showMultiModalTip={showMultiModalTip} /> ) : ( diff --git a/web/app/components/datasets/documents/create-from-pipeline/data-source/local-file/index.tsx b/web/app/components/datasets/documents/create-from-pipeline/data-source/local-file/index.tsx index 868621e1a3..555f2497ef 100644 --- a/web/app/components/datasets/documents/create-from-pipeline/data-source/local-file/index.tsx +++ b/web/app/components/datasets/documents/create-from-pipeline/data-source/local-file/index.tsx @@ -21,8 +21,6 @@ import dynamic from 'next/dynamic' const SimplePieChart = dynamic(() => import('@/app/components/base/simple-pie-chart'), { ssr: false }) -const FILES_NUMBER_LIMIT = 20 - export type LocalFileProps = { allowedExtensions: string[] notSupportBatchUpload?: boolean @@ -64,10 +62,11 @@ const LocalFile = ({ .join(locale !== LanguagesSupported[1] ? ', ' : '、 ') }, [locale, allowedExtensions]) const ACCEPTS = allowedExtensions.map((ext: string) => `.${ext}`) - const fileUploadConfig = useMemo(() => fileUploadConfigResponse ?? { - file_size_limit: 15, - batch_count_limit: 5, - }, [fileUploadConfigResponse]) + const fileUploadConfig = useMemo(() => ({ + file_size_limit: fileUploadConfigResponse?.file_size_limit ?? 15, + batch_count_limit: fileUploadConfigResponse?.batch_count_limit ?? 5, + file_upload_limit: fileUploadConfigResponse?.file_upload_limit ?? 5, + }), [fileUploadConfigResponse]) const updateFile = useCallback((fileItem: FileItem, progress: number, list: FileItem[]) => { const { setLocalFileList } = dataSourceStore.getState() @@ -186,11 +185,12 @@ const LocalFile = ({ }, [fileUploadConfig, uploadBatchFiles]) const initialUpload = useCallback((files: File[]) => { + const filesCountLimit = fileUploadConfig.file_upload_limit if (!files.length) return false - if (files.length + localFileList.length > FILES_NUMBER_LIMIT && !IS_CE_EDITION) { - notify({ type: 'error', message: t('datasetCreation.stepOne.uploader.validation.filesNumber', { filesNumber: FILES_NUMBER_LIMIT }) }) + if (files.length + localFileList.length > filesCountLimit && !IS_CE_EDITION) { + notify({ type: 'error', message: t('datasetCreation.stepOne.uploader.validation.filesNumber', { filesNumber: filesCountLimit }) }) return false } @@ -203,7 +203,7 @@ const LocalFile = ({ updateFileList(newFiles) fileListRef.current = newFiles uploadMultipleFiles(preparedFiles) - }, [updateFileList, uploadMultipleFiles, notify, t, localFileList]) + }, [fileUploadConfig.file_upload_limit, localFileList.length, updateFileList, uploadMultipleFiles, notify, t]) const handleDragEnter = (e: DragEvent) => { e.preventDefault() @@ -250,9 +250,10 @@ const LocalFile = ({ updateFileList([...fileListRef.current]) } const fileChangeHandle = useCallback((e: React.ChangeEvent) => { - const files = [...(e.target.files ?? [])] as File[] + let files = [...(e.target.files ?? [])] as File[] + files = files.slice(0, fileUploadConfig.batch_count_limit) initialUpload(files.filter(isValid)) - }, [isValid, initialUpload]) + }, [isValid, initialUpload, fileUploadConfig.batch_count_limit]) const { theme } = useTheme() const chartColor = useMemo(() => theme === Theme.dark ? '#5289ff' : '#296dff', [theme]) @@ -305,6 +306,7 @@ const LocalFile = ({ size: fileUploadConfig.file_size_limit, supportTypes: supportTypesShowNames, batchCount: notSupportBatchUpload ? 1 : fileUploadConfig.batch_count_limit, + totalCount: fileUploadConfig.file_upload_limit, })}
{dragging &&
}
diff --git a/web/app/components/datasets/documents/detail/completed/common/action-buttons.tsx b/web/app/components/datasets/documents/detail/completed/common/action-buttons.tsx index 4bed7b461d..c5d3bf5629 100644 --- a/web/app/components/datasets/documents/detail/completed/common/action-buttons.tsx +++ b/web/app/components/datasets/documents/detail/completed/common/action-buttons.tsx @@ -13,6 +13,7 @@ type IActionButtonsProps = { actionType?: 'edit' | 'add' handleRegeneration?: () => void isChildChunk?: boolean + showRegenerationButton?: boolean } const ActionButtons: FC = ({ @@ -22,6 +23,7 @@ const ActionButtons: FC = ({ actionType = 'edit', handleRegeneration, isChildChunk = false, + showRegenerationButton = true, }) => { const { t } = useTranslation() const docForm = useDocumentContext(s => s.docForm) @@ -54,7 +56,7 @@ const ActionButtons: FC = ({ ESC
- {(isParentChildParagraphMode && actionType === 'edit' && !isChildChunk) + {(isParentChildParagraphMode && actionType === 'edit' && !isChildChunk && showRegenerationButton) ?