mirror of
https://github.com/langgenius/dify.git
synced 2026-05-09 21:28:25 +08:00
feat(api): introduce file upload apis for human input page
This commit is contained in:
parent
d6f607f6e7
commit
51e181c588
@ -23,6 +23,7 @@ from . import (
|
||||
feature,
|
||||
files,
|
||||
forgot_password,
|
||||
human_input_file_upload,
|
||||
human_input_form,
|
||||
login,
|
||||
message,
|
||||
@ -46,6 +47,7 @@ __all__ = [
|
||||
"feature",
|
||||
"files",
|
||||
"forgot_password",
|
||||
"human_input_file_upload",
|
||||
"human_input_form",
|
||||
"login",
|
||||
"message",
|
||||
|
||||
181
api/controllers/web/human_input_file_upload.py
Normal file
181
api/controllers/web/human_input_file_upload.py
Normal file
@ -0,0 +1,181 @@
|
||||
import httpx
|
||||
from flask import request
|
||||
from flask_restx import Resource
|
||||
from pydantic import BaseModel, Field, HttpUrl
|
||||
|
||||
import services
|
||||
from controllers.common import helpers
|
||||
from controllers.common.errors import (
|
||||
BlockedFileExtensionError,
|
||||
FileTooLargeError,
|
||||
NoFileUploadedError,
|
||||
RemoteFileUploadError,
|
||||
TooManyFilesError,
|
||||
UnsupportedFileTypeError,
|
||||
)
|
||||
from controllers.common.schema import register_schema_models
|
||||
from controllers.web import web_ns
|
||||
from core.helper import ssrf_proxy
|
||||
from extensions.ext_database import db
|
||||
from fields.file_fields import FileResponse, FileWithSignedUrl
|
||||
from graphon.file import helpers as file_helpers
|
||||
from libs.exception import BaseHTTPException
|
||||
from services.file_service import FileService
|
||||
from services.human_input_file_upload_service import (
|
||||
HITL_UPLOAD_TOKEN_PREFIX,
|
||||
HumanInputFileUploadService,
|
||||
InvalidUploadTokenError,
|
||||
)
|
||||
|
||||
|
||||
class InvalidUploadTokenBadRequestError(BaseHTTPException):
|
||||
error_code = "invalid_upload_token"
|
||||
description = "Invalid upload token."
|
||||
code = 400
|
||||
|
||||
|
||||
class InvalidUploadTokenUnauthorizedError(BaseHTTPException):
|
||||
error_code = "invalid_upload_token"
|
||||
description = "Upload token is required."
|
||||
code = 401
|
||||
|
||||
|
||||
class InvalidUploadTokenForbiddenError(BaseHTTPException):
|
||||
error_code = "invalid_upload_token"
|
||||
description = "Upload token is invalid or expired."
|
||||
code = 403
|
||||
|
||||
|
||||
class HumanInputRemoteFileUploadPayload(BaseModel):
|
||||
url: HttpUrl = Field(description="Remote file URL")
|
||||
|
||||
|
||||
register_schema_models(web_ns, HumanInputRemoteFileUploadPayload, FileResponse, FileWithSignedUrl)
|
||||
|
||||
|
||||
def _extract_hitl_upload_token() -> str:
|
||||
"""Read HITL upload token from Authorization without invoking other bearer auth chains."""
|
||||
|
||||
authorization = request.headers.get("Authorization")
|
||||
if authorization is None:
|
||||
raise InvalidUploadTokenUnauthorizedError()
|
||||
|
||||
parts = authorization.split()
|
||||
if len(parts) != 2:
|
||||
raise InvalidUploadTokenUnauthorizedError()
|
||||
|
||||
scheme, token = parts
|
||||
if scheme.lower() != "bearer":
|
||||
raise InvalidUploadTokenBadRequestError()
|
||||
if not token:
|
||||
raise InvalidUploadTokenUnauthorizedError()
|
||||
if not token.startswith(HITL_UPLOAD_TOKEN_PREFIX):
|
||||
raise InvalidUploadTokenBadRequestError()
|
||||
return token
|
||||
|
||||
|
||||
def _validate_context(service: HumanInputFileUploadService, token: str):
|
||||
try:
|
||||
return service.validate_upload_token(token)
|
||||
except InvalidUploadTokenError as exc:
|
||||
raise InvalidUploadTokenForbiddenError() from exc
|
||||
|
||||
|
||||
def _parse_local_upload_file():
|
||||
if "file" not in request.files:
|
||||
raise NoFileUploadedError()
|
||||
if len(request.files) > 1:
|
||||
raise TooManyFilesError()
|
||||
|
||||
file = request.files["file"]
|
||||
if not file.filename:
|
||||
from controllers.common.errors import FilenameNotExistsError
|
||||
|
||||
raise FilenameNotExistsError()
|
||||
|
||||
return file
|
||||
|
||||
|
||||
@web_ns.route("/form/human_input/files/upload")
|
||||
class HumanInputFileUploadApi(Resource):
|
||||
def post(self):
|
||||
"""Upload one local file for a HITL human input form."""
|
||||
|
||||
token = _extract_hitl_upload_token()
|
||||
upload_service = HumanInputFileUploadService(db.engine)
|
||||
context = _validate_context(upload_service, token)
|
||||
file = _parse_local_upload_file()
|
||||
|
||||
try:
|
||||
upload_file = FileService(db.engine).upload_file(
|
||||
filename=file.filename or "",
|
||||
content=file.read(),
|
||||
mimetype=file.mimetype,
|
||||
user=context.end_user,
|
||||
source=None,
|
||||
)
|
||||
except services.errors.file.FileTooLargeError as file_too_large_error:
|
||||
raise FileTooLargeError(file_too_large_error.description)
|
||||
except services.errors.file.UnsupportedFileTypeError:
|
||||
raise UnsupportedFileTypeError()
|
||||
except services.errors.file.BlockedFileExtensionError as exc:
|
||||
raise BlockedFileExtensionError() from exc
|
||||
|
||||
upload_service.record_upload_file(context=context, file_id=upload_file.id)
|
||||
response = FileResponse.model_validate(upload_file, from_attributes=True)
|
||||
return response.model_dump(mode="json"), 201
|
||||
|
||||
|
||||
@web_ns.route("/form/human_input/files/remote-upload")
|
||||
class HumanInputRemoteFileUploadApi(Resource):
|
||||
def post(self):
|
||||
"""Upload one remote URL file for a HITL human input form."""
|
||||
|
||||
token = _extract_hitl_upload_token()
|
||||
upload_service = HumanInputFileUploadService(db.engine)
|
||||
context = _validate_context(upload_service, token)
|
||||
payload = HumanInputRemoteFileUploadPayload.model_validate(request.get_json(silent=True) or {})
|
||||
url = str(payload.url)
|
||||
|
||||
try:
|
||||
resp = ssrf_proxy.head(url=url)
|
||||
if resp.status_code != httpx.codes.OK:
|
||||
resp = ssrf_proxy.get(url=url, timeout=3, follow_redirects=True)
|
||||
if resp.status_code != httpx.codes.OK:
|
||||
raise RemoteFileUploadError(f"Failed to fetch file from {url}: {resp.text}")
|
||||
except httpx.RequestError as exc:
|
||||
raise RemoteFileUploadError(f"Failed to fetch file from {url}: {str(exc)}")
|
||||
|
||||
file_info = helpers.guess_file_info_from_response(resp)
|
||||
if not FileService.is_file_size_within_limit(extension=file_info.extension, file_size=file_info.size):
|
||||
raise FileTooLargeError()
|
||||
|
||||
content = resp.content if resp.request.method == "GET" else ssrf_proxy.get(url).content
|
||||
|
||||
try:
|
||||
upload_file = FileService(db.engine).upload_file(
|
||||
filename=file_info.filename,
|
||||
content=content,
|
||||
mimetype=file_info.mimetype,
|
||||
user=context.end_user,
|
||||
source_url=url,
|
||||
)
|
||||
except services.errors.file.FileTooLargeError as file_too_large_error:
|
||||
raise FileTooLargeError(file_too_large_error.description)
|
||||
except services.errors.file.UnsupportedFileTypeError:
|
||||
raise UnsupportedFileTypeError()
|
||||
except services.errors.file.BlockedFileExtensionError as exc:
|
||||
raise BlockedFileExtensionError() from exc
|
||||
|
||||
upload_service.record_upload_file(context=context, file_id=upload_file.id)
|
||||
payload1 = FileWithSignedUrl(
|
||||
id=upload_file.id,
|
||||
name=upload_file.name,
|
||||
size=upload_file.size,
|
||||
extension=upload_file.extension,
|
||||
url=file_helpers.get_signed_file_url(upload_file_id=upload_file.id),
|
||||
mime_type=upload_file.mime_type,
|
||||
created_by=upload_file.created_by,
|
||||
created_at=int(upload_file.created_at.timestamp()),
|
||||
)
|
||||
return payload1.model_dump(mode="json"), 201
|
||||
@ -14,6 +14,7 @@ from sqlalchemy import select
|
||||
from werkzeug.exceptions import Forbidden
|
||||
|
||||
from configs import dify_config
|
||||
from controllers.common.schema import register_schema_models
|
||||
from controllers.web import web_ns
|
||||
from controllers.web.error import NotFoundError, WebFormRateLimitExceededError
|
||||
from controllers.web.site import serialize_app_site_payload
|
||||
@ -21,6 +22,7 @@ from extensions.ext_database import db
|
||||
from libs.helper import RateLimiter, extract_remote_ip
|
||||
from models.account import TenantStatus
|
||||
from models.model import App, Site
|
||||
from services.human_input_file_upload_service import HumanInputFileUploadService
|
||||
from services.human_input_service import Form, FormNotFoundError, HumanInputService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@ -31,6 +33,14 @@ class HumanInputFormSubmitPayload(BaseModel):
|
||||
action: str
|
||||
|
||||
|
||||
class HumanInputUploadTokenResponse(BaseModel):
|
||||
upload_token: str
|
||||
expires_at: int
|
||||
|
||||
|
||||
register_schema_models(web_ns, HumanInputUploadTokenResponse)
|
||||
|
||||
|
||||
_FORM_SUBMIT_RATE_LIMITER = RateLimiter(
|
||||
prefix="web_form_submit_rate_limit",
|
||||
max_attempts=dify_config.WEB_FORM_SUBMIT_RATE_LIMIT_MAX_ATTEMPTS,
|
||||
@ -41,6 +51,11 @@ _FORM_ACCESS_RATE_LIMITER = RateLimiter(
|
||||
max_attempts=dify_config.WEB_FORM_SUBMIT_RATE_LIMIT_MAX_ATTEMPTS,
|
||||
time_window=dify_config.WEB_FORM_SUBMIT_RATE_LIMIT_WINDOW_SECONDS,
|
||||
)
|
||||
_FORM_UPLOAD_TOKEN_RATE_LIMITER = RateLimiter(
|
||||
prefix="web_form_upload_token_rate_limit",
|
||||
max_attempts=dify_config.WEB_FORM_SUBMIT_RATE_LIMIT_MAX_ATTEMPTS,
|
||||
time_window=dify_config.WEB_FORM_SUBMIT_RATE_LIMIT_WINDOW_SECONDS,
|
||||
)
|
||||
|
||||
|
||||
def _stringify_default_values(values: dict[str, object]) -> dict[str, str]:
|
||||
@ -83,6 +98,33 @@ def _jsonify_form_definition(form: Form, site_payload: dict | None = None) -> Re
|
||||
return Response(json.dumps(payload, ensure_ascii=False), mimetype="application/json")
|
||||
|
||||
|
||||
@web_ns.route("/form/human_input/<string:form_token>/upload-token")
|
||||
class HumanInputFormUploadTokenApi(Resource):
|
||||
"""API for issuing HITL upload tokens for active human input forms."""
|
||||
|
||||
def post(self, form_token: str):
|
||||
"""
|
||||
Issue an upload token for a human input form.
|
||||
|
||||
POST /api/form/human_input/<form_token>/upload-token
|
||||
"""
|
||||
ip_address = extract_remote_ip(request)
|
||||
if _FORM_UPLOAD_TOKEN_RATE_LIMITER.is_rate_limited(ip_address):
|
||||
raise WebFormRateLimitExceededError()
|
||||
_FORM_UPLOAD_TOKEN_RATE_LIMITER.increment_rate_limit(ip_address)
|
||||
|
||||
try:
|
||||
token = HumanInputFileUploadService(db.engine).issue_upload_token(form_token)
|
||||
except FormNotFoundError:
|
||||
raise NotFoundError("Form not found")
|
||||
|
||||
response = HumanInputUploadTokenResponse(
|
||||
upload_token=token.upload_token,
|
||||
expires_at=_to_timestamp(token.expires_at),
|
||||
)
|
||||
return response.model_dump(mode="json"), 200
|
||||
|
||||
|
||||
@web_ns.route("/form/human_input/<string:form_token>")
|
||||
class HumanInputFormApi(Resource):
|
||||
"""API for getting and submitting human input forms via the web app."""
|
||||
|
||||
@ -0,0 +1,66 @@
|
||||
"""Add human input upload token and file association tables
|
||||
|
||||
Revision ID: 8d4c2a1b9f03
|
||||
Revises: 227822d22895
|
||||
Create Date: 2026-05-06 12:00:00.000000
|
||||
|
||||
"""
|
||||
|
||||
import sqlalchemy as sa
|
||||
from alembic import op
|
||||
|
||||
import models as models
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "8d4c2a1b9f03"
|
||||
down_revision = "227822d22895"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade():
|
||||
op.create_table(
|
||||
"human_input_form_upload_tokens",
|
||||
sa.Column("id", models.types.StringUUID(), nullable=False),
|
||||
sa.Column("created_at", sa.DateTime(), server_default=sa.text("CURRENT_TIMESTAMP"), nullable=False),
|
||||
sa.Column("updated_at", sa.DateTime(), server_default=sa.text("CURRENT_TIMESTAMP"), nullable=False),
|
||||
sa.Column("tenant_id", models.types.StringUUID(), nullable=False),
|
||||
sa.Column("app_id", models.types.StringUUID(), nullable=False),
|
||||
sa.Column("form_id", models.types.StringUUID(), nullable=False),
|
||||
sa.Column("recipient_id", models.types.StringUUID(), nullable=False),
|
||||
sa.Column("end_user_id", models.types.StringUUID(), nullable=False),
|
||||
sa.Column("token", sa.String(length=255), nullable=False),
|
||||
sa.PrimaryKeyConstraint("id", name="human_input_form_upload_tokens_pkey"),
|
||||
sa.UniqueConstraint("token", name="human_input_form_upload_tokens_token_key"),
|
||||
)
|
||||
with op.batch_alter_table("human_input_form_upload_tokens", schema=None) as batch_op:
|
||||
batch_op.create_index("human_input_form_upload_tokens_form_id_idx", ["form_id"], unique=False)
|
||||
|
||||
op.create_table(
|
||||
"human_input_form_upload_files",
|
||||
sa.Column("id", models.types.StringUUID(), nullable=False),
|
||||
sa.Column("created_at", sa.DateTime(), server_default=sa.text("CURRENT_TIMESTAMP"), nullable=False),
|
||||
sa.Column("updated_at", sa.DateTime(), server_default=sa.text("CURRENT_TIMESTAMP"), nullable=False),
|
||||
sa.Column("tenant_id", models.types.StringUUID(), nullable=False),
|
||||
sa.Column("form_id", models.types.StringUUID(), nullable=False),
|
||||
sa.Column("upload_file_id", models.types.StringUUID(), nullable=False),
|
||||
sa.Column("upload_token_id", models.types.StringUUID(), nullable=False),
|
||||
sa.Column("end_user_id", models.types.StringUUID(), nullable=False),
|
||||
sa.PrimaryKeyConstraint("id", name="human_input_form_upload_files_pkey"),
|
||||
sa.UniqueConstraint("upload_file_id", name="human_input_form_upload_files_upload_file_id_key"),
|
||||
)
|
||||
with op.batch_alter_table("human_input_form_upload_files", schema=None) as batch_op:
|
||||
batch_op.create_index("human_input_form_upload_files_form_id_idx", ["form_id"], unique=False)
|
||||
batch_op.create_index("human_input_form_upload_files_upload_token_id_idx", ["upload_token_id"], unique=False)
|
||||
|
||||
|
||||
def downgrade():
|
||||
with op.batch_alter_table("human_input_form_upload_files", schema=None) as batch_op:
|
||||
batch_op.drop_index("human_input_form_upload_files_upload_token_id_idx")
|
||||
batch_op.drop_index("human_input_form_upload_files_form_id_idx")
|
||||
op.drop_table("human_input_form_upload_files")
|
||||
|
||||
with op.batch_alter_table("human_input_form_upload_tokens", schema=None) as batch_op:
|
||||
batch_op.drop_index("human_input_form_upload_tokens_form_id_idx")
|
||||
op.drop_table("human_input_form_upload_tokens")
|
||||
@ -39,7 +39,7 @@ from .enums import (
|
||||
WorkflowTriggerStatus,
|
||||
)
|
||||
from .execution_extra_content import ExecutionExtraContent, HumanInputContent
|
||||
from .human_input import HumanInputForm
|
||||
from .human_input import HumanInputForm, HumanInputFormUploadFile, HumanInputFormUploadToken
|
||||
from .model import (
|
||||
AccountTrialAppRecord,
|
||||
ApiRequest,
|
||||
@ -167,6 +167,8 @@ __all__ = [
|
||||
"ExternalKnowledgeBindings",
|
||||
"HumanInputContent",
|
||||
"HumanInputForm",
|
||||
"HumanInputFormUploadFile",
|
||||
"HumanInputFormUploadToken",
|
||||
"IconType",
|
||||
"InstalledApp",
|
||||
"InvitationCode",
|
||||
|
||||
@ -251,3 +251,50 @@ class HumanInputFormRecipient(DefaultFieldsMixin, Base):
|
||||
access_token=_generate_token(),
|
||||
)
|
||||
return recipient_model
|
||||
|
||||
|
||||
class HumanInputFormUploadToken(DefaultFieldsMixin, Base):
|
||||
"""Upload authorization token bound to one human input form recipient.
|
||||
|
||||
HITL upload tokens are intentionally separate from app/service bearer tokens.
|
||||
The token is stored as an opaque random value so upload endpoints can perform
|
||||
a direct lookup without entering the normal Web App authentication chain.
|
||||
"""
|
||||
|
||||
__tablename__ = "human_input_form_upload_tokens"
|
||||
__table_args__ = (
|
||||
sa.UniqueConstraint("token", name="human_input_form_upload_tokens_token_key"),
|
||||
sa.Index("human_input_form_upload_tokens_form_id_idx", "form_id"),
|
||||
)
|
||||
|
||||
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
app_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
form_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
recipient_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
end_user_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
token: Mapped[str] = mapped_column(sa.String(255), nullable=False)
|
||||
|
||||
form: Mapped[HumanInputForm] = relationship(
|
||||
"HumanInputForm",
|
||||
uselist=False,
|
||||
foreign_keys=[form_id],
|
||||
primaryjoin="foreign(HumanInputFormUploadToken.form_id) == HumanInputForm.id",
|
||||
lazy="raise",
|
||||
)
|
||||
|
||||
|
||||
class HumanInputFormUploadFile(DefaultFieldsMixin, Base):
|
||||
"""Association between a human input form and a file uploaded through its token."""
|
||||
|
||||
__tablename__ = "human_input_form_upload_files"
|
||||
__table_args__ = (
|
||||
sa.UniqueConstraint("upload_file_id", name="human_input_form_upload_files_upload_file_id_key"),
|
||||
sa.Index("human_input_form_upload_files_form_id_idx", "form_id"),
|
||||
sa.Index("human_input_form_upload_files_upload_token_id_idx", "upload_token_id"),
|
||||
)
|
||||
|
||||
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
form_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
upload_file_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
upload_token_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
end_user_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
|
||||
201
api/services/human_input_file_upload_service.py
Normal file
201
api/services/human_input_file_upload_service.py
Normal file
@ -0,0 +1,201 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import secrets
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
from sqlalchemy import Engine, select
|
||||
from sqlalchemy.orm import Session, selectinload, sessionmaker
|
||||
|
||||
from configs import dify_config
|
||||
from graphon.nodes.human_input.enums import HumanInputFormStatus
|
||||
from libs.datetime_utils import ensure_naive_utc, naive_utc_now
|
||||
from models.human_input import (
|
||||
HumanInputForm,
|
||||
HumanInputFormRecipient,
|
||||
HumanInputFormUploadFile,
|
||||
HumanInputFormUploadToken,
|
||||
)
|
||||
from models.model import EndUser
|
||||
from services.human_input_service import FormExpiredError, FormNotFoundError, FormSubmittedError
|
||||
|
||||
HITL_UPLOAD_TOKEN_PREFIX = "hitl_upload_"
|
||||
HUMAN_INPUT_END_USER_TYPE = "human-input"
|
||||
HUMAN_INPUT_END_USER_SESSION_PREFIX = "hitl:recipient:"
|
||||
_TOKEN_RANDOM_BYTES = 32
|
||||
_TOKEN_GENERATION_ATTEMPTS = 10
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class HumanInputUploadToken:
|
||||
upload_token: str
|
||||
expires_at: datetime
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class HumanInputUploadContext:
|
||||
tenant_id: str
|
||||
app_id: str
|
||||
form_id: str
|
||||
recipient_id: str
|
||||
upload_token_id: str
|
||||
end_user: EndUser
|
||||
|
||||
|
||||
class InvalidUploadTokenError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class HumanInputFileUploadService:
|
||||
"""Coordinates HITL upload tokens, technical EndUsers, and form-file links."""
|
||||
|
||||
_session_maker: sessionmaker[Session]
|
||||
|
||||
def __init__(self, session_factory: sessionmaker[Session] | Engine):
|
||||
if isinstance(session_factory, Engine):
|
||||
session_factory = sessionmaker(bind=session_factory)
|
||||
self._session_maker = session_factory
|
||||
|
||||
def issue_upload_token(self, form_token: str) -> HumanInputUploadToken:
|
||||
"""Create an upload token for an active human input recipient token."""
|
||||
|
||||
with self._session_maker(expire_on_commit=False) as session, session.begin():
|
||||
recipient_model = session.scalar(
|
||||
select(HumanInputFormRecipient)
|
||||
.options(selectinload(HumanInputFormRecipient.form))
|
||||
.where(HumanInputFormRecipient.access_token == form_token)
|
||||
.limit(1)
|
||||
)
|
||||
if recipient_model is None or recipient_model.form is None:
|
||||
raise FormNotFoundError()
|
||||
|
||||
form = recipient_model.form
|
||||
self._ensure_form_model_active(form)
|
||||
end_user = self._get_or_create_human_input_end_user(
|
||||
session=session,
|
||||
tenant_id=form.tenant_id,
|
||||
app_id=form.app_id,
|
||||
recipient_id=recipient_model.id,
|
||||
)
|
||||
upload_token = self._generate_unique_upload_token(session)
|
||||
token_model = HumanInputFormUploadToken(
|
||||
tenant_id=form.tenant_id,
|
||||
app_id=form.app_id,
|
||||
form_id=form.id,
|
||||
recipient_id=recipient_model.id,
|
||||
end_user_id=end_user.id,
|
||||
token=upload_token,
|
||||
)
|
||||
session.add(token_model)
|
||||
|
||||
return HumanInputUploadToken(upload_token=upload_token, expires_at=form.expiration_time)
|
||||
|
||||
def validate_upload_token(self, upload_token: str) -> HumanInputUploadContext:
|
||||
"""Resolve an upload token and ensure the bound form is still active."""
|
||||
|
||||
query = (
|
||||
select(HumanInputFormUploadToken)
|
||||
.options(selectinload(HumanInputFormUploadToken.form))
|
||||
.where(HumanInputFormUploadToken.token == upload_token)
|
||||
.limit(1)
|
||||
)
|
||||
with self._session_maker(expire_on_commit=False) as session:
|
||||
token_model = session.scalars(query).first()
|
||||
if token_model is None:
|
||||
raise InvalidUploadTokenError()
|
||||
|
||||
form_model = token_model.form
|
||||
if form_model is None:
|
||||
raise InvalidUploadTokenError()
|
||||
self._ensure_form_model_active(form_model)
|
||||
|
||||
end_user = session.scalar(
|
||||
select(EndUser)
|
||||
.where(
|
||||
EndUser.id == token_model.end_user_id,
|
||||
EndUser.tenant_id == token_model.tenant_id,
|
||||
EndUser.app_id == token_model.app_id,
|
||||
EndUser.type == HUMAN_INPUT_END_USER_TYPE,
|
||||
)
|
||||
.limit(1)
|
||||
)
|
||||
if end_user is None:
|
||||
raise InvalidUploadTokenError()
|
||||
|
||||
return HumanInputUploadContext(
|
||||
tenant_id=token_model.tenant_id,
|
||||
app_id=token_model.app_id,
|
||||
form_id=token_model.form_id,
|
||||
recipient_id=token_model.recipient_id,
|
||||
upload_token_id=token_model.id,
|
||||
end_user=end_user,
|
||||
)
|
||||
|
||||
def record_upload_file(self, *, context: HumanInputUploadContext, file_id: str) -> None:
|
||||
"""Record that a file was uploaded through a specific form upload token."""
|
||||
|
||||
with self._session_maker(expire_on_commit=False) as session, session.begin():
|
||||
session.add(
|
||||
HumanInputFormUploadFile(
|
||||
tenant_id=context.tenant_id,
|
||||
form_id=context.form_id,
|
||||
upload_file_id=file_id,
|
||||
upload_token_id=context.upload_token_id,
|
||||
end_user_id=context.end_user.id,
|
||||
)
|
||||
)
|
||||
|
||||
def _generate_unique_upload_token(self, session: Session) -> str:
|
||||
return f"{HITL_UPLOAD_TOKEN_PREFIX}{secrets.token_urlsafe(_TOKEN_RANDOM_BYTES)}"
|
||||
|
||||
@staticmethod
|
||||
def _get_or_create_human_input_end_user(
|
||||
*,
|
||||
session: Session,
|
||||
tenant_id: str,
|
||||
app_id: str,
|
||||
recipient_id: str,
|
||||
) -> EndUser:
|
||||
session_id = f"{HUMAN_INPUT_END_USER_SESSION_PREFIX}{recipient_id}"
|
||||
end_user = session.scalar(
|
||||
select(EndUser)
|
||||
.where(
|
||||
EndUser.tenant_id == tenant_id,
|
||||
EndUser.app_id == app_id,
|
||||
EndUser.session_id == session_id,
|
||||
EndUser.type == HUMAN_INPUT_END_USER_TYPE,
|
||||
)
|
||||
.limit(1)
|
||||
)
|
||||
if end_user is not None:
|
||||
return end_user
|
||||
|
||||
end_user = EndUser(
|
||||
tenant_id=tenant_id,
|
||||
app_id=app_id,
|
||||
type=HUMAN_INPUT_END_USER_TYPE,
|
||||
is_anonymous=True,
|
||||
session_id=session_id,
|
||||
external_user_id=session_id,
|
||||
)
|
||||
session.add(end_user)
|
||||
session.flush()
|
||||
return end_user
|
||||
|
||||
@staticmethod
|
||||
def _ensure_form_model_active(form: HumanInputForm) -> None:
|
||||
if form.submitted_at is not None or form.status == HumanInputFormStatus.SUBMITTED:
|
||||
raise FormSubmittedError(form.id)
|
||||
if form.status in {HumanInputFormStatus.TIMEOUT, HumanInputFormStatus.EXPIRED}:
|
||||
raise FormExpiredError(form.id)
|
||||
|
||||
now = naive_utc_now()
|
||||
if ensure_naive_utc(form.expiration_time) <= now:
|
||||
raise FormExpiredError(form.id)
|
||||
|
||||
global_timeout_seconds = dify_config.HUMAN_INPUT_GLOBAL_TIMEOUT_SECONDS
|
||||
if global_timeout_seconds <= 0 or form.workflow_run_id is None:
|
||||
return
|
||||
global_deadline = ensure_naive_utc(form.created_at) + timedelta(seconds=global_timeout_seconds)
|
||||
if global_deadline <= now:
|
||||
raise FormExpiredError(form.id)
|
||||
@ -1,22 +1,33 @@
|
||||
import logging
|
||||
from collections.abc import Mapping
|
||||
from collections.abc import Mapping, Sequence
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Any
|
||||
from typing import Any, Protocol, cast
|
||||
|
||||
from pydantic import JsonValue
|
||||
from sqlalchemy import Engine, select
|
||||
from sqlalchemy.orm import Session, sessionmaker
|
||||
|
||||
from configs import dify_config
|
||||
from core.app.file_access import DatabaseFileAccessController
|
||||
from core.repositories.human_input_repository import (
|
||||
HumanInputFormRecord,
|
||||
HumanInputFormSubmissionRepository,
|
||||
)
|
||||
from factories.file_factory import build_from_mapping, build_from_mappings
|
||||
from graphon.file import FileUploadConfig
|
||||
from graphon.nodes.human_input.entities import (
|
||||
FileInputConfig,
|
||||
FileListInputConfig,
|
||||
FormDefinition,
|
||||
FormInputConfig,
|
||||
HumanInputSubmissionValidationError,
|
||||
validate_human_input_submission,
|
||||
SelectInputConfig,
|
||||
UserActionConfig,
|
||||
)
|
||||
from graphon.nodes.human_input.enums import HumanInputFormKind, HumanInputFormStatus
|
||||
from graphon.nodes.human_input.entities import (
|
||||
validate_human_input_submission as graphon_validate_human_input_submission,
|
||||
)
|
||||
from graphon.nodes.human_input.enums import HumanInputFormKind, HumanInputFormStatus, ValueSourceType
|
||||
from libs.datetime_utils import ensure_naive_utc, naive_utc_now
|
||||
from libs.exception import BaseHTTPException
|
||||
from models.human_input import RecipientType
|
||||
@ -24,6 +35,8 @@ from models.model import App, AppMode
|
||||
from repositories.factory import DifyAPIRepositoryFactory
|
||||
from tasks.app_generate.workflow_execute_task import resume_app_execution
|
||||
|
||||
_file_access_controller = DatabaseFileAccessController()
|
||||
|
||||
|
||||
class Form:
|
||||
def __init__(self, record: HumanInputFormRecord):
|
||||
@ -82,7 +95,7 @@ class HumanInputError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class FormSubmittedError(HumanInputError, BaseHTTPException):
|
||||
class FormSubmittedError(BaseHTTPException, HumanInputError):
|
||||
error_code = "human_input_form_submitted"
|
||||
description = "This form has already been submitted by another user, form_id={form_id}"
|
||||
code = 412
|
||||
@ -90,37 +103,48 @@ class FormSubmittedError(HumanInputError, BaseHTTPException):
|
||||
def __init__(self, form_id: str):
|
||||
template = self.description or "This form has already been submitted by another user, form_id={form_id}"
|
||||
description = template.format(form_id=form_id)
|
||||
super().__init__(description=description)
|
||||
BaseHTTPException.__init__(self, description=description)
|
||||
|
||||
|
||||
class FormNotFoundError(HumanInputError, BaseHTTPException):
|
||||
class FormNotFoundError(BaseHTTPException, HumanInputError):
|
||||
error_code = "human_input_form_not_found"
|
||||
code = 404
|
||||
|
||||
|
||||
class InvalidFormDataError(HumanInputError, BaseHTTPException):
|
||||
class InvalidFormDataError(BaseHTTPException, HumanInputError):
|
||||
error_code = "invalid_form_data"
|
||||
code = 400
|
||||
|
||||
def __init__(self, description: str):
|
||||
super().__init__(description=description)
|
||||
BaseHTTPException.__init__(self, description=description)
|
||||
|
||||
|
||||
class WebAppDeliveryNotEnabledError(HumanInputError, BaseException):
|
||||
pass
|
||||
|
||||
|
||||
class FormExpiredError(HumanInputError, BaseHTTPException):
|
||||
class FormExpiredError(BaseHTTPException, HumanInputError):
|
||||
error_code = "human_input_form_expired"
|
||||
code = 412
|
||||
|
||||
def __init__(self, form_id: str):
|
||||
super().__init__(description=f"This form has expired, form_id={form_id}")
|
||||
BaseHTTPException.__init__(
|
||||
self,
|
||||
description=f"This form has expired, form_id={form_id}",
|
||||
)
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class FormDefinitionProtocol(Protocol):
|
||||
@property
|
||||
def inputs(self) -> Sequence[FormInputConfig]: ...
|
||||
|
||||
@property
|
||||
def user_actions(self) -> Sequence[UserActionConfig]: ...
|
||||
|
||||
|
||||
class HumanInputService:
|
||||
def __init__(
|
||||
self,
|
||||
@ -157,7 +181,7 @@ class HumanInputService:
|
||||
recipient_type: RecipientType,
|
||||
form_token: str,
|
||||
selected_action_id: str,
|
||||
form_data: Mapping[str, Any],
|
||||
form_data: Mapping[str, JsonValue],
|
||||
submission_end_user_id: str | None = None,
|
||||
submission_user_id: str | None = None,
|
||||
):
|
||||
@ -166,13 +190,17 @@ class HumanInputService:
|
||||
raise WebAppDeliveryNotEnabledError()
|
||||
|
||||
self.ensure_form_active(form)
|
||||
self._validate_submission(form=form, selected_action_id=selected_action_id, form_data=form_data)
|
||||
normalized_form_data = self._validate_submission(
|
||||
form=form,
|
||||
selected_action_id=selected_action_id,
|
||||
form_data=form_data,
|
||||
)
|
||||
|
||||
result = self._form_repository.mark_submitted(
|
||||
form_id=form.id,
|
||||
recipient_id=form.recipient_id,
|
||||
selected_action_id=selected_action_id,
|
||||
form_data=form_data,
|
||||
form_data=normalized_form_data,
|
||||
submission_user_id=submission_user_id,
|
||||
submission_end_user_id=submission_end_user_id,
|
||||
)
|
||||
@ -198,12 +226,17 @@ class HumanInputService:
|
||||
if form.submitted:
|
||||
raise FormSubmittedError(form.id)
|
||||
|
||||
def _validate_submission(self, form: Form, selected_action_id: str, form_data: Mapping[str, Any]) -> None:
|
||||
def _validate_submission(
|
||||
self,
|
||||
form: Form,
|
||||
selected_action_id: str,
|
||||
form_data: Mapping[str, Any],
|
||||
) -> dict[str, JsonValue]:
|
||||
definition = form.get_definition()
|
||||
try:
|
||||
validate_human_input_submission(
|
||||
inputs=definition.inputs,
|
||||
user_actions=definition.user_actions,
|
||||
return self.validate_and_normalize_submission(
|
||||
tenant_id=form.tenant_id,
|
||||
form_definition=definition,
|
||||
selected_action_id=selected_action_id,
|
||||
form_data=form_data,
|
||||
)
|
||||
@ -247,3 +280,184 @@ class HumanInputService:
|
||||
created_at = ensure_naive_utc(form.created_at)
|
||||
global_deadline = created_at + timedelta(seconds=global_timeout_seconds)
|
||||
return global_deadline <= current
|
||||
|
||||
@staticmethod
|
||||
def validate_human_input_submission(
|
||||
*,
|
||||
form_definition: FormDefinitionProtocol,
|
||||
selected_action_id: str,
|
||||
form_data: Mapping[str, Any],
|
||||
) -> None:
|
||||
graphon_validate_human_input_submission(
|
||||
inputs=form_definition.inputs,
|
||||
user_actions=form_definition.user_actions,
|
||||
selected_action_id=selected_action_id,
|
||||
form_data=form_data,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def validate_and_normalize_submission(
|
||||
cls,
|
||||
*,
|
||||
tenant_id: str,
|
||||
form_definition: FormDefinitionProtocol,
|
||||
selected_action_id: str,
|
||||
form_data: Mapping[str, Any],
|
||||
) -> dict[str, JsonValue]:
|
||||
"""
|
||||
Normalize Dify-owned runtime payloads before delegating shape validation to graphon.
|
||||
|
||||
graphon owns the form schema and validation rules, while Dify owns tenant-aware file
|
||||
reconstruction and persistence compatibility for submitted payloads.
|
||||
"""
|
||||
normalized_form_data = cls.normalize_submission_data(
|
||||
tenant_id=tenant_id,
|
||||
form_definition=form_definition,
|
||||
form_data=form_data,
|
||||
)
|
||||
graphon_validate_human_input_submission(
|
||||
inputs=form_definition.inputs,
|
||||
user_actions=form_definition.user_actions,
|
||||
selected_action_id=selected_action_id,
|
||||
form_data=normalized_form_data,
|
||||
)
|
||||
return normalized_form_data
|
||||
|
||||
@classmethod
|
||||
def normalize_submission_data(
|
||||
cls,
|
||||
*,
|
||||
tenant_id: str,
|
||||
form_definition: FormDefinitionProtocol,
|
||||
form_data: Mapping[str, Any],
|
||||
) -> dict[str, JsonValue]:
|
||||
normalized_form_data: dict[str, JsonValue] = {key: cast(JsonValue, value) for key, value in form_data.items()}
|
||||
inputs_by_name = {form_input.output_variable_name: form_input for form_input in form_definition.inputs}
|
||||
for name, form_input in inputs_by_name.items():
|
||||
if name not in form_data:
|
||||
continue
|
||||
normalized_form_data[name] = cls._normalize_input_value(
|
||||
tenant_id=tenant_id,
|
||||
form_input=form_input,
|
||||
value=form_data[name],
|
||||
)
|
||||
|
||||
return normalized_form_data
|
||||
|
||||
@classmethod
|
||||
def _normalize_input_value(
|
||||
cls,
|
||||
*,
|
||||
tenant_id: str,
|
||||
form_input: FormInputConfig,
|
||||
value: Any,
|
||||
) -> JsonValue:
|
||||
if isinstance(form_input, SelectInputConfig):
|
||||
return cls._normalize_select_value(form_input=form_input, value=value)
|
||||
if isinstance(form_input, FileInputConfig):
|
||||
return cls._normalize_file_value(
|
||||
tenant_id=tenant_id,
|
||||
form_input=form_input,
|
||||
value=value,
|
||||
)
|
||||
if isinstance(form_input, FileListInputConfig):
|
||||
return cls._normalize_file_list_value(
|
||||
tenant_id=tenant_id,
|
||||
form_input=form_input,
|
||||
value=value,
|
||||
)
|
||||
return cast(JsonValue, value)
|
||||
|
||||
@classmethod
|
||||
def _normalize_select_value(
|
||||
cls,
|
||||
*,
|
||||
form_input: SelectInputConfig,
|
||||
value: Any,
|
||||
) -> JsonValue:
|
||||
if not isinstance(value, str):
|
||||
raise HumanInputSubmissionValidationError(
|
||||
f"Invalid value for select input '{form_input.output_variable_name}': expected string"
|
||||
)
|
||||
option_source = form_input.option_source
|
||||
if option_source.type == ValueSourceType.CONSTANT and value not in option_source.value:
|
||||
raise HumanInputSubmissionValidationError(
|
||||
f"Invalid value for select input '{form_input.output_variable_name}': {value}"
|
||||
)
|
||||
return value
|
||||
|
||||
@classmethod
|
||||
def _normalize_file_value(
|
||||
cls,
|
||||
*,
|
||||
tenant_id: str,
|
||||
form_input: FileInputConfig,
|
||||
value: Any,
|
||||
) -> JsonValue:
|
||||
if not isinstance(value, Mapping):
|
||||
raise HumanInputSubmissionValidationError(
|
||||
f"Invalid value for file input '{form_input.output_variable_name}': expected mapping"
|
||||
)
|
||||
upload_config = cls._build_file_upload_config(form_input=form_input, number_limits=1)
|
||||
try:
|
||||
# `build_from_mapping` enforces tenant ownership for persisted upload references.
|
||||
file = build_from_mapping(
|
||||
mapping=value,
|
||||
tenant_id=tenant_id,
|
||||
config=upload_config,
|
||||
strict_type_validation=True,
|
||||
access_controller=_file_access_controller,
|
||||
)
|
||||
except ValueError as exc:
|
||||
raise HumanInputSubmissionValidationError(
|
||||
f"Invalid value for file input '{form_input.output_variable_name}': {exc}"
|
||||
) from exc
|
||||
return cast(JsonValue, file.to_dict())
|
||||
|
||||
@classmethod
|
||||
def _normalize_file_list_value(
|
||||
cls,
|
||||
*,
|
||||
tenant_id: str,
|
||||
form_input: FileListInputConfig,
|
||||
value: Any,
|
||||
) -> JsonValue:
|
||||
if not isinstance(value, list):
|
||||
raise HumanInputSubmissionValidationError(
|
||||
f"Invalid value for file list input '{form_input.output_variable_name}': expected list"
|
||||
)
|
||||
if any(not isinstance(item, Mapping) for item in value):
|
||||
raise HumanInputSubmissionValidationError(
|
||||
f"Invalid value for file list input '{form_input.output_variable_name}': expected list of mappings"
|
||||
)
|
||||
upload_config = cls._build_file_upload_config(
|
||||
form_input=form_input,
|
||||
number_limits=form_input.number_limits,
|
||||
)
|
||||
try:
|
||||
# `build_from_mappings` performs the same tenant-aware ownership validation in batch.
|
||||
files = build_from_mappings(
|
||||
mappings=cast(Sequence[Mapping[str, Any]], value),
|
||||
tenant_id=tenant_id,
|
||||
config=upload_config,
|
||||
strict_type_validation=True,
|
||||
access_controller=_file_access_controller,
|
||||
)
|
||||
except ValueError as exc:
|
||||
raise HumanInputSubmissionValidationError(
|
||||
f"Invalid value for file list input '{form_input.output_variable_name}': {exc}"
|
||||
) from exc
|
||||
return cast(JsonValue, [file.to_dict() for file in files])
|
||||
|
||||
@staticmethod
|
||||
def _build_file_upload_config(
|
||||
*,
|
||||
form_input: FileInputConfig | FileListInputConfig,
|
||||
number_limits: int,
|
||||
) -> FileUploadConfig:
|
||||
return FileUploadConfig(
|
||||
allowed_file_types=list(form_input.allowed_file_types),
|
||||
allowed_file_extensions=list(form_input.allowed_file_extensions),
|
||||
allowed_file_upload_methods=list(form_input.allowed_file_upload_methods),
|
||||
number_limits=number_limits,
|
||||
)
|
||||
|
||||
Loading…
Reference in New Issue
Block a user