diff --git a/api/core/helper/tool_provider_cache.py b/api/core/helper/tool_provider_cache.py
index eef5937407..c5447c2b3f 100644
--- a/api/core/helper/tool_provider_cache.py
+++ b/api/core/helper/tool_provider_cache.py
@@ -1,6 +1,6 @@
import json
import logging
-from typing import Any
+from typing import Any, cast
from core.tools.entities.api_entities import ToolProviderTypeApiLiteral
from extensions.ext_redis import redis_client, redis_fallback
@@ -50,7 +50,9 @@ class ToolProviderListCache:
redis_client.delete(cache_key)
else:
# Invalidate all caches for this tenant
- pattern = f"tool_providers:tenant_id:{tenant_id}:*"
- keys = list(redis_client.scan_iter(pattern))
- if keys:
- redis_client.delete(*keys)
+ keys = ["builtin", "model", "api", "workflow", "mcp"]
+ pipeline = redis_client.pipeline()
+ for key in keys:
+ cache_key = ToolProviderListCache._generate_cache_key(tenant_id, cast(ToolProviderTypeApiLiteral, key))
+ pipeline.delete(cache_key)
+ pipeline.execute()
diff --git a/api/core/logging/structured_formatter.py b/api/core/logging/structured_formatter.py
index 1b6e0bbbf0..dece96dc2f 100644
--- a/api/core/logging/structured_formatter.py
+++ b/api/core/logging/structured_formatter.py
@@ -54,8 +54,8 @@ class StructuredJSONFormatter(logging.Formatter):
}
# Trace context (from TraceContextFilter)
- trace_id = getattr(record, "trace_id", "") or ""
- span_id = getattr(record, "span_id", "") or ""
+ trace_id = getattr(record, "trace_id", "")
+ span_id = getattr(record, "span_id", "")
if trace_id:
log_dict["trace_id"] = trace_id
diff --git a/api/pyproject.toml b/api/pyproject.toml
index 6716603dd4..dbc6a2eb83 100644
--- a/api/pyproject.toml
+++ b/api/pyproject.toml
@@ -1,6 +1,6 @@
[project]
name = "dify-api"
-version = "1.11.1"
+version = "1.11.2"
requires-python = ">=3.11,<3.13"
dependencies = [
diff --git a/api/services/dataset_service.py b/api/services/dataset_service.py
index 970192fde5..ac4b25c5dc 100644
--- a/api/services/dataset_service.py
+++ b/api/services/dataset_service.py
@@ -3458,7 +3458,7 @@ class SegmentService:
if keyword:
query = query.where(DocumentSegment.content.ilike(f"%{keyword}%"))
- query = query.order_by(DocumentSegment.position.asc())
+ query = query.order_by(DocumentSegment.position.asc(), DocumentSegment.id.asc())
paginated_segments = db.paginate(select=query, page=page, per_page=limit, max_per_page=100, error_out=False)
return paginated_segments.items, paginated_segments.total
diff --git a/api/services/tools/builtin_tools_manage_service.py b/api/services/tools/builtin_tools_manage_service.py
index cf1d39fa25..87951d53e6 100644
--- a/api/services/tools/builtin_tools_manage_service.py
+++ b/api/services/tools/builtin_tools_manage_service.py
@@ -286,12 +286,12 @@ class BuiltinToolManageService:
session.add(db_provider)
session.commit()
-
- # Invalidate tool providers cache
- ToolProviderListCache.invalidate_cache(tenant_id)
except Exception as e:
session.rollback()
raise ValueError(str(e))
+
+ # Invalidate tool providers cache
+ ToolProviderListCache.invalidate_cache(tenant_id, "builtin")
return {"result": "success"}
@staticmethod
diff --git a/api/tests/unit_tests/core/helper/test_tool_provider_cache.py b/api/tests/unit_tests/core/helper/test_tool_provider_cache.py
index 00f7c9d7e9..d237c68f35 100644
--- a/api/tests/unit_tests/core/helper/test_tool_provider_cache.py
+++ b/api/tests/unit_tests/core/helper/test_tool_provider_cache.py
@@ -96,9 +96,6 @@ class TestToolProviderListCache:
ToolProviderListCache.invalidate_cache(tenant_id)
- mock_redis_client.scan_iter.assert_called_once_with(f"tool_providers:tenant_id:{tenant_id}:*")
- mock_redis_client.delete.assert_called_once_with(*mock_keys)
-
def test_invalidate_cache_no_keys(self, mock_redis_client):
"""Test invalidate cache - no cache keys for tenant"""
tenant_id = "tenant_123"
diff --git a/api/tests/unit_tests/services/test_dataset_service_get_segments.py b/api/tests/unit_tests/services/test_dataset_service_get_segments.py
new file mode 100644
index 0000000000..360c8a3c7d
--- /dev/null
+++ b/api/tests/unit_tests/services/test_dataset_service_get_segments.py
@@ -0,0 +1,472 @@
+"""
+Unit tests for SegmentService.get_segments method.
+
+Tests the retrieval of document segments with pagination and filtering:
+- Basic pagination (page, limit)
+- Status filtering
+- Keyword search
+- Ordering by position and id (to avoid duplicate data)
+"""
+
+from unittest.mock import Mock, create_autospec, patch
+
+import pytest
+
+from models.dataset import DocumentSegment
+
+
+class SegmentServiceTestDataFactory:
+ """
+ Factory class for creating test data and mock objects for segment tests.
+ """
+
+ @staticmethod
+ def create_segment_mock(
+ segment_id: str = "segment-123",
+ document_id: str = "doc-123",
+ tenant_id: str = "tenant-123",
+ dataset_id: str = "dataset-123",
+ position: int = 1,
+ content: str = "Test content",
+ status: str = "completed",
+ **kwargs,
+ ) -> Mock:
+ """
+ Create a mock document segment.
+
+ Args:
+ segment_id: Unique identifier for the segment
+ document_id: Parent document ID
+ tenant_id: Tenant ID the segment belongs to
+ dataset_id: Parent dataset ID
+ position: Position within the document
+ content: Segment text content
+ status: Indexing status
+ **kwargs: Additional attributes
+
+ Returns:
+ Mock: DocumentSegment mock object
+ """
+ segment = create_autospec(DocumentSegment, instance=True)
+ segment.id = segment_id
+ segment.document_id = document_id
+ segment.tenant_id = tenant_id
+ segment.dataset_id = dataset_id
+ segment.position = position
+ segment.content = content
+ segment.status = status
+ for key, value in kwargs.items():
+ setattr(segment, key, value)
+ return segment
+
+
+class TestSegmentServiceGetSegments:
+ """
+ Comprehensive unit tests for SegmentService.get_segments method.
+
+ Tests cover:
+ - Basic pagination functionality
+ - Status list filtering
+ - Keyword search filtering
+ - Ordering (position + id for uniqueness)
+ - Empty results
+ - Combined filters
+ """
+
+ @pytest.fixture
+ def mock_segment_service_dependencies(self):
+ """
+ Common mock setup for segment service dependencies.
+
+ Patches:
+ - db: Database operations and pagination
+ - select: SQLAlchemy query builder
+ """
+ with (
+ patch("services.dataset_service.db") as mock_db,
+ patch("services.dataset_service.select") as mock_select,
+ ):
+ yield {
+ "db": mock_db,
+ "select": mock_select,
+ }
+
+ def test_get_segments_basic_pagination(self, mock_segment_service_dependencies):
+ """
+ Test basic pagination functionality.
+
+ Verifies:
+ - Query is built with document_id and tenant_id filters
+ - Pagination uses correct page and limit parameters
+ - Returns segments and total count
+ """
+ # Arrange
+ document_id = "doc-123"
+ tenant_id = "tenant-123"
+ page = 1
+ limit = 20
+
+ # Create mock segments
+ segment1 = SegmentServiceTestDataFactory.create_segment_mock(
+ segment_id="seg-1", position=1, content="First segment"
+ )
+ segment2 = SegmentServiceTestDataFactory.create_segment_mock(
+ segment_id="seg-2", position=2, content="Second segment"
+ )
+
+ # Mock pagination result
+ mock_paginated = Mock()
+ mock_paginated.items = [segment1, segment2]
+ mock_paginated.total = 2
+
+ mock_segment_service_dependencies["db"].paginate.return_value = mock_paginated
+
+ # Mock select builder
+ mock_query = Mock()
+ mock_segment_service_dependencies["select"].return_value = mock_query
+ mock_query.where.return_value = mock_query
+ mock_query.order_by.return_value = mock_query
+
+ # Act
+ from services.dataset_service import SegmentService
+
+ items, total = SegmentService.get_segments(document_id=document_id, tenant_id=tenant_id, page=page, limit=limit)
+
+ # Assert
+ assert len(items) == 2
+ assert total == 2
+ assert items[0].id == "seg-1"
+ assert items[1].id == "seg-2"
+ mock_segment_service_dependencies["db"].paginate.assert_called_once()
+ call_kwargs = mock_segment_service_dependencies["db"].paginate.call_args[1]
+ assert call_kwargs["page"] == page
+ assert call_kwargs["per_page"] == limit
+ assert call_kwargs["max_per_page"] == 100
+ assert call_kwargs["error_out"] is False
+
+ def test_get_segments_with_status_filter(self, mock_segment_service_dependencies):
+ """
+ Test filtering by status list.
+
+ Verifies:
+ - Status list filter is applied to query
+ - Only segments with matching status are returned
+ """
+ # Arrange
+ document_id = "doc-123"
+ tenant_id = "tenant-123"
+ status_list = ["completed", "indexing"]
+
+ segment1 = SegmentServiceTestDataFactory.create_segment_mock(segment_id="seg-1", status="completed")
+ segment2 = SegmentServiceTestDataFactory.create_segment_mock(segment_id="seg-2", status="indexing")
+
+ mock_paginated = Mock()
+ mock_paginated.items = [segment1, segment2]
+ mock_paginated.total = 2
+
+ mock_segment_service_dependencies["db"].paginate.return_value = mock_paginated
+
+ mock_query = Mock()
+ mock_segment_service_dependencies["select"].return_value = mock_query
+ mock_query.where.return_value = mock_query
+ mock_query.order_by.return_value = mock_query
+
+ # Act
+ from services.dataset_service import SegmentService
+
+ items, total = SegmentService.get_segments(
+ document_id=document_id, tenant_id=tenant_id, status_list=status_list
+ )
+
+ # Assert
+ assert len(items) == 2
+ assert total == 2
+ # Verify where was called multiple times (base filters + status filter)
+ assert mock_query.where.call_count >= 2
+
+ def test_get_segments_with_empty_status_list(self, mock_segment_service_dependencies):
+ """
+ Test with empty status list.
+
+ Verifies:
+ - Empty status list is handled correctly
+ - No status filter is applied to avoid WHERE false condition
+ """
+ # Arrange
+ document_id = "doc-123"
+ tenant_id = "tenant-123"
+ status_list = []
+
+ segment = SegmentServiceTestDataFactory.create_segment_mock(segment_id="seg-1")
+
+ mock_paginated = Mock()
+ mock_paginated.items = [segment]
+ mock_paginated.total = 1
+
+ mock_segment_service_dependencies["db"].paginate.return_value = mock_paginated
+
+ mock_query = Mock()
+ mock_segment_service_dependencies["select"].return_value = mock_query
+ mock_query.where.return_value = mock_query
+ mock_query.order_by.return_value = mock_query
+
+ # Act
+ from services.dataset_service import SegmentService
+
+ items, total = SegmentService.get_segments(
+ document_id=document_id, tenant_id=tenant_id, status_list=status_list
+ )
+
+ # Assert
+ assert len(items) == 1
+ assert total == 1
+ # Should only be called once (base filters, no status filter)
+ assert mock_query.where.call_count == 1
+
+ def test_get_segments_with_keyword_search(self, mock_segment_service_dependencies):
+ """
+ Test keyword search functionality.
+
+ Verifies:
+ - Keyword filter uses ilike for case-insensitive search
+ - Search pattern includes wildcards (%keyword%)
+ """
+ # Arrange
+ document_id = "doc-123"
+ tenant_id = "tenant-123"
+ keyword = "search term"
+
+ segment = SegmentServiceTestDataFactory.create_segment_mock(
+ segment_id="seg-1", content="This contains search term"
+ )
+
+ mock_paginated = Mock()
+ mock_paginated.items = [segment]
+ mock_paginated.total = 1
+
+ mock_segment_service_dependencies["db"].paginate.return_value = mock_paginated
+
+ mock_query = Mock()
+ mock_segment_service_dependencies["select"].return_value = mock_query
+ mock_query.where.return_value = mock_query
+ mock_query.order_by.return_value = mock_query
+
+ # Act
+ from services.dataset_service import SegmentService
+
+ items, total = SegmentService.get_segments(document_id=document_id, tenant_id=tenant_id, keyword=keyword)
+
+ # Assert
+ assert len(items) == 1
+ assert total == 1
+ # Verify where was called for base filters + keyword filter
+ assert mock_query.where.call_count == 2
+
+ def test_get_segments_ordering_by_position_and_id(self, mock_segment_service_dependencies):
+ """
+ Test ordering by position and id.
+
+ Verifies:
+ - Results are ordered by position ASC
+ - Results are secondarily ordered by id ASC to ensure uniqueness
+ - This prevents duplicate data across pages when positions are not unique
+ """
+ # Arrange
+ document_id = "doc-123"
+ tenant_id = "tenant-123"
+
+ # Create segments with same position but different ids
+ segment1 = SegmentServiceTestDataFactory.create_segment_mock(
+ segment_id="seg-1", position=1, content="Content 1"
+ )
+ segment2 = SegmentServiceTestDataFactory.create_segment_mock(
+ segment_id="seg-2", position=1, content="Content 2"
+ )
+ segment3 = SegmentServiceTestDataFactory.create_segment_mock(
+ segment_id="seg-3", position=2, content="Content 3"
+ )
+
+ mock_paginated = Mock()
+ mock_paginated.items = [segment1, segment2, segment3]
+ mock_paginated.total = 3
+
+ mock_segment_service_dependencies["db"].paginate.return_value = mock_paginated
+
+ mock_query = Mock()
+ mock_segment_service_dependencies["select"].return_value = mock_query
+ mock_query.where.return_value = mock_query
+ mock_query.order_by.return_value = mock_query
+
+ # Act
+ from services.dataset_service import SegmentService
+
+ items, total = SegmentService.get_segments(document_id=document_id, tenant_id=tenant_id)
+
+ # Assert
+ assert len(items) == 3
+ assert total == 3
+ mock_query.order_by.assert_called_once()
+
+ def test_get_segments_empty_results(self, mock_segment_service_dependencies):
+ """
+ Test when no segments match the criteria.
+
+ Verifies:
+ - Empty list is returned for items
+ - Total count is 0
+ """
+ # Arrange
+ document_id = "non-existent-doc"
+ tenant_id = "tenant-123"
+
+ mock_paginated = Mock()
+ mock_paginated.items = []
+ mock_paginated.total = 0
+
+ mock_segment_service_dependencies["db"].paginate.return_value = mock_paginated
+
+ mock_query = Mock()
+ mock_segment_service_dependencies["select"].return_value = mock_query
+ mock_query.where.return_value = mock_query
+ mock_query.order_by.return_value = mock_query
+
+ # Act
+ from services.dataset_service import SegmentService
+
+ items, total = SegmentService.get_segments(document_id=document_id, tenant_id=tenant_id)
+
+ # Assert
+ assert items == []
+ assert total == 0
+
+ def test_get_segments_combined_filters(self, mock_segment_service_dependencies):
+ """
+ Test with multiple filters combined.
+
+ Verifies:
+ - All filters work together correctly
+ - Status list and keyword search both applied
+ """
+ # Arrange
+ document_id = "doc-123"
+ tenant_id = "tenant-123"
+ status_list = ["completed"]
+ keyword = "important"
+ page = 2
+ limit = 10
+
+ segment = SegmentServiceTestDataFactory.create_segment_mock(
+ segment_id="seg-1",
+ status="completed",
+ content="This is important information",
+ )
+
+ mock_paginated = Mock()
+ mock_paginated.items = [segment]
+ mock_paginated.total = 1
+
+ mock_segment_service_dependencies["db"].paginate.return_value = mock_paginated
+
+ mock_query = Mock()
+ mock_segment_service_dependencies["select"].return_value = mock_query
+ mock_query.where.return_value = mock_query
+ mock_query.order_by.return_value = mock_query
+
+ # Act
+ from services.dataset_service import SegmentService
+
+ items, total = SegmentService.get_segments(
+ document_id=document_id,
+ tenant_id=tenant_id,
+ status_list=status_list,
+ keyword=keyword,
+ page=page,
+ limit=limit,
+ )
+
+ # Assert
+ assert len(items) == 1
+ assert total == 1
+ # Verify filters: base + status + keyword
+ assert mock_query.where.call_count == 3
+ # Verify pagination parameters
+ call_kwargs = mock_segment_service_dependencies["db"].paginate.call_args[1]
+ assert call_kwargs["page"] == page
+ assert call_kwargs["per_page"] == limit
+
+ def test_get_segments_with_none_status_list(self, mock_segment_service_dependencies):
+ """
+ Test with None status list.
+
+ Verifies:
+ - None status list is handled correctly
+ - No status filter is applied
+ """
+ # Arrange
+ document_id = "doc-123"
+ tenant_id = "tenant-123"
+
+ segment = SegmentServiceTestDataFactory.create_segment_mock(segment_id="seg-1")
+
+ mock_paginated = Mock()
+ mock_paginated.items = [segment]
+ mock_paginated.total = 1
+
+ mock_segment_service_dependencies["db"].paginate.return_value = mock_paginated
+
+ mock_query = Mock()
+ mock_segment_service_dependencies["select"].return_value = mock_query
+ mock_query.where.return_value = mock_query
+ mock_query.order_by.return_value = mock_query
+
+ # Act
+ from services.dataset_service import SegmentService
+
+ items, total = SegmentService.get_segments(
+ document_id=document_id,
+ tenant_id=tenant_id,
+ status_list=None,
+ )
+
+ # Assert
+ assert len(items) == 1
+ assert total == 1
+ # Should only be called once (base filters only, no status filter)
+ assert mock_query.where.call_count == 1
+
+ def test_get_segments_pagination_max_per_page_limit(self, mock_segment_service_dependencies):
+ """
+ Test that max_per_page is correctly set to 100.
+
+ Verifies:
+ - max_per_page parameter is set to 100
+ - This prevents excessive page sizes
+ """
+ # Arrange
+ document_id = "doc-123"
+ tenant_id = "tenant-123"
+ limit = 200 # Request more than max_per_page
+
+ mock_paginated = Mock()
+ mock_paginated.items = []
+ mock_paginated.total = 0
+
+ mock_segment_service_dependencies["db"].paginate.return_value = mock_paginated
+
+ mock_query = Mock()
+ mock_segment_service_dependencies["select"].return_value = mock_query
+ mock_query.where.return_value = mock_query
+ mock_query.order_by.return_value = mock_query
+
+ # Act
+ from services.dataset_service import SegmentService
+
+ SegmentService.get_segments(
+ document_id=document_id,
+ tenant_id=tenant_id,
+ limit=limit,
+ )
+
+ # Assert
+ call_kwargs = mock_segment_service_dependencies["db"].paginate.call_args[1]
+ assert call_kwargs["max_per_page"] == 100
diff --git a/api/uv.lock b/api/uv.lock
index 4c2cb3c3f1..c31b7fe445 100644
--- a/api/uv.lock
+++ b/api/uv.lock
@@ -1368,7 +1368,7 @@ wheels = [
[[package]]
name = "dify-api"
-version = "1.11.1"
+version = "1.11.2"
source = { virtual = "." }
dependencies = [
{ name = "aliyun-log-python-sdk" },
diff --git a/docker/docker-compose-template.yaml b/docker/docker-compose-template.yaml
index 0de9d3e939..3c88cddf8c 100644
--- a/docker/docker-compose-template.yaml
+++ b/docker/docker-compose-template.yaml
@@ -21,7 +21,7 @@ services:
# API service
api:
- image: langgenius/dify-api:1.11.1
+ image: langgenius/dify-api:1.11.2
restart: always
environment:
# Use the shared environment variables.
@@ -63,7 +63,7 @@ services:
# worker service
# The Celery worker for processing all queues (dataset, workflow, mail, etc.)
worker:
- image: langgenius/dify-api:1.11.1
+ image: langgenius/dify-api:1.11.2
restart: always
environment:
# Use the shared environment variables.
@@ -102,7 +102,7 @@ services:
# worker_beat service
# Celery beat for scheduling periodic tasks.
worker_beat:
- image: langgenius/dify-api:1.11.1
+ image: langgenius/dify-api:1.11.2
restart: always
environment:
# Use the shared environment variables.
@@ -132,7 +132,7 @@ services:
# Frontend web application.
web:
- image: langgenius/dify-web:1.11.1
+ image: langgenius/dify-web:1.11.2
restart: always
environment:
CONSOLE_API_URL: ${CONSOLE_API_URL:-}
diff --git a/docker/docker-compose.yaml b/docker/docker-compose.yaml
index 964b9fe724..3f2031dbd9 100644
--- a/docker/docker-compose.yaml
+++ b/docker/docker-compose.yaml
@@ -692,7 +692,7 @@ services:
# API service
api:
- image: langgenius/dify-api:1.11.1
+ image: langgenius/dify-api:1.11.2
restart: always
environment:
# Use the shared environment variables.
@@ -734,7 +734,7 @@ services:
# worker service
# The Celery worker for processing all queues (dataset, workflow, mail, etc.)
worker:
- image: langgenius/dify-api:1.11.1
+ image: langgenius/dify-api:1.11.2
restart: always
environment:
# Use the shared environment variables.
@@ -773,7 +773,7 @@ services:
# worker_beat service
# Celery beat for scheduling periodic tasks.
worker_beat:
- image: langgenius/dify-api:1.11.1
+ image: langgenius/dify-api:1.11.2
restart: always
environment:
# Use the shared environment variables.
@@ -803,7 +803,7 @@ services:
# Frontend web application.
web:
- image: langgenius/dify-web:1.11.1
+ image: langgenius/dify-web:1.11.2
restart: always
environment:
CONSOLE_API_URL: ${CONSOLE_API_URL:-}
diff --git a/web/app/components/billing/billing-page/index.spec.tsx b/web/app/components/billing/billing-page/index.spec.tsx
new file mode 100644
index 0000000000..2310baa4f4
--- /dev/null
+++ b/web/app/components/billing/billing-page/index.spec.tsx
@@ -0,0 +1,84 @@
+import { fireEvent, render, screen, waitFor } from '@testing-library/react'
+import Billing from './index'
+
+let currentBillingUrl: string | null = 'https://billing'
+let fetching = false
+let isManager = true
+let enableBilling = true
+
+const refetchMock = vi.fn()
+const openAsyncWindowMock = vi.fn()
+
+vi.mock('@/service/use-billing', () => ({
+ useBillingUrl: () => ({
+ data: currentBillingUrl,
+ isFetching: fetching,
+ refetch: refetchMock,
+ }),
+}))
+
+vi.mock('@/hooks/use-async-window-open', () => ({
+ useAsyncWindowOpen: () => openAsyncWindowMock,
+}))
+
+vi.mock('@/context/app-context', () => ({
+ useAppContext: () => ({
+ isCurrentWorkspaceManager: isManager,
+ }),
+}))
+
+vi.mock('@/context/provider-context', () => ({
+ useProviderContext: () => ({
+ enableBilling,
+ }),
+}))
+
+vi.mock('../plan', () => ({
+ __esModule: true,
+ default: ({ loc }: { loc: string }) =>
,
+}))
+
+describe('Billing', () => {
+ beforeEach(() => {
+ vi.clearAllMocks()
+ currentBillingUrl = 'https://billing'
+ fetching = false
+ isManager = true
+ enableBilling = true
+ refetchMock.mockResolvedValue({ data: 'https://billing' })
+ })
+
+ it('hides the billing action when user is not manager or billing is disabled', () => {
+ isManager = false
+ render()
+ expect(screen.queryByRole('button', { name: /billing\.viewBillingTitle/ })).not.toBeInTheDocument()
+
+ vi.clearAllMocks()
+ isManager = true
+ enableBilling = false
+ render()
+ expect(screen.queryByRole('button', { name: /billing\.viewBillingTitle/ })).not.toBeInTheDocument()
+ })
+
+ it('opens the billing window with the immediate url when the button is clicked', async () => {
+ render()
+
+ const actionButton = screen.getByRole('button', { name: /billing\.viewBillingTitle/ })
+ fireEvent.click(actionButton)
+
+ await waitFor(() => expect(openAsyncWindowMock).toHaveBeenCalled())
+ const [, options] = openAsyncWindowMock.mock.calls[0]
+ expect(options).toMatchObject({
+ immediateUrl: currentBillingUrl,
+ features: 'noopener,noreferrer',
+ })
+ })
+
+ it('disables the button while billing url is fetching', () => {
+ fetching = true
+ render()
+
+ const actionButton = screen.getByRole('button', { name: /billing\.viewBillingTitle/ })
+ expect(actionButton).toBeDisabled()
+ })
+})
diff --git a/web/app/components/billing/header-billing-btn/index.spec.tsx b/web/app/components/billing/header-billing-btn/index.spec.tsx
new file mode 100644
index 0000000000..b87b733353
--- /dev/null
+++ b/web/app/components/billing/header-billing-btn/index.spec.tsx
@@ -0,0 +1,92 @@
+import { fireEvent, render, screen } from '@testing-library/react'
+import { Plan } from '../type'
+import HeaderBillingBtn from './index'
+
+type HeaderGlobal = typeof globalThis & {
+ __mockProviderContext?: ReturnType
+}
+
+function getHeaderGlobal(): HeaderGlobal {
+ return globalThis as HeaderGlobal
+}
+
+const ensureProviderContextMock = () => {
+ const globals = getHeaderGlobal()
+ if (!globals.__mockProviderContext)
+ throw new Error('Provider context mock not set')
+ return globals.__mockProviderContext
+}
+
+vi.mock('@/context/provider-context', () => {
+ const mock = vi.fn()
+ const globals = getHeaderGlobal()
+ globals.__mockProviderContext = mock
+ return {
+ useProviderContext: () => mock(),
+ }
+})
+
+vi.mock('../upgrade-btn', () => ({
+ __esModule: true,
+ default: () => ,
+}))
+
+describe('HeaderBillingBtn', () => {
+ beforeEach(() => {
+ vi.clearAllMocks()
+ ensureProviderContextMock().mockReturnValue({
+ plan: {
+ type: Plan.professional,
+ },
+ enableBilling: true,
+ isFetchedPlan: true,
+ })
+ })
+
+ it('renders nothing when billing is disabled or plan is not fetched', () => {
+ ensureProviderContextMock().mockReturnValueOnce({
+ plan: {
+ type: Plan.professional,
+ },
+ enableBilling: false,
+ isFetchedPlan: true,
+ })
+
+ const { container } = render()
+
+ expect(container.firstChild).toBeNull()
+ })
+
+ it('renders upgrade button for sandbox plan', () => {
+ ensureProviderContextMock().mockReturnValueOnce({
+ plan: {
+ type: Plan.sandbox,
+ },
+ enableBilling: true,
+ isFetchedPlan: true,
+ })
+
+ render()
+
+ expect(screen.getByTestId('upgrade-btn')).toBeInTheDocument()
+ })
+
+ it('renders plan badge and forwards clicks when not display-only', () => {
+ const onClick = vi.fn()
+
+ const { rerender } = render()
+
+ const badge = screen.getByText('pro').closest('div')
+
+ expect(badge).toHaveClass('cursor-pointer')
+
+ fireEvent.click(badge!)
+ expect(onClick).toHaveBeenCalledTimes(1)
+
+ rerender()
+ expect(screen.getByText('pro').closest('div')).toHaveClass('cursor-default')
+
+ fireEvent.click(screen.getByText('pro').closest('div')!)
+ expect(onClick).toHaveBeenCalledTimes(1)
+ })
+})
diff --git a/web/app/components/billing/partner-stack/index.spec.tsx b/web/app/components/billing/partner-stack/index.spec.tsx
new file mode 100644
index 0000000000..7b4658cf0f
--- /dev/null
+++ b/web/app/components/billing/partner-stack/index.spec.tsx
@@ -0,0 +1,44 @@
+import { render } from '@testing-library/react'
+import PartnerStack from './index'
+
+let isCloudEdition = true
+
+const saveOrUpdate = vi.fn()
+const bind = vi.fn()
+
+vi.mock('@/config', () => ({
+ get IS_CLOUD_EDITION() {
+ return isCloudEdition
+ },
+}))
+
+vi.mock('./use-ps-info', () => ({
+ __esModule: true,
+ default: () => ({
+ saveOrUpdate,
+ bind,
+ }),
+}))
+
+describe('PartnerStack', () => {
+ beforeEach(() => {
+ vi.clearAllMocks()
+ isCloudEdition = true
+ })
+
+ it('does not call partner stack helpers when not in cloud edition', () => {
+ isCloudEdition = false
+
+ render()
+
+ expect(saveOrUpdate).not.toHaveBeenCalled()
+ expect(bind).not.toHaveBeenCalled()
+ })
+
+ it('calls saveOrUpdate and bind once when running in cloud edition', () => {
+ render()
+
+ expect(saveOrUpdate).toHaveBeenCalledTimes(1)
+ expect(bind).toHaveBeenCalledTimes(1)
+ })
+})
diff --git a/web/app/components/billing/partner-stack/use-ps-info.spec.tsx b/web/app/components/billing/partner-stack/use-ps-info.spec.tsx
new file mode 100644
index 0000000000..14215f2514
--- /dev/null
+++ b/web/app/components/billing/partner-stack/use-ps-info.spec.tsx
@@ -0,0 +1,197 @@
+import { act, renderHook } from '@testing-library/react'
+import { PARTNER_STACK_CONFIG } from '@/config'
+import usePSInfo from './use-ps-info'
+
+let searchParamsValues: Record = {}
+const setSearchParams = (values: Record) => {
+ searchParamsValues = values
+}
+
+type PartnerStackGlobal = typeof globalThis & {
+ __partnerStackCookieMocks?: {
+ get: ReturnType
+ set: ReturnType
+ remove: ReturnType
+ }
+ __partnerStackMutateAsync?: ReturnType
+}
+
+function getPartnerStackGlobal(): PartnerStackGlobal {
+ return globalThis as PartnerStackGlobal
+}
+
+const ensureCookieMocks = () => {
+ const globals = getPartnerStackGlobal()
+ if (!globals.__partnerStackCookieMocks)
+ throw new Error('Cookie mocks not initialized')
+ return globals.__partnerStackCookieMocks
+}
+
+const ensureMutateAsync = () => {
+ const globals = getPartnerStackGlobal()
+ if (!globals.__partnerStackMutateAsync)
+ throw new Error('Mutate mock not initialized')
+ return globals.__partnerStackMutateAsync
+}
+
+vi.mock('js-cookie', () => {
+ const get = vi.fn()
+ const set = vi.fn()
+ const remove = vi.fn()
+ const globals = getPartnerStackGlobal()
+ globals.__partnerStackCookieMocks = { get, set, remove }
+ const cookieApi = { get, set, remove }
+ return {
+ __esModule: true,
+ default: cookieApi,
+ get,
+ set,
+ remove,
+ }
+})
+vi.mock('next/navigation', () => ({
+ useSearchParams: () => ({
+ get: (key: string) => searchParamsValues[key] ?? null,
+ }),
+}))
+vi.mock('@/service/use-billing', () => {
+ const mutateAsync = vi.fn()
+ const globals = getPartnerStackGlobal()
+ globals.__partnerStackMutateAsync = mutateAsync
+ return {
+ useBindPartnerStackInfo: () => ({
+ mutateAsync,
+ }),
+ }
+})
+
+describe('usePSInfo', () => {
+ const originalLocationDescriptor = Object.getOwnPropertyDescriptor(globalThis, 'location')
+
+ beforeAll(() => {
+ Object.defineProperty(globalThis, 'location', {
+ value: { hostname: 'cloud.dify.ai' },
+ configurable: true,
+ })
+ })
+
+ beforeEach(() => {
+ setSearchParams({})
+ const { get, set, remove } = ensureCookieMocks()
+ get.mockReset()
+ set.mockReset()
+ remove.mockReset()
+ const mutate = ensureMutateAsync()
+ mutate.mockReset()
+ mutate.mockResolvedValue(undefined)
+ get.mockReturnValue('{}')
+ })
+
+ afterAll(() => {
+ if (originalLocationDescriptor)
+ Object.defineProperty(globalThis, 'location', originalLocationDescriptor)
+ })
+
+ it('saves partner info when query params change', () => {
+ const { get, set } = ensureCookieMocks()
+ get.mockReturnValue(JSON.stringify({ partnerKey: 'old', clickId: 'old-click' }))
+ setSearchParams({
+ ps_partner_key: 'new-partner',
+ ps_xid: 'new-click',
+ })
+
+ const { result } = renderHook(() => usePSInfo())
+
+ expect(result.current.psPartnerKey).toBe('new-partner')
+ expect(result.current.psClickId).toBe('new-click')
+
+ act(() => {
+ result.current.saveOrUpdate()
+ })
+
+ expect(set).toHaveBeenCalledWith(
+ PARTNER_STACK_CONFIG.cookieName,
+ JSON.stringify({
+ partnerKey: 'new-partner',
+ clickId: 'new-click',
+ }),
+ {
+ expires: PARTNER_STACK_CONFIG.saveCookieDays,
+ path: '/',
+ domain: '.dify.ai',
+ },
+ )
+ })
+
+ it('does not overwrite cookie when params do not change', () => {
+ setSearchParams({
+ ps_partner_key: 'existing',
+ ps_xid: 'existing-click',
+ })
+ const { get } = ensureCookieMocks()
+ get.mockReturnValue(JSON.stringify({
+ partnerKey: 'existing',
+ clickId: 'existing-click',
+ }))
+
+ const { result } = renderHook(() => usePSInfo())
+
+ act(() => {
+ result.current.saveOrUpdate()
+ })
+
+ const { set } = ensureCookieMocks()
+ expect(set).not.toHaveBeenCalled()
+ })
+
+ it('binds partner info and clears cookie once', async () => {
+ setSearchParams({
+ ps_partner_key: 'bind-partner',
+ ps_xid: 'bind-click',
+ })
+
+ const { result } = renderHook(() => usePSInfo())
+
+ const mutate = ensureMutateAsync()
+ const { remove } = ensureCookieMocks()
+ await act(async () => {
+ await result.current.bind()
+ })
+
+ expect(mutate).toHaveBeenCalledWith({
+ partnerKey: 'bind-partner',
+ clickId: 'bind-click',
+ })
+ expect(remove).toHaveBeenCalledWith(PARTNER_STACK_CONFIG.cookieName, {
+ path: '/',
+ domain: '.dify.ai',
+ })
+
+ await act(async () => {
+ await result.current.bind()
+ })
+
+ expect(mutate).toHaveBeenCalledTimes(1)
+ })
+
+ it('still removes cookie when bind fails with status 400', async () => {
+ const mutate = ensureMutateAsync()
+ mutate.mockRejectedValueOnce({ status: 400 })
+ setSearchParams({
+ ps_partner_key: 'bind-partner',
+ ps_xid: 'bind-click',
+ })
+
+ const { result } = renderHook(() => usePSInfo())
+
+ await act(async () => {
+ await result.current.bind()
+ })
+
+ const { remove } = ensureCookieMocks()
+ expect(remove).toHaveBeenCalledWith(PARTNER_STACK_CONFIG.cookieName, {
+ path: '/',
+ domain: '.dify.ai',
+ })
+ })
+})
diff --git a/web/app/components/billing/plan/index.spec.tsx b/web/app/components/billing/plan/index.spec.tsx
new file mode 100644
index 0000000000..bcdb83b5df
--- /dev/null
+++ b/web/app/components/billing/plan/index.spec.tsx
@@ -0,0 +1,130 @@
+import { fireEvent, render, screen, waitFor } from '@testing-library/react'
+import { EDUCATION_VERIFYING_LOCALSTORAGE_ITEM } from '@/app/education-apply/constants'
+import { Plan } from '../type'
+import PlanComp from './index'
+
+let currentPath = '/billing'
+
+const push = vi.fn()
+
+vi.mock('next/navigation', () => ({
+ useRouter: () => ({ push }),
+ usePathname: () => currentPath,
+}))
+
+const setShowAccountSettingModalMock = vi.fn()
+vi.mock('@/context/modal-context', () => ({
+ // eslint-disable-next-line ts/no-explicit-any
+ useModalContextSelector: (selector: any) => selector({
+ setShowAccountSettingModal: setShowAccountSettingModalMock,
+ }),
+}))
+
+const providerContextMock = vi.fn()
+vi.mock('@/context/provider-context', () => ({
+ useProviderContext: () => providerContextMock(),
+}))
+
+vi.mock('@/context/app-context', () => ({
+ useAppContext: () => ({
+ userProfile: { email: 'user@example.com' },
+ isCurrentWorkspaceManager: true,
+ }),
+}))
+
+const mutateAsyncMock = vi.fn()
+let isPending = false
+vi.mock('@/service/use-education', () => ({
+ useEducationVerify: () => ({
+ mutateAsync: mutateAsyncMock,
+ isPending,
+ }),
+}))
+
+const verifyStateModalMock = vi.fn(props => (
+
+ {props.isShow ? 'visible' : 'hidden'}
+
+))
+vi.mock('@/app/education-apply/verify-state-modal', () => ({
+ __esModule: true,
+ // eslint-disable-next-line ts/no-explicit-any
+ default: (props: any) => verifyStateModalMock(props),
+}))
+
+vi.mock('../upgrade-btn', () => ({
+ __esModule: true,
+ default: () => ,
+}))
+
+describe('PlanComp', () => {
+ const planMock = {
+ type: Plan.professional,
+ usage: {
+ teamMembers: 4,
+ documentsUploadQuota: 3,
+ vectorSpace: 8,
+ annotatedResponse: 5,
+ triggerEvents: 60,
+ apiRateLimit: 100,
+ },
+ total: {
+ teamMembers: 10,
+ documentsUploadQuota: 20,
+ vectorSpace: 10,
+ annotatedResponse: 500,
+ triggerEvents: 100,
+ apiRateLimit: 200,
+ },
+ reset: {
+ triggerEvents: 2,
+ apiRateLimit: 1,
+ },
+ }
+
+ beforeEach(() => {
+ vi.clearAllMocks()
+ currentPath = '/billing'
+ isPending = false
+ providerContextMock.mockReturnValue({
+ plan: planMock,
+ enableEducationPlan: true,
+ allowRefreshEducationVerify: false,
+ isEducationAccount: false,
+ })
+ mutateAsyncMock.mockReset()
+ mutateAsyncMock.mockResolvedValue({ token: 'token' })
+ })
+
+ it('renders plan info and handles education verify success', async () => {
+ render()
+
+ expect(screen.getByText('billing.plans.professional.name')).toBeInTheDocument()
+ expect(screen.getByTestId('plan-upgrade-btn')).toBeInTheDocument()
+
+ const verifyBtn = screen.getByText('education.toVerified')
+ fireEvent.click(verifyBtn)
+
+ await waitFor(() => expect(mutateAsyncMock).toHaveBeenCalled())
+ await waitFor(() => expect(push).toHaveBeenCalledWith('/education-apply?token=token'))
+ expect(localStorage.removeItem).toHaveBeenCalledWith(EDUCATION_VERIFYING_LOCALSTORAGE_ITEM)
+ })
+
+ it('shows modal when education verify fails', async () => {
+ mutateAsyncMock.mockRejectedValueOnce(new Error('boom'))
+ render()
+
+ const verifyBtn = screen.getByText('education.toVerified')
+ fireEvent.click(verifyBtn)
+
+ await waitFor(() => expect(mutateAsyncMock).toHaveBeenCalled())
+ await waitFor(() => expect(screen.getByTestId('verify-modal').getAttribute('data-is-show')).toBe('true'))
+ })
+
+ it('resets modal context when on education apply path', () => {
+ currentPath = '/education-apply/setup'
+ render()
+
+ expect(setShowAccountSettingModalMock).toHaveBeenCalledWith(null)
+ })
+})
diff --git a/web/app/components/billing/progress-bar/index.spec.tsx b/web/app/components/billing/progress-bar/index.spec.tsx
new file mode 100644
index 0000000000..a9c91468de
--- /dev/null
+++ b/web/app/components/billing/progress-bar/index.spec.tsx
@@ -0,0 +1,25 @@
+import { render, screen } from '@testing-library/react'
+import ProgressBar from './index'
+
+describe('ProgressBar', () => {
+ it('renders with provided percent and color', () => {
+ render()
+
+ const bar = screen.getByTestId('billing-progress-bar')
+ expect(bar).toHaveClass('bg-test-color')
+ expect(bar.getAttribute('style')).toContain('width: 42%')
+ })
+
+ it('caps width at 100% when percent exceeds max', () => {
+ render()
+
+ const bar = screen.getByTestId('billing-progress-bar')
+ expect(bar.getAttribute('style')).toContain('width: 100%')
+ })
+
+ it('uses the default color when no color prop is provided', () => {
+ render()
+
+ expect(screen.getByTestId('billing-progress-bar')).toHaveClass('#2970FF')
+ })
+})
diff --git a/web/app/components/billing/trigger-events-limit-modal/index.spec.tsx b/web/app/components/billing/trigger-events-limit-modal/index.spec.tsx
new file mode 100644
index 0000000000..a3d04c6031
--- /dev/null
+++ b/web/app/components/billing/trigger-events-limit-modal/index.spec.tsx
@@ -0,0 +1,70 @@
+import { render, screen } from '@testing-library/react'
+import TriggerEventsLimitModal from './index'
+
+const mockOnClose = vi.fn()
+const mockOnUpgrade = vi.fn()
+
+const planUpgradeModalMock = vi.fn((props: { show: boolean, title: string, description: string, extraInfo?: React.ReactNode, onClose: () => void, onUpgrade: () => void }) => (
+
+ {props.extraInfo}
+
+))
+
+vi.mock('@/app/components/billing/plan-upgrade-modal', () => ({
+ __esModule: true,
+ // eslint-disable-next-line ts/no-explicit-any
+ default: (props: any) => planUpgradeModalMock(props),
+}))
+
+describe('TriggerEventsLimitModal', () => {
+ beforeEach(() => {
+ vi.clearAllMocks()
+ })
+
+ it('passes the trigger usage props to the upgrade modal', () => {
+ render(
+ ,
+ )
+
+ const modal = screen.getByTestId('plan-upgrade-modal')
+ expect(modal.getAttribute('data-show')).toBe('true')
+ expect(modal.getAttribute('data-title')).toContain('billing.triggerLimitModal.title')
+ expect(modal.getAttribute('data-description')).toContain('billing.triggerLimitModal.description')
+ expect(planUpgradeModalMock).toHaveBeenCalled()
+
+ const passedProps = planUpgradeModalMock.mock.calls[0][0]
+ expect(passedProps.onClose).toBe(mockOnClose)
+ expect(passedProps.onUpgrade).toBe(mockOnUpgrade)
+
+ expect(screen.getByText('billing.triggerLimitModal.usageTitle')).toBeInTheDocument()
+ expect(screen.getByText('12')).toBeInTheDocument()
+ expect(screen.getByText('20')).toBeInTheDocument()
+ })
+
+ it('renders even when trigger modal is hidden', () => {
+ render(
+ ,
+ )
+
+ expect(planUpgradeModalMock).toHaveBeenCalled()
+ expect(screen.getByTestId('plan-upgrade-modal').getAttribute('data-show')).toBe('false')
+ })
+})
diff --git a/web/app/components/billing/usage-info/apps-info.spec.tsx b/web/app/components/billing/usage-info/apps-info.spec.tsx
new file mode 100644
index 0000000000..7289b474e5
--- /dev/null
+++ b/web/app/components/billing/usage-info/apps-info.spec.tsx
@@ -0,0 +1,35 @@
+import { render, screen } from '@testing-library/react'
+import { defaultPlan } from '../config'
+import AppsInfo from './apps-info'
+
+const appsUsage = 7
+const appsTotal = 15
+
+const mockPlan = {
+ ...defaultPlan,
+ usage: {
+ ...defaultPlan.usage,
+ buildApps: appsUsage,
+ },
+ total: {
+ ...defaultPlan.total,
+ buildApps: appsTotal,
+ },
+}
+
+vi.mock('@/context/provider-context', () => ({
+ useProviderContext: () => ({
+ plan: mockPlan,
+ }),
+}))
+
+describe('AppsInfo', () => {
+ it('renders build apps usage information with context data', () => {
+ render()
+
+ expect(screen.getByText('billing.usagePage.buildApps')).toBeInTheDocument()
+ expect(screen.getByText(`${appsUsage}`)).toBeInTheDocument()
+ expect(screen.getByText(`${appsTotal}`)).toBeInTheDocument()
+ expect(screen.getByText('billing.usagePage.buildApps').closest('.apps-info-class')).toBeInTheDocument()
+ })
+})
diff --git a/web/app/components/billing/usage-info/index.spec.tsx b/web/app/components/billing/usage-info/index.spec.tsx
new file mode 100644
index 0000000000..3137c4865f
--- /dev/null
+++ b/web/app/components/billing/usage-info/index.spec.tsx
@@ -0,0 +1,114 @@
+import { render, screen } from '@testing-library/react'
+import { NUM_INFINITE } from '../config'
+import UsageInfo from './index'
+
+const TestIcon = () =>
+
+describe('UsageInfo', () => {
+ it('renders the metric with a suffix unit and tooltip text', () => {
+ render(
+ ,
+ )
+
+ expect(screen.getByTestId('usage-icon')).toBeInTheDocument()
+ expect(screen.getByText('Apps')).toBeInTheDocument()
+ expect(screen.getByText('30')).toBeInTheDocument()
+ expect(screen.getByText('100')).toBeInTheDocument()
+ expect(screen.getByText('GB')).toBeInTheDocument()
+ })
+
+ it('renders inline unit when unitPosition is inline', () => {
+ render(
+ ,
+ )
+
+ expect(screen.getByText('100GB')).toBeInTheDocument()
+ })
+
+ it('shows reset hint text instead of the unit when resetHint is provided', () => {
+ const resetHint = 'Resets in 3 days'
+ render(
+ ,
+ )
+
+ expect(screen.getByText(resetHint)).toBeInTheDocument()
+ expect(screen.queryByText('GB')).not.toBeInTheDocument()
+ })
+
+ it('displays unlimited text when total is infinite', () => {
+ render(
+ ,
+ )
+
+ expect(screen.getByText('billing.plansCommon.unlimited')).toBeInTheDocument()
+ })
+
+ it('applies warning color when usage is close to the limit', () => {
+ render(
+ ,
+ )
+
+ const progressBar = screen.getByTestId('billing-progress-bar')
+ expect(progressBar).toHaveClass('bg-components-progress-warning-progress')
+ })
+
+ it('applies error color when usage exceeds the limit', () => {
+ render(
+ ,
+ )
+
+ const progressBar = screen.getByTestId('billing-progress-bar')
+ expect(progressBar).toHaveClass('bg-components-progress-error-progress')
+ })
+
+ it('does not render the icon when hideIcon is true', () => {
+ render(
+ ,
+ )
+
+ expect(screen.queryByTestId('usage-icon')).not.toBeInTheDocument()
+ })
+})
diff --git a/web/app/components/billing/vector-space-full/index.spec.tsx b/web/app/components/billing/vector-space-full/index.spec.tsx
new file mode 100644
index 0000000000..de5607df41
--- /dev/null
+++ b/web/app/components/billing/vector-space-full/index.spec.tsx
@@ -0,0 +1,58 @@
+import { render, screen } from '@testing-library/react'
+import VectorSpaceFull from './index'
+
+type VectorProviderGlobal = typeof globalThis & {
+ __vectorProviderContext?: ReturnType
+}
+
+function getVectorGlobal(): VectorProviderGlobal {
+ return globalThis as VectorProviderGlobal
+}
+
+vi.mock('@/context/provider-context', () => {
+ const mock = vi.fn()
+ getVectorGlobal().__vectorProviderContext = mock
+ return {
+ useProviderContext: () => mock(),
+ }
+})
+
+vi.mock('../upgrade-btn', () => ({
+ __esModule: true,
+ default: () => ,
+}))
+
+describe('VectorSpaceFull', () => {
+ const planMock = {
+ type: 'team',
+ usage: {
+ vectorSpace: 8,
+ },
+ total: {
+ vectorSpace: 10,
+ },
+ }
+
+ beforeEach(() => {
+ vi.clearAllMocks()
+ const globals = getVectorGlobal()
+ globals.__vectorProviderContext?.mockReturnValue({
+ plan: planMock,
+ })
+ })
+
+ it('renders tip text and upgrade button', () => {
+ render()
+
+ expect(screen.getByText('billing.vectorSpace.fullTip')).toBeInTheDocument()
+ expect(screen.getByText('billing.vectorSpace.fullSolution')).toBeInTheDocument()
+ expect(screen.getByTestId('vector-upgrade-btn')).toBeInTheDocument()
+ })
+
+ it('shows vector usage and total', () => {
+ render()
+
+ expect(screen.getByText('8')).toBeInTheDocument()
+ expect(screen.getByText('10MB')).toBeInTheDocument()
+ })
+})
diff --git a/web/app/components/header/account-setting/members-page/operation/index.spec.tsx b/web/app/components/header/account-setting/members-page/operation/index.spec.tsx
new file mode 100644
index 0000000000..fbe3959a0f
--- /dev/null
+++ b/web/app/components/header/account-setting/members-page/operation/index.spec.tsx
@@ -0,0 +1,91 @@
+import type { Member } from '@/models/common'
+import { fireEvent, render, screen, waitFor } from '@testing-library/react'
+import { vi } from 'vitest'
+import { ToastContext } from '@/app/components/base/toast'
+import Operation from './index'
+
+const mockUpdateMemberRole = vi.fn()
+const mockDeleteMemberOrCancelInvitation = vi.fn()
+
+vi.mock('@/service/common', () => ({
+ deleteMemberOrCancelInvitation: () => mockDeleteMemberOrCancelInvitation(),
+ updateMemberRole: () => mockUpdateMemberRole(),
+}))
+
+const mockUseProviderContext = vi.fn(() => ({
+ datasetOperatorEnabled: false,
+}))
+
+vi.mock('@/context/provider-context', () => ({
+ useProviderContext: () => mockUseProviderContext(),
+}))
+
+const defaultMember: Member = {
+ id: 'member-id',
+ name: 'Test Member',
+ email: 'test@example.com',
+ avatar: '',
+ avatar_url: null,
+ status: 'active',
+ role: 'editor',
+ last_login_at: '',
+ last_active_at: '',
+ created_at: '',
+}
+
+const renderOperation = (propsOverride: Partial = {}, operatorRole = 'owner', onOperate?: () => void) => {
+ const mergedMember = { ...defaultMember, ...propsOverride }
+ return render(
+
+
+ ,
+ )
+}
+
+describe('Operation', () => {
+ beforeEach(() => {
+ vi.clearAllMocks()
+ mockUseProviderContext.mockReturnValue({ datasetOperatorEnabled: false })
+ })
+
+ it('renders the current role label', () => {
+ renderOperation()
+
+ expect(screen.getByText('common.members.editor')).toBeInTheDocument()
+ })
+
+ it('shows dataset operator option when the feature flag is enabled', async () => {
+ mockUseProviderContext.mockReturnValue({ datasetOperatorEnabled: true })
+ renderOperation()
+
+ fireEvent.click(screen.getByText('common.members.editor'))
+
+ expect(await screen.findByText('common.members.datasetOperator')).toBeInTheDocument()
+ })
+
+ it('calls updateMemberRole and onOperate when selecting another role', async () => {
+ const onOperate = vi.fn()
+ renderOperation({}, 'owner', onOperate)
+
+ fireEvent.click(screen.getByText('common.members.editor'))
+ fireEvent.click(await screen.findByText('common.members.normal'))
+
+ await waitFor(() => {
+ expect(mockUpdateMemberRole).toHaveBeenCalled()
+ expect(onOperate).toHaveBeenCalled()
+ })
+ })
+
+ it('calls deleteMemberOrCancelInvitation when removing the member', async () => {
+ const onOperate = vi.fn()
+ renderOperation({}, 'owner', onOperate)
+
+ fireEvent.click(screen.getByText('common.members.editor'))
+ fireEvent.click(await screen.findByText('common.members.removeFromTeam'))
+
+ await waitFor(() => {
+ expect(mockDeleteMemberOrCancelInvitation).toHaveBeenCalled()
+ expect(onOperate).toHaveBeenCalled()
+ })
+ })
+})
diff --git a/web/app/components/header/account-setting/members-page/operation/index.tsx b/web/app/components/header/account-setting/members-page/operation/index.tsx
index da61746685..6effe8b058 100644
--- a/web/app/components/header/account-setting/members-page/operation/index.tsx
+++ b/web/app/components/header/account-setting/members-page/operation/index.tsx
@@ -1,10 +1,14 @@
'use client'
import type { Member } from '@/models/common'
-import { Menu, MenuButton, MenuItem, MenuItems, Transition } from '@headlessui/react'
import { CheckIcon, ChevronDownIcon } from '@heroicons/react/24/outline'
-import { Fragment, useMemo } from 'react'
+import { memo, useMemo, useState } from 'react'
import { useTranslation } from 'react-i18next'
import { useContext } from 'use-context-selector'
+import {
+ PortalToFollowElem,
+ PortalToFollowElemContent,
+ PortalToFollowElemTrigger,
+} from '@/app/components/base/portal-to-follow-elem'
import { ToastContext } from '@/app/components/base/toast'
import { useProviderContext } from '@/context/provider-context'
import { deleteMemberOrCancelInvitation, updateMemberRole } from '@/service/common'
@@ -21,6 +25,7 @@ const Operation = ({
operatorRole,
onOperate,
}: IOperationProps) => {
+ const [open, setOpen] = useState(false)
const { t } = useTranslation()
const { datasetOperatorEnabled } = useProviderContext()
const RoleMap = {
@@ -51,6 +56,7 @@ const Operation = ({
const { notify } = useContext(ToastContext)
const toHump = (name: string) => name.replace(/_(\w)/g, (all, letter) => letter.toUpperCase())
const handleDeleteMemberOrCancelInvitation = async () => {
+ setOpen(false)
try {
await deleteMemberOrCancelInvitation({ url: `/workspaces/current/members/${member.id}` })
onOperate()
@@ -61,6 +67,7 @@ const Operation = ({
}
}
const handleUpdateMemberRole = async (role: string) => {
+ setOpen(false)
try {
await updateMemberRole({
url: `/workspaces/current/members/${member.id}/update-role`,
@@ -75,63 +82,50 @@ const Operation = ({
}
return (
-