From 18a589003e2505626d140e0ee50a5ffb7c58a1e4 Mon Sep 17 00:00:00 2001 From: Harry Date: Tue, 20 Jan 2026 19:44:20 +0800 Subject: [PATCH] feat(sandbox): enhance sandbox initialization with draft support and asset management - Introduced DraftAppAssetsInitializer for handling draft assets. - Updated SandboxLayer to conditionally set sandbox ID and storage based on workflow version. - Improved asset initialization logging and error handling. - Refactored ArchiveSandboxStorage to support exclusion patterns during archiving. - Modified command and LLM nodes to retrieve sandbox from workflow context, supporting draft workflows. --- api/core/app/layers/sandbox_layer.py | 25 +- api/core/sandbox/bash/session.py | 16 +- .../initializer/app_assets_initializer.py | 33 +- api/core/sandbox/storage/archive_storage.py | 29 +- api/core/workflow/nodes/command/node.py | 10 +- api/core/workflow/nodes/llm/node.py | 32 +- api/services/workflow_service.py | 14 +- .../core/app/layers/test_sandbox_layer.py | 324 ------------------ 8 files changed, 114 insertions(+), 369 deletions(-) delete mode 100644 api/tests/unit_tests/core/app/layers/test_sandbox_layer.py diff --git a/api/core/app/layers/sandbox_layer.py b/api/core/app/layers/sandbox_layer.py index 4ad8229cc6..c1e2b27a8b 100644 --- a/api/core/app/layers/sandbox_layer.py +++ b/api/core/app/layers/sandbox_layer.py @@ -1,6 +1,8 @@ import logging from core.sandbox import AppAssetsInitializer, DifyCliInitializer, SandboxManager +from core.sandbox.constants import APP_ASSETS_PATH +from core.sandbox.initializer.app_assets_initializer import DraftAppAssetsInitializer from core.sandbox.storage.archive_storage import ArchiveSandboxStorage from core.sandbox.vm import SandboxBuilder from core.workflow.graph_engine.layers.base import GraphEngineLayer @@ -33,12 +35,11 @@ class SandboxLayer(GraphEngineLayer): self._user_id = user_id self._workflow_version = workflow_version self._workflow_execution_id = workflow_execution_id - self._sandbox_id = ( - self._workflow_execution_id - if self._workflow_version == Workflow.VERSION_DRAFT - else SandboxBuilder.draft_id(self._user_id) + is_draft = self._workflow_version == Workflow.VERSION_DRAFT + self._sandbox_id = SandboxBuilder.draft_id(self._user_id) if is_draft else self._workflow_execution_id + self._sandbox_storage = ArchiveSandboxStorage( + self._tenant_id, self._sandbox_id, exclude_patterns=[APP_ASSETS_PATH] if is_draft else None ) - self._sandbox_storage = ArchiveSandboxStorage(self._tenant_id, self._sandbox_id) def on_graph_start(self) -> None: try: @@ -61,9 +62,15 @@ class SandboxLayer(GraphEngineLayer): ) AppAssetService.build_assets(self._tenant_id, self._app_id, assets) + assets_initializer = ( + DraftAppAssetsInitializer(self._tenant_id, self._app_id, assets.id) + if is_draft + else AppAssetsInitializer(self._tenant_id, self._app_id, assets.id) + ) + builder = ( SandboxProviderService.create_sandbox_builder(self._tenant_id) - .initializer(AppAssetsInitializer(self._tenant_id, self._app_id, assets.id)) + .initializer(assets_initializer) .initializer(DifyCliInitializer(self._tenant_id, self._user_id, self._app_id, assets.id)) ) try: @@ -78,12 +85,6 @@ class SandboxLayer(GraphEngineLayer): raise SandboxInitializationError(f"Failed to build sandbox: {e}") from e SandboxManager.register(self._sandbox_id, sandbox) - logger.info( - "Sandbox initialized, workflow_execution_id=%s, sandbox_id=%s, sandbox_arch=%s", - self._sandbox_id, - sandbox.metadata.id, - sandbox.metadata.arch, - ) # mount sandbox files from storage mounted = self._sandbox_storage.mount(sandbox) diff --git a/api/core/sandbox/bash/session.py b/api/core/sandbox/bash/session.py index e37959b1b8..eb836ba6a2 100644 --- a/api/core/sandbox/bash/session.py +++ b/api/core/sandbox/bash/session.py @@ -18,7 +18,6 @@ from ..constants import ( DIFY_CLI_PATH, DIFY_CLI_TOOLS_ROOT, ) -from ..manager import SandboxManager from .bash_tool import SandboxBashTool logger = logging.getLogger(__name__) @@ -28,7 +27,7 @@ class SandboxBashSession: def __init__( self, *, - workflow_execution_id: str, + sandbox: VirtualEnvironment, tenant_id: str, user_id: str, node_id: str, @@ -36,7 +35,7 @@ class SandboxBashSession: assets_id: str, allow_tools: list[tuple[str, str]] | None, ) -> None: - self._workflow_execution_id = workflow_execution_id + self._sandbox = sandbox self._tenant_id = tenant_id self._user_id = user_id self._node_id = node_id @@ -46,25 +45,18 @@ class SandboxBashSession: self._assets_id = assets_id self._allow_tools = allow_tools - self._sandbox = None self._bash_tool = None self._session_id = None def __enter__(self) -> SandboxBashSession: - sandbox = SandboxManager.get(self._workflow_execution_id) - if sandbox is None: - raise RuntimeError(f"Sandbox not found for workflow_execution_id={self._workflow_execution_id}") - - self._sandbox = sandbox - if self._allow_tools is not None: if self._node_id is None: raise ValueError("node_id is required when allow_tools is specified") - tools_path = self._setup_node_tools_directory(sandbox, self._node_id, self._allow_tools) + tools_path = self._setup_node_tools_directory(self._sandbox, self._node_id, self._allow_tools) else: tools_path = DIFY_CLI_GLOBAL_TOOLS_PATH - self._bash_tool = SandboxBashTool(sandbox=sandbox, tenant_id=self._tenant_id, tools_path=tools_path) + self._bash_tool = SandboxBashTool(sandbox=self._sandbox, tenant_id=self._tenant_id, tools_path=tools_path) return self def _setup_node_tools_directory( diff --git a/api/core/sandbox/initializer/app_assets_initializer.py b/api/core/sandbox/initializer/app_assets_initializer.py index f738df11d5..dec580b20d 100644 --- a/api/core/sandbox/initializer/app_assets_initializer.py +++ b/api/core/sandbox/initializer/app_assets_initializer.py @@ -6,7 +6,7 @@ from core.virtual_environment.__base.virtual_environment import VirtualEnvironme from extensions.ext_storage import storage from extensions.storage.file_presign_storage import FilePresignStorage -from ..constants import APP_ASSETS_ZIP_PATH +from ..constants import APP_ASSETS_PATH, APP_ASSETS_ZIP_PATH from .base import SandboxInitializer logger = logging.getLogger(__name__) @@ -33,7 +33,36 @@ class AppAssetsInitializer(SandboxInitializer): ["sh", "-c", f"unzip {APP_ASSETS_ZIP_PATH} 2>/dev/null || [ $? -eq 1 ]"], error_message="Failed to unzip assets", ) - .add(["rm", "-f", APP_ASSETS_ZIP_PATH], error_message="Failed to cleanup temp zip file") + .execute(timeout=APP_ASSETS_DOWNLOAD_TIMEOUT, raise_on_error=True) + ) + + logger.info( + "App assets initialized for app_id=%s, published_id=%s", + self._app_id, + self._assets_id, + ) + + +class DraftAppAssetsInitializer(SandboxInitializer): + def __init__(self, tenant_id: str, app_id: str, assets_id: str) -> None: + self._tenant_id = tenant_id + self._app_id = app_id + self._assets_id = assets_id + + def initialize(self, env: VirtualEnvironment) -> None: + zip_key = AssetPaths.build_zip(self._tenant_id, self._app_id, self._assets_id) + download_url = FilePresignStorage(storage.storage_runner).get_download_url(zip_key) + + ( + pipeline(env) + .add(["rm", "-rf", APP_ASSETS_PATH]) + .add(["wget", "-q", download_url, "-O", APP_ASSETS_ZIP_PATH], error_message="Failed to download assets zip") + # unzip with silent error and return 1 if the zip is empty + # FIXME(Mairuis): should use a more robust way to check if the zip is empty + .add( + ["sh", "-c", f"unzip {APP_ASSETS_ZIP_PATH} 2>/dev/null || [ $? -eq 1 ]"], + error_message="Failed to unzip assets", + ) .execute(timeout=APP_ASSETS_DOWNLOAD_TIMEOUT, raise_on_error=True) ) diff --git a/api/core/sandbox/storage/archive_storage.py b/api/core/sandbox/storage/archive_storage.py index 4f7ed247bb..b0c124a2dd 100644 --- a/api/core/sandbox/storage/archive_storage.py +++ b/api/core/sandbox/storage/archive_storage.py @@ -18,10 +18,19 @@ ARCHIVE_DOWNLOAD_TIMEOUT = 60 * 5 ARCHIVE_UPLOAD_TIMEOUT = 60 * 5 +def build_tar_exclude_args(patterns: list[str]) -> list[str]: + return [f"--exclude={p}" for p in patterns] + + class ArchiveSandboxStorage(SandboxStorage): - def __init__(self, tenant_id: str, sandbox_id: str): + _tenant_id: str + _sandbox_id: str + _exclude_patterns: list[str] + + def __init__(self, tenant_id: str, sandbox_id: str, exclude_patterns: list[str] | None = None): self._tenant_id = tenant_id self._sandbox_id = sandbox_id + self._exclude_patterns = exclude_patterns or [] @property def _storage_key(self) -> str: @@ -36,7 +45,7 @@ class ArchiveSandboxStorage(SandboxStorage): try: ( pipeline(sandbox) - .add(["wget", download_url, "-O", ARCHIVE_NAME], error_message="Failed to download archive") + .add(["wget", "-q", download_url, "-O", ARCHIVE_NAME], error_message="Failed to download archive") .add(["tar", "-xzf", ARCHIVE_NAME], error_message="Failed to extract archive") .add(["rm", ARCHIVE_NAME], error_message="Failed to cleanup archive") .execute(timeout=ARCHIVE_DOWNLOAD_TIMEOUT, raise_on_error=True) @@ -53,10 +62,22 @@ class ArchiveSandboxStorage(SandboxStorage): ( pipeline(sandbox) .add( - ["tar", "-czf", ARCHIVE_PATH, "--warning=no-file-changed", "-C", WORKSPACE_DIR, "."], + [ + "tar", + "-czf", + ARCHIVE_PATH, + "--warning=no-file-changed", + *build_tar_exclude_args(self._exclude_patterns), + "-C", + WORKSPACE_DIR, + ".", + ], error_message="Failed to create archive", ) - .add(["wget", upload_url, "-O", ARCHIVE_PATH], error_message="Failed to upload archive") + .add( + ["curl", "-s", "-f", "-X", "PUT", "-T", ARCHIVE_PATH, upload_url], + error_message="Failed to upload archive", + ) .execute(timeout=ARCHIVE_UPLOAD_TIMEOUT, raise_on_error=True) ) logger.info("Unmounted archive for sandbox %s", self._sandbox_id) diff --git a/api/core/workflow/nodes/command/node.py b/api/core/workflow/nodes/command/node.py index 39c07a1b1b..35e1a1340b 100644 --- a/api/core/workflow/nodes/command/node.py +++ b/api/core/workflow/nodes/command/node.py @@ -4,6 +4,7 @@ from collections.abc import Mapping, Sequence from typing import Any from core.sandbox import SandboxManager, sandbox_debug +from core.sandbox.vm import SandboxBuilder from core.virtual_environment.__base.command_future import CommandCancelledError, CommandTimeoutError from core.virtual_environment.__base.helpers import submit_command, with_connection from core.virtual_environment.__base.virtual_environment import VirtualEnvironment @@ -24,11 +25,18 @@ COMMAND_NODE_TIMEOUT_SECONDS = 60 class CommandNode(Node[CommandNodeData]): node_type = NodeType.COMMAND + # FIXME(Mairuis): should read sandbox from workflow run context... def _get_sandbox(self) -> VirtualEnvironment | None: workflow_execution_id = self.graph_runtime_state.variable_pool.system_variables.workflow_execution_id if not workflow_execution_id: return None - return SandboxManager.get(workflow_execution_id) + sandbox_by_workflow_run_id = SandboxManager.get(workflow_execution_id) + if sandbox_by_workflow_run_id is not None: + return sandbox_by_workflow_run_id + sandbox_by_draft_id = SandboxManager.get(SandboxBuilder.draft_id(self.user_id)) + if sandbox_by_draft_id is not None: + return sandbox_by_draft_id + return None def _render_template(self, template: str) -> str: parser = VariableTemplateParser(template=template) diff --git a/api/core/workflow/nodes/llm/node.py b/api/core/workflow/nodes/llm/node.py index c404ba4318..4eb9431230 100644 --- a/api/core/workflow/nodes/llm/node.py +++ b/api/core/workflow/nodes/llm/node.py @@ -51,6 +51,7 @@ from core.prompt.entities.advanced_prompt_entities import CompletionModelPromptT from core.prompt.utils.prompt_message_util import PromptMessageUtil from core.rag.entities.citation_metadata import RetrievalSourceMetadata from core.sandbox import SandboxBashSession, SandboxManager +from core.sandbox.vm import SandboxBuilder from core.tools.__base.tool import Tool from core.tools.signature import sign_upload_file from core.tools.tool_manager import ToolManager @@ -63,6 +64,7 @@ from core.variables import ( ObjectSegment, StringSegment, ) +from core.virtual_environment.__base.virtual_environment import VirtualEnvironment from core.workflow.constants import SYSTEM_VARIABLE_NODE_ID from core.workflow.entities import GraphInitParams, ToolCall, ToolResult, ToolResultStatus from core.workflow.entities.tool_entities import ToolCallResult @@ -172,6 +174,19 @@ class LLMNode(Node[LLMNodeData]): def version(cls) -> str: return "1" + # FIXME(Mairuis): should read sandbox from workflow run context... + def _get_sandbox(self) -> VirtualEnvironment | None: + workflow_execution_id = self.graph_runtime_state.variable_pool.system_variables.workflow_execution_id + if not workflow_execution_id: + return None + sandbox_by_workflow_run_id = SandboxManager.get(workflow_execution_id) + if sandbox_by_workflow_run_id is not None: + return sandbox_by_workflow_run_id + sandbox_by_draft_id = SandboxManager.get(SandboxBuilder.draft_id(self.user_id)) + if sandbox_by_draft_id is not None: + return sandbox_by_draft_id + return None + def _run(self) -> Generator: node_inputs: dict[str, Any] = {} process_data: dict[str, Any] = {} @@ -287,13 +302,11 @@ class LLMNode(Node[LLMNodeData]): structured_output: LLMStructuredOutput | None = None if self.tool_call_enabled: - workflow_execution_id = variable_pool.system_variables.workflow_execution_id - is_sandbox_runtime = workflow_execution_id is not None and SandboxManager.is_sandbox_runtime( - workflow_execution_id - ) - - if is_sandbox_runtime: + # FIXME(Mairuis): should read sandbox from workflow run context... + sandbox = self._get_sandbox() + if sandbox: generator = self._invoke_llm_with_sandbox( + sandbox=sandbox, model_instance=model_instance, prompt_messages=prompt_messages, stop=stop, @@ -1827,21 +1840,18 @@ class LLMNode(Node[LLMNodeData]): def _invoke_llm_with_sandbox( self, + sandbox: VirtualEnvironment, model_instance: ModelInstance, prompt_messages: Sequence[PromptMessage], stop: Sequence[str] | None, variable_pool: VariablePool, ) -> Generator[NodeEventBase, None, LLMGenerationData]: - workflow_execution_id = variable_pool.system_variables.workflow_execution_id - if not workflow_execution_id: - raise LLMNodeError("workflow_execution_id is required for sandbox runtime mode") - allow_tools = self._get_allow_tools_list() result: LLMGenerationData | None = None with SandboxBashSession( - workflow_execution_id=workflow_execution_id, + sandbox=sandbox, tenant_id=self.tenant_id, user_id=self.user_id, node_id=self.id, diff --git a/api/services/workflow_service.py b/api/services/workflow_service.py index 8b13de703b..a509eccad4 100644 --- a/api/services/workflow_service.py +++ b/api/services/workflow_service.py @@ -15,8 +15,9 @@ from core.app.apps.workflow.app_config_manager import WorkflowAppConfigManager from core.file import File from core.repositories import DifyCoreRepositoryFactory from core.sandbox import SandboxManager +from core.sandbox.constants import APP_ASSETS_PATH from core.sandbox.storage.archive_storage import ArchiveSandboxStorage -from core.sandbox.storage.sandbox_storage import SandboxStorage +from core.sandbox.vm import SandboxBuilder from core.variables import Variable, VariableBase from core.workflow.entities import WorkflowNodeExecution from core.workflow.enums import ErrorStrategy, WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus @@ -706,20 +707,24 @@ class WorkflowService: from core.sandbox import AppAssetsInitializer, DifyCliInitializer from services.app_asset_service import AppAssetService - assets = AppAssetService.get_or_create_assets(draft_workflow.tenant_id, app_model.id, is_draft=True) + assets = AppAssetService.get_or_create_assets(session, app_model, account.id) if not assets: raise ValueError(f"No assets found for tid={draft_workflow.tenant_id}, app_id={app_model.id}") # FIXME(Mairuis): single step execution AppAssetService.build_assets(draft_workflow.tenant_id, app_model.id, assets) + sandbox_id = SandboxBuilder.draft_id(account.id) + sandbox_storage = ArchiveSandboxStorage( + draft_workflow.tenant_id, sandbox_id, exclude_patterns=[APP_ASSETS_PATH] + ) sandbox = ( SandboxProviderService.create_sandbox_builder(draft_workflow.tenant_id) .initializer(DifyCliInitializer(draft_workflow.tenant_id, account.id, app_model.id, assets.id)) .initializer(AppAssetsInitializer(draft_workflow.tenant_id, app_model.id, assets.id)) - .storage(ArchiveSandboxStorage(draft_workflow.tenant_id, SandboxStorage.draft_id(account.id))) .build() ) + sandbox_storage.mount(sandbox) single_step_execution_id = f"single-step-{uuid.uuid4()}" SandboxManager.register(single_step_execution_id, sandbox) @@ -742,6 +747,9 @@ class WorkflowService: start_at=start_at, node_id=node_id, ) + # FIXME(Mairuis): fidn a better way to handle this + if sandbox is not None: + sandbox_storage.unmount(sandbox) finally: if single_step_execution_id: sandbox = SandboxManager.unregister(single_step_execution_id) diff --git a/api/tests/unit_tests/core/app/layers/test_sandbox_layer.py b/api/tests/unit_tests/core/app/layers/test_sandbox_layer.py deleted file mode 100644 index 3929c64896..0000000000 --- a/api/tests/unit_tests/core/app/layers/test_sandbox_layer.py +++ /dev/null @@ -1,324 +0,0 @@ -from typing import Any -from unittest.mock import MagicMock, patch - -import pytest - -from core.app.layers.sandbox_layer import SandboxInitializationError, SandboxLayer -from core.sandbox import SandboxManager -from core.sandbox.storage.sandbox_storage import SandboxStorage -from core.virtual_environment.__base.entities import Arch -from core.virtual_environment.__base.virtual_environment import VirtualEnvironment -from core.workflow.graph_engine.layers.base import GraphEngineLayerNotInitializedError -from core.workflow.graph_events.graph import ( - GraphRunFailedEvent, - GraphRunStartedEvent, - GraphRunSucceededEvent, -) -from models.app_asset import AppAssets - - -class MockMetadata: - def __init__(self, sandbox_id: str = "test-sandbox-id", arch: Arch = Arch.AMD64): - self.id = sandbox_id - self.arch = arch - - -class MockVirtualEnvironment: - def __init__(self, sandbox_id: str = "test-sandbox-id"): - self.metadata = MockMetadata(sandbox_id=sandbox_id) - self._released = False - - def release_environment(self) -> None: - self._released = True - - -class MockVMBuilder: - _sandbox: VirtualEnvironment - - def __init__(self, sandbox: VirtualEnvironment) -> None: - self._sandbox = sandbox - - def environments(self, _: object) -> "MockVMBuilder": - return self - - def initializer(self, _: object) -> "MockVMBuilder": - return self - - def build(self) -> VirtualEnvironment: - return self._sandbox - - -@pytest.fixture(autouse=True) -def clean_sandbox_manager(): - SandboxManager.clear() - yield - SandboxManager.clear() - - -@pytest.fixture -def mock_sandbox_storage() -> MagicMock: - mock_storage = MagicMock(spec=SandboxStorage) - mock_storage.mount.return_value = False - mock_storage.unmount.return_value = True - return mock_storage - - -def create_mock_builder(sandbox: Any) -> MockVMBuilder: - return MockVMBuilder(sandbox) - - -def create_layer( - tenant_id: str = "test-tenant", - app_id: str = "test-app", - workflow_version: str = AppAssets.VERSION_DRAFT, - sandbox_id: str = "test-sandbox", - sandbox_storage: Any = None, -) -> SandboxLayer: - if sandbox_storage is None: - sandbox_storage = MagicMock(spec=SandboxStorage) - sandbox_storage.mount.return_value = False - sandbox_storage.unmount.return_value = True - return SandboxLayer( - tenant_id=tenant_id, - app_id=app_id, - workflow_version=workflow_version, - sandbox_id=sandbox_id, - sandbox_storage=sandbox_storage, - ) - - -class TestSandboxLayer: - def test_init_with_parameters(self, mock_sandbox_storage: MagicMock) -> None: - layer = create_layer( - tenant_id="test-tenant", - app_id="test-app", - sandbox_id="test-sandbox", - sandbox_storage=mock_sandbox_storage, - ) - - assert layer._tenant_id == "test-tenant" # pyright: ignore[reportPrivateUsage] - assert layer._app_id == "test-app" # pyright: ignore[reportPrivateUsage] - assert layer._sandbox_id == "test-sandbox" # pyright: ignore[reportPrivateUsage] - - def test_sandbox_property_raises_when_not_initialized(self, mock_sandbox_storage: MagicMock) -> None: - layer = create_layer(sandbox_storage=mock_sandbox_storage) - - with pytest.raises(RuntimeError) as exc_info: - _ = layer.sandbox - - assert "Sandbox not found" in str(exc_info.value) - - def test_sandbox_property_returns_sandbox_after_initialization(self, mock_sandbox_storage: MagicMock) -> None: - sandbox_id = "test-exec-id" - layer = create_layer(sandbox_id=sandbox_id, sandbox_storage=mock_sandbox_storage) - mock_sandbox = MockVirtualEnvironment() - - with ( - patch( - "services.sandbox.sandbox_provider_service.SandboxProviderService.create_sandbox_builder", - return_value=create_mock_builder(mock_sandbox), - ), - patch("services.app_asset_service.AppAssetService.get_assets", return_value=None), - ): - layer.on_graph_start() - - assert layer.sandbox is mock_sandbox - - def test_on_graph_start_creates_sandbox_and_registers_with_manager(self, mock_sandbox_storage: MagicMock) -> None: - sandbox_id = "test-exec-123" - layer = create_layer( - tenant_id="test-tenant-123", - app_id="test-app-123", - sandbox_id=sandbox_id, - sandbox_storage=mock_sandbox_storage, - ) - mock_sandbox = MockVirtualEnvironment() - - with ( - patch( - "services.sandbox.sandbox_provider_service.SandboxProviderService.create_sandbox_builder", - return_value=create_mock_builder(mock_sandbox), - ) as mock_create, - patch("services.app_asset_service.AppAssetService.get_assets", return_value=None), - ): - layer.on_graph_start() - mock_create.assert_called_once_with("test-tenant-123") - - assert SandboxManager.get(sandbox_id) is mock_sandbox - - def test_on_graph_start_raises_sandbox_initialization_error_on_failure( - self, mock_sandbox_storage: MagicMock - ) -> None: - layer = create_layer(sandbox_storage=mock_sandbox_storage) - - with ( - patch( - "services.sandbox.sandbox_provider_service.SandboxProviderService.create_sandbox_builder", - side_effect=Exception("Sandbox provider not available"), - ), - patch("services.app_asset_service.AppAssetService.get_assets", return_value=None), - ): - with pytest.raises(SandboxInitializationError) as exc_info: - layer.on_graph_start() - - assert "Failed to initialize sandbox" in str(exc_info.value) - assert "Sandbox provider not available" in str(exc_info.value) - - def test_on_event_is_noop(self, mock_sandbox_storage: MagicMock) -> None: - layer = create_layer(sandbox_storage=mock_sandbox_storage) - - layer.on_event(GraphRunStartedEvent()) - layer.on_event(GraphRunSucceededEvent(outputs={})) - layer.on_event(GraphRunFailedEvent(error="test error", exceptions_count=1)) - - def test_on_graph_end_releases_sandbox_and_unregisters_from_manager(self, mock_sandbox_storage: MagicMock) -> None: - sandbox_id = "test-exec-456" - layer = create_layer(sandbox_id=sandbox_id, sandbox_storage=mock_sandbox_storage) - mock_sandbox = MagicMock(spec=VirtualEnvironment) - mock_sandbox.metadata = MockMetadata() - - with ( - patch( - "services.sandbox.sandbox_provider_service.SandboxProviderService.create_sandbox_builder", - return_value=create_mock_builder(mock_sandbox), - ), - patch("services.app_asset_service.AppAssetService.get_assets", return_value=None), - ): - layer.on_graph_start() - - assert SandboxManager.has(sandbox_id) - - layer.on_graph_end(error=None) - - mock_sandbox.release_environment.assert_called_once() - assert not SandboxManager.has(sandbox_id) - - def test_on_graph_end_releases_sandbox_even_on_error(self, mock_sandbox_storage: MagicMock) -> None: - sandbox_id = "test-exec-789" - layer = create_layer(sandbox_id=sandbox_id, sandbox_storage=mock_sandbox_storage) - mock_sandbox = MagicMock(spec=VirtualEnvironment) - mock_sandbox.metadata = MockMetadata() - - with ( - patch( - "services.sandbox.sandbox_provider_service.SandboxProviderService.create_sandbox_builder", - return_value=create_mock_builder(mock_sandbox), - ), - patch("services.app_asset_service.AppAssetService.get_assets", return_value=None), - ): - layer.on_graph_start() - - layer.on_graph_end(error=Exception("Workflow failed")) - - mock_sandbox.release_environment.assert_called_once() - assert not SandboxManager.has(sandbox_id) - - def test_on_graph_end_handles_release_failure_gracefully(self, mock_sandbox_storage: MagicMock) -> None: - sandbox_id = "test-exec-fail" - layer = create_layer(sandbox_id=sandbox_id, sandbox_storage=mock_sandbox_storage) - mock_sandbox = MagicMock(spec=VirtualEnvironment) - mock_sandbox.metadata = MockMetadata() - mock_sandbox.release_environment.side_effect = Exception("Container already removed") - - with ( - patch( - "services.sandbox.sandbox_provider_service.SandboxProviderService.create_sandbox_builder", - return_value=create_mock_builder(mock_sandbox), - ), - patch("services.app_asset_service.AppAssetService.get_assets", return_value=None), - ): - layer.on_graph_start() - - layer.on_graph_end(error=None) - - mock_sandbox.release_environment.assert_called_once() - - def test_on_graph_end_noop_when_sandbox_not_registered(self, mock_sandbox_storage: MagicMock) -> None: - layer = create_layer(sandbox_id="nonexistent-sandbox", sandbox_storage=mock_sandbox_storage) - - layer.on_graph_end(error=None) - - def test_on_graph_end_is_idempotent(self, mock_sandbox_storage: MagicMock) -> None: - sandbox_id = "test-exec-idempotent" - layer = create_layer(sandbox_id=sandbox_id, sandbox_storage=mock_sandbox_storage) - mock_sandbox = MagicMock(spec=VirtualEnvironment) - mock_sandbox.metadata = MockMetadata() - - with ( - patch( - "services.sandbox.sandbox_provider_service.SandboxProviderService.create_sandbox_builder", - return_value=create_mock_builder(mock_sandbox), - ), - patch("services.app_asset_service.AppAssetService.get_assets", return_value=None), - ): - layer.on_graph_start() - - layer.on_graph_end(error=None) - layer.on_graph_end(error=None) - - mock_sandbox.release_environment.assert_called_once() - - def test_layer_inherits_from_graph_engine_layer(self, mock_sandbox_storage: MagicMock) -> None: - layer = create_layer(sandbox_storage=mock_sandbox_storage) - - with pytest.raises(GraphEngineLayerNotInitializedError): - _ = layer.graph_runtime_state - - assert layer.command_channel is None - - -class TestSandboxLayerIntegration: - def test_full_lifecycle_with_mocked_provider(self, mock_sandbox_storage: MagicMock) -> None: - sandbox_id = "integration-test-exec" - layer = create_layer( - tenant_id="integration-tenant", - app_id="integration-app", - sandbox_id=sandbox_id, - sandbox_storage=mock_sandbox_storage, - ) - mock_sandbox = MagicMock(spec=VirtualEnvironment) - mock_sandbox.metadata = MockMetadata(sandbox_id="integration-sandbox") - - with ( - patch( - "services.sandbox.sandbox_provider_service.SandboxProviderService.create_sandbox_builder", - return_value=create_mock_builder(mock_sandbox), - ), - patch("services.app_asset_service.AppAssetService.get_assets", return_value=None), - ): - layer.on_graph_start() - - assert layer.sandbox is mock_sandbox - assert SandboxManager.get(sandbox_id) is mock_sandbox - - layer.on_graph_end(error=None) - - assert not SandboxManager.has(sandbox_id) - mock_sandbox.release_environment.assert_called_once() - - def test_lifecycle_with_workflow_error(self, mock_sandbox_storage: MagicMock) -> None: - sandbox_id = "integration-error-test" - layer = create_layer( - tenant_id="error-tenant", - app_id="error-app", - sandbox_id=sandbox_id, - sandbox_storage=mock_sandbox_storage, - ) - mock_sandbox = MagicMock(spec=VirtualEnvironment) - mock_sandbox.metadata = MockMetadata() - - with ( - patch( - "services.sandbox.sandbox_provider_service.SandboxProviderService.create_sandbox_builder", - return_value=create_mock_builder(mock_sandbox), - ), - patch("services.app_asset_service.AppAssetService.get_assets", return_value=None), - ): - layer.on_graph_start() - - assert layer.sandbox.metadata.id is not None - - layer.on_graph_end(error=Exception("Workflow execution failed")) - - assert not SandboxManager.has(sandbox_id) - mock_sandbox.release_environment.assert_called_once()