From 2e3ac6ca4e64e3d9800f4599792484adcca5f774 Mon Sep 17 00:00:00 2001 From: Vincent Ye Date: Wed, 11 Feb 2026 14:51:18 -0800 Subject: [PATCH 1/2] remove kernel_id para from ctor. extracted gateway url generation to the creator. --- jupyter_server/gateway/managers.py | 29 ++++++++++++-------- jupyter_server/services/kernels/websocket.py | 1 + 2 files changed, 19 insertions(+), 11 deletions(-) diff --git a/jupyter_server/gateway/managers.py b/jupyter_server/gateway/managers.py index 270001f30..179e7b169 100644 --- a/jupyter_server/gateway/managers.py +++ b/jupyter_server/gateway/managers.py @@ -21,7 +21,7 @@ from jupyter_core.utils import ensure_async from tornado import web from tornado.escape import json_decode, json_encode, url_escape, utf8 -from traitlets import DottedObjectName, Instance, Type, default +from traitlets import DottedObjectName, Instance, Type, Unicode, default from .._tz import UTC, utcnow from ..services.kernels.kernelmanager import ( @@ -622,9 +622,10 @@ async def _async_get(self, timeout=None): except Empty: if self.response_router_finished: msg = "Response router had finished" + # TODO throw dedicated Exception for the caller to react on it. raise RuntimeError(msg) from None if monotonic() > end_time: - raise + raise TimeoutError(f"{self.channel_name} async_get timeout") from None await asyncio.sleep(0) async def get_msg(self, *args: Any, **kwargs: Any) -> dict[str, Any]: @@ -720,10 +721,18 @@ class GatewayKernelClient(AsyncKernelClient): _iopub_channel: Optional[ChannelQueue] _shell_channel: Optional[ChannelQueue] - def __init__(self, kernel_id, **kwargs): + ws_url = Unicode( + default_value=None, + allow_none=True, + config=True, + help="""The websocket url of the Kernel or Enterprise Gateway server. If not provided, this value +will correspond to the value of the Gateway url with 'ws' in place of 'http'. (JUPYTER_GATEWAY_WS_URL env var) + """, + ) + + def __init__(self, **kwargs): """Initialize a gateway kernel client.""" super().__init__(**kwargs) - self.kernel_id = kernel_id self.channel_socket: Optional[websocket.WebSocket] = None self.response_router: Optional[Thread] = None self._channels_stopped = False @@ -736,17 +745,15 @@ def __init__(self, kernel_id, **kwargs): async def start_channels(self, shell=True, iopub=True, stdin=True, hb=True, control=True): """Starts the channels for this kernel. + Please set property ws_url before calling this method. For this class, we establish a websocket connection to the destination and set up the channel-based queues on which applicable messages will be posted. """ + if self.ws_url is None: + msg = "ws_url is None. set it before call start_channels" + raise RuntimeError(msg) - ws_url = url_path_join( - GatewayClient.instance().ws_url or "", - GatewayClient.instance().kernels_endpoint, - url_escape(self.kernel_id), - "channels", - ) # Gather cert info in case where ssl is desired... ssl_options = { "ca_certs": GatewayClient.instance().ca_certs, @@ -755,7 +762,7 @@ async def start_channels(self, shell=True, iopub=True, stdin=True, hb=True, cont } self.channel_socket = websocket.create_connection( - ws_url, + self.ws_url, timeout=GatewayClient.instance().KERNEL_LAUNCH_TIMEOUT, enable_multithread=True, sslopt=ssl_options, diff --git a/jupyter_server/services/kernels/websocket.py b/jupyter_server/services/kernels/websocket.py index a24b0539f..94834604f 100644 --- a/jupyter_server/services/kernels/websocket.py +++ b/jupyter_server/services/kernels/websocket.py @@ -9,6 +9,7 @@ from jupyter_server.auth.decorator import ws_authenticated from jupyter_server.base.handlers import JupyterHandler from jupyter_server.base.websocket import WebSocketMixin +from jupyter_server.saturn.connections import SaturnGatewayWebSocketConnection AUTH_RESOURCE = "kernels" From 4dfa785d19faf6b22c42c53a19926483875e68a9 Mon Sep 17 00:00:00 2001 From: Vincent Ye Date: Mon, 16 Feb 2026 13:29:39 -0800 Subject: [PATCH 2/2] reconnect to gateway add session_id to gateway url --- jupyter_server/gateway/managers.py | 167 +++++++++++++++++++++++++---- 1 file changed, 147 insertions(+), 20 deletions(-) diff --git a/jupyter_server/gateway/managers.py b/jupyter_server/gateway/managers.py index 179e7b169..75d96aeeb 100644 --- a/jupyter_server/gateway/managers.py +++ b/jupyter_server/gateway/managers.py @@ -8,10 +8,12 @@ import datetime import json import os +import time from queue import Empty, Queue from threading import Thread from time import monotonic from typing import TYPE_CHECKING, Any, Optional, cast +from urllib.parse import parse_qs, urlencode, urlsplit, urlunsplit import websocket from jupyter_client.asynchronous.client import AsyncKernelClient @@ -725,11 +727,16 @@ class GatewayKernelClient(AsyncKernelClient): default_value=None, allow_none=True, config=True, - help="""The websocket url of the Kernel or Enterprise Gateway server. If not provided, this value + help="""The websocket url of the Kernel or Kernel Gateway server. If not provided, this value will correspond to the value of the Gateway url with 'ws' in place of 'http'. (JUPYTER_GATEWAY_WS_URL env var) """, ) + @property + def session_id(self): + """The session id.""" + return self.session.session + def __init__(self, **kwargs): """Initialize a gateway kernel client.""" super().__init__(**kwargs) @@ -738,6 +745,20 @@ def __init__(self, **kwargs): self._channels_stopped = False self._channel_queues = {} + def add_session_param(self, ws_url): + if not self.session_id: + return ws_url + + scheme, netloc, path, query_string, fragment = urlsplit(ws_url) + # Parse the existing query string into a dictionary + query_params = parse_qs(query_string) + # Update the dictionary with new parameters + query_params.update({"session": self.session_id}) + # Encode the updated dictionary back into a query string + new_query_string = urlencode(query_params, doseq=True) + # Reconstruct the URL + return urlunsplit((scheme, netloc, path, new_query_string, fragment)) + # -------------------------------------------------------------------------- # Channel management methods # -------------------------------------------------------------------------- @@ -762,7 +783,7 @@ async def start_channels(self, shell=True, iopub=True, stdin=True, hb=True, cont } self.channel_socket = websocket.create_connection( - self.ws_url, + self.add_session_param(self.ws_url), timeout=GatewayClient.instance().KERNEL_LAUNCH_TIMEOUT, enable_multithread=True, sslopt=ssl_options, @@ -851,38 +872,144 @@ def control_channel(self): self._channel_queues["control"] = self._control_channel return self._control_channel + def _should_reconnect(self) -> bool: + """Determine if we should attempt to reconnect the WebSocket. + + Returns True if the kernel manager has a provisioner with is_alive=True, + or if the kernel manager's is_alive check passes. + """ + # Don't reconnect if intentionally stopped + if self._channels_stopped: + return False + + # Don't reconnect if kernel is not alive + return asyncio.run(self.is_alive()) + + def _reconnect_socket(self) -> bool: + """Attempt to reconnect the WebSocket. + + Returns True on successful reconnection, False otherwise. + """ + try: + # Close old socket if exists + if self.channel_socket: + try: + self.channel_socket.close() + except Exception: + pass + + # Create new WebSocket connection + ssl_options = { + "ca_certs": GatewayClient.instance().ca_certs, + "certfile": GatewayClient.instance().client_cert, + "keyfile": GatewayClient.instance().client_key, + } + + self.channel_socket = websocket.create_connection( + self.add_session_param(self.ws_url), + timeout=GatewayClient.instance().KERNEL_LAUNCH_TIMEOUT, + enable_multithread=True, + sslopt=ssl_options, + ) + + # Update channel queues with new socket + # Each ChannelQueue holds a reference to channel_socket for sending messages + if self._channel_queues: + for queue in self._channel_queues.values(): + queue.channel_socket = self.channel_socket + + return True + + except Exception as e: + self.log.error(f"Failed to reconnect WebSocket: {e}") + return False + + def _mark_queues_finished(self) -> None: + """Mark all channel queues as finished. + + This notifies consumers that no more messages will be received. + """ + if self._channel_queues: + for channel_queue in self._channel_queues.values(): + channel_queue.response_router_finished = True + self.log.debug("Response router thread exiting...") + def _route_responses(self): """ Reads responses from the websocket and routes each to the appropriate channel queue based - on the message's channel. It does this for the duration of the class's lifetime until the - channels are stopped, at which time the socket is closed (unblocking the router) and - the thread terminates. If shutdown happens to occur while processing a response (unlikely), - termination takes place via the loop control boolean. + on the message's channel. Implements automatic reconnection with exponential backoff when + the WebSocket disconnects but the kernel is still alive. + + The router continues running until channels are stopped, handling disconnections gracefully + and attempting to reconnect to maintain the message flow. """ - try: - while not self._channels_stopped: + max_reconnect_attempts = 5 + base_delay = 1.0 # seconds + max_delay = 30.0 # max delay between reconnection attempts + attempt = 0 + + while not self._channels_stopped: + try: + # Check if socket needs reconnection + if self.channel_socket is None: + if not self._should_reconnect(): + self.log.info( + "WebSocket disconnected and Kernel is terminated. Exiting router." + ) + break + + # Calculate exponential backoff delay + delay = min(base_delay * (2**attempt), max_delay) + self.log.warning( + f"WebSocket disconnected. Attempting reconnection " + f"(attempt {attempt + 1}/{max_reconnect_attempts}) " + f"after {delay:.1f}s delay..." + ) + time.sleep(delay) + + # Attempt reconnection + if not self._reconnect_socket(): + attempt += 1 + if attempt >= max_reconnect_attempts: + self.log.error( + f"Max reconnection attempts ({max_reconnect_attempts}) reached. " + "Giving up on WebSocket reconnection." + ) + break + continue + + # Reset attempt counter on successful reconnection + attempt = 0 + self.log.info("WebSocket reconnected successfully") + + # Normal message routing assert self.channel_socket is not None raw_message = self.channel_socket.recv() if not raw_message: - break + # Empty message, socket might be closing + self.log.debug("Received empty message from WebSocket") + self.channel_socket = None + continue + response_message = json_decode(utf8(raw_message)) channel = response_message["channel"] assert self._channel_queues is not None self._channel_queues[channel].put_nowait(response_message) - except websocket.WebSocketConnectionClosedException: - pass # websocket closure most likely due to shut down + except websocket.WebSocketConnectionClosedException: + self.log.warning("WebSocket connection closed unexpectedly") + self.channel_socket = None + # Loop will retry connection on next iteration if not stopped - except BaseException as be: - if not self._channels_stopped: - self.log.warning(f"Unexpected exception encountered ({be})") + except Exception as e: + if not self._channels_stopped: + self.log.exception(f"Error in response router: {e}") + self.channel_socket = None + # Brief pause before retry to avoid tight error loop + time.sleep(1) - # Notify channel queues that this thread had finished and no more messages are being received - assert self._channel_queues is not None - for channel_queue in self._channel_queues.values(): - channel_queue.response_router_finished = True - - self.log.debug("Response router thread exiting...") + # Final cleanup: mark all queues as finished + self._mark_queues_finished() KernelClientABC.register(GatewayKernelClient)