refactor(vibe): extract workflow generation to dedicated module

- Move workflow generation logic from LLMGenerator to WorkflowGenerator
  - Extract to api/core/workflow/generator/ with modular architecture
  - Implement Planner-Builder pattern for better separation of concerns
  - Add validation engine with rule-based error classification
  - Add node and edge repair utilities for auto-fixing common issues
  - Add deterministic Mermaid generator for consistent output

- Reorganize configuration and prompts
  - Move vibe_config/ to generator/config/
  - Move vibe_prompts.py to generator/prompts/ (split into multiple files)
  - Add builder_prompts.py and planner_prompts.py for new architecture

- Enhance frontend workflow handling
  - Use standard node initialization for proper node setup
  - Improve variable reference replacement with better error handling
  - Add model fallback logic for better compatibility
  - Handle end node outputs format (value_selector vs legacy format)
  - Ensure parameter-extractor nodes have required 'required' field

- Add comprehensive test coverage
  - Unit tests for mermaid generator, node repair, edge repair
  - Tests for validation engine and rule system
  - Tests for planner prompts formatting
  - Frontend tests for variable reference replacement

- Add max_fix_iterations parameter for validate-fix loop configuration

# Conflicts:
#	web/app/components/workflow/hooks/use-workflow-vibe.tsx
This commit is contained in:
aqiu 2025-12-27 15:06:44 +08:00
parent c4eee28fd8
commit cd030d82e5
28 changed files with 5046 additions and 267 deletions

View File

@ -1,9 +1,13 @@
import logging
from collections.abc import Sequence
from typing import Any
from typing import Any, cast
from flask_restx import Resource
from pydantic import BaseModel, Field
logger = logging.getLogger(__name__)
from controllers.console import console_ns
from controllers.console.app.error import (
CompletionRequestError,
@ -18,6 +22,7 @@ from core.helper.code_executor.javascript.javascript_code_provider import Javasc
from core.helper.code_executor.python3.python3_code_provider import Python3CodeProvider
from core.llm_generator.llm_generator import LLMGenerator
from core.model_runtime.errors.invoke import InvokeError
from core.workflow.generator import WorkflowGenerator
from extensions.ext_database import db
from libs.login import current_account_with_tenant, login_required
from models import App
@ -77,6 +82,13 @@ class FlowchartGeneratePayload(BaseModel):
language: str | None = Field(default=None, description="Preferred language for generated content")
# Available models that user has configured (for LLM/question-classifier nodes)
available_models: list[dict[str, Any]] = Field(default_factory=list, description="User's configured models")
# Validate-fix iteration loop configuration
max_fix_iterations: int = Field(
default=2,
ge=0,
le=5,
description="Maximum number of validate-fix iterations (0 to disable auto-fix)",
)
def reg(cls: type[BaseModel]):
@ -305,7 +317,7 @@ class FlowchartGenerateApi(Resource):
"warnings": args.previous_workflow.warnings,
}
result = LLMGenerator.generate_workflow_flowchart(
result = WorkflowGenerator.generate_workflow_flowchart(
tenant_id=current_tenant_id,
instruction=args.instruction,
model_config=args.model_config_data,
@ -313,11 +325,13 @@ class FlowchartGenerateApi(Resource):
existing_nodes=args.existing_nodes,
available_tools=args.available_tools,
selected_node_ids=args.selected_node_ids,
previous_workflow=previous_workflow_dict,
previous_workflow=cast(dict[str, object], previous_workflow_dict),
regenerate_mode=args.regenerate_mode,
preferred_language=args.language,
available_models=args.available_models,
max_fix_iterations=args.max_fix_iterations,
)
except ProviderTokenNotInitError as ex:
raise ProviderNotInitializeError(ex.description)
except QuotaExceededError:

View File

@ -1,6 +1,5 @@
import json
import logging
import re
from collections.abc import Sequence
from typing import Protocol, cast
@ -12,8 +11,6 @@ from core.llm_generator.prompts import (
CONVERSATION_TITLE_PROMPT,
GENERATOR_QA_PROMPT,
JAVASCRIPT_CODE_GENERATOR_PROMPT_TEMPLATE,
LLM_MODIFY_CODE_SYSTEM,
LLM_MODIFY_PROMPT_SYSTEM,
PYTHON_CODE_GENERATOR_PROMPT_TEMPLATE,
SUGGESTED_QUESTIONS_MAX_TOKENS,
SUGGESTED_QUESTIONS_TEMPERATURE,
@ -30,6 +27,7 @@ from core.ops.ops_trace_manager import TraceQueueManager, TraceTask
from core.ops.utils import measure_time
from core.prompt.utils.prompt_template_parser import PromptTemplateParser
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey
from core.workflow.generator import WorkflowGenerator
from extensions.ext_database import db
from extensions.ext_storage import storage
from models import App, Message, WorkflowNodeExecutionModel
@ -299,178 +297,24 @@ class LLMGenerator:
regenerate_mode: bool = False,
preferred_language: str | None = None,
available_models: Sequence[dict[str, object]] | None = None,
max_fix_iterations: int = 2,
):
"""
Generate workflow flowchart with enhanced prompts and inline intent classification.
Returns a dict with:
- intent: "generate" | "off_topic" | "error"
- flowchart: Mermaid syntax string (for generate intent)
- message: User-friendly explanation
- warnings: List of validation warnings
- suggestions: List of workflow suggestions (for off_topic intent)
- error: Error message if generation failed
"""
from core.llm_generator.vibe_prompts import (
build_vibe_enhanced_prompt,
extract_mermaid_from_response,
parse_vibe_response,
sanitize_tool_nodes,
validate_node_parameters,
validate_tool_references,
)
model_parameters = model_config.get("completion_params", {})
# Build enhanced prompts with context
system_prompt, user_prompt = build_vibe_enhanced_prompt(
return WorkflowGenerator.generate_workflow_flowchart(
tenant_id=tenant_id,
instruction=instruction,
available_nodes=list(available_nodes) if available_nodes else None,
available_tools=list(available_tools) if available_tools else None,
existing_nodes=list(existing_nodes) if existing_nodes else None,
selected_node_ids=list(selected_node_ids) if selected_node_ids else None,
previous_workflow=dict(previous_workflow) if previous_workflow else None,
model_config=model_config,
available_nodes=available_nodes,
existing_nodes=existing_nodes,
available_tools=available_tools,
selected_node_ids=selected_node_ids,
previous_workflow=previous_workflow,
regenerate_mode=regenerate_mode,
preferred_language=preferred_language,
available_models=list(available_models) if available_models else None,
available_models=available_models,
max_fix_iterations=max_fix_iterations,
)
prompt_messages: list[PromptMessage] = [
SystemPromptMessage(content=system_prompt),
UserPromptMessage(content=user_prompt),
]
# DEBUG: Log model input
logger.debug("=" * 80)
logger.debug("[VIBE] generate_workflow_flowchart - MODEL INPUT")
logger.debug("=" * 80)
logger.debug("[VIBE] Instruction: %s", instruction)
logger.debug("[VIBE] Model: %s/%s", model_config.get("provider", ""), model_config.get("name", ""))
system_prompt_log = system_prompt[:2000] + "..." if len(system_prompt) > 2000 else system_prompt
logger.debug("[VIBE] System Prompt:\n%s", system_prompt_log)
logger.debug("[VIBE] User Prompt:\n%s", user_prompt)
logger.debug("=" * 80)
model_manager = ModelManager()
model_instance = model_manager.get_model_instance(
tenant_id=tenant_id,
model_type=ModelType.LLM,
provider=model_config.get("provider", ""),
model=model_config.get("name", ""),
)
try:
response: LLMResult = model_instance.invoke_llm(
prompt_messages=list(prompt_messages),
model_parameters=model_parameters,
stream=False,
)
content = response.message.get_text_content()
# DEBUG: Log model output
logger.debug("=" * 80)
logger.debug("[VIBE] generate_workflow_flowchart - MODEL OUTPUT")
logger.debug("=" * 80)
logger.debug("[VIBE] Raw Response:\n%s", content)
logger.debug("=" * 80)
if not isinstance(content, str):
raise ValueError("Flowchart response is not a string")
# Parse the enhanced response format
parsed = parse_vibe_response(content)
# DEBUG: Log parsed result
logger.debug("[VIBE] Parsed Response:")
logger.debug("[VIBE] intent: %s", parsed.get("intent"))
logger.debug("[VIBE] message: %s", parsed.get("message", "")[:200] if parsed.get("message") else "")
logger.debug("[VIBE] mermaid: %s", parsed.get("mermaid", "")[:500] if parsed.get("mermaid") else "")
logger.debug("[VIBE] warnings: %s", parsed.get("warnings", []))
logger.debug("[VIBE] suggestions: %s", parsed.get("suggestions", []))
if parsed.get("error"):
logger.debug("[VIBE] error: %s", parsed.get("error"))
logger.debug("=" * 80)
# Handle error case from parsing
if parsed.get("intent") == "error":
# Fall back to legacy parsing for backwards compatibility
match = re.search(r"```(?:mermaid)?\s*([\s\S]+?)```", content, flags=re.IGNORECASE)
flowchart = (match.group(1) if match else content).strip()
return {
"intent": "generate",
"flowchart": flowchart,
"message": "",
"warnings": [],
"tool_recommendations": [],
"error": "",
}
# Handle off_topic case
if parsed.get("intent") == "off_topic":
return {
"intent": "off_topic",
"flowchart": "",
"message": parsed.get("message", ""),
"suggestions": parsed.get("suggestions", []),
"warnings": [],
"tool_recommendations": [],
"error": "",
}
# Handle generate case
flowchart = extract_mermaid_from_response(parsed)
# Sanitize tool nodes - replace invalid tools with fallback nodes
original_nodes = parsed.get("nodes", [])
sanitized_nodes, sanitize_warnings = sanitize_tool_nodes(
original_nodes,
list(available_tools) if available_tools else None,
)
# Update parsed nodes with sanitized version
parsed["nodes"] = sanitized_nodes
# Validate tool references and get recommendations for unconfigured tools
validation_warnings, tool_recommendations = validate_tool_references(
sanitized_nodes,
list(available_tools) if available_tools else None,
)
# Validate node parameters are properly filled (Phase 9: Auto-Fill)
param_warnings = validate_node_parameters(sanitized_nodes)
existing_warnings = parsed.get("warnings", [])
all_warnings = existing_warnings + sanitize_warnings + validation_warnings + param_warnings
return {
"intent": "generate",
"flowchart": flowchart,
"nodes": sanitized_nodes, # Include sanitized nodes in response
"edges": parsed.get("edges", []),
"message": parsed.get("message", ""),
"warnings": all_warnings,
"tool_recommendations": tool_recommendations,
"error": "",
}
except InvokeError as e:
return {
"intent": "error",
"flowchart": "",
"message": "",
"warnings": [],
"tool_recommendations": [],
"error": str(e),
}
except Exception as e:
logger.exception("Failed to generate workflow flowchart, model: %s", model_config.get("name"))
return {
"intent": "error",
"flowchart": "",
"message": "",
"warnings": [],
"tool_recommendations": [],
"error": str(e),
}
@classmethod
def generate_code(cls, tenant_id: str, instruction: str, model_config: dict, code_language: str = "javascript"):
if code_language == "python":

View File

@ -0,0 +1 @@
from .runner import WorkflowGenerator

View File

@ -5,14 +5,14 @@ This module centralizes configuration for the Vibe workflow generation feature,
including node schemas, fallback rules, and response templates.
"""
from core.llm_generator.vibe_config.fallback_rules import (
from core.workflow.generator.config.fallback_rules import (
FALLBACK_RULES,
FIELD_NAME_CORRECTIONS,
NODE_TYPE_ALIASES,
get_corrected_field_name,
)
from core.llm_generator.vibe_config.node_schemas import BUILTIN_NODE_SCHEMAS
from core.llm_generator.vibe_config.responses import DEFAULT_SUGGESTIONS, OFF_TOPIC_RESPONSES
from core.workflow.generator.config.node_schemas import BUILTIN_NODE_SCHEMAS
from core.workflow.generator.config.responses import DEFAULT_SUGGESTIONS, OFF_TOPIC_RESPONSES
__all__ = [
"BUILTIN_NODE_SCHEMAS",

View File

@ -137,19 +137,28 @@ BUILTIN_NODE_SCHEMAS: dict[str, dict[str, Any]] = {
},
"if-else": {
"description": "Conditional branching based on conditions",
"required": ["conditions"],
"required": ["cases"],
"parameters": {
"conditions": {
"cases": {
"type": "array",
"description": "List of condition cases",
"description": "List of condition cases. Each case defines when 'true' branch is taken.",
"item_schema": {
"case_id": "string - unique case identifier",
"logical_operator": "enum: and, or",
"conditions": "array of {variable_selector, comparison_operator, value}",
"case_id": "string - unique case identifier (e.g., 'case_1')",
"logical_operator": "enum: and, or - how multiple conditions combine",
"conditions": {
"type": "array",
"item_schema": {
"variable_selector": "array of strings - path to variable, e.g. ['node_id', 'field']",
"comparison_operator": (
"enum: =, ≠, >, <, ≥, ≤, contains, not contains, is, is not, empty, not empty"
),
"value": "string or number - value to compare against",
},
},
},
},
},
"outputs": ["Branches: true (conditions met), false (else)"],
"outputs": ["Branches: true (first case conditions met), false (else/no case matched)"],
},
"knowledge-retrieval": {
"description": "Query knowledge base for relevant content",
@ -207,5 +216,70 @@ BUILTIN_NODE_SCHEMAS: dict[str, dict[str, Any]] = {
},
"outputs": ["item (current iteration item)", "index (current index)"],
},
"parameter-extractor": {
"description": "Extract structured parameters from user input using LLM",
"required": ["query", "parameters"],
"parameters": {
"model": {
"type": "object",
"description": "Model configuration (provider, name, mode)",
},
"query": {
"type": "array",
"description": "Path to input text to extract parameters from, e.g. ['start', 'user_input']",
},
"parameters": {
"type": "array",
"description": "Parameters to extract from the input",
"item_schema": {
"name": "string - parameter name (required)",
"type": (
"enum: string, number, boolean, array[string], array[number], "
"array[object], array[boolean]"
),
"description": "string - description of what to extract (required)",
"required": "boolean - whether this parameter is required (MUST be specified)",
"options": "array of strings (optional) - for enum-like selection",
},
},
"instruction": {
"type": "string",
"description": "Additional instructions for extraction",
},
"reasoning_mode": {
"type": "enum",
"options": ["function_call", "prompt"],
"description": "How to perform extraction (defaults to function_call)",
},
},
"outputs": ["Extracted parameters as defined in parameters array", "__is_success", "__reason"],
},
"question-classifier": {
"description": "Classify user input into predefined categories using LLM",
"required": ["query", "classes"],
"parameters": {
"model": {
"type": "object",
"description": "Model configuration (provider, name, mode)",
},
"query": {
"type": "array",
"description": "Path to input text to classify, e.g. ['start', 'user_input']",
},
"classes": {
"type": "array",
"description": "Classification categories",
"item_schema": {
"id": "string - unique class identifier",
"name": "string - class name/label",
},
},
"instruction": {
"type": "string",
"description": "Additional instructions for classification",
},
},
"outputs": ["class_name (selected class)"],
},
}

View File

@ -0,0 +1,343 @@
BUILDER_SYSTEM_PROMPT = """<role>
You are a Workflow Configuration Engineer.
Your goal is to implement the Architect's plan by generating a precise, runnable Dify Workflow JSON configuration.
</role>
<inputs>
<plan>
{plan_context}
</plan>
<tool_schemas>
{tool_schemas}
</tool_schemas>
<node_specs>
{builtin_node_specs}
</node_specs>
</inputs>
<rules>
1. **Configuration**:
- You MUST fill ALL required parameters for every node.
- Use `{{{{#node_id.field#}}}}` syntax to reference outputs from previous nodes in text fields.
- For 'start' node, define all necessary user inputs.
2. **Variable References**:
- For text fields (like prompts, queries): use string format `{{{{#node_id.field#}}}}`
- For 'end' node outputs: use `value_selector` array format `["node_id", "field"]`
- Example: to reference 'llm' node's 'text' output in end node, use `["llm", "text"]`
3. **Tools**:
- ONLY use the tools listed in `<tool_schemas>`.
- If a planned tool is missing from schemas, fallback to `http-request` or `code`.
4. **Node Specifics**:
- For `if-else` comparison_operator, use literal symbols: ``, ``, `=`, `` (NOT `>=` or `==`).
5. **Output**:
- Return ONLY the JSON object with `nodes` and `edges`.
- Do NOT generate Mermaid diagrams.
- Do NOT generate explanations.
</rules>
<edge_rules priority="critical">
**EDGES ARE CRITICAL** - Every node except 'end' MUST have at least one outgoing edge.
1. **Linear Flow**: Simple source -> target connection
```
{{"source": "node_a", "target": "node_b"}}
```
2. **question-classifier Branching**: Each class MUST have a separate edge with `sourceHandle` = class `id`
- If you define classes: [{{"id": "cls_refund", "name": "Refund"}}, {{"id": "cls_inquiry", "name": "Inquiry"}}]
- You MUST create edges:
- {{"source": "classifier", "sourceHandle": "cls_refund", "target": "refund_handler"}}
- {{"source": "classifier", "sourceHandle": "cls_inquiry", "target": "inquiry_handler"}}
3. **if-else Branching**: MUST have exactly TWO edges with sourceHandle "true" and "false"
- {{"source": "condition", "sourceHandle": "true", "target": "true_branch"}}
- {{"source": "condition", "sourceHandle": "false", "target": "false_branch"}}
4. **Branch Convergence**: Multiple branches can connect to same downstream node
- Both true_branch and false_branch can connect to the same 'end' node
5. **NEVER leave orphan nodes**: Every node must be connected in the graph
</edge_rules>
<examples>
<example name="simple_linear">
```json
{{
"nodes": [
{{
"id": "start",
"type": "start",
"title": "Start",
"config": {{
"variables": [{{"variable": "query", "label": "Query", "type": "text-input"}}]
}}
}},
{{
"id": "llm",
"type": "llm",
"title": "Generate Response",
"config": {{
"model": {{"provider": "openai", "name": "gpt-4o", "mode": "chat"}},
"prompt_template": [{{"role": "user", "text": "Answer: {{{{#start.query#}}}}"}}]
}}
}},
{{
"id": "end",
"type": "end",
"title": "End",
"config": {{
"outputs": [
{{"variable": "result", "value_selector": ["llm", "text"]}}
]
}}
}}
],
"edges": [
{{"source": "start", "target": "llm"}},
{{"source": "llm", "target": "end"}}
]
}}
```
</example>
<example name="question_classifier_branching" description="Customer service with intent classification">
```json
{{
"nodes": [
{{
"id": "start",
"type": "start",
"title": "Start",
"config": {{
"variables": [{{"variable": "user_input", "label": "User Message", "type": "text-input", "required": true}}]
}}
}},
{{
"id": "classifier",
"type": "question-classifier",
"title": "Classify Intent",
"config": {{
"model": {{"provider": "openai", "name": "gpt-4o", "mode": "chat"}},
"query_variable_selector": ["start", "user_input"],
"classes": [
{{"id": "cls_refund", "name": "Refund Request"}},
{{"id": "cls_inquiry", "name": "Product Inquiry"}},
{{"id": "cls_complaint", "name": "Complaint"}},
{{"id": "cls_other", "name": "Other"}}
],
"instruction": "Classify the user's intent"
}}
}},
{{
"id": "handle_refund",
"type": "llm",
"title": "Handle Refund",
"config": {{
"model": {{"provider": "openai", "name": "gpt-4o", "mode": "chat"}},
"prompt_template": [{{"role": "user", "text": "Extract order number and respond: {{{{#start.user_input#}}}}"}}]
}}
}},
{{
"id": "handle_inquiry",
"type": "llm",
"title": "Handle Inquiry",
"config": {{
"model": {{"provider": "openai", "name": "gpt-4o", "mode": "chat"}},
"prompt_template": [{{"role": "user", "text": "Answer product question: {{{{#start.user_input#}}}}"}}]
}}
}},
{{
"id": "handle_complaint",
"type": "llm",
"title": "Handle Complaint",
"config": {{
"model": {{"provider": "openai", "name": "gpt-4o", "mode": "chat"}},
"prompt_template": [{{"role": "user", "text": "Respond with empathy: {{{{#start.user_input#}}}}"}}]
}}
}},
{{
"id": "handle_other",
"type": "llm",
"title": "Handle Other",
"config": {{
"model": {{"provider": "openai", "name": "gpt-4o", "mode": "chat"}},
"prompt_template": [{{"role": "user", "text": "Provide general response: {{{{#start.user_input#}}}}"}}]
}}
}},
{{
"id": "end",
"type": "end",
"title": "End",
"config": {{
"outputs": [{{"variable": "response", "value_selector": ["handle_refund", "text"]}}]
}}
}}
],
"edges": [
{{"source": "start", "target": "classifier"}},
{{"source": "classifier", "sourceHandle": "cls_refund", "target": "handle_refund"}},
{{"source": "classifier", "sourceHandle": "cls_inquiry", "target": "handle_inquiry"}},
{{"source": "classifier", "sourceHandle": "cls_complaint", "target": "handle_complaint"}},
{{"source": "classifier", "sourceHandle": "cls_other", "target": "handle_other"}},
{{"source": "handle_refund", "target": "end"}},
{{"source": "handle_inquiry", "target": "end"}},
{{"source": "handle_complaint", "target": "end"}},
{{"source": "handle_other", "target": "end"}}
]
}}
```
CRITICAL: Notice that each class id (cls_refund, cls_inquiry, etc.) becomes a sourceHandle in the edges!
</example>
<example name="if_else_branching" description="Conditional logic with if-else">
```json
{{
"nodes": [
{{
"id": "start",
"type": "start",
"title": "Start",
"config": {{
"variables": [{{"variable": "years", "label": "Years of Experience", "type": "number", "required": true}}]
}}
}},
{{
"id": "check_experience",
"type": "if-else",
"title": "Check Experience",
"config": {{
"cases": [
{{
"case_id": "case_1",
"logical_operator": "and",
"conditions": [
{{
"variable_selector": ["start", "years"],
"comparison_operator": "",
"value": "3"
}}
]
}}
]
}}
}},
{{
"id": "qualified",
"type": "llm",
"title": "Qualified Response",
"config": {{
"model": {{"provider": "openai", "name": "gpt-4o", "mode": "chat"}},
"prompt_template": [{{"role": "user", "text": "Generate qualified candidate response"}}]
}}
}},
{{
"id": "not_qualified",
"type": "llm",
"title": "Not Qualified Response",
"config": {{
"model": {{"provider": "openai", "name": "gpt-4o", "mode": "chat"}},
"prompt_template": [{{"role": "user", "text": "Generate rejection response"}}]
}}
}},
{{
"id": "end",
"type": "end",
"title": "End",
"config": {{
"outputs": [{{"variable": "result", "value_selector": ["qualified", "text"]}}]
}}
}}
],
"edges": [
{{"source": "start", "target": "check_experience"}},
{{"source": "check_experience", "sourceHandle": "true", "target": "qualified"}},
{{"source": "check_experience", "sourceHandle": "false", "target": "not_qualified"}},
{{"source": "qualified", "target": "end"}},
{{"source": "not_qualified", "target": "end"}}
]
}}
```
CRITICAL: if-else MUST have exactly two edges with sourceHandle "true" and "false"!
</example>
<example name="parameter_extractor" description="Extract structured data from text">
```json
{{
"nodes": [
{{
"id": "start",
"type": "start",
"title": "Start",
"config": {{
"variables": [{{"variable": "resume", "label": "Resume Text", "type": "paragraph", "required": true}}]
}}
}},
{{
"id": "extract",
"type": "parameter-extractor",
"title": "Extract Info",
"config": {{
"model": {{"provider": "openai", "name": "gpt-4o", "mode": "chat"}},
"query": ["start", "resume"],
"parameters": [
{{"name": "name", "type": "string", "description": "Candidate name", "required": true}},
{{"name": "years", "type": "number", "description": "Years of experience", "required": true}},
{{"name": "skills", "type": "array[string]", "description": "List of skills", "required": true}}
],
"instruction": "Extract candidate information from resume"
}}
}},
{{
"id": "process",
"type": "llm",
"title": "Process Data",
"config": {{
"model": {{"provider": "openai", "name": "gpt-4o", "mode": "chat"}},
"prompt_template": [{{"role": "user", "text": "Name: {{{{#extract.name#}}}}, Years: {{{{#extract.years#}}}}"}}]
}}
}},
{{
"id": "end",
"type": "end",
"title": "End",
"config": {{
"outputs": [{{"variable": "result", "value_selector": ["process", "text"]}}]
}}
}}
],
"edges": [
{{"source": "start", "target": "extract"}},
{{"source": "extract", "target": "process"}},
{{"source": "process", "target": "end"}}
]
}}
```
</example>
</examples>
<edge_checklist>
Before finalizing, verify:
1. [ ] Every node (except 'end') has at least one outgoing edge
2. [ ] 'start' node has exactly one outgoing edge
3. [ ] 'question-classifier' has one edge per class, each with sourceHandle = class id
4. [ ] 'if-else' has exactly two edges: sourceHandle "true" and sourceHandle "false"
5. [ ] All branches eventually connect to 'end' (directly or through other nodes)
6. [ ] No orphan nodes exist (every node is reachable from 'start')
</edge_checklist>
"""
BUILDER_USER_PROMPT = """<instruction>
{instruction}
</instruction>
Generate the full workflow configuration now. Pay special attention to:
1. Creating edges for ALL branches of question-classifier and if-else nodes
2. Using correct sourceHandle values for branching nodes
3. Ensuring every node is connected in the graph
"""

View File

@ -0,0 +1,75 @@
PLANNER_SYSTEM_PROMPT = """<role>
You are an expert Workflow Architect.
Your job is to analyze user requests and plan a high-level automation workflow.
</role>
<task>
1. **Classify Intent**:
- Is the user asking to create an automation/workflow? -> Intent: "generate"
- Is it general chat/weather/jokes? -> Intent: "off_topic"
2. **Plan Steps** (if intent is "generate"):
- Break down the user's goal into logical steps.
- For each step, identify if a specific capability/tool is needed.
- Select the MOST RELEVANT tools from the available_tools list.
- DO NOT configure parameters yet. Just identify the tool.
3. **Output Format**:
Return a JSON object.
</task>
<available_tools>
{tools_summary}
</available_tools>
<response_format>
If intent is "generate":
```json
{{
"intent": "generate",
"plan_thought": "Brief explanation of the plan...",
"steps": [
{{ "step": 1, "description": "Fetch data from URL", "tool": "http-request" }},
{{ "step": 2, "description": "Summarize content", "tool": "llm" }},
{{ "step": 3, "description": "Search for info", "tool": "google_search" }}
],
"required_tool_keys": ["google_search"]
}}
```
(Note: 'http-request', 'llm', 'code' are built-in, you don't need to list them in required_tool_keys,
only external tools)
If intent is "off_topic":
```json
{{
"intent": "off_topic",
"message": "I can only help you build workflows. Try asking me to 'Create a workflow that...'",
"suggestions": ["Scrape a website", "Summarize a PDF"]
}}
```
</response_format>
"""
PLANNER_USER_PROMPT = """<user_request>
{instruction}
</user_request>
"""
def format_tools_for_planner(tools: list[dict]) -> str:
"""Format tools list for planner (Lightweight: Name + Description only)."""
if not tools:
return "No external tools available."
lines = []
for t in tools:
key = t.get("tool_key") or t.get("tool_name")
provider = t.get("provider_id") or t.get("provider", "")
desc = t.get("tool_description") or t.get("description", "")
label = t.get("tool_label") or key
# Format: - [provider/key] Label: Description
full_key = f"{provider}/{key}" if provider else key
lines.append(f"- [{full_key}] {label}: {desc}")
return "\n".join(lines)

View File

@ -10,7 +10,7 @@ import json
import re
from typing import Any
from core.llm_generator.vibe_config import (
from core.workflow.generator.config import (
BUILTIN_NODE_SCHEMAS,
DEFAULT_SUGGESTIONS,
FALLBACK_RULES,
@ -100,6 +100,13 @@ You help users create AI automation workflows by generating workflow configurati
</variable_syntax>
<rules>
<rule id="model_selection" priority="critical">
For LLM, question-classifier, parameter-extractor nodes:
- You MUST include a "model" config with provider and name from available_models section
- Copy the EXACT provider and name values from available_models
- NEVER use openai/gpt-4o, openai/gpt-3.5-turbo, openai/gpt-4 unless they appear in available_models
- If available_models is empty or not provided, omit the model config entirely
</rule>
<rule id="tool_usage" priority="critical">
ONLY use tools with status="configured" from available_tools.
NEVER invent tool names like "webscraper", "email_sender", etc.
@ -217,12 +224,14 @@ You help users create AI automation workflows by generating workflow configurati
"type": "llm",
"title": "Analyze Content",
"config": {{{{
"model": {{{{"provider": "USE_FROM_AVAILABLE_MODELS", "name": "USE_FROM_AVAILABLE_MODELS", "mode": "chat"}}}},
"prompt_template": [
{{{{"role": "system", "text": "You are a helpful analyst."}}}},
{{{{"role": "user", "text": "Analyze this content:\\n\\n{{{{#fetch.body#}}}}"}}}}
]
}}}}
}}}}
NOTE: Replace "USE_FROM_AVAILABLE_MODELS" with actual values from available_models section!
</example>
<example type="code" title="Process data">
{{{{
@ -344,6 +353,7 @@ Generate your JSON response now. Remember:
</output_instruction>
"""
def format_available_nodes(nodes: list[dict[str, Any]] | None) -> str:
"""Format available nodes as XML with parameter schemas."""
lines = ["<available_nodes>"]
@ -591,7 +601,7 @@ def format_previous_attempt(
def format_available_models(models: list[dict[str, Any]] | None) -> str:
"""Format available models as XML for prompt inclusion."""
if not models:
return "<available_models>\n <!-- No models configured -->\n</available_models>"
return "<available_models>\n <!-- No models configured - omit model config from nodes -->\n</available_models>"
lines = ["<available_models>"]
for model in models:
@ -600,16 +610,30 @@ def format_available_models(models: list[dict[str, Any]] | None) -> str:
lines.append(f' <model provider="{provider}" name="{model_name}" />')
lines.append("</available_models>")
# Add model selection rule
# Add model selection rule with concrete example
lines.append("")
lines.append("<model_selection_rule>")
lines.append(" CRITICAL: For LLM, question-classifier, and parameter-extractor nodes, you MUST select a model from available_models.")
if len(models) == 1:
first_model = models[0]
lines.append(f' Use provider="{first_model.get("provider")}" and name="{first_model.get("model")}" for all model-dependent nodes.')
else:
lines.append(" Choose the most suitable model for each task from the available options.")
lines.append(" NEVER use models not listed in available_models (e.g., openai/gpt-4o if not listed).")
lines.append(" CRITICAL: For LLM, question-classifier, and parameter-extractor nodes:")
lines.append(" - You MUST include a 'model' field in the config")
lines.append(" - You MUST use ONLY models from available_models above")
lines.append(" - NEVER use openai/gpt-4o, gpt-3.5-turbo, gpt-4 unless they appear in available_models")
lines.append("")
# Provide concrete JSON example to copy
first_model = models[0]
provider = first_model.get("provider", "unknown")
model_name = first_model.get("model", "unknown")
lines.append(" COPY THIS EXACT MODEL CONFIG for all LLM/question-classifier/parameter-extractor nodes:")
lines.append(f' "model": {{"provider": "{provider}", "name": "{model_name}", "mode": "chat"}}')
if len(models) > 1:
lines.append("")
lines.append(" Alternative models you can use:")
for m in models[1:4]: # Show up to 3 alternatives
p = m.get("provider", "unknown")
n = m.get("model", "unknown")
lines.append(f' - "model": {{"provider": "{p}", "name": "{n}", "mode": "chat"}}')
lines.append("</model_selection_rule>")
return "\n".join(lines)
@ -1023,6 +1047,7 @@ def validate_node_parameters(nodes: list[dict[str, Any]]) -> list[str]:
def extract_mermaid_from_response(data: dict[str, Any]) -> str:
"""Extract mermaid flowchart from parsed response."""
mermaid = data.get("mermaid", "")
if not mermaid:
return ""
@ -1034,5 +1059,203 @@ def extract_mermaid_from_response(data: dict[str, Any]) -> str:
if match:
mermaid = match.group(1).strip()
# Sanitize edge labels to remove characters that break Mermaid parsing
# Edge labels in Mermaid are ONLY in the pattern: -->|label|
# We must NOT match |pipe| characters inside node labels like ["type=start|title=开始"]
def sanitize_edge_label(match: re.Match) -> str:
arrow = match.group(1) # --> or ---
label = match.group(2) # the label between pipes
# Remove or replace special characters that break Mermaid
# Parentheses, brackets, braces have special meaning in Mermaid
sanitized = re.sub(r'[(){}\[\]]', '', label)
return f"{arrow}|{sanitized}|"
# Only match edge labels: --> or --- followed by |label|
# This pattern ensures we only sanitize actual edge labels, not node content
mermaid = re.sub(r'(-->|---)\|([^|]+)\|', sanitize_edge_label, mermaid)
return mermaid
def classify_validation_errors(
nodes: list[dict[str, Any]],
available_models: list[dict[str, Any]] | None = None,
available_tools: list[dict[str, Any]] | None = None,
edges: list[dict[str, Any]] | None = None,
) -> dict[str, list[dict[str, Any]]]:
"""
Classify validation errors into fixable and user-required categories.
This function uses the declarative rule engine to validate nodes.
The rule engine provides deterministic, testable validation without
relying on LLM judgment.
Fixable errors can be automatically corrected by the LLM in subsequent
iterations. User-required errors need manual intervention.
Args:
nodes: List of generated workflow nodes
available_models: List of models the user has configured
available_tools: List of available tools
edges: List of edges connecting nodes
Returns:
dict with:
- "fixable": errors that LLM can fix automatically
- "user_required": errors that need user intervention
- "all_warnings": combined warning messages for backwards compatibility
- "stats": validation statistics
"""
from core.workflow.generator.validation import ValidationContext, ValidationEngine
# Build validation context
context = ValidationContext(
nodes=nodes,
edges=edges or [],
available_models=available_models or [],
available_tools=available_tools or [],
)
# Run validation through rule engine
engine = ValidationEngine()
result = engine.validate(context)
# Convert to legacy format for backwards compatibility
fixable: list[dict[str, Any]] = []
user_required: list[dict[str, Any]] = []
for error in result.fixable_errors:
fixable.append({
"node_id": error.node_id,
"node_type": error.node_type,
"error_type": error.rule_id,
"message": error.message,
"is_fixable": True,
"fix_hint": error.fix_hint,
"category": error.category.value,
"details": error.details,
})
for error in result.user_required_errors:
user_required.append({
"node_id": error.node_id,
"node_type": error.node_type,
"error_type": error.rule_id,
"message": error.message,
"is_fixable": False,
"fix_hint": error.fix_hint,
"category": error.category.value,
"details": error.details,
})
# Include warnings in user_required (they're non-blocking but informative)
for error in result.warnings:
user_required.append({
"node_id": error.node_id,
"node_type": error.node_type,
"error_type": error.rule_id,
"message": error.message,
"is_fixable": error.is_fixable,
"fix_hint": error.fix_hint,
"category": error.category.value,
"severity": "warning",
"details": error.details,
})
# Generate combined warnings for backwards compatibility
all_warnings = [e["message"] for e in fixable + user_required]
return {
"fixable": fixable,
"user_required": user_required,
"all_warnings": all_warnings,
"stats": result.stats,
}
def build_fix_prompt(
fixable_errors: list[dict[str, Any]],
previous_nodes: list[dict[str, Any]],
available_models: list[dict[str, Any]] | None = None,
) -> str:
"""
Build a prompt for LLM to fix the identified errors.
This creates a focused instruction that tells the LLM exactly what
to fix in the previous generation.
Args:
fixable_errors: List of errors that can be automatically fixed
previous_nodes: The nodes from the previous generation attempt
available_models: Available models for model configuration fixes
Returns:
Formatted prompt string for the fix iteration
"""
if not fixable_errors:
return ""
parts = ["<fix_required>"]
parts.append(" <description>")
parts.append(" Your previous generation has errors that need fixing.")
parts.append(" Please regenerate with the following corrections:")
parts.append(" </description>")
# Group errors by node
errors_by_node: dict[str, list[dict[str, Any]]] = {}
for error in fixable_errors:
node_id = error["node_id"]
if node_id not in errors_by_node:
errors_by_node[node_id] = []
errors_by_node[node_id].append(error)
parts.append(" <errors_to_fix>")
for node_id, node_errors in errors_by_node.items():
parts.append(f" <node id=\"{node_id}\">")
for error in node_errors:
error_type = error["error_type"]
message = error["message"]
fix_hint = error.get("fix_hint", "")
parts.append(f" <error type=\"{error_type}\">")
parts.append(f" <message>{message}</message>")
if fix_hint:
parts.append(f" <fix_hint>{fix_hint}</fix_hint>")
parts.append(" </error>")
parts.append(" </node>")
parts.append(" </errors_to_fix>")
# Add model selection help if there are model-related errors
model_errors = [e for e in fixable_errors if "model" in e["error_type"]]
if model_errors and available_models:
parts.append(" <model_selection_help>")
parts.append(" Use one of these models for nodes requiring model config:")
for model in available_models[:3]: # Show top 3
provider = model.get("provider", "unknown")
name = model.get("model", "unknown")
parts.append(f' - {{"provider": "{provider}", "name": "{name}", "mode": "chat"}}')
parts.append(" </model_selection_help>")
# Add previous nodes summary for context
parts.append(" <previous_nodes_to_fix>")
for node in previous_nodes:
node_id = node.get("id", "unknown")
if node_id in errors_by_node:
# Only include nodes that have errors
node_type = node.get("type", "unknown")
title = node.get("title", "Untitled")
config_summary = json.dumps(node.get("config", {}), ensure_ascii=False)[:200]
parts.append(f" <node id=\"{node_id}\" type=\"{node_type}\" title=\"{title}\">")
parts.append(f" <current_config>{config_summary}...</current_config>")
parts.append(" </node>")
parts.append(" </previous_nodes_to_fix>")
parts.append(" <instructions>")
parts.append(" 1. Keep the workflow structure and logic unchanged")
parts.append(" 2. Fix ONLY the errors listed above")
parts.append(" 3. Ensure all required fields are properly filled")
parts.append(" 4. Use variable references {{#node_id.field#}} where appropriate")
parts.append(" </instructions>")
parts.append("</fix_required>")
return "\n".join(parts)

View File

@ -0,0 +1,194 @@
import json
import logging
import re
from collections.abc import Sequence
import json_repair
from core.model_manager import ModelManager
from core.model_runtime.entities.message_entities import SystemPromptMessage, UserPromptMessage
from core.model_runtime.entities.model_entities import ModelType
from core.workflow.generator.prompts.builder_prompts import BUILDER_SYSTEM_PROMPT, BUILDER_USER_PROMPT
from core.workflow.generator.prompts.planner_prompts import (
PLANNER_SYSTEM_PROMPT,
PLANNER_USER_PROMPT,
format_tools_for_planner,
)
from core.workflow.generator.prompts.vibe_prompts import (
format_available_nodes,
format_available_tools,
parse_vibe_response,
)
from core.workflow.generator.utils.edge_repair import EdgeRepair
from core.workflow.generator.utils.mermaid_generator import generate_mermaid
from core.workflow.generator.utils.node_repair import NodeRepair
from core.workflow.generator.utils.workflow_validator import WorkflowValidator
logger = logging.getLogger(__name__)
class WorkflowGenerator:
"""
Refactored Vibe Workflow Generator (Planner-Builder Architecture).
Extracts Vibe logic from the monolithic LLMGenerator.
"""
@classmethod
def generate_workflow_flowchart(
cls,
tenant_id: str,
instruction: str,
model_config: dict,
available_nodes: Sequence[dict[str, object]] | None = None,
existing_nodes: Sequence[dict[str, object]] | None = None,
available_tools: Sequence[dict[str, object]] | None = None,
selected_node_ids: Sequence[str] | None = None,
previous_workflow: dict[str, object] | None = None,
regenerate_mode: bool = False,
preferred_language: str | None = None,
available_models: Sequence[dict[str, object]] | None = None,
max_fix_iterations: int = 2,
):
"""
Generates a Dify Workflow Flowchart from natural language instruction.
Pipeline:
1. Planner: Analyze intent & select tools.
2. Context Filter: Filter relevant tools (reduce tokens).
3. Builder: Generate node configurations.
4. Repair: Fix common node/edge issues (NodeRepair, EdgeRepair).
5. Validator: Check for errors & generate friendly hints.
6. Renderer: Deterministic Mermaid generation.
"""
model_manager = ModelManager()
model_instance = model_manager.get_model_instance(
tenant_id=tenant_id,
model_type=ModelType.LLM,
provider=model_config.get("provider", ""),
model=model_config.get("name", ""),
)
model_parameters = model_config.get("completion_params", {})
available_tools_list = list(available_tools) if available_tools else []
# --- STEP 1: PLANNER ---
planner_tools_context = format_tools_for_planner(available_tools_list)
planner_system = PLANNER_SYSTEM_PROMPT.format(tools_summary=planner_tools_context)
planner_user = PLANNER_USER_PROMPT.format(instruction=instruction)
try:
response = model_instance.invoke_llm(
prompt_messages=[SystemPromptMessage(content=planner_system), UserPromptMessage(content=planner_user)],
model_parameters=model_parameters,
stream=False,
)
plan_content = response.message.content
# Reuse parse_vibe_response logic or simple load
plan_data = parse_vibe_response(plan_content)
except Exception as e:
logger.exception("Planner failed")
return {"intent": "error", "error": f"Planning failed: {str(e)}"}
if plan_data.get("intent") == "off_topic":
return {
"intent": "off_topic",
"message": plan_data.get("message", "I can only help with workflow creation."),
"suggestions": plan_data.get("suggestions", []),
}
# --- STEP 2: CONTEXT FILTERING ---
required_tools = plan_data.get("required_tool_keys", [])
filtered_tools = []
if required_tools:
# Simple linear search (optimized version would use a map)
for tool in available_tools_list:
t_key = tool.get("tool_key") or tool.get("tool_name")
provider = tool.get("provider_id") or tool.get("provider")
full_key = f"{provider}/{t_key}" if provider else t_key
# Check if this tool is in required list (match either full key or short name)
if t_key in required_tools or full_key in required_tools:
filtered_tools.append(tool)
else:
# If logic only, no tools needed
filtered_tools = []
# --- STEP 3: BUILDER ---
# Prepare context
tool_schemas = format_available_tools(filtered_tools)
# We need to construct a fake list structure for builtin nodes formatting if using format_available_nodes
# Actually format_available_nodes takes None to use defaults, or a list to add custom
# But we want to SHOW the builtins. format_available_nodes internally uses BUILTIN_NODE_SCHEMAS.
node_specs = format_available_nodes([])
builder_system = BUILDER_SYSTEM_PROMPT.format(
plan_context=json.dumps(plan_data.get("steps", []), indent=2),
tool_schemas=tool_schemas,
builtin_node_specs=node_specs,
)
builder_user = BUILDER_USER_PROMPT.format(instruction=instruction)
try:
build_res = model_instance.invoke_llm(
prompt_messages=[SystemPromptMessage(content=builder_system), UserPromptMessage(content=builder_user)],
model_parameters=model_parameters,
stream=False,
)
# Builder output is raw JSON nodes/edges
build_content = build_res.message.content
match = re.search(r"```(?:json)?\s*([\s\S]+?)```", build_content)
if match:
build_content = match.group(1)
workflow_data = json_repair.loads(build_content)
if "nodes" not in workflow_data:
workflow_data["nodes"] = []
if "edges" not in workflow_data:
workflow_data["edges"] = []
except Exception as e:
logger.exception("Builder failed")
return {"intent": "error", "error": f"Building failed: {str(e)}"}
# --- STEP 3.4: NODE REPAIR ---
node_repair_result = NodeRepair.repair(workflow_data["nodes"])
workflow_data["nodes"] = node_repair_result.nodes
# --- STEP 3.5: EDGE REPAIR ---
repair_result = EdgeRepair.repair(workflow_data)
workflow_data = {
"nodes": repair_result.nodes,
"edges": repair_result.edges,
}
# --- STEP 4: VALIDATOR ---
is_valid, hints = WorkflowValidator.validate(workflow_data, available_tools_list)
# --- STEP 5: RENDERER ---
mermaid_code = generate_mermaid(workflow_data)
# --- FINALIZE ---
# Combine validation hints with repair warnings
all_warnings = [h.message for h in hints] + repair_result.warnings + node_repair_result.warnings
# Add stability warning (as requested by user)
stability_warning = "The generated workflow may require debugging."
if preferred_language and preferred_language.startswith("zh"):
stability_warning = "生成的 Workflow 可能需要调试。"
all_warnings.append(stability_warning)
all_fixes = repair_result.repairs_made + node_repair_result.repairs_made
return {
"intent": "generate",
"flowchart": mermaid_code,
"nodes": workflow_data["nodes"],
"edges": workflow_data["edges"],
"message": plan_data.get("plan_thought", "Generated workflow based on your request."),
"warnings": all_warnings,
"tool_recommendations": [], # Legacy field
"error": "",
"fix_iterations": 0, # Legacy
"fixed_issues": all_fixes, # Track what was auto-fixed
}

View File

@ -0,0 +1,372 @@
"""
Edge Repair Utility for Vibe Workflow Generation.
This module provides intelligent edge repair capabilities for generated workflows.
It can detect and fix common edge issues:
- Missing edges between sequential nodes
- Incomplete branches for question-classifier and if-else nodes
- Orphaned nodes without connections
The repair logic is deterministic and doesn't require LLM calls.
"""
import logging
from dataclasses import dataclass, field
from typing import Any
logger = logging.getLogger(__name__)
@dataclass
class RepairResult:
"""Result of edge repair operation."""
nodes: list[dict[str, Any]]
edges: list[dict[str, Any]]
repairs_made: list[str] = field(default_factory=list)
warnings: list[str] = field(default_factory=list)
@property
def was_repaired(self) -> bool:
"""Check if any repairs were made."""
return len(self.repairs_made) > 0
class EdgeRepair:
"""
Intelligent edge repair for workflow graphs.
Repairs are applied in order:
1. Infer linear connections from node order (if no edges exist)
2. Add missing branch edges for question-classifier
3. Add missing branch edges for if-else
4. Connect orphaned nodes
"""
@classmethod
def repair(cls, workflow_data: dict[str, Any]) -> RepairResult:
"""
Repair edges in the workflow data.
Args:
workflow_data: Dict containing 'nodes' and 'edges'
Returns:
RepairResult with repaired nodes, edges, and repair logs
"""
nodes = list(workflow_data.get("nodes", []))
edges = list(workflow_data.get("edges", []))
repairs: list[str] = []
warnings: list[str] = []
logger.info("[EdgeRepair] Starting repair: %d nodes, %d edges", len(nodes), len(edges))
# Build node lookup
node_map = {n.get("id"): n for n in nodes if n.get("id")}
node_ids = set(node_map.keys())
# 1. If no edges at all, infer linear chain
if not edges and len(nodes) > 1:
edges, inferred_repairs = cls._infer_linear_chain(nodes)
repairs.extend(inferred_repairs)
# 2. Build edge index for analysis
outgoing_edges: dict[str, list[dict[str, Any]]] = {}
incoming_edges: dict[str, list[dict[str, Any]]] = {}
for edge in edges:
src = edge.get("source")
tgt = edge.get("target")
if src:
outgoing_edges.setdefault(src, []).append(edge)
if tgt:
incoming_edges.setdefault(tgt, []).append(edge)
# 3. Repair question-classifier branches
for node in nodes:
if node.get("type") == "question-classifier":
new_edges, branch_repairs, branch_warnings = cls._repair_classifier_branches(
node, edges, outgoing_edges, node_ids
)
edges.extend(new_edges)
repairs.extend(branch_repairs)
warnings.extend(branch_warnings)
# Update outgoing index
for edge in new_edges:
outgoing_edges.setdefault(edge.get("source"), []).append(edge)
# 4. Repair if-else branches
for node in nodes:
if node.get("type") == "if-else":
new_edges, branch_repairs, branch_warnings = cls._repair_if_else_branches(
node, edges, outgoing_edges, node_ids
)
edges.extend(new_edges)
repairs.extend(branch_repairs)
warnings.extend(branch_warnings)
# Update outgoing index
for edge in new_edges:
outgoing_edges.setdefault(edge.get("source"), []).append(edge)
# 5. Connect orphaned nodes (nodes with no incoming edge, except start)
new_edges, orphan_repairs = cls._connect_orphaned_nodes(
nodes, edges, outgoing_edges, incoming_edges
)
edges.extend(new_edges)
repairs.extend(orphan_repairs)
# 6. Connect nodes with no outgoing edge to 'end' (except end nodes)
new_edges, terminal_repairs = cls._connect_terminal_nodes(
nodes, edges, outgoing_edges
)
edges.extend(new_edges)
repairs.extend(terminal_repairs)
logger.info("[EdgeRepair] Completed: %d repairs made, %d warnings", len(repairs), len(warnings))
for r in repairs:
logger.info("[EdgeRepair] Repair: %s", r)
for w in warnings:
logger.info("[EdgeRepair] Warning: %s", w)
return RepairResult(
nodes=nodes,
edges=edges,
repairs_made=repairs,
warnings=warnings,
)
@classmethod
def _infer_linear_chain(cls, nodes: list[dict[str, Any]]) -> tuple[list[dict[str, Any]], list[str]]:
"""
Infer a linear chain of edges from node order.
This is used when no edges are provided at all.
"""
edges: list[dict[str, Any]] = []
repairs: list[str] = []
# Filter to get ordered node IDs
node_ids = [n.get("id") for n in nodes if n.get("id")]
if len(node_ids) < 2:
return edges, repairs
# Create edges between consecutive nodes
for i in range(len(node_ids) - 1):
src = node_ids[i]
tgt = node_ids[i + 1]
edges.append({"source": src, "target": tgt})
repairs.append(f"Inferred edge: {src} -> {tgt}")
logger.info("[EdgeRepair] Inferred %d edges from node order (no edges provided)", len(edges))
return edges, repairs
@classmethod
def _repair_classifier_branches(
cls,
node: dict[str, Any],
edges: list[dict[str, Any]],
outgoing_edges: dict[str, list[dict[str, Any]]],
valid_node_ids: set[str],
) -> tuple[list[dict[str, Any]], list[str], list[str]]:
"""
Repair missing branches for question-classifier nodes.
For each class that doesn't have an edge, create one pointing to 'end'.
"""
new_edges: list[dict[str, Any]] = []
repairs: list[str] = []
warnings: list[str] = []
node_id = node.get("id")
if not node_id:
return new_edges, repairs, warnings
config = node.get("config", {})
classes = config.get("classes", [])
if not classes:
return new_edges, repairs, warnings
# Get existing sourceHandles for this node
existing_handles = set()
for edge in outgoing_edges.get(node_id, []):
handle = edge.get("sourceHandle")
if handle:
existing_handles.add(handle)
# Find 'end' node as default target
end_node_id = "end"
if "end" not in valid_node_ids:
# Try to find an end node
for nid in valid_node_ids:
if "end" in nid.lower():
end_node_id = nid
break
# Add missing branches
for cls_def in classes:
if not isinstance(cls_def, dict):
continue
cls_id = cls_def.get("id")
cls_name = cls_def.get("name", cls_id)
if cls_id and cls_id not in existing_handles:
new_edge = {
"source": node_id,
"sourceHandle": cls_id,
"target": end_node_id,
}
new_edges.append(new_edge)
repairs.append(f"Added missing branch edge for class '{cls_name}' -> {end_node_id}")
warnings.append(
f"Auto-connected question-classifier branch '{cls_name}' to '{end_node_id}'. "
"You may want to redirect this to a specific handler node."
)
return new_edges, repairs, warnings
@classmethod
def _repair_if_else_branches(
cls,
node: dict[str, Any],
edges: list[dict[str, Any]],
outgoing_edges: dict[str, list[dict[str, Any]]],
valid_node_ids: set[str],
) -> tuple[list[dict[str, Any]], list[str], list[str]]:
"""
Repair missing true/false branches for if-else nodes.
"""
new_edges: list[dict[str, Any]] = []
repairs: list[str] = []
warnings: list[str] = []
node_id = node.get("id")
if not node_id:
return new_edges, repairs, warnings
# Get existing sourceHandles
existing_handles = set()
for edge in outgoing_edges.get(node_id, []):
handle = edge.get("sourceHandle")
if handle:
existing_handles.add(handle)
# Find 'end' node as default target
end_node_id = "end"
if "end" not in valid_node_ids:
for nid in valid_node_ids:
if "end" in nid.lower():
end_node_id = nid
break
# Add missing branches
required_branches = ["true", "false"]
for branch in required_branches:
if branch not in existing_handles:
new_edge = {
"source": node_id,
"sourceHandle": branch,
"target": end_node_id,
}
new_edges.append(new_edge)
repairs.append(f"Added missing if-else '{branch}' branch -> {end_node_id}")
warnings.append(
f"Auto-connected if-else '{branch}' branch to '{end_node_id}'. "
"You may want to redirect this to a specific handler node."
)
return new_edges, repairs, warnings
@classmethod
def _connect_orphaned_nodes(
cls,
nodes: list[dict[str, Any]],
edges: list[dict[str, Any]],
outgoing_edges: dict[str, list[dict[str, Any]]],
incoming_edges: dict[str, list[dict[str, Any]]],
) -> tuple[list[dict[str, Any]], list[str]]:
"""
Connect orphaned nodes to the previous node in sequence.
An orphaned node has no incoming edges and is not a 'start' node.
"""
new_edges: list[dict[str, Any]] = []
repairs: list[str] = []
node_ids = [n.get("id") for n in nodes if n.get("id")]
node_types = {n.get("id"): n.get("type") for n in nodes}
for i, node_id in enumerate(node_ids):
node_type = node_types.get(node_id)
# Skip start nodes - they don't need incoming edges
if node_type == "start":
continue
# Check if node has incoming edges
if node_id not in incoming_edges or not incoming_edges[node_id]:
# Find previous node to connect from
if i > 0:
prev_node_id = node_ids[i - 1]
new_edge = {"source": prev_node_id, "target": node_id}
new_edges.append(new_edge)
repairs.append(f"Connected orphaned node: {prev_node_id} -> {node_id}")
# Update incoming_edges for subsequent checks
incoming_edges.setdefault(node_id, []).append(new_edge)
return new_edges, repairs
@classmethod
def _connect_terminal_nodes(
cls,
nodes: list[dict[str, Any]],
edges: list[dict[str, Any]],
outgoing_edges: dict[str, list[dict[str, Any]]],
) -> tuple[list[dict[str, Any]], list[str]]:
"""
Connect terminal nodes (no outgoing edges) to 'end'.
A terminal node has no outgoing edges and is not an 'end' node.
This ensures all branches eventually reach 'end'.
"""
new_edges: list[dict[str, Any]] = []
repairs: list[str] = []
# Find end node
end_node_id = None
node_ids = set()
for n in nodes:
nid = n.get("id")
ntype = n.get("type")
if nid:
node_ids.add(nid)
if ntype == "end":
end_node_id = nid
if not end_node_id:
# No end node found, can't connect
return new_edges, repairs
for node in nodes:
node_id = node.get("id")
node_type = node.get("type")
# Skip end nodes
if node_type == "end":
continue
# Skip nodes that already have outgoing edges
if outgoing_edges.get(node_id):
continue
# Connect to end
new_edge = {"source": node_id, "target": end_node_id}
new_edges.append(new_edge)
repairs.append(f"Connected terminal node to end: {node_id} -> {end_node_id}")
# Update for subsequent checks
outgoing_edges.setdefault(node_id, []).append(new_edge)
return new_edges, repairs

View File

@ -0,0 +1,138 @@
import logging
from typing import Any
logger = logging.getLogger(__name__)
def generate_mermaid(workflow_data: dict[str, Any]) -> str:
"""
Generate a Mermaid flowchart from workflow data consisting of nodes and edges.
Args:
workflow_data: Dict containing 'nodes' (list) and 'edges' (list)
Returns:
String containing the Mermaid flowchart syntax
"""
nodes = workflow_data.get("nodes", [])
edges = workflow_data.get("edges", [])
# DEBUG: Log input data
logger.debug("[MERMAID] Input nodes count: %d", len(nodes))
logger.debug("[MERMAID] Input edges count: %d", len(edges))
for i, node in enumerate(nodes):
logger.debug(
"[MERMAID] Node %d: id=%s, type=%s, title=%s", i, node.get("id"), node.get("type"), node.get("title")
)
for i, edge in enumerate(edges):
logger.debug(
"[MERMAID] Edge %d: source=%s, target=%s, sourceHandle=%s",
i,
edge.get("source"),
edge.get("target"),
edge.get("sourceHandle"),
)
lines = ["flowchart TD"]
# 1. Define Nodes
# Format: node_id["title<br/>type"] or similar
# We will use the Vibe Workflow standard format: id["type=TYPE|title=TITLE"]
# Or specifically for tool nodes: id["type=tool|title=TITLE|tool=TOOL_KEY"]
# Map of original IDs to safe Mermaid IDs
id_map = {}
def get_safe_id(original_id: str) -> str:
if original_id == "end":
return "end_node"
if original_id == "subgraph":
return "subgraph_node"
# Mermaid IDs should be alphanumeric.
# If the ID has special chars, we might need to escape or hash, but Vibe usually generates simple IDs.
# We'll trust standard IDs but handle the reserved keyword 'end'.
return original_id
for node in nodes:
node_id = node.get("id")
if not node_id:
continue
safe_id = get_safe_id(node_id)
id_map[node_id] = safe_id
node_type = node.get("type", "unknown")
title = node.get("title", "Untitled")
# Escape quotes in title
safe_title = title.replace('"', "'")
if node_type == "tool":
config = node.get("config", {})
# Try multiple fields for tool reference
tool_ref = (
config.get("tool_key")
or config.get("tool")
or config.get("tool_name")
or node.get("tool_name")
or "unknown"
)
node_def = f'{safe_id}["type={node_type}|title={safe_title}|tool={tool_ref}"]'
else:
node_def = f'{safe_id}["type={node_type}|title={safe_title}"]'
lines.append(f" {node_def}")
# 2. Define Edges
# Format: source --> target
# Track defined nodes to avoid edge errors
defined_node_ids = {n.get("id") for n in nodes if n.get("id")}
for edge in edges:
source = edge.get("source")
target = edge.get("target")
# Skip invalid edges
if not source or not target:
continue
if source not in defined_node_ids or target not in defined_node_ids:
# Log skipped edges for debugging
logger.warning(
"[MERMAID] Skipping edge: source=%s (exists=%s), target=%s (exists=%s)",
source,
source in defined_node_ids,
target,
target in defined_node_ids,
)
continue
safe_source = id_map.get(source, source)
safe_target = id_map.get(target, target)
# Handle conditional branches (true/false) if present
# In Dify workflow, sourceHandle is often used for this
source_handle = edge.get("sourceHandle")
label = ""
if source_handle == "true":
label = "|true|"
elif source_handle == "false":
label = "|false|"
elif source_handle and source_handle != "source":
# For question-classifier or other multi-path nodes
# Clean up handle for display if needed
safe_handle = str(source_handle).replace('"', "'")
label = f"|{safe_handle}|"
edge_line = f" {safe_source} -->{label} {safe_target}"
logger.debug("[MERMAID] Adding edge: %s", edge_line)
lines.append(edge_line)
# Start/End nodes are implicitly handled if they are in the 'nodes' list
# If not, we might need to add them, but usually the Builder should produce them.
result = "\n".join(lines)
logger.debug("[MERMAID] Final output:\n%s", result)
return result

View File

@ -0,0 +1,96 @@
"""
Node Repair Utility for Vibe Workflow Generation.
This module provides intelligent node configuration repair capabilities.
It can detect and fix common node configuration issues:
- Invalid comparison operators in if-else nodes (e.g. '>=' -> '')
"""
import copy
import logging
from dataclasses import dataclass, field
from typing import Any
logger = logging.getLogger(__name__)
@dataclass
class NodeRepairResult:
"""Result of node repair operation."""
nodes: list[dict[str, Any]]
repairs_made: list[str] = field(default_factory=list)
warnings: list[str] = field(default_factory=list)
@property
def was_repaired(self) -> bool:
"""Check if any repairs were made."""
return len(self.repairs_made) > 0
class NodeRepair:
"""
Intelligent node configuration repair.
"""
OPERATOR_MAP = {
">=": "",
"<=": "",
"!=": "",
"==": "=",
}
@classmethod
def repair(cls, nodes: list[dict[str, Any]]) -> NodeRepairResult:
"""
Repair node configurations.
Args:
nodes: List of node dictionaries
Returns:
NodeRepairResult with repaired nodes and logs
"""
# Deep copy to avoid mutating original
nodes = copy.deepcopy(nodes)
repairs: list[str] = []
warnings: list[str] = []
logger.info("[NodeRepair] Starting repair: %d nodes", len(nodes))
for node in nodes:
node_type = node.get("type")
if node_type == "if-else":
cls._repair_if_else_operators(node, repairs)
# Add other node type repairs here as needed
if repairs:
logger.info("[NodeRepair] Completed: %d repairs made", len(repairs))
for r in repairs:
logger.info("[NodeRepair] Repair: %s", r)
return NodeRepairResult(
nodes=nodes,
repairs_made=repairs,
warnings=warnings,
)
@classmethod
def _repair_if_else_operators(cls, node: dict[str, Any], repairs: list[str]):
"""
Normalize comparison operators in if-else nodes.
"""
node_id = node.get("id", "unknown")
config = node.get("config", {})
cases = config.get("cases", [])
for case in cases:
conditions = case.get("conditions", [])
for condition in conditions:
op = condition.get("comparison_operator")
if op in cls.OPERATOR_MAP:
new_op = cls.OPERATOR_MAP[op]
condition["comparison_operator"] = new_op
repairs.append(f"Normalized operator '{op}' to '{new_op}' in node '{node_id}'")

View File

@ -0,0 +1,96 @@
import logging
from dataclasses import dataclass
from typing import Any
from core.workflow.generator.validation.context import ValidationContext
from core.workflow.generator.validation.engine import ValidationEngine
from core.workflow.generator.validation.rules import Severity
logger = logging.getLogger(__name__)
@dataclass
class ValidationHint:
"""Legacy compatibility class for validation hints."""
node_id: str
field: str
message: str
severity: str # 'error', 'warning'
suggestion: str = None
node_type: str = None # Added for test compatibility
# Alias for potential old code using 'type' instead of 'severity'
@property
def type(self) -> str:
return self.severity
@property
def element_id(self) -> str:
return self.node_id
FriendlyHint = ValidationHint # Alias for backward compatibility
class WorkflowValidator:
"""
Validates the generated workflow configuration (nodes and edges).
Wraps the new ValidationEngine for backward compatibility.
"""
@classmethod
def validate(
cls,
workflow_data: dict[str, Any],
available_tools: list[dict[str, Any]],
available_models: list[dict[str, Any]] | None = None,
) -> tuple[bool, list[ValidationHint]]:
"""
Validate workflow data and return validity status and hints.
Args:
workflow_data: Dict containing 'nodes' and 'edges'
available_tools: List of available tool configurations
available_models: List of available models (added for Vibe compat)
Returns:
Tuple(max_severity_is_not_error, list_of_hints)
"""
nodes = workflow_data.get("nodes", [])
edges = workflow_data.get("edges", [])
# Create context
context = ValidationContext(
nodes=nodes,
edges=edges,
available_models=available_models or [],
available_tools=available_tools or [],
)
# Run validation engine
engine = ValidationEngine()
result = engine.validate(context)
# Convert engine errors to legacy hints
hints: list[ValidationHint] = []
for error in result.all_errors:
# Map severity
severity = "error" if error.severity == Severity.ERROR else "warning"
# Map field from message or details if possible (heuristic)
field_name = error.details.get("field", "unknown")
hints.append(
ValidationHint(
node_id=error.node_id,
field=field_name,
message=error.message,
severity=severity,
suggestion=error.fix_hint,
node_type=error.node_type,
)
)
return result.is_valid, hints

View File

@ -0,0 +1,45 @@
"""
Validation Rule Engine for Vibe Workflow Generation.
This module provides a declarative, schema-based validation system for
generated workflow nodes. It classifies errors into fixable (LLM can auto-fix)
and user-required (needs manual intervention) categories.
Usage:
from core.workflow.generator.validation import ValidationEngine, ValidationContext
context = ValidationContext(
available_models=[...],
available_tools=[...],
nodes=[...],
edges=[...],
)
engine = ValidationEngine()
result = engine.validate(context)
# Access classified errors
fixable_errors = result.fixable_errors
user_required_errors = result.user_required_errors
"""
from core.workflow.generator.validation.context import ValidationContext
from core.workflow.generator.validation.engine import ValidationEngine, ValidationResult
from core.workflow.generator.validation.rules import (
RuleCategory,
Severity,
ValidationError,
ValidationRule,
)
__all__ = [
"RuleCategory",
"Severity",
"ValidationContext",
"ValidationEngine",
"ValidationError",
"ValidationResult",
"ValidationRule",
]

View File

@ -0,0 +1,123 @@
"""
Validation Context for the Rule Engine.
The ValidationContext holds all the data needed for validation:
- Generated nodes and edges
- Available models, tools, and datasets
- Node output schemas for variable reference validation
"""
from dataclasses import dataclass, field
from typing import Any
@dataclass
class ValidationContext:
"""
Context object containing all data needed for validation.
This is passed to each validation rule, providing access to:
- The nodes being validated
- Edge connections between nodes
- Available external resources (models, tools)
"""
# Generated workflow data
nodes: list[dict[str, Any]] = field(default_factory=list)
edges: list[dict[str, Any]] = field(default_factory=list)
# Available external resources
available_models: list[dict[str, Any]] = field(default_factory=list)
available_tools: list[dict[str, Any]] = field(default_factory=list)
# Cached lookups (populated lazily)
_node_map: dict[str, dict[str, Any]] | None = field(default=None, repr=False)
_model_set: set[tuple[str, str]] | None = field(default=None, repr=False)
_tool_set: set[str] | None = field(default=None, repr=False)
_configured_tool_set: set[str] | None = field(default=None, repr=False)
@property
def node_map(self) -> dict[str, dict[str, Any]]:
"""Get a map of node_id -> node for quick lookup."""
if self._node_map is None:
self._node_map = {node.get("id", ""): node for node in self.nodes}
return self._node_map
@property
def model_set(self) -> set[tuple[str, str]]:
"""Get a set of (provider, model_name) tuples for quick lookup."""
if self._model_set is None:
self._model_set = {
(m.get("provider", ""), m.get("model", ""))
for m in self.available_models
}
return self._model_set
@property
def tool_set(self) -> set[str]:
"""Get a set of all tool keys (both configured and unconfigured)."""
if self._tool_set is None:
self._tool_set = set()
for tool in self.available_tools:
provider = tool.get("provider_id") or tool.get("provider", "")
tool_key = tool.get("tool_key") or tool.get("tool_name", "")
if provider and tool_key:
self._tool_set.add(f"{provider}/{tool_key}")
if tool_key:
self._tool_set.add(tool_key)
return self._tool_set
@property
def configured_tool_set(self) -> set[str]:
"""Get a set of configured (authorized) tool keys."""
if self._configured_tool_set is None:
self._configured_tool_set = set()
for tool in self.available_tools:
if not tool.get("is_team_authorization", False):
continue
provider = tool.get("provider_id") or tool.get("provider", "")
tool_key = tool.get("tool_key") or tool.get("tool_name", "")
if provider and tool_key:
self._configured_tool_set.add(f"{provider}/{tool_key}")
if tool_key:
self._configured_tool_set.add(tool_key)
return self._configured_tool_set
def has_model(self, provider: str, model_name: str) -> bool:
"""Check if a model is available."""
return (provider, model_name) in self.model_set
def has_tool(self, tool_key: str) -> bool:
"""Check if a tool exists (configured or not)."""
return tool_key in self.tool_set
def is_tool_configured(self, tool_key: str) -> bool:
"""Check if a tool is configured and ready to use."""
return tool_key in self.configured_tool_set
def get_node(self, node_id: str) -> dict[str, Any] | None:
"""Get a node by its ID."""
return self.node_map.get(node_id)
def get_node_ids(self) -> set[str]:
"""Get all node IDs in the workflow."""
return set(self.node_map.keys())
def get_upstream_nodes(self, node_id: str) -> list[str]:
"""Get IDs of nodes that connect to this node (upstream)."""
return [
edge.get("source", "")
for edge in self.edges
if edge.get("target") == node_id
]
def get_downstream_nodes(self, node_id: str) -> list[str]:
"""Get IDs of nodes that this node connects to (downstream)."""
return [
edge.get("target", "")
for edge in self.edges
if edge.get("source") == node_id
]

View File

@ -0,0 +1,266 @@
"""
Validation Engine - Core validation logic.
The ValidationEngine orchestrates rule execution and aggregates results.
It provides a clean interface for validating workflow nodes.
"""
import logging
from dataclasses import dataclass, field
from typing import Any
from core.workflow.generator.validation.context import ValidationContext
from core.workflow.generator.validation.rules import (
RuleCategory,
Severity,
ValidationError,
get_registry,
)
logger = logging.getLogger(__name__)
@dataclass
class ValidationResult:
"""
Result of validation containing all errors classified by fixability.
Attributes:
all_errors: All validation errors found
fixable_errors: Errors that LLM can automatically fix
user_required_errors: Errors that require user intervention
warnings: Non-blocking warnings
stats: Validation statistics
"""
all_errors: list[ValidationError] = field(default_factory=list)
fixable_errors: list[ValidationError] = field(default_factory=list)
user_required_errors: list[ValidationError] = field(default_factory=list)
warnings: list[ValidationError] = field(default_factory=list)
stats: dict[str, int] = field(default_factory=dict)
@property
def has_errors(self) -> bool:
"""Check if there are any errors (excluding warnings)."""
return len(self.fixable_errors) > 0 or len(self.user_required_errors) > 0
@property
def has_fixable_errors(self) -> bool:
"""Check if there are fixable errors."""
return len(self.fixable_errors) > 0
@property
def is_valid(self) -> bool:
"""Check if validation passed (no errors, warnings are OK)."""
return not self.has_errors
def to_dict(self) -> dict[str, Any]:
"""Convert to dictionary for API response."""
return {
"fixable": [e.to_dict() for e in self.fixable_errors],
"user_required": [e.to_dict() for e in self.user_required_errors],
"warnings": [e.to_dict() for e in self.warnings],
"all_warnings": [e.message for e in self.all_errors],
"stats": self.stats,
}
def get_error_messages(self) -> list[str]:
"""Get all error messages as strings."""
return [e.message for e in self.all_errors]
def get_fixable_by_node(self) -> dict[str, list[ValidationError]]:
"""Group fixable errors by node ID."""
result: dict[str, list[ValidationError]] = {}
for error in self.fixable_errors:
if error.node_id not in result:
result[error.node_id] = []
result[error.node_id].append(error)
return result
class ValidationEngine:
"""
The main validation engine.
Usage:
engine = ValidationEngine()
context = ValidationContext(nodes=[...], available_models=[...])
result = engine.validate(context)
"""
def __init__(self):
self._registry = get_registry()
def validate(self, context: ValidationContext) -> ValidationResult:
"""
Validate all nodes in the context.
Args:
context: ValidationContext with nodes, edges, and available resources
Returns:
ValidationResult with classified errors
"""
result = ValidationResult()
stats = {
"total_nodes": len(context.nodes),
"total_rules_checked": 0,
"total_errors": 0,
"fixable_count": 0,
"user_required_count": 0,
"warning_count": 0,
}
# Validate each node
for node in context.nodes:
node_type = node.get("type", "unknown")
node_id = node.get("id", "unknown")
# Get applicable rules for this node type
rules = self._registry.get_rules_for_node(node_type)
for rule in rules:
stats["total_rules_checked"] += 1
try:
errors = rule.check(node, context)
for error in errors:
result.all_errors.append(error)
stats["total_errors"] += 1
# Classify by severity and fixability
if error.severity == Severity.WARNING:
result.warnings.append(error)
stats["warning_count"] += 1
elif error.is_fixable:
result.fixable_errors.append(error)
stats["fixable_count"] += 1
else:
result.user_required_errors.append(error)
stats["user_required_count"] += 1
except Exception:
logger.exception(
"Rule '%s' failed for node '%s'",
rule.id,
node_id,
)
# Don't let a rule failure break the entire validation
continue
# Validate edges separately
edge_errors = self._validate_edges(context)
for error in edge_errors:
result.all_errors.append(error)
stats["total_errors"] += 1
if error.is_fixable:
result.fixable_errors.append(error)
stats["fixable_count"] += 1
else:
result.user_required_errors.append(error)
stats["user_required_count"] += 1
result.stats = stats
logger.debug(
"[Validation] Completed: %d nodes, %d rules, %d errors (%d fixable, %d user-required)",
stats["total_nodes"],
stats["total_rules_checked"],
stats["total_errors"],
stats["fixable_count"],
stats["user_required_count"],
)
return result
def _validate_edges(self, context: ValidationContext) -> list[ValidationError]:
"""Validate edge connections."""
errors: list[ValidationError] = []
valid_node_ids = context.get_node_ids()
for edge in context.edges:
source = edge.get("source", "")
target = edge.get("target", "")
if source and source not in valid_node_ids:
errors.append(
ValidationError(
rule_id="edge.source.invalid",
node_id=source,
node_type="edge",
category=RuleCategory.SEMANTIC,
severity=Severity.ERROR,
is_fixable=True,
message=f"Edge source '{source}' does not exist",
fix_hint="Update edge to reference existing node",
)
)
if target and target not in valid_node_ids:
errors.append(
ValidationError(
rule_id="edge.target.invalid",
node_id=target,
node_type="edge",
category=RuleCategory.SEMANTIC,
severity=Severity.ERROR,
is_fixable=True,
message=f"Edge target '{target}' does not exist",
fix_hint="Update edge to reference existing node",
)
)
return errors
def validate_single_node(
self,
node: dict[str, Any],
context: ValidationContext,
) -> list[ValidationError]:
"""
Validate a single node.
Useful for incremental validation when a node is added/modified.
"""
node_type = node.get("type", "unknown")
rules = self._registry.get_rules_for_node(node_type)
errors: list[ValidationError] = []
for rule in rules:
try:
errors.extend(rule.check(node, context))
except Exception:
logger.exception("Rule '%s' failed", rule.id)
return errors
def validate_nodes(
nodes: list[dict[str, Any]],
edges: list[dict[str, Any]] | None = None,
available_models: list[dict[str, Any]] | None = None,
available_tools: list[dict[str, Any]] | None = None,
) -> ValidationResult:
"""
Convenience function to validate nodes without creating engine/context manually.
Args:
nodes: List of workflow nodes to validate
edges: Optional list of edges
available_models: Optional list of available models
available_tools: Optional list of available tools
Returns:
ValidationResult with classified errors
"""
context = ValidationContext(
nodes=nodes,
edges=edges or [],
available_models=available_models or [],
available_tools=available_tools or [],
)
engine = ValidationEngine()
return engine.validate(context)

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,288 @@
"""
Unit tests for the Mermaid Generator.
Tests cover:
- Basic workflow rendering
- Reserved word handling ('end' 'end_node')
- Question classifier multi-branch edges
- If-else branch labels
- Edge validation and skipping
- Tool node formatting
"""
from core.workflow.generator.utils.mermaid_generator import generate_mermaid
class TestBasicWorkflow:
"""Tests for basic workflow Mermaid generation."""
def test_simple_start_end_workflow(self):
"""Test simple Start → End workflow."""
workflow_data = {
"nodes": [
{"id": "start", "type": "start", "title": "Start"},
{"id": "end", "type": "end", "title": "End"},
],
"edges": [{"source": "start", "target": "end"}],
}
result = generate_mermaid(workflow_data)
assert "flowchart TD" in result
assert 'start["type=start|title=Start"]' in result
assert 'end_node["type=end|title=End"]' in result
assert "start --> end_node" in result
def test_start_llm_end_workflow(self):
"""Test Start → LLM → End workflow."""
workflow_data = {
"nodes": [
{"id": "start", "type": "start", "title": "Start"},
{"id": "llm", "type": "llm", "title": "Generate"},
{"id": "end", "type": "end", "title": "End"},
],
"edges": [
{"source": "start", "target": "llm"},
{"source": "llm", "target": "end"},
],
}
result = generate_mermaid(workflow_data)
assert 'llm["type=llm|title=Generate"]' in result
assert "start --> llm" in result
assert "llm --> end_node" in result
def test_empty_workflow(self):
"""Test empty workflow returns minimal output."""
workflow_data = {"nodes": [], "edges": []}
result = generate_mermaid(workflow_data)
assert result == "flowchart TD"
def test_missing_keys_handled(self):
"""Test workflow with missing keys doesn't crash."""
workflow_data = {}
result = generate_mermaid(workflow_data)
assert "flowchart TD" in result
class TestReservedWords:
"""Tests for reserved word handling in node IDs."""
def test_end_node_id_is_replaced(self):
"""Test 'end' node ID is replaced with 'end_node'."""
workflow_data = {
"nodes": [{"id": "end", "type": "end", "title": "End"}],
"edges": [],
}
result = generate_mermaid(workflow_data)
# Should use end_node instead of end
assert "end_node[" in result
assert '"type=end|title=End"' in result
def test_subgraph_node_id_is_replaced(self):
"""Test 'subgraph' node ID is replaced with 'subgraph_node'."""
workflow_data = {
"nodes": [{"id": "subgraph", "type": "code", "title": "Process"}],
"edges": [],
}
result = generate_mermaid(workflow_data)
assert "subgraph_node[" in result
def test_edge_uses_safe_ids(self):
"""Test edges correctly reference safe IDs after replacement."""
workflow_data = {
"nodes": [
{"id": "start", "type": "start", "title": "Start"},
{"id": "end", "type": "end", "title": "End"},
],
"edges": [{"source": "start", "target": "end"}],
}
result = generate_mermaid(workflow_data)
# Edge should use end_node, not end
assert "start --> end_node" in result
assert "start --> end\n" not in result
class TestBranchEdges:
"""Tests for branching node edge labels."""
def test_question_classifier_source_handles(self):
"""Test question-classifier edges with sourceHandle labels."""
workflow_data = {
"nodes": [
{"id": "classifier", "type": "question-classifier", "title": "Classify"},
{"id": "refund", "type": "llm", "title": "Handle Refund"},
{"id": "inquiry", "type": "llm", "title": "Handle Inquiry"},
],
"edges": [
{"source": "classifier", "target": "refund", "sourceHandle": "refund"},
{"source": "classifier", "target": "inquiry", "sourceHandle": "inquiry"},
],
}
result = generate_mermaid(workflow_data)
assert "classifier -->|refund| refund" in result
assert "classifier -->|inquiry| inquiry" in result
def test_if_else_true_false_handles(self):
"""Test if-else edges with true/false labels."""
workflow_data = {
"nodes": [
{"id": "ifelse", "type": "if-else", "title": "Check"},
{"id": "yes_branch", "type": "llm", "title": "Yes"},
{"id": "no_branch", "type": "llm", "title": "No"},
],
"edges": [
{"source": "ifelse", "target": "yes_branch", "sourceHandle": "true"},
{"source": "ifelse", "target": "no_branch", "sourceHandle": "false"},
],
}
result = generate_mermaid(workflow_data)
assert "ifelse -->|true| yes_branch" in result
assert "ifelse -->|false| no_branch" in result
def test_source_handle_source_is_ignored(self):
"""Test sourceHandle='source' doesn't add label."""
workflow_data = {
"nodes": [
{"id": "llm1", "type": "llm", "title": "LLM 1"},
{"id": "llm2", "type": "llm", "title": "LLM 2"},
],
"edges": [{"source": "llm1", "target": "llm2", "sourceHandle": "source"}],
}
result = generate_mermaid(workflow_data)
# Should be plain arrow without label
assert "llm1 --> llm2" in result
assert "llm1 -->|source|" not in result
class TestEdgeValidation:
"""Tests for edge validation and error handling."""
def test_edge_with_missing_source_is_skipped(self):
"""Test edge with non-existent source node is skipped."""
workflow_data = {
"nodes": [{"id": "end", "type": "end", "title": "End"}],
"edges": [{"source": "nonexistent", "target": "end"}],
}
result = generate_mermaid(workflow_data)
# Should not contain the invalid edge
assert "nonexistent" not in result
assert "-->" not in result or "nonexistent" not in result
def test_edge_with_missing_target_is_skipped(self):
"""Test edge with non-existent target node is skipped."""
workflow_data = {
"nodes": [{"id": "start", "type": "start", "title": "Start"}],
"edges": [{"source": "start", "target": "nonexistent"}],
}
result = generate_mermaid(workflow_data)
# Edge should be skipped
assert "start --> nonexistent" not in result
def test_edge_without_source_or_target_is_skipped(self):
"""Test edge missing source or target is skipped."""
workflow_data = {
"nodes": [{"id": "start", "type": "start", "title": "Start"}],
"edges": [{"source": "start"}, {"target": "start"}, {}],
}
result = generate_mermaid(workflow_data)
# No edges should be rendered
assert result.count("-->") == 0
class TestToolNodes:
"""Tests for tool node formatting."""
def test_tool_node_includes_tool_key(self):
"""Test tool node includes tool_key in label."""
workflow_data = {
"nodes": [
{
"id": "search",
"type": "tool",
"title": "Search",
"config": {"tool_key": "google/search"},
}
],
"edges": [],
}
result = generate_mermaid(workflow_data)
assert 'search["type=tool|title=Search|tool=google/search"]' in result
def test_tool_node_with_tool_name_fallback(self):
"""Test tool node uses tool_name as fallback."""
workflow_data = {
"nodes": [
{
"id": "tool1",
"type": "tool",
"title": "My Tool",
"config": {"tool_name": "my_tool"},
}
],
"edges": [],
}
result = generate_mermaid(workflow_data)
assert "tool=my_tool" in result
def test_tool_node_missing_tool_key_shows_unknown(self):
"""Test tool node without tool_key shows 'unknown'."""
workflow_data = {
"nodes": [{"id": "tool1", "type": "tool", "title": "Tool", "config": {}}],
"edges": [],
}
result = generate_mermaid(workflow_data)
assert "tool=unknown" in result
class TestNodeFormatting:
"""Tests for node label formatting."""
def test_quotes_in_title_are_escaped(self):
"""Test double quotes in title are replaced with single quotes."""
workflow_data = {
"nodes": [{"id": "llm", "type": "llm", "title": 'Say "Hello"'}],
"edges": [],
}
result = generate_mermaid(workflow_data)
# Double quotes should be replaced
assert "Say 'Hello'" in result
assert 'Say "Hello"' not in result
def test_node_without_id_is_skipped(self):
"""Test node without id is skipped."""
workflow_data = {
"nodes": [{"type": "llm", "title": "No ID"}],
"edges": [],
}
result = generate_mermaid(workflow_data)
# Should only have flowchart header
lines = [line for line in result.split("\n") if line.strip()]
assert len(lines) == 1
def test_node_default_values(self):
"""Test node with missing type/title uses defaults."""
workflow_data = {
"nodes": [{"id": "node1"}],
"edges": [],
}
result = generate_mermaid(workflow_data)
assert "type=unknown" in result
assert "title=Untitled" in result

View File

@ -0,0 +1,81 @@
from core.workflow.generator.utils.node_repair import NodeRepair
class TestNodeRepair:
"""Tests for NodeRepair utility."""
def test_repair_if_else_valid_operators(self):
"""Test that valid operators remain unchanged."""
nodes = [
{
"id": "node1",
"type": "if-else",
"config": {
"cases": [
{
"conditions": [
{"comparison_operator": "", "value": "1"},
{"comparison_operator": "=", "value": "2"},
]
}
]
},
}
]
result = NodeRepair.repair(nodes)
assert result.was_repaired is False
assert result.nodes == nodes
def test_repair_if_else_invalid_operators(self):
"""Test that invalid operators are normalized."""
nodes = [
{
"id": "node1",
"type": "if-else",
"config": {
"cases": [
{
"conditions": [
{"comparison_operator": ">=", "value": "1"},
{"comparison_operator": "<=", "value": "2"},
{"comparison_operator": "!=", "value": "3"},
{"comparison_operator": "==", "value": "4"},
]
}
]
},
}
]
result = NodeRepair.repair(nodes)
assert result.was_repaired is True
assert len(result.repairs_made) == 4
conditions = result.nodes[0]["config"]["cases"][0]["conditions"]
assert conditions[0]["comparison_operator"] == ""
assert conditions[1]["comparison_operator"] == ""
assert conditions[2]["comparison_operator"] == ""
assert conditions[3]["comparison_operator"] == "="
def test_repair_ignores_other_nodes(self):
"""Test that other node types are ignored."""
nodes = [{"id": "node1", "type": "llm", "config": {"some_field": ">="}}]
result = NodeRepair.repair(nodes)
assert result.was_repaired is False
assert result.nodes[0]["config"]["some_field"] == ">="
def test_repair_handles_missing_config(self):
"""Test robustness against missing fields."""
nodes = [
{
"id": "node1",
"type": "if-else",
# Missing config
},
{
"id": "node2",
"type": "if-else",
"config": {}, # Missing cases
},
]
result = NodeRepair.repair(nodes)
assert result.was_repaired is False

View File

@ -0,0 +1,173 @@
"""
Unit tests for the Planner Prompts.
Tests cover:
- Tool formatting for planner context
- Edge cases with missing fields
- Empty tool lists
"""
from core.workflow.generator.prompts.planner_prompts import format_tools_for_planner
class TestFormatToolsForPlanner:
"""Tests for format_tools_for_planner function."""
def test_empty_tools_returns_default_message(self):
"""Test empty tools list returns default message."""
result = format_tools_for_planner([])
assert result == "No external tools available."
def test_none_tools_returns_default_message(self):
"""Test None tools list returns default message."""
result = format_tools_for_planner(None)
assert result == "No external tools available."
def test_single_tool_formatting(self):
"""Test single tool is formatted correctly."""
tools = [
{
"provider_id": "google",
"tool_key": "search",
"tool_label": "Google Search",
"tool_description": "Search the web using Google",
}
]
result = format_tools_for_planner(tools)
assert "[google/search]" in result
assert "Google Search" in result
assert "Search the web using Google" in result
def test_multiple_tools_formatting(self):
"""Test multiple tools are formatted correctly."""
tools = [
{
"provider_id": "google",
"tool_key": "search",
"tool_label": "Search",
"tool_description": "Web search",
},
{
"provider_id": "slack",
"tool_key": "send_message",
"tool_label": "Send Message",
"tool_description": "Send a Slack message",
},
]
result = format_tools_for_planner(tools)
lines = result.strip().split("\n")
assert len(lines) == 2
assert "[google/search]" in result
assert "[slack/send_message]" in result
def test_tool_without_provider_uses_key_only(self):
"""Test tool without provider_id uses tool_key only."""
tools = [
{
"tool_key": "my_tool",
"tool_label": "My Tool",
"tool_description": "A custom tool",
}
]
result = format_tools_for_planner(tools)
# Should format as [my_tool] without provider prefix
assert "[my_tool]" in result
assert "My Tool" in result
def test_tool_with_tool_name_fallback(self):
"""Test tool uses tool_name when tool_key is missing."""
tools = [
{
"tool_name": "fallback_tool",
"description": "Fallback description",
}
]
result = format_tools_for_planner(tools)
assert "fallback_tool" in result
assert "Fallback description" in result
def test_tool_with_missing_description(self):
"""Test tool with missing description doesn't crash."""
tools = [
{
"provider_id": "test",
"tool_key": "tool1",
"tool_label": "Tool 1",
}
]
result = format_tools_for_planner(tools)
assert "[test/tool1]" in result
assert "Tool 1" in result
def test_tool_with_all_missing_fields(self):
"""Test tool with all fields missing uses defaults."""
tools = [{}]
result = format_tools_for_planner(tools)
# Should not crash, may produce minimal output
assert isinstance(result, str)
def test_tool_uses_provider_fallback(self):
"""Test tool uses 'provider' when 'provider_id' is missing."""
tools = [
{
"provider": "openai",
"tool_key": "dalle",
"tool_label": "DALL-E",
"tool_description": "Generate images",
}
]
result = format_tools_for_planner(tools)
assert "[openai/dalle]" in result
def test_tool_label_fallback_to_key(self):
"""Test tool_label falls back to tool_key when missing."""
tools = [
{
"provider_id": "test",
"tool_key": "my_key",
"tool_description": "Description here",
}
]
result = format_tools_for_planner(tools)
# Label should fallback to key
assert "my_key" in result
assert "Description here" in result
class TestPlannerPromptConstants:
"""Tests for planner prompt constant availability."""
def test_planner_system_prompt_exists(self):
"""Test PLANNER_SYSTEM_PROMPT is defined."""
from core.workflow.generator.prompts.planner_prompts import PLANNER_SYSTEM_PROMPT
assert PLANNER_SYSTEM_PROMPT is not None
assert len(PLANNER_SYSTEM_PROMPT) > 0
assert "{tools_summary}" in PLANNER_SYSTEM_PROMPT
def test_planner_user_prompt_exists(self):
"""Test PLANNER_USER_PROMPT is defined."""
from core.workflow.generator.prompts.planner_prompts import PLANNER_USER_PROMPT
assert PLANNER_USER_PROMPT is not None
assert "{instruction}" in PLANNER_USER_PROMPT
def test_planner_system_prompt_has_required_sections(self):
"""Test PLANNER_SYSTEM_PROMPT has required XML sections."""
from core.workflow.generator.prompts.planner_prompts import PLANNER_SYSTEM_PROMPT
assert "<role>" in PLANNER_SYSTEM_PROMPT
assert "<task>" in PLANNER_SYSTEM_PROMPT
assert "<available_tools>" in PLANNER_SYSTEM_PROMPT
assert "<response_format>" in PLANNER_SYSTEM_PROMPT

View File

@ -0,0 +1,536 @@
"""
Unit tests for the Validation Rule Engine.
Tests cover:
- Structure rules (required fields, types, formats)
- Semantic rules (variable references, edge connections)
- Reference rules (model exists, tool configured, dataset valid)
- ValidationEngine integration
"""
from core.workflow.generator.validation import (
ValidationContext,
ValidationEngine,
)
from core.workflow.generator.validation.rules import (
extract_variable_refs,
is_placeholder,
)
class TestPlaceholderDetection:
"""Tests for placeholder detection utility."""
def test_detects_please_select(self):
assert is_placeholder("PLEASE_SELECT_YOUR_MODEL") is True
def test_detects_your_prefix(self):
assert is_placeholder("YOUR_API_KEY") is True
def test_detects_todo(self):
assert is_placeholder("TODO: fill this in") is True
def test_detects_placeholder(self):
assert is_placeholder("PLACEHOLDER_VALUE") is True
def test_detects_example_prefix(self):
assert is_placeholder("EXAMPLE_URL") is True
def test_detects_replace_prefix(self):
assert is_placeholder("REPLACE_WITH_ACTUAL") is True
def test_case_insensitive(self):
assert is_placeholder("please_select") is True
assert is_placeholder("Please_Select") is True
def test_valid_values_not_detected(self):
assert is_placeholder("https://api.example.com") is False
assert is_placeholder("gpt-4") is False
assert is_placeholder("my_variable") is False
def test_non_string_returns_false(self):
assert is_placeholder(123) is False
assert is_placeholder(None) is False
assert is_placeholder(["list"]) is False
class TestVariableRefExtraction:
"""Tests for variable reference extraction."""
def test_extracts_simple_ref(self):
refs = extract_variable_refs("Hello {{#start.query#}}")
assert refs == [("start", "query")]
def test_extracts_multiple_refs(self):
refs = extract_variable_refs("{{#node1.output#}} and {{#node2.text#}}")
assert refs == [("node1", "output"), ("node2", "text")]
def test_extracts_nested_field(self):
refs = extract_variable_refs("{{#http_request.body#}}")
assert refs == [("http_request", "body")]
def test_no_refs_returns_empty(self):
refs = extract_variable_refs("No references here")
assert refs == []
def test_handles_malformed_refs(self):
refs = extract_variable_refs("{{#invalid}} and {{incomplete#}}")
assert refs == []
class TestValidationContext:
"""Tests for ValidationContext."""
def test_node_map_lookup(self):
ctx = ValidationContext(
nodes=[
{"id": "start", "type": "start"},
{"id": "llm_1", "type": "llm"},
]
)
assert ctx.get_node("start") == {"id": "start", "type": "start"}
assert ctx.get_node("nonexistent") is None
def test_model_set(self):
ctx = ValidationContext(
available_models=[
{"provider": "openai", "model": "gpt-4"},
{"provider": "anthropic", "model": "claude-3"},
]
)
assert ctx.has_model("openai", "gpt-4") is True
assert ctx.has_model("anthropic", "claude-3") is True
assert ctx.has_model("openai", "gpt-3.5") is False
def test_tool_set(self):
ctx = ValidationContext(
available_tools=[
{"provider_id": "google", "tool_key": "search", "is_team_authorization": True},
{"provider_id": "slack", "tool_key": "send_message", "is_team_authorization": False},
]
)
assert ctx.has_tool("google/search") is True
assert ctx.has_tool("search") is True
assert ctx.is_tool_configured("google/search") is True
assert ctx.is_tool_configured("slack/send_message") is False
def test_upstream_downstream_nodes(self):
ctx = ValidationContext(
nodes=[
{"id": "start", "type": "start"},
{"id": "llm", "type": "llm"},
{"id": "end", "type": "end"},
],
edges=[
{"source": "start", "target": "llm"},
{"source": "llm", "target": "end"},
],
)
assert ctx.get_upstream_nodes("llm") == ["start"]
assert ctx.get_downstream_nodes("llm") == ["end"]
class TestStructureRules:
"""Tests for structure validation rules."""
def test_llm_missing_prompt_template(self):
ctx = ValidationContext(
nodes=[{"id": "llm_1", "type": "llm", "config": {}}]
)
engine = ValidationEngine()
result = engine.validate(ctx)
assert result.has_errors
errors = [e for e in result.all_errors if e.rule_id == "llm.prompt_template.required"]
assert len(errors) == 1
assert errors[0].is_fixable is True
def test_llm_with_prompt_template_passes(self):
ctx = ValidationContext(
nodes=[
{
"id": "llm_1",
"type": "llm",
"config": {
"prompt_template": [
{"role": "system", "text": "You are helpful"},
{"role": "user", "text": "Hello"},
]
},
}
]
)
engine = ValidationEngine()
result = engine.validate(ctx)
# No prompt_template errors
errors = [e for e in result.all_errors if "prompt_template" in e.rule_id]
assert len(errors) == 0
def test_http_request_missing_url(self):
ctx = ValidationContext(
nodes=[{"id": "http_1", "type": "http-request", "config": {}}]
)
engine = ValidationEngine()
result = engine.validate(ctx)
errors = [e for e in result.all_errors if "http.url" in e.rule_id]
assert len(errors) == 1
assert errors[0].is_fixable is True
def test_http_request_placeholder_url(self):
ctx = ValidationContext(
nodes=[
{
"id": "http_1",
"type": "http-request",
"config": {"url": "PLEASE_SELECT_YOUR_URL", "method": "GET"},
}
]
)
engine = ValidationEngine()
result = engine.validate(ctx)
errors = [e for e in result.all_errors if "placeholder" in e.rule_id]
assert len(errors) == 1
def test_code_node_missing_fields(self):
ctx = ValidationContext(
nodes=[{"id": "code_1", "type": "code", "config": {}}]
)
engine = ValidationEngine()
result = engine.validate(ctx)
error_rules = {e.rule_id for e in result.all_errors}
assert "code.code.required" in error_rules
assert "code.language.required" in error_rules
def test_knowledge_retrieval_missing_dataset(self):
ctx = ValidationContext(
nodes=[{"id": "kb_1", "type": "knowledge-retrieval", "config": {}}]
)
engine = ValidationEngine()
result = engine.validate(ctx)
errors = [e for e in result.all_errors if "knowledge.dataset" in e.rule_id]
assert len(errors) == 1
assert errors[0].is_fixable is False # User must configure
class TestSemanticRules:
"""Tests for semantic validation rules."""
def test_valid_variable_reference(self):
ctx = ValidationContext(
nodes=[
{"id": "start", "type": "start", "config": {}},
{
"id": "llm_1",
"type": "llm",
"config": {
"prompt_template": [
{"role": "user", "text": "Process: {{#start.query#}}"}
]
},
},
]
)
engine = ValidationEngine()
result = engine.validate(ctx)
# No variable reference errors
errors = [e for e in result.all_errors if "variable.ref" in e.rule_id]
assert len(errors) == 0
def test_invalid_variable_reference(self):
ctx = ValidationContext(
nodes=[
{"id": "start", "type": "start", "config": {}},
{
"id": "llm_1",
"type": "llm",
"config": {
"prompt_template": [
{"role": "user", "text": "Process: {{#nonexistent.field#}}"}
]
},
},
]
)
engine = ValidationEngine()
result = engine.validate(ctx)
errors = [e for e in result.all_errors if "variable.ref" in e.rule_id]
assert len(errors) == 1
assert "nonexistent" in errors[0].message
def test_edge_validation(self):
ctx = ValidationContext(
nodes=[
{"id": "start", "type": "start", "config": {}},
{"id": "end", "type": "end", "config": {}},
],
edges=[
{"source": "start", "target": "end"},
{"source": "nonexistent", "target": "end"},
],
)
engine = ValidationEngine()
result = engine.validate(ctx)
errors = [e for e in result.all_errors if "edge" in e.rule_id]
assert len(errors) == 1
assert "nonexistent" in errors[0].message
class TestReferenceRules:
"""Tests for reference validation rules (models, tools)."""
def test_llm_missing_model_with_available(self):
ctx = ValidationContext(
nodes=[
{
"id": "llm_1",
"type": "llm",
"config": {"prompt_template": [{"role": "user", "text": "Hi"}]},
}
],
available_models=[{"provider": "openai", "model": "gpt-4"}],
)
engine = ValidationEngine()
result = engine.validate(ctx)
errors = [e for e in result.all_errors if e.rule_id == "model.required"]
assert len(errors) == 1
assert errors[0].is_fixable is True
def test_llm_missing_model_no_available(self):
ctx = ValidationContext(
nodes=[
{
"id": "llm_1",
"type": "llm",
"config": {"prompt_template": [{"role": "user", "text": "Hi"}]},
}
],
available_models=[], # No models available
)
engine = ValidationEngine()
result = engine.validate(ctx)
errors = [e for e in result.all_errors if e.rule_id == "model.no_available"]
assert len(errors) == 1
assert errors[0].is_fixable is False
def test_llm_with_valid_model(self):
ctx = ValidationContext(
nodes=[
{
"id": "llm_1",
"type": "llm",
"config": {
"prompt_template": [{"role": "user", "text": "Hi"}],
"model": {"provider": "openai", "name": "gpt-4"},
},
}
],
available_models=[{"provider": "openai", "model": "gpt-4"}],
)
engine = ValidationEngine()
result = engine.validate(ctx)
errors = [e for e in result.all_errors if "model" in e.rule_id]
assert len(errors) == 0
def test_llm_with_invalid_model(self):
ctx = ValidationContext(
nodes=[
{
"id": "llm_1",
"type": "llm",
"config": {
"prompt_template": [{"role": "user", "text": "Hi"}],
"model": {"provider": "openai", "name": "gpt-99"},
},
}
],
available_models=[{"provider": "openai", "model": "gpt-4"}],
)
engine = ValidationEngine()
result = engine.validate(ctx)
errors = [e for e in result.all_errors if e.rule_id == "model.not_found"]
assert len(errors) == 1
assert errors[0].is_fixable is True
def test_tool_node_not_found(self):
ctx = ValidationContext(
nodes=[
{
"id": "tool_1",
"type": "tool",
"config": {"tool_key": "nonexistent/tool"},
}
],
available_tools=[],
)
engine = ValidationEngine()
result = engine.validate(ctx)
errors = [e for e in result.all_errors if e.rule_id == "tool.not_found"]
assert len(errors) == 1
def test_tool_node_not_configured(self):
ctx = ValidationContext(
nodes=[
{
"id": "tool_1",
"type": "tool",
"config": {"tool_key": "google/search"},
}
],
available_tools=[
{"provider_id": "google", "tool_key": "search", "is_team_authorization": False}
],
)
engine = ValidationEngine()
result = engine.validate(ctx)
errors = [e for e in result.all_errors if e.rule_id == "tool.not_configured"]
assert len(errors) == 1
assert errors[0].is_fixable is False
class TestValidationResult:
"""Tests for ValidationResult classification."""
def test_has_errors(self):
ctx = ValidationContext(
nodes=[{"id": "llm_1", "type": "llm", "config": {}}]
)
engine = ValidationEngine()
result = engine.validate(ctx)
assert result.has_errors is True
assert result.is_valid is False
def test_has_fixable_errors(self):
ctx = ValidationContext(
nodes=[
{
"id": "llm_1",
"type": "llm",
"config": {"prompt_template": [{"role": "user", "text": "Hi"}]},
}
],
available_models=[{"provider": "openai", "model": "gpt-4"}],
)
engine = ValidationEngine()
result = engine.validate(ctx)
assert result.has_fixable_errors is True
assert len(result.fixable_errors) > 0
def test_get_fixable_by_node(self):
ctx = ValidationContext(
nodes=[
{"id": "llm_1", "type": "llm", "config": {}},
{"id": "http_1", "type": "http-request", "config": {}},
]
)
engine = ValidationEngine()
result = engine.validate(ctx)
by_node = result.get_fixable_by_node()
assert "llm_1" in by_node
assert "http_1" in by_node
def test_to_dict(self):
ctx = ValidationContext(
nodes=[{"id": "llm_1", "type": "llm", "config": {}}]
)
engine = ValidationEngine()
result = engine.validate(ctx)
d = result.to_dict()
assert "fixable" in d
assert "user_required" in d
assert "warnings" in d
assert "all_warnings" in d
assert "stats" in d
class TestIntegration:
"""Integration tests for the full validation pipeline."""
def test_complete_workflow_validation(self):
"""Test validation of a complete workflow."""
ctx = ValidationContext(
nodes=[
{
"id": "start",
"type": "start",
"config": {"variables": [{"variable": "query", "type": "text-input"}]},
},
{
"id": "llm_1",
"type": "llm",
"config": {
"model": {"provider": "openai", "name": "gpt-4"},
"prompt_template": [{"role": "user", "text": "{{#start.query#}}"}],
},
},
{
"id": "end",
"type": "end",
"config": {"outputs": [{"variable": "result", "value_selector": ["llm_1", "text"]}]},
},
],
edges=[
{"source": "start", "target": "llm_1"},
{"source": "llm_1", "target": "end"},
],
available_models=[{"provider": "openai", "model": "gpt-4"}],
)
engine = ValidationEngine()
result = engine.validate(ctx)
# Should have no errors
assert result.is_valid is True
assert len(result.fixable_errors) == 0
assert len(result.user_required_errors) == 0
def test_workflow_with_multiple_errors(self):
"""Test workflow with multiple types of errors."""
ctx = ValidationContext(
nodes=[
{"id": "start", "type": "start", "config": {}},
{
"id": "llm_1",
"type": "llm",
"config": {}, # Missing prompt_template and model
},
{
"id": "kb_1",
"type": "knowledge-retrieval",
"config": {"dataset_ids": ["PLEASE_SELECT_YOUR_DATASET"]},
},
{"id": "end", "type": "end", "config": {}},
],
available_models=[{"provider": "openai", "model": "gpt-4"}],
)
engine = ValidationEngine()
result = engine.validate(ctx)
# Should have multiple errors
assert result.has_errors is True
assert len(result.fixable_errors) >= 2 # model, prompt_template
assert len(result.user_required_errors) >= 1 # dataset placeholder
# Check stats
assert result.stats["total_nodes"] == 4
assert result.stats["total_errors"] >= 3

View File

@ -0,0 +1,435 @@
"""
Unit tests for the Vibe Workflow Validator.
Tests cover:
- Basic validation function
- User-friendly validation hints
- Edge cases and error handling
"""
from core.workflow.generator.utils.workflow_validator import ValidationHint, WorkflowValidator
class TestValidationHint:
"""Tests for ValidationHint dataclass."""
def test_hint_creation(self):
"""Test creating a validation hint."""
hint = ValidationHint(
node_id="llm_1",
field="model",
message="Model is not configured",
severity="error",
)
assert hint.node_id == "llm_1"
assert hint.field == "model"
assert hint.message == "Model is not configured"
assert hint.severity == "error"
def test_hint_with_suggestion(self):
"""Test hint with suggestion."""
hint = ValidationHint(
node_id="http_1",
field="url",
message="URL is required",
severity="error",
suggestion="Add a valid URL like https://api.example.com",
)
assert hint.suggestion is not None
class TestWorkflowValidatorBasic:
"""Tests for basic validation scenarios."""
def test_empty_workflow_is_valid(self):
"""Test empty workflow passes validation."""
workflow_data = {"nodes": [], "edges": []}
is_valid, hints = WorkflowValidator.validate(workflow_data, [])
# Empty but valid structure
assert is_valid is True
assert len(hints) == 0
def test_minimal_valid_workflow(self):
"""Test minimal Start → End workflow."""
workflow_data = {
"nodes": [
{"id": "start", "type": "start", "config": {}},
{"id": "end", "type": "end", "config": {}},
],
"edges": [{"source": "start", "target": "end"}],
}
is_valid, hints = WorkflowValidator.validate(workflow_data, [])
assert is_valid is True
def test_complete_workflow_with_llm(self):
"""Test complete workflow with LLM node."""
workflow_data = {
"nodes": [
{"id": "start", "type": "start", "config": {"variables": []}},
{
"id": "llm",
"type": "llm",
"config": {
"model": {"provider": "openai", "name": "gpt-4"},
"prompt_template": [{"role": "user", "text": "Hello"}],
},
},
{"id": "end", "type": "end", "config": {"outputs": []}},
],
"edges": [
{"source": "start", "target": "llm"},
{"source": "llm", "target": "end"},
],
}
is_valid, hints = WorkflowValidator.validate(workflow_data, [])
# Should pass with no critical errors
errors = [h for h in hints if h.severity == "error"]
assert len(errors) == 0
class TestVariableReferenceValidation:
"""Tests for variable reference validation."""
def test_valid_variable_reference(self):
"""Test valid variable reference passes."""
workflow_data = {
"nodes": [
{"id": "start", "type": "start", "config": {}},
{
"id": "llm",
"type": "llm",
"config": {"prompt_template": [{"role": "user", "text": "Query: {{#start.query#}}"}]},
},
],
"edges": [{"source": "start", "target": "llm"}],
}
is_valid, hints = WorkflowValidator.validate(workflow_data, [])
ref_errors = [h for h in hints if "reference" in h.message.lower()]
assert len(ref_errors) == 0
def test_invalid_variable_reference(self):
"""Test invalid variable reference generates hint."""
workflow_data = {
"nodes": [
{"id": "start", "type": "start", "config": {}},
{
"id": "llm",
"type": "llm",
"config": {"prompt_template": [{"role": "user", "text": "{{#nonexistent.field#}}"}]},
},
],
"edges": [{"source": "start", "target": "llm"}],
}
is_valid, hints = WorkflowValidator.validate(workflow_data, [])
# Should have a hint about invalid reference
ref_hints = [h for h in hints if "nonexistent" in h.message or "reference" in h.message.lower()]
assert len(ref_hints) >= 1
class TestEdgeValidation:
"""Tests for edge validation."""
def test_edge_with_invalid_source(self):
"""Test edge with non-existent source generates hint."""
workflow_data = {
"nodes": [{"id": "end", "type": "end", "config": {}}],
"edges": [{"source": "nonexistent", "target": "end"}],
}
is_valid, hints = WorkflowValidator.validate(workflow_data, [])
# Should have hint about invalid edge
edge_hints = [h for h in hints if "edge" in h.message.lower() or "source" in h.message.lower()]
assert len(edge_hints) >= 1
def test_edge_with_invalid_target(self):
"""Test edge with non-existent target generates hint."""
workflow_data = {
"nodes": [{"id": "start", "type": "start", "config": {}}],
"edges": [{"source": "start", "target": "nonexistent"}],
}
is_valid, hints = WorkflowValidator.validate(workflow_data, [])
edge_hints = [h for h in hints if "edge" in h.message.lower() or "target" in h.message.lower()]
assert len(edge_hints) >= 1
class TestToolValidation:
"""Tests for tool node validation."""
def test_tool_node_found_in_available(self):
"""Test tool node that exists in available tools."""
workflow_data = {
"nodes": [
{"id": "start", "type": "start", "config": {}},
{
"id": "tool1",
"type": "tool",
"config": {"tool_key": "google/search"},
},
{"id": "end", "type": "end", "config": {}},
],
"edges": [{"source": "start", "target": "tool1"}, {"source": "tool1", "target": "end"}],
}
available_tools = [{"provider_id": "google", "tool_key": "search", "is_team_authorization": True}]
is_valid, hints = WorkflowValidator.validate(workflow_data, available_tools)
tool_errors = [h for h in hints if h.severity == "error" and "tool" in h.message.lower()]
assert len(tool_errors) == 0
def test_tool_node_not_found(self):
"""Test tool node not in available tools generates hint."""
workflow_data = {
"nodes": [
{
"id": "tool1",
"type": "tool",
"config": {"tool_key": "unknown/tool"},
}
],
"edges": [],
}
available_tools = []
is_valid, hints = WorkflowValidator.validate(workflow_data, available_tools)
tool_hints = [h for h in hints if "tool" in h.message.lower()]
assert len(tool_hints) >= 1
class TestQuestionClassifierValidation:
"""Tests for question-classifier node validation."""
def test_question_classifier_with_classes(self):
"""Test question-classifier with valid classes."""
workflow_data = {
"nodes": [
{"id": "start", "type": "start", "config": {}},
{
"id": "classifier",
"type": "question-classifier",
"config": {
"classes": [
{"id": "class1", "name": "Class 1"},
{"id": "class2", "name": "Class 2"},
],
"model": {"provider": "openai", "name": "gpt-4", "mode": "chat"},
},
},
{"id": "h1", "type": "llm", "config": {}},
{"id": "h2", "type": "llm", "config": {}},
{"id": "end", "type": "end", "config": {}},
],
"edges": [
{"source": "start", "target": "classifier"},
{"source": "classifier", "sourceHandle": "class1", "target": "h1"},
{"source": "classifier", "sourceHandle": "class2", "target": "h2"},
{"source": "h1", "target": "end"},
{"source": "h2", "target": "end"},
],
}
available_models = [{"provider": "openai", "model": "gpt-4", "mode": "chat"}]
is_valid, hints = WorkflowValidator.validate(workflow_data, [], available_models=available_models)
class_errors = [h for h in hints if "class" in h.message.lower() and h.severity == "error"]
assert len(class_errors) == 0
def test_question_classifier_missing_classes(self):
"""Test question-classifier without classes generates hint."""
workflow_data = {
"nodes": [
{
"id": "classifier",
"type": "question-classifier",
"config": {"model": {"provider": "openai", "name": "gpt-4", "mode": "chat"}},
}
],
"edges": [],
}
available_models = [{"provider": "openai", "model": "gpt-4", "mode": "chat"}]
is_valid, hints = WorkflowValidator.validate(workflow_data, [], available_models=available_models)
# Should have hint about missing classes
class_hints = [h for h in hints if "class" in h.message.lower()]
assert len(class_hints) >= 1
class TestHttpRequestValidation:
"""Tests for HTTP request node validation."""
def test_http_request_with_url(self):
"""Test HTTP request with valid URL."""
workflow_data = {
"nodes": [
{"id": "start", "type": "start", "config": {}},
{
"id": "http",
"type": "http-request",
"config": {"url": "https://api.example.com", "method": "GET"},
},
{"id": "end", "type": "end", "config": {}},
],
"edges": [{"source": "start", "target": "http"}, {"source": "http", "target": "end"}],
}
is_valid, hints = WorkflowValidator.validate(workflow_data, [])
url_errors = [h for h in hints if "url" in h.message.lower() and h.severity == "error"]
assert len(url_errors) == 0
def test_http_request_missing_url(self):
"""Test HTTP request without URL generates hint."""
workflow_data = {
"nodes": [
{
"id": "http",
"type": "http-request",
"config": {"method": "GET"},
}
],
"edges": [],
}
is_valid, hints = WorkflowValidator.validate(workflow_data, [])
url_hints = [h for h in hints if "url" in h.message.lower()]
assert len(url_hints) >= 1
class TestParameterExtractorValidation:
"""Tests for parameter-extractor node validation."""
def test_parameter_extractor_valid_params(self):
"""Test parameter-extractor with valid parameters."""
workflow_data = {
"nodes": [
{"id": "start", "type": "start", "config": {}},
{
"id": "extractor",
"type": "parameter-extractor",
"config": {
"instruction": "Extract info",
"parameters": [
{
"name": "name",
"type": "string",
"description": "Name",
"required": True,
}
],
"model": {"provider": "openai", "name": "gpt-4", "mode": "chat"},
},
},
{"id": "end", "type": "end", "config": {}},
],
"edges": [{"source": "start", "target": "extractor"}, {"source": "extractor", "target": "end"}],
}
available_models = [{"provider": "openai", "model": "gpt-4", "mode": "chat"}]
is_valid, hints = WorkflowValidator.validate(workflow_data, [], available_models=available_models)
errors = [h for h in hints if h.severity == "error"]
assert len(errors) == 0
def test_parameter_extractor_missing_required_field(self):
"""Test parameter-extractor missing 'required' field in parameter item."""
workflow_data = {
"nodes": [
{
"id": "extractor",
"type": "parameter-extractor",
"config": {
"instruction": "Extract info",
"parameters": [
{
"name": "name",
"type": "string",
"description": "Name",
# Missing 'required'
}
],
"model": {"provider": "openai", "name": "gpt-4", "mode": "chat"},
},
}
],
"edges": [],
}
available_models = [{"provider": "openai", "model": "gpt-4", "mode": "chat"}]
is_valid, hints = WorkflowValidator.validate(workflow_data, [], available_models=available_models)
errors = [h for h in hints if "required" in h.message and h.severity == "error"]
assert len(errors) >= 1
assert "parameter-extractor" in errors[0].node_type
class TestIfElseValidation:
"""Tests for if-else node validation."""
def test_if_else_valid_operators(self):
"""Test if-else with valid operators."""
workflow_data = {
"nodes": [
{"id": "start", "type": "start", "config": {}},
{
"id": "ifelse",
"type": "if-else",
"config": {
"cases": [{"case_id": "c1", "conditions": [{"comparison_operator": "", "value": "1"}]}]
},
},
{"id": "t", "type": "llm", "config": {}},
{"id": "f", "type": "llm", "config": {}},
{"id": "end", "type": "end", "config": {}},
],
"edges": [
{"source": "start", "target": "ifelse"},
{"source": "ifelse", "sourceHandle": "true", "target": "t"},
{"source": "ifelse", "sourceHandle": "false", "target": "f"},
{"source": "t", "target": "end"},
{"source": "f", "target": "end"},
],
}
is_valid, hints = WorkflowValidator.validate(workflow_data, [])
errors = [h for h in hints if h.severity == "error"]
# Filter out LLM model errors if any (available tools/models check might trigger)
# (actually available_models empty list might trigger model error?
# No, model config validation skips if model field not present? No, LLM has model config.
# But logic skips check if key missing? Let's check logic.
# _check_model_config checks if provider/name match available. If available is empty, it fails.
# But wait, validate default available_models is None?
# I should provide mock available_models or ignore model errors.
# Actually LLM node "config": {} implies missing model config. Rules check if config structure is valid?
# Let's filter specifically for operator errors.
operator_errors = [h for h in errors if "operator" in h.message]
assert len(operator_errors) == 0
def test_if_else_invalid_operators(self):
"""Test if-else with invalid operators."""
workflow_data = {
"nodes": [
{"id": "start", "type": "start", "config": {}},
{
"id": "ifelse",
"type": "if-else",
"config": {
"cases": [{"case_id": "c1", "conditions": [{"comparison_operator": ">=", "value": "1"}]}]
},
},
{"id": "t", "type": "llm", "config": {}},
{"id": "f", "type": "llm", "config": {}},
{"id": "end", "type": "end", "config": {}},
],
"edges": [
{"source": "start", "target": "ifelse"},
{"source": "ifelse", "sourceHandle": "true", "target": "t"},
{"source": "ifelse", "sourceHandle": "false", "target": "f"},
{"source": "t", "target": "end"},
{"source": "f", "target": "end"},
],
}
is_valid, hints = WorkflowValidator.validate(workflow_data, [])
operator_errors = [h for h in hints if "operator" in h.message and h.severity == "error"]
assert len(operator_errors) > 0
assert "" in operator_errors[0].suggestion

View File

@ -0,0 +1,82 @@
import { describe, it, expect } from 'vitest'
import { replaceVariableReferences } from '../use-workflow-vibe'
import { BlockEnum } from '@/app/components/workflow/types'
// Mock types needed for the test
interface NodeData {
title: string
[key: string]: any
}
describe('use-workflow-vibe', () => {
describe('replaceVariableReferences', () => {
it('should replace variable references in strings', () => {
const data = {
title: 'Test Node',
prompt: 'Hello {{#old_id.query#}}',
}
const nodeIdMap = new Map<string, any>()
nodeIdMap.set('old_id', { id: 'new_uuid', data: { type: 'llm' } })
const result = replaceVariableReferences(data, nodeIdMap) as NodeData
expect(result.prompt).toBe('Hello {{#new_uuid.query#}}')
})
it('should handle multiple references in one string', () => {
const data = {
title: 'Test Node',
text: '{{#node1.out#}} and {{#node2.out#}}',
}
const nodeIdMap = new Map<string, any>()
nodeIdMap.set('node1', { id: 'uuid1', data: { type: 'llm' } })
nodeIdMap.set('node2', { id: 'uuid2', data: { type: 'llm' } })
const result = replaceVariableReferences(data, nodeIdMap) as NodeData
expect(result.text).toBe('{{#uuid1.out#}} and {{#uuid2.out#}}')
})
it('should replace variable references in value_selector arrays', () => {
const data = {
title: 'End Node',
outputs: [
{
variable: 'result',
value_selector: ['old_id', 'text'],
},
],
}
const nodeIdMap = new Map<string, any>()
nodeIdMap.set('old_id', { id: 'new_uuid', data: { type: 'llm' } })
const result = replaceVariableReferences(data, nodeIdMap) as NodeData
expect(result.outputs[0].value_selector).toEqual(['new_uuid', 'text'])
})
it('should handle nested objects recursively', () => {
const data = {
config: {
model: {
prompt: '{{#old_id.text#}}',
},
},
}
const nodeIdMap = new Map<string, any>()
nodeIdMap.set('old_id', { id: 'new_uuid', data: { type: 'llm' } })
const result = replaceVariableReferences(data, nodeIdMap) as any
expect(result.config.model.prompt).toBe('{{#new_uuid.text#}}')
})
it('should ignoring missing node mappings', () => {
const data = {
text: '{{#missing_id.text#}}',
}
const nodeIdMap = new Map<string, any>()
// missing_id is not in map
const result = replaceVariableReferences(data, nodeIdMap) as NodeData
expect(result.text).toBe('{{#missing_id.text#}}')
})
})
})

View File

@ -39,6 +39,7 @@ import {
getNodeCustomTypeByNodeDataType,
getNodesConnectedSourceOrTargetHandleIdsMap,
} from '../utils'
import { initialNodes as initializeNodeData } from '../utils/workflow-init'
import { useNodesMetaData } from './use-nodes-meta-data'
import { useNodesSyncDraft } from './use-nodes-sync-draft'
import { useNodesReadOnly } from './use-workflow'
@ -115,7 +116,7 @@ const normalizeProviderIcon = (icon?: ToolWithProvider['icon']) => {
* - Mixed content objects: {type: "mixed", value: "..."} normalized to string
* - Field name correction based on node type
*/
const replaceVariableReferences = (
export const replaceVariableReferences = (
data: unknown,
nodeIdMap: Map<string, Node>,
parentKey?: string,
@ -124,6 +125,11 @@ const replaceVariableReferences = (
// Replace {{#old_id.field#}} patterns and correct field names
return data.replace(/\{\{#([^.#]+)\.([^#]+)#\}\}/g, (match, oldId, field) => {
const newNode = nodeIdMap.get(oldId)
// #region agent log
if (!newNode) {
console.warn(`[VIBE DEBUG] replaceVariableReferences: No mapping for "${oldId}" in template "${match}"`)
}
// #endregion
if (newNode) {
const nodeType = newNode.data?.type as string || ''
const correctedField = correctFieldName(field, nodeType)
@ -138,6 +144,11 @@ const replaceVariableReferences = (
if (data.length >= 2 && typeof data[0] === 'string' && typeof data[1] === 'string') {
const potentialNodeId = data[0]
const newNode = nodeIdMap.get(potentialNodeId)
// #region agent log
if (!newNode && !['sys', 'env', 'conversation'].includes(potentialNodeId)) {
console.warn(`[VIBE DEBUG] replaceVariableReferences: No mapping for "${potentialNodeId}" in selector [${data.join(', ')}]`)
}
// #endregion
if (newNode) {
const nodeType = newNode.data?.type as string || ''
const correctedField = correctFieldName(data[1], nodeType)
@ -598,6 +609,8 @@ export const useWorkflowVibe = () => {
const { getNodes } = store.getState()
const nodes = getNodes()
if (!nodesMetaDataMap) {
Toast.notify({ type: 'error', message: t('workflow.vibe.nodesUnavailable') })
return { nodes: [], edges: [] }
@ -699,12 +712,59 @@ export const useWorkflowVibe = () => {
}
}
// For any node with model config, ALWAYS use user's default model
if (backendConfig.model && defaultModel) {
mergedConfig.model = {
provider: defaultModel.provider.provider,
name: defaultModel.model,
mode: 'chat',
// For End nodes, ensure outputs have value_selector format
// New format (preferred): {"outputs": [{"variable": "name", "value_selector": ["nodeId", "field"]}]}
// Legacy format (fallback): {"outputs": [{"variable": "name", "value": "{{#nodeId.field#}}"}]}
if (nodeType === BlockEnum.End && backendConfig.outputs) {
const outputs = backendConfig.outputs as Array<{ variable?: string, value?: string, value_selector?: string[] }>
mergedConfig.outputs = outputs.map((output) => {
// Preferred: value_selector array format (new LLM output format)
if (output.value_selector && Array.isArray(output.value_selector)) {
return output
}
// Parse value like "{{#nodeId.field#}}" into ["nodeId", "field"]
if (output.value) {
const match = output.value.match(/\{\{#([^.]+)\.([^#]+)#\}\}/)
if (match) {
return {
variable: output.variable,
value_selector: [match[1], match[2]],
}
}
}
// Fallback: return with empty value_selector to prevent crash
return {
variable: output.variable || 'output',
value_selector: [],
}
})
}
// For Parameter Extractor nodes, ensure each parameter has a 'required' field
// Backend may omit this field, but Dify's Pydantic model requires it
if (nodeType === BlockEnum.ParameterExtractor && backendConfig.parameters) {
const parameters = backendConfig.parameters as Array<{ name?: string, type?: string, description?: string, required?: boolean }>
mergedConfig.parameters = parameters.map((param) => ({
...param,
required: param.required ?? true, // Default to required if not specified
}))
}
// For any node with model config, ALWAYS use user's configured model
// This prevents "Model not exist" errors when LLM generates models the user doesn't have configured
// Applies to: LLM, QuestionClassifier, ParameterExtractor, and any future model-dependent nodes
if (backendConfig.model) {
// Try to use defaultModel first, fallback to first available model from modelList
const fallbackModel = modelList?.[0]?.models?.[0]
const modelProvider = defaultModel?.provider?.provider || modelList?.[0]?.provider
const modelName = defaultModel?.model || fallbackModel?.model
if (modelProvider && modelName) {
mergedConfig.model = {
provider: modelProvider,
name: modelName,
mode: 'chat',
}
}
}
@ -731,10 +791,19 @@ export const useWorkflowVibe = () => {
}
// Replace variable references in all node configs using the nodeIdMap
// This converts {{#old_id.field#}} to {{#new_uuid.field#}}
for (const node of newNodes) {
node.data = replaceVariableReferences(node.data, nodeIdMap) as typeof node.data
}
// Use Dify's standard node initialization to handle all node types generically
// This sets up _targetBranches for question-classifier/if-else, _children for iteration/loop, etc.
const initializedNodes = initializeNodeData(newNodes, [])
// Update newNodes with initialized data
newNodes.splice(0, newNodes.length, ...initializedNodes)
if (!newNodes.length) {
Toast.notify({ type: 'error', message: t('workflow.vibe.invalidFlowchart') })
return { nodes: [], edges: [] }
@ -762,12 +831,16 @@ export const useWorkflowVibe = () => {
zIndex: 0,
})
const newEdges: Edge[] = []
for (const edgeSpec of backendEdges) {
const sourceNode = nodeIdMap.get(edgeSpec.source)
const targetNode = nodeIdMap.get(edgeSpec.target)
if (!sourceNode || !targetNode)
if (!sourceNode || !targetNode) {
console.warn(`[VIBE] Edge skipped: source=${edgeSpec.source} (found=${!!sourceNode}), target=${edgeSpec.target} (found=${!!targetNode})`)
continue
}
let sourceHandle = edgeSpec.sourceHandle || 'source'
// Handle IfElse branch handles
@ -775,9 +848,11 @@ export const useWorkflowVibe = () => {
sourceHandle = 'source'
}
newEdges.push(buildEdge(sourceNode, targetNode, sourceHandle, edgeSpec.targetHandle || 'target'))
}
// Layout nodes
const bounds = nodes.reduce(
(acc, node) => {
@ -878,11 +953,15 @@ export const useWorkflowVibe = () => {
}
})
setNodes(updatedNodes)
setEdges([...edges, ...newEdges])
saveStateToHistory(WorkflowHistoryEvent.NodeAdd, { nodeId: newNodes[0].id })
handleSyncWorkflowDraft()
workflowStore.setState(state => ({
...state,
showVibePanel: false,
@ -1194,81 +1273,128 @@ export const useWorkflowVibe = () => {
output_schema: tool.output_schema,
}))
const stream = await generateFlowchart({
instruction: trimmed,
model_config: latestModelConfig!,
existing_nodes: existingNodesPayload,
tools: toolsPayload,
regenerate_mode: regenerateMode,
})
const availableNodesPayload = availableNodesList.map(node => ({
type: node.type,
title: node.title,
description: node.description,
}))
let mermaidCode = ''
let backendNodes: BackendNodeSpec[] | undefined
let backendEdges: BackendEdgeSpec[] | undefined
const reader = stream.getReader()
const decoder = new TextDecoder()
while (true) {
const { done, value } = await reader.read()
if (done)
break
const chunk = decoder.decode(value)
const lines = chunk.split('\n')
for (const line of lines) {
if (!line.trim() || !line.startsWith('data: '))
continue
try {
const data = JSON.parse(line.slice(6))
if (data.event === 'message' || data.event === 'workflow_generated') {
if (data.data?.text) {
mermaidCode += data.data.text
workflowStore.setState(state => ({
...state,
vibePanelMermaidCode: mermaidCode,
}))
}
if (data.data?.nodes) {
backendNodes = data.data.nodes
workflowStore.setState(state => ({
...state,
vibePanelBackendNodes: backendNodes,
}))
}
if (data.data?.edges) {
backendEdges = data.data.edges
workflowStore.setState(state => ({
...state,
vibePanelBackendEdges: backendEdges,
}))
}
if (data.data?.intent) {
workflowStore.setState(state => ({
...state,
vibePanelIntent: data.data.intent,
}))
}
if (data.data?.message) {
workflowStore.setState(state => ({
...state,
vibePanelMessage: data.data.message,
}))
}
if (data.data?.suggestions) {
workflowStore.setState(state => ({
...state,
vibePanelSuggestions: data.data.suggestions,
}))
}
}
}
catch (e) {
console.error('Error parsing chunk:', e)
if (!isMermaidFlowchart(trimmed)) {
// Build previous workflow context if regenerating
const { vibePanelBackendNodes, vibePanelBackendEdges, vibePanelLastWarnings } = workflowStore.getState()
const previousWorkflow = regenerateMode && vibePanelBackendNodes && vibePanelBackendNodes.length > 0
? {
nodes: vibePanelBackendNodes,
edges: vibePanelBackendEdges || [],
warnings: vibePanelLastWarnings || [],
}
: undefined
// Map language code to human-readable language name for LLM
const languageNameMap: Record<string, string> = {
en_US: 'English',
zh_Hans: 'Chinese',
zh_Hant: 'Traditional Chinese',
ja_JP: 'Japanese',
ko_KR: 'Korean',
pt_BR: 'Portuguese',
es_ES: 'Spanish',
fr_FR: 'French',
de_DE: 'German',
it_IT: 'Italian',
ru_RU: 'Russian',
uk_UA: 'Ukrainian',
vi_VN: 'Vietnamese',
pl_PL: 'Polish',
ro_RO: 'Romanian',
tr_TR: 'Turkish',
fa_IR: 'Persian',
hi_IN: 'Hindi',
}
const preferredLanguage = languageNameMap[language] || 'English'
// Extract available models from user's configured model providers
const availableModelsPayload = modelList?.flatMap(provider =>
provider.models.map(model => ({
provider: provider.provider,
model: model.model,
})),
) || []
const requestPayload = {
instruction: trimmed,
model_config: latestModelConfig,
available_nodes: availableNodesPayload,
existing_nodes: existingNodesPayload,
available_tools: toolsPayload,
selected_node_ids: [],
previous_workflow: previousWorkflow,
regenerate_mode: regenerateMode,
language: preferredLanguage,
available_models: availableModelsPayload,
}
const response = await generateFlowchart(requestPayload)
const { error, flowchart, nodes, edges, intent, message, warnings, suggestions } = response
if (error) {
Toast.notify({ type: 'error', message: error })
setIsVibeGenerating(false)
return
}
// Handle off_topic intent - show rejection message and suggestions
if (intent === 'off_topic') {
workflowStore.setState(state => ({
...state,
vibePanelMermaidCode: '',
vibePanelMessage: message || t('workflow.vibe.offTopicDefault'),
vibePanelSuggestions: suggestions || [],
vibePanelIntent: 'off_topic',
isVibeGenerating: false,
}))
return
}
if (!flowchart) {
Toast.notify({ type: 'error', message: t('workflow.vibe.missingFlowchart') })
setIsVibeGenerating(false)
return
}
// Show warnings if any (includes tool sanitization warnings)
const responseWarnings = warnings || []
if (responseWarnings.length > 0) {
responseWarnings.forEach((warning) => {
Toast.notify({ type: 'warning', message: warning })
})
}
mermaidCode = flowchart
// Store backend nodes/edges for direct use (bypasses mermaid re-parsing)
backendNodes = nodes
backendEdges = edges
// Store warnings for regeneration context
workflowStore.setState(state => ({
...state,
vibePanelLastWarnings: responseWarnings,
}))
workflowStore.setState(state => ({
...state,
vibePanelMermaidCode: mermaidCode,
vibePanelBackendNodes: backendNodes,
vibePanelBackendEdges: backendEdges,
vibePanelMessage: '',
vibePanelSuggestions: [],
vibePanelIntent: 'generate',
isVibeGenerating: false,
}))
}
setIsVibeGenerating(false)
@ -1286,10 +1412,16 @@ export const useWorkflowVibe = () => {
if (skipPanelPreview) {
// Prefer backend nodes (already sanitized) over mermaid re-parsing
if (backendNodes && backendNodes.length > 0 && backendEdges) {
console.log('[VIBE] Applying backend nodes directly to workflow')
console.log('[VIBE] Backend nodes:', backendNodes.length)
console.log('[VIBE] Backend edges:', backendEdges.length)
await applyBackendNodesToWorkflow(backendNodes, backendEdges)
console.log('[VIBE] Backend nodes applied successfully')
}
else {
console.log('[VIBE] Applying mermaid flowchart to workflow')
await applyFlowchartToWorkflow()
console.log('[VIBE] Mermaid flowchart applied successfully')
}
}
}