dify/api/services/snippet_service.py
2026-04-28 19:10:47 +08:00

685 lines
22 KiB
Python

import json
import logging
from collections.abc import Mapping, Sequence
from datetime import UTC, datetime
from typing import Any
from sqlalchemy import func, select
from sqlalchemy.orm import Session, sessionmaker
from core.workflow.node_factory import LATEST_VERSION, NODE_TYPE_CLASSES_MAPPING
from extensions.ext_database import db
from graphon.enums import BuiltinNodeTypes, NodeType
from libs.infinite_scroll_pagination import InfiniteScrollPagination
from models import Account
from models.enums import WorkflowRunTriggeredFrom
from models.snippet import CustomizedSnippet, SnippetType
from models.workflow import (
Workflow,
WorkflowKind,
WorkflowNodeExecutionModel,
WorkflowRun,
WorkflowType,
)
from repositories.factory import DifyAPIRepositoryFactory
from services.errors.app import IsDraftWorkflowError, WorkflowHashNotEqualError, WorkflowNotFoundError
from services.workflow_restore import apply_published_workflow_snapshot_to_draft
logger = logging.getLogger(__name__)
# Node types not allowed in snippet workflows (sync, publish, DSL import).
SNIPPET_FORBIDDEN_NODE_TYPES: frozenset[str] = frozenset(
{
BuiltinNodeTypes.START,
BuiltinNodeTypes.HUMAN_INPUT,
BuiltinNodeTypes.KNOWLEDGE_RETRIEVAL,
}
)
class SnippetService:
"""Service for managing customized snippets."""
def __init__(self, session_maker: sessionmaker | None = None):
"""Initialize SnippetService with repository dependencies."""
if session_maker is None:
session_maker = sessionmaker(bind=db.engine, expire_on_commit=False)
self._node_execution_service_repo = DifyAPIRepositoryFactory.create_api_workflow_node_execution_repository(
session_maker
)
self._workflow_run_repo = DifyAPIRepositoryFactory.create_api_workflow_run_repository(session_maker)
@staticmethod
def _snippet_kind_filter():
"""Match snippet workflows by business kind."""
return Workflow.kind == WorkflowKind.SNIPPET.value
@staticmethod
def validate_snippet_graph_forbidden_nodes(graph: Mapping[str, Any]) -> None:
"""Reject graphs that contain node types not allowed in snippets."""
nodes = graph.get("nodes") or []
disallowed: list[tuple[str, str]] = []
for node in nodes:
if not isinstance(node, dict):
continue
node_data = node.get("data") or {}
node_type = node_data.get("type")
if not isinstance(node_type, str):
continue
if node_type in SNIPPET_FORBIDDEN_NODE_TYPES:
node_id = node.get("id")
disallowed.append((str(node_id) if node_id is not None else "?", node_type))
if not disallowed:
return
detail = ", ".join(f"{nid}:{t}" for nid, t in disallowed)
raise ValueError(
"Snippet workflow cannot contain start, human-input, or knowledge-retrieval nodes. "
f"Found: {detail}"
)
# --- CRUD Operations ---
@staticmethod
def get_snippets(
*,
tenant_id: str,
page: int = 1,
limit: int = 20,
keyword: str | None = None,
is_published: bool | None = None,
creators: list[str] | None = None,
) -> tuple[Sequence[CustomizedSnippet], int, bool]:
"""
Get paginated list of snippets with optional search.
:param tenant_id: Tenant ID
:param page: Page number (1-indexed)
:param limit: Number of items per page
:param keyword: Optional search keyword for name/description
:param is_published: Optional filter by published status (True/False/None for all)
:param creators: Optional filter by creator account IDs
:return: Tuple of (snippets list, total count, has_more flag)
"""
stmt = (
select(CustomizedSnippet)
.where(CustomizedSnippet.tenant_id == tenant_id)
.order_by(CustomizedSnippet.created_at.desc())
)
if keyword:
stmt = stmt.where(
CustomizedSnippet.name.ilike(f"%{keyword}%") | CustomizedSnippet.description.ilike(f"%{keyword}%")
)
if is_published is not None:
stmt = stmt.where(CustomizedSnippet.is_published == is_published)
if creators:
stmt = stmt.where(CustomizedSnippet.created_by.in_(creators))
# Get total count
count_stmt = select(func.count()).select_from(stmt.subquery())
total = db.session.scalar(count_stmt) or 0
# Apply pagination
stmt = stmt.limit(limit + 1).offset((page - 1) * limit)
snippets = list(db.session.scalars(stmt).all())
has_more = len(snippets) > limit
if has_more:
snippets = snippets[:-1]
return snippets, total, has_more
@staticmethod
def get_snippet_by_id(
*,
snippet_id: str,
tenant_id: str,
) -> CustomizedSnippet | None:
"""
Get snippet by ID with tenant isolation.
:param snippet_id: Snippet ID
:param tenant_id: Tenant ID
:return: CustomizedSnippet or None
"""
return (
db.session.query(CustomizedSnippet)
.where(
CustomizedSnippet.id == snippet_id,
CustomizedSnippet.tenant_id == tenant_id,
)
.first()
)
@staticmethod
def create_snippet(
*,
tenant_id: str,
name: str,
description: str | None,
snippet_type: SnippetType,
icon_info: dict | None,
input_fields: list[dict] | None,
account: Account,
) -> CustomizedSnippet:
"""
Create a new snippet.
:param tenant_id: Tenant ID
:param name: Snippet name (must be unique per tenant)
:param description: Snippet description
:param snippet_type: Type of snippet (node or group)
:param icon_info: Icon information
:param input_fields: Input field definitions
:param account: Creator account
:return: Created CustomizedSnippet
:raises ValueError: If name already exists
"""
# Check if name already exists for this tenant
existing = (
db.session.query(CustomizedSnippet)
.where(
CustomizedSnippet.tenant_id == tenant_id,
CustomizedSnippet.name == name,
)
.first()
)
if existing:
raise ValueError(f"Snippet with name '{name}' already exists")
snippet = CustomizedSnippet(
tenant_id=tenant_id,
name=name,
description=description or "",
type=snippet_type.value,
icon_info=icon_info,
input_fields=json.dumps(input_fields) if input_fields else None,
created_by=account.id,
)
db.session.add(snippet)
db.session.commit()
return snippet
@staticmethod
def update_snippet(
*,
session: Session,
snippet: CustomizedSnippet,
account_id: str,
data: dict,
) -> CustomizedSnippet:
"""
Update snippet attributes.
:param session: Database session
:param snippet: Snippet to update
:param account_id: ID of account making the update
:param data: Dictionary of fields to update
:return: Updated CustomizedSnippet
"""
if "name" in data:
# Check if new name already exists for this tenant
existing = (
session.query(CustomizedSnippet)
.where(
CustomizedSnippet.tenant_id == snippet.tenant_id,
CustomizedSnippet.name == data["name"],
CustomizedSnippet.id != snippet.id,
)
.first()
)
if existing:
raise ValueError(f"Snippet with name '{data['name']}' already exists")
snippet.name = data["name"]
if "description" in data:
snippet.description = data["description"]
if "icon_info" in data:
snippet.icon_info = data["icon_info"]
snippet.updated_by = account_id
snippet.updated_at = datetime.now(UTC).replace(tzinfo=None)
session.add(snippet)
return snippet
@staticmethod
def delete_snippet(
*,
session: Session,
snippet: CustomizedSnippet,
) -> bool:
"""
Delete a snippet.
:param session: Database session
:param snippet: Snippet to delete
:return: True if deleted successfully
"""
session.delete(snippet)
return True
# --- Workflow Operations ---
def get_draft_workflow(self, snippet: CustomizedSnippet) -> Workflow | None:
"""
Get draft workflow for snippet.
:param snippet: CustomizedSnippet instance
:return: Draft Workflow or None
"""
workflow = (
db.session.query(Workflow)
.where(
Workflow.tenant_id == snippet.tenant_id,
Workflow.app_id == snippet.id,
self._snippet_kind_filter(),
Workflow.version == "draft",
)
.first()
)
return workflow
def get_published_workflow(self, snippet: CustomizedSnippet) -> Workflow | None:
"""
Get published workflow for snippet.
:param snippet: CustomizedSnippet instance
:return: Published Workflow or None
"""
if not snippet.workflow_id:
return None
workflow = (
db.session.query(Workflow)
.where(
Workflow.tenant_id == snippet.tenant_id,
Workflow.app_id == snippet.id,
self._snippet_kind_filter(),
Workflow.id == snippet.workflow_id,
)
.first()
)
return workflow
def get_published_workflow_by_id(self, snippet: CustomizedSnippet, workflow_id: str) -> Workflow | None:
"""
Get a published workflow snapshot by ID for snippet history restore.
:param snippet: CustomizedSnippet instance
:param workflow_id: Workflow ID
:return: Published Workflow or None
:raises IsDraftWorkflowError: If the workflow ID points to a draft workflow
"""
workflow = (
db.session.query(Workflow)
.where(
Workflow.tenant_id == snippet.tenant_id,
Workflow.app_id == snippet.id,
self._snippet_kind_filter(),
Workflow.id == workflow_id,
)
.first()
)
if not workflow:
return None
if workflow.version == Workflow.VERSION_DRAFT:
raise IsDraftWorkflowError("source workflow must be published")
return workflow
def sync_draft_workflow(
self,
*,
snippet: CustomizedSnippet,
graph: dict,
unique_hash: str | None,
account: Account,
input_fields: list[dict] | None = None,
) -> Workflow:
"""
Sync draft workflow for snippet.
Snippet workflows do not persist environment variables (always empty) or
conversation variables (always empty).
:param snippet: CustomizedSnippet instance
:param graph: Workflow graph configuration
:param unique_hash: Hash for conflict detection
:param account: Account making the change
:param input_fields: Input fields for snippet
:return: Synced Workflow
:raises WorkflowHashNotEqualError: If hash mismatch
"""
SnippetService.validate_snippet_graph_forbidden_nodes(graph)
workflow = self.get_draft_workflow(snippet=snippet)
if workflow and workflow.unique_hash != unique_hash:
raise WorkflowHashNotEqualError()
# Create draft workflow if not found
if not workflow:
workflow = Workflow(
tenant_id=snippet.tenant_id,
app_id=snippet.id,
features="{}",
type=WorkflowType.WORKFLOW.value,
kind=WorkflowKind.SNIPPET.value,
version="draft",
graph=json.dumps(graph),
created_by=account.id,
environment_variables=[],
conversation_variables=[],
)
db.session.add(workflow)
db.session.flush()
else:
# Update existing draft workflow
workflow.graph = json.dumps(graph)
workflow.type = WorkflowType.WORKFLOW.value
workflow.kind = WorkflowKind.SNIPPET
workflow.updated_by = account.id
workflow.updated_at = datetime.now(UTC).replace(tzinfo=None)
workflow.environment_variables = []
workflow.conversation_variables = []
# Update snippet's input_fields if provided
if input_fields is not None:
snippet.input_fields = json.dumps(input_fields)
snippet.updated_by = account.id
snippet.updated_at = datetime.now(UTC).replace(tzinfo=None)
db.session.commit()
return workflow
def restore_published_workflow_to_draft(
self,
*,
snippet: CustomizedSnippet,
workflow_id: str,
account: Account,
) -> Workflow:
"""
Restore a published snippet workflow snapshot into the draft workflow.
:param snippet: CustomizedSnippet instance
:param workflow_id: Published workflow ID
:param account: Account making the change
:return: Restored draft Workflow
:raises WorkflowNotFoundError: If the source workflow does not exist
:raises IsDraftWorkflowError: If the source workflow is a draft
:raises ValueError: If the restored graph is invalid for snippets
"""
source_workflow = self.get_published_workflow_by_id(snippet=snippet, workflow_id=workflow_id)
if not source_workflow:
raise WorkflowNotFoundError("Workflow not found.")
SnippetService.validate_snippet_graph_forbidden_nodes(source_workflow.graph_dict)
draft_workflow = self.get_draft_workflow(snippet=snippet)
draft_workflow, is_new_draft = apply_published_workflow_snapshot_to_draft(
tenant_id=snippet.tenant_id,
app_id=snippet.id,
source_workflow=source_workflow,
draft_workflow=draft_workflow,
account=account,
updated_at_factory=lambda: datetime.now(UTC).replace(tzinfo=None),
)
if is_new_draft:
db.session.add(draft_workflow)
db.session.commit()
return draft_workflow
def publish_workflow(
self,
*,
session: Session,
snippet: CustomizedSnippet,
account: Account,
) -> Workflow:
"""
Publish the draft workflow as a new version.
:param session: Database session
:param snippet: CustomizedSnippet instance
:param account: Account making the change
:return: Published Workflow
:raises ValueError: If no draft workflow exists
"""
draft_workflow_stmt = select(Workflow).where(
Workflow.tenant_id == snippet.tenant_id,
Workflow.app_id == snippet.id,
self._snippet_kind_filter(),
Workflow.version == "draft",
)
draft_workflow = session.scalar(draft_workflow_stmt)
if not draft_workflow:
raise ValueError("No valid workflow found.")
SnippetService.validate_snippet_graph_forbidden_nodes(draft_workflow.graph_dict)
# Create new published workflow
workflow = Workflow.new(
tenant_id=snippet.tenant_id,
app_id=snippet.id,
type=WorkflowType.WORKFLOW.value,
version=str(datetime.now(UTC).replace(tzinfo=None)),
graph=draft_workflow.graph,
features=draft_workflow.features,
created_by=account.id,
environment_variables=[],
conversation_variables=[],
rag_pipeline_variables=draft_workflow.rag_pipeline_variables,
kind=WorkflowKind.SNIPPET.value,
marked_name="",
marked_comment="",
)
session.add(workflow)
# Update snippet version
snippet.version += 1
snippet.is_published = True
snippet.workflow_id = workflow.id
snippet.updated_by = account.id
session.add(snippet)
return workflow
def get_all_published_workflows(
self,
*,
session: Session,
snippet: CustomizedSnippet,
page: int,
limit: int,
) -> tuple[Sequence[Workflow], bool]:
"""
Get all published workflow versions for snippet.
:param session: Database session
:param snippet: CustomizedSnippet instance
:param page: Page number
:param limit: Items per page
:return: Tuple of (workflows list, has_more flag)
"""
if not snippet.workflow_id:
return [], False
stmt = (
select(Workflow)
.where(
Workflow.app_id == snippet.id,
self._snippet_kind_filter(),
Workflow.version != "draft",
)
.order_by(Workflow.version.desc())
.limit(limit + 1)
.offset((page - 1) * limit)
)
workflows = list(session.scalars(stmt).all())
has_more = len(workflows) > limit
if has_more:
workflows = workflows[:-1]
return workflows, has_more
# --- Default Block Configs ---
def get_default_block_configs(self) -> list[dict]:
"""
Get default block configurations for all node types.
:return: List of default configurations
"""
default_block_configs: list[dict[str, Any]] = []
for node_class_mapping in NODE_TYPE_CLASSES_MAPPING.values():
node_class = node_class_mapping[LATEST_VERSION]
default_config = node_class.get_default_config()
if default_config:
default_block_configs.append(dict(default_config))
return default_block_configs
def get_default_block_config(self, node_type: str, filters: dict | None = None) -> Mapping[str, object] | None:
"""
Get default config for specific node type.
:param node_type: Node type string
:param filters: Optional filters
:return: Default configuration or None
"""
node_type_enum = NodeType(node_type)
if node_type_enum not in NODE_TYPE_CLASSES_MAPPING:
return None
node_class = NODE_TYPE_CLASSES_MAPPING[node_type_enum][LATEST_VERSION]
default_config = node_class.get_default_config(filters=filters)
if not default_config:
return None
return default_config
# --- Workflow Run Operations ---
def get_snippet_workflow_runs(
self,
*,
snippet: CustomizedSnippet,
args: dict,
) -> InfiniteScrollPagination:
"""
Get paginated workflow runs for snippet.
:param snippet: CustomizedSnippet instance
:param args: Request arguments (last_id, limit)
:return: InfiniteScrollPagination result
"""
limit = int(args.get("limit", 20))
last_id = args.get("last_id")
triggered_from_values = [
WorkflowRunTriggeredFrom.DEBUGGING,
]
return self._workflow_run_repo.get_paginated_workflow_runs(
tenant_id=snippet.tenant_id,
app_id=snippet.id,
triggered_from=triggered_from_values,
limit=limit,
last_id=last_id,
)
def get_snippet_workflow_run(
self,
*,
snippet: CustomizedSnippet,
run_id: str,
) -> WorkflowRun | None:
"""
Get workflow run details.
:param snippet: CustomizedSnippet instance
:param run_id: Workflow run ID
:return: WorkflowRun or None
"""
return self._workflow_run_repo.get_workflow_run_by_id(
tenant_id=snippet.tenant_id,
app_id=snippet.id,
run_id=run_id,
)
def get_snippet_workflow_run_node_executions(
self,
*,
snippet: CustomizedSnippet,
run_id: str,
) -> Sequence[WorkflowNodeExecutionModel]:
"""
Get workflow run node execution list.
:param snippet: CustomizedSnippet instance
:param run_id: Workflow run ID
:return: List of WorkflowNodeExecutionModel
"""
workflow_run = self.get_snippet_workflow_run(snippet=snippet, run_id=run_id)
if not workflow_run:
return []
node_executions = self._node_execution_service_repo.get_executions_by_workflow_run(
tenant_id=snippet.tenant_id,
app_id=snippet.id,
workflow_run_id=workflow_run.id,
)
return node_executions
# --- Node Execution Operations ---
def get_snippet_node_last_run(
self,
*,
snippet: CustomizedSnippet,
workflow: Workflow,
node_id: str,
) -> WorkflowNodeExecutionModel | None:
"""
Get the most recent execution for a specific node in a snippet workflow.
:param snippet: CustomizedSnippet instance
:param workflow: Workflow instance
:param node_id: Node identifier
:return: WorkflowNodeExecutionModel or None
"""
return self._node_execution_service_repo.get_node_last_execution(
tenant_id=snippet.tenant_id,
app_id=snippet.id,
workflow_id=workflow.id,
node_id=node_id,
)
# --- Use Count ---
@staticmethod
def increment_use_count(
*,
session: Session,
snippet: CustomizedSnippet,
) -> None:
"""
Increment the use_count when snippet is used.
:param session: Database session
:param snippet: CustomizedSnippet instance
"""
snippet.use_count += 1
session.add(snippet)