Merge branch 'fix-drive-skill-archive-cleanup' into feat/agent-v2

# Conflicts:
#	dify-agent/tests/local/dify_agent/layers/drive/test_layer.py
This commit is contained in:
Yanli 盐粒 2026-06-25 22:08:18 +08:00
commit 0ed76850d4
5 changed files with 65 additions and 24 deletions

View File

@ -13,7 +13,7 @@ import stat
from dataclasses import dataclass
from pathlib import Path, PurePosixPath
from tempfile import TemporaryDirectory
from typing import Collection, Final, Mapping
from typing import Final
from uuid import uuid4
from zipfile import BadZipFile, ZipFile, ZipInfo
@ -42,14 +42,14 @@ def materialize_drive_downloads(
*,
base_path: Path,
downloads: list[DriveDownloadPayload],
archive_skip_entry_names_by_dir: Mapping[str, Collection[str]] | None = None,
) -> list[Path]:
"""Write downloaded drive payloads under one local base and extract skills.
The helper preserves caller-provided order in the returned list of written
paths. Skill archives are extracted only after every payload has been
The helper preserves caller-provided order in the returned list of paths.
Skill archives are extracted and deleted only after every payload has been
written successfully so partial extraction cannot outlive a later failure in
the same batch.
the same batch. The returned path for an archive is the path where it was
downloaded before successful extraction.
"""
resolved_base_path = base_path.expanduser().resolve()
@ -60,7 +60,6 @@ def materialize_drive_downloads(
written_paths: list[Path] = []
archive_paths: list[Path] = []
skip_entry_names_by_dir = archive_skip_entry_names_by_dir or {}
for download in downloads:
if download.size is not None and len(download.payload) != download.size:
raise DriveMaterializationTransferError(f"downloaded drive file size mismatch for {download.key}")
@ -77,11 +76,8 @@ def materialize_drive_downloads(
archive_paths.append(destination)
for archive_path in sorted(archive_paths):
archive_skill_dir = archive_path.parent.relative_to(resolved_base_path).as_posix()
extract_skill_archive(
archive_path,
skip_entry_names=frozenset(skip_entry_names_by_dir.get(archive_skill_dir, ())),
)
extract_skill_archive(archive_path)
_delete_extracted_archive(archive_path)
return written_paths
@ -96,18 +92,15 @@ def resolve_drive_destination(base_path: Path, drive_key: str) -> Path:
return destination
def extract_skill_archive(archive_path: Path, *, skip_entry_names: Collection[str] = ()) -> None:
def extract_skill_archive(archive_path: Path) -> None:
"""Safely extract one downloaded skill archive into its containing directory."""
target_dir = archive_path.parent.resolve()
normalized_skip_entry_names = {entry_name.replace("\\", "/").rstrip("/") for entry_name in skip_entry_names}
try:
with TemporaryDirectory(dir=target_dir, prefix=".dify-skill-extract-") as staging_dir_name:
staging_dir = Path(staging_dir_name).resolve()
with ZipFile(archive_path) as archive:
for zip_info in archive.infolist():
if zip_info.filename.replace("\\", "/").rstrip("/") in normalized_skip_entry_names:
continue
destination = _resolve_zip_entry_destination(staging_dir, zip_info.filename)
if _is_zip_symlink(zip_info):
raise DriveMaterializationValidationError(
@ -156,6 +149,15 @@ def _is_zip_symlink(zip_info: ZipInfo) -> bool:
return stat.S_ISLNK(file_mode)
def _delete_extracted_archive(archive_path: Path) -> None:
try:
archive_path.unlink(missing_ok=True)
except OSError as exc:
raise DriveMaterializationTransferError(
f"failed to delete extracted skill archive: {archive_path.name}"
) from exc
__all__ = [
"DriveDownloadPayload",
"DriveMaterializationTransferError",

View File

@ -126,7 +126,8 @@ def pull_drive_from_environment(
``.DIFY-SKILL-FULL.zip`` archives into their containing skill
directory with the same path-safety checks. Archive extraction is staged
under a temporary directory and only moved into place after the full
archive validates successfully.
archive validates successfully. Successfully extracted skill archives
are deleted from disk.
Extracted files are materialized on disk but are not added to the
returned item list.

View File

@ -241,7 +241,6 @@ class DifyDriveLayer(PlainLayer[DifyDriveDeps, DifyDriveLayerConfig, EmptyRuntim
async def _download_items(self, items: list[AgentStubDriveItem]) -> dict[str, str]:
base_path = Path(agent_stub_drive_base_for_ref(self.config.drive_ref))
semaphore = asyncio.Semaphore(_DOWNLOAD_CONCURRENCY)
canonical_skill_dirs = {item.key.rsplit("/", 1)[0] for item in items if item.key.endswith("/SKILL.md")}
async with httpx.AsyncClient(timeout=30.0, follow_redirects=True, trust_env=False) as client:
@ -264,7 +263,6 @@ class DifyDriveLayer(PlainLayer[DifyDriveDeps, DifyDriveLayerConfig, EmptyRuntim
written_paths = materialize_drive_downloads(
base_path=base_path,
downloads=downloads,
archive_skip_entry_names_by_dir={skill_dir: {"SKILL.md"} for skill_dir in canonical_skill_dirs},
)
except (DriveMaterializationValidationError, DriveMaterializationTransferError) as exc:
raise DifyDriveLayerError(str(exc)) from exc

View File

@ -179,7 +179,7 @@ def test_pull_drive_from_environment_auto_extracts_skill_archive(
assert result.model_dump() == {
"items": [{"key": "skills/foo/.DIFY-SKILL-FULL.zip", "local_path": str(archive_path)}]
}
assert archive_path.read_bytes() == archive_bytes
assert not archive_path.exists()
assert (tmp_path / "skills" / "foo" / "SKILL.md").read_text(encoding="utf-8") == "# Example\n"
assert (tmp_path / "skills" / "foo" / "nested" / "helper.py").read_text(encoding="utf-8") == "print('x')\n"

View File

@ -2,8 +2,10 @@
from __future__ import annotations
from io import BytesIO
from pathlib import Path
from typing import ClassVar
from zipfile import ZipFile
import pytest
@ -314,12 +316,9 @@ async def test_download_items_hands_validated_downloads_to_materialization(
)
captured: dict[str, object] = {}
def fake_materialize_drive_downloads(
*, base_path: Path, downloads: list[DriveDownloadPayload], archive_skip_entry_names_by_dir
):
def fake_materialize_drive_downloads(*, base_path: Path, downloads: list[DriveDownloadPayload]):
captured["base_path"] = base_path
captured["downloads"] = downloads
captured["archive_skip_entry_names_by_dir"] = archive_skip_entry_names_by_dir
return [tmp_path / "tender-analyzer" / "SKILL.md", tmp_path / "files" / "report.pdf"]
monkeypatch.setattr(
@ -340,13 +339,54 @@ async def test_download_items_hands_validated_downloads_to_materialization(
DriveDownloadPayload(key="tender-analyzer/SKILL.md", payload=b"skill-md", size=8),
DriveDownloadPayload(key="files/report.pdf", payload=b"pdf", size=3),
]
assert captured["archive_skip_entry_names_by_dir"] == {"tender-analyzer": {"SKILL.md"}}
assert result == {
"tender-analyzer/SKILL.md": str(tmp_path / "tender-analyzer" / "SKILL.md"),
"files/report.pdf": str(tmp_path / "files" / "report.pdf"),
}
@pytest.mark.anyio
async def test_download_items_extracts_skill_archive_over_skill_md_and_deletes_archive(
monkeypatch: pytest.MonkeyPatch,
tmp_path: Path,
) -> None:
layer = _build_layer(tmp_path)
archive_buffer = BytesIO()
with ZipFile(archive_buffer, mode="w") as archive:
archive.writestr("SKILL.md", "# From archive\n")
archive.writestr("helper.py", "print('archive')\n")
archive_bytes = archive_buffer.getvalue()
responses = {
"download:https://files/skill-md": _FakeAsyncResponse(content=b"# From manifest\n"),
"download:https://files/archive": _FakeAsyncResponse(content=archive_bytes),
}
monkeypatch.setattr(
"dify_agent.layers.drive.layer.httpx.AsyncClient",
lambda **_kwargs: _FakeAsyncClient(responses),
)
monkeypatch.setattr(
"dify_agent.layers.drive.layer.agent_stub_drive_base_for_ref",
lambda _drive_ref: str(tmp_path),
)
result = await layer._download_items(
[
AgentStubDriveItem(key="tender-analyzer/SKILL.md", download_url="https://files/skill-md", size=16),
AgentStubDriveItem(
key="tender-analyzer/.DIFY-SKILL-FULL.zip",
download_url="https://files/archive",
size=len(archive_bytes),
),
]
)
archive_path = tmp_path / "tender-analyzer" / ".DIFY-SKILL-FULL.zip"
assert result["tender-analyzer/.DIFY-SKILL-FULL.zip"] == str(archive_path)
assert not archive_path.exists()
assert (tmp_path / "tender-analyzer" / "SKILL.md").read_text(encoding="utf-8") == "# From archive\n"
assert (tmp_path / "tender-analyzer" / "helper.py").read_text(encoding="utf-8") == "print('archive')\n"
@pytest.mark.anyio
async def test_on_context_resume_raises_when_mentioned_targets_are_missing(
monkeypatch: pytest.MonkeyPatch,