From 63b3e71909bda8b7fdd6ee3fb01b3c05745e6e68 Mon Sep 17 00:00:00 2001 From: Harry Date: Thu, 15 Jan 2026 13:22:46 +0800 Subject: [PATCH] refactor(sandbox): redesign sandbox_layer & reorganize import paths --- api/core/app/apps/workflow/app_generator.py | 6 +- api/core/app/layers/sandbox_layer.py | 55 ++---- api/core/sandbox/__init__.py | 21 ++- api/core/sandbox/bash/__init__.py | 2 +- api/core/sandbox/bash/bash_tool.py | 3 +- api/core/sandbox/bash/dify_cli.py | 3 +- api/core/sandbox/factory.py | 2 +- api/core/sandbox/initializer/__init__.py | 6 +- .../initializer/app_assets_initializer.py | 5 +- .../initializer/dify_cli_initializer.py | 7 +- api/core/sandbox/session.py | 19 ++- api/core/sandbox/storage/__init__.py | 4 +- api/core/sandbox/storage/archive_storage.py | 3 +- api/core/workflow/nodes/command/node.py | 3 +- api/core/workflow/nodes/llm/node.py | 3 +- .../sandbox/sandbox_provider_service.py | 3 +- api/services/workflow_service.py | 4 +- .../core/app/layers/test_sandbox_layer.py | 158 ++++++------------ .../core/virtual_environment/test_factory.py | 2 +- .../test_sandbox_manager.py | 2 +- .../nodes/command/test_command_node.py | 2 +- 21 files changed, 134 insertions(+), 179 deletions(-) diff --git a/api/core/app/apps/workflow/app_generator.py b/api/core/app/apps/workflow/app_generator.py index 5f881c2e39..9da4e71ba3 100644 --- a/api/core/app/apps/workflow/app_generator.py +++ b/api/core/app/apps/workflow/app_generator.py @@ -492,7 +492,11 @@ class WorkflowAppGenerator(BaseAppGenerator): if workflow.get_feature(WorkflowFeatures.SANDBOX).enabled: graph_engine_layers = ( *graph_engine_layers, - SandboxLayer(tenant_id=application_generate_entity.app_config.tenant_id), + SandboxLayer( + tenant_id=application_generate_entity.app_config.tenant_id, + app_id=application_generate_entity.app_config.app_id, + sandbox_id=application_generate_entity.workflow_execution_id, + ), ) # Determine system_user_id based on invocation source diff --git a/api/core/app/layers/sandbox_layer.py b/api/core/app/layers/sandbox_layer.py index aa7d8005f7..b9cebb4c8d 100644 --- a/api/core/app/layers/sandbox_layer.py +++ b/api/core/app/layers/sandbox_layer.py @@ -1,7 +1,6 @@ import logging -from core.sandbox.manager import SandboxManager -from core.sandbox.storage import ArchiveSandboxStorage +from core.sandbox import ArchiveSandboxStorage, SandboxManager from core.virtual_environment.__base.virtual_environment import VirtualEnvironment from core.workflow.graph_engine.layers.base import GraphEngineLayer from core.workflow.graph_events.base import GraphEngineEvent @@ -15,38 +14,22 @@ class SandboxInitializationError(Exception): class SandboxLayer(GraphEngineLayer): - def __init__(self, tenant_id: str) -> None: + def __init__(self, tenant_id: str, app_id: str, sandbox_id: str) -> None: super().__init__() self._tenant_id = tenant_id - self._workflow_execution_id: str | None = None - self._app_id: str | None = None - - def _get_workflow_execution_id(self) -> str: - workflow_execution_id = self.graph_runtime_state.system_variable.workflow_execution_id - if not workflow_execution_id: - raise RuntimeError("workflow_execution_id is not set in system variables") - return workflow_execution_id - - def _get_app_id(self) -> str: - app_id = self.graph_runtime_state.system_variable.app_id - if not app_id: - raise RuntimeError("app_id is not set in system variables") - return app_id + self._app_id = app_id + self._sandbox_id = sandbox_id @property def sandbox(self) -> VirtualEnvironment: - if self._workflow_execution_id is None: - raise RuntimeError("Sandbox not initialized. Ensure on_graph_start() has been called.") - sandbox = SandboxManager.get(self._workflow_execution_id) + sandbox = SandboxManager.get(self._sandbox_id) if sandbox is None: - raise RuntimeError(f"Sandbox not found for workflow_execution_id={self._workflow_execution_id}") + raise RuntimeError(f"Sandbox not found or not initialized for sandbox_id={self._sandbox_id}") return sandbox def on_graph_start(self) -> None: - self._workflow_execution_id = self._get_workflow_execution_id() - self._app_id = self._get_app_id() try: - from core.sandbox.initializer import AppAssetsInitializer, DifyCliInitializer + from core.sandbox import AppAssetsInitializer, DifyCliInitializer from services.sandbox.sandbox_provider_service import SandboxProviderService logger.info("Initializing sandbox for tenant_id=%s, app_id=%s", self._tenant_id, self._app_id) @@ -58,10 +41,10 @@ class SandboxLayer(GraphEngineLayer): ) sandbox = builder.build() - SandboxManager.register(self._workflow_execution_id, sandbox) + SandboxManager.register(self._sandbox_id, sandbox) logger.info( "Sandbox initialized, workflow_execution_id=%s, sandbox_id=%s, sandbox_arch=%s", - self._workflow_execution_id, + self._sandbox_id, sandbox.metadata.id, sandbox.metadata.arch, ) @@ -69,10 +52,10 @@ class SandboxLayer(GraphEngineLayer): sandbox_storage = ArchiveSandboxStorage( storage=storage, tenant_id=self._tenant_id, - sandbox_id=self._workflow_execution_id, + sandbox_id=self._sandbox_id, ) if sandbox_storage.mount(sandbox): - logger.info("Sandbox files restored, workflow_execution_id=%s", self._workflow_execution_id) + logger.info("Sandbox files restored, sandbox_id=%s", self._sandbox_id) except Exception as e: logger.exception("Failed to initialize sandbox") raise SandboxInitializationError(f"Failed to initialize sandbox: {e}") from e @@ -81,19 +64,19 @@ class SandboxLayer(GraphEngineLayer): pass def on_graph_end(self, error: Exception | None) -> None: - if self._workflow_execution_id is None: + if self._sandbox_id is None: logger.debug("No workflow_execution_id set, nothing to release") return - sandbox = SandboxManager.unregister(self._workflow_execution_id) + sandbox = SandboxManager.unregister(self._sandbox_id) if sandbox is None: - logger.debug("No sandbox to release for workflow_execution_id=%s", self._workflow_execution_id) + logger.debug("No sandbox to release for sandbox_id=%s", self._sandbox_id) return sandbox_id = sandbox.metadata.id logger.info( "Releasing sandbox, workflow_execution_id=%s, sandbox_id=%s", - self._workflow_execution_id, + self._sandbox_id, sandbox_id, ) @@ -101,17 +84,15 @@ class SandboxLayer(GraphEngineLayer): sandbox_storage = ArchiveSandboxStorage( storage=storage, tenant_id=self._tenant_id, - sandbox_id=self._workflow_execution_id, + sandbox_id=self._sandbox_id, ) sandbox_storage.unmount(sandbox) - logger.info("Sandbox files persisted, workflow_execution_id=%s", self._workflow_execution_id) + logger.info("Sandbox files persisted, sandbox_id=%s", self._sandbox_id) except Exception: - logger.exception("Failed to persist sandbox files, workflow_execution_id=%s", self._workflow_execution_id) + logger.exception("Failed to persist sandbox files, sandbox_id=%s", self._sandbox_id) try: sandbox.release_environment() logger.info("Sandbox released, sandbox_id=%s", sandbox_id) except Exception: logger.exception("Failed to release sandbox, sandbox_id=%s", sandbox_id) - finally: - self._workflow_execution_id = None diff --git a/api/core/sandbox/__init__.py b/api/core/sandbox/__init__.py index 3e9d1a649d..7fd6e4b309 100644 --- a/api/core/sandbox/__init__.py +++ b/api/core/sandbox/__init__.py @@ -1,18 +1,24 @@ -from core.sandbox.bash.dify_cli import ( +from .bash.dify_cli import ( DifyCliBinary, DifyCliConfig, DifyCliEnvConfig, DifyCliLocator, DifyCliToolConfig, ) -from core.sandbox.constants import ( +from .constants import ( APP_ASSETS_PATH, APP_ASSETS_ZIP_PATH, DIFY_CLI_CONFIG_PATH, DIFY_CLI_PATH, DIFY_CLI_PATH_PATTERN, ) -from core.sandbox.initializer import AppAssetsInitializer, DifyCliInitializer, SandboxInitializer +from .factory import VMBuilder, VMType +from .initializer import AppAssetsInitializer, DifyCliInitializer, SandboxInitializer +from .manager import SandboxManager +from .session import SandboxSession +from .storage import ArchiveSandboxStorage, SandboxStorage +from .utils.debug import sandbox_debug +from .utils.encryption import create_sandbox_config_encrypter, masked_config __all__ = [ "APP_ASSETS_PATH", @@ -21,6 +27,7 @@ __all__ = [ "DIFY_CLI_PATH", "DIFY_CLI_PATH_PATTERN", "AppAssetsInitializer", + "ArchiveSandboxStorage", "DifyCliBinary", "DifyCliConfig", "DifyCliEnvConfig", @@ -28,4 +35,12 @@ __all__ = [ "DifyCliLocator", "DifyCliToolConfig", "SandboxInitializer", + "SandboxManager", + "SandboxSession", + "SandboxStorage", + "VMBuilder", + "VMType", + "create_sandbox_config_encrypter", + "masked_config", + "sandbox_debug", ] diff --git a/api/core/sandbox/bash/__init__.py b/api/core/sandbox/bash/__init__.py index e0e85c4224..fd69e39833 100644 --- a/api/core/sandbox/bash/__init__.py +++ b/api/core/sandbox/bash/__init__.py @@ -1,4 +1,4 @@ -from core.sandbox.bash.dify_cli import ( +from .dify_cli import ( DifyCliBinary, DifyCliConfig, DifyCliEnvConfig, diff --git a/api/core/sandbox/bash/bash_tool.py b/api/core/sandbox/bash/bash_tool.py index 431ebb59f7..cb86c2f3f8 100644 --- a/api/core/sandbox/bash/bash_tool.py +++ b/api/core/sandbox/bash/bash_tool.py @@ -1,7 +1,6 @@ from collections.abc import Generator from typing import Any -from core.sandbox.utils.debug import sandbox_debug from core.tools.__base.tool import Tool from core.tools.__base.tool_runtime import ToolRuntime from core.tools.entities.common_entities import I18nObject @@ -16,6 +15,8 @@ from core.tools.entities.tool_entities import ( from core.virtual_environment.__base.helpers import submit_command, with_connection from core.virtual_environment.__base.virtual_environment import VirtualEnvironment +from ..utils.debug import sandbox_debug + COMMAND_TIMEOUT_SECONDS = 60 diff --git a/api/core/sandbox/bash/dify_cli.py b/api/core/sandbox/bash/dify_cli.py index 4e43d58664..8a6add3a43 100644 --- a/api/core/sandbox/bash/dify_cli.py +++ b/api/core/sandbox/bash/dify_cli.py @@ -6,11 +6,12 @@ from typing import TYPE_CHECKING, Any from pydantic import BaseModel, Field from core.model_runtime.utils.encoders import jsonable_encoder -from core.sandbox.constants import DIFY_CLI_PATH_PATTERN from core.session.cli_api import CliApiSession from core.tools.entities.tool_entities import ToolParameter, ToolProviderType from core.virtual_environment.__base.entities import Arch, OperatingSystem +from ..constants import DIFY_CLI_PATH_PATTERN + if TYPE_CHECKING: from core.tools.__base.tool import Tool diff --git a/api/core/sandbox/factory.py b/api/core/sandbox/factory.py index 4489e65231..bd248bb6c8 100644 --- a/api/core/sandbox/factory.py +++ b/api/core/sandbox/factory.py @@ -7,7 +7,7 @@ from typing import TYPE_CHECKING, Any from core.virtual_environment.__base.virtual_environment import VirtualEnvironment if TYPE_CHECKING: - from core.sandbox.initializer import SandboxInitializer + from .initializer import SandboxInitializer class VMType(StrEnum): diff --git a/api/core/sandbox/initializer/__init__.py b/api/core/sandbox/initializer/__init__.py index 65f6d51c0a..ffaa810644 100644 --- a/api/core/sandbox/initializer/__init__.py +++ b/api/core/sandbox/initializer/__init__.py @@ -1,6 +1,6 @@ -from core.sandbox.initializer.app_assets_initializer import AppAssetsInitializer -from core.sandbox.initializer.base import SandboxInitializer -from core.sandbox.initializer.dify_cli_initializer import DifyCliInitializer +from .app_assets_initializer import AppAssetsInitializer +from .base import SandboxInitializer +from .dify_cli_initializer import DifyCliInitializer __all__ = [ "AppAssetsInitializer", diff --git a/api/core/sandbox/initializer/app_assets_initializer.py b/api/core/sandbox/initializer/app_assets_initializer.py index 0606e95177..c07b073627 100644 --- a/api/core/sandbox/initializer/app_assets_initializer.py +++ b/api/core/sandbox/initializer/app_assets_initializer.py @@ -3,14 +3,15 @@ from io import BytesIO from sqlalchemy.orm import Session -from core.sandbox.constants import APP_ASSETS_PATH, APP_ASSETS_ZIP_PATH -from core.sandbox.initializer.base import SandboxInitializer from core.virtual_environment.__base.helpers import execute, with_connection from core.virtual_environment.__base.virtual_environment import VirtualEnvironment from extensions.ext_database import db from extensions.ext_storage import storage from models.app_asset import AppAssetDraft +from ..constants import APP_ASSETS_PATH, APP_ASSETS_ZIP_PATH +from .base import SandboxInitializer + logger = logging.getLogger(__name__) diff --git a/api/core/sandbox/initializer/dify_cli_initializer.py b/api/core/sandbox/initializer/dify_cli_initializer.py index 29f8982616..243d925322 100644 --- a/api/core/sandbox/initializer/dify_cli_initializer.py +++ b/api/core/sandbox/initializer/dify_cli_initializer.py @@ -2,12 +2,13 @@ import logging from io import BytesIO from pathlib import Path -from core.sandbox.bash.dify_cli import DifyCliLocator -from core.sandbox.constants import DIFY_CLI_PATH -from core.sandbox.initializer.base import SandboxInitializer from core.virtual_environment.__base.helpers import execute from core.virtual_environment.__base.virtual_environment import VirtualEnvironment +from ..bash.dify_cli import DifyCliLocator +from ..constants import DIFY_CLI_PATH +from .base import SandboxInitializer + logger = logging.getLogger(__name__) diff --git a/api/core/sandbox/session.py b/api/core/sandbox/session.py index 4a4fe0a8b8..bd6a416888 100644 --- a/api/core/sandbox/session.py +++ b/api/core/sandbox/session.py @@ -4,17 +4,22 @@ import json import logging from io import BytesIO from types import TracebackType +from typing import TYPE_CHECKING -from core.sandbox.bash.bash_tool import SandboxBashTool -from core.sandbox.bash.dify_cli import DifyCliConfig -from core.sandbox.constants import DIFY_CLI_CONFIG_PATH, DIFY_CLI_PATH -from core.sandbox.manager import SandboxManager -from core.sandbox.utils.debug import sandbox_debug from core.session.cli_api import CliApiSessionManager -from core.tools.__base.tool import Tool from core.virtual_environment.__base.helpers import execute from core.virtual_environment.__base.virtual_environment import VirtualEnvironment +from .bash.dify_cli import DifyCliConfig +from .constants import DIFY_CLI_CONFIG_PATH, DIFY_CLI_PATH +from .manager import SandboxManager +from .utils.debug import sandbox_debug + +if TYPE_CHECKING: + from core.tools.__base.tool import Tool + + from .bash.bash_tool import SandboxBashTool + logger = logging.getLogger(__name__) @@ -63,6 +68,8 @@ class SandboxSession: self._session_id = None raise + from .bash.bash_tool import SandboxBashTool + self._sandbox = sandbox self._bash_tool = SandboxBashTool(sandbox=sandbox, tenant_id=self._tenant_id) return self diff --git a/api/core/sandbox/storage/__init__.py b/api/core/sandbox/storage/__init__.py index 57fe3b836a..7206370c10 100644 --- a/api/core/sandbox/storage/__init__.py +++ b/api/core/sandbox/storage/__init__.py @@ -1,4 +1,4 @@ -from core.sandbox.storage.archive_storage import ArchiveSandboxStorage -from core.sandbox.storage.sandbox_storage import SandboxStorage +from .archive_storage import ArchiveSandboxStorage +from .sandbox_storage import SandboxStorage __all__ = ["ArchiveSandboxStorage", "SandboxStorage"] diff --git a/api/core/sandbox/storage/archive_storage.py b/api/core/sandbox/storage/archive_storage.py index 34ddc18fef..61d5c10643 100644 --- a/api/core/sandbox/storage/archive_storage.py +++ b/api/core/sandbox/storage/archive_storage.py @@ -1,11 +1,12 @@ import logging from io import BytesIO -from core.sandbox.storage.sandbox_storage import SandboxStorage from core.virtual_environment.__base.helpers import try_execute from core.virtual_environment.__base.virtual_environment import VirtualEnvironment from extensions.ext_storage import Storage +from .sandbox_storage import SandboxStorage + logger = logging.getLogger(__name__) ARCHIVE_NAME = "workspace.tar.gz" diff --git a/api/core/workflow/nodes/command/node.py b/api/core/workflow/nodes/command/node.py index 39327aab65..39c07a1b1b 100644 --- a/api/core/workflow/nodes/command/node.py +++ b/api/core/workflow/nodes/command/node.py @@ -3,8 +3,7 @@ import shlex from collections.abc import Mapping, Sequence from typing import Any -from core.sandbox.manager import SandboxManager -from core.sandbox.utils.debug import sandbox_debug +from core.sandbox import SandboxManager, sandbox_debug from core.virtual_environment.__base.command_future import CommandCancelledError, CommandTimeoutError from core.virtual_environment.__base.helpers import submit_command, with_connection from core.virtual_environment.__base.virtual_environment import VirtualEnvironment diff --git a/api/core/workflow/nodes/llm/node.py b/api/core/workflow/nodes/llm/node.py index d04fbd3e42..b7d79ee68a 100644 --- a/api/core/workflow/nodes/llm/node.py +++ b/api/core/workflow/nodes/llm/node.py @@ -50,8 +50,7 @@ from core.model_runtime.utils.encoders import jsonable_encoder from core.prompt.entities.advanced_prompt_entities import CompletionModelPromptTemplate, MemoryConfig from core.prompt.utils.prompt_message_util import PromptMessageUtil from core.rag.entities.citation_metadata import RetrievalSourceMetadata -from core.sandbox import SandboxSession -from core.sandbox.manager import SandboxManager +from core.sandbox import SandboxManager, SandboxSession from core.tools.__base.tool import Tool from core.tools.signature import sign_upload_file from core.tools.tool_manager import ToolManager diff --git a/api/services/sandbox/sandbox_provider_service.py b/api/services/sandbox/sandbox_provider_service.py index c1860e02c2..09eaeb2be0 100644 --- a/api/services/sandbox/sandbox_provider_service.py +++ b/api/services/sandbox/sandbox_provider_service.py @@ -19,8 +19,7 @@ from sqlalchemy.orm import Session from configs import dify_config from constants import HIDDEN_VALUE from core.entities.provider_entities import BasicProviderConfig -from core.sandbox.factory import VMBuilder, VMType -from core.sandbox.utils.encryption import create_sandbox_config_encrypter, masked_config +from core.sandbox import VMBuilder, VMType, create_sandbox_config_encrypter, masked_config from core.tools.utils.system_encryption import ( decrypt_system_params, ) diff --git a/api/services/workflow_service.py b/api/services/workflow_service.py index 0005de826f..4b4bbe9632 100644 --- a/api/services/workflow_service.py +++ b/api/services/workflow_service.py @@ -14,7 +14,7 @@ from core.app.apps.advanced_chat.app_config_manager import AdvancedChatAppConfig from core.app.apps.workflow.app_config_manager import WorkflowAppConfigManager from core.file import File from core.repositories import DifyCoreRepositoryFactory -from core.sandbox.manager import SandboxManager +from core.sandbox import SandboxManager from core.variables import Variable, VariableBase from core.workflow.entities import WorkflowNodeExecution from core.workflow.enums import ErrorStrategy, WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus @@ -701,7 +701,7 @@ class WorkflowService: sandbox = None single_step_execution_id: str | None = None if draft_workflow.get_feature(WorkflowFeatures.SANDBOX).enabled: - from core.sandbox.initializer import AppAssetsInitializer, DifyCliInitializer + from core.sandbox import AppAssetsInitializer, DifyCliInitializer sandbox = ( SandboxProviderService.create_sandbox_builder(draft_workflow.tenant_id) diff --git a/api/tests/unit_tests/core/app/layers/test_sandbox_layer.py b/api/tests/unit_tests/core/app/layers/test_sandbox_layer.py index 89c1d03f40..eab45cdb21 100644 --- a/api/tests/unit_tests/core/app/layers/test_sandbox_layer.py +++ b/api/tests/unit_tests/core/app/layers/test_sandbox_layer.py @@ -3,7 +3,7 @@ from unittest.mock import MagicMock, patch import pytest from core.app.layers.sandbox_layer import SandboxInitializationError, SandboxLayer -from core.sandbox.manager import SandboxManager +from core.sandbox import SandboxManager from core.virtual_environment.__base.entities import Arch from core.virtual_environment.__base.virtual_environment import VirtualEnvironment from core.workflow.graph_engine.layers.base import GraphEngineLayerNotInitializedError @@ -29,37 +29,6 @@ class MockVirtualEnvironment: self._released = True -class MockSystemVariableView: - def __init__( - self, - workflow_execution_id: str | None = "test-workflow-exec-id", - app_id: str | None = "test-app-id", - ): - self._workflow_execution_id = workflow_execution_id - self._app_id = app_id - - @property - def workflow_execution_id(self) -> str | None: - return self._workflow_execution_id - - @property - def app_id(self) -> str | None: - return self._app_id - - -class MockReadOnlyGraphRuntimeStateWrapper: - def __init__( - self, - workflow_execution_id: str | None = "test-workflow-exec-id", - app_id: str | None = "test-app-id", - ): - self._system_variable = MockSystemVariableView(workflow_execution_id, app_id) - - @property - def system_variable(self) -> MockSystemVariableView: - return self._system_variable - - class MockVMBuilder: def __init__(self, sandbox: VirtualEnvironment): self._sandbox = sandbox @@ -81,30 +50,40 @@ def clean_sandbox_manager(): SandboxManager.clear() +@pytest.fixture +def mock_archive_storage(): + with patch("core.app.layers.sandbox_layer.ArchiveSandboxStorage") as mock_class: + mock_instance = MagicMock() + mock_instance.mount.return_value = False + mock_instance.unmount.return_value = True + mock_class.return_value = mock_instance + yield mock_instance + + def create_mock_builder(sandbox): return MockVMBuilder(sandbox) class TestSandboxLayer: def test_init_with_parameters(self): - layer = SandboxLayer(tenant_id="test-tenant") + layer = SandboxLayer(tenant_id="test-tenant", app_id="test-app", sandbox_id="test-sandbox") assert layer._tenant_id == "test-tenant" # pyright: ignore[reportPrivateUsage] - assert layer._workflow_execution_id is None # pyright: ignore[reportPrivateUsage] + assert layer._app_id == "test-app" # pyright: ignore[reportPrivateUsage] + assert layer._sandbox_id == "test-sandbox" # pyright: ignore[reportPrivateUsage] def test_sandbox_property_raises_when_not_initialized(self): - layer = SandboxLayer(tenant_id="test-tenant") + layer = SandboxLayer(tenant_id="test-tenant", app_id="test-app", sandbox_id="test-sandbox") with pytest.raises(RuntimeError) as exc_info: _ = layer.sandbox - assert "Sandbox not initialized" in str(exc_info.value) + assert "Sandbox not found" in str(exc_info.value) - def test_sandbox_property_returns_sandbox_after_initialization(self): - layer = SandboxLayer(tenant_id="test-tenant") + def test_sandbox_property_returns_sandbox_after_initialization(self, mock_archive_storage): + sandbox_id = "test-exec-id" + layer = SandboxLayer(tenant_id="test-tenant", app_id="test-app", sandbox_id=sandbox_id) mock_sandbox = MockVirtualEnvironment() - mock_runtime_state = MockReadOnlyGraphRuntimeStateWrapper("test-exec-id") - layer._graph_runtime_state = mock_runtime_state # type: ignore[assignment] with patch( "services.sandbox.sandbox_provider_service.SandboxProviderService.create_sandbox_builder", @@ -114,11 +93,10 @@ class TestSandboxLayer: assert layer.sandbox is mock_sandbox - def test_on_graph_start_creates_sandbox_and_registers_with_manager(self): - layer = SandboxLayer(tenant_id="test-tenant-123") + def test_on_graph_start_creates_sandbox_and_registers_with_manager(self, mock_archive_storage): + sandbox_id = "test-exec-123" + layer = SandboxLayer(tenant_id="test-tenant-123", app_id="test-app-123", sandbox_id=sandbox_id) mock_sandbox = MockVirtualEnvironment() - mock_runtime_state = MockReadOnlyGraphRuntimeStateWrapper("test-exec-123", "test-app-123") - layer._graph_runtime_state = mock_runtime_state # type: ignore[assignment] with patch( "services.sandbox.sandbox_provider_service.SandboxProviderService.create_sandbox_builder", @@ -127,12 +105,10 @@ class TestSandboxLayer: layer.on_graph_start() mock_create.assert_called_once_with("test-tenant-123") - assert SandboxManager.get("test-exec-123") is mock_sandbox + assert SandboxManager.get(sandbox_id) is mock_sandbox def test_on_graph_start_raises_sandbox_initialization_error_on_failure(self): - layer = SandboxLayer(tenant_id="test-tenant") - mock_runtime_state = MockReadOnlyGraphRuntimeStateWrapper("test-exec-id") - layer._graph_runtime_state = mock_runtime_state # type: ignore[assignment] + layer = SandboxLayer(tenant_id="test-tenant", app_id="test-app", sandbox_id="test-sandbox") with patch( "services.sandbox.sandbox_provider_service.SandboxProviderService.create_sandbox_builder", @@ -144,30 +120,18 @@ class TestSandboxLayer: assert "Failed to initialize sandbox" in str(exc_info.value) assert "Sandbox provider not available" in str(exc_info.value) - def test_on_graph_start_raises_when_workflow_execution_id_not_set(self): - layer = SandboxLayer(tenant_id="test-tenant") - mock_runtime_state = MockReadOnlyGraphRuntimeStateWrapper(workflow_execution_id=None) - layer._graph_runtime_state = mock_runtime_state # type: ignore[assignment] - - with pytest.raises(RuntimeError) as exc_info: - layer.on_graph_start() - - assert "workflow_execution_id is not set" in str(exc_info.value) - def test_on_event_is_noop(self): - layer = SandboxLayer(tenant_id="test-tenant") + layer = SandboxLayer(tenant_id="test-tenant", app_id="test-app", sandbox_id="test-sandbox") layer.on_event(GraphRunStartedEvent()) layer.on_event(GraphRunSucceededEvent(outputs={})) layer.on_event(GraphRunFailedEvent(error="test error", exceptions_count=1)) - def test_on_graph_end_releases_sandbox_and_unregisters_from_manager(self): - layer = SandboxLayer(tenant_id="test-tenant") + def test_on_graph_end_releases_sandbox_and_unregisters_from_manager(self, mock_archive_storage): + sandbox_id = "test-exec-456" + layer = SandboxLayer(tenant_id="test-tenant", app_id="test-app", sandbox_id=sandbox_id) mock_sandbox = MagicMock(spec=VirtualEnvironment) mock_sandbox.metadata = MockMetadata() - workflow_execution_id = "test-exec-456" - mock_runtime_state = MockReadOnlyGraphRuntimeStateWrapper(workflow_execution_id) - layer._graph_runtime_state = mock_runtime_state # type: ignore[assignment] with patch( "services.sandbox.sandbox_provider_service.SandboxProviderService.create_sandbox_builder", @@ -175,21 +139,18 @@ class TestSandboxLayer: ): layer.on_graph_start() - assert SandboxManager.has(workflow_execution_id) + assert SandboxManager.has(sandbox_id) layer.on_graph_end(error=None) mock_sandbox.release_environment.assert_called_once() - assert layer._workflow_execution_id is None # pyright: ignore[reportPrivateUsage] - assert not SandboxManager.has(workflow_execution_id) + assert not SandboxManager.has(sandbox_id) - def test_on_graph_end_releases_sandbox_even_on_error(self): - layer = SandboxLayer(tenant_id="test-tenant") + def test_on_graph_end_releases_sandbox_even_on_error(self, mock_archive_storage): + sandbox_id = "test-exec-789" + layer = SandboxLayer(tenant_id="test-tenant", app_id="test-app", sandbox_id=sandbox_id) mock_sandbox = MagicMock(spec=VirtualEnvironment) mock_sandbox.metadata = MockMetadata() - workflow_execution_id = "test-exec-789" - mock_runtime_state = MockReadOnlyGraphRuntimeStateWrapper(workflow_execution_id) - layer._graph_runtime_state = mock_runtime_state # type: ignore[assignment] with patch( "services.sandbox.sandbox_provider_service.SandboxProviderService.create_sandbox_builder", @@ -200,17 +161,14 @@ class TestSandboxLayer: layer.on_graph_end(error=Exception("Workflow failed")) mock_sandbox.release_environment.assert_called_once() - assert layer._workflow_execution_id is None # pyright: ignore[reportPrivateUsage] - assert not SandboxManager.has(workflow_execution_id) + assert not SandboxManager.has(sandbox_id) - def test_on_graph_end_handles_release_failure_gracefully(self): - layer = SandboxLayer(tenant_id="test-tenant") + def test_on_graph_end_handles_release_failure_gracefully(self, mock_archive_storage): + sandbox_id = "test-exec-fail" + layer = SandboxLayer(tenant_id="test-tenant", app_id="test-app", sandbox_id=sandbox_id) mock_sandbox = MagicMock(spec=VirtualEnvironment) mock_sandbox.metadata = MockMetadata() mock_sandbox.release_environment.side_effect = Exception("Container already removed") - workflow_execution_id = "test-exec-fail" - mock_runtime_state = MockReadOnlyGraphRuntimeStateWrapper(workflow_execution_id) - layer._graph_runtime_state = mock_runtime_state # type: ignore[assignment] with patch( "services.sandbox.sandbox_provider_service.SandboxProviderService.create_sandbox_builder", @@ -221,22 +179,17 @@ class TestSandboxLayer: layer.on_graph_end(error=None) mock_sandbox.release_environment.assert_called_once() - assert layer._workflow_execution_id is None # pyright: ignore[reportPrivateUsage] - def test_on_graph_end_noop_when_sandbox_not_initialized(self): - layer = SandboxLayer(tenant_id="test-tenant") + def test_on_graph_end_noop_when_sandbox_not_registered(self): + layer = SandboxLayer(tenant_id="test-tenant", app_id="test-app", sandbox_id="nonexistent-sandbox") layer.on_graph_end(error=None) - assert layer._workflow_execution_id is None # pyright: ignore[reportPrivateUsage] - - def test_on_graph_end_is_idempotent(self): - layer = SandboxLayer(tenant_id="test-tenant") + def test_on_graph_end_is_idempotent(self, mock_archive_storage): + sandbox_id = "test-exec-idempotent" + layer = SandboxLayer(tenant_id="test-tenant", app_id="test-app", sandbox_id=sandbox_id) mock_sandbox = MagicMock(spec=VirtualEnvironment) mock_sandbox.metadata = MockMetadata() - workflow_execution_id = "test-exec-idempotent" - mock_runtime_state = MockReadOnlyGraphRuntimeStateWrapper(workflow_execution_id) - layer._graph_runtime_state = mock_runtime_state # type: ignore[assignment] with patch( "services.sandbox.sandbox_provider_service.SandboxProviderService.create_sandbox_builder", @@ -250,7 +203,7 @@ class TestSandboxLayer: mock_sandbox.release_environment.assert_called_once() def test_layer_inherits_from_graph_engine_layer(self): - layer = SandboxLayer(tenant_id="test-tenant") + layer = SandboxLayer(tenant_id="test-tenant", app_id="test-app", sandbox_id="test-sandbox") with pytest.raises(GraphEngineLayerNotInitializedError): _ = layer.graph_runtime_state @@ -259,11 +212,9 @@ class TestSandboxLayer: class TestSandboxLayerIntegration: - def test_full_lifecycle_with_mocked_provider(self): - layer = SandboxLayer(tenant_id="integration-tenant") - workflow_execution_id = "integration-test-exec" - mock_runtime_state = MockReadOnlyGraphRuntimeStateWrapper(workflow_execution_id) - layer._graph_runtime_state = mock_runtime_state # type: ignore[assignment] + def test_full_lifecycle_with_mocked_provider(self, mock_archive_storage): + sandbox_id = "integration-test-exec" + layer = SandboxLayer(tenant_id="integration-tenant", app_id="integration-app", sandbox_id=sandbox_id) mock_sandbox = MagicMock(spec=VirtualEnvironment) mock_sandbox.metadata = MockMetadata(sandbox_id="integration-sandbox") @@ -273,21 +224,17 @@ class TestSandboxLayerIntegration: ): layer.on_graph_start() - assert layer._workflow_execution_id == workflow_execution_id # pyright: ignore[reportPrivateUsage] assert layer.sandbox is mock_sandbox - assert SandboxManager.get(workflow_execution_id) is mock_sandbox + assert SandboxManager.get(sandbox_id) is mock_sandbox layer.on_graph_end(error=None) - assert layer._workflow_execution_id is None # pyright: ignore[reportPrivateUsage] - assert not SandboxManager.has(workflow_execution_id) + assert not SandboxManager.has(sandbox_id) mock_sandbox.release_environment.assert_called_once() - def test_lifecycle_with_workflow_error(self): - layer = SandboxLayer(tenant_id="error-tenant") - workflow_execution_id = "integration-error-test" - mock_runtime_state = MockReadOnlyGraphRuntimeStateWrapper(workflow_execution_id) - layer._graph_runtime_state = mock_runtime_state # type: ignore[assignment] + def test_lifecycle_with_workflow_error(self, mock_archive_storage): + sandbox_id = "integration-error-test" + layer = SandboxLayer(tenant_id="error-tenant", app_id="error-app", sandbox_id=sandbox_id) mock_sandbox = MagicMock(spec=VirtualEnvironment) mock_sandbox.metadata = MockMetadata() @@ -301,6 +248,5 @@ class TestSandboxLayerIntegration: layer.on_graph_end(error=Exception("Workflow execution failed")) - assert layer._workflow_execution_id is None # pyright: ignore[reportPrivateUsage] - assert not SandboxManager.has(workflow_execution_id) + assert not SandboxManager.has(sandbox_id) mock_sandbox.release_environment.assert_called_once() diff --git a/api/tests/unit_tests/core/virtual_environment/test_factory.py b/api/tests/unit_tests/core/virtual_environment/test_factory.py index 7f6b900656..6a68f3b454 100644 --- a/api/tests/unit_tests/core/virtual_environment/test_factory.py +++ b/api/tests/unit_tests/core/virtual_environment/test_factory.py @@ -3,7 +3,7 @@ from unittest.mock import MagicMock, patch import pytest -from core.sandbox.factory import VMBuilder, VMType +from core.sandbox import VMBuilder, VMType from core.virtual_environment.__base.virtual_environment import VirtualEnvironment diff --git a/api/tests/unit_tests/core/virtual_environment/test_sandbox_manager.py b/api/tests/unit_tests/core/virtual_environment/test_sandbox_manager.py index f00049baf8..b1674478ad 100644 --- a/api/tests/unit_tests/core/virtual_environment/test_sandbox_manager.py +++ b/api/tests/unit_tests/core/virtual_environment/test_sandbox_manager.py @@ -5,7 +5,7 @@ from typing import Any import pytest -from core.sandbox.manager import SandboxManager +from core.sandbox import SandboxManager from core.virtual_environment.__base.entities import ( Arch, CommandStatus, diff --git a/api/tests/unit_tests/core/workflow/nodes/command/test_command_node.py b/api/tests/unit_tests/core/workflow/nodes/command/test_command_node.py index 7253275366..863ee85df5 100644 --- a/api/tests/unit_tests/core/workflow/nodes/command/test_command_node.py +++ b/api/tests/unit_tests/core/workflow/nodes/command/test_command_node.py @@ -5,7 +5,7 @@ from typing import Any import pytest -from core.sandbox.manager import SandboxManager +from core.sandbox import SandboxManager from core.virtual_environment.__base.entities import ( Arch, CommandStatus,