feat(trigger): add suspend/timeslice layers and workflow CFS scheduler

- add suspend, timeslice, and trigger post engine layers
- introduce CFS workflow scheduler tasks and supporting entities
- update async workflow, trigger, and webhook services to wire in the new scheduling flow
This commit is contained in:
Yeuoly 2025-10-21 19:20:54 +08:00
parent 55bf9196dc
commit 3d5e2c5ca1
17 changed files with 2698 additions and 2407 deletions

View File

@ -185,6 +185,22 @@ class TriggerConfig(BaseSettings):
) )
class AsyncWorkflowConfig(BaseSettings):
"""
Configuration for async workflow
"""
ASYNC_WORKFLOW_SCHEDULER_GRANULARITY: int = Field(
description="Granularity for async workflow scheduler, "
"sometime, few users could block the queue due to some time-consuming tasks, "
"to avoid this, workflow can be suspended if needed, to achieve"
"this, a time-based checker is required, every granularity seconds, "
"the checker will check the workflow queue and suspend the workflow",
default=1,
ge=1,
)
class PluginConfig(BaseSettings): class PluginConfig(BaseSettings):
""" """
Plugin configs Plugin configs
@ -1165,6 +1181,7 @@ class FeatureConfig(
BillingConfig, BillingConfig,
CodeExecutionSandboxConfig, CodeExecutionSandboxConfig,
TriggerConfig, TriggerConfig,
AsyncWorkflowConfig,
PluginConfig, PluginConfig,
MarketplaceConfig, MarketplaceConfig,
DataSetConfig, DataSetConfig,

View File

@ -27,6 +27,7 @@ from core.helper.trace_id_helper import extract_external_trace_id_from_args
from core.model_runtime.errors.invoke import InvokeAuthorizationError from core.model_runtime.errors.invoke import InvokeAuthorizationError
from core.ops.ops_trace_manager import TraceQueueManager from core.ops.ops_trace_manager import TraceQueueManager
from core.repositories import DifyCoreRepositoryFactory from core.repositories import DifyCoreRepositoryFactory
from core.workflow.graph_engine.layers.base import GraphEngineLayer
from core.workflow.repositories.draft_variable_repository import DraftVariableSaverFactory from core.workflow.repositories.draft_variable_repository import DraftVariableSaverFactory
from core.workflow.repositories.workflow_execution_repository import WorkflowExecutionRepository from core.workflow.repositories.workflow_execution_repository import WorkflowExecutionRepository
from core.workflow.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository from core.workflow.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository
@ -55,7 +56,8 @@ class WorkflowAppGenerator(BaseAppGenerator):
call_depth: int, call_depth: int,
triggered_from: Optional[WorkflowRunTriggeredFrom] = None, triggered_from: Optional[WorkflowRunTriggeredFrom] = None,
root_node_id: Optional[str] = None, root_node_id: Optional[str] = None,
) -> Generator[Mapping | str, None, None]: ... layers: Optional[Sequence[GraphEngineLayer]] = None,
) -> Generator[Mapping[str, Any] | str, None, None]: ...
@overload @overload
def generate( def generate(
@ -70,6 +72,7 @@ class WorkflowAppGenerator(BaseAppGenerator):
call_depth: int, call_depth: int,
triggered_from: Optional[WorkflowRunTriggeredFrom] = None, triggered_from: Optional[WorkflowRunTriggeredFrom] = None,
root_node_id: Optional[str] = None, root_node_id: Optional[str] = None,
layers: Optional[Sequence[GraphEngineLayer]] = None,
) -> Mapping[str, Any]: ... ) -> Mapping[str, Any]: ...
@overload @overload
@ -85,7 +88,8 @@ class WorkflowAppGenerator(BaseAppGenerator):
call_depth: int, call_depth: int,
triggered_from: Optional[WorkflowRunTriggeredFrom] = None, triggered_from: Optional[WorkflowRunTriggeredFrom] = None,
root_node_id: Optional[str] = None, root_node_id: Optional[str] = None,
) -> Union[Mapping[str, Any], Generator[Mapping | str, None, None]]: ... layers: Optional[Sequence[GraphEngineLayer]] = None,
) -> Union[Mapping[str, Any], Generator[Mapping[str, Any] | str, None, None]]: ...
def generate( def generate(
self, self,
@ -99,7 +103,8 @@ class WorkflowAppGenerator(BaseAppGenerator):
call_depth: int = 0, call_depth: int = 0,
triggered_from: Optional[WorkflowRunTriggeredFrom] = None, triggered_from: Optional[WorkflowRunTriggeredFrom] = None,
root_node_id: Optional[str] = None, root_node_id: Optional[str] = None,
) -> Union[Mapping[str, Any], Generator[Mapping | str, None, None]]: layers: Optional[Sequence[GraphEngineLayer]] = None,
) -> Union[Mapping[str, Any], Generator[Mapping[str, Any] | str, None, None]]:
files: Sequence[Mapping[str, Any]] = args.get("files") or [] files: Sequence[Mapping[str, Any]] = args.get("files") or []
# parse files # parse files
@ -197,8 +202,15 @@ class WorkflowAppGenerator(BaseAppGenerator):
workflow_node_execution_repository=workflow_node_execution_repository, workflow_node_execution_repository=workflow_node_execution_repository,
streaming=streaming, streaming=streaming,
root_node_id=root_node_id, root_node_id=root_node_id,
layers=layers,
) )
def resume(self, *, workflow_run_id: str) -> None:
"""
@TBD
"""
pass
def _generate( def _generate(
self, self,
*, *,
@ -212,6 +224,7 @@ class WorkflowAppGenerator(BaseAppGenerator):
streaming: bool = True, streaming: bool = True,
variable_loader: VariableLoader = DUMMY_VARIABLE_LOADER, variable_loader: VariableLoader = DUMMY_VARIABLE_LOADER,
root_node_id: Optional[str] = None, root_node_id: Optional[str] = None,
layers: Optional[Sequence[GraphEngineLayer]] = None,
) -> Union[Mapping[str, Any], Generator[str | Mapping[str, Any], None, None]]: ) -> Union[Mapping[str, Any], Generator[str | Mapping[str, Any], None, None]]:
""" """
Generate App response. Generate App response.
@ -250,6 +263,7 @@ class WorkflowAppGenerator(BaseAppGenerator):
"root_node_id": root_node_id, "root_node_id": root_node_id,
"workflow_execution_repository": workflow_execution_repository, "workflow_execution_repository": workflow_execution_repository,
"workflow_node_execution_repository": workflow_node_execution_repository, "workflow_node_execution_repository": workflow_node_execution_repository,
"layers": layers,
}, },
) )
@ -444,6 +458,7 @@ class WorkflowAppGenerator(BaseAppGenerator):
workflow_execution_repository: WorkflowExecutionRepository, workflow_execution_repository: WorkflowExecutionRepository,
workflow_node_execution_repository: WorkflowNodeExecutionRepository, workflow_node_execution_repository: WorkflowNodeExecutionRepository,
root_node_id: Optional[str] = None, root_node_id: Optional[str] = None,
layers: Optional[Sequence[GraphEngineLayer]] = None,
) -> None: ) -> None:
""" """
Generate worker in a new thread. Generate worker in a new thread.
@ -488,6 +503,7 @@ class WorkflowAppGenerator(BaseAppGenerator):
workflow_execution_repository=workflow_execution_repository, workflow_execution_repository=workflow_execution_repository,
workflow_node_execution_repository=workflow_node_execution_repository, workflow_node_execution_repository=workflow_node_execution_repository,
root_node_id=root_node_id, root_node_id=root_node_id,
layers=layers,
) )
try: try:

View File

@ -1,13 +1,16 @@
import logging import logging
import time import time
from collections.abc import Sequence
from typing import Optional, cast from typing import Optional, cast
from core.app.apps.base_app_queue_manager import AppQueueManager from core.app.apps.base_app_queue_manager import AppQueueManager
from core.app.apps.workflow.app_config_manager import WorkflowAppConfig from core.app.apps.workflow.app_config_manager import WorkflowAppConfig
from core.app.apps.workflow_app_runner import WorkflowBasedAppRunner from core.app.apps.workflow_app_runner import WorkflowBasedAppRunner
from core.app.engine_layers.suspend_layer import SuspendLayer
from core.app.entities.app_invoke_entities import InvokeFrom, WorkflowAppGenerateEntity from core.app.entities.app_invoke_entities import InvokeFrom, WorkflowAppGenerateEntity
from core.workflow.enums import WorkflowType from core.workflow.enums import WorkflowType
from core.workflow.graph_engine.command_channels.redis_channel import RedisChannel from core.workflow.graph_engine.command_channels.redis_channel import RedisChannel
from core.workflow.graph_engine.layers.base import GraphEngineLayer
from core.workflow.graph_engine.layers.persistence import PersistenceWorkflowInfo, WorkflowPersistenceLayer from core.workflow.graph_engine.layers.persistence import PersistenceWorkflowInfo, WorkflowPersistenceLayer
from core.workflow.repositories.workflow_execution_repository import WorkflowExecutionRepository from core.workflow.repositories.workflow_execution_repository import WorkflowExecutionRepository
from core.workflow.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository from core.workflow.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository
@ -38,6 +41,7 @@ class WorkflowAppRunner(WorkflowBasedAppRunner):
root_node_id: Optional[str] = None, root_node_id: Optional[str] = None,
workflow_execution_repository: WorkflowExecutionRepository, workflow_execution_repository: WorkflowExecutionRepository,
workflow_node_execution_repository: WorkflowNodeExecutionRepository, workflow_node_execution_repository: WorkflowNodeExecutionRepository,
layers: Optional[Sequence[GraphEngineLayer]] = None,
): ):
super().__init__( super().__init__(
queue_manager=queue_manager, queue_manager=queue_manager,
@ -50,6 +54,7 @@ class WorkflowAppRunner(WorkflowBasedAppRunner):
self._root_node_id = root_node_id self._root_node_id = root_node_id
self._workflow_execution_repository = workflow_execution_repository self._workflow_execution_repository = workflow_execution_repository
self._workflow_node_execution_repository = workflow_node_execution_repository self._workflow_node_execution_repository = workflow_node_execution_repository
self._layers = layers or []
def run(self): def run(self):
""" """
@ -137,7 +142,13 @@ class WorkflowAppRunner(WorkflowBasedAppRunner):
trace_manager=self.application_generate_entity.trace_manager, trace_manager=self.application_generate_entity.trace_manager,
) )
suspend_layer = SuspendLayer()
workflow_entry.graph_engine.layer(persistence_layer) workflow_entry.graph_engine.layer(persistence_layer)
workflow_entry.graph_engine.layer(suspend_layer)
for layer in self._layers:
workflow_entry.graph_engine.layer(layer)
generator = workflow_entry.run() generator = workflow_entry.run()

View File

@ -0,0 +1,15 @@
from core.workflow.graph_engine.layers.base import GraphEngineLayer
from core.workflow.graph_events.base import GraphEngineEvent
class SuspendLayer(GraphEngineLayer):
""" """
def on_graph_start(self):
pass
def on_event(self, event: GraphEngineEvent):
pass
def on_graph_end(self, error: Exception | None):
pass

View File

@ -0,0 +1,81 @@
import logging
import uuid
from typing import ClassVar
from apscheduler.schedulers.background import BackgroundScheduler # type: ignore
from configs import dify_config
from core.workflow.graph_engine.entities.commands import CommandType, GraphEngineCommand
from core.workflow.graph_engine.layers.base import GraphEngineLayer
from core.workflow.graph_events.base import GraphEngineEvent
from services.workflow.scheduler import CFSPlanScheduler, SchedulerCommand
logger = logging.getLogger(__name__)
class TimesliceLayer(GraphEngineLayer):
"""
CFS plan scheduler to control the timeslice of the workflow.
"""
scheduler: ClassVar[BackgroundScheduler] = BackgroundScheduler()
def __init__(self, cfs_plan_scheduler: CFSPlanScheduler) -> None:
"""
CFS plan scheduler allows to control the timeslice of the workflow.
"""
if not TimesliceLayer.scheduler.running:
TimesliceLayer.scheduler.start()
super().__init__()
self.cfs_plan_scheduler = cfs_plan_scheduler
self.stopped = False
def on_graph_start(self):
"""
Start timer to check if the workflow need to be suspended.
"""
schedule_id = uuid.uuid4().hex
def runner():
"""
Whenever the workflow is running, keep checking if we need to suspend it.
Otherwise, return directly.
"""
try:
if self.stopped:
self.scheduler.remove_job(schedule_id)
return
if self.cfs_plan_scheduler.can_schedule() == SchedulerCommand.RESOURCE_LIMIT_REACHED:
# remove the job
self.scheduler.remove_job(schedule_id)
if not self.command_channel:
logger.exception("No command channel to stop the workflow")
return
# send command to pause the workflow
self.command_channel.send_command(
GraphEngineCommand(
command_type=CommandType.PAUSE,
payload={
"reason": SchedulerCommand.RESOURCE_LIMIT_REACHED,
},
)
)
except Exception:
logger.exception("scheduler error during check if the workflow need to be suspended")
self.scheduler.add_job(
runner, "interval", seconds=dify_config.ASYNC_WORKFLOW_SCHEDULER_GRANULARITY, id=schedule_id
)
def on_event(self, event: GraphEngineEvent):
pass
def on_graph_end(self, error: Exception | None) -> None:
self.stopped = True

View File

@ -0,0 +1,80 @@
import logging
from datetime import UTC, datetime
from typing import Any
from pydantic import TypeAdapter
from sqlalchemy.orm import Session
from core.workflow.graph_engine.layers.base import GraphEngineLayer
from core.workflow.graph_events.base import GraphEngineEvent
from core.workflow.graph_events.graph import GraphRunFailedEvent, GraphRunPausedEvent, GraphRunSucceededEvent
from models.engine import db
from models.enums import WorkflowTriggerStatus
from repositories.sqlalchemy_workflow_trigger_log_repository import SQLAlchemyWorkflowTriggerLogRepository
from tasks.workflow_cfs_scheduler.cfs_scheduler import TriggerWorkflowCFSPlanEntity
logger = logging.getLogger(__name__)
class TriggerPostLayer(GraphEngineLayer):
"""
Trigger post layer.
"""
def __init__(
self,
cfs_plan_scheduler_entity: TriggerWorkflowCFSPlanEntity,
start_time: datetime,
trigger_log_id: str,
):
self.trigger_log_id = trigger_log_id
self.start_time = start_time
self.cfs_plan_scheduler_entity = cfs_plan_scheduler_entity
def on_graph_start(self):
pass
def on_event(self, event: GraphEngineEvent):
"""
Update trigger log with success or failure.
"""
if isinstance(event, GraphRunSucceededEvent | GraphRunFailedEvent):
with Session(db.engine) as session:
repo = SQLAlchemyWorkflowTriggerLogRepository(session)
trigger_log = repo.get_by_id(self.trigger_log_id)
if not trigger_log:
logger.exception("Trigger log not found: %s", self.trigger_log_id)
return
# Calculate elapsed time
elapsed_time = (datetime.now(UTC) - self.start_time).total_seconds()
# Extract relevant data from result
if not self.graph_runtime_state:
logger.exception("Graph runtime state is not set")
return
outputs = self.graph_runtime_state.outputs
workflow_run_id = outputs.get("workflow_run_id")
total_tokens = self.graph_runtime_state.total_tokens
# Update trigger log with success
trigger_log.status = (
WorkflowTriggerStatus.SUCCEEDED
if isinstance(event, GraphRunSucceededEvent)
else WorkflowTriggerStatus.FAILED
)
trigger_log.workflow_run_id = workflow_run_id
trigger_log.outputs = TypeAdapter(dict[str, Any]).dump_json(outputs).decode()
trigger_log.elapsed_time = elapsed_time
trigger_log.total_tokens = total_tokens
trigger_log.finished_at = datetime.now(UTC)
repo.update(trigger_log)
session.commit()
elif isinstance(event, GraphRunPausedEvent):
# FIXME: handle the paused event
pass
def on_graph_end(self, error: Exception | None) -> None:
pass

View File

@ -88,6 +88,7 @@ dependencies = [
"packaging~=23.2", "packaging~=23.2",
"croniter>=6.0.0", "croniter>=6.0.0",
"weaviate-client==4.17.0", "weaviate-client==4.17.0",
"apscheduler>=3.11.0",
] ]
# Before adding new dependency, consider place it in # Before adding new dependency, consider place it in
# alphabet order (a-z) and suitable group. # alphabet order (a-z) and suitable group.

View File

@ -160,9 +160,6 @@ class AsyncWorkflowService:
else: # SANDBOX else: # SANDBOX
task = execute_workflow_sandbox.delay(task_data_dict) # type: ignore task = execute_workflow_sandbox.delay(task_data_dict) # type: ignore
if not task:
raise ValueError(f"Failed to queue task for queue: {queue_name}")
# 10. Update trigger log with task info # 10. Update trigger log with task info
trigger_log.status = WorkflowTriggerStatus.QUEUED trigger_log.status = WorkflowTriggerStatus.QUEUED
trigger_log.celery_task_id = task.id trigger_log.celery_task_id = task.id

View File

@ -21,13 +21,13 @@ from core.workflow.enums import NodeType
from extensions.ext_database import db from extensions.ext_database import db
from extensions.ext_redis import redis_client from extensions.ext_redis import redis_client
from factories import file_factory from factories import file_factory
from models.enums import AppTriggerStatus, AppTriggerType, WorkflowRunTriggeredFrom from models.enums import AppTriggerStatus, AppTriggerType
from models.model import App from models.model import App
from models.trigger import AppTrigger, WorkflowWebhookTrigger from models.trigger import AppTrigger, WorkflowWebhookTrigger
from models.workflow import Workflow from models.workflow import Workflow
from services.async_workflow_service import AsyncWorkflowService from services.async_workflow_service import AsyncWorkflowService
from services.end_user_service import EndUserService from services.end_user_service import EndUserService
from services.workflow.entities import TriggerData from services.workflow.entities import WebhookTriggerData
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -714,11 +714,10 @@ class WebhookService:
workflow_inputs = cls.build_workflow_inputs(webhook_data) workflow_inputs = cls.build_workflow_inputs(webhook_data)
# Create trigger data # Create trigger data
trigger_data = TriggerData( trigger_data = WebhookTriggerData(
app_id=webhook_trigger.app_id, app_id=webhook_trigger.app_id,
workflow_id=workflow.id, workflow_id=workflow.id,
root_node_id=webhook_trigger.node_id, # Start from the webhook node root_node_id=webhook_trigger.node_id, # Start from the webhook node
trigger_type=WorkflowRunTriggeredFrom.WEBHOOK,
inputs=workflow_inputs, inputs=workflow_inputs,
tenant_id=webhook_trigger.tenant_id, tenant_id=webhook_trigger.tenant_id,
) )

View File

@ -8,7 +8,7 @@ from typing import Any, Optional
from pydantic import BaseModel, ConfigDict, Field from pydantic import BaseModel, ConfigDict, Field
from models.enums import WorkflowRunTriggeredFrom from models.enums import AppTriggerType, WorkflowRunTriggeredFrom
class AsyncTriggerStatus(StrEnum): class AsyncTriggerStatus(StrEnum):
@ -28,7 +28,8 @@ class TriggerData(BaseModel):
root_node_id: str root_node_id: str
inputs: Mapping[str, Any] inputs: Mapping[str, Any]
files: Sequence[Mapping[str, Any]] = Field(default_factory=list) files: Sequence[Mapping[str, Any]] = Field(default_factory=list)
trigger_type: WorkflowRunTriggeredFrom trigger_type: AppTriggerType
trigger_from: WorkflowRunTriggeredFrom
model_config = ConfigDict(use_enum_values=True) model_config = ConfigDict(use_enum_values=True)
@ -36,24 +37,22 @@ class TriggerData(BaseModel):
class WebhookTriggerData(TriggerData): class WebhookTriggerData(TriggerData):
"""Webhook-specific trigger data""" """Webhook-specific trigger data"""
trigger_type: WorkflowRunTriggeredFrom = WorkflowRunTriggeredFrom.WEBHOOK trigger_type: AppTriggerType = AppTriggerType.TRIGGER_WEBHOOK
webhook_url: str trigger_from: WorkflowRunTriggeredFrom = WorkflowRunTriggeredFrom.WEBHOOK
headers: Mapping[str, str] = Field(default_factory=dict)
method: str = "POST"
class ScheduleTriggerData(TriggerData): class ScheduleTriggerData(TriggerData):
"""Schedule-specific trigger data""" """Schedule-specific trigger data"""
trigger_type: WorkflowRunTriggeredFrom = WorkflowRunTriggeredFrom.SCHEDULE trigger_type: AppTriggerType = AppTriggerType.TRIGGER_SCHEDULE
schedule_id: str trigger_from: WorkflowRunTriggeredFrom = WorkflowRunTriggeredFrom.SCHEDULE
cron_expression: str
class PluginTriggerData(TriggerData): class PluginTriggerData(TriggerData):
"""Plugin webhook trigger data""" """Plugin webhook trigger data"""
trigger_type: WorkflowRunTriggeredFrom = WorkflowRunTriggeredFrom.PLUGIN trigger_type: AppTriggerType = AppTriggerType.TRIGGER_PLUGIN
trigger_from: WorkflowRunTriggeredFrom = WorkflowRunTriggeredFrom.PLUGIN
plugin_id: str plugin_id: str
endpoint_id: str endpoint_id: str
@ -125,3 +124,21 @@ class TriggerLogResponse(BaseModel):
finished_at: Optional[str] = None finished_at: Optional[str] = None
model_config = ConfigDict(use_enum_values=True) model_config = ConfigDict(use_enum_values=True)
class WorkflowScheduleCFSPlanEntity(BaseModel):
"""
CFS plan entity.
Ensure each workflow run inside Dify is associated with a CFS(Completely Fair Scheduler) plan.
"""
class Strategy(StrEnum):
"""
CFS plan strategy.
"""
TimeSlice = "time-slice" # time-slice based plan
schedule_strategy: Strategy
granularity: int = Field(default=-1) # -1 means infinite

View File

@ -0,0 +1,34 @@
from abc import ABC, abstractmethod
from enum import StrEnum
from services.workflow.entities import WorkflowScheduleCFSPlanEntity
class SchedulerCommand(StrEnum):
"""
Scheduler command.
"""
RESOURCE_LIMIT_REACHED = "resource_limit_reached"
NONE = "none"
class CFSPlanScheduler(ABC):
"""
CFS plan scheduler.
"""
def __init__(self, plan: WorkflowScheduleCFSPlanEntity):
"""
Initialize the CFS plan scheduler.
Args:
plan: The CFS plan.
"""
self.plan = plan
@abstractmethod
def can_schedule(self) -> SchedulerCommand:
"""
Whether a workflow run can be scheduled.
"""

View File

@ -5,7 +5,6 @@ These tasks handle workflow execution for different subscription tiers
with appropriate retry policies and error handling. with appropriate retry policies and error handling.
""" """
import json
from datetime import UTC, datetime from datetime import UTC, datetime
from typing import Any from typing import Any
@ -13,8 +12,9 @@ from celery import shared_task
from sqlalchemy import select from sqlalchemy import select
from sqlalchemy.orm import Session, sessionmaker from sqlalchemy.orm import Session, sessionmaker
from configs import dify_config
from core.app.apps.workflow.app_generator import WorkflowAppGenerator from core.app.apps.workflow.app_generator import WorkflowAppGenerator
from core.app.engine_layers.timeslice_layer import TimesliceLayer
from core.app.engine_layers.trigger_post_layer import TriggerPostLayer
from core.app.entities.app_invoke_entities import InvokeFrom from core.app.entities.app_invoke_entities import InvokeFrom
from extensions.ext_database import db from extensions.ext_database import db
from models.account import Account from models.account import Account
@ -24,57 +24,64 @@ from models.trigger import WorkflowTriggerLog
from models.workflow import Workflow from models.workflow import Workflow
from repositories.sqlalchemy_workflow_trigger_log_repository import SQLAlchemyWorkflowTriggerLogRepository from repositories.sqlalchemy_workflow_trigger_log_repository import SQLAlchemyWorkflowTriggerLogRepository
from services.errors.app import WorkflowNotFoundError from services.errors.app import WorkflowNotFoundError
from services.workflow.entities import AsyncTriggerExecutionResult, AsyncTriggerStatus, TriggerData, WorkflowTaskData from services.workflow.entities import (
TriggerData,
# Determine queue names based on edition WorkflowScheduleCFSPlanEntity,
if dify_config.EDITION == "CLOUD": WorkflowTaskData,
# Cloud edition: separate queues for different tiers )
_professional_queue = "workflow_professional" from tasks.workflow_cfs_scheduler.cfs_scheduler import TriggerCFSPlanScheduler, TriggerWorkflowCFSPlanEntity
_team_queue = "workflow_team" from tasks.workflow_cfs_scheduler.entities import AsyncWorkflowQueue
_sandbox_queue = "workflow_sandbox"
else:
# Community edition: single workflow queue (not dataset)
_professional_queue = "workflow"
_team_queue = "workflow"
_sandbox_queue = "workflow"
# Define constants
PROFESSIONAL_QUEUE = _professional_queue
TEAM_QUEUE = _team_queue
SANDBOX_QUEUE = _sandbox_queue
@shared_task(queue=PROFESSIONAL_QUEUE) @shared_task(queue=AsyncWorkflowQueue.PROFESSIONAL_QUEUE)
def execute_workflow_professional(task_data_dict: dict[str, Any]) -> dict[str, Any]: def execute_workflow_professional(task_data_dict: dict[str, Any]):
"""Execute workflow for professional tier with highest priority""" """Execute workflow for professional tier with highest priority"""
task_data = WorkflowTaskData.model_validate(task_data_dict) task_data = WorkflowTaskData.model_validate(task_data_dict)
return _execute_workflow_common(task_data).model_dump() cfs_plan_scheduler_entity = TriggerWorkflowCFSPlanEntity(
queue=AsyncWorkflowQueue.PROFESSIONAL_QUEUE, schedule_strategy=WorkflowScheduleCFSPlanEntity.Strategy.TimeSlice
)
_execute_workflow_common(
task_data,
TriggerCFSPlanScheduler(plan=cfs_plan_scheduler_entity),
cfs_plan_scheduler_entity,
)
@shared_task(queue=TEAM_QUEUE) @shared_task(queue=AsyncWorkflowQueue.TEAM_QUEUE)
def execute_workflow_team(task_data_dict: dict[str, Any]) -> dict[str, Any]: def execute_workflow_team(task_data_dict: dict[str, Any]):
"""Execute workflow for team tier""" """Execute workflow for team tier"""
task_data = WorkflowTaskData.model_validate(task_data_dict) task_data = WorkflowTaskData.model_validate(task_data_dict)
return _execute_workflow_common(task_data).model_dump() cfs_plan_scheduler_entity = TriggerWorkflowCFSPlanEntity(
queue=AsyncWorkflowQueue.TEAM_QUEUE, schedule_strategy=WorkflowScheduleCFSPlanEntity.Strategy.TimeSlice
)
_execute_workflow_common(
task_data,
TriggerCFSPlanScheduler(plan=cfs_plan_scheduler_entity),
cfs_plan_scheduler_entity,
)
@shared_task(queue=SANDBOX_QUEUE) @shared_task(queue=AsyncWorkflowQueue.SANDBOX_QUEUE)
def execute_workflow_sandbox(task_data_dict: dict[str, Any]) -> dict[str, Any]: def execute_workflow_sandbox(task_data_dict: dict[str, Any]):
"""Execute workflow for free tier with lower retry limit""" """Execute workflow for free tier with lower retry limit"""
task_data = WorkflowTaskData.model_validate(task_data_dict) task_data = WorkflowTaskData.model_validate(task_data_dict)
return _execute_workflow_common(task_data).model_dump() cfs_plan_scheduler_entity = TriggerWorkflowCFSPlanEntity(
queue=AsyncWorkflowQueue.SANDBOX_QUEUE, schedule_strategy=WorkflowScheduleCFSPlanEntity.Strategy.TimeSlice
)
_execute_workflow_common(
task_data,
TriggerCFSPlanScheduler(plan=cfs_plan_scheduler_entity),
cfs_plan_scheduler_entity,
)
def _execute_workflow_common(task_data: WorkflowTaskData) -> AsyncTriggerExecutionResult: def _execute_workflow_common(
""" task_data: WorkflowTaskData,
Common workflow execution logic with trigger log updates cfs_plan_scheduler: TriggerCFSPlanScheduler,
cfs_plan_scheduler_entity: TriggerWorkflowCFSPlanEntity,
):
"""Execute workflow with common logic and trigger log updates."""
Args:
task_data: Validated Pydantic model with task data
Returns:
AsyncTriggerExecutionResult: Pydantic model with execution results
"""
# Create a new session for this task # Create a new session for this task
session_factory = sessionmaker(bind=db.engine, expire_on_commit=False) session_factory = sessionmaker(bind=db.engine, expire_on_commit=False)
@ -86,11 +93,7 @@ def _execute_workflow_common(task_data: WorkflowTaskData) -> AsyncTriggerExecuti
if not trigger_log: if not trigger_log:
# This should not happen, but handle gracefully # This should not happen, but handle gracefully
return AsyncTriggerExecutionResult( return
execution_id=task_data.workflow_trigger_log_id,
status=AsyncTriggerStatus.FAILED,
error=f"Trigger log not found: {task_data.workflow_trigger_log_id}",
)
# Reconstruct execution data from trigger log # Reconstruct execution data from trigger log
trigger_data = TriggerData.model_validate_json(trigger_log.trigger_data) trigger_data = TriggerData.model_validate_json(trigger_log.trigger_data)
@ -126,7 +129,7 @@ def _execute_workflow_common(task_data: WorkflowTaskData) -> AsyncTriggerExecuti
args["workflow_id"] = str(trigger_data.workflow_id) args["workflow_id"] = str(trigger_data.workflow_id)
# Execute the workflow with the trigger type # Execute the workflow with the trigger type
result = generator.generate( generator.generate(
app_model=app_model, app_model=app_model,
workflow=workflow, workflow=workflow,
user=user, user=user,
@ -136,38 +139,10 @@ def _execute_workflow_common(task_data: WorkflowTaskData) -> AsyncTriggerExecuti
call_depth=0, call_depth=0,
triggered_from=trigger_data.trigger_type, triggered_from=trigger_data.trigger_type,
root_node_id=trigger_data.root_node_id, root_node_id=trigger_data.root_node_id,
) layers=[
TimesliceLayer(cfs_plan_scheduler),
# Calculate elapsed time TriggerPostLayer(cfs_plan_scheduler_entity, start_time, trigger_log.id),
elapsed_time = (datetime.now(UTC) - start_time).total_seconds() ],
# Extract relevant data from result
if isinstance(result, dict):
workflow_run_id = result.get("workflow_run_id")
total_tokens = result.get("total_tokens")
outputs = result
else:
# Handle generator result - collect all data
workflow_run_id = None
total_tokens = None
outputs = {"data": "streaming_result"}
# Update trigger log with success
trigger_log.status = WorkflowTriggerStatus.SUCCEEDED
trigger_log.workflow_run_id = workflow_run_id
trigger_log.outputs = json.dumps(outputs)
trigger_log.elapsed_time = elapsed_time
trigger_log.total_tokens = total_tokens
trigger_log.finished_at = datetime.now(UTC)
trigger_log_repo.update(trigger_log)
session.commit()
return AsyncTriggerExecutionResult(
execution_id=trigger_log.id,
status=AsyncTriggerStatus.COMPLETED,
result=outputs,
elapsed_time=elapsed_time,
total_tokens=total_tokens,
) )
except Exception as e: except Exception as e:
@ -184,10 +159,6 @@ def _execute_workflow_common(task_data: WorkflowTaskData) -> AsyncTriggerExecuti
# Final failure - no retry logic (simplified like RAG tasks) # Final failure - no retry logic (simplified like RAG tasks)
session.commit() session.commit()
return AsyncTriggerExecutionResult(
execution_id=trigger_log.id, status=AsyncTriggerStatus.FAILED, error=str(e), elapsed_time=elapsed_time
)
def _get_user(session: Session, trigger_log: WorkflowTriggerLog) -> Account | EndUser: def _get_user(session: Session, trigger_log: WorkflowTriggerLog) -> Account | EndUser:
"""Compose user from trigger log""" """Compose user from trigger log"""

View File

@ -23,7 +23,6 @@ from core.trigger.trigger_manager import TriggerManager
from core.workflow.enums import NodeType from core.workflow.enums import NodeType
from core.workflow.nodes.trigger_plugin.entities import TriggerEventNodeData from core.workflow.nodes.trigger_plugin.entities import TriggerEventNodeData
from extensions.ext_database import db from extensions.ext_database import db
from models.enums import WorkflowRunTriggeredFrom
from models.model import EndUser from models.model import EndUser
from models.provider_ids import TriggerProviderID from models.provider_ids import TriggerProviderID
from models.trigger import TriggerSubscription, WorkflowPluginTrigger from models.trigger import TriggerSubscription, WorkflowPluginTrigger
@ -199,7 +198,6 @@ def dispatch_triggered_workflow(
tenant_id=subscription.tenant_id, tenant_id=subscription.tenant_id,
workflow_id=workflow.id, workflow_id=workflow.id,
root_node_id=plugin_trigger.node_id, root_node_id=plugin_trigger.node_id,
trigger_type=WorkflowRunTriggeredFrom.PLUGIN,
plugin_id=subscription.provider_id, plugin_id=subscription.provider_id,
endpoint_id=subscription.endpoint_id, endpoint_id=subscription.endpoint_id,
inputs=invoke_response.variables, inputs=invoke_response.variables,

View File

@ -0,0 +1,33 @@
from services.workflow.entities import WorkflowScheduleCFSPlanEntity
from services.workflow.scheduler import CFSPlanScheduler, SchedulerCommand
from tasks.workflow_cfs_scheduler.entities import AsyncWorkflowQueue
class TriggerWorkflowCFSPlanEntity(WorkflowScheduleCFSPlanEntity):
"""
Trigger workflow CFS plan entity.
"""
queue: AsyncWorkflowQueue
class TriggerCFSPlanScheduler(CFSPlanScheduler):
"""
Trigger workflow CFS plan scheduler.
"""
def can_schedule(self) -> SchedulerCommand:
"""
Check if the workflow can be scheduled.
"""
assert isinstance(self.plan, TriggerWorkflowCFSPlanEntity)
if self.plan.queue in [AsyncWorkflowQueue.PROFESSIONAL_QUEUE, AsyncWorkflowQueue.TEAM_QUEUE]:
"""
permitted all paid users to schedule the workflow
"""
return SchedulerCommand.NONE
# FIXME: avoid the sandbox user's workflow at a running state for ever
return SchedulerCommand.NONE

View File

@ -0,0 +1,22 @@
from enum import StrEnum
from configs import dify_config
# Determine queue names based on edition
if dify_config.EDITION == "CLOUD":
# Cloud edition: separate queues for different tiers
_professional_queue = "workflow_professional"
_team_queue = "workflow_team"
_sandbox_queue = "workflow_sandbox"
else:
# Community edition: single workflow queue (not dataset)
_professional_queue = "workflow"
_team_queue = "workflow"
_sandbox_queue = "workflow"
class AsyncWorkflowQueue(StrEnum):
# Define constants
PROFESSIONAL_QUEUE = _professional_queue
TEAM_QUEUE = _team_queue
SANDBOX_QUEUE = _sandbox_queue

View File

@ -14,11 +14,10 @@ from core.workflow.nodes.trigger_schedule.exc import (
TenantOwnerNotFoundError, TenantOwnerNotFoundError,
) )
from extensions.ext_database import db from extensions.ext_database import db
from models.enums import WorkflowRunTriggeredFrom
from models.trigger import WorkflowSchedulePlan from models.trigger import WorkflowSchedulePlan
from services.async_workflow_service import AsyncWorkflowService from services.async_workflow_service import AsyncWorkflowService
from services.trigger.schedule_service import ScheduleService from services.trigger.schedule_service import ScheduleService
from services.workflow.entities import TriggerData from services.workflow.entities import ScheduleTriggerData
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -57,10 +56,9 @@ def run_schedule_trigger(schedule_id: str) -> None:
response = AsyncWorkflowService.trigger_workflow_async( response = AsyncWorkflowService.trigger_workflow_async(
session=session, session=session,
user=tenant_owner, user=tenant_owner,
trigger_data=TriggerData( trigger_data=ScheduleTriggerData(
app_id=schedule.app_id, app_id=schedule.app_id,
root_node_id=schedule.node_id, root_node_id=schedule.node_id,
trigger_type=WorkflowRunTriggeredFrom.SCHEDULE,
inputs=inputs, inputs=inputs,
tenant_id=schedule.tenant_id, tenant_id=schedule.tenant_id,
), ),

4603
api/uv.lock generated

File diff suppressed because it is too large Load Diff