import logging from collections.abc import Mapping from datetime import datetime, timedelta from typing import Any from sqlalchemy import Engine, select from sqlalchemy.orm import Session, sessionmaker from configs import dify_config from core.repositories.human_input_repository import ( HumanInputFormRecord, HumanInputFormSubmissionRepository, ) from core.workflow.nodes.human_input.entities import ( FormDefinition, HumanInputSubmissionValidationError, validate_human_input_submission, ) from core.workflow.nodes.human_input.enums import HumanInputFormKind, HumanInputFormStatus from libs.datetime_utils import ensure_naive_utc, naive_utc_now from libs.exception import BaseHTTPException from models.human_input import RecipientType from models.model import App, AppMode from repositories.factory import DifyAPIRepositoryFactory from tasks.app_generate.workflow_execute_task import WORKFLOW_BASED_APP_EXECUTION_QUEUE, resume_app_execution class Form: def __init__(self, record: HumanInputFormRecord): self._record = record def get_definition(self) -> FormDefinition: return self._record.definition @property def submitted(self) -> bool: return self._record.submitted @property def id(self) -> str: return self._record.form_id @property def workflow_run_id(self) -> str | None: """Workflow run id for runtime forms; None for delivery tests.""" return self._record.workflow_run_id @property def tenant_id(self) -> str: return self._record.tenant_id @property def app_id(self) -> str: return self._record.app_id @property def recipient_id(self) -> str | None: return self._record.recipient_id @property def recipient_type(self) -> RecipientType | None: return self._record.recipient_type @property def status(self) -> HumanInputFormStatus: return self._record.status @property def form_kind(self) -> HumanInputFormKind: return self._record.form_kind @property def created_at(self) -> "datetime": return self._record.created_at @property def expiration_time(self) -> "datetime": return self._record.expiration_time class HumanInputError(Exception): pass class FormSubmittedError(HumanInputError, BaseHTTPException): error_code = "human_input_form_submitted" description = "This form has already been submitted by another user, form_id={form_id}" code = 412 def __init__(self, form_id: str): template = self.description or "This form has already been submitted by another user, form_id={form_id}" description = template.format(form_id=form_id) super().__init__(description=description) class FormNotFoundError(HumanInputError, BaseHTTPException): error_code = "human_input_form_not_found" code = 404 class InvalidFormDataError(HumanInputError, BaseHTTPException): error_code = "invalid_form_data" code = 400 def __init__(self, description: str): super().__init__(description=description) class WebAppDeliveryNotEnabledError(HumanInputError, BaseException): pass class FormExpiredError(HumanInputError, BaseHTTPException): error_code = "human_input_form_expired" code = 412 def __init__(self, form_id: str): super().__init__(description=f"This form has expired, form_id={form_id}") logger = logging.getLogger(__name__) class HumanInputService: def __init__( self, session_factory: sessionmaker[Session] | Engine, form_repository: HumanInputFormSubmissionRepository | None = None, ): if isinstance(session_factory, Engine): session_factory = sessionmaker(bind=session_factory) self._session_factory = session_factory self._form_repository = form_repository or HumanInputFormSubmissionRepository(session_factory) def get_form_by_token(self, form_token: str) -> Form | None: record = self._form_repository.get_by_token(form_token) if record is None: return None return Form(record) def get_form_definition_by_token(self, recipient_type: RecipientType, form_token: str) -> Form | None: form = self.get_form_by_token(form_token) if form is None or form.recipient_type != recipient_type: return None self._ensure_not_submitted(form) return form def get_form_definition_by_token_for_console(self, form_token: str) -> Form | None: form = self.get_form_by_token(form_token) if form is None or form.recipient_type not in {RecipientType.CONSOLE, RecipientType.BACKSTAGE}: return None self._ensure_not_submitted(form) return form def submit_form_by_token( self, recipient_type: RecipientType, form_token: str, selected_action_id: str, form_data: Mapping[str, Any], submission_end_user_id: str | None = None, submission_user_id: str | None = None, ): form = self.get_form_by_token(form_token) if form is None or form.recipient_type != recipient_type: raise WebAppDeliveryNotEnabledError() self.ensure_form_active(form) self._validate_submission(form=form, selected_action_id=selected_action_id, form_data=form_data) result = self._form_repository.mark_submitted( form_id=form.id, recipient_id=form.recipient_id, selected_action_id=selected_action_id, form_data=form_data, submission_user_id=submission_user_id, submission_end_user_id=submission_end_user_id, ) if result.form_kind != HumanInputFormKind.RUNTIME: return if result.workflow_run_id is None: return self.enqueue_resume(result.workflow_run_id) def ensure_form_active(self, form: Form) -> None: if form.submitted: raise FormSubmittedError(form.id) if form.status in {HumanInputFormStatus.TIMEOUT, HumanInputFormStatus.EXPIRED}: raise FormExpiredError(form.id) now = naive_utc_now() if ensure_naive_utc(form.expiration_time) <= now: raise FormExpiredError(form.id) if self._is_globally_expired(form, now=now): raise FormExpiredError(form.id) def _ensure_not_submitted(self, form: Form) -> None: if form.submitted: raise FormSubmittedError(form.id) def _validate_submission(self, form: Form, selected_action_id: str, form_data: Mapping[str, Any]) -> None: definition = form.get_definition() try: validate_human_input_submission( inputs=definition.inputs, user_actions=definition.user_actions, selected_action_id=selected_action_id, form_data=form_data, ) except HumanInputSubmissionValidationError as exc: raise InvalidFormDataError(str(exc)) from exc def enqueue_resume(self, workflow_run_id: str) -> None: workflow_run_repo = DifyAPIRepositoryFactory.create_api_workflow_run_repository(self._session_factory) workflow_run = workflow_run_repo.get_workflow_run_by_id_without_tenant(workflow_run_id) if workflow_run is None: raise AssertionError(f"WorkflowRun not found, id={workflow_run_id}") with self._session_factory(expire_on_commit=False) as session: app_query = select(App).where(App.id == workflow_run.app_id) app = session.execute(app_query).scalar_one_or_none() if app is None: logger.error( "App not found for WorkflowRun, workflow_run_id=%s, app_id=%s", workflow_run_id, workflow_run.app_id ) return if app.mode in {AppMode.WORKFLOW, AppMode.ADVANCED_CHAT}: payload = {"workflow_run_id": workflow_run_id} try: resume_app_execution.apply_async( kwargs={"payload": payload}, queue=WORKFLOW_BASED_APP_EXECUTION_QUEUE, ) except Exception: # pragma: no cover logger.exception("Failed to enqueue resume task for workflow run %s", workflow_run_id) return logger.warning("App mode %s does not support resume for workflow run %s", app.mode, workflow_run_id) def _is_globally_expired(self, form: Form, *, now: datetime | None = None) -> bool: global_timeout_seconds = dify_config.HITL_GLOBAL_TIMEOUT_SECONDS if global_timeout_seconds <= 0: return False if form.workflow_run_id is None: return False current = now or naive_utc_now() created_at = ensure_naive_utc(form.created_at) global_deadline = created_at + timedelta(seconds=global_timeout_seconds) return global_deadline <= current