diff --git a/api/controllers/console/app/generator.py b/api/controllers/console/app/generator.py index 1a1cde7329..15bfcdf999 100644 --- a/api/controllers/console/app/generator.py +++ b/api/controllers/console/app/generator.py @@ -1,9 +1,13 @@ +import logging from collections.abc import Sequence from typing import Any from flask_restx import Resource from pydantic import BaseModel, Field +logger = logging.getLogger(__name__) + + from controllers.console import console_ns from controllers.console.app.error import ( CompletionRequestError, @@ -18,6 +22,7 @@ from core.helper.code_executor.javascript.javascript_code_provider import Javasc from core.helper.code_executor.python3.python3_code_provider import Python3CodeProvider from core.llm_generator.llm_generator import LLMGenerator from core.model_runtime.errors.invoke import InvokeError +from core.workflow.generator import WorkflowGenerator from extensions.ext_database import db from libs.login import current_account_with_tenant, login_required from models import App @@ -55,12 +60,34 @@ class InstructionTemplatePayload(BaseModel): type: str = Field(..., description="Instruction template type") +class PreviousWorkflow(BaseModel): + """Previous workflow attempt for regeneration context.""" + + nodes: list[dict[str, Any]] = Field(default_factory=list, description="Previously generated nodes") + edges: list[dict[str, Any]] = Field(default_factory=list, description="Previously generated edges") + warnings: list[str] = Field(default_factory=list, description="Warnings from previous generation") + + class FlowchartGeneratePayload(BaseModel): instruction: str = Field(..., description="Workflow flowchart generation instruction") model_config_data: dict[str, Any] = Field(..., alias="model_config", description="Model configuration") available_nodes: list[dict[str, Any]] = Field(default_factory=list, description="Available node types") existing_nodes: list[dict[str, Any]] = Field(default_factory=list, description="Existing workflow nodes") available_tools: list[dict[str, Any]] = Field(default_factory=list, description="Available tools") + selected_node_ids: list[str] = Field(default_factory=list, description="IDs of selected nodes for context") + previous_workflow: PreviousWorkflow | None = Field(default=None, description="Previous workflow for regeneration") + regenerate_mode: bool = Field(default=False, description="Whether this is a regeneration request") + # Language preference for generated content (node titles, descriptions) + language: str | None = Field(default=None, description="Preferred language for generated content") + # Available models that user has configured (for LLM/question-classifier nodes) + available_models: list[dict[str, Any]] = Field(default_factory=list, description="User's configured models") + # Validate-fix iteration loop configuration + max_fix_iterations: int = Field( + default=2, + ge=0, + le=5, + description="Maximum number of validate-fix iterations (0 to disable auto-fix)", + ) def reg(cls: type[BaseModel]): @@ -267,7 +294,7 @@ class InstructionGenerateApi(Resource): @console_ns.route("/flowchart-generate") class FlowchartGenerateApi(Resource): @console_ns.doc("generate_workflow_flowchart") - @console_ns.doc(description="Generate workflow flowchart using LLM") + @console_ns.doc(description="Generate workflow flowchart using LLM with intent classification") @console_ns.expect(console_ns.models[FlowchartGeneratePayload.__name__]) @console_ns.response(200, "Flowchart generated successfully") @console_ns.response(400, "Invalid request parameters") @@ -280,14 +307,24 @@ class FlowchartGenerateApi(Resource): _, current_tenant_id = current_account_with_tenant() try: - result = LLMGenerator.generate_workflow_flowchart( + # Convert PreviousWorkflow to dict if present + previous_workflow_dict = args.previous_workflow.model_dump() if args.previous_workflow else None + + result = WorkflowGenerator.generate_workflow_flowchart( tenant_id=current_tenant_id, instruction=args.instruction, model_config=args.model_config_data, available_nodes=args.available_nodes, existing_nodes=args.existing_nodes, available_tools=args.available_tools, + selected_node_ids=args.selected_node_ids, + previous_workflow=previous_workflow_dict, + regenerate_mode=args.regenerate_mode, + preferred_language=args.language, + available_models=args.available_models, + max_fix_iterations=args.max_fix_iterations, ) + except ProviderTokenNotInitError as ex: raise ProviderNotInitializeError(ex.description) except QuotaExceededError: diff --git a/api/core/llm_generator/llm_generator.py b/api/core/llm_generator/llm_generator.py index 4cc60a4878..7a5c1d550c 100644 --- a/api/core/llm_generator/llm_generator.py +++ b/api/core/llm_generator/llm_generator.py @@ -1,6 +1,5 @@ import json import logging -import re from collections.abc import Sequence from typing import Protocol, cast @@ -12,13 +11,10 @@ from core.llm_generator.prompts import ( CONVERSATION_TITLE_PROMPT, GENERATOR_QA_PROMPT, JAVASCRIPT_CODE_GENERATOR_PROMPT_TEMPLATE, - LLM_MODIFY_CODE_SYSTEM, - LLM_MODIFY_PROMPT_SYSTEM, PYTHON_CODE_GENERATOR_PROMPT_TEMPLATE, SUGGESTED_QUESTIONS_MAX_TOKENS, SUGGESTED_QUESTIONS_TEMPERATURE, SYSTEM_STRUCTURED_OUTPUT_GENERATE, - WORKFLOW_FLOWCHART_PROMPT_TEMPLATE, WORKFLOW_RULE_CONFIG_PROMPT_GENERATE_TEMPLATE, ) from core.model_manager import ModelManager @@ -31,6 +27,7 @@ from core.ops.ops_trace_manager import TraceQueueManager, TraceTask from core.ops.utils import measure_time from core.prompt.utils.prompt_template_parser import PromptTemplateParser from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey +from core.workflow.generator import WorkflowGenerator from extensions.ext_database import db from extensions.ext_storage import storage from models import App, Message, WorkflowNodeExecutionModel @@ -295,52 +292,29 @@ class LLMGenerator: available_nodes: Sequence[dict[str, object]] | None = None, existing_nodes: Sequence[dict[str, object]] | None = None, available_tools: Sequence[dict[str, object]] | None = None, + selected_node_ids: Sequence[str] | None = None, + previous_workflow: dict[str, object] | None = None, + regenerate_mode: bool = False, + preferred_language: str | None = None, + available_models: Sequence[dict[str, object]] | None = None, + max_fix_iterations: int = 2, ): - model_parameters = model_config.get("completion_params", {}) - prompt_template = PromptTemplateParser(WORKFLOW_FLOWCHART_PROMPT_TEMPLATE) - prompt_generate = prompt_template.format( - inputs={ - "TASK_DESCRIPTION": instruction, - "AVAILABLE_NODES": json.dumps(available_nodes or [], ensure_ascii=False), - "EXISTING_NODES": json.dumps(existing_nodes or [], ensure_ascii=False), - "AVAILABLE_TOOLS": json.dumps(available_tools or [], ensure_ascii=False), - }, - remove_template_variables=False, - ) - prompt_messages = [UserPromptMessage(content=prompt_generate)] - - model_manager = ModelManager() - model_instance = model_manager.get_model_instance( + return WorkflowGenerator.generate_workflow_flowchart( tenant_id=tenant_id, - model_type=ModelType.LLM, - provider=model_config.get("provider", ""), - model=model_config.get("name", ""), + instruction=instruction, + model_config=model_config, + available_nodes=available_nodes, + existing_nodes=existing_nodes, + available_tools=available_tools, + selected_node_ids=selected_node_ids, + previous_workflow=previous_workflow, + regenerate_mode=regenerate_mode, + preferred_language=preferred_language, + available_models=available_models, + max_fix_iterations=max_fix_iterations, ) - flowchart = "" - error = "" - - try: - response: LLMResult = model_instance.invoke_llm( - prompt_messages=list(prompt_messages), - model_parameters=model_parameters, - stream=False, - ) - content = response.message.get_text_content() - if not isinstance(content, str): - raise ValueError("Flowchart response is not a string") - - match = re.search(r"```(?:mermaid)?\s*([\s\S]+?)```", content, flags=re.IGNORECASE) - flowchart = (match.group(1) if match else content).strip() - except InvokeError as e: - error = str(e) - except Exception as e: - logger.exception("Failed to generate workflow flowchart, model: %s", model_config.get("name")) - error = str(e) - - return {"flowchart": flowchart, "error": error} - @classmethod def generate_code(cls, tenant_id: str, instruction: str, model_config: dict, code_language: str = "javascript"): if code_language == "python": diff --git a/api/core/llm_generator/prompts.py b/api/core/llm_generator/prompts.py index 766ae07231..76d7231b55 100644 --- a/api/core/llm_generator/prompts.py +++ b/api/core/llm_generator/prompts.py @@ -147,6 +147,8 @@ WORKFLOW_FLOWCHART_PROMPT_TEMPLATE = """ You are an expert workflow designer. Generate a Mermaid flowchart based on the user's request. Constraints: +- Detect the language of the user's request. Generate all node titles in the same language as the user's input. +- If the input language cannot be determined, use {{PREFERRED_LANGUAGE}} as the fallback language. - Use only node types listed in . - Use only tools listed in . When using a tool node, set type=tool and tool=. - Tools may include MCP providers (provider_type=mcp). Tool selection still uses tool_key. diff --git a/api/core/workflow/generator/__init__.py b/api/core/workflow/generator/__init__.py new file mode 100644 index 0000000000..2b722441a9 --- /dev/null +++ b/api/core/workflow/generator/__init__.py @@ -0,0 +1 @@ +from .runner import WorkflowGenerator diff --git a/api/core/workflow/generator/config/__init__.py b/api/core/workflow/generator/config/__init__.py new file mode 100644 index 0000000000..b7f6d1d3e9 --- /dev/null +++ b/api/core/workflow/generator/config/__init__.py @@ -0,0 +1,29 @@ +""" +Vibe Workflow Generator Configuration Module. + +This module centralizes configuration for the Vibe workflow generation feature, +including node schemas, fallback rules, and response templates. +""" + +from core.workflow.generator.config.node_schemas import ( + BUILTIN_NODE_SCHEMAS, + FALLBACK_RULES, + FIELD_NAME_CORRECTIONS, + NODE_TYPE_ALIASES, + get_builtin_node_schemas, + get_corrected_field_name, + validate_node_schemas, +) +from core.workflow.generator.config.responses import DEFAULT_SUGGESTIONS, OFF_TOPIC_RESPONSES + +__all__ = [ + "BUILTIN_NODE_SCHEMAS", + "DEFAULT_SUGGESTIONS", + "FALLBACK_RULES", + "FIELD_NAME_CORRECTIONS", + "NODE_TYPE_ALIASES", + "OFF_TOPIC_RESPONSES", + "get_builtin_node_schemas", + "get_corrected_field_name", + "validate_node_schemas", +] diff --git a/api/core/workflow/generator/config/node_schemas.py b/api/core/workflow/generator/config/node_schemas.py new file mode 100644 index 0000000000..e4980cd4c3 --- /dev/null +++ b/api/core/workflow/generator/config/node_schemas.py @@ -0,0 +1,501 @@ +""" +Unified Node Configuration for Vibe Workflow Generation. + +This module centralizes all node-related configuration: +- Node schemas (parameter definitions) +- Fallback rules (keyword-based node type inference) +- Node type aliases (natural language to canonical type mapping) +- Field name corrections (LLM output normalization) +- Validation utilities + +Note: These definitions are the single source of truth. +Frontend has a mirrored copy at web/app/components/workflow/hooks/use-workflow-vibe-config.ts +""" + +from typing import Any + +# ============================================================================= +# NODE SCHEMAS +# ============================================================================= + +# Built-in node schemas with parameter definitions +# These help the model understand what config each node type requires +_HARDCODED_SCHEMAS: dict[str, dict[str, Any]] = { + "http-request": { + "description": "Send HTTP requests to external APIs or fetch web content", + "required": ["url", "method"], + "parameters": { + "url": { + "type": "string", + "description": "Full URL including protocol (https://...)", + "example": "{{#start.url#}} or https://api.example.com/data", + }, + "method": { + "type": "enum", + "options": ["GET", "POST", "PUT", "DELETE", "PATCH", "HEAD"], + "description": "HTTP method", + }, + "headers": { + "type": "string", + "description": "HTTP headers as newline-separated 'Key: Value' pairs", + "example": "Content-Type: application/json\nAuthorization: Bearer {{#start.api_key#}}", + }, + "params": { + "type": "string", + "description": "URL query parameters as newline-separated 'key: value' pairs", + }, + "body": { + "type": "object", + "description": "Request body with type field required", + "example": {"type": "none", "data": []}, + }, + "authorization": { + "type": "object", + "description": "Authorization config", + "example": {"type": "no-auth"}, + }, + "timeout": { + "type": "number", + "description": "Request timeout in seconds", + "default": 60, + }, + }, + "outputs": ["body (response content)", "status_code", "headers"], + }, + "code": { + "description": "Execute Python or JavaScript code for custom logic", + "required": ["code", "language"], + "parameters": { + "code": { + "type": "string", + "description": "Code to execute. Must define a main() function that returns a dict.", + }, + "language": { + "type": "enum", + "options": ["python3", "javascript"], + }, + "variables": { + "type": "array", + "description": "Input variables passed to the code", + "item_schema": {"variable": "string", "value_selector": "array"}, + }, + "outputs": { + "type": "object", + "description": "Output variable definitions", + }, + }, + "outputs": ["Variables defined in outputs schema"], + }, + "llm": { + "description": "Call a large language model for text generation/processing", + "required": ["prompt_template"], + "parameters": { + "model": { + "type": "object", + "description": "Model configuration (provider, name, mode)", + }, + "prompt_template": { + "type": "array", + "description": "Messages for the LLM", + "item_schema": { + "role": "enum: system, user, assistant", + "text": "string - message content, can include {{#node_id.field#}} references", + }, + }, + "context": { + "type": "object", + "description": "Optional context settings", + }, + "memory": { + "type": "object", + "description": "Optional memory/conversation settings", + }, + }, + "outputs": ["text (generated response)"], + }, + "if-else": { + "description": "Conditional branching based on conditions", + "required": ["cases"], + "parameters": { + "cases": { + "type": "array", + "description": "List of condition cases. Each case defines when 'true' branch is taken.", + "item_schema": { + "case_id": "string - unique case identifier (e.g., 'case_1')", + "logical_operator": "enum: and, or - how multiple conditions combine", + "conditions": { + "type": "array", + "item_schema": { + "variable_selector": "array of strings - path to variable, e.g. ['node_id', 'field']", + "comparison_operator": ( + "enum: =, ≠, >, <, ≥, ≤, contains, not contains, is, is not, empty, not empty" + ), + "value": "string or number - value to compare against", + }, + }, + }, + }, + }, + "outputs": ["Branches: true (first case conditions met), false (else/no case matched)"], + }, + "knowledge-retrieval": { + "description": "Query knowledge base for relevant content", + "required": ["query_variable_selector", "dataset_ids"], + "parameters": { + "query_variable_selector": { + "type": "array", + "description": "Path to query variable, e.g. ['start', 'query']", + }, + "dataset_ids": { + "type": "array", + "description": "List of knowledge base IDs to search", + }, + "retrieval_mode": { + "type": "enum", + "options": ["single", "multiple"], + }, + }, + "outputs": ["result (retrieved documents)"], + }, + "template-transform": { + "description": "Transform data using Jinja2 templates", + "required": ["template", "variables"], + "parameters": { + "template": { + "type": "string", + "description": "Jinja2 template string. Use {{ variable_name }} to reference variables.", + }, + "variables": { + "type": "array", + "description": "Input variables defined for the template", + "item_schema": { + "variable": "string - variable name to use in template", + "value_selector": "array - path to source value, e.g. ['start', 'user_input']", + }, + }, + }, + "outputs": ["output (transformed string)"], + }, + "variable-aggregator": { + "description": "Aggregate variables from multiple branches", + "required": ["variables"], + "parameters": { + "variables": { + "type": "array", + "description": "List of variable selectors to aggregate", + "item_schema": "array of strings - path to source variable, e.g. ['node_id', 'field']", + }, + }, + "outputs": ["output (aggregated value)"], + }, + "iteration": { + "description": "Loop over array items", + "required": ["iterator_selector"], + "parameters": { + "iterator_selector": { + "type": "array", + "description": "Path to array variable to iterate", + }, + }, + "outputs": ["item (current iteration item)", "index (current index)"], + }, + "parameter-extractor": { + "description": "Extract structured parameters from user input using LLM", + "required": ["query", "parameters"], + "parameters": { + "model": { + "type": "object", + "description": "Model configuration (provider, name, mode)", + }, + "query": { + "type": "array", + "description": "Path to input text to extract parameters from, e.g. ['start', 'user_input']", + }, + "parameters": { + "type": "array", + "description": "Parameters to extract from the input", + "item_schema": { + "name": "string - parameter name (required)", + "type": ( + "enum: string, number, boolean, array[string], array[number], array[object], array[boolean]" + ), + "description": "string - description of what to extract (required)", + "required": "boolean - whether this parameter is required (MUST be specified)", + "options": "array of strings (optional) - for enum-like selection", + }, + }, + "instruction": { + "type": "string", + "description": "Additional instructions for extraction", + }, + "reasoning_mode": { + "type": "enum", + "options": ["function_call", "prompt"], + "description": "How to perform extraction (defaults to function_call)", + }, + }, + "outputs": ["Extracted parameters as defined in parameters array", "__is_success", "__reason"], + }, + "question-classifier": { + "description": "Classify user input into predefined categories using LLM", + "required": ["query", "classes"], + "parameters": { + "model": { + "type": "object", + "description": "Model configuration (provider, name, mode)", + }, + "query": { + "type": "array", + "description": "Path to input text to classify, e.g. ['start', 'user_input']", + }, + "classes": { + "type": "array", + "description": "Classification categories", + "item_schema": { + "id": "string - unique class identifier", + "name": "string - class name/label", + }, + }, + "instruction": { + "type": "string", + "description": "Additional instructions for classification", + }, + }, + "outputs": ["class_name (selected class)"], + }, +} + + +def _get_dynamic_schemas() -> dict[str, dict[str, Any]]: + """ + Dynamically load schemas from node classes. + Uses lazy import to avoid circular dependency. + """ + from core.workflow.nodes.node_mapping import LATEST_VERSION, NODE_TYPE_CLASSES_MAPPING + + schemas = {} + for node_type, version_map in NODE_TYPE_CLASSES_MAPPING.items(): + # Get the latest version class + node_cls = version_map.get(LATEST_VERSION) + if not node_cls: + continue + + # Get schema from the class + schema = node_cls.get_default_config_schema() + if schema: + schemas[node_type.value] = schema + + return schemas + + +# Cache for built-in schemas (populated on first access) +_builtin_schemas_cache: dict[str, dict[str, Any]] | None = None + + +def get_builtin_node_schemas() -> dict[str, dict[str, Any]]: + """ + Get the complete set of built-in node schemas. + Combines hardcoded schemas with dynamically loaded ones. + Results are cached after first call. + """ + global _builtin_schemas_cache + if _builtin_schemas_cache is None: + _builtin_schemas_cache = {**_HARDCODED_SCHEMAS, **_get_dynamic_schemas()} + return _builtin_schemas_cache + + +# For backward compatibility - but use get_builtin_node_schemas() for lazy loading +BUILTIN_NODE_SCHEMAS: dict[str, dict[str, Any]] = _HARDCODED_SCHEMAS.copy() + + +# ============================================================================= +# FALLBACK RULES +# ============================================================================= + +# Keyword rules for smart fallback detection +# Maps node type to keywords that suggest using that node type as a fallback +FALLBACK_RULES: dict[str, list[str]] = { + "http-request": [ + "http", + "url", + "web", + "scrape", + "scraper", + "fetch", + "api", + "request", + "download", + "upload", + "webhook", + "endpoint", + "rest", + "get", + "post", + ], + "code": [ + "code", + "script", + "calculate", + "compute", + "process", + "transform", + "parse", + "convert", + "format", + "filter", + "sort", + "math", + "logic", + ], + "llm": [ + "analyze", + "summarize", + "summary", + "extract", + "classify", + "translate", + "generate", + "write", + "rewrite", + "explain", + "answer", + "chat", + ], +} + + +# ============================================================================= +# NODE TYPE ALIASES +# ============================================================================= + +# Node type aliases for inference from natural language +# Maps common terms to canonical node type names +NODE_TYPE_ALIASES: dict[str, str] = { + # Start node aliases + "start": "start", + "begin": "start", + "input": "start", + # End node aliases + "end": "end", + "finish": "end", + "output": "end", + # LLM node aliases + "llm": "llm", + "ai": "llm", + "gpt": "llm", + "model": "llm", + "chat": "llm", + # Code node aliases + "code": "code", + "script": "code", + "python": "code", + "javascript": "code", + # HTTP request node aliases + "http-request": "http-request", + "http": "http-request", + "request": "http-request", + "api": "http-request", + "fetch": "http-request", + "webhook": "http-request", + # Conditional node aliases + "if-else": "if-else", + "condition": "if-else", + "branch": "if-else", + "switch": "if-else", + # Loop node aliases + "iteration": "iteration", + "loop": "loop", + "foreach": "iteration", + # Tool node alias + "tool": "tool", +} + + +# ============================================================================= +# FIELD NAME CORRECTIONS +# ============================================================================= + +# Field name corrections for LLM-generated node configs +# Maps incorrect field names to correct ones for specific node types +FIELD_NAME_CORRECTIONS: dict[str, dict[str, str]] = { + "http-request": { + "text": "body", # LLM might use "text" instead of "body" + "content": "body", + "response": "body", + }, + "code": { + "text": "result", # LLM might use "text" instead of "result" + "output": "result", + }, + "llm": { + "response": "text", + "answer": "text", + }, +} + + +def get_corrected_field_name(node_type: str, field: str) -> str: + """ + Get the corrected field name for a node type. + + Args: + node_type: The type of the node (e.g., "http-request", "code") + field: The field name to correct + + Returns: + The corrected field name, or the original if no correction needed + """ + corrections = FIELD_NAME_CORRECTIONS.get(node_type, {}) + return corrections.get(field, field) + + +# ============================================================================= +# VALIDATION UTILITIES +# ============================================================================= + +# Node types that are internal and don't need schemas for LLM generation +_INTERNAL_NODE_TYPES: set[str] = { + # Internal workflow nodes + "answer", # Internal to chatflow + "loop", # Uses iteration internally + "assigner", # Variable assignment utility + "variable-assigner", # Variable assignment utility + "agent", # Agent node (complex, handled separately) + "document-extractor", # Internal document processing + "list-operator", # Internal list operations + # Iteration internal nodes + "iteration-start", # Internal to iteration loop + "loop-start", # Internal to loop + "loop-end", # Internal to loop + # Trigger nodes (not user-creatable via LLM) + "trigger-plugin", # Plugin trigger + "trigger-schedule", # Scheduled trigger + "trigger-webhook", # Webhook trigger + # Other internal nodes + "datasource", # Data source configuration + "human-input", # Human-in-the-loop node + "knowledge-index", # Knowledge indexing node +} + + +def validate_node_schemas() -> list[str]: + """ + Validate that all registered node types have corresponding schemas. + + This function checks if BUILTIN_NODE_SCHEMAS covers all node types + registered in NODE_TYPE_CLASSES_MAPPING, excluding internal node types. + + Returns: + List of warning messages for missing schemas (empty if all valid) + """ + from core.workflow.nodes.node_mapping import NODE_TYPE_CLASSES_MAPPING + + schemas = get_builtin_node_schemas() + warnings = [] + for node_type in NODE_TYPE_CLASSES_MAPPING: + type_value = node_type.value + if type_value in _INTERNAL_NODE_TYPES: + continue + if type_value not in schemas: + warnings.append(f"Missing schema for node type: {type_value}") + return warnings diff --git a/api/core/workflow/generator/config/responses.py b/api/core/workflow/generator/config/responses.py new file mode 100644 index 0000000000..4303fcb709 --- /dev/null +++ b/api/core/workflow/generator/config/responses.py @@ -0,0 +1,74 @@ +""" +Response Templates for Vibe Workflow Generation. + +This module defines templates for off-topic responses and default suggestions +to guide users back to workflow-related requests. +""" + +# Off-topic response templates for different categories +# Each category has messages in multiple languages +OFF_TOPIC_RESPONSES: dict[str, dict[str, str]] = { + "weather": { + "en": ( + "I'm the workflow design assistant - I can't check the weather, " + "but I can help you build AI workflows! For example, I could help you " + "create a workflow that fetches weather data from an API." + ), + "zh": "我是工作流设计助手,无法查询天气。但我可以帮你创建一个从API获取天气数据的工作流!", + }, + "math": { + "en": ( + "I focus on workflow design rather than calculations. However, " + "if you need calculations in a workflow, I can help you add a Code node " + "that handles math operations!" + ), + "zh": "我专注于工作流设计而非计算。但如果您需要在工作流中进行计算,我可以帮您添加一个处理数学运算的代码节点!", + }, + "joke": { + "en": ( + "While I'd love to share a laugh, I'm specialized in workflow design. " + "How about we create something fun instead - like a workflow that generates jokes using AI?" + ), + "zh": "虽然我很想讲笑话,但我专门从事工作流设计。不如我们创建一个有趣的东西——比如使用AI生成笑话的工作流?", + }, + "translation": { + "en": ( + "I can't translate directly, but I can help you build a translation workflow! " + "Would you like to create one using an LLM node?" + ), + "zh": "我不能直接翻译,但我可以帮你构建一个翻译工作流!要创建一个使用LLM节点的翻译流程吗?", + }, + "general_coding": { + "en": ( + "I'm specialized in Dify workflow design rather than general coding help. " + "But if you want to add code logic to your workflow, I can help you configure a Code node!" + ), + "zh": ( + "我专注于Dify工作流设计,而非通用编程帮助。" + "但如果您想在工作流中添加代码逻辑,我可以帮您配置一个代码节点!" + ), + }, + "default": { + "en": ( + "I'm the Dify workflow design assistant. I help create AI automation workflows, " + "but I can't help with general questions. Would you like to create a workflow instead?" + ), + "zh": "我是Dify工作流设计助手。我帮助创建AI自动化工作流,但无法回答一般性问题。您想创建一个工作流吗?", + }, +} + +# Default suggestions for off-topic requests +# These help guide users towards valid workflow requests +DEFAULT_SUGGESTIONS: dict[str, list[str]] = { + "en": [ + "Create a chatbot workflow", + "Build a document summarization pipeline", + "Add email notification to workflow", + ], + "zh": [ + "创建一个聊天机器人工作流", + "构建文档摘要处理流程", + "添加邮件通知到工作流", + ], +} + diff --git a/api/core/workflow/generator/prompts/__init__.py b/api/core/workflow/generator/prompts/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/core/workflow/generator/prompts/builder_prompts.py b/api/core/workflow/generator/prompts/builder_prompts.py new file mode 100644 index 0000000000..d6eeaa2ebe --- /dev/null +++ b/api/core/workflow/generator/prompts/builder_prompts.py @@ -0,0 +1,354 @@ +BUILDER_SYSTEM_PROMPT = """ +You are a Workflow Configuration Engineer. +Your goal is to implement the Architect's plan by generating a precise, runnable Dify Workflow JSON configuration. + + + + +{plan_context} + + + +{tool_schemas} + + + +{builtin_node_specs} + + + +{available_models} + + + + +1. **Configuration**: + - You MUST fill ALL required parameters for every node. + - Use `{{{{#node_id.field#}}}}` syntax to reference outputs from previous nodes in text fields. + - For 'start' node, define all necessary user inputs. + +2. **Variable References**: + - For text fields (like prompts, queries): use string format `{{{{#node_id.field#}}}}` + - For 'end' node outputs: use `value_selector` array format `["node_id", "field"]` + - Example: to reference 'llm' node's 'text' output in end node, use `["llm", "text"]` + +3. **Tools**: + - ONLY use the tools listed in ``. + - If a planned tool is missing from schemas, fallback to `http-request` or `code`. + +4. **Model Selection** (CRITICAL): + - For LLM, question-classifier, and parameter-extractor nodes, you MUST include a "model" config. + - You MUST use ONLY models from the `` section above. + - Copy the EXACT provider and name values from available_models. + - NEVER use openai/gpt-4o, gpt-3.5-turbo, gpt-4, or any other models unless they appear in available_models. + - If available_models is empty or shows "No models configured", omit the model config entirely. + +5. **Node Specifics**: + - For `if-else` comparison_operator, use literal symbols: `≥`, `≤`, `=`, `≠` (NOT `>=` or `==`). + +6. **Output**: + - Return ONLY the JSON object with `nodes` and `edges`. + - Do NOT generate Mermaid diagrams. + - Do NOT generate explanations. + + + +**EDGES ARE CRITICAL** - Every node except 'end' MUST have at least one outgoing edge. + +1. **Linear Flow**: Simple source -> target connection + ``` + {{"source": "node_a", "target": "node_b"}} + ``` + +2. **question-classifier Branching**: Each class MUST have a separate edge with `sourceHandle` = class `id` + - If you define classes: [{{"id": "cls_refund", "name": "Refund"}}, {{"id": "cls_inquiry", "name": "Inquiry"}}] + - You MUST create edges: + - {{"source": "classifier", "sourceHandle": "cls_refund", "target": "refund_handler"}} + - {{"source": "classifier", "sourceHandle": "cls_inquiry", "target": "inquiry_handler"}} + +3. **if-else Branching**: MUST have exactly TWO edges with sourceHandle "true" and "false" + - {{"source": "condition", "sourceHandle": "true", "target": "true_branch"}} + - {{"source": "condition", "sourceHandle": "false", "target": "false_branch"}} + +4. **Branch Convergence**: Multiple branches can connect to same downstream node + - Both true_branch and false_branch can connect to the same 'end' node + +5. **NEVER leave orphan nodes**: Every node must be connected in the graph + + + + +```json +{{ + "nodes": [ + {{ + "id": "start", + "type": "start", + "title": "Start", + "config": {{ + "variables": [{{"variable": "query", "label": "Query", "type": "text-input"}}] + }} + }}, + {{ + "id": "llm", + "type": "llm", + "title": "Generate Response", + "config": {{ + "model": {{"provider": "openai", "name": "gpt-4o", "mode": "chat"}}, + "prompt_template": [{{"role": "user", "text": "Answer: {{{{#start.query#}}}}"}}] + }} + }}, + {{ + "id": "end", + "type": "end", + "title": "End", + "config": {{ + "outputs": [ + {{"variable": "result", "value_selector": ["llm", "text"]}} + ] + }} + }} + ], + "edges": [ + {{"source": "start", "target": "llm"}}, + {{"source": "llm", "target": "end"}} + ] +}} +``` + + + +```json +{{ + "nodes": [ + {{ + "id": "start", + "type": "start", + "title": "Start", + "config": {{ + "variables": [{{"variable": "user_input", "label": "User Message", "type": "text-input", "required": true}}] + }} + }}, + {{ + "id": "classifier", + "type": "question-classifier", + "title": "Classify Intent", + "config": {{ + "model": {{"provider": "openai", "name": "gpt-4o", "mode": "chat"}}, + "query_variable_selector": ["start", "user_input"], + "classes": [ + {{"id": "cls_refund", "name": "Refund Request"}}, + {{"id": "cls_inquiry", "name": "Product Inquiry"}}, + {{"id": "cls_complaint", "name": "Complaint"}}, + {{"id": "cls_other", "name": "Other"}} + ], + "instruction": "Classify the user's intent" + }} + }}, + {{ + "id": "handle_refund", + "type": "llm", + "title": "Handle Refund", + "config": {{ + "model": {{"provider": "openai", "name": "gpt-4o", "mode": "chat"}}, + "prompt_template": [{{"role": "user", "text": "Extract order number and respond: {{{{#start.user_input#}}}}"}}] + }} + }}, + {{ + "id": "handle_inquiry", + "type": "llm", + "title": "Handle Inquiry", + "config": {{ + "model": {{"provider": "openai", "name": "gpt-4o", "mode": "chat"}}, + "prompt_template": [{{"role": "user", "text": "Answer product question: {{{{#start.user_input#}}}}"}}] + }} + }}, + {{ + "id": "handle_complaint", + "type": "llm", + "title": "Handle Complaint", + "config": {{ + "model": {{"provider": "openai", "name": "gpt-4o", "mode": "chat"}}, + "prompt_template": [{{"role": "user", "text": "Respond with empathy: {{{{#start.user_input#}}}}"}}] + }} + }}, + {{ + "id": "handle_other", + "type": "llm", + "title": "Handle Other", + "config": {{ + "model": {{"provider": "openai", "name": "gpt-4o", "mode": "chat"}}, + "prompt_template": [{{"role": "user", "text": "Provide general response: {{{{#start.user_input#}}}}"}}] + }} + }}, + {{ + "id": "end", + "type": "end", + "title": "End", + "config": {{ + "outputs": [{{"variable": "response", "value_selector": ["handle_refund", "text"]}}] + }} + }} + ], + "edges": [ + {{"source": "start", "target": "classifier"}}, + {{"source": "classifier", "sourceHandle": "cls_refund", "target": "handle_refund"}}, + {{"source": "classifier", "sourceHandle": "cls_inquiry", "target": "handle_inquiry"}}, + {{"source": "classifier", "sourceHandle": "cls_complaint", "target": "handle_complaint"}}, + {{"source": "classifier", "sourceHandle": "cls_other", "target": "handle_other"}}, + {{"source": "handle_refund", "target": "end"}}, + {{"source": "handle_inquiry", "target": "end"}}, + {{"source": "handle_complaint", "target": "end"}}, + {{"source": "handle_other", "target": "end"}} + ] +}} +``` +CRITICAL: Notice that each class id (cls_refund, cls_inquiry, etc.) becomes a sourceHandle in the edges! + + + +```json +{{ + "nodes": [ + {{ + "id": "start", + "type": "start", + "title": "Start", + "config": {{ + "variables": [{{"variable": "years", "label": "Years of Experience", "type": "number", "required": true}}] + }} + }}, + {{ + "id": "check_experience", + "type": "if-else", + "title": "Check Experience", + "config": {{ + "cases": [ + {{ + "case_id": "case_1", + "logical_operator": "and", + "conditions": [ + {{ + "variable_selector": ["start", "years"], + "comparison_operator": "≥", + "value": "3" + }} + ] + }} + ] + }} + }}, + {{ + "id": "qualified", + "type": "llm", + "title": "Qualified Response", + "config": {{ + "model": {{"provider": "openai", "name": "gpt-4o", "mode": "chat"}}, + "prompt_template": [{{"role": "user", "text": "Generate qualified candidate response"}}] + }} + }}, + {{ + "id": "not_qualified", + "type": "llm", + "title": "Not Qualified Response", + "config": {{ + "model": {{"provider": "openai", "name": "gpt-4o", "mode": "chat"}}, + "prompt_template": [{{"role": "user", "text": "Generate rejection response"}}] + }} + }}, + {{ + "id": "end", + "type": "end", + "title": "End", + "config": {{ + "outputs": [{{"variable": "result", "value_selector": ["qualified", "text"]}}] + }} + }} + ], + "edges": [ + {{"source": "start", "target": "check_experience"}}, + {{"source": "check_experience", "sourceHandle": "true", "target": "qualified"}}, + {{"source": "check_experience", "sourceHandle": "false", "target": "not_qualified"}}, + {{"source": "qualified", "target": "end"}}, + {{"source": "not_qualified", "target": "end"}} + ] +}} +``` +CRITICAL: if-else MUST have exactly two edges with sourceHandle "true" and "false"! + + + +```json +{{ + "nodes": [ + {{ + "id": "start", + "type": "start", + "title": "Start", + "config": {{ + "variables": [{{"variable": "resume", "label": "Resume Text", "type": "paragraph", "required": true}}] + }} + }}, + {{ + "id": "extract", + "type": "parameter-extractor", + "title": "Extract Info", + "config": {{ + "model": {{"provider": "openai", "name": "gpt-4o", "mode": "chat"}}, + "query": ["start", "resume"], + "parameters": [ + {{"name": "name", "type": "string", "description": "Candidate name", "required": true}}, + {{"name": "years", "type": "number", "description": "Years of experience", "required": true}}, + {{"name": "skills", "type": "array[string]", "description": "List of skills", "required": true}} + ], + "instruction": "Extract candidate information from resume" + }} + }}, + {{ + "id": "process", + "type": "llm", + "title": "Process Data", + "config": {{ + "model": {{"provider": "openai", "name": "gpt-4o", "mode": "chat"}}, + "prompt_template": [{{"role": "user", "text": "Name: {{{{#extract.name#}}}}, Years: {{{{#extract.years#}}}}"}}] + }} + }}, + {{ + "id": "end", + "type": "end", + "title": "End", + "config": {{ + "outputs": [{{"variable": "result", "value_selector": ["process", "text"]}}] + }} + }} + ], + "edges": [ + {{"source": "start", "target": "extract"}}, + {{"source": "extract", "target": "process"}}, + {{"source": "process", "target": "end"}} + ] +}} +``` + + + + +Before finalizing, verify: +1. [ ] Every node (except 'end') has at least one outgoing edge +2. [ ] 'start' node has exactly one outgoing edge +3. [ ] 'question-classifier' has one edge per class, each with sourceHandle = class id +4. [ ] 'if-else' has exactly two edges: sourceHandle "true" and sourceHandle "false" +5. [ ] All branches eventually connect to 'end' (directly or through other nodes) +6. [ ] No orphan nodes exist (every node is reachable from 'start') + +""" + +BUILDER_USER_PROMPT = """ +{instruction} + + +Generate the full workflow configuration now. Pay special attention to: +1. Creating edges for ALL branches of question-classifier and if-else nodes +2. Using correct sourceHandle values for branching nodes +3. Ensuring every node is connected in the graph +""" diff --git a/api/core/workflow/generator/prompts/planner_prompts.py b/api/core/workflow/generator/prompts/planner_prompts.py new file mode 100644 index 0000000000..ada791bf94 --- /dev/null +++ b/api/core/workflow/generator/prompts/planner_prompts.py @@ -0,0 +1,75 @@ +PLANNER_SYSTEM_PROMPT = """ +You are an expert Workflow Architect. +Your job is to analyze user requests and plan a high-level automation workflow. + + + +1. **Classify Intent**: + - Is the user asking to create an automation/workflow? -> Intent: "generate" + - Is it general chat/weather/jokes? -> Intent: "off_topic" + +2. **Plan Steps** (if intent is "generate"): + - Break down the user's goal into logical steps. + - For each step, identify if a specific capability/tool is needed. + - Select the MOST RELEVANT tools from the available_tools list. + - DO NOT configure parameters yet. Just identify the tool. + +3. **Output Format**: + Return a JSON object. + + + +{tools_summary} + + + +If intent is "generate": +```json +{{ + "intent": "generate", + "plan_thought": "Brief explanation of the plan...", + "steps": [ + {{ "step": 1, "description": "Fetch data from URL", "tool": "http-request" }}, + {{ "step": 2, "description": "Summarize content", "tool": "llm" }}, + {{ "step": 3, "description": "Search for info", "tool": "google_search" }} + ], + "required_tool_keys": ["google_search"] +}} +``` +(Note: 'http-request', 'llm', 'code' are built-in, you don't need to list them in required_tool_keys, +only external tools) + +If intent is "off_topic": +```json +{{ + "intent": "off_topic", + "message": "I can only help you build workflows. Try asking me to 'Create a workflow that...'", + "suggestions": ["Scrape a website", "Summarize a PDF"] +}} +``` + +""" + +PLANNER_USER_PROMPT = """ +{instruction} + +""" + + +def format_tools_for_planner(tools: list[dict]) -> str: + """Format tools list for planner (Lightweight: Name + Description only).""" + if not tools: + return "No external tools available." + + lines = [] + for t in tools: + key = t.get("tool_key") or t.get("tool_name") + provider = t.get("provider_id") or t.get("provider", "") + desc = t.get("tool_description") or t.get("description", "") + label = t.get("tool_label") or key + + # Format: - [provider/key] Label: Description + full_key = f"{provider}/{key}" if provider else key + lines.append(f"- [{full_key}] {label}: {desc}") + + return "\n".join(lines) diff --git a/api/core/workflow/generator/prompts/vibe_prompts.py b/api/core/workflow/generator/prompts/vibe_prompts.py new file mode 100644 index 0000000000..d6d1b08cef --- /dev/null +++ b/api/core/workflow/generator/prompts/vibe_prompts.py @@ -0,0 +1,1264 @@ +""" +Vibe Workflow Generator - Enhanced Prompts with Inline Intent Classification. + +This module provides prompts for the agent-enhanced workflow generation +with inline intent classification (no separate ReAct loop to stay within +single endpoint constraints). +""" + +import json +import re +from typing import Any + +from core.workflow.generator.config import ( + BUILTIN_NODE_SCHEMAS, + DEFAULT_SUGGESTIONS, + FALLBACK_RULES, + OFF_TOPIC_RESPONSES, +) +from core.workflow.generator.types import ( + AvailableModelDict, + AvailableToolDict, + WorkflowDataDict, + WorkflowNodeDict, +) + + +def extract_instruction_values(instruction: str) -> dict[str, Any]: + """ + Extract concrete values from user instruction for auto-fill hints. + + This pre-processes the instruction to find URLs, emails, and other + concrete values that can be used as defaults in the generated workflow. + """ + urls = re.findall(r'https?://[^\s<>"{}|\\^`\[\]]+', instruction) + + return { + "urls": urls, + "emails": re.findall(r"[\w.-]+@[\w.-]+\.\w+", instruction), + "api_endpoints": [u for u in urls if "/api/" in u or "/v1/" in u or "/v2/" in u], + "file_extensions": re.findall(r"\.(json|csv|txt|pdf|docx?)(?:\s|$)", instruction, re.IGNORECASE), + "json_paths": re.findall(r"\.[\w]+(?:\.[\w]+)+", instruction), # e.g., .data.results + } + + +def format_extracted_values(extracted: dict[str, Any]) -> str: + """Format extracted values as XML for prompt inclusion.""" + parts = [] + + if extracted.get("urls"): + urls_str = ", ".join(extracted["urls"]) + parts.append(f" {urls_str}") + + if extracted.get("api_endpoints"): + endpoints_str = ", ".join(extracted["api_endpoints"]) + parts.append(f" {endpoints_str}") + + if extracted.get("emails"): + emails_str = ", ".join(extracted["emails"]) + parts.append(f" {emails_str}") + + if extracted.get("file_extensions"): + exts_str = ", ".join(extracted["file_extensions"]) + parts.append(f" {exts_str}") + + if parts: + return "\n" + "\n".join(parts) + "\n" + return "" + + +VIBE_ENHANCED_SYSTEM_PROMPT = """ +You are a Dify workflow design assistant. +You help users create AI automation workflows by generating workflow configurations. + + + +- Detect the language of the user's request automatically. +- Generate ALL node titles and user-facing text in the SAME language as the user's input. +- If the input language cannot be determined, use {preferred_language} as the fallback language. +- Example: If user writes in Chinese, node titles should be in Chinese (e.g., "获取数据", "处理结果"). +- Example: If user writes in English, node titles should be in English (e.g., "Fetch Data", "Process Results"). + + + +- Generate workflow configurations from natural language descriptions +- Validate tool references against available integrations +- Provide clear, helpful responses +- Reject requests that are not about workflow design + + +{available_nodes_formatted} + +{available_tools_formatted} + +{available_models_formatted} + + + How to reference data from other nodes in your workflow + {{{{#node_id.field_name#}}}} + + {{{{#start.url#}}}} + {{{{#start.query#}}}} + {{{{#llm_node_id.text#}}}} + {{{{#http_node_id.body#}}}} + {{{{#code_node_id.result#}}}} + + + + + + For LLM, question-classifier, parameter-extractor nodes: + - You MUST include a "model" config with provider and name from available_models section + - Copy the EXACT provider and name values from available_models + - NEVER use openai/gpt-4o, openai/gpt-3.5-turbo, openai/gpt-4 unless they appear in available_models + - If available_models is empty or not provided, omit the model config entirely + + + ONLY use tools with status="configured" from available_tools. + NEVER invent tool names like "webscraper", "email_sender", etc. + If no matching tool exists, use http-request or code node as fallback. + + + ALWAYS fill ALL required_params for each node type. + Check the node's params section to know what config is needed. + + + Use {{{{#node_id.field#}}}} syntax to reference outputs from previous nodes. + Start node variables: {{{{#start.variable_name#}}}} + + + + + When user requests capability with NO matching tool in available_tools + + Configure with: url (the URL to fetch), method: GET + + + Configure with: url, method, headers, body as needed + + + Write Python/JavaScript code with main() function + + + Use prompt_template with appropriate system/user messages + + Add warning to response explaining the fallback substitution + + + + +{existing_nodes_formatted} + + +{selected_nodes_formatted} + + + + + +```json +{{{{ + "intent": "generate", + "thinking": "Brief analysis of user request and approach", + "message": "User-friendly explanation of the workflow", + "mermaid": "flowchart TD\\n N1[\\"type=start|title=Start\\"]\\n ...", + "nodes": [ + {{{{ + "id": "node_id", + "type": "node_type", + "title": "Display Name", + "config": {{{{ /* REQUIRED: Fill all required_params from node schema */ }}}} + }}}} + ], + "edges": [{{{{"source": "node1_id", "target": "node2_id"}}}}], + "warnings": ["Any warnings about fallbacks or missing features"] +}}}} +``` + + +```json +{{{{ + "intent": "off_topic", + "message": "Explanation of what you can help with", + "suggestions": ["Workflow suggestion 1", "Suggestion 2"] +}}}} +``` + + + + + Weather queries, math calculations, jokes, general knowledge + Translation requests, general coding help, account/billing questions + Workflow creation, node configuration, automation design + Questions about Dify workflow capabilities + + + + Use `flowchart TD` for top-down flow + Node format: `ID["type=TYPE|title=TITLE"]` or `ID["type=tool|title=TITLE|tool=TOOL_KEY"]` + type= and title= are REQUIRED for EVERY node + Declare all nodes BEFORE edges + Use `-->` for connections, `-->|true|` and `-->|false|` for branches + + N1["type=start|title=Start"] + N2["type=http-request|title=Fetch Data"] + N3["type=tool|title=Search|tool=google/google_search"] + Start[Start] + N1["type=tool|title=Scrape|tool=webscraper"] + + + + + + {{{{ + "id": "fetch", + "type": "http-request", + "title": "Fetch Webpage", + "config": {{{{ + "url": "{{{{#start.url#}}}}", + "method": "GET", + "headers": "", + "params": "", + "body": {{{{"type": "none", "data": []}}}}, + "authorization": {{{{"type": "no-auth"}}}} + }}}} + }}}} + + + {{{{ + "id": "analyze", + "type": "llm", + "title": "Analyze Content", + "config": {{{{ + "model": {{{{"provider": "USE_FROM_AVAILABLE_MODELS", "name": "USE_FROM_AVAILABLE_MODELS", "mode": "chat"}}}}, + "prompt_template": [ + {{{{"role": "system", "text": "You are a helpful analyst."}}}}, + {{{{"role": "user", "text": "Analyze this content:\\n\\n{{{{#fetch.body#}}}}"}}}} + ] + }}}} + }}}} + NOTE: Replace "USE_FROM_AVAILABLE_MODELS" with actual values from available_models section! + + + {{{{ + "id": "process", + "type": "code", + "title": "Process Data", + "config": {{{{ + "language": "python3", + "code": "def main(data):\\n return {{\\"result\\": data.upper()}}" + }}}} + }}}} + + + + + CRITICAL: Auto-fill all parameters so workflow runs immediately + + + MUST define input variables for ALL data the workflow needs from user + Use extracted values from user instruction as "default" when available + + URL fetching workflow → add "url" variable with type="text-input" + Text processing workflow → add "content" or "query" variable + API integration → add "api_key" variable if authentication needed + + + {{{{ + "variables": [ + {{{{"variable": "url", "label": "Target URL", "type": "text-input", "required": true, "default": "https://..."}}}} + ] + }}}} + + + + + EVERY node parameter that needs data must reference a source: + - User input from start node → {{{{#start.variable_name#}}}} + - Output from previous node → {{{{#node_id.output_field#}}}} + - Or a concrete hardcoded value extracted from user instruction + NEVER leave parameters empty - always fill with variable reference or concrete value + + + + url: MUST be {{{{#start.url#}}}} OR concrete URL from instruction - NEVER empty + method: Set based on action (fetch/get → GET, send/post/create → POST, update → PUT, delete → DELETE) + headers: Include Authorization if API key is available + + + + prompt_template MUST reference previous node output for context + Example: {{{{"role": "user", "text": "Analyze this:\\n\\n{{{{#http_node.body#}}}}"}}}} + Include system message to set AI behavior/role + + + + variables array MUST include inputs from previous nodes + Example: {{{{"variable": "data", "value_selector": ["http_node", "body"]}}}} + code must define main() function that returns dict + + + + outputs MUST use value_selector to reference the final processing node's output + Example: {{{{"variable": "result", "value_selector": ["llm_node", "text"]}}}} + + + +{extracted_values_formatted} +""" + +VIBE_ENHANCED_USER_PROMPT = """ +{instruction} + + + + + Is this request about workflow/automation design? + - If NO (weather, jokes, math, translations, general questions) → return off_topic response + - If YES → proceed to Step 2 + + + + - What is the user trying to achieve? + - What inputs are needed (define in start node)? + - What processing steps are required? + - What outputs should be produced (define in end node)? + + + + - Check available_tools - which tools with status="configured" can be used? + - For each required capability, check if a matching tool exists + - If NO matching tool: use fallback node (http-request, code, or llm) + - NEVER invent tool names - only use exact keys from available_tools + + + + - For EACH node, check its required_params in available_nodes + - Fill ALL required config fields with proper values + - Use {{{{#node_id.field#}}}} syntax to reference previous node outputs + - http-request MUST have: url, method + - code MUST have: code, language + - llm MUST have: prompt_template + + + + - Create mermaid flowchart with correct syntax + - Generate nodes array with complete config for each node + - Generate edges array connecting the nodes + - Add warnings if using fallback nodes + + + +{previous_attempt_formatted} + + +Generate your JSON response now. Remember: +1. Fill ALL required_params for each node type +2. Use variable references like {{{{#start.url#}}}} to connect nodes +3. Never invent tool names - use fallback nodes instead + +""" + + +def format_available_nodes(nodes: list[WorkflowNodeDict] | None) -> str: + """Format available nodes as XML with parameter schemas.""" + lines = [""] + + # First, add built-in nodes with their schemas + for node_type, schema in BUILTIN_NODE_SCHEMAS.items(): + lines.append(f' ') + lines.append(f" {schema.get('description', '')}") + + required = schema.get("required", []) + if required: + lines.append(f" {', '.join(required)}") + + params = schema.get("parameters", {}) + if params: + lines.append(" ") + for param_name, param_info in params.items(): + param_type = param_info.get("type", "string") + is_required = param_name in required + desc = param_info.get("description", "") + + if param_type == "enum": + options = param_info.get("options", []) + lines.append( + f' ' + f"{desc}" + ) + else: + lines.append( + f' {desc}' + ) + + # Add example if present + if "example" in param_info: + example = param_info["example"] + if isinstance(example, dict): + example = json.dumps(example) + lines.append(f" ") + lines.append(" ") + + outputs = schema.get("outputs", []) + if outputs: + lines.append(f" {', '.join(outputs)}") + + lines.append(" ") + + # Add custom nodes from the provided list (without detailed schemas) + if nodes: + for node in nodes: + node_type = node.get("type", "unknown") + # Skip if already covered by built-in schemas + if node_type in BUILTIN_NODE_SCHEMAS: + continue + description = node.get("description", "No description") + lines.append(f' ') + lines.append(f" {description}") + lines.append(" ") + + lines.append("") + return "\n".join(lines) + + +def format_available_tools(tools: list[AvailableToolDict] | None) -> str: + """Format available tools as XML with parameter schemas.""" + lines = [""] + + if not tools: + lines.append(" ") + lines.append(" ") + lines.append("") + return "\n".join(lines) + + configured_tools: list[AvailableToolDict] = [] + unconfigured_tools: list[AvailableToolDict] = [] + + for tool in tools: + if tool.get("is_team_authorization", False): + configured_tools.append(tool) + else: + unconfigured_tools.append(tool) + + # Configured tools (ready to use) + lines.append(" ") + if configured_tools: + for tool in configured_tools: + tool_key = tool.get("tool_key") or f"{tool.get('provider_id')}/{tool.get('tool_name')}" + description = tool.get("tool_description") or tool.get("description", "") + lines.append(f' ') + lines.append(f" {description}") + + # Add parameter schemas if available + parameters = tool.get("parameters") + if parameters: + lines.append(" ") + for param in parameters: + param_name = param.get("name", "") + param_type = param.get("type", "string") + required = param.get("required", False) + param_desc = param.get("human_description") or param.get("llm_description") or "" + # Handle localized descriptions + if isinstance(param_desc, dict): + param_desc = param_desc.get("en_US") or param_desc.get("zh_Hans") or str(param_desc) + options = param.get("options", []) + + if options: + opt_str = ",".join(str(o.get("value", o)) if isinstance(o, dict) else str(o) for o in options) + lines.append( + f' {param_desc}' + ) + else: + lines.append( + f' {param_desc}' + ) + lines.append(" ") + + lines.append(" ") + else: + lines.append(" ") + + # Unconfigured tools (need setup first) + lines.append("") + lines.append(" ") + if unconfigured_tools: + for tool in unconfigured_tools: + tool_key = tool.get("tool_key") or f"{tool.get('provider_id')}/{tool.get('tool_name')}" + description = tool.get("tool_description") or tool.get("description", "") + lines.append(f' ') + lines.append(f" {description}") + lines.append(" ") + lines.append(" ") + else: + lines.append(" ") + + lines.append("") + return "\n".join(lines) + + +def format_existing_nodes(nodes: list[WorkflowNodeDict] | None) -> str: + """Format existing workflow nodes for context.""" + if not nodes: + return "No existing nodes in workflow (creating from scratch)." + + lines = [] + for node in nodes: + node_id = node.get("id", "unknown") + node_type = node.get("type", "unknown") + title = node.get("title", "Untitled") + lines.append(f"- [{node_id}] {title} ({node_type})") + return "\n".join(lines) + + +def format_selected_nodes( + selected_ids: list[str] | None, + existing_nodes: list[WorkflowNodeDict] | None, +) -> str: + """Format selected nodes for modification context.""" + if not selected_ids: + return "No nodes selected (generating new workflow)." + + node_map = {n.get("id"): n for n in (existing_nodes or [])} + lines = [] + for node_id in selected_ids: + if node_id in node_map: + node = node_map[node_id] + lines.append(f"- [{node_id}] {node.get('title', 'Untitled')} ({node.get('type', 'unknown')})") + else: + lines.append(f"- [{node_id}] (not found in current workflow)") + return "\n".join(lines) + + +def format_previous_attempt( + previous_workflow: WorkflowDataDict | None, + regenerate_mode: bool = False, +) -> str: + """ + Format previous workflow attempt as XML context for regeneration. + + When regenerating, we pass the previous workflow and warnings so the model + can fix specific issues instead of starting from scratch. + """ + if not regenerate_mode or not previous_workflow: + return "" + + nodes = previous_workflow.get("nodes", []) + edges = previous_workflow.get("edges", []) + warnings = previous_workflow.get("warnings", []) + + parts = [""] + parts.append(" ") + parts.append(" Your previous generation had issues. Please fix them while keeping the good parts.") + parts.append(" ") + + if warnings: + parts.append(" ") + for warning in warnings: + parts.append(f" - {warning}") + parts.append(" ") + + if nodes: + # Summarize nodes without full config to save tokens + parts.append(" ") + for node in nodes: + node_id = node.get("id", "unknown") + node_type = node.get("type", "unknown") + title = node.get("title", "Untitled") + config = node.get("config", {}) + + # Show key config issues for debugging + config_summary = "" + if node_type == "http-request": + url = config.get("url", "") + if not url: + config_summary = " (url: EMPTY - needs fix)" + elif url.startswith("{{#"): + config_summary = f" (url: {url})" + elif node_type == "tool": + tool_name = config.get("tool_name", "") + config_summary = f" (tool: {tool_name})" + + parts.append(f" - [{node_id}] {title} ({node_type}){config_summary}") + parts.append(" ") + + if edges: + parts.append(" ") + for edge in edges: + parts.append(f" - {edge.get('source', '?')} → {edge.get('target', '?')}") + parts.append(" ") + + parts.append(" ") + parts.append(" 1. Keep the workflow structure if it makes sense") + parts.append(" 2. Fix any invalid tool references - use http-request or code as fallback") + parts.append(" 3. Fill ALL required parameters (url, method, prompt_template, etc.)") + parts.append(" 4. Use {{#node_id.field#}} syntax for variable references") + parts.append(" 5. Define input variables in the Start node") + parts.append(" ") + parts.append("") + + return "\n".join(parts) + + +def format_available_models(models: list[AvailableModelDict] | None) -> str: + """Format available models as XML for prompt inclusion.""" + if not models: + return "\n \n" + + lines = [""] + for model in models: + provider = model.get("provider", "unknown") + model_name = model.get("model", "unknown") + lines.append(f' ') + lines.append("") + + # Add model selection rule with concrete example + lines.append("") + lines.append("") + lines.append(" CRITICAL: For LLM, question-classifier, and parameter-extractor nodes:") + lines.append(" - You MUST include a 'model' field in the config") + lines.append(" - You MUST use ONLY models from available_models above") + lines.append(" - NEVER use openai/gpt-4o, gpt-3.5-turbo, gpt-4 unless they appear in available_models") + lines.append("") + + # Provide concrete JSON example to copy + first_model = models[0] + provider = first_model.get("provider", "unknown") + model_name = first_model.get("model", "unknown") + lines.append(" COPY THIS EXACT MODEL CONFIG for all LLM/question-classifier/parameter-extractor nodes:") + lines.append(f' "model": {{"provider": "{provider}", "name": "{model_name}", "mode": "chat"}}') + + if len(models) > 1: + lines.append("") + lines.append(" Alternative models you can use:") + for m in models[1:4]: # Show up to 3 alternatives + p = m.get("provider", "unknown") + n = m.get("model", "unknown") + lines.append(f' - "model": {{"provider": "{p}", "name": "{n}", "mode": "chat"}}') + + lines.append("") + + return "\n".join(lines) + + +def build_vibe_enhanced_prompt( + instruction: str, + available_nodes: list[WorkflowNodeDict] | None = None, + available_tools: list[AvailableToolDict] | None = None, + existing_nodes: list[WorkflowNodeDict] | None = None, + selected_node_ids: list[str] | None = None, + previous_workflow: WorkflowDataDict | None = None, + regenerate_mode: bool = False, + preferred_language: str | None = None, + available_models: list[AvailableModelDict] | None = None, +) -> tuple[str, str]: + """Build the complete system and user prompts.""" + # Extract concrete values from user instruction for auto-fill hints + extracted = extract_instruction_values(instruction) + extracted_values_xml = format_extracted_values(extracted) + + # Format previous attempt context for regeneration + previous_attempt_xml = format_previous_attempt(previous_workflow, regenerate_mode) + + # Default to English if no preferred language specified + language_hint = preferred_language or "English" + + system_prompt = VIBE_ENHANCED_SYSTEM_PROMPT.format( + preferred_language=language_hint, + available_nodes_formatted=format_available_nodes(available_nodes), + available_tools_formatted=format_available_tools(available_tools), + existing_nodes_formatted=format_existing_nodes(existing_nodes), + selected_nodes_formatted=format_selected_nodes(selected_node_ids, existing_nodes), + extracted_values_formatted=extracted_values_xml, + previous_attempt_formatted=previous_attempt_xml, + available_models_formatted=format_available_models(available_models), + ) + + user_prompt = VIBE_ENHANCED_USER_PROMPT.format( + instruction=instruction, + previous_attempt_formatted=previous_attempt_xml, + ) + + return system_prompt, user_prompt + + +def parse_vibe_response(content: str) -> dict[str, Any]: + """Parse LLM response into structured format.""" + # Extract JSON from markdown code block if present + json_match = re.search(r"```(?:json)?\s*([\s\S]+?)```", content) + if json_match: + content = json_match.group(1).strip() + + # Try parsing JSON + try: + data = json.loads(content) + except json.JSONDecodeError: + # Attempt simple repair: remove trailing commas + cleaned = re.sub(r",\s*([}\]])", r"\1", content) + try: + data = json.loads(cleaned) + except json.JSONDecodeError: + # Return error format + return { + "intent": "error", + "error": "Failed to parse LLM response as JSON", + "raw_content": content[:500], # First 500 chars for debugging + } + + # Validate and normalize + if "intent" not in data: + data["intent"] = "generate" # Default assumption + + # Ensure required fields for generate intent + if data["intent"] == "generate": + data.setdefault("mermaid", "") + data.setdefault("nodes", []) + data.setdefault("edges", []) + data.setdefault("message", "") + data.setdefault("warnings", []) + + # Ensure required fields for off_topic intent + if data["intent"] == "off_topic": + data.setdefault("message", OFF_TOPIC_RESPONSES["default"]["en"]) + data.setdefault("suggestions", DEFAULT_SUGGESTIONS["en"]) + + return data + + +def validate_tool_references( + nodes: list[WorkflowNodeDict], + available_tools: list[AvailableToolDict] | None, +) -> tuple[list[str], list[dict[str, Any]]]: + """ + Validate tool references and return warnings and recommendations. + + Returns: + tuple of (warnings, tool_recommendations) + """ + if not available_tools: + return [], [] + + # Build lookup sets for configured and unconfigured tools + configured_keys: set[str] = set() + unconfigured_keys: set[str] = set() + tool_info_map: dict[str, dict[str, Any]] = {} + + for tool in available_tools: + provider = tool.get("provider_id") or tool.get("provider", "") + tool_key = tool.get("tool_key") or tool.get("tool_name", "") + is_authorized = tool.get("is_team_authorization", False) + + full_key = f"{provider}/{tool_key}" if provider else tool_key + tool_info_map[full_key] = { + "provider_id": provider, + "tool_name": tool_key, + "description": tool.get("tool_description") or tool.get("description", ""), + } + + if is_authorized: + configured_keys.add(full_key) + if tool_key: + configured_keys.add(tool_key) + else: + unconfigured_keys.add(full_key) + if tool_key: + unconfigured_keys.add(tool_key) + + warnings: list[str] = [] + recommendations: list[dict[str, Any]] = [] + seen_recommendations: set[str] = set() + + for node in nodes: + if node.get("type") == "tool": + config = node.get("config", {}) + tool_ref = config.get("tool_key") or config.get("tool") or node.get("tool_name") + + if not tool_ref: + continue + + # Check if tool is configured + if tool_ref in configured_keys: + continue + + # Check if tool exists but is unconfigured + if tool_ref in unconfigured_keys: + if tool_ref not in seen_recommendations: + seen_recommendations.add(tool_ref) + warnings.append(f"Tool '{tool_ref}' requires configuration") + tool_info = tool_info_map.get(tool_ref, {}) + recommendations.append( + { + "requested_capability": f"Use {tool_ref}", + "unconfigured_tools": [tool_info] if tool_info else [], + "configured_alternatives": [], + "recommendation": f"Configure '{tool_ref}' in Tools settings to enable this functionality", + } + ) + else: + # Tool doesn't exist at all + warnings.append(f"Tool '{tool_ref}' not found in available tools") + + return warnings, recommendations + + +def determine_fallback_type(tool_ref: str, node_title: str) -> str | None: + """ + Determine the best fallback node type based on tool name/title semantics. + + Returns: + - "http-request" for web/API related tools + - "code" for logic/calculation related tools + - "llm" for text/AI analysis related tools + - None if no appropriate fallback can be determined + """ + combined = f"{tool_ref} {node_title}".lower() + + for fallback_type, keywords in FALLBACK_RULES.items(): + if any(kw in combined for kw in keywords): + return fallback_type + + # No matching rule - don't force a fallback + return None + + +def create_http_request_fallback(original_node: WorkflowNodeDict) -> WorkflowNodeDict: + """Create http-request fallback node, preserving original URL if present.""" + config = original_node.get("config", {}) + tool_params = config.get("tool_parameters", {}) + # Also check "params" - LLM may put tool parameters there + params = config.get("params", {}) + if isinstance(params, str): + # params might be a string (query params), not tool params + params = {} + + # Try to preserve URL from original config (check multiple locations) + original_url = config.get("url") or tool_params.get("url") or params.get("url") or "" + + # Headers should be a string (newline separated key: value pairs) + headers = config.get("headers") or tool_params.get("headers") or params.get("headers") or "" + if isinstance(headers, dict): + # Convert dict to string format + headers = "\n".join(f"{k}: {v}" for k, v in headers.items()) if headers else "" + + # Body should have a type field - use "none" as default + body = config.get("body") or tool_params.get("body") or params.get("body") or {} + if not isinstance(body, dict) or "type" not in body: + body = {"type": "none", "data": []} + + # Method - check multiple locations + method = config.get("method") or tool_params.get("method") or params.get("method") or "GET" + + return { + "id": original_node.get("id", ""), + "type": "http-request", + "title": f"{original_node.get('title', 'Request')} (fallback)", + "config": { + "method": method, + "url": original_url, + "headers": headers, + "params": "", + "body": body, + "authorization": {"type": "no-auth"}, + }, + } + + +def create_code_fallback(original_node: WorkflowNodeDict) -> WorkflowNodeDict: + """Create code fallback node with placeholder implementation.""" + title = original_node.get("title", "Process") + return { + "id": original_node.get("id", ""), + "type": "code", + "title": f"{title} (fallback)", + "config": { + "language": "python3", + "code": f'def main():\n # TODO: Implement "{title}" logic\n return {{"result": "placeholder"}}', + }, + } + + +def create_llm_fallback(original_node: WorkflowNodeDict) -> WorkflowNodeDict: + """Create LLM fallback node for text analysis tasks.""" + title = original_node.get("title", "Analyze") + return { + "id": original_node.get("id", ""), + "type": "llm", + "title": f"{title} (fallback)", + "config": { + "prompt_template": [ + {"role": "system", "text": "You are a helpful assistant."}, + {"role": "user", "text": f"Please help with: {title}"}, + ], + }, + } + + +def sanitize_tool_nodes( + nodes: list[WorkflowNodeDict], + available_tools: list[AvailableToolDict] | None, +) -> tuple[list[WorkflowNodeDict], list[str]]: + """ + Replace invalid tool nodes with fallback nodes (http-request or code). + + This is a safety net for when the LLM hallucinates tool names despite prompt instructions. + + Returns: + tuple of (sanitized_nodes, warnings) + """ + if not nodes: + return [], [] + + # Build set of valid tool keys + valid_tool_keys: set[str] = set() + if available_tools: + for tool in available_tools: + provider = tool.get("provider_id") or tool.get("provider", "") + tool_key = tool.get("tool_key") or tool.get("tool_name", "") + if provider and tool_key: + valid_tool_keys.add(f"{provider}/{tool_key}") + if tool_key: + valid_tool_keys.add(tool_key) + + sanitized: list[WorkflowNodeDict] = [] + warnings: list[str] = [] + + for node in nodes: + if node.get("type") != "tool": + sanitized.append(node) + continue + + # Check if tool reference is valid + config = node.get("config", {}) + tool_ref = ( + config.get("tool_key") + or config.get("tool_name") + or config.get("provider_id", "") + "/" + config.get("tool_name", "") + ) + + # Normalize and check validity + normalized_refs = [tool_ref] + if "/" in tool_ref: + # Also check just the tool name part + normalized_refs.append(tool_ref.split("/")[-1]) + + is_valid = any(ref in valid_tool_keys for ref in normalized_refs if ref) + + if is_valid: + sanitized.append(node) + else: + # Determine the best fallback type based on tool semantics + node_title = node.get("title", "") + fallback_type = determine_fallback_type(tool_ref, node_title) + + if fallback_type == "http-request": + fallback_node = create_http_request_fallback(node) + sanitized.append(fallback_node) + warnings.append( + f"Tool '{tool_ref}' not found. Replaced with http-request node. " + "Please configure the URL if not set." + ) + elif fallback_type == "code": + fallback_node = create_code_fallback(node) + sanitized.append(fallback_node) + warnings.append( + f"Tool '{tool_ref}' not found. Replaced with code node. " + "Please implement the logic in the code editor." + ) + elif fallback_type == "llm": + fallback_node = create_llm_fallback(node) + sanitized.append(fallback_node) + warnings.append( + f"Tool '{tool_ref}' not found. Replaced with LLM node. Please configure the prompt template." + ) + else: + # No appropriate fallback - keep original node and warn + sanitized.append(node) + warnings.append( + f"Tool '{tool_ref}' not found and no suitable fallback determined. " + "Please configure a valid tool or replace this node manually." + ) + + return sanitized, warnings + + +def validate_node_parameters(nodes: list[WorkflowNodeDict]) -> list[str]: + """ + Validate that all required parameters are properly filled in generated nodes. + + Returns a list of warnings for nodes with missing or empty parameters. + """ + warnings: list[str] = [] + + for node in nodes: + node_id = node.get("id", "unknown") + node_type = node.get("type", "") + config = node.get("config", {}) + + if node_type == "http-request": + url = config.get("url", "") + if not url: + warnings.append(f"Node '{node_id}': http-request is missing required 'url' parameter") + elif url == "": + warnings.append(f"Node '{node_id}': http-request has empty 'url' - please configure") + method = config.get("method", "") + if not method: + warnings.append(f"Node '{node_id}': http-request should have 'method' (GET, POST, etc.)") + + elif node_type == "llm": + prompt_template = config.get("prompt_template", []) + if not prompt_template: + warnings.append(f"Node '{node_id}': LLM node is missing 'prompt_template'") + else: + # Check if any prompt references previous node output + has_reference = any("{{#" in p.get("text", "") for p in prompt_template if isinstance(p, dict)) + if not has_reference: + warnings.append( + f"Node '{node_id}': LLM prompt should reference previous node output " + "using {{#node_id.field#}} syntax" + ) + + elif node_type == "code": + code = config.get("code", "") + if not code: + warnings.append(f"Node '{node_id}': code node is missing 'code' parameter") + language = config.get("language", "") + if not language: + warnings.append(f"Node '{node_id}': code node should specify 'language' (python3 or javascript)") + + elif node_type == "start": + variables = config.get("variables", []) + if not variables: + warnings.append("Start node should define input variables for user data (e.g., url, query, content)") + + elif node_type == "end": + outputs = config.get("outputs", []) + if not outputs: + warnings.append("End node should define output variables to return workflow results") + + return warnings + + +def extract_mermaid_from_response(data: dict[str, Any]) -> str: + """Extract mermaid flowchart from parsed response.""" + mermaid = data.get("mermaid", "") + + if not mermaid: + return "" + + # Clean up mermaid code + mermaid = mermaid.strip() + # Remove code fence if present + if mermaid.startswith("```"): + match = re.search(r"```(?:mermaid)?\s*([\s\S]+?)```", mermaid) + if match: + mermaid = match.group(1).strip() + + # Sanitize edge labels to remove characters that break Mermaid parsing + # Edge labels in Mermaid are ONLY in the pattern: -->|label| + # We must NOT match |pipe| characters inside node labels like ["type=start|title=开始"] + def sanitize_edge_label(match: re.Match) -> str: + arrow = match.group(1) # --> or --- + label = match.group(2) # the label between pipes + # Remove or replace special characters that break Mermaid + # Parentheses, brackets, braces have special meaning in Mermaid + sanitized = re.sub(r"[(){}\[\]]", "", label) + return f"{arrow}|{sanitized}|" + + # Only match edge labels: --> or --- followed by |label| + # This pattern ensures we only sanitize actual edge labels, not node content + mermaid = re.sub(r"(-->|---)\|([^|]+)\|", sanitize_edge_label, mermaid) + + return mermaid + + +def classify_validation_errors( + nodes: list[dict[str, Any]], + available_models: list[dict[str, Any]] | None = None, + available_tools: list[dict[str, Any]] | None = None, + edges: list[dict[str, Any]] | None = None, +) -> dict[str, list[dict[str, Any]]]: + """ + Classify validation errors into fixable and user-required categories. + + This function uses the declarative rule engine to validate nodes. + The rule engine provides deterministic, testable validation without + relying on LLM judgment. + + Fixable errors can be automatically corrected by the LLM in subsequent + iterations. User-required errors need manual intervention. + + Args: + nodes: List of generated workflow nodes + available_models: List of models the user has configured + available_tools: List of available tools + edges: List of edges connecting nodes + + Returns: + dict with: + - "fixable": errors that LLM can fix automatically + - "user_required": errors that need user intervention + - "all_warnings": combined warning messages for backwards compatibility + - "stats": validation statistics + """ + from core.workflow.generator.validation import ValidationContext, ValidationEngine + + # Build validation context + context = ValidationContext( + nodes=nodes, + edges=edges or [], + available_models=available_models or [], + available_tools=available_tools or [], + ) + + # Run validation through rule engine + engine = ValidationEngine() + result = engine.validate(context) + + # Convert to legacy format for backwards compatibility + fixable: list[dict[str, Any]] = [] + user_required: list[dict[str, Any]] = [] + + for error in result.fixable_errors: + fixable.append( + { + "node_id": error.node_id, + "node_type": error.node_type, + "error_type": error.rule_id, + "message": error.message, + "is_fixable": True, + "fix_hint": error.fix_hint, + "category": error.category.value, + "details": error.details, + } + ) + + for error in result.user_required_errors: + user_required.append( + { + "node_id": error.node_id, + "node_type": error.node_type, + "error_type": error.rule_id, + "message": error.message, + "is_fixable": False, + "fix_hint": error.fix_hint, + "category": error.category.value, + "details": error.details, + } + ) + + # Include warnings in user_required (they're non-blocking but informative) + for error in result.warnings: + user_required.append( + { + "node_id": error.node_id, + "node_type": error.node_type, + "error_type": error.rule_id, + "message": error.message, + "is_fixable": error.is_fixable, + "fix_hint": error.fix_hint, + "category": error.category.value, + "severity": "warning", + "details": error.details, + } + ) + + # Generate combined warnings for backwards compatibility + all_warnings = [e["message"] for e in fixable + user_required] + + return { + "fixable": fixable, + "user_required": user_required, + "all_warnings": all_warnings, + "stats": result.stats, + } + + +def build_fix_prompt( + fixable_errors: list[dict[str, Any]], + previous_nodes: list[dict[str, Any]], + available_models: list[dict[str, Any]] | None = None, +) -> str: + """ + Build a prompt for LLM to fix the identified errors. + + This creates a focused instruction that tells the LLM exactly what + to fix in the previous generation. + + Args: + fixable_errors: List of errors that can be automatically fixed + previous_nodes: The nodes from the previous generation attempt + available_models: Available models for model configuration fixes + + Returns: + Formatted prompt string for the fix iteration + """ + if not fixable_errors: + return "" + + parts = [""] + parts.append(" ") + parts.append(" Your previous generation has errors that need fixing.") + parts.append(" Please regenerate with the following corrections:") + parts.append(" ") + + # Group errors by node + errors_by_node: dict[str, list[dict[str, Any]]] = {} + for error in fixable_errors: + node_id = error["node_id"] + if node_id not in errors_by_node: + errors_by_node[node_id] = [] + errors_by_node[node_id].append(error) + + parts.append(" ") + for node_id, node_errors in errors_by_node.items(): + parts.append(f' ') + for error in node_errors: + error_type = error["error_type"] + message = error["message"] + fix_hint = error.get("fix_hint", "") + parts.append(f' ') + parts.append(f" {message}") + if fix_hint: + parts.append(f" {fix_hint}") + parts.append(" ") + parts.append(" ") + parts.append(" ") + + # Add model selection help if there are model-related errors + model_errors = [e for e in fixable_errors if "model" in e["error_type"]] + if model_errors and available_models: + parts.append(" ") + parts.append(" Use one of these models for nodes requiring model config:") + for model in available_models[:3]: # Show top 3 + provider = model.get("provider", "unknown") + name = model.get("model", "unknown") + parts.append(f' - {{"provider": "{provider}", "name": "{name}", "mode": "chat"}}') + parts.append(" ") + + # Add previous nodes summary for context + parts.append(" ") + for node in previous_nodes: + node_id = node.get("id", "unknown") + if node_id in errors_by_node: + # Only include nodes that have errors + node_type = node.get("type", "unknown") + title = node.get("title", "Untitled") + config_summary = json.dumps(node.get("config", {}), ensure_ascii=False)[:200] + parts.append(f' ') + parts.append(f" {config_summary}...") + parts.append(" ") + parts.append(" ") + + parts.append(" ") + parts.append(" 1. Keep the workflow structure and logic unchanged") + parts.append(" 2. Fix ONLY the errors listed above") + parts.append(" 3. Ensure all required fields are properly filled") + parts.append(" 4. Use variable references {{#node_id.field#}} where appropriate") + parts.append(" ") + parts.append("") + + return "\n".join(parts) diff --git a/api/core/workflow/generator/runner.py b/api/core/workflow/generator/runner.py new file mode 100644 index 0000000000..af00eae13b --- /dev/null +++ b/api/core/workflow/generator/runner.py @@ -0,0 +1,196 @@ +import json +import logging +import re +from collections.abc import Sequence + +import json_repair + +from core.model_manager import ModelManager +from core.model_runtime.entities.message_entities import SystemPromptMessage, UserPromptMessage +from core.model_runtime.entities.model_entities import ModelType +from core.workflow.generator.prompts.builder_prompts import BUILDER_SYSTEM_PROMPT, BUILDER_USER_PROMPT +from core.workflow.generator.prompts.planner_prompts import ( + PLANNER_SYSTEM_PROMPT, + PLANNER_USER_PROMPT, + format_tools_for_planner, +) +from core.workflow.generator.prompts.vibe_prompts import ( + format_available_models, + format_available_nodes, + format_available_tools, + parse_vibe_response, +) +from core.workflow.generator.utils.edge_repair import EdgeRepair +from core.workflow.generator.utils.mermaid_generator import generate_mermaid +from core.workflow.generator.utils.node_repair import NodeRepair +from core.workflow.generator.utils.workflow_validator import WorkflowValidator + +logger = logging.getLogger(__name__) + + +class WorkflowGenerator: + """ + Refactored Vibe Workflow Generator (Planner-Builder Architecture). + Extracts Vibe logic from the monolithic LLMGenerator. + """ + + @classmethod + def generate_workflow_flowchart( + cls, + tenant_id: str, + instruction: str, + model_config: dict, + available_nodes: Sequence[dict[str, object]] | None = None, + existing_nodes: Sequence[dict[str, object]] | None = None, + available_tools: Sequence[dict[str, object]] | None = None, + selected_node_ids: Sequence[str] | None = None, + previous_workflow: dict[str, object] | None = None, + regenerate_mode: bool = False, + preferred_language: str | None = None, + available_models: Sequence[dict[str, object]] | None = None, + max_fix_iterations: int = 2, + ): + """ + Generates a Dify Workflow Flowchart from natural language instruction. + + Pipeline: + 1. Planner: Analyze intent & select tools. + 2. Context Filter: Filter relevant tools (reduce tokens). + 3. Builder: Generate node configurations. + 4. Repair: Fix common node/edge issues (NodeRepair, EdgeRepair). + 5. Validator: Check for errors & generate friendly hints. + 6. Renderer: Deterministic Mermaid generation. + """ + model_manager = ModelManager() + model_instance = model_manager.get_model_instance( + tenant_id=tenant_id, + model_type=ModelType.LLM, + provider=model_config.get("provider", ""), + model=model_config.get("name", ""), + ) + model_parameters = model_config.get("completion_params", {}) + available_tools_list = list(available_tools) if available_tools else [] + + # --- STEP 1: PLANNER --- + planner_tools_context = format_tools_for_planner(available_tools_list) + planner_system = PLANNER_SYSTEM_PROMPT.format(tools_summary=planner_tools_context) + planner_user = PLANNER_USER_PROMPT.format(instruction=instruction) + + try: + response = model_instance.invoke_llm( + prompt_messages=[SystemPromptMessage(content=planner_system), UserPromptMessage(content=planner_user)], + model_parameters=model_parameters, + stream=False, + ) + plan_content = response.message.content + # Reuse parse_vibe_response logic or simple load + plan_data = parse_vibe_response(plan_content) + except Exception as e: + logger.exception("Planner failed") + return {"intent": "error", "error": f"Planning failed: {str(e)}"} + + if plan_data.get("intent") == "off_topic": + return { + "intent": "off_topic", + "message": plan_data.get("message", "I can only help with workflow creation."), + "suggestions": plan_data.get("suggestions", []), + } + + # --- STEP 2: CONTEXT FILTERING --- + required_tools = plan_data.get("required_tool_keys", []) + + filtered_tools = [] + if required_tools: + # Simple linear search (optimized version would use a map) + for tool in available_tools_list: + t_key = tool.get("tool_key") or tool.get("tool_name") + provider = tool.get("provider_id") or tool.get("provider") + full_key = f"{provider}/{t_key}" if provider else t_key + + # Check if this tool is in required list (match either full key or short name) + if t_key in required_tools or full_key in required_tools: + filtered_tools.append(tool) + else: + # If logic only, no tools needed + filtered_tools = [] + + # --- STEP 3: BUILDER --- + # Prepare context + tool_schemas = format_available_tools(filtered_tools) + # We need to construct a fake list structure for builtin nodes formatting if using format_available_nodes + # Actually format_available_nodes takes None to use defaults, or a list to add custom + # But we want to SHOW the builtins. format_available_nodes internally uses BUILTIN_NODE_SCHEMAS. + node_specs = format_available_nodes([]) + + builder_system = BUILDER_SYSTEM_PROMPT.format( + plan_context=json.dumps(plan_data.get("steps", []), indent=2), + tool_schemas=tool_schemas, + builtin_node_specs=node_specs, + available_models=format_available_models(list(available_models or [])), + ) + builder_user = BUILDER_USER_PROMPT.format(instruction=instruction) + + try: + build_res = model_instance.invoke_llm( + prompt_messages=[SystemPromptMessage(content=builder_system), UserPromptMessage(content=builder_user)], + model_parameters=model_parameters, + stream=False, + ) + # Builder output is raw JSON nodes/edges + build_content = build_res.message.content + match = re.search(r"```(?:json)?\s*([\s\S]+?)```", build_content) + if match: + build_content = match.group(1) + + workflow_data = json_repair.loads(build_content) + + if "nodes" not in workflow_data: + workflow_data["nodes"] = [] + if "edges" not in workflow_data: + workflow_data["edges"] = [] + + except Exception as e: + logger.exception("Builder failed") + return {"intent": "error", "error": f"Building failed: {str(e)}"} + + # --- STEP 3.4: NODE REPAIR --- + node_repair_result = NodeRepair.repair(workflow_data["nodes"]) + workflow_data["nodes"] = node_repair_result.nodes + + # --- STEP 3.5: EDGE REPAIR --- + repair_result = EdgeRepair.repair(workflow_data) + workflow_data = { + "nodes": repair_result.nodes, + "edges": repair_result.edges, + } + + # --- STEP 4: VALIDATOR --- + is_valid, hints = WorkflowValidator.validate(workflow_data, available_tools_list) + + # --- STEP 5: RENDERER --- + mermaid_code = generate_mermaid(workflow_data) + + # --- FINALIZE --- + # Combine validation hints with repair warnings + all_warnings = [h.message for h in hints] + repair_result.warnings + node_repair_result.warnings + + # Add stability warning (as requested by user) + stability_warning = "The generated workflow may require debugging." + if preferred_language and preferred_language.startswith("zh"): + stability_warning = "生成的 Workflow 可能需要调试。" + all_warnings.append(stability_warning) + + all_fixes = repair_result.repairs_made + node_repair_result.repairs_made + + return { + "intent": "generate", + "flowchart": mermaid_code, + "nodes": workflow_data["nodes"], + "edges": workflow_data["edges"], + "message": plan_data.get("plan_thought", "Generated workflow based on your request."), + "warnings": all_warnings, + "tool_recommendations": [], # Legacy field + "error": "", + "fix_iterations": 0, # Legacy + "fixed_issues": all_fixes, # Track what was auto-fixed + } diff --git a/api/core/workflow/generator/types.py b/api/core/workflow/generator/types.py new file mode 100644 index 0000000000..fd46dc519d --- /dev/null +++ b/api/core/workflow/generator/types.py @@ -0,0 +1,217 @@ +""" +Type definitions for Vibe Workflow Generator. + +This module provides: +- TypedDict classes for lightweight type hints (no runtime overhead) +- Pydantic models for runtime validation where needed + +Usage: + # For type hints only (no runtime validation): + from core.workflow.generator.types import WorkflowNodeDict, WorkflowEdgeDict + + # For runtime validation: + from core.workflow.generator.types import WorkflowNode, WorkflowEdge +""" + +from typing import Any, TypedDict + +from pydantic import BaseModel, Field + +# ============================================================ +# TypedDict definitions (lightweight, for type hints only) +# ============================================================ + + +class WorkflowNodeDict(TypedDict, total=False): + """ + Workflow node structure (TypedDict for hints). + + Attributes: + id: Unique node identifier + type: Node type (e.g., "start", "end", "llm", "if-else", "http-request") + title: Human-readable node title + config: Node-specific configuration + data: Additional node data + """ + + id: str + type: str + title: str + config: dict[str, Any] + data: dict[str, Any] + + +class WorkflowEdgeDict(TypedDict, total=False): + """ + Workflow edge structure (TypedDict for hints). + + Attributes: + source: Source node ID + target: Target node ID + sourceHandle: Branch handle for if-else/question-classifier nodes + """ + + source: str + target: str + sourceHandle: str + + +class AvailableModelDict(TypedDict): + """ + Available model structure. + + Attributes: + provider: Model provider (e.g., "openai", "anthropic") + model: Model name (e.g., "gpt-4", "claude-3") + """ + + provider: str + model: str + + +class ToolParameterDict(TypedDict, total=False): + """ + Tool parameter structure. + + Attributes: + name: Parameter name + type: Parameter type (e.g., "string", "number", "boolean") + required: Whether parameter is required + human_description: Human-readable description + llm_description: LLM-oriented description + options: Available options for enum-type parameters + """ + + name: str + type: str + required: bool + human_description: str | dict[str, str] + llm_description: str + options: list[Any] + + +class AvailableToolDict(TypedDict, total=False): + """ + Available tool structure. + + Attributes: + provider_id: Tool provider ID + provider: Tool provider name (alternative to provider_id) + tool_key: Unique tool key + tool_name: Tool name (alternative to tool_key) + tool_description: Tool description + description: Alternative description field + is_team_authorization: Whether tool is configured/authorized + parameters: List of tool parameters + """ + + provider_id: str + provider: str + tool_key: str + tool_name: str + tool_description: str + description: str + is_team_authorization: bool + parameters: list[ToolParameterDict] + + +class WorkflowDataDict(TypedDict, total=False): + """ + Complete workflow data structure. + + Attributes: + nodes: List of workflow nodes + edges: List of workflow edges + warnings: List of warning messages + """ + + nodes: list[WorkflowNodeDict] + edges: list[WorkflowEdgeDict] + warnings: list[str] + + +# ============================================================ +# Pydantic models (for runtime validation) +# ============================================================ + + +class WorkflowNode(BaseModel): + """ + Workflow node with runtime validation. + + Use this model when you need to validate node data at runtime. + For lightweight type hints without validation, use WorkflowNodeDict. + """ + + id: str + type: str + title: str = "" + config: dict[str, Any] = Field(default_factory=dict) + data: dict[str, Any] = Field(default_factory=dict) + + +class WorkflowEdge(BaseModel): + """ + Workflow edge with runtime validation. + + Use this model when you need to validate edge data at runtime. + For lightweight type hints without validation, use WorkflowEdgeDict. + """ + + source: str + target: str + sourceHandle: str | None = None + + +class AvailableModel(BaseModel): + """ + Available model with runtime validation. + + Use this model when you need to validate model data at runtime. + For lightweight type hints without validation, use AvailableModelDict. + """ + + provider: str + model: str + + +class ToolParameter(BaseModel): + """Tool parameter with runtime validation.""" + + name: str = "" + type: str = "string" + required: bool = False + human_description: str | dict[str, str] = "" + llm_description: str = "" + options: list[Any] = Field(default_factory=list) + + +class AvailableTool(BaseModel): + """ + Available tool with runtime validation. + + Use this model when you need to validate tool data at runtime. + For lightweight type hints without validation, use AvailableToolDict. + """ + + provider_id: str = "" + provider: str = "" + tool_key: str = "" + tool_name: str = "" + tool_description: str = "" + description: str = "" + is_team_authorization: bool = False + parameters: list[ToolParameter] = Field(default_factory=list) + + +class WorkflowData(BaseModel): + """ + Complete workflow data with runtime validation. + + Use this model when you need to validate workflow data at runtime. + For lightweight type hints without validation, use WorkflowDataDict. + """ + + nodes: list[WorkflowNode] = Field(default_factory=list) + edges: list[WorkflowEdge] = Field(default_factory=list) + warnings: list[str] = Field(default_factory=list) diff --git a/api/core/workflow/generator/utils/edge_repair.py b/api/core/workflow/generator/utils/edge_repair.py new file mode 100644 index 0000000000..00f91f8a85 --- /dev/null +++ b/api/core/workflow/generator/utils/edge_repair.py @@ -0,0 +1,359 @@ +""" +Edge Repair Utility for Vibe Workflow Generation. + +This module provides intelligent edge repair capabilities for generated workflows. +It can detect and fix common edge issues: +- Missing edges between sequential nodes +- Incomplete branches for question-classifier and if-else nodes +- Orphaned nodes without connections + +The repair logic is deterministic and doesn't require LLM calls. +""" + +import logging +from dataclasses import dataclass, field + +from core.workflow.generator.types import WorkflowDataDict, WorkflowEdgeDict, WorkflowNodeDict + +logger = logging.getLogger(__name__) + + +@dataclass +class RepairResult: + """Result of edge repair operation.""" + + nodes: list[WorkflowNodeDict] + edges: list[WorkflowEdgeDict] + repairs_made: list[str] = field(default_factory=list) + warnings: list[str] = field(default_factory=list) + + @property + def was_repaired(self) -> bool: + """Check if any repairs were made.""" + return len(self.repairs_made) > 0 + + +class EdgeRepair: + """ + Intelligent edge repair for workflow graphs. + + Repairs are applied in order: + 1. Infer linear connections from node order (if no edges exist) + 2. Add missing branch edges for question-classifier + 3. Add missing branch edges for if-else + 4. Connect orphaned nodes + """ + + @classmethod + def repair(cls, workflow_data: WorkflowDataDict) -> RepairResult: + """ + Repair edges in the workflow data. + + Args: + workflow_data: Dict containing 'nodes' and 'edges' + + Returns: + RepairResult with repaired nodes, edges, and repair logs + """ + nodes = list(workflow_data.get("nodes", [])) + edges = list(workflow_data.get("edges", [])) + repairs: list[str] = [] + warnings: list[str] = [] + + # Build node lookup + node_map = {n.get("id"): n for n in nodes if n.get("id")} + node_ids = set(node_map.keys()) + + # 1. If no edges at all, infer linear chain + if not edges and len(nodes) > 1: + edges, inferred_repairs = cls._infer_linear_chain(nodes) + repairs.extend(inferred_repairs) + + # 2. Build edge index for analysis + outgoing_edges: dict[str, list[WorkflowEdgeDict]] = {} + incoming_edges: dict[str, list[WorkflowEdgeDict]] = {} + for edge in edges: + src = edge.get("source") + tgt = edge.get("target") + if src: + outgoing_edges.setdefault(src, []).append(edge) + if tgt: + incoming_edges.setdefault(tgt, []).append(edge) + + # 3. Repair question-classifier branches + for node in nodes: + if node.get("type") == "question-classifier": + new_edges, branch_repairs, branch_warnings = cls._repair_classifier_branches( + node, edges, outgoing_edges, node_ids + ) + edges.extend(new_edges) + repairs.extend(branch_repairs) + warnings.extend(branch_warnings) + # Update outgoing index + for edge in new_edges: + outgoing_edges.setdefault(edge.get("source"), []).append(edge) + + # 4. Repair if-else branches + for node in nodes: + if node.get("type") == "if-else": + new_edges, branch_repairs, branch_warnings = cls._repair_if_else_branches( + node, edges, outgoing_edges, node_ids + ) + edges.extend(new_edges) + repairs.extend(branch_repairs) + warnings.extend(branch_warnings) + # Update outgoing index + for edge in new_edges: + outgoing_edges.setdefault(edge.get("source"), []).append(edge) + + # 5. Connect orphaned nodes (nodes with no incoming edge, except start) + new_edges, orphan_repairs = cls._connect_orphaned_nodes(nodes, edges, outgoing_edges, incoming_edges) + edges.extend(new_edges) + repairs.extend(orphan_repairs) + + # 6. Connect nodes with no outgoing edge to 'end' (except end nodes) + new_edges, terminal_repairs = cls._connect_terminal_nodes(nodes, edges, outgoing_edges) + edges.extend(new_edges) + repairs.extend(terminal_repairs) + + return RepairResult( + nodes=nodes, + edges=edges, + repairs_made=repairs, + warnings=warnings, + ) + + @classmethod + def _infer_linear_chain(cls, nodes: list[WorkflowNodeDict]) -> tuple[list[WorkflowEdgeDict], list[str]]: + """ + Infer a linear chain of edges from node order. + + This is used when no edges are provided at all. + """ + edges: list[WorkflowEdgeDict] = [] + repairs: list[str] = [] + + # Filter to get ordered node IDs + node_ids = [n.get("id") for n in nodes if n.get("id")] + + if len(node_ids) < 2: + return edges, repairs + + # Create edges between consecutive nodes + for i in range(len(node_ids) - 1): + src = node_ids[i] + tgt = node_ids[i + 1] + edges.append({"source": src, "target": tgt}) + repairs.append(f"Inferred edge: {src} -> {tgt}") + + return edges, repairs + + @classmethod + def _repair_classifier_branches( + cls, + node: WorkflowNodeDict, + edges: list[WorkflowEdgeDict], + outgoing_edges: dict[str, list[WorkflowEdgeDict]], + valid_node_ids: set[str], + ) -> tuple[list[WorkflowEdgeDict], list[str], list[str]]: + """ + Repair missing branches for question-classifier nodes. + + For each class that doesn't have an edge, create one pointing to 'end'. + """ + new_edges: list[WorkflowEdgeDict] = [] + repairs: list[str] = [] + warnings: list[str] = [] + + node_id = node.get("id") + if not node_id: + return new_edges, repairs, warnings + + config = node.get("config", {}) + classes = config.get("classes", []) + + if not classes: + return new_edges, repairs, warnings + + # Get existing sourceHandles for this node + existing_handles = set() + for edge in outgoing_edges.get(node_id, []): + handle = edge.get("sourceHandle") + if handle: + existing_handles.add(handle) + + # Find 'end' node as default target + end_node_id = "end" + if "end" not in valid_node_ids: + # Try to find an end node + for nid in valid_node_ids: + if "end" in nid.lower(): + end_node_id = nid + break + + # Add missing branches + for cls_def in classes: + if not isinstance(cls_def, dict): + continue + cls_id = cls_def.get("id") + cls_name = cls_def.get("name", cls_id) + + if cls_id and cls_id not in existing_handles: + new_edge = { + "source": node_id, + "sourceHandle": cls_id, + "target": end_node_id, + } + new_edges.append(new_edge) + repairs.append(f"Added missing branch edge for class '{cls_name}' -> {end_node_id}") + warnings.append( + f"Auto-connected question-classifier branch '{cls_name}' to '{end_node_id}'. " + "You may want to redirect this to a specific handler node." + ) + + return new_edges, repairs, warnings + + @classmethod + def _repair_if_else_branches( + cls, + node: WorkflowNodeDict, + edges: list[WorkflowEdgeDict], + outgoing_edges: dict[str, list[WorkflowEdgeDict]], + valid_node_ids: set[str], + ) -> tuple[list[WorkflowEdgeDict], list[str], list[str]]: + """ + Repair missing true/false branches for if-else nodes. + """ + new_edges: list[WorkflowEdgeDict] = [] + repairs: list[str] = [] + warnings: list[str] = [] + + node_id = node.get("id") + if not node_id: + return new_edges, repairs, warnings + + # Get existing sourceHandles + existing_handles = set() + for edge in outgoing_edges.get(node_id, []): + handle = edge.get("sourceHandle") + if handle: + existing_handles.add(handle) + + # Find 'end' node as default target + end_node_id = "end" + if "end" not in valid_node_ids: + for nid in valid_node_ids: + if "end" in nid.lower(): + end_node_id = nid + break + + # Add missing branches + required_branches = ["true", "false"] + for branch in required_branches: + if branch not in existing_handles: + new_edge = { + "source": node_id, + "sourceHandle": branch, + "target": end_node_id, + } + new_edges.append(new_edge) + repairs.append(f"Added missing if-else '{branch}' branch -> {end_node_id}") + warnings.append( + f"Auto-connected if-else '{branch}' branch to '{end_node_id}'. " + "You may want to redirect this to a specific handler node." + ) + + return new_edges, repairs, warnings + + @classmethod + def _connect_orphaned_nodes( + cls, + nodes: list[WorkflowNodeDict], + edges: list[WorkflowEdgeDict], + outgoing_edges: dict[str, list[WorkflowEdgeDict]], + incoming_edges: dict[str, list[WorkflowEdgeDict]], + ) -> tuple[list[WorkflowEdgeDict], list[str]]: + """ + Connect orphaned nodes to the previous node in sequence. + + An orphaned node has no incoming edges and is not a 'start' node. + """ + new_edges: list[WorkflowEdgeDict] = [] + repairs: list[str] = [] + + node_ids = [n.get("id") for n in nodes if n.get("id")] + node_types = {n.get("id"): n.get("type") for n in nodes} + + for i, node_id in enumerate(node_ids): + node_type = node_types.get(node_id) + + # Skip start nodes - they don't need incoming edges + if node_type == "start": + continue + + # Check if node has incoming edges + if node_id not in incoming_edges or not incoming_edges[node_id]: + # Find previous node to connect from + if i > 0: + prev_node_id = node_ids[i - 1] + new_edge = {"source": prev_node_id, "target": node_id} + new_edges.append(new_edge) + repairs.append(f"Connected orphaned node: {prev_node_id} -> {node_id}") + + # Update incoming_edges for subsequent checks + incoming_edges.setdefault(node_id, []).append(new_edge) + + return new_edges, repairs + + @classmethod + def _connect_terminal_nodes( + cls, + nodes: list[WorkflowNodeDict], + edges: list[WorkflowEdgeDict], + outgoing_edges: dict[str, list[WorkflowEdgeDict]], + ) -> tuple[list[WorkflowEdgeDict], list[str]]: + """ + Connect terminal nodes (no outgoing edges) to 'end'. + + A terminal node has no outgoing edges and is not an 'end' node. + This ensures all branches eventually reach 'end'. + """ + new_edges: list[WorkflowEdgeDict] = [] + repairs: list[str] = [] + + # Find end node + end_node_id = None + node_ids = set() + for n in nodes: + nid = n.get("id") + ntype = n.get("type") + if nid: + node_ids.add(nid) + if ntype == "end": + end_node_id = nid + + if not end_node_id: + # No end node found, can't connect + return new_edges, repairs + + for node in nodes: + node_id = node.get("id") + node_type = node.get("type") + + # Skip end nodes + if node_type == "end": + continue + + # Skip nodes that already have outgoing edges + if outgoing_edges.get(node_id): + continue + + # Connect to end + new_edge = {"source": node_id, "target": end_node_id} + new_edges.append(new_edge) + repairs.append(f"Connected terminal node to end: {node_id} -> {end_node_id}") + + # Update for subsequent checks + outgoing_edges.setdefault(node_id, []).append(new_edge) + + return new_edges, repairs diff --git a/api/core/workflow/generator/utils/mermaid_generator.py b/api/core/workflow/generator/utils/mermaid_generator.py new file mode 100644 index 0000000000..c042f7c481 --- /dev/null +++ b/api/core/workflow/generator/utils/mermaid_generator.py @@ -0,0 +1,113 @@ +import logging + +from core.workflow.generator.types import WorkflowDataDict + +logger = logging.getLogger(__name__) + + +def generate_mermaid(workflow_data: WorkflowDataDict) -> str: + """ + Generate a Mermaid flowchart from workflow data consisting of nodes and edges. + + Args: + workflow_data: Dict containing 'nodes' (list) and 'edges' (list) + + Returns: + String containing the Mermaid flowchart syntax + """ + nodes = workflow_data.get("nodes", []) + edges = workflow_data.get("edges", []) + + lines = ["flowchart TD"] + + # 1. Define Nodes + # Format: node_id["title
type"] or similar + # We will use the Vibe Workflow standard format: id["type=TYPE|title=TITLE"] + # Or specifically for tool nodes: id["type=tool|title=TITLE|tool=TOOL_KEY"] + + # Map of original IDs to safe Mermaid IDs + id_map = {} + + def get_safe_id(original_id: str) -> str: + if original_id == "end": + return "end_node" + if original_id == "subgraph": + return "subgraph_node" + # Mermaid IDs should be alphanumeric. + # If the ID has special chars, we might need to escape or hash, but Vibe usually generates simple IDs. + # We'll trust standard IDs but handle the reserved keyword 'end'. + return original_id + + for node in nodes: + node_id = node.get("id") + if not node_id: + continue + + safe_id = get_safe_id(node_id) + id_map[node_id] = safe_id + + node_type = node.get("type", "unknown") + title = node.get("title", "Untitled") + + # Escape quotes in title + safe_title = title.replace('"', "'") + + if node_type == "tool": + config = node.get("config", {}) + # Try multiple fields for tool reference + tool_ref = ( + config.get("tool_key") + or config.get("tool") + or config.get("tool_name") + or node.get("tool_name") + or "unknown" + ) + node_def = f'{safe_id}["type={node_type}|title={safe_title}|tool={tool_ref}"]' + else: + node_def = f'{safe_id}["type={node_type}|title={safe_title}"]' + + lines.append(f" {node_def}") + + # 2. Define Edges + # Format: source --> target + + # Track defined nodes to avoid edge errors + defined_node_ids = {n.get("id") for n in nodes if n.get("id")} + + for edge in edges: + source = edge.get("source") + target = edge.get("target") + + # Skip invalid edges + if not source or not target: + continue + + if source not in defined_node_ids or target not in defined_node_ids: + continue + + safe_source = id_map.get(source, source) + safe_target = id_map.get(target, target) + + # Handle conditional branches (true/false) if present + # In Dify workflow, sourceHandle is often used for this + source_handle = edge.get("sourceHandle") + label = "" + + if source_handle == "true": + label = "|true|" + elif source_handle == "false": + label = "|false|" + elif source_handle and source_handle != "source": + # For question-classifier or other multi-path nodes + # Clean up handle for display if needed + safe_handle = str(source_handle).replace('"', "'") + label = f"|{safe_handle}|" + + edge_line = f" {safe_source} -->{label} {safe_target}" + lines.append(edge_line) + + # Start/End nodes are implicitly handled if they are in the 'nodes' list + # If not, we might need to add them, but usually the Builder should produce them. + + result = "\n".join(lines) + return result diff --git a/api/core/workflow/generator/utils/node_repair.py b/api/core/workflow/generator/utils/node_repair.py new file mode 100644 index 0000000000..0ffbfcd9dd --- /dev/null +++ b/api/core/workflow/generator/utils/node_repair.py @@ -0,0 +1,155 @@ +""" +Node Repair Utility for Vibe Workflow Generation. + +This module provides intelligent node configuration repair capabilities. +It can detect and fix common node configuration issues: +- Invalid comparison operators in if-else nodes (e.g. '>=' -> '≥') +""" + +import copy +import logging +from dataclasses import dataclass, field + +from core.workflow.generator.types import WorkflowNodeDict + +logger = logging.getLogger(__name__) + + +@dataclass +class NodeRepairResult: + """Result of node repair operation.""" + + nodes: list[WorkflowNodeDict] + repairs_made: list[str] = field(default_factory=list) + warnings: list[str] = field(default_factory=list) + + @property + def was_repaired(self) -> bool: + """Check if any repairs were made.""" + return len(self.repairs_made) > 0 + + +class NodeRepair: + """ + Intelligent node configuration repair. + """ + + OPERATOR_MAP = { + ">=": "≥", + "<=": "≤", + "!=": "≠", + "==": "=", + } + + @classmethod + def repair(cls, nodes: list[WorkflowNodeDict]) -> NodeRepairResult: + """ + Repair node configurations. + + Args: + nodes: List of node dictionaries + + Returns: + NodeRepairResult with repaired nodes and logs + """ + # Deep copy to avoid mutating original + nodes = copy.deepcopy(nodes) + repairs: list[str] = [] + warnings: list[str] = [] + + for node in nodes: + node_type = node.get("type") + node_id = node.get("id", "unknown") + + if node_type == "if-else": + cls._repair_if_else_operators(node, repairs) + + if node_type == "variable-aggregator": + cls._repair_variable_aggregator_variables(node, repairs) + + # Add other node type repairs here as needed + + return NodeRepairResult( + nodes=nodes, + repairs_made=repairs, + warnings=warnings, + ) + + @classmethod + def _repair_if_else_operators(cls, node: WorkflowNodeDict, repairs: list[str]): + """ + Normalize comparison operators in if-else nodes. + """ + node_id = node.get("id", "unknown") + config = node.get("config", {}) + cases = config.get("cases", []) + + for case in cases: + conditions = case.get("conditions", []) + for condition in conditions: + op = condition.get("comparison_operator") + if op in cls.OPERATOR_MAP: + new_op = cls.OPERATOR_MAP[op] + condition["comparison_operator"] = new_op + repairs.append(f"Normalized operator '{op}' to '{new_op}' in node '{node_id}'") + + @classmethod + def _repair_variable_aggregator_variables(cls, node: WorkflowNodeDict, repairs: list[str]): + """ + Repair variable-aggregator variables format. + Converts dict format to list[list[str]] format. + Expected: [["node_id", "field"], ["node_id2", "field2"]] + May receive: [{"name": "...", "value_selector": ["node_id", "field"]}, ...] + """ + node_id = node.get("id", "unknown") + config = node.get("config", {}) + variables = config.get("variables", []) + + if not variables: + return + + repaired = False + repaired_variables = [] + + for var in variables: + if isinstance(var, dict): + # Convert dict format to array format + value_selector = var.get("value_selector") or var.get("selector") or var.get("path") + if isinstance(value_selector, list) and len(value_selector) > 0: + repaired_variables.append(value_selector) + repaired = True + else: + # Try to extract from name field - LLM may generate {"name": "node_id.field"} + name = var.get("name") + if isinstance(name, str) and "." in name: + # Try to parse "node_id.field" format + parts = name.split(".", 1) + if len(parts) == 2: + repaired_variables.append([parts[0], parts[1]]) + repaired = True + else: + logger.warning( + "Variable aggregator node '%s' has invalid variable format: %s", + node_id, + var, + ) + repaired_variables.append([]) # Empty array as fallback + else: + # If no valid selector or name, skip this variable + logger.warning( + "Variable aggregator node '%s' has invalid variable format: %s", + node_id, + var, + ) + # Don't add empty array - skip invalid variables + elif isinstance(var, list): + # Already in correct format + repaired_variables.append(var) + else: + # Unknown format, skip + logger.warning("Variable aggregator node '%s' has unknown variable format: %s", node_id, var) + # Don't add empty array - skip invalid variables + + if repaired: + config["variables"] = repaired_variables + repairs.append(f"Repaired variable-aggregator variables format in node '{node_id}'") diff --git a/api/core/workflow/generator/utils/workflow_validator.py b/api/core/workflow/generator/utils/workflow_validator.py new file mode 100644 index 0000000000..080e1bdda6 --- /dev/null +++ b/api/core/workflow/generator/utils/workflow_validator.py @@ -0,0 +1,96 @@ +import logging +from dataclasses import dataclass + +from core.workflow.generator.types import AvailableModelDict, AvailableToolDict, WorkflowDataDict +from core.workflow.generator.validation.context import ValidationContext +from core.workflow.generator.validation.engine import ValidationEngine +from core.workflow.generator.validation.rules import Severity + +logger = logging.getLogger(__name__) + + +@dataclass +class ValidationHint: + """Legacy compatibility class for validation hints.""" + + node_id: str + field: str + message: str + severity: str # 'error', 'warning' + suggestion: str = None + node_type: str = None # Added for test compatibility + + # Alias for potential old code using 'type' instead of 'severity' + @property + def type(self) -> str: + return self.severity + + @property + def element_id(self) -> str: + return self.node_id + + +FriendlyHint = ValidationHint # Alias for backward compatibility + + +class WorkflowValidator: + """ + Validates the generated workflow configuration (nodes and edges). + Wraps the new ValidationEngine for backward compatibility. + """ + + @classmethod + def validate( + cls, + workflow_data: WorkflowDataDict, + available_tools: list[AvailableToolDict], + available_models: list[AvailableModelDict] | None = None, + ) -> tuple[bool, list[ValidationHint]]: + """ + Validate workflow data and return validity status and hints. + + Args: + workflow_data: Dict containing 'nodes' and 'edges' + available_tools: List of available tool configurations + available_models: List of available models (added for Vibe compat) + + Returns: + Tuple(max_severity_is_not_error, list_of_hints) + """ + nodes = workflow_data.get("nodes", []) + edges = workflow_data.get("edges", []) + + # Create context + context = ValidationContext( + nodes=nodes, + edges=edges, + available_models=available_models or [], + available_tools=available_tools or [], + ) + + # Run validation engine + engine = ValidationEngine() + result = engine.validate(context) + + # Convert engine errors to legacy hints + hints: list[ValidationHint] = [] + + for error in result.all_errors: + # Map severity + severity = "error" if error.severity == Severity.ERROR else "warning" + + # Map field from message or details if possible (heuristic) + field_name = error.details.get("field", "unknown") + + hints.append( + ValidationHint( + node_id=error.node_id, + field=field_name, + message=error.message, + severity=severity, + suggestion=error.fix_hint, + node_type=error.node_type, + ) + ) + + return result.is_valid, hints diff --git a/api/core/workflow/generator/validation/__init__.py b/api/core/workflow/generator/validation/__init__.py new file mode 100644 index 0000000000..4ce2d263ac --- /dev/null +++ b/api/core/workflow/generator/validation/__init__.py @@ -0,0 +1,45 @@ +""" +Validation Rule Engine for Vibe Workflow Generation. + +This module provides a declarative, schema-based validation system for +generated workflow nodes. It classifies errors into fixable (LLM can auto-fix) +and user-required (needs manual intervention) categories. + +Usage: + from core.workflow.generator.validation import ValidationEngine, ValidationContext + + context = ValidationContext( + available_models=[...], + available_tools=[...], + nodes=[...], + edges=[...], + ) + engine = ValidationEngine() + result = engine.validate(context) + + # Access classified errors + fixable_errors = result.fixable_errors + user_required_errors = result.user_required_errors +""" + +from core.workflow.generator.validation.context import ValidationContext +from core.workflow.generator.validation.engine import ValidationEngine, ValidationResult +from core.workflow.generator.validation.rules import ( + RuleCategory, + Severity, + ValidationError, + ValidationRule, +) + +__all__ = [ + "RuleCategory", + "Severity", + "ValidationContext", + "ValidationEngine", + "ValidationError", + "ValidationResult", + "ValidationRule", +] + + + diff --git a/api/core/workflow/generator/validation/context.py b/api/core/workflow/generator/validation/context.py new file mode 100644 index 0000000000..4f44f52440 --- /dev/null +++ b/api/core/workflow/generator/validation/context.py @@ -0,0 +1,115 @@ +""" +Validation Context for the Rule Engine. + +The ValidationContext holds all the data needed for validation: +- Generated nodes and edges +- Available models, tools, and datasets +- Node output schemas for variable reference validation +""" + +from dataclasses import dataclass, field + +from core.workflow.generator.types import ( + AvailableModelDict, + AvailableToolDict, + WorkflowEdgeDict, + WorkflowNodeDict, +) + + +@dataclass +class ValidationContext: + """ + Context object containing all data needed for validation. + + This is passed to each validation rule, providing access to: + - The nodes being validated + - Edge connections between nodes + - Available external resources (models, tools) + """ + + # Generated workflow data + nodes: list[WorkflowNodeDict] = field(default_factory=list) + edges: list[WorkflowEdgeDict] = field(default_factory=list) + + # Available external resources + available_models: list[AvailableModelDict] = field(default_factory=list) + available_tools: list[AvailableToolDict] = field(default_factory=list) + + # Cached lookups (populated lazily) + _node_map: dict[str, WorkflowNodeDict] | None = field(default=None, repr=False) + _model_set: set[tuple[str, str]] | None = field(default=None, repr=False) + _tool_set: set[str] | None = field(default=None, repr=False) + _configured_tool_set: set[str] | None = field(default=None, repr=False) + + @property + def node_map(self) -> dict[str, WorkflowNodeDict]: + """Get a map of node_id -> node for quick lookup.""" + if self._node_map is None: + self._node_map = {node.get("id", ""): node for node in self.nodes} + return self._node_map + + @property + def model_set(self) -> set[tuple[str, str]]: + """Get a set of (provider, model_name) tuples for quick lookup.""" + if self._model_set is None: + self._model_set = {(m.get("provider", ""), m.get("model", "")) for m in self.available_models} + return self._model_set + + @property + def tool_set(self) -> set[str]: + """Get a set of all tool keys (both configured and unconfigured).""" + if self._tool_set is None: + self._tool_set = set() + for tool in self.available_tools: + provider = tool.get("provider_id") or tool.get("provider", "") + tool_key = tool.get("tool_key") or tool.get("tool_name", "") + if provider and tool_key: + self._tool_set.add(f"{provider}/{tool_key}") + if tool_key: + self._tool_set.add(tool_key) + return self._tool_set + + @property + def configured_tool_set(self) -> set[str]: + """Get a set of configured (authorized) tool keys.""" + if self._configured_tool_set is None: + self._configured_tool_set = set() + for tool in self.available_tools: + if not tool.get("is_team_authorization", False): + continue + provider = tool.get("provider_id") or tool.get("provider", "") + tool_key = tool.get("tool_key") or tool.get("tool_name", "") + if provider and tool_key: + self._configured_tool_set.add(f"{provider}/{tool_key}") + if tool_key: + self._configured_tool_set.add(tool_key) + return self._configured_tool_set + + def has_model(self, provider: str, model_name: str) -> bool: + """Check if a model is available.""" + return (provider, model_name) in self.model_set + + def has_tool(self, tool_key: str) -> bool: + """Check if a tool exists (configured or not).""" + return tool_key in self.tool_set + + def is_tool_configured(self, tool_key: str) -> bool: + """Check if a tool is configured and ready to use.""" + return tool_key in self.configured_tool_set + + def get_node(self, node_id: str) -> WorkflowNodeDict | None: + """Get a node by its ID.""" + return self.node_map.get(node_id) + + def get_node_ids(self) -> set[str]: + """Get all node IDs in the workflow.""" + return set(self.node_map.keys()) + + def get_upstream_nodes(self, node_id: str) -> list[str]: + """Get IDs of nodes that connect to this node (upstream).""" + return [edge.get("source", "") for edge in self.edges if edge.get("target") == node_id] + + def get_downstream_nodes(self, node_id: str) -> list[str]: + """Get IDs of nodes that this node connects to (downstream).""" + return [edge.get("target", "") for edge in self.edges if edge.get("source") == node_id] diff --git a/api/core/workflow/generator/validation/engine.py b/api/core/workflow/generator/validation/engine.py new file mode 100644 index 0000000000..48f880f5f7 --- /dev/null +++ b/api/core/workflow/generator/validation/engine.py @@ -0,0 +1,260 @@ +""" +Validation Engine - Core validation logic. + +The ValidationEngine orchestrates rule execution and aggregates results. +It provides a clean interface for validating workflow nodes. +""" + +import logging +from dataclasses import dataclass, field +from typing import Any + +from core.workflow.generator.types import ( + AvailableModelDict, + AvailableToolDict, + WorkflowEdgeDict, + WorkflowNodeDict, +) +from core.workflow.generator.validation.context import ValidationContext +from core.workflow.generator.validation.rules import ( + RuleCategory, + Severity, + ValidationError, + get_registry, +) + +logger = logging.getLogger(__name__) + + +@dataclass +class ValidationResult: + """ + Result of validation containing all errors classified by fixability. + + Attributes: + all_errors: All validation errors found + fixable_errors: Errors that LLM can automatically fix + user_required_errors: Errors that require user intervention + warnings: Non-blocking warnings + stats: Validation statistics + """ + + all_errors: list[ValidationError] = field(default_factory=list) + fixable_errors: list[ValidationError] = field(default_factory=list) + user_required_errors: list[ValidationError] = field(default_factory=list) + warnings: list[ValidationError] = field(default_factory=list) + stats: dict[str, int] = field(default_factory=dict) + + @property + def has_errors(self) -> bool: + """Check if there are any errors (excluding warnings).""" + return len(self.fixable_errors) > 0 or len(self.user_required_errors) > 0 + + @property + def has_fixable_errors(self) -> bool: + """Check if there are fixable errors.""" + return len(self.fixable_errors) > 0 + + @property + def is_valid(self) -> bool: + """Check if validation passed (no errors, warnings are OK).""" + return not self.has_errors + + def to_dict(self) -> dict[str, Any]: + """Convert to dictionary for API response.""" + return { + "fixable": [e.to_dict() for e in self.fixable_errors], + "user_required": [e.to_dict() for e in self.user_required_errors], + "warnings": [e.to_dict() for e in self.warnings], + "all_warnings": [e.message for e in self.all_errors], + "stats": self.stats, + } + + def get_error_messages(self) -> list[str]: + """Get all error messages as strings.""" + return [e.message for e in self.all_errors] + + def get_fixable_by_node(self) -> dict[str, list[ValidationError]]: + """Group fixable errors by node ID.""" + result: dict[str, list[ValidationError]] = {} + for error in self.fixable_errors: + if error.node_id not in result: + result[error.node_id] = [] + result[error.node_id].append(error) + return result + + +class ValidationEngine: + """ + The main validation engine. + + Usage: + engine = ValidationEngine() + context = ValidationContext(nodes=[...], available_models=[...]) + result = engine.validate(context) + """ + + def __init__(self): + self._registry = get_registry() + + def validate(self, context: ValidationContext) -> ValidationResult: + """ + Validate all nodes in the context. + + Args: + context: ValidationContext with nodes, edges, and available resources + + Returns: + ValidationResult with classified errors + """ + result = ValidationResult() + stats = { + "total_nodes": len(context.nodes), + "total_rules_checked": 0, + "total_errors": 0, + "fixable_count": 0, + "user_required_count": 0, + "warning_count": 0, + } + + # Validate each node + for node in context.nodes: + node_type = node.get("type", "unknown") + node_id = node.get("id", "unknown") + + # Get applicable rules for this node type + rules = self._registry.get_rules_for_node(node_type) + + for rule in rules: + stats["total_rules_checked"] += 1 + + try: + errors = rule.check(node, context) + for error in errors: + result.all_errors.append(error) + stats["total_errors"] += 1 + + # Classify by severity and fixability + if error.severity == Severity.WARNING: + result.warnings.append(error) + stats["warning_count"] += 1 + elif error.is_fixable: + result.fixable_errors.append(error) + stats["fixable_count"] += 1 + else: + result.user_required_errors.append(error) + stats["user_required_count"] += 1 + + except Exception: + logger.exception( + "Rule '%s' failed for node '%s'", + rule.id, + node_id, + ) + # Don't let a rule failure break the entire validation + continue + + # Validate edges separately + edge_errors = self._validate_edges(context) + for error in edge_errors: + result.all_errors.append(error) + stats["total_errors"] += 1 + if error.is_fixable: + result.fixable_errors.append(error) + stats["fixable_count"] += 1 + else: + result.user_required_errors.append(error) + stats["user_required_count"] += 1 + + result.stats = stats + + return result + + def _validate_edges(self, context: ValidationContext) -> list[ValidationError]: + """Validate edge connections.""" + errors: list[ValidationError] = [] + valid_node_ids = context.get_node_ids() + + for edge in context.edges: + source = edge.get("source", "") + target = edge.get("target", "") + + if source and source not in valid_node_ids: + errors.append( + ValidationError( + rule_id="edge.source.invalid", + node_id=source, + node_type="edge", + category=RuleCategory.SEMANTIC, + severity=Severity.ERROR, + is_fixable=True, + message=f"Edge source '{source}' does not exist", + fix_hint="Update edge to reference existing node", + ) + ) + + if target and target not in valid_node_ids: + errors.append( + ValidationError( + rule_id="edge.target.invalid", + node_id=target, + node_type="edge", + category=RuleCategory.SEMANTIC, + severity=Severity.ERROR, + is_fixable=True, + message=f"Edge target '{target}' does not exist", + fix_hint="Update edge to reference existing node", + ) + ) + + return errors + + def validate_single_node( + self, + node: WorkflowNodeDict, + context: ValidationContext, + ) -> list[ValidationError]: + """ + Validate a single node. + + Useful for incremental validation when a node is added/modified. + """ + node_type = node.get("type", "unknown") + rules = self._registry.get_rules_for_node(node_type) + + errors: list[ValidationError] = [] + for rule in rules: + try: + errors.extend(rule.check(node, context)) + except Exception: + logger.exception("Rule '%s' failed", rule.id) + + return errors + + +def validate_nodes( + nodes: list[WorkflowNodeDict], + edges: list[WorkflowEdgeDict] | None = None, + available_models: list[AvailableModelDict] | None = None, + available_tools: list[AvailableToolDict] | None = None, +) -> ValidationResult: + """ + Convenience function to validate nodes without creating engine/context manually. + + Args: + nodes: List of workflow nodes to validate + edges: Optional list of edges + available_models: Optional list of available models + available_tools: Optional list of available tools + + Returns: + ValidationResult with classified errors + """ + context = ValidationContext( + nodes=nodes, + edges=edges or [], + available_models=available_models or [], + available_tools=available_tools or [], + ) + engine = ValidationEngine() + return engine.validate(context) diff --git a/api/core/workflow/generator/validation/rules.py b/api/core/workflow/generator/validation/rules.py new file mode 100644 index 0000000000..ccb8372b49 --- /dev/null +++ b/api/core/workflow/generator/validation/rules.py @@ -0,0 +1,1150 @@ +""" +Validation Rules Definition and Registry. + +This module defines: +- ValidationRule: The rule structure +- RuleCategory: Categories of validation rules +- Severity: Error severity levels +- ValidationError: Error output structure +- All built-in validation rules +""" + +import re +from collections.abc import Callable +from dataclasses import dataclass, field +from enum import Enum +from typing import TYPE_CHECKING, Any + +from core.workflow.generator.types import WorkflowNodeDict + +if TYPE_CHECKING: + from core.workflow.generator.validation.context import ValidationContext + + +class RuleCategory(Enum): + """Categories of validation rules.""" + + STRUCTURE = "structure" # Field existence, types, formats + SEMANTIC = "semantic" # Variable references, edge connections + REFERENCE = "reference" # External resources (models, tools, datasets) + + +class Severity(Enum): + """Severity levels for validation errors.""" + + ERROR = "error" # Must be fixed + WARNING = "warning" # Should be fixed but not blocking + + +@dataclass +class ValidationError: + """ + Represents a validation error found during rule execution. + + Attributes: + rule_id: The ID of the rule that generated this error + node_id: The ID of the node with the error + node_type: The type of the node + category: The rule category + severity: Error severity + is_fixable: Whether LLM can auto-fix this error + message: Human-readable error message + fix_hint: Hint for LLM to fix the error + details: Additional error details + """ + + rule_id: str + node_id: str + node_type: str + category: RuleCategory + severity: Severity + is_fixable: bool + message: str + fix_hint: str = "" + details: dict[str, Any] = field(default_factory=dict) + + def to_dict(self) -> dict[str, Any]: + """Convert to dictionary for API response.""" + return { + "rule_id": self.rule_id, + "node_id": self.node_id, + "node_type": self.node_type, + "category": self.category.value, + "severity": self.severity.value, + "is_fixable": self.is_fixable, + "message": self.message, + "fix_hint": self.fix_hint, + "details": self.details, + } + + +# Type alias for rule check functions +RuleCheckFn = Callable[ + [WorkflowNodeDict, "ValidationContext"], + list[ValidationError], +] + + +@dataclass +class ValidationRule: + """ + A validation rule definition. + + Attributes: + id: Unique rule identifier (e.g., "llm.model.required") + node_types: List of node types this rule applies to, or ["*"] for all + category: The rule category + severity: Default severity for errors from this rule + is_fixable: Whether errors from this rule can be auto-fixed by LLM + check: The validation function + description: Human-readable description of what this rule checks + fix_hint: Default hint for fixing errors from this rule + """ + + id: str + node_types: list[str] + category: RuleCategory + severity: Severity + is_fixable: bool + check: RuleCheckFn + description: str = "" + fix_hint: str = "" + + def applies_to(self, node_type: str) -> bool: + """Check if this rule applies to a given node type.""" + return "*" in self.node_types or node_type in self.node_types + + +# ============================================================================= +# Rule Registry +# ============================================================================= + + +class RuleRegistry: + """ + Registry for validation rules. + + Rules are registered here and can be retrieved by category or node type. + """ + + def __init__(self): + self._rules: list[ValidationRule] = [] + + def register(self, rule: ValidationRule) -> None: + """Register a validation rule.""" + self._rules.append(rule) + + def get_rules_for_node(self, node_type: str) -> list[ValidationRule]: + """Get all rules that apply to a given node type.""" + return [r for r in self._rules if r.applies_to(node_type)] + + def get_rules_by_category(self, category: RuleCategory) -> list[ValidationRule]: + """Get all rules in a given category.""" + return [r for r in self._rules if r.category == category] + + def get_all_rules(self) -> list[ValidationRule]: + """Get all registered rules.""" + return list(self._rules) + + +# Global rule registry instance +_registry = RuleRegistry() + + +def register_rule(rule: ValidationRule) -> ValidationRule: + """Decorator/function to register a rule with the global registry.""" + _registry.register(rule) + return rule + + +def get_registry() -> RuleRegistry: + """Get the global rule registry.""" + return _registry + + +# ============================================================================= +# Helper Functions for Rule Implementations +# ============================================================================= + +# Placeholder patterns that indicate user needs to fill in values +PLACEHOLDER_PATTERNS = [ + "PLEASE_SELECT", + "YOUR_", + "TODO", + "PLACEHOLDER", + "EXAMPLE_", + "REPLACE_", + "INSERT_", + "ADD_YOUR_", +] + +# Variable reference pattern: {{#node_id.field#}} +VARIABLE_REF_PATTERN = re.compile(r"\{\{#([^.#]+)\.([^#]+)#\}\}") + + +def is_placeholder(value: Any) -> bool: + """Check if a value appears to be a placeholder.""" + if not isinstance(value, str): + return False + value_upper = value.upper() + return any(p in value_upper for p in PLACEHOLDER_PATTERNS) + + +def extract_variable_refs(text: str) -> list[tuple[str, str]]: + """ + Extract variable references from text. + + Returns list of (node_id, field_name) tuples. + """ + return VARIABLE_REF_PATTERN.findall(text) + + +def check_required_field( + config: dict[str, Any], + field_name: str, + node_id: str, + node_type: str, + rule_id: str, + fix_hint: str = "", +) -> ValidationError | None: + """Helper to check if a required field exists and is non-empty.""" + value = config.get(field_name) + if value is None or value == "" or (isinstance(value, list) and len(value) == 0): + return ValidationError( + rule_id=rule_id, + node_id=node_id, + node_type=node_type, + category=RuleCategory.STRUCTURE, + severity=Severity.ERROR, + is_fixable=True, + message=f"Node '{node_id}': missing required field '{field_name}'", + fix_hint=fix_hint or f"Add '{field_name}' to the node config", + ) + return None + + +# ============================================================================= +# Structure Rules - Field existence, types, formats +# ============================================================================= + + +def _check_llm_prompt_template(node: WorkflowNodeDict, ctx: "ValidationContext") -> list[ValidationError]: + """Check that LLM node has prompt_template.""" + errors: list[ValidationError] = [] + node_id = node.get("id", "unknown") + config = node.get("config", {}) + + err = check_required_field( + config, + "prompt_template", + node_id, + "llm", + "llm.prompt_template.required", + "Add prompt_template with system and user messages", + ) + if err: + errors.append(err) + + return errors + + +def _check_http_request_url(node: WorkflowNodeDict, ctx: "ValidationContext") -> list[ValidationError]: + """Check that http-request node has url and method.""" + errors: list[ValidationError] = [] + node_id = node.get("id", "unknown") + config = node.get("config", {}) + + # Check url + url = config.get("url", "") + if not url: + errors.append( + ValidationError( + rule_id="http.url.required", + node_id=node_id, + node_type="http-request", + category=RuleCategory.STRUCTURE, + severity=Severity.ERROR, + is_fixable=True, + message=f"Node '{node_id}': http-request missing required 'url'", + fix_hint="Add url - use {{#start.url#}} or a concrete URL", + ) + ) + elif is_placeholder(url): + errors.append( + ValidationError( + rule_id="http.url.placeholder", + node_id=node_id, + node_type="http-request", + category=RuleCategory.STRUCTURE, + severity=Severity.ERROR, + is_fixable=True, + message=f"Node '{node_id}': url contains placeholder value", + fix_hint="Replace placeholder with actual URL or variable reference", + ) + ) + + # Check method + method = config.get("method", "") + if not method: + errors.append( + ValidationError( + rule_id="http.method.required", + node_id=node_id, + node_type="http-request", + category=RuleCategory.STRUCTURE, + severity=Severity.ERROR, + is_fixable=True, + message=f"Node '{node_id}': http-request missing 'method'", + fix_hint="Add method: GET, POST, PUT, DELETE, or PATCH", + ) + ) + + return errors + + +def _check_code_node(node: WorkflowNodeDict, ctx: "ValidationContext") -> list[ValidationError]: + """Check that code node has code and language.""" + errors: list[ValidationError] = [] + node_id = node.get("id", "unknown") + config = node.get("config", {}) + + err = check_required_field( + config, + "code", + node_id, + "code", + "code.code.required", + "Add code with a main() function that returns a dict", + ) + if err: + errors.append(err) + + err = check_required_field( + config, + "language", + node_id, + "code", + "code.language.required", + "Add language: python3 or javascript", + ) + if err: + errors.append(err) + + return errors + + +def _check_question_classifier(node: WorkflowNodeDict, ctx: "ValidationContext") -> list[ValidationError]: + """Check that question-classifier has classes.""" + errors: list[ValidationError] = [] + node_id = node.get("id", "unknown") + config = node.get("config", {}) + + err = check_required_field( + config, + "classes", + node_id, + "question-classifier", + "classifier.classes.required", + "Add classes array with id and name for each classification", + ) + if err: + errors.append(err) + + return errors + + +def _check_parameter_extractor(node: WorkflowNodeDict, ctx: "ValidationContext") -> list[ValidationError]: + """Check that parameter-extractor has parameters and instruction.""" + errors: list[ValidationError] = [] + node_id = node.get("id", "unknown") + config = node.get("config", {}) + + err = check_required_field( + config, + "parameters", + node_id, + "parameter-extractor", + "extractor.parameters.required", + "Add parameters array with name, type, description fields", + ) + if err: + errors.append(err) + else: + # Check individual parameters for required fields + parameters = config.get("parameters", []) + if isinstance(parameters, list): + for i, param in enumerate(parameters): + if isinstance(param, dict): + # Check for 'required' field (boolean) + if "required" not in param: + errors.append( + ValidationError( + rule_id="extractor.param.required_field.missing", + node_id=node_id, + node_type="parameter-extractor", + category=RuleCategory.STRUCTURE, + severity=Severity.ERROR, + is_fixable=True, + message=f"Node '{node_id}': parameter[{i}] missing 'required' field", + fix_hint=f"Add 'required': True to parameter '{param.get('name', 'unknown')}'", + details={"param_index": i, "param_name": param.get("name")}, + ) + ) + + # instruction is recommended but not strictly required + if not config.get("instruction"): + errors.append( + ValidationError( + rule_id="extractor.instruction.recommended", + node_id=node_id, + node_type="parameter-extractor", + category=RuleCategory.STRUCTURE, + severity=Severity.WARNING, + is_fixable=True, + message=f"Node '{node_id}': parameter-extractor should have 'instruction'", + fix_hint="Add instruction describing what to extract", + ) + ) + + return errors + + +def _check_knowledge_retrieval(node: WorkflowNodeDict, ctx: "ValidationContext") -> list[ValidationError]: + """Check that knowledge-retrieval has dataset_ids.""" + errors: list[ValidationError] = [] + node_id = node.get("id", "unknown") + config = node.get("config", {}) + + dataset_ids = config.get("dataset_ids", []) + if not dataset_ids: + errors.append( + ValidationError( + rule_id="knowledge.dataset.required", + node_id=node_id, + node_type="knowledge-retrieval", + category=RuleCategory.STRUCTURE, + severity=Severity.ERROR, + is_fixable=False, # User must select knowledge base + message=f"Node '{node_id}': knowledge-retrieval missing 'dataset_ids'", + fix_hint="User must select knowledge bases in the UI", + ) + ) + else: + # Check for placeholder values + for ds_id in dataset_ids: + if is_placeholder(ds_id): + errors.append( + ValidationError( + rule_id="knowledge.dataset.placeholder", + node_id=node_id, + node_type="knowledge-retrieval", + category=RuleCategory.STRUCTURE, + severity=Severity.ERROR, + is_fixable=False, + message=f"Node '{node_id}': dataset_ids contains placeholder", + fix_hint="User must replace placeholder with actual knowledge base ID", + details={"placeholder_value": ds_id}, + ) + ) + break + + return errors + + +def _check_end_node(node: WorkflowNodeDict, ctx: "ValidationContext") -> list[ValidationError]: + """Check that end node has outputs defined.""" + errors: list[ValidationError] = [] + node_id = node.get("id", "unknown") + config = node.get("config", {}) + + outputs = config.get("outputs", []) + if not outputs: + errors.append( + ValidationError( + rule_id="end.outputs.recommended", + node_id=node_id, + node_type="end", + category=RuleCategory.STRUCTURE, + severity=Severity.WARNING, + is_fixable=True, + message="End node should define output variables", + fix_hint="Add outputs array with variable and value_selector", + ) + ) + + return errors + + +# ============================================================================= +# Semantic Rules - Variable references, edge connections +# ============================================================================= + + +def _check_variable_references(node: WorkflowNodeDict, ctx: "ValidationContext") -> list[ValidationError]: + """Check that variable references point to valid nodes.""" + errors: list[ValidationError] = [] + node_id = node.get("id", "unknown") + node_type = node.get("type", "unknown") + config = node.get("config", {}) + + # Get all valid node IDs (including 'start' which is always valid) + valid_node_ids = ctx.get_node_ids() + valid_node_ids.add("start") + valid_node_ids.add("sys") # System variables + + def check_text_for_refs(text: str, field_path: str) -> None: + if not isinstance(text, str): + return + refs = extract_variable_refs(text) + for ref_node_id, ref_field in refs: + if ref_node_id not in valid_node_ids: + errors.append( + ValidationError( + rule_id="variable.ref.invalid_node", + node_id=node_id, + node_type=node_type, + category=RuleCategory.SEMANTIC, + severity=Severity.ERROR, + is_fixable=True, + message=f"Node '{node_id}': references non-existent node '{ref_node_id}'", + fix_hint=f"Change {{{{#{ref_node_id}.{ref_field}#}}}} to reference a valid node", + details={"field_path": field_path, "invalid_ref": ref_node_id}, + ) + ) + + # Check prompt_template for LLM nodes + prompt_template = config.get("prompt_template", []) + if isinstance(prompt_template, list): + for i, msg in enumerate(prompt_template): + if isinstance(msg, dict): + text = msg.get("text", "") + check_text_for_refs(text, f"prompt_template[{i}].text") + + # Check instruction field + instruction = config.get("instruction", "") + check_text_for_refs(instruction, "instruction") + + # Check url for http-request + url = config.get("url", "") + check_text_for_refs(url, "url") + + return errors + + +def _check_node_has_outgoing_edge(node: WorkflowNodeDict, ctx: "ValidationContext") -> list[ValidationError]: + """Check that non-end nodes have at least one outgoing edge.""" + errors: list[ValidationError] = [] + node_id = node.get("id", "unknown") + node_type = node.get("type", "unknown") + + # End nodes don't need outgoing edges + if node_type == "end": + return errors + + # Check if this node has any outgoing edges + downstream = ctx.get_downstream_nodes(node_id) + if not downstream: + errors.append( + ValidationError( + rule_id="edge.no_outgoing", + node_id=node_id, + node_type=node_type, + category=RuleCategory.SEMANTIC, + severity=Severity.ERROR, + is_fixable=True, + message=f"Node '{node_id}' has no outgoing edge - workflow is disconnected", + fix_hint=f"Add an edge from '{node_id}' to the next node or to 'end'", + details={"field": "edges"}, + ) + ) + + return errors + + +def _check_node_has_incoming_edge(node: WorkflowNodeDict, ctx: "ValidationContext") -> list[ValidationError]: + """Check that non-start nodes have at least one incoming edge.""" + errors: list[ValidationError] = [] + node_id = node.get("id", "unknown") + node_type = node.get("type", "unknown") + + # Start nodes don't need incoming edges + if node_type == "start": + return errors + + # Check if this node has any incoming edges + upstream = ctx.get_upstream_nodes(node_id) + if not upstream: + errors.append( + ValidationError( + rule_id="edge.no_incoming", + node_id=node_id, + node_type=node_type, + category=RuleCategory.SEMANTIC, + severity=Severity.ERROR, + is_fixable=True, + message=f"Node '{node_id}' is orphaned - no incoming edges", + fix_hint=f"Add an edge from a previous node to '{node_id}'", + details={"field": "edges"}, + ) + ) + + return errors + + +def _check_question_classifier_branches(node: WorkflowNodeDict, ctx: "ValidationContext") -> list[ValidationError]: + """Check that question-classifier has edges for all defined classes.""" + errors: list[ValidationError] = [] + node_id = node.get("id", "unknown") + node_type = node.get("type", "unknown") + + if node_type != "question-classifier": + return errors + + config = node.get("config", {}) + classes = config.get("classes", []) + + if not classes: + return errors # Already caught by structure validation + + # Get all class IDs + class_ids = set() + for cls in classes: + if isinstance(cls, dict) and cls.get("id"): + class_ids.add(cls["id"]) + + # Get all outgoing edges with their sourceHandles + outgoing_handles = set() + for edge in ctx.edges: + if edge.get("source") == node_id: + handle = edge.get("sourceHandle") + if handle: + outgoing_handles.add(handle) + + # Check for missing branches + missing_branches = class_ids - outgoing_handles + if missing_branches: + for branch_id in missing_branches: + # Find the class name for better error message + class_name = branch_id + for cls in classes: + if isinstance(cls, dict) and cls.get("id") == branch_id: + class_name = cls.get("name", branch_id) + break + + errors.append( + ValidationError( + rule_id="edge.classifier_branch.missing", + node_id=node_id, + node_type=node_type, + category=RuleCategory.SEMANTIC, + severity=Severity.ERROR, + is_fixable=True, + message=f"Question classifier '{node_id}' missing edge for class '{class_name}'", + fix_hint=f"Add edge: {{source: '{node_id}', sourceHandle: '{branch_id}', target: ''}}", + details={"missing_class_id": branch_id, "missing_class_name": class_name, "field": "edges"}, + ) + ) + + return errors + + +def _check_if_else_branches(node: WorkflowNodeDict, ctx: "ValidationContext") -> list[ValidationError]: + """Check that if-else has both true and false branch edges.""" + errors: list[ValidationError] = [] + node_id = node.get("id", "unknown") + node_type = node.get("type", "unknown") + + if node_type != "if-else": + return errors + + # Get all outgoing edges with their sourceHandles + outgoing_handles = set() + for edge in ctx.edges: + if edge.get("source") == node_id: + handle = edge.get("sourceHandle") + if handle: + outgoing_handles.add(handle) + + # Check for required branches + required_branches = {"true", "false"} + missing_branches = required_branches - outgoing_handles + + for branch in missing_branches: + errors.append( + ValidationError( + rule_id="edge.if_else_branch.missing", + node_id=node_id, + node_type=node_type, + category=RuleCategory.SEMANTIC, + severity=Severity.ERROR, + is_fixable=True, + message=f"If-else node '{node_id}' missing '{branch}' branch edge", + fix_hint=f"Add edge: {{source: '{node_id}', sourceHandle: '{branch}', target: ''}}", + details={"missing_branch": branch, "field": "edges"}, + ) + ) + + return errors + + return errors + + +def _check_if_else_operators(node: WorkflowNodeDict, ctx: "ValidationContext") -> list[ValidationError]: + """Check that if-else comparison operators are valid.""" + errors: list[ValidationError] = [] + node_id = node.get("id", "unknown") + node_type = node.get("type", "unknown") + + if node_type != "if-else": + return errors + + valid_operators = { + "contains", + "not contains", + "start with", + "end with", + "is", + "is not", + "empty", + "not empty", + "in", + "not in", + "all of", + "=", + "≠", + ">", + "<", + "≥", + "≤", + "null", + "not null", + "exists", + "not exists", + } + + config = node.get("config", {}) + cases = config.get("cases", []) + + for case in cases: + conditions = case.get("conditions", []) + for condition in conditions: + op = condition.get("comparison_operator") + if op and op not in valid_operators: + errors.append( + ValidationError( + rule_id="ifelse.operator.invalid", + node_id=node_id, + node_type=node_type, + category=RuleCategory.SEMANTIC, + severity=Severity.ERROR, + is_fixable=True, + message=f"Invalid operator '{op}' in if-else node", + fix_hint=f"Use one of: {', '.join(sorted(valid_operators))}", + details={"invalid_operator": op, "field": "config.cases.conditions.comparison_operator"}, + ) + ) + + return errors + + +def _check_edge_targets_exist(node: WorkflowNodeDict, ctx: "ValidationContext") -> list[ValidationError]: + """Check that edge targets reference existing nodes.""" + errors: list[ValidationError] = [] + node_id = node.get("id", "unknown") + node_type = node.get("type", "unknown") + + valid_node_ids = ctx.get_node_ids() + + # Check all outgoing edges from this node + for edge in ctx.edges: + if edge.get("source") == node_id: + target = edge.get("target") + if target and target not in valid_node_ids: + errors.append( + ValidationError( + rule_id="edge.target.invalid", + node_id=node_id, + node_type=node_type, + category=RuleCategory.SEMANTIC, + severity=Severity.ERROR, + is_fixable=True, + message=f"Edge from '{node_id}' targets non-existent node '{target}'", + fix_hint=f"Change edge target from '{target}' to an existing node", + details={"invalid_target": target, "field": "edges"}, + ) + ) + + return errors + + +# ============================================================================= +# Reference Rules - External resources (models, tools, datasets) +# ============================================================================= + +# Node types that require model configuration +MODEL_REQUIRED_NODE_TYPES = {"llm", "question-classifier", "parameter-extractor"} + + +def _check_model_config(node: WorkflowNodeDict, ctx: "ValidationContext") -> list[ValidationError]: + """Check that model configuration is valid.""" + errors: list[ValidationError] = [] + node_id = node.get("id", "unknown") + node_type = node.get("type", "unknown") + config = node.get("config", {}) + + if node_type not in MODEL_REQUIRED_NODE_TYPES: + return errors + + model = config.get("model") + + # Check if model config exists + if not model: + if ctx.available_models: + errors.append( + ValidationError( + rule_id="model.required", + node_id=node_id, + node_type=node_type, + category=RuleCategory.REFERENCE, + severity=Severity.ERROR, + is_fixable=True, + message=f"Node '{node_id}' ({node_type}): missing required 'model' configuration", + fix_hint="Add model config using one of the available models", + ) + ) + else: + errors.append( + ValidationError( + rule_id="model.no_available", + node_id=node_id, + node_type=node_type, + category=RuleCategory.REFERENCE, + severity=Severity.ERROR, + is_fixable=False, + message=f"Node '{node_id}' ({node_type}): needs model but no models available", + fix_hint="User must configure a model provider first", + ) + ) + return errors + + # Check if model config is valid + if isinstance(model, dict): + provider = model.get("provider", "") + name = model.get("name", "") + + # Check for placeholder values + if is_placeholder(provider) or is_placeholder(name): + if ctx.available_models: + errors.append( + ValidationError( + rule_id="model.placeholder", + node_id=node_id, + node_type=node_type, + category=RuleCategory.REFERENCE, + severity=Severity.ERROR, + is_fixable=True, + message=f"Node '{node_id}': model config contains placeholder", + fix_hint="Replace placeholder with actual model from available_models", + ) + ) + return errors + + # Check if model exists in available_models + if ctx.available_models and provider and name: + if not ctx.has_model(provider, name): + errors.append( + ValidationError( + rule_id="model.not_found", + node_id=node_id, + node_type=node_type, + category=RuleCategory.REFERENCE, + severity=Severity.ERROR, + is_fixable=True, + message=f"Node '{node_id}': model '{provider}/{name}' not in available models", + fix_hint="Replace with a model from available_models", + details={"provider": provider, "model": name}, + ) + ) + + return errors + + +def _check_tool_reference(node: WorkflowNodeDict, ctx: "ValidationContext") -> list[ValidationError]: + """Check that tool references are valid and configured.""" + errors: list[ValidationError] = [] + node_id = node.get("id", "unknown") + node_type = node.get("type", "unknown") + + if node_type != "tool": + return errors + + config = node.get("config", {}) + tool_ref = ( + config.get("tool_key") + or config.get("tool_name") + or config.get("provider_id", "") + "/" + config.get("tool_name", "") + ) + + if not tool_ref: + errors.append( + ValidationError( + rule_id="tool.key.required", + node_id=node_id, + node_type=node_type, + category=RuleCategory.REFERENCE, + severity=Severity.ERROR, + is_fixable=True, + message=f"Node '{node_id}': tool node missing tool_key", + fix_hint="Add tool_key from available_tools", + ) + ) + return errors + + # Check if tool exists + if not ctx.has_tool(tool_ref): + errors.append( + ValidationError( + rule_id="tool.not_found", + node_id=node_id, + node_type=node_type, + category=RuleCategory.REFERENCE, + severity=Severity.ERROR, + is_fixable=True, # Can be replaced with http-request fallback + message=f"Node '{node_id}': tool '{tool_ref}' not found", + fix_hint="Use http-request or code node as fallback", + details={"tool_ref": tool_ref}, + ) + ) + elif not ctx.is_tool_configured(tool_ref): + errors.append( + ValidationError( + rule_id="tool.not_configured", + node_id=node_id, + node_type=node_type, + category=RuleCategory.REFERENCE, + severity=Severity.WARNING, + is_fixable=False, # User needs to configure + message=f"Node '{node_id}': tool '{tool_ref}' requires configuration", + fix_hint="Configure the tool in Tools settings", + details={"tool_ref": tool_ref}, + ) + ) + + return errors + + +# ============================================================================= +# Register All Rules +# ============================================================================= + +# Structure Rules +register_rule( + ValidationRule( + id="llm.prompt_template.required", + node_types=["llm"], + category=RuleCategory.STRUCTURE, + severity=Severity.ERROR, + is_fixable=True, + check=_check_llm_prompt_template, + description="LLM node must have prompt_template", + fix_hint="Add prompt_template with system and user messages", + ) +) + +register_rule( + ValidationRule( + id="http.config.required", + node_types=["http-request"], + category=RuleCategory.STRUCTURE, + severity=Severity.ERROR, + is_fixable=True, + check=_check_http_request_url, + description="HTTP request node must have url and method", + fix_hint="Add url and method to config", + ) +) + +register_rule( + ValidationRule( + id="code.config.required", + node_types=["code"], + category=RuleCategory.STRUCTURE, + severity=Severity.ERROR, + is_fixable=True, + check=_check_code_node, + description="Code node must have code and language", + fix_hint="Add code with main() function and language", + ) +) + +register_rule( + ValidationRule( + id="classifier.classes.required", + node_types=["question-classifier"], + category=RuleCategory.STRUCTURE, + severity=Severity.ERROR, + is_fixable=True, + check=_check_question_classifier, + description="Question classifier must have classes", + fix_hint="Add classes array with classification options", + ) +) + +register_rule( + ValidationRule( + id="extractor.config.required", + node_types=["parameter-extractor"], + category=RuleCategory.STRUCTURE, + severity=Severity.ERROR, + is_fixable=True, + check=_check_parameter_extractor, + description="Parameter extractor must have parameters", + fix_hint="Add parameters array", + ) +) + +register_rule( + ValidationRule( + id="knowledge.config.required", + node_types=["knowledge-retrieval"], + category=RuleCategory.STRUCTURE, + severity=Severity.ERROR, + is_fixable=False, + check=_check_knowledge_retrieval, + description="Knowledge retrieval must have dataset_ids", + fix_hint="User must select knowledge base", + ) +) + +register_rule( + ValidationRule( + id="end.outputs.check", + node_types=["end"], + category=RuleCategory.STRUCTURE, + severity=Severity.WARNING, + is_fixable=True, + check=_check_end_node, + description="End node should have outputs", + fix_hint="Add outputs array", + ) +) + +# Semantic Rules +register_rule( + ValidationRule( + id="variable.references.valid", + node_types=["*"], + category=RuleCategory.SEMANTIC, + severity=Severity.ERROR, + is_fixable=True, + check=_check_variable_references, + description="Variable references must point to valid nodes", + fix_hint="Fix variable reference to use valid node ID", + ) +) + +# Edge Validation Rules +register_rule( + ValidationRule( + id="edge.outgoing.required", + node_types=["*"], + category=RuleCategory.SEMANTIC, + severity=Severity.ERROR, + is_fixable=True, + check=_check_node_has_outgoing_edge, + description="Non-end nodes must have outgoing edges", + fix_hint="Add an edge from this node to the next node", + ) +) + +register_rule( + ValidationRule( + id="edge.incoming.required", + node_types=["*"], + category=RuleCategory.SEMANTIC, + severity=Severity.ERROR, + is_fixable=True, + check=_check_node_has_incoming_edge, + description="Non-start nodes must have incoming edges", + fix_hint="Add an edge from a previous node to this node", + ) +) + +register_rule( + ValidationRule( + id="edge.classifier_branches.complete", + node_types=["question-classifier"], + category=RuleCategory.SEMANTIC, + severity=Severity.ERROR, + is_fixable=True, + check=_check_question_classifier_branches, + description="Question classifier must have edges for all classes", + fix_hint="Add edges with sourceHandle for each class ID", + ) +) + +register_rule( + ValidationRule( + id="edge.if_else_branches.complete", + node_types=["if-else"], + category=RuleCategory.SEMANTIC, + severity=Severity.ERROR, + is_fixable=True, + check=_check_if_else_branches, + description="If-else must have true and false branch edges", + fix_hint="Add edges with sourceHandle 'true' and 'false'", + ) +) + +register_rule( + ValidationRule( + id="edge.targets.valid", + node_types=["*"], + category=RuleCategory.SEMANTIC, + severity=Severity.ERROR, + is_fixable=True, + check=_check_edge_targets_exist, + description="Edge targets must reference existing nodes", + fix_hint="Change edge target to an existing node ID", + ) +) + +# Reference Rules +register_rule( + ValidationRule( + id="model.config.valid", + node_types=["llm", "question-classifier", "parameter-extractor"], + category=RuleCategory.REFERENCE, + severity=Severity.ERROR, + is_fixable=True, + check=_check_model_config, + description="Model configuration must be valid", + fix_hint="Add valid model from available_models", + ) +) + +register_rule( + ValidationRule( + id="tool.reference.valid", + node_types=["tool"], + category=RuleCategory.REFERENCE, + severity=Severity.ERROR, + is_fixable=True, + check=_check_tool_reference, + description="Tool reference must be valid and configured", + fix_hint="Use valid tool or fallback node", + ) +) + +register_rule( + ValidationRule( + id="ifelse.operator.valid", + node_types=["if-else"], + category=RuleCategory.SEMANTIC, + severity=Severity.ERROR, + is_fixable=True, + check=_check_if_else_operators, + description="If-else operators must be valid", + fix_hint="Use standard operators like ≥, ≤, =, ≠", + ) +) diff --git a/api/core/workflow/nodes/base/node.py b/api/core/workflow/nodes/base/node.py index 8ebba3659c..721b746dfc 100644 --- a/api/core/workflow/nodes/base/node.py +++ b/api/core/workflow/nodes/base/node.py @@ -197,6 +197,14 @@ class Node(Generic[NodeDataT]): return None + @classmethod + def get_default_config_schema(cls) -> dict[str, Any] | None: + """ + Get the default configuration schema for the node. + Used for LLM generation. + """ + return None + # Global registry populated via __init_subclass__ _registry: ClassVar[dict["NodeType", dict[str, type["Node"]]]] = {} diff --git a/api/core/workflow/nodes/end/end_node.py b/api/core/workflow/nodes/end/end_node.py index 2efcb4f418..299cbb90ad 100644 --- a/api/core/workflow/nodes/end/end_node.py +++ b/api/core/workflow/nodes/end/end_node.py @@ -1,3 +1,5 @@ +from typing import Any + from core.workflow.enums import NodeExecutionType, NodeType, WorkflowNodeExecutionStatus from core.workflow.node_events import NodeRunResult from core.workflow.nodes.base.node import Node @@ -9,6 +11,24 @@ class EndNode(Node[EndNodeData]): node_type = NodeType.END execution_type = NodeExecutionType.RESPONSE + @classmethod + def get_default_config_schema(cls) -> dict[str, Any] | None: + return { + "description": "Workflow exit point - defines output variables", + "required": ["outputs"], + "parameters": { + "outputs": { + "type": "array", + "description": "Output variables to return", + "item_schema": { + "variable": "string - output variable name", + "type": "enum: string, number, object, array", + "value_selector": "array - path to source value, e.g. ['node_id', 'field']", + }, + }, + }, + } + @classmethod def version(cls) -> str: return "1" diff --git a/api/core/workflow/nodes/start/start_node.py b/api/core/workflow/nodes/start/start_node.py index 36fc5078c5..591c716188 100644 --- a/api/core/workflow/nodes/start/start_node.py +++ b/api/core/workflow/nodes/start/start_node.py @@ -15,6 +15,27 @@ class StartNode(Node[StartNodeData]): node_type = NodeType.START execution_type = NodeExecutionType.ROOT + @classmethod + def get_default_config_schema(cls) -> dict[str, Any] | None: + return { + "description": "Workflow entry point - defines input variables", + "required": [], + "parameters": { + "variables": { + "type": "array", + "description": "Input variables for the workflow", + "item_schema": { + "variable": "string - variable name", + "label": "string - display label", + "type": "enum: text-input, paragraph, number, select, file, file-list", + "required": "boolean", + "max_length": "number (optional)", + }, + }, + }, + "outputs": ["All defined variables are available as {{#start.variable_name#}}"], + } + @classmethod def version(cls) -> str: return "1" diff --git a/api/core/workflow/nodes/tool/tool_node.py b/api/core/workflow/nodes/tool/tool_node.py index 2e7ec757b4..1ebddeb6ad 100644 --- a/api/core/workflow/nodes/tool/tool_node.py +++ b/api/core/workflow/nodes/tool/tool_node.py @@ -50,6 +50,19 @@ class ToolNode(Node[ToolNodeData]): def version(cls) -> str: return "1" + @classmethod + def get_default_config_schema(cls) -> dict[str, Any] | None: + return { + "description": "Execute an external tool", + "required": ["provider_id", "tool_id", "tool_parameters"], + "parameters": { + "provider_id": {"type": "string"}, + "provider_type": {"type": "string"}, + "tool_id": {"type": "string"}, + "tool_parameters": {"type": "object"}, + }, + } + def _run(self) -> Generator[NodeEventBase, None, None]: """ Run the tool node diff --git a/api/tests/unit_tests/core/llm_generator/test_mermaid_generator.py b/api/tests/unit_tests/core/llm_generator/test_mermaid_generator.py new file mode 100644 index 0000000000..9dbb486dd9 --- /dev/null +++ b/api/tests/unit_tests/core/llm_generator/test_mermaid_generator.py @@ -0,0 +1,288 @@ +""" +Unit tests for the Mermaid Generator. + +Tests cover: +- Basic workflow rendering +- Reserved word handling ('end' → 'end_node') +- Question classifier multi-branch edges +- If-else branch labels +- Edge validation and skipping +- Tool node formatting +""" + + +from core.workflow.generator.utils.mermaid_generator import generate_mermaid + + +class TestBasicWorkflow: + """Tests for basic workflow Mermaid generation.""" + + def test_simple_start_end_workflow(self): + """Test simple Start → End workflow.""" + workflow_data = { + "nodes": [ + {"id": "start", "type": "start", "title": "Start"}, + {"id": "end", "type": "end", "title": "End"}, + ], + "edges": [{"source": "start", "target": "end"}], + } + result = generate_mermaid(workflow_data) + + assert "flowchart TD" in result + assert 'start["type=start|title=Start"]' in result + assert 'end_node["type=end|title=End"]' in result + assert "start --> end_node" in result + + def test_start_llm_end_workflow(self): + """Test Start → LLM → End workflow.""" + workflow_data = { + "nodes": [ + {"id": "start", "type": "start", "title": "Start"}, + {"id": "llm", "type": "llm", "title": "Generate"}, + {"id": "end", "type": "end", "title": "End"}, + ], + "edges": [ + {"source": "start", "target": "llm"}, + {"source": "llm", "target": "end"}, + ], + } + result = generate_mermaid(workflow_data) + + assert 'llm["type=llm|title=Generate"]' in result + assert "start --> llm" in result + assert "llm --> end_node" in result + + def test_empty_workflow(self): + """Test empty workflow returns minimal output.""" + workflow_data = {"nodes": [], "edges": []} + result = generate_mermaid(workflow_data) + + assert result == "flowchart TD" + + def test_missing_keys_handled(self): + """Test workflow with missing keys doesn't crash.""" + workflow_data = {} + result = generate_mermaid(workflow_data) + + assert "flowchart TD" in result + + +class TestReservedWords: + """Tests for reserved word handling in node IDs.""" + + def test_end_node_id_is_replaced(self): + """Test 'end' node ID is replaced with 'end_node'.""" + workflow_data = { + "nodes": [{"id": "end", "type": "end", "title": "End"}], + "edges": [], + } + result = generate_mermaid(workflow_data) + + # Should use end_node instead of end + assert "end_node[" in result + assert '"type=end|title=End"' in result + + def test_subgraph_node_id_is_replaced(self): + """Test 'subgraph' node ID is replaced with 'subgraph_node'.""" + workflow_data = { + "nodes": [{"id": "subgraph", "type": "code", "title": "Process"}], + "edges": [], + } + result = generate_mermaid(workflow_data) + + assert "subgraph_node[" in result + + def test_edge_uses_safe_ids(self): + """Test edges correctly reference safe IDs after replacement.""" + workflow_data = { + "nodes": [ + {"id": "start", "type": "start", "title": "Start"}, + {"id": "end", "type": "end", "title": "End"}, + ], + "edges": [{"source": "start", "target": "end"}], + } + result = generate_mermaid(workflow_data) + + # Edge should use end_node, not end + assert "start --> end_node" in result + assert "start --> end\n" not in result + + +class TestBranchEdges: + """Tests for branching node edge labels.""" + + def test_question_classifier_source_handles(self): + """Test question-classifier edges with sourceHandle labels.""" + workflow_data = { + "nodes": [ + {"id": "classifier", "type": "question-classifier", "title": "Classify"}, + {"id": "refund", "type": "llm", "title": "Handle Refund"}, + {"id": "inquiry", "type": "llm", "title": "Handle Inquiry"}, + ], + "edges": [ + {"source": "classifier", "target": "refund", "sourceHandle": "refund"}, + {"source": "classifier", "target": "inquiry", "sourceHandle": "inquiry"}, + ], + } + result = generate_mermaid(workflow_data) + + assert "classifier -->|refund| refund" in result + assert "classifier -->|inquiry| inquiry" in result + + def test_if_else_true_false_handles(self): + """Test if-else edges with true/false labels.""" + workflow_data = { + "nodes": [ + {"id": "ifelse", "type": "if-else", "title": "Check"}, + {"id": "yes_branch", "type": "llm", "title": "Yes"}, + {"id": "no_branch", "type": "llm", "title": "No"}, + ], + "edges": [ + {"source": "ifelse", "target": "yes_branch", "sourceHandle": "true"}, + {"source": "ifelse", "target": "no_branch", "sourceHandle": "false"}, + ], + } + result = generate_mermaid(workflow_data) + + assert "ifelse -->|true| yes_branch" in result + assert "ifelse -->|false| no_branch" in result + + def test_source_handle_source_is_ignored(self): + """Test sourceHandle='source' doesn't add label.""" + workflow_data = { + "nodes": [ + {"id": "llm1", "type": "llm", "title": "LLM 1"}, + {"id": "llm2", "type": "llm", "title": "LLM 2"}, + ], + "edges": [{"source": "llm1", "target": "llm2", "sourceHandle": "source"}], + } + result = generate_mermaid(workflow_data) + + # Should be plain arrow without label + assert "llm1 --> llm2" in result + assert "llm1 -->|source|" not in result + + +class TestEdgeValidation: + """Tests for edge validation and error handling.""" + + def test_edge_with_missing_source_is_skipped(self): + """Test edge with non-existent source node is skipped.""" + workflow_data = { + "nodes": [{"id": "end", "type": "end", "title": "End"}], + "edges": [{"source": "nonexistent", "target": "end"}], + } + result = generate_mermaid(workflow_data) + + # Should not contain the invalid edge + assert "nonexistent" not in result + assert "-->" not in result or "nonexistent" not in result + + def test_edge_with_missing_target_is_skipped(self): + """Test edge with non-existent target node is skipped.""" + workflow_data = { + "nodes": [{"id": "start", "type": "start", "title": "Start"}], + "edges": [{"source": "start", "target": "nonexistent"}], + } + result = generate_mermaid(workflow_data) + + # Edge should be skipped + assert "start --> nonexistent" not in result + + def test_edge_without_source_or_target_is_skipped(self): + """Test edge missing source or target is skipped.""" + workflow_data = { + "nodes": [{"id": "start", "type": "start", "title": "Start"}], + "edges": [{"source": "start"}, {"target": "start"}, {}], + } + result = generate_mermaid(workflow_data) + + # No edges should be rendered + assert result.count("-->") == 0 + + +class TestToolNodes: + """Tests for tool node formatting.""" + + def test_tool_node_includes_tool_key(self): + """Test tool node includes tool_key in label.""" + workflow_data = { + "nodes": [ + { + "id": "search", + "type": "tool", + "title": "Search", + "config": {"tool_key": "google/search"}, + } + ], + "edges": [], + } + result = generate_mermaid(workflow_data) + + assert 'search["type=tool|title=Search|tool=google/search"]' in result + + def test_tool_node_with_tool_name_fallback(self): + """Test tool node uses tool_name as fallback.""" + workflow_data = { + "nodes": [ + { + "id": "tool1", + "type": "tool", + "title": "My Tool", + "config": {"tool_name": "my_tool"}, + } + ], + "edges": [], + } + result = generate_mermaid(workflow_data) + + assert "tool=my_tool" in result + + def test_tool_node_missing_tool_key_shows_unknown(self): + """Test tool node without tool_key shows 'unknown'.""" + workflow_data = { + "nodes": [{"id": "tool1", "type": "tool", "title": "Tool", "config": {}}], + "edges": [], + } + result = generate_mermaid(workflow_data) + + assert "tool=unknown" in result + + +class TestNodeFormatting: + """Tests for node label formatting.""" + + def test_quotes_in_title_are_escaped(self): + """Test double quotes in title are replaced with single quotes.""" + workflow_data = { + "nodes": [{"id": "llm", "type": "llm", "title": 'Say "Hello"'}], + "edges": [], + } + result = generate_mermaid(workflow_data) + + # Double quotes should be replaced + assert "Say 'Hello'" in result + assert 'Say "Hello"' not in result + + def test_node_without_id_is_skipped(self): + """Test node without id is skipped.""" + workflow_data = { + "nodes": [{"type": "llm", "title": "No ID"}], + "edges": [], + } + result = generate_mermaid(workflow_data) + + # Should only have flowchart header + lines = [line for line in result.split("\n") if line.strip()] + assert len(lines) == 1 + + def test_node_default_values(self): + """Test node with missing type/title uses defaults.""" + workflow_data = { + "nodes": [{"id": "node1"}], + "edges": [], + } + result = generate_mermaid(workflow_data) + + assert "type=unknown" in result + assert "title=Untitled" in result diff --git a/api/tests/unit_tests/core/llm_generator/test_node_repair.py b/api/tests/unit_tests/core/llm_generator/test_node_repair.py new file mode 100644 index 0000000000..a92a7d0125 --- /dev/null +++ b/api/tests/unit_tests/core/llm_generator/test_node_repair.py @@ -0,0 +1,81 @@ +from core.workflow.generator.utils.node_repair import NodeRepair + + +class TestNodeRepair: + """Tests for NodeRepair utility.""" + + def test_repair_if_else_valid_operators(self): + """Test that valid operators remain unchanged.""" + nodes = [ + { + "id": "node1", + "type": "if-else", + "config": { + "cases": [ + { + "conditions": [ + {"comparison_operator": "≥", "value": "1"}, + {"comparison_operator": "=", "value": "2"}, + ] + } + ] + }, + } + ] + result = NodeRepair.repair(nodes) + assert result.was_repaired is False + assert result.nodes == nodes + + def test_repair_if_else_invalid_operators(self): + """Test that invalid operators are normalized.""" + nodes = [ + { + "id": "node1", + "type": "if-else", + "config": { + "cases": [ + { + "conditions": [ + {"comparison_operator": ">=", "value": "1"}, + {"comparison_operator": "<=", "value": "2"}, + {"comparison_operator": "!=", "value": "3"}, + {"comparison_operator": "==", "value": "4"}, + ] + } + ] + }, + } + ] + result = NodeRepair.repair(nodes) + assert result.was_repaired is True + assert len(result.repairs_made) == 4 + + conditions = result.nodes[0]["config"]["cases"][0]["conditions"] + assert conditions[0]["comparison_operator"] == "≥" + assert conditions[1]["comparison_operator"] == "≤" + assert conditions[2]["comparison_operator"] == "≠" + assert conditions[3]["comparison_operator"] == "=" + + def test_repair_ignores_other_nodes(self): + """Test that other node types are ignored.""" + nodes = [{"id": "node1", "type": "llm", "config": {"some_field": ">="}}] + result = NodeRepair.repair(nodes) + assert result.was_repaired is False + assert result.nodes[0]["config"]["some_field"] == ">=" + + def test_repair_handles_missing_config(self): + """Test robustness against missing fields.""" + nodes = [ + { + "id": "node1", + "type": "if-else", + # Missing config + }, + { + "id": "node2", + "type": "if-else", + "config": {}, # Missing cases + }, + ] + result = NodeRepair.repair(nodes) + assert result.was_repaired is False diff --git a/api/tests/unit_tests/core/llm_generator/test_node_schemas_validation.py b/api/tests/unit_tests/core/llm_generator/test_node_schemas_validation.py new file mode 100644 index 0000000000..eccfd93207 --- /dev/null +++ b/api/tests/unit_tests/core/llm_generator/test_node_schemas_validation.py @@ -0,0 +1,99 @@ +""" +Tests for node schemas validation. + +Ensures that the node configuration stays in sync with registered node types. +""" + +from core.workflow.generator.config.node_schemas import ( + get_builtin_node_schemas, + validate_node_schemas, +) + + +class TestNodeSchemasValidation: + """Tests for node schema validation utilities.""" + + def test_validate_node_schemas_returns_no_warnings(self): + """Ensure all registered node types have corresponding schemas.""" + warnings = validate_node_schemas() + # If this test fails, it means a new node type was added but + # no schema was defined for it in node_schemas.py + assert len(warnings) == 0, ( + f"Missing schemas for node types: {warnings}. " + "Please add schemas for these node types in node_schemas.py " + "or add them to _INTERNAL_NODE_TYPES if they don't need schemas." + ) + + def test_builtin_node_schemas_not_empty(self): + """Ensure BUILTIN_NODE_SCHEMAS contains expected node types.""" + # get_builtin_node_schemas() includes dynamic schemas + all_schemas = get_builtin_node_schemas() + assert len(all_schemas) > 0 + # Core node types should always be present + expected_types = ["llm", "code", "http-request", "if-else"] + for node_type in expected_types: + assert node_type in all_schemas, f"Missing schema for core node type: {node_type}" + + def test_schema_structure(self): + """Ensure each schema has required fields.""" + all_schemas = get_builtin_node_schemas() + for node_type, schema in all_schemas.items(): + assert "description" in schema, f"Missing 'description' in schema for {node_type}" + # 'parameters' is optional but if present should be a dict + if "parameters" in schema: + assert isinstance(schema["parameters"], dict), ( + f"'parameters' in schema for {node_type} should be a dict" + ) + + +class TestNodeSchemasMerged: + """Tests to verify the merged configuration works correctly.""" + + def test_fallback_rules_available(self): + """Ensure FALLBACK_RULES is available from node_schemas.""" + from core.workflow.generator.config.node_schemas import FALLBACK_RULES + + assert len(FALLBACK_RULES) > 0 + assert "http-request" in FALLBACK_RULES + assert "code" in FALLBACK_RULES + assert "llm" in FALLBACK_RULES + + def test_node_type_aliases_available(self): + """Ensure NODE_TYPE_ALIASES is available from node_schemas.""" + from core.workflow.generator.config.node_schemas import NODE_TYPE_ALIASES + + assert len(NODE_TYPE_ALIASES) > 0 + assert NODE_TYPE_ALIASES.get("gpt") == "llm" + assert NODE_TYPE_ALIASES.get("api") == "http-request" + + def test_field_name_corrections_available(self): + """Ensure FIELD_NAME_CORRECTIONS is available from node_schemas.""" + from core.workflow.generator.config.node_schemas import ( + FIELD_NAME_CORRECTIONS, + get_corrected_field_name, + ) + + assert len(FIELD_NAME_CORRECTIONS) > 0 + # Test the helper function + assert get_corrected_field_name("http-request", "text") == "body" + assert get_corrected_field_name("llm", "response") == "text" + assert get_corrected_field_name("code", "unknown") == "unknown" + + def test_config_init_exports(self): + """Ensure config __init__.py exports all needed symbols.""" + from core.workflow.generator.config import ( + BUILTIN_NODE_SCHEMAS, + FALLBACK_RULES, + FIELD_NAME_CORRECTIONS, + NODE_TYPE_ALIASES, + get_corrected_field_name, + validate_node_schemas, + ) + + # Just verify imports work + assert BUILTIN_NODE_SCHEMAS is not None + assert FALLBACK_RULES is not None + assert FIELD_NAME_CORRECTIONS is not None + assert NODE_TYPE_ALIASES is not None + assert callable(get_corrected_field_name) + assert callable(validate_node_schemas) diff --git a/api/tests/unit_tests/core/llm_generator/test_planner_prompts.py b/api/tests/unit_tests/core/llm_generator/test_planner_prompts.py new file mode 100644 index 0000000000..a741c30c7a --- /dev/null +++ b/api/tests/unit_tests/core/llm_generator/test_planner_prompts.py @@ -0,0 +1,173 @@ +""" +Unit tests for the Planner Prompts. + +Tests cover: +- Tool formatting for planner context +- Edge cases with missing fields +- Empty tool lists +""" + + +from core.workflow.generator.prompts.planner_prompts import format_tools_for_planner + + +class TestFormatToolsForPlanner: + """Tests for format_tools_for_planner function.""" + + def test_empty_tools_returns_default_message(self): + """Test empty tools list returns default message.""" + result = format_tools_for_planner([]) + + assert result == "No external tools available." + + def test_none_tools_returns_default_message(self): + """Test None tools list returns default message.""" + result = format_tools_for_planner(None) + + assert result == "No external tools available." + + def test_single_tool_formatting(self): + """Test single tool is formatted correctly.""" + tools = [ + { + "provider_id": "google", + "tool_key": "search", + "tool_label": "Google Search", + "tool_description": "Search the web using Google", + } + ] + result = format_tools_for_planner(tools) + + assert "[google/search]" in result + assert "Google Search" in result + assert "Search the web using Google" in result + + def test_multiple_tools_formatting(self): + """Test multiple tools are formatted correctly.""" + tools = [ + { + "provider_id": "google", + "tool_key": "search", + "tool_label": "Search", + "tool_description": "Web search", + }, + { + "provider_id": "slack", + "tool_key": "send_message", + "tool_label": "Send Message", + "tool_description": "Send a Slack message", + }, + ] + result = format_tools_for_planner(tools) + + lines = result.strip().split("\n") + assert len(lines) == 2 + assert "[google/search]" in result + assert "[slack/send_message]" in result + + def test_tool_without_provider_uses_key_only(self): + """Test tool without provider_id uses tool_key only.""" + tools = [ + { + "tool_key": "my_tool", + "tool_label": "My Tool", + "tool_description": "A custom tool", + } + ] + result = format_tools_for_planner(tools) + + # Should format as [my_tool] without provider prefix + assert "[my_tool]" in result + assert "My Tool" in result + + def test_tool_with_tool_name_fallback(self): + """Test tool uses tool_name when tool_key is missing.""" + tools = [ + { + "tool_name": "fallback_tool", + "description": "Fallback description", + } + ] + result = format_tools_for_planner(tools) + + assert "fallback_tool" in result + assert "Fallback description" in result + + def test_tool_with_missing_description(self): + """Test tool with missing description doesn't crash.""" + tools = [ + { + "provider_id": "test", + "tool_key": "tool1", + "tool_label": "Tool 1", + } + ] + result = format_tools_for_planner(tools) + + assert "[test/tool1]" in result + assert "Tool 1" in result + + def test_tool_with_all_missing_fields(self): + """Test tool with all fields missing uses defaults.""" + tools = [{}] + result = format_tools_for_planner(tools) + + # Should not crash, may produce minimal output + assert isinstance(result, str) + + def test_tool_uses_provider_fallback(self): + """Test tool uses 'provider' when 'provider_id' is missing.""" + tools = [ + { + "provider": "openai", + "tool_key": "dalle", + "tool_label": "DALL-E", + "tool_description": "Generate images", + } + ] + result = format_tools_for_planner(tools) + + assert "[openai/dalle]" in result + + def test_tool_label_fallback_to_key(self): + """Test tool_label falls back to tool_key when missing.""" + tools = [ + { + "provider_id": "test", + "tool_key": "my_key", + "tool_description": "Description here", + } + ] + result = format_tools_for_planner(tools) + + # Label should fallback to key + assert "my_key" in result + assert "Description here" in result + + +class TestPlannerPromptConstants: + """Tests for planner prompt constant availability.""" + + def test_planner_system_prompt_exists(self): + """Test PLANNER_SYSTEM_PROMPT is defined.""" + from core.workflow.generator.prompts.planner_prompts import PLANNER_SYSTEM_PROMPT + + assert PLANNER_SYSTEM_PROMPT is not None + assert len(PLANNER_SYSTEM_PROMPT) > 0 + assert "{tools_summary}" in PLANNER_SYSTEM_PROMPT + + def test_planner_user_prompt_exists(self): + """Test PLANNER_USER_PROMPT is defined.""" + from core.workflow.generator.prompts.planner_prompts import PLANNER_USER_PROMPT + + assert PLANNER_USER_PROMPT is not None + assert "{instruction}" in PLANNER_USER_PROMPT + + def test_planner_system_prompt_has_required_sections(self): + """Test PLANNER_SYSTEM_PROMPT has required XML sections.""" + from core.workflow.generator.prompts.planner_prompts import PLANNER_SYSTEM_PROMPT + + assert "" in PLANNER_SYSTEM_PROMPT + assert "" in PLANNER_SYSTEM_PROMPT + assert "" in PLANNER_SYSTEM_PROMPT + assert "" in PLANNER_SYSTEM_PROMPT diff --git a/api/tests/unit_tests/core/llm_generator/test_validation_engine.py b/api/tests/unit_tests/core/llm_generator/test_validation_engine.py new file mode 100644 index 0000000000..477b0cdcf7 --- /dev/null +++ b/api/tests/unit_tests/core/llm_generator/test_validation_engine.py @@ -0,0 +1,536 @@ +""" +Unit tests for the Validation Rule Engine. + +Tests cover: +- Structure rules (required fields, types, formats) +- Semantic rules (variable references, edge connections) +- Reference rules (model exists, tool configured, dataset valid) +- ValidationEngine integration +""" + + +from core.workflow.generator.validation import ( + ValidationContext, + ValidationEngine, +) +from core.workflow.generator.validation.rules import ( + extract_variable_refs, + is_placeholder, +) + + +class TestPlaceholderDetection: + """Tests for placeholder detection utility.""" + + def test_detects_please_select(self): + assert is_placeholder("PLEASE_SELECT_YOUR_MODEL") is True + + def test_detects_your_prefix(self): + assert is_placeholder("YOUR_API_KEY") is True + + def test_detects_todo(self): + assert is_placeholder("TODO: fill this in") is True + + def test_detects_placeholder(self): + assert is_placeholder("PLACEHOLDER_VALUE") is True + + def test_detects_example_prefix(self): + assert is_placeholder("EXAMPLE_URL") is True + + def test_detects_replace_prefix(self): + assert is_placeholder("REPLACE_WITH_ACTUAL") is True + + def test_case_insensitive(self): + assert is_placeholder("please_select") is True + assert is_placeholder("Please_Select") is True + + def test_valid_values_not_detected(self): + assert is_placeholder("https://api.example.com") is False + assert is_placeholder("gpt-4") is False + assert is_placeholder("my_variable") is False + + def test_non_string_returns_false(self): + assert is_placeholder(123) is False + assert is_placeholder(None) is False + assert is_placeholder(["list"]) is False + + +class TestVariableRefExtraction: + """Tests for variable reference extraction.""" + + def test_extracts_simple_ref(self): + refs = extract_variable_refs("Hello {{#start.query#}}") + assert refs == [("start", "query")] + + def test_extracts_multiple_refs(self): + refs = extract_variable_refs("{{#node1.output#}} and {{#node2.text#}}") + assert refs == [("node1", "output"), ("node2", "text")] + + def test_extracts_nested_field(self): + refs = extract_variable_refs("{{#http_request.body#}}") + assert refs == [("http_request", "body")] + + def test_no_refs_returns_empty(self): + refs = extract_variable_refs("No references here") + assert refs == [] + + def test_handles_malformed_refs(self): + refs = extract_variable_refs("{{#invalid}} and {{incomplete#}}") + assert refs == [] + + +class TestValidationContext: + """Tests for ValidationContext.""" + + def test_node_map_lookup(self): + ctx = ValidationContext( + nodes=[ + {"id": "start", "type": "start"}, + {"id": "llm_1", "type": "llm"}, + ] + ) + assert ctx.get_node("start") == {"id": "start", "type": "start"} + assert ctx.get_node("nonexistent") is None + + def test_model_set(self): + ctx = ValidationContext( + available_models=[ + {"provider": "openai", "model": "gpt-4"}, + {"provider": "anthropic", "model": "claude-3"}, + ] + ) + assert ctx.has_model("openai", "gpt-4") is True + assert ctx.has_model("anthropic", "claude-3") is True + assert ctx.has_model("openai", "gpt-3.5") is False + + def test_tool_set(self): + ctx = ValidationContext( + available_tools=[ + {"provider_id": "google", "tool_key": "search", "is_team_authorization": True}, + {"provider_id": "slack", "tool_key": "send_message", "is_team_authorization": False}, + ] + ) + assert ctx.has_tool("google/search") is True + assert ctx.has_tool("search") is True + assert ctx.is_tool_configured("google/search") is True + assert ctx.is_tool_configured("slack/send_message") is False + + def test_upstream_downstream_nodes(self): + ctx = ValidationContext( + nodes=[ + {"id": "start", "type": "start"}, + {"id": "llm", "type": "llm"}, + {"id": "end", "type": "end"}, + ], + edges=[ + {"source": "start", "target": "llm"}, + {"source": "llm", "target": "end"}, + ], + ) + assert ctx.get_upstream_nodes("llm") == ["start"] + assert ctx.get_downstream_nodes("llm") == ["end"] + + +class TestStructureRules: + """Tests for structure validation rules.""" + + def test_llm_missing_prompt_template(self): + ctx = ValidationContext( + nodes=[{"id": "llm_1", "type": "llm", "config": {}}] + ) + engine = ValidationEngine() + result = engine.validate(ctx) + + assert result.has_errors + errors = [e for e in result.all_errors if e.rule_id == "llm.prompt_template.required"] + assert len(errors) == 1 + assert errors[0].is_fixable is True + + def test_llm_with_prompt_template_passes(self): + ctx = ValidationContext( + nodes=[ + { + "id": "llm_1", + "type": "llm", + "config": { + "prompt_template": [ + {"role": "system", "text": "You are helpful"}, + {"role": "user", "text": "Hello"}, + ] + }, + } + ] + ) + engine = ValidationEngine() + result = engine.validate(ctx) + + # No prompt_template errors + errors = [e for e in result.all_errors if "prompt_template" in e.rule_id] + assert len(errors) == 0 + + def test_http_request_missing_url(self): + ctx = ValidationContext( + nodes=[{"id": "http_1", "type": "http-request", "config": {}}] + ) + engine = ValidationEngine() + result = engine.validate(ctx) + + errors = [e for e in result.all_errors if "http.url" in e.rule_id] + assert len(errors) == 1 + assert errors[0].is_fixable is True + + def test_http_request_placeholder_url(self): + ctx = ValidationContext( + nodes=[ + { + "id": "http_1", + "type": "http-request", + "config": {"url": "PLEASE_SELECT_YOUR_URL", "method": "GET"}, + } + ] + ) + engine = ValidationEngine() + result = engine.validate(ctx) + + errors = [e for e in result.all_errors if "placeholder" in e.rule_id] + assert len(errors) == 1 + + def test_code_node_missing_fields(self): + ctx = ValidationContext( + nodes=[{"id": "code_1", "type": "code", "config": {}}] + ) + engine = ValidationEngine() + result = engine.validate(ctx) + + error_rules = {e.rule_id for e in result.all_errors} + assert "code.code.required" in error_rules + assert "code.language.required" in error_rules + + def test_knowledge_retrieval_missing_dataset(self): + ctx = ValidationContext( + nodes=[{"id": "kb_1", "type": "knowledge-retrieval", "config": {}}] + ) + engine = ValidationEngine() + result = engine.validate(ctx) + + errors = [e for e in result.all_errors if "knowledge.dataset" in e.rule_id] + assert len(errors) == 1 + assert errors[0].is_fixable is False # User must configure + + +class TestSemanticRules: + """Tests for semantic validation rules.""" + + def test_valid_variable_reference(self): + ctx = ValidationContext( + nodes=[ + {"id": "start", "type": "start", "config": {}}, + { + "id": "llm_1", + "type": "llm", + "config": { + "prompt_template": [ + {"role": "user", "text": "Process: {{#start.query#}}"} + ] + }, + }, + ] + ) + engine = ValidationEngine() + result = engine.validate(ctx) + + # No variable reference errors + errors = [e for e in result.all_errors if "variable.ref" in e.rule_id] + assert len(errors) == 0 + + def test_invalid_variable_reference(self): + ctx = ValidationContext( + nodes=[ + {"id": "start", "type": "start", "config": {}}, + { + "id": "llm_1", + "type": "llm", + "config": { + "prompt_template": [ + {"role": "user", "text": "Process: {{#nonexistent.field#}}"} + ] + }, + }, + ] + ) + engine = ValidationEngine() + result = engine.validate(ctx) + + errors = [e for e in result.all_errors if "variable.ref" in e.rule_id] + assert len(errors) == 1 + assert "nonexistent" in errors[0].message + + def test_edge_validation(self): + ctx = ValidationContext( + nodes=[ + {"id": "start", "type": "start", "config": {}}, + {"id": "end", "type": "end", "config": {}}, + ], + edges=[ + {"source": "start", "target": "end"}, + {"source": "nonexistent", "target": "end"}, + ], + ) + engine = ValidationEngine() + result = engine.validate(ctx) + + errors = [e for e in result.all_errors if "edge" in e.rule_id] + assert len(errors) == 1 + assert "nonexistent" in errors[0].message + + +class TestReferenceRules: + """Tests for reference validation rules (models, tools).""" + + def test_llm_missing_model_with_available(self): + ctx = ValidationContext( + nodes=[ + { + "id": "llm_1", + "type": "llm", + "config": {"prompt_template": [{"role": "user", "text": "Hi"}]}, + } + ], + available_models=[{"provider": "openai", "model": "gpt-4"}], + ) + engine = ValidationEngine() + result = engine.validate(ctx) + + errors = [e for e in result.all_errors if e.rule_id == "model.required"] + assert len(errors) == 1 + assert errors[0].is_fixable is True + + def test_llm_missing_model_no_available(self): + ctx = ValidationContext( + nodes=[ + { + "id": "llm_1", + "type": "llm", + "config": {"prompt_template": [{"role": "user", "text": "Hi"}]}, + } + ], + available_models=[], # No models available + ) + engine = ValidationEngine() + result = engine.validate(ctx) + + errors = [e for e in result.all_errors if e.rule_id == "model.no_available"] + assert len(errors) == 1 + assert errors[0].is_fixable is False + + def test_llm_with_valid_model(self): + ctx = ValidationContext( + nodes=[ + { + "id": "llm_1", + "type": "llm", + "config": { + "prompt_template": [{"role": "user", "text": "Hi"}], + "model": {"provider": "openai", "name": "gpt-4"}, + }, + } + ], + available_models=[{"provider": "openai", "model": "gpt-4"}], + ) + engine = ValidationEngine() + result = engine.validate(ctx) + + errors = [e for e in result.all_errors if "model" in e.rule_id] + assert len(errors) == 0 + + def test_llm_with_invalid_model(self): + ctx = ValidationContext( + nodes=[ + { + "id": "llm_1", + "type": "llm", + "config": { + "prompt_template": [{"role": "user", "text": "Hi"}], + "model": {"provider": "openai", "name": "gpt-99"}, + }, + } + ], + available_models=[{"provider": "openai", "model": "gpt-4"}], + ) + engine = ValidationEngine() + result = engine.validate(ctx) + + errors = [e for e in result.all_errors if e.rule_id == "model.not_found"] + assert len(errors) == 1 + assert errors[0].is_fixable is True + + def test_tool_node_not_found(self): + ctx = ValidationContext( + nodes=[ + { + "id": "tool_1", + "type": "tool", + "config": {"tool_key": "nonexistent/tool"}, + } + ], + available_tools=[], + ) + engine = ValidationEngine() + result = engine.validate(ctx) + + errors = [e for e in result.all_errors if e.rule_id == "tool.not_found"] + assert len(errors) == 1 + + def test_tool_node_not_configured(self): + ctx = ValidationContext( + nodes=[ + { + "id": "tool_1", + "type": "tool", + "config": {"tool_key": "google/search"}, + } + ], + available_tools=[ + {"provider_id": "google", "tool_key": "search", "is_team_authorization": False} + ], + ) + engine = ValidationEngine() + result = engine.validate(ctx) + + errors = [e for e in result.all_errors if e.rule_id == "tool.not_configured"] + assert len(errors) == 1 + assert errors[0].is_fixable is False + + +class TestValidationResult: + """Tests for ValidationResult classification.""" + + def test_has_errors(self): + ctx = ValidationContext( + nodes=[{"id": "llm_1", "type": "llm", "config": {}}] + ) + engine = ValidationEngine() + result = engine.validate(ctx) + + assert result.has_errors is True + assert result.is_valid is False + + def test_has_fixable_errors(self): + ctx = ValidationContext( + nodes=[ + { + "id": "llm_1", + "type": "llm", + "config": {"prompt_template": [{"role": "user", "text": "Hi"}]}, + } + ], + available_models=[{"provider": "openai", "model": "gpt-4"}], + ) + engine = ValidationEngine() + result = engine.validate(ctx) + + assert result.has_fixable_errors is True + assert len(result.fixable_errors) > 0 + + def test_get_fixable_by_node(self): + ctx = ValidationContext( + nodes=[ + {"id": "llm_1", "type": "llm", "config": {}}, + {"id": "http_1", "type": "http-request", "config": {}}, + ] + ) + engine = ValidationEngine() + result = engine.validate(ctx) + + by_node = result.get_fixable_by_node() + assert "llm_1" in by_node + assert "http_1" in by_node + + def test_to_dict(self): + ctx = ValidationContext( + nodes=[{"id": "llm_1", "type": "llm", "config": {}}] + ) + engine = ValidationEngine() + result = engine.validate(ctx) + + d = result.to_dict() + assert "fixable" in d + assert "user_required" in d + assert "warnings" in d + assert "all_warnings" in d + assert "stats" in d + + +class TestIntegration: + """Integration tests for the full validation pipeline.""" + + def test_complete_workflow_validation(self): + """Test validation of a complete workflow.""" + ctx = ValidationContext( + nodes=[ + { + "id": "start", + "type": "start", + "config": {"variables": [{"variable": "query", "type": "text-input"}]}, + }, + { + "id": "llm_1", + "type": "llm", + "config": { + "model": {"provider": "openai", "name": "gpt-4"}, + "prompt_template": [{"role": "user", "text": "{{#start.query#}}"}], + }, + }, + { + "id": "end", + "type": "end", + "config": {"outputs": [{"variable": "result", "value_selector": ["llm_1", "text"]}]}, + }, + ], + edges=[ + {"source": "start", "target": "llm_1"}, + {"source": "llm_1", "target": "end"}, + ], + available_models=[{"provider": "openai", "model": "gpt-4"}], + ) + engine = ValidationEngine() + result = engine.validate(ctx) + + # Should have no errors + assert result.is_valid is True + assert len(result.fixable_errors) == 0 + assert len(result.user_required_errors) == 0 + + def test_workflow_with_multiple_errors(self): + """Test workflow with multiple types of errors.""" + ctx = ValidationContext( + nodes=[ + {"id": "start", "type": "start", "config": {}}, + { + "id": "llm_1", + "type": "llm", + "config": {}, # Missing prompt_template and model + }, + { + "id": "kb_1", + "type": "knowledge-retrieval", + "config": {"dataset_ids": ["PLEASE_SELECT_YOUR_DATASET"]}, + }, + {"id": "end", "type": "end", "config": {}}, + ], + available_models=[{"provider": "openai", "model": "gpt-4"}], + ) + engine = ValidationEngine() + result = engine.validate(ctx) + + # Should have multiple errors + assert result.has_errors is True + assert len(result.fixable_errors) >= 2 # model, prompt_template + assert len(result.user_required_errors) >= 1 # dataset placeholder + + # Check stats + assert result.stats["total_nodes"] == 4 + assert result.stats["total_errors"] >= 3 + + + diff --git a/api/tests/unit_tests/core/llm_generator/test_workflow_validator_vibe.py b/api/tests/unit_tests/core/llm_generator/test_workflow_validator_vibe.py new file mode 100644 index 0000000000..39e2ba5a0e --- /dev/null +++ b/api/tests/unit_tests/core/llm_generator/test_workflow_validator_vibe.py @@ -0,0 +1,435 @@ +""" +Unit tests for the Vibe Workflow Validator. + +Tests cover: +- Basic validation function +- User-friendly validation hints +- Edge cases and error handling +""" + + +from core.workflow.generator.utils.workflow_validator import ValidationHint, WorkflowValidator + + +class TestValidationHint: + """Tests for ValidationHint dataclass.""" + + def test_hint_creation(self): + """Test creating a validation hint.""" + hint = ValidationHint( + node_id="llm_1", + field="model", + message="Model is not configured", + severity="error", + ) + assert hint.node_id == "llm_1" + assert hint.field == "model" + assert hint.message == "Model is not configured" + assert hint.severity == "error" + + def test_hint_with_suggestion(self): + """Test hint with suggestion.""" + hint = ValidationHint( + node_id="http_1", + field="url", + message="URL is required", + severity="error", + suggestion="Add a valid URL like https://api.example.com", + ) + assert hint.suggestion is not None + + +class TestWorkflowValidatorBasic: + """Tests for basic validation scenarios.""" + + def test_empty_workflow_is_valid(self): + """Test empty workflow passes validation.""" + workflow_data = {"nodes": [], "edges": []} + is_valid, hints = WorkflowValidator.validate(workflow_data, []) + + # Empty but valid structure + assert is_valid is True + assert len(hints) == 0 + + def test_minimal_valid_workflow(self): + """Test minimal Start → End workflow.""" + workflow_data = { + "nodes": [ + {"id": "start", "type": "start", "config": {}}, + {"id": "end", "type": "end", "config": {}}, + ], + "edges": [{"source": "start", "target": "end"}], + } + is_valid, hints = WorkflowValidator.validate(workflow_data, []) + + assert is_valid is True + + def test_complete_workflow_with_llm(self): + """Test complete workflow with LLM node.""" + workflow_data = { + "nodes": [ + {"id": "start", "type": "start", "config": {"variables": []}}, + { + "id": "llm", + "type": "llm", + "config": { + "model": {"provider": "openai", "name": "gpt-4"}, + "prompt_template": [{"role": "user", "text": "Hello"}], + }, + }, + {"id": "end", "type": "end", "config": {"outputs": []}}, + ], + "edges": [ + {"source": "start", "target": "llm"}, + {"source": "llm", "target": "end"}, + ], + } + is_valid, hints = WorkflowValidator.validate(workflow_data, []) + + # Should pass with no critical errors + errors = [h for h in hints if h.severity == "error"] + assert len(errors) == 0 + + +class TestVariableReferenceValidation: + """Tests for variable reference validation.""" + + def test_valid_variable_reference(self): + """Test valid variable reference passes.""" + workflow_data = { + "nodes": [ + {"id": "start", "type": "start", "config": {}}, + { + "id": "llm", + "type": "llm", + "config": {"prompt_template": [{"role": "user", "text": "Query: {{#start.query#}}"}]}, + }, + ], + "edges": [{"source": "start", "target": "llm"}], + } + is_valid, hints = WorkflowValidator.validate(workflow_data, []) + + ref_errors = [h for h in hints if "reference" in h.message.lower()] + assert len(ref_errors) == 0 + + def test_invalid_variable_reference(self): + """Test invalid variable reference generates hint.""" + workflow_data = { + "nodes": [ + {"id": "start", "type": "start", "config": {}}, + { + "id": "llm", + "type": "llm", + "config": {"prompt_template": [{"role": "user", "text": "{{#nonexistent.field#}}"}]}, + }, + ], + "edges": [{"source": "start", "target": "llm"}], + } + is_valid, hints = WorkflowValidator.validate(workflow_data, []) + + # Should have a hint about invalid reference + ref_hints = [h for h in hints if "nonexistent" in h.message or "reference" in h.message.lower()] + assert len(ref_hints) >= 1 + + +class TestEdgeValidation: + """Tests for edge validation.""" + + def test_edge_with_invalid_source(self): + """Test edge with non-existent source generates hint.""" + workflow_data = { + "nodes": [{"id": "end", "type": "end", "config": {}}], + "edges": [{"source": "nonexistent", "target": "end"}], + } + is_valid, hints = WorkflowValidator.validate(workflow_data, []) + + # Should have hint about invalid edge + edge_hints = [h for h in hints if "edge" in h.message.lower() or "source" in h.message.lower()] + assert len(edge_hints) >= 1 + + def test_edge_with_invalid_target(self): + """Test edge with non-existent target generates hint.""" + workflow_data = { + "nodes": [{"id": "start", "type": "start", "config": {}}], + "edges": [{"source": "start", "target": "nonexistent"}], + } + is_valid, hints = WorkflowValidator.validate(workflow_data, []) + + edge_hints = [h for h in hints if "edge" in h.message.lower() or "target" in h.message.lower()] + assert len(edge_hints) >= 1 + + +class TestToolValidation: + """Tests for tool node validation.""" + + def test_tool_node_found_in_available(self): + """Test tool node that exists in available tools.""" + workflow_data = { + "nodes": [ + {"id": "start", "type": "start", "config": {}}, + { + "id": "tool1", + "type": "tool", + "config": {"tool_key": "google/search"}, + }, + {"id": "end", "type": "end", "config": {}}, + ], + "edges": [{"source": "start", "target": "tool1"}, {"source": "tool1", "target": "end"}], + } + available_tools = [{"provider_id": "google", "tool_key": "search", "is_team_authorization": True}] + is_valid, hints = WorkflowValidator.validate(workflow_data, available_tools) + + tool_errors = [h for h in hints if h.severity == "error" and "tool" in h.message.lower()] + assert len(tool_errors) == 0 + + def test_tool_node_not_found(self): + """Test tool node not in available tools generates hint.""" + workflow_data = { + "nodes": [ + { + "id": "tool1", + "type": "tool", + "config": {"tool_key": "unknown/tool"}, + } + ], + "edges": [], + } + available_tools = [] + is_valid, hints = WorkflowValidator.validate(workflow_data, available_tools) + + tool_hints = [h for h in hints if "tool" in h.message.lower()] + assert len(tool_hints) >= 1 + + +class TestQuestionClassifierValidation: + """Tests for question-classifier node validation.""" + + def test_question_classifier_with_classes(self): + """Test question-classifier with valid classes.""" + workflow_data = { + "nodes": [ + {"id": "start", "type": "start", "config": {}}, + { + "id": "classifier", + "type": "question-classifier", + "config": { + "classes": [ + {"id": "class1", "name": "Class 1"}, + {"id": "class2", "name": "Class 2"}, + ], + "model": {"provider": "openai", "name": "gpt-4", "mode": "chat"}, + }, + }, + {"id": "h1", "type": "llm", "config": {}}, + {"id": "h2", "type": "llm", "config": {}}, + {"id": "end", "type": "end", "config": {}}, + ], + "edges": [ + {"source": "start", "target": "classifier"}, + {"source": "classifier", "sourceHandle": "class1", "target": "h1"}, + {"source": "classifier", "sourceHandle": "class2", "target": "h2"}, + {"source": "h1", "target": "end"}, + {"source": "h2", "target": "end"}, + ], + } + available_models = [{"provider": "openai", "model": "gpt-4", "mode": "chat"}] + is_valid, hints = WorkflowValidator.validate(workflow_data, [], available_models=available_models) + + class_errors = [h for h in hints if "class" in h.message.lower() and h.severity == "error"] + assert len(class_errors) == 0 + + def test_question_classifier_missing_classes(self): + """Test question-classifier without classes generates hint.""" + workflow_data = { + "nodes": [ + { + "id": "classifier", + "type": "question-classifier", + "config": {"model": {"provider": "openai", "name": "gpt-4", "mode": "chat"}}, + } + ], + "edges": [], + } + available_models = [{"provider": "openai", "model": "gpt-4", "mode": "chat"}] + is_valid, hints = WorkflowValidator.validate(workflow_data, [], available_models=available_models) + + # Should have hint about missing classes + class_hints = [h for h in hints if "class" in h.message.lower()] + assert len(class_hints) >= 1 + + +class TestHttpRequestValidation: + """Tests for HTTP request node validation.""" + + def test_http_request_with_url(self): + """Test HTTP request with valid URL.""" + workflow_data = { + "nodes": [ + {"id": "start", "type": "start", "config": {}}, + { + "id": "http", + "type": "http-request", + "config": {"url": "https://api.example.com", "method": "GET"}, + }, + {"id": "end", "type": "end", "config": {}}, + ], + "edges": [{"source": "start", "target": "http"}, {"source": "http", "target": "end"}], + } + is_valid, hints = WorkflowValidator.validate(workflow_data, []) + + url_errors = [h for h in hints if "url" in h.message.lower() and h.severity == "error"] + assert len(url_errors) == 0 + + def test_http_request_missing_url(self): + """Test HTTP request without URL generates hint.""" + workflow_data = { + "nodes": [ + { + "id": "http", + "type": "http-request", + "config": {"method": "GET"}, + } + ], + "edges": [], + } + is_valid, hints = WorkflowValidator.validate(workflow_data, []) + + url_hints = [h for h in hints if "url" in h.message.lower()] + assert len(url_hints) >= 1 + + +class TestParameterExtractorValidation: + """Tests for parameter-extractor node validation.""" + + def test_parameter_extractor_valid_params(self): + """Test parameter-extractor with valid parameters.""" + workflow_data = { + "nodes": [ + {"id": "start", "type": "start", "config": {}}, + { + "id": "extractor", + "type": "parameter-extractor", + "config": { + "instruction": "Extract info", + "parameters": [ + { + "name": "name", + "type": "string", + "description": "Name", + "required": True, + } + ], + "model": {"provider": "openai", "name": "gpt-4", "mode": "chat"}, + }, + }, + {"id": "end", "type": "end", "config": {}}, + ], + "edges": [{"source": "start", "target": "extractor"}, {"source": "extractor", "target": "end"}], + } + available_models = [{"provider": "openai", "model": "gpt-4", "mode": "chat"}] + is_valid, hints = WorkflowValidator.validate(workflow_data, [], available_models=available_models) + + errors = [h for h in hints if h.severity == "error"] + assert len(errors) == 0 + + def test_parameter_extractor_missing_required_field(self): + """Test parameter-extractor missing 'required' field in parameter item.""" + workflow_data = { + "nodes": [ + { + "id": "extractor", + "type": "parameter-extractor", + "config": { + "instruction": "Extract info", + "parameters": [ + { + "name": "name", + "type": "string", + "description": "Name", + # Missing 'required' + } + ], + "model": {"provider": "openai", "name": "gpt-4", "mode": "chat"}, + }, + } + ], + "edges": [], + } + available_models = [{"provider": "openai", "model": "gpt-4", "mode": "chat"}] + is_valid, hints = WorkflowValidator.validate(workflow_data, [], available_models=available_models) + + errors = [h for h in hints if "required" in h.message and h.severity == "error"] + assert len(errors) >= 1 + assert "parameter-extractor" in errors[0].node_type + + +class TestIfElseValidation: + """Tests for if-else node validation.""" + + def test_if_else_valid_operators(self): + """Test if-else with valid operators.""" + workflow_data = { + "nodes": [ + {"id": "start", "type": "start", "config": {}}, + { + "id": "ifelse", + "type": "if-else", + "config": { + "cases": [{"case_id": "c1", "conditions": [{"comparison_operator": "≥", "value": "1"}]}] + }, + }, + {"id": "t", "type": "llm", "config": {}}, + {"id": "f", "type": "llm", "config": {}}, + {"id": "end", "type": "end", "config": {}}, + ], + "edges": [ + {"source": "start", "target": "ifelse"}, + {"source": "ifelse", "sourceHandle": "true", "target": "t"}, + {"source": "ifelse", "sourceHandle": "false", "target": "f"}, + {"source": "t", "target": "end"}, + {"source": "f", "target": "end"}, + ], + } + is_valid, hints = WorkflowValidator.validate(workflow_data, []) + errors = [h for h in hints if h.severity == "error"] + # Filter out LLM model errors if any (available tools/models check might trigger) + # (actually available_models empty list might trigger model error? + # No, model config validation skips if model field not present? No, LLM has model config. + # But logic skips check if key missing? Let's check logic. + # _check_model_config checks if provider/name match available. If available is empty, it fails. + # But wait, validate default available_models is None? + # I should provide mock available_models or ignore model errors. + + # Actually LLM node "config": {} implies missing model config. Rules check if config structure is valid? + # Let's filter specifically for operator errors. + operator_errors = [h for h in errors if "operator" in h.message] + assert len(operator_errors) == 0 + + def test_if_else_invalid_operators(self): + """Test if-else with invalid operators.""" + workflow_data = { + "nodes": [ + {"id": "start", "type": "start", "config": {}}, + { + "id": "ifelse", + "type": "if-else", + "config": { + "cases": [{"case_id": "c1", "conditions": [{"comparison_operator": ">=", "value": "1"}]}] + }, + }, + {"id": "t", "type": "llm", "config": {}}, + {"id": "f", "type": "llm", "config": {}}, + {"id": "end", "type": "end", "config": {}}, + ], + "edges": [ + {"source": "start", "target": "ifelse"}, + {"source": "ifelse", "sourceHandle": "true", "target": "t"}, + {"source": "ifelse", "sourceHandle": "false", "target": "f"}, + {"source": "t", "target": "end"}, + {"source": "f", "target": "end"}, + ], + } + is_valid, hints = WorkflowValidator.validate(workflow_data, []) + operator_errors = [h for h in hints if "operator" in h.message and h.severity == "error"] + assert len(operator_errors) > 0 + assert "≥" in operator_errors[0].suggestion diff --git a/web/app/components/goto-anything/actions/banana.spec.tsx b/web/app/components/goto-anything/actions/banana.spec.tsx new file mode 100644 index 0000000000..ec7cd36c8e --- /dev/null +++ b/web/app/components/goto-anything/actions/banana.spec.tsx @@ -0,0 +1,87 @@ +import type { CommandSearchResult, SearchResult } from './types' +import { isInWorkflowPage } from '@/app/components/workflow/constants' +import i18n from '@/i18n-config/i18next-config' +import { bananaAction } from './banana' + +vi.mock('@/i18n-config/i18next-config', () => ({ + default: { + t: vi.fn((key: string, options?: Record) => { + if (!options) + return key + return `${key}:${JSON.stringify(options)}` + }), + }, +})) + +vi.mock('@/app/components/workflow/constants', async () => { + const actual = await vi.importActual( + '@/app/components/workflow/constants', + ) + return { + ...actual, + isInWorkflowPage: vi.fn(), + } +}) + +const mockedIsInWorkflowPage = vi.mocked(isInWorkflowPage) +const mockedT = vi.mocked(i18n.t) + +const getCommandResult = (item: SearchResult): CommandSearchResult => { + expect(item.type).toBe('command') + return item as CommandSearchResult +} + +beforeEach(() => { + vi.clearAllMocks() +}) + +// Search behavior for the banana action. +describe('bananaAction', () => { + // Search results depend on workflow context and input content. + describe('search', () => { + it('should return no results when not on workflow page', async () => { + // Arrange + mockedIsInWorkflowPage.mockReturnValue(false) + + // Act + const result = await bananaAction.search('', '', 'en') + + // Assert + expect(result).toEqual([]) + }) + + it('should return hint description when input is blank', async () => { + // Arrange + mockedIsInWorkflowPage.mockReturnValue(true) + + // Act + const result = await bananaAction.search('', ' ', 'en') + + // Assert + expect(result).toHaveLength(1) + const [item] = result + const commandItem = getCommandResult(item) + expect(item.description).toContain('app.gotoAnything.actions.vibeHint') + expect(commandItem.data.args?.dsl).toBe('') + expect(mockedT).toHaveBeenCalledWith( + 'app.gotoAnything.actions.vibeHint', + expect.objectContaining({ prompt: expect.any(String), lng: 'en' }), + ) + }) + + it('should return default description when input is provided', async () => { + // Arrange + mockedIsInWorkflowPage.mockReturnValue(true) + + // Act + const result = await bananaAction.search('', ' build a flow ', 'en') + + // Assert + expect(result).toHaveLength(1) + const [item] = result + const commandItem = getCommandResult(item) + expect(item.description).toContain('app.gotoAnything.actions.vibeDesc') + expect(commandItem.data.args?.dsl).toBe('build a flow') + }) + }) +}) diff --git a/web/app/components/workflow/hooks/__tests__/use-workflow-vibe.test.ts b/web/app/components/workflow/hooks/__tests__/use-workflow-vibe.test.ts new file mode 100644 index 0000000000..11d19ce9d2 --- /dev/null +++ b/web/app/components/workflow/hooks/__tests__/use-workflow-vibe.test.ts @@ -0,0 +1,82 @@ + +import { describe, it, expect } from 'vitest' +import { replaceVariableReferences } from '../use-workflow-vibe' +import { BlockEnum } from '@/app/components/workflow/types' + +// Mock types needed for the test +interface NodeData { + title: string + [key: string]: any +} + +describe('use-workflow-vibe', () => { + describe('replaceVariableReferences', () => { + it('should replace variable references in strings', () => { + const data = { + title: 'Test Node', + prompt: 'Hello {{#old_id.query#}}', + } + const nodeIdMap = new Map() + nodeIdMap.set('old_id', { id: 'new_uuid', data: { type: 'llm' } }) + + const result = replaceVariableReferences(data, nodeIdMap) as NodeData + expect(result.prompt).toBe('Hello {{#new_uuid.query#}}') + }) + + it('should handle multiple references in one string', () => { + const data = { + title: 'Test Node', + text: '{{#node1.out#}} and {{#node2.out#}}', + } + const nodeIdMap = new Map() + nodeIdMap.set('node1', { id: 'uuid1', data: { type: 'llm' } }) + nodeIdMap.set('node2', { id: 'uuid2', data: { type: 'llm' } }) + + const result = replaceVariableReferences(data, nodeIdMap) as NodeData + expect(result.text).toBe('{{#uuid1.out#}} and {{#uuid2.out#}}') + }) + + it('should replace variable references in value_selector arrays', () => { + const data = { + title: 'End Node', + outputs: [ + { + variable: 'result', + value_selector: ['old_id', 'text'], + }, + ], + } + const nodeIdMap = new Map() + nodeIdMap.set('old_id', { id: 'new_uuid', data: { type: 'llm' } }) + + const result = replaceVariableReferences(data, nodeIdMap) as NodeData + expect(result.outputs[0].value_selector).toEqual(['new_uuid', 'text']) + }) + + it('should handle nested objects recursively', () => { + const data = { + config: { + model: { + prompt: '{{#old_id.text#}}', + }, + }, + } + const nodeIdMap = new Map() + nodeIdMap.set('old_id', { id: 'new_uuid', data: { type: 'llm' } }) + + const result = replaceVariableReferences(data, nodeIdMap) as any + expect(result.config.model.prompt).toBe('{{#new_uuid.text#}}') + }) + + it('should ignoring missing node mappings', () => { + const data = { + text: '{{#missing_id.text#}}', + } + const nodeIdMap = new Map() + // missing_id is not in map + + const result = replaceVariableReferences(data, nodeIdMap) as NodeData + expect(result.text).toBe('{{#missing_id.text#}}') + }) + }) +}) diff --git a/web/app/components/workflow/hooks/index.ts b/web/app/components/workflow/hooks/index.ts index df54065dea..f3765b92ed 100644 --- a/web/app/components/workflow/hooks/index.ts +++ b/web/app/components/workflow/hooks/index.ts @@ -25,3 +25,4 @@ export * from './use-workflow-search' export * from './use-workflow-start-run' export * from './use-workflow-variables' export * from './use-workflow-vibe' +export * from './use-workflow-vibe-config' diff --git a/web/app/components/workflow/hooks/use-checklist.ts b/web/app/components/workflow/hooks/use-checklist.ts index 7cead40705..9b8a1ae569 100644 --- a/web/app/components/workflow/hooks/use-checklist.ts +++ b/web/app/components/workflow/hooks/use-checklist.ts @@ -159,7 +159,7 @@ export const useChecklist = (nodes: Node[], edges: Edge[]) => { } } else { - usedVars = getNodeUsedVars(node).filter(v => v.length > 0) + usedVars = getNodeUsedVars(node).filter(v => v && v.length > 0) } if (node.type === CUSTOM_NODE) { @@ -355,7 +355,7 @@ export const useChecklistBeforePublish = () => { } } else { - usedVars = getNodeUsedVars(node).filter(v => v.length > 0) + usedVars = getNodeUsedVars(node).filter(v => v && v.length > 0) } const checkData = getCheckData(node.data, datasets) const { errorMessage } = nodesExtraData![node.data.type as BlockEnum].checkValid(checkData, t, moreDataForCheckValid) diff --git a/web/app/components/workflow/hooks/use-workflow-vibe-config.ts b/web/app/components/workflow/hooks/use-workflow-vibe-config.ts new file mode 100644 index 0000000000..a149d15947 --- /dev/null +++ b/web/app/components/workflow/hooks/use-workflow-vibe-config.ts @@ -0,0 +1,99 @@ +/** + * Vibe Workflow Generator Configuration + * + * This module centralizes configuration for the Vibe workflow generation feature, + * including node type aliases and field name corrections. + * + * Note: These definitions are mirrored in the backend at: + * api/core/workflow/generator/config/node_schemas.py + * When updating these values, also update the backend file. + */ + +/** + * Node type aliases for inference from natural language. + * Maps common terms to canonical node type names. + */ +export const NODE_TYPE_ALIASES: Record = { + // Start node aliases + 'start': 'start', + 'begin': 'start', + 'input': 'start', + // End node aliases + 'end': 'end', + 'finish': 'end', + 'output': 'end', + // LLM node aliases + 'llm': 'llm', + 'ai': 'llm', + 'gpt': 'llm', + 'model': 'llm', + 'chat': 'llm', + // Code node aliases + 'code': 'code', + 'script': 'code', + 'python': 'code', + 'javascript': 'code', + // HTTP request node aliases + 'http-request': 'http-request', + 'http': 'http-request', + 'request': 'http-request', + 'api': 'http-request', + 'fetch': 'http-request', + 'webhook': 'http-request', + // Conditional node aliases + 'if-else': 'if-else', + 'condition': 'if-else', + 'branch': 'if-else', + 'switch': 'if-else', + // Loop node aliases + 'iteration': 'iteration', + 'loop': 'loop', + 'foreach': 'iteration', + // Tool node alias + 'tool': 'tool', +} + +/** + * Field name corrections for LLM-generated node configs. + * Maps incorrect field names to correct ones for specific node types. + */ +export const FIELD_NAME_CORRECTIONS: Record> = { + 'http-request': { + text: 'body', // LLM might use "text" instead of "body" + content: 'body', + response: 'body', + }, + 'code': { + text: 'result', // LLM might use "text" instead of "result" + output: 'result', + }, + 'llm': { + response: 'text', + answer: 'text', + }, +} + +/** + * Correct field names based on node type. + * LLM sometimes generates wrong field names (e.g., "text" instead of "body" for HTTP nodes). + * + * @param field - The field name to correct + * @param nodeType - The type of the node + * @returns The corrected field name, or the original if no correction needed + */ +export const correctFieldName = (field: string, nodeType: string): string => { + const corrections = FIELD_NAME_CORRECTIONS[nodeType] + if (corrections && corrections[field]) + return corrections[field] + return field +} + +/** + * Get the canonical node type from an alias. + * + * @param alias - The alias to look up + * @returns The canonical node type, or undefined if not found + */ +export const getCanonicalNodeType = (alias: string): string | undefined => { + return NODE_TYPE_ALIASES[alias.toLowerCase()] +} diff --git a/web/app/components/workflow/hooks/use-workflow-vibe.tsx b/web/app/components/workflow/hooks/use-workflow-vibe.tsx index 010e5ca53c..d4f5edc7ad 100644 --- a/web/app/components/workflow/hooks/use-workflow-vibe.tsx +++ b/web/app/components/workflow/hooks/use-workflow-vibe.tsx @@ -3,6 +3,7 @@ import type { ToolDefaultValue } from '../block-selector/types' import type { Edge, Node, ToolWithProvider } from '../types' import type { Tool } from '@/app/components/tools/types' +import type { BackendEdgeSpec, BackendNodeSpec } from '@/service/debug' import type { Model } from '@/types/app' import { useSessionStorageState } from 'ahooks' import { useCallback, useEffect, useMemo, useRef, useState } from 'react' @@ -38,10 +39,12 @@ import { getNodeCustomTypeByNodeDataType, getNodesConnectedSourceOrTargetHandleIdsMap, } from '../utils' +import { initialNodes as initializeNodeData } from '../utils/workflow-init' import { useNodesMetaData } from './use-nodes-meta-data' import { useNodesSyncDraft } from './use-nodes-sync-draft' import { useNodesReadOnly } from './use-workflow' import { useWorkflowHistory, WorkflowHistoryEvent } from './use-workflow-history' +import { correctFieldName, NODE_TYPE_ALIASES } from './use-workflow-vibe-config' type VibeCommandDetail = { dsl?: string @@ -105,6 +108,79 @@ const normalizeProviderIcon = (icon?: ToolWithProvider['icon']) => { return icon } +/** + * Replace variable references in node data using the nodeIdMap. + * Handles: + * - String templates: {{#old_id.field#}} → {{#new_id.field#}} + * - Value selectors: ["old_id", "field"] → ["new_id", "field"] + * - Mixed content objects: {type: "mixed", value: "..."} → normalized to string + * - Field name correction based on node type + */ +export const replaceVariableReferences = ( + data: unknown, + nodeIdMap: Map, + parentKey?: string, +): unknown => { + if (typeof data === 'string') { + // Replace {{#old_id.field#}} patterns and correct field names + return data.replace(/\{\{#([^.#]+)\.([^#]+)#\}\}/g, (match, oldId, field) => { + const newNode = nodeIdMap.get(oldId) + if (newNode) { + const nodeType = newNode.data?.type as string || '' + const correctedField = correctFieldName(field, nodeType) + return `{{#${newNode.id}.${correctedField}#}}` + } + return match // Keep original if no mapping found + }) + } + + if (Array.isArray(data)) { + // Check if this is a value_selector array: ["node_id", "field", ...] + if (data.length >= 2 && typeof data[0] === 'string' && typeof data[1] === 'string') { + const potentialNodeId = data[0] + const newNode = nodeIdMap.get(potentialNodeId) + // #region agent log + if (!newNode && !['sys', 'env', 'conversation'].includes(potentialNodeId)) { + console.warn(`[VIBE DEBUG] replaceVariableReferences: No mapping for "${potentialNodeId}" in selector [${data.join(', ')}]`) + } + // #endregion + if (newNode) { + const nodeType = newNode.data?.type as string || '' + const correctedField = correctFieldName(data[1], nodeType) + // Replace the node ID and correct field name in value_selector + return [newNode.id, correctedField, ...data.slice(2)] + } + } + // Recursively process array elements + return data.map(item => replaceVariableReferences(item, nodeIdMap)) + } + + if (data !== null && typeof data === 'object') { + const obj = data as Record + + // Handle "mixed content" objects like {type: "mixed", value: "{{#...#}}"} + // These should be normalized to plain strings for fields like 'url' + if (obj.type === 'mixed' && typeof obj.value === 'string') { + const processedValue = replaceVariableReferences(obj.value, nodeIdMap) as string + // For certain fields (url, headers, params), return just the string value + if (parentKey && ['url', 'headers', 'params'].includes(parentKey)) { + return processedValue + } + // Otherwise keep the object structure but update the value + return { ...obj, value: processedValue } + } + + // Recursively process object properties + const result: Record = {} + for (const [key, value] of Object.entries(obj)) { + result[key] = replaceVariableReferences(value, nodeIdMap, key) + } + return result + } + + return data // Return primitives as-is +} + const parseNodeLabel = (label: string) => { const tokens = label.split('|').map(token => token.trim()).filter(Boolean) const info: Record = {} @@ -116,8 +192,17 @@ const parseNodeLabel = (label: string) => { info[rawKey.trim().toLowerCase()] = rest.join('=').trim() }) + // Fallback: if no type= found, try to infer from label text if (!info.type && tokens.length === 1 && !tokens[0].includes('=')) { - info.type = tokens[0] + const labelLower = tokens[0].toLowerCase() + // Check if label matches a known node type alias + if (NODE_TYPE_ALIASES[labelLower]) { + info.type = NODE_TYPE_ALIASES[labelLower] + info.title = tokens[0] // Use original label as title + } + else { + info.type = tokens[0] + } } if (!info.tool && info.tool_key) @@ -345,6 +430,28 @@ export const useVibeFlowData = ({ storageKey }: UseVibeFlowDataParams) => { } } +const buildEdge = ( + source: Node, + target: Node, + sourceHandle = 'source', + targetHandle = 'target', +): Edge => ({ + id: `${source.id}-${sourceHandle}-${target.id}-${targetHandle}`, + type: CUSTOM_EDGE, + source: source.id, + sourceHandle, + target: target.id, + targetHandle, + data: { + sourceType: source.data.type, + targetType: target.data.type, + isInIteration: false, + isInLoop: false, + _connectedNodeIsSelected: false, + }, + zIndex: 0, +}) + export const useWorkflowVibe = () => { const { t } = useTranslation() const store = useStoreApi() @@ -356,7 +463,7 @@ export const useWorkflowVibe = () => { const { handleSyncWorkflowDraft } = useNodesSyncDraft() const { getNodesReadOnly } = useNodesReadOnly() const { saveStateToHistory } = useWorkflowHistory() - const { defaultModel } = useModelListAndDefaultModelAndCurrentProviderAndModel(ModelTypeEnum.textGeneration) + const { defaultModel, modelList } = useModelListAndDefaultModelAndCurrentProviderAndModel(ModelTypeEnum.textGeneration) const { data: buildInTools } = useAllBuiltInTools() const { data: customTools } = useAllCustomTools() @@ -476,14 +583,24 @@ export const useWorkflowVibe = () => { const toolLookup = useMemo(() => { const map = new Map() toolOptions.forEach((tool) => { + // Primary key: provider_id/tool_name (e.g., "google/google_search") const primaryKey = normalizeKey(`${tool.provider_id}/${tool.tool_name}`) map.set(primaryKey, tool) + // Fallback 1: provider_name/tool_name (e.g., "Google/google_search") const providerNameKey = normalizeKey(`${tool.provider_name}/${tool.tool_name}`) map.set(providerNameKey, tool) + // Fallback 2: tool_label (display name) const labelKey = normalizeKey(tool.tool_label) map.set(labelKey, tool) + + // Fallback 3: tool_name alone (for partial matching when model omits provider) + const toolNameKey = normalizeKey(tool.tool_name) + if (!map.has(toolNameKey)) { + // Only set if not already taken (avoid collisions between providers) + map.set(toolNameKey, tool) + } }) return map }, [toolOptions]) @@ -502,6 +619,409 @@ export const useWorkflowVibe = () => { return map }, [nodesMetaDataMap]) + const createGraphFromBackendNodes = useCallback(async ( + backendNodes: BackendNodeSpec[], + backendEdges: BackendEdgeSpec[], + ): Promise => { + const { getNodes } = store.getState() + const nodes = getNodes() + + if (!nodesMetaDataMap) { + Toast.notify({ type: 'error', message: t('workflow.vibe.nodesUnavailable') }) + return { nodes: [], edges: [] } + } + + const existingStartNode = nodes.find(node => node.data.type === BlockEnum.Start) + const newNodes: Node[] = [] + const nodeIdMap = new Map() + + for (const nodeSpec of backendNodes) { + // Map string type to BlockEnum + const typeKey = normalizeKey(nodeSpec.type) + const nodeType = nodeTypeLookup.get(typeKey) + if (!nodeType) { + // Skip unknown node types + continue + } + + if (nodeType === BlockEnum.Start && existingStartNode) { + // Merge backend variables into existing Start node + const backendVariables = (nodeSpec.config?.variables as Array>) || [] + if (backendVariables.length > 0) { + const existingVariables = (existingStartNode.data.variables as Array>) || [] + // Add new variables that don't already exist + for (const backendVar of backendVariables) { + const varName = backendVar.variable as string + const exists = existingVariables.some(v => v.variable === varName) + if (!exists) { + existingVariables.push(backendVar) + } + } + // Note: we don't mutate existingStartNode directly here for the return value, + // but we should probably include it in the graph if we want it to be part of the preview? + // Actually, existingStartNode is already in 'nodes'. + // The preview usually shows ONLY new nodes + maybe start node? + // User's code applied changes to existingStartNode directly. + // For preview, we might want to clone it. + // For now, we just map it. + } + + nodeIdMap.set(nodeSpec.id, existingStartNode) + continue + } + + const nodeDefault = nodesMetaDataMap[nodeType] + if (!nodeDefault) + continue + + const defaultValue = nodeDefault.defaultValue || {} + const title = nodeSpec.title?.trim() || nodeDefault.metaData.title || defaultValue.title || nodeSpec.type + + // For tool nodes, try to get tool default value from config + let toolDefaultValue: ToolDefaultValue | undefined + if (nodeType === BlockEnum.Tool && nodeSpec.config) { + const toolName = nodeSpec.config.tool_name as string | undefined + const providerId = nodeSpec.config.provider_id as string | undefined + if (toolName && providerId) { + const toolKey = normalizeKey(`${providerId}/${toolName}`) + toolDefaultValue = toolLookup.get(toolKey) || toolLookup.get(normalizeKey(toolName)) + } + } + + const desc = (toolDefaultValue?.tool_description || (defaultValue as { desc?: string }).desc || '') as string + + // Merge backend config into node data + // Backend provides: { url: "{{#start.url#}}", method: "GET", ... } + const backendConfig = nodeSpec.config || {} + + // Deep merge for nested objects (e.g., body, authorization) to preserve required fields + const mergedConfig: Record = { ...backendConfig } + const defaultValueRecord = defaultValue as Record + + // For http-request nodes, ensure body has all required fields + if (nodeType === BlockEnum.HttpRequest) { + const defaultBody = defaultValueRecord.body as Record | undefined + const backendBody = backendConfig.body as Record | undefined + if (defaultBody || backendBody) { + mergedConfig.body = { + type: 'none', + data: [], + ...(defaultBody || {}), + ...(backendBody || {}), + } + // Ensure data is always an array + if (!Array.isArray((mergedConfig.body as Record).data)) { + (mergedConfig.body as Record).data = [] + } + } + + // Ensure authorization has type + const defaultAuth = defaultValueRecord.authorization as Record | undefined + const backendAuth = backendConfig.authorization as Record | undefined + if (defaultAuth || backendAuth) { + mergedConfig.authorization = { + type: 'no-auth', + ...(defaultAuth || {}), + ...(backendAuth || {}), + } + } + } + + // For End nodes, ensure outputs have value_selector format + // New format (preferred): {"outputs": [{"variable": "name", "value_selector": ["nodeId", "field"]}]} + // Legacy format (fallback): {"outputs": [{"variable": "name", "value": "{{#nodeId.field#}}"}]} + if (nodeType === BlockEnum.End && backendConfig.outputs) { + const outputs = backendConfig.outputs as Array<{ variable?: string, value?: string, value_selector?: string[] }> + mergedConfig.outputs = outputs.map((output) => { + // Preferred: value_selector array format (new LLM output format) + if (output.value_selector && Array.isArray(output.value_selector)) { + return output + } + // Parse value like "{{#nodeId.field#}}" into ["nodeId", "field"] + if (output.value) { + const match = output.value.match(/\{\{#([^.]+)\.([^#]+)#\}\}/) + if (match) { + return { + variable: output.variable, + value_selector: [match[1], match[2]], + } + } + } + // Fallback: return with empty value_selector to prevent crash + return { + variable: output.variable || 'output', + value_selector: [], + } + }) + } + + // For Parameter Extractor nodes, ensure each parameter has a 'required' field + // Backend may omit this field, but Dify's Pydantic model requires it + if (nodeType === BlockEnum.ParameterExtractor) { + // Fix: If backend returns query as null, use default empty array instead + if (backendConfig.query === null || backendConfig.query === undefined) { + mergedConfig.query = [] + } + if (backendConfig.parameters) { + const parameters = backendConfig.parameters as Array<{ name?: string, type?: string, description?: string, required?: boolean }> + mergedConfig.parameters = parameters.map(param => ({ + ...param, + required: param.required ?? true, // Default to required if not specified + })) + } + } + + // For Question Classifier nodes, ensure query_variable_selector is not null + // Backend may return null, but Dify's Pydantic model requires an array + // Note: question-classifier uses 'query' field in backend config, but 'query_variable_selector' in frontend + if (nodeType === BlockEnum.QuestionClassifier) { + // Fix: If backend returns query as null, use default empty array instead + const backendQuery = backendConfig.query + if (backendQuery === null || backendQuery === undefined) { + mergedConfig.query_variable_selector = [] + } + else if (Array.isArray(backendQuery)) { + // Map backend 'query' field to frontend 'query_variable_selector' field + mergedConfig.query_variable_selector = backendQuery + // Remove the 'query' field to avoid confusion + delete mergedConfig.query + } + } + + // For Variable Aggregator nodes, ensure variables format is correct + // Backend expects list[list[str]], but LLM may generate dict format + if (nodeType === BlockEnum.VariableAggregator && backendConfig.variables) { + const backendVariables = backendConfig.variables as Array + const repairedVariables: string[][] = [] + let repaired = false + + for (const varItem of backendVariables) { + if (Array.isArray(varItem)) { + // Already in correct format + repairedVariables.push(varItem) + } + else if (typeof varItem === 'object' && varItem !== null) { + // Convert dict format to array format + const valueSelector = varItem.value_selector || varItem.selector || varItem.path + if (Array.isArray(valueSelector) && valueSelector.length > 0) { + repairedVariables.push(valueSelector) + repaired = true + } + else { + // Try to extract from name field - LLM may generate {"name": "node_id.field"} + const name = varItem.name + if (typeof name === 'string' && name.includes('.')) { + const parts = name.split('.', 2) + if (parts.length === 2) { + repairedVariables.push([parts[0], parts[1]]) + repaired = true + } + } + // If still can't parse, skip this variable (don't add empty array) + } + } + } + + if (repaired || repairedVariables.length !== backendVariables.length) { + mergedConfig.variables = repairedVariables + } + } + + // For any node with model config, ALWAYS use user's configured model + // This prevents "Model not exist" errors when LLM generates models the user doesn't have configured + // Applies to: LLM, QuestionClassifier, ParameterExtractor, and any future model-dependent nodes + if (backendConfig.model) { + // Try to use defaultModel first, fallback to first available model from modelList + const fallbackModel = modelList?.[0]?.models?.[0] + const modelProvider = defaultModel?.provider?.provider || modelList?.[0]?.provider + const modelName = defaultModel?.model || fallbackModel?.model + + if (modelProvider && modelName) { + mergedConfig.model = { + provider: modelProvider, + name: modelName, + mode: 'chat', + } + } + } + + const data = { + ...(defaultValue as Record), + title, + desc, + type: nodeType, + selected: false, + ...(toolDefaultValue || {}), + // Apply backend-generated config (url, method, headers, etc.) + ...mergedConfig, + } + + const newNode = generateNewNode({ + id: uuid4(), + type: getNodeCustomTypeByNodeDataType(nodeType), + data, + position: nodeSpec.position || { x: 0, y: 0 }, + }).newNode + + newNodes.push(newNode) + nodeIdMap.set(nodeSpec.id, newNode) + } + + // Replace variable references in all node configs using the nodeIdMap + // This converts {{#old_id.field#}} to {{#new_uuid.field#}} + + for (const node of newNodes) { + node.data = replaceVariableReferences(node.data, nodeIdMap) as typeof node.data + } + + // Use Dify's standard node initialization to handle all node types generically + // This sets up _targetBranches for question-classifier/if-else, _children for iteration/loop, etc. + const initializedNodes = initializeNodeData(newNodes, []) + + // Update newNodes with initialized data + newNodes.splice(0, newNodes.length, ...initializedNodes) + + if (!newNodes.length) { + Toast.notify({ type: 'error', message: t('workflow.vibe.invalidFlowchart') }) + return { nodes: [], edges: [] } + } + + const newEdges: Edge[] = [] + for (const edgeSpec of backendEdges) { + const sourceNode = nodeIdMap.get(edgeSpec.source) + const targetNode = nodeIdMap.get(edgeSpec.target) + + if (!sourceNode || !targetNode) { + console.warn(`[VIBE] Edge skipped: source=${edgeSpec.source} (found=${!!sourceNode}), target=${edgeSpec.target} (found=${!!targetNode})`) + continue + } + + let sourceHandle = edgeSpec.sourceHandle || 'source' + // Handle IfElse branch handles + if (sourceNode.data.type === BlockEnum.IfElse && !edgeSpec.sourceHandle) { + sourceHandle = 'source' + } + + newEdges.push(buildEdge(sourceNode, targetNode, sourceHandle, edgeSpec.targetHandle || 'target')) + } + + // Layout nodes + const bounds = nodes.reduce( + (acc, node) => { + const width = node.width ?? NODE_WIDTH + acc.maxX = Math.max(acc.maxX, node.position.x + width) + acc.minY = Math.min(acc.minY, node.position.y) + return acc + }, + { maxX: 0, minY: 0 }, + ) + + const baseX = nodes.length ? bounds.maxX + NODE_WIDTH_X_OFFSET : 0 + const baseY = Number.isFinite(bounds.minY) ? bounds.minY : 0 + const branchOffset = Math.max(120, NODE_WIDTH_X_OFFSET / 2) + + const layoutNodeIds = new Set(newNodes.map(node => node.id)) + const layoutEdges = newEdges.filter(edge => + layoutNodeIds.has(edge.source) && layoutNodeIds.has(edge.target), + ) + + try { + const layout = await getLayoutByDagre(newNodes, layoutEdges) + const layoutedNodes = newNodes.map((node) => { + const info = layout.nodes.get(node.id) + if (!info) + return node + return { + ...node, + position: { + x: baseX + info.x, + y: baseY + info.y, + }, + } + }) + newNodes.splice(0, newNodes.length, ...layoutedNodes) + } + catch { + newNodes.forEach((node, index) => { + const row = Math.floor(index / 4) + const col = index % 4 + node.position = { + x: baseX + col * NODE_WIDTH_X_OFFSET, + y: baseY + row * branchOffset, + } + }) + } + + return { + nodes: newNodes, + edges: newEdges, + } + }, [ + defaultModel, + nodeTypeLookup, + nodesMetaDataMap, + store, + t, + toolLookup, + ]) + + // Apply backend-provided nodes directly (bypasses mermaid parsing) + const applyBackendNodesToWorkflow = useCallback(async ( + backendNodes: BackendNodeSpec[], + backendEdges: BackendEdgeSpec[], + ) => { + const { getNodes, setNodes, edges, setEdges } = store.getState() + const nodes = getNodes() + const { + setShowVibePanel, + } = workflowStore.getState() + + const { nodes: newNodes, edges: newEdges } = await createGraphFromBackendNodes(backendNodes, backendEdges) + + if (newNodes.length === 0) { + setShowVibePanel(false) + return + } + + const allNodes = [...nodes, ...newNodes] + const nodesConnectedMap = getNodesConnectedSourceOrTargetHandleIdsMap( + newEdges.map(edge => ({ type: 'add', edge })), + allNodes, + ) + + const updatedNodes = allNodes.map((node) => { + const connected = nodesConnectedMap[node.id] + if (!connected) + return node + + return { + ...node, + data: { + ...node.data, + ...connected, + _connectedSourceHandleIds: dedupeHandles(connected._connectedSourceHandleIds), + _connectedTargetHandleIds: dedupeHandles(connected._connectedTargetHandleIds), + }, + } + }) + + setNodes(updatedNodes) + setEdges([...edges, ...newEdges]) + saveStateToHistory(WorkflowHistoryEvent.NodeAdd, { nodeId: newNodes[0].id }) + handleSyncWorkflowDraft() + + workflowStore.setState(state => ({ + ...state, + showVibePanel: false, + vibePanelMermaidCode: '', + })) + }, [ + createGraphFromBackendNodes, + handleSyncWorkflowDraft, + saveStateToHistory, + store, + ]) + const flowchartToWorkflowGraph = useCallback(async (mermaidCode: string): Promise => { const { getNodes } = store.getState() const nodes = getNodes() @@ -585,28 +1105,6 @@ export const useWorkflowVibe = () => { return emptyGraph } - const buildEdge = ( - source: Node, - target: Node, - sourceHandle = 'source', - targetHandle = 'target', - ): Edge => ({ - id: `${source.id}-${sourceHandle}-${target.id}-${targetHandle}`, - type: CUSTOM_EDGE, - source: source.id, - sourceHandle, - target: target.id, - targetHandle, - data: { - sourceType: source.data.type, - targetType: target.data.type, - isInIteration: false, - isInLoop: false, - _connectedNodeIsSelected: false, - }, - zIndex: 0, - }) - const newEdges: Edge[] = [] for (const edgeSpec of parseResultToUse.edges) { const sourceNode = nodeIdMap.get(edgeSpec.sourceId) @@ -699,7 +1197,7 @@ export const useWorkflowVibe = () => { nodes: updatedNodes, edges: newEdges, } - }, [nodeTypeLookup, toolLookup]) + }, [nodeTypeLookup, nodesMetaDataMap, store, t, toolLookup]) const applyFlowchartToWorkflow = useCallback(() => { if (!currentFlowGraph || !currentFlowGraph.nodes || currentFlowGraph.nodes.length === 0) { @@ -724,15 +1222,16 @@ export const useWorkflowVibe = () => { }, [ currentFlowGraph, handleSyncWorkflowDraft, - nodeTypeLookup, - nodesMetaDataMap, saveStateToHistory, store, t, - toolLookup, ]) - const handleVibeCommand = useCallback(async (dsl?: string, skipPanelPreview = false) => { + const handleVibeCommand = useCallback(async ( + dsl?: string, + skipPanelPreview = false, + regenerateMode = false, + ) => { if (getNodesReadOnly()) { Toast.notify({ type: 'error', message: t('workflow.vibe.readOnly') }) return @@ -768,6 +1267,9 @@ export const useWorkflowVibe = () => { isVibeGenerating: true, vibePanelMermaidCode: '', vibePanelInstruction: trimmed, + vibePanelIntent: '', + vibePanelMessage: '', + vibePanelSuggestions: [], })) try { @@ -790,6 +1292,11 @@ export const useWorkflowVibe = () => { tool_name: tool.tool_name, tool_label: tool.tool_label, tool_key: `${tool.provider_id}/${tool.tool_name}`, + tool_description: tool.tool_description, + is_team_authorization: tool.is_team_authorization, + // Include parameter schemas so backend can inform model how to use tools + parameters: tool.paramSchemas, + output_schema: tool.output_schema, })) const availableNodesPayload = availableNodesList.map(node => ({ @@ -798,15 +1305,68 @@ export const useWorkflowVibe = () => { description: node.description, })) - let mermaidCode = trimmed + let mermaidCode = '' + let backendNodes: BackendNodeSpec[] | undefined + let backendEdges: BackendEdgeSpec[] | undefined + if (!isMermaidFlowchart(trimmed)) { - const { error, flowchart } = await generateFlowchart({ + // Build previous workflow context if regenerating + const { vibePanelBackendNodes, vibePanelBackendEdges, vibePanelLastWarnings } = workflowStore.getState() + const previousWorkflow = regenerateMode && vibePanelBackendNodes && vibePanelBackendNodes.length > 0 + ? { + nodes: vibePanelBackendNodes, + edges: vibePanelBackendEdges || [], + warnings: vibePanelLastWarnings || [], + } + : undefined + + // Map language code to human-readable language name for LLM + const languageNameMap: Record = { + en_US: 'English', + zh_Hans: 'Chinese', + zh_Hant: 'Traditional Chinese', + ja_JP: 'Japanese', + ko_KR: 'Korean', + pt_BR: 'Portuguese', + es_ES: 'Spanish', + fr_FR: 'French', + de_DE: 'German', + it_IT: 'Italian', + ru_RU: 'Russian', + uk_UA: 'Ukrainian', + vi_VN: 'Vietnamese', + pl_PL: 'Polish', + ro_RO: 'Romanian', + tr_TR: 'Turkish', + fa_IR: 'Persian', + hi_IN: 'Hindi', + } + const preferredLanguage = languageNameMap[language] || 'English' + + // Extract available models from user's configured model providers + const availableModelsPayload = modelList?.flatMap(provider => + provider.models.map(model => ({ + provider: provider.provider, + model: model.model, + })), + ) || [] + + const requestPayload = { instruction: trimmed, model_config: latestModelConfig, available_nodes: availableNodesPayload, existing_nodes: existingNodesPayload, available_tools: toolsPayload, - }) + selected_node_ids: [], + previous_workflow: previousWorkflow, + regenerate_mode: regenerateMode, + language: preferredLanguage, + available_models: availableModelsPayload, + } + + const response = await generateFlowchart(requestPayload) + + const { error, flowchart, nodes, edges, intent, message, warnings, suggestions } = response if (error) { Toast.notify({ type: 'error', message: error }) @@ -814,47 +1374,134 @@ export const useWorkflowVibe = () => { return } + // Handle off_topic intent - show rejection message and suggestions + if (intent === 'off_topic') { + workflowStore.setState(state => ({ + ...state, + vibePanelMermaidCode: '', + vibePanelMessage: message || t('workflow.vibe.offTopicDefault'), + vibePanelSuggestions: suggestions || [], + vibePanelIntent: 'off_topic', + isVibeGenerating: false, + })) + return + } + if (!flowchart) { Toast.notify({ type: 'error', message: t('workflow.vibe.missingFlowchart') }) setIsVibeGenerating(false) return } + // Show warnings if any (includes tool sanitization warnings) + const responseWarnings = warnings || [] + if (responseWarnings.length > 0) { + responseWarnings.forEach((warning) => { + Toast.notify({ type: 'warning', message: warning }) + }) + } + mermaidCode = flowchart + // Store backend nodes/edges for direct use (bypasses mermaid re-parsing) + backendNodes = nodes + backendEdges = edges + // Store warnings for regeneration context + workflowStore.setState(state => ({ + ...state, + vibePanelLastWarnings: responseWarnings, + })) + + workflowStore.setState(state => ({ + ...state, + vibePanelMermaidCode: mermaidCode, + vibePanelBackendNodes: backendNodes, + vibePanelBackendEdges: backendEdges, + vibePanelMessage: '', + vibePanelSuggestions: [], + vibePanelIntent: 'generate', + isVibeGenerating: false, + })) } - workflowStore.setState(state => ({ - ...state, - vibePanelMermaidCode: mermaidCode, - isVibeGenerating: false, - })) + setIsVibeGenerating(false) - const workflowGraph = await flowchartToWorkflowGraph(mermaidCode) - addVersion(workflowGraph) + // Add version for preview + if (backendNodes && backendNodes.length > 0 && backendEdges) { + const graph = await createGraphFromBackendNodes(backendNodes, backendEdges) + addVersion(graph) + } + else if (mermaidCode) { + const graph = await flowchartToWorkflowGraph(mermaidCode) + addVersion(graph) + } - if (skipPanelPreview) - applyFlowchartToWorkflow() + if (skipPanelPreview) { + // Prefer backend nodes (already sanitized) over mermaid re-parsing + if (backendNodes && backendNodes.length > 0 && backendEdges) { + console.log('[VIBE] Applying backend nodes directly to workflow') + console.log('[VIBE] Backend nodes:', backendNodes.length) + console.log('[VIBE] Backend edges:', backendEdges.length) + await applyBackendNodesToWorkflow(backendNodes, backendEdges) + console.log('[VIBE] Backend nodes applied successfully') + } + else { + console.log('[VIBE] Applying mermaid flowchart to workflow') + await applyFlowchartToWorkflow() + console.log('[VIBE] Mermaid flowchart applied successfully') + } + } + } + catch (error: unknown) { + // Handle API errors (e.g., network errors, server errors) + const { setIsVibeGenerating } = workflowStore.getState() + setIsVibeGenerating(false) + + // Extract error message from Response object or Error + let errorMessage = t('workflow.vibe.generateError') + if (error instanceof Response) { + try { + const errorData = await error.json() + errorMessage = errorData?.message || errorMessage + } + catch { + // If we can't parse the response, use the default error message + } + } + else if (error instanceof Error) { + errorMessage = error.message || errorMessage + } + + Toast.notify({ type: 'error', message: errorMessage }) } finally { isGeneratingRef.current = false } }, [ - availableNodesList, + addVersion, + applyBackendNodesToWorkflow, + applyFlowchartToWorkflow, + createGraphFromBackendNodes, + flowchartToWorkflowGraph, + getLatestModelConfig, getNodesReadOnly, - handleSyncWorkflowDraft, nodeTypeLookup, nodesMetaDataMap, - saveStateToHistory, store, t, - toolLookup, toolOptions, - getLatestModelConfig, ]) - const handleAccept = useCallback(() => { - applyFlowchartToWorkflow() - }, [applyFlowchartToWorkflow]) + const handleAccept = useCallback(async () => { + // Prefer backend nodes (already sanitized) over mermaid re-parsing + const { vibePanelBackendNodes, vibePanelBackendEdges } = workflowStore.getState() + if (vibePanelBackendNodes && vibePanelBackendNodes.length > 0 && vibePanelBackendEdges) { + await applyBackendNodesToWorkflow(vibePanelBackendNodes, vibePanelBackendEdges) + } + else { + // Use applyFlowchartToWorkflow which uses currentFlowGraph (populated by addVersion) + applyFlowchartToWorkflow() + } + }, [applyBackendNodesToWorkflow, applyFlowchartToWorkflow]) useEffect(() => { const handler = (event: CustomEvent) => { diff --git a/web/app/components/workflow/nodes/_base/components/variable/utils.ts b/web/app/components/workflow/nodes/_base/components/variable/utils.ts index a7dc04e571..b7c7125ca6 100644 --- a/web/app/components/workflow/nodes/_base/components/variable/utils.ts +++ b/web/app/components/workflow/nodes/_base/components/variable/utils.ts @@ -1390,9 +1390,9 @@ export const getNodeUsedVars = (node: Node): ValueSelector[] => { payload.url, payload.headers, payload.params, - typeof payload.body.data === 'string' + typeof payload.body?.data === 'string' ? payload.body.data - : payload.body.data.map(d => d.value).join(''), + : (payload.body?.data?.map(d => d.value).join('') ?? ''), ]) break } diff --git a/web/app/components/workflow/nodes/http/hooks/use-key-value-list.ts b/web/app/components/workflow/nodes/http/hooks/use-key-value-list.ts index 650ae47156..d5a4f3d872 100644 --- a/web/app/components/workflow/nodes/http/hooks/use-key-value-list.ts +++ b/web/app/components/workflow/nodes/http/hooks/use-key-value-list.ts @@ -5,6 +5,9 @@ import { useCallback, useEffect, useState } from 'react' const UNIQUE_ID_PREFIX = 'key-value-' const strToKeyValueList = (value: string) => { + if (typeof value !== 'string' || !value) + return [] + return value.split('\n').map((item) => { const [key, ...others] = item.split(':') return { @@ -16,7 +19,7 @@ const strToKeyValueList = (value: string) => { } const useKeyValueList = (value: string, onChange: (value: string) => void, noFilter?: boolean) => { - const [list, doSetList] = useState(() => value ? strToKeyValueList(value) : []) + const [list, doSetList] = useState(() => typeof value === 'string' && value ? strToKeyValueList(value) : []) const setList = (l: KeyValue[]) => { doSetList(l.map((item) => { return { diff --git a/web/app/components/workflow/nodes/variable-assigner/components/node-group-item.tsx b/web/app/components/workflow/nodes/variable-assigner/components/node-group-item.tsx index fdffecb8f6..b6e5096dfa 100644 --- a/web/app/components/workflow/nodes/variable-assigner/components/node-group-item.tsx +++ b/web/app/components/workflow/nodes/variable-assigner/components/node-group-item.tsx @@ -127,23 +127,30 @@ const NodeGroupItem = ({ !!item.variables.length && (
{ - item.variables.map((variable = [], index) => { - const isSystem = isSystemVar(variable) + item.variables + .map((variable = [], index) => { + // Ensure variable is an array + const safeVariable = Array.isArray(variable) ? variable : [] + if (!safeVariable.length) + return null - const node = isSystem ? nodes.find(node => node.data.type === BlockEnum.Start) : nodes.find(node => node.id === variable[0]) - const varName = isSystem ? `sys.${variable[variable.length - 1]}` : variable.slice(1).join('.') - const isException = isExceptionVariable(varName, node?.data.type) + const isSystem = isSystemVar(safeVariable) - return ( - - ) - }) + const node = isSystem ? nodes.find(node => node.data.type === BlockEnum.Start) : nodes.find(node => node.id === safeVariable[0]) + const varName = isSystem ? `sys.${safeVariable[safeVariable.length - 1]}` : safeVariable.slice(1).join('.') + const isException = isExceptionVariable(varName, node?.data.type) + + return ( + + ) + }) + .filter(Boolean) }
) diff --git a/web/app/components/workflow/panel/vibe-panel/index.tsx b/web/app/components/workflow/panel/vibe-panel/index.tsx index 966172518c..2f644bd09a 100644 --- a/web/app/components/workflow/panel/vibe-panel/index.tsx +++ b/web/app/components/workflow/panel/vibe-panel/index.tsx @@ -3,7 +3,7 @@ import type { FC } from 'react' import type { FormValue } from '@/app/components/header/account-setting/model-provider-page/declarations' import type { CompletionParams, Model } from '@/types/app' -import { RiClipboardLine } from '@remixicon/react' +import { RiClipboardLine, RiInformation2Line } from '@remixicon/react' import copy from 'copy-to-clipboard' import { useCallback, useEffect, useState } from 'react' import { useTranslation } from 'react-i18next' @@ -29,8 +29,12 @@ const VibePanel: FC = () => { const { t } = useTranslation() const workflowStore = useWorkflowStore() const showVibePanel = useStore(s => s.showVibePanel) + const setShowVibePanel = useStore(s => s.setShowVibePanel) const isVibeGenerating = useStore(s => s.isVibeGenerating) + const setIsVibeGenerating = useStore(s => s.setIsVibeGenerating) const vibePanelInstruction = useStore(s => s.vibePanelInstruction) + const vibePanelMermaidCode = useStore(s => s.vibePanelMermaidCode) + const setVibePanelMermaidCode = useStore(s => s.setVibePanelMermaidCode) const configsMap = useHooksStore(s => s.configsMap) const { current: currentFlowGraph, versions, currentVersionIndex, setCurrentVersionIndex } = useVibeFlowData({ @@ -40,6 +44,14 @@ const VibePanel: FC = () => { const vibePanelPreviewNodes = currentFlowGraph?.nodes || [] const vibePanelPreviewEdges = currentFlowGraph?.edges || [] + const setVibePanelInstruction = useStore(s => s.setVibePanelInstruction) + const vibePanelIntent = useStore(s => s.vibePanelIntent) + const setVibePanelIntent = useStore(s => s.setVibePanelIntent) + const vibePanelMessage = useStore(s => s.vibePanelMessage) + const setVibePanelMessage = useStore(s => s.setVibePanelMessage) + const vibePanelSuggestions = useStore(s => s.vibePanelSuggestions) + const setVibePanelSuggestions = useStore(s => s.setVibePanelSuggestions) + const localModel = localStorage.getItem('auto-gen-model') ? JSON.parse(localStorage.getItem('auto-gen-model') as string) as Model : null @@ -97,13 +109,13 @@ const VibePanel: FC = () => { }, [workflowStore]) const handleClose = useCallback(() => { - workflowStore.setState(state => ({ - ...state, - showVibePanel: false, - vibePanelMermaidCode: '', - isVibeGenerating: false, - })) - }, [workflowStore]) + setShowVibePanel(false) + setVibePanelMermaidCode('') + setIsVibeGenerating(false) + setVibePanelIntent('') + setVibePanelMessage('') + setVibePanelSuggestions([]) + }, [setShowVibePanel, setVibePanelMermaidCode, setIsVibeGenerating, setVibePanelIntent, setVibePanelMessage, setVibePanelSuggestions]) const handleGenerate = useCallback(() => { const event = new CustomEvent(VIBE_COMMAND_EVENT, { @@ -119,10 +131,18 @@ const VibePanel: FC = () => { }, [handleClose]) const handleCopyMermaid = useCallback(() => { - const { vibePanelMermaidCode } = workflowStore.getState() copy(vibePanelMermaidCode) Toast.notify({ type: 'success', message: t('common.actionMsg.copySuccessfully') }) - }, [workflowStore, t]) + }, [vibePanelMermaidCode, t]) + + const handleSuggestionClick = useCallback((suggestion: string) => { + setVibePanelInstruction(suggestion) + // Trigger generation with the suggestion + const event = new CustomEvent(VIBE_COMMAND_EVENT, { + detail: { dsl: suggestion }, + }) + document.dispatchEvent(event) + }, [setVibePanelInstruction]) if (!showVibePanel) return null @@ -134,6 +154,40 @@ const VibePanel: FC = () => { ) + const renderOffTopic = ( +
+
+
+ +
+
+ {t('workflow.vibe.offTopicTitle')} +
+
+ {vibePanelMessage || t('workflow.vibe.offTopicDefault')} +
+ {vibePanelSuggestions.length > 0 && ( +
+
+ {t('workflow.vibe.trySuggestion')} +
+
+ {vibePanelSuggestions.map((suggestion, index) => ( + + ))} +
+
+ )} +
+
+ ) + return ( { - {!isVibeGenerating && vibePanelPreviewNodes.length > 0 && ( + {!isVibeGenerating && vibePanelIntent === 'off_topic' && renderOffTopic} + {!isVibeGenerating && vibePanelIntent !== 'off_topic' && (vibePanelPreviewNodes.length > 0 || vibePanelMermaidCode) && (
@@ -226,7 +281,7 @@ const VibePanel: FC = () => {
)} {isVibeGenerating && renderLoading} - {!isVibeGenerating && vibePanelPreviewNodes.length === 0 && } + {!isVibeGenerating && vibePanelIntent !== 'off_topic' && vibePanelPreviewNodes.length === 0 && !vibePanelMermaidCode && }
) diff --git a/web/app/components/workflow/store/workflow/panel-slice.ts b/web/app/components/workflow/store/workflow/panel-slice.ts index e90418823a..5cb2f8193f 100644 --- a/web/app/components/workflow/store/workflow/panel-slice.ts +++ b/web/app/components/workflow/store/workflow/panel-slice.ts @@ -1,5 +1,8 @@ +import type { BackendEdgeSpec, BackendNodeSpec } from '@/service/debug' import type { StateCreator } from 'zustand' +export type VibeIntent = 'generate' | 'off_topic' | 'error' | '' + export type PanelSliceShape = { panelWidth: number showFeaturesPanel: boolean @@ -26,6 +29,24 @@ export type PanelSliceShape = { setInitShowLastRunTab: (initShowLastRunTab: boolean) => void showVibePanel: boolean setShowVibePanel: (showVibePanel: boolean) => void + vibePanelMermaidCode: string + setVibePanelMermaidCode: (vibePanelMermaidCode: string) => void + vibePanelBackendNodes?: BackendNodeSpec[] + setVibePanelBackendNodes: (nodes?: BackendNodeSpec[]) => void + vibePanelBackendEdges?: BackendEdgeSpec[] + setVibePanelBackendEdges: (edges?: BackendEdgeSpec[]) => void + isVibeGenerating: boolean + setIsVibeGenerating: (isVibeGenerating: boolean) => void + vibePanelInstruction: string + setVibePanelInstruction: (vibePanelInstruction: string) => void + vibePanelIntent: VibeIntent + setVibePanelIntent: (vibePanelIntent: VibeIntent) => void + vibePanelMessage: string + setVibePanelMessage: (vibePanelMessage: string) => void + vibePanelSuggestions: string[] + setVibePanelSuggestions: (vibePanelSuggestions: string[]) => void + vibePanelLastWarnings: string[] + setVibePanelLastWarnings: (vibePanelLastWarnings: string[]) => void } export const createPanelSlice: StateCreator = set => ({ @@ -48,4 +69,22 @@ export const createPanelSlice: StateCreator = set => ({ setInitShowLastRunTab: initShowLastRunTab => set(() => ({ initShowLastRunTab })), showVibePanel: false, setShowVibePanel: showVibePanel => set(() => ({ showVibePanel })), + vibePanelMermaidCode: '', + setVibePanelMermaidCode: vibePanelMermaidCode => set(() => ({ vibePanelMermaidCode })), + vibePanelBackendNodes: undefined, + setVibePanelBackendNodes: vibePanelBackendNodes => set(() => ({ vibePanelBackendNodes })), + vibePanelBackendEdges: undefined, + setVibePanelBackendEdges: vibePanelBackendEdges => set(() => ({ vibePanelBackendEdges })), + isVibeGenerating: false, + setIsVibeGenerating: isVibeGenerating => set(() => ({ isVibeGenerating })), + vibePanelInstruction: '', + setVibePanelInstruction: vibePanelInstruction => set(() => ({ vibePanelInstruction })), + vibePanelIntent: '', + setVibePanelIntent: vibePanelIntent => set(() => ({ vibePanelIntent })), + vibePanelMessage: '', + setVibePanelMessage: vibePanelMessage => set(() => ({ vibePanelMessage })), + vibePanelSuggestions: [], + setVibePanelSuggestions: vibePanelSuggestions => set(() => ({ vibePanelSuggestions })), + vibePanelLastWarnings: [], + setVibePanelLastWarnings: vibePanelLastWarnings => set(() => ({ vibePanelLastWarnings })), }) diff --git a/web/i18n/en-US/workflow.ts b/web/i18n/en-US/workflow.ts index 9d00be30c7..203b3197a4 100644 --- a/web/i18n/en-US/workflow.ts +++ b/web/i18n/en-US/workflow.ts @@ -140,6 +140,10 @@ const translation = { regenerate: 'Regenerate', apply: 'Apply', noFlowchart: 'No flowchart provided', + offTopicDefault: 'I\'m the Dify workflow design assistant. I can help you create AI automation workflows, but I can\'t answer general questions. Would you like to create a workflow instead?', + offTopicTitle: 'Off-Topic Request', + trySuggestion: 'Try one of these suggestions:', + generateError: 'Failed to generate workflow. Please try again.', }, publishLimit: { startNodeTitlePrefix: 'Upgrade to', diff --git a/web/service/debug.ts b/web/service/debug.ts index 40aa8c2173..7e69fb5e29 100644 --- a/web/service/debug.ts +++ b/web/service/debug.ts @@ -19,8 +19,45 @@ export type GenRes = { error?: string } +export type ToolRecommendation = { + requested_capability: string + unconfigured_tools: Array<{ + provider_id: string + tool_name: string + description: string + }> + configured_alternatives: Array<{ + provider_id: string + tool_name: string + description: string + }> + recommendation: string +} + +export type BackendNodeSpec = { + id: string + type: string + title?: string + config?: Record + position?: { x: number; y: number } +} + +export type BackendEdgeSpec = { + source: string + target: string + sourceHandle?: string + targetHandle?: string +} + export type FlowchartGenRes = { + intent?: 'generate' | 'off_topic' | 'error' flowchart: string + nodes?: BackendNodeSpec[] + edges?: BackendEdgeSpec[] + message?: string + warnings?: string[] + suggestions?: string[] + tool_recommendations?: ToolRecommendation[] error?: string }