feat/trigger universal entry (#24358)

Co-authored-by: Claude <noreply@anthropic.com>
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
This commit is contained in:
Yeuoly 2025-08-23 20:18:08 +08:00 committed by GitHub
parent 8e93a8a2e2
commit 6aed7e3ff4
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
16 changed files with 1730 additions and 9 deletions

View File

@ -54,7 +54,7 @@
"--loglevel",
"DEBUG",
"-Q",
"dataset,generation,mail,ops_trace,app_deletion"
"dataset,generation,mail,ops_trace,app_deletion,workflow"
]
}
]

View File

@ -54,6 +54,7 @@ class WorkflowAppGenerator(BaseAppGenerator):
streaming: Literal[True],
call_depth: int,
workflow_thread_pool_id: Optional[str],
triggered_from: Optional[WorkflowRunTriggeredFrom] = None,
) -> Generator[Mapping | str, None, None]: ...
@overload
@ -68,6 +69,7 @@ class WorkflowAppGenerator(BaseAppGenerator):
streaming: Literal[False],
call_depth: int,
workflow_thread_pool_id: Optional[str],
triggered_from: Optional[WorkflowRunTriggeredFrom] = None,
) -> Mapping[str, Any]: ...
@overload
@ -82,6 +84,7 @@ class WorkflowAppGenerator(BaseAppGenerator):
streaming: bool,
call_depth: int,
workflow_thread_pool_id: Optional[str],
triggered_from: Optional[WorkflowRunTriggeredFrom] = None,
) -> Union[Mapping[str, Any], Generator[Mapping | str, None, None]]: ...
def generate(
@ -95,6 +98,7 @@ class WorkflowAppGenerator(BaseAppGenerator):
streaming: bool = True,
call_depth: int = 0,
workflow_thread_pool_id: Optional[str] = None,
triggered_from: Optional[WorkflowRunTriggeredFrom] = None,
) -> Union[Mapping[str, Any], Generator[Mapping | str, None, None]]:
files: Sequence[Mapping[str, Any]] = args.get("files") or []
@ -159,7 +163,10 @@ class WorkflowAppGenerator(BaseAppGenerator):
# Create session factory
session_factory = sessionmaker(bind=db.engine, expire_on_commit=False)
# Create workflow execution(aka workflow run) repository
if invoke_from == InvokeFrom.DEBUGGER:
if triggered_from is not None:
# Use explicitly provided triggered_from (for async triggers)
workflow_triggered_from = triggered_from
elif invoke_from == InvokeFrom.DEBUGGER:
workflow_triggered_from = WorkflowRunTriggeredFrom.DEBUGGING
else:
workflow_triggered_from = WorkflowRunTriggeredFrom.APP_RUN

View File

@ -30,9 +30,41 @@ if [[ "${MODE}" == "worker" ]]; then
CONCURRENCY_OPTION="-c ${CELERY_WORKER_AMOUNT:-1}"
fi
exec celery -A app.celery worker -P ${CELERY_WORKER_CLASS:-gevent} $CONCURRENCY_OPTION \
# Configure queues based on edition if not explicitly set
if [[ -z "${CELERY_QUEUES}" ]]; then
if [[ "${EDITION}" == "CLOUD" ]]; then
# Cloud edition: separate queues for dataset and trigger tasks
DEFAULT_QUEUES="dataset,mail,ops_trace,app_deletion,plugin,workflow_storage,workflow_professional,workflow_team,workflow_sandbox"
else
# Community edition (SELF_HOSTED): dataset and workflow have separate queues
DEFAULT_QUEUES="dataset,mail,ops_trace,app_deletion,plugin,workflow_storage,workflow"
fi
else
DEFAULT_QUEUES="${CELERY_QUEUES}"
fi
# Support for Kubernetes deployment with specific queue workers
# Environment variables that can be set:
# - CELERY_WORKER_QUEUES: Comma-separated list of queues (overrides CELERY_QUEUES)
# - CELERY_WORKER_CONCURRENCY: Number of worker processes (overrides CELERY_WORKER_AMOUNT)
# - CELERY_WORKER_POOL: Pool implementation (overrides CELERY_WORKER_CLASS)
if [[ -n "${CELERY_WORKER_QUEUES}" ]]; then
DEFAULT_QUEUES="${CELERY_WORKER_QUEUES}"
echo "Using CELERY_WORKER_QUEUES: ${DEFAULT_QUEUES}"
fi
if [[ -n "${CELERY_WORKER_CONCURRENCY}" ]]; then
CONCURRENCY_OPTION="-c ${CELERY_WORKER_CONCURRENCY}"
echo "Using CELERY_WORKER_CONCURRENCY: ${CELERY_WORKER_CONCURRENCY}"
fi
WORKER_POOL="${CELERY_WORKER_POOL:-${CELERY_WORKER_CLASS:-gevent}}"
echo "Starting Celery worker with queues: ${DEFAULT_QUEUES}"
exec celery -A app.celery worker -P ${WORKER_POOL} $CONCURRENCY_OPTION \
--max-tasks-per-child ${MAX_TASK_PRE_CHILD:-50} --loglevel ${LOG_LEVEL:-INFO} \
-Q ${CELERY_QUEUES:-dataset,mail,ops_trace,app_deletion,plugin,workflow_storage}
-Q ${DEFAULT_QUEUES}
elif [[ "${MODE}" == "beat" ]]; then
exec celery -A app.celery beat --loglevel ${LOG_LEVEL:-INFO}

View File

@ -96,7 +96,9 @@ def init_app(app: DifyApp) -> Celery:
celery_app.set_default()
app.extensions["celery"] = celery_app
imports = []
imports = [
"tasks.async_workflow_tasks", # trigger workers
]
day = dify_config.CELERY_BEAT_SCHEDULER_TIME
# if you add a new task, please add the switch to CeleryScheduleTasksConfig

View File

@ -0,0 +1,66 @@
"""empty message
Revision ID: 994bdf7197ab
Revises: fa8b0fa6f407
Create Date: 2025-08-23 20:06:35.995973
"""
from alembic import op
import models as models
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = '994bdf7197ab'
down_revision = 'fa8b0fa6f407'
branch_labels = None
depends_on = None
def upgrade():
# ### commands auto generated by Alembic - please adjust! ###
op.create_table('workflow_trigger_logs',
sa.Column('id', models.types.StringUUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False),
sa.Column('tenant_id', models.types.StringUUID(), nullable=False),
sa.Column('app_id', models.types.StringUUID(), nullable=False),
sa.Column('workflow_id', models.types.StringUUID(), nullable=False),
sa.Column('workflow_run_id', models.types.StringUUID(), nullable=True),
sa.Column('trigger_type', sa.String(length=50), nullable=False),
sa.Column('trigger_data', sa.Text(), nullable=False),
sa.Column('inputs', sa.Text(), nullable=False),
sa.Column('outputs', sa.Text(), nullable=True),
sa.Column('status', sa.String(length=50), nullable=False),
sa.Column('error', sa.Text(), nullable=True),
sa.Column('queue_name', sa.String(length=100), nullable=False),
sa.Column('celery_task_id', sa.String(length=255), nullable=True),
sa.Column('retry_count', sa.Integer(), nullable=False),
sa.Column('elapsed_time', sa.Float(), nullable=True),
sa.Column('total_tokens', sa.Integer(), nullable=True),
sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP'), nullable=False),
sa.Column('created_by_role', sa.String(length=255), nullable=False),
sa.Column('created_by', sa.String(length=255), nullable=False),
sa.Column('triggered_at', sa.DateTime(), nullable=True),
sa.Column('finished_at', sa.DateTime(), nullable=True),
sa.PrimaryKeyConstraint('id', name='workflow_trigger_log_pkey')
)
with op.batch_alter_table('workflow_trigger_logs', schema=None) as batch_op:
batch_op.create_index('workflow_trigger_log_created_at_idx', ['created_at'], unique=False)
batch_op.create_index('workflow_trigger_log_status_idx', ['status'], unique=False)
batch_op.create_index('workflow_trigger_log_tenant_app_idx', ['tenant_id', 'app_id'], unique=False)
batch_op.create_index('workflow_trigger_log_workflow_id_idx', ['workflow_id'], unique=False)
batch_op.create_index('workflow_trigger_log_workflow_run_idx', ['workflow_run_id'], unique=False)
# ### end Alembic commands ###
def downgrade():
# ### commands auto generated by Alembic - please adjust! ###
with op.batch_alter_table('workflow_trigger_logs', schema=None) as batch_op:
batch_op.drop_index('workflow_trigger_log_workflow_run_idx')
batch_op.drop_index('workflow_trigger_log_workflow_id_idx')
batch_op.drop_index('workflow_trigger_log_tenant_app_idx')
batch_op.drop_index('workflow_trigger_log_status_idx')
batch_op.drop_index('workflow_trigger_log_created_at_idx')
op.drop_table('workflow_trigger_logs')
# ### end Alembic commands ###

View File

@ -13,7 +13,10 @@ class UserFrom(StrEnum):
class WorkflowRunTriggeredFrom(StrEnum):
DEBUGGING = "debugging"
APP_RUN = "app-run"
APP_RUN = "app-run" # webapp / service api
WEBHOOK = "webhook"
SCHEDULE = "schedule"
PLUGIN = "plugin"
class DraftVariableType(StrEnum):

View File

@ -1262,3 +1262,122 @@ class WorkflowDraftVariable(Base):
def is_system_variable_editable(name: str) -> bool:
return name in _EDITABLE_SYSTEM_VARIABLE
class WorkflowTriggerStatus(StrEnum):
"""Workflow Trigger Execution Status"""
PENDING = "pending"
QUEUED = "queued"
RUNNING = "running"
SUCCEEDED = "succeeded"
FAILED = "failed"
RATE_LIMITED = "rate_limited"
RETRYING = "retrying"
class WorkflowTriggerLog(Base):
"""
Workflow Trigger Log
Track async trigger workflow runs with re-invocation capability
Attributes:
- id (uuid) Trigger Log ID (used as workflow_trigger_log_id)
- tenant_id (uuid) Workspace ID
- app_id (uuid) App ID
- workflow_id (uuid) Workflow ID
- workflow_run_id (uuid) Optional - Associated workflow run ID when execution starts
- trigger_type (string) Type of trigger: webhook, schedule, plugin
- trigger_data (text) Full trigger data including inputs (JSON)
- inputs (text) Input parameters (JSON)
- outputs (text) Optional - Output content (JSON)
- status (string) Execution status
- error (text) Optional - Error message if failed
- queue_name (string) Celery queue used
- celery_task_id (string) Optional - Celery task ID for tracking
- retry_count (int) Number of retry attempts
- elapsed_time (float) Optional - Time consumption in seconds
- total_tokens (int) Optional - Total tokens used
- created_by_role (string) Creator role: account, end_user
- created_by (string) Creator ID
- created_at (timestamp) Creation time
- triggered_at (timestamp) Optional - When actually triggered
- finished_at (timestamp) Optional - Completion time
"""
__tablename__ = "workflow_trigger_logs"
__table_args__ = (
sa.PrimaryKeyConstraint("id", name="workflow_trigger_log_pkey"),
sa.Index("workflow_trigger_log_tenant_app_idx", "tenant_id", "app_id"),
sa.Index("workflow_trigger_log_status_idx", "status"),
sa.Index("workflow_trigger_log_created_at_idx", "created_at"),
sa.Index("workflow_trigger_log_workflow_run_idx", "workflow_run_id"),
sa.Index("workflow_trigger_log_workflow_id_idx", "workflow_id"),
)
id: Mapped[str] = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()"))
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
app_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
workflow_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
workflow_run_id: Mapped[Optional[str]] = mapped_column(StringUUID, nullable=True)
trigger_type: Mapped[str] = mapped_column(String(50), nullable=False)
trigger_data: Mapped[str] = mapped_column(sa.Text, nullable=False) # Full TriggerData as JSON
inputs: Mapped[str] = mapped_column(sa.Text, nullable=False) # Just inputs for easy viewing
outputs: Mapped[Optional[str]] = mapped_column(sa.Text, nullable=True)
status: Mapped[str] = mapped_column(String(50), nullable=False, default=WorkflowTriggerStatus.PENDING)
error: Mapped[Optional[str]] = mapped_column(sa.Text, nullable=True)
queue_name: Mapped[str] = mapped_column(String(100), nullable=False)
celery_task_id: Mapped[Optional[str]] = mapped_column(String(255), nullable=True)
retry_count: Mapped[int] = mapped_column(sa.Integer, nullable=False, default=0)
elapsed_time: Mapped[Optional[float]] = mapped_column(sa.Float, nullable=True)
total_tokens: Mapped[Optional[int]] = mapped_column(sa.Integer, nullable=True)
created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp())
created_by_role: Mapped[str] = mapped_column(String(255), nullable=False)
created_by: Mapped[str] = mapped_column(String(255), nullable=False)
triggered_at: Mapped[Optional[datetime]] = mapped_column(DateTime, nullable=True)
finished_at: Mapped[Optional[datetime]] = mapped_column(DateTime, nullable=True)
@property
def created_by_account(self):
created_by_role = CreatorUserRole(self.created_by_role)
return db.session.get(Account, self.created_by) if created_by_role == CreatorUserRole.ACCOUNT else None
@property
def created_by_end_user(self):
from models.model import EndUser
created_by_role = CreatorUserRole(self.created_by_role)
return db.session.get(EndUser, self.created_by) if created_by_role == CreatorUserRole.END_USER else None
def to_dict(self) -> dict:
"""Convert to dictionary for API responses"""
return {
"id": self.id,
"tenant_id": self.tenant_id,
"app_id": self.app_id,
"workflow_id": self.workflow_id,
"workflow_run_id": self.workflow_run_id,
"trigger_type": self.trigger_type,
"trigger_data": json.loads(self.trigger_data),
"inputs": json.loads(self.inputs),
"outputs": json.loads(self.outputs) if self.outputs else None,
"status": self.status,
"error": self.error,
"queue_name": self.queue_name,
"celery_task_id": self.celery_task_id,
"retry_count": self.retry_count,
"elapsed_time": self.elapsed_time,
"total_tokens": self.total_tokens,
"created_by_role": self.created_by_role,
"created_by": self.created_by,
"created_at": self.created_at.isoformat() if self.created_at else None,
"triggered_at": self.triggered_at.isoformat() if self.triggered_at else None,
"finished_at": self.finished_at.isoformat() if self.finished_at else None,
}

View File

@ -0,0 +1,198 @@
"""
SQLAlchemy implementation of WorkflowTriggerLogRepository.
"""
from collections.abc import Sequence
from datetime import datetime, timedelta
from typing import Any, Optional
from sqlalchemy import and_, delete, func, select, update
from sqlalchemy.orm import Session
from models.workflow import WorkflowTriggerLog, WorkflowTriggerStatus
from repositories.workflow_trigger_log_repository import TriggerLogOrderBy, WorkflowTriggerLogRepository
class SQLAlchemyWorkflowTriggerLogRepository(WorkflowTriggerLogRepository):
"""
SQLAlchemy implementation of WorkflowTriggerLogRepository.
Optimized for large table operations with proper indexing and batch processing.
"""
def __init__(self, session: Session):
self.session = session
def create(self, trigger_log: WorkflowTriggerLog) -> WorkflowTriggerLog:
"""Create a new trigger log entry."""
self.session.add(trigger_log)
self.session.flush()
return trigger_log
def update(self, trigger_log: WorkflowTriggerLog) -> WorkflowTriggerLog:
"""Update an existing trigger log entry."""
self.session.merge(trigger_log)
self.session.flush()
return trigger_log
def get_by_id(self, trigger_log_id: str, tenant_id: Optional[str] = None) -> Optional[WorkflowTriggerLog]:
"""Get a trigger log by its ID."""
query = select(WorkflowTriggerLog).where(WorkflowTriggerLog.id == trigger_log_id)
if tenant_id:
query = query.where(WorkflowTriggerLog.tenant_id == tenant_id)
return self.session.scalar(query)
def get_by_status(
self,
tenant_id: str,
app_id: str,
status: WorkflowTriggerStatus,
limit: int = 100,
offset: int = 0,
order_by: TriggerLogOrderBy = TriggerLogOrderBy.CREATED_AT,
order_desc: bool = True,
) -> Sequence[WorkflowTriggerLog]:
"""Get trigger logs by status with pagination."""
query = select(WorkflowTriggerLog).where(
and_(
WorkflowTriggerLog.tenant_id == tenant_id,
WorkflowTriggerLog.app_id == app_id,
WorkflowTriggerLog.status == status,
)
)
# Apply ordering
order_column = getattr(WorkflowTriggerLog, order_by.value)
if order_desc:
query = query.order_by(order_column.desc())
else:
query = query.order_by(order_column.asc())
# Apply pagination
query = query.limit(limit).offset(offset)
return list(self.session.scalars(query).all())
def get_failed_for_retry(
self, tenant_id: str, max_retry_count: int = 3, limit: int = 100
) -> Sequence[WorkflowTriggerLog]:
"""Get failed trigger logs eligible for retry."""
query = (
select(WorkflowTriggerLog)
.where(
and_(
WorkflowTriggerLog.tenant_id == tenant_id,
WorkflowTriggerLog.status.in_([WorkflowTriggerStatus.FAILED, WorkflowTriggerStatus.RATE_LIMITED]),
WorkflowTriggerLog.retry_count < max_retry_count,
)
)
.order_by(WorkflowTriggerLog.created_at.asc())
.limit(limit)
)
return list(self.session.scalars(query).all())
def get_recent_logs(
self, tenant_id: str, app_id: str, hours: int = 24, limit: int = 100, offset: int = 0
) -> Sequence[WorkflowTriggerLog]:
"""Get recent trigger logs within specified hours."""
since = datetime.utcnow() - timedelta(hours=hours)
query = (
select(WorkflowTriggerLog)
.where(
and_(
WorkflowTriggerLog.tenant_id == tenant_id,
WorkflowTriggerLog.app_id == app_id,
WorkflowTriggerLog.created_at >= since,
)
)
.order_by(WorkflowTriggerLog.created_at.desc())
.limit(limit)
.offset(offset)
)
return list(self.session.scalars(query).all())
def count_by_status(
self,
tenant_id: str,
app_id: str,
status: Optional[WorkflowTriggerStatus] = None,
since: Optional[datetime] = None,
) -> int:
"""Count trigger logs by status."""
query = select(func.count(WorkflowTriggerLog.id)).where(
and_(WorkflowTriggerLog.tenant_id == tenant_id, WorkflowTriggerLog.app_id == app_id)
)
if status:
query = query.where(WorkflowTriggerLog.status == status)
if since:
query = query.where(WorkflowTriggerLog.created_at >= since)
return self.session.scalar(query) or 0
def delete_expired_logs(self, tenant_id: str, before_date: datetime, batch_size: int = 1000) -> int:
"""Delete expired trigger logs in batches."""
total_deleted = 0
while True:
# Get batch of IDs to delete
subquery = (
select(WorkflowTriggerLog.id)
.where(and_(WorkflowTriggerLog.tenant_id == tenant_id, WorkflowTriggerLog.created_at < before_date))
.limit(batch_size)
)
# Delete the batch
result = self.session.execute(delete(WorkflowTriggerLog).where(WorkflowTriggerLog.id.in_(subquery)))
deleted = result.rowcount
total_deleted += deleted
if deleted < batch_size:
break
self.session.commit()
return total_deleted
def archive_completed_logs(
self, tenant_id: str, before_date: datetime, batch_size: int = 1000
) -> Sequence[WorkflowTriggerLog]:
"""Get completed logs for archival."""
query = (
select(WorkflowTriggerLog)
.where(
and_(
WorkflowTriggerLog.tenant_id == tenant_id,
WorkflowTriggerLog.status == WorkflowTriggerStatus.SUCCEEDED,
WorkflowTriggerLog.finished_at < before_date,
)
)
.limit(batch_size)
)
return list(self.session.scalars(query).all())
def update_status_batch(
self, trigger_log_ids: Sequence[str], new_status: WorkflowTriggerStatus, error_message: Optional[str] = None
) -> int:
"""Update status for multiple trigger logs."""
update_data: dict[str, Any] = {"status": new_status}
if error_message is not None:
update_data["error"] = error_message
if new_status in [WorkflowTriggerStatus.SUCCEEDED, WorkflowTriggerStatus.FAILED]:
update_data["finished_at"] = datetime.utcnow()
result = self.session.execute(
update(WorkflowTriggerLog).where(WorkflowTriggerLog.id.in_(trigger_log_ids)).values(**update_data)
)
return result.rowcount

View File

@ -0,0 +1,206 @@
"""
Repository protocol for WorkflowTriggerLog operations.
This module provides a protocol interface for operations on WorkflowTriggerLog,
designed to efficiently handle a potentially large volume of trigger logs with
proper indexing and batch operations.
"""
from collections.abc import Sequence
from datetime import datetime
from enum import StrEnum
from typing import Optional, Protocol
from models.workflow import WorkflowTriggerLog, WorkflowTriggerStatus
class TriggerLogOrderBy(StrEnum):
"""Fields available for ordering trigger logs"""
CREATED_AT = "created_at"
TRIGGERED_AT = "triggered_at"
FINISHED_AT = "finished_at"
STATUS = "status"
class WorkflowTriggerLogRepository(Protocol):
"""
Protocol for operations on WorkflowTriggerLog.
This repository provides efficient access patterns for the trigger log table,
which is expected to grow large over time. It includes:
- Batch operations for cleanup
- Efficient queries with proper indexing
- Pagination support
- Status-based filtering
Implementation notes:
- Leverage database indexes on (tenant_id, app_id), status, and created_at
- Use batch operations for deletions to avoid locking
- Support pagination for large result sets
"""
def create(self, trigger_log: WorkflowTriggerLog) -> WorkflowTriggerLog:
"""
Create a new trigger log entry.
Args:
trigger_log: The WorkflowTriggerLog instance to create
Returns:
The created WorkflowTriggerLog with generated ID
"""
...
def update(self, trigger_log: WorkflowTriggerLog) -> WorkflowTriggerLog:
"""
Update an existing trigger log entry.
Args:
trigger_log: The WorkflowTriggerLog instance to update
Returns:
The updated WorkflowTriggerLog
"""
...
def get_by_id(self, trigger_log_id: str, tenant_id: Optional[str] = None) -> Optional[WorkflowTriggerLog]:
"""
Get a trigger log by its ID.
Args:
trigger_log_id: The trigger log identifier
tenant_id: Optional tenant identifier for additional security
Returns:
The WorkflowTriggerLog if found, None otherwise
"""
...
def get_by_status(
self,
tenant_id: str,
app_id: str,
status: WorkflowTriggerStatus,
limit: int = 100,
offset: int = 0,
order_by: TriggerLogOrderBy = TriggerLogOrderBy.CREATED_AT,
order_desc: bool = True,
) -> Sequence[WorkflowTriggerLog]:
"""
Get trigger logs by status with pagination.
Args:
tenant_id: The tenant identifier
app_id: The application identifier
status: The workflow trigger status to filter by
limit: Maximum number of results
offset: Number of results to skip
order_by: Field to order results by
order_desc: Whether to order descending (True) or ascending (False)
Returns:
A sequence of WorkflowTriggerLog instances
"""
...
def get_failed_for_retry(
self, tenant_id: str, max_retry_count: int = 3, limit: int = 100
) -> Sequence[WorkflowTriggerLog]:
"""
Get failed trigger logs that are eligible for retry.
Args:
tenant_id: The tenant identifier
max_retry_count: Maximum retry count to consider
limit: Maximum number of results
Returns:
A sequence of WorkflowTriggerLog instances eligible for retry
"""
...
def get_recent_logs(
self, tenant_id: str, app_id: str, hours: int = 24, limit: int = 100, offset: int = 0
) -> Sequence[WorkflowTriggerLog]:
"""
Get recent trigger logs within specified hours.
Args:
tenant_id: The tenant identifier
app_id: The application identifier
hours: Number of hours to look back
limit: Maximum number of results
offset: Number of results to skip
Returns:
A sequence of recent WorkflowTriggerLog instances
"""
...
def count_by_status(
self,
tenant_id: str,
app_id: str,
status: Optional[WorkflowTriggerStatus] = None,
since: Optional[datetime] = None,
) -> int:
"""
Count trigger logs by status.
Args:
tenant_id: The tenant identifier
app_id: The application identifier
status: Optional status filter
since: Optional datetime to count from
Returns:
Count of matching trigger logs
"""
...
def delete_expired_logs(self, tenant_id: str, before_date: datetime, batch_size: int = 1000) -> int:
"""
Delete expired trigger logs in batches.
Args:
tenant_id: The tenant identifier
before_date: Delete logs created before this date
batch_size: Number of logs to delete per batch
Returns:
Total number of logs deleted
"""
...
def archive_completed_logs(
self, tenant_id: str, before_date: datetime, batch_size: int = 1000
) -> Sequence[WorkflowTriggerLog]:
"""
Get completed logs for archival before deletion.
Args:
tenant_id: The tenant identifier
before_date: Get logs completed before this date
batch_size: Number of logs to retrieve
Returns:
A sequence of WorkflowTriggerLog instances for archival
"""
...
def update_status_batch(
self, trigger_log_ids: Sequence[str], new_status: WorkflowTriggerStatus, error_message: Optional[str] = None
) -> int:
"""
Update status for multiple trigger logs at once.
Args:
trigger_log_ids: List of trigger log IDs to update
new_status: The new status to set
error_message: Optional error message to set
Returns:
Number of logs updated
"""
...

View File

@ -0,0 +1,320 @@
"""
Universal async workflow execution service.
This service provides a centralized entry point for triggering workflows asynchronously
with support for different subscription tiers, rate limiting, and execution tracking.
"""
import json
from datetime import UTC, datetime
from typing import Optional, Union
from celery.result import AsyncResult
from sqlalchemy import select
from sqlalchemy.orm import Session
from extensions.ext_database import db
from extensions.ext_redis import redis_client
from models.account import Account
from models.enums import CreatorUserRole
from models.model import App, EndUser
from models.workflow import Workflow, WorkflowTriggerLog, WorkflowTriggerStatus
from repositories.sqlalchemy_workflow_trigger_log_repository import SQLAlchemyWorkflowTriggerLogRepository
from services.errors.llm import InvokeRateLimitError
from services.workflow.entities import AsyncTriggerResponse, TriggerData, WorkflowTaskData
from services.workflow.queue_dispatcher import QueueDispatcherManager, QueuePriority
from services.workflow.rate_limiter import TenantDailyRateLimiter
from services.workflow_service import WorkflowService
from tasks.async_workflow_tasks import (
execute_workflow_professional,
execute_workflow_sandbox,
execute_workflow_team,
)
class AsyncWorkflowService:
"""
Universal entry point for async workflow execution - ALL METHODS ARE NON-BLOCKING
This service handles:
- Trigger data validation and processing
- Queue routing based on subscription tier
- Daily rate limiting with timezone support
- Execution tracking and logging
- Retry mechanisms for failed executions
Important: All trigger methods return immediately after queuing tasks.
Actual workflow execution happens asynchronously in background Celery workers.
Use trigger log IDs to monitor execution status and results.
"""
@classmethod
def trigger_workflow_async(
cls, session: Session, user: Union[Account, EndUser], trigger_data: TriggerData
) -> AsyncTriggerResponse:
"""
Universal entry point for async workflow execution - THIS METHOD WILL NOT BLOCK
Creates a trigger log and dispatches to appropriate queue based on subscription tier.
The workflow execution happens asynchronously in the background via Celery workers.
This method returns immediately after queuing the task, not after execution completion.
Args:
session: Database session to use for operations
user: User (Account or EndUser) who initiated the workflow trigger
trigger_data: Validated Pydantic model containing trigger information
Returns:
AsyncTriggerResponse with workflow_trigger_log_id, task_id, status="queued", and queue
Note: The actual workflow execution status must be checked separately via workflow_trigger_log_id
Raises:
ValueError: If app or workflow not found
InvokeRateLimitError: If daily rate limit exceeded
Behavior:
- Non-blocking: Returns immediately after queuing
- Asynchronous: Actual execution happens in background Celery workers
- Status tracking: Use workflow_trigger_log_id to monitor progress
- Queue-based: Routes to different queues based on subscription tier
"""
trigger_log_repo = SQLAlchemyWorkflowTriggerLogRepository(session)
dispatcher_manager = QueueDispatcherManager()
workflow_service = WorkflowService()
rate_limiter = TenantDailyRateLimiter(redis_client)
# 1. Validate app exists
app_model = session.scalar(select(App).where(App.id == trigger_data.app_id))
if not app_model:
raise ValueError(f"App not found: {trigger_data.app_id}")
# 2. Get workflow
workflow = cls._get_workflow(workflow_service, app_model, trigger_data.workflow_id)
# 3. Get dispatcher based on tenant subscription
dispatcher = dispatcher_manager.get_dispatcher(trigger_data.tenant_id)
# 4. Get tenant owner timezone for rate limiting
tenant_owner_tz = rate_limiter._get_tenant_owner_timezone(trigger_data.tenant_id)
# 5. Determine user role and ID
if isinstance(user, Account):
created_by_role = CreatorUserRole.ACCOUNT
created_by = user.id
else: # EndUser
created_by_role = CreatorUserRole.END_USER
created_by = user.id
# 6. Create trigger log entry first (for tracking)
trigger_log = WorkflowTriggerLog(
tenant_id=trigger_data.tenant_id,
app_id=trigger_data.app_id,
workflow_id=workflow.id,
trigger_type=trigger_data.trigger_type,
trigger_data=trigger_data.model_dump_json(),
inputs=json.dumps(dict(trigger_data.inputs)),
status=WorkflowTriggerStatus.PENDING,
queue_name=dispatcher.get_queue_name(),
retry_count=0,
created_by_role=created_by_role,
created_by=created_by,
)
trigger_log = trigger_log_repo.create(trigger_log)
session.commit()
# 7. Check and consume daily quota
if not dispatcher.consume_quota(trigger_data.tenant_id, tenant_owner_tz):
# Update trigger log status
trigger_log.status = WorkflowTriggerStatus.RATE_LIMITED
trigger_log.error = f"Daily limit reached for {dispatcher.get_queue_name()}"
trigger_log_repo.update(trigger_log)
session.commit()
remaining = rate_limiter.get_remaining_quota(
trigger_data.tenant_id, dispatcher.get_daily_limit(), tenant_owner_tz
)
reset_time = rate_limiter.get_quota_reset_time(trigger_data.tenant_id, tenant_owner_tz)
raise InvokeRateLimitError(
f"Daily workflow execution limit reached. "
f"Limit resets at {reset_time.strftime('%Y-%m-%d %H:%M:%S %Z')}. "
f"Remaining quota: {remaining}"
)
# 8. Create task data
queue_name = dispatcher.get_queue_name()
task_data = WorkflowTaskData(workflow_trigger_log_id=trigger_log.id)
# 9. Dispatch to appropriate queue
task_data_dict = task_data.model_dump(mode="json")
task: AsyncResult | None = None
if queue_name == QueuePriority.PROFESSIONAL:
task = execute_workflow_professional.delay(task_data_dict) # type: ignore
elif queue_name == QueuePriority.TEAM:
task = execute_workflow_team.delay(task_data_dict) # type: ignore
else: # SANDBOX
task = execute_workflow_sandbox.delay(task_data_dict) # type: ignore
if not task:
raise ValueError(f"Failed to queue task for queue: {queue_name}")
# 10. Update trigger log with task info
trigger_log.status = WorkflowTriggerStatus.QUEUED
trigger_log.celery_task_id = task.id
trigger_log.triggered_at = datetime.now(UTC)
trigger_log_repo.update(trigger_log)
session.commit()
return AsyncTriggerResponse(
workflow_trigger_log_id=trigger_log.id,
task_id=task.id, # type: ignore
status="queued",
queue=queue_name,
)
@classmethod
def reinvoke_trigger(
cls, session: Session, user: Union[Account, EndUser], workflow_trigger_log_id: str
) -> AsyncTriggerResponse:
"""
Re-invoke a previously failed or rate-limited trigger - THIS METHOD WILL NOT BLOCK
Updates the existing trigger log to retry status and creates a new async execution.
Returns immediately after queuing the retry, not after execution completion.
Args:
session: Database session to use for operations
user: User (Account or EndUser) who initiated the retry
workflow_trigger_log_id: ID of the trigger log to re-invoke
Returns:
AsyncTriggerResponse with new execution information (status="queued")
Note: This creates a new trigger log entry for the retry attempt
Raises:
ValueError: If trigger log not found
Behavior:
- Non-blocking: Returns immediately after queuing retry
- Creates new trigger log: Original log marked as retrying, new log for execution
- Preserves original trigger data: Uses same inputs and configuration
"""
trigger_log_repo = SQLAlchemyWorkflowTriggerLogRepository(session)
trigger_log = trigger_log_repo.get_by_id(workflow_trigger_log_id)
if not trigger_log:
raise ValueError(f"Trigger log not found: {workflow_trigger_log_id}")
# Reconstruct trigger data from log
trigger_data = TriggerData.model_validate_json(trigger_log.trigger_data)
# Reset log for retry
trigger_log.status = WorkflowTriggerStatus.RETRYING
trigger_log.retry_count += 1
trigger_log.error = None
trigger_log.triggered_at = datetime.now(UTC)
trigger_log_repo.update(trigger_log)
session.commit()
# Re-trigger workflow (this will create a new trigger log)
return cls.trigger_workflow_async(session, user, trigger_data)
@classmethod
def get_trigger_log(cls, workflow_trigger_log_id: str, tenant_id: Optional[str] = None) -> Optional[dict]:
"""
Get trigger log by ID
Args:
workflow_trigger_log_id: ID of the trigger log
tenant_id: Optional tenant ID for security check
Returns:
Trigger log as dictionary or None if not found
"""
with Session(db.engine) as session:
trigger_log_repo = SQLAlchemyWorkflowTriggerLogRepository(session)
trigger_log = trigger_log_repo.get_by_id(workflow_trigger_log_id, tenant_id)
if not trigger_log:
return None
return trigger_log.to_dict()
@classmethod
def get_recent_logs(
cls, tenant_id: str, app_id: str, hours: int = 24, limit: int = 100, offset: int = 0
) -> list[dict]:
"""
Get recent trigger logs
Args:
tenant_id: Tenant ID
app_id: Application ID
hours: Number of hours to look back
limit: Maximum number of results
offset: Number of results to skip
Returns:
List of trigger logs as dictionaries
"""
with Session(db.engine) as session:
trigger_log_repo = SQLAlchemyWorkflowTriggerLogRepository(session)
logs = trigger_log_repo.get_recent_logs(
tenant_id=tenant_id, app_id=app_id, hours=hours, limit=limit, offset=offset
)
return [log.to_dict() for log in logs]
@classmethod
def get_failed_logs_for_retry(cls, tenant_id: str, max_retry_count: int = 3, limit: int = 100) -> list[dict]:
"""
Get failed logs eligible for retry
Args:
tenant_id: Tenant ID
max_retry_count: Maximum retry count
limit: Maximum number of results
Returns:
List of failed trigger logs as dictionaries
"""
with Session(db.engine) as session:
trigger_log_repo = SQLAlchemyWorkflowTriggerLogRepository(session)
logs = trigger_log_repo.get_failed_for_retry(
tenant_id=tenant_id, max_retry_count=max_retry_count, limit=limit
)
return [log.to_dict() for log in logs]
@staticmethod
def _get_workflow(workflow_service: WorkflowService, app_model: App, workflow_id: Optional[str] = None) -> Workflow:
"""
Get workflow for the app
Args:
app_model: App model instance
workflow_id: Optional specific workflow ID
Returns:
Workflow instance
Raises:
ValueError: If workflow not found
"""
if workflow_id:
# Get specific published workflow
workflow = workflow_service.get_published_workflow_by_id(app_model, workflow_id)
if not workflow:
raise ValueError(f"Published workflow not found: {workflow_id}")
else:
# Get default published workflow
workflow = workflow_service.get_published_workflow(app_model)
if not workflow:
raise ValueError(f"No published workflow found for app: {app_model.id}")
return workflow

View File

@ -0,0 +1,113 @@
"""
Pydantic models for async workflow trigger system.
"""
from collections.abc import Mapping, Sequence
from enum import StrEnum
from typing import Any, Optional
from pydantic import BaseModel, ConfigDict, Field
from models.enums import WorkflowRunTriggeredFrom
class AsyncTriggerStatus(StrEnum):
"""Async trigger execution status"""
COMPLETED = "completed"
FAILED = "failed"
TIMEOUT = "timeout"
class TriggerData(BaseModel):
"""Base trigger data model for async workflow execution"""
app_id: str
tenant_id: str
workflow_id: Optional[str] = None
inputs: Mapping[str, Any]
files: Sequence[Mapping[str, Any]] = Field(default_factory=list)
trigger_type: WorkflowRunTriggeredFrom
model_config = ConfigDict(use_enum_values=True)
class WebhookTriggerData(TriggerData):
"""Webhook-specific trigger data"""
trigger_type: WorkflowRunTriggeredFrom = WorkflowRunTriggeredFrom.WEBHOOK
webhook_url: str
headers: Mapping[str, str] = Field(default_factory=dict)
method: str = "POST"
class ScheduleTriggerData(TriggerData):
"""Schedule-specific trigger data"""
trigger_type: WorkflowRunTriggeredFrom = WorkflowRunTriggeredFrom.SCHEDULE
schedule_id: str
cron_expression: str
class PluginTriggerData(TriggerData):
"""Plugin webhook trigger data"""
trigger_type: WorkflowRunTriggeredFrom = WorkflowRunTriggeredFrom.PLUGIN
plugin_id: str
webhook_url: str
class WorkflowTaskData(BaseModel):
"""Lightweight data structure for Celery workflow tasks"""
workflow_trigger_log_id: str # Primary tracking ID - all other data can be fetched from DB
model_config = ConfigDict(arbitrary_types_allowed=True)
class AsyncTriggerExecutionResult(BaseModel):
"""Result from async trigger-based workflow execution"""
execution_id: str
status: AsyncTriggerStatus
result: Optional[Mapping[str, Any]] = None
error: Optional[str] = None
elapsed_time: Optional[float] = None
total_tokens: Optional[int] = None
model_config = ConfigDict(use_enum_values=True)
class AsyncTriggerResponse(BaseModel):
"""Response from triggering an async workflow"""
workflow_trigger_log_id: str
task_id: str
status: str
queue: str
model_config = ConfigDict(use_enum_values=True)
class TriggerLogResponse(BaseModel):
"""Response model for trigger log data"""
id: str
tenant_id: str
app_id: str
workflow_id: str
trigger_type: WorkflowRunTriggeredFrom
status: str
queue_name: str
retry_count: int
celery_task_id: Optional[str] = None
workflow_run_id: Optional[str] = None
error: Optional[str] = None
outputs: Optional[str] = None
elapsed_time: Optional[float] = None
total_tokens: Optional[int] = None
created_at: Optional[str] = None
triggered_at: Optional[str] = None
finished_at: Optional[str] = None
model_config = ConfigDict(use_enum_values=True)

View File

@ -0,0 +1,158 @@
"""
Queue dispatcher system for async workflow execution.
Implements an ABC-based pattern for handling different subscription tiers
with appropriate queue routing and rate limiting.
"""
import os
from abc import ABC, abstractmethod
from enum import StrEnum
from configs import dify_config
from extensions.ext_redis import redis_client
from services.billing_service import BillingService
from services.workflow.rate_limiter import TenantDailyRateLimiter
class QueuePriority(StrEnum):
"""Queue priorities for different subscription tiers"""
PROFESSIONAL = "workflow_professional" # Highest priority
TEAM = "workflow_team"
SANDBOX = "workflow_sandbox" # Free tier
class BaseQueueDispatcher(ABC):
"""Abstract base class for queue dispatchers"""
def __init__(self):
self.rate_limiter = TenantDailyRateLimiter(redis_client)
@abstractmethod
def get_queue_name(self) -> str:
"""Get the queue name for this dispatcher"""
pass
@abstractmethod
def get_daily_limit(self) -> int:
"""Get daily execution limit"""
pass
@abstractmethod
def get_priority(self) -> int:
"""Get task priority level"""
pass
def check_daily_quota(self, tenant_id: str, tenant_owner_tz: str) -> bool:
"""
Check if tenant has remaining daily quota
Args:
tenant_id: The tenant identifier
tenant_owner_tz: Tenant owner's timezone
Returns:
True if quota available, False otherwise
"""
# Check without consuming
remaining = self.rate_limiter.get_remaining_quota(
tenant_id=tenant_id, max_daily_limit=self.get_daily_limit(), timezone_str=tenant_owner_tz
)
return remaining > 0
def consume_quota(self, tenant_id: str, tenant_owner_tz: str) -> bool:
"""
Consume one execution from daily quota
Args:
tenant_id: The tenant identifier
tenant_owner_tz: Tenant owner's timezone
Returns:
True if quota consumed successfully, False if limit reached
"""
return self.rate_limiter.check_and_consume(
tenant_id=tenant_id, max_daily_limit=self.get_daily_limit(), timezone_str=tenant_owner_tz
)
class ProfessionalQueueDispatcher(BaseQueueDispatcher):
"""Dispatcher for professional tier"""
def get_queue_name(self) -> str:
return QueuePriority.PROFESSIONAL
def get_daily_limit(self) -> int:
return int(os.getenv("PROFESSIONAL_DAILY_LIMIT", "1000"))
def get_priority(self) -> int:
return 100
class TeamQueueDispatcher(BaseQueueDispatcher):
"""Dispatcher for team tier"""
def get_queue_name(self) -> str:
return QueuePriority.TEAM
def get_daily_limit(self) -> int:
return int(os.getenv("TEAM_DAILY_LIMIT", "100"))
def get_priority(self) -> int:
return 50
class SandboxQueueDispatcher(BaseQueueDispatcher):
"""Dispatcher for free/sandbox tier"""
def get_queue_name(self) -> str:
return QueuePriority.SANDBOX
def get_daily_limit(self) -> int:
return int(os.getenv("SANDBOX_DAILY_LIMIT", "10"))
def get_priority(self) -> int:
return 10
class QueueDispatcherManager:
"""Factory for creating appropriate dispatcher based on tenant subscription"""
# Mapping of billing plans to dispatchers
PLAN_DISPATCHER_MAP = {
"professional": ProfessionalQueueDispatcher,
"team": TeamQueueDispatcher,
"sandbox": SandboxQueueDispatcher,
# Add new tiers here as they're created
# For any unknown plan, default to sandbox
}
@classmethod
def get_dispatcher(cls, tenant_id: str) -> BaseQueueDispatcher:
"""
Get dispatcher based on tenant's subscription plan
Args:
tenant_id: The tenant identifier
Returns:
Appropriate queue dispatcher instance
"""
if dify_config.BILLING_ENABLED:
try:
billing_info = BillingService.get_info(tenant_id)
plan = billing_info.get("subscription", {}).get("plan", "sandbox")
except Exception:
# If billing service fails, default to sandbox
plan = "sandbox"
else:
# If billing is disabled, use team tier as default
plan = "team"
dispatcher_class = cls.PLAN_DISPATCHER_MAP.get(
plan,
SandboxQueueDispatcher, # Default to sandbox for unknown plans
)
return dispatcher_class()

View File

@ -0,0 +1,206 @@
"""
Day-based rate limiter for workflow executions.
Implements timezone-aware daily quotas that reset at midnight in the tenant owner's timezone.
"""
from datetime import datetime, time, timedelta
from typing import Optional, Union
import pytz
from redis import Redis
from sqlalchemy import select
from extensions.ext_database import db
from extensions.ext_redis import RedisClientWrapper
from models.account import Account, TenantAccountJoin, TenantAccountRole
class TenantDailyRateLimiter:
"""
Day-based rate limiter that resets at midnight in tenant owner's timezone
This class provides Redis-based rate limiting with the following features:
- Daily quotas that reset at midnight in tenant owner's timezone
- Atomic check-and-consume operations
- Automatic cleanup of stale counters
- Support for timezone changes without duplicate limits
"""
def __init__(self, redis_client: Union[Redis, RedisClientWrapper]):
self.redis = redis_client
def _get_tenant_owner_timezone(self, tenant_id: str) -> str:
"""
Get timezone of tenant owner
Args:
tenant_id: The tenant identifier
Returns:
Timezone string (e.g., 'America/New_York', 'UTC')
"""
# Query to get tenant owner's timezone using scalar and select
owner = db.session.scalar(
select(Account)
.join(TenantAccountJoin, TenantAccountJoin.account_id == Account.id)
.where(TenantAccountJoin.tenant_id == tenant_id, TenantAccountJoin.role == TenantAccountRole.OWNER)
)
if not owner:
return "UTC"
return owner.timezone or "UTC"
def _get_day_key(self, tenant_id: str, timezone_str: str) -> str:
"""
Get Redis key for current day in tenant's timezone
Args:
tenant_id: The tenant identifier
timezone_str: Timezone string
Returns:
Redis key for the current day
"""
tz = pytz.timezone(timezone_str)
now = datetime.now(tz)
date_str = now.strftime("%Y-%m-%d")
return f"workflow:daily_limit:{tenant_id}:{date_str}:{timezone_str}"
def _get_ttl_seconds(self, timezone_str: str) -> int:
"""
Calculate seconds until midnight in given timezone
Args:
timezone_str: Timezone string
Returns:
Number of seconds until midnight
"""
tz = pytz.timezone(timezone_str)
now = datetime.now(tz)
# Get next midnight in the timezone
midnight = tz.localize(datetime.combine(now.date() + timedelta(days=1), time.min))
return int((midnight - now).total_seconds())
def check_and_consume(self, tenant_id: str, max_daily_limit: int, timezone_str: Optional[str] = None) -> bool:
"""
Check if quota available and consume one execution
Args:
tenant_id: The tenant identifier
max_daily_limit: Maximum daily limit
timezone_str: Optional timezone string (will be fetched if not provided)
Returns:
True if quota consumed successfully, False if limit reached
"""
if not timezone_str:
timezone_str = self._get_tenant_owner_timezone(tenant_id)
key = self._get_day_key(tenant_id, timezone_str)
ttl = self._get_ttl_seconds(timezone_str)
# Check current usage
current = self.redis.get(key)
if current is None:
# First execution of the day - set to 1
self.redis.setex(key, ttl, 1)
return True
current_count = int(current)
if current_count < max_daily_limit:
# Within limit, increment
new_count = self.redis.incr(key)
# Update TTL in case timezone changed
self.redis.expire(key, ttl)
# Double-check in case of race condition
if new_count <= max_daily_limit:
return True
else:
# Race condition occurred, decrement back
self.redis.decr(key)
return False
else:
# Limit exceeded
return False
def get_remaining_quota(self, tenant_id: str, max_daily_limit: int, timezone_str: Optional[str] = None) -> int:
"""
Get remaining quota for the day
Args:
tenant_id: The tenant identifier
max_daily_limit: Maximum daily limit
timezone_str: Optional timezone string (will be fetched if not provided)
Returns:
Number of remaining executions for the day
"""
if not timezone_str:
timezone_str = self._get_tenant_owner_timezone(tenant_id)
key = self._get_day_key(tenant_id, timezone_str)
used = int(self.redis.get(key) or 0)
return max(0, max_daily_limit - used)
def get_current_usage(self, tenant_id: str, timezone_str: Optional[str] = None) -> int:
"""
Get current usage for the day
Args:
tenant_id: The tenant identifier
timezone_str: Optional timezone string (will be fetched if not provided)
Returns:
Number of executions used today
"""
if not timezone_str:
timezone_str = self._get_tenant_owner_timezone(tenant_id)
key = self._get_day_key(tenant_id, timezone_str)
return int(self.redis.get(key) or 0)
def reset_quota(self, tenant_id: str, timezone_str: Optional[str] = None) -> bool:
"""
Reset quota for testing purposes
Args:
tenant_id: The tenant identifier
timezone_str: Optional timezone string (will be fetched if not provided)
Returns:
True if key was deleted, False if key didn't exist
"""
if not timezone_str:
timezone_str = self._get_tenant_owner_timezone(tenant_id)
key = self._get_day_key(tenant_id, timezone_str)
return bool(self.redis.delete(key))
def get_quota_reset_time(self, tenant_id: str, timezone_str: Optional[str] = None) -> datetime:
"""
Get the time when quota will reset (midnight in tenant's timezone)
Args:
tenant_id: The tenant identifier
timezone_str: Optional timezone string (will be fetched if not provided)
Returns:
Datetime when quota resets
"""
if not timezone_str:
timezone_str = self._get_tenant_owner_timezone(tenant_id)
tz = pytz.timezone(timezone_str)
now = datetime.now(tz)
# Get next midnight in the timezone
midnight = tz.localize(datetime.combine(now.date() + timedelta(days=1), time.min))
return midnight

View File

@ -0,0 +1,201 @@
"""
Celery tasks for async workflow execution.
These tasks handle workflow execution for different subscription tiers
with appropriate retry policies and error handling.
"""
import json
from datetime import UTC, datetime
from celery import shared_task
from sqlalchemy import select
from sqlalchemy.orm import Session, sessionmaker
from configs import dify_config
from core.app.apps.workflow.app_generator import WorkflowAppGenerator
from core.app.entities.app_invoke_entities import InvokeFrom
from extensions.ext_database import db
from models.account import Account
from models.enums import CreatorUserRole
from models.model import App, EndUser, Tenant
from models.workflow import Workflow, WorkflowTriggerLog, WorkflowTriggerStatus
from repositories.sqlalchemy_workflow_trigger_log_repository import SQLAlchemyWorkflowTriggerLogRepository
from services.workflow.entities import AsyncTriggerExecutionResult, AsyncTriggerStatus, TriggerData, WorkflowTaskData
# Determine queue names based on edition
if dify_config.EDITION == "CLOUD":
# Cloud edition: separate queues for different tiers
PROFESSIONAL_QUEUE = "workflow_professional"
TEAM_QUEUE = "workflow_team"
SANDBOX_QUEUE = "workflow_sandbox"
else:
# Community edition: single workflow queue (not dataset)
PROFESSIONAL_QUEUE = "workflow"
TEAM_QUEUE = "workflow"
SANDBOX_QUEUE = "workflow"
@shared_task(queue=PROFESSIONAL_QUEUE)
def execute_workflow_professional(task_data_dict: dict) -> dict:
"""Execute workflow for professional tier with highest priority"""
task_data = WorkflowTaskData.model_validate(task_data_dict)
return _execute_workflow_common(task_data).model_dump()
@shared_task(queue=TEAM_QUEUE)
def execute_workflow_team(task_data_dict: dict) -> dict:
"""Execute workflow for team tier"""
task_data = WorkflowTaskData.model_validate(task_data_dict)
return _execute_workflow_common(task_data).model_dump()
@shared_task(queue=SANDBOX_QUEUE)
def execute_workflow_sandbox(task_data_dict: dict) -> dict:
"""Execute workflow for free tier with lower retry limit"""
task_data = WorkflowTaskData.model_validate(task_data_dict)
return _execute_workflow_common(task_data).model_dump()
def _execute_workflow_common(task_data: WorkflowTaskData) -> AsyncTriggerExecutionResult:
"""
Common workflow execution logic with trigger log updates
Args:
task_data: Validated Pydantic model with task data
Returns:
AsyncTriggerExecutionResult: Pydantic model with execution results
"""
# Create a new session for this task
session_factory = sessionmaker(bind=db.engine, expire_on_commit=False)
with session_factory() as session:
trigger_log_repo = SQLAlchemyWorkflowTriggerLogRepository(session)
# Get trigger log
trigger_log = trigger_log_repo.get_by_id(task_data.workflow_trigger_log_id)
if not trigger_log:
# This should not happen, but handle gracefully
return AsyncTriggerExecutionResult(
execution_id=task_data.workflow_trigger_log_id,
status=AsyncTriggerStatus.FAILED,
error=f"Trigger log not found: {task_data.workflow_trigger_log_id}",
)
# Reconstruct execution data from trigger log
trigger_data = TriggerData.model_validate_json(trigger_log.trigger_data)
# Update status to running
trigger_log.status = WorkflowTriggerStatus.RUNNING
trigger_log_repo.update(trigger_log)
session.commit()
start_time = datetime.now(UTC)
try:
# Get app and workflow models
app_model = session.scalar(select(App).where(App.id == trigger_log.app_id))
if not app_model:
raise ValueError(f"App not found: {trigger_log.app_id}")
workflow = session.scalar(select(Workflow).where(Workflow.id == trigger_log.workflow_id))
if not workflow:
raise ValueError(f"Workflow not found: {trigger_log.workflow_id}")
user = _get_user(session, trigger_log)
# Execute workflow using WorkflowAppGenerator
generator = WorkflowAppGenerator()
# Prepare args matching AppGenerateService.generate format
args = {"inputs": dict(trigger_data.inputs), "files": list(trigger_data.files)}
# If workflow_id was specified, add it to args
if trigger_data.workflow_id:
args["workflow_id"] = trigger_data.workflow_id
# Execute the workflow with the trigger type
result = generator.generate(
app_model=app_model,
workflow=workflow,
user=user,
args=args,
invoke_from=InvokeFrom.SERVICE_API,
streaming=False,
call_depth=0,
workflow_thread_pool_id=None,
triggered_from=trigger_data.trigger_type,
)
# Calculate elapsed time
elapsed_time = (datetime.now(UTC) - start_time).total_seconds()
# Extract relevant data from result
if isinstance(result, dict):
workflow_run_id = result.get("workflow_run_id")
total_tokens = result.get("total_tokens")
outputs = result
else:
# Handle generator result - collect all data
workflow_run_id = None
total_tokens = None
outputs = {"data": "streaming_result"}
# Update trigger log with success
trigger_log.status = WorkflowTriggerStatus.SUCCEEDED
trigger_log.workflow_run_id = workflow_run_id
trigger_log.outputs = json.dumps(outputs)
trigger_log.elapsed_time = elapsed_time
trigger_log.total_tokens = total_tokens
trigger_log.finished_at = datetime.now(UTC)
trigger_log_repo.update(trigger_log)
session.commit()
return AsyncTriggerExecutionResult(
execution_id=trigger_log.id,
status=AsyncTriggerStatus.COMPLETED,
result=outputs,
elapsed_time=elapsed_time,
total_tokens=total_tokens,
)
except Exception as e:
# Calculate elapsed time for failed execution
elapsed_time = (datetime.now(UTC) - start_time).total_seconds()
# Update trigger log with failure
trigger_log.status = WorkflowTriggerStatus.FAILED
trigger_log.error = str(e)
trigger_log.finished_at = datetime.now(UTC)
trigger_log.elapsed_time = elapsed_time
trigger_log_repo.update(trigger_log)
# Final failure - no retry logic (simplified like RAG tasks)
session.commit()
return AsyncTriggerExecutionResult(
execution_id=trigger_log.id, status=AsyncTriggerStatus.FAILED, error=str(e), elapsed_time=elapsed_time
)
def _get_user(session: Session, trigger_log: WorkflowTriggerLog) -> Account | EndUser:
"""Compose user from trigger log"""
tenant = session.scalar(select(Tenant).where(Tenant.id == trigger_log.tenant_id))
if not tenant:
raise ValueError(f"Tenant not found: {trigger_log.tenant_id}")
# Get user from trigger log
if trigger_log.created_by_role == CreatorUserRole.ACCOUNT:
user = session.scalar(select(Account).where(Account.id == trigger_log.created_by))
if user:
user.current_tenant = tenant
else: # CreatorUserRole.END_USER
user = session.scalar(select(EndUser).where(EndUser.id == trigger_log.created_by))
if not user:
raise ValueError(f"User not found: {trigger_log.created_by} (role: {trigger_log.created_by_role})")
return user

View File

@ -2,10 +2,100 @@
set -x
# Help function
show_help() {
echo "Usage: $0 [OPTIONS]"
echo ""
echo "Options:"
echo " -q, --queues QUEUES Comma-separated list of queues to process"
echo " -c, --concurrency NUM Number of worker processes (default: 1)"
echo " -P, --pool POOL Pool implementation (default: gevent)"
echo " --loglevel LEVEL Log level (default: INFO)"
echo " -h, --help Show this help message"
echo ""
echo "Examples:"
echo " $0 --queues dataset,workflow"
echo " $0 --queues workflow_professional,workflow_team --concurrency 4"
echo " $0 --queues dataset --concurrency 2 --pool prefork"
echo ""
echo "Available queues:"
echo " dataset - RAG indexing and document processing"
echo " workflow - Workflow triggers (community edition)"
echo " workflow_professional - Professional tier workflows (cloud edition)"
echo " workflow_team - Team tier workflows (cloud edition)"
echo " workflow_sandbox - Sandbox tier workflows (cloud edition)"
echo " generation - Content generation tasks"
echo " mail - Email notifications"
echo " ops_trace - Operations tracing"
echo " app_deletion - Application cleanup"
echo " plugin - Plugin operations"
echo " workflow_storage - Workflow storage tasks"
}
# Parse command line arguments
QUEUES=""
CONCURRENCY=1
POOL="gevent"
LOGLEVEL="INFO"
while [[ $# -gt 0 ]]; do
case $1 in
-q|--queues)
QUEUES="$2"
shift 2
;;
-c|--concurrency)
CONCURRENCY="$2"
shift 2
;;
-P|--pool)
POOL="$2"
shift 2
;;
--loglevel)
LOGLEVEL="$2"
shift 2
;;
-h|--help)
show_help
exit 0
;;
*)
echo "Unknown option: $1"
show_help
exit 1
;;
esac
done
SCRIPT_DIR="$(dirname "$(realpath "$0")")"
cd "$SCRIPT_DIR/.."
# If no queues specified, use edition-based defaults
if [[ -z "${QUEUES}" ]]; then
# Get EDITION from environment, default to SELF_HOSTED (community edition)
EDITION=${EDITION:-"SELF_HOSTED"}
# Configure queues based on edition
if [[ "${EDITION}" == "CLOUD" ]]; then
# Cloud edition: separate queues for dataset and trigger tasks
QUEUES="dataset,generation,mail,ops_trace,app_deletion,plugin,workflow_storage,workflow_professional,workflow_team,workflow_sandbox"
else
# Community edition (SELF_HOSTED): dataset and workflow have separate queues
QUEUES="dataset,generation,mail,ops_trace,app_deletion,plugin,workflow_storage,workflow"
fi
echo "No queues specified, using edition-based defaults: ${QUEUES}"
else
echo "Using specified queues: ${QUEUES}"
fi
echo "Starting Celery worker with:"
echo " Queues: ${QUEUES}"
echo " Concurrency: ${CONCURRENCY}"
echo " Pool: ${POOL}"
echo " Log Level: ${LOGLEVEL}"
uv --directory api run \
celery -A app.celery worker \
-P gevent -c 1 --loglevel INFO -Q dataset,generation,mail,ops_trace,app_deletion,plugin,workflow_storage
-P ${POOL} -c ${CONCURRENCY} --loglevel ${LOGLEVEL} -Q ${QUEUES}

View File

@ -29,14 +29,14 @@ services:
- default
# worker service
# The Celery worker for processing the queue.
# The Celery worker for processing all queues (dataset, workflow, mail, etc.)
worker:
image: langgenius/dify-api:1.7.2
restart: always
environment:
# Use the shared environment variables.
<<: *shared-api-worker-env
# Startup mode, 'worker' starts the Celery worker for processing the queue.
# Startup mode, 'worker' starts the Celery worker for processing all queues.
MODE: worker
SENTRY_DSN: ${API_SENTRY_DSN:-}
SENTRY_TRACES_SAMPLE_RATE: ${API_SENTRY_TRACES_SAMPLE_RATE:-1.0}