import json import logging from collections.abc import Mapping, Sequence from datetime import UTC, datetime from typing import Any from sqlalchemy import func, select from sqlalchemy.orm import Session, sessionmaker from core.workflow.node_factory import LATEST_VERSION, NODE_TYPE_CLASSES_MAPPING from extensions.ext_database import db from graphon.enums import BuiltinNodeTypes, NodeType from libs.infinite_scroll_pagination import InfiniteScrollPagination from models import Account from models.enums import WorkflowRunTriggeredFrom from models.snippet import CustomizedSnippet, SnippetType from models.workflow import ( Workflow, WorkflowKind, WorkflowNodeExecutionModel, WorkflowRun, WorkflowType, ) from repositories.factory import DifyAPIRepositoryFactory from services.errors.app import IsDraftWorkflowError, WorkflowHashNotEqualError, WorkflowNotFoundError from services.workflow_restore import apply_published_workflow_snapshot_to_draft logger = logging.getLogger(__name__) # Node types not allowed in snippet workflows (sync, publish, DSL import). SNIPPET_FORBIDDEN_NODE_TYPES: frozenset[str] = frozenset( { BuiltinNodeTypes.START, BuiltinNodeTypes.HUMAN_INPUT, BuiltinNodeTypes.KNOWLEDGE_RETRIEVAL, } ) class SnippetService: """Service for managing customized snippets.""" def __init__(self, session_maker: sessionmaker | None = None): """Initialize SnippetService with repository dependencies.""" if session_maker is None: session_maker = sessionmaker(bind=db.engine, expire_on_commit=False) self._node_execution_service_repo = DifyAPIRepositoryFactory.create_api_workflow_node_execution_repository( session_maker ) self._workflow_run_repo = DifyAPIRepositoryFactory.create_api_workflow_run_repository(session_maker) @staticmethod def _snippet_kind_filter(): """Match snippet workflows by business kind.""" return Workflow.kind == WorkflowKind.SNIPPET.value @staticmethod def validate_snippet_graph_forbidden_nodes(graph: Mapping[str, Any]) -> None: """Reject graphs that contain node types not allowed in snippets.""" nodes = graph.get("nodes") or [] disallowed: list[tuple[str, str]] = [] for node in nodes: if not isinstance(node, dict): continue node_data = node.get("data") or {} node_type = node_data.get("type") if not isinstance(node_type, str): continue if node_type in SNIPPET_FORBIDDEN_NODE_TYPES: node_id = node.get("id") disallowed.append((str(node_id) if node_id is not None else "?", node_type)) if not disallowed: return detail = ", ".join(f"{nid}:{t}" for nid, t in disallowed) raise ValueError( "Snippet workflow cannot contain start, human-input, or knowledge-retrieval nodes. " f"Found: {detail}" ) # --- CRUD Operations --- @staticmethod def get_snippets( *, tenant_id: str, page: int = 1, limit: int = 20, keyword: str | None = None, is_published: bool | None = None, creators: list[str] | None = None, ) -> tuple[Sequence[CustomizedSnippet], int, bool]: """ Get paginated list of snippets with optional search. :param tenant_id: Tenant ID :param page: Page number (1-indexed) :param limit: Number of items per page :param keyword: Optional search keyword for name/description :param is_published: Optional filter by published status (True/False/None for all) :param creators: Optional filter by creator account IDs :return: Tuple of (snippets list, total count, has_more flag) """ stmt = ( select(CustomizedSnippet) .where(CustomizedSnippet.tenant_id == tenant_id) .order_by(CustomizedSnippet.created_at.desc()) ) if keyword: stmt = stmt.where( CustomizedSnippet.name.ilike(f"%{keyword}%") | CustomizedSnippet.description.ilike(f"%{keyword}%") ) if is_published is not None: stmt = stmt.where(CustomizedSnippet.is_published == is_published) if creators: stmt = stmt.where(CustomizedSnippet.created_by.in_(creators)) # Get total count count_stmt = select(func.count()).select_from(stmt.subquery()) total = db.session.scalar(count_stmt) or 0 # Apply pagination stmt = stmt.limit(limit + 1).offset((page - 1) * limit) snippets = list(db.session.scalars(stmt).all()) has_more = len(snippets) > limit if has_more: snippets = snippets[:-1] return snippets, total, has_more @staticmethod def get_snippet_by_id( *, snippet_id: str, tenant_id: str, ) -> CustomizedSnippet | None: """ Get snippet by ID with tenant isolation. :param snippet_id: Snippet ID :param tenant_id: Tenant ID :return: CustomizedSnippet or None """ return ( db.session.query(CustomizedSnippet) .where( CustomizedSnippet.id == snippet_id, CustomizedSnippet.tenant_id == tenant_id, ) .first() ) @staticmethod def create_snippet( *, tenant_id: str, name: str, description: str | None, snippet_type: SnippetType, icon_info: dict | None, input_fields: list[dict] | None, account: Account, ) -> CustomizedSnippet: """ Create a new snippet. :param tenant_id: Tenant ID :param name: Snippet name (must be unique per tenant) :param description: Snippet description :param snippet_type: Type of snippet (node or group) :param icon_info: Icon information :param input_fields: Input field definitions :param account: Creator account :return: Created CustomizedSnippet :raises ValueError: If name already exists """ # Check if name already exists for this tenant existing = ( db.session.query(CustomizedSnippet) .where( CustomizedSnippet.tenant_id == tenant_id, CustomizedSnippet.name == name, ) .first() ) if existing: raise ValueError(f"Snippet with name '{name}' already exists") snippet = CustomizedSnippet( tenant_id=tenant_id, name=name, description=description or "", type=snippet_type.value, icon_info=icon_info, input_fields=json.dumps(input_fields) if input_fields else None, created_by=account.id, ) db.session.add(snippet) db.session.commit() return snippet @staticmethod def update_snippet( *, session: Session, snippet: CustomizedSnippet, account_id: str, data: dict, ) -> CustomizedSnippet: """ Update snippet attributes. :param session: Database session :param snippet: Snippet to update :param account_id: ID of account making the update :param data: Dictionary of fields to update :return: Updated CustomizedSnippet """ if "name" in data: # Check if new name already exists for this tenant existing = ( session.query(CustomizedSnippet) .where( CustomizedSnippet.tenant_id == snippet.tenant_id, CustomizedSnippet.name == data["name"], CustomizedSnippet.id != snippet.id, ) .first() ) if existing: raise ValueError(f"Snippet with name '{data['name']}' already exists") snippet.name = data["name"] if "description" in data: snippet.description = data["description"] if "icon_info" in data: snippet.icon_info = data["icon_info"] snippet.updated_by = account_id snippet.updated_at = datetime.now(UTC).replace(tzinfo=None) session.add(snippet) return snippet @staticmethod def delete_snippet( *, session: Session, snippet: CustomizedSnippet, ) -> bool: """ Delete a snippet. :param session: Database session :param snippet: Snippet to delete :return: True if deleted successfully """ session.delete(snippet) return True # --- Workflow Operations --- def get_draft_workflow(self, snippet: CustomizedSnippet) -> Workflow | None: """ Get draft workflow for snippet. :param snippet: CustomizedSnippet instance :return: Draft Workflow or None """ workflow = ( db.session.query(Workflow) .where( Workflow.tenant_id == snippet.tenant_id, Workflow.app_id == snippet.id, self._snippet_kind_filter(), Workflow.version == "draft", ) .first() ) return workflow def get_published_workflow(self, snippet: CustomizedSnippet) -> Workflow | None: """ Get published workflow for snippet. :param snippet: CustomizedSnippet instance :return: Published Workflow or None """ if not snippet.workflow_id: return None workflow = ( db.session.query(Workflow) .where( Workflow.tenant_id == snippet.tenant_id, Workflow.app_id == snippet.id, self._snippet_kind_filter(), Workflow.id == snippet.workflow_id, ) .first() ) return workflow def get_published_workflow_by_id(self, snippet: CustomizedSnippet, workflow_id: str) -> Workflow | None: """ Get a published workflow snapshot by ID for snippet history restore. :param snippet: CustomizedSnippet instance :param workflow_id: Workflow ID :return: Published Workflow or None :raises IsDraftWorkflowError: If the workflow ID points to a draft workflow """ workflow = ( db.session.query(Workflow) .where( Workflow.tenant_id == snippet.tenant_id, Workflow.app_id == snippet.id, self._snippet_kind_filter(), Workflow.id == workflow_id, ) .first() ) if not workflow: return None if workflow.version == Workflow.VERSION_DRAFT: raise IsDraftWorkflowError("source workflow must be published") return workflow def sync_draft_workflow( self, *, snippet: CustomizedSnippet, graph: dict, unique_hash: str | None, account: Account, input_fields: list[dict] | None = None, ) -> Workflow: """ Sync draft workflow for snippet. Snippet workflows do not persist environment variables (always empty) or conversation variables (always empty). :param snippet: CustomizedSnippet instance :param graph: Workflow graph configuration :param unique_hash: Hash for conflict detection :param account: Account making the change :param input_fields: Input fields for snippet :return: Synced Workflow :raises WorkflowHashNotEqualError: If hash mismatch """ SnippetService.validate_snippet_graph_forbidden_nodes(graph) workflow = self.get_draft_workflow(snippet=snippet) if workflow and workflow.unique_hash != unique_hash: raise WorkflowHashNotEqualError() # Create draft workflow if not found if not workflow: workflow = Workflow( tenant_id=snippet.tenant_id, app_id=snippet.id, features="{}", type=WorkflowType.WORKFLOW.value, kind=WorkflowKind.SNIPPET.value, version="draft", graph=json.dumps(graph), created_by=account.id, environment_variables=[], conversation_variables=[], ) db.session.add(workflow) db.session.flush() else: # Update existing draft workflow workflow.graph = json.dumps(graph) workflow.type = WorkflowType.WORKFLOW.value workflow.kind = WorkflowKind.SNIPPET workflow.updated_by = account.id workflow.updated_at = datetime.now(UTC).replace(tzinfo=None) workflow.environment_variables = [] workflow.conversation_variables = [] # Update snippet's input_fields if provided if input_fields is not None: snippet.input_fields = json.dumps(input_fields) snippet.updated_by = account.id snippet.updated_at = datetime.now(UTC).replace(tzinfo=None) db.session.commit() return workflow def restore_published_workflow_to_draft( self, *, snippet: CustomizedSnippet, workflow_id: str, account: Account, ) -> Workflow: """ Restore a published snippet workflow snapshot into the draft workflow. :param snippet: CustomizedSnippet instance :param workflow_id: Published workflow ID :param account: Account making the change :return: Restored draft Workflow :raises WorkflowNotFoundError: If the source workflow does not exist :raises IsDraftWorkflowError: If the source workflow is a draft :raises ValueError: If the restored graph is invalid for snippets """ source_workflow = self.get_published_workflow_by_id(snippet=snippet, workflow_id=workflow_id) if not source_workflow: raise WorkflowNotFoundError("Workflow not found.") SnippetService.validate_snippet_graph_forbidden_nodes(source_workflow.graph_dict) draft_workflow = self.get_draft_workflow(snippet=snippet) draft_workflow, is_new_draft = apply_published_workflow_snapshot_to_draft( tenant_id=snippet.tenant_id, app_id=snippet.id, source_workflow=source_workflow, draft_workflow=draft_workflow, account=account, updated_at_factory=lambda: datetime.now(UTC).replace(tzinfo=None), ) if is_new_draft: db.session.add(draft_workflow) db.session.commit() return draft_workflow def publish_workflow( self, *, session: Session, snippet: CustomizedSnippet, account: Account, ) -> Workflow: """ Publish the draft workflow as a new version. :param session: Database session :param snippet: CustomizedSnippet instance :param account: Account making the change :return: Published Workflow :raises ValueError: If no draft workflow exists """ draft_workflow_stmt = select(Workflow).where( Workflow.tenant_id == snippet.tenant_id, Workflow.app_id == snippet.id, self._snippet_kind_filter(), Workflow.version == "draft", ) draft_workflow = session.scalar(draft_workflow_stmt) if not draft_workflow: raise ValueError("No valid workflow found.") SnippetService.validate_snippet_graph_forbidden_nodes(draft_workflow.graph_dict) # Create new published workflow workflow = Workflow.new( tenant_id=snippet.tenant_id, app_id=snippet.id, type=WorkflowType.WORKFLOW.value, version=str(datetime.now(UTC).replace(tzinfo=None)), graph=draft_workflow.graph, features=draft_workflow.features, created_by=account.id, environment_variables=[], conversation_variables=[], rag_pipeline_variables=draft_workflow.rag_pipeline_variables, kind=WorkflowKind.SNIPPET.value, marked_name="", marked_comment="", ) session.add(workflow) # Update snippet version snippet.version += 1 snippet.is_published = True snippet.workflow_id = workflow.id snippet.updated_by = account.id session.add(snippet) return workflow def get_all_published_workflows( self, *, session: Session, snippet: CustomizedSnippet, page: int, limit: int, ) -> tuple[Sequence[Workflow], bool]: """ Get all published workflow versions for snippet. :param session: Database session :param snippet: CustomizedSnippet instance :param page: Page number :param limit: Items per page :return: Tuple of (workflows list, has_more flag) """ if not snippet.workflow_id: return [], False stmt = ( select(Workflow) .where( Workflow.app_id == snippet.id, self._snippet_kind_filter(), Workflow.version != "draft", ) .order_by(Workflow.version.desc()) .limit(limit + 1) .offset((page - 1) * limit) ) workflows = list(session.scalars(stmt).all()) has_more = len(workflows) > limit if has_more: workflows = workflows[:-1] return workflows, has_more # --- Default Block Configs --- def get_default_block_configs(self) -> list[dict]: """ Get default block configurations for all node types. :return: List of default configurations """ default_block_configs: list[dict[str, Any]] = [] for node_class_mapping in NODE_TYPE_CLASSES_MAPPING.values(): node_class = node_class_mapping[LATEST_VERSION] default_config = node_class.get_default_config() if default_config: default_block_configs.append(dict(default_config)) return default_block_configs def get_default_block_config(self, node_type: str, filters: dict | None = None) -> Mapping[str, object] | None: """ Get default config for specific node type. :param node_type: Node type string :param filters: Optional filters :return: Default configuration or None """ node_type_enum = NodeType(node_type) if node_type_enum not in NODE_TYPE_CLASSES_MAPPING: return None node_class = NODE_TYPE_CLASSES_MAPPING[node_type_enum][LATEST_VERSION] default_config = node_class.get_default_config(filters=filters) if not default_config: return None return default_config # --- Workflow Run Operations --- def get_snippet_workflow_runs( self, *, snippet: CustomizedSnippet, args: dict, ) -> InfiniteScrollPagination: """ Get paginated workflow runs for snippet. :param snippet: CustomizedSnippet instance :param args: Request arguments (last_id, limit) :return: InfiniteScrollPagination result """ limit = int(args.get("limit", 20)) last_id = args.get("last_id") triggered_from_values = [ WorkflowRunTriggeredFrom.DEBUGGING, ] return self._workflow_run_repo.get_paginated_workflow_runs( tenant_id=snippet.tenant_id, app_id=snippet.id, triggered_from=triggered_from_values, limit=limit, last_id=last_id, ) def get_snippet_workflow_run( self, *, snippet: CustomizedSnippet, run_id: str, ) -> WorkflowRun | None: """ Get workflow run details. :param snippet: CustomizedSnippet instance :param run_id: Workflow run ID :return: WorkflowRun or None """ return self._workflow_run_repo.get_workflow_run_by_id( tenant_id=snippet.tenant_id, app_id=snippet.id, run_id=run_id, ) def get_snippet_workflow_run_node_executions( self, *, snippet: CustomizedSnippet, run_id: str, ) -> Sequence[WorkflowNodeExecutionModel]: """ Get workflow run node execution list. :param snippet: CustomizedSnippet instance :param run_id: Workflow run ID :return: List of WorkflowNodeExecutionModel """ workflow_run = self.get_snippet_workflow_run(snippet=snippet, run_id=run_id) if not workflow_run: return [] node_executions = self._node_execution_service_repo.get_executions_by_workflow_run( tenant_id=snippet.tenant_id, app_id=snippet.id, workflow_run_id=workflow_run.id, ) return node_executions # --- Node Execution Operations --- def get_snippet_node_last_run( self, *, snippet: CustomizedSnippet, workflow: Workflow, node_id: str, ) -> WorkflowNodeExecutionModel | None: """ Get the most recent execution for a specific node in a snippet workflow. :param snippet: CustomizedSnippet instance :param workflow: Workflow instance :param node_id: Node identifier :return: WorkflowNodeExecutionModel or None """ return self._node_execution_service_repo.get_node_last_execution( tenant_id=snippet.tenant_id, app_id=snippet.id, workflow_id=workflow.id, node_id=node_id, ) # --- Use Count --- @staticmethod def increment_use_count( *, session: Session, snippet: CustomizedSnippet, ) -> None: """ Increment the use_count when snippet is used. :param session: Database session :param snippet: CustomizedSnippet instance """ snippet.use_count += 1 session.add(snippet)