diff --git a/api/.importlinter b/api/.importlinter index e30f498ba9..fd768ad48b 100644 --- a/api/.importlinter +++ b/api/.importlinter @@ -104,6 +104,8 @@ forbidden_modules = core.trigger core.variables ignore_imports = + core.workflow.nodes.agent.agent_node -> core.db.session_factory + core.workflow.nodes.agent.agent_node -> models.tools core.workflow.nodes.loop.loop_node -> core.app.workflow.node_factory core.workflow.graph_engine.command_channels.redis_channel -> extensions.ext_redis core.workflow.workflow_entry -> core.app.workflow.layers.observability diff --git a/api/controllers/console/app/generator.py b/api/controllers/console/app/generator.py index d398997ccf..0e1818094a 100644 --- a/api/controllers/console/app/generator.py +++ b/api/controllers/console/app/generator.py @@ -1,9 +1,13 @@ +import logging from collections.abc import Sequence from typing import Any from flask_restx import Resource from pydantic import BaseModel, Field +logger = logging.getLogger(__name__) + + from controllers.console import console_ns from controllers.console.app.error import ( CompletionRequestError, diff --git a/api/core/llm_generator/llm_generator.py b/api/core/llm_generator/llm_generator.py index 5b2c640265..09cf683cfd 100644 --- a/api/core/llm_generator/llm_generator.py +++ b/api/core/llm_generator/llm_generator.py @@ -32,6 +32,7 @@ from core.ops.ops_trace_manager import TraceQueueManager, TraceTask from core.ops.utils import measure_time from core.prompt.utils.prompt_template_parser import PromptTemplateParser from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey +from core.workflow.generator import WorkflowGenerator from extensions.ext_database import db from extensions.ext_storage import storage from models import App, Message, WorkflowNodeExecutionModel @@ -285,6 +286,35 @@ class LLMGenerator: return rule_config + @classmethod + def generate_workflow_flowchart( + cls, + tenant_id: str, + instruction: str, + model_config: dict, + available_nodes: Sequence[dict[str, object]] | None = None, + existing_nodes: Sequence[dict[str, object]] | None = None, + available_tools: Sequence[dict[str, object]] | None = None, + selected_node_ids: Sequence[str] | None = None, + previous_workflow: dict[str, object] | None = None, + regenerate_mode: bool = False, + preferred_language: str | None = None, + available_models: Sequence[dict[str, object]] | None = None, + ): + return WorkflowGenerator.generate_workflow_flowchart( + tenant_id=tenant_id, + instruction=instruction, + model_config=model_config, + available_nodes=available_nodes, + existing_nodes=existing_nodes, + available_tools=available_tools, + selected_node_ids=selected_node_ids, + previous_workflow=previous_workflow, + regenerate_mode=regenerate_mode, + preferred_language=preferred_language, + available_models=available_models, + ) + @classmethod def generate_code( cls, diff --git a/api/core/llm_generator/prompts.py b/api/core/llm_generator/prompts.py index ee9a016c95..8faf56528c 100644 --- a/api/core/llm_generator/prompts.py +++ b/api/core/llm_generator/prompts.py @@ -143,6 +143,50 @@ Based on task description, please create a well-structured prompt template that Please generate the full prompt template with at least 300 words and output only the prompt template. """ # noqa: E501 +WORKFLOW_FLOWCHART_PROMPT_TEMPLATE = """ +You are an expert workflow designer. Generate a Mermaid flowchart based on the user's request. + +Constraints: +- Detect the language of the user's request. Generate all node titles in the same language as the user's input. +- If the input language cannot be determined, use {{PREFERRED_LANGUAGE}} as the fallback language. +- Use only node types listed in . +- Use only tools listed in . When using a tool node, set type=tool and tool=. +- Tools may include MCP providers (provider_type=mcp). Tool selection still uses tool_key. +- Prefer reusing node titles from when possible. +- Output must be valid Mermaid flowchart syntax, no markdown, no extra text. +- First line must be: flowchart LR +- Every node must be declared on its own line using: + ["type=|title=|tool=<tool_key>"] + - type is required and must match a type in <available_nodes>. + - title is required for non-tool nodes. + - tool is required only when type=tool, otherwise omit tool. +- Declare all node lines before any edges. +- Edges must use: + <id> --> <id> + <id> -->|true| <id> + <id> -->|false| <id> +- Keep node ids unique and simple (N1, N2, ...). +- For complex orchestration: + - Break the request into stages (ingest, transform, decision, action, output). + - Use IfElse for branching and label edges true/false only. + - Fan-in branches by connecting multiple nodes into a shared downstream node. + - Avoid cycles unless explicitly requested. + - Keep each branch complete with a clear downstream target. + +<user_request> +{{TASK_DESCRIPTION}} +</user_request> +<available_nodes> +{{AVAILABLE_NODES}} +</available_nodes> +<existing_nodes> +{{EXISTING_NODES}} +</existing_nodes> +<available_tools> +{{AVAILABLE_TOOLS}} +</available_tools> +""" + RULE_CONFIG_PROMPT_GENERATE_TEMPLATE = """ Here is a task description for which I would like you to create a high-quality prompt template for: <task_description> diff --git a/api/core/workflow/nodes/agent/agent_node.py b/api/core/workflow/nodes/agent/agent_node.py index e195aebe6d..e64a83034c 100644 --- a/api/core/workflow/nodes/agent/agent_node.py +++ b/api/core/workflow/nodes/agent/agent_node.py @@ -2,7 +2,7 @@ from __future__ import annotations import json from collections.abc import Generator, Mapping, Sequence -from typing import TYPE_CHECKING, Any, cast +from typing import TYPE_CHECKING, Any, Union, cast from packaging.version import Version from pydantic import ValidationError @@ -11,6 +11,7 @@ from sqlalchemy.orm import Session from core.agent.entities import AgentToolEntity from core.agent.plugin_entities import AgentStrategyParameter +from core.db.session_factory import session_factory from core.file import File, FileTransferMethod from core.memory.token_buffer_memory import TokenBufferMemory from core.model_manager import ModelInstance, ModelManager @@ -49,6 +50,12 @@ from factories import file_factory from factories.agent_factory import get_plugin_agent_strategy from models import ToolFile from models.model import Conversation +from models.tools import ( + ApiToolProvider, + BuiltinToolProvider, + MCPToolProvider, + WorkflowToolProvider, +) from services.tools.builtin_tools_manage_service import BuiltinToolManageService from .exc import ( @@ -259,7 +266,7 @@ class AgentNode(Node[AgentNodeData]): value = cast(list[dict[str, Any]], value) tool_value = [] for tool in value: - provider_type = ToolProviderType(tool.get("type", ToolProviderType.BUILT_IN)) + provider_type = self._infer_tool_provider_type(tool, self.tenant_id) setting_params = tool.get("settings", {}) parameters = tool.get("parameters", {}) manual_input_params = [key for key, value in parameters.items() if value is not None] @@ -748,3 +755,34 @@ class AgentNode(Node[AgentNodeData]): llm_usage=llm_usage, ) ) + + @staticmethod + def _infer_tool_provider_type(tool_config: dict[str, Any], tenant_id: str) -> ToolProviderType: + provider_type_str = tool_config.get("type") + if provider_type_str: + return ToolProviderType(provider_type_str) + + provider_id = tool_config.get("provider_name") + if not provider_id: + return ToolProviderType.BUILT_IN + + with session_factory.create_session() as session: + provider_map: dict[ + type[Union[WorkflowToolProvider, MCPToolProvider, ApiToolProvider, BuiltinToolProvider]], + ToolProviderType, + ] = { + WorkflowToolProvider: ToolProviderType.WORKFLOW, + MCPToolProvider: ToolProviderType.MCP, + ApiToolProvider: ToolProviderType.API, + BuiltinToolProvider: ToolProviderType.BUILT_IN, + } + + for provider_model, provider_type in provider_map.items(): + stmt = select(provider_model).where( + provider_model.id == provider_id, + provider_model.tenant_id == tenant_id, + ) + if session.scalar(stmt): + return provider_type + + raise AgentNodeError(f"Tool provider with ID '{provider_id}' not found.") diff --git a/api/core/workflow/nodes/base/node.py b/api/core/workflow/nodes/base/node.py index 2b773b537c..870c1f4090 100644 --- a/api/core/workflow/nodes/base/node.py +++ b/api/core/workflow/nodes/base/node.py @@ -212,6 +212,14 @@ class Node(Generic[NodeDataT]): return None + @classmethod + def get_default_config_schema(cls) -> dict[str, Any] | None: + """ + Get the default configuration schema for the node. + Used for LLM generation. + """ + return None + # Global registry populated via __init_subclass__ _registry: ClassVar[dict[NodeType, dict[str, type[Node]]]] = {} diff --git a/api/core/workflow/nodes/end/end_node.py b/api/core/workflow/nodes/end/end_node.py index 2efcb4f418..299cbb90ad 100644 --- a/api/core/workflow/nodes/end/end_node.py +++ b/api/core/workflow/nodes/end/end_node.py @@ -1,3 +1,5 @@ +from typing import Any + from core.workflow.enums import NodeExecutionType, NodeType, WorkflowNodeExecutionStatus from core.workflow.node_events import NodeRunResult from core.workflow.nodes.base.node import Node @@ -9,6 +11,24 @@ class EndNode(Node[EndNodeData]): node_type = NodeType.END execution_type = NodeExecutionType.RESPONSE + @classmethod + def get_default_config_schema(cls) -> dict[str, Any] | None: + return { + "description": "Workflow exit point - defines output variables", + "required": ["outputs"], + "parameters": { + "outputs": { + "type": "array", + "description": "Output variables to return", + "item_schema": { + "variable": "string - output variable name", + "type": "enum: string, number, object, array", + "value_selector": "array - path to source value, e.g. ['node_id', 'field']", + }, + }, + }, + } + @classmethod def version(cls) -> str: return "1" diff --git a/api/core/workflow/nodes/start/start_node.py b/api/core/workflow/nodes/start/start_node.py index 53c1b4ee6b..3279edd78e 100644 --- a/api/core/workflow/nodes/start/start_node.py +++ b/api/core/workflow/nodes/start/start_node.py @@ -14,6 +14,27 @@ class StartNode(Node[StartNodeData]): node_type = NodeType.START execution_type = NodeExecutionType.ROOT + @classmethod + def get_default_config_schema(cls) -> dict[str, Any] | None: + return { + "description": "Workflow entry point - defines input variables", + "required": [], + "parameters": { + "variables": { + "type": "array", + "description": "Input variables for the workflow", + "item_schema": { + "variable": "string - variable name", + "label": "string - display label", + "type": "enum: text-input, paragraph, number, select, file, file-list", + "required": "boolean", + "max_length": "number (optional)", + }, + }, + }, + "outputs": ["All defined variables are available as {{#start.variable_name#}}"], + } + @classmethod def version(cls) -> str: return "1" diff --git a/api/core/workflow/nodes/tool/tool_node.py b/api/core/workflow/nodes/tool/tool_node.py index 60d76db9b6..e5ed1679d0 100644 --- a/api/core/workflow/nodes/tool/tool_node.py +++ b/api/core/workflow/nodes/tool/tool_node.py @@ -50,6 +50,19 @@ class ToolNode(Node[ToolNodeData]): def version(cls) -> str: return "1" + @classmethod + def get_default_config_schema(cls) -> dict[str, Any] | None: + return { + "description": "Execute an external tool", + "required": ["provider_id", "tool_id", "tool_parameters"], + "parameters": { + "provider_id": {"type": "string"}, + "provider_type": {"type": "string"}, + "tool_id": {"type": "string"}, + "tool_parameters": {"type": "object"}, + }, + } + def _run(self) -> Generator[NodeEventBase, None, None]: """ Run the tool node diff --git a/api/tests/unit_tests/core/llm_generator/test_graph_builder.py b/api/tests/unit_tests/core/llm_generator/test_graph_builder.py new file mode 100644 index 0000000000..2d72c41684 --- /dev/null +++ b/api/tests/unit_tests/core/llm_generator/test_graph_builder.py @@ -0,0 +1,400 @@ +""" +Unit tests for GraphBuilder. + +Tests the automatic graph construction from node lists with dependency declarations. +""" + +import pytest + +from core.workflow.generator.utils.graph_builder import ( + CyclicDependencyError, + GraphBuilder, +) + + +class TestGraphBuilderBasic: + """Basic functionality tests.""" + + def test_empty_nodes_creates_minimal_workflow(self): + """Empty node list creates start -> end workflow.""" + result_nodes, result_edges = GraphBuilder.build_graph([]) + + assert len(result_nodes) == 2 + assert result_nodes[0]["type"] == "start" + assert result_nodes[1]["type"] == "end" + assert len(result_edges) == 1 + assert result_edges[0]["source"] == "start" + assert result_edges[0]["target"] == "end" + + def test_simple_linear_workflow(self): + """Simple linear workflow: start -> fetch -> process -> end.""" + nodes = [ + {"id": "fetch", "type": "http-request", "depends_on": []}, + {"id": "process", "type": "llm", "depends_on": ["fetch"]}, + ] + result_nodes, result_edges = GraphBuilder.build_graph(nodes) + + # Should have: start + 2 user nodes + end = 4 + assert len(result_nodes) == 4 + assert result_nodes[0]["type"] == "start" + assert result_nodes[-1]["type"] == "end" + + # Should have: start->fetch, fetch->process, process->end = 3 + assert len(result_edges) == 3 + + # Verify edge connections + edge_pairs = [(e["source"], e["target"]) for e in result_edges] + assert ("start", "fetch") in edge_pairs + assert ("fetch", "process") in edge_pairs + assert ("process", "end") in edge_pairs + + +class TestParallelWorkflow: + """Tests for parallel node handling.""" + + def test_parallel_workflow(self): + """Parallel workflow: multiple nodes from start, merging to one.""" + nodes = [ + {"id": "api1", "type": "http-request", "depends_on": []}, + {"id": "api2", "type": "http-request", "depends_on": []}, + {"id": "merge", "type": "llm", "depends_on": ["api1", "api2"]}, + ] + result_nodes, result_edges = GraphBuilder.build_graph(nodes) + + # start should connect to both api1 and api2 + start_edges = [e for e in result_edges if e["source"] == "start"] + assert len(start_edges) == 2 + + start_targets = {e["target"] for e in start_edges} + assert start_targets == {"api1", "api2"} + + # Both api1 and api2 should connect to merge + merge_incoming = [e for e in result_edges if e["target"] == "merge"] + assert len(merge_incoming) == 2 + + def test_multiple_terminal_nodes(self): + """Multiple terminal nodes all connect to end.""" + nodes = [ + {"id": "branch1", "type": "llm", "depends_on": []}, + {"id": "branch2", "type": "llm", "depends_on": []}, + ] + result_nodes, result_edges = GraphBuilder.build_graph(nodes) + + # Both branches should connect to end + end_incoming = [e for e in result_edges if e["target"] == "end"] + assert len(end_incoming) == 2 + + +class TestIfElseWorkflow: + """Tests for if-else branching.""" + + def test_if_else_workflow(self): + """Conditional branching workflow.""" + nodes = [ + { + "id": "check", + "type": "if-else", + "config": {"true_branch": "success", "false_branch": "fallback"}, + "depends_on": [], + }, + {"id": "success", "type": "llm", "depends_on": []}, + {"id": "fallback", "type": "code", "depends_on": []}, + ] + result_nodes, result_edges = GraphBuilder.build_graph(nodes) + + # Should have true and false branch edges + branch_edges = [e for e in result_edges if e["source"] == "check"] + assert len(branch_edges) == 2 + assert any(e.get("sourceHandle") == "true" for e in branch_edges) + assert any(e.get("sourceHandle") == "false" for e in branch_edges) + + # Verify targets + true_edge = next(e for e in branch_edges if e.get("sourceHandle") == "true") + false_edge = next(e for e in branch_edges if e.get("sourceHandle") == "false") + assert true_edge["target"] == "success" + assert false_edge["target"] == "fallback" + + def test_if_else_missing_branch_no_error(self): + """if-else with only true branch doesn't error (warning only).""" + nodes = [ + { + "id": "check", + "type": "if-else", + "config": {"true_branch": "success"}, + "depends_on": [], + }, + {"id": "success", "type": "llm", "depends_on": []}, + ] + # Should not raise + result_nodes, result_edges = GraphBuilder.build_graph(nodes) + + # Should have one branch edge + branch_edges = [e for e in result_edges if e["source"] == "check"] + assert len(branch_edges) == 1 + assert branch_edges[0].get("sourceHandle") == "true" + + +class TestQuestionClassifierWorkflow: + """Tests for question-classifier branching.""" + + def test_question_classifier_workflow(self): + """Question classifier with multiple classes.""" + nodes = [ + { + "id": "classifier", + "type": "question-classifier", + "config": { + "query": ["start", "user_input"], + "classes": [ + {"id": "tech", "name": "技术问题", "target": "tech_handler"}, + {"id": "sales", "name": "销售咨询", "target": "sales_handler"}, + {"id": "other", "name": "其他问题", "target": "other_handler"}, + ], + }, + "depends_on": [], + }, + {"id": "tech_handler", "type": "llm", "depends_on": []}, + {"id": "sales_handler", "type": "llm", "depends_on": []}, + {"id": "other_handler", "type": "llm", "depends_on": []}, + ] + result_nodes, result_edges = GraphBuilder.build_graph(nodes) + + # Should have 3 branch edges from classifier + classifier_edges = [e for e in result_edges if e["source"] == "classifier"] + assert len(classifier_edges) == 3 + + # Each should use class id as sourceHandle + assert any(e.get("sourceHandle") == "tech" and e["target"] == "tech_handler" for e in classifier_edges) + assert any(e.get("sourceHandle") == "sales" and e["target"] == "sales_handler" for e in classifier_edges) + assert any(e.get("sourceHandle") == "other" and e["target"] == "other_handler" for e in classifier_edges) + + def test_question_classifier_missing_target(self): + """Classes without target connect to end.""" + nodes = [ + { + "id": "classifier", + "type": "question-classifier", + "config": { + "classes": [ + {"id": "known", "name": "已知问题", "target": "handler"}, + {"id": "unknown", "name": "未知问题"}, # Missing target + ], + }, + "depends_on": [], + }, + {"id": "handler", "type": "llm", "depends_on": []}, + ] + result_nodes, result_edges = GraphBuilder.build_graph(nodes) + + # Missing target should connect to end + classifier_edges = [e for e in result_edges if e["source"] == "classifier"] + assert any(e.get("sourceHandle") == "unknown" and e["target"] == "end" for e in classifier_edges) + + +class TestVariableDependencyInference: + """Tests for automatic dependency inference from variables.""" + + def test_variable_dependency_inference(self): + """Dependencies inferred from variable references.""" + nodes = [ + {"id": "fetch", "type": "http-request", "depends_on": []}, + { + "id": "process", + "type": "llm", + "config": {"prompt_template": [{"text": "{{#fetch.body#}}"}]}, + # No explicit depends_on, but references fetch + }, + ] + result_nodes, result_edges = GraphBuilder.build_graph(nodes) + + # Should automatically infer process depends on fetch + assert any(e["source"] == "fetch" and e["target"] == "process" for e in result_edges) + + def test_system_variable_not_inferred(self): + """System variables (sys, start) not inferred as dependencies.""" + nodes = [ + { + "id": "process", + "type": "llm", + "config": {"prompt_template": [{"text": "{{#sys.query#}} {{#start.input#}}"}]}, + "depends_on": [], + }, + ] + result_nodes, result_edges = GraphBuilder.build_graph(nodes) + + # Should connect to start, not create dependency on sys or start + edge_sources = {e["source"] for e in result_edges} + assert "sys" not in edge_sources + assert "start" in edge_sources + + +class TestCycleDetection: + """Tests for cyclic dependency detection.""" + + def test_cyclic_dependency_detected(self): + """Cyclic dependencies raise error.""" + nodes = [ + {"id": "a", "type": "llm", "depends_on": ["c"]}, + {"id": "b", "type": "llm", "depends_on": ["a"]}, + {"id": "c", "type": "llm", "depends_on": ["b"]}, + ] + + with pytest.raises(CyclicDependencyError): + GraphBuilder.build_graph(nodes) + + def test_self_dependency_detected(self): + """Self-dependency raises error.""" + nodes = [ + {"id": "a", "type": "llm", "depends_on": ["a"]}, + ] + + with pytest.raises(CyclicDependencyError): + GraphBuilder.build_graph(nodes) + + +class TestErrorRecovery: + """Tests for silent error recovery.""" + + def test_invalid_dependency_removed(self): + """Invalid dependencies (non-existent nodes) are silently removed.""" + nodes = [ + {"id": "process", "type": "llm", "depends_on": ["nonexistent"]}, + ] + # Should not raise, invalid dependency silently removed + result_nodes, result_edges = GraphBuilder.build_graph(nodes) + + # Process should connect from start (since invalid dep was removed) + assert any(e["source"] == "start" and e["target"] == "process" for e in result_edges) + + def test_depends_on_as_string(self): + """depends_on as string is converted to list.""" + nodes = [ + {"id": "fetch", "type": "http-request", "depends_on": []}, + {"id": "process", "type": "llm", "depends_on": "fetch"}, # String instead of list + ] + result_nodes, result_edges = GraphBuilder.build_graph(nodes) + + # Should work correctly + assert any(e["source"] == "fetch" and e["target"] == "process" for e in result_edges) + + +class TestContainerNodes: + """Tests for container nodes (iteration, loop).""" + + def test_iteration_node_as_regular_node(self): + """Iteration nodes behave as regular single-in-single-out nodes.""" + nodes = [ + {"id": "prepare", "type": "code", "depends_on": []}, + { + "id": "loop", + "type": "iteration", + "config": {"iterator_selector": ["prepare", "items"]}, + "depends_on": ["prepare"], + }, + {"id": "process_result", "type": "llm", "depends_on": ["loop"]}, + ] + result_nodes, result_edges = GraphBuilder.build_graph(nodes) + + # Should have standard edges: start->prepare, prepare->loop, loop->process_result, process_result->end + edge_pairs = [(e["source"], e["target"]) for e in result_edges] + assert ("start", "prepare") in edge_pairs + assert ("prepare", "loop") in edge_pairs + assert ("loop", "process_result") in edge_pairs + assert ("process_result", "end") in edge_pairs + + def test_loop_node_as_regular_node(self): + """Loop nodes behave as regular single-in-single-out nodes.""" + nodes = [ + {"id": "init", "type": "code", "depends_on": []}, + { + "id": "repeat", + "type": "loop", + "config": {"loop_count": 5}, + "depends_on": ["init"], + }, + {"id": "finish", "type": "llm", "depends_on": ["repeat"]}, + ] + result_nodes, result_edges = GraphBuilder.build_graph(nodes) + + # Standard edge flow + edge_pairs = [(e["source"], e["target"]) for e in result_edges] + assert ("init", "repeat") in edge_pairs + assert ("repeat", "finish") in edge_pairs + + def test_iteration_with_variable_inference(self): + """Iteration node dependencies can be inferred from iterator_selector.""" + nodes = [ + {"id": "data_source", "type": "http-request", "depends_on": []}, + { + "id": "process_each", + "type": "iteration", + "config": { + "iterator_selector": ["data_source", "items"], + }, + # No explicit depends_on, but references data_source + }, + ] + result_nodes, result_edges = GraphBuilder.build_graph(nodes) + + # Should infer dependency from iterator_selector reference + # Note: iterator_selector format is different from {{#...#}}, so this tests + # that explicit depends_on is properly handled when not provided + # In this case, process_each has no depends_on, so it connects to start + edge_pairs = [(e["source"], e["target"]) for e in result_edges] + # Without explicit depends_on, connects to start + assert ("start", "process_each") in edge_pairs or ("data_source", "process_each") in edge_pairs + + def test_loop_node_self_reference_not_cycle(self): + """Loop nodes referencing their own outputs should not create cycle.""" + nodes = [ + {"id": "init", "type": "code", "depends_on": []}, + { + "id": "my_loop", + "type": "loop", + "config": { + "loop_count": 5, + # Loop node referencing its own output (common pattern) + "prompt": "Previous: {{#my_loop.output#}}, continue...", + }, + "depends_on": ["init"], + }, + {"id": "finish", "type": "llm", "depends_on": ["my_loop"]}, + ] + # Should NOT raise CyclicDependencyError + result_nodes, result_edges = GraphBuilder.build_graph(nodes) + + # Verify the graph is built correctly + assert len(result_nodes) == 5 # start + 3 + end + edge_pairs = [(e["source"], e["target"]) for e in result_edges] + assert ("init", "my_loop") in edge_pairs + assert ("my_loop", "finish") in edge_pairs + + +class TestEdgeStructure: + """Tests for edge structure correctness.""" + + def test_edge_has_required_fields(self): + """Edges have all required fields.""" + nodes = [ + {"id": "node1", "type": "llm", "depends_on": []}, + ] + result_nodes, result_edges = GraphBuilder.build_graph(nodes) + + for edge in result_edges: + assert "id" in edge + assert "source" in edge + assert "target" in edge + assert "sourceHandle" in edge + assert "targetHandle" in edge + + def test_edge_id_unique(self): + """Each edge has a unique ID.""" + nodes = [ + {"id": "a", "type": "llm", "depends_on": []}, + {"id": "b", "type": "llm", "depends_on": []}, + {"id": "c", "type": "llm", "depends_on": ["a", "b"]}, + ] + result_nodes, result_edges = GraphBuilder.build_graph(nodes) + + edge_ids = [e["id"] for e in result_edges] + assert len(edge_ids) == len(set(edge_ids)) # All unique diff --git a/api/tests/unit_tests/core/llm_generator/test_mermaid_generator.py b/api/tests/unit_tests/core/llm_generator/test_mermaid_generator.py new file mode 100644 index 0000000000..bdeff2258c --- /dev/null +++ b/api/tests/unit_tests/core/llm_generator/test_mermaid_generator.py @@ -0,0 +1,287 @@ +""" +Unit tests for the Mermaid Generator. + +Tests cover: +- Basic workflow rendering +- Reserved word handling ('end' → 'end_node') +- Question classifier multi-branch edges +- If-else branch labels +- Edge validation and skipping +- Tool node formatting +""" + +from core.workflow.generator.utils.mermaid_generator import generate_mermaid + + +class TestBasicWorkflow: + """Tests for basic workflow Mermaid generation.""" + + def test_simple_start_end_workflow(self): + """Test simple Start → End workflow.""" + workflow_data = { + "nodes": [ + {"id": "start", "type": "start", "title": "Start"}, + {"id": "end", "type": "end", "title": "End"}, + ], + "edges": [{"source": "start", "target": "end"}], + } + result = generate_mermaid(workflow_data) + + assert "flowchart TD" in result + assert 'start["type=start|title=Start"]' in result + assert 'end_node["type=end|title=End"]' in result + assert "start --> end_node" in result + + def test_start_llm_end_workflow(self): + """Test Start → LLM → End workflow.""" + workflow_data = { + "nodes": [ + {"id": "start", "type": "start", "title": "Start"}, + {"id": "llm", "type": "llm", "title": "Generate"}, + {"id": "end", "type": "end", "title": "End"}, + ], + "edges": [ + {"source": "start", "target": "llm"}, + {"source": "llm", "target": "end"}, + ], + } + result = generate_mermaid(workflow_data) + + assert 'llm["type=llm|title=Generate"]' in result + assert "start --> llm" in result + assert "llm --> end_node" in result + + def test_empty_workflow(self): + """Test empty workflow returns minimal output.""" + workflow_data = {"nodes": [], "edges": []} + result = generate_mermaid(workflow_data) + + assert result == "flowchart TD" + + def test_missing_keys_handled(self): + """Test workflow with missing keys doesn't crash.""" + workflow_data = {} + result = generate_mermaid(workflow_data) + + assert "flowchart TD" in result + + +class TestReservedWords: + """Tests for reserved word handling in node IDs.""" + + def test_end_node_id_is_replaced(self): + """Test 'end' node ID is replaced with 'end_node'.""" + workflow_data = { + "nodes": [{"id": "end", "type": "end", "title": "End"}], + "edges": [], + } + result = generate_mermaid(workflow_data) + + # Should use end_node instead of end + assert "end_node[" in result + assert '"type=end|title=End"' in result + + def test_subgraph_node_id_is_replaced(self): + """Test 'subgraph' node ID is replaced with 'subgraph_node'.""" + workflow_data = { + "nodes": [{"id": "subgraph", "type": "code", "title": "Process"}], + "edges": [], + } + result = generate_mermaid(workflow_data) + + assert "subgraph_node[" in result + + def test_edge_uses_safe_ids(self): + """Test edges correctly reference safe IDs after replacement.""" + workflow_data = { + "nodes": [ + {"id": "start", "type": "start", "title": "Start"}, + {"id": "end", "type": "end", "title": "End"}, + ], + "edges": [{"source": "start", "target": "end"}], + } + result = generate_mermaid(workflow_data) + + # Edge should use end_node, not end + assert "start --> end_node" in result + assert "start --> end\n" not in result + + +class TestBranchEdges: + """Tests for branching node edge labels.""" + + def test_question_classifier_source_handles(self): + """Test question-classifier edges with sourceHandle labels.""" + workflow_data = { + "nodes": [ + {"id": "classifier", "type": "question-classifier", "title": "Classify"}, + {"id": "refund", "type": "llm", "title": "Handle Refund"}, + {"id": "inquiry", "type": "llm", "title": "Handle Inquiry"}, + ], + "edges": [ + {"source": "classifier", "target": "refund", "sourceHandle": "refund"}, + {"source": "classifier", "target": "inquiry", "sourceHandle": "inquiry"}, + ], + } + result = generate_mermaid(workflow_data) + + assert "classifier -->|refund| refund" in result + assert "classifier -->|inquiry| inquiry" in result + + def test_if_else_true_false_handles(self): + """Test if-else edges with true/false labels.""" + workflow_data = { + "nodes": [ + {"id": "ifelse", "type": "if-else", "title": "Check"}, + {"id": "yes_branch", "type": "llm", "title": "Yes"}, + {"id": "no_branch", "type": "llm", "title": "No"}, + ], + "edges": [ + {"source": "ifelse", "target": "yes_branch", "sourceHandle": "true"}, + {"source": "ifelse", "target": "no_branch", "sourceHandle": "false"}, + ], + } + result = generate_mermaid(workflow_data) + + assert "ifelse -->|true| yes_branch" in result + assert "ifelse -->|false| no_branch" in result + + def test_source_handle_source_is_ignored(self): + """Test sourceHandle='source' doesn't add label.""" + workflow_data = { + "nodes": [ + {"id": "llm1", "type": "llm", "title": "LLM 1"}, + {"id": "llm2", "type": "llm", "title": "LLM 2"}, + ], + "edges": [{"source": "llm1", "target": "llm2", "sourceHandle": "source"}], + } + result = generate_mermaid(workflow_data) + + # Should be plain arrow without label + assert "llm1 --> llm2" in result + assert "llm1 -->|source|" not in result + + +class TestEdgeValidation: + """Tests for edge validation and error handling.""" + + def test_edge_with_missing_source_is_skipped(self): + """Test edge with non-existent source node is skipped.""" + workflow_data = { + "nodes": [{"id": "end", "type": "end", "title": "End"}], + "edges": [{"source": "nonexistent", "target": "end"}], + } + result = generate_mermaid(workflow_data) + + # Should not contain the invalid edge + assert "nonexistent" not in result + assert "-->" not in result or "nonexistent" not in result + + def test_edge_with_missing_target_is_skipped(self): + """Test edge with non-existent target node is skipped.""" + workflow_data = { + "nodes": [{"id": "start", "type": "start", "title": "Start"}], + "edges": [{"source": "start", "target": "nonexistent"}], + } + result = generate_mermaid(workflow_data) + + # Edge should be skipped + assert "start --> nonexistent" not in result + + def test_edge_without_source_or_target_is_skipped(self): + """Test edge missing source or target is skipped.""" + workflow_data = { + "nodes": [{"id": "start", "type": "start", "title": "Start"}], + "edges": [{"source": "start"}, {"target": "start"}, {}], + } + result = generate_mermaid(workflow_data) + + # No edges should be rendered + assert result.count("-->") == 0 + + +class TestToolNodes: + """Tests for tool node formatting.""" + + def test_tool_node_includes_tool_key(self): + """Test tool node includes tool_key in label.""" + workflow_data = { + "nodes": [ + { + "id": "search", + "type": "tool", + "title": "Search", + "config": {"tool_key": "google/search"}, + } + ], + "edges": [], + } + result = generate_mermaid(workflow_data) + + assert 'search["type=tool|title=Search|tool=google/search"]' in result + + def test_tool_node_with_tool_name_fallback(self): + """Test tool node uses tool_name as fallback.""" + workflow_data = { + "nodes": [ + { + "id": "tool1", + "type": "tool", + "title": "My Tool", + "config": {"tool_name": "my_tool"}, + } + ], + "edges": [], + } + result = generate_mermaid(workflow_data) + + assert "tool=my_tool" in result + + def test_tool_node_missing_tool_key_shows_unknown(self): + """Test tool node without tool_key shows 'unknown'.""" + workflow_data = { + "nodes": [{"id": "tool1", "type": "tool", "title": "Tool", "config": {}}], + "edges": [], + } + result = generate_mermaid(workflow_data) + + assert "tool=unknown" in result + + +class TestNodeFormatting: + """Tests for node label formatting.""" + + def test_quotes_in_title_are_escaped(self): + """Test double quotes in title are replaced with single quotes.""" + workflow_data = { + "nodes": [{"id": "llm", "type": "llm", "title": 'Say "Hello"'}], + "edges": [], + } + result = generate_mermaid(workflow_data) + + # Double quotes should be replaced + assert "Say 'Hello'" in result + assert 'Say "Hello"' not in result + + def test_node_without_id_is_skipped(self): + """Test node without id is skipped.""" + workflow_data = { + "nodes": [{"type": "llm", "title": "No ID"}], + "edges": [], + } + result = generate_mermaid(workflow_data) + + # Should only have flowchart header + lines = [line for line in result.split("\n") if line.strip()] + assert len(lines) == 1 + + def test_node_default_values(self): + """Test node with missing type/title uses defaults.""" + workflow_data = { + "nodes": [{"id": "node1"}], + "edges": [], + } + result = generate_mermaid(workflow_data) + + assert "type=unknown" in result + assert "title=Untitled" in result diff --git a/api/tests/unit_tests/core/llm_generator/test_node_repair.py b/api/tests/unit_tests/core/llm_generator/test_node_repair.py new file mode 100644 index 0000000000..a92a7d0125 --- /dev/null +++ b/api/tests/unit_tests/core/llm_generator/test_node_repair.py @@ -0,0 +1,81 @@ +from core.workflow.generator.utils.node_repair import NodeRepair + + +class TestNodeRepair: + """Tests for NodeRepair utility.""" + + def test_repair_if_else_valid_operators(self): + """Test that valid operators remain unchanged.""" + nodes = [ + { + "id": "node1", + "type": "if-else", + "config": { + "cases": [ + { + "conditions": [ + {"comparison_operator": "≥", "value": "1"}, + {"comparison_operator": "=", "value": "2"}, + ] + } + ] + }, + } + ] + result = NodeRepair.repair(nodes) + assert result.was_repaired is False + assert result.nodes == nodes + + def test_repair_if_else_invalid_operators(self): + """Test that invalid operators are normalized.""" + nodes = [ + { + "id": "node1", + "type": "if-else", + "config": { + "cases": [ + { + "conditions": [ + {"comparison_operator": ">=", "value": "1"}, + {"comparison_operator": "<=", "value": "2"}, + {"comparison_operator": "!=", "value": "3"}, + {"comparison_operator": "==", "value": "4"}, + ] + } + ] + }, + } + ] + result = NodeRepair.repair(nodes) + assert result.was_repaired is True + assert len(result.repairs_made) == 4 + + conditions = result.nodes[0]["config"]["cases"][0]["conditions"] + assert conditions[0]["comparison_operator"] == "≥" + assert conditions[1]["comparison_operator"] == "≤" + assert conditions[2]["comparison_operator"] == "≠" + assert conditions[3]["comparison_operator"] == "=" + + def test_repair_ignores_other_nodes(self): + """Test that other node types are ignored.""" + nodes = [{"id": "node1", "type": "llm", "config": {"some_field": ">="}}] + result = NodeRepair.repair(nodes) + assert result.was_repaired is False + assert result.nodes[0]["config"]["some_field"] == ">=" + + def test_repair_handles_missing_config(self): + """Test robustness against missing fields.""" + nodes = [ + { + "id": "node1", + "type": "if-else", + # Missing config + }, + { + "id": "node2", + "type": "if-else", + "config": {}, # Missing cases + }, + ] + result = NodeRepair.repair(nodes) + assert result.was_repaired is False diff --git a/api/tests/unit_tests/core/llm_generator/test_node_schemas_validation.py b/api/tests/unit_tests/core/llm_generator/test_node_schemas_validation.py new file mode 100644 index 0000000000..eccfd93207 --- /dev/null +++ b/api/tests/unit_tests/core/llm_generator/test_node_schemas_validation.py @@ -0,0 +1,99 @@ +""" +Tests for node schemas validation. + +Ensures that the node configuration stays in sync with registered node types. +""" + +from core.workflow.generator.config.node_schemas import ( + get_builtin_node_schemas, + validate_node_schemas, +) + + +class TestNodeSchemasValidation: + """Tests for node schema validation utilities.""" + + def test_validate_node_schemas_returns_no_warnings(self): + """Ensure all registered node types have corresponding schemas.""" + warnings = validate_node_schemas() + # If this test fails, it means a new node type was added but + # no schema was defined for it in node_schemas.py + assert len(warnings) == 0, ( + f"Missing schemas for node types: {warnings}. " + "Please add schemas for these node types in node_schemas.py " + "or add them to _INTERNAL_NODE_TYPES if they don't need schemas." + ) + + def test_builtin_node_schemas_not_empty(self): + """Ensure BUILTIN_NODE_SCHEMAS contains expected node types.""" + # get_builtin_node_schemas() includes dynamic schemas + all_schemas = get_builtin_node_schemas() + assert len(all_schemas) > 0 + # Core node types should always be present + expected_types = ["llm", "code", "http-request", "if-else"] + for node_type in expected_types: + assert node_type in all_schemas, f"Missing schema for core node type: {node_type}" + + def test_schema_structure(self): + """Ensure each schema has required fields.""" + all_schemas = get_builtin_node_schemas() + for node_type, schema in all_schemas.items(): + assert "description" in schema, f"Missing 'description' in schema for {node_type}" + # 'parameters' is optional but if present should be a dict + if "parameters" in schema: + assert isinstance(schema["parameters"], dict), ( + f"'parameters' in schema for {node_type} should be a dict" + ) + + +class TestNodeSchemasMerged: + """Tests to verify the merged configuration works correctly.""" + + def test_fallback_rules_available(self): + """Ensure FALLBACK_RULES is available from node_schemas.""" + from core.workflow.generator.config.node_schemas import FALLBACK_RULES + + assert len(FALLBACK_RULES) > 0 + assert "http-request" in FALLBACK_RULES + assert "code" in FALLBACK_RULES + assert "llm" in FALLBACK_RULES + + def test_node_type_aliases_available(self): + """Ensure NODE_TYPE_ALIASES is available from node_schemas.""" + from core.workflow.generator.config.node_schemas import NODE_TYPE_ALIASES + + assert len(NODE_TYPE_ALIASES) > 0 + assert NODE_TYPE_ALIASES.get("gpt") == "llm" + assert NODE_TYPE_ALIASES.get("api") == "http-request" + + def test_field_name_corrections_available(self): + """Ensure FIELD_NAME_CORRECTIONS is available from node_schemas.""" + from core.workflow.generator.config.node_schemas import ( + FIELD_NAME_CORRECTIONS, + get_corrected_field_name, + ) + + assert len(FIELD_NAME_CORRECTIONS) > 0 + # Test the helper function + assert get_corrected_field_name("http-request", "text") == "body" + assert get_corrected_field_name("llm", "response") == "text" + assert get_corrected_field_name("code", "unknown") == "unknown" + + def test_config_init_exports(self): + """Ensure config __init__.py exports all needed symbols.""" + from core.workflow.generator.config import ( + BUILTIN_NODE_SCHEMAS, + FALLBACK_RULES, + FIELD_NAME_CORRECTIONS, + NODE_TYPE_ALIASES, + get_corrected_field_name, + validate_node_schemas, + ) + + # Just verify imports work + assert BUILTIN_NODE_SCHEMAS is not None + assert FALLBACK_RULES is not None + assert FIELD_NAME_CORRECTIONS is not None + assert NODE_TYPE_ALIASES is not None + assert callable(get_corrected_field_name) + assert callable(validate_node_schemas) diff --git a/api/tests/unit_tests/core/llm_generator/test_planner_prompts.py b/api/tests/unit_tests/core/llm_generator/test_planner_prompts.py new file mode 100644 index 0000000000..ae632e0bb4 --- /dev/null +++ b/api/tests/unit_tests/core/llm_generator/test_planner_prompts.py @@ -0,0 +1,172 @@ +""" +Unit tests for the Planner Prompts. + +Tests cover: +- Tool formatting for planner context +- Edge cases with missing fields +- Empty tool lists +""" + +from core.workflow.generator.prompts.planner_prompts import format_tools_for_planner + + +class TestFormatToolsForPlanner: + """Tests for format_tools_for_planner function.""" + + def test_empty_tools_returns_default_message(self): + """Test empty tools list returns default message.""" + result = format_tools_for_planner([]) + + assert result == "No external tools available." + + def test_none_tools_returns_default_message(self): + """Test None tools list returns default message.""" + result = format_tools_for_planner(None) + + assert result == "No external tools available." + + def test_single_tool_formatting(self): + """Test single tool is formatted correctly.""" + tools = [ + { + "provider_id": "google", + "tool_key": "search", + "tool_label": "Google Search", + "tool_description": "Search the web using Google", + } + ] + result = format_tools_for_planner(tools) + + assert "[google/search]" in result + assert "Google Search" in result + assert "Search the web using Google" in result + + def test_multiple_tools_formatting(self): + """Test multiple tools are formatted correctly.""" + tools = [ + { + "provider_id": "google", + "tool_key": "search", + "tool_label": "Search", + "tool_description": "Web search", + }, + { + "provider_id": "slack", + "tool_key": "send_message", + "tool_label": "Send Message", + "tool_description": "Send a Slack message", + }, + ] + result = format_tools_for_planner(tools) + + lines = result.strip().split("\n") + assert len(lines) == 2 + assert "[google/search]" in result + assert "[slack/send_message]" in result + + def test_tool_without_provider_uses_key_only(self): + """Test tool without provider_id uses tool_key only.""" + tools = [ + { + "tool_key": "my_tool", + "tool_label": "My Tool", + "tool_description": "A custom tool", + } + ] + result = format_tools_for_planner(tools) + + # Should format as [my_tool] without provider prefix + assert "[my_tool]" in result + assert "My Tool" in result + + def test_tool_with_tool_name_fallback(self): + """Test tool uses tool_name when tool_key is missing.""" + tools = [ + { + "tool_name": "fallback_tool", + "description": "Fallback description", + } + ] + result = format_tools_for_planner(tools) + + assert "fallback_tool" in result + assert "Fallback description" in result + + def test_tool_with_missing_description(self): + """Test tool with missing description doesn't crash.""" + tools = [ + { + "provider_id": "test", + "tool_key": "tool1", + "tool_label": "Tool 1", + } + ] + result = format_tools_for_planner(tools) + + assert "[test/tool1]" in result + assert "Tool 1" in result + + def test_tool_with_all_missing_fields(self): + """Test tool with all fields missing uses defaults.""" + tools = [{}] + result = format_tools_for_planner(tools) + + # Should not crash, may produce minimal output + assert isinstance(result, str) + + def test_tool_uses_provider_fallback(self): + """Test tool uses 'provider' when 'provider_id' is missing.""" + tools = [ + { + "provider": "openai", + "tool_key": "dalle", + "tool_label": "DALL-E", + "tool_description": "Generate images", + } + ] + result = format_tools_for_planner(tools) + + assert "[openai/dalle]" in result + + def test_tool_label_fallback_to_key(self): + """Test tool_label falls back to tool_key when missing.""" + tools = [ + { + "provider_id": "test", + "tool_key": "my_key", + "tool_description": "Description here", + } + ] + result = format_tools_for_planner(tools) + + # Label should fallback to key + assert "my_key" in result + assert "Description here" in result + + +class TestPlannerPromptConstants: + """Tests for planner prompt constant availability.""" + + def test_planner_system_prompt_exists(self): + """Test PLANNER_SYSTEM_PROMPT is defined.""" + from core.workflow.generator.prompts.planner_prompts import PLANNER_SYSTEM_PROMPT + + assert PLANNER_SYSTEM_PROMPT is not None + assert len(PLANNER_SYSTEM_PROMPT) > 0 + assert "{tools_summary}" in PLANNER_SYSTEM_PROMPT + + def test_planner_user_prompt_exists(self): + """Test PLANNER_USER_PROMPT is defined.""" + from core.workflow.generator.prompts.planner_prompts import PLANNER_USER_PROMPT + + assert PLANNER_USER_PROMPT is not None + assert "{instruction}" in PLANNER_USER_PROMPT + + def test_planner_system_prompt_has_required_sections(self): + """Test PLANNER_SYSTEM_PROMPT has required XML sections.""" + from core.workflow.generator.prompts.planner_prompts import PLANNER_SYSTEM_PROMPT + + assert "<role>" in PLANNER_SYSTEM_PROMPT + assert "<task>" in PLANNER_SYSTEM_PROMPT + assert "<available_tools>" in PLANNER_SYSTEM_PROMPT + assert "<response_format>" in PLANNER_SYSTEM_PROMPT diff --git a/api/tests/unit_tests/core/llm_generator/test_validation_engine.py b/api/tests/unit_tests/core/llm_generator/test_validation_engine.py new file mode 100644 index 0000000000..5b4f8757dc --- /dev/null +++ b/api/tests/unit_tests/core/llm_generator/test_validation_engine.py @@ -0,0 +1,510 @@ +""" +Unit tests for the Validation Rule Engine. + +Tests cover: +- Structure rules (required fields, types, formats) +- Semantic rules (variable references, edge connections) +- Reference rules (model exists, tool configured, dataset valid) +- ValidationEngine integration +""" + +from core.workflow.generator.validation import ( + ValidationContext, + ValidationEngine, +) +from core.workflow.generator.validation.rules import ( + extract_variable_refs, + is_placeholder, +) + + +class TestPlaceholderDetection: + """Tests for placeholder detection utility.""" + + def test_detects_please_select(self): + assert is_placeholder("PLEASE_SELECT_YOUR_MODEL") is True + + def test_detects_your_prefix(self): + assert is_placeholder("YOUR_API_KEY") is True + + def test_detects_todo(self): + assert is_placeholder("TODO: fill this in") is True + + def test_detects_placeholder(self): + assert is_placeholder("PLACEHOLDER_VALUE") is True + + def test_detects_example_prefix(self): + assert is_placeholder("EXAMPLE_URL") is True + + def test_detects_replace_prefix(self): + assert is_placeholder("REPLACE_WITH_ACTUAL") is True + + def test_case_insensitive(self): + assert is_placeholder("please_select") is True + assert is_placeholder("Please_Select") is True + + def test_valid_values_not_detected(self): + assert is_placeholder("https://api.example.com") is False + assert is_placeholder("gpt-4") is False + assert is_placeholder("my_variable") is False + + def test_non_string_returns_false(self): + assert is_placeholder(123) is False + assert is_placeholder(None) is False + assert is_placeholder(["list"]) is False + + +class TestVariableRefExtraction: + """Tests for variable reference extraction.""" + + def test_extracts_simple_ref(self): + refs = extract_variable_refs("Hello {{#start.query#}}") + assert refs == [("start", "query")] + + def test_extracts_multiple_refs(self): + refs = extract_variable_refs("{{#node1.output#}} and {{#node2.text#}}") + assert refs == [("node1", "output"), ("node2", "text")] + + def test_extracts_nested_field(self): + refs = extract_variable_refs("{{#http_request.body#}}") + assert refs == [("http_request", "body")] + + def test_no_refs_returns_empty(self): + refs = extract_variable_refs("No references here") + assert refs == [] + + def test_handles_malformed_refs(self): + refs = extract_variable_refs("{{#invalid}} and {{incomplete#}}") + assert refs == [] + + +class TestValidationContext: + """Tests for ValidationContext.""" + + def test_node_map_lookup(self): + ctx = ValidationContext( + nodes=[ + {"id": "start", "type": "start"}, + {"id": "llm_1", "type": "llm"}, + ] + ) + assert ctx.get_node("start") == {"id": "start", "type": "start"} + assert ctx.get_node("nonexistent") is None + + def test_model_set(self): + ctx = ValidationContext( + available_models=[ + {"provider": "openai", "model": "gpt-4"}, + {"provider": "anthropic", "model": "claude-3"}, + ] + ) + assert ctx.has_model("openai", "gpt-4") is True + assert ctx.has_model("anthropic", "claude-3") is True + assert ctx.has_model("openai", "gpt-3.5") is False + + def test_tool_set(self): + ctx = ValidationContext( + available_tools=[ + {"provider_id": "google", "tool_key": "search", "is_team_authorization": True}, + {"provider_id": "slack", "tool_key": "send_message", "is_team_authorization": False}, + ] + ) + assert ctx.has_tool("google/search") is True + assert ctx.has_tool("search") is True + assert ctx.is_tool_configured("google/search") is True + assert ctx.is_tool_configured("slack/send_message") is False + + def test_upstream_downstream_nodes(self): + ctx = ValidationContext( + nodes=[ + {"id": "start", "type": "start"}, + {"id": "llm", "type": "llm"}, + {"id": "end", "type": "end"}, + ], + edges=[ + {"source": "start", "target": "llm"}, + {"source": "llm", "target": "end"}, + ], + ) + assert ctx.get_upstream_nodes("llm") == ["start"] + assert ctx.get_downstream_nodes("llm") == ["end"] + + +class TestStructureRules: + """Tests for structure validation rules.""" + + def test_llm_missing_prompt_template(self): + ctx = ValidationContext(nodes=[{"id": "llm_1", "type": "llm", "config": {}}]) + engine = ValidationEngine() + result = engine.validate(ctx) + + assert result.has_errors + errors = [e for e in result.all_errors if e.rule_id == "llm.prompt_template.required"] + assert len(errors) == 1 + assert errors[0].is_fixable is True + + def test_llm_with_prompt_template_passes(self): + ctx = ValidationContext( + nodes=[ + { + "id": "llm_1", + "type": "llm", + "config": { + "prompt_template": [ + {"role": "system", "text": "You are helpful"}, + {"role": "user", "text": "Hello"}, + ] + }, + } + ] + ) + engine = ValidationEngine() + result = engine.validate(ctx) + + # No prompt_template errors + errors = [e for e in result.all_errors if "prompt_template" in e.rule_id] + assert len(errors) == 0 + + def test_http_request_missing_url(self): + ctx = ValidationContext(nodes=[{"id": "http_1", "type": "http-request", "config": {}}]) + engine = ValidationEngine() + result = engine.validate(ctx) + + errors = [e for e in result.all_errors if "http.url" in e.rule_id] + assert len(errors) == 1 + assert errors[0].is_fixable is True + + def test_http_request_placeholder_url(self): + ctx = ValidationContext( + nodes=[ + { + "id": "http_1", + "type": "http-request", + "config": {"url": "PLEASE_SELECT_YOUR_URL", "method": "GET"}, + } + ] + ) + engine = ValidationEngine() + result = engine.validate(ctx) + + errors = [e for e in result.all_errors if "placeholder" in e.rule_id] + assert len(errors) == 1 + + def test_code_node_missing_fields(self): + ctx = ValidationContext(nodes=[{"id": "code_1", "type": "code", "config": {}}]) + engine = ValidationEngine() + result = engine.validate(ctx) + + error_rules = {e.rule_id for e in result.all_errors} + assert "code.code.required" in error_rules + assert "code.language.required" in error_rules + + def test_knowledge_retrieval_missing_dataset(self): + ctx = ValidationContext(nodes=[{"id": "kb_1", "type": "knowledge-retrieval", "config": {}}]) + engine = ValidationEngine() + result = engine.validate(ctx) + + errors = [e for e in result.all_errors if "knowledge.dataset" in e.rule_id] + assert len(errors) == 1 + assert errors[0].is_fixable is False # User must configure + + +class TestSemanticRules: + """Tests for semantic validation rules.""" + + def test_valid_variable_reference(self): + ctx = ValidationContext( + nodes=[ + {"id": "start", "type": "start", "config": {}}, + { + "id": "llm_1", + "type": "llm", + "config": {"prompt_template": [{"role": "user", "text": "Process: {{#start.query#}}"}]}, + }, + ] + ) + engine = ValidationEngine() + result = engine.validate(ctx) + + # No variable reference errors + errors = [e for e in result.all_errors if "variable.ref" in e.rule_id] + assert len(errors) == 0 + + def test_invalid_variable_reference(self): + ctx = ValidationContext( + nodes=[ + {"id": "start", "type": "start", "config": {}}, + { + "id": "llm_1", + "type": "llm", + "config": {"prompt_template": [{"role": "user", "text": "Process: {{#nonexistent.field#}}"}]}, + }, + ] + ) + engine = ValidationEngine() + result = engine.validate(ctx) + + errors = [e for e in result.all_errors if "variable.ref" in e.rule_id] + assert len(errors) == 1 + assert "nonexistent" in errors[0].message + + def test_edge_validation(self): + ctx = ValidationContext( + nodes=[ + {"id": "start", "type": "start", "config": {}}, + {"id": "end", "type": "end", "config": {}}, + ], + edges=[ + {"source": "start", "target": "end"}, + {"source": "nonexistent", "target": "end"}, + ], + ) + engine = ValidationEngine() + result = engine.validate(ctx) + + errors = [e for e in result.all_errors if "edge" in e.rule_id] + assert len(errors) == 1 + assert "nonexistent" in errors[0].message + + +class TestReferenceRules: + """Tests for reference validation rules (models, tools).""" + + def test_llm_missing_model_with_available(self): + ctx = ValidationContext( + nodes=[ + { + "id": "llm_1", + "type": "llm", + "config": {"prompt_template": [{"role": "user", "text": "Hi"}]}, + } + ], + available_models=[{"provider": "openai", "model": "gpt-4"}], + ) + engine = ValidationEngine() + result = engine.validate(ctx) + + errors = [e for e in result.all_errors if e.rule_id == "model.required"] + assert len(errors) == 1 + assert errors[0].is_fixable is True + + def test_llm_missing_model_no_available(self): + ctx = ValidationContext( + nodes=[ + { + "id": "llm_1", + "type": "llm", + "config": {"prompt_template": [{"role": "user", "text": "Hi"}]}, + } + ], + available_models=[], # No models available + ) + engine = ValidationEngine() + result = engine.validate(ctx) + + errors = [e for e in result.all_errors if e.rule_id == "model.no_available"] + assert len(errors) == 1 + assert errors[0].is_fixable is False + + def test_llm_with_valid_model(self): + ctx = ValidationContext( + nodes=[ + { + "id": "llm_1", + "type": "llm", + "config": { + "prompt_template": [{"role": "user", "text": "Hi"}], + "model": {"provider": "openai", "name": "gpt-4"}, + }, + } + ], + available_models=[{"provider": "openai", "model": "gpt-4"}], + ) + engine = ValidationEngine() + result = engine.validate(ctx) + + errors = [e for e in result.all_errors if "model" in e.rule_id] + assert len(errors) == 0 + + def test_llm_with_invalid_model(self): + ctx = ValidationContext( + nodes=[ + { + "id": "llm_1", + "type": "llm", + "config": { + "prompt_template": [{"role": "user", "text": "Hi"}], + "model": {"provider": "openai", "name": "gpt-99"}, + }, + } + ], + available_models=[{"provider": "openai", "model": "gpt-4"}], + ) + engine = ValidationEngine() + result = engine.validate(ctx) + + errors = [e for e in result.all_errors if e.rule_id == "model.not_found"] + assert len(errors) == 1 + assert errors[0].is_fixable is True + + def test_tool_node_not_found(self): + ctx = ValidationContext( + nodes=[ + { + "id": "tool_1", + "type": "tool", + "config": {"tool_key": "nonexistent/tool"}, + } + ], + available_tools=[], + ) + engine = ValidationEngine() + result = engine.validate(ctx) + + errors = [e for e in result.all_errors if e.rule_id == "tool.not_found"] + assert len(errors) == 1 + + def test_tool_node_not_configured(self): + ctx = ValidationContext( + nodes=[ + { + "id": "tool_1", + "type": "tool", + "config": {"tool_key": "google/search"}, + } + ], + available_tools=[{"provider_id": "google", "tool_key": "search", "is_team_authorization": False}], + ) + engine = ValidationEngine() + result = engine.validate(ctx) + + errors = [e for e in result.all_errors if e.rule_id == "tool.not_configured"] + assert len(errors) == 1 + assert errors[0].is_fixable is False + + +class TestValidationResult: + """Tests for ValidationResult classification.""" + + def test_has_errors(self): + ctx = ValidationContext(nodes=[{"id": "llm_1", "type": "llm", "config": {}}]) + engine = ValidationEngine() + result = engine.validate(ctx) + + assert result.has_errors is True + assert result.is_valid is False + + def test_has_fixable_errors(self): + ctx = ValidationContext( + nodes=[ + { + "id": "llm_1", + "type": "llm", + "config": {"prompt_template": [{"role": "user", "text": "Hi"}]}, + } + ], + available_models=[{"provider": "openai", "model": "gpt-4"}], + ) + engine = ValidationEngine() + result = engine.validate(ctx) + + assert result.has_fixable_errors is True + assert len(result.fixable_errors) > 0 + + def test_get_fixable_by_node(self): + ctx = ValidationContext( + nodes=[ + {"id": "llm_1", "type": "llm", "config": {}}, + {"id": "http_1", "type": "http-request", "config": {}}, + ] + ) + engine = ValidationEngine() + result = engine.validate(ctx) + + by_node = result.get_fixable_by_node() + assert "llm_1" in by_node + assert "http_1" in by_node + + def test_to_dict(self): + ctx = ValidationContext(nodes=[{"id": "llm_1", "type": "llm", "config": {}}]) + engine = ValidationEngine() + result = engine.validate(ctx) + + d = result.to_dict() + assert "fixable" in d + assert "user_required" in d + assert "warnings" in d + assert "all_warnings" in d + assert "stats" in d + + +class TestIntegration: + """Integration tests for the full validation pipeline.""" + + def test_complete_workflow_validation(self): + """Test validation of a complete workflow.""" + ctx = ValidationContext( + nodes=[ + { + "id": "start", + "type": "start", + "config": {"variables": [{"variable": "query", "type": "text-input"}]}, + }, + { + "id": "llm_1", + "type": "llm", + "config": { + "model": {"provider": "openai", "name": "gpt-4"}, + "prompt_template": [{"role": "user", "text": "{{#start.query#}}"}], + }, + }, + { + "id": "end", + "type": "end", + "config": {"outputs": [{"variable": "result", "value_selector": ["llm_1", "text"]}]}, + }, + ], + edges=[ + {"source": "start", "target": "llm_1"}, + {"source": "llm_1", "target": "end"}, + ], + available_models=[{"provider": "openai", "model": "gpt-4"}], + ) + engine = ValidationEngine() + result = engine.validate(ctx) + + # Should have no errors + assert result.is_valid is True + assert len(result.fixable_errors) == 0 + assert len(result.user_required_errors) == 0 + + def test_workflow_with_multiple_errors(self): + """Test workflow with multiple types of errors.""" + ctx = ValidationContext( + nodes=[ + {"id": "start", "type": "start", "config": {}}, + { + "id": "llm_1", + "type": "llm", + "config": {}, # Missing prompt_template and model + }, + { + "id": "kb_1", + "type": "knowledge-retrieval", + "config": {"dataset_ids": ["PLEASE_SELECT_YOUR_DATASET"]}, + }, + {"id": "end", "type": "end", "config": {}}, + ], + available_models=[{"provider": "openai", "model": "gpt-4"}], + ) + engine = ValidationEngine() + result = engine.validate(ctx) + + # Should have multiple errors + assert result.has_errors is True + assert len(result.fixable_errors) >= 2 # model, prompt_template + assert len(result.user_required_errors) >= 1 # dataset placeholder + + # Check stats + assert result.stats["total_nodes"] == 4 + assert result.stats["total_errors"] >= 3 diff --git a/api/tests/unit_tests/core/workflow/nodes/agent/test_agent_node.py b/api/tests/unit_tests/core/workflow/nodes/agent/test_agent_node.py new file mode 100644 index 0000000000..a95892d0b6 --- /dev/null +++ b/api/tests/unit_tests/core/workflow/nodes/agent/test_agent_node.py @@ -0,0 +1,197 @@ +from unittest.mock import MagicMock, patch + +import pytest + +from core.tools.entities.tool_entities import ToolProviderType +from core.workflow.nodes.agent.agent_node import AgentNode + + +class TestInferToolProviderType: + """Test cases for AgentNode._infer_tool_provider_type method.""" + + def test_infer_type_from_config_workflow(self): + """Test inferring workflow provider type from config.""" + tool_config = { + "type": "workflow", + "provider_name": "workflow-provider-id", + } + tenant_id = "test-tenant" + + result = AgentNode._infer_tool_provider_type(tool_config, tenant_id) + + assert result == ToolProviderType.WORKFLOW + + def test_infer_type_from_config_builtin(self): + """Test inferring builtin provider type from config.""" + tool_config = { + "type": "builtin", + "provider_name": "builtin-provider-id", + } + tenant_id = "test-tenant" + + result = AgentNode._infer_tool_provider_type(tool_config, tenant_id) + + assert result == ToolProviderType.BUILT_IN + + def test_infer_type_from_config_api(self): + """Test inferring API provider type from config.""" + tool_config = { + "type": "api", + "provider_name": "api-provider-id", + } + tenant_id = "test-tenant" + + result = AgentNode._infer_tool_provider_type(tool_config, tenant_id) + + assert result == ToolProviderType.API + + def test_infer_type_from_config_mcp(self): + """Test inferring MCP provider type from config.""" + tool_config = { + "type": "mcp", + "provider_name": "mcp-provider-id", + } + tenant_id = "test-tenant" + + result = AgentNode._infer_tool_provider_type(tool_config, tenant_id) + + assert result == ToolProviderType.MCP + + def test_infer_type_invalid_config_value_raises_error(self): + """Test that invalid type value in config raises ValueError.""" + tool_config = { + "type": "invalid-type", + "provider_name": "workflow-provider-id", + } + tenant_id = "test-tenant" + + with pytest.raises(ValueError): + AgentNode._infer_tool_provider_type(tool_config, tenant_id) + + def test_infer_workflow_type_from_database(self): + """Test inferring workflow provider type from database.""" + tool_config = { + "provider_name": "workflow-provider-id", + } + tenant_id = "test-tenant" + + with patch("core.db.session_factory.session_factory.create_session") as mock_create_session: + mock_session = MagicMock() + mock_create_session.return_value.__enter__.return_value = mock_session + + # First query (WorkflowToolProvider) returns a result + mock_session.scalar.return_value = True + + result = AgentNode._infer_tool_provider_type(tool_config, tenant_id) + + assert result == ToolProviderType.WORKFLOW + # Should only query once (after finding WorkflowToolProvider) + assert mock_session.scalar.call_count == 1 + + def test_infer_mcp_type_from_database(self): + """Test inferring MCP provider type from database.""" + tool_config = { + "provider_name": "mcp-provider-id", + } + tenant_id = "test-tenant" + + with patch("core.db.session_factory.session_factory.create_session") as mock_create_session: + mock_session = MagicMock() + mock_create_session.return_value.__enter__.return_value = mock_session + + # First query (WorkflowToolProvider) returns None + # Second query (MCPToolProvider) returns a result + mock_session.scalar.side_effect = [None, True] + + result = AgentNode._infer_tool_provider_type(tool_config, tenant_id) + + assert result == ToolProviderType.MCP + assert mock_session.scalar.call_count == 2 + + def test_infer_api_type_from_database(self): + """Test inferring API provider type from database.""" + tool_config = { + "provider_name": "api-provider-id", + } + tenant_id = "test-tenant" + + with patch("core.db.session_factory.session_factory.create_session") as mock_create_session: + mock_session = MagicMock() + mock_create_session.return_value.__enter__.return_value = mock_session + + # First query (WorkflowToolProvider) returns None + # Second query (MCPToolProvider) returns None + # Third query (ApiToolProvider) returns a result + mock_session.scalar.side_effect = [None, None, True] + + result = AgentNode._infer_tool_provider_type(tool_config, tenant_id) + + assert result == ToolProviderType.API + assert mock_session.scalar.call_count == 3 + + def test_infer_builtin_type_from_database(self): + """Test inferring builtin provider type from database.""" + tool_config = { + "provider_name": "builtin-provider-id", + } + tenant_id = "test-tenant" + + with patch("core.db.session_factory.session_factory.create_session") as mock_create_session: + mock_session = MagicMock() + mock_create_session.return_value.__enter__.return_value = mock_session + + # First three queries return None + # Fourth query (BuiltinToolProvider) returns a result + mock_session.scalar.side_effect = [None, None, None, True] + + result = AgentNode._infer_tool_provider_type(tool_config, tenant_id) + + assert result == ToolProviderType.BUILT_IN + assert mock_session.scalar.call_count == 4 + + def test_infer_type_default_when_not_found(self): + """Test raising AgentNodeError when provider is not found in database.""" + tool_config = { + "provider_name": "unknown-provider-id", + } + tenant_id = "test-tenant" + + with patch("core.db.session_factory.session_factory.create_session") as mock_create_session: + mock_session = MagicMock() + mock_create_session.return_value.__enter__.return_value = mock_session + + # All queries return None + mock_session.scalar.return_value = None + + # Current implementation raises AgentNodeError when provider not found + from core.workflow.nodes.agent.exc import AgentNodeError + + with pytest.raises(AgentNodeError, match="Tool provider with ID 'unknown-provider-id' not found"): + AgentNode._infer_tool_provider_type(tool_config, tenant_id) + + def test_infer_type_default_when_no_provider_name(self): + """Test defaulting to BUILT_IN when provider_name is missing.""" + tool_config = {} + tenant_id = "test-tenant" + + result = AgentNode._infer_tool_provider_type(tool_config, tenant_id) + + assert result == ToolProviderType.BUILT_IN + + def test_infer_type_database_exception_propagates(self): + """Test that database exception propagates (current implementation doesn't catch it).""" + tool_config = { + "provider_name": "provider-id", + } + tenant_id = "test-tenant" + + with patch("core.db.session_factory.session_factory.create_session") as mock_create_session: + mock_session = MagicMock() + mock_create_session.return_value.__enter__.return_value = mock_session + + # Database query raises exception + mock_session.scalar.side_effect = Exception("Database error") + + # Current implementation doesn't catch exceptions, so it propagates + with pytest.raises(Exception, match="Database error"): + AgentNode._infer_tool_provider_type(tool_config, tenant_id) diff --git a/web/app/components/app/configuration/config/automatic/version-selector.tsx b/web/app/components/app/configuration/config/automatic/version-selector.tsx index 91fb3950d2..6518fd1a03 100644 --- a/web/app/components/app/configuration/config/automatic/version-selector.tsx +++ b/web/app/components/app/configuration/config/automatic/version-selector.tsx @@ -10,9 +10,15 @@ type VersionSelectorProps = { versionLen: number value: number onChange: (index: number) => void + contentClassName?: string } -const VersionSelector: React.FC<VersionSelectorProps> = ({ versionLen, value, onChange }) => { +const VersionSelector: React.FC<VersionSelectorProps> = ({ + versionLen, + value, + onChange, + contentClassName, +}) => { const { t } = useTranslation() const [isOpen, { setFalse: handleOpenFalse, @@ -64,6 +70,7 @@ const VersionSelector: React.FC<VersionSelectorProps> = ({ versionLen, value, on </PortalToFollowElemTrigger> <PortalToFollowElemContent className={cn( 'z-[99]', + contentClassName, )} > <div diff --git a/web/app/components/goto-anything/actions/app.tsx b/web/app/components/goto-anything/actions/app.tsx index 496475eacb..739357388d 100644 --- a/web/app/components/goto-anything/actions/app.tsx +++ b/web/app/components/goto-anything/actions/app.tsx @@ -1,9 +1,10 @@ -import type { ActionItem, AppSearchResult } from './types' +import type { AppSearchResult, ScopeDescriptor } from './types' import type { App } from '@/types/app' -import { fetchAppList } from '@/service/apps' +import { searchApps } from '@/service/use-goto-anything' import { getRedirectionPath } from '@/utils/app-redirection' import { AppTypeIcon } from '../../app/type-selector' import AppIcon from '../../base/app-icon' +import { ACTION_KEYS } from '../constants' const parser = (apps: App[]): AppSearchResult[] => { return apps.map(app => ({ @@ -35,21 +36,14 @@ const parser = (apps: App[]): AppSearchResult[] => { })) } -export const appAction: ActionItem = { - key: '@app', - shortcut: '@app', +export const appScope: ScopeDescriptor = { + id: 'app', + shortcut: ACTION_KEYS.APP, title: 'Search Applications', description: 'Search and navigate to your applications', - // action, search: async (_, searchTerm = '', _locale) => { try { - const response = await fetchAppList({ - url: 'apps', - params: { - page: 1, - name: searchTerm, - }, - }) + const response = await searchApps(searchTerm) const apps = response?.data || [] return parser(apps) } diff --git a/web/app/components/goto-anything/actions/commands/index.ts b/web/app/components/goto-anything/actions/commands/index.ts index 72388f6565..7258840d7e 100644 --- a/web/app/components/goto-anything/actions/commands/index.ts +++ b/web/app/components/goto-anything/actions/commands/index.ts @@ -9,7 +9,7 @@ export { export { slashCommandRegistry, SlashCommandRegistry } from './registry' // Command system exports -export { slashAction } from './slash' +export { slashScope } from './slash' export { registerSlashCommands, SlashCommandProvider, unregisterSlashCommands } from './slash' export type { SlashCommandHandler } from './types' diff --git a/web/app/components/goto-anything/actions/commands/language.tsx b/web/app/components/goto-anything/actions/commands/language.tsx index f4bafc1d58..f3244f4cd1 100644 --- a/web/app/components/goto-anything/actions/commands/language.tsx +++ b/web/app/components/goto-anything/actions/commands/language.tsx @@ -1,12 +1,13 @@ import type { CommandSearchResult } from '../types' import type { SlashCommandHandler } from './types' +import type { Locale } from '@/i18n-config/language' import { getI18n } from 'react-i18next' import { languages } from '@/i18n-config/language' import { registerCommands, unregisterCommands } from './command-bus' // Language dependency types type LanguageDeps = { - setLocale?: (locale: string) => Promise<void> + setLocale?: (locale: Locale, reloadPage?: boolean) => Promise<void> } const buildLanguageCommands = (query: string): CommandSearchResult[] => { diff --git a/web/app/components/goto-anything/actions/commands/registry.ts b/web/app/components/goto-anything/actions/commands/registry.ts index 51beef4c0b..94321a1916 100644 --- a/web/app/components/goto-anything/actions/commands/registry.ts +++ b/web/app/components/goto-anything/actions/commands/registry.ts @@ -6,20 +6,21 @@ import type { SlashCommandHandler } from './types' * Responsible for managing registration, lookup, and search of all slash commands */ export class SlashCommandRegistry { - private commands = new Map<string, SlashCommandHandler>() - private commandDeps = new Map<string, any>() + private commands = new Map<string, SlashCommandHandler<unknown>>() + private commandDeps = new Map<string, unknown>() /** * Register command handler */ - register<TDeps = any>(handler: SlashCommandHandler<TDeps>, deps?: TDeps) { + register<TDeps = unknown>(handler: SlashCommandHandler<TDeps>, deps?: TDeps) { // Register main command name - this.commands.set(handler.name, handler) + // Cast to unknown first, then to SlashCommandHandler<unknown> to handle generic type variance + this.commands.set(handler.name, handler as SlashCommandHandler<unknown>) // Register aliases if (handler.aliases) { handler.aliases.forEach((alias) => { - this.commands.set(alias, handler) + this.commands.set(alias, handler as SlashCommandHandler<unknown>) }) } @@ -57,7 +58,7 @@ export class SlashCommandRegistry { /** * Find command handler */ - findCommand(commandName: string): SlashCommandHandler | undefined { + findCommand(commandName: string): SlashCommandHandler<unknown> | undefined { return this.commands.get(commandName) } @@ -65,7 +66,7 @@ export class SlashCommandRegistry { * Smart partial command matching * Prioritize alias matching, then match command name prefix */ - private findBestPartialMatch(partialName: string): SlashCommandHandler | undefined { + private findBestPartialMatch(partialName: string): SlashCommandHandler<unknown> | undefined { const lowerPartial = partialName.toLowerCase() // First check if any alias starts with this @@ -81,7 +82,7 @@ export class SlashCommandRegistry { /** * Find handler by alias prefix */ - private findHandlerByAliasPrefix(prefix: string): SlashCommandHandler | undefined { + private findHandlerByAliasPrefix(prefix: string): SlashCommandHandler<unknown> | undefined { for (const handler of this.getAllCommands()) { if (handler.aliases?.some(alias => alias.toLowerCase().startsWith(prefix))) return handler @@ -92,7 +93,7 @@ export class SlashCommandRegistry { /** * Find handler by name prefix */ - private findHandlerByNamePrefix(prefix: string): SlashCommandHandler | undefined { + private findHandlerByNamePrefix(prefix: string): SlashCommandHandler<unknown> | undefined { return this.getAllCommands().find(handler => handler.name.toLowerCase().startsWith(prefix), ) @@ -101,8 +102,8 @@ export class SlashCommandRegistry { /** * Get all registered commands (deduplicated) */ - getAllCommands(): SlashCommandHandler[] { - const uniqueCommands = new Map<string, SlashCommandHandler>() + getAllCommands(): SlashCommandHandler<unknown>[] { + const uniqueCommands = new Map<string, SlashCommandHandler<unknown>>() this.commands.forEach((handler) => { uniqueCommands.set(handler.name, handler) }) @@ -113,7 +114,7 @@ export class SlashCommandRegistry { * Get all available commands in current context (deduplicated and filtered) * Commands without isAvailable method are considered always available */ - getAvailableCommands(): SlashCommandHandler[] { + getAvailableCommands(): SlashCommandHandler<unknown>[] { return this.getAllCommands().filter(handler => this.isCommandAvailable(handler)) } @@ -228,7 +229,7 @@ export class SlashCommandRegistry { /** * Get command dependencies */ - getCommandDependencies(commandName: string): any { + getCommandDependencies(commandName: string): unknown { return this.commandDeps.get(commandName) } @@ -236,7 +237,7 @@ export class SlashCommandRegistry { * Determine if a command is available in the current context. * Defaults to true when a handler does not implement the guard. */ - private isCommandAvailable(handler: SlashCommandHandler) { + private isCommandAvailable(handler: SlashCommandHandler<unknown>) { return handler.isAvailable?.() ?? true } } diff --git a/web/app/components/goto-anything/actions/commands/slash.tsx b/web/app/components/goto-anything/actions/commands/slash.tsx index 3063634416..2d6019dac1 100644 --- a/web/app/components/goto-anything/actions/commands/slash.tsx +++ b/web/app/components/goto-anything/actions/commands/slash.tsx @@ -1,12 +1,13 @@ 'use client' -import type { ActionItem } from '../types' +import type { ScopeDescriptor } from '../types' +import type { SlashCommandDependencies } from './types' import { useTheme } from 'next-themes' import { useEffect } from 'react' import { getI18n } from 'react-i18next' import { setLocaleOnClient } from '@/i18n-config' +import { ACTION_KEYS } from '../../constants' import { accountCommand } from './account' import { bananaCommand } from './banana' -import { executeCommand } from './command-bus' import { communityCommand } from './community' import { docsCommand } from './docs' import { forumCommand } from './forum' @@ -17,17 +18,11 @@ import { zenCommand } from './zen' const i18n = getI18n() -export const slashAction: ActionItem = { - key: '/', - shortcut: '/', +export const slashScope: ScopeDescriptor = { + id: 'slash', + shortcut: ACTION_KEYS.SLASH, title: i18n.t('gotoAnything.actions.slashTitle', { ns: 'app' }), description: i18n.t('gotoAnything.actions.slashDesc', { ns: 'app' }), - action: (result) => { - if (result.type !== 'command') - return - const { command, args } = result.data - executeCommand(command, args) - }, search: async (query, _searchTerm = '') => { // Delegate all search logic to the command registry system return slashCommandRegistry.search(query, i18n.language) @@ -35,7 +30,7 @@ export const slashAction: ActionItem = { } // Register/unregister default handlers for slash commands with external dependencies. -export const registerSlashCommands = (deps: Record<string, any>) => { +export const registerSlashCommands = (deps: SlashCommandDependencies) => { // Register command handlers to the registry system with their respective dependencies slashCommandRegistry.register(themeCommand, { setTheme: deps.setTheme }) slashCommandRegistry.register(languageCommand, { setLocale: deps.setLocale }) diff --git a/web/app/components/goto-anything/actions/commands/types.ts b/web/app/components/goto-anything/actions/commands/types.ts index 528883c25f..ccf8cdd881 100644 --- a/web/app/components/goto-anything/actions/commands/types.ts +++ b/web/app/components/goto-anything/actions/commands/types.ts @@ -1,10 +1,11 @@ import type { CommandSearchResult } from '../types' +import type { Locale } from '@/i18n-config/language' /** * Slash command handler interface * Each slash command should implement this interface */ -export type SlashCommandHandler<TDeps = any> = { +export type SlashCommandHandler<TDeps = unknown> = { /** Command name (e.g., 'theme', 'language') */ name: string @@ -51,3 +52,31 @@ export type SlashCommandHandler<TDeps = any> = { */ unregister?: () => void } + +/** + * Theme command dependencies + */ +export type ThemeCommandDeps = { + setTheme?: (value: 'light' | 'dark' | 'system') => void +} + +/** + * Language command dependencies + */ +export type LanguageCommandDeps = { + setLocale?: (locale: Locale, reloadPage?: boolean) => Promise<void> +} + +/** + * Commands without external dependencies + */ +export type NoDepsCommandDeps = Record<string, never> + +/** + * Union type of all slash command dependencies + * Used for type-safe dependency injection in registerSlashCommands + */ +export type SlashCommandDependencies = { + setTheme?: (value: 'light' | 'dark' | 'system') => void + setLocale?: (locale: Locale, reloadPage?: boolean) => Promise<void> +} diff --git a/web/app/components/goto-anything/actions/index.ts b/web/app/components/goto-anything/actions/index.ts index 024b6bfd2c..b8866475be 100644 --- a/web/app/components/goto-anything/actions/index.ts +++ b/web/app/components/goto-anything/actions/index.ts @@ -3,228 +3,66 @@ * * This file defines the action registry for the goto-anything search system. * Actions handle different types of searches: apps, knowledge bases, plugins, workflow nodes, and commands. - * - * ## How to Add a New Slash Command - * - * 1. **Create Command Handler File** (in `./commands/` directory): - * ```typescript - * // commands/my-command.ts - * import type { SlashCommandHandler } from './types' - * import type { CommandSearchResult } from '../types' - * import { registerCommands, unregisterCommands } from './command-bus' - * - * interface MyCommandDeps { - * myService?: (data: any) => Promise<void> - * } - * - * export const myCommand: SlashCommandHandler<MyCommandDeps> = { - * name: 'mycommand', - * aliases: ['mc'], // Optional aliases - * description: 'My custom command description', - * - * async search(args: string, locale: string = 'en') { - * // Return search results based on args - * return [{ - * id: 'my-result', - * title: 'My Command Result', - * description: 'Description of the result', - * type: 'command' as const, - * data: { command: 'my.action', args: { value: args } } - * }] - * }, - * - * register(deps: MyCommandDeps) { - * registerCommands({ - * 'my.action': async (args) => { - * await deps.myService?.(args?.value) - * } - * }) - * }, - * - * unregister() { - * unregisterCommands(['my.action']) - * } - * } - * ``` - * - * **Example for Self-Contained Command (no external dependencies):** - * ```typescript - * // commands/calculator-command.ts - * export const calculatorCommand: SlashCommandHandler = { - * name: 'calc', - * aliases: ['calculator'], - * description: 'Simple calculator', - * - * async search(args: string) { - * if (!args.trim()) return [] - * try { - * // Safe math evaluation (implement proper parser in real use) - * const result = Function('"use strict"; return (' + args + ')')() - * return [{ - * id: 'calc-result', - * title: `${args} = ${result}`, - * description: 'Calculator result', - * type: 'command' as const, - * data: { command: 'calc.copy', args: { result: result.toString() } } - * }] - * } catch { - * return [{ - * id: 'calc-error', - * title: 'Invalid expression', - * description: 'Please enter a valid math expression', - * type: 'command' as const, - * data: { command: 'calc.noop', args: {} } - * }] - * } - * }, - * - * register() { - * registerCommands({ - * 'calc.copy': (args) => navigator.clipboard.writeText(args.result), - * 'calc.noop': () => {} // No operation - * }) - * }, - * - * unregister() { - * unregisterCommands(['calc.copy', 'calc.noop']) - * } - * } - * ``` - * - * 2. **Register Command** (in `./commands/slash.tsx`): - * ```typescript - * import { myCommand } from './my-command' - * import { calculatorCommand } from './calculator-command' // For self-contained commands - * - * export const registerSlashCommands = (deps: Record<string, any>) => { - * slashCommandRegistry.register(themeCommand, { setTheme: deps.setTheme }) - * slashCommandRegistry.register(languageCommand, { setLocale: deps.setLocale }) - * slashCommandRegistry.register(myCommand, { myService: deps.myService }) // With dependencies - * slashCommandRegistry.register(calculatorCommand) // Self-contained, no dependencies - * } - * - * export const unregisterSlashCommands = () => { - * slashCommandRegistry.unregister('theme') - * slashCommandRegistry.unregister('language') - * slashCommandRegistry.unregister('mycommand') - * slashCommandRegistry.unregister('calc') // Add this line - * } - * ``` - * - * - * 3. **Update SlashCommandProvider** (in `./commands/slash.tsx`): - * ```typescript - * export const SlashCommandProvider = () => { - * const theme = useTheme() - * const myService = useMyService() // Add external dependency if needed - * - * useEffect(() => { - * registerSlashCommands({ - * setTheme: theme.setTheme, // Required for theme command - * setLocale: setLocaleOnClient, // Required for language command - * myService: myService, // Required for your custom command - * // Note: calculatorCommand doesn't need dependencies, so not listed here - * }) - * return () => unregisterSlashCommands() - * }, [theme.setTheme, myService]) // Update dependency array for all dynamic deps - * - * return null - * } - * ``` - * - * **Note:** Self-contained commands (like calculator) don't require dependencies but are - * still registered through the same system for consistent lifecycle management. - * - * 4. **Usage**: Users can now type `/mycommand` or `/mc` to use your command - * - * ## Command System Architecture - * - Commands are registered via `SlashCommandRegistry` - * - Each command is self-contained with its own dependencies - * - Commands support aliases for easier access - * - Command execution is handled by the command bus system - * - All commands should be registered through `SlashCommandProvider` for consistent lifecycle management - * - * ## Command Types - * **Commands with External Dependencies:** - * - Require external services, APIs, or React hooks - * - Must provide dependencies in `SlashCommandProvider` - * - Example: theme commands (needs useTheme), API commands (needs service) - * - * **Self-Contained Commands:** - * - Pure logic operations, no external dependencies - * - Still recommended to register through `SlashCommandProvider` for consistency - * - Example: calculator, text manipulation commands - * - * ## Available Actions - * - `@app` - Search applications - * - `@knowledge` / `@kb` - Search knowledge bases - * - `@plugin` - Search plugins - * - `@node` - Search workflow nodes (workflow pages only) - * - `/` - Execute slash commands (theme, language, banana, etc.) */ -import type { ActionItem, SearchResult } from './types' -import { appAction } from './app' -import { slashAction } from './commands' +import type { ScopeContext, ScopeDescriptor, SearchResult } from './types' +import { ACTION_KEYS } from '../constants' +import { appScope } from './app' +import { slashScope } from './commands' import { slashCommandRegistry } from './commands/registry' -import { knowledgeAction } from './knowledge' -import { pluginAction } from './plugin' -import { ragPipelineNodesAction } from './rag-pipeline-nodes' -import { workflowNodesAction } from './workflow-nodes' +import { knowledgeScope } from './knowledge' +import { pluginScope } from './plugin' +import { registerRagPipelineNodeScope } from './rag-pipeline-nodes' +import { scopeRegistry, useScopeRegistry } from './scope-registry' +import { registerWorkflowNodeScope } from './workflow-nodes' -// Create dynamic Actions based on context -export const createActions = (isWorkflowPage: boolean, isRagPipelinePage: boolean) => { - const baseActions = { - slash: slashAction, - app: appAction, - knowledge: knowledgeAction, - plugin: pluginAction, - } +let scopesInitialized = false - // Add appropriate node search based on context - if (isRagPipelinePage) { - return { - ...baseActions, - node: ragPipelineNodesAction, - } - } - else if (isWorkflowPage) { - return { - ...baseActions, - node: workflowNodesAction, - } - } +export const initGotoAnythingScopes = () => { + if (scopesInitialized) + return - // Default actions without node search - return baseActions + scopesInitialized = true + + scopeRegistry.register(slashScope) + scopeRegistry.register(appScope) + scopeRegistry.register(knowledgeScope) + scopeRegistry.register(pluginScope) + registerWorkflowNodeScope() + registerRagPipelineNodeScope() } -// Legacy export for backward compatibility -export const Actions = { - slash: slashAction, - app: appAction, - knowledge: knowledgeAction, - plugin: pluginAction, - node: workflowNodesAction, +export const useGotoAnythingScopes = (context: ScopeContext) => { + initGotoAnythingScopes() + return useScopeRegistry(context) } +const isSlashScope = (scope: ScopeDescriptor) => { + if (scope.shortcut === ACTION_KEYS.SLASH) + return true + return scope.aliases?.includes(ACTION_KEYS.SLASH) ?? false +} + +const getScopeShortcuts = (scope: ScopeDescriptor) => [scope.shortcut, ...(scope.aliases ?? [])] + export const searchAnything = async ( locale: string, query: string, - actionItem?: ActionItem, - dynamicActions?: Record<string, ActionItem>, + scope: ScopeDescriptor | undefined, + scopes: ScopeDescriptor[], ): Promise<SearchResult[]> => { const trimmedQuery = query.trim() - if (actionItem) { + if (scope) { const escapeRegExp = (value: string) => value.replace(/[.*+?^${}()|[\]\\]/g, '\\$&') - const prefixPattern = new RegExp(`^(${escapeRegExp(actionItem.key)}|${escapeRegExp(actionItem.shortcut)})\\s*`) + const shortcuts = getScopeShortcuts(scope).map(escapeRegExp) + const prefixPattern = new RegExp(`^(${shortcuts.join('|')})\\s*`) const searchTerm = trimmedQuery.replace(prefixPattern, '').trim() try { - return await actionItem.search(query, searchTerm, locale) + return await scope.search(query, searchTerm, locale) } catch (error) { - console.warn(`Search failed for ${actionItem.key}:`, error) + console.warn(`Search failed for ${scope.id}:`, error) return [] } } @@ -232,19 +70,19 @@ export const searchAnything = async ( if (trimmedQuery.startsWith('@') || trimmedQuery.startsWith('/')) return [] - const globalSearchActions = Object.values(dynamicActions || Actions) - // Exclude slash commands from general search results - .filter(action => action.key !== '/') + // Filter out slash commands from general search + const searchScopes = scopes.filter(scope => !isSlashScope(scope)) // Use Promise.allSettled to handle partial failures gracefully - const searchPromises = globalSearchActions.map(async (action) => { + const searchPromises = searchScopes.map(async (action) => { + const actionId = action.id try { const results = await action.search(query, query, locale) - return { success: true, data: results, actionType: action.key } + return { success: true, data: results, actionType: actionId } } catch (error) { - console.warn(`Search failed for ${action.key}:`, error) - return { success: false, data: [], actionType: action.key, error } + console.warn(`Search failed for ${actionId}:`, error) + return { success: false, data: [], actionType: actionId, error } } }) @@ -258,7 +96,7 @@ export const searchAnything = async ( allResults.push(...result.value.data) } else { - const actionKey = globalSearchActions[index]?.key || 'unknown' + const actionKey = searchScopes[index]?.id || 'unknown' failedActions.push(actionKey) } }) @@ -269,31 +107,31 @@ export const searchAnything = async ( return allResults } -export const matchAction = (query: string, actions: Record<string, ActionItem>) => { - return Object.values(actions).find((action) => { - // Special handling for slash commands - if (action.key === '/') { - // Get all registered commands from the registry - const allCommands = slashCommandRegistry.getAllCommands() +// ... - // Check if query matches any registered command +export const matchAction = (query: string, scopes: ScopeDescriptor[]) => { + const escapeRegExp = (value: string) => value.replace(/[.*+?^${}()|[\]\\]/g, '\\$&') + return scopes.find((scope) => { + // Special handling for slash commands + if (isSlashScope(scope)) { + const allCommands = slashCommandRegistry.getAllCommands() return allCommands.some((cmd) => { const cmdPattern = `/${cmd.name}` - - // For direct mode commands, don't match (keep in command selector) if (cmd.mode === 'direct') return false - - // For submenu mode commands, match when complete command is entered return query === cmdPattern || query.startsWith(`${cmdPattern} `) }) } - const reg = new RegExp(`^(${action.key}|${action.shortcut})(?:\\s|$)`) + // Check if query matches shortcut (exact or prefix) + // Only match if it's the full shortcut followed by space + const shortcuts = getScopeShortcuts(scope).map(escapeRegExp) + const reg = new RegExp(`^(${shortcuts.join('|')})(?:\\s|$)`) return reg.test(query) }) } export * from './commands' +export * from './scope-registry' export * from './types' -export { appAction, knowledgeAction, pluginAction, workflowNodesAction } +export { appScope, knowledgeScope, pluginScope } diff --git a/web/app/components/goto-anything/actions/knowledge.tsx b/web/app/components/goto-anything/actions/knowledge.tsx index 9531a3551f..61cff3c00c 100644 --- a/web/app/components/goto-anything/actions/knowledge.tsx +++ b/web/app/components/goto-anything/actions/knowledge.tsx @@ -1,8 +1,9 @@ -import type { ActionItem, KnowledgeSearchResult } from './types' +import type { KnowledgeSearchResult, ScopeDescriptor } from './types' import type { DataSet } from '@/models/datasets' -import { fetchDatasets } from '@/service/datasets' +import { searchDatasets } from '@/service/use-goto-anything' import { cn } from '@/utils/classnames' import { Folder } from '../../base/icons/src/vender/solid/files' +import { ACTION_KEYS } from '../constants' const EXTERNAL_PROVIDER = 'external' as const const isExternalProvider = (provider: string): boolean => provider === EXTERNAL_PROVIDER @@ -30,22 +31,15 @@ const parser = (datasets: DataSet[]): KnowledgeSearchResult[] => { }) } -export const knowledgeAction: ActionItem = { - key: '@knowledge', - shortcut: '@kb', +export const knowledgeScope: ScopeDescriptor = { + id: 'knowledge', + shortcut: ACTION_KEYS.KNOWLEDGE, + aliases: ['@kb'], title: 'Search Knowledge Bases', description: 'Search and navigate to your knowledge bases', - // action, search: async (_, searchTerm = '', _locale) => { try { - const response = await fetchDatasets({ - url: '/datasets', - params: { - page: 1, - limit: 10, - keyword: searchTerm, - }, - }) + const response = await searchDatasets(searchTerm) const datasets = response?.data || [] return parser(datasets) } diff --git a/web/app/components/goto-anything/actions/plugin.tsx b/web/app/components/goto-anything/actions/plugin.tsx index 07197b8198..8b441d17f6 100644 --- a/web/app/components/goto-anything/actions/plugin.tsx +++ b/web/app/components/goto-anything/actions/plugin.tsx @@ -1,9 +1,10 @@ -import type { Plugin, PluginsFromMarketplaceResponse } from '../../plugins/types' -import type { ActionItem, PluginSearchResult } from './types' +import type { Plugin } from '../../plugins/types' +import type { PluginSearchResult, ScopeDescriptor } from './types' import { renderI18nObject } from '@/i18n-config' -import { postMarketplace } from '@/service/base' +import { searchPlugins } from '@/service/use-goto-anything' import Icon from '../../plugins/card/base/card-icon' import { getPluginIconInMarketplace } from '../../plugins/marketplace/utils' +import { ACTION_KEYS } from '../constants' const parser = (plugins: Plugin[], locale: string): PluginSearchResult[] => { return plugins.map((plugin) => { @@ -18,21 +19,14 @@ const parser = (plugins: Plugin[], locale: string): PluginSearchResult[] => { }) } -export const pluginAction: ActionItem = { - key: '@plugin', - shortcut: '@plugin', +export const pluginScope: ScopeDescriptor = { + id: 'plugin', + shortcut: ACTION_KEYS.PLUGIN, title: 'Search Plugins', description: 'Search and navigate to your plugins', search: async (_, searchTerm = '', locale) => { try { - const response = await postMarketplace<{ data: PluginsFromMarketplaceResponse }>('/plugins/search/advanced', { - body: { - page: 1, - page_size: 10, - query: searchTerm, - type: 'plugin', - }, - }) + const response = await searchPlugins(searchTerm) if (!response?.data?.plugins) { console.warn('Plugin search: Unexpected response structure', response) diff --git a/web/app/components/goto-anything/actions/rag-pipeline-nodes.tsx b/web/app/components/goto-anything/actions/rag-pipeline-nodes.tsx index dc632e4999..14a4f8c3f3 100644 --- a/web/app/components/goto-anything/actions/rag-pipeline-nodes.tsx +++ b/web/app/components/goto-anything/actions/rag-pipeline-nodes.tsx @@ -1,24 +1,41 @@ -import type { ActionItem } from './types' +import type { ScopeSearchHandler } from './scope-registry' +import type { SearchResult } from './types' +import { ACTION_KEYS } from '../constants' +import { scopeRegistry } from './scope-registry' -// Create the RAG pipeline nodes action -export const ragPipelineNodesAction: ActionItem = { - key: '@node', - shortcut: '@node', - title: 'Search RAG Pipeline Nodes', - description: 'Find and jump to nodes in the current RAG pipeline by name or type', - searchFn: undefined, // Will be set by useRagPipelineSearch hook - search: async (_, searchTerm = '', _locale) => { +const scopeId = 'rag-pipeline-node' +let scopeRegistered = false + +const buildSearchHandler = (searchFn?: (searchTerm: string) => SearchResult[]): ScopeSearchHandler => { + return async (_, searchTerm = '', _locale) => { try { - // Use the searchFn if available (set by useRagPipelineSearch hook) - if (ragPipelineNodesAction.searchFn) - return ragPipelineNodesAction.searchFn(searchTerm) - - // If not in RAG pipeline context, return empty array + if (searchFn) + return searchFn(searchTerm) return [] } catch (error) { console.warn('RAG pipeline nodes search failed:', error) return [] } - }, + } +} + +export const registerRagPipelineNodeScope = () => { + if (scopeRegistered) + return + + scopeRegistered = true + scopeRegistry.register({ + id: scopeId, + shortcut: ACTION_KEYS.NODE, + title: 'Search RAG Pipeline Nodes', + description: 'Find and jump to nodes in the current RAG pipeline by name or type', + isAvailable: context => context.isRagPipelinePage, + search: buildSearchHandler(), + }) +} + +export const setRagPipelineNodesSearchFn = (fn: (searchTerm: string) => SearchResult[]) => { + registerRagPipelineNodeScope() + scopeRegistry.updateSearchHandler(scopeId, buildSearchHandler(fn)) } diff --git a/web/app/components/goto-anything/actions/scope-registry.ts b/web/app/components/goto-anything/actions/scope-registry.ts new file mode 100644 index 0000000000..fc27c3b9fb --- /dev/null +++ b/web/app/components/goto-anything/actions/scope-registry.ts @@ -0,0 +1,123 @@ +import type { SearchResult } from './types' + +import { useCallback, useMemo, useSyncExternalStore } from 'react' + +export type ScopeContext = { + isWorkflowPage: boolean + isRagPipelinePage: boolean + isAdmin?: boolean +} + +export type ScopeSearchHandler = ( + query: string, + searchTerm: string, + locale?: string, +) => Promise<SearchResult[]> | SearchResult[] + +export type ScopeDescriptor = { + /** + * Unique identifier for the scope (e.g. 'app', 'plugin') + */ + id: string + /** + * Shortcut to trigger this scope (e.g. '@app') + */ + shortcut: string + /** + * Additional shortcuts that map to this scope (e.g. ['@kb']) + */ + aliases?: string[] + /** + * I18n key or string for the scope title + */ + title: string + /** + * Description for help text + */ + description: string + /** + * Search handler function + */ + search: ScopeSearchHandler + /** + * Predicate to check if this scope is available in current context + */ + isAvailable?: (context: ScopeContext) => boolean +} + +type Listener = () => void + +class ScopeRegistry { + private scopes: Map<string, ScopeDescriptor> = new Map() + private listeners: Set<Listener> = new Set() + private version = 0 + + register(scope: ScopeDescriptor) { + this.scopes.set(scope.id, scope) + this.notify() + } + + unregister(id: string) { + if (this.scopes.delete(id)) + this.notify() + } + + getScope(id: string) { + return this.scopes.get(id) + } + + getScopes(context: ScopeContext): ScopeDescriptor[] { + return Array.from(this.scopes.values()) + .filter(scope => !scope.isAvailable || scope.isAvailable(context)) + .sort((a, b) => a.shortcut.localeCompare(b.shortcut)) + } + + updateSearchHandler(id: string, search: ScopeSearchHandler) { + const scope = this.scopes.get(id) + if (!scope) + return + this.scopes.set(id, { ...scope, search }) + this.notify() + } + + getVersion() { + return this.version + } + + subscribe(listener: Listener) { + this.listeners.add(listener) + return () => { + this.listeners.delete(listener) + } + } + + private notify() { + this.version += 1 + this.listeners.forEach(listener => listener()) + } +} + +export const scopeRegistry = new ScopeRegistry() + +export const useScopeRegistry = (context: ScopeContext) => { + const subscribe = useCallback( + (listener: Listener) => scopeRegistry.subscribe(listener), + [], + ) + + const getSnapshot = useCallback( + () => scopeRegistry.getVersion(), + [], + ) + + const version = useSyncExternalStore( + subscribe, + getSnapshot, + getSnapshot, + ) + + return useMemo( + () => scopeRegistry.getScopes(context), + [version, context.isWorkflowPage, context.isRagPipelinePage, context.isAdmin], + ) +} diff --git a/web/app/components/goto-anything/actions/types.ts b/web/app/components/goto-anything/actions/types.ts index 838195ad85..9e04832cd4 100644 --- a/web/app/components/goto-anything/actions/types.ts +++ b/web/app/components/goto-anything/actions/types.ts @@ -1,5 +1,4 @@ import type { ReactNode } from 'react' -import type { TypeWithI18N } from '../../base/form/types' import type { Plugin } from '../../plugins/types' import type { CommonNodeType } from '../../workflow/types' import type { DataSet } from '@/models/datasets' @@ -7,7 +6,7 @@ import type { App } from '@/types/app' export type SearchResultType = 'app' | 'knowledge' | 'plugin' | 'workflow-node' | 'command' -export type BaseSearchResult<T = any> = { +export type BaseSearchResult<T = unknown> = { id: string title: string description?: string @@ -39,20 +38,8 @@ export type WorkflowNodeSearchResult = { export type CommandSearchResult = { type: 'command' -} & BaseSearchResult<{ command: string, args?: Record<string, any> }> +} & BaseSearchResult<{ command: string, args?: Record<string, unknown> }> export type SearchResult = AppSearchResult | PluginSearchResult | KnowledgeSearchResult | WorkflowNodeSearchResult | CommandSearchResult -export type ActionItem = { - key: '@app' | '@knowledge' | '@plugin' | '@node' | '/' - shortcut: string - title: string | TypeWithI18N - description: string - action?: (data: SearchResult) => void - searchFn?: (searchTerm: string) => SearchResult[] - search: ( - query: string, - searchTerm: string, - locale?: string, - ) => (Promise<SearchResult[]> | SearchResult[]) -} +export type { ScopeContext, ScopeDescriptor } from './scope-registry' diff --git a/web/app/components/goto-anything/actions/workflow-nodes.tsx b/web/app/components/goto-anything/actions/workflow-nodes.tsx index b9aa61705b..d4de980011 100644 --- a/web/app/components/goto-anything/actions/workflow-nodes.tsx +++ b/web/app/components/goto-anything/actions/workflow-nodes.tsx @@ -1,24 +1,41 @@ -import type { ActionItem } from './types' +import type { ScopeSearchHandler } from './scope-registry' +import type { SearchResult } from './types' +import { ACTION_KEYS } from '../constants' +import { scopeRegistry } from './scope-registry' -// Create the workflow nodes action -export const workflowNodesAction: ActionItem = { - key: '@node', - shortcut: '@node', - title: 'Search Workflow Nodes', - description: 'Find and jump to nodes in the current workflow by name or type', - searchFn: undefined, // Will be set by useWorkflowSearch hook - search: async (_, searchTerm = '', _locale) => { +const scopeId = 'workflow-node' +let scopeRegistered = false + +const buildSearchHandler = (searchFn?: (searchTerm: string) => SearchResult[]): ScopeSearchHandler => { + return async (_, searchTerm = '', _locale) => { try { - // Use the searchFn if available (set by useWorkflowSearch hook) - if (workflowNodesAction.searchFn) - return workflowNodesAction.searchFn(searchTerm) - - // If not in workflow context, return empty array + if (searchFn) + return searchFn(searchTerm) return [] } catch (error) { console.warn('Workflow nodes search failed:', error) return [] } - }, + } +} + +export const registerWorkflowNodeScope = () => { + if (scopeRegistered) + return + + scopeRegistered = true + scopeRegistry.register({ + id: scopeId, + shortcut: ACTION_KEYS.NODE, + title: 'Search Workflow Nodes', + description: 'Find and jump to nodes in the current workflow by name or type', + isAvailable: context => context.isWorkflowPage, + search: buildSearchHandler(), + }) +} + +export const setWorkflowNodesSearchFn = (fn: (searchTerm: string) => SearchResult[]) => { + registerWorkflowNodeScope() + scopeRegistry.updateSearchHandler(scopeId, buildSearchHandler(fn)) } diff --git a/web/app/components/goto-anything/command-selector.spec.tsx b/web/app/components/goto-anything/command-selector.spec.tsx index 0712a1afd6..e8ec9a8231 100644 --- a/web/app/components/goto-anything/command-selector.spec.tsx +++ b/web/app/components/goto-anything/command-selector.spec.tsx @@ -1,5 +1,5 @@ -import type { ActionItem } from './actions/types' -import { render, screen } from '@testing-library/react' +import type { ScopeDescriptor } from './actions/scope-registry' +import { render, screen, waitFor } from '@testing-library/react' import userEvent from '@testing-library/user-event' import { Command } from 'cmdk' import * as React from 'react' @@ -22,263 +22,315 @@ vi.mock('./actions/commands/registry', () => ({ }, })) -const createActions = (): Record<string, ActionItem> => ({ - app: { - key: '@app', +type CommandSelectorProps = React.ComponentProps<typeof CommandSelector> + +const mockScopes: ScopeDescriptor[] = [ + { + id: 'app', shortcut: '@app', - title: 'Apps', + title: 'Search Applications', + description: 'Search apps', search: vi.fn(), - description: '', - } as ActionItem, - plugin: { - key: '@plugin', + }, + { + id: 'knowledge', + shortcut: '@knowledge', + title: 'Search Knowledge Bases', + description: 'Search knowledge bases', + search: vi.fn(), + }, + { + id: 'plugin', shortcut: '@plugin', - title: 'Plugins', + title: 'Search Plugins', + description: 'Search plugins', search: vi.fn(), - description: '', - } as ActionItem, -}) + }, + { + id: 'workflow-node', + shortcut: '@node', + title: 'Search Nodes', + description: 'Search workflow nodes', + search: vi.fn(), + }, +] + +const mockOnCommandSelect = vi.fn() +const mockOnCommandValueChange = vi.fn() + +const buildCommandSelector = (props: Partial<CommandSelectorProps> = {}) => ( + <Command> + <Command.List> + <CommandSelector + scopes={mockScopes} + onCommandSelect={mockOnCommandSelect} + {...props} + /> + </Command.List> + </Command> +) + +const renderCommandSelector = (props: Partial<CommandSelectorProps> = {}) => { + return render(buildCommandSelector(props)) +} describe('CommandSelector', () => { - it('should list contextual search actions and notify selection', async () => { - const actions = createActions() - const onSelect = vi.fn() - - render( - <Command> - <CommandSelector - actions={actions} - onCommandSelect={onSelect} - searchFilter="app" - originalQuery="@app" - /> - </Command>, - ) - - const actionButton = screen.getByText('app.gotoAnything.actions.searchApplicationsDesc') - await userEvent.click(actionButton) - - expect(onSelect).toHaveBeenCalledWith('@app') + beforeEach(() => { + vi.clearAllMocks() }) - it('should render slash commands when query starts with slash', async () => { - const actions = createActions() - const onSelect = vi.fn() + describe('Basic Rendering', () => { + it('should render all scopes when no filter is provided', () => { + renderCommandSelector() - render( - <Command> - <CommandSelector - actions={actions} - onCommandSelect={onSelect} - searchFilter="zen" - originalQuery="/zen" - /> - </Command>, - ) + expect(screen.getByText('@app')).toBeInTheDocument() + expect(screen.getByText('@knowledge')).toBeInTheDocument() + expect(screen.getByText('@plugin')).toBeInTheDocument() + expect(screen.getByText('@node')).toBeInTheDocument() + }) - const slashItem = await screen.findByText('app.gotoAnything.actions.zenDesc') - await userEvent.click(slashItem) + it('should render empty filter as showing all scopes', () => { + renderCommandSelector({ searchFilter: '' }) - expect(onSelect).toHaveBeenCalledWith('/zen') + expect(screen.getByText('@app')).toBeInTheDocument() + expect(screen.getByText('@knowledge')).toBeInTheDocument() + expect(screen.getByText('@plugin')).toBeInTheDocument() + expect(screen.getByText('@node')).toBeInTheDocument() + }) + }) + + describe('Filtering Functionality', () => { + it('should filter scopes based on searchFilter - single match', () => { + renderCommandSelector({ searchFilter: 'k' }) + + expect(screen.queryByText('@app')).not.toBeInTheDocument() + expect(screen.getByText('@knowledge')).toBeInTheDocument() + expect(screen.queryByText('@plugin')).not.toBeInTheDocument() + expect(screen.queryByText('@node')).not.toBeInTheDocument() + }) + + it('should filter scopes with multiple matches', () => { + renderCommandSelector({ searchFilter: 'p' }) + + expect(screen.getByText('@app')).toBeInTheDocument() + expect(screen.queryByText('@knowledge')).not.toBeInTheDocument() + expect(screen.getByText('@plugin')).toBeInTheDocument() + expect(screen.queryByText('@node')).not.toBeInTheDocument() + }) + + it('should be case-insensitive when filtering', () => { + renderCommandSelector({ searchFilter: 'APP' }) + + expect(screen.getByText('@app')).toBeInTheDocument() + expect(screen.queryByText('@knowledge')).not.toBeInTheDocument() + }) + + it('should match partial strings', () => { + renderCommandSelector({ searchFilter: 'od' }) + + expect(screen.queryByText('@app')).not.toBeInTheDocument() + expect(screen.queryByText('@knowledge')).not.toBeInTheDocument() + expect(screen.queryByText('@plugin')).not.toBeInTheDocument() + expect(screen.getByText('@node')).toBeInTheDocument() + }) + }) + + describe('Empty State', () => { + it('should show empty state when no matches found', () => { + renderCommandSelector({ searchFilter: 'xyz' }) + + expect(screen.queryByText('@app')).not.toBeInTheDocument() + expect(screen.queryByText('@knowledge')).not.toBeInTheDocument() + expect(screen.queryByText('@plugin')).not.toBeInTheDocument() + expect(screen.queryByText('@node')).not.toBeInTheDocument() + + expect(screen.getByText('app.gotoAnything.noMatchingCommands')).toBeInTheDocument() + expect(screen.getByText('app.gotoAnything.tryDifferentSearch')).toBeInTheDocument() + }) + + it('should not show empty state when filter is empty', () => { + renderCommandSelector({ searchFilter: '' }) + + expect(screen.queryByText('app.gotoAnything.noMatchingCommands')).not.toBeInTheDocument() + }) + }) + + describe('Selection and Highlight Management', () => { + it('should call onCommandValueChange when filter changes and first item differs', async () => { + const { rerender } = renderCommandSelector({ + searchFilter: '', + commandValue: '@app', + onCommandValueChange: mockOnCommandValueChange, + }) + + rerender(buildCommandSelector({ + searchFilter: 'k', + commandValue: '@app', + onCommandValueChange: mockOnCommandValueChange, + })) + + await waitFor(() => { + expect(mockOnCommandValueChange).toHaveBeenCalledWith('@knowledge') + }) + }) + + it('should not call onCommandValueChange if current value still exists', async () => { + const { rerender } = renderCommandSelector({ + searchFilter: '', + commandValue: '@app', + onCommandValueChange: mockOnCommandValueChange, + }) + + rerender(buildCommandSelector({ + searchFilter: 'a', + commandValue: '@app', + onCommandValueChange: mockOnCommandValueChange, + })) + + await waitFor(() => { + expect(mockOnCommandValueChange).not.toHaveBeenCalled() + }) + }) + + it('should handle onCommandSelect callback correctly', async () => { + const user = userEvent.setup() + renderCommandSelector({ searchFilter: 'k' }) + + await user.click(screen.getByText('@knowledge')) + + expect(mockOnCommandSelect).toHaveBeenCalledWith('@knowledge') + }) + }) + + describe('Edge Cases', () => { + it('should handle empty scopes array', () => { + renderCommandSelector({ scopes: [] }) + + expect(screen.getByText('app.gotoAnything.noMatchingCommands')).toBeInTheDocument() + }) + + it('should handle special characters in filter', () => { + renderCommandSelector({ searchFilter: '@' }) + + expect(screen.getByText('@app')).toBeInTheDocument() + expect(screen.getByText('@knowledge')).toBeInTheDocument() + expect(screen.getByText('@plugin')).toBeInTheDocument() + expect(screen.getByText('@node')).toBeInTheDocument() + }) + + it('should handle undefined onCommandValueChange gracefully', () => { + const { rerender } = renderCommandSelector({ searchFilter: '' }) + + expect(() => { + rerender(buildCommandSelector({ searchFilter: 'k' })) + }).not.toThrow() + }) + }) + + describe('User Interactions', () => { + it('should list contextual scopes and notify selection', async () => { + const user = userEvent.setup() + renderCommandSelector({ searchFilter: 'app', originalQuery: '@app' }) + + await user.click(screen.getByText('app.gotoAnything.actions.searchApplicationsDesc')) + + expect(mockOnCommandSelect).toHaveBeenCalledWith('@app') + }) + + it('should render slash commands when query starts with slash', async () => { + const user = userEvent.setup() + renderCommandSelector({ searchFilter: 'zen', originalQuery: '/zen' }) + + const slashItem = await screen.findByText('app.gotoAnything.actions.zenDesc') + await user.click(slashItem) + + expect(mockOnCommandSelect).toHaveBeenCalledWith('/zen') + }) }) it('should show all slash commands when no filter provided', () => { - const actions = createActions() - const onSelect = vi.fn() - - render( - <Command> - <CommandSelector - actions={actions} - onCommandSelect={onSelect} - searchFilter="" - originalQuery="/" - /> - </Command>, - ) + renderCommandSelector({ searchFilter: '', originalQuery: '/' }) // Should show the zen command from mock expect(screen.getByText('/zen')).toBeInTheDocument() }) - it('should exclude slash action when in @ mode', () => { - const actions = { - ...createActions(), - slash: { - key: '/', + it('should exclude slash scope when in @ mode', () => { + const scopesWithSlash: ScopeDescriptor[] = [ + ...mockScopes, + { + id: 'slash', shortcut: '/', title: 'Slash', - search: vi.fn(), description: '', - } as ActionItem, - } - const onSelect = vi.fn() + search: vi.fn(), + }, + ] - render( - <Command> - <CommandSelector - actions={actions} - onCommandSelect={onSelect} - searchFilter="" - originalQuery="@" - /> - </Command>, - ) + renderCommandSelector({ scopes: scopesWithSlash, searchFilter: '', originalQuery: '@' }) // Should show @ commands but not / expect(screen.getByText('@app')).toBeInTheDocument() expect(screen.queryByText('/')).not.toBeInTheDocument() }) - it('should show all actions when no filter in @ mode', () => { - const actions = createActions() - const onSelect = vi.fn() - - render( - <Command> - <CommandSelector - actions={actions} - onCommandSelect={onSelect} - searchFilter="" - originalQuery="@" - /> - </Command>, - ) + it('should show all scopes when no filter in @ mode', () => { + renderCommandSelector({ searchFilter: '', originalQuery: '@' }) expect(screen.getByText('@app')).toBeInTheDocument() expect(screen.getByText('@plugin')).toBeInTheDocument() }) it('should set default command value when items exist but value does not', () => { - const actions = createActions() - const onSelect = vi.fn() - const onCommandValueChange = vi.fn() + renderCommandSelector({ + searchFilter: '', + originalQuery: '@', + commandValue: 'non-existent', + onCommandValueChange: mockOnCommandValueChange, + }) - render( - <Command> - <CommandSelector - actions={actions} - onCommandSelect={onSelect} - searchFilter="" - originalQuery="@" - commandValue="non-existent" - onCommandValueChange={onCommandValueChange} - /> - </Command>, - ) - - expect(onCommandValueChange).toHaveBeenCalledWith('@app') + expect(mockOnCommandValueChange).toHaveBeenCalledWith('@app') }) it('should NOT set command value when value already exists in items', () => { - const actions = createActions() - const onSelect = vi.fn() - const onCommandValueChange = vi.fn() + renderCommandSelector({ + searchFilter: '', + originalQuery: '@', + commandValue: '@app', + onCommandValueChange: mockOnCommandValueChange, + }) - render( - <Command> - <CommandSelector - actions={actions} - onCommandSelect={onSelect} - searchFilter="" - originalQuery="@" - commandValue="@app" - onCommandValueChange={onCommandValueChange} - /> - </Command>, - ) - - expect(onCommandValueChange).not.toHaveBeenCalled() + expect(mockOnCommandValueChange).not.toHaveBeenCalled() }) it('should show no matching commands message when filter has no results', () => { - const actions = createActions() - const onSelect = vi.fn() - - render( - <Command> - <CommandSelector - actions={actions} - onCommandSelect={onSelect} - searchFilter="nonexistent" - originalQuery="@nonexistent" - /> - </Command>, - ) + renderCommandSelector({ searchFilter: 'nonexistent', originalQuery: '@nonexistent' }) expect(screen.getByText('app.gotoAnything.noMatchingCommands')).toBeInTheDocument() expect(screen.getByText('app.gotoAnything.tryDifferentSearch')).toBeInTheDocument() }) it('should show no matching commands for slash mode with no results', () => { - const actions = createActions() - const onSelect = vi.fn() - - render( - <Command> - <CommandSelector - actions={actions} - onCommandSelect={onSelect} - searchFilter="nonexistentcommand" - originalQuery="/nonexistentcommand" - /> - </Command>, - ) + renderCommandSelector({ searchFilter: 'nonexistentcommand', originalQuery: '/nonexistentcommand' }) expect(screen.getByText('app.gotoAnything.noMatchingCommands')).toBeInTheDocument() }) it('should render description for @ commands', () => { - const actions = createActions() - const onSelect = vi.fn() - - render( - <Command> - <CommandSelector - actions={actions} - onCommandSelect={onSelect} - searchFilter="" - originalQuery="@" - /> - </Command>, - ) + renderCommandSelector({ searchFilter: '', originalQuery: '@' }) expect(screen.getByText('app.gotoAnything.actions.searchApplicationsDesc')).toBeInTheDocument() expect(screen.getByText('app.gotoAnything.actions.searchPluginsDesc')).toBeInTheDocument() }) it('should render group header for @ mode', () => { - const actions = createActions() - const onSelect = vi.fn() - - render( - <Command> - <CommandSelector - actions={actions} - onCommandSelect={onSelect} - searchFilter="" - originalQuery="@" - /> - </Command>, - ) + renderCommandSelector({ searchFilter: '', originalQuery: '@' }) expect(screen.getByText('app.gotoAnything.selectSearchType')).toBeInTheDocument() }) it('should render group header for slash mode', () => { - const actions = createActions() - const onSelect = vi.fn() - - render( - <Command> - <CommandSelector - actions={actions} - onCommandSelect={onSelect} - searchFilter="" - originalQuery="/" - /> - </Command>, - ) + renderCommandSelector({ searchFilter: '', originalQuery: '/' }) expect(screen.getByText('app.gotoAnything.groups.commands')).toBeInTheDocument() }) diff --git a/web/app/components/goto-anything/command-selector.tsx b/web/app/components/goto-anything/command-selector.tsx index 79d543faca..731796a320 100644 --- a/web/app/components/goto-anything/command-selector.tsx +++ b/web/app/components/goto-anything/command-selector.tsx @@ -1,13 +1,14 @@ import type { FC } from 'react' -import type { ActionItem } from './actions/types' +import type { ScopeDescriptor } from './actions/scope-registry' import { Command } from 'cmdk' import { usePathname } from 'next/navigation' import { useEffect, useMemo } from 'react' import { useTranslation } from 'react-i18next' import { slashCommandRegistry } from './actions/commands/registry' +import { ACTION_KEYS } from './constants' type Props = { - actions: Record<string, ActionItem> + scopes: ScopeDescriptor[] onCommandSelect: (commandKey: string) => void searchFilter?: string commandValue?: string @@ -15,7 +16,7 @@ type Props = { originalQuery?: string } -const CommandSelector: FC<Props> = ({ actions, onCommandSelect, searchFilter, commandValue, onCommandValueChange, originalQuery }) => { +const CommandSelector: FC<Props> = ({ scopes, onCommandSelect, searchFilter, commandValue, onCommandValueChange, originalQuery }) => { const { t } = useTranslation() const pathname = usePathname() @@ -43,22 +44,31 @@ const CommandSelector: FC<Props> = ({ actions, onCommandSelect, searchFilter, co })) }, [isSlashMode, searchFilter, pathname]) - const filteredActions = useMemo(() => { + const filteredScopes = useMemo(() => { if (isSlashMode) return [] - return Object.values(actions).filter((action) => { + return scopes.filter((scope) => { // Exclude slash action when in @ mode - if (action.key === '/') + if (scope.id === 'slash' || scope.shortcut === ACTION_KEYS.SLASH) return false if (!searchFilter) return true - const filterLower = searchFilter.toLowerCase() - return action.shortcut.toLowerCase().includes(filterLower) - }) - }, [actions, searchFilter, isSlashMode]) - const allItems = isSlashMode ? slashCommands : filteredActions + // Match against shortcut/aliases or title + const filterLower = searchFilter.toLowerCase() + const shortcuts = [scope.shortcut, ...(scope.aliases || [])] + return shortcuts.some(shortcut => shortcut.toLowerCase().includes(filterLower)) + || scope.title.toLowerCase().includes(filterLower) + }).map(scope => ({ + key: scope.shortcut, // Map to shortcut for UI display consistency + shortcut: scope.shortcut, + title: scope.title, + description: scope.description, + })) + }, [scopes, searchFilter, isSlashMode]) + + const allItems = isSlashMode ? slashCommands : filteredScopes useEffect(() => { if (allItems.length > 0 && onCommandValueChange) { diff --git a/web/app/components/goto-anything/components/empty-state.spec.tsx b/web/app/components/goto-anything/components/empty-state.spec.tsx index e1e5e0dc89..e3b2136397 100644 --- a/web/app/components/goto-anything/components/empty-state.spec.tsx +++ b/web/app/components/goto-anything/components/empty-state.spec.tsx @@ -83,10 +83,10 @@ describe('EmptyState', () => { }) it('should show specific search hint with shortcuts', () => { - const Actions = { - app: { key: '@app', shortcut: '@app' }, - plugin: { key: '@plugin', shortcut: '@plugin' }, - } as unknown as Record<string, import('../actions/types').ActionItem> + const Actions = [ + { id: 'app', shortcut: '@app', title: 'App', description: '', search: vi.fn() }, + { id: 'plugin', shortcut: '@plugin', title: 'Plugin', description: '', search: vi.fn() }, + ] as import('../actions/types').ScopeDescriptor[] render(<EmptyState variant="no-results" searchMode="general" Actions={Actions} />) expect(screen.getByText('gotoAnything.emptyState.trySpecificSearch:@app, @plugin')).toBeInTheDocument() diff --git a/web/app/components/goto-anything/components/empty-state.tsx b/web/app/components/goto-anything/components/empty-state.tsx index a07bc1d45a..a59e5070fd 100644 --- a/web/app/components/goto-anything/components/empty-state.tsx +++ b/web/app/components/goto-anything/components/empty-state.tsx @@ -1,7 +1,7 @@ 'use client' import type { FC } from 'react' -import type { ActionItem } from '../actions/types' +import type { ScopeDescriptor } from '../actions/types' import { useTranslation } from 'react-i18next' export type EmptyStateVariant = 'no-results' | 'error' | 'default' | 'loading' @@ -10,14 +10,14 @@ export type EmptyStateProps = { variant: EmptyStateVariant searchMode?: string error?: Error | null - Actions?: Record<string, ActionItem> + Actions?: ScopeDescriptor[] } const EmptyState: FC<EmptyStateProps> = ({ variant, searchMode = 'general', error, - Actions = {}, + Actions = [], }) => { const { t } = useTranslation() @@ -88,7 +88,7 @@ const EmptyState: FC<EmptyStateProps> = ({ return t('gotoAnything.emptyState.tryDifferentTerm', { ns: 'app' }) } - const shortcuts = Object.values(Actions).map(action => action.shortcut).join(', ') + const shortcuts = Actions.map(scope => scope.shortcut).join(', ') return t('gotoAnything.emptyState.trySpecificSearch', { ns: 'app', shortcuts }) } diff --git a/web/app/components/goto-anything/constants.ts b/web/app/components/goto-anything/constants.ts new file mode 100644 index 0000000000..45ef9fd7cb --- /dev/null +++ b/web/app/components/goto-anything/constants.ts @@ -0,0 +1,20 @@ +/** + * Goto Anything Constants + * Centralized constants for action keys + */ + +/** + * Action keys for scope-based searches + */ +export const ACTION_KEYS = { + APP: '@app', + KNOWLEDGE: '@knowledge', + PLUGIN: '@plugin', + NODE: '@node', + SLASH: '/', +} as const + +/** + * Type-safe action key union type + */ +export type ActionKey = typeof ACTION_KEYS[keyof typeof ACTION_KEYS] diff --git a/web/app/components/goto-anything/hooks/use-goto-anything-navigation.spec.ts b/web/app/components/goto-anything/hooks/use-goto-anything-navigation.spec.ts index efb15f41b3..4d2603b947 100644 --- a/web/app/components/goto-anything/hooks/use-goto-anything-navigation.spec.ts +++ b/web/app/components/goto-anything/hooks/use-goto-anything-navigation.spec.ts @@ -32,23 +32,17 @@ vi.mock('../actions/commands/registry', () => ({ }, })) -const createMockActionItem = ( - key: '@app' | '@knowledge' | '@plugin' | '@node' | '/', - extra: Record<string, unknown> = {}, -) => ({ - key, - shortcut: key, - title: `${key} title`, - description: `${key} description`, - search: vi.fn().mockResolvedValue([]), - ...extra, -}) +const mockExecuteCommand = vi.fn() + +vi.mock('../actions/commands', () => ({ + executeCommand: (...args: unknown[]) => mockExecuteCommand(...args), +})) + +vi.mock('@/app/components/workflow/constants', () => ({ + VIBE_COMMAND_EVENT: 'vibe-command', +})) const createMockOptions = (overrides = {}) => ({ - Actions: { - slash: createMockActionItem('/', { action: vi.fn() }), - app: createMockActionItem('@app'), - }, setSearchQuery: vi.fn(), clearSelection: vi.fn(), inputRef: { current: { focus: vi.fn() } } as unknown as React.RefObject<HTMLInputElement>, @@ -60,6 +54,7 @@ describe('useGotoAnythingNavigation', () => { beforeEach(() => { vi.clearAllMocks() mockFindCommandResult = null + mockExecuteCommand.mockReset() vi.useFakeTimers() }) @@ -221,13 +216,8 @@ describe('useGotoAnythingNavigation', () => { expect(mockRouterPush).not.toHaveBeenCalled() }) - it('should execute slash command action for command type', () => { - const actionMock = vi.fn() - const options = createMockOptions({ - Actions: { - slash: { key: '/', shortcut: '/', action: actionMock }, - }, - }) + it('should execute command via executeCommand for command type', () => { + const options = createMockOptions() const { result } = renderHook(() => useGotoAnythingNavigation(options)) @@ -242,7 +232,7 @@ describe('useGotoAnythingNavigation', () => { result.current.handleNavigate(commandResult) }) - expect(actionMock).toHaveBeenCalledWith(commandResult) + expect(mockExecuteCommand).toHaveBeenCalledWith('theme.set', { theme: 'dark' }) }) it('should set activePlugin for plugin type', () => { @@ -368,10 +358,8 @@ describe('useGotoAnythingNavigation', () => { // No error should occur }) - it('should handle missing slash action', () => { - const options = createMockOptions({ - Actions: {}, - }) + it('should handle command execution without error', () => { + const options = createMockOptions() const { result } = renderHook(() => useGotoAnythingNavigation(options)) @@ -385,7 +373,7 @@ describe('useGotoAnythingNavigation', () => { }) }) - // No error should occur + expect(mockExecuteCommand).toHaveBeenCalledWith('test-command', undefined) }) }) }) diff --git a/web/app/components/goto-anything/hooks/use-goto-anything-navigation.ts b/web/app/components/goto-anything/hooks/use-goto-anything-navigation.ts index 73be6cd3ee..8b90e597c5 100644 --- a/web/app/components/goto-anything/hooks/use-goto-anything-navigation.ts +++ b/web/app/components/goto-anything/hooks/use-goto-anything-navigation.ts @@ -2,10 +2,12 @@ import type { RefObject } from 'react' import type { Plugin } from '../../plugins/types' -import type { ActionItem, SearchResult } from '../actions/types' +import type { SearchResult } from '../actions/types' import { useRouter } from 'next/navigation' import { useCallback, useState } from 'react' +import { VIBE_COMMAND_EVENT } from '@/app/components/workflow/constants' import { selectWorkflowNode } from '@/app/components/workflow/utils/node-navigation' +import { executeCommand } from '../actions/commands' import { slashCommandRegistry } from '../actions/commands/registry' export type UseGotoAnythingNavigationReturn = { @@ -16,7 +18,6 @@ export type UseGotoAnythingNavigationReturn = { } export type UseGotoAnythingNavigationOptions = { - Actions: Record<string, ActionItem> setSearchQuery: (query: string) => void clearSelection: () => void inputRef: RefObject<HTMLInputElement | null> @@ -27,7 +28,6 @@ export const useGotoAnythingNavigation = ( options: UseGotoAnythingNavigationOptions, ): UseGotoAnythingNavigationReturn => { const { - Actions, setSearchQuery, clearSelection, inputRef, @@ -67,9 +67,16 @@ export const useGotoAnythingNavigation = ( switch (result.type) { case 'command': { - // Execute slash commands - const action = Actions.slash - action?.action?.(result) + if (result.data.command === 'workflow.vibe') { + if (typeof document !== 'undefined') { + document.dispatchEvent(new CustomEvent(VIBE_COMMAND_EVENT, { detail: { dsl: result.data.args?.dsl } })) + } + break + } + + // Execute slash commands using the command bus + const { command, args } = result.data + executeCommand(command, args) break } case 'plugin': @@ -79,13 +86,12 @@ export const useGotoAnythingNavigation = ( // Handle workflow node selection and navigation if (result.metadata?.nodeId) selectWorkflowNode(result.metadata.nodeId, true) - break default: if (result.path) router.push(result.path) } - }, [router, Actions, onClose, setSearchQuery]) + }, [router, onClose, setSearchQuery]) return { handleCommandSelect, diff --git a/web/app/components/goto-anything/hooks/use-goto-anything-results.spec.ts b/web/app/components/goto-anything/hooks/use-goto-anything-results.spec.ts index ca95abeacd..d5e0c628e4 100644 --- a/web/app/components/goto-anything/hooks/use-goto-anything-results.spec.ts +++ b/web/app/components/goto-anything/hooks/use-goto-anything-results.spec.ts @@ -35,11 +35,11 @@ vi.mock('../actions', () => ({ searchAnything: (...args: unknown[]) => mockSearchAnything(...args), })) -const createMockActionItem = (key: '@app' | '@knowledge' | '@plugin' | '@node' | '/') => ({ - key, - shortcut: key, - title: `${key} title`, - description: `${key} description`, +const createMockScopeDescriptor = (id: string, shortcut: string) => ({ + id, + shortcut, + title: `${shortcut} title`, + description: `${shortcut} description`, search: vi.fn().mockResolvedValue([]), }) @@ -47,7 +47,7 @@ const createMockOptions = (overrides = {}) => ({ searchQueryDebouncedValue: '', searchMode: 'general', isCommandsMode: false, - Actions: { app: createMockActionItem('@app') }, + scopes: [createMockScopeDescriptor('app', '@app')], isWorkflowPage: false, isRagPipelinePage: false, cmdVal: '_', @@ -300,36 +300,36 @@ describe('useGotoAnythingResults', () => { describe('queryFn execution', () => { it('should call matchAction with lowercased query', async () => { - const mockActions = { app: createMockActionItem('@app') } - mockMatchAction.mockReturnValue({ key: '@app' }) + const mockScopes = [createMockScopeDescriptor('app', '@app')] + mockMatchAction.mockReturnValue(mockScopes[0]) mockSearchAnything.mockResolvedValue([]) renderHook(() => useGotoAnythingResults(createMockOptions({ searchQueryDebouncedValue: 'TEST QUERY', - Actions: mockActions, + scopes: mockScopes, }))) expect(capturedQueryFn).toBeDefined() await capturedQueryFn!() - expect(mockMatchAction).toHaveBeenCalledWith('test query', mockActions) + expect(mockMatchAction).toHaveBeenCalledWith('test query', mockScopes) }) it('should call searchAnything with correct parameters', async () => { - const mockActions = { app: createMockActionItem('@app') } - const mockAction = { key: '@app' } + const mockScopes = [createMockScopeDescriptor('app', '@app')] + const mockAction = mockScopes[0] mockMatchAction.mockReturnValue(mockAction) mockSearchAnything.mockResolvedValue([{ id: '1', type: 'app', title: 'Result' }]) renderHook(() => useGotoAnythingResults(createMockOptions({ searchQueryDebouncedValue: 'My Query', - Actions: mockActions, + scopes: mockScopes, }))) expect(capturedQueryFn).toBeDefined() const result = await capturedQueryFn!() - expect(mockSearchAnything).toHaveBeenCalledWith('en_US', 'my query', mockAction, mockActions) + expect(mockSearchAnything).toHaveBeenCalledWith('en_US', 'my query', mockAction, mockScopes) expect(result).toEqual([{ id: '1', type: 'app', title: 'Result' }]) }) diff --git a/web/app/components/goto-anything/hooks/use-goto-anything-results.ts b/web/app/components/goto-anything/hooks/use-goto-anything-results.ts index dabbd8039c..7051177e3d 100644 --- a/web/app/components/goto-anything/hooks/use-goto-anything-results.ts +++ b/web/app/components/goto-anything/hooks/use-goto-anything-results.ts @@ -1,6 +1,6 @@ 'use client' -import type { ActionItem, SearchResult } from '../actions/types' +import type { ScopeDescriptor, SearchResult } from '../actions/types' import { useQuery } from '@tanstack/react-query' import { useEffect, useMemo } from 'react' import { useGetLanguage } from '@/context/i18n' @@ -19,7 +19,7 @@ export type UseGotoAnythingResultsOptions = { searchQueryDebouncedValue: string searchMode: string isCommandsMode: boolean - Actions: Record<string, ActionItem> + scopes: ScopeDescriptor[] isWorkflowPage: boolean isRagPipelinePage: boolean cmdVal: string @@ -33,7 +33,7 @@ export const useGotoAnythingResults = ( searchQueryDebouncedValue, searchMode, isCommandsMode, - Actions, + scopes, isWorkflowPage, isRagPipelinePage, cmdVal, @@ -42,13 +42,9 @@ export const useGotoAnythingResults = ( const defaultLocale = useGetLanguage() - // Use action keys as stable cache key instead of the full Actions object - // (Actions contains functions which are not serializable) - const actionKeys = useMemo(() => Object.keys(Actions).sort(), [Actions]) - const { data: searchResults = [], isLoading, isError, error } = useQuery( { - // eslint-disable-next-line @tanstack/query/exhaustive-deps -- Actions intentionally excluded: contains non-serializable functions; actionKeys provides stable representation + // eslint-disable-next-line @tanstack/query/exhaustive-deps -- scopes intentionally excluded: contains non-serializable functions; scope IDs provide stable representation queryKey: [ 'goto-anything', 'search-result', @@ -57,12 +53,12 @@ export const useGotoAnythingResults = ( isWorkflowPage, isRagPipelinePage, defaultLocale, - actionKeys, + scopes.map(s => s.id).sort().join(','), ], queryFn: async () => { const query = searchQueryDebouncedValue.toLowerCase() - const action = matchAction(query, Actions) - return await searchAnything(defaultLocale, query, action, Actions) + const scope = matchAction(query, scopes) + return await searchAnything(defaultLocale, query, scope, scopes) }, enabled: !!searchQueryDebouncedValue && !isCommandsMode, staleTime: 30000, diff --git a/web/app/components/goto-anything/hooks/use-goto-anything-search.spec.ts b/web/app/components/goto-anything/hooks/use-goto-anything-search.spec.ts index d8987c2d9c..6d09bce82d 100644 --- a/web/app/components/goto-anything/hooks/use-goto-anything-search.spec.ts +++ b/web/app/components/goto-anything/hooks/use-goto-anything-search.spec.ts @@ -1,9 +1,25 @@ -import type { ActionItem } from '../actions/types' +import type { ScopeDescriptor } from '../actions/types' import { act, renderHook } from '@testing-library/react' import { useGotoAnythingSearch } from './use-goto-anything-search' let mockContextValue = { isWorkflowPage: false, isRagPipelinePage: false } -let mockMatchActionResult: Partial<ActionItem> | undefined +let mockMatchActionResult: ScopeDescriptor | undefined + +const baseScopesMock: ScopeDescriptor[] = [ + { id: 'slash', shortcut: '/', title: 'Slash', description: 'Slash commands', search: vi.fn() }, + { id: 'app', shortcut: '@app', title: 'App', description: 'Search apps', search: vi.fn() }, + { id: 'knowledge', shortcut: '@knowledge', title: 'Knowledge', description: 'Search KB', search: vi.fn() }, +] + +const workflowScopesMock: ScopeDescriptor[] = [ + ...baseScopesMock, + { id: 'node', shortcut: '@node', title: 'Node', description: 'Search nodes', search: vi.fn() }, +] + +const ragScopesMock: ScopeDescriptor[] = [ + ...baseScopesMock, + { id: 'ragNode', shortcut: '@node', title: 'RAG Node', description: 'Search RAG nodes', search: vi.fn() }, +] vi.mock('ahooks', () => ({ useDebounce: <T>(value: T) => value, @@ -14,19 +30,12 @@ vi.mock('../context', () => ({ })) vi.mock('../actions', () => ({ - createActions: (isWorkflowPage: boolean, isRagPipelinePage: boolean) => { - const base = { - slash: { key: '/', shortcut: '/' }, - app: { key: '@app', shortcut: '@app' }, - knowledge: { key: '@knowledge', shortcut: '@kb' }, - } - if (isWorkflowPage) { - return { ...base, node: { key: '@node', shortcut: '@node' } } - } - if (isRagPipelinePage) { - return { ...base, ragNode: { key: '@node', shortcut: '@node' } } - } - return base + useGotoAnythingScopes: (context: { isWorkflowPage: boolean, isRagPipelinePage: boolean }) => { + if (context.isWorkflowPage) + return workflowScopesMock + if (context.isRagPipelinePage) + return ragScopesMock + return baseScopesMock }, matchAction: () => mockMatchActionResult, })) @@ -74,30 +83,30 @@ describe('useGotoAnythingSearch', () => { }) }) - describe('Actions', () => { - it('should provide Actions based on context', () => { + describe('scopes', () => { + it('should provide scopes based on context', () => { const { result } = renderHook(() => useGotoAnythingSearch()) - expect(result.current.Actions).toBeDefined() - expect(typeof result.current.Actions).toBe('object') + expect(result.current.scopes).toBeDefined() + expect(Array.isArray(result.current.scopes)).toBe(true) }) - it('should include node action when on workflow page', () => { + it('should include node scope when on workflow page', () => { mockContextValue = { isWorkflowPage: true, isRagPipelinePage: false } const { result } = renderHook(() => useGotoAnythingSearch()) - expect(result.current.Actions.node).toBeDefined() + expect(result.current.scopes.find(s => s.id === 'node')).toBeDefined() }) - it('should include ragNode action when on RAG pipeline page', () => { + it('should include ragNode scope when on RAG pipeline page', () => { mockContextValue = { isWorkflowPage: false, isRagPipelinePage: true } const { result } = renderHook(() => useGotoAnythingSearch()) - expect(result.current.Actions.ragNode).toBeDefined() + expect(result.current.scopes.find(s => s.id === 'ragNode')).toBeDefined() }) - it('should not include node actions when on regular page', () => { + it('should not include node scopes when on regular page', () => { mockContextValue = { isWorkflowPage: false, isRagPipelinePage: false } const { result } = renderHook(() => useGotoAnythingSearch()) - expect(result.current.Actions.node).toBeUndefined() - expect(result.current.Actions.ragNode).toBeUndefined() + expect(result.current.scopes.find(s => s.id === 'node')).toBeUndefined() + expect(result.current.scopes.find(s => s.id === 'ragNode')).toBeUndefined() }) }) @@ -145,7 +154,7 @@ describe('useGotoAnythingSearch', () => { }) it('should return false when query starts with "@" and action matches', () => { - mockMatchActionResult = { key: '@app', shortcut: '@app' } + mockMatchActionResult = baseScopesMock.find(s => s.id === 'app') const { result } = renderHook(() => useGotoAnythingSearch()) act(() => { @@ -206,8 +215,8 @@ describe('useGotoAnythingSearch', () => { expect(result.current.searchMode).toBe('general') }) - it('should return action key when action matches', () => { - mockMatchActionResult = { key: '@app', shortcut: '@app' } + it('should return action shortcut when action matches', () => { + mockMatchActionResult = baseScopesMock.find(s => s.id === 'app') const { result } = renderHook(() => useGotoAnythingSearch()) act(() => { @@ -217,8 +226,8 @@ describe('useGotoAnythingSearch', () => { expect(result.current.searchMode).toBe('@app') }) - it('should return "@command" when action key is "/"', () => { - mockMatchActionResult = { key: '/', shortcut: '/' } + it('should return "@command" when action is slash', () => { + mockMatchActionResult = baseScopesMock.find(s => s.id === 'slash') const { result } = renderHook(() => useGotoAnythingSearch()) act(() => { diff --git a/web/app/components/goto-anything/hooks/use-goto-anything-search.ts b/web/app/components/goto-anything/hooks/use-goto-anything-search.ts index dd88fc315b..f614801200 100644 --- a/web/app/components/goto-anything/hooks/use-goto-anything-search.ts +++ b/web/app/components/goto-anything/hooks/use-goto-anything-search.ts @@ -1,9 +1,10 @@ 'use client' -import type { ActionItem } from '../actions/types' +import type { ScopeDescriptor } from '../actions/types' import { useDebounce } from 'ahooks' import { useCallback, useMemo, useState } from 'react' -import { createActions, matchAction } from '../actions' +import { matchAction, useGotoAnythingScopes } from '../actions' +import { ACTION_KEYS } from '../constants' import { useGotoAnythingContext } from '../context' export type UseGotoAnythingSearchReturn = { @@ -15,7 +16,7 @@ export type UseGotoAnythingSearchReturn = { cmdVal: string setCmdVal: (val: string) => void clearSelection: () => void - Actions: Record<string, ActionItem> + scopes: ScopeDescriptor[] } export const useGotoAnythingSearch = (): UseGotoAnythingSearchReturn => { @@ -23,10 +24,8 @@ export const useGotoAnythingSearch = (): UseGotoAnythingSearchReturn => { const [searchQuery, setSearchQuery] = useState<string>('') const [cmdVal, setCmdVal] = useState<string>('_') - // Filter actions based on context - const Actions = useMemo(() => { - return createActions(isWorkflowPage, isRagPipelinePage) - }, [isWorkflowPage, isRagPipelinePage]) + // Fetch scopes from registry based on context + const scopes = useGotoAnythingScopes({ isWorkflowPage, isRagPipelinePage }) const searchQueryDebouncedValue = useDebounce(searchQuery.trim(), { wait: 300, @@ -35,28 +34,30 @@ export const useGotoAnythingSearch = (): UseGotoAnythingSearchReturn => { const isCommandsMode = useMemo(() => { const trimmed = searchQuery.trim() return trimmed === '@' || trimmed === '/' - || (trimmed.startsWith('@') && !matchAction(trimmed, Actions)) - || (trimmed.startsWith('/') && !matchAction(trimmed, Actions)) - }, [searchQuery, Actions]) + || (trimmed.startsWith('@') && !matchAction(trimmed, scopes)) + || (trimmed.startsWith('/') && !matchAction(trimmed, scopes)) + }, [searchQuery, scopes]) const searchMode = useMemo(() => { if (isCommandsMode) { - // Distinguish between @ (scopes) and / (commands) mode if (searchQuery.trim().startsWith('@')) return 'scopes' else if (searchQuery.trim().startsWith('/')) return 'commands' - return 'commands' // default fallback + return 'commands' } const query = searchQueryDebouncedValue.toLowerCase() - const action = matchAction(query, Actions) + const action = matchAction(query, scopes) if (!action) return 'general' - return action.key === '/' ? '@command' : action.key - }, [searchQueryDebouncedValue, Actions, isCommandsMode, searchQuery]) + if (action.id === 'slash' || action.shortcut === ACTION_KEYS.SLASH) + return '@command' + + return action.shortcut + }, [searchQueryDebouncedValue, scopes, isCommandsMode, searchQuery]) // Prevent automatic selection of the first option when cmdVal is not set const clearSelection = useCallback(() => { @@ -72,6 +73,6 @@ export const useGotoAnythingSearch = (): UseGotoAnythingSearchReturn => { cmdVal, setCmdVal, clearSelection, - Actions, + scopes, } } diff --git a/web/app/components/goto-anything/hooks/use-search.ts b/web/app/components/goto-anything/hooks/use-search.ts new file mode 100644 index 0000000000..f6dcae70f1 --- /dev/null +++ b/web/app/components/goto-anything/hooks/use-search.ts @@ -0,0 +1,93 @@ +import { keepPreviousData, useQuery } from '@tanstack/react-query' +import { useDebounce } from 'ahooks' +import { useMemo } from 'react' +import { useGetLanguage } from '@/context/i18n' +import { matchAction, searchAnything, useGotoAnythingScopes } from '../actions' +import { ACTION_KEYS } from '../constants' +import { useGotoAnythingContext } from '../context' + +export const useSearch = (searchQuery: string) => { + const defaultLocale = useGetLanguage() + const { isWorkflowPage, isRagPipelinePage } = useGotoAnythingContext() + + // Fetch scopes from registry based on context + const scopes = useGotoAnythingScopes({ isWorkflowPage, isRagPipelinePage }) + + const searchQueryDebouncedValue = useDebounce(searchQuery.trim(), { + wait: 300, + }) + + const isCommandsMode = searchQuery.trim() === '@' || searchQuery.trim() === '/' + || (searchQuery.trim().startsWith('@') && !matchAction(searchQuery.trim(), scopes)) + || (searchQuery.trim().startsWith('/') && !matchAction(searchQuery.trim(), scopes)) + + const searchMode = useMemo(() => { + if (isCommandsMode) { + // Distinguish between @ (scopes) and / (commands) mode + if (searchQuery.trim().startsWith('@')) + return 'scopes' + else if (searchQuery.trim().startsWith('/')) + return 'commands' + return 'commands' // default fallback + } + + const query = searchQueryDebouncedValue.toLowerCase() + const action = matchAction(query, scopes) + + if (!action) + return 'general' + + if (action.id === 'slash' || action.shortcut === ACTION_KEYS.SLASH) + return '@command' + + return action.shortcut + }, [searchQueryDebouncedValue, scopes, isCommandsMode, searchQuery]) + + const { data: searchResults = [], isLoading, isError, error } = useQuery( + { + queryKey: [ + 'goto-anything', + 'search-result', + searchQueryDebouncedValue, + searchMode, + isWorkflowPage, + isRagPipelinePage, + defaultLocale, + scopes.map(s => s.id).sort().join(','), + ], + queryFn: async () => { + const query = searchQueryDebouncedValue.toLowerCase() + const scope = matchAction(query, scopes) + return await searchAnything(defaultLocale, query, scope, scopes) + }, + enabled: !!searchQueryDebouncedValue && !isCommandsMode, + staleTime: 30000, + gcTime: 300000, + placeholderData: keepPreviousData, + }, + ) + + const dedupedResults = useMemo(() => { + if (!searchQuery.trim()) + return [] + + const seen = new Set<string>() + return searchResults.filter((result) => { + const key = `${result.type}-${result.id}` + if (seen.has(key)) + return false + seen.add(key) + return true + }) + }, [searchResults, searchQuery]) + + return { + scopes, + searchResults: dedupedResults, + isLoading, + isError, + error, + searchMode, + isCommandsMode, + } +} diff --git a/web/app/components/goto-anything/index.spec.tsx b/web/app/components/goto-anything/index.spec.tsx index 6a6143a6e2..b9545f690a 100644 --- a/web/app/components/goto-anything/index.spec.tsx +++ b/web/app/components/goto-anything/index.spec.tsx @@ -1,5 +1,6 @@ import type { ReactNode } from 'react' -import type { ActionItem, SearchResult } from './actions/types' +import type { ScopeDescriptor } from './actions/scope-registry' +import type { SearchResult } from './actions/types' import { act, render, screen, waitFor } from '@testing-library/react' import userEvent from '@testing-library/user-event' import * as React from 'react' @@ -58,6 +59,7 @@ const triggerKeyPress = (combo: string) => { let mockQueryResult = { data: [] as TestSearchResult[], isLoading: false, isError: false, error: null as Error | null } vi.mock('@tanstack/react-query', () => ({ useQuery: () => mockQueryResult, + keepPreviousData: (data: unknown) => data, })) vi.mock('@/context/i18n', () => ({ @@ -70,37 +72,30 @@ vi.mock('./context', () => ({ GotoAnythingProvider: ({ children }: { children: React.ReactNode }) => <>{children}</>, })) -vi.mock('@/app/components/workflow/utils', () => ({ - getKeyboardKeyNameBySystem: (key: string) => key, -})) +type MatchAction = typeof import('./actions').matchAction +type SearchAnything = typeof import('./actions').searchAnything -const createActionItem = (key: ActionItem['key'], shortcut: string): ActionItem => ({ - key, - shortcut, - title: `${key} title`, - description: `${key} desc`, - action: vi.fn(), - search: vi.fn(), +const mockState = vi.hoisted(() => { + const state = { + scopes: [] as ScopeDescriptor[], + useGotoAnythingScopesMock: vi.fn(() => state.scopes), + matchActionMock: vi.fn<MatchAction>(() => undefined), + searchAnythingMock: vi.fn<SearchAnything>(async () => []), + } + + return state }) -const actionsMock = { - slash: createActionItem('/', '/'), - app: createActionItem('@app', '@app'), - plugin: createActionItem('@plugin', '@plugin'), -} - -const createActionsMock = vi.fn(() => actionsMock) -const matchActionMock = vi.fn(() => undefined) -const searchAnythingMock = vi.fn(async () => mockQueryResult.data) - vi.mock('./actions', () => ({ - createActions: () => createActionsMock(), - matchAction: () => matchActionMock(), - searchAnything: () => searchAnythingMock(), + __esModule: true, + matchAction: (...args: Parameters<MatchAction>) => mockState.matchActionMock(...args), + searchAnything: (...args: Parameters<SearchAnything>) => mockState.searchAnythingMock(...args), + useGotoAnythingScopes: () => mockState.useGotoAnythingScopesMock(), })) vi.mock('./actions/commands', () => ({ SlashCommandProvider: () => null, + executeCommand: vi.fn(), })) type MockSlashCommand = { @@ -118,6 +113,20 @@ vi.mock('./actions/commands/registry', () => ({ }, })) +const createScope = (id: ScopeDescriptor['id'], shortcut: string): ScopeDescriptor => ({ + id, + shortcut, + title: `${id} title`, + description: `${id} desc`, + search: vi.fn(), +}) + +const scopesMock = [ + createScope('slash', '/'), + createScope('app', '@app'), + createScope('plugin', '@plugin'), +] + vi.mock('@/app/components/workflow/utils/common', () => ({ getKeyboardKeyCodeBySystem: () => 'ctrl', getKeyboardKeyNameBySystem: (key: string) => key, @@ -144,8 +153,10 @@ describe('GotoAnything', () => { routerPush.mockClear() Object.keys(keyPressHandlers).forEach(key => delete keyPressHandlers[key]) mockQueryResult = { data: [], isLoading: false, isError: false, error: null } - matchActionMock.mockReset() - searchAnythingMock.mockClear() + mockState.scopes = scopesMock + mockState.matchActionMock.mockReset() + mockState.searchAnythingMock.mockClear() + mockState.searchAnythingMock.mockImplementation(async () => mockQueryResult.data as SearchResult[]) mockFindCommand = null }) diff --git a/web/app/components/goto-anything/index.tsx b/web/app/components/goto-anything/index.tsx index 8ee2395cce..999c9a36bc 100644 --- a/web/app/components/goto-anything/index.tsx +++ b/web/app/components/goto-anything/index.tsx @@ -39,7 +39,7 @@ const GotoAnything: FC<Props> = ({ cmdVal, setCmdVal, clearSelection, - Actions, + scopes, } = useGotoAnythingSearch() // Modal state management @@ -76,7 +76,7 @@ const GotoAnything: FC<Props> = ({ searchQueryDebouncedValue, searchMode, isCommandsMode, - Actions, + scopes, isWorkflowPage, isRagPipelinePage, cmdVal, @@ -90,7 +90,6 @@ const GotoAnything: FC<Props> = ({ activePlugin, setActivePlugin, } = useGotoAnythingNavigation({ - Actions, setSearchQuery, clearSelection, inputRef, @@ -179,7 +178,7 @@ const GotoAnything: FC<Props> = ({ {isCommandsMode ? ( <CommandSelector - actions={Actions} + scopes={scopes} onCommandSelect={handleCommandSelect} searchFilter={searchQuery.trim().substring(1)} commandValue={cmdVal} @@ -198,7 +197,7 @@ const GotoAnything: FC<Props> = ({ <EmptyState variant="no-results" searchMode={searchMode} - Actions={Actions} + Actions={scopes} /> )} diff --git a/web/app/components/plugins/plugin-detail-panel/subscription-list/create/common-modal.spec.tsx b/web/app/components/plugins/plugin-detail-panel/subscription-list/create/common-modal.spec.tsx index 2c9d0f5002..5088ca764d 100644 --- a/web/app/components/plugins/plugin-detail-panel/subscription-list/create/common-modal.spec.tsx +++ b/web/app/components/plugins/plugin-detail-panel/subscription-list/create/common-modal.spec.tsx @@ -2052,9 +2052,6 @@ describe('CommonCreateModal', () => { expect(mockCreateBuilder).toHaveBeenCalled() }) - // Flush pending state updates from createBuilder promise resolution - await act(async () => {}) - const input = screen.getByTestId('form-field-webhook_url') fireEvent.change(input, { target: { value: 'test' } }) diff --git a/web/app/components/rag-pipeline/components/update-dsl-modal.spec.tsx b/web/app/components/rag-pipeline/components/update-dsl-modal.spec.tsx index 6643d8239d..14c8e607d0 100644 --- a/web/app/components/rag-pipeline/components/update-dsl-modal.spec.tsx +++ b/web/app/components/rag-pipeline/components/update-dsl-modal.spec.tsx @@ -145,6 +145,22 @@ vi.mock('@/app/components/workflow/constants', () => ({ WORKFLOW_DATA_UPDATE: 'WORKFLOW_DATA_UPDATE', })) +// Mock FileReader +class MockFileReader { + result: string | null = null + onload: ((e: { target: { result: string | null } }) => void) | null = null + + readAsText(_file: File) { + // Simulate async file reading using queueMicrotask for more reliable async behavior + queueMicrotask(() => { + this.result = 'test file content' + if (this.onload) { + this.onload({ target: { result: this.result } }) + } + }) + } +} + afterEach(() => { cleanup() vi.clearAllMocks() @@ -154,6 +170,7 @@ describe('UpdateDSLModal', () => { const mockOnCancel = vi.fn() const mockOnBackup = vi.fn() const mockOnImport = vi.fn() + let originalFileReader: typeof FileReader const defaultProps = { onCancel: mockOnCancel, @@ -169,6 +186,14 @@ describe('UpdateDSLModal', () => { pipeline_id: 'test-pipeline-id', }) mockHandleCheckPluginDependencies.mockResolvedValue(undefined) + + // Mock FileReader + originalFileReader = globalThis.FileReader + globalThis.FileReader = MockFileReader as unknown as typeof FileReader + }) + + afterEach(() => { + globalThis.FileReader = originalFileReader }) describe('rendering', () => { @@ -538,7 +563,6 @@ describe('UpdateDSLModal', () => { const file = new File(['test content'], 'test.pipeline', { type: 'text/yaml' }) fireEvent.change(fileInput, { target: { files: [file] } }) - // Wait for FileReader to process and button to be enabled await waitFor(() => { const importButton = screen.getByText('common.overwriteAndImport') expect(importButton).not.toBeDisabled() @@ -563,12 +587,15 @@ describe('UpdateDSLModal', () => { const file = new File(['test content'], 'test.pipeline', { type: 'text/yaml' }) fireEvent.change(fileInput, { target: { files: [file] } }) - // Wait for FileReader to complete and button to be enabled + // Wait for FileReader to complete (setTimeout 0) and button to be enabled await waitFor(() => { const importButton = screen.getByText('common.overwriteAndImport') expect(importButton).not.toBeDisabled() }) + // Give extra time for the FileReader's setTimeout to complete + await new Promise(resolve => setTimeout(resolve, 10)) + const importButton = screen.getByText('common.overwriteAndImport') fireEvent.click(importButton) @@ -597,11 +624,6 @@ describe('UpdateDSLModal', () => { expect(importButton).not.toBeDisabled() }) - // Flush the FileReader microtask to ensure fileContent is set - await act(async () => { - await new Promise<void>(resolve => queueMicrotask(resolve)) - }) - const importButton = screen.getByText('common.overwriteAndImport') fireEvent.click(importButton) @@ -703,7 +725,7 @@ describe('UpdateDSLModal', () => { await waitFor(() => { expect(screen.getByText('1.0.0')).toBeInTheDocument() expect(screen.getByText('2.0.0')).toBeInTheDocument() - }, { timeout: 1000 }) + }, { timeout: 500 }) }) it('should close error modal when cancel button is clicked', async () => { @@ -732,7 +754,7 @@ describe('UpdateDSLModal', () => { // Wait for error modal await waitFor(() => { expect(screen.getByText('newApp.appCreateDSLErrorTitle')).toBeInTheDocument() - }, { timeout: 1000 }) + }, { timeout: 500 }) // Find and click cancel button in error modal - it should be the one with secondary variant const cancelButtons = screen.getAllByText('newApp.Cancel') @@ -750,8 +772,6 @@ describe('UpdateDSLModal', () => { }) it('should call importDSLConfirm when confirm button is clicked in error modal', async () => { - vi.useFakeTimers({ shouldAdvanceTime: true }) - mockImportDSL.mockResolvedValue({ id: 'import-id', status: DSLImportStatus.PENDING, @@ -769,27 +789,20 @@ describe('UpdateDSLModal', () => { const fileInput = screen.getByTestId('file-input') const file = new File(['test content'], 'test.pipeline', { type: 'text/yaml' }) + fireEvent.change(fileInput, { target: { files: [file] } }) - await act(async () => { - fireEvent.change(fileInput, { target: { files: [file] } }) - // Flush microtasks scheduled by the FileReader mock (which uses queueMicrotask) - await new Promise<void>(resolve => queueMicrotask(resolve)) + await waitFor(() => { + const importButton = screen.getByText('common.overwriteAndImport') + expect(importButton).not.toBeDisabled() }) const importButton = screen.getByText('common.overwriteAndImport') - expect(importButton).not.toBeDisabled() - - await act(async () => { - fireEvent.click(importButton) - // Flush the promise resolution from mockImportDSL - await Promise.resolve() - // Advance past the 300ms setTimeout in the component - await vi.advanceTimersByTimeAsync(350) - }) + fireEvent.click(importButton) + // Wait for error modal await waitFor(() => { expect(screen.getByText('newApp.appCreateDSLErrorTitle')).toBeInTheDocument() - }, { timeout: 1000 }) + }, { timeout: 500 }) // Click confirm button const confirmButton = screen.getByText('newApp.Confirm') @@ -798,8 +811,6 @@ describe('UpdateDSLModal', () => { await waitFor(() => { expect(mockImportDSLConfirm).toHaveBeenCalledWith('import-id') }) - - vi.useRealTimers() }) it('should show success notification after confirm completes', async () => { @@ -832,7 +843,7 @@ describe('UpdateDSLModal', () => { await waitFor(() => { expect(screen.getByText('newApp.appCreateDSLErrorTitle')).toBeInTheDocument() - }, { timeout: 1000 }) + }, { timeout: 500 }) const confirmButton = screen.getByText('newApp.Confirm') fireEvent.click(confirmButton) @@ -874,7 +885,7 @@ describe('UpdateDSLModal', () => { await waitFor(() => { expect(screen.getByText('newApp.appCreateDSLErrorTitle')).toBeInTheDocument() - }, { timeout: 1000 }) + }, { timeout: 500 }) const confirmButton = screen.getByText('newApp.Confirm') fireEvent.click(confirmButton) @@ -913,7 +924,7 @@ describe('UpdateDSLModal', () => { await waitFor(() => { expect(screen.getByText('newApp.appCreateDSLErrorTitle')).toBeInTheDocument() - }, { timeout: 1000 }) + }, { timeout: 500 }) const confirmButton = screen.getByText('newApp.Confirm') fireEvent.click(confirmButton) @@ -955,7 +966,7 @@ describe('UpdateDSLModal', () => { await waitFor(() => { expect(screen.getByText('newApp.appCreateDSLErrorTitle')).toBeInTheDocument() - }, { timeout: 1000 }) + }, { timeout: 500 }) const confirmButton = screen.getByText('newApp.Confirm') fireEvent.click(confirmButton) @@ -997,7 +1008,7 @@ describe('UpdateDSLModal', () => { await waitFor(() => { expect(screen.getByText('newApp.appCreateDSLErrorTitle')).toBeInTheDocument() - }, { timeout: 1000 }) + }, { timeout: 500 }) const confirmButton = screen.getByText('newApp.Confirm') fireEvent.click(confirmButton) @@ -1008,8 +1019,6 @@ describe('UpdateDSLModal', () => { }) it('should call handleCheckPluginDependencies after confirm', async () => { - vi.useFakeTimers({ shouldAdvanceTime: true }) - mockImportDSL.mockResolvedValue({ id: 'import-id', status: DSLImportStatus.PENDING, @@ -1027,27 +1036,19 @@ describe('UpdateDSLModal', () => { const fileInput = screen.getByTestId('file-input') const file = new File(['test content'], 'test.pipeline', { type: 'text/yaml' }) + fireEvent.change(fileInput, { target: { files: [file] } }) - await act(async () => { - fireEvent.change(fileInput, { target: { files: [file] } }) - // Flush microtasks scheduled by the FileReader mock (which uses queueMicrotask) - await new Promise<void>(resolve => queueMicrotask(resolve)) + await waitFor(() => { + const importButton = screen.getByText('common.overwriteAndImport') + expect(importButton).not.toBeDisabled() }) const importButton = screen.getByText('common.overwriteAndImport') - expect(importButton).not.toBeDisabled() - - await act(async () => { - fireEvent.click(importButton) - // Flush the promise resolution from mockImportDSL - await Promise.resolve() - // Advance past the 300ms setTimeout in the component - await vi.advanceTimersByTimeAsync(350) - }) + fireEvent.click(importButton) await waitFor(() => { expect(screen.getByText('newApp.appCreateDSLErrorTitle')).toBeInTheDocument() - }, { timeout: 1000 }) + }, { timeout: 500 }) const confirmButton = screen.getByText('newApp.Confirm') fireEvent.click(confirmButton) @@ -1055,8 +1056,6 @@ describe('UpdateDSLModal', () => { await waitFor(() => { expect(mockHandleCheckPluginDependencies).toHaveBeenCalledWith('test-pipeline-id', true) }) - - vi.useRealTimers() }) it('should handle undefined imported_dsl_version and current_dsl_version', async () => { @@ -1085,7 +1084,7 @@ describe('UpdateDSLModal', () => { // Should show error modal even with undefined versions await waitFor(() => { expect(screen.getByText('newApp.appCreateDSLErrorTitle')).toBeInTheDocument() - }, { timeout: 1000 }) + }, { timeout: 500 }) }) it('should not call importDSLConfirm when importId is not set', async () => { diff --git a/web/app/components/rag-pipeline/hooks/use-DSL.spec.ts b/web/app/components/rag-pipeline/hooks/use-DSL.spec.ts index 295ed20bd8..0d217f3605 100644 --- a/web/app/components/rag-pipeline/hooks/use-DSL.spec.ts +++ b/web/app/components/rag-pipeline/hooks/use-DSL.spec.ts @@ -1,49 +1,79 @@ -import { act, renderHook, waitFor } from '@testing-library/react' +import { renderHook } from '@testing-library/react' +import { act } from 'react' import { afterEach, beforeEach, describe, expect, it, vi } from 'vitest' + +// ============================================================================ +// Import after mocks +// ============================================================================ + import { useDSL } from './use-DSL' -// Mock dependencies -const mockNotify = vi.fn() -vi.mock('@/app/components/base/toast', () => ({ - useToastContext: () => ({ notify: mockNotify }), -})) - -const mockEventEmitter = { emit: vi.fn() } -vi.mock('@/context/event-emitter', () => ({ - useEventEmitterContextContext: () => ({ eventEmitter: mockEventEmitter }), -})) - -const mockDoSyncWorkflowDraft = vi.fn() -vi.mock('./use-nodes-sync-draft', () => ({ - useNodesSyncDraft: () => ({ doSyncWorkflowDraft: mockDoSyncWorkflowDraft }), -})) - -const mockGetState = vi.fn() -vi.mock('@/app/components/workflow/store', () => ({ - useWorkflowStore: () => ({ getState: mockGetState }), -})) - -const mockExportPipelineConfig = vi.fn() -vi.mock('@/service/use-pipeline', () => ({ - useExportPipelineDSL: () => ({ mutateAsync: mockExportPipelineConfig }), -})) - -const mockFetchWorkflowDraft = vi.fn() -vi.mock('@/service/workflow', () => ({ - fetchWorkflowDraft: (...args: unknown[]) => mockFetchWorkflowDraft(...args), -})) - -const mockDownloadBlob = vi.fn() -vi.mock('@/utils/download', () => ({ - downloadBlob: (...args: unknown[]) => mockDownloadBlob(...args), -})) +// ============================================================================ +// Mocks +// ============================================================================ +// Mock react-i18next vi.mock('react-i18next', () => ({ useTranslation: () => ({ t: (key: string) => key, }), })) +// Mock toast context +const mockNotify = vi.fn() +vi.mock('@/app/components/base/toast', () => ({ + useToastContext: () => ({ + notify: mockNotify, + }), +})) + +// Mock event emitter context +const mockEmit = vi.fn() +vi.mock('@/context/event-emitter', () => ({ + useEventEmitterContextContext: () => ({ + eventEmitter: { + emit: mockEmit, + }, + }), +})) + +// Mock workflow store +const mockWorkflowStoreGetState = vi.fn() +vi.mock('@/app/components/workflow/store', () => ({ + useWorkflowStore: () => ({ + getState: mockWorkflowStoreGetState, + }), +})) + +// Mock useNodesSyncDraft +const mockDoSyncWorkflowDraft = vi.fn() +vi.mock('./use-nodes-sync-draft', () => ({ + useNodesSyncDraft: () => ({ + doSyncWorkflowDraft: mockDoSyncWorkflowDraft, + }), +})) + +// Mock pipeline service +const mockExportPipelineConfig = vi.fn() +vi.mock('@/service/use-pipeline', () => ({ + useExportPipelineDSL: () => ({ + mutateAsync: mockExportPipelineConfig, + }), +})) + +// Mock download utility +const mockDownloadBlob = vi.fn() +vi.mock('@/utils/download', () => ({ + downloadBlob: (...args: unknown[]) => mockDownloadBlob(...args), +})) + +// Mock workflow service +const mockFetchWorkflowDraft = vi.fn() +vi.mock('@/service/workflow', () => ({ + fetchWorkflowDraft: (url: string) => mockFetchWorkflowDraft(url), +})) + +// Mock workflow constants vi.mock('@/app/components/workflow/constants', () => ({ DSL_EXPORT_CHECK: 'DSL_EXPORT_CHECK', })) @@ -53,63 +83,48 @@ vi.mock('@/app/components/workflow/constants', () => ({ // ============================================================================ describe('useDSL', () => { - let mockLink: { href: string, download: string, click: ReturnType<typeof vi.fn>, style: { display: string }, remove: ReturnType<typeof vi.fn> } - let originalCreateElement: typeof document.createElement - let originalAppendChild: typeof document.body.appendChild - let mockCreateObjectURL: ReturnType<typeof vi.spyOn> - let mockRevokeObjectURL: ReturnType<typeof vi.spyOn> - beforeEach(() => { vi.clearAllMocks() - // Create a proper mock link element with all required properties for downloadBlob - mockLink = { - href: '', - download: '', - click: vi.fn(), - style: { display: '' }, - remove: vi.fn(), - } - - // Save original and mock selectively - only intercept 'a' elements - originalCreateElement = document.createElement.bind(document) - document.createElement = vi.fn((tagName: string) => { - if (tagName === 'a') { - return mockLink as unknown as HTMLElement - } - return originalCreateElement(tagName) - }) as typeof document.createElement - - // Mock document.body.appendChild for downloadBlob - originalAppendChild = document.body.appendChild.bind(document.body) - document.body.appendChild = vi.fn(<T extends Node>(node: T): T => node) as typeof document.body.appendChild - - // downloadBlob uses window.URL, not URL - mockCreateObjectURL = vi.spyOn(window.URL, 'createObjectURL').mockReturnValue('blob:test-url') - mockRevokeObjectURL = vi.spyOn(window.URL, 'revokeObjectURL').mockImplementation(() => {}) - // Default store state - mockGetState.mockReturnValue({ + mockWorkflowStoreGetState.mockReturnValue({ pipelineId: 'test-pipeline-id', knowledgeName: 'Test Knowledge Base', }) mockDoSyncWorkflowDraft.mockResolvedValue(undefined) mockExportPipelineConfig.mockResolvedValue({ data: 'yaml-content' }) - mockFetchWorkflowDraft.mockResolvedValue({ environment_variables: [] }) + mockFetchWorkflowDraft.mockResolvedValue({ + environment_variables: [], + }) }) afterEach(() => { - document.createElement = originalCreateElement - document.body.appendChild = originalAppendChild - mockCreateObjectURL.mockRestore() - mockRevokeObjectURL.mockRestore() vi.clearAllMocks() }) + describe('hook initialization', () => { + it('should return exportCheck function', () => { + const { result } = renderHook(() => useDSL()) + + expect(result.current.exportCheck).toBeDefined() + expect(typeof result.current.exportCheck).toBe('function') + }) + + it('should return handleExportDSL function', () => { + const { result } = renderHook(() => useDSL()) + + expect(result.current.handleExportDSL).toBeDefined() + expect(typeof result.current.handleExportDSL).toBe('function') + }) + }) + describe('handleExportDSL', () => { - it('should return early when pipelineId is not set', async () => { - mockGetState.mockReturnValue({ pipelineId: null, knowledgeName: 'test' }) + it('should not export when pipelineId is missing', async () => { + mockWorkflowStoreGetState.mockReturnValue({ + pipelineId: undefined, + knowledgeName: 'Test', + }) const { result } = renderHook(() => useDSL()) @@ -118,6 +133,30 @@ describe('useDSL', () => { }) expect(mockDoSyncWorkflowDraft).not.toHaveBeenCalled() + expect(mockExportPipelineConfig).not.toHaveBeenCalled() + }) + + it('should sync workflow draft before export', async () => { + const { result } = renderHook(() => useDSL()) + + await act(async () => { + await result.current.handleExportDSL() + }) + + expect(mockDoSyncWorkflowDraft).toHaveBeenCalled() + }) + + it('should call exportPipelineConfig with correct params', async () => { + const { result } = renderHook(() => useDSL()) + + await act(async () => { + await result.current.handleExportDSL(true) + }) + + expect(mockExportPipelineConfig).toHaveBeenCalledWith({ + pipelineId: 'test-pipeline-id', + include: true, + }) }) it('should create and download file', async () => { @@ -130,7 +169,7 @@ describe('useDSL', () => { expect(mockDownloadBlob).toHaveBeenCalled() }) - it('should set correct download filename', async () => { + it('should use correct file extension for download', async () => { const { result } = renderHook(() => useDSL()) await act(async () => { @@ -158,7 +197,7 @@ describe('useDSL', () => { ) }) - it('should handle export error', async () => { + it('should show error notification on export failure', async () => { mockExportPipelineConfig.mockRejectedValue(new Error('Export failed')) const { result } = renderHook(() => useDSL()) @@ -167,33 +206,19 @@ describe('useDSL', () => { await result.current.handleExportDSL() }) - await waitFor(() => { - expect(mockNotify).toHaveBeenCalledWith({ - type: 'error', - message: 'exportFailed', - }) - }) - }) - - it('should pass include parameter', async () => { - const { result } = renderHook(() => useDSL()) - - await act(async () => { - await result.current.handleExportDSL(true) - }) - - await waitFor(() => { - expect(mockExportPipelineConfig).toHaveBeenCalledWith({ - pipelineId: 'test-pipeline-id', - include: true, - }) + expect(mockNotify).toHaveBeenCalledWith({ + type: 'error', + message: 'exportFailed', }) }) }) describe('exportCheck', () => { - it('should return early when pipelineId is not set', async () => { - mockGetState.mockReturnValue({ pipelineId: null }) + it('should not check when pipelineId is missing', async () => { + mockWorkflowStoreGetState.mockReturnValue({ + pipelineId: undefined, + knowledgeName: 'Test', + }) const { result } = renderHook(() => useDSL()) @@ -204,8 +229,22 @@ describe('useDSL', () => { expect(mockFetchWorkflowDraft).not.toHaveBeenCalled() }) - it('should call handleExportDSL directly when no secret variables', async () => { - mockFetchWorkflowDraft.mockResolvedValue({ environment_variables: [] }) + it('should fetch workflow draft', async () => { + const { result } = renderHook(() => useDSL()) + + await act(async () => { + await result.current.exportCheck() + }) + + expect(mockFetchWorkflowDraft).toHaveBeenCalledWith('/rag/pipelines/test-pipeline-id/workflows/draft') + }) + + it('should directly export when no secret environment variables', async () => { + mockFetchWorkflowDraft.mockResolvedValue({ + environment_variables: [ + { id: '1', value_type: 'string', value: 'test' }, + ], + }) const { result } = renderHook(() => useDSL()) @@ -213,15 +252,16 @@ describe('useDSL', () => { await result.current.exportCheck() }) - await waitFor(() => { - expect(mockFetchWorkflowDraft).toHaveBeenCalledWith('/rag/pipelines/test-pipeline-id/workflows/draft') - expect(mockDoSyncWorkflowDraft).toHaveBeenCalled() - }) + // Should call doSyncWorkflowDraft (which means handleExportDSL was called) + expect(mockDoSyncWorkflowDraft).toHaveBeenCalled() }) - it('should emit event when secret variables exist', async () => { - const secretVars = [{ value_type: 'secret', name: 'API_KEY' }] - mockFetchWorkflowDraft.mockResolvedValue({ environment_variables: secretVars }) + it('should emit DSL_EXPORT_CHECK event when secret variables exist', async () => { + mockFetchWorkflowDraft.mockResolvedValue({ + environment_variables: [ + { id: '1', value_type: 'secret', value: 'secret-value' }, + ], + }) const { result } = renderHook(() => useDSL()) @@ -229,17 +269,15 @@ describe('useDSL', () => { await result.current.exportCheck() }) - await waitFor(() => { - expect(mockEventEmitter.emit).toHaveBeenCalledWith({ - type: expect.any(String), - payload: { - data: secretVars, - }, - }) + expect(mockEmit).toHaveBeenCalledWith({ + type: 'DSL_EXPORT_CHECK', + payload: { + data: [{ id: '1', value_type: 'secret', value: 'secret-value' }], + }, }) }) - it('should handle export check error', async () => { + it('should show error notification on check failure', async () => { mockFetchWorkflowDraft.mockRejectedValue(new Error('Fetch failed')) const { result } = renderHook(() => useDSL()) @@ -248,12 +286,68 @@ describe('useDSL', () => { await result.current.exportCheck() }) - await waitFor(() => { - expect(mockNotify).toHaveBeenCalledWith({ - type: 'error', - message: 'exportFailed', - }) + expect(mockNotify).toHaveBeenCalledWith({ + type: 'error', + message: 'exportFailed', }) }) + + it('should filter only secret environment variables', async () => { + mockFetchWorkflowDraft.mockResolvedValue({ + environment_variables: [ + { id: '1', value_type: 'string', value: 'plain' }, + { id: '2', value_type: 'secret', value: 'secret1' }, + { id: '3', value_type: 'number', value: '123' }, + { id: '4', value_type: 'secret', value: 'secret2' }, + ], + }) + + const { result } = renderHook(() => useDSL()) + + await act(async () => { + await result.current.exportCheck() + }) + + expect(mockEmit).toHaveBeenCalledWith({ + type: 'DSL_EXPORT_CHECK', + payload: { + data: [ + { id: '2', value_type: 'secret', value: 'secret1' }, + { id: '4', value_type: 'secret', value: 'secret2' }, + ], + }, + }) + }) + + it('should handle empty environment variables', async () => { + mockFetchWorkflowDraft.mockResolvedValue({ + environment_variables: [], + }) + + const { result } = renderHook(() => useDSL()) + + await act(async () => { + await result.current.exportCheck() + }) + + // Should directly call handleExportDSL since no secrets + expect(mockEmit).not.toHaveBeenCalled() + expect(mockDoSyncWorkflowDraft).toHaveBeenCalled() + }) + + it('should handle undefined environment variables', async () => { + mockFetchWorkflowDraft.mockResolvedValue({ + environment_variables: undefined, + }) + + const { result } = renderHook(() => useDSL()) + + await act(async () => { + await result.current.exportCheck() + }) + + // Should directly call handleExportDSL since no secrets + expect(mockEmit).not.toHaveBeenCalled() + }) }) }) diff --git a/web/app/components/rag-pipeline/hooks/use-rag-pipeline-search.tsx b/web/app/components/rag-pipeline/hooks/use-rag-pipeline-search.tsx index b999f5ccc8..43479f3ea2 100644 --- a/web/app/components/rag-pipeline/hooks/use-rag-pipeline-search.tsx +++ b/web/app/components/rag-pipeline/hooks/use-rag-pipeline-search.tsx @@ -5,7 +5,7 @@ import type { LLMNodeType } from '@/app/components/workflow/nodes/llm/types' import type { ToolNodeType } from '@/app/components/workflow/nodes/tool/types' import type { CommonNodeType } from '@/app/components/workflow/types' import { useCallback, useEffect, useMemo } from 'react' -import { ragPipelineNodesAction } from '@/app/components/goto-anything/actions/rag-pipeline-nodes' +import { setRagPipelineNodesSearchFn } from '@/app/components/goto-anything/actions/rag-pipeline-nodes' import BlockIcon from '@/app/components/workflow/block-icon' import { useNodesInteractions } from '@/app/components/workflow/hooks/use-nodes-interactions' import { useGetToolIcon } from '@/app/components/workflow/hooks/use-tool-icon' @@ -153,16 +153,15 @@ export const useRagPipelineSearch = () => { return results }, [searchableNodes, calculateScore]) - // Directly set the search function on the action object + // Directly set the search function using the setter useEffect(() => { if (searchableNodes.length > 0) { - // Set the search function directly on the action - ragPipelineNodesAction.searchFn = searchRagPipelineNodes + setRagPipelineNodesSearchFn(searchRagPipelineNodes) } return () => { // Clean up when component unmounts - ragPipelineNodesAction.searchFn = undefined + setRagPipelineNodesSearchFn(() => []) } }, [searchableNodes, searchRagPipelineNodes]) diff --git a/web/app/components/tools/edit-custom-collection-modal/index.spec.tsx b/web/app/components/tools/edit-custom-collection-modal/index.spec.tsx index 97fc03175d..204772a3e2 100644 --- a/web/app/components/tools/edit-custom-collection-modal/index.spec.tsx +++ b/web/app/components/tools/edit-custom-collection-modal/index.spec.tsx @@ -168,7 +168,6 @@ describe('EditCustomCollectionModal', () => { const schemaInput = screen.getByPlaceholderText('tools.createTool.schemaPlaceHolder') fireEvent.change(schemaInput, { target: { value: '{}' } }) - // Wait for parseParamsSchema to be called and state to be updated await waitFor(() => { expect(parseParamsSchemaMock).toHaveBeenCalledWith('{}') }) @@ -185,13 +184,13 @@ describe('EditCustomCollectionModal', () => { provider: 'provider', schema: '{}', schema_type: 'openapi', + credentials: { + auth_type: 'none', + }, icon: { content: '🕵️', background: '#FEF7C3', }, - credentials: { - auth_type: 'none', - }, labels: [], })) expect(toastNotifySpy).not.toHaveBeenCalled() diff --git a/web/app/components/workflow-app/components/workflow-onboarding-modal/index.spec.tsx b/web/app/components/workflow-app/components/workflow-onboarding-modal/index.spec.tsx index 63d0344275..525946bb1c 100644 --- a/web/app/components/workflow-app/components/workflow-onboarding-modal/index.spec.tsx +++ b/web/app/components/workflow-app/components/workflow-onboarding-modal/index.spec.tsx @@ -11,12 +11,7 @@ vi.mock('@/app/components/base/modal', () => ({ onClose, children, closable, - }: { - isShow: boolean - onClose?: () => void - children?: React.ReactNode - closable?: boolean - }) { + }: any) { if (!isShow) return null @@ -44,10 +39,7 @@ vi.mock('./start-node-selection-panel', () => ({ default: function MockStartNodeSelectionPanel({ onSelectUserInput, onSelectTrigger, - }: { - onSelectUserInput?: () => void - onSelectTrigger?: (type: BlockEnum, config?: Record<string, unknown>) => void - }) { + }: any) { return ( <div data-testid="start-node-selection-panel"> <button data-testid="select-user-input" onClick={onSelectUserInput}> @@ -55,13 +47,13 @@ vi.mock('./start-node-selection-panel', () => ({ </button> <button data-testid="select-trigger-schedule" - onClick={() => onSelectTrigger?.(BlockEnum.TriggerSchedule)} + onClick={() => onSelectTrigger(BlockEnum.TriggerSchedule)} > Select Trigger Schedule </button> <button data-testid="select-trigger-webhook" - onClick={() => onSelectTrigger?.(BlockEnum.TriggerWebhook, { config: 'test' })} + onClick={() => onSelectTrigger(BlockEnum.TriggerWebhook, { config: 'test' })} > Select Trigger Webhook </button> @@ -557,7 +549,7 @@ describe('WorkflowOnboardingModal', () => { // Arrange & Act renderComponent({ isShow: true }) - // Assert - ShortcutsName component renders keys in div elements with system-kbd class + // Assert const escKey = screen.getByText('workflow.onboarding.escTip.key') // ShortcutsName renders a <div> with class system-kbd, not a <kbd> element expect(escKey.closest('.system-kbd')).toBeInTheDocument() diff --git a/web/app/components/workflow/constants.ts b/web/app/components/workflow/constants.ts index e0e0b79b64..7a8acdb35a 100644 --- a/web/app/components/workflow/constants.ts +++ b/web/app/components/workflow/constants.ts @@ -10,8 +10,7 @@ export const X_OFFSET = 60 export const NODE_WIDTH_X_OFFSET = NODE_WIDTH + X_OFFSET export const Y_OFFSET = 39 export const VIBE_COMMAND_EVENT = 'workflow-vibe-command' -export const VIBE_REGENERATE_EVENT = 'workflow-vibe-regenerate' -export const VIBE_ACCEPT_EVENT = 'workflow-vibe-accept' +export const VIBE_APPLY_EVENT = 'workflow-vibe-apply' export const START_INITIAL_POSITION = { x: 80, y: 282 } export const AUTO_LAYOUT_OFFSET = { x: -42, diff --git a/web/app/components/workflow/hooks/use-checklist.ts b/web/app/components/workflow/hooks/use-checklist.ts index 642179aed7..8a414b3a07 100644 --- a/web/app/components/workflow/hooks/use-checklist.ts +++ b/web/app/components/workflow/hooks/use-checklist.ts @@ -160,7 +160,7 @@ export const useChecklist = (nodes: Node[], edges: Edge[]) => { } } else { - usedVars = getNodeUsedVars(node).filter(v => v.length > 0) + usedVars = getNodeUsedVars(node).filter(v => v && v.length > 0) } if (node.type === CUSTOM_NODE) { @@ -359,7 +359,7 @@ export const useChecklistBeforePublish = () => { } } else { - usedVars = getNodeUsedVars(node).filter(v => v.length > 0) + usedVars = getNodeUsedVars(node).filter(v => v && v.length > 0) } const checkData = getCheckData(node.data, datasets) const { errorMessage } = nodesExtraData![node.data.type as BlockEnum].checkValid(checkData, t, moreDataForCheckValid) diff --git a/web/app/components/workflow/hooks/use-workflow-search.tsx b/web/app/components/workflow/hooks/use-workflow-search.tsx index 8ca597f94e..8f7b1e59c7 100644 --- a/web/app/components/workflow/hooks/use-workflow-search.tsx +++ b/web/app/components/workflow/hooks/use-workflow-search.tsx @@ -5,7 +5,7 @@ import type { CommonNodeType } from '../types' import type { Emoji } from '@/app/components/tools/types' import { useCallback, useEffect, useMemo } from 'react' import { useNodes } from 'reactflow' -import { workflowNodesAction } from '@/app/components/goto-anything/actions/workflow-nodes' +import { setWorkflowNodesSearchFn } from '@/app/components/goto-anything/actions/workflow-nodes' import { CollectionType } from '@/app/components/tools/types' import BlockIcon from '@/app/components/workflow/block-icon' import { @@ -183,16 +183,15 @@ export const useWorkflowSearch = () => { return results }, [searchableNodes, calculateScore]) - // Directly set the search function on the action object + // Directly set the search function using the setter useEffect(() => { if (searchableNodes.length > 0) { - // Set the search function directly on the action - workflowNodesAction.searchFn = searchWorkflowNodes + setWorkflowNodesSearchFn(searchWorkflowNodes) } return () => { // Clean up when component unmounts - workflowNodesAction.searchFn = undefined + setWorkflowNodesSearchFn(() => []) } }, [searchableNodes, searchWorkflowNodes]) diff --git a/web/app/components/workflow/hooks/use-workflow.ts b/web/app/components/workflow/hooks/use-workflow.ts index 0c6aa7466e..7cf2dca636 100644 --- a/web/app/components/workflow/hooks/use-workflow.ts +++ b/web/app/components/workflow/hooks/use-workflow.ts @@ -471,12 +471,14 @@ export const useNodesReadOnly = () => { const workflowRunningData = useStore(s => s.workflowRunningData) const historyWorkflowData = useStore(s => s.historyWorkflowData) const isRestoring = useStore(s => s.isRestoring) + // const showVibePanel = useStore(s => s.showVibePanel) const getNodesReadOnly = useCallback((): boolean => { const { workflowRunningData, historyWorkflowData, isRestoring, + // showVibePanel, } = workflowStore.getState() return !!( diff --git a/web/app/components/workflow/index.tsx b/web/app/components/workflow/index.tsx index 62516a797d..3095fcc676 100644 --- a/web/app/components/workflow/index.tsx +++ b/web/app/components/workflow/index.tsx @@ -68,6 +68,7 @@ import { useWorkflow, useWorkflowReadOnly, useWorkflowRefreshDraft, + useWorkflowVibe, } from './hooks' import { HooksStoreContextProvider, useHooksStore } from './hooks-store' import { useWorkflowSearch } from './hooks/use-workflow-search' @@ -329,6 +330,7 @@ export const Workflow: FC<WorkflowProps> = memo(({ useShortcuts() // Initialize workflow node search functionality useWorkflowSearch() + useWorkflowVibe() // Set up scroll to node event listener using the utility function useEffect(() => { diff --git a/web/app/components/workflow/nodes/_base/components/file-upload-setting.tsx b/web/app/components/workflow/nodes/_base/components/file-upload-setting.tsx index 19dfbcad1c..5a87abae0d 100644 --- a/web/app/components/workflow/nodes/_base/components/file-upload-setting.tsx +++ b/web/app/components/workflow/nodes/_base/components/file-upload-setting.tsx @@ -33,9 +33,9 @@ const FileUploadSetting: FC<Props> = ({ const { t } = useTranslation() const { - allowed_file_upload_methods, + allowed_file_upload_methods = [], max_length, - allowed_file_types, + allowed_file_types = [], allowed_file_extensions, } = payload const { data: fileUploadConfigResponse } = useFileUploadConfig() diff --git a/web/app/components/workflow/nodes/_base/components/variable/utils.ts b/web/app/components/workflow/nodes/_base/components/variable/utils.ts index 867221ea31..6b3782662b 100644 --- a/web/app/components/workflow/nodes/_base/components/variable/utils.ts +++ b/web/app/components/workflow/nodes/_base/components/variable/utils.ts @@ -1404,9 +1404,9 @@ export const getNodeUsedVars = (node: Node): ValueSelector[] => { payload.url, payload.headers, payload.params, - typeof payload.body.data === 'string' + typeof payload.body?.data === 'string' ? payload.body.data - : payload.body.data.map(d => d.value).join(''), + : (payload.body?.data?.map(d => d.value).join('') ?? ''), ]) break } diff --git a/web/app/components/workflow/nodes/http/hooks/use-key-value-list.ts b/web/app/components/workflow/nodes/http/hooks/use-key-value-list.ts index 650ae47156..d5a4f3d872 100644 --- a/web/app/components/workflow/nodes/http/hooks/use-key-value-list.ts +++ b/web/app/components/workflow/nodes/http/hooks/use-key-value-list.ts @@ -5,6 +5,9 @@ import { useCallback, useEffect, useState } from 'react' const UNIQUE_ID_PREFIX = 'key-value-' const strToKeyValueList = (value: string) => { + if (typeof value !== 'string' || !value) + return [] + return value.split('\n').map((item) => { const [key, ...others] = item.split(':') return { @@ -16,7 +19,7 @@ const strToKeyValueList = (value: string) => { } const useKeyValueList = (value: string, onChange: (value: string) => void, noFilter?: boolean) => { - const [list, doSetList] = useState<KeyValue[]>(() => value ? strToKeyValueList(value) : []) + const [list, doSetList] = useState<KeyValue[]>(() => typeof value === 'string' && value ? strToKeyValueList(value) : []) const setList = (l: KeyValue[]) => { doSetList(l.map((item) => { return { diff --git a/web/app/components/workflow/nodes/if-else/components/condition-value.tsx b/web/app/components/workflow/nodes/if-else/components/condition-value.tsx index 6afa708494..00d618bb89 100644 --- a/web/app/components/workflow/nodes/if-else/components/condition-value.tsx +++ b/web/app/components/workflow/nodes/if-else/components/condition-value.tsx @@ -49,7 +49,7 @@ const ConditionValue = ({ if (value === true || value === false) return value ? 'True' : 'False' - return value.replace(/\{\{#([^#]*)#\}\}/g, (a, b) => { + return String(value).replace(/\{\{#([^#]*)#\}\}/g, (a, b) => { const arr: string[] = b.split('.') if (isSystemVar(arr)) return `{{${b}}}` diff --git a/web/app/components/workflow/nodes/knowledge-base/panel.tsx b/web/app/components/workflow/nodes/knowledge-base/panel.tsx index 2845d605bf..0a275645a8 100644 --- a/web/app/components/workflow/nodes/knowledge-base/panel.tsx +++ b/web/app/components/workflow/nodes/knowledge-base/panel.tsx @@ -18,7 +18,6 @@ import { Group, } from '@/app/components/workflow/nodes/_base/components/layout' import VarReferencePicker from '@/app/components/workflow/nodes/_base/components/variable/var-reference-picker' -import { IS_CE_EDITION } from '@/config' import Split from '../_base/components/split' import ChunkStructure from './components/chunk-structure' import EmbeddingModel from './components/embedding-model' @@ -173,7 +172,7 @@ const Panel: FC<NodePanelProps<KnowledgeBaseNodeType>> = ({ { data.indexing_technique === IndexMethodEnum.QUALIFIED && [ChunkStructureEnum.general, ChunkStructureEnum.parent_child].includes(data.chunk_structure) - && IS_CE_EDITION && ( + && ( <> <SummaryIndexSetting summaryIndexSetting={data.summary_index_setting} diff --git a/web/app/components/workflow/nodes/tool/use-config.ts b/web/app/components/workflow/nodes/tool/use-config.ts index 87e9186008..7e4594f4f2 100644 --- a/web/app/components/workflow/nodes/tool/use-config.ts +++ b/web/app/components/workflow/nodes/tool/use-config.ts @@ -1,7 +1,6 @@ import type { ToolNodeType, ToolVarInputs } from './types' import type { InputVar } from '@/app/components/workflow/types' import { useBoolean } from 'ahooks' -import { capitalize } from 'es-toolkit/string' import { produce } from 'immer' import { useCallback, useEffect, useMemo, useState } from 'react' import { useTranslation } from 'react-i18next' @@ -26,12 +25,6 @@ import { } from '@/service/use-tools' import { canFindTool } from '@/utils' import { useWorkflowStore } from '../../store' -import { normalizeJsonSchemaType } from './output-schema-utils' - -const formatDisplayType = (output: Record<string, unknown>): string => { - const normalizedType = normalizeJsonSchemaType(output) || 'Unknown' - return capitalize(normalizedType) -} const useConfig = (id: string, payload: ToolNodeType) => { const workflowStore = useWorkflowStore() @@ -254,13 +247,20 @@ const useConfig = (id: string, payload: ToolNodeType) => { }) } else { - const normalizedType = normalizeJsonSchemaType(output) res.push({ name: outputKey, type: - normalizedType === 'array' - ? `Array[${output.items ? formatDisplayType(output.items) : 'Unknown'}]` - : formatDisplayType(output), + output.type === 'array' + ? `Array[${output.items?.type + ? output.items.type.slice(0, 1).toLocaleUpperCase() + + output.items.type.slice(1) + : 'Unknown' + }]` + : `${output.type + ? output.type.slice(0, 1).toLocaleUpperCase() + + output.type.slice(1) + : 'Unknown' + }`, description: output.description, }) } diff --git a/web/app/components/workflow/nodes/variable-assigner/components/node-group-item.tsx b/web/app/components/workflow/nodes/variable-assigner/components/node-group-item.tsx index d08a34701b..9e46239987 100644 --- a/web/app/components/workflow/nodes/variable-assigner/components/node-group-item.tsx +++ b/web/app/components/workflow/nodes/variable-assigner/components/node-group-item.tsx @@ -127,23 +127,30 @@ const NodeGroupItem = ({ !!item.variables.length && ( <div className="space-y-0.5"> { - item.variables.map((variable = [], index) => { - const isSystem = isSystemVar(variable) + item.variables + .map((variable = [], index) => { + // Ensure variable is an array + const safeVariable = Array.isArray(variable) ? variable : [] + if (!safeVariable.length) + return null - const node = isSystem ? nodes.find(node => node.data.type === BlockEnum.Start) : nodes.find(node => node.id === variable[0]) - const varName = isSystem ? `sys.${variable[variable.length - 1]}` : variable.slice(1).join('.') - const isException = isExceptionVariable(varName, node?.data.type) + const isSystem = isSystemVar(safeVariable) - return ( - <VariableLabelInNode - key={index} - variables={variable} - nodeType={node?.data.type} - nodeTitle={node?.data.title} - isExceptionVariable={isException} - /> - ) - }) + const node = isSystem ? nodes.find(node => node.data.type === BlockEnum.Start) : nodes.find(node => node.id === safeVariable[0]) + const varName = isSystem ? `sys.${safeVariable[safeVariable.length - 1]}` : safeVariable.slice(1).join('.') + const isException = isExceptionVariable(varName, node?.data.type) + + return ( + <VariableLabelInNode + key={index} + variables={safeVariable} + nodeType={node?.data.type} + nodeTitle={node?.data.title} + isExceptionVariable={isException} + /> + ) + }) + .filter(Boolean) } </div> ) diff --git a/web/app/components/workflow/panel/index.tsx b/web/app/components/workflow/panel/index.tsx index 88ada8b11e..8b7ebfda63 100644 --- a/web/app/components/workflow/panel/index.tsx +++ b/web/app/components/workflow/panel/index.tsx @@ -8,6 +8,7 @@ import { cn } from '@/utils/classnames' import { Panel as NodePanel } from '../nodes' import { useStore } from '../store' import EnvPanel from './env-panel' +import VibePanel from './vibe-panel' const VersionHistoryPanel = dynamic(() => import('@/app/components/workflow/panel/version-history-panel'), { ssr: false, @@ -85,6 +86,7 @@ const Panel: FC<PanelProps> = ({ const showEnvPanel = useStore(s => s.showEnvPanel) const isRestoring = useStore(s => s.isRestoring) const showWorkflowVersionHistoryPanel = useStore(s => s.showWorkflowVersionHistoryPanel) + const showVibePanel = useStore(s => s.showVibePanel) // widths used for adaptive layout const workflowCanvasWidth = useStore(s => s.workflowCanvasWidth) @@ -124,33 +126,36 @@ const Panel: FC<PanelProps> = ({ ) return ( - <div - ref={rightPanelRef} - tabIndex={-1} - className={cn('absolute bottom-1 right-0 top-14 z-10 flex outline-none')} - key={`${isRestoring}`} - > - {components?.left} - {!!selectedNode && <NodePanel {...selectedNode} />} + <> <div - className="relative" - ref={otherPanelRef} + ref={rightPanelRef} + tabIndex={-1} + className={cn('absolute bottom-1 right-0 top-14 z-10 flex outline-none')} + key={`${isRestoring}`} > - { - components?.right - } - { - showWorkflowVersionHistoryPanel && ( - <VersionHistoryPanel {...versionHistoryPanelProps} /> - ) - } - { - showEnvPanel && ( - <EnvPanel /> - ) - } + {components?.left} + {!!selectedNode && <NodePanel {...selectedNode} />} + <div + className="relative" + ref={otherPanelRef} + > + { + components?.right + } + { + showWorkflowVersionHistoryPanel && ( + <VersionHistoryPanel {...versionHistoryPanelProps} /> + ) + } + { + showEnvPanel && ( + <EnvPanel /> + ) + } + </div> </div> - </div> + {showVibePanel && <VibePanel />} + </> ) } diff --git a/web/app/components/workflow/store/workflow/index.ts b/web/app/components/workflow/store/workflow/index.ts index 923c9c8c72..e9416e6e1b 100644 --- a/web/app/components/workflow/store/workflow/index.ts +++ b/web/app/components/workflow/store/workflow/index.ts @@ -11,8 +11,8 @@ import type { LayoutSliceShape } from './layout-slice' import type { NodeSliceShape } from './node-slice' import type { PanelSliceShape } from './panel-slice' import type { ToolSliceShape } from './tool-slice' -import type { VibeWorkflowSliceShape } from './vibe-workflow-slice' import type { VersionSliceShape } from './version-slice' +import type { VibeWorkflowSliceShape } from './vibe-workflow-slice' import type { WorkflowDraftSliceShape } from './workflow-draft-slice' import type { WorkflowSliceShape } from './workflow-slice' import type { RagPipelineSliceShape } from '@/app/components/rag-pipeline/store' @@ -34,8 +34,8 @@ import { createNodeSlice } from './node-slice' import { createPanelSlice } from './panel-slice' import { createToolSlice } from './tool-slice' -import { createVibeWorkflowSlice } from './vibe-workflow-slice' import { createVersionSlice } from './version-slice' +import { createVibeWorkflowSlice } from './vibe-workflow-slice' import { createWorkflowDraftSlice } from './workflow-draft-slice' import { createWorkflowSlice } from './workflow-slice' @@ -57,8 +57,8 @@ export type Shape & WorkflowSliceShape & InspectVarsSliceShape & LayoutSliceShape - & VibeWorkflowSliceShape & SliceFromInjection + & VibeWorkflowSliceShape export type InjectWorkflowStoreSliceFn = StateCreator<SliceFromInjection> diff --git a/web/app/components/workflow/store/workflow/panel-slice.ts b/web/app/components/workflow/store/workflow/panel-slice.ts index 4848beeac5..00a3112857 100644 --- a/web/app/components/workflow/store/workflow/panel-slice.ts +++ b/web/app/components/workflow/store/workflow/panel-slice.ts @@ -1,4 +1,7 @@ import type { StateCreator } from 'zustand' +import type { BackendEdgeSpec, BackendNodeSpec } from '@/service/debug' + +export type VibeIntent = 'generate' | 'off_topic' | 'error' | '' export type PanelSliceShape = { panelWidth: number @@ -24,6 +27,26 @@ export type PanelSliceShape = { setShowVariableInspectPanel: (showVariableInspectPanel: boolean) => void initShowLastRunTab: boolean setInitShowLastRunTab: (initShowLastRunTab: boolean) => void + showVibePanel: boolean + setShowVibePanel: (showVibePanel: boolean) => void + vibePanelMermaidCode: string + setVibePanelMermaidCode: (vibePanelMermaidCode: string) => void + vibePanelBackendNodes?: BackendNodeSpec[] + setVibePanelBackendNodes: (nodes?: BackendNodeSpec[]) => void + vibePanelBackendEdges?: BackendEdgeSpec[] + setVibePanelBackendEdges: (edges?: BackendEdgeSpec[]) => void + isVibeGenerating: boolean + setIsVibeGenerating: (isVibeGenerating: boolean) => void + vibePanelInstruction: string + setVibePanelInstruction: (vibePanelInstruction: string) => void + vibePanelIntent: VibeIntent + setVibePanelIntent: (vibePanelIntent: VibeIntent) => void + vibePanelMessage: string + setVibePanelMessage: (vibePanelMessage: string) => void + vibePanelSuggestions: string[] + setVibePanelSuggestions: (vibePanelSuggestions: string[]) => void + vibePanelLastWarnings: string[] + setVibePanelLastWarnings: (vibePanelLastWarnings: string[]) => void } export const createPanelSlice: StateCreator<PanelSliceShape> = set => ({ @@ -44,4 +67,24 @@ export const createPanelSlice: StateCreator<PanelSliceShape> = set => ({ setShowVariableInspectPanel: showVariableInspectPanel => set(() => ({ showVariableInspectPanel })), initShowLastRunTab: false, setInitShowLastRunTab: initShowLastRunTab => set(() => ({ initShowLastRunTab })), + showVibePanel: false, + setShowVibePanel: showVibePanel => set(() => ({ showVibePanel })), + vibePanelMermaidCode: '', + setVibePanelMermaidCode: vibePanelMermaidCode => set(() => ({ vibePanelMermaidCode })), + vibePanelBackendNodes: undefined, + setVibePanelBackendNodes: vibePanelBackendNodes => set(() => ({ vibePanelBackendNodes })), + vibePanelBackendEdges: undefined, + setVibePanelBackendEdges: vibePanelBackendEdges => set(() => ({ vibePanelBackendEdges })), + isVibeGenerating: false, + setIsVibeGenerating: isVibeGenerating => set(() => ({ isVibeGenerating })), + vibePanelInstruction: '', + setVibePanelInstruction: vibePanelInstruction => set(() => ({ vibePanelInstruction })), + vibePanelIntent: '', + setVibePanelIntent: vibePanelIntent => set(() => ({ vibePanelIntent })), + vibePanelMessage: '', + setVibePanelMessage: vibePanelMessage => set(() => ({ vibePanelMessage })), + vibePanelSuggestions: [], + setVibePanelSuggestions: vibePanelSuggestions => set(() => ({ vibePanelSuggestions })), + vibePanelLastWarnings: [], + setVibePanelLastWarnings: vibePanelLastWarnings => set(() => ({ vibePanelLastWarnings })), }) diff --git a/web/app/components/workflow/utils/workflow-init.ts b/web/app/components/workflow/utils/workflow-init.ts index 77a2ccefac..a282f5109c 100644 --- a/web/app/components/workflow/utils/workflow-init.ts +++ b/web/app/components/workflow/utils/workflow-init.ts @@ -111,8 +111,8 @@ export const preprocessNodesAndEdges = (nodes: Node[], edges: Edge[]) => { const currentNode = nodes[i] as Node<IterationNodeType | LoopNodeType> if (currentNode.data.type === BlockEnum.Iteration) { - if (currentNode.data.start_node_id) { - if (nodesMap[currentNode.data.start_node_id]?.type !== CUSTOM_ITERATION_START_NODE) + if (currentNode.data.start_node_id && nodesMap[currentNode.data.start_node_id]) { + if (nodesMap[currentNode.data.start_node_id].type !== CUSTOM_ITERATION_START_NODE) iterationNodesWithStartNode.push(currentNode) } else { @@ -121,8 +121,8 @@ export const preprocessNodesAndEdges = (nodes: Node[], edges: Edge[]) => { } if (currentNode.data.type === BlockEnum.Loop) { - if (currentNode.data.start_node_id) { - if (nodesMap[currentNode.data.start_node_id]?.type !== CUSTOM_LOOP_START_NODE) + if (currentNode.data.start_node_id && nodesMap[currentNode.data.start_node_id]) { + if (nodesMap[currentNode.data.start_node_id].type !== CUSTOM_LOOP_START_NODE) loopNodesWithStartNode.push(currentNode) } else { diff --git a/web/app/components/workflow/workflow-preview/components/nodes/if-else/node.tsx b/web/app/components/workflow/workflow-preview/components/nodes/if-else/node.tsx index 1e79244a3f..e9af59a0a2 100644 --- a/web/app/components/workflow/workflow-preview/components/nodes/if-else/node.tsx +++ b/web/app/components/workflow/workflow-preview/components/nodes/if-else/node.tsx @@ -65,7 +65,7 @@ const IfElseNode: FC<NodeProps<IfElseNodeType>> = (props) => { </div> <div className="space-y-0.5"> {caseItem.conditions.map((condition, i) => ( - <div key={condition.id} className="relative"> + <div key={condition.id || i} className="relative"> { checkIsConditionSet(condition) ? ( diff --git a/web/app/components/workflow/workflow-preview/index.tsx b/web/app/components/workflow/workflow-preview/index.tsx index bb85e00b6b..9e6229f4c4 100644 --- a/web/app/components/workflow/workflow-preview/index.tsx +++ b/web/app/components/workflow/workflow-preview/index.tsx @@ -2,6 +2,7 @@ import type { EdgeChange, + FitViewOptions, NodeChange, Viewport, } from 'reactflow' @@ -59,8 +60,10 @@ const edgeTypes = { type WorkflowPreviewProps = { nodes: Node[] edges: Edge[] - viewport: Viewport + viewport?: Viewport className?: string + fitView?: boolean + fitViewOptions?: FitViewOptions miniMapToRight?: boolean } const WorkflowPreview = ({ @@ -68,6 +71,8 @@ const WorkflowPreview = ({ edges, viewport, className, + fitView, + fitViewOptions, miniMapToRight, }: WorkflowPreviewProps) => { const [nodesData, setNodesData] = useState(() => initialNodes(nodes, edges)) @@ -125,6 +130,8 @@ const WorkflowPreview = ({ selectionKeyCode={null} selectionMode={SelectionMode.Partial} minZoom={0.25} + fitView={fitView} + fitViewOptions={fitViewOptions} > <Background gap={[14, 14]} diff --git a/web/contract/console/goto-anything.ts b/web/contract/console/goto-anything.ts new file mode 100644 index 0000000000..35f61cf518 --- /dev/null +++ b/web/contract/console/goto-anything.ts @@ -0,0 +1,87 @@ +import type { AppListResponse } from '@/models/app' +import type { DataSetListResponse } from '@/models/datasets' +import type { BackendEdgeSpec, BackendNodeSpec, FlowchartGenRes } from '@/service/debug' +import { type } from '@orpc/contract' +import { base } from '../base' + +// Search APIs +export const searchAppsContract = base + .route({ + path: '/apps', + method: 'GET', + }) + .input(type<{ + query?: { + page?: number + limit?: number + name?: string + } + }>()) + .output(type<AppListResponse>()) + +export const searchDatasetsContract = base + .route({ + path: '/datasets', + method: 'GET', + }) + .input(type<{ + query?: { + page?: number + limit?: number + keyword?: string + } + }>()) + .output(type<DataSetListResponse>()) + +// Vibe Workflow API +export type GenerateFlowchartInput = { + instruction: string + model_config: { + provider: string + name: string + mode: string + completion_params: Record<string, unknown> + } | null + available_nodes: Array<{ + type: string + title?: string + description?: string + }> + existing_nodes?: Array<{ + id: string + type: string + title?: string + }> + existing_edges?: BackendEdgeSpec[] + available_tools: Array<{ + provider_id: string + provider_name?: string + provider_type?: string + tool_name: string + tool_label?: string + tool_key?: string + tool_description?: string + }> + selected_node_ids?: string[] + previous_workflow?: { + nodes: BackendNodeSpec[] + edges: BackendEdgeSpec[] + warnings?: string[] + } + regenerate_mode?: boolean + language: string + available_models?: Array<{ + provider: string + model: string + }> +} + +export const generateFlowchartContract = base + .route({ + path: '/flowchart-generate', + method: 'POST', + }) + .input(type<{ + body: GenerateFlowchartInput + }>()) + .output(type<FlowchartGenRes>()) diff --git a/web/contract/router.ts b/web/contract/router.ts index 33499b106f..dbb6f48a93 100644 --- a/web/contract/router.ts +++ b/web/contract/router.ts @@ -1,5 +1,6 @@ import type { InferContractRouterInputs } from '@orpc/contract' import { bindPartnerStackContract, invoicesContract } from './console/billing' +import { generateFlowchartContract, searchAppsContract, searchDatasetsContract } from './console/goto-anything' import { systemFeaturesContract } from './console/system' import { triggerOAuthConfigContract, @@ -58,6 +59,11 @@ export const consoleRouterContract = { oauthDelete: triggerOAuthDeleteContract, oauthInitiate: triggerOAuthInitiateContract, }, + gotoAnything: { + searchApps: searchAppsContract, + searchDatasets: searchDatasetsContract, + generateFlowchart: generateFlowchartContract, + }, } export type ConsoleInputs = InferContractRouterInputs<typeof consoleRouterContract> diff --git a/web/eslint-suppressions.json b/web/eslint-suppressions.json index f1e7af211d..e127d3a21d 100644 --- a/web/eslint-suppressions.json +++ b/web/eslint-suppressions.json @@ -269,6 +269,9 @@ } }, "app/components/app/app-publisher/index.tsx": { + "tailwindcss/no-unnecessary-whitespace": { + "count": 1 + }, "tailwindcss/no-unnecessary-whitespace": { "count": 1 }, @@ -3204,6 +3207,11 @@ "count": 1 } }, + "app/components/share/text-generation/result/header.tsx": { + "tailwindcss/no-unnecessary-whitespace": { + "count": 3 + } + }, "app/components/share/text-generation/result/index.tsx": { "react-hooks-extra/no-direct-set-state-in-use-effect": { "count": 3 @@ -3643,6 +3651,11 @@ "count": 1 } }, + "app/components/workflow/nodes/_base/components/before-run-form/panel-wrap.tsx": { + "tailwindcss/no-unnecessary-whitespace": { + "count": 1 + } + }, "app/components/workflow/nodes/_base/components/editor/code-editor/editor-support-vars.tsx": { "react-hooks-extra/no-direct-set-state-in-use-effect": { "count": 1 @@ -5447,4 +5460,4 @@ "count": 2 } } -} \ No newline at end of file +} diff --git a/web/i18n/en-US/app.json b/web/i18n/en-US/app.json index e4109db4b6..fcc220448e 100644 --- a/web/i18n/en-US/app.json +++ b/web/i18n/en-US/app.json @@ -75,6 +75,9 @@ "gotoAnything.actions.themeLightDesc": "Use light appearance", "gotoAnything.actions.themeSystem": "System Theme", "gotoAnything.actions.themeSystemDesc": "Follow your OS appearance", + "gotoAnything.actions.vibeDesc": "Generate workflow from natural language", + "gotoAnything.actions.vibeHint": "Try: {{prompt}}", + "gotoAnything.actions.vibeTitle": "Vibe", "gotoAnything.actions.zenDesc": "Toggle canvas focus mode", "gotoAnything.actions.zenTitle": "Zen Mode", "gotoAnything.clearToSearchAll": "Clear @ to search all", diff --git a/web/i18n/en-US/workflow.json b/web/i18n/en-US/workflow.json index 4d9f5adbac..f55f73da8f 100644 --- a/web/i18n/en-US/workflow.json +++ b/web/i18n/en-US/workflow.json @@ -1150,5 +1150,26 @@ "versionHistory.nameThisVersion": "Name this version", "versionHistory.releaseNotesPlaceholder": "Describe what changed", "versionHistory.restorationTip": "After version restoration, the current draft will be overwritten.", - "versionHistory.title": "Versions" + "versionHistory.title": "Versions", + "vibe.apply": "Apply", + "vibe.generateError": "Failed to generate workflow. Please try again.", + "vibe.generatingFlowchart": "Generating flowchart preview...", + "vibe.invalidFlowchart": "The generated flowchart could not be parsed.", + "vibe.missingFlowchart": "No flowchart was generated.", + "vibe.missingInstruction": "Describe the workflow you want to build.", + "vibe.modelUnavailable": "No model available for flowchart generation.", + "vibe.noFlowchart": "No flowchart provided", + "vibe.noFlowchartYet": "No flowchart preview available", + "vibe.nodeTypeUnavailable": "Node type \"{{type}}\" is not available in this workflow.", + "vibe.nodesUnavailable": "Workflow nodes are not available yet.", + "vibe.offTopicDefault": "I'm the Dify workflow design assistant. I can help you create AI automation workflows, but I can't answer general questions. Would you like to create a workflow instead?", + "vibe.offTopicTitle": "Off-Topic Request", + "vibe.panelTitle": "Workflow Preview", + "vibe.readOnly": "This workflow is read-only.", + "vibe.regenerate": "Regenerate", + "vibe.regenerateReminder": "Please verify your input and re-generate.", + "vibe.toolUnavailable": "Tool \"{{tool}}\" is not available in this workspace.", + "vibe.trySuggestion": "Try one of these suggestions:", + "vibe.unknownNodeId": "Node \"{{id}}\" is used before it is defined.", + "vibe.unsupportedEdgeLabel": "Unsupported edge label \"{{label}}\". Only true/false are allowed for if/else." } diff --git a/web/i18n/zh-Hans/workflow.json b/web/i18n/zh-Hans/workflow.json index acda7db2fc..40c4139aa6 100644 --- a/web/i18n/zh-Hans/workflow.json +++ b/web/i18n/zh-Hans/workflow.json @@ -1150,5 +1150,6 @@ "versionHistory.nameThisVersion": "命名", "versionHistory.releaseNotesPlaceholder": "请描述变更", "versionHistory.restorationTip": "版本回滚后,当前草稿将被覆盖。", - "versionHistory.title": "版本" + "versionHistory.title": "版本", + "vibe.regenerateReminder": "请检查输入并重新生成。" } diff --git a/web/package.json b/web/package.json index 2e44328f53..e617d3eb3f 100644 --- a/web/package.json +++ b/web/package.json @@ -236,7 +236,8 @@ "vite": "7.3.1", "vite-tsconfig-paths": "6.0.4", "vitest": "4.0.17", - "vitest-canvas-mock": "1.1.3" + "vitest-canvas-mock": "1.1.3", + "vitest-tiny-reporter": "1.3.1" }, "pnpm": { "overrides": { diff --git a/web/pnpm-lock.yaml b/web/pnpm-lock.yaml index 169428bfbd..7024688ade 100644 --- a/web/pnpm-lock.yaml +++ b/web/pnpm-lock.yaml @@ -585,6 +585,9 @@ importers: vitest-canvas-mock: specifier: 1.1.3 version: 1.1.3(vitest@4.0.17) + vitest-tiny-reporter: + specifier: 1.3.1 + version: 1.3.1(@vitest/runner@4.0.17)(vitest@4.0.17) packages: @@ -7291,6 +7294,12 @@ packages: peerDependencies: vitest: ^3.0.0 || ^4.0.0 + vitest-tiny-reporter@1.3.1: + resolution: {integrity: sha512-9WfLruQBbxm4EqMIS0jDZmQjvMgsWgHUso9mHQWgjA6hM3tEVhjdG8wYo7ePFh1XbwEFzEo3XUQqkGoKZ/Td2Q==} + peerDependencies: + '@vitest/runner': ^2.0.0 || ^3.0.2 || ^4.0.0 + vitest: ^2.0.0 || ^3.0.2 || ^4.0.0 + vitest@4.0.17: resolution: {integrity: sha512-FQMeF0DJdWY0iOnbv466n/0BudNdKj1l5jYgl5JVTwjSsZSlqyXFt/9+1sEyhR6CLowbZpV7O1sCHrzBhucKKg==} engines: {node: ^20.0.0 || ^22.0.0 || >=24.0.0} @@ -15342,6 +15351,12 @@ snapshots: moo-color: 1.0.3 vitest: 4.0.17(@types/node@18.15.0)(@vitest/browser-playwright@4.0.17)(jiti@1.21.7)(jsdom@27.3.0(canvas@3.2.1))(sass@1.93.2)(terser@5.46.0)(tsx@4.21.0)(yaml@2.8.2) + vitest-tiny-reporter@1.3.1(@vitest/runner@4.0.17)(vitest@4.0.17): + dependencies: + '@vitest/runner': 4.0.17 + tinyrainbow: 3.0.3 + vitest: 4.0.17(@types/node@18.15.0)(@vitest/browser-playwright@4.0.17)(jiti@1.21.7)(jsdom@27.3.0(canvas@3.2.1))(sass@1.93.2)(terser@5.46.0)(tsx@4.21.0)(yaml@2.8.2) + vitest@4.0.17(@types/node@18.15.0)(@vitest/browser-playwright@4.0.17)(jiti@1.21.7)(jsdom@27.3.0(canvas@3.2.1))(sass@1.93.2)(terser@5.46.0)(tsx@4.21.0)(yaml@2.8.2): dependencies: '@vitest/expect': 4.0.17 diff --git a/web/service/debug.ts b/web/service/debug.ts index 9f11643e7f..9e4377435e 100644 --- a/web/service/debug.ts +++ b/web/service/debug.ts @@ -19,6 +19,48 @@ export type GenRes = { error?: string } +export type ToolRecommendation = { + requested_capability: string + unconfigured_tools: Array<{ + provider_id: string + tool_name: string + description: string + }> + configured_alternatives: Array<{ + provider_id: string + tool_name: string + description: string + }> + recommendation: string +} + +export type BackendNodeSpec = { + id: string + type: string + title?: string + config?: Record<string, any> + position?: { x: number, y: number } +} + +export type BackendEdgeSpec = { + source: string + target: string + sourceHandle?: string + targetHandle?: string +} + +export type FlowchartGenRes = { + intent?: 'generate' | 'off_topic' | 'error' + flowchart: string + nodes?: BackendNodeSpec[] + edges?: BackendEdgeSpec[] + message?: string + warnings?: string[] + suggestions?: string[] + tool_recommendations?: ToolRecommendation[] + error?: string +} + export type CodeGenRes = { code: string language: string[] @@ -75,6 +117,12 @@ export const generateRule = (body: Record<string, any>) => { }) } +export const generateFlowchart = (body: Record<string, any>) => { + return post<FlowchartGenRes>('/flowchart-generate', { + body, + }) +} + export const fetchModelParams = (providerName: string, modelId: string) => { return get(`workspaces/current/model-providers/${providerName}/models/parameter-rules`, { params: { diff --git a/web/service/use-goto-anything.ts b/web/service/use-goto-anything.ts new file mode 100644 index 0000000000..a6ca833c25 --- /dev/null +++ b/web/service/use-goto-anything.ts @@ -0,0 +1,50 @@ +import type { GenerateFlowchartInput } from '@/contract/console/goto-anything' +import { consoleClient, consoleQuery, marketplaceClient, marketplaceQuery } from '@/service/client' + +// Search APIs +export const searchAppsQueryKey = consoleQuery.gotoAnything.searchApps.queryKey + +export const searchApps = async (name?: string) => { + return consoleClient.gotoAnything.searchApps({ + query: { + page: 1, + name, + }, + }) +} + +export const searchDatasetsQueryKey = consoleQuery.gotoAnything.searchDatasets.queryKey + +export const searchDatasets = async (keyword?: string) => { + return consoleClient.gotoAnything.searchDatasets({ + query: { + page: 1, + limit: 10, + keyword, + }, + }) +} + +export const searchPluginsQueryKey = marketplaceQuery.searchAdvanced.queryKey + +export const searchPlugins = async (query?: string) => { + return marketplaceClient.searchAdvanced({ + params: { + kind: 'plugins', + }, + body: { + query: query || '', + page: 1, + page_size: 10, + }, + }) +} + +// Vibe Workflow API +export const generateFlowchartMutationKey = consoleQuery.gotoAnything.generateFlowchart.mutationKey + +export const generateFlowchart = async (input: GenerateFlowchartInput) => { + return consoleClient.gotoAnything.generateFlowchart({ + body: input, + }) +}