from __future__ import annotations from dataclasses import dataclass, field from enum import StrEnum from typing import Protocol from graphon.runtime import VariablePool from sqlalchemy import Engine, select from sqlalchemy.orm import sessionmaker from configs import dify_config from core.workflow.human_input_compat import ( DeliveryChannelConfig, EmailDeliveryConfig, EmailDeliveryMethod, ExternalRecipient, MemberRecipient, ) from extensions.ext_database import db from extensions.ext_mail import mail from libs.email_template_renderer import render_email_template from models import Account, TenantAccountJoin from services.feature_service import FeatureService class DeliveryTestStatus(StrEnum): OK = "ok" FAILED = "failed" @dataclass(frozen=True) class DeliveryTestEmailRecipient: email: str form_token: str @dataclass(frozen=True) class DeliveryTestContext: tenant_id: str app_id: str node_id: str node_title: str | None rendered_content: str template_vars: dict[str, str] = field(default_factory=dict) recipients: list[DeliveryTestEmailRecipient] = field(default_factory=list) variable_pool: VariablePool | None = None @dataclass(frozen=True) class DeliveryTestResult: status: DeliveryTestStatus delivered_to: list[str] = field(default_factory=list) warnings: list[str] = field(default_factory=list) class DeliveryTestError(Exception): pass class DeliveryTestUnsupportedError(DeliveryTestError): pass def _build_form_link(token: str | None) -> str | None: if not token: return None base_url = dify_config.APP_WEB_URL if not base_url: return None return f"{base_url.rstrip('/')}/form/{token}" class DeliveryTestHandler(Protocol): def supports(self, method: DeliveryChannelConfig) -> bool: ... def send_test( self, *, context: DeliveryTestContext, method: DeliveryChannelConfig, ) -> DeliveryTestResult: ... class DeliveryTestRegistry: def __init__(self, handlers: list[DeliveryTestHandler] | None = None) -> None: self._handlers = list(handlers or []) def register(self, handler: DeliveryTestHandler) -> None: self._handlers.append(handler) def dispatch( self, *, context: DeliveryTestContext, method: DeliveryChannelConfig, ) -> DeliveryTestResult: for handler in self._handlers: if handler.supports(method): return handler.send_test(context=context, method=method) raise DeliveryTestUnsupportedError("Delivery method does not support test send.") @classmethod def default(cls) -> DeliveryTestRegistry: return cls([EmailDeliveryTestHandler()]) class HumanInputDeliveryTestService: def __init__(self, registry: DeliveryTestRegistry | None = None) -> None: self._registry = registry or DeliveryTestRegistry.default() def send_test( self, *, context: DeliveryTestContext, method: DeliveryChannelConfig, ) -> DeliveryTestResult: return self._registry.dispatch(context=context, method=method) class EmailDeliveryTestHandler: def __init__(self, session_factory: sessionmaker | Engine | None = None) -> None: if session_factory is None: session_factory = sessionmaker(bind=db.engine) elif isinstance(session_factory, Engine): session_factory = sessionmaker(bind=session_factory) self._session_factory = session_factory def supports(self, method: DeliveryChannelConfig) -> bool: return isinstance(method, EmailDeliveryMethod) def send_test( self, *, context: DeliveryTestContext, method: DeliveryChannelConfig, ) -> DeliveryTestResult: if not isinstance(method, EmailDeliveryMethod): raise DeliveryTestUnsupportedError("Delivery method does not support test send.") features = FeatureService.get_features(context.tenant_id) if not features.human_input_email_delivery_enabled: raise DeliveryTestError("Email delivery is not available for current plan.") if not mail.is_inited(): raise DeliveryTestError("Mail client is not initialized.") recipients = self._resolve_recipients( tenant_id=context.tenant_id, method=method, ) if not recipients: raise DeliveryTestError("No recipients configured for delivery method.") delivered: list[str] = [] for recipient_email in recipients: substitutions = self._build_substitutions( context=context, recipient_email=recipient_email, ) subject_template = render_email_template(method.config.subject, substitutions) subject = EmailDeliveryConfig.sanitize_subject(subject_template) templated_body = EmailDeliveryConfig.render_body_template( body=method.config.body, url=substitutions.get("form_link"), variable_pool=context.variable_pool, ) body = render_email_template(templated_body, substitutions) body = EmailDeliveryConfig.render_markdown_body(body) mail.send( to=recipient_email, subject=subject, html=body, ) delivered.append(recipient_email) return DeliveryTestResult(status=DeliveryTestStatus.OK, delivered_to=delivered) def _resolve_recipients(self, *, tenant_id: str, method: EmailDeliveryMethod) -> list[str]: recipients = method.config.recipients emails: list[str] = [] bound_reference_ids: list[str] = [] for recipient in recipients.items: if isinstance(recipient, MemberRecipient): bound_reference_ids.append(recipient.reference_id) elif isinstance(recipient, ExternalRecipient): if recipient.email: emails.append(recipient.email) if recipients.include_bound_group: bound_reference_ids = [] member_emails = self._query_workspace_member_emails(tenant_id=tenant_id, user_ids=None) emails.extend(member_emails.values()) elif bound_reference_ids: member_emails = self._query_workspace_member_emails(tenant_id=tenant_id, user_ids=bound_reference_ids) for user_id in bound_reference_ids: email = member_emails.get(user_id) if email: emails.append(email) return list(dict.fromkeys([email for email in emails if email])) def _query_workspace_member_emails( self, *, tenant_id: str, user_ids: list[str] | None, ) -> dict[str, str]: if user_ids is None: unique_ids = None else: unique_ids = {user_id for user_id in user_ids if user_id} if not unique_ids: return {} stmt = ( select(Account.id, Account.email) .join(TenantAccountJoin, TenantAccountJoin.account_id == Account.id) .where(TenantAccountJoin.tenant_id == tenant_id) ) if unique_ids is not None: stmt = stmt.where(Account.id.in_(unique_ids)) with self._session_factory() as session: rows = session.execute(stmt).tuples().all() return dict(rows) @staticmethod def _build_substitutions( *, context: DeliveryTestContext, recipient_email: str, ) -> dict[str, str]: raw_values: dict[str, str | None] = { "form_id": "", "node_title": context.node_title, "workflow_run_id": "", "form_token": "", "form_link": "", "form_content": context.rendered_content, "recipient_email": recipient_email, } substitutions = {key: value or "" for key, value in raw_values.items()} if context.template_vars: substitutions.update({key: value for key, value in context.template_vars.items() if value is not None}) token = next( (recipient.form_token for recipient in context.recipients if recipient.email == recipient_email), None, ) if token: substitutions["form_token"] = token link = _build_form_link(token) substitutions["form_link"] = link if link is not None else f"/form/{token}" return substitutions