From cd030d82e57550860e350140292bade050cf25b9 Mon Sep 17 00:00:00 2001 From: aqiu <819110812@qq.com> Date: Sat, 27 Dec 2025 15:06:44 +0800 Subject: [PATCH] refactor(vibe): extract workflow generation to dedicated module - Move workflow generation logic from LLMGenerator to WorkflowGenerator - Extract to api/core/workflow/generator/ with modular architecture - Implement Planner-Builder pattern for better separation of concerns - Add validation engine with rule-based error classification - Add node and edge repair utilities for auto-fixing common issues - Add deterministic Mermaid generator for consistent output - Reorganize configuration and prompts - Move vibe_config/ to generator/config/ - Move vibe_prompts.py to generator/prompts/ (split into multiple files) - Add builder_prompts.py and planner_prompts.py for new architecture - Enhance frontend workflow handling - Use standard node initialization for proper node setup - Improve variable reference replacement with better error handling - Add model fallback logic for better compatibility - Handle end node outputs format (value_selector vs legacy format) - Ensure parameter-extractor nodes have required 'required' field - Add comprehensive test coverage - Unit tests for mermaid generator, node repair, edge repair - Tests for validation engine and rule system - Tests for planner prompts formatting - Frontend tests for variable reference replacement - Add max_fix_iterations parameter for validate-fix loop configuration # Conflicts: # web/app/components/workflow/hooks/use-workflow-vibe.tsx --- api/controllers/console/app/generator.py | 20 +- api/core/llm_generator/llm_generator.py | 180 +-- api/core/workflow/generator/__init__.py | 1 + .../generator/config}/__init__.py | 6 +- .../generator/config}/fallback_rules.py | 0 .../generator/config}/node_definitions.json | 0 .../generator/config}/node_schemas.py | 88 +- .../generator/config}/responses.py | 0 .../workflow/generator/prompts/__init__.py | 0 .../generator/prompts/builder_prompts.py | 343 +++++ .../generator/prompts/planner_prompts.py | 75 ++ .../generator/prompts}/vibe_prompts.py | 243 +++- api/core/workflow/generator/runner.py | 194 +++ .../workflow/generator/utils/edge_repair.py | 372 ++++++ .../generator/utils/mermaid_generator.py | 138 ++ .../workflow/generator/utils/node_repair.py | 96 ++ .../generator/utils/workflow_validator.py | 96 ++ .../workflow/generator/validation/__init__.py | 45 + .../workflow/generator/validation/context.py | 123 ++ .../workflow/generator/validation/engine.py | 266 ++++ .../workflow/generator/validation/rules.py | 1148 +++++++++++++++++ .../llm_generator/test_mermaid_generator.py | 288 +++++ .../core/llm_generator/test_node_repair.py | 81 ++ .../llm_generator/test_planner_prompts.py | 173 +++ .../llm_generator/test_validation_engine.py | 536 ++++++++ .../test_workflow_validator_vibe.py | 435 +++++++ .../hooks/__tests__/use-workflow-vibe.test.ts | 82 ++ .../workflow/hooks/use-workflow-vibe.tsx | 284 ++-- 28 files changed, 5046 insertions(+), 267 deletions(-) create mode 100644 api/core/workflow/generator/__init__.py rename api/core/{llm_generator/vibe_config => workflow/generator/config}/__init__.py (74%) rename api/core/{llm_generator/vibe_config => workflow/generator/config}/fallback_rules.py (100%) rename api/core/{llm_generator/vibe_config => workflow/generator/config}/node_definitions.json (100%) rename api/core/{llm_generator/vibe_config => workflow/generator/config}/node_schemas.py (66%) rename api/core/{llm_generator/vibe_config => workflow/generator/config}/responses.py (100%) create mode 100644 api/core/workflow/generator/prompts/__init__.py create mode 100644 api/core/workflow/generator/prompts/builder_prompts.py create mode 100644 api/core/workflow/generator/prompts/planner_prompts.py rename api/core/{llm_generator => workflow/generator/prompts}/vibe_prompts.py (80%) create mode 100644 api/core/workflow/generator/runner.py create mode 100644 api/core/workflow/generator/utils/edge_repair.py create mode 100644 api/core/workflow/generator/utils/mermaid_generator.py create mode 100644 api/core/workflow/generator/utils/node_repair.py create mode 100644 api/core/workflow/generator/utils/workflow_validator.py create mode 100644 api/core/workflow/generator/validation/__init__.py create mode 100644 api/core/workflow/generator/validation/context.py create mode 100644 api/core/workflow/generator/validation/engine.py create mode 100644 api/core/workflow/generator/validation/rules.py create mode 100644 api/tests/unit_tests/core/llm_generator/test_mermaid_generator.py create mode 100644 api/tests/unit_tests/core/llm_generator/test_node_repair.py create mode 100644 api/tests/unit_tests/core/llm_generator/test_planner_prompts.py create mode 100644 api/tests/unit_tests/core/llm_generator/test_validation_engine.py create mode 100644 api/tests/unit_tests/core/llm_generator/test_workflow_validator_vibe.py create mode 100644 web/app/components/workflow/hooks/__tests__/use-workflow-vibe.test.ts diff --git a/api/controllers/console/app/generator.py b/api/controllers/console/app/generator.py index 2f1b3a0db4..eb9f12752b 100644 --- a/api/controllers/console/app/generator.py +++ b/api/controllers/console/app/generator.py @@ -1,9 +1,13 @@ +import logging from collections.abc import Sequence -from typing import Any +from typing import Any, cast from flask_restx import Resource from pydantic import BaseModel, Field +logger = logging.getLogger(__name__) + + from controllers.console import console_ns from controllers.console.app.error import ( CompletionRequestError, @@ -18,6 +22,7 @@ from core.helper.code_executor.javascript.javascript_code_provider import Javasc from core.helper.code_executor.python3.python3_code_provider import Python3CodeProvider from core.llm_generator.llm_generator import LLMGenerator from core.model_runtime.errors.invoke import InvokeError +from core.workflow.generator import WorkflowGenerator from extensions.ext_database import db from libs.login import current_account_with_tenant, login_required from models import App @@ -77,6 +82,13 @@ class FlowchartGeneratePayload(BaseModel): language: str | None = Field(default=None, description="Preferred language for generated content") # Available models that user has configured (for LLM/question-classifier nodes) available_models: list[dict[str, Any]] = Field(default_factory=list, description="User's configured models") + # Validate-fix iteration loop configuration + max_fix_iterations: int = Field( + default=2, + ge=0, + le=5, + description="Maximum number of validate-fix iterations (0 to disable auto-fix)", + ) def reg(cls: type[BaseModel]): @@ -305,7 +317,7 @@ class FlowchartGenerateApi(Resource): "warnings": args.previous_workflow.warnings, } - result = LLMGenerator.generate_workflow_flowchart( + result = WorkflowGenerator.generate_workflow_flowchart( tenant_id=current_tenant_id, instruction=args.instruction, model_config=args.model_config_data, @@ -313,11 +325,13 @@ class FlowchartGenerateApi(Resource): existing_nodes=args.existing_nodes, available_tools=args.available_tools, selected_node_ids=args.selected_node_ids, - previous_workflow=previous_workflow_dict, + previous_workflow=cast(dict[str, object], previous_workflow_dict), regenerate_mode=args.regenerate_mode, preferred_language=args.language, available_models=args.available_models, + max_fix_iterations=args.max_fix_iterations, ) + except ProviderTokenNotInitError as ex: raise ProviderNotInitializeError(ex.description) except QuotaExceededError: diff --git a/api/core/llm_generator/llm_generator.py b/api/core/llm_generator/llm_generator.py index 21b4ed38f4..7a5c1d550c 100644 --- a/api/core/llm_generator/llm_generator.py +++ b/api/core/llm_generator/llm_generator.py @@ -1,6 +1,5 @@ import json import logging -import re from collections.abc import Sequence from typing import Protocol, cast @@ -12,8 +11,6 @@ from core.llm_generator.prompts import ( CONVERSATION_TITLE_PROMPT, GENERATOR_QA_PROMPT, JAVASCRIPT_CODE_GENERATOR_PROMPT_TEMPLATE, - LLM_MODIFY_CODE_SYSTEM, - LLM_MODIFY_PROMPT_SYSTEM, PYTHON_CODE_GENERATOR_PROMPT_TEMPLATE, SUGGESTED_QUESTIONS_MAX_TOKENS, SUGGESTED_QUESTIONS_TEMPERATURE, @@ -30,6 +27,7 @@ from core.ops.ops_trace_manager import TraceQueueManager, TraceTask from core.ops.utils import measure_time from core.prompt.utils.prompt_template_parser import PromptTemplateParser from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey +from core.workflow.generator import WorkflowGenerator from extensions.ext_database import db from extensions.ext_storage import storage from models import App, Message, WorkflowNodeExecutionModel @@ -299,178 +297,24 @@ class LLMGenerator: regenerate_mode: bool = False, preferred_language: str | None = None, available_models: Sequence[dict[str, object]] | None = None, + max_fix_iterations: int = 2, ): - """ - Generate workflow flowchart with enhanced prompts and inline intent classification. - Returns a dict with: - - intent: "generate" | "off_topic" | "error" - - flowchart: Mermaid syntax string (for generate intent) - - message: User-friendly explanation - - warnings: List of validation warnings - - suggestions: List of workflow suggestions (for off_topic intent) - - error: Error message if generation failed - """ - from core.llm_generator.vibe_prompts import ( - build_vibe_enhanced_prompt, - extract_mermaid_from_response, - parse_vibe_response, - sanitize_tool_nodes, - validate_node_parameters, - validate_tool_references, - ) - - model_parameters = model_config.get("completion_params", {}) - - # Build enhanced prompts with context - system_prompt, user_prompt = build_vibe_enhanced_prompt( + return WorkflowGenerator.generate_workflow_flowchart( + tenant_id=tenant_id, instruction=instruction, - available_nodes=list(available_nodes) if available_nodes else None, - available_tools=list(available_tools) if available_tools else None, - existing_nodes=list(existing_nodes) if existing_nodes else None, - selected_node_ids=list(selected_node_ids) if selected_node_ids else None, - previous_workflow=dict(previous_workflow) if previous_workflow else None, + model_config=model_config, + available_nodes=available_nodes, + existing_nodes=existing_nodes, + available_tools=available_tools, + selected_node_ids=selected_node_ids, + previous_workflow=previous_workflow, regenerate_mode=regenerate_mode, preferred_language=preferred_language, - available_models=list(available_models) if available_models else None, + available_models=available_models, + max_fix_iterations=max_fix_iterations, ) - prompt_messages: list[PromptMessage] = [ - SystemPromptMessage(content=system_prompt), - UserPromptMessage(content=user_prompt), - ] - - # DEBUG: Log model input - logger.debug("=" * 80) - logger.debug("[VIBE] generate_workflow_flowchart - MODEL INPUT") - logger.debug("=" * 80) - logger.debug("[VIBE] Instruction: %s", instruction) - logger.debug("[VIBE] Model: %s/%s", model_config.get("provider", ""), model_config.get("name", "")) - system_prompt_log = system_prompt[:2000] + "..." if len(system_prompt) > 2000 else system_prompt - logger.debug("[VIBE] System Prompt:\n%s", system_prompt_log) - logger.debug("[VIBE] User Prompt:\n%s", user_prompt) - logger.debug("=" * 80) - - model_manager = ModelManager() - model_instance = model_manager.get_model_instance( - tenant_id=tenant_id, - model_type=ModelType.LLM, - provider=model_config.get("provider", ""), - model=model_config.get("name", ""), - ) - - try: - response: LLMResult = model_instance.invoke_llm( - prompt_messages=list(prompt_messages), - model_parameters=model_parameters, - stream=False, - ) - content = response.message.get_text_content() - - # DEBUG: Log model output - logger.debug("=" * 80) - logger.debug("[VIBE] generate_workflow_flowchart - MODEL OUTPUT") - logger.debug("=" * 80) - logger.debug("[VIBE] Raw Response:\n%s", content) - logger.debug("=" * 80) - if not isinstance(content, str): - raise ValueError("Flowchart response is not a string") - - # Parse the enhanced response format - parsed = parse_vibe_response(content) - - # DEBUG: Log parsed result - logger.debug("[VIBE] Parsed Response:") - logger.debug("[VIBE] intent: %s", parsed.get("intent")) - logger.debug("[VIBE] message: %s", parsed.get("message", "")[:200] if parsed.get("message") else "") - logger.debug("[VIBE] mermaid: %s", parsed.get("mermaid", "")[:500] if parsed.get("mermaid") else "") - logger.debug("[VIBE] warnings: %s", parsed.get("warnings", [])) - logger.debug("[VIBE] suggestions: %s", parsed.get("suggestions", [])) - if parsed.get("error"): - logger.debug("[VIBE] error: %s", parsed.get("error")) - logger.debug("=" * 80) - - # Handle error case from parsing - if parsed.get("intent") == "error": - # Fall back to legacy parsing for backwards compatibility - match = re.search(r"```(?:mermaid)?\s*([\s\S]+?)```", content, flags=re.IGNORECASE) - flowchart = (match.group(1) if match else content).strip() - return { - "intent": "generate", - "flowchart": flowchart, - "message": "", - "warnings": [], - "tool_recommendations": [], - "error": "", - } - - # Handle off_topic case - if parsed.get("intent") == "off_topic": - return { - "intent": "off_topic", - "flowchart": "", - "message": parsed.get("message", ""), - "suggestions": parsed.get("suggestions", []), - "warnings": [], - "tool_recommendations": [], - "error": "", - } - - # Handle generate case - flowchart = extract_mermaid_from_response(parsed) - - # Sanitize tool nodes - replace invalid tools with fallback nodes - original_nodes = parsed.get("nodes", []) - sanitized_nodes, sanitize_warnings = sanitize_tool_nodes( - original_nodes, - list(available_tools) if available_tools else None, - ) - # Update parsed nodes with sanitized version - parsed["nodes"] = sanitized_nodes - - # Validate tool references and get recommendations for unconfigured tools - validation_warnings, tool_recommendations = validate_tool_references( - sanitized_nodes, - list(available_tools) if available_tools else None, - ) - - # Validate node parameters are properly filled (Phase 9: Auto-Fill) - param_warnings = validate_node_parameters(sanitized_nodes) - - existing_warnings = parsed.get("warnings", []) - all_warnings = existing_warnings + sanitize_warnings + validation_warnings + param_warnings - - return { - "intent": "generate", - "flowchart": flowchart, - "nodes": sanitized_nodes, # Include sanitized nodes in response - "edges": parsed.get("edges", []), - "message": parsed.get("message", ""), - "warnings": all_warnings, - "tool_recommendations": tool_recommendations, - "error": "", - } - - except InvokeError as e: - return { - "intent": "error", - "flowchart": "", - "message": "", - "warnings": [], - "tool_recommendations": [], - "error": str(e), - } - except Exception as e: - logger.exception("Failed to generate workflow flowchart, model: %s", model_config.get("name")) - return { - "intent": "error", - "flowchart": "", - "message": "", - "warnings": [], - "tool_recommendations": [], - "error": str(e), - } - @classmethod def generate_code(cls, tenant_id: str, instruction: str, model_config: dict, code_language: str = "javascript"): if code_language == "python": diff --git a/api/core/workflow/generator/__init__.py b/api/core/workflow/generator/__init__.py new file mode 100644 index 0000000000..2b722441a9 --- /dev/null +++ b/api/core/workflow/generator/__init__.py @@ -0,0 +1 @@ +from .runner import WorkflowGenerator diff --git a/api/core/llm_generator/vibe_config/__init__.py b/api/core/workflow/generator/config/__init__.py similarity index 74% rename from api/core/llm_generator/vibe_config/__init__.py rename to api/core/workflow/generator/config/__init__.py index 85e8a87d18..19e321b083 100644 --- a/api/core/llm_generator/vibe_config/__init__.py +++ b/api/core/workflow/generator/config/__init__.py @@ -5,14 +5,14 @@ This module centralizes configuration for the Vibe workflow generation feature, including node schemas, fallback rules, and response templates. """ -from core.llm_generator.vibe_config.fallback_rules import ( +from core.workflow.generator.config.fallback_rules import ( FALLBACK_RULES, FIELD_NAME_CORRECTIONS, NODE_TYPE_ALIASES, get_corrected_field_name, ) -from core.llm_generator.vibe_config.node_schemas import BUILTIN_NODE_SCHEMAS -from core.llm_generator.vibe_config.responses import DEFAULT_SUGGESTIONS, OFF_TOPIC_RESPONSES +from core.workflow.generator.config.node_schemas import BUILTIN_NODE_SCHEMAS +from core.workflow.generator.config.responses import DEFAULT_SUGGESTIONS, OFF_TOPIC_RESPONSES __all__ = [ "BUILTIN_NODE_SCHEMAS", diff --git a/api/core/llm_generator/vibe_config/fallback_rules.py b/api/core/workflow/generator/config/fallback_rules.py similarity index 100% rename from api/core/llm_generator/vibe_config/fallback_rules.py rename to api/core/workflow/generator/config/fallback_rules.py diff --git a/api/core/llm_generator/vibe_config/node_definitions.json b/api/core/workflow/generator/config/node_definitions.json similarity index 100% rename from api/core/llm_generator/vibe_config/node_definitions.json rename to api/core/workflow/generator/config/node_definitions.json diff --git a/api/core/llm_generator/vibe_config/node_schemas.py b/api/core/workflow/generator/config/node_schemas.py similarity index 66% rename from api/core/llm_generator/vibe_config/node_schemas.py rename to api/core/workflow/generator/config/node_schemas.py index 779aba2efa..4af79e36f3 100644 --- a/api/core/llm_generator/vibe_config/node_schemas.py +++ b/api/core/workflow/generator/config/node_schemas.py @@ -137,19 +137,28 @@ BUILTIN_NODE_SCHEMAS: dict[str, dict[str, Any]] = { }, "if-else": { "description": "Conditional branching based on conditions", - "required": ["conditions"], + "required": ["cases"], "parameters": { - "conditions": { + "cases": { "type": "array", - "description": "List of condition cases", + "description": "List of condition cases. Each case defines when 'true' branch is taken.", "item_schema": { - "case_id": "string - unique case identifier", - "logical_operator": "enum: and, or", - "conditions": "array of {variable_selector, comparison_operator, value}", + "case_id": "string - unique case identifier (e.g., 'case_1')", + "logical_operator": "enum: and, or - how multiple conditions combine", + "conditions": { + "type": "array", + "item_schema": { + "variable_selector": "array of strings - path to variable, e.g. ['node_id', 'field']", + "comparison_operator": ( + "enum: =, ≠, >, <, ≥, ≤, contains, not contains, is, is not, empty, not empty" + ), + "value": "string or number - value to compare against", + }, + }, }, }, }, - "outputs": ["Branches: true (conditions met), false (else)"], + "outputs": ["Branches: true (first case conditions met), false (else/no case matched)"], }, "knowledge-retrieval": { "description": "Query knowledge base for relevant content", @@ -207,5 +216,70 @@ BUILTIN_NODE_SCHEMAS: dict[str, dict[str, Any]] = { }, "outputs": ["item (current iteration item)", "index (current index)"], }, + "parameter-extractor": { + "description": "Extract structured parameters from user input using LLM", + "required": ["query", "parameters"], + "parameters": { + "model": { + "type": "object", + "description": "Model configuration (provider, name, mode)", + }, + "query": { + "type": "array", + "description": "Path to input text to extract parameters from, e.g. ['start', 'user_input']", + }, + "parameters": { + "type": "array", + "description": "Parameters to extract from the input", + "item_schema": { + "name": "string - parameter name (required)", + "type": ( + "enum: string, number, boolean, array[string], array[number], " + "array[object], array[boolean]" + ), + "description": "string - description of what to extract (required)", + "required": "boolean - whether this parameter is required (MUST be specified)", + "options": "array of strings (optional) - for enum-like selection", + }, + }, + "instruction": { + "type": "string", + "description": "Additional instructions for extraction", + }, + "reasoning_mode": { + "type": "enum", + "options": ["function_call", "prompt"], + "description": "How to perform extraction (defaults to function_call)", + }, + }, + "outputs": ["Extracted parameters as defined in parameters array", "__is_success", "__reason"], + }, + "question-classifier": { + "description": "Classify user input into predefined categories using LLM", + "required": ["query", "classes"], + "parameters": { + "model": { + "type": "object", + "description": "Model configuration (provider, name, mode)", + }, + "query": { + "type": "array", + "description": "Path to input text to classify, e.g. ['start', 'user_input']", + }, + "classes": { + "type": "array", + "description": "Classification categories", + "item_schema": { + "id": "string - unique class identifier", + "name": "string - class name/label", + }, + }, + "instruction": { + "type": "string", + "description": "Additional instructions for classification", + }, + }, + "outputs": ["class_name (selected class)"], + }, } diff --git a/api/core/llm_generator/vibe_config/responses.py b/api/core/workflow/generator/config/responses.py similarity index 100% rename from api/core/llm_generator/vibe_config/responses.py rename to api/core/workflow/generator/config/responses.py diff --git a/api/core/workflow/generator/prompts/__init__.py b/api/core/workflow/generator/prompts/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/core/workflow/generator/prompts/builder_prompts.py b/api/core/workflow/generator/prompts/builder_prompts.py new file mode 100644 index 0000000000..d45c9528ff --- /dev/null +++ b/api/core/workflow/generator/prompts/builder_prompts.py @@ -0,0 +1,343 @@ +BUILDER_SYSTEM_PROMPT = """ +You are a Workflow Configuration Engineer. +Your goal is to implement the Architect's plan by generating a precise, runnable Dify Workflow JSON configuration. + + + + +{plan_context} + + + +{tool_schemas} + + + +{builtin_node_specs} + + + + +1. **Configuration**: + - You MUST fill ALL required parameters for every node. + - Use `{{{{#node_id.field#}}}}` syntax to reference outputs from previous nodes in text fields. + - For 'start' node, define all necessary user inputs. + +2. **Variable References**: + - For text fields (like prompts, queries): use string format `{{{{#node_id.field#}}}}` + - For 'end' node outputs: use `value_selector` array format `["node_id", "field"]` + - Example: to reference 'llm' node's 'text' output in end node, use `["llm", "text"]` + +3. **Tools**: + - ONLY use the tools listed in ``. + - If a planned tool is missing from schemas, fallback to `http-request` or `code`. + +4. **Node Specifics**: + - For `if-else` comparison_operator, use literal symbols: `≥`, `≤`, `=`, `≠` (NOT `>=` or `==`). + +5. **Output**: + - Return ONLY the JSON object with `nodes` and `edges`. + - Do NOT generate Mermaid diagrams. + - Do NOT generate explanations. + + + +**EDGES ARE CRITICAL** - Every node except 'end' MUST have at least one outgoing edge. + +1. **Linear Flow**: Simple source -> target connection + ``` + {{"source": "node_a", "target": "node_b"}} + ``` + +2. **question-classifier Branching**: Each class MUST have a separate edge with `sourceHandle` = class `id` + - If you define classes: [{{"id": "cls_refund", "name": "Refund"}}, {{"id": "cls_inquiry", "name": "Inquiry"}}] + - You MUST create edges: + - {{"source": "classifier", "sourceHandle": "cls_refund", "target": "refund_handler"}} + - {{"source": "classifier", "sourceHandle": "cls_inquiry", "target": "inquiry_handler"}} + +3. **if-else Branching**: MUST have exactly TWO edges with sourceHandle "true" and "false" + - {{"source": "condition", "sourceHandle": "true", "target": "true_branch"}} + - {{"source": "condition", "sourceHandle": "false", "target": "false_branch"}} + +4. **Branch Convergence**: Multiple branches can connect to same downstream node + - Both true_branch and false_branch can connect to the same 'end' node + +5. **NEVER leave orphan nodes**: Every node must be connected in the graph + + + + +```json +{{ + "nodes": [ + {{ + "id": "start", + "type": "start", + "title": "Start", + "config": {{ + "variables": [{{"variable": "query", "label": "Query", "type": "text-input"}}] + }} + }}, + {{ + "id": "llm", + "type": "llm", + "title": "Generate Response", + "config": {{ + "model": {{"provider": "openai", "name": "gpt-4o", "mode": "chat"}}, + "prompt_template": [{{"role": "user", "text": "Answer: {{{{#start.query#}}}}"}}] + }} + }}, + {{ + "id": "end", + "type": "end", + "title": "End", + "config": {{ + "outputs": [ + {{"variable": "result", "value_selector": ["llm", "text"]}} + ] + }} + }} + ], + "edges": [ + {{"source": "start", "target": "llm"}}, + {{"source": "llm", "target": "end"}} + ] +}} +``` + + + +```json +{{ + "nodes": [ + {{ + "id": "start", + "type": "start", + "title": "Start", + "config": {{ + "variables": [{{"variable": "user_input", "label": "User Message", "type": "text-input", "required": true}}] + }} + }}, + {{ + "id": "classifier", + "type": "question-classifier", + "title": "Classify Intent", + "config": {{ + "model": {{"provider": "openai", "name": "gpt-4o", "mode": "chat"}}, + "query_variable_selector": ["start", "user_input"], + "classes": [ + {{"id": "cls_refund", "name": "Refund Request"}}, + {{"id": "cls_inquiry", "name": "Product Inquiry"}}, + {{"id": "cls_complaint", "name": "Complaint"}}, + {{"id": "cls_other", "name": "Other"}} + ], + "instruction": "Classify the user's intent" + }} + }}, + {{ + "id": "handle_refund", + "type": "llm", + "title": "Handle Refund", + "config": {{ + "model": {{"provider": "openai", "name": "gpt-4o", "mode": "chat"}}, + "prompt_template": [{{"role": "user", "text": "Extract order number and respond: {{{{#start.user_input#}}}}"}}] + }} + }}, + {{ + "id": "handle_inquiry", + "type": "llm", + "title": "Handle Inquiry", + "config": {{ + "model": {{"provider": "openai", "name": "gpt-4o", "mode": "chat"}}, + "prompt_template": [{{"role": "user", "text": "Answer product question: {{{{#start.user_input#}}}}"}}] + }} + }}, + {{ + "id": "handle_complaint", + "type": "llm", + "title": "Handle Complaint", + "config": {{ + "model": {{"provider": "openai", "name": "gpt-4o", "mode": "chat"}}, + "prompt_template": [{{"role": "user", "text": "Respond with empathy: {{{{#start.user_input#}}}}"}}] + }} + }}, + {{ + "id": "handle_other", + "type": "llm", + "title": "Handle Other", + "config": {{ + "model": {{"provider": "openai", "name": "gpt-4o", "mode": "chat"}}, + "prompt_template": [{{"role": "user", "text": "Provide general response: {{{{#start.user_input#}}}}"}}] + }} + }}, + {{ + "id": "end", + "type": "end", + "title": "End", + "config": {{ + "outputs": [{{"variable": "response", "value_selector": ["handle_refund", "text"]}}] + }} + }} + ], + "edges": [ + {{"source": "start", "target": "classifier"}}, + {{"source": "classifier", "sourceHandle": "cls_refund", "target": "handle_refund"}}, + {{"source": "classifier", "sourceHandle": "cls_inquiry", "target": "handle_inquiry"}}, + {{"source": "classifier", "sourceHandle": "cls_complaint", "target": "handle_complaint"}}, + {{"source": "classifier", "sourceHandle": "cls_other", "target": "handle_other"}}, + {{"source": "handle_refund", "target": "end"}}, + {{"source": "handle_inquiry", "target": "end"}}, + {{"source": "handle_complaint", "target": "end"}}, + {{"source": "handle_other", "target": "end"}} + ] +}} +``` +CRITICAL: Notice that each class id (cls_refund, cls_inquiry, etc.) becomes a sourceHandle in the edges! + + + +```json +{{ + "nodes": [ + {{ + "id": "start", + "type": "start", + "title": "Start", + "config": {{ + "variables": [{{"variable": "years", "label": "Years of Experience", "type": "number", "required": true}}] + }} + }}, + {{ + "id": "check_experience", + "type": "if-else", + "title": "Check Experience", + "config": {{ + "cases": [ + {{ + "case_id": "case_1", + "logical_operator": "and", + "conditions": [ + {{ + "variable_selector": ["start", "years"], + "comparison_operator": "≥", + "value": "3" + }} + ] + }} + ] + }} + }}, + {{ + "id": "qualified", + "type": "llm", + "title": "Qualified Response", + "config": {{ + "model": {{"provider": "openai", "name": "gpt-4o", "mode": "chat"}}, + "prompt_template": [{{"role": "user", "text": "Generate qualified candidate response"}}] + }} + }}, + {{ + "id": "not_qualified", + "type": "llm", + "title": "Not Qualified Response", + "config": {{ + "model": {{"provider": "openai", "name": "gpt-4o", "mode": "chat"}}, + "prompt_template": [{{"role": "user", "text": "Generate rejection response"}}] + }} + }}, + {{ + "id": "end", + "type": "end", + "title": "End", + "config": {{ + "outputs": [{{"variable": "result", "value_selector": ["qualified", "text"]}}] + }} + }} + ], + "edges": [ + {{"source": "start", "target": "check_experience"}}, + {{"source": "check_experience", "sourceHandle": "true", "target": "qualified"}}, + {{"source": "check_experience", "sourceHandle": "false", "target": "not_qualified"}}, + {{"source": "qualified", "target": "end"}}, + {{"source": "not_qualified", "target": "end"}} + ] +}} +``` +CRITICAL: if-else MUST have exactly two edges with sourceHandle "true" and "false"! + + + +```json +{{ + "nodes": [ + {{ + "id": "start", + "type": "start", + "title": "Start", + "config": {{ + "variables": [{{"variable": "resume", "label": "Resume Text", "type": "paragraph", "required": true}}] + }} + }}, + {{ + "id": "extract", + "type": "parameter-extractor", + "title": "Extract Info", + "config": {{ + "model": {{"provider": "openai", "name": "gpt-4o", "mode": "chat"}}, + "query": ["start", "resume"], + "parameters": [ + {{"name": "name", "type": "string", "description": "Candidate name", "required": true}}, + {{"name": "years", "type": "number", "description": "Years of experience", "required": true}}, + {{"name": "skills", "type": "array[string]", "description": "List of skills", "required": true}} + ], + "instruction": "Extract candidate information from resume" + }} + }}, + {{ + "id": "process", + "type": "llm", + "title": "Process Data", + "config": {{ + "model": {{"provider": "openai", "name": "gpt-4o", "mode": "chat"}}, + "prompt_template": [{{"role": "user", "text": "Name: {{{{#extract.name#}}}}, Years: {{{{#extract.years#}}}}"}}] + }} + }}, + {{ + "id": "end", + "type": "end", + "title": "End", + "config": {{ + "outputs": [{{"variable": "result", "value_selector": ["process", "text"]}}] + }} + }} + ], + "edges": [ + {{"source": "start", "target": "extract"}}, + {{"source": "extract", "target": "process"}}, + {{"source": "process", "target": "end"}} + ] +}} +``` + + + + +Before finalizing, verify: +1. [ ] Every node (except 'end') has at least one outgoing edge +2. [ ] 'start' node has exactly one outgoing edge +3. [ ] 'question-classifier' has one edge per class, each with sourceHandle = class id +4. [ ] 'if-else' has exactly two edges: sourceHandle "true" and sourceHandle "false" +5. [ ] All branches eventually connect to 'end' (directly or through other nodes) +6. [ ] No orphan nodes exist (every node is reachable from 'start') + +""" + +BUILDER_USER_PROMPT = """ +{instruction} + + +Generate the full workflow configuration now. Pay special attention to: +1. Creating edges for ALL branches of question-classifier and if-else nodes +2. Using correct sourceHandle values for branching nodes +3. Ensuring every node is connected in the graph +""" diff --git a/api/core/workflow/generator/prompts/planner_prompts.py b/api/core/workflow/generator/prompts/planner_prompts.py new file mode 100644 index 0000000000..ada791bf94 --- /dev/null +++ b/api/core/workflow/generator/prompts/planner_prompts.py @@ -0,0 +1,75 @@ +PLANNER_SYSTEM_PROMPT = """ +You are an expert Workflow Architect. +Your job is to analyze user requests and plan a high-level automation workflow. + + + +1. **Classify Intent**: + - Is the user asking to create an automation/workflow? -> Intent: "generate" + - Is it general chat/weather/jokes? -> Intent: "off_topic" + +2. **Plan Steps** (if intent is "generate"): + - Break down the user's goal into logical steps. + - For each step, identify if a specific capability/tool is needed. + - Select the MOST RELEVANT tools from the available_tools list. + - DO NOT configure parameters yet. Just identify the tool. + +3. **Output Format**: + Return a JSON object. + + + +{tools_summary} + + + +If intent is "generate": +```json +{{ + "intent": "generate", + "plan_thought": "Brief explanation of the plan...", + "steps": [ + {{ "step": 1, "description": "Fetch data from URL", "tool": "http-request" }}, + {{ "step": 2, "description": "Summarize content", "tool": "llm" }}, + {{ "step": 3, "description": "Search for info", "tool": "google_search" }} + ], + "required_tool_keys": ["google_search"] +}} +``` +(Note: 'http-request', 'llm', 'code' are built-in, you don't need to list them in required_tool_keys, +only external tools) + +If intent is "off_topic": +```json +{{ + "intent": "off_topic", + "message": "I can only help you build workflows. Try asking me to 'Create a workflow that...'", + "suggestions": ["Scrape a website", "Summarize a PDF"] +}} +``` + +""" + +PLANNER_USER_PROMPT = """ +{instruction} + +""" + + +def format_tools_for_planner(tools: list[dict]) -> str: + """Format tools list for planner (Lightweight: Name + Description only).""" + if not tools: + return "No external tools available." + + lines = [] + for t in tools: + key = t.get("tool_key") or t.get("tool_name") + provider = t.get("provider_id") or t.get("provider", "") + desc = t.get("tool_description") or t.get("description", "") + label = t.get("tool_label") or key + + # Format: - [provider/key] Label: Description + full_key = f"{provider}/{key}" if provider else key + lines.append(f"- [{full_key}] {label}: {desc}") + + return "\n".join(lines) diff --git a/api/core/llm_generator/vibe_prompts.py b/api/core/workflow/generator/prompts/vibe_prompts.py similarity index 80% rename from api/core/llm_generator/vibe_prompts.py rename to api/core/workflow/generator/prompts/vibe_prompts.py index 59a07f3c71..ace209e063 100644 --- a/api/core/llm_generator/vibe_prompts.py +++ b/api/core/workflow/generator/prompts/vibe_prompts.py @@ -10,7 +10,7 @@ import json import re from typing import Any -from core.llm_generator.vibe_config import ( +from core.workflow.generator.config import ( BUILTIN_NODE_SCHEMAS, DEFAULT_SUGGESTIONS, FALLBACK_RULES, @@ -100,6 +100,13 @@ You help users create AI automation workflows by generating workflow configurati + + For LLM, question-classifier, parameter-extractor nodes: + - You MUST include a "model" config with provider and name from available_models section + - Copy the EXACT provider and name values from available_models + - NEVER use openai/gpt-4o, openai/gpt-3.5-turbo, openai/gpt-4 unless they appear in available_models + - If available_models is empty or not provided, omit the model config entirely + ONLY use tools with status="configured" from available_tools. NEVER invent tool names like "webscraper", "email_sender", etc. @@ -217,12 +224,14 @@ You help users create AI automation workflows by generating workflow configurati "type": "llm", "title": "Analyze Content", "config": {{{{ + "model": {{{{"provider": "USE_FROM_AVAILABLE_MODELS", "name": "USE_FROM_AVAILABLE_MODELS", "mode": "chat"}}}}, "prompt_template": [ {{{{"role": "system", "text": "You are a helpful analyst."}}}}, {{{{"role": "user", "text": "Analyze this content:\\n\\n{{{{#fetch.body#}}}}"}}}} ] }}}} }}}} + NOTE: Replace "USE_FROM_AVAILABLE_MODELS" with actual values from available_models section! {{{{ @@ -344,6 +353,7 @@ Generate your JSON response now. Remember: """ + def format_available_nodes(nodes: list[dict[str, Any]] | None) -> str: """Format available nodes as XML with parameter schemas.""" lines = [""] @@ -591,7 +601,7 @@ def format_previous_attempt( def format_available_models(models: list[dict[str, Any]] | None) -> str: """Format available models as XML for prompt inclusion.""" if not models: - return "\n \n" + return "\n \n" lines = [""] for model in models: @@ -600,16 +610,30 @@ def format_available_models(models: list[dict[str, Any]] | None) -> str: lines.append(f' ') lines.append("") - # Add model selection rule + # Add model selection rule with concrete example lines.append("") lines.append("") - lines.append(" CRITICAL: For LLM, question-classifier, and parameter-extractor nodes, you MUST select a model from available_models.") - if len(models) == 1: - first_model = models[0] - lines.append(f' Use provider="{first_model.get("provider")}" and name="{first_model.get("model")}" for all model-dependent nodes.') - else: - lines.append(" Choose the most suitable model for each task from the available options.") - lines.append(" NEVER use models not listed in available_models (e.g., openai/gpt-4o if not listed).") + lines.append(" CRITICAL: For LLM, question-classifier, and parameter-extractor nodes:") + lines.append(" - You MUST include a 'model' field in the config") + lines.append(" - You MUST use ONLY models from available_models above") + lines.append(" - NEVER use openai/gpt-4o, gpt-3.5-turbo, gpt-4 unless they appear in available_models") + lines.append("") + + # Provide concrete JSON example to copy + first_model = models[0] + provider = first_model.get("provider", "unknown") + model_name = first_model.get("model", "unknown") + lines.append(" COPY THIS EXACT MODEL CONFIG for all LLM/question-classifier/parameter-extractor nodes:") + lines.append(f' "model": {{"provider": "{provider}", "name": "{model_name}", "mode": "chat"}}') + + if len(models) > 1: + lines.append("") + lines.append(" Alternative models you can use:") + for m in models[1:4]: # Show up to 3 alternatives + p = m.get("provider", "unknown") + n = m.get("model", "unknown") + lines.append(f' - "model": {{"provider": "{p}", "name": "{n}", "mode": "chat"}}') + lines.append("") return "\n".join(lines) @@ -1023,6 +1047,7 @@ def validate_node_parameters(nodes: list[dict[str, Any]]) -> list[str]: def extract_mermaid_from_response(data: dict[str, Any]) -> str: """Extract mermaid flowchart from parsed response.""" mermaid = data.get("mermaid", "") + if not mermaid: return "" @@ -1034,5 +1059,203 @@ def extract_mermaid_from_response(data: dict[str, Any]) -> str: if match: mermaid = match.group(1).strip() + # Sanitize edge labels to remove characters that break Mermaid parsing + # Edge labels in Mermaid are ONLY in the pattern: -->|label| + # We must NOT match |pipe| characters inside node labels like ["type=start|title=开始"] + def sanitize_edge_label(match: re.Match) -> str: + arrow = match.group(1) # --> or --- + label = match.group(2) # the label between pipes + # Remove or replace special characters that break Mermaid + # Parentheses, brackets, braces have special meaning in Mermaid + sanitized = re.sub(r'[(){}\[\]]', '', label) + return f"{arrow}|{sanitized}|" + + # Only match edge labels: --> or --- followed by |label| + # This pattern ensures we only sanitize actual edge labels, not node content + mermaid = re.sub(r'(-->|---)\|([^|]+)\|', sanitize_edge_label, mermaid) + return mermaid + +def classify_validation_errors( + nodes: list[dict[str, Any]], + available_models: list[dict[str, Any]] | None = None, + available_tools: list[dict[str, Any]] | None = None, + edges: list[dict[str, Any]] | None = None, +) -> dict[str, list[dict[str, Any]]]: + """ + Classify validation errors into fixable and user-required categories. + + This function uses the declarative rule engine to validate nodes. + The rule engine provides deterministic, testable validation without + relying on LLM judgment. + + Fixable errors can be automatically corrected by the LLM in subsequent + iterations. User-required errors need manual intervention. + + Args: + nodes: List of generated workflow nodes + available_models: List of models the user has configured + available_tools: List of available tools + edges: List of edges connecting nodes + + Returns: + dict with: + - "fixable": errors that LLM can fix automatically + - "user_required": errors that need user intervention + - "all_warnings": combined warning messages for backwards compatibility + - "stats": validation statistics + """ + from core.workflow.generator.validation import ValidationContext, ValidationEngine + + # Build validation context + context = ValidationContext( + nodes=nodes, + edges=edges or [], + available_models=available_models or [], + available_tools=available_tools or [], + ) + + # Run validation through rule engine + engine = ValidationEngine() + result = engine.validate(context) + + # Convert to legacy format for backwards compatibility + fixable: list[dict[str, Any]] = [] + user_required: list[dict[str, Any]] = [] + + for error in result.fixable_errors: + fixable.append({ + "node_id": error.node_id, + "node_type": error.node_type, + "error_type": error.rule_id, + "message": error.message, + "is_fixable": True, + "fix_hint": error.fix_hint, + "category": error.category.value, + "details": error.details, + }) + + for error in result.user_required_errors: + user_required.append({ + "node_id": error.node_id, + "node_type": error.node_type, + "error_type": error.rule_id, + "message": error.message, + "is_fixable": False, + "fix_hint": error.fix_hint, + "category": error.category.value, + "details": error.details, + }) + + # Include warnings in user_required (they're non-blocking but informative) + for error in result.warnings: + user_required.append({ + "node_id": error.node_id, + "node_type": error.node_type, + "error_type": error.rule_id, + "message": error.message, + "is_fixable": error.is_fixable, + "fix_hint": error.fix_hint, + "category": error.category.value, + "severity": "warning", + "details": error.details, + }) + + # Generate combined warnings for backwards compatibility + all_warnings = [e["message"] for e in fixable + user_required] + + return { + "fixable": fixable, + "user_required": user_required, + "all_warnings": all_warnings, + "stats": result.stats, + } + + +def build_fix_prompt( + fixable_errors: list[dict[str, Any]], + previous_nodes: list[dict[str, Any]], + available_models: list[dict[str, Any]] | None = None, +) -> str: + """ + Build a prompt for LLM to fix the identified errors. + + This creates a focused instruction that tells the LLM exactly what + to fix in the previous generation. + + Args: + fixable_errors: List of errors that can be automatically fixed + previous_nodes: The nodes from the previous generation attempt + available_models: Available models for model configuration fixes + + Returns: + Formatted prompt string for the fix iteration + """ + if not fixable_errors: + return "" + + parts = [""] + parts.append(" ") + parts.append(" Your previous generation has errors that need fixing.") + parts.append(" Please regenerate with the following corrections:") + parts.append(" ") + + # Group errors by node + errors_by_node: dict[str, list[dict[str, Any]]] = {} + for error in fixable_errors: + node_id = error["node_id"] + if node_id not in errors_by_node: + errors_by_node[node_id] = [] + errors_by_node[node_id].append(error) + + parts.append(" ") + for node_id, node_errors in errors_by_node.items(): + parts.append(f" ") + for error in node_errors: + error_type = error["error_type"] + message = error["message"] + fix_hint = error.get("fix_hint", "") + parts.append(f" ") + parts.append(f" {message}") + if fix_hint: + parts.append(f" {fix_hint}") + parts.append(" ") + parts.append(" ") + parts.append(" ") + + # Add model selection help if there are model-related errors + model_errors = [e for e in fixable_errors if "model" in e["error_type"]] + if model_errors and available_models: + parts.append(" ") + parts.append(" Use one of these models for nodes requiring model config:") + for model in available_models[:3]: # Show top 3 + provider = model.get("provider", "unknown") + name = model.get("model", "unknown") + parts.append(f' - {{"provider": "{provider}", "name": "{name}", "mode": "chat"}}') + parts.append(" ") + + # Add previous nodes summary for context + parts.append(" ") + for node in previous_nodes: + node_id = node.get("id", "unknown") + if node_id in errors_by_node: + # Only include nodes that have errors + node_type = node.get("type", "unknown") + title = node.get("title", "Untitled") + config_summary = json.dumps(node.get("config", {}), ensure_ascii=False)[:200] + parts.append(f" ") + parts.append(f" {config_summary}...") + parts.append(" ") + parts.append(" ") + + parts.append(" ") + parts.append(" 1. Keep the workflow structure and logic unchanged") + parts.append(" 2. Fix ONLY the errors listed above") + parts.append(" 3. Ensure all required fields are properly filled") + parts.append(" 4. Use variable references {{#node_id.field#}} where appropriate") + parts.append(" ") + parts.append("") + + return "\n".join(parts) + diff --git a/api/core/workflow/generator/runner.py b/api/core/workflow/generator/runner.py new file mode 100644 index 0000000000..1d81bc4483 --- /dev/null +++ b/api/core/workflow/generator/runner.py @@ -0,0 +1,194 @@ +import json +import logging +import re +from collections.abc import Sequence + +import json_repair + +from core.model_manager import ModelManager +from core.model_runtime.entities.message_entities import SystemPromptMessage, UserPromptMessage +from core.model_runtime.entities.model_entities import ModelType +from core.workflow.generator.prompts.builder_prompts import BUILDER_SYSTEM_PROMPT, BUILDER_USER_PROMPT +from core.workflow.generator.prompts.planner_prompts import ( + PLANNER_SYSTEM_PROMPT, + PLANNER_USER_PROMPT, + format_tools_for_planner, +) +from core.workflow.generator.prompts.vibe_prompts import ( + format_available_nodes, + format_available_tools, + parse_vibe_response, +) +from core.workflow.generator.utils.edge_repair import EdgeRepair +from core.workflow.generator.utils.mermaid_generator import generate_mermaid +from core.workflow.generator.utils.node_repair import NodeRepair +from core.workflow.generator.utils.workflow_validator import WorkflowValidator + +logger = logging.getLogger(__name__) + + +class WorkflowGenerator: + """ + Refactored Vibe Workflow Generator (Planner-Builder Architecture). + Extracts Vibe logic from the monolithic LLMGenerator. + """ + + @classmethod + def generate_workflow_flowchart( + cls, + tenant_id: str, + instruction: str, + model_config: dict, + available_nodes: Sequence[dict[str, object]] | None = None, + existing_nodes: Sequence[dict[str, object]] | None = None, + available_tools: Sequence[dict[str, object]] | None = None, + selected_node_ids: Sequence[str] | None = None, + previous_workflow: dict[str, object] | None = None, + regenerate_mode: bool = False, + preferred_language: str | None = None, + available_models: Sequence[dict[str, object]] | None = None, + max_fix_iterations: int = 2, + ): + """ + Generates a Dify Workflow Flowchart from natural language instruction. + + Pipeline: + 1. Planner: Analyze intent & select tools. + 2. Context Filter: Filter relevant tools (reduce tokens). + 3. Builder: Generate node configurations. + 4. Repair: Fix common node/edge issues (NodeRepair, EdgeRepair). + 5. Validator: Check for errors & generate friendly hints. + 6. Renderer: Deterministic Mermaid generation. + """ + model_manager = ModelManager() + model_instance = model_manager.get_model_instance( + tenant_id=tenant_id, + model_type=ModelType.LLM, + provider=model_config.get("provider", ""), + model=model_config.get("name", ""), + ) + model_parameters = model_config.get("completion_params", {}) + available_tools_list = list(available_tools) if available_tools else [] + + # --- STEP 1: PLANNER --- + planner_tools_context = format_tools_for_planner(available_tools_list) + planner_system = PLANNER_SYSTEM_PROMPT.format(tools_summary=planner_tools_context) + planner_user = PLANNER_USER_PROMPT.format(instruction=instruction) + + try: + response = model_instance.invoke_llm( + prompt_messages=[SystemPromptMessage(content=planner_system), UserPromptMessage(content=planner_user)], + model_parameters=model_parameters, + stream=False, + ) + plan_content = response.message.content + # Reuse parse_vibe_response logic or simple load + plan_data = parse_vibe_response(plan_content) + except Exception as e: + logger.exception("Planner failed") + return {"intent": "error", "error": f"Planning failed: {str(e)}"} + + if plan_data.get("intent") == "off_topic": + return { + "intent": "off_topic", + "message": plan_data.get("message", "I can only help with workflow creation."), + "suggestions": plan_data.get("suggestions", []), + } + + # --- STEP 2: CONTEXT FILTERING --- + required_tools = plan_data.get("required_tool_keys", []) + + filtered_tools = [] + if required_tools: + # Simple linear search (optimized version would use a map) + for tool in available_tools_list: + t_key = tool.get("tool_key") or tool.get("tool_name") + provider = tool.get("provider_id") or tool.get("provider") + full_key = f"{provider}/{t_key}" if provider else t_key + + # Check if this tool is in required list (match either full key or short name) + if t_key in required_tools or full_key in required_tools: + filtered_tools.append(tool) + else: + # If logic only, no tools needed + filtered_tools = [] + + # --- STEP 3: BUILDER --- + # Prepare context + tool_schemas = format_available_tools(filtered_tools) + # We need to construct a fake list structure for builtin nodes formatting if using format_available_nodes + # Actually format_available_nodes takes None to use defaults, or a list to add custom + # But we want to SHOW the builtins. format_available_nodes internally uses BUILTIN_NODE_SCHEMAS. + node_specs = format_available_nodes([]) + + builder_system = BUILDER_SYSTEM_PROMPT.format( + plan_context=json.dumps(plan_data.get("steps", []), indent=2), + tool_schemas=tool_schemas, + builtin_node_specs=node_specs, + ) + builder_user = BUILDER_USER_PROMPT.format(instruction=instruction) + + try: + build_res = model_instance.invoke_llm( + prompt_messages=[SystemPromptMessage(content=builder_system), UserPromptMessage(content=builder_user)], + model_parameters=model_parameters, + stream=False, + ) + # Builder output is raw JSON nodes/edges + build_content = build_res.message.content + match = re.search(r"```(?:json)?\s*([\s\S]+?)```", build_content) + if match: + build_content = match.group(1) + + workflow_data = json_repair.loads(build_content) + + if "nodes" not in workflow_data: + workflow_data["nodes"] = [] + if "edges" not in workflow_data: + workflow_data["edges"] = [] + + except Exception as e: + logger.exception("Builder failed") + return {"intent": "error", "error": f"Building failed: {str(e)}"} + + # --- STEP 3.4: NODE REPAIR --- + node_repair_result = NodeRepair.repair(workflow_data["nodes"]) + workflow_data["nodes"] = node_repair_result.nodes + + # --- STEP 3.5: EDGE REPAIR --- + repair_result = EdgeRepair.repair(workflow_data) + workflow_data = { + "nodes": repair_result.nodes, + "edges": repair_result.edges, + } + + # --- STEP 4: VALIDATOR --- + is_valid, hints = WorkflowValidator.validate(workflow_data, available_tools_list) + + # --- STEP 5: RENDERER --- + mermaid_code = generate_mermaid(workflow_data) + + # --- FINALIZE --- + # Combine validation hints with repair warnings + all_warnings = [h.message for h in hints] + repair_result.warnings + node_repair_result.warnings + + # Add stability warning (as requested by user) + stability_warning = "The generated workflow may require debugging." + if preferred_language and preferred_language.startswith("zh"): + stability_warning = "生成的 Workflow 可能需要调试。" + all_warnings.append(stability_warning) + + all_fixes = repair_result.repairs_made + node_repair_result.repairs_made + + return { + "intent": "generate", + "flowchart": mermaid_code, + "nodes": workflow_data["nodes"], + "edges": workflow_data["edges"], + "message": plan_data.get("plan_thought", "Generated workflow based on your request."), + "warnings": all_warnings, + "tool_recommendations": [], # Legacy field + "error": "", + "fix_iterations": 0, # Legacy + "fixed_issues": all_fixes, # Track what was auto-fixed + } diff --git a/api/core/workflow/generator/utils/edge_repair.py b/api/core/workflow/generator/utils/edge_repair.py new file mode 100644 index 0000000000..c1f37ae011 --- /dev/null +++ b/api/core/workflow/generator/utils/edge_repair.py @@ -0,0 +1,372 @@ +""" +Edge Repair Utility for Vibe Workflow Generation. + +This module provides intelligent edge repair capabilities for generated workflows. +It can detect and fix common edge issues: +- Missing edges between sequential nodes +- Incomplete branches for question-classifier and if-else nodes +- Orphaned nodes without connections + +The repair logic is deterministic and doesn't require LLM calls. +""" + +import logging +from dataclasses import dataclass, field +from typing import Any + +logger = logging.getLogger(__name__) + + +@dataclass +class RepairResult: + """Result of edge repair operation.""" + + nodes: list[dict[str, Any]] + edges: list[dict[str, Any]] + repairs_made: list[str] = field(default_factory=list) + warnings: list[str] = field(default_factory=list) + + @property + def was_repaired(self) -> bool: + """Check if any repairs were made.""" + return len(self.repairs_made) > 0 + + +class EdgeRepair: + """ + Intelligent edge repair for workflow graphs. + + Repairs are applied in order: + 1. Infer linear connections from node order (if no edges exist) + 2. Add missing branch edges for question-classifier + 3. Add missing branch edges for if-else + 4. Connect orphaned nodes + """ + + @classmethod + def repair(cls, workflow_data: dict[str, Any]) -> RepairResult: + """ + Repair edges in the workflow data. + + Args: + workflow_data: Dict containing 'nodes' and 'edges' + + Returns: + RepairResult with repaired nodes, edges, and repair logs + """ + nodes = list(workflow_data.get("nodes", [])) + edges = list(workflow_data.get("edges", [])) + repairs: list[str] = [] + warnings: list[str] = [] + + logger.info("[EdgeRepair] Starting repair: %d nodes, %d edges", len(nodes), len(edges)) + + # Build node lookup + node_map = {n.get("id"): n for n in nodes if n.get("id")} + node_ids = set(node_map.keys()) + + # 1. If no edges at all, infer linear chain + if not edges and len(nodes) > 1: + edges, inferred_repairs = cls._infer_linear_chain(nodes) + repairs.extend(inferred_repairs) + + # 2. Build edge index for analysis + outgoing_edges: dict[str, list[dict[str, Any]]] = {} + incoming_edges: dict[str, list[dict[str, Any]]] = {} + for edge in edges: + src = edge.get("source") + tgt = edge.get("target") + if src: + outgoing_edges.setdefault(src, []).append(edge) + if tgt: + incoming_edges.setdefault(tgt, []).append(edge) + + # 3. Repair question-classifier branches + for node in nodes: + if node.get("type") == "question-classifier": + new_edges, branch_repairs, branch_warnings = cls._repair_classifier_branches( + node, edges, outgoing_edges, node_ids + ) + edges.extend(new_edges) + repairs.extend(branch_repairs) + warnings.extend(branch_warnings) + # Update outgoing index + for edge in new_edges: + outgoing_edges.setdefault(edge.get("source"), []).append(edge) + + # 4. Repair if-else branches + for node in nodes: + if node.get("type") == "if-else": + new_edges, branch_repairs, branch_warnings = cls._repair_if_else_branches( + node, edges, outgoing_edges, node_ids + ) + edges.extend(new_edges) + repairs.extend(branch_repairs) + warnings.extend(branch_warnings) + # Update outgoing index + for edge in new_edges: + outgoing_edges.setdefault(edge.get("source"), []).append(edge) + + # 5. Connect orphaned nodes (nodes with no incoming edge, except start) + new_edges, orphan_repairs = cls._connect_orphaned_nodes( + nodes, edges, outgoing_edges, incoming_edges + ) + edges.extend(new_edges) + repairs.extend(orphan_repairs) + + # 6. Connect nodes with no outgoing edge to 'end' (except end nodes) + new_edges, terminal_repairs = cls._connect_terminal_nodes( + nodes, edges, outgoing_edges + ) + edges.extend(new_edges) + repairs.extend(terminal_repairs) + + logger.info("[EdgeRepair] Completed: %d repairs made, %d warnings", len(repairs), len(warnings)) + for r in repairs: + logger.info("[EdgeRepair] Repair: %s", r) + for w in warnings: + logger.info("[EdgeRepair] Warning: %s", w) + + return RepairResult( + nodes=nodes, + edges=edges, + repairs_made=repairs, + warnings=warnings, + ) + + @classmethod + def _infer_linear_chain(cls, nodes: list[dict[str, Any]]) -> tuple[list[dict[str, Any]], list[str]]: + """ + Infer a linear chain of edges from node order. + + This is used when no edges are provided at all. + """ + edges: list[dict[str, Any]] = [] + repairs: list[str] = [] + + # Filter to get ordered node IDs + node_ids = [n.get("id") for n in nodes if n.get("id")] + + if len(node_ids) < 2: + return edges, repairs + + # Create edges between consecutive nodes + for i in range(len(node_ids) - 1): + src = node_ids[i] + tgt = node_ids[i + 1] + edges.append({"source": src, "target": tgt}) + repairs.append(f"Inferred edge: {src} -> {tgt}") + + logger.info("[EdgeRepair] Inferred %d edges from node order (no edges provided)", len(edges)) + return edges, repairs + + @classmethod + def _repair_classifier_branches( + cls, + node: dict[str, Any], + edges: list[dict[str, Any]], + outgoing_edges: dict[str, list[dict[str, Any]]], + valid_node_ids: set[str], + ) -> tuple[list[dict[str, Any]], list[str], list[str]]: + """ + Repair missing branches for question-classifier nodes. + + For each class that doesn't have an edge, create one pointing to 'end'. + """ + new_edges: list[dict[str, Any]] = [] + repairs: list[str] = [] + warnings: list[str] = [] + + node_id = node.get("id") + if not node_id: + return new_edges, repairs, warnings + + config = node.get("config", {}) + classes = config.get("classes", []) + + if not classes: + return new_edges, repairs, warnings + + # Get existing sourceHandles for this node + existing_handles = set() + for edge in outgoing_edges.get(node_id, []): + handle = edge.get("sourceHandle") + if handle: + existing_handles.add(handle) + + # Find 'end' node as default target + end_node_id = "end" + if "end" not in valid_node_ids: + # Try to find an end node + for nid in valid_node_ids: + if "end" in nid.lower(): + end_node_id = nid + break + + # Add missing branches + for cls_def in classes: + if not isinstance(cls_def, dict): + continue + cls_id = cls_def.get("id") + cls_name = cls_def.get("name", cls_id) + + if cls_id and cls_id not in existing_handles: + new_edge = { + "source": node_id, + "sourceHandle": cls_id, + "target": end_node_id, + } + new_edges.append(new_edge) + repairs.append(f"Added missing branch edge for class '{cls_name}' -> {end_node_id}") + warnings.append( + f"Auto-connected question-classifier branch '{cls_name}' to '{end_node_id}'. " + "You may want to redirect this to a specific handler node." + ) + + return new_edges, repairs, warnings + + @classmethod + def _repair_if_else_branches( + cls, + node: dict[str, Any], + edges: list[dict[str, Any]], + outgoing_edges: dict[str, list[dict[str, Any]]], + valid_node_ids: set[str], + ) -> tuple[list[dict[str, Any]], list[str], list[str]]: + """ + Repair missing true/false branches for if-else nodes. + """ + new_edges: list[dict[str, Any]] = [] + repairs: list[str] = [] + warnings: list[str] = [] + + node_id = node.get("id") + if not node_id: + return new_edges, repairs, warnings + + # Get existing sourceHandles + existing_handles = set() + for edge in outgoing_edges.get(node_id, []): + handle = edge.get("sourceHandle") + if handle: + existing_handles.add(handle) + + # Find 'end' node as default target + end_node_id = "end" + if "end" not in valid_node_ids: + for nid in valid_node_ids: + if "end" in nid.lower(): + end_node_id = nid + break + + # Add missing branches + required_branches = ["true", "false"] + for branch in required_branches: + if branch not in existing_handles: + new_edge = { + "source": node_id, + "sourceHandle": branch, + "target": end_node_id, + } + new_edges.append(new_edge) + repairs.append(f"Added missing if-else '{branch}' branch -> {end_node_id}") + warnings.append( + f"Auto-connected if-else '{branch}' branch to '{end_node_id}'. " + "You may want to redirect this to a specific handler node." + ) + + return new_edges, repairs, warnings + + @classmethod + def _connect_orphaned_nodes( + cls, + nodes: list[dict[str, Any]], + edges: list[dict[str, Any]], + outgoing_edges: dict[str, list[dict[str, Any]]], + incoming_edges: dict[str, list[dict[str, Any]]], + ) -> tuple[list[dict[str, Any]], list[str]]: + """ + Connect orphaned nodes to the previous node in sequence. + + An orphaned node has no incoming edges and is not a 'start' node. + """ + new_edges: list[dict[str, Any]] = [] + repairs: list[str] = [] + + node_ids = [n.get("id") for n in nodes if n.get("id")] + node_types = {n.get("id"): n.get("type") for n in nodes} + + for i, node_id in enumerate(node_ids): + node_type = node_types.get(node_id) + + # Skip start nodes - they don't need incoming edges + if node_type == "start": + continue + + # Check if node has incoming edges + if node_id not in incoming_edges or not incoming_edges[node_id]: + # Find previous node to connect from + if i > 0: + prev_node_id = node_ids[i - 1] + new_edge = {"source": prev_node_id, "target": node_id} + new_edges.append(new_edge) + repairs.append(f"Connected orphaned node: {prev_node_id} -> {node_id}") + + # Update incoming_edges for subsequent checks + incoming_edges.setdefault(node_id, []).append(new_edge) + + return new_edges, repairs + + @classmethod + def _connect_terminal_nodes( + cls, + nodes: list[dict[str, Any]], + edges: list[dict[str, Any]], + outgoing_edges: dict[str, list[dict[str, Any]]], + ) -> tuple[list[dict[str, Any]], list[str]]: + """ + Connect terminal nodes (no outgoing edges) to 'end'. + + A terminal node has no outgoing edges and is not an 'end' node. + This ensures all branches eventually reach 'end'. + """ + new_edges: list[dict[str, Any]] = [] + repairs: list[str] = [] + + # Find end node + end_node_id = None + node_ids = set() + for n in nodes: + nid = n.get("id") + ntype = n.get("type") + if nid: + node_ids.add(nid) + if ntype == "end": + end_node_id = nid + + if not end_node_id: + # No end node found, can't connect + return new_edges, repairs + + for node in nodes: + node_id = node.get("id") + node_type = node.get("type") + + # Skip end nodes + if node_type == "end": + continue + + # Skip nodes that already have outgoing edges + if outgoing_edges.get(node_id): + continue + + # Connect to end + new_edge = {"source": node_id, "target": end_node_id} + new_edges.append(new_edge) + repairs.append(f"Connected terminal node to end: {node_id} -> {end_node_id}") + + # Update for subsequent checks + outgoing_edges.setdefault(node_id, []).append(new_edge) + + return new_edges, repairs + diff --git a/api/core/workflow/generator/utils/mermaid_generator.py b/api/core/workflow/generator/utils/mermaid_generator.py new file mode 100644 index 0000000000..135b5aa95d --- /dev/null +++ b/api/core/workflow/generator/utils/mermaid_generator.py @@ -0,0 +1,138 @@ +import logging +from typing import Any + +logger = logging.getLogger(__name__) + + +def generate_mermaid(workflow_data: dict[str, Any]) -> str: + """ + Generate a Mermaid flowchart from workflow data consisting of nodes and edges. + + Args: + workflow_data: Dict containing 'nodes' (list) and 'edges' (list) + + Returns: + String containing the Mermaid flowchart syntax + """ + nodes = workflow_data.get("nodes", []) + edges = workflow_data.get("edges", []) + + # DEBUG: Log input data + logger.debug("[MERMAID] Input nodes count: %d", len(nodes)) + logger.debug("[MERMAID] Input edges count: %d", len(edges)) + for i, node in enumerate(nodes): + logger.debug( + "[MERMAID] Node %d: id=%s, type=%s, title=%s", i, node.get("id"), node.get("type"), node.get("title") + ) + for i, edge in enumerate(edges): + logger.debug( + "[MERMAID] Edge %d: source=%s, target=%s, sourceHandle=%s", + i, + edge.get("source"), + edge.get("target"), + edge.get("sourceHandle"), + ) + + lines = ["flowchart TD"] + + # 1. Define Nodes + # Format: node_id["title
type"] or similar + # We will use the Vibe Workflow standard format: id["type=TYPE|title=TITLE"] + # Or specifically for tool nodes: id["type=tool|title=TITLE|tool=TOOL_KEY"] + + # Map of original IDs to safe Mermaid IDs + id_map = {} + + def get_safe_id(original_id: str) -> str: + if original_id == "end": + return "end_node" + if original_id == "subgraph": + return "subgraph_node" + # Mermaid IDs should be alphanumeric. + # If the ID has special chars, we might need to escape or hash, but Vibe usually generates simple IDs. + # We'll trust standard IDs but handle the reserved keyword 'end'. + return original_id + + for node in nodes: + node_id = node.get("id") + if not node_id: + continue + + safe_id = get_safe_id(node_id) + id_map[node_id] = safe_id + + node_type = node.get("type", "unknown") + title = node.get("title", "Untitled") + + # Escape quotes in title + safe_title = title.replace('"', "'") + + if node_type == "tool": + config = node.get("config", {}) + # Try multiple fields for tool reference + tool_ref = ( + config.get("tool_key") + or config.get("tool") + or config.get("tool_name") + or node.get("tool_name") + or "unknown" + ) + node_def = f'{safe_id}["type={node_type}|title={safe_title}|tool={tool_ref}"]' + else: + node_def = f'{safe_id}["type={node_type}|title={safe_title}"]' + + lines.append(f" {node_def}") + + # 2. Define Edges + # Format: source --> target + + # Track defined nodes to avoid edge errors + defined_node_ids = {n.get("id") for n in nodes if n.get("id")} + + for edge in edges: + source = edge.get("source") + target = edge.get("target") + + # Skip invalid edges + if not source or not target: + continue + + if source not in defined_node_ids or target not in defined_node_ids: + # Log skipped edges for debugging + logger.warning( + "[MERMAID] Skipping edge: source=%s (exists=%s), target=%s (exists=%s)", + source, + source in defined_node_ids, + target, + target in defined_node_ids, + ) + continue + + safe_source = id_map.get(source, source) + safe_target = id_map.get(target, target) + + # Handle conditional branches (true/false) if present + # In Dify workflow, sourceHandle is often used for this + source_handle = edge.get("sourceHandle") + label = "" + + if source_handle == "true": + label = "|true|" + elif source_handle == "false": + label = "|false|" + elif source_handle and source_handle != "source": + # For question-classifier or other multi-path nodes + # Clean up handle for display if needed + safe_handle = str(source_handle).replace('"', "'") + label = f"|{safe_handle}|" + + edge_line = f" {safe_source} -->{label} {safe_target}" + logger.debug("[MERMAID] Adding edge: %s", edge_line) + lines.append(edge_line) + + # Start/End nodes are implicitly handled if they are in the 'nodes' list + # If not, we might need to add them, but usually the Builder should produce them. + + result = "\n".join(lines) + logger.debug("[MERMAID] Final output:\n%s", result) + return result diff --git a/api/core/workflow/generator/utils/node_repair.py b/api/core/workflow/generator/utils/node_repair.py new file mode 100644 index 0000000000..fa4d337635 --- /dev/null +++ b/api/core/workflow/generator/utils/node_repair.py @@ -0,0 +1,96 @@ +""" +Node Repair Utility for Vibe Workflow Generation. + +This module provides intelligent node configuration repair capabilities. +It can detect and fix common node configuration issues: +- Invalid comparison operators in if-else nodes (e.g. '>=' -> '≥') +""" + +import copy +import logging +from dataclasses import dataclass, field +from typing import Any + +logger = logging.getLogger(__name__) + + +@dataclass +class NodeRepairResult: + """Result of node repair operation.""" + + nodes: list[dict[str, Any]] + repairs_made: list[str] = field(default_factory=list) + warnings: list[str] = field(default_factory=list) + + @property + def was_repaired(self) -> bool: + """Check if any repairs were made.""" + return len(self.repairs_made) > 0 + + +class NodeRepair: + """ + Intelligent node configuration repair. + """ + + OPERATOR_MAP = { + ">=": "≥", + "<=": "≤", + "!=": "≠", + "==": "=", + } + + @classmethod + def repair(cls, nodes: list[dict[str, Any]]) -> NodeRepairResult: + """ + Repair node configurations. + + Args: + nodes: List of node dictionaries + + Returns: + NodeRepairResult with repaired nodes and logs + """ + # Deep copy to avoid mutating original + nodes = copy.deepcopy(nodes) + repairs: list[str] = [] + warnings: list[str] = [] + + logger.info("[NodeRepair] Starting repair: %d nodes", len(nodes)) + + for node in nodes: + node_type = node.get("type") + + if node_type == "if-else": + cls._repair_if_else_operators(node, repairs) + + # Add other node type repairs here as needed + + if repairs: + logger.info("[NodeRepair] Completed: %d repairs made", len(repairs)) + for r in repairs: + logger.info("[NodeRepair] Repair: %s", r) + + return NodeRepairResult( + nodes=nodes, + repairs_made=repairs, + warnings=warnings, + ) + + @classmethod + def _repair_if_else_operators(cls, node: dict[str, Any], repairs: list[str]): + """ + Normalize comparison operators in if-else nodes. + """ + node_id = node.get("id", "unknown") + config = node.get("config", {}) + cases = config.get("cases", []) + + for case in cases: + conditions = case.get("conditions", []) + for condition in conditions: + op = condition.get("comparison_operator") + if op in cls.OPERATOR_MAP: + new_op = cls.OPERATOR_MAP[op] + condition["comparison_operator"] = new_op + repairs.append(f"Normalized operator '{op}' to '{new_op}' in node '{node_id}'") diff --git a/api/core/workflow/generator/utils/workflow_validator.py b/api/core/workflow/generator/utils/workflow_validator.py new file mode 100644 index 0000000000..858687f7b7 --- /dev/null +++ b/api/core/workflow/generator/utils/workflow_validator.py @@ -0,0 +1,96 @@ +import logging +from dataclasses import dataclass +from typing import Any + +from core.workflow.generator.validation.context import ValidationContext +from core.workflow.generator.validation.engine import ValidationEngine +from core.workflow.generator.validation.rules import Severity + +logger = logging.getLogger(__name__) + + +@dataclass +class ValidationHint: + """Legacy compatibility class for validation hints.""" + + node_id: str + field: str + message: str + severity: str # 'error', 'warning' + suggestion: str = None + node_type: str = None # Added for test compatibility + + # Alias for potential old code using 'type' instead of 'severity' + @property + def type(self) -> str: + return self.severity + + @property + def element_id(self) -> str: + return self.node_id + + +FriendlyHint = ValidationHint # Alias for backward compatibility + + +class WorkflowValidator: + """ + Validates the generated workflow configuration (nodes and edges). + Wraps the new ValidationEngine for backward compatibility. + """ + + @classmethod + def validate( + cls, + workflow_data: dict[str, Any], + available_tools: list[dict[str, Any]], + available_models: list[dict[str, Any]] | None = None, + ) -> tuple[bool, list[ValidationHint]]: + """ + Validate workflow data and return validity status and hints. + + Args: + workflow_data: Dict containing 'nodes' and 'edges' + available_tools: List of available tool configurations + available_models: List of available models (added for Vibe compat) + + Returns: + Tuple(max_severity_is_not_error, list_of_hints) + """ + nodes = workflow_data.get("nodes", []) + edges = workflow_data.get("edges", []) + + # Create context + context = ValidationContext( + nodes=nodes, + edges=edges, + available_models=available_models or [], + available_tools=available_tools or [], + ) + + # Run validation engine + engine = ValidationEngine() + result = engine.validate(context) + + # Convert engine errors to legacy hints + hints: list[ValidationHint] = [] + + for error in result.all_errors: + # Map severity + severity = "error" if error.severity == Severity.ERROR else "warning" + + # Map field from message or details if possible (heuristic) + field_name = error.details.get("field", "unknown") + + hints.append( + ValidationHint( + node_id=error.node_id, + field=field_name, + message=error.message, + severity=severity, + suggestion=error.fix_hint, + node_type=error.node_type, + ) + ) + + return result.is_valid, hints diff --git a/api/core/workflow/generator/validation/__init__.py b/api/core/workflow/generator/validation/__init__.py new file mode 100644 index 0000000000..4ce2d263ac --- /dev/null +++ b/api/core/workflow/generator/validation/__init__.py @@ -0,0 +1,45 @@ +""" +Validation Rule Engine for Vibe Workflow Generation. + +This module provides a declarative, schema-based validation system for +generated workflow nodes. It classifies errors into fixable (LLM can auto-fix) +and user-required (needs manual intervention) categories. + +Usage: + from core.workflow.generator.validation import ValidationEngine, ValidationContext + + context = ValidationContext( + available_models=[...], + available_tools=[...], + nodes=[...], + edges=[...], + ) + engine = ValidationEngine() + result = engine.validate(context) + + # Access classified errors + fixable_errors = result.fixable_errors + user_required_errors = result.user_required_errors +""" + +from core.workflow.generator.validation.context import ValidationContext +from core.workflow.generator.validation.engine import ValidationEngine, ValidationResult +from core.workflow.generator.validation.rules import ( + RuleCategory, + Severity, + ValidationError, + ValidationRule, +) + +__all__ = [ + "RuleCategory", + "Severity", + "ValidationContext", + "ValidationEngine", + "ValidationError", + "ValidationResult", + "ValidationRule", +] + + + diff --git a/api/core/workflow/generator/validation/context.py b/api/core/workflow/generator/validation/context.py new file mode 100644 index 0000000000..3cb44429f1 --- /dev/null +++ b/api/core/workflow/generator/validation/context.py @@ -0,0 +1,123 @@ +""" +Validation Context for the Rule Engine. + +The ValidationContext holds all the data needed for validation: +- Generated nodes and edges +- Available models, tools, and datasets +- Node output schemas for variable reference validation +""" + +from dataclasses import dataclass, field +from typing import Any + + +@dataclass +class ValidationContext: + """ + Context object containing all data needed for validation. + + This is passed to each validation rule, providing access to: + - The nodes being validated + - Edge connections between nodes + - Available external resources (models, tools) + """ + + # Generated workflow data + nodes: list[dict[str, Any]] = field(default_factory=list) + edges: list[dict[str, Any]] = field(default_factory=list) + + # Available external resources + available_models: list[dict[str, Any]] = field(default_factory=list) + available_tools: list[dict[str, Any]] = field(default_factory=list) + + # Cached lookups (populated lazily) + _node_map: dict[str, dict[str, Any]] | None = field(default=None, repr=False) + _model_set: set[tuple[str, str]] | None = field(default=None, repr=False) + _tool_set: set[str] | None = field(default=None, repr=False) + _configured_tool_set: set[str] | None = field(default=None, repr=False) + + @property + def node_map(self) -> dict[str, dict[str, Any]]: + """Get a map of node_id -> node for quick lookup.""" + if self._node_map is None: + self._node_map = {node.get("id", ""): node for node in self.nodes} + return self._node_map + + @property + def model_set(self) -> set[tuple[str, str]]: + """Get a set of (provider, model_name) tuples for quick lookup.""" + if self._model_set is None: + self._model_set = { + (m.get("provider", ""), m.get("model", "")) + for m in self.available_models + } + return self._model_set + + @property + def tool_set(self) -> set[str]: + """Get a set of all tool keys (both configured and unconfigured).""" + if self._tool_set is None: + self._tool_set = set() + for tool in self.available_tools: + provider = tool.get("provider_id") or tool.get("provider", "") + tool_key = tool.get("tool_key") or tool.get("tool_name", "") + if provider and tool_key: + self._tool_set.add(f"{provider}/{tool_key}") + if tool_key: + self._tool_set.add(tool_key) + return self._tool_set + + @property + def configured_tool_set(self) -> set[str]: + """Get a set of configured (authorized) tool keys.""" + if self._configured_tool_set is None: + self._configured_tool_set = set() + for tool in self.available_tools: + if not tool.get("is_team_authorization", False): + continue + provider = tool.get("provider_id") or tool.get("provider", "") + tool_key = tool.get("tool_key") or tool.get("tool_name", "") + if provider and tool_key: + self._configured_tool_set.add(f"{provider}/{tool_key}") + if tool_key: + self._configured_tool_set.add(tool_key) + return self._configured_tool_set + + def has_model(self, provider: str, model_name: str) -> bool: + """Check if a model is available.""" + return (provider, model_name) in self.model_set + + def has_tool(self, tool_key: str) -> bool: + """Check if a tool exists (configured or not).""" + return tool_key in self.tool_set + + def is_tool_configured(self, tool_key: str) -> bool: + """Check if a tool is configured and ready to use.""" + return tool_key in self.configured_tool_set + + def get_node(self, node_id: str) -> dict[str, Any] | None: + """Get a node by its ID.""" + return self.node_map.get(node_id) + + def get_node_ids(self) -> set[str]: + """Get all node IDs in the workflow.""" + return set(self.node_map.keys()) + + def get_upstream_nodes(self, node_id: str) -> list[str]: + """Get IDs of nodes that connect to this node (upstream).""" + return [ + edge.get("source", "") + for edge in self.edges + if edge.get("target") == node_id + ] + + def get_downstream_nodes(self, node_id: str) -> list[str]: + """Get IDs of nodes that this node connects to (downstream).""" + return [ + edge.get("target", "") + for edge in self.edges + if edge.get("source") == node_id + ] + + + diff --git a/api/core/workflow/generator/validation/engine.py b/api/core/workflow/generator/validation/engine.py new file mode 100644 index 0000000000..de585bef19 --- /dev/null +++ b/api/core/workflow/generator/validation/engine.py @@ -0,0 +1,266 @@ +""" +Validation Engine - Core validation logic. + +The ValidationEngine orchestrates rule execution and aggregates results. +It provides a clean interface for validating workflow nodes. +""" + +import logging +from dataclasses import dataclass, field +from typing import Any + +from core.workflow.generator.validation.context import ValidationContext +from core.workflow.generator.validation.rules import ( + RuleCategory, + Severity, + ValidationError, + get_registry, +) + +logger = logging.getLogger(__name__) + + +@dataclass +class ValidationResult: + """ + Result of validation containing all errors classified by fixability. + + Attributes: + all_errors: All validation errors found + fixable_errors: Errors that LLM can automatically fix + user_required_errors: Errors that require user intervention + warnings: Non-blocking warnings + stats: Validation statistics + """ + + all_errors: list[ValidationError] = field(default_factory=list) + fixable_errors: list[ValidationError] = field(default_factory=list) + user_required_errors: list[ValidationError] = field(default_factory=list) + warnings: list[ValidationError] = field(default_factory=list) + stats: dict[str, int] = field(default_factory=dict) + + @property + def has_errors(self) -> bool: + """Check if there are any errors (excluding warnings).""" + return len(self.fixable_errors) > 0 or len(self.user_required_errors) > 0 + + @property + def has_fixable_errors(self) -> bool: + """Check if there are fixable errors.""" + return len(self.fixable_errors) > 0 + + @property + def is_valid(self) -> bool: + """Check if validation passed (no errors, warnings are OK).""" + return not self.has_errors + + def to_dict(self) -> dict[str, Any]: + """Convert to dictionary for API response.""" + return { + "fixable": [e.to_dict() for e in self.fixable_errors], + "user_required": [e.to_dict() for e in self.user_required_errors], + "warnings": [e.to_dict() for e in self.warnings], + "all_warnings": [e.message for e in self.all_errors], + "stats": self.stats, + } + + def get_error_messages(self) -> list[str]: + """Get all error messages as strings.""" + return [e.message for e in self.all_errors] + + def get_fixable_by_node(self) -> dict[str, list[ValidationError]]: + """Group fixable errors by node ID.""" + result: dict[str, list[ValidationError]] = {} + for error in self.fixable_errors: + if error.node_id not in result: + result[error.node_id] = [] + result[error.node_id].append(error) + return result + + +class ValidationEngine: + """ + The main validation engine. + + Usage: + engine = ValidationEngine() + context = ValidationContext(nodes=[...], available_models=[...]) + result = engine.validate(context) + """ + + def __init__(self): + self._registry = get_registry() + + def validate(self, context: ValidationContext) -> ValidationResult: + """ + Validate all nodes in the context. + + Args: + context: ValidationContext with nodes, edges, and available resources + + Returns: + ValidationResult with classified errors + """ + result = ValidationResult() + stats = { + "total_nodes": len(context.nodes), + "total_rules_checked": 0, + "total_errors": 0, + "fixable_count": 0, + "user_required_count": 0, + "warning_count": 0, + } + + # Validate each node + for node in context.nodes: + node_type = node.get("type", "unknown") + node_id = node.get("id", "unknown") + + # Get applicable rules for this node type + rules = self._registry.get_rules_for_node(node_type) + + for rule in rules: + stats["total_rules_checked"] += 1 + + try: + errors = rule.check(node, context) + for error in errors: + result.all_errors.append(error) + stats["total_errors"] += 1 + + # Classify by severity and fixability + if error.severity == Severity.WARNING: + result.warnings.append(error) + stats["warning_count"] += 1 + elif error.is_fixable: + result.fixable_errors.append(error) + stats["fixable_count"] += 1 + else: + result.user_required_errors.append(error) + stats["user_required_count"] += 1 + + except Exception: + logger.exception( + "Rule '%s' failed for node '%s'", + rule.id, + node_id, + ) + # Don't let a rule failure break the entire validation + continue + + # Validate edges separately + edge_errors = self._validate_edges(context) + for error in edge_errors: + result.all_errors.append(error) + stats["total_errors"] += 1 + if error.is_fixable: + result.fixable_errors.append(error) + stats["fixable_count"] += 1 + else: + result.user_required_errors.append(error) + stats["user_required_count"] += 1 + + result.stats = stats + + logger.debug( + "[Validation] Completed: %d nodes, %d rules, %d errors (%d fixable, %d user-required)", + stats["total_nodes"], + stats["total_rules_checked"], + stats["total_errors"], + stats["fixable_count"], + stats["user_required_count"], + ) + + return result + + def _validate_edges(self, context: ValidationContext) -> list[ValidationError]: + """Validate edge connections.""" + errors: list[ValidationError] = [] + valid_node_ids = context.get_node_ids() + + for edge in context.edges: + source = edge.get("source", "") + target = edge.get("target", "") + + if source and source not in valid_node_ids: + errors.append( + ValidationError( + rule_id="edge.source.invalid", + node_id=source, + node_type="edge", + category=RuleCategory.SEMANTIC, + severity=Severity.ERROR, + is_fixable=True, + message=f"Edge source '{source}' does not exist", + fix_hint="Update edge to reference existing node", + ) + ) + + if target and target not in valid_node_ids: + errors.append( + ValidationError( + rule_id="edge.target.invalid", + node_id=target, + node_type="edge", + category=RuleCategory.SEMANTIC, + severity=Severity.ERROR, + is_fixable=True, + message=f"Edge target '{target}' does not exist", + fix_hint="Update edge to reference existing node", + ) + ) + + return errors + + def validate_single_node( + self, + node: dict[str, Any], + context: ValidationContext, + ) -> list[ValidationError]: + """ + Validate a single node. + + Useful for incremental validation when a node is added/modified. + """ + node_type = node.get("type", "unknown") + rules = self._registry.get_rules_for_node(node_type) + + errors: list[ValidationError] = [] + for rule in rules: + try: + errors.extend(rule.check(node, context)) + except Exception: + logger.exception("Rule '%s' failed", rule.id) + + return errors + + +def validate_nodes( + nodes: list[dict[str, Any]], + edges: list[dict[str, Any]] | None = None, + available_models: list[dict[str, Any]] | None = None, + available_tools: list[dict[str, Any]] | None = None, +) -> ValidationResult: + """ + Convenience function to validate nodes without creating engine/context manually. + + Args: + nodes: List of workflow nodes to validate + edges: Optional list of edges + available_models: Optional list of available models + available_tools: Optional list of available tools + + Returns: + ValidationResult with classified errors + """ + context = ValidationContext( + nodes=nodes, + edges=edges or [], + available_models=available_models or [], + available_tools=available_tools or [], + ) + engine = ValidationEngine() + return engine.validate(context) + + + diff --git a/api/core/workflow/generator/validation/rules.py b/api/core/workflow/generator/validation/rules.py new file mode 100644 index 0000000000..761dde5e96 --- /dev/null +++ b/api/core/workflow/generator/validation/rules.py @@ -0,0 +1,1148 @@ +""" +Validation Rules Definition and Registry. + +This module defines: +- ValidationRule: The rule structure +- RuleCategory: Categories of validation rules +- Severity: Error severity levels +- ValidationError: Error output structure +- All built-in validation rules +""" + +import re +from collections.abc import Callable +from dataclasses import dataclass, field +from enum import Enum +from typing import TYPE_CHECKING, Any + +if TYPE_CHECKING: + from core.workflow.generator.validation.context import ValidationContext + + +class RuleCategory(Enum): + """Categories of validation rules.""" + + STRUCTURE = "structure" # Field existence, types, formats + SEMANTIC = "semantic" # Variable references, edge connections + REFERENCE = "reference" # External resources (models, tools, datasets) + + +class Severity(Enum): + """Severity levels for validation errors.""" + + ERROR = "error" # Must be fixed + WARNING = "warning" # Should be fixed but not blocking + + +@dataclass +class ValidationError: + """ + Represents a validation error found during rule execution. + + Attributes: + rule_id: The ID of the rule that generated this error + node_id: The ID of the node with the error + node_type: The type of the node + category: The rule category + severity: Error severity + is_fixable: Whether LLM can auto-fix this error + message: Human-readable error message + fix_hint: Hint for LLM to fix the error + details: Additional error details + """ + + rule_id: str + node_id: str + node_type: str + category: RuleCategory + severity: Severity + is_fixable: bool + message: str + fix_hint: str = "" + details: dict[str, Any] = field(default_factory=dict) + + def to_dict(self) -> dict[str, Any]: + """Convert to dictionary for API response.""" + return { + "rule_id": self.rule_id, + "node_id": self.node_id, + "node_type": self.node_type, + "category": self.category.value, + "severity": self.severity.value, + "is_fixable": self.is_fixable, + "message": self.message, + "fix_hint": self.fix_hint, + "details": self.details, + } + + +# Type alias for rule check functions +RuleCheckFn = Callable[ + [dict[str, Any], "ValidationContext"], + list[ValidationError], +] + + +@dataclass +class ValidationRule: + """ + A validation rule definition. + + Attributes: + id: Unique rule identifier (e.g., "llm.model.required") + node_types: List of node types this rule applies to, or ["*"] for all + category: The rule category + severity: Default severity for errors from this rule + is_fixable: Whether errors from this rule can be auto-fixed by LLM + check: The validation function + description: Human-readable description of what this rule checks + fix_hint: Default hint for fixing errors from this rule + """ + + id: str + node_types: list[str] + category: RuleCategory + severity: Severity + is_fixable: bool + check: RuleCheckFn + description: str = "" + fix_hint: str = "" + + def applies_to(self, node_type: str) -> bool: + """Check if this rule applies to a given node type.""" + return "*" in self.node_types or node_type in self.node_types + + +# ============================================================================= +# Rule Registry +# ============================================================================= + + +class RuleRegistry: + """ + Registry for validation rules. + + Rules are registered here and can be retrieved by category or node type. + """ + + def __init__(self): + self._rules: list[ValidationRule] = [] + + def register(self, rule: ValidationRule) -> None: + """Register a validation rule.""" + self._rules.append(rule) + + def get_rules_for_node(self, node_type: str) -> list[ValidationRule]: + """Get all rules that apply to a given node type.""" + return [r for r in self._rules if r.applies_to(node_type)] + + def get_rules_by_category(self, category: RuleCategory) -> list[ValidationRule]: + """Get all rules in a given category.""" + return [r for r in self._rules if r.category == category] + + def get_all_rules(self) -> list[ValidationRule]: + """Get all registered rules.""" + return list(self._rules) + + +# Global rule registry instance +_registry = RuleRegistry() + + +def register_rule(rule: ValidationRule) -> ValidationRule: + """Decorator/function to register a rule with the global registry.""" + _registry.register(rule) + return rule + + +def get_registry() -> RuleRegistry: + """Get the global rule registry.""" + return _registry + + +# ============================================================================= +# Helper Functions for Rule Implementations +# ============================================================================= + +# Placeholder patterns that indicate user needs to fill in values +PLACEHOLDER_PATTERNS = [ + "PLEASE_SELECT", + "YOUR_", + "TODO", + "PLACEHOLDER", + "EXAMPLE_", + "REPLACE_", + "INSERT_", + "ADD_YOUR_", +] + +# Variable reference pattern: {{#node_id.field#}} +VARIABLE_REF_PATTERN = re.compile(r"\{\{#([^.#]+)\.([^#]+)#\}\}") + + +def is_placeholder(value: Any) -> bool: + """Check if a value appears to be a placeholder.""" + if not isinstance(value, str): + return False + value_upper = value.upper() + return any(p in value_upper for p in PLACEHOLDER_PATTERNS) + + +def extract_variable_refs(text: str) -> list[tuple[str, str]]: + """ + Extract variable references from text. + + Returns list of (node_id, field_name) tuples. + """ + return VARIABLE_REF_PATTERN.findall(text) + + +def check_required_field( + config: dict[str, Any], + field_name: str, + node_id: str, + node_type: str, + rule_id: str, + fix_hint: str = "", +) -> ValidationError | None: + """Helper to check if a required field exists and is non-empty.""" + value = config.get(field_name) + if value is None or value == "" or (isinstance(value, list) and len(value) == 0): + return ValidationError( + rule_id=rule_id, + node_id=node_id, + node_type=node_type, + category=RuleCategory.STRUCTURE, + severity=Severity.ERROR, + is_fixable=True, + message=f"Node '{node_id}': missing required field '{field_name}'", + fix_hint=fix_hint or f"Add '{field_name}' to the node config", + ) + return None + + +# ============================================================================= +# Structure Rules - Field existence, types, formats +# ============================================================================= + + +def _check_llm_prompt_template(node: dict[str, Any], ctx: "ValidationContext") -> list[ValidationError]: + """Check that LLM node has prompt_template.""" + errors: list[ValidationError] = [] + node_id = node.get("id", "unknown") + config = node.get("config", {}) + + err = check_required_field( + config, + "prompt_template", + node_id, + "llm", + "llm.prompt_template.required", + "Add prompt_template with system and user messages", + ) + if err: + errors.append(err) + + return errors + + +def _check_http_request_url(node: dict[str, Any], ctx: "ValidationContext") -> list[ValidationError]: + """Check that http-request node has url and method.""" + errors: list[ValidationError] = [] + node_id = node.get("id", "unknown") + config = node.get("config", {}) + + # Check url + url = config.get("url", "") + if not url: + errors.append( + ValidationError( + rule_id="http.url.required", + node_id=node_id, + node_type="http-request", + category=RuleCategory.STRUCTURE, + severity=Severity.ERROR, + is_fixable=True, + message=f"Node '{node_id}': http-request missing required 'url'", + fix_hint="Add url - use {{#start.url#}} or a concrete URL", + ) + ) + elif is_placeholder(url): + errors.append( + ValidationError( + rule_id="http.url.placeholder", + node_id=node_id, + node_type="http-request", + category=RuleCategory.STRUCTURE, + severity=Severity.ERROR, + is_fixable=True, + message=f"Node '{node_id}': url contains placeholder value", + fix_hint="Replace placeholder with actual URL or variable reference", + ) + ) + + # Check method + method = config.get("method", "") + if not method: + errors.append( + ValidationError( + rule_id="http.method.required", + node_id=node_id, + node_type="http-request", + category=RuleCategory.STRUCTURE, + severity=Severity.ERROR, + is_fixable=True, + message=f"Node '{node_id}': http-request missing 'method'", + fix_hint="Add method: GET, POST, PUT, DELETE, or PATCH", + ) + ) + + return errors + + +def _check_code_node(node: dict[str, Any], ctx: "ValidationContext") -> list[ValidationError]: + """Check that code node has code and language.""" + errors: list[ValidationError] = [] + node_id = node.get("id", "unknown") + config = node.get("config", {}) + + err = check_required_field( + config, + "code", + node_id, + "code", + "code.code.required", + "Add code with a main() function that returns a dict", + ) + if err: + errors.append(err) + + err = check_required_field( + config, + "language", + node_id, + "code", + "code.language.required", + "Add language: python3 or javascript", + ) + if err: + errors.append(err) + + return errors + + +def _check_question_classifier(node: dict[str, Any], ctx: "ValidationContext") -> list[ValidationError]: + """Check that question-classifier has classes.""" + errors: list[ValidationError] = [] + node_id = node.get("id", "unknown") + config = node.get("config", {}) + + err = check_required_field( + config, + "classes", + node_id, + "question-classifier", + "classifier.classes.required", + "Add classes array with id and name for each classification", + ) + if err: + errors.append(err) + + return errors + + +def _check_parameter_extractor(node: dict[str, Any], ctx: "ValidationContext") -> list[ValidationError]: + """Check that parameter-extractor has parameters and instruction.""" + errors: list[ValidationError] = [] + node_id = node.get("id", "unknown") + config = node.get("config", {}) + + err = check_required_field( + config, + "parameters", + node_id, + "parameter-extractor", + "extractor.parameters.required", + "Add parameters array with name, type, description fields", + ) + if err: + errors.append(err) + else: + # Check individual parameters for required fields + parameters = config.get("parameters", []) + if isinstance(parameters, list): + for i, param in enumerate(parameters): + if isinstance(param, dict): + # Check for 'required' field (boolean) + if "required" not in param: + errors.append( + ValidationError( + rule_id="extractor.param.required_field.missing", + node_id=node_id, + node_type="parameter-extractor", + category=RuleCategory.STRUCTURE, + severity=Severity.ERROR, + is_fixable=True, + message=f"Node '{node_id}': parameter[{i}] missing 'required' field", + fix_hint=f"Add 'required': True to parameter '{param.get('name', 'unknown')}'", + details={"param_index": i, "param_name": param.get("name")}, + ) + ) + + # instruction is recommended but not strictly required + if not config.get("instruction"): + errors.append( + ValidationError( + rule_id="extractor.instruction.recommended", + node_id=node_id, + node_type="parameter-extractor", + category=RuleCategory.STRUCTURE, + severity=Severity.WARNING, + is_fixable=True, + message=f"Node '{node_id}': parameter-extractor should have 'instruction'", + fix_hint="Add instruction describing what to extract", + ) + ) + + return errors + + +def _check_knowledge_retrieval(node: dict[str, Any], ctx: "ValidationContext") -> list[ValidationError]: + """Check that knowledge-retrieval has dataset_ids.""" + errors: list[ValidationError] = [] + node_id = node.get("id", "unknown") + config = node.get("config", {}) + + dataset_ids = config.get("dataset_ids", []) + if not dataset_ids: + errors.append( + ValidationError( + rule_id="knowledge.dataset.required", + node_id=node_id, + node_type="knowledge-retrieval", + category=RuleCategory.STRUCTURE, + severity=Severity.ERROR, + is_fixable=False, # User must select knowledge base + message=f"Node '{node_id}': knowledge-retrieval missing 'dataset_ids'", + fix_hint="User must select knowledge bases in the UI", + ) + ) + else: + # Check for placeholder values + for ds_id in dataset_ids: + if is_placeholder(ds_id): + errors.append( + ValidationError( + rule_id="knowledge.dataset.placeholder", + node_id=node_id, + node_type="knowledge-retrieval", + category=RuleCategory.STRUCTURE, + severity=Severity.ERROR, + is_fixable=False, + message=f"Node '{node_id}': dataset_ids contains placeholder", + fix_hint="User must replace placeholder with actual knowledge base ID", + details={"placeholder_value": ds_id}, + ) + ) + break + + return errors + + +def _check_end_node(node: dict[str, Any], ctx: "ValidationContext") -> list[ValidationError]: + """Check that end node has outputs defined.""" + errors: list[ValidationError] = [] + node_id = node.get("id", "unknown") + config = node.get("config", {}) + + outputs = config.get("outputs", []) + if not outputs: + errors.append( + ValidationError( + rule_id="end.outputs.recommended", + node_id=node_id, + node_type="end", + category=RuleCategory.STRUCTURE, + severity=Severity.WARNING, + is_fixable=True, + message="End node should define output variables", + fix_hint="Add outputs array with variable and value_selector", + ) + ) + + return errors + + +# ============================================================================= +# Semantic Rules - Variable references, edge connections +# ============================================================================= + + +def _check_variable_references(node: dict[str, Any], ctx: "ValidationContext") -> list[ValidationError]: + """Check that variable references point to valid nodes.""" + errors: list[ValidationError] = [] + node_id = node.get("id", "unknown") + node_type = node.get("type", "unknown") + config = node.get("config", {}) + + # Get all valid node IDs (including 'start' which is always valid) + valid_node_ids = ctx.get_node_ids() + valid_node_ids.add("start") + valid_node_ids.add("sys") # System variables + + def check_text_for_refs(text: str, field_path: str) -> None: + if not isinstance(text, str): + return + refs = extract_variable_refs(text) + for ref_node_id, ref_field in refs: + if ref_node_id not in valid_node_ids: + errors.append( + ValidationError( + rule_id="variable.ref.invalid_node", + node_id=node_id, + node_type=node_type, + category=RuleCategory.SEMANTIC, + severity=Severity.ERROR, + is_fixable=True, + message=f"Node '{node_id}': references non-existent node '{ref_node_id}'", + fix_hint=f"Change {{{{#{ref_node_id}.{ref_field}#}}}} to reference a valid node", + details={"field_path": field_path, "invalid_ref": ref_node_id}, + ) + ) + + # Check prompt_template for LLM nodes + prompt_template = config.get("prompt_template", []) + if isinstance(prompt_template, list): + for i, msg in enumerate(prompt_template): + if isinstance(msg, dict): + text = msg.get("text", "") + check_text_for_refs(text, f"prompt_template[{i}].text") + + # Check instruction field + instruction = config.get("instruction", "") + check_text_for_refs(instruction, "instruction") + + # Check url for http-request + url = config.get("url", "") + check_text_for_refs(url, "url") + + return errors + + +def _check_node_has_outgoing_edge(node: dict[str, Any], ctx: "ValidationContext") -> list[ValidationError]: + """Check that non-end nodes have at least one outgoing edge.""" + errors: list[ValidationError] = [] + node_id = node.get("id", "unknown") + node_type = node.get("type", "unknown") + + # End nodes don't need outgoing edges + if node_type == "end": + return errors + + # Check if this node has any outgoing edges + downstream = ctx.get_downstream_nodes(node_id) + if not downstream: + errors.append( + ValidationError( + rule_id="edge.no_outgoing", + node_id=node_id, + node_type=node_type, + category=RuleCategory.SEMANTIC, + severity=Severity.ERROR, + is_fixable=True, + message=f"Node '{node_id}' has no outgoing edge - workflow is disconnected", + fix_hint=f"Add an edge from '{node_id}' to the next node or to 'end'", + details={"field": "edges"}, + ) + ) + + return errors + + +def _check_node_has_incoming_edge(node: dict[str, Any], ctx: "ValidationContext") -> list[ValidationError]: + """Check that non-start nodes have at least one incoming edge.""" + errors: list[ValidationError] = [] + node_id = node.get("id", "unknown") + node_type = node.get("type", "unknown") + + # Start nodes don't need incoming edges + if node_type == "start": + return errors + + # Check if this node has any incoming edges + upstream = ctx.get_upstream_nodes(node_id) + if not upstream: + errors.append( + ValidationError( + rule_id="edge.no_incoming", + node_id=node_id, + node_type=node_type, + category=RuleCategory.SEMANTIC, + severity=Severity.ERROR, + is_fixable=True, + message=f"Node '{node_id}' is orphaned - no incoming edges", + fix_hint=f"Add an edge from a previous node to '{node_id}'", + details={"field": "edges"}, + ) + ) + + return errors + + +def _check_question_classifier_branches(node: dict[str, Any], ctx: "ValidationContext") -> list[ValidationError]: + """Check that question-classifier has edges for all defined classes.""" + errors: list[ValidationError] = [] + node_id = node.get("id", "unknown") + node_type = node.get("type", "unknown") + + if node_type != "question-classifier": + return errors + + config = node.get("config", {}) + classes = config.get("classes", []) + + if not classes: + return errors # Already caught by structure validation + + # Get all class IDs + class_ids = set() + for cls in classes: + if isinstance(cls, dict) and cls.get("id"): + class_ids.add(cls["id"]) + + # Get all outgoing edges with their sourceHandles + outgoing_handles = set() + for edge in ctx.edges: + if edge.get("source") == node_id: + handle = edge.get("sourceHandle") + if handle: + outgoing_handles.add(handle) + + # Check for missing branches + missing_branches = class_ids - outgoing_handles + if missing_branches: + for branch_id in missing_branches: + # Find the class name for better error message + class_name = branch_id + for cls in classes: + if isinstance(cls, dict) and cls.get("id") == branch_id: + class_name = cls.get("name", branch_id) + break + + errors.append( + ValidationError( + rule_id="edge.classifier_branch.missing", + node_id=node_id, + node_type=node_type, + category=RuleCategory.SEMANTIC, + severity=Severity.ERROR, + is_fixable=True, + message=f"Question classifier '{node_id}' missing edge for class '{class_name}'", + fix_hint=f"Add edge: {{source: '{node_id}', sourceHandle: '{branch_id}', target: ''}}", + details={"missing_class_id": branch_id, "missing_class_name": class_name, "field": "edges"}, + ) + ) + + return errors + + +def _check_if_else_branches(node: dict[str, Any], ctx: "ValidationContext") -> list[ValidationError]: + """Check that if-else has both true and false branch edges.""" + errors: list[ValidationError] = [] + node_id = node.get("id", "unknown") + node_type = node.get("type", "unknown") + + if node_type != "if-else": + return errors + + # Get all outgoing edges with their sourceHandles + outgoing_handles = set() + for edge in ctx.edges: + if edge.get("source") == node_id: + handle = edge.get("sourceHandle") + if handle: + outgoing_handles.add(handle) + + # Check for required branches + required_branches = {"true", "false"} + missing_branches = required_branches - outgoing_handles + + for branch in missing_branches: + errors.append( + ValidationError( + rule_id="edge.if_else_branch.missing", + node_id=node_id, + node_type=node_type, + category=RuleCategory.SEMANTIC, + severity=Severity.ERROR, + is_fixable=True, + message=f"If-else node '{node_id}' missing '{branch}' branch edge", + fix_hint=f"Add edge: {{source: '{node_id}', sourceHandle: '{branch}', target: ''}}", + details={"missing_branch": branch, "field": "edges"}, + ) + ) + + return errors + + return errors + + +def _check_if_else_operators(node: dict[str, Any], ctx: "ValidationContext") -> list[ValidationError]: + """Check that if-else comparison operators are valid.""" + errors: list[ValidationError] = [] + node_id = node.get("id", "unknown") + node_type = node.get("type", "unknown") + + if node_type != "if-else": + return errors + + valid_operators = { + "contains", + "not contains", + "start with", + "end with", + "is", + "is not", + "empty", + "not empty", + "in", + "not in", + "all of", + "=", + "≠", + ">", + "<", + "≥", + "≤", + "null", + "not null", + "exists", + "not exists", + } + + config = node.get("config", {}) + cases = config.get("cases", []) + + for case in cases: + conditions = case.get("conditions", []) + for condition in conditions: + op = condition.get("comparison_operator") + if op and op not in valid_operators: + errors.append( + ValidationError( + rule_id="ifelse.operator.invalid", + node_id=node_id, + node_type=node_type, + category=RuleCategory.SEMANTIC, + severity=Severity.ERROR, + is_fixable=True, + message=f"Invalid operator '{op}' in if-else node", + fix_hint=f"Use one of: {', '.join(sorted(valid_operators))}", + details={"invalid_operator": op, "field": "config.cases.conditions.comparison_operator"}, + ) + ) + + return errors + + +def _check_edge_targets_exist(node: dict[str, Any], ctx: "ValidationContext") -> list[ValidationError]: + """Check that edge targets reference existing nodes.""" + errors: list[ValidationError] = [] + node_id = node.get("id", "unknown") + node_type = node.get("type", "unknown") + + valid_node_ids = ctx.get_node_ids() + + # Check all outgoing edges from this node + for edge in ctx.edges: + if edge.get("source") == node_id: + target = edge.get("target") + if target and target not in valid_node_ids: + errors.append( + ValidationError( + rule_id="edge.target.invalid", + node_id=node_id, + node_type=node_type, + category=RuleCategory.SEMANTIC, + severity=Severity.ERROR, + is_fixable=True, + message=f"Edge from '{node_id}' targets non-existent node '{target}'", + fix_hint=f"Change edge target from '{target}' to an existing node", + details={"invalid_target": target, "field": "edges"}, + ) + ) + + return errors + + +# ============================================================================= +# Reference Rules - External resources (models, tools, datasets) +# ============================================================================= + +# Node types that require model configuration +MODEL_REQUIRED_NODE_TYPES = {"llm", "question-classifier", "parameter-extractor"} + + +def _check_model_config(node: dict[str, Any], ctx: "ValidationContext") -> list[ValidationError]: + """Check that model configuration is valid.""" + errors: list[ValidationError] = [] + node_id = node.get("id", "unknown") + node_type = node.get("type", "unknown") + config = node.get("config", {}) + + if node_type not in MODEL_REQUIRED_NODE_TYPES: + return errors + + model = config.get("model") + + # Check if model config exists + if not model: + if ctx.available_models: + errors.append( + ValidationError( + rule_id="model.required", + node_id=node_id, + node_type=node_type, + category=RuleCategory.REFERENCE, + severity=Severity.ERROR, + is_fixable=True, + message=f"Node '{node_id}' ({node_type}): missing required 'model' configuration", + fix_hint="Add model config using one of the available models", + ) + ) + else: + errors.append( + ValidationError( + rule_id="model.no_available", + node_id=node_id, + node_type=node_type, + category=RuleCategory.REFERENCE, + severity=Severity.ERROR, + is_fixable=False, + message=f"Node '{node_id}' ({node_type}): needs model but no models available", + fix_hint="User must configure a model provider first", + ) + ) + return errors + + # Check if model config is valid + if isinstance(model, dict): + provider = model.get("provider", "") + name = model.get("name", "") + + # Check for placeholder values + if is_placeholder(provider) or is_placeholder(name): + if ctx.available_models: + errors.append( + ValidationError( + rule_id="model.placeholder", + node_id=node_id, + node_type=node_type, + category=RuleCategory.REFERENCE, + severity=Severity.ERROR, + is_fixable=True, + message=f"Node '{node_id}': model config contains placeholder", + fix_hint="Replace placeholder with actual model from available_models", + ) + ) + return errors + + # Check if model exists in available_models + if ctx.available_models and provider and name: + if not ctx.has_model(provider, name): + errors.append( + ValidationError( + rule_id="model.not_found", + node_id=node_id, + node_type=node_type, + category=RuleCategory.REFERENCE, + severity=Severity.ERROR, + is_fixable=True, + message=f"Node '{node_id}': model '{provider}/{name}' not in available models", + fix_hint="Replace with a model from available_models", + details={"provider": provider, "model": name}, + ) + ) + + return errors + + +def _check_tool_reference(node: dict[str, Any], ctx: "ValidationContext") -> list[ValidationError]: + """Check that tool references are valid and configured.""" + errors: list[ValidationError] = [] + node_id = node.get("id", "unknown") + node_type = node.get("type", "unknown") + + if node_type != "tool": + return errors + + config = node.get("config", {}) + tool_ref = ( + config.get("tool_key") + or config.get("tool_name") + or config.get("provider_id", "") + "/" + config.get("tool_name", "") + ) + + if not tool_ref: + errors.append( + ValidationError( + rule_id="tool.key.required", + node_id=node_id, + node_type=node_type, + category=RuleCategory.REFERENCE, + severity=Severity.ERROR, + is_fixable=True, + message=f"Node '{node_id}': tool node missing tool_key", + fix_hint="Add tool_key from available_tools", + ) + ) + return errors + + # Check if tool exists + if not ctx.has_tool(tool_ref): + errors.append( + ValidationError( + rule_id="tool.not_found", + node_id=node_id, + node_type=node_type, + category=RuleCategory.REFERENCE, + severity=Severity.ERROR, + is_fixable=True, # Can be replaced with http-request fallback + message=f"Node '{node_id}': tool '{tool_ref}' not found", + fix_hint="Use http-request or code node as fallback", + details={"tool_ref": tool_ref}, + ) + ) + elif not ctx.is_tool_configured(tool_ref): + errors.append( + ValidationError( + rule_id="tool.not_configured", + node_id=node_id, + node_type=node_type, + category=RuleCategory.REFERENCE, + severity=Severity.WARNING, + is_fixable=False, # User needs to configure + message=f"Node '{node_id}': tool '{tool_ref}' requires configuration", + fix_hint="Configure the tool in Tools settings", + details={"tool_ref": tool_ref}, + ) + ) + + return errors + + +# ============================================================================= +# Register All Rules +# ============================================================================= + +# Structure Rules +register_rule( + ValidationRule( + id="llm.prompt_template.required", + node_types=["llm"], + category=RuleCategory.STRUCTURE, + severity=Severity.ERROR, + is_fixable=True, + check=_check_llm_prompt_template, + description="LLM node must have prompt_template", + fix_hint="Add prompt_template with system and user messages", + ) +) + +register_rule( + ValidationRule( + id="http.config.required", + node_types=["http-request"], + category=RuleCategory.STRUCTURE, + severity=Severity.ERROR, + is_fixable=True, + check=_check_http_request_url, + description="HTTP request node must have url and method", + fix_hint="Add url and method to config", + ) +) + +register_rule( + ValidationRule( + id="code.config.required", + node_types=["code"], + category=RuleCategory.STRUCTURE, + severity=Severity.ERROR, + is_fixable=True, + check=_check_code_node, + description="Code node must have code and language", + fix_hint="Add code with main() function and language", + ) +) + +register_rule( + ValidationRule( + id="classifier.classes.required", + node_types=["question-classifier"], + category=RuleCategory.STRUCTURE, + severity=Severity.ERROR, + is_fixable=True, + check=_check_question_classifier, + description="Question classifier must have classes", + fix_hint="Add classes array with classification options", + ) +) + +register_rule( + ValidationRule( + id="extractor.config.required", + node_types=["parameter-extractor"], + category=RuleCategory.STRUCTURE, + severity=Severity.ERROR, + is_fixable=True, + check=_check_parameter_extractor, + description="Parameter extractor must have parameters", + fix_hint="Add parameters array", + ) +) + +register_rule( + ValidationRule( + id="knowledge.config.required", + node_types=["knowledge-retrieval"], + category=RuleCategory.STRUCTURE, + severity=Severity.ERROR, + is_fixable=False, + check=_check_knowledge_retrieval, + description="Knowledge retrieval must have dataset_ids", + fix_hint="User must select knowledge base", + ) +) + +register_rule( + ValidationRule( + id="end.outputs.check", + node_types=["end"], + category=RuleCategory.STRUCTURE, + severity=Severity.WARNING, + is_fixable=True, + check=_check_end_node, + description="End node should have outputs", + fix_hint="Add outputs array", + ) +) + +# Semantic Rules +register_rule( + ValidationRule( + id="variable.references.valid", + node_types=["*"], + category=RuleCategory.SEMANTIC, + severity=Severity.ERROR, + is_fixable=True, + check=_check_variable_references, + description="Variable references must point to valid nodes", + fix_hint="Fix variable reference to use valid node ID", + ) +) + +# Edge Validation Rules +register_rule( + ValidationRule( + id="edge.outgoing.required", + node_types=["*"], + category=RuleCategory.SEMANTIC, + severity=Severity.ERROR, + is_fixable=True, + check=_check_node_has_outgoing_edge, + description="Non-end nodes must have outgoing edges", + fix_hint="Add an edge from this node to the next node", + ) +) + +register_rule( + ValidationRule( + id="edge.incoming.required", + node_types=["*"], + category=RuleCategory.SEMANTIC, + severity=Severity.ERROR, + is_fixable=True, + check=_check_node_has_incoming_edge, + description="Non-start nodes must have incoming edges", + fix_hint="Add an edge from a previous node to this node", + ) +) + +register_rule( + ValidationRule( + id="edge.classifier_branches.complete", + node_types=["question-classifier"], + category=RuleCategory.SEMANTIC, + severity=Severity.ERROR, + is_fixable=True, + check=_check_question_classifier_branches, + description="Question classifier must have edges for all classes", + fix_hint="Add edges with sourceHandle for each class ID", + ) +) + +register_rule( + ValidationRule( + id="edge.if_else_branches.complete", + node_types=["if-else"], + category=RuleCategory.SEMANTIC, + severity=Severity.ERROR, + is_fixable=True, + check=_check_if_else_branches, + description="If-else must have true and false branch edges", + fix_hint="Add edges with sourceHandle 'true' and 'false'", + ) +) + +register_rule( + ValidationRule( + id="edge.targets.valid", + node_types=["*"], + category=RuleCategory.SEMANTIC, + severity=Severity.ERROR, + is_fixable=True, + check=_check_edge_targets_exist, + description="Edge targets must reference existing nodes", + fix_hint="Change edge target to an existing node ID", + ) +) + +# Reference Rules +register_rule( + ValidationRule( + id="model.config.valid", + node_types=["llm", "question-classifier", "parameter-extractor"], + category=RuleCategory.REFERENCE, + severity=Severity.ERROR, + is_fixable=True, + check=_check_model_config, + description="Model configuration must be valid", + fix_hint="Add valid model from available_models", + ) +) + +register_rule( + ValidationRule( + id="tool.reference.valid", + node_types=["tool"], + category=RuleCategory.REFERENCE, + severity=Severity.ERROR, + is_fixable=True, + check=_check_tool_reference, + description="Tool reference must be valid and configured", + fix_hint="Use valid tool or fallback node", + ) +) + +register_rule( + ValidationRule( + id="ifelse.operator.valid", + node_types=["if-else"], + category=RuleCategory.SEMANTIC, + severity=Severity.ERROR, + is_fixable=True, + check=_check_if_else_operators, + description="If-else operators must be valid", + fix_hint="Use standard operators like ≥, ≤, =, ≠", + ) +) diff --git a/api/tests/unit_tests/core/llm_generator/test_mermaid_generator.py b/api/tests/unit_tests/core/llm_generator/test_mermaid_generator.py new file mode 100644 index 0000000000..9dbb486dd9 --- /dev/null +++ b/api/tests/unit_tests/core/llm_generator/test_mermaid_generator.py @@ -0,0 +1,288 @@ +""" +Unit tests for the Mermaid Generator. + +Tests cover: +- Basic workflow rendering +- Reserved word handling ('end' → 'end_node') +- Question classifier multi-branch edges +- If-else branch labels +- Edge validation and skipping +- Tool node formatting +""" + + +from core.workflow.generator.utils.mermaid_generator import generate_mermaid + + +class TestBasicWorkflow: + """Tests for basic workflow Mermaid generation.""" + + def test_simple_start_end_workflow(self): + """Test simple Start → End workflow.""" + workflow_data = { + "nodes": [ + {"id": "start", "type": "start", "title": "Start"}, + {"id": "end", "type": "end", "title": "End"}, + ], + "edges": [{"source": "start", "target": "end"}], + } + result = generate_mermaid(workflow_data) + + assert "flowchart TD" in result + assert 'start["type=start|title=Start"]' in result + assert 'end_node["type=end|title=End"]' in result + assert "start --> end_node" in result + + def test_start_llm_end_workflow(self): + """Test Start → LLM → End workflow.""" + workflow_data = { + "nodes": [ + {"id": "start", "type": "start", "title": "Start"}, + {"id": "llm", "type": "llm", "title": "Generate"}, + {"id": "end", "type": "end", "title": "End"}, + ], + "edges": [ + {"source": "start", "target": "llm"}, + {"source": "llm", "target": "end"}, + ], + } + result = generate_mermaid(workflow_data) + + assert 'llm["type=llm|title=Generate"]' in result + assert "start --> llm" in result + assert "llm --> end_node" in result + + def test_empty_workflow(self): + """Test empty workflow returns minimal output.""" + workflow_data = {"nodes": [], "edges": []} + result = generate_mermaid(workflow_data) + + assert result == "flowchart TD" + + def test_missing_keys_handled(self): + """Test workflow with missing keys doesn't crash.""" + workflow_data = {} + result = generate_mermaid(workflow_data) + + assert "flowchart TD" in result + + +class TestReservedWords: + """Tests for reserved word handling in node IDs.""" + + def test_end_node_id_is_replaced(self): + """Test 'end' node ID is replaced with 'end_node'.""" + workflow_data = { + "nodes": [{"id": "end", "type": "end", "title": "End"}], + "edges": [], + } + result = generate_mermaid(workflow_data) + + # Should use end_node instead of end + assert "end_node[" in result + assert '"type=end|title=End"' in result + + def test_subgraph_node_id_is_replaced(self): + """Test 'subgraph' node ID is replaced with 'subgraph_node'.""" + workflow_data = { + "nodes": [{"id": "subgraph", "type": "code", "title": "Process"}], + "edges": [], + } + result = generate_mermaid(workflow_data) + + assert "subgraph_node[" in result + + def test_edge_uses_safe_ids(self): + """Test edges correctly reference safe IDs after replacement.""" + workflow_data = { + "nodes": [ + {"id": "start", "type": "start", "title": "Start"}, + {"id": "end", "type": "end", "title": "End"}, + ], + "edges": [{"source": "start", "target": "end"}], + } + result = generate_mermaid(workflow_data) + + # Edge should use end_node, not end + assert "start --> end_node" in result + assert "start --> end\n" not in result + + +class TestBranchEdges: + """Tests for branching node edge labels.""" + + def test_question_classifier_source_handles(self): + """Test question-classifier edges with sourceHandle labels.""" + workflow_data = { + "nodes": [ + {"id": "classifier", "type": "question-classifier", "title": "Classify"}, + {"id": "refund", "type": "llm", "title": "Handle Refund"}, + {"id": "inquiry", "type": "llm", "title": "Handle Inquiry"}, + ], + "edges": [ + {"source": "classifier", "target": "refund", "sourceHandle": "refund"}, + {"source": "classifier", "target": "inquiry", "sourceHandle": "inquiry"}, + ], + } + result = generate_mermaid(workflow_data) + + assert "classifier -->|refund| refund" in result + assert "classifier -->|inquiry| inquiry" in result + + def test_if_else_true_false_handles(self): + """Test if-else edges with true/false labels.""" + workflow_data = { + "nodes": [ + {"id": "ifelse", "type": "if-else", "title": "Check"}, + {"id": "yes_branch", "type": "llm", "title": "Yes"}, + {"id": "no_branch", "type": "llm", "title": "No"}, + ], + "edges": [ + {"source": "ifelse", "target": "yes_branch", "sourceHandle": "true"}, + {"source": "ifelse", "target": "no_branch", "sourceHandle": "false"}, + ], + } + result = generate_mermaid(workflow_data) + + assert "ifelse -->|true| yes_branch" in result + assert "ifelse -->|false| no_branch" in result + + def test_source_handle_source_is_ignored(self): + """Test sourceHandle='source' doesn't add label.""" + workflow_data = { + "nodes": [ + {"id": "llm1", "type": "llm", "title": "LLM 1"}, + {"id": "llm2", "type": "llm", "title": "LLM 2"}, + ], + "edges": [{"source": "llm1", "target": "llm2", "sourceHandle": "source"}], + } + result = generate_mermaid(workflow_data) + + # Should be plain arrow without label + assert "llm1 --> llm2" in result + assert "llm1 -->|source|" not in result + + +class TestEdgeValidation: + """Tests for edge validation and error handling.""" + + def test_edge_with_missing_source_is_skipped(self): + """Test edge with non-existent source node is skipped.""" + workflow_data = { + "nodes": [{"id": "end", "type": "end", "title": "End"}], + "edges": [{"source": "nonexistent", "target": "end"}], + } + result = generate_mermaid(workflow_data) + + # Should not contain the invalid edge + assert "nonexistent" not in result + assert "-->" not in result or "nonexistent" not in result + + def test_edge_with_missing_target_is_skipped(self): + """Test edge with non-existent target node is skipped.""" + workflow_data = { + "nodes": [{"id": "start", "type": "start", "title": "Start"}], + "edges": [{"source": "start", "target": "nonexistent"}], + } + result = generate_mermaid(workflow_data) + + # Edge should be skipped + assert "start --> nonexistent" not in result + + def test_edge_without_source_or_target_is_skipped(self): + """Test edge missing source or target is skipped.""" + workflow_data = { + "nodes": [{"id": "start", "type": "start", "title": "Start"}], + "edges": [{"source": "start"}, {"target": "start"}, {}], + } + result = generate_mermaid(workflow_data) + + # No edges should be rendered + assert result.count("-->") == 0 + + +class TestToolNodes: + """Tests for tool node formatting.""" + + def test_tool_node_includes_tool_key(self): + """Test tool node includes tool_key in label.""" + workflow_data = { + "nodes": [ + { + "id": "search", + "type": "tool", + "title": "Search", + "config": {"tool_key": "google/search"}, + } + ], + "edges": [], + } + result = generate_mermaid(workflow_data) + + assert 'search["type=tool|title=Search|tool=google/search"]' in result + + def test_tool_node_with_tool_name_fallback(self): + """Test tool node uses tool_name as fallback.""" + workflow_data = { + "nodes": [ + { + "id": "tool1", + "type": "tool", + "title": "My Tool", + "config": {"tool_name": "my_tool"}, + } + ], + "edges": [], + } + result = generate_mermaid(workflow_data) + + assert "tool=my_tool" in result + + def test_tool_node_missing_tool_key_shows_unknown(self): + """Test tool node without tool_key shows 'unknown'.""" + workflow_data = { + "nodes": [{"id": "tool1", "type": "tool", "title": "Tool", "config": {}}], + "edges": [], + } + result = generate_mermaid(workflow_data) + + assert "tool=unknown" in result + + +class TestNodeFormatting: + """Tests for node label formatting.""" + + def test_quotes_in_title_are_escaped(self): + """Test double quotes in title are replaced with single quotes.""" + workflow_data = { + "nodes": [{"id": "llm", "type": "llm", "title": 'Say "Hello"'}], + "edges": [], + } + result = generate_mermaid(workflow_data) + + # Double quotes should be replaced + assert "Say 'Hello'" in result + assert 'Say "Hello"' not in result + + def test_node_without_id_is_skipped(self): + """Test node without id is skipped.""" + workflow_data = { + "nodes": [{"type": "llm", "title": "No ID"}], + "edges": [], + } + result = generate_mermaid(workflow_data) + + # Should only have flowchart header + lines = [line for line in result.split("\n") if line.strip()] + assert len(lines) == 1 + + def test_node_default_values(self): + """Test node with missing type/title uses defaults.""" + workflow_data = { + "nodes": [{"id": "node1"}], + "edges": [], + } + result = generate_mermaid(workflow_data) + + assert "type=unknown" in result + assert "title=Untitled" in result diff --git a/api/tests/unit_tests/core/llm_generator/test_node_repair.py b/api/tests/unit_tests/core/llm_generator/test_node_repair.py new file mode 100644 index 0000000000..a92a7d0125 --- /dev/null +++ b/api/tests/unit_tests/core/llm_generator/test_node_repair.py @@ -0,0 +1,81 @@ +from core.workflow.generator.utils.node_repair import NodeRepair + + +class TestNodeRepair: + """Tests for NodeRepair utility.""" + + def test_repair_if_else_valid_operators(self): + """Test that valid operators remain unchanged.""" + nodes = [ + { + "id": "node1", + "type": "if-else", + "config": { + "cases": [ + { + "conditions": [ + {"comparison_operator": "≥", "value": "1"}, + {"comparison_operator": "=", "value": "2"}, + ] + } + ] + }, + } + ] + result = NodeRepair.repair(nodes) + assert result.was_repaired is False + assert result.nodes == nodes + + def test_repair_if_else_invalid_operators(self): + """Test that invalid operators are normalized.""" + nodes = [ + { + "id": "node1", + "type": "if-else", + "config": { + "cases": [ + { + "conditions": [ + {"comparison_operator": ">=", "value": "1"}, + {"comparison_operator": "<=", "value": "2"}, + {"comparison_operator": "!=", "value": "3"}, + {"comparison_operator": "==", "value": "4"}, + ] + } + ] + }, + } + ] + result = NodeRepair.repair(nodes) + assert result.was_repaired is True + assert len(result.repairs_made) == 4 + + conditions = result.nodes[0]["config"]["cases"][0]["conditions"] + assert conditions[0]["comparison_operator"] == "≥" + assert conditions[1]["comparison_operator"] == "≤" + assert conditions[2]["comparison_operator"] == "≠" + assert conditions[3]["comparison_operator"] == "=" + + def test_repair_ignores_other_nodes(self): + """Test that other node types are ignored.""" + nodes = [{"id": "node1", "type": "llm", "config": {"some_field": ">="}}] + result = NodeRepair.repair(nodes) + assert result.was_repaired is False + assert result.nodes[0]["config"]["some_field"] == ">=" + + def test_repair_handles_missing_config(self): + """Test robustness against missing fields.""" + nodes = [ + { + "id": "node1", + "type": "if-else", + # Missing config + }, + { + "id": "node2", + "type": "if-else", + "config": {}, # Missing cases + }, + ] + result = NodeRepair.repair(nodes) + assert result.was_repaired is False diff --git a/api/tests/unit_tests/core/llm_generator/test_planner_prompts.py b/api/tests/unit_tests/core/llm_generator/test_planner_prompts.py new file mode 100644 index 0000000000..a741c30c7a --- /dev/null +++ b/api/tests/unit_tests/core/llm_generator/test_planner_prompts.py @@ -0,0 +1,173 @@ +""" +Unit tests for the Planner Prompts. + +Tests cover: +- Tool formatting for planner context +- Edge cases with missing fields +- Empty tool lists +""" + + +from core.workflow.generator.prompts.planner_prompts import format_tools_for_planner + + +class TestFormatToolsForPlanner: + """Tests for format_tools_for_planner function.""" + + def test_empty_tools_returns_default_message(self): + """Test empty tools list returns default message.""" + result = format_tools_for_planner([]) + + assert result == "No external tools available." + + def test_none_tools_returns_default_message(self): + """Test None tools list returns default message.""" + result = format_tools_for_planner(None) + + assert result == "No external tools available." + + def test_single_tool_formatting(self): + """Test single tool is formatted correctly.""" + tools = [ + { + "provider_id": "google", + "tool_key": "search", + "tool_label": "Google Search", + "tool_description": "Search the web using Google", + } + ] + result = format_tools_for_planner(tools) + + assert "[google/search]" in result + assert "Google Search" in result + assert "Search the web using Google" in result + + def test_multiple_tools_formatting(self): + """Test multiple tools are formatted correctly.""" + tools = [ + { + "provider_id": "google", + "tool_key": "search", + "tool_label": "Search", + "tool_description": "Web search", + }, + { + "provider_id": "slack", + "tool_key": "send_message", + "tool_label": "Send Message", + "tool_description": "Send a Slack message", + }, + ] + result = format_tools_for_planner(tools) + + lines = result.strip().split("\n") + assert len(lines) == 2 + assert "[google/search]" in result + assert "[slack/send_message]" in result + + def test_tool_without_provider_uses_key_only(self): + """Test tool without provider_id uses tool_key only.""" + tools = [ + { + "tool_key": "my_tool", + "tool_label": "My Tool", + "tool_description": "A custom tool", + } + ] + result = format_tools_for_planner(tools) + + # Should format as [my_tool] without provider prefix + assert "[my_tool]" in result + assert "My Tool" in result + + def test_tool_with_tool_name_fallback(self): + """Test tool uses tool_name when tool_key is missing.""" + tools = [ + { + "tool_name": "fallback_tool", + "description": "Fallback description", + } + ] + result = format_tools_for_planner(tools) + + assert "fallback_tool" in result + assert "Fallback description" in result + + def test_tool_with_missing_description(self): + """Test tool with missing description doesn't crash.""" + tools = [ + { + "provider_id": "test", + "tool_key": "tool1", + "tool_label": "Tool 1", + } + ] + result = format_tools_for_planner(tools) + + assert "[test/tool1]" in result + assert "Tool 1" in result + + def test_tool_with_all_missing_fields(self): + """Test tool with all fields missing uses defaults.""" + tools = [{}] + result = format_tools_for_planner(tools) + + # Should not crash, may produce minimal output + assert isinstance(result, str) + + def test_tool_uses_provider_fallback(self): + """Test tool uses 'provider' when 'provider_id' is missing.""" + tools = [ + { + "provider": "openai", + "tool_key": "dalle", + "tool_label": "DALL-E", + "tool_description": "Generate images", + } + ] + result = format_tools_for_planner(tools) + + assert "[openai/dalle]" in result + + def test_tool_label_fallback_to_key(self): + """Test tool_label falls back to tool_key when missing.""" + tools = [ + { + "provider_id": "test", + "tool_key": "my_key", + "tool_description": "Description here", + } + ] + result = format_tools_for_planner(tools) + + # Label should fallback to key + assert "my_key" in result + assert "Description here" in result + + +class TestPlannerPromptConstants: + """Tests for planner prompt constant availability.""" + + def test_planner_system_prompt_exists(self): + """Test PLANNER_SYSTEM_PROMPT is defined.""" + from core.workflow.generator.prompts.planner_prompts import PLANNER_SYSTEM_PROMPT + + assert PLANNER_SYSTEM_PROMPT is not None + assert len(PLANNER_SYSTEM_PROMPT) > 0 + assert "{tools_summary}" in PLANNER_SYSTEM_PROMPT + + def test_planner_user_prompt_exists(self): + """Test PLANNER_USER_PROMPT is defined.""" + from core.workflow.generator.prompts.planner_prompts import PLANNER_USER_PROMPT + + assert PLANNER_USER_PROMPT is not None + assert "{instruction}" in PLANNER_USER_PROMPT + + def test_planner_system_prompt_has_required_sections(self): + """Test PLANNER_SYSTEM_PROMPT has required XML sections.""" + from core.workflow.generator.prompts.planner_prompts import PLANNER_SYSTEM_PROMPT + + assert "" in PLANNER_SYSTEM_PROMPT + assert "" in PLANNER_SYSTEM_PROMPT + assert "" in PLANNER_SYSTEM_PROMPT + assert "" in PLANNER_SYSTEM_PROMPT diff --git a/api/tests/unit_tests/core/llm_generator/test_validation_engine.py b/api/tests/unit_tests/core/llm_generator/test_validation_engine.py new file mode 100644 index 0000000000..477b0cdcf7 --- /dev/null +++ b/api/tests/unit_tests/core/llm_generator/test_validation_engine.py @@ -0,0 +1,536 @@ +""" +Unit tests for the Validation Rule Engine. + +Tests cover: +- Structure rules (required fields, types, formats) +- Semantic rules (variable references, edge connections) +- Reference rules (model exists, tool configured, dataset valid) +- ValidationEngine integration +""" + + +from core.workflow.generator.validation import ( + ValidationContext, + ValidationEngine, +) +from core.workflow.generator.validation.rules import ( + extract_variable_refs, + is_placeholder, +) + + +class TestPlaceholderDetection: + """Tests for placeholder detection utility.""" + + def test_detects_please_select(self): + assert is_placeholder("PLEASE_SELECT_YOUR_MODEL") is True + + def test_detects_your_prefix(self): + assert is_placeholder("YOUR_API_KEY") is True + + def test_detects_todo(self): + assert is_placeholder("TODO: fill this in") is True + + def test_detects_placeholder(self): + assert is_placeholder("PLACEHOLDER_VALUE") is True + + def test_detects_example_prefix(self): + assert is_placeholder("EXAMPLE_URL") is True + + def test_detects_replace_prefix(self): + assert is_placeholder("REPLACE_WITH_ACTUAL") is True + + def test_case_insensitive(self): + assert is_placeholder("please_select") is True + assert is_placeholder("Please_Select") is True + + def test_valid_values_not_detected(self): + assert is_placeholder("https://api.example.com") is False + assert is_placeholder("gpt-4") is False + assert is_placeholder("my_variable") is False + + def test_non_string_returns_false(self): + assert is_placeholder(123) is False + assert is_placeholder(None) is False + assert is_placeholder(["list"]) is False + + +class TestVariableRefExtraction: + """Tests for variable reference extraction.""" + + def test_extracts_simple_ref(self): + refs = extract_variable_refs("Hello {{#start.query#}}") + assert refs == [("start", "query")] + + def test_extracts_multiple_refs(self): + refs = extract_variable_refs("{{#node1.output#}} and {{#node2.text#}}") + assert refs == [("node1", "output"), ("node2", "text")] + + def test_extracts_nested_field(self): + refs = extract_variable_refs("{{#http_request.body#}}") + assert refs == [("http_request", "body")] + + def test_no_refs_returns_empty(self): + refs = extract_variable_refs("No references here") + assert refs == [] + + def test_handles_malformed_refs(self): + refs = extract_variable_refs("{{#invalid}} and {{incomplete#}}") + assert refs == [] + + +class TestValidationContext: + """Tests for ValidationContext.""" + + def test_node_map_lookup(self): + ctx = ValidationContext( + nodes=[ + {"id": "start", "type": "start"}, + {"id": "llm_1", "type": "llm"}, + ] + ) + assert ctx.get_node("start") == {"id": "start", "type": "start"} + assert ctx.get_node("nonexistent") is None + + def test_model_set(self): + ctx = ValidationContext( + available_models=[ + {"provider": "openai", "model": "gpt-4"}, + {"provider": "anthropic", "model": "claude-3"}, + ] + ) + assert ctx.has_model("openai", "gpt-4") is True + assert ctx.has_model("anthropic", "claude-3") is True + assert ctx.has_model("openai", "gpt-3.5") is False + + def test_tool_set(self): + ctx = ValidationContext( + available_tools=[ + {"provider_id": "google", "tool_key": "search", "is_team_authorization": True}, + {"provider_id": "slack", "tool_key": "send_message", "is_team_authorization": False}, + ] + ) + assert ctx.has_tool("google/search") is True + assert ctx.has_tool("search") is True + assert ctx.is_tool_configured("google/search") is True + assert ctx.is_tool_configured("slack/send_message") is False + + def test_upstream_downstream_nodes(self): + ctx = ValidationContext( + nodes=[ + {"id": "start", "type": "start"}, + {"id": "llm", "type": "llm"}, + {"id": "end", "type": "end"}, + ], + edges=[ + {"source": "start", "target": "llm"}, + {"source": "llm", "target": "end"}, + ], + ) + assert ctx.get_upstream_nodes("llm") == ["start"] + assert ctx.get_downstream_nodes("llm") == ["end"] + + +class TestStructureRules: + """Tests for structure validation rules.""" + + def test_llm_missing_prompt_template(self): + ctx = ValidationContext( + nodes=[{"id": "llm_1", "type": "llm", "config": {}}] + ) + engine = ValidationEngine() + result = engine.validate(ctx) + + assert result.has_errors + errors = [e for e in result.all_errors if e.rule_id == "llm.prompt_template.required"] + assert len(errors) == 1 + assert errors[0].is_fixable is True + + def test_llm_with_prompt_template_passes(self): + ctx = ValidationContext( + nodes=[ + { + "id": "llm_1", + "type": "llm", + "config": { + "prompt_template": [ + {"role": "system", "text": "You are helpful"}, + {"role": "user", "text": "Hello"}, + ] + }, + } + ] + ) + engine = ValidationEngine() + result = engine.validate(ctx) + + # No prompt_template errors + errors = [e for e in result.all_errors if "prompt_template" in e.rule_id] + assert len(errors) == 0 + + def test_http_request_missing_url(self): + ctx = ValidationContext( + nodes=[{"id": "http_1", "type": "http-request", "config": {}}] + ) + engine = ValidationEngine() + result = engine.validate(ctx) + + errors = [e for e in result.all_errors if "http.url" in e.rule_id] + assert len(errors) == 1 + assert errors[0].is_fixable is True + + def test_http_request_placeholder_url(self): + ctx = ValidationContext( + nodes=[ + { + "id": "http_1", + "type": "http-request", + "config": {"url": "PLEASE_SELECT_YOUR_URL", "method": "GET"}, + } + ] + ) + engine = ValidationEngine() + result = engine.validate(ctx) + + errors = [e for e in result.all_errors if "placeholder" in e.rule_id] + assert len(errors) == 1 + + def test_code_node_missing_fields(self): + ctx = ValidationContext( + nodes=[{"id": "code_1", "type": "code", "config": {}}] + ) + engine = ValidationEngine() + result = engine.validate(ctx) + + error_rules = {e.rule_id for e in result.all_errors} + assert "code.code.required" in error_rules + assert "code.language.required" in error_rules + + def test_knowledge_retrieval_missing_dataset(self): + ctx = ValidationContext( + nodes=[{"id": "kb_1", "type": "knowledge-retrieval", "config": {}}] + ) + engine = ValidationEngine() + result = engine.validate(ctx) + + errors = [e for e in result.all_errors if "knowledge.dataset" in e.rule_id] + assert len(errors) == 1 + assert errors[0].is_fixable is False # User must configure + + +class TestSemanticRules: + """Tests for semantic validation rules.""" + + def test_valid_variable_reference(self): + ctx = ValidationContext( + nodes=[ + {"id": "start", "type": "start", "config": {}}, + { + "id": "llm_1", + "type": "llm", + "config": { + "prompt_template": [ + {"role": "user", "text": "Process: {{#start.query#}}"} + ] + }, + }, + ] + ) + engine = ValidationEngine() + result = engine.validate(ctx) + + # No variable reference errors + errors = [e for e in result.all_errors if "variable.ref" in e.rule_id] + assert len(errors) == 0 + + def test_invalid_variable_reference(self): + ctx = ValidationContext( + nodes=[ + {"id": "start", "type": "start", "config": {}}, + { + "id": "llm_1", + "type": "llm", + "config": { + "prompt_template": [ + {"role": "user", "text": "Process: {{#nonexistent.field#}}"} + ] + }, + }, + ] + ) + engine = ValidationEngine() + result = engine.validate(ctx) + + errors = [e for e in result.all_errors if "variable.ref" in e.rule_id] + assert len(errors) == 1 + assert "nonexistent" in errors[0].message + + def test_edge_validation(self): + ctx = ValidationContext( + nodes=[ + {"id": "start", "type": "start", "config": {}}, + {"id": "end", "type": "end", "config": {}}, + ], + edges=[ + {"source": "start", "target": "end"}, + {"source": "nonexistent", "target": "end"}, + ], + ) + engine = ValidationEngine() + result = engine.validate(ctx) + + errors = [e for e in result.all_errors if "edge" in e.rule_id] + assert len(errors) == 1 + assert "nonexistent" in errors[0].message + + +class TestReferenceRules: + """Tests for reference validation rules (models, tools).""" + + def test_llm_missing_model_with_available(self): + ctx = ValidationContext( + nodes=[ + { + "id": "llm_1", + "type": "llm", + "config": {"prompt_template": [{"role": "user", "text": "Hi"}]}, + } + ], + available_models=[{"provider": "openai", "model": "gpt-4"}], + ) + engine = ValidationEngine() + result = engine.validate(ctx) + + errors = [e for e in result.all_errors if e.rule_id == "model.required"] + assert len(errors) == 1 + assert errors[0].is_fixable is True + + def test_llm_missing_model_no_available(self): + ctx = ValidationContext( + nodes=[ + { + "id": "llm_1", + "type": "llm", + "config": {"prompt_template": [{"role": "user", "text": "Hi"}]}, + } + ], + available_models=[], # No models available + ) + engine = ValidationEngine() + result = engine.validate(ctx) + + errors = [e for e in result.all_errors if e.rule_id == "model.no_available"] + assert len(errors) == 1 + assert errors[0].is_fixable is False + + def test_llm_with_valid_model(self): + ctx = ValidationContext( + nodes=[ + { + "id": "llm_1", + "type": "llm", + "config": { + "prompt_template": [{"role": "user", "text": "Hi"}], + "model": {"provider": "openai", "name": "gpt-4"}, + }, + } + ], + available_models=[{"provider": "openai", "model": "gpt-4"}], + ) + engine = ValidationEngine() + result = engine.validate(ctx) + + errors = [e for e in result.all_errors if "model" in e.rule_id] + assert len(errors) == 0 + + def test_llm_with_invalid_model(self): + ctx = ValidationContext( + nodes=[ + { + "id": "llm_1", + "type": "llm", + "config": { + "prompt_template": [{"role": "user", "text": "Hi"}], + "model": {"provider": "openai", "name": "gpt-99"}, + }, + } + ], + available_models=[{"provider": "openai", "model": "gpt-4"}], + ) + engine = ValidationEngine() + result = engine.validate(ctx) + + errors = [e for e in result.all_errors if e.rule_id == "model.not_found"] + assert len(errors) == 1 + assert errors[0].is_fixable is True + + def test_tool_node_not_found(self): + ctx = ValidationContext( + nodes=[ + { + "id": "tool_1", + "type": "tool", + "config": {"tool_key": "nonexistent/tool"}, + } + ], + available_tools=[], + ) + engine = ValidationEngine() + result = engine.validate(ctx) + + errors = [e for e in result.all_errors if e.rule_id == "tool.not_found"] + assert len(errors) == 1 + + def test_tool_node_not_configured(self): + ctx = ValidationContext( + nodes=[ + { + "id": "tool_1", + "type": "tool", + "config": {"tool_key": "google/search"}, + } + ], + available_tools=[ + {"provider_id": "google", "tool_key": "search", "is_team_authorization": False} + ], + ) + engine = ValidationEngine() + result = engine.validate(ctx) + + errors = [e for e in result.all_errors if e.rule_id == "tool.not_configured"] + assert len(errors) == 1 + assert errors[0].is_fixable is False + + +class TestValidationResult: + """Tests for ValidationResult classification.""" + + def test_has_errors(self): + ctx = ValidationContext( + nodes=[{"id": "llm_1", "type": "llm", "config": {}}] + ) + engine = ValidationEngine() + result = engine.validate(ctx) + + assert result.has_errors is True + assert result.is_valid is False + + def test_has_fixable_errors(self): + ctx = ValidationContext( + nodes=[ + { + "id": "llm_1", + "type": "llm", + "config": {"prompt_template": [{"role": "user", "text": "Hi"}]}, + } + ], + available_models=[{"provider": "openai", "model": "gpt-4"}], + ) + engine = ValidationEngine() + result = engine.validate(ctx) + + assert result.has_fixable_errors is True + assert len(result.fixable_errors) > 0 + + def test_get_fixable_by_node(self): + ctx = ValidationContext( + nodes=[ + {"id": "llm_1", "type": "llm", "config": {}}, + {"id": "http_1", "type": "http-request", "config": {}}, + ] + ) + engine = ValidationEngine() + result = engine.validate(ctx) + + by_node = result.get_fixable_by_node() + assert "llm_1" in by_node + assert "http_1" in by_node + + def test_to_dict(self): + ctx = ValidationContext( + nodes=[{"id": "llm_1", "type": "llm", "config": {}}] + ) + engine = ValidationEngine() + result = engine.validate(ctx) + + d = result.to_dict() + assert "fixable" in d + assert "user_required" in d + assert "warnings" in d + assert "all_warnings" in d + assert "stats" in d + + +class TestIntegration: + """Integration tests for the full validation pipeline.""" + + def test_complete_workflow_validation(self): + """Test validation of a complete workflow.""" + ctx = ValidationContext( + nodes=[ + { + "id": "start", + "type": "start", + "config": {"variables": [{"variable": "query", "type": "text-input"}]}, + }, + { + "id": "llm_1", + "type": "llm", + "config": { + "model": {"provider": "openai", "name": "gpt-4"}, + "prompt_template": [{"role": "user", "text": "{{#start.query#}}"}], + }, + }, + { + "id": "end", + "type": "end", + "config": {"outputs": [{"variable": "result", "value_selector": ["llm_1", "text"]}]}, + }, + ], + edges=[ + {"source": "start", "target": "llm_1"}, + {"source": "llm_1", "target": "end"}, + ], + available_models=[{"provider": "openai", "model": "gpt-4"}], + ) + engine = ValidationEngine() + result = engine.validate(ctx) + + # Should have no errors + assert result.is_valid is True + assert len(result.fixable_errors) == 0 + assert len(result.user_required_errors) == 0 + + def test_workflow_with_multiple_errors(self): + """Test workflow with multiple types of errors.""" + ctx = ValidationContext( + nodes=[ + {"id": "start", "type": "start", "config": {}}, + { + "id": "llm_1", + "type": "llm", + "config": {}, # Missing prompt_template and model + }, + { + "id": "kb_1", + "type": "knowledge-retrieval", + "config": {"dataset_ids": ["PLEASE_SELECT_YOUR_DATASET"]}, + }, + {"id": "end", "type": "end", "config": {}}, + ], + available_models=[{"provider": "openai", "model": "gpt-4"}], + ) + engine = ValidationEngine() + result = engine.validate(ctx) + + # Should have multiple errors + assert result.has_errors is True + assert len(result.fixable_errors) >= 2 # model, prompt_template + assert len(result.user_required_errors) >= 1 # dataset placeholder + + # Check stats + assert result.stats["total_nodes"] == 4 + assert result.stats["total_errors"] >= 3 + + + diff --git a/api/tests/unit_tests/core/llm_generator/test_workflow_validator_vibe.py b/api/tests/unit_tests/core/llm_generator/test_workflow_validator_vibe.py new file mode 100644 index 0000000000..39e2ba5a0e --- /dev/null +++ b/api/tests/unit_tests/core/llm_generator/test_workflow_validator_vibe.py @@ -0,0 +1,435 @@ +""" +Unit tests for the Vibe Workflow Validator. + +Tests cover: +- Basic validation function +- User-friendly validation hints +- Edge cases and error handling +""" + + +from core.workflow.generator.utils.workflow_validator import ValidationHint, WorkflowValidator + + +class TestValidationHint: + """Tests for ValidationHint dataclass.""" + + def test_hint_creation(self): + """Test creating a validation hint.""" + hint = ValidationHint( + node_id="llm_1", + field="model", + message="Model is not configured", + severity="error", + ) + assert hint.node_id == "llm_1" + assert hint.field == "model" + assert hint.message == "Model is not configured" + assert hint.severity == "error" + + def test_hint_with_suggestion(self): + """Test hint with suggestion.""" + hint = ValidationHint( + node_id="http_1", + field="url", + message="URL is required", + severity="error", + suggestion="Add a valid URL like https://api.example.com", + ) + assert hint.suggestion is not None + + +class TestWorkflowValidatorBasic: + """Tests for basic validation scenarios.""" + + def test_empty_workflow_is_valid(self): + """Test empty workflow passes validation.""" + workflow_data = {"nodes": [], "edges": []} + is_valid, hints = WorkflowValidator.validate(workflow_data, []) + + # Empty but valid structure + assert is_valid is True + assert len(hints) == 0 + + def test_minimal_valid_workflow(self): + """Test minimal Start → End workflow.""" + workflow_data = { + "nodes": [ + {"id": "start", "type": "start", "config": {}}, + {"id": "end", "type": "end", "config": {}}, + ], + "edges": [{"source": "start", "target": "end"}], + } + is_valid, hints = WorkflowValidator.validate(workflow_data, []) + + assert is_valid is True + + def test_complete_workflow_with_llm(self): + """Test complete workflow with LLM node.""" + workflow_data = { + "nodes": [ + {"id": "start", "type": "start", "config": {"variables": []}}, + { + "id": "llm", + "type": "llm", + "config": { + "model": {"provider": "openai", "name": "gpt-4"}, + "prompt_template": [{"role": "user", "text": "Hello"}], + }, + }, + {"id": "end", "type": "end", "config": {"outputs": []}}, + ], + "edges": [ + {"source": "start", "target": "llm"}, + {"source": "llm", "target": "end"}, + ], + } + is_valid, hints = WorkflowValidator.validate(workflow_data, []) + + # Should pass with no critical errors + errors = [h for h in hints if h.severity == "error"] + assert len(errors) == 0 + + +class TestVariableReferenceValidation: + """Tests for variable reference validation.""" + + def test_valid_variable_reference(self): + """Test valid variable reference passes.""" + workflow_data = { + "nodes": [ + {"id": "start", "type": "start", "config": {}}, + { + "id": "llm", + "type": "llm", + "config": {"prompt_template": [{"role": "user", "text": "Query: {{#start.query#}}"}]}, + }, + ], + "edges": [{"source": "start", "target": "llm"}], + } + is_valid, hints = WorkflowValidator.validate(workflow_data, []) + + ref_errors = [h for h in hints if "reference" in h.message.lower()] + assert len(ref_errors) == 0 + + def test_invalid_variable_reference(self): + """Test invalid variable reference generates hint.""" + workflow_data = { + "nodes": [ + {"id": "start", "type": "start", "config": {}}, + { + "id": "llm", + "type": "llm", + "config": {"prompt_template": [{"role": "user", "text": "{{#nonexistent.field#}}"}]}, + }, + ], + "edges": [{"source": "start", "target": "llm"}], + } + is_valid, hints = WorkflowValidator.validate(workflow_data, []) + + # Should have a hint about invalid reference + ref_hints = [h for h in hints if "nonexistent" in h.message or "reference" in h.message.lower()] + assert len(ref_hints) >= 1 + + +class TestEdgeValidation: + """Tests for edge validation.""" + + def test_edge_with_invalid_source(self): + """Test edge with non-existent source generates hint.""" + workflow_data = { + "nodes": [{"id": "end", "type": "end", "config": {}}], + "edges": [{"source": "nonexistent", "target": "end"}], + } + is_valid, hints = WorkflowValidator.validate(workflow_data, []) + + # Should have hint about invalid edge + edge_hints = [h for h in hints if "edge" in h.message.lower() or "source" in h.message.lower()] + assert len(edge_hints) >= 1 + + def test_edge_with_invalid_target(self): + """Test edge with non-existent target generates hint.""" + workflow_data = { + "nodes": [{"id": "start", "type": "start", "config": {}}], + "edges": [{"source": "start", "target": "nonexistent"}], + } + is_valid, hints = WorkflowValidator.validate(workflow_data, []) + + edge_hints = [h for h in hints if "edge" in h.message.lower() or "target" in h.message.lower()] + assert len(edge_hints) >= 1 + + +class TestToolValidation: + """Tests for tool node validation.""" + + def test_tool_node_found_in_available(self): + """Test tool node that exists in available tools.""" + workflow_data = { + "nodes": [ + {"id": "start", "type": "start", "config": {}}, + { + "id": "tool1", + "type": "tool", + "config": {"tool_key": "google/search"}, + }, + {"id": "end", "type": "end", "config": {}}, + ], + "edges": [{"source": "start", "target": "tool1"}, {"source": "tool1", "target": "end"}], + } + available_tools = [{"provider_id": "google", "tool_key": "search", "is_team_authorization": True}] + is_valid, hints = WorkflowValidator.validate(workflow_data, available_tools) + + tool_errors = [h for h in hints if h.severity == "error" and "tool" in h.message.lower()] + assert len(tool_errors) == 0 + + def test_tool_node_not_found(self): + """Test tool node not in available tools generates hint.""" + workflow_data = { + "nodes": [ + { + "id": "tool1", + "type": "tool", + "config": {"tool_key": "unknown/tool"}, + } + ], + "edges": [], + } + available_tools = [] + is_valid, hints = WorkflowValidator.validate(workflow_data, available_tools) + + tool_hints = [h for h in hints if "tool" in h.message.lower()] + assert len(tool_hints) >= 1 + + +class TestQuestionClassifierValidation: + """Tests for question-classifier node validation.""" + + def test_question_classifier_with_classes(self): + """Test question-classifier with valid classes.""" + workflow_data = { + "nodes": [ + {"id": "start", "type": "start", "config": {}}, + { + "id": "classifier", + "type": "question-classifier", + "config": { + "classes": [ + {"id": "class1", "name": "Class 1"}, + {"id": "class2", "name": "Class 2"}, + ], + "model": {"provider": "openai", "name": "gpt-4", "mode": "chat"}, + }, + }, + {"id": "h1", "type": "llm", "config": {}}, + {"id": "h2", "type": "llm", "config": {}}, + {"id": "end", "type": "end", "config": {}}, + ], + "edges": [ + {"source": "start", "target": "classifier"}, + {"source": "classifier", "sourceHandle": "class1", "target": "h1"}, + {"source": "classifier", "sourceHandle": "class2", "target": "h2"}, + {"source": "h1", "target": "end"}, + {"source": "h2", "target": "end"}, + ], + } + available_models = [{"provider": "openai", "model": "gpt-4", "mode": "chat"}] + is_valid, hints = WorkflowValidator.validate(workflow_data, [], available_models=available_models) + + class_errors = [h for h in hints if "class" in h.message.lower() and h.severity == "error"] + assert len(class_errors) == 0 + + def test_question_classifier_missing_classes(self): + """Test question-classifier without classes generates hint.""" + workflow_data = { + "nodes": [ + { + "id": "classifier", + "type": "question-classifier", + "config": {"model": {"provider": "openai", "name": "gpt-4", "mode": "chat"}}, + } + ], + "edges": [], + } + available_models = [{"provider": "openai", "model": "gpt-4", "mode": "chat"}] + is_valid, hints = WorkflowValidator.validate(workflow_data, [], available_models=available_models) + + # Should have hint about missing classes + class_hints = [h for h in hints if "class" in h.message.lower()] + assert len(class_hints) >= 1 + + +class TestHttpRequestValidation: + """Tests for HTTP request node validation.""" + + def test_http_request_with_url(self): + """Test HTTP request with valid URL.""" + workflow_data = { + "nodes": [ + {"id": "start", "type": "start", "config": {}}, + { + "id": "http", + "type": "http-request", + "config": {"url": "https://api.example.com", "method": "GET"}, + }, + {"id": "end", "type": "end", "config": {}}, + ], + "edges": [{"source": "start", "target": "http"}, {"source": "http", "target": "end"}], + } + is_valid, hints = WorkflowValidator.validate(workflow_data, []) + + url_errors = [h for h in hints if "url" in h.message.lower() and h.severity == "error"] + assert len(url_errors) == 0 + + def test_http_request_missing_url(self): + """Test HTTP request without URL generates hint.""" + workflow_data = { + "nodes": [ + { + "id": "http", + "type": "http-request", + "config": {"method": "GET"}, + } + ], + "edges": [], + } + is_valid, hints = WorkflowValidator.validate(workflow_data, []) + + url_hints = [h for h in hints if "url" in h.message.lower()] + assert len(url_hints) >= 1 + + +class TestParameterExtractorValidation: + """Tests for parameter-extractor node validation.""" + + def test_parameter_extractor_valid_params(self): + """Test parameter-extractor with valid parameters.""" + workflow_data = { + "nodes": [ + {"id": "start", "type": "start", "config": {}}, + { + "id": "extractor", + "type": "parameter-extractor", + "config": { + "instruction": "Extract info", + "parameters": [ + { + "name": "name", + "type": "string", + "description": "Name", + "required": True, + } + ], + "model": {"provider": "openai", "name": "gpt-4", "mode": "chat"}, + }, + }, + {"id": "end", "type": "end", "config": {}}, + ], + "edges": [{"source": "start", "target": "extractor"}, {"source": "extractor", "target": "end"}], + } + available_models = [{"provider": "openai", "model": "gpt-4", "mode": "chat"}] + is_valid, hints = WorkflowValidator.validate(workflow_data, [], available_models=available_models) + + errors = [h for h in hints if h.severity == "error"] + assert len(errors) == 0 + + def test_parameter_extractor_missing_required_field(self): + """Test parameter-extractor missing 'required' field in parameter item.""" + workflow_data = { + "nodes": [ + { + "id": "extractor", + "type": "parameter-extractor", + "config": { + "instruction": "Extract info", + "parameters": [ + { + "name": "name", + "type": "string", + "description": "Name", + # Missing 'required' + } + ], + "model": {"provider": "openai", "name": "gpt-4", "mode": "chat"}, + }, + } + ], + "edges": [], + } + available_models = [{"provider": "openai", "model": "gpt-4", "mode": "chat"}] + is_valid, hints = WorkflowValidator.validate(workflow_data, [], available_models=available_models) + + errors = [h for h in hints if "required" in h.message and h.severity == "error"] + assert len(errors) >= 1 + assert "parameter-extractor" in errors[0].node_type + + +class TestIfElseValidation: + """Tests for if-else node validation.""" + + def test_if_else_valid_operators(self): + """Test if-else with valid operators.""" + workflow_data = { + "nodes": [ + {"id": "start", "type": "start", "config": {}}, + { + "id": "ifelse", + "type": "if-else", + "config": { + "cases": [{"case_id": "c1", "conditions": [{"comparison_operator": "≥", "value": "1"}]}] + }, + }, + {"id": "t", "type": "llm", "config": {}}, + {"id": "f", "type": "llm", "config": {}}, + {"id": "end", "type": "end", "config": {}}, + ], + "edges": [ + {"source": "start", "target": "ifelse"}, + {"source": "ifelse", "sourceHandle": "true", "target": "t"}, + {"source": "ifelse", "sourceHandle": "false", "target": "f"}, + {"source": "t", "target": "end"}, + {"source": "f", "target": "end"}, + ], + } + is_valid, hints = WorkflowValidator.validate(workflow_data, []) + errors = [h for h in hints if h.severity == "error"] + # Filter out LLM model errors if any (available tools/models check might trigger) + # (actually available_models empty list might trigger model error? + # No, model config validation skips if model field not present? No, LLM has model config. + # But logic skips check if key missing? Let's check logic. + # _check_model_config checks if provider/name match available. If available is empty, it fails. + # But wait, validate default available_models is None? + # I should provide mock available_models or ignore model errors. + + # Actually LLM node "config": {} implies missing model config. Rules check if config structure is valid? + # Let's filter specifically for operator errors. + operator_errors = [h for h in errors if "operator" in h.message] + assert len(operator_errors) == 0 + + def test_if_else_invalid_operators(self): + """Test if-else with invalid operators.""" + workflow_data = { + "nodes": [ + {"id": "start", "type": "start", "config": {}}, + { + "id": "ifelse", + "type": "if-else", + "config": { + "cases": [{"case_id": "c1", "conditions": [{"comparison_operator": ">=", "value": "1"}]}] + }, + }, + {"id": "t", "type": "llm", "config": {}}, + {"id": "f", "type": "llm", "config": {}}, + {"id": "end", "type": "end", "config": {}}, + ], + "edges": [ + {"source": "start", "target": "ifelse"}, + {"source": "ifelse", "sourceHandle": "true", "target": "t"}, + {"source": "ifelse", "sourceHandle": "false", "target": "f"}, + {"source": "t", "target": "end"}, + {"source": "f", "target": "end"}, + ], + } + is_valid, hints = WorkflowValidator.validate(workflow_data, []) + operator_errors = [h for h in hints if "operator" in h.message and h.severity == "error"] + assert len(operator_errors) > 0 + assert "≥" in operator_errors[0].suggestion diff --git a/web/app/components/workflow/hooks/__tests__/use-workflow-vibe.test.ts b/web/app/components/workflow/hooks/__tests__/use-workflow-vibe.test.ts new file mode 100644 index 0000000000..11d19ce9d2 --- /dev/null +++ b/web/app/components/workflow/hooks/__tests__/use-workflow-vibe.test.ts @@ -0,0 +1,82 @@ + +import { describe, it, expect } from 'vitest' +import { replaceVariableReferences } from '../use-workflow-vibe' +import { BlockEnum } from '@/app/components/workflow/types' + +// Mock types needed for the test +interface NodeData { + title: string + [key: string]: any +} + +describe('use-workflow-vibe', () => { + describe('replaceVariableReferences', () => { + it('should replace variable references in strings', () => { + const data = { + title: 'Test Node', + prompt: 'Hello {{#old_id.query#}}', + } + const nodeIdMap = new Map() + nodeIdMap.set('old_id', { id: 'new_uuid', data: { type: 'llm' } }) + + const result = replaceVariableReferences(data, nodeIdMap) as NodeData + expect(result.prompt).toBe('Hello {{#new_uuid.query#}}') + }) + + it('should handle multiple references in one string', () => { + const data = { + title: 'Test Node', + text: '{{#node1.out#}} and {{#node2.out#}}', + } + const nodeIdMap = new Map() + nodeIdMap.set('node1', { id: 'uuid1', data: { type: 'llm' } }) + nodeIdMap.set('node2', { id: 'uuid2', data: { type: 'llm' } }) + + const result = replaceVariableReferences(data, nodeIdMap) as NodeData + expect(result.text).toBe('{{#uuid1.out#}} and {{#uuid2.out#}}') + }) + + it('should replace variable references in value_selector arrays', () => { + const data = { + title: 'End Node', + outputs: [ + { + variable: 'result', + value_selector: ['old_id', 'text'], + }, + ], + } + const nodeIdMap = new Map() + nodeIdMap.set('old_id', { id: 'new_uuid', data: { type: 'llm' } }) + + const result = replaceVariableReferences(data, nodeIdMap) as NodeData + expect(result.outputs[0].value_selector).toEqual(['new_uuid', 'text']) + }) + + it('should handle nested objects recursively', () => { + const data = { + config: { + model: { + prompt: '{{#old_id.text#}}', + }, + }, + } + const nodeIdMap = new Map() + nodeIdMap.set('old_id', { id: 'new_uuid', data: { type: 'llm' } }) + + const result = replaceVariableReferences(data, nodeIdMap) as any + expect(result.config.model.prompt).toBe('{{#new_uuid.text#}}') + }) + + it('should ignoring missing node mappings', () => { + const data = { + text: '{{#missing_id.text#}}', + } + const nodeIdMap = new Map() + // missing_id is not in map + + const result = replaceVariableReferences(data, nodeIdMap) as NodeData + expect(result.text).toBe('{{#missing_id.text#}}') + }) + }) +}) diff --git a/web/app/components/workflow/hooks/use-workflow-vibe.tsx b/web/app/components/workflow/hooks/use-workflow-vibe.tsx index 72e98301b8..8ec2916513 100644 --- a/web/app/components/workflow/hooks/use-workflow-vibe.tsx +++ b/web/app/components/workflow/hooks/use-workflow-vibe.tsx @@ -39,6 +39,7 @@ import { getNodeCustomTypeByNodeDataType, getNodesConnectedSourceOrTargetHandleIdsMap, } from '../utils' +import { initialNodes as initializeNodeData } from '../utils/workflow-init' import { useNodesMetaData } from './use-nodes-meta-data' import { useNodesSyncDraft } from './use-nodes-sync-draft' import { useNodesReadOnly } from './use-workflow' @@ -115,7 +116,7 @@ const normalizeProviderIcon = (icon?: ToolWithProvider['icon']) => { * - Mixed content objects: {type: "mixed", value: "..."} → normalized to string * - Field name correction based on node type */ -const replaceVariableReferences = ( +export const replaceVariableReferences = ( data: unknown, nodeIdMap: Map, parentKey?: string, @@ -124,6 +125,11 @@ const replaceVariableReferences = ( // Replace {{#old_id.field#}} patterns and correct field names return data.replace(/\{\{#([^.#]+)\.([^#]+)#\}\}/g, (match, oldId, field) => { const newNode = nodeIdMap.get(oldId) + // #region agent log + if (!newNode) { + console.warn(`[VIBE DEBUG] replaceVariableReferences: No mapping for "${oldId}" in template "${match}"`) + } + // #endregion if (newNode) { const nodeType = newNode.data?.type as string || '' const correctedField = correctFieldName(field, nodeType) @@ -138,6 +144,11 @@ const replaceVariableReferences = ( if (data.length >= 2 && typeof data[0] === 'string' && typeof data[1] === 'string') { const potentialNodeId = data[0] const newNode = nodeIdMap.get(potentialNodeId) + // #region agent log + if (!newNode && !['sys', 'env', 'conversation'].includes(potentialNodeId)) { + console.warn(`[VIBE DEBUG] replaceVariableReferences: No mapping for "${potentialNodeId}" in selector [${data.join(', ')}]`) + } + // #endregion if (newNode) { const nodeType = newNode.data?.type as string || '' const correctedField = correctFieldName(data[1], nodeType) @@ -598,6 +609,8 @@ export const useWorkflowVibe = () => { const { getNodes } = store.getState() const nodes = getNodes() + + if (!nodesMetaDataMap) { Toast.notify({ type: 'error', message: t('workflow.vibe.nodesUnavailable') }) return { nodes: [], edges: [] } @@ -699,12 +712,59 @@ export const useWorkflowVibe = () => { } } - // For any node with model config, ALWAYS use user's default model - if (backendConfig.model && defaultModel) { - mergedConfig.model = { - provider: defaultModel.provider.provider, - name: defaultModel.model, - mode: 'chat', + // For End nodes, ensure outputs have value_selector format + // New format (preferred): {"outputs": [{"variable": "name", "value_selector": ["nodeId", "field"]}]} + // Legacy format (fallback): {"outputs": [{"variable": "name", "value": "{{#nodeId.field#}}"}]} + if (nodeType === BlockEnum.End && backendConfig.outputs) { + const outputs = backendConfig.outputs as Array<{ variable?: string, value?: string, value_selector?: string[] }> + mergedConfig.outputs = outputs.map((output) => { + // Preferred: value_selector array format (new LLM output format) + if (output.value_selector && Array.isArray(output.value_selector)) { + return output + } + // Parse value like "{{#nodeId.field#}}" into ["nodeId", "field"] + if (output.value) { + const match = output.value.match(/\{\{#([^.]+)\.([^#]+)#\}\}/) + if (match) { + return { + variable: output.variable, + value_selector: [match[1], match[2]], + } + } + } + // Fallback: return with empty value_selector to prevent crash + return { + variable: output.variable || 'output', + value_selector: [], + } + }) + } + + // For Parameter Extractor nodes, ensure each parameter has a 'required' field + // Backend may omit this field, but Dify's Pydantic model requires it + if (nodeType === BlockEnum.ParameterExtractor && backendConfig.parameters) { + const parameters = backendConfig.parameters as Array<{ name?: string, type?: string, description?: string, required?: boolean }> + mergedConfig.parameters = parameters.map((param) => ({ + ...param, + required: param.required ?? true, // Default to required if not specified + })) + } + + // For any node with model config, ALWAYS use user's configured model + // This prevents "Model not exist" errors when LLM generates models the user doesn't have configured + // Applies to: LLM, QuestionClassifier, ParameterExtractor, and any future model-dependent nodes + if (backendConfig.model) { + // Try to use defaultModel first, fallback to first available model from modelList + const fallbackModel = modelList?.[0]?.models?.[0] + const modelProvider = defaultModel?.provider?.provider || modelList?.[0]?.provider + const modelName = defaultModel?.model || fallbackModel?.model + + if (modelProvider && modelName) { + mergedConfig.model = { + provider: modelProvider, + name: modelName, + mode: 'chat', + } } } @@ -731,10 +791,19 @@ export const useWorkflowVibe = () => { } // Replace variable references in all node configs using the nodeIdMap + // This converts {{#old_id.field#}} to {{#new_uuid.field#}} + for (const node of newNodes) { node.data = replaceVariableReferences(node.data, nodeIdMap) as typeof node.data } + // Use Dify's standard node initialization to handle all node types generically + // This sets up _targetBranches for question-classifier/if-else, _children for iteration/loop, etc. + const initializedNodes = initializeNodeData(newNodes, []) + + // Update newNodes with initialized data + newNodes.splice(0, newNodes.length, ...initializedNodes) + if (!newNodes.length) { Toast.notify({ type: 'error', message: t('workflow.vibe.invalidFlowchart') }) return { nodes: [], edges: [] } @@ -762,12 +831,16 @@ export const useWorkflowVibe = () => { zIndex: 0, }) + const newEdges: Edge[] = [] for (const edgeSpec of backendEdges) { const sourceNode = nodeIdMap.get(edgeSpec.source) const targetNode = nodeIdMap.get(edgeSpec.target) - if (!sourceNode || !targetNode) + + if (!sourceNode || !targetNode) { + console.warn(`[VIBE] Edge skipped: source=${edgeSpec.source} (found=${!!sourceNode}), target=${edgeSpec.target} (found=${!!targetNode})`) continue + } let sourceHandle = edgeSpec.sourceHandle || 'source' // Handle IfElse branch handles @@ -775,9 +848,11 @@ export const useWorkflowVibe = () => { sourceHandle = 'source' } + newEdges.push(buildEdge(sourceNode, targetNode, sourceHandle, edgeSpec.targetHandle || 'target')) } + // Layout nodes const bounds = nodes.reduce( (acc, node) => { @@ -878,11 +953,15 @@ export const useWorkflowVibe = () => { } }) + + setNodes(updatedNodes) setEdges([...edges, ...newEdges]) saveStateToHistory(WorkflowHistoryEvent.NodeAdd, { nodeId: newNodes[0].id }) handleSyncWorkflowDraft() + + workflowStore.setState(state => ({ ...state, showVibePanel: false, @@ -1194,81 +1273,128 @@ export const useWorkflowVibe = () => { output_schema: tool.output_schema, })) - const stream = await generateFlowchart({ - instruction: trimmed, - model_config: latestModelConfig!, - existing_nodes: existingNodesPayload, - tools: toolsPayload, - regenerate_mode: regenerateMode, - }) + const availableNodesPayload = availableNodesList.map(node => ({ + type: node.type, + title: node.title, + description: node.description, + })) let mermaidCode = '' let backendNodes: BackendNodeSpec[] | undefined let backendEdges: BackendEdgeSpec[] | undefined - const reader = stream.getReader() - const decoder = new TextDecoder() - - while (true) { - const { done, value } = await reader.read() - if (done) - break - - const chunk = decoder.decode(value) - const lines = chunk.split('\n') - - for (const line of lines) { - if (!line.trim() || !line.startsWith('data: ')) - continue - - try { - const data = JSON.parse(line.slice(6)) - if (data.event === 'message' || data.event === 'workflow_generated') { - if (data.data?.text) { - mermaidCode += data.data.text - workflowStore.setState(state => ({ - ...state, - vibePanelMermaidCode: mermaidCode, - })) - } - if (data.data?.nodes) { - backendNodes = data.data.nodes - workflowStore.setState(state => ({ - ...state, - vibePanelBackendNodes: backendNodes, - })) - } - if (data.data?.edges) { - backendEdges = data.data.edges - workflowStore.setState(state => ({ - ...state, - vibePanelBackendEdges: backendEdges, - })) - } - if (data.data?.intent) { - workflowStore.setState(state => ({ - ...state, - vibePanelIntent: data.data.intent, - })) - } - if (data.data?.message) { - workflowStore.setState(state => ({ - ...state, - vibePanelMessage: data.data.message, - })) - } - if (data.data?.suggestions) { - workflowStore.setState(state => ({ - ...state, - vibePanelSuggestions: data.data.suggestions, - })) - } - } - } - catch (e) { - console.error('Error parsing chunk:', e) + if (!isMermaidFlowchart(trimmed)) { + // Build previous workflow context if regenerating + const { vibePanelBackendNodes, vibePanelBackendEdges, vibePanelLastWarnings } = workflowStore.getState() + const previousWorkflow = regenerateMode && vibePanelBackendNodes && vibePanelBackendNodes.length > 0 + ? { + nodes: vibePanelBackendNodes, + edges: vibePanelBackendEdges || [], + warnings: vibePanelLastWarnings || [], } + : undefined + + // Map language code to human-readable language name for LLM + const languageNameMap: Record = { + en_US: 'English', + zh_Hans: 'Chinese', + zh_Hant: 'Traditional Chinese', + ja_JP: 'Japanese', + ko_KR: 'Korean', + pt_BR: 'Portuguese', + es_ES: 'Spanish', + fr_FR: 'French', + de_DE: 'German', + it_IT: 'Italian', + ru_RU: 'Russian', + uk_UA: 'Ukrainian', + vi_VN: 'Vietnamese', + pl_PL: 'Polish', + ro_RO: 'Romanian', + tr_TR: 'Turkish', + fa_IR: 'Persian', + hi_IN: 'Hindi', } + const preferredLanguage = languageNameMap[language] || 'English' + + // Extract available models from user's configured model providers + const availableModelsPayload = modelList?.flatMap(provider => + provider.models.map(model => ({ + provider: provider.provider, + model: model.model, + })), + ) || [] + + const requestPayload = { + instruction: trimmed, + model_config: latestModelConfig, + available_nodes: availableNodesPayload, + existing_nodes: existingNodesPayload, + available_tools: toolsPayload, + selected_node_ids: [], + previous_workflow: previousWorkflow, + regenerate_mode: regenerateMode, + language: preferredLanguage, + available_models: availableModelsPayload, + } + + const response = await generateFlowchart(requestPayload) + + const { error, flowchart, nodes, edges, intent, message, warnings, suggestions } = response + + if (error) { + Toast.notify({ type: 'error', message: error }) + setIsVibeGenerating(false) + return + } + + // Handle off_topic intent - show rejection message and suggestions + if (intent === 'off_topic') { + workflowStore.setState(state => ({ + ...state, + vibePanelMermaidCode: '', + vibePanelMessage: message || t('workflow.vibe.offTopicDefault'), + vibePanelSuggestions: suggestions || [], + vibePanelIntent: 'off_topic', + isVibeGenerating: false, + })) + return + } + + if (!flowchart) { + Toast.notify({ type: 'error', message: t('workflow.vibe.missingFlowchart') }) + setIsVibeGenerating(false) + return + } + + // Show warnings if any (includes tool sanitization warnings) + const responseWarnings = warnings || [] + if (responseWarnings.length > 0) { + responseWarnings.forEach((warning) => { + Toast.notify({ type: 'warning', message: warning }) + }) + } + + mermaidCode = flowchart + // Store backend nodes/edges for direct use (bypasses mermaid re-parsing) + backendNodes = nodes + backendEdges = edges + // Store warnings for regeneration context + workflowStore.setState(state => ({ + ...state, + vibePanelLastWarnings: responseWarnings, + })) + + workflowStore.setState(state => ({ + ...state, + vibePanelMermaidCode: mermaidCode, + vibePanelBackendNodes: backendNodes, + vibePanelBackendEdges: backendEdges, + vibePanelMessage: '', + vibePanelSuggestions: [], + vibePanelIntent: 'generate', + isVibeGenerating: false, + })) } setIsVibeGenerating(false) @@ -1286,10 +1412,16 @@ export const useWorkflowVibe = () => { if (skipPanelPreview) { // Prefer backend nodes (already sanitized) over mermaid re-parsing if (backendNodes && backendNodes.length > 0 && backendEdges) { + console.log('[VIBE] Applying backend nodes directly to workflow') + console.log('[VIBE] Backend nodes:', backendNodes.length) + console.log('[VIBE] Backend edges:', backendEdges.length) await applyBackendNodesToWorkflow(backendNodes, backendEdges) + console.log('[VIBE] Backend nodes applied successfully') } else { + console.log('[VIBE] Applying mermaid flowchart to workflow') await applyFlowchartToWorkflow() + console.log('[VIBE] Mermaid flowchart applied successfully') } } }