From 6aed7e3ff45341359839060b3f1936cf6e70371a Mon Sep 17 00:00:00 2001 From: Yeuoly <45712896+Yeuoly@users.noreply.github.com> Date: Sat, 23 Aug 2025 20:18:08 +0800 Subject: [PATCH] feat/trigger universal entry (#24358) Co-authored-by: Claude Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com> --- api/.vscode/launch.json.example | 2 +- api/core/app/apps/workflow/app_generator.py | 9 +- api/docker/entrypoint.sh | 36 +- api/extensions/ext_celery.py | 4 +- ...-994bdf7197ab_add_workflow_trigger_logs.py | 66 ++++ api/models/enums.py | 5 +- api/models/workflow.py | 119 +++++++ ...alchemy_workflow_trigger_log_repository.py | 198 +++++++++++ .../workflow_trigger_log_repository.py | 206 +++++++++++ api/services/async_workflow_service.py | 320 ++++++++++++++++++ api/services/workflow/entities.py | 113 +++++++ api/services/workflow/queue_dispatcher.py | 158 +++++++++ api/services/workflow/rate_limiter.py | 206 +++++++++++ api/tasks/async_workflow_tasks.py | 201 +++++++++++ dev/start-worker | 92 ++++- docker/docker-compose-template.yaml | 4 +- 16 files changed, 1730 insertions(+), 9 deletions(-) create mode 100644 api/migrations/versions/2025_08_23_2006-994bdf7197ab_add_workflow_trigger_logs.py create mode 100644 api/repositories/sqlalchemy_workflow_trigger_log_repository.py create mode 100644 api/repositories/workflow_trigger_log_repository.py create mode 100644 api/services/async_workflow_service.py create mode 100644 api/services/workflow/entities.py create mode 100644 api/services/workflow/queue_dispatcher.py create mode 100644 api/services/workflow/rate_limiter.py create mode 100644 api/tasks/async_workflow_tasks.py diff --git a/api/.vscode/launch.json.example b/api/.vscode/launch.json.example index b9e32e2511..a52eca63d9 100644 --- a/api/.vscode/launch.json.example +++ b/api/.vscode/launch.json.example @@ -54,7 +54,7 @@ "--loglevel", "DEBUG", "-Q", - "dataset,generation,mail,ops_trace,app_deletion" + "dataset,generation,mail,ops_trace,app_deletion,workflow" ] } ] diff --git a/api/core/app/apps/workflow/app_generator.py b/api/core/app/apps/workflow/app_generator.py index 22b0234604..77fb8c1975 100644 --- a/api/core/app/apps/workflow/app_generator.py +++ b/api/core/app/apps/workflow/app_generator.py @@ -54,6 +54,7 @@ class WorkflowAppGenerator(BaseAppGenerator): streaming: Literal[True], call_depth: int, workflow_thread_pool_id: Optional[str], + triggered_from: Optional[WorkflowRunTriggeredFrom] = None, ) -> Generator[Mapping | str, None, None]: ... @overload @@ -68,6 +69,7 @@ class WorkflowAppGenerator(BaseAppGenerator): streaming: Literal[False], call_depth: int, workflow_thread_pool_id: Optional[str], + triggered_from: Optional[WorkflowRunTriggeredFrom] = None, ) -> Mapping[str, Any]: ... @overload @@ -82,6 +84,7 @@ class WorkflowAppGenerator(BaseAppGenerator): streaming: bool, call_depth: int, workflow_thread_pool_id: Optional[str], + triggered_from: Optional[WorkflowRunTriggeredFrom] = None, ) -> Union[Mapping[str, Any], Generator[Mapping | str, None, None]]: ... def generate( @@ -95,6 +98,7 @@ class WorkflowAppGenerator(BaseAppGenerator): streaming: bool = True, call_depth: int = 0, workflow_thread_pool_id: Optional[str] = None, + triggered_from: Optional[WorkflowRunTriggeredFrom] = None, ) -> Union[Mapping[str, Any], Generator[Mapping | str, None, None]]: files: Sequence[Mapping[str, Any]] = args.get("files") or [] @@ -159,7 +163,10 @@ class WorkflowAppGenerator(BaseAppGenerator): # Create session factory session_factory = sessionmaker(bind=db.engine, expire_on_commit=False) # Create workflow execution(aka workflow run) repository - if invoke_from == InvokeFrom.DEBUGGER: + if triggered_from is not None: + # Use explicitly provided triggered_from (for async triggers) + workflow_triggered_from = triggered_from + elif invoke_from == InvokeFrom.DEBUGGER: workflow_triggered_from = WorkflowRunTriggeredFrom.DEBUGGING else: workflow_triggered_from = WorkflowRunTriggeredFrom.APP_RUN diff --git a/api/docker/entrypoint.sh b/api/docker/entrypoint.sh index da147fe895..c6b1afc3bd 100755 --- a/api/docker/entrypoint.sh +++ b/api/docker/entrypoint.sh @@ -30,9 +30,41 @@ if [[ "${MODE}" == "worker" ]]; then CONCURRENCY_OPTION="-c ${CELERY_WORKER_AMOUNT:-1}" fi - exec celery -A app.celery worker -P ${CELERY_WORKER_CLASS:-gevent} $CONCURRENCY_OPTION \ + # Configure queues based on edition if not explicitly set + if [[ -z "${CELERY_QUEUES}" ]]; then + if [[ "${EDITION}" == "CLOUD" ]]; then + # Cloud edition: separate queues for dataset and trigger tasks + DEFAULT_QUEUES="dataset,mail,ops_trace,app_deletion,plugin,workflow_storage,workflow_professional,workflow_team,workflow_sandbox" + else + # Community edition (SELF_HOSTED): dataset and workflow have separate queues + DEFAULT_QUEUES="dataset,mail,ops_trace,app_deletion,plugin,workflow_storage,workflow" + fi + else + DEFAULT_QUEUES="${CELERY_QUEUES}" + fi + + # Support for Kubernetes deployment with specific queue workers + # Environment variables that can be set: + # - CELERY_WORKER_QUEUES: Comma-separated list of queues (overrides CELERY_QUEUES) + # - CELERY_WORKER_CONCURRENCY: Number of worker processes (overrides CELERY_WORKER_AMOUNT) + # - CELERY_WORKER_POOL: Pool implementation (overrides CELERY_WORKER_CLASS) + + if [[ -n "${CELERY_WORKER_QUEUES}" ]]; then + DEFAULT_QUEUES="${CELERY_WORKER_QUEUES}" + echo "Using CELERY_WORKER_QUEUES: ${DEFAULT_QUEUES}" + fi + + if [[ -n "${CELERY_WORKER_CONCURRENCY}" ]]; then + CONCURRENCY_OPTION="-c ${CELERY_WORKER_CONCURRENCY}" + echo "Using CELERY_WORKER_CONCURRENCY: ${CELERY_WORKER_CONCURRENCY}" + fi + + WORKER_POOL="${CELERY_WORKER_POOL:-${CELERY_WORKER_CLASS:-gevent}}" + echo "Starting Celery worker with queues: ${DEFAULT_QUEUES}" + + exec celery -A app.celery worker -P ${WORKER_POOL} $CONCURRENCY_OPTION \ --max-tasks-per-child ${MAX_TASK_PRE_CHILD:-50} --loglevel ${LOG_LEVEL:-INFO} \ - -Q ${CELERY_QUEUES:-dataset,mail,ops_trace,app_deletion,plugin,workflow_storage} + -Q ${DEFAULT_QUEUES} elif [[ "${MODE}" == "beat" ]]; then exec celery -A app.celery beat --loglevel ${LOG_LEVEL:-INFO} diff --git a/api/extensions/ext_celery.py b/api/extensions/ext_celery.py index 00e0bd9a16..bcda75aee2 100644 --- a/api/extensions/ext_celery.py +++ b/api/extensions/ext_celery.py @@ -96,7 +96,9 @@ def init_app(app: DifyApp) -> Celery: celery_app.set_default() app.extensions["celery"] = celery_app - imports = [] + imports = [ + "tasks.async_workflow_tasks", # trigger workers + ] day = dify_config.CELERY_BEAT_SCHEDULER_TIME # if you add a new task, please add the switch to CeleryScheduleTasksConfig diff --git a/api/migrations/versions/2025_08_23_2006-994bdf7197ab_add_workflow_trigger_logs.py b/api/migrations/versions/2025_08_23_2006-994bdf7197ab_add_workflow_trigger_logs.py new file mode 100644 index 0000000000..20760983b6 --- /dev/null +++ b/api/migrations/versions/2025_08_23_2006-994bdf7197ab_add_workflow_trigger_logs.py @@ -0,0 +1,66 @@ +"""empty message + +Revision ID: 994bdf7197ab +Revises: fa8b0fa6f407 +Create Date: 2025-08-23 20:06:35.995973 + +""" +from alembic import op +import models as models +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision = '994bdf7197ab' +down_revision = 'fa8b0fa6f407' +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + op.create_table('workflow_trigger_logs', + sa.Column('id', models.types.StringUUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False), + sa.Column('tenant_id', models.types.StringUUID(), nullable=False), + sa.Column('app_id', models.types.StringUUID(), nullable=False), + sa.Column('workflow_id', models.types.StringUUID(), nullable=False), + sa.Column('workflow_run_id', models.types.StringUUID(), nullable=True), + sa.Column('trigger_type', sa.String(length=50), nullable=False), + sa.Column('trigger_data', sa.Text(), nullable=False), + sa.Column('inputs', sa.Text(), nullable=False), + sa.Column('outputs', sa.Text(), nullable=True), + sa.Column('status', sa.String(length=50), nullable=False), + sa.Column('error', sa.Text(), nullable=True), + sa.Column('queue_name', sa.String(length=100), nullable=False), + sa.Column('celery_task_id', sa.String(length=255), nullable=True), + sa.Column('retry_count', sa.Integer(), nullable=False), + sa.Column('elapsed_time', sa.Float(), nullable=True), + sa.Column('total_tokens', sa.Integer(), nullable=True), + sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP'), nullable=False), + sa.Column('created_by_role', sa.String(length=255), nullable=False), + sa.Column('created_by', sa.String(length=255), nullable=False), + sa.Column('triggered_at', sa.DateTime(), nullable=True), + sa.Column('finished_at', sa.DateTime(), nullable=True), + sa.PrimaryKeyConstraint('id', name='workflow_trigger_log_pkey') + ) + with op.batch_alter_table('workflow_trigger_logs', schema=None) as batch_op: + batch_op.create_index('workflow_trigger_log_created_at_idx', ['created_at'], unique=False) + batch_op.create_index('workflow_trigger_log_status_idx', ['status'], unique=False) + batch_op.create_index('workflow_trigger_log_tenant_app_idx', ['tenant_id', 'app_id'], unique=False) + batch_op.create_index('workflow_trigger_log_workflow_id_idx', ['workflow_id'], unique=False) + batch_op.create_index('workflow_trigger_log_workflow_run_idx', ['workflow_run_id'], unique=False) + + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table('workflow_trigger_logs', schema=None) as batch_op: + batch_op.drop_index('workflow_trigger_log_workflow_run_idx') + batch_op.drop_index('workflow_trigger_log_workflow_id_idx') + batch_op.drop_index('workflow_trigger_log_tenant_app_idx') + batch_op.drop_index('workflow_trigger_log_status_idx') + batch_op.drop_index('workflow_trigger_log_created_at_idx') + + op.drop_table('workflow_trigger_logs') + # ### end Alembic commands ### diff --git a/api/models/enums.py b/api/models/enums.py index cc9f28a7bb..a2693317b0 100644 --- a/api/models/enums.py +++ b/api/models/enums.py @@ -13,7 +13,10 @@ class UserFrom(StrEnum): class WorkflowRunTriggeredFrom(StrEnum): DEBUGGING = "debugging" - APP_RUN = "app-run" + APP_RUN = "app-run" # webapp / service api + WEBHOOK = "webhook" + SCHEDULE = "schedule" + PLUGIN = "plugin" class DraftVariableType(StrEnum): diff --git a/api/models/workflow.py b/api/models/workflow.py index 2fea3fcd78..f00c0030dd 100644 --- a/api/models/workflow.py +++ b/api/models/workflow.py @@ -1262,3 +1262,122 @@ class WorkflowDraftVariable(Base): def is_system_variable_editable(name: str) -> bool: return name in _EDITABLE_SYSTEM_VARIABLE + + +class WorkflowTriggerStatus(StrEnum): + """Workflow Trigger Execution Status""" + + PENDING = "pending" + QUEUED = "queued" + RUNNING = "running" + SUCCEEDED = "succeeded" + FAILED = "failed" + RATE_LIMITED = "rate_limited" + RETRYING = "retrying" + + +class WorkflowTriggerLog(Base): + """ + Workflow Trigger Log + + Track async trigger workflow runs with re-invocation capability + + Attributes: + - id (uuid) Trigger Log ID (used as workflow_trigger_log_id) + - tenant_id (uuid) Workspace ID + - app_id (uuid) App ID + - workflow_id (uuid) Workflow ID + - workflow_run_id (uuid) Optional - Associated workflow run ID when execution starts + - trigger_type (string) Type of trigger: webhook, schedule, plugin + - trigger_data (text) Full trigger data including inputs (JSON) + - inputs (text) Input parameters (JSON) + - outputs (text) Optional - Output content (JSON) + - status (string) Execution status + - error (text) Optional - Error message if failed + - queue_name (string) Celery queue used + - celery_task_id (string) Optional - Celery task ID for tracking + - retry_count (int) Number of retry attempts + - elapsed_time (float) Optional - Time consumption in seconds + - total_tokens (int) Optional - Total tokens used + - created_by_role (string) Creator role: account, end_user + - created_by (string) Creator ID + - created_at (timestamp) Creation time + - triggered_at (timestamp) Optional - When actually triggered + - finished_at (timestamp) Optional - Completion time + """ + + __tablename__ = "workflow_trigger_logs" + __table_args__ = ( + sa.PrimaryKeyConstraint("id", name="workflow_trigger_log_pkey"), + sa.Index("workflow_trigger_log_tenant_app_idx", "tenant_id", "app_id"), + sa.Index("workflow_trigger_log_status_idx", "status"), + sa.Index("workflow_trigger_log_created_at_idx", "created_at"), + sa.Index("workflow_trigger_log_workflow_run_idx", "workflow_run_id"), + sa.Index("workflow_trigger_log_workflow_id_idx", "workflow_id"), + ) + + id: Mapped[str] = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()")) + tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False) + app_id: Mapped[str] = mapped_column(StringUUID, nullable=False) + workflow_id: Mapped[str] = mapped_column(StringUUID, nullable=False) + workflow_run_id: Mapped[Optional[str]] = mapped_column(StringUUID, nullable=True) + + trigger_type: Mapped[str] = mapped_column(String(50), nullable=False) + trigger_data: Mapped[str] = mapped_column(sa.Text, nullable=False) # Full TriggerData as JSON + inputs: Mapped[str] = mapped_column(sa.Text, nullable=False) # Just inputs for easy viewing + outputs: Mapped[Optional[str]] = mapped_column(sa.Text, nullable=True) + + status: Mapped[str] = mapped_column(String(50), nullable=False, default=WorkflowTriggerStatus.PENDING) + error: Mapped[Optional[str]] = mapped_column(sa.Text, nullable=True) + + queue_name: Mapped[str] = mapped_column(String(100), nullable=False) + celery_task_id: Mapped[Optional[str]] = mapped_column(String(255), nullable=True) + retry_count: Mapped[int] = mapped_column(sa.Integer, nullable=False, default=0) + + elapsed_time: Mapped[Optional[float]] = mapped_column(sa.Float, nullable=True) + total_tokens: Mapped[Optional[int]] = mapped_column(sa.Integer, nullable=True) + + created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp()) + created_by_role: Mapped[str] = mapped_column(String(255), nullable=False) + created_by: Mapped[str] = mapped_column(String(255), nullable=False) + + triggered_at: Mapped[Optional[datetime]] = mapped_column(DateTime, nullable=True) + finished_at: Mapped[Optional[datetime]] = mapped_column(DateTime, nullable=True) + + @property + def created_by_account(self): + created_by_role = CreatorUserRole(self.created_by_role) + return db.session.get(Account, self.created_by) if created_by_role == CreatorUserRole.ACCOUNT else None + + @property + def created_by_end_user(self): + from models.model import EndUser + + created_by_role = CreatorUserRole(self.created_by_role) + return db.session.get(EndUser, self.created_by) if created_by_role == CreatorUserRole.END_USER else None + + def to_dict(self) -> dict: + """Convert to dictionary for API responses""" + return { + "id": self.id, + "tenant_id": self.tenant_id, + "app_id": self.app_id, + "workflow_id": self.workflow_id, + "workflow_run_id": self.workflow_run_id, + "trigger_type": self.trigger_type, + "trigger_data": json.loads(self.trigger_data), + "inputs": json.loads(self.inputs), + "outputs": json.loads(self.outputs) if self.outputs else None, + "status": self.status, + "error": self.error, + "queue_name": self.queue_name, + "celery_task_id": self.celery_task_id, + "retry_count": self.retry_count, + "elapsed_time": self.elapsed_time, + "total_tokens": self.total_tokens, + "created_by_role": self.created_by_role, + "created_by": self.created_by, + "created_at": self.created_at.isoformat() if self.created_at else None, + "triggered_at": self.triggered_at.isoformat() if self.triggered_at else None, + "finished_at": self.finished_at.isoformat() if self.finished_at else None, + } diff --git a/api/repositories/sqlalchemy_workflow_trigger_log_repository.py b/api/repositories/sqlalchemy_workflow_trigger_log_repository.py new file mode 100644 index 0000000000..1276686cd8 --- /dev/null +++ b/api/repositories/sqlalchemy_workflow_trigger_log_repository.py @@ -0,0 +1,198 @@ +""" +SQLAlchemy implementation of WorkflowTriggerLogRepository. +""" + +from collections.abc import Sequence +from datetime import datetime, timedelta +from typing import Any, Optional + +from sqlalchemy import and_, delete, func, select, update +from sqlalchemy.orm import Session + +from models.workflow import WorkflowTriggerLog, WorkflowTriggerStatus +from repositories.workflow_trigger_log_repository import TriggerLogOrderBy, WorkflowTriggerLogRepository + + +class SQLAlchemyWorkflowTriggerLogRepository(WorkflowTriggerLogRepository): + """ + SQLAlchemy implementation of WorkflowTriggerLogRepository. + + Optimized for large table operations with proper indexing and batch processing. + """ + + def __init__(self, session: Session): + self.session = session + + def create(self, trigger_log: WorkflowTriggerLog) -> WorkflowTriggerLog: + """Create a new trigger log entry.""" + self.session.add(trigger_log) + self.session.flush() + return trigger_log + + def update(self, trigger_log: WorkflowTriggerLog) -> WorkflowTriggerLog: + """Update an existing trigger log entry.""" + self.session.merge(trigger_log) + self.session.flush() + return trigger_log + + def get_by_id(self, trigger_log_id: str, tenant_id: Optional[str] = None) -> Optional[WorkflowTriggerLog]: + """Get a trigger log by its ID.""" + query = select(WorkflowTriggerLog).where(WorkflowTriggerLog.id == trigger_log_id) + + if tenant_id: + query = query.where(WorkflowTriggerLog.tenant_id == tenant_id) + + return self.session.scalar(query) + + def get_by_status( + self, + tenant_id: str, + app_id: str, + status: WorkflowTriggerStatus, + limit: int = 100, + offset: int = 0, + order_by: TriggerLogOrderBy = TriggerLogOrderBy.CREATED_AT, + order_desc: bool = True, + ) -> Sequence[WorkflowTriggerLog]: + """Get trigger logs by status with pagination.""" + query = select(WorkflowTriggerLog).where( + and_( + WorkflowTriggerLog.tenant_id == tenant_id, + WorkflowTriggerLog.app_id == app_id, + WorkflowTriggerLog.status == status, + ) + ) + + # Apply ordering + order_column = getattr(WorkflowTriggerLog, order_by.value) + if order_desc: + query = query.order_by(order_column.desc()) + else: + query = query.order_by(order_column.asc()) + + # Apply pagination + query = query.limit(limit).offset(offset) + + return list(self.session.scalars(query).all()) + + def get_failed_for_retry( + self, tenant_id: str, max_retry_count: int = 3, limit: int = 100 + ) -> Sequence[WorkflowTriggerLog]: + """Get failed trigger logs eligible for retry.""" + query = ( + select(WorkflowTriggerLog) + .where( + and_( + WorkflowTriggerLog.tenant_id == tenant_id, + WorkflowTriggerLog.status.in_([WorkflowTriggerStatus.FAILED, WorkflowTriggerStatus.RATE_LIMITED]), + WorkflowTriggerLog.retry_count < max_retry_count, + ) + ) + .order_by(WorkflowTriggerLog.created_at.asc()) + .limit(limit) + ) + + return list(self.session.scalars(query).all()) + + def get_recent_logs( + self, tenant_id: str, app_id: str, hours: int = 24, limit: int = 100, offset: int = 0 + ) -> Sequence[WorkflowTriggerLog]: + """Get recent trigger logs within specified hours.""" + since = datetime.utcnow() - timedelta(hours=hours) + + query = ( + select(WorkflowTriggerLog) + .where( + and_( + WorkflowTriggerLog.tenant_id == tenant_id, + WorkflowTriggerLog.app_id == app_id, + WorkflowTriggerLog.created_at >= since, + ) + ) + .order_by(WorkflowTriggerLog.created_at.desc()) + .limit(limit) + .offset(offset) + ) + + return list(self.session.scalars(query).all()) + + def count_by_status( + self, + tenant_id: str, + app_id: str, + status: Optional[WorkflowTriggerStatus] = None, + since: Optional[datetime] = None, + ) -> int: + """Count trigger logs by status.""" + query = select(func.count(WorkflowTriggerLog.id)).where( + and_(WorkflowTriggerLog.tenant_id == tenant_id, WorkflowTriggerLog.app_id == app_id) + ) + + if status: + query = query.where(WorkflowTriggerLog.status == status) + + if since: + query = query.where(WorkflowTriggerLog.created_at >= since) + + return self.session.scalar(query) or 0 + + def delete_expired_logs(self, tenant_id: str, before_date: datetime, batch_size: int = 1000) -> int: + """Delete expired trigger logs in batches.""" + total_deleted = 0 + + while True: + # Get batch of IDs to delete + subquery = ( + select(WorkflowTriggerLog.id) + .where(and_(WorkflowTriggerLog.tenant_id == tenant_id, WorkflowTriggerLog.created_at < before_date)) + .limit(batch_size) + ) + + # Delete the batch + result = self.session.execute(delete(WorkflowTriggerLog).where(WorkflowTriggerLog.id.in_(subquery))) + + deleted = result.rowcount + total_deleted += deleted + + if deleted < batch_size: + break + + self.session.commit() + + return total_deleted + + def archive_completed_logs( + self, tenant_id: str, before_date: datetime, batch_size: int = 1000 + ) -> Sequence[WorkflowTriggerLog]: + """Get completed logs for archival.""" + query = ( + select(WorkflowTriggerLog) + .where( + and_( + WorkflowTriggerLog.tenant_id == tenant_id, + WorkflowTriggerLog.status == WorkflowTriggerStatus.SUCCEEDED, + WorkflowTriggerLog.finished_at < before_date, + ) + ) + .limit(batch_size) + ) + + return list(self.session.scalars(query).all()) + + def update_status_batch( + self, trigger_log_ids: Sequence[str], new_status: WorkflowTriggerStatus, error_message: Optional[str] = None + ) -> int: + """Update status for multiple trigger logs.""" + update_data: dict[str, Any] = {"status": new_status} + + if error_message is not None: + update_data["error"] = error_message + + if new_status in [WorkflowTriggerStatus.SUCCEEDED, WorkflowTriggerStatus.FAILED]: + update_data["finished_at"] = datetime.utcnow() + + result = self.session.execute( + update(WorkflowTriggerLog).where(WorkflowTriggerLog.id.in_(trigger_log_ids)).values(**update_data) + ) + + return result.rowcount diff --git a/api/repositories/workflow_trigger_log_repository.py b/api/repositories/workflow_trigger_log_repository.py new file mode 100644 index 0000000000..46e945b892 --- /dev/null +++ b/api/repositories/workflow_trigger_log_repository.py @@ -0,0 +1,206 @@ +""" +Repository protocol for WorkflowTriggerLog operations. + +This module provides a protocol interface for operations on WorkflowTriggerLog, +designed to efficiently handle a potentially large volume of trigger logs with +proper indexing and batch operations. +""" + +from collections.abc import Sequence +from datetime import datetime +from enum import StrEnum +from typing import Optional, Protocol + +from models.workflow import WorkflowTriggerLog, WorkflowTriggerStatus + + +class TriggerLogOrderBy(StrEnum): + """Fields available for ordering trigger logs""" + + CREATED_AT = "created_at" + TRIGGERED_AT = "triggered_at" + FINISHED_AT = "finished_at" + STATUS = "status" + + +class WorkflowTriggerLogRepository(Protocol): + """ + Protocol for operations on WorkflowTriggerLog. + + This repository provides efficient access patterns for the trigger log table, + which is expected to grow large over time. It includes: + - Batch operations for cleanup + - Efficient queries with proper indexing + - Pagination support + - Status-based filtering + + Implementation notes: + - Leverage database indexes on (tenant_id, app_id), status, and created_at + - Use batch operations for deletions to avoid locking + - Support pagination for large result sets + """ + + def create(self, trigger_log: WorkflowTriggerLog) -> WorkflowTriggerLog: + """ + Create a new trigger log entry. + + Args: + trigger_log: The WorkflowTriggerLog instance to create + + Returns: + The created WorkflowTriggerLog with generated ID + """ + ... + + def update(self, trigger_log: WorkflowTriggerLog) -> WorkflowTriggerLog: + """ + Update an existing trigger log entry. + + Args: + trigger_log: The WorkflowTriggerLog instance to update + + Returns: + The updated WorkflowTriggerLog + """ + ... + + def get_by_id(self, trigger_log_id: str, tenant_id: Optional[str] = None) -> Optional[WorkflowTriggerLog]: + """ + Get a trigger log by its ID. + + Args: + trigger_log_id: The trigger log identifier + tenant_id: Optional tenant identifier for additional security + + Returns: + The WorkflowTriggerLog if found, None otherwise + """ + ... + + def get_by_status( + self, + tenant_id: str, + app_id: str, + status: WorkflowTriggerStatus, + limit: int = 100, + offset: int = 0, + order_by: TriggerLogOrderBy = TriggerLogOrderBy.CREATED_AT, + order_desc: bool = True, + ) -> Sequence[WorkflowTriggerLog]: + """ + Get trigger logs by status with pagination. + + Args: + tenant_id: The tenant identifier + app_id: The application identifier + status: The workflow trigger status to filter by + limit: Maximum number of results + offset: Number of results to skip + order_by: Field to order results by + order_desc: Whether to order descending (True) or ascending (False) + + Returns: + A sequence of WorkflowTriggerLog instances + """ + ... + + def get_failed_for_retry( + self, tenant_id: str, max_retry_count: int = 3, limit: int = 100 + ) -> Sequence[WorkflowTriggerLog]: + """ + Get failed trigger logs that are eligible for retry. + + Args: + tenant_id: The tenant identifier + max_retry_count: Maximum retry count to consider + limit: Maximum number of results + + Returns: + A sequence of WorkflowTriggerLog instances eligible for retry + """ + ... + + def get_recent_logs( + self, tenant_id: str, app_id: str, hours: int = 24, limit: int = 100, offset: int = 0 + ) -> Sequence[WorkflowTriggerLog]: + """ + Get recent trigger logs within specified hours. + + Args: + tenant_id: The tenant identifier + app_id: The application identifier + hours: Number of hours to look back + limit: Maximum number of results + offset: Number of results to skip + + Returns: + A sequence of recent WorkflowTriggerLog instances + """ + ... + + def count_by_status( + self, + tenant_id: str, + app_id: str, + status: Optional[WorkflowTriggerStatus] = None, + since: Optional[datetime] = None, + ) -> int: + """ + Count trigger logs by status. + + Args: + tenant_id: The tenant identifier + app_id: The application identifier + status: Optional status filter + since: Optional datetime to count from + + Returns: + Count of matching trigger logs + """ + ... + + def delete_expired_logs(self, tenant_id: str, before_date: datetime, batch_size: int = 1000) -> int: + """ + Delete expired trigger logs in batches. + + Args: + tenant_id: The tenant identifier + before_date: Delete logs created before this date + batch_size: Number of logs to delete per batch + + Returns: + Total number of logs deleted + """ + ... + + def archive_completed_logs( + self, tenant_id: str, before_date: datetime, batch_size: int = 1000 + ) -> Sequence[WorkflowTriggerLog]: + """ + Get completed logs for archival before deletion. + + Args: + tenant_id: The tenant identifier + before_date: Get logs completed before this date + batch_size: Number of logs to retrieve + + Returns: + A sequence of WorkflowTriggerLog instances for archival + """ + ... + + def update_status_batch( + self, trigger_log_ids: Sequence[str], new_status: WorkflowTriggerStatus, error_message: Optional[str] = None + ) -> int: + """ + Update status for multiple trigger logs at once. + + Args: + trigger_log_ids: List of trigger log IDs to update + new_status: The new status to set + error_message: Optional error message to set + + Returns: + Number of logs updated + """ + ... diff --git a/api/services/async_workflow_service.py b/api/services/async_workflow_service.py new file mode 100644 index 0000000000..448f6b2f63 --- /dev/null +++ b/api/services/async_workflow_service.py @@ -0,0 +1,320 @@ +""" +Universal async workflow execution service. + +This service provides a centralized entry point for triggering workflows asynchronously +with support for different subscription tiers, rate limiting, and execution tracking. +""" + +import json +from datetime import UTC, datetime +from typing import Optional, Union + +from celery.result import AsyncResult +from sqlalchemy import select +from sqlalchemy.orm import Session + +from extensions.ext_database import db +from extensions.ext_redis import redis_client +from models.account import Account +from models.enums import CreatorUserRole +from models.model import App, EndUser +from models.workflow import Workflow, WorkflowTriggerLog, WorkflowTriggerStatus +from repositories.sqlalchemy_workflow_trigger_log_repository import SQLAlchemyWorkflowTriggerLogRepository +from services.errors.llm import InvokeRateLimitError +from services.workflow.entities import AsyncTriggerResponse, TriggerData, WorkflowTaskData +from services.workflow.queue_dispatcher import QueueDispatcherManager, QueuePriority +from services.workflow.rate_limiter import TenantDailyRateLimiter +from services.workflow_service import WorkflowService +from tasks.async_workflow_tasks import ( + execute_workflow_professional, + execute_workflow_sandbox, + execute_workflow_team, +) + + +class AsyncWorkflowService: + """ + Universal entry point for async workflow execution - ALL METHODS ARE NON-BLOCKING + + This service handles: + - Trigger data validation and processing + - Queue routing based on subscription tier + - Daily rate limiting with timezone support + - Execution tracking and logging + - Retry mechanisms for failed executions + + Important: All trigger methods return immediately after queuing tasks. + Actual workflow execution happens asynchronously in background Celery workers. + Use trigger log IDs to monitor execution status and results. + """ + + @classmethod + def trigger_workflow_async( + cls, session: Session, user: Union[Account, EndUser], trigger_data: TriggerData + ) -> AsyncTriggerResponse: + """ + Universal entry point for async workflow execution - THIS METHOD WILL NOT BLOCK + + Creates a trigger log and dispatches to appropriate queue based on subscription tier. + The workflow execution happens asynchronously in the background via Celery workers. + This method returns immediately after queuing the task, not after execution completion. + + Args: + session: Database session to use for operations + user: User (Account or EndUser) who initiated the workflow trigger + trigger_data: Validated Pydantic model containing trigger information + + Returns: + AsyncTriggerResponse with workflow_trigger_log_id, task_id, status="queued", and queue + Note: The actual workflow execution status must be checked separately via workflow_trigger_log_id + + Raises: + ValueError: If app or workflow not found + InvokeRateLimitError: If daily rate limit exceeded + + Behavior: + - Non-blocking: Returns immediately after queuing + - Asynchronous: Actual execution happens in background Celery workers + - Status tracking: Use workflow_trigger_log_id to monitor progress + - Queue-based: Routes to different queues based on subscription tier + """ + trigger_log_repo = SQLAlchemyWorkflowTriggerLogRepository(session) + dispatcher_manager = QueueDispatcherManager() + workflow_service = WorkflowService() + rate_limiter = TenantDailyRateLimiter(redis_client) + + # 1. Validate app exists + app_model = session.scalar(select(App).where(App.id == trigger_data.app_id)) + if not app_model: + raise ValueError(f"App not found: {trigger_data.app_id}") + + # 2. Get workflow + workflow = cls._get_workflow(workflow_service, app_model, trigger_data.workflow_id) + + # 3. Get dispatcher based on tenant subscription + dispatcher = dispatcher_manager.get_dispatcher(trigger_data.tenant_id) + + # 4. Get tenant owner timezone for rate limiting + tenant_owner_tz = rate_limiter._get_tenant_owner_timezone(trigger_data.tenant_id) + + # 5. Determine user role and ID + if isinstance(user, Account): + created_by_role = CreatorUserRole.ACCOUNT + created_by = user.id + else: # EndUser + created_by_role = CreatorUserRole.END_USER + created_by = user.id + + # 6. Create trigger log entry first (for tracking) + trigger_log = WorkflowTriggerLog( + tenant_id=trigger_data.tenant_id, + app_id=trigger_data.app_id, + workflow_id=workflow.id, + trigger_type=trigger_data.trigger_type, + trigger_data=trigger_data.model_dump_json(), + inputs=json.dumps(dict(trigger_data.inputs)), + status=WorkflowTriggerStatus.PENDING, + queue_name=dispatcher.get_queue_name(), + retry_count=0, + created_by_role=created_by_role, + created_by=created_by, + ) + + trigger_log = trigger_log_repo.create(trigger_log) + session.commit() + + # 7. Check and consume daily quota + if not dispatcher.consume_quota(trigger_data.tenant_id, tenant_owner_tz): + # Update trigger log status + trigger_log.status = WorkflowTriggerStatus.RATE_LIMITED + trigger_log.error = f"Daily limit reached for {dispatcher.get_queue_name()}" + trigger_log_repo.update(trigger_log) + session.commit() + + remaining = rate_limiter.get_remaining_quota( + trigger_data.tenant_id, dispatcher.get_daily_limit(), tenant_owner_tz + ) + + reset_time = rate_limiter.get_quota_reset_time(trigger_data.tenant_id, tenant_owner_tz) + + raise InvokeRateLimitError( + f"Daily workflow execution limit reached. " + f"Limit resets at {reset_time.strftime('%Y-%m-%d %H:%M:%S %Z')}. " + f"Remaining quota: {remaining}" + ) + + # 8. Create task data + queue_name = dispatcher.get_queue_name() + + task_data = WorkflowTaskData(workflow_trigger_log_id=trigger_log.id) + + # 9. Dispatch to appropriate queue + task_data_dict = task_data.model_dump(mode="json") + + task: AsyncResult | None = None + if queue_name == QueuePriority.PROFESSIONAL: + task = execute_workflow_professional.delay(task_data_dict) # type: ignore + elif queue_name == QueuePriority.TEAM: + task = execute_workflow_team.delay(task_data_dict) # type: ignore + else: # SANDBOX + task = execute_workflow_sandbox.delay(task_data_dict) # type: ignore + + if not task: + raise ValueError(f"Failed to queue task for queue: {queue_name}") + + # 10. Update trigger log with task info + trigger_log.status = WorkflowTriggerStatus.QUEUED + trigger_log.celery_task_id = task.id + trigger_log.triggered_at = datetime.now(UTC) + trigger_log_repo.update(trigger_log) + session.commit() + + return AsyncTriggerResponse( + workflow_trigger_log_id=trigger_log.id, + task_id=task.id, # type: ignore + status="queued", + queue=queue_name, + ) + + @classmethod + def reinvoke_trigger( + cls, session: Session, user: Union[Account, EndUser], workflow_trigger_log_id: str + ) -> AsyncTriggerResponse: + """ + Re-invoke a previously failed or rate-limited trigger - THIS METHOD WILL NOT BLOCK + + Updates the existing trigger log to retry status and creates a new async execution. + Returns immediately after queuing the retry, not after execution completion. + + Args: + session: Database session to use for operations + user: User (Account or EndUser) who initiated the retry + workflow_trigger_log_id: ID of the trigger log to re-invoke + + Returns: + AsyncTriggerResponse with new execution information (status="queued") + Note: This creates a new trigger log entry for the retry attempt + + Raises: + ValueError: If trigger log not found + + Behavior: + - Non-blocking: Returns immediately after queuing retry + - Creates new trigger log: Original log marked as retrying, new log for execution + - Preserves original trigger data: Uses same inputs and configuration + """ + trigger_log_repo = SQLAlchemyWorkflowTriggerLogRepository(session) + + trigger_log = trigger_log_repo.get_by_id(workflow_trigger_log_id) + + if not trigger_log: + raise ValueError(f"Trigger log not found: {workflow_trigger_log_id}") + + # Reconstruct trigger data from log + trigger_data = TriggerData.model_validate_json(trigger_log.trigger_data) + + # Reset log for retry + trigger_log.status = WorkflowTriggerStatus.RETRYING + trigger_log.retry_count += 1 + trigger_log.error = None + trigger_log.triggered_at = datetime.now(UTC) + trigger_log_repo.update(trigger_log) + session.commit() + + # Re-trigger workflow (this will create a new trigger log) + return cls.trigger_workflow_async(session, user, trigger_data) + + @classmethod + def get_trigger_log(cls, workflow_trigger_log_id: str, tenant_id: Optional[str] = None) -> Optional[dict]: + """ + Get trigger log by ID + + Args: + workflow_trigger_log_id: ID of the trigger log + tenant_id: Optional tenant ID for security check + + Returns: + Trigger log as dictionary or None if not found + """ + with Session(db.engine) as session: + trigger_log_repo = SQLAlchemyWorkflowTriggerLogRepository(session) + trigger_log = trigger_log_repo.get_by_id(workflow_trigger_log_id, tenant_id) + + if not trigger_log: + return None + + return trigger_log.to_dict() + + @classmethod + def get_recent_logs( + cls, tenant_id: str, app_id: str, hours: int = 24, limit: int = 100, offset: int = 0 + ) -> list[dict]: + """ + Get recent trigger logs + + Args: + tenant_id: Tenant ID + app_id: Application ID + hours: Number of hours to look back + limit: Maximum number of results + offset: Number of results to skip + + Returns: + List of trigger logs as dictionaries + """ + with Session(db.engine) as session: + trigger_log_repo = SQLAlchemyWorkflowTriggerLogRepository(session) + logs = trigger_log_repo.get_recent_logs( + tenant_id=tenant_id, app_id=app_id, hours=hours, limit=limit, offset=offset + ) + + return [log.to_dict() for log in logs] + + @classmethod + def get_failed_logs_for_retry(cls, tenant_id: str, max_retry_count: int = 3, limit: int = 100) -> list[dict]: + """ + Get failed logs eligible for retry + + Args: + tenant_id: Tenant ID + max_retry_count: Maximum retry count + limit: Maximum number of results + + Returns: + List of failed trigger logs as dictionaries + """ + with Session(db.engine) as session: + trigger_log_repo = SQLAlchemyWorkflowTriggerLogRepository(session) + logs = trigger_log_repo.get_failed_for_retry( + tenant_id=tenant_id, max_retry_count=max_retry_count, limit=limit + ) + + return [log.to_dict() for log in logs] + + @staticmethod + def _get_workflow(workflow_service: WorkflowService, app_model: App, workflow_id: Optional[str] = None) -> Workflow: + """ + Get workflow for the app + + Args: + app_model: App model instance + workflow_id: Optional specific workflow ID + + Returns: + Workflow instance + + Raises: + ValueError: If workflow not found + """ + if workflow_id: + # Get specific published workflow + workflow = workflow_service.get_published_workflow_by_id(app_model, workflow_id) + if not workflow: + raise ValueError(f"Published workflow not found: {workflow_id}") + else: + # Get default published workflow + workflow = workflow_service.get_published_workflow(app_model) + if not workflow: + raise ValueError(f"No published workflow found for app: {app_model.id}") + + return workflow diff --git a/api/services/workflow/entities.py b/api/services/workflow/entities.py new file mode 100644 index 0000000000..cfefa021b6 --- /dev/null +++ b/api/services/workflow/entities.py @@ -0,0 +1,113 @@ +""" +Pydantic models for async workflow trigger system. +""" + +from collections.abc import Mapping, Sequence +from enum import StrEnum +from typing import Any, Optional + +from pydantic import BaseModel, ConfigDict, Field + +from models.enums import WorkflowRunTriggeredFrom + + +class AsyncTriggerStatus(StrEnum): + """Async trigger execution status""" + + COMPLETED = "completed" + FAILED = "failed" + TIMEOUT = "timeout" + + +class TriggerData(BaseModel): + """Base trigger data model for async workflow execution""" + + app_id: str + tenant_id: str + workflow_id: Optional[str] = None + inputs: Mapping[str, Any] + files: Sequence[Mapping[str, Any]] = Field(default_factory=list) + trigger_type: WorkflowRunTriggeredFrom + + model_config = ConfigDict(use_enum_values=True) + + +class WebhookTriggerData(TriggerData): + """Webhook-specific trigger data""" + + trigger_type: WorkflowRunTriggeredFrom = WorkflowRunTriggeredFrom.WEBHOOK + webhook_url: str + headers: Mapping[str, str] = Field(default_factory=dict) + method: str = "POST" + + +class ScheduleTriggerData(TriggerData): + """Schedule-specific trigger data""" + + trigger_type: WorkflowRunTriggeredFrom = WorkflowRunTriggeredFrom.SCHEDULE + schedule_id: str + cron_expression: str + + +class PluginTriggerData(TriggerData): + """Plugin webhook trigger data""" + + trigger_type: WorkflowRunTriggeredFrom = WorkflowRunTriggeredFrom.PLUGIN + plugin_id: str + webhook_url: str + + +class WorkflowTaskData(BaseModel): + """Lightweight data structure for Celery workflow tasks""" + + workflow_trigger_log_id: str # Primary tracking ID - all other data can be fetched from DB + + model_config = ConfigDict(arbitrary_types_allowed=True) + + +class AsyncTriggerExecutionResult(BaseModel): + """Result from async trigger-based workflow execution""" + + execution_id: str + status: AsyncTriggerStatus + result: Optional[Mapping[str, Any]] = None + error: Optional[str] = None + elapsed_time: Optional[float] = None + total_tokens: Optional[int] = None + + model_config = ConfigDict(use_enum_values=True) + + +class AsyncTriggerResponse(BaseModel): + """Response from triggering an async workflow""" + + workflow_trigger_log_id: str + task_id: str + status: str + queue: str + + model_config = ConfigDict(use_enum_values=True) + + +class TriggerLogResponse(BaseModel): + """Response model for trigger log data""" + + id: str + tenant_id: str + app_id: str + workflow_id: str + trigger_type: WorkflowRunTriggeredFrom + status: str + queue_name: str + retry_count: int + celery_task_id: Optional[str] = None + workflow_run_id: Optional[str] = None + error: Optional[str] = None + outputs: Optional[str] = None + elapsed_time: Optional[float] = None + total_tokens: Optional[int] = None + created_at: Optional[str] = None + triggered_at: Optional[str] = None + finished_at: Optional[str] = None + + model_config = ConfigDict(use_enum_values=True) diff --git a/api/services/workflow/queue_dispatcher.py b/api/services/workflow/queue_dispatcher.py new file mode 100644 index 0000000000..f1e5db9073 --- /dev/null +++ b/api/services/workflow/queue_dispatcher.py @@ -0,0 +1,158 @@ +""" +Queue dispatcher system for async workflow execution. + +Implements an ABC-based pattern for handling different subscription tiers +with appropriate queue routing and rate limiting. +""" + +import os +from abc import ABC, abstractmethod +from enum import StrEnum + +from configs import dify_config +from extensions.ext_redis import redis_client +from services.billing_service import BillingService +from services.workflow.rate_limiter import TenantDailyRateLimiter + + +class QueuePriority(StrEnum): + """Queue priorities for different subscription tiers""" + + PROFESSIONAL = "workflow_professional" # Highest priority + TEAM = "workflow_team" + SANDBOX = "workflow_sandbox" # Free tier + + +class BaseQueueDispatcher(ABC): + """Abstract base class for queue dispatchers""" + + def __init__(self): + self.rate_limiter = TenantDailyRateLimiter(redis_client) + + @abstractmethod + def get_queue_name(self) -> str: + """Get the queue name for this dispatcher""" + pass + + @abstractmethod + def get_daily_limit(self) -> int: + """Get daily execution limit""" + pass + + @abstractmethod + def get_priority(self) -> int: + """Get task priority level""" + pass + + def check_daily_quota(self, tenant_id: str, tenant_owner_tz: str) -> bool: + """ + Check if tenant has remaining daily quota + + Args: + tenant_id: The tenant identifier + tenant_owner_tz: Tenant owner's timezone + + Returns: + True if quota available, False otherwise + """ + # Check without consuming + remaining = self.rate_limiter.get_remaining_quota( + tenant_id=tenant_id, max_daily_limit=self.get_daily_limit(), timezone_str=tenant_owner_tz + ) + return remaining > 0 + + def consume_quota(self, tenant_id: str, tenant_owner_tz: str) -> bool: + """ + Consume one execution from daily quota + + Args: + tenant_id: The tenant identifier + tenant_owner_tz: Tenant owner's timezone + + Returns: + True if quota consumed successfully, False if limit reached + """ + return self.rate_limiter.check_and_consume( + tenant_id=tenant_id, max_daily_limit=self.get_daily_limit(), timezone_str=tenant_owner_tz + ) + + +class ProfessionalQueueDispatcher(BaseQueueDispatcher): + """Dispatcher for professional tier""" + + def get_queue_name(self) -> str: + return QueuePriority.PROFESSIONAL + + def get_daily_limit(self) -> int: + return int(os.getenv("PROFESSIONAL_DAILY_LIMIT", "1000")) + + def get_priority(self) -> int: + return 100 + + +class TeamQueueDispatcher(BaseQueueDispatcher): + """Dispatcher for team tier""" + + def get_queue_name(self) -> str: + return QueuePriority.TEAM + + def get_daily_limit(self) -> int: + return int(os.getenv("TEAM_DAILY_LIMIT", "100")) + + def get_priority(self) -> int: + return 50 + + +class SandboxQueueDispatcher(BaseQueueDispatcher): + """Dispatcher for free/sandbox tier""" + + def get_queue_name(self) -> str: + return QueuePriority.SANDBOX + + def get_daily_limit(self) -> int: + return int(os.getenv("SANDBOX_DAILY_LIMIT", "10")) + + def get_priority(self) -> int: + return 10 + + +class QueueDispatcherManager: + """Factory for creating appropriate dispatcher based on tenant subscription""" + + # Mapping of billing plans to dispatchers + PLAN_DISPATCHER_MAP = { + "professional": ProfessionalQueueDispatcher, + "team": TeamQueueDispatcher, + "sandbox": SandboxQueueDispatcher, + # Add new tiers here as they're created + # For any unknown plan, default to sandbox + } + + @classmethod + def get_dispatcher(cls, tenant_id: str) -> BaseQueueDispatcher: + """ + Get dispatcher based on tenant's subscription plan + + Args: + tenant_id: The tenant identifier + + Returns: + Appropriate queue dispatcher instance + """ + if dify_config.BILLING_ENABLED: + try: + billing_info = BillingService.get_info(tenant_id) + plan = billing_info.get("subscription", {}).get("plan", "sandbox") + except Exception: + # If billing service fails, default to sandbox + plan = "sandbox" + else: + # If billing is disabled, use team tier as default + plan = "team" + + dispatcher_class = cls.PLAN_DISPATCHER_MAP.get( + plan, + SandboxQueueDispatcher, # Default to sandbox for unknown plans + ) + + return dispatcher_class() diff --git a/api/services/workflow/rate_limiter.py b/api/services/workflow/rate_limiter.py new file mode 100644 index 0000000000..49f936d253 --- /dev/null +++ b/api/services/workflow/rate_limiter.py @@ -0,0 +1,206 @@ +""" +Day-based rate limiter for workflow executions. + +Implements timezone-aware daily quotas that reset at midnight in the tenant owner's timezone. +""" + +from datetime import datetime, time, timedelta +from typing import Optional, Union + +import pytz +from redis import Redis +from sqlalchemy import select + +from extensions.ext_database import db +from extensions.ext_redis import RedisClientWrapper +from models.account import Account, TenantAccountJoin, TenantAccountRole + + +class TenantDailyRateLimiter: + """ + Day-based rate limiter that resets at midnight in tenant owner's timezone + + This class provides Redis-based rate limiting with the following features: + - Daily quotas that reset at midnight in tenant owner's timezone + - Atomic check-and-consume operations + - Automatic cleanup of stale counters + - Support for timezone changes without duplicate limits + """ + + def __init__(self, redis_client: Union[Redis, RedisClientWrapper]): + self.redis = redis_client + + def _get_tenant_owner_timezone(self, tenant_id: str) -> str: + """ + Get timezone of tenant owner + + Args: + tenant_id: The tenant identifier + + Returns: + Timezone string (e.g., 'America/New_York', 'UTC') + """ + # Query to get tenant owner's timezone using scalar and select + owner = db.session.scalar( + select(Account) + .join(TenantAccountJoin, TenantAccountJoin.account_id == Account.id) + .where(TenantAccountJoin.tenant_id == tenant_id, TenantAccountJoin.role == TenantAccountRole.OWNER) + ) + + if not owner: + return "UTC" + + return owner.timezone or "UTC" + + def _get_day_key(self, tenant_id: str, timezone_str: str) -> str: + """ + Get Redis key for current day in tenant's timezone + + Args: + tenant_id: The tenant identifier + timezone_str: Timezone string + + Returns: + Redis key for the current day + """ + tz = pytz.timezone(timezone_str) + now = datetime.now(tz) + date_str = now.strftime("%Y-%m-%d") + return f"workflow:daily_limit:{tenant_id}:{date_str}:{timezone_str}" + + def _get_ttl_seconds(self, timezone_str: str) -> int: + """ + Calculate seconds until midnight in given timezone + + Args: + timezone_str: Timezone string + + Returns: + Number of seconds until midnight + """ + tz = pytz.timezone(timezone_str) + now = datetime.now(tz) + + # Get next midnight in the timezone + midnight = tz.localize(datetime.combine(now.date() + timedelta(days=1), time.min)) + + return int((midnight - now).total_seconds()) + + def check_and_consume(self, tenant_id: str, max_daily_limit: int, timezone_str: Optional[str] = None) -> bool: + """ + Check if quota available and consume one execution + + Args: + tenant_id: The tenant identifier + max_daily_limit: Maximum daily limit + timezone_str: Optional timezone string (will be fetched if not provided) + + Returns: + True if quota consumed successfully, False if limit reached + """ + if not timezone_str: + timezone_str = self._get_tenant_owner_timezone(tenant_id) + + key = self._get_day_key(tenant_id, timezone_str) + ttl = self._get_ttl_seconds(timezone_str) + + # Check current usage + current = self.redis.get(key) + + if current is None: + # First execution of the day - set to 1 + self.redis.setex(key, ttl, 1) + return True + + current_count = int(current) + if current_count < max_daily_limit: + # Within limit, increment + new_count = self.redis.incr(key) + # Update TTL in case timezone changed + self.redis.expire(key, ttl) + + # Double-check in case of race condition + if new_count <= max_daily_limit: + return True + else: + # Race condition occurred, decrement back + self.redis.decr(key) + return False + else: + # Limit exceeded + return False + + def get_remaining_quota(self, tenant_id: str, max_daily_limit: int, timezone_str: Optional[str] = None) -> int: + """ + Get remaining quota for the day + + Args: + tenant_id: The tenant identifier + max_daily_limit: Maximum daily limit + timezone_str: Optional timezone string (will be fetched if not provided) + + Returns: + Number of remaining executions for the day + """ + if not timezone_str: + timezone_str = self._get_tenant_owner_timezone(tenant_id) + + key = self._get_day_key(tenant_id, timezone_str) + used = int(self.redis.get(key) or 0) + return max(0, max_daily_limit - used) + + def get_current_usage(self, tenant_id: str, timezone_str: Optional[str] = None) -> int: + """ + Get current usage for the day + + Args: + tenant_id: The tenant identifier + timezone_str: Optional timezone string (will be fetched if not provided) + + Returns: + Number of executions used today + """ + if not timezone_str: + timezone_str = self._get_tenant_owner_timezone(tenant_id) + + key = self._get_day_key(tenant_id, timezone_str) + return int(self.redis.get(key) or 0) + + def reset_quota(self, tenant_id: str, timezone_str: Optional[str] = None) -> bool: + """ + Reset quota for testing purposes + + Args: + tenant_id: The tenant identifier + timezone_str: Optional timezone string (will be fetched if not provided) + + Returns: + True if key was deleted, False if key didn't exist + """ + if not timezone_str: + timezone_str = self._get_tenant_owner_timezone(tenant_id) + + key = self._get_day_key(tenant_id, timezone_str) + return bool(self.redis.delete(key)) + + def get_quota_reset_time(self, tenant_id: str, timezone_str: Optional[str] = None) -> datetime: + """ + Get the time when quota will reset (midnight in tenant's timezone) + + Args: + tenant_id: The tenant identifier + timezone_str: Optional timezone string (will be fetched if not provided) + + Returns: + Datetime when quota resets + """ + if not timezone_str: + timezone_str = self._get_tenant_owner_timezone(tenant_id) + + tz = pytz.timezone(timezone_str) + now = datetime.now(tz) + + # Get next midnight in the timezone + midnight = tz.localize(datetime.combine(now.date() + timedelta(days=1), time.min)) + + return midnight diff --git a/api/tasks/async_workflow_tasks.py b/api/tasks/async_workflow_tasks.py new file mode 100644 index 0000000000..0612deb806 --- /dev/null +++ b/api/tasks/async_workflow_tasks.py @@ -0,0 +1,201 @@ +""" +Celery tasks for async workflow execution. + +These tasks handle workflow execution for different subscription tiers +with appropriate retry policies and error handling. +""" + +import json +from datetime import UTC, datetime + +from celery import shared_task +from sqlalchemy import select +from sqlalchemy.orm import Session, sessionmaker + +from configs import dify_config +from core.app.apps.workflow.app_generator import WorkflowAppGenerator +from core.app.entities.app_invoke_entities import InvokeFrom +from extensions.ext_database import db +from models.account import Account +from models.enums import CreatorUserRole +from models.model import App, EndUser, Tenant +from models.workflow import Workflow, WorkflowTriggerLog, WorkflowTriggerStatus +from repositories.sqlalchemy_workflow_trigger_log_repository import SQLAlchemyWorkflowTriggerLogRepository +from services.workflow.entities import AsyncTriggerExecutionResult, AsyncTriggerStatus, TriggerData, WorkflowTaskData + +# Determine queue names based on edition +if dify_config.EDITION == "CLOUD": + # Cloud edition: separate queues for different tiers + PROFESSIONAL_QUEUE = "workflow_professional" + TEAM_QUEUE = "workflow_team" + SANDBOX_QUEUE = "workflow_sandbox" +else: + # Community edition: single workflow queue (not dataset) + PROFESSIONAL_QUEUE = "workflow" + TEAM_QUEUE = "workflow" + SANDBOX_QUEUE = "workflow" + + +@shared_task(queue=PROFESSIONAL_QUEUE) +def execute_workflow_professional(task_data_dict: dict) -> dict: + """Execute workflow for professional tier with highest priority""" + task_data = WorkflowTaskData.model_validate(task_data_dict) + return _execute_workflow_common(task_data).model_dump() + + +@shared_task(queue=TEAM_QUEUE) +def execute_workflow_team(task_data_dict: dict) -> dict: + """Execute workflow for team tier""" + task_data = WorkflowTaskData.model_validate(task_data_dict) + return _execute_workflow_common(task_data).model_dump() + + +@shared_task(queue=SANDBOX_QUEUE) +def execute_workflow_sandbox(task_data_dict: dict) -> dict: + """Execute workflow for free tier with lower retry limit""" + task_data = WorkflowTaskData.model_validate(task_data_dict) + return _execute_workflow_common(task_data).model_dump() + + +def _execute_workflow_common(task_data: WorkflowTaskData) -> AsyncTriggerExecutionResult: + """ + Common workflow execution logic with trigger log updates + + Args: + task_data: Validated Pydantic model with task data + + Returns: + AsyncTriggerExecutionResult: Pydantic model with execution results + """ + # Create a new session for this task + session_factory = sessionmaker(bind=db.engine, expire_on_commit=False) + + with session_factory() as session: + trigger_log_repo = SQLAlchemyWorkflowTriggerLogRepository(session) + + # Get trigger log + trigger_log = trigger_log_repo.get_by_id(task_data.workflow_trigger_log_id) + + if not trigger_log: + # This should not happen, but handle gracefully + return AsyncTriggerExecutionResult( + execution_id=task_data.workflow_trigger_log_id, + status=AsyncTriggerStatus.FAILED, + error=f"Trigger log not found: {task_data.workflow_trigger_log_id}", + ) + + # Reconstruct execution data from trigger log + trigger_data = TriggerData.model_validate_json(trigger_log.trigger_data) + + # Update status to running + trigger_log.status = WorkflowTriggerStatus.RUNNING + trigger_log_repo.update(trigger_log) + session.commit() + + start_time = datetime.now(UTC) + + try: + # Get app and workflow models + app_model = session.scalar(select(App).where(App.id == trigger_log.app_id)) + + if not app_model: + raise ValueError(f"App not found: {trigger_log.app_id}") + + workflow = session.scalar(select(Workflow).where(Workflow.id == trigger_log.workflow_id)) + if not workflow: + raise ValueError(f"Workflow not found: {trigger_log.workflow_id}") + + user = _get_user(session, trigger_log) + + # Execute workflow using WorkflowAppGenerator + generator = WorkflowAppGenerator() + + # Prepare args matching AppGenerateService.generate format + args = {"inputs": dict(trigger_data.inputs), "files": list(trigger_data.files)} + + # If workflow_id was specified, add it to args + if trigger_data.workflow_id: + args["workflow_id"] = trigger_data.workflow_id + + # Execute the workflow with the trigger type + result = generator.generate( + app_model=app_model, + workflow=workflow, + user=user, + args=args, + invoke_from=InvokeFrom.SERVICE_API, + streaming=False, + call_depth=0, + workflow_thread_pool_id=None, + triggered_from=trigger_data.trigger_type, + ) + + # Calculate elapsed time + elapsed_time = (datetime.now(UTC) - start_time).total_seconds() + + # Extract relevant data from result + if isinstance(result, dict): + workflow_run_id = result.get("workflow_run_id") + total_tokens = result.get("total_tokens") + outputs = result + else: + # Handle generator result - collect all data + workflow_run_id = None + total_tokens = None + outputs = {"data": "streaming_result"} + + # Update trigger log with success + trigger_log.status = WorkflowTriggerStatus.SUCCEEDED + trigger_log.workflow_run_id = workflow_run_id + trigger_log.outputs = json.dumps(outputs) + trigger_log.elapsed_time = elapsed_time + trigger_log.total_tokens = total_tokens + trigger_log.finished_at = datetime.now(UTC) + trigger_log_repo.update(trigger_log) + session.commit() + + return AsyncTriggerExecutionResult( + execution_id=trigger_log.id, + status=AsyncTriggerStatus.COMPLETED, + result=outputs, + elapsed_time=elapsed_time, + total_tokens=total_tokens, + ) + + except Exception as e: + # Calculate elapsed time for failed execution + elapsed_time = (datetime.now(UTC) - start_time).total_seconds() + + # Update trigger log with failure + trigger_log.status = WorkflowTriggerStatus.FAILED + trigger_log.error = str(e) + trigger_log.finished_at = datetime.now(UTC) + trigger_log.elapsed_time = elapsed_time + trigger_log_repo.update(trigger_log) + + # Final failure - no retry logic (simplified like RAG tasks) + session.commit() + + return AsyncTriggerExecutionResult( + execution_id=trigger_log.id, status=AsyncTriggerStatus.FAILED, error=str(e), elapsed_time=elapsed_time + ) + + +def _get_user(session: Session, trigger_log: WorkflowTriggerLog) -> Account | EndUser: + """Compose user from trigger log""" + tenant = session.scalar(select(Tenant).where(Tenant.id == trigger_log.tenant_id)) + if not tenant: + raise ValueError(f"Tenant not found: {trigger_log.tenant_id}") + + # Get user from trigger log + if trigger_log.created_by_role == CreatorUserRole.ACCOUNT: + user = session.scalar(select(Account).where(Account.id == trigger_log.created_by)) + if user: + user.current_tenant = tenant + else: # CreatorUserRole.END_USER + user = session.scalar(select(EndUser).where(EndUser.id == trigger_log.created_by)) + + if not user: + raise ValueError(f"User not found: {trigger_log.created_by} (role: {trigger_log.created_by_role})") + + return user diff --git a/dev/start-worker b/dev/start-worker index 66e446c831..a05d3bcc48 100755 --- a/dev/start-worker +++ b/dev/start-worker @@ -2,10 +2,100 @@ set -x +# Help function +show_help() { + echo "Usage: $0 [OPTIONS]" + echo "" + echo "Options:" + echo " -q, --queues QUEUES Comma-separated list of queues to process" + echo " -c, --concurrency NUM Number of worker processes (default: 1)" + echo " -P, --pool POOL Pool implementation (default: gevent)" + echo " --loglevel LEVEL Log level (default: INFO)" + echo " -h, --help Show this help message" + echo "" + echo "Examples:" + echo " $0 --queues dataset,workflow" + echo " $0 --queues workflow_professional,workflow_team --concurrency 4" + echo " $0 --queues dataset --concurrency 2 --pool prefork" + echo "" + echo "Available queues:" + echo " dataset - RAG indexing and document processing" + echo " workflow - Workflow triggers (community edition)" + echo " workflow_professional - Professional tier workflows (cloud edition)" + echo " workflow_team - Team tier workflows (cloud edition)" + echo " workflow_sandbox - Sandbox tier workflows (cloud edition)" + echo " generation - Content generation tasks" + echo " mail - Email notifications" + echo " ops_trace - Operations tracing" + echo " app_deletion - Application cleanup" + echo " plugin - Plugin operations" + echo " workflow_storage - Workflow storage tasks" +} + +# Parse command line arguments +QUEUES="" +CONCURRENCY=1 +POOL="gevent" +LOGLEVEL="INFO" + +while [[ $# -gt 0 ]]; do + case $1 in + -q|--queues) + QUEUES="$2" + shift 2 + ;; + -c|--concurrency) + CONCURRENCY="$2" + shift 2 + ;; + -P|--pool) + POOL="$2" + shift 2 + ;; + --loglevel) + LOGLEVEL="$2" + shift 2 + ;; + -h|--help) + show_help + exit 0 + ;; + *) + echo "Unknown option: $1" + show_help + exit 1 + ;; + esac +done + SCRIPT_DIR="$(dirname "$(realpath "$0")")" cd "$SCRIPT_DIR/.." +# If no queues specified, use edition-based defaults +if [[ -z "${QUEUES}" ]]; then + # Get EDITION from environment, default to SELF_HOSTED (community edition) + EDITION=${EDITION:-"SELF_HOSTED"} + + # Configure queues based on edition + if [[ "${EDITION}" == "CLOUD" ]]; then + # Cloud edition: separate queues for dataset and trigger tasks + QUEUES="dataset,generation,mail,ops_trace,app_deletion,plugin,workflow_storage,workflow_professional,workflow_team,workflow_sandbox" + else + # Community edition (SELF_HOSTED): dataset and workflow have separate queues + QUEUES="dataset,generation,mail,ops_trace,app_deletion,plugin,workflow_storage,workflow" + fi + + echo "No queues specified, using edition-based defaults: ${QUEUES}" +else + echo "Using specified queues: ${QUEUES}" +fi + +echo "Starting Celery worker with:" +echo " Queues: ${QUEUES}" +echo " Concurrency: ${CONCURRENCY}" +echo " Pool: ${POOL}" +echo " Log Level: ${LOGLEVEL}" uv --directory api run \ celery -A app.celery worker \ - -P gevent -c 1 --loglevel INFO -Q dataset,generation,mail,ops_trace,app_deletion,plugin,workflow_storage + -P ${POOL} -c ${CONCURRENCY} --loglevel ${LOGLEVEL} -Q ${QUEUES} diff --git a/docker/docker-compose-template.yaml b/docker/docker-compose-template.yaml index 04981f6b7f..cea0a04b8f 100644 --- a/docker/docker-compose-template.yaml +++ b/docker/docker-compose-template.yaml @@ -29,14 +29,14 @@ services: - default # worker service - # The Celery worker for processing the queue. + # The Celery worker for processing all queues (dataset, workflow, mail, etc.) worker: image: langgenius/dify-api:1.7.2 restart: always environment: # Use the shared environment variables. <<: *shared-api-worker-env - # Startup mode, 'worker' starts the Celery worker for processing the queue. + # Startup mode, 'worker' starts the Celery worker for processing all queues. MODE: worker SENTRY_DSN: ${API_SENTRY_DSN:-} SENTRY_TRACES_SAMPLE_RATE: ${API_SENTRY_TRACES_SAMPLE_RATE:-1.0}