import base64 import hashlib import os import uuid from collections.abc import Iterator, Sequence from contextlib import contextmanager, suppress from tempfile import NamedTemporaryFile from typing import Literal from zipfile import ZIP_DEFLATED, ZipFile from graphon.file import helpers as file_helpers from sqlalchemy import Engine, select from sqlalchemy.orm import Session, sessionmaker from werkzeug.exceptions import NotFound from configs import dify_config from constants import ( AUDIO_EXTENSIONS, DOCUMENT_EXTENSIONS, IMAGE_EXTENSIONS, VIDEO_EXTENSIONS, ) from core.rag.extractor.extract_processor import ExtractProcessor from extensions.ext_database import db from extensions.ext_storage import storage from extensions.storage.storage_type import StorageType from libs.datetime_utils import naive_utc_now from libs.helper import extract_tenant_id from models import Account from models.enums import CreatorUserRole from models.model import EndUser, UploadFile from .errors.file import BlockedFileExtensionError, FileTooLargeError, UnsupportedFileTypeError PREVIEW_WORDS_LIMIT = 3000 class FileService: _session_maker: sessionmaker[Session] def __init__(self, session_factory: sessionmaker | Engine | None = None): if isinstance(session_factory, Engine): self._session_maker = sessionmaker(bind=session_factory) elif isinstance(session_factory, sessionmaker): self._session_maker = session_factory else: raise AssertionError("must be a sessionmaker or an Engine.") def upload_file( self, *, filename: str, content: bytes, mimetype: str, user: Account | EndUser, source: Literal["datasets"] | None = None, source_url: str = "", ) -> UploadFile: # get file extension extension = os.path.splitext(filename)[1].lstrip(".").lower() # Only reject path separators here. The original filename is stored as metadata, # while the storage key is UUID-based. if any(c in filename for c in ["/", "\\"]): raise ValueError("Filename contains invalid characters") if len(filename) > 200: filename = filename.split(".")[0][:200] + "." + extension # check if extension is in blacklist if extension and extension in dify_config.UPLOAD_FILE_EXTENSION_BLACKLIST: raise BlockedFileExtensionError(f"File extension '.{extension}' is not allowed for security reasons") if source == "datasets" and extension not in DOCUMENT_EXTENSIONS: raise UnsupportedFileTypeError() # get file size file_size = len(content) # check if the file size is exceeded if not FileService.is_file_size_within_limit(extension=extension, file_size=file_size): raise FileTooLargeError # generate file key file_uuid = str(uuid.uuid4()) current_tenant_id = extract_tenant_id(user) file_key = "upload_files/" + (current_tenant_id or "") + "/" + file_uuid + "." + extension # save file to storage storage.save(file_key, content) # save file to db upload_file = UploadFile( tenant_id=current_tenant_id or "", storage_type=StorageType(dify_config.STORAGE_TYPE), key=file_key, name=filename, size=file_size, extension=extension, mime_type=mimetype, created_by_role=(CreatorUserRole.ACCOUNT if isinstance(user, Account) else CreatorUserRole.END_USER), created_by=user.id, created_at=naive_utc_now(), used=False, hash=hashlib.sha3_256(content).hexdigest(), source_url=source_url, ) # The `UploadFile` ID is generated within its constructor, so flushing to retrieve the ID is unnecessary. # We can directly generate the `source_url` here before committing. if not upload_file.source_url: upload_file.source_url = file_helpers.get_signed_file_url(upload_file_id=upload_file.id) with self._session_maker(expire_on_commit=False) as session: session.add(upload_file) session.commit() return upload_file @staticmethod def is_file_size_within_limit(*, extension: str, file_size: int) -> bool: if extension in IMAGE_EXTENSIONS: file_size_limit = dify_config.UPLOAD_IMAGE_FILE_SIZE_LIMIT * 1024 * 1024 elif extension in VIDEO_EXTENSIONS: file_size_limit = dify_config.UPLOAD_VIDEO_FILE_SIZE_LIMIT * 1024 * 1024 elif extension in AUDIO_EXTENSIONS: file_size_limit = dify_config.UPLOAD_AUDIO_FILE_SIZE_LIMIT * 1024 * 1024 else: file_size_limit = dify_config.UPLOAD_FILE_SIZE_LIMIT * 1024 * 1024 return file_size <= file_size_limit def get_file_base64(self, file_id: str) -> str: upload_file = self._session_maker(expire_on_commit=False).scalar( select(UploadFile).where(UploadFile.id == file_id).limit(1) ) if not upload_file: raise NotFound("File not found") blob = storage.load_once(upload_file.key) return base64.b64encode(blob).decode() def upload_text(self, text: str, text_name: str, user_id: str, tenant_id: str) -> UploadFile: if len(text_name) > 200: text_name = text_name[:200] # user uuid as file name file_uuid = str(uuid.uuid4()) file_key = "upload_files/" + tenant_id + "/" + file_uuid + ".txt" # save file to storage storage.save(file_key, text.encode("utf-8")) # save file to db upload_file = UploadFile( tenant_id=tenant_id, storage_type=StorageType(dify_config.STORAGE_TYPE), key=file_key, name=text_name, size=len(text), extension="txt", mime_type="text/plain", created_by=user_id, created_by_role=CreatorUserRole.ACCOUNT, created_at=naive_utc_now(), used=True, used_by=user_id, used_at=naive_utc_now(), ) with self._session_maker(expire_on_commit=False) as session: session.add(upload_file) session.commit() return upload_file def get_file_preview(self, file_id: str): """ Return a short text preview extracted from a document file. """ with self._session_maker(expire_on_commit=False) as session: upload_file = session.scalar(select(UploadFile).where(UploadFile.id == file_id).limit(1)) if not upload_file: raise NotFound("File not found") # extract text from file extension = upload_file.extension if extension.lower() not in DOCUMENT_EXTENSIONS: raise UnsupportedFileTypeError() text = ExtractProcessor.load_from_upload_file(upload_file, return_text=True) text = text[0:PREVIEW_WORDS_LIMIT] if text else "" return text def get_image_preview(self, file_id: str, timestamp: str, nonce: str, sign: str): result = file_helpers.verify_image_signature( upload_file_id=file_id, timestamp=timestamp, nonce=nonce, sign=sign ) if not result: raise NotFound("File not found or signature is invalid") with self._session_maker(expire_on_commit=False) as session: upload_file = session.scalar(select(UploadFile).where(UploadFile.id == file_id).limit(1)) if not upload_file: raise NotFound("File not found or signature is invalid") # extract text from file extension = upload_file.extension if extension.lower() not in IMAGE_EXTENSIONS: raise UnsupportedFileTypeError() generator = storage.load(upload_file.key, stream=True) return generator, upload_file.mime_type def get_file_generator_by_file_id(self, file_id: str, timestamp: str, nonce: str, sign: str): result = file_helpers.verify_file_signature(upload_file_id=file_id, timestamp=timestamp, nonce=nonce, sign=sign) if not result: raise NotFound("File not found or signature is invalid") with self._session_maker(expire_on_commit=False) as session: upload_file = session.scalar(select(UploadFile).where(UploadFile.id == file_id).limit(1)) if not upload_file: raise NotFound("File not found or signature is invalid") generator = storage.load(upload_file.key, stream=True) return generator, upload_file def get_public_image_preview(self, file_id: str): with self._session_maker(expire_on_commit=False) as session: upload_file = session.scalar(select(UploadFile).where(UploadFile.id == file_id).limit(1)) if not upload_file: raise NotFound("File not found or signature is invalid") # extract text from file extension = upload_file.extension if extension.lower() not in IMAGE_EXTENSIONS: raise UnsupportedFileTypeError() generator = storage.load(upload_file.key) return generator, upload_file.mime_type def get_file_content(self, file_id: str) -> str: with self._session_maker(expire_on_commit=False) as session: upload_file: UploadFile | None = session.scalar(select(UploadFile).where(UploadFile.id == file_id).limit(1)) if not upload_file: raise NotFound("File not found") content = storage.load(upload_file.key) return content.decode("utf-8") def delete_file(self, file_id: str): with self._session_maker() as session, session.begin(): upload_file = session.scalar(select(UploadFile).where(UploadFile.id == file_id)) if not upload_file: return storage.delete(upload_file.key) session.delete(upload_file) @staticmethod def get_upload_files_by_ids(tenant_id: str, upload_file_ids: Sequence[str]) -> dict[str, UploadFile]: """ Fetch `UploadFile` rows for a tenant in a single batch query. This is a generic `UploadFile` lookup helper (not dataset/document specific), so it lives in `FileService`. """ if not upload_file_ids: return {} # Normalize and deduplicate ids before using them in the IN clause. upload_file_id_list: list[str] = [str(upload_file_id) for upload_file_id in upload_file_ids] unique_upload_file_ids: list[str] = list(set(upload_file_id_list)) # Fetch upload files in one query for efficient batch access. upload_files: Sequence[UploadFile] = db.session.scalars( select(UploadFile).where( UploadFile.tenant_id == tenant_id, UploadFile.id.in_(unique_upload_file_ids), ) ).all() return {str(upload_file.id): upload_file for upload_file in upload_files} @staticmethod def _sanitize_zip_entry_name(name: str) -> str: """ Sanitize a ZIP entry name to avoid path traversal and weird separators. We keep this conservative: the upload flow already rejects `/` and `\\`, but older rows (or imported data) could still contain unsafe names. """ # Drop any directory components and prevent empty names. base = os.path.basename(name).strip() or "file" # ZIP uses forward slashes as separators; remove any residual separator characters. return base.replace("/", "_").replace("\\", "_") @staticmethod def _dedupe_zip_entry_name(original_name: str, used_names: set[str]) -> str: """ Return a unique ZIP entry name, inserting suffixes before the extension. """ # Keep the original name when it's not already used. if original_name not in used_names: return original_name # Insert suffixes before the extension (e.g., "doc.txt" -> "doc (1).txt"). stem, extension = os.path.splitext(original_name) suffix = 1 while True: candidate = f"{stem} ({suffix}){extension}" if candidate not in used_names: return candidate suffix += 1 @staticmethod @contextmanager def build_upload_files_zip_tempfile( *, upload_files: Sequence[UploadFile], ) -> Iterator[str]: """ Build a ZIP from `UploadFile`s and yield a tempfile path. We yield a path (rather than an open file handle) to avoid "read of closed file" issues when Flask/Werkzeug streams responses. The caller is expected to keep this context open until the response is fully sent, then close it (e.g., via `response.call_on_close(...)`) to delete the tempfile. """ used_names: set[str] = set() # Build a ZIP in a temp file and keep it on disk until the caller finishes streaming it. tmp_path: str | None = None try: with NamedTemporaryFile(mode="w+b", suffix=".zip", delete=False) as tmp: tmp_path = tmp.name with ZipFile(tmp, mode="w", compression=ZIP_DEFLATED) as zf: for upload_file in upload_files: # Ensure the entry name is safe and unique. safe_name = FileService._sanitize_zip_entry_name(upload_file.name) arcname = FileService._dedupe_zip_entry_name(safe_name, used_names) used_names.add(arcname) # Stream file bytes from storage into the ZIP entry. with zf.open(arcname, "w") as entry: for chunk in storage.load(upload_file.key, stream=True): entry.write(chunk) # Flush so `send_file(path, ...)` can re-open it safely on all platforms. tmp.flush() assert tmp_path is not None yield tmp_path finally: # Remove the temp file when the context is closed (typically after the response finishes streaming). if tmp_path is not None: with suppress(FileNotFoundError): os.remove(tmp_path)