diff --git a/api/core/mcp/client/sse_client.py b/api/core/mcp/client/sse_client.py index 24ca59ee45..1de1d5a073 100644 --- a/api/core/mcp/client/sse_client.py +++ b/api/core/mcp/client/sse_client.py @@ -61,6 +61,7 @@ class SSETransport: self.timeout = timeout self.sse_read_timeout = sse_read_timeout self.endpoint_url: str | None = None + self.event_source: EventSource | None = None def _validate_endpoint_url(self, endpoint_url: str) -> bool: """Validate that the endpoint URL matches the connection origin. @@ -237,6 +238,9 @@ class SSETransport: write_queue: WriteQueue = queue.Queue() status_queue: StatusQueue = queue.Queue() + # Store event_source for graceful shutdown + self.event_source = event_source + # Start SSE reader thread executor.submit(self.sse_reader, event_source, read_queue, status_queue) @@ -296,6 +300,13 @@ def sse_client( logger.exception("Error connecting to SSE endpoint") raise finally: + # Close the SSE connection to unblock the reader thread + if transport.event_source is not None: + try: + transport.event_source.response.close() + except RuntimeError: + pass + # Clean up queues if read_queue: read_queue.put(None) diff --git a/api/core/mcp/client/streamable_client.py b/api/core/mcp/client/streamable_client.py index 805c16c838..f81e7cead8 100644 --- a/api/core/mcp/client/streamable_client.py +++ b/api/core/mcp/client/streamable_client.py @@ -8,6 +8,7 @@ and session management. import logging import queue +import threading from collections.abc import Callable, Generator from concurrent.futures import ThreadPoolExecutor from contextlib import contextmanager @@ -103,6 +104,9 @@ class StreamableHTTPTransport: CONTENT_TYPE: JSON, **self.headers, } + self.stop_event = threading.Event() + self._active_responses: list[httpx.Response] = [] + self._lock = threading.Lock() def _update_headers_with_session(self, base_headers: dict[str, str]) -> dict[str, str]: """Update headers with session ID if available.""" @@ -111,6 +115,30 @@ class StreamableHTTPTransport: headers[MCP_SESSION_ID] = self.session_id return headers + def _register_response(self, response: httpx.Response): + """Register a response for cleanup on shutdown.""" + with self._lock: + self._active_responses.append(response) + + def _unregister_response(self, response: httpx.Response): + """Unregister a response after it's closed.""" + with self._lock: + try: + self._active_responses.remove(response) + except ValueError as e: + logger.debug("Ignoring error during response unregister: %s", e) + + def close_active_responses(self): + """Close all active SSE connections to unblock threads.""" + with self._lock: + responses_to_close = list(self._active_responses) + self._active_responses.clear() + for response in responses_to_close: + try: + response.close() + except RuntimeError as e: + logger.debug("Ignoring error during active response close: %s", e) + def _is_initialization_request(self, message: JSONRPCMessage) -> bool: """Check if the message is an initialization request.""" return isinstance(message.root, JSONRPCRequest) and message.root.method == "initialize" @@ -195,11 +223,21 @@ class StreamableHTTPTransport: event_source.response.raise_for_status() logger.debug("GET SSE connection established") - for sse in event_source.iter_sse(): - self._handle_sse_event(sse, server_to_client_queue) + # Register response for cleanup + self._register_response(event_source.response) + + try: + for sse in event_source.iter_sse(): + if self.stop_event.is_set(): + logger.debug("GET stream received stop signal") + break + self._handle_sse_event(sse, server_to_client_queue) + finally: + self._unregister_response(event_source.response) except Exception as exc: - logger.debug("GET stream error (non-fatal): %s", exc) + if not self.stop_event.is_set(): + logger.debug("GET stream error (non-fatal): %s", exc) def _handle_resumption_request(self, ctx: RequestContext): """Handle a resumption request using GET with SSE.""" @@ -224,15 +262,24 @@ class StreamableHTTPTransport: event_source.response.raise_for_status() logger.debug("Resumption GET SSE connection established") - for sse in event_source.iter_sse(): - is_complete = self._handle_sse_event( - sse, - ctx.server_to_client_queue, - original_request_id, - ctx.metadata.on_resumption_token_update if ctx.metadata else None, - ) - if is_complete: - break + # Register response for cleanup + self._register_response(event_source.response) + + try: + for sse in event_source.iter_sse(): + if self.stop_event.is_set(): + logger.debug("Resumption stream received stop signal") + break + is_complete = self._handle_sse_event( + sse, + ctx.server_to_client_queue, + original_request_id, + ctx.metadata.on_resumption_token_update if ctx.metadata else None, + ) + if is_complete: + break + finally: + self._unregister_response(event_source.response) def _handle_post_request(self, ctx: RequestContext): """Handle a POST request with response processing.""" @@ -295,17 +342,27 @@ class StreamableHTTPTransport: def _handle_sse_response(self, response: httpx.Response, ctx: RequestContext): """Handle SSE response from the server.""" try: + # Register response for cleanup + self._register_response(response) + event_source = EventSource(response) - for sse in event_source.iter_sse(): - is_complete = self._handle_sse_event( - sse, - ctx.server_to_client_queue, - resumption_callback=(ctx.metadata.on_resumption_token_update if ctx.metadata else None), - ) - if is_complete: - break + try: + for sse in event_source.iter_sse(): + if self.stop_event.is_set(): + logger.debug("SSE response stream received stop signal") + break + is_complete = self._handle_sse_event( + sse, + ctx.server_to_client_queue, + resumption_callback=(ctx.metadata.on_resumption_token_update if ctx.metadata else None), + ) + if is_complete: + break + finally: + self._unregister_response(response) except Exception as e: - ctx.server_to_client_queue.put(e) + if not self.stop_event.is_set(): + ctx.server_to_client_queue.put(e) def _handle_unexpected_content_type( self, @@ -345,6 +402,11 @@ class StreamableHTTPTransport: """ while True: try: + # Check if we should stop + if self.stop_event.is_set(): + logger.debug("Post writer received stop signal") + break + # Read message from client queue with timeout to check stop_event periodically session_message = client_to_server_queue.get(timeout=DEFAULT_QUEUE_READ_TIMEOUT) if session_message is None: @@ -381,7 +443,8 @@ class StreamableHTTPTransport: except queue.Empty: continue except Exception as exc: - server_to_client_queue.put(exc) + if not self.stop_event.is_set(): + server_to_client_queue.put(exc) def terminate_session(self, client: httpx.Client): """Terminate the session by sending a DELETE request.""" @@ -465,6 +528,12 @@ def streamablehttp_client( transport.get_session_id, ) finally: + # Set stop event to signal all threads to stop + transport.stop_event.set() + + # Close all active SSE connections to unblock threads + transport.close_active_responses() + if transport.session_id and terminate_on_close: transport.terminate_session(client)