Merge branch 'main' into 12-24-json-for-translation

This commit is contained in:
Stephen Zhou 2025-12-25 16:06:52 +08:00 committed by GitHub
commit 6cd33fc1c1
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
23 changed files with 2235 additions and 32 deletions

View File

@ -13,12 +13,28 @@ jobs:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- name: Check Docker Compose inputs
id: docker-compose-changes
uses: tj-actions/changed-files@v46
with:
files: |
docker/generate_docker_compose
docker/.env.example
docker/docker-compose-template.yaml
docker/docker-compose.yaml
- uses: actions/setup-python@v5
with:
python-version: "3.11"
- uses: astral-sh/setup-uv@v6
- name: Generate Docker Compose
if: steps.docker-compose-changes.outputs.any_changed == 'true'
run: |
cd docker
./generate_docker_compose
- run: |
cd api
uv sync --dev

View File

@ -108,36 +108,6 @@ jobs:
working-directory: ./web
run: pnpm run type-check:tsgo
docker-compose-template:
name: Docker Compose Template
runs-on: ubuntu-latest
steps:
- name: Checkout code
uses: actions/checkout@v4
with:
persist-credentials: false
- name: Check changed files
id: changed-files
uses: tj-actions/changed-files@v46
with:
files: |
docker/generate_docker_compose
docker/.env.example
docker/docker-compose-template.yaml
docker/docker-compose.yaml
- name: Generate Docker Compose
if: steps.changed-files.outputs.any_changed == 'true'
run: |
cd docker
./generate_docker_compose
- name: Check for changes
if: steps.changed-files.outputs.any_changed == 'true'
run: git diff --exit-code
superlinter:
name: SuperLinter
runs-on: ubuntu-latest

View File

@ -0,0 +1,57 @@
import os
from email.message import Message
from urllib.parse import quote
from flask import Response
HTML_MIME_TYPES = frozenset({"text/html", "application/xhtml+xml"})
HTML_EXTENSIONS = frozenset({"html", "htm"})
def _normalize_mime_type(mime_type: str | None) -> str:
if not mime_type:
return ""
message = Message()
message["Content-Type"] = mime_type
return message.get_content_type().strip().lower()
def _is_html_extension(extension: str | None) -> bool:
if not extension:
return False
return extension.lstrip(".").lower() in HTML_EXTENSIONS
def is_html_content(mime_type: str | None, filename: str | None, extension: str | None = None) -> bool:
normalized_mime_type = _normalize_mime_type(mime_type)
if normalized_mime_type in HTML_MIME_TYPES:
return True
if _is_html_extension(extension):
return True
if filename:
return _is_html_extension(os.path.splitext(filename)[1])
return False
def enforce_download_for_html(
response: Response,
*,
mime_type: str | None,
filename: str | None,
extension: str | None = None,
) -> bool:
if not is_html_content(mime_type, filename, extension):
return False
if filename:
encoded_filename = quote(filename)
response.headers["Content-Disposition"] = f"attachment; filename*=UTF-8''{encoded_filename}"
else:
response.headers["Content-Disposition"] = "attachment"
response.headers["Content-Type"] = "application/octet-stream"
response.headers["X-Content-Type-Options"] = "nosniff"
return True

View File

@ -7,6 +7,7 @@ from werkzeug.exceptions import NotFound
import services
from controllers.common.errors import UnsupportedFileTypeError
from controllers.common.file_response import enforce_download_for_html
from controllers.files import files_ns
from extensions.ext_database import db
from services.account_service import TenantService
@ -138,6 +139,13 @@ class FilePreviewApi(Resource):
response.headers["Content-Disposition"] = f"attachment; filename*=UTF-8''{encoded_filename}"
response.headers["Content-Type"] = "application/octet-stream"
enforce_download_for_html(
response,
mime_type=upload_file.mime_type,
filename=upload_file.name,
extension=upload_file.extension,
)
return response

View File

@ -6,6 +6,7 @@ from pydantic import BaseModel, Field
from werkzeug.exceptions import Forbidden, NotFound
from controllers.common.errors import UnsupportedFileTypeError
from controllers.common.file_response import enforce_download_for_html
from controllers.files import files_ns
from core.tools.signature import verify_tool_file_signature
from core.tools.tool_file_manager import ToolFileManager
@ -78,4 +79,11 @@ class ToolFileApi(Resource):
encoded_filename = quote(tool_file.name)
response.headers["Content-Disposition"] = f"attachment; filename*=UTF-8''{encoded_filename}"
enforce_download_for_html(
response,
mime_type=tool_file.mimetype,
filename=tool_file.name,
extension=extension,
)
return response

View File

@ -5,6 +5,7 @@ from flask import Response, request
from flask_restx import Resource
from pydantic import BaseModel, Field
from controllers.common.file_response import enforce_download_for_html
from controllers.common.schema import register_schema_model
from controllers.service_api import service_api_ns
from controllers.service_api.app.error import (
@ -183,6 +184,13 @@ class FilePreviewApi(Resource):
# Override content-type for downloads to force download
response.headers["Content-Type"] = "application/octet-stream"
enforce_download_for_html(
response,
mime_type=upload_file.mime_type,
filename=upload_file.name,
extension=upload_file.extension,
)
# Add caching headers for performance
response.headers["Cache-Control"] = "public, max-age=3600" # Cache for 1 hour

View File

@ -3458,7 +3458,7 @@ class SegmentService:
if keyword:
query = query.where(DocumentSegment.content.ilike(f"%{keyword}%"))
query = query.order_by(DocumentSegment.position.asc())
query = query.order_by(DocumentSegment.position.asc(), DocumentSegment.id.asc())
paginated_segments = db.paginate(select=query, page=page, per_page=limit, max_per_page=100, error_out=False)
return paginated_segments.items, paginated_segments.total

View File

@ -0,0 +1,46 @@
from flask import Response
from controllers.common.file_response import enforce_download_for_html, is_html_content
class TestFileResponseHelpers:
def test_is_html_content_detects_mime_type(self):
mime_type = "text/html; charset=UTF-8"
result = is_html_content(mime_type, filename="file.txt", extension="txt")
assert result is True
def test_is_html_content_detects_extension(self):
result = is_html_content("text/plain", filename="report.html", extension=None)
assert result is True
def test_enforce_download_for_html_sets_headers(self):
response = Response("payload", mimetype="text/html")
updated = enforce_download_for_html(
response,
mime_type="text/html",
filename="unsafe.html",
extension="html",
)
assert updated is True
assert "attachment" in response.headers["Content-Disposition"]
assert response.headers["Content-Type"] == "application/octet-stream"
assert response.headers["X-Content-Type-Options"] == "nosniff"
def test_enforce_download_for_html_no_change_for_non_html(self):
response = Response("payload", mimetype="text/plain")
updated = enforce_download_for_html(
response,
mime_type="text/plain",
filename="notes.txt",
extension="txt",
)
assert updated is False
assert "Content-Disposition" not in response.headers
assert "X-Content-Type-Options" not in response.headers

View File

@ -41,6 +41,7 @@ class TestFilePreviewApi:
upload_file = Mock(spec=UploadFile)
upload_file.id = str(uuid.uuid4())
upload_file.name = "test_file.jpg"
upload_file.extension = "jpg"
upload_file.mime_type = "image/jpeg"
upload_file.size = 1024
upload_file.key = "storage/key/test_file.jpg"
@ -210,6 +211,19 @@ class TestFilePreviewApi:
assert mock_upload_file.name in response.headers["Content-Disposition"]
assert response.headers["Content-Type"] == "application/octet-stream"
def test_build_file_response_html_forces_attachment(self, file_preview_api, mock_upload_file):
"""Test HTML files are forced to download"""
mock_generator = Mock()
mock_upload_file.mime_type = "text/html"
mock_upload_file.name = "unsafe.html"
mock_upload_file.extension = "html"
response = file_preview_api._build_file_response(mock_generator, mock_upload_file, False)
assert "attachment" in response.headers["Content-Disposition"]
assert response.headers["Content-Type"] == "application/octet-stream"
assert response.headers["X-Content-Type-Options"] == "nosniff"
def test_build_file_response_audio_video(self, file_preview_api, mock_upload_file):
"""Test file response building for audio/video files"""
mock_generator = Mock()

View File

@ -0,0 +1,472 @@
"""
Unit tests for SegmentService.get_segments method.
Tests the retrieval of document segments with pagination and filtering:
- Basic pagination (page, limit)
- Status filtering
- Keyword search
- Ordering by position and id (to avoid duplicate data)
"""
from unittest.mock import Mock, create_autospec, patch
import pytest
from models.dataset import DocumentSegment
class SegmentServiceTestDataFactory:
"""
Factory class for creating test data and mock objects for segment tests.
"""
@staticmethod
def create_segment_mock(
segment_id: str = "segment-123",
document_id: str = "doc-123",
tenant_id: str = "tenant-123",
dataset_id: str = "dataset-123",
position: int = 1,
content: str = "Test content",
status: str = "completed",
**kwargs,
) -> Mock:
"""
Create a mock document segment.
Args:
segment_id: Unique identifier for the segment
document_id: Parent document ID
tenant_id: Tenant ID the segment belongs to
dataset_id: Parent dataset ID
position: Position within the document
content: Segment text content
status: Indexing status
**kwargs: Additional attributes
Returns:
Mock: DocumentSegment mock object
"""
segment = create_autospec(DocumentSegment, instance=True)
segment.id = segment_id
segment.document_id = document_id
segment.tenant_id = tenant_id
segment.dataset_id = dataset_id
segment.position = position
segment.content = content
segment.status = status
for key, value in kwargs.items():
setattr(segment, key, value)
return segment
class TestSegmentServiceGetSegments:
"""
Comprehensive unit tests for SegmentService.get_segments method.
Tests cover:
- Basic pagination functionality
- Status list filtering
- Keyword search filtering
- Ordering (position + id for uniqueness)
- Empty results
- Combined filters
"""
@pytest.fixture
def mock_segment_service_dependencies(self):
"""
Common mock setup for segment service dependencies.
Patches:
- db: Database operations and pagination
- select: SQLAlchemy query builder
"""
with (
patch("services.dataset_service.db") as mock_db,
patch("services.dataset_service.select") as mock_select,
):
yield {
"db": mock_db,
"select": mock_select,
}
def test_get_segments_basic_pagination(self, mock_segment_service_dependencies):
"""
Test basic pagination functionality.
Verifies:
- Query is built with document_id and tenant_id filters
- Pagination uses correct page and limit parameters
- Returns segments and total count
"""
# Arrange
document_id = "doc-123"
tenant_id = "tenant-123"
page = 1
limit = 20
# Create mock segments
segment1 = SegmentServiceTestDataFactory.create_segment_mock(
segment_id="seg-1", position=1, content="First segment"
)
segment2 = SegmentServiceTestDataFactory.create_segment_mock(
segment_id="seg-2", position=2, content="Second segment"
)
# Mock pagination result
mock_paginated = Mock()
mock_paginated.items = [segment1, segment2]
mock_paginated.total = 2
mock_segment_service_dependencies["db"].paginate.return_value = mock_paginated
# Mock select builder
mock_query = Mock()
mock_segment_service_dependencies["select"].return_value = mock_query
mock_query.where.return_value = mock_query
mock_query.order_by.return_value = mock_query
# Act
from services.dataset_service import SegmentService
items, total = SegmentService.get_segments(document_id=document_id, tenant_id=tenant_id, page=page, limit=limit)
# Assert
assert len(items) == 2
assert total == 2
assert items[0].id == "seg-1"
assert items[1].id == "seg-2"
mock_segment_service_dependencies["db"].paginate.assert_called_once()
call_kwargs = mock_segment_service_dependencies["db"].paginate.call_args[1]
assert call_kwargs["page"] == page
assert call_kwargs["per_page"] == limit
assert call_kwargs["max_per_page"] == 100
assert call_kwargs["error_out"] is False
def test_get_segments_with_status_filter(self, mock_segment_service_dependencies):
"""
Test filtering by status list.
Verifies:
- Status list filter is applied to query
- Only segments with matching status are returned
"""
# Arrange
document_id = "doc-123"
tenant_id = "tenant-123"
status_list = ["completed", "indexing"]
segment1 = SegmentServiceTestDataFactory.create_segment_mock(segment_id="seg-1", status="completed")
segment2 = SegmentServiceTestDataFactory.create_segment_mock(segment_id="seg-2", status="indexing")
mock_paginated = Mock()
mock_paginated.items = [segment1, segment2]
mock_paginated.total = 2
mock_segment_service_dependencies["db"].paginate.return_value = mock_paginated
mock_query = Mock()
mock_segment_service_dependencies["select"].return_value = mock_query
mock_query.where.return_value = mock_query
mock_query.order_by.return_value = mock_query
# Act
from services.dataset_service import SegmentService
items, total = SegmentService.get_segments(
document_id=document_id, tenant_id=tenant_id, status_list=status_list
)
# Assert
assert len(items) == 2
assert total == 2
# Verify where was called multiple times (base filters + status filter)
assert mock_query.where.call_count >= 2
def test_get_segments_with_empty_status_list(self, mock_segment_service_dependencies):
"""
Test with empty status list.
Verifies:
- Empty status list is handled correctly
- No status filter is applied to avoid WHERE false condition
"""
# Arrange
document_id = "doc-123"
tenant_id = "tenant-123"
status_list = []
segment = SegmentServiceTestDataFactory.create_segment_mock(segment_id="seg-1")
mock_paginated = Mock()
mock_paginated.items = [segment]
mock_paginated.total = 1
mock_segment_service_dependencies["db"].paginate.return_value = mock_paginated
mock_query = Mock()
mock_segment_service_dependencies["select"].return_value = mock_query
mock_query.where.return_value = mock_query
mock_query.order_by.return_value = mock_query
# Act
from services.dataset_service import SegmentService
items, total = SegmentService.get_segments(
document_id=document_id, tenant_id=tenant_id, status_list=status_list
)
# Assert
assert len(items) == 1
assert total == 1
# Should only be called once (base filters, no status filter)
assert mock_query.where.call_count == 1
def test_get_segments_with_keyword_search(self, mock_segment_service_dependencies):
"""
Test keyword search functionality.
Verifies:
- Keyword filter uses ilike for case-insensitive search
- Search pattern includes wildcards (%keyword%)
"""
# Arrange
document_id = "doc-123"
tenant_id = "tenant-123"
keyword = "search term"
segment = SegmentServiceTestDataFactory.create_segment_mock(
segment_id="seg-1", content="This contains search term"
)
mock_paginated = Mock()
mock_paginated.items = [segment]
mock_paginated.total = 1
mock_segment_service_dependencies["db"].paginate.return_value = mock_paginated
mock_query = Mock()
mock_segment_service_dependencies["select"].return_value = mock_query
mock_query.where.return_value = mock_query
mock_query.order_by.return_value = mock_query
# Act
from services.dataset_service import SegmentService
items, total = SegmentService.get_segments(document_id=document_id, tenant_id=tenant_id, keyword=keyword)
# Assert
assert len(items) == 1
assert total == 1
# Verify where was called for base filters + keyword filter
assert mock_query.where.call_count == 2
def test_get_segments_ordering_by_position_and_id(self, mock_segment_service_dependencies):
"""
Test ordering by position and id.
Verifies:
- Results are ordered by position ASC
- Results are secondarily ordered by id ASC to ensure uniqueness
- This prevents duplicate data across pages when positions are not unique
"""
# Arrange
document_id = "doc-123"
tenant_id = "tenant-123"
# Create segments with same position but different ids
segment1 = SegmentServiceTestDataFactory.create_segment_mock(
segment_id="seg-1", position=1, content="Content 1"
)
segment2 = SegmentServiceTestDataFactory.create_segment_mock(
segment_id="seg-2", position=1, content="Content 2"
)
segment3 = SegmentServiceTestDataFactory.create_segment_mock(
segment_id="seg-3", position=2, content="Content 3"
)
mock_paginated = Mock()
mock_paginated.items = [segment1, segment2, segment3]
mock_paginated.total = 3
mock_segment_service_dependencies["db"].paginate.return_value = mock_paginated
mock_query = Mock()
mock_segment_service_dependencies["select"].return_value = mock_query
mock_query.where.return_value = mock_query
mock_query.order_by.return_value = mock_query
# Act
from services.dataset_service import SegmentService
items, total = SegmentService.get_segments(document_id=document_id, tenant_id=tenant_id)
# Assert
assert len(items) == 3
assert total == 3
mock_query.order_by.assert_called_once()
def test_get_segments_empty_results(self, mock_segment_service_dependencies):
"""
Test when no segments match the criteria.
Verifies:
- Empty list is returned for items
- Total count is 0
"""
# Arrange
document_id = "non-existent-doc"
tenant_id = "tenant-123"
mock_paginated = Mock()
mock_paginated.items = []
mock_paginated.total = 0
mock_segment_service_dependencies["db"].paginate.return_value = mock_paginated
mock_query = Mock()
mock_segment_service_dependencies["select"].return_value = mock_query
mock_query.where.return_value = mock_query
mock_query.order_by.return_value = mock_query
# Act
from services.dataset_service import SegmentService
items, total = SegmentService.get_segments(document_id=document_id, tenant_id=tenant_id)
# Assert
assert items == []
assert total == 0
def test_get_segments_combined_filters(self, mock_segment_service_dependencies):
"""
Test with multiple filters combined.
Verifies:
- All filters work together correctly
- Status list and keyword search both applied
"""
# Arrange
document_id = "doc-123"
tenant_id = "tenant-123"
status_list = ["completed"]
keyword = "important"
page = 2
limit = 10
segment = SegmentServiceTestDataFactory.create_segment_mock(
segment_id="seg-1",
status="completed",
content="This is important information",
)
mock_paginated = Mock()
mock_paginated.items = [segment]
mock_paginated.total = 1
mock_segment_service_dependencies["db"].paginate.return_value = mock_paginated
mock_query = Mock()
mock_segment_service_dependencies["select"].return_value = mock_query
mock_query.where.return_value = mock_query
mock_query.order_by.return_value = mock_query
# Act
from services.dataset_service import SegmentService
items, total = SegmentService.get_segments(
document_id=document_id,
tenant_id=tenant_id,
status_list=status_list,
keyword=keyword,
page=page,
limit=limit,
)
# Assert
assert len(items) == 1
assert total == 1
# Verify filters: base + status + keyword
assert mock_query.where.call_count == 3
# Verify pagination parameters
call_kwargs = mock_segment_service_dependencies["db"].paginate.call_args[1]
assert call_kwargs["page"] == page
assert call_kwargs["per_page"] == limit
def test_get_segments_with_none_status_list(self, mock_segment_service_dependencies):
"""
Test with None status list.
Verifies:
- None status list is handled correctly
- No status filter is applied
"""
# Arrange
document_id = "doc-123"
tenant_id = "tenant-123"
segment = SegmentServiceTestDataFactory.create_segment_mock(segment_id="seg-1")
mock_paginated = Mock()
mock_paginated.items = [segment]
mock_paginated.total = 1
mock_segment_service_dependencies["db"].paginate.return_value = mock_paginated
mock_query = Mock()
mock_segment_service_dependencies["select"].return_value = mock_query
mock_query.where.return_value = mock_query
mock_query.order_by.return_value = mock_query
# Act
from services.dataset_service import SegmentService
items, total = SegmentService.get_segments(
document_id=document_id,
tenant_id=tenant_id,
status_list=None,
)
# Assert
assert len(items) == 1
assert total == 1
# Should only be called once (base filters only, no status filter)
assert mock_query.where.call_count == 1
def test_get_segments_pagination_max_per_page_limit(self, mock_segment_service_dependencies):
"""
Test that max_per_page is correctly set to 100.
Verifies:
- max_per_page parameter is set to 100
- This prevents excessive page sizes
"""
# Arrange
document_id = "doc-123"
tenant_id = "tenant-123"
limit = 200 # Request more than max_per_page
mock_paginated = Mock()
mock_paginated.items = []
mock_paginated.total = 0
mock_segment_service_dependencies["db"].paginate.return_value = mock_paginated
mock_query = Mock()
mock_segment_service_dependencies["select"].return_value = mock_query
mock_query.where.return_value = mock_query
mock_query.order_by.return_value = mock_query
# Act
from services.dataset_service import SegmentService
SegmentService.get_segments(
document_id=document_id,
tenant_id=tenant_id,
limit=limit,
)
# Assert
call_kwargs = mock_segment_service_dependencies["db"].paginate.call_args[1]
assert call_kwargs["max_per_page"] == 100

View File

@ -0,0 +1,360 @@
import { fireEvent, render, screen } from '@testing-library/react'
import Badge, { BadgeState, BadgeVariants } from './index'
describe('Badge', () => {
describe('Rendering', () => {
it('should render as a div element with badge class', () => {
render(<Badge>Test Badge</Badge>)
const badge = screen.getByText('Test Badge')
expect(badge).toHaveClass('badge')
expect(badge.tagName).toBe('DIV')
})
it.each([
{ children: undefined, label: 'no children' },
{ children: '', label: 'empty string' },
])('should render correctly when provided $label', ({ children }) => {
const { container } = render(<Badge>{children}</Badge>)
expect(container.firstChild).toHaveClass('badge')
})
it('should render React Node children correctly', () => {
render(
<Badge data-testid="badge-with-icon">
<span data-testid="custom-icon">🔔</span>
</Badge>,
)
expect(screen.getByTestId('badge-with-icon')).toBeInTheDocument()
expect(screen.getByTestId('custom-icon')).toBeInTheDocument()
})
})
describe('size prop', () => {
it.each([
{ size: undefined, label: 'medium (default)' },
{ size: 's', label: 'small' },
{ size: 'm', label: 'medium' },
{ size: 'l', label: 'large' },
] as const)('should render with $label size', ({ size }) => {
render(<Badge size={size}>Test</Badge>)
const expectedSize = size || 'm'
expect(screen.getByText('Test')).toHaveClass('badge', `badge-${expectedSize}`)
})
})
describe('state prop', () => {
it.each([
{ state: BadgeState.Warning, label: 'warning', expectedClass: 'badge-warning' },
{ state: BadgeState.Accent, label: 'accent', expectedClass: 'badge-accent' },
])('should render with $label state', ({ state, expectedClass }) => {
render(<Badge state={state}>State Test</Badge>)
expect(screen.getByText('State Test')).toHaveClass(expectedClass)
})
it.each([
{ state: undefined, label: 'default (undefined)' },
{ state: BadgeState.Default, label: 'default (explicit)' },
])('should use default styles when state is $label', ({ state }) => {
render(<Badge state={state}>State Test</Badge>)
const badge = screen.getByText('State Test')
expect(badge).not.toHaveClass('badge-warning', 'badge-accent')
})
})
describe('iconOnly prop', () => {
it.each([
{ size: 's', iconOnly: false, label: 'small with text' },
{ size: 's', iconOnly: true, label: 'small icon-only' },
{ size: 'm', iconOnly: false, label: 'medium with text' },
{ size: 'm', iconOnly: true, label: 'medium icon-only' },
{ size: 'l', iconOnly: false, label: 'large with text' },
{ size: 'l', iconOnly: true, label: 'large icon-only' },
] as const)('should render correctly for $label', ({ size, iconOnly }) => {
const { container } = render(<Badge size={size} iconOnly={iconOnly}>🔔</Badge>)
const badge = screen.getByText('🔔')
// Verify badge renders with correct size
expect(badge).toHaveClass('badge', `badge-${size}`)
// Verify the badge is in the DOM and contains the content
expect(badge).toBeInTheDocument()
expect(container.firstChild).toBe(badge)
})
it('should apply icon-only padding when iconOnly is true', () => {
render(<Badge iconOnly>🔔</Badge>)
// When iconOnly is true, the badge should have uniform padding (all sides equal)
const badge = screen.getByText('🔔')
expect(badge).toHaveClass('p-1')
})
it('should apply asymmetric padding when iconOnly is false', () => {
render(<Badge iconOnly={false}>Badge</Badge>)
// When iconOnly is false, the badge should have different horizontal and vertical padding
const badge = screen.getByText('Badge')
expect(badge).toHaveClass('px-[5px]', 'py-[2px]')
})
})
describe('uppercase prop', () => {
it.each([
{ uppercase: undefined, label: 'default (undefined)', expected: 'system-2xs-medium' },
{ uppercase: false, label: 'explicitly false', expected: 'system-2xs-medium' },
{ uppercase: true, label: 'true', expected: 'system-2xs-medium-uppercase' },
])('should apply $expected class when uppercase is $label', ({ uppercase, expected }) => {
render(<Badge uppercase={uppercase}>Text</Badge>)
expect(screen.getByText('Text')).toHaveClass(expected)
})
})
describe('styleCss prop', () => {
it('should apply custom inline styles correctly', () => {
const customStyles = {
backgroundColor: 'rgb(0, 0, 255)',
color: 'rgb(255, 255, 255)',
padding: '10px',
}
render(<Badge styleCss={customStyles}>Styled Badge</Badge>)
expect(screen.getByText('Styled Badge')).toHaveStyle(customStyles)
})
it('should apply inline styles without overriding core classes', () => {
render(<Badge styleCss={{ backgroundColor: 'rgb(255, 0, 0)', margin: '5px' }}>Custom</Badge>)
const badge = screen.getByText('Custom')
expect(badge).toHaveStyle({ backgroundColor: 'rgb(255, 0, 0)', margin: '5px' })
expect(badge).toHaveClass('badge')
})
})
describe('className prop', () => {
it.each([
{
props: { className: 'custom-badge' },
expected: ['badge', 'custom-badge'],
label: 'single custom class',
},
{
props: { className: 'custom-class another-class', size: 'l' as const },
expected: ['badge', 'badge-l', 'custom-class', 'another-class'],
label: 'multiple classes with size variant',
},
])('should merge $label with default classes', ({ props, expected }) => {
render(<Badge {...props}>Test</Badge>)
expect(screen.getByText('Test')).toHaveClass(...expected)
})
})
describe('HTML attributes passthrough', () => {
it.each([
{ attr: 'data-testid', value: 'custom-badge-id', label: 'data attribute' },
{ attr: 'id', value: 'unique-badge', label: 'id attribute' },
{ attr: 'aria-label', value: 'Notification badge', label: 'aria-label' },
{ attr: 'title', value: 'Hover tooltip', label: 'title attribute' },
{ attr: 'role', value: 'status', label: 'ARIA role' },
])('should pass through $label correctly', ({ attr, value }) => {
render(<Badge {...{ [attr]: value }}>Test</Badge>)
expect(screen.getByText('Test')).toHaveAttribute(attr, value)
})
it('should support multiple HTML attributes simultaneously', () => {
render(
<Badge
data-testid="multi-attr-badge"
id="badge-123"
aria-label="Status indicator"
title="Current status"
>
Test
</Badge>,
)
const badge = screen.getByTestId('multi-attr-badge')
expect(badge).toHaveAttribute('id', 'badge-123')
expect(badge).toHaveAttribute('aria-label', 'Status indicator')
expect(badge).toHaveAttribute('title', 'Current status')
})
})
describe('Event handlers', () => {
it.each([
{ handler: 'onClick', trigger: fireEvent.click, label: 'click' },
{ handler: 'onMouseEnter', trigger: fireEvent.mouseEnter, label: 'mouse enter' },
{ handler: 'onMouseLeave', trigger: fireEvent.mouseLeave, label: 'mouse leave' },
])('should trigger $handler when $label occurs', ({ handler, trigger }) => {
const mockHandler = vi.fn()
render(<Badge {...{ [handler]: mockHandler }}>Badge</Badge>)
trigger(screen.getByText('Badge'))
expect(mockHandler).toHaveBeenCalledTimes(1)
})
it('should handle user interaction flow with multiple events', () => {
const handlers = {
onClick: vi.fn(),
onMouseEnter: vi.fn(),
onMouseLeave: vi.fn(),
}
render(<Badge {...handlers}>Interactive</Badge>)
const badge = screen.getByText('Interactive')
fireEvent.mouseEnter(badge)
fireEvent.click(badge)
fireEvent.mouseLeave(badge)
expect(handlers.onMouseEnter).toHaveBeenCalledTimes(1)
expect(handlers.onClick).toHaveBeenCalledTimes(1)
expect(handlers.onMouseLeave).toHaveBeenCalledTimes(1)
})
it('should pass event object to handler with correct properties', () => {
const handleClick = vi.fn()
render(<Badge onClick={handleClick}>Event Badge</Badge>)
fireEvent.click(screen.getByText('Event Badge'))
expect(handleClick).toHaveBeenCalledWith(expect.objectContaining({
type: 'click',
}))
})
})
describe('Combined props', () => {
it('should correctly apply all props when used together', () => {
render(
<Badge
size="l"
state={BadgeState.Warning}
uppercase
className="custom-badge"
styleCss={{ backgroundColor: 'rgb(0, 0, 255)' }}
data-testid="combined-badge"
>
Full Featured
</Badge>,
)
const badge = screen.getByTestId('combined-badge')
expect(badge).toHaveClass('badge', 'badge-l', 'badge-warning', 'system-2xs-medium-uppercase', 'custom-badge')
expect(badge).toHaveStyle({ backgroundColor: 'rgb(0, 0, 255)' })
expect(badge).toHaveTextContent('Full Featured')
})
it.each([
{
props: { size: 'l' as const, state: BadgeState.Accent },
expected: ['badge', 'badge-l', 'badge-accent'],
label: 'size and state variants',
},
{
props: { iconOnly: true, uppercase: true },
expected: ['badge', 'system-2xs-medium-uppercase'],
label: 'iconOnly and uppercase',
},
])('should combine $label correctly', ({ props, expected }) => {
render(<Badge {...props}>Test</Badge>)
expect(screen.getByText('Test')).toHaveClass(...expected)
})
it('should handle event handlers with combined props', () => {
const handleClick = vi.fn()
render(
<Badge size="s" state={BadgeState.Warning} onClick={handleClick} className="interactive">
Test
</Badge>,
)
const badge = screen.getByText('Test')
expect(badge).toHaveClass('badge', 'badge-s', 'badge-warning', 'interactive')
fireEvent.click(badge)
expect(handleClick).toHaveBeenCalledTimes(1)
})
})
describe('Edge cases', () => {
it.each([
{ children: 42, text: '42', label: 'numeric value' },
{ children: 0, text: '0', label: 'zero' },
])('should render $label correctly', ({ children, text }) => {
render(<Badge>{children}</Badge>)
expect(screen.getByText(text)).toBeInTheDocument()
})
it.each([
{ children: null, label: 'null' },
{ children: false, label: 'boolean false' },
])('should handle $label children without errors', ({ children }) => {
const { container } = render(<Badge>{children}</Badge>)
expect(container.firstChild).toHaveClass('badge')
})
it('should render complex nested content correctly', () => {
render(
<Badge>
<span data-testid="icon">🔔</span>
<span data-testid="count">5</span>
</Badge>,
)
expect(screen.getByTestId('icon')).toBeInTheDocument()
expect(screen.getByTestId('count')).toBeInTheDocument()
})
})
describe('Component metadata and exports', () => {
it('should have correct displayName for debugging', () => {
expect(Badge.displayName).toBe('Badge')
})
describe('BadgeState enum', () => {
it.each([
{ key: 'Warning', value: 'warning' },
{ key: 'Accent', value: 'accent' },
{ key: 'Default', value: '' },
])('should export $key state with value "$value"', ({ key, value }) => {
expect(BadgeState[key as keyof typeof BadgeState]).toBe(value)
})
})
describe('BadgeVariants utility', () => {
it('should be a function', () => {
expect(typeof BadgeVariants).toBe('function')
})
it('should generate base badge class with default medium size', () => {
const result = BadgeVariants({})
expect(result).toContain('badge')
expect(result).toContain('badge-m')
})
it.each([
{ size: 's' },
{ size: 'm' },
{ size: 'l' },
] as const)('should generate correct classes for size=$size', ({ size }) => {
const result = BadgeVariants({ size })
expect(result).toContain('badge')
expect(result).toContain(`badge-${size}`)
})
})
})
})

View File

@ -0,0 +1,394 @@
import type { Item } from './index'
import { cleanup, fireEvent, render, screen } from '@testing-library/react'
import * as React from 'react'
import Chip from './index'
afterEach(cleanup)
// Test data factory
const createTestItems = (): Item[] => [
{ value: 'all', name: 'All Items' },
{ value: 'active', name: 'Active' },
{ value: 'archived', name: 'Archived' },
]
describe('Chip', () => {
// Shared test props
let items: Item[]
let onSelect: (item: Item) => void
let onClear: () => void
beforeEach(() => {
vi.clearAllMocks()
items = createTestItems()
onSelect = vi.fn()
onClear = vi.fn()
})
// Helper function to render Chip with default props
const renderChip = (props: Partial<React.ComponentProps<typeof Chip>> = {}) => {
return render(
<Chip
value="all"
items={items}
onSelect={onSelect}
onClear={onClear}
{...props}
/>,
)
}
// Helper function to get the trigger element
const getTrigger = (container: HTMLElement) => {
return container.querySelector('[data-state]')
}
// Helper function to open dropdown panel
const openPanel = (container: HTMLElement) => {
const trigger = getTrigger(container)
if (trigger)
fireEvent.click(trigger)
}
describe('Rendering', () => {
it('should render without crashing', () => {
renderChip()
expect(screen.getByText('All Items')).toBeInTheDocument()
})
it('should display current selected item name', () => {
renderChip({ value: 'active' })
expect(screen.getByText('Active')).toBeInTheDocument()
})
it('should display empty content when value does not match any item', () => {
const { container } = renderChip({ value: 'nonexistent' })
// When value doesn't match, no text should be displayed in trigger
const trigger = getTrigger(container)
// Check that there's no item name text (only icons should be present)
expect(trigger?.textContent?.trim()).toBeFalsy()
})
})
describe('Props', () => {
it('should update displayed item name when value prop changes', () => {
const { rerender } = renderChip({ value: 'all' })
expect(screen.getByText('All Items')).toBeInTheDocument()
rerender(
<Chip
value="archived"
items={items}
onSelect={onSelect}
onClear={onClear}
/>,
)
expect(screen.getByText('Archived')).toBeInTheDocument()
})
it('should show left icon by default', () => {
const { container } = renderChip()
// The filter icon should be visible
const svg = container.querySelector('svg')
expect(svg).toBeInTheDocument()
})
it('should hide left icon when showLeftIcon is false', () => {
renderChip({ showLeftIcon: false })
// When showLeftIcon is false, there should be no filter icon before the text
const textElement = screen.getByText('All Items')
const parent = textElement.closest('div[data-state]')
const icons = parent?.querySelectorAll('svg')
// Should only have the arrow icon, not the filter icon
expect(icons?.length).toBe(1)
})
it('should render custom left icon', () => {
const CustomIcon = () => <span data-testid="custom-icon"></span>
renderChip({ leftIcon: <CustomIcon /> })
expect(screen.getByTestId('custom-icon')).toBeInTheDocument()
})
it('should apply custom className to trigger', () => {
const customClass = 'custom-chip-class'
const { container } = renderChip({ className: customClass })
const chipElement = container.querySelector(`.${customClass}`)
expect(chipElement).toBeInTheDocument()
})
it('should apply custom panelClassName to dropdown panel', () => {
const customPanelClass = 'custom-panel-class'
const { container } = renderChip({ panelClassName: customPanelClass })
openPanel(container)
// Panel is rendered in a portal, so check document.body
const panel = document.body.querySelector(`.${customPanelClass}`)
expect(panel).toBeInTheDocument()
})
})
describe('State Management', () => {
it('should toggle dropdown panel on trigger click', () => {
const { container } = renderChip()
// Initially closed - check data-state attribute
const trigger = getTrigger(container)
expect(trigger).toHaveAttribute('data-state', 'closed')
// Open panel
openPanel(container)
expect(trigger).toHaveAttribute('data-state', 'open')
// Panel items should be visible
expect(screen.getAllByText('All Items').length).toBeGreaterThan(1)
// Close panel
if (trigger)
fireEvent.click(trigger)
expect(trigger).toHaveAttribute('data-state', 'closed')
})
it('should close panel after selecting an item', () => {
const { container } = renderChip()
openPanel(container)
const trigger = getTrigger(container)
expect(trigger).toHaveAttribute('data-state', 'open')
// Click on an item in the dropdown panel
const activeItems = screen.getAllByText('Active')
// The second one should be in the dropdown
fireEvent.click(activeItems[activeItems.length - 1])
expect(trigger).toHaveAttribute('data-state', 'closed')
})
})
describe('Event Handlers', () => {
it('should call onSelect with correct item when item is clicked', () => {
const { container } = renderChip()
openPanel(container)
// Get all "Active" texts and click the one in the dropdown (should be the last one)
const activeItems = screen.getAllByText('Active')
fireEvent.click(activeItems[activeItems.length - 1])
expect(onSelect).toHaveBeenCalledTimes(1)
expect(onSelect).toHaveBeenCalledWith(items[1])
})
it('should call onClear when clear button is clicked', () => {
const { container } = renderChip({ value: 'active' })
// Find the close icon (last SVG in the trigger) and click its parent
const trigger = getTrigger(container)
const svgs = trigger?.querySelectorAll('svg')
// The close icon should be the last SVG element
const closeIcon = svgs?.[svgs.length - 1]
const clearButton = closeIcon?.parentElement
expect(clearButton).toBeInTheDocument()
if (clearButton)
fireEvent.click(clearButton)
expect(onClear).toHaveBeenCalledTimes(1)
})
it('should stop event propagation when clear button is clicked', () => {
const { container } = renderChip({ value: 'active' })
const trigger = getTrigger(container)
expect(trigger).toHaveAttribute('data-state', 'closed')
// Find the close icon (last SVG) and click its parent
const svgs = trigger?.querySelectorAll('svg')
const closeIcon = svgs?.[svgs.length - 1]
const clearButton = closeIcon?.parentElement
if (clearButton)
fireEvent.click(clearButton)
// Panel should remain closed
expect(trigger).toHaveAttribute('data-state', 'closed')
expect(onClear).toHaveBeenCalledTimes(1)
})
it('should handle multiple rapid clicks on trigger', () => {
const { container } = renderChip()
const trigger = getTrigger(container)
// Click 1: open
if (trigger)
fireEvent.click(trigger)
expect(trigger).toHaveAttribute('data-state', 'open')
// Click 2: close
if (trigger)
fireEvent.click(trigger)
expect(trigger).toHaveAttribute('data-state', 'closed')
// Click 3: open again
if (trigger)
fireEvent.click(trigger)
expect(trigger).toHaveAttribute('data-state', 'open')
})
})
describe('Conditional Rendering', () => {
it('should show arrow down icon when no value is selected', () => {
const { container } = renderChip({ value: '' })
// Should have SVG icons (filter icon and arrow down icon)
const svgs = container.querySelectorAll('svg')
expect(svgs.length).toBeGreaterThan(0)
})
it('should show clear button when value is selected', () => {
const { container } = renderChip({ value: 'active' })
// When value is selected, there should be an icon (the close icon)
const svgs = container.querySelectorAll('svg')
expect(svgs.length).toBeGreaterThan(0)
})
it('should not show clear button when no value is selected', () => {
const { container } = renderChip({ value: '' })
const trigger = getTrigger(container)
// When value is empty, the trigger should only have 2 SVGs (filter icon + arrow)
// When value is selected, it would have 2 SVGs (filter icon + close icon)
const svgs = trigger?.querySelectorAll('svg')
// Arrow icon should be present, close icon should not
expect(svgs?.length).toBe(2)
// Verify onClear hasn't been called
expect(onClear).not.toHaveBeenCalled()
})
it('should show dropdown content only when panel is open', () => {
const { container } = renderChip()
const trigger = getTrigger(container)
// Closed by default
expect(trigger).toHaveAttribute('data-state', 'closed')
openPanel(container)
expect(trigger).toHaveAttribute('data-state', 'open')
// Items should be duplicated (once in trigger, once in panel)
expect(screen.getAllByText('All Items').length).toBeGreaterThan(1)
})
it('should show check icon on selected item in dropdown', () => {
const { container } = renderChip({ value: 'active' })
openPanel(container)
// Find the dropdown panel items
const allActiveTexts = screen.getAllByText('Active')
// The dropdown item should be the last one
const dropdownItem = allActiveTexts[allActiveTexts.length - 1]
const parentContainer = dropdownItem.parentElement
// The check icon should be a sibling within the parent
const checkIcon = parentContainer?.querySelector('svg')
expect(checkIcon).toBeInTheDocument()
})
it('should render all items in dropdown when open', () => {
const { container } = renderChip()
openPanel(container)
// Each item should appear at least twice (once in potential selected state, once in dropdown)
// Use getAllByText to handle multiple occurrences
expect(screen.getAllByText('All Items').length).toBeGreaterThan(0)
expect(screen.getAllByText('Active').length).toBeGreaterThan(0)
expect(screen.getAllByText('Archived').length).toBeGreaterThan(0)
})
})
describe('Edge Cases', () => {
it('should handle empty items array', () => {
const { container } = renderChip({ items: [], value: '' })
// Trigger should still render
const trigger = container.querySelector('[data-state]')
expect(trigger).toBeInTheDocument()
})
it('should handle value not in items list', () => {
const { container } = renderChip({ value: 'nonexistent' })
const trigger = getTrigger(container)
expect(trigger).toBeInTheDocument()
// The trigger should not display any item name text
expect(trigger?.textContent?.trim()).toBeFalsy()
})
it('should allow selecting already selected item', () => {
const { container } = renderChip({ value: 'active' })
openPanel(container)
// Click on the already selected item in the dropdown
const activeItems = screen.getAllByText('Active')
fireEvent.click(activeItems[activeItems.length - 1])
expect(onSelect).toHaveBeenCalledTimes(1)
expect(onSelect).toHaveBeenCalledWith(items[1])
})
it('should handle numeric values', () => {
const numericItems: Item[] = [
{ value: 1, name: 'First' },
{ value: 2, name: 'Second' },
{ value: 3, name: 'Third' },
]
const { container } = renderChip({ value: 2, items: numericItems })
expect(screen.getByText('Second')).toBeInTheDocument()
// Open panel and select Third
openPanel(container)
const thirdItems = screen.getAllByText('Third')
fireEvent.click(thirdItems[thirdItems.length - 1])
expect(onSelect).toHaveBeenCalledWith(numericItems[2])
})
it('should handle items with additional properties', () => {
const itemsWithExtra: Item[] = [
{ value: 'a', name: 'Item A', customProp: 'extra1' },
{ value: 'b', name: 'Item B', customProp: 'extra2' },
]
const { container } = renderChip({ value: 'a', items: itemsWithExtra })
expect(screen.getByText('Item A')).toBeInTheDocument()
// Open panel and select Item B
openPanel(container)
const itemBs = screen.getAllByText('Item B')
fireEvent.click(itemBs[itemBs.length - 1])
expect(onSelect).toHaveBeenCalledWith(itemsWithExtra[1])
})
})
})

View File

@ -0,0 +1,84 @@
import { fireEvent, render, screen, waitFor } from '@testing-library/react'
import Billing from './index'
let currentBillingUrl: string | null = 'https://billing'
let fetching = false
let isManager = true
let enableBilling = true
const refetchMock = vi.fn()
const openAsyncWindowMock = vi.fn()
vi.mock('@/service/use-billing', () => ({
useBillingUrl: () => ({
data: currentBillingUrl,
isFetching: fetching,
refetch: refetchMock,
}),
}))
vi.mock('@/hooks/use-async-window-open', () => ({
useAsyncWindowOpen: () => openAsyncWindowMock,
}))
vi.mock('@/context/app-context', () => ({
useAppContext: () => ({
isCurrentWorkspaceManager: isManager,
}),
}))
vi.mock('@/context/provider-context', () => ({
useProviderContext: () => ({
enableBilling,
}),
}))
vi.mock('../plan', () => ({
__esModule: true,
default: ({ loc }: { loc: string }) => <div data-testid="plan-component" data-loc={loc} />,
}))
describe('Billing', () => {
beforeEach(() => {
vi.clearAllMocks()
currentBillingUrl = 'https://billing'
fetching = false
isManager = true
enableBilling = true
refetchMock.mockResolvedValue({ data: 'https://billing' })
})
it('hides the billing action when user is not manager or billing is disabled', () => {
isManager = false
render(<Billing />)
expect(screen.queryByRole('button', { name: /billing\.viewBillingTitle/ })).not.toBeInTheDocument()
vi.clearAllMocks()
isManager = true
enableBilling = false
render(<Billing />)
expect(screen.queryByRole('button', { name: /billing\.viewBillingTitle/ })).not.toBeInTheDocument()
})
it('opens the billing window with the immediate url when the button is clicked', async () => {
render(<Billing />)
const actionButton = screen.getByRole('button', { name: /billing\.viewBillingTitle/ })
fireEvent.click(actionButton)
await waitFor(() => expect(openAsyncWindowMock).toHaveBeenCalled())
const [, options] = openAsyncWindowMock.mock.calls[0]
expect(options).toMatchObject({
immediateUrl: currentBillingUrl,
features: 'noopener,noreferrer',
})
})
it('disables the button while billing url is fetching', () => {
fetching = true
render(<Billing />)
const actionButton = screen.getByRole('button', { name: /billing\.viewBillingTitle/ })
expect(actionButton).toBeDisabled()
})
})

View File

@ -0,0 +1,92 @@
import { fireEvent, render, screen } from '@testing-library/react'
import { Plan } from '../type'
import HeaderBillingBtn from './index'
type HeaderGlobal = typeof globalThis & {
__mockProviderContext?: ReturnType<typeof vi.fn>
}
function getHeaderGlobal(): HeaderGlobal {
return globalThis as HeaderGlobal
}
const ensureProviderContextMock = () => {
const globals = getHeaderGlobal()
if (!globals.__mockProviderContext)
throw new Error('Provider context mock not set')
return globals.__mockProviderContext
}
vi.mock('@/context/provider-context', () => {
const mock = vi.fn()
const globals = getHeaderGlobal()
globals.__mockProviderContext = mock
return {
useProviderContext: () => mock(),
}
})
vi.mock('../upgrade-btn', () => ({
__esModule: true,
default: () => <button data-testid="upgrade-btn" type="button">Upgrade</button>,
}))
describe('HeaderBillingBtn', () => {
beforeEach(() => {
vi.clearAllMocks()
ensureProviderContextMock().mockReturnValue({
plan: {
type: Plan.professional,
},
enableBilling: true,
isFetchedPlan: true,
})
})
it('renders nothing when billing is disabled or plan is not fetched', () => {
ensureProviderContextMock().mockReturnValueOnce({
plan: {
type: Plan.professional,
},
enableBilling: false,
isFetchedPlan: true,
})
const { container } = render(<HeaderBillingBtn />)
expect(container.firstChild).toBeNull()
})
it('renders upgrade button for sandbox plan', () => {
ensureProviderContextMock().mockReturnValueOnce({
plan: {
type: Plan.sandbox,
},
enableBilling: true,
isFetchedPlan: true,
})
render(<HeaderBillingBtn />)
expect(screen.getByTestId('upgrade-btn')).toBeInTheDocument()
})
it('renders plan badge and forwards clicks when not display-only', () => {
const onClick = vi.fn()
const { rerender } = render(<HeaderBillingBtn onClick={onClick} />)
const badge = screen.getByText('pro').closest('div')
expect(badge).toHaveClass('cursor-pointer')
fireEvent.click(badge!)
expect(onClick).toHaveBeenCalledTimes(1)
rerender(<HeaderBillingBtn onClick={onClick} isDisplayOnly />)
expect(screen.getByText('pro').closest('div')).toHaveClass('cursor-default')
fireEvent.click(screen.getByText('pro').closest('div')!)
expect(onClick).toHaveBeenCalledTimes(1)
})
})

View File

@ -0,0 +1,44 @@
import { render } from '@testing-library/react'
import PartnerStack from './index'
let isCloudEdition = true
const saveOrUpdate = vi.fn()
const bind = vi.fn()
vi.mock('@/config', () => ({
get IS_CLOUD_EDITION() {
return isCloudEdition
},
}))
vi.mock('./use-ps-info', () => ({
__esModule: true,
default: () => ({
saveOrUpdate,
bind,
}),
}))
describe('PartnerStack', () => {
beforeEach(() => {
vi.clearAllMocks()
isCloudEdition = true
})
it('does not call partner stack helpers when not in cloud edition', () => {
isCloudEdition = false
render(<PartnerStack />)
expect(saveOrUpdate).not.toHaveBeenCalled()
expect(bind).not.toHaveBeenCalled()
})
it('calls saveOrUpdate and bind once when running in cloud edition', () => {
render(<PartnerStack />)
expect(saveOrUpdate).toHaveBeenCalledTimes(1)
expect(bind).toHaveBeenCalledTimes(1)
})
})

View File

@ -0,0 +1,197 @@
import { act, renderHook } from '@testing-library/react'
import { PARTNER_STACK_CONFIG } from '@/config'
import usePSInfo from './use-ps-info'
let searchParamsValues: Record<string, string | null> = {}
const setSearchParams = (values: Record<string, string | null>) => {
searchParamsValues = values
}
type PartnerStackGlobal = typeof globalThis & {
__partnerStackCookieMocks?: {
get: ReturnType<typeof vi.fn>
set: ReturnType<typeof vi.fn>
remove: ReturnType<typeof vi.fn>
}
__partnerStackMutateAsync?: ReturnType<typeof vi.fn>
}
function getPartnerStackGlobal(): PartnerStackGlobal {
return globalThis as PartnerStackGlobal
}
const ensureCookieMocks = () => {
const globals = getPartnerStackGlobal()
if (!globals.__partnerStackCookieMocks)
throw new Error('Cookie mocks not initialized')
return globals.__partnerStackCookieMocks
}
const ensureMutateAsync = () => {
const globals = getPartnerStackGlobal()
if (!globals.__partnerStackMutateAsync)
throw new Error('Mutate mock not initialized')
return globals.__partnerStackMutateAsync
}
vi.mock('js-cookie', () => {
const get = vi.fn()
const set = vi.fn()
const remove = vi.fn()
const globals = getPartnerStackGlobal()
globals.__partnerStackCookieMocks = { get, set, remove }
const cookieApi = { get, set, remove }
return {
__esModule: true,
default: cookieApi,
get,
set,
remove,
}
})
vi.mock('next/navigation', () => ({
useSearchParams: () => ({
get: (key: string) => searchParamsValues[key] ?? null,
}),
}))
vi.mock('@/service/use-billing', () => {
const mutateAsync = vi.fn()
const globals = getPartnerStackGlobal()
globals.__partnerStackMutateAsync = mutateAsync
return {
useBindPartnerStackInfo: () => ({
mutateAsync,
}),
}
})
describe('usePSInfo', () => {
const originalLocationDescriptor = Object.getOwnPropertyDescriptor(globalThis, 'location')
beforeAll(() => {
Object.defineProperty(globalThis, 'location', {
value: { hostname: 'cloud.dify.ai' },
configurable: true,
})
})
beforeEach(() => {
setSearchParams({})
const { get, set, remove } = ensureCookieMocks()
get.mockReset()
set.mockReset()
remove.mockReset()
const mutate = ensureMutateAsync()
mutate.mockReset()
mutate.mockResolvedValue(undefined)
get.mockReturnValue('{}')
})
afterAll(() => {
if (originalLocationDescriptor)
Object.defineProperty(globalThis, 'location', originalLocationDescriptor)
})
it('saves partner info when query params change', () => {
const { get, set } = ensureCookieMocks()
get.mockReturnValue(JSON.stringify({ partnerKey: 'old', clickId: 'old-click' }))
setSearchParams({
ps_partner_key: 'new-partner',
ps_xid: 'new-click',
})
const { result } = renderHook(() => usePSInfo())
expect(result.current.psPartnerKey).toBe('new-partner')
expect(result.current.psClickId).toBe('new-click')
act(() => {
result.current.saveOrUpdate()
})
expect(set).toHaveBeenCalledWith(
PARTNER_STACK_CONFIG.cookieName,
JSON.stringify({
partnerKey: 'new-partner',
clickId: 'new-click',
}),
{
expires: PARTNER_STACK_CONFIG.saveCookieDays,
path: '/',
domain: '.dify.ai',
},
)
})
it('does not overwrite cookie when params do not change', () => {
setSearchParams({
ps_partner_key: 'existing',
ps_xid: 'existing-click',
})
const { get } = ensureCookieMocks()
get.mockReturnValue(JSON.stringify({
partnerKey: 'existing',
clickId: 'existing-click',
}))
const { result } = renderHook(() => usePSInfo())
act(() => {
result.current.saveOrUpdate()
})
const { set } = ensureCookieMocks()
expect(set).not.toHaveBeenCalled()
})
it('binds partner info and clears cookie once', async () => {
setSearchParams({
ps_partner_key: 'bind-partner',
ps_xid: 'bind-click',
})
const { result } = renderHook(() => usePSInfo())
const mutate = ensureMutateAsync()
const { remove } = ensureCookieMocks()
await act(async () => {
await result.current.bind()
})
expect(mutate).toHaveBeenCalledWith({
partnerKey: 'bind-partner',
clickId: 'bind-click',
})
expect(remove).toHaveBeenCalledWith(PARTNER_STACK_CONFIG.cookieName, {
path: '/',
domain: '.dify.ai',
})
await act(async () => {
await result.current.bind()
})
expect(mutate).toHaveBeenCalledTimes(1)
})
it('still removes cookie when bind fails with status 400', async () => {
const mutate = ensureMutateAsync()
mutate.mockRejectedValueOnce({ status: 400 })
setSearchParams({
ps_partner_key: 'bind-partner',
ps_xid: 'bind-click',
})
const { result } = renderHook(() => usePSInfo())
await act(async () => {
await result.current.bind()
})
const { remove } = ensureCookieMocks()
expect(remove).toHaveBeenCalledWith(PARTNER_STACK_CONFIG.cookieName, {
path: '/',
domain: '.dify.ai',
})
})
})

View File

@ -0,0 +1,130 @@
import { fireEvent, render, screen, waitFor } from '@testing-library/react'
import { EDUCATION_VERIFYING_LOCALSTORAGE_ITEM } from '@/app/education-apply/constants'
import { Plan } from '../type'
import PlanComp from './index'
let currentPath = '/billing'
const push = vi.fn()
vi.mock('next/navigation', () => ({
useRouter: () => ({ push }),
usePathname: () => currentPath,
}))
const setShowAccountSettingModalMock = vi.fn()
vi.mock('@/context/modal-context', () => ({
// eslint-disable-next-line ts/no-explicit-any
useModalContextSelector: (selector: any) => selector({
setShowAccountSettingModal: setShowAccountSettingModalMock,
}),
}))
const providerContextMock = vi.fn()
vi.mock('@/context/provider-context', () => ({
useProviderContext: () => providerContextMock(),
}))
vi.mock('@/context/app-context', () => ({
useAppContext: () => ({
userProfile: { email: 'user@example.com' },
isCurrentWorkspaceManager: true,
}),
}))
const mutateAsyncMock = vi.fn()
let isPending = false
vi.mock('@/service/use-education', () => ({
useEducationVerify: () => ({
mutateAsync: mutateAsyncMock,
isPending,
}),
}))
const verifyStateModalMock = vi.fn(props => (
<div data-testid="verify-modal" data-is-show={props.isShow ? 'true' : 'false'}>
{props.isShow ? 'visible' : 'hidden'}
</div>
))
vi.mock('@/app/education-apply/verify-state-modal', () => ({
__esModule: true,
// eslint-disable-next-line ts/no-explicit-any
default: (props: any) => verifyStateModalMock(props),
}))
vi.mock('../upgrade-btn', () => ({
__esModule: true,
default: () => <button data-testid="plan-upgrade-btn" type="button">Upgrade</button>,
}))
describe('PlanComp', () => {
const planMock = {
type: Plan.professional,
usage: {
teamMembers: 4,
documentsUploadQuota: 3,
vectorSpace: 8,
annotatedResponse: 5,
triggerEvents: 60,
apiRateLimit: 100,
},
total: {
teamMembers: 10,
documentsUploadQuota: 20,
vectorSpace: 10,
annotatedResponse: 500,
triggerEvents: 100,
apiRateLimit: 200,
},
reset: {
triggerEvents: 2,
apiRateLimit: 1,
},
}
beforeEach(() => {
vi.clearAllMocks()
currentPath = '/billing'
isPending = false
providerContextMock.mockReturnValue({
plan: planMock,
enableEducationPlan: true,
allowRefreshEducationVerify: false,
isEducationAccount: false,
})
mutateAsyncMock.mockReset()
mutateAsyncMock.mockResolvedValue({ token: 'token' })
})
it('renders plan info and handles education verify success', async () => {
render(<PlanComp loc="billing-page" />)
expect(screen.getByText('billing.plans.professional.name')).toBeInTheDocument()
expect(screen.getByTestId('plan-upgrade-btn')).toBeInTheDocument()
const verifyBtn = screen.getByText('education.toVerified')
fireEvent.click(verifyBtn)
await waitFor(() => expect(mutateAsyncMock).toHaveBeenCalled())
await waitFor(() => expect(push).toHaveBeenCalledWith('/education-apply?token=token'))
expect(localStorage.removeItem).toHaveBeenCalledWith(EDUCATION_VERIFYING_LOCALSTORAGE_ITEM)
})
it('shows modal when education verify fails', async () => {
mutateAsyncMock.mockRejectedValueOnce(new Error('boom'))
render(<PlanComp loc="billing-page" />)
const verifyBtn = screen.getByText('education.toVerified')
fireEvent.click(verifyBtn)
await waitFor(() => expect(mutateAsyncMock).toHaveBeenCalled())
await waitFor(() => expect(screen.getByTestId('verify-modal').getAttribute('data-is-show')).toBe('true'))
})
it('resets modal context when on education apply path', () => {
currentPath = '/education-apply/setup'
render(<PlanComp loc="billing-page" />)
expect(setShowAccountSettingModalMock).toHaveBeenCalledWith(null)
})
})

View File

@ -0,0 +1,25 @@
import { render, screen } from '@testing-library/react'
import ProgressBar from './index'
describe('ProgressBar', () => {
it('renders with provided percent and color', () => {
render(<ProgressBar percent={42} color="bg-test-color" />)
const bar = screen.getByTestId('billing-progress-bar')
expect(bar).toHaveClass('bg-test-color')
expect(bar.getAttribute('style')).toContain('width: 42%')
})
it('caps width at 100% when percent exceeds max', () => {
render(<ProgressBar percent={150} color="bg-test-color" />)
const bar = screen.getByTestId('billing-progress-bar')
expect(bar.getAttribute('style')).toContain('width: 100%')
})
it('uses the default color when no color prop is provided', () => {
render(<ProgressBar percent={20} color={undefined as unknown as string} />)
expect(screen.getByTestId('billing-progress-bar')).toHaveClass('#2970FF')
})
})

View File

@ -0,0 +1,70 @@
import { render, screen } from '@testing-library/react'
import TriggerEventsLimitModal from './index'
const mockOnClose = vi.fn()
const mockOnUpgrade = vi.fn()
const planUpgradeModalMock = vi.fn((props: { show: boolean, title: string, description: string, extraInfo?: React.ReactNode, onClose: () => void, onUpgrade: () => void }) => (
<div
data-testid="plan-upgrade-modal"
data-show={props.show}
data-title={props.title}
data-description={props.description}
>
{props.extraInfo}
</div>
))
vi.mock('@/app/components/billing/plan-upgrade-modal', () => ({
__esModule: true,
// eslint-disable-next-line ts/no-explicit-any
default: (props: any) => planUpgradeModalMock(props),
}))
describe('TriggerEventsLimitModal', () => {
beforeEach(() => {
vi.clearAllMocks()
})
it('passes the trigger usage props to the upgrade modal', () => {
render(
<TriggerEventsLimitModal
show
onClose={mockOnClose}
onUpgrade={mockOnUpgrade}
usage={12}
total={20}
resetInDays={5}
/>,
)
const modal = screen.getByTestId('plan-upgrade-modal')
expect(modal.getAttribute('data-show')).toBe('true')
expect(modal.getAttribute('data-title')).toContain('billing.triggerLimitModal.title')
expect(modal.getAttribute('data-description')).toContain('billing.triggerLimitModal.description')
expect(planUpgradeModalMock).toHaveBeenCalled()
const passedProps = planUpgradeModalMock.mock.calls[0][0]
expect(passedProps.onClose).toBe(mockOnClose)
expect(passedProps.onUpgrade).toBe(mockOnUpgrade)
expect(screen.getByText('billing.triggerLimitModal.usageTitle')).toBeInTheDocument()
expect(screen.getByText('12')).toBeInTheDocument()
expect(screen.getByText('20')).toBeInTheDocument()
})
it('renders even when trigger modal is hidden', () => {
render(
<TriggerEventsLimitModal
show={false}
onClose={mockOnClose}
onUpgrade={mockOnUpgrade}
usage={0}
total={0}
/>,
)
expect(planUpgradeModalMock).toHaveBeenCalled()
expect(screen.getByTestId('plan-upgrade-modal').getAttribute('data-show')).toBe('false')
})
})

View File

@ -0,0 +1,35 @@
import { render, screen } from '@testing-library/react'
import { defaultPlan } from '../config'
import AppsInfo from './apps-info'
const appsUsage = 7
const appsTotal = 15
const mockPlan = {
...defaultPlan,
usage: {
...defaultPlan.usage,
buildApps: appsUsage,
},
total: {
...defaultPlan.total,
buildApps: appsTotal,
},
}
vi.mock('@/context/provider-context', () => ({
useProviderContext: () => ({
plan: mockPlan,
}),
}))
describe('AppsInfo', () => {
it('renders build apps usage information with context data', () => {
render(<AppsInfo className="apps-info-class" />)
expect(screen.getByText('billing.usagePage.buildApps')).toBeInTheDocument()
expect(screen.getByText(`${appsUsage}`)).toBeInTheDocument()
expect(screen.getByText(`${appsTotal}`)).toBeInTheDocument()
expect(screen.getByText('billing.usagePage.buildApps').closest('.apps-info-class')).toBeInTheDocument()
})
})

View File

@ -0,0 +1,114 @@
import { render, screen } from '@testing-library/react'
import { NUM_INFINITE } from '../config'
import UsageInfo from './index'
const TestIcon = () => <span data-testid="usage-icon" />
describe('UsageInfo', () => {
it('renders the metric with a suffix unit and tooltip text', () => {
render(
<UsageInfo
Icon={TestIcon}
name="Apps"
usage={30}
total={100}
unit="GB"
tooltip="tooltip text"
/>,
)
expect(screen.getByTestId('usage-icon')).toBeInTheDocument()
expect(screen.getByText('Apps')).toBeInTheDocument()
expect(screen.getByText('30')).toBeInTheDocument()
expect(screen.getByText('100')).toBeInTheDocument()
expect(screen.getByText('GB')).toBeInTheDocument()
})
it('renders inline unit when unitPosition is inline', () => {
render(
<UsageInfo
Icon={TestIcon}
name="Storage"
usage={20}
total={100}
unit="GB"
unitPosition="inline"
/>,
)
expect(screen.getByText('100GB')).toBeInTheDocument()
})
it('shows reset hint text instead of the unit when resetHint is provided', () => {
const resetHint = 'Resets in 3 days'
render(
<UsageInfo
Icon={TestIcon}
name="Storage"
usage={20}
total={100}
unit="GB"
resetHint={resetHint}
/>,
)
expect(screen.getByText(resetHint)).toBeInTheDocument()
expect(screen.queryByText('GB')).not.toBeInTheDocument()
})
it('displays unlimited text when total is infinite', () => {
render(
<UsageInfo
Icon={TestIcon}
name="Storage"
usage={10}
total={NUM_INFINITE}
unit="GB"
/>,
)
expect(screen.getByText('billing.plansCommon.unlimited')).toBeInTheDocument()
})
it('applies warning color when usage is close to the limit', () => {
render(
<UsageInfo
Icon={TestIcon}
name="Storage"
usage={85}
total={100}
/>,
)
const progressBar = screen.getByTestId('billing-progress-bar')
expect(progressBar).toHaveClass('bg-components-progress-warning-progress')
})
it('applies error color when usage exceeds the limit', () => {
render(
<UsageInfo
Icon={TestIcon}
name="Storage"
usage={120}
total={100}
/>,
)
const progressBar = screen.getByTestId('billing-progress-bar')
expect(progressBar).toHaveClass('bg-components-progress-error-progress')
})
it('does not render the icon when hideIcon is true', () => {
render(
<UsageInfo
Icon={TestIcon}
name="Storage"
usage={5}
total={100}
hideIcon
/>,
)
expect(screen.queryByTestId('usage-icon')).not.toBeInTheDocument()
})
})

View File

@ -0,0 +1,58 @@
import { render, screen } from '@testing-library/react'
import VectorSpaceFull from './index'
type VectorProviderGlobal = typeof globalThis & {
__vectorProviderContext?: ReturnType<typeof vi.fn>
}
function getVectorGlobal(): VectorProviderGlobal {
return globalThis as VectorProviderGlobal
}
vi.mock('@/context/provider-context', () => {
const mock = vi.fn()
getVectorGlobal().__vectorProviderContext = mock
return {
useProviderContext: () => mock(),
}
})
vi.mock('../upgrade-btn', () => ({
__esModule: true,
default: () => <button data-testid="vector-upgrade-btn" type="button">Upgrade</button>,
}))
describe('VectorSpaceFull', () => {
const planMock = {
type: 'team',
usage: {
vectorSpace: 8,
},
total: {
vectorSpace: 10,
},
}
beforeEach(() => {
vi.clearAllMocks()
const globals = getVectorGlobal()
globals.__vectorProviderContext?.mockReturnValue({
plan: planMock,
})
})
it('renders tip text and upgrade button', () => {
render(<VectorSpaceFull />)
expect(screen.getByText('billing.vectorSpace.fullTip')).toBeInTheDocument()
expect(screen.getByText('billing.vectorSpace.fullSolution')).toBeInTheDocument()
expect(screen.getByTestId('vector-upgrade-btn')).toBeInTheDocument()
})
it('shows vector usage and total', () => {
render(<VectorSpaceFull />)
expect(screen.getByText('8')).toBeInTheDocument()
expect(screen.getByText('10MB')).toBeInTheDocument()
})
})

View File

@ -195,9 +195,11 @@ export const Workflow: FC<WorkflowProps> = memo(({
const { nodesReadOnly } = useNodesReadOnly()
const { eventEmitter } = useEventEmitterContextContext()
const store = useStoreApi()
eventEmitter?.useSubscription((v: any) => {
if (v.type === WORKFLOW_DATA_UPDATE) {
setNodes(v.payload.nodes)
store.getState().setNodes(v.payload.nodes)
setEdges(v.payload.edges)
if (v.payload.viewport)
@ -359,7 +361,6 @@ export const Workflow: FC<WorkflowProps> = memo(({
}
}, [schemaTypeDefinitions, fetchInspectVars, isLoadedVars, vars, customTools, buildInTools, workflowTools, mcpTools, dataSourceList])
const store = useStoreApi()
if (process.env.NODE_ENV === 'development') {
store.getState().onError = (code, message) => {
if (code === '002')