from __future__ import annotations import json import logging import re from collections import defaultdict from collections.abc import Sequence from typing import Any from graphon.nodes.human_input.entities import FormDefinition from graphon.nodes.human_input.enums import HumanInputFormStatus from graphon.nodes.human_input.human_input_node import HumanInputNode from sqlalchemy import select from sqlalchemy.orm import Session, selectinload, sessionmaker from core.entities.execution_extra_content import ( ExecutionExtraContentDomainModel, HumanInputFormDefinition, HumanInputFormSubmissionData, ) from core.entities.execution_extra_content import ( HumanInputContent as HumanInputContentDomainModel, ) from models.execution_extra_content import ( ExecutionExtraContent as ExecutionExtraContentModel, ) from models.execution_extra_content import ( HumanInputContent as HumanInputContentModel, ) from models.human_input import HumanInputFormRecipient, RecipientType from repositories.execution_extra_content_repository import ExecutionExtraContentRepository logger = logging.getLogger(__name__) _OUTPUT_VARIABLE_PATTERN = re.compile(r"\{\{#\$output\.(?P[a-zA-Z_][a-zA-Z0-9_]{0,29})#\}\}") def _extract_output_field_names(form_content: str) -> list[str]: if not form_content: return [] return [match.group("field_name") for match in _OUTPUT_VARIABLE_PATTERN.finditer(form_content)] class SQLAlchemyExecutionExtraContentRepository(ExecutionExtraContentRepository): def __init__(self, session_maker: sessionmaker[Session]): self._session_maker = session_maker def get_by_message_ids(self, message_ids: Sequence[str]) -> list[list[ExecutionExtraContentDomainModel]]: if not message_ids: return [] grouped_contents: dict[str, list[ExecutionExtraContentDomainModel]] = { message_id: [] for message_id in message_ids } stmt = ( select(ExecutionExtraContentModel) .where(ExecutionExtraContentModel.message_id.in_(message_ids)) .options(selectinload(HumanInputContentModel.form)) .order_by(ExecutionExtraContentModel.created_at.asc()) ) with self._session_maker() as session: results = session.scalars(stmt).all() form_ids = { content.form_id for content in results if isinstance(content, HumanInputContentModel) and content.form_id is not None } recipients_by_form_id: dict[str, list[HumanInputFormRecipient]] = defaultdict(list) if form_ids: recipient_stmt = select(HumanInputFormRecipient).where(HumanInputFormRecipient.form_id.in_(form_ids)) recipients = session.scalars(recipient_stmt).all() for recipient in recipients: recipients_by_form_id[recipient.form_id].append(recipient) else: recipients_by_form_id = {} for content in results: message_id = content.message_id if not message_id or message_id not in grouped_contents: continue domain_model = self._map_model_to_domain(content, recipients_by_form_id) if domain_model is None: continue grouped_contents[message_id].append(domain_model) return [grouped_contents[message_id] for message_id in message_ids] def _map_model_to_domain( self, model: ExecutionExtraContentModel, recipients_by_form_id: dict[str, list[HumanInputFormRecipient]], ) -> ExecutionExtraContentDomainModel | None: if isinstance(model, HumanInputContentModel): return self._map_human_input_content(model, recipients_by_form_id) logger.debug("Unsupported execution extra content type encountered: %s", model.type) return None def _map_human_input_content( self, model: HumanInputContentModel, recipients_by_form_id: dict[str, list[HumanInputFormRecipient]], ) -> HumanInputContentDomainModel | None: form = model.form if form is None: logger.warning("HumanInputContent(id=%s) has no associated form loaded", model.id) return None try: definition_payload = json.loads(form.form_definition) if "expiration_time" not in definition_payload: definition_payload["expiration_time"] = form.expiration_time form_definition = FormDefinition.model_validate(definition_payload) except ValueError: logger.warning("Failed to load form definition for HumanInputContent(id=%s)", model.id) return None node_title = form_definition.node_title or form.node_id display_in_ui = bool(form_definition.display_in_ui) submitted = form.submitted_at is not None or form.status == HumanInputFormStatus.SUBMITTED if not submitted: form_token = self._resolve_form_token(recipients_by_form_id.get(form.id, [])) return HumanInputContentDomainModel( workflow_run_id=model.workflow_run_id, submitted=False, form_definition=HumanInputFormDefinition( form_id=form.id, node_id=form.node_id, node_title=node_title, form_content=form.rendered_content, inputs=form_definition.inputs, actions=form_definition.user_actions, display_in_ui=display_in_ui, form_token=form_token, resolved_default_values=form_definition.default_values, expiration_time=int(form.expiration_time.timestamp()), ), ) selected_action_id = form.selected_action_id if not selected_action_id: logger.warning("HumanInputContent(id=%s) form has no selected action", model.id) return None action_text = next( (action.title for action in form_definition.user_actions if action.id == selected_action_id), selected_action_id, ) submitted_data: dict[str, Any] = {} if form.submitted_data: try: submitted_data = json.loads(form.submitted_data) except ValueError: logger.warning("Failed to load submitted data for HumanInputContent(id=%s)", model.id) return None rendered_content = HumanInputNode.render_form_content_with_outputs( form.rendered_content, submitted_data, _extract_output_field_names(form_definition.form_content), ) return HumanInputContentDomainModel( workflow_run_id=model.workflow_run_id, submitted=True, form_submission_data=HumanInputFormSubmissionData( node_id=form.node_id, node_title=node_title, rendered_content=rendered_content, action_id=selected_action_id, action_text=action_text, ), ) @staticmethod def _resolve_form_token(recipients: Sequence[HumanInputFormRecipient]) -> str | None: console_recipient = next( (recipient for recipient in recipients if recipient.recipient_type == RecipientType.CONSOLE), None, ) if console_recipient and console_recipient.access_token: return console_recipient.access_token web_app_recipient = next( (recipient for recipient in recipients if recipient.recipient_type == RecipientType.STANDALONE_WEB_APP), None, ) if web_app_recipient and web_app_recipient.access_token: return web_app_recipient.access_token return None __all__ = ["SQLAlchemyExecutionExtraContentRepository"]