feat: grace ful close the connection (#30039)

This commit is contained in:
wangxiaolei 2025-12-23 18:56:38 +08:00 committed by GitHub
parent a3d4f4f3bd
commit b321511518
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 102 additions and 22 deletions

View File

@ -61,6 +61,7 @@ class SSETransport:
self.timeout = timeout self.timeout = timeout
self.sse_read_timeout = sse_read_timeout self.sse_read_timeout = sse_read_timeout
self.endpoint_url: str | None = None self.endpoint_url: str | None = None
self.event_source: EventSource | None = None
def _validate_endpoint_url(self, endpoint_url: str) -> bool: def _validate_endpoint_url(self, endpoint_url: str) -> bool:
"""Validate that the endpoint URL matches the connection origin. """Validate that the endpoint URL matches the connection origin.
@ -237,6 +238,9 @@ class SSETransport:
write_queue: WriteQueue = queue.Queue() write_queue: WriteQueue = queue.Queue()
status_queue: StatusQueue = queue.Queue() status_queue: StatusQueue = queue.Queue()
# Store event_source for graceful shutdown
self.event_source = event_source
# Start SSE reader thread # Start SSE reader thread
executor.submit(self.sse_reader, event_source, read_queue, status_queue) executor.submit(self.sse_reader, event_source, read_queue, status_queue)
@ -296,6 +300,13 @@ def sse_client(
logger.exception("Error connecting to SSE endpoint") logger.exception("Error connecting to SSE endpoint")
raise raise
finally: finally:
# Close the SSE connection to unblock the reader thread
if transport.event_source is not None:
try:
transport.event_source.response.close()
except RuntimeError:
pass
# Clean up queues # Clean up queues
if read_queue: if read_queue:
read_queue.put(None) read_queue.put(None)

View File

@ -8,6 +8,7 @@ and session management.
import logging import logging
import queue import queue
import threading
from collections.abc import Callable, Generator from collections.abc import Callable, Generator
from concurrent.futures import ThreadPoolExecutor from concurrent.futures import ThreadPoolExecutor
from contextlib import contextmanager from contextlib import contextmanager
@ -103,6 +104,9 @@ class StreamableHTTPTransport:
CONTENT_TYPE: JSON, CONTENT_TYPE: JSON,
**self.headers, **self.headers,
} }
self.stop_event = threading.Event()
self._active_responses: list[httpx.Response] = []
self._lock = threading.Lock()
def _update_headers_with_session(self, base_headers: dict[str, str]) -> dict[str, str]: def _update_headers_with_session(self, base_headers: dict[str, str]) -> dict[str, str]:
"""Update headers with session ID if available.""" """Update headers with session ID if available."""
@ -111,6 +115,30 @@ class StreamableHTTPTransport:
headers[MCP_SESSION_ID] = self.session_id headers[MCP_SESSION_ID] = self.session_id
return headers return headers
def _register_response(self, response: httpx.Response):
"""Register a response for cleanup on shutdown."""
with self._lock:
self._active_responses.append(response)
def _unregister_response(self, response: httpx.Response):
"""Unregister a response after it's closed."""
with self._lock:
try:
self._active_responses.remove(response)
except ValueError as e:
logger.debug("Ignoring error during response unregister: %s", e)
def close_active_responses(self):
"""Close all active SSE connections to unblock threads."""
with self._lock:
responses_to_close = list(self._active_responses)
self._active_responses.clear()
for response in responses_to_close:
try:
response.close()
except RuntimeError as e:
logger.debug("Ignoring error during active response close: %s", e)
def _is_initialization_request(self, message: JSONRPCMessage) -> bool: def _is_initialization_request(self, message: JSONRPCMessage) -> bool:
"""Check if the message is an initialization request.""" """Check if the message is an initialization request."""
return isinstance(message.root, JSONRPCRequest) and message.root.method == "initialize" return isinstance(message.root, JSONRPCRequest) and message.root.method == "initialize"
@ -195,11 +223,21 @@ class StreamableHTTPTransport:
event_source.response.raise_for_status() event_source.response.raise_for_status()
logger.debug("GET SSE connection established") logger.debug("GET SSE connection established")
for sse in event_source.iter_sse(): # Register response for cleanup
self._handle_sse_event(sse, server_to_client_queue) self._register_response(event_source.response)
try:
for sse in event_source.iter_sse():
if self.stop_event.is_set():
logger.debug("GET stream received stop signal")
break
self._handle_sse_event(sse, server_to_client_queue)
finally:
self._unregister_response(event_source.response)
except Exception as exc: except Exception as exc:
logger.debug("GET stream error (non-fatal): %s", exc) if not self.stop_event.is_set():
logger.debug("GET stream error (non-fatal): %s", exc)
def _handle_resumption_request(self, ctx: RequestContext): def _handle_resumption_request(self, ctx: RequestContext):
"""Handle a resumption request using GET with SSE.""" """Handle a resumption request using GET with SSE."""
@ -224,15 +262,24 @@ class StreamableHTTPTransport:
event_source.response.raise_for_status() event_source.response.raise_for_status()
logger.debug("Resumption GET SSE connection established") logger.debug("Resumption GET SSE connection established")
for sse in event_source.iter_sse(): # Register response for cleanup
is_complete = self._handle_sse_event( self._register_response(event_source.response)
sse,
ctx.server_to_client_queue, try:
original_request_id, for sse in event_source.iter_sse():
ctx.metadata.on_resumption_token_update if ctx.metadata else None, if self.stop_event.is_set():
) logger.debug("Resumption stream received stop signal")
if is_complete: break
break is_complete = self._handle_sse_event(
sse,
ctx.server_to_client_queue,
original_request_id,
ctx.metadata.on_resumption_token_update if ctx.metadata else None,
)
if is_complete:
break
finally:
self._unregister_response(event_source.response)
def _handle_post_request(self, ctx: RequestContext): def _handle_post_request(self, ctx: RequestContext):
"""Handle a POST request with response processing.""" """Handle a POST request with response processing."""
@ -295,17 +342,27 @@ class StreamableHTTPTransport:
def _handle_sse_response(self, response: httpx.Response, ctx: RequestContext): def _handle_sse_response(self, response: httpx.Response, ctx: RequestContext):
"""Handle SSE response from the server.""" """Handle SSE response from the server."""
try: try:
# Register response for cleanup
self._register_response(response)
event_source = EventSource(response) event_source = EventSource(response)
for sse in event_source.iter_sse(): try:
is_complete = self._handle_sse_event( for sse in event_source.iter_sse():
sse, if self.stop_event.is_set():
ctx.server_to_client_queue, logger.debug("SSE response stream received stop signal")
resumption_callback=(ctx.metadata.on_resumption_token_update if ctx.metadata else None), break
) is_complete = self._handle_sse_event(
if is_complete: sse,
break ctx.server_to_client_queue,
resumption_callback=(ctx.metadata.on_resumption_token_update if ctx.metadata else None),
)
if is_complete:
break
finally:
self._unregister_response(response)
except Exception as e: except Exception as e:
ctx.server_to_client_queue.put(e) if not self.stop_event.is_set():
ctx.server_to_client_queue.put(e)
def _handle_unexpected_content_type( def _handle_unexpected_content_type(
self, self,
@ -345,6 +402,11 @@ class StreamableHTTPTransport:
""" """
while True: while True:
try: try:
# Check if we should stop
if self.stop_event.is_set():
logger.debug("Post writer received stop signal")
break
# Read message from client queue with timeout to check stop_event periodically # Read message from client queue with timeout to check stop_event periodically
session_message = client_to_server_queue.get(timeout=DEFAULT_QUEUE_READ_TIMEOUT) session_message = client_to_server_queue.get(timeout=DEFAULT_QUEUE_READ_TIMEOUT)
if session_message is None: if session_message is None:
@ -381,7 +443,8 @@ class StreamableHTTPTransport:
except queue.Empty: except queue.Empty:
continue continue
except Exception as exc: except Exception as exc:
server_to_client_queue.put(exc) if not self.stop_event.is_set():
server_to_client_queue.put(exc)
def terminate_session(self, client: httpx.Client): def terminate_session(self, client: httpx.Client):
"""Terminate the session by sending a DELETE request.""" """Terminate the session by sending a DELETE request."""
@ -465,6 +528,12 @@ def streamablehttp_client(
transport.get_session_id, transport.get_session_id,
) )
finally: finally:
# Set stop event to signal all threads to stop
transport.stop_event.set()
# Close all active SSE connections to unblock threads
transport.close_active_responses()
if transport.session_id and terminate_on_close: if transport.session_id and terminate_on_close:
transport.terminate_session(client) transport.terminate_session(client)