"""Human Input node entities. The graph package owns the workflow-facing form schema and keeps it transportable across runtimes. Dify-specific delivery surface and recipient translation stay outside `dify_graph`. """ import re from collections.abc import Mapping, Sequence from datetime import datetime, timedelta from typing import Any, Self from pydantic import BaseModel, Field, field_validator, model_validator from dify_graph.entities.base_node_data import BaseNodeData from dify_graph.enums import BuiltinNodeTypes, NodeType from dify_graph.nodes.base.variable_template_parser import VariableTemplateParser from dify_graph.variables.consts import SELECTORS_LENGTH from .enums import ButtonStyle, FormInputType, PlaceholderType, TimeoutUnit _OUTPUT_VARIABLE_PATTERN = re.compile(r"\{\{#\$output\.(?P[a-zA-Z_][a-zA-Z0-9_]{0,29})#\}\}") class FormInputDefault(BaseModel): """Default configuration for form inputs.""" # NOTE: Ideally, a discriminated union would be used to model # FormInputDefault. However, the UI requires preserving the previous # value when switching between `VARIABLE` and `CONSTANT` types. This # necessitates retaining all fields, making a discriminated union unsuitable. type: PlaceholderType # The selector of default variable, used when `type` is `VARIABLE`. selector: Sequence[str] = Field(default_factory=tuple) # # The value of the default, used when `type` is `CONSTANT`. # TODO: How should we express JSON values? value: str = "" @model_validator(mode="after") def _validate_selector(self) -> Self: if self.type == PlaceholderType.CONSTANT: return self if len(self.selector) < SELECTORS_LENGTH: raise ValueError(f"the length of selector should be at least {SELECTORS_LENGTH}, selector={self.selector}") return self class FormInput(BaseModel): """Form input definition.""" type: FormInputType output_variable_name: str default: FormInputDefault | None = None _IDENTIFIER_PATTERN = re.compile(r"^[A-Za-z_][A-Za-z0-9_]*$") class UserAction(BaseModel): """User action configuration.""" # id is the identifier for this action. # It also serves as the identifiers of output handle. # # The id must be a valid identifier (satisfy the _IDENTIFIER_PATTERN above.) id: str = Field(max_length=20) title: str = Field(max_length=20) button_style: ButtonStyle = ButtonStyle.DEFAULT @field_validator("id") @classmethod def _validate_id(cls, value: str) -> str: if not _IDENTIFIER_PATTERN.match(value): raise ValueError( f"'{value}' is not a valid identifier. It must start with a letter or underscore, " f"and contain only letters, numbers, or underscores." ) return value class HumanInputNodeData(BaseNodeData): """Human Input node data.""" type: NodeType = BuiltinNodeTypes.HUMAN_INPUT form_content: str = "" inputs: list[FormInput] = Field(default_factory=list) user_actions: list[UserAction] = Field(default_factory=list) timeout: int = 36 timeout_unit: TimeoutUnit = TimeoutUnit.HOUR @field_validator("inputs") @classmethod def _validate_inputs(cls, inputs: list[FormInput]) -> list[FormInput]: seen_names: set[str] = set() for form_input in inputs: name = form_input.output_variable_name if name in seen_names: raise ValueError(f"duplicated output_variable_name '{name}' in inputs") seen_names.add(name) return inputs @field_validator("user_actions") @classmethod def _validate_user_actions(cls, user_actions: list[UserAction]) -> list[UserAction]: seen_ids: set[str] = set() for action in user_actions: action_id = action.id if action_id in seen_ids: raise ValueError(f"duplicated user action id '{action_id}'") seen_ids.add(action_id) return user_actions def expiration_time(self, start_time: datetime) -> datetime: if self.timeout_unit == TimeoutUnit.HOUR: return start_time + timedelta(hours=self.timeout) elif self.timeout_unit == TimeoutUnit.DAY: return start_time + timedelta(days=self.timeout) else: raise AssertionError("unknown timeout unit.") def outputs_field_names(self) -> Sequence[str]: field_names = [] for match in _OUTPUT_VARIABLE_PATTERN.finditer(self.form_content): field_names.append(match.group("field_name")) return field_names def extract_variable_selector_to_variable_mapping(self, node_id: str) -> Mapping[str, Sequence[str]]: variable_mappings: dict[str, Sequence[str]] = {} def _add_variable_selectors(selectors: Sequence[Sequence[str]]) -> None: for selector in selectors: if len(selector) < SELECTORS_LENGTH: continue qualified_variable_mapping_key = f"{node_id}.#{'.'.join(selector[:SELECTORS_LENGTH])}#" variable_mappings[qualified_variable_mapping_key] = list(selector[:SELECTORS_LENGTH]) form_template_parser = VariableTemplateParser(template=self.form_content) _add_variable_selectors( [selector.value_selector for selector in form_template_parser.extract_variable_selectors()] ) for input in self.inputs: default_value = input.default if default_value is None: continue if default_value.type == PlaceholderType.CONSTANT: continue default_value_key = ".".join(default_value.selector) qualified_variable_mapping_key = f"{node_id}.#{default_value_key}#" variable_mappings[qualified_variable_mapping_key] = default_value.selector return variable_mappings def find_action_text(self, action_id: str) -> str: """ Resolve action display text by id. """ for action in self.user_actions: if action.id == action_id: return action.title return action_id class FormDefinition(BaseModel): form_content: str inputs: list[FormInput] = Field(default_factory=list) user_actions: list[UserAction] = Field(default_factory=list) rendered_content: str expiration_time: datetime # this is used to store the resolved default values default_values: dict[str, Any] = Field(default_factory=dict) # node_title records the title of the HumanInput node. node_title: str | None = None # display_in_ui controls whether the form should be displayed in UI surfaces. display_in_ui: bool | None = None class HumanInputSubmissionValidationError(ValueError): pass def validate_human_input_submission( *, inputs: Sequence[FormInput], user_actions: Sequence[UserAction], selected_action_id: str, form_data: Mapping[str, Any], ) -> None: available_actions = {action.id for action in user_actions} if selected_action_id not in available_actions: raise HumanInputSubmissionValidationError(f"Invalid action: {selected_action_id}") provided_inputs = set(form_data.keys()) missing_inputs = [ form_input.output_variable_name for form_input in inputs if form_input.output_variable_name not in provided_inputs ] if missing_inputs: missing_list = ", ".join(missing_inputs) raise HumanInputSubmissionValidationError(f"Missing required inputs: {missing_list}")