mirror of https://github.com/langgenius/dify.git
feat(api): Human Input Node (backend part) (#31646)
The backend part of the human in the loop (HITL) feature and relevant architecture / workflow engine changes. Signed-off-by: yihong0618 <zouzou0208@gmail.com> Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com> Co-authored-by: -LAN- <laipz8200@outlook.com> Co-authored-by: 盐粒 Yanli <yanli@dify.ai> Co-authored-by: CrabSAMA <40541269+CrabSAMA@users.noreply.github.com> Co-authored-by: Stephen Zhou <38493346+hyoban@users.noreply.github.com> Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> Co-authored-by: yihong <zouzou0208@gmail.com> Co-authored-by: Joel <iamjoel007@gmail.com>
This commit is contained in:
parent
fedd097f63
commit
03e3acfc71
|
|
@ -8,6 +8,7 @@ on:
|
|||
- "build/**"
|
||||
- "release/e-*"
|
||||
- "hotfix/**"
|
||||
- "feat/hitl-backend"
|
||||
tags:
|
||||
- "*"
|
||||
|
||||
|
|
|
|||
|
|
@ -717,3 +717,28 @@ SANDBOX_EXPIRED_RECORDS_CLEAN_GRACEFUL_PERIOD=21
|
|||
SANDBOX_EXPIRED_RECORDS_CLEAN_BATCH_SIZE=1000
|
||||
SANDBOX_EXPIRED_RECORDS_RETENTION_DAYS=30
|
||||
SANDBOX_EXPIRED_RECORDS_CLEAN_TASK_LOCK_TTL=90000
|
||||
|
||||
|
||||
# Redis URL used for PubSub between API and
|
||||
# celery worker
|
||||
# defaults to url constructed from `REDIS_*`
|
||||
# configurations
|
||||
PUBSUB_REDIS_URL=
|
||||
# Pub/sub channel type for streaming events.
|
||||
# valid options are:
|
||||
#
|
||||
# - pubsub: for normal Pub/Sub
|
||||
# - sharded: for sharded Pub/Sub
|
||||
#
|
||||
# It's highly recommended to use sharded Pub/Sub AND redis cluster
|
||||
# for large deployments.
|
||||
PUBSUB_REDIS_CHANNEL_TYPE=pubsub
|
||||
# Whether to use Redis cluster mode while running
|
||||
# PubSub.
|
||||
# It's highly recommended to enable this for large deployments.
|
||||
PUBSUB_REDIS_USE_CLUSTERS=false
|
||||
|
||||
# Whether to Enable human input timeout check task
|
||||
ENABLE_HUMAN_INPUT_TIMEOUT_TASK=true
|
||||
# Human input timeout check interval in minutes
|
||||
HUMAN_INPUT_TIMEOUT_TASK_INTERVAL=1
|
||||
|
|
|
|||
|
|
@ -36,6 +36,8 @@ ignore_imports =
|
|||
core.workflow.nodes.loop.loop_node -> core.workflow.graph_engine
|
||||
core.workflow.nodes.loop.loop_node -> core.workflow.graph
|
||||
core.workflow.nodes.loop.loop_node -> core.workflow.graph_engine.command_channels
|
||||
# TODO(QuantumGhost): fix the import violation later
|
||||
core.workflow.entities.pause_reason -> core.workflow.nodes.human_input.entities
|
||||
|
||||
[importlinter:contract:workflow-infrastructure-dependencies]
|
||||
name = Workflow Infrastructure Dependencies
|
||||
|
|
@ -58,6 +60,8 @@ ignore_imports =
|
|||
core.workflow.graph_engine.command_channels.redis_channel -> extensions.ext_redis
|
||||
core.workflow.graph_engine.manager -> extensions.ext_redis
|
||||
core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node -> extensions.ext_redis
|
||||
# TODO(QuantumGhost): use DI to avoid depending on global DB.
|
||||
core.workflow.nodes.human_input.human_input_node -> extensions.ext_database
|
||||
|
||||
[importlinter:contract:workflow-external-imports]
|
||||
name = Workflow External Imports
|
||||
|
|
@ -145,6 +149,7 @@ ignore_imports =
|
|||
core.workflow.nodes.agent.agent_node -> core.agent.entities
|
||||
core.workflow.nodes.agent.agent_node -> core.agent.plugin_entities
|
||||
core.workflow.nodes.base.node -> core.app.entities.app_invoke_entities
|
||||
core.workflow.nodes.human_input.human_input_node -> core.app.entities.app_invoke_entities
|
||||
core.workflow.nodes.knowledge_index.knowledge_index_node -> core.app.entities.app_invoke_entities
|
||||
core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node -> core.app.app_config.entities
|
||||
core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node -> core.app.entities.app_invoke_entities
|
||||
|
|
@ -248,6 +253,7 @@ ignore_imports =
|
|||
core.workflow.nodes.document_extractor.node -> core.variables.segments
|
||||
core.workflow.nodes.http_request.executor -> core.variables.segments
|
||||
core.workflow.nodes.http_request.node -> core.variables.segments
|
||||
core.workflow.nodes.human_input.entities -> core.variables.consts
|
||||
core.workflow.nodes.iteration.iteration_node -> core.variables
|
||||
core.workflow.nodes.iteration.iteration_node -> core.variables.segments
|
||||
core.workflow.nodes.iteration.iteration_node -> core.variables.variables
|
||||
|
|
@ -294,6 +300,8 @@ ignore_imports =
|
|||
core.workflow.nodes.llm.llm_utils -> extensions.ext_database
|
||||
core.workflow.nodes.llm.node -> extensions.ext_database
|
||||
core.workflow.nodes.tool.tool_node -> extensions.ext_database
|
||||
core.workflow.nodes.human_input.human_input_node -> extensions.ext_database
|
||||
core.workflow.nodes.human_input.human_input_node -> core.repositories.human_input_repository
|
||||
core.workflow.workflow_entry -> extensions.otel.runtime
|
||||
core.workflow.nodes.agent.agent_node -> models
|
||||
core.workflow.nodes.base.node -> models.enums
|
||||
|
|
|
|||
|
|
@ -1,3 +1,4 @@
|
|||
from datetime import timedelta
|
||||
from enum import StrEnum
|
||||
from typing import Literal
|
||||
|
||||
|
|
@ -48,6 +49,16 @@ class SecurityConfig(BaseSettings):
|
|||
default=5,
|
||||
)
|
||||
|
||||
WEB_FORM_SUBMIT_RATE_LIMIT_MAX_ATTEMPTS: PositiveInt = Field(
|
||||
description="Maximum number of web form submissions allowed per IP within the rate limit window",
|
||||
default=30,
|
||||
)
|
||||
|
||||
WEB_FORM_SUBMIT_RATE_LIMIT_WINDOW_SECONDS: PositiveInt = Field(
|
||||
description="Time window in seconds for web form submission rate limiting",
|
||||
default=60,
|
||||
)
|
||||
|
||||
LOGIN_DISABLED: bool = Field(
|
||||
description="Whether to disable login checks",
|
||||
default=False,
|
||||
|
|
@ -82,6 +93,12 @@ class AppExecutionConfig(BaseSettings):
|
|||
default=0,
|
||||
)
|
||||
|
||||
HITL_GLOBAL_TIMEOUT_SECONDS: PositiveInt = Field(
|
||||
description="Maximum seconds a workflow run can stay paused waiting for human input before global timeout.",
|
||||
default=int(timedelta(days=3).total_seconds()),
|
||||
ge=1,
|
||||
)
|
||||
|
||||
|
||||
class CodeExecutionSandboxConfig(BaseSettings):
|
||||
"""
|
||||
|
|
@ -1134,6 +1151,14 @@ class CeleryScheduleTasksConfig(BaseSettings):
|
|||
description="Enable queue monitor task",
|
||||
default=False,
|
||||
)
|
||||
ENABLE_HUMAN_INPUT_TIMEOUT_TASK: bool = Field(
|
||||
description="Enable human input timeout check task",
|
||||
default=True,
|
||||
)
|
||||
HUMAN_INPUT_TIMEOUT_TASK_INTERVAL: PositiveInt = Field(
|
||||
description="Human input timeout check interval in minutes",
|
||||
default=1,
|
||||
)
|
||||
ENABLE_CHECK_UPGRADABLE_PLUGIN_TASK: bool = Field(
|
||||
description="Enable check upgradable plugin task",
|
||||
default=True,
|
||||
|
|
|
|||
|
|
@ -6,6 +6,7 @@ from pydantic import Field, NonNegativeFloat, NonNegativeInt, PositiveFloat, Pos
|
|||
from pydantic_settings import BaseSettings
|
||||
|
||||
from .cache.redis_config import RedisConfig
|
||||
from .cache.redis_pubsub_config import RedisPubSubConfig
|
||||
from .storage.aliyun_oss_storage_config import AliyunOSSStorageConfig
|
||||
from .storage.amazon_s3_storage_config import S3StorageConfig
|
||||
from .storage.azure_blob_storage_config import AzureBlobStorageConfig
|
||||
|
|
@ -317,6 +318,7 @@ class MiddlewareConfig(
|
|||
CeleryConfig, # Note: CeleryConfig already inherits from DatabaseConfig
|
||||
KeywordStoreConfig,
|
||||
RedisConfig,
|
||||
RedisPubSubConfig,
|
||||
# configs of storage and storage providers
|
||||
StorageConfig,
|
||||
AliyunOSSStorageConfig,
|
||||
|
|
|
|||
|
|
@ -0,0 +1,96 @@
|
|||
from typing import Literal, Protocol
|
||||
from urllib.parse import quote_plus, urlunparse
|
||||
|
||||
from pydantic import Field
|
||||
from pydantic_settings import BaseSettings
|
||||
|
||||
|
||||
class RedisConfigDefaults(Protocol):
|
||||
REDIS_HOST: str
|
||||
REDIS_PORT: int
|
||||
REDIS_USERNAME: str | None
|
||||
REDIS_PASSWORD: str | None
|
||||
REDIS_DB: int
|
||||
REDIS_USE_SSL: bool
|
||||
REDIS_USE_SENTINEL: bool | None
|
||||
REDIS_USE_CLUSTERS: bool
|
||||
|
||||
|
||||
class RedisConfigDefaultsMixin:
|
||||
def _redis_defaults(self: RedisConfigDefaults) -> RedisConfigDefaults:
|
||||
return self
|
||||
|
||||
|
||||
class RedisPubSubConfig(BaseSettings, RedisConfigDefaultsMixin):
|
||||
"""
|
||||
Configuration settings for Redis pub/sub streaming.
|
||||
"""
|
||||
|
||||
PUBSUB_REDIS_URL: str | None = Field(
|
||||
alias="PUBSUB_REDIS_URL",
|
||||
description=(
|
||||
"Redis connection URL for pub/sub streaming events between API "
|
||||
"and celery worker, defaults to url constructed from "
|
||||
"`REDIS_*` configurations"
|
||||
),
|
||||
default=None,
|
||||
)
|
||||
|
||||
PUBSUB_REDIS_USE_CLUSTERS: bool = Field(
|
||||
description=(
|
||||
"Enable Redis Cluster mode for pub/sub streaming. It's highly "
|
||||
"recommended to enable this for large deployments."
|
||||
),
|
||||
default=False,
|
||||
)
|
||||
|
||||
PUBSUB_REDIS_CHANNEL_TYPE: Literal["pubsub", "sharded"] = Field(
|
||||
description=(
|
||||
"Pub/sub channel type for streaming events. "
|
||||
"Valid options are:\n"
|
||||
"\n"
|
||||
" - pubsub: for normal Pub/Sub\n"
|
||||
" - sharded: for sharded Pub/Sub\n"
|
||||
"\n"
|
||||
"It's highly recommended to use sharded Pub/Sub AND redis cluster "
|
||||
"for large deployments."
|
||||
),
|
||||
default="pubsub",
|
||||
)
|
||||
|
||||
def _build_default_pubsub_url(self) -> str:
|
||||
defaults = self._redis_defaults()
|
||||
if not defaults.REDIS_HOST or not defaults.REDIS_PORT:
|
||||
raise ValueError("PUBSUB_REDIS_URL must be set when default Redis URL cannot be constructed")
|
||||
|
||||
scheme = "rediss" if defaults.REDIS_USE_SSL else "redis"
|
||||
username = defaults.REDIS_USERNAME or None
|
||||
password = defaults.REDIS_PASSWORD or None
|
||||
|
||||
userinfo = ""
|
||||
if username:
|
||||
userinfo = quote_plus(username)
|
||||
if password:
|
||||
password_part = quote_plus(password)
|
||||
userinfo = f"{userinfo}:{password_part}" if userinfo else f":{password_part}"
|
||||
if userinfo:
|
||||
userinfo = f"{userinfo}@"
|
||||
|
||||
host = defaults.REDIS_HOST
|
||||
port = defaults.REDIS_PORT
|
||||
db = defaults.REDIS_DB
|
||||
|
||||
netloc = f"{userinfo}{host}:{port}"
|
||||
return urlunparse((scheme, netloc, f"/{db}", "", "", ""))
|
||||
|
||||
@property
|
||||
def normalized_pubsub_redis_url(self) -> str:
|
||||
pubsub_redis_url = self.PUBSUB_REDIS_URL
|
||||
if pubsub_redis_url:
|
||||
cleaned = pubsub_redis_url.strip()
|
||||
pubsub_redis_url = cleaned or None
|
||||
|
||||
if pubsub_redis_url:
|
||||
return pubsub_redis_url
|
||||
|
||||
return self._build_default_pubsub_url()
|
||||
|
|
@ -37,6 +37,7 @@ from . import (
|
|||
apikey,
|
||||
extension,
|
||||
feature,
|
||||
human_input_form,
|
||||
init_validate,
|
||||
ping,
|
||||
setup,
|
||||
|
|
@ -171,6 +172,7 @@ __all__ = [
|
|||
"forgot_password",
|
||||
"generator",
|
||||
"hit_testing",
|
||||
"human_input_form",
|
||||
"init_validate",
|
||||
"installed_app",
|
||||
"load_balancing_config",
|
||||
|
|
|
|||
|
|
@ -89,6 +89,7 @@ status_count_model = console_ns.model(
|
|||
"success": fields.Integer,
|
||||
"failed": fields.Integer,
|
||||
"partial_success": fields.Integer,
|
||||
"paused": fields.Integer,
|
||||
},
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -32,7 +32,7 @@ from libs.login import current_account_with_tenant, login_required
|
|||
from models.model import AppMode, Conversation, Message, MessageAnnotation, MessageFeedback
|
||||
from services.errors.conversation import ConversationNotExistsError
|
||||
from services.errors.message import MessageNotExistsError, SuggestedQuestionsAfterAnswerDisabledError
|
||||
from services.message_service import MessageService
|
||||
from services.message_service import MessageService, attach_message_extra_contents
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
|
||||
|
|
@ -198,6 +198,7 @@ message_detail_model = console_ns.model(
|
|||
"created_at": TimestampField,
|
||||
"agent_thoughts": fields.List(fields.Nested(agent_thought_model)),
|
||||
"message_files": fields.List(fields.Nested(message_file_model)),
|
||||
"extra_contents": fields.List(fields.Raw),
|
||||
"metadata": fields.Raw(attribute="message_metadata_dict"),
|
||||
"status": fields.String,
|
||||
"error": fields.String,
|
||||
|
|
@ -290,6 +291,7 @@ class ChatMessageListApi(Resource):
|
|||
has_more = False
|
||||
|
||||
history_messages = list(reversed(history_messages))
|
||||
attach_message_extra_contents(history_messages)
|
||||
|
||||
return InfiniteScrollPagination(data=history_messages, limit=args.limit, has_more=has_more)
|
||||
|
||||
|
|
@ -474,4 +476,5 @@ class MessageApi(Resource):
|
|||
if not message:
|
||||
raise NotFound("Message Not Exists.")
|
||||
|
||||
attach_message_extra_contents([message])
|
||||
return message
|
||||
|
|
|
|||
|
|
@ -507,6 +507,179 @@ class WorkflowDraftRunLoopNodeApi(Resource):
|
|||
raise InternalServerError()
|
||||
|
||||
|
||||
class HumanInputFormPreviewPayload(BaseModel):
|
||||
inputs: dict[str, Any] = Field(
|
||||
default_factory=dict,
|
||||
description="Values used to fill missing upstream variables referenced in form_content",
|
||||
)
|
||||
|
||||
|
||||
class HumanInputFormSubmitPayload(BaseModel):
|
||||
form_inputs: dict[str, Any] = Field(..., description="Values the user provides for the form's own fields")
|
||||
inputs: dict[str, Any] = Field(
|
||||
...,
|
||||
description="Values used to fill missing upstream variables referenced in form_content",
|
||||
)
|
||||
action: str = Field(..., description="Selected action ID")
|
||||
|
||||
|
||||
class HumanInputDeliveryTestPayload(BaseModel):
|
||||
delivery_method_id: str = Field(..., description="Delivery method ID")
|
||||
inputs: dict[str, Any] = Field(
|
||||
default_factory=dict,
|
||||
description="Values used to fill missing upstream variables referenced in form_content",
|
||||
)
|
||||
|
||||
|
||||
reg(HumanInputFormPreviewPayload)
|
||||
reg(HumanInputFormSubmitPayload)
|
||||
reg(HumanInputDeliveryTestPayload)
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/advanced-chat/workflows/draft/human-input/nodes/<string:node_id>/form/preview")
|
||||
class AdvancedChatDraftHumanInputFormPreviewApi(Resource):
|
||||
@console_ns.doc("get_advanced_chat_draft_human_input_form")
|
||||
@console_ns.doc(description="Get human input form preview for advanced chat workflow")
|
||||
@console_ns.doc(params={"app_id": "Application ID", "node_id": "Node ID"})
|
||||
@console_ns.expect(console_ns.models[HumanInputFormPreviewPayload.__name__])
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.ADVANCED_CHAT])
|
||||
@edit_permission_required
|
||||
def post(self, app_model: App, node_id: str):
|
||||
"""
|
||||
Preview human input form content and placeholders
|
||||
"""
|
||||
current_user, _ = current_account_with_tenant()
|
||||
args = HumanInputFormPreviewPayload.model_validate(console_ns.payload or {})
|
||||
inputs = args.inputs
|
||||
|
||||
workflow_service = WorkflowService()
|
||||
preview = workflow_service.get_human_input_form_preview(
|
||||
app_model=app_model,
|
||||
account=current_user,
|
||||
node_id=node_id,
|
||||
inputs=inputs,
|
||||
)
|
||||
return jsonable_encoder(preview)
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/advanced-chat/workflows/draft/human-input/nodes/<string:node_id>/form/run")
|
||||
class AdvancedChatDraftHumanInputFormRunApi(Resource):
|
||||
@console_ns.doc("submit_advanced_chat_draft_human_input_form")
|
||||
@console_ns.doc(description="Submit human input form preview for advanced chat workflow")
|
||||
@console_ns.doc(params={"app_id": "Application ID", "node_id": "Node ID"})
|
||||
@console_ns.expect(console_ns.models[HumanInputFormSubmitPayload.__name__])
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.ADVANCED_CHAT])
|
||||
@edit_permission_required
|
||||
def post(self, app_model: App, node_id: str):
|
||||
"""
|
||||
Submit human input form preview
|
||||
"""
|
||||
current_user, _ = current_account_with_tenant()
|
||||
args = HumanInputFormSubmitPayload.model_validate(console_ns.payload or {})
|
||||
workflow_service = WorkflowService()
|
||||
result = workflow_service.submit_human_input_form_preview(
|
||||
app_model=app_model,
|
||||
account=current_user,
|
||||
node_id=node_id,
|
||||
form_inputs=args.form_inputs,
|
||||
inputs=args.inputs,
|
||||
action=args.action,
|
||||
)
|
||||
return jsonable_encoder(result)
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/workflows/draft/human-input/nodes/<string:node_id>/form/preview")
|
||||
class WorkflowDraftHumanInputFormPreviewApi(Resource):
|
||||
@console_ns.doc("get_workflow_draft_human_input_form")
|
||||
@console_ns.doc(description="Get human input form preview for workflow")
|
||||
@console_ns.doc(params={"app_id": "Application ID", "node_id": "Node ID"})
|
||||
@console_ns.expect(console_ns.models[HumanInputFormPreviewPayload.__name__])
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.WORKFLOW])
|
||||
@edit_permission_required
|
||||
def post(self, app_model: App, node_id: str):
|
||||
"""
|
||||
Preview human input form content and placeholders
|
||||
"""
|
||||
current_user, _ = current_account_with_tenant()
|
||||
args = HumanInputFormPreviewPayload.model_validate(console_ns.payload or {})
|
||||
inputs = args.inputs
|
||||
|
||||
workflow_service = WorkflowService()
|
||||
preview = workflow_service.get_human_input_form_preview(
|
||||
app_model=app_model,
|
||||
account=current_user,
|
||||
node_id=node_id,
|
||||
inputs=inputs,
|
||||
)
|
||||
return jsonable_encoder(preview)
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/workflows/draft/human-input/nodes/<string:node_id>/form/run")
|
||||
class WorkflowDraftHumanInputFormRunApi(Resource):
|
||||
@console_ns.doc("submit_workflow_draft_human_input_form")
|
||||
@console_ns.doc(description="Submit human input form preview for workflow")
|
||||
@console_ns.doc(params={"app_id": "Application ID", "node_id": "Node ID"})
|
||||
@console_ns.expect(console_ns.models[HumanInputFormSubmitPayload.__name__])
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.WORKFLOW])
|
||||
@edit_permission_required
|
||||
def post(self, app_model: App, node_id: str):
|
||||
"""
|
||||
Submit human input form preview
|
||||
"""
|
||||
current_user, _ = current_account_with_tenant()
|
||||
workflow_service = WorkflowService()
|
||||
args = HumanInputFormSubmitPayload.model_validate(console_ns.payload or {})
|
||||
result = workflow_service.submit_human_input_form_preview(
|
||||
app_model=app_model,
|
||||
account=current_user,
|
||||
node_id=node_id,
|
||||
form_inputs=args.form_inputs,
|
||||
inputs=args.inputs,
|
||||
action=args.action,
|
||||
)
|
||||
return jsonable_encoder(result)
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/workflows/draft/human-input/nodes/<string:node_id>/delivery-test")
|
||||
class WorkflowDraftHumanInputDeliveryTestApi(Resource):
|
||||
@console_ns.doc("test_workflow_draft_human_input_delivery")
|
||||
@console_ns.doc(description="Test human input delivery for workflow")
|
||||
@console_ns.doc(params={"app_id": "Application ID", "node_id": "Node ID"})
|
||||
@console_ns.expect(console_ns.models[HumanInputDeliveryTestPayload.__name__])
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.WORKFLOW, AppMode.ADVANCED_CHAT])
|
||||
@edit_permission_required
|
||||
def post(self, app_model: App, node_id: str):
|
||||
"""
|
||||
Test human input delivery
|
||||
"""
|
||||
current_user, _ = current_account_with_tenant()
|
||||
workflow_service = WorkflowService()
|
||||
args = HumanInputDeliveryTestPayload.model_validate(console_ns.payload or {})
|
||||
workflow_service.test_human_input_delivery(
|
||||
app_model=app_model,
|
||||
account=current_user,
|
||||
node_id=node_id,
|
||||
delivery_method_id=args.delivery_method_id,
|
||||
inputs=args.inputs,
|
||||
)
|
||||
return jsonable_encoder({})
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/workflows/draft/run")
|
||||
class DraftWorkflowRunApi(Resource):
|
||||
@console_ns.doc("run_draft_workflow")
|
||||
|
|
|
|||
|
|
@ -5,10 +5,15 @@ from flask import request
|
|||
from flask_restx import Resource, fields, marshal_with
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
from configs import dify_config
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.app.wraps import get_app_model
|
||||
from controllers.console.wraps import account_initialization_required, setup_required
|
||||
from controllers.web.error import NotFoundError
|
||||
from core.workflow.entities.pause_reason import HumanInputRequired
|
||||
from core.workflow.enums import WorkflowExecutionStatus
|
||||
from extensions.ext_database import db
|
||||
from fields.end_user_fields import simple_end_user_fields
|
||||
from fields.member_fields import simple_account_fields
|
||||
|
|
@ -27,9 +32,21 @@ from libs.custom_inputs import time_duration
|
|||
from libs.helper import uuid_value
|
||||
from libs.login import current_user, login_required
|
||||
from models import Account, App, AppMode, EndUser, WorkflowArchiveLog, WorkflowRunTriggeredFrom
|
||||
from models.workflow import WorkflowRun
|
||||
from repositories.factory import DifyAPIRepositoryFactory
|
||||
from services.retention.workflow_run.constants import ARCHIVE_BUNDLE_NAME
|
||||
from services.workflow_run_service import WorkflowRunService
|
||||
|
||||
|
||||
def _build_backstage_input_url(form_token: str | None) -> str | None:
|
||||
if not form_token:
|
||||
return None
|
||||
base_url = dify_config.APP_WEB_URL
|
||||
if not base_url:
|
||||
return None
|
||||
return f"{base_url.rstrip('/')}/form/{form_token}"
|
||||
|
||||
|
||||
# Workflow run status choices for filtering
|
||||
WORKFLOW_RUN_STATUS_CHOICES = ["running", "succeeded", "failed", "stopped", "partial-succeeded"]
|
||||
EXPORT_SIGNED_URL_EXPIRE_SECONDS = 3600
|
||||
|
|
@ -440,3 +457,63 @@ class WorkflowRunNodeExecutionListApi(Resource):
|
|||
)
|
||||
|
||||
return {"data": node_executions}
|
||||
|
||||
|
||||
@console_ns.route("/workflow/<string:workflow_run_id>/pause-details")
|
||||
class ConsoleWorkflowPauseDetailsApi(Resource):
|
||||
"""Console API for getting workflow pause details."""
|
||||
|
||||
@account_initialization_required
|
||||
@login_required
|
||||
def get(self, workflow_run_id: str):
|
||||
"""
|
||||
Get workflow pause details.
|
||||
|
||||
GET /console/api/workflow/<workflow_run_id>/pause-details
|
||||
|
||||
Returns information about why and where the workflow is paused.
|
||||
"""
|
||||
|
||||
# Query WorkflowRun to determine if workflow is suspended
|
||||
session_maker = sessionmaker(bind=db.engine)
|
||||
workflow_run_repo = DifyAPIRepositoryFactory.create_api_workflow_run_repository(session_maker=session_maker)
|
||||
workflow_run = db.session.get(WorkflowRun, workflow_run_id)
|
||||
if not workflow_run:
|
||||
raise NotFoundError("Workflow run not found")
|
||||
|
||||
# Check if workflow is suspended
|
||||
is_paused = workflow_run.status == WorkflowExecutionStatus.PAUSED
|
||||
if not is_paused:
|
||||
return {
|
||||
"paused_at": None,
|
||||
"paused_nodes": [],
|
||||
}, 200
|
||||
|
||||
pause_entity = workflow_run_repo.get_workflow_pause(workflow_run_id)
|
||||
pause_reasons = pause_entity.get_pause_reasons() if pause_entity else []
|
||||
|
||||
# Build response
|
||||
paused_at = pause_entity.paused_at if pause_entity else None
|
||||
paused_nodes = []
|
||||
response = {
|
||||
"paused_at": paused_at.isoformat() + "Z" if paused_at else None,
|
||||
"paused_nodes": paused_nodes,
|
||||
}
|
||||
|
||||
for reason in pause_reasons:
|
||||
if isinstance(reason, HumanInputRequired):
|
||||
paused_nodes.append(
|
||||
{
|
||||
"node_id": reason.node_id,
|
||||
"node_title": reason.node_title,
|
||||
"pause_type": {
|
||||
"type": "human_input",
|
||||
"form_id": reason.form_id,
|
||||
"backstage_input_url": _build_backstage_input_url(reason.form_token),
|
||||
},
|
||||
}
|
||||
)
|
||||
else:
|
||||
raise AssertionError("unimplemented.")
|
||||
|
||||
return response, 200
|
||||
|
|
|
|||
|
|
@ -0,0 +1,217 @@
|
|||
"""
|
||||
Console/Studio Human Input Form APIs.
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
from collections.abc import Generator
|
||||
|
||||
from flask import Response, jsonify, request
|
||||
from flask_restx import Resource, reqparse
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session, sessionmaker
|
||||
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.wraps import account_initialization_required, setup_required
|
||||
from controllers.web.error import InvalidArgumentError, NotFoundError
|
||||
from core.app.apps.advanced_chat.app_generator import AdvancedChatAppGenerator
|
||||
from core.app.apps.common.workflow_response_converter import WorkflowResponseConverter
|
||||
from core.app.apps.message_generator import MessageGenerator
|
||||
from core.app.apps.workflow.app_generator import WorkflowAppGenerator
|
||||
from extensions.ext_database import db
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from models import App
|
||||
from models.enums import CreatorUserRole
|
||||
from models.human_input import RecipientType
|
||||
from models.model import AppMode
|
||||
from models.workflow import WorkflowRun
|
||||
from repositories.factory import DifyAPIRepositoryFactory
|
||||
from services.human_input_service import Form, HumanInputService
|
||||
from services.workflow_event_snapshot_service import build_workflow_event_stream
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _jsonify_form_definition(form: Form) -> Response:
|
||||
payload = form.get_definition().model_dump()
|
||||
payload["expiration_time"] = int(form.expiration_time.timestamp())
|
||||
return Response(json.dumps(payload, ensure_ascii=False), mimetype="application/json")
|
||||
|
||||
|
||||
@console_ns.route("/form/human_input/<string:form_token>")
|
||||
class ConsoleHumanInputFormApi(Resource):
|
||||
"""Console API for getting human input form definition."""
|
||||
|
||||
@staticmethod
|
||||
def _ensure_console_access(form: Form):
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
if form.tenant_id != current_tenant_id:
|
||||
raise NotFoundError("App not found")
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, form_token: str):
|
||||
"""
|
||||
Get human input form definition by form token.
|
||||
|
||||
GET /console/api/form/human_input/<form_token>
|
||||
"""
|
||||
service = HumanInputService(db.engine)
|
||||
form = service.get_form_definition_by_token_for_console(form_token)
|
||||
if form is None:
|
||||
raise NotFoundError(f"form not found, token={form_token}")
|
||||
|
||||
self._ensure_console_access(form)
|
||||
|
||||
return _jsonify_form_definition(form)
|
||||
|
||||
@account_initialization_required
|
||||
@login_required
|
||||
def post(self, form_token: str):
|
||||
"""
|
||||
Submit human input form by form token.
|
||||
|
||||
POST /console/api/form/human_input/<form_token>
|
||||
|
||||
Request body:
|
||||
{
|
||||
"inputs": {
|
||||
"content": "User input content"
|
||||
},
|
||||
"action": "Approve"
|
||||
}
|
||||
"""
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("inputs", type=dict, required=True, location="json")
|
||||
parser.add_argument("action", type=str, required=True, location="json")
|
||||
args = parser.parse_args()
|
||||
current_user, _ = current_account_with_tenant()
|
||||
|
||||
service = HumanInputService(db.engine)
|
||||
form = service.get_form_by_token(form_token)
|
||||
if form is None:
|
||||
raise NotFoundError(f"form not found, token={form_token}")
|
||||
|
||||
self._ensure_console_access(form)
|
||||
|
||||
recipient_type = form.recipient_type
|
||||
if recipient_type not in {RecipientType.CONSOLE, RecipientType.BACKSTAGE}:
|
||||
raise NotFoundError(f"form not found, token={form_token}")
|
||||
# The type checker is not smart enought to validate the following invariant.
|
||||
# So we need to assert it manually.
|
||||
assert recipient_type is not None, "recipient_type cannot be None here."
|
||||
|
||||
service.submit_form_by_token(
|
||||
recipient_type=recipient_type,
|
||||
form_token=form_token,
|
||||
selected_action_id=args["action"],
|
||||
form_data=args["inputs"],
|
||||
submission_user_id=current_user.id,
|
||||
)
|
||||
|
||||
return jsonify({})
|
||||
|
||||
|
||||
@console_ns.route("/workflow/<string:workflow_run_id>/events")
|
||||
class ConsoleWorkflowEventsApi(Resource):
|
||||
"""Console API for getting workflow execution events after resume."""
|
||||
|
||||
@account_initialization_required
|
||||
@login_required
|
||||
def get(self, workflow_run_id: str):
|
||||
"""
|
||||
Get workflow execution events stream after resume.
|
||||
|
||||
GET /console/api/workflow/<workflow_run_id>/events
|
||||
|
||||
Returns Server-Sent Events stream.
|
||||
"""
|
||||
|
||||
user, tenant_id = current_account_with_tenant()
|
||||
session_maker = sessionmaker(db.engine)
|
||||
repo = DifyAPIRepositoryFactory.create_api_workflow_run_repository(session_maker)
|
||||
workflow_run = repo.get_workflow_run_by_id_and_tenant_id(
|
||||
tenant_id=tenant_id,
|
||||
run_id=workflow_run_id,
|
||||
)
|
||||
if workflow_run is None:
|
||||
raise NotFoundError(f"WorkflowRun not found, id={workflow_run_id}")
|
||||
|
||||
if workflow_run.created_by_role != CreatorUserRole.ACCOUNT:
|
||||
raise NotFoundError(f"WorkflowRun not created by account, id={workflow_run_id}")
|
||||
|
||||
if workflow_run.created_by != user.id:
|
||||
raise NotFoundError(f"WorkflowRun not created by the current account, id={workflow_run_id}")
|
||||
|
||||
with Session(expire_on_commit=False, bind=db.engine) as session:
|
||||
app = _retrieve_app_for_workflow_run(session, workflow_run)
|
||||
|
||||
if workflow_run.finished_at is not None:
|
||||
# TODO(QuantumGhost): should we modify the handling for finished workflow run here?
|
||||
response = WorkflowResponseConverter.workflow_run_result_to_finish_response(
|
||||
task_id=workflow_run.id,
|
||||
workflow_run=workflow_run,
|
||||
creator_user=user,
|
||||
)
|
||||
|
||||
payload = response.model_dump(mode="json")
|
||||
payload["event"] = response.event.value
|
||||
|
||||
def _generate_finished_events() -> Generator[str, None, None]:
|
||||
yield f"data: {json.dumps(payload)}\n\n"
|
||||
|
||||
event_generator = _generate_finished_events
|
||||
|
||||
else:
|
||||
msg_generator = MessageGenerator()
|
||||
if app.mode == AppMode.ADVANCED_CHAT:
|
||||
generator = AdvancedChatAppGenerator()
|
||||
elif app.mode == AppMode.WORKFLOW:
|
||||
generator = WorkflowAppGenerator()
|
||||
else:
|
||||
raise InvalidArgumentError(f"cannot subscribe to workflow run, workflow_run_id={workflow_run.id}")
|
||||
|
||||
include_state_snapshot = request.args.get("include_state_snapshot", "false").lower() == "true"
|
||||
|
||||
def _generate_stream_events():
|
||||
if include_state_snapshot:
|
||||
return generator.convert_to_event_stream(
|
||||
build_workflow_event_stream(
|
||||
app_mode=AppMode(app.mode),
|
||||
workflow_run=workflow_run,
|
||||
tenant_id=workflow_run.tenant_id,
|
||||
app_id=workflow_run.app_id,
|
||||
session_maker=session_maker,
|
||||
)
|
||||
)
|
||||
return generator.convert_to_event_stream(
|
||||
msg_generator.retrieve_events(AppMode(app.mode), workflow_run.id),
|
||||
)
|
||||
|
||||
event_generator = _generate_stream_events
|
||||
|
||||
return Response(
|
||||
event_generator(),
|
||||
mimetype="text/event-stream",
|
||||
headers={
|
||||
"Cache-Control": "no-cache",
|
||||
"Connection": "keep-alive",
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
def _retrieve_app_for_workflow_run(session: Session, workflow_run: WorkflowRun):
|
||||
query = select(App).where(
|
||||
App.id == workflow_run.app_id,
|
||||
App.tenant_id == workflow_run.tenant_id,
|
||||
)
|
||||
app = session.scalars(query).first()
|
||||
if app is None:
|
||||
raise AssertionError(
|
||||
f"App not found for WorkflowRun, workflow_run_id={workflow_run.id}, "
|
||||
f"app_id={workflow_run.app_id}, tenant_id={workflow_run.tenant_id}"
|
||||
)
|
||||
|
||||
return app
|
||||
|
|
@ -33,8 +33,9 @@ from core.workflow.graph_engine.manager import GraphEngineManager
|
|||
from extensions.ext_database import db
|
||||
from fields.workflow_app_log_fields import build_workflow_app_log_pagination_model
|
||||
from libs import helper
|
||||
from libs.helper import TimestampField
|
||||
from libs.helper import OptionalTimestampField, TimestampField
|
||||
from models.model import App, AppMode, EndUser
|
||||
from models.workflow import WorkflowRun
|
||||
from repositories.factory import DifyAPIRepositoryFactory
|
||||
from services.app_generate_service import AppGenerateService
|
||||
from services.errors.app import IsDraftWorkflowError, WorkflowIdFormatError, WorkflowNotFoundError
|
||||
|
|
@ -63,17 +64,32 @@ class WorkflowLogQuery(BaseModel):
|
|||
|
||||
register_schema_models(service_api_ns, WorkflowRunPayload, WorkflowLogQuery)
|
||||
|
||||
|
||||
class WorkflowRunStatusField(fields.Raw):
|
||||
def output(self, key, obj: WorkflowRun, **kwargs):
|
||||
return obj.status.value
|
||||
|
||||
|
||||
class WorkflowRunOutputsField(fields.Raw):
|
||||
def output(self, key, obj: WorkflowRun, **kwargs):
|
||||
if obj.status == WorkflowExecutionStatus.PAUSED:
|
||||
return {}
|
||||
|
||||
outputs = obj.outputs_dict
|
||||
return outputs or {}
|
||||
|
||||
|
||||
workflow_run_fields = {
|
||||
"id": fields.String,
|
||||
"workflow_id": fields.String,
|
||||
"status": fields.String,
|
||||
"status": WorkflowRunStatusField,
|
||||
"inputs": fields.Raw,
|
||||
"outputs": fields.Raw,
|
||||
"outputs": WorkflowRunOutputsField,
|
||||
"error": fields.String,
|
||||
"total_steps": fields.Integer,
|
||||
"total_tokens": fields.Integer,
|
||||
"created_at": TimestampField,
|
||||
"finished_at": TimestampField,
|
||||
"finished_at": OptionalTimestampField,
|
||||
"elapsed_time": fields.Float,
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -23,6 +23,7 @@ from . import (
|
|||
feature,
|
||||
files,
|
||||
forgot_password,
|
||||
human_input_form,
|
||||
login,
|
||||
message,
|
||||
passport,
|
||||
|
|
@ -30,6 +31,7 @@ from . import (
|
|||
saved_message,
|
||||
site,
|
||||
workflow,
|
||||
workflow_events,
|
||||
)
|
||||
|
||||
api.add_namespace(web_ns)
|
||||
|
|
@ -44,6 +46,7 @@ __all__ = [
|
|||
"feature",
|
||||
"files",
|
||||
"forgot_password",
|
||||
"human_input_form",
|
||||
"login",
|
||||
"message",
|
||||
"passport",
|
||||
|
|
@ -52,4 +55,5 @@ __all__ = [
|
|||
"site",
|
||||
"web_ns",
|
||||
"workflow",
|
||||
"workflow_events",
|
||||
]
|
||||
|
|
|
|||
|
|
@ -117,6 +117,12 @@ class InvokeRateLimitError(BaseHTTPException):
|
|||
code = 429
|
||||
|
||||
|
||||
class WebFormRateLimitExceededError(BaseHTTPException):
|
||||
error_code = "web_form_rate_limit_exceeded"
|
||||
description = "Too many form requests. Please try again later."
|
||||
code = 429
|
||||
|
||||
|
||||
class NotFoundError(BaseHTTPException):
|
||||
error_code = "not_found"
|
||||
code = 404
|
||||
|
|
|
|||
|
|
@ -0,0 +1,164 @@
|
|||
"""
|
||||
Web App Human Input Form APIs.
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
from datetime import datetime
|
||||
|
||||
from flask import Response, request
|
||||
from flask_restx import Resource, reqparse
|
||||
from werkzeug.exceptions import Forbidden
|
||||
|
||||
from configs import dify_config
|
||||
from controllers.web import web_ns
|
||||
from controllers.web.error import NotFoundError, WebFormRateLimitExceededError
|
||||
from controllers.web.site import serialize_app_site_payload
|
||||
from extensions.ext_database import db
|
||||
from libs.helper import RateLimiter, extract_remote_ip
|
||||
from models.account import TenantStatus
|
||||
from models.model import App, Site
|
||||
from services.human_input_service import Form, FormNotFoundError, HumanInputService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_FORM_SUBMIT_RATE_LIMITER = RateLimiter(
|
||||
prefix="web_form_submit_rate_limit",
|
||||
max_attempts=dify_config.WEB_FORM_SUBMIT_RATE_LIMIT_MAX_ATTEMPTS,
|
||||
time_window=dify_config.WEB_FORM_SUBMIT_RATE_LIMIT_WINDOW_SECONDS,
|
||||
)
|
||||
_FORM_ACCESS_RATE_LIMITER = RateLimiter(
|
||||
prefix="web_form_access_rate_limit",
|
||||
max_attempts=dify_config.WEB_FORM_SUBMIT_RATE_LIMIT_MAX_ATTEMPTS,
|
||||
time_window=dify_config.WEB_FORM_SUBMIT_RATE_LIMIT_WINDOW_SECONDS,
|
||||
)
|
||||
|
||||
|
||||
def _stringify_default_values(values: dict[str, object]) -> dict[str, str]:
|
||||
result: dict[str, str] = {}
|
||||
for key, value in values.items():
|
||||
if value is None:
|
||||
result[key] = ""
|
||||
elif isinstance(value, (dict, list)):
|
||||
result[key] = json.dumps(value, ensure_ascii=False)
|
||||
else:
|
||||
result[key] = str(value)
|
||||
return result
|
||||
|
||||
|
||||
def _to_timestamp(value: datetime) -> int:
|
||||
return int(value.timestamp())
|
||||
|
||||
|
||||
def _jsonify_form_definition(form: Form, site_payload: dict | None = None) -> Response:
|
||||
"""Return the form payload (optionally with site) as a JSON response."""
|
||||
definition_payload = form.get_definition().model_dump()
|
||||
payload = {
|
||||
"form_content": definition_payload["rendered_content"],
|
||||
"inputs": definition_payload["inputs"],
|
||||
"resolved_default_values": _stringify_default_values(definition_payload["default_values"]),
|
||||
"user_actions": definition_payload["user_actions"],
|
||||
"expiration_time": _to_timestamp(form.expiration_time),
|
||||
}
|
||||
if site_payload is not None:
|
||||
payload["site"] = site_payload
|
||||
return Response(json.dumps(payload, ensure_ascii=False), mimetype="application/json")
|
||||
|
||||
|
||||
# TODO(QuantumGhost): disable authorization for web app
|
||||
# form api temporarily
|
||||
|
||||
|
||||
@web_ns.route("/form/human_input/<string:form_token>")
|
||||
# class HumanInputFormApi(WebApiResource):
|
||||
class HumanInputFormApi(Resource):
|
||||
"""API for getting and submitting human input forms via the web app."""
|
||||
|
||||
# def get(self, _app_model: App, _end_user: EndUser, form_token: str):
|
||||
def get(self, form_token: str):
|
||||
"""
|
||||
Get human input form definition by token.
|
||||
|
||||
GET /api/form/human_input/<form_token>
|
||||
"""
|
||||
ip_address = extract_remote_ip(request)
|
||||
if _FORM_ACCESS_RATE_LIMITER.is_rate_limited(ip_address):
|
||||
raise WebFormRateLimitExceededError()
|
||||
_FORM_ACCESS_RATE_LIMITER.increment_rate_limit(ip_address)
|
||||
|
||||
service = HumanInputService(db.engine)
|
||||
# TODO(QuantumGhost): forbid submision for form tokens
|
||||
# that are only for console.
|
||||
form = service.get_form_by_token(form_token)
|
||||
|
||||
if form is None:
|
||||
raise NotFoundError("Form not found")
|
||||
|
||||
service.ensure_form_active(form)
|
||||
app_model, site = _get_app_site_from_form(form)
|
||||
|
||||
return _jsonify_form_definition(form, site_payload=serialize_app_site_payload(app_model, site, None))
|
||||
|
||||
# def post(self, _app_model: App, _end_user: EndUser, form_token: str):
|
||||
def post(self, form_token: str):
|
||||
"""
|
||||
Submit human input form by token.
|
||||
|
||||
POST /api/form/human_input/<form_token>
|
||||
|
||||
Request body:
|
||||
{
|
||||
"inputs": {
|
||||
"content": "User input content"
|
||||
},
|
||||
"action": "Approve"
|
||||
}
|
||||
"""
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("inputs", type=dict, required=True, location="json")
|
||||
parser.add_argument("action", type=str, required=True, location="json")
|
||||
args = parser.parse_args()
|
||||
|
||||
ip_address = extract_remote_ip(request)
|
||||
if _FORM_SUBMIT_RATE_LIMITER.is_rate_limited(ip_address):
|
||||
raise WebFormRateLimitExceededError()
|
||||
_FORM_SUBMIT_RATE_LIMITER.increment_rate_limit(ip_address)
|
||||
|
||||
service = HumanInputService(db.engine)
|
||||
form = service.get_form_by_token(form_token)
|
||||
if form is None:
|
||||
raise NotFoundError("Form not found")
|
||||
|
||||
if (recipient_type := form.recipient_type) is None:
|
||||
logger.warning("Recipient type is None for form, form_id=%", form.id)
|
||||
raise AssertionError("Recipient type is None")
|
||||
|
||||
try:
|
||||
service.submit_form_by_token(
|
||||
recipient_type=recipient_type,
|
||||
form_token=form_token,
|
||||
selected_action_id=args["action"],
|
||||
form_data=args["inputs"],
|
||||
submission_end_user_id=None,
|
||||
# submission_end_user_id=_end_user.id,
|
||||
)
|
||||
except FormNotFoundError:
|
||||
raise NotFoundError("Form not found")
|
||||
|
||||
return {}, 200
|
||||
|
||||
|
||||
def _get_app_site_from_form(form: Form) -> tuple[App, Site]:
|
||||
"""Resolve App/Site for the form's app and validate tenant status."""
|
||||
app_model = db.session.query(App).where(App.id == form.app_id).first()
|
||||
if app_model is None or app_model.tenant_id != form.tenant_id:
|
||||
raise NotFoundError("Form not found")
|
||||
|
||||
site = db.session.query(Site).where(Site.app_id == app_model.id).first()
|
||||
if site is None:
|
||||
raise Forbidden()
|
||||
|
||||
if app_model.tenant and app_model.tenant.status == TenantStatus.ARCHIVE:
|
||||
raise Forbidden()
|
||||
|
||||
return app_model, site
|
||||
|
|
@ -1,4 +1,6 @@
|
|||
from flask_restx import fields, marshal_with
|
||||
from typing import cast
|
||||
|
||||
from flask_restx import fields, marshal, marshal_with
|
||||
from werkzeug.exceptions import Forbidden
|
||||
|
||||
from configs import dify_config
|
||||
|
|
@ -7,7 +9,7 @@ from controllers.web.wraps import WebApiResource
|
|||
from extensions.ext_database import db
|
||||
from libs.helper import AppIconUrlField
|
||||
from models.account import TenantStatus
|
||||
from models.model import Site
|
||||
from models.model import App, Site
|
||||
from services.feature_service import FeatureService
|
||||
|
||||
|
||||
|
|
@ -108,3 +110,14 @@ class AppSiteInfo:
|
|||
"remove_webapp_brand": remove_webapp_brand,
|
||||
"replace_webapp_logo": replace_webapp_logo,
|
||||
}
|
||||
|
||||
|
||||
def serialize_site(site: Site) -> dict:
|
||||
"""Serialize Site model using the same schema as AppSiteApi."""
|
||||
return cast(dict, marshal(site, AppSiteApi.site_fields))
|
||||
|
||||
|
||||
def serialize_app_site_payload(app_model: App, site: Site, end_user_id: str | None) -> dict:
|
||||
can_replace_logo = FeatureService.get_features(app_model.tenant_id).can_replace_logo
|
||||
app_site_info = AppSiteInfo(app_model.tenant, app_model, site, end_user_id, can_replace_logo)
|
||||
return cast(dict, marshal(app_site_info, AppSiteApi.app_fields))
|
||||
|
|
|
|||
|
|
@ -0,0 +1,112 @@
|
|||
"""
|
||||
Web App Workflow Resume APIs.
|
||||
"""
|
||||
|
||||
import json
|
||||
from collections.abc import Generator
|
||||
|
||||
from flask import Response, request
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
from controllers.web import api
|
||||
from controllers.web.error import InvalidArgumentError, NotFoundError
|
||||
from controllers.web.wraps import WebApiResource
|
||||
from core.app.apps.advanced_chat.app_generator import AdvancedChatAppGenerator
|
||||
from core.app.apps.base_app_generator import BaseAppGenerator
|
||||
from core.app.apps.common.workflow_response_converter import WorkflowResponseConverter
|
||||
from core.app.apps.message_generator import MessageGenerator
|
||||
from core.app.apps.workflow.app_generator import WorkflowAppGenerator
|
||||
from extensions.ext_database import db
|
||||
from models.enums import CreatorUserRole
|
||||
from models.model import App, AppMode, EndUser
|
||||
from repositories.factory import DifyAPIRepositoryFactory
|
||||
from services.workflow_event_snapshot_service import build_workflow_event_stream
|
||||
|
||||
|
||||
class WorkflowEventsApi(WebApiResource):
|
||||
"""API for getting workflow execution events after resume."""
|
||||
|
||||
def get(self, app_model: App, end_user: EndUser, task_id: str):
|
||||
"""
|
||||
Get workflow execution events stream after resume.
|
||||
|
||||
GET /api/workflow/<task_id>/events
|
||||
|
||||
Returns Server-Sent Events stream.
|
||||
"""
|
||||
workflow_run_id = task_id
|
||||
session_maker = sessionmaker(db.engine)
|
||||
repo = DifyAPIRepositoryFactory.create_api_workflow_run_repository(session_maker)
|
||||
workflow_run = repo.get_workflow_run_by_id_and_tenant_id(
|
||||
tenant_id=app_model.tenant_id,
|
||||
run_id=workflow_run_id,
|
||||
)
|
||||
|
||||
if workflow_run is None:
|
||||
raise NotFoundError(f"WorkflowRun not found, id={workflow_run_id}")
|
||||
|
||||
if workflow_run.app_id != app_model.id:
|
||||
raise NotFoundError(f"WorkflowRun not found, id={workflow_run_id}")
|
||||
|
||||
if workflow_run.created_by_role != CreatorUserRole.END_USER:
|
||||
raise NotFoundError(f"WorkflowRun not created by end user, id={workflow_run_id}")
|
||||
|
||||
if workflow_run.created_by != end_user.id:
|
||||
raise NotFoundError(f"WorkflowRun not created by the current end user, id={workflow_run_id}")
|
||||
|
||||
if workflow_run.finished_at is not None:
|
||||
response = WorkflowResponseConverter.workflow_run_result_to_finish_response(
|
||||
task_id=workflow_run.id,
|
||||
workflow_run=workflow_run,
|
||||
creator_user=end_user,
|
||||
)
|
||||
|
||||
payload = response.model_dump(mode="json")
|
||||
payload["event"] = response.event.value
|
||||
|
||||
def _generate_finished_events() -> Generator[str, None, None]:
|
||||
yield f"data: {json.dumps(payload)}\n\n"
|
||||
|
||||
event_generator = _generate_finished_events
|
||||
else:
|
||||
app_mode = AppMode.value_of(app_model.mode)
|
||||
msg_generator = MessageGenerator()
|
||||
generator: BaseAppGenerator
|
||||
if app_mode == AppMode.ADVANCED_CHAT:
|
||||
generator = AdvancedChatAppGenerator()
|
||||
elif app_mode == AppMode.WORKFLOW:
|
||||
generator = WorkflowAppGenerator()
|
||||
else:
|
||||
raise InvalidArgumentError(f"cannot subscribe to workflow run, workflow_run_id={workflow_run.id}")
|
||||
|
||||
include_state_snapshot = request.args.get("include_state_snapshot", "false").lower() == "true"
|
||||
|
||||
def _generate_stream_events():
|
||||
if include_state_snapshot:
|
||||
return generator.convert_to_event_stream(
|
||||
build_workflow_event_stream(
|
||||
app_mode=app_mode,
|
||||
workflow_run=workflow_run,
|
||||
tenant_id=app_model.tenant_id,
|
||||
app_id=app_model.id,
|
||||
session_maker=session_maker,
|
||||
)
|
||||
)
|
||||
return generator.convert_to_event_stream(
|
||||
msg_generator.retrieve_events(app_mode, workflow_run.id),
|
||||
)
|
||||
|
||||
event_generator = _generate_stream_events
|
||||
|
||||
return Response(
|
||||
event_generator(),
|
||||
mimetype="text/event-stream",
|
||||
headers={
|
||||
"Cache-Control": "no-cache",
|
||||
"Connection": "keep-alive",
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
# Register the APIs
|
||||
api.add_resource(WorkflowEventsApi, "/workflow/<string:task_id>/events")
|
||||
|
|
@ -4,8 +4,8 @@ import contextvars
|
|||
import logging
|
||||
import threading
|
||||
import uuid
|
||||
from collections.abc import Generator, Mapping
|
||||
from typing import TYPE_CHECKING, Any, Literal, Union, overload
|
||||
from collections.abc import Generator, Mapping, Sequence
|
||||
from typing import TYPE_CHECKING, Any, Literal, TypeVar, Union, overload
|
||||
|
||||
from flask import Flask, current_app
|
||||
from pydantic import ValidationError
|
||||
|
|
@ -29,21 +29,25 @@ from core.app.apps.message_based_app_generator import MessageBasedAppGenerator
|
|||
from core.app.apps.message_based_app_queue_manager import MessageBasedAppQueueManager
|
||||
from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, InvokeFrom
|
||||
from core.app.entities.task_entities import ChatbotAppBlockingResponse, ChatbotAppStreamResponse
|
||||
from core.app.layers.pause_state_persist_layer import PauseStateLayerConfig, PauseStatePersistenceLayer
|
||||
from core.helper.trace_id_helper import extract_external_trace_id_from_args
|
||||
from core.model_runtime.errors.invoke import InvokeAuthorizationError
|
||||
from core.ops.ops_trace_manager import TraceQueueManager
|
||||
from core.prompt.utils.get_thread_messages_length import get_thread_messages_length
|
||||
from core.repositories import DifyCoreRepositoryFactory
|
||||
from core.workflow.graph_engine.layers.base import GraphEngineLayer
|
||||
from core.workflow.repositories.draft_variable_repository import (
|
||||
DraftVariableSaverFactory,
|
||||
)
|
||||
from core.workflow.repositories.workflow_execution_repository import WorkflowExecutionRepository
|
||||
from core.workflow.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository
|
||||
from core.workflow.runtime import GraphRuntimeState
|
||||
from core.workflow.variable_loader import DUMMY_VARIABLE_LOADER, VariableLoader
|
||||
from extensions.ext_database import db
|
||||
from factories import file_factory
|
||||
from libs.flask_utils import preserve_flask_contexts
|
||||
from models import Account, App, Conversation, EndUser, Message, Workflow, WorkflowNodeExecutionTriggeredFrom
|
||||
from models.base import Base
|
||||
from models.enums import WorkflowRunTriggeredFrom
|
||||
from services.conversation_service import ConversationService
|
||||
from services.workflow_draft_variable_service import (
|
||||
|
|
@ -65,7 +69,9 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
|||
user: Union[Account, EndUser],
|
||||
args: Mapping[str, Any],
|
||||
invoke_from: InvokeFrom,
|
||||
workflow_run_id: str,
|
||||
streaming: Literal[False],
|
||||
pause_state_config: PauseStateLayerConfig | None = None,
|
||||
) -> Mapping[str, Any]: ...
|
||||
|
||||
@overload
|
||||
|
|
@ -74,9 +80,11 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
|||
app_model: App,
|
||||
workflow: Workflow,
|
||||
user: Union[Account, EndUser],
|
||||
args: Mapping,
|
||||
args: Mapping[str, Any],
|
||||
invoke_from: InvokeFrom,
|
||||
workflow_run_id: str,
|
||||
streaming: Literal[True],
|
||||
pause_state_config: PauseStateLayerConfig | None = None,
|
||||
) -> Generator[Mapping | str, None, None]: ...
|
||||
|
||||
@overload
|
||||
|
|
@ -85,9 +93,11 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
|||
app_model: App,
|
||||
workflow: Workflow,
|
||||
user: Union[Account, EndUser],
|
||||
args: Mapping,
|
||||
args: Mapping[str, Any],
|
||||
invoke_from: InvokeFrom,
|
||||
workflow_run_id: str,
|
||||
streaming: bool,
|
||||
pause_state_config: PauseStateLayerConfig | None = None,
|
||||
) -> Mapping[str, Any] | Generator[str | Mapping, None, None]: ...
|
||||
|
||||
def generate(
|
||||
|
|
@ -95,9 +105,11 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
|||
app_model: App,
|
||||
workflow: Workflow,
|
||||
user: Union[Account, EndUser],
|
||||
args: Mapping,
|
||||
args: Mapping[str, Any],
|
||||
invoke_from: InvokeFrom,
|
||||
workflow_run_id: str,
|
||||
streaming: bool = True,
|
||||
pause_state_config: PauseStateLayerConfig | None = None,
|
||||
) -> Mapping[str, Any] | Generator[str | Mapping, None, None]:
|
||||
"""
|
||||
Generate App response.
|
||||
|
|
@ -161,7 +173,6 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
|||
# always enable retriever resource in debugger mode
|
||||
app_config.additional_features.show_retrieve_source = True # type: ignore
|
||||
|
||||
workflow_run_id = str(uuid.uuid4())
|
||||
# init application generate entity
|
||||
application_generate_entity = AdvancedChatAppGenerateEntity(
|
||||
task_id=str(uuid.uuid4()),
|
||||
|
|
@ -179,7 +190,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
|||
invoke_from=invoke_from,
|
||||
extras=extras,
|
||||
trace_manager=trace_manager,
|
||||
workflow_run_id=workflow_run_id,
|
||||
workflow_run_id=str(workflow_run_id),
|
||||
)
|
||||
contexts.plugin_tool_providers.set({})
|
||||
contexts.plugin_tool_providers_lock.set(threading.Lock())
|
||||
|
|
@ -216,6 +227,38 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
|||
workflow_node_execution_repository=workflow_node_execution_repository,
|
||||
conversation=conversation,
|
||||
stream=streaming,
|
||||
pause_state_config=pause_state_config,
|
||||
)
|
||||
|
||||
def resume(
|
||||
self,
|
||||
*,
|
||||
app_model: App,
|
||||
workflow: Workflow,
|
||||
user: Union[Account, EndUser],
|
||||
conversation: Conversation,
|
||||
message: Message,
|
||||
application_generate_entity: AdvancedChatAppGenerateEntity,
|
||||
workflow_execution_repository: WorkflowExecutionRepository,
|
||||
workflow_node_execution_repository: WorkflowNodeExecutionRepository,
|
||||
graph_runtime_state: GraphRuntimeState,
|
||||
pause_state_config: PauseStateLayerConfig | None = None,
|
||||
):
|
||||
"""
|
||||
Resume a paused advanced chat execution.
|
||||
"""
|
||||
return self._generate(
|
||||
workflow=workflow,
|
||||
user=user,
|
||||
invoke_from=application_generate_entity.invoke_from,
|
||||
application_generate_entity=application_generate_entity,
|
||||
workflow_execution_repository=workflow_execution_repository,
|
||||
workflow_node_execution_repository=workflow_node_execution_repository,
|
||||
conversation=conversation,
|
||||
message=message,
|
||||
stream=application_generate_entity.stream,
|
||||
pause_state_config=pause_state_config,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
)
|
||||
|
||||
def single_iteration_generate(
|
||||
|
|
@ -396,8 +439,12 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
|||
workflow_execution_repository: WorkflowExecutionRepository,
|
||||
workflow_node_execution_repository: WorkflowNodeExecutionRepository,
|
||||
conversation: Conversation | None = None,
|
||||
message: Message | None = None,
|
||||
stream: bool = True,
|
||||
variable_loader: VariableLoader = DUMMY_VARIABLE_LOADER,
|
||||
pause_state_config: PauseStateLayerConfig | None = None,
|
||||
graph_runtime_state: GraphRuntimeState | None = None,
|
||||
graph_engine_layers: Sequence[GraphEngineLayer] = (),
|
||||
) -> Mapping[str, Any] | Generator[str | Mapping[str, Any], Any, None]:
|
||||
"""
|
||||
Generate App response.
|
||||
|
|
@ -411,12 +458,12 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
|||
:param conversation: conversation
|
||||
:param stream: is stream
|
||||
"""
|
||||
is_first_conversation = False
|
||||
if not conversation:
|
||||
is_first_conversation = True
|
||||
is_first_conversation = conversation is None
|
||||
|
||||
# init generate records
|
||||
(conversation, message) = self._init_generate_records(application_generate_entity, conversation)
|
||||
if conversation is not None and message is not None:
|
||||
pass
|
||||
else:
|
||||
conversation, message = self._init_generate_records(application_generate_entity, conversation)
|
||||
|
||||
if is_first_conversation:
|
||||
# update conversation features
|
||||
|
|
@ -439,6 +486,16 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
|||
message_id=message.id,
|
||||
)
|
||||
|
||||
graph_layers: list[GraphEngineLayer] = list(graph_engine_layers)
|
||||
if pause_state_config is not None:
|
||||
graph_layers.append(
|
||||
PauseStatePersistenceLayer(
|
||||
session_factory=pause_state_config.session_factory,
|
||||
generate_entity=application_generate_entity,
|
||||
state_owner_user_id=pause_state_config.state_owner_user_id,
|
||||
)
|
||||
)
|
||||
|
||||
# new thread with request context and contextvars
|
||||
context = contextvars.copy_context()
|
||||
|
||||
|
|
@ -454,14 +511,25 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
|||
"variable_loader": variable_loader,
|
||||
"workflow_execution_repository": workflow_execution_repository,
|
||||
"workflow_node_execution_repository": workflow_node_execution_repository,
|
||||
"graph_engine_layers": tuple(graph_layers),
|
||||
"graph_runtime_state": graph_runtime_state,
|
||||
},
|
||||
)
|
||||
|
||||
worker_thread.start()
|
||||
|
||||
# release database connection, because the following new thread operations may take a long time
|
||||
db.session.refresh(workflow)
|
||||
db.session.refresh(message)
|
||||
with Session(bind=db.engine, expire_on_commit=False) as session:
|
||||
workflow = _refresh_model(session, workflow)
|
||||
message = _refresh_model(session, message)
|
||||
# workflow_ = session.get(Workflow, workflow.id)
|
||||
# assert workflow_ is not None
|
||||
# workflow = workflow_
|
||||
# message_ = session.get(Message, message.id)
|
||||
# assert message_ is not None
|
||||
# message = message_
|
||||
# db.session.refresh(workflow)
|
||||
# db.session.refresh(message)
|
||||
# db.session.refresh(user)
|
||||
db.session.close()
|
||||
|
||||
|
|
@ -490,6 +558,8 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
|||
variable_loader: VariableLoader,
|
||||
workflow_execution_repository: WorkflowExecutionRepository,
|
||||
workflow_node_execution_repository: WorkflowNodeExecutionRepository,
|
||||
graph_engine_layers: Sequence[GraphEngineLayer] = (),
|
||||
graph_runtime_state: GraphRuntimeState | None = None,
|
||||
):
|
||||
"""
|
||||
Generate worker in a new thread.
|
||||
|
|
@ -547,6 +617,8 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
|||
app=app,
|
||||
workflow_execution_repository=workflow_execution_repository,
|
||||
workflow_node_execution_repository=workflow_node_execution_repository,
|
||||
graph_engine_layers=graph_engine_layers,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
)
|
||||
|
||||
try:
|
||||
|
|
@ -614,3 +686,13 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
|||
else:
|
||||
logger.exception("Failed to process generate task pipeline, conversation_id: %s", conversation.id)
|
||||
raise e
|
||||
|
||||
|
||||
_T = TypeVar("_T", bound=Base)
|
||||
|
||||
|
||||
def _refresh_model(session, model: _T) -> _T:
|
||||
with Session(bind=db.engine, expire_on_commit=False) as session:
|
||||
detach_model = session.get(type(model), model.id)
|
||||
assert detach_model is not None
|
||||
return detach_model
|
||||
|
|
|
|||
|
|
@ -66,6 +66,7 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
|
|||
workflow_execution_repository: WorkflowExecutionRepository,
|
||||
workflow_node_execution_repository: WorkflowNodeExecutionRepository,
|
||||
graph_engine_layers: Sequence[GraphEngineLayer] = (),
|
||||
graph_runtime_state: GraphRuntimeState | None = None,
|
||||
):
|
||||
super().__init__(
|
||||
queue_manager=queue_manager,
|
||||
|
|
@ -82,6 +83,7 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
|
|||
self._app = app
|
||||
self._workflow_execution_repository = workflow_execution_repository
|
||||
self._workflow_node_execution_repository = workflow_node_execution_repository
|
||||
self._resume_graph_runtime_state = graph_runtime_state
|
||||
|
||||
@trace_span(WorkflowAppRunnerHandler)
|
||||
def run(self):
|
||||
|
|
@ -110,7 +112,21 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
|
|||
invoke_from = InvokeFrom.DEBUGGER
|
||||
user_from = self._resolve_user_from(invoke_from)
|
||||
|
||||
if self.application_generate_entity.single_iteration_run or self.application_generate_entity.single_loop_run:
|
||||
resume_state = self._resume_graph_runtime_state
|
||||
|
||||
if resume_state is not None:
|
||||
graph_runtime_state = resume_state
|
||||
variable_pool = graph_runtime_state.variable_pool
|
||||
graph = self._init_graph(
|
||||
graph_config=self._workflow.graph_dict,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
workflow_id=self._workflow.id,
|
||||
tenant_id=self._workflow.tenant_id,
|
||||
user_id=self.application_generate_entity.user_id,
|
||||
invoke_from=invoke_from,
|
||||
user_from=user_from,
|
||||
)
|
||||
elif self.application_generate_entity.single_iteration_run or self.application_generate_entity.single_loop_run:
|
||||
# Handle single iteration or single loop run
|
||||
graph, variable_pool, graph_runtime_state = self._prepare_single_node_execution(
|
||||
workflow=self._workflow,
|
||||
|
|
|
|||
|
|
@ -24,6 +24,8 @@ from core.app.entities.queue_entities import (
|
|||
QueueAgentLogEvent,
|
||||
QueueAnnotationReplyEvent,
|
||||
QueueErrorEvent,
|
||||
QueueHumanInputFormFilledEvent,
|
||||
QueueHumanInputFormTimeoutEvent,
|
||||
QueueIterationCompletedEvent,
|
||||
QueueIterationNextEvent,
|
||||
QueueIterationStartEvent,
|
||||
|
|
@ -42,6 +44,7 @@ from core.app.entities.queue_entities import (
|
|||
QueueTextChunkEvent,
|
||||
QueueWorkflowFailedEvent,
|
||||
QueueWorkflowPartialSuccessEvent,
|
||||
QueueWorkflowPausedEvent,
|
||||
QueueWorkflowStartedEvent,
|
||||
QueueWorkflowSucceededEvent,
|
||||
WorkflowQueueMessage,
|
||||
|
|
@ -63,6 +66,8 @@ from core.base.tts import AppGeneratorTTSPublisher, AudioTrunk
|
|||
from core.model_runtime.entities.llm_entities import LLMUsage
|
||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
from core.ops.ops_trace_manager import TraceQueueManager
|
||||
from core.repositories.human_input_repository import HumanInputFormRepositoryImpl
|
||||
from core.workflow.entities.pause_reason import HumanInputRequired
|
||||
from core.workflow.enums import WorkflowExecutionStatus
|
||||
from core.workflow.nodes import NodeType
|
||||
from core.workflow.repositories.draft_variable_repository import DraftVariableSaverFactory
|
||||
|
|
@ -71,7 +76,8 @@ from core.workflow.system_variable import SystemVariable
|
|||
from extensions.ext_database import db
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from models import Account, Conversation, EndUser, Message, MessageFile
|
||||
from models.enums import CreatorUserRole
|
||||
from models.enums import CreatorUserRole, MessageStatus
|
||||
from models.execution_extra_content import HumanInputContent
|
||||
from models.workflow import Workflow
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
|
@ -128,6 +134,7 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport):
|
|||
)
|
||||
|
||||
self._task_state = WorkflowTaskState()
|
||||
self._seed_task_state_from_message(message)
|
||||
self._message_cycle_manager = MessageCycleManager(
|
||||
application_generate_entity=application_generate_entity, task_state=self._task_state
|
||||
)
|
||||
|
|
@ -135,6 +142,7 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport):
|
|||
self._application_generate_entity = application_generate_entity
|
||||
self._workflow_id = workflow.id
|
||||
self._workflow_features_dict = workflow.features_dict
|
||||
self._workflow_tenant_id = workflow.tenant_id
|
||||
self._conversation_id = conversation.id
|
||||
self._conversation_mode = conversation.mode
|
||||
self._message_id = message.id
|
||||
|
|
@ -144,8 +152,13 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport):
|
|||
self._workflow_run_id: str = ""
|
||||
self._draft_var_saver_factory = draft_var_saver_factory
|
||||
self._graph_runtime_state: GraphRuntimeState | None = None
|
||||
self._message_saved_on_pause = False
|
||||
self._seed_graph_runtime_state_from_queue_manager()
|
||||
|
||||
def _seed_task_state_from_message(self, message: Message) -> None:
|
||||
if message.status == MessageStatus.PAUSED and message.answer:
|
||||
self._task_state.answer = message.answer
|
||||
|
||||
def process(self) -> Union[ChatbotAppBlockingResponse, Generator[ChatbotAppStreamResponse, None, None]]:
|
||||
"""
|
||||
Process generate task pipeline.
|
||||
|
|
@ -308,6 +321,7 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport):
|
|||
task_id=self._application_generate_entity.task_id,
|
||||
workflow_run_id=run_id,
|
||||
workflow_id=self._workflow_id,
|
||||
reason=event.reason,
|
||||
)
|
||||
|
||||
yield workflow_start_resp
|
||||
|
|
@ -525,6 +539,35 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport):
|
|||
)
|
||||
|
||||
yield workflow_finish_resp
|
||||
|
||||
def _handle_workflow_paused_event(
|
||||
self,
|
||||
event: QueueWorkflowPausedEvent,
|
||||
**kwargs,
|
||||
) -> Generator[StreamResponse, None, None]:
|
||||
"""Handle workflow paused events."""
|
||||
validated_state = self._ensure_graph_runtime_initialized()
|
||||
responses = self._workflow_response_converter.workflow_pause_to_stream_response(
|
||||
event=event,
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
graph_runtime_state=validated_state,
|
||||
)
|
||||
for reason in event.reasons:
|
||||
if isinstance(reason, HumanInputRequired):
|
||||
self._persist_human_input_extra_content(form_id=reason.form_id, node_id=reason.node_id)
|
||||
yield from responses
|
||||
resolved_state: GraphRuntimeState | None = None
|
||||
try:
|
||||
resolved_state = self._ensure_graph_runtime_initialized()
|
||||
except ValueError:
|
||||
resolved_state = None
|
||||
|
||||
with self._database_session() as session:
|
||||
self._save_message(session=session, graph_runtime_state=resolved_state)
|
||||
message = self._get_message(session=session)
|
||||
if message is not None:
|
||||
message.status = MessageStatus.PAUSED
|
||||
self._message_saved_on_pause = True
|
||||
self._base_task_pipeline.queue_manager.publish(QueueAdvancedChatMessageEndEvent(), PublishFrom.TASK_PIPELINE)
|
||||
|
||||
def _handle_workflow_failed_event(
|
||||
|
|
@ -614,9 +657,10 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport):
|
|||
reason=QueueMessageReplaceEvent.MessageReplaceReason.OUTPUT_MODERATION,
|
||||
)
|
||||
|
||||
# Save message
|
||||
with self._database_session() as session:
|
||||
self._save_message(session=session, graph_runtime_state=resolved_state)
|
||||
# Save message unless it has already been persisted on pause.
|
||||
if not self._message_saved_on_pause:
|
||||
with self._database_session() as session:
|
||||
self._save_message(session=session, graph_runtime_state=resolved_state)
|
||||
|
||||
yield self._message_end_to_stream_response()
|
||||
|
||||
|
|
@ -642,6 +686,65 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport):
|
|||
"""Handle message replace events."""
|
||||
yield self._message_cycle_manager.message_replace_to_stream_response(answer=event.text, reason=event.reason)
|
||||
|
||||
def _handle_human_input_form_filled_event(
|
||||
self, event: QueueHumanInputFormFilledEvent, **kwargs
|
||||
) -> Generator[StreamResponse, None, None]:
|
||||
"""Handle human input form filled events."""
|
||||
self._persist_human_input_extra_content(node_id=event.node_id)
|
||||
yield self._workflow_response_converter.human_input_form_filled_to_stream_response(
|
||||
event=event, task_id=self._application_generate_entity.task_id
|
||||
)
|
||||
|
||||
def _handle_human_input_form_timeout_event(
|
||||
self, event: QueueHumanInputFormTimeoutEvent, **kwargs
|
||||
) -> Generator[StreamResponse, None, None]:
|
||||
"""Handle human input form timeout events."""
|
||||
yield self._workflow_response_converter.human_input_form_timeout_to_stream_response(
|
||||
event=event, task_id=self._application_generate_entity.task_id
|
||||
)
|
||||
|
||||
def _persist_human_input_extra_content(self, *, node_id: str | None = None, form_id: str | None = None) -> None:
|
||||
if not self._workflow_run_id or not self._message_id:
|
||||
return
|
||||
|
||||
if form_id is None:
|
||||
if node_id is None:
|
||||
return
|
||||
form_id = self._load_human_input_form_id(node_id=node_id)
|
||||
if form_id is None:
|
||||
logger.warning(
|
||||
"HumanInput form not found for workflow run %s node %s",
|
||||
self._workflow_run_id,
|
||||
node_id,
|
||||
)
|
||||
return
|
||||
|
||||
with self._database_session() as session:
|
||||
exists_stmt = select(HumanInputContent).where(
|
||||
HumanInputContent.workflow_run_id == self._workflow_run_id,
|
||||
HumanInputContent.message_id == self._message_id,
|
||||
HumanInputContent.form_id == form_id,
|
||||
)
|
||||
if session.scalar(exists_stmt) is not None:
|
||||
return
|
||||
|
||||
content = HumanInputContent(
|
||||
workflow_run_id=self._workflow_run_id,
|
||||
message_id=self._message_id,
|
||||
form_id=form_id,
|
||||
)
|
||||
session.add(content)
|
||||
|
||||
def _load_human_input_form_id(self, *, node_id: str) -> str | None:
|
||||
form_repository = HumanInputFormRepositoryImpl(
|
||||
session_factory=db.engine,
|
||||
tenant_id=self._workflow_tenant_id,
|
||||
)
|
||||
form = form_repository.get_form(self._workflow_run_id, node_id)
|
||||
if form is None:
|
||||
return None
|
||||
return form.id
|
||||
|
||||
def _handle_agent_log_event(self, event: QueueAgentLogEvent, **kwargs) -> Generator[StreamResponse, None, None]:
|
||||
"""Handle agent log events."""
|
||||
yield self._workflow_response_converter.handle_agent_log(
|
||||
|
|
@ -659,6 +762,7 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport):
|
|||
QueueWorkflowStartedEvent: self._handle_workflow_started_event,
|
||||
QueueWorkflowSucceededEvent: self._handle_workflow_succeeded_event,
|
||||
QueueWorkflowPartialSuccessEvent: self._handle_workflow_partial_success_event,
|
||||
QueueWorkflowPausedEvent: self._handle_workflow_paused_event,
|
||||
QueueWorkflowFailedEvent: self._handle_workflow_failed_event,
|
||||
# Node events
|
||||
QueueNodeRetryEvent: self._handle_node_retry_event,
|
||||
|
|
@ -680,6 +784,8 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport):
|
|||
QueueMessageReplaceEvent: self._handle_message_replace_event,
|
||||
QueueAdvancedChatMessageEndEvent: self._handle_advanced_chat_message_end_event,
|
||||
QueueAgentLogEvent: self._handle_agent_log_event,
|
||||
QueueHumanInputFormFilledEvent: self._handle_human_input_form_filled_event,
|
||||
QueueHumanInputFormTimeoutEvent: self._handle_human_input_form_timeout_event,
|
||||
}
|
||||
|
||||
def _dispatch_event(
|
||||
|
|
@ -747,6 +853,9 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport):
|
|||
case QueueWorkflowFailedEvent():
|
||||
yield from self._handle_workflow_failed_event(event, trace_manager=trace_manager)
|
||||
break
|
||||
case QueueWorkflowPausedEvent():
|
||||
yield from self._handle_workflow_paused_event(event)
|
||||
break
|
||||
|
||||
case QueueStopEvent():
|
||||
yield from self._handle_stop_event(event, graph_runtime_state=None, trace_manager=trace_manager)
|
||||
|
|
@ -772,6 +881,11 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport):
|
|||
|
||||
def _save_message(self, *, session: Session, graph_runtime_state: GraphRuntimeState | None = None):
|
||||
message = self._get_message(session=session)
|
||||
if message is None:
|
||||
return
|
||||
|
||||
if message.status == MessageStatus.PAUSED:
|
||||
message.status = MessageStatus.NORMAL
|
||||
|
||||
# If there are assistant files, remove markdown image links from answer
|
||||
answer_text = self._task_state.answer
|
||||
|
|
|
|||
|
|
@ -5,9 +5,14 @@ from dataclasses import dataclass
|
|||
from datetime import datetime
|
||||
from typing import Any, NewType, Union
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, InvokeFrom, WorkflowAppGenerateEntity
|
||||
from core.app.entities.queue_entities import (
|
||||
QueueAgentLogEvent,
|
||||
QueueHumanInputFormFilledEvent,
|
||||
QueueHumanInputFormTimeoutEvent,
|
||||
QueueIterationCompletedEvent,
|
||||
QueueIterationNextEvent,
|
||||
QueueIterationStartEvent,
|
||||
|
|
@ -19,9 +24,13 @@ from core.app.entities.queue_entities import (
|
|||
QueueNodeRetryEvent,
|
||||
QueueNodeStartedEvent,
|
||||
QueueNodeSucceededEvent,
|
||||
QueueWorkflowPausedEvent,
|
||||
)
|
||||
from core.app.entities.task_entities import (
|
||||
AgentLogStreamResponse,
|
||||
HumanInputFormFilledResponse,
|
||||
HumanInputFormTimeoutResponse,
|
||||
HumanInputRequiredResponse,
|
||||
IterationNodeCompletedStreamResponse,
|
||||
IterationNodeNextStreamResponse,
|
||||
IterationNodeStartStreamResponse,
|
||||
|
|
@ -31,7 +40,9 @@ from core.app.entities.task_entities import (
|
|||
NodeFinishStreamResponse,
|
||||
NodeRetryStreamResponse,
|
||||
NodeStartStreamResponse,
|
||||
StreamResponse,
|
||||
WorkflowFinishStreamResponse,
|
||||
WorkflowPauseStreamResponse,
|
||||
WorkflowStartStreamResponse,
|
||||
)
|
||||
from core.file import FILE_MODEL_IDENTITY, File
|
||||
|
|
@ -40,6 +51,8 @@ from core.tools.entities.tool_entities import ToolProviderType
|
|||
from core.tools.tool_manager import ToolManager
|
||||
from core.trigger.trigger_manager import TriggerManager
|
||||
from core.variables.segments import ArrayFileSegment, FileSegment, Segment
|
||||
from core.workflow.entities.pause_reason import HumanInputRequired
|
||||
from core.workflow.entities.workflow_start_reason import WorkflowStartReason
|
||||
from core.workflow.enums import (
|
||||
NodeType,
|
||||
SystemVariableKey,
|
||||
|
|
@ -51,8 +64,11 @@ from core.workflow.runtime import GraphRuntimeState
|
|||
from core.workflow.system_variable import SystemVariable
|
||||
from core.workflow.workflow_entry import WorkflowEntry
|
||||
from core.workflow.workflow_type_encoder import WorkflowRuntimeTypeConverter
|
||||
from extensions.ext_database import db
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from models import Account, EndUser
|
||||
from models.human_input import HumanInputForm
|
||||
from models.workflow import WorkflowRun
|
||||
from services.variable_truncator import BaseTruncator, DummyVariableTruncator, VariableTruncator
|
||||
|
||||
NodeExecutionId = NewType("NodeExecutionId", str)
|
||||
|
|
@ -191,6 +207,7 @@ class WorkflowResponseConverter:
|
|||
task_id: str,
|
||||
workflow_run_id: str,
|
||||
workflow_id: str,
|
||||
reason: WorkflowStartReason,
|
||||
) -> WorkflowStartStreamResponse:
|
||||
run_id = self._ensure_workflow_run_id(workflow_run_id)
|
||||
started_at = naive_utc_now()
|
||||
|
|
@ -204,6 +221,7 @@ class WorkflowResponseConverter:
|
|||
workflow_id=workflow_id,
|
||||
inputs=self._workflow_inputs,
|
||||
created_at=int(started_at.timestamp()),
|
||||
reason=reason,
|
||||
),
|
||||
)
|
||||
|
||||
|
|
@ -264,6 +282,160 @@ class WorkflowResponseConverter:
|
|||
),
|
||||
)
|
||||
|
||||
def workflow_pause_to_stream_response(
|
||||
self,
|
||||
*,
|
||||
event: QueueWorkflowPausedEvent,
|
||||
task_id: str,
|
||||
graph_runtime_state: GraphRuntimeState,
|
||||
) -> list[StreamResponse]:
|
||||
run_id = self._ensure_workflow_run_id()
|
||||
started_at = self._workflow_started_at
|
||||
if started_at is None:
|
||||
raise ValueError(
|
||||
"workflow_pause_to_stream_response called before workflow_start_to_stream_response",
|
||||
)
|
||||
paused_at = naive_utc_now()
|
||||
elapsed_time = (paused_at - started_at).total_seconds()
|
||||
encoded_outputs = self._encode_outputs(event.outputs) or {}
|
||||
if self._application_generate_entity.invoke_from == InvokeFrom.SERVICE_API:
|
||||
encoded_outputs = {}
|
||||
pause_reasons = [reason.model_dump(mode="json") for reason in event.reasons]
|
||||
human_input_form_ids = [reason.form_id for reason in event.reasons if isinstance(reason, HumanInputRequired)]
|
||||
expiration_times_by_form_id: dict[str, datetime] = {}
|
||||
if human_input_form_ids:
|
||||
stmt = select(HumanInputForm.id, HumanInputForm.expiration_time).where(
|
||||
HumanInputForm.id.in_(human_input_form_ids)
|
||||
)
|
||||
with Session(bind=db.engine) as session:
|
||||
for form_id, expiration_time in session.execute(stmt):
|
||||
expiration_times_by_form_id[str(form_id)] = expiration_time
|
||||
|
||||
responses: list[StreamResponse] = []
|
||||
|
||||
for reason in event.reasons:
|
||||
if isinstance(reason, HumanInputRequired):
|
||||
expiration_time = expiration_times_by_form_id.get(reason.form_id)
|
||||
if expiration_time is None:
|
||||
raise ValueError(f"HumanInputForm not found for pause reason, form_id={reason.form_id}")
|
||||
responses.append(
|
||||
HumanInputRequiredResponse(
|
||||
task_id=task_id,
|
||||
workflow_run_id=run_id,
|
||||
data=HumanInputRequiredResponse.Data(
|
||||
form_id=reason.form_id,
|
||||
node_id=reason.node_id,
|
||||
node_title=reason.node_title,
|
||||
form_content=reason.form_content,
|
||||
inputs=reason.inputs,
|
||||
actions=reason.actions,
|
||||
display_in_ui=reason.display_in_ui,
|
||||
form_token=reason.form_token,
|
||||
resolved_default_values=reason.resolved_default_values,
|
||||
expiration_time=int(expiration_time.timestamp()),
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
responses.append(
|
||||
WorkflowPauseStreamResponse(
|
||||
task_id=task_id,
|
||||
workflow_run_id=run_id,
|
||||
data=WorkflowPauseStreamResponse.Data(
|
||||
workflow_run_id=run_id,
|
||||
paused_nodes=list(event.paused_nodes),
|
||||
outputs=encoded_outputs,
|
||||
reasons=pause_reasons,
|
||||
status=WorkflowExecutionStatus.PAUSED.value,
|
||||
created_at=int(started_at.timestamp()),
|
||||
elapsed_time=elapsed_time,
|
||||
total_tokens=graph_runtime_state.total_tokens,
|
||||
total_steps=graph_runtime_state.node_run_steps,
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
return responses
|
||||
|
||||
def human_input_form_filled_to_stream_response(
|
||||
self, *, event: QueueHumanInputFormFilledEvent, task_id: str
|
||||
) -> HumanInputFormFilledResponse:
|
||||
run_id = self._ensure_workflow_run_id()
|
||||
return HumanInputFormFilledResponse(
|
||||
task_id=task_id,
|
||||
workflow_run_id=run_id,
|
||||
data=HumanInputFormFilledResponse.Data(
|
||||
node_id=event.node_id,
|
||||
node_title=event.node_title,
|
||||
rendered_content=event.rendered_content,
|
||||
action_id=event.action_id,
|
||||
action_text=event.action_text,
|
||||
),
|
||||
)
|
||||
|
||||
def human_input_form_timeout_to_stream_response(
|
||||
self, *, event: QueueHumanInputFormTimeoutEvent, task_id: str
|
||||
) -> HumanInputFormTimeoutResponse:
|
||||
run_id = self._ensure_workflow_run_id()
|
||||
return HumanInputFormTimeoutResponse(
|
||||
task_id=task_id,
|
||||
workflow_run_id=run_id,
|
||||
data=HumanInputFormTimeoutResponse.Data(
|
||||
node_id=event.node_id,
|
||||
node_title=event.node_title,
|
||||
expiration_time=int(event.expiration_time.timestamp()),
|
||||
),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def workflow_run_result_to_finish_response(
|
||||
cls,
|
||||
*,
|
||||
task_id: str,
|
||||
workflow_run: WorkflowRun,
|
||||
creator_user: Account | EndUser,
|
||||
) -> WorkflowFinishStreamResponse:
|
||||
run_id = workflow_run.id
|
||||
elapsed_time = workflow_run.elapsed_time
|
||||
|
||||
encoded_outputs = workflow_run.outputs_dict
|
||||
finished_at = workflow_run.finished_at
|
||||
assert finished_at is not None
|
||||
|
||||
created_by: Mapping[str, object]
|
||||
user = creator_user
|
||||
if isinstance(user, Account):
|
||||
created_by = {
|
||||
"id": user.id,
|
||||
"name": user.name,
|
||||
"email": user.email,
|
||||
}
|
||||
else:
|
||||
created_by = {
|
||||
"id": user.id,
|
||||
"user": user.session_id,
|
||||
}
|
||||
|
||||
return WorkflowFinishStreamResponse(
|
||||
task_id=task_id,
|
||||
workflow_run_id=run_id,
|
||||
data=WorkflowFinishStreamResponse.Data(
|
||||
id=run_id,
|
||||
workflow_id=workflow_run.workflow_id,
|
||||
status=workflow_run.status.value,
|
||||
outputs=encoded_outputs,
|
||||
error=workflow_run.error,
|
||||
elapsed_time=elapsed_time,
|
||||
total_tokens=workflow_run.total_tokens,
|
||||
total_steps=workflow_run.total_steps,
|
||||
created_by=created_by,
|
||||
created_at=int(workflow_run.created_at.timestamp()),
|
||||
finished_at=int(finished_at.timestamp()),
|
||||
files=cls.fetch_files_from_node_outputs(encoded_outputs),
|
||||
exceptions_count=workflow_run.exceptions_count,
|
||||
),
|
||||
)
|
||||
|
||||
def workflow_node_start_to_stream_response(
|
||||
self,
|
||||
*,
|
||||
|
|
@ -592,7 +764,8 @@ class WorkflowResponseConverter:
|
|||
),
|
||||
)
|
||||
|
||||
def fetch_files_from_node_outputs(self, outputs_dict: Mapping[str, Any] | None) -> Sequence[Mapping[str, Any]]:
|
||||
@classmethod
|
||||
def fetch_files_from_node_outputs(cls, outputs_dict: Mapping[str, Any] | None) -> Sequence[Mapping[str, Any]]:
|
||||
"""
|
||||
Fetch files from node outputs
|
||||
:param outputs_dict: node outputs dict
|
||||
|
|
@ -601,7 +774,7 @@ class WorkflowResponseConverter:
|
|||
if not outputs_dict:
|
||||
return []
|
||||
|
||||
files = [self._fetch_files_from_variable_value(output_value) for output_value in outputs_dict.values()]
|
||||
files = [cls._fetch_files_from_variable_value(output_value) for output_value in outputs_dict.values()]
|
||||
# Remove None
|
||||
files = [file for file in files if file]
|
||||
# Flatten list
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
import json
|
||||
import logging
|
||||
from collections.abc import Generator
|
||||
from collections.abc import Callable, Generator, Mapping
|
||||
from typing import Union, cast
|
||||
|
||||
from sqlalchemy import select
|
||||
|
|
@ -10,12 +10,14 @@ from core.app.app_config.entities import EasyUIBasedAppConfig, EasyUIBasedAppMod
|
|||
from core.app.apps.base_app_generator import BaseAppGenerator
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager
|
||||
from core.app.apps.exc import GenerateTaskStoppedError
|
||||
from core.app.apps.streaming_utils import stream_topic_events
|
||||
from core.app.entities.app_invoke_entities import (
|
||||
AdvancedChatAppGenerateEntity,
|
||||
AgentChatAppGenerateEntity,
|
||||
AppGenerateEntity,
|
||||
ChatAppGenerateEntity,
|
||||
CompletionAppGenerateEntity,
|
||||
ConversationAppGenerateEntity,
|
||||
InvokeFrom,
|
||||
)
|
||||
from core.app.entities.task_entities import (
|
||||
|
|
@ -27,6 +29,8 @@ from core.app.entities.task_entities import (
|
|||
from core.app.task_pipeline.easy_ui_based_generate_task_pipeline import EasyUIBasedGenerateTaskPipeline
|
||||
from core.prompt.utils.prompt_template_parser import PromptTemplateParser
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_redis import get_pubsub_broadcast_channel
|
||||
from libs.broadcast_channel.channel import Topic
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from models import Account
|
||||
from models.enums import CreatorUserRole
|
||||
|
|
@ -156,6 +160,7 @@ class MessageBasedAppGenerator(BaseAppGenerator):
|
|||
query = application_generate_entity.query or "New conversation"
|
||||
conversation_name = (query[:20] + "…") if len(query) > 20 else query
|
||||
|
||||
created_new_conversation = conversation is None
|
||||
try:
|
||||
if not conversation:
|
||||
conversation = Conversation(
|
||||
|
|
@ -232,6 +237,10 @@ class MessageBasedAppGenerator(BaseAppGenerator):
|
|||
db.session.add_all(message_files)
|
||||
|
||||
db.session.commit()
|
||||
|
||||
if isinstance(application_generate_entity, ConversationAppGenerateEntity):
|
||||
application_generate_entity.conversation_id = conversation.id
|
||||
application_generate_entity.is_new_conversation = created_new_conversation
|
||||
return conversation, message
|
||||
except Exception:
|
||||
db.session.rollback()
|
||||
|
|
@ -284,3 +293,29 @@ class MessageBasedAppGenerator(BaseAppGenerator):
|
|||
raise MessageNotExistsError("Message not exists")
|
||||
|
||||
return message
|
||||
|
||||
@staticmethod
|
||||
def _make_channel_key(app_mode: AppMode, workflow_run_id: str):
|
||||
return f"channel:{app_mode}:{workflow_run_id}"
|
||||
|
||||
@classmethod
|
||||
def get_response_topic(cls, app_mode: AppMode, workflow_run_id: str) -> Topic:
|
||||
key = cls._make_channel_key(app_mode, workflow_run_id)
|
||||
channel = get_pubsub_broadcast_channel()
|
||||
topic = channel.topic(key)
|
||||
return topic
|
||||
|
||||
@classmethod
|
||||
def retrieve_events(
|
||||
cls,
|
||||
app_mode: AppMode,
|
||||
workflow_run_id: str,
|
||||
idle_timeout=300,
|
||||
on_subscribe: Callable[[], None] | None = None,
|
||||
) -> Generator[Mapping | str, None, None]:
|
||||
topic = cls.get_response_topic(app_mode, workflow_run_id)
|
||||
return stream_topic_events(
|
||||
topic=topic,
|
||||
idle_timeout=idle_timeout,
|
||||
on_subscribe=on_subscribe,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -0,0 +1,36 @@
|
|||
from collections.abc import Callable, Generator, Mapping
|
||||
|
||||
from core.app.apps.streaming_utils import stream_topic_events
|
||||
from extensions.ext_redis import get_pubsub_broadcast_channel
|
||||
from libs.broadcast_channel.channel import Topic
|
||||
from models.model import AppMode
|
||||
|
||||
|
||||
class MessageGenerator:
|
||||
@staticmethod
|
||||
def _make_channel_key(app_mode: AppMode, workflow_run_id: str):
|
||||
return f"channel:{app_mode}:{str(workflow_run_id)}"
|
||||
|
||||
@classmethod
|
||||
def get_response_topic(cls, app_mode: AppMode, workflow_run_id: str) -> Topic:
|
||||
key = cls._make_channel_key(app_mode, workflow_run_id)
|
||||
channel = get_pubsub_broadcast_channel()
|
||||
topic = channel.topic(key)
|
||||
return topic
|
||||
|
||||
@classmethod
|
||||
def retrieve_events(
|
||||
cls,
|
||||
app_mode: AppMode,
|
||||
workflow_run_id: str,
|
||||
idle_timeout=300,
|
||||
ping_interval: float = 10.0,
|
||||
on_subscribe: Callable[[], None] | None = None,
|
||||
) -> Generator[Mapping | str, None, None]:
|
||||
topic = cls.get_response_topic(app_mode, workflow_run_id)
|
||||
return stream_topic_events(
|
||||
topic=topic,
|
||||
idle_timeout=idle_timeout,
|
||||
ping_interval=ping_interval,
|
||||
on_subscribe=on_subscribe,
|
||||
)
|
||||
|
|
@ -0,0 +1,70 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import time
|
||||
from collections.abc import Callable, Generator, Iterable, Mapping
|
||||
from typing import Any
|
||||
|
||||
from core.app.entities.task_entities import StreamEvent
|
||||
from libs.broadcast_channel.channel import Topic
|
||||
from libs.broadcast_channel.exc import SubscriptionClosedError
|
||||
|
||||
|
||||
def stream_topic_events(
|
||||
*,
|
||||
topic: Topic,
|
||||
idle_timeout: float,
|
||||
ping_interval: float | None = None,
|
||||
on_subscribe: Callable[[], None] | None = None,
|
||||
terminal_events: Iterable[str | StreamEvent] | None = None,
|
||||
) -> Generator[Mapping[str, Any] | str, None, None]:
|
||||
# send a PING event immediately to prevent the connection staying in pending state for a long time.
|
||||
#
|
||||
# This simplify the debugging process as the DevTools in Chrome does not
|
||||
# provide complete curl command for pending connections.
|
||||
yield StreamEvent.PING.value
|
||||
|
||||
terminal_values = _normalize_terminal_events(terminal_events)
|
||||
last_msg_time = time.time()
|
||||
last_ping_time = last_msg_time
|
||||
with topic.subscribe() as sub:
|
||||
# on_subscribe fires only after the Redis subscription is active.
|
||||
# This is used to gate task start and reduce pub/sub race for the first event.
|
||||
if on_subscribe is not None:
|
||||
on_subscribe()
|
||||
while True:
|
||||
try:
|
||||
msg = sub.receive(timeout=0.1)
|
||||
except SubscriptionClosedError:
|
||||
return
|
||||
if msg is None:
|
||||
current_time = time.time()
|
||||
if current_time - last_msg_time > idle_timeout:
|
||||
return
|
||||
if ping_interval is not None and current_time - last_ping_time >= ping_interval:
|
||||
yield StreamEvent.PING.value
|
||||
last_ping_time = current_time
|
||||
continue
|
||||
|
||||
last_msg_time = time.time()
|
||||
last_ping_time = last_msg_time
|
||||
event = json.loads(msg)
|
||||
yield event
|
||||
if not isinstance(event, dict):
|
||||
continue
|
||||
|
||||
event_type = event.get("event")
|
||||
if event_type in terminal_values:
|
||||
return
|
||||
|
||||
|
||||
def _normalize_terminal_events(terminal_events: Iterable[str | StreamEvent] | None) -> set[str]:
|
||||
if not terminal_events:
|
||||
return {StreamEvent.WORKFLOW_FINISHED.value, StreamEvent.WORKFLOW_PAUSED.value}
|
||||
values: set[str] = set()
|
||||
for item in terminal_events:
|
||||
if isinstance(item, StreamEvent):
|
||||
values.add(item.value)
|
||||
else:
|
||||
values.add(str(item))
|
||||
return values
|
||||
|
|
@ -25,6 +25,7 @@ from core.app.apps.workflow.generate_response_converter import WorkflowAppGenera
|
|||
from core.app.apps.workflow.generate_task_pipeline import WorkflowAppGenerateTaskPipeline
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom, WorkflowAppGenerateEntity
|
||||
from core.app.entities.task_entities import WorkflowAppBlockingResponse, WorkflowAppStreamResponse
|
||||
from core.app.layers.pause_state_persist_layer import PauseStateLayerConfig, PauseStatePersistenceLayer
|
||||
from core.db.session_factory import session_factory
|
||||
from core.helper.trace_id_helper import extract_external_trace_id_from_args
|
||||
from core.model_runtime.errors.invoke import InvokeAuthorizationError
|
||||
|
|
@ -34,12 +35,15 @@ from core.workflow.graph_engine.layers.base import GraphEngineLayer
|
|||
from core.workflow.repositories.draft_variable_repository import DraftVariableSaverFactory
|
||||
from core.workflow.repositories.workflow_execution_repository import WorkflowExecutionRepository
|
||||
from core.workflow.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository
|
||||
from core.workflow.runtime import GraphRuntimeState
|
||||
from core.workflow.variable_loader import DUMMY_VARIABLE_LOADER, VariableLoader
|
||||
from extensions.ext_database import db
|
||||
from factories import file_factory
|
||||
from libs.flask_utils import preserve_flask_contexts
|
||||
from models import Account, App, EndUser, Workflow, WorkflowNodeExecutionTriggeredFrom
|
||||
from models.account import Account
|
||||
from models.enums import WorkflowRunTriggeredFrom
|
||||
from models.model import App, EndUser
|
||||
from models.workflow import Workflow, WorkflowNodeExecutionTriggeredFrom
|
||||
from services.workflow_draft_variable_service import DraftVarLoader, WorkflowDraftVariableService
|
||||
|
||||
if TYPE_CHECKING:
|
||||
|
|
@ -66,9 +70,11 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
|||
invoke_from: InvokeFrom,
|
||||
streaming: Literal[True],
|
||||
call_depth: int,
|
||||
workflow_run_id: str | uuid.UUID | None = None,
|
||||
triggered_from: WorkflowRunTriggeredFrom | None = None,
|
||||
root_node_id: str | None = None,
|
||||
graph_engine_layers: Sequence[GraphEngineLayer] = (),
|
||||
pause_state_config: PauseStateLayerConfig | None = None,
|
||||
) -> Generator[Mapping[str, Any] | str, None, None]: ...
|
||||
|
||||
@overload
|
||||
|
|
@ -82,9 +88,11 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
|||
invoke_from: InvokeFrom,
|
||||
streaming: Literal[False],
|
||||
call_depth: int,
|
||||
workflow_run_id: str | uuid.UUID | None = None,
|
||||
triggered_from: WorkflowRunTriggeredFrom | None = None,
|
||||
root_node_id: str | None = None,
|
||||
graph_engine_layers: Sequence[GraphEngineLayer] = (),
|
||||
pause_state_config: PauseStateLayerConfig | None = None,
|
||||
) -> Mapping[str, Any]: ...
|
||||
|
||||
@overload
|
||||
|
|
@ -98,9 +106,11 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
|||
invoke_from: InvokeFrom,
|
||||
streaming: bool,
|
||||
call_depth: int,
|
||||
workflow_run_id: str | uuid.UUID | None = None,
|
||||
triggered_from: WorkflowRunTriggeredFrom | None = None,
|
||||
root_node_id: str | None = None,
|
||||
graph_engine_layers: Sequence[GraphEngineLayer] = (),
|
||||
pause_state_config: PauseStateLayerConfig | None = None,
|
||||
) -> Union[Mapping[str, Any], Generator[Mapping[str, Any] | str, None, None]]: ...
|
||||
|
||||
def generate(
|
||||
|
|
@ -113,9 +123,11 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
|||
invoke_from: InvokeFrom,
|
||||
streaming: bool = True,
|
||||
call_depth: int = 0,
|
||||
workflow_run_id: str | uuid.UUID | None = None,
|
||||
triggered_from: WorkflowRunTriggeredFrom | None = None,
|
||||
root_node_id: str | None = None,
|
||||
graph_engine_layers: Sequence[GraphEngineLayer] = (),
|
||||
pause_state_config: PauseStateLayerConfig | None = None,
|
||||
) -> Union[Mapping[str, Any], Generator[Mapping[str, Any] | str, None, None]]:
|
||||
files: Sequence[Mapping[str, Any]] = args.get("files") or []
|
||||
|
||||
|
|
@ -150,7 +162,7 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
|||
extras = {
|
||||
**extract_external_trace_id_from_args(args),
|
||||
}
|
||||
workflow_run_id = str(uuid.uuid4())
|
||||
workflow_run_id = str(workflow_run_id or uuid.uuid4())
|
||||
# FIXME (Yeuoly): we need to remove the SKIP_PREPARE_USER_INPUTS_KEY from the args
|
||||
# trigger shouldn't prepare user inputs
|
||||
if self._should_prepare_user_inputs(args):
|
||||
|
|
@ -216,13 +228,40 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
|||
streaming=streaming,
|
||||
root_node_id=root_node_id,
|
||||
graph_engine_layers=graph_engine_layers,
|
||||
pause_state_config=pause_state_config,
|
||||
)
|
||||
|
||||
def resume(self, *, workflow_run_id: str) -> None:
|
||||
def resume(
|
||||
self,
|
||||
*,
|
||||
app_model: App,
|
||||
workflow: Workflow,
|
||||
user: Union[Account, EndUser],
|
||||
application_generate_entity: WorkflowAppGenerateEntity,
|
||||
graph_runtime_state: GraphRuntimeState,
|
||||
workflow_execution_repository: WorkflowExecutionRepository,
|
||||
workflow_node_execution_repository: WorkflowNodeExecutionRepository,
|
||||
graph_engine_layers: Sequence[GraphEngineLayer] = (),
|
||||
pause_state_config: PauseStateLayerConfig | None = None,
|
||||
variable_loader: VariableLoader = DUMMY_VARIABLE_LOADER,
|
||||
) -> Union[Mapping[str, Any], Generator[str | Mapping[str, Any], None, None]]:
|
||||
"""
|
||||
@TBD
|
||||
Resume a paused workflow execution using the persisted runtime state.
|
||||
"""
|
||||
pass
|
||||
return self._generate(
|
||||
app_model=app_model,
|
||||
workflow=workflow,
|
||||
user=user,
|
||||
application_generate_entity=application_generate_entity,
|
||||
invoke_from=application_generate_entity.invoke_from,
|
||||
workflow_execution_repository=workflow_execution_repository,
|
||||
workflow_node_execution_repository=workflow_node_execution_repository,
|
||||
streaming=application_generate_entity.stream,
|
||||
variable_loader=variable_loader,
|
||||
graph_engine_layers=graph_engine_layers,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
pause_state_config=pause_state_config,
|
||||
)
|
||||
|
||||
def _generate(
|
||||
self,
|
||||
|
|
@ -238,6 +277,8 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
|||
variable_loader: VariableLoader = DUMMY_VARIABLE_LOADER,
|
||||
root_node_id: str | None = None,
|
||||
graph_engine_layers: Sequence[GraphEngineLayer] = (),
|
||||
graph_runtime_state: GraphRuntimeState | None = None,
|
||||
pause_state_config: PauseStateLayerConfig | None = None,
|
||||
) -> Union[Mapping[str, Any], Generator[str | Mapping[str, Any], None, None]]:
|
||||
"""
|
||||
Generate App response.
|
||||
|
|
@ -251,6 +292,8 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
|||
:param workflow_node_execution_repository: repository for workflow node execution
|
||||
:param streaming: is stream
|
||||
"""
|
||||
graph_layers: list[GraphEngineLayer] = list(graph_engine_layers)
|
||||
|
||||
# init queue manager
|
||||
queue_manager = WorkflowAppQueueManager(
|
||||
task_id=application_generate_entity.task_id,
|
||||
|
|
@ -259,6 +302,15 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
|||
app_mode=app_model.mode,
|
||||
)
|
||||
|
||||
if pause_state_config is not None:
|
||||
graph_layers.append(
|
||||
PauseStatePersistenceLayer(
|
||||
session_factory=pause_state_config.session_factory,
|
||||
generate_entity=application_generate_entity,
|
||||
state_owner_user_id=pause_state_config.state_owner_user_id,
|
||||
)
|
||||
)
|
||||
|
||||
# new thread with request context and contextvars
|
||||
context = contextvars.copy_context()
|
||||
|
||||
|
|
@ -276,7 +328,8 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
|||
"root_node_id": root_node_id,
|
||||
"workflow_execution_repository": workflow_execution_repository,
|
||||
"workflow_node_execution_repository": workflow_node_execution_repository,
|
||||
"graph_engine_layers": graph_engine_layers,
|
||||
"graph_engine_layers": tuple(graph_layers),
|
||||
"graph_runtime_state": graph_runtime_state,
|
||||
},
|
||||
)
|
||||
|
||||
|
|
@ -378,6 +431,7 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
|||
workflow_node_execution_repository=workflow_node_execution_repository,
|
||||
streaming=streaming,
|
||||
variable_loader=var_loader,
|
||||
pause_state_config=None,
|
||||
)
|
||||
|
||||
def single_loop_generate(
|
||||
|
|
@ -459,6 +513,7 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
|||
workflow_node_execution_repository=workflow_node_execution_repository,
|
||||
streaming=streaming,
|
||||
variable_loader=var_loader,
|
||||
pause_state_config=None,
|
||||
)
|
||||
|
||||
def _generate_worker(
|
||||
|
|
@ -472,6 +527,7 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
|||
workflow_node_execution_repository: WorkflowNodeExecutionRepository,
|
||||
root_node_id: str | None = None,
|
||||
graph_engine_layers: Sequence[GraphEngineLayer] = (),
|
||||
graph_runtime_state: GraphRuntimeState | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
Generate worker in a new thread.
|
||||
|
|
@ -517,6 +573,7 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
|||
workflow_node_execution_repository=workflow_node_execution_repository,
|
||||
root_node_id=root_node_id,
|
||||
graph_engine_layers=graph_engine_layers,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
)
|
||||
|
||||
try:
|
||||
|
|
|
|||
|
|
@ -42,6 +42,7 @@ class WorkflowAppRunner(WorkflowBasedAppRunner):
|
|||
workflow_execution_repository: WorkflowExecutionRepository,
|
||||
workflow_node_execution_repository: WorkflowNodeExecutionRepository,
|
||||
graph_engine_layers: Sequence[GraphEngineLayer] = (),
|
||||
graph_runtime_state: GraphRuntimeState | None = None,
|
||||
):
|
||||
super().__init__(
|
||||
queue_manager=queue_manager,
|
||||
|
|
@ -55,6 +56,7 @@ class WorkflowAppRunner(WorkflowBasedAppRunner):
|
|||
self._root_node_id = root_node_id
|
||||
self._workflow_execution_repository = workflow_execution_repository
|
||||
self._workflow_node_execution_repository = workflow_node_execution_repository
|
||||
self._resume_graph_runtime_state = graph_runtime_state
|
||||
|
||||
@trace_span(WorkflowAppRunnerHandler)
|
||||
def run(self):
|
||||
|
|
@ -63,23 +65,28 @@ class WorkflowAppRunner(WorkflowBasedAppRunner):
|
|||
"""
|
||||
app_config = self.application_generate_entity.app_config
|
||||
app_config = cast(WorkflowAppConfig, app_config)
|
||||
|
||||
system_inputs = SystemVariable(
|
||||
files=self.application_generate_entity.files,
|
||||
user_id=self._sys_user_id,
|
||||
app_id=app_config.app_id,
|
||||
timestamp=int(naive_utc_now().timestamp()),
|
||||
workflow_id=app_config.workflow_id,
|
||||
workflow_execution_id=self.application_generate_entity.workflow_execution_id,
|
||||
)
|
||||
|
||||
invoke_from = self.application_generate_entity.invoke_from
|
||||
# if only single iteration or single loop run is requested
|
||||
if self.application_generate_entity.single_iteration_run or self.application_generate_entity.single_loop_run:
|
||||
invoke_from = InvokeFrom.DEBUGGER
|
||||
user_from = self._resolve_user_from(invoke_from)
|
||||
|
||||
if self.application_generate_entity.single_iteration_run or self.application_generate_entity.single_loop_run:
|
||||
resume_state = self._resume_graph_runtime_state
|
||||
|
||||
if resume_state is not None:
|
||||
graph_runtime_state = resume_state
|
||||
variable_pool = graph_runtime_state.variable_pool
|
||||
graph = self._init_graph(
|
||||
graph_config=self._workflow.graph_dict,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
workflow_id=self._workflow.id,
|
||||
tenant_id=self._workflow.tenant_id,
|
||||
user_id=self.application_generate_entity.user_id,
|
||||
user_from=user_from,
|
||||
invoke_from=invoke_from,
|
||||
root_node_id=self._root_node_id,
|
||||
)
|
||||
elif self.application_generate_entity.single_iteration_run or self.application_generate_entity.single_loop_run:
|
||||
graph, variable_pool, graph_runtime_state = self._prepare_single_node_execution(
|
||||
workflow=self._workflow,
|
||||
single_iteration_run=self.application_generate_entity.single_iteration_run,
|
||||
|
|
@ -89,7 +96,14 @@ class WorkflowAppRunner(WorkflowBasedAppRunner):
|
|||
inputs = self.application_generate_entity.inputs
|
||||
|
||||
# Create a variable pool.
|
||||
|
||||
system_inputs = SystemVariable(
|
||||
files=self.application_generate_entity.files,
|
||||
user_id=self._sys_user_id,
|
||||
app_id=app_config.app_id,
|
||||
timestamp=int(naive_utc_now().timestamp()),
|
||||
workflow_id=app_config.workflow_id,
|
||||
workflow_execution_id=self.application_generate_entity.workflow_execution_id,
|
||||
)
|
||||
variable_pool = VariablePool(
|
||||
system_variables=system_inputs,
|
||||
user_inputs=inputs,
|
||||
|
|
@ -98,8 +112,6 @@ class WorkflowAppRunner(WorkflowBasedAppRunner):
|
|||
)
|
||||
|
||||
graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter())
|
||||
|
||||
# init graph
|
||||
graph = self._init_graph(
|
||||
graph_config=self._workflow.graph_dict,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
|
|
|
|||
|
|
@ -0,0 +1,7 @@
|
|||
from libs.exception import BaseHTTPException
|
||||
|
||||
|
||||
class WorkflowPausedInBlockingModeError(BaseHTTPException):
|
||||
error_code = "workflow_paused_in_blocking_mode"
|
||||
description = "Workflow execution paused for human input; blocking response mode is not supported."
|
||||
code = 400
|
||||
|
|
@ -16,6 +16,8 @@ from core.app.entities.queue_entities import (
|
|||
MessageQueueMessage,
|
||||
QueueAgentLogEvent,
|
||||
QueueErrorEvent,
|
||||
QueueHumanInputFormFilledEvent,
|
||||
QueueHumanInputFormTimeoutEvent,
|
||||
QueueIterationCompletedEvent,
|
||||
QueueIterationNextEvent,
|
||||
QueueIterationStartEvent,
|
||||
|
|
@ -32,6 +34,7 @@ from core.app.entities.queue_entities import (
|
|||
QueueTextChunkEvent,
|
||||
QueueWorkflowFailedEvent,
|
||||
QueueWorkflowPartialSuccessEvent,
|
||||
QueueWorkflowPausedEvent,
|
||||
QueueWorkflowStartedEvent,
|
||||
QueueWorkflowSucceededEvent,
|
||||
WorkflowQueueMessage,
|
||||
|
|
@ -46,11 +49,13 @@ from core.app.entities.task_entities import (
|
|||
WorkflowAppBlockingResponse,
|
||||
WorkflowAppStreamResponse,
|
||||
WorkflowFinishStreamResponse,
|
||||
WorkflowPauseStreamResponse,
|
||||
WorkflowStartStreamResponse,
|
||||
)
|
||||
from core.app.task_pipeline.based_generate_task_pipeline import BasedGenerateTaskPipeline
|
||||
from core.base.tts import AppGeneratorTTSPublisher, AudioTrunk
|
||||
from core.ops.ops_trace_manager import TraceQueueManager
|
||||
from core.workflow.entities.workflow_start_reason import WorkflowStartReason
|
||||
from core.workflow.enums import WorkflowExecutionStatus
|
||||
from core.workflow.repositories.draft_variable_repository import DraftVariableSaverFactory
|
||||
from core.workflow.runtime import GraphRuntimeState
|
||||
|
|
@ -132,6 +137,25 @@ class WorkflowAppGenerateTaskPipeline(GraphRuntimeStateSupport):
|
|||
for stream_response in generator:
|
||||
if isinstance(stream_response, ErrorStreamResponse):
|
||||
raise stream_response.err
|
||||
elif isinstance(stream_response, WorkflowPauseStreamResponse):
|
||||
response = WorkflowAppBlockingResponse(
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
workflow_run_id=stream_response.data.workflow_run_id,
|
||||
data=WorkflowAppBlockingResponse.Data(
|
||||
id=stream_response.data.workflow_run_id,
|
||||
workflow_id=self._workflow.id,
|
||||
status=stream_response.data.status,
|
||||
outputs=stream_response.data.outputs or {},
|
||||
error=None,
|
||||
elapsed_time=stream_response.data.elapsed_time,
|
||||
total_tokens=stream_response.data.total_tokens,
|
||||
total_steps=stream_response.data.total_steps,
|
||||
created_at=stream_response.data.created_at,
|
||||
finished_at=None,
|
||||
),
|
||||
)
|
||||
|
||||
return response
|
||||
elif isinstance(stream_response, WorkflowFinishStreamResponse):
|
||||
response = WorkflowAppBlockingResponse(
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
|
|
@ -146,7 +170,7 @@ class WorkflowAppGenerateTaskPipeline(GraphRuntimeStateSupport):
|
|||
total_tokens=stream_response.data.total_tokens,
|
||||
total_steps=stream_response.data.total_steps,
|
||||
created_at=int(stream_response.data.created_at),
|
||||
finished_at=int(stream_response.data.finished_at),
|
||||
finished_at=int(stream_response.data.finished_at) if stream_response.data.finished_at else None,
|
||||
),
|
||||
)
|
||||
|
||||
|
|
@ -259,13 +283,15 @@ class WorkflowAppGenerateTaskPipeline(GraphRuntimeStateSupport):
|
|||
run_id = self._extract_workflow_run_id(runtime_state)
|
||||
self._workflow_execution_id = run_id
|
||||
|
||||
with self._database_session() as session:
|
||||
self._save_workflow_app_log(session=session, workflow_run_id=self._workflow_execution_id)
|
||||
if event.reason == WorkflowStartReason.INITIAL:
|
||||
with self._database_session() as session:
|
||||
self._save_workflow_app_log(session=session, workflow_run_id=self._workflow_execution_id)
|
||||
|
||||
start_resp = self._workflow_response_converter.workflow_start_to_stream_response(
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
workflow_run_id=run_id,
|
||||
workflow_id=self._workflow.id,
|
||||
reason=event.reason,
|
||||
)
|
||||
yield start_resp
|
||||
|
||||
|
|
@ -440,6 +466,21 @@ class WorkflowAppGenerateTaskPipeline(GraphRuntimeStateSupport):
|
|||
)
|
||||
yield workflow_finish_resp
|
||||
|
||||
def _handle_workflow_paused_event(
|
||||
self,
|
||||
event: QueueWorkflowPausedEvent,
|
||||
**kwargs,
|
||||
) -> Generator[StreamResponse, None, None]:
|
||||
"""Handle workflow paused events."""
|
||||
self._ensure_workflow_initialized()
|
||||
validated_state = self._ensure_graph_runtime_initialized()
|
||||
responses = self._workflow_response_converter.workflow_pause_to_stream_response(
|
||||
event=event,
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
graph_runtime_state=validated_state,
|
||||
)
|
||||
yield from responses
|
||||
|
||||
def _handle_workflow_failed_and_stop_events(
|
||||
self,
|
||||
event: Union[QueueWorkflowFailedEvent, QueueStopEvent],
|
||||
|
|
@ -495,6 +536,22 @@ class WorkflowAppGenerateTaskPipeline(GraphRuntimeStateSupport):
|
|||
task_id=self._application_generate_entity.task_id, event=event
|
||||
)
|
||||
|
||||
def _handle_human_input_form_filled_event(
|
||||
self, event: QueueHumanInputFormFilledEvent, **kwargs
|
||||
) -> Generator[StreamResponse, None, None]:
|
||||
"""Handle human input form filled events."""
|
||||
yield self._workflow_response_converter.human_input_form_filled_to_stream_response(
|
||||
event=event, task_id=self._application_generate_entity.task_id
|
||||
)
|
||||
|
||||
def _handle_human_input_form_timeout_event(
|
||||
self, event: QueueHumanInputFormTimeoutEvent, **kwargs
|
||||
) -> Generator[StreamResponse, None, None]:
|
||||
"""Handle human input form timeout events."""
|
||||
yield self._workflow_response_converter.human_input_form_timeout_to_stream_response(
|
||||
event=event, task_id=self._application_generate_entity.task_id
|
||||
)
|
||||
|
||||
def _get_event_handlers(self) -> dict[type, Callable]:
|
||||
"""Get mapping of event types to their handlers using fluent pattern."""
|
||||
return {
|
||||
|
|
@ -506,6 +563,7 @@ class WorkflowAppGenerateTaskPipeline(GraphRuntimeStateSupport):
|
|||
QueueWorkflowStartedEvent: self._handle_workflow_started_event,
|
||||
QueueWorkflowSucceededEvent: self._handle_workflow_succeeded_event,
|
||||
QueueWorkflowPartialSuccessEvent: self._handle_workflow_partial_success_event,
|
||||
QueueWorkflowPausedEvent: self._handle_workflow_paused_event,
|
||||
# Node events
|
||||
QueueNodeRetryEvent: self._handle_node_retry_event,
|
||||
QueueNodeStartedEvent: self._handle_node_started_event,
|
||||
|
|
@ -520,6 +578,8 @@ class WorkflowAppGenerateTaskPipeline(GraphRuntimeStateSupport):
|
|||
QueueLoopCompletedEvent: self._handle_loop_completed_event,
|
||||
# Agent events
|
||||
QueueAgentLogEvent: self._handle_agent_log_event,
|
||||
QueueHumanInputFormFilledEvent: self._handle_human_input_form_filled_event,
|
||||
QueueHumanInputFormTimeoutEvent: self._handle_human_input_form_timeout_event,
|
||||
}
|
||||
|
||||
def _dispatch_event(
|
||||
|
|
@ -602,6 +662,9 @@ class WorkflowAppGenerateTaskPipeline(GraphRuntimeStateSupport):
|
|||
case QueueWorkflowFailedEvent():
|
||||
yield from self._handle_workflow_failed_and_stop_events(event)
|
||||
break
|
||||
case QueueWorkflowPausedEvent():
|
||||
yield from self._handle_workflow_paused_event(event)
|
||||
break
|
||||
|
||||
case QueueStopEvent():
|
||||
yield from self._handle_workflow_failed_and_stop_events(event)
|
||||
|
|
|
|||
|
|
@ -1,3 +1,4 @@
|
|||
import logging
|
||||
import time
|
||||
from collections.abc import Mapping, Sequence
|
||||
from typing import Any, cast
|
||||
|
|
@ -7,6 +8,8 @@ from core.app.entities.app_invoke_entities import InvokeFrom
|
|||
from core.app.entities.queue_entities import (
|
||||
AppQueueEvent,
|
||||
QueueAgentLogEvent,
|
||||
QueueHumanInputFormFilledEvent,
|
||||
QueueHumanInputFormTimeoutEvent,
|
||||
QueueIterationCompletedEvent,
|
||||
QueueIterationNextEvent,
|
||||
QueueIterationStartEvent,
|
||||
|
|
@ -22,22 +25,27 @@ from core.app.entities.queue_entities import (
|
|||
QueueTextChunkEvent,
|
||||
QueueWorkflowFailedEvent,
|
||||
QueueWorkflowPartialSuccessEvent,
|
||||
QueueWorkflowPausedEvent,
|
||||
QueueWorkflowStartedEvent,
|
||||
QueueWorkflowSucceededEvent,
|
||||
)
|
||||
from core.app.workflow.node_factory import DifyNodeFactory
|
||||
from core.workflow.entities import GraphInitParams
|
||||
from core.workflow.entities.pause_reason import HumanInputRequired
|
||||
from core.workflow.graph import Graph
|
||||
from core.workflow.graph_engine.layers.base import GraphEngineLayer
|
||||
from core.workflow.graph_events import (
|
||||
GraphEngineEvent,
|
||||
GraphRunFailedEvent,
|
||||
GraphRunPartialSucceededEvent,
|
||||
GraphRunPausedEvent,
|
||||
GraphRunStartedEvent,
|
||||
GraphRunSucceededEvent,
|
||||
NodeRunAgentLogEvent,
|
||||
NodeRunExceptionEvent,
|
||||
NodeRunFailedEvent,
|
||||
NodeRunHumanInputFormFilledEvent,
|
||||
NodeRunHumanInputFormTimeoutEvent,
|
||||
NodeRunIterationFailedEvent,
|
||||
NodeRunIterationNextEvent,
|
||||
NodeRunIterationStartedEvent,
|
||||
|
|
@ -61,6 +69,9 @@ from core.workflow.variable_loader import DUMMY_VARIABLE_LOADER, VariableLoader,
|
|||
from core.workflow.workflow_entry import WorkflowEntry
|
||||
from models.enums import UserFrom
|
||||
from models.workflow import Workflow
|
||||
from tasks.mail_human_input_delivery_task import dispatch_human_input_email_task
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class WorkflowBasedAppRunner:
|
||||
|
|
@ -327,7 +338,7 @@ class WorkflowBasedAppRunner:
|
|||
:param event: event
|
||||
"""
|
||||
if isinstance(event, GraphRunStartedEvent):
|
||||
self._publish_event(QueueWorkflowStartedEvent())
|
||||
self._publish_event(QueueWorkflowStartedEvent(reason=event.reason))
|
||||
elif isinstance(event, GraphRunSucceededEvent):
|
||||
self._publish_event(QueueWorkflowSucceededEvent(outputs=event.outputs))
|
||||
elif isinstance(event, GraphRunPartialSucceededEvent):
|
||||
|
|
@ -338,6 +349,38 @@ class WorkflowBasedAppRunner:
|
|||
self._publish_event(QueueWorkflowFailedEvent(error=event.error, exceptions_count=event.exceptions_count))
|
||||
elif isinstance(event, GraphRunAbortedEvent):
|
||||
self._publish_event(QueueWorkflowFailedEvent(error=event.reason or "Unknown error", exceptions_count=0))
|
||||
elif isinstance(event, GraphRunPausedEvent):
|
||||
runtime_state = workflow_entry.graph_engine.graph_runtime_state
|
||||
paused_nodes = runtime_state.get_paused_nodes()
|
||||
self._enqueue_human_input_notifications(event.reasons)
|
||||
self._publish_event(
|
||||
QueueWorkflowPausedEvent(
|
||||
reasons=event.reasons,
|
||||
outputs=event.outputs,
|
||||
paused_nodes=paused_nodes,
|
||||
)
|
||||
)
|
||||
elif isinstance(event, NodeRunHumanInputFormFilledEvent):
|
||||
self._publish_event(
|
||||
QueueHumanInputFormFilledEvent(
|
||||
node_execution_id=event.id,
|
||||
node_id=event.node_id,
|
||||
node_type=event.node_type,
|
||||
node_title=event.node_title,
|
||||
rendered_content=event.rendered_content,
|
||||
action_id=event.action_id,
|
||||
action_text=event.action_text,
|
||||
)
|
||||
)
|
||||
elif isinstance(event, NodeRunHumanInputFormTimeoutEvent):
|
||||
self._publish_event(
|
||||
QueueHumanInputFormTimeoutEvent(
|
||||
node_id=event.node_id,
|
||||
node_type=event.node_type,
|
||||
node_title=event.node_title,
|
||||
expiration_time=event.expiration_time,
|
||||
)
|
||||
)
|
||||
elif isinstance(event, NodeRunRetryEvent):
|
||||
node_run_result = event.node_run_result
|
||||
inputs = node_run_result.inputs
|
||||
|
|
@ -544,5 +587,19 @@ class WorkflowBasedAppRunner:
|
|||
)
|
||||
)
|
||||
|
||||
def _enqueue_human_input_notifications(self, reasons: Sequence[object]) -> None:
|
||||
for reason in reasons:
|
||||
if not isinstance(reason, HumanInputRequired):
|
||||
continue
|
||||
if not reason.form_id:
|
||||
continue
|
||||
try:
|
||||
dispatch_human_input_email_task.apply_async(
|
||||
kwargs={"form_id": reason.form_id, "node_title": reason.node_title},
|
||||
queue="mail",
|
||||
)
|
||||
except Exception: # pragma: no cover - defensive logging
|
||||
logger.exception("Failed to enqueue human input email task for form %s", reason.form_id)
|
||||
|
||||
def _publish_event(self, event: AppQueueEvent):
|
||||
self._queue_manager.publish(event, PublishFrom.APPLICATION_MANAGER)
|
||||
|
|
|
|||
|
|
@ -132,7 +132,7 @@ class AppGenerateEntity(BaseModel):
|
|||
extras: dict[str, Any] = Field(default_factory=dict)
|
||||
|
||||
# tracing instance
|
||||
trace_manager: Optional["TraceQueueManager"] = None
|
||||
trace_manager: Optional["TraceQueueManager"] = Field(default=None, exclude=True, repr=False)
|
||||
|
||||
|
||||
class EasyUIBasedAppGenerateEntity(AppGenerateEntity):
|
||||
|
|
@ -156,6 +156,7 @@ class ConversationAppGenerateEntity(AppGenerateEntity):
|
|||
"""
|
||||
|
||||
conversation_id: str | None = None
|
||||
is_new_conversation: bool = False
|
||||
parent_message_id: str | None = Field(
|
||||
default=None,
|
||||
description=(
|
||||
|
|
|
|||
|
|
@ -8,6 +8,8 @@ from pydantic import BaseModel, ConfigDict, Field
|
|||
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk
|
||||
from core.rag.entities.citation_metadata import RetrievalSourceMetadata
|
||||
from core.workflow.entities import AgentNodeStrategyInit
|
||||
from core.workflow.entities.pause_reason import PauseReason
|
||||
from core.workflow.entities.workflow_start_reason import WorkflowStartReason
|
||||
from core.workflow.enums import WorkflowNodeExecutionMetadataKey
|
||||
from core.workflow.nodes import NodeType
|
||||
|
||||
|
|
@ -46,6 +48,9 @@ class QueueEvent(StrEnum):
|
|||
PING = "ping"
|
||||
STOP = "stop"
|
||||
RETRY = "retry"
|
||||
PAUSE = "pause"
|
||||
HUMAN_INPUT_FORM_FILLED = "human_input_form_filled"
|
||||
HUMAN_INPUT_FORM_TIMEOUT = "human_input_form_timeout"
|
||||
|
||||
|
||||
class AppQueueEvent(BaseModel):
|
||||
|
|
@ -261,6 +266,8 @@ class QueueWorkflowStartedEvent(AppQueueEvent):
|
|||
"""QueueWorkflowStartedEvent entity."""
|
||||
|
||||
event: QueueEvent = QueueEvent.WORKFLOW_STARTED
|
||||
# Always present; mirrors GraphRunStartedEvent.reason for downstream consumers.
|
||||
reason: WorkflowStartReason = WorkflowStartReason.INITIAL
|
||||
|
||||
|
||||
class QueueWorkflowSucceededEvent(AppQueueEvent):
|
||||
|
|
@ -484,6 +491,35 @@ class QueueStopEvent(AppQueueEvent):
|
|||
return reason_mapping.get(self.stopped_by, "Stopped by unknown reason.")
|
||||
|
||||
|
||||
class QueueHumanInputFormFilledEvent(AppQueueEvent):
|
||||
"""
|
||||
QueueHumanInputFormFilledEvent entity
|
||||
"""
|
||||
|
||||
event: QueueEvent = QueueEvent.HUMAN_INPUT_FORM_FILLED
|
||||
|
||||
node_execution_id: str
|
||||
node_id: str
|
||||
node_type: NodeType
|
||||
node_title: str
|
||||
rendered_content: str
|
||||
action_id: str
|
||||
action_text: str
|
||||
|
||||
|
||||
class QueueHumanInputFormTimeoutEvent(AppQueueEvent):
|
||||
"""
|
||||
QueueHumanInputFormTimeoutEvent entity
|
||||
"""
|
||||
|
||||
event: QueueEvent = QueueEvent.HUMAN_INPUT_FORM_TIMEOUT
|
||||
|
||||
node_id: str
|
||||
node_type: NodeType
|
||||
node_title: str
|
||||
expiration_time: datetime
|
||||
|
||||
|
||||
class QueueMessage(BaseModel):
|
||||
"""
|
||||
QueueMessage abstract entity
|
||||
|
|
@ -509,3 +545,14 @@ class WorkflowQueueMessage(QueueMessage):
|
|||
"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class QueueWorkflowPausedEvent(AppQueueEvent):
|
||||
"""
|
||||
QueueWorkflowPausedEvent entity
|
||||
"""
|
||||
|
||||
event: QueueEvent = QueueEvent.PAUSE
|
||||
reasons: Sequence[PauseReason] = Field(default_factory=list)
|
||||
outputs: Mapping[str, object] = Field(default_factory=dict)
|
||||
paused_nodes: Sequence[str] = Field(default_factory=list)
|
||||
|
|
|
|||
|
|
@ -7,7 +7,9 @@ from pydantic import BaseModel, ConfigDict, Field
|
|||
from core.model_runtime.entities.llm_entities import LLMResult, LLMUsage
|
||||
from core.rag.entities.citation_metadata import RetrievalSourceMetadata
|
||||
from core.workflow.entities import AgentNodeStrategyInit
|
||||
from core.workflow.entities.workflow_start_reason import WorkflowStartReason
|
||||
from core.workflow.enums import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus
|
||||
from core.workflow.nodes.human_input.entities import FormInput, UserAction
|
||||
|
||||
|
||||
class AnnotationReplyAccount(BaseModel):
|
||||
|
|
@ -69,6 +71,7 @@ class StreamEvent(StrEnum):
|
|||
AGENT_THOUGHT = "agent_thought"
|
||||
AGENT_MESSAGE = "agent_message"
|
||||
WORKFLOW_STARTED = "workflow_started"
|
||||
WORKFLOW_PAUSED = "workflow_paused"
|
||||
WORKFLOW_FINISHED = "workflow_finished"
|
||||
NODE_STARTED = "node_started"
|
||||
NODE_FINISHED = "node_finished"
|
||||
|
|
@ -82,6 +85,9 @@ class StreamEvent(StrEnum):
|
|||
TEXT_CHUNK = "text_chunk"
|
||||
TEXT_REPLACE = "text_replace"
|
||||
AGENT_LOG = "agent_log"
|
||||
HUMAN_INPUT_REQUIRED = "human_input_required"
|
||||
HUMAN_INPUT_FORM_FILLED = "human_input_form_filled"
|
||||
HUMAN_INPUT_FORM_TIMEOUT = "human_input_form_timeout"
|
||||
|
||||
|
||||
class StreamResponse(BaseModel):
|
||||
|
|
@ -205,6 +211,8 @@ class WorkflowStartStreamResponse(StreamResponse):
|
|||
workflow_id: str
|
||||
inputs: Mapping[str, Any]
|
||||
created_at: int
|
||||
# Always present; mirrors QueueWorkflowStartedEvent.reason for SSE clients.
|
||||
reason: WorkflowStartReason = WorkflowStartReason.INITIAL
|
||||
|
||||
event: StreamEvent = StreamEvent.WORKFLOW_STARTED
|
||||
workflow_run_id: str
|
||||
|
|
@ -231,7 +239,7 @@ class WorkflowFinishStreamResponse(StreamResponse):
|
|||
total_steps: int
|
||||
created_by: Mapping[str, object] = Field(default_factory=dict)
|
||||
created_at: int
|
||||
finished_at: int
|
||||
finished_at: int | None
|
||||
exceptions_count: int | None = 0
|
||||
files: Sequence[Mapping[str, Any]] | None = []
|
||||
|
||||
|
|
@ -240,6 +248,85 @@ class WorkflowFinishStreamResponse(StreamResponse):
|
|||
data: Data
|
||||
|
||||
|
||||
class WorkflowPauseStreamResponse(StreamResponse):
|
||||
"""
|
||||
WorkflowPauseStreamResponse entity
|
||||
"""
|
||||
|
||||
class Data(BaseModel):
|
||||
"""
|
||||
Data entity
|
||||
"""
|
||||
|
||||
workflow_run_id: str
|
||||
paused_nodes: Sequence[str] = Field(default_factory=list)
|
||||
outputs: Mapping[str, Any] = Field(default_factory=dict)
|
||||
reasons: Sequence[Mapping[str, Any]] = Field(default_factory=list)
|
||||
status: str
|
||||
created_at: int
|
||||
elapsed_time: float
|
||||
total_tokens: int
|
||||
total_steps: int
|
||||
|
||||
event: StreamEvent = StreamEvent.WORKFLOW_PAUSED
|
||||
workflow_run_id: str
|
||||
data: Data
|
||||
|
||||
|
||||
class HumanInputRequiredResponse(StreamResponse):
|
||||
class Data(BaseModel):
|
||||
"""
|
||||
Data entity
|
||||
"""
|
||||
|
||||
form_id: str
|
||||
node_id: str
|
||||
node_title: str
|
||||
form_content: str
|
||||
inputs: Sequence[FormInput] = Field(default_factory=list)
|
||||
actions: Sequence[UserAction] = Field(default_factory=list)
|
||||
display_in_ui: bool = False
|
||||
form_token: str | None = None
|
||||
resolved_default_values: Mapping[str, Any] = Field(default_factory=dict)
|
||||
expiration_time: int = Field(..., description="Unix timestamp in seconds")
|
||||
|
||||
event: StreamEvent = StreamEvent.HUMAN_INPUT_REQUIRED
|
||||
workflow_run_id: str
|
||||
data: Data
|
||||
|
||||
|
||||
class HumanInputFormFilledResponse(StreamResponse):
|
||||
class Data(BaseModel):
|
||||
"""
|
||||
Data entity
|
||||
"""
|
||||
|
||||
node_id: str
|
||||
node_title: str
|
||||
rendered_content: str
|
||||
action_id: str
|
||||
action_text: str
|
||||
|
||||
event: StreamEvent = StreamEvent.HUMAN_INPUT_FORM_FILLED
|
||||
workflow_run_id: str
|
||||
data: Data
|
||||
|
||||
|
||||
class HumanInputFormTimeoutResponse(StreamResponse):
|
||||
class Data(BaseModel):
|
||||
"""
|
||||
Data entity
|
||||
"""
|
||||
|
||||
node_id: str
|
||||
node_title: str
|
||||
expiration_time: int
|
||||
|
||||
event: StreamEvent = StreamEvent.HUMAN_INPUT_FORM_TIMEOUT
|
||||
workflow_run_id: str
|
||||
data: Data
|
||||
|
||||
|
||||
class NodeStartStreamResponse(StreamResponse):
|
||||
"""
|
||||
NodeStartStreamResponse entity
|
||||
|
|
@ -726,7 +813,7 @@ class WorkflowAppBlockingResponse(AppBlockingResponse):
|
|||
total_tokens: int
|
||||
total_steps: int
|
||||
created_at: int
|
||||
finished_at: int
|
||||
finished_at: int | None
|
||||
|
||||
workflow_run_id: str
|
||||
data: Data
|
||||
|
|
|
|||
|
|
@ -1,3 +1,4 @@
|
|||
import contextlib
|
||||
import logging
|
||||
import time
|
||||
import uuid
|
||||
|
|
@ -103,6 +104,14 @@ class RateLimit:
|
|||
)
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def rate_limit_context(rate_limit: RateLimit, request_id: str | None):
|
||||
request_id = rate_limit.enter(request_id)
|
||||
yield
|
||||
if request_id is not None:
|
||||
rate_limit.exit(request_id)
|
||||
|
||||
|
||||
class RateLimitGenerator:
|
||||
def __init__(self, rate_limit: RateLimit, generator: Generator[str, None, None], request_id: str):
|
||||
self.rate_limit = rate_limit
|
||||
|
|
|
|||
|
|
@ -1,3 +1,4 @@
|
|||
from dataclasses import dataclass
|
||||
from typing import Annotated, Literal, Self, TypeAlias
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
|
@ -52,6 +53,14 @@ class WorkflowResumptionContext(BaseModel):
|
|||
return self.generate_entity.entity
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class PauseStateLayerConfig:
|
||||
"""Configuration container for instantiating pause persistence layers."""
|
||||
|
||||
session_factory: Engine | sessionmaker[Session]
|
||||
state_owner_user_id: str
|
||||
|
||||
|
||||
class PauseStatePersistenceLayer(GraphEngineLayer):
|
||||
def __init__(
|
||||
self,
|
||||
|
|
|
|||
|
|
@ -82,10 +82,11 @@ class MessageCycleManager:
|
|||
if isinstance(self._application_generate_entity, CompletionAppGenerateEntity):
|
||||
return None
|
||||
|
||||
is_first_message = self._application_generate_entity.conversation_id is None
|
||||
is_first_message = self._application_generate_entity.is_new_conversation
|
||||
extras = self._application_generate_entity.extras
|
||||
auto_generate_conversation_name = extras.get("auto_generate_conversation_name", True)
|
||||
|
||||
thread: Thread | None = None
|
||||
if auto_generate_conversation_name and is_first_message:
|
||||
# start generate thread
|
||||
# time.sleep not block other logic
|
||||
|
|
@ -101,9 +102,10 @@ class MessageCycleManager:
|
|||
thread.daemon = True
|
||||
thread.start()
|
||||
|
||||
return thread
|
||||
if is_first_message:
|
||||
self._application_generate_entity.is_new_conversation = False
|
||||
|
||||
return None
|
||||
return thread
|
||||
|
||||
def _generate_conversation_name_worker(self, flask_app: Flask, conversation_id: str, query: str):
|
||||
with flask_app.app_context():
|
||||
|
|
|
|||
|
|
@ -0,0 +1,54 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Mapping, Sequence
|
||||
from typing import Any, TypeAlias
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
|
||||
from core.workflow.nodes.human_input.entities import FormInput, UserAction
|
||||
from models.execution_extra_content import ExecutionContentType
|
||||
|
||||
|
||||
class HumanInputFormDefinition(BaseModel):
|
||||
model_config = ConfigDict(frozen=True)
|
||||
|
||||
form_id: str
|
||||
node_id: str
|
||||
node_title: str
|
||||
form_content: str
|
||||
inputs: Sequence[FormInput] = Field(default_factory=list)
|
||||
actions: Sequence[UserAction] = Field(default_factory=list)
|
||||
display_in_ui: bool = False
|
||||
form_token: str | None = None
|
||||
resolved_default_values: Mapping[str, Any] = Field(default_factory=dict)
|
||||
expiration_time: int
|
||||
|
||||
|
||||
class HumanInputFormSubmissionData(BaseModel):
|
||||
model_config = ConfigDict(frozen=True)
|
||||
|
||||
node_id: str
|
||||
node_title: str
|
||||
rendered_content: str
|
||||
action_id: str
|
||||
action_text: str
|
||||
|
||||
|
||||
class HumanInputContent(BaseModel):
|
||||
model_config = ConfigDict(frozen=True)
|
||||
|
||||
workflow_run_id: str
|
||||
submitted: bool
|
||||
form_definition: HumanInputFormDefinition | None = None
|
||||
form_submission_data: HumanInputFormSubmissionData | None = None
|
||||
type: ExecutionContentType = Field(default=ExecutionContentType.HUMAN_INPUT)
|
||||
|
||||
|
||||
ExecutionExtraContentDomainModel: TypeAlias = HumanInputContent
|
||||
|
||||
__all__ = [
|
||||
"ExecutionExtraContentDomainModel",
|
||||
"HumanInputContent",
|
||||
"HumanInputFormDefinition",
|
||||
"HumanInputFormSubmissionData",
|
||||
]
|
||||
|
|
@ -28,8 +28,8 @@ from core.model_runtime.entities.provider_entities import (
|
|||
)
|
||||
from core.model_runtime.model_providers.__base.ai_model import AIModel
|
||||
from core.model_runtime.model_providers.model_provider_factory import ModelProviderFactory
|
||||
from extensions.ext_database import db
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from models.engine import db
|
||||
from models.provider import (
|
||||
LoadBalancingModelConfig,
|
||||
Provider,
|
||||
|
|
|
|||
|
|
@ -15,10 +15,7 @@ from sqlalchemy import select
|
|||
from sqlalchemy.orm import Session, sessionmaker
|
||||
|
||||
from core.helper.encrypter import batch_decrypt_token, encrypt_token, obfuscated_token
|
||||
from core.ops.entities.config_entity import (
|
||||
OPS_FILE_PATH,
|
||||
TracingProviderEnum,
|
||||
)
|
||||
from core.ops.entities.config_entity import OPS_FILE_PATH, TracingProviderEnum
|
||||
from core.ops.entities.trace_entity import (
|
||||
DatasetRetrievalTraceInfo,
|
||||
GenerateNameTraceInfo,
|
||||
|
|
@ -31,8 +28,8 @@ from core.ops.entities.trace_entity import (
|
|||
WorkflowTraceInfo,
|
||||
)
|
||||
from core.ops.utils import get_message_data
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_storage import storage
|
||||
from models.engine import db
|
||||
from models.model import App, AppModelConfig, Conversation, Message, MessageFile, TraceAppConfig
|
||||
from models.workflow import WorkflowAppLog
|
||||
from tasks.ops_trace_task import process_trace_tasks
|
||||
|
|
@ -469,6 +466,8 @@ class TraceTask:
|
|||
|
||||
@classmethod
|
||||
def _get_workflow_run_repo(cls):
|
||||
from repositories.factory import DifyAPIRepositoryFactory
|
||||
|
||||
if cls._workflow_run_repo is None:
|
||||
with cls._repo_lock:
|
||||
if cls._workflow_run_repo is None:
|
||||
|
|
|
|||
|
|
@ -5,7 +5,7 @@ from urllib.parse import urlparse
|
|||
|
||||
from sqlalchemy import select
|
||||
|
||||
from extensions.ext_database import db
|
||||
from models.engine import db
|
||||
from models.model import Message
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -1,3 +1,4 @@
|
|||
import uuid
|
||||
from collections.abc import Generator, Mapping
|
||||
from typing import Union
|
||||
|
||||
|
|
@ -112,6 +113,7 @@ class PluginAppBackwardsInvocation(BaseBackwardsInvocation):
|
|||
"conversation_id": conversation_id,
|
||||
},
|
||||
invoke_from=InvokeFrom.SERVICE_API,
|
||||
workflow_run_id=str(uuid.uuid4()),
|
||||
streaming=stream,
|
||||
)
|
||||
elif app.mode == AppMode.AGENT_CHAT:
|
||||
|
|
|
|||
|
|
@ -1,19 +1,18 @@
|
|||
"""
|
||||
Repository implementations for data access.
|
||||
"""Repository implementations for data access."""
|
||||
|
||||
This package contains concrete implementations of the repository interfaces
|
||||
defined in the core.workflow.repository package.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
from core.repositories.celery_workflow_execution_repository import CeleryWorkflowExecutionRepository
|
||||
from core.repositories.celery_workflow_node_execution_repository import CeleryWorkflowNodeExecutionRepository
|
||||
from core.repositories.factory import DifyCoreRepositoryFactory, RepositoryImportError
|
||||
from core.repositories.sqlalchemy_workflow_node_execution_repository import SQLAlchemyWorkflowNodeExecutionRepository
|
||||
from .celery_workflow_execution_repository import CeleryWorkflowExecutionRepository
|
||||
from .celery_workflow_node_execution_repository import CeleryWorkflowNodeExecutionRepository
|
||||
from .factory import DifyCoreRepositoryFactory, RepositoryImportError
|
||||
from .sqlalchemy_workflow_execution_repository import SQLAlchemyWorkflowExecutionRepository
|
||||
from .sqlalchemy_workflow_node_execution_repository import SQLAlchemyWorkflowNodeExecutionRepository
|
||||
|
||||
__all__ = [
|
||||
"CeleryWorkflowExecutionRepository",
|
||||
"CeleryWorkflowNodeExecutionRepository",
|
||||
"DifyCoreRepositoryFactory",
|
||||
"RepositoryImportError",
|
||||
"SQLAlchemyWorkflowExecutionRepository",
|
||||
"SQLAlchemyWorkflowNodeExecutionRepository",
|
||||
]
|
||||
|
|
|
|||
|
|
@ -0,0 +1,553 @@
|
|||
import dataclasses
|
||||
import json
|
||||
from collections.abc import Mapping, Sequence
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
|
||||
from sqlalchemy import Engine, select
|
||||
from sqlalchemy.orm import Session, selectinload, sessionmaker
|
||||
|
||||
from core.workflow.nodes.human_input.entities import (
|
||||
DeliveryChannelConfig,
|
||||
EmailDeliveryMethod,
|
||||
EmailRecipients,
|
||||
ExternalRecipient,
|
||||
FormDefinition,
|
||||
HumanInputNodeData,
|
||||
MemberRecipient,
|
||||
WebAppDeliveryMethod,
|
||||
)
|
||||
from core.workflow.nodes.human_input.enums import (
|
||||
DeliveryMethodType,
|
||||
HumanInputFormKind,
|
||||
HumanInputFormStatus,
|
||||
)
|
||||
from core.workflow.repositories.human_input_form_repository import (
|
||||
FormCreateParams,
|
||||
FormNotFoundError,
|
||||
HumanInputFormEntity,
|
||||
HumanInputFormRecipientEntity,
|
||||
)
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from libs.uuid_utils import uuidv7
|
||||
from models.account import Account, TenantAccountJoin
|
||||
from models.human_input import (
|
||||
BackstageRecipientPayload,
|
||||
ConsoleDeliveryPayload,
|
||||
ConsoleRecipientPayload,
|
||||
EmailExternalRecipientPayload,
|
||||
EmailMemberRecipientPayload,
|
||||
HumanInputDelivery,
|
||||
HumanInputForm,
|
||||
HumanInputFormRecipient,
|
||||
RecipientType,
|
||||
StandaloneWebAppRecipientPayload,
|
||||
)
|
||||
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class _DeliveryAndRecipients:
|
||||
delivery: HumanInputDelivery
|
||||
recipients: Sequence[HumanInputFormRecipient]
|
||||
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class _WorkspaceMemberInfo:
|
||||
user_id: str
|
||||
email: str
|
||||
|
||||
|
||||
class _HumanInputFormRecipientEntityImpl(HumanInputFormRecipientEntity):
|
||||
def __init__(self, recipient_model: HumanInputFormRecipient):
|
||||
self._recipient_model = recipient_model
|
||||
|
||||
@property
|
||||
def id(self) -> str:
|
||||
return self._recipient_model.id
|
||||
|
||||
@property
|
||||
def token(self) -> str:
|
||||
if self._recipient_model.access_token is None:
|
||||
raise AssertionError(f"access_token should not be None for recipient {self._recipient_model.id}")
|
||||
return self._recipient_model.access_token
|
||||
|
||||
|
||||
class _HumanInputFormEntityImpl(HumanInputFormEntity):
|
||||
def __init__(self, form_model: HumanInputForm, recipient_models: Sequence[HumanInputFormRecipient]):
|
||||
self._form_model = form_model
|
||||
self._recipients = [_HumanInputFormRecipientEntityImpl(recipient) for recipient in recipient_models]
|
||||
self._web_app_recipient = next(
|
||||
(
|
||||
recipient
|
||||
for recipient in recipient_models
|
||||
if recipient.recipient_type == RecipientType.STANDALONE_WEB_APP
|
||||
),
|
||||
None,
|
||||
)
|
||||
self._console_recipient = next(
|
||||
(recipient for recipient in recipient_models if recipient.recipient_type == RecipientType.CONSOLE),
|
||||
None,
|
||||
)
|
||||
self._submitted_data: Mapping[str, Any] | None = (
|
||||
json.loads(form_model.submitted_data) if form_model.submitted_data is not None else None
|
||||
)
|
||||
|
||||
@property
|
||||
def id(self) -> str:
|
||||
return self._form_model.id
|
||||
|
||||
@property
|
||||
def web_app_token(self):
|
||||
if self._console_recipient is not None:
|
||||
return self._console_recipient.access_token
|
||||
if self._web_app_recipient is None:
|
||||
return None
|
||||
return self._web_app_recipient.access_token
|
||||
|
||||
@property
|
||||
def recipients(self) -> list[HumanInputFormRecipientEntity]:
|
||||
return list(self._recipients)
|
||||
|
||||
@property
|
||||
def rendered_content(self) -> str:
|
||||
return self._form_model.rendered_content
|
||||
|
||||
@property
|
||||
def selected_action_id(self) -> str | None:
|
||||
return self._form_model.selected_action_id
|
||||
|
||||
@property
|
||||
def submitted_data(self) -> Mapping[str, Any] | None:
|
||||
return self._submitted_data
|
||||
|
||||
@property
|
||||
def submitted(self) -> bool:
|
||||
return self._form_model.submitted_at is not None
|
||||
|
||||
@property
|
||||
def status(self) -> HumanInputFormStatus:
|
||||
return self._form_model.status
|
||||
|
||||
@property
|
||||
def expiration_time(self) -> datetime:
|
||||
return self._form_model.expiration_time
|
||||
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class HumanInputFormRecord:
|
||||
form_id: str
|
||||
workflow_run_id: str | None
|
||||
node_id: str
|
||||
tenant_id: str
|
||||
app_id: str
|
||||
form_kind: HumanInputFormKind
|
||||
definition: FormDefinition
|
||||
rendered_content: str
|
||||
created_at: datetime
|
||||
expiration_time: datetime
|
||||
status: HumanInputFormStatus
|
||||
selected_action_id: str | None
|
||||
submitted_data: Mapping[str, Any] | None
|
||||
submitted_at: datetime | None
|
||||
submission_user_id: str | None
|
||||
submission_end_user_id: str | None
|
||||
completed_by_recipient_id: str | None
|
||||
recipient_id: str | None
|
||||
recipient_type: RecipientType | None
|
||||
access_token: str | None
|
||||
|
||||
@property
|
||||
def submitted(self) -> bool:
|
||||
return self.submitted_at is not None
|
||||
|
||||
@classmethod
|
||||
def from_models(
|
||||
cls, form_model: HumanInputForm, recipient_model: HumanInputFormRecipient | None
|
||||
) -> "HumanInputFormRecord":
|
||||
definition_payload = json.loads(form_model.form_definition)
|
||||
if "expiration_time" not in definition_payload:
|
||||
definition_payload["expiration_time"] = form_model.expiration_time
|
||||
return cls(
|
||||
form_id=form_model.id,
|
||||
workflow_run_id=form_model.workflow_run_id,
|
||||
node_id=form_model.node_id,
|
||||
tenant_id=form_model.tenant_id,
|
||||
app_id=form_model.app_id,
|
||||
form_kind=form_model.form_kind,
|
||||
definition=FormDefinition.model_validate(definition_payload),
|
||||
rendered_content=form_model.rendered_content,
|
||||
created_at=form_model.created_at,
|
||||
expiration_time=form_model.expiration_time,
|
||||
status=form_model.status,
|
||||
selected_action_id=form_model.selected_action_id,
|
||||
submitted_data=json.loads(form_model.submitted_data) if form_model.submitted_data else None,
|
||||
submitted_at=form_model.submitted_at,
|
||||
submission_user_id=form_model.submission_user_id,
|
||||
submission_end_user_id=form_model.submission_end_user_id,
|
||||
completed_by_recipient_id=form_model.completed_by_recipient_id,
|
||||
recipient_id=recipient_model.id if recipient_model else None,
|
||||
recipient_type=recipient_model.recipient_type if recipient_model else None,
|
||||
access_token=recipient_model.access_token if recipient_model else None,
|
||||
)
|
||||
|
||||
|
||||
class _InvalidTimeoutStatusError(ValueError):
|
||||
pass
|
||||
|
||||
|
||||
class HumanInputFormRepositoryImpl:
|
||||
def __init__(
|
||||
self,
|
||||
session_factory: sessionmaker | Engine,
|
||||
tenant_id: str,
|
||||
):
|
||||
if isinstance(session_factory, Engine):
|
||||
session_factory = sessionmaker(bind=session_factory)
|
||||
self._session_factory = session_factory
|
||||
self._tenant_id = tenant_id
|
||||
|
||||
def _delivery_method_to_model(
|
||||
self,
|
||||
session: Session,
|
||||
form_id: str,
|
||||
delivery_method: DeliveryChannelConfig,
|
||||
) -> _DeliveryAndRecipients:
|
||||
delivery_id = str(uuidv7())
|
||||
delivery_model = HumanInputDelivery(
|
||||
id=delivery_id,
|
||||
form_id=form_id,
|
||||
delivery_method_type=delivery_method.type,
|
||||
delivery_config_id=delivery_method.id,
|
||||
channel_payload=delivery_method.model_dump_json(),
|
||||
)
|
||||
recipients: list[HumanInputFormRecipient] = []
|
||||
if isinstance(delivery_method, WebAppDeliveryMethod):
|
||||
recipient_model = HumanInputFormRecipient(
|
||||
form_id=form_id,
|
||||
delivery_id=delivery_id,
|
||||
recipient_type=RecipientType.STANDALONE_WEB_APP,
|
||||
recipient_payload=StandaloneWebAppRecipientPayload().model_dump_json(),
|
||||
)
|
||||
recipients.append(recipient_model)
|
||||
elif isinstance(delivery_method, EmailDeliveryMethod):
|
||||
email_recipients_config = delivery_method.config.recipients
|
||||
recipients.extend(
|
||||
self._build_email_recipients(
|
||||
session=session,
|
||||
form_id=form_id,
|
||||
delivery_id=delivery_id,
|
||||
recipients_config=email_recipients_config,
|
||||
)
|
||||
)
|
||||
|
||||
return _DeliveryAndRecipients(delivery=delivery_model, recipients=recipients)
|
||||
|
||||
def _build_email_recipients(
|
||||
self,
|
||||
session: Session,
|
||||
form_id: str,
|
||||
delivery_id: str,
|
||||
recipients_config: EmailRecipients,
|
||||
) -> list[HumanInputFormRecipient]:
|
||||
member_user_ids = [
|
||||
recipient.user_id for recipient in recipients_config.items if isinstance(recipient, MemberRecipient)
|
||||
]
|
||||
external_emails = [
|
||||
recipient.email for recipient in recipients_config.items if isinstance(recipient, ExternalRecipient)
|
||||
]
|
||||
if recipients_config.whole_workspace:
|
||||
members = self._query_all_workspace_members(session=session)
|
||||
else:
|
||||
members = self._query_workspace_members_by_ids(session=session, restrict_to_user_ids=member_user_ids)
|
||||
|
||||
return self._create_email_recipients_from_resolved(
|
||||
form_id=form_id,
|
||||
delivery_id=delivery_id,
|
||||
members=members,
|
||||
external_emails=external_emails,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _create_email_recipients_from_resolved(
|
||||
*,
|
||||
form_id: str,
|
||||
delivery_id: str,
|
||||
members: Sequence[_WorkspaceMemberInfo],
|
||||
external_emails: Sequence[str],
|
||||
) -> list[HumanInputFormRecipient]:
|
||||
recipient_models: list[HumanInputFormRecipient] = []
|
||||
seen_emails: set[str] = set()
|
||||
|
||||
for member in members:
|
||||
if not member.email:
|
||||
continue
|
||||
if member.email in seen_emails:
|
||||
continue
|
||||
seen_emails.add(member.email)
|
||||
payload = EmailMemberRecipientPayload(user_id=member.user_id, email=member.email)
|
||||
recipient_models.append(
|
||||
HumanInputFormRecipient.new(
|
||||
form_id=form_id,
|
||||
delivery_id=delivery_id,
|
||||
payload=payload,
|
||||
)
|
||||
)
|
||||
|
||||
for email in external_emails:
|
||||
if not email:
|
||||
continue
|
||||
if email in seen_emails:
|
||||
continue
|
||||
seen_emails.add(email)
|
||||
recipient_models.append(
|
||||
HumanInputFormRecipient.new(
|
||||
form_id=form_id,
|
||||
delivery_id=delivery_id,
|
||||
payload=EmailExternalRecipientPayload(email=email),
|
||||
)
|
||||
)
|
||||
|
||||
return recipient_models
|
||||
|
||||
def _query_all_workspace_members(
|
||||
self,
|
||||
session: Session,
|
||||
) -> list[_WorkspaceMemberInfo]:
|
||||
stmt = (
|
||||
select(Account.id, Account.email)
|
||||
.join(TenantAccountJoin, TenantAccountJoin.account_id == Account.id)
|
||||
.where(TenantAccountJoin.tenant_id == self._tenant_id)
|
||||
)
|
||||
rows = session.execute(stmt).all()
|
||||
return [_WorkspaceMemberInfo(user_id=account_id, email=email) for account_id, email in rows]
|
||||
|
||||
def _query_workspace_members_by_ids(
|
||||
self,
|
||||
session: Session,
|
||||
restrict_to_user_ids: Sequence[str],
|
||||
) -> list[_WorkspaceMemberInfo]:
|
||||
unique_ids = {user_id for user_id in restrict_to_user_ids if user_id}
|
||||
if not unique_ids:
|
||||
return []
|
||||
|
||||
stmt = (
|
||||
select(Account.id, Account.email)
|
||||
.join(TenantAccountJoin, TenantAccountJoin.account_id == Account.id)
|
||||
.where(TenantAccountJoin.tenant_id == self._tenant_id)
|
||||
)
|
||||
stmt = stmt.where(Account.id.in_(unique_ids))
|
||||
|
||||
rows = session.execute(stmt).all()
|
||||
return [_WorkspaceMemberInfo(user_id=account_id, email=email) for account_id, email in rows]
|
||||
|
||||
def create_form(self, params: FormCreateParams) -> HumanInputFormEntity:
|
||||
form_config: HumanInputNodeData = params.form_config
|
||||
|
||||
with self._session_factory(expire_on_commit=False) as session, session.begin():
|
||||
# Generate unique form ID
|
||||
form_id = str(uuidv7())
|
||||
start_time = naive_utc_now()
|
||||
node_expiration = form_config.expiration_time(start_time)
|
||||
form_definition = FormDefinition(
|
||||
form_content=form_config.form_content,
|
||||
inputs=form_config.inputs,
|
||||
user_actions=form_config.user_actions,
|
||||
rendered_content=params.rendered_content,
|
||||
expiration_time=node_expiration,
|
||||
default_values=dict(params.resolved_default_values),
|
||||
display_in_ui=params.display_in_ui,
|
||||
node_title=form_config.title,
|
||||
)
|
||||
form_model = HumanInputForm(
|
||||
id=form_id,
|
||||
tenant_id=self._tenant_id,
|
||||
app_id=params.app_id,
|
||||
workflow_run_id=params.workflow_execution_id,
|
||||
form_kind=params.form_kind,
|
||||
node_id=params.node_id,
|
||||
form_definition=form_definition.model_dump_json(),
|
||||
rendered_content=params.rendered_content,
|
||||
expiration_time=node_expiration,
|
||||
created_at=start_time,
|
||||
)
|
||||
session.add(form_model)
|
||||
recipient_models: list[HumanInputFormRecipient] = []
|
||||
for delivery in params.delivery_methods:
|
||||
delivery_and_recipients = self._delivery_method_to_model(
|
||||
session=session,
|
||||
form_id=form_id,
|
||||
delivery_method=delivery,
|
||||
)
|
||||
session.add(delivery_and_recipients.delivery)
|
||||
session.add_all(delivery_and_recipients.recipients)
|
||||
recipient_models.extend(delivery_and_recipients.recipients)
|
||||
if params.console_recipient_required and not any(
|
||||
recipient.recipient_type == RecipientType.CONSOLE for recipient in recipient_models
|
||||
):
|
||||
console_delivery_id = str(uuidv7())
|
||||
console_delivery = HumanInputDelivery(
|
||||
id=console_delivery_id,
|
||||
form_id=form_id,
|
||||
delivery_method_type=DeliveryMethodType.WEBAPP,
|
||||
delivery_config_id=None,
|
||||
channel_payload=ConsoleDeliveryPayload().model_dump_json(),
|
||||
)
|
||||
console_recipient = HumanInputFormRecipient(
|
||||
form_id=form_id,
|
||||
delivery_id=console_delivery_id,
|
||||
recipient_type=RecipientType.CONSOLE,
|
||||
recipient_payload=ConsoleRecipientPayload(
|
||||
account_id=params.console_creator_account_id,
|
||||
).model_dump_json(),
|
||||
)
|
||||
session.add(console_delivery)
|
||||
session.add(console_recipient)
|
||||
recipient_models.append(console_recipient)
|
||||
if params.backstage_recipient_required and not any(
|
||||
recipient.recipient_type == RecipientType.BACKSTAGE for recipient in recipient_models
|
||||
):
|
||||
backstage_delivery_id = str(uuidv7())
|
||||
backstage_delivery = HumanInputDelivery(
|
||||
id=backstage_delivery_id,
|
||||
form_id=form_id,
|
||||
delivery_method_type=DeliveryMethodType.WEBAPP,
|
||||
delivery_config_id=None,
|
||||
channel_payload=ConsoleDeliveryPayload().model_dump_json(),
|
||||
)
|
||||
backstage_recipient = HumanInputFormRecipient(
|
||||
form_id=form_id,
|
||||
delivery_id=backstage_delivery_id,
|
||||
recipient_type=RecipientType.BACKSTAGE,
|
||||
recipient_payload=BackstageRecipientPayload(
|
||||
account_id=params.console_creator_account_id,
|
||||
).model_dump_json(),
|
||||
)
|
||||
session.add(backstage_delivery)
|
||||
session.add(backstage_recipient)
|
||||
recipient_models.append(backstage_recipient)
|
||||
session.flush()
|
||||
|
||||
return _HumanInputFormEntityImpl(form_model=form_model, recipient_models=recipient_models)
|
||||
|
||||
def get_form(self, workflow_execution_id: str, node_id: str) -> HumanInputFormEntity | None:
|
||||
form_query = select(HumanInputForm).where(
|
||||
HumanInputForm.workflow_run_id == workflow_execution_id,
|
||||
HumanInputForm.node_id == node_id,
|
||||
HumanInputForm.tenant_id == self._tenant_id,
|
||||
)
|
||||
with self._session_factory(expire_on_commit=False) as session:
|
||||
form_model: HumanInputForm | None = session.scalars(form_query).first()
|
||||
if form_model is None:
|
||||
return None
|
||||
|
||||
recipient_query = select(HumanInputFormRecipient).where(HumanInputFormRecipient.form_id == form_model.id)
|
||||
recipient_models = session.scalars(recipient_query).all()
|
||||
return _HumanInputFormEntityImpl(form_model=form_model, recipient_models=recipient_models)
|
||||
|
||||
|
||||
class HumanInputFormSubmissionRepository:
|
||||
"""Repository for fetching and submitting human input forms."""
|
||||
|
||||
def __init__(self, session_factory: sessionmaker | Engine):
|
||||
if isinstance(session_factory, Engine):
|
||||
session_factory = sessionmaker(bind=session_factory)
|
||||
self._session_factory = session_factory
|
||||
|
||||
def get_by_token(self, form_token: str) -> HumanInputFormRecord | None:
|
||||
query = (
|
||||
select(HumanInputFormRecipient)
|
||||
.options(selectinload(HumanInputFormRecipient.form))
|
||||
.where(HumanInputFormRecipient.access_token == form_token)
|
||||
)
|
||||
with self._session_factory(expire_on_commit=False) as session:
|
||||
recipient_model = session.scalars(query).first()
|
||||
if recipient_model is None or recipient_model.form is None:
|
||||
return None
|
||||
return HumanInputFormRecord.from_models(recipient_model.form, recipient_model)
|
||||
|
||||
def get_by_form_id_and_recipient_type(
|
||||
self,
|
||||
form_id: str,
|
||||
recipient_type: RecipientType,
|
||||
) -> HumanInputFormRecord | None:
|
||||
query = (
|
||||
select(HumanInputFormRecipient)
|
||||
.options(selectinload(HumanInputFormRecipient.form))
|
||||
.where(
|
||||
HumanInputFormRecipient.form_id == form_id,
|
||||
HumanInputFormRecipient.recipient_type == recipient_type,
|
||||
)
|
||||
)
|
||||
with self._session_factory(expire_on_commit=False) as session:
|
||||
recipient_model = session.scalars(query).first()
|
||||
if recipient_model is None or recipient_model.form is None:
|
||||
return None
|
||||
return HumanInputFormRecord.from_models(recipient_model.form, recipient_model)
|
||||
|
||||
def mark_submitted(
|
||||
self,
|
||||
*,
|
||||
form_id: str,
|
||||
recipient_id: str | None,
|
||||
selected_action_id: str,
|
||||
form_data: Mapping[str, Any],
|
||||
submission_user_id: str | None,
|
||||
submission_end_user_id: str | None,
|
||||
) -> HumanInputFormRecord:
|
||||
with self._session_factory(expire_on_commit=False) as session, session.begin():
|
||||
form_model = session.get(HumanInputForm, form_id)
|
||||
if form_model is None:
|
||||
raise FormNotFoundError(f"form not found, id={form_id}")
|
||||
|
||||
recipient_model = session.get(HumanInputFormRecipient, recipient_id) if recipient_id else None
|
||||
|
||||
form_model.selected_action_id = selected_action_id
|
||||
form_model.submitted_data = json.dumps(form_data)
|
||||
form_model.submitted_at = naive_utc_now()
|
||||
form_model.status = HumanInputFormStatus.SUBMITTED
|
||||
form_model.submission_user_id = submission_user_id
|
||||
form_model.submission_end_user_id = submission_end_user_id
|
||||
form_model.completed_by_recipient_id = recipient_id
|
||||
|
||||
session.add(form_model)
|
||||
session.flush()
|
||||
session.refresh(form_model)
|
||||
if recipient_model is not None:
|
||||
session.refresh(recipient_model)
|
||||
|
||||
return HumanInputFormRecord.from_models(form_model, recipient_model)
|
||||
|
||||
def mark_timeout(
|
||||
self,
|
||||
*,
|
||||
form_id: str,
|
||||
timeout_status: HumanInputFormStatus,
|
||||
reason: str | None = None,
|
||||
) -> HumanInputFormRecord:
|
||||
with self._session_factory(expire_on_commit=False) as session, session.begin():
|
||||
form_model = session.get(HumanInputForm, form_id)
|
||||
if form_model is None:
|
||||
raise FormNotFoundError(f"form not found, id={form_id}")
|
||||
|
||||
if timeout_status not in {HumanInputFormStatus.TIMEOUT, HumanInputFormStatus.EXPIRED}:
|
||||
raise _InvalidTimeoutStatusError(f"invalid timeout status: {timeout_status}")
|
||||
|
||||
# already handled or submitted
|
||||
if form_model.status in {HumanInputFormStatus.TIMEOUT, HumanInputFormStatus.EXPIRED}:
|
||||
return HumanInputFormRecord.from_models(form_model, None)
|
||||
|
||||
if form_model.submitted_at is not None or form_model.status == HumanInputFormStatus.SUBMITTED:
|
||||
raise FormNotFoundError(f"form already submitted, id={form_id}")
|
||||
|
||||
form_model.status = timeout_status
|
||||
form_model.selected_action_id = None
|
||||
form_model.submitted_data = None
|
||||
form_model.submission_user_id = None
|
||||
form_model.submission_end_user_id = None
|
||||
form_model.completed_by_recipient_id = None
|
||||
# Reason is recorded in status/error downstream; not stored on form.
|
||||
session.add(form_model)
|
||||
session.flush()
|
||||
session.refresh(form_model)
|
||||
|
||||
return HumanInputFormRecord.from_models(form_model, None)
|
||||
|
|
@ -488,6 +488,7 @@ class SQLAlchemyWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository)
|
|||
WorkflowNodeExecutionModel.workflow_run_id == workflow_run_id,
|
||||
WorkflowNodeExecutionModel.tenant_id == self._tenant_id,
|
||||
WorkflowNodeExecutionModel.triggered_from == triggered_from,
|
||||
WorkflowNodeExecutionModel.status != WorkflowNodeExecutionStatus.PAUSED,
|
||||
)
|
||||
|
||||
if self._app_id:
|
||||
|
|
|
|||
|
|
@ -1,4 +1,5 @@
|
|||
from core.tools.entities.tool_entities import ToolInvokeMeta
|
||||
from libs.exception import BaseHTTPException
|
||||
|
||||
|
||||
class ToolProviderNotFoundError(ValueError):
|
||||
|
|
@ -37,6 +38,12 @@ class ToolCredentialPolicyViolationError(ValueError):
|
|||
pass
|
||||
|
||||
|
||||
class WorkflowToolHumanInputNotSupportedError(BaseHTTPException):
|
||||
error_code = "workflow_tool_human_input_not_supported"
|
||||
description = "Workflow with Human Input nodes cannot be published as a workflow tool."
|
||||
code = 400
|
||||
|
||||
|
||||
class ToolEngineInvokeError(Exception):
|
||||
meta: ToolInvokeMeta
|
||||
|
||||
|
|
|
|||
|
|
@ -3,6 +3,8 @@ from typing import Any
|
|||
|
||||
from core.app.app_config.entities import VariableEntity
|
||||
from core.tools.entities.tool_entities import WorkflowToolParameterConfiguration
|
||||
from core.tools.errors import WorkflowToolHumanInputNotSupportedError
|
||||
from core.workflow.enums import NodeType
|
||||
from core.workflow.nodes.base.entities import OutputVariableEntity
|
||||
|
||||
|
||||
|
|
@ -50,6 +52,13 @@ class WorkflowToolConfigurationUtils:
|
|||
|
||||
return [outputs_by_variable[variable] for variable in variable_order]
|
||||
|
||||
@classmethod
|
||||
def ensure_no_human_input_nodes(cls, graph: Mapping[str, Any]) -> None:
|
||||
nodes = graph.get("nodes", [])
|
||||
for node in nodes:
|
||||
if node.get("data", {}).get("type") == NodeType.HUMAN_INPUT:
|
||||
raise WorkflowToolHumanInputNotSupportedError()
|
||||
|
||||
@classmethod
|
||||
def check_is_synced(
|
||||
cls, variables: list[VariableEntity], tool_configurations: list[WorkflowToolParameterConfiguration]
|
||||
|
|
|
|||
|
|
@ -2,10 +2,12 @@ from .agent import AgentNodeStrategyInit
|
|||
from .graph_init_params import GraphInitParams
|
||||
from .workflow_execution import WorkflowExecution
|
||||
from .workflow_node_execution import WorkflowNodeExecution
|
||||
from .workflow_start_reason import WorkflowStartReason
|
||||
|
||||
__all__ = [
|
||||
"AgentNodeStrategyInit",
|
||||
"GraphInitParams",
|
||||
"WorkflowExecution",
|
||||
"WorkflowNodeExecution",
|
||||
"WorkflowStartReason",
|
||||
]
|
||||
|
|
|
|||
|
|
@ -5,6 +5,16 @@ from pydantic import BaseModel, Field
|
|||
|
||||
|
||||
class GraphInitParams(BaseModel):
|
||||
"""GraphInitParams encapsulates the configurations and contextual information
|
||||
that remain constant throughout a single execution of the graph engine.
|
||||
|
||||
A single execution is defined as follows: as long as the execution has not reached
|
||||
its conclusion, it is considered one execution. For instance, if a workflow is suspended
|
||||
and later resumed, it is still regarded as a single execution, not two.
|
||||
|
||||
For the state diagram of workflow execution, refer to `WorkflowExecutionStatus`.
|
||||
"""
|
||||
|
||||
# init params
|
||||
tenant_id: str = Field(..., description="tenant / workspace id")
|
||||
app_id: str = Field(..., description="app id")
|
||||
|
|
|
|||
|
|
@ -1,8 +1,11 @@
|
|||
from collections.abc import Mapping
|
||||
from enum import StrEnum, auto
|
||||
from typing import Annotated, Literal, TypeAlias
|
||||
from typing import Annotated, Any, Literal, TypeAlias
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from core.workflow.nodes.human_input.entities import FormInput, UserAction
|
||||
|
||||
|
||||
class PauseReasonType(StrEnum):
|
||||
HUMAN_INPUT_REQUIRED = auto()
|
||||
|
|
@ -11,10 +14,31 @@ class PauseReasonType(StrEnum):
|
|||
|
||||
class HumanInputRequired(BaseModel):
|
||||
TYPE: Literal[PauseReasonType.HUMAN_INPUT_REQUIRED] = PauseReasonType.HUMAN_INPUT_REQUIRED
|
||||
|
||||
form_id: str
|
||||
# The identifier of the human input node causing the pause.
|
||||
form_content: str
|
||||
inputs: list[FormInput] = Field(default_factory=list)
|
||||
actions: list[UserAction] = Field(default_factory=list)
|
||||
display_in_ui: bool = False
|
||||
node_id: str
|
||||
node_title: str
|
||||
|
||||
# The `resolved_default_values` stores the resolved values of variable defaults. It's a mapping from
|
||||
# `output_variable_name` to their resolved values.
|
||||
#
|
||||
# For example, The form contains a input with output variable name `name` and placeholder type `VARIABLE`, its
|
||||
# selector is ["start", "name"]. While the HumanInputNode is executed, the correspond value of variable
|
||||
# `start.name` in variable pool is `John`. Thus, the resolved value of the output variable `name` is `John`. The
|
||||
# `resolved_default_values` is `{"name": "John"}`.
|
||||
#
|
||||
# Only form inputs with default value type `VARIABLE` will be resolved and stored in `resolved_default_values`.
|
||||
resolved_default_values: Mapping[str, Any] = Field(default_factory=dict)
|
||||
|
||||
# The `form_token` is the token used to submit the form via UI surfaces. It corresponds to
|
||||
# `HumanInputFormRecipient.access_token`.
|
||||
#
|
||||
# This field is `None` if webapp delivery is not set and not
|
||||
# in orchestrating mode.
|
||||
form_token: str | None = None
|
||||
|
||||
|
||||
class SchedulingPause(BaseModel):
|
||||
|
|
|
|||
|
|
@ -0,0 +1,8 @@
|
|||
from enum import StrEnum
|
||||
|
||||
|
||||
class WorkflowStartReason(StrEnum):
|
||||
"""Reason for workflow start events across graph/queue/SSE layers."""
|
||||
|
||||
INITIAL = "initial" # First start of a workflow run.
|
||||
RESUMPTION = "resumption" # Start triggered after resuming a paused run.
|
||||
|
|
@ -0,0 +1,15 @@
|
|||
import time
|
||||
|
||||
|
||||
def get_timestamp() -> float:
|
||||
"""Retrieve a timestamp as a float point numer representing the number of seconds
|
||||
since the Unix epoch.
|
||||
|
||||
This function is primarily used to measure the execution time of the workflow engine.
|
||||
Since workflow execution may be paused and resumed on a different machine,
|
||||
`time.perf_counter` cannot be used as it is inconsistent across machines.
|
||||
|
||||
To address this, the function uses the wall clock as the time source.
|
||||
However, it assumes that the clocks of all servers are properly synchronized.
|
||||
"""
|
||||
return round(time.time())
|
||||
|
|
@ -2,12 +2,14 @@
|
|||
GraphEngine configuration models.
|
||||
"""
|
||||
|
||||
from pydantic import BaseModel
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
|
||||
|
||||
class GraphEngineConfig(BaseModel):
|
||||
"""Configuration for GraphEngine worker pool scaling."""
|
||||
|
||||
model_config = ConfigDict(frozen=True)
|
||||
|
||||
min_workers: int = 1
|
||||
max_workers: int = 5
|
||||
scale_up_threshold: int = 3
|
||||
|
|
|
|||
|
|
@ -192,9 +192,13 @@ class EventHandler:
|
|||
self._event_collector.collect(edge_event)
|
||||
|
||||
# Enqueue ready nodes
|
||||
for node_id in ready_nodes:
|
||||
self._state_manager.enqueue_node(node_id)
|
||||
self._state_manager.start_execution(node_id)
|
||||
if self._graph_execution.is_paused:
|
||||
for node_id in ready_nodes:
|
||||
self._graph_runtime_state.register_deferred_node(node_id)
|
||||
else:
|
||||
for node_id in ready_nodes:
|
||||
self._state_manager.enqueue_node(node_id)
|
||||
self._state_manager.start_execution(node_id)
|
||||
|
||||
# Update execution tracking
|
||||
self._state_manager.finish_execution(event.node_id)
|
||||
|
|
|
|||
|
|
@ -14,6 +14,7 @@ from collections.abc import Generator
|
|||
from typing import TYPE_CHECKING, cast, final
|
||||
|
||||
from core.workflow.context import capture_current_context
|
||||
from core.workflow.entities.workflow_start_reason import WorkflowStartReason
|
||||
from core.workflow.enums import NodeExecutionType
|
||||
from core.workflow.graph import Graph
|
||||
from core.workflow.graph_events import (
|
||||
|
|
@ -56,6 +57,9 @@ if TYPE_CHECKING:
|
|||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
_DEFAULT_CONFIG = GraphEngineConfig()
|
||||
|
||||
|
||||
@final
|
||||
class GraphEngine:
|
||||
"""
|
||||
|
|
@ -71,7 +75,7 @@ class GraphEngine:
|
|||
graph: Graph,
|
||||
graph_runtime_state: GraphRuntimeState,
|
||||
command_channel: CommandChannel,
|
||||
config: GraphEngineConfig,
|
||||
config: GraphEngineConfig = _DEFAULT_CONFIG,
|
||||
) -> None:
|
||||
"""Initialize the graph engine with all subsystems and dependencies."""
|
||||
# stop event
|
||||
|
|
@ -235,7 +239,9 @@ class GraphEngine:
|
|||
self._graph_execution.paused = False
|
||||
self._graph_execution.pause_reasons = []
|
||||
|
||||
start_event = GraphRunStartedEvent()
|
||||
start_event = GraphRunStartedEvent(
|
||||
reason=WorkflowStartReason.RESUMPTION if is_resume else WorkflowStartReason.INITIAL,
|
||||
)
|
||||
self._event_manager.notify_layers(start_event)
|
||||
yield start_event
|
||||
|
||||
|
|
@ -304,15 +310,17 @@ class GraphEngine:
|
|||
for layer in self._layers:
|
||||
try:
|
||||
layer.on_graph_start()
|
||||
except Exception as e:
|
||||
logger.warning("Layer %s failed on_graph_start: %s", layer.__class__.__name__, e)
|
||||
except Exception:
|
||||
logger.exception("Layer %s failed on_graph_start", layer.__class__.__name__)
|
||||
|
||||
def _start_execution(self, *, resume: bool = False) -> None:
|
||||
"""Start execution subsystems."""
|
||||
self._stop_event.clear()
|
||||
paused_nodes: list[str] = []
|
||||
deferred_nodes: list[str] = []
|
||||
if resume:
|
||||
paused_nodes = self._graph_runtime_state.consume_paused_nodes()
|
||||
deferred_nodes = self._graph_runtime_state.consume_deferred_nodes()
|
||||
|
||||
# Start worker pool (it calculates initial workers internally)
|
||||
self._worker_pool.start()
|
||||
|
|
@ -328,7 +336,11 @@ class GraphEngine:
|
|||
self._state_manager.enqueue_node(root_node.id)
|
||||
self._state_manager.start_execution(root_node.id)
|
||||
else:
|
||||
for node_id in paused_nodes:
|
||||
seen_nodes: set[str] = set()
|
||||
for node_id in paused_nodes + deferred_nodes:
|
||||
if node_id in seen_nodes:
|
||||
continue
|
||||
seen_nodes.add(node_id)
|
||||
self._state_manager.enqueue_node(node_id)
|
||||
self._state_manager.start_execution(node_id)
|
||||
|
||||
|
|
@ -346,8 +358,8 @@ class GraphEngine:
|
|||
for layer in self._layers:
|
||||
try:
|
||||
layer.on_graph_end(self._graph_execution.error)
|
||||
except Exception as e:
|
||||
logger.warning("Layer %s failed on_graph_end: %s", layer.__class__.__name__, e)
|
||||
except Exception:
|
||||
logger.exception("Layer %s failed on_graph_end", layer.__class__.__name__)
|
||||
|
||||
# Public property accessors for attributes that need external access
|
||||
@property
|
||||
|
|
|
|||
|
|
@ -224,6 +224,8 @@ class GraphStateManager:
|
|||
Returns:
|
||||
Number of executing nodes
|
||||
"""
|
||||
# This count is a best-effort snapshot and can change concurrently.
|
||||
# Only use it for pause-drain checks where scheduling is already frozen.
|
||||
with self._lock:
|
||||
return len(self._executing_nodes)
|
||||
|
||||
|
|
|
|||
|
|
@ -83,12 +83,12 @@ class Dispatcher:
|
|||
"""Main dispatcher loop."""
|
||||
try:
|
||||
self._process_commands()
|
||||
paused = False
|
||||
while not self._stop_event.is_set():
|
||||
if (
|
||||
self._execution_coordinator.aborted
|
||||
or self._execution_coordinator.paused
|
||||
or self._execution_coordinator.execution_complete
|
||||
):
|
||||
if self._execution_coordinator.aborted or self._execution_coordinator.execution_complete:
|
||||
break
|
||||
if self._execution_coordinator.paused:
|
||||
paused = True
|
||||
break
|
||||
|
||||
self._execution_coordinator.check_scaling()
|
||||
|
|
@ -101,13 +101,10 @@ class Dispatcher:
|
|||
time.sleep(0.1)
|
||||
|
||||
self._process_commands()
|
||||
while True:
|
||||
try:
|
||||
event = self._event_queue.get(block=False)
|
||||
self._event_handler.dispatch(event)
|
||||
self._event_queue.task_done()
|
||||
except queue.Empty:
|
||||
break
|
||||
if paused:
|
||||
self._drain_events_until_idle()
|
||||
else:
|
||||
self._drain_event_queue()
|
||||
|
||||
except Exception as e:
|
||||
logger.exception("Dispatcher error")
|
||||
|
|
@ -122,3 +119,24 @@ class Dispatcher:
|
|||
def _process_commands(self, event: GraphNodeEventBase | None = None):
|
||||
if event is None or isinstance(event, self._COMMAND_TRIGGER_EVENTS):
|
||||
self._execution_coordinator.process_commands()
|
||||
|
||||
def _drain_event_queue(self) -> None:
|
||||
while True:
|
||||
try:
|
||||
event = self._event_queue.get(block=False)
|
||||
self._event_handler.dispatch(event)
|
||||
self._event_queue.task_done()
|
||||
except queue.Empty:
|
||||
break
|
||||
|
||||
def _drain_events_until_idle(self) -> None:
|
||||
while not self._stop_event.is_set():
|
||||
try:
|
||||
event = self._event_queue.get(timeout=0.1)
|
||||
self._event_handler.dispatch(event)
|
||||
self._event_queue.task_done()
|
||||
self._process_commands(event)
|
||||
except queue.Empty:
|
||||
if not self._execution_coordinator.has_executing_nodes():
|
||||
break
|
||||
self._drain_event_queue()
|
||||
|
|
|
|||
|
|
@ -94,3 +94,11 @@ class ExecutionCoordinator:
|
|||
|
||||
self._worker_pool.stop()
|
||||
self._state_manager.clear_executing()
|
||||
|
||||
def has_executing_nodes(self) -> bool:
|
||||
"""Return True if any nodes are currently marked as executing."""
|
||||
# This check is only safe once execution has already paused.
|
||||
# Before pause, executing state can change concurrently, which makes the result unreliable.
|
||||
if not self._graph_execution.is_paused:
|
||||
raise AssertionError("has_executing_nodes should only be called after execution is paused")
|
||||
return self._state_manager.get_executing_count() > 0
|
||||
|
|
|
|||
|
|
@ -38,6 +38,8 @@ from .loop import (
|
|||
from .node import (
|
||||
NodeRunExceptionEvent,
|
||||
NodeRunFailedEvent,
|
||||
NodeRunHumanInputFormFilledEvent,
|
||||
NodeRunHumanInputFormTimeoutEvent,
|
||||
NodeRunPauseRequestedEvent,
|
||||
NodeRunRetrieverResourceEvent,
|
||||
NodeRunRetryEvent,
|
||||
|
|
@ -60,6 +62,8 @@ __all__ = [
|
|||
"NodeRunAgentLogEvent",
|
||||
"NodeRunExceptionEvent",
|
||||
"NodeRunFailedEvent",
|
||||
"NodeRunHumanInputFormFilledEvent",
|
||||
"NodeRunHumanInputFormTimeoutEvent",
|
||||
"NodeRunIterationFailedEvent",
|
||||
"NodeRunIterationNextEvent",
|
||||
"NodeRunIterationStartedEvent",
|
||||
|
|
|
|||
|
|
@ -1,11 +1,16 @@
|
|||
from pydantic import Field
|
||||
|
||||
from core.workflow.entities.pause_reason import PauseReason
|
||||
from core.workflow.entities.workflow_start_reason import WorkflowStartReason
|
||||
from core.workflow.graph_events import BaseGraphEvent
|
||||
|
||||
|
||||
class GraphRunStartedEvent(BaseGraphEvent):
|
||||
pass
|
||||
# Reason is emitted for workflow start events and is always set.
|
||||
reason: WorkflowStartReason = Field(
|
||||
default=WorkflowStartReason.INITIAL,
|
||||
description="reason for workflow start",
|
||||
)
|
||||
|
||||
|
||||
class GraphRunSucceededEvent(BaseGraphEvent):
|
||||
|
|
|
|||
|
|
@ -54,6 +54,22 @@ class NodeRunRetryEvent(NodeRunStartedEvent):
|
|||
retry_index: int = Field(..., description="which retry attempt is about to be performed")
|
||||
|
||||
|
||||
class NodeRunHumanInputFormFilledEvent(GraphNodeEventBase):
|
||||
"""Emitted when a HumanInput form is submitted and before the node finishes."""
|
||||
|
||||
node_title: str = Field(..., description="HumanInput node title")
|
||||
rendered_content: str = Field(..., description="Markdown content rendered with user inputs.")
|
||||
action_id: str = Field(..., description="User action identifier chosen in the form.")
|
||||
action_text: str = Field(..., description="Display text of the chosen action button.")
|
||||
|
||||
|
||||
class NodeRunHumanInputFormTimeoutEvent(GraphNodeEventBase):
|
||||
"""Emitted when a HumanInput form times out."""
|
||||
|
||||
node_title: str = Field(..., description="HumanInput node title")
|
||||
expiration_time: datetime = Field(..., description="Form expiration time")
|
||||
|
||||
|
||||
class NodeRunPauseRequestedEvent(GraphNodeEventBase):
|
||||
reason: PauseReason = Field(..., description="pause reason")
|
||||
|
||||
|
|
|
|||
|
|
@ -13,6 +13,8 @@ from .loop import (
|
|||
LoopSucceededEvent,
|
||||
)
|
||||
from .node import (
|
||||
HumanInputFormFilledEvent,
|
||||
HumanInputFormTimeoutEvent,
|
||||
ModelInvokeCompletedEvent,
|
||||
PauseRequestedEvent,
|
||||
RunRetrieverResourceEvent,
|
||||
|
|
@ -23,6 +25,8 @@ from .node import (
|
|||
|
||||
__all__ = [
|
||||
"AgentLogEvent",
|
||||
"HumanInputFormFilledEvent",
|
||||
"HumanInputFormTimeoutEvent",
|
||||
"IterationFailedEvent",
|
||||
"IterationNextEvent",
|
||||
"IterationStartedEvent",
|
||||
|
|
|
|||
|
|
@ -47,3 +47,19 @@ class StreamCompletedEvent(NodeEventBase):
|
|||
|
||||
class PauseRequestedEvent(NodeEventBase):
|
||||
reason: PauseReason = Field(..., description="pause reason")
|
||||
|
||||
|
||||
class HumanInputFormFilledEvent(NodeEventBase):
|
||||
"""Event emitted when a human input form is submitted."""
|
||||
|
||||
node_title: str
|
||||
rendered_content: str
|
||||
action_id: str
|
||||
action_text: str
|
||||
|
||||
|
||||
class HumanInputFormTimeoutEvent(NodeEventBase):
|
||||
"""Event emitted when a human input form times out."""
|
||||
|
||||
node_title: str
|
||||
expiration_time: datetime
|
||||
|
|
|
|||
|
|
@ -18,6 +18,8 @@ from core.workflow.graph_events import (
|
|||
GraphNodeEventBase,
|
||||
NodeRunAgentLogEvent,
|
||||
NodeRunFailedEvent,
|
||||
NodeRunHumanInputFormFilledEvent,
|
||||
NodeRunHumanInputFormTimeoutEvent,
|
||||
NodeRunIterationFailedEvent,
|
||||
NodeRunIterationNextEvent,
|
||||
NodeRunIterationStartedEvent,
|
||||
|
|
@ -34,6 +36,8 @@ from core.workflow.graph_events import (
|
|||
)
|
||||
from core.workflow.node_events import (
|
||||
AgentLogEvent,
|
||||
HumanInputFormFilledEvent,
|
||||
HumanInputFormTimeoutEvent,
|
||||
IterationFailedEvent,
|
||||
IterationNextEvent,
|
||||
IterationStartedEvent,
|
||||
|
|
@ -61,6 +65,15 @@ logger = logging.getLogger(__name__)
|
|||
|
||||
|
||||
class Node(Generic[NodeDataT]):
|
||||
"""BaseNode serves as the foundational class for all node implementations.
|
||||
|
||||
Nodes are allowed to maintain transient states (e.g., `LLMNode` uses the `_file_output`
|
||||
attribute to track files generated by the LLM). However, these states are not persisted
|
||||
when the workflow is suspended or resumed. If a node needs its state to be preserved
|
||||
across workflow suspension and resumption, it should include the relevant state data
|
||||
in its output.
|
||||
"""
|
||||
|
||||
node_type: ClassVar[NodeType]
|
||||
execution_type: NodeExecutionType = NodeExecutionType.EXECUTABLE
|
||||
_node_data_type: ClassVar[type[BaseNodeData]] = BaseNodeData
|
||||
|
|
@ -251,10 +264,33 @@ class Node(Generic[NodeDataT]):
|
|||
return self._node_execution_id
|
||||
|
||||
def ensure_execution_id(self) -> str:
|
||||
if not self._node_execution_id:
|
||||
self._node_execution_id = str(uuid4())
|
||||
if self._node_execution_id:
|
||||
return self._node_execution_id
|
||||
|
||||
resumed_execution_id = self._restore_execution_id_from_runtime_state()
|
||||
if resumed_execution_id:
|
||||
self._node_execution_id = resumed_execution_id
|
||||
return self._node_execution_id
|
||||
|
||||
self._node_execution_id = str(uuid4())
|
||||
return self._node_execution_id
|
||||
|
||||
def _restore_execution_id_from_runtime_state(self) -> str | None:
|
||||
graph_execution = self.graph_runtime_state.graph_execution
|
||||
try:
|
||||
node_executions = graph_execution.node_executions
|
||||
except AttributeError:
|
||||
return None
|
||||
if not isinstance(node_executions, dict):
|
||||
return None
|
||||
node_execution = node_executions.get(self._node_id)
|
||||
if node_execution is None:
|
||||
return None
|
||||
execution_id = node_execution.execution_id
|
||||
if not execution_id:
|
||||
return None
|
||||
return str(execution_id)
|
||||
|
||||
def _hydrate_node_data(self, data: Mapping[str, Any]) -> NodeDataT:
|
||||
return cast(NodeDataT, self._node_data_type.model_validate(data))
|
||||
|
||||
|
|
@ -620,6 +656,28 @@ class Node(Generic[NodeDataT]):
|
|||
metadata=event.metadata,
|
||||
)
|
||||
|
||||
@_dispatch.register
|
||||
def _(self, event: HumanInputFormFilledEvent):
|
||||
return NodeRunHumanInputFormFilledEvent(
|
||||
id=self.execution_id,
|
||||
node_id=self._node_id,
|
||||
node_type=self.node_type,
|
||||
node_title=event.node_title,
|
||||
rendered_content=event.rendered_content,
|
||||
action_id=event.action_id,
|
||||
action_text=event.action_text,
|
||||
)
|
||||
|
||||
@_dispatch.register
|
||||
def _(self, event: HumanInputFormTimeoutEvent):
|
||||
return NodeRunHumanInputFormTimeoutEvent(
|
||||
id=self.execution_id,
|
||||
node_id=self._node_id,
|
||||
node_type=self.node_type,
|
||||
node_title=event.node_title,
|
||||
expiration_time=event.expiration_time,
|
||||
)
|
||||
|
||||
@_dispatch.register
|
||||
def _(self, event: LoopStartedEvent) -> NodeRunLoopStartedEvent:
|
||||
return NodeRunLoopStartedEvent(
|
||||
|
|
|
|||
|
|
@ -1,3 +1,3 @@
|
|||
from .human_input_node import HumanInputNode
|
||||
|
||||
__all__ = ["HumanInputNode"]
|
||||
"""
|
||||
Human Input node implementation.
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -1,10 +1,350 @@
|
|||
from pydantic import Field
|
||||
"""
|
||||
Human Input node entities.
|
||||
"""
|
||||
|
||||
import re
|
||||
import uuid
|
||||
from collections.abc import Mapping, Sequence
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Annotated, Any, ClassVar, Literal, Self
|
||||
|
||||
from pydantic import BaseModel, Field, field_validator, model_validator
|
||||
|
||||
from core.variables.consts import SELECTORS_LENGTH
|
||||
from core.workflow.nodes.base import BaseNodeData
|
||||
from core.workflow.nodes.base.variable_template_parser import VariableTemplateParser
|
||||
from core.workflow.runtime import VariablePool
|
||||
|
||||
from .enums import ButtonStyle, DeliveryMethodType, EmailRecipientType, FormInputType, PlaceholderType, TimeoutUnit
|
||||
|
||||
_OUTPUT_VARIABLE_PATTERN = re.compile(r"\{\{#\$output\.(?P<field_name>[a-zA-Z_][a-zA-Z0-9_]{0,29})#\}\}")
|
||||
|
||||
|
||||
class _WebAppDeliveryConfig(BaseModel):
|
||||
"""Configuration for webapp delivery method."""
|
||||
|
||||
pass # Empty for webapp delivery
|
||||
|
||||
|
||||
class MemberRecipient(BaseModel):
|
||||
"""Member recipient for email delivery."""
|
||||
|
||||
type: Literal[EmailRecipientType.MEMBER] = EmailRecipientType.MEMBER
|
||||
user_id: str
|
||||
|
||||
|
||||
class ExternalRecipient(BaseModel):
|
||||
"""External recipient for email delivery."""
|
||||
|
||||
type: Literal[EmailRecipientType.EXTERNAL] = EmailRecipientType.EXTERNAL
|
||||
email: str
|
||||
|
||||
|
||||
EmailRecipient = Annotated[MemberRecipient | ExternalRecipient, Field(discriminator="type")]
|
||||
|
||||
|
||||
class EmailRecipients(BaseModel):
|
||||
"""Email recipients configuration."""
|
||||
|
||||
# When true, recipients are the union of all workspace members and external items.
|
||||
# Member items are ignored because they are already covered by the workspace scope.
|
||||
# De-duplication is applied by email, with member recipients taking precedence.
|
||||
whole_workspace: bool = False
|
||||
items: list[EmailRecipient] = Field(default_factory=list)
|
||||
|
||||
|
||||
class EmailDeliveryConfig(BaseModel):
|
||||
"""Configuration for email delivery method."""
|
||||
|
||||
URL_PLACEHOLDER: ClassVar[str] = "{{#url#}}"
|
||||
|
||||
recipients: EmailRecipients
|
||||
|
||||
# the subject of email
|
||||
subject: str
|
||||
|
||||
# Body is the content of email.It may contain the speical placeholder `{{#url#}}`, which
|
||||
# represent the url to submit the form.
|
||||
#
|
||||
# It may also reference the output variable of the previous node with the syntax
|
||||
# `{{#<node_id>.<field_name>#}}`.
|
||||
body: str
|
||||
debug_mode: bool = False
|
||||
|
||||
def with_debug_recipient(self, user_id: str) -> "EmailDeliveryConfig":
|
||||
if not user_id:
|
||||
debug_recipients = EmailRecipients(whole_workspace=False, items=[])
|
||||
return self.model_copy(update={"recipients": debug_recipients})
|
||||
debug_recipients = EmailRecipients(whole_workspace=False, items=[MemberRecipient(user_id=user_id)])
|
||||
return self.model_copy(update={"recipients": debug_recipients})
|
||||
|
||||
@classmethod
|
||||
def replace_url_placeholder(cls, body: str, url: str | None) -> str:
|
||||
"""Replace the url placeholder with provided value."""
|
||||
return body.replace(cls.URL_PLACEHOLDER, url or "")
|
||||
|
||||
@classmethod
|
||||
def render_body_template(
|
||||
cls,
|
||||
*,
|
||||
body: str,
|
||||
url: str | None,
|
||||
variable_pool: VariablePool | None = None,
|
||||
) -> str:
|
||||
"""Render email body by replacing placeholders with runtime values."""
|
||||
templated_body = cls.replace_url_placeholder(body, url)
|
||||
if variable_pool is None:
|
||||
return templated_body
|
||||
return variable_pool.convert_template(templated_body).text
|
||||
|
||||
|
||||
class _DeliveryMethodBase(BaseModel):
|
||||
"""Base delivery method configuration."""
|
||||
|
||||
enabled: bool = True
|
||||
id: uuid.UUID = Field(default_factory=uuid.uuid4)
|
||||
|
||||
def extract_variable_selectors(self) -> Sequence[Sequence[str]]:
|
||||
return ()
|
||||
|
||||
|
||||
class WebAppDeliveryMethod(_DeliveryMethodBase):
|
||||
"""Webapp delivery method configuration."""
|
||||
|
||||
type: Literal[DeliveryMethodType.WEBAPP] = DeliveryMethodType.WEBAPP
|
||||
# The config field is not used currently.
|
||||
config: _WebAppDeliveryConfig = Field(default_factory=_WebAppDeliveryConfig)
|
||||
|
||||
|
||||
class EmailDeliveryMethod(_DeliveryMethodBase):
|
||||
"""Email delivery method configuration."""
|
||||
|
||||
type: Literal[DeliveryMethodType.EMAIL] = DeliveryMethodType.EMAIL
|
||||
config: EmailDeliveryConfig
|
||||
|
||||
def extract_variable_selectors(self) -> Sequence[Sequence[str]]:
|
||||
variable_template_parser = VariableTemplateParser(template=self.config.body)
|
||||
selectors: list[Sequence[str]] = []
|
||||
for variable_selector in variable_template_parser.extract_variable_selectors():
|
||||
value_selector = list(variable_selector.value_selector)
|
||||
if len(value_selector) < SELECTORS_LENGTH:
|
||||
continue
|
||||
selectors.append(value_selector[:SELECTORS_LENGTH])
|
||||
return selectors
|
||||
|
||||
|
||||
DeliveryChannelConfig = Annotated[WebAppDeliveryMethod | EmailDeliveryMethod, Field(discriminator="type")]
|
||||
|
||||
|
||||
def apply_debug_email_recipient(
|
||||
method: DeliveryChannelConfig,
|
||||
*,
|
||||
enabled: bool,
|
||||
user_id: str,
|
||||
) -> DeliveryChannelConfig:
|
||||
if not enabled:
|
||||
return method
|
||||
if not isinstance(method, EmailDeliveryMethod):
|
||||
return method
|
||||
if not method.config.debug_mode:
|
||||
return method
|
||||
debug_config = method.config.with_debug_recipient(user_id or "")
|
||||
return method.model_copy(update={"config": debug_config})
|
||||
|
||||
|
||||
class FormInputDefault(BaseModel):
|
||||
"""Default configuration for form inputs."""
|
||||
|
||||
# NOTE: Ideally, a discriminated union would be used to model
|
||||
# FormInputDefault. However, the UI requires preserving the previous
|
||||
# value when switching between `VARIABLE` and `CONSTANT` types. This
|
||||
# necessitates retaining all fields, making a discriminated union unsuitable.
|
||||
|
||||
type: PlaceholderType
|
||||
|
||||
# The selector of default variable, used when `type` is `VARIABLE`.
|
||||
selector: Sequence[str] = Field(default_factory=tuple) #
|
||||
|
||||
# The value of the default, used when `type` is `CONSTANT`.
|
||||
# TODO: How should we express JSON values?
|
||||
value: str = ""
|
||||
|
||||
@model_validator(mode="after")
|
||||
def _validate_selector(self) -> Self:
|
||||
if self.type == PlaceholderType.CONSTANT:
|
||||
return self
|
||||
if len(self.selector) < SELECTORS_LENGTH:
|
||||
raise ValueError(f"the length of selector should be at least {SELECTORS_LENGTH}, selector={self.selector}")
|
||||
return self
|
||||
|
||||
|
||||
class FormInput(BaseModel):
|
||||
"""Form input definition."""
|
||||
|
||||
type: FormInputType
|
||||
output_variable_name: str
|
||||
default: FormInputDefault | None = None
|
||||
|
||||
|
||||
_IDENTIFIER_PATTERN = re.compile(r"^[A-Za-z_][A-Za-z0-9_]*$")
|
||||
|
||||
|
||||
class UserAction(BaseModel):
|
||||
"""User action configuration."""
|
||||
|
||||
# id is the identifier for this action.
|
||||
# It also serves as the identifiers of output handle.
|
||||
#
|
||||
# The id must be a valid identifier (satisfy the _IDENTIFIER_PATTERN above.)
|
||||
id: str = Field(max_length=20)
|
||||
title: str = Field(max_length=20)
|
||||
button_style: ButtonStyle = ButtonStyle.DEFAULT
|
||||
|
||||
@field_validator("id")
|
||||
@classmethod
|
||||
def _validate_id(cls, value: str) -> str:
|
||||
if not _IDENTIFIER_PATTERN.match(value):
|
||||
raise ValueError(
|
||||
f"'{value}' is not a valid identifier. It must start with a letter or underscore, "
|
||||
f"and contain only letters, numbers, or underscores."
|
||||
)
|
||||
return value
|
||||
|
||||
|
||||
class HumanInputNodeData(BaseNodeData):
|
||||
"""Configuration schema for the HumanInput node."""
|
||||
"""Human Input node data."""
|
||||
|
||||
required_variables: list[str] = Field(default_factory=list)
|
||||
pause_reason: str | None = Field(default=None)
|
||||
delivery_methods: list[DeliveryChannelConfig] = Field(default_factory=list)
|
||||
form_content: str = ""
|
||||
inputs: list[FormInput] = Field(default_factory=list)
|
||||
user_actions: list[UserAction] = Field(default_factory=list)
|
||||
timeout: int = 36
|
||||
timeout_unit: TimeoutUnit = TimeoutUnit.HOUR
|
||||
|
||||
@field_validator("inputs")
|
||||
@classmethod
|
||||
def _validate_inputs(cls, inputs: list[FormInput]) -> list[FormInput]:
|
||||
seen_names: set[str] = set()
|
||||
for form_input in inputs:
|
||||
name = form_input.output_variable_name
|
||||
if name in seen_names:
|
||||
raise ValueError(f"duplicated output_variable_name '{name}' in inputs")
|
||||
seen_names.add(name)
|
||||
return inputs
|
||||
|
||||
@field_validator("user_actions")
|
||||
@classmethod
|
||||
def _validate_user_actions(cls, user_actions: list[UserAction]) -> list[UserAction]:
|
||||
seen_ids: set[str] = set()
|
||||
for action in user_actions:
|
||||
action_id = action.id
|
||||
if action_id in seen_ids:
|
||||
raise ValueError(f"duplicated user action id '{action_id}'")
|
||||
seen_ids.add(action_id)
|
||||
return user_actions
|
||||
|
||||
def is_webapp_enabled(self) -> bool:
|
||||
for dm in self.delivery_methods:
|
||||
if not dm.enabled:
|
||||
continue
|
||||
if dm.type == DeliveryMethodType.WEBAPP:
|
||||
return True
|
||||
return False
|
||||
|
||||
def expiration_time(self, start_time: datetime) -> datetime:
|
||||
if self.timeout_unit == TimeoutUnit.HOUR:
|
||||
return start_time + timedelta(hours=self.timeout)
|
||||
elif self.timeout_unit == TimeoutUnit.DAY:
|
||||
return start_time + timedelta(days=self.timeout)
|
||||
else:
|
||||
raise AssertionError("unknown timeout unit.")
|
||||
|
||||
def outputs_field_names(self) -> Sequence[str]:
|
||||
field_names = []
|
||||
for match in _OUTPUT_VARIABLE_PATTERN.finditer(self.form_content):
|
||||
field_names.append(match.group("field_name"))
|
||||
return field_names
|
||||
|
||||
def extract_variable_selector_to_variable_mapping(self, node_id: str) -> Mapping[str, Sequence[str]]:
|
||||
variable_mappings: dict[str, Sequence[str]] = {}
|
||||
|
||||
def _add_variable_selectors(selectors: Sequence[Sequence[str]]) -> None:
|
||||
for selector in selectors:
|
||||
if len(selector) < SELECTORS_LENGTH:
|
||||
continue
|
||||
qualified_variable_mapping_key = f"{node_id}.#{'.'.join(selector[:SELECTORS_LENGTH])}#"
|
||||
variable_mappings[qualified_variable_mapping_key] = list(selector[:SELECTORS_LENGTH])
|
||||
|
||||
form_template_parser = VariableTemplateParser(template=self.form_content)
|
||||
_add_variable_selectors(
|
||||
[selector.value_selector for selector in form_template_parser.extract_variable_selectors()]
|
||||
)
|
||||
for delivery_method in self.delivery_methods:
|
||||
if not delivery_method.enabled:
|
||||
continue
|
||||
_add_variable_selectors(delivery_method.extract_variable_selectors())
|
||||
|
||||
for input in self.inputs:
|
||||
default_value = input.default
|
||||
if default_value is None:
|
||||
continue
|
||||
if default_value.type == PlaceholderType.CONSTANT:
|
||||
continue
|
||||
default_value_key = ".".join(default_value.selector)
|
||||
qualified_variable_mapping_key = f"{node_id}.#{default_value_key}#"
|
||||
variable_mappings[qualified_variable_mapping_key] = default_value.selector
|
||||
|
||||
return variable_mappings
|
||||
|
||||
def find_action_text(self, action_id: str) -> str:
|
||||
"""
|
||||
Resolve action display text by id.
|
||||
"""
|
||||
for action in self.user_actions:
|
||||
if action.id == action_id:
|
||||
return action.title
|
||||
return action_id
|
||||
|
||||
|
||||
class FormDefinition(BaseModel):
|
||||
form_content: str
|
||||
inputs: list[FormInput] = Field(default_factory=list)
|
||||
user_actions: list[UserAction] = Field(default_factory=list)
|
||||
rendered_content: str
|
||||
expiration_time: datetime
|
||||
|
||||
# this is used to store the resolved default values
|
||||
default_values: dict[str, Any] = Field(default_factory=dict)
|
||||
|
||||
# node_title records the title of the HumanInput node.
|
||||
node_title: str | None = None
|
||||
|
||||
# display_in_ui controls whether the form should be displayed in UI surfaces.
|
||||
display_in_ui: bool | None = None
|
||||
|
||||
|
||||
class HumanInputSubmissionValidationError(ValueError):
|
||||
pass
|
||||
|
||||
|
||||
def validate_human_input_submission(
|
||||
*,
|
||||
inputs: Sequence[FormInput],
|
||||
user_actions: Sequence[UserAction],
|
||||
selected_action_id: str,
|
||||
form_data: Mapping[str, Any],
|
||||
) -> None:
|
||||
available_actions = {action.id for action in user_actions}
|
||||
if selected_action_id not in available_actions:
|
||||
raise HumanInputSubmissionValidationError(f"Invalid action: {selected_action_id}")
|
||||
|
||||
provided_inputs = set(form_data.keys())
|
||||
missing_inputs = [
|
||||
form_input.output_variable_name
|
||||
for form_input in inputs
|
||||
if form_input.output_variable_name not in provided_inputs
|
||||
]
|
||||
|
||||
if missing_inputs:
|
||||
missing_list = ", ".join(missing_inputs)
|
||||
raise HumanInputSubmissionValidationError(f"Missing required inputs: {missing_list}")
|
||||
|
|
|
|||
|
|
@ -0,0 +1,72 @@
|
|||
import enum
|
||||
|
||||
|
||||
class HumanInputFormStatus(enum.StrEnum):
|
||||
"""Status of a human input form."""
|
||||
|
||||
# Awaiting submission from any recipient. Forms stay in this state until
|
||||
# submitted or a timeout rule applies.
|
||||
WAITING = enum.auto()
|
||||
# Global timeout reached. The workflow run is stopped and will not resume.
|
||||
# This is distinct from node-level timeout.
|
||||
EXPIRED = enum.auto()
|
||||
# Submitted by a recipient; form data is available and execution resumes
|
||||
# along the selected action edge.
|
||||
SUBMITTED = enum.auto()
|
||||
# Node-level timeout reached. The human input node should emit a timeout
|
||||
# event and the workflow should resume along the timeout edge.
|
||||
TIMEOUT = enum.auto()
|
||||
|
||||
|
||||
class HumanInputFormKind(enum.StrEnum):
|
||||
"""Kind of a human input form."""
|
||||
|
||||
RUNTIME = enum.auto() # Form created during workflow execution.
|
||||
DELIVERY_TEST = enum.auto() # Form created for delivery tests.
|
||||
|
||||
|
||||
class DeliveryMethodType(enum.StrEnum):
|
||||
"""Delivery method types for human input forms."""
|
||||
|
||||
# WEBAPP controls whether the form is delivered to the web app. It not only controls
|
||||
# the standalone web app, but also controls the installed apps in the console.
|
||||
WEBAPP = enum.auto()
|
||||
|
||||
EMAIL = enum.auto()
|
||||
|
||||
|
||||
class ButtonStyle(enum.StrEnum):
|
||||
"""Button styles for user actions."""
|
||||
|
||||
PRIMARY = enum.auto()
|
||||
DEFAULT = enum.auto()
|
||||
ACCENT = enum.auto()
|
||||
GHOST = enum.auto()
|
||||
|
||||
|
||||
class TimeoutUnit(enum.StrEnum):
|
||||
"""Timeout unit for form expiration."""
|
||||
|
||||
HOUR = enum.auto()
|
||||
DAY = enum.auto()
|
||||
|
||||
|
||||
class FormInputType(enum.StrEnum):
|
||||
"""Form input types."""
|
||||
|
||||
TEXT_INPUT = enum.auto()
|
||||
PARAGRAPH = enum.auto()
|
||||
|
||||
|
||||
class PlaceholderType(enum.StrEnum):
|
||||
"""Default value types for form inputs."""
|
||||
|
||||
VARIABLE = enum.auto()
|
||||
CONSTANT = enum.auto()
|
||||
|
||||
|
||||
class EmailRecipientType(enum.StrEnum):
|
||||
"""Email recipient types."""
|
||||
|
||||
MEMBER = enum.auto()
|
||||
EXTERNAL = enum.auto()
|
||||
|
|
@ -1,12 +1,42 @@
|
|||
from collections.abc import Mapping
|
||||
from typing import Any
|
||||
import json
|
||||
import logging
|
||||
from collections.abc import Generator, Mapping, Sequence
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.repositories.human_input_repository import HumanInputFormRepositoryImpl
|
||||
from core.workflow.entities.pause_reason import HumanInputRequired
|
||||
from core.workflow.enums import NodeExecutionType, NodeType, WorkflowNodeExecutionStatus
|
||||
from core.workflow.node_events import NodeRunResult, PauseRequestedEvent
|
||||
from core.workflow.node_events import (
|
||||
HumanInputFormFilledEvent,
|
||||
HumanInputFormTimeoutEvent,
|
||||
NodeRunResult,
|
||||
PauseRequestedEvent,
|
||||
)
|
||||
from core.workflow.node_events.base import NodeEventBase
|
||||
from core.workflow.node_events.node import StreamCompletedEvent
|
||||
from core.workflow.nodes.base.node import Node
|
||||
from core.workflow.repositories.human_input_form_repository import (
|
||||
FormCreateParams,
|
||||
HumanInputFormEntity,
|
||||
HumanInputFormRepository,
|
||||
)
|
||||
from core.workflow.workflow_type_encoder import WorkflowRuntimeTypeConverter
|
||||
from extensions.ext_database import db
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
|
||||
from .entities import HumanInputNodeData
|
||||
from .entities import DeliveryChannelConfig, HumanInputNodeData, apply_debug_email_recipient
|
||||
from .enums import DeliveryMethodType, HumanInputFormStatus, PlaceholderType
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from core.workflow.entities.graph_init_params import GraphInitParams
|
||||
from core.workflow.runtime.graph_runtime_state import GraphRuntimeState
|
||||
|
||||
|
||||
_SELECTED_BRANCH_KEY = "selected_branch"
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class HumanInputNode(Node[HumanInputNodeData]):
|
||||
|
|
@ -17,7 +47,7 @@ class HumanInputNode(Node[HumanInputNodeData]):
|
|||
"edge_source_handle",
|
||||
"edgeSourceHandle",
|
||||
"source_handle",
|
||||
"selected_branch",
|
||||
_SELECTED_BRANCH_KEY,
|
||||
"selectedBranch",
|
||||
"branch",
|
||||
"branch_id",
|
||||
|
|
@ -25,43 +55,37 @@ class HumanInputNode(Node[HumanInputNodeData]):
|
|||
"handle",
|
||||
)
|
||||
|
||||
_node_data: HumanInputNodeData
|
||||
_form_repository: HumanInputFormRepository
|
||||
_OUTPUT_FIELD_ACTION_ID = "__action_id"
|
||||
_OUTPUT_FIELD_RENDERED_CONTENT = "__rendered_content"
|
||||
_TIMEOUT_HANDLE = _TIMEOUT_ACTION_ID = "__timeout"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
id: str,
|
||||
config: Mapping[str, Any],
|
||||
graph_init_params: "GraphInitParams",
|
||||
graph_runtime_state: "GraphRuntimeState",
|
||||
form_repository: HumanInputFormRepository | None = None,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
id=id,
|
||||
config=config,
|
||||
graph_init_params=graph_init_params,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
)
|
||||
if form_repository is None:
|
||||
form_repository = HumanInputFormRepositoryImpl(
|
||||
session_factory=db.engine,
|
||||
tenant_id=self.tenant_id,
|
||||
)
|
||||
self._form_repository = form_repository
|
||||
|
||||
@classmethod
|
||||
def version(cls) -> str:
|
||||
return "1"
|
||||
|
||||
def _run(self): # type: ignore[override]
|
||||
if self._is_completion_ready():
|
||||
branch_handle = self._resolve_branch_selection()
|
||||
return NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||
outputs={},
|
||||
edge_source_handle=branch_handle or "source",
|
||||
)
|
||||
|
||||
return self._pause_generator()
|
||||
|
||||
def _pause_generator(self):
|
||||
# TODO(QuantumGhost): yield a real form id.
|
||||
yield PauseRequestedEvent(reason=HumanInputRequired(form_id="test_form_id", node_id=self.id))
|
||||
|
||||
def _is_completion_ready(self) -> bool:
|
||||
"""Determine whether all required inputs are satisfied."""
|
||||
|
||||
if not self.node_data.required_variables:
|
||||
return False
|
||||
|
||||
variable_pool = self.graph_runtime_state.variable_pool
|
||||
|
||||
for selector_str in self.node_data.required_variables:
|
||||
parts = selector_str.split(".")
|
||||
if len(parts) != 2:
|
||||
return False
|
||||
segment = variable_pool.get(parts)
|
||||
if segment is None:
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
def _resolve_branch_selection(self) -> str | None:
|
||||
"""Determine the branch handle selected by human input if available."""
|
||||
|
||||
|
|
@ -108,3 +132,224 @@ class HumanInputNode(Node[HumanInputNodeData]):
|
|||
return candidate
|
||||
|
||||
return None
|
||||
|
||||
@property
|
||||
def _workflow_execution_id(self) -> str:
|
||||
workflow_exec_id = self.graph_runtime_state.variable_pool.system_variables.workflow_execution_id
|
||||
assert workflow_exec_id is not None
|
||||
return workflow_exec_id
|
||||
|
||||
def _form_to_pause_event(self, form_entity: HumanInputFormEntity):
|
||||
required_event = self._human_input_required_event(form_entity)
|
||||
pause_requested_event = PauseRequestedEvent(reason=required_event)
|
||||
return pause_requested_event
|
||||
|
||||
def resolve_default_values(self) -> Mapping[str, Any]:
|
||||
variable_pool = self.graph_runtime_state.variable_pool
|
||||
resolved_defaults: dict[str, Any] = {}
|
||||
for input in self._node_data.inputs:
|
||||
if (default_value := input.default) is None:
|
||||
continue
|
||||
if default_value.type == PlaceholderType.CONSTANT:
|
||||
continue
|
||||
resolved_value = variable_pool.get(default_value.selector)
|
||||
if resolved_value is None:
|
||||
# TODO: How should we handle this?
|
||||
continue
|
||||
resolved_defaults[input.output_variable_name] = (
|
||||
WorkflowRuntimeTypeConverter().value_to_json_encodable_recursive(resolved_value.value)
|
||||
)
|
||||
|
||||
return resolved_defaults
|
||||
|
||||
def _should_require_console_recipient(self) -> bool:
|
||||
if self.invoke_from == InvokeFrom.DEBUGGER:
|
||||
return True
|
||||
if self.invoke_from == InvokeFrom.EXPLORE:
|
||||
return self._node_data.is_webapp_enabled()
|
||||
return False
|
||||
|
||||
def _display_in_ui(self) -> bool:
|
||||
if self.invoke_from == InvokeFrom.DEBUGGER:
|
||||
return True
|
||||
return self._node_data.is_webapp_enabled()
|
||||
|
||||
def _effective_delivery_methods(self) -> Sequence[DeliveryChannelConfig]:
|
||||
enabled_methods = [method for method in self._node_data.delivery_methods if method.enabled]
|
||||
if self.invoke_from in {InvokeFrom.DEBUGGER, InvokeFrom.EXPLORE}:
|
||||
enabled_methods = [method for method in enabled_methods if method.type != DeliveryMethodType.WEBAPP]
|
||||
return [
|
||||
apply_debug_email_recipient(
|
||||
method,
|
||||
enabled=self.invoke_from == InvokeFrom.DEBUGGER,
|
||||
user_id=self.user_id or "",
|
||||
)
|
||||
for method in enabled_methods
|
||||
]
|
||||
|
||||
def _human_input_required_event(self, form_entity: HumanInputFormEntity) -> HumanInputRequired:
|
||||
node_data = self._node_data
|
||||
resolved_default_values = self.resolve_default_values()
|
||||
display_in_ui = self._display_in_ui()
|
||||
form_token = form_entity.web_app_token
|
||||
if display_in_ui and form_token is None:
|
||||
raise AssertionError("Form token should be available for UI execution.")
|
||||
return HumanInputRequired(
|
||||
form_id=form_entity.id,
|
||||
form_content=form_entity.rendered_content,
|
||||
inputs=node_data.inputs,
|
||||
actions=node_data.user_actions,
|
||||
display_in_ui=display_in_ui,
|
||||
node_id=self.id,
|
||||
node_title=node_data.title,
|
||||
form_token=form_token,
|
||||
resolved_default_values=resolved_default_values,
|
||||
)
|
||||
|
||||
def _run(self) -> Generator[NodeEventBase, None, None]:
|
||||
"""
|
||||
Execute the human input node.
|
||||
|
||||
This method will:
|
||||
1. Generate a unique form ID
|
||||
2. Create form content with variable substitution
|
||||
3. Create form in database
|
||||
4. Send form via configured delivery methods
|
||||
5. Suspend workflow execution
|
||||
6. Wait for form submission to resume
|
||||
"""
|
||||
repo = self._form_repository
|
||||
form = repo.get_form(self._workflow_execution_id, self.id)
|
||||
if form is None:
|
||||
display_in_ui = self._display_in_ui()
|
||||
params = FormCreateParams(
|
||||
app_id=self.app_id,
|
||||
workflow_execution_id=self._workflow_execution_id,
|
||||
node_id=self.id,
|
||||
form_config=self._node_data,
|
||||
rendered_content=self.render_form_content_before_submission(),
|
||||
delivery_methods=self._effective_delivery_methods(),
|
||||
display_in_ui=display_in_ui,
|
||||
resolved_default_values=self.resolve_default_values(),
|
||||
console_recipient_required=self._should_require_console_recipient(),
|
||||
console_creator_account_id=(
|
||||
self.user_id if self.invoke_from in {InvokeFrom.DEBUGGER, InvokeFrom.EXPLORE} else None
|
||||
),
|
||||
backstage_recipient_required=True,
|
||||
)
|
||||
form_entity = self._form_repository.create_form(params)
|
||||
# Create human input required event
|
||||
|
||||
logger.info(
|
||||
"Human Input node suspended workflow for form. workflow_run_id=%s, node_id=%s, form_id=%s",
|
||||
self.graph_runtime_state.variable_pool.system_variables.workflow_execution_id,
|
||||
self.id,
|
||||
form_entity.id,
|
||||
)
|
||||
yield self._form_to_pause_event(form_entity)
|
||||
return
|
||||
|
||||
if (
|
||||
form.status in {HumanInputFormStatus.TIMEOUT, HumanInputFormStatus.EXPIRED}
|
||||
or form.expiration_time <= naive_utc_now()
|
||||
):
|
||||
yield HumanInputFormTimeoutEvent(
|
||||
node_title=self._node_data.title,
|
||||
expiration_time=form.expiration_time,
|
||||
)
|
||||
yield StreamCompletedEvent(
|
||||
node_run_result=NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||
outputs={self._OUTPUT_FIELD_ACTION_ID: ""},
|
||||
edge_source_handle=self._TIMEOUT_HANDLE,
|
||||
)
|
||||
)
|
||||
return
|
||||
|
||||
if not form.submitted:
|
||||
yield self._form_to_pause_event(form)
|
||||
return
|
||||
|
||||
selected_action_id = form.selected_action_id
|
||||
if selected_action_id is None:
|
||||
raise AssertionError(f"selected_action_id should not be None when form submitted, form_id={form.id}")
|
||||
submitted_data = form.submitted_data or {}
|
||||
outputs: dict[str, Any] = dict(submitted_data)
|
||||
outputs[self._OUTPUT_FIELD_ACTION_ID] = selected_action_id
|
||||
rendered_content = self.render_form_content_with_outputs(
|
||||
form.rendered_content,
|
||||
outputs,
|
||||
self._node_data.outputs_field_names(),
|
||||
)
|
||||
outputs[self._OUTPUT_FIELD_RENDERED_CONTENT] = rendered_content
|
||||
|
||||
action_text = self._node_data.find_action_text(selected_action_id)
|
||||
|
||||
yield HumanInputFormFilledEvent(
|
||||
node_title=self._node_data.title,
|
||||
rendered_content=rendered_content,
|
||||
action_id=selected_action_id,
|
||||
action_text=action_text,
|
||||
)
|
||||
|
||||
yield StreamCompletedEvent(
|
||||
node_run_result=NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||
outputs=outputs,
|
||||
edge_source_handle=selected_action_id,
|
||||
)
|
||||
)
|
||||
|
||||
def render_form_content_before_submission(self) -> str:
|
||||
"""
|
||||
Process form content by substituting variables.
|
||||
|
||||
This method should:
|
||||
1. Parse the form_content markdown
|
||||
2. Substitute {{#node_name.var_name#}} with actual values
|
||||
3. Keep {{#$output.field_name#}} placeholders for form inputs
|
||||
"""
|
||||
rendered_form_content = self.graph_runtime_state.variable_pool.convert_template(
|
||||
self._node_data.form_content,
|
||||
)
|
||||
return rendered_form_content.markdown
|
||||
|
||||
@staticmethod
|
||||
def render_form_content_with_outputs(
|
||||
form_content: str,
|
||||
outputs: Mapping[str, Any],
|
||||
field_names: Sequence[str],
|
||||
) -> str:
|
||||
"""
|
||||
Replace {{#$output.xxx#}} placeholders with submitted values.
|
||||
"""
|
||||
rendered_content = form_content
|
||||
for field_name in field_names:
|
||||
placeholder = "{{#$output." + field_name + "#}}"
|
||||
value = outputs.get(field_name)
|
||||
if value is None:
|
||||
replacement = ""
|
||||
elif isinstance(value, (dict, list)):
|
||||
replacement = json.dumps(value, ensure_ascii=False)
|
||||
else:
|
||||
replacement = str(value)
|
||||
rendered_content = rendered_content.replace(placeholder, replacement)
|
||||
return rendered_content
|
||||
|
||||
@classmethod
|
||||
def _extract_variable_selector_to_variable_mapping(
|
||||
cls,
|
||||
*,
|
||||
graph_config: Mapping[str, Any],
|
||||
node_id: str,
|
||||
node_data: Mapping[str, Any],
|
||||
) -> Mapping[str, Sequence[str]]:
|
||||
"""
|
||||
Extract variable selectors referenced in form content and input default values.
|
||||
|
||||
This method should parse:
|
||||
1. Variables referenced in form_content ({{#node_name.var_name#}})
|
||||
2. Variables referenced in input default values
|
||||
"""
|
||||
validated_node_data = HumanInputNodeData.model_validate(node_data)
|
||||
return validated_node_data.extract_variable_selector_to_variable_mapping(node_id)
|
||||
|
|
|
|||
|
|
@ -0,0 +1,152 @@
|
|||
import abc
|
||||
import dataclasses
|
||||
from collections.abc import Mapping, Sequence
|
||||
from datetime import datetime
|
||||
from typing import Any, Protocol
|
||||
|
||||
from core.workflow.nodes.human_input.entities import DeliveryChannelConfig, HumanInputNodeData
|
||||
from core.workflow.nodes.human_input.enums import HumanInputFormKind, HumanInputFormStatus
|
||||
|
||||
|
||||
class HumanInputError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class FormNotFoundError(HumanInputError):
|
||||
pass
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class FormCreateParams:
|
||||
# app_id is the identifier for the app that the form belongs to.
|
||||
# It is a string with uuid format.
|
||||
app_id: str
|
||||
# None when creating a delivery test form; set for runtime forms.
|
||||
workflow_execution_id: str | None
|
||||
|
||||
# node_id is the identifier for a specific
|
||||
# node in the graph.
|
||||
#
|
||||
# TODO: for node inside loop / iteration, this would
|
||||
# cause problems, as a single node may be executed multiple times.
|
||||
node_id: str
|
||||
|
||||
form_config: HumanInputNodeData
|
||||
rendered_content: str
|
||||
# Delivery methods already filtered by runtime context (invoke_from).
|
||||
delivery_methods: Sequence[DeliveryChannelConfig]
|
||||
# UI display flag computed by runtime context.
|
||||
display_in_ui: bool
|
||||
|
||||
# resolved_default_values saves the values for defaults with
|
||||
# type = VARIABLE.
|
||||
#
|
||||
# For type = CONSTANT, the value is not stored inside `resolved_default_values`
|
||||
resolved_default_values: Mapping[str, Any]
|
||||
form_kind: HumanInputFormKind = HumanInputFormKind.RUNTIME
|
||||
|
||||
# Force creating a console-only recipient for submission in Console.
|
||||
console_recipient_required: bool = False
|
||||
console_creator_account_id: str | None = None
|
||||
# Force creating a backstage recipient for submission in Console.
|
||||
backstage_recipient_required: bool = False
|
||||
|
||||
|
||||
class HumanInputFormEntity(abc.ABC):
|
||||
@property
|
||||
@abc.abstractmethod
|
||||
def id(self) -> str:
|
||||
"""id returns the identifer of the form."""
|
||||
pass
|
||||
|
||||
@property
|
||||
@abc.abstractmethod
|
||||
def web_app_token(self) -> str | None:
|
||||
"""web_app_token returns the token for submission inside webapp.
|
||||
|
||||
For console/debug execution, this may point to the console submission token
|
||||
if the form is configured to require console delivery.
|
||||
"""
|
||||
|
||||
# TODO: what if the users are allowed to add multiple
|
||||
# webapp delivery?
|
||||
pass
|
||||
|
||||
@property
|
||||
@abc.abstractmethod
|
||||
def recipients(self) -> list["HumanInputFormRecipientEntity"]: ...
|
||||
|
||||
@property
|
||||
@abc.abstractmethod
|
||||
def rendered_content(self) -> str:
|
||||
"""Rendered markdown content associated with the form."""
|
||||
...
|
||||
|
||||
@property
|
||||
@abc.abstractmethod
|
||||
def selected_action_id(self) -> str | None:
|
||||
"""Identifier of the selected user action if the form has been submitted."""
|
||||
...
|
||||
|
||||
@property
|
||||
@abc.abstractmethod
|
||||
def submitted_data(self) -> Mapping[str, Any] | None:
|
||||
"""Submitted form data if available."""
|
||||
...
|
||||
|
||||
@property
|
||||
@abc.abstractmethod
|
||||
def submitted(self) -> bool:
|
||||
"""Whether the form has been submitted."""
|
||||
...
|
||||
|
||||
@property
|
||||
@abc.abstractmethod
|
||||
def status(self) -> HumanInputFormStatus:
|
||||
"""Current status of the form."""
|
||||
...
|
||||
|
||||
@property
|
||||
@abc.abstractmethod
|
||||
def expiration_time(self) -> datetime:
|
||||
"""When the form expires."""
|
||||
...
|
||||
|
||||
|
||||
class HumanInputFormRecipientEntity(abc.ABC):
|
||||
@property
|
||||
@abc.abstractmethod
|
||||
def id(self) -> str:
|
||||
"""id returns the identifer of this recipient."""
|
||||
...
|
||||
|
||||
@property
|
||||
@abc.abstractmethod
|
||||
def token(self) -> str:
|
||||
"""token returns a random string used to submit form"""
|
||||
...
|
||||
|
||||
|
||||
class HumanInputFormRepository(Protocol):
|
||||
"""
|
||||
Repository interface for HumanInputForm.
|
||||
|
||||
This interface defines the contract for accessing and manipulating
|
||||
HumanInputForm data, regardless of the underlying storage mechanism.
|
||||
|
||||
Note: Domain-specific concepts like multi-tenancy (tenant_id), application context (app_id),
|
||||
and other implementation details should be handled at the implementation level, not in
|
||||
the core interface. This keeps the core domain model clean and independent of specific
|
||||
application domains or deployment scenarios.
|
||||
"""
|
||||
|
||||
def get_form(self, workflow_execution_id: str, node_id: str) -> HumanInputFormEntity | None:
|
||||
"""Get the form created for a given human input node in a workflow execution. Returns
|
||||
`None` if the form has not been created yet."""
|
||||
...
|
||||
|
||||
def create_form(self, params: FormCreateParams) -> HumanInputFormEntity:
|
||||
"""
|
||||
Create a human input form from form definition.
|
||||
"""
|
||||
...
|
||||
|
|
@ -6,14 +6,18 @@ import threading
|
|||
from collections.abc import Mapping, Sequence
|
||||
from copy import deepcopy
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Protocol
|
||||
from typing import TYPE_CHECKING, Any, Protocol
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
from pydantic.json import pydantic_encoder
|
||||
|
||||
from core.model_runtime.entities.llm_entities import LLMUsage
|
||||
from core.workflow.entities.pause_reason import PauseReason
|
||||
from core.workflow.enums import NodeState
|
||||
from core.workflow.runtime.variable_pool import VariablePool
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from core.workflow.entities.pause_reason import PauseReason
|
||||
|
||||
|
||||
class ReadyQueueProtocol(Protocol):
|
||||
"""Structural interface required from ready queue implementations."""
|
||||
|
|
@ -60,7 +64,7 @@ class GraphExecutionProtocol(Protocol):
|
|||
aborted: bool
|
||||
error: Exception | None
|
||||
exceptions_count: int
|
||||
pause_reasons: list[PauseReason]
|
||||
pause_reasons: Sequence[PauseReason]
|
||||
|
||||
def start(self) -> None:
|
||||
"""Transition execution into the running state."""
|
||||
|
|
@ -103,14 +107,33 @@ class ResponseStreamCoordinatorProtocol(Protocol):
|
|||
...
|
||||
|
||||
|
||||
class NodeProtocol(Protocol):
|
||||
"""Structural interface for graph nodes."""
|
||||
|
||||
id: str
|
||||
state: NodeState
|
||||
|
||||
|
||||
class EdgeProtocol(Protocol):
|
||||
id: str
|
||||
state: NodeState
|
||||
|
||||
|
||||
class GraphProtocol(Protocol):
|
||||
"""Structural interface required from graph instances attached to the runtime state."""
|
||||
|
||||
nodes: Mapping[str, object]
|
||||
edges: Mapping[str, object]
|
||||
root_node: object
|
||||
nodes: Mapping[str, NodeProtocol]
|
||||
edges: Mapping[str, EdgeProtocol]
|
||||
root_node: NodeProtocol
|
||||
|
||||
def get_outgoing_edges(self, node_id: str) -> Sequence[object]: ...
|
||||
def get_outgoing_edges(self, node_id: str) -> Sequence[EdgeProtocol]: ...
|
||||
|
||||
|
||||
class _GraphStateSnapshot(BaseModel):
|
||||
"""Serializable graph state snapshot for node/edge states."""
|
||||
|
||||
nodes: dict[str, NodeState] = Field(default_factory=dict)
|
||||
edges: dict[str, NodeState] = Field(default_factory=dict)
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
|
|
@ -128,10 +151,20 @@ class _GraphRuntimeStateSnapshot:
|
|||
graph_execution_dump: str | None
|
||||
response_coordinator_dump: str | None
|
||||
paused_nodes: tuple[str, ...]
|
||||
deferred_nodes: tuple[str, ...]
|
||||
graph_node_states: dict[str, NodeState]
|
||||
graph_edge_states: dict[str, NodeState]
|
||||
|
||||
|
||||
class GraphRuntimeState:
|
||||
"""Mutable runtime state shared across graph execution components."""
|
||||
"""Mutable runtime state shared across graph execution components.
|
||||
|
||||
`GraphRuntimeState` encapsulates the runtime state of workflow execution,
|
||||
including scheduling details, variable values, and timing information.
|
||||
|
||||
Values that are initialized prior to workflow execution and remain constant
|
||||
throughout the execution should be part of `GraphInitParams` instead.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
|
|
@ -169,6 +202,16 @@ class GraphRuntimeState:
|
|||
self._pending_response_coordinator_dump: str | None = None
|
||||
self._pending_graph_execution_workflow_id: str | None = None
|
||||
self._paused_nodes: set[str] = set()
|
||||
self._deferred_nodes: set[str] = set()
|
||||
|
||||
# Node and edges states needed to be restored into
|
||||
# graph object.
|
||||
#
|
||||
# These two fields are non-None only when resuming from a snapshot.
|
||||
# Once the graph is attached, these two fields will be set to None.
|
||||
self._pending_graph_node_states: dict[str, NodeState] | None = None
|
||||
self._pending_graph_edge_states: dict[str, NodeState] | None = None
|
||||
|
||||
self.stop_event: threading.Event = threading.Event()
|
||||
|
||||
if graph is not None:
|
||||
|
|
@ -190,6 +233,7 @@ class GraphRuntimeState:
|
|||
if self._pending_response_coordinator_dump is not None and self._response_coordinator is not None:
|
||||
self._response_coordinator.loads(self._pending_response_coordinator_dump)
|
||||
self._pending_response_coordinator_dump = None
|
||||
self._apply_pending_graph_state()
|
||||
|
||||
def configure(self, *, graph: GraphProtocol | None = None) -> None:
|
||||
"""Ensure core collaborators are initialized with the provided context."""
|
||||
|
|
@ -311,8 +355,13 @@ class GraphRuntimeState:
|
|||
"ready_queue": self.ready_queue.dumps(),
|
||||
"graph_execution": self.graph_execution.dumps(),
|
||||
"paused_nodes": list(self._paused_nodes),
|
||||
"deferred_nodes": list(self._deferred_nodes),
|
||||
}
|
||||
|
||||
graph_state = self._snapshot_graph_state()
|
||||
if graph_state is not None:
|
||||
snapshot["graph_state"] = graph_state
|
||||
|
||||
if self._response_coordinator is not None and self._graph is not None:
|
||||
snapshot["response_coordinator"] = self._response_coordinator.dumps()
|
||||
|
||||
|
|
@ -346,6 +395,11 @@ class GraphRuntimeState:
|
|||
|
||||
self._paused_nodes.add(node_id)
|
||||
|
||||
def get_paused_nodes(self) -> list[str]:
|
||||
"""Retrieve the list of paused nodes without mutating internal state."""
|
||||
|
||||
return list(self._paused_nodes)
|
||||
|
||||
def consume_paused_nodes(self) -> list[str]:
|
||||
"""Retrieve and clear the list of paused nodes awaiting resume."""
|
||||
|
||||
|
|
@ -353,6 +407,23 @@ class GraphRuntimeState:
|
|||
self._paused_nodes.clear()
|
||||
return nodes
|
||||
|
||||
def register_deferred_node(self, node_id: str) -> None:
|
||||
"""Record a node that became ready during pause and should resume later."""
|
||||
|
||||
self._deferred_nodes.add(node_id)
|
||||
|
||||
def get_deferred_nodes(self) -> list[str]:
|
||||
"""Retrieve deferred nodes without mutating internal state."""
|
||||
|
||||
return list(self._deferred_nodes)
|
||||
|
||||
def consume_deferred_nodes(self) -> list[str]:
|
||||
"""Retrieve and clear deferred nodes awaiting resume."""
|
||||
|
||||
nodes = list(self._deferred_nodes)
|
||||
self._deferred_nodes.clear()
|
||||
return nodes
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Builders
|
||||
# ------------------------------------------------------------------
|
||||
|
|
@ -414,6 +485,10 @@ class GraphRuntimeState:
|
|||
graph_execution_payload = payload.get("graph_execution")
|
||||
response_payload = payload.get("response_coordinator")
|
||||
paused_nodes_payload = payload.get("paused_nodes", [])
|
||||
deferred_nodes_payload = payload.get("deferred_nodes", [])
|
||||
graph_state_payload = payload.get("graph_state", {}) or {}
|
||||
graph_node_states = _coerce_graph_state_map(graph_state_payload, "nodes")
|
||||
graph_edge_states = _coerce_graph_state_map(graph_state_payload, "edges")
|
||||
|
||||
return _GraphRuntimeStateSnapshot(
|
||||
start_at=start_at,
|
||||
|
|
@ -427,6 +502,9 @@ class GraphRuntimeState:
|
|||
graph_execution_dump=graph_execution_payload,
|
||||
response_coordinator_dump=response_payload,
|
||||
paused_nodes=tuple(map(str, paused_nodes_payload)),
|
||||
deferred_nodes=tuple(map(str, deferred_nodes_payload)),
|
||||
graph_node_states=graph_node_states,
|
||||
graph_edge_states=graph_edge_states,
|
||||
)
|
||||
|
||||
def _apply_snapshot(self, snapshot: _GraphRuntimeStateSnapshot) -> None:
|
||||
|
|
@ -442,6 +520,10 @@ class GraphRuntimeState:
|
|||
self._restore_graph_execution(snapshot.graph_execution_dump)
|
||||
self._restore_response_coordinator(snapshot.response_coordinator_dump)
|
||||
self._paused_nodes = set(snapshot.paused_nodes)
|
||||
self._deferred_nodes = set(snapshot.deferred_nodes)
|
||||
self._pending_graph_node_states = snapshot.graph_node_states or None
|
||||
self._pending_graph_edge_states = snapshot.graph_edge_states or None
|
||||
self._apply_pending_graph_state()
|
||||
|
||||
def _restore_ready_queue(self, payload: str | None) -> None:
|
||||
if payload is not None:
|
||||
|
|
@ -478,3 +560,68 @@ class GraphRuntimeState:
|
|||
|
||||
self._pending_response_coordinator_dump = payload
|
||||
self._response_coordinator = None
|
||||
|
||||
def _snapshot_graph_state(self) -> _GraphStateSnapshot:
|
||||
graph = self._graph
|
||||
if graph is None:
|
||||
if self._pending_graph_node_states is None and self._pending_graph_edge_states is None:
|
||||
return _GraphStateSnapshot()
|
||||
return _GraphStateSnapshot(
|
||||
nodes=self._pending_graph_node_states or {},
|
||||
edges=self._pending_graph_edge_states or {},
|
||||
)
|
||||
|
||||
nodes = graph.nodes
|
||||
edges = graph.edges
|
||||
if not isinstance(nodes, Mapping) or not isinstance(edges, Mapping):
|
||||
return _GraphStateSnapshot()
|
||||
|
||||
node_states = {}
|
||||
for node_id, node in nodes.items():
|
||||
if not isinstance(node_id, str):
|
||||
continue
|
||||
node_states[node_id] = node.state
|
||||
|
||||
edge_states = {}
|
||||
for edge_id, edge in edges.items():
|
||||
if not isinstance(edge_id, str):
|
||||
continue
|
||||
edge_states[edge_id] = edge.state
|
||||
|
||||
return _GraphStateSnapshot(nodes=node_states, edges=edge_states)
|
||||
|
||||
def _apply_pending_graph_state(self) -> None:
|
||||
if self._graph is None:
|
||||
return
|
||||
if self._pending_graph_node_states:
|
||||
for node_id, state in self._pending_graph_node_states.items():
|
||||
node = self._graph.nodes.get(node_id)
|
||||
if node is None:
|
||||
continue
|
||||
node.state = state
|
||||
if self._pending_graph_edge_states:
|
||||
for edge_id, state in self._pending_graph_edge_states.items():
|
||||
edge = self._graph.edges.get(edge_id)
|
||||
if edge is None:
|
||||
continue
|
||||
edge.state = state
|
||||
|
||||
self._pending_graph_node_states = None
|
||||
self._pending_graph_edge_states = None
|
||||
|
||||
|
||||
def _coerce_graph_state_map(payload: Any, key: str) -> dict[str, NodeState]:
|
||||
if not isinstance(payload, Mapping):
|
||||
return {}
|
||||
raw_map = payload.get(key, {})
|
||||
if not isinstance(raw_map, Mapping):
|
||||
return {}
|
||||
result: dict[str, NodeState] = {}
|
||||
for node_id, raw_state in raw_map.items():
|
||||
if not isinstance(node_id, str):
|
||||
continue
|
||||
try:
|
||||
result[node_id] = NodeState(str(raw_state))
|
||||
except ValueError:
|
||||
continue
|
||||
return result
|
||||
|
|
|
|||
|
|
@ -15,12 +15,14 @@ class WorkflowRuntimeTypeConverter:
|
|||
def to_json_encodable(self, value: None) -> None: ...
|
||||
|
||||
def to_json_encodable(self, value: Mapping[str, Any] | None) -> Mapping[str, Any] | None:
|
||||
result = self._to_json_encodable_recursive(value)
|
||||
"""Convert runtime values to JSON-serializable structures."""
|
||||
|
||||
result = self.value_to_json_encodable_recursive(value)
|
||||
if isinstance(result, Mapping) or result is None:
|
||||
return result
|
||||
return {}
|
||||
|
||||
def _to_json_encodable_recursive(self, value: Any):
|
||||
def value_to_json_encodable_recursive(self, value: Any):
|
||||
if value is None:
|
||||
return value
|
||||
if isinstance(value, (bool, int, str, float)):
|
||||
|
|
@ -29,7 +31,7 @@ class WorkflowRuntimeTypeConverter:
|
|||
# Convert Decimal to float for JSON serialization
|
||||
return float(value)
|
||||
if isinstance(value, Segment):
|
||||
return self._to_json_encodable_recursive(value.value)
|
||||
return self.value_to_json_encodable_recursive(value.value)
|
||||
if isinstance(value, File):
|
||||
return value.to_dict()
|
||||
if isinstance(value, BaseModel):
|
||||
|
|
@ -37,11 +39,11 @@ class WorkflowRuntimeTypeConverter:
|
|||
if isinstance(value, dict):
|
||||
res = {}
|
||||
for k, v in value.items():
|
||||
res[k] = self._to_json_encodable_recursive(v)
|
||||
res[k] = self.value_to_json_encodable_recursive(v)
|
||||
return res
|
||||
if isinstance(value, list):
|
||||
res_list = []
|
||||
for item in value:
|
||||
res_list.append(self._to_json_encodable_recursive(item))
|
||||
res_list.append(self.value_to_json_encodable_recursive(item))
|
||||
return res_list
|
||||
return value
|
||||
|
|
|
|||
|
|
@ -35,10 +35,10 @@ if [[ "${MODE}" == "worker" ]]; then
|
|||
if [[ -z "${CELERY_QUEUES}" ]]; then
|
||||
if [[ "${EDITION}" == "CLOUD" ]]; then
|
||||
# Cloud edition: separate queues for dataset and trigger tasks
|
||||
DEFAULT_QUEUES="dataset,priority_dataset,priority_pipeline,pipeline,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation,workflow_professional,workflow_team,workflow_sandbox,schedule_poller,schedule_executor,triggered_workflow_dispatcher,trigger_refresh_executor,retention"
|
||||
DEFAULT_QUEUES="dataset,priority_dataset,priority_pipeline,pipeline,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation,workflow_professional,workflow_team,workflow_sandbox,schedule_poller,schedule_executor,triggered_workflow_dispatcher,trigger_refresh_executor,retention,workflow_based_app_execution"
|
||||
else
|
||||
# Community edition (SELF_HOSTED): dataset, pipeline and workflow have separate queues
|
||||
DEFAULT_QUEUES="dataset,priority_dataset,priority_pipeline,pipeline,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation,workflow,schedule_poller,schedule_executor,triggered_workflow_dispatcher,trigger_refresh_executor,retention"
|
||||
DEFAULT_QUEUES="dataset,priority_dataset,priority_pipeline,pipeline,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation,workflow,schedule_poller,schedule_executor,triggered_workflow_dispatcher,trigger_refresh_executor,retention,workflow_based_app_execution"
|
||||
fi
|
||||
else
|
||||
DEFAULT_QUEUES="${CELERY_QUEUES}"
|
||||
|
|
@ -102,7 +102,7 @@ elif [[ "${MODE}" == "job" ]]; then
|
|||
fi
|
||||
|
||||
echo "Running Flask job command: flask $*"
|
||||
|
||||
|
||||
# Temporarily disable exit on error to capture exit code
|
||||
set +e
|
||||
flask "$@"
|
||||
|
|
|
|||
|
|
@ -151,6 +151,12 @@ def init_app(app: DifyApp) -> Celery:
|
|||
"task": "schedule.queue_monitor_task.queue_monitor_task",
|
||||
"schedule": timedelta(minutes=dify_config.QUEUE_MONITOR_INTERVAL or 30),
|
||||
}
|
||||
if dify_config.ENABLE_HUMAN_INPUT_TIMEOUT_TASK:
|
||||
imports.append("tasks.human_input_timeout_tasks")
|
||||
beat_schedule["human_input_form_timeout"] = {
|
||||
"task": "human_input_form_timeout.check_and_resume",
|
||||
"schedule": timedelta(minutes=dify_config.HUMAN_INPUT_TIMEOUT_TASK_INTERVAL),
|
||||
}
|
||||
if dify_config.ENABLE_CHECK_UPGRADABLE_PLUGIN_TASK and dify_config.MARKETPLACE_ENABLED:
|
||||
imports.append("schedule.check_upgradable_plugin_task")
|
||||
imports.append("tasks.process_tenant_plugin_autoupgrade_check_task")
|
||||
|
|
|
|||
|
|
@ -8,12 +8,16 @@ from typing import TYPE_CHECKING, Any, ParamSpec, TypeVar, Union
|
|||
import redis
|
||||
from redis import RedisError
|
||||
from redis.cache import CacheConfig
|
||||
from redis.client import PubSub
|
||||
from redis.cluster import ClusterNode, RedisCluster
|
||||
from redis.connection import Connection, SSLConnection
|
||||
from redis.sentinel import Sentinel
|
||||
|
||||
from configs import dify_config
|
||||
from dify_app import DifyApp
|
||||
from libs.broadcast_channel.channel import BroadcastChannel as BroadcastChannelProtocol
|
||||
from libs.broadcast_channel.redis.channel import BroadcastChannel as RedisBroadcastChannel
|
||||
from libs.broadcast_channel.redis.sharded_channel import ShardedRedisBroadcastChannel
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from redis.lock import Lock
|
||||
|
|
@ -106,6 +110,7 @@ class RedisClientWrapper:
|
|||
def zremrangebyscore(self, name: str | bytes, min: float | str, max: float | str) -> Any: ...
|
||||
def zcard(self, name: str | bytes) -> Any: ...
|
||||
def getdel(self, name: str | bytes) -> Any: ...
|
||||
def pubsub(self) -> PubSub: ...
|
||||
|
||||
def __getattr__(self, item: str) -> Any:
|
||||
if self._client is None:
|
||||
|
|
@ -114,6 +119,7 @@ class RedisClientWrapper:
|
|||
|
||||
|
||||
redis_client: RedisClientWrapper = RedisClientWrapper()
|
||||
pubsub_redis_client: RedisClientWrapper = RedisClientWrapper()
|
||||
|
||||
|
||||
def _get_ssl_configuration() -> tuple[type[Union[Connection, SSLConnection]], dict[str, Any]]:
|
||||
|
|
@ -226,6 +232,12 @@ def _create_standalone_client(redis_params: dict[str, Any]) -> Union[redis.Redis
|
|||
return client
|
||||
|
||||
|
||||
def _create_pubsub_client(pubsub_url: str, use_clusters: bool) -> Union[redis.Redis, RedisCluster]:
|
||||
if use_clusters:
|
||||
return RedisCluster.from_url(pubsub_url)
|
||||
return redis.Redis.from_url(pubsub_url)
|
||||
|
||||
|
||||
def init_app(app: DifyApp):
|
||||
"""Initialize Redis client and attach it to the app."""
|
||||
global redis_client
|
||||
|
|
@ -244,6 +256,24 @@ def init_app(app: DifyApp):
|
|||
redis_client.initialize(client)
|
||||
app.extensions["redis"] = redis_client
|
||||
|
||||
pubsub_client = client
|
||||
if dify_config.normalized_pubsub_redis_url:
|
||||
pubsub_client = _create_pubsub_client(
|
||||
dify_config.normalized_pubsub_redis_url, dify_config.PUBSUB_REDIS_USE_CLUSTERS
|
||||
)
|
||||
pubsub_redis_client.initialize(pubsub_client)
|
||||
|
||||
|
||||
def get_pubsub_redis_client() -> RedisClientWrapper:
|
||||
return pubsub_redis_client
|
||||
|
||||
|
||||
def get_pubsub_broadcast_channel() -> BroadcastChannelProtocol:
|
||||
redis_conn = get_pubsub_redis_client()
|
||||
if dify_config.PUBSUB_REDIS_CHANNEL_TYPE == "sharded":
|
||||
return ShardedRedisBroadcastChannel(redis_conn) # pyright: ignore[reportArgumentType]
|
||||
return RedisBroadcastChannel(redis_conn) # pyright: ignore[reportArgumentType]
|
||||
|
||||
|
||||
P = ParamSpec("P")
|
||||
R = TypeVar("R")
|
||||
|
|
|
|||
|
|
@ -13,6 +13,7 @@ from typing import Any
|
|||
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
from core.workflow.enums import WorkflowNodeExecutionStatus
|
||||
from extensions.logstore.aliyun_logstore import AliyunLogStore
|
||||
from extensions.logstore.repositories import safe_float, safe_int
|
||||
from extensions.logstore.sql_escape import escape_identifier, escape_logstore_query_value
|
||||
|
|
@ -207,8 +208,10 @@ class LogstoreAPIWorkflowNodeExecutionRepository(DifyAPIWorkflowNodeExecutionRep
|
|||
reverse=True,
|
||||
)
|
||||
|
||||
if deduplicated_results:
|
||||
return _dict_to_workflow_node_execution_model(deduplicated_results[0])
|
||||
for row in deduplicated_results:
|
||||
model = _dict_to_workflow_node_execution_model(row)
|
||||
if model.status != WorkflowNodeExecutionStatus.PAUSED:
|
||||
return model
|
||||
|
||||
return None
|
||||
|
||||
|
|
@ -309,6 +312,8 @@ class LogstoreAPIWorkflowNodeExecutionRepository(DifyAPIWorkflowNodeExecutionRep
|
|||
if model and model.id: # Ensure model is valid
|
||||
models.append(model)
|
||||
|
||||
models = [model for model in models if model.status != WorkflowNodeExecutionStatus.PAUSED]
|
||||
|
||||
# Sort by index DESC for trace visualization
|
||||
models.sort(key=lambda x: x.index, reverse=True)
|
||||
|
||||
|
|
|
|||
|
|
@ -192,6 +192,7 @@ class StatusCount(ResponseModel):
|
|||
success: int
|
||||
failed: int
|
||||
partial_success: int
|
||||
paused: int
|
||||
|
||||
|
||||
class ModelConfig(ResponseModel):
|
||||
|
|
|
|||
|
|
@ -6,6 +6,7 @@ from uuid import uuid4
|
|||
|
||||
from pydantic import BaseModel, ConfigDict, Field, field_validator
|
||||
|
||||
from core.entities.execution_extra_content import ExecutionExtraContentDomainModel
|
||||
from core.file import File
|
||||
from fields.conversation_fields import AgentThought, JSONValue, MessageFile
|
||||
|
||||
|
|
@ -61,6 +62,7 @@ class MessageListItem(ResponseModel):
|
|||
message_files: list[MessageFile]
|
||||
status: str
|
||||
error: str | None = None
|
||||
extra_contents: list[ExecutionExtraContentDomainModel]
|
||||
|
||||
@field_validator("inputs", mode="before")
|
||||
@classmethod
|
||||
|
|
|
|||
|
|
@ -162,7 +162,7 @@ class RedisSubscriptionBase(Subscription):
|
|||
self._start_if_needed()
|
||||
return iter(self._message_iterator())
|
||||
|
||||
def receive(self, timeout: float | None = None) -> bytes | None:
|
||||
def receive(self, timeout: float | None = 0.1) -> bytes | None:
|
||||
"""Receive the next message from the subscription."""
|
||||
if self._closed.is_set():
|
||||
raise SubscriptionClosedError(f"The Redis {self._get_subscription_type()} subscription is closed")
|
||||
|
|
|
|||
|
|
@ -61,7 +61,14 @@ class _RedisShardedSubscription(RedisSubscriptionBase):
|
|||
|
||||
def _get_message(self) -> dict | None:
|
||||
assert self._pubsub is not None
|
||||
return self._pubsub.get_sharded_message(ignore_subscribe_messages=True, timeout=0.1) # type: ignore[attr-defined]
|
||||
# NOTE(QuantumGhost): this is an issue in
|
||||
# upstream code. If Sharded PubSub is used with Cluster, the
|
||||
# `ClusterPubSub.get_sharded_message` will return `None` regardless of
|
||||
# message['type'].
|
||||
#
|
||||
# Since we have already filtered at the caller's site, we can safely set
|
||||
# `ignore_subscribe_messages=False`.
|
||||
return self._pubsub.get_sharded_message(ignore_subscribe_messages=False, timeout=0.1) # type: ignore[attr-defined]
|
||||
|
||||
def _get_message_type(self) -> str:
|
||||
return "smessage"
|
||||
|
|
|
|||
|
|
@ -0,0 +1,49 @@
|
|||
"""
|
||||
Email template rendering helpers with configurable safety modes.
|
||||
"""
|
||||
|
||||
import time
|
||||
from collections.abc import Mapping
|
||||
from typing import Any
|
||||
|
||||
from flask import render_template_string
|
||||
from jinja2.runtime import Context
|
||||
from jinja2.sandbox import ImmutableSandboxedEnvironment
|
||||
|
||||
from configs import dify_config
|
||||
from configs.feature import TemplateMode
|
||||
|
||||
|
||||
class SandboxedEnvironment(ImmutableSandboxedEnvironment):
|
||||
"""Sandboxed environment with execution timeout."""
|
||||
|
||||
def __init__(self, timeout: int, *args: Any, **kwargs: Any):
|
||||
self._deadline = time.time() + timeout if timeout else None
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
def call(self, context: Context, obj: Any, *args: Any, **kwargs: Any) -> Any:
|
||||
if self._deadline is not None and time.time() > self._deadline:
|
||||
raise TimeoutError("Template rendering timeout")
|
||||
return super().call(context, obj, *args, **kwargs)
|
||||
|
||||
|
||||
def render_email_template(template: str, substitutions: Mapping[str, str]) -> str:
|
||||
"""
|
||||
Render email template content according to the configured template mode.
|
||||
|
||||
In unsafe mode, Jinja expressions are evaluated directly.
|
||||
In sandbox mode, a sandboxed environment with timeout is used.
|
||||
In disabled mode, the template is returned without rendering.
|
||||
"""
|
||||
mode = dify_config.MAIL_TEMPLATING_MODE
|
||||
timeout = dify_config.MAIL_TEMPLATING_TIMEOUT
|
||||
|
||||
if mode == TemplateMode.UNSAFE:
|
||||
return render_template_string(template, **substitutions)
|
||||
if mode == TemplateMode.SANDBOX:
|
||||
env = SandboxedEnvironment(timeout=timeout)
|
||||
tmpl = env.from_string(template)
|
||||
return tmpl.render(substitutions)
|
||||
if mode == TemplateMode.DISABLED:
|
||||
return template
|
||||
raise ValueError(f"Unsupported mail templating mode: {mode}")
|
||||
|
|
@ -1,12 +1,15 @@
|
|||
import contextvars
|
||||
from collections.abc import Iterator
|
||||
from contextlib import contextmanager
|
||||
from typing import TypeVar
|
||||
from typing import TYPE_CHECKING, TypeVar
|
||||
|
||||
from flask import Flask, g
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from models import Account, EndUser
|
||||
|
||||
|
||||
@contextmanager
|
||||
def preserve_flask_contexts(
|
||||
|
|
@ -64,3 +67,7 @@ def preserve_flask_contexts(
|
|||
finally:
|
||||
# Any cleanup can be added here if needed
|
||||
pass
|
||||
|
||||
|
||||
def set_login_user(user: "Account | EndUser"):
|
||||
g._login_user = user
|
||||
|
|
|
|||
|
|
@ -7,10 +7,10 @@ import struct
|
|||
import subprocess
|
||||
import time
|
||||
import uuid
|
||||
from collections.abc import Generator, Mapping
|
||||
from collections.abc import Callable, Generator, Mapping
|
||||
from datetime import datetime
|
||||
from hashlib import sha256
|
||||
from typing import TYPE_CHECKING, Annotated, Any, Optional, Union, cast
|
||||
from typing import TYPE_CHECKING, Annotated, Any, Optional, Protocol, Union, cast
|
||||
from uuid import UUID
|
||||
from zoneinfo import available_timezones
|
||||
|
||||
|
|
@ -126,6 +126,13 @@ class TimestampField(fields.Raw):
|
|||
return int(value.timestamp())
|
||||
|
||||
|
||||
class OptionalTimestampField(fields.Raw):
|
||||
def format(self, value) -> int | None:
|
||||
if value is None:
|
||||
return None
|
||||
return int(value.timestamp())
|
||||
|
||||
|
||||
def email(email):
|
||||
# Define a regex pattern for email addresses
|
||||
pattern = r"^[\w\.!#$%&'*+\-/=?^_`{|}~]+@([\w-]+\.)+[\w-]{2,}$"
|
||||
|
|
@ -237,6 +244,26 @@ def convert_datetime_to_date(field, target_timezone: str = ":tz"):
|
|||
|
||||
|
||||
def generate_string(n):
|
||||
"""
|
||||
Generates a cryptographically secure random string of the specified length.
|
||||
|
||||
This function uses a cryptographically secure pseudorandom number generator (CSPRNG)
|
||||
to create a string composed of ASCII letters (both uppercase and lowercase) and digits.
|
||||
|
||||
Each character in the generated string provides approximately 5.95 bits of entropy
|
||||
(log2(62)). To ensure a minimum of 128 bits of entropy for security purposes, the
|
||||
length of the string (`n`) should be at least 22 characters.
|
||||
|
||||
Args:
|
||||
n (int): The length of the random string to generate. For secure usage,
|
||||
`n` should be 22 or greater.
|
||||
|
||||
Returns:
|
||||
str: A random string of length `n` composed of ASCII letters and digits.
|
||||
|
||||
Note:
|
||||
This function is suitable for generating credentials or other secure tokens.
|
||||
"""
|
||||
letters_digits = string.ascii_letters + string.digits
|
||||
result = ""
|
||||
for _ in range(n):
|
||||
|
|
@ -405,11 +432,35 @@ class TokenManager:
|
|||
return f"{token_type}:account:{account_id}"
|
||||
|
||||
|
||||
class _RateLimiterRedisClient(Protocol):
|
||||
def zadd(self, name: str | bytes, mapping: dict[str | bytes | int | float, float | int | str | bytes]) -> int: ...
|
||||
|
||||
def zremrangebyscore(self, name: str | bytes, min: str | float, max: str | float) -> int: ...
|
||||
|
||||
def zcard(self, name: str | bytes) -> int: ...
|
||||
|
||||
def expire(self, name: str | bytes, time: int) -> bool: ...
|
||||
|
||||
|
||||
def _default_rate_limit_member_factory() -> str:
|
||||
current_time = int(time.time())
|
||||
return f"{current_time}:{secrets.token_urlsafe(nbytes=8)}"
|
||||
|
||||
|
||||
class RateLimiter:
|
||||
def __init__(self, prefix: str, max_attempts: int, time_window: int):
|
||||
def __init__(
|
||||
self,
|
||||
prefix: str,
|
||||
max_attempts: int,
|
||||
time_window: int,
|
||||
member_factory: Callable[[], str] = _default_rate_limit_member_factory,
|
||||
redis_client: _RateLimiterRedisClient = redis_client,
|
||||
):
|
||||
self.prefix = prefix
|
||||
self.max_attempts = max_attempts
|
||||
self.time_window = time_window
|
||||
self._member_factory = member_factory
|
||||
self._redis_client = redis_client
|
||||
|
||||
def _get_key(self, email: str) -> str:
|
||||
return f"{self.prefix}:{email}"
|
||||
|
|
@ -419,8 +470,8 @@ class RateLimiter:
|
|||
current_time = int(time.time())
|
||||
window_start_time = current_time - self.time_window
|
||||
|
||||
redis_client.zremrangebyscore(key, "-inf", window_start_time)
|
||||
attempts = redis_client.zcard(key)
|
||||
self._redis_client.zremrangebyscore(key, "-inf", window_start_time)
|
||||
attempts = self._redis_client.zcard(key)
|
||||
|
||||
if attempts and int(attempts) >= self.max_attempts:
|
||||
return True
|
||||
|
|
@ -428,7 +479,8 @@ class RateLimiter:
|
|||
|
||||
def increment_rate_limit(self, email: str):
|
||||
key = self._get_key(email)
|
||||
member = self._member_factory()
|
||||
current_time = int(time.time())
|
||||
|
||||
redis_client.zadd(key, {current_time: current_time})
|
||||
redis_client.expire(key, self.time_window * 2)
|
||||
self._redis_client.zadd(key, {member: current_time})
|
||||
self._redis_client.expire(key, self.time_window * 2)
|
||||
|
|
|
|||
|
|
@ -0,0 +1,99 @@
|
|||
"""Add human input related db models
|
||||
|
||||
Revision ID: e8c3b3c46151
|
||||
Revises: 788d3099ae3a
|
||||
Create Date: 2026-01-29 14:15:23.081903
|
||||
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
import models as models
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "e8c3b3c46151"
|
||||
down_revision = "788d3099ae3a"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade():
|
||||
op.create_table(
|
||||
"execution_extra_contents",
|
||||
sa.Column("id", models.types.StringUUID(), nullable=False),
|
||||
sa.Column("created_at", sa.DateTime(), server_default=sa.text("CURRENT_TIMESTAMP"), nullable=False),
|
||||
sa.Column("updated_at", sa.DateTime(), server_default=sa.text("CURRENT_TIMESTAMP"), nullable=False),
|
||||
|
||||
sa.Column("type", sa.String(length=30), nullable=False),
|
||||
sa.Column("workflow_run_id", models.types.StringUUID(), nullable=False),
|
||||
sa.Column("message_id", models.types.StringUUID(), nullable=True),
|
||||
sa.Column("form_id", models.types.StringUUID(), nullable=True),
|
||||
sa.PrimaryKeyConstraint("id", name=op.f("execution_extra_contents_pkey")),
|
||||
)
|
||||
with op.batch_alter_table("execution_extra_contents", schema=None) as batch_op:
|
||||
batch_op.create_index(batch_op.f("execution_extra_contents_message_id_idx"), ["message_id"], unique=False)
|
||||
batch_op.create_index(
|
||||
batch_op.f("execution_extra_contents_workflow_run_id_idx"), ["workflow_run_id"], unique=False
|
||||
)
|
||||
|
||||
op.create_table(
|
||||
"human_input_form_deliveries",
|
||||
sa.Column("id", models.types.StringUUID(), nullable=False),
|
||||
sa.Column("created_at", sa.DateTime(), server_default=sa.text("CURRENT_TIMESTAMP"), nullable=False),
|
||||
sa.Column("updated_at", sa.DateTime(), server_default=sa.text("CURRENT_TIMESTAMP"), nullable=False),
|
||||
|
||||
sa.Column("form_id", models.types.StringUUID(), nullable=False),
|
||||
sa.Column("delivery_method_type", sa.String(length=20), nullable=False),
|
||||
sa.Column("delivery_config_id", models.types.StringUUID(), nullable=True),
|
||||
sa.Column("channel_payload", sa.Text(), nullable=False),
|
||||
sa.PrimaryKeyConstraint("id", name=op.f("human_input_form_deliveries_pkey")),
|
||||
)
|
||||
|
||||
op.create_table(
|
||||
"human_input_form_recipients",
|
||||
sa.Column("id", models.types.StringUUID(), nullable=False),
|
||||
sa.Column("created_at", sa.DateTime(), server_default=sa.text("CURRENT_TIMESTAMP"), nullable=False),
|
||||
sa.Column("updated_at", sa.DateTime(), server_default=sa.text("CURRENT_TIMESTAMP"), nullable=False),
|
||||
|
||||
sa.Column("form_id", models.types.StringUUID(), nullable=False),
|
||||
sa.Column("delivery_id", models.types.StringUUID(), nullable=False),
|
||||
sa.Column("recipient_type", sa.String(length=20), nullable=False),
|
||||
sa.Column("recipient_payload", sa.Text(), nullable=False),
|
||||
sa.Column("access_token", sa.VARCHAR(length=32), nullable=False),
|
||||
sa.PrimaryKeyConstraint("id", name=op.f("human_input_form_recipients_pkey")),
|
||||
)
|
||||
with op.batch_alter_table('human_input_form_recipients', schema=None) as batch_op:
|
||||
batch_op.create_unique_constraint(batch_op.f('human_input_form_recipients_access_token_key'), ['access_token'])
|
||||
|
||||
op.create_table(
|
||||
"human_input_forms",
|
||||
sa.Column("id", models.types.StringUUID(), nullable=False),
|
||||
sa.Column("created_at", sa.DateTime(), server_default=sa.text("CURRENT_TIMESTAMP"), nullable=False),
|
||||
sa.Column("updated_at", sa.DateTime(), server_default=sa.text("CURRENT_TIMESTAMP"), nullable=False),
|
||||
|
||||
sa.Column("tenant_id", models.types.StringUUID(), nullable=False),
|
||||
sa.Column("app_id", models.types.StringUUID(), nullable=False),
|
||||
sa.Column("workflow_run_id", models.types.StringUUID(), nullable=True),
|
||||
sa.Column("form_kind", sa.String(length=20), nullable=False),
|
||||
sa.Column("node_id", sa.String(length=60), nullable=False),
|
||||
sa.Column("form_definition", sa.Text(), nullable=False),
|
||||
sa.Column("rendered_content", sa.Text(), nullable=False),
|
||||
sa.Column("status", sa.String(length=20), nullable=False),
|
||||
sa.Column("expiration_time", sa.DateTime(), nullable=False),
|
||||
sa.Column("selected_action_id", sa.String(length=200), nullable=True),
|
||||
sa.Column("submitted_data", sa.Text(), nullable=True),
|
||||
sa.Column("submitted_at", sa.DateTime(), nullable=True),
|
||||
sa.Column("submission_user_id", models.types.StringUUID(), nullable=True),
|
||||
sa.Column("submission_end_user_id", models.types.StringUUID(), nullable=True),
|
||||
sa.Column("completed_by_recipient_id", models.types.StringUUID(), nullable=True),
|
||||
|
||||
sa.PrimaryKeyConstraint("id", name=op.f("human_input_forms_pkey")),
|
||||
)
|
||||
|
||||
|
||||
def downgrade():
|
||||
op.drop_table("human_input_forms")
|
||||
op.drop_table("human_input_form_recipients")
|
||||
op.drop_table("human_input_form_deliveries")
|
||||
op.drop_table("execution_extra_contents")
|
||||
|
|
@ -34,6 +34,8 @@ from .enums import (
|
|||
WorkflowRunTriggeredFrom,
|
||||
WorkflowTriggerStatus,
|
||||
)
|
||||
from .execution_extra_content import ExecutionExtraContent, HumanInputContent
|
||||
from .human_input import HumanInputForm
|
||||
from .model import (
|
||||
AccountTrialAppRecord,
|
||||
ApiRequest,
|
||||
|
|
@ -155,9 +157,12 @@ __all__ = [
|
|||
"DocumentSegment",
|
||||
"Embedding",
|
||||
"EndUser",
|
||||
"ExecutionExtraContent",
|
||||
"ExporleBanner",
|
||||
"ExternalKnowledgeApis",
|
||||
"ExternalKnowledgeBindings",
|
||||
"HumanInputContent",
|
||||
"HumanInputForm",
|
||||
"IconType",
|
||||
"InstalledApp",
|
||||
"InvitationCode",
|
||||
|
|
|
|||
|
|
@ -41,7 +41,7 @@ class DefaultFieldsMixin:
|
|||
)
|
||||
|
||||
updated_at: Mapped[datetime] = mapped_column(
|
||||
__name_pos=DateTime,
|
||||
DateTime,
|
||||
nullable=False,
|
||||
default=naive_utc_now,
|
||||
server_default=func.current_timestamp(),
|
||||
|
|
|
|||
|
|
@ -36,6 +36,7 @@ class MessageStatus(StrEnum):
|
|||
"""
|
||||
|
||||
NORMAL = "normal"
|
||||
PAUSED = "paused"
|
||||
ERROR = "error"
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -0,0 +1,78 @@
|
|||
from enum import StrEnum, auto
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
||||
|
||||
from .base import Base, DefaultFieldsMixin
|
||||
from .types import EnumText, StringUUID
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .human_input import HumanInputForm
|
||||
|
||||
|
||||
class ExecutionContentType(StrEnum):
|
||||
HUMAN_INPUT = auto()
|
||||
|
||||
|
||||
class ExecutionExtraContent(DefaultFieldsMixin, Base):
|
||||
"""ExecutionExtraContent stores extra contents produced during workflow / chatflow execution."""
|
||||
|
||||
# The `ExecutionExtraContent` uses single table inheritance to model different
|
||||
# kinds of contents produced during message generation.
|
||||
#
|
||||
# See: https://docs.sqlalchemy.org/en/20/orm/inheritance.html#single-table-inheritance
|
||||
|
||||
__tablename__ = "execution_extra_contents"
|
||||
__mapper_args__ = {
|
||||
"polymorphic_abstract": True,
|
||||
"polymorphic_on": "type",
|
||||
"with_polymorphic": "*",
|
||||
}
|
||||
# type records the type of the content. It serves as the `discriminator` for the
|
||||
# single table inheritance.
|
||||
type: Mapped[ExecutionContentType] = mapped_column(
|
||||
EnumText(ExecutionContentType, length=30),
|
||||
nullable=False,
|
||||
)
|
||||
|
||||
# `workflow_run_id` records the workflow execution which generates this content, correspond to
|
||||
# `WorkflowRun.id`.
|
||||
workflow_run_id: Mapped[str] = mapped_column(StringUUID, nullable=False, index=True)
|
||||
|
||||
# `message_id` records the messages generated by the execution associated with this `ExecutionExtraContent`.
|
||||
# It references to `Message.id`.
|
||||
#
|
||||
# For workflow execution, this field is `None`.
|
||||
#
|
||||
# For chatflow execution, `message_id`` is not None, and the following condition holds:
|
||||
#
|
||||
# The message referenced by `message_id` has `message.workflow_run_id == execution_extra_content.workflow_run_id`
|
||||
#
|
||||
message_id: Mapped[str | None] = mapped_column(StringUUID, nullable=True, index=True)
|
||||
|
||||
|
||||
class HumanInputContent(ExecutionExtraContent):
|
||||
"""HumanInputContent is a concrete class that represents human input content.
|
||||
It should only be initialized with the `new` class method."""
|
||||
|
||||
__mapper_args__ = {
|
||||
"polymorphic_identity": ExecutionContentType.HUMAN_INPUT,
|
||||
}
|
||||
|
||||
# A relation to HumanInputForm table.
|
||||
#
|
||||
# While the form_id column is nullable in database (due to the nature of single table inheritance),
|
||||
# the form_id field should not be null for a given `HumanInputContent` instance.
|
||||
form_id: Mapped[str] = mapped_column(StringUUID, nullable=True)
|
||||
|
||||
@classmethod
|
||||
def new(cls, form_id: str, message_id: str | None) -> "HumanInputContent":
|
||||
return cls(form_id=form_id, message_id=message_id)
|
||||
|
||||
form: Mapped["HumanInputForm"] = relationship(
|
||||
"HumanInputForm",
|
||||
foreign_keys=[form_id],
|
||||
uselist=False,
|
||||
lazy="raise",
|
||||
primaryjoin="foreign(HumanInputContent.form_id) == HumanInputForm.id",
|
||||
)
|
||||
|
|
@ -0,0 +1,237 @@
|
|||
from datetime import datetime
|
||||
from enum import StrEnum
|
||||
from typing import Annotated, Literal, Self, final
|
||||
|
||||
import sqlalchemy as sa
|
||||
from pydantic import BaseModel, Field
|
||||
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
||||
|
||||
from core.workflow.nodes.human_input.enums import (
|
||||
DeliveryMethodType,
|
||||
HumanInputFormKind,
|
||||
HumanInputFormStatus,
|
||||
)
|
||||
from libs.helper import generate_string
|
||||
|
||||
from .base import Base, DefaultFieldsMixin
|
||||
from .types import EnumText, StringUUID
|
||||
|
||||
_token_length = 22
|
||||
# A 32-character string can store a base64-encoded value with 192 bits of entropy
|
||||
# or a base62-encoded value with over 180 bits of entropy, providing sufficient
|
||||
# uniqueness for most use cases.
|
||||
_token_field_length = 32
|
||||
_email_field_length = 330
|
||||
|
||||
|
||||
def _generate_token() -> str:
|
||||
return generate_string(_token_length)
|
||||
|
||||
|
||||
class HumanInputForm(DefaultFieldsMixin, Base):
|
||||
__tablename__ = "human_input_forms"
|
||||
|
||||
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
app_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
workflow_run_id: Mapped[str | None] = mapped_column(StringUUID, nullable=True)
|
||||
form_kind: Mapped[HumanInputFormKind] = mapped_column(
|
||||
EnumText(HumanInputFormKind),
|
||||
nullable=False,
|
||||
default=HumanInputFormKind.RUNTIME,
|
||||
)
|
||||
|
||||
# The human input node the current form corresponds to.
|
||||
node_id: Mapped[str] = mapped_column(sa.String(60), nullable=False)
|
||||
form_definition: Mapped[str] = mapped_column(sa.Text, nullable=False)
|
||||
rendered_content: Mapped[str] = mapped_column(sa.Text, nullable=False)
|
||||
status: Mapped[HumanInputFormStatus] = mapped_column(
|
||||
EnumText(HumanInputFormStatus),
|
||||
nullable=False,
|
||||
default=HumanInputFormStatus.WAITING,
|
||||
)
|
||||
|
||||
expiration_time: Mapped[datetime] = mapped_column(
|
||||
sa.DateTime,
|
||||
nullable=False,
|
||||
)
|
||||
|
||||
# Submission-related fields (nullable until a submission happens).
|
||||
selected_action_id: Mapped[str | None] = mapped_column(sa.String(200), nullable=True)
|
||||
submitted_data: Mapped[str | None] = mapped_column(sa.Text, nullable=True)
|
||||
submitted_at: Mapped[datetime | None] = mapped_column(sa.DateTime, nullable=True)
|
||||
submission_user_id: Mapped[str | None] = mapped_column(StringUUID, nullable=True)
|
||||
submission_end_user_id: Mapped[str | None] = mapped_column(StringUUID, nullable=True)
|
||||
|
||||
completed_by_recipient_id: Mapped[str | None] = mapped_column(
|
||||
StringUUID,
|
||||
nullable=True,
|
||||
)
|
||||
|
||||
deliveries: Mapped[list["HumanInputDelivery"]] = relationship(
|
||||
"HumanInputDelivery",
|
||||
primaryjoin="HumanInputForm.id == foreign(HumanInputDelivery.form_id)",
|
||||
uselist=True,
|
||||
back_populates="form",
|
||||
lazy="raise",
|
||||
)
|
||||
completed_by_recipient: Mapped["HumanInputFormRecipient | None"] = relationship(
|
||||
"HumanInputFormRecipient",
|
||||
primaryjoin="HumanInputForm.completed_by_recipient_id == foreign(HumanInputFormRecipient.id)",
|
||||
lazy="raise",
|
||||
viewonly=True,
|
||||
)
|
||||
|
||||
|
||||
class HumanInputDelivery(DefaultFieldsMixin, Base):
|
||||
__tablename__ = "human_input_form_deliveries"
|
||||
|
||||
form_id: Mapped[str] = mapped_column(
|
||||
StringUUID,
|
||||
nullable=False,
|
||||
)
|
||||
delivery_method_type: Mapped[DeliveryMethodType] = mapped_column(
|
||||
EnumText(DeliveryMethodType),
|
||||
nullable=False,
|
||||
)
|
||||
delivery_config_id: Mapped[str | None] = mapped_column(StringUUID, nullable=True)
|
||||
channel_payload: Mapped[str] = mapped_column(sa.Text, nullable=False)
|
||||
|
||||
form: Mapped[HumanInputForm] = relationship(
|
||||
"HumanInputForm",
|
||||
uselist=False,
|
||||
foreign_keys=[form_id],
|
||||
primaryjoin="HumanInputDelivery.form_id == HumanInputForm.id",
|
||||
back_populates="deliveries",
|
||||
lazy="raise",
|
||||
)
|
||||
|
||||
recipients: Mapped[list["HumanInputFormRecipient"]] = relationship(
|
||||
"HumanInputFormRecipient",
|
||||
primaryjoin="HumanInputDelivery.id == foreign(HumanInputFormRecipient.delivery_id)",
|
||||
uselist=True,
|
||||
back_populates="delivery",
|
||||
# Require explicit preloading
|
||||
lazy="raise",
|
||||
)
|
||||
|
||||
|
||||
class RecipientType(StrEnum):
|
||||
# EMAIL_MEMBER member means that the
|
||||
EMAIL_MEMBER = "email_member"
|
||||
EMAIL_EXTERNAL = "email_external"
|
||||
# STANDALONE_WEB_APP is used by the standalone web app.
|
||||
#
|
||||
# It's not used while running workflows / chatflows containing HumanInput
|
||||
# node inside console.
|
||||
STANDALONE_WEB_APP = "standalone_web_app"
|
||||
# CONSOLE is used while running workflows / chatflows containing HumanInput
|
||||
# node inside console. (E.G. running installed apps or debugging workflows / chatflows)
|
||||
CONSOLE = "console"
|
||||
# BACKSTAGE is used for backstage input inside console.
|
||||
BACKSTAGE = "backstage"
|
||||
|
||||
|
||||
@final
|
||||
class EmailMemberRecipientPayload(BaseModel):
|
||||
TYPE: Literal[RecipientType.EMAIL_MEMBER] = RecipientType.EMAIL_MEMBER
|
||||
user_id: str
|
||||
|
||||
# The `email` field here is only used for mail sending.
|
||||
email: str
|
||||
|
||||
|
||||
@final
|
||||
class EmailExternalRecipientPayload(BaseModel):
|
||||
TYPE: Literal[RecipientType.EMAIL_EXTERNAL] = RecipientType.EMAIL_EXTERNAL
|
||||
email: str
|
||||
|
||||
|
||||
@final
|
||||
class StandaloneWebAppRecipientPayload(BaseModel):
|
||||
TYPE: Literal[RecipientType.STANDALONE_WEB_APP] = RecipientType.STANDALONE_WEB_APP
|
||||
|
||||
|
||||
@final
|
||||
class ConsoleRecipientPayload(BaseModel):
|
||||
TYPE: Literal[RecipientType.CONSOLE] = RecipientType.CONSOLE
|
||||
account_id: str | None = None
|
||||
|
||||
|
||||
@final
|
||||
class BackstageRecipientPayload(BaseModel):
|
||||
TYPE: Literal[RecipientType.BACKSTAGE] = RecipientType.BACKSTAGE
|
||||
account_id: str | None = None
|
||||
|
||||
|
||||
@final
|
||||
class ConsoleDeliveryPayload(BaseModel):
|
||||
type: Literal["console"] = "console"
|
||||
internal: bool = True
|
||||
|
||||
|
||||
RecipientPayload = Annotated[
|
||||
EmailMemberRecipientPayload
|
||||
| EmailExternalRecipientPayload
|
||||
| StandaloneWebAppRecipientPayload
|
||||
| ConsoleRecipientPayload
|
||||
| BackstageRecipientPayload,
|
||||
Field(discriminator="TYPE"),
|
||||
]
|
||||
|
||||
|
||||
class HumanInputFormRecipient(DefaultFieldsMixin, Base):
|
||||
__tablename__ = "human_input_form_recipients"
|
||||
|
||||
form_id: Mapped[str] = mapped_column(
|
||||
StringUUID,
|
||||
nullable=False,
|
||||
)
|
||||
delivery_id: Mapped[str] = mapped_column(
|
||||
StringUUID,
|
||||
nullable=False,
|
||||
)
|
||||
recipient_type: Mapped["RecipientType"] = mapped_column(EnumText(RecipientType), nullable=False)
|
||||
recipient_payload: Mapped[str] = mapped_column(sa.Text, nullable=False)
|
||||
|
||||
# Token primarily used for authenticated resume links (email, etc.).
|
||||
access_token: Mapped[str | None] = mapped_column(
|
||||
sa.VARCHAR(_token_field_length),
|
||||
nullable=False,
|
||||
default=_generate_token,
|
||||
unique=True,
|
||||
)
|
||||
|
||||
delivery: Mapped[HumanInputDelivery] = relationship(
|
||||
"HumanInputDelivery",
|
||||
uselist=False,
|
||||
foreign_keys=[delivery_id],
|
||||
back_populates="recipients",
|
||||
primaryjoin="HumanInputFormRecipient.delivery_id == HumanInputDelivery.id",
|
||||
# Require explicit preloading
|
||||
lazy="raise",
|
||||
)
|
||||
|
||||
form: Mapped[HumanInputForm] = relationship(
|
||||
"HumanInputForm",
|
||||
uselist=False,
|
||||
foreign_keys=[form_id],
|
||||
primaryjoin="HumanInputFormRecipient.form_id == HumanInputForm.id",
|
||||
# Require explicit preloading
|
||||
lazy="raise",
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def new(
|
||||
cls,
|
||||
form_id: str,
|
||||
delivery_id: str,
|
||||
payload: RecipientPayload,
|
||||
) -> Self:
|
||||
recipient_model = cls(
|
||||
form_id=form_id,
|
||||
delivery_id=delivery_id,
|
||||
recipient_type=payload.TYPE,
|
||||
recipient_payload=payload.model_dump_json(),
|
||||
access_token=_generate_token(),
|
||||
)
|
||||
return recipient_model
|
||||
|
|
@ -3,7 +3,7 @@ from __future__ import annotations
|
|||
import json
|
||||
import re
|
||||
import uuid
|
||||
from collections.abc import Mapping
|
||||
from collections.abc import Mapping, Sequence
|
||||
from datetime import datetime
|
||||
from decimal import Decimal
|
||||
from enum import StrEnum, auto
|
||||
|
|
@ -943,6 +943,7 @@ class Conversation(Base):
|
|||
WorkflowExecutionStatus.FAILED: 0,
|
||||
WorkflowExecutionStatus.STOPPED: 0,
|
||||
WorkflowExecutionStatus.PARTIAL_SUCCEEDED: 0,
|
||||
WorkflowExecutionStatus.PAUSED: 0,
|
||||
}
|
||||
|
||||
for message in messages:
|
||||
|
|
@ -963,6 +964,7 @@ class Conversation(Base):
|
|||
"success": status_counts[WorkflowExecutionStatus.SUCCEEDED],
|
||||
"failed": status_counts[WorkflowExecutionStatus.FAILED],
|
||||
"partial_success": status_counts[WorkflowExecutionStatus.PARTIAL_SUCCEEDED],
|
||||
"paused": status_counts[WorkflowExecutionStatus.PAUSED],
|
||||
}
|
||||
|
||||
@property
|
||||
|
|
@ -1345,6 +1347,14 @@ class Message(Base):
|
|||
db.session.commit()
|
||||
return result
|
||||
|
||||
# TODO(QuantumGhost): dirty hacks, fix this later.
|
||||
def set_extra_contents(self, contents: Sequence[dict[str, Any]]) -> None:
|
||||
self._extra_contents = list(contents)
|
||||
|
||||
@property
|
||||
def extra_contents(self) -> list[dict[str, Any]]:
|
||||
return getattr(self, "_extra_contents", [])
|
||||
|
||||
@property
|
||||
def workflow_run(self):
|
||||
if self.workflow_run_id:
|
||||
|
|
|
|||
|
|
@ -20,6 +20,7 @@ from sqlalchemy import (
|
|||
select,
|
||||
)
|
||||
from sqlalchemy.orm import Mapped, declared_attr, mapped_column
|
||||
from typing_extensions import deprecated
|
||||
|
||||
from core.file.constants import maybe_file_object
|
||||
from core.file.models import File
|
||||
|
|
@ -30,7 +31,7 @@ from core.workflow.constants import (
|
|||
SYSTEM_VARIABLE_NODE_ID,
|
||||
)
|
||||
from core.workflow.entities.pause_reason import HumanInputRequired, PauseReason, PauseReasonType, SchedulingPause
|
||||
from core.workflow.enums import NodeType
|
||||
from core.workflow.enums import NodeType, WorkflowExecutionStatus
|
||||
from extensions.ext_storage import Storage
|
||||
from factories.variable_factory import TypeMismatchError, build_segment_with_type
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
|
|
@ -405,6 +406,11 @@ class Workflow(Base): # bug
|
|||
return helper.generate_text_hash(json.dumps(entity, sort_keys=True))
|
||||
|
||||
@property
|
||||
@deprecated(
|
||||
"This property is not accurate for determining if a workflow is published as a tool."
|
||||
"It only checks if there's a WorkflowToolProvider for the app, "
|
||||
"not if this specific workflow version is the one being used by the tool."
|
||||
)
|
||||
def tool_published(self) -> bool:
|
||||
"""
|
||||
DEPRECATED: This property is not accurate for determining if a workflow is published as a tool.
|
||||
|
|
@ -607,13 +613,16 @@ class WorkflowRun(Base):
|
|||
version: Mapped[str] = mapped_column(String(255))
|
||||
graph: Mapped[str | None] = mapped_column(LongText)
|
||||
inputs: Mapped[str | None] = mapped_column(LongText)
|
||||
status: Mapped[str] = mapped_column(String(255)) # running, succeeded, failed, stopped, partial-succeeded
|
||||
status: Mapped[WorkflowExecutionStatus] = mapped_column(
|
||||
EnumText(WorkflowExecutionStatus, length=255),
|
||||
nullable=False,
|
||||
)
|
||||
outputs: Mapped[str | None] = mapped_column(LongText, default="{}")
|
||||
error: Mapped[str | None] = mapped_column(LongText)
|
||||
elapsed_time: Mapped[float] = mapped_column(sa.Float, nullable=False, server_default=sa.text("0"))
|
||||
total_tokens: Mapped[int] = mapped_column(sa.BigInteger, server_default=sa.text("0"))
|
||||
total_steps: Mapped[int] = mapped_column(sa.Integer, server_default=sa.text("0"), nullable=True)
|
||||
created_by_role: Mapped[str] = mapped_column(String(255)) # account, end_user
|
||||
created_by_role: Mapped[CreatorUserRole] = mapped_column(EnumText(CreatorUserRole, length=255)) # account, end_user
|
||||
created_by: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
finished_at: Mapped[datetime | None] = mapped_column(DateTime)
|
||||
|
|
@ -629,11 +638,13 @@ class WorkflowRun(Base):
|
|||
)
|
||||
|
||||
@property
|
||||
@deprecated("This method is retained for historical reasons; avoid using it if possible.")
|
||||
def created_by_account(self):
|
||||
created_by_role = CreatorUserRole(self.created_by_role)
|
||||
return db.session.get(Account, self.created_by) if created_by_role == CreatorUserRole.ACCOUNT else None
|
||||
|
||||
@property
|
||||
@deprecated("This method is retained for historical reasons; avoid using it if possible.")
|
||||
def created_by_end_user(self):
|
||||
from .model import EndUser
|
||||
|
||||
|
|
@ -653,6 +664,7 @@ class WorkflowRun(Base):
|
|||
return json.loads(self.outputs) if self.outputs else {}
|
||||
|
||||
@property
|
||||
@deprecated("This method is retained for historical reasons; avoid using it if possible.")
|
||||
def message(self):
|
||||
from .model import Message
|
||||
|
||||
|
|
@ -661,6 +673,7 @@ class WorkflowRun(Base):
|
|||
)
|
||||
|
||||
@property
|
||||
@deprecated("This method is retained for historical reasons; avoid using it if possible.")
|
||||
def workflow(self):
|
||||
return db.session.query(Workflow).where(Workflow.id == self.workflow_id).first()
|
||||
|
||||
|
|
@ -1861,7 +1874,12 @@ class WorkflowPauseReason(DefaultFieldsMixin, Base):
|
|||
|
||||
def to_entity(self) -> PauseReason:
|
||||
if self.type_ == PauseReasonType.HUMAN_INPUT_REQUIRED:
|
||||
return HumanInputRequired(form_id=self.form_id, node_id=self.node_id)
|
||||
return HumanInputRequired(
|
||||
form_id=self.form_id,
|
||||
form_content="",
|
||||
node_id=self.node_id,
|
||||
node_title="",
|
||||
)
|
||||
elif self.type_ == PauseReasonType.SCHEDULED_PAUSE:
|
||||
return SchedulingPause(message=self.message)
|
||||
else:
|
||||
|
|
|
|||
|
|
@ -10,6 +10,7 @@ tenant_id, app_id, triggered_from, etc., which are not part of the core domain m
|
|||
"""
|
||||
|
||||
from collections.abc import Sequence
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime
|
||||
from typing import Protocol
|
||||
|
||||
|
|
@ -19,6 +20,27 @@ from core.workflow.repositories.workflow_node_execution_repository import Workfl
|
|||
from models.workflow import WorkflowNodeExecutionModel, WorkflowNodeExecutionOffload
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class WorkflowNodeExecutionSnapshot:
|
||||
"""
|
||||
Minimal snapshot of workflow node execution for stream recovery.
|
||||
|
||||
Only includes fields required by snapshot events.
|
||||
"""
|
||||
|
||||
execution_id: str # Unique execution identifier (node_execution_id or row id).
|
||||
node_id: str # Workflow graph node id.
|
||||
node_type: str # Workflow graph node type (e.g. "human-input").
|
||||
title: str # Human-friendly node title.
|
||||
index: int # Execution order index within the workflow run.
|
||||
status: str # Execution status (running/succeeded/failed/paused).
|
||||
elapsed_time: float # Execution elapsed time in seconds.
|
||||
created_at: datetime # Execution created timestamp.
|
||||
finished_at: datetime | None # Execution finished timestamp.
|
||||
iteration_id: str | None = None # Iteration id from execution metadata, if any.
|
||||
loop_id: str | None = None # Loop id from execution metadata, if any.
|
||||
|
||||
|
||||
class DifyAPIWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository, Protocol):
|
||||
"""
|
||||
Protocol for service-layer operations on WorkflowNodeExecutionModel.
|
||||
|
|
@ -79,6 +101,8 @@ class DifyAPIWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository, Pr
|
|||
Args:
|
||||
tenant_id: The tenant identifier
|
||||
app_id: The application identifier
|
||||
workflow_id: The workflow identifier
|
||||
triggered_from: The workflow trigger source
|
||||
workflow_run_id: The workflow run identifier
|
||||
|
||||
Returns:
|
||||
|
|
@ -86,6 +110,27 @@ class DifyAPIWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository, Pr
|
|||
"""
|
||||
...
|
||||
|
||||
def get_execution_snapshots_by_workflow_run(
|
||||
self,
|
||||
tenant_id: str,
|
||||
app_id: str,
|
||||
workflow_id: str,
|
||||
triggered_from: str,
|
||||
workflow_run_id: str,
|
||||
) -> Sequence[WorkflowNodeExecutionSnapshot]:
|
||||
"""
|
||||
Get minimal snapshots for node executions in a workflow run.
|
||||
|
||||
Args:
|
||||
tenant_id: The tenant identifier
|
||||
app_id: The application identifier
|
||||
workflow_run_id: The workflow run identifier
|
||||
|
||||
Returns:
|
||||
A sequence of WorkflowNodeExecutionSnapshot ordered by creation time
|
||||
"""
|
||||
...
|
||||
|
||||
def get_execution_by_id(
|
||||
self,
|
||||
execution_id: str,
|
||||
|
|
|
|||
|
|
@ -432,6 +432,13 @@ class APIWorkflowRunRepository(WorkflowExecutionRepository, Protocol):
|
|||
# while creating pause.
|
||||
...
|
||||
|
||||
def get_workflow_pause(self, workflow_run_id: str) -> WorkflowPauseEntity | None:
|
||||
"""Retrieve the current pause for a workflow execution.
|
||||
|
||||
If there is no current pause, this method would return `None`.
|
||||
"""
|
||||
...
|
||||
|
||||
def resume_workflow_pause(
|
||||
self,
|
||||
workflow_run_id: str,
|
||||
|
|
@ -627,3 +634,19 @@ class APIWorkflowRunRepository(WorkflowExecutionRepository, Protocol):
|
|||
[{"date": "2024-01-01", "interactions": 2.5}, ...]
|
||||
"""
|
||||
...
|
||||
|
||||
def get_workflow_run_by_id_and_tenant_id(self, tenant_id: str, run_id: str) -> WorkflowRun | None:
|
||||
"""
|
||||
Get a specific workflow run by its id and the associated tenant id.
|
||||
|
||||
This function does not apply application isolation. It should only be used when
|
||||
the application identifier is not available.
|
||||
|
||||
Args:
|
||||
tenant_id: Tenant identifier for multi-tenant isolation
|
||||
run_id: Workflow run identifier
|
||||
|
||||
Returns:
|
||||
WorkflowRun object if found, None otherwise
|
||||
"""
|
||||
...
|
||||
|
|
|
|||
|
|
@ -63,6 +63,12 @@ class WorkflowPauseEntity(ABC):
|
|||
"""
|
||||
pass
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def paused_at(self) -> datetime:
|
||||
"""`paused_at` returns the creation time of the pause."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_pause_reasons(self) -> Sequence[PauseReason]:
|
||||
"""
|
||||
|
|
@ -70,7 +76,5 @@ class WorkflowPauseEntity(ABC):
|
|||
|
||||
Returns a sequence of `PauseReason` objects describing the specific nodes and
|
||||
reasons for which the workflow execution was paused.
|
||||
This information is related to, but distinct from, the `PauseReason` type
|
||||
defined in `api/core/workflow/entities/pause_reason.py`.
|
||||
"""
|
||||
...
|
||||
|
|
|
|||
|
|
@ -0,0 +1,13 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Sequence
|
||||
from typing import Protocol
|
||||
|
||||
from core.entities.execution_extra_content import ExecutionExtraContentDomainModel
|
||||
|
||||
|
||||
class ExecutionExtraContentRepository(Protocol):
|
||||
def get_by_message_ids(self, message_ids: Sequence[str]) -> list[list[ExecutionExtraContentDomainModel]]: ...
|
||||
|
||||
|
||||
__all__ = ["ExecutionExtraContentRepository"]
|
||||
|
|
@ -5,6 +5,7 @@ This module provides a concrete implementation of the service repository protoco
|
|||
using SQLAlchemy 2.0 style queries for WorkflowNodeExecutionModel operations.
|
||||
"""
|
||||
|
||||
import json
|
||||
from collections.abc import Sequence
|
||||
from datetime import datetime
|
||||
from typing import cast
|
||||
|
|
@ -13,11 +14,12 @@ from sqlalchemy import asc, delete, desc, func, select
|
|||
from sqlalchemy.engine import CursorResult
|
||||
from sqlalchemy.orm import Session, sessionmaker
|
||||
|
||||
from models.workflow import (
|
||||
WorkflowNodeExecutionModel,
|
||||
WorkflowNodeExecutionOffload,
|
||||
from core.workflow.enums import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus
|
||||
from models.workflow import WorkflowNodeExecutionModel, WorkflowNodeExecutionOffload
|
||||
from repositories.api_workflow_node_execution_repository import (
|
||||
DifyAPIWorkflowNodeExecutionRepository,
|
||||
WorkflowNodeExecutionSnapshot,
|
||||
)
|
||||
from repositories.api_workflow_node_execution_repository import DifyAPIWorkflowNodeExecutionRepository
|
||||
|
||||
|
||||
class DifyAPISQLAlchemyWorkflowNodeExecutionRepository(DifyAPIWorkflowNodeExecutionRepository):
|
||||
|
|
@ -79,6 +81,7 @@ class DifyAPISQLAlchemyWorkflowNodeExecutionRepository(DifyAPIWorkflowNodeExecut
|
|||
WorkflowNodeExecutionModel.app_id == app_id,
|
||||
WorkflowNodeExecutionModel.workflow_id == workflow_id,
|
||||
WorkflowNodeExecutionModel.node_id == node_id,
|
||||
WorkflowNodeExecutionModel.status != WorkflowNodeExecutionStatus.PAUSED,
|
||||
)
|
||||
.order_by(desc(WorkflowNodeExecutionModel.created_at))
|
||||
.limit(1)
|
||||
|
|
@ -117,6 +120,80 @@ class DifyAPISQLAlchemyWorkflowNodeExecutionRepository(DifyAPIWorkflowNodeExecut
|
|||
with self._session_maker() as session:
|
||||
return session.execute(stmt).scalars().all()
|
||||
|
||||
def get_execution_snapshots_by_workflow_run(
|
||||
self,
|
||||
tenant_id: str,
|
||||
app_id: str,
|
||||
workflow_id: str,
|
||||
triggered_from: str,
|
||||
workflow_run_id: str,
|
||||
) -> Sequence[WorkflowNodeExecutionSnapshot]:
|
||||
stmt = (
|
||||
select(
|
||||
WorkflowNodeExecutionModel.id,
|
||||
WorkflowNodeExecutionModel.node_execution_id,
|
||||
WorkflowNodeExecutionModel.node_id,
|
||||
WorkflowNodeExecutionModel.node_type,
|
||||
WorkflowNodeExecutionModel.title,
|
||||
WorkflowNodeExecutionModel.index,
|
||||
WorkflowNodeExecutionModel.status,
|
||||
WorkflowNodeExecutionModel.elapsed_time,
|
||||
WorkflowNodeExecutionModel.created_at,
|
||||
WorkflowNodeExecutionModel.finished_at,
|
||||
WorkflowNodeExecutionModel.execution_metadata,
|
||||
)
|
||||
.where(
|
||||
WorkflowNodeExecutionModel.tenant_id == tenant_id,
|
||||
WorkflowNodeExecutionModel.app_id == app_id,
|
||||
WorkflowNodeExecutionModel.workflow_id == workflow_id,
|
||||
WorkflowNodeExecutionModel.triggered_from == triggered_from,
|
||||
WorkflowNodeExecutionModel.workflow_run_id == workflow_run_id,
|
||||
)
|
||||
.order_by(
|
||||
asc(WorkflowNodeExecutionModel.created_at),
|
||||
asc(WorkflowNodeExecutionModel.index),
|
||||
)
|
||||
)
|
||||
|
||||
with self._session_maker() as session:
|
||||
rows = session.execute(stmt).all()
|
||||
|
||||
return [self._row_to_snapshot(row) for row in rows]
|
||||
|
||||
@staticmethod
|
||||
def _row_to_snapshot(row: object) -> WorkflowNodeExecutionSnapshot:
|
||||
metadata: dict[str, object] = {}
|
||||
execution_metadata = getattr(row, "execution_metadata", None)
|
||||
if execution_metadata:
|
||||
try:
|
||||
metadata = json.loads(execution_metadata)
|
||||
except json.JSONDecodeError:
|
||||
metadata = {}
|
||||
iteration_id = metadata.get(WorkflowNodeExecutionMetadataKey.ITERATION_ID.value)
|
||||
loop_id = metadata.get(WorkflowNodeExecutionMetadataKey.LOOP_ID.value)
|
||||
execution_id = getattr(row, "node_execution_id", None) or row.id
|
||||
elapsed_time = getattr(row, "elapsed_time", None)
|
||||
created_at = row.created_at
|
||||
finished_at = getattr(row, "finished_at", None)
|
||||
if elapsed_time is None:
|
||||
if finished_at is not None and created_at is not None:
|
||||
elapsed_time = (finished_at - created_at).total_seconds()
|
||||
else:
|
||||
elapsed_time = 0.0
|
||||
return WorkflowNodeExecutionSnapshot(
|
||||
execution_id=str(execution_id),
|
||||
node_id=row.node_id,
|
||||
node_type=row.node_type,
|
||||
title=row.title,
|
||||
index=row.index,
|
||||
status=row.status,
|
||||
elapsed_time=float(elapsed_time),
|
||||
created_at=created_at,
|
||||
finished_at=finished_at,
|
||||
iteration_id=str(iteration_id) if iteration_id else None,
|
||||
loop_id=str(loop_id) if loop_id else None,
|
||||
)
|
||||
|
||||
def get_execution_by_id(
|
||||
self,
|
||||
execution_id: str,
|
||||
|
|
|
|||
|
|
@ -19,6 +19,7 @@ Implementation Notes:
|
|||
- Maintains data consistency with proper transaction handling
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
import uuid
|
||||
from collections.abc import Callable, Sequence
|
||||
|
|
@ -27,12 +28,14 @@ from decimal import Decimal
|
|||
from typing import Any, cast
|
||||
|
||||
import sqlalchemy as sa
|
||||
from pydantic import ValidationError
|
||||
from sqlalchemy import and_, delete, func, null, or_, select
|
||||
from sqlalchemy.engine import CursorResult
|
||||
from sqlalchemy.orm import Session, selectinload, sessionmaker
|
||||
|
||||
from core.workflow.entities.pause_reason import HumanInputRequired, PauseReason, SchedulingPause
|
||||
from core.workflow.entities.pause_reason import HumanInputRequired, PauseReason, PauseReasonType, SchedulingPause
|
||||
from core.workflow.enums import WorkflowExecutionStatus, WorkflowType
|
||||
from core.workflow.nodes.human_input.entities import FormDefinition
|
||||
from extensions.ext_storage import storage
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from libs.helper import convert_datetime_to_date
|
||||
|
|
@ -40,6 +43,7 @@ from libs.infinite_scroll_pagination import InfiniteScrollPagination
|
|||
from libs.time_parser import get_time_threshold
|
||||
from libs.uuid_utils import uuidv7
|
||||
from models.enums import WorkflowRunTriggeredFrom
|
||||
from models.human_input import HumanInputForm, HumanInputFormRecipient, RecipientType
|
||||
from models.workflow import WorkflowAppLog, WorkflowArchiveLog, WorkflowPause, WorkflowPauseReason, WorkflowRun
|
||||
from repositories.api_workflow_run_repository import APIWorkflowRunRepository
|
||||
from repositories.entities.workflow_pause import WorkflowPauseEntity
|
||||
|
|
@ -57,6 +61,67 @@ class _WorkflowRunError(Exception):
|
|||
pass
|
||||
|
||||
|
||||
def _select_recipient_token(
|
||||
recipients: Sequence[HumanInputFormRecipient],
|
||||
recipient_type: RecipientType,
|
||||
) -> str | None:
|
||||
for recipient in recipients:
|
||||
if recipient.recipient_type == recipient_type and recipient.access_token:
|
||||
return recipient.access_token
|
||||
return None
|
||||
|
||||
|
||||
def _build_human_input_required_reason(
|
||||
reason_model: WorkflowPauseReason,
|
||||
form_model: HumanInputForm | None,
|
||||
recipients: Sequence[HumanInputFormRecipient],
|
||||
) -> HumanInputRequired:
|
||||
form_content = ""
|
||||
inputs = []
|
||||
actions = []
|
||||
display_in_ui = False
|
||||
resolved_default_values: dict[str, Any] = {}
|
||||
node_title = "Human Input"
|
||||
form_id = reason_model.form_id
|
||||
node_id = reason_model.node_id
|
||||
if form_model is not None:
|
||||
form_id = form_model.id
|
||||
node_id = form_model.node_id or node_id
|
||||
try:
|
||||
definition_payload = json.loads(form_model.form_definition)
|
||||
if "expiration_time" not in definition_payload:
|
||||
definition_payload["expiration_time"] = form_model.expiration_time
|
||||
definition = FormDefinition.model_validate(definition_payload)
|
||||
except ValidationError:
|
||||
definition = None
|
||||
|
||||
if definition is not None:
|
||||
form_content = definition.form_content
|
||||
inputs = list(definition.inputs)
|
||||
actions = list(definition.user_actions)
|
||||
display_in_ui = bool(definition.display_in_ui)
|
||||
resolved_default_values = dict(definition.default_values)
|
||||
node_title = definition.node_title or node_title
|
||||
|
||||
form_token = (
|
||||
_select_recipient_token(recipients, RecipientType.BACKSTAGE)
|
||||
or _select_recipient_token(recipients, RecipientType.CONSOLE)
|
||||
or _select_recipient_token(recipients, RecipientType.STANDALONE_WEB_APP)
|
||||
)
|
||||
|
||||
return HumanInputRequired(
|
||||
form_id=form_id,
|
||||
form_content=form_content,
|
||||
inputs=inputs,
|
||||
actions=actions,
|
||||
display_in_ui=display_in_ui,
|
||||
node_id=node_id,
|
||||
node_title=node_title,
|
||||
form_token=form_token,
|
||||
resolved_default_values=resolved_default_values,
|
||||
)
|
||||
|
||||
|
||||
class DifyAPISQLAlchemyWorkflowRunRepository(APIWorkflowRunRepository):
|
||||
"""
|
||||
SQLAlchemy implementation of APIWorkflowRunRepository.
|
||||
|
|
@ -676,9 +741,11 @@ class DifyAPISQLAlchemyWorkflowRunRepository(APIWorkflowRunRepository):
|
|||
raise ValueError(f"WorkflowRun not found: {workflow_run_id}")
|
||||
|
||||
# Check if workflow is in RUNNING status
|
||||
if workflow_run.status != WorkflowExecutionStatus.RUNNING:
|
||||
# TODO(QuantumGhost): It seems that the persistence of `WorkflowRun.status`
|
||||
# happens before the execution of GraphLayer
|
||||
if workflow_run.status not in {WorkflowExecutionStatus.RUNNING, WorkflowExecutionStatus.PAUSED}:
|
||||
raise _WorkflowRunError(
|
||||
f"Only WorkflowRun with RUNNING status can be paused, "
|
||||
f"Only WorkflowRun with RUNNING or PAUSED status can be paused, "
|
||||
f"workflow_run_id={workflow_run_id}, current_status={workflow_run.status}"
|
||||
)
|
||||
#
|
||||
|
|
@ -729,13 +796,48 @@ class DifyAPISQLAlchemyWorkflowRunRepository(APIWorkflowRunRepository):
|
|||
|
||||
logger.info("Created workflow pause %s for workflow run %s", pause_model.id, workflow_run_id)
|
||||
|
||||
return _PrivateWorkflowPauseEntity(pause_model=pause_model, reason_models=pause_reason_models)
|
||||
return _PrivateWorkflowPauseEntity(
|
||||
pause_model=pause_model,
|
||||
reason_models=pause_reason_models,
|
||||
pause_reasons=pause_reasons,
|
||||
)
|
||||
|
||||
def _get_reasons_by_pause_id(self, session: Session, pause_id: str):
|
||||
reason_stmt = select(WorkflowPauseReason).where(WorkflowPauseReason.pause_id == pause_id)
|
||||
pause_reason_models = session.scalars(reason_stmt).all()
|
||||
return pause_reason_models
|
||||
|
||||
def _hydrate_pause_reasons(
|
||||
self,
|
||||
session: Session,
|
||||
pause_reason_models: Sequence[WorkflowPauseReason],
|
||||
) -> list[PauseReason]:
|
||||
form_ids = [
|
||||
reason.form_id
|
||||
for reason in pause_reason_models
|
||||
if reason.type_ == PauseReasonType.HUMAN_INPUT_REQUIRED and reason.form_id
|
||||
]
|
||||
form_models: dict[str, HumanInputForm] = {}
|
||||
recipient_models_by_form: dict[str, list[HumanInputFormRecipient]] = {}
|
||||
if form_ids:
|
||||
form_stmt = select(HumanInputForm).where(HumanInputForm.id.in_(form_ids))
|
||||
for form in session.scalars(form_stmt).all():
|
||||
form_models[form.id] = form
|
||||
|
||||
recipient_stmt = select(HumanInputFormRecipient).where(HumanInputFormRecipient.form_id.in_(form_ids))
|
||||
for recipient in session.scalars(recipient_stmt).all():
|
||||
recipient_models_by_form.setdefault(recipient.form_id, []).append(recipient)
|
||||
|
||||
pause_reasons: list[PauseReason] = []
|
||||
for reason in pause_reason_models:
|
||||
if reason.type_ == PauseReasonType.HUMAN_INPUT_REQUIRED:
|
||||
form_model = form_models.get(reason.form_id)
|
||||
recipients = recipient_models_by_form.get(reason.form_id, [])
|
||||
pause_reasons.append(_build_human_input_required_reason(reason, form_model, recipients))
|
||||
else:
|
||||
pause_reasons.append(reason.to_entity())
|
||||
return pause_reasons
|
||||
|
||||
def get_workflow_pause(
|
||||
self,
|
||||
workflow_run_id: str,
|
||||
|
|
@ -767,14 +869,12 @@ class DifyAPISQLAlchemyWorkflowRunRepository(APIWorkflowRunRepository):
|
|||
if pause_model is None:
|
||||
return None
|
||||
pause_reason_models = self._get_reasons_by_pause_id(session, pause_model.id)
|
||||
|
||||
human_input_form: list[Any] = []
|
||||
# TODO(QuantumGhost): query human_input_forms model and rebuild PauseReason
|
||||
pause_reasons = self._hydrate_pause_reasons(session, pause_reason_models)
|
||||
|
||||
return _PrivateWorkflowPauseEntity(
|
||||
pause_model=pause_model,
|
||||
reason_models=pause_reason_models,
|
||||
human_input_form=human_input_form,
|
||||
pause_reasons=pause_reasons,
|
||||
)
|
||||
|
||||
def resume_workflow_pause(
|
||||
|
|
@ -828,10 +928,10 @@ class DifyAPISQLAlchemyWorkflowRunRepository(APIWorkflowRunRepository):
|
|||
raise _WorkflowRunError(f"Cannot resume an already resumed pause, pause_id={pause_model.id}")
|
||||
|
||||
pause_reasons = self._get_reasons_by_pause_id(session, pause_model.id)
|
||||
hydrated_pause_reasons = self._hydrate_pause_reasons(session, pause_reasons)
|
||||
|
||||
# Mark as resumed
|
||||
pause_model.resumed_at = naive_utc_now()
|
||||
workflow_run.pause_id = None # type: ignore
|
||||
workflow_run.status = WorkflowExecutionStatus.RUNNING
|
||||
|
||||
session.add(pause_model)
|
||||
|
|
@ -839,7 +939,11 @@ class DifyAPISQLAlchemyWorkflowRunRepository(APIWorkflowRunRepository):
|
|||
|
||||
logger.info("Resumed workflow pause %s for workflow run %s", pause_model.id, workflow_run_id)
|
||||
|
||||
return _PrivateWorkflowPauseEntity(pause_model=pause_model, reason_models=pause_reasons)
|
||||
return _PrivateWorkflowPauseEntity(
|
||||
pause_model=pause_model,
|
||||
reason_models=pause_reasons,
|
||||
pause_reasons=hydrated_pause_reasons,
|
||||
)
|
||||
|
||||
def delete_workflow_pause(
|
||||
self,
|
||||
|
|
@ -1165,6 +1269,15 @@ GROUP BY
|
|||
|
||||
return cast(list[AverageInteractionStats], response_data)
|
||||
|
||||
def get_workflow_run_by_id_and_tenant_id(self, tenant_id: str, run_id: str) -> WorkflowRun | None:
|
||||
"""Get a specific workflow run by its id and the associated tenant id."""
|
||||
with self._session_maker() as session:
|
||||
stmt = select(WorkflowRun).where(
|
||||
WorkflowRun.tenant_id == tenant_id,
|
||||
WorkflowRun.id == run_id,
|
||||
)
|
||||
return session.scalar(stmt)
|
||||
|
||||
|
||||
class _PrivateWorkflowPauseEntity(WorkflowPauseEntity):
|
||||
"""
|
||||
|
|
@ -1179,10 +1292,12 @@ class _PrivateWorkflowPauseEntity(WorkflowPauseEntity):
|
|||
*,
|
||||
pause_model: WorkflowPause,
|
||||
reason_models: Sequence[WorkflowPauseReason],
|
||||
pause_reasons: Sequence[PauseReason] | None = None,
|
||||
human_input_form: Sequence = (),
|
||||
) -> None:
|
||||
self._pause_model = pause_model
|
||||
self._reason_models = reason_models
|
||||
self._pause_reasons = pause_reasons
|
||||
self._cached_state: bytes | None = None
|
||||
self._human_input_form = human_input_form
|
||||
|
||||
|
|
@ -1219,4 +1334,10 @@ class _PrivateWorkflowPauseEntity(WorkflowPauseEntity):
|
|||
return self._pause_model.resumed_at
|
||||
|
||||
def get_pause_reasons(self) -> Sequence[PauseReason]:
|
||||
if self._pause_reasons is not None:
|
||||
return list(self._pause_reasons)
|
||||
return [reason.to_entity() for reason in self._reason_models]
|
||||
|
||||
@property
|
||||
def paused_at(self) -> datetime:
|
||||
return self._pause_model.created_at
|
||||
|
|
|
|||
|
|
@ -0,0 +1,200 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
import re
|
||||
from collections import defaultdict
|
||||
from collections.abc import Sequence
|
||||
from typing import Any
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session, selectinload, sessionmaker
|
||||
|
||||
from core.entities.execution_extra_content import (
|
||||
ExecutionExtraContentDomainModel,
|
||||
HumanInputFormDefinition,
|
||||
HumanInputFormSubmissionData,
|
||||
)
|
||||
from core.entities.execution_extra_content import (
|
||||
HumanInputContent as HumanInputContentDomainModel,
|
||||
)
|
||||
from core.workflow.nodes.human_input.entities import FormDefinition
|
||||
from core.workflow.nodes.human_input.enums import HumanInputFormStatus
|
||||
from core.workflow.nodes.human_input.human_input_node import HumanInputNode
|
||||
from models.execution_extra_content import (
|
||||
ExecutionExtraContent as ExecutionExtraContentModel,
|
||||
)
|
||||
from models.execution_extra_content import (
|
||||
HumanInputContent as HumanInputContentModel,
|
||||
)
|
||||
from models.human_input import HumanInputFormRecipient, RecipientType
|
||||
from repositories.execution_extra_content_repository import ExecutionExtraContentRepository
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_OUTPUT_VARIABLE_PATTERN = re.compile(r"\{\{#\$output\.(?P<field_name>[a-zA-Z_][a-zA-Z0-9_]{0,29})#\}\}")
|
||||
|
||||
|
||||
def _extract_output_field_names(form_content: str) -> list[str]:
|
||||
if not form_content:
|
||||
return []
|
||||
return [match.group("field_name") for match in _OUTPUT_VARIABLE_PATTERN.finditer(form_content)]
|
||||
|
||||
|
||||
class SQLAlchemyExecutionExtraContentRepository(ExecutionExtraContentRepository):
|
||||
def __init__(self, session_maker: sessionmaker[Session]):
|
||||
self._session_maker = session_maker
|
||||
|
||||
def get_by_message_ids(self, message_ids: Sequence[str]) -> list[list[ExecutionExtraContentDomainModel]]:
|
||||
if not message_ids:
|
||||
return []
|
||||
|
||||
grouped_contents: dict[str, list[ExecutionExtraContentDomainModel]] = {
|
||||
message_id: [] for message_id in message_ids
|
||||
}
|
||||
|
||||
stmt = (
|
||||
select(ExecutionExtraContentModel)
|
||||
.where(ExecutionExtraContentModel.message_id.in_(message_ids))
|
||||
.options(selectinload(HumanInputContentModel.form))
|
||||
.order_by(ExecutionExtraContentModel.created_at.asc())
|
||||
)
|
||||
|
||||
with self._session_maker() as session:
|
||||
results = session.scalars(stmt).all()
|
||||
|
||||
form_ids = {
|
||||
content.form_id
|
||||
for content in results
|
||||
if isinstance(content, HumanInputContentModel) and content.form_id is not None
|
||||
}
|
||||
recipients_by_form_id: dict[str, list[HumanInputFormRecipient]] = defaultdict(list)
|
||||
if form_ids:
|
||||
recipient_stmt = select(HumanInputFormRecipient).where(HumanInputFormRecipient.form_id.in_(form_ids))
|
||||
recipients = session.scalars(recipient_stmt).all()
|
||||
for recipient in recipients:
|
||||
recipients_by_form_id[recipient.form_id].append(recipient)
|
||||
else:
|
||||
recipients_by_form_id = {}
|
||||
|
||||
for content in results:
|
||||
message_id = content.message_id
|
||||
if not message_id or message_id not in grouped_contents:
|
||||
continue
|
||||
|
||||
domain_model = self._map_model_to_domain(content, recipients_by_form_id)
|
||||
if domain_model is None:
|
||||
continue
|
||||
|
||||
grouped_contents[message_id].append(domain_model)
|
||||
|
||||
return [grouped_contents[message_id] for message_id in message_ids]
|
||||
|
||||
def _map_model_to_domain(
|
||||
self,
|
||||
model: ExecutionExtraContentModel,
|
||||
recipients_by_form_id: dict[str, list[HumanInputFormRecipient]],
|
||||
) -> ExecutionExtraContentDomainModel | None:
|
||||
if isinstance(model, HumanInputContentModel):
|
||||
return self._map_human_input_content(model, recipients_by_form_id)
|
||||
|
||||
logger.debug("Unsupported execution extra content type encountered: %s", model.type)
|
||||
return None
|
||||
|
||||
def _map_human_input_content(
|
||||
self,
|
||||
model: HumanInputContentModel,
|
||||
recipients_by_form_id: dict[str, list[HumanInputFormRecipient]],
|
||||
) -> HumanInputContentDomainModel | None:
|
||||
form = model.form
|
||||
if form is None:
|
||||
logger.warning("HumanInputContent(id=%s) has no associated form loaded", model.id)
|
||||
return None
|
||||
|
||||
try:
|
||||
definition_payload = json.loads(form.form_definition)
|
||||
if "expiration_time" not in definition_payload:
|
||||
definition_payload["expiration_time"] = form.expiration_time
|
||||
form_definition = FormDefinition.model_validate(definition_payload)
|
||||
except ValueError:
|
||||
logger.warning("Failed to load form definition for HumanInputContent(id=%s)", model.id)
|
||||
return None
|
||||
node_title = form_definition.node_title or form.node_id
|
||||
display_in_ui = bool(form_definition.display_in_ui)
|
||||
|
||||
submitted = form.submitted_at is not None or form.status == HumanInputFormStatus.SUBMITTED
|
||||
if not submitted:
|
||||
form_token = self._resolve_form_token(recipients_by_form_id.get(form.id, []))
|
||||
return HumanInputContentDomainModel(
|
||||
workflow_run_id=model.workflow_run_id,
|
||||
submitted=False,
|
||||
form_definition=HumanInputFormDefinition(
|
||||
form_id=form.id,
|
||||
node_id=form.node_id,
|
||||
node_title=node_title,
|
||||
form_content=form.rendered_content,
|
||||
inputs=form_definition.inputs,
|
||||
actions=form_definition.user_actions,
|
||||
display_in_ui=display_in_ui,
|
||||
form_token=form_token,
|
||||
resolved_default_values=form_definition.default_values,
|
||||
expiration_time=int(form.expiration_time.timestamp()),
|
||||
),
|
||||
)
|
||||
|
||||
selected_action_id = form.selected_action_id
|
||||
if not selected_action_id:
|
||||
logger.warning("HumanInputContent(id=%s) form has no selected action", model.id)
|
||||
return None
|
||||
|
||||
action_text = next(
|
||||
(action.title for action in form_definition.user_actions if action.id == selected_action_id),
|
||||
selected_action_id,
|
||||
)
|
||||
|
||||
submitted_data: dict[str, Any] = {}
|
||||
if form.submitted_data:
|
||||
try:
|
||||
submitted_data = json.loads(form.submitted_data)
|
||||
except ValueError:
|
||||
logger.warning("Failed to load submitted data for HumanInputContent(id=%s)", model.id)
|
||||
return None
|
||||
|
||||
rendered_content = HumanInputNode.render_form_content_with_outputs(
|
||||
form.rendered_content,
|
||||
submitted_data,
|
||||
_extract_output_field_names(form_definition.form_content),
|
||||
)
|
||||
|
||||
return HumanInputContentDomainModel(
|
||||
workflow_run_id=model.workflow_run_id,
|
||||
submitted=True,
|
||||
form_submission_data=HumanInputFormSubmissionData(
|
||||
node_id=form.node_id,
|
||||
node_title=node_title,
|
||||
rendered_content=rendered_content,
|
||||
action_id=selected_action_id,
|
||||
action_text=action_text,
|
||||
),
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _resolve_form_token(recipients: Sequence[HumanInputFormRecipient]) -> str | None:
|
||||
console_recipient = next(
|
||||
(recipient for recipient in recipients if recipient.recipient_type == RecipientType.CONSOLE),
|
||||
None,
|
||||
)
|
||||
if console_recipient and console_recipient.access_token:
|
||||
return console_recipient.access_token
|
||||
|
||||
web_app_recipient = next(
|
||||
(recipient for recipient in recipients if recipient.recipient_type == RecipientType.STANDALONE_WEB_APP),
|
||||
None,
|
||||
)
|
||||
if web_app_recipient and web_app_recipient.access_token:
|
||||
return web_app_recipient.access_token
|
||||
|
||||
return None
|
||||
|
||||
|
||||
__all__ = ["SQLAlchemyExecutionExtraContentRepository"]
|
||||
|
|
@ -92,6 +92,16 @@ class SQLAlchemyWorkflowTriggerLogRepository(WorkflowTriggerLogRepository):
|
|||
|
||||
return list(self.session.scalars(query).all())
|
||||
|
||||
def get_by_workflow_run_id(self, workflow_run_id: str) -> WorkflowTriggerLog | None:
|
||||
"""Get the trigger log associated with a workflow run."""
|
||||
query = (
|
||||
select(WorkflowTriggerLog)
|
||||
.where(WorkflowTriggerLog.workflow_run_id == workflow_run_id)
|
||||
.order_by(WorkflowTriggerLog.created_at.desc())
|
||||
.limit(1)
|
||||
)
|
||||
return self.session.scalar(query)
|
||||
|
||||
def delete_by_run_ids(self, run_ids: Sequence[str]) -> int:
|
||||
"""
|
||||
Delete trigger logs associated with the given workflow run ids.
|
||||
|
|
|
|||
|
|
@ -110,6 +110,18 @@ class WorkflowTriggerLogRepository(Protocol):
|
|||
"""
|
||||
...
|
||||
|
||||
def get_by_workflow_run_id(self, workflow_run_id: str) -> WorkflowTriggerLog | None:
|
||||
"""
|
||||
Retrieve a trigger log associated with a specific workflow run.
|
||||
|
||||
Args:
|
||||
workflow_run_id: Identifier of the workflow run
|
||||
|
||||
Returns:
|
||||
The matching WorkflowTriggerLog if present, None otherwise
|
||||
"""
|
||||
...
|
||||
|
||||
def delete_by_run_ids(self, run_ids: Sequence[str]) -> int:
|
||||
"""
|
||||
Delete trigger logs for workflow run IDs.
|
||||
|
|
|
|||
|
|
@ -44,7 +44,7 @@ IMPORT_INFO_REDIS_KEY_PREFIX = "app_import_info:"
|
|||
CHECK_DEPENDENCIES_REDIS_KEY_PREFIX = "app_check_dependencies:"
|
||||
IMPORT_INFO_REDIS_EXPIRY = 10 * 60 # 10 minutes
|
||||
DSL_MAX_SIZE = 10 * 1024 * 1024 # 10MB
|
||||
CURRENT_DSL_VERSION = "0.5.0"
|
||||
CURRENT_DSL_VERSION = "0.6.0"
|
||||
|
||||
|
||||
class ImportMode(StrEnum):
|
||||
|
|
|
|||
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue