dify/api/services/workflow_plugin_trigger_ser...

547 lines
19 KiB
Python

from typing import Optional
from pydantic import BaseModel
from sqlalchemy import select
from sqlalchemy.orm import Session
from werkzeug.exceptions import NotFound
from core.workflow.nodes.enums import NodeType
from extensions.ext_database import db
from extensions.ext_redis import redis_client
from models.model import App
from models.trigger import TriggerSubscription
from models.workflow import Workflow, WorkflowPluginTrigger
class WorkflowPluginTriggerService:
"""Service for managing workflow plugin triggers"""
__PLUGIN_TRIGGER_NODE_CACHE_KEY__ = "plugin_trigger_nodes"
MAX_PLUGIN_TRIGGER_NODES_PER_WORKFLOW = 5 # Maximum allowed plugin trigger nodes per workflow
@classmethod
def create_plugin_trigger(
cls,
app_id: str,
tenant_id: str,
node_id: str,
provider_id: str,
trigger_name: str,
subscription_id: str,
) -> WorkflowPluginTrigger:
"""Create a new plugin trigger
Args:
app_id: The app ID
tenant_id: The tenant ID
node_id: The node ID in the workflow
provider_id: The plugin provider ID
trigger_name: The trigger name
subscription_id: The subscription ID
Returns:
The created WorkflowPluginTrigger instance
Raises:
BadRequest: If plugin trigger already exists for this app and node
"""
with Session(db.engine) as session:
# Check if plugin trigger already exists for this app and node
# Based on unique constraint: uniq_app_node
existing_trigger = session.scalar(
select(WorkflowPluginTrigger).where(
WorkflowPluginTrigger.app_id == app_id,
WorkflowPluginTrigger.node_id == node_id,
)
)
if existing_trigger:
raise ValueError("Plugin trigger already exists for this app and node")
# Check if subscription exists
subscription = session.scalar(
select(TriggerSubscription).where(
TriggerSubscription.id == subscription_id,
)
)
if not subscription:
raise NotFound("Subscription not found")
# Create new plugin trigger
plugin_trigger = WorkflowPluginTrigger(
app_id=app_id,
node_id=node_id,
tenant_id=tenant_id,
provider_id=provider_id,
trigger_name=trigger_name,
subscription_id=subscription_id,
)
session.add(plugin_trigger)
session.commit()
session.refresh(plugin_trigger)
return plugin_trigger
@classmethod
def get_plugin_trigger(
cls,
app_id: str,
node_id: str,
) -> WorkflowPluginTrigger:
"""Get a plugin trigger by app_id and node_id
Args:
app_id: The app ID
node_id: The node ID in the workflow
Returns:
The WorkflowPluginTrigger instance
Raises:
NotFound: If plugin trigger not found
"""
with Session(db.engine) as session:
# Find plugin trigger using unique constraint
plugin_trigger = session.scalar(
select(WorkflowPluginTrigger).where(
WorkflowPluginTrigger.app_id == app_id,
WorkflowPluginTrigger.node_id == node_id,
)
)
if not plugin_trigger:
raise NotFound("Plugin trigger not found")
return plugin_trigger
@classmethod
def get_plugin_trigger_by_subscription(
cls,
tenant_id: str,
subscription_id: str,
) -> WorkflowPluginTrigger:
"""Get a plugin trigger by tenant_id and subscription_id
This is the primary query pattern, optimized with composite index
Args:
tenant_id: The tenant ID
subscription_id: The subscription ID
Returns:
The WorkflowPluginTrigger instance
Raises:
NotFound: If plugin trigger not found
"""
with Session(db.engine) as session:
# Find plugin trigger using indexed columns
plugin_trigger = session.scalar(
select(WorkflowPluginTrigger).where(
WorkflowPluginTrigger.tenant_id == tenant_id,
WorkflowPluginTrigger.subscription_id == subscription_id,
)
)
if not plugin_trigger:
raise NotFound("Plugin trigger not found")
return plugin_trigger
@classmethod
def list_plugin_triggers_by_tenant(
cls,
tenant_id: str,
) -> list[WorkflowPluginTrigger]:
"""List all plugin triggers for a tenant
Args:
tenant_id: The tenant ID
Returns:
List of WorkflowPluginTrigger instances
"""
with Session(db.engine) as session:
plugin_triggers = session.scalars(
select(WorkflowPluginTrigger)
.where(WorkflowPluginTrigger.tenant_id == tenant_id)
.order_by(WorkflowPluginTrigger.created_at.desc())
).all()
return list(plugin_triggers)
@classmethod
def list_plugin_triggers_by_subscription(
cls,
subscription_id: str,
) -> list[WorkflowPluginTrigger]:
"""List all plugin triggers for a subscription
Args:
subscription_id: The subscription ID
Returns:
List of WorkflowPluginTrigger instances
"""
with Session(db.engine) as session:
plugin_triggers = session.scalars(
select(WorkflowPluginTrigger)
.where(WorkflowPluginTrigger.subscription_id == subscription_id)
.order_by(WorkflowPluginTrigger.created_at.desc())
).all()
return list(plugin_triggers)
@classmethod
def update_plugin_trigger(
cls,
app_id: str,
node_id: str,
subscription_id: str,
) -> WorkflowPluginTrigger:
"""Update a plugin trigger
Args:
app_id: The app ID
node_id: The node ID in the workflow
subscription_id: The new subscription ID (optional)
Returns:
The updated WorkflowPluginTrigger instance
Raises:
NotFound: If plugin trigger not found
"""
with Session(db.engine) as session:
# Find plugin trigger using unique constraint
plugin_trigger = session.scalar(
select(WorkflowPluginTrigger).where(
WorkflowPluginTrigger.app_id == app_id,
WorkflowPluginTrigger.node_id == node_id,
)
)
if not plugin_trigger:
raise NotFound("Plugin trigger not found")
# Check if subscription exists
subscription = session.scalar(
select(TriggerSubscription).where(
TriggerSubscription.id == subscription_id,
)
)
if not subscription:
raise NotFound("Subscription not found")
# Update subscription ID
plugin_trigger.subscription_id = subscription_id
session.commit()
session.refresh(plugin_trigger)
return plugin_trigger
@classmethod
def update_plugin_trigger_by_subscription(
cls,
tenant_id: str,
subscription_id: str,
provider_id: Optional[str] = None,
trigger_name: Optional[str] = None,
new_subscription_id: Optional[str] = None,
) -> WorkflowPluginTrigger:
"""Update a plugin trigger by tenant_id and subscription_id
Args:
tenant_id: The tenant ID
subscription_id: The current subscription ID
provider_id: The new provider ID (optional)
trigger_name: The new trigger name (optional)
new_subscription_id: The new subscription ID (optional)
Returns:
The updated WorkflowPluginTrigger instance
Raises:
NotFound: If plugin trigger not found
"""
with Session(db.engine) as session:
# Find plugin trigger using indexed columns
plugin_trigger = session.scalar(
select(WorkflowPluginTrigger).where(
WorkflowPluginTrigger.tenant_id == tenant_id,
WorkflowPluginTrigger.subscription_id == subscription_id,
)
)
if not plugin_trigger:
raise NotFound("Plugin trigger not found")
# Update fields if provided
if provider_id:
plugin_trigger.provider_id = provider_id
if trigger_name:
# Update trigger_id if provider_id or trigger_name changed
provider_id = provider_id or plugin_trigger.provider_id
plugin_trigger.trigger_name = f"{provider_id}:{trigger_name}"
if new_subscription_id:
plugin_trigger.subscription_id = new_subscription_id
session.commit()
session.refresh(plugin_trigger)
return plugin_trigger
@classmethod
def delete_plugin_trigger(
cls,
app_id: str,
node_id: str,
) -> None:
"""Delete a plugin trigger by app_id and node_id
Args:
app_id: The app ID
node_id: The node ID in the workflow
Raises:
NotFound: If plugin trigger not found
"""
with Session(db.engine) as session:
# Find plugin trigger using unique constraint
plugin_trigger = session.scalar(
select(WorkflowPluginTrigger).where(
WorkflowPluginTrigger.app_id == app_id,
WorkflowPluginTrigger.node_id == node_id,
)
)
if not plugin_trigger:
raise NotFound("Plugin trigger not found")
session.delete(plugin_trigger)
session.commit()
@classmethod
def delete_plugin_trigger_by_subscription(
cls,
session: Session,
tenant_id: str,
subscription_id: str,
) -> None:
"""Delete a plugin trigger by tenant_id and subscription_id within an existing session
Args:
session: Database session
tenant_id: The tenant ID
subscription_id: The subscription ID
Raises:
NotFound: If plugin trigger not found
"""
# Find plugin trigger using indexed columns
plugin_trigger = session.scalar(
select(WorkflowPluginTrigger).where(
WorkflowPluginTrigger.tenant_id == tenant_id,
WorkflowPluginTrigger.subscription_id == subscription_id,
)
)
if not plugin_trigger:
return
session.delete(plugin_trigger)
@classmethod
def delete_all_by_subscription(
cls,
subscription_id: str,
) -> int:
"""Delete all plugin triggers for a subscription
Useful when a subscription is cancelled
Args:
subscription_id: The subscription ID
Returns:
Number of triggers deleted
"""
with Session(db.engine) as session:
# Find all plugin triggers for this subscription
plugin_triggers = session.scalars(
select(WorkflowPluginTrigger).where(
WorkflowPluginTrigger.subscription_id == subscription_id,
)
).all()
count = len(plugin_triggers)
for trigger in plugin_triggers:
session.delete(trigger)
session.commit()
return count
@classmethod
def sync_plugin_trigger_relationships(cls, app: App, workflow: Workflow):
"""
Sync plugin trigger relationships in DB.
1. Check if the workflow has any plugin trigger nodes
2. Fetch the nodes from DB, see if there were any plugin trigger records already
3. Diff the nodes and the plugin trigger records, create/update/delete the records as needed
Approach:
Frequent DB operations may cause performance issues, using Redis to cache it instead.
If any record exists, cache it.
Limits:
- Maximum 5 plugin trigger nodes per workflow
"""
class Cache(BaseModel):
"""
Cache model for plugin trigger nodes
"""
record_id: str
node_id: str
provider_id: str
trigger_name: str
subscription_id: str
# Walk nodes to find plugin triggers
nodes_in_graph = []
for node_id, node_config in workflow.walk_nodes(NodeType.TRIGGER_PLUGIN):
# Extract plugin trigger configuration from node
plugin_id = node_config.get("plugin_id", "")
provider_id = node_config.get("provider_id", "")
trigger_name = node_config.get("trigger_name", "")
subscription_id = node_config.get("subscription_id", "")
if not subscription_id:
continue
nodes_in_graph.append(
{
"node_id": node_id,
"plugin_id": plugin_id,
"provider_id": provider_id,
"trigger_name": trigger_name,
"subscription_id": subscription_id,
}
)
# Check plugin trigger node limit
if len(nodes_in_graph) > cls.MAX_PLUGIN_TRIGGER_NODES_PER_WORKFLOW:
raise ValueError(
f"Workflow exceeds maximum plugin trigger node limit. "
f"Found {len(nodes_in_graph)} plugin trigger nodes, "
f"maximum allowed is {cls.MAX_PLUGIN_TRIGGER_NODES_PER_WORKFLOW}"
)
not_found_in_cache: list[dict] = []
for node_info in nodes_in_graph:
node_id = node_info["node_id"]
# firstly check if the node exists in cache
if not redis_client.get(f"{cls.__PLUGIN_TRIGGER_NODE_CACHE_KEY__}:{node_id}"):
not_found_in_cache.append(node_info)
continue
with Session(db.engine) as session:
try:
# lock the concurrent plugin trigger creation
redis_client.lock(f"{cls.__PLUGIN_TRIGGER_NODE_CACHE_KEY__}:apps:{app.id}:lock", timeout=10)
# fetch the non-cached nodes from DB
all_records = session.scalars(
select(WorkflowPluginTrigger).where(
WorkflowPluginTrigger.app_id == app.id,
WorkflowPluginTrigger.tenant_id == app.tenant_id,
)
).all()
nodes_id_in_db = {node.node_id: node for node in all_records}
nodes_id_in_graph = {node["node_id"] for node in nodes_in_graph}
# get the nodes not found both in cache and DB
nodes_not_found = [
node_info for node_info in not_found_in_cache if node_info["node_id"] not in nodes_id_in_db
]
# create new plugin trigger records
for node_info in nodes_not_found:
plugin_trigger = WorkflowPluginTrigger(
app_id=app.id,
tenant_id=app.tenant_id,
node_id=node_info["node_id"],
provider_id=node_info["provider_id"],
trigger_name=node_info["trigger_name"],
subscription_id=node_info["subscription_id"],
)
session.add(plugin_trigger)
session.flush() # Get the ID for caching
cache = Cache(
record_id=plugin_trigger.id,
node_id=node_info["node_id"],
provider_id=node_info["provider_id"],
trigger_name=node_info["trigger_name"],
subscription_id=node_info["subscription_id"],
)
redis_client.set(
f"{cls.__PLUGIN_TRIGGER_NODE_CACHE_KEY__}:{node_info['node_id']}",
cache.model_dump_json(),
ex=60 * 60,
)
session.commit()
# Update existing records if subscription_id changed
for node_info in nodes_in_graph:
node_id = node_info["node_id"]
if node_id in nodes_id_in_db:
existing_record = nodes_id_in_db[node_id]
if (
existing_record.subscription_id != node_info["subscription_id"]
or existing_record.provider_id != node_info["provider_id"]
or existing_record.trigger_name != node_info["trigger_name"]
):
existing_record.subscription_id = node_info["subscription_id"]
existing_record.provider_id = node_info["provider_id"]
existing_record.trigger_name = node_info["trigger_name"]
session.add(existing_record)
# Update cache
cache = Cache(
record_id=existing_record.id,
node_id=node_id,
provider_id=node_info["provider_id"],
trigger_name=node_info["trigger_name"],
subscription_id=node_info["subscription_id"],
)
redis_client.set(
f"{cls.__PLUGIN_TRIGGER_NODE_CACHE_KEY__}:{node_id}",
cache.model_dump_json(),
ex=60 * 60,
)
session.commit()
# delete the nodes not found in the graph
for node_id in nodes_id_in_db:
if node_id not in nodes_id_in_graph:
session.delete(nodes_id_in_db[node_id])
redis_client.delete(f"{cls.__PLUGIN_TRIGGER_NODE_CACHE_KEY__}:{node_id}")
session.commit()
except Exception:
import logging
logger = logging.getLogger(__name__)
logger.exception("Failed to sync plugin trigger relationships for app %s", app.id)
raise
finally:
redis_client.delete(f"{cls.__PLUGIN_TRIGGER_NODE_CACHE_KEY__}:apps:{app.id}:lock")