refactor(api): type workflow generator args dict with TypedDict (#34876)

Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
This commit is contained in:
dataCenter430 2026-04-10 01:27:32 -07:00 committed by GitHub
parent e224c77920
commit c9f525a3b2
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -7,15 +7,16 @@ with appropriate retry policies and error handling.
import logging import logging
from datetime import UTC, datetime from datetime import UTC, datetime
from typing import Any from typing import Any, NotRequired
from celery import shared_task from celery import shared_task
from graphon.runtime import GraphRuntimeState from graphon.runtime import GraphRuntimeState
from sqlalchemy import select from sqlalchemy import select
from sqlalchemy.orm import Session, sessionmaker from sqlalchemy.orm import Session, sessionmaker
from typing_extensions import TypedDict
from configs import dify_config from configs import dify_config
from core.app.apps.workflow.app_generator import SKIP_PREPARE_USER_INPUTS_KEY, WorkflowAppGenerator from core.app.apps.workflow.app_generator import WorkflowAppGenerator
from core.app.entities.app_invoke_entities import InvokeFrom, WorkflowAppGenerateEntity from core.app.entities.app_invoke_entities import InvokeFrom, WorkflowAppGenerateEntity
from core.app.layers.pause_state_persist_layer import PauseStateLayerConfig, WorkflowResumptionContext from core.app.layers.pause_state_persist_layer import PauseStateLayerConfig, WorkflowResumptionContext
from core.app.layers.timeslice_layer import TimeSliceLayer from core.app.layers.timeslice_layer import TimeSliceLayer
@ -42,6 +43,13 @@ from tasks.workflow_cfs_scheduler.entities import AsyncWorkflowQueue, AsyncWorkf
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class WorkflowGeneratorArgsDict(TypedDict):
inputs: dict[str, Any]
files: list[Any]
_skip_prepare_user_inputs: bool
workflow_id: NotRequired[str]
@shared_task(queue=AsyncWorkflowQueue.PROFESSIONAL_QUEUE) @shared_task(queue=AsyncWorkflowQueue.PROFESSIONAL_QUEUE)
def execute_workflow_professional(task_data_dict: dict[str, Any]): def execute_workflow_professional(task_data_dict: dict[str, Any]):
"""Execute workflow for professional tier with highest priority""" """Execute workflow for professional tier with highest priority"""
@ -90,15 +98,13 @@ def execute_workflow_sandbox(task_data_dict: dict[str, Any]):
) )
def _build_generator_args(trigger_data: TriggerData) -> dict[str, Any]: def _build_generator_args(trigger_data: TriggerData) -> WorkflowGeneratorArgsDict:
"""Build args passed into WorkflowAppGenerator.generate for Celery executions.""" """Build args passed into WorkflowAppGenerator.generate for Celery executions."""
return {
args: dict[str, Any] = {
"inputs": dict(trigger_data.inputs), "inputs": dict(trigger_data.inputs),
"files": list(trigger_data.files), "files": list(trigger_data.files),
SKIP_PREPARE_USER_INPUTS_KEY: True, "_skip_prepare_user_inputs": True,
} }
return args
def _execute_workflow_common( def _execute_workflow_common(