feat: future interface for easy way to use VM.execute_command

This commit is contained in:
Harry 2026-01-07 11:57:00 +08:00
parent 888be71639
commit 05c3344554
5 changed files with 367 additions and 132 deletions

View File

@ -0,0 +1,170 @@
import contextlib
import logging
import threading
import time
from collections.abc import Callable
from concurrent.futures import ThreadPoolExecutor
from core.virtual_environment.__base.entities import CommandResult, CommandStatus
from core.virtual_environment.__base.exec import NotSupportedOperationError
from core.virtual_environment.channel.exec import TransportEOFError
from core.virtual_environment.channel.transport import TransportReadCloser, TransportWriteCloser
logger = logging.getLogger(__name__)
class CommandTimeoutError(Exception):
pass
class CommandCancelledError(Exception):
pass
class CommandFuture:
"""
Lightweight future for command execution.
Mirrors concurrent.futures.Future API with 4 essential methods:
result(), done(), cancel(), cancelled().
"""
def __init__(
self,
pid: str,
stdin_transport: TransportWriteCloser,
stdout_transport: TransportReadCloser,
stderr_transport: TransportReadCloser,
poll_status: Callable[[], CommandStatus],
poll_interval: float = 0.1,
):
self._pid = pid
self._stdin_transport = stdin_transport
self._stdout_transport = stdout_transport
self._stderr_transport = stderr_transport
self._poll_status = poll_status
self._poll_interval = poll_interval
self._done_event = threading.Event()
self._lock = threading.Lock()
self._result: CommandResult | None = None
self._cancelled = False
self._started = False
def result(self, timeout: float | None = None) -> CommandResult:
"""
Block until command completes and return result.
Args:
timeout: Maximum seconds to wait. None means wait forever.
Raises:
CommandTimeoutError: If timeout exceeded.
CommandCancelledError: If command was cancelled.
"""
self._ensure_started()
if not self._done_event.wait(timeout):
raise CommandTimeoutError(f"Command timed out after {timeout}s")
if self._cancelled:
raise CommandCancelledError("Command was cancelled")
assert self._result is not None
return self._result
def done(self) -> bool:
self._ensure_started()
return self._done_event.is_set()
def cancel(self) -> bool:
"""
Attempt to cancel command by closing transports.
Returns True if cancelled, False if already completed.
"""
with self._lock:
if self._done_event.is_set():
return False
self._cancelled = True
self._close_transports()
self._done_event.set()
return True
def cancelled(self) -> bool:
return self._cancelled
def _ensure_started(self) -> None:
with self._lock:
if not self._started:
self._started = True
thread = threading.Thread(target=self._execute, daemon=True)
thread.start()
def _execute(self) -> None:
stdout_buf = bytearray()
stderr_buf = bytearray()
is_combined_stream = self._stdout_transport is self._stderr_transport
try:
with ThreadPoolExecutor(max_workers=2) as executor:
stdout_future = executor.submit(self._drain_transport, self._stdout_transport, stdout_buf)
stderr_future = None
if not is_combined_stream:
stderr_future = executor.submit(self._drain_transport, self._stderr_transport, stderr_buf)
exit_code = self._wait_for_completion()
stdout_future.result()
if stderr_future:
stderr_future.result()
with self._lock:
if not self._cancelled:
self._result = CommandResult(
stdout=bytes(stdout_buf),
stderr=b"" if is_combined_stream else bytes(stderr_buf),
exit_code=exit_code,
pid=self._pid,
)
self._done_event.set()
except Exception:
logger.exception("Command execution failed for pid %s", self._pid)
with self._lock:
if not self._cancelled:
self._result = CommandResult(
stdout=bytes(stdout_buf),
stderr=b"" if is_combined_stream else bytes(stderr_buf),
exit_code=None,
pid=self._pid,
)
self._done_event.set()
finally:
self._close_transports()
def _wait_for_completion(self) -> int | None:
while not self._cancelled:
try:
status = self._poll_status()
except NotSupportedOperationError:
return None
if status.status == CommandStatus.Status.COMPLETED:
return status.exit_code
time.sleep(self._poll_interval)
return None
def _drain_transport(self, transport: TransportReadCloser, buffer: bytearray) -> None:
try:
while True:
buffer.extend(transport.read(4096))
except TransportEOFError:
pass
except Exception:
logger.exception("Failed reading transport")
def _close_transports(self) -> None:
for transport in (self._stdin_transport, self._stdout_transport, self._stderr_transport):
with contextlib.suppress(Exception):
transport.close()

View File

@ -56,3 +56,14 @@ class FileState(BaseModel):
path: str = Field(description="The path of the file in the virtual environment.")
created_at: int = Field(description="The creation timestamp of the file.")
updated_at: int = Field(description="The last modified timestamp of the file.")
class CommandResult(BaseModel):
"""
Result of a synchronous command execution.
"""
stdout: bytes = Field(description="Standard output content.")
stderr: bytes = Field(description="Standard error content.")
exit_code: int | None = Field(description="Exit code of the command. None if unavailable.")
pid: str = Field(description="Process ID of the executed command.")

View File

@ -1,8 +1,10 @@
from abc import ABC, abstractmethod
from collections.abc import Mapping, Sequence
from functools import partial
from io import BytesIO
from typing import Any
from core.virtual_environment.__base.command_future import CommandFuture
from core.virtual_environment.__base.entities import CommandStatus, ConnectionHandle, FileState, Metadata
from core.virtual_environment.channel.transport import TransportReadCloser, TransportWriteCloser
@ -144,3 +146,38 @@ class VirtualEnvironment(ABC):
Returns:
CommandStatus: The status of the command execution.
"""
def run_command(
self,
connection_handle: ConnectionHandle,
command: list[str],
environments: Mapping[str, str] | None = None,
) -> CommandFuture:
"""
Execute a command and return a Future for the result.
High-level interface that handles IO draining internally.
For streaming output, use execute_command() instead.
Args:
connection_handle: The connection handle.
command: Command as list of strings.
environments: Environment variables.
Returns:
CommandFuture that can be used to get result with timeout or cancel.
Example:
result = env.run_command(handle, ["ls", "-la"]).result(timeout=30)
"""
pid, stdin_transport, stdout_transport, stderr_transport = self.execute_command(
connection_handle, command, environments
)
return CommandFuture(
pid=pid,
stdin_transport=stdin_transport,
stdout_transport=stdout_transport,
stderr_transport=stderr_transport,
poll_status=partial(self.get_command_status, connection_handle, pid),
)

View File

@ -1,15 +1,11 @@
import contextlib
import logging
import shlex
import threading
import time
from collections.abc import Mapping, Sequence
from typing import Any
from core.virtual_environment.__base.exec import NotSupportedOperationError
from core.virtual_environment.__base.command_future import CommandCancelledError, CommandTimeoutError
from core.virtual_environment.__base.virtual_environment import VirtualEnvironment
from core.virtual_environment.channel.exec import TransportEOFError
from core.virtual_environment.channel.transport import TransportReadCloser
from core.workflow.enums import NodeType, WorkflowNodeExecutionStatus
from core.workflow.node_events import NodeRunResult
from core.workflow.nodes.base import variable_template_parser
@ -17,35 +13,22 @@ from core.workflow.nodes.base.entities import VariableSelector
from core.workflow.nodes.base.node import Node
from core.workflow.nodes.base.variable_template_parser import VariableTemplateParser
from core.workflow.nodes.command.entities import CommandNodeData
from core.workflow.nodes.command.exc import CommandExecutionError, CommandTimeoutError
from core.workflow.nodes.command.exc import CommandExecutionError
logger = logging.getLogger(__name__)
COMMAND_NODE_TIMEOUT_SECONDS = 60
def _drain_transport(transport: TransportReadCloser, buffer: bytearray) -> None:
try:
while True:
buffer.extend(transport.read(4096))
except TransportEOFError:
pass
except Exception:
logger.exception("Failed reading transport")
finally:
with contextlib.suppress(Exception):
transport.close()
class CommandNode(Node[CommandNodeData]):
"""Command Node - execute shell commands in a VirtualEnvironment."""
# FIXME: This is a temporary solution for sandbox injection from SandboxLayer.
# The sandbox is dynamically attached by SandboxLayer.on_node_run_start() before
# node execution and cleared by on_node_run_end(). A cleaner approach would be
# to pass sandbox through GraphRuntimeState or use a proper dependency injection pattern.
sandbox: VirtualEnvironment | None = None
node_type = NodeType.COMMAND
def _render_template(self, template: str) -> str:
parser = VariableTemplateParser(template=template)
selectors = parser.extract_variable_selectors()
@ -59,11 +42,8 @@ class CommandNode(Node[CommandNodeData]):
return parser.format(inputs)
node_type = NodeType.COMMAND
@classmethod
def get_default_config(cls, filters: Mapping[str, object] | None = None) -> Mapping[str, object]:
"""Get default config of node."""
return {
"type": "command",
"config": {
@ -91,7 +71,6 @@ class CommandNode(Node[CommandNodeData]):
raw_command = self._render_template(raw_command).strip()
working_directory = working_directory or None
timeout_seconds = COMMAND_NODE_TIMEOUT_SECONDS
if not raw_command:
return NodeRunResult(
@ -105,136 +84,52 @@ class CommandNode(Node[CommandNodeData]):
shell_command = f"cd {shlex.quote(working_directory)} && {raw_command}"
command = ["sh", "-lc", shell_command]
# 0 or negative means no timeout
deadline = None
if timeout_seconds > 0:
deadline = time.monotonic() + timeout_seconds
timeout = COMMAND_NODE_TIMEOUT_SECONDS if COMMAND_NODE_TIMEOUT_SECONDS > 0 else None
connection_handle = self.sandbox.establish_connection()
pid = ""
stdin_transport = None
stdout_transport = None
stderr_transport = None
threads: list[threading.Thread] = []
stdout_buf = bytearray()
stderr_buf = bytearray()
try:
pid, stdin_transport, stdout_transport, stderr_transport = self.sandbox.execute_command(
connection_handle, command
)
is_combined_stream = stdout_transport is stderr_transport
stdout_thread = threading.Thread(
target=_drain_transport,
args=(stdout_transport, stdout_buf),
daemon=True,
)
threads.append(stdout_thread)
stdout_thread.start()
if not is_combined_stream:
stderr_thread = threading.Thread(
target=_drain_transport,
args=(stderr_transport, stderr_buf),
daemon=True,
)
threads.append(stderr_thread)
stderr_thread.start()
exit_code: int | None = None
while True:
if deadline is not None and time.monotonic() > deadline:
raise CommandTimeoutError(f"Command timed out after {timeout_seconds}s")
try:
status = self.sandbox.get_command_status(connection_handle, pid)
except NotSupportedOperationError:
break
if status.status == status.Status.COMPLETED:
exit_code = status.exit_code
break
time.sleep(0.1)
# Ensure transports are fully drained.
def _join_all() -> bool:
for t in threads:
remaining = None
if deadline is not None:
remaining = max(0.0, deadline - time.monotonic())
t.join(timeout=remaining)
if t.is_alive():
return False
return True
if not _join_all():
raise CommandTimeoutError(f"Command output not drained within {timeout_seconds}s")
stdout_text = stdout_buf.decode("utf-8", errors="replace")
stderr_text = "" if is_combined_stream else stderr_buf.decode("utf-8", errors="replace")
future = self.sandbox.run_command(connection_handle, command)
result = future.result(timeout=timeout)
outputs: dict[str, Any] = {
"stdout": stdout_text,
"stderr": stderr_text,
"exit_code": exit_code,
"pid": pid,
"stdout": result.stdout.decode("utf-8", errors="replace"),
"stderr": result.stderr.decode("utf-8", errors="replace"),
"exit_code": result.exit_code,
"pid": result.pid,
}
process_data = {"command": command, "working_directory": working_directory}
if exit_code not in (None, 0):
if result.exit_code not in (None, 0):
return NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
outputs=outputs,
process_data={"command": command, "working_directory": working_directory},
error=f"Command exited with code {exit_code}",
process_data=process_data,
error=f"Command exited with code {result.exit_code}",
error_type=CommandExecutionError.__name__,
)
return NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED,
outputs=outputs,
process_data={"command": command, "working_directory": working_directory},
process_data=process_data,
)
except (CommandExecutionError, CommandTimeoutError) as e:
if isinstance(e, CommandTimeoutError) and stdout_transport is not None:
for transport in (stdout_transport, stderr_transport):
if transport is None:
continue
with contextlib.suppress(Exception):
transport.close()
for t in threads:
t.join(timeout=0.2)
except CommandTimeoutError:
return NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
outputs={
"stdout": stdout_buf.decode("utf-8", errors="replace"),
"stderr": stderr_buf.decode("utf-8", errors="replace"),
"exit_code": None,
"pid": pid,
},
process_data={"command": command, "working_directory": working_directory},
error=str(e),
error_type=type(e).__name__,
error=f"Command timed out after {COMMAND_NODE_TIMEOUT_SECONDS}s",
error_type=CommandTimeoutError.__name__,
)
except CommandCancelledError:
return NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
error="Command was cancelled",
error_type=CommandCancelledError.__name__,
)
except Exception as e:
logger.exception("Command node %s failed", self.id)
return NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
outputs={
"stdout": stdout_buf.decode("utf-8", errors="replace"),
"stderr": stderr_buf.decode("utf-8", errors="replace"),
"exit_code": None,
"pid": pid,
},
process_data={"command": command, "working_directory": working_directory},
error=str(e),
error_type=type(e).__name__,
)
@ -250,8 +145,7 @@ class CommandNode(Node[CommandNodeData]):
node_id: str,
node_data: Mapping[str, Any],
) -> Mapping[str, Sequence[str]]:
"""Extract variable mappings from node data."""
_ = graph_config # Explicitly mark as unused
_ = graph_config
typed_node_data = CommandNodeData.model_validate(node_data)

View File

@ -0,0 +1,123 @@
import threading
import pytest
from core.virtual_environment.__base.command_future import (
CommandCancelledError,
CommandFuture,
CommandTimeoutError,
)
from core.virtual_environment.__base.entities import CommandStatus
from core.virtual_environment.channel.queue_transport import QueueTransportReadCloser
from core.virtual_environment.channel.transport import NopTransportWriteCloser
def _make_future(
stdout: bytes = b"",
stderr: bytes = b"",
exit_code: int = 0,
delay_completion: float = 0,
close_streams: bool = True,
) -> CommandFuture:
stdout_transport = QueueTransportReadCloser()
stderr_transport = QueueTransportReadCloser()
if stdout:
stdout_transport.get_write_handler().write(stdout)
if stderr:
stderr_transport.get_write_handler().write(stderr)
if close_streams:
stdout_transport.close()
stderr_transport.close()
completion_event = threading.Event()
if delay_completion == 0:
completion_event.set()
else:
threading.Timer(delay_completion, completion_event.set).start()
def poll_status() -> CommandStatus:
if completion_event.is_set():
return CommandStatus(status=CommandStatus.Status.COMPLETED, exit_code=exit_code)
return CommandStatus(status=CommandStatus.Status.RUNNING, exit_code=None)
return CommandFuture(
pid="test-pid",
stdin_transport=NopTransportWriteCloser(),
stdout_transport=stdout_transport,
stderr_transport=stderr_transport,
poll_status=poll_status,
poll_interval=0.05,
)
def test_result_returns_command_output():
future = _make_future(stdout=b"hello\n", stderr=b"world\n", exit_code=0)
result = future.result()
assert result.stdout == b"hello\n"
assert result.stderr == b"world\n"
assert result.exit_code == 0
assert result.pid == "test-pid"
def test_result_with_timeout_succeeds_when_command_completes_in_time():
future = _make_future(stdout=b"ok", delay_completion=0.1)
result = future.result(timeout=5.0)
assert result.stdout == b"ok"
def test_result_raises_timeout_error_when_exceeded():
future = _make_future(delay_completion=10.0, close_streams=False)
with pytest.raises(CommandTimeoutError):
future.result(timeout=0.2)
def test_done_returns_false_while_running():
future = _make_future(delay_completion=10.0, close_streams=False)
assert future.done() is False
def test_done_returns_true_after_completion():
future = _make_future(stdout=b"done")
future.result()
assert future.done() is True
def test_cancel_returns_true_and_sets_cancelled():
future = _make_future(delay_completion=10.0, close_streams=False)
assert future.cancel() is True
assert future.cancelled() is True
def test_cancel_returns_false_if_already_completed():
future = _make_future(stdout=b"done")
future.result()
assert future.cancel() is False
assert future.cancelled() is False
def test_result_raises_cancelled_error_after_cancel():
future = _make_future(delay_completion=10.0, close_streams=False)
future.cancel()
with pytest.raises(CommandCancelledError):
future.result()
def test_nonzero_exit_code_is_returned():
future = _make_future(stdout=b"err", exit_code=42)
result = future.result()
assert result.exit_code == 42