mirror of
https://github.com/langgenius/dify.git
synced 2026-05-10 05:56:31 +08:00
Merge remote-tracking branch 'origin/main' into feat/support-agent-sandbox
# Conflicts: # api/core/app/apps/workflow/app_runner.py
This commit is contained in:
commit
a38b8987b4
@ -104,9 +104,7 @@ forbidden_modules =
|
||||
ignore_imports =
|
||||
core.workflow.nodes.loop.loop_node -> core.app.workflow.node_factory
|
||||
core.workflow.graph_engine.command_channels.redis_channel -> extensions.ext_redis
|
||||
core.workflow.graph_engine.layers.observability -> configs
|
||||
core.workflow.graph_engine.layers.observability -> extensions.otel.runtime
|
||||
core.workflow.graph_engine.layers.persistence -> core.ops.ops_trace_manager
|
||||
core.workflow.workflow_entry -> core.app.workflow.layers.observability
|
||||
core.workflow.graph_engine.worker_management.worker_pool -> configs
|
||||
core.workflow.nodes.agent.agent_node -> core.model_manager
|
||||
core.workflow.nodes.agent.agent_node -> core.provider_manager
|
||||
@ -147,7 +145,6 @@ ignore_imports =
|
||||
core.workflow.workflow_entry -> models.workflow
|
||||
core.workflow.nodes.agent.agent_node -> core.agent.entities
|
||||
core.workflow.nodes.agent.agent_node -> core.agent.plugin_entities
|
||||
core.workflow.graph_engine.layers.persistence -> core.app.entities.app_invoke_entities
|
||||
core.workflow.nodes.base.node -> core.app.entities.app_invoke_entities
|
||||
core.workflow.nodes.knowledge_index.knowledge_index_node -> core.app.entities.app_invoke_entities
|
||||
core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node -> core.app.app_config.entities
|
||||
@ -217,7 +214,6 @@ ignore_imports =
|
||||
core.workflow.nodes.llm.node -> core.llm_generator.output_parser.errors
|
||||
core.workflow.nodes.llm.node -> core.llm_generator.output_parser.structured_output
|
||||
core.workflow.nodes.llm.node -> core.model_manager
|
||||
core.workflow.graph_engine.layers.persistence -> core.ops.entities.trace_entity
|
||||
core.workflow.nodes.agent.entities -> core.prompt.entities.advanced_prompt_entities
|
||||
core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node -> core.prompt.simple_prompt_transform
|
||||
core.workflow.nodes.llm.entities -> core.prompt.entities.advanced_prompt_entities
|
||||
|
||||
@ -21,6 +21,7 @@ from core.app.entities.queue_entities import (
|
||||
)
|
||||
from core.app.features.annotation_reply.annotation_reply import AnnotationReplyFeature
|
||||
from core.app.layers.conversation_variable_persist_layer import ConversationVariablePersistenceLayer
|
||||
from core.app.workflow.layers.persistence import PersistenceWorkflowInfo, WorkflowPersistenceLayer
|
||||
from core.db.session_factory import session_factory
|
||||
from core.moderation.base import ModerationError
|
||||
from core.moderation.input_moderation import InputModeration
|
||||
@ -29,7 +30,6 @@ from core.variables.variables import Variable
|
||||
from core.workflow.enums import WorkflowType
|
||||
from core.workflow.graph_engine.command_channels.redis_channel import RedisChannel
|
||||
from core.workflow.graph_engine.layers.base import GraphEngineLayer
|
||||
from core.workflow.graph_engine.layers.persistence import PersistenceWorkflowInfo, WorkflowPersistenceLayer
|
||||
from core.workflow.repositories.workflow_execution_repository import WorkflowExecutionRepository
|
||||
from core.workflow.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository
|
||||
from core.workflow.runtime import GraphRuntimeState, VariablePool
|
||||
|
||||
@ -9,12 +9,12 @@ from core.app.entities.app_invoke_entities import (
|
||||
InvokeFrom,
|
||||
RagPipelineGenerateEntity,
|
||||
)
|
||||
from core.app.workflow.layers.persistence import PersistenceWorkflowInfo, WorkflowPersistenceLayer
|
||||
from core.app.workflow.node_factory import DifyNodeFactory
|
||||
from core.variables.variables import RAGPipelineVariable, RAGPipelineVariableInput
|
||||
from core.workflow.entities.graph_init_params import GraphInitParams
|
||||
from core.workflow.enums import WorkflowType
|
||||
from core.workflow.graph import Graph
|
||||
from core.workflow.graph_engine.layers.persistence import PersistenceWorkflowInfo, WorkflowPersistenceLayer
|
||||
from core.workflow.graph_events import GraphEngineEvent, GraphRunFailedEvent
|
||||
from core.workflow.repositories.workflow_execution_repository import WorkflowExecutionRepository
|
||||
from core.workflow.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository
|
||||
|
||||
@ -8,10 +8,10 @@ from core.app.apps.workflow.app_config_manager import WorkflowAppConfig
|
||||
from core.app.apps.workflow_app_runner import WorkflowBasedAppRunner
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom, WorkflowAppGenerateEntity
|
||||
from core.sandbox import Sandbox
|
||||
from core.app.workflow.layers.persistence import PersistenceWorkflowInfo, WorkflowPersistenceLayer
|
||||
from core.workflow.enums import WorkflowType
|
||||
from core.workflow.graph_engine.command_channels.redis_channel import RedisChannel
|
||||
from core.workflow.graph_engine.layers.base import GraphEngineLayer
|
||||
from core.workflow.graph_engine.layers.persistence import PersistenceWorkflowInfo, WorkflowPersistenceLayer
|
||||
from core.workflow.repositories.workflow_execution_repository import WorkflowExecutionRepository
|
||||
from core.workflow.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository
|
||||
from core.workflow.runtime import GraphRuntimeState, VariablePool
|
||||
|
||||
@ -157,7 +157,7 @@ class WorkflowBasedAppRunner:
|
||||
# Create initial runtime state with variable pool containing environment variables
|
||||
graph_runtime_state = GraphRuntimeState(
|
||||
variable_pool=VariablePool(
|
||||
system_variables=SystemVariable.empty(),
|
||||
system_variables=SystemVariable.default(),
|
||||
user_inputs={},
|
||||
environment_variables=workflow.environment_variables,
|
||||
),
|
||||
@ -272,7 +272,9 @@ class WorkflowBasedAppRunner:
|
||||
)
|
||||
|
||||
# init graph
|
||||
graph = Graph.init(graph_config=graph_config, node_factory=node_factory, root_node_id=node_id)
|
||||
graph = Graph.init(
|
||||
graph_config=graph_config, node_factory=node_factory, root_node_id=node_id, skip_validation=True
|
||||
)
|
||||
|
||||
if not graph:
|
||||
raise ValueError("graph not found in workflow")
|
||||
|
||||
10
api/core/app/workflow/layers/__init__.py
Normal file
10
api/core/app/workflow/layers/__init__.py
Normal file
@ -0,0 +1,10 @@
|
||||
"""Workflow-level GraphEngine layers that depend on outer infrastructure."""
|
||||
|
||||
from .observability import ObservabilityLayer
|
||||
from .persistence import PersistenceWorkflowInfo, WorkflowPersistenceLayer
|
||||
|
||||
__all__ = [
|
||||
"ObservabilityLayer",
|
||||
"PersistenceWorkflowInfo",
|
||||
"WorkflowPersistenceLayer",
|
||||
]
|
||||
@ -45,7 +45,6 @@ from core.workflow.graph_events import (
|
||||
from core.workflow.node_events import NodeRunResult
|
||||
from core.workflow.repositories.workflow_execution_repository import WorkflowExecutionRepository
|
||||
from core.workflow.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository
|
||||
from core.workflow.workflow_entry import WorkflowEntry
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
|
||||
|
||||
@ -319,6 +318,9 @@ class WorkflowPersistenceLayer(GraphEngineLayer):
|
||||
# workflow inputs stay reusable without binding future runs to this conversation.
|
||||
continue
|
||||
inputs[f"sys.{field_name}"] = value
|
||||
# Local import to avoid circular dependency during app bootstrapping.
|
||||
from core.workflow.workflow_entry import WorkflowEntry
|
||||
|
||||
handled = WorkflowEntry.handle_special_values(inputs)
|
||||
return handled or {}
|
||||
|
||||
@ -288,6 +288,7 @@ class Graph:
|
||||
graph_config: Mapping[str, object],
|
||||
node_factory: NodeFactory,
|
||||
root_node_id: str | None = None,
|
||||
skip_validation: bool = False,
|
||||
) -> Graph:
|
||||
"""
|
||||
Initialize graph
|
||||
@ -346,8 +347,9 @@ class Graph:
|
||||
root_node=root_node,
|
||||
)
|
||||
|
||||
# Validate the graph structure using built-in validators
|
||||
get_graph_validator().validate(graph)
|
||||
if not skip_validation:
|
||||
# Validate the graph structure using built-in validators
|
||||
get_graph_validator().validate(graph)
|
||||
|
||||
return graph
|
||||
|
||||
|
||||
@ -8,11 +8,9 @@ with middleware-like components that can observe events and interact with execut
|
||||
from .base import GraphEngineLayer
|
||||
from .debug_logging import DebugLoggingLayer
|
||||
from .execution_limits import ExecutionLimitsLayer
|
||||
from .observability import ObservabilityLayer
|
||||
|
||||
__all__ = [
|
||||
"DebugLoggingLayer",
|
||||
"ExecutionLimitsLayer",
|
||||
"GraphEngineLayer",
|
||||
"ObservabilityLayer",
|
||||
]
|
||||
|
||||
@ -44,7 +44,7 @@ class VariablePool(BaseModel):
|
||||
)
|
||||
system_variables: SystemVariable = Field(
|
||||
description="System variables",
|
||||
default_factory=SystemVariable.empty,
|
||||
default_factory=SystemVariable.default,
|
||||
)
|
||||
environment_variables: Sequence[Variable] = Field(
|
||||
description="Environment variables.",
|
||||
@ -309,4 +309,4 @@ class VariablePool(BaseModel):
|
||||
@classmethod
|
||||
def empty(cls) -> VariablePool:
|
||||
"""Create an empty variable pool."""
|
||||
return cls(system_variables=SystemVariable.empty())
|
||||
return cls(system_variables=SystemVariable.default())
|
||||
|
||||
@ -3,6 +3,7 @@ from __future__ import annotations
|
||||
from collections.abc import Mapping, Sequence
|
||||
from types import MappingProxyType
|
||||
from typing import Any
|
||||
from uuid import uuid4
|
||||
|
||||
from pydantic import AliasChoices, BaseModel, ConfigDict, Field, model_validator
|
||||
|
||||
@ -72,8 +73,8 @@ class SystemVariable(BaseModel):
|
||||
return data
|
||||
|
||||
@classmethod
|
||||
def empty(cls) -> SystemVariable:
|
||||
return cls()
|
||||
def default(cls) -> SystemVariable:
|
||||
return cls(workflow_execution_id=str(uuid4()))
|
||||
|
||||
def to_dict(self) -> dict[SystemVariableKey, Any]:
|
||||
# NOTE: This method is provided for compatibility with legacy code.
|
||||
|
||||
@ -7,6 +7,7 @@ from typing import Any
|
||||
from configs import dify_config
|
||||
from core.app.apps.exc import GenerateTaskStoppedError
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.app.workflow.layers.observability import ObservabilityLayer
|
||||
from core.app.workflow.node_factory import DifyNodeFactory
|
||||
from core.file.models import File
|
||||
from core.sandbox import Sandbox
|
||||
@ -16,7 +17,7 @@ from core.workflow.errors import WorkflowNodeRunFailedError
|
||||
from core.workflow.graph import Graph
|
||||
from core.workflow.graph_engine import GraphEngine
|
||||
from core.workflow.graph_engine.command_channels import InMemoryChannel
|
||||
from core.workflow.graph_engine.layers import DebugLoggingLayer, ExecutionLimitsLayer, ObservabilityLayer
|
||||
from core.workflow.graph_engine.layers import DebugLoggingLayer, ExecutionLimitsLayer
|
||||
from core.workflow.graph_engine.protocols.command_channel import CommandChannel
|
||||
from core.workflow.graph_events import GraphEngineEvent, GraphNodeEventBase, GraphRunFailedEvent
|
||||
from core.workflow.nodes import NodeType
|
||||
@ -281,7 +282,7 @@ class WorkflowEntry:
|
||||
|
||||
# init variable pool
|
||||
variable_pool = VariablePool(
|
||||
system_variables=SystemVariable.empty(),
|
||||
system_variables=SystemVariable.default(),
|
||||
user_inputs={},
|
||||
environment_variables=[],
|
||||
)
|
||||
|
||||
@ -36,7 +36,7 @@ def init_app(app: DifyApp) -> None:
|
||||
router.include_router(console_router, prefix="/console/api")
|
||||
CORS(
|
||||
app,
|
||||
resources={r"/console/api/*": {"origins": dify_config.CONSOLE_CORS_ALLOW_ORIGINS}},
|
||||
resources={r"/console/api/.*": {"origins": dify_config.CONSOLE_CORS_ALLOW_ORIGINS}},
|
||||
supports_credentials=True,
|
||||
allow_headers=list(AUTHENTICATED_HEADERS),
|
||||
methods=["GET", "PUT", "POST", "DELETE", "OPTIONS", "PATCH"],
|
||||
|
||||
@ -436,7 +436,7 @@ class RagPipelineService:
|
||||
user_inputs=user_inputs,
|
||||
user_id=account.id,
|
||||
variable_pool=VariablePool(
|
||||
system_variables=SystemVariable.empty(),
|
||||
system_variables=SystemVariable.default(),
|
||||
user_inputs=user_inputs,
|
||||
environment_variables=[],
|
||||
conversation_variables=[],
|
||||
|
||||
@ -752,7 +752,7 @@ class WorkflowService:
|
||||
|
||||
else:
|
||||
variable_pool = VariablePool(
|
||||
system_variables=SystemVariable.empty(),
|
||||
system_variables=SystemVariable.default(),
|
||||
user_inputs=user_inputs,
|
||||
environment_variables=draft_workflow.environment_variables,
|
||||
conversation_variables=[],
|
||||
@ -1160,7 +1160,7 @@ def _setup_variable_pool(
|
||||
system_variable.conversation_id = conversation_id
|
||||
system_variable.dialogue_count = 1
|
||||
else:
|
||||
system_variable = SystemVariable.empty()
|
||||
system_variable = SystemVariable.default()
|
||||
|
||||
# init variable pool
|
||||
variable_pool = VariablePool(
|
||||
|
||||
@ -0,0 +1,107 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager
|
||||
from core.app.apps.workflow.app_runner import WorkflowAppRunner
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom, WorkflowAppGenerateEntity
|
||||
from core.workflow.runtime import GraphRuntimeState, VariablePool
|
||||
from core.workflow.system_variable import SystemVariable
|
||||
from models.workflow import Workflow
|
||||
|
||||
|
||||
def _make_graph_state():
|
||||
variable_pool = VariablePool(
|
||||
system_variables=SystemVariable.default(),
|
||||
user_inputs={},
|
||||
environment_variables=[],
|
||||
conversation_variables=[],
|
||||
)
|
||||
return MagicMock(), variable_pool, GraphRuntimeState(variable_pool=variable_pool, start_at=0.0)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("single_iteration_run", "single_loop_run"),
|
||||
[
|
||||
(WorkflowAppGenerateEntity.SingleIterationRunEntity(node_id="iter", inputs={}), None),
|
||||
(None, WorkflowAppGenerateEntity.SingleLoopRunEntity(node_id="loop", inputs={})),
|
||||
],
|
||||
)
|
||||
def test_run_uses_single_node_execution_branch(
|
||||
single_iteration_run: Any,
|
||||
single_loop_run: Any,
|
||||
) -> None:
|
||||
app_config = MagicMock()
|
||||
app_config.app_id = "app"
|
||||
app_config.tenant_id = "tenant"
|
||||
app_config.workflow_id = "workflow"
|
||||
|
||||
app_generate_entity = MagicMock(spec=WorkflowAppGenerateEntity)
|
||||
app_generate_entity.app_config = app_config
|
||||
app_generate_entity.inputs = {}
|
||||
app_generate_entity.files = []
|
||||
app_generate_entity.user_id = "user"
|
||||
app_generate_entity.invoke_from = InvokeFrom.SERVICE_API
|
||||
app_generate_entity.workflow_execution_id = "execution-id"
|
||||
app_generate_entity.task_id = "task-id"
|
||||
app_generate_entity.call_depth = 0
|
||||
app_generate_entity.trace_manager = None
|
||||
app_generate_entity.single_iteration_run = single_iteration_run
|
||||
app_generate_entity.single_loop_run = single_loop_run
|
||||
|
||||
workflow = MagicMock(spec=Workflow)
|
||||
workflow.tenant_id = "tenant"
|
||||
workflow.app_id = "app"
|
||||
workflow.id = "workflow"
|
||||
workflow.type = "workflow"
|
||||
workflow.version = "v1"
|
||||
workflow.graph_dict = {"nodes": [], "edges": []}
|
||||
workflow.environment_variables = []
|
||||
|
||||
runner = WorkflowAppRunner(
|
||||
application_generate_entity=app_generate_entity,
|
||||
queue_manager=MagicMock(spec=AppQueueManager),
|
||||
variable_loader=MagicMock(),
|
||||
workflow=workflow,
|
||||
system_user_id="system-user",
|
||||
workflow_execution_repository=MagicMock(),
|
||||
workflow_node_execution_repository=MagicMock(),
|
||||
)
|
||||
|
||||
graph, variable_pool, graph_runtime_state = _make_graph_state()
|
||||
mock_workflow_entry = MagicMock()
|
||||
mock_workflow_entry.graph_engine = MagicMock()
|
||||
mock_workflow_entry.graph_engine.layer = MagicMock()
|
||||
mock_workflow_entry.run.return_value = iter([])
|
||||
|
||||
with (
|
||||
patch("core.app.apps.workflow.app_runner.RedisChannel"),
|
||||
patch("core.app.apps.workflow.app_runner.redis_client"),
|
||||
patch("core.app.apps.workflow.app_runner.WorkflowEntry", return_value=mock_workflow_entry) as entry_class,
|
||||
patch.object(
|
||||
runner,
|
||||
"_prepare_single_node_execution",
|
||||
return_value=(
|
||||
graph,
|
||||
variable_pool,
|
||||
graph_runtime_state,
|
||||
),
|
||||
) as prepare_single,
|
||||
patch.object(runner, "_init_graph") as init_graph,
|
||||
):
|
||||
runner.run()
|
||||
|
||||
prepare_single.assert_called_once_with(
|
||||
workflow=workflow,
|
||||
single_iteration_run=single_iteration_run,
|
||||
single_loop_run=single_loop_run,
|
||||
)
|
||||
init_graph.assert_not_called()
|
||||
|
||||
entry_kwargs = entry_class.call_args.kwargs
|
||||
assert entry_kwargs["invoke_from"] == InvokeFrom.DEBUGGER
|
||||
assert entry_kwargs["variable_pool"] is variable_pool
|
||||
assert entry_kwargs["graph_runtime_state"] is graph_runtime_state
|
||||
@ -0,0 +1,120 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
import pytest
|
||||
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.app.workflow.node_factory import DifyNodeFactory
|
||||
from core.workflow.entities import GraphInitParams
|
||||
from core.workflow.graph import Graph
|
||||
from core.workflow.graph.validation import GraphValidationError
|
||||
from core.workflow.nodes import NodeType
|
||||
from core.workflow.runtime import GraphRuntimeState, VariablePool
|
||||
from core.workflow.system_variable import SystemVariable
|
||||
from models.enums import UserFrom
|
||||
|
||||
|
||||
def _build_iteration_graph(node_id: str) -> dict[str, Any]:
|
||||
return {
|
||||
"nodes": [
|
||||
{
|
||||
"id": node_id,
|
||||
"data": {
|
||||
"type": "iteration",
|
||||
"title": "Iteration",
|
||||
"iterator_selector": ["start", "items"],
|
||||
"output_selector": [node_id, "output"],
|
||||
},
|
||||
}
|
||||
],
|
||||
"edges": [],
|
||||
}
|
||||
|
||||
|
||||
def _build_loop_graph(node_id: str) -> dict[str, Any]:
|
||||
return {
|
||||
"nodes": [
|
||||
{
|
||||
"id": node_id,
|
||||
"data": {
|
||||
"type": "loop",
|
||||
"title": "Loop",
|
||||
"loop_count": 1,
|
||||
"break_conditions": [],
|
||||
"logical_operator": "and",
|
||||
"loop_variables": [],
|
||||
"outputs": {},
|
||||
},
|
||||
}
|
||||
],
|
||||
"edges": [],
|
||||
}
|
||||
|
||||
|
||||
def _make_factory(graph_config: dict[str, Any]) -> DifyNodeFactory:
|
||||
graph_init_params = GraphInitParams(
|
||||
tenant_id="tenant",
|
||||
app_id="app",
|
||||
workflow_id="workflow",
|
||||
graph_config=graph_config,
|
||||
user_id="user",
|
||||
user_from=UserFrom.ACCOUNT,
|
||||
invoke_from=InvokeFrom.DEBUGGER,
|
||||
call_depth=0,
|
||||
)
|
||||
graph_runtime_state = GraphRuntimeState(
|
||||
variable_pool=VariablePool(
|
||||
system_variables=SystemVariable.default(),
|
||||
user_inputs={},
|
||||
environment_variables=[],
|
||||
),
|
||||
start_at=0.0,
|
||||
)
|
||||
return DifyNodeFactory(graph_init_params=graph_init_params, graph_runtime_state=graph_runtime_state)
|
||||
|
||||
|
||||
def test_iteration_root_requires_skip_validation():
|
||||
node_id = "iteration-node"
|
||||
graph_config = _build_iteration_graph(node_id)
|
||||
node_factory = _make_factory(graph_config)
|
||||
|
||||
with pytest.raises(GraphValidationError):
|
||||
Graph.init(
|
||||
graph_config=graph_config,
|
||||
node_factory=node_factory,
|
||||
root_node_id=node_id,
|
||||
)
|
||||
|
||||
graph = Graph.init(
|
||||
graph_config=graph_config,
|
||||
node_factory=node_factory,
|
||||
root_node_id=node_id,
|
||||
skip_validation=True,
|
||||
)
|
||||
|
||||
assert graph.root_node.id == node_id
|
||||
assert graph.root_node.node_type == NodeType.ITERATION
|
||||
|
||||
|
||||
def test_loop_root_requires_skip_validation():
|
||||
node_id = "loop-node"
|
||||
graph_config = _build_loop_graph(node_id)
|
||||
node_factory = _make_factory(graph_config)
|
||||
|
||||
with pytest.raises(GraphValidationError):
|
||||
Graph.init(
|
||||
graph_config=graph_config,
|
||||
node_factory=node_factory,
|
||||
root_node_id=node_id,
|
||||
)
|
||||
|
||||
graph = Graph.init(
|
||||
graph_config=graph_config,
|
||||
node_factory=node_factory,
|
||||
root_node_id=node_id,
|
||||
skip_validation=True,
|
||||
)
|
||||
|
||||
assert graph.root_node.id == node_id
|
||||
assert graph.root_node.node_type == NodeType.LOOP
|
||||
@ -90,14 +90,14 @@ def mock_tool_node():
|
||||
@pytest.fixture
|
||||
def mock_is_instrument_flag_enabled_false():
|
||||
"""Mock is_instrument_flag_enabled to return False."""
|
||||
with patch("core.workflow.graph_engine.layers.observability.is_instrument_flag_enabled", return_value=False):
|
||||
with patch("core.app.workflow.layers.observability.is_instrument_flag_enabled", return_value=False):
|
||||
yield
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_is_instrument_flag_enabled_true():
|
||||
"""Mock is_instrument_flag_enabled to return True."""
|
||||
with patch("core.workflow.graph_engine.layers.observability.is_instrument_flag_enabled", return_value=True):
|
||||
with patch("core.app.workflow.layers.observability.is_instrument_flag_enabled", return_value=True):
|
||||
yield
|
||||
|
||||
|
||||
|
||||
@ -15,14 +15,14 @@ from unittest.mock import patch
|
||||
import pytest
|
||||
from opentelemetry.trace import StatusCode
|
||||
|
||||
from core.app.workflow.layers.observability import ObservabilityLayer
|
||||
from core.workflow.enums import NodeType
|
||||
from core.workflow.graph_engine.layers.observability import ObservabilityLayer
|
||||
|
||||
|
||||
class TestObservabilityLayerInitialization:
|
||||
"""Test ObservabilityLayer initialization logic."""
|
||||
|
||||
@patch("core.workflow.graph_engine.layers.observability.dify_config.ENABLE_OTEL", True)
|
||||
@patch("core.app.workflow.layers.observability.dify_config.ENABLE_OTEL", True)
|
||||
@pytest.mark.usefixtures("mock_is_instrument_flag_enabled_false")
|
||||
def test_initialization_when_otel_enabled(self, tracer_provider_with_memory_exporter):
|
||||
"""Test that layer initializes correctly when OTel is enabled."""
|
||||
@ -32,7 +32,7 @@ class TestObservabilityLayerInitialization:
|
||||
assert NodeType.TOOL in layer._parsers
|
||||
assert layer._default_parser is not None
|
||||
|
||||
@patch("core.workflow.graph_engine.layers.observability.dify_config.ENABLE_OTEL", False)
|
||||
@patch("core.app.workflow.layers.observability.dify_config.ENABLE_OTEL", False)
|
||||
@pytest.mark.usefixtures("mock_is_instrument_flag_enabled_true")
|
||||
def test_initialization_when_instrument_flag_enabled(self, tracer_provider_with_memory_exporter):
|
||||
"""Test that layer enables when instrument flag is enabled."""
|
||||
@ -46,7 +46,7 @@ class TestObservabilityLayerInitialization:
|
||||
class TestObservabilityLayerNodeSpanLifecycle:
|
||||
"""Test node span creation and lifecycle management."""
|
||||
|
||||
@patch("core.workflow.graph_engine.layers.observability.dify_config.ENABLE_OTEL", True)
|
||||
@patch("core.app.workflow.layers.observability.dify_config.ENABLE_OTEL", True)
|
||||
@pytest.mark.usefixtures("mock_is_instrument_flag_enabled_false")
|
||||
def test_node_span_created_and_ended(
|
||||
self, tracer_provider_with_memory_exporter, memory_span_exporter, mock_llm_node
|
||||
@ -63,7 +63,7 @@ class TestObservabilityLayerNodeSpanLifecycle:
|
||||
assert spans[0].name == mock_llm_node.title
|
||||
assert spans[0].status.status_code == StatusCode.OK
|
||||
|
||||
@patch("core.workflow.graph_engine.layers.observability.dify_config.ENABLE_OTEL", True)
|
||||
@patch("core.app.workflow.layers.observability.dify_config.ENABLE_OTEL", True)
|
||||
@pytest.mark.usefixtures("mock_is_instrument_flag_enabled_false")
|
||||
def test_node_error_recorded_in_span(
|
||||
self, tracer_provider_with_memory_exporter, memory_span_exporter, mock_llm_node
|
||||
@ -82,7 +82,7 @@ class TestObservabilityLayerNodeSpanLifecycle:
|
||||
assert len(spans[0].events) > 0
|
||||
assert any("exception" in event.name.lower() for event in spans[0].events)
|
||||
|
||||
@patch("core.workflow.graph_engine.layers.observability.dify_config.ENABLE_OTEL", True)
|
||||
@patch("core.app.workflow.layers.observability.dify_config.ENABLE_OTEL", True)
|
||||
@pytest.mark.usefixtures("mock_is_instrument_flag_enabled_false")
|
||||
def test_node_end_without_start_handled_gracefully(
|
||||
self, tracer_provider_with_memory_exporter, memory_span_exporter, mock_llm_node
|
||||
@ -100,7 +100,7 @@ class TestObservabilityLayerNodeSpanLifecycle:
|
||||
class TestObservabilityLayerParserIntegration:
|
||||
"""Test parser integration for different node types."""
|
||||
|
||||
@patch("core.workflow.graph_engine.layers.observability.dify_config.ENABLE_OTEL", True)
|
||||
@patch("core.app.workflow.layers.observability.dify_config.ENABLE_OTEL", True)
|
||||
@pytest.mark.usefixtures("mock_is_instrument_flag_enabled_false")
|
||||
def test_default_parser_used_for_regular_node(
|
||||
self, tracer_provider_with_memory_exporter, memory_span_exporter, mock_start_node
|
||||
@ -119,7 +119,7 @@ class TestObservabilityLayerParserIntegration:
|
||||
assert attrs["node.execution_id"] == mock_start_node.execution_id
|
||||
assert attrs["node.type"] == mock_start_node.node_type.value
|
||||
|
||||
@patch("core.workflow.graph_engine.layers.observability.dify_config.ENABLE_OTEL", True)
|
||||
@patch("core.app.workflow.layers.observability.dify_config.ENABLE_OTEL", True)
|
||||
@pytest.mark.usefixtures("mock_is_instrument_flag_enabled_false")
|
||||
def test_tool_parser_used_for_tool_node(
|
||||
self, tracer_provider_with_memory_exporter, memory_span_exporter, mock_tool_node
|
||||
@ -138,7 +138,7 @@ class TestObservabilityLayerParserIntegration:
|
||||
assert attrs["gen_ai.tool.name"] == mock_tool_node.title
|
||||
assert attrs["gen_ai.tool.type"] == mock_tool_node._node_data.provider_type.value
|
||||
|
||||
@patch("core.workflow.graph_engine.layers.observability.dify_config.ENABLE_OTEL", True)
|
||||
@patch("core.app.workflow.layers.observability.dify_config.ENABLE_OTEL", True)
|
||||
@pytest.mark.usefixtures("mock_is_instrument_flag_enabled_false")
|
||||
def test_llm_parser_used_for_llm_node(
|
||||
self, tracer_provider_with_memory_exporter, memory_span_exporter, mock_llm_node, mock_result_event
|
||||
@ -176,7 +176,7 @@ class TestObservabilityLayerParserIntegration:
|
||||
assert attrs["gen_ai.completion"] == "test completion"
|
||||
assert attrs["gen_ai.response.finish_reason"] == "stop"
|
||||
|
||||
@patch("core.workflow.graph_engine.layers.observability.dify_config.ENABLE_OTEL", True)
|
||||
@patch("core.app.workflow.layers.observability.dify_config.ENABLE_OTEL", True)
|
||||
@pytest.mark.usefixtures("mock_is_instrument_flag_enabled_false")
|
||||
def test_retrieval_parser_used_for_retrieval_node(
|
||||
self, tracer_provider_with_memory_exporter, memory_span_exporter, mock_retrieval_node, mock_result_event
|
||||
@ -204,7 +204,7 @@ class TestObservabilityLayerParserIntegration:
|
||||
assert attrs["retrieval.query"] == "test query"
|
||||
assert "retrieval.document" in attrs
|
||||
|
||||
@patch("core.workflow.graph_engine.layers.observability.dify_config.ENABLE_OTEL", True)
|
||||
@patch("core.app.workflow.layers.observability.dify_config.ENABLE_OTEL", True)
|
||||
@pytest.mark.usefixtures("mock_is_instrument_flag_enabled_false")
|
||||
def test_result_event_extracts_inputs_and_outputs(
|
||||
self, tracer_provider_with_memory_exporter, memory_span_exporter, mock_start_node, mock_result_event
|
||||
@ -235,7 +235,7 @@ class TestObservabilityLayerParserIntegration:
|
||||
class TestObservabilityLayerGraphLifecycle:
|
||||
"""Test graph lifecycle management."""
|
||||
|
||||
@patch("core.workflow.graph_engine.layers.observability.dify_config.ENABLE_OTEL", True)
|
||||
@patch("core.app.workflow.layers.observability.dify_config.ENABLE_OTEL", True)
|
||||
@pytest.mark.usefixtures("mock_is_instrument_flag_enabled_false")
|
||||
def test_on_graph_start_clears_contexts(self, tracer_provider_with_memory_exporter, mock_llm_node):
|
||||
"""Test that on_graph_start clears node contexts."""
|
||||
@ -248,7 +248,7 @@ class TestObservabilityLayerGraphLifecycle:
|
||||
layer.on_graph_start()
|
||||
assert len(layer._node_contexts) == 0
|
||||
|
||||
@patch("core.workflow.graph_engine.layers.observability.dify_config.ENABLE_OTEL", True)
|
||||
@patch("core.app.workflow.layers.observability.dify_config.ENABLE_OTEL", True)
|
||||
@pytest.mark.usefixtures("mock_is_instrument_flag_enabled_false")
|
||||
def test_on_graph_end_with_no_unfinished_spans(
|
||||
self, tracer_provider_with_memory_exporter, memory_span_exporter, mock_llm_node
|
||||
@ -264,7 +264,7 @@ class TestObservabilityLayerGraphLifecycle:
|
||||
spans = memory_span_exporter.get_finished_spans()
|
||||
assert len(spans) == 1
|
||||
|
||||
@patch("core.workflow.graph_engine.layers.observability.dify_config.ENABLE_OTEL", True)
|
||||
@patch("core.app.workflow.layers.observability.dify_config.ENABLE_OTEL", True)
|
||||
@pytest.mark.usefixtures("mock_is_instrument_flag_enabled_false")
|
||||
def test_on_graph_end_with_unfinished_spans_logs_warning(
|
||||
self, tracer_provider_with_memory_exporter, mock_llm_node, caplog
|
||||
@ -285,7 +285,7 @@ class TestObservabilityLayerGraphLifecycle:
|
||||
class TestObservabilityLayerDisabledMode:
|
||||
"""Test behavior when layer is disabled."""
|
||||
|
||||
@patch("core.workflow.graph_engine.layers.observability.dify_config.ENABLE_OTEL", False)
|
||||
@patch("core.app.workflow.layers.observability.dify_config.ENABLE_OTEL", False)
|
||||
@pytest.mark.usefixtures("mock_is_instrument_flag_enabled_false")
|
||||
def test_disabled_mode_skips_node_start(self, memory_span_exporter, mock_start_node):
|
||||
"""Test that disabled layer doesn't create spans on node start."""
|
||||
@ -299,7 +299,7 @@ class TestObservabilityLayerDisabledMode:
|
||||
spans = memory_span_exporter.get_finished_spans()
|
||||
assert len(spans) == 0
|
||||
|
||||
@patch("core.workflow.graph_engine.layers.observability.dify_config.ENABLE_OTEL", False)
|
||||
@patch("core.app.workflow.layers.observability.dify_config.ENABLE_OTEL", False)
|
||||
@pytest.mark.usefixtures("mock_is_instrument_flag_enabled_false")
|
||||
def test_disabled_mode_skips_node_end(self, memory_span_exporter, mock_llm_node):
|
||||
"""Test that disabled layer doesn't process node end."""
|
||||
|
||||
@ -16,7 +16,7 @@ from core.workflow.system_variable import SystemVariable
|
||||
def test_executor_with_json_body_and_number_variable():
|
||||
# Prepare the variable pool
|
||||
variable_pool = VariablePool(
|
||||
system_variables=SystemVariable.empty(),
|
||||
system_variables=SystemVariable.default(),
|
||||
user_inputs={},
|
||||
)
|
||||
variable_pool.add(["pre_node_id", "number"], 42)
|
||||
@ -69,7 +69,7 @@ def test_executor_with_json_body_and_number_variable():
|
||||
def test_executor_with_json_body_and_object_variable():
|
||||
# Prepare the variable pool
|
||||
variable_pool = VariablePool(
|
||||
system_variables=SystemVariable.empty(),
|
||||
system_variables=SystemVariable.default(),
|
||||
user_inputs={},
|
||||
)
|
||||
variable_pool.add(["pre_node_id", "object"], {"name": "John Doe", "age": 30, "email": "john@example.com"})
|
||||
@ -124,7 +124,7 @@ def test_executor_with_json_body_and_object_variable():
|
||||
def test_executor_with_json_body_and_nested_object_variable():
|
||||
# Prepare the variable pool
|
||||
variable_pool = VariablePool(
|
||||
system_variables=SystemVariable.empty(),
|
||||
system_variables=SystemVariable.default(),
|
||||
user_inputs={},
|
||||
)
|
||||
variable_pool.add(["pre_node_id", "object"], {"name": "John Doe", "age": 30, "email": "john@example.com"})
|
||||
@ -178,7 +178,7 @@ def test_executor_with_json_body_and_nested_object_variable():
|
||||
|
||||
|
||||
def test_extract_selectors_from_template_with_newline():
|
||||
variable_pool = VariablePool(system_variables=SystemVariable.empty())
|
||||
variable_pool = VariablePool(system_variables=SystemVariable.default())
|
||||
variable_pool.add(("node_id", "custom_query"), "line1\nline2")
|
||||
node_data = HttpRequestNodeData(
|
||||
title="Test JSON Body with Nested Object Variable",
|
||||
@ -205,7 +205,7 @@ def test_extract_selectors_from_template_with_newline():
|
||||
def test_executor_with_form_data():
|
||||
# Prepare the variable pool
|
||||
variable_pool = VariablePool(
|
||||
system_variables=SystemVariable.empty(),
|
||||
system_variables=SystemVariable.default(),
|
||||
user_inputs={},
|
||||
)
|
||||
variable_pool.add(["pre_node_id", "text_field"], "Hello, World!")
|
||||
@ -290,7 +290,7 @@ def test_init_headers():
|
||||
return Executor(
|
||||
node_data=node_data,
|
||||
timeout=timeout,
|
||||
variable_pool=VariablePool(system_variables=SystemVariable.empty()),
|
||||
variable_pool=VariablePool(system_variables=SystemVariable.default()),
|
||||
)
|
||||
|
||||
executor = create_executor("aa\n cc:")
|
||||
@ -324,7 +324,7 @@ def test_init_params():
|
||||
return Executor(
|
||||
node_data=node_data,
|
||||
timeout=timeout,
|
||||
variable_pool=VariablePool(system_variables=SystemVariable.empty()),
|
||||
variable_pool=VariablePool(system_variables=SystemVariable.default()),
|
||||
)
|
||||
|
||||
# Test basic key-value pairs
|
||||
@ -355,7 +355,7 @@ def test_init_params():
|
||||
|
||||
def test_empty_api_key_raises_error_bearer():
|
||||
"""Test that empty API key raises AuthorizationConfigError for bearer auth."""
|
||||
variable_pool = VariablePool(system_variables=SystemVariable.empty())
|
||||
variable_pool = VariablePool(system_variables=SystemVariable.default())
|
||||
node_data = HttpRequestNodeData(
|
||||
title="test",
|
||||
method="get",
|
||||
@ -379,7 +379,7 @@ def test_empty_api_key_raises_error_bearer():
|
||||
|
||||
def test_empty_api_key_raises_error_basic():
|
||||
"""Test that empty API key raises AuthorizationConfigError for basic auth."""
|
||||
variable_pool = VariablePool(system_variables=SystemVariable.empty())
|
||||
variable_pool = VariablePool(system_variables=SystemVariable.default())
|
||||
node_data = HttpRequestNodeData(
|
||||
title="test",
|
||||
method="get",
|
||||
@ -403,7 +403,7 @@ def test_empty_api_key_raises_error_basic():
|
||||
|
||||
def test_empty_api_key_raises_error_custom():
|
||||
"""Test that empty API key raises AuthorizationConfigError for custom auth."""
|
||||
variable_pool = VariablePool(system_variables=SystemVariable.empty())
|
||||
variable_pool = VariablePool(system_variables=SystemVariable.default())
|
||||
node_data = HttpRequestNodeData(
|
||||
title="test",
|
||||
method="get",
|
||||
@ -427,7 +427,7 @@ def test_empty_api_key_raises_error_custom():
|
||||
|
||||
def test_whitespace_only_api_key_raises_error():
|
||||
"""Test that whitespace-only API key raises AuthorizationConfigError."""
|
||||
variable_pool = VariablePool(system_variables=SystemVariable.empty())
|
||||
variable_pool = VariablePool(system_variables=SystemVariable.default())
|
||||
node_data = HttpRequestNodeData(
|
||||
title="test",
|
||||
method="get",
|
||||
@ -451,7 +451,7 @@ def test_whitespace_only_api_key_raises_error():
|
||||
|
||||
def test_valid_api_key_works():
|
||||
"""Test that valid API key works correctly for bearer auth."""
|
||||
variable_pool = VariablePool(system_variables=SystemVariable.empty())
|
||||
variable_pool = VariablePool(system_variables=SystemVariable.default())
|
||||
node_data = HttpRequestNodeData(
|
||||
title="test",
|
||||
method="get",
|
||||
@ -487,7 +487,7 @@ def test_executor_with_json_body_and_unquoted_uuid_variable():
|
||||
test_uuid = "57eeeeb1-450b-482c-81b9-4be77e95dee2"
|
||||
|
||||
variable_pool = VariablePool(
|
||||
system_variables=SystemVariable.empty(),
|
||||
system_variables=SystemVariable.default(),
|
||||
user_inputs={},
|
||||
)
|
||||
variable_pool.add(["pre_node_id", "uuid"], test_uuid)
|
||||
@ -531,7 +531,7 @@ def test_executor_with_json_body_and_unquoted_uuid_with_newlines():
|
||||
test_uuid = "57eeeeb1-450b-482c-81b9-4be77e95dee2"
|
||||
|
||||
variable_pool = VariablePool(
|
||||
system_variables=SystemVariable.empty(),
|
||||
system_variables=SystemVariable.default(),
|
||||
user_inputs={},
|
||||
)
|
||||
variable_pool.add(["pre_node_id", "uuid"], test_uuid)
|
||||
@ -569,7 +569,7 @@ def test_executor_with_json_body_and_unquoted_uuid_with_newlines():
|
||||
def test_executor_with_json_body_preserves_numbers_and_strings():
|
||||
"""Test that numbers are preserved and string values are properly quoted."""
|
||||
variable_pool = VariablePool(
|
||||
system_variables=SystemVariable.empty(),
|
||||
system_variables=SystemVariable.default(),
|
||||
user_inputs={},
|
||||
)
|
||||
variable_pool.add(["node", "count"], 42)
|
||||
|
||||
@ -86,7 +86,7 @@ def graph_init_params() -> GraphInitParams:
|
||||
@pytest.fixture
|
||||
def graph_runtime_state() -> GraphRuntimeState:
|
||||
variable_pool = VariablePool(
|
||||
system_variables=SystemVariable.empty(),
|
||||
system_variables=SystemVariable.default(),
|
||||
user_inputs={},
|
||||
)
|
||||
return GraphRuntimeState(
|
||||
|
||||
@ -111,7 +111,7 @@ def test_webhook_node_file_conversion_to_file_variable():
|
||||
)
|
||||
|
||||
variable_pool = VariablePool(
|
||||
system_variables=SystemVariable.empty(),
|
||||
system_variables=SystemVariable.default(),
|
||||
user_inputs={
|
||||
"webhook_data": {
|
||||
"headers": {},
|
||||
@ -184,7 +184,7 @@ def test_webhook_node_file_conversion_with_missing_files():
|
||||
)
|
||||
|
||||
variable_pool = VariablePool(
|
||||
system_variables=SystemVariable.empty(),
|
||||
system_variables=SystemVariable.default(),
|
||||
user_inputs={
|
||||
"webhook_data": {
|
||||
"headers": {},
|
||||
@ -219,7 +219,7 @@ def test_webhook_node_file_conversion_with_none_file():
|
||||
)
|
||||
|
||||
variable_pool = VariablePool(
|
||||
system_variables=SystemVariable.empty(),
|
||||
system_variables=SystemVariable.default(),
|
||||
user_inputs={
|
||||
"webhook_data": {
|
||||
"headers": {},
|
||||
@ -256,7 +256,7 @@ def test_webhook_node_file_conversion_with_non_dict_file():
|
||||
)
|
||||
|
||||
variable_pool = VariablePool(
|
||||
system_variables=SystemVariable.empty(),
|
||||
system_variables=SystemVariable.default(),
|
||||
user_inputs={
|
||||
"webhook_data": {
|
||||
"headers": {},
|
||||
@ -300,7 +300,7 @@ def test_webhook_node_file_conversion_mixed_parameters():
|
||||
)
|
||||
|
||||
variable_pool = VariablePool(
|
||||
system_variables=SystemVariable.empty(),
|
||||
system_variables=SystemVariable.default(),
|
||||
user_inputs={
|
||||
"webhook_data": {
|
||||
"headers": {},
|
||||
@ -370,7 +370,7 @@ def test_webhook_node_different_file_types():
|
||||
)
|
||||
|
||||
variable_pool = VariablePool(
|
||||
system_variables=SystemVariable.empty(),
|
||||
system_variables=SystemVariable.default(),
|
||||
user_inputs={
|
||||
"webhook_data": {
|
||||
"headers": {},
|
||||
@ -430,7 +430,7 @@ def test_webhook_node_file_conversion_with_non_dict_wrapper():
|
||||
)
|
||||
|
||||
variable_pool = VariablePool(
|
||||
system_variables=SystemVariable.empty(),
|
||||
system_variables=SystemVariable.default(),
|
||||
user_inputs={
|
||||
"webhook_data": {
|
||||
"headers": {},
|
||||
|
||||
@ -75,7 +75,7 @@ def test_webhook_node_basic_initialization():
|
||||
)
|
||||
|
||||
variable_pool = VariablePool(
|
||||
system_variables=SystemVariable.empty(),
|
||||
system_variables=SystemVariable.default(),
|
||||
user_inputs={},
|
||||
)
|
||||
|
||||
@ -118,7 +118,7 @@ def test_webhook_node_run_with_headers():
|
||||
)
|
||||
|
||||
variable_pool = VariablePool(
|
||||
system_variables=SystemVariable.empty(),
|
||||
system_variables=SystemVariable.default(),
|
||||
user_inputs={
|
||||
"webhook_data": {
|
||||
"headers": {
|
||||
@ -154,7 +154,7 @@ def test_webhook_node_run_with_query_params():
|
||||
)
|
||||
|
||||
variable_pool = VariablePool(
|
||||
system_variables=SystemVariable.empty(),
|
||||
system_variables=SystemVariable.default(),
|
||||
user_inputs={
|
||||
"webhook_data": {
|
||||
"headers": {},
|
||||
@ -190,7 +190,7 @@ def test_webhook_node_run_with_body_params():
|
||||
)
|
||||
|
||||
variable_pool = VariablePool(
|
||||
system_variables=SystemVariable.empty(),
|
||||
system_variables=SystemVariable.default(),
|
||||
user_inputs={
|
||||
"webhook_data": {
|
||||
"headers": {},
|
||||
@ -249,7 +249,7 @@ def test_webhook_node_run_with_file_params():
|
||||
)
|
||||
|
||||
variable_pool = VariablePool(
|
||||
system_variables=SystemVariable.empty(),
|
||||
system_variables=SystemVariable.default(),
|
||||
user_inputs={
|
||||
"webhook_data": {
|
||||
"headers": {},
|
||||
@ -302,7 +302,7 @@ def test_webhook_node_run_mixed_parameters():
|
||||
)
|
||||
|
||||
variable_pool = VariablePool(
|
||||
system_variables=SystemVariable.empty(),
|
||||
system_variables=SystemVariable.default(),
|
||||
user_inputs={
|
||||
"webhook_data": {
|
||||
"headers": {"Authorization": "Bearer token"},
|
||||
@ -342,7 +342,7 @@ def test_webhook_node_run_empty_webhook_data():
|
||||
)
|
||||
|
||||
variable_pool = VariablePool(
|
||||
system_variables=SystemVariable.empty(),
|
||||
system_variables=SystemVariable.default(),
|
||||
user_inputs={}, # No webhook_data
|
||||
)
|
||||
|
||||
@ -368,7 +368,7 @@ def test_webhook_node_run_case_insensitive_headers():
|
||||
)
|
||||
|
||||
variable_pool = VariablePool(
|
||||
system_variables=SystemVariable.empty(),
|
||||
system_variables=SystemVariable.default(),
|
||||
user_inputs={
|
||||
"webhook_data": {
|
||||
"headers": {
|
||||
@ -398,7 +398,7 @@ def test_webhook_node_variable_pool_user_inputs():
|
||||
|
||||
# Add some additional variables to the pool
|
||||
variable_pool = VariablePool(
|
||||
system_variables=SystemVariable.empty(),
|
||||
system_variables=SystemVariable.default(),
|
||||
user_inputs={
|
||||
"webhook_data": {"headers": {}, "query_params": {}, "body": {}, "files": {}},
|
||||
"other_var": "should_be_included",
|
||||
@ -429,7 +429,7 @@ def test_webhook_node_different_methods(method):
|
||||
)
|
||||
|
||||
variable_pool = VariablePool(
|
||||
system_variables=SystemVariable.empty(),
|
||||
system_variables=SystemVariable.default(),
|
||||
user_inputs={
|
||||
"webhook_data": {
|
||||
"headers": {},
|
||||
|
||||
@ -127,7 +127,7 @@ class TestWorkflowEntry:
|
||||
return node_config
|
||||
|
||||
workflow = StubWorkflow()
|
||||
variable_pool = VariablePool(system_variables=SystemVariable.empty(), user_inputs={})
|
||||
variable_pool = VariablePool(system_variables=SystemVariable.default(), user_inputs={})
|
||||
expected_limits = CodeNodeLimits(
|
||||
max_string_length=dify_config.CODE_MAX_STRING_LENGTH,
|
||||
max_number=dify_config.CODE_MAX_NUMBER,
|
||||
@ -157,7 +157,7 @@ class TestWorkflowEntry:
|
||||
# Initialize variable pool with environment variables
|
||||
env_var = StringVariable(name="API_KEY", value="existing_key")
|
||||
variable_pool = VariablePool(
|
||||
system_variables=SystemVariable.empty(),
|
||||
system_variables=SystemVariable.default(),
|
||||
environment_variables=[env_var],
|
||||
user_inputs={},
|
||||
)
|
||||
@ -198,7 +198,7 @@ class TestWorkflowEntry:
|
||||
# Initialize variable pool with conversation variables
|
||||
conv_var = StringVariable(name="last_message", value="Hello")
|
||||
variable_pool = VariablePool(
|
||||
system_variables=SystemVariable.empty(),
|
||||
system_variables=SystemVariable.default(),
|
||||
conversation_variables=[conv_var],
|
||||
user_inputs={},
|
||||
)
|
||||
@ -239,7 +239,7 @@ class TestWorkflowEntry:
|
||||
"""Test mapping regular node variables from user inputs to variable pool."""
|
||||
# Initialize empty variable pool
|
||||
variable_pool = VariablePool(
|
||||
system_variables=SystemVariable.empty(),
|
||||
system_variables=SystemVariable.default(),
|
||||
user_inputs={},
|
||||
)
|
||||
|
||||
@ -281,7 +281,7 @@ class TestWorkflowEntry:
|
||||
def test_mapping_user_inputs_with_file_handling(self):
|
||||
"""Test mapping file inputs from user inputs to variable pool."""
|
||||
variable_pool = VariablePool(
|
||||
system_variables=SystemVariable.empty(),
|
||||
system_variables=SystemVariable.default(),
|
||||
user_inputs={},
|
||||
)
|
||||
|
||||
@ -340,7 +340,7 @@ class TestWorkflowEntry:
|
||||
def test_mapping_user_inputs_missing_variable_error(self):
|
||||
"""Test that mapping raises error when required variable is missing."""
|
||||
variable_pool = VariablePool(
|
||||
system_variables=SystemVariable.empty(),
|
||||
system_variables=SystemVariable.default(),
|
||||
user_inputs={},
|
||||
)
|
||||
|
||||
@ -366,7 +366,7 @@ class TestWorkflowEntry:
|
||||
def test_mapping_user_inputs_with_alternative_key_format(self):
|
||||
"""Test mapping with alternative key format (without node prefix)."""
|
||||
variable_pool = VariablePool(
|
||||
system_variables=SystemVariable.empty(),
|
||||
system_variables=SystemVariable.default(),
|
||||
user_inputs={},
|
||||
)
|
||||
|
||||
@ -396,7 +396,7 @@ class TestWorkflowEntry:
|
||||
def test_mapping_user_inputs_with_complex_selectors(self):
|
||||
"""Test mapping with complex node variable keys."""
|
||||
variable_pool = VariablePool(
|
||||
system_variables=SystemVariable.empty(),
|
||||
system_variables=SystemVariable.default(),
|
||||
user_inputs={},
|
||||
)
|
||||
|
||||
@ -432,7 +432,7 @@ class TestWorkflowEntry:
|
||||
def test_mapping_user_inputs_invalid_node_variable(self):
|
||||
"""Test that mapping handles invalid node variable format."""
|
||||
variable_pool = VariablePool(
|
||||
system_variables=SystemVariable.empty(),
|
||||
system_variables=SystemVariable.default(),
|
||||
user_inputs={},
|
||||
)
|
||||
|
||||
|
||||
@ -0,0 +1,710 @@
|
||||
import type { Inputs, ModelConfig } from '@/models/debug'
|
||||
import type { PromptVariable } from '@/types/app'
|
||||
import { fireEvent, render, screen } from '@testing-library/react'
|
||||
import ChatUserInput from './chat-user-input'
|
||||
|
||||
const mockSetInputs = vi.fn()
|
||||
const mockUseContext = vi.fn()
|
||||
|
||||
vi.mock('react-i18next', () => ({
|
||||
useTranslation: () => ({
|
||||
t: (key: string) => key,
|
||||
}),
|
||||
}))
|
||||
|
||||
vi.mock('use-context-selector', () => ({
|
||||
useContext: () => mockUseContext(),
|
||||
createContext: vi.fn(() => ({})),
|
||||
}))
|
||||
|
||||
vi.mock('@/app/components/base/input', () => ({
|
||||
default: ({ value, onChange, placeholder, autoFocus, maxLength, readOnly, type }: {
|
||||
value: string
|
||||
onChange: (e: { target: { value: string } }) => void
|
||||
placeholder?: string
|
||||
autoFocus?: boolean
|
||||
maxLength?: number
|
||||
readOnly?: boolean
|
||||
type?: string
|
||||
}) => (
|
||||
<input
|
||||
data-testid={`input-${placeholder}`}
|
||||
data-autofocus={autoFocus ? 'true' : undefined}
|
||||
type={type || 'text'}
|
||||
value={value}
|
||||
onChange={onChange}
|
||||
placeholder={placeholder}
|
||||
maxLength={maxLength}
|
||||
readOnly={readOnly}
|
||||
/>
|
||||
),
|
||||
}))
|
||||
|
||||
vi.mock('@/app/components/base/select', () => ({
|
||||
default: ({ defaultValue, onSelect, items, disabled, className }: {
|
||||
defaultValue: string
|
||||
onSelect: (item: { value: string }) => void
|
||||
items: { name: string, value: string }[]
|
||||
allowSearch?: boolean
|
||||
disabled?: boolean
|
||||
className?: string
|
||||
}) => (
|
||||
<select
|
||||
data-testid="select-input"
|
||||
value={defaultValue}
|
||||
onChange={e => onSelect({ value: e.target.value })}
|
||||
disabled={disabled}
|
||||
className={className}
|
||||
>
|
||||
{items.map(item => (
|
||||
<option key={item.value} value={item.value}>{item.name}</option>
|
||||
))}
|
||||
</select>
|
||||
),
|
||||
}))
|
||||
|
||||
vi.mock('@/app/components/base/textarea', () => ({
|
||||
default: ({ value, onChange, placeholder, readOnly, className }: {
|
||||
value: string
|
||||
onChange: (e: { target: { value: string } }) => void
|
||||
placeholder?: string
|
||||
readOnly?: boolean
|
||||
className?: string
|
||||
}) => (
|
||||
<textarea
|
||||
data-testid={`textarea-${placeholder}`}
|
||||
value={value}
|
||||
onChange={onChange}
|
||||
placeholder={placeholder}
|
||||
readOnly={readOnly}
|
||||
className={className}
|
||||
/>
|
||||
),
|
||||
}))
|
||||
|
||||
vi.mock('@/app/components/workflow/nodes/_base/components/before-run-form/bool-input', () => ({
|
||||
default: ({ name, value, required, onChange, readonly }: {
|
||||
name: string
|
||||
value: boolean
|
||||
required?: boolean
|
||||
onChange: (value: boolean) => void
|
||||
readonly?: boolean
|
||||
}) => (
|
||||
<div data-testid={`bool-input-${name}`}>
|
||||
<input
|
||||
type="checkbox"
|
||||
checked={value}
|
||||
onChange={e => onChange(e.target.checked)}
|
||||
disabled={readonly}
|
||||
data-required={required}
|
||||
/>
|
||||
<span>{name}</span>
|
||||
</div>
|
||||
),
|
||||
}))
|
||||
|
||||
// Extended type to match runtime behavior (includes 'paragraph', 'checkbox', 'default')
|
||||
type ExtendedPromptVariable = {
|
||||
key: string
|
||||
name: string
|
||||
type: 'string' | 'number' | 'select' | 'paragraph' | 'checkbox'
|
||||
required: boolean
|
||||
options?: string[]
|
||||
max_length?: number
|
||||
default?: string | null
|
||||
}
|
||||
|
||||
const createPromptVariable = (overrides: Partial<ExtendedPromptVariable> = {}): ExtendedPromptVariable => ({
|
||||
key: 'test-key',
|
||||
name: 'Test Name',
|
||||
type: 'string',
|
||||
required: false,
|
||||
...overrides,
|
||||
})
|
||||
|
||||
const createModelConfig = (promptVariables: ExtendedPromptVariable[] = []): ModelConfig => ({
|
||||
provider: 'openai',
|
||||
model_id: 'gpt-4',
|
||||
mode: 'chat',
|
||||
configs: {
|
||||
prompt_template: '',
|
||||
prompt_variables: promptVariables as PromptVariable[],
|
||||
},
|
||||
} as ModelConfig)
|
||||
|
||||
const createContextValue = (overrides: Partial<{
|
||||
modelConfig: ModelConfig
|
||||
setInputs: (inputs: Inputs) => void
|
||||
readonly: boolean
|
||||
}> = {}) => ({
|
||||
modelConfig: createModelConfig(),
|
||||
setInputs: mockSetInputs,
|
||||
readonly: false,
|
||||
...overrides,
|
||||
})
|
||||
|
||||
describe('ChatUserInput', () => {
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks()
|
||||
mockUseContext.mockReturnValue(createContextValue())
|
||||
})
|
||||
|
||||
describe('Rendering', () => {
|
||||
it('should return null when no prompt variables exist', () => {
|
||||
mockUseContext.mockReturnValue(createContextValue({
|
||||
modelConfig: createModelConfig([]),
|
||||
}))
|
||||
|
||||
const { container } = render(<ChatUserInput inputs={{}} />)
|
||||
expect(container.firstChild).toBeNull()
|
||||
})
|
||||
|
||||
it('should return null when prompt variables have empty keys', () => {
|
||||
mockUseContext.mockReturnValue(createContextValue({
|
||||
modelConfig: createModelConfig([
|
||||
createPromptVariable({ key: '', name: 'Test' }),
|
||||
createPromptVariable({ key: ' ', name: 'Test2' }),
|
||||
]),
|
||||
}))
|
||||
|
||||
const { container } = render(<ChatUserInput inputs={{}} />)
|
||||
expect(container.firstChild).toBeNull()
|
||||
})
|
||||
|
||||
it('should return null when prompt variables have empty names', () => {
|
||||
mockUseContext.mockReturnValue(createContextValue({
|
||||
modelConfig: createModelConfig([
|
||||
createPromptVariable({ key: 'key1', name: '' }),
|
||||
createPromptVariable({ key: 'key2', name: ' ' }),
|
||||
]),
|
||||
}))
|
||||
|
||||
const { container } = render(<ChatUserInput inputs={{}} />)
|
||||
expect(container.firstChild).toBeNull()
|
||||
})
|
||||
|
||||
it('should render string input type', () => {
|
||||
mockUseContext.mockReturnValue(createContextValue({
|
||||
modelConfig: createModelConfig([
|
||||
createPromptVariable({ key: 'name', name: 'Name', type: 'string' }),
|
||||
]),
|
||||
}))
|
||||
|
||||
render(<ChatUserInput inputs={{}} />)
|
||||
expect(screen.getByTestId('input-Name')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should render paragraph input type', () => {
|
||||
mockUseContext.mockReturnValue(createContextValue({
|
||||
modelConfig: createModelConfig([
|
||||
createPromptVariable({ key: 'description', name: 'Description', type: 'paragraph' }),
|
||||
]),
|
||||
}))
|
||||
|
||||
render(<ChatUserInput inputs={{}} />)
|
||||
expect(screen.getByTestId('textarea-Description')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should render select input type', () => {
|
||||
mockUseContext.mockReturnValue(createContextValue({
|
||||
modelConfig: createModelConfig([
|
||||
createPromptVariable({ key: 'choice', name: 'Choice', type: 'select', options: ['A', 'B', 'C'] }),
|
||||
]),
|
||||
}))
|
||||
|
||||
render(<ChatUserInput inputs={{}} />)
|
||||
expect(screen.getByTestId('select-input')).toBeInTheDocument()
|
||||
expect(screen.getByText('A')).toBeInTheDocument()
|
||||
expect(screen.getByText('B')).toBeInTheDocument()
|
||||
expect(screen.getByText('C')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should render number input type', () => {
|
||||
mockUseContext.mockReturnValue(createContextValue({
|
||||
modelConfig: createModelConfig([
|
||||
createPromptVariable({ key: 'count', name: 'Count', type: 'number' }),
|
||||
]),
|
||||
}))
|
||||
|
||||
render(<ChatUserInput inputs={{}} />)
|
||||
const input = screen.getByTestId('input-Count')
|
||||
expect(input).toBeInTheDocument()
|
||||
expect(input).toHaveAttribute('type', 'number')
|
||||
})
|
||||
|
||||
it('should render checkbox input type', () => {
|
||||
mockUseContext.mockReturnValue(createContextValue({
|
||||
modelConfig: createModelConfig([
|
||||
createPromptVariable({ key: 'enabled', name: 'Enabled', type: 'checkbox' }),
|
||||
]),
|
||||
}))
|
||||
|
||||
render(<ChatUserInput inputs={{}} />)
|
||||
expect(screen.getByTestId('bool-input-Enabled')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should render multiple input types', () => {
|
||||
mockUseContext.mockReturnValue(createContextValue({
|
||||
modelConfig: createModelConfig([
|
||||
createPromptVariable({ key: 'name', name: 'Name', type: 'string' }),
|
||||
createPromptVariable({ key: 'desc', name: 'Description', type: 'paragraph' }),
|
||||
createPromptVariable({ key: 'choice', name: 'Choice', type: 'select', options: ['X', 'Y'] }),
|
||||
]),
|
||||
}))
|
||||
|
||||
render(<ChatUserInput inputs={{}} />)
|
||||
expect(screen.getByTestId('input-Name')).toBeInTheDocument()
|
||||
expect(screen.getByTestId('textarea-Description')).toBeInTheDocument()
|
||||
expect(screen.getByTestId('select-input')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should show optional label for non-required fields', () => {
|
||||
mockUseContext.mockReturnValue(createContextValue({
|
||||
modelConfig: createModelConfig([
|
||||
createPromptVariable({ key: 'name', name: 'Name', type: 'string', required: false }),
|
||||
]),
|
||||
}))
|
||||
|
||||
render(<ChatUserInput inputs={{}} />)
|
||||
expect(screen.getByText('panel.optional')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should not show optional label for required fields', () => {
|
||||
mockUseContext.mockReturnValue(createContextValue({
|
||||
modelConfig: createModelConfig([
|
||||
createPromptVariable({ key: 'name', name: 'Name', type: 'string', required: true }),
|
||||
]),
|
||||
}))
|
||||
|
||||
render(<ChatUserInput inputs={{}} />)
|
||||
expect(screen.queryByText('panel.optional')).not.toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should use key as label when name is not provided', () => {
|
||||
mockUseContext.mockReturnValue(createContextValue({
|
||||
modelConfig: createModelConfig([
|
||||
createPromptVariable({ key: 'myKey', name: '', type: 'string' }),
|
||||
]),
|
||||
}))
|
||||
|
||||
// This should actually return null because name is empty
|
||||
const { container } = render(<ChatUserInput inputs={{}} />)
|
||||
expect(container.firstChild).toBeNull()
|
||||
})
|
||||
})
|
||||
|
||||
describe('Input Values', () => {
|
||||
it('should display existing input values for string type', () => {
|
||||
mockUseContext.mockReturnValue(createContextValue({
|
||||
modelConfig: createModelConfig([
|
||||
createPromptVariable({ key: 'name', name: 'Name', type: 'string' }),
|
||||
]),
|
||||
}))
|
||||
|
||||
render(<ChatUserInput inputs={{ name: 'John' }} />)
|
||||
expect(screen.getByTestId('input-Name')).toHaveValue('John')
|
||||
})
|
||||
|
||||
it('should display existing input values for paragraph type', () => {
|
||||
mockUseContext.mockReturnValue(createContextValue({
|
||||
modelConfig: createModelConfig([
|
||||
createPromptVariable({ key: 'desc', name: 'Description', type: 'paragraph' }),
|
||||
]),
|
||||
}))
|
||||
|
||||
render(<ChatUserInput inputs={{ desc: 'Long text here' }} />)
|
||||
expect(screen.getByTestId('textarea-Description')).toHaveValue('Long text here')
|
||||
})
|
||||
|
||||
it('should display existing input values for number type', () => {
|
||||
mockUseContext.mockReturnValue(createContextValue({
|
||||
modelConfig: createModelConfig([
|
||||
createPromptVariable({ key: 'count', name: 'Count', type: 'number' }),
|
||||
]),
|
||||
}))
|
||||
|
||||
render(<ChatUserInput inputs={{ count: 42 }} />)
|
||||
// Number type input still uses string value internally
|
||||
expect(screen.getByTestId('input-Count')).toHaveValue(42)
|
||||
})
|
||||
|
||||
it('should display checkbox as checked when value is truthy', () => {
|
||||
mockUseContext.mockReturnValue(createContextValue({
|
||||
modelConfig: createModelConfig([
|
||||
createPromptVariable({ key: 'enabled', name: 'Enabled', type: 'checkbox' }),
|
||||
]),
|
||||
}))
|
||||
|
||||
render(<ChatUserInput inputs={{ enabled: true }} />)
|
||||
const checkbox = screen.getByTestId('bool-input-Enabled').querySelector('input')
|
||||
expect(checkbox).toBeChecked()
|
||||
})
|
||||
|
||||
it('should display checkbox as unchecked when value is falsy', () => {
|
||||
mockUseContext.mockReturnValue(createContextValue({
|
||||
modelConfig: createModelConfig([
|
||||
createPromptVariable({ key: 'enabled', name: 'Enabled', type: 'checkbox' }),
|
||||
]),
|
||||
}))
|
||||
|
||||
render(<ChatUserInput inputs={{ enabled: false }} />)
|
||||
const checkbox = screen.getByTestId('bool-input-Enabled').querySelector('input')
|
||||
expect(checkbox).not.toBeChecked()
|
||||
})
|
||||
|
||||
it('should handle empty string values', () => {
|
||||
mockUseContext.mockReturnValue(createContextValue({
|
||||
modelConfig: createModelConfig([
|
||||
createPromptVariable({ key: 'name', name: 'Name', type: 'string' }),
|
||||
]),
|
||||
}))
|
||||
|
||||
render(<ChatUserInput inputs={{ name: '' }} />)
|
||||
expect(screen.getByTestId('input-Name')).toHaveValue('')
|
||||
})
|
||||
|
||||
it('should handle undefined values', () => {
|
||||
mockUseContext.mockReturnValue(createContextValue({
|
||||
modelConfig: createModelConfig([
|
||||
createPromptVariable({ key: 'name', name: 'Name', type: 'string' }),
|
||||
]),
|
||||
}))
|
||||
|
||||
render(<ChatUserInput inputs={{}} />)
|
||||
expect(screen.getByTestId('input-Name')).toHaveValue('')
|
||||
})
|
||||
})
|
||||
|
||||
describe('User Interactions', () => {
|
||||
it('should call setInputs when string input changes', () => {
|
||||
mockUseContext.mockReturnValue(createContextValue({
|
||||
modelConfig: createModelConfig([
|
||||
createPromptVariable({ key: 'name', name: 'Name', type: 'string' }),
|
||||
]),
|
||||
}))
|
||||
|
||||
render(<ChatUserInput inputs={{}} />)
|
||||
fireEvent.change(screen.getByTestId('input-Name'), { target: { value: 'New Value' } })
|
||||
|
||||
expect(mockSetInputs).toHaveBeenCalledWith({ name: 'New Value' })
|
||||
})
|
||||
|
||||
it('should call setInputs when paragraph input changes', () => {
|
||||
mockUseContext.mockReturnValue(createContextValue({
|
||||
modelConfig: createModelConfig([
|
||||
createPromptVariable({ key: 'desc', name: 'Description', type: 'paragraph' }),
|
||||
]),
|
||||
}))
|
||||
|
||||
render(<ChatUserInput inputs={{}} />)
|
||||
fireEvent.change(screen.getByTestId('textarea-Description'), { target: { value: 'New Description' } })
|
||||
|
||||
expect(mockSetInputs).toHaveBeenCalledWith({ desc: 'New Description' })
|
||||
})
|
||||
|
||||
it('should call setInputs when select input changes', () => {
|
||||
mockUseContext.mockReturnValue(createContextValue({
|
||||
modelConfig: createModelConfig([
|
||||
createPromptVariable({ key: 'choice', name: 'Choice', type: 'select', options: ['A', 'B', 'C'] }),
|
||||
]),
|
||||
}))
|
||||
|
||||
render(<ChatUserInput inputs={{ choice: 'A' }} />)
|
||||
fireEvent.change(screen.getByTestId('select-input'), { target: { value: 'B' } })
|
||||
|
||||
expect(mockSetInputs).toHaveBeenCalledWith({ choice: 'B' })
|
||||
})
|
||||
|
||||
it('should call setInputs when number input changes', () => {
|
||||
mockUseContext.mockReturnValue(createContextValue({
|
||||
modelConfig: createModelConfig([
|
||||
createPromptVariable({ key: 'count', name: 'Count', type: 'number' }),
|
||||
]),
|
||||
}))
|
||||
|
||||
render(<ChatUserInput inputs={{}} />)
|
||||
fireEvent.change(screen.getByTestId('input-Count'), { target: { value: '100' } })
|
||||
|
||||
expect(mockSetInputs).toHaveBeenCalledWith({ count: '100' })
|
||||
})
|
||||
|
||||
it('should call setInputs when checkbox changes', () => {
|
||||
mockUseContext.mockReturnValue(createContextValue({
|
||||
modelConfig: createModelConfig([
|
||||
createPromptVariable({ key: 'enabled', name: 'Enabled', type: 'checkbox' }),
|
||||
]),
|
||||
}))
|
||||
|
||||
render(<ChatUserInput inputs={{ enabled: false }} />)
|
||||
const checkbox = screen.getByTestId('bool-input-Enabled').querySelector('input')!
|
||||
fireEvent.click(checkbox)
|
||||
|
||||
expect(mockSetInputs).toHaveBeenCalledWith({ enabled: true })
|
||||
})
|
||||
|
||||
it('should not call setInputs for unknown keys', () => {
|
||||
mockUseContext.mockReturnValue(createContextValue({
|
||||
modelConfig: createModelConfig([
|
||||
createPromptVariable({ key: 'name', name: 'Name', type: 'string' }),
|
||||
]),
|
||||
}))
|
||||
|
||||
render(<ChatUserInput inputs={{}} />)
|
||||
|
||||
// The component filters by promptVariableObj, so unknown keys won't trigger updates
|
||||
// This is tested indirectly - only valid keys should trigger setInputs
|
||||
fireEvent.change(screen.getByTestId('input-Name'), { target: { value: 'Valid' } })
|
||||
|
||||
expect(mockSetInputs).toHaveBeenCalledTimes(1)
|
||||
expect(mockSetInputs).toHaveBeenCalledWith({ name: 'Valid' })
|
||||
})
|
||||
})
|
||||
|
||||
describe('Readonly Mode', () => {
|
||||
it('should set string input as readonly when readonly is true', () => {
|
||||
mockUseContext.mockReturnValue(createContextValue({
|
||||
modelConfig: createModelConfig([
|
||||
createPromptVariable({ key: 'name', name: 'Name', type: 'string' }),
|
||||
]),
|
||||
readonly: true,
|
||||
}))
|
||||
|
||||
render(<ChatUserInput inputs={{}} />)
|
||||
expect(screen.getByTestId('input-Name')).toHaveAttribute('readonly')
|
||||
})
|
||||
|
||||
it('should set paragraph input as readonly when readonly is true', () => {
|
||||
mockUseContext.mockReturnValue(createContextValue({
|
||||
modelConfig: createModelConfig([
|
||||
createPromptVariable({ key: 'desc', name: 'Description', type: 'paragraph' }),
|
||||
]),
|
||||
readonly: true,
|
||||
}))
|
||||
|
||||
render(<ChatUserInput inputs={{}} />)
|
||||
expect(screen.getByTestId('textarea-Description')).toHaveAttribute('readonly')
|
||||
})
|
||||
|
||||
it('should disable select when readonly is true', () => {
|
||||
mockUseContext.mockReturnValue(createContextValue({
|
||||
modelConfig: createModelConfig([
|
||||
createPromptVariable({ key: 'choice', name: 'Choice', type: 'select', options: ['A', 'B'] }),
|
||||
]),
|
||||
readonly: true,
|
||||
}))
|
||||
|
||||
render(<ChatUserInput inputs={{}} />)
|
||||
expect(screen.getByTestId('select-input')).toBeDisabled()
|
||||
})
|
||||
|
||||
it('should disable checkbox when readonly is true', () => {
|
||||
mockUseContext.mockReturnValue(createContextValue({
|
||||
modelConfig: createModelConfig([
|
||||
createPromptVariable({ key: 'enabled', name: 'Enabled', type: 'checkbox' }),
|
||||
]),
|
||||
readonly: true,
|
||||
}))
|
||||
|
||||
render(<ChatUserInput inputs={{}} />)
|
||||
const checkbox = screen.getByTestId('bool-input-Enabled').querySelector('input')
|
||||
expect(checkbox).toBeDisabled()
|
||||
})
|
||||
})
|
||||
|
||||
describe('Default Values', () => {
|
||||
it('should initialize inputs with default values when field is empty', () => {
|
||||
mockUseContext.mockReturnValue(createContextValue({
|
||||
modelConfig: createModelConfig([
|
||||
createPromptVariable({ key: 'name', name: 'Name', type: 'string', default: 'Default Name' }),
|
||||
]),
|
||||
}))
|
||||
|
||||
render(<ChatUserInput inputs={{}} />)
|
||||
|
||||
expect(mockSetInputs).toHaveBeenCalledWith({ name: 'Default Name' })
|
||||
})
|
||||
|
||||
it('should not override existing values with defaults', () => {
|
||||
mockUseContext.mockReturnValue(createContextValue({
|
||||
modelConfig: createModelConfig([
|
||||
createPromptVariable({ key: 'name', name: 'Name', type: 'string', default: 'Default' }),
|
||||
]),
|
||||
}))
|
||||
|
||||
render(<ChatUserInput inputs={{ name: 'Existing Value' }} />)
|
||||
|
||||
// setInputs should not be called since there's already a value
|
||||
expect(mockSetInputs).not.toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('should handle multiple default values', () => {
|
||||
mockUseContext.mockReturnValue(createContextValue({
|
||||
modelConfig: createModelConfig([
|
||||
createPromptVariable({ key: 'name', name: 'Name', type: 'string', default: 'Default Name' }),
|
||||
createPromptVariable({ key: 'count', name: 'Count', type: 'number', default: '10' }),
|
||||
]),
|
||||
}))
|
||||
|
||||
render(<ChatUserInput inputs={{}} />)
|
||||
|
||||
expect(mockSetInputs).toHaveBeenCalledWith({
|
||||
name: 'Default Name',
|
||||
count: '10',
|
||||
})
|
||||
})
|
||||
|
||||
it('should not set default when default is empty string', () => {
|
||||
mockUseContext.mockReturnValue(createContextValue({
|
||||
modelConfig: createModelConfig([
|
||||
createPromptVariable({ key: 'name', name: 'Name', type: 'string', default: '' }),
|
||||
]),
|
||||
}))
|
||||
|
||||
render(<ChatUserInput inputs={{}} />)
|
||||
|
||||
expect(mockSetInputs).not.toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('should not set default when default is undefined', () => {
|
||||
mockUseContext.mockReturnValue(createContextValue({
|
||||
modelConfig: createModelConfig([
|
||||
createPromptVariable({ key: 'name', name: 'Name', type: 'string' }),
|
||||
]),
|
||||
}))
|
||||
|
||||
render(<ChatUserInput inputs={{}} />)
|
||||
|
||||
expect(mockSetInputs).not.toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('should not set default when default is null', () => {
|
||||
mockUseContext.mockReturnValue(createContextValue({
|
||||
modelConfig: createModelConfig([
|
||||
createPromptVariable({ key: 'name', name: 'Name', type: 'string', default: null as unknown as string }),
|
||||
]),
|
||||
}))
|
||||
|
||||
render(<ChatUserInput inputs={{}} />)
|
||||
|
||||
expect(mockSetInputs).not.toHaveBeenCalled()
|
||||
})
|
||||
})
|
||||
|
||||
describe('AutoFocus', () => {
|
||||
it('should set autoFocus on first string input', () => {
|
||||
mockUseContext.mockReturnValue(createContextValue({
|
||||
modelConfig: createModelConfig([
|
||||
createPromptVariable({ key: 'first', name: 'First', type: 'string' }),
|
||||
createPromptVariable({ key: 'second', name: 'Second', type: 'string' }),
|
||||
]),
|
||||
}))
|
||||
|
||||
render(<ChatUserInput inputs={{}} />)
|
||||
expect(screen.getByTestId('input-First')).toHaveAttribute('data-autofocus', 'true')
|
||||
expect(screen.getByTestId('input-Second')).not.toHaveAttribute('data-autofocus')
|
||||
})
|
||||
|
||||
it('should set autoFocus on first number input when it is the first field', () => {
|
||||
mockUseContext.mockReturnValue(createContextValue({
|
||||
modelConfig: createModelConfig([
|
||||
createPromptVariable({ key: 'count', name: 'Count', type: 'number' }),
|
||||
createPromptVariable({ key: 'name', name: 'Name', type: 'string' }),
|
||||
]),
|
||||
}))
|
||||
|
||||
render(<ChatUserInput inputs={{}} />)
|
||||
expect(screen.getByTestId('input-Count')).toHaveAttribute('data-autofocus', 'true')
|
||||
})
|
||||
})
|
||||
|
||||
describe('MaxLength', () => {
|
||||
it('should pass maxLength to string input', () => {
|
||||
mockUseContext.mockReturnValue(createContextValue({
|
||||
modelConfig: createModelConfig([
|
||||
createPromptVariable({ key: 'name', name: 'Name', type: 'string', max_length: 50 }),
|
||||
]),
|
||||
}))
|
||||
|
||||
render(<ChatUserInput inputs={{}} />)
|
||||
expect(screen.getByTestId('input-Name')).toHaveAttribute('maxLength', '50')
|
||||
})
|
||||
|
||||
it('should pass maxLength to number input', () => {
|
||||
mockUseContext.mockReturnValue(createContextValue({
|
||||
modelConfig: createModelConfig([
|
||||
createPromptVariable({ key: 'count', name: 'Count', type: 'number', max_length: 10 }),
|
||||
]),
|
||||
}))
|
||||
|
||||
render(<ChatUserInput inputs={{}} />)
|
||||
expect(screen.getByTestId('input-Count')).toHaveAttribute('maxLength', '10')
|
||||
})
|
||||
})
|
||||
|
||||
describe('Edge Cases', () => {
|
||||
it('should handle select with empty options', () => {
|
||||
mockUseContext.mockReturnValue(createContextValue({
|
||||
modelConfig: createModelConfig([
|
||||
createPromptVariable({ key: 'choice', name: 'Choice', type: 'select', options: [] }),
|
||||
]),
|
||||
}))
|
||||
|
||||
render(<ChatUserInput inputs={{}} />)
|
||||
const select = screen.getByTestId('select-input')
|
||||
expect(select).toBeInTheDocument()
|
||||
expect(select.children).toHaveLength(0)
|
||||
})
|
||||
|
||||
it('should handle select with undefined options', () => {
|
||||
mockUseContext.mockReturnValue(createContextValue({
|
||||
modelConfig: createModelConfig([
|
||||
createPromptVariable({ key: 'choice', name: 'Choice', type: 'select' }),
|
||||
]),
|
||||
}))
|
||||
|
||||
render(<ChatUserInput inputs={{}} />)
|
||||
const select = screen.getByTestId('select-input')
|
||||
expect(select).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should preserve other input values when updating one field', () => {
|
||||
mockUseContext.mockReturnValue(createContextValue({
|
||||
modelConfig: createModelConfig([
|
||||
createPromptVariable({ key: 'name', name: 'Name', type: 'string' }),
|
||||
createPromptVariable({ key: 'desc', name: 'Description', type: 'paragraph' }),
|
||||
]),
|
||||
}))
|
||||
|
||||
render(<ChatUserInput inputs={{ name: 'Existing', desc: 'Also Existing' }} />)
|
||||
fireEvent.change(screen.getByTestId('input-Name'), { target: { value: 'Updated' } })
|
||||
|
||||
expect(mockSetInputs).toHaveBeenCalledWith({
|
||||
name: 'Updated',
|
||||
desc: 'Also Existing',
|
||||
})
|
||||
})
|
||||
|
||||
it('should convert non-string values to string for display', () => {
|
||||
mockUseContext.mockReturnValue(createContextValue({
|
||||
modelConfig: createModelConfig([
|
||||
createPromptVariable({ key: 'value', name: 'Value', type: 'string' }),
|
||||
]),
|
||||
}))
|
||||
|
||||
render(<ChatUserInput inputs={{ value: 123 as unknown as string }} />)
|
||||
expect(screen.getByTestId('input-Value')).toHaveValue('123')
|
||||
})
|
||||
|
||||
it('should not hide label for checkbox type', () => {
|
||||
mockUseContext.mockReturnValue(createContextValue({
|
||||
modelConfig: createModelConfig([
|
||||
createPromptVariable({ key: 'enabled', name: 'Is Enabled', type: 'checkbox' }),
|
||||
]),
|
||||
}))
|
||||
|
||||
render(<ChatUserInput inputs={{}} />)
|
||||
// For checkbox, the label is rendered inside BoolInput, not in the header
|
||||
expect(screen.queryByText('Is Enabled')).toBeInTheDocument()
|
||||
})
|
||||
})
|
||||
})
|
||||
@ -0,0 +1,641 @@
|
||||
import type { ModelAndParameter } from '../types'
|
||||
import type { ChatConfig, ChatItem as ChatItemType, OnSend } from '@/app/components/base/chat/types'
|
||||
import { render, screen } from '@testing-library/react'
|
||||
import { TransferMethod } from '@/app/components/base/chat/types'
|
||||
import { ModelFeatureEnum } from '@/app/components/header/account-setting/model-provider-page/declarations'
|
||||
import { APP_CHAT_WITH_MULTIPLE_MODEL, APP_CHAT_WITH_MULTIPLE_MODEL_RESTART } from '../types'
|
||||
import ChatItem from './chat-item'
|
||||
|
||||
const mockUseAppContext = vi.fn()
|
||||
const mockUseDebugConfigurationContext = vi.fn()
|
||||
const mockUseProviderContext = vi.fn()
|
||||
const mockUseFeatures = vi.fn()
|
||||
const mockUseConfigFromDebugContext = vi.fn()
|
||||
const mockUseFormattingChangedSubscription = vi.fn()
|
||||
const mockUseChat = vi.fn()
|
||||
const mockUseEventEmitterContextContext = vi.fn()
|
||||
const mockFetchConversationMessages = vi.fn()
|
||||
const mockFetchSuggestedQuestions = vi.fn()
|
||||
const mockStopChatMessageResponding = vi.fn()
|
||||
|
||||
let capturedChatProps: {
|
||||
config: ChatConfig
|
||||
chatList: ChatItemType[]
|
||||
isResponding: boolean
|
||||
onSend: OnSend
|
||||
suggestedQuestions: string[]
|
||||
allToolIcons: Record<string, string | undefined>
|
||||
} | null = null
|
||||
|
||||
let eventSubscriptionCallback: ((v: { type: string, payload?: Record<string, unknown> }) => void) | null = null
|
||||
|
||||
vi.mock('@/context/app-context', () => ({
|
||||
useAppContext: () => mockUseAppContext(),
|
||||
}))
|
||||
|
||||
vi.mock('@/context/debug-configuration', () => ({
|
||||
useDebugConfigurationContext: () => mockUseDebugConfigurationContext(),
|
||||
}))
|
||||
|
||||
vi.mock('@/context/provider-context', () => ({
|
||||
useProviderContext: () => mockUseProviderContext(),
|
||||
}))
|
||||
|
||||
vi.mock('@/app/components/base/features/hooks', () => ({
|
||||
useFeatures: (selector: (state: Record<string, unknown>) => unknown) => mockUseFeatures(selector),
|
||||
}))
|
||||
|
||||
vi.mock('../hooks', () => ({
|
||||
useConfigFromDebugContext: () => mockUseConfigFromDebugContext(),
|
||||
useFormattingChangedSubscription: (chatList: ChatItemType[]) => mockUseFormattingChangedSubscription(chatList),
|
||||
}))
|
||||
|
||||
vi.mock('@/app/components/base/chat/chat/hooks', () => ({
|
||||
useChat: () => mockUseChat(),
|
||||
}))
|
||||
|
||||
vi.mock('@/context/event-emitter', () => ({
|
||||
useEventEmitterContextContext: () => mockUseEventEmitterContextContext(),
|
||||
}))
|
||||
|
||||
vi.mock('@/service/debug', () => ({
|
||||
fetchConversationMessages: (...args: unknown[]) => mockFetchConversationMessages(...args),
|
||||
fetchSuggestedQuestions: (...args: unknown[]) => mockFetchSuggestedQuestions(...args),
|
||||
stopChatMessageResponding: (...args: unknown[]) => mockStopChatMessageResponding(...args),
|
||||
}))
|
||||
|
||||
vi.mock('@/app/components/base/chat/utils', () => ({
|
||||
getLastAnswer: (chatList: ChatItemType[]) => chatList.find(item => item.isAnswer),
|
||||
}))
|
||||
|
||||
vi.mock('@/utils', () => ({
|
||||
canFindTool: (collectionId: string, providerId: string) => collectionId === providerId,
|
||||
}))
|
||||
|
||||
vi.mock('@/app/components/base/chat/chat', () => ({
|
||||
default: (props: typeof capturedChatProps) => {
|
||||
capturedChatProps = props
|
||||
return (
|
||||
<div data-testid="chat-component">
|
||||
<span data-testid="chat-list-length">{props?.chatList?.length || 0}</span>
|
||||
<span data-testid="is-responding">{props?.isResponding ? 'yes' : 'no'}</span>
|
||||
<button
|
||||
data-testid="send-button"
|
||||
onClick={() => props?.onSend?.('test message', [{ id: 'file-1', name: 'test.txt', size: 100, type: 'text/plain', progress: 100, transferMethod: TransferMethod.local_file, supportFileType: 'document' }])}
|
||||
>
|
||||
Send
|
||||
</button>
|
||||
</div>
|
||||
)
|
||||
},
|
||||
}))
|
||||
|
||||
vi.mock('@/app/components/base/avatar', () => ({
|
||||
default: ({ name }: { name: string }) => <div data-testid="avatar">{name}</div>,
|
||||
}))
|
||||
|
||||
const createModelAndParameter = (overrides: Partial<ModelAndParameter> = {}): ModelAndParameter => ({
|
||||
id: 'model-1',
|
||||
model: 'gpt-3.5-turbo',
|
||||
provider: 'openai',
|
||||
parameters: { temperature: 0.7 },
|
||||
...overrides,
|
||||
})
|
||||
|
||||
const createDefaultMocks = () => {
|
||||
mockUseAppContext.mockReturnValue({
|
||||
userProfile: { avatar_url: 'http://avatar.url', name: 'Test User' },
|
||||
})
|
||||
|
||||
mockUseDebugConfigurationContext.mockReturnValue({
|
||||
modelConfig: {
|
||||
configs: { prompt_variables: [] },
|
||||
agentConfig: { tools: [] },
|
||||
},
|
||||
appId: 'app-123',
|
||||
inputs: { key: 'value' },
|
||||
collectionList: [],
|
||||
})
|
||||
|
||||
mockUseProviderContext.mockReturnValue({
|
||||
textGenerationModelList: [
|
||||
{
|
||||
provider: 'openai',
|
||||
models: [
|
||||
{
|
||||
model: 'gpt-3.5-turbo',
|
||||
features: [ModelFeatureEnum.vision],
|
||||
model_properties: { mode: 'chat' },
|
||||
},
|
||||
],
|
||||
},
|
||||
],
|
||||
})
|
||||
|
||||
mockUseFeatures.mockImplementation((selector: (state: Record<string, unknown>) => unknown) => {
|
||||
const state = {
|
||||
features: {
|
||||
moreLikeThis: { enabled: false },
|
||||
opening: { enabled: true, opening_statement: 'Hello!', suggested_questions: ['Q1'] },
|
||||
moderation: { enabled: false },
|
||||
speech2text: { enabled: true },
|
||||
text2speech: { enabled: false },
|
||||
file: { enabled: true },
|
||||
suggested: { enabled: true },
|
||||
citation: { enabled: false },
|
||||
annotationReply: { enabled: false },
|
||||
},
|
||||
}
|
||||
return selector(state)
|
||||
})
|
||||
|
||||
mockUseConfigFromDebugContext.mockReturnValue({
|
||||
base_config: 'test',
|
||||
})
|
||||
|
||||
mockUseChat.mockReturnValue({
|
||||
chatList: [{ id: 'msg-1', content: 'Hello', isAnswer: true }],
|
||||
isResponding: false,
|
||||
handleSend: vi.fn(),
|
||||
suggestedQuestions: ['Question 1', 'Question 2'],
|
||||
handleRestart: vi.fn(),
|
||||
})
|
||||
|
||||
mockUseEventEmitterContextContext.mockReturnValue({
|
||||
eventEmitter: {
|
||||
useSubscription: (callback: (v: { type: string, payload?: Record<string, unknown> }) => void) => {
|
||||
eventSubscriptionCallback = callback
|
||||
},
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
const renderComponent = (props: Partial<{ modelAndParameter: ModelAndParameter }> = {}) => {
|
||||
const defaultProps = {
|
||||
modelAndParameter: createModelAndParameter(),
|
||||
...props,
|
||||
}
|
||||
return render(<ChatItem {...defaultProps} />)
|
||||
}
|
||||
|
||||
describe('ChatItem', () => {
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks()
|
||||
capturedChatProps = null
|
||||
eventSubscriptionCallback = null
|
||||
createDefaultMocks()
|
||||
})
|
||||
|
||||
describe('rendering', () => {
|
||||
it('should render Chat component when chatList is not empty', () => {
|
||||
renderComponent()
|
||||
|
||||
expect(screen.getByTestId('chat-component')).toBeInTheDocument()
|
||||
expect(screen.getByTestId('chat-list-length')).toHaveTextContent('1')
|
||||
})
|
||||
|
||||
it('should not render when chatList is empty', () => {
|
||||
mockUseChat.mockReturnValue({
|
||||
chatList: [],
|
||||
isResponding: false,
|
||||
handleSend: vi.fn(),
|
||||
suggestedQuestions: [],
|
||||
handleRestart: vi.fn(),
|
||||
})
|
||||
|
||||
renderComponent()
|
||||
|
||||
expect(screen.queryByTestId('chat-component')).not.toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should pass correct config to Chat', () => {
|
||||
renderComponent()
|
||||
|
||||
expect(capturedChatProps?.config).toMatchObject({
|
||||
base_config: 'test',
|
||||
opening_statement: 'Hello!',
|
||||
suggested_questions: ['Q1'],
|
||||
})
|
||||
})
|
||||
|
||||
it('should pass suggestedQuestions to Chat', () => {
|
||||
renderComponent()
|
||||
|
||||
expect(capturedChatProps?.suggestedQuestions).toEqual(['Question 1', 'Question 2'])
|
||||
})
|
||||
|
||||
it('should pass isResponding to Chat', () => {
|
||||
mockUseChat.mockReturnValue({
|
||||
chatList: [{ id: 'msg-1' }],
|
||||
isResponding: true,
|
||||
handleSend: vi.fn(),
|
||||
suggestedQuestions: [],
|
||||
handleRestart: vi.fn(),
|
||||
})
|
||||
|
||||
renderComponent()
|
||||
|
||||
expect(screen.getByTestId('is-responding')).toHaveTextContent('yes')
|
||||
})
|
||||
})
|
||||
|
||||
describe('config composition', () => {
|
||||
it('should include opening statement when enabled', () => {
|
||||
renderComponent()
|
||||
|
||||
expect(capturedChatProps?.config.opening_statement).toBe('Hello!')
|
||||
})
|
||||
|
||||
it('should use empty opening statement when disabled', () => {
|
||||
mockUseFeatures.mockImplementation((selector: (state: Record<string, unknown>) => unknown) => {
|
||||
const state = {
|
||||
features: {
|
||||
moreLikeThis: { enabled: false },
|
||||
opening: { enabled: false, opening_statement: 'Should not appear' },
|
||||
moderation: { enabled: false },
|
||||
speech2text: { enabled: false },
|
||||
text2speech: { enabled: false },
|
||||
file: { enabled: false },
|
||||
suggested: { enabled: false },
|
||||
citation: { enabled: false },
|
||||
annotationReply: { enabled: false },
|
||||
},
|
||||
}
|
||||
return selector(state)
|
||||
})
|
||||
|
||||
renderComponent()
|
||||
|
||||
expect(capturedChatProps?.config.opening_statement).toBe('')
|
||||
expect(capturedChatProps?.config.suggested_questions).toEqual([])
|
||||
})
|
||||
})
|
||||
|
||||
describe('inputsForm transformation', () => {
|
||||
it('should filter out API type variables', () => {
|
||||
mockUseDebugConfigurationContext.mockReturnValue({
|
||||
modelConfig: {
|
||||
configs: {
|
||||
prompt_variables: [
|
||||
{ key: 'var1', name: 'Var 1', type: 'string' },
|
||||
{ key: 'var2', name: 'Var 2', type: 'api' },
|
||||
{ key: 'var3', name: 'Var 3', type: 'number' },
|
||||
],
|
||||
},
|
||||
agentConfig: { tools: [] },
|
||||
},
|
||||
appId: 'app-123',
|
||||
inputs: {},
|
||||
collectionList: [],
|
||||
})
|
||||
|
||||
renderComponent()
|
||||
|
||||
// The component transforms prompt_variables into inputsForm
|
||||
// We can verify this through the useChat call
|
||||
expect(mockUseChat).toHaveBeenCalled()
|
||||
})
|
||||
})
|
||||
|
||||
describe('event subscription', () => {
|
||||
it('should handle APP_CHAT_WITH_MULTIPLE_MODEL event', () => {
|
||||
const handleSend = vi.fn()
|
||||
mockUseChat.mockReturnValue({
|
||||
chatList: [{ id: 'msg-1' }],
|
||||
isResponding: false,
|
||||
handleSend,
|
||||
suggestedQuestions: [],
|
||||
handleRestart: vi.fn(),
|
||||
})
|
||||
|
||||
renderComponent()
|
||||
|
||||
// Trigger the event
|
||||
eventSubscriptionCallback?.({
|
||||
type: APP_CHAT_WITH_MULTIPLE_MODEL,
|
||||
payload: { message: 'Hello', files: [{ id: 'file-1' }] },
|
||||
})
|
||||
|
||||
expect(handleSend).toHaveBeenCalledWith(
|
||||
'apps/app-123/chat-messages',
|
||||
expect.objectContaining({
|
||||
query: 'Hello',
|
||||
inputs: { key: 'value' },
|
||||
}),
|
||||
expect.any(Object),
|
||||
)
|
||||
})
|
||||
|
||||
it('should handle APP_CHAT_WITH_MULTIPLE_MODEL_RESTART event', () => {
|
||||
const handleRestart = vi.fn()
|
||||
mockUseChat.mockReturnValue({
|
||||
chatList: [{ id: 'msg-1' }],
|
||||
isResponding: false,
|
||||
handleSend: vi.fn(),
|
||||
suggestedQuestions: [],
|
||||
handleRestart,
|
||||
})
|
||||
|
||||
renderComponent()
|
||||
|
||||
eventSubscriptionCallback?.({
|
||||
type: APP_CHAT_WITH_MULTIPLE_MODEL_RESTART,
|
||||
})
|
||||
|
||||
expect(handleRestart).toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('should ignore unrelated events', () => {
|
||||
const handleSend = vi.fn()
|
||||
const handleRestart = vi.fn()
|
||||
mockUseChat.mockReturnValue({
|
||||
chatList: [{ id: 'msg-1' }],
|
||||
isResponding: false,
|
||||
handleSend,
|
||||
suggestedQuestions: [],
|
||||
handleRestart,
|
||||
})
|
||||
|
||||
renderComponent()
|
||||
|
||||
eventSubscriptionCallback?.({
|
||||
type: 'SOME_OTHER_EVENT',
|
||||
payload: {},
|
||||
})
|
||||
|
||||
expect(handleSend).not.toHaveBeenCalled()
|
||||
expect(handleRestart).not.toHaveBeenCalled()
|
||||
})
|
||||
})
|
||||
|
||||
describe('doSend function', () => {
|
||||
it('should include files when vision is supported and file upload enabled', () => {
|
||||
const handleSend = vi.fn()
|
||||
mockUseChat.mockReturnValue({
|
||||
chatList: [{ id: 'msg-1' }],
|
||||
isResponding: false,
|
||||
handleSend,
|
||||
suggestedQuestions: [],
|
||||
handleRestart: vi.fn(),
|
||||
})
|
||||
|
||||
renderComponent()
|
||||
|
||||
eventSubscriptionCallback?.({
|
||||
type: APP_CHAT_WITH_MULTIPLE_MODEL,
|
||||
payload: { message: 'Hello', files: [{ id: 'file-1' }] },
|
||||
})
|
||||
|
||||
expect(handleSend).toHaveBeenCalledWith(
|
||||
expect.any(String),
|
||||
expect.objectContaining({
|
||||
files: [{ id: 'file-1' }],
|
||||
}),
|
||||
expect.any(Object),
|
||||
)
|
||||
})
|
||||
|
||||
it('should not include files when vision is not supported', () => {
|
||||
mockUseProviderContext.mockReturnValue({
|
||||
textGenerationModelList: [
|
||||
{
|
||||
provider: 'openai',
|
||||
models: [
|
||||
{
|
||||
model: 'gpt-3.5-turbo',
|
||||
features: [], // No vision support
|
||||
model_properties: { mode: 'chat' },
|
||||
},
|
||||
],
|
||||
},
|
||||
],
|
||||
})
|
||||
|
||||
const handleSend = vi.fn()
|
||||
mockUseChat.mockReturnValue({
|
||||
chatList: [{ id: 'msg-1' }],
|
||||
isResponding: false,
|
||||
handleSend,
|
||||
suggestedQuestions: [],
|
||||
handleRestart: vi.fn(),
|
||||
})
|
||||
|
||||
renderComponent()
|
||||
|
||||
eventSubscriptionCallback?.({
|
||||
type: APP_CHAT_WITH_MULTIPLE_MODEL,
|
||||
payload: { message: 'Hello', files: [{ id: 'file-1' }] },
|
||||
})
|
||||
|
||||
expect(handleSend).toHaveBeenCalledWith(
|
||||
expect.any(String),
|
||||
expect.not.objectContaining({
|
||||
files: expect.anything(),
|
||||
}),
|
||||
expect.any(Object),
|
||||
)
|
||||
})
|
||||
|
||||
it('should include model configuration in request', () => {
|
||||
const handleSend = vi.fn()
|
||||
mockUseChat.mockReturnValue({
|
||||
chatList: [{ id: 'msg-1' }],
|
||||
isResponding: false,
|
||||
handleSend,
|
||||
suggestedQuestions: [],
|
||||
handleRestart: vi.fn(),
|
||||
})
|
||||
|
||||
const modelAndParameter = createModelAndParameter({
|
||||
provider: 'openai',
|
||||
model: 'gpt-3.5-turbo',
|
||||
parameters: { temperature: 0.5 },
|
||||
})
|
||||
|
||||
renderComponent({ modelAndParameter })
|
||||
|
||||
eventSubscriptionCallback?.({
|
||||
type: APP_CHAT_WITH_MULTIPLE_MODEL,
|
||||
payload: { message: 'Hello', files: [] },
|
||||
})
|
||||
|
||||
expect(handleSend).toHaveBeenCalledWith(
|
||||
expect.any(String),
|
||||
expect.objectContaining({
|
||||
model_config: expect.objectContaining({
|
||||
model: expect.objectContaining({
|
||||
provider: 'openai',
|
||||
name: 'gpt-3.5-turbo',
|
||||
completion_params: { temperature: 0.5 },
|
||||
}),
|
||||
}),
|
||||
}),
|
||||
expect.any(Object),
|
||||
)
|
||||
})
|
||||
|
||||
it('should use parent_message_id from last answer', () => {
|
||||
const handleSend = vi.fn()
|
||||
mockUseChat.mockReturnValue({
|
||||
chatList: [
|
||||
{ id: 'msg-1', content: 'Hi', isAnswer: false },
|
||||
{ id: 'msg-2', content: 'Hello', isAnswer: true },
|
||||
],
|
||||
isResponding: false,
|
||||
handleSend,
|
||||
suggestedQuestions: [],
|
||||
handleRestart: vi.fn(),
|
||||
})
|
||||
|
||||
renderComponent()
|
||||
|
||||
eventSubscriptionCallback?.({
|
||||
type: APP_CHAT_WITH_MULTIPLE_MODEL,
|
||||
payload: { message: 'Hello', files: [] },
|
||||
})
|
||||
|
||||
expect(handleSend).toHaveBeenCalledWith(
|
||||
expect.any(String),
|
||||
expect.objectContaining({
|
||||
parent_message_id: 'msg-2',
|
||||
}),
|
||||
expect.any(Object),
|
||||
)
|
||||
})
|
||||
})
|
||||
|
||||
describe('allToolIcons', () => {
|
||||
it('should compute tool icons from collectionList', () => {
|
||||
mockUseDebugConfigurationContext.mockReturnValue({
|
||||
modelConfig: {
|
||||
configs: { prompt_variables: [] },
|
||||
agentConfig: {
|
||||
tools: [
|
||||
{ tool_name: 'tool1', provider_id: 'collection1' },
|
||||
{ tool_name: 'tool2', provider_id: 'collection2' },
|
||||
],
|
||||
},
|
||||
},
|
||||
appId: 'app-123',
|
||||
inputs: {},
|
||||
collectionList: [
|
||||
{ id: 'collection1', icon: 'icon1' },
|
||||
{ id: 'collection2', icon: 'icon2' },
|
||||
],
|
||||
})
|
||||
|
||||
renderComponent()
|
||||
|
||||
expect(capturedChatProps?.allToolIcons).toEqual({
|
||||
tool1: 'icon1',
|
||||
tool2: 'icon2',
|
||||
})
|
||||
})
|
||||
|
||||
it('should handle tools without matching collection', () => {
|
||||
mockUseDebugConfigurationContext.mockReturnValue({
|
||||
modelConfig: {
|
||||
configs: { prompt_variables: [] },
|
||||
agentConfig: {
|
||||
tools: [
|
||||
{ tool_name: 'tool1', provider_id: 'nonexistent' },
|
||||
],
|
||||
},
|
||||
},
|
||||
appId: 'app-123',
|
||||
inputs: {},
|
||||
collectionList: [],
|
||||
})
|
||||
|
||||
renderComponent()
|
||||
|
||||
expect(capturedChatProps?.allToolIcons).toEqual({
|
||||
tool1: undefined,
|
||||
})
|
||||
})
|
||||
|
||||
it('should handle empty tools array', () => {
|
||||
mockUseDebugConfigurationContext.mockReturnValue({
|
||||
modelConfig: {
|
||||
configs: { prompt_variables: [] },
|
||||
agentConfig: { tools: [] },
|
||||
},
|
||||
appId: 'app-123',
|
||||
inputs: {},
|
||||
collectionList: [],
|
||||
})
|
||||
|
||||
renderComponent()
|
||||
|
||||
expect(capturedChatProps?.allToolIcons).toEqual({})
|
||||
})
|
||||
})
|
||||
|
||||
describe('useFormattingChangedSubscription', () => {
|
||||
it('should call useFormattingChangedSubscription with chatList', () => {
|
||||
const chatList = [{ id: 'msg-1', content: 'Hello' }]
|
||||
mockUseChat.mockReturnValue({
|
||||
chatList,
|
||||
isResponding: false,
|
||||
handleSend: vi.fn(),
|
||||
suggestedQuestions: [],
|
||||
handleRestart: vi.fn(),
|
||||
})
|
||||
|
||||
renderComponent()
|
||||
|
||||
expect(mockUseFormattingChangedSubscription).toHaveBeenCalledWith(chatList)
|
||||
})
|
||||
})
|
||||
|
||||
describe('edge cases', () => {
|
||||
it('should handle missing provider in textGenerationModelList', () => {
|
||||
mockUseProviderContext.mockReturnValue({
|
||||
textGenerationModelList: [],
|
||||
})
|
||||
|
||||
const handleSend = vi.fn()
|
||||
mockUseChat.mockReturnValue({
|
||||
chatList: [{ id: 'msg-1' }],
|
||||
isResponding: false,
|
||||
handleSend,
|
||||
suggestedQuestions: [],
|
||||
handleRestart: vi.fn(),
|
||||
})
|
||||
|
||||
renderComponent()
|
||||
|
||||
eventSubscriptionCallback?.({
|
||||
type: APP_CHAT_WITH_MULTIPLE_MODEL,
|
||||
payload: { message: 'Hello', files: [] },
|
||||
})
|
||||
|
||||
// Should still call handleSend without crashing
|
||||
expect(handleSend).toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('should handle null eventEmitter', () => {
|
||||
mockUseEventEmitterContextContext.mockReturnValue({
|
||||
eventEmitter: null,
|
||||
})
|
||||
|
||||
expect(() => renderComponent()).not.toThrow()
|
||||
})
|
||||
|
||||
it('should handle undefined tools in agentConfig', () => {
|
||||
mockUseDebugConfigurationContext.mockReturnValue({
|
||||
modelConfig: {
|
||||
configs: { prompt_variables: [] },
|
||||
agentConfig: { tools: undefined },
|
||||
},
|
||||
appId: 'app-123',
|
||||
inputs: {},
|
||||
collectionList: [],
|
||||
})
|
||||
|
||||
// This may throw since the code does agentConfig.tools?.forEach
|
||||
// But the optional chaining should handle it
|
||||
expect(() => renderComponent()).not.toThrow()
|
||||
})
|
||||
})
|
||||
})
|
||||
@ -0,0 +1,224 @@
|
||||
import type { ModelAndParameter } from '../types'
|
||||
import type { DebugWithMultipleModelContextType } from './context'
|
||||
import { render, screen } from '@testing-library/react'
|
||||
import {
|
||||
DebugWithMultipleModelContextProvider,
|
||||
useDebugWithMultipleModelContext,
|
||||
} from './context'
|
||||
|
||||
const createModelAndParameter = (overrides: Partial<ModelAndParameter> = {}): ModelAndParameter => ({
|
||||
id: 'model-1',
|
||||
model: 'gpt-3.5-turbo',
|
||||
provider: 'openai',
|
||||
parameters: {},
|
||||
...overrides,
|
||||
})
|
||||
|
||||
const TestConsumer = () => {
|
||||
const context = useDebugWithMultipleModelContext()
|
||||
return (
|
||||
<div>
|
||||
<span data-testid="configs-count">{context.multipleModelConfigs.length}</span>
|
||||
<span data-testid="has-check-can-send">{context.checkCanSend ? 'yes' : 'no'}</span>
|
||||
<button
|
||||
data-testid="call-on-change"
|
||||
onClick={() => context.onMultipleModelConfigsChange(true, [])}
|
||||
>
|
||||
Change
|
||||
</button>
|
||||
<button
|
||||
data-testid="call-on-debug-change"
|
||||
onClick={() => context.onDebugWithMultipleModelChange(createModelAndParameter())}
|
||||
>
|
||||
Debug Change
|
||||
</button>
|
||||
</div>
|
||||
)
|
||||
}
|
||||
|
||||
describe('DebugWithMultipleModelContext', () => {
|
||||
describe('useDebugWithMultipleModelContext', () => {
|
||||
it('should return default values when used outside provider', () => {
|
||||
render(<TestConsumer />)
|
||||
|
||||
expect(screen.getByTestId('configs-count')).toHaveTextContent('0')
|
||||
expect(screen.getByTestId('has-check-can-send')).toHaveTextContent('no')
|
||||
})
|
||||
|
||||
it('should return default noop functions that do not throw', () => {
|
||||
render(<TestConsumer />)
|
||||
|
||||
// These should not throw when called
|
||||
expect(() => {
|
||||
screen.getByTestId('call-on-change').click()
|
||||
}).not.toThrow()
|
||||
|
||||
expect(() => {
|
||||
screen.getByTestId('call-on-debug-change').click()
|
||||
}).not.toThrow()
|
||||
})
|
||||
})
|
||||
|
||||
describe('DebugWithMultipleModelContextProvider', () => {
|
||||
it('should provide multipleModelConfigs to children', () => {
|
||||
const multipleModelConfigs = [
|
||||
createModelAndParameter({ id: 'model-1' }),
|
||||
createModelAndParameter({ id: 'model-2' }),
|
||||
]
|
||||
|
||||
render(
|
||||
<DebugWithMultipleModelContextProvider
|
||||
multipleModelConfigs={multipleModelConfigs}
|
||||
onMultipleModelConfigsChange={vi.fn()}
|
||||
onDebugWithMultipleModelChange={vi.fn()}
|
||||
>
|
||||
<TestConsumer />
|
||||
</DebugWithMultipleModelContextProvider>,
|
||||
)
|
||||
|
||||
expect(screen.getByTestId('configs-count')).toHaveTextContent('2')
|
||||
})
|
||||
|
||||
it('should provide checkCanSend function to children', () => {
|
||||
const checkCanSend = vi.fn(() => true)
|
||||
|
||||
render(
|
||||
<DebugWithMultipleModelContextProvider
|
||||
multipleModelConfigs={[]}
|
||||
onMultipleModelConfigsChange={vi.fn()}
|
||||
onDebugWithMultipleModelChange={vi.fn()}
|
||||
checkCanSend={checkCanSend}
|
||||
>
|
||||
<TestConsumer />
|
||||
</DebugWithMultipleModelContextProvider>,
|
||||
)
|
||||
|
||||
expect(screen.getByTestId('has-check-can-send')).toHaveTextContent('yes')
|
||||
})
|
||||
|
||||
it('should call onMultipleModelConfigsChange when invoked from context', () => {
|
||||
const onMultipleModelConfigsChange = vi.fn()
|
||||
|
||||
render(
|
||||
<DebugWithMultipleModelContextProvider
|
||||
multipleModelConfigs={[]}
|
||||
onMultipleModelConfigsChange={onMultipleModelConfigsChange}
|
||||
onDebugWithMultipleModelChange={vi.fn()}
|
||||
>
|
||||
<TestConsumer />
|
||||
</DebugWithMultipleModelContextProvider>,
|
||||
)
|
||||
|
||||
screen.getByTestId('call-on-change').click()
|
||||
|
||||
expect(onMultipleModelConfigsChange).toHaveBeenCalledWith(true, [])
|
||||
})
|
||||
|
||||
it('should call onDebugWithMultipleModelChange when invoked from context', () => {
|
||||
const onDebugWithMultipleModelChange = vi.fn()
|
||||
|
||||
render(
|
||||
<DebugWithMultipleModelContextProvider
|
||||
multipleModelConfigs={[]}
|
||||
onMultipleModelConfigsChange={vi.fn()}
|
||||
onDebugWithMultipleModelChange={onDebugWithMultipleModelChange}
|
||||
>
|
||||
<TestConsumer />
|
||||
</DebugWithMultipleModelContextProvider>,
|
||||
)
|
||||
|
||||
screen.getByTestId('call-on-debug-change').click()
|
||||
|
||||
expect(onDebugWithMultipleModelChange).toHaveBeenCalledWith(
|
||||
expect.objectContaining({ id: 'model-1' }),
|
||||
)
|
||||
})
|
||||
|
||||
it('should handle undefined checkCanSend', () => {
|
||||
render(
|
||||
<DebugWithMultipleModelContextProvider
|
||||
multipleModelConfigs={[]}
|
||||
onMultipleModelConfigsChange={vi.fn()}
|
||||
onDebugWithMultipleModelChange={vi.fn()}
|
||||
checkCanSend={undefined}
|
||||
>
|
||||
<TestConsumer />
|
||||
</DebugWithMultipleModelContextProvider>,
|
||||
)
|
||||
|
||||
expect(screen.getByTestId('has-check-can-send')).toHaveTextContent('no')
|
||||
})
|
||||
|
||||
it('should render children correctly', () => {
|
||||
render(
|
||||
<DebugWithMultipleModelContextProvider
|
||||
multipleModelConfigs={[]}
|
||||
onMultipleModelConfigsChange={vi.fn()}
|
||||
onDebugWithMultipleModelChange={vi.fn()}
|
||||
>
|
||||
<div data-testid="child-element">Child Content</div>
|
||||
</DebugWithMultipleModelContextProvider>,
|
||||
)
|
||||
|
||||
expect(screen.getByTestId('child-element')).toHaveTextContent('Child Content')
|
||||
})
|
||||
|
||||
it('should update context when props change', () => {
|
||||
const { rerender } = render(
|
||||
<DebugWithMultipleModelContextProvider
|
||||
multipleModelConfigs={[createModelAndParameter()]}
|
||||
onMultipleModelConfigsChange={vi.fn()}
|
||||
onDebugWithMultipleModelChange={vi.fn()}
|
||||
>
|
||||
<TestConsumer />
|
||||
</DebugWithMultipleModelContextProvider>,
|
||||
)
|
||||
|
||||
expect(screen.getByTestId('configs-count')).toHaveTextContent('1')
|
||||
|
||||
rerender(
|
||||
<DebugWithMultipleModelContextProvider
|
||||
multipleModelConfigs={[createModelAndParameter(), createModelAndParameter({ id: 'model-2' })]}
|
||||
onMultipleModelConfigsChange={vi.fn()}
|
||||
onDebugWithMultipleModelChange={vi.fn()}
|
||||
>
|
||||
<TestConsumer />
|
||||
</DebugWithMultipleModelContextProvider>,
|
||||
)
|
||||
|
||||
expect(screen.getByTestId('configs-count')).toHaveTextContent('2')
|
||||
})
|
||||
|
||||
it('should pass all context values correctly', () => {
|
||||
const contextValues: DebugWithMultipleModelContextType = {
|
||||
multipleModelConfigs: [createModelAndParameter()],
|
||||
onMultipleModelConfigsChange: vi.fn(),
|
||||
onDebugWithMultipleModelChange: vi.fn(),
|
||||
checkCanSend: () => true,
|
||||
}
|
||||
|
||||
const FullTestConsumer = () => {
|
||||
const context = useDebugWithMultipleModelContext()
|
||||
return (
|
||||
<div>
|
||||
<span data-testid="configs">{JSON.stringify(context.multipleModelConfigs)}</span>
|
||||
<span data-testid="has-on-change">{typeof context.onMultipleModelConfigsChange}</span>
|
||||
<span data-testid="has-on-debug-change">{typeof context.onDebugWithMultipleModelChange}</span>
|
||||
<span data-testid="has-check">{typeof context.checkCanSend}</span>
|
||||
</div>
|
||||
)
|
||||
}
|
||||
|
||||
render(
|
||||
<DebugWithMultipleModelContextProvider {...contextValues}>
|
||||
<FullTestConsumer />
|
||||
</DebugWithMultipleModelContextProvider>,
|
||||
)
|
||||
|
||||
expect(screen.getByTestId('configs')).toHaveTextContent('model-1')
|
||||
expect(screen.getByTestId('has-on-change')).toHaveTextContent('function')
|
||||
expect(screen.getByTestId('has-on-debug-change')).toHaveTextContent('function')
|
||||
expect(screen.getByTestId('has-check')).toHaveTextContent('function')
|
||||
})
|
||||
})
|
||||
})
|
||||
@ -0,0 +1,552 @@
|
||||
import type { CSSProperties } from 'react'
|
||||
import type { ModelAndParameter } from '../types'
|
||||
import type { Item } from '@/app/components/base/dropdown'
|
||||
import { fireEvent, render, screen } from '@testing-library/react'
|
||||
import { ModelStatusEnum } from '@/app/components/header/account-setting/model-provider-page/declarations'
|
||||
import { AppModeEnum } from '@/types/app'
|
||||
import DebugItem from './debug-item'
|
||||
|
||||
const mockUseDebugConfigurationContext = vi.fn()
|
||||
const mockUseDebugWithMultipleModelContext = vi.fn()
|
||||
const mockUseProviderContext = vi.fn()
|
||||
|
||||
let capturedDropdownProps: {
|
||||
onSelect: (item: Item) => void
|
||||
items: Item[]
|
||||
secondItems?: Item[]
|
||||
} | null = null
|
||||
|
||||
let capturedModelParameterTriggerProps: {
|
||||
modelAndParameter: ModelAndParameter
|
||||
} | null = null
|
||||
|
||||
vi.mock('@/context/debug-configuration', () => ({
|
||||
useDebugConfigurationContext: () => mockUseDebugConfigurationContext(),
|
||||
}))
|
||||
|
||||
vi.mock('./context', () => ({
|
||||
useDebugWithMultipleModelContext: () => mockUseDebugWithMultipleModelContext(),
|
||||
}))
|
||||
|
||||
vi.mock('@/context/provider-context', () => ({
|
||||
useProviderContext: () => mockUseProviderContext(),
|
||||
}))
|
||||
|
||||
vi.mock('./chat-item', () => ({
|
||||
default: ({ modelAndParameter }: { modelAndParameter: ModelAndParameter }) => (
|
||||
<div data-testid="chat-item" data-model-id={modelAndParameter.id}>ChatItem</div>
|
||||
),
|
||||
}))
|
||||
|
||||
vi.mock('./text-generation-item', () => ({
|
||||
default: ({ modelAndParameter }: { modelAndParameter: ModelAndParameter }) => (
|
||||
<div data-testid="text-generation-item" data-model-id={modelAndParameter.id}>TextGenerationItem</div>
|
||||
),
|
||||
}))
|
||||
|
||||
vi.mock('./model-parameter-trigger', () => ({
|
||||
default: (props: { modelAndParameter: ModelAndParameter }) => {
|
||||
capturedModelParameterTriggerProps = props
|
||||
return <div data-testid="model-parameter-trigger">ModelParameterTrigger</div>
|
||||
},
|
||||
}))
|
||||
|
||||
vi.mock('@/app/components/base/dropdown', () => ({
|
||||
default: (props: { onSelect: (item: Item) => void, items: Item[], secondItems?: Item[] }) => {
|
||||
capturedDropdownProps = props
|
||||
return (
|
||||
<div data-testid="dropdown">
|
||||
{props.items.map(item => (
|
||||
<button
|
||||
key={item.value}
|
||||
data-testid={`dropdown-item-${item.value}`}
|
||||
onClick={() => props.onSelect(item)}
|
||||
>
|
||||
{item.text}
|
||||
</button>
|
||||
))}
|
||||
{props.secondItems?.map(item => (
|
||||
<button
|
||||
key={item.value}
|
||||
data-testid={`dropdown-second-item-${item.value}`}
|
||||
onClick={() => props.onSelect(item)}
|
||||
>
|
||||
{item.text}
|
||||
</button>
|
||||
))}
|
||||
</div>
|
||||
)
|
||||
},
|
||||
}))
|
||||
|
||||
const createModelAndParameter = (overrides: Partial<ModelAndParameter> = {}): ModelAndParameter => ({
|
||||
id: 'model-1',
|
||||
model: 'gpt-3.5-turbo',
|
||||
provider: 'openai',
|
||||
parameters: {},
|
||||
...overrides,
|
||||
})
|
||||
|
||||
const createTextGenerationModelList = (models: Array<{ provider: string, model: string, status?: ModelStatusEnum }> = []) => {
|
||||
const providers: Record<string, { provider: string, models: Array<{ model: string, status: ModelStatusEnum }> }> = {}
|
||||
|
||||
models.forEach(({ provider, model, status = ModelStatusEnum.active }) => {
|
||||
if (!providers[provider]) {
|
||||
providers[provider] = { provider, models: [] }
|
||||
}
|
||||
providers[provider].models.push({ model, status })
|
||||
})
|
||||
|
||||
return Object.values(providers)
|
||||
}
|
||||
|
||||
type DebugItemProps = {
|
||||
modelAndParameter: ModelAndParameter
|
||||
className?: string
|
||||
style?: CSSProperties
|
||||
}
|
||||
|
||||
const renderComponent = (props: Partial<DebugItemProps> = {}) => {
|
||||
const defaultProps: DebugItemProps = {
|
||||
modelAndParameter: createModelAndParameter(),
|
||||
...props,
|
||||
}
|
||||
return render(<DebugItem {...defaultProps} />)
|
||||
}
|
||||
|
||||
describe('DebugItem', () => {
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks()
|
||||
capturedDropdownProps = null
|
||||
capturedModelParameterTriggerProps = null
|
||||
|
||||
mockUseDebugConfigurationContext.mockReturnValue({
|
||||
mode: AppModeEnum.CHAT,
|
||||
})
|
||||
|
||||
mockUseDebugWithMultipleModelContext.mockReturnValue({
|
||||
multipleModelConfigs: [createModelAndParameter()],
|
||||
onMultipleModelConfigsChange: vi.fn(),
|
||||
onDebugWithMultipleModelChange: vi.fn(),
|
||||
})
|
||||
|
||||
mockUseProviderContext.mockReturnValue({
|
||||
textGenerationModelList: createTextGenerationModelList([
|
||||
{ provider: 'openai', model: 'gpt-3.5-turbo' },
|
||||
]),
|
||||
})
|
||||
})
|
||||
|
||||
describe('rendering', () => {
|
||||
it('should render with basic props', () => {
|
||||
renderComponent()
|
||||
|
||||
expect(screen.getByTestId('model-parameter-trigger')).toBeInTheDocument()
|
||||
expect(screen.getByTestId('dropdown')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should display correct index number', () => {
|
||||
const modelConfigs = [
|
||||
createModelAndParameter({ id: 'model-1' }),
|
||||
createModelAndParameter({ id: 'model-2' }),
|
||||
]
|
||||
mockUseDebugWithMultipleModelContext.mockReturnValue({
|
||||
multipleModelConfigs: modelConfigs,
|
||||
onMultipleModelConfigsChange: vi.fn(),
|
||||
onDebugWithMultipleModelChange: vi.fn(),
|
||||
})
|
||||
|
||||
const { container } = renderComponent({ modelAndParameter: createModelAndParameter({ id: 'model-2' }) })
|
||||
|
||||
// The index is displayed as "#2" in the component
|
||||
const indexElement = container.querySelector('.font-medium.italic')
|
||||
expect(indexElement?.textContent?.trim()).toContain('2')
|
||||
})
|
||||
|
||||
it('should apply className and style props', () => {
|
||||
const { container } = renderComponent({
|
||||
className: 'custom-class',
|
||||
style: { backgroundColor: 'red' },
|
||||
})
|
||||
|
||||
const wrapper = container.firstChild as HTMLElement
|
||||
expect(wrapper).toHaveClass('custom-class')
|
||||
expect(wrapper.style.backgroundColor).toBe('red')
|
||||
})
|
||||
|
||||
it('should pass modelAndParameter to ModelParameterTrigger', () => {
|
||||
const modelAndParameter = createModelAndParameter({ id: 'test-model' })
|
||||
renderComponent({ modelAndParameter })
|
||||
|
||||
expect(capturedModelParameterTriggerProps?.modelAndParameter).toEqual(modelAndParameter)
|
||||
})
|
||||
})
|
||||
|
||||
describe('ChatItem rendering', () => {
|
||||
it('should render ChatItem in CHAT mode with active model', () => {
|
||||
mockUseDebugConfigurationContext.mockReturnValue({ mode: AppModeEnum.CHAT })
|
||||
mockUseProviderContext.mockReturnValue({
|
||||
textGenerationModelList: createTextGenerationModelList([
|
||||
{ provider: 'openai', model: 'gpt-3.5-turbo', status: ModelStatusEnum.active },
|
||||
]),
|
||||
})
|
||||
|
||||
renderComponent()
|
||||
|
||||
expect(screen.getByTestId('chat-item')).toBeInTheDocument()
|
||||
expect(screen.queryByTestId('text-generation-item')).not.toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should render ChatItem in AGENT_CHAT mode with active model', () => {
|
||||
mockUseDebugConfigurationContext.mockReturnValue({ mode: AppModeEnum.AGENT_CHAT })
|
||||
mockUseProviderContext.mockReturnValue({
|
||||
textGenerationModelList: createTextGenerationModelList([
|
||||
{ provider: 'openai', model: 'gpt-3.5-turbo', status: ModelStatusEnum.active },
|
||||
]),
|
||||
})
|
||||
|
||||
renderComponent()
|
||||
|
||||
expect(screen.getByTestId('chat-item')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should not render ChatItem when model is not active', () => {
|
||||
mockUseDebugConfigurationContext.mockReturnValue({ mode: AppModeEnum.CHAT })
|
||||
mockUseProviderContext.mockReturnValue({
|
||||
textGenerationModelList: createTextGenerationModelList([
|
||||
{ provider: 'openai', model: 'gpt-3.5-turbo', status: ModelStatusEnum.disabled },
|
||||
]),
|
||||
})
|
||||
|
||||
renderComponent()
|
||||
|
||||
expect(screen.queryByTestId('chat-item')).not.toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should not render ChatItem when provider not found', () => {
|
||||
mockUseDebugConfigurationContext.mockReturnValue({ mode: AppModeEnum.CHAT })
|
||||
mockUseProviderContext.mockReturnValue({
|
||||
textGenerationModelList: createTextGenerationModelList([
|
||||
{ provider: 'anthropic', model: 'claude-3', status: ModelStatusEnum.active },
|
||||
]),
|
||||
})
|
||||
|
||||
renderComponent()
|
||||
|
||||
expect(screen.queryByTestId('chat-item')).not.toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should not render ChatItem when model not found', () => {
|
||||
mockUseDebugConfigurationContext.mockReturnValue({ mode: AppModeEnum.CHAT })
|
||||
mockUseProviderContext.mockReturnValue({
|
||||
textGenerationModelList: createTextGenerationModelList([
|
||||
{ provider: 'openai', model: 'gpt-4', status: ModelStatusEnum.active },
|
||||
]),
|
||||
})
|
||||
|
||||
renderComponent()
|
||||
|
||||
expect(screen.queryByTestId('chat-item')).not.toBeInTheDocument()
|
||||
})
|
||||
})
|
||||
|
||||
describe('TextGenerationItem rendering', () => {
|
||||
it('should render TextGenerationItem in COMPLETION mode with active model', () => {
|
||||
mockUseDebugConfigurationContext.mockReturnValue({ mode: AppModeEnum.COMPLETION })
|
||||
mockUseProviderContext.mockReturnValue({
|
||||
textGenerationModelList: createTextGenerationModelList([
|
||||
{ provider: 'openai', model: 'gpt-3.5-turbo', status: ModelStatusEnum.active },
|
||||
]),
|
||||
})
|
||||
|
||||
renderComponent()
|
||||
|
||||
expect(screen.getByTestId('text-generation-item')).toBeInTheDocument()
|
||||
expect(screen.queryByTestId('chat-item')).not.toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should not render TextGenerationItem when provider is not found', () => {
|
||||
mockUseDebugConfigurationContext.mockReturnValue({ mode: AppModeEnum.COMPLETION })
|
||||
mockUseProviderContext.mockReturnValue({
|
||||
textGenerationModelList: createTextGenerationModelList([
|
||||
{ provider: 'anthropic', model: 'claude-3', status: ModelStatusEnum.active },
|
||||
]),
|
||||
})
|
||||
|
||||
renderComponent()
|
||||
|
||||
expect(screen.queryByTestId('text-generation-item')).not.toBeInTheDocument()
|
||||
})
|
||||
})
|
||||
|
||||
describe('dropdown menu', () => {
|
||||
it('should show duplicate option when less than 4 models', () => {
|
||||
mockUseDebugWithMultipleModelContext.mockReturnValue({
|
||||
multipleModelConfigs: [createModelAndParameter()],
|
||||
onMultipleModelConfigsChange: vi.fn(),
|
||||
onDebugWithMultipleModelChange: vi.fn(),
|
||||
})
|
||||
|
||||
renderComponent()
|
||||
|
||||
expect(capturedDropdownProps?.items).toContainEqual(
|
||||
expect.objectContaining({ value: 'duplicate' }),
|
||||
)
|
||||
})
|
||||
|
||||
it('should hide duplicate option when 4 or more models', () => {
|
||||
mockUseDebugWithMultipleModelContext.mockReturnValue({
|
||||
multipleModelConfigs: [
|
||||
createModelAndParameter({ id: '1' }),
|
||||
createModelAndParameter({ id: '2' }),
|
||||
createModelAndParameter({ id: '3' }),
|
||||
createModelAndParameter({ id: '4' }),
|
||||
],
|
||||
onMultipleModelConfigsChange: vi.fn(),
|
||||
onDebugWithMultipleModelChange: vi.fn(),
|
||||
})
|
||||
|
||||
renderComponent()
|
||||
|
||||
expect(capturedDropdownProps?.items).not.toContainEqual(
|
||||
expect.objectContaining({ value: 'duplicate' }),
|
||||
)
|
||||
})
|
||||
|
||||
it('should show debug-as-single-model option when provider and model are set', () => {
|
||||
renderComponent({
|
||||
modelAndParameter: createModelAndParameter({
|
||||
provider: 'openai',
|
||||
model: 'gpt-3.5-turbo',
|
||||
}),
|
||||
})
|
||||
|
||||
expect(capturedDropdownProps?.items).toContainEqual(
|
||||
expect.objectContaining({ value: 'debug-as-single-model' }),
|
||||
)
|
||||
})
|
||||
|
||||
it('should hide debug-as-single-model option when provider is missing', () => {
|
||||
renderComponent({
|
||||
modelAndParameter: createModelAndParameter({
|
||||
provider: '',
|
||||
model: 'gpt-3.5-turbo',
|
||||
}),
|
||||
})
|
||||
|
||||
expect(capturedDropdownProps?.items).not.toContainEqual(
|
||||
expect.objectContaining({ value: 'debug-as-single-model' }),
|
||||
)
|
||||
})
|
||||
|
||||
it('should hide debug-as-single-model option when model is missing', () => {
|
||||
renderComponent({
|
||||
modelAndParameter: createModelAndParameter({
|
||||
provider: 'openai',
|
||||
model: '',
|
||||
}),
|
||||
})
|
||||
|
||||
expect(capturedDropdownProps?.items).not.toContainEqual(
|
||||
expect.objectContaining({ value: 'debug-as-single-model' }),
|
||||
)
|
||||
})
|
||||
|
||||
it('should show remove option in secondItems when more than 2 models', () => {
|
||||
mockUseDebugWithMultipleModelContext.mockReturnValue({
|
||||
multipleModelConfigs: [
|
||||
createModelAndParameter({ id: '1' }),
|
||||
createModelAndParameter({ id: '2' }),
|
||||
createModelAndParameter({ id: '3' }),
|
||||
],
|
||||
onMultipleModelConfigsChange: vi.fn(),
|
||||
onDebugWithMultipleModelChange: vi.fn(),
|
||||
})
|
||||
|
||||
renderComponent()
|
||||
|
||||
expect(capturedDropdownProps?.secondItems).toContainEqual(
|
||||
expect.objectContaining({ value: 'remove' }),
|
||||
)
|
||||
})
|
||||
|
||||
it('should not show remove option when 2 or fewer models', () => {
|
||||
mockUseDebugWithMultipleModelContext.mockReturnValue({
|
||||
multipleModelConfigs: [
|
||||
createModelAndParameter({ id: '1' }),
|
||||
createModelAndParameter({ id: '2' }),
|
||||
],
|
||||
onMultipleModelConfigsChange: vi.fn(),
|
||||
onDebugWithMultipleModelChange: vi.fn(),
|
||||
})
|
||||
|
||||
renderComponent()
|
||||
|
||||
expect(capturedDropdownProps?.secondItems).toBeUndefined()
|
||||
})
|
||||
})
|
||||
|
||||
describe('dropdown actions', () => {
|
||||
it('should duplicate model when duplicate is selected', () => {
|
||||
const onMultipleModelConfigsChange = vi.fn()
|
||||
const originalModel = createModelAndParameter({ id: 'original' })
|
||||
|
||||
mockUseDebugWithMultipleModelContext.mockReturnValue({
|
||||
multipleModelConfigs: [originalModel],
|
||||
onMultipleModelConfigsChange,
|
||||
onDebugWithMultipleModelChange: vi.fn(),
|
||||
})
|
||||
|
||||
renderComponent({ modelAndParameter: originalModel })
|
||||
|
||||
fireEvent.click(screen.getByTestId('dropdown-item-duplicate'))
|
||||
|
||||
expect(onMultipleModelConfigsChange).toHaveBeenCalledWith(
|
||||
true,
|
||||
expect.arrayContaining([
|
||||
originalModel,
|
||||
expect.objectContaining({
|
||||
model: originalModel.model,
|
||||
provider: originalModel.provider,
|
||||
parameters: originalModel.parameters,
|
||||
}),
|
||||
]),
|
||||
)
|
||||
})
|
||||
|
||||
it('should not duplicate when already at 4 models', () => {
|
||||
const onMultipleModelConfigsChange = vi.fn()
|
||||
const models = [
|
||||
createModelAndParameter({ id: '1' }),
|
||||
createModelAndParameter({ id: '2' }),
|
||||
createModelAndParameter({ id: '3' }),
|
||||
createModelAndParameter({ id: '4' }),
|
||||
]
|
||||
|
||||
mockUseDebugWithMultipleModelContext.mockReturnValue({
|
||||
multipleModelConfigs: models,
|
||||
onMultipleModelConfigsChange,
|
||||
onDebugWithMultipleModelChange: vi.fn(),
|
||||
})
|
||||
|
||||
renderComponent({ modelAndParameter: models[0] })
|
||||
|
||||
// Since duplicate is not shown when >= 4 models, we need to manually call handleSelect
|
||||
capturedDropdownProps?.onSelect({ value: 'duplicate', text: 'Duplicate' })
|
||||
|
||||
expect(onMultipleModelConfigsChange).not.toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('should call onDebugWithMultipleModelChange when debug-as-single-model is selected', () => {
|
||||
const onDebugWithMultipleModelChange = vi.fn()
|
||||
const modelAndParameter = createModelAndParameter()
|
||||
|
||||
mockUseDebugWithMultipleModelContext.mockReturnValue({
|
||||
multipleModelConfigs: [modelAndParameter],
|
||||
onMultipleModelConfigsChange: vi.fn(),
|
||||
onDebugWithMultipleModelChange,
|
||||
})
|
||||
|
||||
renderComponent({ modelAndParameter })
|
||||
|
||||
fireEvent.click(screen.getByTestId('dropdown-item-debug-as-single-model'))
|
||||
|
||||
expect(onDebugWithMultipleModelChange).toHaveBeenCalledWith(modelAndParameter)
|
||||
})
|
||||
|
||||
it('should remove model when remove is selected', () => {
|
||||
const onMultipleModelConfigsChange = vi.fn()
|
||||
const models = [
|
||||
createModelAndParameter({ id: '1' }),
|
||||
createModelAndParameter({ id: '2' }),
|
||||
createModelAndParameter({ id: '3' }),
|
||||
]
|
||||
|
||||
mockUseDebugWithMultipleModelContext.mockReturnValue({
|
||||
multipleModelConfigs: models,
|
||||
onMultipleModelConfigsChange,
|
||||
onDebugWithMultipleModelChange: vi.fn(),
|
||||
})
|
||||
|
||||
renderComponent({ modelAndParameter: models[1] })
|
||||
|
||||
fireEvent.click(screen.getByTestId('dropdown-second-item-remove'))
|
||||
|
||||
expect(onMultipleModelConfigsChange).toHaveBeenCalledWith(
|
||||
true,
|
||||
[models[0], models[2]],
|
||||
)
|
||||
})
|
||||
|
||||
it('should insert duplicated model at correct position', () => {
|
||||
const onMultipleModelConfigsChange = vi.fn()
|
||||
const models = [
|
||||
createModelAndParameter({ id: '1' }),
|
||||
createModelAndParameter({ id: '2' }),
|
||||
createModelAndParameter({ id: '3' }),
|
||||
]
|
||||
|
||||
mockUseDebugWithMultipleModelContext.mockReturnValue({
|
||||
multipleModelConfigs: models,
|
||||
onMultipleModelConfigsChange,
|
||||
onDebugWithMultipleModelChange: vi.fn(),
|
||||
})
|
||||
|
||||
// Duplicate the second model
|
||||
renderComponent({ modelAndParameter: models[1] })
|
||||
|
||||
fireEvent.click(screen.getByTestId('dropdown-item-duplicate'))
|
||||
|
||||
expect(onMultipleModelConfigsChange).toHaveBeenCalledWith(
|
||||
true,
|
||||
expect.arrayContaining([
|
||||
models[0],
|
||||
models[1],
|
||||
expect.objectContaining({ model: models[1].model }),
|
||||
models[2],
|
||||
]),
|
||||
)
|
||||
})
|
||||
})
|
||||
|
||||
describe('edge cases', () => {
|
||||
it('should handle model not found in multipleModelConfigs', () => {
|
||||
mockUseDebugWithMultipleModelContext.mockReturnValue({
|
||||
multipleModelConfigs: [],
|
||||
onMultipleModelConfigsChange: vi.fn(),
|
||||
onDebugWithMultipleModelChange: vi.fn(),
|
||||
})
|
||||
|
||||
const { container } = renderComponent()
|
||||
|
||||
// Should show index 0 (not found returns -1, but display shows index + 1)
|
||||
const indexElement = container.querySelector('.font-medium.italic')
|
||||
expect(indexElement?.textContent?.trim()).toContain('0')
|
||||
})
|
||||
|
||||
it('should handle empty textGenerationModelList', () => {
|
||||
mockUseProviderContext.mockReturnValue({
|
||||
textGenerationModelList: [],
|
||||
})
|
||||
|
||||
renderComponent()
|
||||
|
||||
expect(screen.queryByTestId('chat-item')).not.toBeInTheDocument()
|
||||
expect(screen.queryByTestId('text-generation-item')).not.toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should handle model with quotaExceeded status', () => {
|
||||
mockUseDebugConfigurationContext.mockReturnValue({ mode: AppModeEnum.CHAT })
|
||||
mockUseProviderContext.mockReturnValue({
|
||||
textGenerationModelList: createTextGenerationModelList([
|
||||
{ provider: 'anthropic', model: 'not-matching', status: ModelStatusEnum.quotaExceeded },
|
||||
]),
|
||||
})
|
||||
|
||||
renderComponent()
|
||||
|
||||
// When provider/model doesn't match, ChatItem won't render
|
||||
expect(screen.queryByTestId('chat-item')).not.toBeInTheDocument()
|
||||
})
|
||||
})
|
||||
})
|
||||
@ -0,0 +1,405 @@
|
||||
import type { ReactNode } from 'react'
|
||||
import type { ModelAndParameter } from '../types'
|
||||
import type { FormValue } from '@/app/components/header/account-setting/model-provider-page/declarations'
|
||||
import { render, screen } from '@testing-library/react'
|
||||
import { ModelStatusEnum } from '@/app/components/header/account-setting/model-provider-page/declarations'
|
||||
import ModelParameterTrigger from './model-parameter-trigger'
|
||||
|
||||
const mockUseDebugConfigurationContext = vi.fn()
|
||||
const mockUseDebugWithMultipleModelContext = vi.fn()
|
||||
const mockUseLanguage = vi.fn()
|
||||
|
||||
type RenderTriggerProps = {
|
||||
open: boolean
|
||||
currentProvider: { provider: string } | null
|
||||
currentModel: { model: string, status: ModelStatusEnum } | null
|
||||
}
|
||||
|
||||
let capturedModalProps: {
|
||||
isAdvancedMode: boolean
|
||||
provider: string
|
||||
modelId: string
|
||||
completionParams: FormValue
|
||||
onCompletionParamsChange: (params: FormValue) => void
|
||||
setModel: (model: { modelId: string, provider: string }) => void
|
||||
debugWithMultipleModel: boolean
|
||||
onDebugWithMultipleModelChange: () => void
|
||||
renderTrigger: (props: RenderTriggerProps) => ReactNode
|
||||
} | null = null
|
||||
|
||||
vi.mock('@/context/debug-configuration', () => ({
|
||||
useDebugConfigurationContext: () => mockUseDebugConfigurationContext(),
|
||||
}))
|
||||
|
||||
vi.mock('./context', () => ({
|
||||
useDebugWithMultipleModelContext: () => mockUseDebugWithMultipleModelContext(),
|
||||
}))
|
||||
|
||||
vi.mock('@/app/components/header/account-setting/model-provider-page/hooks', () => ({
|
||||
useLanguage: () => mockUseLanguage(),
|
||||
}))
|
||||
|
||||
vi.mock('@/app/components/header/account-setting/model-provider-page/model-parameter-modal', () => ({
|
||||
default: (props: typeof capturedModalProps) => {
|
||||
capturedModalProps = props
|
||||
// Render the trigger that the component passes
|
||||
const triggerContent = props?.renderTrigger({
|
||||
open: false,
|
||||
currentProvider: null,
|
||||
currentModel: null,
|
||||
})
|
||||
return (
|
||||
<div data-testid="model-parameter-modal">
|
||||
{triggerContent}
|
||||
</div>
|
||||
)
|
||||
},
|
||||
}))
|
||||
|
||||
vi.mock('@/app/components/header/account-setting/model-provider-page/model-icon', () => ({
|
||||
default: ({ provider, modelName }: { provider: { provider: string }, modelName?: string }) => (
|
||||
<div data-testid="model-icon" data-provider={provider?.provider} data-model={modelName}>
|
||||
ModelIcon
|
||||
</div>
|
||||
),
|
||||
}))
|
||||
|
||||
vi.mock('@/app/components/header/account-setting/model-provider-page/model-name', () => ({
|
||||
default: ({ modelItem }: { modelItem: { model: string } }) => (
|
||||
<div data-testid="model-name">{modelItem?.model}</div>
|
||||
),
|
||||
}))
|
||||
|
||||
vi.mock('@/app/components/base/tooltip', () => ({
|
||||
default: ({ children, popupContent }: { children: ReactNode, popupContent: string }) => (
|
||||
<div data-testid="tooltip" data-content={popupContent}>{children}</div>
|
||||
),
|
||||
}))
|
||||
|
||||
const createModelAndParameter = (overrides: Partial<ModelAndParameter> = {}): ModelAndParameter => ({
|
||||
id: 'model-1',
|
||||
model: 'gpt-3.5-turbo',
|
||||
provider: 'openai',
|
||||
parameters: { temperature: 0.7 },
|
||||
...overrides,
|
||||
})
|
||||
|
||||
const renderComponent = (props: Partial<{ modelAndParameter: ModelAndParameter }> = {}) => {
|
||||
const defaultProps = {
|
||||
modelAndParameter: createModelAndParameter(),
|
||||
...props,
|
||||
}
|
||||
return render(<ModelParameterTrigger {...defaultProps} />)
|
||||
}
|
||||
|
||||
describe('ModelParameterTrigger', () => {
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks()
|
||||
capturedModalProps = null
|
||||
|
||||
mockUseDebugConfigurationContext.mockReturnValue({
|
||||
isAdvancedMode: false,
|
||||
})
|
||||
|
||||
mockUseDebugWithMultipleModelContext.mockReturnValue({
|
||||
multipleModelConfigs: [createModelAndParameter()],
|
||||
onMultipleModelConfigsChange: vi.fn(),
|
||||
onDebugWithMultipleModelChange: vi.fn(),
|
||||
})
|
||||
|
||||
mockUseLanguage.mockReturnValue('en_US')
|
||||
})
|
||||
|
||||
describe('rendering', () => {
|
||||
it('should render ModelParameterModal', () => {
|
||||
renderComponent()
|
||||
|
||||
expect(screen.getByTestId('model-parameter-modal')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should pass correct props to ModelParameterModal', () => {
|
||||
const modelAndParameter = createModelAndParameter({
|
||||
provider: 'anthropic',
|
||||
model: 'claude-3',
|
||||
parameters: { max_tokens: 1000 },
|
||||
})
|
||||
|
||||
renderComponent({ modelAndParameter })
|
||||
|
||||
expect(capturedModalProps?.provider).toBe('anthropic')
|
||||
expect(capturedModalProps?.modelId).toBe('claude-3')
|
||||
expect(capturedModalProps?.completionParams).toEqual({ max_tokens: 1000 })
|
||||
expect(capturedModalProps?.debugWithMultipleModel).toBe(true)
|
||||
})
|
||||
|
||||
it('should pass isAdvancedMode from context', () => {
|
||||
mockUseDebugConfigurationContext.mockReturnValue({
|
||||
isAdvancedMode: true,
|
||||
})
|
||||
|
||||
renderComponent()
|
||||
|
||||
expect(capturedModalProps?.isAdvancedMode).toBe(true)
|
||||
})
|
||||
})
|
||||
|
||||
describe('handleSelectModel', () => {
|
||||
it('should call onMultipleModelConfigsChange with updated model', () => {
|
||||
const onMultipleModelConfigsChange = vi.fn()
|
||||
const modelAndParameter = createModelAndParameter({ id: 'model-1' })
|
||||
|
||||
mockUseDebugWithMultipleModelContext.mockReturnValue({
|
||||
multipleModelConfigs: [modelAndParameter],
|
||||
onMultipleModelConfigsChange,
|
||||
onDebugWithMultipleModelChange: vi.fn(),
|
||||
})
|
||||
|
||||
renderComponent({ modelAndParameter })
|
||||
|
||||
// Directly call the setModel callback
|
||||
capturedModalProps?.setModel({ modelId: 'gpt-4', provider: 'openai' })
|
||||
|
||||
expect(onMultipleModelConfigsChange).toHaveBeenCalledWith(true, [
|
||||
expect.objectContaining({
|
||||
id: 'model-1',
|
||||
model: 'gpt-4',
|
||||
provider: 'openai',
|
||||
}),
|
||||
])
|
||||
})
|
||||
|
||||
it('should update correct model in array', () => {
|
||||
const onMultipleModelConfigsChange = vi.fn()
|
||||
const models = [
|
||||
createModelAndParameter({ id: 'model-1' }),
|
||||
createModelAndParameter({ id: 'model-2' }),
|
||||
createModelAndParameter({ id: 'model-3' }),
|
||||
]
|
||||
|
||||
mockUseDebugWithMultipleModelContext.mockReturnValue({
|
||||
multipleModelConfigs: models,
|
||||
onMultipleModelConfigsChange,
|
||||
onDebugWithMultipleModelChange: vi.fn(),
|
||||
})
|
||||
|
||||
renderComponent({ modelAndParameter: models[1] })
|
||||
|
||||
capturedModalProps?.setModel({ modelId: 'gpt-4', provider: 'openai' })
|
||||
|
||||
expect(onMultipleModelConfigsChange).toHaveBeenCalledWith(true, [
|
||||
models[0],
|
||||
expect.objectContaining({
|
||||
id: 'model-2',
|
||||
model: 'gpt-4',
|
||||
provider: 'openai',
|
||||
}),
|
||||
models[2],
|
||||
])
|
||||
})
|
||||
})
|
||||
|
||||
describe('handleParamsChange', () => {
|
||||
it('should call onMultipleModelConfigsChange with updated parameters', () => {
|
||||
const onMultipleModelConfigsChange = vi.fn()
|
||||
const modelAndParameter = createModelAndParameter({ id: 'model-1' })
|
||||
|
||||
mockUseDebugWithMultipleModelContext.mockReturnValue({
|
||||
multipleModelConfigs: [modelAndParameter],
|
||||
onMultipleModelConfigsChange,
|
||||
onDebugWithMultipleModelChange: vi.fn(),
|
||||
})
|
||||
|
||||
renderComponent({ modelAndParameter })
|
||||
|
||||
capturedModalProps?.onCompletionParamsChange({ temperature: 0.8 })
|
||||
|
||||
expect(onMultipleModelConfigsChange).toHaveBeenCalledWith(true, [
|
||||
expect.objectContaining({
|
||||
id: 'model-1',
|
||||
parameters: { temperature: 0.8 },
|
||||
}),
|
||||
])
|
||||
})
|
||||
|
||||
it('should preserve other model properties when changing params', () => {
|
||||
const onMultipleModelConfigsChange = vi.fn()
|
||||
const modelAndParameter = createModelAndParameter({
|
||||
id: 'model-1',
|
||||
model: 'gpt-3.5-turbo',
|
||||
provider: 'openai',
|
||||
parameters: { temperature: 0.7 },
|
||||
})
|
||||
|
||||
mockUseDebugWithMultipleModelContext.mockReturnValue({
|
||||
multipleModelConfigs: [modelAndParameter],
|
||||
onMultipleModelConfigsChange,
|
||||
onDebugWithMultipleModelChange: vi.fn(),
|
||||
})
|
||||
|
||||
renderComponent({ modelAndParameter })
|
||||
|
||||
capturedModalProps?.onCompletionParamsChange({ temperature: 0.8 })
|
||||
|
||||
expect(onMultipleModelConfigsChange).toHaveBeenCalledWith(true, [
|
||||
expect.objectContaining({
|
||||
id: 'model-1',
|
||||
model: 'gpt-3.5-turbo',
|
||||
provider: 'openai',
|
||||
parameters: { temperature: 0.8 },
|
||||
}),
|
||||
])
|
||||
})
|
||||
})
|
||||
|
||||
describe('onDebugWithMultipleModelChange', () => {
|
||||
it('should call context onDebugWithMultipleModelChange with modelAndParameter', () => {
|
||||
const onDebugWithMultipleModelChange = vi.fn()
|
||||
const modelAndParameter = createModelAndParameter()
|
||||
|
||||
mockUseDebugWithMultipleModelContext.mockReturnValue({
|
||||
multipleModelConfigs: [modelAndParameter],
|
||||
onMultipleModelConfigsChange: vi.fn(),
|
||||
onDebugWithMultipleModelChange,
|
||||
})
|
||||
|
||||
renderComponent({ modelAndParameter })
|
||||
|
||||
capturedModalProps?.onDebugWithMultipleModelChange()
|
||||
|
||||
expect(onDebugWithMultipleModelChange).toHaveBeenCalledWith(modelAndParameter)
|
||||
})
|
||||
})
|
||||
|
||||
describe('index calculation', () => {
|
||||
it('should find correct index in multipleModelConfigs', () => {
|
||||
const models = [
|
||||
createModelAndParameter({ id: 'model-1' }),
|
||||
createModelAndParameter({ id: 'model-2' }),
|
||||
createModelAndParameter({ id: 'model-3' }),
|
||||
]
|
||||
|
||||
mockUseDebugWithMultipleModelContext.mockReturnValue({
|
||||
multipleModelConfigs: models,
|
||||
onMultipleModelConfigsChange: vi.fn(),
|
||||
onDebugWithMultipleModelChange: vi.fn(),
|
||||
})
|
||||
|
||||
renderComponent({ modelAndParameter: models[2] })
|
||||
|
||||
// The component uses the index to update the correct model
|
||||
// We verify this through the handleSelectModel behavior
|
||||
expect(capturedModalProps).not.toBeNull()
|
||||
})
|
||||
|
||||
it('should handle model not found in configs', () => {
|
||||
mockUseDebugWithMultipleModelContext.mockReturnValue({
|
||||
multipleModelConfigs: [createModelAndParameter({ id: 'other' })],
|
||||
onMultipleModelConfigsChange: vi.fn(),
|
||||
onDebugWithMultipleModelChange: vi.fn(),
|
||||
})
|
||||
|
||||
// Should not throw even if model is not found
|
||||
expect(() => renderComponent()).not.toThrow()
|
||||
})
|
||||
})
|
||||
|
||||
describe('trigger rendering', () => {
|
||||
it('should render trigger content from renderTrigger', () => {
|
||||
renderComponent()
|
||||
|
||||
// The trigger is rendered via renderTrigger callback
|
||||
expect(screen.getByTestId('model-parameter-modal')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should render "Select Model" text when no provider/model', () => {
|
||||
renderComponent()
|
||||
|
||||
// When currentProvider and currentModel are null, shows "Select Model"
|
||||
expect(screen.getByText('common.modelProvider.selectModel')).toBeInTheDocument()
|
||||
})
|
||||
})
|
||||
|
||||
describe('language context', () => {
|
||||
it('should use language from useLanguage hook', () => {
|
||||
mockUseLanguage.mockReturnValue('zh_Hans')
|
||||
|
||||
renderComponent()
|
||||
|
||||
// The language is used for MODEL_STATUS_TEXT tooltip
|
||||
// We verify the hook is called
|
||||
expect(mockUseLanguage).toHaveBeenCalled()
|
||||
})
|
||||
})
|
||||
|
||||
describe('edge cases', () => {
|
||||
it('should handle empty multipleModelConfigs', () => {
|
||||
mockUseDebugWithMultipleModelContext.mockReturnValue({
|
||||
multipleModelConfigs: [],
|
||||
onMultipleModelConfigsChange: vi.fn(),
|
||||
onDebugWithMultipleModelChange: vi.fn(),
|
||||
})
|
||||
|
||||
expect(() => renderComponent()).not.toThrow()
|
||||
})
|
||||
|
||||
it('should handle undefined parameters', () => {
|
||||
const modelAndParameter = createModelAndParameter({
|
||||
parameters: undefined as unknown as FormValue,
|
||||
})
|
||||
|
||||
expect(() => renderComponent({ modelAndParameter })).not.toThrow()
|
||||
expect(capturedModalProps?.completionParams).toBeUndefined()
|
||||
})
|
||||
|
||||
it('should handle model selection for model not in list', () => {
|
||||
const onMultipleModelConfigsChange = vi.fn()
|
||||
const modelAndParameter = createModelAndParameter({ id: 'not-in-list' })
|
||||
|
||||
mockUseDebugWithMultipleModelContext.mockReturnValue({
|
||||
multipleModelConfigs: [createModelAndParameter({ id: 'different-model' })],
|
||||
onMultipleModelConfigsChange,
|
||||
onDebugWithMultipleModelChange: vi.fn(),
|
||||
})
|
||||
|
||||
renderComponent({ modelAndParameter })
|
||||
|
||||
capturedModalProps?.setModel({ modelId: 'gpt-4', provider: 'openai' })
|
||||
|
||||
// index will be -1, so newModelConfigs[-1] will be undefined
|
||||
// This tests the edge case behavior
|
||||
expect(onMultipleModelConfigsChange).toHaveBeenCalled()
|
||||
})
|
||||
})
|
||||
|
||||
describe('renderTrigger with different states', () => {
|
||||
it('should pass correct props to renderTrigger', () => {
|
||||
renderComponent()
|
||||
|
||||
expect(capturedModalProps?.renderTrigger).toBeDefined()
|
||||
expect(typeof capturedModalProps?.renderTrigger).toBe('function')
|
||||
})
|
||||
|
||||
it('should render trigger with provider info when available', () => {
|
||||
// Mock the modal to render trigger with provider
|
||||
vi.doMock('@/app/components/header/account-setting/model-provider-page/model-parameter-modal', () => ({
|
||||
default: (props: typeof capturedModalProps) => {
|
||||
capturedModalProps = props
|
||||
const triggerContent = props?.renderTrigger({
|
||||
open: false,
|
||||
currentProvider: { provider: 'openai' },
|
||||
currentModel: { model: 'gpt-3.5-turbo', status: ModelStatusEnum.active },
|
||||
})
|
||||
return (
|
||||
<div data-testid="model-parameter-modal">
|
||||
{triggerContent}
|
||||
</div>
|
||||
)
|
||||
},
|
||||
}))
|
||||
|
||||
renderComponent()
|
||||
|
||||
expect(screen.getByTestId('model-parameter-modal')).toBeInTheDocument()
|
||||
})
|
||||
})
|
||||
})
|
||||
@ -0,0 +1,721 @@
|
||||
import type { ModelAndParameter } from '../types'
|
||||
import { render, screen } from '@testing-library/react'
|
||||
import { TransferMethod } from '@/app/components/base/chat/types'
|
||||
import { APP_CHAT_WITH_MULTIPLE_MODEL } from '../types'
|
||||
import TextGenerationItem from './text-generation-item'
|
||||
|
||||
const mockUseDebugConfigurationContext = vi.fn()
|
||||
const mockUseProviderContext = vi.fn()
|
||||
const mockUseFeatures = vi.fn()
|
||||
const mockUseTextGeneration = vi.fn()
|
||||
const mockUseEventEmitterContextContext = vi.fn()
|
||||
const mockPromptVariablesToUserInputsForm = vi.fn()
|
||||
|
||||
let capturedTextGenerationProps: {
|
||||
content: string
|
||||
isLoading: boolean
|
||||
isResponding: boolean
|
||||
messageId: string | null
|
||||
className?: string
|
||||
} | null = null
|
||||
|
||||
let eventSubscriptionCallback: ((v: { type: string, payload?: Record<string, unknown> }) => void) | null = null
|
||||
|
||||
vi.mock('@/context/debug-configuration', () => ({
|
||||
useDebugConfigurationContext: () => mockUseDebugConfigurationContext(),
|
||||
}))
|
||||
|
||||
vi.mock('@/context/provider-context', () => ({
|
||||
useProviderContext: () => mockUseProviderContext(),
|
||||
}))
|
||||
|
||||
vi.mock('@/app/components/base/features/hooks', () => ({
|
||||
useFeatures: (selector: (state: Record<string, unknown>) => unknown) => mockUseFeatures(selector),
|
||||
}))
|
||||
|
||||
vi.mock('@/app/components/base/text-generation/hooks', () => ({
|
||||
useTextGeneration: () => mockUseTextGeneration(),
|
||||
}))
|
||||
|
||||
vi.mock('@/context/event-emitter', () => ({
|
||||
useEventEmitterContextContext: () => mockUseEventEmitterContextContext(),
|
||||
}))
|
||||
|
||||
vi.mock('@/utils/model-config', () => ({
|
||||
promptVariablesToUserInputsForm: (...args: unknown[]) => mockPromptVariablesToUserInputsForm(...args),
|
||||
}))
|
||||
|
||||
vi.mock('@/app/components/app/text-generate/item', () => ({
|
||||
default: (props: typeof capturedTextGenerationProps) => {
|
||||
capturedTextGenerationProps = props
|
||||
return (
|
||||
<div data-testid="text-generation">
|
||||
<span data-testid="content">{props?.content}</span>
|
||||
<span data-testid="is-loading">{props?.isLoading ? 'yes' : 'no'}</span>
|
||||
<span data-testid="is-responding">{props?.isResponding ? 'yes' : 'no'}</span>
|
||||
<span data-testid="message-id">{props?.messageId || 'null'}</span>
|
||||
</div>
|
||||
)
|
||||
},
|
||||
}))
|
||||
|
||||
const createModelAndParameter = (overrides: Partial<ModelAndParameter> = {}): ModelAndParameter => ({
|
||||
id: 'model-1',
|
||||
model: 'gpt-3.5-turbo',
|
||||
provider: 'openai',
|
||||
parameters: { temperature: 0.7 },
|
||||
...overrides,
|
||||
})
|
||||
|
||||
const createDefaultMocks = () => {
|
||||
mockUseDebugConfigurationContext.mockReturnValue({
|
||||
isAdvancedMode: false,
|
||||
modelConfig: {
|
||||
configs: {
|
||||
prompt_template: 'Hello {{name}}',
|
||||
prompt_variables: [
|
||||
{ key: 'name', name: 'Name', type: 'string', is_context_var: false },
|
||||
],
|
||||
},
|
||||
system_parameters: {},
|
||||
},
|
||||
appId: 'app-123',
|
||||
inputs: { name: 'World' },
|
||||
promptMode: 'simple',
|
||||
speechToTextConfig: { enabled: true },
|
||||
introduction: 'Welcome!',
|
||||
suggestedQuestionsAfterAnswerConfig: { enabled: false },
|
||||
citationConfig: { enabled: true },
|
||||
externalDataToolsConfig: [],
|
||||
chatPromptConfig: {},
|
||||
completionPromptConfig: {},
|
||||
dataSets: [{ id: 'ds-1' }],
|
||||
datasetConfigs: { retrieval_model: 'single' },
|
||||
})
|
||||
|
||||
mockUseProviderContext.mockReturnValue({
|
||||
textGenerationModelList: [
|
||||
{
|
||||
provider: 'openai',
|
||||
models: [
|
||||
{
|
||||
model: 'gpt-3.5-turbo',
|
||||
model_properties: { mode: 'chat' },
|
||||
},
|
||||
],
|
||||
},
|
||||
],
|
||||
})
|
||||
|
||||
mockUseFeatures.mockImplementation((selector: (state: Record<string, unknown>) => unknown) => {
|
||||
const state = {
|
||||
features: {
|
||||
moreLikeThis: { enabled: false },
|
||||
moderation: { enabled: false },
|
||||
text2speech: { enabled: false },
|
||||
file: { enabled: true },
|
||||
},
|
||||
}
|
||||
return selector(state)
|
||||
})
|
||||
|
||||
mockUseTextGeneration.mockReturnValue({
|
||||
completion: 'Generated text',
|
||||
handleSend: vi.fn(),
|
||||
isResponding: false,
|
||||
messageId: 'msg-123',
|
||||
})
|
||||
|
||||
mockUseEventEmitterContextContext.mockReturnValue({
|
||||
eventEmitter: {
|
||||
useSubscription: (callback: (v: { type: string, payload?: Record<string, unknown> }) => void) => {
|
||||
eventSubscriptionCallback = callback
|
||||
},
|
||||
},
|
||||
})
|
||||
|
||||
mockPromptVariablesToUserInputsForm.mockReturnValue([
|
||||
{ variable: 'name', label: 'Name', type: 'text-input', required: true },
|
||||
])
|
||||
}
|
||||
|
||||
const renderComponent = (props: Partial<{ modelAndParameter: ModelAndParameter }> = {}) => {
|
||||
const defaultProps = {
|
||||
modelAndParameter: createModelAndParameter(),
|
||||
...props,
|
||||
}
|
||||
return render(<TextGenerationItem {...defaultProps} />)
|
||||
}
|
||||
|
||||
describe('TextGenerationItem', () => {
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks()
|
||||
capturedTextGenerationProps = null
|
||||
eventSubscriptionCallback = null
|
||||
createDefaultMocks()
|
||||
})
|
||||
|
||||
describe('rendering', () => {
|
||||
it('should render TextGeneration component', () => {
|
||||
renderComponent()
|
||||
|
||||
expect(screen.getByTestId('text-generation')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should pass completion content to TextGeneration', () => {
|
||||
mockUseTextGeneration.mockReturnValue({
|
||||
completion: 'Hello World',
|
||||
handleSend: vi.fn(),
|
||||
isResponding: false,
|
||||
messageId: 'msg-1',
|
||||
})
|
||||
|
||||
renderComponent()
|
||||
|
||||
expect(screen.getByTestId('content')).toHaveTextContent('Hello World')
|
||||
})
|
||||
|
||||
it('should show loading when no completion and responding', () => {
|
||||
mockUseTextGeneration.mockReturnValue({
|
||||
completion: '',
|
||||
handleSend: vi.fn(),
|
||||
isResponding: true,
|
||||
messageId: null,
|
||||
})
|
||||
|
||||
renderComponent()
|
||||
|
||||
expect(screen.getByTestId('is-loading')).toHaveTextContent('yes')
|
||||
})
|
||||
|
||||
it('should not show loading when completion exists', () => {
|
||||
mockUseTextGeneration.mockReturnValue({
|
||||
completion: 'Some text',
|
||||
handleSend: vi.fn(),
|
||||
isResponding: true,
|
||||
messageId: 'msg-1',
|
||||
})
|
||||
|
||||
renderComponent()
|
||||
|
||||
expect(screen.getByTestId('is-loading')).toHaveTextContent('no')
|
||||
})
|
||||
|
||||
it('should pass isResponding to TextGeneration', () => {
|
||||
mockUseTextGeneration.mockReturnValue({
|
||||
completion: 'Text',
|
||||
handleSend: vi.fn(),
|
||||
isResponding: true,
|
||||
messageId: 'msg-1',
|
||||
})
|
||||
|
||||
renderComponent()
|
||||
|
||||
expect(screen.getByTestId('is-responding')).toHaveTextContent('yes')
|
||||
})
|
||||
|
||||
it('should pass messageId to TextGeneration', () => {
|
||||
mockUseTextGeneration.mockReturnValue({
|
||||
completion: 'Text',
|
||||
handleSend: vi.fn(),
|
||||
isResponding: false,
|
||||
messageId: 'msg-456',
|
||||
})
|
||||
|
||||
renderComponent()
|
||||
|
||||
expect(screen.getByTestId('message-id')).toHaveTextContent('msg-456')
|
||||
})
|
||||
})
|
||||
|
||||
describe('config composition', () => {
|
||||
it('should use prompt_template in non-advanced mode', () => {
|
||||
mockUseDebugConfigurationContext.mockReturnValue({
|
||||
isAdvancedMode: false,
|
||||
modelConfig: {
|
||||
configs: {
|
||||
prompt_template: 'My Template',
|
||||
prompt_variables: [],
|
||||
},
|
||||
system_parameters: {},
|
||||
},
|
||||
appId: 'app-123',
|
||||
inputs: {},
|
||||
promptMode: 'simple',
|
||||
speechToTextConfig: {},
|
||||
introduction: '',
|
||||
suggestedQuestionsAfterAnswerConfig: {},
|
||||
citationConfig: {},
|
||||
externalDataToolsConfig: [],
|
||||
chatPromptConfig: {},
|
||||
completionPromptConfig: {},
|
||||
dataSets: [],
|
||||
datasetConfigs: {},
|
||||
})
|
||||
|
||||
renderComponent()
|
||||
|
||||
// Config is built internally - we verify through the component rendering
|
||||
expect(capturedTextGenerationProps).not.toBeNull()
|
||||
})
|
||||
|
||||
it('should use empty pre_prompt in advanced mode', () => {
|
||||
mockUseDebugConfigurationContext.mockReturnValue({
|
||||
isAdvancedMode: true,
|
||||
modelConfig: {
|
||||
configs: {
|
||||
prompt_template: 'Should not be used',
|
||||
prompt_variables: [],
|
||||
},
|
||||
system_parameters: {},
|
||||
},
|
||||
appId: 'app-123',
|
||||
inputs: {},
|
||||
promptMode: 'advanced',
|
||||
speechToTextConfig: {},
|
||||
introduction: '',
|
||||
suggestedQuestionsAfterAnswerConfig: {},
|
||||
citationConfig: {},
|
||||
externalDataToolsConfig: [],
|
||||
chatPromptConfig: { custom: true },
|
||||
completionPromptConfig: { custom: true },
|
||||
dataSets: [],
|
||||
datasetConfigs: {},
|
||||
})
|
||||
|
||||
renderComponent()
|
||||
|
||||
expect(capturedTextGenerationProps).not.toBeNull()
|
||||
})
|
||||
|
||||
it('should find context variable from prompt_variables', () => {
|
||||
mockUseDebugConfigurationContext.mockReturnValue({
|
||||
isAdvancedMode: false,
|
||||
modelConfig: {
|
||||
configs: {
|
||||
prompt_template: '',
|
||||
prompt_variables: [
|
||||
{ key: 'context', name: 'Context', type: 'string', is_context_var: true },
|
||||
{ key: 'query', name: 'Query', type: 'string', is_context_var: false },
|
||||
],
|
||||
},
|
||||
system_parameters: {},
|
||||
},
|
||||
appId: 'app-123',
|
||||
inputs: {},
|
||||
promptMode: 'simple',
|
||||
speechToTextConfig: {},
|
||||
introduction: '',
|
||||
suggestedQuestionsAfterAnswerConfig: {},
|
||||
citationConfig: {},
|
||||
externalDataToolsConfig: [],
|
||||
chatPromptConfig: {},
|
||||
completionPromptConfig: {},
|
||||
dataSets: [],
|
||||
datasetConfigs: {},
|
||||
})
|
||||
|
||||
renderComponent()
|
||||
|
||||
expect(capturedTextGenerationProps).not.toBeNull()
|
||||
})
|
||||
})
|
||||
|
||||
describe('dataset configuration', () => {
|
||||
it('should transform dataSets to postDatasets format', () => {
|
||||
mockUseDebugConfigurationContext.mockReturnValue({
|
||||
isAdvancedMode: false,
|
||||
modelConfig: {
|
||||
configs: { prompt_template: '', prompt_variables: [] },
|
||||
system_parameters: {},
|
||||
},
|
||||
appId: 'app-123',
|
||||
inputs: {},
|
||||
promptMode: 'simple',
|
||||
speechToTextConfig: {},
|
||||
introduction: '',
|
||||
suggestedQuestionsAfterAnswerConfig: {},
|
||||
citationConfig: {},
|
||||
externalDataToolsConfig: [],
|
||||
chatPromptConfig: {},
|
||||
completionPromptConfig: {},
|
||||
dataSets: [{ id: 'ds-1' }, { id: 'ds-2' }],
|
||||
datasetConfigs: { retrieval_model: 'multiple' },
|
||||
})
|
||||
|
||||
renderComponent()
|
||||
|
||||
// postDatasets is used in config.dataset_configs.datasets
|
||||
expect(capturedTextGenerationProps).not.toBeNull()
|
||||
})
|
||||
})
|
||||
|
||||
describe('event subscription', () => {
|
||||
it('should handle APP_CHAT_WITH_MULTIPLE_MODEL event', () => {
|
||||
const handleSend = vi.fn()
|
||||
mockUseTextGeneration.mockReturnValue({
|
||||
completion: '',
|
||||
handleSend,
|
||||
isResponding: false,
|
||||
messageId: null,
|
||||
})
|
||||
|
||||
renderComponent()
|
||||
|
||||
eventSubscriptionCallback?.({
|
||||
type: APP_CHAT_WITH_MULTIPLE_MODEL,
|
||||
payload: { message: 'Generate text', files: [] },
|
||||
})
|
||||
|
||||
expect(handleSend).toHaveBeenCalledWith(
|
||||
'apps/app-123/completion-messages',
|
||||
expect.objectContaining({
|
||||
inputs: { name: 'World' },
|
||||
}),
|
||||
)
|
||||
})
|
||||
|
||||
it('should ignore other event types', () => {
|
||||
const handleSend = vi.fn()
|
||||
mockUseTextGeneration.mockReturnValue({
|
||||
completion: '',
|
||||
handleSend,
|
||||
isResponding: false,
|
||||
messageId: null,
|
||||
})
|
||||
|
||||
renderComponent()
|
||||
|
||||
eventSubscriptionCallback?.({
|
||||
type: 'OTHER_EVENT',
|
||||
payload: {},
|
||||
})
|
||||
|
||||
expect(handleSend).not.toHaveBeenCalled()
|
||||
})
|
||||
})
|
||||
|
||||
describe('doSend function', () => {
|
||||
it('should include model configuration', () => {
|
||||
const handleSend = vi.fn()
|
||||
mockUseTextGeneration.mockReturnValue({
|
||||
completion: '',
|
||||
handleSend,
|
||||
isResponding: false,
|
||||
messageId: null,
|
||||
})
|
||||
|
||||
const modelAndParameter = createModelAndParameter({
|
||||
provider: 'anthropic',
|
||||
model: 'claude-3',
|
||||
parameters: { max_tokens: 2000 },
|
||||
})
|
||||
|
||||
renderComponent({ modelAndParameter })
|
||||
|
||||
eventSubscriptionCallback?.({
|
||||
type: APP_CHAT_WITH_MULTIPLE_MODEL,
|
||||
payload: { message: 'Test', files: [] },
|
||||
})
|
||||
|
||||
expect(handleSend).toHaveBeenCalledWith(
|
||||
expect.any(String),
|
||||
expect.objectContaining({
|
||||
model_config: expect.objectContaining({
|
||||
model: expect.objectContaining({
|
||||
provider: 'anthropic',
|
||||
name: 'claude-3',
|
||||
completion_params: { max_tokens: 2000 },
|
||||
}),
|
||||
}),
|
||||
}),
|
||||
)
|
||||
})
|
||||
|
||||
it('should include files with local_file transfer method handled', () => {
|
||||
const handleSend = vi.fn()
|
||||
mockUseTextGeneration.mockReturnValue({
|
||||
completion: '',
|
||||
handleSend,
|
||||
isResponding: false,
|
||||
messageId: null,
|
||||
})
|
||||
|
||||
renderComponent()
|
||||
|
||||
const files = [
|
||||
{ id: 'f1', transfer_method: TransferMethod.local_file, url: 'blob:123' },
|
||||
{ id: 'f2', transfer_method: TransferMethod.remote_url, url: 'https://example.com/file' },
|
||||
]
|
||||
|
||||
eventSubscriptionCallback?.({
|
||||
type: APP_CHAT_WITH_MULTIPLE_MODEL,
|
||||
payload: { message: 'Test', files },
|
||||
})
|
||||
|
||||
expect(handleSend).toHaveBeenCalledWith(
|
||||
expect.any(String),
|
||||
expect.objectContaining({
|
||||
files: [
|
||||
expect.objectContaining({ id: 'f1', transfer_method: TransferMethod.local_file, url: '' }),
|
||||
expect.objectContaining({ id: 'f2', transfer_method: TransferMethod.remote_url, url: 'https://example.com/file' }),
|
||||
],
|
||||
}),
|
||||
)
|
||||
})
|
||||
|
||||
it('should not include files when file upload is disabled', () => {
|
||||
const handleSend = vi.fn()
|
||||
mockUseTextGeneration.mockReturnValue({
|
||||
completion: '',
|
||||
handleSend,
|
||||
isResponding: false,
|
||||
messageId: null,
|
||||
})
|
||||
|
||||
mockUseFeatures.mockImplementation((selector: (state: Record<string, unknown>) => unknown) => {
|
||||
const state = {
|
||||
features: {
|
||||
moreLikeThis: { enabled: false },
|
||||
moderation: { enabled: false },
|
||||
text2speech: { enabled: false },
|
||||
file: { enabled: false },
|
||||
},
|
||||
}
|
||||
return selector(state)
|
||||
})
|
||||
|
||||
renderComponent()
|
||||
|
||||
eventSubscriptionCallback?.({
|
||||
type: APP_CHAT_WITH_MULTIPLE_MODEL,
|
||||
payload: { message: 'Test', files: [{ id: 'f1' }] },
|
||||
})
|
||||
|
||||
expect(handleSend).toHaveBeenCalledWith(
|
||||
expect.any(String),
|
||||
expect.not.objectContaining({
|
||||
files: expect.anything(),
|
||||
}),
|
||||
)
|
||||
})
|
||||
|
||||
it('should not include files when files array is empty', () => {
|
||||
const handleSend = vi.fn()
|
||||
mockUseTextGeneration.mockReturnValue({
|
||||
completion: '',
|
||||
handleSend,
|
||||
isResponding: false,
|
||||
messageId: null,
|
||||
})
|
||||
|
||||
renderComponent()
|
||||
|
||||
eventSubscriptionCallback?.({
|
||||
type: APP_CHAT_WITH_MULTIPLE_MODEL,
|
||||
payload: { message: 'Test', files: [] },
|
||||
})
|
||||
|
||||
expect(handleSend).toHaveBeenCalledWith(
|
||||
expect.any(String),
|
||||
expect.not.objectContaining({
|
||||
files: expect.anything(),
|
||||
}),
|
||||
)
|
||||
})
|
||||
|
||||
it('should not include files when files is undefined', () => {
|
||||
const handleSend = vi.fn()
|
||||
mockUseTextGeneration.mockReturnValue({
|
||||
completion: '',
|
||||
handleSend,
|
||||
isResponding: false,
|
||||
messageId: null,
|
||||
})
|
||||
|
||||
renderComponent()
|
||||
|
||||
eventSubscriptionCallback?.({
|
||||
type: APP_CHAT_WITH_MULTIPLE_MODEL,
|
||||
payload: { message: 'Test' },
|
||||
})
|
||||
|
||||
expect(handleSend).toHaveBeenCalledWith(
|
||||
expect.any(String),
|
||||
expect.not.objectContaining({
|
||||
files: expect.anything(),
|
||||
}),
|
||||
)
|
||||
})
|
||||
})
|
||||
|
||||
describe('model resolution', () => {
|
||||
it('should find current provider and model', () => {
|
||||
const handleSend = vi.fn()
|
||||
mockUseTextGeneration.mockReturnValue({
|
||||
completion: '',
|
||||
handleSend,
|
||||
isResponding: false,
|
||||
messageId: null,
|
||||
})
|
||||
|
||||
mockUseProviderContext.mockReturnValue({
|
||||
textGenerationModelList: [
|
||||
{
|
||||
provider: 'openai',
|
||||
models: [
|
||||
{ model: 'gpt-3.5-turbo', model_properties: { mode: 'chat' } },
|
||||
{ model: 'gpt-4', model_properties: { mode: 'chat' } },
|
||||
],
|
||||
},
|
||||
],
|
||||
})
|
||||
|
||||
const modelAndParameter = createModelAndParameter({
|
||||
provider: 'openai',
|
||||
model: 'gpt-4',
|
||||
})
|
||||
|
||||
renderComponent({ modelAndParameter })
|
||||
|
||||
eventSubscriptionCallback?.({
|
||||
type: APP_CHAT_WITH_MULTIPLE_MODEL,
|
||||
payload: { message: 'Test', files: [] },
|
||||
})
|
||||
|
||||
expect(handleSend).toHaveBeenCalledWith(
|
||||
expect.any(String),
|
||||
expect.objectContaining({
|
||||
model_config: expect.objectContaining({
|
||||
model: expect.objectContaining({
|
||||
mode: 'chat',
|
||||
}),
|
||||
}),
|
||||
}),
|
||||
)
|
||||
})
|
||||
|
||||
it('should handle provider not found', () => {
|
||||
const handleSend = vi.fn()
|
||||
mockUseTextGeneration.mockReturnValue({
|
||||
completion: '',
|
||||
handleSend,
|
||||
isResponding: false,
|
||||
messageId: null,
|
||||
})
|
||||
|
||||
mockUseProviderContext.mockReturnValue({
|
||||
textGenerationModelList: [],
|
||||
})
|
||||
|
||||
renderComponent()
|
||||
|
||||
eventSubscriptionCallback?.({
|
||||
type: APP_CHAT_WITH_MULTIPLE_MODEL,
|
||||
payload: { message: 'Test', files: [] },
|
||||
})
|
||||
|
||||
// Should still call handleSend without crashing
|
||||
expect(handleSend).toHaveBeenCalled()
|
||||
})
|
||||
})
|
||||
|
||||
describe('edge cases', () => {
|
||||
it('should handle null eventEmitter', () => {
|
||||
mockUseEventEmitterContextContext.mockReturnValue({
|
||||
eventEmitter: null,
|
||||
})
|
||||
|
||||
expect(() => renderComponent()).not.toThrow()
|
||||
})
|
||||
|
||||
it('should handle empty prompt_variables', () => {
|
||||
mockUseDebugConfigurationContext.mockReturnValue({
|
||||
isAdvancedMode: false,
|
||||
modelConfig: {
|
||||
configs: { prompt_template: '', prompt_variables: [] },
|
||||
system_parameters: {},
|
||||
},
|
||||
appId: 'app-123',
|
||||
inputs: {},
|
||||
promptMode: 'simple',
|
||||
speechToTextConfig: {},
|
||||
introduction: '',
|
||||
suggestedQuestionsAfterAnswerConfig: {},
|
||||
citationConfig: {},
|
||||
externalDataToolsConfig: [],
|
||||
chatPromptConfig: {},
|
||||
completionPromptConfig: {},
|
||||
dataSets: [],
|
||||
datasetConfigs: {},
|
||||
})
|
||||
|
||||
expect(() => renderComponent()).not.toThrow()
|
||||
})
|
||||
|
||||
it('should handle no context variable found', () => {
|
||||
mockUseDebugConfigurationContext.mockReturnValue({
|
||||
isAdvancedMode: false,
|
||||
modelConfig: {
|
||||
configs: {
|
||||
prompt_template: '',
|
||||
prompt_variables: [
|
||||
{ key: 'var1', name: 'Var1', type: 'string', is_context_var: false },
|
||||
],
|
||||
},
|
||||
system_parameters: {},
|
||||
},
|
||||
appId: 'app-123',
|
||||
inputs: {},
|
||||
promptMode: 'simple',
|
||||
speechToTextConfig: {},
|
||||
introduction: '',
|
||||
suggestedQuestionsAfterAnswerConfig: {},
|
||||
citationConfig: {},
|
||||
externalDataToolsConfig: [],
|
||||
chatPromptConfig: {},
|
||||
completionPromptConfig: {},
|
||||
dataSets: [],
|
||||
datasetConfigs: {},
|
||||
})
|
||||
|
||||
renderComponent()
|
||||
|
||||
// Should use empty string for dataset_query_variable
|
||||
expect(capturedTextGenerationProps).not.toBeNull()
|
||||
})
|
||||
})
|
||||
|
||||
describe('promptVariablesToUserInputsForm', () => {
|
||||
it('should call promptVariablesToUserInputsForm with prompt_variables', () => {
|
||||
const promptVariables = [
|
||||
{ key: 'name', name: 'Name', type: 'string' },
|
||||
{ key: 'age', name: 'Age', type: 'number' },
|
||||
]
|
||||
|
||||
mockUseDebugConfigurationContext.mockReturnValue({
|
||||
isAdvancedMode: false,
|
||||
modelConfig: {
|
||||
configs: { prompt_template: '', prompt_variables: promptVariables },
|
||||
system_parameters: {},
|
||||
},
|
||||
appId: 'app-123',
|
||||
inputs: {},
|
||||
promptMode: 'simple',
|
||||
speechToTextConfig: {},
|
||||
introduction: '',
|
||||
suggestedQuestionsAfterAnswerConfig: {},
|
||||
citationConfig: {},
|
||||
externalDataToolsConfig: [],
|
||||
chatPromptConfig: {},
|
||||
completionPromptConfig: {},
|
||||
dataSets: [],
|
||||
datasetConfigs: {},
|
||||
})
|
||||
|
||||
renderComponent()
|
||||
|
||||
expect(mockPromptVariablesToUserInputsForm).toHaveBeenCalledWith(promptVariables)
|
||||
})
|
||||
})
|
||||
})
|
||||
@ -0,0 +1,6 @@
|
||||
<svg width="16" height="16" viewBox="0 0 16 16" fill="none" xmlns="http://www.w3.org/2000/svg">
|
||||
<path d="M12 7.33337V2.66671H4.00002V13.3334H8.00002C8.36821 13.3334 8.66669 13.6319 8.66669 14C8.66669 14.3682 8.36821 14.6667 8.00002 14.6667H3.33335C2.96516 14.6667 2.66669 14.3682 2.66669 14V2.00004C2.66669 1.63185 2.96516 1.33337 3.33335 1.33337H12.6667C13.0349 1.33337 13.3334 1.63185 13.3334 2.00004V7.33337C13.3334 7.70156 13.0349 8.00004 12.6667 8.00004C12.2985 8.00004 12 7.70156 12 7.33337Z" fill="#354052"/>
|
||||
<path d="M10 4.00004C10.3682 4.00004 10.6667 4.29852 10.6667 4.66671C10.6667 5.0349 10.3682 5.33337 10 5.33337H6.00002C5.63183 5.33337 5.33335 5.0349 5.33335 4.66671C5.33335 4.29852 5.63183 4.00004 6.00002 4.00004H10Z" fill="#354052"/>
|
||||
<path d="M8.00002 6.66671C8.36821 6.66671 8.66669 6.96518 8.66669 7.33337C8.66669 7.70156 8.36821 8.00004 8.00002 8.00004H6.00002C5.63183 8.00004 5.33335 7.70156 5.33335 7.33337C5.33335 6.96518 5.63183 6.66671 6.00002 6.66671H8.00002Z" fill="#354052"/>
|
||||
<path d="M12.827 10.7902L12.3624 9.58224C12.3048 9.43231 12.1607 9.33337 12 9.33337C11.8394 9.33337 11.6953 9.43231 11.6376 9.58224L11.173 10.7902C11.1054 10.9662 10.9662 11.1054 10.7902 11.173L9.58222 11.6376C9.43229 11.6953 9.33335 11.8394 9.33335 12C9.33335 12.1607 9.43229 12.3048 9.58222 12.3624L10.7902 12.827C10.9662 12.8947 11.1054 13.0338 11.173 13.2099L11.6376 14.4178C11.6953 14.5678 11.8394 14.6667 12 14.6667C12.1607 14.6667 12.3048 14.5678 12.3624 14.4178L12.827 13.2099C12.8947 13.0338 13.0338 12.8947 13.2099 12.827L14.4178 12.3624C14.5678 12.3048 14.6667 12.1607 14.6667 12C14.6667 11.8394 14.5678 11.6953 14.4178 11.6376L13.2099 11.173C13.0338 11.1054 12.8947 10.9662 12.827 10.7902Z" fill="#354052"/>
|
||||
</svg>
|
||||
|
After Width: | Height: | Size: 1.7 KiB |
@ -0,0 +1,53 @@
|
||||
{
|
||||
"icon": {
|
||||
"type": "element",
|
||||
"isRootNode": true,
|
||||
"name": "svg",
|
||||
"attributes": {
|
||||
"width": "16",
|
||||
"height": "16",
|
||||
"viewBox": "0 0 16 16",
|
||||
"fill": "none",
|
||||
"xmlns": "http://www.w3.org/2000/svg"
|
||||
},
|
||||
"children": [
|
||||
{
|
||||
"type": "element",
|
||||
"name": "path",
|
||||
"attributes": {
|
||||
"d": "M12 7.33337V2.66671H4.00002V13.3334H8.00002C8.36821 13.3334 8.66669 13.6319 8.66669 14C8.66669 14.3682 8.36821 14.6667 8.00002 14.6667H3.33335C2.96516 14.6667 2.66669 14.3682 2.66669 14V2.00004C2.66669 1.63185 2.96516 1.33337 3.33335 1.33337H12.6667C13.0349 1.33337 13.3334 1.63185 13.3334 2.00004V7.33337C13.3334 7.70156 13.0349 8.00004 12.6667 8.00004C12.2985 8.00004 12 7.70156 12 7.33337Z",
|
||||
"fill": "currentColor"
|
||||
},
|
||||
"children": []
|
||||
},
|
||||
{
|
||||
"type": "element",
|
||||
"name": "path",
|
||||
"attributes": {
|
||||
"d": "M10 4.00004C10.3682 4.00004 10.6667 4.29852 10.6667 4.66671C10.6667 5.0349 10.3682 5.33337 10 5.33337H6.00002C5.63183 5.33337 5.33335 5.0349 5.33335 4.66671C5.33335 4.29852 5.63183 4.00004 6.00002 4.00004H10Z",
|
||||
"fill": "currentColor"
|
||||
},
|
||||
"children": []
|
||||
},
|
||||
{
|
||||
"type": "element",
|
||||
"name": "path",
|
||||
"attributes": {
|
||||
"d": "M8.00002 6.66671C8.36821 6.66671 8.66669 6.96518 8.66669 7.33337C8.66669 7.70156 8.36821 8.00004 8.00002 8.00004H6.00002C5.63183 8.00004 5.33335 7.70156 5.33335 7.33337C5.33335 6.96518 5.63183 6.66671 6.00002 6.66671H8.00002Z",
|
||||
"fill": "currentColor"
|
||||
},
|
||||
"children": []
|
||||
},
|
||||
{
|
||||
"type": "element",
|
||||
"name": "path",
|
||||
"attributes": {
|
||||
"d": "M12.827 10.7902L12.3624 9.58224C12.3048 9.43231 12.1607 9.33337 12 9.33337C11.8394 9.33337 11.6953 9.43231 11.6376 9.58224L11.173 10.7902C11.1054 10.9662 10.9662 11.1054 10.7902 11.173L9.58222 11.6376C9.43229 11.6953 9.33335 11.8394 9.33335 12C9.33335 12.1607 9.43229 12.3048 9.58222 12.3624L10.7902 12.827C10.9662 12.8947 11.1054 13.0338 11.173 13.2099L11.6376 14.4178C11.6953 14.5678 11.8394 14.6667 12 14.6667C12.1607 14.6667 12.3048 14.5678 12.3624 14.4178L12.827 13.2099C12.8947 13.0338 13.0338 12.8947 13.2099 12.827L14.4178 12.3624C14.5678 12.3048 14.6667 12.1607 14.6667 12C14.6667 11.8394 14.5678 11.6953 14.4178 11.6376L13.2099 11.173C13.0338 11.1054 12.8947 10.9662 12.827 10.7902Z",
|
||||
"fill": "currentColor"
|
||||
},
|
||||
"children": []
|
||||
}
|
||||
]
|
||||
},
|
||||
"name": "SearchLinesSparkle"
|
||||
}
|
||||
@ -0,0 +1,20 @@
|
||||
// GENERATE BY script
|
||||
// DON NOT EDIT IT MANUALLY
|
||||
|
||||
import type { IconData } from '@/app/components/base/icons/IconBase'
|
||||
import * as React from 'react'
|
||||
import IconBase from '@/app/components/base/icons/IconBase'
|
||||
import data from './SearchLinesSparkle.json'
|
||||
|
||||
const Icon = (
|
||||
{
|
||||
ref,
|
||||
...props
|
||||
}: React.SVGProps<SVGSVGElement> & {
|
||||
ref?: React.RefObject<React.RefObject<HTMLOrSVGElement>>
|
||||
},
|
||||
) => <IconBase {...props} ref={ref} data={data as IconData} />
|
||||
|
||||
Icon.displayName = 'SearchLinesSparkle'
|
||||
|
||||
export default Icon
|
||||
@ -11,5 +11,6 @@ export { default as HighQuality } from './HighQuality'
|
||||
export { default as HybridSearch } from './HybridSearch'
|
||||
export { default as ParentChildChunk } from './ParentChildChunk'
|
||||
export { default as QuestionAndAnswer } from './QuestionAndAnswer'
|
||||
export { default as SearchLinesSparkle } from './SearchLinesSparkle'
|
||||
export { default as SearchMenu } from './SearchMenu'
|
||||
export { default as VectorSearch } from './VectorSearch'
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
'use client'
|
||||
|
||||
import type { FC } from 'react'
|
||||
import type { PreProcessingRule } from '@/models/datasets'
|
||||
import type { PreProcessingRule, SummaryIndexSetting as SummaryIndexSettingType } from '@/models/datasets'
|
||||
import {
|
||||
RiAlertFill,
|
||||
RiSearchEyeLine,
|
||||
@ -12,6 +12,7 @@ import Button from '@/app/components/base/button'
|
||||
import Checkbox from '@/app/components/base/checkbox'
|
||||
import Divider from '@/app/components/base/divider'
|
||||
import Tooltip from '@/app/components/base/tooltip'
|
||||
import SummaryIndexSetting from '@/app/components/datasets/settings/summary-index-setting'
|
||||
import { IS_CE_EDITION } from '@/config'
|
||||
import { ChunkingMode } from '@/models/datasets'
|
||||
import SettingCog from '../../assets/setting-gear-mod.svg'
|
||||
@ -52,6 +53,9 @@ type GeneralChunkingOptionsProps = {
|
||||
onReset: () => void
|
||||
// Locale
|
||||
locale: string
|
||||
showSummaryIndexSetting?: boolean
|
||||
summaryIndexSetting?: SummaryIndexSettingType
|
||||
onSummaryIndexSettingChange?: (payload: SummaryIndexSettingType) => void
|
||||
}
|
||||
|
||||
export const GeneralChunkingOptions: FC<GeneralChunkingOptionsProps> = ({
|
||||
@ -74,6 +78,9 @@ export const GeneralChunkingOptions: FC<GeneralChunkingOptionsProps> = ({
|
||||
onPreview,
|
||||
onReset,
|
||||
locale,
|
||||
showSummaryIndexSetting,
|
||||
summaryIndexSetting,
|
||||
onSummaryIndexSettingChange,
|
||||
}) => {
|
||||
const { t } = useTranslation()
|
||||
|
||||
@ -146,6 +153,17 @@ export const GeneralChunkingOptions: FC<GeneralChunkingOptionsProps> = ({
|
||||
</label>
|
||||
</div>
|
||||
))}
|
||||
{
|
||||
showSummaryIndexSetting && (
|
||||
<div className="mt-3">
|
||||
<SummaryIndexSetting
|
||||
entry="create-document"
|
||||
summaryIndexSetting={summaryIndexSetting}
|
||||
onSummaryIndexSettingChange={onSummaryIndexSettingChange}
|
||||
/>
|
||||
</div>
|
||||
)
|
||||
}
|
||||
{IS_CE_EDITION && (
|
||||
<>
|
||||
<Divider type="horizontal" className="my-4 bg-divider-subtle" />
|
||||
|
||||
@ -2,7 +2,7 @@
|
||||
|
||||
import type { FC } from 'react'
|
||||
import type { ParentChildConfig } from '../hooks'
|
||||
import type { ParentMode, PreProcessingRule } from '@/models/datasets'
|
||||
import type { ParentMode, PreProcessingRule, SummaryIndexSetting as SummaryIndexSettingType } from '@/models/datasets'
|
||||
import { RiSearchEyeLine } from '@remixicon/react'
|
||||
import Image from 'next/image'
|
||||
import { useTranslation } from 'react-i18next'
|
||||
@ -11,6 +11,7 @@ import Checkbox from '@/app/components/base/checkbox'
|
||||
import Divider from '@/app/components/base/divider'
|
||||
import { ParentChildChunk } from '@/app/components/base/icons/src/vender/knowledge'
|
||||
import RadioCard from '@/app/components/base/radio-card'
|
||||
import SummaryIndexSetting from '@/app/components/datasets/settings/summary-index-setting'
|
||||
import { ChunkingMode } from '@/models/datasets'
|
||||
import FileList from '../../assets/file-list-3-fill.svg'
|
||||
import Note from '../../assets/note-mod.svg'
|
||||
@ -31,6 +32,8 @@ type ParentChildOptionsProps = {
|
||||
// State
|
||||
parentChildConfig: ParentChildConfig
|
||||
rules: PreProcessingRule[]
|
||||
summaryIndexSetting?: SummaryIndexSettingType
|
||||
onSummaryIndexSettingChange?: (payload: SummaryIndexSettingType) => void
|
||||
currentDocForm: ChunkingMode
|
||||
// Flags
|
||||
isActive: boolean
|
||||
@ -46,11 +49,13 @@ type ParentChildOptionsProps = {
|
||||
onRuleToggle: (id: string) => void
|
||||
onPreview: () => void
|
||||
onReset: () => void
|
||||
showSummaryIndexSetting?: boolean
|
||||
}
|
||||
|
||||
export const ParentChildOptions: FC<ParentChildOptionsProps> = ({
|
||||
parentChildConfig,
|
||||
rules,
|
||||
summaryIndexSetting,
|
||||
currentDocForm: _currentDocForm,
|
||||
isActive,
|
||||
isInUpload,
|
||||
@ -62,8 +67,10 @@ export const ParentChildOptions: FC<ParentChildOptionsProps> = ({
|
||||
onChildDelimiterChange,
|
||||
onChildMaxLengthChange,
|
||||
onRuleToggle,
|
||||
onSummaryIndexSettingChange,
|
||||
onPreview,
|
||||
onReset,
|
||||
showSummaryIndexSetting,
|
||||
}) => {
|
||||
const { t } = useTranslation()
|
||||
|
||||
@ -183,6 +190,17 @@ export const ParentChildOptions: FC<ParentChildOptionsProps> = ({
|
||||
</label>
|
||||
</div>
|
||||
))}
|
||||
{
|
||||
showSummaryIndexSetting && (
|
||||
<div className="mt-3">
|
||||
<SummaryIndexSetting
|
||||
entry="create-document"
|
||||
summaryIndexSetting={summaryIndexSetting}
|
||||
onSummaryIndexSettingChange={onSummaryIndexSettingChange}
|
||||
/>
|
||||
</div>
|
||||
)
|
||||
}
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
@ -14,6 +14,7 @@ import { ChunkingMode } from '@/models/datasets'
|
||||
import { cn } from '@/utils/classnames'
|
||||
import { ChunkContainer, QAPreview } from '../../../chunk'
|
||||
import PreviewDocumentPicker from '../../../common/document-picker/preview-document-picker'
|
||||
import SummaryLabel from '../../../documents/detail/completed/common/summary-label'
|
||||
import { PreviewSlice } from '../../../formatted-text/flavours/preview-slice'
|
||||
import { FormattedText } from '../../../formatted-text/formatted'
|
||||
import PreviewContainer from '../../../preview/container'
|
||||
@ -99,6 +100,7 @@ export const PreviewPanel: FC<PreviewPanelProps> = ({
|
||||
characterCount={item.content.length}
|
||||
>
|
||||
{item.content}
|
||||
{item.summary && <SummaryLabel summary={item.summary} />}
|
||||
</ChunkContainer>
|
||||
))
|
||||
)}
|
||||
@ -131,6 +133,7 @@ export const PreviewPanel: FC<PreviewPanelProps> = ({
|
||||
)
|
||||
})}
|
||||
</FormattedText>
|
||||
{item.summary && <SummaryLabel summary={item.summary} />}
|
||||
</ChunkContainer>
|
||||
)
|
||||
})
|
||||
|
||||
@ -9,6 +9,7 @@ import type {
|
||||
CustomFile,
|
||||
FullDocumentDetail,
|
||||
ProcessRule,
|
||||
SummaryIndexSetting as SummaryIndexSettingType,
|
||||
} from '@/models/datasets'
|
||||
import type { RetrievalConfig, RETRIEVE_METHOD } from '@/types/app'
|
||||
import { useCallback } from 'react'
|
||||
@ -141,6 +142,7 @@ export const useDocumentCreation = (options: UseDocumentCreationOptions) => {
|
||||
retrievalConfig: RetrievalConfig,
|
||||
embeddingModel: DefaultModel,
|
||||
indexingTechnique: string,
|
||||
summaryIndexSetting?: SummaryIndexSettingType,
|
||||
): CreateDocumentReq | null => {
|
||||
if (isSetting) {
|
||||
return {
|
||||
@ -148,6 +150,7 @@ export const useDocumentCreation = (options: UseDocumentCreationOptions) => {
|
||||
doc_form: currentDocForm,
|
||||
doc_language: docLanguage,
|
||||
process_rule: processRule,
|
||||
summary_index_setting: summaryIndexSetting,
|
||||
retrieval_model: retrievalConfig,
|
||||
embedding_model: embeddingModel.model,
|
||||
embedding_model_provider: embeddingModel.provider,
|
||||
@ -164,6 +167,7 @@ export const useDocumentCreation = (options: UseDocumentCreationOptions) => {
|
||||
},
|
||||
indexing_technique: indexingTechnique,
|
||||
process_rule: processRule,
|
||||
summary_index_setting: summaryIndexSetting,
|
||||
doc_form: currentDocForm,
|
||||
doc_language: docLanguage,
|
||||
retrieval_model: retrievalConfig,
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
import type { ParentMode, PreProcessingRule, ProcessRule, Rules } from '@/models/datasets'
|
||||
import { useCallback, useState } from 'react'
|
||||
import type { ParentMode, PreProcessingRule, ProcessRule, Rules, SummaryIndexSetting as SummaryIndexSettingType } from '@/models/datasets'
|
||||
import { useCallback, useRef, useState } from 'react'
|
||||
import { ChunkingMode, ProcessMode } from '@/models/datasets'
|
||||
import escape from './escape'
|
||||
import unescape from './unescape'
|
||||
@ -39,10 +39,11 @@ export const defaultParentChildConfig: ParentChildConfig = {
|
||||
|
||||
export type UseSegmentationStateOptions = {
|
||||
initialSegmentationType?: ProcessMode
|
||||
initialSummaryIndexSetting?: SummaryIndexSettingType
|
||||
}
|
||||
|
||||
export const useSegmentationState = (options: UseSegmentationStateOptions = {}) => {
|
||||
const { initialSegmentationType } = options
|
||||
const { initialSegmentationType, initialSummaryIndexSetting } = options
|
||||
|
||||
// Segmentation type (general or parent-child)
|
||||
const [segmentationType, setSegmentationType] = useState<ProcessMode>(
|
||||
@ -58,6 +59,15 @@ export const useSegmentationState = (options: UseSegmentationStateOptions = {})
|
||||
// Pre-processing rules
|
||||
const [rules, setRules] = useState<PreProcessingRule[]>([])
|
||||
const [defaultConfig, setDefaultConfig] = useState<Rules>()
|
||||
const [summaryIndexSetting, setSummaryIndexSetting] = useState<SummaryIndexSettingType | undefined>(initialSummaryIndexSetting)
|
||||
const summaryIndexSettingRef = useRef<SummaryIndexSettingType | undefined>(initialSummaryIndexSetting)
|
||||
const handleSummaryIndexSettingChange = useCallback((payload: SummaryIndexSettingType) => {
|
||||
setSummaryIndexSetting((prev) => {
|
||||
const newSetting = { ...prev, ...payload }
|
||||
summaryIndexSettingRef.current = newSetting
|
||||
return newSetting
|
||||
})
|
||||
}, [])
|
||||
|
||||
// Parent-child config
|
||||
const [parentChildConfig, setParentChildConfig] = useState<ParentChildConfig>(defaultParentChildConfig)
|
||||
@ -134,6 +144,7 @@ export const useSegmentationState = (options: UseSegmentationStateOptions = {})
|
||||
},
|
||||
},
|
||||
mode: 'hierarchical',
|
||||
summary_index_setting: summaryIndexSettingRef.current,
|
||||
} as ProcessRule
|
||||
}
|
||||
|
||||
@ -147,6 +158,7 @@ export const useSegmentationState = (options: UseSegmentationStateOptions = {})
|
||||
},
|
||||
},
|
||||
mode: segmentationType,
|
||||
summary_index_setting: summaryIndexSettingRef.current,
|
||||
} as ProcessRule
|
||||
}, [rules, parentChildConfig, segmentIdentifier, maxChunkLength, overlap, segmentationType])
|
||||
|
||||
@ -204,6 +216,8 @@ export const useSegmentationState = (options: UseSegmentationStateOptions = {})
|
||||
defaultConfig,
|
||||
setDefaultConfig,
|
||||
toggleRule,
|
||||
summaryIndexSetting,
|
||||
handleSummaryIndexSettingChange,
|
||||
|
||||
// Parent-child config
|
||||
parentChildConfig,
|
||||
|
||||
@ -65,7 +65,9 @@ const StepTwo: FC<StepTwoProps> = ({
|
||||
// Custom hooks
|
||||
const segmentation = useSegmentationState({
|
||||
initialSegmentationType: currentDataset?.doc_form === ChunkingMode.parentChild ? ProcessMode.parentChild : ProcessMode.general,
|
||||
initialSummaryIndexSetting: currentDataset?.summary_index_setting,
|
||||
})
|
||||
const showSummaryIndexSetting = !currentDataset
|
||||
const indexing = useIndexingConfig({
|
||||
initialIndexType: propsIndexingType,
|
||||
initialEmbeddingModel: currentDataset?.embedding_model ? { provider: currentDataset.embedding_model_provider, model: currentDataset.embedding_model } : undefined,
|
||||
@ -156,7 +158,7 @@ const StepTwo: FC<StepTwoProps> = ({
|
||||
})
|
||||
if (!isValid)
|
||||
return
|
||||
const params = creation.buildCreationParams(currentDocForm, docLanguage, segmentation.getProcessRule(currentDocForm), indexing.retrievalConfig, indexing.embeddingModel, indexing.getIndexingTechnique())
|
||||
const params = creation.buildCreationParams(currentDocForm, docLanguage, segmentation.getProcessRule(currentDocForm), indexing.retrievalConfig, indexing.embeddingModel, indexing.getIndexingTechnique(), segmentation.summaryIndexSetting)
|
||||
if (!params)
|
||||
return
|
||||
await creation.executeCreation(params, indexing.indexType, indexing.retrievalConfig)
|
||||
@ -217,6 +219,9 @@ const StepTwo: FC<StepTwoProps> = ({
|
||||
onPreview={updatePreview}
|
||||
onReset={segmentation.resetToDefaults}
|
||||
locale={locale}
|
||||
showSummaryIndexSetting={showSummaryIndexSetting}
|
||||
summaryIndexSetting={segmentation.summaryIndexSetting}
|
||||
onSummaryIndexSettingChange={segmentation.handleSummaryIndexSettingChange}
|
||||
/>
|
||||
)}
|
||||
{showParentChildOption && (
|
||||
@ -236,6 +241,9 @@ const StepTwo: FC<StepTwoProps> = ({
|
||||
onRuleToggle={segmentation.toggleRule}
|
||||
onPreview={updatePreview}
|
||||
onReset={segmentation.resetToDefaults}
|
||||
showSummaryIndexSetting={showSummaryIndexSetting}
|
||||
summaryIndexSetting={segmentation.summaryIndexSetting}
|
||||
onSummaryIndexSettingChange={segmentation.handleSummaryIndexSettingChange}
|
||||
/>
|
||||
)}
|
||||
<Divider className="my-5" />
|
||||
|
||||
@ -73,6 +73,12 @@ const createDefaultProps = (overrides: Partial<Parameters<typeof WaterCrawl>[0]>
|
||||
describe('WaterCrawl', () => {
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks()
|
||||
vi.useFakeTimers({ shouldAdvanceTime: true })
|
||||
})
|
||||
|
||||
afterEach(() => {
|
||||
vi.runOnlyPendingTimers()
|
||||
vi.useRealTimers()
|
||||
})
|
||||
|
||||
// Tests for initial component rendering
|
||||
|
||||
@ -30,12 +30,13 @@ import { useDatasetDetailContextWithSelector as useDatasetDetailContext } from '
|
||||
import useTimestamp from '@/hooks/use-timestamp'
|
||||
import { ChunkingMode, DataSourceType, DocumentActionType } from '@/models/datasets'
|
||||
import { DatasourceType } from '@/models/pipeline'
|
||||
import { useDocumentArchive, useDocumentBatchRetryIndex, useDocumentDelete, useDocumentDisable, useDocumentDownloadZip, useDocumentEnable } from '@/service/knowledge/use-document'
|
||||
import { useDocumentArchive, useDocumentBatchRetryIndex, useDocumentDelete, useDocumentDisable, useDocumentDownloadZip, useDocumentEnable, useDocumentSummary } from '@/service/knowledge/use-document'
|
||||
import { asyncRunSafe } from '@/utils'
|
||||
import { cn } from '@/utils/classnames'
|
||||
import { downloadBlob } from '@/utils/download'
|
||||
import { formatNumber } from '@/utils/format'
|
||||
import BatchAction from '../detail/completed/common/batch-action'
|
||||
import SummaryStatus from '../detail/completed/common/summary-status'
|
||||
import StatusItem from '../status-item'
|
||||
import s from '../style.module.css'
|
||||
import Operations from './operations'
|
||||
@ -219,6 +220,7 @@ const DocumentList: FC<IDocumentListProps> = ({
|
||||
onSelectedIdChange(uniq([...selectedIds, ...localDocs.map(doc => doc.id)]))
|
||||
}, [isAllSelected, localDocs, onSelectedIdChange, selectedIds])
|
||||
const { mutateAsync: archiveDocument } = useDocumentArchive()
|
||||
const { mutateAsync: generateSummary } = useDocumentSummary()
|
||||
const { mutateAsync: enableDocument } = useDocumentEnable()
|
||||
const { mutateAsync: disableDocument } = useDocumentDisable()
|
||||
const { mutateAsync: deleteDocument } = useDocumentDelete()
|
||||
@ -232,6 +234,9 @@ const DocumentList: FC<IDocumentListProps> = ({
|
||||
case DocumentActionType.archive:
|
||||
opApi = archiveDocument
|
||||
break
|
||||
case DocumentActionType.summary:
|
||||
opApi = generateSummary
|
||||
break
|
||||
case DocumentActionType.enable:
|
||||
opApi = enableDocument
|
||||
break
|
||||
@ -444,6 +449,13 @@ const DocumentList: FC<IDocumentListProps> = ({
|
||||
>
|
||||
<span className="grow-1 truncate text-sm">{doc.name}</span>
|
||||
</Tooltip>
|
||||
{
|
||||
doc.summary_index_status && (
|
||||
<div className="ml-1 hidden shrink-0 group-hover:flex">
|
||||
<SummaryStatus status={doc.summary_index_status} />
|
||||
</div>
|
||||
)
|
||||
}
|
||||
<div className="hidden shrink-0 group-hover:ml-auto group-hover:flex">
|
||||
<Tooltip
|
||||
popupContent={t('list.table.rename', { ns: 'datasetDocuments' })}
|
||||
@ -496,6 +508,7 @@ const DocumentList: FC<IDocumentListProps> = ({
|
||||
className="absolute bottom-16 left-0 z-20"
|
||||
selectedIds={selectedIds}
|
||||
onArchive={handleAction(DocumentActionType.archive)}
|
||||
onBatchSummary={handleAction(DocumentActionType.summary)}
|
||||
onBatchEnable={handleAction(DocumentActionType.enable)}
|
||||
onBatchDisable={handleAction(DocumentActionType.disable)}
|
||||
onBatchDownload={downloadableSelectedIds.length > 0 ? handleBatchDownload : undefined}
|
||||
|
||||
@ -15,6 +15,7 @@ vi.mock('@/service/knowledge/use-document', () => ({
|
||||
useSyncWebsite: () => ({ mutateAsync: vi.fn().mockResolvedValue({}) }),
|
||||
useDocumentPause: () => ({ mutateAsync: vi.fn().mockResolvedValue({}) }),
|
||||
useDocumentResume: () => ({ mutateAsync: vi.fn().mockResolvedValue({}) }),
|
||||
useDocumentSummary: () => ({ mutateAsync: vi.fn().mockResolvedValue({}) }),
|
||||
}))
|
||||
|
||||
// Mock utils
|
||||
|
||||
@ -21,6 +21,7 @@ import { useTranslation } from 'react-i18next'
|
||||
import { useContext } from 'use-context-selector'
|
||||
import Confirm from '@/app/components/base/confirm'
|
||||
import Divider from '@/app/components/base/divider'
|
||||
import { SearchLinesSparkle } from '@/app/components/base/icons/src/vender/knowledge'
|
||||
import CustomPopover from '@/app/components/base/popover'
|
||||
import Switch from '@/app/components/base/switch'
|
||||
import { ToastContext } from '@/app/components/base/toast'
|
||||
@ -34,6 +35,7 @@ import {
|
||||
useDocumentEnable,
|
||||
useDocumentPause,
|
||||
useDocumentResume,
|
||||
useDocumentSummary,
|
||||
useDocumentUnArchive,
|
||||
useSyncDocument,
|
||||
useSyncWebsite,
|
||||
@ -87,6 +89,7 @@ const Operations = ({
|
||||
const { mutateAsync: downloadDocument, isPending: isDownloading } = useDocumentDownload()
|
||||
const { mutateAsync: syncDocument } = useSyncDocument()
|
||||
const { mutateAsync: syncWebsite } = useSyncWebsite()
|
||||
const { mutateAsync: generateSummary } = useDocumentSummary()
|
||||
const { mutateAsync: pauseDocument } = useDocumentPause()
|
||||
const { mutateAsync: resumeDocument } = useDocumentResume()
|
||||
const isListScene = scene === 'list'
|
||||
@ -112,6 +115,9 @@ const Operations = ({
|
||||
else
|
||||
opApi = syncWebsite
|
||||
break
|
||||
case 'summary':
|
||||
opApi = generateSummary
|
||||
break
|
||||
case 'pause':
|
||||
opApi = pauseDocument
|
||||
break
|
||||
@ -257,6 +263,10 @@ const Operations = ({
|
||||
<span className={s.actionName}>{t('list.action.sync', { ns: 'datasetDocuments' })}</span>
|
||||
</div>
|
||||
)}
|
||||
<div className={s.actionItem} onClick={() => onOperate('summary')}>
|
||||
<SearchLinesSparkle className="h-4 w-4 text-text-tertiary" />
|
||||
<span className={s.actionName}>{t('list.action.summary', { ns: 'datasetDocuments' })}</span>
|
||||
</div>
|
||||
<Divider className="my-1" />
|
||||
</>
|
||||
)}
|
||||
|
||||
@ -8,6 +8,7 @@ import { useTranslation } from 'react-i18next'
|
||||
import Badge from '@/app/components/base/badge'
|
||||
import Button from '@/app/components/base/button'
|
||||
import { SkeletonContainer, SkeletonPoint, SkeletonRectangle, SkeletonRow } from '@/app/components/base/skeleton'
|
||||
import SummaryLabel from '@/app/components/datasets/documents/detail/completed/common/summary-label'
|
||||
import { useDatasetDetailContextWithSelector } from '@/context/dataset-detail'
|
||||
import { ChunkingMode } from '@/models/datasets'
|
||||
import { DatasourceType } from '@/models/pipeline'
|
||||
@ -181,6 +182,7 @@ const ChunkPreview = ({
|
||||
characterCount={item.content.length}
|
||||
>
|
||||
{item.content}
|
||||
{item.summary && <SummaryLabel summary={item.summary} />}
|
||||
</ChunkContainer>
|
||||
))
|
||||
)}
|
||||
@ -207,6 +209,7 @@ const ChunkPreview = ({
|
||||
/>
|
||||
)
|
||||
})}
|
||||
{item.summary && <SummaryLabel summary={item.summary} />}
|
||||
</FormattedText>
|
||||
</ChunkContainer>
|
||||
)
|
||||
|
||||
@ -6,6 +6,7 @@ import { useTranslation } from 'react-i18next'
|
||||
import Button from '@/app/components/base/button'
|
||||
import Confirm from '@/app/components/base/confirm'
|
||||
import Divider from '@/app/components/base/divider'
|
||||
import { SearchLinesSparkle } from '@/app/components/base/icons/src/vender/knowledge'
|
||||
import { cn } from '@/utils/classnames'
|
||||
|
||||
const i18nPrefix = 'batchAction'
|
||||
@ -16,6 +17,7 @@ type IBatchActionProps = {
|
||||
onBatchDisable: () => void
|
||||
onBatchDownload?: () => void
|
||||
onBatchDelete: () => Promise<void>
|
||||
onBatchSummary?: () => void
|
||||
onArchive?: () => void
|
||||
onEditMetadata?: () => void
|
||||
onBatchReIndex?: () => void
|
||||
@ -27,6 +29,7 @@ const BatchAction: FC<IBatchActionProps> = ({
|
||||
selectedIds,
|
||||
onBatchEnable,
|
||||
onBatchDisable,
|
||||
onBatchSummary,
|
||||
onBatchDownload,
|
||||
onArchive,
|
||||
onBatchDelete,
|
||||
@ -84,7 +87,16 @@ const BatchAction: FC<IBatchActionProps> = ({
|
||||
<span className="px-0.5">{t('metadata.metadata', { ns: 'dataset' })}</span>
|
||||
</Button>
|
||||
)}
|
||||
|
||||
{onBatchSummary && (
|
||||
<Button
|
||||
variant="ghost"
|
||||
className="gap-x-0.5 px-3"
|
||||
onClick={onBatchSummary}
|
||||
>
|
||||
<SearchLinesSparkle className="size-4" />
|
||||
<span className="px-0.5">{t('list.action.summary', { ns: 'datasetDocuments' })}</span>
|
||||
</Button>
|
||||
)}
|
||||
{onArchive && (
|
||||
<Button
|
||||
variant="ghost"
|
||||
|
||||
@ -0,0 +1,26 @@
|
||||
import { memo } from 'react'
|
||||
import { useTranslation } from 'react-i18next'
|
||||
import { cn } from '@/utils/classnames'
|
||||
|
||||
type SummaryLabelProps = {
|
||||
summary?: string
|
||||
className?: string
|
||||
}
|
||||
const SummaryLabel = ({
|
||||
summary,
|
||||
className,
|
||||
}: SummaryLabelProps) => {
|
||||
const { t } = useTranslation()
|
||||
|
||||
return (
|
||||
<div className={cn('space-y-1', className)}>
|
||||
<div className="system-xs-medium-uppercase mt-2 flex items-center justify-between text-text-tertiary">
|
||||
{t('segment.summary', { ns: 'datasetDocuments' })}
|
||||
<div className="ml-2 h-px grow bg-divider-regular"></div>
|
||||
</div>
|
||||
<div className="body-xs-regular text-text-tertiary">{summary}</div>
|
||||
</div>
|
||||
)
|
||||
}
|
||||
|
||||
export default memo(SummaryLabel)
|
||||
@ -0,0 +1,37 @@
|
||||
import { memo, useMemo } from 'react'
|
||||
import { useTranslation } from 'react-i18next'
|
||||
import Badge from '@/app/components/base/badge'
|
||||
import { SearchLinesSparkle } from '@/app/components/base/icons/src/vender/knowledge'
|
||||
import Tooltip from '@/app/components/base/tooltip'
|
||||
|
||||
type SummaryStatusProps = {
|
||||
status: string
|
||||
}
|
||||
|
||||
const SummaryStatus = ({ status }: SummaryStatusProps) => {
|
||||
const { t } = useTranslation()
|
||||
|
||||
const tip = useMemo(() => {
|
||||
if (status === 'SUMMARIZING') {
|
||||
return t('list.summary.generatingSummary', { ns: 'datasetDocuments' })
|
||||
}
|
||||
return ''
|
||||
}, [status, t])
|
||||
|
||||
return (
|
||||
<Tooltip
|
||||
popupContent={tip}
|
||||
>
|
||||
{
|
||||
status === 'SUMMARIZING' && (
|
||||
<Badge className="border-text-accent-secondary text-text-accent-secondary">
|
||||
<SearchLinesSparkle className="mr-0.5 h-3 w-3" />
|
||||
<span>{t('list.summary.generating', { ns: 'datasetDocuments' })}</span>
|
||||
</Badge>
|
||||
)
|
||||
}
|
||||
</Tooltip>
|
||||
)
|
||||
}
|
||||
|
||||
export default memo(SummaryStatus)
|
||||
@ -0,0 +1,35 @@
|
||||
import { memo } from 'react'
|
||||
import { useTranslation } from 'react-i18next'
|
||||
import Textarea from 'react-textarea-autosize'
|
||||
import { cn } from '@/utils/classnames'
|
||||
|
||||
type SummaryTextProps = {
|
||||
value?: string
|
||||
onChange?: (value: string) => void
|
||||
disabled?: boolean
|
||||
}
|
||||
const SummaryText = ({
|
||||
value,
|
||||
onChange,
|
||||
disabled,
|
||||
}: SummaryTextProps) => {
|
||||
const { t } = useTranslation()
|
||||
|
||||
return (
|
||||
<div className="space-y-1">
|
||||
<div className="system-xs-medium-uppercase text-text-tertiary">{t('segment.summary', { ns: 'datasetDocuments' })}</div>
|
||||
<Textarea
|
||||
className={cn(
|
||||
'body-sm-regular w-full resize-none bg-transparent leading-6 text-text-secondary outline-none',
|
||||
)}
|
||||
placeholder={t('segment.summaryPlaceholder', { ns: 'datasetDocuments' })}
|
||||
minRows={1}
|
||||
value={value ?? ''}
|
||||
onChange={e => onChange?.(e.target.value)}
|
||||
disabled={disabled}
|
||||
/>
|
||||
</div>
|
||||
)
|
||||
}
|
||||
|
||||
export default memo(SummaryText)
|
||||
@ -22,6 +22,7 @@ type DrawerGroupProps = {
|
||||
answer: string,
|
||||
keywords: string[],
|
||||
attachments: FileEntity[],
|
||||
summary?: string,
|
||||
needRegenerate?: boolean,
|
||||
) => Promise<void>
|
||||
isRegenerationModalOpen: boolean
|
||||
|
||||
@ -614,7 +614,7 @@ describe('useSegmentListData', () => {
|
||||
})
|
||||
|
||||
await act(async () => {
|
||||
await result.current.handleUpdateSegment('seg-1', 'content', '', [], [], true)
|
||||
await result.current.handleUpdateSegment('seg-1', 'content', '', [], [], 'summary', true)
|
||||
})
|
||||
|
||||
expect(onCloseSegmentDetail).not.toHaveBeenCalled()
|
||||
|
||||
@ -53,6 +53,7 @@ export type UseSegmentListDataReturn = {
|
||||
answer: string,
|
||||
keywords: string[],
|
||||
attachments: FileEntity[],
|
||||
summary?: string,
|
||||
needRegenerate?: boolean,
|
||||
) => Promise<void>
|
||||
resetList: () => void
|
||||
@ -248,6 +249,7 @@ export const useSegmentListData = (options: UseSegmentListDataOptions): UseSegme
|
||||
answer: string,
|
||||
keywords: string[],
|
||||
attachments: FileEntity[],
|
||||
summary?: string,
|
||||
needRegenerate = false,
|
||||
) => {
|
||||
const params: SegmentUpdater = { content: '', attachment_ids: [] }
|
||||
@ -285,6 +287,8 @@ export const useSegmentListData = (options: UseSegmentListDataOptions): UseSegme
|
||||
params.attachment_ids = attachments.map(item => item.uploadedId!)
|
||||
}
|
||||
|
||||
params.summary = summary ?? ''
|
||||
|
||||
if (needRegenerate)
|
||||
params.regenerate_child_chunks = needRegenerate
|
||||
|
||||
@ -302,6 +306,7 @@ export const useSegmentListData = (options: UseSegmentListDataOptions): UseSegme
|
||||
sign_content: res.data.sign_content,
|
||||
keywords: res.data.keywords,
|
||||
attachments: res.data.attachments,
|
||||
summary: res.data.summary,
|
||||
word_count: res.data.word_count,
|
||||
hit_count: res.data.hit_count,
|
||||
enabled: res.data.enabled,
|
||||
|
||||
@ -19,13 +19,14 @@ import { useDocumentContext } from '../../context'
|
||||
import ChildSegmentList from '../child-segment-list'
|
||||
import Dot from '../common/dot'
|
||||
import { SegmentIndexTag } from '../common/segment-index-tag'
|
||||
import SummaryLabel from '../common/summary-label'
|
||||
import Tag from '../common/tag'
|
||||
import ParentChunkCardSkeleton from '../skeleton/parent-chunk-card-skeleton'
|
||||
import ChunkContent from './chunk-content'
|
||||
|
||||
type ISegmentCardProps = {
|
||||
loading: boolean
|
||||
detail?: SegmentDetailModel & { document?: { name: string } }
|
||||
detail?: SegmentDetailModel & { document?: { name: string }, status?: string }
|
||||
onClick?: () => void
|
||||
onChangeSwitch?: (enabled: boolean, segId?: string) => Promise<void>
|
||||
onDelete?: (segId: string) => Promise<void>
|
||||
@ -43,7 +44,7 @@ type ISegmentCardProps = {
|
||||
}
|
||||
|
||||
const SegmentCard: FC<ISegmentCardProps> = ({
|
||||
detail = {},
|
||||
detail = { status: '' },
|
||||
onClick,
|
||||
onChangeSwitch,
|
||||
onDelete,
|
||||
@ -67,6 +68,7 @@ const SegmentCard: FC<ISegmentCardProps> = ({
|
||||
word_count,
|
||||
hit_count,
|
||||
answer,
|
||||
summary,
|
||||
keywords,
|
||||
child_chunks = [],
|
||||
created_at,
|
||||
@ -237,6 +239,11 @@ const SegmentCard: FC<ISegmentCardProps> = ({
|
||||
className={contentOpacity}
|
||||
/>
|
||||
{images.length > 0 && <ImageList images={images} size="md" className="py-1" />}
|
||||
{
|
||||
summary && (
|
||||
<SummaryLabel summary={summary} className="mt-2" />
|
||||
)
|
||||
}
|
||||
{isGeneralMode && (
|
||||
<div className={cn('flex flex-wrap items-center gap-2 py-1.5', contentOpacity)}>
|
||||
{keywords?.map(keyword => <Tag key={keyword} text={keyword} />)}
|
||||
|
||||
@ -356,6 +356,8 @@ describe('SegmentDetail', () => {
|
||||
expect.any(String),
|
||||
expect.any(Array),
|
||||
expect.any(Array),
|
||||
expect.any(String),
|
||||
expect.any(Boolean),
|
||||
)
|
||||
})
|
||||
|
||||
@ -545,6 +547,8 @@ describe('SegmentDetail', () => {
|
||||
expect.any(String),
|
||||
expect.any(Array),
|
||||
expect.arrayContaining([expect.objectContaining({ id: 'new-attachment' })]),
|
||||
expect.any(String),
|
||||
expect.any(Boolean),
|
||||
)
|
||||
})
|
||||
|
||||
@ -585,6 +589,7 @@ describe('SegmentDetail', () => {
|
||||
expect.any(String),
|
||||
expect.any(Array),
|
||||
expect.any(Array),
|
||||
expect.any(String),
|
||||
true,
|
||||
)
|
||||
})
|
||||
|
||||
@ -25,6 +25,7 @@ import Dot from './common/dot'
|
||||
import Keywords from './common/keywords'
|
||||
import RegenerationModal from './common/regeneration-modal'
|
||||
import { SegmentIndexTag } from './common/segment-index-tag'
|
||||
import SummaryText from './common/summary-text'
|
||||
import { useSegmentListContext } from './index'
|
||||
|
||||
type ISegmentDetailProps = {
|
||||
@ -35,6 +36,7 @@ type ISegmentDetailProps = {
|
||||
a: string,
|
||||
k: string[],
|
||||
attachments: FileEntity[],
|
||||
summary?: string,
|
||||
needRegenerate?: boolean,
|
||||
) => void
|
||||
onCancel: () => void
|
||||
@ -57,6 +59,7 @@ const SegmentDetail: FC<ISegmentDetailProps> = ({
|
||||
const { t } = useTranslation()
|
||||
const [question, setQuestion] = useState(isEditMode ? segInfo?.content || '' : segInfo?.sign_content || '')
|
||||
const [answer, setAnswer] = useState(segInfo?.answer || '')
|
||||
const [summary, setSummary] = useState(segInfo?.summary || '')
|
||||
const [attachments, setAttachments] = useState<FileEntity[]>(() => {
|
||||
return segInfo?.attachments?.map(item => ({
|
||||
id: uuid4(),
|
||||
@ -91,8 +94,8 @@ const SegmentDetail: FC<ISegmentDetailProps> = ({
|
||||
}, [onCancel])
|
||||
|
||||
const handleSave = useCallback(() => {
|
||||
onUpdate(segInfo?.id || '', question, answer, keywords, attachments)
|
||||
}, [onUpdate, segInfo?.id, question, answer, keywords, attachments])
|
||||
onUpdate(segInfo?.id || '', question, answer, keywords, attachments, summary, false)
|
||||
}, [onUpdate, segInfo?.id, question, answer, keywords, attachments, summary])
|
||||
|
||||
const handleRegeneration = useCallback(() => {
|
||||
setShowRegenerationModal(true)
|
||||
@ -111,8 +114,8 @@ const SegmentDetail: FC<ISegmentDetailProps> = ({
|
||||
}, [onCancel, onModalStateChange])
|
||||
|
||||
const onConfirmRegeneration = useCallback(() => {
|
||||
onUpdate(segInfo?.id || '', question, answer, keywords, attachments, true)
|
||||
}, [onUpdate, segInfo?.id, question, answer, keywords, attachments])
|
||||
onUpdate(segInfo?.id || '', question, answer, keywords, attachments, summary, true)
|
||||
}, [onUpdate, segInfo?.id, question, answer, keywords, attachments, summary])
|
||||
|
||||
const onAttachmentsChange = useCallback((attachments: FileEntity[]) => {
|
||||
setAttachments(attachments)
|
||||
@ -197,6 +200,11 @@ const SegmentDetail: FC<ISegmentDetailProps> = ({
|
||||
value={attachments}
|
||||
onChange={onAttachmentsChange}
|
||||
/>
|
||||
<SummaryText
|
||||
value={summary}
|
||||
onChange={summary => setSummary(summary)}
|
||||
disabled={!isEditMode}
|
||||
/>
|
||||
{isECOIndexing && (
|
||||
<Keywords
|
||||
className="w-full"
|
||||
|
||||
@ -1 +1 @@
|
||||
export type OperationName = 'delete' | 'archive' | 'enable' | 'disable' | 'sync' | 'un_archive' | 'pause' | 'resume'
|
||||
export type OperationName = 'delete' | 'archive' | 'enable' | 'disable' | 'sync' | 'un_archive' | 'pause' | 'resume' | 'summary'
|
||||
|
||||
@ -12,6 +12,7 @@ import { cn } from '@/utils/classnames'
|
||||
import ImageList from '../../common/image-list'
|
||||
import Dot from '../../documents/detail/completed/common/dot'
|
||||
import { SegmentIndexTag } from '../../documents/detail/completed/common/segment-index-tag'
|
||||
import SummaryText from '../../documents/detail/completed/common/summary-text'
|
||||
import ChildChunksItem from './child-chunks-item'
|
||||
import Mask from './mask'
|
||||
import Score from './score'
|
||||
@ -28,7 +29,7 @@ const ChunkDetailModal = ({
|
||||
onHide,
|
||||
}: ChunkDetailModalProps) => {
|
||||
const { t } = useTranslation()
|
||||
const { segment, score, child_chunks, files } = payload
|
||||
const { segment, score, child_chunks, files, summary } = payload
|
||||
const { position, content, sign_content, keywords, document, answer } = segment
|
||||
const isParentChildRetrieval = !!(child_chunks && child_chunks.length > 0)
|
||||
const extension = document.name.split('.').slice(-1)[0] as FileAppearanceTypeEnum
|
||||
@ -104,11 +105,14 @@ const ChunkDetailModal = ({
|
||||
{/* Mask */}
|
||||
<Mask className="absolute inset-x-0 bottom-0" />
|
||||
</div>
|
||||
{(showImages || showKeywords) && (
|
||||
{(showImages || showKeywords || !!summary) && (
|
||||
<div className="flex flex-col gap-y-3 pt-3">
|
||||
{showImages && (
|
||||
<ImageList images={images} size="md" className="py-1" />
|
||||
)}
|
||||
{!!summary && (
|
||||
<SummaryText value={summary} disabled />
|
||||
)}
|
||||
{showKeywords && (
|
||||
<div className="flex flex-col gap-y-1">
|
||||
<div className="text-xs font-medium uppercase text-text-tertiary">{t(`${i18nPrefix}keyword`, { ns: 'datasetHitTesting' })}</div>
|
||||
|
||||
@ -7,6 +7,7 @@ import * as React from 'react'
|
||||
import { useMemo } from 'react'
|
||||
import { useTranslation } from 'react-i18next'
|
||||
import { Markdown } from '@/app/components/base/markdown'
|
||||
import SummaryLabel from '@/app/components/datasets/documents/detail/completed/common/summary-label'
|
||||
import Tag from '@/app/components/datasets/documents/detail/completed/common/tag'
|
||||
import { extensionToFileType } from '@/app/components/datasets/hit-testing/utils/extension-to-file-type'
|
||||
import { cn } from '@/utils/classnames'
|
||||
@ -25,7 +26,7 @@ const ResultItem = ({
|
||||
payload,
|
||||
}: ResultItemProps) => {
|
||||
const { t } = useTranslation()
|
||||
const { segment, score, child_chunks, files } = payload
|
||||
const { segment, score, child_chunks, files, summary } = payload
|
||||
const data = segment
|
||||
const { position, word_count, content, sign_content, keywords, document } = data
|
||||
const isParentChildRetrieval = !!(child_chunks && child_chunks.length > 0)
|
||||
@ -98,6 +99,9 @@ const ResultItem = ({
|
||||
))}
|
||||
</div>
|
||||
)}
|
||||
{summary && (
|
||||
<SummaryLabel summary={summary} className="mt-2" />
|
||||
)}
|
||||
</div>
|
||||
{/* Foot */}
|
||||
<ResultItemFooter docType={fileType} docTitle={document.name} showDetailModal={showDetailModal} />
|
||||
|
||||
@ -2,7 +2,7 @@
|
||||
import type { AppIconSelection } from '@/app/components/base/app-icon-picker'
|
||||
import type { DefaultModel } from '@/app/components/header/account-setting/model-provider-page/declarations'
|
||||
import type { Member } from '@/models/common'
|
||||
import type { IconInfo } from '@/models/datasets'
|
||||
import type { IconInfo, SummaryIndexSetting as SummaryIndexSettingType } from '@/models/datasets'
|
||||
import type { AppIconType, RetrievalConfig } from '@/types/app'
|
||||
import { RiAlertFill } from '@remixicon/react'
|
||||
import { useCallback, useEffect, useMemo, useRef, useState } from 'react'
|
||||
@ -33,6 +33,7 @@ import RetrievalSettings from '../../external-knowledge-base/create/RetrievalSet
|
||||
import ChunkStructure from '../chunk-structure'
|
||||
import IndexMethod from '../index-method'
|
||||
import PermissionSelector from '../permission-selector'
|
||||
import SummaryIndexSetting from '../summary-index-setting'
|
||||
import { checkShowMultiModalTip } from '../utils'
|
||||
|
||||
const rowClass = 'flex gap-x-1'
|
||||
@ -76,6 +77,12 @@ const Form = () => {
|
||||
model: '',
|
||||
},
|
||||
)
|
||||
const [summaryIndexSetting, setSummaryIndexSetting] = useState(currentDataset?.summary_index_setting)
|
||||
const handleSummaryIndexSettingChange = useCallback((payload: SummaryIndexSettingType) => {
|
||||
setSummaryIndexSetting((prev) => {
|
||||
return { ...prev, ...payload }
|
||||
})
|
||||
}, [])
|
||||
const { data: rerankModelList } = useModelList(ModelTypeEnum.rerank)
|
||||
const { data: embeddingModelList } = useModelList(ModelTypeEnum.textEmbedding)
|
||||
const { data: membersData } = useMembers()
|
||||
@ -167,6 +174,7 @@ const Form = () => {
|
||||
},
|
||||
}),
|
||||
keyword_number: keywordNumber,
|
||||
summary_index_setting: summaryIndexSetting,
|
||||
},
|
||||
} as any
|
||||
if (permission === DatasetPermission.partialMembers) {
|
||||
@ -348,6 +356,23 @@ const Form = () => {
|
||||
</div>
|
||||
</div>
|
||||
)}
|
||||
{
|
||||
indexMethod === IndexingType.QUALIFIED
|
||||
&& [ChunkingMode.text, ChunkingMode.parentChild].includes(currentDataset?.doc_form as ChunkingMode)
|
||||
&& (
|
||||
<>
|
||||
<Divider
|
||||
type="horizontal"
|
||||
className="my-1 h-px bg-divider-subtle"
|
||||
/>
|
||||
<SummaryIndexSetting
|
||||
entry="dataset-settings"
|
||||
summaryIndexSetting={summaryIndexSetting}
|
||||
onSummaryIndexSettingChange={handleSummaryIndexSettingChange}
|
||||
/>
|
||||
</>
|
||||
)
|
||||
}
|
||||
{/* Retrieval Method Config */}
|
||||
{currentDataset?.provider === 'external'
|
||||
? (
|
||||
|
||||
228
web/app/components/datasets/settings/summary-index-setting.tsx
Normal file
228
web/app/components/datasets/settings/summary-index-setting.tsx
Normal file
@ -0,0 +1,228 @@
|
||||
import type { ChangeEvent } from 'react'
|
||||
import type { DefaultModel } from '@/app/components/header/account-setting/model-provider-page/declarations'
|
||||
import type { SummaryIndexSetting as SummaryIndexSettingType } from '@/models/datasets'
|
||||
import {
|
||||
memo,
|
||||
useCallback,
|
||||
useMemo,
|
||||
} from 'react'
|
||||
import { useTranslation } from 'react-i18next'
|
||||
import Switch from '@/app/components/base/switch'
|
||||
import Textarea from '@/app/components/base/textarea'
|
||||
import Tooltip from '@/app/components/base/tooltip'
|
||||
import { ModelTypeEnum } from '@/app/components/header/account-setting/model-provider-page/declarations'
|
||||
import { useModelList } from '@/app/components/header/account-setting/model-provider-page/hooks'
|
||||
import ModelSelector from '@/app/components/header/account-setting/model-provider-page/model-selector'
|
||||
|
||||
type SummaryIndexSettingProps = {
|
||||
entry?: 'knowledge-base' | 'dataset-settings' | 'create-document'
|
||||
summaryIndexSetting?: SummaryIndexSettingType
|
||||
onSummaryIndexSettingChange?: (payload: SummaryIndexSettingType) => void
|
||||
readonly?: boolean
|
||||
}
|
||||
const SummaryIndexSetting = ({
|
||||
entry = 'knowledge-base',
|
||||
summaryIndexSetting,
|
||||
onSummaryIndexSettingChange,
|
||||
readonly = false,
|
||||
}: SummaryIndexSettingProps) => {
|
||||
const { t } = useTranslation()
|
||||
const {
|
||||
data: textGenerationModelList,
|
||||
} = useModelList(ModelTypeEnum.textGeneration)
|
||||
const summaryIndexModelConfig = useMemo(() => {
|
||||
if (!summaryIndexSetting?.model_name || !summaryIndexSetting?.model_provider_name)
|
||||
return undefined
|
||||
|
||||
return {
|
||||
providerName: summaryIndexSetting?.model_provider_name,
|
||||
modelName: summaryIndexSetting?.model_name,
|
||||
}
|
||||
}, [summaryIndexSetting?.model_name, summaryIndexSetting?.model_provider_name])
|
||||
|
||||
const handleSummaryIndexEnableChange = useCallback((value: boolean) => {
|
||||
onSummaryIndexSettingChange?.({
|
||||
enable: value,
|
||||
})
|
||||
}, [onSummaryIndexSettingChange])
|
||||
|
||||
const handleSummaryIndexModelChange = useCallback((model: DefaultModel) => {
|
||||
onSummaryIndexSettingChange?.({
|
||||
model_provider_name: model.provider,
|
||||
model_name: model.model,
|
||||
})
|
||||
}, [onSummaryIndexSettingChange])
|
||||
|
||||
const handleSummaryIndexPromptChange = useCallback((e: ChangeEvent<HTMLTextAreaElement>) => {
|
||||
onSummaryIndexSettingChange?.({
|
||||
summary_prompt: e.target.value,
|
||||
})
|
||||
}, [onSummaryIndexSettingChange])
|
||||
|
||||
if (entry === 'knowledge-base') {
|
||||
return (
|
||||
<div>
|
||||
<div className="flex h-6 items-center justify-between">
|
||||
<div className="system-sm-semibold-uppercase flex items-center text-text-secondary">
|
||||
{t('form.summaryAutoGen', { ns: 'datasetSettings' })}
|
||||
<Tooltip
|
||||
triggerClassName="ml-1 h-4 w-4 shrink-0"
|
||||
popupContent={t('form.summaryAutoGenTip', { ns: 'datasetSettings' })}
|
||||
>
|
||||
</Tooltip>
|
||||
</div>
|
||||
<Switch
|
||||
defaultValue={summaryIndexSetting?.enable ?? false}
|
||||
onChange={handleSummaryIndexEnableChange}
|
||||
size="md"
|
||||
/>
|
||||
</div>
|
||||
{
|
||||
summaryIndexSetting?.enable && (
|
||||
<div>
|
||||
<div className="system-xs-medium-uppercase mb-1.5 mt-2 flex h-6 items-center text-text-tertiary">
|
||||
{t('form.summaryModel', { ns: 'datasetSettings' })}
|
||||
</div>
|
||||
<ModelSelector
|
||||
defaultModel={summaryIndexModelConfig && { provider: summaryIndexModelConfig.providerName, model: summaryIndexModelConfig.modelName }}
|
||||
modelList={textGenerationModelList}
|
||||
onSelect={handleSummaryIndexModelChange}
|
||||
readonly={readonly}
|
||||
showDeprecatedWarnIcon
|
||||
/>
|
||||
<div className="system-xs-medium-uppercase mt-3 flex h-6 items-center text-text-tertiary">
|
||||
{t('form.summaryInstructions', { ns: 'datasetSettings' })}
|
||||
</div>
|
||||
<Textarea
|
||||
value={summaryIndexSetting?.summary_prompt ?? ''}
|
||||
onChange={handleSummaryIndexPromptChange}
|
||||
disabled={readonly}
|
||||
placeholder={t('form.summaryInstructionsPlaceholder', { ns: 'datasetSettings' })}
|
||||
/>
|
||||
</div>
|
||||
)
|
||||
}
|
||||
</div>
|
||||
)
|
||||
}
|
||||
|
||||
if (entry === 'dataset-settings') {
|
||||
return (
|
||||
<div className="space-y-4">
|
||||
<div className="flex gap-x-1">
|
||||
<div className="flex h-7 w-[180px] shrink-0 items-center pt-1">
|
||||
<div className="system-sm-semibold text-text-secondary">
|
||||
{t('form.summaryAutoGen', { ns: 'datasetSettings' })}
|
||||
</div>
|
||||
</div>
|
||||
<div className="py-1.5">
|
||||
<div className="system-sm-semibold flex items-center text-text-secondary">
|
||||
<Switch
|
||||
className="mr-2"
|
||||
defaultValue={summaryIndexSetting?.enable ?? false}
|
||||
onChange={handleSummaryIndexEnableChange}
|
||||
size="md"
|
||||
/>
|
||||
{
|
||||
summaryIndexSetting?.enable ? t('list.status.enabled', { ns: 'datasetDocuments' }) : t('list.status.disabled', { ns: 'datasetDocuments' })
|
||||
}
|
||||
</div>
|
||||
<div className="system-sm-regular mt-2 text-text-tertiary">
|
||||
{
|
||||
summaryIndexSetting?.enable && t('form.summaryAutoGenTip', { ns: 'datasetSettings' })
|
||||
}
|
||||
{
|
||||
!summaryIndexSetting?.enable && t('form.summaryAutoGenEnableTip', { ns: 'datasetSettings' })
|
||||
}
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
{
|
||||
summaryIndexSetting?.enable && (
|
||||
<>
|
||||
<div className="flex gap-x-1">
|
||||
<div className="flex h-7 w-[180px] shrink-0 items-center pt-1">
|
||||
<div className="system-sm-medium text-text-tertiary">
|
||||
{t('form.summaryModel', { ns: 'datasetSettings' })}
|
||||
</div>
|
||||
</div>
|
||||
<div className="grow">
|
||||
<ModelSelector
|
||||
defaultModel={summaryIndexModelConfig && { provider: summaryIndexModelConfig.providerName, model: summaryIndexModelConfig.modelName }}
|
||||
modelList={textGenerationModelList}
|
||||
onSelect={handleSummaryIndexModelChange}
|
||||
readonly={readonly}
|
||||
showDeprecatedWarnIcon
|
||||
triggerClassName="h-8"
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
<div className="flex">
|
||||
<div className="flex h-7 w-[180px] shrink-0 items-center pt-1">
|
||||
<div className="system-sm-medium text-text-tertiary">
|
||||
{t('form.summaryInstructions', { ns: 'datasetSettings' })}
|
||||
</div>
|
||||
</div>
|
||||
<div className="grow">
|
||||
<Textarea
|
||||
value={summaryIndexSetting?.summary_prompt ?? ''}
|
||||
onChange={handleSummaryIndexPromptChange}
|
||||
disabled={readonly}
|
||||
placeholder={t('form.summaryInstructionsPlaceholder', { ns: 'datasetSettings' })}
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
</>
|
||||
)
|
||||
}
|
||||
</div>
|
||||
)
|
||||
}
|
||||
|
||||
return (
|
||||
<div className="space-y-3">
|
||||
<div className="flex h-6 items-center">
|
||||
<Switch
|
||||
className="mr-2"
|
||||
defaultValue={summaryIndexSetting?.enable ?? false}
|
||||
onChange={handleSummaryIndexEnableChange}
|
||||
size="md"
|
||||
/>
|
||||
<div className="system-sm-semibold text-text-secondary">
|
||||
{t('form.summaryAutoGen', { ns: 'datasetSettings' })}
|
||||
</div>
|
||||
</div>
|
||||
{
|
||||
summaryIndexSetting?.enable && (
|
||||
<>
|
||||
<div>
|
||||
<div className="system-sm-medium mb-1.5 flex h-6 items-center text-text-secondary">
|
||||
{t('form.summaryModel', { ns: 'datasetSettings' })}
|
||||
</div>
|
||||
<ModelSelector
|
||||
defaultModel={summaryIndexModelConfig && { provider: summaryIndexModelConfig.providerName, model: summaryIndexModelConfig.modelName }}
|
||||
modelList={textGenerationModelList}
|
||||
onSelect={handleSummaryIndexModelChange}
|
||||
readonly={readonly}
|
||||
showDeprecatedWarnIcon
|
||||
triggerClassName="h-8"
|
||||
/>
|
||||
</div>
|
||||
<div>
|
||||
<div className="system-sm-medium mb-1.5 flex h-6 items-center text-text-secondary">
|
||||
{t('form.summaryInstructions', { ns: 'datasetSettings' })}
|
||||
</div>
|
||||
<Textarea
|
||||
value={summaryIndexSetting?.summary_prompt ?? ''}
|
||||
onChange={handleSummaryIndexPromptChange}
|
||||
disabled={readonly}
|
||||
placeholder={t('form.summaryInstructionsPlaceholder', { ns: 'datasetSettings' })}
|
||||
/>
|
||||
</div>
|
||||
</>
|
||||
)
|
||||
}
|
||||
</div>
|
||||
)
|
||||
}
|
||||
export default memo(SummaryIndexSetting)
|
||||
@ -1,10 +1,11 @@
|
||||
import type { QAChunk } from './types'
|
||||
import type { GeneralChunk, ParentChildChunk, QAChunk } from './types'
|
||||
import type { ParentMode } from '@/models/datasets'
|
||||
import * as React from 'react'
|
||||
import { useMemo } from 'react'
|
||||
import { useTranslation } from 'react-i18next'
|
||||
import Dot from '@/app/components/datasets/documents/detail/completed/common/dot'
|
||||
import SegmentIndexTag from '@/app/components/datasets/documents/detail/completed/common/segment-index-tag'
|
||||
import SummaryLabel from '@/app/components/datasets/documents/detail/completed/common/summary-label'
|
||||
import { PreviewSlice } from '@/app/components/datasets/formatted-text/flavours/preview-slice'
|
||||
import { ChunkingMode } from '@/models/datasets'
|
||||
import { formatNumber } from '@/utils/format'
|
||||
@ -14,7 +15,7 @@ import { QAItemType } from './types'
|
||||
type ChunkCardProps = {
|
||||
chunkType: ChunkingMode
|
||||
parentMode?: ParentMode
|
||||
content: string | string[] | QAChunk
|
||||
content: ParentChildChunk | QAChunk | GeneralChunk
|
||||
positionId?: string | number
|
||||
wordCount: number
|
||||
}
|
||||
@ -33,7 +34,7 @@ const ChunkCard = (props: ChunkCardProps) => {
|
||||
|
||||
const contentElement = useMemo(() => {
|
||||
if (chunkType === ChunkingMode.parentChild) {
|
||||
return (content as string[]).map((child, index) => {
|
||||
return (content as ParentChildChunk).child_contents.map((child, index) => {
|
||||
const indexForLabel = index + 1
|
||||
return (
|
||||
<PreviewSlice
|
||||
@ -57,7 +58,17 @@ const ChunkCard = (props: ChunkCardProps) => {
|
||||
)
|
||||
}
|
||||
|
||||
return content as string
|
||||
return (content as GeneralChunk).content
|
||||
}, [content, chunkType])
|
||||
|
||||
const summaryElement = useMemo(() => {
|
||||
if (chunkType === ChunkingMode.parentChild) {
|
||||
return (content as ParentChildChunk).parent_summary
|
||||
}
|
||||
if (chunkType === ChunkingMode.text) {
|
||||
return (content as GeneralChunk).summary
|
||||
}
|
||||
return null
|
||||
}, [content, chunkType])
|
||||
|
||||
return (
|
||||
@ -73,6 +84,7 @@ const ChunkCard = (props: ChunkCardProps) => {
|
||||
</div>
|
||||
)}
|
||||
<div className="body-md-regular text-text-secondary">{contentElement}</div>
|
||||
{summaryElement && <SummaryLabel summary={summaryElement} />}
|
||||
</div>
|
||||
)
|
||||
}
|
||||
|
||||
@ -10,13 +10,13 @@ import { QAItemType } from './types'
|
||||
// Test Data Factories
|
||||
// =============================================================================
|
||||
|
||||
const createGeneralChunks = (overrides: string[] = []): GeneralChunks => {
|
||||
const createGeneralChunks = (overrides: GeneralChunks = []): GeneralChunks => {
|
||||
if (overrides.length > 0)
|
||||
return overrides
|
||||
return [
|
||||
'This is the first chunk of text content.',
|
||||
'This is the second chunk with different content.',
|
||||
'Third chunk here with more text.',
|
||||
{ content: 'This is the first chunk of text content.' },
|
||||
{ content: 'This is the second chunk with different content.' },
|
||||
{ content: 'Third chunk here with more text.' },
|
||||
]
|
||||
}
|
||||
|
||||
@ -152,14 +152,14 @@ describe('ChunkCard', () => {
|
||||
render(
|
||||
<ChunkCard
|
||||
chunkType={ChunkingMode.text}
|
||||
content="This is plain text content."
|
||||
content={createGeneralChunks()[0]}
|
||||
wordCount={27}
|
||||
positionId={1}
|
||||
/>,
|
||||
)
|
||||
|
||||
// Assert
|
||||
expect(screen.getByText('This is plain text content.')).toBeInTheDocument()
|
||||
expect(screen.getByText('This is the first chunk of text content.')).toBeInTheDocument()
|
||||
expect(screen.getByText(/Chunk-01/)).toBeInTheDocument()
|
||||
})
|
||||
|
||||
@ -196,7 +196,7 @@ describe('ChunkCard', () => {
|
||||
<ChunkCard
|
||||
chunkType={ChunkingMode.parentChild}
|
||||
parentMode="paragraph"
|
||||
content={childContents}
|
||||
content={createParentChildChunk({ child_contents: childContents })}
|
||||
wordCount={50}
|
||||
positionId={1}
|
||||
/>,
|
||||
@ -218,7 +218,7 @@ describe('ChunkCard', () => {
|
||||
<ChunkCard
|
||||
chunkType={ChunkingMode.parentChild}
|
||||
parentMode="paragraph"
|
||||
content={['Child content']}
|
||||
content={createParentChildChunk({ child_contents: ['Child content'] })}
|
||||
wordCount={13}
|
||||
positionId={1}
|
||||
/>,
|
||||
@ -234,7 +234,7 @@ describe('ChunkCard', () => {
|
||||
<ChunkCard
|
||||
chunkType={ChunkingMode.parentChild}
|
||||
parentMode="full-doc"
|
||||
content={['Child content']}
|
||||
content={createParentChildChunk({ child_contents: ['Child content'] })}
|
||||
wordCount={13}
|
||||
positionId={1}
|
||||
/>,
|
||||
@ -250,7 +250,7 @@ describe('ChunkCard', () => {
|
||||
render(
|
||||
<ChunkCard
|
||||
chunkType={ChunkingMode.text}
|
||||
content="Text content"
|
||||
content={createGeneralChunks()[0]}
|
||||
wordCount={12}
|
||||
positionId={5}
|
||||
/>,
|
||||
@ -268,7 +268,7 @@ describe('ChunkCard', () => {
|
||||
render(
|
||||
<ChunkCard
|
||||
chunkType={ChunkingMode.text}
|
||||
content="Some content"
|
||||
content={createGeneralChunks()[0]}
|
||||
wordCount={1234}
|
||||
positionId={1}
|
||||
/>,
|
||||
@ -283,7 +283,7 @@ describe('ChunkCard', () => {
|
||||
render(
|
||||
<ChunkCard
|
||||
chunkType={ChunkingMode.text}
|
||||
content="Some content"
|
||||
content={createGeneralChunks()[0]}
|
||||
wordCount={100}
|
||||
positionId={1}
|
||||
/>,
|
||||
@ -299,7 +299,7 @@ describe('ChunkCard', () => {
|
||||
<ChunkCard
|
||||
chunkType={ChunkingMode.parentChild}
|
||||
parentMode="full-doc"
|
||||
content={['Child']}
|
||||
content={createParentChildChunk({ child_contents: ['Child'] })}
|
||||
wordCount={500}
|
||||
positionId={1}
|
||||
/>,
|
||||
@ -317,7 +317,7 @@ describe('ChunkCard', () => {
|
||||
render(
|
||||
<ChunkCard
|
||||
chunkType={ChunkingMode.text}
|
||||
content="Content"
|
||||
content={createGeneralChunks()[0]}
|
||||
wordCount={7}
|
||||
positionId={42}
|
||||
/>,
|
||||
@ -332,7 +332,7 @@ describe('ChunkCard', () => {
|
||||
render(
|
||||
<ChunkCard
|
||||
chunkType={ChunkingMode.text}
|
||||
content="Content"
|
||||
content={createGeneralChunks()[0]}
|
||||
wordCount={7}
|
||||
positionId="99"
|
||||
/>,
|
||||
@ -347,7 +347,7 @@ describe('ChunkCard', () => {
|
||||
render(
|
||||
<ChunkCard
|
||||
chunkType={ChunkingMode.text}
|
||||
content="Content"
|
||||
content={createGeneralChunks()[0]}
|
||||
wordCount={7}
|
||||
positionId={3}
|
||||
/>,
|
||||
@ -366,7 +366,7 @@ describe('ChunkCard', () => {
|
||||
<ChunkCard
|
||||
chunkType={ChunkingMode.parentChild}
|
||||
parentMode="paragraph"
|
||||
content={['Child']}
|
||||
content={createParentChildChunk({ child_contents: ['Child'] })}
|
||||
wordCount={5}
|
||||
positionId={1}
|
||||
/>,
|
||||
@ -380,7 +380,7 @@ describe('ChunkCard', () => {
|
||||
<ChunkCard
|
||||
chunkType={ChunkingMode.parentChild}
|
||||
parentMode="full-doc"
|
||||
content={['Child']}
|
||||
content={createParentChildChunk({ child_contents: ['Child'] })}
|
||||
wordCount={5}
|
||||
positionId={1}
|
||||
/>,
|
||||
@ -392,10 +392,13 @@ describe('ChunkCard', () => {
|
||||
|
||||
it('should update contentElement memo when content changes', () => {
|
||||
// Arrange
|
||||
const initialContent = { content: 'Initial content' }
|
||||
const updatedContent = { content: 'Updated content' }
|
||||
|
||||
const { rerender } = render(
|
||||
<ChunkCard
|
||||
chunkType={ChunkingMode.text}
|
||||
content="Initial content"
|
||||
content={initialContent}
|
||||
wordCount={15}
|
||||
positionId={1}
|
||||
/>,
|
||||
@ -408,7 +411,7 @@ describe('ChunkCard', () => {
|
||||
rerender(
|
||||
<ChunkCard
|
||||
chunkType={ChunkingMode.text}
|
||||
content="Updated content"
|
||||
content={updatedContent}
|
||||
wordCount={15}
|
||||
positionId={1}
|
||||
/>,
|
||||
@ -421,10 +424,11 @@ describe('ChunkCard', () => {
|
||||
|
||||
it('should update contentElement memo when chunkType changes', () => {
|
||||
// Arrange
|
||||
const textContent = { content: 'Text content' }
|
||||
const { rerender } = render(
|
||||
<ChunkCard
|
||||
chunkType={ChunkingMode.text}
|
||||
content="Text content"
|
||||
content={textContent}
|
||||
wordCount={12}
|
||||
positionId={1}
|
||||
/>,
|
||||
@ -458,7 +462,7 @@ describe('ChunkCard', () => {
|
||||
<ChunkCard
|
||||
chunkType={ChunkingMode.parentChild}
|
||||
parentMode="paragraph"
|
||||
content={[]}
|
||||
content={createParentChildChunk({ child_contents: [] })}
|
||||
wordCount={0}
|
||||
positionId={1}
|
||||
/>,
|
||||
@ -490,12 +494,13 @@ describe('ChunkCard', () => {
|
||||
it('should handle very long content', () => {
|
||||
// Arrange
|
||||
const longContent = 'A'.repeat(10000)
|
||||
const longContentChunk = { content: longContent }
|
||||
|
||||
// Act
|
||||
render(
|
||||
<ChunkCard
|
||||
chunkType={ChunkingMode.text}
|
||||
content={longContent}
|
||||
content={longContentChunk}
|
||||
wordCount={10000}
|
||||
positionId={1}
|
||||
/>,
|
||||
@ -510,7 +515,7 @@ describe('ChunkCard', () => {
|
||||
render(
|
||||
<ChunkCard
|
||||
chunkType={ChunkingMode.text}
|
||||
content=""
|
||||
content={createGeneralChunks()[0]}
|
||||
wordCount={0}
|
||||
positionId={1}
|
||||
/>,
|
||||
@ -546,9 +551,9 @@ describe('ChunkCardList', () => {
|
||||
)
|
||||
|
||||
// Assert
|
||||
expect(screen.getByText(chunks[0])).toBeInTheDocument()
|
||||
expect(screen.getByText(chunks[1])).toBeInTheDocument()
|
||||
expect(screen.getByText(chunks[2])).toBeInTheDocument()
|
||||
expect(screen.getByText(chunks[0].content)).toBeInTheDocument()
|
||||
expect(screen.getByText(chunks[1].content)).toBeInTheDocument()
|
||||
expect(screen.getByText(chunks[2].content)).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should render parent-child chunks correctly', () => {
|
||||
@ -594,7 +599,10 @@ describe('ChunkCardList', () => {
|
||||
describe('Memoization - chunkList', () => {
|
||||
it('should extract chunks from GeneralChunks for text mode', () => {
|
||||
// Arrange
|
||||
const chunks: GeneralChunks = ['Chunk 1', 'Chunk 2']
|
||||
const chunks: GeneralChunks = [
|
||||
{ content: 'Chunk 1' },
|
||||
{ content: 'Chunk 2' },
|
||||
]
|
||||
|
||||
// Act
|
||||
render(
|
||||
@ -653,7 +661,7 @@ describe('ChunkCardList', () => {
|
||||
|
||||
it('should update chunkList when chunkInfo changes', () => {
|
||||
// Arrange
|
||||
const initialChunks = createGeneralChunks(['Initial chunk'])
|
||||
const initialChunks = createGeneralChunks([{ content: 'Initial chunk' }])
|
||||
|
||||
const { rerender } = render(
|
||||
<ChunkCardList
|
||||
@ -666,7 +674,7 @@ describe('ChunkCardList', () => {
|
||||
expect(screen.getByText('Initial chunk')).toBeInTheDocument()
|
||||
|
||||
// Act - update chunks
|
||||
const updatedChunks = createGeneralChunks(['Updated chunk'])
|
||||
const updatedChunks = createGeneralChunks([{ content: 'Updated chunk' }])
|
||||
rerender(
|
||||
<ChunkCardList
|
||||
chunkType={ChunkingMode.text}
|
||||
@ -684,7 +692,7 @@ describe('ChunkCardList', () => {
|
||||
describe('Word Count Calculation', () => {
|
||||
it('should calculate word count for text chunks using string length', () => {
|
||||
// Arrange - "Hello" has 5 characters
|
||||
const chunks = createGeneralChunks(['Hello'])
|
||||
const chunks = createGeneralChunks([{ content: 'Hello' }])
|
||||
|
||||
// Act
|
||||
render(
|
||||
@ -747,7 +755,11 @@ describe('ChunkCardList', () => {
|
||||
describe('Position ID', () => {
|
||||
it('should assign 1-based position IDs to chunks', () => {
|
||||
// Arrange
|
||||
const chunks = createGeneralChunks(['First', 'Second', 'Third'])
|
||||
const chunks = createGeneralChunks([
|
||||
{ content: 'First' },
|
||||
{ content: 'Second' },
|
||||
{ content: 'Third' },
|
||||
])
|
||||
|
||||
// Act
|
||||
render(
|
||||
@ -768,7 +780,7 @@ describe('ChunkCardList', () => {
|
||||
describe('Custom className', () => {
|
||||
it('should apply custom className to container', () => {
|
||||
// Arrange
|
||||
const chunks = createGeneralChunks(['Test'])
|
||||
const chunks = createGeneralChunks([{ content: 'Test' }])
|
||||
|
||||
// Act
|
||||
const { container } = render(
|
||||
@ -785,7 +797,7 @@ describe('ChunkCardList', () => {
|
||||
|
||||
it('should merge custom className with default classes', () => {
|
||||
// Arrange
|
||||
const chunks = createGeneralChunks(['Test'])
|
||||
const chunks = createGeneralChunks([{ content: 'Test' }])
|
||||
|
||||
// Act
|
||||
const { container } = render(
|
||||
@ -805,7 +817,7 @@ describe('ChunkCardList', () => {
|
||||
|
||||
it('should render without className prop', () => {
|
||||
// Arrange
|
||||
const chunks = createGeneralChunks(['Test'])
|
||||
const chunks = createGeneralChunks([{ content: 'Test' }])
|
||||
|
||||
// Act
|
||||
const { container } = render(
|
||||
@ -860,7 +872,7 @@ describe('ChunkCardList', () => {
|
||||
|
||||
it('should not use parentMode for text type', () => {
|
||||
// Arrange
|
||||
const chunks = createGeneralChunks(['Text'])
|
||||
const chunks = createGeneralChunks([{ content: 'Text' }])
|
||||
|
||||
// Act
|
||||
render(
|
||||
@ -937,7 +949,7 @@ describe('ChunkCardList', () => {
|
||||
|
||||
it('should handle single item in chunks', () => {
|
||||
// Arrange
|
||||
const chunks = createGeneralChunks(['Single chunk'])
|
||||
const chunks = createGeneralChunks([{ content: 'Single chunk' }])
|
||||
|
||||
// Act
|
||||
render(
|
||||
@ -954,7 +966,7 @@ describe('ChunkCardList', () => {
|
||||
|
||||
it('should handle large number of chunks', () => {
|
||||
// Arrange
|
||||
const chunks = Array.from({ length: 100 }, (_, i) => `Chunk number ${i + 1}`)
|
||||
const chunks = Array.from({ length: 100 }, (_, i) => ({ content: `Chunk number ${i + 1}` }))
|
||||
|
||||
// Act
|
||||
render(
|
||||
@ -975,8 +987,11 @@ describe('ChunkCardList', () => {
|
||||
describe('Key Generation', () => {
|
||||
it('should generate unique keys for chunks', () => {
|
||||
// Arrange - chunks with same content
|
||||
const chunks = createGeneralChunks(['Same content', 'Same content', 'Same content'])
|
||||
|
||||
const chunks = createGeneralChunks([
|
||||
{ content: 'Same content' },
|
||||
{ content: 'Same content' },
|
||||
{ content: 'Same content' },
|
||||
])
|
||||
// Act
|
||||
const { container } = render(
|
||||
<ChunkCardList
|
||||
@ -1006,9 +1021,9 @@ describe('ChunkCardList Integration', () => {
|
||||
it('should render complete text chunking workflow', () => {
|
||||
// Arrange
|
||||
const textChunks = createGeneralChunks([
|
||||
'First paragraph of the document.',
|
||||
'Second paragraph with more information.',
|
||||
'Final paragraph concluding the content.',
|
||||
{ content: 'First paragraph of the document.' },
|
||||
{ content: 'Second paragraph with more information.' },
|
||||
{ content: 'Final paragraph concluding the content.' },
|
||||
])
|
||||
|
||||
// Act
|
||||
@ -1104,7 +1119,7 @@ describe('ChunkCardList Integration', () => {
|
||||
describe('Type Switching', () => {
|
||||
it('should handle switching from text to QA type', () => {
|
||||
// Arrange
|
||||
const textChunks = createGeneralChunks(['Text content'])
|
||||
const textChunks = createGeneralChunks([{ content: 'Text content' }])
|
||||
const qaChunks = createQAChunks()
|
||||
|
||||
const { rerender } = render(
|
||||
@ -1132,7 +1147,7 @@ describe('ChunkCardList Integration', () => {
|
||||
|
||||
it('should handle switching from text to parent-child type', () => {
|
||||
// Arrange
|
||||
const textChunks = createGeneralChunks(['Simple text'])
|
||||
const textChunks = createGeneralChunks([{ content: 'Simple text' }])
|
||||
const parentChildChunks = createParentChildChunks()
|
||||
|
||||
const { rerender } = render(
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
import type { ChunkInfo, GeneralChunks, ParentChildChunk, ParentChildChunks, QAChunk, QAChunks } from './types'
|
||||
import type { ChunkInfo, GeneralChunk, GeneralChunks, ParentChildChunk, ParentChildChunks, QAChunk, QAChunks } from './types'
|
||||
import type { ParentMode } from '@/models/datasets'
|
||||
import { useMemo } from 'react'
|
||||
import { ChunkingMode } from '@/models/datasets'
|
||||
@ -21,13 +21,13 @@ export const ChunkCardList = (props: ChunkCardListProps) => {
|
||||
if (chunkType === ChunkingMode.parentChild)
|
||||
return (chunkInfo as ParentChildChunks).parent_child_chunks
|
||||
return (chunkInfo as QAChunks).qa_chunks
|
||||
}, [chunkInfo])
|
||||
}, [chunkInfo, chunkType])
|
||||
|
||||
const getWordCount = (seg: string | ParentChildChunk | QAChunk) => {
|
||||
const getWordCount = (seg: GeneralChunk | ParentChildChunk | QAChunk) => {
|
||||
if (chunkType === ChunkingMode.parentChild)
|
||||
return (seg as ParentChildChunk).parent_content.length
|
||||
return (seg as ParentChildChunk).parent_content?.length
|
||||
if (chunkType === ChunkingMode.text)
|
||||
return (seg as string).length
|
||||
return (seg as GeneralChunk).content.length
|
||||
return (seg as QAChunk).question.length + (seg as QAChunk).answer.length
|
||||
}
|
||||
|
||||
@ -41,7 +41,7 @@ export const ChunkCardList = (props: ChunkCardListProps) => {
|
||||
key={`${chunkType}-${index}`}
|
||||
chunkType={chunkType}
|
||||
parentMode={parentMode}
|
||||
content={chunkType === ChunkingMode.parentChild ? (seg as ParentChildChunk).child_contents : (seg as string | QAChunk)}
|
||||
content={seg}
|
||||
wordCount={wordCount}
|
||||
positionId={index + 1}
|
||||
/>
|
||||
|
||||
@ -1,8 +1,12 @@
|
||||
export type GeneralChunks = string[]
|
||||
|
||||
export type GeneralChunk = {
|
||||
content: string
|
||||
summary?: string
|
||||
}
|
||||
export type GeneralChunks = GeneralChunk[]
|
||||
export type ParentChildChunk = {
|
||||
child_contents: string[]
|
||||
parent_content: string
|
||||
parent_summary?: string
|
||||
parent_mode: string
|
||||
}
|
||||
|
||||
|
||||
@ -1,8 +1,8 @@
|
||||
import type { GeneralChunks } from '@/app/components/rag-pipeline/components/chunk-card-list/types'
|
||||
import type { WorkflowRunningData } from '@/app/components/workflow/types'
|
||||
import { fireEvent, render, screen, waitFor } from '@testing-library/react'
|
||||
import { WorkflowRunningStatus } from '@/app/components/workflow/types'
|
||||
import { ChunkingMode } from '@/models/datasets'
|
||||
|
||||
import Header from './header'
|
||||
// Import components after mocks
|
||||
import TestRunPanel from './index'
|
||||
@ -830,17 +830,27 @@ describe('formatPreviewChunks', () => {
|
||||
const outputs = createMockGeneralOutputs(['content1', 'content2', 'content3'])
|
||||
const result = formatPreviewChunks(outputs)
|
||||
|
||||
expect(result).toEqual(['content1', 'content2', 'content3'])
|
||||
expect(result).toEqual([
|
||||
{ content: 'content1', summary: undefined },
|
||||
{ content: 'content2', summary: undefined },
|
||||
{ content: 'content3', summary: undefined },
|
||||
])
|
||||
})
|
||||
|
||||
it('should limit to RAG_PIPELINE_PREVIEW_CHUNK_NUM chunks', () => {
|
||||
const manyChunks = Array.from({ length: 10 }, (_, i) => `chunk${i}`)
|
||||
const outputs = createMockGeneralOutputs(manyChunks)
|
||||
const result = formatPreviewChunks(outputs) as string[]
|
||||
const result = formatPreviewChunks(outputs) as GeneralChunks
|
||||
|
||||
// RAG_PIPELINE_PREVIEW_CHUNK_NUM is mocked to 5
|
||||
expect(result).toHaveLength(5)
|
||||
expect(result).toEqual(['chunk0', 'chunk1', 'chunk2', 'chunk3', 'chunk4'])
|
||||
expect(result).toEqual([
|
||||
{ content: 'chunk0', summary: undefined },
|
||||
{ content: 'chunk1', summary: undefined },
|
||||
{ content: 'chunk2', summary: undefined },
|
||||
{ content: 'chunk3', summary: undefined },
|
||||
{ content: 'chunk4', summary: undefined },
|
||||
])
|
||||
})
|
||||
|
||||
it('should handle empty preview array', () => {
|
||||
|
||||
@ -590,9 +590,9 @@ describe('formatPreviewChunks', () => {
|
||||
const result = formatPreviewChunks(outputs) as GeneralChunks
|
||||
|
||||
expect(result).toHaveLength(3)
|
||||
expect(result[0]).toBe('General chunk content 1')
|
||||
expect(result[1]).toBe('General chunk content 2')
|
||||
expect(result[2]).toBe('General chunk content 3')
|
||||
expect((result as GeneralChunks)[0].content).toBe('General chunk content 1')
|
||||
expect((result as GeneralChunks)[1].content).toBe('General chunk content 2')
|
||||
expect((result as GeneralChunks)[2].content).toBe('General chunk content 3')
|
||||
})
|
||||
|
||||
it('should limit chunks to RAG_PIPELINE_PREVIEW_CHUNK_NUM', () => {
|
||||
|
||||
@ -145,9 +145,9 @@ describe('formatPreviewChunks', () => {
|
||||
|
||||
// Assert
|
||||
expect(result).toEqual([
|
||||
'First chunk content',
|
||||
'Second chunk content',
|
||||
'Third chunk content',
|
||||
{ content: 'First chunk content', summary: undefined },
|
||||
{ content: 'Second chunk content', summary: undefined },
|
||||
{ content: 'Third chunk content', summary: undefined },
|
||||
])
|
||||
})
|
||||
|
||||
@ -160,8 +160,8 @@ describe('formatPreviewChunks', () => {
|
||||
|
||||
// Assert
|
||||
expect(result).toHaveLength(20)
|
||||
expect(result[0]).toBe('Chunk content 1')
|
||||
expect(result[19]).toBe('Chunk content 20')
|
||||
expect((result as GeneralChunks)[0].content).toBe('Chunk content 1')
|
||||
expect((result as GeneralChunks)[19].content).toBe('Chunk content 20')
|
||||
})
|
||||
|
||||
it('should handle empty preview array for general chunks', () => {
|
||||
@ -186,7 +186,10 @@ describe('formatPreviewChunks', () => {
|
||||
const result = formatPreviewChunks(outputs) as GeneralChunks
|
||||
|
||||
// Assert
|
||||
expect(result).toEqual(['', 'Valid content'])
|
||||
expect(result).toEqual([
|
||||
{ content: '', summary: undefined },
|
||||
{ content: 'Valid content', summary: undefined },
|
||||
])
|
||||
})
|
||||
|
||||
it('should handle general chunks with special characters', () => {
|
||||
@ -202,9 +205,9 @@ describe('formatPreviewChunks', () => {
|
||||
|
||||
// Assert
|
||||
expect(result).toEqual([
|
||||
'<script>alert("xss")</script>',
|
||||
'中文内容 🎉',
|
||||
'Line1\nLine2\tTab',
|
||||
{ content: '<script>alert("xss")</script>', summary: undefined },
|
||||
{ content: '中文内容 🎉', summary: undefined },
|
||||
{ content: 'Line1\nLine2\tTab', summary: undefined },
|
||||
])
|
||||
})
|
||||
|
||||
@ -217,7 +220,7 @@ describe('formatPreviewChunks', () => {
|
||||
const result = formatPreviewChunks(outputs) as GeneralChunks
|
||||
|
||||
// Assert
|
||||
expect(result[0]).toHaveLength(10000)
|
||||
expect((result as GeneralChunks)[0].content).toHaveLength(10000)
|
||||
})
|
||||
})
|
||||
|
||||
@ -501,7 +504,7 @@ describe('formatPreviewChunks', () => {
|
||||
const result = formatPreviewChunks(outputs) as GeneralChunks
|
||||
|
||||
// Assert
|
||||
expect(result).toEqual(['Test'])
|
||||
expect(result).toEqual([{ content: 'Test', summary: undefined }])
|
||||
})
|
||||
})
|
||||
})
|
||||
@ -667,7 +670,10 @@ describe('ResultPreview', () => {
|
||||
// Assert
|
||||
const chunkList = screen.getByTestId('chunk-card-list')
|
||||
const chunkInfo = JSON.parse(chunkList.getAttribute('data-chunk-info') || '[]')
|
||||
expect(chunkInfo).toEqual(['Chunk 1', 'Chunk 2'])
|
||||
expect(chunkInfo).toEqual([
|
||||
{ content: 'Chunk 1' },
|
||||
{ content: 'Chunk 2' },
|
||||
])
|
||||
})
|
||||
|
||||
it('should handle parent-child outputs', () => {
|
||||
@ -792,7 +798,7 @@ describe('ResultPreview', () => {
|
||||
// Assert
|
||||
const chunkList = screen.getByTestId('chunk-card-list')
|
||||
const chunkInfo = JSON.parse(chunkList.getAttribute('data-chunk-info') || '[]')
|
||||
expect(chunkInfo).toEqual(['Second'])
|
||||
expect(chunkInfo).toEqual([{ content: 'Second' }])
|
||||
})
|
||||
})
|
||||
|
||||
@ -820,7 +826,7 @@ describe('ResultPreview', () => {
|
||||
|
||||
let chunkList = screen.getByTestId('chunk-card-list')
|
||||
let chunkInfo = JSON.parse(chunkList.getAttribute('data-chunk-info') || '[]')
|
||||
expect(chunkInfo).toEqual(['Original'])
|
||||
expect(chunkInfo).toEqual([{ content: 'Original' }])
|
||||
|
||||
// Act - Change outputs
|
||||
const outputs2 = createGeneralChunkOutputs([{ content: 'Updated' }])
|
||||
@ -829,7 +835,7 @@ describe('ResultPreview', () => {
|
||||
// Assert
|
||||
chunkList = screen.getByTestId('chunk-card-list')
|
||||
chunkInfo = JSON.parse(chunkList.getAttribute('data-chunk-info') || '[]')
|
||||
expect(chunkInfo).toEqual(['Updated'])
|
||||
expect(chunkInfo).toEqual([{ content: 'Updated' }])
|
||||
})
|
||||
|
||||
it('should handle undefined outputs in useMemo', () => {
|
||||
|
||||
@ -5,13 +5,17 @@ import { ChunkingMode } from '@/models/datasets'
|
||||
|
||||
type GeneralChunkPreview = {
|
||||
content: string
|
||||
summary?: string
|
||||
}
|
||||
|
||||
const formatGeneralChunks = (outputs: any) => {
|
||||
const chunkInfo: GeneralChunks = []
|
||||
const chunks = outputs.preview as GeneralChunkPreview[]
|
||||
chunks.slice(0, RAG_PIPELINE_PREVIEW_CHUNK_NUM).forEach((chunk) => {
|
||||
chunkInfo.push(chunk.content)
|
||||
chunkInfo.push({
|
||||
content: chunk.content,
|
||||
summary: chunk.summary,
|
||||
})
|
||||
})
|
||||
|
||||
return chunkInfo
|
||||
@ -20,6 +24,7 @@ const formatGeneralChunks = (outputs: any) => {
|
||||
type ParentChildChunkPreview = {
|
||||
content: string
|
||||
child_chunks: string[]
|
||||
summary?: string
|
||||
}
|
||||
|
||||
const formatParentChildChunks = (outputs: any, parentMode: ParentMode) => {
|
||||
@ -32,6 +37,7 @@ const formatParentChildChunks = (outputs: any, parentMode: ParentMode) => {
|
||||
chunks.slice(0, RAG_PIPELINE_PREVIEW_CHUNK_NUM).forEach((chunk) => {
|
||||
chunkInfo.parent_child_chunks?.push({
|
||||
parent_content: chunk.content,
|
||||
parent_summary: chunk.summary,
|
||||
child_contents: chunk.child_chunks,
|
||||
parent_mode: parentMode,
|
||||
})
|
||||
|
||||
@ -1,6 +1,7 @@
|
||||
import type {
|
||||
KnowledgeBaseNodeType,
|
||||
RerankingModel,
|
||||
SummaryIndexSetting,
|
||||
} from '../types'
|
||||
import type { ValueSelector } from '@/app/components/workflow/types'
|
||||
import { produce } from 'immer'
|
||||
@ -246,6 +247,16 @@ export const useConfig = (id: string) => {
|
||||
})
|
||||
}, [handleNodeDataUpdate])
|
||||
|
||||
const handleSummaryIndexSettingChange = useCallback((summaryIndexSetting: SummaryIndexSetting) => {
|
||||
const nodeData = getNodeData()
|
||||
handleNodeDataUpdate({
|
||||
summary_index_setting: {
|
||||
...nodeData?.data.summary_index_setting,
|
||||
...summaryIndexSetting,
|
||||
},
|
||||
})
|
||||
}, [handleNodeDataUpdate, getNodeData])
|
||||
|
||||
return {
|
||||
handleChunkStructureChange,
|
||||
handleIndexMethodChange,
|
||||
@ -260,5 +271,6 @@ export const useConfig = (id: string) => {
|
||||
handleScoreThresholdChange,
|
||||
handleScoreThresholdEnabledChange,
|
||||
handleInputVariableChange,
|
||||
handleSummaryIndexSettingChange,
|
||||
}
|
||||
}
|
||||
|
||||
@ -7,6 +7,7 @@ import {
|
||||
useMemo,
|
||||
} from 'react'
|
||||
import { useTranslation } from 'react-i18next'
|
||||
import SummaryIndexSetting from '@/app/components/datasets/settings/summary-index-setting'
|
||||
import { checkShowMultiModalTip } from '@/app/components/datasets/settings/utils'
|
||||
import { ModelTypeEnum } from '@/app/components/header/account-setting/model-provider-page/declarations'
|
||||
import { useModelList } from '@/app/components/header/account-setting/model-provider-page/hooks'
|
||||
@ -51,6 +52,7 @@ const Panel: FC<NodePanelProps<KnowledgeBaseNodeType>> = ({
|
||||
handleScoreThresholdChange,
|
||||
handleScoreThresholdEnabledChange,
|
||||
handleInputVariableChange,
|
||||
handleSummaryIndexSettingChange,
|
||||
} = useConfig(id)
|
||||
|
||||
const filterVar = useCallback((variable: Var) => {
|
||||
@ -167,6 +169,22 @@ const Panel: FC<NodePanelProps<KnowledgeBaseNodeType>> = ({
|
||||
<div className="pt-1">
|
||||
<Split className="h-[1px]" />
|
||||
</div>
|
||||
{
|
||||
data.indexing_technique === IndexMethodEnum.QUALIFIED
|
||||
&& [ChunkStructureEnum.general, ChunkStructureEnum.parent_child].includes(data.chunk_structure)
|
||||
&& (
|
||||
<>
|
||||
<SummaryIndexSetting
|
||||
summaryIndexSetting={data.summary_index_setting}
|
||||
onSummaryIndexSettingChange={handleSummaryIndexSettingChange}
|
||||
readonly={nodesReadOnly}
|
||||
/>
|
||||
<div className="pt-1">
|
||||
<Split className="h-[1px]" />
|
||||
</div>
|
||||
</>
|
||||
)
|
||||
}
|
||||
<RetrievalSetting
|
||||
indexMethod={data.indexing_technique}
|
||||
searchMethod={data.retrieval_model.search_method}
|
||||
|
||||
@ -42,6 +42,12 @@ export type RetrievalSetting = {
|
||||
score_threshold: number
|
||||
reranking_mode?: RerankingModeEnum
|
||||
}
|
||||
export type SummaryIndexSetting = {
|
||||
enable?: boolean
|
||||
model_name?: string
|
||||
model_provider_name?: string
|
||||
summary_prompt?: string
|
||||
}
|
||||
export type KnowledgeBaseNodeType = CommonNodeType & {
|
||||
index_chunk_variable_selector: string[]
|
||||
chunk_structure?: ChunkStructureEnum
|
||||
@ -52,4 +58,5 @@ export type KnowledgeBaseNodeType = CommonNodeType & {
|
||||
retrieval_model: RetrievalSetting
|
||||
_embeddingModelList?: Model[]
|
||||
_rerankModelList?: Model[]
|
||||
summary_index_setting?: SummaryIndexSetting
|
||||
}
|
||||
|
||||
@ -31,6 +31,7 @@
|
||||
"list.action.pause": "Pause",
|
||||
"list.action.resume": "Resume",
|
||||
"list.action.settings": "Chunking Settings",
|
||||
"list.action.summary": "Generate summary",
|
||||
"list.action.sync": "Sync",
|
||||
"list.action.unarchive": "Unarchive",
|
||||
"list.action.uploadFile": "Upload new file",
|
||||
@ -75,6 +76,9 @@
|
||||
"list.status.indexing": "Indexing",
|
||||
"list.status.paused": "Paused",
|
||||
"list.status.queuing": "Queuing",
|
||||
"list.summary.generating": "Generating...",
|
||||
"list.summary.generatingSummary": "Generating summary",
|
||||
"list.summary.ready": "Summary ready",
|
||||
"list.table.header.action": "ACTION",
|
||||
"list.table.header.chunkingMode": "CHUNKING MODE",
|
||||
"list.table.header.fileName": "NAME",
|
||||
@ -329,5 +333,7 @@
|
||||
"segment.searchResults_one": "RESULT",
|
||||
"segment.searchResults_other": "RESULTS",
|
||||
"segment.searchResults_zero": "RESULT",
|
||||
"segment.summary": "SUMMARY",
|
||||
"segment.summaryPlaceholder": "Write a brief summary for better retrieval…",
|
||||
"segment.vectorHash": "Vector hash: "
|
||||
}
|
||||
|
||||
@ -39,6 +39,12 @@
|
||||
"form.retrievalSettings": "Retrieval Settings",
|
||||
"form.save": "Save",
|
||||
"form.searchModel": "Search model",
|
||||
"form.summaryAutoGen": "Summary Auto-Gen",
|
||||
"form.summaryAutoGenEnableTip": "Once enabled, summaries will be generated automatically for newly added documents. Existing documents can still be summarized manually.",
|
||||
"form.summaryAutoGenTip": "Summaries are automatically generated for newly added documents. Existing documents can still be summarized manually.",
|
||||
"form.summaryInstructions": "Instructions",
|
||||
"form.summaryInstructionsPlaceholder": "Describe the rules or style for auto-generated summaries…",
|
||||
"form.summaryModel": "Summary Model",
|
||||
"form.upgradeHighQualityTip": "Once upgrading to High Quality mode, reverting to Economical mode is not available",
|
||||
"title": "Knowledge settings"
|
||||
}
|
||||
|
||||
@ -31,6 +31,7 @@
|
||||
"list.action.pause": "暂停",
|
||||
"list.action.resume": "恢复",
|
||||
"list.action.settings": "分段设置",
|
||||
"list.action.summary": "生成摘要",
|
||||
"list.action.sync": "同步",
|
||||
"list.action.unarchive": "撤销归档",
|
||||
"list.action.uploadFile": "上传新文件",
|
||||
@ -75,6 +76,9 @@
|
||||
"list.status.indexing": "索引中",
|
||||
"list.status.paused": "已暂停",
|
||||
"list.status.queuing": "排队中",
|
||||
"list.summary.generating": "生成中...",
|
||||
"list.summary.generatingSummary": "生成摘要中",
|
||||
"list.summary.ready": "摘要已生成",
|
||||
"list.table.header.action": "操作",
|
||||
"list.table.header.chunkingMode": "分段模式",
|
||||
"list.table.header.fileName": "名称",
|
||||
@ -329,5 +333,7 @@
|
||||
"segment.searchResults_one": "搜索结果",
|
||||
"segment.searchResults_other": "搜索结果",
|
||||
"segment.searchResults_zero": "搜索结果",
|
||||
"segment.summary": "摘要",
|
||||
"segment.summaryPlaceholder": "写一个简短的摘要,以便更好地检索…",
|
||||
"segment.vectorHash": "向量哈希:"
|
||||
}
|
||||
|
||||
@ -39,6 +39,12 @@
|
||||
"form.retrievalSettings": "检索设置",
|
||||
"form.save": "保存",
|
||||
"form.searchModel": "搜索模型",
|
||||
"form.summaryAutoGen": "摘要自动生成",
|
||||
"form.summaryAutoGenEnableTip": "启用后,将自动为新添加的文档生成摘要。已有的文档仍可以手动摘要。",
|
||||
"form.summaryAutoGenTip": "将自动为新添加的文档生成摘要。已有的文档仍可以手动摘要。",
|
||||
"form.summaryInstructions": "指令",
|
||||
"form.summaryInstructionsPlaceholder": "描述自动生成摘要的规则或风格…",
|
||||
"form.summaryModel": "摘要模型",
|
||||
"form.upgradeHighQualityTip": "一旦升级为高质量模式,将无法切换回经济模式。",
|
||||
"title": "知识库设置"
|
||||
}
|
||||
|
||||
@ -42,6 +42,13 @@ export type IconInfo = {
|
||||
icon_url?: string
|
||||
}
|
||||
|
||||
export type SummaryIndexSetting = {
|
||||
enable?: boolean
|
||||
model_name?: string
|
||||
model_provider_name?: string
|
||||
summary_prompt?: string
|
||||
}
|
||||
|
||||
export type DataSet = {
|
||||
id: string
|
||||
name: string
|
||||
@ -88,6 +95,7 @@ export type DataSet = {
|
||||
runtime_mode: 'rag_pipeline' | 'general'
|
||||
enable_api: boolean // Indicates if the service API is enabled
|
||||
is_multimodal: boolean // Indicates if the dataset supports multimodal
|
||||
summary_index_setting?: SummaryIndexSetting
|
||||
}
|
||||
|
||||
export type ExternalAPIItem = {
|
||||
@ -225,7 +233,7 @@ export type IndexingEstimateResponse = {
|
||||
total_price: number
|
||||
currency: string
|
||||
total_segments: number
|
||||
preview: Array<{ content: string, child_chunks: string[] }>
|
||||
preview: Array<{ content: string, child_chunks: string[], summary?: string }>
|
||||
qa_preview?: QA[]
|
||||
}
|
||||
|
||||
@ -262,6 +270,7 @@ export type ProcessRuleResponse = {
|
||||
mode: ProcessMode
|
||||
rules: Rules
|
||||
limits: Limits
|
||||
summary_index_setting?: SummaryIndexSetting
|
||||
}
|
||||
|
||||
export type Rules = {
|
||||
@ -392,6 +401,7 @@ export type InitialDocumentDetail = {
|
||||
total_segments?: number
|
||||
doc_form: ChunkingMode
|
||||
doc_language: string
|
||||
summary_index_status?: string
|
||||
}
|
||||
|
||||
export type SimpleDocumentDetail = InitialDocumentDetail & {
|
||||
@ -425,6 +435,7 @@ export type DocumentReq = {
|
||||
doc_form: ChunkingMode
|
||||
doc_language: string
|
||||
process_rule: ProcessRule
|
||||
summary_index_setting?: SummaryIndexSetting
|
||||
}
|
||||
|
||||
export type CreateDocumentReq = DocumentReq & {
|
||||
@ -467,6 +478,7 @@ export type NotionPage = {
|
||||
export type ProcessRule = {
|
||||
mode: ProcessMode
|
||||
rules: Rules
|
||||
summary_index_setting?: SummaryIndexSetting
|
||||
}
|
||||
|
||||
export type createDocumentResponse = {
|
||||
@ -575,6 +587,7 @@ export type SegmentDetailModel = {
|
||||
error: string | null
|
||||
stopped_at: number
|
||||
answer?: string
|
||||
summary?: string
|
||||
child_chunks?: ChildChunkDetail[]
|
||||
updated_at: number
|
||||
attachments: Attachment[]
|
||||
@ -618,6 +631,7 @@ export type HitTesting = {
|
||||
tsne_position: TsnePosition
|
||||
child_chunks: HitTestingChildChunk[] | null
|
||||
files: Attachment[]
|
||||
summary?: string
|
||||
}
|
||||
|
||||
export type ExternalKnowledgeBaseHitTesting = {
|
||||
@ -697,6 +711,7 @@ export type RelatedAppResponse = {
|
||||
export type SegmentUpdater = {
|
||||
content: string
|
||||
answer?: string
|
||||
summary?: string
|
||||
keywords?: string[]
|
||||
regenerate_child_chunks?: boolean
|
||||
attachment_ids?: string[]
|
||||
@ -778,6 +793,7 @@ export enum DocumentActionType {
|
||||
archive = 'archive',
|
||||
unArchive = 'un_archive',
|
||||
delete = 'delete',
|
||||
summary = 'summary',
|
||||
}
|
||||
|
||||
export type UpdateDocumentBatchParams = {
|
||||
|
||||
@ -107,6 +107,18 @@ export const useSyncDocument = () => {
|
||||
})
|
||||
}
|
||||
|
||||
export const useDocumentSummary = () => {
|
||||
return useMutation({
|
||||
mutationFn: ({ datasetId, documentIds, documentId }: UpdateDocumentBatchParams) => {
|
||||
return post<CommonResponse>(`/datasets/${datasetId}/documents/generate-summary`, {
|
||||
body: {
|
||||
document_list: documentId ? [documentId] : documentIds!,
|
||||
},
|
||||
})
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
export const useSyncWebsite = () => {
|
||||
return useMutation({
|
||||
mutationFn: ({ datasetId, documentId }: UpdateDocumentBatchParams) => {
|
||||
|
||||
Loading…
Reference in New Issue
Block a user