diff --git a/api/services/snippet_generate_service.py b/api/services/snippet_generate_service.py index b21a2cbd7a..250f605c0c 100644 --- a/api/services/snippet_generate_service.py +++ b/api/services/snippet_generate_service.py @@ -18,10 +18,13 @@ Supported execution modes: - Single loop run (generate_single_loop): SSE stream for loop container nodes. """ +import json import logging from collections.abc import Generator, Mapping, Sequence from typing import Any, Union +from sqlalchemy.orm import make_transient + from core.app.app_config.features.file_upload.manager import FileUploadConfigManager from core.app.apps.workflow.app_generator import WorkflowAppGenerator from core.app.entities.app_invoke_entities import InvokeFrom @@ -74,6 +77,9 @@ class SnippetGenerateService: complex workflow execution pipeline. """ + # Specific ID for the injected virtual Start node so it can be recognised + _VIRTUAL_START_NODE_ID = "__snippet_virtual_start__" + @classmethod def generate( cls, @@ -89,6 +95,11 @@ class SnippetGenerateService: Retrieves the draft workflow, adapts the snippet to an App-like proxy, then delegates execution to WorkflowAppGenerator. + If the workflow graph has no Start node, a virtual Start node is injected + in-memory so that: + 1. Graph validation passes (root node must have execution_type=ROOT). + 2. User inputs are processed into the variable pool by the StartNode logic. + :param snippet: CustomizedSnippet instance :param user: Account or EndUser initiating the run :param args: Workflow inputs (must include "inputs" key) @@ -102,6 +113,9 @@ class SnippetGenerateService: if not workflow: raise ValueError("Workflow not initialized") + # Inject a virtual Start node when the graph doesn't have one. + workflow = cls._ensure_start_node(workflow, snippet) + # Adapt snippet to App-like interface for WorkflowAppGenerator app_proxy = _SnippetAsApp(snippet) @@ -117,6 +131,102 @@ class SnippetGenerateService: ) ) + @classmethod + def _ensure_start_node(cls, workflow: Workflow, snippet: CustomizedSnippet) -> Workflow: + """ + Return *workflow* with a Start node. + + If the graph already contains a Start node, the original workflow is + returned unchanged. Otherwise a virtual Start node is injected and the + workflow object is detached from the SQLAlchemy session so the in-memory + change is never flushed to the database. + """ + graph_dict = workflow.graph_dict + nodes: list[dict[str, Any]] = graph_dict.get("nodes", []) + + has_start = any(node.get("data", {}).get("type") == "start" for node in nodes) + if has_start: + return workflow + + modified_graph = cls._inject_virtual_start_node( + graph_dict=graph_dict, + input_fields=snippet.input_fields_list, + ) + + # Detach from session to prevent accidental DB persistence of the + # modified graph. All attributes remain accessible for read. + make_transient(workflow) + workflow.graph = json.dumps(modified_graph) + return workflow + + @classmethod + def _inject_virtual_start_node( + cls, + graph_dict: Mapping[str, Any], + input_fields: list[dict[str, Any]], + ) -> dict[str, Any]: + """ + Build a new graph dict with a virtual Start node prepended. + + The virtual Start node is wired to every existing node that has no + incoming edges (i.e. the current root candidates). This guarantees: + + :param graph_dict: Original graph configuration. + :param input_fields: Snippet input field definitions from + ``CustomizedSnippet.input_fields_list``. + :return: New graph dict containing the virtual Start node and edges. + """ + nodes: list[dict[str, Any]] = list(graph_dict.get("nodes", [])) + edges: list[dict[str, Any]] = list(graph_dict.get("edges", [])) + + # Identify nodes with no incoming edges. + nodes_with_incoming: set[str] = set() + for edge in edges: + target = edge.get("target") + if isinstance(target, str): + nodes_with_incoming.add(target) + root_candidate_ids = [n["id"] for n in nodes if n["id"] not in nodes_with_incoming] + + # Build Start node ``variables`` from snippet input fields. + start_variables: list[dict[str, Any]] = [] + for field in input_fields: + var: dict[str, Any] = { + "variable": field.get("variable", ""), + "label": field.get("label", field.get("variable", "")), + "type": field.get("type", "text-input"), + "required": field.get("required", False), + "options": field.get("options", []), + } + if field.get("max_length") is not None: + var["max_length"] = field["max_length"] + start_variables.append(var) + + virtual_start_node: dict[str, Any] = { + "id": cls._VIRTUAL_START_NODE_ID, + "data": { + "type": "start", + "title": "Start", + "variables": start_variables, + }, + } + + # Create edges from virtual Start to each root candidate. + new_edges: list[dict[str, Any]] = [ + { + "source": cls._VIRTUAL_START_NODE_ID, + "sourceHandle": "source", + "target": root_id, + "targetHandle": "target", + } + for root_id in root_candidate_ids + ] + + return { + **graph_dict, + "nodes": [virtual_start_node, *nodes], + "edges": [*edges, *new_edges], + } + @classmethod def run_draft_node( cls,