diff --git a/api/core/app/apps/workflow/app_generator.py b/api/core/app/apps/workflow/app_generator.py index 6fbe19a3b2..822019e170 100644 --- a/api/core/app/apps/workflow/app_generator.py +++ b/api/core/app/apps/workflow/app_generator.py @@ -10,7 +10,7 @@ from typing import TYPE_CHECKING, Any, Literal, Union, overload from flask import Flask, current_app from pydantic import ValidationError from sqlalchemy import select -from sqlalchemy.orm import sessionmaker +from sqlalchemy.orm import Session, sessionmaker import contexts from configs import dify_config @@ -55,6 +55,25 @@ logger = logging.getLogger(__name__) class WorkflowAppGenerator(BaseAppGenerator): + @staticmethod + def _ensure_snippet_start_node_in_worker(*, session: Session, workflow: Workflow) -> Workflow: + """Re-apply snippet virtual Start injection after worker reloads workflow from DB.""" + if workflow.type != "snippet": + return workflow + + from models.snippet import CustomizedSnippet + from services.snippet_generate_service import SnippetGenerateService + + snippet = session.scalar( + select(CustomizedSnippet).where( + CustomizedSnippet.id == workflow.app_id, + CustomizedSnippet.tenant_id == workflow.tenant_id, + ) + ) + if snippet is None: + return workflow + return SnippetGenerateService.ensure_start_node_for_worker(workflow, snippet) + @staticmethod def _should_prepare_user_inputs(args: Mapping[str, Any]) -> bool: return not bool(args.get(SKIP_PREPARE_USER_INPUTS_KEY)) @@ -551,6 +570,8 @@ class WorkflowAppGenerator(BaseAppGenerator): if workflow is None: raise ValueError("Workflow not found") + workflow = self._ensure_snippet_start_node_in_worker(session=session, workflow=workflow) + # Determine system_user_id based on invocation source is_external_api_call = application_generate_entity.invoke_from in { InvokeFrom.WEB_APP, diff --git a/api/dify_graph/enums.py b/api/dify_graph/enums.py index cfb135cbb0..f0333b3e1c 100644 --- a/api/dify_graph/enums.py +++ b/api/dify_graph/enums.py @@ -129,6 +129,7 @@ class WorkflowType(StrEnum): WORKFLOW = "workflow" CHAT = "chat" RAG_PIPELINE = "rag-pipeline" + SNIPPET = "snippet" class WorkflowExecutionStatus(StrEnum): diff --git a/api/services/snippet_generate_service.py b/api/services/snippet_generate_service.py index 6f59570f11..e8648c7e6a 100644 --- a/api/services/snippet_generate_service.py +++ b/api/services/snippet_generate_service.py @@ -173,6 +173,11 @@ class SnippetGenerateService: ) return response + @classmethod + def ensure_start_node_for_worker(cls, workflow: Workflow, snippet: CustomizedSnippet) -> Workflow: + """Public wrapper for worker-thread start-node injection.""" + return cls._ensure_start_node(workflow, snippet) + @classmethod def _ensure_start_node(cls, workflow: Workflow, snippet: CustomizedSnippet) -> Workflow: """