From 94ff904a04129c831321020a9f1d5f123da94e17 Mon Sep 17 00:00:00 2001 From: Harry Date: Thu, 15 Jan 2026 00:13:29 +0800 Subject: [PATCH] feat(sandbox): add AppAssetsInitializer and refactor VMFactory to VMBuilder - Add AppAssetsInitializer to load published app assets into sandbox - Refactor VMFactory.create() to VMBuilder with builder pattern - Extract SandboxInitializer base class and DifyCliInitializer - Simplify SandboxLayer constructor (remove options/environments params) - Fix circular import in sandbox module by removing eager SandboxBashTool export - Update SandboxProviderService to return VMBuilder instead of VirtualEnvironment --- api/core/app/layers/sandbox_layer.py | 24 +-- api/core/sandbox/__init__.py | 11 +- api/core/sandbox/bash/__init__.py | 2 - api/core/sandbox/constants.py | 4 + api/core/sandbox/factory.py | 98 ++++++----- api/core/sandbox/initializer/__init__.py | 5 +- .../initializer/app_assets_initializer.py | 86 +++++++++ api/core/sandbox/initializer/base.py | 8 + ...initializer.py => dify_cli_initializer.py} | 7 +- api/core/sandbox/utils/debug.py | 8 +- .../sandbox/sandbox_provider_service.py | 21 +-- api/services/workflow_service.py | 9 +- .../core/app/layers/test_sandbox_layer.py | 92 ++++++---- .../core/virtual_environment/test_factory.py | 164 +++++++----------- 14 files changed, 316 insertions(+), 223 deletions(-) create mode 100644 api/core/sandbox/initializer/app_assets_initializer.py create mode 100644 api/core/sandbox/initializer/base.py rename api/core/sandbox/initializer/{initializer.py => dify_cli_initializer.py} (86%) diff --git a/api/core/app/layers/sandbox_layer.py b/api/core/app/layers/sandbox_layer.py index a1db5b45f4..fa4ec2dc70 100644 --- a/api/core/app/layers/sandbox_layer.py +++ b/api/core/app/layers/sandbox_layer.py @@ -1,6 +1,4 @@ import logging -from collections.abc import Mapping -from typing import Any from core.sandbox.manager import SandboxManager from core.virtual_environment.__base.virtual_environment import VirtualEnvironment @@ -15,16 +13,9 @@ class SandboxInitializationError(Exception): class SandboxLayer(GraphEngineLayer): - def __init__( - self, - tenant_id: str, - options: Mapping[str, Any] | None = None, - environments: Mapping[str, str] | None = None, - ) -> None: + def __init__(self, tenant_id: str) -> None: super().__init__() self._tenant_id = tenant_id - self._options: Mapping[str, Any] = options or {} - self._environments: Mapping[str, str] = environments or {} self._workflow_execution_id: str | None = None def _get_workflow_execution_id(self) -> str: @@ -46,13 +37,16 @@ class SandboxLayer(GraphEngineLayer): self._workflow_execution_id = self._get_workflow_execution_id() try: + from core.sandbox.initializer import AppAssetsInitializer, DifyCliInitializer from services.sandbox.sandbox_provider_service import SandboxProviderService - logger.info("Initializing sandbox for tenant_id=%s", self._tenant_id) - sandbox = SandboxProviderService.create_sandbox( - tenant_id=self._tenant_id, - environments=self._environments, - ) + app_id = self.graph_runtime_state.system_variable.app_id + logger.info("Initializing sandbox for tenant_id=%s, app_id=%s", self._tenant_id, app_id) + + builder = SandboxProviderService.create_sandbox_builder(self._tenant_id).initializer(DifyCliInitializer()) + if app_id: + builder.initializer(AppAssetsInitializer(self._tenant_id, app_id)) + sandbox = builder.build() SandboxManager.register(self._workflow_execution_id, sandbox) logger.info( diff --git a/api/core/sandbox/__init__.py b/api/core/sandbox/__init__.py index 5f3f14fcb0..3e9d1a649d 100644 --- a/api/core/sandbox/__init__.py +++ b/api/core/sandbox/__init__.py @@ -1,4 +1,3 @@ -from core.sandbox.bash.bash_tool import SandboxBashTool from core.sandbox.bash.dify_cli import ( DifyCliBinary, DifyCliConfig, @@ -7,24 +6,26 @@ from core.sandbox.bash.dify_cli import ( DifyCliToolConfig, ) from core.sandbox.constants import ( + APP_ASSETS_PATH, + APP_ASSETS_ZIP_PATH, DIFY_CLI_CONFIG_PATH, DIFY_CLI_PATH, DIFY_CLI_PATH_PATTERN, ) -from core.sandbox.initializer.initializer import DifyCliInitializer, SandboxInitializer -from core.sandbox.session import SandboxSession +from core.sandbox.initializer import AppAssetsInitializer, DifyCliInitializer, SandboxInitializer __all__ = [ + "APP_ASSETS_PATH", + "APP_ASSETS_ZIP_PATH", "DIFY_CLI_CONFIG_PATH", "DIFY_CLI_PATH", "DIFY_CLI_PATH_PATTERN", + "AppAssetsInitializer", "DifyCliBinary", "DifyCliConfig", "DifyCliEnvConfig", "DifyCliInitializer", "DifyCliLocator", "DifyCliToolConfig", - "SandboxBashTool", "SandboxInitializer", - "SandboxSession", ] diff --git a/api/core/sandbox/bash/__init__.py b/api/core/sandbox/bash/__init__.py index 3e0f59e1bf..e0e85c4224 100644 --- a/api/core/sandbox/bash/__init__.py +++ b/api/core/sandbox/bash/__init__.py @@ -1,4 +1,3 @@ -from core.sandbox.bash.bash_tool import SandboxBashTool from core.sandbox.bash.dify_cli import ( DifyCliBinary, DifyCliConfig, @@ -13,5 +12,4 @@ __all__ = [ "DifyCliEnvConfig", "DifyCliLocator", "DifyCliToolConfig", - "SandboxBashTool", ] diff --git a/api/core/sandbox/constants.py b/api/core/sandbox/constants.py index c880849e09..35a43e850d 100644 --- a/api/core/sandbox/constants.py +++ b/api/core/sandbox/constants.py @@ -5,3 +5,7 @@ DIFY_CLI_PATH: Final[str] = ".dify/bin/dify" DIFY_CLI_PATH_PATTERN: Final[str] = "dify-cli-{os}-{arch}" DIFY_CLI_CONFIG_PATH: Final[str] = ".dify_cli.json" + +# App Assets +APP_ASSETS_PATH: Final[str] = "assets" +APP_ASSETS_ZIP_PATH: Final[str] = ".dify/tmp/assets.zip" diff --git a/api/core/sandbox/factory.py b/api/core/sandbox/factory.py index fbd159da84..4489e65231 100644 --- a/api/core/sandbox/factory.py +++ b/api/core/sandbox/factory.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from collections.abc import Mapping, Sequence from enum import StrEnum from typing import TYPE_CHECKING, Any @@ -14,48 +16,66 @@ class VMType(StrEnum): LOCAL = "local" -class VMFactory: - @classmethod - def create( - cls, - tenant_id: str, - vm_type: VMType, - options: Mapping[str, Any] | None = None, - environments: Mapping[str, str] | None = None, - user_id: str | None = None, - initializers: Sequence["SandboxInitializer"] | None = None, - ) -> VirtualEnvironment: - options = options or {} - environments = environments or {} +def _get_vm_class(vm_type: VMType) -> type[VirtualEnvironment]: + match vm_type: + case VMType.DOCKER: + from core.virtual_environment.providers.docker_daemon_sandbox import DockerDaemonEnvironment - vm_class = cls._get_vm_class(vm_type) - vm = vm_class(tenant_id=tenant_id, options=options, environments=environments, user_id=user_id) + return DockerDaemonEnvironment + case VMType.E2B: + from core.virtual_environment.providers.e2b_sandbox import E2BEnvironment - if initializers: - for initializer in initializers: - initializer.initialize(vm) + return E2BEnvironment + case VMType.LOCAL: + from core.virtual_environment.providers.local_without_isolation import LocalVirtualEnvironment + return LocalVirtualEnvironment + case _: + raise ValueError(f"Unsupported VM type: {vm_type}") + + +class VMBuilder: + def __init__(self, tenant_id: str, vm_type: VMType) -> None: + self._tenant_id = tenant_id + self._vm_type = vm_type + self._user_id: str | None = None + self._options: dict[str, Any] = {} + self._environments: dict[str, str] = {} + self._initializers: list[SandboxInitializer] = [] + + def user(self, user_id: str) -> VMBuilder: + self._user_id = user_id + return self + + def options(self, options: Mapping[str, Any]) -> VMBuilder: + self._options = dict(options) + return self + + def environments(self, environments: Mapping[str, str]) -> VMBuilder: + self._environments = dict(environments) + return self + + def initializer(self, initializer: SandboxInitializer) -> VMBuilder: + self._initializers.append(initializer) + return self + + def initializers(self, initializers: Sequence[SandboxInitializer]) -> VMBuilder: + self._initializers.extend(initializers) + return self + + def build(self) -> VirtualEnvironment: + vm_class = _get_vm_class(self._vm_type) + vm = vm_class( + tenant_id=self._tenant_id, + options=self._options, + environments=self._environments, + user_id=self._user_id, + ) + for init in self._initializers: + init.initialize(vm) return vm - @classmethod - def _get_vm_class(cls, vm_type: VMType) -> type[VirtualEnvironment]: - match vm_type: - case VMType.DOCKER: - from core.virtual_environment.providers.docker_daemon_sandbox import DockerDaemonEnvironment - - return DockerDaemonEnvironment - case VMType.E2B: - from core.virtual_environment.providers.e2b_sandbox import E2BEnvironment - - return E2BEnvironment - case VMType.LOCAL: - from core.virtual_environment.providers.local_without_isolation import LocalVirtualEnvironment - - return LocalVirtualEnvironment - case _: - raise ValueError(f"Unsupported VM type: {vm_type}") - - @classmethod - def validate(cls, vm_type: VMType, options: Mapping[str, Any]) -> None: - vm_class = cls._get_vm_class(vm_type) + @staticmethod + def validate(vm_type: VMType, options: Mapping[str, Any]) -> None: + vm_class = _get_vm_class(vm_type) vm_class.validate(options) diff --git a/api/core/sandbox/initializer/__init__.py b/api/core/sandbox/initializer/__init__.py index 258d7fafc5..65f6d51c0a 100644 --- a/api/core/sandbox/initializer/__init__.py +++ b/api/core/sandbox/initializer/__init__.py @@ -1,6 +1,9 @@ -from core.sandbox.initializer.initializer import DifyCliInitializer, SandboxInitializer +from core.sandbox.initializer.app_assets_initializer import AppAssetsInitializer +from core.sandbox.initializer.base import SandboxInitializer +from core.sandbox.initializer.dify_cli_initializer import DifyCliInitializer __all__ = [ + "AppAssetsInitializer", "DifyCliInitializer", "SandboxInitializer", ] diff --git a/api/core/sandbox/initializer/app_assets_initializer.py b/api/core/sandbox/initializer/app_assets_initializer.py new file mode 100644 index 0000000000..0606e95177 --- /dev/null +++ b/api/core/sandbox/initializer/app_assets_initializer.py @@ -0,0 +1,86 @@ +import logging +from io import BytesIO + +from sqlalchemy.orm import Session + +from core.sandbox.constants import APP_ASSETS_PATH, APP_ASSETS_ZIP_PATH +from core.sandbox.initializer.base import SandboxInitializer +from core.virtual_environment.__base.helpers import execute, with_connection +from core.virtual_environment.__base.virtual_environment import VirtualEnvironment +from extensions.ext_database import db +from extensions.ext_storage import storage +from models.app_asset import AppAssetDraft + +logger = logging.getLogger(__name__) + + +class AppAssetsInitializer(SandboxInitializer): + def __init__(self, tenant_id: str, app_id: str) -> None: + self._tenant_id = tenant_id + self._app_id = app_id + + def initialize(self, env: VirtualEnvironment) -> None: + published = self._get_latest_published() + if not published: + logger.debug("No published assets for app_id=%s, skipping", self._app_id) + return + + zip_key = AppAssetDraft.get_published_storage_key(self._tenant_id, self._app_id, published.id) + try: + zip_data = storage.load_once(zip_key) + except Exception: + logger.warning( + "Failed to load assets zip for app_id=%s, key=%s", + self._app_id, + zip_key, + exc_info=True, + ) + return + + env.upload_file(APP_ASSETS_ZIP_PATH, BytesIO(zip_data)) + + with with_connection(env) as conn: + execute( + env, + ["mkdir", "-p", ".dify/tmp"], + connection=conn, + error_message="Failed to create temp directory", + ) + execute( + env, + ["mkdir", "-p", APP_ASSETS_PATH], + connection=conn, + error_message="Failed to create assets directory", + ) + execute( + env, + ["unzip", "-o", APP_ASSETS_ZIP_PATH, "-d", APP_ASSETS_PATH], + connection=conn, + timeout=60, + error_message="Failed to unzip assets", + ) + execute( + env, + ["rm", "-f", APP_ASSETS_ZIP_PATH], + connection=conn, + error_message="Failed to cleanup temp zip file", + ) + + logger.info( + "App assets initialized for app_id=%s, published_id=%s", + self._app_id, + published.id, + ) + + def _get_latest_published(self) -> AppAssetDraft | None: + with Session(db.engine) as session: + return ( + session.query(AppAssetDraft) + .filter( + AppAssetDraft.tenant_id == self._tenant_id, + AppAssetDraft.app_id == self._app_id, + AppAssetDraft.version != AppAssetDraft.VERSION_DRAFT, + ) + .order_by(AppAssetDraft.created_at.desc()) + .first() + ) diff --git a/api/core/sandbox/initializer/base.py b/api/core/sandbox/initializer/base.py new file mode 100644 index 0000000000..937b09c2dc --- /dev/null +++ b/api/core/sandbox/initializer/base.py @@ -0,0 +1,8 @@ +from abc import ABC, abstractmethod + +from core.virtual_environment.__base.virtual_environment import VirtualEnvironment + + +class SandboxInitializer(ABC): + @abstractmethod + def initialize(self, env: VirtualEnvironment) -> None: ... diff --git a/api/core/sandbox/initializer/initializer.py b/api/core/sandbox/initializer/dify_cli_initializer.py similarity index 86% rename from api/core/sandbox/initializer/initializer.py rename to api/core/sandbox/initializer/dify_cli_initializer.py index 50535faac7..29f8982616 100644 --- a/api/core/sandbox/initializer/initializer.py +++ b/api/core/sandbox/initializer/dify_cli_initializer.py @@ -1,21 +1,16 @@ import logging -from abc import ABC, abstractmethod from io import BytesIO from pathlib import Path from core.sandbox.bash.dify_cli import DifyCliLocator from core.sandbox.constants import DIFY_CLI_PATH +from core.sandbox.initializer.base import SandboxInitializer from core.virtual_environment.__base.helpers import execute from core.virtual_environment.__base.virtual_environment import VirtualEnvironment logger = logging.getLogger(__name__) -class SandboxInitializer(ABC): - @abstractmethod - def initialize(self, env: VirtualEnvironment) -> None: ... - - class DifyCliInitializer(SandboxInitializer): def __init__(self, cli_root: str | Path | None = None) -> None: self._locator = DifyCliLocator(root=cli_root) diff --git a/api/core/sandbox/utils/debug.py b/api/core/sandbox/utils/debug.py index e541bc435e..397d0ee322 100644 --- a/api/core/sandbox/utils/debug.py +++ b/api/core/sandbox/utils/debug.py @@ -1,8 +1,9 @@ """Sandbox debug utilities. TODO: Remove this module when sandbox debugging is complete.""" -from typing import Any +from typing import TYPE_CHECKING, Any -from core.callback_handler.agent_tool_callback_handler import print_text +if TYPE_CHECKING: + pass SANDBOX_DEBUG_ENABLED = True @@ -11,6 +12,9 @@ def sandbox_debug(tag: str, message: str, data: Any = None) -> None: if not SANDBOX_DEBUG_ENABLED: return + # Lazy import to avoid circular dependency + from core.callback_handler.agent_tool_callback_handler import print_text + print_text(f"\n[{tag}]\n", color="blue") if data is not None: print_text(f"{message}: {data}\n", color="blue") diff --git a/api/services/sandbox/sandbox_provider_service.py b/api/services/sandbox/sandbox_provider_service.py index f88a53137a..c1860e02c2 100644 --- a/api/services/sandbox/sandbox_provider_service.py +++ b/api/services/sandbox/sandbox_provider_service.py @@ -19,13 +19,11 @@ from sqlalchemy.orm import Session from configs import dify_config from constants import HIDDEN_VALUE from core.entities.provider_entities import BasicProviderConfig -from core.sandbox.factory import VMFactory, VMType -from core.sandbox.initializer import DifyCliInitializer +from core.sandbox.factory import VMBuilder, VMType from core.sandbox.utils.encryption import create_sandbox_config_encrypter, masked_config from core.tools.utils.system_encryption import ( decrypt_system_params, ) -from core.virtual_environment.__base.virtual_environment import VirtualEnvironment from extensions.ext_database import db from models.sandbox import SandboxProvider, SandboxProviderSystemConfig @@ -175,7 +173,7 @@ class SandboxProviderService: if model_class: model_class.model_validate(config) - VMFactory.validate(VMType(provider_type), config) + VMBuilder.validate(VMType(provider_type), config) @classmethod def save_config( @@ -306,13 +304,8 @@ class SandboxProviderService: return config.provider_type if config else None @classmethod - def create_sandbox( - cls, - tenant_id: str, - environments: Mapping[str, str] | None = None, - ) -> VirtualEnvironment: + def create_sandbox_builder(cls, tenant_id: str) -> VMBuilder: with Session(db.engine, expire_on_commit=False) as session: - # Get config: tenant config > system default > raise error tenant_config = ( session.query(SandboxProvider) .filter( @@ -337,10 +330,4 @@ class SandboxProviderService: if not config or not provider_type: raise ValueError(f"No active sandbox provider for tenant {tenant_id} or system default") - return VMFactory.create( - tenant_id=tenant_id, - vm_type=VMType(provider_type), - options=dict(config), - environments=environments or {}, - initializers=[DifyCliInitializer()], - ) + return VMBuilder(tenant_id, VMType(provider_type)).options(config) diff --git a/api/services/workflow_service.py b/api/services/workflow_service.py index a706e9bbdc..0005de826f 100644 --- a/api/services/workflow_service.py +++ b/api/services/workflow_service.py @@ -701,7 +701,14 @@ class WorkflowService: sandbox = None single_step_execution_id: str | None = None if draft_workflow.get_feature(WorkflowFeatures.SANDBOX).enabled: - sandbox = SandboxProviderService.create_sandbox(tenant_id=draft_workflow.tenant_id) + from core.sandbox.initializer import AppAssetsInitializer, DifyCliInitializer + + sandbox = ( + SandboxProviderService.create_sandbox_builder(draft_workflow.tenant_id) + .initializer(DifyCliInitializer()) + .initializer(AppAssetsInitializer(draft_workflow.tenant_id, app_model.id)) + .build() + ) single_step_execution_id = f"single-step-{uuid.uuid4()}" SandboxManager.register(single_step_execution_id, sandbox) diff --git a/api/tests/unit_tests/core/app/layers/test_sandbox_layer.py b/api/tests/unit_tests/core/app/layers/test_sandbox_layer.py index b3c3822646..89c1d03f40 100644 --- a/api/tests/unit_tests/core/app/layers/test_sandbox_layer.py +++ b/api/tests/unit_tests/core/app/layers/test_sandbox_layer.py @@ -30,23 +30,50 @@ class MockVirtualEnvironment: class MockSystemVariableView: - def __init__(self, workflow_execution_id: str | None = "test-workflow-exec-id"): + def __init__( + self, + workflow_execution_id: str | None = "test-workflow-exec-id", + app_id: str | None = "test-app-id", + ): self._workflow_execution_id = workflow_execution_id + self._app_id = app_id @property def workflow_execution_id(self) -> str | None: return self._workflow_execution_id + @property + def app_id(self) -> str | None: + return self._app_id + class MockReadOnlyGraphRuntimeStateWrapper: - def __init__(self, workflow_execution_id: str | None = "test-workflow-exec-id"): - self._system_variable = MockSystemVariableView(workflow_execution_id) + def __init__( + self, + workflow_execution_id: str | None = "test-workflow-exec-id", + app_id: str | None = "test-app-id", + ): + self._system_variable = MockSystemVariableView(workflow_execution_id, app_id) @property def system_variable(self) -> MockSystemVariableView: return self._system_variable +class MockVMBuilder: + def __init__(self, sandbox: VirtualEnvironment): + self._sandbox = sandbox + + def environments(self, _): + return self + + def initializer(self, _): + return self + + def build(self): + return self._sandbox + + @pytest.fixture(autouse=True) def clean_sandbox_manager(): SandboxManager.clear() @@ -54,17 +81,15 @@ def clean_sandbox_manager(): SandboxManager.clear() +def create_mock_builder(sandbox): + return MockVMBuilder(sandbox) + + class TestSandboxLayer: def test_init_with_parameters(self): - layer = SandboxLayer( - tenant_id="test-tenant", - options={"base_working_path": "/tmp/sandbox"}, - environments={"PYTHONUNBUFFERED": "1"}, - ) + layer = SandboxLayer(tenant_id="test-tenant") assert layer._tenant_id == "test-tenant" # pyright: ignore[reportPrivateUsage] - assert layer._options == {"base_working_path": "/tmp/sandbox"} # pyright: ignore[reportPrivateUsage] - assert layer._environments == {"PYTHONUNBUFFERED": "1"} # pyright: ignore[reportPrivateUsage] assert layer._workflow_execution_id is None # pyright: ignore[reportPrivateUsage] def test_sandbox_property_raises_when_not_initialized(self): @@ -82,32 +107,25 @@ class TestSandboxLayer: layer._graph_runtime_state = mock_runtime_state # type: ignore[assignment] with patch( - "services.sandbox.sandbox_provider_service.SandboxProviderService.create_sandbox", - return_value=mock_sandbox, + "services.sandbox.sandbox_provider_service.SandboxProviderService.create_sandbox_builder", + return_value=create_mock_builder(mock_sandbox), ): layer.on_graph_start() assert layer.sandbox is mock_sandbox def test_on_graph_start_creates_sandbox_and_registers_with_manager(self): - layer = SandboxLayer( - tenant_id="test-tenant-123", - environments={"PATH": "/usr/bin"}, - ) + layer = SandboxLayer(tenant_id="test-tenant-123") mock_sandbox = MockVirtualEnvironment() - mock_runtime_state = MockReadOnlyGraphRuntimeStateWrapper("test-exec-123") + mock_runtime_state = MockReadOnlyGraphRuntimeStateWrapper("test-exec-123", "test-app-123") layer._graph_runtime_state = mock_runtime_state # type: ignore[assignment] with patch( - "services.sandbox.sandbox_provider_service.SandboxProviderService.create_sandbox", - return_value=mock_sandbox, + "services.sandbox.sandbox_provider_service.SandboxProviderService.create_sandbox_builder", + return_value=create_mock_builder(mock_sandbox), ) as mock_create: layer.on_graph_start() - - mock_create.assert_called_once_with( - tenant_id="test-tenant-123", - environments={"PATH": "/usr/bin"}, - ) + mock_create.assert_called_once_with("test-tenant-123") assert SandboxManager.get("test-exec-123") is mock_sandbox @@ -117,7 +135,7 @@ class TestSandboxLayer: layer._graph_runtime_state = mock_runtime_state # type: ignore[assignment] with patch( - "services.sandbox.sandbox_provider_service.SandboxProviderService.create_sandbox", + "services.sandbox.sandbox_provider_service.SandboxProviderService.create_sandbox_builder", side_effect=Exception("Sandbox provider not available"), ): with pytest.raises(SandboxInitializationError) as exc_info: @@ -152,8 +170,8 @@ class TestSandboxLayer: layer._graph_runtime_state = mock_runtime_state # type: ignore[assignment] with patch( - "services.sandbox.sandbox_provider_service.SandboxProviderService.create_sandbox", - return_value=mock_sandbox, + "services.sandbox.sandbox_provider_service.SandboxProviderService.create_sandbox_builder", + return_value=create_mock_builder(mock_sandbox), ): layer.on_graph_start() @@ -174,8 +192,8 @@ class TestSandboxLayer: layer._graph_runtime_state = mock_runtime_state # type: ignore[assignment] with patch( - "services.sandbox.sandbox_provider_service.SandboxProviderService.create_sandbox", - return_value=mock_sandbox, + "services.sandbox.sandbox_provider_service.SandboxProviderService.create_sandbox_builder", + return_value=create_mock_builder(mock_sandbox), ): layer.on_graph_start() @@ -195,8 +213,8 @@ class TestSandboxLayer: layer._graph_runtime_state = mock_runtime_state # type: ignore[assignment] with patch( - "services.sandbox.sandbox_provider_service.SandboxProviderService.create_sandbox", - return_value=mock_sandbox, + "services.sandbox.sandbox_provider_service.SandboxProviderService.create_sandbox_builder", + return_value=create_mock_builder(mock_sandbox), ): layer.on_graph_start() @@ -221,8 +239,8 @@ class TestSandboxLayer: layer._graph_runtime_state = mock_runtime_state # type: ignore[assignment] with patch( - "services.sandbox.sandbox_provider_service.SandboxProviderService.create_sandbox", - return_value=mock_sandbox, + "services.sandbox.sandbox_provider_service.SandboxProviderService.create_sandbox_builder", + return_value=create_mock_builder(mock_sandbox), ): layer.on_graph_start() @@ -250,8 +268,8 @@ class TestSandboxLayerIntegration: mock_sandbox.metadata = MockMetadata(sandbox_id="integration-sandbox") with patch( - "services.sandbox.sandbox_provider_service.SandboxProviderService.create_sandbox", - return_value=mock_sandbox, + "services.sandbox.sandbox_provider_service.SandboxProviderService.create_sandbox_builder", + return_value=create_mock_builder(mock_sandbox), ): layer.on_graph_start() @@ -274,8 +292,8 @@ class TestSandboxLayerIntegration: mock_sandbox.metadata = MockMetadata() with patch( - "services.sandbox.sandbox_provider_service.SandboxProviderService.create_sandbox", - return_value=mock_sandbox, + "services.sandbox.sandbox_provider_service.SandboxProviderService.create_sandbox_builder", + return_value=create_mock_builder(mock_sandbox), ): layer.on_graph_start() diff --git a/api/tests/unit_tests/core/virtual_environment/test_factory.py b/api/tests/unit_tests/core/virtual_environment/test_factory.py index 88e1f3f0a7..7f6b900656 100644 --- a/api/tests/unit_tests/core/virtual_environment/test_factory.py +++ b/api/tests/unit_tests/core/virtual_environment/test_factory.py @@ -1,148 +1,116 @@ -""" -Unit tests for the SandboxFactory. - -This module tests the factory pattern implementation for creating VirtualEnvironment instances -based on sandbox type, including error handling for unsupported types. -""" - from pathlib import Path from unittest.mock import MagicMock, patch import pytest -from core.sandbox.factory import VMFactory, VMType +from core.sandbox.factory import VMBuilder, VMType from core.virtual_environment.__base.virtual_environment import VirtualEnvironment -class TestSandboxType: - """Test cases for SandboxType enum.""" - - def test_sandbox_type_values(self): - """Test that SandboxType enum has expected values.""" +class TestVMType: + def test_values(self): assert VMType.DOCKER == "docker" assert VMType.E2B == "e2b" assert VMType.LOCAL == "local" - def test_sandbox_type_is_string_enum(self): - """Test that SandboxType values are strings.""" + def test_is_string_enum(self): assert isinstance(VMType.DOCKER.value, str) assert isinstance(VMType.E2B.value, str) assert isinstance(VMType.LOCAL.value, str) -class TestSandboxFactory: - """Test cases for SandboxFactory.""" +class TestVMBuilder: + def test_build_docker(self): + mock_instance = MagicMock(spec=VirtualEnvironment) + mock_class = MagicMock(return_value=mock_instance) - def test_create_docker_sandbox_success(self): - """Test successful Docker sandbox creation.""" - mock_sandbox_instance = MagicMock(spec=VirtualEnvironment) - mock_sandbox_class = MagicMock(return_value=mock_sandbox_instance) - - with patch.object(VMFactory, "_get_sandbox_class", return_value=mock_sandbox_class): - result = VMFactory.create( - tenant_id="test-tenant", - vm_type=VMType.DOCKER, - options={"docker_image": "python:3.11-slim"}, - environments={"PYTHONUNBUFFERED": "1"}, + with patch( + "core.virtual_environment.providers.docker_daemon_sandbox.DockerDaemonEnvironment", + mock_class, + ): + result = ( + VMBuilder("test-tenant", VMType.DOCKER) + .options({"docker_image": "python:3.11-slim"}) + .environments({"PYTHONUNBUFFERED": "1"}) + .build() ) - mock_sandbox_class.assert_called_once_with( + mock_class.assert_called_once_with( tenant_id="test-tenant", options={"docker_image": "python:3.11-slim"}, environments={"PYTHONUNBUFFERED": "1"}, user_id=None, ) - assert result is mock_sandbox_instance + assert result is mock_instance - def test_create_with_none_options_uses_empty_dict(self): - """Test that None options are converted to empty dict.""" - mock_sandbox_instance = MagicMock(spec=VirtualEnvironment) - mock_sandbox_class = MagicMock(return_value=mock_sandbox_instance) - - with patch.object(VMFactory, "_get_sandbox_class", return_value=mock_sandbox_class): - VMFactory.create(tenant_id="test-tenant", vm_type=VMType.DOCKER, options=None, environments=None) - - mock_sandbox_class.assert_called_once_with( - tenant_id="test-tenant", options={}, environments={}, user_id=None - ) - - def test_create_with_default_parameters(self): - """Test sandbox creation with default parameters.""" - mock_sandbox_instance = MagicMock(spec=VirtualEnvironment) - mock_sandbox_class = MagicMock(return_value=mock_sandbox_instance) - - with patch.object(VMFactory, "_get_sandbox_class", return_value=mock_sandbox_class): - result = VMFactory.create(tenant_id="test-tenant", vm_type=VMType.DOCKER) - - mock_sandbox_class.assert_called_once_with( - tenant_id="test-tenant", options={}, environments={}, user_id=None - ) - assert result is mock_sandbox_instance - - def test_get_sandbox_class_docker_returns_correct_class(self): - """Test that DOCKER type returns DockerDaemonEnvironment class.""" - # Test by creating with mock to verify the class lookup works + def test_build_with_user(self): mock_instance = MagicMock(spec=VirtualEnvironment) + mock_class = MagicMock(return_value=mock_instance) with patch( "core.virtual_environment.providers.docker_daemon_sandbox.DockerDaemonEnvironment", - return_value=mock_instance, - ) as mock_docker_class: - VMFactory.create(tenant_id="test-tenant", vm_type=VMType.DOCKER) - mock_docker_class.assert_called_once() + mock_class, + ): + VMBuilder("test-tenant", VMType.DOCKER).user("user-123").build() - def test_get_sandbox_class_local_returns_correct_class(self): - """Test that LOCAL type returns LocalVirtualEnvironment class.""" + mock_class.assert_called_once_with( + tenant_id="test-tenant", + options={}, + environments={}, + user_id="user-123", + ) + + def test_build_with_initializers(self): + mock_instance = MagicMock(spec=VirtualEnvironment) + mock_class = MagicMock(return_value=mock_instance) + mock_initializer = MagicMock() + + with patch( + "core.virtual_environment.providers.docker_daemon_sandbox.DockerDaemonEnvironment", + mock_class, + ): + VMBuilder("test-tenant", VMType.DOCKER).initializer(mock_initializer).build() + + mock_initializer.initialize.assert_called_once_with(mock_instance) + + def test_build_local(self): mock_instance = MagicMock(spec=VirtualEnvironment) with patch( "core.virtual_environment.providers.local_without_isolation.LocalVirtualEnvironment", return_value=mock_instance, - ) as mock_local_class: - VMFactory.create(tenant_id="test-tenant", vm_type=VMType.LOCAL) - mock_local_class.assert_called_once() + ) as mock_class: + VMBuilder("test-tenant", VMType.LOCAL).build() + mock_class.assert_called_once() - def test_get_sandbox_class_e2b_returns_correct_class(self): - """Test that E2B type returns E2BEnvironment class.""" + def test_build_e2b(self): mock_instance = MagicMock(spec=VirtualEnvironment) with patch( "core.virtual_environment.providers.e2b_sandbox.E2BEnvironment", return_value=mock_instance, - ) as mock_e2b_class: - VMFactory.create(tenant_id="test-tenant", vm_type=VMType.E2B) - mock_e2b_class.assert_called_once() + ) as mock_class: + VMBuilder("test-tenant", VMType.E2B).build() + mock_class.assert_called_once() - def test_create_with_unsupported_type_raises_value_error(self): - """Test that unsupported sandbox type raises ValueError.""" - with pytest.raises(ValueError) as exc_info: - VMFactory.create(tenant_id="test-tenant", vm_type="unsupported_type") # type: ignore[arg-type] + def test_build_unsupported_type_raises(self): + with pytest.raises(ValueError, match="Unsupported VM type"): + VMBuilder("test-tenant", "unsupported").build() # type: ignore[arg-type] - assert "Unsupported sandbox type: unsupported_type" in str(exc_info.value) + def test_validate(self): + mock_class = MagicMock() - def test_create_propagates_instantiation_error(self): - """Test that sandbox instantiation errors are propagated.""" - mock_sandbox_class = MagicMock() - mock_sandbox_class.side_effect = Exception("Docker daemon not available") - - with patch.object(VMFactory, "_get_sandbox_class", return_value=mock_sandbox_class): - with pytest.raises(Exception) as exc_info: - VMFactory.create(tenant_id="test-tenant", vm_type=VMType.DOCKER) - - assert "Docker daemon not available" in str(exc_info.value) + with patch( + "core.virtual_environment.providers.docker_daemon_sandbox.DockerDaemonEnvironment", + mock_class, + ): + VMBuilder.validate(VMType.DOCKER, {"key": "value"}) + mock_class.validate.assert_called_once_with({"key": "value"}) -class TestSandboxFactoryIntegration: - """Integration tests for SandboxFactory with real providers (using LOCAL type).""" - - def test_create_local_sandbox_integration(self, tmp_path: Path): - """Test creating a real local sandbox.""" - sandbox = VMFactory.create( - tenant_id="test-tenant", - vm_type=VMType.LOCAL, - options={"base_working_path": str(tmp_path)}, - environments={}, - ) +class TestVMBuilderIntegration: + def test_local_sandbox(self, tmp_path: Path): + sandbox = VMBuilder("test-tenant", VMType.LOCAL).options({"base_working_path": str(tmp_path)}).build() try: assert sandbox is not None