mirror of https://github.com/langgenius/dify.git
refactor(vibe): extract workflow generation to dedicated module
- Move workflow generation logic from LLMGenerator to WorkflowGenerator - Extract to api/core/workflow/generator/ with modular architecture - Implement Planner-Builder pattern for better separation of concerns - Add validation engine with rule-based error classification - Add node and edge repair utilities for auto-fixing common issues - Add deterministic Mermaid generator for consistent output - Reorganize configuration and prompts - Move vibe_config/ to generator/config/ - Move vibe_prompts.py to generator/prompts/ (split into multiple files) - Add builder_prompts.py and planner_prompts.py for new architecture - Enhance frontend workflow handling - Use standard node initialization for proper node setup - Improve variable reference replacement with better error handling - Add model fallback logic for better compatibility - Handle end node outputs format (value_selector vs legacy format) - Ensure parameter-extractor nodes have required 'required' field - Add comprehensive test coverage - Unit tests for mermaid generator, node repair, edge repair - Tests for validation engine and rule system - Tests for planner prompts formatting - Frontend tests for variable reference replacement - Add max_fix_iterations parameter for validate-fix loop configuration # Conflicts: # web/app/components/workflow/hooks/use-workflow-vibe.tsx
This commit is contained in:
parent
c4eee28fd8
commit
cd030d82e5
|
|
@ -1,9 +1,13 @@
|
|||
import logging
|
||||
from collections.abc import Sequence
|
||||
from typing import Any
|
||||
from typing import Any, cast
|
||||
|
||||
from flask_restx import Resource
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.app.error import (
|
||||
CompletionRequestError,
|
||||
|
|
@ -18,6 +22,7 @@ from core.helper.code_executor.javascript.javascript_code_provider import Javasc
|
|||
from core.helper.code_executor.python3.python3_code_provider import Python3CodeProvider
|
||||
from core.llm_generator.llm_generator import LLMGenerator
|
||||
from core.model_runtime.errors.invoke import InvokeError
|
||||
from core.workflow.generator import WorkflowGenerator
|
||||
from extensions.ext_database import db
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from models import App
|
||||
|
|
@ -77,6 +82,13 @@ class FlowchartGeneratePayload(BaseModel):
|
|||
language: str | None = Field(default=None, description="Preferred language for generated content")
|
||||
# Available models that user has configured (for LLM/question-classifier nodes)
|
||||
available_models: list[dict[str, Any]] = Field(default_factory=list, description="User's configured models")
|
||||
# Validate-fix iteration loop configuration
|
||||
max_fix_iterations: int = Field(
|
||||
default=2,
|
||||
ge=0,
|
||||
le=5,
|
||||
description="Maximum number of validate-fix iterations (0 to disable auto-fix)",
|
||||
)
|
||||
|
||||
|
||||
def reg(cls: type[BaseModel]):
|
||||
|
|
@ -305,7 +317,7 @@ class FlowchartGenerateApi(Resource):
|
|||
"warnings": args.previous_workflow.warnings,
|
||||
}
|
||||
|
||||
result = LLMGenerator.generate_workflow_flowchart(
|
||||
result = WorkflowGenerator.generate_workflow_flowchart(
|
||||
tenant_id=current_tenant_id,
|
||||
instruction=args.instruction,
|
||||
model_config=args.model_config_data,
|
||||
|
|
@ -313,11 +325,13 @@ class FlowchartGenerateApi(Resource):
|
|||
existing_nodes=args.existing_nodes,
|
||||
available_tools=args.available_tools,
|
||||
selected_node_ids=args.selected_node_ids,
|
||||
previous_workflow=previous_workflow_dict,
|
||||
previous_workflow=cast(dict[str, object], previous_workflow_dict),
|
||||
regenerate_mode=args.regenerate_mode,
|
||||
preferred_language=args.language,
|
||||
available_models=args.available_models,
|
||||
max_fix_iterations=args.max_fix_iterations,
|
||||
)
|
||||
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ProviderNotInitializeError(ex.description)
|
||||
except QuotaExceededError:
|
||||
|
|
|
|||
|
|
@ -1,6 +1,5 @@
|
|||
import json
|
||||
import logging
|
||||
import re
|
||||
from collections.abc import Sequence
|
||||
from typing import Protocol, cast
|
||||
|
||||
|
|
@ -12,8 +11,6 @@ from core.llm_generator.prompts import (
|
|||
CONVERSATION_TITLE_PROMPT,
|
||||
GENERATOR_QA_PROMPT,
|
||||
JAVASCRIPT_CODE_GENERATOR_PROMPT_TEMPLATE,
|
||||
LLM_MODIFY_CODE_SYSTEM,
|
||||
LLM_MODIFY_PROMPT_SYSTEM,
|
||||
PYTHON_CODE_GENERATOR_PROMPT_TEMPLATE,
|
||||
SUGGESTED_QUESTIONS_MAX_TOKENS,
|
||||
SUGGESTED_QUESTIONS_TEMPERATURE,
|
||||
|
|
@ -30,6 +27,7 @@ from core.ops.ops_trace_manager import TraceQueueManager, TraceTask
|
|||
from core.ops.utils import measure_time
|
||||
from core.prompt.utils.prompt_template_parser import PromptTemplateParser
|
||||
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey
|
||||
from core.workflow.generator import WorkflowGenerator
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_storage import storage
|
||||
from models import App, Message, WorkflowNodeExecutionModel
|
||||
|
|
@ -299,178 +297,24 @@ class LLMGenerator:
|
|||
regenerate_mode: bool = False,
|
||||
preferred_language: str | None = None,
|
||||
available_models: Sequence[dict[str, object]] | None = None,
|
||||
max_fix_iterations: int = 2,
|
||||
):
|
||||
"""
|
||||
Generate workflow flowchart with enhanced prompts and inline intent classification.
|
||||
|
||||
Returns a dict with:
|
||||
- intent: "generate" | "off_topic" | "error"
|
||||
- flowchart: Mermaid syntax string (for generate intent)
|
||||
- message: User-friendly explanation
|
||||
- warnings: List of validation warnings
|
||||
- suggestions: List of workflow suggestions (for off_topic intent)
|
||||
- error: Error message if generation failed
|
||||
"""
|
||||
from core.llm_generator.vibe_prompts import (
|
||||
build_vibe_enhanced_prompt,
|
||||
extract_mermaid_from_response,
|
||||
parse_vibe_response,
|
||||
sanitize_tool_nodes,
|
||||
validate_node_parameters,
|
||||
validate_tool_references,
|
||||
)
|
||||
|
||||
model_parameters = model_config.get("completion_params", {})
|
||||
|
||||
# Build enhanced prompts with context
|
||||
system_prompt, user_prompt = build_vibe_enhanced_prompt(
|
||||
return WorkflowGenerator.generate_workflow_flowchart(
|
||||
tenant_id=tenant_id,
|
||||
instruction=instruction,
|
||||
available_nodes=list(available_nodes) if available_nodes else None,
|
||||
available_tools=list(available_tools) if available_tools else None,
|
||||
existing_nodes=list(existing_nodes) if existing_nodes else None,
|
||||
selected_node_ids=list(selected_node_ids) if selected_node_ids else None,
|
||||
previous_workflow=dict(previous_workflow) if previous_workflow else None,
|
||||
model_config=model_config,
|
||||
available_nodes=available_nodes,
|
||||
existing_nodes=existing_nodes,
|
||||
available_tools=available_tools,
|
||||
selected_node_ids=selected_node_ids,
|
||||
previous_workflow=previous_workflow,
|
||||
regenerate_mode=regenerate_mode,
|
||||
preferred_language=preferred_language,
|
||||
available_models=list(available_models) if available_models else None,
|
||||
available_models=available_models,
|
||||
max_fix_iterations=max_fix_iterations,
|
||||
)
|
||||
|
||||
prompt_messages: list[PromptMessage] = [
|
||||
SystemPromptMessage(content=system_prompt),
|
||||
UserPromptMessage(content=user_prompt),
|
||||
]
|
||||
|
||||
# DEBUG: Log model input
|
||||
logger.debug("=" * 80)
|
||||
logger.debug("[VIBE] generate_workflow_flowchart - MODEL INPUT")
|
||||
logger.debug("=" * 80)
|
||||
logger.debug("[VIBE] Instruction: %s", instruction)
|
||||
logger.debug("[VIBE] Model: %s/%s", model_config.get("provider", ""), model_config.get("name", ""))
|
||||
system_prompt_log = system_prompt[:2000] + "..." if len(system_prompt) > 2000 else system_prompt
|
||||
logger.debug("[VIBE] System Prompt:\n%s", system_prompt_log)
|
||||
logger.debug("[VIBE] User Prompt:\n%s", user_prompt)
|
||||
logger.debug("=" * 80)
|
||||
|
||||
model_manager = ModelManager()
|
||||
model_instance = model_manager.get_model_instance(
|
||||
tenant_id=tenant_id,
|
||||
model_type=ModelType.LLM,
|
||||
provider=model_config.get("provider", ""),
|
||||
model=model_config.get("name", ""),
|
||||
)
|
||||
|
||||
try:
|
||||
response: LLMResult = model_instance.invoke_llm(
|
||||
prompt_messages=list(prompt_messages),
|
||||
model_parameters=model_parameters,
|
||||
stream=False,
|
||||
)
|
||||
content = response.message.get_text_content()
|
||||
|
||||
# DEBUG: Log model output
|
||||
logger.debug("=" * 80)
|
||||
logger.debug("[VIBE] generate_workflow_flowchart - MODEL OUTPUT")
|
||||
logger.debug("=" * 80)
|
||||
logger.debug("[VIBE] Raw Response:\n%s", content)
|
||||
logger.debug("=" * 80)
|
||||
if not isinstance(content, str):
|
||||
raise ValueError("Flowchart response is not a string")
|
||||
|
||||
# Parse the enhanced response format
|
||||
parsed = parse_vibe_response(content)
|
||||
|
||||
# DEBUG: Log parsed result
|
||||
logger.debug("[VIBE] Parsed Response:")
|
||||
logger.debug("[VIBE] intent: %s", parsed.get("intent"))
|
||||
logger.debug("[VIBE] message: %s", parsed.get("message", "")[:200] if parsed.get("message") else "")
|
||||
logger.debug("[VIBE] mermaid: %s", parsed.get("mermaid", "")[:500] if parsed.get("mermaid") else "")
|
||||
logger.debug("[VIBE] warnings: %s", parsed.get("warnings", []))
|
||||
logger.debug("[VIBE] suggestions: %s", parsed.get("suggestions", []))
|
||||
if parsed.get("error"):
|
||||
logger.debug("[VIBE] error: %s", parsed.get("error"))
|
||||
logger.debug("=" * 80)
|
||||
|
||||
# Handle error case from parsing
|
||||
if parsed.get("intent") == "error":
|
||||
# Fall back to legacy parsing for backwards compatibility
|
||||
match = re.search(r"```(?:mermaid)?\s*([\s\S]+?)```", content, flags=re.IGNORECASE)
|
||||
flowchart = (match.group(1) if match else content).strip()
|
||||
return {
|
||||
"intent": "generate",
|
||||
"flowchart": flowchart,
|
||||
"message": "",
|
||||
"warnings": [],
|
||||
"tool_recommendations": [],
|
||||
"error": "",
|
||||
}
|
||||
|
||||
# Handle off_topic case
|
||||
if parsed.get("intent") == "off_topic":
|
||||
return {
|
||||
"intent": "off_topic",
|
||||
"flowchart": "",
|
||||
"message": parsed.get("message", ""),
|
||||
"suggestions": parsed.get("suggestions", []),
|
||||
"warnings": [],
|
||||
"tool_recommendations": [],
|
||||
"error": "",
|
||||
}
|
||||
|
||||
# Handle generate case
|
||||
flowchart = extract_mermaid_from_response(parsed)
|
||||
|
||||
# Sanitize tool nodes - replace invalid tools with fallback nodes
|
||||
original_nodes = parsed.get("nodes", [])
|
||||
sanitized_nodes, sanitize_warnings = sanitize_tool_nodes(
|
||||
original_nodes,
|
||||
list(available_tools) if available_tools else None,
|
||||
)
|
||||
# Update parsed nodes with sanitized version
|
||||
parsed["nodes"] = sanitized_nodes
|
||||
|
||||
# Validate tool references and get recommendations for unconfigured tools
|
||||
validation_warnings, tool_recommendations = validate_tool_references(
|
||||
sanitized_nodes,
|
||||
list(available_tools) if available_tools else None,
|
||||
)
|
||||
|
||||
# Validate node parameters are properly filled (Phase 9: Auto-Fill)
|
||||
param_warnings = validate_node_parameters(sanitized_nodes)
|
||||
|
||||
existing_warnings = parsed.get("warnings", [])
|
||||
all_warnings = existing_warnings + sanitize_warnings + validation_warnings + param_warnings
|
||||
|
||||
return {
|
||||
"intent": "generate",
|
||||
"flowchart": flowchart,
|
||||
"nodes": sanitized_nodes, # Include sanitized nodes in response
|
||||
"edges": parsed.get("edges", []),
|
||||
"message": parsed.get("message", ""),
|
||||
"warnings": all_warnings,
|
||||
"tool_recommendations": tool_recommendations,
|
||||
"error": "",
|
||||
}
|
||||
|
||||
except InvokeError as e:
|
||||
return {
|
||||
"intent": "error",
|
||||
"flowchart": "",
|
||||
"message": "",
|
||||
"warnings": [],
|
||||
"tool_recommendations": [],
|
||||
"error": str(e),
|
||||
}
|
||||
except Exception as e:
|
||||
logger.exception("Failed to generate workflow flowchart, model: %s", model_config.get("name"))
|
||||
return {
|
||||
"intent": "error",
|
||||
"flowchart": "",
|
||||
"message": "",
|
||||
"warnings": [],
|
||||
"tool_recommendations": [],
|
||||
"error": str(e),
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def generate_code(cls, tenant_id: str, instruction: str, model_config: dict, code_language: str = "javascript"):
|
||||
if code_language == "python":
|
||||
|
|
|
|||
|
|
@ -0,0 +1 @@
|
|||
from .runner import WorkflowGenerator
|
||||
|
|
@ -5,14 +5,14 @@ This module centralizes configuration for the Vibe workflow generation feature,
|
|||
including node schemas, fallback rules, and response templates.
|
||||
"""
|
||||
|
||||
from core.llm_generator.vibe_config.fallback_rules import (
|
||||
from core.workflow.generator.config.fallback_rules import (
|
||||
FALLBACK_RULES,
|
||||
FIELD_NAME_CORRECTIONS,
|
||||
NODE_TYPE_ALIASES,
|
||||
get_corrected_field_name,
|
||||
)
|
||||
from core.llm_generator.vibe_config.node_schemas import BUILTIN_NODE_SCHEMAS
|
||||
from core.llm_generator.vibe_config.responses import DEFAULT_SUGGESTIONS, OFF_TOPIC_RESPONSES
|
||||
from core.workflow.generator.config.node_schemas import BUILTIN_NODE_SCHEMAS
|
||||
from core.workflow.generator.config.responses import DEFAULT_SUGGESTIONS, OFF_TOPIC_RESPONSES
|
||||
|
||||
__all__ = [
|
||||
"BUILTIN_NODE_SCHEMAS",
|
||||
|
|
@ -137,19 +137,28 @@ BUILTIN_NODE_SCHEMAS: dict[str, dict[str, Any]] = {
|
|||
},
|
||||
"if-else": {
|
||||
"description": "Conditional branching based on conditions",
|
||||
"required": ["conditions"],
|
||||
"required": ["cases"],
|
||||
"parameters": {
|
||||
"conditions": {
|
||||
"cases": {
|
||||
"type": "array",
|
||||
"description": "List of condition cases",
|
||||
"description": "List of condition cases. Each case defines when 'true' branch is taken.",
|
||||
"item_schema": {
|
||||
"case_id": "string - unique case identifier",
|
||||
"logical_operator": "enum: and, or",
|
||||
"conditions": "array of {variable_selector, comparison_operator, value}",
|
||||
"case_id": "string - unique case identifier (e.g., 'case_1')",
|
||||
"logical_operator": "enum: and, or - how multiple conditions combine",
|
||||
"conditions": {
|
||||
"type": "array",
|
||||
"item_schema": {
|
||||
"variable_selector": "array of strings - path to variable, e.g. ['node_id', 'field']",
|
||||
"comparison_operator": (
|
||||
"enum: =, ≠, >, <, ≥, ≤, contains, not contains, is, is not, empty, not empty"
|
||||
),
|
||||
"value": "string or number - value to compare against",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
"outputs": ["Branches: true (conditions met), false (else)"],
|
||||
"outputs": ["Branches: true (first case conditions met), false (else/no case matched)"],
|
||||
},
|
||||
"knowledge-retrieval": {
|
||||
"description": "Query knowledge base for relevant content",
|
||||
|
|
@ -207,5 +216,70 @@ BUILTIN_NODE_SCHEMAS: dict[str, dict[str, Any]] = {
|
|||
},
|
||||
"outputs": ["item (current iteration item)", "index (current index)"],
|
||||
},
|
||||
"parameter-extractor": {
|
||||
"description": "Extract structured parameters from user input using LLM",
|
||||
"required": ["query", "parameters"],
|
||||
"parameters": {
|
||||
"model": {
|
||||
"type": "object",
|
||||
"description": "Model configuration (provider, name, mode)",
|
||||
},
|
||||
"query": {
|
||||
"type": "array",
|
||||
"description": "Path to input text to extract parameters from, e.g. ['start', 'user_input']",
|
||||
},
|
||||
"parameters": {
|
||||
"type": "array",
|
||||
"description": "Parameters to extract from the input",
|
||||
"item_schema": {
|
||||
"name": "string - parameter name (required)",
|
||||
"type": (
|
||||
"enum: string, number, boolean, array[string], array[number], "
|
||||
"array[object], array[boolean]"
|
||||
),
|
||||
"description": "string - description of what to extract (required)",
|
||||
"required": "boolean - whether this parameter is required (MUST be specified)",
|
||||
"options": "array of strings (optional) - for enum-like selection",
|
||||
},
|
||||
},
|
||||
"instruction": {
|
||||
"type": "string",
|
||||
"description": "Additional instructions for extraction",
|
||||
},
|
||||
"reasoning_mode": {
|
||||
"type": "enum",
|
||||
"options": ["function_call", "prompt"],
|
||||
"description": "How to perform extraction (defaults to function_call)",
|
||||
},
|
||||
},
|
||||
"outputs": ["Extracted parameters as defined in parameters array", "__is_success", "__reason"],
|
||||
},
|
||||
"question-classifier": {
|
||||
"description": "Classify user input into predefined categories using LLM",
|
||||
"required": ["query", "classes"],
|
||||
"parameters": {
|
||||
"model": {
|
||||
"type": "object",
|
||||
"description": "Model configuration (provider, name, mode)",
|
||||
},
|
||||
"query": {
|
||||
"type": "array",
|
||||
"description": "Path to input text to classify, e.g. ['start', 'user_input']",
|
||||
},
|
||||
"classes": {
|
||||
"type": "array",
|
||||
"description": "Classification categories",
|
||||
"item_schema": {
|
||||
"id": "string - unique class identifier",
|
||||
"name": "string - class name/label",
|
||||
},
|
||||
},
|
||||
"instruction": {
|
||||
"type": "string",
|
||||
"description": "Additional instructions for classification",
|
||||
},
|
||||
},
|
||||
"outputs": ["class_name (selected class)"],
|
||||
},
|
||||
}
|
||||
|
||||
|
|
@ -0,0 +1,343 @@
|
|||
BUILDER_SYSTEM_PROMPT = """<role>
|
||||
You are a Workflow Configuration Engineer.
|
||||
Your goal is to implement the Architect's plan by generating a precise, runnable Dify Workflow JSON configuration.
|
||||
</role>
|
||||
|
||||
<inputs>
|
||||
<plan>
|
||||
{plan_context}
|
||||
</plan>
|
||||
|
||||
<tool_schemas>
|
||||
{tool_schemas}
|
||||
</tool_schemas>
|
||||
|
||||
<node_specs>
|
||||
{builtin_node_specs}
|
||||
</node_specs>
|
||||
</inputs>
|
||||
|
||||
<rules>
|
||||
1. **Configuration**:
|
||||
- You MUST fill ALL required parameters for every node.
|
||||
- Use `{{{{#node_id.field#}}}}` syntax to reference outputs from previous nodes in text fields.
|
||||
- For 'start' node, define all necessary user inputs.
|
||||
|
||||
2. **Variable References**:
|
||||
- For text fields (like prompts, queries): use string format `{{{{#node_id.field#}}}}`
|
||||
- For 'end' node outputs: use `value_selector` array format `["node_id", "field"]`
|
||||
- Example: to reference 'llm' node's 'text' output in end node, use `["llm", "text"]`
|
||||
|
||||
3. **Tools**:
|
||||
- ONLY use the tools listed in `<tool_schemas>`.
|
||||
- If a planned tool is missing from schemas, fallback to `http-request` or `code`.
|
||||
|
||||
4. **Node Specifics**:
|
||||
- For `if-else` comparison_operator, use literal symbols: `≥`, `≤`, `=`, `≠` (NOT `>=` or `==`).
|
||||
|
||||
5. **Output**:
|
||||
- Return ONLY the JSON object with `nodes` and `edges`.
|
||||
- Do NOT generate Mermaid diagrams.
|
||||
- Do NOT generate explanations.
|
||||
</rules>
|
||||
|
||||
<edge_rules priority="critical">
|
||||
**EDGES ARE CRITICAL** - Every node except 'end' MUST have at least one outgoing edge.
|
||||
|
||||
1. **Linear Flow**: Simple source -> target connection
|
||||
```
|
||||
{{"source": "node_a", "target": "node_b"}}
|
||||
```
|
||||
|
||||
2. **question-classifier Branching**: Each class MUST have a separate edge with `sourceHandle` = class `id`
|
||||
- If you define classes: [{{"id": "cls_refund", "name": "Refund"}}, {{"id": "cls_inquiry", "name": "Inquiry"}}]
|
||||
- You MUST create edges:
|
||||
- {{"source": "classifier", "sourceHandle": "cls_refund", "target": "refund_handler"}}
|
||||
- {{"source": "classifier", "sourceHandle": "cls_inquiry", "target": "inquiry_handler"}}
|
||||
|
||||
3. **if-else Branching**: MUST have exactly TWO edges with sourceHandle "true" and "false"
|
||||
- {{"source": "condition", "sourceHandle": "true", "target": "true_branch"}}
|
||||
- {{"source": "condition", "sourceHandle": "false", "target": "false_branch"}}
|
||||
|
||||
4. **Branch Convergence**: Multiple branches can connect to same downstream node
|
||||
- Both true_branch and false_branch can connect to the same 'end' node
|
||||
|
||||
5. **NEVER leave orphan nodes**: Every node must be connected in the graph
|
||||
</edge_rules>
|
||||
|
||||
<examples>
|
||||
<example name="simple_linear">
|
||||
```json
|
||||
{{
|
||||
"nodes": [
|
||||
{{
|
||||
"id": "start",
|
||||
"type": "start",
|
||||
"title": "Start",
|
||||
"config": {{
|
||||
"variables": [{{"variable": "query", "label": "Query", "type": "text-input"}}]
|
||||
}}
|
||||
}},
|
||||
{{
|
||||
"id": "llm",
|
||||
"type": "llm",
|
||||
"title": "Generate Response",
|
||||
"config": {{
|
||||
"model": {{"provider": "openai", "name": "gpt-4o", "mode": "chat"}},
|
||||
"prompt_template": [{{"role": "user", "text": "Answer: {{{{#start.query#}}}}"}}]
|
||||
}}
|
||||
}},
|
||||
{{
|
||||
"id": "end",
|
||||
"type": "end",
|
||||
"title": "End",
|
||||
"config": {{
|
||||
"outputs": [
|
||||
{{"variable": "result", "value_selector": ["llm", "text"]}}
|
||||
]
|
||||
}}
|
||||
}}
|
||||
],
|
||||
"edges": [
|
||||
{{"source": "start", "target": "llm"}},
|
||||
{{"source": "llm", "target": "end"}}
|
||||
]
|
||||
}}
|
||||
```
|
||||
</example>
|
||||
|
||||
<example name="question_classifier_branching" description="Customer service with intent classification">
|
||||
```json
|
||||
{{
|
||||
"nodes": [
|
||||
{{
|
||||
"id": "start",
|
||||
"type": "start",
|
||||
"title": "Start",
|
||||
"config": {{
|
||||
"variables": [{{"variable": "user_input", "label": "User Message", "type": "text-input", "required": true}}]
|
||||
}}
|
||||
}},
|
||||
{{
|
||||
"id": "classifier",
|
||||
"type": "question-classifier",
|
||||
"title": "Classify Intent",
|
||||
"config": {{
|
||||
"model": {{"provider": "openai", "name": "gpt-4o", "mode": "chat"}},
|
||||
"query_variable_selector": ["start", "user_input"],
|
||||
"classes": [
|
||||
{{"id": "cls_refund", "name": "Refund Request"}},
|
||||
{{"id": "cls_inquiry", "name": "Product Inquiry"}},
|
||||
{{"id": "cls_complaint", "name": "Complaint"}},
|
||||
{{"id": "cls_other", "name": "Other"}}
|
||||
],
|
||||
"instruction": "Classify the user's intent"
|
||||
}}
|
||||
}},
|
||||
{{
|
||||
"id": "handle_refund",
|
||||
"type": "llm",
|
||||
"title": "Handle Refund",
|
||||
"config": {{
|
||||
"model": {{"provider": "openai", "name": "gpt-4o", "mode": "chat"}},
|
||||
"prompt_template": [{{"role": "user", "text": "Extract order number and respond: {{{{#start.user_input#}}}}"}}]
|
||||
}}
|
||||
}},
|
||||
{{
|
||||
"id": "handle_inquiry",
|
||||
"type": "llm",
|
||||
"title": "Handle Inquiry",
|
||||
"config": {{
|
||||
"model": {{"provider": "openai", "name": "gpt-4o", "mode": "chat"}},
|
||||
"prompt_template": [{{"role": "user", "text": "Answer product question: {{{{#start.user_input#}}}}"}}]
|
||||
}}
|
||||
}},
|
||||
{{
|
||||
"id": "handle_complaint",
|
||||
"type": "llm",
|
||||
"title": "Handle Complaint",
|
||||
"config": {{
|
||||
"model": {{"provider": "openai", "name": "gpt-4o", "mode": "chat"}},
|
||||
"prompt_template": [{{"role": "user", "text": "Respond with empathy: {{{{#start.user_input#}}}}"}}]
|
||||
}}
|
||||
}},
|
||||
{{
|
||||
"id": "handle_other",
|
||||
"type": "llm",
|
||||
"title": "Handle Other",
|
||||
"config": {{
|
||||
"model": {{"provider": "openai", "name": "gpt-4o", "mode": "chat"}},
|
||||
"prompt_template": [{{"role": "user", "text": "Provide general response: {{{{#start.user_input#}}}}"}}]
|
||||
}}
|
||||
}},
|
||||
{{
|
||||
"id": "end",
|
||||
"type": "end",
|
||||
"title": "End",
|
||||
"config": {{
|
||||
"outputs": [{{"variable": "response", "value_selector": ["handle_refund", "text"]}}]
|
||||
}}
|
||||
}}
|
||||
],
|
||||
"edges": [
|
||||
{{"source": "start", "target": "classifier"}},
|
||||
{{"source": "classifier", "sourceHandle": "cls_refund", "target": "handle_refund"}},
|
||||
{{"source": "classifier", "sourceHandle": "cls_inquiry", "target": "handle_inquiry"}},
|
||||
{{"source": "classifier", "sourceHandle": "cls_complaint", "target": "handle_complaint"}},
|
||||
{{"source": "classifier", "sourceHandle": "cls_other", "target": "handle_other"}},
|
||||
{{"source": "handle_refund", "target": "end"}},
|
||||
{{"source": "handle_inquiry", "target": "end"}},
|
||||
{{"source": "handle_complaint", "target": "end"}},
|
||||
{{"source": "handle_other", "target": "end"}}
|
||||
]
|
||||
}}
|
||||
```
|
||||
CRITICAL: Notice that each class id (cls_refund, cls_inquiry, etc.) becomes a sourceHandle in the edges!
|
||||
</example>
|
||||
|
||||
<example name="if_else_branching" description="Conditional logic with if-else">
|
||||
```json
|
||||
{{
|
||||
"nodes": [
|
||||
{{
|
||||
"id": "start",
|
||||
"type": "start",
|
||||
"title": "Start",
|
||||
"config": {{
|
||||
"variables": [{{"variable": "years", "label": "Years of Experience", "type": "number", "required": true}}]
|
||||
}}
|
||||
}},
|
||||
{{
|
||||
"id": "check_experience",
|
||||
"type": "if-else",
|
||||
"title": "Check Experience",
|
||||
"config": {{
|
||||
"cases": [
|
||||
{{
|
||||
"case_id": "case_1",
|
||||
"logical_operator": "and",
|
||||
"conditions": [
|
||||
{{
|
||||
"variable_selector": ["start", "years"],
|
||||
"comparison_operator": "≥",
|
||||
"value": "3"
|
||||
}}
|
||||
]
|
||||
}}
|
||||
]
|
||||
}}
|
||||
}},
|
||||
{{
|
||||
"id": "qualified",
|
||||
"type": "llm",
|
||||
"title": "Qualified Response",
|
||||
"config": {{
|
||||
"model": {{"provider": "openai", "name": "gpt-4o", "mode": "chat"}},
|
||||
"prompt_template": [{{"role": "user", "text": "Generate qualified candidate response"}}]
|
||||
}}
|
||||
}},
|
||||
{{
|
||||
"id": "not_qualified",
|
||||
"type": "llm",
|
||||
"title": "Not Qualified Response",
|
||||
"config": {{
|
||||
"model": {{"provider": "openai", "name": "gpt-4o", "mode": "chat"}},
|
||||
"prompt_template": [{{"role": "user", "text": "Generate rejection response"}}]
|
||||
}}
|
||||
}},
|
||||
{{
|
||||
"id": "end",
|
||||
"type": "end",
|
||||
"title": "End",
|
||||
"config": {{
|
||||
"outputs": [{{"variable": "result", "value_selector": ["qualified", "text"]}}]
|
||||
}}
|
||||
}}
|
||||
],
|
||||
"edges": [
|
||||
{{"source": "start", "target": "check_experience"}},
|
||||
{{"source": "check_experience", "sourceHandle": "true", "target": "qualified"}},
|
||||
{{"source": "check_experience", "sourceHandle": "false", "target": "not_qualified"}},
|
||||
{{"source": "qualified", "target": "end"}},
|
||||
{{"source": "not_qualified", "target": "end"}}
|
||||
]
|
||||
}}
|
||||
```
|
||||
CRITICAL: if-else MUST have exactly two edges with sourceHandle "true" and "false"!
|
||||
</example>
|
||||
|
||||
<example name="parameter_extractor" description="Extract structured data from text">
|
||||
```json
|
||||
{{
|
||||
"nodes": [
|
||||
{{
|
||||
"id": "start",
|
||||
"type": "start",
|
||||
"title": "Start",
|
||||
"config": {{
|
||||
"variables": [{{"variable": "resume", "label": "Resume Text", "type": "paragraph", "required": true}}]
|
||||
}}
|
||||
}},
|
||||
{{
|
||||
"id": "extract",
|
||||
"type": "parameter-extractor",
|
||||
"title": "Extract Info",
|
||||
"config": {{
|
||||
"model": {{"provider": "openai", "name": "gpt-4o", "mode": "chat"}},
|
||||
"query": ["start", "resume"],
|
||||
"parameters": [
|
||||
{{"name": "name", "type": "string", "description": "Candidate name", "required": true}},
|
||||
{{"name": "years", "type": "number", "description": "Years of experience", "required": true}},
|
||||
{{"name": "skills", "type": "array[string]", "description": "List of skills", "required": true}}
|
||||
],
|
||||
"instruction": "Extract candidate information from resume"
|
||||
}}
|
||||
}},
|
||||
{{
|
||||
"id": "process",
|
||||
"type": "llm",
|
||||
"title": "Process Data",
|
||||
"config": {{
|
||||
"model": {{"provider": "openai", "name": "gpt-4o", "mode": "chat"}},
|
||||
"prompt_template": [{{"role": "user", "text": "Name: {{{{#extract.name#}}}}, Years: {{{{#extract.years#}}}}"}}]
|
||||
}}
|
||||
}},
|
||||
{{
|
||||
"id": "end",
|
||||
"type": "end",
|
||||
"title": "End",
|
||||
"config": {{
|
||||
"outputs": [{{"variable": "result", "value_selector": ["process", "text"]}}]
|
||||
}}
|
||||
}}
|
||||
],
|
||||
"edges": [
|
||||
{{"source": "start", "target": "extract"}},
|
||||
{{"source": "extract", "target": "process"}},
|
||||
{{"source": "process", "target": "end"}}
|
||||
]
|
||||
}}
|
||||
```
|
||||
</example>
|
||||
</examples>
|
||||
|
||||
<edge_checklist>
|
||||
Before finalizing, verify:
|
||||
1. [ ] Every node (except 'end') has at least one outgoing edge
|
||||
2. [ ] 'start' node has exactly one outgoing edge
|
||||
3. [ ] 'question-classifier' has one edge per class, each with sourceHandle = class id
|
||||
4. [ ] 'if-else' has exactly two edges: sourceHandle "true" and sourceHandle "false"
|
||||
5. [ ] All branches eventually connect to 'end' (directly or through other nodes)
|
||||
6. [ ] No orphan nodes exist (every node is reachable from 'start')
|
||||
</edge_checklist>
|
||||
"""
|
||||
|
||||
BUILDER_USER_PROMPT = """<instruction>
|
||||
{instruction}
|
||||
</instruction>
|
||||
|
||||
Generate the full workflow configuration now. Pay special attention to:
|
||||
1. Creating edges for ALL branches of question-classifier and if-else nodes
|
||||
2. Using correct sourceHandle values for branching nodes
|
||||
3. Ensuring every node is connected in the graph
|
||||
"""
|
||||
|
|
@ -0,0 +1,75 @@
|
|||
PLANNER_SYSTEM_PROMPT = """<role>
|
||||
You are an expert Workflow Architect.
|
||||
Your job is to analyze user requests and plan a high-level automation workflow.
|
||||
</role>
|
||||
|
||||
<task>
|
||||
1. **Classify Intent**:
|
||||
- Is the user asking to create an automation/workflow? -> Intent: "generate"
|
||||
- Is it general chat/weather/jokes? -> Intent: "off_topic"
|
||||
|
||||
2. **Plan Steps** (if intent is "generate"):
|
||||
- Break down the user's goal into logical steps.
|
||||
- For each step, identify if a specific capability/tool is needed.
|
||||
- Select the MOST RELEVANT tools from the available_tools list.
|
||||
- DO NOT configure parameters yet. Just identify the tool.
|
||||
|
||||
3. **Output Format**:
|
||||
Return a JSON object.
|
||||
</task>
|
||||
|
||||
<available_tools>
|
||||
{tools_summary}
|
||||
</available_tools>
|
||||
|
||||
<response_format>
|
||||
If intent is "generate":
|
||||
```json
|
||||
{{
|
||||
"intent": "generate",
|
||||
"plan_thought": "Brief explanation of the plan...",
|
||||
"steps": [
|
||||
{{ "step": 1, "description": "Fetch data from URL", "tool": "http-request" }},
|
||||
{{ "step": 2, "description": "Summarize content", "tool": "llm" }},
|
||||
{{ "step": 3, "description": "Search for info", "tool": "google_search" }}
|
||||
],
|
||||
"required_tool_keys": ["google_search"]
|
||||
}}
|
||||
```
|
||||
(Note: 'http-request', 'llm', 'code' are built-in, you don't need to list them in required_tool_keys,
|
||||
only external tools)
|
||||
|
||||
If intent is "off_topic":
|
||||
```json
|
||||
{{
|
||||
"intent": "off_topic",
|
||||
"message": "I can only help you build workflows. Try asking me to 'Create a workflow that...'",
|
||||
"suggestions": ["Scrape a website", "Summarize a PDF"]
|
||||
}}
|
||||
```
|
||||
</response_format>
|
||||
"""
|
||||
|
||||
PLANNER_USER_PROMPT = """<user_request>
|
||||
{instruction}
|
||||
</user_request>
|
||||
"""
|
||||
|
||||
|
||||
def format_tools_for_planner(tools: list[dict]) -> str:
|
||||
"""Format tools list for planner (Lightweight: Name + Description only)."""
|
||||
if not tools:
|
||||
return "No external tools available."
|
||||
|
||||
lines = []
|
||||
for t in tools:
|
||||
key = t.get("tool_key") or t.get("tool_name")
|
||||
provider = t.get("provider_id") or t.get("provider", "")
|
||||
desc = t.get("tool_description") or t.get("description", "")
|
||||
label = t.get("tool_label") or key
|
||||
|
||||
# Format: - [provider/key] Label: Description
|
||||
full_key = f"{provider}/{key}" if provider else key
|
||||
lines.append(f"- [{full_key}] {label}: {desc}")
|
||||
|
||||
return "\n".join(lines)
|
||||
|
|
@ -10,7 +10,7 @@ import json
|
|||
import re
|
||||
from typing import Any
|
||||
|
||||
from core.llm_generator.vibe_config import (
|
||||
from core.workflow.generator.config import (
|
||||
BUILTIN_NODE_SCHEMAS,
|
||||
DEFAULT_SUGGESTIONS,
|
||||
FALLBACK_RULES,
|
||||
|
|
@ -100,6 +100,13 @@ You help users create AI automation workflows by generating workflow configurati
|
|||
</variable_syntax>
|
||||
|
||||
<rules>
|
||||
<rule id="model_selection" priority="critical">
|
||||
For LLM, question-classifier, parameter-extractor nodes:
|
||||
- You MUST include a "model" config with provider and name from available_models section
|
||||
- Copy the EXACT provider and name values from available_models
|
||||
- NEVER use openai/gpt-4o, openai/gpt-3.5-turbo, openai/gpt-4 unless they appear in available_models
|
||||
- If available_models is empty or not provided, omit the model config entirely
|
||||
</rule>
|
||||
<rule id="tool_usage" priority="critical">
|
||||
ONLY use tools with status="configured" from available_tools.
|
||||
NEVER invent tool names like "webscraper", "email_sender", etc.
|
||||
|
|
@ -217,12 +224,14 @@ You help users create AI automation workflows by generating workflow configurati
|
|||
"type": "llm",
|
||||
"title": "Analyze Content",
|
||||
"config": {{{{
|
||||
"model": {{{{"provider": "USE_FROM_AVAILABLE_MODELS", "name": "USE_FROM_AVAILABLE_MODELS", "mode": "chat"}}}},
|
||||
"prompt_template": [
|
||||
{{{{"role": "system", "text": "You are a helpful analyst."}}}},
|
||||
{{{{"role": "user", "text": "Analyze this content:\\n\\n{{{{#fetch.body#}}}}"}}}}
|
||||
]
|
||||
}}}}
|
||||
}}}}
|
||||
NOTE: Replace "USE_FROM_AVAILABLE_MODELS" with actual values from available_models section!
|
||||
</example>
|
||||
<example type="code" title="Process data">
|
||||
{{{{
|
||||
|
|
@ -344,6 +353,7 @@ Generate your JSON response now. Remember:
|
|||
</output_instruction>
|
||||
"""
|
||||
|
||||
|
||||
def format_available_nodes(nodes: list[dict[str, Any]] | None) -> str:
|
||||
"""Format available nodes as XML with parameter schemas."""
|
||||
lines = ["<available_nodes>"]
|
||||
|
|
@ -591,7 +601,7 @@ def format_previous_attempt(
|
|||
def format_available_models(models: list[dict[str, Any]] | None) -> str:
|
||||
"""Format available models as XML for prompt inclusion."""
|
||||
if not models:
|
||||
return "<available_models>\n <!-- No models configured -->\n</available_models>"
|
||||
return "<available_models>\n <!-- No models configured - omit model config from nodes -->\n</available_models>"
|
||||
|
||||
lines = ["<available_models>"]
|
||||
for model in models:
|
||||
|
|
@ -600,16 +610,30 @@ def format_available_models(models: list[dict[str, Any]] | None) -> str:
|
|||
lines.append(f' <model provider="{provider}" name="{model_name}" />')
|
||||
lines.append("</available_models>")
|
||||
|
||||
# Add model selection rule
|
||||
# Add model selection rule with concrete example
|
||||
lines.append("")
|
||||
lines.append("<model_selection_rule>")
|
||||
lines.append(" CRITICAL: For LLM, question-classifier, and parameter-extractor nodes, you MUST select a model from available_models.")
|
||||
if len(models) == 1:
|
||||
first_model = models[0]
|
||||
lines.append(f' Use provider="{first_model.get("provider")}" and name="{first_model.get("model")}" for all model-dependent nodes.')
|
||||
else:
|
||||
lines.append(" Choose the most suitable model for each task from the available options.")
|
||||
lines.append(" NEVER use models not listed in available_models (e.g., openai/gpt-4o if not listed).")
|
||||
lines.append(" CRITICAL: For LLM, question-classifier, and parameter-extractor nodes:")
|
||||
lines.append(" - You MUST include a 'model' field in the config")
|
||||
lines.append(" - You MUST use ONLY models from available_models above")
|
||||
lines.append(" - NEVER use openai/gpt-4o, gpt-3.5-turbo, gpt-4 unless they appear in available_models")
|
||||
lines.append("")
|
||||
|
||||
# Provide concrete JSON example to copy
|
||||
first_model = models[0]
|
||||
provider = first_model.get("provider", "unknown")
|
||||
model_name = first_model.get("model", "unknown")
|
||||
lines.append(" COPY THIS EXACT MODEL CONFIG for all LLM/question-classifier/parameter-extractor nodes:")
|
||||
lines.append(f' "model": {{"provider": "{provider}", "name": "{model_name}", "mode": "chat"}}')
|
||||
|
||||
if len(models) > 1:
|
||||
lines.append("")
|
||||
lines.append(" Alternative models you can use:")
|
||||
for m in models[1:4]: # Show up to 3 alternatives
|
||||
p = m.get("provider", "unknown")
|
||||
n = m.get("model", "unknown")
|
||||
lines.append(f' - "model": {{"provider": "{p}", "name": "{n}", "mode": "chat"}}')
|
||||
|
||||
lines.append("</model_selection_rule>")
|
||||
|
||||
return "\n".join(lines)
|
||||
|
|
@ -1023,6 +1047,7 @@ def validate_node_parameters(nodes: list[dict[str, Any]]) -> list[str]:
|
|||
def extract_mermaid_from_response(data: dict[str, Any]) -> str:
|
||||
"""Extract mermaid flowchart from parsed response."""
|
||||
mermaid = data.get("mermaid", "")
|
||||
|
||||
if not mermaid:
|
||||
return ""
|
||||
|
||||
|
|
@ -1034,5 +1059,203 @@ def extract_mermaid_from_response(data: dict[str, Any]) -> str:
|
|||
if match:
|
||||
mermaid = match.group(1).strip()
|
||||
|
||||
# Sanitize edge labels to remove characters that break Mermaid parsing
|
||||
# Edge labels in Mermaid are ONLY in the pattern: -->|label|
|
||||
# We must NOT match |pipe| characters inside node labels like ["type=start|title=开始"]
|
||||
def sanitize_edge_label(match: re.Match) -> str:
|
||||
arrow = match.group(1) # --> or ---
|
||||
label = match.group(2) # the label between pipes
|
||||
# Remove or replace special characters that break Mermaid
|
||||
# Parentheses, brackets, braces have special meaning in Mermaid
|
||||
sanitized = re.sub(r'[(){}\[\]]', '', label)
|
||||
return f"{arrow}|{sanitized}|"
|
||||
|
||||
# Only match edge labels: --> or --- followed by |label|
|
||||
# This pattern ensures we only sanitize actual edge labels, not node content
|
||||
mermaid = re.sub(r'(-->|---)\|([^|]+)\|', sanitize_edge_label, mermaid)
|
||||
|
||||
return mermaid
|
||||
|
||||
|
||||
def classify_validation_errors(
|
||||
nodes: list[dict[str, Any]],
|
||||
available_models: list[dict[str, Any]] | None = None,
|
||||
available_tools: list[dict[str, Any]] | None = None,
|
||||
edges: list[dict[str, Any]] | None = None,
|
||||
) -> dict[str, list[dict[str, Any]]]:
|
||||
"""
|
||||
Classify validation errors into fixable and user-required categories.
|
||||
|
||||
This function uses the declarative rule engine to validate nodes.
|
||||
The rule engine provides deterministic, testable validation without
|
||||
relying on LLM judgment.
|
||||
|
||||
Fixable errors can be automatically corrected by the LLM in subsequent
|
||||
iterations. User-required errors need manual intervention.
|
||||
|
||||
Args:
|
||||
nodes: List of generated workflow nodes
|
||||
available_models: List of models the user has configured
|
||||
available_tools: List of available tools
|
||||
edges: List of edges connecting nodes
|
||||
|
||||
Returns:
|
||||
dict with:
|
||||
- "fixable": errors that LLM can fix automatically
|
||||
- "user_required": errors that need user intervention
|
||||
- "all_warnings": combined warning messages for backwards compatibility
|
||||
- "stats": validation statistics
|
||||
"""
|
||||
from core.workflow.generator.validation import ValidationContext, ValidationEngine
|
||||
|
||||
# Build validation context
|
||||
context = ValidationContext(
|
||||
nodes=nodes,
|
||||
edges=edges or [],
|
||||
available_models=available_models or [],
|
||||
available_tools=available_tools or [],
|
||||
)
|
||||
|
||||
# Run validation through rule engine
|
||||
engine = ValidationEngine()
|
||||
result = engine.validate(context)
|
||||
|
||||
# Convert to legacy format for backwards compatibility
|
||||
fixable: list[dict[str, Any]] = []
|
||||
user_required: list[dict[str, Any]] = []
|
||||
|
||||
for error in result.fixable_errors:
|
||||
fixable.append({
|
||||
"node_id": error.node_id,
|
||||
"node_type": error.node_type,
|
||||
"error_type": error.rule_id,
|
||||
"message": error.message,
|
||||
"is_fixable": True,
|
||||
"fix_hint": error.fix_hint,
|
||||
"category": error.category.value,
|
||||
"details": error.details,
|
||||
})
|
||||
|
||||
for error in result.user_required_errors:
|
||||
user_required.append({
|
||||
"node_id": error.node_id,
|
||||
"node_type": error.node_type,
|
||||
"error_type": error.rule_id,
|
||||
"message": error.message,
|
||||
"is_fixable": False,
|
||||
"fix_hint": error.fix_hint,
|
||||
"category": error.category.value,
|
||||
"details": error.details,
|
||||
})
|
||||
|
||||
# Include warnings in user_required (they're non-blocking but informative)
|
||||
for error in result.warnings:
|
||||
user_required.append({
|
||||
"node_id": error.node_id,
|
||||
"node_type": error.node_type,
|
||||
"error_type": error.rule_id,
|
||||
"message": error.message,
|
||||
"is_fixable": error.is_fixable,
|
||||
"fix_hint": error.fix_hint,
|
||||
"category": error.category.value,
|
||||
"severity": "warning",
|
||||
"details": error.details,
|
||||
})
|
||||
|
||||
# Generate combined warnings for backwards compatibility
|
||||
all_warnings = [e["message"] for e in fixable + user_required]
|
||||
|
||||
return {
|
||||
"fixable": fixable,
|
||||
"user_required": user_required,
|
||||
"all_warnings": all_warnings,
|
||||
"stats": result.stats,
|
||||
}
|
||||
|
||||
|
||||
def build_fix_prompt(
|
||||
fixable_errors: list[dict[str, Any]],
|
||||
previous_nodes: list[dict[str, Any]],
|
||||
available_models: list[dict[str, Any]] | None = None,
|
||||
) -> str:
|
||||
"""
|
||||
Build a prompt for LLM to fix the identified errors.
|
||||
|
||||
This creates a focused instruction that tells the LLM exactly what
|
||||
to fix in the previous generation.
|
||||
|
||||
Args:
|
||||
fixable_errors: List of errors that can be automatically fixed
|
||||
previous_nodes: The nodes from the previous generation attempt
|
||||
available_models: Available models for model configuration fixes
|
||||
|
||||
Returns:
|
||||
Formatted prompt string for the fix iteration
|
||||
"""
|
||||
if not fixable_errors:
|
||||
return ""
|
||||
|
||||
parts = ["<fix_required>"]
|
||||
parts.append(" <description>")
|
||||
parts.append(" Your previous generation has errors that need fixing.")
|
||||
parts.append(" Please regenerate with the following corrections:")
|
||||
parts.append(" </description>")
|
||||
|
||||
# Group errors by node
|
||||
errors_by_node: dict[str, list[dict[str, Any]]] = {}
|
||||
for error in fixable_errors:
|
||||
node_id = error["node_id"]
|
||||
if node_id not in errors_by_node:
|
||||
errors_by_node[node_id] = []
|
||||
errors_by_node[node_id].append(error)
|
||||
|
||||
parts.append(" <errors_to_fix>")
|
||||
for node_id, node_errors in errors_by_node.items():
|
||||
parts.append(f" <node id=\"{node_id}\">")
|
||||
for error in node_errors:
|
||||
error_type = error["error_type"]
|
||||
message = error["message"]
|
||||
fix_hint = error.get("fix_hint", "")
|
||||
parts.append(f" <error type=\"{error_type}\">")
|
||||
parts.append(f" <message>{message}</message>")
|
||||
if fix_hint:
|
||||
parts.append(f" <fix_hint>{fix_hint}</fix_hint>")
|
||||
parts.append(" </error>")
|
||||
parts.append(" </node>")
|
||||
parts.append(" </errors_to_fix>")
|
||||
|
||||
# Add model selection help if there are model-related errors
|
||||
model_errors = [e for e in fixable_errors if "model" in e["error_type"]]
|
||||
if model_errors and available_models:
|
||||
parts.append(" <model_selection_help>")
|
||||
parts.append(" Use one of these models for nodes requiring model config:")
|
||||
for model in available_models[:3]: # Show top 3
|
||||
provider = model.get("provider", "unknown")
|
||||
name = model.get("model", "unknown")
|
||||
parts.append(f' - {{"provider": "{provider}", "name": "{name}", "mode": "chat"}}')
|
||||
parts.append(" </model_selection_help>")
|
||||
|
||||
# Add previous nodes summary for context
|
||||
parts.append(" <previous_nodes_to_fix>")
|
||||
for node in previous_nodes:
|
||||
node_id = node.get("id", "unknown")
|
||||
if node_id in errors_by_node:
|
||||
# Only include nodes that have errors
|
||||
node_type = node.get("type", "unknown")
|
||||
title = node.get("title", "Untitled")
|
||||
config_summary = json.dumps(node.get("config", {}), ensure_ascii=False)[:200]
|
||||
parts.append(f" <node id=\"{node_id}\" type=\"{node_type}\" title=\"{title}\">")
|
||||
parts.append(f" <current_config>{config_summary}...</current_config>")
|
||||
parts.append(" </node>")
|
||||
parts.append(" </previous_nodes_to_fix>")
|
||||
|
||||
parts.append(" <instructions>")
|
||||
parts.append(" 1. Keep the workflow structure and logic unchanged")
|
||||
parts.append(" 2. Fix ONLY the errors listed above")
|
||||
parts.append(" 3. Ensure all required fields are properly filled")
|
||||
parts.append(" 4. Use variable references {{#node_id.field#}} where appropriate")
|
||||
parts.append(" </instructions>")
|
||||
parts.append("</fix_required>")
|
||||
|
||||
return "\n".join(parts)
|
||||
|
||||
|
|
@ -0,0 +1,194 @@
|
|||
import json
|
||||
import logging
|
||||
import re
|
||||
from collections.abc import Sequence
|
||||
|
||||
import json_repair
|
||||
|
||||
from core.model_manager import ModelManager
|
||||
from core.model_runtime.entities.message_entities import SystemPromptMessage, UserPromptMessage
|
||||
from core.model_runtime.entities.model_entities import ModelType
|
||||
from core.workflow.generator.prompts.builder_prompts import BUILDER_SYSTEM_PROMPT, BUILDER_USER_PROMPT
|
||||
from core.workflow.generator.prompts.planner_prompts import (
|
||||
PLANNER_SYSTEM_PROMPT,
|
||||
PLANNER_USER_PROMPT,
|
||||
format_tools_for_planner,
|
||||
)
|
||||
from core.workflow.generator.prompts.vibe_prompts import (
|
||||
format_available_nodes,
|
||||
format_available_tools,
|
||||
parse_vibe_response,
|
||||
)
|
||||
from core.workflow.generator.utils.edge_repair import EdgeRepair
|
||||
from core.workflow.generator.utils.mermaid_generator import generate_mermaid
|
||||
from core.workflow.generator.utils.node_repair import NodeRepair
|
||||
from core.workflow.generator.utils.workflow_validator import WorkflowValidator
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class WorkflowGenerator:
|
||||
"""
|
||||
Refactored Vibe Workflow Generator (Planner-Builder Architecture).
|
||||
Extracts Vibe logic from the monolithic LLMGenerator.
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def generate_workflow_flowchart(
|
||||
cls,
|
||||
tenant_id: str,
|
||||
instruction: str,
|
||||
model_config: dict,
|
||||
available_nodes: Sequence[dict[str, object]] | None = None,
|
||||
existing_nodes: Sequence[dict[str, object]] | None = None,
|
||||
available_tools: Sequence[dict[str, object]] | None = None,
|
||||
selected_node_ids: Sequence[str] | None = None,
|
||||
previous_workflow: dict[str, object] | None = None,
|
||||
regenerate_mode: bool = False,
|
||||
preferred_language: str | None = None,
|
||||
available_models: Sequence[dict[str, object]] | None = None,
|
||||
max_fix_iterations: int = 2,
|
||||
):
|
||||
"""
|
||||
Generates a Dify Workflow Flowchart from natural language instruction.
|
||||
|
||||
Pipeline:
|
||||
1. Planner: Analyze intent & select tools.
|
||||
2. Context Filter: Filter relevant tools (reduce tokens).
|
||||
3. Builder: Generate node configurations.
|
||||
4. Repair: Fix common node/edge issues (NodeRepair, EdgeRepair).
|
||||
5. Validator: Check for errors & generate friendly hints.
|
||||
6. Renderer: Deterministic Mermaid generation.
|
||||
"""
|
||||
model_manager = ModelManager()
|
||||
model_instance = model_manager.get_model_instance(
|
||||
tenant_id=tenant_id,
|
||||
model_type=ModelType.LLM,
|
||||
provider=model_config.get("provider", ""),
|
||||
model=model_config.get("name", ""),
|
||||
)
|
||||
model_parameters = model_config.get("completion_params", {})
|
||||
available_tools_list = list(available_tools) if available_tools else []
|
||||
|
||||
# --- STEP 1: PLANNER ---
|
||||
planner_tools_context = format_tools_for_planner(available_tools_list)
|
||||
planner_system = PLANNER_SYSTEM_PROMPT.format(tools_summary=planner_tools_context)
|
||||
planner_user = PLANNER_USER_PROMPT.format(instruction=instruction)
|
||||
|
||||
try:
|
||||
response = model_instance.invoke_llm(
|
||||
prompt_messages=[SystemPromptMessage(content=planner_system), UserPromptMessage(content=planner_user)],
|
||||
model_parameters=model_parameters,
|
||||
stream=False,
|
||||
)
|
||||
plan_content = response.message.content
|
||||
# Reuse parse_vibe_response logic or simple load
|
||||
plan_data = parse_vibe_response(plan_content)
|
||||
except Exception as e:
|
||||
logger.exception("Planner failed")
|
||||
return {"intent": "error", "error": f"Planning failed: {str(e)}"}
|
||||
|
||||
if plan_data.get("intent") == "off_topic":
|
||||
return {
|
||||
"intent": "off_topic",
|
||||
"message": plan_data.get("message", "I can only help with workflow creation."),
|
||||
"suggestions": plan_data.get("suggestions", []),
|
||||
}
|
||||
|
||||
# --- STEP 2: CONTEXT FILTERING ---
|
||||
required_tools = plan_data.get("required_tool_keys", [])
|
||||
|
||||
filtered_tools = []
|
||||
if required_tools:
|
||||
# Simple linear search (optimized version would use a map)
|
||||
for tool in available_tools_list:
|
||||
t_key = tool.get("tool_key") or tool.get("tool_name")
|
||||
provider = tool.get("provider_id") or tool.get("provider")
|
||||
full_key = f"{provider}/{t_key}" if provider else t_key
|
||||
|
||||
# Check if this tool is in required list (match either full key or short name)
|
||||
if t_key in required_tools or full_key in required_tools:
|
||||
filtered_tools.append(tool)
|
||||
else:
|
||||
# If logic only, no tools needed
|
||||
filtered_tools = []
|
||||
|
||||
# --- STEP 3: BUILDER ---
|
||||
# Prepare context
|
||||
tool_schemas = format_available_tools(filtered_tools)
|
||||
# We need to construct a fake list structure for builtin nodes formatting if using format_available_nodes
|
||||
# Actually format_available_nodes takes None to use defaults, or a list to add custom
|
||||
# But we want to SHOW the builtins. format_available_nodes internally uses BUILTIN_NODE_SCHEMAS.
|
||||
node_specs = format_available_nodes([])
|
||||
|
||||
builder_system = BUILDER_SYSTEM_PROMPT.format(
|
||||
plan_context=json.dumps(plan_data.get("steps", []), indent=2),
|
||||
tool_schemas=tool_schemas,
|
||||
builtin_node_specs=node_specs,
|
||||
)
|
||||
builder_user = BUILDER_USER_PROMPT.format(instruction=instruction)
|
||||
|
||||
try:
|
||||
build_res = model_instance.invoke_llm(
|
||||
prompt_messages=[SystemPromptMessage(content=builder_system), UserPromptMessage(content=builder_user)],
|
||||
model_parameters=model_parameters,
|
||||
stream=False,
|
||||
)
|
||||
# Builder output is raw JSON nodes/edges
|
||||
build_content = build_res.message.content
|
||||
match = re.search(r"```(?:json)?\s*([\s\S]+?)```", build_content)
|
||||
if match:
|
||||
build_content = match.group(1)
|
||||
|
||||
workflow_data = json_repair.loads(build_content)
|
||||
|
||||
if "nodes" not in workflow_data:
|
||||
workflow_data["nodes"] = []
|
||||
if "edges" not in workflow_data:
|
||||
workflow_data["edges"] = []
|
||||
|
||||
except Exception as e:
|
||||
logger.exception("Builder failed")
|
||||
return {"intent": "error", "error": f"Building failed: {str(e)}"}
|
||||
|
||||
# --- STEP 3.4: NODE REPAIR ---
|
||||
node_repair_result = NodeRepair.repair(workflow_data["nodes"])
|
||||
workflow_data["nodes"] = node_repair_result.nodes
|
||||
|
||||
# --- STEP 3.5: EDGE REPAIR ---
|
||||
repair_result = EdgeRepair.repair(workflow_data)
|
||||
workflow_data = {
|
||||
"nodes": repair_result.nodes,
|
||||
"edges": repair_result.edges,
|
||||
}
|
||||
|
||||
# --- STEP 4: VALIDATOR ---
|
||||
is_valid, hints = WorkflowValidator.validate(workflow_data, available_tools_list)
|
||||
|
||||
# --- STEP 5: RENDERER ---
|
||||
mermaid_code = generate_mermaid(workflow_data)
|
||||
|
||||
# --- FINALIZE ---
|
||||
# Combine validation hints with repair warnings
|
||||
all_warnings = [h.message for h in hints] + repair_result.warnings + node_repair_result.warnings
|
||||
|
||||
# Add stability warning (as requested by user)
|
||||
stability_warning = "The generated workflow may require debugging."
|
||||
if preferred_language and preferred_language.startswith("zh"):
|
||||
stability_warning = "生成的 Workflow 可能需要调试。"
|
||||
all_warnings.append(stability_warning)
|
||||
|
||||
all_fixes = repair_result.repairs_made + node_repair_result.repairs_made
|
||||
|
||||
return {
|
||||
"intent": "generate",
|
||||
"flowchart": mermaid_code,
|
||||
"nodes": workflow_data["nodes"],
|
||||
"edges": workflow_data["edges"],
|
||||
"message": plan_data.get("plan_thought", "Generated workflow based on your request."),
|
||||
"warnings": all_warnings,
|
||||
"tool_recommendations": [], # Legacy field
|
||||
"error": "",
|
||||
"fix_iterations": 0, # Legacy
|
||||
"fixed_issues": all_fixes, # Track what was auto-fixed
|
||||
}
|
||||
|
|
@ -0,0 +1,372 @@
|
|||
"""
|
||||
Edge Repair Utility for Vibe Workflow Generation.
|
||||
|
||||
This module provides intelligent edge repair capabilities for generated workflows.
|
||||
It can detect and fix common edge issues:
|
||||
- Missing edges between sequential nodes
|
||||
- Incomplete branches for question-classifier and if-else nodes
|
||||
- Orphaned nodes without connections
|
||||
|
||||
The repair logic is deterministic and doesn't require LLM calls.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class RepairResult:
|
||||
"""Result of edge repair operation."""
|
||||
|
||||
nodes: list[dict[str, Any]]
|
||||
edges: list[dict[str, Any]]
|
||||
repairs_made: list[str] = field(default_factory=list)
|
||||
warnings: list[str] = field(default_factory=list)
|
||||
|
||||
@property
|
||||
def was_repaired(self) -> bool:
|
||||
"""Check if any repairs were made."""
|
||||
return len(self.repairs_made) > 0
|
||||
|
||||
|
||||
class EdgeRepair:
|
||||
"""
|
||||
Intelligent edge repair for workflow graphs.
|
||||
|
||||
Repairs are applied in order:
|
||||
1. Infer linear connections from node order (if no edges exist)
|
||||
2. Add missing branch edges for question-classifier
|
||||
3. Add missing branch edges for if-else
|
||||
4. Connect orphaned nodes
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def repair(cls, workflow_data: dict[str, Any]) -> RepairResult:
|
||||
"""
|
||||
Repair edges in the workflow data.
|
||||
|
||||
Args:
|
||||
workflow_data: Dict containing 'nodes' and 'edges'
|
||||
|
||||
Returns:
|
||||
RepairResult with repaired nodes, edges, and repair logs
|
||||
"""
|
||||
nodes = list(workflow_data.get("nodes", []))
|
||||
edges = list(workflow_data.get("edges", []))
|
||||
repairs: list[str] = []
|
||||
warnings: list[str] = []
|
||||
|
||||
logger.info("[EdgeRepair] Starting repair: %d nodes, %d edges", len(nodes), len(edges))
|
||||
|
||||
# Build node lookup
|
||||
node_map = {n.get("id"): n for n in nodes if n.get("id")}
|
||||
node_ids = set(node_map.keys())
|
||||
|
||||
# 1. If no edges at all, infer linear chain
|
||||
if not edges and len(nodes) > 1:
|
||||
edges, inferred_repairs = cls._infer_linear_chain(nodes)
|
||||
repairs.extend(inferred_repairs)
|
||||
|
||||
# 2. Build edge index for analysis
|
||||
outgoing_edges: dict[str, list[dict[str, Any]]] = {}
|
||||
incoming_edges: dict[str, list[dict[str, Any]]] = {}
|
||||
for edge in edges:
|
||||
src = edge.get("source")
|
||||
tgt = edge.get("target")
|
||||
if src:
|
||||
outgoing_edges.setdefault(src, []).append(edge)
|
||||
if tgt:
|
||||
incoming_edges.setdefault(tgt, []).append(edge)
|
||||
|
||||
# 3. Repair question-classifier branches
|
||||
for node in nodes:
|
||||
if node.get("type") == "question-classifier":
|
||||
new_edges, branch_repairs, branch_warnings = cls._repair_classifier_branches(
|
||||
node, edges, outgoing_edges, node_ids
|
||||
)
|
||||
edges.extend(new_edges)
|
||||
repairs.extend(branch_repairs)
|
||||
warnings.extend(branch_warnings)
|
||||
# Update outgoing index
|
||||
for edge in new_edges:
|
||||
outgoing_edges.setdefault(edge.get("source"), []).append(edge)
|
||||
|
||||
# 4. Repair if-else branches
|
||||
for node in nodes:
|
||||
if node.get("type") == "if-else":
|
||||
new_edges, branch_repairs, branch_warnings = cls._repair_if_else_branches(
|
||||
node, edges, outgoing_edges, node_ids
|
||||
)
|
||||
edges.extend(new_edges)
|
||||
repairs.extend(branch_repairs)
|
||||
warnings.extend(branch_warnings)
|
||||
# Update outgoing index
|
||||
for edge in new_edges:
|
||||
outgoing_edges.setdefault(edge.get("source"), []).append(edge)
|
||||
|
||||
# 5. Connect orphaned nodes (nodes with no incoming edge, except start)
|
||||
new_edges, orphan_repairs = cls._connect_orphaned_nodes(
|
||||
nodes, edges, outgoing_edges, incoming_edges
|
||||
)
|
||||
edges.extend(new_edges)
|
||||
repairs.extend(orphan_repairs)
|
||||
|
||||
# 6. Connect nodes with no outgoing edge to 'end' (except end nodes)
|
||||
new_edges, terminal_repairs = cls._connect_terminal_nodes(
|
||||
nodes, edges, outgoing_edges
|
||||
)
|
||||
edges.extend(new_edges)
|
||||
repairs.extend(terminal_repairs)
|
||||
|
||||
logger.info("[EdgeRepair] Completed: %d repairs made, %d warnings", len(repairs), len(warnings))
|
||||
for r in repairs:
|
||||
logger.info("[EdgeRepair] Repair: %s", r)
|
||||
for w in warnings:
|
||||
logger.info("[EdgeRepair] Warning: %s", w)
|
||||
|
||||
return RepairResult(
|
||||
nodes=nodes,
|
||||
edges=edges,
|
||||
repairs_made=repairs,
|
||||
warnings=warnings,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _infer_linear_chain(cls, nodes: list[dict[str, Any]]) -> tuple[list[dict[str, Any]], list[str]]:
|
||||
"""
|
||||
Infer a linear chain of edges from node order.
|
||||
|
||||
This is used when no edges are provided at all.
|
||||
"""
|
||||
edges: list[dict[str, Any]] = []
|
||||
repairs: list[str] = []
|
||||
|
||||
# Filter to get ordered node IDs
|
||||
node_ids = [n.get("id") for n in nodes if n.get("id")]
|
||||
|
||||
if len(node_ids) < 2:
|
||||
return edges, repairs
|
||||
|
||||
# Create edges between consecutive nodes
|
||||
for i in range(len(node_ids) - 1):
|
||||
src = node_ids[i]
|
||||
tgt = node_ids[i + 1]
|
||||
edges.append({"source": src, "target": tgt})
|
||||
repairs.append(f"Inferred edge: {src} -> {tgt}")
|
||||
|
||||
logger.info("[EdgeRepair] Inferred %d edges from node order (no edges provided)", len(edges))
|
||||
return edges, repairs
|
||||
|
||||
@classmethod
|
||||
def _repair_classifier_branches(
|
||||
cls,
|
||||
node: dict[str, Any],
|
||||
edges: list[dict[str, Any]],
|
||||
outgoing_edges: dict[str, list[dict[str, Any]]],
|
||||
valid_node_ids: set[str],
|
||||
) -> tuple[list[dict[str, Any]], list[str], list[str]]:
|
||||
"""
|
||||
Repair missing branches for question-classifier nodes.
|
||||
|
||||
For each class that doesn't have an edge, create one pointing to 'end'.
|
||||
"""
|
||||
new_edges: list[dict[str, Any]] = []
|
||||
repairs: list[str] = []
|
||||
warnings: list[str] = []
|
||||
|
||||
node_id = node.get("id")
|
||||
if not node_id:
|
||||
return new_edges, repairs, warnings
|
||||
|
||||
config = node.get("config", {})
|
||||
classes = config.get("classes", [])
|
||||
|
||||
if not classes:
|
||||
return new_edges, repairs, warnings
|
||||
|
||||
# Get existing sourceHandles for this node
|
||||
existing_handles = set()
|
||||
for edge in outgoing_edges.get(node_id, []):
|
||||
handle = edge.get("sourceHandle")
|
||||
if handle:
|
||||
existing_handles.add(handle)
|
||||
|
||||
# Find 'end' node as default target
|
||||
end_node_id = "end"
|
||||
if "end" not in valid_node_ids:
|
||||
# Try to find an end node
|
||||
for nid in valid_node_ids:
|
||||
if "end" in nid.lower():
|
||||
end_node_id = nid
|
||||
break
|
||||
|
||||
# Add missing branches
|
||||
for cls_def in classes:
|
||||
if not isinstance(cls_def, dict):
|
||||
continue
|
||||
cls_id = cls_def.get("id")
|
||||
cls_name = cls_def.get("name", cls_id)
|
||||
|
||||
if cls_id and cls_id not in existing_handles:
|
||||
new_edge = {
|
||||
"source": node_id,
|
||||
"sourceHandle": cls_id,
|
||||
"target": end_node_id,
|
||||
}
|
||||
new_edges.append(new_edge)
|
||||
repairs.append(f"Added missing branch edge for class '{cls_name}' -> {end_node_id}")
|
||||
warnings.append(
|
||||
f"Auto-connected question-classifier branch '{cls_name}' to '{end_node_id}'. "
|
||||
"You may want to redirect this to a specific handler node."
|
||||
)
|
||||
|
||||
return new_edges, repairs, warnings
|
||||
|
||||
@classmethod
|
||||
def _repair_if_else_branches(
|
||||
cls,
|
||||
node: dict[str, Any],
|
||||
edges: list[dict[str, Any]],
|
||||
outgoing_edges: dict[str, list[dict[str, Any]]],
|
||||
valid_node_ids: set[str],
|
||||
) -> tuple[list[dict[str, Any]], list[str], list[str]]:
|
||||
"""
|
||||
Repair missing true/false branches for if-else nodes.
|
||||
"""
|
||||
new_edges: list[dict[str, Any]] = []
|
||||
repairs: list[str] = []
|
||||
warnings: list[str] = []
|
||||
|
||||
node_id = node.get("id")
|
||||
if not node_id:
|
||||
return new_edges, repairs, warnings
|
||||
|
||||
# Get existing sourceHandles
|
||||
existing_handles = set()
|
||||
for edge in outgoing_edges.get(node_id, []):
|
||||
handle = edge.get("sourceHandle")
|
||||
if handle:
|
||||
existing_handles.add(handle)
|
||||
|
||||
# Find 'end' node as default target
|
||||
end_node_id = "end"
|
||||
if "end" not in valid_node_ids:
|
||||
for nid in valid_node_ids:
|
||||
if "end" in nid.lower():
|
||||
end_node_id = nid
|
||||
break
|
||||
|
||||
# Add missing branches
|
||||
required_branches = ["true", "false"]
|
||||
for branch in required_branches:
|
||||
if branch not in existing_handles:
|
||||
new_edge = {
|
||||
"source": node_id,
|
||||
"sourceHandle": branch,
|
||||
"target": end_node_id,
|
||||
}
|
||||
new_edges.append(new_edge)
|
||||
repairs.append(f"Added missing if-else '{branch}' branch -> {end_node_id}")
|
||||
warnings.append(
|
||||
f"Auto-connected if-else '{branch}' branch to '{end_node_id}'. "
|
||||
"You may want to redirect this to a specific handler node."
|
||||
)
|
||||
|
||||
return new_edges, repairs, warnings
|
||||
|
||||
@classmethod
|
||||
def _connect_orphaned_nodes(
|
||||
cls,
|
||||
nodes: list[dict[str, Any]],
|
||||
edges: list[dict[str, Any]],
|
||||
outgoing_edges: dict[str, list[dict[str, Any]]],
|
||||
incoming_edges: dict[str, list[dict[str, Any]]],
|
||||
) -> tuple[list[dict[str, Any]], list[str]]:
|
||||
"""
|
||||
Connect orphaned nodes to the previous node in sequence.
|
||||
|
||||
An orphaned node has no incoming edges and is not a 'start' node.
|
||||
"""
|
||||
new_edges: list[dict[str, Any]] = []
|
||||
repairs: list[str] = []
|
||||
|
||||
node_ids = [n.get("id") for n in nodes if n.get("id")]
|
||||
node_types = {n.get("id"): n.get("type") for n in nodes}
|
||||
|
||||
for i, node_id in enumerate(node_ids):
|
||||
node_type = node_types.get(node_id)
|
||||
|
||||
# Skip start nodes - they don't need incoming edges
|
||||
if node_type == "start":
|
||||
continue
|
||||
|
||||
# Check if node has incoming edges
|
||||
if node_id not in incoming_edges or not incoming_edges[node_id]:
|
||||
# Find previous node to connect from
|
||||
if i > 0:
|
||||
prev_node_id = node_ids[i - 1]
|
||||
new_edge = {"source": prev_node_id, "target": node_id}
|
||||
new_edges.append(new_edge)
|
||||
repairs.append(f"Connected orphaned node: {prev_node_id} -> {node_id}")
|
||||
|
||||
# Update incoming_edges for subsequent checks
|
||||
incoming_edges.setdefault(node_id, []).append(new_edge)
|
||||
|
||||
return new_edges, repairs
|
||||
|
||||
@classmethod
|
||||
def _connect_terminal_nodes(
|
||||
cls,
|
||||
nodes: list[dict[str, Any]],
|
||||
edges: list[dict[str, Any]],
|
||||
outgoing_edges: dict[str, list[dict[str, Any]]],
|
||||
) -> tuple[list[dict[str, Any]], list[str]]:
|
||||
"""
|
||||
Connect terminal nodes (no outgoing edges) to 'end'.
|
||||
|
||||
A terminal node has no outgoing edges and is not an 'end' node.
|
||||
This ensures all branches eventually reach 'end'.
|
||||
"""
|
||||
new_edges: list[dict[str, Any]] = []
|
||||
repairs: list[str] = []
|
||||
|
||||
# Find end node
|
||||
end_node_id = None
|
||||
node_ids = set()
|
||||
for n in nodes:
|
||||
nid = n.get("id")
|
||||
ntype = n.get("type")
|
||||
if nid:
|
||||
node_ids.add(nid)
|
||||
if ntype == "end":
|
||||
end_node_id = nid
|
||||
|
||||
if not end_node_id:
|
||||
# No end node found, can't connect
|
||||
return new_edges, repairs
|
||||
|
||||
for node in nodes:
|
||||
node_id = node.get("id")
|
||||
node_type = node.get("type")
|
||||
|
||||
# Skip end nodes
|
||||
if node_type == "end":
|
||||
continue
|
||||
|
||||
# Skip nodes that already have outgoing edges
|
||||
if outgoing_edges.get(node_id):
|
||||
continue
|
||||
|
||||
# Connect to end
|
||||
new_edge = {"source": node_id, "target": end_node_id}
|
||||
new_edges.append(new_edge)
|
||||
repairs.append(f"Connected terminal node to end: {node_id} -> {end_node_id}")
|
||||
|
||||
# Update for subsequent checks
|
||||
outgoing_edges.setdefault(node_id, []).append(new_edge)
|
||||
|
||||
return new_edges, repairs
|
||||
|
||||
|
|
@ -0,0 +1,138 @@
|
|||
import logging
|
||||
from typing import Any
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def generate_mermaid(workflow_data: dict[str, Any]) -> str:
|
||||
"""
|
||||
Generate a Mermaid flowchart from workflow data consisting of nodes and edges.
|
||||
|
||||
Args:
|
||||
workflow_data: Dict containing 'nodes' (list) and 'edges' (list)
|
||||
|
||||
Returns:
|
||||
String containing the Mermaid flowchart syntax
|
||||
"""
|
||||
nodes = workflow_data.get("nodes", [])
|
||||
edges = workflow_data.get("edges", [])
|
||||
|
||||
# DEBUG: Log input data
|
||||
logger.debug("[MERMAID] Input nodes count: %d", len(nodes))
|
||||
logger.debug("[MERMAID] Input edges count: %d", len(edges))
|
||||
for i, node in enumerate(nodes):
|
||||
logger.debug(
|
||||
"[MERMAID] Node %d: id=%s, type=%s, title=%s", i, node.get("id"), node.get("type"), node.get("title")
|
||||
)
|
||||
for i, edge in enumerate(edges):
|
||||
logger.debug(
|
||||
"[MERMAID] Edge %d: source=%s, target=%s, sourceHandle=%s",
|
||||
i,
|
||||
edge.get("source"),
|
||||
edge.get("target"),
|
||||
edge.get("sourceHandle"),
|
||||
)
|
||||
|
||||
lines = ["flowchart TD"]
|
||||
|
||||
# 1. Define Nodes
|
||||
# Format: node_id["title<br/>type"] or similar
|
||||
# We will use the Vibe Workflow standard format: id["type=TYPE|title=TITLE"]
|
||||
# Or specifically for tool nodes: id["type=tool|title=TITLE|tool=TOOL_KEY"]
|
||||
|
||||
# Map of original IDs to safe Mermaid IDs
|
||||
id_map = {}
|
||||
|
||||
def get_safe_id(original_id: str) -> str:
|
||||
if original_id == "end":
|
||||
return "end_node"
|
||||
if original_id == "subgraph":
|
||||
return "subgraph_node"
|
||||
# Mermaid IDs should be alphanumeric.
|
||||
# If the ID has special chars, we might need to escape or hash, but Vibe usually generates simple IDs.
|
||||
# We'll trust standard IDs but handle the reserved keyword 'end'.
|
||||
return original_id
|
||||
|
||||
for node in nodes:
|
||||
node_id = node.get("id")
|
||||
if not node_id:
|
||||
continue
|
||||
|
||||
safe_id = get_safe_id(node_id)
|
||||
id_map[node_id] = safe_id
|
||||
|
||||
node_type = node.get("type", "unknown")
|
||||
title = node.get("title", "Untitled")
|
||||
|
||||
# Escape quotes in title
|
||||
safe_title = title.replace('"', "'")
|
||||
|
||||
if node_type == "tool":
|
||||
config = node.get("config", {})
|
||||
# Try multiple fields for tool reference
|
||||
tool_ref = (
|
||||
config.get("tool_key")
|
||||
or config.get("tool")
|
||||
or config.get("tool_name")
|
||||
or node.get("tool_name")
|
||||
or "unknown"
|
||||
)
|
||||
node_def = f'{safe_id}["type={node_type}|title={safe_title}|tool={tool_ref}"]'
|
||||
else:
|
||||
node_def = f'{safe_id}["type={node_type}|title={safe_title}"]'
|
||||
|
||||
lines.append(f" {node_def}")
|
||||
|
||||
# 2. Define Edges
|
||||
# Format: source --> target
|
||||
|
||||
# Track defined nodes to avoid edge errors
|
||||
defined_node_ids = {n.get("id") for n in nodes if n.get("id")}
|
||||
|
||||
for edge in edges:
|
||||
source = edge.get("source")
|
||||
target = edge.get("target")
|
||||
|
||||
# Skip invalid edges
|
||||
if not source or not target:
|
||||
continue
|
||||
|
||||
if source not in defined_node_ids or target not in defined_node_ids:
|
||||
# Log skipped edges for debugging
|
||||
logger.warning(
|
||||
"[MERMAID] Skipping edge: source=%s (exists=%s), target=%s (exists=%s)",
|
||||
source,
|
||||
source in defined_node_ids,
|
||||
target,
|
||||
target in defined_node_ids,
|
||||
)
|
||||
continue
|
||||
|
||||
safe_source = id_map.get(source, source)
|
||||
safe_target = id_map.get(target, target)
|
||||
|
||||
# Handle conditional branches (true/false) if present
|
||||
# In Dify workflow, sourceHandle is often used for this
|
||||
source_handle = edge.get("sourceHandle")
|
||||
label = ""
|
||||
|
||||
if source_handle == "true":
|
||||
label = "|true|"
|
||||
elif source_handle == "false":
|
||||
label = "|false|"
|
||||
elif source_handle and source_handle != "source":
|
||||
# For question-classifier or other multi-path nodes
|
||||
# Clean up handle for display if needed
|
||||
safe_handle = str(source_handle).replace('"', "'")
|
||||
label = f"|{safe_handle}|"
|
||||
|
||||
edge_line = f" {safe_source} -->{label} {safe_target}"
|
||||
logger.debug("[MERMAID] Adding edge: %s", edge_line)
|
||||
lines.append(edge_line)
|
||||
|
||||
# Start/End nodes are implicitly handled if they are in the 'nodes' list
|
||||
# If not, we might need to add them, but usually the Builder should produce them.
|
||||
|
||||
result = "\n".join(lines)
|
||||
logger.debug("[MERMAID] Final output:\n%s", result)
|
||||
return result
|
||||
|
|
@ -0,0 +1,96 @@
|
|||
"""
|
||||
Node Repair Utility for Vibe Workflow Generation.
|
||||
|
||||
This module provides intelligent node configuration repair capabilities.
|
||||
It can detect and fix common node configuration issues:
|
||||
- Invalid comparison operators in if-else nodes (e.g. '>=' -> '≥')
|
||||
"""
|
||||
|
||||
import copy
|
||||
import logging
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class NodeRepairResult:
|
||||
"""Result of node repair operation."""
|
||||
|
||||
nodes: list[dict[str, Any]]
|
||||
repairs_made: list[str] = field(default_factory=list)
|
||||
warnings: list[str] = field(default_factory=list)
|
||||
|
||||
@property
|
||||
def was_repaired(self) -> bool:
|
||||
"""Check if any repairs were made."""
|
||||
return len(self.repairs_made) > 0
|
||||
|
||||
|
||||
class NodeRepair:
|
||||
"""
|
||||
Intelligent node configuration repair.
|
||||
"""
|
||||
|
||||
OPERATOR_MAP = {
|
||||
">=": "≥",
|
||||
"<=": "≤",
|
||||
"!=": "≠",
|
||||
"==": "=",
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def repair(cls, nodes: list[dict[str, Any]]) -> NodeRepairResult:
|
||||
"""
|
||||
Repair node configurations.
|
||||
|
||||
Args:
|
||||
nodes: List of node dictionaries
|
||||
|
||||
Returns:
|
||||
NodeRepairResult with repaired nodes and logs
|
||||
"""
|
||||
# Deep copy to avoid mutating original
|
||||
nodes = copy.deepcopy(nodes)
|
||||
repairs: list[str] = []
|
||||
warnings: list[str] = []
|
||||
|
||||
logger.info("[NodeRepair] Starting repair: %d nodes", len(nodes))
|
||||
|
||||
for node in nodes:
|
||||
node_type = node.get("type")
|
||||
|
||||
if node_type == "if-else":
|
||||
cls._repair_if_else_operators(node, repairs)
|
||||
|
||||
# Add other node type repairs here as needed
|
||||
|
||||
if repairs:
|
||||
logger.info("[NodeRepair] Completed: %d repairs made", len(repairs))
|
||||
for r in repairs:
|
||||
logger.info("[NodeRepair] Repair: %s", r)
|
||||
|
||||
return NodeRepairResult(
|
||||
nodes=nodes,
|
||||
repairs_made=repairs,
|
||||
warnings=warnings,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _repair_if_else_operators(cls, node: dict[str, Any], repairs: list[str]):
|
||||
"""
|
||||
Normalize comparison operators in if-else nodes.
|
||||
"""
|
||||
node_id = node.get("id", "unknown")
|
||||
config = node.get("config", {})
|
||||
cases = config.get("cases", [])
|
||||
|
||||
for case in cases:
|
||||
conditions = case.get("conditions", [])
|
||||
for condition in conditions:
|
||||
op = condition.get("comparison_operator")
|
||||
if op in cls.OPERATOR_MAP:
|
||||
new_op = cls.OPERATOR_MAP[op]
|
||||
condition["comparison_operator"] = new_op
|
||||
repairs.append(f"Normalized operator '{op}' to '{new_op}' in node '{node_id}'")
|
||||
|
|
@ -0,0 +1,96 @@
|
|||
import logging
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
|
||||
from core.workflow.generator.validation.context import ValidationContext
|
||||
from core.workflow.generator.validation.engine import ValidationEngine
|
||||
from core.workflow.generator.validation.rules import Severity
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ValidationHint:
|
||||
"""Legacy compatibility class for validation hints."""
|
||||
|
||||
node_id: str
|
||||
field: str
|
||||
message: str
|
||||
severity: str # 'error', 'warning'
|
||||
suggestion: str = None
|
||||
node_type: str = None # Added for test compatibility
|
||||
|
||||
# Alias for potential old code using 'type' instead of 'severity'
|
||||
@property
|
||||
def type(self) -> str:
|
||||
return self.severity
|
||||
|
||||
@property
|
||||
def element_id(self) -> str:
|
||||
return self.node_id
|
||||
|
||||
|
||||
FriendlyHint = ValidationHint # Alias for backward compatibility
|
||||
|
||||
|
||||
class WorkflowValidator:
|
||||
"""
|
||||
Validates the generated workflow configuration (nodes and edges).
|
||||
Wraps the new ValidationEngine for backward compatibility.
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def validate(
|
||||
cls,
|
||||
workflow_data: dict[str, Any],
|
||||
available_tools: list[dict[str, Any]],
|
||||
available_models: list[dict[str, Any]] | None = None,
|
||||
) -> tuple[bool, list[ValidationHint]]:
|
||||
"""
|
||||
Validate workflow data and return validity status and hints.
|
||||
|
||||
Args:
|
||||
workflow_data: Dict containing 'nodes' and 'edges'
|
||||
available_tools: List of available tool configurations
|
||||
available_models: List of available models (added for Vibe compat)
|
||||
|
||||
Returns:
|
||||
Tuple(max_severity_is_not_error, list_of_hints)
|
||||
"""
|
||||
nodes = workflow_data.get("nodes", [])
|
||||
edges = workflow_data.get("edges", [])
|
||||
|
||||
# Create context
|
||||
context = ValidationContext(
|
||||
nodes=nodes,
|
||||
edges=edges,
|
||||
available_models=available_models or [],
|
||||
available_tools=available_tools or [],
|
||||
)
|
||||
|
||||
# Run validation engine
|
||||
engine = ValidationEngine()
|
||||
result = engine.validate(context)
|
||||
|
||||
# Convert engine errors to legacy hints
|
||||
hints: list[ValidationHint] = []
|
||||
|
||||
for error in result.all_errors:
|
||||
# Map severity
|
||||
severity = "error" if error.severity == Severity.ERROR else "warning"
|
||||
|
||||
# Map field from message or details if possible (heuristic)
|
||||
field_name = error.details.get("field", "unknown")
|
||||
|
||||
hints.append(
|
||||
ValidationHint(
|
||||
node_id=error.node_id,
|
||||
field=field_name,
|
||||
message=error.message,
|
||||
severity=severity,
|
||||
suggestion=error.fix_hint,
|
||||
node_type=error.node_type,
|
||||
)
|
||||
)
|
||||
|
||||
return result.is_valid, hints
|
||||
|
|
@ -0,0 +1,45 @@
|
|||
"""
|
||||
Validation Rule Engine for Vibe Workflow Generation.
|
||||
|
||||
This module provides a declarative, schema-based validation system for
|
||||
generated workflow nodes. It classifies errors into fixable (LLM can auto-fix)
|
||||
and user-required (needs manual intervention) categories.
|
||||
|
||||
Usage:
|
||||
from core.workflow.generator.validation import ValidationEngine, ValidationContext
|
||||
|
||||
context = ValidationContext(
|
||||
available_models=[...],
|
||||
available_tools=[...],
|
||||
nodes=[...],
|
||||
edges=[...],
|
||||
)
|
||||
engine = ValidationEngine()
|
||||
result = engine.validate(context)
|
||||
|
||||
# Access classified errors
|
||||
fixable_errors = result.fixable_errors
|
||||
user_required_errors = result.user_required_errors
|
||||
"""
|
||||
|
||||
from core.workflow.generator.validation.context import ValidationContext
|
||||
from core.workflow.generator.validation.engine import ValidationEngine, ValidationResult
|
||||
from core.workflow.generator.validation.rules import (
|
||||
RuleCategory,
|
||||
Severity,
|
||||
ValidationError,
|
||||
ValidationRule,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"RuleCategory",
|
||||
"Severity",
|
||||
"ValidationContext",
|
||||
"ValidationEngine",
|
||||
"ValidationError",
|
||||
"ValidationResult",
|
||||
"ValidationRule",
|
||||
]
|
||||
|
||||
|
||||
|
||||
|
|
@ -0,0 +1,123 @@
|
|||
"""
|
||||
Validation Context for the Rule Engine.
|
||||
|
||||
The ValidationContext holds all the data needed for validation:
|
||||
- Generated nodes and edges
|
||||
- Available models, tools, and datasets
|
||||
- Node output schemas for variable reference validation
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any
|
||||
|
||||
|
||||
@dataclass
|
||||
class ValidationContext:
|
||||
"""
|
||||
Context object containing all data needed for validation.
|
||||
|
||||
This is passed to each validation rule, providing access to:
|
||||
- The nodes being validated
|
||||
- Edge connections between nodes
|
||||
- Available external resources (models, tools)
|
||||
"""
|
||||
|
||||
# Generated workflow data
|
||||
nodes: list[dict[str, Any]] = field(default_factory=list)
|
||||
edges: list[dict[str, Any]] = field(default_factory=list)
|
||||
|
||||
# Available external resources
|
||||
available_models: list[dict[str, Any]] = field(default_factory=list)
|
||||
available_tools: list[dict[str, Any]] = field(default_factory=list)
|
||||
|
||||
# Cached lookups (populated lazily)
|
||||
_node_map: dict[str, dict[str, Any]] | None = field(default=None, repr=False)
|
||||
_model_set: set[tuple[str, str]] | None = field(default=None, repr=False)
|
||||
_tool_set: set[str] | None = field(default=None, repr=False)
|
||||
_configured_tool_set: set[str] | None = field(default=None, repr=False)
|
||||
|
||||
@property
|
||||
def node_map(self) -> dict[str, dict[str, Any]]:
|
||||
"""Get a map of node_id -> node for quick lookup."""
|
||||
if self._node_map is None:
|
||||
self._node_map = {node.get("id", ""): node for node in self.nodes}
|
||||
return self._node_map
|
||||
|
||||
@property
|
||||
def model_set(self) -> set[tuple[str, str]]:
|
||||
"""Get a set of (provider, model_name) tuples for quick lookup."""
|
||||
if self._model_set is None:
|
||||
self._model_set = {
|
||||
(m.get("provider", ""), m.get("model", ""))
|
||||
for m in self.available_models
|
||||
}
|
||||
return self._model_set
|
||||
|
||||
@property
|
||||
def tool_set(self) -> set[str]:
|
||||
"""Get a set of all tool keys (both configured and unconfigured)."""
|
||||
if self._tool_set is None:
|
||||
self._tool_set = set()
|
||||
for tool in self.available_tools:
|
||||
provider = tool.get("provider_id") or tool.get("provider", "")
|
||||
tool_key = tool.get("tool_key") or tool.get("tool_name", "")
|
||||
if provider and tool_key:
|
||||
self._tool_set.add(f"{provider}/{tool_key}")
|
||||
if tool_key:
|
||||
self._tool_set.add(tool_key)
|
||||
return self._tool_set
|
||||
|
||||
@property
|
||||
def configured_tool_set(self) -> set[str]:
|
||||
"""Get a set of configured (authorized) tool keys."""
|
||||
if self._configured_tool_set is None:
|
||||
self._configured_tool_set = set()
|
||||
for tool in self.available_tools:
|
||||
if not tool.get("is_team_authorization", False):
|
||||
continue
|
||||
provider = tool.get("provider_id") or tool.get("provider", "")
|
||||
tool_key = tool.get("tool_key") or tool.get("tool_name", "")
|
||||
if provider and tool_key:
|
||||
self._configured_tool_set.add(f"{provider}/{tool_key}")
|
||||
if tool_key:
|
||||
self._configured_tool_set.add(tool_key)
|
||||
return self._configured_tool_set
|
||||
|
||||
def has_model(self, provider: str, model_name: str) -> bool:
|
||||
"""Check if a model is available."""
|
||||
return (provider, model_name) in self.model_set
|
||||
|
||||
def has_tool(self, tool_key: str) -> bool:
|
||||
"""Check if a tool exists (configured or not)."""
|
||||
return tool_key in self.tool_set
|
||||
|
||||
def is_tool_configured(self, tool_key: str) -> bool:
|
||||
"""Check if a tool is configured and ready to use."""
|
||||
return tool_key in self.configured_tool_set
|
||||
|
||||
def get_node(self, node_id: str) -> dict[str, Any] | None:
|
||||
"""Get a node by its ID."""
|
||||
return self.node_map.get(node_id)
|
||||
|
||||
def get_node_ids(self) -> set[str]:
|
||||
"""Get all node IDs in the workflow."""
|
||||
return set(self.node_map.keys())
|
||||
|
||||
def get_upstream_nodes(self, node_id: str) -> list[str]:
|
||||
"""Get IDs of nodes that connect to this node (upstream)."""
|
||||
return [
|
||||
edge.get("source", "")
|
||||
for edge in self.edges
|
||||
if edge.get("target") == node_id
|
||||
]
|
||||
|
||||
def get_downstream_nodes(self, node_id: str) -> list[str]:
|
||||
"""Get IDs of nodes that this node connects to (downstream)."""
|
||||
return [
|
||||
edge.get("target", "")
|
||||
for edge in self.edges
|
||||
if edge.get("source") == node_id
|
||||
]
|
||||
|
||||
|
||||
|
||||
|
|
@ -0,0 +1,266 @@
|
|||
"""
|
||||
Validation Engine - Core validation logic.
|
||||
|
||||
The ValidationEngine orchestrates rule execution and aggregates results.
|
||||
It provides a clean interface for validating workflow nodes.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any
|
||||
|
||||
from core.workflow.generator.validation.context import ValidationContext
|
||||
from core.workflow.generator.validation.rules import (
|
||||
RuleCategory,
|
||||
Severity,
|
||||
ValidationError,
|
||||
get_registry,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ValidationResult:
|
||||
"""
|
||||
Result of validation containing all errors classified by fixability.
|
||||
|
||||
Attributes:
|
||||
all_errors: All validation errors found
|
||||
fixable_errors: Errors that LLM can automatically fix
|
||||
user_required_errors: Errors that require user intervention
|
||||
warnings: Non-blocking warnings
|
||||
stats: Validation statistics
|
||||
"""
|
||||
|
||||
all_errors: list[ValidationError] = field(default_factory=list)
|
||||
fixable_errors: list[ValidationError] = field(default_factory=list)
|
||||
user_required_errors: list[ValidationError] = field(default_factory=list)
|
||||
warnings: list[ValidationError] = field(default_factory=list)
|
||||
stats: dict[str, int] = field(default_factory=dict)
|
||||
|
||||
@property
|
||||
def has_errors(self) -> bool:
|
||||
"""Check if there are any errors (excluding warnings)."""
|
||||
return len(self.fixable_errors) > 0 or len(self.user_required_errors) > 0
|
||||
|
||||
@property
|
||||
def has_fixable_errors(self) -> bool:
|
||||
"""Check if there are fixable errors."""
|
||||
return len(self.fixable_errors) > 0
|
||||
|
||||
@property
|
||||
def is_valid(self) -> bool:
|
||||
"""Check if validation passed (no errors, warnings are OK)."""
|
||||
return not self.has_errors
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
"""Convert to dictionary for API response."""
|
||||
return {
|
||||
"fixable": [e.to_dict() for e in self.fixable_errors],
|
||||
"user_required": [e.to_dict() for e in self.user_required_errors],
|
||||
"warnings": [e.to_dict() for e in self.warnings],
|
||||
"all_warnings": [e.message for e in self.all_errors],
|
||||
"stats": self.stats,
|
||||
}
|
||||
|
||||
def get_error_messages(self) -> list[str]:
|
||||
"""Get all error messages as strings."""
|
||||
return [e.message for e in self.all_errors]
|
||||
|
||||
def get_fixable_by_node(self) -> dict[str, list[ValidationError]]:
|
||||
"""Group fixable errors by node ID."""
|
||||
result: dict[str, list[ValidationError]] = {}
|
||||
for error in self.fixable_errors:
|
||||
if error.node_id not in result:
|
||||
result[error.node_id] = []
|
||||
result[error.node_id].append(error)
|
||||
return result
|
||||
|
||||
|
||||
class ValidationEngine:
|
||||
"""
|
||||
The main validation engine.
|
||||
|
||||
Usage:
|
||||
engine = ValidationEngine()
|
||||
context = ValidationContext(nodes=[...], available_models=[...])
|
||||
result = engine.validate(context)
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self._registry = get_registry()
|
||||
|
||||
def validate(self, context: ValidationContext) -> ValidationResult:
|
||||
"""
|
||||
Validate all nodes in the context.
|
||||
|
||||
Args:
|
||||
context: ValidationContext with nodes, edges, and available resources
|
||||
|
||||
Returns:
|
||||
ValidationResult with classified errors
|
||||
"""
|
||||
result = ValidationResult()
|
||||
stats = {
|
||||
"total_nodes": len(context.nodes),
|
||||
"total_rules_checked": 0,
|
||||
"total_errors": 0,
|
||||
"fixable_count": 0,
|
||||
"user_required_count": 0,
|
||||
"warning_count": 0,
|
||||
}
|
||||
|
||||
# Validate each node
|
||||
for node in context.nodes:
|
||||
node_type = node.get("type", "unknown")
|
||||
node_id = node.get("id", "unknown")
|
||||
|
||||
# Get applicable rules for this node type
|
||||
rules = self._registry.get_rules_for_node(node_type)
|
||||
|
||||
for rule in rules:
|
||||
stats["total_rules_checked"] += 1
|
||||
|
||||
try:
|
||||
errors = rule.check(node, context)
|
||||
for error in errors:
|
||||
result.all_errors.append(error)
|
||||
stats["total_errors"] += 1
|
||||
|
||||
# Classify by severity and fixability
|
||||
if error.severity == Severity.WARNING:
|
||||
result.warnings.append(error)
|
||||
stats["warning_count"] += 1
|
||||
elif error.is_fixable:
|
||||
result.fixable_errors.append(error)
|
||||
stats["fixable_count"] += 1
|
||||
else:
|
||||
result.user_required_errors.append(error)
|
||||
stats["user_required_count"] += 1
|
||||
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"Rule '%s' failed for node '%s'",
|
||||
rule.id,
|
||||
node_id,
|
||||
)
|
||||
# Don't let a rule failure break the entire validation
|
||||
continue
|
||||
|
||||
# Validate edges separately
|
||||
edge_errors = self._validate_edges(context)
|
||||
for error in edge_errors:
|
||||
result.all_errors.append(error)
|
||||
stats["total_errors"] += 1
|
||||
if error.is_fixable:
|
||||
result.fixable_errors.append(error)
|
||||
stats["fixable_count"] += 1
|
||||
else:
|
||||
result.user_required_errors.append(error)
|
||||
stats["user_required_count"] += 1
|
||||
|
||||
result.stats = stats
|
||||
|
||||
logger.debug(
|
||||
"[Validation] Completed: %d nodes, %d rules, %d errors (%d fixable, %d user-required)",
|
||||
stats["total_nodes"],
|
||||
stats["total_rules_checked"],
|
||||
stats["total_errors"],
|
||||
stats["fixable_count"],
|
||||
stats["user_required_count"],
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
def _validate_edges(self, context: ValidationContext) -> list[ValidationError]:
|
||||
"""Validate edge connections."""
|
||||
errors: list[ValidationError] = []
|
||||
valid_node_ids = context.get_node_ids()
|
||||
|
||||
for edge in context.edges:
|
||||
source = edge.get("source", "")
|
||||
target = edge.get("target", "")
|
||||
|
||||
if source and source not in valid_node_ids:
|
||||
errors.append(
|
||||
ValidationError(
|
||||
rule_id="edge.source.invalid",
|
||||
node_id=source,
|
||||
node_type="edge",
|
||||
category=RuleCategory.SEMANTIC,
|
||||
severity=Severity.ERROR,
|
||||
is_fixable=True,
|
||||
message=f"Edge source '{source}' does not exist",
|
||||
fix_hint="Update edge to reference existing node",
|
||||
)
|
||||
)
|
||||
|
||||
if target and target not in valid_node_ids:
|
||||
errors.append(
|
||||
ValidationError(
|
||||
rule_id="edge.target.invalid",
|
||||
node_id=target,
|
||||
node_type="edge",
|
||||
category=RuleCategory.SEMANTIC,
|
||||
severity=Severity.ERROR,
|
||||
is_fixable=True,
|
||||
message=f"Edge target '{target}' does not exist",
|
||||
fix_hint="Update edge to reference existing node",
|
||||
)
|
||||
)
|
||||
|
||||
return errors
|
||||
|
||||
def validate_single_node(
|
||||
self,
|
||||
node: dict[str, Any],
|
||||
context: ValidationContext,
|
||||
) -> list[ValidationError]:
|
||||
"""
|
||||
Validate a single node.
|
||||
|
||||
Useful for incremental validation when a node is added/modified.
|
||||
"""
|
||||
node_type = node.get("type", "unknown")
|
||||
rules = self._registry.get_rules_for_node(node_type)
|
||||
|
||||
errors: list[ValidationError] = []
|
||||
for rule in rules:
|
||||
try:
|
||||
errors.extend(rule.check(node, context))
|
||||
except Exception:
|
||||
logger.exception("Rule '%s' failed", rule.id)
|
||||
|
||||
return errors
|
||||
|
||||
|
||||
def validate_nodes(
|
||||
nodes: list[dict[str, Any]],
|
||||
edges: list[dict[str, Any]] | None = None,
|
||||
available_models: list[dict[str, Any]] | None = None,
|
||||
available_tools: list[dict[str, Any]] | None = None,
|
||||
) -> ValidationResult:
|
||||
"""
|
||||
Convenience function to validate nodes without creating engine/context manually.
|
||||
|
||||
Args:
|
||||
nodes: List of workflow nodes to validate
|
||||
edges: Optional list of edges
|
||||
available_models: Optional list of available models
|
||||
available_tools: Optional list of available tools
|
||||
|
||||
Returns:
|
||||
ValidationResult with classified errors
|
||||
"""
|
||||
context = ValidationContext(
|
||||
nodes=nodes,
|
||||
edges=edges or [],
|
||||
available_models=available_models or [],
|
||||
available_tools=available_tools or [],
|
||||
)
|
||||
engine = ValidationEngine()
|
||||
return engine.validate(context)
|
||||
|
||||
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
|
|
@ -0,0 +1,288 @@
|
|||
"""
|
||||
Unit tests for the Mermaid Generator.
|
||||
|
||||
Tests cover:
|
||||
- Basic workflow rendering
|
||||
- Reserved word handling ('end' → 'end_node')
|
||||
- Question classifier multi-branch edges
|
||||
- If-else branch labels
|
||||
- Edge validation and skipping
|
||||
- Tool node formatting
|
||||
"""
|
||||
|
||||
|
||||
from core.workflow.generator.utils.mermaid_generator import generate_mermaid
|
||||
|
||||
|
||||
class TestBasicWorkflow:
|
||||
"""Tests for basic workflow Mermaid generation."""
|
||||
|
||||
def test_simple_start_end_workflow(self):
|
||||
"""Test simple Start → End workflow."""
|
||||
workflow_data = {
|
||||
"nodes": [
|
||||
{"id": "start", "type": "start", "title": "Start"},
|
||||
{"id": "end", "type": "end", "title": "End"},
|
||||
],
|
||||
"edges": [{"source": "start", "target": "end"}],
|
||||
}
|
||||
result = generate_mermaid(workflow_data)
|
||||
|
||||
assert "flowchart TD" in result
|
||||
assert 'start["type=start|title=Start"]' in result
|
||||
assert 'end_node["type=end|title=End"]' in result
|
||||
assert "start --> end_node" in result
|
||||
|
||||
def test_start_llm_end_workflow(self):
|
||||
"""Test Start → LLM → End workflow."""
|
||||
workflow_data = {
|
||||
"nodes": [
|
||||
{"id": "start", "type": "start", "title": "Start"},
|
||||
{"id": "llm", "type": "llm", "title": "Generate"},
|
||||
{"id": "end", "type": "end", "title": "End"},
|
||||
],
|
||||
"edges": [
|
||||
{"source": "start", "target": "llm"},
|
||||
{"source": "llm", "target": "end"},
|
||||
],
|
||||
}
|
||||
result = generate_mermaid(workflow_data)
|
||||
|
||||
assert 'llm["type=llm|title=Generate"]' in result
|
||||
assert "start --> llm" in result
|
||||
assert "llm --> end_node" in result
|
||||
|
||||
def test_empty_workflow(self):
|
||||
"""Test empty workflow returns minimal output."""
|
||||
workflow_data = {"nodes": [], "edges": []}
|
||||
result = generate_mermaid(workflow_data)
|
||||
|
||||
assert result == "flowchart TD"
|
||||
|
||||
def test_missing_keys_handled(self):
|
||||
"""Test workflow with missing keys doesn't crash."""
|
||||
workflow_data = {}
|
||||
result = generate_mermaid(workflow_data)
|
||||
|
||||
assert "flowchart TD" in result
|
||||
|
||||
|
||||
class TestReservedWords:
|
||||
"""Tests for reserved word handling in node IDs."""
|
||||
|
||||
def test_end_node_id_is_replaced(self):
|
||||
"""Test 'end' node ID is replaced with 'end_node'."""
|
||||
workflow_data = {
|
||||
"nodes": [{"id": "end", "type": "end", "title": "End"}],
|
||||
"edges": [],
|
||||
}
|
||||
result = generate_mermaid(workflow_data)
|
||||
|
||||
# Should use end_node instead of end
|
||||
assert "end_node[" in result
|
||||
assert '"type=end|title=End"' in result
|
||||
|
||||
def test_subgraph_node_id_is_replaced(self):
|
||||
"""Test 'subgraph' node ID is replaced with 'subgraph_node'."""
|
||||
workflow_data = {
|
||||
"nodes": [{"id": "subgraph", "type": "code", "title": "Process"}],
|
||||
"edges": [],
|
||||
}
|
||||
result = generate_mermaid(workflow_data)
|
||||
|
||||
assert "subgraph_node[" in result
|
||||
|
||||
def test_edge_uses_safe_ids(self):
|
||||
"""Test edges correctly reference safe IDs after replacement."""
|
||||
workflow_data = {
|
||||
"nodes": [
|
||||
{"id": "start", "type": "start", "title": "Start"},
|
||||
{"id": "end", "type": "end", "title": "End"},
|
||||
],
|
||||
"edges": [{"source": "start", "target": "end"}],
|
||||
}
|
||||
result = generate_mermaid(workflow_data)
|
||||
|
||||
# Edge should use end_node, not end
|
||||
assert "start --> end_node" in result
|
||||
assert "start --> end\n" not in result
|
||||
|
||||
|
||||
class TestBranchEdges:
|
||||
"""Tests for branching node edge labels."""
|
||||
|
||||
def test_question_classifier_source_handles(self):
|
||||
"""Test question-classifier edges with sourceHandle labels."""
|
||||
workflow_data = {
|
||||
"nodes": [
|
||||
{"id": "classifier", "type": "question-classifier", "title": "Classify"},
|
||||
{"id": "refund", "type": "llm", "title": "Handle Refund"},
|
||||
{"id": "inquiry", "type": "llm", "title": "Handle Inquiry"},
|
||||
],
|
||||
"edges": [
|
||||
{"source": "classifier", "target": "refund", "sourceHandle": "refund"},
|
||||
{"source": "classifier", "target": "inquiry", "sourceHandle": "inquiry"},
|
||||
],
|
||||
}
|
||||
result = generate_mermaid(workflow_data)
|
||||
|
||||
assert "classifier -->|refund| refund" in result
|
||||
assert "classifier -->|inquiry| inquiry" in result
|
||||
|
||||
def test_if_else_true_false_handles(self):
|
||||
"""Test if-else edges with true/false labels."""
|
||||
workflow_data = {
|
||||
"nodes": [
|
||||
{"id": "ifelse", "type": "if-else", "title": "Check"},
|
||||
{"id": "yes_branch", "type": "llm", "title": "Yes"},
|
||||
{"id": "no_branch", "type": "llm", "title": "No"},
|
||||
],
|
||||
"edges": [
|
||||
{"source": "ifelse", "target": "yes_branch", "sourceHandle": "true"},
|
||||
{"source": "ifelse", "target": "no_branch", "sourceHandle": "false"},
|
||||
],
|
||||
}
|
||||
result = generate_mermaid(workflow_data)
|
||||
|
||||
assert "ifelse -->|true| yes_branch" in result
|
||||
assert "ifelse -->|false| no_branch" in result
|
||||
|
||||
def test_source_handle_source_is_ignored(self):
|
||||
"""Test sourceHandle='source' doesn't add label."""
|
||||
workflow_data = {
|
||||
"nodes": [
|
||||
{"id": "llm1", "type": "llm", "title": "LLM 1"},
|
||||
{"id": "llm2", "type": "llm", "title": "LLM 2"},
|
||||
],
|
||||
"edges": [{"source": "llm1", "target": "llm2", "sourceHandle": "source"}],
|
||||
}
|
||||
result = generate_mermaid(workflow_data)
|
||||
|
||||
# Should be plain arrow without label
|
||||
assert "llm1 --> llm2" in result
|
||||
assert "llm1 -->|source|" not in result
|
||||
|
||||
|
||||
class TestEdgeValidation:
|
||||
"""Tests for edge validation and error handling."""
|
||||
|
||||
def test_edge_with_missing_source_is_skipped(self):
|
||||
"""Test edge with non-existent source node is skipped."""
|
||||
workflow_data = {
|
||||
"nodes": [{"id": "end", "type": "end", "title": "End"}],
|
||||
"edges": [{"source": "nonexistent", "target": "end"}],
|
||||
}
|
||||
result = generate_mermaid(workflow_data)
|
||||
|
||||
# Should not contain the invalid edge
|
||||
assert "nonexistent" not in result
|
||||
assert "-->" not in result or "nonexistent" not in result
|
||||
|
||||
def test_edge_with_missing_target_is_skipped(self):
|
||||
"""Test edge with non-existent target node is skipped."""
|
||||
workflow_data = {
|
||||
"nodes": [{"id": "start", "type": "start", "title": "Start"}],
|
||||
"edges": [{"source": "start", "target": "nonexistent"}],
|
||||
}
|
||||
result = generate_mermaid(workflow_data)
|
||||
|
||||
# Edge should be skipped
|
||||
assert "start --> nonexistent" not in result
|
||||
|
||||
def test_edge_without_source_or_target_is_skipped(self):
|
||||
"""Test edge missing source or target is skipped."""
|
||||
workflow_data = {
|
||||
"nodes": [{"id": "start", "type": "start", "title": "Start"}],
|
||||
"edges": [{"source": "start"}, {"target": "start"}, {}],
|
||||
}
|
||||
result = generate_mermaid(workflow_data)
|
||||
|
||||
# No edges should be rendered
|
||||
assert result.count("-->") == 0
|
||||
|
||||
|
||||
class TestToolNodes:
|
||||
"""Tests for tool node formatting."""
|
||||
|
||||
def test_tool_node_includes_tool_key(self):
|
||||
"""Test tool node includes tool_key in label."""
|
||||
workflow_data = {
|
||||
"nodes": [
|
||||
{
|
||||
"id": "search",
|
||||
"type": "tool",
|
||||
"title": "Search",
|
||||
"config": {"tool_key": "google/search"},
|
||||
}
|
||||
],
|
||||
"edges": [],
|
||||
}
|
||||
result = generate_mermaid(workflow_data)
|
||||
|
||||
assert 'search["type=tool|title=Search|tool=google/search"]' in result
|
||||
|
||||
def test_tool_node_with_tool_name_fallback(self):
|
||||
"""Test tool node uses tool_name as fallback."""
|
||||
workflow_data = {
|
||||
"nodes": [
|
||||
{
|
||||
"id": "tool1",
|
||||
"type": "tool",
|
||||
"title": "My Tool",
|
||||
"config": {"tool_name": "my_tool"},
|
||||
}
|
||||
],
|
||||
"edges": [],
|
||||
}
|
||||
result = generate_mermaid(workflow_data)
|
||||
|
||||
assert "tool=my_tool" in result
|
||||
|
||||
def test_tool_node_missing_tool_key_shows_unknown(self):
|
||||
"""Test tool node without tool_key shows 'unknown'."""
|
||||
workflow_data = {
|
||||
"nodes": [{"id": "tool1", "type": "tool", "title": "Tool", "config": {}}],
|
||||
"edges": [],
|
||||
}
|
||||
result = generate_mermaid(workflow_data)
|
||||
|
||||
assert "tool=unknown" in result
|
||||
|
||||
|
||||
class TestNodeFormatting:
|
||||
"""Tests for node label formatting."""
|
||||
|
||||
def test_quotes_in_title_are_escaped(self):
|
||||
"""Test double quotes in title are replaced with single quotes."""
|
||||
workflow_data = {
|
||||
"nodes": [{"id": "llm", "type": "llm", "title": 'Say "Hello"'}],
|
||||
"edges": [],
|
||||
}
|
||||
result = generate_mermaid(workflow_data)
|
||||
|
||||
# Double quotes should be replaced
|
||||
assert "Say 'Hello'" in result
|
||||
assert 'Say "Hello"' not in result
|
||||
|
||||
def test_node_without_id_is_skipped(self):
|
||||
"""Test node without id is skipped."""
|
||||
workflow_data = {
|
||||
"nodes": [{"type": "llm", "title": "No ID"}],
|
||||
"edges": [],
|
||||
}
|
||||
result = generate_mermaid(workflow_data)
|
||||
|
||||
# Should only have flowchart header
|
||||
lines = [line for line in result.split("\n") if line.strip()]
|
||||
assert len(lines) == 1
|
||||
|
||||
def test_node_default_values(self):
|
||||
"""Test node with missing type/title uses defaults."""
|
||||
workflow_data = {
|
||||
"nodes": [{"id": "node1"}],
|
||||
"edges": [],
|
||||
}
|
||||
result = generate_mermaid(workflow_data)
|
||||
|
||||
assert "type=unknown" in result
|
||||
assert "title=Untitled" in result
|
||||
|
|
@ -0,0 +1,81 @@
|
|||
from core.workflow.generator.utils.node_repair import NodeRepair
|
||||
|
||||
|
||||
class TestNodeRepair:
|
||||
"""Tests for NodeRepair utility."""
|
||||
|
||||
def test_repair_if_else_valid_operators(self):
|
||||
"""Test that valid operators remain unchanged."""
|
||||
nodes = [
|
||||
{
|
||||
"id": "node1",
|
||||
"type": "if-else",
|
||||
"config": {
|
||||
"cases": [
|
||||
{
|
||||
"conditions": [
|
||||
{"comparison_operator": "≥", "value": "1"},
|
||||
{"comparison_operator": "=", "value": "2"},
|
||||
]
|
||||
}
|
||||
]
|
||||
},
|
||||
}
|
||||
]
|
||||
result = NodeRepair.repair(nodes)
|
||||
assert result.was_repaired is False
|
||||
assert result.nodes == nodes
|
||||
|
||||
def test_repair_if_else_invalid_operators(self):
|
||||
"""Test that invalid operators are normalized."""
|
||||
nodes = [
|
||||
{
|
||||
"id": "node1",
|
||||
"type": "if-else",
|
||||
"config": {
|
||||
"cases": [
|
||||
{
|
||||
"conditions": [
|
||||
{"comparison_operator": ">=", "value": "1"},
|
||||
{"comparison_operator": "<=", "value": "2"},
|
||||
{"comparison_operator": "!=", "value": "3"},
|
||||
{"comparison_operator": "==", "value": "4"},
|
||||
]
|
||||
}
|
||||
]
|
||||
},
|
||||
}
|
||||
]
|
||||
result = NodeRepair.repair(nodes)
|
||||
assert result.was_repaired is True
|
||||
assert len(result.repairs_made) == 4
|
||||
|
||||
conditions = result.nodes[0]["config"]["cases"][0]["conditions"]
|
||||
assert conditions[0]["comparison_operator"] == "≥"
|
||||
assert conditions[1]["comparison_operator"] == "≤"
|
||||
assert conditions[2]["comparison_operator"] == "≠"
|
||||
assert conditions[3]["comparison_operator"] == "="
|
||||
|
||||
def test_repair_ignores_other_nodes(self):
|
||||
"""Test that other node types are ignored."""
|
||||
nodes = [{"id": "node1", "type": "llm", "config": {"some_field": ">="}}]
|
||||
result = NodeRepair.repair(nodes)
|
||||
assert result.was_repaired is False
|
||||
assert result.nodes[0]["config"]["some_field"] == ">="
|
||||
|
||||
def test_repair_handles_missing_config(self):
|
||||
"""Test robustness against missing fields."""
|
||||
nodes = [
|
||||
{
|
||||
"id": "node1",
|
||||
"type": "if-else",
|
||||
# Missing config
|
||||
},
|
||||
{
|
||||
"id": "node2",
|
||||
"type": "if-else",
|
||||
"config": {}, # Missing cases
|
||||
},
|
||||
]
|
||||
result = NodeRepair.repair(nodes)
|
||||
assert result.was_repaired is False
|
||||
|
|
@ -0,0 +1,173 @@
|
|||
"""
|
||||
Unit tests for the Planner Prompts.
|
||||
|
||||
Tests cover:
|
||||
- Tool formatting for planner context
|
||||
- Edge cases with missing fields
|
||||
- Empty tool lists
|
||||
"""
|
||||
|
||||
|
||||
from core.workflow.generator.prompts.planner_prompts import format_tools_for_planner
|
||||
|
||||
|
||||
class TestFormatToolsForPlanner:
|
||||
"""Tests for format_tools_for_planner function."""
|
||||
|
||||
def test_empty_tools_returns_default_message(self):
|
||||
"""Test empty tools list returns default message."""
|
||||
result = format_tools_for_planner([])
|
||||
|
||||
assert result == "No external tools available."
|
||||
|
||||
def test_none_tools_returns_default_message(self):
|
||||
"""Test None tools list returns default message."""
|
||||
result = format_tools_for_planner(None)
|
||||
|
||||
assert result == "No external tools available."
|
||||
|
||||
def test_single_tool_formatting(self):
|
||||
"""Test single tool is formatted correctly."""
|
||||
tools = [
|
||||
{
|
||||
"provider_id": "google",
|
||||
"tool_key": "search",
|
||||
"tool_label": "Google Search",
|
||||
"tool_description": "Search the web using Google",
|
||||
}
|
||||
]
|
||||
result = format_tools_for_planner(tools)
|
||||
|
||||
assert "[google/search]" in result
|
||||
assert "Google Search" in result
|
||||
assert "Search the web using Google" in result
|
||||
|
||||
def test_multiple_tools_formatting(self):
|
||||
"""Test multiple tools are formatted correctly."""
|
||||
tools = [
|
||||
{
|
||||
"provider_id": "google",
|
||||
"tool_key": "search",
|
||||
"tool_label": "Search",
|
||||
"tool_description": "Web search",
|
||||
},
|
||||
{
|
||||
"provider_id": "slack",
|
||||
"tool_key": "send_message",
|
||||
"tool_label": "Send Message",
|
||||
"tool_description": "Send a Slack message",
|
||||
},
|
||||
]
|
||||
result = format_tools_for_planner(tools)
|
||||
|
||||
lines = result.strip().split("\n")
|
||||
assert len(lines) == 2
|
||||
assert "[google/search]" in result
|
||||
assert "[slack/send_message]" in result
|
||||
|
||||
def test_tool_without_provider_uses_key_only(self):
|
||||
"""Test tool without provider_id uses tool_key only."""
|
||||
tools = [
|
||||
{
|
||||
"tool_key": "my_tool",
|
||||
"tool_label": "My Tool",
|
||||
"tool_description": "A custom tool",
|
||||
}
|
||||
]
|
||||
result = format_tools_for_planner(tools)
|
||||
|
||||
# Should format as [my_tool] without provider prefix
|
||||
assert "[my_tool]" in result
|
||||
assert "My Tool" in result
|
||||
|
||||
def test_tool_with_tool_name_fallback(self):
|
||||
"""Test tool uses tool_name when tool_key is missing."""
|
||||
tools = [
|
||||
{
|
||||
"tool_name": "fallback_tool",
|
||||
"description": "Fallback description",
|
||||
}
|
||||
]
|
||||
result = format_tools_for_planner(tools)
|
||||
|
||||
assert "fallback_tool" in result
|
||||
assert "Fallback description" in result
|
||||
|
||||
def test_tool_with_missing_description(self):
|
||||
"""Test tool with missing description doesn't crash."""
|
||||
tools = [
|
||||
{
|
||||
"provider_id": "test",
|
||||
"tool_key": "tool1",
|
||||
"tool_label": "Tool 1",
|
||||
}
|
||||
]
|
||||
result = format_tools_for_planner(tools)
|
||||
|
||||
assert "[test/tool1]" in result
|
||||
assert "Tool 1" in result
|
||||
|
||||
def test_tool_with_all_missing_fields(self):
|
||||
"""Test tool with all fields missing uses defaults."""
|
||||
tools = [{}]
|
||||
result = format_tools_for_planner(tools)
|
||||
|
||||
# Should not crash, may produce minimal output
|
||||
assert isinstance(result, str)
|
||||
|
||||
def test_tool_uses_provider_fallback(self):
|
||||
"""Test tool uses 'provider' when 'provider_id' is missing."""
|
||||
tools = [
|
||||
{
|
||||
"provider": "openai",
|
||||
"tool_key": "dalle",
|
||||
"tool_label": "DALL-E",
|
||||
"tool_description": "Generate images",
|
||||
}
|
||||
]
|
||||
result = format_tools_for_planner(tools)
|
||||
|
||||
assert "[openai/dalle]" in result
|
||||
|
||||
def test_tool_label_fallback_to_key(self):
|
||||
"""Test tool_label falls back to tool_key when missing."""
|
||||
tools = [
|
||||
{
|
||||
"provider_id": "test",
|
||||
"tool_key": "my_key",
|
||||
"tool_description": "Description here",
|
||||
}
|
||||
]
|
||||
result = format_tools_for_planner(tools)
|
||||
|
||||
# Label should fallback to key
|
||||
assert "my_key" in result
|
||||
assert "Description here" in result
|
||||
|
||||
|
||||
class TestPlannerPromptConstants:
|
||||
"""Tests for planner prompt constant availability."""
|
||||
|
||||
def test_planner_system_prompt_exists(self):
|
||||
"""Test PLANNER_SYSTEM_PROMPT is defined."""
|
||||
from core.workflow.generator.prompts.planner_prompts import PLANNER_SYSTEM_PROMPT
|
||||
|
||||
assert PLANNER_SYSTEM_PROMPT is not None
|
||||
assert len(PLANNER_SYSTEM_PROMPT) > 0
|
||||
assert "{tools_summary}" in PLANNER_SYSTEM_PROMPT
|
||||
|
||||
def test_planner_user_prompt_exists(self):
|
||||
"""Test PLANNER_USER_PROMPT is defined."""
|
||||
from core.workflow.generator.prompts.planner_prompts import PLANNER_USER_PROMPT
|
||||
|
||||
assert PLANNER_USER_PROMPT is not None
|
||||
assert "{instruction}" in PLANNER_USER_PROMPT
|
||||
|
||||
def test_planner_system_prompt_has_required_sections(self):
|
||||
"""Test PLANNER_SYSTEM_PROMPT has required XML sections."""
|
||||
from core.workflow.generator.prompts.planner_prompts import PLANNER_SYSTEM_PROMPT
|
||||
|
||||
assert "<role>" in PLANNER_SYSTEM_PROMPT
|
||||
assert "<task>" in PLANNER_SYSTEM_PROMPT
|
||||
assert "<available_tools>" in PLANNER_SYSTEM_PROMPT
|
||||
assert "<response_format>" in PLANNER_SYSTEM_PROMPT
|
||||
|
|
@ -0,0 +1,536 @@
|
|||
"""
|
||||
Unit tests for the Validation Rule Engine.
|
||||
|
||||
Tests cover:
|
||||
- Structure rules (required fields, types, formats)
|
||||
- Semantic rules (variable references, edge connections)
|
||||
- Reference rules (model exists, tool configured, dataset valid)
|
||||
- ValidationEngine integration
|
||||
"""
|
||||
|
||||
|
||||
from core.workflow.generator.validation import (
|
||||
ValidationContext,
|
||||
ValidationEngine,
|
||||
)
|
||||
from core.workflow.generator.validation.rules import (
|
||||
extract_variable_refs,
|
||||
is_placeholder,
|
||||
)
|
||||
|
||||
|
||||
class TestPlaceholderDetection:
|
||||
"""Tests for placeholder detection utility."""
|
||||
|
||||
def test_detects_please_select(self):
|
||||
assert is_placeholder("PLEASE_SELECT_YOUR_MODEL") is True
|
||||
|
||||
def test_detects_your_prefix(self):
|
||||
assert is_placeholder("YOUR_API_KEY") is True
|
||||
|
||||
def test_detects_todo(self):
|
||||
assert is_placeholder("TODO: fill this in") is True
|
||||
|
||||
def test_detects_placeholder(self):
|
||||
assert is_placeholder("PLACEHOLDER_VALUE") is True
|
||||
|
||||
def test_detects_example_prefix(self):
|
||||
assert is_placeholder("EXAMPLE_URL") is True
|
||||
|
||||
def test_detects_replace_prefix(self):
|
||||
assert is_placeholder("REPLACE_WITH_ACTUAL") is True
|
||||
|
||||
def test_case_insensitive(self):
|
||||
assert is_placeholder("please_select") is True
|
||||
assert is_placeholder("Please_Select") is True
|
||||
|
||||
def test_valid_values_not_detected(self):
|
||||
assert is_placeholder("https://api.example.com") is False
|
||||
assert is_placeholder("gpt-4") is False
|
||||
assert is_placeholder("my_variable") is False
|
||||
|
||||
def test_non_string_returns_false(self):
|
||||
assert is_placeholder(123) is False
|
||||
assert is_placeholder(None) is False
|
||||
assert is_placeholder(["list"]) is False
|
||||
|
||||
|
||||
class TestVariableRefExtraction:
|
||||
"""Tests for variable reference extraction."""
|
||||
|
||||
def test_extracts_simple_ref(self):
|
||||
refs = extract_variable_refs("Hello {{#start.query#}}")
|
||||
assert refs == [("start", "query")]
|
||||
|
||||
def test_extracts_multiple_refs(self):
|
||||
refs = extract_variable_refs("{{#node1.output#}} and {{#node2.text#}}")
|
||||
assert refs == [("node1", "output"), ("node2", "text")]
|
||||
|
||||
def test_extracts_nested_field(self):
|
||||
refs = extract_variable_refs("{{#http_request.body#}}")
|
||||
assert refs == [("http_request", "body")]
|
||||
|
||||
def test_no_refs_returns_empty(self):
|
||||
refs = extract_variable_refs("No references here")
|
||||
assert refs == []
|
||||
|
||||
def test_handles_malformed_refs(self):
|
||||
refs = extract_variable_refs("{{#invalid}} and {{incomplete#}}")
|
||||
assert refs == []
|
||||
|
||||
|
||||
class TestValidationContext:
|
||||
"""Tests for ValidationContext."""
|
||||
|
||||
def test_node_map_lookup(self):
|
||||
ctx = ValidationContext(
|
||||
nodes=[
|
||||
{"id": "start", "type": "start"},
|
||||
{"id": "llm_1", "type": "llm"},
|
||||
]
|
||||
)
|
||||
assert ctx.get_node("start") == {"id": "start", "type": "start"}
|
||||
assert ctx.get_node("nonexistent") is None
|
||||
|
||||
def test_model_set(self):
|
||||
ctx = ValidationContext(
|
||||
available_models=[
|
||||
{"provider": "openai", "model": "gpt-4"},
|
||||
{"provider": "anthropic", "model": "claude-3"},
|
||||
]
|
||||
)
|
||||
assert ctx.has_model("openai", "gpt-4") is True
|
||||
assert ctx.has_model("anthropic", "claude-3") is True
|
||||
assert ctx.has_model("openai", "gpt-3.5") is False
|
||||
|
||||
def test_tool_set(self):
|
||||
ctx = ValidationContext(
|
||||
available_tools=[
|
||||
{"provider_id": "google", "tool_key": "search", "is_team_authorization": True},
|
||||
{"provider_id": "slack", "tool_key": "send_message", "is_team_authorization": False},
|
||||
]
|
||||
)
|
||||
assert ctx.has_tool("google/search") is True
|
||||
assert ctx.has_tool("search") is True
|
||||
assert ctx.is_tool_configured("google/search") is True
|
||||
assert ctx.is_tool_configured("slack/send_message") is False
|
||||
|
||||
def test_upstream_downstream_nodes(self):
|
||||
ctx = ValidationContext(
|
||||
nodes=[
|
||||
{"id": "start", "type": "start"},
|
||||
{"id": "llm", "type": "llm"},
|
||||
{"id": "end", "type": "end"},
|
||||
],
|
||||
edges=[
|
||||
{"source": "start", "target": "llm"},
|
||||
{"source": "llm", "target": "end"},
|
||||
],
|
||||
)
|
||||
assert ctx.get_upstream_nodes("llm") == ["start"]
|
||||
assert ctx.get_downstream_nodes("llm") == ["end"]
|
||||
|
||||
|
||||
class TestStructureRules:
|
||||
"""Tests for structure validation rules."""
|
||||
|
||||
def test_llm_missing_prompt_template(self):
|
||||
ctx = ValidationContext(
|
||||
nodes=[{"id": "llm_1", "type": "llm", "config": {}}]
|
||||
)
|
||||
engine = ValidationEngine()
|
||||
result = engine.validate(ctx)
|
||||
|
||||
assert result.has_errors
|
||||
errors = [e for e in result.all_errors if e.rule_id == "llm.prompt_template.required"]
|
||||
assert len(errors) == 1
|
||||
assert errors[0].is_fixable is True
|
||||
|
||||
def test_llm_with_prompt_template_passes(self):
|
||||
ctx = ValidationContext(
|
||||
nodes=[
|
||||
{
|
||||
"id": "llm_1",
|
||||
"type": "llm",
|
||||
"config": {
|
||||
"prompt_template": [
|
||||
{"role": "system", "text": "You are helpful"},
|
||||
{"role": "user", "text": "Hello"},
|
||||
]
|
||||
},
|
||||
}
|
||||
]
|
||||
)
|
||||
engine = ValidationEngine()
|
||||
result = engine.validate(ctx)
|
||||
|
||||
# No prompt_template errors
|
||||
errors = [e for e in result.all_errors if "prompt_template" in e.rule_id]
|
||||
assert len(errors) == 0
|
||||
|
||||
def test_http_request_missing_url(self):
|
||||
ctx = ValidationContext(
|
||||
nodes=[{"id": "http_1", "type": "http-request", "config": {}}]
|
||||
)
|
||||
engine = ValidationEngine()
|
||||
result = engine.validate(ctx)
|
||||
|
||||
errors = [e for e in result.all_errors if "http.url" in e.rule_id]
|
||||
assert len(errors) == 1
|
||||
assert errors[0].is_fixable is True
|
||||
|
||||
def test_http_request_placeholder_url(self):
|
||||
ctx = ValidationContext(
|
||||
nodes=[
|
||||
{
|
||||
"id": "http_1",
|
||||
"type": "http-request",
|
||||
"config": {"url": "PLEASE_SELECT_YOUR_URL", "method": "GET"},
|
||||
}
|
||||
]
|
||||
)
|
||||
engine = ValidationEngine()
|
||||
result = engine.validate(ctx)
|
||||
|
||||
errors = [e for e in result.all_errors if "placeholder" in e.rule_id]
|
||||
assert len(errors) == 1
|
||||
|
||||
def test_code_node_missing_fields(self):
|
||||
ctx = ValidationContext(
|
||||
nodes=[{"id": "code_1", "type": "code", "config": {}}]
|
||||
)
|
||||
engine = ValidationEngine()
|
||||
result = engine.validate(ctx)
|
||||
|
||||
error_rules = {e.rule_id for e in result.all_errors}
|
||||
assert "code.code.required" in error_rules
|
||||
assert "code.language.required" in error_rules
|
||||
|
||||
def test_knowledge_retrieval_missing_dataset(self):
|
||||
ctx = ValidationContext(
|
||||
nodes=[{"id": "kb_1", "type": "knowledge-retrieval", "config": {}}]
|
||||
)
|
||||
engine = ValidationEngine()
|
||||
result = engine.validate(ctx)
|
||||
|
||||
errors = [e for e in result.all_errors if "knowledge.dataset" in e.rule_id]
|
||||
assert len(errors) == 1
|
||||
assert errors[0].is_fixable is False # User must configure
|
||||
|
||||
|
||||
class TestSemanticRules:
|
||||
"""Tests for semantic validation rules."""
|
||||
|
||||
def test_valid_variable_reference(self):
|
||||
ctx = ValidationContext(
|
||||
nodes=[
|
||||
{"id": "start", "type": "start", "config": {}},
|
||||
{
|
||||
"id": "llm_1",
|
||||
"type": "llm",
|
||||
"config": {
|
||||
"prompt_template": [
|
||||
{"role": "user", "text": "Process: {{#start.query#}}"}
|
||||
]
|
||||
},
|
||||
},
|
||||
]
|
||||
)
|
||||
engine = ValidationEngine()
|
||||
result = engine.validate(ctx)
|
||||
|
||||
# No variable reference errors
|
||||
errors = [e for e in result.all_errors if "variable.ref" in e.rule_id]
|
||||
assert len(errors) == 0
|
||||
|
||||
def test_invalid_variable_reference(self):
|
||||
ctx = ValidationContext(
|
||||
nodes=[
|
||||
{"id": "start", "type": "start", "config": {}},
|
||||
{
|
||||
"id": "llm_1",
|
||||
"type": "llm",
|
||||
"config": {
|
||||
"prompt_template": [
|
||||
{"role": "user", "text": "Process: {{#nonexistent.field#}}"}
|
||||
]
|
||||
},
|
||||
},
|
||||
]
|
||||
)
|
||||
engine = ValidationEngine()
|
||||
result = engine.validate(ctx)
|
||||
|
||||
errors = [e for e in result.all_errors if "variable.ref" in e.rule_id]
|
||||
assert len(errors) == 1
|
||||
assert "nonexistent" in errors[0].message
|
||||
|
||||
def test_edge_validation(self):
|
||||
ctx = ValidationContext(
|
||||
nodes=[
|
||||
{"id": "start", "type": "start", "config": {}},
|
||||
{"id": "end", "type": "end", "config": {}},
|
||||
],
|
||||
edges=[
|
||||
{"source": "start", "target": "end"},
|
||||
{"source": "nonexistent", "target": "end"},
|
||||
],
|
||||
)
|
||||
engine = ValidationEngine()
|
||||
result = engine.validate(ctx)
|
||||
|
||||
errors = [e for e in result.all_errors if "edge" in e.rule_id]
|
||||
assert len(errors) == 1
|
||||
assert "nonexistent" in errors[0].message
|
||||
|
||||
|
||||
class TestReferenceRules:
|
||||
"""Tests for reference validation rules (models, tools)."""
|
||||
|
||||
def test_llm_missing_model_with_available(self):
|
||||
ctx = ValidationContext(
|
||||
nodes=[
|
||||
{
|
||||
"id": "llm_1",
|
||||
"type": "llm",
|
||||
"config": {"prompt_template": [{"role": "user", "text": "Hi"}]},
|
||||
}
|
||||
],
|
||||
available_models=[{"provider": "openai", "model": "gpt-4"}],
|
||||
)
|
||||
engine = ValidationEngine()
|
||||
result = engine.validate(ctx)
|
||||
|
||||
errors = [e for e in result.all_errors if e.rule_id == "model.required"]
|
||||
assert len(errors) == 1
|
||||
assert errors[0].is_fixable is True
|
||||
|
||||
def test_llm_missing_model_no_available(self):
|
||||
ctx = ValidationContext(
|
||||
nodes=[
|
||||
{
|
||||
"id": "llm_1",
|
||||
"type": "llm",
|
||||
"config": {"prompt_template": [{"role": "user", "text": "Hi"}]},
|
||||
}
|
||||
],
|
||||
available_models=[], # No models available
|
||||
)
|
||||
engine = ValidationEngine()
|
||||
result = engine.validate(ctx)
|
||||
|
||||
errors = [e for e in result.all_errors if e.rule_id == "model.no_available"]
|
||||
assert len(errors) == 1
|
||||
assert errors[0].is_fixable is False
|
||||
|
||||
def test_llm_with_valid_model(self):
|
||||
ctx = ValidationContext(
|
||||
nodes=[
|
||||
{
|
||||
"id": "llm_1",
|
||||
"type": "llm",
|
||||
"config": {
|
||||
"prompt_template": [{"role": "user", "text": "Hi"}],
|
||||
"model": {"provider": "openai", "name": "gpt-4"},
|
||||
},
|
||||
}
|
||||
],
|
||||
available_models=[{"provider": "openai", "model": "gpt-4"}],
|
||||
)
|
||||
engine = ValidationEngine()
|
||||
result = engine.validate(ctx)
|
||||
|
||||
errors = [e for e in result.all_errors if "model" in e.rule_id]
|
||||
assert len(errors) == 0
|
||||
|
||||
def test_llm_with_invalid_model(self):
|
||||
ctx = ValidationContext(
|
||||
nodes=[
|
||||
{
|
||||
"id": "llm_1",
|
||||
"type": "llm",
|
||||
"config": {
|
||||
"prompt_template": [{"role": "user", "text": "Hi"}],
|
||||
"model": {"provider": "openai", "name": "gpt-99"},
|
||||
},
|
||||
}
|
||||
],
|
||||
available_models=[{"provider": "openai", "model": "gpt-4"}],
|
||||
)
|
||||
engine = ValidationEngine()
|
||||
result = engine.validate(ctx)
|
||||
|
||||
errors = [e for e in result.all_errors if e.rule_id == "model.not_found"]
|
||||
assert len(errors) == 1
|
||||
assert errors[0].is_fixable is True
|
||||
|
||||
def test_tool_node_not_found(self):
|
||||
ctx = ValidationContext(
|
||||
nodes=[
|
||||
{
|
||||
"id": "tool_1",
|
||||
"type": "tool",
|
||||
"config": {"tool_key": "nonexistent/tool"},
|
||||
}
|
||||
],
|
||||
available_tools=[],
|
||||
)
|
||||
engine = ValidationEngine()
|
||||
result = engine.validate(ctx)
|
||||
|
||||
errors = [e for e in result.all_errors if e.rule_id == "tool.not_found"]
|
||||
assert len(errors) == 1
|
||||
|
||||
def test_tool_node_not_configured(self):
|
||||
ctx = ValidationContext(
|
||||
nodes=[
|
||||
{
|
||||
"id": "tool_1",
|
||||
"type": "tool",
|
||||
"config": {"tool_key": "google/search"},
|
||||
}
|
||||
],
|
||||
available_tools=[
|
||||
{"provider_id": "google", "tool_key": "search", "is_team_authorization": False}
|
||||
],
|
||||
)
|
||||
engine = ValidationEngine()
|
||||
result = engine.validate(ctx)
|
||||
|
||||
errors = [e for e in result.all_errors if e.rule_id == "tool.not_configured"]
|
||||
assert len(errors) == 1
|
||||
assert errors[0].is_fixable is False
|
||||
|
||||
|
||||
class TestValidationResult:
|
||||
"""Tests for ValidationResult classification."""
|
||||
|
||||
def test_has_errors(self):
|
||||
ctx = ValidationContext(
|
||||
nodes=[{"id": "llm_1", "type": "llm", "config": {}}]
|
||||
)
|
||||
engine = ValidationEngine()
|
||||
result = engine.validate(ctx)
|
||||
|
||||
assert result.has_errors is True
|
||||
assert result.is_valid is False
|
||||
|
||||
def test_has_fixable_errors(self):
|
||||
ctx = ValidationContext(
|
||||
nodes=[
|
||||
{
|
||||
"id": "llm_1",
|
||||
"type": "llm",
|
||||
"config": {"prompt_template": [{"role": "user", "text": "Hi"}]},
|
||||
}
|
||||
],
|
||||
available_models=[{"provider": "openai", "model": "gpt-4"}],
|
||||
)
|
||||
engine = ValidationEngine()
|
||||
result = engine.validate(ctx)
|
||||
|
||||
assert result.has_fixable_errors is True
|
||||
assert len(result.fixable_errors) > 0
|
||||
|
||||
def test_get_fixable_by_node(self):
|
||||
ctx = ValidationContext(
|
||||
nodes=[
|
||||
{"id": "llm_1", "type": "llm", "config": {}},
|
||||
{"id": "http_1", "type": "http-request", "config": {}},
|
||||
]
|
||||
)
|
||||
engine = ValidationEngine()
|
||||
result = engine.validate(ctx)
|
||||
|
||||
by_node = result.get_fixable_by_node()
|
||||
assert "llm_1" in by_node
|
||||
assert "http_1" in by_node
|
||||
|
||||
def test_to_dict(self):
|
||||
ctx = ValidationContext(
|
||||
nodes=[{"id": "llm_1", "type": "llm", "config": {}}]
|
||||
)
|
||||
engine = ValidationEngine()
|
||||
result = engine.validate(ctx)
|
||||
|
||||
d = result.to_dict()
|
||||
assert "fixable" in d
|
||||
assert "user_required" in d
|
||||
assert "warnings" in d
|
||||
assert "all_warnings" in d
|
||||
assert "stats" in d
|
||||
|
||||
|
||||
class TestIntegration:
|
||||
"""Integration tests for the full validation pipeline."""
|
||||
|
||||
def test_complete_workflow_validation(self):
|
||||
"""Test validation of a complete workflow."""
|
||||
ctx = ValidationContext(
|
||||
nodes=[
|
||||
{
|
||||
"id": "start",
|
||||
"type": "start",
|
||||
"config": {"variables": [{"variable": "query", "type": "text-input"}]},
|
||||
},
|
||||
{
|
||||
"id": "llm_1",
|
||||
"type": "llm",
|
||||
"config": {
|
||||
"model": {"provider": "openai", "name": "gpt-4"},
|
||||
"prompt_template": [{"role": "user", "text": "{{#start.query#}}"}],
|
||||
},
|
||||
},
|
||||
{
|
||||
"id": "end",
|
||||
"type": "end",
|
||||
"config": {"outputs": [{"variable": "result", "value_selector": ["llm_1", "text"]}]},
|
||||
},
|
||||
],
|
||||
edges=[
|
||||
{"source": "start", "target": "llm_1"},
|
||||
{"source": "llm_1", "target": "end"},
|
||||
],
|
||||
available_models=[{"provider": "openai", "model": "gpt-4"}],
|
||||
)
|
||||
engine = ValidationEngine()
|
||||
result = engine.validate(ctx)
|
||||
|
||||
# Should have no errors
|
||||
assert result.is_valid is True
|
||||
assert len(result.fixable_errors) == 0
|
||||
assert len(result.user_required_errors) == 0
|
||||
|
||||
def test_workflow_with_multiple_errors(self):
|
||||
"""Test workflow with multiple types of errors."""
|
||||
ctx = ValidationContext(
|
||||
nodes=[
|
||||
{"id": "start", "type": "start", "config": {}},
|
||||
{
|
||||
"id": "llm_1",
|
||||
"type": "llm",
|
||||
"config": {}, # Missing prompt_template and model
|
||||
},
|
||||
{
|
||||
"id": "kb_1",
|
||||
"type": "knowledge-retrieval",
|
||||
"config": {"dataset_ids": ["PLEASE_SELECT_YOUR_DATASET"]},
|
||||
},
|
||||
{"id": "end", "type": "end", "config": {}},
|
||||
],
|
||||
available_models=[{"provider": "openai", "model": "gpt-4"}],
|
||||
)
|
||||
engine = ValidationEngine()
|
||||
result = engine.validate(ctx)
|
||||
|
||||
# Should have multiple errors
|
||||
assert result.has_errors is True
|
||||
assert len(result.fixable_errors) >= 2 # model, prompt_template
|
||||
assert len(result.user_required_errors) >= 1 # dataset placeholder
|
||||
|
||||
# Check stats
|
||||
assert result.stats["total_nodes"] == 4
|
||||
assert result.stats["total_errors"] >= 3
|
||||
|
||||
|
||||
|
||||
|
|
@ -0,0 +1,435 @@
|
|||
"""
|
||||
Unit tests for the Vibe Workflow Validator.
|
||||
|
||||
Tests cover:
|
||||
- Basic validation function
|
||||
- User-friendly validation hints
|
||||
- Edge cases and error handling
|
||||
"""
|
||||
|
||||
|
||||
from core.workflow.generator.utils.workflow_validator import ValidationHint, WorkflowValidator
|
||||
|
||||
|
||||
class TestValidationHint:
|
||||
"""Tests for ValidationHint dataclass."""
|
||||
|
||||
def test_hint_creation(self):
|
||||
"""Test creating a validation hint."""
|
||||
hint = ValidationHint(
|
||||
node_id="llm_1",
|
||||
field="model",
|
||||
message="Model is not configured",
|
||||
severity="error",
|
||||
)
|
||||
assert hint.node_id == "llm_1"
|
||||
assert hint.field == "model"
|
||||
assert hint.message == "Model is not configured"
|
||||
assert hint.severity == "error"
|
||||
|
||||
def test_hint_with_suggestion(self):
|
||||
"""Test hint with suggestion."""
|
||||
hint = ValidationHint(
|
||||
node_id="http_1",
|
||||
field="url",
|
||||
message="URL is required",
|
||||
severity="error",
|
||||
suggestion="Add a valid URL like https://api.example.com",
|
||||
)
|
||||
assert hint.suggestion is not None
|
||||
|
||||
|
||||
class TestWorkflowValidatorBasic:
|
||||
"""Tests for basic validation scenarios."""
|
||||
|
||||
def test_empty_workflow_is_valid(self):
|
||||
"""Test empty workflow passes validation."""
|
||||
workflow_data = {"nodes": [], "edges": []}
|
||||
is_valid, hints = WorkflowValidator.validate(workflow_data, [])
|
||||
|
||||
# Empty but valid structure
|
||||
assert is_valid is True
|
||||
assert len(hints) == 0
|
||||
|
||||
def test_minimal_valid_workflow(self):
|
||||
"""Test minimal Start → End workflow."""
|
||||
workflow_data = {
|
||||
"nodes": [
|
||||
{"id": "start", "type": "start", "config": {}},
|
||||
{"id": "end", "type": "end", "config": {}},
|
||||
],
|
||||
"edges": [{"source": "start", "target": "end"}],
|
||||
}
|
||||
is_valid, hints = WorkflowValidator.validate(workflow_data, [])
|
||||
|
||||
assert is_valid is True
|
||||
|
||||
def test_complete_workflow_with_llm(self):
|
||||
"""Test complete workflow with LLM node."""
|
||||
workflow_data = {
|
||||
"nodes": [
|
||||
{"id": "start", "type": "start", "config": {"variables": []}},
|
||||
{
|
||||
"id": "llm",
|
||||
"type": "llm",
|
||||
"config": {
|
||||
"model": {"provider": "openai", "name": "gpt-4"},
|
||||
"prompt_template": [{"role": "user", "text": "Hello"}],
|
||||
},
|
||||
},
|
||||
{"id": "end", "type": "end", "config": {"outputs": []}},
|
||||
],
|
||||
"edges": [
|
||||
{"source": "start", "target": "llm"},
|
||||
{"source": "llm", "target": "end"},
|
||||
],
|
||||
}
|
||||
is_valid, hints = WorkflowValidator.validate(workflow_data, [])
|
||||
|
||||
# Should pass with no critical errors
|
||||
errors = [h for h in hints if h.severity == "error"]
|
||||
assert len(errors) == 0
|
||||
|
||||
|
||||
class TestVariableReferenceValidation:
|
||||
"""Tests for variable reference validation."""
|
||||
|
||||
def test_valid_variable_reference(self):
|
||||
"""Test valid variable reference passes."""
|
||||
workflow_data = {
|
||||
"nodes": [
|
||||
{"id": "start", "type": "start", "config": {}},
|
||||
{
|
||||
"id": "llm",
|
||||
"type": "llm",
|
||||
"config": {"prompt_template": [{"role": "user", "text": "Query: {{#start.query#}}"}]},
|
||||
},
|
||||
],
|
||||
"edges": [{"source": "start", "target": "llm"}],
|
||||
}
|
||||
is_valid, hints = WorkflowValidator.validate(workflow_data, [])
|
||||
|
||||
ref_errors = [h for h in hints if "reference" in h.message.lower()]
|
||||
assert len(ref_errors) == 0
|
||||
|
||||
def test_invalid_variable_reference(self):
|
||||
"""Test invalid variable reference generates hint."""
|
||||
workflow_data = {
|
||||
"nodes": [
|
||||
{"id": "start", "type": "start", "config": {}},
|
||||
{
|
||||
"id": "llm",
|
||||
"type": "llm",
|
||||
"config": {"prompt_template": [{"role": "user", "text": "{{#nonexistent.field#}}"}]},
|
||||
},
|
||||
],
|
||||
"edges": [{"source": "start", "target": "llm"}],
|
||||
}
|
||||
is_valid, hints = WorkflowValidator.validate(workflow_data, [])
|
||||
|
||||
# Should have a hint about invalid reference
|
||||
ref_hints = [h for h in hints if "nonexistent" in h.message or "reference" in h.message.lower()]
|
||||
assert len(ref_hints) >= 1
|
||||
|
||||
|
||||
class TestEdgeValidation:
|
||||
"""Tests for edge validation."""
|
||||
|
||||
def test_edge_with_invalid_source(self):
|
||||
"""Test edge with non-existent source generates hint."""
|
||||
workflow_data = {
|
||||
"nodes": [{"id": "end", "type": "end", "config": {}}],
|
||||
"edges": [{"source": "nonexistent", "target": "end"}],
|
||||
}
|
||||
is_valid, hints = WorkflowValidator.validate(workflow_data, [])
|
||||
|
||||
# Should have hint about invalid edge
|
||||
edge_hints = [h for h in hints if "edge" in h.message.lower() or "source" in h.message.lower()]
|
||||
assert len(edge_hints) >= 1
|
||||
|
||||
def test_edge_with_invalid_target(self):
|
||||
"""Test edge with non-existent target generates hint."""
|
||||
workflow_data = {
|
||||
"nodes": [{"id": "start", "type": "start", "config": {}}],
|
||||
"edges": [{"source": "start", "target": "nonexistent"}],
|
||||
}
|
||||
is_valid, hints = WorkflowValidator.validate(workflow_data, [])
|
||||
|
||||
edge_hints = [h for h in hints if "edge" in h.message.lower() or "target" in h.message.lower()]
|
||||
assert len(edge_hints) >= 1
|
||||
|
||||
|
||||
class TestToolValidation:
|
||||
"""Tests for tool node validation."""
|
||||
|
||||
def test_tool_node_found_in_available(self):
|
||||
"""Test tool node that exists in available tools."""
|
||||
workflow_data = {
|
||||
"nodes": [
|
||||
{"id": "start", "type": "start", "config": {}},
|
||||
{
|
||||
"id": "tool1",
|
||||
"type": "tool",
|
||||
"config": {"tool_key": "google/search"},
|
||||
},
|
||||
{"id": "end", "type": "end", "config": {}},
|
||||
],
|
||||
"edges": [{"source": "start", "target": "tool1"}, {"source": "tool1", "target": "end"}],
|
||||
}
|
||||
available_tools = [{"provider_id": "google", "tool_key": "search", "is_team_authorization": True}]
|
||||
is_valid, hints = WorkflowValidator.validate(workflow_data, available_tools)
|
||||
|
||||
tool_errors = [h for h in hints if h.severity == "error" and "tool" in h.message.lower()]
|
||||
assert len(tool_errors) == 0
|
||||
|
||||
def test_tool_node_not_found(self):
|
||||
"""Test tool node not in available tools generates hint."""
|
||||
workflow_data = {
|
||||
"nodes": [
|
||||
{
|
||||
"id": "tool1",
|
||||
"type": "tool",
|
||||
"config": {"tool_key": "unknown/tool"},
|
||||
}
|
||||
],
|
||||
"edges": [],
|
||||
}
|
||||
available_tools = []
|
||||
is_valid, hints = WorkflowValidator.validate(workflow_data, available_tools)
|
||||
|
||||
tool_hints = [h for h in hints if "tool" in h.message.lower()]
|
||||
assert len(tool_hints) >= 1
|
||||
|
||||
|
||||
class TestQuestionClassifierValidation:
|
||||
"""Tests for question-classifier node validation."""
|
||||
|
||||
def test_question_classifier_with_classes(self):
|
||||
"""Test question-classifier with valid classes."""
|
||||
workflow_data = {
|
||||
"nodes": [
|
||||
{"id": "start", "type": "start", "config": {}},
|
||||
{
|
||||
"id": "classifier",
|
||||
"type": "question-classifier",
|
||||
"config": {
|
||||
"classes": [
|
||||
{"id": "class1", "name": "Class 1"},
|
||||
{"id": "class2", "name": "Class 2"},
|
||||
],
|
||||
"model": {"provider": "openai", "name": "gpt-4", "mode": "chat"},
|
||||
},
|
||||
},
|
||||
{"id": "h1", "type": "llm", "config": {}},
|
||||
{"id": "h2", "type": "llm", "config": {}},
|
||||
{"id": "end", "type": "end", "config": {}},
|
||||
],
|
||||
"edges": [
|
||||
{"source": "start", "target": "classifier"},
|
||||
{"source": "classifier", "sourceHandle": "class1", "target": "h1"},
|
||||
{"source": "classifier", "sourceHandle": "class2", "target": "h2"},
|
||||
{"source": "h1", "target": "end"},
|
||||
{"source": "h2", "target": "end"},
|
||||
],
|
||||
}
|
||||
available_models = [{"provider": "openai", "model": "gpt-4", "mode": "chat"}]
|
||||
is_valid, hints = WorkflowValidator.validate(workflow_data, [], available_models=available_models)
|
||||
|
||||
class_errors = [h for h in hints if "class" in h.message.lower() and h.severity == "error"]
|
||||
assert len(class_errors) == 0
|
||||
|
||||
def test_question_classifier_missing_classes(self):
|
||||
"""Test question-classifier without classes generates hint."""
|
||||
workflow_data = {
|
||||
"nodes": [
|
||||
{
|
||||
"id": "classifier",
|
||||
"type": "question-classifier",
|
||||
"config": {"model": {"provider": "openai", "name": "gpt-4", "mode": "chat"}},
|
||||
}
|
||||
],
|
||||
"edges": [],
|
||||
}
|
||||
available_models = [{"provider": "openai", "model": "gpt-4", "mode": "chat"}]
|
||||
is_valid, hints = WorkflowValidator.validate(workflow_data, [], available_models=available_models)
|
||||
|
||||
# Should have hint about missing classes
|
||||
class_hints = [h for h in hints if "class" in h.message.lower()]
|
||||
assert len(class_hints) >= 1
|
||||
|
||||
|
||||
class TestHttpRequestValidation:
|
||||
"""Tests for HTTP request node validation."""
|
||||
|
||||
def test_http_request_with_url(self):
|
||||
"""Test HTTP request with valid URL."""
|
||||
workflow_data = {
|
||||
"nodes": [
|
||||
{"id": "start", "type": "start", "config": {}},
|
||||
{
|
||||
"id": "http",
|
||||
"type": "http-request",
|
||||
"config": {"url": "https://api.example.com", "method": "GET"},
|
||||
},
|
||||
{"id": "end", "type": "end", "config": {}},
|
||||
],
|
||||
"edges": [{"source": "start", "target": "http"}, {"source": "http", "target": "end"}],
|
||||
}
|
||||
is_valid, hints = WorkflowValidator.validate(workflow_data, [])
|
||||
|
||||
url_errors = [h for h in hints if "url" in h.message.lower() and h.severity == "error"]
|
||||
assert len(url_errors) == 0
|
||||
|
||||
def test_http_request_missing_url(self):
|
||||
"""Test HTTP request without URL generates hint."""
|
||||
workflow_data = {
|
||||
"nodes": [
|
||||
{
|
||||
"id": "http",
|
||||
"type": "http-request",
|
||||
"config": {"method": "GET"},
|
||||
}
|
||||
],
|
||||
"edges": [],
|
||||
}
|
||||
is_valid, hints = WorkflowValidator.validate(workflow_data, [])
|
||||
|
||||
url_hints = [h for h in hints if "url" in h.message.lower()]
|
||||
assert len(url_hints) >= 1
|
||||
|
||||
|
||||
class TestParameterExtractorValidation:
|
||||
"""Tests for parameter-extractor node validation."""
|
||||
|
||||
def test_parameter_extractor_valid_params(self):
|
||||
"""Test parameter-extractor with valid parameters."""
|
||||
workflow_data = {
|
||||
"nodes": [
|
||||
{"id": "start", "type": "start", "config": {}},
|
||||
{
|
||||
"id": "extractor",
|
||||
"type": "parameter-extractor",
|
||||
"config": {
|
||||
"instruction": "Extract info",
|
||||
"parameters": [
|
||||
{
|
||||
"name": "name",
|
||||
"type": "string",
|
||||
"description": "Name",
|
||||
"required": True,
|
||||
}
|
||||
],
|
||||
"model": {"provider": "openai", "name": "gpt-4", "mode": "chat"},
|
||||
},
|
||||
},
|
||||
{"id": "end", "type": "end", "config": {}},
|
||||
],
|
||||
"edges": [{"source": "start", "target": "extractor"}, {"source": "extractor", "target": "end"}],
|
||||
}
|
||||
available_models = [{"provider": "openai", "model": "gpt-4", "mode": "chat"}]
|
||||
is_valid, hints = WorkflowValidator.validate(workflow_data, [], available_models=available_models)
|
||||
|
||||
errors = [h for h in hints if h.severity == "error"]
|
||||
assert len(errors) == 0
|
||||
|
||||
def test_parameter_extractor_missing_required_field(self):
|
||||
"""Test parameter-extractor missing 'required' field in parameter item."""
|
||||
workflow_data = {
|
||||
"nodes": [
|
||||
{
|
||||
"id": "extractor",
|
||||
"type": "parameter-extractor",
|
||||
"config": {
|
||||
"instruction": "Extract info",
|
||||
"parameters": [
|
||||
{
|
||||
"name": "name",
|
||||
"type": "string",
|
||||
"description": "Name",
|
||||
# Missing 'required'
|
||||
}
|
||||
],
|
||||
"model": {"provider": "openai", "name": "gpt-4", "mode": "chat"},
|
||||
},
|
||||
}
|
||||
],
|
||||
"edges": [],
|
||||
}
|
||||
available_models = [{"provider": "openai", "model": "gpt-4", "mode": "chat"}]
|
||||
is_valid, hints = WorkflowValidator.validate(workflow_data, [], available_models=available_models)
|
||||
|
||||
errors = [h for h in hints if "required" in h.message and h.severity == "error"]
|
||||
assert len(errors) >= 1
|
||||
assert "parameter-extractor" in errors[0].node_type
|
||||
|
||||
|
||||
class TestIfElseValidation:
|
||||
"""Tests for if-else node validation."""
|
||||
|
||||
def test_if_else_valid_operators(self):
|
||||
"""Test if-else with valid operators."""
|
||||
workflow_data = {
|
||||
"nodes": [
|
||||
{"id": "start", "type": "start", "config": {}},
|
||||
{
|
||||
"id": "ifelse",
|
||||
"type": "if-else",
|
||||
"config": {
|
||||
"cases": [{"case_id": "c1", "conditions": [{"comparison_operator": "≥", "value": "1"}]}]
|
||||
},
|
||||
},
|
||||
{"id": "t", "type": "llm", "config": {}},
|
||||
{"id": "f", "type": "llm", "config": {}},
|
||||
{"id": "end", "type": "end", "config": {}},
|
||||
],
|
||||
"edges": [
|
||||
{"source": "start", "target": "ifelse"},
|
||||
{"source": "ifelse", "sourceHandle": "true", "target": "t"},
|
||||
{"source": "ifelse", "sourceHandle": "false", "target": "f"},
|
||||
{"source": "t", "target": "end"},
|
||||
{"source": "f", "target": "end"},
|
||||
],
|
||||
}
|
||||
is_valid, hints = WorkflowValidator.validate(workflow_data, [])
|
||||
errors = [h for h in hints if h.severity == "error"]
|
||||
# Filter out LLM model errors if any (available tools/models check might trigger)
|
||||
# (actually available_models empty list might trigger model error?
|
||||
# No, model config validation skips if model field not present? No, LLM has model config.
|
||||
# But logic skips check if key missing? Let's check logic.
|
||||
# _check_model_config checks if provider/name match available. If available is empty, it fails.
|
||||
# But wait, validate default available_models is None?
|
||||
# I should provide mock available_models or ignore model errors.
|
||||
|
||||
# Actually LLM node "config": {} implies missing model config. Rules check if config structure is valid?
|
||||
# Let's filter specifically for operator errors.
|
||||
operator_errors = [h for h in errors if "operator" in h.message]
|
||||
assert len(operator_errors) == 0
|
||||
|
||||
def test_if_else_invalid_operators(self):
|
||||
"""Test if-else with invalid operators."""
|
||||
workflow_data = {
|
||||
"nodes": [
|
||||
{"id": "start", "type": "start", "config": {}},
|
||||
{
|
||||
"id": "ifelse",
|
||||
"type": "if-else",
|
||||
"config": {
|
||||
"cases": [{"case_id": "c1", "conditions": [{"comparison_operator": ">=", "value": "1"}]}]
|
||||
},
|
||||
},
|
||||
{"id": "t", "type": "llm", "config": {}},
|
||||
{"id": "f", "type": "llm", "config": {}},
|
||||
{"id": "end", "type": "end", "config": {}},
|
||||
],
|
||||
"edges": [
|
||||
{"source": "start", "target": "ifelse"},
|
||||
{"source": "ifelse", "sourceHandle": "true", "target": "t"},
|
||||
{"source": "ifelse", "sourceHandle": "false", "target": "f"},
|
||||
{"source": "t", "target": "end"},
|
||||
{"source": "f", "target": "end"},
|
||||
],
|
||||
}
|
||||
is_valid, hints = WorkflowValidator.validate(workflow_data, [])
|
||||
operator_errors = [h for h in hints if "operator" in h.message and h.severity == "error"]
|
||||
assert len(operator_errors) > 0
|
||||
assert "≥" in operator_errors[0].suggestion
|
||||
|
|
@ -0,0 +1,82 @@
|
|||
|
||||
import { describe, it, expect } from 'vitest'
|
||||
import { replaceVariableReferences } from '../use-workflow-vibe'
|
||||
import { BlockEnum } from '@/app/components/workflow/types'
|
||||
|
||||
// Mock types needed for the test
|
||||
interface NodeData {
|
||||
title: string
|
||||
[key: string]: any
|
||||
}
|
||||
|
||||
describe('use-workflow-vibe', () => {
|
||||
describe('replaceVariableReferences', () => {
|
||||
it('should replace variable references in strings', () => {
|
||||
const data = {
|
||||
title: 'Test Node',
|
||||
prompt: 'Hello {{#old_id.query#}}',
|
||||
}
|
||||
const nodeIdMap = new Map<string, any>()
|
||||
nodeIdMap.set('old_id', { id: 'new_uuid', data: { type: 'llm' } })
|
||||
|
||||
const result = replaceVariableReferences(data, nodeIdMap) as NodeData
|
||||
expect(result.prompt).toBe('Hello {{#new_uuid.query#}}')
|
||||
})
|
||||
|
||||
it('should handle multiple references in one string', () => {
|
||||
const data = {
|
||||
title: 'Test Node',
|
||||
text: '{{#node1.out#}} and {{#node2.out#}}',
|
||||
}
|
||||
const nodeIdMap = new Map<string, any>()
|
||||
nodeIdMap.set('node1', { id: 'uuid1', data: { type: 'llm' } })
|
||||
nodeIdMap.set('node2', { id: 'uuid2', data: { type: 'llm' } })
|
||||
|
||||
const result = replaceVariableReferences(data, nodeIdMap) as NodeData
|
||||
expect(result.text).toBe('{{#uuid1.out#}} and {{#uuid2.out#}}')
|
||||
})
|
||||
|
||||
it('should replace variable references in value_selector arrays', () => {
|
||||
const data = {
|
||||
title: 'End Node',
|
||||
outputs: [
|
||||
{
|
||||
variable: 'result',
|
||||
value_selector: ['old_id', 'text'],
|
||||
},
|
||||
],
|
||||
}
|
||||
const nodeIdMap = new Map<string, any>()
|
||||
nodeIdMap.set('old_id', { id: 'new_uuid', data: { type: 'llm' } })
|
||||
|
||||
const result = replaceVariableReferences(data, nodeIdMap) as NodeData
|
||||
expect(result.outputs[0].value_selector).toEqual(['new_uuid', 'text'])
|
||||
})
|
||||
|
||||
it('should handle nested objects recursively', () => {
|
||||
const data = {
|
||||
config: {
|
||||
model: {
|
||||
prompt: '{{#old_id.text#}}',
|
||||
},
|
||||
},
|
||||
}
|
||||
const nodeIdMap = new Map<string, any>()
|
||||
nodeIdMap.set('old_id', { id: 'new_uuid', data: { type: 'llm' } })
|
||||
|
||||
const result = replaceVariableReferences(data, nodeIdMap) as any
|
||||
expect(result.config.model.prompt).toBe('{{#new_uuid.text#}}')
|
||||
})
|
||||
|
||||
it('should ignoring missing node mappings', () => {
|
||||
const data = {
|
||||
text: '{{#missing_id.text#}}',
|
||||
}
|
||||
const nodeIdMap = new Map<string, any>()
|
||||
// missing_id is not in map
|
||||
|
||||
const result = replaceVariableReferences(data, nodeIdMap) as NodeData
|
||||
expect(result.text).toBe('{{#missing_id.text#}}')
|
||||
})
|
||||
})
|
||||
})
|
||||
|
|
@ -39,6 +39,7 @@ import {
|
|||
getNodeCustomTypeByNodeDataType,
|
||||
getNodesConnectedSourceOrTargetHandleIdsMap,
|
||||
} from '../utils'
|
||||
import { initialNodes as initializeNodeData } from '../utils/workflow-init'
|
||||
import { useNodesMetaData } from './use-nodes-meta-data'
|
||||
import { useNodesSyncDraft } from './use-nodes-sync-draft'
|
||||
import { useNodesReadOnly } from './use-workflow'
|
||||
|
|
@ -115,7 +116,7 @@ const normalizeProviderIcon = (icon?: ToolWithProvider['icon']) => {
|
|||
* - Mixed content objects: {type: "mixed", value: "..."} → normalized to string
|
||||
* - Field name correction based on node type
|
||||
*/
|
||||
const replaceVariableReferences = (
|
||||
export const replaceVariableReferences = (
|
||||
data: unknown,
|
||||
nodeIdMap: Map<string, Node>,
|
||||
parentKey?: string,
|
||||
|
|
@ -124,6 +125,11 @@ const replaceVariableReferences = (
|
|||
// Replace {{#old_id.field#}} patterns and correct field names
|
||||
return data.replace(/\{\{#([^.#]+)\.([^#]+)#\}\}/g, (match, oldId, field) => {
|
||||
const newNode = nodeIdMap.get(oldId)
|
||||
// #region agent log
|
||||
if (!newNode) {
|
||||
console.warn(`[VIBE DEBUG] replaceVariableReferences: No mapping for "${oldId}" in template "${match}"`)
|
||||
}
|
||||
// #endregion
|
||||
if (newNode) {
|
||||
const nodeType = newNode.data?.type as string || ''
|
||||
const correctedField = correctFieldName(field, nodeType)
|
||||
|
|
@ -138,6 +144,11 @@ const replaceVariableReferences = (
|
|||
if (data.length >= 2 && typeof data[0] === 'string' && typeof data[1] === 'string') {
|
||||
const potentialNodeId = data[0]
|
||||
const newNode = nodeIdMap.get(potentialNodeId)
|
||||
// #region agent log
|
||||
if (!newNode && !['sys', 'env', 'conversation'].includes(potentialNodeId)) {
|
||||
console.warn(`[VIBE DEBUG] replaceVariableReferences: No mapping for "${potentialNodeId}" in selector [${data.join(', ')}]`)
|
||||
}
|
||||
// #endregion
|
||||
if (newNode) {
|
||||
const nodeType = newNode.data?.type as string || ''
|
||||
const correctedField = correctFieldName(data[1], nodeType)
|
||||
|
|
@ -598,6 +609,8 @@ export const useWorkflowVibe = () => {
|
|||
const { getNodes } = store.getState()
|
||||
const nodes = getNodes()
|
||||
|
||||
|
||||
|
||||
if (!nodesMetaDataMap) {
|
||||
Toast.notify({ type: 'error', message: t('workflow.vibe.nodesUnavailable') })
|
||||
return { nodes: [], edges: [] }
|
||||
|
|
@ -699,12 +712,59 @@ export const useWorkflowVibe = () => {
|
|||
}
|
||||
}
|
||||
|
||||
// For any node with model config, ALWAYS use user's default model
|
||||
if (backendConfig.model && defaultModel) {
|
||||
mergedConfig.model = {
|
||||
provider: defaultModel.provider.provider,
|
||||
name: defaultModel.model,
|
||||
mode: 'chat',
|
||||
// For End nodes, ensure outputs have value_selector format
|
||||
// New format (preferred): {"outputs": [{"variable": "name", "value_selector": ["nodeId", "field"]}]}
|
||||
// Legacy format (fallback): {"outputs": [{"variable": "name", "value": "{{#nodeId.field#}}"}]}
|
||||
if (nodeType === BlockEnum.End && backendConfig.outputs) {
|
||||
const outputs = backendConfig.outputs as Array<{ variable?: string, value?: string, value_selector?: string[] }>
|
||||
mergedConfig.outputs = outputs.map((output) => {
|
||||
// Preferred: value_selector array format (new LLM output format)
|
||||
if (output.value_selector && Array.isArray(output.value_selector)) {
|
||||
return output
|
||||
}
|
||||
// Parse value like "{{#nodeId.field#}}" into ["nodeId", "field"]
|
||||
if (output.value) {
|
||||
const match = output.value.match(/\{\{#([^.]+)\.([^#]+)#\}\}/)
|
||||
if (match) {
|
||||
return {
|
||||
variable: output.variable,
|
||||
value_selector: [match[1], match[2]],
|
||||
}
|
||||
}
|
||||
}
|
||||
// Fallback: return with empty value_selector to prevent crash
|
||||
return {
|
||||
variable: output.variable || 'output',
|
||||
value_selector: [],
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// For Parameter Extractor nodes, ensure each parameter has a 'required' field
|
||||
// Backend may omit this field, but Dify's Pydantic model requires it
|
||||
if (nodeType === BlockEnum.ParameterExtractor && backendConfig.parameters) {
|
||||
const parameters = backendConfig.parameters as Array<{ name?: string, type?: string, description?: string, required?: boolean }>
|
||||
mergedConfig.parameters = parameters.map((param) => ({
|
||||
...param,
|
||||
required: param.required ?? true, // Default to required if not specified
|
||||
}))
|
||||
}
|
||||
|
||||
// For any node with model config, ALWAYS use user's configured model
|
||||
// This prevents "Model not exist" errors when LLM generates models the user doesn't have configured
|
||||
// Applies to: LLM, QuestionClassifier, ParameterExtractor, and any future model-dependent nodes
|
||||
if (backendConfig.model) {
|
||||
// Try to use defaultModel first, fallback to first available model from modelList
|
||||
const fallbackModel = modelList?.[0]?.models?.[0]
|
||||
const modelProvider = defaultModel?.provider?.provider || modelList?.[0]?.provider
|
||||
const modelName = defaultModel?.model || fallbackModel?.model
|
||||
|
||||
if (modelProvider && modelName) {
|
||||
mergedConfig.model = {
|
||||
provider: modelProvider,
|
||||
name: modelName,
|
||||
mode: 'chat',
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -731,10 +791,19 @@ export const useWorkflowVibe = () => {
|
|||
}
|
||||
|
||||
// Replace variable references in all node configs using the nodeIdMap
|
||||
// This converts {{#old_id.field#}} to {{#new_uuid.field#}}
|
||||
|
||||
for (const node of newNodes) {
|
||||
node.data = replaceVariableReferences(node.data, nodeIdMap) as typeof node.data
|
||||
}
|
||||
|
||||
// Use Dify's standard node initialization to handle all node types generically
|
||||
// This sets up _targetBranches for question-classifier/if-else, _children for iteration/loop, etc.
|
||||
const initializedNodes = initializeNodeData(newNodes, [])
|
||||
|
||||
// Update newNodes with initialized data
|
||||
newNodes.splice(0, newNodes.length, ...initializedNodes)
|
||||
|
||||
if (!newNodes.length) {
|
||||
Toast.notify({ type: 'error', message: t('workflow.vibe.invalidFlowchart') })
|
||||
return { nodes: [], edges: [] }
|
||||
|
|
@ -762,12 +831,16 @@ export const useWorkflowVibe = () => {
|
|||
zIndex: 0,
|
||||
})
|
||||
|
||||
|
||||
const newEdges: Edge[] = []
|
||||
for (const edgeSpec of backendEdges) {
|
||||
const sourceNode = nodeIdMap.get(edgeSpec.source)
|
||||
const targetNode = nodeIdMap.get(edgeSpec.target)
|
||||
if (!sourceNode || !targetNode)
|
||||
|
||||
if (!sourceNode || !targetNode) {
|
||||
console.warn(`[VIBE] Edge skipped: source=${edgeSpec.source} (found=${!!sourceNode}), target=${edgeSpec.target} (found=${!!targetNode})`)
|
||||
continue
|
||||
}
|
||||
|
||||
let sourceHandle = edgeSpec.sourceHandle || 'source'
|
||||
// Handle IfElse branch handles
|
||||
|
|
@ -775,9 +848,11 @@ export const useWorkflowVibe = () => {
|
|||
sourceHandle = 'source'
|
||||
}
|
||||
|
||||
|
||||
newEdges.push(buildEdge(sourceNode, targetNode, sourceHandle, edgeSpec.targetHandle || 'target'))
|
||||
}
|
||||
|
||||
|
||||
// Layout nodes
|
||||
const bounds = nodes.reduce(
|
||||
(acc, node) => {
|
||||
|
|
@ -878,11 +953,15 @@ export const useWorkflowVibe = () => {
|
|||
}
|
||||
})
|
||||
|
||||
|
||||
|
||||
setNodes(updatedNodes)
|
||||
setEdges([...edges, ...newEdges])
|
||||
saveStateToHistory(WorkflowHistoryEvent.NodeAdd, { nodeId: newNodes[0].id })
|
||||
handleSyncWorkflowDraft()
|
||||
|
||||
|
||||
|
||||
workflowStore.setState(state => ({
|
||||
...state,
|
||||
showVibePanel: false,
|
||||
|
|
@ -1194,81 +1273,128 @@ export const useWorkflowVibe = () => {
|
|||
output_schema: tool.output_schema,
|
||||
}))
|
||||
|
||||
const stream = await generateFlowchart({
|
||||
instruction: trimmed,
|
||||
model_config: latestModelConfig!,
|
||||
existing_nodes: existingNodesPayload,
|
||||
tools: toolsPayload,
|
||||
regenerate_mode: regenerateMode,
|
||||
})
|
||||
const availableNodesPayload = availableNodesList.map(node => ({
|
||||
type: node.type,
|
||||
title: node.title,
|
||||
description: node.description,
|
||||
}))
|
||||
|
||||
let mermaidCode = ''
|
||||
let backendNodes: BackendNodeSpec[] | undefined
|
||||
let backendEdges: BackendEdgeSpec[] | undefined
|
||||
|
||||
const reader = stream.getReader()
|
||||
const decoder = new TextDecoder()
|
||||
|
||||
while (true) {
|
||||
const { done, value } = await reader.read()
|
||||
if (done)
|
||||
break
|
||||
|
||||
const chunk = decoder.decode(value)
|
||||
const lines = chunk.split('\n')
|
||||
|
||||
for (const line of lines) {
|
||||
if (!line.trim() || !line.startsWith('data: '))
|
||||
continue
|
||||
|
||||
try {
|
||||
const data = JSON.parse(line.slice(6))
|
||||
if (data.event === 'message' || data.event === 'workflow_generated') {
|
||||
if (data.data?.text) {
|
||||
mermaidCode += data.data.text
|
||||
workflowStore.setState(state => ({
|
||||
...state,
|
||||
vibePanelMermaidCode: mermaidCode,
|
||||
}))
|
||||
}
|
||||
if (data.data?.nodes) {
|
||||
backendNodes = data.data.nodes
|
||||
workflowStore.setState(state => ({
|
||||
...state,
|
||||
vibePanelBackendNodes: backendNodes,
|
||||
}))
|
||||
}
|
||||
if (data.data?.edges) {
|
||||
backendEdges = data.data.edges
|
||||
workflowStore.setState(state => ({
|
||||
...state,
|
||||
vibePanelBackendEdges: backendEdges,
|
||||
}))
|
||||
}
|
||||
if (data.data?.intent) {
|
||||
workflowStore.setState(state => ({
|
||||
...state,
|
||||
vibePanelIntent: data.data.intent,
|
||||
}))
|
||||
}
|
||||
if (data.data?.message) {
|
||||
workflowStore.setState(state => ({
|
||||
...state,
|
||||
vibePanelMessage: data.data.message,
|
||||
}))
|
||||
}
|
||||
if (data.data?.suggestions) {
|
||||
workflowStore.setState(state => ({
|
||||
...state,
|
||||
vibePanelSuggestions: data.data.suggestions,
|
||||
}))
|
||||
}
|
||||
}
|
||||
}
|
||||
catch (e) {
|
||||
console.error('Error parsing chunk:', e)
|
||||
if (!isMermaidFlowchart(trimmed)) {
|
||||
// Build previous workflow context if regenerating
|
||||
const { vibePanelBackendNodes, vibePanelBackendEdges, vibePanelLastWarnings } = workflowStore.getState()
|
||||
const previousWorkflow = regenerateMode && vibePanelBackendNodes && vibePanelBackendNodes.length > 0
|
||||
? {
|
||||
nodes: vibePanelBackendNodes,
|
||||
edges: vibePanelBackendEdges || [],
|
||||
warnings: vibePanelLastWarnings || [],
|
||||
}
|
||||
: undefined
|
||||
|
||||
// Map language code to human-readable language name for LLM
|
||||
const languageNameMap: Record<string, string> = {
|
||||
en_US: 'English',
|
||||
zh_Hans: 'Chinese',
|
||||
zh_Hant: 'Traditional Chinese',
|
||||
ja_JP: 'Japanese',
|
||||
ko_KR: 'Korean',
|
||||
pt_BR: 'Portuguese',
|
||||
es_ES: 'Spanish',
|
||||
fr_FR: 'French',
|
||||
de_DE: 'German',
|
||||
it_IT: 'Italian',
|
||||
ru_RU: 'Russian',
|
||||
uk_UA: 'Ukrainian',
|
||||
vi_VN: 'Vietnamese',
|
||||
pl_PL: 'Polish',
|
||||
ro_RO: 'Romanian',
|
||||
tr_TR: 'Turkish',
|
||||
fa_IR: 'Persian',
|
||||
hi_IN: 'Hindi',
|
||||
}
|
||||
const preferredLanguage = languageNameMap[language] || 'English'
|
||||
|
||||
// Extract available models from user's configured model providers
|
||||
const availableModelsPayload = modelList?.flatMap(provider =>
|
||||
provider.models.map(model => ({
|
||||
provider: provider.provider,
|
||||
model: model.model,
|
||||
})),
|
||||
) || []
|
||||
|
||||
const requestPayload = {
|
||||
instruction: trimmed,
|
||||
model_config: latestModelConfig,
|
||||
available_nodes: availableNodesPayload,
|
||||
existing_nodes: existingNodesPayload,
|
||||
available_tools: toolsPayload,
|
||||
selected_node_ids: [],
|
||||
previous_workflow: previousWorkflow,
|
||||
regenerate_mode: regenerateMode,
|
||||
language: preferredLanguage,
|
||||
available_models: availableModelsPayload,
|
||||
}
|
||||
|
||||
const response = await generateFlowchart(requestPayload)
|
||||
|
||||
const { error, flowchart, nodes, edges, intent, message, warnings, suggestions } = response
|
||||
|
||||
if (error) {
|
||||
Toast.notify({ type: 'error', message: error })
|
||||
setIsVibeGenerating(false)
|
||||
return
|
||||
}
|
||||
|
||||
// Handle off_topic intent - show rejection message and suggestions
|
||||
if (intent === 'off_topic') {
|
||||
workflowStore.setState(state => ({
|
||||
...state,
|
||||
vibePanelMermaidCode: '',
|
||||
vibePanelMessage: message || t('workflow.vibe.offTopicDefault'),
|
||||
vibePanelSuggestions: suggestions || [],
|
||||
vibePanelIntent: 'off_topic',
|
||||
isVibeGenerating: false,
|
||||
}))
|
||||
return
|
||||
}
|
||||
|
||||
if (!flowchart) {
|
||||
Toast.notify({ type: 'error', message: t('workflow.vibe.missingFlowchart') })
|
||||
setIsVibeGenerating(false)
|
||||
return
|
||||
}
|
||||
|
||||
// Show warnings if any (includes tool sanitization warnings)
|
||||
const responseWarnings = warnings || []
|
||||
if (responseWarnings.length > 0) {
|
||||
responseWarnings.forEach((warning) => {
|
||||
Toast.notify({ type: 'warning', message: warning })
|
||||
})
|
||||
}
|
||||
|
||||
mermaidCode = flowchart
|
||||
// Store backend nodes/edges for direct use (bypasses mermaid re-parsing)
|
||||
backendNodes = nodes
|
||||
backendEdges = edges
|
||||
// Store warnings for regeneration context
|
||||
workflowStore.setState(state => ({
|
||||
...state,
|
||||
vibePanelLastWarnings: responseWarnings,
|
||||
}))
|
||||
|
||||
workflowStore.setState(state => ({
|
||||
...state,
|
||||
vibePanelMermaidCode: mermaidCode,
|
||||
vibePanelBackendNodes: backendNodes,
|
||||
vibePanelBackendEdges: backendEdges,
|
||||
vibePanelMessage: '',
|
||||
vibePanelSuggestions: [],
|
||||
vibePanelIntent: 'generate',
|
||||
isVibeGenerating: false,
|
||||
}))
|
||||
}
|
||||
|
||||
setIsVibeGenerating(false)
|
||||
|
|
@ -1286,10 +1412,16 @@ export const useWorkflowVibe = () => {
|
|||
if (skipPanelPreview) {
|
||||
// Prefer backend nodes (already sanitized) over mermaid re-parsing
|
||||
if (backendNodes && backendNodes.length > 0 && backendEdges) {
|
||||
console.log('[VIBE] Applying backend nodes directly to workflow')
|
||||
console.log('[VIBE] Backend nodes:', backendNodes.length)
|
||||
console.log('[VIBE] Backend edges:', backendEdges.length)
|
||||
await applyBackendNodesToWorkflow(backendNodes, backendEdges)
|
||||
console.log('[VIBE] Backend nodes applied successfully')
|
||||
}
|
||||
else {
|
||||
console.log('[VIBE] Applying mermaid flowchart to workflow')
|
||||
await applyFlowchartToWorkflow()
|
||||
console.log('[VIBE] Mermaid flowchart applied successfully')
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in New Issue