From a0c388f283b60f29966d24f88201fd0b57b80a6f Mon Sep 17 00:00:00 2001 From: Harry Date: Wed, 14 Jan 2026 23:23:00 +0800 Subject: [PATCH] refactor(sandbox): extract connection helpers and move run_command to helper module - Add helpers.py with connection management utilities: - with_connection: context manager for connection lifecycle - submit_command: execute command and return CommandFuture - execute: run command with auto connection, raise on failure - try_execute: run command with auto connection, return result - Add CommandExecutionError to exec.py for typed error handling with access to exit_code, stderr, and full result - Remove run_command method from VirtualEnvironment base class (now available as submit_command helper) - Update all call sites to use new helper functions: - sandbox/session.py - sandbox/storage/archive_storage.py - sandbox/bash/bash_tool.py - workflow/nodes/command/node.py - Add comprehensive unit tests for helpers with connection reuse --- api/core/sandbox/__init__.py | 16 +- api/core/sandbox/bash/__init__.py | 17 ++ api/core/sandbox/{ => bash}/bash_tool.py | 37 ++- api/core/sandbox/{ => bash}/dify_cli.py | 0 api/core/sandbox/initializer/__init__.py | 6 + .../sandbox/{ => initializer}/initializer.py | 21 +- api/core/sandbox/session.py | 21 +- api/core/sandbox/storage/archive_storage.py | 40 +-- api/core/sandbox/utils/__init__.py | 2 + api/core/sandbox/{ => utils}/debug.py | 0 api/core/sandbox/{ => utils}/encryption.py | 0 api/core/virtual_environment/__base/exec.py | 24 ++ .../virtual_environment/__base/helpers.py | 149 ++++++++++ .../__base/virtual_environment.py | 39 --- api/core/workflow/nodes/command/node.py | 56 ++-- api/models/app_asset.py | 3 +- .../sandbox/sandbox_provider_service.py | 2 +- .../core/virtual_environment/test_helpers.py | 264 ++++++++++++++++++ .../test_local_without_isolation.py | 5 +- 19 files changed, 553 insertions(+), 149 deletions(-) create mode 100644 api/core/sandbox/bash/__init__.py rename api/core/sandbox/{ => bash}/bash_tool.py (69%) rename api/core/sandbox/{ => bash}/dify_cli.py (100%) create mode 100644 api/core/sandbox/initializer/__init__.py rename api/core/sandbox/{ => initializer}/initializer.py (54%) create mode 100644 api/core/sandbox/utils/__init__.py rename api/core/sandbox/{ => utils}/debug.py (100%) rename api/core/sandbox/{ => utils}/encryption.py (100%) create mode 100644 api/core/virtual_environment/__base/helpers.py create mode 100644 api/tests/unit_tests/core/virtual_environment/test_helpers.py diff --git a/api/core/sandbox/__init__.py b/api/core/sandbox/__init__.py index 8718130175..5f3f14fcb0 100644 --- a/api/core/sandbox/__init__.py +++ b/api/core/sandbox/__init__.py @@ -1,17 +1,17 @@ -from core.sandbox.bash_tool import SandboxBashTool -from core.sandbox.constants import ( - DIFY_CLI_CONFIG_PATH, - DIFY_CLI_PATH, - DIFY_CLI_PATH_PATTERN, -) -from core.sandbox.dify_cli import ( +from core.sandbox.bash.bash_tool import SandboxBashTool +from core.sandbox.bash.dify_cli import ( DifyCliBinary, DifyCliConfig, DifyCliEnvConfig, DifyCliLocator, DifyCliToolConfig, ) -from core.sandbox.initializer import DifyCliInitializer, SandboxInitializer +from core.sandbox.constants import ( + DIFY_CLI_CONFIG_PATH, + DIFY_CLI_PATH, + DIFY_CLI_PATH_PATTERN, +) +from core.sandbox.initializer.initializer import DifyCliInitializer, SandboxInitializer from core.sandbox.session import SandboxSession __all__ = [ diff --git a/api/core/sandbox/bash/__init__.py b/api/core/sandbox/bash/__init__.py new file mode 100644 index 0000000000..3e0f59e1bf --- /dev/null +++ b/api/core/sandbox/bash/__init__.py @@ -0,0 +1,17 @@ +from core.sandbox.bash.bash_tool import SandboxBashTool +from core.sandbox.bash.dify_cli import ( + DifyCliBinary, + DifyCliConfig, + DifyCliEnvConfig, + DifyCliLocator, + DifyCliToolConfig, +) + +__all__ = [ + "DifyCliBinary", + "DifyCliConfig", + "DifyCliEnvConfig", + "DifyCliLocator", + "DifyCliToolConfig", + "SandboxBashTool", +] diff --git a/api/core/sandbox/bash_tool.py b/api/core/sandbox/bash/bash_tool.py similarity index 69% rename from api/core/sandbox/bash_tool.py rename to api/core/sandbox/bash/bash_tool.py index 32c92a22e5..431ebb59f7 100644 --- a/api/core/sandbox/bash_tool.py +++ b/api/core/sandbox/bash/bash_tool.py @@ -1,7 +1,7 @@ from collections.abc import Generator from typing import Any -from core.sandbox.debug import sandbox_debug +from core.sandbox.utils.debug import sandbox_debug from core.tools.__base.tool import Tool from core.tools.__base.tool_runtime import ToolRuntime from core.tools.entities.common_entities import I18nObject @@ -13,6 +13,7 @@ from core.tools.entities.tool_entities import ( ToolParameter, ToolProviderType, ) +from core.virtual_environment.__base.helpers import submit_command, with_connection from core.virtual_environment.__base.virtual_environment import VirtualEnvironment COMMAND_TIMEOUT_SECONDS = 60 @@ -66,31 +67,29 @@ class SandboxBashTool(Tool): yield self.create_text_message("Error: No command provided") return - connection_handle = self._sandbox.establish_connection() try: - cmd_list = ["bash", "-c", command] + with with_connection(self._sandbox) as conn: + cmd_list = ["bash", "-c", command] - sandbox_debug("bash_tool", "cmd_list", cmd_list) - future = self._sandbox.run_command(connection_handle, cmd_list) - timeout = COMMAND_TIMEOUT_SECONDS if COMMAND_TIMEOUT_SECONDS > 0 else None - result = future.result(timeout=timeout) + sandbox_debug("bash_tool", "cmd_list", cmd_list) + future = submit_command(self._sandbox, conn, cmd_list) + timeout = COMMAND_TIMEOUT_SECONDS if COMMAND_TIMEOUT_SECONDS > 0 else None + result = future.result(timeout=timeout) - stdout = result.stdout.decode("utf-8", errors="replace") if result.stdout else "" - stderr = result.stderr.decode("utf-8", errors="replace") if result.stderr else "" - exit_code = result.exit_code + stdout = result.stdout.decode("utf-8", errors="replace") if result.stdout else "" + stderr = result.stderr.decode("utf-8", errors="replace") if result.stderr else "" + exit_code = result.exit_code - output_parts: list[str] = [] - if stdout: - output_parts.append(f"\n{stdout}") - if stderr: - output_parts.append(f"\n{stderr}") - output_parts.append(f"\nCommand exited with code {exit_code}") + output_parts: list[str] = [] + if stdout: + output_parts.append(f"\n{stdout}") + if stderr: + output_parts.append(f"\n{stderr}") + output_parts.append(f"\nCommand exited with code {exit_code}") - yield self.create_text_message("\n".join(output_parts)) + yield self.create_text_message("\n".join(output_parts)) except TimeoutError: yield self.create_text_message(f"Error: Command timed out after {COMMAND_TIMEOUT_SECONDS}s") except Exception as e: yield self.create_text_message(f"Error: {e!s}") - finally: - self._sandbox.release_connection(connection_handle) diff --git a/api/core/sandbox/dify_cli.py b/api/core/sandbox/bash/dify_cli.py similarity index 100% rename from api/core/sandbox/dify_cli.py rename to api/core/sandbox/bash/dify_cli.py diff --git a/api/core/sandbox/initializer/__init__.py b/api/core/sandbox/initializer/__init__.py new file mode 100644 index 0000000000..258d7fafc5 --- /dev/null +++ b/api/core/sandbox/initializer/__init__.py @@ -0,0 +1,6 @@ +from core.sandbox.initializer.initializer import DifyCliInitializer, SandboxInitializer + +__all__ = [ + "DifyCliInitializer", + "SandboxInitializer", +] diff --git a/api/core/sandbox/initializer.py b/api/core/sandbox/initializer/initializer.py similarity index 54% rename from api/core/sandbox/initializer.py rename to api/core/sandbox/initializer/initializer.py index f1a6f4f343..50535faac7 100644 --- a/api/core/sandbox/initializer.py +++ b/api/core/sandbox/initializer/initializer.py @@ -3,8 +3,9 @@ from abc import ABC, abstractmethod from io import BytesIO from pathlib import Path +from core.sandbox.bash.dify_cli import DifyCliLocator from core.sandbox.constants import DIFY_CLI_PATH -from core.sandbox.dify_cli import DifyCliLocator +from core.virtual_environment.__base.helpers import execute from core.virtual_environment.__base.virtual_environment import VirtualEnvironment logger = logging.getLogger(__name__) @@ -23,14 +24,10 @@ class DifyCliInitializer(SandboxInitializer): binary = self._locator.resolve(env.metadata.os, env.metadata.arch) env.upload_file(DIFY_CLI_PATH, BytesIO(binary.path.read_bytes())) - connection_handle = env.establish_connection() - try: - future = env.run_command(connection_handle, ["chmod", "+x", DIFY_CLI_PATH]) - result = future.result(timeout=10) - if result.exit_code not in (0, None): - stderr = result.stderr.decode("utf-8", errors="replace") if result.stderr else "" - raise RuntimeError(f"Failed to mark dify CLI as executable: {stderr}") - - logger.info("Dify CLI uploaded to sandbox, path=%s", DIFY_CLI_PATH) - finally: - env.release_connection(connection_handle) + execute( + env, + ["chmod", "+x", DIFY_CLI_PATH], + timeout=10, + error_message="Failed to mark dify CLI as executable", + ) + logger.info("Dify CLI uploaded to sandbox, path=%s", DIFY_CLI_PATH) diff --git a/api/core/sandbox/session.py b/api/core/sandbox/session.py index 9ed2b0589a..4a4fe0a8b8 100644 --- a/api/core/sandbox/session.py +++ b/api/core/sandbox/session.py @@ -5,13 +5,14 @@ import logging from io import BytesIO from types import TracebackType -from core.sandbox.bash_tool import SandboxBashTool +from core.sandbox.bash.bash_tool import SandboxBashTool +from core.sandbox.bash.dify_cli import DifyCliConfig from core.sandbox.constants import DIFY_CLI_CONFIG_PATH, DIFY_CLI_PATH -from core.sandbox.debug import sandbox_debug -from core.sandbox.dify_cli import DifyCliConfig from core.sandbox.manager import SandboxManager +from core.sandbox.utils.debug import sandbox_debug from core.session.cli_api import CliApiSessionManager from core.tools.__base.tool import Tool +from core.virtual_environment.__base.helpers import execute from core.virtual_environment.__base.virtual_environment import VirtualEnvironment logger = logging.getLogger(__name__) @@ -50,14 +51,12 @@ class SandboxSession: sandbox_debug("sandbox", "config_json", config_json) sandbox.upload_file(DIFY_CLI_CONFIG_PATH, BytesIO(config_json.encode("utf-8"))) - connection_handle = sandbox.establish_connection() - try: - future = sandbox.run_command(connection_handle, [DIFY_CLI_PATH, "init"]) - result = future.result(timeout=30) - if result.is_error: - raise RuntimeError(f"Failed to initialize Dify CLI in sandbox: {result.error_message}") - finally: - sandbox.release_connection(connection_handle) + execute( + sandbox, + [DIFY_CLI_PATH, "init"], + timeout=30, + error_message="Failed to initialize Dify CLI in sandbox", + ) except Exception: CliApiSessionManager().delete(session.id) diff --git a/api/core/sandbox/storage/archive_storage.py b/api/core/sandbox/storage/archive_storage.py index 1393c8962b..34ddc18fef 100644 --- a/api/core/sandbox/storage/archive_storage.py +++ b/api/core/sandbox/storage/archive_storage.py @@ -2,6 +2,7 @@ import logging from io import BytesIO from core.sandbox.storage.sandbox_storage import SandboxStorage +from core.virtual_environment.__base.helpers import try_execute from core.virtual_environment.__base.virtual_environment import VirtualEnvironment from extensions.ext_storage import Storage @@ -29,38 +30,25 @@ class ArchiveSandboxStorage(SandboxStorage): archive_data = self._storage.load_once(self._storage_key) sandbox.upload_file(ARCHIVE_NAME, BytesIO(archive_data)) - connection = sandbox.establish_connection() - try: - future = sandbox.run_command(connection, ["tar", "-xzf", ARCHIVE_NAME]) - result = future.result(timeout=60) - if result.is_error: - logger.error("Failed to extract archive: %s", result.error_message) - return False - finally: - sandbox.release_connection(connection) + result = try_execute(sandbox, ["tar", "-xzf", ARCHIVE_NAME], timeout=60) + if result.is_error: + logger.error("Failed to extract archive: %s", result.error_message) + return False - connection = sandbox.establish_connection() - try: - sandbox.run_command(connection, ["rm", ARCHIVE_NAME]).result(timeout=10) - finally: - sandbox.release_connection(connection) + try_execute(sandbox, ["rm", ARCHIVE_NAME], timeout=10) logger.info("Mounted archive for sandbox %s", self._sandbox_id) return True def unmount(self, sandbox: VirtualEnvironment) -> bool: - connection = sandbox.establish_connection() - try: - future = sandbox.run_command( - connection, - ["tar", "-czf", ARCHIVE_NAME, "-C", WORKSPACE_DIR, "."], - ) - result = future.result(timeout=120) - if result.is_error: - logger.error("Failed to create archive: %s", result.error_message) - return False - finally: - sandbox.release_connection(connection) + result = try_execute( + sandbox, + ["tar", "-czf", ARCHIVE_NAME, "-C", WORKSPACE_DIR, "."], + timeout=120, + ) + if result.is_error: + logger.error("Failed to create archive: %s", result.error_message) + return False archive_content = sandbox.download_file(ARCHIVE_NAME) self._storage.save(self._storage_key, archive_content.getvalue()) diff --git a/api/core/sandbox/utils/__init__.py b/api/core/sandbox/utils/__init__.py new file mode 100644 index 0000000000..c1d71c108c --- /dev/null +++ b/api/core/sandbox/utils/__init__.py @@ -0,0 +1,2 @@ +# Sandbox utilities +# Connection helpers have been moved to core.virtual_environment.helpers diff --git a/api/core/sandbox/debug.py b/api/core/sandbox/utils/debug.py similarity index 100% rename from api/core/sandbox/debug.py rename to api/core/sandbox/utils/debug.py diff --git a/api/core/sandbox/encryption.py b/api/core/sandbox/utils/encryption.py similarity index 100% rename from api/core/sandbox/encryption.py rename to api/core/sandbox/utils/encryption.py diff --git a/api/core/virtual_environment/__base/exec.py b/api/core/virtual_environment/__base/exec.py index 2ec420d84b..d523556e7d 100644 --- a/api/core/virtual_environment/__base/exec.py +++ b/api/core/virtual_environment/__base/exec.py @@ -1,3 +1,11 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from core.virtual_environment.__base.entities import CommandResult + + class ArchNotSupportedError(Exception): """Exception raised when the architecture is not supported.""" @@ -20,3 +28,19 @@ class SandboxConfigValidationError(ValueError): """Exception raised when sandbox configuration validation fails.""" pass + + +class CommandExecutionError(Exception): + """Raised when a command execution fails.""" + + def __init__(self, message: str, result: CommandResult): + super().__init__(message) + self.result = result + + @property + def exit_code(self) -> int | None: + return self.result.exit_code + + @property + def stderr(self) -> bytes: + return self.result.stderr diff --git a/api/core/virtual_environment/__base/helpers.py b/api/core/virtual_environment/__base/helpers.py new file mode 100644 index 0000000000..808f8a594f --- /dev/null +++ b/api/core/virtual_environment/__base/helpers.py @@ -0,0 +1,149 @@ +from __future__ import annotations + +import contextlib +from collections.abc import Generator, Mapping +from contextlib import contextmanager +from functools import partial + +from core.virtual_environment.__base.command_future import CommandFuture +from core.virtual_environment.__base.entities import CommandResult, ConnectionHandle +from core.virtual_environment.__base.exec import CommandExecutionError +from core.virtual_environment.__base.virtual_environment import VirtualEnvironment + + +@contextmanager +def with_connection(env: VirtualEnvironment) -> Generator[ConnectionHandle, None, None]: + """Context manager for VirtualEnvironment connection lifecycle. + + Automatically establishes and releases connection handles. + + Usage: + with with_connection(env) as conn: + future = run_command(env, conn, ["echo", "hello"]) + result = future.result(timeout=10) + """ + connection_handle = env.establish_connection() + try: + yield connection_handle + finally: + with contextlib.suppress(Exception): + env.release_connection(connection_handle) + + +def submit_command( + env: VirtualEnvironment, + connection: ConnectionHandle, + command: list[str], + environments: Mapping[str, str] | None = None, + *, + cwd: str | None = None, +) -> CommandFuture: + """Execute a command and return a Future for the result. + + High-level interface that handles IO draining internally. + For streaming output, use env.execute_command() instead. + + Args: + env: The virtual environment to execute the command in. + connection: The connection handle. + command: Command as list of strings. + environments: Environment variables. + cwd: Working directory for the command. If None, uses the provider's default. + + Returns: + CommandFuture that can be used to get result with timeout or cancel. + + Example: + with with_connection(env) as conn: + result = run_command(env, conn, ["ls", "-la"]).result(timeout=30) + """ + pid, stdin_transport, stdout_transport, stderr_transport = env.execute_command( + connection, command, environments, cwd + ) + + return CommandFuture( + pid=pid, + stdin_transport=stdin_transport, + stdout_transport=stdout_transport, + stderr_transport=stderr_transport, + poll_status=partial(env.get_command_status, connection, pid), + ) + + +def _execute_with_connection( + env: VirtualEnvironment, + conn: ConnectionHandle, + command: list[str], + timeout: float | None, + cwd: str | None, +) -> CommandResult: + """Internal helper to execute command with given connection.""" + future = submit_command(env, conn, command, cwd=cwd) + return future.result(timeout=timeout) + + +def execute( + env: VirtualEnvironment, + command: list[str], + *, + timeout: float | None = 30, + cwd: str | None = None, + error_message: str = "Command failed", + connection: ConnectionHandle | None = None, +) -> CommandResult: + """Execute a command with automatic connection management. + + Raises CommandExecutionError if the command fails (non-zero exit code). + + Args: + env: The virtual environment to execute the command in. + command: The command to execute as a list of strings. + timeout: Maximum time to wait for the command to complete (seconds). + cwd: Working directory for the command. + error_message: Custom error message prefix for failures. + connection: Optional connection handle to reuse. If None, creates and releases a new connection. + + Returns: + CommandResult on success. + + Raises: + CommandExecutionError: If the command fails. + """ + if connection is not None: + result = _execute_with_connection(env, connection, command, timeout, cwd) + else: + with with_connection(env) as conn: + result = _execute_with_connection(env, conn, command, timeout, cwd) + + if result.is_error: + raise CommandExecutionError(f"{error_message}: {result.error_message}", result) + return result + + +def try_execute( + env: VirtualEnvironment, + command: list[str], + *, + timeout: float | None = 30, + cwd: str | None = None, + connection: ConnectionHandle | None = None, +) -> CommandResult: + """Execute a command with automatic connection management. + + Does not raise on failure - returns the result for caller to handle. + + Args: + env: The virtual environment to execute the command in. + command: The command to execute as a list of strings. + timeout: Maximum time to wait for the command to complete (seconds). + cwd: Working directory for the command. + connection: Optional connection handle to reuse. If None, creates and releases a new connection. + + Returns: + CommandResult containing stdout, stderr, and exit_code. + """ + if connection is not None: + return _execute_with_connection(env, connection, command, timeout, cwd) + + with with_connection(env) as conn: + return _execute_with_connection(env, conn, command, timeout, cwd) diff --git a/api/core/virtual_environment/__base/virtual_environment.py b/api/core/virtual_environment/__base/virtual_environment.py index b74230f8c5..6fe4f0f081 100644 --- a/api/core/virtual_environment/__base/virtual_environment.py +++ b/api/core/virtual_environment/__base/virtual_environment.py @@ -1,10 +1,8 @@ from abc import ABC, abstractmethod from collections.abc import Mapping, Sequence -from functools import partial from io import BytesIO from typing import Any -from core.virtual_environment.__base.command_future import CommandFuture from core.virtual_environment.__base.entities import CommandStatus, ConnectionHandle, FileState, Metadata from core.virtual_environment.channel.transport import TransportReadCloser, TransportWriteCloser @@ -176,40 +174,3 @@ class VirtualEnvironment(ABC): Returns: CommandStatus: The status of the command execution. """ - - def run_command( - self, - connection_handle: ConnectionHandle, - command: list[str], - environments: Mapping[str, str] | None = None, - cwd: str | None = None, - ) -> CommandFuture: - """ - Execute a command and return a Future for the result. - - High-level interface that handles IO draining internally. - For streaming output, use execute_command() instead. - - Args: - connection_handle: The connection handle. - command: Command as list of strings. - environments: Environment variables. - cwd: Working directory for the command. If None, uses the provider's default. - - Returns: - CommandFuture that can be used to get result with timeout or cancel. - - Example: - result = env.run_command(handle, ["ls", "-la"]).result(timeout=30) - """ - pid, stdin_transport, stdout_transport, stderr_transport = self.execute_command( - connection_handle, command, environments, cwd - ) - - return CommandFuture( - pid=pid, - stdin_transport=stdin_transport, - stdout_transport=stdout_transport, - stderr_transport=stderr_transport, - poll_status=partial(self.get_command_status, connection_handle, pid), - ) diff --git a/api/core/workflow/nodes/command/node.py b/api/core/workflow/nodes/command/node.py index a5d5347d0b..39327aab65 100644 --- a/api/core/workflow/nodes/command/node.py +++ b/api/core/workflow/nodes/command/node.py @@ -1,12 +1,12 @@ -import contextlib import logging import shlex from collections.abc import Mapping, Sequence from typing import Any -from core.sandbox.debug import sandbox_debug from core.sandbox.manager import SandboxManager +from core.sandbox.utils.debug import sandbox_debug from core.virtual_environment.__base.command_future import CommandCancelledError, CommandTimeoutError +from core.virtual_environment.__base.helpers import submit_command, with_connection from core.virtual_environment.__base.virtual_environment import VirtualEnvironment from core.workflow.enums import NodeType, WorkflowNodeExecutionStatus from core.workflow.node_events import NodeRunResult @@ -80,42 +80,41 @@ class CommandNode(Node[CommandNodeData]): ) timeout = COMMAND_NODE_TIMEOUT_SECONDS if COMMAND_NODE_TIMEOUT_SECONDS > 0 else None - connection_handle = sandbox.establish_connection() try: - command = shlex.split(raw_command) + with with_connection(sandbox) as conn: + command = shlex.split(raw_command) - sandbox_debug("command_node", "command", command) + sandbox_debug("command_node", "command", command) - future = sandbox.run_command(connection_handle, command, cwd=working_directory) - result = future.result(timeout=timeout) + future = submit_command(sandbox, conn, command, cwd=working_directory) + result = future.result(timeout=timeout) - outputs: dict[str, Any] = { - "stdout": result.stdout.decode("utf-8", errors="replace"), - "stderr": result.stderr.decode("utf-8", errors="replace"), - "exit_code": result.exit_code, - "pid": result.pid, - } - process_data = {"command": command, "working_directory": working_directory} + outputs: dict[str, Any] = { + "stdout": result.stdout.decode("utf-8", errors="replace"), + "stderr": result.stderr.decode("utf-8", errors="replace"), + "exit_code": result.exit_code, + "pid": result.pid, + } + process_data = {"command": command, "working_directory": working_directory} + + if result.exit_code not in (None, 0): + stderr_text = result.stderr.decode("utf-8", errors="replace") + error_message = f"{stderr_text}\n\nCommand exited with code {result.exit_code}" + return NodeRunResult( + status=WorkflowNodeExecutionStatus.FAILED, + outputs=outputs, + process_data=process_data, + error=error_message, + error_type=CommandExecutionError.__name__, + ) - if result.exit_code not in (None, 0): - error_message = ( - f"{result.stderr.decode('utf-8', errors='replace')}\n\nCommand exited with code {result.exit_code}" - ) return NodeRunResult( - status=WorkflowNodeExecutionStatus.FAILED, + status=WorkflowNodeExecutionStatus.SUCCEEDED, outputs=outputs, process_data=process_data, - error=error_message, - error_type=CommandExecutionError.__name__, ) - return NodeRunResult( - status=WorkflowNodeExecutionStatus.SUCCEEDED, - outputs=outputs, - process_data=process_data, - ) - except CommandTimeoutError: return NodeRunResult( status=WorkflowNodeExecutionStatus.FAILED, @@ -135,9 +134,6 @@ class CommandNode(Node[CommandNodeData]): error=str(e), error_type=type(e).__name__, ) - finally: - with contextlib.suppress(Exception): - sandbox.release_connection(connection_handle) @classmethod def _extract_variable_selector_to_variable_mapping( diff --git a/api/models/app_asset.py b/api/models/app_asset.py index 70a9f095ad..2d66de9ee2 100644 --- a/api/models/app_asset.py +++ b/api/models/app_asset.py @@ -5,7 +5,8 @@ import sqlalchemy as sa from sqlalchemy import DateTime, String, func from sqlalchemy.orm import Mapped, mapped_column -from ..core.app.entities.app_asset_entities import AppAssetFileTree +from core.app.entities.app_asset_entities import AppAssetFileTree + from .base import Base from .types import LongText, StringUUID diff --git a/api/services/sandbox/sandbox_provider_service.py b/api/services/sandbox/sandbox_provider_service.py index ac0caac8bf..f88a53137a 100644 --- a/api/services/sandbox/sandbox_provider_service.py +++ b/api/services/sandbox/sandbox_provider_service.py @@ -19,9 +19,9 @@ from sqlalchemy.orm import Session from configs import dify_config from constants import HIDDEN_VALUE from core.entities.provider_entities import BasicProviderConfig -from core.sandbox.encryption import create_sandbox_config_encrypter, masked_config from core.sandbox.factory import VMFactory, VMType from core.sandbox.initializer import DifyCliInitializer +from core.sandbox.utils.encryption import create_sandbox_config_encrypter, masked_config from core.tools.utils.system_encryption import ( decrypt_system_params, ) diff --git a/api/tests/unit_tests/core/virtual_environment/test_helpers.py b/api/tests/unit_tests/core/virtual_environment/test_helpers.py new file mode 100644 index 0000000000..8b91ee1c95 --- /dev/null +++ b/api/tests/unit_tests/core/virtual_environment/test_helpers.py @@ -0,0 +1,264 @@ +from collections.abc import Mapping +from io import BytesIO +from typing import Any + +import pytest + +from core.virtual_environment.__base.entities import ( + Arch, + CommandStatus, + ConnectionHandle, + FileState, + Metadata, + OperatingSystem, +) +from core.virtual_environment.__base.exec import CommandExecutionError +from core.virtual_environment.__base.helpers import execute, try_execute, with_connection +from core.virtual_environment.__base.virtual_environment import VirtualEnvironment +from core.virtual_environment.channel.exec import TransportEOFError +from core.virtual_environment.channel.transport import TransportReadCloser, TransportWriteCloser + + +class MockReadTransport(TransportReadCloser): + """Mock transport that returns data once then raises EOF.""" + + def __init__(self, data: bytes): + self._data = data + self._read = False + + def read(self, n: int) -> bytes: + if self._read: + raise TransportEOFError() + self._read = True + return self._data[:n] if n < len(self._data) else self._data + + def close(self) -> None: + pass + + +class MockWriteTransport(TransportWriteCloser): + """Mock transport for stdin (no-op).""" + + def write(self, data: bytes) -> None: + pass + + def close(self) -> None: + pass + + +class FakeVirtualEnvironment(VirtualEnvironment): + """Fake virtual environment for testing connection utilities.""" + + def __init__( + self, + *, + exit_code: int | None = 0, + stdout: bytes = b"", + stderr: bytes = b"", + ): + self._exit_code = exit_code + self._stdout = stdout + self._stderr = stderr + self._connection_established = False + self._connection_released = False + self._establish_count = 0 + self._release_count = 0 + super().__init__(tenant_id="test-tenant", options={}, environments={}) + + def _construct_environment(self, _options: Mapping[str, Any], _environments: Mapping[str, str]) -> Metadata: + return Metadata(id="fake-id", arch=Arch.AMD64, os=OperatingSystem.LINUX) + + def upload_file(self, _path: str, _content: BytesIO) -> None: + raise NotImplementedError + + def download_file(self, _path: str) -> BytesIO: + raise NotImplementedError + + def list_files(self, _directory_path: str, _limit: int) -> list[FileState]: + return [] + + def establish_connection(self) -> ConnectionHandle: + self._connection_established = True + self._establish_count += 1 + return ConnectionHandle(id=f"test-conn-{self._establish_count}") + + def release_connection(self, _connection_handle: ConnectionHandle) -> None: + self._connection_released = True + self._release_count += 1 + + def release_environment(self) -> None: + pass + + def execute_command( + self, + _connection_handle: ConnectionHandle, + _command: list[str], + _environments: Mapping[str, str] | None = None, + _cwd: str | None = None, + ) -> tuple[str, TransportWriteCloser, TransportReadCloser, TransportReadCloser]: + """Return mock transports for testing.""" + return ( + "test-pid", + MockWriteTransport(), + MockReadTransport(self._stdout), + MockReadTransport(self._stderr), + ) + + def get_command_status(self, _connection_handle: ConnectionHandle, _pid: str) -> CommandStatus: + return CommandStatus(status=CommandStatus.Status.COMPLETED, exit_code=self._exit_code) + + @classmethod + def validate(cls, _options: Mapping[str, Any]) -> None: + pass + + +class TestWithConnection: + def test_connection_established_and_released(self): + env = FakeVirtualEnvironment() + + with with_connection(env) as conn: + assert env._connection_established is True + assert conn.id == "test-conn-1" + + assert env._connection_released is True + + def test_connection_released_on_exception(self): + env = FakeVirtualEnvironment() + + with pytest.raises(ValueError): + with with_connection(env): + raise ValueError("test error") + + assert env._connection_released is True + + +class TestExecute: + def test_execute_success(self): + env = FakeVirtualEnvironment(exit_code=0, stdout=b"hello world") + + result = execute(env, ["echo", "hello"]) + + assert result.stdout == b"hello world" + assert result.exit_code == 0 + assert env._connection_released is True + + def test_execute_raises_on_nonzero_exit_code(self): + env = FakeVirtualEnvironment(exit_code=1, stderr=b"command not found") + + with pytest.raises(CommandExecutionError, match="Command failed: command not found") as exc_info: + execute(env, ["invalid-command"]) + + assert exc_info.value.exit_code == 1 + assert exc_info.value.stderr == b"command not found" + assert env._connection_released is True + + def test_execute_with_custom_error_message(self): + env = FakeVirtualEnvironment(exit_code=1, stderr=b"error") + + with pytest.raises(CommandExecutionError, match="Custom error: error"): + execute(env, ["cmd"], error_message="Custom error") + + def test_execute_releases_connection_on_error(self): + env = FakeVirtualEnvironment(exit_code=1, stderr=b"error") + + with pytest.raises(CommandExecutionError): + execute(env, ["cmd"]) + + assert env._connection_released is True + + +class TestTryExecute: + def test_try_execute_success(self): + env = FakeVirtualEnvironment(exit_code=0, stdout=b"output") + + result = try_execute(env, ["echo", "test"]) + + assert result.stdout == b"output" + assert result.exit_code == 0 + assert env._connection_released is True + + def test_try_execute_returns_error_result(self): + env = FakeVirtualEnvironment(exit_code=1, stderr=b"error message") + + result = try_execute(env, ["failing-command"]) + + assert result.exit_code == 1 + assert result.stderr == b"error message" + assert result.is_error is True + assert env._connection_released is True + + def test_try_execute_does_not_raise(self): + env = FakeVirtualEnvironment(exit_code=127, stderr=b"not found") + + result = try_execute(env, ["nonexistent"]) + + assert result.exit_code == 127 + assert env._connection_released is True + + +class TestConnectionReuse: + def test_execute_with_reused_connection(self): + """Test that execute reuses provided connection without creating new one.""" + env = FakeVirtualEnvironment(exit_code=0, stdout=b"output") + + with with_connection(env) as conn: + # Execute with reused connection + result = execute(env, ["cmd1"], connection=conn) + assert result.stdout == b"output" + + # Should have only established one connection (from with_connection) + assert env._establish_count == 1 + assert env._release_count == 0 # Not released yet + + # Now connection should be released + assert env._release_count == 1 + + def test_execute_without_connection_creates_new(self): + """Test that execute without connection creates and releases its own.""" + env = FakeVirtualEnvironment(exit_code=0, stdout=b"output") + + execute(env, ["cmd1"]) + + assert env._establish_count == 1 + assert env._release_count == 1 + + def test_multiple_executes_with_same_connection(self): + """Test multiple execute calls reusing the same connection.""" + env = FakeVirtualEnvironment(exit_code=0, stdout=b"output") + + with with_connection(env) as conn: + execute(env, ["cmd1"], connection=conn) + execute(env, ["cmd2"], connection=conn) + execute(env, ["cmd3"], connection=conn) + + # Only one connection established + assert env._establish_count == 1 + assert env._release_count == 0 + + # Released once at the end + assert env._release_count == 1 + + def test_try_execute_with_reused_connection(self): + """Test that try_execute reuses provided connection.""" + env = FakeVirtualEnvironment(exit_code=0, stdout=b"output") + + with with_connection(env) as conn: + result = try_execute(env, ["cmd1"], connection=conn) + assert result.stdout == b"output" + assert env._establish_count == 1 + assert env._release_count == 0 + + assert env._release_count == 1 + + def test_mixed_execute_and_try_execute_reuse(self): + """Test mixing execute and try_execute with same connection.""" + env = FakeVirtualEnvironment(exit_code=0, stdout=b"output") + + with with_connection(env) as conn: + execute(env, ["cmd1"], connection=conn) + try_execute(env, ["cmd2"], connection=conn) + execute(env, ["cmd3"], connection=conn) + + assert env._establish_count == 1 + + assert env._release_count == 1 diff --git a/api/tests/unit_tests/core/virtual_environment/test_local_without_isolation.py b/api/tests/unit_tests/core/virtual_environment/test_local_without_isolation.py index 63438211a8..52f1c986ee 100644 --- a/api/tests/unit_tests/core/virtual_environment/test_local_without_isolation.py +++ b/api/tests/unit_tests/core/virtual_environment/test_local_without_isolation.py @@ -3,6 +3,7 @@ from pathlib import Path import pytest +from core.virtual_environment.__base.helpers import submit_command from core.virtual_environment.channel.exec import TransportEOFError from core.virtual_environment.channel.transport import TransportReadCloser from core.virtual_environment.providers import local_without_isolation @@ -99,7 +100,7 @@ def test_run_command_returns_output(local_env: LocalVirtualEnvironment): local_env.upload_file("message.txt", BytesIO(b"hello")) connection = local_env.establish_connection() - result = local_env.run_command(connection, ["/bin/sh", "-c", "cat message.txt"]).result(timeout=10) + result = submit_command(local_env, connection, ["/bin/sh", "-c", "cat message.txt"]).result(timeout=10) assert result.stdout == b"hello" assert result.stderr == b"" @@ -109,7 +110,7 @@ def test_run_command_returns_output(local_env: LocalVirtualEnvironment): def test_run_command_captures_stderr(local_env: LocalVirtualEnvironment): connection = local_env.establish_connection() - result = local_env.run_command(connection, ["/bin/sh", "-c", "echo OUT; echo ERR >&2"]).result(timeout=10) + result = submit_command(local_env, connection, ["/bin/sh", "-c", "echo OUT; echo ERR >&2"]).result(timeout=10) assert b"OUT" in result.stdout assert b"ERR" in result.stderr