mirror of
https://github.com/langgenius/dify.git
synced 2026-04-28 03:36:36 +08:00
feat: grace ful close the connection (#30039)
This commit is contained in:
parent
a3d4f4f3bd
commit
b321511518
@ -61,6 +61,7 @@ class SSETransport:
|
|||||||
self.timeout = timeout
|
self.timeout = timeout
|
||||||
self.sse_read_timeout = sse_read_timeout
|
self.sse_read_timeout = sse_read_timeout
|
||||||
self.endpoint_url: str | None = None
|
self.endpoint_url: str | None = None
|
||||||
|
self.event_source: EventSource | None = None
|
||||||
|
|
||||||
def _validate_endpoint_url(self, endpoint_url: str) -> bool:
|
def _validate_endpoint_url(self, endpoint_url: str) -> bool:
|
||||||
"""Validate that the endpoint URL matches the connection origin.
|
"""Validate that the endpoint URL matches the connection origin.
|
||||||
@ -237,6 +238,9 @@ class SSETransport:
|
|||||||
write_queue: WriteQueue = queue.Queue()
|
write_queue: WriteQueue = queue.Queue()
|
||||||
status_queue: StatusQueue = queue.Queue()
|
status_queue: StatusQueue = queue.Queue()
|
||||||
|
|
||||||
|
# Store event_source for graceful shutdown
|
||||||
|
self.event_source = event_source
|
||||||
|
|
||||||
# Start SSE reader thread
|
# Start SSE reader thread
|
||||||
executor.submit(self.sse_reader, event_source, read_queue, status_queue)
|
executor.submit(self.sse_reader, event_source, read_queue, status_queue)
|
||||||
|
|
||||||
@ -296,6 +300,13 @@ def sse_client(
|
|||||||
logger.exception("Error connecting to SSE endpoint")
|
logger.exception("Error connecting to SSE endpoint")
|
||||||
raise
|
raise
|
||||||
finally:
|
finally:
|
||||||
|
# Close the SSE connection to unblock the reader thread
|
||||||
|
if transport.event_source is not None:
|
||||||
|
try:
|
||||||
|
transport.event_source.response.close()
|
||||||
|
except RuntimeError:
|
||||||
|
pass
|
||||||
|
|
||||||
# Clean up queues
|
# Clean up queues
|
||||||
if read_queue:
|
if read_queue:
|
||||||
read_queue.put(None)
|
read_queue.put(None)
|
||||||
|
|||||||
@ -8,6 +8,7 @@ and session management.
|
|||||||
|
|
||||||
import logging
|
import logging
|
||||||
import queue
|
import queue
|
||||||
|
import threading
|
||||||
from collections.abc import Callable, Generator
|
from collections.abc import Callable, Generator
|
||||||
from concurrent.futures import ThreadPoolExecutor
|
from concurrent.futures import ThreadPoolExecutor
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
@ -103,6 +104,9 @@ class StreamableHTTPTransport:
|
|||||||
CONTENT_TYPE: JSON,
|
CONTENT_TYPE: JSON,
|
||||||
**self.headers,
|
**self.headers,
|
||||||
}
|
}
|
||||||
|
self.stop_event = threading.Event()
|
||||||
|
self._active_responses: list[httpx.Response] = []
|
||||||
|
self._lock = threading.Lock()
|
||||||
|
|
||||||
def _update_headers_with_session(self, base_headers: dict[str, str]) -> dict[str, str]:
|
def _update_headers_with_session(self, base_headers: dict[str, str]) -> dict[str, str]:
|
||||||
"""Update headers with session ID if available."""
|
"""Update headers with session ID if available."""
|
||||||
@ -111,6 +115,30 @@ class StreamableHTTPTransport:
|
|||||||
headers[MCP_SESSION_ID] = self.session_id
|
headers[MCP_SESSION_ID] = self.session_id
|
||||||
return headers
|
return headers
|
||||||
|
|
||||||
|
def _register_response(self, response: httpx.Response):
|
||||||
|
"""Register a response for cleanup on shutdown."""
|
||||||
|
with self._lock:
|
||||||
|
self._active_responses.append(response)
|
||||||
|
|
||||||
|
def _unregister_response(self, response: httpx.Response):
|
||||||
|
"""Unregister a response after it's closed."""
|
||||||
|
with self._lock:
|
||||||
|
try:
|
||||||
|
self._active_responses.remove(response)
|
||||||
|
except ValueError as e:
|
||||||
|
logger.debug("Ignoring error during response unregister: %s", e)
|
||||||
|
|
||||||
|
def close_active_responses(self):
|
||||||
|
"""Close all active SSE connections to unblock threads."""
|
||||||
|
with self._lock:
|
||||||
|
responses_to_close = list(self._active_responses)
|
||||||
|
self._active_responses.clear()
|
||||||
|
for response in responses_to_close:
|
||||||
|
try:
|
||||||
|
response.close()
|
||||||
|
except RuntimeError as e:
|
||||||
|
logger.debug("Ignoring error during active response close: %s", e)
|
||||||
|
|
||||||
def _is_initialization_request(self, message: JSONRPCMessage) -> bool:
|
def _is_initialization_request(self, message: JSONRPCMessage) -> bool:
|
||||||
"""Check if the message is an initialization request."""
|
"""Check if the message is an initialization request."""
|
||||||
return isinstance(message.root, JSONRPCRequest) and message.root.method == "initialize"
|
return isinstance(message.root, JSONRPCRequest) and message.root.method == "initialize"
|
||||||
@ -195,11 +223,21 @@ class StreamableHTTPTransport:
|
|||||||
event_source.response.raise_for_status()
|
event_source.response.raise_for_status()
|
||||||
logger.debug("GET SSE connection established")
|
logger.debug("GET SSE connection established")
|
||||||
|
|
||||||
for sse in event_source.iter_sse():
|
# Register response for cleanup
|
||||||
self._handle_sse_event(sse, server_to_client_queue)
|
self._register_response(event_source.response)
|
||||||
|
|
||||||
|
try:
|
||||||
|
for sse in event_source.iter_sse():
|
||||||
|
if self.stop_event.is_set():
|
||||||
|
logger.debug("GET stream received stop signal")
|
||||||
|
break
|
||||||
|
self._handle_sse_event(sse, server_to_client_queue)
|
||||||
|
finally:
|
||||||
|
self._unregister_response(event_source.response)
|
||||||
|
|
||||||
except Exception as exc:
|
except Exception as exc:
|
||||||
logger.debug("GET stream error (non-fatal): %s", exc)
|
if not self.stop_event.is_set():
|
||||||
|
logger.debug("GET stream error (non-fatal): %s", exc)
|
||||||
|
|
||||||
def _handle_resumption_request(self, ctx: RequestContext):
|
def _handle_resumption_request(self, ctx: RequestContext):
|
||||||
"""Handle a resumption request using GET with SSE."""
|
"""Handle a resumption request using GET with SSE."""
|
||||||
@ -224,15 +262,24 @@ class StreamableHTTPTransport:
|
|||||||
event_source.response.raise_for_status()
|
event_source.response.raise_for_status()
|
||||||
logger.debug("Resumption GET SSE connection established")
|
logger.debug("Resumption GET SSE connection established")
|
||||||
|
|
||||||
for sse in event_source.iter_sse():
|
# Register response for cleanup
|
||||||
is_complete = self._handle_sse_event(
|
self._register_response(event_source.response)
|
||||||
sse,
|
|
||||||
ctx.server_to_client_queue,
|
try:
|
||||||
original_request_id,
|
for sse in event_source.iter_sse():
|
||||||
ctx.metadata.on_resumption_token_update if ctx.metadata else None,
|
if self.stop_event.is_set():
|
||||||
)
|
logger.debug("Resumption stream received stop signal")
|
||||||
if is_complete:
|
break
|
||||||
break
|
is_complete = self._handle_sse_event(
|
||||||
|
sse,
|
||||||
|
ctx.server_to_client_queue,
|
||||||
|
original_request_id,
|
||||||
|
ctx.metadata.on_resumption_token_update if ctx.metadata else None,
|
||||||
|
)
|
||||||
|
if is_complete:
|
||||||
|
break
|
||||||
|
finally:
|
||||||
|
self._unregister_response(event_source.response)
|
||||||
|
|
||||||
def _handle_post_request(self, ctx: RequestContext):
|
def _handle_post_request(self, ctx: RequestContext):
|
||||||
"""Handle a POST request with response processing."""
|
"""Handle a POST request with response processing."""
|
||||||
@ -295,17 +342,27 @@ class StreamableHTTPTransport:
|
|||||||
def _handle_sse_response(self, response: httpx.Response, ctx: RequestContext):
|
def _handle_sse_response(self, response: httpx.Response, ctx: RequestContext):
|
||||||
"""Handle SSE response from the server."""
|
"""Handle SSE response from the server."""
|
||||||
try:
|
try:
|
||||||
|
# Register response for cleanup
|
||||||
|
self._register_response(response)
|
||||||
|
|
||||||
event_source = EventSource(response)
|
event_source = EventSource(response)
|
||||||
for sse in event_source.iter_sse():
|
try:
|
||||||
is_complete = self._handle_sse_event(
|
for sse in event_source.iter_sse():
|
||||||
sse,
|
if self.stop_event.is_set():
|
||||||
ctx.server_to_client_queue,
|
logger.debug("SSE response stream received stop signal")
|
||||||
resumption_callback=(ctx.metadata.on_resumption_token_update if ctx.metadata else None),
|
break
|
||||||
)
|
is_complete = self._handle_sse_event(
|
||||||
if is_complete:
|
sse,
|
||||||
break
|
ctx.server_to_client_queue,
|
||||||
|
resumption_callback=(ctx.metadata.on_resumption_token_update if ctx.metadata else None),
|
||||||
|
)
|
||||||
|
if is_complete:
|
||||||
|
break
|
||||||
|
finally:
|
||||||
|
self._unregister_response(response)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
ctx.server_to_client_queue.put(e)
|
if not self.stop_event.is_set():
|
||||||
|
ctx.server_to_client_queue.put(e)
|
||||||
|
|
||||||
def _handle_unexpected_content_type(
|
def _handle_unexpected_content_type(
|
||||||
self,
|
self,
|
||||||
@ -345,6 +402,11 @@ class StreamableHTTPTransport:
|
|||||||
"""
|
"""
|
||||||
while True:
|
while True:
|
||||||
try:
|
try:
|
||||||
|
# Check if we should stop
|
||||||
|
if self.stop_event.is_set():
|
||||||
|
logger.debug("Post writer received stop signal")
|
||||||
|
break
|
||||||
|
|
||||||
# Read message from client queue with timeout to check stop_event periodically
|
# Read message from client queue with timeout to check stop_event periodically
|
||||||
session_message = client_to_server_queue.get(timeout=DEFAULT_QUEUE_READ_TIMEOUT)
|
session_message = client_to_server_queue.get(timeout=DEFAULT_QUEUE_READ_TIMEOUT)
|
||||||
if session_message is None:
|
if session_message is None:
|
||||||
@ -381,7 +443,8 @@ class StreamableHTTPTransport:
|
|||||||
except queue.Empty:
|
except queue.Empty:
|
||||||
continue
|
continue
|
||||||
except Exception as exc:
|
except Exception as exc:
|
||||||
server_to_client_queue.put(exc)
|
if not self.stop_event.is_set():
|
||||||
|
server_to_client_queue.put(exc)
|
||||||
|
|
||||||
def terminate_session(self, client: httpx.Client):
|
def terminate_session(self, client: httpx.Client):
|
||||||
"""Terminate the session by sending a DELETE request."""
|
"""Terminate the session by sending a DELETE request."""
|
||||||
@ -465,6 +528,12 @@ def streamablehttp_client(
|
|||||||
transport.get_session_id,
|
transport.get_session_id,
|
||||||
)
|
)
|
||||||
finally:
|
finally:
|
||||||
|
# Set stop event to signal all threads to stop
|
||||||
|
transport.stop_event.set()
|
||||||
|
|
||||||
|
# Close all active SSE connections to unblock threads
|
||||||
|
transport.close_active_responses()
|
||||||
|
|
||||||
if transport.session_id and terminate_on_close:
|
if transport.session_id and terminate_on_close:
|
||||||
transport.terminate_session(client)
|
transport.terminate_session(client)
|
||||||
|
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user