diff --git a/api/core/virtual_environment/__base/command_future.py b/api/core/virtual_environment/__base/command_future.py new file mode 100644 index 0000000000..2c9cd5b6ea --- /dev/null +++ b/api/core/virtual_environment/__base/command_future.py @@ -0,0 +1,170 @@ +import contextlib +import logging +import threading +import time +from collections.abc import Callable +from concurrent.futures import ThreadPoolExecutor + +from core.virtual_environment.__base.entities import CommandResult, CommandStatus +from core.virtual_environment.__base.exec import NotSupportedOperationError +from core.virtual_environment.channel.exec import TransportEOFError +from core.virtual_environment.channel.transport import TransportReadCloser, TransportWriteCloser + +logger = logging.getLogger(__name__) + + +class CommandTimeoutError(Exception): + pass + + +class CommandCancelledError(Exception): + pass + + +class CommandFuture: + """ + Lightweight future for command execution. + Mirrors concurrent.futures.Future API with 4 essential methods: + result(), done(), cancel(), cancelled(). + """ + + def __init__( + self, + pid: str, + stdin_transport: TransportWriteCloser, + stdout_transport: TransportReadCloser, + stderr_transport: TransportReadCloser, + poll_status: Callable[[], CommandStatus], + poll_interval: float = 0.1, + ): + self._pid = pid + self._stdin_transport = stdin_transport + self._stdout_transport = stdout_transport + self._stderr_transport = stderr_transport + self._poll_status = poll_status + self._poll_interval = poll_interval + + self._done_event = threading.Event() + self._lock = threading.Lock() + self._result: CommandResult | None = None + self._cancelled = False + self._started = False + + def result(self, timeout: float | None = None) -> CommandResult: + """ + Block until command completes and return result. + + Args: + timeout: Maximum seconds to wait. None means wait forever. + + Raises: + CommandTimeoutError: If timeout exceeded. + CommandCancelledError: If command was cancelled. + """ + self._ensure_started() + + if not self._done_event.wait(timeout): + raise CommandTimeoutError(f"Command timed out after {timeout}s") + + if self._cancelled: + raise CommandCancelledError("Command was cancelled") + + assert self._result is not None + return self._result + + def done(self) -> bool: + self._ensure_started() + return self._done_event.is_set() + + def cancel(self) -> bool: + """ + Attempt to cancel command by closing transports. + Returns True if cancelled, False if already completed. + """ + with self._lock: + if self._done_event.is_set(): + return False + self._cancelled = True + self._close_transports() + self._done_event.set() + return True + + def cancelled(self) -> bool: + return self._cancelled + + def _ensure_started(self) -> None: + with self._lock: + if not self._started: + self._started = True + thread = threading.Thread(target=self._execute, daemon=True) + thread.start() + + def _execute(self) -> None: + stdout_buf = bytearray() + stderr_buf = bytearray() + is_combined_stream = self._stdout_transport is self._stderr_transport + + try: + with ThreadPoolExecutor(max_workers=2) as executor: + stdout_future = executor.submit(self._drain_transport, self._stdout_transport, stdout_buf) + stderr_future = None + if not is_combined_stream: + stderr_future = executor.submit(self._drain_transport, self._stderr_transport, stderr_buf) + + exit_code = self._wait_for_completion() + + stdout_future.result() + if stderr_future: + stderr_future.result() + + with self._lock: + if not self._cancelled: + self._result = CommandResult( + stdout=bytes(stdout_buf), + stderr=b"" if is_combined_stream else bytes(stderr_buf), + exit_code=exit_code, + pid=self._pid, + ) + self._done_event.set() + + except Exception: + logger.exception("Command execution failed for pid %s", self._pid) + with self._lock: + if not self._cancelled: + self._result = CommandResult( + stdout=bytes(stdout_buf), + stderr=b"" if is_combined_stream else bytes(stderr_buf), + exit_code=None, + pid=self._pid, + ) + self._done_event.set() + finally: + self._close_transports() + + def _wait_for_completion(self) -> int | None: + while not self._cancelled: + try: + status = self._poll_status() + except NotSupportedOperationError: + return None + + if status.status == CommandStatus.Status.COMPLETED: + return status.exit_code + + time.sleep(self._poll_interval) + + return None + + def _drain_transport(self, transport: TransportReadCloser, buffer: bytearray) -> None: + try: + while True: + buffer.extend(transport.read(4096)) + except TransportEOFError: + pass + except Exception: + logger.exception("Failed reading transport") + + def _close_transports(self) -> None: + for transport in (self._stdin_transport, self._stdout_transport, self._stderr_transport): + with contextlib.suppress(Exception): + transport.close() diff --git a/api/core/virtual_environment/__base/entities.py b/api/core/virtual_environment/__base/entities.py index 2fc1292036..253d63d1fa 100644 --- a/api/core/virtual_environment/__base/entities.py +++ b/api/core/virtual_environment/__base/entities.py @@ -56,3 +56,14 @@ class FileState(BaseModel): path: str = Field(description="The path of the file in the virtual environment.") created_at: int = Field(description="The creation timestamp of the file.") updated_at: int = Field(description="The last modified timestamp of the file.") + + +class CommandResult(BaseModel): + """ + Result of a synchronous command execution. + """ + + stdout: bytes = Field(description="Standard output content.") + stderr: bytes = Field(description="Standard error content.") + exit_code: int | None = Field(description="Exit code of the command. None if unavailable.") + pid: str = Field(description="Process ID of the executed command.") diff --git a/api/core/virtual_environment/__base/virtual_environment.py b/api/core/virtual_environment/__base/virtual_environment.py index f006e296d6..9f6c163cc7 100644 --- a/api/core/virtual_environment/__base/virtual_environment.py +++ b/api/core/virtual_environment/__base/virtual_environment.py @@ -1,8 +1,10 @@ from abc import ABC, abstractmethod from collections.abc import Mapping, Sequence +from functools import partial from io import BytesIO from typing import Any +from core.virtual_environment.__base.command_future import CommandFuture from core.virtual_environment.__base.entities import CommandStatus, ConnectionHandle, FileState, Metadata from core.virtual_environment.channel.transport import TransportReadCloser, TransportWriteCloser @@ -144,3 +146,38 @@ class VirtualEnvironment(ABC): Returns: CommandStatus: The status of the command execution. """ + + def run_command( + self, + connection_handle: ConnectionHandle, + command: list[str], + environments: Mapping[str, str] | None = None, + ) -> CommandFuture: + """ + Execute a command and return a Future for the result. + + High-level interface that handles IO draining internally. + For streaming output, use execute_command() instead. + + Args: + connection_handle: The connection handle. + command: Command as list of strings. + environments: Environment variables. + + Returns: + CommandFuture that can be used to get result with timeout or cancel. + + Example: + result = env.run_command(handle, ["ls", "-la"]).result(timeout=30) + """ + pid, stdin_transport, stdout_transport, stderr_transport = self.execute_command( + connection_handle, command, environments + ) + + return CommandFuture( + pid=pid, + stdin_transport=stdin_transport, + stdout_transport=stdout_transport, + stderr_transport=stderr_transport, + poll_status=partial(self.get_command_status, connection_handle, pid), + ) diff --git a/api/core/workflow/nodes/command/node.py b/api/core/workflow/nodes/command/node.py index e03ee18c90..1a2c9e02a2 100644 --- a/api/core/workflow/nodes/command/node.py +++ b/api/core/workflow/nodes/command/node.py @@ -1,15 +1,11 @@ import contextlib import logging import shlex -import threading -import time from collections.abc import Mapping, Sequence from typing import Any -from core.virtual_environment.__base.exec import NotSupportedOperationError +from core.virtual_environment.__base.command_future import CommandCancelledError, CommandTimeoutError from core.virtual_environment.__base.virtual_environment import VirtualEnvironment -from core.virtual_environment.channel.exec import TransportEOFError -from core.virtual_environment.channel.transport import TransportReadCloser from core.workflow.enums import NodeType, WorkflowNodeExecutionStatus from core.workflow.node_events import NodeRunResult from core.workflow.nodes.base import variable_template_parser @@ -17,35 +13,22 @@ from core.workflow.nodes.base.entities import VariableSelector from core.workflow.nodes.base.node import Node from core.workflow.nodes.base.variable_template_parser import VariableTemplateParser from core.workflow.nodes.command.entities import CommandNodeData -from core.workflow.nodes.command.exc import CommandExecutionError, CommandTimeoutError +from core.workflow.nodes.command.exc import CommandExecutionError logger = logging.getLogger(__name__) COMMAND_NODE_TIMEOUT_SECONDS = 60 -def _drain_transport(transport: TransportReadCloser, buffer: bytearray) -> None: - try: - while True: - buffer.extend(transport.read(4096)) - except TransportEOFError: - pass - except Exception: - logger.exception("Failed reading transport") - finally: - with contextlib.suppress(Exception): - transport.close() - - class CommandNode(Node[CommandNodeData]): - """Command Node - execute shell commands in a VirtualEnvironment.""" - # FIXME: This is a temporary solution for sandbox injection from SandboxLayer. # The sandbox is dynamically attached by SandboxLayer.on_node_run_start() before # node execution and cleared by on_node_run_end(). A cleaner approach would be # to pass sandbox through GraphRuntimeState or use a proper dependency injection pattern. sandbox: VirtualEnvironment | None = None + node_type = NodeType.COMMAND + def _render_template(self, template: str) -> str: parser = VariableTemplateParser(template=template) selectors = parser.extract_variable_selectors() @@ -59,11 +42,8 @@ class CommandNode(Node[CommandNodeData]): return parser.format(inputs) - node_type = NodeType.COMMAND - @classmethod def get_default_config(cls, filters: Mapping[str, object] | None = None) -> Mapping[str, object]: - """Get default config of node.""" return { "type": "command", "config": { @@ -91,7 +71,6 @@ class CommandNode(Node[CommandNodeData]): raw_command = self._render_template(raw_command).strip() working_directory = working_directory or None - timeout_seconds = COMMAND_NODE_TIMEOUT_SECONDS if not raw_command: return NodeRunResult( @@ -105,136 +84,52 @@ class CommandNode(Node[CommandNodeData]): shell_command = f"cd {shlex.quote(working_directory)} && {raw_command}" command = ["sh", "-lc", shell_command] - - # 0 or negative means no timeout - deadline = None - if timeout_seconds > 0: - deadline = time.monotonic() + timeout_seconds + timeout = COMMAND_NODE_TIMEOUT_SECONDS if COMMAND_NODE_TIMEOUT_SECONDS > 0 else None connection_handle = self.sandbox.establish_connection() - - pid = "" - stdin_transport = None - stdout_transport = None - stderr_transport = None - threads: list[threading.Thread] = [] - stdout_buf = bytearray() - stderr_buf = bytearray() - try: - pid, stdin_transport, stdout_transport, stderr_transport = self.sandbox.execute_command( - connection_handle, command - ) - - is_combined_stream = stdout_transport is stderr_transport - - stdout_thread = threading.Thread( - target=_drain_transport, - args=(stdout_transport, stdout_buf), - daemon=True, - ) - threads.append(stdout_thread) - stdout_thread.start() - - if not is_combined_stream: - stderr_thread = threading.Thread( - target=_drain_transport, - args=(stderr_transport, stderr_buf), - daemon=True, - ) - threads.append(stderr_thread) - stderr_thread.start() - - exit_code: int | None = None - - while True: - if deadline is not None and time.monotonic() > deadline: - raise CommandTimeoutError(f"Command timed out after {timeout_seconds}s") - - try: - status = self.sandbox.get_command_status(connection_handle, pid) - except NotSupportedOperationError: - break - - if status.status == status.Status.COMPLETED: - exit_code = status.exit_code - break - - time.sleep(0.1) - - # Ensure transports are fully drained. - def _join_all() -> bool: - for t in threads: - remaining = None - if deadline is not None: - remaining = max(0.0, deadline - time.monotonic()) - t.join(timeout=remaining) - if t.is_alive(): - return False - return True - - if not _join_all(): - raise CommandTimeoutError(f"Command output not drained within {timeout_seconds}s") - - stdout_text = stdout_buf.decode("utf-8", errors="replace") - stderr_text = "" if is_combined_stream else stderr_buf.decode("utf-8", errors="replace") + future = self.sandbox.run_command(connection_handle, command) + result = future.result(timeout=timeout) outputs: dict[str, Any] = { - "stdout": stdout_text, - "stderr": stderr_text, - "exit_code": exit_code, - "pid": pid, + "stdout": result.stdout.decode("utf-8", errors="replace"), + "stderr": result.stderr.decode("utf-8", errors="replace"), + "exit_code": result.exit_code, + "pid": result.pid, } + process_data = {"command": command, "working_directory": working_directory} - if exit_code not in (None, 0): + if result.exit_code not in (None, 0): return NodeRunResult( status=WorkflowNodeExecutionStatus.FAILED, outputs=outputs, - process_data={"command": command, "working_directory": working_directory}, - error=f"Command exited with code {exit_code}", + process_data=process_data, + error=f"Command exited with code {result.exit_code}", error_type=CommandExecutionError.__name__, ) return NodeRunResult( status=WorkflowNodeExecutionStatus.SUCCEEDED, outputs=outputs, - process_data={"command": command, "working_directory": working_directory}, + process_data=process_data, ) - except (CommandExecutionError, CommandTimeoutError) as e: - if isinstance(e, CommandTimeoutError) and stdout_transport is not None: - for transport in (stdout_transport, stderr_transport): - if transport is None: - continue - with contextlib.suppress(Exception): - transport.close() - - for t in threads: - t.join(timeout=0.2) - + except CommandTimeoutError: return NodeRunResult( status=WorkflowNodeExecutionStatus.FAILED, - outputs={ - "stdout": stdout_buf.decode("utf-8", errors="replace"), - "stderr": stderr_buf.decode("utf-8", errors="replace"), - "exit_code": None, - "pid": pid, - }, - process_data={"command": command, "working_directory": working_directory}, - error=str(e), - error_type=type(e).__name__, + error=f"Command timed out after {COMMAND_NODE_TIMEOUT_SECONDS}s", + error_type=CommandTimeoutError.__name__, + ) + except CommandCancelledError: + return NodeRunResult( + status=WorkflowNodeExecutionStatus.FAILED, + error="Command was cancelled", + error_type=CommandCancelledError.__name__, ) except Exception as e: logger.exception("Command node %s failed", self.id) return NodeRunResult( status=WorkflowNodeExecutionStatus.FAILED, - outputs={ - "stdout": stdout_buf.decode("utf-8", errors="replace"), - "stderr": stderr_buf.decode("utf-8", errors="replace"), - "exit_code": None, - "pid": pid, - }, - process_data={"command": command, "working_directory": working_directory}, error=str(e), error_type=type(e).__name__, ) @@ -250,8 +145,7 @@ class CommandNode(Node[CommandNodeData]): node_id: str, node_data: Mapping[str, Any], ) -> Mapping[str, Sequence[str]]: - """Extract variable mappings from node data.""" - _ = graph_config # Explicitly mark as unused + _ = graph_config typed_node_data = CommandNodeData.model_validate(node_data) diff --git a/api/tests/unit_tests/core/virtual_environment/__base/test_command_future.py b/api/tests/unit_tests/core/virtual_environment/__base/test_command_future.py new file mode 100644 index 0000000000..455a64a9af --- /dev/null +++ b/api/tests/unit_tests/core/virtual_environment/__base/test_command_future.py @@ -0,0 +1,123 @@ +import threading + +import pytest + +from core.virtual_environment.__base.command_future import ( + CommandCancelledError, + CommandFuture, + CommandTimeoutError, +) +from core.virtual_environment.__base.entities import CommandStatus +from core.virtual_environment.channel.queue_transport import QueueTransportReadCloser +from core.virtual_environment.channel.transport import NopTransportWriteCloser + + +def _make_future( + stdout: bytes = b"", + stderr: bytes = b"", + exit_code: int = 0, + delay_completion: float = 0, + close_streams: bool = True, +) -> CommandFuture: + stdout_transport = QueueTransportReadCloser() + stderr_transport = QueueTransportReadCloser() + + if stdout: + stdout_transport.get_write_handler().write(stdout) + if stderr: + stderr_transport.get_write_handler().write(stderr) + + if close_streams: + stdout_transport.close() + stderr_transport.close() + + completion_event = threading.Event() + if delay_completion == 0: + completion_event.set() + else: + threading.Timer(delay_completion, completion_event.set).start() + + def poll_status() -> CommandStatus: + if completion_event.is_set(): + return CommandStatus(status=CommandStatus.Status.COMPLETED, exit_code=exit_code) + return CommandStatus(status=CommandStatus.Status.RUNNING, exit_code=None) + + return CommandFuture( + pid="test-pid", + stdin_transport=NopTransportWriteCloser(), + stdout_transport=stdout_transport, + stderr_transport=stderr_transport, + poll_status=poll_status, + poll_interval=0.05, + ) + + +def test_result_returns_command_output(): + future = _make_future(stdout=b"hello\n", stderr=b"world\n", exit_code=0) + + result = future.result() + + assert result.stdout == b"hello\n" + assert result.stderr == b"world\n" + assert result.exit_code == 0 + assert result.pid == "test-pid" + + +def test_result_with_timeout_succeeds_when_command_completes_in_time(): + future = _make_future(stdout=b"ok", delay_completion=0.1) + + result = future.result(timeout=5.0) + + assert result.stdout == b"ok" + + +def test_result_raises_timeout_error_when_exceeded(): + future = _make_future(delay_completion=10.0, close_streams=False) + + with pytest.raises(CommandTimeoutError): + future.result(timeout=0.2) + + +def test_done_returns_false_while_running(): + future = _make_future(delay_completion=10.0, close_streams=False) + + assert future.done() is False + + +def test_done_returns_true_after_completion(): + future = _make_future(stdout=b"done") + + future.result() + + assert future.done() is True + + +def test_cancel_returns_true_and_sets_cancelled(): + future = _make_future(delay_completion=10.0, close_streams=False) + + assert future.cancel() is True + assert future.cancelled() is True + + +def test_cancel_returns_false_if_already_completed(): + future = _make_future(stdout=b"done") + future.result() + + assert future.cancel() is False + assert future.cancelled() is False + + +def test_result_raises_cancelled_error_after_cancel(): + future = _make_future(delay_completion=10.0, close_streams=False) + future.cancel() + + with pytest.raises(CommandCancelledError): + future.result() + + +def test_nonzero_exit_code_is_returned(): + future = _make_future(stdout=b"err", exit_code=42) + + result = future.result() + + assert result.exit_code == 42