mirror of
https://github.com/langgenius/dify.git
synced 2026-05-11 06:37:13 +08:00
feat: introduce TransportEOFError for handling closed transport scenarios and update transport classes to raise it
This commit is contained in:
parent
180fdffab1
commit
2673fe05a5
4
api/core/virtual_environment/channel/exec.py
Normal file
4
api/core/virtual_environment/channel/exec.py
Normal file
@ -0,0 +1,4 @@
|
|||||||
|
class TransportEOFError(Exception):
|
||||||
|
"""Exception raised when attempting to read from a closed transport."""
|
||||||
|
|
||||||
|
pass
|
||||||
@ -1,5 +1,6 @@
|
|||||||
import os
|
import os
|
||||||
|
|
||||||
|
from core.virtual_environment.channel.exec import TransportEOFError
|
||||||
from core.virtual_environment.channel.transport import Transport, TransportReadCloser, TransportWriteCloser
|
from core.virtual_environment.channel.transport import Transport, TransportReadCloser, TransportWriteCloser
|
||||||
|
|
||||||
|
|
||||||
@ -18,10 +19,16 @@ class PipeTransport(Transport):
|
|||||||
self.w_fd = w_fd
|
self.w_fd = w_fd
|
||||||
|
|
||||||
def write(self, data: bytes) -> None:
|
def write(self, data: bytes) -> None:
|
||||||
os.write(self.w_fd, data)
|
try:
|
||||||
|
os.write(self.w_fd, data)
|
||||||
|
except OSError:
|
||||||
|
raise TransportEOFError("Pipe write error, maybe the read end is closed")
|
||||||
|
|
||||||
def read(self, n: int) -> bytes:
|
def read(self, n: int) -> bytes:
|
||||||
return os.read(self.r_fd, n)
|
data = os.read(self.r_fd, n)
|
||||||
|
if data == b"":
|
||||||
|
raise TransportEOFError("End of Pipe reached")
|
||||||
|
return data
|
||||||
|
|
||||||
def close(self) -> None:
|
def close(self) -> None:
|
||||||
os.close(self.r_fd)
|
os.close(self.r_fd)
|
||||||
@ -37,7 +44,11 @@ class PipeReadCloser(TransportReadCloser):
|
|||||||
self.r_fd = r_fd
|
self.r_fd = r_fd
|
||||||
|
|
||||||
def read(self, n: int) -> bytes:
|
def read(self, n: int) -> bytes:
|
||||||
return os.read(self.r_fd, n)
|
data = os.read(self.r_fd, n)
|
||||||
|
if data == b"":
|
||||||
|
raise TransportEOFError("End of Pipe reached")
|
||||||
|
|
||||||
|
return data
|
||||||
|
|
||||||
def close(self) -> None:
|
def close(self) -> None:
|
||||||
os.close(self.r_fd)
|
os.close(self.r_fd)
|
||||||
@ -52,7 +63,10 @@ class PipeWriteCloser(TransportWriteCloser):
|
|||||||
self.w_fd = w_fd
|
self.w_fd = w_fd
|
||||||
|
|
||||||
def write(self, data: bytes) -> None:
|
def write(self, data: bytes) -> None:
|
||||||
os.write(self.w_fd, data)
|
try:
|
||||||
|
os.write(self.w_fd, data)
|
||||||
|
except OSError:
|
||||||
|
raise TransportEOFError("Pipe write error, maybe the read end is closed")
|
||||||
|
|
||||||
def close(self) -> None:
|
def close(self) -> None:
|
||||||
os.close(self.w_fd)
|
os.close(self.w_fd)
|
||||||
|
|||||||
@ -1,5 +1,6 @@
|
|||||||
from queue import Queue
|
from queue import Queue
|
||||||
|
|
||||||
|
from core.virtual_environment.channel.exec import TransportEOFError
|
||||||
from core.virtual_environment.channel.transport import TransportReadCloser
|
from core.virtual_environment.channel.transport import TransportReadCloser
|
||||||
|
|
||||||
|
|
||||||
@ -39,6 +40,9 @@ class QueueTransportReadCloser(TransportReadCloser):
|
|||||||
Initialize the QueueTransportReadCloser with write function.
|
Initialize the QueueTransportReadCloser with write function.
|
||||||
"""
|
"""
|
||||||
self.q = Queue[bytes | None]()
|
self.q = Queue[bytes | None]()
|
||||||
|
self._read_buffer = bytearray()
|
||||||
|
self._closed = False
|
||||||
|
self._write_channel_closed = False
|
||||||
|
|
||||||
def get_write_handler(self) -> WriteHandler:
|
def get_write_handler(self) -> WriteHandler:
|
||||||
"""
|
"""
|
||||||
@ -50,17 +54,47 @@ class QueueTransportReadCloser(TransportReadCloser):
|
|||||||
"""
|
"""
|
||||||
Close the transport by putting a sentinel value in the queue.
|
Close the transport by putting a sentinel value in the queue.
|
||||||
"""
|
"""
|
||||||
|
if self._write_channel_closed:
|
||||||
|
raise TransportEOFError("Write channel already closed")
|
||||||
|
|
||||||
|
self._write_channel_closed = True
|
||||||
self.q.put(None)
|
self.q.put(None)
|
||||||
|
|
||||||
def read(self, n: int) -> bytes:
|
def read(self, n: int) -> bytes:
|
||||||
"""
|
"""
|
||||||
Read up to n bytes from the queue.
|
Read up to n bytes from the queue.
|
||||||
|
|
||||||
|
NEVER USE IT IN A MULTI-THREADED CONTEXT WITHOUT PROPER SYNCHRONIZATION.
|
||||||
"""
|
"""
|
||||||
data = bytearray()
|
if n <= 0:
|
||||||
while len(data) < n:
|
return b""
|
||||||
|
|
||||||
|
if self._closed:
|
||||||
|
raise TransportEOFError("Transport is closed")
|
||||||
|
|
||||||
|
to_return = self._drain_buffer(n)
|
||||||
|
while len(to_return) < n and not self._closed:
|
||||||
chunk = self.q.get()
|
chunk = self.q.get()
|
||||||
if chunk is None:
|
if chunk is None:
|
||||||
break
|
self._closed = True
|
||||||
data.extend(chunk)
|
raise TransportEOFError("Transport is closed")
|
||||||
|
|
||||||
return bytes(data)
|
self._read_buffer.extend(chunk)
|
||||||
|
|
||||||
|
if n - len(to_return) > 0:
|
||||||
|
# Drain the buffer if we still need more data
|
||||||
|
to_return += self._drain_buffer(n - len(to_return))
|
||||||
|
else:
|
||||||
|
# No more data needed, break
|
||||||
|
break
|
||||||
|
|
||||||
|
if self.q.qsize() == 0:
|
||||||
|
# If no more data is available, break to return what we have
|
||||||
|
break
|
||||||
|
|
||||||
|
return to_return
|
||||||
|
|
||||||
|
def _drain_buffer(self, n: int) -> bytes:
|
||||||
|
data = bytes(self._read_buffer[:n])
|
||||||
|
del self._read_buffer[:n]
|
||||||
|
return data
|
||||||
|
|||||||
@ -1,5 +1,6 @@
|
|||||||
import socket
|
import socket
|
||||||
|
|
||||||
|
from core.virtual_environment.channel.exec import TransportEOFError
|
||||||
from core.virtual_environment.channel.transport import Transport, TransportReadCloser, TransportWriteCloser
|
from core.virtual_environment.channel.transport import Transport, TransportReadCloser, TransportWriteCloser
|
||||||
|
|
||||||
|
|
||||||
@ -12,10 +13,19 @@ class SocketTransport(Transport):
|
|||||||
self.sock = sock
|
self.sock = sock
|
||||||
|
|
||||||
def write(self, data: bytes) -> None:
|
def write(self, data: bytes) -> None:
|
||||||
self.sock.write(data)
|
try:
|
||||||
|
self.sock.write(data)
|
||||||
|
except (ConnectionResetError, BrokenPipeError):
|
||||||
|
raise TransportEOFError("Socket write error, maybe the read end is closed")
|
||||||
|
|
||||||
def read(self, n: int) -> bytes:
|
def read(self, n: int) -> bytes:
|
||||||
return self.sock.read(n)
|
try:
|
||||||
|
data = self.sock.read(n)
|
||||||
|
if data == b"":
|
||||||
|
raise TransportEOFError("End of Socket reached")
|
||||||
|
except (ConnectionResetError, BrokenPipeError):
|
||||||
|
raise TransportEOFError("Socket connection reset")
|
||||||
|
return data
|
||||||
|
|
||||||
def close(self) -> None:
|
def close(self) -> None:
|
||||||
self.sock.close()
|
self.sock.close()
|
||||||
@ -30,7 +40,13 @@ class SocketReadCloser(TransportReadCloser):
|
|||||||
self.sock = sock
|
self.sock = sock
|
||||||
|
|
||||||
def read(self, n: int) -> bytes:
|
def read(self, n: int) -> bytes:
|
||||||
return self.sock.read(n)
|
try:
|
||||||
|
data = self.sock.read(n)
|
||||||
|
if data == b"":
|
||||||
|
raise TransportEOFError("End of Socket reached")
|
||||||
|
return data
|
||||||
|
except (ConnectionResetError, BrokenPipeError):
|
||||||
|
raise TransportEOFError("Socket connection reset")
|
||||||
|
|
||||||
def close(self) -> None:
|
def close(self) -> None:
|
||||||
self.sock.close()
|
self.sock.close()
|
||||||
@ -45,7 +61,10 @@ class SocketWriteCloser(TransportWriteCloser):
|
|||||||
self.sock = sock
|
self.sock = sock
|
||||||
|
|
||||||
def write(self, data: bytes) -> None:
|
def write(self, data: bytes) -> None:
|
||||||
self.sock.write(data)
|
try:
|
||||||
|
self.sock.write(data)
|
||||||
|
except (ConnectionResetError, BrokenPipeError):
|
||||||
|
raise TransportEOFError("Socket write error, maybe the read end is closed")
|
||||||
|
|
||||||
def close(self) -> None:
|
def close(self) -> None:
|
||||||
self.sock.close()
|
self.sock.close()
|
||||||
|
|||||||
@ -23,6 +23,8 @@ class TransportWriter(Protocol):
|
|||||||
def write(self, data: bytes) -> None:
|
def write(self, data: bytes) -> None:
|
||||||
"""
|
"""
|
||||||
Write data to the transport.
|
Write data to the transport.
|
||||||
|
|
||||||
|
Raises TransportEOFError if the transport is closed.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
@ -35,6 +37,8 @@ class TransportReader(Protocol):
|
|||||||
def read(self, n: int) -> bytes:
|
def read(self, n: int) -> bytes:
|
||||||
"""
|
"""
|
||||||
Read up to n bytes from the transport.
|
Read up to n bytes from the transport.
|
||||||
|
|
||||||
|
Raises TransportEOFError if the end of the transport is reached.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -50,8 +50,16 @@ pid, transport_stdout, transport_stderr, transport_stdin = environment.execute_c
|
|||||||
print(f"Executed command with PID: {pid}")
|
print(f"Executed command with PID: {pid}")
|
||||||
|
|
||||||
# consume stdout
|
# consume stdout
|
||||||
output = transport_stdout.read(1024)
|
# consume stdout
|
||||||
print(f"Command output: {output.decode().strip()}")
|
while True:
|
||||||
|
try:
|
||||||
|
output = transport_stdout.read(1024)
|
||||||
|
except TransportEOFError:
|
||||||
|
logger.info("End of stdout reached")
|
||||||
|
break
|
||||||
|
|
||||||
|
logger.info("Command output: %s", output.decode().strip())
|
||||||
|
|
||||||
|
|
||||||
environment.release_connection(connection_handle)
|
environment.release_connection(connection_handle)
|
||||||
environment.release_environment()
|
environment.release_environment()
|
||||||
|
|||||||
@ -1,4 +1,5 @@
|
|||||||
import os
|
import os
|
||||||
|
import shlex
|
||||||
import threading
|
import threading
|
||||||
from collections.abc import Mapping, Sequence
|
from collections.abc import Mapping, Sequence
|
||||||
from enum import StrEnum
|
from enum import StrEnum
|
||||||
@ -50,8 +51,16 @@ pid, transport_stdin, transport_stdout, transport_stderr = environment.execute_c
|
|||||||
logger.info("Executed command with PID: %s", pid)
|
logger.info("Executed command with PID: %s", pid)
|
||||||
|
|
||||||
# consume stdout
|
# consume stdout
|
||||||
output = transport_stdout.read(1024)
|
# consume stdout
|
||||||
logger.info("Command output: %s", output.decode().strip())
|
while True:
|
||||||
|
try:
|
||||||
|
output = transport_stdout.read(1024)
|
||||||
|
except TransportEOFError:
|
||||||
|
logger.info("End of stdout reached")
|
||||||
|
break
|
||||||
|
|
||||||
|
logger.info("Command output: %s", output.decode().strip())
|
||||||
|
|
||||||
|
|
||||||
environment.release_connection(connection_handle)
|
environment.release_connection(connection_handle)
|
||||||
environment.release_environment()
|
environment.release_environment()
|
||||||
@ -204,17 +213,19 @@ class E2BEnvironment(VirtualEnvironment):
|
|||||||
""" """
|
""" """
|
||||||
stdout_stream_write_handler = stdout_stream.get_write_handler()
|
stdout_stream_write_handler = stdout_stream.get_write_handler()
|
||||||
stderr_stream_write_handler = stderr_stream.get_write_handler()
|
stderr_stream_write_handler = stderr_stream.get_write_handler()
|
||||||
sandbox.commands.run(
|
|
||||||
cmd=" ".join(command),
|
|
||||||
envs=dict(environments or {}),
|
|
||||||
# stdin=True,
|
|
||||||
on_stdout=lambda data: stdout_stream_write_handler.write(data.encode()),
|
|
||||||
on_stderr=lambda data: stderr_stream_write_handler.write(data.encode()),
|
|
||||||
)
|
|
||||||
|
|
||||||
# Close the write handlers to signal EOF
|
try:
|
||||||
stdout_stream.close()
|
sandbox.commands.run(
|
||||||
stderr_stream.close()
|
cmd=shlex.join(command),
|
||||||
|
envs=dict(environments or {}),
|
||||||
|
# stdin=True,
|
||||||
|
on_stdout=lambda data: stdout_stream_write_handler.write(data.encode()),
|
||||||
|
on_stderr=lambda data: stderr_stream_write_handler.write(data.encode()),
|
||||||
|
)
|
||||||
|
finally:
|
||||||
|
# Close the write handlers to signal EOF
|
||||||
|
stdout_stream.close()
|
||||||
|
stderr_stream.close()
|
||||||
|
|
||||||
@cached_property
|
@cached_property
|
||||||
def api_key(self) -> str:
|
def api_key(self) -> str:
|
||||||
|
|||||||
@ -14,6 +14,48 @@ from core.virtual_environment.__base.virtual_environment import VirtualEnvironme
|
|||||||
from core.virtual_environment.channel.pipe_transport import PipeReadCloser, PipeWriteCloser
|
from core.virtual_environment.channel.pipe_transport import PipeReadCloser, PipeWriteCloser
|
||||||
from core.virtual_environment.channel.transport import TransportReadCloser, TransportWriteCloser
|
from core.virtual_environment.channel.transport import TransportReadCloser, TransportWriteCloser
|
||||||
|
|
||||||
|
"""
|
||||||
|
USAGE:
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from collections.abc import Mapping
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from core.virtual_environment.channel.exec import TransportEOFError
|
||||||
|
from core.virtual_environment.providers.local_without_isolation import LocalVirtualEnvironment
|
||||||
|
|
||||||
|
options: Mapping[str, Any] = {}
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
logging.basicConfig(level=logging.DEBUG)
|
||||||
|
|
||||||
|
environment = LocalVirtualEnvironment(options=options)
|
||||||
|
|
||||||
|
connection_handle = environment.establish_connection()
|
||||||
|
|
||||||
|
pid, transport_stdin, transport_stdout, transport_stderr = environment.execute_command(
|
||||||
|
connection_handle,
|
||||||
|
["sh", "-lc", "for i in 1 2 3 4 5; do date '+%F %T'; sleep 1; done"],
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info("Executed command with PID: %s", pid)
|
||||||
|
|
||||||
|
# consume stdout
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
output = transport_stdout.read(1024)
|
||||||
|
except TransportEOFError:
|
||||||
|
logger.info("End of stdout reached")
|
||||||
|
break
|
||||||
|
|
||||||
|
logger.info("Command output: %s", output.decode().strip())
|
||||||
|
|
||||||
|
|
||||||
|
environment.release_connection(connection_handle)
|
||||||
|
environment.release_environment()
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
class LocalVirtualEnvironment(VirtualEnvironment):
|
class LocalVirtualEnvironment(VirtualEnvironment):
|
||||||
"""
|
"""
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user