feat(api): Human Input Node (backend part) (#31646)

The backend part of the human in the loop (HITL) feature and relevant architecture / workflow engine changes.

Signed-off-by: yihong0618 <zouzou0208@gmail.com>
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
Co-authored-by: -LAN- <laipz8200@outlook.com>
Co-authored-by: 盐粒 Yanli <yanli@dify.ai>
Co-authored-by: CrabSAMA <40541269+CrabSAMA@users.noreply.github.com>
Co-authored-by: Stephen Zhou <38493346+hyoban@users.noreply.github.com>
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
Co-authored-by: yihong <zouzou0208@gmail.com>
Co-authored-by: Joel <iamjoel007@gmail.com>
This commit is contained in:
QuantumGhost 2026-01-30 10:18:49 +08:00 committed by GitHub
parent fedd097f63
commit 03e3acfc71
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
207 changed files with 19006 additions and 373 deletions

View File

@ -8,6 +8,7 @@ on:
- "build/**"
- "release/e-*"
- "hotfix/**"
- "feat/hitl-backend"
tags:
- "*"

View File

@ -717,3 +717,28 @@ SANDBOX_EXPIRED_RECORDS_CLEAN_GRACEFUL_PERIOD=21
SANDBOX_EXPIRED_RECORDS_CLEAN_BATCH_SIZE=1000
SANDBOX_EXPIRED_RECORDS_RETENTION_DAYS=30
SANDBOX_EXPIRED_RECORDS_CLEAN_TASK_LOCK_TTL=90000
# Redis URL used for PubSub between API and
# celery worker
# defaults to url constructed from `REDIS_*`
# configurations
PUBSUB_REDIS_URL=
# Pub/sub channel type for streaming events.
# valid options are:
#
# - pubsub: for normal Pub/Sub
# - sharded: for sharded Pub/Sub
#
# It's highly recommended to use sharded Pub/Sub AND redis cluster
# for large deployments.
PUBSUB_REDIS_CHANNEL_TYPE=pubsub
# Whether to use Redis cluster mode while running
# PubSub.
# It's highly recommended to enable this for large deployments.
PUBSUB_REDIS_USE_CLUSTERS=false
# Whether to Enable human input timeout check task
ENABLE_HUMAN_INPUT_TIMEOUT_TASK=true
# Human input timeout check interval in minutes
HUMAN_INPUT_TIMEOUT_TASK_INTERVAL=1

View File

@ -36,6 +36,8 @@ ignore_imports =
core.workflow.nodes.loop.loop_node -> core.workflow.graph_engine
core.workflow.nodes.loop.loop_node -> core.workflow.graph
core.workflow.nodes.loop.loop_node -> core.workflow.graph_engine.command_channels
# TODO(QuantumGhost): fix the import violation later
core.workflow.entities.pause_reason -> core.workflow.nodes.human_input.entities
[importlinter:contract:workflow-infrastructure-dependencies]
name = Workflow Infrastructure Dependencies
@ -58,6 +60,8 @@ ignore_imports =
core.workflow.graph_engine.command_channels.redis_channel -> extensions.ext_redis
core.workflow.graph_engine.manager -> extensions.ext_redis
core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node -> extensions.ext_redis
# TODO(QuantumGhost): use DI to avoid depending on global DB.
core.workflow.nodes.human_input.human_input_node -> extensions.ext_database
[importlinter:contract:workflow-external-imports]
name = Workflow External Imports
@ -145,6 +149,7 @@ ignore_imports =
core.workflow.nodes.agent.agent_node -> core.agent.entities
core.workflow.nodes.agent.agent_node -> core.agent.plugin_entities
core.workflow.nodes.base.node -> core.app.entities.app_invoke_entities
core.workflow.nodes.human_input.human_input_node -> core.app.entities.app_invoke_entities
core.workflow.nodes.knowledge_index.knowledge_index_node -> core.app.entities.app_invoke_entities
core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node -> core.app.app_config.entities
core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node -> core.app.entities.app_invoke_entities
@ -248,6 +253,7 @@ ignore_imports =
core.workflow.nodes.document_extractor.node -> core.variables.segments
core.workflow.nodes.http_request.executor -> core.variables.segments
core.workflow.nodes.http_request.node -> core.variables.segments
core.workflow.nodes.human_input.entities -> core.variables.consts
core.workflow.nodes.iteration.iteration_node -> core.variables
core.workflow.nodes.iteration.iteration_node -> core.variables.segments
core.workflow.nodes.iteration.iteration_node -> core.variables.variables
@ -294,6 +300,8 @@ ignore_imports =
core.workflow.nodes.llm.llm_utils -> extensions.ext_database
core.workflow.nodes.llm.node -> extensions.ext_database
core.workflow.nodes.tool.tool_node -> extensions.ext_database
core.workflow.nodes.human_input.human_input_node -> extensions.ext_database
core.workflow.nodes.human_input.human_input_node -> core.repositories.human_input_repository
core.workflow.workflow_entry -> extensions.otel.runtime
core.workflow.nodes.agent.agent_node -> models
core.workflow.nodes.base.node -> models.enums

View File

@ -1,3 +1,4 @@
from datetime import timedelta
from enum import StrEnum
from typing import Literal
@ -48,6 +49,16 @@ class SecurityConfig(BaseSettings):
default=5,
)
WEB_FORM_SUBMIT_RATE_LIMIT_MAX_ATTEMPTS: PositiveInt = Field(
description="Maximum number of web form submissions allowed per IP within the rate limit window",
default=30,
)
WEB_FORM_SUBMIT_RATE_LIMIT_WINDOW_SECONDS: PositiveInt = Field(
description="Time window in seconds for web form submission rate limiting",
default=60,
)
LOGIN_DISABLED: bool = Field(
description="Whether to disable login checks",
default=False,
@ -82,6 +93,12 @@ class AppExecutionConfig(BaseSettings):
default=0,
)
HITL_GLOBAL_TIMEOUT_SECONDS: PositiveInt = Field(
description="Maximum seconds a workflow run can stay paused waiting for human input before global timeout.",
default=int(timedelta(days=3).total_seconds()),
ge=1,
)
class CodeExecutionSandboxConfig(BaseSettings):
"""
@ -1134,6 +1151,14 @@ class CeleryScheduleTasksConfig(BaseSettings):
description="Enable queue monitor task",
default=False,
)
ENABLE_HUMAN_INPUT_TIMEOUT_TASK: bool = Field(
description="Enable human input timeout check task",
default=True,
)
HUMAN_INPUT_TIMEOUT_TASK_INTERVAL: PositiveInt = Field(
description="Human input timeout check interval in minutes",
default=1,
)
ENABLE_CHECK_UPGRADABLE_PLUGIN_TASK: bool = Field(
description="Enable check upgradable plugin task",
default=True,

View File

@ -6,6 +6,7 @@ from pydantic import Field, NonNegativeFloat, NonNegativeInt, PositiveFloat, Pos
from pydantic_settings import BaseSettings
from .cache.redis_config import RedisConfig
from .cache.redis_pubsub_config import RedisPubSubConfig
from .storage.aliyun_oss_storage_config import AliyunOSSStorageConfig
from .storage.amazon_s3_storage_config import S3StorageConfig
from .storage.azure_blob_storage_config import AzureBlobStorageConfig
@ -317,6 +318,7 @@ class MiddlewareConfig(
CeleryConfig, # Note: CeleryConfig already inherits from DatabaseConfig
KeywordStoreConfig,
RedisConfig,
RedisPubSubConfig,
# configs of storage and storage providers
StorageConfig,
AliyunOSSStorageConfig,

View File

@ -0,0 +1,96 @@
from typing import Literal, Protocol
from urllib.parse import quote_plus, urlunparse
from pydantic import Field
from pydantic_settings import BaseSettings
class RedisConfigDefaults(Protocol):
REDIS_HOST: str
REDIS_PORT: int
REDIS_USERNAME: str | None
REDIS_PASSWORD: str | None
REDIS_DB: int
REDIS_USE_SSL: bool
REDIS_USE_SENTINEL: bool | None
REDIS_USE_CLUSTERS: bool
class RedisConfigDefaultsMixin:
def _redis_defaults(self: RedisConfigDefaults) -> RedisConfigDefaults:
return self
class RedisPubSubConfig(BaseSettings, RedisConfigDefaultsMixin):
"""
Configuration settings for Redis pub/sub streaming.
"""
PUBSUB_REDIS_URL: str | None = Field(
alias="PUBSUB_REDIS_URL",
description=(
"Redis connection URL for pub/sub streaming events between API "
"and celery worker, defaults to url constructed from "
"`REDIS_*` configurations"
),
default=None,
)
PUBSUB_REDIS_USE_CLUSTERS: bool = Field(
description=(
"Enable Redis Cluster mode for pub/sub streaming. It's highly "
"recommended to enable this for large deployments."
),
default=False,
)
PUBSUB_REDIS_CHANNEL_TYPE: Literal["pubsub", "sharded"] = Field(
description=(
"Pub/sub channel type for streaming events. "
"Valid options are:\n"
"\n"
" - pubsub: for normal Pub/Sub\n"
" - sharded: for sharded Pub/Sub\n"
"\n"
"It's highly recommended to use sharded Pub/Sub AND redis cluster "
"for large deployments."
),
default="pubsub",
)
def _build_default_pubsub_url(self) -> str:
defaults = self._redis_defaults()
if not defaults.REDIS_HOST or not defaults.REDIS_PORT:
raise ValueError("PUBSUB_REDIS_URL must be set when default Redis URL cannot be constructed")
scheme = "rediss" if defaults.REDIS_USE_SSL else "redis"
username = defaults.REDIS_USERNAME or None
password = defaults.REDIS_PASSWORD or None
userinfo = ""
if username:
userinfo = quote_plus(username)
if password:
password_part = quote_plus(password)
userinfo = f"{userinfo}:{password_part}" if userinfo else f":{password_part}"
if userinfo:
userinfo = f"{userinfo}@"
host = defaults.REDIS_HOST
port = defaults.REDIS_PORT
db = defaults.REDIS_DB
netloc = f"{userinfo}{host}:{port}"
return urlunparse((scheme, netloc, f"/{db}", "", "", ""))
@property
def normalized_pubsub_redis_url(self) -> str:
pubsub_redis_url = self.PUBSUB_REDIS_URL
if pubsub_redis_url:
cleaned = pubsub_redis_url.strip()
pubsub_redis_url = cleaned or None
if pubsub_redis_url:
return pubsub_redis_url
return self._build_default_pubsub_url()

View File

@ -37,6 +37,7 @@ from . import (
apikey,
extension,
feature,
human_input_form,
init_validate,
ping,
setup,
@ -171,6 +172,7 @@ __all__ = [
"forgot_password",
"generator",
"hit_testing",
"human_input_form",
"init_validate",
"installed_app",
"load_balancing_config",

View File

@ -89,6 +89,7 @@ status_count_model = console_ns.model(
"success": fields.Integer,
"failed": fields.Integer,
"partial_success": fields.Integer,
"paused": fields.Integer,
},
)

View File

@ -32,7 +32,7 @@ from libs.login import current_account_with_tenant, login_required
from models.model import AppMode, Conversation, Message, MessageAnnotation, MessageFeedback
from services.errors.conversation import ConversationNotExistsError
from services.errors.message import MessageNotExistsError, SuggestedQuestionsAfterAnswerDisabledError
from services.message_service import MessageService
from services.message_service import MessageService, attach_message_extra_contents
logger = logging.getLogger(__name__)
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
@ -198,6 +198,7 @@ message_detail_model = console_ns.model(
"created_at": TimestampField,
"agent_thoughts": fields.List(fields.Nested(agent_thought_model)),
"message_files": fields.List(fields.Nested(message_file_model)),
"extra_contents": fields.List(fields.Raw),
"metadata": fields.Raw(attribute="message_metadata_dict"),
"status": fields.String,
"error": fields.String,
@ -290,6 +291,7 @@ class ChatMessageListApi(Resource):
has_more = False
history_messages = list(reversed(history_messages))
attach_message_extra_contents(history_messages)
return InfiniteScrollPagination(data=history_messages, limit=args.limit, has_more=has_more)
@ -474,4 +476,5 @@ class MessageApi(Resource):
if not message:
raise NotFound("Message Not Exists.")
attach_message_extra_contents([message])
return message

View File

@ -507,6 +507,179 @@ class WorkflowDraftRunLoopNodeApi(Resource):
raise InternalServerError()
class HumanInputFormPreviewPayload(BaseModel):
inputs: dict[str, Any] = Field(
default_factory=dict,
description="Values used to fill missing upstream variables referenced in form_content",
)
class HumanInputFormSubmitPayload(BaseModel):
form_inputs: dict[str, Any] = Field(..., description="Values the user provides for the form's own fields")
inputs: dict[str, Any] = Field(
...,
description="Values used to fill missing upstream variables referenced in form_content",
)
action: str = Field(..., description="Selected action ID")
class HumanInputDeliveryTestPayload(BaseModel):
delivery_method_id: str = Field(..., description="Delivery method ID")
inputs: dict[str, Any] = Field(
default_factory=dict,
description="Values used to fill missing upstream variables referenced in form_content",
)
reg(HumanInputFormPreviewPayload)
reg(HumanInputFormSubmitPayload)
reg(HumanInputDeliveryTestPayload)
@console_ns.route("/apps/<uuid:app_id>/advanced-chat/workflows/draft/human-input/nodes/<string:node_id>/form/preview")
class AdvancedChatDraftHumanInputFormPreviewApi(Resource):
@console_ns.doc("get_advanced_chat_draft_human_input_form")
@console_ns.doc(description="Get human input form preview for advanced chat workflow")
@console_ns.doc(params={"app_id": "Application ID", "node_id": "Node ID"})
@console_ns.expect(console_ns.models[HumanInputFormPreviewPayload.__name__])
@setup_required
@login_required
@account_initialization_required
@get_app_model(mode=[AppMode.ADVANCED_CHAT])
@edit_permission_required
def post(self, app_model: App, node_id: str):
"""
Preview human input form content and placeholders
"""
current_user, _ = current_account_with_tenant()
args = HumanInputFormPreviewPayload.model_validate(console_ns.payload or {})
inputs = args.inputs
workflow_service = WorkflowService()
preview = workflow_service.get_human_input_form_preview(
app_model=app_model,
account=current_user,
node_id=node_id,
inputs=inputs,
)
return jsonable_encoder(preview)
@console_ns.route("/apps/<uuid:app_id>/advanced-chat/workflows/draft/human-input/nodes/<string:node_id>/form/run")
class AdvancedChatDraftHumanInputFormRunApi(Resource):
@console_ns.doc("submit_advanced_chat_draft_human_input_form")
@console_ns.doc(description="Submit human input form preview for advanced chat workflow")
@console_ns.doc(params={"app_id": "Application ID", "node_id": "Node ID"})
@console_ns.expect(console_ns.models[HumanInputFormSubmitPayload.__name__])
@setup_required
@login_required
@account_initialization_required
@get_app_model(mode=[AppMode.ADVANCED_CHAT])
@edit_permission_required
def post(self, app_model: App, node_id: str):
"""
Submit human input form preview
"""
current_user, _ = current_account_with_tenant()
args = HumanInputFormSubmitPayload.model_validate(console_ns.payload or {})
workflow_service = WorkflowService()
result = workflow_service.submit_human_input_form_preview(
app_model=app_model,
account=current_user,
node_id=node_id,
form_inputs=args.form_inputs,
inputs=args.inputs,
action=args.action,
)
return jsonable_encoder(result)
@console_ns.route("/apps/<uuid:app_id>/workflows/draft/human-input/nodes/<string:node_id>/form/preview")
class WorkflowDraftHumanInputFormPreviewApi(Resource):
@console_ns.doc("get_workflow_draft_human_input_form")
@console_ns.doc(description="Get human input form preview for workflow")
@console_ns.doc(params={"app_id": "Application ID", "node_id": "Node ID"})
@console_ns.expect(console_ns.models[HumanInputFormPreviewPayload.__name__])
@setup_required
@login_required
@account_initialization_required
@get_app_model(mode=[AppMode.WORKFLOW])
@edit_permission_required
def post(self, app_model: App, node_id: str):
"""
Preview human input form content and placeholders
"""
current_user, _ = current_account_with_tenant()
args = HumanInputFormPreviewPayload.model_validate(console_ns.payload or {})
inputs = args.inputs
workflow_service = WorkflowService()
preview = workflow_service.get_human_input_form_preview(
app_model=app_model,
account=current_user,
node_id=node_id,
inputs=inputs,
)
return jsonable_encoder(preview)
@console_ns.route("/apps/<uuid:app_id>/workflows/draft/human-input/nodes/<string:node_id>/form/run")
class WorkflowDraftHumanInputFormRunApi(Resource):
@console_ns.doc("submit_workflow_draft_human_input_form")
@console_ns.doc(description="Submit human input form preview for workflow")
@console_ns.doc(params={"app_id": "Application ID", "node_id": "Node ID"})
@console_ns.expect(console_ns.models[HumanInputFormSubmitPayload.__name__])
@setup_required
@login_required
@account_initialization_required
@get_app_model(mode=[AppMode.WORKFLOW])
@edit_permission_required
def post(self, app_model: App, node_id: str):
"""
Submit human input form preview
"""
current_user, _ = current_account_with_tenant()
workflow_service = WorkflowService()
args = HumanInputFormSubmitPayload.model_validate(console_ns.payload or {})
result = workflow_service.submit_human_input_form_preview(
app_model=app_model,
account=current_user,
node_id=node_id,
form_inputs=args.form_inputs,
inputs=args.inputs,
action=args.action,
)
return jsonable_encoder(result)
@console_ns.route("/apps/<uuid:app_id>/workflows/draft/human-input/nodes/<string:node_id>/delivery-test")
class WorkflowDraftHumanInputDeliveryTestApi(Resource):
@console_ns.doc("test_workflow_draft_human_input_delivery")
@console_ns.doc(description="Test human input delivery for workflow")
@console_ns.doc(params={"app_id": "Application ID", "node_id": "Node ID"})
@console_ns.expect(console_ns.models[HumanInputDeliveryTestPayload.__name__])
@setup_required
@login_required
@account_initialization_required
@get_app_model(mode=[AppMode.WORKFLOW, AppMode.ADVANCED_CHAT])
@edit_permission_required
def post(self, app_model: App, node_id: str):
"""
Test human input delivery
"""
current_user, _ = current_account_with_tenant()
workflow_service = WorkflowService()
args = HumanInputDeliveryTestPayload.model_validate(console_ns.payload or {})
workflow_service.test_human_input_delivery(
app_model=app_model,
account=current_user,
node_id=node_id,
delivery_method_id=args.delivery_method_id,
inputs=args.inputs,
)
return jsonable_encoder({})
@console_ns.route("/apps/<uuid:app_id>/workflows/draft/run")
class DraftWorkflowRunApi(Resource):
@console_ns.doc("run_draft_workflow")

View File

@ -5,10 +5,15 @@ from flask import request
from flask_restx import Resource, fields, marshal_with
from pydantic import BaseModel, Field, field_validator
from sqlalchemy import select
from sqlalchemy.orm import sessionmaker
from configs import dify_config
from controllers.console import console_ns
from controllers.console.app.wraps import get_app_model
from controllers.console.wraps import account_initialization_required, setup_required
from controllers.web.error import NotFoundError
from core.workflow.entities.pause_reason import HumanInputRequired
from core.workflow.enums import WorkflowExecutionStatus
from extensions.ext_database import db
from fields.end_user_fields import simple_end_user_fields
from fields.member_fields import simple_account_fields
@ -27,9 +32,21 @@ from libs.custom_inputs import time_duration
from libs.helper import uuid_value
from libs.login import current_user, login_required
from models import Account, App, AppMode, EndUser, WorkflowArchiveLog, WorkflowRunTriggeredFrom
from models.workflow import WorkflowRun
from repositories.factory import DifyAPIRepositoryFactory
from services.retention.workflow_run.constants import ARCHIVE_BUNDLE_NAME
from services.workflow_run_service import WorkflowRunService
def _build_backstage_input_url(form_token: str | None) -> str | None:
if not form_token:
return None
base_url = dify_config.APP_WEB_URL
if not base_url:
return None
return f"{base_url.rstrip('/')}/form/{form_token}"
# Workflow run status choices for filtering
WORKFLOW_RUN_STATUS_CHOICES = ["running", "succeeded", "failed", "stopped", "partial-succeeded"]
EXPORT_SIGNED_URL_EXPIRE_SECONDS = 3600
@ -440,3 +457,63 @@ class WorkflowRunNodeExecutionListApi(Resource):
)
return {"data": node_executions}
@console_ns.route("/workflow/<string:workflow_run_id>/pause-details")
class ConsoleWorkflowPauseDetailsApi(Resource):
"""Console API for getting workflow pause details."""
@account_initialization_required
@login_required
def get(self, workflow_run_id: str):
"""
Get workflow pause details.
GET /console/api/workflow/<workflow_run_id>/pause-details
Returns information about why and where the workflow is paused.
"""
# Query WorkflowRun to determine if workflow is suspended
session_maker = sessionmaker(bind=db.engine)
workflow_run_repo = DifyAPIRepositoryFactory.create_api_workflow_run_repository(session_maker=session_maker)
workflow_run = db.session.get(WorkflowRun, workflow_run_id)
if not workflow_run:
raise NotFoundError("Workflow run not found")
# Check if workflow is suspended
is_paused = workflow_run.status == WorkflowExecutionStatus.PAUSED
if not is_paused:
return {
"paused_at": None,
"paused_nodes": [],
}, 200
pause_entity = workflow_run_repo.get_workflow_pause(workflow_run_id)
pause_reasons = pause_entity.get_pause_reasons() if pause_entity else []
# Build response
paused_at = pause_entity.paused_at if pause_entity else None
paused_nodes = []
response = {
"paused_at": paused_at.isoformat() + "Z" if paused_at else None,
"paused_nodes": paused_nodes,
}
for reason in pause_reasons:
if isinstance(reason, HumanInputRequired):
paused_nodes.append(
{
"node_id": reason.node_id,
"node_title": reason.node_title,
"pause_type": {
"type": "human_input",
"form_id": reason.form_id,
"backstage_input_url": _build_backstage_input_url(reason.form_token),
},
}
)
else:
raise AssertionError("unimplemented.")
return response, 200

View File

@ -0,0 +1,217 @@
"""
Console/Studio Human Input Form APIs.
"""
import json
import logging
from collections.abc import Generator
from flask import Response, jsonify, request
from flask_restx import Resource, reqparse
from sqlalchemy import select
from sqlalchemy.orm import Session, sessionmaker
from controllers.console import console_ns
from controllers.console.wraps import account_initialization_required, setup_required
from controllers.web.error import InvalidArgumentError, NotFoundError
from core.app.apps.advanced_chat.app_generator import AdvancedChatAppGenerator
from core.app.apps.common.workflow_response_converter import WorkflowResponseConverter
from core.app.apps.message_generator import MessageGenerator
from core.app.apps.workflow.app_generator import WorkflowAppGenerator
from extensions.ext_database import db
from libs.login import current_account_with_tenant, login_required
from models import App
from models.enums import CreatorUserRole
from models.human_input import RecipientType
from models.model import AppMode
from models.workflow import WorkflowRun
from repositories.factory import DifyAPIRepositoryFactory
from services.human_input_service import Form, HumanInputService
from services.workflow_event_snapshot_service import build_workflow_event_stream
logger = logging.getLogger(__name__)
def _jsonify_form_definition(form: Form) -> Response:
payload = form.get_definition().model_dump()
payload["expiration_time"] = int(form.expiration_time.timestamp())
return Response(json.dumps(payload, ensure_ascii=False), mimetype="application/json")
@console_ns.route("/form/human_input/<string:form_token>")
class ConsoleHumanInputFormApi(Resource):
"""Console API for getting human input form definition."""
@staticmethod
def _ensure_console_access(form: Form):
_, current_tenant_id = current_account_with_tenant()
if form.tenant_id != current_tenant_id:
raise NotFoundError("App not found")
@setup_required
@login_required
@account_initialization_required
def get(self, form_token: str):
"""
Get human input form definition by form token.
GET /console/api/form/human_input/<form_token>
"""
service = HumanInputService(db.engine)
form = service.get_form_definition_by_token_for_console(form_token)
if form is None:
raise NotFoundError(f"form not found, token={form_token}")
self._ensure_console_access(form)
return _jsonify_form_definition(form)
@account_initialization_required
@login_required
def post(self, form_token: str):
"""
Submit human input form by form token.
POST /console/api/form/human_input/<form_token>
Request body:
{
"inputs": {
"content": "User input content"
},
"action": "Approve"
}
"""
parser = reqparse.RequestParser()
parser.add_argument("inputs", type=dict, required=True, location="json")
parser.add_argument("action", type=str, required=True, location="json")
args = parser.parse_args()
current_user, _ = current_account_with_tenant()
service = HumanInputService(db.engine)
form = service.get_form_by_token(form_token)
if form is None:
raise NotFoundError(f"form not found, token={form_token}")
self._ensure_console_access(form)
recipient_type = form.recipient_type
if recipient_type not in {RecipientType.CONSOLE, RecipientType.BACKSTAGE}:
raise NotFoundError(f"form not found, token={form_token}")
# The type checker is not smart enought to validate the following invariant.
# So we need to assert it manually.
assert recipient_type is not None, "recipient_type cannot be None here."
service.submit_form_by_token(
recipient_type=recipient_type,
form_token=form_token,
selected_action_id=args["action"],
form_data=args["inputs"],
submission_user_id=current_user.id,
)
return jsonify({})
@console_ns.route("/workflow/<string:workflow_run_id>/events")
class ConsoleWorkflowEventsApi(Resource):
"""Console API for getting workflow execution events after resume."""
@account_initialization_required
@login_required
def get(self, workflow_run_id: str):
"""
Get workflow execution events stream after resume.
GET /console/api/workflow/<workflow_run_id>/events
Returns Server-Sent Events stream.
"""
user, tenant_id = current_account_with_tenant()
session_maker = sessionmaker(db.engine)
repo = DifyAPIRepositoryFactory.create_api_workflow_run_repository(session_maker)
workflow_run = repo.get_workflow_run_by_id_and_tenant_id(
tenant_id=tenant_id,
run_id=workflow_run_id,
)
if workflow_run is None:
raise NotFoundError(f"WorkflowRun not found, id={workflow_run_id}")
if workflow_run.created_by_role != CreatorUserRole.ACCOUNT:
raise NotFoundError(f"WorkflowRun not created by account, id={workflow_run_id}")
if workflow_run.created_by != user.id:
raise NotFoundError(f"WorkflowRun not created by the current account, id={workflow_run_id}")
with Session(expire_on_commit=False, bind=db.engine) as session:
app = _retrieve_app_for_workflow_run(session, workflow_run)
if workflow_run.finished_at is not None:
# TODO(QuantumGhost): should we modify the handling for finished workflow run here?
response = WorkflowResponseConverter.workflow_run_result_to_finish_response(
task_id=workflow_run.id,
workflow_run=workflow_run,
creator_user=user,
)
payload = response.model_dump(mode="json")
payload["event"] = response.event.value
def _generate_finished_events() -> Generator[str, None, None]:
yield f"data: {json.dumps(payload)}\n\n"
event_generator = _generate_finished_events
else:
msg_generator = MessageGenerator()
if app.mode == AppMode.ADVANCED_CHAT:
generator = AdvancedChatAppGenerator()
elif app.mode == AppMode.WORKFLOW:
generator = WorkflowAppGenerator()
else:
raise InvalidArgumentError(f"cannot subscribe to workflow run, workflow_run_id={workflow_run.id}")
include_state_snapshot = request.args.get("include_state_snapshot", "false").lower() == "true"
def _generate_stream_events():
if include_state_snapshot:
return generator.convert_to_event_stream(
build_workflow_event_stream(
app_mode=AppMode(app.mode),
workflow_run=workflow_run,
tenant_id=workflow_run.tenant_id,
app_id=workflow_run.app_id,
session_maker=session_maker,
)
)
return generator.convert_to_event_stream(
msg_generator.retrieve_events(AppMode(app.mode), workflow_run.id),
)
event_generator = _generate_stream_events
return Response(
event_generator(),
mimetype="text/event-stream",
headers={
"Cache-Control": "no-cache",
"Connection": "keep-alive",
},
)
def _retrieve_app_for_workflow_run(session: Session, workflow_run: WorkflowRun):
query = select(App).where(
App.id == workflow_run.app_id,
App.tenant_id == workflow_run.tenant_id,
)
app = session.scalars(query).first()
if app is None:
raise AssertionError(
f"App not found for WorkflowRun, workflow_run_id={workflow_run.id}, "
f"app_id={workflow_run.app_id}, tenant_id={workflow_run.tenant_id}"
)
return app

View File

@ -33,8 +33,9 @@ from core.workflow.graph_engine.manager import GraphEngineManager
from extensions.ext_database import db
from fields.workflow_app_log_fields import build_workflow_app_log_pagination_model
from libs import helper
from libs.helper import TimestampField
from libs.helper import OptionalTimestampField, TimestampField
from models.model import App, AppMode, EndUser
from models.workflow import WorkflowRun
from repositories.factory import DifyAPIRepositoryFactory
from services.app_generate_service import AppGenerateService
from services.errors.app import IsDraftWorkflowError, WorkflowIdFormatError, WorkflowNotFoundError
@ -63,17 +64,32 @@ class WorkflowLogQuery(BaseModel):
register_schema_models(service_api_ns, WorkflowRunPayload, WorkflowLogQuery)
class WorkflowRunStatusField(fields.Raw):
def output(self, key, obj: WorkflowRun, **kwargs):
return obj.status.value
class WorkflowRunOutputsField(fields.Raw):
def output(self, key, obj: WorkflowRun, **kwargs):
if obj.status == WorkflowExecutionStatus.PAUSED:
return {}
outputs = obj.outputs_dict
return outputs or {}
workflow_run_fields = {
"id": fields.String,
"workflow_id": fields.String,
"status": fields.String,
"status": WorkflowRunStatusField,
"inputs": fields.Raw,
"outputs": fields.Raw,
"outputs": WorkflowRunOutputsField,
"error": fields.String,
"total_steps": fields.Integer,
"total_tokens": fields.Integer,
"created_at": TimestampField,
"finished_at": TimestampField,
"finished_at": OptionalTimestampField,
"elapsed_time": fields.Float,
}

View File

@ -23,6 +23,7 @@ from . import (
feature,
files,
forgot_password,
human_input_form,
login,
message,
passport,
@ -30,6 +31,7 @@ from . import (
saved_message,
site,
workflow,
workflow_events,
)
api.add_namespace(web_ns)
@ -44,6 +46,7 @@ __all__ = [
"feature",
"files",
"forgot_password",
"human_input_form",
"login",
"message",
"passport",
@ -52,4 +55,5 @@ __all__ = [
"site",
"web_ns",
"workflow",
"workflow_events",
]

View File

@ -117,6 +117,12 @@ class InvokeRateLimitError(BaseHTTPException):
code = 429
class WebFormRateLimitExceededError(BaseHTTPException):
error_code = "web_form_rate_limit_exceeded"
description = "Too many form requests. Please try again later."
code = 429
class NotFoundError(BaseHTTPException):
error_code = "not_found"
code = 404

View File

@ -0,0 +1,164 @@
"""
Web App Human Input Form APIs.
"""
import json
import logging
from datetime import datetime
from flask import Response, request
from flask_restx import Resource, reqparse
from werkzeug.exceptions import Forbidden
from configs import dify_config
from controllers.web import web_ns
from controllers.web.error import NotFoundError, WebFormRateLimitExceededError
from controllers.web.site import serialize_app_site_payload
from extensions.ext_database import db
from libs.helper import RateLimiter, extract_remote_ip
from models.account import TenantStatus
from models.model import App, Site
from services.human_input_service import Form, FormNotFoundError, HumanInputService
logger = logging.getLogger(__name__)
_FORM_SUBMIT_RATE_LIMITER = RateLimiter(
prefix="web_form_submit_rate_limit",
max_attempts=dify_config.WEB_FORM_SUBMIT_RATE_LIMIT_MAX_ATTEMPTS,
time_window=dify_config.WEB_FORM_SUBMIT_RATE_LIMIT_WINDOW_SECONDS,
)
_FORM_ACCESS_RATE_LIMITER = RateLimiter(
prefix="web_form_access_rate_limit",
max_attempts=dify_config.WEB_FORM_SUBMIT_RATE_LIMIT_MAX_ATTEMPTS,
time_window=dify_config.WEB_FORM_SUBMIT_RATE_LIMIT_WINDOW_SECONDS,
)
def _stringify_default_values(values: dict[str, object]) -> dict[str, str]:
result: dict[str, str] = {}
for key, value in values.items():
if value is None:
result[key] = ""
elif isinstance(value, (dict, list)):
result[key] = json.dumps(value, ensure_ascii=False)
else:
result[key] = str(value)
return result
def _to_timestamp(value: datetime) -> int:
return int(value.timestamp())
def _jsonify_form_definition(form: Form, site_payload: dict | None = None) -> Response:
"""Return the form payload (optionally with site) as a JSON response."""
definition_payload = form.get_definition().model_dump()
payload = {
"form_content": definition_payload["rendered_content"],
"inputs": definition_payload["inputs"],
"resolved_default_values": _stringify_default_values(definition_payload["default_values"]),
"user_actions": definition_payload["user_actions"],
"expiration_time": _to_timestamp(form.expiration_time),
}
if site_payload is not None:
payload["site"] = site_payload
return Response(json.dumps(payload, ensure_ascii=False), mimetype="application/json")
# TODO(QuantumGhost): disable authorization for web app
# form api temporarily
@web_ns.route("/form/human_input/<string:form_token>")
# class HumanInputFormApi(WebApiResource):
class HumanInputFormApi(Resource):
"""API for getting and submitting human input forms via the web app."""
# def get(self, _app_model: App, _end_user: EndUser, form_token: str):
def get(self, form_token: str):
"""
Get human input form definition by token.
GET /api/form/human_input/<form_token>
"""
ip_address = extract_remote_ip(request)
if _FORM_ACCESS_RATE_LIMITER.is_rate_limited(ip_address):
raise WebFormRateLimitExceededError()
_FORM_ACCESS_RATE_LIMITER.increment_rate_limit(ip_address)
service = HumanInputService(db.engine)
# TODO(QuantumGhost): forbid submision for form tokens
# that are only for console.
form = service.get_form_by_token(form_token)
if form is None:
raise NotFoundError("Form not found")
service.ensure_form_active(form)
app_model, site = _get_app_site_from_form(form)
return _jsonify_form_definition(form, site_payload=serialize_app_site_payload(app_model, site, None))
# def post(self, _app_model: App, _end_user: EndUser, form_token: str):
def post(self, form_token: str):
"""
Submit human input form by token.
POST /api/form/human_input/<form_token>
Request body:
{
"inputs": {
"content": "User input content"
},
"action": "Approve"
}
"""
parser = reqparse.RequestParser()
parser.add_argument("inputs", type=dict, required=True, location="json")
parser.add_argument("action", type=str, required=True, location="json")
args = parser.parse_args()
ip_address = extract_remote_ip(request)
if _FORM_SUBMIT_RATE_LIMITER.is_rate_limited(ip_address):
raise WebFormRateLimitExceededError()
_FORM_SUBMIT_RATE_LIMITER.increment_rate_limit(ip_address)
service = HumanInputService(db.engine)
form = service.get_form_by_token(form_token)
if form is None:
raise NotFoundError("Form not found")
if (recipient_type := form.recipient_type) is None:
logger.warning("Recipient type is None for form, form_id=%", form.id)
raise AssertionError("Recipient type is None")
try:
service.submit_form_by_token(
recipient_type=recipient_type,
form_token=form_token,
selected_action_id=args["action"],
form_data=args["inputs"],
submission_end_user_id=None,
# submission_end_user_id=_end_user.id,
)
except FormNotFoundError:
raise NotFoundError("Form not found")
return {}, 200
def _get_app_site_from_form(form: Form) -> tuple[App, Site]:
"""Resolve App/Site for the form's app and validate tenant status."""
app_model = db.session.query(App).where(App.id == form.app_id).first()
if app_model is None or app_model.tenant_id != form.tenant_id:
raise NotFoundError("Form not found")
site = db.session.query(Site).where(Site.app_id == app_model.id).first()
if site is None:
raise Forbidden()
if app_model.tenant and app_model.tenant.status == TenantStatus.ARCHIVE:
raise Forbidden()
return app_model, site

View File

@ -1,4 +1,6 @@
from flask_restx import fields, marshal_with
from typing import cast
from flask_restx import fields, marshal, marshal_with
from werkzeug.exceptions import Forbidden
from configs import dify_config
@ -7,7 +9,7 @@ from controllers.web.wraps import WebApiResource
from extensions.ext_database import db
from libs.helper import AppIconUrlField
from models.account import TenantStatus
from models.model import Site
from models.model import App, Site
from services.feature_service import FeatureService
@ -108,3 +110,14 @@ class AppSiteInfo:
"remove_webapp_brand": remove_webapp_brand,
"replace_webapp_logo": replace_webapp_logo,
}
def serialize_site(site: Site) -> dict:
"""Serialize Site model using the same schema as AppSiteApi."""
return cast(dict, marshal(site, AppSiteApi.site_fields))
def serialize_app_site_payload(app_model: App, site: Site, end_user_id: str | None) -> dict:
can_replace_logo = FeatureService.get_features(app_model.tenant_id).can_replace_logo
app_site_info = AppSiteInfo(app_model.tenant, app_model, site, end_user_id, can_replace_logo)
return cast(dict, marshal(app_site_info, AppSiteApi.app_fields))

View File

@ -0,0 +1,112 @@
"""
Web App Workflow Resume APIs.
"""
import json
from collections.abc import Generator
from flask import Response, request
from sqlalchemy.orm import sessionmaker
from controllers.web import api
from controllers.web.error import InvalidArgumentError, NotFoundError
from controllers.web.wraps import WebApiResource
from core.app.apps.advanced_chat.app_generator import AdvancedChatAppGenerator
from core.app.apps.base_app_generator import BaseAppGenerator
from core.app.apps.common.workflow_response_converter import WorkflowResponseConverter
from core.app.apps.message_generator import MessageGenerator
from core.app.apps.workflow.app_generator import WorkflowAppGenerator
from extensions.ext_database import db
from models.enums import CreatorUserRole
from models.model import App, AppMode, EndUser
from repositories.factory import DifyAPIRepositoryFactory
from services.workflow_event_snapshot_service import build_workflow_event_stream
class WorkflowEventsApi(WebApiResource):
"""API for getting workflow execution events after resume."""
def get(self, app_model: App, end_user: EndUser, task_id: str):
"""
Get workflow execution events stream after resume.
GET /api/workflow/<task_id>/events
Returns Server-Sent Events stream.
"""
workflow_run_id = task_id
session_maker = sessionmaker(db.engine)
repo = DifyAPIRepositoryFactory.create_api_workflow_run_repository(session_maker)
workflow_run = repo.get_workflow_run_by_id_and_tenant_id(
tenant_id=app_model.tenant_id,
run_id=workflow_run_id,
)
if workflow_run is None:
raise NotFoundError(f"WorkflowRun not found, id={workflow_run_id}")
if workflow_run.app_id != app_model.id:
raise NotFoundError(f"WorkflowRun not found, id={workflow_run_id}")
if workflow_run.created_by_role != CreatorUserRole.END_USER:
raise NotFoundError(f"WorkflowRun not created by end user, id={workflow_run_id}")
if workflow_run.created_by != end_user.id:
raise NotFoundError(f"WorkflowRun not created by the current end user, id={workflow_run_id}")
if workflow_run.finished_at is not None:
response = WorkflowResponseConverter.workflow_run_result_to_finish_response(
task_id=workflow_run.id,
workflow_run=workflow_run,
creator_user=end_user,
)
payload = response.model_dump(mode="json")
payload["event"] = response.event.value
def _generate_finished_events() -> Generator[str, None, None]:
yield f"data: {json.dumps(payload)}\n\n"
event_generator = _generate_finished_events
else:
app_mode = AppMode.value_of(app_model.mode)
msg_generator = MessageGenerator()
generator: BaseAppGenerator
if app_mode == AppMode.ADVANCED_CHAT:
generator = AdvancedChatAppGenerator()
elif app_mode == AppMode.WORKFLOW:
generator = WorkflowAppGenerator()
else:
raise InvalidArgumentError(f"cannot subscribe to workflow run, workflow_run_id={workflow_run.id}")
include_state_snapshot = request.args.get("include_state_snapshot", "false").lower() == "true"
def _generate_stream_events():
if include_state_snapshot:
return generator.convert_to_event_stream(
build_workflow_event_stream(
app_mode=app_mode,
workflow_run=workflow_run,
tenant_id=app_model.tenant_id,
app_id=app_model.id,
session_maker=session_maker,
)
)
return generator.convert_to_event_stream(
msg_generator.retrieve_events(app_mode, workflow_run.id),
)
event_generator = _generate_stream_events
return Response(
event_generator(),
mimetype="text/event-stream",
headers={
"Cache-Control": "no-cache",
"Connection": "keep-alive",
},
)
# Register the APIs
api.add_resource(WorkflowEventsApi, "/workflow/<string:task_id>/events")

View File

@ -4,8 +4,8 @@ import contextvars
import logging
import threading
import uuid
from collections.abc import Generator, Mapping
from typing import TYPE_CHECKING, Any, Literal, Union, overload
from collections.abc import Generator, Mapping, Sequence
from typing import TYPE_CHECKING, Any, Literal, TypeVar, Union, overload
from flask import Flask, current_app
from pydantic import ValidationError
@ -29,21 +29,25 @@ from core.app.apps.message_based_app_generator import MessageBasedAppGenerator
from core.app.apps.message_based_app_queue_manager import MessageBasedAppQueueManager
from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, InvokeFrom
from core.app.entities.task_entities import ChatbotAppBlockingResponse, ChatbotAppStreamResponse
from core.app.layers.pause_state_persist_layer import PauseStateLayerConfig, PauseStatePersistenceLayer
from core.helper.trace_id_helper import extract_external_trace_id_from_args
from core.model_runtime.errors.invoke import InvokeAuthorizationError
from core.ops.ops_trace_manager import TraceQueueManager
from core.prompt.utils.get_thread_messages_length import get_thread_messages_length
from core.repositories import DifyCoreRepositoryFactory
from core.workflow.graph_engine.layers.base import GraphEngineLayer
from core.workflow.repositories.draft_variable_repository import (
DraftVariableSaverFactory,
)
from core.workflow.repositories.workflow_execution_repository import WorkflowExecutionRepository
from core.workflow.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository
from core.workflow.runtime import GraphRuntimeState
from core.workflow.variable_loader import DUMMY_VARIABLE_LOADER, VariableLoader
from extensions.ext_database import db
from factories import file_factory
from libs.flask_utils import preserve_flask_contexts
from models import Account, App, Conversation, EndUser, Message, Workflow, WorkflowNodeExecutionTriggeredFrom
from models.base import Base
from models.enums import WorkflowRunTriggeredFrom
from services.conversation_service import ConversationService
from services.workflow_draft_variable_service import (
@ -65,7 +69,9 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
user: Union[Account, EndUser],
args: Mapping[str, Any],
invoke_from: InvokeFrom,
workflow_run_id: str,
streaming: Literal[False],
pause_state_config: PauseStateLayerConfig | None = None,
) -> Mapping[str, Any]: ...
@overload
@ -74,9 +80,11 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
app_model: App,
workflow: Workflow,
user: Union[Account, EndUser],
args: Mapping,
args: Mapping[str, Any],
invoke_from: InvokeFrom,
workflow_run_id: str,
streaming: Literal[True],
pause_state_config: PauseStateLayerConfig | None = None,
) -> Generator[Mapping | str, None, None]: ...
@overload
@ -85,9 +93,11 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
app_model: App,
workflow: Workflow,
user: Union[Account, EndUser],
args: Mapping,
args: Mapping[str, Any],
invoke_from: InvokeFrom,
workflow_run_id: str,
streaming: bool,
pause_state_config: PauseStateLayerConfig | None = None,
) -> Mapping[str, Any] | Generator[str | Mapping, None, None]: ...
def generate(
@ -95,9 +105,11 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
app_model: App,
workflow: Workflow,
user: Union[Account, EndUser],
args: Mapping,
args: Mapping[str, Any],
invoke_from: InvokeFrom,
workflow_run_id: str,
streaming: bool = True,
pause_state_config: PauseStateLayerConfig | None = None,
) -> Mapping[str, Any] | Generator[str | Mapping, None, None]:
"""
Generate App response.
@ -161,7 +173,6 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
# always enable retriever resource in debugger mode
app_config.additional_features.show_retrieve_source = True # type: ignore
workflow_run_id = str(uuid.uuid4())
# init application generate entity
application_generate_entity = AdvancedChatAppGenerateEntity(
task_id=str(uuid.uuid4()),
@ -179,7 +190,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
invoke_from=invoke_from,
extras=extras,
trace_manager=trace_manager,
workflow_run_id=workflow_run_id,
workflow_run_id=str(workflow_run_id),
)
contexts.plugin_tool_providers.set({})
contexts.plugin_tool_providers_lock.set(threading.Lock())
@ -216,6 +227,38 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
workflow_node_execution_repository=workflow_node_execution_repository,
conversation=conversation,
stream=streaming,
pause_state_config=pause_state_config,
)
def resume(
self,
*,
app_model: App,
workflow: Workflow,
user: Union[Account, EndUser],
conversation: Conversation,
message: Message,
application_generate_entity: AdvancedChatAppGenerateEntity,
workflow_execution_repository: WorkflowExecutionRepository,
workflow_node_execution_repository: WorkflowNodeExecutionRepository,
graph_runtime_state: GraphRuntimeState,
pause_state_config: PauseStateLayerConfig | None = None,
):
"""
Resume a paused advanced chat execution.
"""
return self._generate(
workflow=workflow,
user=user,
invoke_from=application_generate_entity.invoke_from,
application_generate_entity=application_generate_entity,
workflow_execution_repository=workflow_execution_repository,
workflow_node_execution_repository=workflow_node_execution_repository,
conversation=conversation,
message=message,
stream=application_generate_entity.stream,
pause_state_config=pause_state_config,
graph_runtime_state=graph_runtime_state,
)
def single_iteration_generate(
@ -396,8 +439,12 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
workflow_execution_repository: WorkflowExecutionRepository,
workflow_node_execution_repository: WorkflowNodeExecutionRepository,
conversation: Conversation | None = None,
message: Message | None = None,
stream: bool = True,
variable_loader: VariableLoader = DUMMY_VARIABLE_LOADER,
pause_state_config: PauseStateLayerConfig | None = None,
graph_runtime_state: GraphRuntimeState | None = None,
graph_engine_layers: Sequence[GraphEngineLayer] = (),
) -> Mapping[str, Any] | Generator[str | Mapping[str, Any], Any, None]:
"""
Generate App response.
@ -411,12 +458,12 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
:param conversation: conversation
:param stream: is stream
"""
is_first_conversation = False
if not conversation:
is_first_conversation = True
is_first_conversation = conversation is None
# init generate records
(conversation, message) = self._init_generate_records(application_generate_entity, conversation)
if conversation is not None and message is not None:
pass
else:
conversation, message = self._init_generate_records(application_generate_entity, conversation)
if is_first_conversation:
# update conversation features
@ -439,6 +486,16 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
message_id=message.id,
)
graph_layers: list[GraphEngineLayer] = list(graph_engine_layers)
if pause_state_config is not None:
graph_layers.append(
PauseStatePersistenceLayer(
session_factory=pause_state_config.session_factory,
generate_entity=application_generate_entity,
state_owner_user_id=pause_state_config.state_owner_user_id,
)
)
# new thread with request context and contextvars
context = contextvars.copy_context()
@ -454,14 +511,25 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
"variable_loader": variable_loader,
"workflow_execution_repository": workflow_execution_repository,
"workflow_node_execution_repository": workflow_node_execution_repository,
"graph_engine_layers": tuple(graph_layers),
"graph_runtime_state": graph_runtime_state,
},
)
worker_thread.start()
# release database connection, because the following new thread operations may take a long time
db.session.refresh(workflow)
db.session.refresh(message)
with Session(bind=db.engine, expire_on_commit=False) as session:
workflow = _refresh_model(session, workflow)
message = _refresh_model(session, message)
# workflow_ = session.get(Workflow, workflow.id)
# assert workflow_ is not None
# workflow = workflow_
# message_ = session.get(Message, message.id)
# assert message_ is not None
# message = message_
# db.session.refresh(workflow)
# db.session.refresh(message)
# db.session.refresh(user)
db.session.close()
@ -490,6 +558,8 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
variable_loader: VariableLoader,
workflow_execution_repository: WorkflowExecutionRepository,
workflow_node_execution_repository: WorkflowNodeExecutionRepository,
graph_engine_layers: Sequence[GraphEngineLayer] = (),
graph_runtime_state: GraphRuntimeState | None = None,
):
"""
Generate worker in a new thread.
@ -547,6 +617,8 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
app=app,
workflow_execution_repository=workflow_execution_repository,
workflow_node_execution_repository=workflow_node_execution_repository,
graph_engine_layers=graph_engine_layers,
graph_runtime_state=graph_runtime_state,
)
try:
@ -614,3 +686,13 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
else:
logger.exception("Failed to process generate task pipeline, conversation_id: %s", conversation.id)
raise e
_T = TypeVar("_T", bound=Base)
def _refresh_model(session, model: _T) -> _T:
with Session(bind=db.engine, expire_on_commit=False) as session:
detach_model = session.get(type(model), model.id)
assert detach_model is not None
return detach_model

View File

@ -66,6 +66,7 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
workflow_execution_repository: WorkflowExecutionRepository,
workflow_node_execution_repository: WorkflowNodeExecutionRepository,
graph_engine_layers: Sequence[GraphEngineLayer] = (),
graph_runtime_state: GraphRuntimeState | None = None,
):
super().__init__(
queue_manager=queue_manager,
@ -82,6 +83,7 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
self._app = app
self._workflow_execution_repository = workflow_execution_repository
self._workflow_node_execution_repository = workflow_node_execution_repository
self._resume_graph_runtime_state = graph_runtime_state
@trace_span(WorkflowAppRunnerHandler)
def run(self):
@ -110,7 +112,21 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
invoke_from = InvokeFrom.DEBUGGER
user_from = self._resolve_user_from(invoke_from)
if self.application_generate_entity.single_iteration_run or self.application_generate_entity.single_loop_run:
resume_state = self._resume_graph_runtime_state
if resume_state is not None:
graph_runtime_state = resume_state
variable_pool = graph_runtime_state.variable_pool
graph = self._init_graph(
graph_config=self._workflow.graph_dict,
graph_runtime_state=graph_runtime_state,
workflow_id=self._workflow.id,
tenant_id=self._workflow.tenant_id,
user_id=self.application_generate_entity.user_id,
invoke_from=invoke_from,
user_from=user_from,
)
elif self.application_generate_entity.single_iteration_run or self.application_generate_entity.single_loop_run:
# Handle single iteration or single loop run
graph, variable_pool, graph_runtime_state = self._prepare_single_node_execution(
workflow=self._workflow,

View File

@ -24,6 +24,8 @@ from core.app.entities.queue_entities import (
QueueAgentLogEvent,
QueueAnnotationReplyEvent,
QueueErrorEvent,
QueueHumanInputFormFilledEvent,
QueueHumanInputFormTimeoutEvent,
QueueIterationCompletedEvent,
QueueIterationNextEvent,
QueueIterationStartEvent,
@ -42,6 +44,7 @@ from core.app.entities.queue_entities import (
QueueTextChunkEvent,
QueueWorkflowFailedEvent,
QueueWorkflowPartialSuccessEvent,
QueueWorkflowPausedEvent,
QueueWorkflowStartedEvent,
QueueWorkflowSucceededEvent,
WorkflowQueueMessage,
@ -63,6 +66,8 @@ from core.base.tts import AppGeneratorTTSPublisher, AudioTrunk
from core.model_runtime.entities.llm_entities import LLMUsage
from core.model_runtime.utils.encoders import jsonable_encoder
from core.ops.ops_trace_manager import TraceQueueManager
from core.repositories.human_input_repository import HumanInputFormRepositoryImpl
from core.workflow.entities.pause_reason import HumanInputRequired
from core.workflow.enums import WorkflowExecutionStatus
from core.workflow.nodes import NodeType
from core.workflow.repositories.draft_variable_repository import DraftVariableSaverFactory
@ -71,7 +76,8 @@ from core.workflow.system_variable import SystemVariable
from extensions.ext_database import db
from libs.datetime_utils import naive_utc_now
from models import Account, Conversation, EndUser, Message, MessageFile
from models.enums import CreatorUserRole
from models.enums import CreatorUserRole, MessageStatus
from models.execution_extra_content import HumanInputContent
from models.workflow import Workflow
logger = logging.getLogger(__name__)
@ -128,6 +134,7 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport):
)
self._task_state = WorkflowTaskState()
self._seed_task_state_from_message(message)
self._message_cycle_manager = MessageCycleManager(
application_generate_entity=application_generate_entity, task_state=self._task_state
)
@ -135,6 +142,7 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport):
self._application_generate_entity = application_generate_entity
self._workflow_id = workflow.id
self._workflow_features_dict = workflow.features_dict
self._workflow_tenant_id = workflow.tenant_id
self._conversation_id = conversation.id
self._conversation_mode = conversation.mode
self._message_id = message.id
@ -144,8 +152,13 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport):
self._workflow_run_id: str = ""
self._draft_var_saver_factory = draft_var_saver_factory
self._graph_runtime_state: GraphRuntimeState | None = None
self._message_saved_on_pause = False
self._seed_graph_runtime_state_from_queue_manager()
def _seed_task_state_from_message(self, message: Message) -> None:
if message.status == MessageStatus.PAUSED and message.answer:
self._task_state.answer = message.answer
def process(self) -> Union[ChatbotAppBlockingResponse, Generator[ChatbotAppStreamResponse, None, None]]:
"""
Process generate task pipeline.
@ -308,6 +321,7 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport):
task_id=self._application_generate_entity.task_id,
workflow_run_id=run_id,
workflow_id=self._workflow_id,
reason=event.reason,
)
yield workflow_start_resp
@ -525,6 +539,35 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport):
)
yield workflow_finish_resp
def _handle_workflow_paused_event(
self,
event: QueueWorkflowPausedEvent,
**kwargs,
) -> Generator[StreamResponse, None, None]:
"""Handle workflow paused events."""
validated_state = self._ensure_graph_runtime_initialized()
responses = self._workflow_response_converter.workflow_pause_to_stream_response(
event=event,
task_id=self._application_generate_entity.task_id,
graph_runtime_state=validated_state,
)
for reason in event.reasons:
if isinstance(reason, HumanInputRequired):
self._persist_human_input_extra_content(form_id=reason.form_id, node_id=reason.node_id)
yield from responses
resolved_state: GraphRuntimeState | None = None
try:
resolved_state = self._ensure_graph_runtime_initialized()
except ValueError:
resolved_state = None
with self._database_session() as session:
self._save_message(session=session, graph_runtime_state=resolved_state)
message = self._get_message(session=session)
if message is not None:
message.status = MessageStatus.PAUSED
self._message_saved_on_pause = True
self._base_task_pipeline.queue_manager.publish(QueueAdvancedChatMessageEndEvent(), PublishFrom.TASK_PIPELINE)
def _handle_workflow_failed_event(
@ -614,9 +657,10 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport):
reason=QueueMessageReplaceEvent.MessageReplaceReason.OUTPUT_MODERATION,
)
# Save message
with self._database_session() as session:
self._save_message(session=session, graph_runtime_state=resolved_state)
# Save message unless it has already been persisted on pause.
if not self._message_saved_on_pause:
with self._database_session() as session:
self._save_message(session=session, graph_runtime_state=resolved_state)
yield self._message_end_to_stream_response()
@ -642,6 +686,65 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport):
"""Handle message replace events."""
yield self._message_cycle_manager.message_replace_to_stream_response(answer=event.text, reason=event.reason)
def _handle_human_input_form_filled_event(
self, event: QueueHumanInputFormFilledEvent, **kwargs
) -> Generator[StreamResponse, None, None]:
"""Handle human input form filled events."""
self._persist_human_input_extra_content(node_id=event.node_id)
yield self._workflow_response_converter.human_input_form_filled_to_stream_response(
event=event, task_id=self._application_generate_entity.task_id
)
def _handle_human_input_form_timeout_event(
self, event: QueueHumanInputFormTimeoutEvent, **kwargs
) -> Generator[StreamResponse, None, None]:
"""Handle human input form timeout events."""
yield self._workflow_response_converter.human_input_form_timeout_to_stream_response(
event=event, task_id=self._application_generate_entity.task_id
)
def _persist_human_input_extra_content(self, *, node_id: str | None = None, form_id: str | None = None) -> None:
if not self._workflow_run_id or not self._message_id:
return
if form_id is None:
if node_id is None:
return
form_id = self._load_human_input_form_id(node_id=node_id)
if form_id is None:
logger.warning(
"HumanInput form not found for workflow run %s node %s",
self._workflow_run_id,
node_id,
)
return
with self._database_session() as session:
exists_stmt = select(HumanInputContent).where(
HumanInputContent.workflow_run_id == self._workflow_run_id,
HumanInputContent.message_id == self._message_id,
HumanInputContent.form_id == form_id,
)
if session.scalar(exists_stmt) is not None:
return
content = HumanInputContent(
workflow_run_id=self._workflow_run_id,
message_id=self._message_id,
form_id=form_id,
)
session.add(content)
def _load_human_input_form_id(self, *, node_id: str) -> str | None:
form_repository = HumanInputFormRepositoryImpl(
session_factory=db.engine,
tenant_id=self._workflow_tenant_id,
)
form = form_repository.get_form(self._workflow_run_id, node_id)
if form is None:
return None
return form.id
def _handle_agent_log_event(self, event: QueueAgentLogEvent, **kwargs) -> Generator[StreamResponse, None, None]:
"""Handle agent log events."""
yield self._workflow_response_converter.handle_agent_log(
@ -659,6 +762,7 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport):
QueueWorkflowStartedEvent: self._handle_workflow_started_event,
QueueWorkflowSucceededEvent: self._handle_workflow_succeeded_event,
QueueWorkflowPartialSuccessEvent: self._handle_workflow_partial_success_event,
QueueWorkflowPausedEvent: self._handle_workflow_paused_event,
QueueWorkflowFailedEvent: self._handle_workflow_failed_event,
# Node events
QueueNodeRetryEvent: self._handle_node_retry_event,
@ -680,6 +784,8 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport):
QueueMessageReplaceEvent: self._handle_message_replace_event,
QueueAdvancedChatMessageEndEvent: self._handle_advanced_chat_message_end_event,
QueueAgentLogEvent: self._handle_agent_log_event,
QueueHumanInputFormFilledEvent: self._handle_human_input_form_filled_event,
QueueHumanInputFormTimeoutEvent: self._handle_human_input_form_timeout_event,
}
def _dispatch_event(
@ -747,6 +853,9 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport):
case QueueWorkflowFailedEvent():
yield from self._handle_workflow_failed_event(event, trace_manager=trace_manager)
break
case QueueWorkflowPausedEvent():
yield from self._handle_workflow_paused_event(event)
break
case QueueStopEvent():
yield from self._handle_stop_event(event, graph_runtime_state=None, trace_manager=trace_manager)
@ -772,6 +881,11 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport):
def _save_message(self, *, session: Session, graph_runtime_state: GraphRuntimeState | None = None):
message = self._get_message(session=session)
if message is None:
return
if message.status == MessageStatus.PAUSED:
message.status = MessageStatus.NORMAL
# If there are assistant files, remove markdown image links from answer
answer_text = self._task_state.answer

View File

@ -5,9 +5,14 @@ from dataclasses import dataclass
from datetime import datetime
from typing import Any, NewType, Union
from sqlalchemy import select
from sqlalchemy.orm import Session
from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, InvokeFrom, WorkflowAppGenerateEntity
from core.app.entities.queue_entities import (
QueueAgentLogEvent,
QueueHumanInputFormFilledEvent,
QueueHumanInputFormTimeoutEvent,
QueueIterationCompletedEvent,
QueueIterationNextEvent,
QueueIterationStartEvent,
@ -19,9 +24,13 @@ from core.app.entities.queue_entities import (
QueueNodeRetryEvent,
QueueNodeStartedEvent,
QueueNodeSucceededEvent,
QueueWorkflowPausedEvent,
)
from core.app.entities.task_entities import (
AgentLogStreamResponse,
HumanInputFormFilledResponse,
HumanInputFormTimeoutResponse,
HumanInputRequiredResponse,
IterationNodeCompletedStreamResponse,
IterationNodeNextStreamResponse,
IterationNodeStartStreamResponse,
@ -31,7 +40,9 @@ from core.app.entities.task_entities import (
NodeFinishStreamResponse,
NodeRetryStreamResponse,
NodeStartStreamResponse,
StreamResponse,
WorkflowFinishStreamResponse,
WorkflowPauseStreamResponse,
WorkflowStartStreamResponse,
)
from core.file import FILE_MODEL_IDENTITY, File
@ -40,6 +51,8 @@ from core.tools.entities.tool_entities import ToolProviderType
from core.tools.tool_manager import ToolManager
from core.trigger.trigger_manager import TriggerManager
from core.variables.segments import ArrayFileSegment, FileSegment, Segment
from core.workflow.entities.pause_reason import HumanInputRequired
from core.workflow.entities.workflow_start_reason import WorkflowStartReason
from core.workflow.enums import (
NodeType,
SystemVariableKey,
@ -51,8 +64,11 @@ from core.workflow.runtime import GraphRuntimeState
from core.workflow.system_variable import SystemVariable
from core.workflow.workflow_entry import WorkflowEntry
from core.workflow.workflow_type_encoder import WorkflowRuntimeTypeConverter
from extensions.ext_database import db
from libs.datetime_utils import naive_utc_now
from models import Account, EndUser
from models.human_input import HumanInputForm
from models.workflow import WorkflowRun
from services.variable_truncator import BaseTruncator, DummyVariableTruncator, VariableTruncator
NodeExecutionId = NewType("NodeExecutionId", str)
@ -191,6 +207,7 @@ class WorkflowResponseConverter:
task_id: str,
workflow_run_id: str,
workflow_id: str,
reason: WorkflowStartReason,
) -> WorkflowStartStreamResponse:
run_id = self._ensure_workflow_run_id(workflow_run_id)
started_at = naive_utc_now()
@ -204,6 +221,7 @@ class WorkflowResponseConverter:
workflow_id=workflow_id,
inputs=self._workflow_inputs,
created_at=int(started_at.timestamp()),
reason=reason,
),
)
@ -264,6 +282,160 @@ class WorkflowResponseConverter:
),
)
def workflow_pause_to_stream_response(
self,
*,
event: QueueWorkflowPausedEvent,
task_id: str,
graph_runtime_state: GraphRuntimeState,
) -> list[StreamResponse]:
run_id = self._ensure_workflow_run_id()
started_at = self._workflow_started_at
if started_at is None:
raise ValueError(
"workflow_pause_to_stream_response called before workflow_start_to_stream_response",
)
paused_at = naive_utc_now()
elapsed_time = (paused_at - started_at).total_seconds()
encoded_outputs = self._encode_outputs(event.outputs) or {}
if self._application_generate_entity.invoke_from == InvokeFrom.SERVICE_API:
encoded_outputs = {}
pause_reasons = [reason.model_dump(mode="json") for reason in event.reasons]
human_input_form_ids = [reason.form_id for reason in event.reasons if isinstance(reason, HumanInputRequired)]
expiration_times_by_form_id: dict[str, datetime] = {}
if human_input_form_ids:
stmt = select(HumanInputForm.id, HumanInputForm.expiration_time).where(
HumanInputForm.id.in_(human_input_form_ids)
)
with Session(bind=db.engine) as session:
for form_id, expiration_time in session.execute(stmt):
expiration_times_by_form_id[str(form_id)] = expiration_time
responses: list[StreamResponse] = []
for reason in event.reasons:
if isinstance(reason, HumanInputRequired):
expiration_time = expiration_times_by_form_id.get(reason.form_id)
if expiration_time is None:
raise ValueError(f"HumanInputForm not found for pause reason, form_id={reason.form_id}")
responses.append(
HumanInputRequiredResponse(
task_id=task_id,
workflow_run_id=run_id,
data=HumanInputRequiredResponse.Data(
form_id=reason.form_id,
node_id=reason.node_id,
node_title=reason.node_title,
form_content=reason.form_content,
inputs=reason.inputs,
actions=reason.actions,
display_in_ui=reason.display_in_ui,
form_token=reason.form_token,
resolved_default_values=reason.resolved_default_values,
expiration_time=int(expiration_time.timestamp()),
),
)
)
responses.append(
WorkflowPauseStreamResponse(
task_id=task_id,
workflow_run_id=run_id,
data=WorkflowPauseStreamResponse.Data(
workflow_run_id=run_id,
paused_nodes=list(event.paused_nodes),
outputs=encoded_outputs,
reasons=pause_reasons,
status=WorkflowExecutionStatus.PAUSED.value,
created_at=int(started_at.timestamp()),
elapsed_time=elapsed_time,
total_tokens=graph_runtime_state.total_tokens,
total_steps=graph_runtime_state.node_run_steps,
),
)
)
return responses
def human_input_form_filled_to_stream_response(
self, *, event: QueueHumanInputFormFilledEvent, task_id: str
) -> HumanInputFormFilledResponse:
run_id = self._ensure_workflow_run_id()
return HumanInputFormFilledResponse(
task_id=task_id,
workflow_run_id=run_id,
data=HumanInputFormFilledResponse.Data(
node_id=event.node_id,
node_title=event.node_title,
rendered_content=event.rendered_content,
action_id=event.action_id,
action_text=event.action_text,
),
)
def human_input_form_timeout_to_stream_response(
self, *, event: QueueHumanInputFormTimeoutEvent, task_id: str
) -> HumanInputFormTimeoutResponse:
run_id = self._ensure_workflow_run_id()
return HumanInputFormTimeoutResponse(
task_id=task_id,
workflow_run_id=run_id,
data=HumanInputFormTimeoutResponse.Data(
node_id=event.node_id,
node_title=event.node_title,
expiration_time=int(event.expiration_time.timestamp()),
),
)
@classmethod
def workflow_run_result_to_finish_response(
cls,
*,
task_id: str,
workflow_run: WorkflowRun,
creator_user: Account | EndUser,
) -> WorkflowFinishStreamResponse:
run_id = workflow_run.id
elapsed_time = workflow_run.elapsed_time
encoded_outputs = workflow_run.outputs_dict
finished_at = workflow_run.finished_at
assert finished_at is not None
created_by: Mapping[str, object]
user = creator_user
if isinstance(user, Account):
created_by = {
"id": user.id,
"name": user.name,
"email": user.email,
}
else:
created_by = {
"id": user.id,
"user": user.session_id,
}
return WorkflowFinishStreamResponse(
task_id=task_id,
workflow_run_id=run_id,
data=WorkflowFinishStreamResponse.Data(
id=run_id,
workflow_id=workflow_run.workflow_id,
status=workflow_run.status.value,
outputs=encoded_outputs,
error=workflow_run.error,
elapsed_time=elapsed_time,
total_tokens=workflow_run.total_tokens,
total_steps=workflow_run.total_steps,
created_by=created_by,
created_at=int(workflow_run.created_at.timestamp()),
finished_at=int(finished_at.timestamp()),
files=cls.fetch_files_from_node_outputs(encoded_outputs),
exceptions_count=workflow_run.exceptions_count,
),
)
def workflow_node_start_to_stream_response(
self,
*,
@ -592,7 +764,8 @@ class WorkflowResponseConverter:
),
)
def fetch_files_from_node_outputs(self, outputs_dict: Mapping[str, Any] | None) -> Sequence[Mapping[str, Any]]:
@classmethod
def fetch_files_from_node_outputs(cls, outputs_dict: Mapping[str, Any] | None) -> Sequence[Mapping[str, Any]]:
"""
Fetch files from node outputs
:param outputs_dict: node outputs dict
@ -601,7 +774,7 @@ class WorkflowResponseConverter:
if not outputs_dict:
return []
files = [self._fetch_files_from_variable_value(output_value) for output_value in outputs_dict.values()]
files = [cls._fetch_files_from_variable_value(output_value) for output_value in outputs_dict.values()]
# Remove None
files = [file for file in files if file]
# Flatten list

View File

@ -1,6 +1,6 @@
import json
import logging
from collections.abc import Generator
from collections.abc import Callable, Generator, Mapping
from typing import Union, cast
from sqlalchemy import select
@ -10,12 +10,14 @@ from core.app.app_config.entities import EasyUIBasedAppConfig, EasyUIBasedAppMod
from core.app.apps.base_app_generator import BaseAppGenerator
from core.app.apps.base_app_queue_manager import AppQueueManager
from core.app.apps.exc import GenerateTaskStoppedError
from core.app.apps.streaming_utils import stream_topic_events
from core.app.entities.app_invoke_entities import (
AdvancedChatAppGenerateEntity,
AgentChatAppGenerateEntity,
AppGenerateEntity,
ChatAppGenerateEntity,
CompletionAppGenerateEntity,
ConversationAppGenerateEntity,
InvokeFrom,
)
from core.app.entities.task_entities import (
@ -27,6 +29,8 @@ from core.app.entities.task_entities import (
from core.app.task_pipeline.easy_ui_based_generate_task_pipeline import EasyUIBasedGenerateTaskPipeline
from core.prompt.utils.prompt_template_parser import PromptTemplateParser
from extensions.ext_database import db
from extensions.ext_redis import get_pubsub_broadcast_channel
from libs.broadcast_channel.channel import Topic
from libs.datetime_utils import naive_utc_now
from models import Account
from models.enums import CreatorUserRole
@ -156,6 +160,7 @@ class MessageBasedAppGenerator(BaseAppGenerator):
query = application_generate_entity.query or "New conversation"
conversation_name = (query[:20] + "") if len(query) > 20 else query
created_new_conversation = conversation is None
try:
if not conversation:
conversation = Conversation(
@ -232,6 +237,10 @@ class MessageBasedAppGenerator(BaseAppGenerator):
db.session.add_all(message_files)
db.session.commit()
if isinstance(application_generate_entity, ConversationAppGenerateEntity):
application_generate_entity.conversation_id = conversation.id
application_generate_entity.is_new_conversation = created_new_conversation
return conversation, message
except Exception:
db.session.rollback()
@ -284,3 +293,29 @@ class MessageBasedAppGenerator(BaseAppGenerator):
raise MessageNotExistsError("Message not exists")
return message
@staticmethod
def _make_channel_key(app_mode: AppMode, workflow_run_id: str):
return f"channel:{app_mode}:{workflow_run_id}"
@classmethod
def get_response_topic(cls, app_mode: AppMode, workflow_run_id: str) -> Topic:
key = cls._make_channel_key(app_mode, workflow_run_id)
channel = get_pubsub_broadcast_channel()
topic = channel.topic(key)
return topic
@classmethod
def retrieve_events(
cls,
app_mode: AppMode,
workflow_run_id: str,
idle_timeout=300,
on_subscribe: Callable[[], None] | None = None,
) -> Generator[Mapping | str, None, None]:
topic = cls.get_response_topic(app_mode, workflow_run_id)
return stream_topic_events(
topic=topic,
idle_timeout=idle_timeout,
on_subscribe=on_subscribe,
)

View File

@ -0,0 +1,36 @@
from collections.abc import Callable, Generator, Mapping
from core.app.apps.streaming_utils import stream_topic_events
from extensions.ext_redis import get_pubsub_broadcast_channel
from libs.broadcast_channel.channel import Topic
from models.model import AppMode
class MessageGenerator:
@staticmethod
def _make_channel_key(app_mode: AppMode, workflow_run_id: str):
return f"channel:{app_mode}:{str(workflow_run_id)}"
@classmethod
def get_response_topic(cls, app_mode: AppMode, workflow_run_id: str) -> Topic:
key = cls._make_channel_key(app_mode, workflow_run_id)
channel = get_pubsub_broadcast_channel()
topic = channel.topic(key)
return topic
@classmethod
def retrieve_events(
cls,
app_mode: AppMode,
workflow_run_id: str,
idle_timeout=300,
ping_interval: float = 10.0,
on_subscribe: Callable[[], None] | None = None,
) -> Generator[Mapping | str, None, None]:
topic = cls.get_response_topic(app_mode, workflow_run_id)
return stream_topic_events(
topic=topic,
idle_timeout=idle_timeout,
ping_interval=ping_interval,
on_subscribe=on_subscribe,
)

View File

@ -0,0 +1,70 @@
from __future__ import annotations
import json
import time
from collections.abc import Callable, Generator, Iterable, Mapping
from typing import Any
from core.app.entities.task_entities import StreamEvent
from libs.broadcast_channel.channel import Topic
from libs.broadcast_channel.exc import SubscriptionClosedError
def stream_topic_events(
*,
topic: Topic,
idle_timeout: float,
ping_interval: float | None = None,
on_subscribe: Callable[[], None] | None = None,
terminal_events: Iterable[str | StreamEvent] | None = None,
) -> Generator[Mapping[str, Any] | str, None, None]:
# send a PING event immediately to prevent the connection staying in pending state for a long time.
#
# This simplify the debugging process as the DevTools in Chrome does not
# provide complete curl command for pending connections.
yield StreamEvent.PING.value
terminal_values = _normalize_terminal_events(terminal_events)
last_msg_time = time.time()
last_ping_time = last_msg_time
with topic.subscribe() as sub:
# on_subscribe fires only after the Redis subscription is active.
# This is used to gate task start and reduce pub/sub race for the first event.
if on_subscribe is not None:
on_subscribe()
while True:
try:
msg = sub.receive(timeout=0.1)
except SubscriptionClosedError:
return
if msg is None:
current_time = time.time()
if current_time - last_msg_time > idle_timeout:
return
if ping_interval is not None and current_time - last_ping_time >= ping_interval:
yield StreamEvent.PING.value
last_ping_time = current_time
continue
last_msg_time = time.time()
last_ping_time = last_msg_time
event = json.loads(msg)
yield event
if not isinstance(event, dict):
continue
event_type = event.get("event")
if event_type in terminal_values:
return
def _normalize_terminal_events(terminal_events: Iterable[str | StreamEvent] | None) -> set[str]:
if not terminal_events:
return {StreamEvent.WORKFLOW_FINISHED.value, StreamEvent.WORKFLOW_PAUSED.value}
values: set[str] = set()
for item in terminal_events:
if isinstance(item, StreamEvent):
values.add(item.value)
else:
values.add(str(item))
return values

View File

@ -25,6 +25,7 @@ from core.app.apps.workflow.generate_response_converter import WorkflowAppGenera
from core.app.apps.workflow.generate_task_pipeline import WorkflowAppGenerateTaskPipeline
from core.app.entities.app_invoke_entities import InvokeFrom, WorkflowAppGenerateEntity
from core.app.entities.task_entities import WorkflowAppBlockingResponse, WorkflowAppStreamResponse
from core.app.layers.pause_state_persist_layer import PauseStateLayerConfig, PauseStatePersistenceLayer
from core.db.session_factory import session_factory
from core.helper.trace_id_helper import extract_external_trace_id_from_args
from core.model_runtime.errors.invoke import InvokeAuthorizationError
@ -34,12 +35,15 @@ from core.workflow.graph_engine.layers.base import GraphEngineLayer
from core.workflow.repositories.draft_variable_repository import DraftVariableSaverFactory
from core.workflow.repositories.workflow_execution_repository import WorkflowExecutionRepository
from core.workflow.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository
from core.workflow.runtime import GraphRuntimeState
from core.workflow.variable_loader import DUMMY_VARIABLE_LOADER, VariableLoader
from extensions.ext_database import db
from factories import file_factory
from libs.flask_utils import preserve_flask_contexts
from models import Account, App, EndUser, Workflow, WorkflowNodeExecutionTriggeredFrom
from models.account import Account
from models.enums import WorkflowRunTriggeredFrom
from models.model import App, EndUser
from models.workflow import Workflow, WorkflowNodeExecutionTriggeredFrom
from services.workflow_draft_variable_service import DraftVarLoader, WorkflowDraftVariableService
if TYPE_CHECKING:
@ -66,9 +70,11 @@ class WorkflowAppGenerator(BaseAppGenerator):
invoke_from: InvokeFrom,
streaming: Literal[True],
call_depth: int,
workflow_run_id: str | uuid.UUID | None = None,
triggered_from: WorkflowRunTriggeredFrom | None = None,
root_node_id: str | None = None,
graph_engine_layers: Sequence[GraphEngineLayer] = (),
pause_state_config: PauseStateLayerConfig | None = None,
) -> Generator[Mapping[str, Any] | str, None, None]: ...
@overload
@ -82,9 +88,11 @@ class WorkflowAppGenerator(BaseAppGenerator):
invoke_from: InvokeFrom,
streaming: Literal[False],
call_depth: int,
workflow_run_id: str | uuid.UUID | None = None,
triggered_from: WorkflowRunTriggeredFrom | None = None,
root_node_id: str | None = None,
graph_engine_layers: Sequence[GraphEngineLayer] = (),
pause_state_config: PauseStateLayerConfig | None = None,
) -> Mapping[str, Any]: ...
@overload
@ -98,9 +106,11 @@ class WorkflowAppGenerator(BaseAppGenerator):
invoke_from: InvokeFrom,
streaming: bool,
call_depth: int,
workflow_run_id: str | uuid.UUID | None = None,
triggered_from: WorkflowRunTriggeredFrom | None = None,
root_node_id: str | None = None,
graph_engine_layers: Sequence[GraphEngineLayer] = (),
pause_state_config: PauseStateLayerConfig | None = None,
) -> Union[Mapping[str, Any], Generator[Mapping[str, Any] | str, None, None]]: ...
def generate(
@ -113,9 +123,11 @@ class WorkflowAppGenerator(BaseAppGenerator):
invoke_from: InvokeFrom,
streaming: bool = True,
call_depth: int = 0,
workflow_run_id: str | uuid.UUID | None = None,
triggered_from: WorkflowRunTriggeredFrom | None = None,
root_node_id: str | None = None,
graph_engine_layers: Sequence[GraphEngineLayer] = (),
pause_state_config: PauseStateLayerConfig | None = None,
) -> Union[Mapping[str, Any], Generator[Mapping[str, Any] | str, None, None]]:
files: Sequence[Mapping[str, Any]] = args.get("files") or []
@ -150,7 +162,7 @@ class WorkflowAppGenerator(BaseAppGenerator):
extras = {
**extract_external_trace_id_from_args(args),
}
workflow_run_id = str(uuid.uuid4())
workflow_run_id = str(workflow_run_id or uuid.uuid4())
# FIXME (Yeuoly): we need to remove the SKIP_PREPARE_USER_INPUTS_KEY from the args
# trigger shouldn't prepare user inputs
if self._should_prepare_user_inputs(args):
@ -216,13 +228,40 @@ class WorkflowAppGenerator(BaseAppGenerator):
streaming=streaming,
root_node_id=root_node_id,
graph_engine_layers=graph_engine_layers,
pause_state_config=pause_state_config,
)
def resume(self, *, workflow_run_id: str) -> None:
def resume(
self,
*,
app_model: App,
workflow: Workflow,
user: Union[Account, EndUser],
application_generate_entity: WorkflowAppGenerateEntity,
graph_runtime_state: GraphRuntimeState,
workflow_execution_repository: WorkflowExecutionRepository,
workflow_node_execution_repository: WorkflowNodeExecutionRepository,
graph_engine_layers: Sequence[GraphEngineLayer] = (),
pause_state_config: PauseStateLayerConfig | None = None,
variable_loader: VariableLoader = DUMMY_VARIABLE_LOADER,
) -> Union[Mapping[str, Any], Generator[str | Mapping[str, Any], None, None]]:
"""
@TBD
Resume a paused workflow execution using the persisted runtime state.
"""
pass
return self._generate(
app_model=app_model,
workflow=workflow,
user=user,
application_generate_entity=application_generate_entity,
invoke_from=application_generate_entity.invoke_from,
workflow_execution_repository=workflow_execution_repository,
workflow_node_execution_repository=workflow_node_execution_repository,
streaming=application_generate_entity.stream,
variable_loader=variable_loader,
graph_engine_layers=graph_engine_layers,
graph_runtime_state=graph_runtime_state,
pause_state_config=pause_state_config,
)
def _generate(
self,
@ -238,6 +277,8 @@ class WorkflowAppGenerator(BaseAppGenerator):
variable_loader: VariableLoader = DUMMY_VARIABLE_LOADER,
root_node_id: str | None = None,
graph_engine_layers: Sequence[GraphEngineLayer] = (),
graph_runtime_state: GraphRuntimeState | None = None,
pause_state_config: PauseStateLayerConfig | None = None,
) -> Union[Mapping[str, Any], Generator[str | Mapping[str, Any], None, None]]:
"""
Generate App response.
@ -251,6 +292,8 @@ class WorkflowAppGenerator(BaseAppGenerator):
:param workflow_node_execution_repository: repository for workflow node execution
:param streaming: is stream
"""
graph_layers: list[GraphEngineLayer] = list(graph_engine_layers)
# init queue manager
queue_manager = WorkflowAppQueueManager(
task_id=application_generate_entity.task_id,
@ -259,6 +302,15 @@ class WorkflowAppGenerator(BaseAppGenerator):
app_mode=app_model.mode,
)
if pause_state_config is not None:
graph_layers.append(
PauseStatePersistenceLayer(
session_factory=pause_state_config.session_factory,
generate_entity=application_generate_entity,
state_owner_user_id=pause_state_config.state_owner_user_id,
)
)
# new thread with request context and contextvars
context = contextvars.copy_context()
@ -276,7 +328,8 @@ class WorkflowAppGenerator(BaseAppGenerator):
"root_node_id": root_node_id,
"workflow_execution_repository": workflow_execution_repository,
"workflow_node_execution_repository": workflow_node_execution_repository,
"graph_engine_layers": graph_engine_layers,
"graph_engine_layers": tuple(graph_layers),
"graph_runtime_state": graph_runtime_state,
},
)
@ -378,6 +431,7 @@ class WorkflowAppGenerator(BaseAppGenerator):
workflow_node_execution_repository=workflow_node_execution_repository,
streaming=streaming,
variable_loader=var_loader,
pause_state_config=None,
)
def single_loop_generate(
@ -459,6 +513,7 @@ class WorkflowAppGenerator(BaseAppGenerator):
workflow_node_execution_repository=workflow_node_execution_repository,
streaming=streaming,
variable_loader=var_loader,
pause_state_config=None,
)
def _generate_worker(
@ -472,6 +527,7 @@ class WorkflowAppGenerator(BaseAppGenerator):
workflow_node_execution_repository: WorkflowNodeExecutionRepository,
root_node_id: str | None = None,
graph_engine_layers: Sequence[GraphEngineLayer] = (),
graph_runtime_state: GraphRuntimeState | None = None,
) -> None:
"""
Generate worker in a new thread.
@ -517,6 +573,7 @@ class WorkflowAppGenerator(BaseAppGenerator):
workflow_node_execution_repository=workflow_node_execution_repository,
root_node_id=root_node_id,
graph_engine_layers=graph_engine_layers,
graph_runtime_state=graph_runtime_state,
)
try:

View File

@ -42,6 +42,7 @@ class WorkflowAppRunner(WorkflowBasedAppRunner):
workflow_execution_repository: WorkflowExecutionRepository,
workflow_node_execution_repository: WorkflowNodeExecutionRepository,
graph_engine_layers: Sequence[GraphEngineLayer] = (),
graph_runtime_state: GraphRuntimeState | None = None,
):
super().__init__(
queue_manager=queue_manager,
@ -55,6 +56,7 @@ class WorkflowAppRunner(WorkflowBasedAppRunner):
self._root_node_id = root_node_id
self._workflow_execution_repository = workflow_execution_repository
self._workflow_node_execution_repository = workflow_node_execution_repository
self._resume_graph_runtime_state = graph_runtime_state
@trace_span(WorkflowAppRunnerHandler)
def run(self):
@ -63,23 +65,28 @@ class WorkflowAppRunner(WorkflowBasedAppRunner):
"""
app_config = self.application_generate_entity.app_config
app_config = cast(WorkflowAppConfig, app_config)
system_inputs = SystemVariable(
files=self.application_generate_entity.files,
user_id=self._sys_user_id,
app_id=app_config.app_id,
timestamp=int(naive_utc_now().timestamp()),
workflow_id=app_config.workflow_id,
workflow_execution_id=self.application_generate_entity.workflow_execution_id,
)
invoke_from = self.application_generate_entity.invoke_from
# if only single iteration or single loop run is requested
if self.application_generate_entity.single_iteration_run or self.application_generate_entity.single_loop_run:
invoke_from = InvokeFrom.DEBUGGER
user_from = self._resolve_user_from(invoke_from)
if self.application_generate_entity.single_iteration_run or self.application_generate_entity.single_loop_run:
resume_state = self._resume_graph_runtime_state
if resume_state is not None:
graph_runtime_state = resume_state
variable_pool = graph_runtime_state.variable_pool
graph = self._init_graph(
graph_config=self._workflow.graph_dict,
graph_runtime_state=graph_runtime_state,
workflow_id=self._workflow.id,
tenant_id=self._workflow.tenant_id,
user_id=self.application_generate_entity.user_id,
user_from=user_from,
invoke_from=invoke_from,
root_node_id=self._root_node_id,
)
elif self.application_generate_entity.single_iteration_run or self.application_generate_entity.single_loop_run:
graph, variable_pool, graph_runtime_state = self._prepare_single_node_execution(
workflow=self._workflow,
single_iteration_run=self.application_generate_entity.single_iteration_run,
@ -89,7 +96,14 @@ class WorkflowAppRunner(WorkflowBasedAppRunner):
inputs = self.application_generate_entity.inputs
# Create a variable pool.
system_inputs = SystemVariable(
files=self.application_generate_entity.files,
user_id=self._sys_user_id,
app_id=app_config.app_id,
timestamp=int(naive_utc_now().timestamp()),
workflow_id=app_config.workflow_id,
workflow_execution_id=self.application_generate_entity.workflow_execution_id,
)
variable_pool = VariablePool(
system_variables=system_inputs,
user_inputs=inputs,
@ -98,8 +112,6 @@ class WorkflowAppRunner(WorkflowBasedAppRunner):
)
graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter())
# init graph
graph = self._init_graph(
graph_config=self._workflow.graph_dict,
graph_runtime_state=graph_runtime_state,

View File

@ -0,0 +1,7 @@
from libs.exception import BaseHTTPException
class WorkflowPausedInBlockingModeError(BaseHTTPException):
error_code = "workflow_paused_in_blocking_mode"
description = "Workflow execution paused for human input; blocking response mode is not supported."
code = 400

View File

@ -16,6 +16,8 @@ from core.app.entities.queue_entities import (
MessageQueueMessage,
QueueAgentLogEvent,
QueueErrorEvent,
QueueHumanInputFormFilledEvent,
QueueHumanInputFormTimeoutEvent,
QueueIterationCompletedEvent,
QueueIterationNextEvent,
QueueIterationStartEvent,
@ -32,6 +34,7 @@ from core.app.entities.queue_entities import (
QueueTextChunkEvent,
QueueWorkflowFailedEvent,
QueueWorkflowPartialSuccessEvent,
QueueWorkflowPausedEvent,
QueueWorkflowStartedEvent,
QueueWorkflowSucceededEvent,
WorkflowQueueMessage,
@ -46,11 +49,13 @@ from core.app.entities.task_entities import (
WorkflowAppBlockingResponse,
WorkflowAppStreamResponse,
WorkflowFinishStreamResponse,
WorkflowPauseStreamResponse,
WorkflowStartStreamResponse,
)
from core.app.task_pipeline.based_generate_task_pipeline import BasedGenerateTaskPipeline
from core.base.tts import AppGeneratorTTSPublisher, AudioTrunk
from core.ops.ops_trace_manager import TraceQueueManager
from core.workflow.entities.workflow_start_reason import WorkflowStartReason
from core.workflow.enums import WorkflowExecutionStatus
from core.workflow.repositories.draft_variable_repository import DraftVariableSaverFactory
from core.workflow.runtime import GraphRuntimeState
@ -132,6 +137,25 @@ class WorkflowAppGenerateTaskPipeline(GraphRuntimeStateSupport):
for stream_response in generator:
if isinstance(stream_response, ErrorStreamResponse):
raise stream_response.err
elif isinstance(stream_response, WorkflowPauseStreamResponse):
response = WorkflowAppBlockingResponse(
task_id=self._application_generate_entity.task_id,
workflow_run_id=stream_response.data.workflow_run_id,
data=WorkflowAppBlockingResponse.Data(
id=stream_response.data.workflow_run_id,
workflow_id=self._workflow.id,
status=stream_response.data.status,
outputs=stream_response.data.outputs or {},
error=None,
elapsed_time=stream_response.data.elapsed_time,
total_tokens=stream_response.data.total_tokens,
total_steps=stream_response.data.total_steps,
created_at=stream_response.data.created_at,
finished_at=None,
),
)
return response
elif isinstance(stream_response, WorkflowFinishStreamResponse):
response = WorkflowAppBlockingResponse(
task_id=self._application_generate_entity.task_id,
@ -146,7 +170,7 @@ class WorkflowAppGenerateTaskPipeline(GraphRuntimeStateSupport):
total_tokens=stream_response.data.total_tokens,
total_steps=stream_response.data.total_steps,
created_at=int(stream_response.data.created_at),
finished_at=int(stream_response.data.finished_at),
finished_at=int(stream_response.data.finished_at) if stream_response.data.finished_at else None,
),
)
@ -259,13 +283,15 @@ class WorkflowAppGenerateTaskPipeline(GraphRuntimeStateSupport):
run_id = self._extract_workflow_run_id(runtime_state)
self._workflow_execution_id = run_id
with self._database_session() as session:
self._save_workflow_app_log(session=session, workflow_run_id=self._workflow_execution_id)
if event.reason == WorkflowStartReason.INITIAL:
with self._database_session() as session:
self._save_workflow_app_log(session=session, workflow_run_id=self._workflow_execution_id)
start_resp = self._workflow_response_converter.workflow_start_to_stream_response(
task_id=self._application_generate_entity.task_id,
workflow_run_id=run_id,
workflow_id=self._workflow.id,
reason=event.reason,
)
yield start_resp
@ -440,6 +466,21 @@ class WorkflowAppGenerateTaskPipeline(GraphRuntimeStateSupport):
)
yield workflow_finish_resp
def _handle_workflow_paused_event(
self,
event: QueueWorkflowPausedEvent,
**kwargs,
) -> Generator[StreamResponse, None, None]:
"""Handle workflow paused events."""
self._ensure_workflow_initialized()
validated_state = self._ensure_graph_runtime_initialized()
responses = self._workflow_response_converter.workflow_pause_to_stream_response(
event=event,
task_id=self._application_generate_entity.task_id,
graph_runtime_state=validated_state,
)
yield from responses
def _handle_workflow_failed_and_stop_events(
self,
event: Union[QueueWorkflowFailedEvent, QueueStopEvent],
@ -495,6 +536,22 @@ class WorkflowAppGenerateTaskPipeline(GraphRuntimeStateSupport):
task_id=self._application_generate_entity.task_id, event=event
)
def _handle_human_input_form_filled_event(
self, event: QueueHumanInputFormFilledEvent, **kwargs
) -> Generator[StreamResponse, None, None]:
"""Handle human input form filled events."""
yield self._workflow_response_converter.human_input_form_filled_to_stream_response(
event=event, task_id=self._application_generate_entity.task_id
)
def _handle_human_input_form_timeout_event(
self, event: QueueHumanInputFormTimeoutEvent, **kwargs
) -> Generator[StreamResponse, None, None]:
"""Handle human input form timeout events."""
yield self._workflow_response_converter.human_input_form_timeout_to_stream_response(
event=event, task_id=self._application_generate_entity.task_id
)
def _get_event_handlers(self) -> dict[type, Callable]:
"""Get mapping of event types to their handlers using fluent pattern."""
return {
@ -506,6 +563,7 @@ class WorkflowAppGenerateTaskPipeline(GraphRuntimeStateSupport):
QueueWorkflowStartedEvent: self._handle_workflow_started_event,
QueueWorkflowSucceededEvent: self._handle_workflow_succeeded_event,
QueueWorkflowPartialSuccessEvent: self._handle_workflow_partial_success_event,
QueueWorkflowPausedEvent: self._handle_workflow_paused_event,
# Node events
QueueNodeRetryEvent: self._handle_node_retry_event,
QueueNodeStartedEvent: self._handle_node_started_event,
@ -520,6 +578,8 @@ class WorkflowAppGenerateTaskPipeline(GraphRuntimeStateSupport):
QueueLoopCompletedEvent: self._handle_loop_completed_event,
# Agent events
QueueAgentLogEvent: self._handle_agent_log_event,
QueueHumanInputFormFilledEvent: self._handle_human_input_form_filled_event,
QueueHumanInputFormTimeoutEvent: self._handle_human_input_form_timeout_event,
}
def _dispatch_event(
@ -602,6 +662,9 @@ class WorkflowAppGenerateTaskPipeline(GraphRuntimeStateSupport):
case QueueWorkflowFailedEvent():
yield from self._handle_workflow_failed_and_stop_events(event)
break
case QueueWorkflowPausedEvent():
yield from self._handle_workflow_paused_event(event)
break
case QueueStopEvent():
yield from self._handle_workflow_failed_and_stop_events(event)

View File

@ -1,3 +1,4 @@
import logging
import time
from collections.abc import Mapping, Sequence
from typing import Any, cast
@ -7,6 +8,8 @@ from core.app.entities.app_invoke_entities import InvokeFrom
from core.app.entities.queue_entities import (
AppQueueEvent,
QueueAgentLogEvent,
QueueHumanInputFormFilledEvent,
QueueHumanInputFormTimeoutEvent,
QueueIterationCompletedEvent,
QueueIterationNextEvent,
QueueIterationStartEvent,
@ -22,22 +25,27 @@ from core.app.entities.queue_entities import (
QueueTextChunkEvent,
QueueWorkflowFailedEvent,
QueueWorkflowPartialSuccessEvent,
QueueWorkflowPausedEvent,
QueueWorkflowStartedEvent,
QueueWorkflowSucceededEvent,
)
from core.app.workflow.node_factory import DifyNodeFactory
from core.workflow.entities import GraphInitParams
from core.workflow.entities.pause_reason import HumanInputRequired
from core.workflow.graph import Graph
from core.workflow.graph_engine.layers.base import GraphEngineLayer
from core.workflow.graph_events import (
GraphEngineEvent,
GraphRunFailedEvent,
GraphRunPartialSucceededEvent,
GraphRunPausedEvent,
GraphRunStartedEvent,
GraphRunSucceededEvent,
NodeRunAgentLogEvent,
NodeRunExceptionEvent,
NodeRunFailedEvent,
NodeRunHumanInputFormFilledEvent,
NodeRunHumanInputFormTimeoutEvent,
NodeRunIterationFailedEvent,
NodeRunIterationNextEvent,
NodeRunIterationStartedEvent,
@ -61,6 +69,9 @@ from core.workflow.variable_loader import DUMMY_VARIABLE_LOADER, VariableLoader,
from core.workflow.workflow_entry import WorkflowEntry
from models.enums import UserFrom
from models.workflow import Workflow
from tasks.mail_human_input_delivery_task import dispatch_human_input_email_task
logger = logging.getLogger(__name__)
class WorkflowBasedAppRunner:
@ -327,7 +338,7 @@ class WorkflowBasedAppRunner:
:param event: event
"""
if isinstance(event, GraphRunStartedEvent):
self._publish_event(QueueWorkflowStartedEvent())
self._publish_event(QueueWorkflowStartedEvent(reason=event.reason))
elif isinstance(event, GraphRunSucceededEvent):
self._publish_event(QueueWorkflowSucceededEvent(outputs=event.outputs))
elif isinstance(event, GraphRunPartialSucceededEvent):
@ -338,6 +349,38 @@ class WorkflowBasedAppRunner:
self._publish_event(QueueWorkflowFailedEvent(error=event.error, exceptions_count=event.exceptions_count))
elif isinstance(event, GraphRunAbortedEvent):
self._publish_event(QueueWorkflowFailedEvent(error=event.reason or "Unknown error", exceptions_count=0))
elif isinstance(event, GraphRunPausedEvent):
runtime_state = workflow_entry.graph_engine.graph_runtime_state
paused_nodes = runtime_state.get_paused_nodes()
self._enqueue_human_input_notifications(event.reasons)
self._publish_event(
QueueWorkflowPausedEvent(
reasons=event.reasons,
outputs=event.outputs,
paused_nodes=paused_nodes,
)
)
elif isinstance(event, NodeRunHumanInputFormFilledEvent):
self._publish_event(
QueueHumanInputFormFilledEvent(
node_execution_id=event.id,
node_id=event.node_id,
node_type=event.node_type,
node_title=event.node_title,
rendered_content=event.rendered_content,
action_id=event.action_id,
action_text=event.action_text,
)
)
elif isinstance(event, NodeRunHumanInputFormTimeoutEvent):
self._publish_event(
QueueHumanInputFormTimeoutEvent(
node_id=event.node_id,
node_type=event.node_type,
node_title=event.node_title,
expiration_time=event.expiration_time,
)
)
elif isinstance(event, NodeRunRetryEvent):
node_run_result = event.node_run_result
inputs = node_run_result.inputs
@ -544,5 +587,19 @@ class WorkflowBasedAppRunner:
)
)
def _enqueue_human_input_notifications(self, reasons: Sequence[object]) -> None:
for reason in reasons:
if not isinstance(reason, HumanInputRequired):
continue
if not reason.form_id:
continue
try:
dispatch_human_input_email_task.apply_async(
kwargs={"form_id": reason.form_id, "node_title": reason.node_title},
queue="mail",
)
except Exception: # pragma: no cover - defensive logging
logger.exception("Failed to enqueue human input email task for form %s", reason.form_id)
def _publish_event(self, event: AppQueueEvent):
self._queue_manager.publish(event, PublishFrom.APPLICATION_MANAGER)

View File

@ -132,7 +132,7 @@ class AppGenerateEntity(BaseModel):
extras: dict[str, Any] = Field(default_factory=dict)
# tracing instance
trace_manager: Optional["TraceQueueManager"] = None
trace_manager: Optional["TraceQueueManager"] = Field(default=None, exclude=True, repr=False)
class EasyUIBasedAppGenerateEntity(AppGenerateEntity):
@ -156,6 +156,7 @@ class ConversationAppGenerateEntity(AppGenerateEntity):
"""
conversation_id: str | None = None
is_new_conversation: bool = False
parent_message_id: str | None = Field(
default=None,
description=(

View File

@ -8,6 +8,8 @@ from pydantic import BaseModel, ConfigDict, Field
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk
from core.rag.entities.citation_metadata import RetrievalSourceMetadata
from core.workflow.entities import AgentNodeStrategyInit
from core.workflow.entities.pause_reason import PauseReason
from core.workflow.entities.workflow_start_reason import WorkflowStartReason
from core.workflow.enums import WorkflowNodeExecutionMetadataKey
from core.workflow.nodes import NodeType
@ -46,6 +48,9 @@ class QueueEvent(StrEnum):
PING = "ping"
STOP = "stop"
RETRY = "retry"
PAUSE = "pause"
HUMAN_INPUT_FORM_FILLED = "human_input_form_filled"
HUMAN_INPUT_FORM_TIMEOUT = "human_input_form_timeout"
class AppQueueEvent(BaseModel):
@ -261,6 +266,8 @@ class QueueWorkflowStartedEvent(AppQueueEvent):
"""QueueWorkflowStartedEvent entity."""
event: QueueEvent = QueueEvent.WORKFLOW_STARTED
# Always present; mirrors GraphRunStartedEvent.reason for downstream consumers.
reason: WorkflowStartReason = WorkflowStartReason.INITIAL
class QueueWorkflowSucceededEvent(AppQueueEvent):
@ -484,6 +491,35 @@ class QueueStopEvent(AppQueueEvent):
return reason_mapping.get(self.stopped_by, "Stopped by unknown reason.")
class QueueHumanInputFormFilledEvent(AppQueueEvent):
"""
QueueHumanInputFormFilledEvent entity
"""
event: QueueEvent = QueueEvent.HUMAN_INPUT_FORM_FILLED
node_execution_id: str
node_id: str
node_type: NodeType
node_title: str
rendered_content: str
action_id: str
action_text: str
class QueueHumanInputFormTimeoutEvent(AppQueueEvent):
"""
QueueHumanInputFormTimeoutEvent entity
"""
event: QueueEvent = QueueEvent.HUMAN_INPUT_FORM_TIMEOUT
node_id: str
node_type: NodeType
node_title: str
expiration_time: datetime
class QueueMessage(BaseModel):
"""
QueueMessage abstract entity
@ -509,3 +545,14 @@ class WorkflowQueueMessage(QueueMessage):
"""
pass
class QueueWorkflowPausedEvent(AppQueueEvent):
"""
QueueWorkflowPausedEvent entity
"""
event: QueueEvent = QueueEvent.PAUSE
reasons: Sequence[PauseReason] = Field(default_factory=list)
outputs: Mapping[str, object] = Field(default_factory=dict)
paused_nodes: Sequence[str] = Field(default_factory=list)

View File

@ -7,7 +7,9 @@ from pydantic import BaseModel, ConfigDict, Field
from core.model_runtime.entities.llm_entities import LLMResult, LLMUsage
from core.rag.entities.citation_metadata import RetrievalSourceMetadata
from core.workflow.entities import AgentNodeStrategyInit
from core.workflow.entities.workflow_start_reason import WorkflowStartReason
from core.workflow.enums import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus
from core.workflow.nodes.human_input.entities import FormInput, UserAction
class AnnotationReplyAccount(BaseModel):
@ -69,6 +71,7 @@ class StreamEvent(StrEnum):
AGENT_THOUGHT = "agent_thought"
AGENT_MESSAGE = "agent_message"
WORKFLOW_STARTED = "workflow_started"
WORKFLOW_PAUSED = "workflow_paused"
WORKFLOW_FINISHED = "workflow_finished"
NODE_STARTED = "node_started"
NODE_FINISHED = "node_finished"
@ -82,6 +85,9 @@ class StreamEvent(StrEnum):
TEXT_CHUNK = "text_chunk"
TEXT_REPLACE = "text_replace"
AGENT_LOG = "agent_log"
HUMAN_INPUT_REQUIRED = "human_input_required"
HUMAN_INPUT_FORM_FILLED = "human_input_form_filled"
HUMAN_INPUT_FORM_TIMEOUT = "human_input_form_timeout"
class StreamResponse(BaseModel):
@ -205,6 +211,8 @@ class WorkflowStartStreamResponse(StreamResponse):
workflow_id: str
inputs: Mapping[str, Any]
created_at: int
# Always present; mirrors QueueWorkflowStartedEvent.reason for SSE clients.
reason: WorkflowStartReason = WorkflowStartReason.INITIAL
event: StreamEvent = StreamEvent.WORKFLOW_STARTED
workflow_run_id: str
@ -231,7 +239,7 @@ class WorkflowFinishStreamResponse(StreamResponse):
total_steps: int
created_by: Mapping[str, object] = Field(default_factory=dict)
created_at: int
finished_at: int
finished_at: int | None
exceptions_count: int | None = 0
files: Sequence[Mapping[str, Any]] | None = []
@ -240,6 +248,85 @@ class WorkflowFinishStreamResponse(StreamResponse):
data: Data
class WorkflowPauseStreamResponse(StreamResponse):
"""
WorkflowPauseStreamResponse entity
"""
class Data(BaseModel):
"""
Data entity
"""
workflow_run_id: str
paused_nodes: Sequence[str] = Field(default_factory=list)
outputs: Mapping[str, Any] = Field(default_factory=dict)
reasons: Sequence[Mapping[str, Any]] = Field(default_factory=list)
status: str
created_at: int
elapsed_time: float
total_tokens: int
total_steps: int
event: StreamEvent = StreamEvent.WORKFLOW_PAUSED
workflow_run_id: str
data: Data
class HumanInputRequiredResponse(StreamResponse):
class Data(BaseModel):
"""
Data entity
"""
form_id: str
node_id: str
node_title: str
form_content: str
inputs: Sequence[FormInput] = Field(default_factory=list)
actions: Sequence[UserAction] = Field(default_factory=list)
display_in_ui: bool = False
form_token: str | None = None
resolved_default_values: Mapping[str, Any] = Field(default_factory=dict)
expiration_time: int = Field(..., description="Unix timestamp in seconds")
event: StreamEvent = StreamEvent.HUMAN_INPUT_REQUIRED
workflow_run_id: str
data: Data
class HumanInputFormFilledResponse(StreamResponse):
class Data(BaseModel):
"""
Data entity
"""
node_id: str
node_title: str
rendered_content: str
action_id: str
action_text: str
event: StreamEvent = StreamEvent.HUMAN_INPUT_FORM_FILLED
workflow_run_id: str
data: Data
class HumanInputFormTimeoutResponse(StreamResponse):
class Data(BaseModel):
"""
Data entity
"""
node_id: str
node_title: str
expiration_time: int
event: StreamEvent = StreamEvent.HUMAN_INPUT_FORM_TIMEOUT
workflow_run_id: str
data: Data
class NodeStartStreamResponse(StreamResponse):
"""
NodeStartStreamResponse entity
@ -726,7 +813,7 @@ class WorkflowAppBlockingResponse(AppBlockingResponse):
total_tokens: int
total_steps: int
created_at: int
finished_at: int
finished_at: int | None
workflow_run_id: str
data: Data

View File

@ -1,3 +1,4 @@
import contextlib
import logging
import time
import uuid
@ -103,6 +104,14 @@ class RateLimit:
)
@contextlib.contextmanager
def rate_limit_context(rate_limit: RateLimit, request_id: str | None):
request_id = rate_limit.enter(request_id)
yield
if request_id is not None:
rate_limit.exit(request_id)
class RateLimitGenerator:
def __init__(self, rate_limit: RateLimit, generator: Generator[str, None, None], request_id: str):
self.rate_limit = rate_limit

View File

@ -1,3 +1,4 @@
from dataclasses import dataclass
from typing import Annotated, Literal, Self, TypeAlias
from pydantic import BaseModel, Field
@ -52,6 +53,14 @@ class WorkflowResumptionContext(BaseModel):
return self.generate_entity.entity
@dataclass(frozen=True)
class PauseStateLayerConfig:
"""Configuration container for instantiating pause persistence layers."""
session_factory: Engine | sessionmaker[Session]
state_owner_user_id: str
class PauseStatePersistenceLayer(GraphEngineLayer):
def __init__(
self,

View File

@ -82,10 +82,11 @@ class MessageCycleManager:
if isinstance(self._application_generate_entity, CompletionAppGenerateEntity):
return None
is_first_message = self._application_generate_entity.conversation_id is None
is_first_message = self._application_generate_entity.is_new_conversation
extras = self._application_generate_entity.extras
auto_generate_conversation_name = extras.get("auto_generate_conversation_name", True)
thread: Thread | None = None
if auto_generate_conversation_name and is_first_message:
# start generate thread
# time.sleep not block other logic
@ -101,9 +102,10 @@ class MessageCycleManager:
thread.daemon = True
thread.start()
return thread
if is_first_message:
self._application_generate_entity.is_new_conversation = False
return None
return thread
def _generate_conversation_name_worker(self, flask_app: Flask, conversation_id: str, query: str):
with flask_app.app_context():

View File

@ -0,0 +1,54 @@
from __future__ import annotations
from collections.abc import Mapping, Sequence
from typing import Any, TypeAlias
from pydantic import BaseModel, ConfigDict, Field
from core.workflow.nodes.human_input.entities import FormInput, UserAction
from models.execution_extra_content import ExecutionContentType
class HumanInputFormDefinition(BaseModel):
model_config = ConfigDict(frozen=True)
form_id: str
node_id: str
node_title: str
form_content: str
inputs: Sequence[FormInput] = Field(default_factory=list)
actions: Sequence[UserAction] = Field(default_factory=list)
display_in_ui: bool = False
form_token: str | None = None
resolved_default_values: Mapping[str, Any] = Field(default_factory=dict)
expiration_time: int
class HumanInputFormSubmissionData(BaseModel):
model_config = ConfigDict(frozen=True)
node_id: str
node_title: str
rendered_content: str
action_id: str
action_text: str
class HumanInputContent(BaseModel):
model_config = ConfigDict(frozen=True)
workflow_run_id: str
submitted: bool
form_definition: HumanInputFormDefinition | None = None
form_submission_data: HumanInputFormSubmissionData | None = None
type: ExecutionContentType = Field(default=ExecutionContentType.HUMAN_INPUT)
ExecutionExtraContentDomainModel: TypeAlias = HumanInputContent
__all__ = [
"ExecutionExtraContentDomainModel",
"HumanInputContent",
"HumanInputFormDefinition",
"HumanInputFormSubmissionData",
]

View File

@ -28,8 +28,8 @@ from core.model_runtime.entities.provider_entities import (
)
from core.model_runtime.model_providers.__base.ai_model import AIModel
from core.model_runtime.model_providers.model_provider_factory import ModelProviderFactory
from extensions.ext_database import db
from libs.datetime_utils import naive_utc_now
from models.engine import db
from models.provider import (
LoadBalancingModelConfig,
Provider,

View File

@ -15,10 +15,7 @@ from sqlalchemy import select
from sqlalchemy.orm import Session, sessionmaker
from core.helper.encrypter import batch_decrypt_token, encrypt_token, obfuscated_token
from core.ops.entities.config_entity import (
OPS_FILE_PATH,
TracingProviderEnum,
)
from core.ops.entities.config_entity import OPS_FILE_PATH, TracingProviderEnum
from core.ops.entities.trace_entity import (
DatasetRetrievalTraceInfo,
GenerateNameTraceInfo,
@ -31,8 +28,8 @@ from core.ops.entities.trace_entity import (
WorkflowTraceInfo,
)
from core.ops.utils import get_message_data
from extensions.ext_database import db
from extensions.ext_storage import storage
from models.engine import db
from models.model import App, AppModelConfig, Conversation, Message, MessageFile, TraceAppConfig
from models.workflow import WorkflowAppLog
from tasks.ops_trace_task import process_trace_tasks
@ -469,6 +466,8 @@ class TraceTask:
@classmethod
def _get_workflow_run_repo(cls):
from repositories.factory import DifyAPIRepositoryFactory
if cls._workflow_run_repo is None:
with cls._repo_lock:
if cls._workflow_run_repo is None:

View File

@ -5,7 +5,7 @@ from urllib.parse import urlparse
from sqlalchemy import select
from extensions.ext_database import db
from models.engine import db
from models.model import Message

View File

@ -1,3 +1,4 @@
import uuid
from collections.abc import Generator, Mapping
from typing import Union
@ -112,6 +113,7 @@ class PluginAppBackwardsInvocation(BaseBackwardsInvocation):
"conversation_id": conversation_id,
},
invoke_from=InvokeFrom.SERVICE_API,
workflow_run_id=str(uuid.uuid4()),
streaming=stream,
)
elif app.mode == AppMode.AGENT_CHAT:

View File

@ -1,19 +1,18 @@
"""
Repository implementations for data access.
"""Repository implementations for data access."""
This package contains concrete implementations of the repository interfaces
defined in the core.workflow.repository package.
"""
from __future__ import annotations
from core.repositories.celery_workflow_execution_repository import CeleryWorkflowExecutionRepository
from core.repositories.celery_workflow_node_execution_repository import CeleryWorkflowNodeExecutionRepository
from core.repositories.factory import DifyCoreRepositoryFactory, RepositoryImportError
from core.repositories.sqlalchemy_workflow_node_execution_repository import SQLAlchemyWorkflowNodeExecutionRepository
from .celery_workflow_execution_repository import CeleryWorkflowExecutionRepository
from .celery_workflow_node_execution_repository import CeleryWorkflowNodeExecutionRepository
from .factory import DifyCoreRepositoryFactory, RepositoryImportError
from .sqlalchemy_workflow_execution_repository import SQLAlchemyWorkflowExecutionRepository
from .sqlalchemy_workflow_node_execution_repository import SQLAlchemyWorkflowNodeExecutionRepository
__all__ = [
"CeleryWorkflowExecutionRepository",
"CeleryWorkflowNodeExecutionRepository",
"DifyCoreRepositoryFactory",
"RepositoryImportError",
"SQLAlchemyWorkflowExecutionRepository",
"SQLAlchemyWorkflowNodeExecutionRepository",
]

View File

@ -0,0 +1,553 @@
import dataclasses
import json
from collections.abc import Mapping, Sequence
from datetime import datetime
from typing import Any
from sqlalchemy import Engine, select
from sqlalchemy.orm import Session, selectinload, sessionmaker
from core.workflow.nodes.human_input.entities import (
DeliveryChannelConfig,
EmailDeliveryMethod,
EmailRecipients,
ExternalRecipient,
FormDefinition,
HumanInputNodeData,
MemberRecipient,
WebAppDeliveryMethod,
)
from core.workflow.nodes.human_input.enums import (
DeliveryMethodType,
HumanInputFormKind,
HumanInputFormStatus,
)
from core.workflow.repositories.human_input_form_repository import (
FormCreateParams,
FormNotFoundError,
HumanInputFormEntity,
HumanInputFormRecipientEntity,
)
from libs.datetime_utils import naive_utc_now
from libs.uuid_utils import uuidv7
from models.account import Account, TenantAccountJoin
from models.human_input import (
BackstageRecipientPayload,
ConsoleDeliveryPayload,
ConsoleRecipientPayload,
EmailExternalRecipientPayload,
EmailMemberRecipientPayload,
HumanInputDelivery,
HumanInputForm,
HumanInputFormRecipient,
RecipientType,
StandaloneWebAppRecipientPayload,
)
@dataclasses.dataclass(frozen=True)
class _DeliveryAndRecipients:
delivery: HumanInputDelivery
recipients: Sequence[HumanInputFormRecipient]
@dataclasses.dataclass(frozen=True)
class _WorkspaceMemberInfo:
user_id: str
email: str
class _HumanInputFormRecipientEntityImpl(HumanInputFormRecipientEntity):
def __init__(self, recipient_model: HumanInputFormRecipient):
self._recipient_model = recipient_model
@property
def id(self) -> str:
return self._recipient_model.id
@property
def token(self) -> str:
if self._recipient_model.access_token is None:
raise AssertionError(f"access_token should not be None for recipient {self._recipient_model.id}")
return self._recipient_model.access_token
class _HumanInputFormEntityImpl(HumanInputFormEntity):
def __init__(self, form_model: HumanInputForm, recipient_models: Sequence[HumanInputFormRecipient]):
self._form_model = form_model
self._recipients = [_HumanInputFormRecipientEntityImpl(recipient) for recipient in recipient_models]
self._web_app_recipient = next(
(
recipient
for recipient in recipient_models
if recipient.recipient_type == RecipientType.STANDALONE_WEB_APP
),
None,
)
self._console_recipient = next(
(recipient for recipient in recipient_models if recipient.recipient_type == RecipientType.CONSOLE),
None,
)
self._submitted_data: Mapping[str, Any] | None = (
json.loads(form_model.submitted_data) if form_model.submitted_data is not None else None
)
@property
def id(self) -> str:
return self._form_model.id
@property
def web_app_token(self):
if self._console_recipient is not None:
return self._console_recipient.access_token
if self._web_app_recipient is None:
return None
return self._web_app_recipient.access_token
@property
def recipients(self) -> list[HumanInputFormRecipientEntity]:
return list(self._recipients)
@property
def rendered_content(self) -> str:
return self._form_model.rendered_content
@property
def selected_action_id(self) -> str | None:
return self._form_model.selected_action_id
@property
def submitted_data(self) -> Mapping[str, Any] | None:
return self._submitted_data
@property
def submitted(self) -> bool:
return self._form_model.submitted_at is not None
@property
def status(self) -> HumanInputFormStatus:
return self._form_model.status
@property
def expiration_time(self) -> datetime:
return self._form_model.expiration_time
@dataclasses.dataclass(frozen=True)
class HumanInputFormRecord:
form_id: str
workflow_run_id: str | None
node_id: str
tenant_id: str
app_id: str
form_kind: HumanInputFormKind
definition: FormDefinition
rendered_content: str
created_at: datetime
expiration_time: datetime
status: HumanInputFormStatus
selected_action_id: str | None
submitted_data: Mapping[str, Any] | None
submitted_at: datetime | None
submission_user_id: str | None
submission_end_user_id: str | None
completed_by_recipient_id: str | None
recipient_id: str | None
recipient_type: RecipientType | None
access_token: str | None
@property
def submitted(self) -> bool:
return self.submitted_at is not None
@classmethod
def from_models(
cls, form_model: HumanInputForm, recipient_model: HumanInputFormRecipient | None
) -> "HumanInputFormRecord":
definition_payload = json.loads(form_model.form_definition)
if "expiration_time" not in definition_payload:
definition_payload["expiration_time"] = form_model.expiration_time
return cls(
form_id=form_model.id,
workflow_run_id=form_model.workflow_run_id,
node_id=form_model.node_id,
tenant_id=form_model.tenant_id,
app_id=form_model.app_id,
form_kind=form_model.form_kind,
definition=FormDefinition.model_validate(definition_payload),
rendered_content=form_model.rendered_content,
created_at=form_model.created_at,
expiration_time=form_model.expiration_time,
status=form_model.status,
selected_action_id=form_model.selected_action_id,
submitted_data=json.loads(form_model.submitted_data) if form_model.submitted_data else None,
submitted_at=form_model.submitted_at,
submission_user_id=form_model.submission_user_id,
submission_end_user_id=form_model.submission_end_user_id,
completed_by_recipient_id=form_model.completed_by_recipient_id,
recipient_id=recipient_model.id if recipient_model else None,
recipient_type=recipient_model.recipient_type if recipient_model else None,
access_token=recipient_model.access_token if recipient_model else None,
)
class _InvalidTimeoutStatusError(ValueError):
pass
class HumanInputFormRepositoryImpl:
def __init__(
self,
session_factory: sessionmaker | Engine,
tenant_id: str,
):
if isinstance(session_factory, Engine):
session_factory = sessionmaker(bind=session_factory)
self._session_factory = session_factory
self._tenant_id = tenant_id
def _delivery_method_to_model(
self,
session: Session,
form_id: str,
delivery_method: DeliveryChannelConfig,
) -> _DeliveryAndRecipients:
delivery_id = str(uuidv7())
delivery_model = HumanInputDelivery(
id=delivery_id,
form_id=form_id,
delivery_method_type=delivery_method.type,
delivery_config_id=delivery_method.id,
channel_payload=delivery_method.model_dump_json(),
)
recipients: list[HumanInputFormRecipient] = []
if isinstance(delivery_method, WebAppDeliveryMethod):
recipient_model = HumanInputFormRecipient(
form_id=form_id,
delivery_id=delivery_id,
recipient_type=RecipientType.STANDALONE_WEB_APP,
recipient_payload=StandaloneWebAppRecipientPayload().model_dump_json(),
)
recipients.append(recipient_model)
elif isinstance(delivery_method, EmailDeliveryMethod):
email_recipients_config = delivery_method.config.recipients
recipients.extend(
self._build_email_recipients(
session=session,
form_id=form_id,
delivery_id=delivery_id,
recipients_config=email_recipients_config,
)
)
return _DeliveryAndRecipients(delivery=delivery_model, recipients=recipients)
def _build_email_recipients(
self,
session: Session,
form_id: str,
delivery_id: str,
recipients_config: EmailRecipients,
) -> list[HumanInputFormRecipient]:
member_user_ids = [
recipient.user_id for recipient in recipients_config.items if isinstance(recipient, MemberRecipient)
]
external_emails = [
recipient.email for recipient in recipients_config.items if isinstance(recipient, ExternalRecipient)
]
if recipients_config.whole_workspace:
members = self._query_all_workspace_members(session=session)
else:
members = self._query_workspace_members_by_ids(session=session, restrict_to_user_ids=member_user_ids)
return self._create_email_recipients_from_resolved(
form_id=form_id,
delivery_id=delivery_id,
members=members,
external_emails=external_emails,
)
@staticmethod
def _create_email_recipients_from_resolved(
*,
form_id: str,
delivery_id: str,
members: Sequence[_WorkspaceMemberInfo],
external_emails: Sequence[str],
) -> list[HumanInputFormRecipient]:
recipient_models: list[HumanInputFormRecipient] = []
seen_emails: set[str] = set()
for member in members:
if not member.email:
continue
if member.email in seen_emails:
continue
seen_emails.add(member.email)
payload = EmailMemberRecipientPayload(user_id=member.user_id, email=member.email)
recipient_models.append(
HumanInputFormRecipient.new(
form_id=form_id,
delivery_id=delivery_id,
payload=payload,
)
)
for email in external_emails:
if not email:
continue
if email in seen_emails:
continue
seen_emails.add(email)
recipient_models.append(
HumanInputFormRecipient.new(
form_id=form_id,
delivery_id=delivery_id,
payload=EmailExternalRecipientPayload(email=email),
)
)
return recipient_models
def _query_all_workspace_members(
self,
session: Session,
) -> list[_WorkspaceMemberInfo]:
stmt = (
select(Account.id, Account.email)
.join(TenantAccountJoin, TenantAccountJoin.account_id == Account.id)
.where(TenantAccountJoin.tenant_id == self._tenant_id)
)
rows = session.execute(stmt).all()
return [_WorkspaceMemberInfo(user_id=account_id, email=email) for account_id, email in rows]
def _query_workspace_members_by_ids(
self,
session: Session,
restrict_to_user_ids: Sequence[str],
) -> list[_WorkspaceMemberInfo]:
unique_ids = {user_id for user_id in restrict_to_user_ids if user_id}
if not unique_ids:
return []
stmt = (
select(Account.id, Account.email)
.join(TenantAccountJoin, TenantAccountJoin.account_id == Account.id)
.where(TenantAccountJoin.tenant_id == self._tenant_id)
)
stmt = stmt.where(Account.id.in_(unique_ids))
rows = session.execute(stmt).all()
return [_WorkspaceMemberInfo(user_id=account_id, email=email) for account_id, email in rows]
def create_form(self, params: FormCreateParams) -> HumanInputFormEntity:
form_config: HumanInputNodeData = params.form_config
with self._session_factory(expire_on_commit=False) as session, session.begin():
# Generate unique form ID
form_id = str(uuidv7())
start_time = naive_utc_now()
node_expiration = form_config.expiration_time(start_time)
form_definition = FormDefinition(
form_content=form_config.form_content,
inputs=form_config.inputs,
user_actions=form_config.user_actions,
rendered_content=params.rendered_content,
expiration_time=node_expiration,
default_values=dict(params.resolved_default_values),
display_in_ui=params.display_in_ui,
node_title=form_config.title,
)
form_model = HumanInputForm(
id=form_id,
tenant_id=self._tenant_id,
app_id=params.app_id,
workflow_run_id=params.workflow_execution_id,
form_kind=params.form_kind,
node_id=params.node_id,
form_definition=form_definition.model_dump_json(),
rendered_content=params.rendered_content,
expiration_time=node_expiration,
created_at=start_time,
)
session.add(form_model)
recipient_models: list[HumanInputFormRecipient] = []
for delivery in params.delivery_methods:
delivery_and_recipients = self._delivery_method_to_model(
session=session,
form_id=form_id,
delivery_method=delivery,
)
session.add(delivery_and_recipients.delivery)
session.add_all(delivery_and_recipients.recipients)
recipient_models.extend(delivery_and_recipients.recipients)
if params.console_recipient_required and not any(
recipient.recipient_type == RecipientType.CONSOLE for recipient in recipient_models
):
console_delivery_id = str(uuidv7())
console_delivery = HumanInputDelivery(
id=console_delivery_id,
form_id=form_id,
delivery_method_type=DeliveryMethodType.WEBAPP,
delivery_config_id=None,
channel_payload=ConsoleDeliveryPayload().model_dump_json(),
)
console_recipient = HumanInputFormRecipient(
form_id=form_id,
delivery_id=console_delivery_id,
recipient_type=RecipientType.CONSOLE,
recipient_payload=ConsoleRecipientPayload(
account_id=params.console_creator_account_id,
).model_dump_json(),
)
session.add(console_delivery)
session.add(console_recipient)
recipient_models.append(console_recipient)
if params.backstage_recipient_required and not any(
recipient.recipient_type == RecipientType.BACKSTAGE for recipient in recipient_models
):
backstage_delivery_id = str(uuidv7())
backstage_delivery = HumanInputDelivery(
id=backstage_delivery_id,
form_id=form_id,
delivery_method_type=DeliveryMethodType.WEBAPP,
delivery_config_id=None,
channel_payload=ConsoleDeliveryPayload().model_dump_json(),
)
backstage_recipient = HumanInputFormRecipient(
form_id=form_id,
delivery_id=backstage_delivery_id,
recipient_type=RecipientType.BACKSTAGE,
recipient_payload=BackstageRecipientPayload(
account_id=params.console_creator_account_id,
).model_dump_json(),
)
session.add(backstage_delivery)
session.add(backstage_recipient)
recipient_models.append(backstage_recipient)
session.flush()
return _HumanInputFormEntityImpl(form_model=form_model, recipient_models=recipient_models)
def get_form(self, workflow_execution_id: str, node_id: str) -> HumanInputFormEntity | None:
form_query = select(HumanInputForm).where(
HumanInputForm.workflow_run_id == workflow_execution_id,
HumanInputForm.node_id == node_id,
HumanInputForm.tenant_id == self._tenant_id,
)
with self._session_factory(expire_on_commit=False) as session:
form_model: HumanInputForm | None = session.scalars(form_query).first()
if form_model is None:
return None
recipient_query = select(HumanInputFormRecipient).where(HumanInputFormRecipient.form_id == form_model.id)
recipient_models = session.scalars(recipient_query).all()
return _HumanInputFormEntityImpl(form_model=form_model, recipient_models=recipient_models)
class HumanInputFormSubmissionRepository:
"""Repository for fetching and submitting human input forms."""
def __init__(self, session_factory: sessionmaker | Engine):
if isinstance(session_factory, Engine):
session_factory = sessionmaker(bind=session_factory)
self._session_factory = session_factory
def get_by_token(self, form_token: str) -> HumanInputFormRecord | None:
query = (
select(HumanInputFormRecipient)
.options(selectinload(HumanInputFormRecipient.form))
.where(HumanInputFormRecipient.access_token == form_token)
)
with self._session_factory(expire_on_commit=False) as session:
recipient_model = session.scalars(query).first()
if recipient_model is None or recipient_model.form is None:
return None
return HumanInputFormRecord.from_models(recipient_model.form, recipient_model)
def get_by_form_id_and_recipient_type(
self,
form_id: str,
recipient_type: RecipientType,
) -> HumanInputFormRecord | None:
query = (
select(HumanInputFormRecipient)
.options(selectinload(HumanInputFormRecipient.form))
.where(
HumanInputFormRecipient.form_id == form_id,
HumanInputFormRecipient.recipient_type == recipient_type,
)
)
with self._session_factory(expire_on_commit=False) as session:
recipient_model = session.scalars(query).first()
if recipient_model is None or recipient_model.form is None:
return None
return HumanInputFormRecord.from_models(recipient_model.form, recipient_model)
def mark_submitted(
self,
*,
form_id: str,
recipient_id: str | None,
selected_action_id: str,
form_data: Mapping[str, Any],
submission_user_id: str | None,
submission_end_user_id: str | None,
) -> HumanInputFormRecord:
with self._session_factory(expire_on_commit=False) as session, session.begin():
form_model = session.get(HumanInputForm, form_id)
if form_model is None:
raise FormNotFoundError(f"form not found, id={form_id}")
recipient_model = session.get(HumanInputFormRecipient, recipient_id) if recipient_id else None
form_model.selected_action_id = selected_action_id
form_model.submitted_data = json.dumps(form_data)
form_model.submitted_at = naive_utc_now()
form_model.status = HumanInputFormStatus.SUBMITTED
form_model.submission_user_id = submission_user_id
form_model.submission_end_user_id = submission_end_user_id
form_model.completed_by_recipient_id = recipient_id
session.add(form_model)
session.flush()
session.refresh(form_model)
if recipient_model is not None:
session.refresh(recipient_model)
return HumanInputFormRecord.from_models(form_model, recipient_model)
def mark_timeout(
self,
*,
form_id: str,
timeout_status: HumanInputFormStatus,
reason: str | None = None,
) -> HumanInputFormRecord:
with self._session_factory(expire_on_commit=False) as session, session.begin():
form_model = session.get(HumanInputForm, form_id)
if form_model is None:
raise FormNotFoundError(f"form not found, id={form_id}")
if timeout_status not in {HumanInputFormStatus.TIMEOUT, HumanInputFormStatus.EXPIRED}:
raise _InvalidTimeoutStatusError(f"invalid timeout status: {timeout_status}")
# already handled or submitted
if form_model.status in {HumanInputFormStatus.TIMEOUT, HumanInputFormStatus.EXPIRED}:
return HumanInputFormRecord.from_models(form_model, None)
if form_model.submitted_at is not None or form_model.status == HumanInputFormStatus.SUBMITTED:
raise FormNotFoundError(f"form already submitted, id={form_id}")
form_model.status = timeout_status
form_model.selected_action_id = None
form_model.submitted_data = None
form_model.submission_user_id = None
form_model.submission_end_user_id = None
form_model.completed_by_recipient_id = None
# Reason is recorded in status/error downstream; not stored on form.
session.add(form_model)
session.flush()
session.refresh(form_model)
return HumanInputFormRecord.from_models(form_model, None)

View File

@ -488,6 +488,7 @@ class SQLAlchemyWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository)
WorkflowNodeExecutionModel.workflow_run_id == workflow_run_id,
WorkflowNodeExecutionModel.tenant_id == self._tenant_id,
WorkflowNodeExecutionModel.triggered_from == triggered_from,
WorkflowNodeExecutionModel.status != WorkflowNodeExecutionStatus.PAUSED,
)
if self._app_id:

View File

@ -1,4 +1,5 @@
from core.tools.entities.tool_entities import ToolInvokeMeta
from libs.exception import BaseHTTPException
class ToolProviderNotFoundError(ValueError):
@ -37,6 +38,12 @@ class ToolCredentialPolicyViolationError(ValueError):
pass
class WorkflowToolHumanInputNotSupportedError(BaseHTTPException):
error_code = "workflow_tool_human_input_not_supported"
description = "Workflow with Human Input nodes cannot be published as a workflow tool."
code = 400
class ToolEngineInvokeError(Exception):
meta: ToolInvokeMeta

View File

@ -3,6 +3,8 @@ from typing import Any
from core.app.app_config.entities import VariableEntity
from core.tools.entities.tool_entities import WorkflowToolParameterConfiguration
from core.tools.errors import WorkflowToolHumanInputNotSupportedError
from core.workflow.enums import NodeType
from core.workflow.nodes.base.entities import OutputVariableEntity
@ -50,6 +52,13 @@ class WorkflowToolConfigurationUtils:
return [outputs_by_variable[variable] for variable in variable_order]
@classmethod
def ensure_no_human_input_nodes(cls, graph: Mapping[str, Any]) -> None:
nodes = graph.get("nodes", [])
for node in nodes:
if node.get("data", {}).get("type") == NodeType.HUMAN_INPUT:
raise WorkflowToolHumanInputNotSupportedError()
@classmethod
def check_is_synced(
cls, variables: list[VariableEntity], tool_configurations: list[WorkflowToolParameterConfiguration]

View File

@ -2,10 +2,12 @@ from .agent import AgentNodeStrategyInit
from .graph_init_params import GraphInitParams
from .workflow_execution import WorkflowExecution
from .workflow_node_execution import WorkflowNodeExecution
from .workflow_start_reason import WorkflowStartReason
__all__ = [
"AgentNodeStrategyInit",
"GraphInitParams",
"WorkflowExecution",
"WorkflowNodeExecution",
"WorkflowStartReason",
]

View File

@ -5,6 +5,16 @@ from pydantic import BaseModel, Field
class GraphInitParams(BaseModel):
"""GraphInitParams encapsulates the configurations and contextual information
that remain constant throughout a single execution of the graph engine.
A single execution is defined as follows: as long as the execution has not reached
its conclusion, it is considered one execution. For instance, if a workflow is suspended
and later resumed, it is still regarded as a single execution, not two.
For the state diagram of workflow execution, refer to `WorkflowExecutionStatus`.
"""
# init params
tenant_id: str = Field(..., description="tenant / workspace id")
app_id: str = Field(..., description="app id")

View File

@ -1,8 +1,11 @@
from collections.abc import Mapping
from enum import StrEnum, auto
from typing import Annotated, Literal, TypeAlias
from typing import Annotated, Any, Literal, TypeAlias
from pydantic import BaseModel, Field
from core.workflow.nodes.human_input.entities import FormInput, UserAction
class PauseReasonType(StrEnum):
HUMAN_INPUT_REQUIRED = auto()
@ -11,10 +14,31 @@ class PauseReasonType(StrEnum):
class HumanInputRequired(BaseModel):
TYPE: Literal[PauseReasonType.HUMAN_INPUT_REQUIRED] = PauseReasonType.HUMAN_INPUT_REQUIRED
form_id: str
# The identifier of the human input node causing the pause.
form_content: str
inputs: list[FormInput] = Field(default_factory=list)
actions: list[UserAction] = Field(default_factory=list)
display_in_ui: bool = False
node_id: str
node_title: str
# The `resolved_default_values` stores the resolved values of variable defaults. It's a mapping from
# `output_variable_name` to their resolved values.
#
# For example, The form contains a input with output variable name `name` and placeholder type `VARIABLE`, its
# selector is ["start", "name"]. While the HumanInputNode is executed, the correspond value of variable
# `start.name` in variable pool is `John`. Thus, the resolved value of the output variable `name` is `John`. The
# `resolved_default_values` is `{"name": "John"}`.
#
# Only form inputs with default value type `VARIABLE` will be resolved and stored in `resolved_default_values`.
resolved_default_values: Mapping[str, Any] = Field(default_factory=dict)
# The `form_token` is the token used to submit the form via UI surfaces. It corresponds to
# `HumanInputFormRecipient.access_token`.
#
# This field is `None` if webapp delivery is not set and not
# in orchestrating mode.
form_token: str | None = None
class SchedulingPause(BaseModel):

View File

@ -0,0 +1,8 @@
from enum import StrEnum
class WorkflowStartReason(StrEnum):
"""Reason for workflow start events across graph/queue/SSE layers."""
INITIAL = "initial" # First start of a workflow run.
RESUMPTION = "resumption" # Start triggered after resuming a paused run.

View File

@ -0,0 +1,15 @@
import time
def get_timestamp() -> float:
"""Retrieve a timestamp as a float point numer representing the number of seconds
since the Unix epoch.
This function is primarily used to measure the execution time of the workflow engine.
Since workflow execution may be paused and resumed on a different machine,
`time.perf_counter` cannot be used as it is inconsistent across machines.
To address this, the function uses the wall clock as the time source.
However, it assumes that the clocks of all servers are properly synchronized.
"""
return round(time.time())

View File

@ -2,12 +2,14 @@
GraphEngine configuration models.
"""
from pydantic import BaseModel
from pydantic import BaseModel, ConfigDict
class GraphEngineConfig(BaseModel):
"""Configuration for GraphEngine worker pool scaling."""
model_config = ConfigDict(frozen=True)
min_workers: int = 1
max_workers: int = 5
scale_up_threshold: int = 3

View File

@ -192,9 +192,13 @@ class EventHandler:
self._event_collector.collect(edge_event)
# Enqueue ready nodes
for node_id in ready_nodes:
self._state_manager.enqueue_node(node_id)
self._state_manager.start_execution(node_id)
if self._graph_execution.is_paused:
for node_id in ready_nodes:
self._graph_runtime_state.register_deferred_node(node_id)
else:
for node_id in ready_nodes:
self._state_manager.enqueue_node(node_id)
self._state_manager.start_execution(node_id)
# Update execution tracking
self._state_manager.finish_execution(event.node_id)

View File

@ -14,6 +14,7 @@ from collections.abc import Generator
from typing import TYPE_CHECKING, cast, final
from core.workflow.context import capture_current_context
from core.workflow.entities.workflow_start_reason import WorkflowStartReason
from core.workflow.enums import NodeExecutionType
from core.workflow.graph import Graph
from core.workflow.graph_events import (
@ -56,6 +57,9 @@ if TYPE_CHECKING:
logger = logging.getLogger(__name__)
_DEFAULT_CONFIG = GraphEngineConfig()
@final
class GraphEngine:
"""
@ -71,7 +75,7 @@ class GraphEngine:
graph: Graph,
graph_runtime_state: GraphRuntimeState,
command_channel: CommandChannel,
config: GraphEngineConfig,
config: GraphEngineConfig = _DEFAULT_CONFIG,
) -> None:
"""Initialize the graph engine with all subsystems and dependencies."""
# stop event
@ -235,7 +239,9 @@ class GraphEngine:
self._graph_execution.paused = False
self._graph_execution.pause_reasons = []
start_event = GraphRunStartedEvent()
start_event = GraphRunStartedEvent(
reason=WorkflowStartReason.RESUMPTION if is_resume else WorkflowStartReason.INITIAL,
)
self._event_manager.notify_layers(start_event)
yield start_event
@ -304,15 +310,17 @@ class GraphEngine:
for layer in self._layers:
try:
layer.on_graph_start()
except Exception as e:
logger.warning("Layer %s failed on_graph_start: %s", layer.__class__.__name__, e)
except Exception:
logger.exception("Layer %s failed on_graph_start", layer.__class__.__name__)
def _start_execution(self, *, resume: bool = False) -> None:
"""Start execution subsystems."""
self._stop_event.clear()
paused_nodes: list[str] = []
deferred_nodes: list[str] = []
if resume:
paused_nodes = self._graph_runtime_state.consume_paused_nodes()
deferred_nodes = self._graph_runtime_state.consume_deferred_nodes()
# Start worker pool (it calculates initial workers internally)
self._worker_pool.start()
@ -328,7 +336,11 @@ class GraphEngine:
self._state_manager.enqueue_node(root_node.id)
self._state_manager.start_execution(root_node.id)
else:
for node_id in paused_nodes:
seen_nodes: set[str] = set()
for node_id in paused_nodes + deferred_nodes:
if node_id in seen_nodes:
continue
seen_nodes.add(node_id)
self._state_manager.enqueue_node(node_id)
self._state_manager.start_execution(node_id)
@ -346,8 +358,8 @@ class GraphEngine:
for layer in self._layers:
try:
layer.on_graph_end(self._graph_execution.error)
except Exception as e:
logger.warning("Layer %s failed on_graph_end: %s", layer.__class__.__name__, e)
except Exception:
logger.exception("Layer %s failed on_graph_end", layer.__class__.__name__)
# Public property accessors for attributes that need external access
@property

View File

@ -224,6 +224,8 @@ class GraphStateManager:
Returns:
Number of executing nodes
"""
# This count is a best-effort snapshot and can change concurrently.
# Only use it for pause-drain checks where scheduling is already frozen.
with self._lock:
return len(self._executing_nodes)

View File

@ -83,12 +83,12 @@ class Dispatcher:
"""Main dispatcher loop."""
try:
self._process_commands()
paused = False
while not self._stop_event.is_set():
if (
self._execution_coordinator.aborted
or self._execution_coordinator.paused
or self._execution_coordinator.execution_complete
):
if self._execution_coordinator.aborted or self._execution_coordinator.execution_complete:
break
if self._execution_coordinator.paused:
paused = True
break
self._execution_coordinator.check_scaling()
@ -101,13 +101,10 @@ class Dispatcher:
time.sleep(0.1)
self._process_commands()
while True:
try:
event = self._event_queue.get(block=False)
self._event_handler.dispatch(event)
self._event_queue.task_done()
except queue.Empty:
break
if paused:
self._drain_events_until_idle()
else:
self._drain_event_queue()
except Exception as e:
logger.exception("Dispatcher error")
@ -122,3 +119,24 @@ class Dispatcher:
def _process_commands(self, event: GraphNodeEventBase | None = None):
if event is None or isinstance(event, self._COMMAND_TRIGGER_EVENTS):
self._execution_coordinator.process_commands()
def _drain_event_queue(self) -> None:
while True:
try:
event = self._event_queue.get(block=False)
self._event_handler.dispatch(event)
self._event_queue.task_done()
except queue.Empty:
break
def _drain_events_until_idle(self) -> None:
while not self._stop_event.is_set():
try:
event = self._event_queue.get(timeout=0.1)
self._event_handler.dispatch(event)
self._event_queue.task_done()
self._process_commands(event)
except queue.Empty:
if not self._execution_coordinator.has_executing_nodes():
break
self._drain_event_queue()

View File

@ -94,3 +94,11 @@ class ExecutionCoordinator:
self._worker_pool.stop()
self._state_manager.clear_executing()
def has_executing_nodes(self) -> bool:
"""Return True if any nodes are currently marked as executing."""
# This check is only safe once execution has already paused.
# Before pause, executing state can change concurrently, which makes the result unreliable.
if not self._graph_execution.is_paused:
raise AssertionError("has_executing_nodes should only be called after execution is paused")
return self._state_manager.get_executing_count() > 0

View File

@ -38,6 +38,8 @@ from .loop import (
from .node import (
NodeRunExceptionEvent,
NodeRunFailedEvent,
NodeRunHumanInputFormFilledEvent,
NodeRunHumanInputFormTimeoutEvent,
NodeRunPauseRequestedEvent,
NodeRunRetrieverResourceEvent,
NodeRunRetryEvent,
@ -60,6 +62,8 @@ __all__ = [
"NodeRunAgentLogEvent",
"NodeRunExceptionEvent",
"NodeRunFailedEvent",
"NodeRunHumanInputFormFilledEvent",
"NodeRunHumanInputFormTimeoutEvent",
"NodeRunIterationFailedEvent",
"NodeRunIterationNextEvent",
"NodeRunIterationStartedEvent",

View File

@ -1,11 +1,16 @@
from pydantic import Field
from core.workflow.entities.pause_reason import PauseReason
from core.workflow.entities.workflow_start_reason import WorkflowStartReason
from core.workflow.graph_events import BaseGraphEvent
class GraphRunStartedEvent(BaseGraphEvent):
pass
# Reason is emitted for workflow start events and is always set.
reason: WorkflowStartReason = Field(
default=WorkflowStartReason.INITIAL,
description="reason for workflow start",
)
class GraphRunSucceededEvent(BaseGraphEvent):

View File

@ -54,6 +54,22 @@ class NodeRunRetryEvent(NodeRunStartedEvent):
retry_index: int = Field(..., description="which retry attempt is about to be performed")
class NodeRunHumanInputFormFilledEvent(GraphNodeEventBase):
"""Emitted when a HumanInput form is submitted and before the node finishes."""
node_title: str = Field(..., description="HumanInput node title")
rendered_content: str = Field(..., description="Markdown content rendered with user inputs.")
action_id: str = Field(..., description="User action identifier chosen in the form.")
action_text: str = Field(..., description="Display text of the chosen action button.")
class NodeRunHumanInputFormTimeoutEvent(GraphNodeEventBase):
"""Emitted when a HumanInput form times out."""
node_title: str = Field(..., description="HumanInput node title")
expiration_time: datetime = Field(..., description="Form expiration time")
class NodeRunPauseRequestedEvent(GraphNodeEventBase):
reason: PauseReason = Field(..., description="pause reason")

View File

@ -13,6 +13,8 @@ from .loop import (
LoopSucceededEvent,
)
from .node import (
HumanInputFormFilledEvent,
HumanInputFormTimeoutEvent,
ModelInvokeCompletedEvent,
PauseRequestedEvent,
RunRetrieverResourceEvent,
@ -23,6 +25,8 @@ from .node import (
__all__ = [
"AgentLogEvent",
"HumanInputFormFilledEvent",
"HumanInputFormTimeoutEvent",
"IterationFailedEvent",
"IterationNextEvent",
"IterationStartedEvent",

View File

@ -47,3 +47,19 @@ class StreamCompletedEvent(NodeEventBase):
class PauseRequestedEvent(NodeEventBase):
reason: PauseReason = Field(..., description="pause reason")
class HumanInputFormFilledEvent(NodeEventBase):
"""Event emitted when a human input form is submitted."""
node_title: str
rendered_content: str
action_id: str
action_text: str
class HumanInputFormTimeoutEvent(NodeEventBase):
"""Event emitted when a human input form times out."""
node_title: str
expiration_time: datetime

View File

@ -18,6 +18,8 @@ from core.workflow.graph_events import (
GraphNodeEventBase,
NodeRunAgentLogEvent,
NodeRunFailedEvent,
NodeRunHumanInputFormFilledEvent,
NodeRunHumanInputFormTimeoutEvent,
NodeRunIterationFailedEvent,
NodeRunIterationNextEvent,
NodeRunIterationStartedEvent,
@ -34,6 +36,8 @@ from core.workflow.graph_events import (
)
from core.workflow.node_events import (
AgentLogEvent,
HumanInputFormFilledEvent,
HumanInputFormTimeoutEvent,
IterationFailedEvent,
IterationNextEvent,
IterationStartedEvent,
@ -61,6 +65,15 @@ logger = logging.getLogger(__name__)
class Node(Generic[NodeDataT]):
"""BaseNode serves as the foundational class for all node implementations.
Nodes are allowed to maintain transient states (e.g., `LLMNode` uses the `_file_output`
attribute to track files generated by the LLM). However, these states are not persisted
when the workflow is suspended or resumed. If a node needs its state to be preserved
across workflow suspension and resumption, it should include the relevant state data
in its output.
"""
node_type: ClassVar[NodeType]
execution_type: NodeExecutionType = NodeExecutionType.EXECUTABLE
_node_data_type: ClassVar[type[BaseNodeData]] = BaseNodeData
@ -251,10 +264,33 @@ class Node(Generic[NodeDataT]):
return self._node_execution_id
def ensure_execution_id(self) -> str:
if not self._node_execution_id:
self._node_execution_id = str(uuid4())
if self._node_execution_id:
return self._node_execution_id
resumed_execution_id = self._restore_execution_id_from_runtime_state()
if resumed_execution_id:
self._node_execution_id = resumed_execution_id
return self._node_execution_id
self._node_execution_id = str(uuid4())
return self._node_execution_id
def _restore_execution_id_from_runtime_state(self) -> str | None:
graph_execution = self.graph_runtime_state.graph_execution
try:
node_executions = graph_execution.node_executions
except AttributeError:
return None
if not isinstance(node_executions, dict):
return None
node_execution = node_executions.get(self._node_id)
if node_execution is None:
return None
execution_id = node_execution.execution_id
if not execution_id:
return None
return str(execution_id)
def _hydrate_node_data(self, data: Mapping[str, Any]) -> NodeDataT:
return cast(NodeDataT, self._node_data_type.model_validate(data))
@ -620,6 +656,28 @@ class Node(Generic[NodeDataT]):
metadata=event.metadata,
)
@_dispatch.register
def _(self, event: HumanInputFormFilledEvent):
return NodeRunHumanInputFormFilledEvent(
id=self.execution_id,
node_id=self._node_id,
node_type=self.node_type,
node_title=event.node_title,
rendered_content=event.rendered_content,
action_id=event.action_id,
action_text=event.action_text,
)
@_dispatch.register
def _(self, event: HumanInputFormTimeoutEvent):
return NodeRunHumanInputFormTimeoutEvent(
id=self.execution_id,
node_id=self._node_id,
node_type=self.node_type,
node_title=event.node_title,
expiration_time=event.expiration_time,
)
@_dispatch.register
def _(self, event: LoopStartedEvent) -> NodeRunLoopStartedEvent:
return NodeRunLoopStartedEvent(

View File

@ -1,3 +1,3 @@
from .human_input_node import HumanInputNode
__all__ = ["HumanInputNode"]
"""
Human Input node implementation.
"""

View File

@ -1,10 +1,350 @@
from pydantic import Field
"""
Human Input node entities.
"""
import re
import uuid
from collections.abc import Mapping, Sequence
from datetime import datetime, timedelta
from typing import Annotated, Any, ClassVar, Literal, Self
from pydantic import BaseModel, Field, field_validator, model_validator
from core.variables.consts import SELECTORS_LENGTH
from core.workflow.nodes.base import BaseNodeData
from core.workflow.nodes.base.variable_template_parser import VariableTemplateParser
from core.workflow.runtime import VariablePool
from .enums import ButtonStyle, DeliveryMethodType, EmailRecipientType, FormInputType, PlaceholderType, TimeoutUnit
_OUTPUT_VARIABLE_PATTERN = re.compile(r"\{\{#\$output\.(?P<field_name>[a-zA-Z_][a-zA-Z0-9_]{0,29})#\}\}")
class _WebAppDeliveryConfig(BaseModel):
"""Configuration for webapp delivery method."""
pass # Empty for webapp delivery
class MemberRecipient(BaseModel):
"""Member recipient for email delivery."""
type: Literal[EmailRecipientType.MEMBER] = EmailRecipientType.MEMBER
user_id: str
class ExternalRecipient(BaseModel):
"""External recipient for email delivery."""
type: Literal[EmailRecipientType.EXTERNAL] = EmailRecipientType.EXTERNAL
email: str
EmailRecipient = Annotated[MemberRecipient | ExternalRecipient, Field(discriminator="type")]
class EmailRecipients(BaseModel):
"""Email recipients configuration."""
# When true, recipients are the union of all workspace members and external items.
# Member items are ignored because they are already covered by the workspace scope.
# De-duplication is applied by email, with member recipients taking precedence.
whole_workspace: bool = False
items: list[EmailRecipient] = Field(default_factory=list)
class EmailDeliveryConfig(BaseModel):
"""Configuration for email delivery method."""
URL_PLACEHOLDER: ClassVar[str] = "{{#url#}}"
recipients: EmailRecipients
# the subject of email
subject: str
# Body is the content of email.It may contain the speical placeholder `{{#url#}}`, which
# represent the url to submit the form.
#
# It may also reference the output variable of the previous node with the syntax
# `{{#<node_id>.<field_name>#}}`.
body: str
debug_mode: bool = False
def with_debug_recipient(self, user_id: str) -> "EmailDeliveryConfig":
if not user_id:
debug_recipients = EmailRecipients(whole_workspace=False, items=[])
return self.model_copy(update={"recipients": debug_recipients})
debug_recipients = EmailRecipients(whole_workspace=False, items=[MemberRecipient(user_id=user_id)])
return self.model_copy(update={"recipients": debug_recipients})
@classmethod
def replace_url_placeholder(cls, body: str, url: str | None) -> str:
"""Replace the url placeholder with provided value."""
return body.replace(cls.URL_PLACEHOLDER, url or "")
@classmethod
def render_body_template(
cls,
*,
body: str,
url: str | None,
variable_pool: VariablePool | None = None,
) -> str:
"""Render email body by replacing placeholders with runtime values."""
templated_body = cls.replace_url_placeholder(body, url)
if variable_pool is None:
return templated_body
return variable_pool.convert_template(templated_body).text
class _DeliveryMethodBase(BaseModel):
"""Base delivery method configuration."""
enabled: bool = True
id: uuid.UUID = Field(default_factory=uuid.uuid4)
def extract_variable_selectors(self) -> Sequence[Sequence[str]]:
return ()
class WebAppDeliveryMethod(_DeliveryMethodBase):
"""Webapp delivery method configuration."""
type: Literal[DeliveryMethodType.WEBAPP] = DeliveryMethodType.WEBAPP
# The config field is not used currently.
config: _WebAppDeliveryConfig = Field(default_factory=_WebAppDeliveryConfig)
class EmailDeliveryMethod(_DeliveryMethodBase):
"""Email delivery method configuration."""
type: Literal[DeliveryMethodType.EMAIL] = DeliveryMethodType.EMAIL
config: EmailDeliveryConfig
def extract_variable_selectors(self) -> Sequence[Sequence[str]]:
variable_template_parser = VariableTemplateParser(template=self.config.body)
selectors: list[Sequence[str]] = []
for variable_selector in variable_template_parser.extract_variable_selectors():
value_selector = list(variable_selector.value_selector)
if len(value_selector) < SELECTORS_LENGTH:
continue
selectors.append(value_selector[:SELECTORS_LENGTH])
return selectors
DeliveryChannelConfig = Annotated[WebAppDeliveryMethod | EmailDeliveryMethod, Field(discriminator="type")]
def apply_debug_email_recipient(
method: DeliveryChannelConfig,
*,
enabled: bool,
user_id: str,
) -> DeliveryChannelConfig:
if not enabled:
return method
if not isinstance(method, EmailDeliveryMethod):
return method
if not method.config.debug_mode:
return method
debug_config = method.config.with_debug_recipient(user_id or "")
return method.model_copy(update={"config": debug_config})
class FormInputDefault(BaseModel):
"""Default configuration for form inputs."""
# NOTE: Ideally, a discriminated union would be used to model
# FormInputDefault. However, the UI requires preserving the previous
# value when switching between `VARIABLE` and `CONSTANT` types. This
# necessitates retaining all fields, making a discriminated union unsuitable.
type: PlaceholderType
# The selector of default variable, used when `type` is `VARIABLE`.
selector: Sequence[str] = Field(default_factory=tuple) #
# The value of the default, used when `type` is `CONSTANT`.
# TODO: How should we express JSON values?
value: str = ""
@model_validator(mode="after")
def _validate_selector(self) -> Self:
if self.type == PlaceholderType.CONSTANT:
return self
if len(self.selector) < SELECTORS_LENGTH:
raise ValueError(f"the length of selector should be at least {SELECTORS_LENGTH}, selector={self.selector}")
return self
class FormInput(BaseModel):
"""Form input definition."""
type: FormInputType
output_variable_name: str
default: FormInputDefault | None = None
_IDENTIFIER_PATTERN = re.compile(r"^[A-Za-z_][A-Za-z0-9_]*$")
class UserAction(BaseModel):
"""User action configuration."""
# id is the identifier for this action.
# It also serves as the identifiers of output handle.
#
# The id must be a valid identifier (satisfy the _IDENTIFIER_PATTERN above.)
id: str = Field(max_length=20)
title: str = Field(max_length=20)
button_style: ButtonStyle = ButtonStyle.DEFAULT
@field_validator("id")
@classmethod
def _validate_id(cls, value: str) -> str:
if not _IDENTIFIER_PATTERN.match(value):
raise ValueError(
f"'{value}' is not a valid identifier. It must start with a letter or underscore, "
f"and contain only letters, numbers, or underscores."
)
return value
class HumanInputNodeData(BaseNodeData):
"""Configuration schema for the HumanInput node."""
"""Human Input node data."""
required_variables: list[str] = Field(default_factory=list)
pause_reason: str | None = Field(default=None)
delivery_methods: list[DeliveryChannelConfig] = Field(default_factory=list)
form_content: str = ""
inputs: list[FormInput] = Field(default_factory=list)
user_actions: list[UserAction] = Field(default_factory=list)
timeout: int = 36
timeout_unit: TimeoutUnit = TimeoutUnit.HOUR
@field_validator("inputs")
@classmethod
def _validate_inputs(cls, inputs: list[FormInput]) -> list[FormInput]:
seen_names: set[str] = set()
for form_input in inputs:
name = form_input.output_variable_name
if name in seen_names:
raise ValueError(f"duplicated output_variable_name '{name}' in inputs")
seen_names.add(name)
return inputs
@field_validator("user_actions")
@classmethod
def _validate_user_actions(cls, user_actions: list[UserAction]) -> list[UserAction]:
seen_ids: set[str] = set()
for action in user_actions:
action_id = action.id
if action_id in seen_ids:
raise ValueError(f"duplicated user action id '{action_id}'")
seen_ids.add(action_id)
return user_actions
def is_webapp_enabled(self) -> bool:
for dm in self.delivery_methods:
if not dm.enabled:
continue
if dm.type == DeliveryMethodType.WEBAPP:
return True
return False
def expiration_time(self, start_time: datetime) -> datetime:
if self.timeout_unit == TimeoutUnit.HOUR:
return start_time + timedelta(hours=self.timeout)
elif self.timeout_unit == TimeoutUnit.DAY:
return start_time + timedelta(days=self.timeout)
else:
raise AssertionError("unknown timeout unit.")
def outputs_field_names(self) -> Sequence[str]:
field_names = []
for match in _OUTPUT_VARIABLE_PATTERN.finditer(self.form_content):
field_names.append(match.group("field_name"))
return field_names
def extract_variable_selector_to_variable_mapping(self, node_id: str) -> Mapping[str, Sequence[str]]:
variable_mappings: dict[str, Sequence[str]] = {}
def _add_variable_selectors(selectors: Sequence[Sequence[str]]) -> None:
for selector in selectors:
if len(selector) < SELECTORS_LENGTH:
continue
qualified_variable_mapping_key = f"{node_id}.#{'.'.join(selector[:SELECTORS_LENGTH])}#"
variable_mappings[qualified_variable_mapping_key] = list(selector[:SELECTORS_LENGTH])
form_template_parser = VariableTemplateParser(template=self.form_content)
_add_variable_selectors(
[selector.value_selector for selector in form_template_parser.extract_variable_selectors()]
)
for delivery_method in self.delivery_methods:
if not delivery_method.enabled:
continue
_add_variable_selectors(delivery_method.extract_variable_selectors())
for input in self.inputs:
default_value = input.default
if default_value is None:
continue
if default_value.type == PlaceholderType.CONSTANT:
continue
default_value_key = ".".join(default_value.selector)
qualified_variable_mapping_key = f"{node_id}.#{default_value_key}#"
variable_mappings[qualified_variable_mapping_key] = default_value.selector
return variable_mappings
def find_action_text(self, action_id: str) -> str:
"""
Resolve action display text by id.
"""
for action in self.user_actions:
if action.id == action_id:
return action.title
return action_id
class FormDefinition(BaseModel):
form_content: str
inputs: list[FormInput] = Field(default_factory=list)
user_actions: list[UserAction] = Field(default_factory=list)
rendered_content: str
expiration_time: datetime
# this is used to store the resolved default values
default_values: dict[str, Any] = Field(default_factory=dict)
# node_title records the title of the HumanInput node.
node_title: str | None = None
# display_in_ui controls whether the form should be displayed in UI surfaces.
display_in_ui: bool | None = None
class HumanInputSubmissionValidationError(ValueError):
pass
def validate_human_input_submission(
*,
inputs: Sequence[FormInput],
user_actions: Sequence[UserAction],
selected_action_id: str,
form_data: Mapping[str, Any],
) -> None:
available_actions = {action.id for action in user_actions}
if selected_action_id not in available_actions:
raise HumanInputSubmissionValidationError(f"Invalid action: {selected_action_id}")
provided_inputs = set(form_data.keys())
missing_inputs = [
form_input.output_variable_name
for form_input in inputs
if form_input.output_variable_name not in provided_inputs
]
if missing_inputs:
missing_list = ", ".join(missing_inputs)
raise HumanInputSubmissionValidationError(f"Missing required inputs: {missing_list}")

View File

@ -0,0 +1,72 @@
import enum
class HumanInputFormStatus(enum.StrEnum):
"""Status of a human input form."""
# Awaiting submission from any recipient. Forms stay in this state until
# submitted or a timeout rule applies.
WAITING = enum.auto()
# Global timeout reached. The workflow run is stopped and will not resume.
# This is distinct from node-level timeout.
EXPIRED = enum.auto()
# Submitted by a recipient; form data is available and execution resumes
# along the selected action edge.
SUBMITTED = enum.auto()
# Node-level timeout reached. The human input node should emit a timeout
# event and the workflow should resume along the timeout edge.
TIMEOUT = enum.auto()
class HumanInputFormKind(enum.StrEnum):
"""Kind of a human input form."""
RUNTIME = enum.auto() # Form created during workflow execution.
DELIVERY_TEST = enum.auto() # Form created for delivery tests.
class DeliveryMethodType(enum.StrEnum):
"""Delivery method types for human input forms."""
# WEBAPP controls whether the form is delivered to the web app. It not only controls
# the standalone web app, but also controls the installed apps in the console.
WEBAPP = enum.auto()
EMAIL = enum.auto()
class ButtonStyle(enum.StrEnum):
"""Button styles for user actions."""
PRIMARY = enum.auto()
DEFAULT = enum.auto()
ACCENT = enum.auto()
GHOST = enum.auto()
class TimeoutUnit(enum.StrEnum):
"""Timeout unit for form expiration."""
HOUR = enum.auto()
DAY = enum.auto()
class FormInputType(enum.StrEnum):
"""Form input types."""
TEXT_INPUT = enum.auto()
PARAGRAPH = enum.auto()
class PlaceholderType(enum.StrEnum):
"""Default value types for form inputs."""
VARIABLE = enum.auto()
CONSTANT = enum.auto()
class EmailRecipientType(enum.StrEnum):
"""Email recipient types."""
MEMBER = enum.auto()
EXTERNAL = enum.auto()

View File

@ -1,12 +1,42 @@
from collections.abc import Mapping
from typing import Any
import json
import logging
from collections.abc import Generator, Mapping, Sequence
from typing import TYPE_CHECKING, Any
from core.app.entities.app_invoke_entities import InvokeFrom
from core.repositories.human_input_repository import HumanInputFormRepositoryImpl
from core.workflow.entities.pause_reason import HumanInputRequired
from core.workflow.enums import NodeExecutionType, NodeType, WorkflowNodeExecutionStatus
from core.workflow.node_events import NodeRunResult, PauseRequestedEvent
from core.workflow.node_events import (
HumanInputFormFilledEvent,
HumanInputFormTimeoutEvent,
NodeRunResult,
PauseRequestedEvent,
)
from core.workflow.node_events.base import NodeEventBase
from core.workflow.node_events.node import StreamCompletedEvent
from core.workflow.nodes.base.node import Node
from core.workflow.repositories.human_input_form_repository import (
FormCreateParams,
HumanInputFormEntity,
HumanInputFormRepository,
)
from core.workflow.workflow_type_encoder import WorkflowRuntimeTypeConverter
from extensions.ext_database import db
from libs.datetime_utils import naive_utc_now
from .entities import HumanInputNodeData
from .entities import DeliveryChannelConfig, HumanInputNodeData, apply_debug_email_recipient
from .enums import DeliveryMethodType, HumanInputFormStatus, PlaceholderType
if TYPE_CHECKING:
from core.workflow.entities.graph_init_params import GraphInitParams
from core.workflow.runtime.graph_runtime_state import GraphRuntimeState
_SELECTED_BRANCH_KEY = "selected_branch"
logger = logging.getLogger(__name__)
class HumanInputNode(Node[HumanInputNodeData]):
@ -17,7 +47,7 @@ class HumanInputNode(Node[HumanInputNodeData]):
"edge_source_handle",
"edgeSourceHandle",
"source_handle",
"selected_branch",
_SELECTED_BRANCH_KEY,
"selectedBranch",
"branch",
"branch_id",
@ -25,43 +55,37 @@ class HumanInputNode(Node[HumanInputNodeData]):
"handle",
)
_node_data: HumanInputNodeData
_form_repository: HumanInputFormRepository
_OUTPUT_FIELD_ACTION_ID = "__action_id"
_OUTPUT_FIELD_RENDERED_CONTENT = "__rendered_content"
_TIMEOUT_HANDLE = _TIMEOUT_ACTION_ID = "__timeout"
def __init__(
self,
id: str,
config: Mapping[str, Any],
graph_init_params: "GraphInitParams",
graph_runtime_state: "GraphRuntimeState",
form_repository: HumanInputFormRepository | None = None,
) -> None:
super().__init__(
id=id,
config=config,
graph_init_params=graph_init_params,
graph_runtime_state=graph_runtime_state,
)
if form_repository is None:
form_repository = HumanInputFormRepositoryImpl(
session_factory=db.engine,
tenant_id=self.tenant_id,
)
self._form_repository = form_repository
@classmethod
def version(cls) -> str:
return "1"
def _run(self): # type: ignore[override]
if self._is_completion_ready():
branch_handle = self._resolve_branch_selection()
return NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED,
outputs={},
edge_source_handle=branch_handle or "source",
)
return self._pause_generator()
def _pause_generator(self):
# TODO(QuantumGhost): yield a real form id.
yield PauseRequestedEvent(reason=HumanInputRequired(form_id="test_form_id", node_id=self.id))
def _is_completion_ready(self) -> bool:
"""Determine whether all required inputs are satisfied."""
if not self.node_data.required_variables:
return False
variable_pool = self.graph_runtime_state.variable_pool
for selector_str in self.node_data.required_variables:
parts = selector_str.split(".")
if len(parts) != 2:
return False
segment = variable_pool.get(parts)
if segment is None:
return False
return True
def _resolve_branch_selection(self) -> str | None:
"""Determine the branch handle selected by human input if available."""
@ -108,3 +132,224 @@ class HumanInputNode(Node[HumanInputNodeData]):
return candidate
return None
@property
def _workflow_execution_id(self) -> str:
workflow_exec_id = self.graph_runtime_state.variable_pool.system_variables.workflow_execution_id
assert workflow_exec_id is not None
return workflow_exec_id
def _form_to_pause_event(self, form_entity: HumanInputFormEntity):
required_event = self._human_input_required_event(form_entity)
pause_requested_event = PauseRequestedEvent(reason=required_event)
return pause_requested_event
def resolve_default_values(self) -> Mapping[str, Any]:
variable_pool = self.graph_runtime_state.variable_pool
resolved_defaults: dict[str, Any] = {}
for input in self._node_data.inputs:
if (default_value := input.default) is None:
continue
if default_value.type == PlaceholderType.CONSTANT:
continue
resolved_value = variable_pool.get(default_value.selector)
if resolved_value is None:
# TODO: How should we handle this?
continue
resolved_defaults[input.output_variable_name] = (
WorkflowRuntimeTypeConverter().value_to_json_encodable_recursive(resolved_value.value)
)
return resolved_defaults
def _should_require_console_recipient(self) -> bool:
if self.invoke_from == InvokeFrom.DEBUGGER:
return True
if self.invoke_from == InvokeFrom.EXPLORE:
return self._node_data.is_webapp_enabled()
return False
def _display_in_ui(self) -> bool:
if self.invoke_from == InvokeFrom.DEBUGGER:
return True
return self._node_data.is_webapp_enabled()
def _effective_delivery_methods(self) -> Sequence[DeliveryChannelConfig]:
enabled_methods = [method for method in self._node_data.delivery_methods if method.enabled]
if self.invoke_from in {InvokeFrom.DEBUGGER, InvokeFrom.EXPLORE}:
enabled_methods = [method for method in enabled_methods if method.type != DeliveryMethodType.WEBAPP]
return [
apply_debug_email_recipient(
method,
enabled=self.invoke_from == InvokeFrom.DEBUGGER,
user_id=self.user_id or "",
)
for method in enabled_methods
]
def _human_input_required_event(self, form_entity: HumanInputFormEntity) -> HumanInputRequired:
node_data = self._node_data
resolved_default_values = self.resolve_default_values()
display_in_ui = self._display_in_ui()
form_token = form_entity.web_app_token
if display_in_ui and form_token is None:
raise AssertionError("Form token should be available for UI execution.")
return HumanInputRequired(
form_id=form_entity.id,
form_content=form_entity.rendered_content,
inputs=node_data.inputs,
actions=node_data.user_actions,
display_in_ui=display_in_ui,
node_id=self.id,
node_title=node_data.title,
form_token=form_token,
resolved_default_values=resolved_default_values,
)
def _run(self) -> Generator[NodeEventBase, None, None]:
"""
Execute the human input node.
This method will:
1. Generate a unique form ID
2. Create form content with variable substitution
3. Create form in database
4. Send form via configured delivery methods
5. Suspend workflow execution
6. Wait for form submission to resume
"""
repo = self._form_repository
form = repo.get_form(self._workflow_execution_id, self.id)
if form is None:
display_in_ui = self._display_in_ui()
params = FormCreateParams(
app_id=self.app_id,
workflow_execution_id=self._workflow_execution_id,
node_id=self.id,
form_config=self._node_data,
rendered_content=self.render_form_content_before_submission(),
delivery_methods=self._effective_delivery_methods(),
display_in_ui=display_in_ui,
resolved_default_values=self.resolve_default_values(),
console_recipient_required=self._should_require_console_recipient(),
console_creator_account_id=(
self.user_id if self.invoke_from in {InvokeFrom.DEBUGGER, InvokeFrom.EXPLORE} else None
),
backstage_recipient_required=True,
)
form_entity = self._form_repository.create_form(params)
# Create human input required event
logger.info(
"Human Input node suspended workflow for form. workflow_run_id=%s, node_id=%s, form_id=%s",
self.graph_runtime_state.variable_pool.system_variables.workflow_execution_id,
self.id,
form_entity.id,
)
yield self._form_to_pause_event(form_entity)
return
if (
form.status in {HumanInputFormStatus.TIMEOUT, HumanInputFormStatus.EXPIRED}
or form.expiration_time <= naive_utc_now()
):
yield HumanInputFormTimeoutEvent(
node_title=self._node_data.title,
expiration_time=form.expiration_time,
)
yield StreamCompletedEvent(
node_run_result=NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED,
outputs={self._OUTPUT_FIELD_ACTION_ID: ""},
edge_source_handle=self._TIMEOUT_HANDLE,
)
)
return
if not form.submitted:
yield self._form_to_pause_event(form)
return
selected_action_id = form.selected_action_id
if selected_action_id is None:
raise AssertionError(f"selected_action_id should not be None when form submitted, form_id={form.id}")
submitted_data = form.submitted_data or {}
outputs: dict[str, Any] = dict(submitted_data)
outputs[self._OUTPUT_FIELD_ACTION_ID] = selected_action_id
rendered_content = self.render_form_content_with_outputs(
form.rendered_content,
outputs,
self._node_data.outputs_field_names(),
)
outputs[self._OUTPUT_FIELD_RENDERED_CONTENT] = rendered_content
action_text = self._node_data.find_action_text(selected_action_id)
yield HumanInputFormFilledEvent(
node_title=self._node_data.title,
rendered_content=rendered_content,
action_id=selected_action_id,
action_text=action_text,
)
yield StreamCompletedEvent(
node_run_result=NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED,
outputs=outputs,
edge_source_handle=selected_action_id,
)
)
def render_form_content_before_submission(self) -> str:
"""
Process form content by substituting variables.
This method should:
1. Parse the form_content markdown
2. Substitute {{#node_name.var_name#}} with actual values
3. Keep {{#$output.field_name#}} placeholders for form inputs
"""
rendered_form_content = self.graph_runtime_state.variable_pool.convert_template(
self._node_data.form_content,
)
return rendered_form_content.markdown
@staticmethod
def render_form_content_with_outputs(
form_content: str,
outputs: Mapping[str, Any],
field_names: Sequence[str],
) -> str:
"""
Replace {{#$output.xxx#}} placeholders with submitted values.
"""
rendered_content = form_content
for field_name in field_names:
placeholder = "{{#$output." + field_name + "#}}"
value = outputs.get(field_name)
if value is None:
replacement = ""
elif isinstance(value, (dict, list)):
replacement = json.dumps(value, ensure_ascii=False)
else:
replacement = str(value)
rendered_content = rendered_content.replace(placeholder, replacement)
return rendered_content
@classmethod
def _extract_variable_selector_to_variable_mapping(
cls,
*,
graph_config: Mapping[str, Any],
node_id: str,
node_data: Mapping[str, Any],
) -> Mapping[str, Sequence[str]]:
"""
Extract variable selectors referenced in form content and input default values.
This method should parse:
1. Variables referenced in form_content ({{#node_name.var_name#}})
2. Variables referenced in input default values
"""
validated_node_data = HumanInputNodeData.model_validate(node_data)
return validated_node_data.extract_variable_selector_to_variable_mapping(node_id)

View File

@ -0,0 +1,152 @@
import abc
import dataclasses
from collections.abc import Mapping, Sequence
from datetime import datetime
from typing import Any, Protocol
from core.workflow.nodes.human_input.entities import DeliveryChannelConfig, HumanInputNodeData
from core.workflow.nodes.human_input.enums import HumanInputFormKind, HumanInputFormStatus
class HumanInputError(Exception):
pass
class FormNotFoundError(HumanInputError):
pass
@dataclasses.dataclass
class FormCreateParams:
# app_id is the identifier for the app that the form belongs to.
# It is a string with uuid format.
app_id: str
# None when creating a delivery test form; set for runtime forms.
workflow_execution_id: str | None
# node_id is the identifier for a specific
# node in the graph.
#
# TODO: for node inside loop / iteration, this would
# cause problems, as a single node may be executed multiple times.
node_id: str
form_config: HumanInputNodeData
rendered_content: str
# Delivery methods already filtered by runtime context (invoke_from).
delivery_methods: Sequence[DeliveryChannelConfig]
# UI display flag computed by runtime context.
display_in_ui: bool
# resolved_default_values saves the values for defaults with
# type = VARIABLE.
#
# For type = CONSTANT, the value is not stored inside `resolved_default_values`
resolved_default_values: Mapping[str, Any]
form_kind: HumanInputFormKind = HumanInputFormKind.RUNTIME
# Force creating a console-only recipient for submission in Console.
console_recipient_required: bool = False
console_creator_account_id: str | None = None
# Force creating a backstage recipient for submission in Console.
backstage_recipient_required: bool = False
class HumanInputFormEntity(abc.ABC):
@property
@abc.abstractmethod
def id(self) -> str:
"""id returns the identifer of the form."""
pass
@property
@abc.abstractmethod
def web_app_token(self) -> str | None:
"""web_app_token returns the token for submission inside webapp.
For console/debug execution, this may point to the console submission token
if the form is configured to require console delivery.
"""
# TODO: what if the users are allowed to add multiple
# webapp delivery?
pass
@property
@abc.abstractmethod
def recipients(self) -> list["HumanInputFormRecipientEntity"]: ...
@property
@abc.abstractmethod
def rendered_content(self) -> str:
"""Rendered markdown content associated with the form."""
...
@property
@abc.abstractmethod
def selected_action_id(self) -> str | None:
"""Identifier of the selected user action if the form has been submitted."""
...
@property
@abc.abstractmethod
def submitted_data(self) -> Mapping[str, Any] | None:
"""Submitted form data if available."""
...
@property
@abc.abstractmethod
def submitted(self) -> bool:
"""Whether the form has been submitted."""
...
@property
@abc.abstractmethod
def status(self) -> HumanInputFormStatus:
"""Current status of the form."""
...
@property
@abc.abstractmethod
def expiration_time(self) -> datetime:
"""When the form expires."""
...
class HumanInputFormRecipientEntity(abc.ABC):
@property
@abc.abstractmethod
def id(self) -> str:
"""id returns the identifer of this recipient."""
...
@property
@abc.abstractmethod
def token(self) -> str:
"""token returns a random string used to submit form"""
...
class HumanInputFormRepository(Protocol):
"""
Repository interface for HumanInputForm.
This interface defines the contract for accessing and manipulating
HumanInputForm data, regardless of the underlying storage mechanism.
Note: Domain-specific concepts like multi-tenancy (tenant_id), application context (app_id),
and other implementation details should be handled at the implementation level, not in
the core interface. This keeps the core domain model clean and independent of specific
application domains or deployment scenarios.
"""
def get_form(self, workflow_execution_id: str, node_id: str) -> HumanInputFormEntity | None:
"""Get the form created for a given human input node in a workflow execution. Returns
`None` if the form has not been created yet."""
...
def create_form(self, params: FormCreateParams) -> HumanInputFormEntity:
"""
Create a human input form from form definition.
"""
...

View File

@ -6,14 +6,18 @@ import threading
from collections.abc import Mapping, Sequence
from copy import deepcopy
from dataclasses import dataclass
from typing import Any, Protocol
from typing import TYPE_CHECKING, Any, Protocol
from pydantic import BaseModel, Field
from pydantic.json import pydantic_encoder
from core.model_runtime.entities.llm_entities import LLMUsage
from core.workflow.entities.pause_reason import PauseReason
from core.workflow.enums import NodeState
from core.workflow.runtime.variable_pool import VariablePool
if TYPE_CHECKING:
from core.workflow.entities.pause_reason import PauseReason
class ReadyQueueProtocol(Protocol):
"""Structural interface required from ready queue implementations."""
@ -60,7 +64,7 @@ class GraphExecutionProtocol(Protocol):
aborted: bool
error: Exception | None
exceptions_count: int
pause_reasons: list[PauseReason]
pause_reasons: Sequence[PauseReason]
def start(self) -> None:
"""Transition execution into the running state."""
@ -103,14 +107,33 @@ class ResponseStreamCoordinatorProtocol(Protocol):
...
class NodeProtocol(Protocol):
"""Structural interface for graph nodes."""
id: str
state: NodeState
class EdgeProtocol(Protocol):
id: str
state: NodeState
class GraphProtocol(Protocol):
"""Structural interface required from graph instances attached to the runtime state."""
nodes: Mapping[str, object]
edges: Mapping[str, object]
root_node: object
nodes: Mapping[str, NodeProtocol]
edges: Mapping[str, EdgeProtocol]
root_node: NodeProtocol
def get_outgoing_edges(self, node_id: str) -> Sequence[object]: ...
def get_outgoing_edges(self, node_id: str) -> Sequence[EdgeProtocol]: ...
class _GraphStateSnapshot(BaseModel):
"""Serializable graph state snapshot for node/edge states."""
nodes: dict[str, NodeState] = Field(default_factory=dict)
edges: dict[str, NodeState] = Field(default_factory=dict)
@dataclass(slots=True)
@ -128,10 +151,20 @@ class _GraphRuntimeStateSnapshot:
graph_execution_dump: str | None
response_coordinator_dump: str | None
paused_nodes: tuple[str, ...]
deferred_nodes: tuple[str, ...]
graph_node_states: dict[str, NodeState]
graph_edge_states: dict[str, NodeState]
class GraphRuntimeState:
"""Mutable runtime state shared across graph execution components."""
"""Mutable runtime state shared across graph execution components.
`GraphRuntimeState` encapsulates the runtime state of workflow execution,
including scheduling details, variable values, and timing information.
Values that are initialized prior to workflow execution and remain constant
throughout the execution should be part of `GraphInitParams` instead.
"""
def __init__(
self,
@ -169,6 +202,16 @@ class GraphRuntimeState:
self._pending_response_coordinator_dump: str | None = None
self._pending_graph_execution_workflow_id: str | None = None
self._paused_nodes: set[str] = set()
self._deferred_nodes: set[str] = set()
# Node and edges states needed to be restored into
# graph object.
#
# These two fields are non-None only when resuming from a snapshot.
# Once the graph is attached, these two fields will be set to None.
self._pending_graph_node_states: dict[str, NodeState] | None = None
self._pending_graph_edge_states: dict[str, NodeState] | None = None
self.stop_event: threading.Event = threading.Event()
if graph is not None:
@ -190,6 +233,7 @@ class GraphRuntimeState:
if self._pending_response_coordinator_dump is not None and self._response_coordinator is not None:
self._response_coordinator.loads(self._pending_response_coordinator_dump)
self._pending_response_coordinator_dump = None
self._apply_pending_graph_state()
def configure(self, *, graph: GraphProtocol | None = None) -> None:
"""Ensure core collaborators are initialized with the provided context."""
@ -311,8 +355,13 @@ class GraphRuntimeState:
"ready_queue": self.ready_queue.dumps(),
"graph_execution": self.graph_execution.dumps(),
"paused_nodes": list(self._paused_nodes),
"deferred_nodes": list(self._deferred_nodes),
}
graph_state = self._snapshot_graph_state()
if graph_state is not None:
snapshot["graph_state"] = graph_state
if self._response_coordinator is not None and self._graph is not None:
snapshot["response_coordinator"] = self._response_coordinator.dumps()
@ -346,6 +395,11 @@ class GraphRuntimeState:
self._paused_nodes.add(node_id)
def get_paused_nodes(self) -> list[str]:
"""Retrieve the list of paused nodes without mutating internal state."""
return list(self._paused_nodes)
def consume_paused_nodes(self) -> list[str]:
"""Retrieve and clear the list of paused nodes awaiting resume."""
@ -353,6 +407,23 @@ class GraphRuntimeState:
self._paused_nodes.clear()
return nodes
def register_deferred_node(self, node_id: str) -> None:
"""Record a node that became ready during pause and should resume later."""
self._deferred_nodes.add(node_id)
def get_deferred_nodes(self) -> list[str]:
"""Retrieve deferred nodes without mutating internal state."""
return list(self._deferred_nodes)
def consume_deferred_nodes(self) -> list[str]:
"""Retrieve and clear deferred nodes awaiting resume."""
nodes = list(self._deferred_nodes)
self._deferred_nodes.clear()
return nodes
# ------------------------------------------------------------------
# Builders
# ------------------------------------------------------------------
@ -414,6 +485,10 @@ class GraphRuntimeState:
graph_execution_payload = payload.get("graph_execution")
response_payload = payload.get("response_coordinator")
paused_nodes_payload = payload.get("paused_nodes", [])
deferred_nodes_payload = payload.get("deferred_nodes", [])
graph_state_payload = payload.get("graph_state", {}) or {}
graph_node_states = _coerce_graph_state_map(graph_state_payload, "nodes")
graph_edge_states = _coerce_graph_state_map(graph_state_payload, "edges")
return _GraphRuntimeStateSnapshot(
start_at=start_at,
@ -427,6 +502,9 @@ class GraphRuntimeState:
graph_execution_dump=graph_execution_payload,
response_coordinator_dump=response_payload,
paused_nodes=tuple(map(str, paused_nodes_payload)),
deferred_nodes=tuple(map(str, deferred_nodes_payload)),
graph_node_states=graph_node_states,
graph_edge_states=graph_edge_states,
)
def _apply_snapshot(self, snapshot: _GraphRuntimeStateSnapshot) -> None:
@ -442,6 +520,10 @@ class GraphRuntimeState:
self._restore_graph_execution(snapshot.graph_execution_dump)
self._restore_response_coordinator(snapshot.response_coordinator_dump)
self._paused_nodes = set(snapshot.paused_nodes)
self._deferred_nodes = set(snapshot.deferred_nodes)
self._pending_graph_node_states = snapshot.graph_node_states or None
self._pending_graph_edge_states = snapshot.graph_edge_states or None
self._apply_pending_graph_state()
def _restore_ready_queue(self, payload: str | None) -> None:
if payload is not None:
@ -478,3 +560,68 @@ class GraphRuntimeState:
self._pending_response_coordinator_dump = payload
self._response_coordinator = None
def _snapshot_graph_state(self) -> _GraphStateSnapshot:
graph = self._graph
if graph is None:
if self._pending_graph_node_states is None and self._pending_graph_edge_states is None:
return _GraphStateSnapshot()
return _GraphStateSnapshot(
nodes=self._pending_graph_node_states or {},
edges=self._pending_graph_edge_states or {},
)
nodes = graph.nodes
edges = graph.edges
if not isinstance(nodes, Mapping) or not isinstance(edges, Mapping):
return _GraphStateSnapshot()
node_states = {}
for node_id, node in nodes.items():
if not isinstance(node_id, str):
continue
node_states[node_id] = node.state
edge_states = {}
for edge_id, edge in edges.items():
if not isinstance(edge_id, str):
continue
edge_states[edge_id] = edge.state
return _GraphStateSnapshot(nodes=node_states, edges=edge_states)
def _apply_pending_graph_state(self) -> None:
if self._graph is None:
return
if self._pending_graph_node_states:
for node_id, state in self._pending_graph_node_states.items():
node = self._graph.nodes.get(node_id)
if node is None:
continue
node.state = state
if self._pending_graph_edge_states:
for edge_id, state in self._pending_graph_edge_states.items():
edge = self._graph.edges.get(edge_id)
if edge is None:
continue
edge.state = state
self._pending_graph_node_states = None
self._pending_graph_edge_states = None
def _coerce_graph_state_map(payload: Any, key: str) -> dict[str, NodeState]:
if not isinstance(payload, Mapping):
return {}
raw_map = payload.get(key, {})
if not isinstance(raw_map, Mapping):
return {}
result: dict[str, NodeState] = {}
for node_id, raw_state in raw_map.items():
if not isinstance(node_id, str):
continue
try:
result[node_id] = NodeState(str(raw_state))
except ValueError:
continue
return result

View File

@ -15,12 +15,14 @@ class WorkflowRuntimeTypeConverter:
def to_json_encodable(self, value: None) -> None: ...
def to_json_encodable(self, value: Mapping[str, Any] | None) -> Mapping[str, Any] | None:
result = self._to_json_encodable_recursive(value)
"""Convert runtime values to JSON-serializable structures."""
result = self.value_to_json_encodable_recursive(value)
if isinstance(result, Mapping) or result is None:
return result
return {}
def _to_json_encodable_recursive(self, value: Any):
def value_to_json_encodable_recursive(self, value: Any):
if value is None:
return value
if isinstance(value, (bool, int, str, float)):
@ -29,7 +31,7 @@ class WorkflowRuntimeTypeConverter:
# Convert Decimal to float for JSON serialization
return float(value)
if isinstance(value, Segment):
return self._to_json_encodable_recursive(value.value)
return self.value_to_json_encodable_recursive(value.value)
if isinstance(value, File):
return value.to_dict()
if isinstance(value, BaseModel):
@ -37,11 +39,11 @@ class WorkflowRuntimeTypeConverter:
if isinstance(value, dict):
res = {}
for k, v in value.items():
res[k] = self._to_json_encodable_recursive(v)
res[k] = self.value_to_json_encodable_recursive(v)
return res
if isinstance(value, list):
res_list = []
for item in value:
res_list.append(self._to_json_encodable_recursive(item))
res_list.append(self.value_to_json_encodable_recursive(item))
return res_list
return value

View File

@ -35,10 +35,10 @@ if [[ "${MODE}" == "worker" ]]; then
if [[ -z "${CELERY_QUEUES}" ]]; then
if [[ "${EDITION}" == "CLOUD" ]]; then
# Cloud edition: separate queues for dataset and trigger tasks
DEFAULT_QUEUES="dataset,priority_dataset,priority_pipeline,pipeline,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation,workflow_professional,workflow_team,workflow_sandbox,schedule_poller,schedule_executor,triggered_workflow_dispatcher,trigger_refresh_executor,retention"
DEFAULT_QUEUES="dataset,priority_dataset,priority_pipeline,pipeline,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation,workflow_professional,workflow_team,workflow_sandbox,schedule_poller,schedule_executor,triggered_workflow_dispatcher,trigger_refresh_executor,retention,workflow_based_app_execution"
else
# Community edition (SELF_HOSTED): dataset, pipeline and workflow have separate queues
DEFAULT_QUEUES="dataset,priority_dataset,priority_pipeline,pipeline,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation,workflow,schedule_poller,schedule_executor,triggered_workflow_dispatcher,trigger_refresh_executor,retention"
DEFAULT_QUEUES="dataset,priority_dataset,priority_pipeline,pipeline,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation,workflow,schedule_poller,schedule_executor,triggered_workflow_dispatcher,trigger_refresh_executor,retention,workflow_based_app_execution"
fi
else
DEFAULT_QUEUES="${CELERY_QUEUES}"
@ -102,7 +102,7 @@ elif [[ "${MODE}" == "job" ]]; then
fi
echo "Running Flask job command: flask $*"
# Temporarily disable exit on error to capture exit code
set +e
flask "$@"

View File

@ -151,6 +151,12 @@ def init_app(app: DifyApp) -> Celery:
"task": "schedule.queue_monitor_task.queue_monitor_task",
"schedule": timedelta(minutes=dify_config.QUEUE_MONITOR_INTERVAL or 30),
}
if dify_config.ENABLE_HUMAN_INPUT_TIMEOUT_TASK:
imports.append("tasks.human_input_timeout_tasks")
beat_schedule["human_input_form_timeout"] = {
"task": "human_input_form_timeout.check_and_resume",
"schedule": timedelta(minutes=dify_config.HUMAN_INPUT_TIMEOUT_TASK_INTERVAL),
}
if dify_config.ENABLE_CHECK_UPGRADABLE_PLUGIN_TASK and dify_config.MARKETPLACE_ENABLED:
imports.append("schedule.check_upgradable_plugin_task")
imports.append("tasks.process_tenant_plugin_autoupgrade_check_task")

View File

@ -8,12 +8,16 @@ from typing import TYPE_CHECKING, Any, ParamSpec, TypeVar, Union
import redis
from redis import RedisError
from redis.cache import CacheConfig
from redis.client import PubSub
from redis.cluster import ClusterNode, RedisCluster
from redis.connection import Connection, SSLConnection
from redis.sentinel import Sentinel
from configs import dify_config
from dify_app import DifyApp
from libs.broadcast_channel.channel import BroadcastChannel as BroadcastChannelProtocol
from libs.broadcast_channel.redis.channel import BroadcastChannel as RedisBroadcastChannel
from libs.broadcast_channel.redis.sharded_channel import ShardedRedisBroadcastChannel
if TYPE_CHECKING:
from redis.lock import Lock
@ -106,6 +110,7 @@ class RedisClientWrapper:
def zremrangebyscore(self, name: str | bytes, min: float | str, max: float | str) -> Any: ...
def zcard(self, name: str | bytes) -> Any: ...
def getdel(self, name: str | bytes) -> Any: ...
def pubsub(self) -> PubSub: ...
def __getattr__(self, item: str) -> Any:
if self._client is None:
@ -114,6 +119,7 @@ class RedisClientWrapper:
redis_client: RedisClientWrapper = RedisClientWrapper()
pubsub_redis_client: RedisClientWrapper = RedisClientWrapper()
def _get_ssl_configuration() -> tuple[type[Union[Connection, SSLConnection]], dict[str, Any]]:
@ -226,6 +232,12 @@ def _create_standalone_client(redis_params: dict[str, Any]) -> Union[redis.Redis
return client
def _create_pubsub_client(pubsub_url: str, use_clusters: bool) -> Union[redis.Redis, RedisCluster]:
if use_clusters:
return RedisCluster.from_url(pubsub_url)
return redis.Redis.from_url(pubsub_url)
def init_app(app: DifyApp):
"""Initialize Redis client and attach it to the app."""
global redis_client
@ -244,6 +256,24 @@ def init_app(app: DifyApp):
redis_client.initialize(client)
app.extensions["redis"] = redis_client
pubsub_client = client
if dify_config.normalized_pubsub_redis_url:
pubsub_client = _create_pubsub_client(
dify_config.normalized_pubsub_redis_url, dify_config.PUBSUB_REDIS_USE_CLUSTERS
)
pubsub_redis_client.initialize(pubsub_client)
def get_pubsub_redis_client() -> RedisClientWrapper:
return pubsub_redis_client
def get_pubsub_broadcast_channel() -> BroadcastChannelProtocol:
redis_conn = get_pubsub_redis_client()
if dify_config.PUBSUB_REDIS_CHANNEL_TYPE == "sharded":
return ShardedRedisBroadcastChannel(redis_conn) # pyright: ignore[reportArgumentType]
return RedisBroadcastChannel(redis_conn) # pyright: ignore[reportArgumentType]
P = ParamSpec("P")
R = TypeVar("R")

View File

@ -13,6 +13,7 @@ from typing import Any
from sqlalchemy.orm import sessionmaker
from core.workflow.enums import WorkflowNodeExecutionStatus
from extensions.logstore.aliyun_logstore import AliyunLogStore
from extensions.logstore.repositories import safe_float, safe_int
from extensions.logstore.sql_escape import escape_identifier, escape_logstore_query_value
@ -207,8 +208,10 @@ class LogstoreAPIWorkflowNodeExecutionRepository(DifyAPIWorkflowNodeExecutionRep
reverse=True,
)
if deduplicated_results:
return _dict_to_workflow_node_execution_model(deduplicated_results[0])
for row in deduplicated_results:
model = _dict_to_workflow_node_execution_model(row)
if model.status != WorkflowNodeExecutionStatus.PAUSED:
return model
return None
@ -309,6 +312,8 @@ class LogstoreAPIWorkflowNodeExecutionRepository(DifyAPIWorkflowNodeExecutionRep
if model and model.id: # Ensure model is valid
models.append(model)
models = [model for model in models if model.status != WorkflowNodeExecutionStatus.PAUSED]
# Sort by index DESC for trace visualization
models.sort(key=lambda x: x.index, reverse=True)

View File

@ -192,6 +192,7 @@ class StatusCount(ResponseModel):
success: int
failed: int
partial_success: int
paused: int
class ModelConfig(ResponseModel):

View File

@ -6,6 +6,7 @@ from uuid import uuid4
from pydantic import BaseModel, ConfigDict, Field, field_validator
from core.entities.execution_extra_content import ExecutionExtraContentDomainModel
from core.file import File
from fields.conversation_fields import AgentThought, JSONValue, MessageFile
@ -61,6 +62,7 @@ class MessageListItem(ResponseModel):
message_files: list[MessageFile]
status: str
error: str | None = None
extra_contents: list[ExecutionExtraContentDomainModel]
@field_validator("inputs", mode="before")
@classmethod

View File

@ -162,7 +162,7 @@ class RedisSubscriptionBase(Subscription):
self._start_if_needed()
return iter(self._message_iterator())
def receive(self, timeout: float | None = None) -> bytes | None:
def receive(self, timeout: float | None = 0.1) -> bytes | None:
"""Receive the next message from the subscription."""
if self._closed.is_set():
raise SubscriptionClosedError(f"The Redis {self._get_subscription_type()} subscription is closed")

View File

@ -61,7 +61,14 @@ class _RedisShardedSubscription(RedisSubscriptionBase):
def _get_message(self) -> dict | None:
assert self._pubsub is not None
return self._pubsub.get_sharded_message(ignore_subscribe_messages=True, timeout=0.1) # type: ignore[attr-defined]
# NOTE(QuantumGhost): this is an issue in
# upstream code. If Sharded PubSub is used with Cluster, the
# `ClusterPubSub.get_sharded_message` will return `None` regardless of
# message['type'].
#
# Since we have already filtered at the caller's site, we can safely set
# `ignore_subscribe_messages=False`.
return self._pubsub.get_sharded_message(ignore_subscribe_messages=False, timeout=0.1) # type: ignore[attr-defined]
def _get_message_type(self) -> str:
return "smessage"

View File

@ -0,0 +1,49 @@
"""
Email template rendering helpers with configurable safety modes.
"""
import time
from collections.abc import Mapping
from typing import Any
from flask import render_template_string
from jinja2.runtime import Context
from jinja2.sandbox import ImmutableSandboxedEnvironment
from configs import dify_config
from configs.feature import TemplateMode
class SandboxedEnvironment(ImmutableSandboxedEnvironment):
"""Sandboxed environment with execution timeout."""
def __init__(self, timeout: int, *args: Any, **kwargs: Any):
self._deadline = time.time() + timeout if timeout else None
super().__init__(*args, **kwargs)
def call(self, context: Context, obj: Any, *args: Any, **kwargs: Any) -> Any:
if self._deadline is not None and time.time() > self._deadline:
raise TimeoutError("Template rendering timeout")
return super().call(context, obj, *args, **kwargs)
def render_email_template(template: str, substitutions: Mapping[str, str]) -> str:
"""
Render email template content according to the configured template mode.
In unsafe mode, Jinja expressions are evaluated directly.
In sandbox mode, a sandboxed environment with timeout is used.
In disabled mode, the template is returned without rendering.
"""
mode = dify_config.MAIL_TEMPLATING_MODE
timeout = dify_config.MAIL_TEMPLATING_TIMEOUT
if mode == TemplateMode.UNSAFE:
return render_template_string(template, **substitutions)
if mode == TemplateMode.SANDBOX:
env = SandboxedEnvironment(timeout=timeout)
tmpl = env.from_string(template)
return tmpl.render(substitutions)
if mode == TemplateMode.DISABLED:
return template
raise ValueError(f"Unsupported mail templating mode: {mode}")

View File

@ -1,12 +1,15 @@
import contextvars
from collections.abc import Iterator
from contextlib import contextmanager
from typing import TypeVar
from typing import TYPE_CHECKING, TypeVar
from flask import Flask, g
T = TypeVar("T")
if TYPE_CHECKING:
from models import Account, EndUser
@contextmanager
def preserve_flask_contexts(
@ -64,3 +67,7 @@ def preserve_flask_contexts(
finally:
# Any cleanup can be added here if needed
pass
def set_login_user(user: "Account | EndUser"):
g._login_user = user

View File

@ -7,10 +7,10 @@ import struct
import subprocess
import time
import uuid
from collections.abc import Generator, Mapping
from collections.abc import Callable, Generator, Mapping
from datetime import datetime
from hashlib import sha256
from typing import TYPE_CHECKING, Annotated, Any, Optional, Union, cast
from typing import TYPE_CHECKING, Annotated, Any, Optional, Protocol, Union, cast
from uuid import UUID
from zoneinfo import available_timezones
@ -126,6 +126,13 @@ class TimestampField(fields.Raw):
return int(value.timestamp())
class OptionalTimestampField(fields.Raw):
def format(self, value) -> int | None:
if value is None:
return None
return int(value.timestamp())
def email(email):
# Define a regex pattern for email addresses
pattern = r"^[\w\.!#$%&'*+\-/=?^_`{|}~]+@([\w-]+\.)+[\w-]{2,}$"
@ -237,6 +244,26 @@ def convert_datetime_to_date(field, target_timezone: str = ":tz"):
def generate_string(n):
"""
Generates a cryptographically secure random string of the specified length.
This function uses a cryptographically secure pseudorandom number generator (CSPRNG)
to create a string composed of ASCII letters (both uppercase and lowercase) and digits.
Each character in the generated string provides approximately 5.95 bits of entropy
(log2(62)). To ensure a minimum of 128 bits of entropy for security purposes, the
length of the string (`n`) should be at least 22 characters.
Args:
n (int): The length of the random string to generate. For secure usage,
`n` should be 22 or greater.
Returns:
str: A random string of length `n` composed of ASCII letters and digits.
Note:
This function is suitable for generating credentials or other secure tokens.
"""
letters_digits = string.ascii_letters + string.digits
result = ""
for _ in range(n):
@ -405,11 +432,35 @@ class TokenManager:
return f"{token_type}:account:{account_id}"
class _RateLimiterRedisClient(Protocol):
def zadd(self, name: str | bytes, mapping: dict[str | bytes | int | float, float | int | str | bytes]) -> int: ...
def zremrangebyscore(self, name: str | bytes, min: str | float, max: str | float) -> int: ...
def zcard(self, name: str | bytes) -> int: ...
def expire(self, name: str | bytes, time: int) -> bool: ...
def _default_rate_limit_member_factory() -> str:
current_time = int(time.time())
return f"{current_time}:{secrets.token_urlsafe(nbytes=8)}"
class RateLimiter:
def __init__(self, prefix: str, max_attempts: int, time_window: int):
def __init__(
self,
prefix: str,
max_attempts: int,
time_window: int,
member_factory: Callable[[], str] = _default_rate_limit_member_factory,
redis_client: _RateLimiterRedisClient = redis_client,
):
self.prefix = prefix
self.max_attempts = max_attempts
self.time_window = time_window
self._member_factory = member_factory
self._redis_client = redis_client
def _get_key(self, email: str) -> str:
return f"{self.prefix}:{email}"
@ -419,8 +470,8 @@ class RateLimiter:
current_time = int(time.time())
window_start_time = current_time - self.time_window
redis_client.zremrangebyscore(key, "-inf", window_start_time)
attempts = redis_client.zcard(key)
self._redis_client.zremrangebyscore(key, "-inf", window_start_time)
attempts = self._redis_client.zcard(key)
if attempts and int(attempts) >= self.max_attempts:
return True
@ -428,7 +479,8 @@ class RateLimiter:
def increment_rate_limit(self, email: str):
key = self._get_key(email)
member = self._member_factory()
current_time = int(time.time())
redis_client.zadd(key, {current_time: current_time})
redis_client.expire(key, self.time_window * 2)
self._redis_client.zadd(key, {member: current_time})
self._redis_client.expire(key, self.time_window * 2)

View File

@ -0,0 +1,99 @@
"""Add human input related db models
Revision ID: e8c3b3c46151
Revises: 788d3099ae3a
Create Date: 2026-01-29 14:15:23.081903
"""
from alembic import op
import models as models
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "e8c3b3c46151"
down_revision = "788d3099ae3a"
branch_labels = None
depends_on = None
def upgrade():
op.create_table(
"execution_extra_contents",
sa.Column("id", models.types.StringUUID(), nullable=False),
sa.Column("created_at", sa.DateTime(), server_default=sa.text("CURRENT_TIMESTAMP"), nullable=False),
sa.Column("updated_at", sa.DateTime(), server_default=sa.text("CURRENT_TIMESTAMP"), nullable=False),
sa.Column("type", sa.String(length=30), nullable=False),
sa.Column("workflow_run_id", models.types.StringUUID(), nullable=False),
sa.Column("message_id", models.types.StringUUID(), nullable=True),
sa.Column("form_id", models.types.StringUUID(), nullable=True),
sa.PrimaryKeyConstraint("id", name=op.f("execution_extra_contents_pkey")),
)
with op.batch_alter_table("execution_extra_contents", schema=None) as batch_op:
batch_op.create_index(batch_op.f("execution_extra_contents_message_id_idx"), ["message_id"], unique=False)
batch_op.create_index(
batch_op.f("execution_extra_contents_workflow_run_id_idx"), ["workflow_run_id"], unique=False
)
op.create_table(
"human_input_form_deliveries",
sa.Column("id", models.types.StringUUID(), nullable=False),
sa.Column("created_at", sa.DateTime(), server_default=sa.text("CURRENT_TIMESTAMP"), nullable=False),
sa.Column("updated_at", sa.DateTime(), server_default=sa.text("CURRENT_TIMESTAMP"), nullable=False),
sa.Column("form_id", models.types.StringUUID(), nullable=False),
sa.Column("delivery_method_type", sa.String(length=20), nullable=False),
sa.Column("delivery_config_id", models.types.StringUUID(), nullable=True),
sa.Column("channel_payload", sa.Text(), nullable=False),
sa.PrimaryKeyConstraint("id", name=op.f("human_input_form_deliveries_pkey")),
)
op.create_table(
"human_input_form_recipients",
sa.Column("id", models.types.StringUUID(), nullable=False),
sa.Column("created_at", sa.DateTime(), server_default=sa.text("CURRENT_TIMESTAMP"), nullable=False),
sa.Column("updated_at", sa.DateTime(), server_default=sa.text("CURRENT_TIMESTAMP"), nullable=False),
sa.Column("form_id", models.types.StringUUID(), nullable=False),
sa.Column("delivery_id", models.types.StringUUID(), nullable=False),
sa.Column("recipient_type", sa.String(length=20), nullable=False),
sa.Column("recipient_payload", sa.Text(), nullable=False),
sa.Column("access_token", sa.VARCHAR(length=32), nullable=False),
sa.PrimaryKeyConstraint("id", name=op.f("human_input_form_recipients_pkey")),
)
with op.batch_alter_table('human_input_form_recipients', schema=None) as batch_op:
batch_op.create_unique_constraint(batch_op.f('human_input_form_recipients_access_token_key'), ['access_token'])
op.create_table(
"human_input_forms",
sa.Column("id", models.types.StringUUID(), nullable=False),
sa.Column("created_at", sa.DateTime(), server_default=sa.text("CURRENT_TIMESTAMP"), nullable=False),
sa.Column("updated_at", sa.DateTime(), server_default=sa.text("CURRENT_TIMESTAMP"), nullable=False),
sa.Column("tenant_id", models.types.StringUUID(), nullable=False),
sa.Column("app_id", models.types.StringUUID(), nullable=False),
sa.Column("workflow_run_id", models.types.StringUUID(), nullable=True),
sa.Column("form_kind", sa.String(length=20), nullable=False),
sa.Column("node_id", sa.String(length=60), nullable=False),
sa.Column("form_definition", sa.Text(), nullable=False),
sa.Column("rendered_content", sa.Text(), nullable=False),
sa.Column("status", sa.String(length=20), nullable=False),
sa.Column("expiration_time", sa.DateTime(), nullable=False),
sa.Column("selected_action_id", sa.String(length=200), nullable=True),
sa.Column("submitted_data", sa.Text(), nullable=True),
sa.Column("submitted_at", sa.DateTime(), nullable=True),
sa.Column("submission_user_id", models.types.StringUUID(), nullable=True),
sa.Column("submission_end_user_id", models.types.StringUUID(), nullable=True),
sa.Column("completed_by_recipient_id", models.types.StringUUID(), nullable=True),
sa.PrimaryKeyConstraint("id", name=op.f("human_input_forms_pkey")),
)
def downgrade():
op.drop_table("human_input_forms")
op.drop_table("human_input_form_recipients")
op.drop_table("human_input_form_deliveries")
op.drop_table("execution_extra_contents")

View File

@ -34,6 +34,8 @@ from .enums import (
WorkflowRunTriggeredFrom,
WorkflowTriggerStatus,
)
from .execution_extra_content import ExecutionExtraContent, HumanInputContent
from .human_input import HumanInputForm
from .model import (
AccountTrialAppRecord,
ApiRequest,
@ -155,9 +157,12 @@ __all__ = [
"DocumentSegment",
"Embedding",
"EndUser",
"ExecutionExtraContent",
"ExporleBanner",
"ExternalKnowledgeApis",
"ExternalKnowledgeBindings",
"HumanInputContent",
"HumanInputForm",
"IconType",
"InstalledApp",
"InvitationCode",

View File

@ -41,7 +41,7 @@ class DefaultFieldsMixin:
)
updated_at: Mapped[datetime] = mapped_column(
__name_pos=DateTime,
DateTime,
nullable=False,
default=naive_utc_now,
server_default=func.current_timestamp(),

View File

@ -36,6 +36,7 @@ class MessageStatus(StrEnum):
"""
NORMAL = "normal"
PAUSED = "paused"
ERROR = "error"

View File

@ -0,0 +1,78 @@
from enum import StrEnum, auto
from typing import TYPE_CHECKING
from sqlalchemy.orm import Mapped, mapped_column, relationship
from .base import Base, DefaultFieldsMixin
from .types import EnumText, StringUUID
if TYPE_CHECKING:
from .human_input import HumanInputForm
class ExecutionContentType(StrEnum):
HUMAN_INPUT = auto()
class ExecutionExtraContent(DefaultFieldsMixin, Base):
"""ExecutionExtraContent stores extra contents produced during workflow / chatflow execution."""
# The `ExecutionExtraContent` uses single table inheritance to model different
# kinds of contents produced during message generation.
#
# See: https://docs.sqlalchemy.org/en/20/orm/inheritance.html#single-table-inheritance
__tablename__ = "execution_extra_contents"
__mapper_args__ = {
"polymorphic_abstract": True,
"polymorphic_on": "type",
"with_polymorphic": "*",
}
# type records the type of the content. It serves as the `discriminator` for the
# single table inheritance.
type: Mapped[ExecutionContentType] = mapped_column(
EnumText(ExecutionContentType, length=30),
nullable=False,
)
# `workflow_run_id` records the workflow execution which generates this content, correspond to
# `WorkflowRun.id`.
workflow_run_id: Mapped[str] = mapped_column(StringUUID, nullable=False, index=True)
# `message_id` records the messages generated by the execution associated with this `ExecutionExtraContent`.
# It references to `Message.id`.
#
# For workflow execution, this field is `None`.
#
# For chatflow execution, `message_id`` is not None, and the following condition holds:
#
# The message referenced by `message_id` has `message.workflow_run_id == execution_extra_content.workflow_run_id`
#
message_id: Mapped[str | None] = mapped_column(StringUUID, nullable=True, index=True)
class HumanInputContent(ExecutionExtraContent):
"""HumanInputContent is a concrete class that represents human input content.
It should only be initialized with the `new` class method."""
__mapper_args__ = {
"polymorphic_identity": ExecutionContentType.HUMAN_INPUT,
}
# A relation to HumanInputForm table.
#
# While the form_id column is nullable in database (due to the nature of single table inheritance),
# the form_id field should not be null for a given `HumanInputContent` instance.
form_id: Mapped[str] = mapped_column(StringUUID, nullable=True)
@classmethod
def new(cls, form_id: str, message_id: str | None) -> "HumanInputContent":
return cls(form_id=form_id, message_id=message_id)
form: Mapped["HumanInputForm"] = relationship(
"HumanInputForm",
foreign_keys=[form_id],
uselist=False,
lazy="raise",
primaryjoin="foreign(HumanInputContent.form_id) == HumanInputForm.id",
)

237
api/models/human_input.py Normal file
View File

@ -0,0 +1,237 @@
from datetime import datetime
from enum import StrEnum
from typing import Annotated, Literal, Self, final
import sqlalchemy as sa
from pydantic import BaseModel, Field
from sqlalchemy.orm import Mapped, mapped_column, relationship
from core.workflow.nodes.human_input.enums import (
DeliveryMethodType,
HumanInputFormKind,
HumanInputFormStatus,
)
from libs.helper import generate_string
from .base import Base, DefaultFieldsMixin
from .types import EnumText, StringUUID
_token_length = 22
# A 32-character string can store a base64-encoded value with 192 bits of entropy
# or a base62-encoded value with over 180 bits of entropy, providing sufficient
# uniqueness for most use cases.
_token_field_length = 32
_email_field_length = 330
def _generate_token() -> str:
return generate_string(_token_length)
class HumanInputForm(DefaultFieldsMixin, Base):
__tablename__ = "human_input_forms"
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
app_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
workflow_run_id: Mapped[str | None] = mapped_column(StringUUID, nullable=True)
form_kind: Mapped[HumanInputFormKind] = mapped_column(
EnumText(HumanInputFormKind),
nullable=False,
default=HumanInputFormKind.RUNTIME,
)
# The human input node the current form corresponds to.
node_id: Mapped[str] = mapped_column(sa.String(60), nullable=False)
form_definition: Mapped[str] = mapped_column(sa.Text, nullable=False)
rendered_content: Mapped[str] = mapped_column(sa.Text, nullable=False)
status: Mapped[HumanInputFormStatus] = mapped_column(
EnumText(HumanInputFormStatus),
nullable=False,
default=HumanInputFormStatus.WAITING,
)
expiration_time: Mapped[datetime] = mapped_column(
sa.DateTime,
nullable=False,
)
# Submission-related fields (nullable until a submission happens).
selected_action_id: Mapped[str | None] = mapped_column(sa.String(200), nullable=True)
submitted_data: Mapped[str | None] = mapped_column(sa.Text, nullable=True)
submitted_at: Mapped[datetime | None] = mapped_column(sa.DateTime, nullable=True)
submission_user_id: Mapped[str | None] = mapped_column(StringUUID, nullable=True)
submission_end_user_id: Mapped[str | None] = mapped_column(StringUUID, nullable=True)
completed_by_recipient_id: Mapped[str | None] = mapped_column(
StringUUID,
nullable=True,
)
deliveries: Mapped[list["HumanInputDelivery"]] = relationship(
"HumanInputDelivery",
primaryjoin="HumanInputForm.id == foreign(HumanInputDelivery.form_id)",
uselist=True,
back_populates="form",
lazy="raise",
)
completed_by_recipient: Mapped["HumanInputFormRecipient | None"] = relationship(
"HumanInputFormRecipient",
primaryjoin="HumanInputForm.completed_by_recipient_id == foreign(HumanInputFormRecipient.id)",
lazy="raise",
viewonly=True,
)
class HumanInputDelivery(DefaultFieldsMixin, Base):
__tablename__ = "human_input_form_deliveries"
form_id: Mapped[str] = mapped_column(
StringUUID,
nullable=False,
)
delivery_method_type: Mapped[DeliveryMethodType] = mapped_column(
EnumText(DeliveryMethodType),
nullable=False,
)
delivery_config_id: Mapped[str | None] = mapped_column(StringUUID, nullable=True)
channel_payload: Mapped[str] = mapped_column(sa.Text, nullable=False)
form: Mapped[HumanInputForm] = relationship(
"HumanInputForm",
uselist=False,
foreign_keys=[form_id],
primaryjoin="HumanInputDelivery.form_id == HumanInputForm.id",
back_populates="deliveries",
lazy="raise",
)
recipients: Mapped[list["HumanInputFormRecipient"]] = relationship(
"HumanInputFormRecipient",
primaryjoin="HumanInputDelivery.id == foreign(HumanInputFormRecipient.delivery_id)",
uselist=True,
back_populates="delivery",
# Require explicit preloading
lazy="raise",
)
class RecipientType(StrEnum):
# EMAIL_MEMBER member means that the
EMAIL_MEMBER = "email_member"
EMAIL_EXTERNAL = "email_external"
# STANDALONE_WEB_APP is used by the standalone web app.
#
# It's not used while running workflows / chatflows containing HumanInput
# node inside console.
STANDALONE_WEB_APP = "standalone_web_app"
# CONSOLE is used while running workflows / chatflows containing HumanInput
# node inside console. (E.G. running installed apps or debugging workflows / chatflows)
CONSOLE = "console"
# BACKSTAGE is used for backstage input inside console.
BACKSTAGE = "backstage"
@final
class EmailMemberRecipientPayload(BaseModel):
TYPE: Literal[RecipientType.EMAIL_MEMBER] = RecipientType.EMAIL_MEMBER
user_id: str
# The `email` field here is only used for mail sending.
email: str
@final
class EmailExternalRecipientPayload(BaseModel):
TYPE: Literal[RecipientType.EMAIL_EXTERNAL] = RecipientType.EMAIL_EXTERNAL
email: str
@final
class StandaloneWebAppRecipientPayload(BaseModel):
TYPE: Literal[RecipientType.STANDALONE_WEB_APP] = RecipientType.STANDALONE_WEB_APP
@final
class ConsoleRecipientPayload(BaseModel):
TYPE: Literal[RecipientType.CONSOLE] = RecipientType.CONSOLE
account_id: str | None = None
@final
class BackstageRecipientPayload(BaseModel):
TYPE: Literal[RecipientType.BACKSTAGE] = RecipientType.BACKSTAGE
account_id: str | None = None
@final
class ConsoleDeliveryPayload(BaseModel):
type: Literal["console"] = "console"
internal: bool = True
RecipientPayload = Annotated[
EmailMemberRecipientPayload
| EmailExternalRecipientPayload
| StandaloneWebAppRecipientPayload
| ConsoleRecipientPayload
| BackstageRecipientPayload,
Field(discriminator="TYPE"),
]
class HumanInputFormRecipient(DefaultFieldsMixin, Base):
__tablename__ = "human_input_form_recipients"
form_id: Mapped[str] = mapped_column(
StringUUID,
nullable=False,
)
delivery_id: Mapped[str] = mapped_column(
StringUUID,
nullable=False,
)
recipient_type: Mapped["RecipientType"] = mapped_column(EnumText(RecipientType), nullable=False)
recipient_payload: Mapped[str] = mapped_column(sa.Text, nullable=False)
# Token primarily used for authenticated resume links (email, etc.).
access_token: Mapped[str | None] = mapped_column(
sa.VARCHAR(_token_field_length),
nullable=False,
default=_generate_token,
unique=True,
)
delivery: Mapped[HumanInputDelivery] = relationship(
"HumanInputDelivery",
uselist=False,
foreign_keys=[delivery_id],
back_populates="recipients",
primaryjoin="HumanInputFormRecipient.delivery_id == HumanInputDelivery.id",
# Require explicit preloading
lazy="raise",
)
form: Mapped[HumanInputForm] = relationship(
"HumanInputForm",
uselist=False,
foreign_keys=[form_id],
primaryjoin="HumanInputFormRecipient.form_id == HumanInputForm.id",
# Require explicit preloading
lazy="raise",
)
@classmethod
def new(
cls,
form_id: str,
delivery_id: str,
payload: RecipientPayload,
) -> Self:
recipient_model = cls(
form_id=form_id,
delivery_id=delivery_id,
recipient_type=payload.TYPE,
recipient_payload=payload.model_dump_json(),
access_token=_generate_token(),
)
return recipient_model

View File

@ -3,7 +3,7 @@ from __future__ import annotations
import json
import re
import uuid
from collections.abc import Mapping
from collections.abc import Mapping, Sequence
from datetime import datetime
from decimal import Decimal
from enum import StrEnum, auto
@ -943,6 +943,7 @@ class Conversation(Base):
WorkflowExecutionStatus.FAILED: 0,
WorkflowExecutionStatus.STOPPED: 0,
WorkflowExecutionStatus.PARTIAL_SUCCEEDED: 0,
WorkflowExecutionStatus.PAUSED: 0,
}
for message in messages:
@ -963,6 +964,7 @@ class Conversation(Base):
"success": status_counts[WorkflowExecutionStatus.SUCCEEDED],
"failed": status_counts[WorkflowExecutionStatus.FAILED],
"partial_success": status_counts[WorkflowExecutionStatus.PARTIAL_SUCCEEDED],
"paused": status_counts[WorkflowExecutionStatus.PAUSED],
}
@property
@ -1345,6 +1347,14 @@ class Message(Base):
db.session.commit()
return result
# TODO(QuantumGhost): dirty hacks, fix this later.
def set_extra_contents(self, contents: Sequence[dict[str, Any]]) -> None:
self._extra_contents = list(contents)
@property
def extra_contents(self) -> list[dict[str, Any]]:
return getattr(self, "_extra_contents", [])
@property
def workflow_run(self):
if self.workflow_run_id:

View File

@ -20,6 +20,7 @@ from sqlalchemy import (
select,
)
from sqlalchemy.orm import Mapped, declared_attr, mapped_column
from typing_extensions import deprecated
from core.file.constants import maybe_file_object
from core.file.models import File
@ -30,7 +31,7 @@ from core.workflow.constants import (
SYSTEM_VARIABLE_NODE_ID,
)
from core.workflow.entities.pause_reason import HumanInputRequired, PauseReason, PauseReasonType, SchedulingPause
from core.workflow.enums import NodeType
from core.workflow.enums import NodeType, WorkflowExecutionStatus
from extensions.ext_storage import Storage
from factories.variable_factory import TypeMismatchError, build_segment_with_type
from libs.datetime_utils import naive_utc_now
@ -405,6 +406,11 @@ class Workflow(Base): # bug
return helper.generate_text_hash(json.dumps(entity, sort_keys=True))
@property
@deprecated(
"This property is not accurate for determining if a workflow is published as a tool."
"It only checks if there's a WorkflowToolProvider for the app, "
"not if this specific workflow version is the one being used by the tool."
)
def tool_published(self) -> bool:
"""
DEPRECATED: This property is not accurate for determining if a workflow is published as a tool.
@ -607,13 +613,16 @@ class WorkflowRun(Base):
version: Mapped[str] = mapped_column(String(255))
graph: Mapped[str | None] = mapped_column(LongText)
inputs: Mapped[str | None] = mapped_column(LongText)
status: Mapped[str] = mapped_column(String(255)) # running, succeeded, failed, stopped, partial-succeeded
status: Mapped[WorkflowExecutionStatus] = mapped_column(
EnumText(WorkflowExecutionStatus, length=255),
nullable=False,
)
outputs: Mapped[str | None] = mapped_column(LongText, default="{}")
error: Mapped[str | None] = mapped_column(LongText)
elapsed_time: Mapped[float] = mapped_column(sa.Float, nullable=False, server_default=sa.text("0"))
total_tokens: Mapped[int] = mapped_column(sa.BigInteger, server_default=sa.text("0"))
total_steps: Mapped[int] = mapped_column(sa.Integer, server_default=sa.text("0"), nullable=True)
created_by_role: Mapped[str] = mapped_column(String(255)) # account, end_user
created_by_role: Mapped[CreatorUserRole] = mapped_column(EnumText(CreatorUserRole, length=255)) # account, end_user
created_by: Mapped[str] = mapped_column(StringUUID, nullable=False)
created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp())
finished_at: Mapped[datetime | None] = mapped_column(DateTime)
@ -629,11 +638,13 @@ class WorkflowRun(Base):
)
@property
@deprecated("This method is retained for historical reasons; avoid using it if possible.")
def created_by_account(self):
created_by_role = CreatorUserRole(self.created_by_role)
return db.session.get(Account, self.created_by) if created_by_role == CreatorUserRole.ACCOUNT else None
@property
@deprecated("This method is retained for historical reasons; avoid using it if possible.")
def created_by_end_user(self):
from .model import EndUser
@ -653,6 +664,7 @@ class WorkflowRun(Base):
return json.loads(self.outputs) if self.outputs else {}
@property
@deprecated("This method is retained for historical reasons; avoid using it if possible.")
def message(self):
from .model import Message
@ -661,6 +673,7 @@ class WorkflowRun(Base):
)
@property
@deprecated("This method is retained for historical reasons; avoid using it if possible.")
def workflow(self):
return db.session.query(Workflow).where(Workflow.id == self.workflow_id).first()
@ -1861,7 +1874,12 @@ class WorkflowPauseReason(DefaultFieldsMixin, Base):
def to_entity(self) -> PauseReason:
if self.type_ == PauseReasonType.HUMAN_INPUT_REQUIRED:
return HumanInputRequired(form_id=self.form_id, node_id=self.node_id)
return HumanInputRequired(
form_id=self.form_id,
form_content="",
node_id=self.node_id,
node_title="",
)
elif self.type_ == PauseReasonType.SCHEDULED_PAUSE:
return SchedulingPause(message=self.message)
else:

View File

@ -10,6 +10,7 @@ tenant_id, app_id, triggered_from, etc., which are not part of the core domain m
"""
from collections.abc import Sequence
from dataclasses import dataclass
from datetime import datetime
from typing import Protocol
@ -19,6 +20,27 @@ from core.workflow.repositories.workflow_node_execution_repository import Workfl
from models.workflow import WorkflowNodeExecutionModel, WorkflowNodeExecutionOffload
@dataclass(frozen=True)
class WorkflowNodeExecutionSnapshot:
"""
Minimal snapshot of workflow node execution for stream recovery.
Only includes fields required by snapshot events.
"""
execution_id: str # Unique execution identifier (node_execution_id or row id).
node_id: str # Workflow graph node id.
node_type: str # Workflow graph node type (e.g. "human-input").
title: str # Human-friendly node title.
index: int # Execution order index within the workflow run.
status: str # Execution status (running/succeeded/failed/paused).
elapsed_time: float # Execution elapsed time in seconds.
created_at: datetime # Execution created timestamp.
finished_at: datetime | None # Execution finished timestamp.
iteration_id: str | None = None # Iteration id from execution metadata, if any.
loop_id: str | None = None # Loop id from execution metadata, if any.
class DifyAPIWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository, Protocol):
"""
Protocol for service-layer operations on WorkflowNodeExecutionModel.
@ -79,6 +101,8 @@ class DifyAPIWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository, Pr
Args:
tenant_id: The tenant identifier
app_id: The application identifier
workflow_id: The workflow identifier
triggered_from: The workflow trigger source
workflow_run_id: The workflow run identifier
Returns:
@ -86,6 +110,27 @@ class DifyAPIWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository, Pr
"""
...
def get_execution_snapshots_by_workflow_run(
self,
tenant_id: str,
app_id: str,
workflow_id: str,
triggered_from: str,
workflow_run_id: str,
) -> Sequence[WorkflowNodeExecutionSnapshot]:
"""
Get minimal snapshots for node executions in a workflow run.
Args:
tenant_id: The tenant identifier
app_id: The application identifier
workflow_run_id: The workflow run identifier
Returns:
A sequence of WorkflowNodeExecutionSnapshot ordered by creation time
"""
...
def get_execution_by_id(
self,
execution_id: str,

View File

@ -432,6 +432,13 @@ class APIWorkflowRunRepository(WorkflowExecutionRepository, Protocol):
# while creating pause.
...
def get_workflow_pause(self, workflow_run_id: str) -> WorkflowPauseEntity | None:
"""Retrieve the current pause for a workflow execution.
If there is no current pause, this method would return `None`.
"""
...
def resume_workflow_pause(
self,
workflow_run_id: str,
@ -627,3 +634,19 @@ class APIWorkflowRunRepository(WorkflowExecutionRepository, Protocol):
[{"date": "2024-01-01", "interactions": 2.5}, ...]
"""
...
def get_workflow_run_by_id_and_tenant_id(self, tenant_id: str, run_id: str) -> WorkflowRun | None:
"""
Get a specific workflow run by its id and the associated tenant id.
This function does not apply application isolation. It should only be used when
the application identifier is not available.
Args:
tenant_id: Tenant identifier for multi-tenant isolation
run_id: Workflow run identifier
Returns:
WorkflowRun object if found, None otherwise
"""
...

View File

@ -63,6 +63,12 @@ class WorkflowPauseEntity(ABC):
"""
pass
@property
@abstractmethod
def paused_at(self) -> datetime:
"""`paused_at` returns the creation time of the pause."""
pass
@abstractmethod
def get_pause_reasons(self) -> Sequence[PauseReason]:
"""
@ -70,7 +76,5 @@ class WorkflowPauseEntity(ABC):
Returns a sequence of `PauseReason` objects describing the specific nodes and
reasons for which the workflow execution was paused.
This information is related to, but distinct from, the `PauseReason` type
defined in `api/core/workflow/entities/pause_reason.py`.
"""
...

View File

@ -0,0 +1,13 @@
from __future__ import annotations
from collections.abc import Sequence
from typing import Protocol
from core.entities.execution_extra_content import ExecutionExtraContentDomainModel
class ExecutionExtraContentRepository(Protocol):
def get_by_message_ids(self, message_ids: Sequence[str]) -> list[list[ExecutionExtraContentDomainModel]]: ...
__all__ = ["ExecutionExtraContentRepository"]

View File

@ -5,6 +5,7 @@ This module provides a concrete implementation of the service repository protoco
using SQLAlchemy 2.0 style queries for WorkflowNodeExecutionModel operations.
"""
import json
from collections.abc import Sequence
from datetime import datetime
from typing import cast
@ -13,11 +14,12 @@ from sqlalchemy import asc, delete, desc, func, select
from sqlalchemy.engine import CursorResult
from sqlalchemy.orm import Session, sessionmaker
from models.workflow import (
WorkflowNodeExecutionModel,
WorkflowNodeExecutionOffload,
from core.workflow.enums import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus
from models.workflow import WorkflowNodeExecutionModel, WorkflowNodeExecutionOffload
from repositories.api_workflow_node_execution_repository import (
DifyAPIWorkflowNodeExecutionRepository,
WorkflowNodeExecutionSnapshot,
)
from repositories.api_workflow_node_execution_repository import DifyAPIWorkflowNodeExecutionRepository
class DifyAPISQLAlchemyWorkflowNodeExecutionRepository(DifyAPIWorkflowNodeExecutionRepository):
@ -79,6 +81,7 @@ class DifyAPISQLAlchemyWorkflowNodeExecutionRepository(DifyAPIWorkflowNodeExecut
WorkflowNodeExecutionModel.app_id == app_id,
WorkflowNodeExecutionModel.workflow_id == workflow_id,
WorkflowNodeExecutionModel.node_id == node_id,
WorkflowNodeExecutionModel.status != WorkflowNodeExecutionStatus.PAUSED,
)
.order_by(desc(WorkflowNodeExecutionModel.created_at))
.limit(1)
@ -117,6 +120,80 @@ class DifyAPISQLAlchemyWorkflowNodeExecutionRepository(DifyAPIWorkflowNodeExecut
with self._session_maker() as session:
return session.execute(stmt).scalars().all()
def get_execution_snapshots_by_workflow_run(
self,
tenant_id: str,
app_id: str,
workflow_id: str,
triggered_from: str,
workflow_run_id: str,
) -> Sequence[WorkflowNodeExecutionSnapshot]:
stmt = (
select(
WorkflowNodeExecutionModel.id,
WorkflowNodeExecutionModel.node_execution_id,
WorkflowNodeExecutionModel.node_id,
WorkflowNodeExecutionModel.node_type,
WorkflowNodeExecutionModel.title,
WorkflowNodeExecutionModel.index,
WorkflowNodeExecutionModel.status,
WorkflowNodeExecutionModel.elapsed_time,
WorkflowNodeExecutionModel.created_at,
WorkflowNodeExecutionModel.finished_at,
WorkflowNodeExecutionModel.execution_metadata,
)
.where(
WorkflowNodeExecutionModel.tenant_id == tenant_id,
WorkflowNodeExecutionModel.app_id == app_id,
WorkflowNodeExecutionModel.workflow_id == workflow_id,
WorkflowNodeExecutionModel.triggered_from == triggered_from,
WorkflowNodeExecutionModel.workflow_run_id == workflow_run_id,
)
.order_by(
asc(WorkflowNodeExecutionModel.created_at),
asc(WorkflowNodeExecutionModel.index),
)
)
with self._session_maker() as session:
rows = session.execute(stmt).all()
return [self._row_to_snapshot(row) for row in rows]
@staticmethod
def _row_to_snapshot(row: object) -> WorkflowNodeExecutionSnapshot:
metadata: dict[str, object] = {}
execution_metadata = getattr(row, "execution_metadata", None)
if execution_metadata:
try:
metadata = json.loads(execution_metadata)
except json.JSONDecodeError:
metadata = {}
iteration_id = metadata.get(WorkflowNodeExecutionMetadataKey.ITERATION_ID.value)
loop_id = metadata.get(WorkflowNodeExecutionMetadataKey.LOOP_ID.value)
execution_id = getattr(row, "node_execution_id", None) or row.id
elapsed_time = getattr(row, "elapsed_time", None)
created_at = row.created_at
finished_at = getattr(row, "finished_at", None)
if elapsed_time is None:
if finished_at is not None and created_at is not None:
elapsed_time = (finished_at - created_at).total_seconds()
else:
elapsed_time = 0.0
return WorkflowNodeExecutionSnapshot(
execution_id=str(execution_id),
node_id=row.node_id,
node_type=row.node_type,
title=row.title,
index=row.index,
status=row.status,
elapsed_time=float(elapsed_time),
created_at=created_at,
finished_at=finished_at,
iteration_id=str(iteration_id) if iteration_id else None,
loop_id=str(loop_id) if loop_id else None,
)
def get_execution_by_id(
self,
execution_id: str,

View File

@ -19,6 +19,7 @@ Implementation Notes:
- Maintains data consistency with proper transaction handling
"""
import json
import logging
import uuid
from collections.abc import Callable, Sequence
@ -27,12 +28,14 @@ from decimal import Decimal
from typing import Any, cast
import sqlalchemy as sa
from pydantic import ValidationError
from sqlalchemy import and_, delete, func, null, or_, select
from sqlalchemy.engine import CursorResult
from sqlalchemy.orm import Session, selectinload, sessionmaker
from core.workflow.entities.pause_reason import HumanInputRequired, PauseReason, SchedulingPause
from core.workflow.entities.pause_reason import HumanInputRequired, PauseReason, PauseReasonType, SchedulingPause
from core.workflow.enums import WorkflowExecutionStatus, WorkflowType
from core.workflow.nodes.human_input.entities import FormDefinition
from extensions.ext_storage import storage
from libs.datetime_utils import naive_utc_now
from libs.helper import convert_datetime_to_date
@ -40,6 +43,7 @@ from libs.infinite_scroll_pagination import InfiniteScrollPagination
from libs.time_parser import get_time_threshold
from libs.uuid_utils import uuidv7
from models.enums import WorkflowRunTriggeredFrom
from models.human_input import HumanInputForm, HumanInputFormRecipient, RecipientType
from models.workflow import WorkflowAppLog, WorkflowArchiveLog, WorkflowPause, WorkflowPauseReason, WorkflowRun
from repositories.api_workflow_run_repository import APIWorkflowRunRepository
from repositories.entities.workflow_pause import WorkflowPauseEntity
@ -57,6 +61,67 @@ class _WorkflowRunError(Exception):
pass
def _select_recipient_token(
recipients: Sequence[HumanInputFormRecipient],
recipient_type: RecipientType,
) -> str | None:
for recipient in recipients:
if recipient.recipient_type == recipient_type and recipient.access_token:
return recipient.access_token
return None
def _build_human_input_required_reason(
reason_model: WorkflowPauseReason,
form_model: HumanInputForm | None,
recipients: Sequence[HumanInputFormRecipient],
) -> HumanInputRequired:
form_content = ""
inputs = []
actions = []
display_in_ui = False
resolved_default_values: dict[str, Any] = {}
node_title = "Human Input"
form_id = reason_model.form_id
node_id = reason_model.node_id
if form_model is not None:
form_id = form_model.id
node_id = form_model.node_id or node_id
try:
definition_payload = json.loads(form_model.form_definition)
if "expiration_time" not in definition_payload:
definition_payload["expiration_time"] = form_model.expiration_time
definition = FormDefinition.model_validate(definition_payload)
except ValidationError:
definition = None
if definition is not None:
form_content = definition.form_content
inputs = list(definition.inputs)
actions = list(definition.user_actions)
display_in_ui = bool(definition.display_in_ui)
resolved_default_values = dict(definition.default_values)
node_title = definition.node_title or node_title
form_token = (
_select_recipient_token(recipients, RecipientType.BACKSTAGE)
or _select_recipient_token(recipients, RecipientType.CONSOLE)
or _select_recipient_token(recipients, RecipientType.STANDALONE_WEB_APP)
)
return HumanInputRequired(
form_id=form_id,
form_content=form_content,
inputs=inputs,
actions=actions,
display_in_ui=display_in_ui,
node_id=node_id,
node_title=node_title,
form_token=form_token,
resolved_default_values=resolved_default_values,
)
class DifyAPISQLAlchemyWorkflowRunRepository(APIWorkflowRunRepository):
"""
SQLAlchemy implementation of APIWorkflowRunRepository.
@ -676,9 +741,11 @@ class DifyAPISQLAlchemyWorkflowRunRepository(APIWorkflowRunRepository):
raise ValueError(f"WorkflowRun not found: {workflow_run_id}")
# Check if workflow is in RUNNING status
if workflow_run.status != WorkflowExecutionStatus.RUNNING:
# TODO(QuantumGhost): It seems that the persistence of `WorkflowRun.status`
# happens before the execution of GraphLayer
if workflow_run.status not in {WorkflowExecutionStatus.RUNNING, WorkflowExecutionStatus.PAUSED}:
raise _WorkflowRunError(
f"Only WorkflowRun with RUNNING status can be paused, "
f"Only WorkflowRun with RUNNING or PAUSED status can be paused, "
f"workflow_run_id={workflow_run_id}, current_status={workflow_run.status}"
)
#
@ -729,13 +796,48 @@ class DifyAPISQLAlchemyWorkflowRunRepository(APIWorkflowRunRepository):
logger.info("Created workflow pause %s for workflow run %s", pause_model.id, workflow_run_id)
return _PrivateWorkflowPauseEntity(pause_model=pause_model, reason_models=pause_reason_models)
return _PrivateWorkflowPauseEntity(
pause_model=pause_model,
reason_models=pause_reason_models,
pause_reasons=pause_reasons,
)
def _get_reasons_by_pause_id(self, session: Session, pause_id: str):
reason_stmt = select(WorkflowPauseReason).where(WorkflowPauseReason.pause_id == pause_id)
pause_reason_models = session.scalars(reason_stmt).all()
return pause_reason_models
def _hydrate_pause_reasons(
self,
session: Session,
pause_reason_models: Sequence[WorkflowPauseReason],
) -> list[PauseReason]:
form_ids = [
reason.form_id
for reason in pause_reason_models
if reason.type_ == PauseReasonType.HUMAN_INPUT_REQUIRED and reason.form_id
]
form_models: dict[str, HumanInputForm] = {}
recipient_models_by_form: dict[str, list[HumanInputFormRecipient]] = {}
if form_ids:
form_stmt = select(HumanInputForm).where(HumanInputForm.id.in_(form_ids))
for form in session.scalars(form_stmt).all():
form_models[form.id] = form
recipient_stmt = select(HumanInputFormRecipient).where(HumanInputFormRecipient.form_id.in_(form_ids))
for recipient in session.scalars(recipient_stmt).all():
recipient_models_by_form.setdefault(recipient.form_id, []).append(recipient)
pause_reasons: list[PauseReason] = []
for reason in pause_reason_models:
if reason.type_ == PauseReasonType.HUMAN_INPUT_REQUIRED:
form_model = form_models.get(reason.form_id)
recipients = recipient_models_by_form.get(reason.form_id, [])
pause_reasons.append(_build_human_input_required_reason(reason, form_model, recipients))
else:
pause_reasons.append(reason.to_entity())
return pause_reasons
def get_workflow_pause(
self,
workflow_run_id: str,
@ -767,14 +869,12 @@ class DifyAPISQLAlchemyWorkflowRunRepository(APIWorkflowRunRepository):
if pause_model is None:
return None
pause_reason_models = self._get_reasons_by_pause_id(session, pause_model.id)
human_input_form: list[Any] = []
# TODO(QuantumGhost): query human_input_forms model and rebuild PauseReason
pause_reasons = self._hydrate_pause_reasons(session, pause_reason_models)
return _PrivateWorkflowPauseEntity(
pause_model=pause_model,
reason_models=pause_reason_models,
human_input_form=human_input_form,
pause_reasons=pause_reasons,
)
def resume_workflow_pause(
@ -828,10 +928,10 @@ class DifyAPISQLAlchemyWorkflowRunRepository(APIWorkflowRunRepository):
raise _WorkflowRunError(f"Cannot resume an already resumed pause, pause_id={pause_model.id}")
pause_reasons = self._get_reasons_by_pause_id(session, pause_model.id)
hydrated_pause_reasons = self._hydrate_pause_reasons(session, pause_reasons)
# Mark as resumed
pause_model.resumed_at = naive_utc_now()
workflow_run.pause_id = None # type: ignore
workflow_run.status = WorkflowExecutionStatus.RUNNING
session.add(pause_model)
@ -839,7 +939,11 @@ class DifyAPISQLAlchemyWorkflowRunRepository(APIWorkflowRunRepository):
logger.info("Resumed workflow pause %s for workflow run %s", pause_model.id, workflow_run_id)
return _PrivateWorkflowPauseEntity(pause_model=pause_model, reason_models=pause_reasons)
return _PrivateWorkflowPauseEntity(
pause_model=pause_model,
reason_models=pause_reasons,
pause_reasons=hydrated_pause_reasons,
)
def delete_workflow_pause(
self,
@ -1165,6 +1269,15 @@ GROUP BY
return cast(list[AverageInteractionStats], response_data)
def get_workflow_run_by_id_and_tenant_id(self, tenant_id: str, run_id: str) -> WorkflowRun | None:
"""Get a specific workflow run by its id and the associated tenant id."""
with self._session_maker() as session:
stmt = select(WorkflowRun).where(
WorkflowRun.tenant_id == tenant_id,
WorkflowRun.id == run_id,
)
return session.scalar(stmt)
class _PrivateWorkflowPauseEntity(WorkflowPauseEntity):
"""
@ -1179,10 +1292,12 @@ class _PrivateWorkflowPauseEntity(WorkflowPauseEntity):
*,
pause_model: WorkflowPause,
reason_models: Sequence[WorkflowPauseReason],
pause_reasons: Sequence[PauseReason] | None = None,
human_input_form: Sequence = (),
) -> None:
self._pause_model = pause_model
self._reason_models = reason_models
self._pause_reasons = pause_reasons
self._cached_state: bytes | None = None
self._human_input_form = human_input_form
@ -1219,4 +1334,10 @@ class _PrivateWorkflowPauseEntity(WorkflowPauseEntity):
return self._pause_model.resumed_at
def get_pause_reasons(self) -> Sequence[PauseReason]:
if self._pause_reasons is not None:
return list(self._pause_reasons)
return [reason.to_entity() for reason in self._reason_models]
@property
def paused_at(self) -> datetime:
return self._pause_model.created_at

View File

@ -0,0 +1,200 @@
from __future__ import annotations
import json
import logging
import re
from collections import defaultdict
from collections.abc import Sequence
from typing import Any
from sqlalchemy import select
from sqlalchemy.orm import Session, selectinload, sessionmaker
from core.entities.execution_extra_content import (
ExecutionExtraContentDomainModel,
HumanInputFormDefinition,
HumanInputFormSubmissionData,
)
from core.entities.execution_extra_content import (
HumanInputContent as HumanInputContentDomainModel,
)
from core.workflow.nodes.human_input.entities import FormDefinition
from core.workflow.nodes.human_input.enums import HumanInputFormStatus
from core.workflow.nodes.human_input.human_input_node import HumanInputNode
from models.execution_extra_content import (
ExecutionExtraContent as ExecutionExtraContentModel,
)
from models.execution_extra_content import (
HumanInputContent as HumanInputContentModel,
)
from models.human_input import HumanInputFormRecipient, RecipientType
from repositories.execution_extra_content_repository import ExecutionExtraContentRepository
logger = logging.getLogger(__name__)
_OUTPUT_VARIABLE_PATTERN = re.compile(r"\{\{#\$output\.(?P<field_name>[a-zA-Z_][a-zA-Z0-9_]{0,29})#\}\}")
def _extract_output_field_names(form_content: str) -> list[str]:
if not form_content:
return []
return [match.group("field_name") for match in _OUTPUT_VARIABLE_PATTERN.finditer(form_content)]
class SQLAlchemyExecutionExtraContentRepository(ExecutionExtraContentRepository):
def __init__(self, session_maker: sessionmaker[Session]):
self._session_maker = session_maker
def get_by_message_ids(self, message_ids: Sequence[str]) -> list[list[ExecutionExtraContentDomainModel]]:
if not message_ids:
return []
grouped_contents: dict[str, list[ExecutionExtraContentDomainModel]] = {
message_id: [] for message_id in message_ids
}
stmt = (
select(ExecutionExtraContentModel)
.where(ExecutionExtraContentModel.message_id.in_(message_ids))
.options(selectinload(HumanInputContentModel.form))
.order_by(ExecutionExtraContentModel.created_at.asc())
)
with self._session_maker() as session:
results = session.scalars(stmt).all()
form_ids = {
content.form_id
for content in results
if isinstance(content, HumanInputContentModel) and content.form_id is not None
}
recipients_by_form_id: dict[str, list[HumanInputFormRecipient]] = defaultdict(list)
if form_ids:
recipient_stmt = select(HumanInputFormRecipient).where(HumanInputFormRecipient.form_id.in_(form_ids))
recipients = session.scalars(recipient_stmt).all()
for recipient in recipients:
recipients_by_form_id[recipient.form_id].append(recipient)
else:
recipients_by_form_id = {}
for content in results:
message_id = content.message_id
if not message_id or message_id not in grouped_contents:
continue
domain_model = self._map_model_to_domain(content, recipients_by_form_id)
if domain_model is None:
continue
grouped_contents[message_id].append(domain_model)
return [grouped_contents[message_id] for message_id in message_ids]
def _map_model_to_domain(
self,
model: ExecutionExtraContentModel,
recipients_by_form_id: dict[str, list[HumanInputFormRecipient]],
) -> ExecutionExtraContentDomainModel | None:
if isinstance(model, HumanInputContentModel):
return self._map_human_input_content(model, recipients_by_form_id)
logger.debug("Unsupported execution extra content type encountered: %s", model.type)
return None
def _map_human_input_content(
self,
model: HumanInputContentModel,
recipients_by_form_id: dict[str, list[HumanInputFormRecipient]],
) -> HumanInputContentDomainModel | None:
form = model.form
if form is None:
logger.warning("HumanInputContent(id=%s) has no associated form loaded", model.id)
return None
try:
definition_payload = json.loads(form.form_definition)
if "expiration_time" not in definition_payload:
definition_payload["expiration_time"] = form.expiration_time
form_definition = FormDefinition.model_validate(definition_payload)
except ValueError:
logger.warning("Failed to load form definition for HumanInputContent(id=%s)", model.id)
return None
node_title = form_definition.node_title or form.node_id
display_in_ui = bool(form_definition.display_in_ui)
submitted = form.submitted_at is not None or form.status == HumanInputFormStatus.SUBMITTED
if not submitted:
form_token = self._resolve_form_token(recipients_by_form_id.get(form.id, []))
return HumanInputContentDomainModel(
workflow_run_id=model.workflow_run_id,
submitted=False,
form_definition=HumanInputFormDefinition(
form_id=form.id,
node_id=form.node_id,
node_title=node_title,
form_content=form.rendered_content,
inputs=form_definition.inputs,
actions=form_definition.user_actions,
display_in_ui=display_in_ui,
form_token=form_token,
resolved_default_values=form_definition.default_values,
expiration_time=int(form.expiration_time.timestamp()),
),
)
selected_action_id = form.selected_action_id
if not selected_action_id:
logger.warning("HumanInputContent(id=%s) form has no selected action", model.id)
return None
action_text = next(
(action.title for action in form_definition.user_actions if action.id == selected_action_id),
selected_action_id,
)
submitted_data: dict[str, Any] = {}
if form.submitted_data:
try:
submitted_data = json.loads(form.submitted_data)
except ValueError:
logger.warning("Failed to load submitted data for HumanInputContent(id=%s)", model.id)
return None
rendered_content = HumanInputNode.render_form_content_with_outputs(
form.rendered_content,
submitted_data,
_extract_output_field_names(form_definition.form_content),
)
return HumanInputContentDomainModel(
workflow_run_id=model.workflow_run_id,
submitted=True,
form_submission_data=HumanInputFormSubmissionData(
node_id=form.node_id,
node_title=node_title,
rendered_content=rendered_content,
action_id=selected_action_id,
action_text=action_text,
),
)
@staticmethod
def _resolve_form_token(recipients: Sequence[HumanInputFormRecipient]) -> str | None:
console_recipient = next(
(recipient for recipient in recipients if recipient.recipient_type == RecipientType.CONSOLE),
None,
)
if console_recipient and console_recipient.access_token:
return console_recipient.access_token
web_app_recipient = next(
(recipient for recipient in recipients if recipient.recipient_type == RecipientType.STANDALONE_WEB_APP),
None,
)
if web_app_recipient and web_app_recipient.access_token:
return web_app_recipient.access_token
return None
__all__ = ["SQLAlchemyExecutionExtraContentRepository"]

View File

@ -92,6 +92,16 @@ class SQLAlchemyWorkflowTriggerLogRepository(WorkflowTriggerLogRepository):
return list(self.session.scalars(query).all())
def get_by_workflow_run_id(self, workflow_run_id: str) -> WorkflowTriggerLog | None:
"""Get the trigger log associated with a workflow run."""
query = (
select(WorkflowTriggerLog)
.where(WorkflowTriggerLog.workflow_run_id == workflow_run_id)
.order_by(WorkflowTriggerLog.created_at.desc())
.limit(1)
)
return self.session.scalar(query)
def delete_by_run_ids(self, run_ids: Sequence[str]) -> int:
"""
Delete trigger logs associated with the given workflow run ids.

View File

@ -110,6 +110,18 @@ class WorkflowTriggerLogRepository(Protocol):
"""
...
def get_by_workflow_run_id(self, workflow_run_id: str) -> WorkflowTriggerLog | None:
"""
Retrieve a trigger log associated with a specific workflow run.
Args:
workflow_run_id: Identifier of the workflow run
Returns:
The matching WorkflowTriggerLog if present, None otherwise
"""
...
def delete_by_run_ids(self, run_ids: Sequence[str]) -> int:
"""
Delete trigger logs for workflow run IDs.

View File

@ -44,7 +44,7 @@ IMPORT_INFO_REDIS_KEY_PREFIX = "app_import_info:"
CHECK_DEPENDENCIES_REDIS_KEY_PREFIX = "app_check_dependencies:"
IMPORT_INFO_REDIS_EXPIRY = 10 * 60 # 10 minutes
DSL_MAX_SIZE = 10 * 1024 * 1024 # 10MB
CURRENT_DSL_VERSION = "0.5.0"
CURRENT_DSL_VERSION = "0.6.0"
class ImportMode(StrEnum):

Some files were not shown because too many files have changed in this diff Show More