mirror of https://github.com/langgenius/dify.git
feat(trigger): add suspend/timeslice layers and workflow CFS scheduler
- add suspend, timeslice, and trigger post engine layers - introduce CFS workflow scheduler tasks and supporting entities - update async workflow, trigger, and webhook services to wire in the new scheduling flow
This commit is contained in:
parent
55bf9196dc
commit
3d5e2c5ca1
|
|
@ -185,6 +185,22 @@ class TriggerConfig(BaseSettings):
|
|||
)
|
||||
|
||||
|
||||
class AsyncWorkflowConfig(BaseSettings):
|
||||
"""
|
||||
Configuration for async workflow
|
||||
"""
|
||||
|
||||
ASYNC_WORKFLOW_SCHEDULER_GRANULARITY: int = Field(
|
||||
description="Granularity for async workflow scheduler, "
|
||||
"sometime, few users could block the queue due to some time-consuming tasks, "
|
||||
"to avoid this, workflow can be suspended if needed, to achieve"
|
||||
"this, a time-based checker is required, every granularity seconds, "
|
||||
"the checker will check the workflow queue and suspend the workflow",
|
||||
default=1,
|
||||
ge=1,
|
||||
)
|
||||
|
||||
|
||||
class PluginConfig(BaseSettings):
|
||||
"""
|
||||
Plugin configs
|
||||
|
|
@ -1165,6 +1181,7 @@ class FeatureConfig(
|
|||
BillingConfig,
|
||||
CodeExecutionSandboxConfig,
|
||||
TriggerConfig,
|
||||
AsyncWorkflowConfig,
|
||||
PluginConfig,
|
||||
MarketplaceConfig,
|
||||
DataSetConfig,
|
||||
|
|
|
|||
|
|
@ -27,6 +27,7 @@ from core.helper.trace_id_helper import extract_external_trace_id_from_args
|
|||
from core.model_runtime.errors.invoke import InvokeAuthorizationError
|
||||
from core.ops.ops_trace_manager import TraceQueueManager
|
||||
from core.repositories import DifyCoreRepositoryFactory
|
||||
from core.workflow.graph_engine.layers.base import GraphEngineLayer
|
||||
from core.workflow.repositories.draft_variable_repository import DraftVariableSaverFactory
|
||||
from core.workflow.repositories.workflow_execution_repository import WorkflowExecutionRepository
|
||||
from core.workflow.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository
|
||||
|
|
@ -55,7 +56,8 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
|||
call_depth: int,
|
||||
triggered_from: Optional[WorkflowRunTriggeredFrom] = None,
|
||||
root_node_id: Optional[str] = None,
|
||||
) -> Generator[Mapping | str, None, None]: ...
|
||||
layers: Optional[Sequence[GraphEngineLayer]] = None,
|
||||
) -> Generator[Mapping[str, Any] | str, None, None]: ...
|
||||
|
||||
@overload
|
||||
def generate(
|
||||
|
|
@ -70,6 +72,7 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
|||
call_depth: int,
|
||||
triggered_from: Optional[WorkflowRunTriggeredFrom] = None,
|
||||
root_node_id: Optional[str] = None,
|
||||
layers: Optional[Sequence[GraphEngineLayer]] = None,
|
||||
) -> Mapping[str, Any]: ...
|
||||
|
||||
@overload
|
||||
|
|
@ -85,7 +88,8 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
|||
call_depth: int,
|
||||
triggered_from: Optional[WorkflowRunTriggeredFrom] = None,
|
||||
root_node_id: Optional[str] = None,
|
||||
) -> Union[Mapping[str, Any], Generator[Mapping | str, None, None]]: ...
|
||||
layers: Optional[Sequence[GraphEngineLayer]] = None,
|
||||
) -> Union[Mapping[str, Any], Generator[Mapping[str, Any] | str, None, None]]: ...
|
||||
|
||||
def generate(
|
||||
self,
|
||||
|
|
@ -99,7 +103,8 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
|||
call_depth: int = 0,
|
||||
triggered_from: Optional[WorkflowRunTriggeredFrom] = None,
|
||||
root_node_id: Optional[str] = None,
|
||||
) -> Union[Mapping[str, Any], Generator[Mapping | str, None, None]]:
|
||||
layers: Optional[Sequence[GraphEngineLayer]] = None,
|
||||
) -> Union[Mapping[str, Any], Generator[Mapping[str, Any] | str, None, None]]:
|
||||
files: Sequence[Mapping[str, Any]] = args.get("files") or []
|
||||
|
||||
# parse files
|
||||
|
|
@ -197,8 +202,15 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
|||
workflow_node_execution_repository=workflow_node_execution_repository,
|
||||
streaming=streaming,
|
||||
root_node_id=root_node_id,
|
||||
layers=layers,
|
||||
)
|
||||
|
||||
def resume(self, *, workflow_run_id: str) -> None:
|
||||
"""
|
||||
@TBD
|
||||
"""
|
||||
pass
|
||||
|
||||
def _generate(
|
||||
self,
|
||||
*,
|
||||
|
|
@ -212,6 +224,7 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
|||
streaming: bool = True,
|
||||
variable_loader: VariableLoader = DUMMY_VARIABLE_LOADER,
|
||||
root_node_id: Optional[str] = None,
|
||||
layers: Optional[Sequence[GraphEngineLayer]] = None,
|
||||
) -> Union[Mapping[str, Any], Generator[str | Mapping[str, Any], None, None]]:
|
||||
"""
|
||||
Generate App response.
|
||||
|
|
@ -250,6 +263,7 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
|||
"root_node_id": root_node_id,
|
||||
"workflow_execution_repository": workflow_execution_repository,
|
||||
"workflow_node_execution_repository": workflow_node_execution_repository,
|
||||
"layers": layers,
|
||||
},
|
||||
)
|
||||
|
||||
|
|
@ -444,6 +458,7 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
|||
workflow_execution_repository: WorkflowExecutionRepository,
|
||||
workflow_node_execution_repository: WorkflowNodeExecutionRepository,
|
||||
root_node_id: Optional[str] = None,
|
||||
layers: Optional[Sequence[GraphEngineLayer]] = None,
|
||||
) -> None:
|
||||
"""
|
||||
Generate worker in a new thread.
|
||||
|
|
@ -488,6 +503,7 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
|||
workflow_execution_repository=workflow_execution_repository,
|
||||
workflow_node_execution_repository=workflow_node_execution_repository,
|
||||
root_node_id=root_node_id,
|
||||
layers=layers,
|
||||
)
|
||||
|
||||
try:
|
||||
|
|
|
|||
|
|
@ -1,13 +1,16 @@
|
|||
import logging
|
||||
import time
|
||||
from collections.abc import Sequence
|
||||
from typing import Optional, cast
|
||||
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager
|
||||
from core.app.apps.workflow.app_config_manager import WorkflowAppConfig
|
||||
from core.app.apps.workflow_app_runner import WorkflowBasedAppRunner
|
||||
from core.app.engine_layers.suspend_layer import SuspendLayer
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom, WorkflowAppGenerateEntity
|
||||
from core.workflow.enums import WorkflowType
|
||||
from core.workflow.graph_engine.command_channels.redis_channel import RedisChannel
|
||||
from core.workflow.graph_engine.layers.base import GraphEngineLayer
|
||||
from core.workflow.graph_engine.layers.persistence import PersistenceWorkflowInfo, WorkflowPersistenceLayer
|
||||
from core.workflow.repositories.workflow_execution_repository import WorkflowExecutionRepository
|
||||
from core.workflow.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository
|
||||
|
|
@ -38,6 +41,7 @@ class WorkflowAppRunner(WorkflowBasedAppRunner):
|
|||
root_node_id: Optional[str] = None,
|
||||
workflow_execution_repository: WorkflowExecutionRepository,
|
||||
workflow_node_execution_repository: WorkflowNodeExecutionRepository,
|
||||
layers: Optional[Sequence[GraphEngineLayer]] = None,
|
||||
):
|
||||
super().__init__(
|
||||
queue_manager=queue_manager,
|
||||
|
|
@ -50,6 +54,7 @@ class WorkflowAppRunner(WorkflowBasedAppRunner):
|
|||
self._root_node_id = root_node_id
|
||||
self._workflow_execution_repository = workflow_execution_repository
|
||||
self._workflow_node_execution_repository = workflow_node_execution_repository
|
||||
self._layers = layers or []
|
||||
|
||||
def run(self):
|
||||
"""
|
||||
|
|
@ -137,7 +142,13 @@ class WorkflowAppRunner(WorkflowBasedAppRunner):
|
|||
trace_manager=self.application_generate_entity.trace_manager,
|
||||
)
|
||||
|
||||
suspend_layer = SuspendLayer()
|
||||
|
||||
workflow_entry.graph_engine.layer(persistence_layer)
|
||||
workflow_entry.graph_engine.layer(suspend_layer)
|
||||
|
||||
for layer in self._layers:
|
||||
workflow_entry.graph_engine.layer(layer)
|
||||
|
||||
generator = workflow_entry.run()
|
||||
|
||||
|
|
|
|||
|
|
@ -0,0 +1,15 @@
|
|||
from core.workflow.graph_engine.layers.base import GraphEngineLayer
|
||||
from core.workflow.graph_events.base import GraphEngineEvent
|
||||
|
||||
|
||||
class SuspendLayer(GraphEngineLayer):
|
||||
""" """
|
||||
|
||||
def on_graph_start(self):
|
||||
pass
|
||||
|
||||
def on_event(self, event: GraphEngineEvent):
|
||||
pass
|
||||
|
||||
def on_graph_end(self, error: Exception | None):
|
||||
pass
|
||||
|
|
@ -0,0 +1,81 @@
|
|||
import logging
|
||||
import uuid
|
||||
from typing import ClassVar
|
||||
|
||||
from apscheduler.schedulers.background import BackgroundScheduler # type: ignore
|
||||
|
||||
from configs import dify_config
|
||||
from core.workflow.graph_engine.entities.commands import CommandType, GraphEngineCommand
|
||||
from core.workflow.graph_engine.layers.base import GraphEngineLayer
|
||||
from core.workflow.graph_events.base import GraphEngineEvent
|
||||
from services.workflow.scheduler import CFSPlanScheduler, SchedulerCommand
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class TimesliceLayer(GraphEngineLayer):
|
||||
"""
|
||||
CFS plan scheduler to control the timeslice of the workflow.
|
||||
"""
|
||||
|
||||
scheduler: ClassVar[BackgroundScheduler] = BackgroundScheduler()
|
||||
|
||||
def __init__(self, cfs_plan_scheduler: CFSPlanScheduler) -> None:
|
||||
"""
|
||||
CFS plan scheduler allows to control the timeslice of the workflow.
|
||||
"""
|
||||
|
||||
if not TimesliceLayer.scheduler.running:
|
||||
TimesliceLayer.scheduler.start()
|
||||
|
||||
super().__init__()
|
||||
self.cfs_plan_scheduler = cfs_plan_scheduler
|
||||
self.stopped = False
|
||||
|
||||
def on_graph_start(self):
|
||||
"""
|
||||
Start timer to check if the workflow need to be suspended.
|
||||
"""
|
||||
|
||||
schedule_id = uuid.uuid4().hex
|
||||
|
||||
def runner():
|
||||
"""
|
||||
Whenever the workflow is running, keep checking if we need to suspend it.
|
||||
Otherwise, return directly.
|
||||
"""
|
||||
try:
|
||||
if self.stopped:
|
||||
self.scheduler.remove_job(schedule_id)
|
||||
return
|
||||
|
||||
if self.cfs_plan_scheduler.can_schedule() == SchedulerCommand.RESOURCE_LIMIT_REACHED:
|
||||
# remove the job
|
||||
self.scheduler.remove_job(schedule_id)
|
||||
|
||||
if not self.command_channel:
|
||||
logger.exception("No command channel to stop the workflow")
|
||||
return
|
||||
|
||||
# send command to pause the workflow
|
||||
self.command_channel.send_command(
|
||||
GraphEngineCommand(
|
||||
command_type=CommandType.PAUSE,
|
||||
payload={
|
||||
"reason": SchedulerCommand.RESOURCE_LIMIT_REACHED,
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
except Exception:
|
||||
logger.exception("scheduler error during check if the workflow need to be suspended")
|
||||
|
||||
self.scheduler.add_job(
|
||||
runner, "interval", seconds=dify_config.ASYNC_WORKFLOW_SCHEDULER_GRANULARITY, id=schedule_id
|
||||
)
|
||||
|
||||
def on_event(self, event: GraphEngineEvent):
|
||||
pass
|
||||
|
||||
def on_graph_end(self, error: Exception | None) -> None:
|
||||
self.stopped = True
|
||||
|
|
@ -0,0 +1,80 @@
|
|||
import logging
|
||||
from datetime import UTC, datetime
|
||||
from typing import Any
|
||||
|
||||
from pydantic import TypeAdapter
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from core.workflow.graph_engine.layers.base import GraphEngineLayer
|
||||
from core.workflow.graph_events.base import GraphEngineEvent
|
||||
from core.workflow.graph_events.graph import GraphRunFailedEvent, GraphRunPausedEvent, GraphRunSucceededEvent
|
||||
from models.engine import db
|
||||
from models.enums import WorkflowTriggerStatus
|
||||
from repositories.sqlalchemy_workflow_trigger_log_repository import SQLAlchemyWorkflowTriggerLogRepository
|
||||
from tasks.workflow_cfs_scheduler.cfs_scheduler import TriggerWorkflowCFSPlanEntity
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class TriggerPostLayer(GraphEngineLayer):
|
||||
"""
|
||||
Trigger post layer.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
cfs_plan_scheduler_entity: TriggerWorkflowCFSPlanEntity,
|
||||
start_time: datetime,
|
||||
trigger_log_id: str,
|
||||
):
|
||||
self.trigger_log_id = trigger_log_id
|
||||
self.start_time = start_time
|
||||
self.cfs_plan_scheduler_entity = cfs_plan_scheduler_entity
|
||||
|
||||
def on_graph_start(self):
|
||||
pass
|
||||
|
||||
def on_event(self, event: GraphEngineEvent):
|
||||
"""
|
||||
Update trigger log with success or failure.
|
||||
"""
|
||||
if isinstance(event, GraphRunSucceededEvent | GraphRunFailedEvent):
|
||||
with Session(db.engine) as session:
|
||||
repo = SQLAlchemyWorkflowTriggerLogRepository(session)
|
||||
trigger_log = repo.get_by_id(self.trigger_log_id)
|
||||
if not trigger_log:
|
||||
logger.exception("Trigger log not found: %s", self.trigger_log_id)
|
||||
return
|
||||
|
||||
# Calculate elapsed time
|
||||
elapsed_time = (datetime.now(UTC) - self.start_time).total_seconds()
|
||||
|
||||
# Extract relevant data from result
|
||||
if not self.graph_runtime_state:
|
||||
logger.exception("Graph runtime state is not set")
|
||||
return
|
||||
|
||||
outputs = self.graph_runtime_state.outputs
|
||||
|
||||
workflow_run_id = outputs.get("workflow_run_id")
|
||||
total_tokens = self.graph_runtime_state.total_tokens
|
||||
|
||||
# Update trigger log with success
|
||||
trigger_log.status = (
|
||||
WorkflowTriggerStatus.SUCCEEDED
|
||||
if isinstance(event, GraphRunSucceededEvent)
|
||||
else WorkflowTriggerStatus.FAILED
|
||||
)
|
||||
trigger_log.workflow_run_id = workflow_run_id
|
||||
trigger_log.outputs = TypeAdapter(dict[str, Any]).dump_json(outputs).decode()
|
||||
trigger_log.elapsed_time = elapsed_time
|
||||
trigger_log.total_tokens = total_tokens
|
||||
trigger_log.finished_at = datetime.now(UTC)
|
||||
repo.update(trigger_log)
|
||||
session.commit()
|
||||
elif isinstance(event, GraphRunPausedEvent):
|
||||
# FIXME: handle the paused event
|
||||
pass
|
||||
|
||||
def on_graph_end(self, error: Exception | None) -> None:
|
||||
pass
|
||||
|
|
@ -88,6 +88,7 @@ dependencies = [
|
|||
"packaging~=23.2",
|
||||
"croniter>=6.0.0",
|
||||
"weaviate-client==4.17.0",
|
||||
"apscheduler>=3.11.0",
|
||||
]
|
||||
# Before adding new dependency, consider place it in
|
||||
# alphabet order (a-z) and suitable group.
|
||||
|
|
|
|||
|
|
@ -160,9 +160,6 @@ class AsyncWorkflowService:
|
|||
else: # SANDBOX
|
||||
task = execute_workflow_sandbox.delay(task_data_dict) # type: ignore
|
||||
|
||||
if not task:
|
||||
raise ValueError(f"Failed to queue task for queue: {queue_name}")
|
||||
|
||||
# 10. Update trigger log with task info
|
||||
trigger_log.status = WorkflowTriggerStatus.QUEUED
|
||||
trigger_log.celery_task_id = task.id
|
||||
|
|
|
|||
|
|
@ -21,13 +21,13 @@ from core.workflow.enums import NodeType
|
|||
from extensions.ext_database import db
|
||||
from extensions.ext_redis import redis_client
|
||||
from factories import file_factory
|
||||
from models.enums import AppTriggerStatus, AppTriggerType, WorkflowRunTriggeredFrom
|
||||
from models.enums import AppTriggerStatus, AppTriggerType
|
||||
from models.model import App
|
||||
from models.trigger import AppTrigger, WorkflowWebhookTrigger
|
||||
from models.workflow import Workflow
|
||||
from services.async_workflow_service import AsyncWorkflowService
|
||||
from services.end_user_service import EndUserService
|
||||
from services.workflow.entities import TriggerData
|
||||
from services.workflow.entities import WebhookTriggerData
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
|
@ -714,11 +714,10 @@ class WebhookService:
|
|||
workflow_inputs = cls.build_workflow_inputs(webhook_data)
|
||||
|
||||
# Create trigger data
|
||||
trigger_data = TriggerData(
|
||||
trigger_data = WebhookTriggerData(
|
||||
app_id=webhook_trigger.app_id,
|
||||
workflow_id=workflow.id,
|
||||
root_node_id=webhook_trigger.node_id, # Start from the webhook node
|
||||
trigger_type=WorkflowRunTriggeredFrom.WEBHOOK,
|
||||
inputs=workflow_inputs,
|
||||
tenant_id=webhook_trigger.tenant_id,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -8,7 +8,7 @@ from typing import Any, Optional
|
|||
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
|
||||
from models.enums import WorkflowRunTriggeredFrom
|
||||
from models.enums import AppTriggerType, WorkflowRunTriggeredFrom
|
||||
|
||||
|
||||
class AsyncTriggerStatus(StrEnum):
|
||||
|
|
@ -28,7 +28,8 @@ class TriggerData(BaseModel):
|
|||
root_node_id: str
|
||||
inputs: Mapping[str, Any]
|
||||
files: Sequence[Mapping[str, Any]] = Field(default_factory=list)
|
||||
trigger_type: WorkflowRunTriggeredFrom
|
||||
trigger_type: AppTriggerType
|
||||
trigger_from: WorkflowRunTriggeredFrom
|
||||
|
||||
model_config = ConfigDict(use_enum_values=True)
|
||||
|
||||
|
|
@ -36,24 +37,22 @@ class TriggerData(BaseModel):
|
|||
class WebhookTriggerData(TriggerData):
|
||||
"""Webhook-specific trigger data"""
|
||||
|
||||
trigger_type: WorkflowRunTriggeredFrom = WorkflowRunTriggeredFrom.WEBHOOK
|
||||
webhook_url: str
|
||||
headers: Mapping[str, str] = Field(default_factory=dict)
|
||||
method: str = "POST"
|
||||
trigger_type: AppTriggerType = AppTriggerType.TRIGGER_WEBHOOK
|
||||
trigger_from: WorkflowRunTriggeredFrom = WorkflowRunTriggeredFrom.WEBHOOK
|
||||
|
||||
|
||||
class ScheduleTriggerData(TriggerData):
|
||||
"""Schedule-specific trigger data"""
|
||||
|
||||
trigger_type: WorkflowRunTriggeredFrom = WorkflowRunTriggeredFrom.SCHEDULE
|
||||
schedule_id: str
|
||||
cron_expression: str
|
||||
trigger_type: AppTriggerType = AppTriggerType.TRIGGER_SCHEDULE
|
||||
trigger_from: WorkflowRunTriggeredFrom = WorkflowRunTriggeredFrom.SCHEDULE
|
||||
|
||||
|
||||
class PluginTriggerData(TriggerData):
|
||||
"""Plugin webhook trigger data"""
|
||||
|
||||
trigger_type: WorkflowRunTriggeredFrom = WorkflowRunTriggeredFrom.PLUGIN
|
||||
trigger_type: AppTriggerType = AppTriggerType.TRIGGER_PLUGIN
|
||||
trigger_from: WorkflowRunTriggeredFrom = WorkflowRunTriggeredFrom.PLUGIN
|
||||
plugin_id: str
|
||||
endpoint_id: str
|
||||
|
||||
|
|
@ -125,3 +124,21 @@ class TriggerLogResponse(BaseModel):
|
|||
finished_at: Optional[str] = None
|
||||
|
||||
model_config = ConfigDict(use_enum_values=True)
|
||||
|
||||
|
||||
class WorkflowScheduleCFSPlanEntity(BaseModel):
|
||||
"""
|
||||
CFS plan entity.
|
||||
Ensure each workflow run inside Dify is associated with a CFS(Completely Fair Scheduler) plan.
|
||||
|
||||
"""
|
||||
|
||||
class Strategy(StrEnum):
|
||||
"""
|
||||
CFS plan strategy.
|
||||
"""
|
||||
|
||||
TimeSlice = "time-slice" # time-slice based plan
|
||||
|
||||
schedule_strategy: Strategy
|
||||
granularity: int = Field(default=-1) # -1 means infinite
|
||||
|
|
|
|||
|
|
@ -0,0 +1,34 @@
|
|||
from abc import ABC, abstractmethod
|
||||
from enum import StrEnum
|
||||
|
||||
from services.workflow.entities import WorkflowScheduleCFSPlanEntity
|
||||
|
||||
|
||||
class SchedulerCommand(StrEnum):
|
||||
"""
|
||||
Scheduler command.
|
||||
"""
|
||||
|
||||
RESOURCE_LIMIT_REACHED = "resource_limit_reached"
|
||||
NONE = "none"
|
||||
|
||||
|
||||
class CFSPlanScheduler(ABC):
|
||||
"""
|
||||
CFS plan scheduler.
|
||||
"""
|
||||
|
||||
def __init__(self, plan: WorkflowScheduleCFSPlanEntity):
|
||||
"""
|
||||
Initialize the CFS plan scheduler.
|
||||
|
||||
Args:
|
||||
plan: The CFS plan.
|
||||
"""
|
||||
self.plan = plan
|
||||
|
||||
@abstractmethod
|
||||
def can_schedule(self) -> SchedulerCommand:
|
||||
"""
|
||||
Whether a workflow run can be scheduled.
|
||||
"""
|
||||
|
|
@ -5,7 +5,6 @@ These tasks handle workflow execution for different subscription tiers
|
|||
with appropriate retry policies and error handling.
|
||||
"""
|
||||
|
||||
import json
|
||||
from datetime import UTC, datetime
|
||||
from typing import Any
|
||||
|
||||
|
|
@ -13,8 +12,9 @@ from celery import shared_task
|
|||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session, sessionmaker
|
||||
|
||||
from configs import dify_config
|
||||
from core.app.apps.workflow.app_generator import WorkflowAppGenerator
|
||||
from core.app.engine_layers.timeslice_layer import TimesliceLayer
|
||||
from core.app.engine_layers.trigger_post_layer import TriggerPostLayer
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from extensions.ext_database import db
|
||||
from models.account import Account
|
||||
|
|
@ -24,57 +24,64 @@ from models.trigger import WorkflowTriggerLog
|
|||
from models.workflow import Workflow
|
||||
from repositories.sqlalchemy_workflow_trigger_log_repository import SQLAlchemyWorkflowTriggerLogRepository
|
||||
from services.errors.app import WorkflowNotFoundError
|
||||
from services.workflow.entities import AsyncTriggerExecutionResult, AsyncTriggerStatus, TriggerData, WorkflowTaskData
|
||||
|
||||
# Determine queue names based on edition
|
||||
if dify_config.EDITION == "CLOUD":
|
||||
# Cloud edition: separate queues for different tiers
|
||||
_professional_queue = "workflow_professional"
|
||||
_team_queue = "workflow_team"
|
||||
_sandbox_queue = "workflow_sandbox"
|
||||
else:
|
||||
# Community edition: single workflow queue (not dataset)
|
||||
_professional_queue = "workflow"
|
||||
_team_queue = "workflow"
|
||||
_sandbox_queue = "workflow"
|
||||
|
||||
# Define constants
|
||||
PROFESSIONAL_QUEUE = _professional_queue
|
||||
TEAM_QUEUE = _team_queue
|
||||
SANDBOX_QUEUE = _sandbox_queue
|
||||
from services.workflow.entities import (
|
||||
TriggerData,
|
||||
WorkflowScheduleCFSPlanEntity,
|
||||
WorkflowTaskData,
|
||||
)
|
||||
from tasks.workflow_cfs_scheduler.cfs_scheduler import TriggerCFSPlanScheduler, TriggerWorkflowCFSPlanEntity
|
||||
from tasks.workflow_cfs_scheduler.entities import AsyncWorkflowQueue
|
||||
|
||||
|
||||
@shared_task(queue=PROFESSIONAL_QUEUE)
|
||||
def execute_workflow_professional(task_data_dict: dict[str, Any]) -> dict[str, Any]:
|
||||
@shared_task(queue=AsyncWorkflowQueue.PROFESSIONAL_QUEUE)
|
||||
def execute_workflow_professional(task_data_dict: dict[str, Any]):
|
||||
"""Execute workflow for professional tier with highest priority"""
|
||||
task_data = WorkflowTaskData.model_validate(task_data_dict)
|
||||
return _execute_workflow_common(task_data).model_dump()
|
||||
cfs_plan_scheduler_entity = TriggerWorkflowCFSPlanEntity(
|
||||
queue=AsyncWorkflowQueue.PROFESSIONAL_QUEUE, schedule_strategy=WorkflowScheduleCFSPlanEntity.Strategy.TimeSlice
|
||||
)
|
||||
_execute_workflow_common(
|
||||
task_data,
|
||||
TriggerCFSPlanScheduler(plan=cfs_plan_scheduler_entity),
|
||||
cfs_plan_scheduler_entity,
|
||||
)
|
||||
|
||||
|
||||
@shared_task(queue=TEAM_QUEUE)
|
||||
def execute_workflow_team(task_data_dict: dict[str, Any]) -> dict[str, Any]:
|
||||
@shared_task(queue=AsyncWorkflowQueue.TEAM_QUEUE)
|
||||
def execute_workflow_team(task_data_dict: dict[str, Any]):
|
||||
"""Execute workflow for team tier"""
|
||||
task_data = WorkflowTaskData.model_validate(task_data_dict)
|
||||
return _execute_workflow_common(task_data).model_dump()
|
||||
cfs_plan_scheduler_entity = TriggerWorkflowCFSPlanEntity(
|
||||
queue=AsyncWorkflowQueue.TEAM_QUEUE, schedule_strategy=WorkflowScheduleCFSPlanEntity.Strategy.TimeSlice
|
||||
)
|
||||
_execute_workflow_common(
|
||||
task_data,
|
||||
TriggerCFSPlanScheduler(plan=cfs_plan_scheduler_entity),
|
||||
cfs_plan_scheduler_entity,
|
||||
)
|
||||
|
||||
|
||||
@shared_task(queue=SANDBOX_QUEUE)
|
||||
def execute_workflow_sandbox(task_data_dict: dict[str, Any]) -> dict[str, Any]:
|
||||
@shared_task(queue=AsyncWorkflowQueue.SANDBOX_QUEUE)
|
||||
def execute_workflow_sandbox(task_data_dict: dict[str, Any]):
|
||||
"""Execute workflow for free tier with lower retry limit"""
|
||||
task_data = WorkflowTaskData.model_validate(task_data_dict)
|
||||
return _execute_workflow_common(task_data).model_dump()
|
||||
cfs_plan_scheduler_entity = TriggerWorkflowCFSPlanEntity(
|
||||
queue=AsyncWorkflowQueue.SANDBOX_QUEUE, schedule_strategy=WorkflowScheduleCFSPlanEntity.Strategy.TimeSlice
|
||||
)
|
||||
_execute_workflow_common(
|
||||
task_data,
|
||||
TriggerCFSPlanScheduler(plan=cfs_plan_scheduler_entity),
|
||||
cfs_plan_scheduler_entity,
|
||||
)
|
||||
|
||||
|
||||
def _execute_workflow_common(task_data: WorkflowTaskData) -> AsyncTriggerExecutionResult:
|
||||
"""
|
||||
Common workflow execution logic with trigger log updates
|
||||
def _execute_workflow_common(
|
||||
task_data: WorkflowTaskData,
|
||||
cfs_plan_scheduler: TriggerCFSPlanScheduler,
|
||||
cfs_plan_scheduler_entity: TriggerWorkflowCFSPlanEntity,
|
||||
):
|
||||
"""Execute workflow with common logic and trigger log updates."""
|
||||
|
||||
Args:
|
||||
task_data: Validated Pydantic model with task data
|
||||
|
||||
Returns:
|
||||
AsyncTriggerExecutionResult: Pydantic model with execution results
|
||||
"""
|
||||
# Create a new session for this task
|
||||
session_factory = sessionmaker(bind=db.engine, expire_on_commit=False)
|
||||
|
||||
|
|
@ -86,11 +93,7 @@ def _execute_workflow_common(task_data: WorkflowTaskData) -> AsyncTriggerExecuti
|
|||
|
||||
if not trigger_log:
|
||||
# This should not happen, but handle gracefully
|
||||
return AsyncTriggerExecutionResult(
|
||||
execution_id=task_data.workflow_trigger_log_id,
|
||||
status=AsyncTriggerStatus.FAILED,
|
||||
error=f"Trigger log not found: {task_data.workflow_trigger_log_id}",
|
||||
)
|
||||
return
|
||||
|
||||
# Reconstruct execution data from trigger log
|
||||
trigger_data = TriggerData.model_validate_json(trigger_log.trigger_data)
|
||||
|
|
@ -126,7 +129,7 @@ def _execute_workflow_common(task_data: WorkflowTaskData) -> AsyncTriggerExecuti
|
|||
args["workflow_id"] = str(trigger_data.workflow_id)
|
||||
|
||||
# Execute the workflow with the trigger type
|
||||
result = generator.generate(
|
||||
generator.generate(
|
||||
app_model=app_model,
|
||||
workflow=workflow,
|
||||
user=user,
|
||||
|
|
@ -136,38 +139,10 @@ def _execute_workflow_common(task_data: WorkflowTaskData) -> AsyncTriggerExecuti
|
|||
call_depth=0,
|
||||
triggered_from=trigger_data.trigger_type,
|
||||
root_node_id=trigger_data.root_node_id,
|
||||
)
|
||||
|
||||
# Calculate elapsed time
|
||||
elapsed_time = (datetime.now(UTC) - start_time).total_seconds()
|
||||
|
||||
# Extract relevant data from result
|
||||
if isinstance(result, dict):
|
||||
workflow_run_id = result.get("workflow_run_id")
|
||||
total_tokens = result.get("total_tokens")
|
||||
outputs = result
|
||||
else:
|
||||
# Handle generator result - collect all data
|
||||
workflow_run_id = None
|
||||
total_tokens = None
|
||||
outputs = {"data": "streaming_result"}
|
||||
|
||||
# Update trigger log with success
|
||||
trigger_log.status = WorkflowTriggerStatus.SUCCEEDED
|
||||
trigger_log.workflow_run_id = workflow_run_id
|
||||
trigger_log.outputs = json.dumps(outputs)
|
||||
trigger_log.elapsed_time = elapsed_time
|
||||
trigger_log.total_tokens = total_tokens
|
||||
trigger_log.finished_at = datetime.now(UTC)
|
||||
trigger_log_repo.update(trigger_log)
|
||||
session.commit()
|
||||
|
||||
return AsyncTriggerExecutionResult(
|
||||
execution_id=trigger_log.id,
|
||||
status=AsyncTriggerStatus.COMPLETED,
|
||||
result=outputs,
|
||||
elapsed_time=elapsed_time,
|
||||
total_tokens=total_tokens,
|
||||
layers=[
|
||||
TimesliceLayer(cfs_plan_scheduler),
|
||||
TriggerPostLayer(cfs_plan_scheduler_entity, start_time, trigger_log.id),
|
||||
],
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
|
|
@ -184,10 +159,6 @@ def _execute_workflow_common(task_data: WorkflowTaskData) -> AsyncTriggerExecuti
|
|||
# Final failure - no retry logic (simplified like RAG tasks)
|
||||
session.commit()
|
||||
|
||||
return AsyncTriggerExecutionResult(
|
||||
execution_id=trigger_log.id, status=AsyncTriggerStatus.FAILED, error=str(e), elapsed_time=elapsed_time
|
||||
)
|
||||
|
||||
|
||||
def _get_user(session: Session, trigger_log: WorkflowTriggerLog) -> Account | EndUser:
|
||||
"""Compose user from trigger log"""
|
||||
|
|
|
|||
|
|
@ -23,7 +23,6 @@ from core.trigger.trigger_manager import TriggerManager
|
|||
from core.workflow.enums import NodeType
|
||||
from core.workflow.nodes.trigger_plugin.entities import TriggerEventNodeData
|
||||
from extensions.ext_database import db
|
||||
from models.enums import WorkflowRunTriggeredFrom
|
||||
from models.model import EndUser
|
||||
from models.provider_ids import TriggerProviderID
|
||||
from models.trigger import TriggerSubscription, WorkflowPluginTrigger
|
||||
|
|
@ -199,7 +198,6 @@ def dispatch_triggered_workflow(
|
|||
tenant_id=subscription.tenant_id,
|
||||
workflow_id=workflow.id,
|
||||
root_node_id=plugin_trigger.node_id,
|
||||
trigger_type=WorkflowRunTriggeredFrom.PLUGIN,
|
||||
plugin_id=subscription.provider_id,
|
||||
endpoint_id=subscription.endpoint_id,
|
||||
inputs=invoke_response.variables,
|
||||
|
|
|
|||
|
|
@ -0,0 +1,33 @@
|
|||
from services.workflow.entities import WorkflowScheduleCFSPlanEntity
|
||||
from services.workflow.scheduler import CFSPlanScheduler, SchedulerCommand
|
||||
from tasks.workflow_cfs_scheduler.entities import AsyncWorkflowQueue
|
||||
|
||||
|
||||
class TriggerWorkflowCFSPlanEntity(WorkflowScheduleCFSPlanEntity):
|
||||
"""
|
||||
Trigger workflow CFS plan entity.
|
||||
"""
|
||||
|
||||
queue: AsyncWorkflowQueue
|
||||
|
||||
|
||||
class TriggerCFSPlanScheduler(CFSPlanScheduler):
|
||||
"""
|
||||
Trigger workflow CFS plan scheduler.
|
||||
"""
|
||||
|
||||
def can_schedule(self) -> SchedulerCommand:
|
||||
"""
|
||||
Check if the workflow can be scheduled.
|
||||
"""
|
||||
assert isinstance(self.plan, TriggerWorkflowCFSPlanEntity)
|
||||
|
||||
if self.plan.queue in [AsyncWorkflowQueue.PROFESSIONAL_QUEUE, AsyncWorkflowQueue.TEAM_QUEUE]:
|
||||
"""
|
||||
permitted all paid users to schedule the workflow
|
||||
"""
|
||||
return SchedulerCommand.NONE
|
||||
|
||||
# FIXME: avoid the sandbox user's workflow at a running state for ever
|
||||
|
||||
return SchedulerCommand.NONE
|
||||
|
|
@ -0,0 +1,22 @@
|
|||
from enum import StrEnum
|
||||
|
||||
from configs import dify_config
|
||||
|
||||
# Determine queue names based on edition
|
||||
if dify_config.EDITION == "CLOUD":
|
||||
# Cloud edition: separate queues for different tiers
|
||||
_professional_queue = "workflow_professional"
|
||||
_team_queue = "workflow_team"
|
||||
_sandbox_queue = "workflow_sandbox"
|
||||
else:
|
||||
# Community edition: single workflow queue (not dataset)
|
||||
_professional_queue = "workflow"
|
||||
_team_queue = "workflow"
|
||||
_sandbox_queue = "workflow"
|
||||
|
||||
|
||||
class AsyncWorkflowQueue(StrEnum):
|
||||
# Define constants
|
||||
PROFESSIONAL_QUEUE = _professional_queue
|
||||
TEAM_QUEUE = _team_queue
|
||||
SANDBOX_QUEUE = _sandbox_queue
|
||||
|
|
@ -14,11 +14,10 @@ from core.workflow.nodes.trigger_schedule.exc import (
|
|||
TenantOwnerNotFoundError,
|
||||
)
|
||||
from extensions.ext_database import db
|
||||
from models.enums import WorkflowRunTriggeredFrom
|
||||
from models.trigger import WorkflowSchedulePlan
|
||||
from services.async_workflow_service import AsyncWorkflowService
|
||||
from services.trigger.schedule_service import ScheduleService
|
||||
from services.workflow.entities import TriggerData
|
||||
from services.workflow.entities import ScheduleTriggerData
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
|
@ -57,10 +56,9 @@ def run_schedule_trigger(schedule_id: str) -> None:
|
|||
response = AsyncWorkflowService.trigger_workflow_async(
|
||||
session=session,
|
||||
user=tenant_owner,
|
||||
trigger_data=TriggerData(
|
||||
trigger_data=ScheduleTriggerData(
|
||||
app_id=schedule.app_id,
|
||||
root_node_id=schedule.node_id,
|
||||
trigger_type=WorkflowRunTriggeredFrom.SCHEDULE,
|
||||
inputs=inputs,
|
||||
tenant_id=schedule.tenant_id,
|
||||
),
|
||||
|
|
|
|||
4603
api/uv.lock
4603
api/uv.lock
File diff suppressed because it is too large
Load Diff
Loading…
Reference in New Issue