diff --git a/api/.importlinter b/api/.importlinter
index b676e97591..cc7ffc15c8 100644
--- a/api/.importlinter
+++ b/api/.importlinter
@@ -104,9 +104,7 @@ forbidden_modules =
ignore_imports =
core.workflow.nodes.loop.loop_node -> core.app.workflow.node_factory
core.workflow.graph_engine.command_channels.redis_channel -> extensions.ext_redis
- core.workflow.graph_engine.layers.observability -> configs
- core.workflow.graph_engine.layers.observability -> extensions.otel.runtime
- core.workflow.graph_engine.layers.persistence -> core.ops.ops_trace_manager
+ core.workflow.workflow_entry -> core.app.workflow.layers.observability
core.workflow.graph_engine.worker_management.worker_pool -> configs
core.workflow.nodes.agent.agent_node -> core.model_manager
core.workflow.nodes.agent.agent_node -> core.provider_manager
@@ -147,7 +145,6 @@ ignore_imports =
core.workflow.workflow_entry -> models.workflow
core.workflow.nodes.agent.agent_node -> core.agent.entities
core.workflow.nodes.agent.agent_node -> core.agent.plugin_entities
- core.workflow.graph_engine.layers.persistence -> core.app.entities.app_invoke_entities
core.workflow.nodes.base.node -> core.app.entities.app_invoke_entities
core.workflow.nodes.knowledge_index.knowledge_index_node -> core.app.entities.app_invoke_entities
core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node -> core.app.app_config.entities
@@ -217,7 +214,6 @@ ignore_imports =
core.workflow.nodes.llm.node -> core.llm_generator.output_parser.errors
core.workflow.nodes.llm.node -> core.llm_generator.output_parser.structured_output
core.workflow.nodes.llm.node -> core.model_manager
- core.workflow.graph_engine.layers.persistence -> core.ops.entities.trace_entity
core.workflow.nodes.agent.entities -> core.prompt.entities.advanced_prompt_entities
core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node -> core.prompt.simple_prompt_transform
core.workflow.nodes.llm.entities -> core.prompt.entities.advanced_prompt_entities
diff --git a/api/core/app/apps/advanced_chat/app_runner.py b/api/core/app/apps/advanced_chat/app_runner.py
index a9e41bffdb..dc68df3687 100644
--- a/api/core/app/apps/advanced_chat/app_runner.py
+++ b/api/core/app/apps/advanced_chat/app_runner.py
@@ -21,6 +21,7 @@ from core.app.entities.queue_entities import (
)
from core.app.features.annotation_reply.annotation_reply import AnnotationReplyFeature
from core.app.layers.conversation_variable_persist_layer import ConversationVariablePersistenceLayer
+from core.app.workflow.layers.persistence import PersistenceWorkflowInfo, WorkflowPersistenceLayer
from core.db.session_factory import session_factory
from core.moderation.base import ModerationError
from core.moderation.input_moderation import InputModeration
@@ -29,7 +30,6 @@ from core.variables.variables import Variable
from core.workflow.enums import WorkflowType
from core.workflow.graph_engine.command_channels.redis_channel import RedisChannel
from core.workflow.graph_engine.layers.base import GraphEngineLayer
-from core.workflow.graph_engine.layers.persistence import PersistenceWorkflowInfo, WorkflowPersistenceLayer
from core.workflow.repositories.workflow_execution_repository import WorkflowExecutionRepository
from core.workflow.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository
from core.workflow.runtime import GraphRuntimeState, VariablePool
diff --git a/api/core/app/apps/pipeline/pipeline_runner.py b/api/core/app/apps/pipeline/pipeline_runner.py
index 34d02a1e51..8ea34344b2 100644
--- a/api/core/app/apps/pipeline/pipeline_runner.py
+++ b/api/core/app/apps/pipeline/pipeline_runner.py
@@ -9,12 +9,12 @@ from core.app.entities.app_invoke_entities import (
InvokeFrom,
RagPipelineGenerateEntity,
)
+from core.app.workflow.layers.persistence import PersistenceWorkflowInfo, WorkflowPersistenceLayer
from core.app.workflow.node_factory import DifyNodeFactory
from core.variables.variables import RAGPipelineVariable, RAGPipelineVariableInput
from core.workflow.entities.graph_init_params import GraphInitParams
from core.workflow.enums import WorkflowType
from core.workflow.graph import Graph
-from core.workflow.graph_engine.layers.persistence import PersistenceWorkflowInfo, WorkflowPersistenceLayer
from core.workflow.graph_events import GraphEngineEvent, GraphRunFailedEvent
from core.workflow.repositories.workflow_execution_repository import WorkflowExecutionRepository
from core.workflow.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository
diff --git a/api/core/app/apps/workflow/app_runner.py b/api/core/app/apps/workflow/app_runner.py
index 9bc0275f6e..df3a096bc9 100644
--- a/api/core/app/apps/workflow/app_runner.py
+++ b/api/core/app/apps/workflow/app_runner.py
@@ -8,10 +8,10 @@ from core.app.apps.workflow.app_config_manager import WorkflowAppConfig
from core.app.apps.workflow_app_runner import WorkflowBasedAppRunner
from core.app.entities.app_invoke_entities import InvokeFrom, WorkflowAppGenerateEntity
from core.sandbox import Sandbox
+from core.app.workflow.layers.persistence import PersistenceWorkflowInfo, WorkflowPersistenceLayer
from core.workflow.enums import WorkflowType
from core.workflow.graph_engine.command_channels.redis_channel import RedisChannel
from core.workflow.graph_engine.layers.base import GraphEngineLayer
-from core.workflow.graph_engine.layers.persistence import PersistenceWorkflowInfo, WorkflowPersistenceLayer
from core.workflow.repositories.workflow_execution_repository import WorkflowExecutionRepository
from core.workflow.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository
from core.workflow.runtime import GraphRuntimeState, VariablePool
diff --git a/api/core/app/apps/workflow_app_runner.py b/api/core/app/apps/workflow_app_runner.py
index 2ef28ffbe5..3c780a6532 100644
--- a/api/core/app/apps/workflow_app_runner.py
+++ b/api/core/app/apps/workflow_app_runner.py
@@ -157,7 +157,7 @@ class WorkflowBasedAppRunner:
# Create initial runtime state with variable pool containing environment variables
graph_runtime_state = GraphRuntimeState(
variable_pool=VariablePool(
- system_variables=SystemVariable.empty(),
+ system_variables=SystemVariable.default(),
user_inputs={},
environment_variables=workflow.environment_variables,
),
@@ -272,7 +272,9 @@ class WorkflowBasedAppRunner:
)
# init graph
- graph = Graph.init(graph_config=graph_config, node_factory=node_factory, root_node_id=node_id)
+ graph = Graph.init(
+ graph_config=graph_config, node_factory=node_factory, root_node_id=node_id, skip_validation=True
+ )
if not graph:
raise ValueError("graph not found in workflow")
diff --git a/api/core/app/workflow/layers/__init__.py b/api/core/app/workflow/layers/__init__.py
new file mode 100644
index 0000000000..945f75303c
--- /dev/null
+++ b/api/core/app/workflow/layers/__init__.py
@@ -0,0 +1,10 @@
+"""Workflow-level GraphEngine layers that depend on outer infrastructure."""
+
+from .observability import ObservabilityLayer
+from .persistence import PersistenceWorkflowInfo, WorkflowPersistenceLayer
+
+__all__ = [
+ "ObservabilityLayer",
+ "PersistenceWorkflowInfo",
+ "WorkflowPersistenceLayer",
+]
diff --git a/api/core/workflow/graph_engine/layers/observability.py b/api/core/app/workflow/layers/observability.py
similarity index 100%
rename from api/core/workflow/graph_engine/layers/observability.py
rename to api/core/app/workflow/layers/observability.py
diff --git a/api/core/workflow/graph_engine/layers/persistence.py b/api/core/app/workflow/layers/persistence.py
similarity index 99%
rename from api/core/workflow/graph_engine/layers/persistence.py
rename to api/core/app/workflow/layers/persistence.py
index 509478b3ee..132302efe1 100644
--- a/api/core/workflow/graph_engine/layers/persistence.py
+++ b/api/core/app/workflow/layers/persistence.py
@@ -45,7 +45,6 @@ from core.workflow.graph_events import (
from core.workflow.node_events import NodeRunResult
from core.workflow.repositories.workflow_execution_repository import WorkflowExecutionRepository
from core.workflow.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository
-from core.workflow.workflow_entry import WorkflowEntry
from libs.datetime_utils import naive_utc_now
@@ -319,6 +318,9 @@ class WorkflowPersistenceLayer(GraphEngineLayer):
# workflow inputs stay reusable without binding future runs to this conversation.
continue
inputs[f"sys.{field_name}"] = value
+ # Local import to avoid circular dependency during app bootstrapping.
+ from core.workflow.workflow_entry import WorkflowEntry
+
handled = WorkflowEntry.handle_special_values(inputs)
return handled or {}
diff --git a/api/core/workflow/graph/graph.py b/api/core/workflow/graph/graph.py
index bd2326e84f..d95390ae1b 100644
--- a/api/core/workflow/graph/graph.py
+++ b/api/core/workflow/graph/graph.py
@@ -288,6 +288,7 @@ class Graph:
graph_config: Mapping[str, object],
node_factory: NodeFactory,
root_node_id: str | None = None,
+ skip_validation: bool = False,
) -> Graph:
"""
Initialize graph
@@ -346,8 +347,9 @@ class Graph:
root_node=root_node,
)
- # Validate the graph structure using built-in validators
- get_graph_validator().validate(graph)
+ if not skip_validation:
+ # Validate the graph structure using built-in validators
+ get_graph_validator().validate(graph)
return graph
diff --git a/api/core/workflow/graph_engine/layers/__init__.py b/api/core/workflow/graph_engine/layers/__init__.py
index 772433e48c..0a29a52993 100644
--- a/api/core/workflow/graph_engine/layers/__init__.py
+++ b/api/core/workflow/graph_engine/layers/__init__.py
@@ -8,11 +8,9 @@ with middleware-like components that can observe events and interact with execut
from .base import GraphEngineLayer
from .debug_logging import DebugLoggingLayer
from .execution_limits import ExecutionLimitsLayer
-from .observability import ObservabilityLayer
__all__ = [
"DebugLoggingLayer",
"ExecutionLimitsLayer",
"GraphEngineLayer",
- "ObservabilityLayer",
]
diff --git a/api/core/workflow/runtime/variable_pool.py b/api/core/workflow/runtime/variable_pool.py
index 9e4c4e6757..0aecbc8ec9 100644
--- a/api/core/workflow/runtime/variable_pool.py
+++ b/api/core/workflow/runtime/variable_pool.py
@@ -44,7 +44,7 @@ class VariablePool(BaseModel):
)
system_variables: SystemVariable = Field(
description="System variables",
- default_factory=SystemVariable.empty,
+ default_factory=SystemVariable.default,
)
environment_variables: Sequence[Variable] = Field(
description="Environment variables.",
@@ -309,4 +309,4 @@ class VariablePool(BaseModel):
@classmethod
def empty(cls) -> VariablePool:
"""Create an empty variable pool."""
- return cls(system_variables=SystemVariable.empty())
+ return cls(system_variables=SystemVariable.default())
diff --git a/api/core/workflow/system_variable.py b/api/core/workflow/system_variable.py
index cda8091771..6946e3e6ab 100644
--- a/api/core/workflow/system_variable.py
+++ b/api/core/workflow/system_variable.py
@@ -3,6 +3,7 @@ from __future__ import annotations
from collections.abc import Mapping, Sequence
from types import MappingProxyType
from typing import Any
+from uuid import uuid4
from pydantic import AliasChoices, BaseModel, ConfigDict, Field, model_validator
@@ -72,8 +73,8 @@ class SystemVariable(BaseModel):
return data
@classmethod
- def empty(cls) -> SystemVariable:
- return cls()
+ def default(cls) -> SystemVariable:
+ return cls(workflow_execution_id=str(uuid4()))
def to_dict(self) -> dict[SystemVariableKey, Any]:
# NOTE: This method is provided for compatibility with legacy code.
diff --git a/api/core/workflow/workflow_entry.py b/api/core/workflow/workflow_entry.py
index 6e286cef9b..da056b3241 100644
--- a/api/core/workflow/workflow_entry.py
+++ b/api/core/workflow/workflow_entry.py
@@ -7,6 +7,7 @@ from typing import Any
from configs import dify_config
from core.app.apps.exc import GenerateTaskStoppedError
from core.app.entities.app_invoke_entities import InvokeFrom
+from core.app.workflow.layers.observability import ObservabilityLayer
from core.app.workflow.node_factory import DifyNodeFactory
from core.file.models import File
from core.sandbox import Sandbox
@@ -16,7 +17,7 @@ from core.workflow.errors import WorkflowNodeRunFailedError
from core.workflow.graph import Graph
from core.workflow.graph_engine import GraphEngine
from core.workflow.graph_engine.command_channels import InMemoryChannel
-from core.workflow.graph_engine.layers import DebugLoggingLayer, ExecutionLimitsLayer, ObservabilityLayer
+from core.workflow.graph_engine.layers import DebugLoggingLayer, ExecutionLimitsLayer
from core.workflow.graph_engine.protocols.command_channel import CommandChannel
from core.workflow.graph_events import GraphEngineEvent, GraphNodeEventBase, GraphRunFailedEvent
from core.workflow.nodes import NodeType
@@ -281,7 +282,7 @@ class WorkflowEntry:
# init variable pool
variable_pool = VariablePool(
- system_variables=SystemVariable.empty(),
+ system_variables=SystemVariable.default(),
user_inputs={},
environment_variables=[],
)
diff --git a/api/extensions/ext_fastopenapi.py b/api/extensions/ext_fastopenapi.py
index 5f98aa7b67..e6c1bc6bee 100644
--- a/api/extensions/ext_fastopenapi.py
+++ b/api/extensions/ext_fastopenapi.py
@@ -36,7 +36,7 @@ def init_app(app: DifyApp) -> None:
router.include_router(console_router, prefix="/console/api")
CORS(
app,
- resources={r"/console/api/*": {"origins": dify_config.CONSOLE_CORS_ALLOW_ORIGINS}},
+ resources={r"/console/api/.*": {"origins": dify_config.CONSOLE_CORS_ALLOW_ORIGINS}},
supports_credentials=True,
allow_headers=list(AUTHENTICATED_HEADERS),
methods=["GET", "PUT", "POST", "DELETE", "OPTIONS", "PATCH"],
diff --git a/api/services/rag_pipeline/rag_pipeline.py b/api/services/rag_pipeline/rag_pipeline.py
index 2d8418900c..ccc6abcc06 100644
--- a/api/services/rag_pipeline/rag_pipeline.py
+++ b/api/services/rag_pipeline/rag_pipeline.py
@@ -436,7 +436,7 @@ class RagPipelineService:
user_inputs=user_inputs,
user_id=account.id,
variable_pool=VariablePool(
- system_variables=SystemVariable.empty(),
+ system_variables=SystemVariable.default(),
user_inputs=user_inputs,
environment_variables=[],
conversation_variables=[],
diff --git a/api/services/workflow_service.py b/api/services/workflow_service.py
index 7e9605f6d3..e5cd2dd7b9 100644
--- a/api/services/workflow_service.py
+++ b/api/services/workflow_service.py
@@ -752,7 +752,7 @@ class WorkflowService:
else:
variable_pool = VariablePool(
- system_variables=SystemVariable.empty(),
+ system_variables=SystemVariable.default(),
user_inputs=user_inputs,
environment_variables=draft_workflow.environment_variables,
conversation_variables=[],
@@ -1160,7 +1160,7 @@ def _setup_variable_pool(
system_variable.conversation_id = conversation_id
system_variable.dialogue_count = 1
else:
- system_variable = SystemVariable.empty()
+ system_variable = SystemVariable.default()
# init variable pool
variable_pool = VariablePool(
diff --git a/api/tests/unit_tests/core/app/apps/test_workflow_app_runner_single_node.py b/api/tests/unit_tests/core/app/apps/test_workflow_app_runner_single_node.py
new file mode 100644
index 0000000000..f5903d28bd
--- /dev/null
+++ b/api/tests/unit_tests/core/app/apps/test_workflow_app_runner_single_node.py
@@ -0,0 +1,107 @@
+from __future__ import annotations
+
+from typing import Any
+from unittest.mock import MagicMock, patch
+
+import pytest
+
+from core.app.apps.base_app_queue_manager import AppQueueManager
+from core.app.apps.workflow.app_runner import WorkflowAppRunner
+from core.app.entities.app_invoke_entities import InvokeFrom, WorkflowAppGenerateEntity
+from core.workflow.runtime import GraphRuntimeState, VariablePool
+from core.workflow.system_variable import SystemVariable
+from models.workflow import Workflow
+
+
+def _make_graph_state():
+ variable_pool = VariablePool(
+ system_variables=SystemVariable.default(),
+ user_inputs={},
+ environment_variables=[],
+ conversation_variables=[],
+ )
+ return MagicMock(), variable_pool, GraphRuntimeState(variable_pool=variable_pool, start_at=0.0)
+
+
+@pytest.mark.parametrize(
+ ("single_iteration_run", "single_loop_run"),
+ [
+ (WorkflowAppGenerateEntity.SingleIterationRunEntity(node_id="iter", inputs={}), None),
+ (None, WorkflowAppGenerateEntity.SingleLoopRunEntity(node_id="loop", inputs={})),
+ ],
+)
+def test_run_uses_single_node_execution_branch(
+ single_iteration_run: Any,
+ single_loop_run: Any,
+) -> None:
+ app_config = MagicMock()
+ app_config.app_id = "app"
+ app_config.tenant_id = "tenant"
+ app_config.workflow_id = "workflow"
+
+ app_generate_entity = MagicMock(spec=WorkflowAppGenerateEntity)
+ app_generate_entity.app_config = app_config
+ app_generate_entity.inputs = {}
+ app_generate_entity.files = []
+ app_generate_entity.user_id = "user"
+ app_generate_entity.invoke_from = InvokeFrom.SERVICE_API
+ app_generate_entity.workflow_execution_id = "execution-id"
+ app_generate_entity.task_id = "task-id"
+ app_generate_entity.call_depth = 0
+ app_generate_entity.trace_manager = None
+ app_generate_entity.single_iteration_run = single_iteration_run
+ app_generate_entity.single_loop_run = single_loop_run
+
+ workflow = MagicMock(spec=Workflow)
+ workflow.tenant_id = "tenant"
+ workflow.app_id = "app"
+ workflow.id = "workflow"
+ workflow.type = "workflow"
+ workflow.version = "v1"
+ workflow.graph_dict = {"nodes": [], "edges": []}
+ workflow.environment_variables = []
+
+ runner = WorkflowAppRunner(
+ application_generate_entity=app_generate_entity,
+ queue_manager=MagicMock(spec=AppQueueManager),
+ variable_loader=MagicMock(),
+ workflow=workflow,
+ system_user_id="system-user",
+ workflow_execution_repository=MagicMock(),
+ workflow_node_execution_repository=MagicMock(),
+ )
+
+ graph, variable_pool, graph_runtime_state = _make_graph_state()
+ mock_workflow_entry = MagicMock()
+ mock_workflow_entry.graph_engine = MagicMock()
+ mock_workflow_entry.graph_engine.layer = MagicMock()
+ mock_workflow_entry.run.return_value = iter([])
+
+ with (
+ patch("core.app.apps.workflow.app_runner.RedisChannel"),
+ patch("core.app.apps.workflow.app_runner.redis_client"),
+ patch("core.app.apps.workflow.app_runner.WorkflowEntry", return_value=mock_workflow_entry) as entry_class,
+ patch.object(
+ runner,
+ "_prepare_single_node_execution",
+ return_value=(
+ graph,
+ variable_pool,
+ graph_runtime_state,
+ ),
+ ) as prepare_single,
+ patch.object(runner, "_init_graph") as init_graph,
+ ):
+ runner.run()
+
+ prepare_single.assert_called_once_with(
+ workflow=workflow,
+ single_iteration_run=single_iteration_run,
+ single_loop_run=single_loop_run,
+ )
+ init_graph.assert_not_called()
+
+ entry_kwargs = entry_class.call_args.kwargs
+ assert entry_kwargs["invoke_from"] == InvokeFrom.DEBUGGER
+ assert entry_kwargs["variable_pool"] is variable_pool
+ assert entry_kwargs["graph_runtime_state"] is graph_runtime_state
diff --git a/api/tests/unit_tests/core/workflow/graph/test_graph_skip_validation.py b/api/tests/unit_tests/core/workflow/graph/test_graph_skip_validation.py
new file mode 100644
index 0000000000..6858120335
--- /dev/null
+++ b/api/tests/unit_tests/core/workflow/graph/test_graph_skip_validation.py
@@ -0,0 +1,120 @@
+from __future__ import annotations
+
+from typing import Any
+
+import pytest
+
+from core.app.entities.app_invoke_entities import InvokeFrom
+from core.app.workflow.node_factory import DifyNodeFactory
+from core.workflow.entities import GraphInitParams
+from core.workflow.graph import Graph
+from core.workflow.graph.validation import GraphValidationError
+from core.workflow.nodes import NodeType
+from core.workflow.runtime import GraphRuntimeState, VariablePool
+from core.workflow.system_variable import SystemVariable
+from models.enums import UserFrom
+
+
+def _build_iteration_graph(node_id: str) -> dict[str, Any]:
+ return {
+ "nodes": [
+ {
+ "id": node_id,
+ "data": {
+ "type": "iteration",
+ "title": "Iteration",
+ "iterator_selector": ["start", "items"],
+ "output_selector": [node_id, "output"],
+ },
+ }
+ ],
+ "edges": [],
+ }
+
+
+def _build_loop_graph(node_id: str) -> dict[str, Any]:
+ return {
+ "nodes": [
+ {
+ "id": node_id,
+ "data": {
+ "type": "loop",
+ "title": "Loop",
+ "loop_count": 1,
+ "break_conditions": [],
+ "logical_operator": "and",
+ "loop_variables": [],
+ "outputs": {},
+ },
+ }
+ ],
+ "edges": [],
+ }
+
+
+def _make_factory(graph_config: dict[str, Any]) -> DifyNodeFactory:
+ graph_init_params = GraphInitParams(
+ tenant_id="tenant",
+ app_id="app",
+ workflow_id="workflow",
+ graph_config=graph_config,
+ user_id="user",
+ user_from=UserFrom.ACCOUNT,
+ invoke_from=InvokeFrom.DEBUGGER,
+ call_depth=0,
+ )
+ graph_runtime_state = GraphRuntimeState(
+ variable_pool=VariablePool(
+ system_variables=SystemVariable.default(),
+ user_inputs={},
+ environment_variables=[],
+ ),
+ start_at=0.0,
+ )
+ return DifyNodeFactory(graph_init_params=graph_init_params, graph_runtime_state=graph_runtime_state)
+
+
+def test_iteration_root_requires_skip_validation():
+ node_id = "iteration-node"
+ graph_config = _build_iteration_graph(node_id)
+ node_factory = _make_factory(graph_config)
+
+ with pytest.raises(GraphValidationError):
+ Graph.init(
+ graph_config=graph_config,
+ node_factory=node_factory,
+ root_node_id=node_id,
+ )
+
+ graph = Graph.init(
+ graph_config=graph_config,
+ node_factory=node_factory,
+ root_node_id=node_id,
+ skip_validation=True,
+ )
+
+ assert graph.root_node.id == node_id
+ assert graph.root_node.node_type == NodeType.ITERATION
+
+
+def test_loop_root_requires_skip_validation():
+ node_id = "loop-node"
+ graph_config = _build_loop_graph(node_id)
+ node_factory = _make_factory(graph_config)
+
+ with pytest.raises(GraphValidationError):
+ Graph.init(
+ graph_config=graph_config,
+ node_factory=node_factory,
+ root_node_id=node_id,
+ )
+
+ graph = Graph.init(
+ graph_config=graph_config,
+ node_factory=node_factory,
+ root_node_id=node_id,
+ skip_validation=True,
+ )
+
+ assert graph.root_node.id == node_id
+ assert graph.root_node.node_type == NodeType.LOOP
diff --git a/api/tests/unit_tests/core/workflow/graph_engine/layers/conftest.py b/api/tests/unit_tests/core/workflow/graph_engine/layers/conftest.py
index 51da3b7d73..35a234be0b 100644
--- a/api/tests/unit_tests/core/workflow/graph_engine/layers/conftest.py
+++ b/api/tests/unit_tests/core/workflow/graph_engine/layers/conftest.py
@@ -90,14 +90,14 @@ def mock_tool_node():
@pytest.fixture
def mock_is_instrument_flag_enabled_false():
"""Mock is_instrument_flag_enabled to return False."""
- with patch("core.workflow.graph_engine.layers.observability.is_instrument_flag_enabled", return_value=False):
+ with patch("core.app.workflow.layers.observability.is_instrument_flag_enabled", return_value=False):
yield
@pytest.fixture
def mock_is_instrument_flag_enabled_true():
"""Mock is_instrument_flag_enabled to return True."""
- with patch("core.workflow.graph_engine.layers.observability.is_instrument_flag_enabled", return_value=True):
+ with patch("core.app.workflow.layers.observability.is_instrument_flag_enabled", return_value=True):
yield
diff --git a/api/tests/unit_tests/core/workflow/graph_engine/layers/test_observability.py b/api/tests/unit_tests/core/workflow/graph_engine/layers/test_observability.py
index 8cc080fe94..ade846df28 100644
--- a/api/tests/unit_tests/core/workflow/graph_engine/layers/test_observability.py
+++ b/api/tests/unit_tests/core/workflow/graph_engine/layers/test_observability.py
@@ -15,14 +15,14 @@ from unittest.mock import patch
import pytest
from opentelemetry.trace import StatusCode
+from core.app.workflow.layers.observability import ObservabilityLayer
from core.workflow.enums import NodeType
-from core.workflow.graph_engine.layers.observability import ObservabilityLayer
class TestObservabilityLayerInitialization:
"""Test ObservabilityLayer initialization logic."""
- @patch("core.workflow.graph_engine.layers.observability.dify_config.ENABLE_OTEL", True)
+ @patch("core.app.workflow.layers.observability.dify_config.ENABLE_OTEL", True)
@pytest.mark.usefixtures("mock_is_instrument_flag_enabled_false")
def test_initialization_when_otel_enabled(self, tracer_provider_with_memory_exporter):
"""Test that layer initializes correctly when OTel is enabled."""
@@ -32,7 +32,7 @@ class TestObservabilityLayerInitialization:
assert NodeType.TOOL in layer._parsers
assert layer._default_parser is not None
- @patch("core.workflow.graph_engine.layers.observability.dify_config.ENABLE_OTEL", False)
+ @patch("core.app.workflow.layers.observability.dify_config.ENABLE_OTEL", False)
@pytest.mark.usefixtures("mock_is_instrument_flag_enabled_true")
def test_initialization_when_instrument_flag_enabled(self, tracer_provider_with_memory_exporter):
"""Test that layer enables when instrument flag is enabled."""
@@ -46,7 +46,7 @@ class TestObservabilityLayerInitialization:
class TestObservabilityLayerNodeSpanLifecycle:
"""Test node span creation and lifecycle management."""
- @patch("core.workflow.graph_engine.layers.observability.dify_config.ENABLE_OTEL", True)
+ @patch("core.app.workflow.layers.observability.dify_config.ENABLE_OTEL", True)
@pytest.mark.usefixtures("mock_is_instrument_flag_enabled_false")
def test_node_span_created_and_ended(
self, tracer_provider_with_memory_exporter, memory_span_exporter, mock_llm_node
@@ -63,7 +63,7 @@ class TestObservabilityLayerNodeSpanLifecycle:
assert spans[0].name == mock_llm_node.title
assert spans[0].status.status_code == StatusCode.OK
- @patch("core.workflow.graph_engine.layers.observability.dify_config.ENABLE_OTEL", True)
+ @patch("core.app.workflow.layers.observability.dify_config.ENABLE_OTEL", True)
@pytest.mark.usefixtures("mock_is_instrument_flag_enabled_false")
def test_node_error_recorded_in_span(
self, tracer_provider_with_memory_exporter, memory_span_exporter, mock_llm_node
@@ -82,7 +82,7 @@ class TestObservabilityLayerNodeSpanLifecycle:
assert len(spans[0].events) > 0
assert any("exception" in event.name.lower() for event in spans[0].events)
- @patch("core.workflow.graph_engine.layers.observability.dify_config.ENABLE_OTEL", True)
+ @patch("core.app.workflow.layers.observability.dify_config.ENABLE_OTEL", True)
@pytest.mark.usefixtures("mock_is_instrument_flag_enabled_false")
def test_node_end_without_start_handled_gracefully(
self, tracer_provider_with_memory_exporter, memory_span_exporter, mock_llm_node
@@ -100,7 +100,7 @@ class TestObservabilityLayerNodeSpanLifecycle:
class TestObservabilityLayerParserIntegration:
"""Test parser integration for different node types."""
- @patch("core.workflow.graph_engine.layers.observability.dify_config.ENABLE_OTEL", True)
+ @patch("core.app.workflow.layers.observability.dify_config.ENABLE_OTEL", True)
@pytest.mark.usefixtures("mock_is_instrument_flag_enabled_false")
def test_default_parser_used_for_regular_node(
self, tracer_provider_with_memory_exporter, memory_span_exporter, mock_start_node
@@ -119,7 +119,7 @@ class TestObservabilityLayerParserIntegration:
assert attrs["node.execution_id"] == mock_start_node.execution_id
assert attrs["node.type"] == mock_start_node.node_type.value
- @patch("core.workflow.graph_engine.layers.observability.dify_config.ENABLE_OTEL", True)
+ @patch("core.app.workflow.layers.observability.dify_config.ENABLE_OTEL", True)
@pytest.mark.usefixtures("mock_is_instrument_flag_enabled_false")
def test_tool_parser_used_for_tool_node(
self, tracer_provider_with_memory_exporter, memory_span_exporter, mock_tool_node
@@ -138,7 +138,7 @@ class TestObservabilityLayerParserIntegration:
assert attrs["gen_ai.tool.name"] == mock_tool_node.title
assert attrs["gen_ai.tool.type"] == mock_tool_node._node_data.provider_type.value
- @patch("core.workflow.graph_engine.layers.observability.dify_config.ENABLE_OTEL", True)
+ @patch("core.app.workflow.layers.observability.dify_config.ENABLE_OTEL", True)
@pytest.mark.usefixtures("mock_is_instrument_flag_enabled_false")
def test_llm_parser_used_for_llm_node(
self, tracer_provider_with_memory_exporter, memory_span_exporter, mock_llm_node, mock_result_event
@@ -176,7 +176,7 @@ class TestObservabilityLayerParserIntegration:
assert attrs["gen_ai.completion"] == "test completion"
assert attrs["gen_ai.response.finish_reason"] == "stop"
- @patch("core.workflow.graph_engine.layers.observability.dify_config.ENABLE_OTEL", True)
+ @patch("core.app.workflow.layers.observability.dify_config.ENABLE_OTEL", True)
@pytest.mark.usefixtures("mock_is_instrument_flag_enabled_false")
def test_retrieval_parser_used_for_retrieval_node(
self, tracer_provider_with_memory_exporter, memory_span_exporter, mock_retrieval_node, mock_result_event
@@ -204,7 +204,7 @@ class TestObservabilityLayerParserIntegration:
assert attrs["retrieval.query"] == "test query"
assert "retrieval.document" in attrs
- @patch("core.workflow.graph_engine.layers.observability.dify_config.ENABLE_OTEL", True)
+ @patch("core.app.workflow.layers.observability.dify_config.ENABLE_OTEL", True)
@pytest.mark.usefixtures("mock_is_instrument_flag_enabled_false")
def test_result_event_extracts_inputs_and_outputs(
self, tracer_provider_with_memory_exporter, memory_span_exporter, mock_start_node, mock_result_event
@@ -235,7 +235,7 @@ class TestObservabilityLayerParserIntegration:
class TestObservabilityLayerGraphLifecycle:
"""Test graph lifecycle management."""
- @patch("core.workflow.graph_engine.layers.observability.dify_config.ENABLE_OTEL", True)
+ @patch("core.app.workflow.layers.observability.dify_config.ENABLE_OTEL", True)
@pytest.mark.usefixtures("mock_is_instrument_flag_enabled_false")
def test_on_graph_start_clears_contexts(self, tracer_provider_with_memory_exporter, mock_llm_node):
"""Test that on_graph_start clears node contexts."""
@@ -248,7 +248,7 @@ class TestObservabilityLayerGraphLifecycle:
layer.on_graph_start()
assert len(layer._node_contexts) == 0
- @patch("core.workflow.graph_engine.layers.observability.dify_config.ENABLE_OTEL", True)
+ @patch("core.app.workflow.layers.observability.dify_config.ENABLE_OTEL", True)
@pytest.mark.usefixtures("mock_is_instrument_flag_enabled_false")
def test_on_graph_end_with_no_unfinished_spans(
self, tracer_provider_with_memory_exporter, memory_span_exporter, mock_llm_node
@@ -264,7 +264,7 @@ class TestObservabilityLayerGraphLifecycle:
spans = memory_span_exporter.get_finished_spans()
assert len(spans) == 1
- @patch("core.workflow.graph_engine.layers.observability.dify_config.ENABLE_OTEL", True)
+ @patch("core.app.workflow.layers.observability.dify_config.ENABLE_OTEL", True)
@pytest.mark.usefixtures("mock_is_instrument_flag_enabled_false")
def test_on_graph_end_with_unfinished_spans_logs_warning(
self, tracer_provider_with_memory_exporter, mock_llm_node, caplog
@@ -285,7 +285,7 @@ class TestObservabilityLayerGraphLifecycle:
class TestObservabilityLayerDisabledMode:
"""Test behavior when layer is disabled."""
- @patch("core.workflow.graph_engine.layers.observability.dify_config.ENABLE_OTEL", False)
+ @patch("core.app.workflow.layers.observability.dify_config.ENABLE_OTEL", False)
@pytest.mark.usefixtures("mock_is_instrument_flag_enabled_false")
def test_disabled_mode_skips_node_start(self, memory_span_exporter, mock_start_node):
"""Test that disabled layer doesn't create spans on node start."""
@@ -299,7 +299,7 @@ class TestObservabilityLayerDisabledMode:
spans = memory_span_exporter.get_finished_spans()
assert len(spans) == 0
- @patch("core.workflow.graph_engine.layers.observability.dify_config.ENABLE_OTEL", False)
+ @patch("core.app.workflow.layers.observability.dify_config.ENABLE_OTEL", False)
@pytest.mark.usefixtures("mock_is_instrument_flag_enabled_false")
def test_disabled_mode_skips_node_end(self, memory_span_exporter, mock_llm_node):
"""Test that disabled layer doesn't process node end."""
diff --git a/api/tests/unit_tests/core/workflow/nodes/http_request/test_http_request_executor.py b/api/tests/unit_tests/core/workflow/nodes/http_request/test_http_request_executor.py
index 2a9db2d328..cefc4967ac 100644
--- a/api/tests/unit_tests/core/workflow/nodes/http_request/test_http_request_executor.py
+++ b/api/tests/unit_tests/core/workflow/nodes/http_request/test_http_request_executor.py
@@ -16,7 +16,7 @@ from core.workflow.system_variable import SystemVariable
def test_executor_with_json_body_and_number_variable():
# Prepare the variable pool
variable_pool = VariablePool(
- system_variables=SystemVariable.empty(),
+ system_variables=SystemVariable.default(),
user_inputs={},
)
variable_pool.add(["pre_node_id", "number"], 42)
@@ -69,7 +69,7 @@ def test_executor_with_json_body_and_number_variable():
def test_executor_with_json_body_and_object_variable():
# Prepare the variable pool
variable_pool = VariablePool(
- system_variables=SystemVariable.empty(),
+ system_variables=SystemVariable.default(),
user_inputs={},
)
variable_pool.add(["pre_node_id", "object"], {"name": "John Doe", "age": 30, "email": "john@example.com"})
@@ -124,7 +124,7 @@ def test_executor_with_json_body_and_object_variable():
def test_executor_with_json_body_and_nested_object_variable():
# Prepare the variable pool
variable_pool = VariablePool(
- system_variables=SystemVariable.empty(),
+ system_variables=SystemVariable.default(),
user_inputs={},
)
variable_pool.add(["pre_node_id", "object"], {"name": "John Doe", "age": 30, "email": "john@example.com"})
@@ -178,7 +178,7 @@ def test_executor_with_json_body_and_nested_object_variable():
def test_extract_selectors_from_template_with_newline():
- variable_pool = VariablePool(system_variables=SystemVariable.empty())
+ variable_pool = VariablePool(system_variables=SystemVariable.default())
variable_pool.add(("node_id", "custom_query"), "line1\nline2")
node_data = HttpRequestNodeData(
title="Test JSON Body with Nested Object Variable",
@@ -205,7 +205,7 @@ def test_extract_selectors_from_template_with_newline():
def test_executor_with_form_data():
# Prepare the variable pool
variable_pool = VariablePool(
- system_variables=SystemVariable.empty(),
+ system_variables=SystemVariable.default(),
user_inputs={},
)
variable_pool.add(["pre_node_id", "text_field"], "Hello, World!")
@@ -290,7 +290,7 @@ def test_init_headers():
return Executor(
node_data=node_data,
timeout=timeout,
- variable_pool=VariablePool(system_variables=SystemVariable.empty()),
+ variable_pool=VariablePool(system_variables=SystemVariable.default()),
)
executor = create_executor("aa\n cc:")
@@ -324,7 +324,7 @@ def test_init_params():
return Executor(
node_data=node_data,
timeout=timeout,
- variable_pool=VariablePool(system_variables=SystemVariable.empty()),
+ variable_pool=VariablePool(system_variables=SystemVariable.default()),
)
# Test basic key-value pairs
@@ -355,7 +355,7 @@ def test_init_params():
def test_empty_api_key_raises_error_bearer():
"""Test that empty API key raises AuthorizationConfigError for bearer auth."""
- variable_pool = VariablePool(system_variables=SystemVariable.empty())
+ variable_pool = VariablePool(system_variables=SystemVariable.default())
node_data = HttpRequestNodeData(
title="test",
method="get",
@@ -379,7 +379,7 @@ def test_empty_api_key_raises_error_bearer():
def test_empty_api_key_raises_error_basic():
"""Test that empty API key raises AuthorizationConfigError for basic auth."""
- variable_pool = VariablePool(system_variables=SystemVariable.empty())
+ variable_pool = VariablePool(system_variables=SystemVariable.default())
node_data = HttpRequestNodeData(
title="test",
method="get",
@@ -403,7 +403,7 @@ def test_empty_api_key_raises_error_basic():
def test_empty_api_key_raises_error_custom():
"""Test that empty API key raises AuthorizationConfigError for custom auth."""
- variable_pool = VariablePool(system_variables=SystemVariable.empty())
+ variable_pool = VariablePool(system_variables=SystemVariable.default())
node_data = HttpRequestNodeData(
title="test",
method="get",
@@ -427,7 +427,7 @@ def test_empty_api_key_raises_error_custom():
def test_whitespace_only_api_key_raises_error():
"""Test that whitespace-only API key raises AuthorizationConfigError."""
- variable_pool = VariablePool(system_variables=SystemVariable.empty())
+ variable_pool = VariablePool(system_variables=SystemVariable.default())
node_data = HttpRequestNodeData(
title="test",
method="get",
@@ -451,7 +451,7 @@ def test_whitespace_only_api_key_raises_error():
def test_valid_api_key_works():
"""Test that valid API key works correctly for bearer auth."""
- variable_pool = VariablePool(system_variables=SystemVariable.empty())
+ variable_pool = VariablePool(system_variables=SystemVariable.default())
node_data = HttpRequestNodeData(
title="test",
method="get",
@@ -487,7 +487,7 @@ def test_executor_with_json_body_and_unquoted_uuid_variable():
test_uuid = "57eeeeb1-450b-482c-81b9-4be77e95dee2"
variable_pool = VariablePool(
- system_variables=SystemVariable.empty(),
+ system_variables=SystemVariable.default(),
user_inputs={},
)
variable_pool.add(["pre_node_id", "uuid"], test_uuid)
@@ -531,7 +531,7 @@ def test_executor_with_json_body_and_unquoted_uuid_with_newlines():
test_uuid = "57eeeeb1-450b-482c-81b9-4be77e95dee2"
variable_pool = VariablePool(
- system_variables=SystemVariable.empty(),
+ system_variables=SystemVariable.default(),
user_inputs={},
)
variable_pool.add(["pre_node_id", "uuid"], test_uuid)
@@ -569,7 +569,7 @@ def test_executor_with_json_body_and_unquoted_uuid_with_newlines():
def test_executor_with_json_body_preserves_numbers_and_strings():
"""Test that numbers are preserved and string values are properly quoted."""
variable_pool = VariablePool(
- system_variables=SystemVariable.empty(),
+ system_variables=SystemVariable.default(),
user_inputs={},
)
variable_pool.add(["node", "count"], 42)
diff --git a/api/tests/unit_tests/core/workflow/nodes/llm/test_node.py b/api/tests/unit_tests/core/workflow/nodes/llm/test_node.py
index 77264022bc..3d1b8b2f27 100644
--- a/api/tests/unit_tests/core/workflow/nodes/llm/test_node.py
+++ b/api/tests/unit_tests/core/workflow/nodes/llm/test_node.py
@@ -86,7 +86,7 @@ def graph_init_params() -> GraphInitParams:
@pytest.fixture
def graph_runtime_state() -> GraphRuntimeState:
variable_pool = VariablePool(
- system_variables=SystemVariable.empty(),
+ system_variables=SystemVariable.default(),
user_inputs={},
)
return GraphRuntimeState(
diff --git a/api/tests/unit_tests/core/workflow/nodes/webhook/test_webhook_file_conversion.py b/api/tests/unit_tests/core/workflow/nodes/webhook/test_webhook_file_conversion.py
index ead2334473..d8f6b41f89 100644
--- a/api/tests/unit_tests/core/workflow/nodes/webhook/test_webhook_file_conversion.py
+++ b/api/tests/unit_tests/core/workflow/nodes/webhook/test_webhook_file_conversion.py
@@ -111,7 +111,7 @@ def test_webhook_node_file_conversion_to_file_variable():
)
variable_pool = VariablePool(
- system_variables=SystemVariable.empty(),
+ system_variables=SystemVariable.default(),
user_inputs={
"webhook_data": {
"headers": {},
@@ -184,7 +184,7 @@ def test_webhook_node_file_conversion_with_missing_files():
)
variable_pool = VariablePool(
- system_variables=SystemVariable.empty(),
+ system_variables=SystemVariable.default(),
user_inputs={
"webhook_data": {
"headers": {},
@@ -219,7 +219,7 @@ def test_webhook_node_file_conversion_with_none_file():
)
variable_pool = VariablePool(
- system_variables=SystemVariable.empty(),
+ system_variables=SystemVariable.default(),
user_inputs={
"webhook_data": {
"headers": {},
@@ -256,7 +256,7 @@ def test_webhook_node_file_conversion_with_non_dict_file():
)
variable_pool = VariablePool(
- system_variables=SystemVariable.empty(),
+ system_variables=SystemVariable.default(),
user_inputs={
"webhook_data": {
"headers": {},
@@ -300,7 +300,7 @@ def test_webhook_node_file_conversion_mixed_parameters():
)
variable_pool = VariablePool(
- system_variables=SystemVariable.empty(),
+ system_variables=SystemVariable.default(),
user_inputs={
"webhook_data": {
"headers": {},
@@ -370,7 +370,7 @@ def test_webhook_node_different_file_types():
)
variable_pool = VariablePool(
- system_variables=SystemVariable.empty(),
+ system_variables=SystemVariable.default(),
user_inputs={
"webhook_data": {
"headers": {},
@@ -430,7 +430,7 @@ def test_webhook_node_file_conversion_with_non_dict_wrapper():
)
variable_pool = VariablePool(
- system_variables=SystemVariable.empty(),
+ system_variables=SystemVariable.default(),
user_inputs={
"webhook_data": {
"headers": {},
diff --git a/api/tests/unit_tests/core/workflow/nodes/webhook/test_webhook_node.py b/api/tests/unit_tests/core/workflow/nodes/webhook/test_webhook_node.py
index bbb5511923..3b5aedebca 100644
--- a/api/tests/unit_tests/core/workflow/nodes/webhook/test_webhook_node.py
+++ b/api/tests/unit_tests/core/workflow/nodes/webhook/test_webhook_node.py
@@ -75,7 +75,7 @@ def test_webhook_node_basic_initialization():
)
variable_pool = VariablePool(
- system_variables=SystemVariable.empty(),
+ system_variables=SystemVariable.default(),
user_inputs={},
)
@@ -118,7 +118,7 @@ def test_webhook_node_run_with_headers():
)
variable_pool = VariablePool(
- system_variables=SystemVariable.empty(),
+ system_variables=SystemVariable.default(),
user_inputs={
"webhook_data": {
"headers": {
@@ -154,7 +154,7 @@ def test_webhook_node_run_with_query_params():
)
variable_pool = VariablePool(
- system_variables=SystemVariable.empty(),
+ system_variables=SystemVariable.default(),
user_inputs={
"webhook_data": {
"headers": {},
@@ -190,7 +190,7 @@ def test_webhook_node_run_with_body_params():
)
variable_pool = VariablePool(
- system_variables=SystemVariable.empty(),
+ system_variables=SystemVariable.default(),
user_inputs={
"webhook_data": {
"headers": {},
@@ -249,7 +249,7 @@ def test_webhook_node_run_with_file_params():
)
variable_pool = VariablePool(
- system_variables=SystemVariable.empty(),
+ system_variables=SystemVariable.default(),
user_inputs={
"webhook_data": {
"headers": {},
@@ -302,7 +302,7 @@ def test_webhook_node_run_mixed_parameters():
)
variable_pool = VariablePool(
- system_variables=SystemVariable.empty(),
+ system_variables=SystemVariable.default(),
user_inputs={
"webhook_data": {
"headers": {"Authorization": "Bearer token"},
@@ -342,7 +342,7 @@ def test_webhook_node_run_empty_webhook_data():
)
variable_pool = VariablePool(
- system_variables=SystemVariable.empty(),
+ system_variables=SystemVariable.default(),
user_inputs={}, # No webhook_data
)
@@ -368,7 +368,7 @@ def test_webhook_node_run_case_insensitive_headers():
)
variable_pool = VariablePool(
- system_variables=SystemVariable.empty(),
+ system_variables=SystemVariable.default(),
user_inputs={
"webhook_data": {
"headers": {
@@ -398,7 +398,7 @@ def test_webhook_node_variable_pool_user_inputs():
# Add some additional variables to the pool
variable_pool = VariablePool(
- system_variables=SystemVariable.empty(),
+ system_variables=SystemVariable.default(),
user_inputs={
"webhook_data": {"headers": {}, "query_params": {}, "body": {}, "files": {}},
"other_var": "should_be_included",
@@ -429,7 +429,7 @@ def test_webhook_node_different_methods(method):
)
variable_pool = VariablePool(
- system_variables=SystemVariable.empty(),
+ system_variables=SystemVariable.default(),
user_inputs={
"webhook_data": {
"headers": {},
diff --git a/api/tests/unit_tests/core/workflow/test_workflow_entry.py b/api/tests/unit_tests/core/workflow/test_workflow_entry.py
index b38e070ffc..27ffa455d6 100644
--- a/api/tests/unit_tests/core/workflow/test_workflow_entry.py
+++ b/api/tests/unit_tests/core/workflow/test_workflow_entry.py
@@ -127,7 +127,7 @@ class TestWorkflowEntry:
return node_config
workflow = StubWorkflow()
- variable_pool = VariablePool(system_variables=SystemVariable.empty(), user_inputs={})
+ variable_pool = VariablePool(system_variables=SystemVariable.default(), user_inputs={})
expected_limits = CodeNodeLimits(
max_string_length=dify_config.CODE_MAX_STRING_LENGTH,
max_number=dify_config.CODE_MAX_NUMBER,
@@ -157,7 +157,7 @@ class TestWorkflowEntry:
# Initialize variable pool with environment variables
env_var = StringVariable(name="API_KEY", value="existing_key")
variable_pool = VariablePool(
- system_variables=SystemVariable.empty(),
+ system_variables=SystemVariable.default(),
environment_variables=[env_var],
user_inputs={},
)
@@ -198,7 +198,7 @@ class TestWorkflowEntry:
# Initialize variable pool with conversation variables
conv_var = StringVariable(name="last_message", value="Hello")
variable_pool = VariablePool(
- system_variables=SystemVariable.empty(),
+ system_variables=SystemVariable.default(),
conversation_variables=[conv_var],
user_inputs={},
)
@@ -239,7 +239,7 @@ class TestWorkflowEntry:
"""Test mapping regular node variables from user inputs to variable pool."""
# Initialize empty variable pool
variable_pool = VariablePool(
- system_variables=SystemVariable.empty(),
+ system_variables=SystemVariable.default(),
user_inputs={},
)
@@ -281,7 +281,7 @@ class TestWorkflowEntry:
def test_mapping_user_inputs_with_file_handling(self):
"""Test mapping file inputs from user inputs to variable pool."""
variable_pool = VariablePool(
- system_variables=SystemVariable.empty(),
+ system_variables=SystemVariable.default(),
user_inputs={},
)
@@ -340,7 +340,7 @@ class TestWorkflowEntry:
def test_mapping_user_inputs_missing_variable_error(self):
"""Test that mapping raises error when required variable is missing."""
variable_pool = VariablePool(
- system_variables=SystemVariable.empty(),
+ system_variables=SystemVariable.default(),
user_inputs={},
)
@@ -366,7 +366,7 @@ class TestWorkflowEntry:
def test_mapping_user_inputs_with_alternative_key_format(self):
"""Test mapping with alternative key format (without node prefix)."""
variable_pool = VariablePool(
- system_variables=SystemVariable.empty(),
+ system_variables=SystemVariable.default(),
user_inputs={},
)
@@ -396,7 +396,7 @@ class TestWorkflowEntry:
def test_mapping_user_inputs_with_complex_selectors(self):
"""Test mapping with complex node variable keys."""
variable_pool = VariablePool(
- system_variables=SystemVariable.empty(),
+ system_variables=SystemVariable.default(),
user_inputs={},
)
@@ -432,7 +432,7 @@ class TestWorkflowEntry:
def test_mapping_user_inputs_invalid_node_variable(self):
"""Test that mapping handles invalid node variable format."""
variable_pool = VariablePool(
- system_variables=SystemVariable.empty(),
+ system_variables=SystemVariable.default(),
user_inputs={},
)
diff --git a/web/app/components/app/configuration/debug/chat-user-input.spec.tsx b/web/app/components/app/configuration/debug/chat-user-input.spec.tsx
new file mode 100644
index 0000000000..e6678ebf29
--- /dev/null
+++ b/web/app/components/app/configuration/debug/chat-user-input.spec.tsx
@@ -0,0 +1,710 @@
+import type { Inputs, ModelConfig } from '@/models/debug'
+import type { PromptVariable } from '@/types/app'
+import { fireEvent, render, screen } from '@testing-library/react'
+import ChatUserInput from './chat-user-input'
+
+const mockSetInputs = vi.fn()
+const mockUseContext = vi.fn()
+
+vi.mock('react-i18next', () => ({
+ useTranslation: () => ({
+ t: (key: string) => key,
+ }),
+}))
+
+vi.mock('use-context-selector', () => ({
+ useContext: () => mockUseContext(),
+ createContext: vi.fn(() => ({})),
+}))
+
+vi.mock('@/app/components/base/input', () => ({
+ default: ({ value, onChange, placeholder, autoFocus, maxLength, readOnly, type }: {
+ value: string
+ onChange: (e: { target: { value: string } }) => void
+ placeholder?: string
+ autoFocus?: boolean
+ maxLength?: number
+ readOnly?: boolean
+ type?: string
+ }) => (
+
+ ),
+}))
+
+vi.mock('@/app/components/base/select', () => ({
+ default: ({ defaultValue, onSelect, items, disabled, className }: {
+ defaultValue: string
+ onSelect: (item: { value: string }) => void
+ items: { name: string, value: string }[]
+ allowSearch?: boolean
+ disabled?: boolean
+ className?: string
+ }) => (
+
+ ),
+}))
+
+vi.mock('@/app/components/base/textarea', () => ({
+ default: ({ value, onChange, placeholder, readOnly, className }: {
+ value: string
+ onChange: (e: { target: { value: string } }) => void
+ placeholder?: string
+ readOnly?: boolean
+ className?: string
+ }) => (
+
+ ),
+}))
+
+vi.mock('@/app/components/workflow/nodes/_base/components/before-run-form/bool-input', () => ({
+ default: ({ name, value, required, onChange, readonly }: {
+ name: string
+ value: boolean
+ required?: boolean
+ onChange: (value: boolean) => void
+ readonly?: boolean
+ }) => (
+
+ onChange(e.target.checked)}
+ disabled={readonly}
+ data-required={required}
+ />
+ {name}
+
+ ),
+}))
+
+// Extended type to match runtime behavior (includes 'paragraph', 'checkbox', 'default')
+type ExtendedPromptVariable = {
+ key: string
+ name: string
+ type: 'string' | 'number' | 'select' | 'paragraph' | 'checkbox'
+ required: boolean
+ options?: string[]
+ max_length?: number
+ default?: string | null
+}
+
+const createPromptVariable = (overrides: Partial = {}): ExtendedPromptVariable => ({
+ key: 'test-key',
+ name: 'Test Name',
+ type: 'string',
+ required: false,
+ ...overrides,
+})
+
+const createModelConfig = (promptVariables: ExtendedPromptVariable[] = []): ModelConfig => ({
+ provider: 'openai',
+ model_id: 'gpt-4',
+ mode: 'chat',
+ configs: {
+ prompt_template: '',
+ prompt_variables: promptVariables as PromptVariable[],
+ },
+} as ModelConfig)
+
+const createContextValue = (overrides: Partial<{
+ modelConfig: ModelConfig
+ setInputs: (inputs: Inputs) => void
+ readonly: boolean
+}> = {}) => ({
+ modelConfig: createModelConfig(),
+ setInputs: mockSetInputs,
+ readonly: false,
+ ...overrides,
+})
+
+describe('ChatUserInput', () => {
+ beforeEach(() => {
+ vi.clearAllMocks()
+ mockUseContext.mockReturnValue(createContextValue())
+ })
+
+ describe('Rendering', () => {
+ it('should return null when no prompt variables exist', () => {
+ mockUseContext.mockReturnValue(createContextValue({
+ modelConfig: createModelConfig([]),
+ }))
+
+ const { container } = render()
+ expect(container.firstChild).toBeNull()
+ })
+
+ it('should return null when prompt variables have empty keys', () => {
+ mockUseContext.mockReturnValue(createContextValue({
+ modelConfig: createModelConfig([
+ createPromptVariable({ key: '', name: 'Test' }),
+ createPromptVariable({ key: ' ', name: 'Test2' }),
+ ]),
+ }))
+
+ const { container } = render()
+ expect(container.firstChild).toBeNull()
+ })
+
+ it('should return null when prompt variables have empty names', () => {
+ mockUseContext.mockReturnValue(createContextValue({
+ modelConfig: createModelConfig([
+ createPromptVariable({ key: 'key1', name: '' }),
+ createPromptVariable({ key: 'key2', name: ' ' }),
+ ]),
+ }))
+
+ const { container } = render()
+ expect(container.firstChild).toBeNull()
+ })
+
+ it('should render string input type', () => {
+ mockUseContext.mockReturnValue(createContextValue({
+ modelConfig: createModelConfig([
+ createPromptVariable({ key: 'name', name: 'Name', type: 'string' }),
+ ]),
+ }))
+
+ render()
+ expect(screen.getByTestId('input-Name')).toBeInTheDocument()
+ })
+
+ it('should render paragraph input type', () => {
+ mockUseContext.mockReturnValue(createContextValue({
+ modelConfig: createModelConfig([
+ createPromptVariable({ key: 'description', name: 'Description', type: 'paragraph' }),
+ ]),
+ }))
+
+ render()
+ expect(screen.getByTestId('textarea-Description')).toBeInTheDocument()
+ })
+
+ it('should render select input type', () => {
+ mockUseContext.mockReturnValue(createContextValue({
+ modelConfig: createModelConfig([
+ createPromptVariable({ key: 'choice', name: 'Choice', type: 'select', options: ['A', 'B', 'C'] }),
+ ]),
+ }))
+
+ render()
+ expect(screen.getByTestId('select-input')).toBeInTheDocument()
+ expect(screen.getByText('A')).toBeInTheDocument()
+ expect(screen.getByText('B')).toBeInTheDocument()
+ expect(screen.getByText('C')).toBeInTheDocument()
+ })
+
+ it('should render number input type', () => {
+ mockUseContext.mockReturnValue(createContextValue({
+ modelConfig: createModelConfig([
+ createPromptVariable({ key: 'count', name: 'Count', type: 'number' }),
+ ]),
+ }))
+
+ render()
+ const input = screen.getByTestId('input-Count')
+ expect(input).toBeInTheDocument()
+ expect(input).toHaveAttribute('type', 'number')
+ })
+
+ it('should render checkbox input type', () => {
+ mockUseContext.mockReturnValue(createContextValue({
+ modelConfig: createModelConfig([
+ createPromptVariable({ key: 'enabled', name: 'Enabled', type: 'checkbox' }),
+ ]),
+ }))
+
+ render()
+ expect(screen.getByTestId('bool-input-Enabled')).toBeInTheDocument()
+ })
+
+ it('should render multiple input types', () => {
+ mockUseContext.mockReturnValue(createContextValue({
+ modelConfig: createModelConfig([
+ createPromptVariable({ key: 'name', name: 'Name', type: 'string' }),
+ createPromptVariable({ key: 'desc', name: 'Description', type: 'paragraph' }),
+ createPromptVariable({ key: 'choice', name: 'Choice', type: 'select', options: ['X', 'Y'] }),
+ ]),
+ }))
+
+ render()
+ expect(screen.getByTestId('input-Name')).toBeInTheDocument()
+ expect(screen.getByTestId('textarea-Description')).toBeInTheDocument()
+ expect(screen.getByTestId('select-input')).toBeInTheDocument()
+ })
+
+ it('should show optional label for non-required fields', () => {
+ mockUseContext.mockReturnValue(createContextValue({
+ modelConfig: createModelConfig([
+ createPromptVariable({ key: 'name', name: 'Name', type: 'string', required: false }),
+ ]),
+ }))
+
+ render()
+ expect(screen.getByText('panel.optional')).toBeInTheDocument()
+ })
+
+ it('should not show optional label for required fields', () => {
+ mockUseContext.mockReturnValue(createContextValue({
+ modelConfig: createModelConfig([
+ createPromptVariable({ key: 'name', name: 'Name', type: 'string', required: true }),
+ ]),
+ }))
+
+ render()
+ expect(screen.queryByText('panel.optional')).not.toBeInTheDocument()
+ })
+
+ it('should use key as label when name is not provided', () => {
+ mockUseContext.mockReturnValue(createContextValue({
+ modelConfig: createModelConfig([
+ createPromptVariable({ key: 'myKey', name: '', type: 'string' }),
+ ]),
+ }))
+
+ // This should actually return null because name is empty
+ const { container } = render()
+ expect(container.firstChild).toBeNull()
+ })
+ })
+
+ describe('Input Values', () => {
+ it('should display existing input values for string type', () => {
+ mockUseContext.mockReturnValue(createContextValue({
+ modelConfig: createModelConfig([
+ createPromptVariable({ key: 'name', name: 'Name', type: 'string' }),
+ ]),
+ }))
+
+ render()
+ expect(screen.getByTestId('input-Name')).toHaveValue('John')
+ })
+
+ it('should display existing input values for paragraph type', () => {
+ mockUseContext.mockReturnValue(createContextValue({
+ modelConfig: createModelConfig([
+ createPromptVariable({ key: 'desc', name: 'Description', type: 'paragraph' }),
+ ]),
+ }))
+
+ render()
+ expect(screen.getByTestId('textarea-Description')).toHaveValue('Long text here')
+ })
+
+ it('should display existing input values for number type', () => {
+ mockUseContext.mockReturnValue(createContextValue({
+ modelConfig: createModelConfig([
+ createPromptVariable({ key: 'count', name: 'Count', type: 'number' }),
+ ]),
+ }))
+
+ render()
+ // Number type input still uses string value internally
+ expect(screen.getByTestId('input-Count')).toHaveValue(42)
+ })
+
+ it('should display checkbox as checked when value is truthy', () => {
+ mockUseContext.mockReturnValue(createContextValue({
+ modelConfig: createModelConfig([
+ createPromptVariable({ key: 'enabled', name: 'Enabled', type: 'checkbox' }),
+ ]),
+ }))
+
+ render()
+ const checkbox = screen.getByTestId('bool-input-Enabled').querySelector('input')
+ expect(checkbox).toBeChecked()
+ })
+
+ it('should display checkbox as unchecked when value is falsy', () => {
+ mockUseContext.mockReturnValue(createContextValue({
+ modelConfig: createModelConfig([
+ createPromptVariable({ key: 'enabled', name: 'Enabled', type: 'checkbox' }),
+ ]),
+ }))
+
+ render()
+ const checkbox = screen.getByTestId('bool-input-Enabled').querySelector('input')
+ expect(checkbox).not.toBeChecked()
+ })
+
+ it('should handle empty string values', () => {
+ mockUseContext.mockReturnValue(createContextValue({
+ modelConfig: createModelConfig([
+ createPromptVariable({ key: 'name', name: 'Name', type: 'string' }),
+ ]),
+ }))
+
+ render()
+ expect(screen.getByTestId('input-Name')).toHaveValue('')
+ })
+
+ it('should handle undefined values', () => {
+ mockUseContext.mockReturnValue(createContextValue({
+ modelConfig: createModelConfig([
+ createPromptVariable({ key: 'name', name: 'Name', type: 'string' }),
+ ]),
+ }))
+
+ render()
+ expect(screen.getByTestId('input-Name')).toHaveValue('')
+ })
+ })
+
+ describe('User Interactions', () => {
+ it('should call setInputs when string input changes', () => {
+ mockUseContext.mockReturnValue(createContextValue({
+ modelConfig: createModelConfig([
+ createPromptVariable({ key: 'name', name: 'Name', type: 'string' }),
+ ]),
+ }))
+
+ render()
+ fireEvent.change(screen.getByTestId('input-Name'), { target: { value: 'New Value' } })
+
+ expect(mockSetInputs).toHaveBeenCalledWith({ name: 'New Value' })
+ })
+
+ it('should call setInputs when paragraph input changes', () => {
+ mockUseContext.mockReturnValue(createContextValue({
+ modelConfig: createModelConfig([
+ createPromptVariable({ key: 'desc', name: 'Description', type: 'paragraph' }),
+ ]),
+ }))
+
+ render()
+ fireEvent.change(screen.getByTestId('textarea-Description'), { target: { value: 'New Description' } })
+
+ expect(mockSetInputs).toHaveBeenCalledWith({ desc: 'New Description' })
+ })
+
+ it('should call setInputs when select input changes', () => {
+ mockUseContext.mockReturnValue(createContextValue({
+ modelConfig: createModelConfig([
+ createPromptVariable({ key: 'choice', name: 'Choice', type: 'select', options: ['A', 'B', 'C'] }),
+ ]),
+ }))
+
+ render()
+ fireEvent.change(screen.getByTestId('select-input'), { target: { value: 'B' } })
+
+ expect(mockSetInputs).toHaveBeenCalledWith({ choice: 'B' })
+ })
+
+ it('should call setInputs when number input changes', () => {
+ mockUseContext.mockReturnValue(createContextValue({
+ modelConfig: createModelConfig([
+ createPromptVariable({ key: 'count', name: 'Count', type: 'number' }),
+ ]),
+ }))
+
+ render()
+ fireEvent.change(screen.getByTestId('input-Count'), { target: { value: '100' } })
+
+ expect(mockSetInputs).toHaveBeenCalledWith({ count: '100' })
+ })
+
+ it('should call setInputs when checkbox changes', () => {
+ mockUseContext.mockReturnValue(createContextValue({
+ modelConfig: createModelConfig([
+ createPromptVariable({ key: 'enabled', name: 'Enabled', type: 'checkbox' }),
+ ]),
+ }))
+
+ render()
+ const checkbox = screen.getByTestId('bool-input-Enabled').querySelector('input')!
+ fireEvent.click(checkbox)
+
+ expect(mockSetInputs).toHaveBeenCalledWith({ enabled: true })
+ })
+
+ it('should not call setInputs for unknown keys', () => {
+ mockUseContext.mockReturnValue(createContextValue({
+ modelConfig: createModelConfig([
+ createPromptVariable({ key: 'name', name: 'Name', type: 'string' }),
+ ]),
+ }))
+
+ render()
+
+ // The component filters by promptVariableObj, so unknown keys won't trigger updates
+ // This is tested indirectly - only valid keys should trigger setInputs
+ fireEvent.change(screen.getByTestId('input-Name'), { target: { value: 'Valid' } })
+
+ expect(mockSetInputs).toHaveBeenCalledTimes(1)
+ expect(mockSetInputs).toHaveBeenCalledWith({ name: 'Valid' })
+ })
+ })
+
+ describe('Readonly Mode', () => {
+ it('should set string input as readonly when readonly is true', () => {
+ mockUseContext.mockReturnValue(createContextValue({
+ modelConfig: createModelConfig([
+ createPromptVariable({ key: 'name', name: 'Name', type: 'string' }),
+ ]),
+ readonly: true,
+ }))
+
+ render()
+ expect(screen.getByTestId('input-Name')).toHaveAttribute('readonly')
+ })
+
+ it('should set paragraph input as readonly when readonly is true', () => {
+ mockUseContext.mockReturnValue(createContextValue({
+ modelConfig: createModelConfig([
+ createPromptVariable({ key: 'desc', name: 'Description', type: 'paragraph' }),
+ ]),
+ readonly: true,
+ }))
+
+ render()
+ expect(screen.getByTestId('textarea-Description')).toHaveAttribute('readonly')
+ })
+
+ it('should disable select when readonly is true', () => {
+ mockUseContext.mockReturnValue(createContextValue({
+ modelConfig: createModelConfig([
+ createPromptVariable({ key: 'choice', name: 'Choice', type: 'select', options: ['A', 'B'] }),
+ ]),
+ readonly: true,
+ }))
+
+ render()
+ expect(screen.getByTestId('select-input')).toBeDisabled()
+ })
+
+ it('should disable checkbox when readonly is true', () => {
+ mockUseContext.mockReturnValue(createContextValue({
+ modelConfig: createModelConfig([
+ createPromptVariable({ key: 'enabled', name: 'Enabled', type: 'checkbox' }),
+ ]),
+ readonly: true,
+ }))
+
+ render()
+ const checkbox = screen.getByTestId('bool-input-Enabled').querySelector('input')
+ expect(checkbox).toBeDisabled()
+ })
+ })
+
+ describe('Default Values', () => {
+ it('should initialize inputs with default values when field is empty', () => {
+ mockUseContext.mockReturnValue(createContextValue({
+ modelConfig: createModelConfig([
+ createPromptVariable({ key: 'name', name: 'Name', type: 'string', default: 'Default Name' }),
+ ]),
+ }))
+
+ render()
+
+ expect(mockSetInputs).toHaveBeenCalledWith({ name: 'Default Name' })
+ })
+
+ it('should not override existing values with defaults', () => {
+ mockUseContext.mockReturnValue(createContextValue({
+ modelConfig: createModelConfig([
+ createPromptVariable({ key: 'name', name: 'Name', type: 'string', default: 'Default' }),
+ ]),
+ }))
+
+ render()
+
+ // setInputs should not be called since there's already a value
+ expect(mockSetInputs).not.toHaveBeenCalled()
+ })
+
+ it('should handle multiple default values', () => {
+ mockUseContext.mockReturnValue(createContextValue({
+ modelConfig: createModelConfig([
+ createPromptVariable({ key: 'name', name: 'Name', type: 'string', default: 'Default Name' }),
+ createPromptVariable({ key: 'count', name: 'Count', type: 'number', default: '10' }),
+ ]),
+ }))
+
+ render()
+
+ expect(mockSetInputs).toHaveBeenCalledWith({
+ name: 'Default Name',
+ count: '10',
+ })
+ })
+
+ it('should not set default when default is empty string', () => {
+ mockUseContext.mockReturnValue(createContextValue({
+ modelConfig: createModelConfig([
+ createPromptVariable({ key: 'name', name: 'Name', type: 'string', default: '' }),
+ ]),
+ }))
+
+ render()
+
+ expect(mockSetInputs).not.toHaveBeenCalled()
+ })
+
+ it('should not set default when default is undefined', () => {
+ mockUseContext.mockReturnValue(createContextValue({
+ modelConfig: createModelConfig([
+ createPromptVariable({ key: 'name', name: 'Name', type: 'string' }),
+ ]),
+ }))
+
+ render()
+
+ expect(mockSetInputs).not.toHaveBeenCalled()
+ })
+
+ it('should not set default when default is null', () => {
+ mockUseContext.mockReturnValue(createContextValue({
+ modelConfig: createModelConfig([
+ createPromptVariable({ key: 'name', name: 'Name', type: 'string', default: null as unknown as string }),
+ ]),
+ }))
+
+ render()
+
+ expect(mockSetInputs).not.toHaveBeenCalled()
+ })
+ })
+
+ describe('AutoFocus', () => {
+ it('should set autoFocus on first string input', () => {
+ mockUseContext.mockReturnValue(createContextValue({
+ modelConfig: createModelConfig([
+ createPromptVariable({ key: 'first', name: 'First', type: 'string' }),
+ createPromptVariable({ key: 'second', name: 'Second', type: 'string' }),
+ ]),
+ }))
+
+ render()
+ expect(screen.getByTestId('input-First')).toHaveAttribute('data-autofocus', 'true')
+ expect(screen.getByTestId('input-Second')).not.toHaveAttribute('data-autofocus')
+ })
+
+ it('should set autoFocus on first number input when it is the first field', () => {
+ mockUseContext.mockReturnValue(createContextValue({
+ modelConfig: createModelConfig([
+ createPromptVariable({ key: 'count', name: 'Count', type: 'number' }),
+ createPromptVariable({ key: 'name', name: 'Name', type: 'string' }),
+ ]),
+ }))
+
+ render()
+ expect(screen.getByTestId('input-Count')).toHaveAttribute('data-autofocus', 'true')
+ })
+ })
+
+ describe('MaxLength', () => {
+ it('should pass maxLength to string input', () => {
+ mockUseContext.mockReturnValue(createContextValue({
+ modelConfig: createModelConfig([
+ createPromptVariable({ key: 'name', name: 'Name', type: 'string', max_length: 50 }),
+ ]),
+ }))
+
+ render()
+ expect(screen.getByTestId('input-Name')).toHaveAttribute('maxLength', '50')
+ })
+
+ it('should pass maxLength to number input', () => {
+ mockUseContext.mockReturnValue(createContextValue({
+ modelConfig: createModelConfig([
+ createPromptVariable({ key: 'count', name: 'Count', type: 'number', max_length: 10 }),
+ ]),
+ }))
+
+ render()
+ expect(screen.getByTestId('input-Count')).toHaveAttribute('maxLength', '10')
+ })
+ })
+
+ describe('Edge Cases', () => {
+ it('should handle select with empty options', () => {
+ mockUseContext.mockReturnValue(createContextValue({
+ modelConfig: createModelConfig([
+ createPromptVariable({ key: 'choice', name: 'Choice', type: 'select', options: [] }),
+ ]),
+ }))
+
+ render()
+ const select = screen.getByTestId('select-input')
+ expect(select).toBeInTheDocument()
+ expect(select.children).toHaveLength(0)
+ })
+
+ it('should handle select with undefined options', () => {
+ mockUseContext.mockReturnValue(createContextValue({
+ modelConfig: createModelConfig([
+ createPromptVariable({ key: 'choice', name: 'Choice', type: 'select' }),
+ ]),
+ }))
+
+ render()
+ const select = screen.getByTestId('select-input')
+ expect(select).toBeInTheDocument()
+ })
+
+ it('should preserve other input values when updating one field', () => {
+ mockUseContext.mockReturnValue(createContextValue({
+ modelConfig: createModelConfig([
+ createPromptVariable({ key: 'name', name: 'Name', type: 'string' }),
+ createPromptVariable({ key: 'desc', name: 'Description', type: 'paragraph' }),
+ ]),
+ }))
+
+ render()
+ fireEvent.change(screen.getByTestId('input-Name'), { target: { value: 'Updated' } })
+
+ expect(mockSetInputs).toHaveBeenCalledWith({
+ name: 'Updated',
+ desc: 'Also Existing',
+ })
+ })
+
+ it('should convert non-string values to string for display', () => {
+ mockUseContext.mockReturnValue(createContextValue({
+ modelConfig: createModelConfig([
+ createPromptVariable({ key: 'value', name: 'Value', type: 'string' }),
+ ]),
+ }))
+
+ render()
+ expect(screen.getByTestId('input-Value')).toHaveValue('123')
+ })
+
+ it('should not hide label for checkbox type', () => {
+ mockUseContext.mockReturnValue(createContextValue({
+ modelConfig: createModelConfig([
+ createPromptVariable({ key: 'enabled', name: 'Is Enabled', type: 'checkbox' }),
+ ]),
+ }))
+
+ render()
+ // For checkbox, the label is rendered inside BoolInput, not in the header
+ expect(screen.queryByText('Is Enabled')).toBeInTheDocument()
+ })
+ })
+})
diff --git a/web/app/components/app/configuration/debug/debug-with-multiple-model/chat-item.spec.tsx b/web/app/components/app/configuration/debug/debug-with-multiple-model/chat-item.spec.tsx
new file mode 100644
index 0000000000..d621bb3941
--- /dev/null
+++ b/web/app/components/app/configuration/debug/debug-with-multiple-model/chat-item.spec.tsx
@@ -0,0 +1,641 @@
+import type { ModelAndParameter } from '../types'
+import type { ChatConfig, ChatItem as ChatItemType, OnSend } from '@/app/components/base/chat/types'
+import { render, screen } from '@testing-library/react'
+import { TransferMethod } from '@/app/components/base/chat/types'
+import { ModelFeatureEnum } from '@/app/components/header/account-setting/model-provider-page/declarations'
+import { APP_CHAT_WITH_MULTIPLE_MODEL, APP_CHAT_WITH_MULTIPLE_MODEL_RESTART } from '../types'
+import ChatItem from './chat-item'
+
+const mockUseAppContext = vi.fn()
+const mockUseDebugConfigurationContext = vi.fn()
+const mockUseProviderContext = vi.fn()
+const mockUseFeatures = vi.fn()
+const mockUseConfigFromDebugContext = vi.fn()
+const mockUseFormattingChangedSubscription = vi.fn()
+const mockUseChat = vi.fn()
+const mockUseEventEmitterContextContext = vi.fn()
+const mockFetchConversationMessages = vi.fn()
+const mockFetchSuggestedQuestions = vi.fn()
+const mockStopChatMessageResponding = vi.fn()
+
+let capturedChatProps: {
+ config: ChatConfig
+ chatList: ChatItemType[]
+ isResponding: boolean
+ onSend: OnSend
+ suggestedQuestions: string[]
+ allToolIcons: Record
+} | null = null
+
+let eventSubscriptionCallback: ((v: { type: string, payload?: Record }) => void) | null = null
+
+vi.mock('@/context/app-context', () => ({
+ useAppContext: () => mockUseAppContext(),
+}))
+
+vi.mock('@/context/debug-configuration', () => ({
+ useDebugConfigurationContext: () => mockUseDebugConfigurationContext(),
+}))
+
+vi.mock('@/context/provider-context', () => ({
+ useProviderContext: () => mockUseProviderContext(),
+}))
+
+vi.mock('@/app/components/base/features/hooks', () => ({
+ useFeatures: (selector: (state: Record) => unknown) => mockUseFeatures(selector),
+}))
+
+vi.mock('../hooks', () => ({
+ useConfigFromDebugContext: () => mockUseConfigFromDebugContext(),
+ useFormattingChangedSubscription: (chatList: ChatItemType[]) => mockUseFormattingChangedSubscription(chatList),
+}))
+
+vi.mock('@/app/components/base/chat/chat/hooks', () => ({
+ useChat: () => mockUseChat(),
+}))
+
+vi.mock('@/context/event-emitter', () => ({
+ useEventEmitterContextContext: () => mockUseEventEmitterContextContext(),
+}))
+
+vi.mock('@/service/debug', () => ({
+ fetchConversationMessages: (...args: unknown[]) => mockFetchConversationMessages(...args),
+ fetchSuggestedQuestions: (...args: unknown[]) => mockFetchSuggestedQuestions(...args),
+ stopChatMessageResponding: (...args: unknown[]) => mockStopChatMessageResponding(...args),
+}))
+
+vi.mock('@/app/components/base/chat/utils', () => ({
+ getLastAnswer: (chatList: ChatItemType[]) => chatList.find(item => item.isAnswer),
+}))
+
+vi.mock('@/utils', () => ({
+ canFindTool: (collectionId: string, providerId: string) => collectionId === providerId,
+}))
+
+vi.mock('@/app/components/base/chat/chat', () => ({
+ default: (props: typeof capturedChatProps) => {
+ capturedChatProps = props
+ return (
+
+ {props?.chatList?.length || 0}
+ {props?.isResponding ? 'yes' : 'no'}
+
+
+ )
+ },
+}))
+
+vi.mock('@/app/components/base/avatar', () => ({
+ default: ({ name }: { name: string }) => {name}
,
+}))
+
+const createModelAndParameter = (overrides: Partial = {}): ModelAndParameter => ({
+ id: 'model-1',
+ model: 'gpt-3.5-turbo',
+ provider: 'openai',
+ parameters: { temperature: 0.7 },
+ ...overrides,
+})
+
+const createDefaultMocks = () => {
+ mockUseAppContext.mockReturnValue({
+ userProfile: { avatar_url: 'http://avatar.url', name: 'Test User' },
+ })
+
+ mockUseDebugConfigurationContext.mockReturnValue({
+ modelConfig: {
+ configs: { prompt_variables: [] },
+ agentConfig: { tools: [] },
+ },
+ appId: 'app-123',
+ inputs: { key: 'value' },
+ collectionList: [],
+ })
+
+ mockUseProviderContext.mockReturnValue({
+ textGenerationModelList: [
+ {
+ provider: 'openai',
+ models: [
+ {
+ model: 'gpt-3.5-turbo',
+ features: [ModelFeatureEnum.vision],
+ model_properties: { mode: 'chat' },
+ },
+ ],
+ },
+ ],
+ })
+
+ mockUseFeatures.mockImplementation((selector: (state: Record) => unknown) => {
+ const state = {
+ features: {
+ moreLikeThis: { enabled: false },
+ opening: { enabled: true, opening_statement: 'Hello!', suggested_questions: ['Q1'] },
+ moderation: { enabled: false },
+ speech2text: { enabled: true },
+ text2speech: { enabled: false },
+ file: { enabled: true },
+ suggested: { enabled: true },
+ citation: { enabled: false },
+ annotationReply: { enabled: false },
+ },
+ }
+ return selector(state)
+ })
+
+ mockUseConfigFromDebugContext.mockReturnValue({
+ base_config: 'test',
+ })
+
+ mockUseChat.mockReturnValue({
+ chatList: [{ id: 'msg-1', content: 'Hello', isAnswer: true }],
+ isResponding: false,
+ handleSend: vi.fn(),
+ suggestedQuestions: ['Question 1', 'Question 2'],
+ handleRestart: vi.fn(),
+ })
+
+ mockUseEventEmitterContextContext.mockReturnValue({
+ eventEmitter: {
+ useSubscription: (callback: (v: { type: string, payload?: Record }) => void) => {
+ eventSubscriptionCallback = callback
+ },
+ },
+ })
+}
+
+const renderComponent = (props: Partial<{ modelAndParameter: ModelAndParameter }> = {}) => {
+ const defaultProps = {
+ modelAndParameter: createModelAndParameter(),
+ ...props,
+ }
+ return render()
+}
+
+describe('ChatItem', () => {
+ beforeEach(() => {
+ vi.clearAllMocks()
+ capturedChatProps = null
+ eventSubscriptionCallback = null
+ createDefaultMocks()
+ })
+
+ describe('rendering', () => {
+ it('should render Chat component when chatList is not empty', () => {
+ renderComponent()
+
+ expect(screen.getByTestId('chat-component')).toBeInTheDocument()
+ expect(screen.getByTestId('chat-list-length')).toHaveTextContent('1')
+ })
+
+ it('should not render when chatList is empty', () => {
+ mockUseChat.mockReturnValue({
+ chatList: [],
+ isResponding: false,
+ handleSend: vi.fn(),
+ suggestedQuestions: [],
+ handleRestart: vi.fn(),
+ })
+
+ renderComponent()
+
+ expect(screen.queryByTestId('chat-component')).not.toBeInTheDocument()
+ })
+
+ it('should pass correct config to Chat', () => {
+ renderComponent()
+
+ expect(capturedChatProps?.config).toMatchObject({
+ base_config: 'test',
+ opening_statement: 'Hello!',
+ suggested_questions: ['Q1'],
+ })
+ })
+
+ it('should pass suggestedQuestions to Chat', () => {
+ renderComponent()
+
+ expect(capturedChatProps?.suggestedQuestions).toEqual(['Question 1', 'Question 2'])
+ })
+
+ it('should pass isResponding to Chat', () => {
+ mockUseChat.mockReturnValue({
+ chatList: [{ id: 'msg-1' }],
+ isResponding: true,
+ handleSend: vi.fn(),
+ suggestedQuestions: [],
+ handleRestart: vi.fn(),
+ })
+
+ renderComponent()
+
+ expect(screen.getByTestId('is-responding')).toHaveTextContent('yes')
+ })
+ })
+
+ describe('config composition', () => {
+ it('should include opening statement when enabled', () => {
+ renderComponent()
+
+ expect(capturedChatProps?.config.opening_statement).toBe('Hello!')
+ })
+
+ it('should use empty opening statement when disabled', () => {
+ mockUseFeatures.mockImplementation((selector: (state: Record) => unknown) => {
+ const state = {
+ features: {
+ moreLikeThis: { enabled: false },
+ opening: { enabled: false, opening_statement: 'Should not appear' },
+ moderation: { enabled: false },
+ speech2text: { enabled: false },
+ text2speech: { enabled: false },
+ file: { enabled: false },
+ suggested: { enabled: false },
+ citation: { enabled: false },
+ annotationReply: { enabled: false },
+ },
+ }
+ return selector(state)
+ })
+
+ renderComponent()
+
+ expect(capturedChatProps?.config.opening_statement).toBe('')
+ expect(capturedChatProps?.config.suggested_questions).toEqual([])
+ })
+ })
+
+ describe('inputsForm transformation', () => {
+ it('should filter out API type variables', () => {
+ mockUseDebugConfigurationContext.mockReturnValue({
+ modelConfig: {
+ configs: {
+ prompt_variables: [
+ { key: 'var1', name: 'Var 1', type: 'string' },
+ { key: 'var2', name: 'Var 2', type: 'api' },
+ { key: 'var3', name: 'Var 3', type: 'number' },
+ ],
+ },
+ agentConfig: { tools: [] },
+ },
+ appId: 'app-123',
+ inputs: {},
+ collectionList: [],
+ })
+
+ renderComponent()
+
+ // The component transforms prompt_variables into inputsForm
+ // We can verify this through the useChat call
+ expect(mockUseChat).toHaveBeenCalled()
+ })
+ })
+
+ describe('event subscription', () => {
+ it('should handle APP_CHAT_WITH_MULTIPLE_MODEL event', () => {
+ const handleSend = vi.fn()
+ mockUseChat.mockReturnValue({
+ chatList: [{ id: 'msg-1' }],
+ isResponding: false,
+ handleSend,
+ suggestedQuestions: [],
+ handleRestart: vi.fn(),
+ })
+
+ renderComponent()
+
+ // Trigger the event
+ eventSubscriptionCallback?.({
+ type: APP_CHAT_WITH_MULTIPLE_MODEL,
+ payload: { message: 'Hello', files: [{ id: 'file-1' }] },
+ })
+
+ expect(handleSend).toHaveBeenCalledWith(
+ 'apps/app-123/chat-messages',
+ expect.objectContaining({
+ query: 'Hello',
+ inputs: { key: 'value' },
+ }),
+ expect.any(Object),
+ )
+ })
+
+ it('should handle APP_CHAT_WITH_MULTIPLE_MODEL_RESTART event', () => {
+ const handleRestart = vi.fn()
+ mockUseChat.mockReturnValue({
+ chatList: [{ id: 'msg-1' }],
+ isResponding: false,
+ handleSend: vi.fn(),
+ suggestedQuestions: [],
+ handleRestart,
+ })
+
+ renderComponent()
+
+ eventSubscriptionCallback?.({
+ type: APP_CHAT_WITH_MULTIPLE_MODEL_RESTART,
+ })
+
+ expect(handleRestart).toHaveBeenCalled()
+ })
+
+ it('should ignore unrelated events', () => {
+ const handleSend = vi.fn()
+ const handleRestart = vi.fn()
+ mockUseChat.mockReturnValue({
+ chatList: [{ id: 'msg-1' }],
+ isResponding: false,
+ handleSend,
+ suggestedQuestions: [],
+ handleRestart,
+ })
+
+ renderComponent()
+
+ eventSubscriptionCallback?.({
+ type: 'SOME_OTHER_EVENT',
+ payload: {},
+ })
+
+ expect(handleSend).not.toHaveBeenCalled()
+ expect(handleRestart).not.toHaveBeenCalled()
+ })
+ })
+
+ describe('doSend function', () => {
+ it('should include files when vision is supported and file upload enabled', () => {
+ const handleSend = vi.fn()
+ mockUseChat.mockReturnValue({
+ chatList: [{ id: 'msg-1' }],
+ isResponding: false,
+ handleSend,
+ suggestedQuestions: [],
+ handleRestart: vi.fn(),
+ })
+
+ renderComponent()
+
+ eventSubscriptionCallback?.({
+ type: APP_CHAT_WITH_MULTIPLE_MODEL,
+ payload: { message: 'Hello', files: [{ id: 'file-1' }] },
+ })
+
+ expect(handleSend).toHaveBeenCalledWith(
+ expect.any(String),
+ expect.objectContaining({
+ files: [{ id: 'file-1' }],
+ }),
+ expect.any(Object),
+ )
+ })
+
+ it('should not include files when vision is not supported', () => {
+ mockUseProviderContext.mockReturnValue({
+ textGenerationModelList: [
+ {
+ provider: 'openai',
+ models: [
+ {
+ model: 'gpt-3.5-turbo',
+ features: [], // No vision support
+ model_properties: { mode: 'chat' },
+ },
+ ],
+ },
+ ],
+ })
+
+ const handleSend = vi.fn()
+ mockUseChat.mockReturnValue({
+ chatList: [{ id: 'msg-1' }],
+ isResponding: false,
+ handleSend,
+ suggestedQuestions: [],
+ handleRestart: vi.fn(),
+ })
+
+ renderComponent()
+
+ eventSubscriptionCallback?.({
+ type: APP_CHAT_WITH_MULTIPLE_MODEL,
+ payload: { message: 'Hello', files: [{ id: 'file-1' }] },
+ })
+
+ expect(handleSend).toHaveBeenCalledWith(
+ expect.any(String),
+ expect.not.objectContaining({
+ files: expect.anything(),
+ }),
+ expect.any(Object),
+ )
+ })
+
+ it('should include model configuration in request', () => {
+ const handleSend = vi.fn()
+ mockUseChat.mockReturnValue({
+ chatList: [{ id: 'msg-1' }],
+ isResponding: false,
+ handleSend,
+ suggestedQuestions: [],
+ handleRestart: vi.fn(),
+ })
+
+ const modelAndParameter = createModelAndParameter({
+ provider: 'openai',
+ model: 'gpt-3.5-turbo',
+ parameters: { temperature: 0.5 },
+ })
+
+ renderComponent({ modelAndParameter })
+
+ eventSubscriptionCallback?.({
+ type: APP_CHAT_WITH_MULTIPLE_MODEL,
+ payload: { message: 'Hello', files: [] },
+ })
+
+ expect(handleSend).toHaveBeenCalledWith(
+ expect.any(String),
+ expect.objectContaining({
+ model_config: expect.objectContaining({
+ model: expect.objectContaining({
+ provider: 'openai',
+ name: 'gpt-3.5-turbo',
+ completion_params: { temperature: 0.5 },
+ }),
+ }),
+ }),
+ expect.any(Object),
+ )
+ })
+
+ it('should use parent_message_id from last answer', () => {
+ const handleSend = vi.fn()
+ mockUseChat.mockReturnValue({
+ chatList: [
+ { id: 'msg-1', content: 'Hi', isAnswer: false },
+ { id: 'msg-2', content: 'Hello', isAnswer: true },
+ ],
+ isResponding: false,
+ handleSend,
+ suggestedQuestions: [],
+ handleRestart: vi.fn(),
+ })
+
+ renderComponent()
+
+ eventSubscriptionCallback?.({
+ type: APP_CHAT_WITH_MULTIPLE_MODEL,
+ payload: { message: 'Hello', files: [] },
+ })
+
+ expect(handleSend).toHaveBeenCalledWith(
+ expect.any(String),
+ expect.objectContaining({
+ parent_message_id: 'msg-2',
+ }),
+ expect.any(Object),
+ )
+ })
+ })
+
+ describe('allToolIcons', () => {
+ it('should compute tool icons from collectionList', () => {
+ mockUseDebugConfigurationContext.mockReturnValue({
+ modelConfig: {
+ configs: { prompt_variables: [] },
+ agentConfig: {
+ tools: [
+ { tool_name: 'tool1', provider_id: 'collection1' },
+ { tool_name: 'tool2', provider_id: 'collection2' },
+ ],
+ },
+ },
+ appId: 'app-123',
+ inputs: {},
+ collectionList: [
+ { id: 'collection1', icon: 'icon1' },
+ { id: 'collection2', icon: 'icon2' },
+ ],
+ })
+
+ renderComponent()
+
+ expect(capturedChatProps?.allToolIcons).toEqual({
+ tool1: 'icon1',
+ tool2: 'icon2',
+ })
+ })
+
+ it('should handle tools without matching collection', () => {
+ mockUseDebugConfigurationContext.mockReturnValue({
+ modelConfig: {
+ configs: { prompt_variables: [] },
+ agentConfig: {
+ tools: [
+ { tool_name: 'tool1', provider_id: 'nonexistent' },
+ ],
+ },
+ },
+ appId: 'app-123',
+ inputs: {},
+ collectionList: [],
+ })
+
+ renderComponent()
+
+ expect(capturedChatProps?.allToolIcons).toEqual({
+ tool1: undefined,
+ })
+ })
+
+ it('should handle empty tools array', () => {
+ mockUseDebugConfigurationContext.mockReturnValue({
+ modelConfig: {
+ configs: { prompt_variables: [] },
+ agentConfig: { tools: [] },
+ },
+ appId: 'app-123',
+ inputs: {},
+ collectionList: [],
+ })
+
+ renderComponent()
+
+ expect(capturedChatProps?.allToolIcons).toEqual({})
+ })
+ })
+
+ describe('useFormattingChangedSubscription', () => {
+ it('should call useFormattingChangedSubscription with chatList', () => {
+ const chatList = [{ id: 'msg-1', content: 'Hello' }]
+ mockUseChat.mockReturnValue({
+ chatList,
+ isResponding: false,
+ handleSend: vi.fn(),
+ suggestedQuestions: [],
+ handleRestart: vi.fn(),
+ })
+
+ renderComponent()
+
+ expect(mockUseFormattingChangedSubscription).toHaveBeenCalledWith(chatList)
+ })
+ })
+
+ describe('edge cases', () => {
+ it('should handle missing provider in textGenerationModelList', () => {
+ mockUseProviderContext.mockReturnValue({
+ textGenerationModelList: [],
+ })
+
+ const handleSend = vi.fn()
+ mockUseChat.mockReturnValue({
+ chatList: [{ id: 'msg-1' }],
+ isResponding: false,
+ handleSend,
+ suggestedQuestions: [],
+ handleRestart: vi.fn(),
+ })
+
+ renderComponent()
+
+ eventSubscriptionCallback?.({
+ type: APP_CHAT_WITH_MULTIPLE_MODEL,
+ payload: { message: 'Hello', files: [] },
+ })
+
+ // Should still call handleSend without crashing
+ expect(handleSend).toHaveBeenCalled()
+ })
+
+ it('should handle null eventEmitter', () => {
+ mockUseEventEmitterContextContext.mockReturnValue({
+ eventEmitter: null,
+ })
+
+ expect(() => renderComponent()).not.toThrow()
+ })
+
+ it('should handle undefined tools in agentConfig', () => {
+ mockUseDebugConfigurationContext.mockReturnValue({
+ modelConfig: {
+ configs: { prompt_variables: [] },
+ agentConfig: { tools: undefined },
+ },
+ appId: 'app-123',
+ inputs: {},
+ collectionList: [],
+ })
+
+ // This may throw since the code does agentConfig.tools?.forEach
+ // But the optional chaining should handle it
+ expect(() => renderComponent()).not.toThrow()
+ })
+ })
+})
diff --git a/web/app/components/app/configuration/debug/debug-with-multiple-model/context.spec.tsx b/web/app/components/app/configuration/debug/debug-with-multiple-model/context.spec.tsx
new file mode 100644
index 0000000000..e26fcec607
--- /dev/null
+++ b/web/app/components/app/configuration/debug/debug-with-multiple-model/context.spec.tsx
@@ -0,0 +1,224 @@
+import type { ModelAndParameter } from '../types'
+import type { DebugWithMultipleModelContextType } from './context'
+import { render, screen } from '@testing-library/react'
+import {
+ DebugWithMultipleModelContextProvider,
+ useDebugWithMultipleModelContext,
+} from './context'
+
+const createModelAndParameter = (overrides: Partial = {}): ModelAndParameter => ({
+ id: 'model-1',
+ model: 'gpt-3.5-turbo',
+ provider: 'openai',
+ parameters: {},
+ ...overrides,
+})
+
+const TestConsumer = () => {
+ const context = useDebugWithMultipleModelContext()
+ return (
+
+ {context.multipleModelConfigs.length}
+ {context.checkCanSend ? 'yes' : 'no'}
+
+
+
+ )
+}
+
+describe('DebugWithMultipleModelContext', () => {
+ describe('useDebugWithMultipleModelContext', () => {
+ it('should return default values when used outside provider', () => {
+ render()
+
+ expect(screen.getByTestId('configs-count')).toHaveTextContent('0')
+ expect(screen.getByTestId('has-check-can-send')).toHaveTextContent('no')
+ })
+
+ it('should return default noop functions that do not throw', () => {
+ render()
+
+ // These should not throw when called
+ expect(() => {
+ screen.getByTestId('call-on-change').click()
+ }).not.toThrow()
+
+ expect(() => {
+ screen.getByTestId('call-on-debug-change').click()
+ }).not.toThrow()
+ })
+ })
+
+ describe('DebugWithMultipleModelContextProvider', () => {
+ it('should provide multipleModelConfigs to children', () => {
+ const multipleModelConfigs = [
+ createModelAndParameter({ id: 'model-1' }),
+ createModelAndParameter({ id: 'model-2' }),
+ ]
+
+ render(
+
+
+ ,
+ )
+
+ expect(screen.getByTestId('configs-count')).toHaveTextContent('2')
+ })
+
+ it('should provide checkCanSend function to children', () => {
+ const checkCanSend = vi.fn(() => true)
+
+ render(
+
+
+ ,
+ )
+
+ expect(screen.getByTestId('has-check-can-send')).toHaveTextContent('yes')
+ })
+
+ it('should call onMultipleModelConfigsChange when invoked from context', () => {
+ const onMultipleModelConfigsChange = vi.fn()
+
+ render(
+
+
+ ,
+ )
+
+ screen.getByTestId('call-on-change').click()
+
+ expect(onMultipleModelConfigsChange).toHaveBeenCalledWith(true, [])
+ })
+
+ it('should call onDebugWithMultipleModelChange when invoked from context', () => {
+ const onDebugWithMultipleModelChange = vi.fn()
+
+ render(
+
+
+ ,
+ )
+
+ screen.getByTestId('call-on-debug-change').click()
+
+ expect(onDebugWithMultipleModelChange).toHaveBeenCalledWith(
+ expect.objectContaining({ id: 'model-1' }),
+ )
+ })
+
+ it('should handle undefined checkCanSend', () => {
+ render(
+
+
+ ,
+ )
+
+ expect(screen.getByTestId('has-check-can-send')).toHaveTextContent('no')
+ })
+
+ it('should render children correctly', () => {
+ render(
+
+ Child Content
+ ,
+ )
+
+ expect(screen.getByTestId('child-element')).toHaveTextContent('Child Content')
+ })
+
+ it('should update context when props change', () => {
+ const { rerender } = render(
+
+
+ ,
+ )
+
+ expect(screen.getByTestId('configs-count')).toHaveTextContent('1')
+
+ rerender(
+
+
+ ,
+ )
+
+ expect(screen.getByTestId('configs-count')).toHaveTextContent('2')
+ })
+
+ it('should pass all context values correctly', () => {
+ const contextValues: DebugWithMultipleModelContextType = {
+ multipleModelConfigs: [createModelAndParameter()],
+ onMultipleModelConfigsChange: vi.fn(),
+ onDebugWithMultipleModelChange: vi.fn(),
+ checkCanSend: () => true,
+ }
+
+ const FullTestConsumer = () => {
+ const context = useDebugWithMultipleModelContext()
+ return (
+
+ {JSON.stringify(context.multipleModelConfigs)}
+ {typeof context.onMultipleModelConfigsChange}
+ {typeof context.onDebugWithMultipleModelChange}
+ {typeof context.checkCanSend}
+
+ )
+ }
+
+ render(
+
+
+ ,
+ )
+
+ expect(screen.getByTestId('configs')).toHaveTextContent('model-1')
+ expect(screen.getByTestId('has-on-change')).toHaveTextContent('function')
+ expect(screen.getByTestId('has-on-debug-change')).toHaveTextContent('function')
+ expect(screen.getByTestId('has-check')).toHaveTextContent('function')
+ })
+ })
+})
diff --git a/web/app/components/app/configuration/debug/debug-with-multiple-model/debug-item.spec.tsx b/web/app/components/app/configuration/debug/debug-with-multiple-model/debug-item.spec.tsx
new file mode 100644
index 0000000000..efc477fb47
--- /dev/null
+++ b/web/app/components/app/configuration/debug/debug-with-multiple-model/debug-item.spec.tsx
@@ -0,0 +1,552 @@
+import type { CSSProperties } from 'react'
+import type { ModelAndParameter } from '../types'
+import type { Item } from '@/app/components/base/dropdown'
+import { fireEvent, render, screen } from '@testing-library/react'
+import { ModelStatusEnum } from '@/app/components/header/account-setting/model-provider-page/declarations'
+import { AppModeEnum } from '@/types/app'
+import DebugItem from './debug-item'
+
+const mockUseDebugConfigurationContext = vi.fn()
+const mockUseDebugWithMultipleModelContext = vi.fn()
+const mockUseProviderContext = vi.fn()
+
+let capturedDropdownProps: {
+ onSelect: (item: Item) => void
+ items: Item[]
+ secondItems?: Item[]
+} | null = null
+
+let capturedModelParameterTriggerProps: {
+ modelAndParameter: ModelAndParameter
+} | null = null
+
+vi.mock('@/context/debug-configuration', () => ({
+ useDebugConfigurationContext: () => mockUseDebugConfigurationContext(),
+}))
+
+vi.mock('./context', () => ({
+ useDebugWithMultipleModelContext: () => mockUseDebugWithMultipleModelContext(),
+}))
+
+vi.mock('@/context/provider-context', () => ({
+ useProviderContext: () => mockUseProviderContext(),
+}))
+
+vi.mock('./chat-item', () => ({
+ default: ({ modelAndParameter }: { modelAndParameter: ModelAndParameter }) => (
+ ChatItem
+ ),
+}))
+
+vi.mock('./text-generation-item', () => ({
+ default: ({ modelAndParameter }: { modelAndParameter: ModelAndParameter }) => (
+ TextGenerationItem
+ ),
+}))
+
+vi.mock('./model-parameter-trigger', () => ({
+ default: (props: { modelAndParameter: ModelAndParameter }) => {
+ capturedModelParameterTriggerProps = props
+ return ModelParameterTrigger
+ },
+}))
+
+vi.mock('@/app/components/base/dropdown', () => ({
+ default: (props: { onSelect: (item: Item) => void, items: Item[], secondItems?: Item[] }) => {
+ capturedDropdownProps = props
+ return (
+
+ {props.items.map(item => (
+
+ ))}
+ {props.secondItems?.map(item => (
+
+ ))}
+
+ )
+ },
+}))
+
+const createModelAndParameter = (overrides: Partial = {}): ModelAndParameter => ({
+ id: 'model-1',
+ model: 'gpt-3.5-turbo',
+ provider: 'openai',
+ parameters: {},
+ ...overrides,
+})
+
+const createTextGenerationModelList = (models: Array<{ provider: string, model: string, status?: ModelStatusEnum }> = []) => {
+ const providers: Record }> = {}
+
+ models.forEach(({ provider, model, status = ModelStatusEnum.active }) => {
+ if (!providers[provider]) {
+ providers[provider] = { provider, models: [] }
+ }
+ providers[provider].models.push({ model, status })
+ })
+
+ return Object.values(providers)
+}
+
+type DebugItemProps = {
+ modelAndParameter: ModelAndParameter
+ className?: string
+ style?: CSSProperties
+}
+
+const renderComponent = (props: Partial = {}) => {
+ const defaultProps: DebugItemProps = {
+ modelAndParameter: createModelAndParameter(),
+ ...props,
+ }
+ return render()
+}
+
+describe('DebugItem', () => {
+ beforeEach(() => {
+ vi.clearAllMocks()
+ capturedDropdownProps = null
+ capturedModelParameterTriggerProps = null
+
+ mockUseDebugConfigurationContext.mockReturnValue({
+ mode: AppModeEnum.CHAT,
+ })
+
+ mockUseDebugWithMultipleModelContext.mockReturnValue({
+ multipleModelConfigs: [createModelAndParameter()],
+ onMultipleModelConfigsChange: vi.fn(),
+ onDebugWithMultipleModelChange: vi.fn(),
+ })
+
+ mockUseProviderContext.mockReturnValue({
+ textGenerationModelList: createTextGenerationModelList([
+ { provider: 'openai', model: 'gpt-3.5-turbo' },
+ ]),
+ })
+ })
+
+ describe('rendering', () => {
+ it('should render with basic props', () => {
+ renderComponent()
+
+ expect(screen.getByTestId('model-parameter-trigger')).toBeInTheDocument()
+ expect(screen.getByTestId('dropdown')).toBeInTheDocument()
+ })
+
+ it('should display correct index number', () => {
+ const modelConfigs = [
+ createModelAndParameter({ id: 'model-1' }),
+ createModelAndParameter({ id: 'model-2' }),
+ ]
+ mockUseDebugWithMultipleModelContext.mockReturnValue({
+ multipleModelConfigs: modelConfigs,
+ onMultipleModelConfigsChange: vi.fn(),
+ onDebugWithMultipleModelChange: vi.fn(),
+ })
+
+ const { container } = renderComponent({ modelAndParameter: createModelAndParameter({ id: 'model-2' }) })
+
+ // The index is displayed as "#2" in the component
+ const indexElement = container.querySelector('.font-medium.italic')
+ expect(indexElement?.textContent?.trim()).toContain('2')
+ })
+
+ it('should apply className and style props', () => {
+ const { container } = renderComponent({
+ className: 'custom-class',
+ style: { backgroundColor: 'red' },
+ })
+
+ const wrapper = container.firstChild as HTMLElement
+ expect(wrapper).toHaveClass('custom-class')
+ expect(wrapper.style.backgroundColor).toBe('red')
+ })
+
+ it('should pass modelAndParameter to ModelParameterTrigger', () => {
+ const modelAndParameter = createModelAndParameter({ id: 'test-model' })
+ renderComponent({ modelAndParameter })
+
+ expect(capturedModelParameterTriggerProps?.modelAndParameter).toEqual(modelAndParameter)
+ })
+ })
+
+ describe('ChatItem rendering', () => {
+ it('should render ChatItem in CHAT mode with active model', () => {
+ mockUseDebugConfigurationContext.mockReturnValue({ mode: AppModeEnum.CHAT })
+ mockUseProviderContext.mockReturnValue({
+ textGenerationModelList: createTextGenerationModelList([
+ { provider: 'openai', model: 'gpt-3.5-turbo', status: ModelStatusEnum.active },
+ ]),
+ })
+
+ renderComponent()
+
+ expect(screen.getByTestId('chat-item')).toBeInTheDocument()
+ expect(screen.queryByTestId('text-generation-item')).not.toBeInTheDocument()
+ })
+
+ it('should render ChatItem in AGENT_CHAT mode with active model', () => {
+ mockUseDebugConfigurationContext.mockReturnValue({ mode: AppModeEnum.AGENT_CHAT })
+ mockUseProviderContext.mockReturnValue({
+ textGenerationModelList: createTextGenerationModelList([
+ { provider: 'openai', model: 'gpt-3.5-turbo', status: ModelStatusEnum.active },
+ ]),
+ })
+
+ renderComponent()
+
+ expect(screen.getByTestId('chat-item')).toBeInTheDocument()
+ })
+
+ it('should not render ChatItem when model is not active', () => {
+ mockUseDebugConfigurationContext.mockReturnValue({ mode: AppModeEnum.CHAT })
+ mockUseProviderContext.mockReturnValue({
+ textGenerationModelList: createTextGenerationModelList([
+ { provider: 'openai', model: 'gpt-3.5-turbo', status: ModelStatusEnum.disabled },
+ ]),
+ })
+
+ renderComponent()
+
+ expect(screen.queryByTestId('chat-item')).not.toBeInTheDocument()
+ })
+
+ it('should not render ChatItem when provider not found', () => {
+ mockUseDebugConfigurationContext.mockReturnValue({ mode: AppModeEnum.CHAT })
+ mockUseProviderContext.mockReturnValue({
+ textGenerationModelList: createTextGenerationModelList([
+ { provider: 'anthropic', model: 'claude-3', status: ModelStatusEnum.active },
+ ]),
+ })
+
+ renderComponent()
+
+ expect(screen.queryByTestId('chat-item')).not.toBeInTheDocument()
+ })
+
+ it('should not render ChatItem when model not found', () => {
+ mockUseDebugConfigurationContext.mockReturnValue({ mode: AppModeEnum.CHAT })
+ mockUseProviderContext.mockReturnValue({
+ textGenerationModelList: createTextGenerationModelList([
+ { provider: 'openai', model: 'gpt-4', status: ModelStatusEnum.active },
+ ]),
+ })
+
+ renderComponent()
+
+ expect(screen.queryByTestId('chat-item')).not.toBeInTheDocument()
+ })
+ })
+
+ describe('TextGenerationItem rendering', () => {
+ it('should render TextGenerationItem in COMPLETION mode with active model', () => {
+ mockUseDebugConfigurationContext.mockReturnValue({ mode: AppModeEnum.COMPLETION })
+ mockUseProviderContext.mockReturnValue({
+ textGenerationModelList: createTextGenerationModelList([
+ { provider: 'openai', model: 'gpt-3.5-turbo', status: ModelStatusEnum.active },
+ ]),
+ })
+
+ renderComponent()
+
+ expect(screen.getByTestId('text-generation-item')).toBeInTheDocument()
+ expect(screen.queryByTestId('chat-item')).not.toBeInTheDocument()
+ })
+
+ it('should not render TextGenerationItem when provider is not found', () => {
+ mockUseDebugConfigurationContext.mockReturnValue({ mode: AppModeEnum.COMPLETION })
+ mockUseProviderContext.mockReturnValue({
+ textGenerationModelList: createTextGenerationModelList([
+ { provider: 'anthropic', model: 'claude-3', status: ModelStatusEnum.active },
+ ]),
+ })
+
+ renderComponent()
+
+ expect(screen.queryByTestId('text-generation-item')).not.toBeInTheDocument()
+ })
+ })
+
+ describe('dropdown menu', () => {
+ it('should show duplicate option when less than 4 models', () => {
+ mockUseDebugWithMultipleModelContext.mockReturnValue({
+ multipleModelConfigs: [createModelAndParameter()],
+ onMultipleModelConfigsChange: vi.fn(),
+ onDebugWithMultipleModelChange: vi.fn(),
+ })
+
+ renderComponent()
+
+ expect(capturedDropdownProps?.items).toContainEqual(
+ expect.objectContaining({ value: 'duplicate' }),
+ )
+ })
+
+ it('should hide duplicate option when 4 or more models', () => {
+ mockUseDebugWithMultipleModelContext.mockReturnValue({
+ multipleModelConfigs: [
+ createModelAndParameter({ id: '1' }),
+ createModelAndParameter({ id: '2' }),
+ createModelAndParameter({ id: '3' }),
+ createModelAndParameter({ id: '4' }),
+ ],
+ onMultipleModelConfigsChange: vi.fn(),
+ onDebugWithMultipleModelChange: vi.fn(),
+ })
+
+ renderComponent()
+
+ expect(capturedDropdownProps?.items).not.toContainEqual(
+ expect.objectContaining({ value: 'duplicate' }),
+ )
+ })
+
+ it('should show debug-as-single-model option when provider and model are set', () => {
+ renderComponent({
+ modelAndParameter: createModelAndParameter({
+ provider: 'openai',
+ model: 'gpt-3.5-turbo',
+ }),
+ })
+
+ expect(capturedDropdownProps?.items).toContainEqual(
+ expect.objectContaining({ value: 'debug-as-single-model' }),
+ )
+ })
+
+ it('should hide debug-as-single-model option when provider is missing', () => {
+ renderComponent({
+ modelAndParameter: createModelAndParameter({
+ provider: '',
+ model: 'gpt-3.5-turbo',
+ }),
+ })
+
+ expect(capturedDropdownProps?.items).not.toContainEqual(
+ expect.objectContaining({ value: 'debug-as-single-model' }),
+ )
+ })
+
+ it('should hide debug-as-single-model option when model is missing', () => {
+ renderComponent({
+ modelAndParameter: createModelAndParameter({
+ provider: 'openai',
+ model: '',
+ }),
+ })
+
+ expect(capturedDropdownProps?.items).not.toContainEqual(
+ expect.objectContaining({ value: 'debug-as-single-model' }),
+ )
+ })
+
+ it('should show remove option in secondItems when more than 2 models', () => {
+ mockUseDebugWithMultipleModelContext.mockReturnValue({
+ multipleModelConfigs: [
+ createModelAndParameter({ id: '1' }),
+ createModelAndParameter({ id: '2' }),
+ createModelAndParameter({ id: '3' }),
+ ],
+ onMultipleModelConfigsChange: vi.fn(),
+ onDebugWithMultipleModelChange: vi.fn(),
+ })
+
+ renderComponent()
+
+ expect(capturedDropdownProps?.secondItems).toContainEqual(
+ expect.objectContaining({ value: 'remove' }),
+ )
+ })
+
+ it('should not show remove option when 2 or fewer models', () => {
+ mockUseDebugWithMultipleModelContext.mockReturnValue({
+ multipleModelConfigs: [
+ createModelAndParameter({ id: '1' }),
+ createModelAndParameter({ id: '2' }),
+ ],
+ onMultipleModelConfigsChange: vi.fn(),
+ onDebugWithMultipleModelChange: vi.fn(),
+ })
+
+ renderComponent()
+
+ expect(capturedDropdownProps?.secondItems).toBeUndefined()
+ })
+ })
+
+ describe('dropdown actions', () => {
+ it('should duplicate model when duplicate is selected', () => {
+ const onMultipleModelConfigsChange = vi.fn()
+ const originalModel = createModelAndParameter({ id: 'original' })
+
+ mockUseDebugWithMultipleModelContext.mockReturnValue({
+ multipleModelConfigs: [originalModel],
+ onMultipleModelConfigsChange,
+ onDebugWithMultipleModelChange: vi.fn(),
+ })
+
+ renderComponent({ modelAndParameter: originalModel })
+
+ fireEvent.click(screen.getByTestId('dropdown-item-duplicate'))
+
+ expect(onMultipleModelConfigsChange).toHaveBeenCalledWith(
+ true,
+ expect.arrayContaining([
+ originalModel,
+ expect.objectContaining({
+ model: originalModel.model,
+ provider: originalModel.provider,
+ parameters: originalModel.parameters,
+ }),
+ ]),
+ )
+ })
+
+ it('should not duplicate when already at 4 models', () => {
+ const onMultipleModelConfigsChange = vi.fn()
+ const models = [
+ createModelAndParameter({ id: '1' }),
+ createModelAndParameter({ id: '2' }),
+ createModelAndParameter({ id: '3' }),
+ createModelAndParameter({ id: '4' }),
+ ]
+
+ mockUseDebugWithMultipleModelContext.mockReturnValue({
+ multipleModelConfigs: models,
+ onMultipleModelConfigsChange,
+ onDebugWithMultipleModelChange: vi.fn(),
+ })
+
+ renderComponent({ modelAndParameter: models[0] })
+
+ // Since duplicate is not shown when >= 4 models, we need to manually call handleSelect
+ capturedDropdownProps?.onSelect({ value: 'duplicate', text: 'Duplicate' })
+
+ expect(onMultipleModelConfigsChange).not.toHaveBeenCalled()
+ })
+
+ it('should call onDebugWithMultipleModelChange when debug-as-single-model is selected', () => {
+ const onDebugWithMultipleModelChange = vi.fn()
+ const modelAndParameter = createModelAndParameter()
+
+ mockUseDebugWithMultipleModelContext.mockReturnValue({
+ multipleModelConfigs: [modelAndParameter],
+ onMultipleModelConfigsChange: vi.fn(),
+ onDebugWithMultipleModelChange,
+ })
+
+ renderComponent({ modelAndParameter })
+
+ fireEvent.click(screen.getByTestId('dropdown-item-debug-as-single-model'))
+
+ expect(onDebugWithMultipleModelChange).toHaveBeenCalledWith(modelAndParameter)
+ })
+
+ it('should remove model when remove is selected', () => {
+ const onMultipleModelConfigsChange = vi.fn()
+ const models = [
+ createModelAndParameter({ id: '1' }),
+ createModelAndParameter({ id: '2' }),
+ createModelAndParameter({ id: '3' }),
+ ]
+
+ mockUseDebugWithMultipleModelContext.mockReturnValue({
+ multipleModelConfigs: models,
+ onMultipleModelConfigsChange,
+ onDebugWithMultipleModelChange: vi.fn(),
+ })
+
+ renderComponent({ modelAndParameter: models[1] })
+
+ fireEvent.click(screen.getByTestId('dropdown-second-item-remove'))
+
+ expect(onMultipleModelConfigsChange).toHaveBeenCalledWith(
+ true,
+ [models[0], models[2]],
+ )
+ })
+
+ it('should insert duplicated model at correct position', () => {
+ const onMultipleModelConfigsChange = vi.fn()
+ const models = [
+ createModelAndParameter({ id: '1' }),
+ createModelAndParameter({ id: '2' }),
+ createModelAndParameter({ id: '3' }),
+ ]
+
+ mockUseDebugWithMultipleModelContext.mockReturnValue({
+ multipleModelConfigs: models,
+ onMultipleModelConfigsChange,
+ onDebugWithMultipleModelChange: vi.fn(),
+ })
+
+ // Duplicate the second model
+ renderComponent({ modelAndParameter: models[1] })
+
+ fireEvent.click(screen.getByTestId('dropdown-item-duplicate'))
+
+ expect(onMultipleModelConfigsChange).toHaveBeenCalledWith(
+ true,
+ expect.arrayContaining([
+ models[0],
+ models[1],
+ expect.objectContaining({ model: models[1].model }),
+ models[2],
+ ]),
+ )
+ })
+ })
+
+ describe('edge cases', () => {
+ it('should handle model not found in multipleModelConfigs', () => {
+ mockUseDebugWithMultipleModelContext.mockReturnValue({
+ multipleModelConfigs: [],
+ onMultipleModelConfigsChange: vi.fn(),
+ onDebugWithMultipleModelChange: vi.fn(),
+ })
+
+ const { container } = renderComponent()
+
+ // Should show index 0 (not found returns -1, but display shows index + 1)
+ const indexElement = container.querySelector('.font-medium.italic')
+ expect(indexElement?.textContent?.trim()).toContain('0')
+ })
+
+ it('should handle empty textGenerationModelList', () => {
+ mockUseProviderContext.mockReturnValue({
+ textGenerationModelList: [],
+ })
+
+ renderComponent()
+
+ expect(screen.queryByTestId('chat-item')).not.toBeInTheDocument()
+ expect(screen.queryByTestId('text-generation-item')).not.toBeInTheDocument()
+ })
+
+ it('should handle model with quotaExceeded status', () => {
+ mockUseDebugConfigurationContext.mockReturnValue({ mode: AppModeEnum.CHAT })
+ mockUseProviderContext.mockReturnValue({
+ textGenerationModelList: createTextGenerationModelList([
+ { provider: 'anthropic', model: 'not-matching', status: ModelStatusEnum.quotaExceeded },
+ ]),
+ })
+
+ renderComponent()
+
+ // When provider/model doesn't match, ChatItem won't render
+ expect(screen.queryByTestId('chat-item')).not.toBeInTheDocument()
+ })
+ })
+})
diff --git a/web/app/components/app/configuration/debug/debug-with-multiple-model/model-parameter-trigger.spec.tsx b/web/app/components/app/configuration/debug/debug-with-multiple-model/model-parameter-trigger.spec.tsx
new file mode 100644
index 0000000000..5ef1dcadbb
--- /dev/null
+++ b/web/app/components/app/configuration/debug/debug-with-multiple-model/model-parameter-trigger.spec.tsx
@@ -0,0 +1,405 @@
+import type { ReactNode } from 'react'
+import type { ModelAndParameter } from '../types'
+import type { FormValue } from '@/app/components/header/account-setting/model-provider-page/declarations'
+import { render, screen } from '@testing-library/react'
+import { ModelStatusEnum } from '@/app/components/header/account-setting/model-provider-page/declarations'
+import ModelParameterTrigger from './model-parameter-trigger'
+
+const mockUseDebugConfigurationContext = vi.fn()
+const mockUseDebugWithMultipleModelContext = vi.fn()
+const mockUseLanguage = vi.fn()
+
+type RenderTriggerProps = {
+ open: boolean
+ currentProvider: { provider: string } | null
+ currentModel: { model: string, status: ModelStatusEnum } | null
+}
+
+let capturedModalProps: {
+ isAdvancedMode: boolean
+ provider: string
+ modelId: string
+ completionParams: FormValue
+ onCompletionParamsChange: (params: FormValue) => void
+ setModel: (model: { modelId: string, provider: string }) => void
+ debugWithMultipleModel: boolean
+ onDebugWithMultipleModelChange: () => void
+ renderTrigger: (props: RenderTriggerProps) => ReactNode
+} | null = null
+
+vi.mock('@/context/debug-configuration', () => ({
+ useDebugConfigurationContext: () => mockUseDebugConfigurationContext(),
+}))
+
+vi.mock('./context', () => ({
+ useDebugWithMultipleModelContext: () => mockUseDebugWithMultipleModelContext(),
+}))
+
+vi.mock('@/app/components/header/account-setting/model-provider-page/hooks', () => ({
+ useLanguage: () => mockUseLanguage(),
+}))
+
+vi.mock('@/app/components/header/account-setting/model-provider-page/model-parameter-modal', () => ({
+ default: (props: typeof capturedModalProps) => {
+ capturedModalProps = props
+ // Render the trigger that the component passes
+ const triggerContent = props?.renderTrigger({
+ open: false,
+ currentProvider: null,
+ currentModel: null,
+ })
+ return (
+
+ {triggerContent}
+
+ )
+ },
+}))
+
+vi.mock('@/app/components/header/account-setting/model-provider-page/model-icon', () => ({
+ default: ({ provider, modelName }: { provider: { provider: string }, modelName?: string }) => (
+
+ ModelIcon
+
+ ),
+}))
+
+vi.mock('@/app/components/header/account-setting/model-provider-page/model-name', () => ({
+ default: ({ modelItem }: { modelItem: { model: string } }) => (
+ {modelItem?.model}
+ ),
+}))
+
+vi.mock('@/app/components/base/tooltip', () => ({
+ default: ({ children, popupContent }: { children: ReactNode, popupContent: string }) => (
+ {children}
+ ),
+}))
+
+const createModelAndParameter = (overrides: Partial = {}): ModelAndParameter => ({
+ id: 'model-1',
+ model: 'gpt-3.5-turbo',
+ provider: 'openai',
+ parameters: { temperature: 0.7 },
+ ...overrides,
+})
+
+const renderComponent = (props: Partial<{ modelAndParameter: ModelAndParameter }> = {}) => {
+ const defaultProps = {
+ modelAndParameter: createModelAndParameter(),
+ ...props,
+ }
+ return render()
+}
+
+describe('ModelParameterTrigger', () => {
+ beforeEach(() => {
+ vi.clearAllMocks()
+ capturedModalProps = null
+
+ mockUseDebugConfigurationContext.mockReturnValue({
+ isAdvancedMode: false,
+ })
+
+ mockUseDebugWithMultipleModelContext.mockReturnValue({
+ multipleModelConfigs: [createModelAndParameter()],
+ onMultipleModelConfigsChange: vi.fn(),
+ onDebugWithMultipleModelChange: vi.fn(),
+ })
+
+ mockUseLanguage.mockReturnValue('en_US')
+ })
+
+ describe('rendering', () => {
+ it('should render ModelParameterModal', () => {
+ renderComponent()
+
+ expect(screen.getByTestId('model-parameter-modal')).toBeInTheDocument()
+ })
+
+ it('should pass correct props to ModelParameterModal', () => {
+ const modelAndParameter = createModelAndParameter({
+ provider: 'anthropic',
+ model: 'claude-3',
+ parameters: { max_tokens: 1000 },
+ })
+
+ renderComponent({ modelAndParameter })
+
+ expect(capturedModalProps?.provider).toBe('anthropic')
+ expect(capturedModalProps?.modelId).toBe('claude-3')
+ expect(capturedModalProps?.completionParams).toEqual({ max_tokens: 1000 })
+ expect(capturedModalProps?.debugWithMultipleModel).toBe(true)
+ })
+
+ it('should pass isAdvancedMode from context', () => {
+ mockUseDebugConfigurationContext.mockReturnValue({
+ isAdvancedMode: true,
+ })
+
+ renderComponent()
+
+ expect(capturedModalProps?.isAdvancedMode).toBe(true)
+ })
+ })
+
+ describe('handleSelectModel', () => {
+ it('should call onMultipleModelConfigsChange with updated model', () => {
+ const onMultipleModelConfigsChange = vi.fn()
+ const modelAndParameter = createModelAndParameter({ id: 'model-1' })
+
+ mockUseDebugWithMultipleModelContext.mockReturnValue({
+ multipleModelConfigs: [modelAndParameter],
+ onMultipleModelConfigsChange,
+ onDebugWithMultipleModelChange: vi.fn(),
+ })
+
+ renderComponent({ modelAndParameter })
+
+ // Directly call the setModel callback
+ capturedModalProps?.setModel({ modelId: 'gpt-4', provider: 'openai' })
+
+ expect(onMultipleModelConfigsChange).toHaveBeenCalledWith(true, [
+ expect.objectContaining({
+ id: 'model-1',
+ model: 'gpt-4',
+ provider: 'openai',
+ }),
+ ])
+ })
+
+ it('should update correct model in array', () => {
+ const onMultipleModelConfigsChange = vi.fn()
+ const models = [
+ createModelAndParameter({ id: 'model-1' }),
+ createModelAndParameter({ id: 'model-2' }),
+ createModelAndParameter({ id: 'model-3' }),
+ ]
+
+ mockUseDebugWithMultipleModelContext.mockReturnValue({
+ multipleModelConfigs: models,
+ onMultipleModelConfigsChange,
+ onDebugWithMultipleModelChange: vi.fn(),
+ })
+
+ renderComponent({ modelAndParameter: models[1] })
+
+ capturedModalProps?.setModel({ modelId: 'gpt-4', provider: 'openai' })
+
+ expect(onMultipleModelConfigsChange).toHaveBeenCalledWith(true, [
+ models[0],
+ expect.objectContaining({
+ id: 'model-2',
+ model: 'gpt-4',
+ provider: 'openai',
+ }),
+ models[2],
+ ])
+ })
+ })
+
+ describe('handleParamsChange', () => {
+ it('should call onMultipleModelConfigsChange with updated parameters', () => {
+ const onMultipleModelConfigsChange = vi.fn()
+ const modelAndParameter = createModelAndParameter({ id: 'model-1' })
+
+ mockUseDebugWithMultipleModelContext.mockReturnValue({
+ multipleModelConfigs: [modelAndParameter],
+ onMultipleModelConfigsChange,
+ onDebugWithMultipleModelChange: vi.fn(),
+ })
+
+ renderComponent({ modelAndParameter })
+
+ capturedModalProps?.onCompletionParamsChange({ temperature: 0.8 })
+
+ expect(onMultipleModelConfigsChange).toHaveBeenCalledWith(true, [
+ expect.objectContaining({
+ id: 'model-1',
+ parameters: { temperature: 0.8 },
+ }),
+ ])
+ })
+
+ it('should preserve other model properties when changing params', () => {
+ const onMultipleModelConfigsChange = vi.fn()
+ const modelAndParameter = createModelAndParameter({
+ id: 'model-1',
+ model: 'gpt-3.5-turbo',
+ provider: 'openai',
+ parameters: { temperature: 0.7 },
+ })
+
+ mockUseDebugWithMultipleModelContext.mockReturnValue({
+ multipleModelConfigs: [modelAndParameter],
+ onMultipleModelConfigsChange,
+ onDebugWithMultipleModelChange: vi.fn(),
+ })
+
+ renderComponent({ modelAndParameter })
+
+ capturedModalProps?.onCompletionParamsChange({ temperature: 0.8 })
+
+ expect(onMultipleModelConfigsChange).toHaveBeenCalledWith(true, [
+ expect.objectContaining({
+ id: 'model-1',
+ model: 'gpt-3.5-turbo',
+ provider: 'openai',
+ parameters: { temperature: 0.8 },
+ }),
+ ])
+ })
+ })
+
+ describe('onDebugWithMultipleModelChange', () => {
+ it('should call context onDebugWithMultipleModelChange with modelAndParameter', () => {
+ const onDebugWithMultipleModelChange = vi.fn()
+ const modelAndParameter = createModelAndParameter()
+
+ mockUseDebugWithMultipleModelContext.mockReturnValue({
+ multipleModelConfigs: [modelAndParameter],
+ onMultipleModelConfigsChange: vi.fn(),
+ onDebugWithMultipleModelChange,
+ })
+
+ renderComponent({ modelAndParameter })
+
+ capturedModalProps?.onDebugWithMultipleModelChange()
+
+ expect(onDebugWithMultipleModelChange).toHaveBeenCalledWith(modelAndParameter)
+ })
+ })
+
+ describe('index calculation', () => {
+ it('should find correct index in multipleModelConfigs', () => {
+ const models = [
+ createModelAndParameter({ id: 'model-1' }),
+ createModelAndParameter({ id: 'model-2' }),
+ createModelAndParameter({ id: 'model-3' }),
+ ]
+
+ mockUseDebugWithMultipleModelContext.mockReturnValue({
+ multipleModelConfigs: models,
+ onMultipleModelConfigsChange: vi.fn(),
+ onDebugWithMultipleModelChange: vi.fn(),
+ })
+
+ renderComponent({ modelAndParameter: models[2] })
+
+ // The component uses the index to update the correct model
+ // We verify this through the handleSelectModel behavior
+ expect(capturedModalProps).not.toBeNull()
+ })
+
+ it('should handle model not found in configs', () => {
+ mockUseDebugWithMultipleModelContext.mockReturnValue({
+ multipleModelConfigs: [createModelAndParameter({ id: 'other' })],
+ onMultipleModelConfigsChange: vi.fn(),
+ onDebugWithMultipleModelChange: vi.fn(),
+ })
+
+ // Should not throw even if model is not found
+ expect(() => renderComponent()).not.toThrow()
+ })
+ })
+
+ describe('trigger rendering', () => {
+ it('should render trigger content from renderTrigger', () => {
+ renderComponent()
+
+ // The trigger is rendered via renderTrigger callback
+ expect(screen.getByTestId('model-parameter-modal')).toBeInTheDocument()
+ })
+
+ it('should render "Select Model" text when no provider/model', () => {
+ renderComponent()
+
+ // When currentProvider and currentModel are null, shows "Select Model"
+ expect(screen.getByText('common.modelProvider.selectModel')).toBeInTheDocument()
+ })
+ })
+
+ describe('language context', () => {
+ it('should use language from useLanguage hook', () => {
+ mockUseLanguage.mockReturnValue('zh_Hans')
+
+ renderComponent()
+
+ // The language is used for MODEL_STATUS_TEXT tooltip
+ // We verify the hook is called
+ expect(mockUseLanguage).toHaveBeenCalled()
+ })
+ })
+
+ describe('edge cases', () => {
+ it('should handle empty multipleModelConfigs', () => {
+ mockUseDebugWithMultipleModelContext.mockReturnValue({
+ multipleModelConfigs: [],
+ onMultipleModelConfigsChange: vi.fn(),
+ onDebugWithMultipleModelChange: vi.fn(),
+ })
+
+ expect(() => renderComponent()).not.toThrow()
+ })
+
+ it('should handle undefined parameters', () => {
+ const modelAndParameter = createModelAndParameter({
+ parameters: undefined as unknown as FormValue,
+ })
+
+ expect(() => renderComponent({ modelAndParameter })).not.toThrow()
+ expect(capturedModalProps?.completionParams).toBeUndefined()
+ })
+
+ it('should handle model selection for model not in list', () => {
+ const onMultipleModelConfigsChange = vi.fn()
+ const modelAndParameter = createModelAndParameter({ id: 'not-in-list' })
+
+ mockUseDebugWithMultipleModelContext.mockReturnValue({
+ multipleModelConfigs: [createModelAndParameter({ id: 'different-model' })],
+ onMultipleModelConfigsChange,
+ onDebugWithMultipleModelChange: vi.fn(),
+ })
+
+ renderComponent({ modelAndParameter })
+
+ capturedModalProps?.setModel({ modelId: 'gpt-4', provider: 'openai' })
+
+ // index will be -1, so newModelConfigs[-1] will be undefined
+ // This tests the edge case behavior
+ expect(onMultipleModelConfigsChange).toHaveBeenCalled()
+ })
+ })
+
+ describe('renderTrigger with different states', () => {
+ it('should pass correct props to renderTrigger', () => {
+ renderComponent()
+
+ expect(capturedModalProps?.renderTrigger).toBeDefined()
+ expect(typeof capturedModalProps?.renderTrigger).toBe('function')
+ })
+
+ it('should render trigger with provider info when available', () => {
+ // Mock the modal to render trigger with provider
+ vi.doMock('@/app/components/header/account-setting/model-provider-page/model-parameter-modal', () => ({
+ default: (props: typeof capturedModalProps) => {
+ capturedModalProps = props
+ const triggerContent = props?.renderTrigger({
+ open: false,
+ currentProvider: { provider: 'openai' },
+ currentModel: { model: 'gpt-3.5-turbo', status: ModelStatusEnum.active },
+ })
+ return (
+
+ {triggerContent}
+
+ )
+ },
+ }))
+
+ renderComponent()
+
+ expect(screen.getByTestId('model-parameter-modal')).toBeInTheDocument()
+ })
+ })
+})
diff --git a/web/app/components/app/configuration/debug/debug-with-multiple-model/text-generation-item.spec.tsx b/web/app/components/app/configuration/debug/debug-with-multiple-model/text-generation-item.spec.tsx
new file mode 100644
index 0000000000..1876a10a0c
--- /dev/null
+++ b/web/app/components/app/configuration/debug/debug-with-multiple-model/text-generation-item.spec.tsx
@@ -0,0 +1,721 @@
+import type { ModelAndParameter } from '../types'
+import { render, screen } from '@testing-library/react'
+import { TransferMethod } from '@/app/components/base/chat/types'
+import { APP_CHAT_WITH_MULTIPLE_MODEL } from '../types'
+import TextGenerationItem from './text-generation-item'
+
+const mockUseDebugConfigurationContext = vi.fn()
+const mockUseProviderContext = vi.fn()
+const mockUseFeatures = vi.fn()
+const mockUseTextGeneration = vi.fn()
+const mockUseEventEmitterContextContext = vi.fn()
+const mockPromptVariablesToUserInputsForm = vi.fn()
+
+let capturedTextGenerationProps: {
+ content: string
+ isLoading: boolean
+ isResponding: boolean
+ messageId: string | null
+ className?: string
+} | null = null
+
+let eventSubscriptionCallback: ((v: { type: string, payload?: Record }) => void) | null = null
+
+vi.mock('@/context/debug-configuration', () => ({
+ useDebugConfigurationContext: () => mockUseDebugConfigurationContext(),
+}))
+
+vi.mock('@/context/provider-context', () => ({
+ useProviderContext: () => mockUseProviderContext(),
+}))
+
+vi.mock('@/app/components/base/features/hooks', () => ({
+ useFeatures: (selector: (state: Record) => unknown) => mockUseFeatures(selector),
+}))
+
+vi.mock('@/app/components/base/text-generation/hooks', () => ({
+ useTextGeneration: () => mockUseTextGeneration(),
+}))
+
+vi.mock('@/context/event-emitter', () => ({
+ useEventEmitterContextContext: () => mockUseEventEmitterContextContext(),
+}))
+
+vi.mock('@/utils/model-config', () => ({
+ promptVariablesToUserInputsForm: (...args: unknown[]) => mockPromptVariablesToUserInputsForm(...args),
+}))
+
+vi.mock('@/app/components/app/text-generate/item', () => ({
+ default: (props: typeof capturedTextGenerationProps) => {
+ capturedTextGenerationProps = props
+ return (
+
+ {props?.content}
+ {props?.isLoading ? 'yes' : 'no'}
+ {props?.isResponding ? 'yes' : 'no'}
+ {props?.messageId || 'null'}
+
+ )
+ },
+}))
+
+const createModelAndParameter = (overrides: Partial = {}): ModelAndParameter => ({
+ id: 'model-1',
+ model: 'gpt-3.5-turbo',
+ provider: 'openai',
+ parameters: { temperature: 0.7 },
+ ...overrides,
+})
+
+const createDefaultMocks = () => {
+ mockUseDebugConfigurationContext.mockReturnValue({
+ isAdvancedMode: false,
+ modelConfig: {
+ configs: {
+ prompt_template: 'Hello {{name}}',
+ prompt_variables: [
+ { key: 'name', name: 'Name', type: 'string', is_context_var: false },
+ ],
+ },
+ system_parameters: {},
+ },
+ appId: 'app-123',
+ inputs: { name: 'World' },
+ promptMode: 'simple',
+ speechToTextConfig: { enabled: true },
+ introduction: 'Welcome!',
+ suggestedQuestionsAfterAnswerConfig: { enabled: false },
+ citationConfig: { enabled: true },
+ externalDataToolsConfig: [],
+ chatPromptConfig: {},
+ completionPromptConfig: {},
+ dataSets: [{ id: 'ds-1' }],
+ datasetConfigs: { retrieval_model: 'single' },
+ })
+
+ mockUseProviderContext.mockReturnValue({
+ textGenerationModelList: [
+ {
+ provider: 'openai',
+ models: [
+ {
+ model: 'gpt-3.5-turbo',
+ model_properties: { mode: 'chat' },
+ },
+ ],
+ },
+ ],
+ })
+
+ mockUseFeatures.mockImplementation((selector: (state: Record) => unknown) => {
+ const state = {
+ features: {
+ moreLikeThis: { enabled: false },
+ moderation: { enabled: false },
+ text2speech: { enabled: false },
+ file: { enabled: true },
+ },
+ }
+ return selector(state)
+ })
+
+ mockUseTextGeneration.mockReturnValue({
+ completion: 'Generated text',
+ handleSend: vi.fn(),
+ isResponding: false,
+ messageId: 'msg-123',
+ })
+
+ mockUseEventEmitterContextContext.mockReturnValue({
+ eventEmitter: {
+ useSubscription: (callback: (v: { type: string, payload?: Record }) => void) => {
+ eventSubscriptionCallback = callback
+ },
+ },
+ })
+
+ mockPromptVariablesToUserInputsForm.mockReturnValue([
+ { variable: 'name', label: 'Name', type: 'text-input', required: true },
+ ])
+}
+
+const renderComponent = (props: Partial<{ modelAndParameter: ModelAndParameter }> = {}) => {
+ const defaultProps = {
+ modelAndParameter: createModelAndParameter(),
+ ...props,
+ }
+ return render()
+}
+
+describe('TextGenerationItem', () => {
+ beforeEach(() => {
+ vi.clearAllMocks()
+ capturedTextGenerationProps = null
+ eventSubscriptionCallback = null
+ createDefaultMocks()
+ })
+
+ describe('rendering', () => {
+ it('should render TextGeneration component', () => {
+ renderComponent()
+
+ expect(screen.getByTestId('text-generation')).toBeInTheDocument()
+ })
+
+ it('should pass completion content to TextGeneration', () => {
+ mockUseTextGeneration.mockReturnValue({
+ completion: 'Hello World',
+ handleSend: vi.fn(),
+ isResponding: false,
+ messageId: 'msg-1',
+ })
+
+ renderComponent()
+
+ expect(screen.getByTestId('content')).toHaveTextContent('Hello World')
+ })
+
+ it('should show loading when no completion and responding', () => {
+ mockUseTextGeneration.mockReturnValue({
+ completion: '',
+ handleSend: vi.fn(),
+ isResponding: true,
+ messageId: null,
+ })
+
+ renderComponent()
+
+ expect(screen.getByTestId('is-loading')).toHaveTextContent('yes')
+ })
+
+ it('should not show loading when completion exists', () => {
+ mockUseTextGeneration.mockReturnValue({
+ completion: 'Some text',
+ handleSend: vi.fn(),
+ isResponding: true,
+ messageId: 'msg-1',
+ })
+
+ renderComponent()
+
+ expect(screen.getByTestId('is-loading')).toHaveTextContent('no')
+ })
+
+ it('should pass isResponding to TextGeneration', () => {
+ mockUseTextGeneration.mockReturnValue({
+ completion: 'Text',
+ handleSend: vi.fn(),
+ isResponding: true,
+ messageId: 'msg-1',
+ })
+
+ renderComponent()
+
+ expect(screen.getByTestId('is-responding')).toHaveTextContent('yes')
+ })
+
+ it('should pass messageId to TextGeneration', () => {
+ mockUseTextGeneration.mockReturnValue({
+ completion: 'Text',
+ handleSend: vi.fn(),
+ isResponding: false,
+ messageId: 'msg-456',
+ })
+
+ renderComponent()
+
+ expect(screen.getByTestId('message-id')).toHaveTextContent('msg-456')
+ })
+ })
+
+ describe('config composition', () => {
+ it('should use prompt_template in non-advanced mode', () => {
+ mockUseDebugConfigurationContext.mockReturnValue({
+ isAdvancedMode: false,
+ modelConfig: {
+ configs: {
+ prompt_template: 'My Template',
+ prompt_variables: [],
+ },
+ system_parameters: {},
+ },
+ appId: 'app-123',
+ inputs: {},
+ promptMode: 'simple',
+ speechToTextConfig: {},
+ introduction: '',
+ suggestedQuestionsAfterAnswerConfig: {},
+ citationConfig: {},
+ externalDataToolsConfig: [],
+ chatPromptConfig: {},
+ completionPromptConfig: {},
+ dataSets: [],
+ datasetConfigs: {},
+ })
+
+ renderComponent()
+
+ // Config is built internally - we verify through the component rendering
+ expect(capturedTextGenerationProps).not.toBeNull()
+ })
+
+ it('should use empty pre_prompt in advanced mode', () => {
+ mockUseDebugConfigurationContext.mockReturnValue({
+ isAdvancedMode: true,
+ modelConfig: {
+ configs: {
+ prompt_template: 'Should not be used',
+ prompt_variables: [],
+ },
+ system_parameters: {},
+ },
+ appId: 'app-123',
+ inputs: {},
+ promptMode: 'advanced',
+ speechToTextConfig: {},
+ introduction: '',
+ suggestedQuestionsAfterAnswerConfig: {},
+ citationConfig: {},
+ externalDataToolsConfig: [],
+ chatPromptConfig: { custom: true },
+ completionPromptConfig: { custom: true },
+ dataSets: [],
+ datasetConfigs: {},
+ })
+
+ renderComponent()
+
+ expect(capturedTextGenerationProps).not.toBeNull()
+ })
+
+ it('should find context variable from prompt_variables', () => {
+ mockUseDebugConfigurationContext.mockReturnValue({
+ isAdvancedMode: false,
+ modelConfig: {
+ configs: {
+ prompt_template: '',
+ prompt_variables: [
+ { key: 'context', name: 'Context', type: 'string', is_context_var: true },
+ { key: 'query', name: 'Query', type: 'string', is_context_var: false },
+ ],
+ },
+ system_parameters: {},
+ },
+ appId: 'app-123',
+ inputs: {},
+ promptMode: 'simple',
+ speechToTextConfig: {},
+ introduction: '',
+ suggestedQuestionsAfterAnswerConfig: {},
+ citationConfig: {},
+ externalDataToolsConfig: [],
+ chatPromptConfig: {},
+ completionPromptConfig: {},
+ dataSets: [],
+ datasetConfigs: {},
+ })
+
+ renderComponent()
+
+ expect(capturedTextGenerationProps).not.toBeNull()
+ })
+ })
+
+ describe('dataset configuration', () => {
+ it('should transform dataSets to postDatasets format', () => {
+ mockUseDebugConfigurationContext.mockReturnValue({
+ isAdvancedMode: false,
+ modelConfig: {
+ configs: { prompt_template: '', prompt_variables: [] },
+ system_parameters: {},
+ },
+ appId: 'app-123',
+ inputs: {},
+ promptMode: 'simple',
+ speechToTextConfig: {},
+ introduction: '',
+ suggestedQuestionsAfterAnswerConfig: {},
+ citationConfig: {},
+ externalDataToolsConfig: [],
+ chatPromptConfig: {},
+ completionPromptConfig: {},
+ dataSets: [{ id: 'ds-1' }, { id: 'ds-2' }],
+ datasetConfigs: { retrieval_model: 'multiple' },
+ })
+
+ renderComponent()
+
+ // postDatasets is used in config.dataset_configs.datasets
+ expect(capturedTextGenerationProps).not.toBeNull()
+ })
+ })
+
+ describe('event subscription', () => {
+ it('should handle APP_CHAT_WITH_MULTIPLE_MODEL event', () => {
+ const handleSend = vi.fn()
+ mockUseTextGeneration.mockReturnValue({
+ completion: '',
+ handleSend,
+ isResponding: false,
+ messageId: null,
+ })
+
+ renderComponent()
+
+ eventSubscriptionCallback?.({
+ type: APP_CHAT_WITH_MULTIPLE_MODEL,
+ payload: { message: 'Generate text', files: [] },
+ })
+
+ expect(handleSend).toHaveBeenCalledWith(
+ 'apps/app-123/completion-messages',
+ expect.objectContaining({
+ inputs: { name: 'World' },
+ }),
+ )
+ })
+
+ it('should ignore other event types', () => {
+ const handleSend = vi.fn()
+ mockUseTextGeneration.mockReturnValue({
+ completion: '',
+ handleSend,
+ isResponding: false,
+ messageId: null,
+ })
+
+ renderComponent()
+
+ eventSubscriptionCallback?.({
+ type: 'OTHER_EVENT',
+ payload: {},
+ })
+
+ expect(handleSend).not.toHaveBeenCalled()
+ })
+ })
+
+ describe('doSend function', () => {
+ it('should include model configuration', () => {
+ const handleSend = vi.fn()
+ mockUseTextGeneration.mockReturnValue({
+ completion: '',
+ handleSend,
+ isResponding: false,
+ messageId: null,
+ })
+
+ const modelAndParameter = createModelAndParameter({
+ provider: 'anthropic',
+ model: 'claude-3',
+ parameters: { max_tokens: 2000 },
+ })
+
+ renderComponent({ modelAndParameter })
+
+ eventSubscriptionCallback?.({
+ type: APP_CHAT_WITH_MULTIPLE_MODEL,
+ payload: { message: 'Test', files: [] },
+ })
+
+ expect(handleSend).toHaveBeenCalledWith(
+ expect.any(String),
+ expect.objectContaining({
+ model_config: expect.objectContaining({
+ model: expect.objectContaining({
+ provider: 'anthropic',
+ name: 'claude-3',
+ completion_params: { max_tokens: 2000 },
+ }),
+ }),
+ }),
+ )
+ })
+
+ it('should include files with local_file transfer method handled', () => {
+ const handleSend = vi.fn()
+ mockUseTextGeneration.mockReturnValue({
+ completion: '',
+ handleSend,
+ isResponding: false,
+ messageId: null,
+ })
+
+ renderComponent()
+
+ const files = [
+ { id: 'f1', transfer_method: TransferMethod.local_file, url: 'blob:123' },
+ { id: 'f2', transfer_method: TransferMethod.remote_url, url: 'https://example.com/file' },
+ ]
+
+ eventSubscriptionCallback?.({
+ type: APP_CHAT_WITH_MULTIPLE_MODEL,
+ payload: { message: 'Test', files },
+ })
+
+ expect(handleSend).toHaveBeenCalledWith(
+ expect.any(String),
+ expect.objectContaining({
+ files: [
+ expect.objectContaining({ id: 'f1', transfer_method: TransferMethod.local_file, url: '' }),
+ expect.objectContaining({ id: 'f2', transfer_method: TransferMethod.remote_url, url: 'https://example.com/file' }),
+ ],
+ }),
+ )
+ })
+
+ it('should not include files when file upload is disabled', () => {
+ const handleSend = vi.fn()
+ mockUseTextGeneration.mockReturnValue({
+ completion: '',
+ handleSend,
+ isResponding: false,
+ messageId: null,
+ })
+
+ mockUseFeatures.mockImplementation((selector: (state: Record) => unknown) => {
+ const state = {
+ features: {
+ moreLikeThis: { enabled: false },
+ moderation: { enabled: false },
+ text2speech: { enabled: false },
+ file: { enabled: false },
+ },
+ }
+ return selector(state)
+ })
+
+ renderComponent()
+
+ eventSubscriptionCallback?.({
+ type: APP_CHAT_WITH_MULTIPLE_MODEL,
+ payload: { message: 'Test', files: [{ id: 'f1' }] },
+ })
+
+ expect(handleSend).toHaveBeenCalledWith(
+ expect.any(String),
+ expect.not.objectContaining({
+ files: expect.anything(),
+ }),
+ )
+ })
+
+ it('should not include files when files array is empty', () => {
+ const handleSend = vi.fn()
+ mockUseTextGeneration.mockReturnValue({
+ completion: '',
+ handleSend,
+ isResponding: false,
+ messageId: null,
+ })
+
+ renderComponent()
+
+ eventSubscriptionCallback?.({
+ type: APP_CHAT_WITH_MULTIPLE_MODEL,
+ payload: { message: 'Test', files: [] },
+ })
+
+ expect(handleSend).toHaveBeenCalledWith(
+ expect.any(String),
+ expect.not.objectContaining({
+ files: expect.anything(),
+ }),
+ )
+ })
+
+ it('should not include files when files is undefined', () => {
+ const handleSend = vi.fn()
+ mockUseTextGeneration.mockReturnValue({
+ completion: '',
+ handleSend,
+ isResponding: false,
+ messageId: null,
+ })
+
+ renderComponent()
+
+ eventSubscriptionCallback?.({
+ type: APP_CHAT_WITH_MULTIPLE_MODEL,
+ payload: { message: 'Test' },
+ })
+
+ expect(handleSend).toHaveBeenCalledWith(
+ expect.any(String),
+ expect.not.objectContaining({
+ files: expect.anything(),
+ }),
+ )
+ })
+ })
+
+ describe('model resolution', () => {
+ it('should find current provider and model', () => {
+ const handleSend = vi.fn()
+ mockUseTextGeneration.mockReturnValue({
+ completion: '',
+ handleSend,
+ isResponding: false,
+ messageId: null,
+ })
+
+ mockUseProviderContext.mockReturnValue({
+ textGenerationModelList: [
+ {
+ provider: 'openai',
+ models: [
+ { model: 'gpt-3.5-turbo', model_properties: { mode: 'chat' } },
+ { model: 'gpt-4', model_properties: { mode: 'chat' } },
+ ],
+ },
+ ],
+ })
+
+ const modelAndParameter = createModelAndParameter({
+ provider: 'openai',
+ model: 'gpt-4',
+ })
+
+ renderComponent({ modelAndParameter })
+
+ eventSubscriptionCallback?.({
+ type: APP_CHAT_WITH_MULTIPLE_MODEL,
+ payload: { message: 'Test', files: [] },
+ })
+
+ expect(handleSend).toHaveBeenCalledWith(
+ expect.any(String),
+ expect.objectContaining({
+ model_config: expect.objectContaining({
+ model: expect.objectContaining({
+ mode: 'chat',
+ }),
+ }),
+ }),
+ )
+ })
+
+ it('should handle provider not found', () => {
+ const handleSend = vi.fn()
+ mockUseTextGeneration.mockReturnValue({
+ completion: '',
+ handleSend,
+ isResponding: false,
+ messageId: null,
+ })
+
+ mockUseProviderContext.mockReturnValue({
+ textGenerationModelList: [],
+ })
+
+ renderComponent()
+
+ eventSubscriptionCallback?.({
+ type: APP_CHAT_WITH_MULTIPLE_MODEL,
+ payload: { message: 'Test', files: [] },
+ })
+
+ // Should still call handleSend without crashing
+ expect(handleSend).toHaveBeenCalled()
+ })
+ })
+
+ describe('edge cases', () => {
+ it('should handle null eventEmitter', () => {
+ mockUseEventEmitterContextContext.mockReturnValue({
+ eventEmitter: null,
+ })
+
+ expect(() => renderComponent()).not.toThrow()
+ })
+
+ it('should handle empty prompt_variables', () => {
+ mockUseDebugConfigurationContext.mockReturnValue({
+ isAdvancedMode: false,
+ modelConfig: {
+ configs: { prompt_template: '', prompt_variables: [] },
+ system_parameters: {},
+ },
+ appId: 'app-123',
+ inputs: {},
+ promptMode: 'simple',
+ speechToTextConfig: {},
+ introduction: '',
+ suggestedQuestionsAfterAnswerConfig: {},
+ citationConfig: {},
+ externalDataToolsConfig: [],
+ chatPromptConfig: {},
+ completionPromptConfig: {},
+ dataSets: [],
+ datasetConfigs: {},
+ })
+
+ expect(() => renderComponent()).not.toThrow()
+ })
+
+ it('should handle no context variable found', () => {
+ mockUseDebugConfigurationContext.mockReturnValue({
+ isAdvancedMode: false,
+ modelConfig: {
+ configs: {
+ prompt_template: '',
+ prompt_variables: [
+ { key: 'var1', name: 'Var1', type: 'string', is_context_var: false },
+ ],
+ },
+ system_parameters: {},
+ },
+ appId: 'app-123',
+ inputs: {},
+ promptMode: 'simple',
+ speechToTextConfig: {},
+ introduction: '',
+ suggestedQuestionsAfterAnswerConfig: {},
+ citationConfig: {},
+ externalDataToolsConfig: [],
+ chatPromptConfig: {},
+ completionPromptConfig: {},
+ dataSets: [],
+ datasetConfigs: {},
+ })
+
+ renderComponent()
+
+ // Should use empty string for dataset_query_variable
+ expect(capturedTextGenerationProps).not.toBeNull()
+ })
+ })
+
+ describe('promptVariablesToUserInputsForm', () => {
+ it('should call promptVariablesToUserInputsForm with prompt_variables', () => {
+ const promptVariables = [
+ { key: 'name', name: 'Name', type: 'string' },
+ { key: 'age', name: 'Age', type: 'number' },
+ ]
+
+ mockUseDebugConfigurationContext.mockReturnValue({
+ isAdvancedMode: false,
+ modelConfig: {
+ configs: { prompt_template: '', prompt_variables: promptVariables },
+ system_parameters: {},
+ },
+ appId: 'app-123',
+ inputs: {},
+ promptMode: 'simple',
+ speechToTextConfig: {},
+ introduction: '',
+ suggestedQuestionsAfterAnswerConfig: {},
+ citationConfig: {},
+ externalDataToolsConfig: [],
+ chatPromptConfig: {},
+ completionPromptConfig: {},
+ dataSets: [],
+ datasetConfigs: {},
+ })
+
+ renderComponent()
+
+ expect(mockPromptVariablesToUserInputsForm).toHaveBeenCalledWith(promptVariables)
+ })
+ })
+})
diff --git a/web/app/components/base/icons/assets/vender/knowledge/search-lines-sparkle.svg b/web/app/components/base/icons/assets/vender/knowledge/search-lines-sparkle.svg
new file mode 100644
index 0000000000..1eb2781715
--- /dev/null
+++ b/web/app/components/base/icons/assets/vender/knowledge/search-lines-sparkle.svg
@@ -0,0 +1,6 @@
+
diff --git a/web/app/components/base/icons/src/vender/knowledge/SearchLinesSparkle.json b/web/app/components/base/icons/src/vender/knowledge/SearchLinesSparkle.json
new file mode 100644
index 0000000000..7fa195092d
--- /dev/null
+++ b/web/app/components/base/icons/src/vender/knowledge/SearchLinesSparkle.json
@@ -0,0 +1,53 @@
+{
+ "icon": {
+ "type": "element",
+ "isRootNode": true,
+ "name": "svg",
+ "attributes": {
+ "width": "16",
+ "height": "16",
+ "viewBox": "0 0 16 16",
+ "fill": "none",
+ "xmlns": "http://www.w3.org/2000/svg"
+ },
+ "children": [
+ {
+ "type": "element",
+ "name": "path",
+ "attributes": {
+ "d": "M12 7.33337V2.66671H4.00002V13.3334H8.00002C8.36821 13.3334 8.66669 13.6319 8.66669 14C8.66669 14.3682 8.36821 14.6667 8.00002 14.6667H3.33335C2.96516 14.6667 2.66669 14.3682 2.66669 14V2.00004C2.66669 1.63185 2.96516 1.33337 3.33335 1.33337H12.6667C13.0349 1.33337 13.3334 1.63185 13.3334 2.00004V7.33337C13.3334 7.70156 13.0349 8.00004 12.6667 8.00004C12.2985 8.00004 12 7.70156 12 7.33337Z",
+ "fill": "currentColor"
+ },
+ "children": []
+ },
+ {
+ "type": "element",
+ "name": "path",
+ "attributes": {
+ "d": "M10 4.00004C10.3682 4.00004 10.6667 4.29852 10.6667 4.66671C10.6667 5.0349 10.3682 5.33337 10 5.33337H6.00002C5.63183 5.33337 5.33335 5.0349 5.33335 4.66671C5.33335 4.29852 5.63183 4.00004 6.00002 4.00004H10Z",
+ "fill": "currentColor"
+ },
+ "children": []
+ },
+ {
+ "type": "element",
+ "name": "path",
+ "attributes": {
+ "d": "M8.00002 6.66671C8.36821 6.66671 8.66669 6.96518 8.66669 7.33337C8.66669 7.70156 8.36821 8.00004 8.00002 8.00004H6.00002C5.63183 8.00004 5.33335 7.70156 5.33335 7.33337C5.33335 6.96518 5.63183 6.66671 6.00002 6.66671H8.00002Z",
+ "fill": "currentColor"
+ },
+ "children": []
+ },
+ {
+ "type": "element",
+ "name": "path",
+ "attributes": {
+ "d": "M12.827 10.7902L12.3624 9.58224C12.3048 9.43231 12.1607 9.33337 12 9.33337C11.8394 9.33337 11.6953 9.43231 11.6376 9.58224L11.173 10.7902C11.1054 10.9662 10.9662 11.1054 10.7902 11.173L9.58222 11.6376C9.43229 11.6953 9.33335 11.8394 9.33335 12C9.33335 12.1607 9.43229 12.3048 9.58222 12.3624L10.7902 12.827C10.9662 12.8947 11.1054 13.0338 11.173 13.2099L11.6376 14.4178C11.6953 14.5678 11.8394 14.6667 12 14.6667C12.1607 14.6667 12.3048 14.5678 12.3624 14.4178L12.827 13.2099C12.8947 13.0338 13.0338 12.8947 13.2099 12.827L14.4178 12.3624C14.5678 12.3048 14.6667 12.1607 14.6667 12C14.6667 11.8394 14.5678 11.6953 14.4178 11.6376L13.2099 11.173C13.0338 11.1054 12.8947 10.9662 12.827 10.7902Z",
+ "fill": "currentColor"
+ },
+ "children": []
+ }
+ ]
+ },
+ "name": "SearchLinesSparkle"
+}
diff --git a/web/app/components/base/icons/src/vender/knowledge/SearchLinesSparkle.tsx b/web/app/components/base/icons/src/vender/knowledge/SearchLinesSparkle.tsx
new file mode 100644
index 0000000000..3ae90b5fa1
--- /dev/null
+++ b/web/app/components/base/icons/src/vender/knowledge/SearchLinesSparkle.tsx
@@ -0,0 +1,20 @@
+// GENERATE BY script
+// DON NOT EDIT IT MANUALLY
+
+import type { IconData } from '@/app/components/base/icons/IconBase'
+import * as React from 'react'
+import IconBase from '@/app/components/base/icons/IconBase'
+import data from './SearchLinesSparkle.json'
+
+const Icon = (
+ {
+ ref,
+ ...props
+ }: React.SVGProps & {
+ ref?: React.RefObject>
+ },
+) =>
+
+Icon.displayName = 'SearchLinesSparkle'
+
+export default Icon
diff --git a/web/app/components/base/icons/src/vender/knowledge/index.ts b/web/app/components/base/icons/src/vender/knowledge/index.ts
index 7239511af3..44055c4975 100644
--- a/web/app/components/base/icons/src/vender/knowledge/index.ts
+++ b/web/app/components/base/icons/src/vender/knowledge/index.ts
@@ -11,5 +11,6 @@ export { default as HighQuality } from './HighQuality'
export { default as HybridSearch } from './HybridSearch'
export { default as ParentChildChunk } from './ParentChildChunk'
export { default as QuestionAndAnswer } from './QuestionAndAnswer'
+export { default as SearchLinesSparkle } from './SearchLinesSparkle'
export { default as SearchMenu } from './SearchMenu'
export { default as VectorSearch } from './VectorSearch'
diff --git a/web/app/components/datasets/create/step-two/components/general-chunking-options.tsx b/web/app/components/datasets/create/step-two/components/general-chunking-options.tsx
index 5140c902f5..84d742d734 100644
--- a/web/app/components/datasets/create/step-two/components/general-chunking-options.tsx
+++ b/web/app/components/datasets/create/step-two/components/general-chunking-options.tsx
@@ -1,7 +1,7 @@
'use client'
import type { FC } from 'react'
-import type { PreProcessingRule } from '@/models/datasets'
+import type { PreProcessingRule, SummaryIndexSetting as SummaryIndexSettingType } from '@/models/datasets'
import {
RiAlertFill,
RiSearchEyeLine,
@@ -12,6 +12,7 @@ import Button from '@/app/components/base/button'
import Checkbox from '@/app/components/base/checkbox'
import Divider from '@/app/components/base/divider'
import Tooltip from '@/app/components/base/tooltip'
+import SummaryIndexSetting from '@/app/components/datasets/settings/summary-index-setting'
import { IS_CE_EDITION } from '@/config'
import { ChunkingMode } from '@/models/datasets'
import SettingCog from '../../assets/setting-gear-mod.svg'
@@ -52,6 +53,9 @@ type GeneralChunkingOptionsProps = {
onReset: () => void
// Locale
locale: string
+ showSummaryIndexSetting?: boolean
+ summaryIndexSetting?: SummaryIndexSettingType
+ onSummaryIndexSettingChange?: (payload: SummaryIndexSettingType) => void
}
export const GeneralChunkingOptions: FC = ({
@@ -74,6 +78,9 @@ export const GeneralChunkingOptions: FC = ({
onPreview,
onReset,
locale,
+ showSummaryIndexSetting,
+ summaryIndexSetting,
+ onSummaryIndexSettingChange,
}) => {
const { t } = useTranslation()
@@ -146,6 +153,17 @@ export const GeneralChunkingOptions: FC = ({
))}
+ {
+ showSummaryIndexSetting && (
+
+
+
+ )
+ }
{IS_CE_EDITION && (
<>
diff --git a/web/app/components/datasets/create/step-two/components/parent-child-options.tsx b/web/app/components/datasets/create/step-two/components/parent-child-options.tsx
index e46aa5817b..22b88037e1 100644
--- a/web/app/components/datasets/create/step-two/components/parent-child-options.tsx
+++ b/web/app/components/datasets/create/step-two/components/parent-child-options.tsx
@@ -2,7 +2,7 @@
import type { FC } from 'react'
import type { ParentChildConfig } from '../hooks'
-import type { ParentMode, PreProcessingRule } from '@/models/datasets'
+import type { ParentMode, PreProcessingRule, SummaryIndexSetting as SummaryIndexSettingType } from '@/models/datasets'
import { RiSearchEyeLine } from '@remixicon/react'
import Image from 'next/image'
import { useTranslation } from 'react-i18next'
@@ -11,6 +11,7 @@ import Checkbox from '@/app/components/base/checkbox'
import Divider from '@/app/components/base/divider'
import { ParentChildChunk } from '@/app/components/base/icons/src/vender/knowledge'
import RadioCard from '@/app/components/base/radio-card'
+import SummaryIndexSetting from '@/app/components/datasets/settings/summary-index-setting'
import { ChunkingMode } from '@/models/datasets'
import FileList from '../../assets/file-list-3-fill.svg'
import Note from '../../assets/note-mod.svg'
@@ -31,6 +32,8 @@ type ParentChildOptionsProps = {
// State
parentChildConfig: ParentChildConfig
rules: PreProcessingRule[]
+ summaryIndexSetting?: SummaryIndexSettingType
+ onSummaryIndexSettingChange?: (payload: SummaryIndexSettingType) => void
currentDocForm: ChunkingMode
// Flags
isActive: boolean
@@ -46,11 +49,13 @@ type ParentChildOptionsProps = {
onRuleToggle: (id: string) => void
onPreview: () => void
onReset: () => void
+ showSummaryIndexSetting?: boolean
}
export const ParentChildOptions: FC = ({
parentChildConfig,
rules,
+ summaryIndexSetting,
currentDocForm: _currentDocForm,
isActive,
isInUpload,
@@ -62,8 +67,10 @@ export const ParentChildOptions: FC = ({
onChildDelimiterChange,
onChildMaxLengthChange,
onRuleToggle,
+ onSummaryIndexSettingChange,
onPreview,
onReset,
+ showSummaryIndexSetting,
}) => {
const { t } = useTranslation()
@@ -183,6 +190,17 @@ export const ParentChildOptions: FC = ({
))}
+ {
+ showSummaryIndexSetting && (
+
+
+
+ )
+ }
diff --git a/web/app/components/datasets/create/step-two/components/preview-panel.tsx b/web/app/components/datasets/create/step-two/components/preview-panel.tsx
index 4f25cee5bd..5cb33f2d6d 100644
--- a/web/app/components/datasets/create/step-two/components/preview-panel.tsx
+++ b/web/app/components/datasets/create/step-two/components/preview-panel.tsx
@@ -14,6 +14,7 @@ import { ChunkingMode } from '@/models/datasets'
import { cn } from '@/utils/classnames'
import { ChunkContainer, QAPreview } from '../../../chunk'
import PreviewDocumentPicker from '../../../common/document-picker/preview-document-picker'
+import SummaryLabel from '../../../documents/detail/completed/common/summary-label'
import { PreviewSlice } from '../../../formatted-text/flavours/preview-slice'
import { FormattedText } from '../../../formatted-text/formatted'
import PreviewContainer from '../../../preview/container'
@@ -99,6 +100,7 @@ export const PreviewPanel: FC = ({
characterCount={item.content.length}
>
{item.content}
+ {item.summary && }
))
)}
@@ -131,6 +133,7 @@ export const PreviewPanel: FC = ({
)
})}
+ {item.summary && }
)
})
diff --git a/web/app/components/datasets/create/step-two/hooks/use-document-creation.ts b/web/app/components/datasets/create/step-two/hooks/use-document-creation.ts
index fd132b38ef..eaa51e393a 100644
--- a/web/app/components/datasets/create/step-two/hooks/use-document-creation.ts
+++ b/web/app/components/datasets/create/step-two/hooks/use-document-creation.ts
@@ -9,6 +9,7 @@ import type {
CustomFile,
FullDocumentDetail,
ProcessRule,
+ SummaryIndexSetting as SummaryIndexSettingType,
} from '@/models/datasets'
import type { RetrievalConfig, RETRIEVE_METHOD } from '@/types/app'
import { useCallback } from 'react'
@@ -141,6 +142,7 @@ export const useDocumentCreation = (options: UseDocumentCreationOptions) => {
retrievalConfig: RetrievalConfig,
embeddingModel: DefaultModel,
indexingTechnique: string,
+ summaryIndexSetting?: SummaryIndexSettingType,
): CreateDocumentReq | null => {
if (isSetting) {
return {
@@ -148,6 +150,7 @@ export const useDocumentCreation = (options: UseDocumentCreationOptions) => {
doc_form: currentDocForm,
doc_language: docLanguage,
process_rule: processRule,
+ summary_index_setting: summaryIndexSetting,
retrieval_model: retrievalConfig,
embedding_model: embeddingModel.model,
embedding_model_provider: embeddingModel.provider,
@@ -164,6 +167,7 @@ export const useDocumentCreation = (options: UseDocumentCreationOptions) => {
},
indexing_technique: indexingTechnique,
process_rule: processRule,
+ summary_index_setting: summaryIndexSetting,
doc_form: currentDocForm,
doc_language: docLanguage,
retrieval_model: retrievalConfig,
diff --git a/web/app/components/datasets/create/step-two/hooks/use-segmentation-state.ts b/web/app/components/datasets/create/step-two/hooks/use-segmentation-state.ts
index 69cc089b4f..503704276e 100644
--- a/web/app/components/datasets/create/step-two/hooks/use-segmentation-state.ts
+++ b/web/app/components/datasets/create/step-two/hooks/use-segmentation-state.ts
@@ -1,5 +1,5 @@
-import type { ParentMode, PreProcessingRule, ProcessRule, Rules } from '@/models/datasets'
-import { useCallback, useState } from 'react'
+import type { ParentMode, PreProcessingRule, ProcessRule, Rules, SummaryIndexSetting as SummaryIndexSettingType } from '@/models/datasets'
+import { useCallback, useRef, useState } from 'react'
import { ChunkingMode, ProcessMode } from '@/models/datasets'
import escape from './escape'
import unescape from './unescape'
@@ -39,10 +39,11 @@ export const defaultParentChildConfig: ParentChildConfig = {
export type UseSegmentationStateOptions = {
initialSegmentationType?: ProcessMode
+ initialSummaryIndexSetting?: SummaryIndexSettingType
}
export const useSegmentationState = (options: UseSegmentationStateOptions = {}) => {
- const { initialSegmentationType } = options
+ const { initialSegmentationType, initialSummaryIndexSetting } = options
// Segmentation type (general or parent-child)
const [segmentationType, setSegmentationType] = useState(
@@ -58,6 +59,15 @@ export const useSegmentationState = (options: UseSegmentationStateOptions = {})
// Pre-processing rules
const [rules, setRules] = useState([])
const [defaultConfig, setDefaultConfig] = useState()
+ const [summaryIndexSetting, setSummaryIndexSetting] = useState(initialSummaryIndexSetting)
+ const summaryIndexSettingRef = useRef(initialSummaryIndexSetting)
+ const handleSummaryIndexSettingChange = useCallback((payload: SummaryIndexSettingType) => {
+ setSummaryIndexSetting((prev) => {
+ const newSetting = { ...prev, ...payload }
+ summaryIndexSettingRef.current = newSetting
+ return newSetting
+ })
+ }, [])
// Parent-child config
const [parentChildConfig, setParentChildConfig] = useState(defaultParentChildConfig)
@@ -134,6 +144,7 @@ export const useSegmentationState = (options: UseSegmentationStateOptions = {})
},
},
mode: 'hierarchical',
+ summary_index_setting: summaryIndexSettingRef.current,
} as ProcessRule
}
@@ -147,6 +158,7 @@ export const useSegmentationState = (options: UseSegmentationStateOptions = {})
},
},
mode: segmentationType,
+ summary_index_setting: summaryIndexSettingRef.current,
} as ProcessRule
}, [rules, parentChildConfig, segmentIdentifier, maxChunkLength, overlap, segmentationType])
@@ -204,6 +216,8 @@ export const useSegmentationState = (options: UseSegmentationStateOptions = {})
defaultConfig,
setDefaultConfig,
toggleRule,
+ summaryIndexSetting,
+ handleSummaryIndexSettingChange,
// Parent-child config
parentChildConfig,
diff --git a/web/app/components/datasets/create/step-two/index.tsx b/web/app/components/datasets/create/step-two/index.tsx
index b4d2c5f6e9..a77d829488 100644
--- a/web/app/components/datasets/create/step-two/index.tsx
+++ b/web/app/components/datasets/create/step-two/index.tsx
@@ -65,7 +65,9 @@ const StepTwo: FC = ({
// Custom hooks
const segmentation = useSegmentationState({
initialSegmentationType: currentDataset?.doc_form === ChunkingMode.parentChild ? ProcessMode.parentChild : ProcessMode.general,
+ initialSummaryIndexSetting: currentDataset?.summary_index_setting,
})
+ const showSummaryIndexSetting = !currentDataset
const indexing = useIndexingConfig({
initialIndexType: propsIndexingType,
initialEmbeddingModel: currentDataset?.embedding_model ? { provider: currentDataset.embedding_model_provider, model: currentDataset.embedding_model } : undefined,
@@ -156,7 +158,7 @@ const StepTwo: FC = ({
})
if (!isValid)
return
- const params = creation.buildCreationParams(currentDocForm, docLanguage, segmentation.getProcessRule(currentDocForm), indexing.retrievalConfig, indexing.embeddingModel, indexing.getIndexingTechnique())
+ const params = creation.buildCreationParams(currentDocForm, docLanguage, segmentation.getProcessRule(currentDocForm), indexing.retrievalConfig, indexing.embeddingModel, indexing.getIndexingTechnique(), segmentation.summaryIndexSetting)
if (!params)
return
await creation.executeCreation(params, indexing.indexType, indexing.retrievalConfig)
@@ -217,6 +219,9 @@ const StepTwo: FC = ({
onPreview={updatePreview}
onReset={segmentation.resetToDefaults}
locale={locale}
+ showSummaryIndexSetting={showSummaryIndexSetting}
+ summaryIndexSetting={segmentation.summaryIndexSetting}
+ onSummaryIndexSettingChange={segmentation.handleSummaryIndexSettingChange}
/>
)}
{showParentChildOption && (
@@ -236,6 +241,9 @@ const StepTwo: FC = ({
onRuleToggle={segmentation.toggleRule}
onPreview={updatePreview}
onReset={segmentation.resetToDefaults}
+ showSummaryIndexSetting={showSummaryIndexSetting}
+ summaryIndexSetting={segmentation.summaryIndexSetting}
+ onSummaryIndexSettingChange={segmentation.handleSummaryIndexSettingChange}
/>
)}
diff --git a/web/app/components/datasets/create/website/watercrawl/index.spec.tsx b/web/app/components/datasets/create/website/watercrawl/index.spec.tsx
index 646c59eb75..c3caab895a 100644
--- a/web/app/components/datasets/create/website/watercrawl/index.spec.tsx
+++ b/web/app/components/datasets/create/website/watercrawl/index.spec.tsx
@@ -73,6 +73,12 @@ const createDefaultProps = (overrides: Partial[0]>
describe('WaterCrawl', () => {
beforeEach(() => {
vi.clearAllMocks()
+ vi.useFakeTimers({ shouldAdvanceTime: true })
+ })
+
+ afterEach(() => {
+ vi.runOnlyPendingTimers()
+ vi.useRealTimers()
})
// Tests for initial component rendering
diff --git a/web/app/components/datasets/documents/components/list.tsx b/web/app/components/datasets/documents/components/list.tsx
index 01d4afb646..f63d6d987e 100644
--- a/web/app/components/datasets/documents/components/list.tsx
+++ b/web/app/components/datasets/documents/components/list.tsx
@@ -30,12 +30,13 @@ import { useDatasetDetailContextWithSelector as useDatasetDetailContext } from '
import useTimestamp from '@/hooks/use-timestamp'
import { ChunkingMode, DataSourceType, DocumentActionType } from '@/models/datasets'
import { DatasourceType } from '@/models/pipeline'
-import { useDocumentArchive, useDocumentBatchRetryIndex, useDocumentDelete, useDocumentDisable, useDocumentDownloadZip, useDocumentEnable } from '@/service/knowledge/use-document'
+import { useDocumentArchive, useDocumentBatchRetryIndex, useDocumentDelete, useDocumentDisable, useDocumentDownloadZip, useDocumentEnable, useDocumentSummary } from '@/service/knowledge/use-document'
import { asyncRunSafe } from '@/utils'
import { cn } from '@/utils/classnames'
import { downloadBlob } from '@/utils/download'
import { formatNumber } from '@/utils/format'
import BatchAction from '../detail/completed/common/batch-action'
+import SummaryStatus from '../detail/completed/common/summary-status'
import StatusItem from '../status-item'
import s from '../style.module.css'
import Operations from './operations'
@@ -219,6 +220,7 @@ const DocumentList: FC = ({
onSelectedIdChange(uniq([...selectedIds, ...localDocs.map(doc => doc.id)]))
}, [isAllSelected, localDocs, onSelectedIdChange, selectedIds])
const { mutateAsync: archiveDocument } = useDocumentArchive()
+ const { mutateAsync: generateSummary } = useDocumentSummary()
const { mutateAsync: enableDocument } = useDocumentEnable()
const { mutateAsync: disableDocument } = useDocumentDisable()
const { mutateAsync: deleteDocument } = useDocumentDelete()
@@ -232,6 +234,9 @@ const DocumentList: FC = ({
case DocumentActionType.archive:
opApi = archiveDocument
break
+ case DocumentActionType.summary:
+ opApi = generateSummary
+ break
case DocumentActionType.enable:
opApi = enableDocument
break
@@ -444,6 +449,13 @@ const DocumentList: FC = ({
>
{doc.name}
+ {
+ doc.summary_index_status && (
+
+
+
+ )
+ }
= ({
className="absolute bottom-16 left-0 z-20"
selectedIds={selectedIds}
onArchive={handleAction(DocumentActionType.archive)}
+ onBatchSummary={handleAction(DocumentActionType.summary)}
onBatchEnable={handleAction(DocumentActionType.enable)}
onBatchDisable={handleAction(DocumentActionType.disable)}
onBatchDownload={downloadableSelectedIds.length > 0 ? handleBatchDownload : undefined}
diff --git a/web/app/components/datasets/documents/components/operations.spec.tsx b/web/app/components/datasets/documents/components/operations.spec.tsx
index 25d4accc25..f341931a4b 100644
--- a/web/app/components/datasets/documents/components/operations.spec.tsx
+++ b/web/app/components/datasets/documents/components/operations.spec.tsx
@@ -15,6 +15,7 @@ vi.mock('@/service/knowledge/use-document', () => ({
useSyncWebsite: () => ({ mutateAsync: vi.fn().mockResolvedValue({}) }),
useDocumentPause: () => ({ mutateAsync: vi.fn().mockResolvedValue({}) }),
useDocumentResume: () => ({ mutateAsync: vi.fn().mockResolvedValue({}) }),
+ useDocumentSummary: () => ({ mutateAsync: vi.fn().mockResolvedValue({}) }),
}))
// Mock utils
diff --git a/web/app/components/datasets/documents/components/operations.tsx b/web/app/components/datasets/documents/components/operations.tsx
index ee638c5e12..d3dcc23121 100644
--- a/web/app/components/datasets/documents/components/operations.tsx
+++ b/web/app/components/datasets/documents/components/operations.tsx
@@ -21,6 +21,7 @@ import { useTranslation } from 'react-i18next'
import { useContext } from 'use-context-selector'
import Confirm from '@/app/components/base/confirm'
import Divider from '@/app/components/base/divider'
+import { SearchLinesSparkle } from '@/app/components/base/icons/src/vender/knowledge'
import CustomPopover from '@/app/components/base/popover'
import Switch from '@/app/components/base/switch'
import { ToastContext } from '@/app/components/base/toast'
@@ -34,6 +35,7 @@ import {
useDocumentEnable,
useDocumentPause,
useDocumentResume,
+ useDocumentSummary,
useDocumentUnArchive,
useSyncDocument,
useSyncWebsite,
@@ -87,6 +89,7 @@ const Operations = ({
const { mutateAsync: downloadDocument, isPending: isDownloading } = useDocumentDownload()
const { mutateAsync: syncDocument } = useSyncDocument()
const { mutateAsync: syncWebsite } = useSyncWebsite()
+ const { mutateAsync: generateSummary } = useDocumentSummary()
const { mutateAsync: pauseDocument } = useDocumentPause()
const { mutateAsync: resumeDocument } = useDocumentResume()
const isListScene = scene === 'list'
@@ -112,6 +115,9 @@ const Operations = ({
else
opApi = syncWebsite
break
+ case 'summary':
+ opApi = generateSummary
+ break
case 'pause':
opApi = pauseDocument
break
@@ -257,6 +263,10 @@ const Operations = ({
{t('list.action.sync', { ns: 'datasetDocuments' })}
)}
+ onOperate('summary')}>
+
+ {t('list.action.summary', { ns: 'datasetDocuments' })}
+
>
)}
diff --git a/web/app/components/datasets/documents/create-from-pipeline/preview/chunk-preview.tsx b/web/app/components/datasets/documents/create-from-pipeline/preview/chunk-preview.tsx
index abfdea319b..d57f06ed00 100644
--- a/web/app/components/datasets/documents/create-from-pipeline/preview/chunk-preview.tsx
+++ b/web/app/components/datasets/documents/create-from-pipeline/preview/chunk-preview.tsx
@@ -8,6 +8,7 @@ import { useTranslation } from 'react-i18next'
import Badge from '@/app/components/base/badge'
import Button from '@/app/components/base/button'
import { SkeletonContainer, SkeletonPoint, SkeletonRectangle, SkeletonRow } from '@/app/components/base/skeleton'
+import SummaryLabel from '@/app/components/datasets/documents/detail/completed/common/summary-label'
import { useDatasetDetailContextWithSelector } from '@/context/dataset-detail'
import { ChunkingMode } from '@/models/datasets'
import { DatasourceType } from '@/models/pipeline'
@@ -181,6 +182,7 @@ const ChunkPreview = ({
characterCount={item.content.length}
>
{item.content}
+ {item.summary && }
))
)}
@@ -207,6 +209,7 @@ const ChunkPreview = ({
/>
)
})}
+ {item.summary && }
)
diff --git a/web/app/components/datasets/documents/detail/completed/common/batch-action.tsx b/web/app/components/datasets/documents/detail/completed/common/batch-action.tsx
index 2de72d9ff6..486ba2ffdf 100644
--- a/web/app/components/datasets/documents/detail/completed/common/batch-action.tsx
+++ b/web/app/components/datasets/documents/detail/completed/common/batch-action.tsx
@@ -6,6 +6,7 @@ import { useTranslation } from 'react-i18next'
import Button from '@/app/components/base/button'
import Confirm from '@/app/components/base/confirm'
import Divider from '@/app/components/base/divider'
+import { SearchLinesSparkle } from '@/app/components/base/icons/src/vender/knowledge'
import { cn } from '@/utils/classnames'
const i18nPrefix = 'batchAction'
@@ -16,6 +17,7 @@ type IBatchActionProps = {
onBatchDisable: () => void
onBatchDownload?: () => void
onBatchDelete: () => Promise
+ onBatchSummary?: () => void
onArchive?: () => void
onEditMetadata?: () => void
onBatchReIndex?: () => void
@@ -27,6 +29,7 @@ const BatchAction: FC = ({
selectedIds,
onBatchEnable,
onBatchDisable,
+ onBatchSummary,
onBatchDownload,
onArchive,
onBatchDelete,
@@ -84,7 +87,16 @@ const BatchAction: FC = ({
{t('metadata.metadata', { ns: 'dataset' })}
)}
-
+ {onBatchSummary && (
+
+ )}
{onArchive && (