diff --git a/api/core/virtual_environment/providers/ssh_sandbox.py b/api/core/virtual_environment/providers/ssh_sandbox.py index fa3b90754b..c8e81645d9 100644 --- a/api/core/virtual_environment/providers/ssh_sandbox.py +++ b/api/core/virtual_environment/providers/ssh_sandbox.py @@ -1,6 +1,7 @@ from __future__ import annotations import contextlib +import logging import shlex import stat import threading @@ -27,6 +28,8 @@ from core.virtual_environment.channel.exec import TransportEOFError from core.virtual_environment.channel.queue_transport import QueueTransportReadCloser from core.virtual_environment.channel.transport import TransportWriteCloser +logger = logging.getLogger(__name__) + class _SSHStdinTransport(TransportWriteCloser): def __init__(self, channel: Any): @@ -52,6 +55,10 @@ class SSHSandboxEnvironment(VirtualEnvironment): _DEFAULT_SSH_HOST = "agentbox" _DEFAULT_SSH_PORT = 22 _DEFAULT_BASE_WORKING_PATH = "/workspace/sandboxes" + _DEFAULT_SSH_CONNECT_TIMEOUT_SECONDS = 10 + _DEFAULT_SSH_OPERATION_TIMEOUT_SECONDS = 30 + _DEFAULT_COMMAND_MAX_RUNTIME_SECONDS = 60 * 60 + _COMMAND_TIMEOUT_EXIT_CODE = 124 class OptionsKey(StrEnum): SSH_HOST = "ssh_host" @@ -141,6 +148,7 @@ class SSHSandboxEnvironment(VirtualEnvironment): raise RuntimeError("SSH transport is not available") channel = transport.open_session() + channel.settimeout(self._DEFAULT_SSH_OPERATION_TIMEOUT_SECONDS) channel.set_combine_stderr(False) execution_command = self._build_exec_command(command, environments, cwd) @@ -156,7 +164,7 @@ class SSHSandboxEnvironment(VirtualEnvironment): threading.Thread( target=self._consume_channel_output, - args=(pid, channel, stdout_transport, stderr_transport), + args=(pid, channel, stdout_transport, stderr_transport, self._DEFAULT_COMMAND_MAX_RUNTIME_SECONDS), daemon=True, ).start() @@ -174,6 +182,7 @@ class SSHSandboxEnvironment(VirtualEnvironment): with self._client() as client: sftp = client.open_sftp() try: + self._set_sftp_operation_timeout(sftp) self._sftp_mkdirs(sftp, str(PurePosixPath(destination_path).parent)) with sftp.file(destination_path, "wb") as remote_file: remote_file.write(content.getvalue()) @@ -185,6 +194,7 @@ class SSHSandboxEnvironment(VirtualEnvironment): with self._client() as client: sftp = client.open_sftp() try: + self._set_sftp_operation_timeout(sftp) with sftp.file(source_path, "rb") as remote_file: return BytesIO(remote_file.read()) finally: @@ -200,6 +210,7 @@ class SSHSandboxEnvironment(VirtualEnvironment): with self._client() as client: sftp = client.open_sftp() try: + self._set_sftp_operation_timeout(sftp) pending = [root_directory] while pending and len(files) < limit: current_directory = pending.pop(0) @@ -261,8 +272,13 @@ class SSHSandboxEnvironment(VirtualEnvironment): password=password, look_for_keys=False, allow_agent=False, - timeout=10, + timeout=cls._DEFAULT_SSH_CONNECT_TIMEOUT_SECONDS, + banner_timeout=cls._DEFAULT_SSH_CONNECT_TIMEOUT_SECONDS, + auth_timeout=cls._DEFAULT_SSH_CONNECT_TIMEOUT_SECONDS, ) + transport = client.get_transport() + if transport is not None: + transport.set_keepalive(30) except Exception as e: with contextlib.suppress(Exception): client.close() @@ -341,9 +357,19 @@ class SSHSandboxEnvironment(VirtualEnvironment): command_body += shlex.join(command) return f"sh -lc {shlex.quote(command_body)}" - @staticmethod - def _run_command(client: Any, command: str) -> bytes: - _, stdout, stderr = client.exec_command(command) + @classmethod + def _run_command(cls, client: Any, command: str) -> bytes: + _, stdout, stderr = client.exec_command(command, timeout=cls._DEFAULT_SSH_OPERATION_TIMEOUT_SECONDS) + stdout.channel.settimeout(cls._DEFAULT_SSH_OPERATION_TIMEOUT_SECONDS) + + deadline = time.monotonic() + cls._DEFAULT_COMMAND_MAX_RUNTIME_SECONDS + while not stdout.channel.exit_status_ready(): + if time.monotonic() >= deadline: + with contextlib.suppress(Exception): + stdout.channel.close() + raise TimeoutError(f"SSH command timed out after {cls._DEFAULT_COMMAND_MAX_RUNTIME_SECONDS}s") + time.sleep(0.05) + exit_code = stdout.channel.recv_exit_status() stdout_data = stdout.read() stderr_data = stderr.read() @@ -360,13 +386,20 @@ class SSHSandboxEnvironment(VirtualEnvironment): channel: Any, stdout_transport: QueueTransportReadCloser, stderr_transport: QueueTransportReadCloser, + max_runtime_seconds: int, ) -> None: stdout_writer = stdout_transport.get_write_handler() stderr_writer = stderr_transport.get_write_handler() exit_code: int | None = None + started_at = time.monotonic() try: while True: + if time.monotonic() - started_at >= max_runtime_seconds: + exit_code = self._COMMAND_TIMEOUT_EXIT_CODE + stderr_writer.write(f"Command timed out after {max_runtime_seconds}s".encode()) + break + if channel.recv_ready(): stdout_writer.write(channel.recv(4096)) if channel.recv_stderr_ready(): @@ -377,6 +410,9 @@ class SSHSandboxEnvironment(VirtualEnvironment): break time.sleep(0.05) + except TimeoutError: + logger.warning("SSH channel read timed out for command %s", pid) + exit_code = self._COMMAND_TIMEOUT_EXIT_CODE finally: with contextlib.suppress(Exception): stdout_transport.close() @@ -388,6 +424,10 @@ class SSHSandboxEnvironment(VirtualEnvironment): with self._lock: self._commands[pid] = CommandStatus(status=CommandStatus.Status.COMPLETED, exit_code=exit_code) + def _set_sftp_operation_timeout(self, sftp: Any) -> None: + with contextlib.suppress(Exception): + sftp.get_channel().settimeout(self._DEFAULT_SSH_OPERATION_TIMEOUT_SECONDS) + @staticmethod def _parse_arch(raw_arch: str) -> Arch: arch = raw_arch.lower()