mirror of https://github.com/langgenius/dify.git
feat/trigger universal entry (#24358)
Co-authored-by: Claude <noreply@anthropic.com> Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
This commit is contained in:
parent
8e93a8a2e2
commit
6aed7e3ff4
|
|
@ -54,7 +54,7 @@
|
|||
"--loglevel",
|
||||
"DEBUG",
|
||||
"-Q",
|
||||
"dataset,generation,mail,ops_trace,app_deletion"
|
||||
"dataset,generation,mail,ops_trace,app_deletion,workflow"
|
||||
]
|
||||
}
|
||||
]
|
||||
|
|
|
|||
|
|
@ -54,6 +54,7 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
|||
streaming: Literal[True],
|
||||
call_depth: int,
|
||||
workflow_thread_pool_id: Optional[str],
|
||||
triggered_from: Optional[WorkflowRunTriggeredFrom] = None,
|
||||
) -> Generator[Mapping | str, None, None]: ...
|
||||
|
||||
@overload
|
||||
|
|
@ -68,6 +69,7 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
|||
streaming: Literal[False],
|
||||
call_depth: int,
|
||||
workflow_thread_pool_id: Optional[str],
|
||||
triggered_from: Optional[WorkflowRunTriggeredFrom] = None,
|
||||
) -> Mapping[str, Any]: ...
|
||||
|
||||
@overload
|
||||
|
|
@ -82,6 +84,7 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
|||
streaming: bool,
|
||||
call_depth: int,
|
||||
workflow_thread_pool_id: Optional[str],
|
||||
triggered_from: Optional[WorkflowRunTriggeredFrom] = None,
|
||||
) -> Union[Mapping[str, Any], Generator[Mapping | str, None, None]]: ...
|
||||
|
||||
def generate(
|
||||
|
|
@ -95,6 +98,7 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
|||
streaming: bool = True,
|
||||
call_depth: int = 0,
|
||||
workflow_thread_pool_id: Optional[str] = None,
|
||||
triggered_from: Optional[WorkflowRunTriggeredFrom] = None,
|
||||
) -> Union[Mapping[str, Any], Generator[Mapping | str, None, None]]:
|
||||
files: Sequence[Mapping[str, Any]] = args.get("files") or []
|
||||
|
||||
|
|
@ -159,7 +163,10 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
|||
# Create session factory
|
||||
session_factory = sessionmaker(bind=db.engine, expire_on_commit=False)
|
||||
# Create workflow execution(aka workflow run) repository
|
||||
if invoke_from == InvokeFrom.DEBUGGER:
|
||||
if triggered_from is not None:
|
||||
# Use explicitly provided triggered_from (for async triggers)
|
||||
workflow_triggered_from = triggered_from
|
||||
elif invoke_from == InvokeFrom.DEBUGGER:
|
||||
workflow_triggered_from = WorkflowRunTriggeredFrom.DEBUGGING
|
||||
else:
|
||||
workflow_triggered_from = WorkflowRunTriggeredFrom.APP_RUN
|
||||
|
|
|
|||
|
|
@ -30,9 +30,41 @@ if [[ "${MODE}" == "worker" ]]; then
|
|||
CONCURRENCY_OPTION="-c ${CELERY_WORKER_AMOUNT:-1}"
|
||||
fi
|
||||
|
||||
exec celery -A app.celery worker -P ${CELERY_WORKER_CLASS:-gevent} $CONCURRENCY_OPTION \
|
||||
# Configure queues based on edition if not explicitly set
|
||||
if [[ -z "${CELERY_QUEUES}" ]]; then
|
||||
if [[ "${EDITION}" == "CLOUD" ]]; then
|
||||
# Cloud edition: separate queues for dataset and trigger tasks
|
||||
DEFAULT_QUEUES="dataset,mail,ops_trace,app_deletion,plugin,workflow_storage,workflow_professional,workflow_team,workflow_sandbox"
|
||||
else
|
||||
# Community edition (SELF_HOSTED): dataset and workflow have separate queues
|
||||
DEFAULT_QUEUES="dataset,mail,ops_trace,app_deletion,plugin,workflow_storage,workflow"
|
||||
fi
|
||||
else
|
||||
DEFAULT_QUEUES="${CELERY_QUEUES}"
|
||||
fi
|
||||
|
||||
# Support for Kubernetes deployment with specific queue workers
|
||||
# Environment variables that can be set:
|
||||
# - CELERY_WORKER_QUEUES: Comma-separated list of queues (overrides CELERY_QUEUES)
|
||||
# - CELERY_WORKER_CONCURRENCY: Number of worker processes (overrides CELERY_WORKER_AMOUNT)
|
||||
# - CELERY_WORKER_POOL: Pool implementation (overrides CELERY_WORKER_CLASS)
|
||||
|
||||
if [[ -n "${CELERY_WORKER_QUEUES}" ]]; then
|
||||
DEFAULT_QUEUES="${CELERY_WORKER_QUEUES}"
|
||||
echo "Using CELERY_WORKER_QUEUES: ${DEFAULT_QUEUES}"
|
||||
fi
|
||||
|
||||
if [[ -n "${CELERY_WORKER_CONCURRENCY}" ]]; then
|
||||
CONCURRENCY_OPTION="-c ${CELERY_WORKER_CONCURRENCY}"
|
||||
echo "Using CELERY_WORKER_CONCURRENCY: ${CELERY_WORKER_CONCURRENCY}"
|
||||
fi
|
||||
|
||||
WORKER_POOL="${CELERY_WORKER_POOL:-${CELERY_WORKER_CLASS:-gevent}}"
|
||||
echo "Starting Celery worker with queues: ${DEFAULT_QUEUES}"
|
||||
|
||||
exec celery -A app.celery worker -P ${WORKER_POOL} $CONCURRENCY_OPTION \
|
||||
--max-tasks-per-child ${MAX_TASK_PRE_CHILD:-50} --loglevel ${LOG_LEVEL:-INFO} \
|
||||
-Q ${CELERY_QUEUES:-dataset,mail,ops_trace,app_deletion,plugin,workflow_storage}
|
||||
-Q ${DEFAULT_QUEUES}
|
||||
|
||||
elif [[ "${MODE}" == "beat" ]]; then
|
||||
exec celery -A app.celery beat --loglevel ${LOG_LEVEL:-INFO}
|
||||
|
|
|
|||
|
|
@ -96,7 +96,9 @@ def init_app(app: DifyApp) -> Celery:
|
|||
celery_app.set_default()
|
||||
app.extensions["celery"] = celery_app
|
||||
|
||||
imports = []
|
||||
imports = [
|
||||
"tasks.async_workflow_tasks", # trigger workers
|
||||
]
|
||||
day = dify_config.CELERY_BEAT_SCHEDULER_TIME
|
||||
|
||||
# if you add a new task, please add the switch to CeleryScheduleTasksConfig
|
||||
|
|
|
|||
|
|
@ -0,0 +1,66 @@
|
|||
"""empty message
|
||||
|
||||
Revision ID: 994bdf7197ab
|
||||
Revises: fa8b0fa6f407
|
||||
Create Date: 2025-08-23 20:06:35.995973
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import models as models
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = '994bdf7197ab'
|
||||
down_revision = 'fa8b0fa6f407'
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade():
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.create_table('workflow_trigger_logs',
|
||||
sa.Column('id', models.types.StringUUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False),
|
||||
sa.Column('tenant_id', models.types.StringUUID(), nullable=False),
|
||||
sa.Column('app_id', models.types.StringUUID(), nullable=False),
|
||||
sa.Column('workflow_id', models.types.StringUUID(), nullable=False),
|
||||
sa.Column('workflow_run_id', models.types.StringUUID(), nullable=True),
|
||||
sa.Column('trigger_type', sa.String(length=50), nullable=False),
|
||||
sa.Column('trigger_data', sa.Text(), nullable=False),
|
||||
sa.Column('inputs', sa.Text(), nullable=False),
|
||||
sa.Column('outputs', sa.Text(), nullable=True),
|
||||
sa.Column('status', sa.String(length=50), nullable=False),
|
||||
sa.Column('error', sa.Text(), nullable=True),
|
||||
sa.Column('queue_name', sa.String(length=100), nullable=False),
|
||||
sa.Column('celery_task_id', sa.String(length=255), nullable=True),
|
||||
sa.Column('retry_count', sa.Integer(), nullable=False),
|
||||
sa.Column('elapsed_time', sa.Float(), nullable=True),
|
||||
sa.Column('total_tokens', sa.Integer(), nullable=True),
|
||||
sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP'), nullable=False),
|
||||
sa.Column('created_by_role', sa.String(length=255), nullable=False),
|
||||
sa.Column('created_by', sa.String(length=255), nullable=False),
|
||||
sa.Column('triggered_at', sa.DateTime(), nullable=True),
|
||||
sa.Column('finished_at', sa.DateTime(), nullable=True),
|
||||
sa.PrimaryKeyConstraint('id', name='workflow_trigger_log_pkey')
|
||||
)
|
||||
with op.batch_alter_table('workflow_trigger_logs', schema=None) as batch_op:
|
||||
batch_op.create_index('workflow_trigger_log_created_at_idx', ['created_at'], unique=False)
|
||||
batch_op.create_index('workflow_trigger_log_status_idx', ['status'], unique=False)
|
||||
batch_op.create_index('workflow_trigger_log_tenant_app_idx', ['tenant_id', 'app_id'], unique=False)
|
||||
batch_op.create_index('workflow_trigger_log_workflow_id_idx', ['workflow_id'], unique=False)
|
||||
batch_op.create_index('workflow_trigger_log_workflow_run_idx', ['workflow_run_id'], unique=False)
|
||||
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade():
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
with op.batch_alter_table('workflow_trigger_logs', schema=None) as batch_op:
|
||||
batch_op.drop_index('workflow_trigger_log_workflow_run_idx')
|
||||
batch_op.drop_index('workflow_trigger_log_workflow_id_idx')
|
||||
batch_op.drop_index('workflow_trigger_log_tenant_app_idx')
|
||||
batch_op.drop_index('workflow_trigger_log_status_idx')
|
||||
batch_op.drop_index('workflow_trigger_log_created_at_idx')
|
||||
|
||||
op.drop_table('workflow_trigger_logs')
|
||||
# ### end Alembic commands ###
|
||||
|
|
@ -13,7 +13,10 @@ class UserFrom(StrEnum):
|
|||
|
||||
class WorkflowRunTriggeredFrom(StrEnum):
|
||||
DEBUGGING = "debugging"
|
||||
APP_RUN = "app-run"
|
||||
APP_RUN = "app-run" # webapp / service api
|
||||
WEBHOOK = "webhook"
|
||||
SCHEDULE = "schedule"
|
||||
PLUGIN = "plugin"
|
||||
|
||||
|
||||
class DraftVariableType(StrEnum):
|
||||
|
|
|
|||
|
|
@ -1262,3 +1262,122 @@ class WorkflowDraftVariable(Base):
|
|||
|
||||
def is_system_variable_editable(name: str) -> bool:
|
||||
return name in _EDITABLE_SYSTEM_VARIABLE
|
||||
|
||||
|
||||
class WorkflowTriggerStatus(StrEnum):
|
||||
"""Workflow Trigger Execution Status"""
|
||||
|
||||
PENDING = "pending"
|
||||
QUEUED = "queued"
|
||||
RUNNING = "running"
|
||||
SUCCEEDED = "succeeded"
|
||||
FAILED = "failed"
|
||||
RATE_LIMITED = "rate_limited"
|
||||
RETRYING = "retrying"
|
||||
|
||||
|
||||
class WorkflowTriggerLog(Base):
|
||||
"""
|
||||
Workflow Trigger Log
|
||||
|
||||
Track async trigger workflow runs with re-invocation capability
|
||||
|
||||
Attributes:
|
||||
- id (uuid) Trigger Log ID (used as workflow_trigger_log_id)
|
||||
- tenant_id (uuid) Workspace ID
|
||||
- app_id (uuid) App ID
|
||||
- workflow_id (uuid) Workflow ID
|
||||
- workflow_run_id (uuid) Optional - Associated workflow run ID when execution starts
|
||||
- trigger_type (string) Type of trigger: webhook, schedule, plugin
|
||||
- trigger_data (text) Full trigger data including inputs (JSON)
|
||||
- inputs (text) Input parameters (JSON)
|
||||
- outputs (text) Optional - Output content (JSON)
|
||||
- status (string) Execution status
|
||||
- error (text) Optional - Error message if failed
|
||||
- queue_name (string) Celery queue used
|
||||
- celery_task_id (string) Optional - Celery task ID for tracking
|
||||
- retry_count (int) Number of retry attempts
|
||||
- elapsed_time (float) Optional - Time consumption in seconds
|
||||
- total_tokens (int) Optional - Total tokens used
|
||||
- created_by_role (string) Creator role: account, end_user
|
||||
- created_by (string) Creator ID
|
||||
- created_at (timestamp) Creation time
|
||||
- triggered_at (timestamp) Optional - When actually triggered
|
||||
- finished_at (timestamp) Optional - Completion time
|
||||
"""
|
||||
|
||||
__tablename__ = "workflow_trigger_logs"
|
||||
__table_args__ = (
|
||||
sa.PrimaryKeyConstraint("id", name="workflow_trigger_log_pkey"),
|
||||
sa.Index("workflow_trigger_log_tenant_app_idx", "tenant_id", "app_id"),
|
||||
sa.Index("workflow_trigger_log_status_idx", "status"),
|
||||
sa.Index("workflow_trigger_log_created_at_idx", "created_at"),
|
||||
sa.Index("workflow_trigger_log_workflow_run_idx", "workflow_run_id"),
|
||||
sa.Index("workflow_trigger_log_workflow_id_idx", "workflow_id"),
|
||||
)
|
||||
|
||||
id: Mapped[str] = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()"))
|
||||
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
app_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
workflow_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
workflow_run_id: Mapped[Optional[str]] = mapped_column(StringUUID, nullable=True)
|
||||
|
||||
trigger_type: Mapped[str] = mapped_column(String(50), nullable=False)
|
||||
trigger_data: Mapped[str] = mapped_column(sa.Text, nullable=False) # Full TriggerData as JSON
|
||||
inputs: Mapped[str] = mapped_column(sa.Text, nullable=False) # Just inputs for easy viewing
|
||||
outputs: Mapped[Optional[str]] = mapped_column(sa.Text, nullable=True)
|
||||
|
||||
status: Mapped[str] = mapped_column(String(50), nullable=False, default=WorkflowTriggerStatus.PENDING)
|
||||
error: Mapped[Optional[str]] = mapped_column(sa.Text, nullable=True)
|
||||
|
||||
queue_name: Mapped[str] = mapped_column(String(100), nullable=False)
|
||||
celery_task_id: Mapped[Optional[str]] = mapped_column(String(255), nullable=True)
|
||||
retry_count: Mapped[int] = mapped_column(sa.Integer, nullable=False, default=0)
|
||||
|
||||
elapsed_time: Mapped[Optional[float]] = mapped_column(sa.Float, nullable=True)
|
||||
total_tokens: Mapped[Optional[int]] = mapped_column(sa.Integer, nullable=True)
|
||||
|
||||
created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
created_by_role: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
created_by: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
|
||||
triggered_at: Mapped[Optional[datetime]] = mapped_column(DateTime, nullable=True)
|
||||
finished_at: Mapped[Optional[datetime]] = mapped_column(DateTime, nullable=True)
|
||||
|
||||
@property
|
||||
def created_by_account(self):
|
||||
created_by_role = CreatorUserRole(self.created_by_role)
|
||||
return db.session.get(Account, self.created_by) if created_by_role == CreatorUserRole.ACCOUNT else None
|
||||
|
||||
@property
|
||||
def created_by_end_user(self):
|
||||
from models.model import EndUser
|
||||
|
||||
created_by_role = CreatorUserRole(self.created_by_role)
|
||||
return db.session.get(EndUser, self.created_by) if created_by_role == CreatorUserRole.END_USER else None
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
"""Convert to dictionary for API responses"""
|
||||
return {
|
||||
"id": self.id,
|
||||
"tenant_id": self.tenant_id,
|
||||
"app_id": self.app_id,
|
||||
"workflow_id": self.workflow_id,
|
||||
"workflow_run_id": self.workflow_run_id,
|
||||
"trigger_type": self.trigger_type,
|
||||
"trigger_data": json.loads(self.trigger_data),
|
||||
"inputs": json.loads(self.inputs),
|
||||
"outputs": json.loads(self.outputs) if self.outputs else None,
|
||||
"status": self.status,
|
||||
"error": self.error,
|
||||
"queue_name": self.queue_name,
|
||||
"celery_task_id": self.celery_task_id,
|
||||
"retry_count": self.retry_count,
|
||||
"elapsed_time": self.elapsed_time,
|
||||
"total_tokens": self.total_tokens,
|
||||
"created_by_role": self.created_by_role,
|
||||
"created_by": self.created_by,
|
||||
"created_at": self.created_at.isoformat() if self.created_at else None,
|
||||
"triggered_at": self.triggered_at.isoformat() if self.triggered_at else None,
|
||||
"finished_at": self.finished_at.isoformat() if self.finished_at else None,
|
||||
}
|
||||
|
|
|
|||
|
|
@ -0,0 +1,198 @@
|
|||
"""
|
||||
SQLAlchemy implementation of WorkflowTriggerLogRepository.
|
||||
"""
|
||||
|
||||
from collections.abc import Sequence
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Any, Optional
|
||||
|
||||
from sqlalchemy import and_, delete, func, select, update
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from models.workflow import WorkflowTriggerLog, WorkflowTriggerStatus
|
||||
from repositories.workflow_trigger_log_repository import TriggerLogOrderBy, WorkflowTriggerLogRepository
|
||||
|
||||
|
||||
class SQLAlchemyWorkflowTriggerLogRepository(WorkflowTriggerLogRepository):
|
||||
"""
|
||||
SQLAlchemy implementation of WorkflowTriggerLogRepository.
|
||||
|
||||
Optimized for large table operations with proper indexing and batch processing.
|
||||
"""
|
||||
|
||||
def __init__(self, session: Session):
|
||||
self.session = session
|
||||
|
||||
def create(self, trigger_log: WorkflowTriggerLog) -> WorkflowTriggerLog:
|
||||
"""Create a new trigger log entry."""
|
||||
self.session.add(trigger_log)
|
||||
self.session.flush()
|
||||
return trigger_log
|
||||
|
||||
def update(self, trigger_log: WorkflowTriggerLog) -> WorkflowTriggerLog:
|
||||
"""Update an existing trigger log entry."""
|
||||
self.session.merge(trigger_log)
|
||||
self.session.flush()
|
||||
return trigger_log
|
||||
|
||||
def get_by_id(self, trigger_log_id: str, tenant_id: Optional[str] = None) -> Optional[WorkflowTriggerLog]:
|
||||
"""Get a trigger log by its ID."""
|
||||
query = select(WorkflowTriggerLog).where(WorkflowTriggerLog.id == trigger_log_id)
|
||||
|
||||
if tenant_id:
|
||||
query = query.where(WorkflowTriggerLog.tenant_id == tenant_id)
|
||||
|
||||
return self.session.scalar(query)
|
||||
|
||||
def get_by_status(
|
||||
self,
|
||||
tenant_id: str,
|
||||
app_id: str,
|
||||
status: WorkflowTriggerStatus,
|
||||
limit: int = 100,
|
||||
offset: int = 0,
|
||||
order_by: TriggerLogOrderBy = TriggerLogOrderBy.CREATED_AT,
|
||||
order_desc: bool = True,
|
||||
) -> Sequence[WorkflowTriggerLog]:
|
||||
"""Get trigger logs by status with pagination."""
|
||||
query = select(WorkflowTriggerLog).where(
|
||||
and_(
|
||||
WorkflowTriggerLog.tenant_id == tenant_id,
|
||||
WorkflowTriggerLog.app_id == app_id,
|
||||
WorkflowTriggerLog.status == status,
|
||||
)
|
||||
)
|
||||
|
||||
# Apply ordering
|
||||
order_column = getattr(WorkflowTriggerLog, order_by.value)
|
||||
if order_desc:
|
||||
query = query.order_by(order_column.desc())
|
||||
else:
|
||||
query = query.order_by(order_column.asc())
|
||||
|
||||
# Apply pagination
|
||||
query = query.limit(limit).offset(offset)
|
||||
|
||||
return list(self.session.scalars(query).all())
|
||||
|
||||
def get_failed_for_retry(
|
||||
self, tenant_id: str, max_retry_count: int = 3, limit: int = 100
|
||||
) -> Sequence[WorkflowTriggerLog]:
|
||||
"""Get failed trigger logs eligible for retry."""
|
||||
query = (
|
||||
select(WorkflowTriggerLog)
|
||||
.where(
|
||||
and_(
|
||||
WorkflowTriggerLog.tenant_id == tenant_id,
|
||||
WorkflowTriggerLog.status.in_([WorkflowTriggerStatus.FAILED, WorkflowTriggerStatus.RATE_LIMITED]),
|
||||
WorkflowTriggerLog.retry_count < max_retry_count,
|
||||
)
|
||||
)
|
||||
.order_by(WorkflowTriggerLog.created_at.asc())
|
||||
.limit(limit)
|
||||
)
|
||||
|
||||
return list(self.session.scalars(query).all())
|
||||
|
||||
def get_recent_logs(
|
||||
self, tenant_id: str, app_id: str, hours: int = 24, limit: int = 100, offset: int = 0
|
||||
) -> Sequence[WorkflowTriggerLog]:
|
||||
"""Get recent trigger logs within specified hours."""
|
||||
since = datetime.utcnow() - timedelta(hours=hours)
|
||||
|
||||
query = (
|
||||
select(WorkflowTriggerLog)
|
||||
.where(
|
||||
and_(
|
||||
WorkflowTriggerLog.tenant_id == tenant_id,
|
||||
WorkflowTriggerLog.app_id == app_id,
|
||||
WorkflowTriggerLog.created_at >= since,
|
||||
)
|
||||
)
|
||||
.order_by(WorkflowTriggerLog.created_at.desc())
|
||||
.limit(limit)
|
||||
.offset(offset)
|
||||
)
|
||||
|
||||
return list(self.session.scalars(query).all())
|
||||
|
||||
def count_by_status(
|
||||
self,
|
||||
tenant_id: str,
|
||||
app_id: str,
|
||||
status: Optional[WorkflowTriggerStatus] = None,
|
||||
since: Optional[datetime] = None,
|
||||
) -> int:
|
||||
"""Count trigger logs by status."""
|
||||
query = select(func.count(WorkflowTriggerLog.id)).where(
|
||||
and_(WorkflowTriggerLog.tenant_id == tenant_id, WorkflowTriggerLog.app_id == app_id)
|
||||
)
|
||||
|
||||
if status:
|
||||
query = query.where(WorkflowTriggerLog.status == status)
|
||||
|
||||
if since:
|
||||
query = query.where(WorkflowTriggerLog.created_at >= since)
|
||||
|
||||
return self.session.scalar(query) or 0
|
||||
|
||||
def delete_expired_logs(self, tenant_id: str, before_date: datetime, batch_size: int = 1000) -> int:
|
||||
"""Delete expired trigger logs in batches."""
|
||||
total_deleted = 0
|
||||
|
||||
while True:
|
||||
# Get batch of IDs to delete
|
||||
subquery = (
|
||||
select(WorkflowTriggerLog.id)
|
||||
.where(and_(WorkflowTriggerLog.tenant_id == tenant_id, WorkflowTriggerLog.created_at < before_date))
|
||||
.limit(batch_size)
|
||||
)
|
||||
|
||||
# Delete the batch
|
||||
result = self.session.execute(delete(WorkflowTriggerLog).where(WorkflowTriggerLog.id.in_(subquery)))
|
||||
|
||||
deleted = result.rowcount
|
||||
total_deleted += deleted
|
||||
|
||||
if deleted < batch_size:
|
||||
break
|
||||
|
||||
self.session.commit()
|
||||
|
||||
return total_deleted
|
||||
|
||||
def archive_completed_logs(
|
||||
self, tenant_id: str, before_date: datetime, batch_size: int = 1000
|
||||
) -> Sequence[WorkflowTriggerLog]:
|
||||
"""Get completed logs for archival."""
|
||||
query = (
|
||||
select(WorkflowTriggerLog)
|
||||
.where(
|
||||
and_(
|
||||
WorkflowTriggerLog.tenant_id == tenant_id,
|
||||
WorkflowTriggerLog.status == WorkflowTriggerStatus.SUCCEEDED,
|
||||
WorkflowTriggerLog.finished_at < before_date,
|
||||
)
|
||||
)
|
||||
.limit(batch_size)
|
||||
)
|
||||
|
||||
return list(self.session.scalars(query).all())
|
||||
|
||||
def update_status_batch(
|
||||
self, trigger_log_ids: Sequence[str], new_status: WorkflowTriggerStatus, error_message: Optional[str] = None
|
||||
) -> int:
|
||||
"""Update status for multiple trigger logs."""
|
||||
update_data: dict[str, Any] = {"status": new_status}
|
||||
|
||||
if error_message is not None:
|
||||
update_data["error"] = error_message
|
||||
|
||||
if new_status in [WorkflowTriggerStatus.SUCCEEDED, WorkflowTriggerStatus.FAILED]:
|
||||
update_data["finished_at"] = datetime.utcnow()
|
||||
|
||||
result = self.session.execute(
|
||||
update(WorkflowTriggerLog).where(WorkflowTriggerLog.id.in_(trigger_log_ids)).values(**update_data)
|
||||
)
|
||||
|
||||
return result.rowcount
|
||||
|
|
@ -0,0 +1,206 @@
|
|||
"""
|
||||
Repository protocol for WorkflowTriggerLog operations.
|
||||
|
||||
This module provides a protocol interface for operations on WorkflowTriggerLog,
|
||||
designed to efficiently handle a potentially large volume of trigger logs with
|
||||
proper indexing and batch operations.
|
||||
"""
|
||||
|
||||
from collections.abc import Sequence
|
||||
from datetime import datetime
|
||||
from enum import StrEnum
|
||||
from typing import Optional, Protocol
|
||||
|
||||
from models.workflow import WorkflowTriggerLog, WorkflowTriggerStatus
|
||||
|
||||
|
||||
class TriggerLogOrderBy(StrEnum):
|
||||
"""Fields available for ordering trigger logs"""
|
||||
|
||||
CREATED_AT = "created_at"
|
||||
TRIGGERED_AT = "triggered_at"
|
||||
FINISHED_AT = "finished_at"
|
||||
STATUS = "status"
|
||||
|
||||
|
||||
class WorkflowTriggerLogRepository(Protocol):
|
||||
"""
|
||||
Protocol for operations on WorkflowTriggerLog.
|
||||
|
||||
This repository provides efficient access patterns for the trigger log table,
|
||||
which is expected to grow large over time. It includes:
|
||||
- Batch operations for cleanup
|
||||
- Efficient queries with proper indexing
|
||||
- Pagination support
|
||||
- Status-based filtering
|
||||
|
||||
Implementation notes:
|
||||
- Leverage database indexes on (tenant_id, app_id), status, and created_at
|
||||
- Use batch operations for deletions to avoid locking
|
||||
- Support pagination for large result sets
|
||||
"""
|
||||
|
||||
def create(self, trigger_log: WorkflowTriggerLog) -> WorkflowTriggerLog:
|
||||
"""
|
||||
Create a new trigger log entry.
|
||||
|
||||
Args:
|
||||
trigger_log: The WorkflowTriggerLog instance to create
|
||||
|
||||
Returns:
|
||||
The created WorkflowTriggerLog with generated ID
|
||||
"""
|
||||
...
|
||||
|
||||
def update(self, trigger_log: WorkflowTriggerLog) -> WorkflowTriggerLog:
|
||||
"""
|
||||
Update an existing trigger log entry.
|
||||
|
||||
Args:
|
||||
trigger_log: The WorkflowTriggerLog instance to update
|
||||
|
||||
Returns:
|
||||
The updated WorkflowTriggerLog
|
||||
"""
|
||||
...
|
||||
|
||||
def get_by_id(self, trigger_log_id: str, tenant_id: Optional[str] = None) -> Optional[WorkflowTriggerLog]:
|
||||
"""
|
||||
Get a trigger log by its ID.
|
||||
|
||||
Args:
|
||||
trigger_log_id: The trigger log identifier
|
||||
tenant_id: Optional tenant identifier for additional security
|
||||
|
||||
Returns:
|
||||
The WorkflowTriggerLog if found, None otherwise
|
||||
"""
|
||||
...
|
||||
|
||||
def get_by_status(
|
||||
self,
|
||||
tenant_id: str,
|
||||
app_id: str,
|
||||
status: WorkflowTriggerStatus,
|
||||
limit: int = 100,
|
||||
offset: int = 0,
|
||||
order_by: TriggerLogOrderBy = TriggerLogOrderBy.CREATED_AT,
|
||||
order_desc: bool = True,
|
||||
) -> Sequence[WorkflowTriggerLog]:
|
||||
"""
|
||||
Get trigger logs by status with pagination.
|
||||
|
||||
Args:
|
||||
tenant_id: The tenant identifier
|
||||
app_id: The application identifier
|
||||
status: The workflow trigger status to filter by
|
||||
limit: Maximum number of results
|
||||
offset: Number of results to skip
|
||||
order_by: Field to order results by
|
||||
order_desc: Whether to order descending (True) or ascending (False)
|
||||
|
||||
Returns:
|
||||
A sequence of WorkflowTriggerLog instances
|
||||
"""
|
||||
...
|
||||
|
||||
def get_failed_for_retry(
|
||||
self, tenant_id: str, max_retry_count: int = 3, limit: int = 100
|
||||
) -> Sequence[WorkflowTriggerLog]:
|
||||
"""
|
||||
Get failed trigger logs that are eligible for retry.
|
||||
|
||||
Args:
|
||||
tenant_id: The tenant identifier
|
||||
max_retry_count: Maximum retry count to consider
|
||||
limit: Maximum number of results
|
||||
|
||||
Returns:
|
||||
A sequence of WorkflowTriggerLog instances eligible for retry
|
||||
"""
|
||||
...
|
||||
|
||||
def get_recent_logs(
|
||||
self, tenant_id: str, app_id: str, hours: int = 24, limit: int = 100, offset: int = 0
|
||||
) -> Sequence[WorkflowTriggerLog]:
|
||||
"""
|
||||
Get recent trigger logs within specified hours.
|
||||
|
||||
Args:
|
||||
tenant_id: The tenant identifier
|
||||
app_id: The application identifier
|
||||
hours: Number of hours to look back
|
||||
limit: Maximum number of results
|
||||
offset: Number of results to skip
|
||||
|
||||
Returns:
|
||||
A sequence of recent WorkflowTriggerLog instances
|
||||
"""
|
||||
...
|
||||
|
||||
def count_by_status(
|
||||
self,
|
||||
tenant_id: str,
|
||||
app_id: str,
|
||||
status: Optional[WorkflowTriggerStatus] = None,
|
||||
since: Optional[datetime] = None,
|
||||
) -> int:
|
||||
"""
|
||||
Count trigger logs by status.
|
||||
|
||||
Args:
|
||||
tenant_id: The tenant identifier
|
||||
app_id: The application identifier
|
||||
status: Optional status filter
|
||||
since: Optional datetime to count from
|
||||
|
||||
Returns:
|
||||
Count of matching trigger logs
|
||||
"""
|
||||
...
|
||||
|
||||
def delete_expired_logs(self, tenant_id: str, before_date: datetime, batch_size: int = 1000) -> int:
|
||||
"""
|
||||
Delete expired trigger logs in batches.
|
||||
|
||||
Args:
|
||||
tenant_id: The tenant identifier
|
||||
before_date: Delete logs created before this date
|
||||
batch_size: Number of logs to delete per batch
|
||||
|
||||
Returns:
|
||||
Total number of logs deleted
|
||||
"""
|
||||
...
|
||||
|
||||
def archive_completed_logs(
|
||||
self, tenant_id: str, before_date: datetime, batch_size: int = 1000
|
||||
) -> Sequence[WorkflowTriggerLog]:
|
||||
"""
|
||||
Get completed logs for archival before deletion.
|
||||
|
||||
Args:
|
||||
tenant_id: The tenant identifier
|
||||
before_date: Get logs completed before this date
|
||||
batch_size: Number of logs to retrieve
|
||||
|
||||
Returns:
|
||||
A sequence of WorkflowTriggerLog instances for archival
|
||||
"""
|
||||
...
|
||||
|
||||
def update_status_batch(
|
||||
self, trigger_log_ids: Sequence[str], new_status: WorkflowTriggerStatus, error_message: Optional[str] = None
|
||||
) -> int:
|
||||
"""
|
||||
Update status for multiple trigger logs at once.
|
||||
|
||||
Args:
|
||||
trigger_log_ids: List of trigger log IDs to update
|
||||
new_status: The new status to set
|
||||
error_message: Optional error message to set
|
||||
|
||||
Returns:
|
||||
Number of logs updated
|
||||
"""
|
||||
...
|
||||
|
|
@ -0,0 +1,320 @@
|
|||
"""
|
||||
Universal async workflow execution service.
|
||||
|
||||
This service provides a centralized entry point for triggering workflows asynchronously
|
||||
with support for different subscription tiers, rate limiting, and execution tracking.
|
||||
"""
|
||||
|
||||
import json
|
||||
from datetime import UTC, datetime
|
||||
from typing import Optional, Union
|
||||
|
||||
from celery.result import AsyncResult
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_redis import redis_client
|
||||
from models.account import Account
|
||||
from models.enums import CreatorUserRole
|
||||
from models.model import App, EndUser
|
||||
from models.workflow import Workflow, WorkflowTriggerLog, WorkflowTriggerStatus
|
||||
from repositories.sqlalchemy_workflow_trigger_log_repository import SQLAlchemyWorkflowTriggerLogRepository
|
||||
from services.errors.llm import InvokeRateLimitError
|
||||
from services.workflow.entities import AsyncTriggerResponse, TriggerData, WorkflowTaskData
|
||||
from services.workflow.queue_dispatcher import QueueDispatcherManager, QueuePriority
|
||||
from services.workflow.rate_limiter import TenantDailyRateLimiter
|
||||
from services.workflow_service import WorkflowService
|
||||
from tasks.async_workflow_tasks import (
|
||||
execute_workflow_professional,
|
||||
execute_workflow_sandbox,
|
||||
execute_workflow_team,
|
||||
)
|
||||
|
||||
|
||||
class AsyncWorkflowService:
|
||||
"""
|
||||
Universal entry point for async workflow execution - ALL METHODS ARE NON-BLOCKING
|
||||
|
||||
This service handles:
|
||||
- Trigger data validation and processing
|
||||
- Queue routing based on subscription tier
|
||||
- Daily rate limiting with timezone support
|
||||
- Execution tracking and logging
|
||||
- Retry mechanisms for failed executions
|
||||
|
||||
Important: All trigger methods return immediately after queuing tasks.
|
||||
Actual workflow execution happens asynchronously in background Celery workers.
|
||||
Use trigger log IDs to monitor execution status and results.
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def trigger_workflow_async(
|
||||
cls, session: Session, user: Union[Account, EndUser], trigger_data: TriggerData
|
||||
) -> AsyncTriggerResponse:
|
||||
"""
|
||||
Universal entry point for async workflow execution - THIS METHOD WILL NOT BLOCK
|
||||
|
||||
Creates a trigger log and dispatches to appropriate queue based on subscription tier.
|
||||
The workflow execution happens asynchronously in the background via Celery workers.
|
||||
This method returns immediately after queuing the task, not after execution completion.
|
||||
|
||||
Args:
|
||||
session: Database session to use for operations
|
||||
user: User (Account or EndUser) who initiated the workflow trigger
|
||||
trigger_data: Validated Pydantic model containing trigger information
|
||||
|
||||
Returns:
|
||||
AsyncTriggerResponse with workflow_trigger_log_id, task_id, status="queued", and queue
|
||||
Note: The actual workflow execution status must be checked separately via workflow_trigger_log_id
|
||||
|
||||
Raises:
|
||||
ValueError: If app or workflow not found
|
||||
InvokeRateLimitError: If daily rate limit exceeded
|
||||
|
||||
Behavior:
|
||||
- Non-blocking: Returns immediately after queuing
|
||||
- Asynchronous: Actual execution happens in background Celery workers
|
||||
- Status tracking: Use workflow_trigger_log_id to monitor progress
|
||||
- Queue-based: Routes to different queues based on subscription tier
|
||||
"""
|
||||
trigger_log_repo = SQLAlchemyWorkflowTriggerLogRepository(session)
|
||||
dispatcher_manager = QueueDispatcherManager()
|
||||
workflow_service = WorkflowService()
|
||||
rate_limiter = TenantDailyRateLimiter(redis_client)
|
||||
|
||||
# 1. Validate app exists
|
||||
app_model = session.scalar(select(App).where(App.id == trigger_data.app_id))
|
||||
if not app_model:
|
||||
raise ValueError(f"App not found: {trigger_data.app_id}")
|
||||
|
||||
# 2. Get workflow
|
||||
workflow = cls._get_workflow(workflow_service, app_model, trigger_data.workflow_id)
|
||||
|
||||
# 3. Get dispatcher based on tenant subscription
|
||||
dispatcher = dispatcher_manager.get_dispatcher(trigger_data.tenant_id)
|
||||
|
||||
# 4. Get tenant owner timezone for rate limiting
|
||||
tenant_owner_tz = rate_limiter._get_tenant_owner_timezone(trigger_data.tenant_id)
|
||||
|
||||
# 5. Determine user role and ID
|
||||
if isinstance(user, Account):
|
||||
created_by_role = CreatorUserRole.ACCOUNT
|
||||
created_by = user.id
|
||||
else: # EndUser
|
||||
created_by_role = CreatorUserRole.END_USER
|
||||
created_by = user.id
|
||||
|
||||
# 6. Create trigger log entry first (for tracking)
|
||||
trigger_log = WorkflowTriggerLog(
|
||||
tenant_id=trigger_data.tenant_id,
|
||||
app_id=trigger_data.app_id,
|
||||
workflow_id=workflow.id,
|
||||
trigger_type=trigger_data.trigger_type,
|
||||
trigger_data=trigger_data.model_dump_json(),
|
||||
inputs=json.dumps(dict(trigger_data.inputs)),
|
||||
status=WorkflowTriggerStatus.PENDING,
|
||||
queue_name=dispatcher.get_queue_name(),
|
||||
retry_count=0,
|
||||
created_by_role=created_by_role,
|
||||
created_by=created_by,
|
||||
)
|
||||
|
||||
trigger_log = trigger_log_repo.create(trigger_log)
|
||||
session.commit()
|
||||
|
||||
# 7. Check and consume daily quota
|
||||
if not dispatcher.consume_quota(trigger_data.tenant_id, tenant_owner_tz):
|
||||
# Update trigger log status
|
||||
trigger_log.status = WorkflowTriggerStatus.RATE_LIMITED
|
||||
trigger_log.error = f"Daily limit reached for {dispatcher.get_queue_name()}"
|
||||
trigger_log_repo.update(trigger_log)
|
||||
session.commit()
|
||||
|
||||
remaining = rate_limiter.get_remaining_quota(
|
||||
trigger_data.tenant_id, dispatcher.get_daily_limit(), tenant_owner_tz
|
||||
)
|
||||
|
||||
reset_time = rate_limiter.get_quota_reset_time(trigger_data.tenant_id, tenant_owner_tz)
|
||||
|
||||
raise InvokeRateLimitError(
|
||||
f"Daily workflow execution limit reached. "
|
||||
f"Limit resets at {reset_time.strftime('%Y-%m-%d %H:%M:%S %Z')}. "
|
||||
f"Remaining quota: {remaining}"
|
||||
)
|
||||
|
||||
# 8. Create task data
|
||||
queue_name = dispatcher.get_queue_name()
|
||||
|
||||
task_data = WorkflowTaskData(workflow_trigger_log_id=trigger_log.id)
|
||||
|
||||
# 9. Dispatch to appropriate queue
|
||||
task_data_dict = task_data.model_dump(mode="json")
|
||||
|
||||
task: AsyncResult | None = None
|
||||
if queue_name == QueuePriority.PROFESSIONAL:
|
||||
task = execute_workflow_professional.delay(task_data_dict) # type: ignore
|
||||
elif queue_name == QueuePriority.TEAM:
|
||||
task = execute_workflow_team.delay(task_data_dict) # type: ignore
|
||||
else: # SANDBOX
|
||||
task = execute_workflow_sandbox.delay(task_data_dict) # type: ignore
|
||||
|
||||
if not task:
|
||||
raise ValueError(f"Failed to queue task for queue: {queue_name}")
|
||||
|
||||
# 10. Update trigger log with task info
|
||||
trigger_log.status = WorkflowTriggerStatus.QUEUED
|
||||
trigger_log.celery_task_id = task.id
|
||||
trigger_log.triggered_at = datetime.now(UTC)
|
||||
trigger_log_repo.update(trigger_log)
|
||||
session.commit()
|
||||
|
||||
return AsyncTriggerResponse(
|
||||
workflow_trigger_log_id=trigger_log.id,
|
||||
task_id=task.id, # type: ignore
|
||||
status="queued",
|
||||
queue=queue_name,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def reinvoke_trigger(
|
||||
cls, session: Session, user: Union[Account, EndUser], workflow_trigger_log_id: str
|
||||
) -> AsyncTriggerResponse:
|
||||
"""
|
||||
Re-invoke a previously failed or rate-limited trigger - THIS METHOD WILL NOT BLOCK
|
||||
|
||||
Updates the existing trigger log to retry status and creates a new async execution.
|
||||
Returns immediately after queuing the retry, not after execution completion.
|
||||
|
||||
Args:
|
||||
session: Database session to use for operations
|
||||
user: User (Account or EndUser) who initiated the retry
|
||||
workflow_trigger_log_id: ID of the trigger log to re-invoke
|
||||
|
||||
Returns:
|
||||
AsyncTriggerResponse with new execution information (status="queued")
|
||||
Note: This creates a new trigger log entry for the retry attempt
|
||||
|
||||
Raises:
|
||||
ValueError: If trigger log not found
|
||||
|
||||
Behavior:
|
||||
- Non-blocking: Returns immediately after queuing retry
|
||||
- Creates new trigger log: Original log marked as retrying, new log for execution
|
||||
- Preserves original trigger data: Uses same inputs and configuration
|
||||
"""
|
||||
trigger_log_repo = SQLAlchemyWorkflowTriggerLogRepository(session)
|
||||
|
||||
trigger_log = trigger_log_repo.get_by_id(workflow_trigger_log_id)
|
||||
|
||||
if not trigger_log:
|
||||
raise ValueError(f"Trigger log not found: {workflow_trigger_log_id}")
|
||||
|
||||
# Reconstruct trigger data from log
|
||||
trigger_data = TriggerData.model_validate_json(trigger_log.trigger_data)
|
||||
|
||||
# Reset log for retry
|
||||
trigger_log.status = WorkflowTriggerStatus.RETRYING
|
||||
trigger_log.retry_count += 1
|
||||
trigger_log.error = None
|
||||
trigger_log.triggered_at = datetime.now(UTC)
|
||||
trigger_log_repo.update(trigger_log)
|
||||
session.commit()
|
||||
|
||||
# Re-trigger workflow (this will create a new trigger log)
|
||||
return cls.trigger_workflow_async(session, user, trigger_data)
|
||||
|
||||
@classmethod
|
||||
def get_trigger_log(cls, workflow_trigger_log_id: str, tenant_id: Optional[str] = None) -> Optional[dict]:
|
||||
"""
|
||||
Get trigger log by ID
|
||||
|
||||
Args:
|
||||
workflow_trigger_log_id: ID of the trigger log
|
||||
tenant_id: Optional tenant ID for security check
|
||||
|
||||
Returns:
|
||||
Trigger log as dictionary or None if not found
|
||||
"""
|
||||
with Session(db.engine) as session:
|
||||
trigger_log_repo = SQLAlchemyWorkflowTriggerLogRepository(session)
|
||||
trigger_log = trigger_log_repo.get_by_id(workflow_trigger_log_id, tenant_id)
|
||||
|
||||
if not trigger_log:
|
||||
return None
|
||||
|
||||
return trigger_log.to_dict()
|
||||
|
||||
@classmethod
|
||||
def get_recent_logs(
|
||||
cls, tenant_id: str, app_id: str, hours: int = 24, limit: int = 100, offset: int = 0
|
||||
) -> list[dict]:
|
||||
"""
|
||||
Get recent trigger logs
|
||||
|
||||
Args:
|
||||
tenant_id: Tenant ID
|
||||
app_id: Application ID
|
||||
hours: Number of hours to look back
|
||||
limit: Maximum number of results
|
||||
offset: Number of results to skip
|
||||
|
||||
Returns:
|
||||
List of trigger logs as dictionaries
|
||||
"""
|
||||
with Session(db.engine) as session:
|
||||
trigger_log_repo = SQLAlchemyWorkflowTriggerLogRepository(session)
|
||||
logs = trigger_log_repo.get_recent_logs(
|
||||
tenant_id=tenant_id, app_id=app_id, hours=hours, limit=limit, offset=offset
|
||||
)
|
||||
|
||||
return [log.to_dict() for log in logs]
|
||||
|
||||
@classmethod
|
||||
def get_failed_logs_for_retry(cls, tenant_id: str, max_retry_count: int = 3, limit: int = 100) -> list[dict]:
|
||||
"""
|
||||
Get failed logs eligible for retry
|
||||
|
||||
Args:
|
||||
tenant_id: Tenant ID
|
||||
max_retry_count: Maximum retry count
|
||||
limit: Maximum number of results
|
||||
|
||||
Returns:
|
||||
List of failed trigger logs as dictionaries
|
||||
"""
|
||||
with Session(db.engine) as session:
|
||||
trigger_log_repo = SQLAlchemyWorkflowTriggerLogRepository(session)
|
||||
logs = trigger_log_repo.get_failed_for_retry(
|
||||
tenant_id=tenant_id, max_retry_count=max_retry_count, limit=limit
|
||||
)
|
||||
|
||||
return [log.to_dict() for log in logs]
|
||||
|
||||
@staticmethod
|
||||
def _get_workflow(workflow_service: WorkflowService, app_model: App, workflow_id: Optional[str] = None) -> Workflow:
|
||||
"""
|
||||
Get workflow for the app
|
||||
|
||||
Args:
|
||||
app_model: App model instance
|
||||
workflow_id: Optional specific workflow ID
|
||||
|
||||
Returns:
|
||||
Workflow instance
|
||||
|
||||
Raises:
|
||||
ValueError: If workflow not found
|
||||
"""
|
||||
if workflow_id:
|
||||
# Get specific published workflow
|
||||
workflow = workflow_service.get_published_workflow_by_id(app_model, workflow_id)
|
||||
if not workflow:
|
||||
raise ValueError(f"Published workflow not found: {workflow_id}")
|
||||
else:
|
||||
# Get default published workflow
|
||||
workflow = workflow_service.get_published_workflow(app_model)
|
||||
if not workflow:
|
||||
raise ValueError(f"No published workflow found for app: {app_model.id}")
|
||||
|
||||
return workflow
|
||||
|
|
@ -0,0 +1,113 @@
|
|||
"""
|
||||
Pydantic models for async workflow trigger system.
|
||||
"""
|
||||
|
||||
from collections.abc import Mapping, Sequence
|
||||
from enum import StrEnum
|
||||
from typing import Any, Optional
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
|
||||
from models.enums import WorkflowRunTriggeredFrom
|
||||
|
||||
|
||||
class AsyncTriggerStatus(StrEnum):
|
||||
"""Async trigger execution status"""
|
||||
|
||||
COMPLETED = "completed"
|
||||
FAILED = "failed"
|
||||
TIMEOUT = "timeout"
|
||||
|
||||
|
||||
class TriggerData(BaseModel):
|
||||
"""Base trigger data model for async workflow execution"""
|
||||
|
||||
app_id: str
|
||||
tenant_id: str
|
||||
workflow_id: Optional[str] = None
|
||||
inputs: Mapping[str, Any]
|
||||
files: Sequence[Mapping[str, Any]] = Field(default_factory=list)
|
||||
trigger_type: WorkflowRunTriggeredFrom
|
||||
|
||||
model_config = ConfigDict(use_enum_values=True)
|
||||
|
||||
|
||||
class WebhookTriggerData(TriggerData):
|
||||
"""Webhook-specific trigger data"""
|
||||
|
||||
trigger_type: WorkflowRunTriggeredFrom = WorkflowRunTriggeredFrom.WEBHOOK
|
||||
webhook_url: str
|
||||
headers: Mapping[str, str] = Field(default_factory=dict)
|
||||
method: str = "POST"
|
||||
|
||||
|
||||
class ScheduleTriggerData(TriggerData):
|
||||
"""Schedule-specific trigger data"""
|
||||
|
||||
trigger_type: WorkflowRunTriggeredFrom = WorkflowRunTriggeredFrom.SCHEDULE
|
||||
schedule_id: str
|
||||
cron_expression: str
|
||||
|
||||
|
||||
class PluginTriggerData(TriggerData):
|
||||
"""Plugin webhook trigger data"""
|
||||
|
||||
trigger_type: WorkflowRunTriggeredFrom = WorkflowRunTriggeredFrom.PLUGIN
|
||||
plugin_id: str
|
||||
webhook_url: str
|
||||
|
||||
|
||||
class WorkflowTaskData(BaseModel):
|
||||
"""Lightweight data structure for Celery workflow tasks"""
|
||||
|
||||
workflow_trigger_log_id: str # Primary tracking ID - all other data can be fetched from DB
|
||||
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
|
||||
|
||||
class AsyncTriggerExecutionResult(BaseModel):
|
||||
"""Result from async trigger-based workflow execution"""
|
||||
|
||||
execution_id: str
|
||||
status: AsyncTriggerStatus
|
||||
result: Optional[Mapping[str, Any]] = None
|
||||
error: Optional[str] = None
|
||||
elapsed_time: Optional[float] = None
|
||||
total_tokens: Optional[int] = None
|
||||
|
||||
model_config = ConfigDict(use_enum_values=True)
|
||||
|
||||
|
||||
class AsyncTriggerResponse(BaseModel):
|
||||
"""Response from triggering an async workflow"""
|
||||
|
||||
workflow_trigger_log_id: str
|
||||
task_id: str
|
||||
status: str
|
||||
queue: str
|
||||
|
||||
model_config = ConfigDict(use_enum_values=True)
|
||||
|
||||
|
||||
class TriggerLogResponse(BaseModel):
|
||||
"""Response model for trigger log data"""
|
||||
|
||||
id: str
|
||||
tenant_id: str
|
||||
app_id: str
|
||||
workflow_id: str
|
||||
trigger_type: WorkflowRunTriggeredFrom
|
||||
status: str
|
||||
queue_name: str
|
||||
retry_count: int
|
||||
celery_task_id: Optional[str] = None
|
||||
workflow_run_id: Optional[str] = None
|
||||
error: Optional[str] = None
|
||||
outputs: Optional[str] = None
|
||||
elapsed_time: Optional[float] = None
|
||||
total_tokens: Optional[int] = None
|
||||
created_at: Optional[str] = None
|
||||
triggered_at: Optional[str] = None
|
||||
finished_at: Optional[str] = None
|
||||
|
||||
model_config = ConfigDict(use_enum_values=True)
|
||||
|
|
@ -0,0 +1,158 @@
|
|||
"""
|
||||
Queue dispatcher system for async workflow execution.
|
||||
|
||||
Implements an ABC-based pattern for handling different subscription tiers
|
||||
with appropriate queue routing and rate limiting.
|
||||
"""
|
||||
|
||||
import os
|
||||
from abc import ABC, abstractmethod
|
||||
from enum import StrEnum
|
||||
|
||||
from configs import dify_config
|
||||
from extensions.ext_redis import redis_client
|
||||
from services.billing_service import BillingService
|
||||
from services.workflow.rate_limiter import TenantDailyRateLimiter
|
||||
|
||||
|
||||
class QueuePriority(StrEnum):
|
||||
"""Queue priorities for different subscription tiers"""
|
||||
|
||||
PROFESSIONAL = "workflow_professional" # Highest priority
|
||||
TEAM = "workflow_team"
|
||||
SANDBOX = "workflow_sandbox" # Free tier
|
||||
|
||||
|
||||
class BaseQueueDispatcher(ABC):
|
||||
"""Abstract base class for queue dispatchers"""
|
||||
|
||||
def __init__(self):
|
||||
self.rate_limiter = TenantDailyRateLimiter(redis_client)
|
||||
|
||||
@abstractmethod
|
||||
def get_queue_name(self) -> str:
|
||||
"""Get the queue name for this dispatcher"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_daily_limit(self) -> int:
|
||||
"""Get daily execution limit"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_priority(self) -> int:
|
||||
"""Get task priority level"""
|
||||
pass
|
||||
|
||||
def check_daily_quota(self, tenant_id: str, tenant_owner_tz: str) -> bool:
|
||||
"""
|
||||
Check if tenant has remaining daily quota
|
||||
|
||||
Args:
|
||||
tenant_id: The tenant identifier
|
||||
tenant_owner_tz: Tenant owner's timezone
|
||||
|
||||
Returns:
|
||||
True if quota available, False otherwise
|
||||
"""
|
||||
# Check without consuming
|
||||
remaining = self.rate_limiter.get_remaining_quota(
|
||||
tenant_id=tenant_id, max_daily_limit=self.get_daily_limit(), timezone_str=tenant_owner_tz
|
||||
)
|
||||
return remaining > 0
|
||||
|
||||
def consume_quota(self, tenant_id: str, tenant_owner_tz: str) -> bool:
|
||||
"""
|
||||
Consume one execution from daily quota
|
||||
|
||||
Args:
|
||||
tenant_id: The tenant identifier
|
||||
tenant_owner_tz: Tenant owner's timezone
|
||||
|
||||
Returns:
|
||||
True if quota consumed successfully, False if limit reached
|
||||
"""
|
||||
return self.rate_limiter.check_and_consume(
|
||||
tenant_id=tenant_id, max_daily_limit=self.get_daily_limit(), timezone_str=tenant_owner_tz
|
||||
)
|
||||
|
||||
|
||||
class ProfessionalQueueDispatcher(BaseQueueDispatcher):
|
||||
"""Dispatcher for professional tier"""
|
||||
|
||||
def get_queue_name(self) -> str:
|
||||
return QueuePriority.PROFESSIONAL
|
||||
|
||||
def get_daily_limit(self) -> int:
|
||||
return int(os.getenv("PROFESSIONAL_DAILY_LIMIT", "1000"))
|
||||
|
||||
def get_priority(self) -> int:
|
||||
return 100
|
||||
|
||||
|
||||
class TeamQueueDispatcher(BaseQueueDispatcher):
|
||||
"""Dispatcher for team tier"""
|
||||
|
||||
def get_queue_name(self) -> str:
|
||||
return QueuePriority.TEAM
|
||||
|
||||
def get_daily_limit(self) -> int:
|
||||
return int(os.getenv("TEAM_DAILY_LIMIT", "100"))
|
||||
|
||||
def get_priority(self) -> int:
|
||||
return 50
|
||||
|
||||
|
||||
class SandboxQueueDispatcher(BaseQueueDispatcher):
|
||||
"""Dispatcher for free/sandbox tier"""
|
||||
|
||||
def get_queue_name(self) -> str:
|
||||
return QueuePriority.SANDBOX
|
||||
|
||||
def get_daily_limit(self) -> int:
|
||||
return int(os.getenv("SANDBOX_DAILY_LIMIT", "10"))
|
||||
|
||||
def get_priority(self) -> int:
|
||||
return 10
|
||||
|
||||
|
||||
class QueueDispatcherManager:
|
||||
"""Factory for creating appropriate dispatcher based on tenant subscription"""
|
||||
|
||||
# Mapping of billing plans to dispatchers
|
||||
PLAN_DISPATCHER_MAP = {
|
||||
"professional": ProfessionalQueueDispatcher,
|
||||
"team": TeamQueueDispatcher,
|
||||
"sandbox": SandboxQueueDispatcher,
|
||||
# Add new tiers here as they're created
|
||||
# For any unknown plan, default to sandbox
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def get_dispatcher(cls, tenant_id: str) -> BaseQueueDispatcher:
|
||||
"""
|
||||
Get dispatcher based on tenant's subscription plan
|
||||
|
||||
Args:
|
||||
tenant_id: The tenant identifier
|
||||
|
||||
Returns:
|
||||
Appropriate queue dispatcher instance
|
||||
"""
|
||||
if dify_config.BILLING_ENABLED:
|
||||
try:
|
||||
billing_info = BillingService.get_info(tenant_id)
|
||||
plan = billing_info.get("subscription", {}).get("plan", "sandbox")
|
||||
except Exception:
|
||||
# If billing service fails, default to sandbox
|
||||
plan = "sandbox"
|
||||
else:
|
||||
# If billing is disabled, use team tier as default
|
||||
plan = "team"
|
||||
|
||||
dispatcher_class = cls.PLAN_DISPATCHER_MAP.get(
|
||||
plan,
|
||||
SandboxQueueDispatcher, # Default to sandbox for unknown plans
|
||||
)
|
||||
|
||||
return dispatcher_class()
|
||||
|
|
@ -0,0 +1,206 @@
|
|||
"""
|
||||
Day-based rate limiter for workflow executions.
|
||||
|
||||
Implements timezone-aware daily quotas that reset at midnight in the tenant owner's timezone.
|
||||
"""
|
||||
|
||||
from datetime import datetime, time, timedelta
|
||||
from typing import Optional, Union
|
||||
|
||||
import pytz
|
||||
from redis import Redis
|
||||
from sqlalchemy import select
|
||||
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_redis import RedisClientWrapper
|
||||
from models.account import Account, TenantAccountJoin, TenantAccountRole
|
||||
|
||||
|
||||
class TenantDailyRateLimiter:
|
||||
"""
|
||||
Day-based rate limiter that resets at midnight in tenant owner's timezone
|
||||
|
||||
This class provides Redis-based rate limiting with the following features:
|
||||
- Daily quotas that reset at midnight in tenant owner's timezone
|
||||
- Atomic check-and-consume operations
|
||||
- Automatic cleanup of stale counters
|
||||
- Support for timezone changes without duplicate limits
|
||||
"""
|
||||
|
||||
def __init__(self, redis_client: Union[Redis, RedisClientWrapper]):
|
||||
self.redis = redis_client
|
||||
|
||||
def _get_tenant_owner_timezone(self, tenant_id: str) -> str:
|
||||
"""
|
||||
Get timezone of tenant owner
|
||||
|
||||
Args:
|
||||
tenant_id: The tenant identifier
|
||||
|
||||
Returns:
|
||||
Timezone string (e.g., 'America/New_York', 'UTC')
|
||||
"""
|
||||
# Query to get tenant owner's timezone using scalar and select
|
||||
owner = db.session.scalar(
|
||||
select(Account)
|
||||
.join(TenantAccountJoin, TenantAccountJoin.account_id == Account.id)
|
||||
.where(TenantAccountJoin.tenant_id == tenant_id, TenantAccountJoin.role == TenantAccountRole.OWNER)
|
||||
)
|
||||
|
||||
if not owner:
|
||||
return "UTC"
|
||||
|
||||
return owner.timezone or "UTC"
|
||||
|
||||
def _get_day_key(self, tenant_id: str, timezone_str: str) -> str:
|
||||
"""
|
||||
Get Redis key for current day in tenant's timezone
|
||||
|
||||
Args:
|
||||
tenant_id: The tenant identifier
|
||||
timezone_str: Timezone string
|
||||
|
||||
Returns:
|
||||
Redis key for the current day
|
||||
"""
|
||||
tz = pytz.timezone(timezone_str)
|
||||
now = datetime.now(tz)
|
||||
date_str = now.strftime("%Y-%m-%d")
|
||||
return f"workflow:daily_limit:{tenant_id}:{date_str}:{timezone_str}"
|
||||
|
||||
def _get_ttl_seconds(self, timezone_str: str) -> int:
|
||||
"""
|
||||
Calculate seconds until midnight in given timezone
|
||||
|
||||
Args:
|
||||
timezone_str: Timezone string
|
||||
|
||||
Returns:
|
||||
Number of seconds until midnight
|
||||
"""
|
||||
tz = pytz.timezone(timezone_str)
|
||||
now = datetime.now(tz)
|
||||
|
||||
# Get next midnight in the timezone
|
||||
midnight = tz.localize(datetime.combine(now.date() + timedelta(days=1), time.min))
|
||||
|
||||
return int((midnight - now).total_seconds())
|
||||
|
||||
def check_and_consume(self, tenant_id: str, max_daily_limit: int, timezone_str: Optional[str] = None) -> bool:
|
||||
"""
|
||||
Check if quota available and consume one execution
|
||||
|
||||
Args:
|
||||
tenant_id: The tenant identifier
|
||||
max_daily_limit: Maximum daily limit
|
||||
timezone_str: Optional timezone string (will be fetched if not provided)
|
||||
|
||||
Returns:
|
||||
True if quota consumed successfully, False if limit reached
|
||||
"""
|
||||
if not timezone_str:
|
||||
timezone_str = self._get_tenant_owner_timezone(tenant_id)
|
||||
|
||||
key = self._get_day_key(tenant_id, timezone_str)
|
||||
ttl = self._get_ttl_seconds(timezone_str)
|
||||
|
||||
# Check current usage
|
||||
current = self.redis.get(key)
|
||||
|
||||
if current is None:
|
||||
# First execution of the day - set to 1
|
||||
self.redis.setex(key, ttl, 1)
|
||||
return True
|
||||
|
||||
current_count = int(current)
|
||||
if current_count < max_daily_limit:
|
||||
# Within limit, increment
|
||||
new_count = self.redis.incr(key)
|
||||
# Update TTL in case timezone changed
|
||||
self.redis.expire(key, ttl)
|
||||
|
||||
# Double-check in case of race condition
|
||||
if new_count <= max_daily_limit:
|
||||
return True
|
||||
else:
|
||||
# Race condition occurred, decrement back
|
||||
self.redis.decr(key)
|
||||
return False
|
||||
else:
|
||||
# Limit exceeded
|
||||
return False
|
||||
|
||||
def get_remaining_quota(self, tenant_id: str, max_daily_limit: int, timezone_str: Optional[str] = None) -> int:
|
||||
"""
|
||||
Get remaining quota for the day
|
||||
|
||||
Args:
|
||||
tenant_id: The tenant identifier
|
||||
max_daily_limit: Maximum daily limit
|
||||
timezone_str: Optional timezone string (will be fetched if not provided)
|
||||
|
||||
Returns:
|
||||
Number of remaining executions for the day
|
||||
"""
|
||||
if not timezone_str:
|
||||
timezone_str = self._get_tenant_owner_timezone(tenant_id)
|
||||
|
||||
key = self._get_day_key(tenant_id, timezone_str)
|
||||
used = int(self.redis.get(key) or 0)
|
||||
return max(0, max_daily_limit - used)
|
||||
|
||||
def get_current_usage(self, tenant_id: str, timezone_str: Optional[str] = None) -> int:
|
||||
"""
|
||||
Get current usage for the day
|
||||
|
||||
Args:
|
||||
tenant_id: The tenant identifier
|
||||
timezone_str: Optional timezone string (will be fetched if not provided)
|
||||
|
||||
Returns:
|
||||
Number of executions used today
|
||||
"""
|
||||
if not timezone_str:
|
||||
timezone_str = self._get_tenant_owner_timezone(tenant_id)
|
||||
|
||||
key = self._get_day_key(tenant_id, timezone_str)
|
||||
return int(self.redis.get(key) or 0)
|
||||
|
||||
def reset_quota(self, tenant_id: str, timezone_str: Optional[str] = None) -> bool:
|
||||
"""
|
||||
Reset quota for testing purposes
|
||||
|
||||
Args:
|
||||
tenant_id: The tenant identifier
|
||||
timezone_str: Optional timezone string (will be fetched if not provided)
|
||||
|
||||
Returns:
|
||||
True if key was deleted, False if key didn't exist
|
||||
"""
|
||||
if not timezone_str:
|
||||
timezone_str = self._get_tenant_owner_timezone(tenant_id)
|
||||
|
||||
key = self._get_day_key(tenant_id, timezone_str)
|
||||
return bool(self.redis.delete(key))
|
||||
|
||||
def get_quota_reset_time(self, tenant_id: str, timezone_str: Optional[str] = None) -> datetime:
|
||||
"""
|
||||
Get the time when quota will reset (midnight in tenant's timezone)
|
||||
|
||||
Args:
|
||||
tenant_id: The tenant identifier
|
||||
timezone_str: Optional timezone string (will be fetched if not provided)
|
||||
|
||||
Returns:
|
||||
Datetime when quota resets
|
||||
"""
|
||||
if not timezone_str:
|
||||
timezone_str = self._get_tenant_owner_timezone(tenant_id)
|
||||
|
||||
tz = pytz.timezone(timezone_str)
|
||||
now = datetime.now(tz)
|
||||
|
||||
# Get next midnight in the timezone
|
||||
midnight = tz.localize(datetime.combine(now.date() + timedelta(days=1), time.min))
|
||||
|
||||
return midnight
|
||||
|
|
@ -0,0 +1,201 @@
|
|||
"""
|
||||
Celery tasks for async workflow execution.
|
||||
|
||||
These tasks handle workflow execution for different subscription tiers
|
||||
with appropriate retry policies and error handling.
|
||||
"""
|
||||
|
||||
import json
|
||||
from datetime import UTC, datetime
|
||||
|
||||
from celery import shared_task
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session, sessionmaker
|
||||
|
||||
from configs import dify_config
|
||||
from core.app.apps.workflow.app_generator import WorkflowAppGenerator
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from extensions.ext_database import db
|
||||
from models.account import Account
|
||||
from models.enums import CreatorUserRole
|
||||
from models.model import App, EndUser, Tenant
|
||||
from models.workflow import Workflow, WorkflowTriggerLog, WorkflowTriggerStatus
|
||||
from repositories.sqlalchemy_workflow_trigger_log_repository import SQLAlchemyWorkflowTriggerLogRepository
|
||||
from services.workflow.entities import AsyncTriggerExecutionResult, AsyncTriggerStatus, TriggerData, WorkflowTaskData
|
||||
|
||||
# Determine queue names based on edition
|
||||
if dify_config.EDITION == "CLOUD":
|
||||
# Cloud edition: separate queues for different tiers
|
||||
PROFESSIONAL_QUEUE = "workflow_professional"
|
||||
TEAM_QUEUE = "workflow_team"
|
||||
SANDBOX_QUEUE = "workflow_sandbox"
|
||||
else:
|
||||
# Community edition: single workflow queue (not dataset)
|
||||
PROFESSIONAL_QUEUE = "workflow"
|
||||
TEAM_QUEUE = "workflow"
|
||||
SANDBOX_QUEUE = "workflow"
|
||||
|
||||
|
||||
@shared_task(queue=PROFESSIONAL_QUEUE)
|
||||
def execute_workflow_professional(task_data_dict: dict) -> dict:
|
||||
"""Execute workflow for professional tier with highest priority"""
|
||||
task_data = WorkflowTaskData.model_validate(task_data_dict)
|
||||
return _execute_workflow_common(task_data).model_dump()
|
||||
|
||||
|
||||
@shared_task(queue=TEAM_QUEUE)
|
||||
def execute_workflow_team(task_data_dict: dict) -> dict:
|
||||
"""Execute workflow for team tier"""
|
||||
task_data = WorkflowTaskData.model_validate(task_data_dict)
|
||||
return _execute_workflow_common(task_data).model_dump()
|
||||
|
||||
|
||||
@shared_task(queue=SANDBOX_QUEUE)
|
||||
def execute_workflow_sandbox(task_data_dict: dict) -> dict:
|
||||
"""Execute workflow for free tier with lower retry limit"""
|
||||
task_data = WorkflowTaskData.model_validate(task_data_dict)
|
||||
return _execute_workflow_common(task_data).model_dump()
|
||||
|
||||
|
||||
def _execute_workflow_common(task_data: WorkflowTaskData) -> AsyncTriggerExecutionResult:
|
||||
"""
|
||||
Common workflow execution logic with trigger log updates
|
||||
|
||||
Args:
|
||||
task_data: Validated Pydantic model with task data
|
||||
|
||||
Returns:
|
||||
AsyncTriggerExecutionResult: Pydantic model with execution results
|
||||
"""
|
||||
# Create a new session for this task
|
||||
session_factory = sessionmaker(bind=db.engine, expire_on_commit=False)
|
||||
|
||||
with session_factory() as session:
|
||||
trigger_log_repo = SQLAlchemyWorkflowTriggerLogRepository(session)
|
||||
|
||||
# Get trigger log
|
||||
trigger_log = trigger_log_repo.get_by_id(task_data.workflow_trigger_log_id)
|
||||
|
||||
if not trigger_log:
|
||||
# This should not happen, but handle gracefully
|
||||
return AsyncTriggerExecutionResult(
|
||||
execution_id=task_data.workflow_trigger_log_id,
|
||||
status=AsyncTriggerStatus.FAILED,
|
||||
error=f"Trigger log not found: {task_data.workflow_trigger_log_id}",
|
||||
)
|
||||
|
||||
# Reconstruct execution data from trigger log
|
||||
trigger_data = TriggerData.model_validate_json(trigger_log.trigger_data)
|
||||
|
||||
# Update status to running
|
||||
trigger_log.status = WorkflowTriggerStatus.RUNNING
|
||||
trigger_log_repo.update(trigger_log)
|
||||
session.commit()
|
||||
|
||||
start_time = datetime.now(UTC)
|
||||
|
||||
try:
|
||||
# Get app and workflow models
|
||||
app_model = session.scalar(select(App).where(App.id == trigger_log.app_id))
|
||||
|
||||
if not app_model:
|
||||
raise ValueError(f"App not found: {trigger_log.app_id}")
|
||||
|
||||
workflow = session.scalar(select(Workflow).where(Workflow.id == trigger_log.workflow_id))
|
||||
if not workflow:
|
||||
raise ValueError(f"Workflow not found: {trigger_log.workflow_id}")
|
||||
|
||||
user = _get_user(session, trigger_log)
|
||||
|
||||
# Execute workflow using WorkflowAppGenerator
|
||||
generator = WorkflowAppGenerator()
|
||||
|
||||
# Prepare args matching AppGenerateService.generate format
|
||||
args = {"inputs": dict(trigger_data.inputs), "files": list(trigger_data.files)}
|
||||
|
||||
# If workflow_id was specified, add it to args
|
||||
if trigger_data.workflow_id:
|
||||
args["workflow_id"] = trigger_data.workflow_id
|
||||
|
||||
# Execute the workflow with the trigger type
|
||||
result = generator.generate(
|
||||
app_model=app_model,
|
||||
workflow=workflow,
|
||||
user=user,
|
||||
args=args,
|
||||
invoke_from=InvokeFrom.SERVICE_API,
|
||||
streaming=False,
|
||||
call_depth=0,
|
||||
workflow_thread_pool_id=None,
|
||||
triggered_from=trigger_data.trigger_type,
|
||||
)
|
||||
|
||||
# Calculate elapsed time
|
||||
elapsed_time = (datetime.now(UTC) - start_time).total_seconds()
|
||||
|
||||
# Extract relevant data from result
|
||||
if isinstance(result, dict):
|
||||
workflow_run_id = result.get("workflow_run_id")
|
||||
total_tokens = result.get("total_tokens")
|
||||
outputs = result
|
||||
else:
|
||||
# Handle generator result - collect all data
|
||||
workflow_run_id = None
|
||||
total_tokens = None
|
||||
outputs = {"data": "streaming_result"}
|
||||
|
||||
# Update trigger log with success
|
||||
trigger_log.status = WorkflowTriggerStatus.SUCCEEDED
|
||||
trigger_log.workflow_run_id = workflow_run_id
|
||||
trigger_log.outputs = json.dumps(outputs)
|
||||
trigger_log.elapsed_time = elapsed_time
|
||||
trigger_log.total_tokens = total_tokens
|
||||
trigger_log.finished_at = datetime.now(UTC)
|
||||
trigger_log_repo.update(trigger_log)
|
||||
session.commit()
|
||||
|
||||
return AsyncTriggerExecutionResult(
|
||||
execution_id=trigger_log.id,
|
||||
status=AsyncTriggerStatus.COMPLETED,
|
||||
result=outputs,
|
||||
elapsed_time=elapsed_time,
|
||||
total_tokens=total_tokens,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
# Calculate elapsed time for failed execution
|
||||
elapsed_time = (datetime.now(UTC) - start_time).total_seconds()
|
||||
|
||||
# Update trigger log with failure
|
||||
trigger_log.status = WorkflowTriggerStatus.FAILED
|
||||
trigger_log.error = str(e)
|
||||
trigger_log.finished_at = datetime.now(UTC)
|
||||
trigger_log.elapsed_time = elapsed_time
|
||||
trigger_log_repo.update(trigger_log)
|
||||
|
||||
# Final failure - no retry logic (simplified like RAG tasks)
|
||||
session.commit()
|
||||
|
||||
return AsyncTriggerExecutionResult(
|
||||
execution_id=trigger_log.id, status=AsyncTriggerStatus.FAILED, error=str(e), elapsed_time=elapsed_time
|
||||
)
|
||||
|
||||
|
||||
def _get_user(session: Session, trigger_log: WorkflowTriggerLog) -> Account | EndUser:
|
||||
"""Compose user from trigger log"""
|
||||
tenant = session.scalar(select(Tenant).where(Tenant.id == trigger_log.tenant_id))
|
||||
if not tenant:
|
||||
raise ValueError(f"Tenant not found: {trigger_log.tenant_id}")
|
||||
|
||||
# Get user from trigger log
|
||||
if trigger_log.created_by_role == CreatorUserRole.ACCOUNT:
|
||||
user = session.scalar(select(Account).where(Account.id == trigger_log.created_by))
|
||||
if user:
|
||||
user.current_tenant = tenant
|
||||
else: # CreatorUserRole.END_USER
|
||||
user = session.scalar(select(EndUser).where(EndUser.id == trigger_log.created_by))
|
||||
|
||||
if not user:
|
||||
raise ValueError(f"User not found: {trigger_log.created_by} (role: {trigger_log.created_by_role})")
|
||||
|
||||
return user
|
||||
|
|
@ -2,10 +2,100 @@
|
|||
|
||||
set -x
|
||||
|
||||
# Help function
|
||||
show_help() {
|
||||
echo "Usage: $0 [OPTIONS]"
|
||||
echo ""
|
||||
echo "Options:"
|
||||
echo " -q, --queues QUEUES Comma-separated list of queues to process"
|
||||
echo " -c, --concurrency NUM Number of worker processes (default: 1)"
|
||||
echo " -P, --pool POOL Pool implementation (default: gevent)"
|
||||
echo " --loglevel LEVEL Log level (default: INFO)"
|
||||
echo " -h, --help Show this help message"
|
||||
echo ""
|
||||
echo "Examples:"
|
||||
echo " $0 --queues dataset,workflow"
|
||||
echo " $0 --queues workflow_professional,workflow_team --concurrency 4"
|
||||
echo " $0 --queues dataset --concurrency 2 --pool prefork"
|
||||
echo ""
|
||||
echo "Available queues:"
|
||||
echo " dataset - RAG indexing and document processing"
|
||||
echo " workflow - Workflow triggers (community edition)"
|
||||
echo " workflow_professional - Professional tier workflows (cloud edition)"
|
||||
echo " workflow_team - Team tier workflows (cloud edition)"
|
||||
echo " workflow_sandbox - Sandbox tier workflows (cloud edition)"
|
||||
echo " generation - Content generation tasks"
|
||||
echo " mail - Email notifications"
|
||||
echo " ops_trace - Operations tracing"
|
||||
echo " app_deletion - Application cleanup"
|
||||
echo " plugin - Plugin operations"
|
||||
echo " workflow_storage - Workflow storage tasks"
|
||||
}
|
||||
|
||||
# Parse command line arguments
|
||||
QUEUES=""
|
||||
CONCURRENCY=1
|
||||
POOL="gevent"
|
||||
LOGLEVEL="INFO"
|
||||
|
||||
while [[ $# -gt 0 ]]; do
|
||||
case $1 in
|
||||
-q|--queues)
|
||||
QUEUES="$2"
|
||||
shift 2
|
||||
;;
|
||||
-c|--concurrency)
|
||||
CONCURRENCY="$2"
|
||||
shift 2
|
||||
;;
|
||||
-P|--pool)
|
||||
POOL="$2"
|
||||
shift 2
|
||||
;;
|
||||
--loglevel)
|
||||
LOGLEVEL="$2"
|
||||
shift 2
|
||||
;;
|
||||
-h|--help)
|
||||
show_help
|
||||
exit 0
|
||||
;;
|
||||
*)
|
||||
echo "Unknown option: $1"
|
||||
show_help
|
||||
exit 1
|
||||
;;
|
||||
esac
|
||||
done
|
||||
|
||||
SCRIPT_DIR="$(dirname "$(realpath "$0")")"
|
||||
cd "$SCRIPT_DIR/.."
|
||||
|
||||
# If no queues specified, use edition-based defaults
|
||||
if [[ -z "${QUEUES}" ]]; then
|
||||
# Get EDITION from environment, default to SELF_HOSTED (community edition)
|
||||
EDITION=${EDITION:-"SELF_HOSTED"}
|
||||
|
||||
# Configure queues based on edition
|
||||
if [[ "${EDITION}" == "CLOUD" ]]; then
|
||||
# Cloud edition: separate queues for dataset and trigger tasks
|
||||
QUEUES="dataset,generation,mail,ops_trace,app_deletion,plugin,workflow_storage,workflow_professional,workflow_team,workflow_sandbox"
|
||||
else
|
||||
# Community edition (SELF_HOSTED): dataset and workflow have separate queues
|
||||
QUEUES="dataset,generation,mail,ops_trace,app_deletion,plugin,workflow_storage,workflow"
|
||||
fi
|
||||
|
||||
echo "No queues specified, using edition-based defaults: ${QUEUES}"
|
||||
else
|
||||
echo "Using specified queues: ${QUEUES}"
|
||||
fi
|
||||
|
||||
echo "Starting Celery worker with:"
|
||||
echo " Queues: ${QUEUES}"
|
||||
echo " Concurrency: ${CONCURRENCY}"
|
||||
echo " Pool: ${POOL}"
|
||||
echo " Log Level: ${LOGLEVEL}"
|
||||
|
||||
uv --directory api run \
|
||||
celery -A app.celery worker \
|
||||
-P gevent -c 1 --loglevel INFO -Q dataset,generation,mail,ops_trace,app_deletion,plugin,workflow_storage
|
||||
-P ${POOL} -c ${CONCURRENCY} --loglevel ${LOGLEVEL} -Q ${QUEUES}
|
||||
|
|
|
|||
|
|
@ -29,14 +29,14 @@ services:
|
|||
- default
|
||||
|
||||
# worker service
|
||||
# The Celery worker for processing the queue.
|
||||
# The Celery worker for processing all queues (dataset, workflow, mail, etc.)
|
||||
worker:
|
||||
image: langgenius/dify-api:1.7.2
|
||||
restart: always
|
||||
environment:
|
||||
# Use the shared environment variables.
|
||||
<<: *shared-api-worker-env
|
||||
# Startup mode, 'worker' starts the Celery worker for processing the queue.
|
||||
# Startup mode, 'worker' starts the Celery worker for processing all queues.
|
||||
MODE: worker
|
||||
SENTRY_DSN: ${API_SENTRY_DSN:-}
|
||||
SENTRY_TRACES_SAMPLE_RATE: ${API_SENTRY_TRACES_SAMPLE_RATE:-1.0}
|
||||
|
|
|
|||
Loading…
Reference in New Issue