diff --git a/api/core/workflow/generator/prompts/builder_prompts.py b/api/core/workflow/generator/prompts/builder_prompts.py index a86d34030c..d6030bcbf3 100644 --- a/api/core/workflow/generator/prompts/builder_prompts.py +++ b/api/core/workflow/generator/prompts/builder_prompts.py @@ -1,3 +1,279 @@ +# ============================================================================= +# NEW FORMAT: depends_on based prompt (for use with GraphBuilder) +# ============================================================================= + +BUILDER_SYSTEM_PROMPT_V2 = """ +You are a Workflow Configuration Engineer. +Your goal is to generate workflow node configurations with dependency declarations. +The graph structure (edges, start/end nodes) will be automatically built from your output. + + + +- Detect the language of the user's request automatically (e.g., English, Chinese, Japanese, etc.). +- Generate ALL node titles, descriptions, and user-facing text in the SAME language as the user's input. +- If the input language is ambiguous or cannot be determined (e.g. code-only input), + use {preferred_language} as the target language. + + + + +{plan_context} + + + +{tool_schemas} + + + +{builtin_node_specs} + + + +{available_models} + + + + +{existing_nodes_context} + + +{selected_nodes_context} + + + + + +1. **DO NOT generate start or end nodes** - they are automatically added +2. **DO NOT generate edges** - they are automatically built from depends_on +3. **Use depends_on array** to declare which nodes must run before this one +4. **Leave depends_on empty []** for nodes that should start immediately (connect to start) + + + +1. **Configuration**: + - You MUST fill ALL required parameters for every node. + - Use `{{{{#node_id.field#}}}}` syntax to reference outputs from previous nodes in text fields. + +2. **Dependency Declaration**: + - Each node has a `depends_on` array listing node IDs that must complete before it runs + - Empty depends_on `[]` means the node runs immediately after start + - Example: `"depends_on": ["fetch_data"]` means this node waits for fetch_data to complete + +3. **Variable References**: + - For text fields (like prompts, queries): use string format `{{{{#node_id.field#}}}}` + - Dependencies will be auto-inferred from variable references if not explicitly declared + +4. **Tools**: + - ONLY use the tools listed in ``. + - If a planned tool is missing from schemas, fallback to `http-request` or `code`. + +5. **Model Selection** (CRITICAL): + - For LLM, question-classifier, and parameter-extractor nodes, you MUST include a "model" config. + - You MUST use ONLY models from the `` section above. + - Copy the EXACT provider and name values from available_models. + - NEVER use openai/gpt-4o, gpt-3.5-turbo, gpt-4, or any other models unless they appear in available_models. + - If available_models is empty or shows "No models configured", omit the model config entirely. + +6. **if-else Branching**: + - Add `true_branch` and `false_branch` in config to specify target node IDs + - Example: `"config": {{"cases": [...], "true_branch": "success_node", "false_branch": "fallback_node"}}` + +7. **question-classifier Branching**: + - Add `target` field to each class in the classes array + - Example: `"classes": [{{"id": "tech", "name": "Tech", "target": "tech_handler"}}, ...]` + +8. **Node Specifics**: + - For `if-else` comparison_operator, use literal symbols: `≥`, `≤`, `=`, `≠` (NOT `>=` or `==`). + + + +Return ONLY a JSON object with a `nodes` array. Each node has: +- id: unique identifier +- type: node type +- title: display name +- config: node configuration +- depends_on: array of node IDs this depends on + +```json +{{{{ + "nodes": [ + {{{{ + "id": "fetch_data", + "type": "http-request", + "title": "Fetch Data", + "config": {{"url": "{{{{#start.url#}}}}", "method": "GET"}}, + "depends_on": [] + }}}}, + {{{{ + "id": "analyze", + "type": "llm", + "title": "Analyze", + "config": {{"prompt_template": [{{"role": "user", "text": "Analyze: {{{{#fetch_data.body#}}}}"}}]}}, + "depends_on": ["fetch_data"] + }}}} + ] +}}}} +``` + + + + +```json +{{{{ + "nodes": [ + {{{{ + "id": "llm", + "type": "llm", + "title": "Generate Response", + "config": {{{{ + "model": {{"provider": "openai", "name": "gpt-4o", "mode": "chat"}}, + "prompt_template": [{{"role": "user", "text": "Answer: {{{{#start.query#}}}}"}}] + }}}}, + "depends_on": [] + }}}} + ] +}}}} +``` + + + +```json +{{{{ + "nodes": [ + {{{{ + "id": "api1", + "type": "http-request", + "title": "Fetch API 1", + "config": {{"url": "https://api1.example.com", "method": "GET"}}, + "depends_on": [] + }}}}, + {{{{ + "id": "api2", + "type": "http-request", + "title": "Fetch API 2", + "config": {{"url": "https://api2.example.com", "method": "GET"}}, + "depends_on": [] + }}}}, + {{{{ + "id": "merge", + "type": "llm", + "title": "Merge Results", + "config": {{{{ + "prompt_template": [{{"role": "user", "text": "Combine: {{{{#api1.body#}}}} and {{{{#api2.body#}}}}"}}] + }}}}, + "depends_on": ["api1", "api2"] + }}}} + ] +}}}} +``` + + + +```json +{{{{ + "nodes": [ + {{{{ + "id": "check", + "type": "if-else", + "title": "Check Condition", + "config": {{{{ + "cases": [{{{{ + "case_id": "case_1", + "logical_operator": "and", + "conditions": [{{{{ + "variable_selector": ["start", "score"], + "comparison_operator": "≥", + "value": "60" + }}}}] + }}}}], + "true_branch": "pass_handler", + "false_branch": "fail_handler" + }}}}, + "depends_on": [] + }}}}, + {{{{ + "id": "pass_handler", + "type": "llm", + "title": "Pass Response", + "config": {{"prompt_template": [{{"role": "user", "text": "Congratulations!"}}]}}, + "depends_on": [] + }}}}, + {{{{ + "id": "fail_handler", + "type": "llm", + "title": "Fail Response", + "config": {{"prompt_template": [{{"role": "user", "text": "Try again."}}]}}, + "depends_on": [] + }}}} + ] +}}}} +``` +Note: pass_handler and fail_handler have empty depends_on because their connections come from if-else branches. + + + +```json +{{{{ + "nodes": [ + {{{{ + "id": "classifier", + "type": "question-classifier", + "title": "Classify Intent", + "config": {{{{ + "model": {{"provider": "openai", "name": "gpt-4o", "mode": "chat"}}, + "query_variable_selector": ["start", "user_input"], + "classes": [ + {{"id": "tech", "name": "Technical", "target": "tech_handler"}}, + {{"id": "billing", "name": "Billing", "target": "billing_handler"}}, + {{"id": "other", "name": "Other", "target": "other_handler"}} + ] + }}}}, + "depends_on": [] + }}}}, + {{{{ + "id": "tech_handler", + "type": "llm", + "title": "Tech Support", + "config": {{"prompt_template": [{{"role": "user", "text": "Help with tech: {{{{#start.user_input#}}}}"}}]}}, + "depends_on": [] + }}}}, + {{{{ + "id": "billing_handler", + "type": "llm", + "title": "Billing Support", + "config": {{"prompt_template": [{{"role": "user", "text": "Help with billing: {{{{#start.user_input#}}}}"}}]}}, + "depends_on": [] + }}}}, + {{{{ + "id": "other_handler", + "type": "llm", + "title": "General Support", + "config": {{"prompt_template": [{{"role": "user", "text": "General help: {{{{#start.user_input#}}}}"}}]}}, + "depends_on": [] + }}}} + ] +}}}} +``` +Note: Handler nodes have empty depends_on because their connections come from classifier branches. + + +""" + +BUILDER_USER_PROMPT_V2 = """ +{instruction} + + +Generate the workflow nodes configuration. Remember: +1. Do NOT generate start or end nodes +2. Do NOT generate edges - use depends_on instead +3. For if-else: add true_branch/false_branch in config +4. For question-classifier: add target to each class +""" + +# ============================================================================= +# LEGACY FORMAT: edges-based prompt (backward compatible) +# ============================================================================= + BUILDER_SYSTEM_PROMPT = """ You are a Workflow Configuration Engineer. Your goal is to implement the Architect's plan by generating a precise, runnable Dify Workflow JSON configuration. diff --git a/api/core/workflow/generator/runner.py b/api/core/workflow/generator/runner.py index a6c8eb6fad..24092c2276 100644 --- a/api/core/workflow/generator/runner.py +++ b/api/core/workflow/generator/runner.py @@ -10,7 +10,9 @@ from core.model_runtime.entities.message_entities import SystemPromptMessage, Us from core.model_runtime.entities.model_entities import ModelType from core.workflow.generator.prompts.builder_prompts import ( BUILDER_SYSTEM_PROMPT, + BUILDER_SYSTEM_PROMPT_V2, BUILDER_USER_PROMPT, + BUILDER_USER_PROMPT_V2, format_existing_edges, format_existing_nodes, format_selected_nodes, @@ -26,6 +28,7 @@ from core.workflow.generator.prompts.vibe_prompts import ( format_available_tools, parse_vibe_response, ) +from core.workflow.generator.utils.graph_builder import CyclicDependencyError, GraphBuilder from core.workflow.generator.utils.mermaid_generator import generate_mermaid from core.workflow.generator.utils.workflow_validator import ValidationHint, WorkflowValidator @@ -53,6 +56,7 @@ class WorkflowGenerator: regenerate_mode: bool = False, preferred_language: str | None = None, available_models: Sequence[dict[str, object]] | None = None, + use_graph_builder: bool = False, ): """ Generates a Dify Workflow Flowchart from natural language instruction. @@ -173,17 +177,30 @@ class WorkflowGenerator: retry_context += "\nPlease fix these specific issues while keeping everything else UNCHANGED.\n" retry_context += "\n" - builder_system = BUILDER_SYSTEM_PROMPT.format( - plan_context=json.dumps(plan_data.get("steps", []), indent=2), - tool_schemas=tool_schemas, - builtin_node_specs=node_specs, - available_models=format_available_models(list(available_models or [])), - preferred_language=preferred_language or "English", - existing_nodes_context=existing_nodes_context, - existing_edges_context=existing_edges_context, - selected_nodes_context=selected_nodes_context, - ) - builder_user = BUILDER_USER_PROMPT.format(instruction=instruction) + retry_context + # Select prompt version based on use_graph_builder flag + if use_graph_builder: + builder_system = BUILDER_SYSTEM_PROMPT_V2.format( + plan_context=json.dumps(plan_data.get("steps", []), indent=2), + tool_schemas=tool_schemas, + builtin_node_specs=node_specs, + available_models=format_available_models(list(available_models or [])), + preferred_language=preferred_language or "English", + existing_nodes_context=existing_nodes_context, + selected_nodes_context=selected_nodes_context, + ) + builder_user = BUILDER_USER_PROMPT_V2.format(instruction=instruction) + retry_context + else: + builder_system = BUILDER_SYSTEM_PROMPT.format( + plan_context=json.dumps(plan_data.get("steps", []), indent=2), + tool_schemas=tool_schemas, + builtin_node_specs=node_specs, + available_models=format_available_models(list(available_models or [])), + preferred_language=preferred_language or "English", + existing_nodes_context=existing_nodes_context, + existing_edges_context=existing_edges_context, + selected_nodes_context=selected_nodes_context, + ) + builder_user = BUILDER_USER_PROMPT.format(instruction=instruction) + retry_context try: build_res = model_instance.invoke_llm( @@ -204,8 +221,53 @@ class WorkflowGenerator: if "nodes" not in workflow_data: workflow_data["nodes"] = [] - if "edges" not in workflow_data: - workflow_data["edges"] = [] + + # --- GraphBuilder Mode: Build graph from depends_on --- + if use_graph_builder: + try: + # Extract nodes from LLM output (without start/end) + llm_nodes = workflow_data.get("nodes", []) + + # Build complete graph with start/end and edges + complete_nodes, edges = GraphBuilder.build_graph(llm_nodes) + + workflow_data["nodes"] = complete_nodes + workflow_data["edges"] = edges + + logger.info( + "GraphBuilder: built %d nodes, %d edges from %d LLM nodes", + len(complete_nodes), + len(edges), + len(llm_nodes), + ) + + except CyclicDependencyError as e: + logger.warning("GraphBuilder: cyclic dependency detected: %s", e) + # Add to validation hints for retry + validation_hints.append( + ValidationHint( + node_id="", + field="depends_on", + message=f"Cyclic dependency detected: {e}. Please fix the dependency chain.", + severity="error", + ) + ) + if attempt == MAX_GLOBAL_RETRIES - 1: + return { + "intent": "error", + "error": "Failed to build workflow: cyclic dependency detected.", + } + continue # Retry with error feedback + + except Exception as e: + logger.exception("GraphBuilder failed on attempt %d", attempt + 1) + if attempt == MAX_GLOBAL_RETRIES - 1: + return {"intent": "error", "error": f"Graph building failed: {str(e)}"} + continue + else: + # Legacy mode: edges from LLM output + if "edges" not in workflow_data: + workflow_data["edges"] = [] except Exception as e: logger.exception("Builder failed on attempt %d", attempt + 1) diff --git a/api/core/workflow/generator/utils/graph_builder.py b/api/core/workflow/generator/utils/graph_builder.py new file mode 100644 index 0000000000..9f3fde3b4a --- /dev/null +++ b/api/core/workflow/generator/utils/graph_builder.py @@ -0,0 +1,621 @@ +""" +GraphBuilder: Automatic workflow graph construction from node list. + +This module implements the core logic for building complete workflow graphs +from LLM-generated node lists with dependency declarations. + +Key features: +- Automatic start/end node generation +- Dependency inference from variable references +- Topological sorting with cycle detection +- Special handling for branching nodes (if-else, question-classifier) +- Silent error recovery where possible +""" + +import json +import logging +import re +import uuid +from collections import defaultdict +from typing import Any + +logger = logging.getLogger(__name__) + +# Pattern to match variable references like {{#node_id.field#}} +VAR_PATTERN = re.compile(r"\{\{#([^.#]+)\.[^#]+#\}\}") + +# System variable prefixes to exclude from dependency inference +SYSTEM_VAR_PREFIXES = {"sys", "start", "env"} + +# Node types that have special branching behavior +BRANCHING_NODE_TYPES = {"if-else", "question-classifier"} + +# Container node types (iteration, loop) - these have internal subgraphs +# but behave as single-input-single-output nodes in the external graph +CONTAINER_NODE_TYPES = {"iteration", "loop"} + + +class GraphBuildError(Exception): + """Raised when graph cannot be built due to unrecoverable errors.""" + + pass + + +class CyclicDependencyError(GraphBuildError): + """Raised when cyclic dependencies are detected.""" + + pass + + +class GraphBuilder: + """ + Builds complete workflow graphs from LLM-generated node lists. + + This class handles the conversion from a simplified node list format + (with depends_on declarations) to a full workflow graph with nodes and edges. + + The LLM only needs to generate: + - Node configurations with depends_on arrays + - Branch targets in config for branching nodes + + The GraphBuilder automatically: + - Adds start and end nodes + - Generates all edges from dependencies + - Infers implicit dependencies from variable references + - Handles branching nodes (if-else, question-classifier) + - Validates graph structure (no cycles, proper connectivity) + """ + + @classmethod + def build_graph( + cls, + nodes: list[dict[str, Any]], + start_config: dict[str, Any] | None = None, + end_config: dict[str, Any] | None = None, + ) -> tuple[list[dict[str, Any]], list[dict[str, Any]]]: + """ + Build a complete workflow graph from a node list. + + Args: + nodes: LLM-generated nodes (without start/end) + start_config: Optional configuration for start node + end_config: Optional configuration for end node + + Returns: + Tuple of (complete_nodes, edges) where: + - complete_nodes includes start, user nodes, and end + - edges contains all connections + + Raises: + CyclicDependencyError: If cyclic dependencies are detected + GraphBuildError: If graph cannot be built + """ + if not nodes: + # Empty node list - create minimal workflow + start_node = cls._create_start_node([], start_config) + end_node = cls._create_end_node([], end_config) + edge = cls._create_edge("start", "end") + return [start_node, end_node], [edge] + + # Build node index for quick lookup + node_map = {node["id"]: node for node in nodes} + + # Step 1: Extract explicit dependencies from depends_on + dependencies = cls._extract_explicit_dependencies(nodes) + + # Step 2: Infer implicit dependencies from variable references + dependencies = cls._infer_dependencies_from_variables(nodes, dependencies, node_map) + + # Step 3: Validate and fix dependencies (remove invalid references) + dependencies = cls._validate_dependencies(dependencies, node_map) + + # Step 4: Topological sort (detects cycles) + sorted_node_ids = cls._topological_sort(nodes, dependencies) + + # Step 5: Generate start node + start_node = cls._create_start_node(nodes, start_config) + + # Step 6: Generate edges + edges = cls._generate_edges(nodes, sorted_node_ids, dependencies, node_map) + + # Step 7: Find terminal nodes and generate end node + terminal_nodes = cls._find_terminal_nodes(nodes, dependencies, node_map) + end_node = cls._create_end_node(terminal_nodes, end_config) + + # Step 8: Add edges from terminal nodes to end + for terminal_id in terminal_nodes: + edges.append(cls._create_edge(terminal_id, "end")) + + # Step 9: Assemble complete node list + all_nodes = [start_node, *nodes, end_node] + + return all_nodes, edges + + @classmethod + def _extract_explicit_dependencies( + cls, + nodes: list[dict[str, Any]], + ) -> dict[str, list[str]]: + """ + Extract explicit dependencies from depends_on field. + + Args: + nodes: List of nodes with optional depends_on field + + Returns: + Dictionary mapping node_id -> list of dependency node_ids + """ + dependencies: dict[str, list[str]] = {} + + for node in nodes: + node_id = node.get("id", "") + depends_on = node.get("depends_on", []) + + # Ensure depends_on is a list + if isinstance(depends_on, str): + depends_on = [depends_on] if depends_on else [] + elif not isinstance(depends_on, list): + depends_on = [] + + dependencies[node_id] = list(depends_on) + + return dependencies + + @classmethod + def _infer_dependencies_from_variables( + cls, + nodes: list[dict[str, Any]], + explicit_deps: dict[str, list[str]], + node_map: dict[str, dict[str, Any]], + ) -> dict[str, list[str]]: + """ + Infer implicit dependencies from variable references in config. + + Scans node configurations for patterns like {{#node_id.field#}} + and adds those as dependencies if not already declared. + + Args: + nodes: List of nodes + explicit_deps: Already extracted explicit dependencies + node_map: Map of node_id -> node for validation + + Returns: + Updated dependencies dictionary + """ + for node in nodes: + node_id = node.get("id", "") + config = node.get("config", {}) + + # Serialize config to search for variable references + try: + config_str = json.dumps(config, ensure_ascii=False) + except (TypeError, ValueError): + continue + + # Find all variable references + referenced_nodes = set(VAR_PATTERN.findall(config_str)) + + # Filter out system variables + referenced_nodes -= SYSTEM_VAR_PREFIXES + + # Ensure node_id exists in dependencies + if node_id not in explicit_deps: + explicit_deps[node_id] = [] + + # Add inferred dependencies + for ref in referenced_nodes: + # Skip self-references (e.g., loop nodes referencing their own outputs) + if ref == node_id: + logger.debug( + "Skipping self-reference: %s -> %s", + node_id, + ref, + ) + continue + + if ref in node_map and ref not in explicit_deps[node_id]: + explicit_deps[node_id].append(ref) + logger.debug( + "Inferred dependency: %s -> %s (from variable reference)", + node_id, + ref, + ) + + return explicit_deps + + @classmethod + def _validate_dependencies( + cls, + dependencies: dict[str, list[str]], + node_map: dict[str, dict[str, Any]], + ) -> dict[str, list[str]]: + """ + Validate dependencies and remove invalid references. + + Silent fix: References to non-existent nodes are removed. + + Args: + dependencies: Dependencies to validate + node_map: Map of valid node IDs + + Returns: + Validated dependencies + """ + valid_deps: dict[str, list[str]] = {} + + for node_id, deps in dependencies.items(): + valid_deps[node_id] = [] + for dep in deps: + if dep in node_map: + valid_deps[node_id].append(dep) + else: + logger.warning( + "Removed invalid dependency: %s -> %s (node does not exist)", + node_id, + dep, + ) + + return valid_deps + + @classmethod + def _topological_sort( + cls, + nodes: list[dict[str, Any]], + dependencies: dict[str, list[str]], + ) -> list[str]: + """ + Perform topological sort on nodes based on dependencies. + + Uses Kahn's algorithm for cycle detection. + + Args: + nodes: List of nodes + dependencies: Dependency graph + + Returns: + List of node IDs in topological order + + Raises: + CyclicDependencyError: If cyclic dependencies are detected + """ + # Build in-degree map + in_degree: dict[str, int] = defaultdict(int) + reverse_deps: dict[str, list[str]] = defaultdict(list) + + node_ids = {node["id"] for node in nodes} + + for node_id in node_ids: + in_degree[node_id] = 0 + + for node_id, deps in dependencies.items(): + for dep in deps: + if dep in node_ids: + in_degree[node_id] += 1 + reverse_deps[dep].append(node_id) + + # Start with nodes that have no dependencies + queue = [nid for nid in node_ids if in_degree[nid] == 0] + sorted_ids: list[str] = [] + + while queue: + current = queue.pop(0) + sorted_ids.append(current) + + for dependent in reverse_deps[current]: + in_degree[dependent] -= 1 + if in_degree[dependent] == 0: + queue.append(dependent) + + # Check for cycles + if len(sorted_ids) != len(node_ids): + remaining = node_ids - set(sorted_ids) + raise CyclicDependencyError( + f"Cyclic dependency detected involving nodes: {remaining}" + ) + + return sorted_ids + + @classmethod + def _generate_edges( + cls, + nodes: list[dict[str, Any]], + sorted_node_ids: list[str], + dependencies: dict[str, list[str]], + node_map: dict[str, dict[str, Any]], + ) -> list[dict[str, Any]]: + """ + Generate all edges based on dependencies and special node handling. + + Args: + nodes: List of nodes + sorted_node_ids: Topologically sorted node IDs + dependencies: Dependency graph + node_map: Map of node_id -> node + + Returns: + List of edge dictionaries + """ + edges: list[dict[str, Any]] = [] + nodes_with_incoming: set[str] = set() + + # Track which nodes have outgoing edges from branching + branching_sources: set[str] = set() + + # First pass: Handle branching nodes + for node in nodes: + node_id = node.get("id", "") + node_type = node.get("type", "") + + if node_type == "if-else": + branch_edges = cls._handle_if_else_node(node) + edges.extend(branch_edges) + branching_sources.add(node_id) + nodes_with_incoming.update(edge["target"] for edge in branch_edges) + + elif node_type == "question-classifier": + branch_edges = cls._handle_question_classifier_node(node) + edges.extend(branch_edges) + branching_sources.add(node_id) + nodes_with_incoming.update(edge["target"] for edge in branch_edges) + + # Second pass: Generate edges from dependencies + for node_id in sorted_node_ids: + deps = dependencies.get(node_id, []) + + if deps: + # Connect from each dependency + for dep_id in deps: + dep_node = node_map.get(dep_id, {}) + dep_type = dep_node.get("type", "") + + # Skip if dependency is a branching node (edges handled above) + if dep_type in BRANCHING_NODE_TYPES: + continue + + edges.append(cls._create_edge(dep_id, node_id)) + nodes_with_incoming.add(node_id) + else: + # No dependencies - connect from start + # But skip if this node receives edges from branching nodes + if node_id not in nodes_with_incoming: + edges.append(cls._create_edge("start", node_id)) + nodes_with_incoming.add(node_id) + + return edges + + @classmethod + def _handle_if_else_node( + cls, + node: dict[str, Any], + ) -> list[dict[str, Any]]: + """ + Handle if-else node branching. + + Expects config to contain true_branch and/or false_branch. + + Args: + node: If-else node + + Returns: + List of branch edges + """ + edges: list[dict[str, Any]] = [] + node_id = node.get("id", "") + config = node.get("config", {}) + + true_branch = config.get("true_branch") + false_branch = config.get("false_branch") + + if true_branch: + edges.append(cls._create_edge(node_id, true_branch, source_handle="true")) + + if false_branch: + edges.append(cls._create_edge(node_id, false_branch, source_handle="false")) + + # If no branches specified, log warning + if not true_branch and not false_branch: + logger.warning( + "if-else node %s has no branch targets specified", + node_id, + ) + + return edges + + @classmethod + def _handle_question_classifier_node( + cls, + node: dict[str, Any], + ) -> list[dict[str, Any]]: + """ + Handle question-classifier node branching. + + Expects config.classes to contain class definitions with target fields. + + Args: + node: Question-classifier node + + Returns: + List of branch edges + """ + edges: list[dict[str, Any]] = [] + node_id = node.get("id", "") + config = node.get("config", {}) + classes = config.get("classes", []) + + if not classes: + logger.warning( + "question-classifier node %s has no classes defined", + node_id, + ) + return edges + + for cls_def in classes: + class_id = cls_def.get("id", "") + target = cls_def.get("target") + + if target: + edges.append(cls._create_edge(node_id, target, source_handle=class_id)) + else: + # Silent fix: Connect to end if no target specified + edges.append(cls._create_edge(node_id, "end", source_handle=class_id)) + logger.debug( + "question-classifier class %s has no target, connecting to end", + class_id, + ) + + return edges + + @classmethod + def _find_terminal_nodes( + cls, + nodes: list[dict[str, Any]], + dependencies: dict[str, list[str]], + node_map: dict[str, dict[str, Any]], + ) -> list[str]: + """ + Find nodes that should connect to the end node. + + Terminal nodes are those that: + - Are not dependencies of any other node + - Are not branching nodes (those connect to their branches) + + Args: + nodes: List of nodes + dependencies: Dependency graph + node_map: Map of node_id -> node + + Returns: + List of terminal node IDs + """ + # Build set of all nodes that are depended upon + depended_upon: set[str] = set() + for deps in dependencies.values(): + depended_upon.update(deps) + + # Also track nodes that are branch targets + branch_targets: set[str] = set() + branching_nodes: set[str] = set() + + for node in nodes: + node_id = node.get("id", "") + node_type = node.get("type", "") + config = node.get("config", {}) + + if node_type == "if-else": + branching_nodes.add(node_id) + if config.get("true_branch"): + branch_targets.add(config["true_branch"]) + if config.get("false_branch"): + branch_targets.add(config["false_branch"]) + + elif node_type == "question-classifier": + branching_nodes.add(node_id) + for cls_def in config.get("classes", []): + if cls_def.get("target"): + branch_targets.add(cls_def["target"]) + + # Find terminal nodes + terminal_nodes: list[str] = [] + for node in nodes: + node_id = node.get("id", "") + node_type = node.get("type", "") + + # Skip branching nodes - they don't connect to end directly + if node_type in BRANCHING_NODE_TYPES: + continue + + # Terminal if not depended upon and not a branch target that leads elsewhere + if node_id not in depended_upon: + terminal_nodes.append(node_id) + + # If no terminal nodes found (shouldn't happen), use all non-branching nodes + if not terminal_nodes: + terminal_nodes = [ + node["id"] + for node in nodes + if node.get("type") not in BRANCHING_NODE_TYPES + ] + logger.warning("No terminal nodes found, using all non-branching nodes") + + return terminal_nodes + + @classmethod + def _create_start_node( + cls, + nodes: list[dict[str, Any]], + config: dict[str, Any] | None = None, + ) -> dict[str, Any]: + """ + Create a start node. + + Args: + nodes: User nodes (for potential config inference) + config: Optional start node configuration + + Returns: + Start node dictionary + """ + return { + "id": "start", + "type": "start", + "title": "Start", + "config": config or {}, + "data": {}, + } + + @classmethod + def _create_end_node( + cls, + terminal_nodes: list[str], + config: dict[str, Any] | None = None, + ) -> dict[str, Any]: + """ + Create an end node. + + Args: + terminal_nodes: Nodes that will connect to end + config: Optional end node configuration + + Returns: + End node dictionary + """ + return { + "id": "end", + "type": "end", + "title": "End", + "config": config or {}, + "data": {}, + } + + @classmethod + def _create_edge( + cls, + source: str, + target: str, + source_handle: str | None = None, + ) -> dict[str, Any]: + """ + Create an edge dictionary. + + Args: + source: Source node ID + target: Target node ID + source_handle: Optional handle for branching (e.g., "true", "false", class_id) + + Returns: + Edge dictionary + """ + edge: dict[str, Any] = { + "id": f"{source}-{target}-{uuid.uuid4().hex[:8]}", + "source": source, + "target": target, + } + + if source_handle: + edge["sourceHandle"] = source_handle + else: + edge["sourceHandle"] = "source" + + edge["targetHandle"] = "target" + + return edge diff --git a/api/tests/unit_tests/core/llm_generator/test_graph_builder.py b/api/tests/unit_tests/core/llm_generator/test_graph_builder.py new file mode 100644 index 0000000000..e28e418690 --- /dev/null +++ b/api/tests/unit_tests/core/llm_generator/test_graph_builder.py @@ -0,0 +1,418 @@ +""" +Unit tests for GraphBuilder. + +Tests the automatic graph construction from node lists with dependency declarations. +""" + +import pytest + +from core.workflow.generator.utils.graph_builder import ( + CyclicDependencyError, + GraphBuilder, +) + + +class TestGraphBuilderBasic: + """Basic functionality tests.""" + + def test_empty_nodes_creates_minimal_workflow(self): + """Empty node list creates start -> end workflow.""" + result_nodes, result_edges = GraphBuilder.build_graph([]) + + assert len(result_nodes) == 2 + assert result_nodes[0]["type"] == "start" + assert result_nodes[1]["type"] == "end" + assert len(result_edges) == 1 + assert result_edges[0]["source"] == "start" + assert result_edges[0]["target"] == "end" + + def test_simple_linear_workflow(self): + """Simple linear workflow: start -> fetch -> process -> end.""" + nodes = [ + {"id": "fetch", "type": "http-request", "depends_on": []}, + {"id": "process", "type": "llm", "depends_on": ["fetch"]}, + ] + result_nodes, result_edges = GraphBuilder.build_graph(nodes) + + # Should have: start + 2 user nodes + end = 4 + assert len(result_nodes) == 4 + assert result_nodes[0]["type"] == "start" + assert result_nodes[-1]["type"] == "end" + + # Should have: start->fetch, fetch->process, process->end = 3 + assert len(result_edges) == 3 + + # Verify edge connections + edge_pairs = [(e["source"], e["target"]) for e in result_edges] + assert ("start", "fetch") in edge_pairs + assert ("fetch", "process") in edge_pairs + assert ("process", "end") in edge_pairs + + +class TestParallelWorkflow: + """Tests for parallel node handling.""" + + def test_parallel_workflow(self): + """Parallel workflow: multiple nodes from start, merging to one.""" + nodes = [ + {"id": "api1", "type": "http-request", "depends_on": []}, + {"id": "api2", "type": "http-request", "depends_on": []}, + {"id": "merge", "type": "llm", "depends_on": ["api1", "api2"]}, + ] + result_nodes, result_edges = GraphBuilder.build_graph(nodes) + + # start should connect to both api1 and api2 + start_edges = [e for e in result_edges if e["source"] == "start"] + assert len(start_edges) == 2 + + start_targets = {e["target"] for e in start_edges} + assert start_targets == {"api1", "api2"} + + # Both api1 and api2 should connect to merge + merge_incoming = [e for e in result_edges if e["target"] == "merge"] + assert len(merge_incoming) == 2 + + def test_multiple_terminal_nodes(self): + """Multiple terminal nodes all connect to end.""" + nodes = [ + {"id": "branch1", "type": "llm", "depends_on": []}, + {"id": "branch2", "type": "llm", "depends_on": []}, + ] + result_nodes, result_edges = GraphBuilder.build_graph(nodes) + + # Both branches should connect to end + end_incoming = [e for e in result_edges if e["target"] == "end"] + assert len(end_incoming) == 2 + + +class TestIfElseWorkflow: + """Tests for if-else branching.""" + + def test_if_else_workflow(self): + """Conditional branching workflow.""" + nodes = [ + { + "id": "check", + "type": "if-else", + "config": {"true_branch": "success", "false_branch": "fallback"}, + "depends_on": [], + }, + {"id": "success", "type": "llm", "depends_on": []}, + {"id": "fallback", "type": "code", "depends_on": []}, + ] + result_nodes, result_edges = GraphBuilder.build_graph(nodes) + + # Should have true and false branch edges + branch_edges = [e for e in result_edges if e["source"] == "check"] + assert len(branch_edges) == 2 + assert any(e.get("sourceHandle") == "true" for e in branch_edges) + assert any(e.get("sourceHandle") == "false" for e in branch_edges) + + # Verify targets + true_edge = next(e for e in branch_edges if e.get("sourceHandle") == "true") + false_edge = next(e for e in branch_edges if e.get("sourceHandle") == "false") + assert true_edge["target"] == "success" + assert false_edge["target"] == "fallback" + + def test_if_else_missing_branch_no_error(self): + """if-else with only true branch doesn't error (warning only).""" + nodes = [ + { + "id": "check", + "type": "if-else", + "config": {"true_branch": "success"}, + "depends_on": [], + }, + {"id": "success", "type": "llm", "depends_on": []}, + ] + # Should not raise + result_nodes, result_edges = GraphBuilder.build_graph(nodes) + + # Should have one branch edge + branch_edges = [e for e in result_edges if e["source"] == "check"] + assert len(branch_edges) == 1 + assert branch_edges[0].get("sourceHandle") == "true" + + +class TestQuestionClassifierWorkflow: + """Tests for question-classifier branching.""" + + def test_question_classifier_workflow(self): + """Question classifier with multiple classes.""" + nodes = [ + { + "id": "classifier", + "type": "question-classifier", + "config": { + "query": ["start", "user_input"], + "classes": [ + {"id": "tech", "name": "技术问题", "target": "tech_handler"}, + {"id": "sales", "name": "销售咨询", "target": "sales_handler"}, + {"id": "other", "name": "其他问题", "target": "other_handler"}, + ], + }, + "depends_on": [], + }, + {"id": "tech_handler", "type": "llm", "depends_on": []}, + {"id": "sales_handler", "type": "llm", "depends_on": []}, + {"id": "other_handler", "type": "llm", "depends_on": []}, + ] + result_nodes, result_edges = GraphBuilder.build_graph(nodes) + + # Should have 3 branch edges from classifier + classifier_edges = [e for e in result_edges if e["source"] == "classifier"] + assert len(classifier_edges) == 3 + + # Each should use class id as sourceHandle + assert any( + e.get("sourceHandle") == "tech" and e["target"] == "tech_handler" + for e in classifier_edges + ) + assert any( + e.get("sourceHandle") == "sales" and e["target"] == "sales_handler" + for e in classifier_edges + ) + assert any( + e.get("sourceHandle") == "other" and e["target"] == "other_handler" + for e in classifier_edges + ) + + def test_question_classifier_missing_target(self): + """Classes without target connect to end.""" + nodes = [ + { + "id": "classifier", + "type": "question-classifier", + "config": { + "classes": [ + {"id": "known", "name": "已知问题", "target": "handler"}, + {"id": "unknown", "name": "未知问题"}, # Missing target + ], + }, + "depends_on": [], + }, + {"id": "handler", "type": "llm", "depends_on": []}, + ] + result_nodes, result_edges = GraphBuilder.build_graph(nodes) + + # Missing target should connect to end + classifier_edges = [e for e in result_edges if e["source"] == "classifier"] + assert any( + e.get("sourceHandle") == "unknown" and e["target"] == "end" + for e in classifier_edges + ) + + +class TestVariableDependencyInference: + """Tests for automatic dependency inference from variables.""" + + def test_variable_dependency_inference(self): + """Dependencies inferred from variable references.""" + nodes = [ + {"id": "fetch", "type": "http-request", "depends_on": []}, + { + "id": "process", + "type": "llm", + "config": {"prompt_template": [{"text": "{{#fetch.body#}}"}]}, + # No explicit depends_on, but references fetch + }, + ] + result_nodes, result_edges = GraphBuilder.build_graph(nodes) + + # Should automatically infer process depends on fetch + assert any( + e["source"] == "fetch" and e["target"] == "process" for e in result_edges + ) + + def test_system_variable_not_inferred(self): + """System variables (sys, start) not inferred as dependencies.""" + nodes = [ + { + "id": "process", + "type": "llm", + "config": {"prompt_template": [{"text": "{{#sys.query#}} {{#start.input#}}"}]}, + "depends_on": [], + }, + ] + result_nodes, result_edges = GraphBuilder.build_graph(nodes) + + # Should connect to start, not create dependency on sys or start + edge_sources = {e["source"] for e in result_edges} + assert "sys" not in edge_sources + assert "start" in edge_sources + + +class TestCycleDetection: + """Tests for cyclic dependency detection.""" + + def test_cyclic_dependency_detected(self): + """Cyclic dependencies raise error.""" + nodes = [ + {"id": "a", "type": "llm", "depends_on": ["c"]}, + {"id": "b", "type": "llm", "depends_on": ["a"]}, + {"id": "c", "type": "llm", "depends_on": ["b"]}, + ] + + with pytest.raises(CyclicDependencyError): + GraphBuilder.build_graph(nodes) + + def test_self_dependency_detected(self): + """Self-dependency raises error.""" + nodes = [ + {"id": "a", "type": "llm", "depends_on": ["a"]}, + ] + + with pytest.raises(CyclicDependencyError): + GraphBuilder.build_graph(nodes) + + +class TestErrorRecovery: + """Tests for silent error recovery.""" + + def test_invalid_dependency_removed(self): + """Invalid dependencies (non-existent nodes) are silently removed.""" + nodes = [ + {"id": "process", "type": "llm", "depends_on": ["nonexistent"]}, + ] + # Should not raise, invalid dependency silently removed + result_nodes, result_edges = GraphBuilder.build_graph(nodes) + + # Process should connect from start (since invalid dep was removed) + assert any( + e["source"] == "start" and e["target"] == "process" for e in result_edges + ) + + def test_depends_on_as_string(self): + """depends_on as string is converted to list.""" + nodes = [ + {"id": "fetch", "type": "http-request", "depends_on": []}, + {"id": "process", "type": "llm", "depends_on": "fetch"}, # String instead of list + ] + result_nodes, result_edges = GraphBuilder.build_graph(nodes) + + # Should work correctly + assert any( + e["source"] == "fetch" and e["target"] == "process" for e in result_edges + ) + + +class TestContainerNodes: + """Tests for container nodes (iteration, loop).""" + + def test_iteration_node_as_regular_node(self): + """Iteration nodes behave as regular single-in-single-out nodes.""" + nodes = [ + {"id": "prepare", "type": "code", "depends_on": []}, + { + "id": "loop", + "type": "iteration", + "config": {"iterator_selector": ["prepare", "items"]}, + "depends_on": ["prepare"], + }, + {"id": "process_result", "type": "llm", "depends_on": ["loop"]}, + ] + result_nodes, result_edges = GraphBuilder.build_graph(nodes) + + # Should have standard edges: start->prepare, prepare->loop, loop->process_result, process_result->end + edge_pairs = [(e["source"], e["target"]) for e in result_edges] + assert ("start", "prepare") in edge_pairs + assert ("prepare", "loop") in edge_pairs + assert ("loop", "process_result") in edge_pairs + assert ("process_result", "end") in edge_pairs + + def test_loop_node_as_regular_node(self): + """Loop nodes behave as regular single-in-single-out nodes.""" + nodes = [ + {"id": "init", "type": "code", "depends_on": []}, + { + "id": "repeat", + "type": "loop", + "config": {"loop_count": 5}, + "depends_on": ["init"], + }, + {"id": "finish", "type": "llm", "depends_on": ["repeat"]}, + ] + result_nodes, result_edges = GraphBuilder.build_graph(nodes) + + # Standard edge flow + edge_pairs = [(e["source"], e["target"]) for e in result_edges] + assert ("init", "repeat") in edge_pairs + assert ("repeat", "finish") in edge_pairs + + def test_iteration_with_variable_inference(self): + """Iteration node dependencies can be inferred from iterator_selector.""" + nodes = [ + {"id": "data_source", "type": "http-request", "depends_on": []}, + { + "id": "process_each", + "type": "iteration", + "config": { + "iterator_selector": ["data_source", "items"], + }, + # No explicit depends_on, but references data_source + }, + ] + result_nodes, result_edges = GraphBuilder.build_graph(nodes) + + # Should infer dependency from iterator_selector reference + # Note: iterator_selector format is different from {{#...#}}, so this tests + # that explicit depends_on is properly handled when not provided + # In this case, process_each has no depends_on, so it connects to start + edge_pairs = [(e["source"], e["target"]) for e in result_edges] + # Without explicit depends_on, connects to start + assert ("start", "process_each") in edge_pairs or ("data_source", "process_each") in edge_pairs + + def test_loop_node_self_reference_not_cycle(self): + """Loop nodes referencing their own outputs should not create cycle.""" + nodes = [ + {"id": "init", "type": "code", "depends_on": []}, + { + "id": "my_loop", + "type": "loop", + "config": { + "loop_count": 5, + # Loop node referencing its own output (common pattern) + "prompt": "Previous: {{#my_loop.output#}}, continue...", + }, + "depends_on": ["init"], + }, + {"id": "finish", "type": "llm", "depends_on": ["my_loop"]}, + ] + # Should NOT raise CyclicDependencyError + result_nodes, result_edges = GraphBuilder.build_graph(nodes) + + # Verify the graph is built correctly + assert len(result_nodes) == 5 # start + 3 + end + edge_pairs = [(e["source"], e["target"]) for e in result_edges] + assert ("init", "my_loop") in edge_pairs + assert ("my_loop", "finish") in edge_pairs + + +class TestEdgeStructure: + """Tests for edge structure correctness.""" + + def test_edge_has_required_fields(self): + """Edges have all required fields.""" + nodes = [ + {"id": "node1", "type": "llm", "depends_on": []}, + ] + result_nodes, result_edges = GraphBuilder.build_graph(nodes) + + for edge in result_edges: + assert "id" in edge + assert "source" in edge + assert "target" in edge + assert "sourceHandle" in edge + assert "targetHandle" in edge + + def test_edge_id_unique(self): + """Each edge has a unique ID.""" + nodes = [ + {"id": "a", "type": "llm", "depends_on": []}, + {"id": "b", "type": "llm", "depends_on": []}, + {"id": "c", "type": "llm", "depends_on": ["a", "b"]}, + ] + result_nodes, result_edges = GraphBuilder.build_graph(nodes) + + edge_ids = [e["id"] for e in result_edges] + assert len(edge_ids) == len(set(edge_ids)) # All unique diff --git a/web/app/components/workflow/panel/vibe-panel/index.tsx b/web/app/components/workflow/panel/vibe-panel/index.tsx index 34fd1998e5..abdece58af 100644 --- a/web/app/components/workflow/panel/vibe-panel/index.tsx +++ b/web/app/components/workflow/panel/vibe-panel/index.tsx @@ -3,10 +3,11 @@ import type { FC } from 'react' import type { FormValue } from '@/app/components/header/account-setting/model-provider-page/declarations' import type { CompletionParams, Model } from '@/types/app' -import { RiClipboardLine, RiInformation2Line } from '@remixicon/react' +import { RiClipboardLine } from '@remixicon/react' import copy from 'copy-to-clipboard' -import { useCallback, useEffect, useMemo, useState } from 'react' +import { useCallback, useMemo, useState } from 'react' import { useTranslation } from 'react-i18next' +import { z } from 'zod' import ResPlaceholder from '@/app/components/app/configuration/config/automatic/res-placeholder' import VersionSelector from '@/app/components/app/configuration/config/automatic/version-selector' import Button from '@/app/components/base/button' @@ -23,6 +24,23 @@ import { VIBE_APPLY_EVENT, VIBE_COMMAND_EVENT } from '../../constants' import { useStore, useWorkflowStore } from '../../store' import WorkflowPreview from '../../workflow-preview' +const CompletionParamsSchema = z.object({ + max_tokens: z.number(), + temperature: z.number(), + top_p: z.number(), + echo: z.boolean(), + stop: z.array(z.string()), + presence_penalty: z.number(), + frequency_penalty: z.number(), +}) + +const ModelSchema = z.object({ + provider: z.string(), + name: z.string(), + mode: z.nativeEnum(ModelModeType), + completion_params: CompletionParamsSchema, +}) + const VibePanel: FC = () => { const { t } = useTranslation() const workflowStore = useWorkflowStore() @@ -48,54 +66,68 @@ const VibePanel: FC = () => { const vibePanelSuggestions = useStore(s => s.vibePanelSuggestions) const setVibePanelSuggestions = useStore(s => s.setVibePanelSuggestions) - const localModel = localStorage.getItem('auto-gen-model') - ? JSON.parse(localStorage.getItem('auto-gen-model') as string) as Model - : null - const [model, setModel] = useState(localModel || { - name: '', - provider: '', - mode: ModelModeType.chat, - completion_params: {} as CompletionParams, - }) const { defaultModel } = useModelListAndDefaultModelAndCurrentProviderAndModel(ModelTypeEnum.textGeneration) - useEffect(() => { - if (defaultModel) { - const localModel = localStorage.getItem('auto-gen-model') - ? JSON.parse(localStorage.getItem('auto-gen-model') || '') - : null - if (localModel) { - setModel(localModel) - } - else { - setModel(prev => ({ - ...prev, - name: defaultModel.model, - provider: defaultModel.provider.provider, - })) + // Track user's explicit model selection (from localStorage) + const [userModel, setUserModel] = useState(() => { + try { + const stored = localStorage.getItem('auto-gen-model') + if (stored) { + const parsed = JSON.parse(stored) + const result = ModelSchema.safeParse(parsed) + if (result.success) + return result.data + + // If validation fails, clear the invalid data + localStorage.removeItem('auto-gen-model') } } - }, [defaultModel]) + catch { + // ignore parse errors + } + return null + }) + + // Derive the actual model from user selection or default + const model: Model = useMemo(() => { + if (userModel) + return userModel + if (defaultModel) { + return { + name: defaultModel.model, + provider: defaultModel.provider.provider, + mode: ModelModeType.chat, + completion_params: {} as CompletionParams, + } + } + return { + name: '', + provider: '', + mode: ModelModeType.chat, + completion_params: {} as CompletionParams, + } + }, [userModel, defaultModel]) + + const setModel = useCallback((newModel: Model) => { + setUserModel(newModel) + localStorage.setItem('auto-gen-model', JSON.stringify(newModel)) + }, []) const handleModelChange = useCallback((newValue: { modelId: string, provider: string, mode?: string, features?: string[] }) => { - const newModel = { + setModel({ ...model, provider: newValue.provider, name: newValue.modelId, mode: newValue.mode as ModelModeType, - } - setModel(newModel) - localStorage.setItem('auto-gen-model', JSON.stringify(newModel)) - }, [model]) + }) + }, [model, setModel]) const handleCompletionParamsChange = useCallback((newParams: FormValue) => { - const newModel = { + setModel({ ...model, completion_params: newParams as CompletionParams, - } - setModel(newModel) - localStorage.setItem('auto-gen-model', JSON.stringify(newModel)) - }, [model]) + }) + }, [model, setModel]) const handleInstructionChange = useCallback((e: React.ChangeEvent) => { workflowStore.setState(state => ({ @@ -161,28 +193,25 @@ const VibePanel: FC = () => { ) const renderOffTopic = ( -
+
-
- -
-
+
{t('vibe.offTopicTitle', { ns: 'workflow' })}
-
+
{vibePanelMessage || t('vibe.offTopicDefault', { ns: 'workflow' })}
{vibePanelSuggestions.length > 0 && ( -
-
+
+
{t('vibe.trySuggestion', { ns: 'workflow' })}
- {vibePanelSuggestions.map((suggestion, index) => ( + {vibePanelSuggestions.map(suggestion => (