From 9ed83a808afbad5bd969b5ece4a0fc3a3e33f6bd Mon Sep 17 00:00:00 2001 From: Harry Date: Wed, 21 Jan 2026 20:42:19 +0800 Subject: [PATCH] refactor: consolidate sandbox management and initialization - Moved sandbox-related classes and functions into a dedicated module for better organization. - Updated the sandbox initialization process to streamline asset management and environment setup. - Removed deprecated constants and refactored related code to utilize new sandbox entities. - Enhanced the workflow context to support sandbox integration, allowing for improved state management during execution. - Adjusted various components to utilize the new sandbox structure, ensuring compatibility across the application. --- api/commands.py | 2 +- .../app/apps/advanced_chat/app_generator.py | 28 ++- api/core/app/apps/advanced_chat/app_runner.py | 7 + api/core/app/apps/workflow/app_generator.py | 26 ++- api/core/app/apps/workflow/app_runner.py | 6 + api/core/app/layers/sandbox_layer.py | 112 +----------- api/core/sandbox/__init__.py | 26 +-- api/core/sandbox/bash/bash_tool.py | 4 +- api/core/sandbox/bash/dify_cli.py | 4 +- api/core/sandbox/bash/session.py | 84 ++++----- api/core/sandbox/{vm.py => builder.py} | 87 ++++++---- api/core/sandbox/constants.py | 16 -- api/core/sandbox/entities/__init__.py | 9 +- api/core/sandbox/entities/config.py | 19 ++ api/core/sandbox/entities/sandbox_type.py | 16 ++ .../initializer/app_assets_initializer.py | 12 +- .../initializer/dify_cli_initializer.py | 29 ++-- api/core/sandbox/manager.py | 107 ++++++++++-- api/core/sandbox/sandbox.py | 73 ++++++++ api/core/workflow/context/__init__.py | 2 - api/core/workflow/context/models.py | 12 +- api/core/workflow/nodes/command/node.py | 23 +-- api/core/workflow/nodes/llm/node.py | 36 +--- .../workflow/runtime/graph_runtime_state.py | 13 ++ .../runtime/graph_runtime_state_protocol.py | 4 + .../workflow/runtime/read_only_wrappers.py | 4 + api/core/workflow/workflow_entry.py | 5 + .../sandbox/sandbox_provider_service.py | 15 +- api/services/workflow_service.py | 49 ++---- .../test_sandbox_manager.py | 164 ------------------ 30 files changed, 449 insertions(+), 545 deletions(-) rename api/core/sandbox/{vm.py => builder.py} (60%) delete mode 100644 api/core/sandbox/constants.py create mode 100644 api/core/sandbox/entities/config.py create mode 100644 api/core/sandbox/entities/sandbox_type.py create mode 100644 api/core/sandbox/sandbox.py delete mode 100644 api/tests/unit_tests/core/virtual_environment/test_sandbox_manager.py diff --git a/api/commands.py b/api/commands.py index 565661321d..d3b8c30f2d 100644 --- a/api/commands.py +++ b/api/commands.py @@ -23,7 +23,7 @@ from core.rag.datasource.vdb.vector_factory import Vector from core.rag.datasource.vdb.vector_type import VectorType from core.rag.index_processor.constant.built_in_field import BuiltInField from core.rag.models.document import Document -from core.sandbox.vm import SandboxBuilder, SandboxType +from core.sandbox import SandboxBuilder, SandboxType from core.tools.utils.system_encryption import encrypt_system_params from events.app_event import app_was_created from extensions.ext_database import db diff --git a/api/core/app/apps/advanced_chat/app_generator.py b/api/core/app/apps/advanced_chat/app_generator.py index f1e4b1ad14..ef15b18c76 100644 --- a/api/core/app/apps/advanced_chat/app_generator.py +++ b/api/core/app/apps/advanced_chat/app_generator.py @@ -30,6 +30,7 @@ from core.model_runtime.errors.invoke import InvokeAuthorizationError from core.ops.ops_trace_manager import TraceQueueManager from core.prompt.utils.get_thread_messages_length import get_thread_messages_length from core.repositories import DifyCoreRepositoryFactory +from core.sandbox import Sandbox, SandboxManager from core.workflow.repositories.draft_variable_repository import ( DraftVariableSaverFactory, ) @@ -43,6 +44,7 @@ from models import Account, App, Conversation, EndUser, Message, Workflow, Workf from models.enums import WorkflowRunTriggeredFrom from models.workflow_features import WorkflowFeatures from services.conversation_service import ConversationService +from services.sandbox.sandbox_provider_service import SandboxProviderService from services.workflow_draft_variable_service import ( DraftVarLoader, WorkflowDraftVariableService, @@ -514,19 +516,30 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator): if workflow is None: raise ValueError("Workflow not found") + sandbox: Sandbox | None = None graph_engine_layers: tuple = () if workflow.get_feature(WorkflowFeatures.SANDBOX).enabled: - if application_generate_entity.workflow_run_id is None: - raise ValueError("workflow_run_id is required when sandbox is enabled") - graph_engine_layers = ( - SandboxLayer( + sandbox_provider = SandboxProviderService.get_sandbox_provider( + application_generate_entity.app_config.tenant_id + ) + if workflow.version == Workflow.VERSION_DRAFT: + sandbox = SandboxManager.create_draft( + tenant_id=application_generate_entity.app_config.tenant_id, + app_id=application_generate_entity.app_config.app_id, + user_id=application_generate_entity.user_id, + sandbox_provider=sandbox_provider, + ) + else: + if application_generate_entity.workflow_run_id is None: + raise ValueError("workflow_run_id is required when sandbox is enabled") + sandbox = SandboxManager.create( tenant_id=application_generate_entity.app_config.tenant_id, app_id=application_generate_entity.app_config.app_id, user_id=application_generate_entity.user_id, - workflow_version=workflow.version, workflow_execution_id=application_generate_entity.workflow_run_id, - ), - ) + sandbox_provider=sandbox_provider, + ) + graph_engine_layers = (SandboxLayer(sandbox=sandbox),) # Determine system_user_id based on invocation source is_external_api_call = application_generate_entity.invoke_from in { @@ -559,6 +572,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator): workflow_execution_repository=workflow_execution_repository, workflow_node_execution_repository=workflow_node_execution_repository, graph_engine_layers=graph_engine_layers, + sandbox=sandbox, ) try: diff --git a/api/core/app/apps/advanced_chat/app_runner.py b/api/core/app/apps/advanced_chat/app_runner.py index a258144d35..a9e41bffdb 100644 --- a/api/core/app/apps/advanced_chat/app_runner.py +++ b/api/core/app/apps/advanced_chat/app_runner.py @@ -24,6 +24,7 @@ from core.app.layers.conversation_variable_persist_layer import ConversationVari from core.db.session_factory import session_factory from core.moderation.base import ModerationError from core.moderation.input_moderation import InputModeration +from core.sandbox import Sandbox from core.variables.variables import Variable from core.workflow.enums import WorkflowType from core.workflow.graph_engine.command_channels.redis_channel import RedisChannel @@ -66,6 +67,7 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner): workflow_execution_repository: WorkflowExecutionRepository, workflow_node_execution_repository: WorkflowNodeExecutionRepository, graph_engine_layers: Sequence[GraphEngineLayer] = (), + sandbox: Sandbox | None = None, ): super().__init__( queue_manager=queue_manager, @@ -82,6 +84,7 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner): self._app = app self._workflow_execution_repository = workflow_execution_repository self._workflow_node_execution_repository = workflow_node_execution_repository + self._sandbox = sandbox @trace_span(WorkflowAppRunnerHandler) def run(self): @@ -156,6 +159,10 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner): # init graph graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.time()) + + if self._sandbox: + graph_runtime_state.set_sandbox(self._sandbox) + graph = self._init_graph( graph_config=self._workflow.graph_dict, graph_runtime_state=graph_runtime_state, diff --git a/api/core/app/apps/workflow/app_generator.py b/api/core/app/apps/workflow/app_generator.py index 53e80dcd2b..f9199ceebb 100644 --- a/api/core/app/apps/workflow/app_generator.py +++ b/api/core/app/apps/workflow/app_generator.py @@ -29,6 +29,7 @@ from core.helper.trace_id_helper import extract_external_trace_id_from_args from core.model_runtime.errors.invoke import InvokeAuthorizationError from core.ops.ops_trace_manager import TraceQueueManager from core.repositories import DifyCoreRepositoryFactory +from core.sandbox import Sandbox, SandboxManager from core.workflow.graph_engine.layers.base import GraphEngineLayer from core.workflow.repositories.draft_variable_repository import DraftVariableSaverFactory from core.workflow.repositories.workflow_execution_repository import WorkflowExecutionRepository @@ -40,6 +41,7 @@ from libs.flask_utils import preserve_flask_contexts from models import Account, App, EndUser, Workflow, WorkflowNodeExecutionTriggeredFrom from models.enums import WorkflowRunTriggeredFrom from models.workflow_features import WorkflowFeatures +from services.sandbox.sandbox_provider_service import SandboxProviderService from services.workflow_draft_variable_service import DraftVarLoader, WorkflowDraftVariableService SKIP_PREPARE_USER_INPUTS_KEY = "_skip_prepare_user_inputs" @@ -490,16 +492,29 @@ class WorkflowAppGenerator(BaseAppGenerator): if workflow is None: raise ValueError("Workflow not found") + sandbox: Sandbox | None = None if workflow.get_feature(WorkflowFeatures.SANDBOX).enabled: - graph_engine_layers = ( - *graph_engine_layers, - SandboxLayer( + sandbox_provider = SandboxProviderService.get_sandbox_provider( + application_generate_entity.app_config.tenant_id + ) + if workflow.version == Workflow.VERSION_DRAFT: + sandbox = SandboxManager.create_draft( + tenant_id=application_generate_entity.app_config.tenant_id, + app_id=application_generate_entity.app_config.app_id, + user_id=application_generate_entity.user_id, + sandbox_provider=sandbox_provider, + ) + else: + sandbox = SandboxManager.create( tenant_id=application_generate_entity.app_config.tenant_id, app_id=application_generate_entity.app_config.app_id, user_id=application_generate_entity.user_id, - workflow_version=workflow.version, workflow_execution_id=application_generate_entity.workflow_execution_id, - ), + sandbox_provider=sandbox_provider, + ) + graph_engine_layers = ( + *graph_engine_layers, + SandboxLayer(sandbox=sandbox), ) # Determine system_user_id based on invocation source @@ -526,6 +541,7 @@ class WorkflowAppGenerator(BaseAppGenerator): workflow_node_execution_repository=workflow_node_execution_repository, root_node_id=root_node_id, graph_engine_layers=graph_engine_layers, + sandbox=sandbox, ) try: diff --git a/api/core/app/apps/workflow/app_runner.py b/api/core/app/apps/workflow/app_runner.py index 8dbdc1d58c..9bc0275f6e 100644 --- a/api/core/app/apps/workflow/app_runner.py +++ b/api/core/app/apps/workflow/app_runner.py @@ -7,6 +7,7 @@ from core.app.apps.base_app_queue_manager import AppQueueManager from core.app.apps.workflow.app_config_manager import WorkflowAppConfig from core.app.apps.workflow_app_runner import WorkflowBasedAppRunner from core.app.entities.app_invoke_entities import InvokeFrom, WorkflowAppGenerateEntity +from core.sandbox import Sandbox from core.workflow.enums import WorkflowType from core.workflow.graph_engine.command_channels.redis_channel import RedisChannel from core.workflow.graph_engine.layers.base import GraphEngineLayer @@ -42,6 +43,7 @@ class WorkflowAppRunner(WorkflowBasedAppRunner): workflow_execution_repository: WorkflowExecutionRepository, workflow_node_execution_repository: WorkflowNodeExecutionRepository, graph_engine_layers: Sequence[GraphEngineLayer] = (), + sandbox: Sandbox | None = None, ): super().__init__( queue_manager=queue_manager, @@ -55,6 +57,7 @@ class WorkflowAppRunner(WorkflowBasedAppRunner): self._root_node_id = root_node_id self._workflow_execution_repository = workflow_execution_repository self._workflow_node_execution_repository = workflow_node_execution_repository + self._sandbox = sandbox @trace_span(WorkflowAppRunnerHandler) def run(self): @@ -99,6 +102,9 @@ class WorkflowAppRunner(WorkflowBasedAppRunner): graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()) + if self._sandbox: + graph_runtime_state.set_sandbox(self._sandbox) + # init graph graph = self._init_graph( graph_config=self._workflow.graph_dict, diff --git a/api/core/app/layers/sandbox_layer.py b/api/core/app/layers/sandbox_layer.py index c1e2b27a8b..85ed53c4d6 100644 --- a/api/core/app/layers/sandbox_layer.py +++ b/api/core/app/layers/sandbox_layer.py @@ -1,122 +1,22 @@ import logging -from core.sandbox import AppAssetsInitializer, DifyCliInitializer, SandboxManager -from core.sandbox.constants import APP_ASSETS_PATH -from core.sandbox.initializer.app_assets_initializer import DraftAppAssetsInitializer -from core.sandbox.storage.archive_storage import ArchiveSandboxStorage -from core.sandbox.vm import SandboxBuilder +from core.sandbox import Sandbox from core.workflow.graph_engine.layers.base import GraphEngineLayer from core.workflow.graph_events.base import GraphEngineEvent -from core.workflow.graph_events.graph import GraphRunPausedEvent -from core.workflow.nodes.base.node import Node -from models.workflow import Workflow -from services.app_asset_service import AppAssetService -from services.sandbox.sandbox_provider_service import SandboxProviderService logger = logging.getLogger(__name__) -class SandboxInitializationError(Exception): - pass - - class SandboxLayer(GraphEngineLayer): - def __init__( - self, - tenant_id: str, - app_id: str, - user_id: str, - workflow_version: str, - workflow_execution_id: str, - ) -> None: + def __init__(self, sandbox: Sandbox) -> None: super().__init__() - self._tenant_id = tenant_id - self._app_id = app_id - self._user_id = user_id - self._workflow_version = workflow_version - self._workflow_execution_id = workflow_execution_id - is_draft = self._workflow_version == Workflow.VERSION_DRAFT - self._sandbox_id = SandboxBuilder.draft_id(self._user_id) if is_draft else self._workflow_execution_id - self._sandbox_storage = ArchiveSandboxStorage( - self._tenant_id, self._sandbox_id, exclude_patterns=[APP_ASSETS_PATH] if is_draft else None - ) + self._sandbox = sandbox def on_graph_start(self) -> None: - try: - is_draft = self._workflow_version == Workflow.VERSION_DRAFT - assets = AppAssetService.get_assets(self._tenant_id, self._app_id, self._user_id, is_draft=is_draft) - if not assets: - raise ValueError( - f"No assets found for tid={self._tenant_id}, app_id={self._app_id}, wf={self._workflow_version}" - ) - - self._assets_id = assets.id - - if is_draft: - logger.info( - "Building draft assets for tenant_id=%s, app_id=%s, workflow_version=%s, assets_id=%s", - self._tenant_id, - self._app_id, - self._workflow_version, - assets.id, - ) - AppAssetService.build_assets(self._tenant_id, self._app_id, assets) - - assets_initializer = ( - DraftAppAssetsInitializer(self._tenant_id, self._app_id, assets.id) - if is_draft - else AppAssetsInitializer(self._tenant_id, self._app_id, assets.id) - ) - - builder = ( - SandboxProviderService.create_sandbox_builder(self._tenant_id) - .initializer(assets_initializer) - .initializer(DifyCliInitializer(self._tenant_id, self._user_id, self._app_id, assets.id)) - ) - try: - sandbox = builder.build() - logger.info( - "Sandbox initialized, workflow_execution_id=%s, sandbox_id=%s, sandbox_arch=%s", - self._sandbox_id, - sandbox.metadata.id, - sandbox.metadata.arch, - ) - except Exception as e: - raise SandboxInitializationError(f"Failed to build sandbox: {e}") from e - - SandboxManager.register(self._sandbox_id, sandbox) - - # mount sandbox files from storage - mounted = self._sandbox_storage.mount(sandbox) - logger.info("Sandbox files mount status: %s", mounted) - - except Exception as e: - logger.exception("Failed to initialize sandbox") - raise SandboxInitializationError(f"Failed to initialize sandbox: {e}") from e - - def on_node_run_start(self, node: Node) -> None: - # FIXME(Mairuis): should read from workflow run context... - node.assets_id = self._assets_id + pass def on_event(self, event: GraphEngineEvent) -> None: - # TODO: handle graph run paused event - if not isinstance(event, GraphRunPausedEvent): - return + pass def on_graph_end(self, error: Exception | None) -> None: - sandbox = SandboxManager.unregister(self._sandbox_id) - if sandbox is None: - logger.debug("No sandbox to release for sandbox_id=%s", self._sandbox_id) - return - - try: - self._sandbox_storage.unmount(sandbox) - logger.info("Sandbox files persisted, sandbox_id=%s", self._sandbox_id) - except Exception: - logger.exception("Failed to persist sandbox files, sandbox_id=%s", self._sandbox_id) - - try: - sandbox.release_environment() - logger.info("Sandbox released, sandbox_id=%s", self._sandbox_id) - except Exception: - logger.exception("Failed to release sandbox, sandbox_id=%s", self._sandbox_id) + self._sandbox.release() diff --git a/api/core/sandbox/__init__.py b/api/core/sandbox/__init__.py index 4427783d23..33718559e6 100644 --- a/api/core/sandbox/__init__.py +++ b/api/core/sandbox/__init__.py @@ -6,44 +6,32 @@ from .bash.dify_cli import ( DifyCliToolConfig, ) from .bash.session import SandboxBashSession -from .constants import ( - APP_ASSETS_PATH, - APP_ASSETS_ZIP_PATH, - DIFY_CLI_CONFIG_FILENAME, - DIFY_CLI_GLOBAL_TOOLS_PATH, - DIFY_CLI_PATH, - DIFY_CLI_PATH_PATTERN, - DIFY_CLI_ROOT, - DIFY_CLI_TOOLS_ROOT, -) +from .builder import SandboxBuilder, VMConfig +from .entities import AppAssets, DifyCli, SandboxProviderApiEntity, SandboxType from .initializer import AppAssetsInitializer, DifyCliInitializer, SandboxInitializer from .manager import SandboxManager +from .sandbox import Sandbox from .storage import ArchiveSandboxStorage, SandboxStorage from .utils.debug import sandbox_debug from .utils.encryption import create_sandbox_config_encrypter, masked_config -from .vm import SandboxBuilder, SandboxType, VMConfig __all__ = [ - "APP_ASSETS_PATH", - "APP_ASSETS_ZIP_PATH", - "DIFY_CLI_CONFIG_FILENAME", - "DIFY_CLI_GLOBAL_TOOLS_PATH", - "DIFY_CLI_PATH", - "DIFY_CLI_PATH_PATTERN", - "DIFY_CLI_ROOT", - "DIFY_CLI_TOOLS_ROOT", + "AppAssets", "AppAssetsInitializer", "ArchiveSandboxStorage", + "DifyCli", "DifyCliBinary", "DifyCliConfig", "DifyCliEnvConfig", "DifyCliInitializer", "DifyCliLocator", "DifyCliToolConfig", + "Sandbox", "SandboxBashSession", "SandboxBuilder", "SandboxInitializer", "SandboxManager", + "SandboxProviderApiEntity", "SandboxStorage", "SandboxType", "VMConfig", diff --git a/api/core/sandbox/bash/bash_tool.py b/api/core/sandbox/bash/bash_tool.py index 3d53f6314e..82467b01c2 100644 --- a/api/core/sandbox/bash/bash_tool.py +++ b/api/core/sandbox/bash/bash_tool.py @@ -1,7 +1,7 @@ from collections.abc import Generator from typing import Any -from core.sandbox.constants import DIFY_CLI_CONFIG_FILENAME +from core.sandbox.entities import DifyCli from core.tools.__base.tool import Tool from core.tools.__base.tool_runtime import ToolRuntime from core.tools.entities.common_entities import I18nObject @@ -79,7 +79,7 @@ class SandboxBashTool(Tool): if self._tools_path: environments = { "PATH": f"{self._tools_path}:/usr/local/bin:/usr/bin:/bin", - "DIFY_CLI_CONFIG": self._tools_path + f"/{DIFY_CLI_CONFIG_FILENAME}", + "DIFY_CLI_CONFIG": self._tools_path + f"/{DifyCli.CONFIG_FILENAME}", } future = submit_command( self._sandbox, diff --git a/api/core/sandbox/bash/dify_cli.py b/api/core/sandbox/bash/dify_cli.py index c0b3b0a25f..df6dff992c 100644 --- a/api/core/sandbox/bash/dify_cli.py +++ b/api/core/sandbox/bash/dify_cli.py @@ -14,7 +14,7 @@ from core.tools.entities.tool_entities import ToolParameter, ToolProviderType from core.tools.tool_manager import ToolManager from core.virtual_environment.__base.entities import Arch, OperatingSystem -from ..constants import DIFY_CLI_PATH_PATTERN +from ..entities import DifyCli if TYPE_CHECKING: from core.tools.__base.tool import Tool @@ -44,7 +44,7 @@ class DifyCliLocator: self._root = api_root / "bin" def resolve(self, operating_system: OperatingSystem, arch: Arch) -> DifyCliBinary: - filename = DIFY_CLI_PATH_PATTERN.format(os=operating_system.value, arch=arch.value) + filename = DifyCli.PATH_PATTERN.format(os=operating_system.value, arch=arch.value) candidate = self._root / filename if not candidate.is_file(): raise FileNotFoundError( diff --git a/api/core/sandbox/bash/session.py b/api/core/sandbox/bash/session.py index eb836ba6a2..cdca8977b0 100644 --- a/api/core/sandbox/bash/session.py +++ b/api/core/sandbox/bash/session.py @@ -5,19 +5,14 @@ import logging from io import BytesIO from types import TracebackType -from core.session.cli_api import CliApiSessionManager +from core.sandbox.sandbox import Sandbox +from core.session.cli_api import CliApiSession, CliApiSessionManager from core.skill.entities.tool_artifact import ToolArtifact from core.skill.skill_manager import SkillManager from core.virtual_environment.__base.helpers import pipeline -from core.virtual_environment.__base.virtual_environment import VirtualEnvironment from ..bash.dify_cli import DifyCliConfig -from ..constants import ( - DIFY_CLI_CONFIG_FILENAME, - DIFY_CLI_GLOBAL_TOOLS_PATH, - DIFY_CLI_PATH, - DIFY_CLI_TOOLS_ROOT, -) +from ..entities import DifyCli from .bash_tool import SandboxBashTool logger = logging.getLogger(__name__) @@ -27,46 +22,46 @@ class SandboxBashSession: def __init__( self, *, - sandbox: VirtualEnvironment, - tenant_id: str, - user_id: str, + sandbox: Sandbox, node_id: str, - app_id: str, - assets_id: str, allow_tools: list[tuple[str, str]] | None, ) -> None: self._sandbox = sandbox - self._tenant_id = tenant_id - self._user_id = user_id self._node_id = node_id - self._app_id = app_id - - # FIXME(Mairuis): should read from workflow run context... - self._assets_id = assets_id self._allow_tools = allow_tools - self._bash_tool = None - self._session_id = None + self._bash_tool: SandboxBashTool | None = None + self._cli_api_session: CliApiSession | None = None + self._tenant_id = sandbox.tenant_id + self._user_id = sandbox.user_id + self._app_id = sandbox.app_id + self._assets_id = sandbox.assets_id def __enter__(self) -> SandboxBashSession: + self._cli_api_session = CliApiSessionManager().create( + tenant_id=self._tenant_id, + user_id=self._user_id, + ) if self._allow_tools is not None: - if self._node_id is None: - raise ValueError("node_id is required when allow_tools is specified") - tools_path = self._setup_node_tools_directory(self._sandbox, self._node_id, self._allow_tools) + tools_path = self._setup_node_tools_directory(self._node_id, self._allow_tools, self._cli_api_session) else: - tools_path = DIFY_CLI_GLOBAL_TOOLS_PATH + tools_path = DifyCli.GLOBAL_TOOLS_PATH - self._bash_tool = SandboxBashTool(sandbox=self._sandbox, tenant_id=self._tenant_id, tools_path=tools_path) + self._bash_tool = SandboxBashTool( + sandbox=self._sandbox.vm, + tenant_id=self._tenant_id, + tools_path=tools_path, + ) return self def _setup_node_tools_directory( self, - sandbox: VirtualEnvironment, node_id: str, allow_tools: list[tuple[str, str]], + cli_api_session: CliApiSession, ) -> str | None: artifact: ToolArtifact | None = SkillManager.load_tool_artifact( - self._tenant_id, + self._sandbox.tenant_id, self._app_id, self._assets_id, ) @@ -80,26 +75,26 @@ class SandboxBashSession: logger.info("No tools found in artifact for assets_id=%s", self._assets_id) return None - self._cli_api_session = CliApiSessionManager().create(tenant_id=self._tenant_id, user_id=self._user_id) - node_tools_path = f"{DIFY_CLI_TOOLS_ROOT}/{node_id}" + node_tools_path = f"{DifyCli.TOOLS_ROOT}/{node_id}" + vm = self._sandbox.vm ( - pipeline(sandbox) - .add(["mkdir", "-p", DIFY_CLI_GLOBAL_TOOLS_PATH], error_message="Failed to create global tools dir") + pipeline(vm) + .add(["mkdir", "-p", DifyCli.GLOBAL_TOOLS_PATH], error_message="Failed to create global tools dir") .add(["mkdir", "-p", node_tools_path], error_message="Failed to create node tools dir") .execute(raise_on_error=True) ) config_json = json.dumps( - DifyCliConfig.create( - session=self._cli_api_session, tenant_id=self._tenant_id, artifact=artifact - ).model_dump(mode="json"), + DifyCliConfig.create(session=cli_api_session, tenant_id=self._tenant_id, artifact=artifact).model_dump( + mode="json" + ), ensure_ascii=False, ) - sandbox.upload_file(f"{node_tools_path}/{DIFY_CLI_CONFIG_FILENAME}", BytesIO(config_json.encode("utf-8"))) + vm.upload_file(f"{node_tools_path}/{DifyCli.CONFIG_FILENAME}", BytesIO(config_json.encode("utf-8"))) - pipeline(sandbox, cwd=node_tools_path).add( - [DIFY_CLI_PATH, "init"], error_message="Failed to initialize Dify CLI" + pipeline(vm, cwd=node_tools_path).add( + [DifyCli.PATH, "init"], error_message="Failed to initialize Dify CLI" ).execute(raise_on_error=True) logger.info( @@ -114,7 +109,10 @@ class SandboxBashSession: tb: TracebackType | None, ) -> bool: try: - self.cleanup() + if self._session_id is not None: + CliApiSessionManager().delete(self._session_id) + logger.debug("Cleaned up SandboxSession session_id=%s", self._session_id) + self._session_id = None except Exception: logger.exception("Failed to cleanup SandboxSession") return False @@ -124,11 +122,3 @@ class SandboxBashSession: if self._bash_tool is None: raise RuntimeError("SandboxSession is not initialized") return self._bash_tool - - def cleanup(self) -> None: - if self._session_id is None: - return - - CliApiSessionManager().delete(self._session_id) - logger.debug("Cleaned up SandboxSession session_id=%s", self._session_id) - self._session_id = None diff --git a/api/core/sandbox/vm.py b/api/core/sandbox/builder.py similarity index 60% rename from api/core/sandbox/vm.py rename to api/core/sandbox/builder.py index ac2b81ee1b..6113fbc8e6 100644 --- a/api/core/sandbox/vm.py +++ b/api/core/sandbox/builder.py @@ -1,41 +1,17 @@ -""" -Facade module for virtual machine providers. - -Provides unified interfaces to access different VM provider implementations -(E2B, Docker, Local) through VMType, VMBuilder, and VMConfig. -""" - from __future__ import annotations from collections.abc import Mapping, Sequence -from enum import StrEnum -from typing import Any +from typing import TYPE_CHECKING, Any -from configs import dify_config from core.entities.provider_entities import BasicProviderConfig from core.virtual_environment.__base.virtual_environment import VirtualEnvironment +from .entities.sandbox_type import SandboxType from .initializer import SandboxInitializer +from .sandbox import Sandbox - -class SandboxType(StrEnum): - """ - Sandbox types. - """ - - DOCKER = "docker" - E2B = "e2b" - LOCAL = "local" - - @classmethod - def get_all(cls) -> list[str]: - """ - Get all available sandbox types. - """ - if dify_config.EDITION == "SELF_HOSTED": - return [p.value for p in cls] - else: - return [p.value for p in cls if p != SandboxType.LOCAL] +if TYPE_CHECKING: + from .storage.sandbox_storage import SandboxStorage def _get_sandbox_class(sandbox_type: SandboxType) -> type[VirtualEnvironment]: @@ -57,18 +33,35 @@ def _get_sandbox_class(sandbox_type: SandboxType) -> type[VirtualEnvironment]: class SandboxBuilder: + _tenant_id: str + _sandbox_type: SandboxType + _user_id: str | None + _app_id: str | None + _options: dict[str, Any] + _environments: dict[str, str] + _initializers: list[SandboxInitializer] + _storage: SandboxStorage | None + _assets_id: str | None + def __init__(self, tenant_id: str, sandbox_type: SandboxType) -> None: self._tenant_id = tenant_id self._sandbox_type = sandbox_type - self._user_id: str | None = None - self._options: dict[str, Any] = {} - self._environments: dict[str, str] = {} - self._initializers: list[SandboxInitializer] = [] + self._user_id = None + self._app_id = None + self._options = {} + self._environments = {} + self._initializers = [] + self._storage = None + self._assets_id = None def user(self, user_id: str) -> SandboxBuilder: self._user_id = user_id return self + def app(self, app_id: str) -> SandboxBuilder: + self._app_id = app_id + return self + def options(self, options: Mapping[str, Any]) -> SandboxBuilder: self._options = dict(options) return self @@ -85,7 +78,21 @@ class SandboxBuilder: self._initializers.extend(initializers) return self - def build(self) -> VirtualEnvironment: + def storage(self, storage: SandboxStorage, assets_id: str) -> SandboxBuilder: + self._storage = storage + self._assets_id = assets_id + return self + + def build(self) -> Sandbox: + if self._storage is None: + raise ValueError("storage is required, call .storage() before .build()") + if self._assets_id is None: + raise ValueError("assets_id is required, call .storage() before .build()") + if self._user_id is None: + raise ValueError("user_id is required, call .user() before .build()") + if self._app_id is None: + raise ValueError("app_id is required, call .app() before .build()") + vm_class = _get_sandbox_class(self._sandbox_type) vm = vm_class( tenant_id=self._tenant_id, @@ -95,7 +102,17 @@ class SandboxBuilder: ) for init in self._initializers: init.initialize(vm) - return vm + + sandbox = Sandbox( + vm=vm, + storage=self._storage, + tenant_id=self._tenant_id, + user_id=self._user_id, + app_id=self._app_id, + assets_id=self._assets_id, + ) + sandbox.mount() + return sandbox @staticmethod def validate(vm_type: SandboxType, options: Mapping[str, Any]) -> None: diff --git a/api/core/sandbox/constants.py b/api/core/sandbox/constants.py deleted file mode 100644 index f4431c3e67..0000000000 --- a/api/core/sandbox/constants.py +++ /dev/null @@ -1,16 +0,0 @@ -from typing import Final - -# Dify CLI (absolute path - hidden in /tmp, not in sandbox workdir) -DIFY_CLI_ROOT: Final[str] = "/tmp/.dify" -DIFY_CLI_PATH: Final[str] = "/tmp/.dify/bin/dify" - -DIFY_CLI_PATH_PATTERN: Final[str] = "dify-cli-{os}-{arch}" - -DIFY_CLI_CONFIG_FILENAME: Final[str] = ".dify_cli.json" - -DIFY_CLI_TOOLS_ROOT: Final[str] = "/tmp/.dify/tools" -DIFY_CLI_GLOBAL_TOOLS_PATH: Final[str] = "/tmp/.dify/tools/global" - -# App Assets (relative path - stays in sandbox workdir) -APP_ASSETS_PATH: Final[str] = "skills" -APP_ASSETS_ZIP_PATH: Final[str] = "/tmp/assets.zip" diff --git a/api/core/sandbox/entities/__init__.py b/api/core/sandbox/entities/__init__.py index 6829ca31ba..b5c3d57342 100644 --- a/api/core/sandbox/entities/__init__.py +++ b/api/core/sandbox/entities/__init__.py @@ -1,3 +1,10 @@ +from .config import AppAssets, DifyCli from .providers import SandboxProviderApiEntity +from .sandbox_type import SandboxType -__all__ = ["SandboxProviderApiEntity"] +__all__ = [ + "AppAssets", + "DifyCli", + "SandboxProviderApiEntity", + "SandboxType", +] diff --git a/api/core/sandbox/entities/config.py b/api/core/sandbox/entities/config.py new file mode 100644 index 0000000000..db046097c5 --- /dev/null +++ b/api/core/sandbox/entities/config.py @@ -0,0 +1,19 @@ +from typing import Final + + +class DifyCli: + """Dify CLI constants (absolute path - hidden in /tmp, not in sandbox workdir)""" + + ROOT: Final[str] = "/tmp/.dify" + PATH: Final[str] = "/tmp/.dify/bin/dify" + PATH_PATTERN: Final[str] = "dify-cli-{os}-{arch}" + CONFIG_FILENAME: Final[str] = ".dify_cli.json" + TOOLS_ROOT: Final[str] = "/tmp/.dify/tools" + GLOBAL_TOOLS_PATH: Final[str] = "/tmp/.dify/tools/global" + + +class AppAssets: + """App Assets constants (relative path - stays in sandbox workdir)""" + + PATH: Final[str] = "skills" + ZIP_PATH: Final[str] = "/tmp/assets.zip" diff --git a/api/core/sandbox/entities/sandbox_type.py b/api/core/sandbox/entities/sandbox_type.py new file mode 100644 index 0000000000..3ac7e0a94e --- /dev/null +++ b/api/core/sandbox/entities/sandbox_type.py @@ -0,0 +1,16 @@ +from enum import StrEnum + +from configs import dify_config + + +class SandboxType(StrEnum): + DOCKER = "docker" + E2B = "e2b" + LOCAL = "local" + + @classmethod + def get_all(cls) -> list[str]: + if dify_config.EDITION == "SELF_HOSTED": + return [p.value for p in cls] + else: + return [p.value for p in cls if p != SandboxType.LOCAL] diff --git a/api/core/sandbox/initializer/app_assets_initializer.py b/api/core/sandbox/initializer/app_assets_initializer.py index 57102d1a9c..3f990d29ee 100644 --- a/api/core/sandbox/initializer/app_assets_initializer.py +++ b/api/core/sandbox/initializer/app_assets_initializer.py @@ -6,7 +6,7 @@ from core.virtual_environment.__base.virtual_environment import VirtualEnvironme from extensions.ext_storage import storage from extensions.storage.file_presign_storage import FilePresignStorage -from ..constants import APP_ASSETS_PATH, APP_ASSETS_ZIP_PATH +from ..entities import AppAssets from .base import SandboxInitializer logger = logging.getLogger(__name__) @@ -26,11 +26,11 @@ class AppAssetsInitializer(SandboxInitializer): ( pipeline(env) - .add(["wget", "-q", download_url, "-O", APP_ASSETS_ZIP_PATH], error_message="Failed to download assets zip") + .add(["wget", "-q", download_url, "-O", AppAssets.ZIP_PATH], error_message="Failed to download assets zip") # unzip with silent error and return 1 if the zip is empty # FIXME(Mairuis): should use a more robust way to check if the zip is empty .add( - ["sh", "-c", f"unzip {APP_ASSETS_ZIP_PATH} -d {APP_ASSETS_PATH} 2>/dev/null || [ $? -eq 1 ]"], + ["sh", "-c", f"unzip {AppAssets.ZIP_PATH} -d {AppAssets.PATH} 2>/dev/null || [ $? -eq 1 ]"], error_message="Failed to unzip assets", ) .execute(timeout=APP_ASSETS_DOWNLOAD_TIMEOUT, raise_on_error=True) @@ -55,12 +55,12 @@ class DraftAppAssetsInitializer(SandboxInitializer): ( pipeline(env) - .add(["rm", "-rf", APP_ASSETS_PATH]) - .add(["wget", "-q", download_url, "-O", APP_ASSETS_ZIP_PATH], error_message="Failed to download assets zip") + .add(["rm", "-rf", AppAssets.PATH]) + .add(["wget", "-q", download_url, "-O", AppAssets.ZIP_PATH], error_message="Failed to download assets zip") # unzip with silent error and return 1 if the zip is empty # FIXME(Mairuis): should use a more robust way to check if the zip is empty .add( - ["sh", "-c", f"unzip {APP_ASSETS_ZIP_PATH} -d {APP_ASSETS_PATH} 2>/dev/null || [ $? -eq 1 ]"], + ["sh", "-c", f"unzip {AppAssets.ZIP_PATH} -d {AppAssets.PATH} 2>/dev/null || [ $? -eq 1 ]"], error_message="Failed to unzip assets", ) .execute(timeout=APP_ASSETS_DOWNLOAD_TIMEOUT, raise_on_error=True) diff --git a/api/core/sandbox/initializer/dify_cli_initializer.py b/api/core/sandbox/initializer/dify_cli_initializer.py index 33c2be633a..fa6f57ce69 100644 --- a/api/core/sandbox/initializer/dify_cli_initializer.py +++ b/api/core/sandbox/initializer/dify_cli_initializer.py @@ -11,12 +11,7 @@ from core.virtual_environment.__base.helpers import pipeline from core.virtual_environment.__base.virtual_environment import VirtualEnvironment from ..bash.dify_cli import DifyCliConfig, DifyCliLocator -from ..constants import ( - DIFY_CLI_CONFIG_FILENAME, - DIFY_CLI_GLOBAL_TOOLS_PATH, - DIFY_CLI_PATH, - DIFY_CLI_ROOT, -) +from ..entities import DifyCli from .base import SandboxInitializer logger = logging.getLogger(__name__) @@ -44,10 +39,10 @@ class DifyCliInitializer(SandboxInitializer): binary = self._locator.resolve(env.metadata.os, env.metadata.arch) pipeline(env).add( - ["mkdir", "-p", f"{DIFY_CLI_ROOT}/bin"], error_message="Failed to create dify CLI directory" + ["mkdir", "-p", f"{DifyCli.ROOT}/bin"], error_message="Failed to create dify CLI directory" ).execute(raise_on_error=True) - env.upload_file(DIFY_CLI_PATH, BytesIO(binary.path.read_bytes())) + env.upload_file(DifyCli.PATH, BytesIO(binary.path.read_bytes())) # Use 'cp' with mode preservation workaround: copy file to itself to claim ownership, # then use 'install' to set executable permission @@ -55,14 +50,14 @@ class DifyCliInitializer(SandboxInitializer): [ "sh", "-c", - f"cat '{DIFY_CLI_PATH}' > '{DIFY_CLI_PATH}.tmp' && " - f"mv '{DIFY_CLI_PATH}.tmp' '{DIFY_CLI_PATH}' && " - f"chmod +x '{DIFY_CLI_PATH}'", + f"cat '{DifyCli.PATH}' > '{DifyCli.PATH}.tmp' && " + f"mv '{DifyCli.PATH}.tmp' '{DifyCli.PATH}' && " + f"chmod +x '{DifyCli.PATH}'", ], error_message="Failed to mark dify CLI as executable", ).execute(raise_on_error=True) - logger.info("Dify CLI uploaded to sandbox, path=%s", DIFY_CLI_PATH) + logger.info("Dify CLI uploaded to sandbox, path=%s", DifyCli.PATH) artifact = SkillManager.load_tool_artifact(self._tenant_id, self._app_id, self._assets_id) if artifact is None or not artifact.references: @@ -73,16 +68,16 @@ class DifyCliInitializer(SandboxInitializer): self._cli_api_session = CliApiSessionManager().create(tenant_id=self._tenant_id, user_id=self._user_id) pipeline(env).add( - ["mkdir", "-p", DIFY_CLI_GLOBAL_TOOLS_PATH], error_message="Failed to create global tools dir" + ["mkdir", "-p", DifyCli.GLOBAL_TOOLS_PATH], error_message="Failed to create global tools dir" ).execute(raise_on_error=True) config = DifyCliConfig.create(self._cli_api_session, self._tenant_id, artifact) config_json = json.dumps(config.model_dump(mode="json"), ensure_ascii=False) - config_path = f"{DIFY_CLI_GLOBAL_TOOLS_PATH}/{DIFY_CLI_CONFIG_FILENAME}" + config_path = f"{DifyCli.GLOBAL_TOOLS_PATH}/{DifyCli.CONFIG_FILENAME}" env.upload_file(config_path, BytesIO(config_json.encode("utf-8"))) - pipeline(env, cwd=DIFY_CLI_GLOBAL_TOOLS_PATH).add( - [DIFY_CLI_PATH, "init"], error_message="Failed to initialize Dify CLI" + pipeline(env, cwd=DifyCli.GLOBAL_TOOLS_PATH).add( + [DifyCli.PATH, "init"], error_message="Failed to initialize Dify CLI" ).execute(raise_on_error=True) - logger.info("Global tools initialized, path=%s, tool_count=%d", DIFY_CLI_GLOBAL_TOOLS_PATH, len(self._tools)) + logger.info("Global tools initialized, path=%s, tool_count=%d", DifyCli.GLOBAL_TOOLS_PATH, len(self._tools)) diff --git a/api/core/sandbox/manager.py b/api/core/sandbox/manager.py index 3de21e4e83..01a824ad4a 100644 --- a/api/core/sandbox/manager.py +++ b/api/core/sandbox/manager.py @@ -4,23 +4,20 @@ import logging import threading from typing import Final +from core.sandbox.builder import SandboxBuilder +from core.sandbox.entities import AppAssets, SandboxType +from core.sandbox.entities.providers import SandboxProviderEntity +from core.sandbox.initializer.app_assets_initializer import AppAssetsInitializer, DraftAppAssetsInitializer +from core.sandbox.initializer.dify_cli_initializer import DifyCliInitializer +from core.sandbox.sandbox import Sandbox +from core.sandbox.storage.archive_storage import ArchiveSandboxStorage from core.virtual_environment.__base.virtual_environment import VirtualEnvironment +from services.app_asset_service import AppAssetService logger = logging.getLogger(__name__) class SandboxManager: - """Process-local registry for workflow sandboxes. - - Stores `VirtualEnvironment` references keyed by `workflow_execution_id`. - - Concurrency: the registry is split into hash shards and each shard is updated with - copy-on-write under a shard lock. Reads are lock-free (snapshot dict) to reduce - contention in hot paths like `get()`. - """ - - # FIXME:(sandbox) Prefer a workflow-level context on GraphRuntimeState to store workflow-scoped shared objects. - _NUM_SHARDS: Final[int] = 1024 _SHARD_MASK: Final[int] = _NUM_SHARDS - 1 @@ -104,3 +101,91 @@ class SandboxManager: @classmethod def count(cls) -> int: return sum(len(shard) for shard in cls._shards) + + @classmethod + def create( + cls, + tenant_id: str, + app_id: str, + user_id: str, + workflow_execution_id: str, + sandbox_provider: SandboxProviderEntity, + ) -> Sandbox: + assets = AppAssetService.get_assets(tenant_id, app_id, user_id, is_draft=False) + if not assets: + raise ValueError(f"No assets found for tid={tenant_id}, app_id={app_id}") + + storage = ArchiveSandboxStorage(tenant_id, workflow_execution_id) + sandbox = ( + SandboxBuilder(tenant_id, SandboxType(sandbox_provider.provider_type)) + .options(sandbox_provider.config) + .user(user_id) + .app(app_id) + .initializer(AppAssetsInitializer(tenant_id, app_id, assets.id)) + .initializer(DifyCliInitializer(tenant_id, user_id, app_id, assets.id)) + .storage(storage, assets.id) + .build() + ) + + logger.info("Sandbox created: id=%s, assets=%s", sandbox.vm.metadata.id, sandbox.assets_id) + return sandbox + + @classmethod + def create_draft( + cls, + tenant_id: str, + app_id: str, + user_id: str, + sandbox_provider: SandboxProviderEntity, + ) -> Sandbox: + assets = AppAssetService.get_assets(tenant_id, app_id, user_id, is_draft=True) + if not assets: + raise ValueError(f"No assets found for tid={tenant_id}, app_id={app_id}") + + AppAssetService.build_assets(tenant_id, app_id, assets) + sandbox_id = SandboxBuilder.draft_id(user_id) + storage = ArchiveSandboxStorage(tenant_id, sandbox_id, exclude_patterns=[AppAssets.PATH]) + + sandbox = ( + SandboxBuilder(tenant_id, SandboxType(sandbox_provider.provider_type)) + .options(sandbox_provider.config) + .user(user_id) + .app(app_id) + .initializer(DraftAppAssetsInitializer(tenant_id, app_id, assets.id)) + .initializer(DifyCliInitializer(tenant_id, user_id, app_id, assets.id)) + .storage(storage, assets.id) + .build() + ) + + logger.info("Draft sandbox created: id=%s, assets=%s", sandbox.vm.metadata.id, sandbox.assets_id) + return sandbox + + @classmethod + def create_for_single_step( + cls, + tenant_id: str, + app_id: str, + user_id: str, + sandbox_provider: SandboxProviderEntity, + ) -> Sandbox: + assets = AppAssetService.get_assets(tenant_id, app_id, user_id, is_draft=True) + if not assets: + raise ValueError(f"No assets found for tid={tenant_id}, app_id={app_id}") + + AppAssetService.build_assets(tenant_id, app_id, assets) + sandbox_id = SandboxBuilder.draft_id(user_id) + storage = ArchiveSandboxStorage(tenant_id, sandbox_id, exclude_patterns=[AppAssets.PATH]) + + sandbox = ( + SandboxBuilder(tenant_id, SandboxType(sandbox_provider.provider_type)) + .options(sandbox_provider.config) + .user(user_id) + .app(app_id) + .initializer(AppAssetsInitializer(tenant_id, app_id, assets.id)) + .initializer(DifyCliInitializer(tenant_id, user_id, app_id, assets.id)) + .storage(storage, assets.id) + .build() + ) + + logger.info("Single-step sandbox created: id=%s, assets=%s", sandbox.vm.metadata.id, sandbox.assets_id) + return sandbox diff --git a/api/core/sandbox/sandbox.py b/api/core/sandbox/sandbox.py new file mode 100644 index 0000000000..7032238277 --- /dev/null +++ b/api/core/sandbox/sandbox.py @@ -0,0 +1,73 @@ +from __future__ import annotations + +import logging +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from core.sandbox.storage.sandbox_storage import SandboxStorage + from core.virtual_environment.__base.virtual_environment import VirtualEnvironment + +logger = logging.getLogger(__name__) + + +class Sandbox: + def __init__( + self, + *, + vm: VirtualEnvironment, + storage: SandboxStorage, + tenant_id: str, + user_id: str, + app_id: str, + assets_id: str, + ) -> None: + self._vm = vm + self._storage = storage + self._tenant_id = tenant_id + self._user_id = user_id + self._app_id = app_id + self._assets_id = assets_id + + @property + def vm(self) -> VirtualEnvironment: + return self._vm + + @property + def storage(self) -> SandboxStorage: + return self._storage + + @property + def tenant_id(self) -> str: + return self._tenant_id + + @property + def user_id(self) -> str: + return self._user_id + + @property + def app_id(self) -> str: + return self._app_id + + @property + def assets_id(self) -> str: + return self._assets_id + + def mount(self) -> bool: + return self._storage.mount(self._vm) + + def unmount(self) -> bool: + return self._storage.unmount(self._vm) + + def release(self) -> None: + sandbox_id = self._vm.metadata.id + try: + self._storage.unmount(self._vm) + logger.info("Sandbox storage unmounted: sandbox_id=%s", sandbox_id) + except Exception: + logger.exception("Failed to unmount sandbox storage: sandbox_id=%s", sandbox_id) + + try: + self._vm.release_environment() + logger.info("Sandbox released: sandbox_id=%s", sandbox_id) + except Exception: + logger.exception("Failed to release sandbox: sandbox_id=%s", sandbox_id) diff --git a/api/core/workflow/context/__init__.py b/api/core/workflow/context/__init__.py index 1237d6a017..fd60917617 100644 --- a/api/core/workflow/context/__init__.py +++ b/api/core/workflow/context/__init__.py @@ -17,7 +17,6 @@ from core.workflow.context.execution_context import ( register_context_capturer, reset_context_provider, ) -from core.workflow.context.models import SandboxContext __all__ = [ "AppContext", @@ -25,7 +24,6 @@ __all__ = [ "ExecutionContext", "IExecutionContext", "NullAppContext", - "SandboxContext", "capture_current_context", "read_context", "register_context", diff --git a/api/core/workflow/context/models.py b/api/core/workflow/context/models.py index af5a4b2614..bdec2fc8c4 100644 --- a/api/core/workflow/context/models.py +++ b/api/core/workflow/context/models.py @@ -1,13 +1,3 @@ from __future__ import annotations -from pydantic import AnyHttpUrl, BaseModel - - -class SandboxContext(BaseModel): - """Typed context for sandbox integration. All fields optional by design.""" - - sandbox_url: AnyHttpUrl | None = None - sandbox_token: str | None = None # optional, if later needed for auth - - -__all__ = ["SandboxContext"] +__all__: list[str] = [] diff --git a/api/core/workflow/nodes/command/node.py b/api/core/workflow/nodes/command/node.py index 52df7e9351..9e7d686f4e 100644 --- a/api/core/workflow/nodes/command/node.py +++ b/api/core/workflow/nodes/command/node.py @@ -2,11 +2,9 @@ import logging from collections.abc import Mapping, Sequence from typing import Any -from core.sandbox import SandboxManager, sandbox_debug -from core.sandbox.vm import SandboxBuilder +from core.sandbox import sandbox_debug from core.virtual_environment.__base.command_future import CommandCancelledError, CommandTimeoutError from core.virtual_environment.__base.helpers import submit_command, with_connection -from core.virtual_environment.__base.virtual_environment import VirtualEnvironment from core.workflow.enums import NodeType, WorkflowNodeExecutionStatus from core.workflow.node_events import NodeRunResult from core.workflow.nodes.base import variable_template_parser @@ -24,19 +22,6 @@ COMMAND_NODE_TIMEOUT_SECONDS = 60 class CommandNode(Node[CommandNodeData]): node_type = NodeType.COMMAND - # FIXME(Mairuis): should read sandbox from workflow run context... - def _get_sandbox(self) -> VirtualEnvironment | None: - workflow_execution_id = self.graph_runtime_state.variable_pool.system_variables.workflow_execution_id - if not workflow_execution_id: - return None - sandbox_by_workflow_run_id = SandboxManager.get(workflow_execution_id) - if sandbox_by_workflow_run_id is not None: - return sandbox_by_workflow_run_id - sandbox_by_draft_id = SandboxManager.get(SandboxBuilder.draft_id(self.user_id)) - if sandbox_by_draft_id is not None: - return sandbox_by_draft_id - return None - def _render_template(self, template: str) -> str: parser = VariableTemplateParser(template=template) selectors = parser.extract_variable_selectors() @@ -65,7 +50,7 @@ class CommandNode(Node[CommandNodeData]): return "1" def _run(self) -> NodeRunResult: - sandbox = self._get_sandbox() + sandbox = self.graph_runtime_state.sandbox if sandbox is None: return NodeRunResult( status=WorkflowNodeExecutionStatus.FAILED, @@ -88,12 +73,12 @@ class CommandNode(Node[CommandNodeData]): timeout = COMMAND_NODE_TIMEOUT_SECONDS if COMMAND_NODE_TIMEOUT_SECONDS > 0 else None try: - with with_connection(sandbox) as conn: + with with_connection(sandbox.vm) as conn: command = ["bash", "-c", raw_command] sandbox_debug("command_node", "command", command) - future = submit_command(sandbox, conn, command, cwd=working_directory) + future = submit_command(sandbox.vm, conn, command, cwd=working_directory) result = future.result(timeout=timeout) outputs: dict[str, Any] = { diff --git a/api/core/workflow/nodes/llm/node.py b/api/core/workflow/nodes/llm/node.py index 99bf40064f..e7c8036862 100644 --- a/api/core/workflow/nodes/llm/node.py +++ b/api/core/workflow/nodes/llm/node.py @@ -50,8 +50,8 @@ from core.model_runtime.utils.encoders import jsonable_encoder from core.prompt.entities.advanced_prompt_entities import CompletionModelPromptTemplate, MemoryConfig from core.prompt.utils.prompt_message_util import PromptMessageUtil from core.rag.entities.citation_metadata import RetrievalSourceMetadata -from core.sandbox import SandboxBashSession, SandboxManager -from core.sandbox.vm import SandboxBuilder +from core.sandbox import Sandbox +from core.sandbox.bash.session import SandboxBashSession from core.tools.__base.tool import Tool from core.tools.signature import sign_upload_file from core.tools.tool_manager import ToolManager @@ -64,7 +64,6 @@ from core.variables import ( ObjectSegment, StringSegment, ) -from core.virtual_environment.__base.virtual_environment import VirtualEnvironment from core.workflow.constants import SYSTEM_VARIABLE_NODE_ID from core.workflow.entities import GraphInitParams, ToolCall, ToolResult, ToolResultStatus from core.workflow.entities.tool_entities import ToolCallResult @@ -174,19 +173,6 @@ class LLMNode(Node[LLMNodeData]): def version(cls) -> str: return "1" - # FIXME(Mairuis): should read sandbox from workflow run context... - def _get_sandbox(self) -> VirtualEnvironment | None: - workflow_execution_id = self.graph_runtime_state.variable_pool.system_variables.workflow_execution_id - if not workflow_execution_id: - return None - sandbox_by_workflow_run_id = SandboxManager.get(workflow_execution_id) - if sandbox_by_workflow_run_id is not None: - return sandbox_by_workflow_run_id - sandbox_by_draft_id = SandboxManager.get(SandboxBuilder.draft_id(self.user_id)) - if sandbox_by_draft_id is not None: - return sandbox_by_draft_id - return None - def _run(self) -> Generator: node_inputs: dict[str, Any] = {} process_data: dict[str, Any] = {} @@ -301,8 +287,7 @@ class LLMNode(Node[LLMNodeData]): generation_data: LLMGenerationData | None = None structured_output: LLMStructuredOutput | None = None - # FIXME(Mairuis): should read sandbox from workflow run context... - sandbox = self._get_sandbox() + sandbox = self.graph_runtime_state.sandbox if sandbox: generator = self._invoke_llm_with_sandbox( sandbox=sandbox, @@ -1839,7 +1824,7 @@ class LLMNode(Node[LLMNodeData]): def _invoke_llm_with_sandbox( self, - sandbox: VirtualEnvironment, + sandbox: Sandbox, model_instance: ModelInstance, prompt_messages: Sequence[PromptMessage], stop: Sequence[str] | None, @@ -1849,23 +1834,14 @@ class LLMNode(Node[LLMNodeData]): result: LLMGenerationData | None = None - with SandboxBashSession( - sandbox=sandbox, - tenant_id=self.tenant_id, - user_id=self.user_id, - node_id=self.id, - app_id=self.app_id, - # FIXME(Mairuis): should read from workflow run context... - assets_id=getattr(self, "assets_id", ""), - allow_tools=allow_tools, - ) as sandbox_session: + with SandboxBashSession(sandbox=sandbox, node_id=self.id, allow_tools=allow_tools) as session: prompt_files = self._extract_prompt_files(variable_pool) model_features = self._get_model_features(model_instance) strategy = StrategyFactory.create_strategy( model_features=model_features, model_instance=model_instance, - tools=[sandbox_session.bash_tool], + tools=[session.bash_tool], files=prompt_files, max_iterations=self._node_data.max_iterations or 100, agent_strategy=AgentEntity.Strategy.FUNCTION_CALLING, diff --git a/api/core/workflow/runtime/graph_runtime_state.py b/api/core/workflow/runtime/graph_runtime_state.py index 401cecc162..86056b139b 100644 --- a/api/core/workflow/runtime/graph_runtime_state.py +++ b/api/core/workflow/runtime/graph_runtime_state.py @@ -11,6 +11,7 @@ from typing import Any, Protocol from pydantic.json import pydantic_encoder from core.model_runtime.entities.llm_entities import LLMUsage +from core.sandbox.sandbox import Sandbox from core.workflow.entities.pause_reason import PauseReason from core.workflow.runtime.variable_pool import VariablePool @@ -171,6 +172,8 @@ class GraphRuntimeState: self._paused_nodes: set[str] = set() self.stop_event: threading.Event = threading.Event() + self._sandbox: Sandbox | None = None + if graph is not None: self.attach_graph(graph) @@ -294,6 +297,16 @@ class GraphRuntimeState: raise ValueError("tokens must be non-negative") self._total_tokens += tokens + # ------------------------------------------------------------------ + # Sandbox context (workflow-scoped) + # ------------------------------------------------------------------ + @property + def sandbox(self) -> Sandbox | None: + return self._sandbox + + def set_sandbox(self, sandbox: Sandbox) -> None: + self._sandbox = sandbox + # ------------------------------------------------------------------ # Serialization # ------------------------------------------------------------------ diff --git a/api/core/workflow/runtime/graph_runtime_state_protocol.py b/api/core/workflow/runtime/graph_runtime_state_protocol.py index bfbb5ba704..3f3855d2bd 100644 --- a/api/core/workflow/runtime/graph_runtime_state_protocol.py +++ b/api/core/workflow/runtime/graph_runtime_state_protocol.py @@ -78,6 +78,10 @@ class ReadOnlyGraphRuntimeState(Protocol): """Get a single output value (returns a copy).""" ... + @property + def sandbox(self) -> Any: + ... + def dumps(self) -> str: """Serialize the runtime state into a JSON snapshot (read-only).""" ... diff --git a/api/core/workflow/runtime/read_only_wrappers.py b/api/core/workflow/runtime/read_only_wrappers.py index d3e4c60d9b..301da45d36 100644 --- a/api/core/workflow/runtime/read_only_wrappers.py +++ b/api/core/workflow/runtime/read_only_wrappers.py @@ -82,6 +82,10 @@ class ReadOnlyGraphRuntimeStateWrapper: def get_output(self, key: str, default: Any = None) -> Any: return self._state.get_output(key, default) + @property + def sandbox(self) -> Any: + return self._state.sandbox + def dumps(self) -> str: """Serialize the underlying runtime state for external persistence.""" return self._state.dumps() diff --git a/api/core/workflow/workflow_entry.py b/api/core/workflow/workflow_entry.py index ee37314721..f78486da23 100644 --- a/api/core/workflow/workflow_entry.py +++ b/api/core/workflow/workflow_entry.py @@ -8,6 +8,7 @@ from configs import dify_config from core.app.apps.exc import GenerateTaskStoppedError from core.app.entities.app_invoke_entities import InvokeFrom from core.file.models import File +from core.sandbox import Sandbox from core.workflow.constants import ENVIRONMENT_VARIABLE_NODE_ID from core.workflow.entities import GraphInitParams from core.workflow.errors import WorkflowNodeRunFailedError @@ -128,6 +129,7 @@ class WorkflowEntry: user_inputs: Mapping[str, Any], variable_pool: VariablePool, variable_loader: VariableLoader = DUMMY_VARIABLE_LOADER, + sandbox: Sandbox | None = None, ) -> tuple[Node, Generator[GraphNodeEventBase, None, None]]: """ Single step run workflow node @@ -156,6 +158,9 @@ class WorkflowEntry: ) graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()) + if sandbox is not None: + graph_runtime_state.set_sandbox(sandbox) + # init workflow run state node_factory = DifyNodeFactory( graph_init_params=graph_init_params, diff --git a/api/services/sandbox/sandbox_provider_service.py b/api/services/sandbox/sandbox_provider_service.py index 56933f43f3..a73570ff7b 100644 --- a/api/services/sandbox/sandbox_provider_service.py +++ b/api/services/sandbox/sandbox_provider_service.py @@ -6,8 +6,14 @@ from typing import Any from sqlalchemy.orm import Session from constants import HIDDEN_VALUE -from core.sandbox import SandboxBuilder, SandboxType, VMConfig, create_sandbox_config_encrypter, masked_config -from core.sandbox.entities import SandboxProviderApiEntity +from core.sandbox import ( + SandboxBuilder, + SandboxProviderApiEntity, + SandboxType, + VMConfig, + create_sandbox_config_encrypter, + masked_config, +) from core.sandbox.entities.providers import SandboxProviderEntity from core.tools.utils.system_encryption import decrypt_system_params from extensions.ext_database import db @@ -206,7 +212,6 @@ class SandboxProviderService: raise ValueError(f"No system default provider configured for tenant {tenant_id}") @classmethod - def create_sandbox_builder(cls, tenant_id: str) -> SandboxBuilder: + def get_sandbox_provider(cls, tenant_id: str) -> SandboxProviderEntity: with Session(db.engine, expire_on_commit=False) as session: - active_config = cls.get_active_sandbox_config(session, tenant_id) - return SandboxBuilder(tenant_id, SandboxType(active_config.provider_type)).options(active_config.config) + return cls.get_active_sandbox_config(session, tenant_id) diff --git a/api/services/workflow_service.py b/api/services/workflow_service.py index a509eccad4..b9b155ce80 100644 --- a/api/services/workflow_service.py +++ b/api/services/workflow_service.py @@ -14,10 +14,7 @@ from core.app.apps.advanced_chat.app_config_manager import AdvancedChatAppConfig from core.app.apps.workflow.app_config_manager import WorkflowAppConfigManager from core.file import File from core.repositories import DifyCoreRepositoryFactory -from core.sandbox import SandboxManager -from core.sandbox.constants import APP_ASSETS_PATH -from core.sandbox.storage.archive_storage import ArchiveSandboxStorage -from core.sandbox.vm import SandboxBuilder +from core.sandbox.manager import SandboxManager from core.variables import Variable, VariableBase from core.workflow.entities import WorkflowNodeExecution from core.workflow.enums import ErrorStrategy, WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus @@ -702,34 +699,15 @@ class WorkflowService: enclosing_node_id = None sandbox = None - single_step_execution_id: str | None = None if draft_workflow.get_feature(WorkflowFeatures.SANDBOX).enabled: - from core.sandbox import AppAssetsInitializer, DifyCliInitializer - from services.app_asset_service import AppAssetService - - assets = AppAssetService.get_or_create_assets(session, app_model, account.id) - if not assets: - raise ValueError(f"No assets found for tid={draft_workflow.tenant_id}, app_id={app_model.id}") - - # FIXME(Mairuis): single step execution - AppAssetService.build_assets(draft_workflow.tenant_id, app_model.id, assets) - sandbox_id = SandboxBuilder.draft_id(account.id) - sandbox_storage = ArchiveSandboxStorage( - draft_workflow.tenant_id, sandbox_id, exclude_patterns=[APP_ASSETS_PATH] + sandbox_provider = SandboxProviderService.get_sandbox_provider(draft_workflow.tenant_id) + sandbox = SandboxManager.create_for_single_step( + tenant_id=draft_workflow.tenant_id, + app_id=app_model.id, + user_id=account.id, + sandbox_provider=sandbox_provider, ) - sandbox = ( - SandboxProviderService.create_sandbox_builder(draft_workflow.tenant_id) - .initializer(DifyCliInitializer(draft_workflow.tenant_id, account.id, app_model.id, assets.id)) - .initializer(AppAssetsInitializer(draft_workflow.tenant_id, app_model.id, assets.id)) - .build() - ) - sandbox_storage.mount(sandbox) - single_step_execution_id = f"single-step-{uuid.uuid4()}" - - SandboxManager.register(single_step_execution_id, sandbox) - variable_pool.system_variables.workflow_execution_id = single_step_execution_id - try: node, generator = WorkflowEntry.single_step_run( workflow=draft_workflow, @@ -738,6 +716,7 @@ class WorkflowService: user_id=account.id, variable_pool=variable_pool, variable_loader=variable_loader, + sandbox=sandbox, ) # Run draft workflow node @@ -747,17 +726,9 @@ class WorkflowService: start_at=start_at, node_id=node_id, ) - # FIXME(Mairuis): fidn a better way to handle this - if sandbox is not None: - sandbox_storage.unmount(sandbox) finally: - if single_step_execution_id: - sandbox = SandboxManager.unregister(single_step_execution_id) - if sandbox: - try: - sandbox.release_environment() - except Exception: - logger.exception("Failed to release sandbox") + if sandbox is not None: + sandbox.release() # Set workflow_id on the NodeExecution node_execution.workflow_id = draft_workflow.id diff --git a/api/tests/unit_tests/core/virtual_environment/test_sandbox_manager.py b/api/tests/unit_tests/core/virtual_environment/test_sandbox_manager.py deleted file mode 100644 index b1674478ad..0000000000 --- a/api/tests/unit_tests/core/virtual_environment/test_sandbox_manager.py +++ /dev/null @@ -1,164 +0,0 @@ -import threading -from collections.abc import Mapping -from io import BytesIO -from typing import Any - -import pytest - -from core.sandbox import SandboxManager -from core.virtual_environment.__base.entities import ( - Arch, - CommandStatus, - ConnectionHandle, - FileState, - Metadata, - OperatingSystem, -) -from core.virtual_environment.__base.virtual_environment import VirtualEnvironment - - -class FakeVirtualEnvironment(VirtualEnvironment): - def __init__(self, sandbox_id: str = "fake-id"): - self._sandbox_id = sandbox_id - super().__init__(tenant_id="test-tenant", options={}, environments={}) - - def _construct_environment(self, options: Mapping[str, Any], environments: Mapping[str, str]) -> Metadata: - return Metadata(id=self._sandbox_id, arch=Arch.AMD64, os=OperatingSystem.LINUX) - - def upload_file(self, path: str, content: BytesIO) -> None: - raise NotImplementedError - - def download_file(self, path: str) -> BytesIO: - raise NotImplementedError - - def list_files(self, directory_path: str, limit: int) -> list[FileState]: - return [] - - def establish_connection(self) -> ConnectionHandle: - return ConnectionHandle(id="conn") - - def release_connection(self, connection_handle: ConnectionHandle) -> None: - pass - - def release_environment(self) -> None: - pass - - def execute_command( - self, - connection_handle: ConnectionHandle, - command: list[str], - environments: Mapping[str, str] | None = None, - cwd: str | None = None, - ) -> tuple[str, Any, Any, Any]: - raise NotImplementedError - - def get_command_status(self, connection_handle: ConnectionHandle, pid: str) -> CommandStatus: - return CommandStatus(status=CommandStatus.Status.COMPLETED, exit_code=0) - - @classmethod - def validate(cls, options: Mapping[str, Any]) -> None: - pass - - -@pytest.fixture(autouse=True) -def clean_sandbox_manager(): - SandboxManager.clear() - yield - SandboxManager.clear() - - -class TestSandboxManager: - def test_register_and_get(self): - sandbox = FakeVirtualEnvironment("sandbox-1") - - SandboxManager.register("exec-1", sandbox) - result = SandboxManager.get("exec-1") - - assert result is sandbox - - def test_get_returns_none_for_unknown_id(self): - result = SandboxManager.get("unknown-id") - assert result is None - - def test_register_raises_on_empty_workflow_execution_id(self): - sandbox = FakeVirtualEnvironment() - - with pytest.raises(ValueError, match="workflow_execution_id cannot be empty"): - SandboxManager.register("", sandbox) - - def test_register_raises_on_duplicate(self): - sandbox1 = FakeVirtualEnvironment("sandbox-1") - sandbox2 = FakeVirtualEnvironment("sandbox-2") - - SandboxManager.register("exec-dup", sandbox1) - - with pytest.raises(RuntimeError, match="already registered"): - SandboxManager.register("exec-dup", sandbox2) - - def test_unregister_returns_sandbox(self): - sandbox = FakeVirtualEnvironment("sandbox-to-remove") - SandboxManager.register("exec-remove", sandbox) - - result = SandboxManager.unregister("exec-remove") - - assert result is sandbox - assert SandboxManager.get("exec-remove") is None - - def test_unregister_returns_none_for_unknown(self): - result = SandboxManager.unregister("nonexistent") - assert result is None - - def test_has_returns_true_when_registered(self): - sandbox = FakeVirtualEnvironment() - SandboxManager.register("exec-has", sandbox) - - assert SandboxManager.has("exec-has") is True - - def test_has_returns_false_when_not_registered(self): - assert SandboxManager.has("exec-no") is False - - def test_clear_removes_all_sandboxes(self): - sandbox1 = FakeVirtualEnvironment("s1") - sandbox2 = FakeVirtualEnvironment("s2") - SandboxManager.register("exec-1", sandbox1) - SandboxManager.register("exec-2", sandbox2) - - SandboxManager.clear() - - assert SandboxManager.count() == 0 - assert SandboxManager.get("exec-1") is None - assert SandboxManager.get("exec-2") is None - - def test_count_returns_number_of_sandboxes(self): - assert SandboxManager.count() == 0 - - SandboxManager.register("e1", FakeVirtualEnvironment("s1")) - assert SandboxManager.count() == 1 - - SandboxManager.register("e2", FakeVirtualEnvironment("s2")) - assert SandboxManager.count() == 2 - - SandboxManager.unregister("e1") - assert SandboxManager.count() == 1 - - def test_thread_safety(self): - results: list[bool] = [] - errors: list[Exception] = [] - - def register_sandbox(exec_id: str): - try: - sandbox = FakeVirtualEnvironment(f"sandbox-{exec_id}") - SandboxManager.register(exec_id, sandbox) - results.append(True) - except Exception as e: - errors.append(e) - - threads = [threading.Thread(target=register_sandbox, args=(f"exec-{i}",)) for i in range(10)] - for t in threads: - t.start() - for t in threads: - t.join() - - assert len(errors) == 0 - assert len(results) == 10 - assert SandboxManager.count() == 10