diff --git a/api/controllers/console/__init__.py b/api/controllers/console/__init__.py index 3decf4d116..6d290a46ee 100644 --- a/api/controllers/console/__init__.py +++ b/api/controllers/console/__init__.py @@ -32,6 +32,7 @@ for module_name in RESOURCE_MODULES: # Ensure resource modules are imported so route decorators are evaluated. # Import other controllers +# Sandbox file browser from . import ( admin, apikey, @@ -39,6 +40,7 @@ from . import ( feature, init_validate, ping, + sandbox_files, setup, spec, version, @@ -199,6 +201,7 @@ __all__ = [ "rag_pipeline_import", "rag_pipeline_workflow", "recommended_app", + "sandbox_files", "sandbox_providers", "saved_message", "setup", diff --git a/api/controllers/console/sandbox_files.py b/api/controllers/console/sandbox_files.py new file mode 100644 index 0000000000..488d34ca7a --- /dev/null +++ b/api/controllers/console/sandbox_files.py @@ -0,0 +1,87 @@ +from __future__ import annotations + +from flask import request +from flask_restx import Resource, fields +from pydantic import BaseModel, Field + +from controllers.console import console_ns +from controllers.console.wraps import account_initialization_required, setup_required +from libs.login import current_account_with_tenant, login_required +from services.sandbox.sandbox_file_service import SandboxFileService + +DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}" + + +class SandboxFileListQuery(BaseModel): + path: str | None = Field(default=None, description="Workspace relative path") + recursive: bool = Field(default=False, description="List recursively") + + +class SandboxFileDownloadRequest(BaseModel): + path: str = Field(..., description="Workspace relative file path") + + +console_ns.schema_model( + SandboxFileListQuery.__name__, + SandboxFileListQuery.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0), +) +console_ns.schema_model( + SandboxFileDownloadRequest.__name__, + SandboxFileDownloadRequest.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0), +) + + +SANDBOX_FILE_NODE_FIELDS = { + "path": fields.String, + "is_dir": fields.Boolean, + "size": fields.Raw, + "mtime": fields.Raw, +} + + +SANDBOX_FILE_DOWNLOAD_TICKET_FIELDS = { + "download_url": fields.String, + "expires_in": fields.Integer, + "export_id": fields.String, +} + + +sandbox_file_node_model = console_ns.model("SandboxFileNode", SANDBOX_FILE_NODE_FIELDS) +sandbox_file_download_ticket_model = console_ns.model( + "SandboxFileDownloadTicket", SANDBOX_FILE_DOWNLOAD_TICKET_FIELDS +) + + +@console_ns.route("/sandboxes//files") +class SandboxFilesApi(Resource): + @setup_required + @login_required + @account_initialization_required + @console_ns.expect(console_ns.models[SandboxFileListQuery.__name__]) + @console_ns.marshal_list_with(sandbox_file_node_model) + def get(self, sandbox_id: str): + args = SandboxFileListQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore[arg-type] + _, tenant_id = current_account_with_tenant() + return [ + e.__dict__ + for e in SandboxFileService.list_files( + tenant_id=tenant_id, + sandbox_id=sandbox_id, + path=args.path, + recursive=args.recursive, + ) + ] + + +@console_ns.route("/sandboxes//files/download") +class SandboxFileDownloadApi(Resource): + @setup_required + @login_required + @account_initialization_required + @console_ns.expect(console_ns.models[SandboxFileDownloadRequest.__name__]) + @console_ns.marshal_with(sandbox_file_download_ticket_model) + def post(self, sandbox_id: str): + payload = SandboxFileDownloadRequest.model_validate(console_ns.payload or {}) + _, tenant_id = current_account_with_tenant() + res = SandboxFileService.download_file(tenant_id=tenant_id, sandbox_id=sandbox_id, path=payload.path) + return res.__dict__ diff --git a/api/controllers/files/__init__.py b/api/controllers/files/__init__.py index abf9026b9c..77eb012c7c 100644 --- a/api/controllers/files/__init__.py +++ b/api/controllers/files/__init__.py @@ -19,6 +19,7 @@ from . import ( app_assets_upload, image_preview, sandbox_archive, + sandbox_file_downloads, storage_download, tool_files, upload, @@ -34,6 +35,7 @@ __all__ = [ "files_ns", "image_preview", "sandbox_archive", + "sandbox_file_downloads", "storage_download", "tool_files", "upload", diff --git a/api/controllers/files/sandbox_file_downloads.py b/api/controllers/files/sandbox_file_downloads.py new file mode 100644 index 0000000000..7f021d4493 --- /dev/null +++ b/api/controllers/files/sandbox_file_downloads.py @@ -0,0 +1,96 @@ +from urllib.parse import quote +from uuid import UUID + +from flask import Response, request +from flask_restx import Resource +from pydantic import BaseModel, Field +from werkzeug.exceptions import Forbidden, NotFound + +from controllers.files import files_ns +from core.sandbox.security.sandbox_file_signer import SandboxFileDownloadPath, SandboxFileSigner +from extensions.ext_storage import storage + +DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}" + + +class SandboxFileDownloadQuery(BaseModel): + expires_at: int = Field(..., description="Unix timestamp when the link expires") + nonce: str = Field(..., description="Random string for signature") + sign: str = Field(..., description="HMAC signature") + + +files_ns.schema_model( + SandboxFileDownloadQuery.__name__, + SandboxFileDownloadQuery.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0), +) + + +@files_ns.route( + "/sandbox-file-downloads/////download" +) +class SandboxFileDownloadDownloadApi(Resource): + def get(self, tenant_id: str, sandbox_id: str, export_id: str, filename: str): + args = SandboxFileDownloadQuery.model_validate(request.args.to_dict(flat=True)) + + try: + export_path = SandboxFileDownloadPath( + tenant_id=UUID(tenant_id), + sandbox_id=UUID(sandbox_id), + export_id=export_id, + filename=filename, + ) + except ValueError as exc: + raise Forbidden(str(exc)) from exc + + if not SandboxFileSigner.verify_download_signature( + export_path=export_path, + expires_at=args.expires_at, + nonce=args.nonce, + sign=args.sign, + ): + raise Forbidden("Invalid or expired download link") + + try: + generator = storage.load_stream(export_path.get_storage_key()) + except FileNotFoundError as exc: + raise NotFound("File not found") from exc + + encoded_filename = quote(filename.split("/")[-1]) + + return Response( + generator, + mimetype="application/octet-stream", + direct_passthrough=True, + headers={ + "Content-Disposition": f"attachment; filename*=UTF-8''{encoded_filename}", + }, + ) + + +@files_ns.route( + "/sandbox-file-downloads/////upload" +) +class SandboxFileDownloadUploadApi(Resource): + def put(self, tenant_id: str, sandbox_id: str, export_id: str, filename: str): + args = SandboxFileDownloadQuery.model_validate(request.args.to_dict(flat=True)) + + try: + export_path = SandboxFileDownloadPath( + tenant_id=UUID(tenant_id), + sandbox_id=UUID(sandbox_id), + export_id=export_id, + filename=filename, + ) + except ValueError as exc: + raise Forbidden(str(exc)) from exc + + if not SandboxFileSigner.verify_upload_signature( + export_path=export_path, + expires_at=args.expires_at, + nonce=args.nonce, + sign=args.sign, + ): + raise Forbidden("Invalid or expired upload link") + + storage.save(export_path.get_storage_key(), request.get_data()) + return Response(status=204) diff --git a/api/core/sandbox/entities/__init__.py b/api/core/sandbox/entities/__init__.py index b5c3d57342..be562f4e5b 100644 --- a/api/core/sandbox/entities/__init__.py +++ b/api/core/sandbox/entities/__init__.py @@ -1,10 +1,13 @@ from .config import AppAssets, DifyCli +from .files import SandboxFileDownloadTicket, SandboxFileNode from .providers import SandboxProviderApiEntity from .sandbox_type import SandboxType __all__ = [ "AppAssets", "DifyCli", + "SandboxFileDownloadTicket", + "SandboxFileNode", "SandboxProviderApiEntity", "SandboxType", ] diff --git a/api/core/sandbox/entities/files.py b/api/core/sandbox/entities/files.py new file mode 100644 index 0000000000..c5b04787dd --- /dev/null +++ b/api/core/sandbox/entities/files.py @@ -0,0 +1,18 @@ +from __future__ import annotations + +from dataclasses import dataclass + + +@dataclass(frozen=True) +class SandboxFileNode: + path: str + is_dir: bool + size: int | None + mtime: int | None + + +@dataclass(frozen=True) +class SandboxFileDownloadTicket: + download_url: str + expires_in: int + export_id: str diff --git a/api/core/sandbox/inspector.py b/api/core/sandbox/inspector.py new file mode 100644 index 0000000000..166b7ca434 --- /dev/null +++ b/api/core/sandbox/inspector.py @@ -0,0 +1,459 @@ +from __future__ import annotations + +import abc +import json +import logging +import os +import tempfile +from pathlib import Path, PurePosixPath +from uuid import UUID, uuid4 + +from core.sandbox.entities.files import SandboxFileDownloadTicket, SandboxFileNode +from core.sandbox.manager import SandboxManager +from core.sandbox.security.archive_signer import SandboxArchivePath +from core.sandbox.security.sandbox_file_signer import SandboxFileDownloadPath, SandboxFileSigner +from core.virtual_environment.__base.exec import CommandExecutionError +from core.virtual_environment.__base.helpers import execute +from core.virtual_environment.__base.virtual_environment import VirtualEnvironment +from extensions.ext_storage import storage + +logger = logging.getLogger(__name__) + + +class SandboxFileSource(abc.ABC): + _LIST_TIMEOUT_SECONDS = 30 + _UPLOAD_TIMEOUT_SECONDS = 60 + _EXPORT_EXPIRES_IN_SECONDS = 60 * 5 + + def __init__(self, *, tenant_id: str, sandbox_id: str): + self._tenant_id = tenant_id + self._sandbox_id = sandbox_id + + @abc.abstractmethod + def list_files(self, *, path: str, recursive: bool) -> list[SandboxFileNode]: + raise NotImplementedError + + @abc.abstractmethod + def download_file(self, *, path: str) -> SandboxFileDownloadTicket: + raise NotImplementedError + + +class SandboxFileRuntimeSource(SandboxFileSource): + def __init__(self, *, tenant_id: str, sandbox_id: str, runtime: VirtualEnvironment): + super().__init__(tenant_id=tenant_id, sandbox_id=sandbox_id) + self._runtime = runtime + + def list_files(self, *, path: str, recursive: bool) -> list[SandboxFileNode]: + script = r""" +import json +import os +import sys + +path = sys.argv[1] +recursive = sys.argv[2] == "1" + +def norm(rel: str) -> str: + rel = rel.replace("\\\\", "/") + rel = rel.lstrip("./") + return rel or "." + +def stat_entry(full_path: str, rel_path: str) -> dict: + st = os.stat(full_path) + is_dir = os.path.isdir(full_path) + return { + "path": norm(rel_path), + "is_dir": is_dir, + "size": None if is_dir else int(st.st_size), + "mtime": int(st.st_mtime), + } + +entries = [] +if recursive: + for root, dirs, files in os.walk(path): + for d in dirs: + fp = os.path.join(root, d) + rp = os.path.relpath(fp, ".") + entries.append(stat_entry(fp, rp)) + for f in files: + fp = os.path.join(root, f) + rp = os.path.relpath(fp, ".") + entries.append(stat_entry(fp, rp)) +else: + if os.path.isfile(path): + rel_path = os.path.relpath(path, ".") + entries.append(stat_entry(path, rel_path)) + else: + for item in os.scandir(path): + rel_path = os.path.relpath(item.path, ".") + entries.append(stat_entry(item.path, rel_path)) + +print(json.dumps(entries)) +""" + + try: + result = execute( + self._runtime, + [ + "sh", + "-c", + 'if command -v python3 >/dev/null 2>&1; then py=python3; else py=python; fi; "$py" -c "$0" "$@"', + script, + path, + "1" if recursive else "0", + ], + timeout=self._LIST_TIMEOUT_SECONDS, + error_message="Failed to list sandbox files", + ) + except CommandExecutionError as exc: + raise RuntimeError(str(exc)) from exc + + try: + raw = json.loads(result.stdout.decode("utf-8")) + except Exception as exc: + raise RuntimeError("Malformed sandbox file list output") from exc + + entries: list[SandboxFileNode] = [] + for item in raw: + entries.append( + SandboxFileNode( + path=str(item.get("path")), + is_dir=bool(item.get("is_dir")), + size=item.get("size"), + mtime=item.get("mtime"), + ) + ) + return entries + + def download_file(self, *, path: str) -> SandboxFileDownloadTicket: + kind = self._detect_path_kind(path) + + export_name = os.path.basename(path.rstrip("/")) or "workspace" + filename = f"{export_name}.tar.gz" if kind == "dir" else (os.path.basename(path) or "file") + export_id = uuid4().hex + export_path = SandboxFileDownloadPath( + tenant_id=UUID(self._tenant_id), + sandbox_id=UUID(self._sandbox_id), + export_id=export_id, + filename=filename, + ) + + upload_url = SandboxFileSigner.build_signed_url( + export_path=export_path, + expires_in=self._EXPORT_EXPIRES_IN_SECONDS, + action=SandboxFileSigner.OPERATION_UPLOAD, + ) + + if kind == "dir": + archive_path = f"/tmp/{export_id}.tar.gz" + try: + execute( + self._runtime, + ["tar", "-czf", archive_path, "-C", ".", path], + timeout=self._UPLOAD_TIMEOUT_SECONDS, + error_message="Failed to archive directory in sandbox", + ) + execute( + self._runtime, + ["curl", "-s", "-f", "-X", "PUT", "-T", archive_path, upload_url], + timeout=self._UPLOAD_TIMEOUT_SECONDS, + error_message="Failed to upload directory archive from sandbox", + ) + except CommandExecutionError as exc: + raise RuntimeError(str(exc)) from exc + finally: + try: + execute( + self._runtime, + ["rm", "-f", archive_path], + timeout=self._LIST_TIMEOUT_SECONDS, + error_message="Failed to cleanup temp archive", + ) + except Exception as exc: + # Best-effort cleanup; do not fail the download on cleanup issues. + logger.debug("Failed to cleanup temp archive %s: %s", archive_path, exc) + else: + try: + execute( + self._runtime, + ["curl", "-s", "-f", "-X", "PUT", "-T", path, upload_url], + timeout=self._UPLOAD_TIMEOUT_SECONDS, + error_message="Failed to upload file from sandbox", + ) + except CommandExecutionError as exc: + raise RuntimeError(str(exc)) from exc + + download_url = SandboxFileSigner.build_signed_url( + export_path=export_path, + expires_in=self._EXPORT_EXPIRES_IN_SECONDS, + action=SandboxFileSigner.OPERATION_DOWNLOAD, + ) + return SandboxFileDownloadTicket( + download_url=download_url, + expires_in=self._EXPORT_EXPIRES_IN_SECONDS, + export_id=export_id, + ) + + def _detect_path_kind(self, path: str) -> str: + script = r""" +import os +import sys + +p = sys.argv[1] +if os.path.isdir(p): + print("dir") + raise SystemExit(0) +if os.path.isfile(p): + print("file") + raise SystemExit(0) +print("none") +raise SystemExit(2) +""" + + try: + result = execute( + self._runtime, + [ + "sh", + "-c", + 'if command -v python3 >/dev/null 2>&1; then py=python3; else py=python; fi; "$py" -c "$0" "$@"', + script, + path, + ], + timeout=self._LIST_TIMEOUT_SECONDS, + error_message="Failed to check path in sandbox", + ) + except CommandExecutionError as exc: + raise ValueError(str(exc)) from exc + + kind = result.stdout.decode("utf-8", errors="replace").strip() + if kind not in ("dir", "file"): + raise ValueError("File not found in sandbox") + return kind + + +class SandboxFileArchiveSource(SandboxFileSource): + def list_files(self, *, path: str, recursive: bool) -> list[SandboxFileNode]: + import tarfile + + archive_path = SandboxArchivePath(tenant_id=UUID(self._tenant_id), sandbox_id=UUID(self._sandbox_id)) + storage_key = archive_path.get_storage_key() + if not storage.exists(storage_key): + raise ValueError("Sandbox archive not found") + + with tempfile.TemporaryDirectory(prefix="dify-sandbox-archive-") as tmpdir: + local_archive = os.path.join(tmpdir, "workspace.tar.gz") + storage.download(storage_key, local_archive) + + entries_by_path: dict[str, SandboxFileNode] = {} + + def add_dir(dir_path: str) -> None: + if dir_path in ("", "."): + return + if dir_path not in entries_by_path: + entries_by_path[dir_path] = SandboxFileNode(path=dir_path, is_dir=True, size=None, mtime=None) + + def clean(member_name: str) -> str: + name = member_name.lstrip("./") + return name.rstrip("/") + + target_prefix = "" if path in (".", "") else f"{path}/" + + with tarfile.open(local_archive, mode="r:gz") as tf: + for m in tf.getmembers(): + mp = clean(m.name) + if mp in ("", "."): + continue + + if not recursive: + if path in (".", ""): + if "/" in mp: + add_dir(mp.split("/", 1)[0]) + continue + else: + if not mp.startswith(target_prefix): + continue + rest = mp[len(target_prefix) :] + if rest == "": + continue + if "/" in rest: + add_dir(f"{path}/{rest.split('/', 1)[0]}") + continue + else: + if path not in (".", "") and not (mp == path or mp.startswith(target_prefix)): + continue + + parent = os.path.dirname(mp) + while parent not in ("", "."): + if path not in (".", "") and parent == path: + break + add_dir(parent) + parent = os.path.dirname(parent) + + is_dir = m.isdir() + entries_by_path[mp] = SandboxFileNode( + path=mp, + is_dir=is_dir, + size=None if is_dir else int(m.size), + mtime=int(m.mtime) if m.mtime else None, + ) + + return sorted(entries_by_path.values(), key=lambda e: e.path) + + def download_file(self, *, path: str) -> SandboxFileDownloadTicket: + import tarfile + + archive_path = SandboxArchivePath(tenant_id=UUID(self._tenant_id), sandbox_id=UUID(self._sandbox_id)) + storage_key = archive_path.get_storage_key() + if not storage.exists(storage_key): + raise ValueError("Sandbox archive not found") + + export_name = os.path.basename(path.rstrip("/")) or "workspace" + export_id = uuid4().hex + + # Decide file vs directory inside archive. + is_dir_request = path in (".", "") + + with tempfile.TemporaryDirectory(prefix="dify-sandbox-archive-") as tmpdir: + local_archive = os.path.join(tmpdir, "workspace.tar.gz") + storage.download(storage_key, local_archive) + + with tarfile.open(local_archive, mode="r:gz") as tf: + member_name = path.lstrip("./").rstrip("/") + if not is_dir_request: + # If it is an explicit file in archive, treat as file download. + member = None + try: + member = tf.getmember(member_name) + except KeyError: + try: + member = tf.getmember(f"./{member_name}") + except KeyError: + member = None + + if member is not None and not member.isdir(): + export_path = SandboxFileDownloadPath( + tenant_id=UUID(self._tenant_id), + sandbox_id=UUID(self._sandbox_id), + export_id=export_id, + filename=os.path.basename(member_name) or "file", + ) + extracted = tf.extractfile(member) + if extracted is None: + raise ValueError("File not found in sandbox archive") + storage.save(export_path.get_storage_key(), extracted.read()) + + download_url = SandboxFileSigner.build_signed_url( + export_path=export_path, + expires_in=self._EXPORT_EXPIRES_IN_SECONDS, + action=SandboxFileSigner.OPERATION_DOWNLOAD, + ) + return SandboxFileDownloadTicket( + download_url=download_url, + expires_in=self._EXPORT_EXPIRES_IN_SECONDS, + export_id=export_id, + ) + + # Otherwise treat as directory (implied dir is common in tar). + is_dir_request = True + + if is_dir_request: + export_path = SandboxFileDownloadPath( + tenant_id=UUID(self._tenant_id), + sandbox_id=UUID(self._sandbox_id), + export_id=export_id, + filename=f"{export_name}.tar.gz", + ) + export_local = os.path.join(tmpdir, "export.tar.gz") + + prefix = "" if member_name in (".", "") else f"{member_name}/" + found_any = False + for m in tf.getmembers(): + src_name = m.name.lstrip("./").rstrip("/") + if member_name not in (".", ""): + if src_name != member_name and not src_name.startswith(prefix): + continue + found_any = True + break + + if not found_any: + raise ValueError("File not found in sandbox archive") + + with tarfile.open(export_local, mode="w:gz") as out: + if member_name not in (".", ""): + dir_info = tarfile.TarInfo(name=member_name) + dir_info.type = tarfile.DIRTYPE + dir_info.size = 0 + out.addfile(dir_info) + + for m in tf.getmembers(): + src_name = m.name.lstrip("./") + if member_name not in (".", ""): + if src_name != member_name and not src_name.startswith(prefix): + continue + ti = tarfile.TarInfo(name=src_name.rstrip("/")) + ti.mode = m.mode + ti.mtime = m.mtime + ti.uid = m.uid + ti.gid = m.gid + ti.uname = m.uname + ti.gname = m.gname + if m.isdir(): + ti.type = tarfile.DIRTYPE + ti.size = 0 + out.addfile(ti) + continue + extracted = tf.extractfile(m) + if extracted is None: + continue + ti.size = int(m.size) + out.addfile(ti, fileobj=extracted) + + storage.save(export_path.get_storage_key(), Path(export_local).read_bytes()) + + download_url = SandboxFileSigner.build_signed_url( + export_path=export_path, + expires_in=self._EXPORT_EXPIRES_IN_SECONDS, + action=SandboxFileSigner.OPERATION_DOWNLOAD, + ) + return SandboxFileDownloadTicket( + download_url=download_url, + expires_in=self._EXPORT_EXPIRES_IN_SECONDS, + export_id=export_id, + ) + + raise ValueError("File not found in sandbox archive") + + +class SandboxFileBrowser: + def __init__(self, *, tenant_id: str, sandbox_id: str): + self._tenant_id = tenant_id + self._sandbox_id = sandbox_id + + @staticmethod + def _normalize_workspace_path(path: str | None) -> str: + raw = (path or ".").strip() + if raw == "": + raw = "." + + p = PurePosixPath(raw) + if p.is_absolute(): + raise ValueError("path must be relative") + if any(part == ".." for part in p.parts): + raise ValueError("path must not contain '..'") + + normalized = str(p) + return "." if normalized in (".", "") else normalized + + def _backend(self) -> SandboxFileSource: + runtime = SandboxManager.get(self._sandbox_id) + if runtime is not None: + return SandboxFileRuntimeSource(tenant_id=self._tenant_id, sandbox_id=self._sandbox_id, runtime=runtime) + return SandboxFileArchiveSource(tenant_id=self._tenant_id, sandbox_id=self._sandbox_id) + + def list_files(self, *, path: str | None = None, recursive: bool = False) -> list[SandboxFileNode]: + workspace_path = self._normalize_workspace_path(path) + return self._backend().list_files(path=workspace_path, recursive=recursive) + + def download_file(self, *, path: str) -> SandboxFileDownloadTicket: + workspace_path = self._normalize_workspace_path(path) + return self._backend().download_file(path=workspace_path) diff --git a/api/core/sandbox/security/sandbox_file_signer.py b/api/core/sandbox/security/sandbox_file_signer.py new file mode 100644 index 0000000000..dd59023ba9 --- /dev/null +++ b/api/core/sandbox/security/sandbox_file_signer.py @@ -0,0 +1,155 @@ +from __future__ import annotations + +import base64 +import hashlib +import hmac +import os +import time +import urllib.parse +from dataclasses import dataclass +from uuid import UUID + +from configs import dify_config +from libs import rsa + + +@dataclass(frozen=True) +class SandboxFileDownloadPath: + tenant_id: UUID + sandbox_id: UUID + export_id: str + filename: str + + def get_storage_key(self) -> str: + return f"sandbox_file_downloads/{self.tenant_id}/{self.sandbox_id}/{self.export_id}/{self.filename}" + + def proxy_path(self) -> str: + encoded_parts = [ + urllib.parse.quote(str(self.tenant_id), safe=""), + urllib.parse.quote(str(self.sandbox_id), safe=""), + urllib.parse.quote(self.export_id, safe=""), + urllib.parse.quote(self.filename, safe=""), + ] + return "/".join(encoded_parts) + + +class SandboxFileSigner: + SIGNATURE_PREFIX = "sandbox-file-download" + SIGNATURE_VERSION = "v1" + OPERATION_DOWNLOAD = "download" + OPERATION_UPLOAD = "upload" + + @classmethod + def build_signed_url( + cls, + *, + export_path: SandboxFileDownloadPath, + expires_in: int, + action: str, + ) -> str: + expires_in = min(expires_in, dify_config.FILES_ACCESS_TIMEOUT) + expires_at = int(time.time()) + max(expires_in, 1) + nonce = os.urandom(16).hex() + sign = cls._create_signature( + export_path=export_path, + operation=action, + expires_at=expires_at, + nonce=nonce, + ) + + base_url = dify_config.FILES_URL + url = f"{base_url}/files/sandbox-file-downloads/{export_path.proxy_path()}/{action}" + query = urllib.parse.urlencode({"expires_at": expires_at, "nonce": nonce, "sign": sign}) + return f"{url}?{query}" + + @classmethod + def verify_download_signature( + cls, + *, + export_path: SandboxFileDownloadPath, + expires_at: int, + nonce: str, + sign: str, + ) -> bool: + return cls._verify_signature( + export_path=export_path, + operation=cls.OPERATION_DOWNLOAD, + expires_at=expires_at, + nonce=nonce, + sign=sign, + ) + + @classmethod + def verify_upload_signature( + cls, + *, + export_path: SandboxFileDownloadPath, + expires_at: int, + nonce: str, + sign: str, + ) -> bool: + return cls._verify_signature( + export_path=export_path, + operation=cls.OPERATION_UPLOAD, + expires_at=expires_at, + nonce=nonce, + sign=sign, + ) + + @classmethod + def _verify_signature( + cls, + *, + export_path: SandboxFileDownloadPath, + operation: str, + expires_at: int, + nonce: str, + sign: str, + ) -> bool: + if expires_at <= 0: + return False + + expected_sign = cls._create_signature( + export_path=export_path, + operation=operation, + expires_at=expires_at, + nonce=nonce, + ) + if not hmac.compare_digest(sign, expected_sign): + return False + + current_time = int(time.time()) + if expires_at < current_time: + return False + + if expires_at - current_time > dify_config.FILES_ACCESS_TIMEOUT: + return False + + return True + + @classmethod + def _create_signature( + cls, + *, + export_path: SandboxFileDownloadPath, + operation: str, + expires_at: int, + nonce: str, + ) -> str: + key = cls._tenant_key(str(export_path.tenant_id)) + message = ( + f"{cls.SIGNATURE_PREFIX}|{cls.SIGNATURE_VERSION}|{operation}|" + f"{export_path.tenant_id}|{export_path.sandbox_id}|{export_path.export_id}|{export_path.filename}|" + f"{expires_at}|{nonce}" + ) + digest = hmac.new(key, message.encode(), hashlib.sha256).digest() + return base64.urlsafe_b64encode(digest).decode() + + @classmethod + def _tenant_key(cls, tenant_id: str) -> bytes: + try: + rsa_key, _ = rsa.get_decrypt_decoding(tenant_id) + except rsa.PrivkeyNotFoundError as exc: + raise ValueError(f"Tenant private key missing for tenant_id={tenant_id}") from exc + private_key = rsa_key.export_key() + return hashlib.sha256(private_key).digest() diff --git a/api/services/sandbox/sandbox_file_service.py b/api/services/sandbox/sandbox_file_service.py new file mode 100644 index 0000000000..95b8d04040 --- /dev/null +++ b/api/services/sandbox/sandbox_file_service.py @@ -0,0 +1,23 @@ +from __future__ import annotations + +from core.sandbox.entities.files import SandboxFileDownloadTicket, SandboxFileNode +from core.sandbox.inspector import SandboxFileBrowser + + +class SandboxFileService: + @classmethod + def list_files( + cls, + *, + tenant_id: str, + sandbox_id: str, + path: str | None = None, + recursive: bool = False, + ) -> list[SandboxFileNode]: + browser = SandboxFileBrowser(tenant_id=tenant_id, sandbox_id=sandbox_id) + return browser.list_files(path=path, recursive=recursive) + + @classmethod + def download_file(cls, *, tenant_id: str, sandbox_id: str, path: str) -> SandboxFileDownloadTicket: + browser = SandboxFileBrowser(tenant_id=tenant_id, sandbox_id=sandbox_id) + return browser.download_file(path=path)