mirror of
https://github.com/langgenius/dify.git
synced 2026-05-13 00:33:37 +08:00
feat: future interface for easy way to use VM.execute_command
This commit is contained in:
parent
888be71639
commit
05c3344554
170
api/core/virtual_environment/__base/command_future.py
Normal file
170
api/core/virtual_environment/__base/command_future.py
Normal file
@ -0,0 +1,170 @@
|
||||
import contextlib
|
||||
import logging
|
||||
import threading
|
||||
import time
|
||||
from collections.abc import Callable
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
|
||||
from core.virtual_environment.__base.entities import CommandResult, CommandStatus
|
||||
from core.virtual_environment.__base.exec import NotSupportedOperationError
|
||||
from core.virtual_environment.channel.exec import TransportEOFError
|
||||
from core.virtual_environment.channel.transport import TransportReadCloser, TransportWriteCloser
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class CommandTimeoutError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class CommandCancelledError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class CommandFuture:
|
||||
"""
|
||||
Lightweight future for command execution.
|
||||
Mirrors concurrent.futures.Future API with 4 essential methods:
|
||||
result(), done(), cancel(), cancelled().
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
pid: str,
|
||||
stdin_transport: TransportWriteCloser,
|
||||
stdout_transport: TransportReadCloser,
|
||||
stderr_transport: TransportReadCloser,
|
||||
poll_status: Callable[[], CommandStatus],
|
||||
poll_interval: float = 0.1,
|
||||
):
|
||||
self._pid = pid
|
||||
self._stdin_transport = stdin_transport
|
||||
self._stdout_transport = stdout_transport
|
||||
self._stderr_transport = stderr_transport
|
||||
self._poll_status = poll_status
|
||||
self._poll_interval = poll_interval
|
||||
|
||||
self._done_event = threading.Event()
|
||||
self._lock = threading.Lock()
|
||||
self._result: CommandResult | None = None
|
||||
self._cancelled = False
|
||||
self._started = False
|
||||
|
||||
def result(self, timeout: float | None = None) -> CommandResult:
|
||||
"""
|
||||
Block until command completes and return result.
|
||||
|
||||
Args:
|
||||
timeout: Maximum seconds to wait. None means wait forever.
|
||||
|
||||
Raises:
|
||||
CommandTimeoutError: If timeout exceeded.
|
||||
CommandCancelledError: If command was cancelled.
|
||||
"""
|
||||
self._ensure_started()
|
||||
|
||||
if not self._done_event.wait(timeout):
|
||||
raise CommandTimeoutError(f"Command timed out after {timeout}s")
|
||||
|
||||
if self._cancelled:
|
||||
raise CommandCancelledError("Command was cancelled")
|
||||
|
||||
assert self._result is not None
|
||||
return self._result
|
||||
|
||||
def done(self) -> bool:
|
||||
self._ensure_started()
|
||||
return self._done_event.is_set()
|
||||
|
||||
def cancel(self) -> bool:
|
||||
"""
|
||||
Attempt to cancel command by closing transports.
|
||||
Returns True if cancelled, False if already completed.
|
||||
"""
|
||||
with self._lock:
|
||||
if self._done_event.is_set():
|
||||
return False
|
||||
self._cancelled = True
|
||||
self._close_transports()
|
||||
self._done_event.set()
|
||||
return True
|
||||
|
||||
def cancelled(self) -> bool:
|
||||
return self._cancelled
|
||||
|
||||
def _ensure_started(self) -> None:
|
||||
with self._lock:
|
||||
if not self._started:
|
||||
self._started = True
|
||||
thread = threading.Thread(target=self._execute, daemon=True)
|
||||
thread.start()
|
||||
|
||||
def _execute(self) -> None:
|
||||
stdout_buf = bytearray()
|
||||
stderr_buf = bytearray()
|
||||
is_combined_stream = self._stdout_transport is self._stderr_transport
|
||||
|
||||
try:
|
||||
with ThreadPoolExecutor(max_workers=2) as executor:
|
||||
stdout_future = executor.submit(self._drain_transport, self._stdout_transport, stdout_buf)
|
||||
stderr_future = None
|
||||
if not is_combined_stream:
|
||||
stderr_future = executor.submit(self._drain_transport, self._stderr_transport, stderr_buf)
|
||||
|
||||
exit_code = self._wait_for_completion()
|
||||
|
||||
stdout_future.result()
|
||||
if stderr_future:
|
||||
stderr_future.result()
|
||||
|
||||
with self._lock:
|
||||
if not self._cancelled:
|
||||
self._result = CommandResult(
|
||||
stdout=bytes(stdout_buf),
|
||||
stderr=b"" if is_combined_stream else bytes(stderr_buf),
|
||||
exit_code=exit_code,
|
||||
pid=self._pid,
|
||||
)
|
||||
self._done_event.set()
|
||||
|
||||
except Exception:
|
||||
logger.exception("Command execution failed for pid %s", self._pid)
|
||||
with self._lock:
|
||||
if not self._cancelled:
|
||||
self._result = CommandResult(
|
||||
stdout=bytes(stdout_buf),
|
||||
stderr=b"" if is_combined_stream else bytes(stderr_buf),
|
||||
exit_code=None,
|
||||
pid=self._pid,
|
||||
)
|
||||
self._done_event.set()
|
||||
finally:
|
||||
self._close_transports()
|
||||
|
||||
def _wait_for_completion(self) -> int | None:
|
||||
while not self._cancelled:
|
||||
try:
|
||||
status = self._poll_status()
|
||||
except NotSupportedOperationError:
|
||||
return None
|
||||
|
||||
if status.status == CommandStatus.Status.COMPLETED:
|
||||
return status.exit_code
|
||||
|
||||
time.sleep(self._poll_interval)
|
||||
|
||||
return None
|
||||
|
||||
def _drain_transport(self, transport: TransportReadCloser, buffer: bytearray) -> None:
|
||||
try:
|
||||
while True:
|
||||
buffer.extend(transport.read(4096))
|
||||
except TransportEOFError:
|
||||
pass
|
||||
except Exception:
|
||||
logger.exception("Failed reading transport")
|
||||
|
||||
def _close_transports(self) -> None:
|
||||
for transport in (self._stdin_transport, self._stdout_transport, self._stderr_transport):
|
||||
with contextlib.suppress(Exception):
|
||||
transport.close()
|
||||
@ -56,3 +56,14 @@ class FileState(BaseModel):
|
||||
path: str = Field(description="The path of the file in the virtual environment.")
|
||||
created_at: int = Field(description="The creation timestamp of the file.")
|
||||
updated_at: int = Field(description="The last modified timestamp of the file.")
|
||||
|
||||
|
||||
class CommandResult(BaseModel):
|
||||
"""
|
||||
Result of a synchronous command execution.
|
||||
"""
|
||||
|
||||
stdout: bytes = Field(description="Standard output content.")
|
||||
stderr: bytes = Field(description="Standard error content.")
|
||||
exit_code: int | None = Field(description="Exit code of the command. None if unavailable.")
|
||||
pid: str = Field(description="Process ID of the executed command.")
|
||||
|
||||
@ -1,8 +1,10 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Mapping, Sequence
|
||||
from functools import partial
|
||||
from io import BytesIO
|
||||
from typing import Any
|
||||
|
||||
from core.virtual_environment.__base.command_future import CommandFuture
|
||||
from core.virtual_environment.__base.entities import CommandStatus, ConnectionHandle, FileState, Metadata
|
||||
from core.virtual_environment.channel.transport import TransportReadCloser, TransportWriteCloser
|
||||
|
||||
@ -144,3 +146,38 @@ class VirtualEnvironment(ABC):
|
||||
Returns:
|
||||
CommandStatus: The status of the command execution.
|
||||
"""
|
||||
|
||||
def run_command(
|
||||
self,
|
||||
connection_handle: ConnectionHandle,
|
||||
command: list[str],
|
||||
environments: Mapping[str, str] | None = None,
|
||||
) -> CommandFuture:
|
||||
"""
|
||||
Execute a command and return a Future for the result.
|
||||
|
||||
High-level interface that handles IO draining internally.
|
||||
For streaming output, use execute_command() instead.
|
||||
|
||||
Args:
|
||||
connection_handle: The connection handle.
|
||||
command: Command as list of strings.
|
||||
environments: Environment variables.
|
||||
|
||||
Returns:
|
||||
CommandFuture that can be used to get result with timeout or cancel.
|
||||
|
||||
Example:
|
||||
result = env.run_command(handle, ["ls", "-la"]).result(timeout=30)
|
||||
"""
|
||||
pid, stdin_transport, stdout_transport, stderr_transport = self.execute_command(
|
||||
connection_handle, command, environments
|
||||
)
|
||||
|
||||
return CommandFuture(
|
||||
pid=pid,
|
||||
stdin_transport=stdin_transport,
|
||||
stdout_transport=stdout_transport,
|
||||
stderr_transport=stderr_transport,
|
||||
poll_status=partial(self.get_command_status, connection_handle, pid),
|
||||
)
|
||||
|
||||
@ -1,15 +1,11 @@
|
||||
import contextlib
|
||||
import logging
|
||||
import shlex
|
||||
import threading
|
||||
import time
|
||||
from collections.abc import Mapping, Sequence
|
||||
from typing import Any
|
||||
|
||||
from core.virtual_environment.__base.exec import NotSupportedOperationError
|
||||
from core.virtual_environment.__base.command_future import CommandCancelledError, CommandTimeoutError
|
||||
from core.virtual_environment.__base.virtual_environment import VirtualEnvironment
|
||||
from core.virtual_environment.channel.exec import TransportEOFError
|
||||
from core.virtual_environment.channel.transport import TransportReadCloser
|
||||
from core.workflow.enums import NodeType, WorkflowNodeExecutionStatus
|
||||
from core.workflow.node_events import NodeRunResult
|
||||
from core.workflow.nodes.base import variable_template_parser
|
||||
@ -17,35 +13,22 @@ from core.workflow.nodes.base.entities import VariableSelector
|
||||
from core.workflow.nodes.base.node import Node
|
||||
from core.workflow.nodes.base.variable_template_parser import VariableTemplateParser
|
||||
from core.workflow.nodes.command.entities import CommandNodeData
|
||||
from core.workflow.nodes.command.exc import CommandExecutionError, CommandTimeoutError
|
||||
from core.workflow.nodes.command.exc import CommandExecutionError
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
COMMAND_NODE_TIMEOUT_SECONDS = 60
|
||||
|
||||
|
||||
def _drain_transport(transport: TransportReadCloser, buffer: bytearray) -> None:
|
||||
try:
|
||||
while True:
|
||||
buffer.extend(transport.read(4096))
|
||||
except TransportEOFError:
|
||||
pass
|
||||
except Exception:
|
||||
logger.exception("Failed reading transport")
|
||||
finally:
|
||||
with contextlib.suppress(Exception):
|
||||
transport.close()
|
||||
|
||||
|
||||
class CommandNode(Node[CommandNodeData]):
|
||||
"""Command Node - execute shell commands in a VirtualEnvironment."""
|
||||
|
||||
# FIXME: This is a temporary solution for sandbox injection from SandboxLayer.
|
||||
# The sandbox is dynamically attached by SandboxLayer.on_node_run_start() before
|
||||
# node execution and cleared by on_node_run_end(). A cleaner approach would be
|
||||
# to pass sandbox through GraphRuntimeState or use a proper dependency injection pattern.
|
||||
sandbox: VirtualEnvironment | None = None
|
||||
|
||||
node_type = NodeType.COMMAND
|
||||
|
||||
def _render_template(self, template: str) -> str:
|
||||
parser = VariableTemplateParser(template=template)
|
||||
selectors = parser.extract_variable_selectors()
|
||||
@ -59,11 +42,8 @@ class CommandNode(Node[CommandNodeData]):
|
||||
|
||||
return parser.format(inputs)
|
||||
|
||||
node_type = NodeType.COMMAND
|
||||
|
||||
@classmethod
|
||||
def get_default_config(cls, filters: Mapping[str, object] | None = None) -> Mapping[str, object]:
|
||||
"""Get default config of node."""
|
||||
return {
|
||||
"type": "command",
|
||||
"config": {
|
||||
@ -91,7 +71,6 @@ class CommandNode(Node[CommandNodeData]):
|
||||
raw_command = self._render_template(raw_command).strip()
|
||||
|
||||
working_directory = working_directory or None
|
||||
timeout_seconds = COMMAND_NODE_TIMEOUT_SECONDS
|
||||
|
||||
if not raw_command:
|
||||
return NodeRunResult(
|
||||
@ -105,136 +84,52 @@ class CommandNode(Node[CommandNodeData]):
|
||||
shell_command = f"cd {shlex.quote(working_directory)} && {raw_command}"
|
||||
|
||||
command = ["sh", "-lc", shell_command]
|
||||
|
||||
# 0 or negative means no timeout
|
||||
deadline = None
|
||||
if timeout_seconds > 0:
|
||||
deadline = time.monotonic() + timeout_seconds
|
||||
timeout = COMMAND_NODE_TIMEOUT_SECONDS if COMMAND_NODE_TIMEOUT_SECONDS > 0 else None
|
||||
|
||||
connection_handle = self.sandbox.establish_connection()
|
||||
|
||||
pid = ""
|
||||
stdin_transport = None
|
||||
stdout_transport = None
|
||||
stderr_transport = None
|
||||
threads: list[threading.Thread] = []
|
||||
stdout_buf = bytearray()
|
||||
stderr_buf = bytearray()
|
||||
|
||||
try:
|
||||
pid, stdin_transport, stdout_transport, stderr_transport = self.sandbox.execute_command(
|
||||
connection_handle, command
|
||||
)
|
||||
|
||||
is_combined_stream = stdout_transport is stderr_transport
|
||||
|
||||
stdout_thread = threading.Thread(
|
||||
target=_drain_transport,
|
||||
args=(stdout_transport, stdout_buf),
|
||||
daemon=True,
|
||||
)
|
||||
threads.append(stdout_thread)
|
||||
stdout_thread.start()
|
||||
|
||||
if not is_combined_stream:
|
||||
stderr_thread = threading.Thread(
|
||||
target=_drain_transport,
|
||||
args=(stderr_transport, stderr_buf),
|
||||
daemon=True,
|
||||
)
|
||||
threads.append(stderr_thread)
|
||||
stderr_thread.start()
|
||||
|
||||
exit_code: int | None = None
|
||||
|
||||
while True:
|
||||
if deadline is not None and time.monotonic() > deadline:
|
||||
raise CommandTimeoutError(f"Command timed out after {timeout_seconds}s")
|
||||
|
||||
try:
|
||||
status = self.sandbox.get_command_status(connection_handle, pid)
|
||||
except NotSupportedOperationError:
|
||||
break
|
||||
|
||||
if status.status == status.Status.COMPLETED:
|
||||
exit_code = status.exit_code
|
||||
break
|
||||
|
||||
time.sleep(0.1)
|
||||
|
||||
# Ensure transports are fully drained.
|
||||
def _join_all() -> bool:
|
||||
for t in threads:
|
||||
remaining = None
|
||||
if deadline is not None:
|
||||
remaining = max(0.0, deadline - time.monotonic())
|
||||
t.join(timeout=remaining)
|
||||
if t.is_alive():
|
||||
return False
|
||||
return True
|
||||
|
||||
if not _join_all():
|
||||
raise CommandTimeoutError(f"Command output not drained within {timeout_seconds}s")
|
||||
|
||||
stdout_text = stdout_buf.decode("utf-8", errors="replace")
|
||||
stderr_text = "" if is_combined_stream else stderr_buf.decode("utf-8", errors="replace")
|
||||
future = self.sandbox.run_command(connection_handle, command)
|
||||
result = future.result(timeout=timeout)
|
||||
|
||||
outputs: dict[str, Any] = {
|
||||
"stdout": stdout_text,
|
||||
"stderr": stderr_text,
|
||||
"exit_code": exit_code,
|
||||
"pid": pid,
|
||||
"stdout": result.stdout.decode("utf-8", errors="replace"),
|
||||
"stderr": result.stderr.decode("utf-8", errors="replace"),
|
||||
"exit_code": result.exit_code,
|
||||
"pid": result.pid,
|
||||
}
|
||||
process_data = {"command": command, "working_directory": working_directory}
|
||||
|
||||
if exit_code not in (None, 0):
|
||||
if result.exit_code not in (None, 0):
|
||||
return NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.FAILED,
|
||||
outputs=outputs,
|
||||
process_data={"command": command, "working_directory": working_directory},
|
||||
error=f"Command exited with code {exit_code}",
|
||||
process_data=process_data,
|
||||
error=f"Command exited with code {result.exit_code}",
|
||||
error_type=CommandExecutionError.__name__,
|
||||
)
|
||||
|
||||
return NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||
outputs=outputs,
|
||||
process_data={"command": command, "working_directory": working_directory},
|
||||
process_data=process_data,
|
||||
)
|
||||
|
||||
except (CommandExecutionError, CommandTimeoutError) as e:
|
||||
if isinstance(e, CommandTimeoutError) and stdout_transport is not None:
|
||||
for transport in (stdout_transport, stderr_transport):
|
||||
if transport is None:
|
||||
continue
|
||||
with contextlib.suppress(Exception):
|
||||
transport.close()
|
||||
|
||||
for t in threads:
|
||||
t.join(timeout=0.2)
|
||||
|
||||
except CommandTimeoutError:
|
||||
return NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.FAILED,
|
||||
outputs={
|
||||
"stdout": stdout_buf.decode("utf-8", errors="replace"),
|
||||
"stderr": stderr_buf.decode("utf-8", errors="replace"),
|
||||
"exit_code": None,
|
||||
"pid": pid,
|
||||
},
|
||||
process_data={"command": command, "working_directory": working_directory},
|
||||
error=str(e),
|
||||
error_type=type(e).__name__,
|
||||
error=f"Command timed out after {COMMAND_NODE_TIMEOUT_SECONDS}s",
|
||||
error_type=CommandTimeoutError.__name__,
|
||||
)
|
||||
except CommandCancelledError:
|
||||
return NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.FAILED,
|
||||
error="Command was cancelled",
|
||||
error_type=CommandCancelledError.__name__,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.exception("Command node %s failed", self.id)
|
||||
return NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.FAILED,
|
||||
outputs={
|
||||
"stdout": stdout_buf.decode("utf-8", errors="replace"),
|
||||
"stderr": stderr_buf.decode("utf-8", errors="replace"),
|
||||
"exit_code": None,
|
||||
"pid": pid,
|
||||
},
|
||||
process_data={"command": command, "working_directory": working_directory},
|
||||
error=str(e),
|
||||
error_type=type(e).__name__,
|
||||
)
|
||||
@ -250,8 +145,7 @@ class CommandNode(Node[CommandNodeData]):
|
||||
node_id: str,
|
||||
node_data: Mapping[str, Any],
|
||||
) -> Mapping[str, Sequence[str]]:
|
||||
"""Extract variable mappings from node data."""
|
||||
_ = graph_config # Explicitly mark as unused
|
||||
_ = graph_config
|
||||
|
||||
typed_node_data = CommandNodeData.model_validate(node_data)
|
||||
|
||||
|
||||
@ -0,0 +1,123 @@
|
||||
import threading
|
||||
|
||||
import pytest
|
||||
|
||||
from core.virtual_environment.__base.command_future import (
|
||||
CommandCancelledError,
|
||||
CommandFuture,
|
||||
CommandTimeoutError,
|
||||
)
|
||||
from core.virtual_environment.__base.entities import CommandStatus
|
||||
from core.virtual_environment.channel.queue_transport import QueueTransportReadCloser
|
||||
from core.virtual_environment.channel.transport import NopTransportWriteCloser
|
||||
|
||||
|
||||
def _make_future(
|
||||
stdout: bytes = b"",
|
||||
stderr: bytes = b"",
|
||||
exit_code: int = 0,
|
||||
delay_completion: float = 0,
|
||||
close_streams: bool = True,
|
||||
) -> CommandFuture:
|
||||
stdout_transport = QueueTransportReadCloser()
|
||||
stderr_transport = QueueTransportReadCloser()
|
||||
|
||||
if stdout:
|
||||
stdout_transport.get_write_handler().write(stdout)
|
||||
if stderr:
|
||||
stderr_transport.get_write_handler().write(stderr)
|
||||
|
||||
if close_streams:
|
||||
stdout_transport.close()
|
||||
stderr_transport.close()
|
||||
|
||||
completion_event = threading.Event()
|
||||
if delay_completion == 0:
|
||||
completion_event.set()
|
||||
else:
|
||||
threading.Timer(delay_completion, completion_event.set).start()
|
||||
|
||||
def poll_status() -> CommandStatus:
|
||||
if completion_event.is_set():
|
||||
return CommandStatus(status=CommandStatus.Status.COMPLETED, exit_code=exit_code)
|
||||
return CommandStatus(status=CommandStatus.Status.RUNNING, exit_code=None)
|
||||
|
||||
return CommandFuture(
|
||||
pid="test-pid",
|
||||
stdin_transport=NopTransportWriteCloser(),
|
||||
stdout_transport=stdout_transport,
|
||||
stderr_transport=stderr_transport,
|
||||
poll_status=poll_status,
|
||||
poll_interval=0.05,
|
||||
)
|
||||
|
||||
|
||||
def test_result_returns_command_output():
|
||||
future = _make_future(stdout=b"hello\n", stderr=b"world\n", exit_code=0)
|
||||
|
||||
result = future.result()
|
||||
|
||||
assert result.stdout == b"hello\n"
|
||||
assert result.stderr == b"world\n"
|
||||
assert result.exit_code == 0
|
||||
assert result.pid == "test-pid"
|
||||
|
||||
|
||||
def test_result_with_timeout_succeeds_when_command_completes_in_time():
|
||||
future = _make_future(stdout=b"ok", delay_completion=0.1)
|
||||
|
||||
result = future.result(timeout=5.0)
|
||||
|
||||
assert result.stdout == b"ok"
|
||||
|
||||
|
||||
def test_result_raises_timeout_error_when_exceeded():
|
||||
future = _make_future(delay_completion=10.0, close_streams=False)
|
||||
|
||||
with pytest.raises(CommandTimeoutError):
|
||||
future.result(timeout=0.2)
|
||||
|
||||
|
||||
def test_done_returns_false_while_running():
|
||||
future = _make_future(delay_completion=10.0, close_streams=False)
|
||||
|
||||
assert future.done() is False
|
||||
|
||||
|
||||
def test_done_returns_true_after_completion():
|
||||
future = _make_future(stdout=b"done")
|
||||
|
||||
future.result()
|
||||
|
||||
assert future.done() is True
|
||||
|
||||
|
||||
def test_cancel_returns_true_and_sets_cancelled():
|
||||
future = _make_future(delay_completion=10.0, close_streams=False)
|
||||
|
||||
assert future.cancel() is True
|
||||
assert future.cancelled() is True
|
||||
|
||||
|
||||
def test_cancel_returns_false_if_already_completed():
|
||||
future = _make_future(stdout=b"done")
|
||||
future.result()
|
||||
|
||||
assert future.cancel() is False
|
||||
assert future.cancelled() is False
|
||||
|
||||
|
||||
def test_result_raises_cancelled_error_after_cancel():
|
||||
future = _make_future(delay_completion=10.0, close_streams=False)
|
||||
future.cancel()
|
||||
|
||||
with pytest.raises(CommandCancelledError):
|
||||
future.result()
|
||||
|
||||
|
||||
def test_nonzero_exit_code_is_returned():
|
||||
future = _make_future(stdout=b"err", exit_code=42)
|
||||
|
||||
result = future.result()
|
||||
|
||||
assert result.exit_code == 42
|
||||
Loading…
Reference in New Issue
Block a user