mirror of https://github.com/langgenius/dify.git
209 lines
7.9 KiB
Python
209 lines
7.9 KiB
Python
"""
|
|
Celery tasks for async workflow execution.
|
|
|
|
These tasks handle workflow execution for different subscription tiers
|
|
with appropriate retry policies and error handling.
|
|
"""
|
|
|
|
import json
|
|
from datetime import UTC, datetime
|
|
from typing import Any
|
|
|
|
from celery import shared_task
|
|
from sqlalchemy import select
|
|
from sqlalchemy.orm import Session, sessionmaker
|
|
|
|
from configs import dify_config
|
|
from core.app.apps.workflow.app_generator import WorkflowAppGenerator
|
|
from core.app.entities.app_invoke_entities import InvokeFrom
|
|
from extensions.ext_database import db
|
|
from models.account import Account
|
|
from models.enums import CreatorUserRole
|
|
from models.model import App, EndUser, Tenant
|
|
from models.workflow import Workflow, WorkflowTriggerLog, WorkflowTriggerStatus
|
|
from repositories.sqlalchemy_workflow_trigger_log_repository import SQLAlchemyWorkflowTriggerLogRepository
|
|
from services.errors.app import WorkflowNotFoundError
|
|
from services.workflow.entities import AsyncTriggerExecutionResult, AsyncTriggerStatus, TriggerData, WorkflowTaskData
|
|
|
|
# Determine queue names based on edition
|
|
if dify_config.EDITION == "CLOUD":
|
|
# Cloud edition: separate queues for different tiers
|
|
_professional_queue = "workflow_professional"
|
|
_team_queue = "workflow_team"
|
|
_sandbox_queue = "workflow_sandbox"
|
|
else:
|
|
# Community edition: single workflow queue (not dataset)
|
|
_professional_queue = "workflow"
|
|
_team_queue = "workflow"
|
|
_sandbox_queue = "workflow"
|
|
|
|
# Define constants
|
|
PROFESSIONAL_QUEUE = _professional_queue
|
|
TEAM_QUEUE = _team_queue
|
|
SANDBOX_QUEUE = _sandbox_queue
|
|
|
|
|
|
@shared_task(queue=PROFESSIONAL_QUEUE)
|
|
def execute_workflow_professional(task_data_dict: dict) -> dict:
|
|
"""Execute workflow for professional tier with highest priority"""
|
|
task_data = WorkflowTaskData.model_validate(task_data_dict)
|
|
return _execute_workflow_common(task_data).model_dump()
|
|
|
|
|
|
@shared_task(queue=TEAM_QUEUE)
|
|
def execute_workflow_team(task_data_dict: dict) -> dict:
|
|
"""Execute workflow for team tier"""
|
|
task_data = WorkflowTaskData.model_validate(task_data_dict)
|
|
return _execute_workflow_common(task_data).model_dump()
|
|
|
|
|
|
@shared_task(queue=SANDBOX_QUEUE)
|
|
def execute_workflow_sandbox(task_data_dict: dict) -> dict:
|
|
"""Execute workflow for free tier with lower retry limit"""
|
|
task_data = WorkflowTaskData.model_validate(task_data_dict)
|
|
return _execute_workflow_common(task_data).model_dump()
|
|
|
|
|
|
def _execute_workflow_common(task_data: WorkflowTaskData) -> AsyncTriggerExecutionResult:
|
|
"""
|
|
Common workflow execution logic with trigger log updates
|
|
|
|
Args:
|
|
task_data: Validated Pydantic model with task data
|
|
|
|
Returns:
|
|
AsyncTriggerExecutionResult: Pydantic model with execution results
|
|
"""
|
|
# Create a new session for this task
|
|
session_factory = sessionmaker(bind=db.engine, expire_on_commit=False)
|
|
|
|
with session_factory() as session:
|
|
trigger_log_repo = SQLAlchemyWorkflowTriggerLogRepository(session)
|
|
|
|
# Get trigger log
|
|
trigger_log = trigger_log_repo.get_by_id(task_data.workflow_trigger_log_id)
|
|
|
|
if not trigger_log:
|
|
# This should not happen, but handle gracefully
|
|
return AsyncTriggerExecutionResult(
|
|
execution_id=task_data.workflow_trigger_log_id,
|
|
status=AsyncTriggerStatus.FAILED,
|
|
error=f"Trigger log not found: {task_data.workflow_trigger_log_id}",
|
|
)
|
|
|
|
# Reconstruct execution data from trigger log
|
|
trigger_data = TriggerData.model_validate_json(trigger_log.trigger_data)
|
|
|
|
# Update status to running
|
|
trigger_log.status = WorkflowTriggerStatus.RUNNING
|
|
trigger_log_repo.update(trigger_log)
|
|
session.commit()
|
|
|
|
start_time = datetime.now(UTC)
|
|
|
|
try:
|
|
# Get app and workflow models
|
|
app_model = session.scalar(select(App).where(App.id == trigger_log.app_id))
|
|
|
|
if not app_model:
|
|
raise WorkflowNotFoundError(f"App not found: {trigger_log.app_id}")
|
|
|
|
workflow = session.scalar(select(Workflow).where(Workflow.id == trigger_log.workflow_id))
|
|
if not workflow:
|
|
raise WorkflowNotFoundError(f"Workflow not found: {trigger_log.workflow_id}")
|
|
|
|
user = _get_user(session, trigger_log)
|
|
|
|
# Execute workflow using WorkflowAppGenerator
|
|
generator = WorkflowAppGenerator()
|
|
|
|
# Prepare args matching AppGenerateService.generate format
|
|
args: dict[str, Any] = {"inputs": dict(trigger_data.inputs), "files": list(trigger_data.files)}
|
|
|
|
# If workflow_id was specified, add it to args
|
|
if trigger_data.workflow_id:
|
|
args["workflow_id"] = str(trigger_data.workflow_id)
|
|
|
|
# Execute the workflow with the trigger type
|
|
result = generator.generate(
|
|
app_model=app_model,
|
|
workflow=workflow,
|
|
user=user,
|
|
args=args,
|
|
invoke_from=InvokeFrom.SERVICE_API,
|
|
streaming=False,
|
|
call_depth=0,
|
|
triggered_from=trigger_data.trigger_type,
|
|
root_node_id=trigger_data.root_node_id,
|
|
)
|
|
|
|
# Calculate elapsed time
|
|
elapsed_time = (datetime.now(UTC) - start_time).total_seconds()
|
|
|
|
# Extract relevant data from result
|
|
if isinstance(result, dict):
|
|
workflow_run_id = result.get("workflow_run_id")
|
|
total_tokens = result.get("total_tokens")
|
|
outputs = result
|
|
else:
|
|
# Handle generator result - collect all data
|
|
workflow_run_id = None
|
|
total_tokens = None
|
|
outputs = {"data": "streaming_result"}
|
|
|
|
# Update trigger log with success
|
|
trigger_log.status = WorkflowTriggerStatus.SUCCEEDED
|
|
trigger_log.workflow_run_id = workflow_run_id
|
|
trigger_log.outputs = json.dumps(outputs)
|
|
trigger_log.elapsed_time = elapsed_time
|
|
trigger_log.total_tokens = total_tokens
|
|
trigger_log.finished_at = datetime.now(UTC)
|
|
trigger_log_repo.update(trigger_log)
|
|
session.commit()
|
|
|
|
return AsyncTriggerExecutionResult(
|
|
execution_id=trigger_log.id,
|
|
status=AsyncTriggerStatus.COMPLETED,
|
|
result=outputs,
|
|
elapsed_time=elapsed_time,
|
|
total_tokens=total_tokens,
|
|
)
|
|
|
|
except Exception as e:
|
|
# Calculate elapsed time for failed execution
|
|
elapsed_time = (datetime.now(UTC) - start_time).total_seconds()
|
|
|
|
# Update trigger log with failure
|
|
trigger_log.status = WorkflowTriggerStatus.FAILED
|
|
trigger_log.error = str(e)
|
|
trigger_log.finished_at = datetime.now(UTC)
|
|
trigger_log.elapsed_time = elapsed_time
|
|
trigger_log_repo.update(trigger_log)
|
|
|
|
# Final failure - no retry logic (simplified like RAG tasks)
|
|
session.commit()
|
|
|
|
return AsyncTriggerExecutionResult(
|
|
execution_id=trigger_log.id, status=AsyncTriggerStatus.FAILED, error=str(e), elapsed_time=elapsed_time
|
|
)
|
|
|
|
|
|
def _get_user(session: Session, trigger_log: WorkflowTriggerLog) -> Account | EndUser:
|
|
"""Compose user from trigger log"""
|
|
tenant = session.scalar(select(Tenant).where(Tenant.id == trigger_log.tenant_id))
|
|
if not tenant:
|
|
raise ValueError(f"Tenant not found: {trigger_log.tenant_id}")
|
|
|
|
# Get user from trigger log
|
|
if trigger_log.created_by_role == CreatorUserRole.ACCOUNT:
|
|
user = session.scalar(select(Account).where(Account.id == trigger_log.created_by))
|
|
if user:
|
|
user.current_tenant = tenant
|
|
else: # CreatorUserRole.END_USER
|
|
user = session.scalar(select(EndUser).where(EndUser.id == trigger_log.created_by))
|
|
|
|
if not user:
|
|
raise ValueError(f"User not found: {trigger_log.created_by} (role: {trigger_log.created_by_role})")
|
|
|
|
return user
|