diff --git a/jupyter_server/gateway/managers.py b/jupyter_server/gateway/managers.py index 270001f30..75d96aeeb 100644 --- a/jupyter_server/gateway/managers.py +++ b/jupyter_server/gateway/managers.py @@ -8,10 +8,12 @@ import datetime import json import os +import time from queue import Empty, Queue from threading import Thread from time import monotonic from typing import TYPE_CHECKING, Any, Optional, cast +from urllib.parse import parse_qs, urlencode, urlsplit, urlunsplit import websocket from jupyter_client.asynchronous.client import AsyncKernelClient @@ -21,7 +23,7 @@ from jupyter_core.utils import ensure_async from tornado import web from tornado.escape import json_decode, json_encode, url_escape, utf8 -from traitlets import DottedObjectName, Instance, Type, default +from traitlets import DottedObjectName, Instance, Type, Unicode, default from .._tz import UTC, utcnow from ..services.kernels.kernelmanager import ( @@ -622,9 +624,10 @@ async def _async_get(self, timeout=None): except Empty: if self.response_router_finished: msg = "Response router had finished" + # TODO throw dedicated Exception for the caller to react on it. raise RuntimeError(msg) from None if monotonic() > end_time: - raise + raise TimeoutError(f"{self.channel_name} async_get timeout") from None await asyncio.sleep(0) async def get_msg(self, *args: Any, **kwargs: Any) -> dict[str, Any]: @@ -720,15 +723,42 @@ class GatewayKernelClient(AsyncKernelClient): _iopub_channel: Optional[ChannelQueue] _shell_channel: Optional[ChannelQueue] - def __init__(self, kernel_id, **kwargs): + ws_url = Unicode( + default_value=None, + allow_none=True, + config=True, + help="""The websocket url of the Kernel or Kernel Gateway server. If not provided, this value +will correspond to the value of the Gateway url with 'ws' in place of 'http'. (JUPYTER_GATEWAY_WS_URL env var) + """, + ) + + @property + def session_id(self): + """The session id.""" + return self.session.session + + def __init__(self, **kwargs): """Initialize a gateway kernel client.""" super().__init__(**kwargs) - self.kernel_id = kernel_id self.channel_socket: Optional[websocket.WebSocket] = None self.response_router: Optional[Thread] = None self._channels_stopped = False self._channel_queues = {} + def add_session_param(self, ws_url): + if not self.session_id: + return ws_url + + scheme, netloc, path, query_string, fragment = urlsplit(ws_url) + # Parse the existing query string into a dictionary + query_params = parse_qs(query_string) + # Update the dictionary with new parameters + query_params.update({"session": self.session_id}) + # Encode the updated dictionary back into a query string + new_query_string = urlencode(query_params, doseq=True) + # Reconstruct the URL + return urlunsplit((scheme, netloc, path, new_query_string, fragment)) + # -------------------------------------------------------------------------- # Channel management methods # -------------------------------------------------------------------------- @@ -736,17 +766,15 @@ def __init__(self, kernel_id, **kwargs): async def start_channels(self, shell=True, iopub=True, stdin=True, hb=True, control=True): """Starts the channels for this kernel. + Please set property ws_url before calling this method. For this class, we establish a websocket connection to the destination and set up the channel-based queues on which applicable messages will be posted. """ + if self.ws_url is None: + msg = "ws_url is None. set it before call start_channels" + raise RuntimeError(msg) - ws_url = url_path_join( - GatewayClient.instance().ws_url or "", - GatewayClient.instance().kernels_endpoint, - url_escape(self.kernel_id), - "channels", - ) # Gather cert info in case where ssl is desired... ssl_options = { "ca_certs": GatewayClient.instance().ca_certs, @@ -755,7 +783,7 @@ async def start_channels(self, shell=True, iopub=True, stdin=True, hb=True, cont } self.channel_socket = websocket.create_connection( - ws_url, + self.add_session_param(self.ws_url), timeout=GatewayClient.instance().KERNEL_LAUNCH_TIMEOUT, enable_multithread=True, sslopt=ssl_options, @@ -844,38 +872,144 @@ def control_channel(self): self._channel_queues["control"] = self._control_channel return self._control_channel + def _should_reconnect(self) -> bool: + """Determine if we should attempt to reconnect the WebSocket. + + Returns True if the kernel manager has a provisioner with is_alive=True, + or if the kernel manager's is_alive check passes. + """ + # Don't reconnect if intentionally stopped + if self._channels_stopped: + return False + + # Don't reconnect if kernel is not alive + return asyncio.run(self.is_alive()) + + def _reconnect_socket(self) -> bool: + """Attempt to reconnect the WebSocket. + + Returns True on successful reconnection, False otherwise. + """ + try: + # Close old socket if exists + if self.channel_socket: + try: + self.channel_socket.close() + except Exception: + pass + + # Create new WebSocket connection + ssl_options = { + "ca_certs": GatewayClient.instance().ca_certs, + "certfile": GatewayClient.instance().client_cert, + "keyfile": GatewayClient.instance().client_key, + } + + self.channel_socket = websocket.create_connection( + self.add_session_param(self.ws_url), + timeout=GatewayClient.instance().KERNEL_LAUNCH_TIMEOUT, + enable_multithread=True, + sslopt=ssl_options, + ) + + # Update channel queues with new socket + # Each ChannelQueue holds a reference to channel_socket for sending messages + if self._channel_queues: + for queue in self._channel_queues.values(): + queue.channel_socket = self.channel_socket + + return True + + except Exception as e: + self.log.error(f"Failed to reconnect WebSocket: {e}") + return False + + def _mark_queues_finished(self) -> None: + """Mark all channel queues as finished. + + This notifies consumers that no more messages will be received. + """ + if self._channel_queues: + for channel_queue in self._channel_queues.values(): + channel_queue.response_router_finished = True + self.log.debug("Response router thread exiting...") + def _route_responses(self): """ Reads responses from the websocket and routes each to the appropriate channel queue based - on the message's channel. It does this for the duration of the class's lifetime until the - channels are stopped, at which time the socket is closed (unblocking the router) and - the thread terminates. If shutdown happens to occur while processing a response (unlikely), - termination takes place via the loop control boolean. + on the message's channel. Implements automatic reconnection with exponential backoff when + the WebSocket disconnects but the kernel is still alive. + + The router continues running until channels are stopped, handling disconnections gracefully + and attempting to reconnect to maintain the message flow. """ - try: - while not self._channels_stopped: + max_reconnect_attempts = 5 + base_delay = 1.0 # seconds + max_delay = 30.0 # max delay between reconnection attempts + attempt = 0 + + while not self._channels_stopped: + try: + # Check if socket needs reconnection + if self.channel_socket is None: + if not self._should_reconnect(): + self.log.info( + "WebSocket disconnected and Kernel is terminated. Exiting router." + ) + break + + # Calculate exponential backoff delay + delay = min(base_delay * (2**attempt), max_delay) + self.log.warning( + f"WebSocket disconnected. Attempting reconnection " + f"(attempt {attempt + 1}/{max_reconnect_attempts}) " + f"after {delay:.1f}s delay..." + ) + time.sleep(delay) + + # Attempt reconnection + if not self._reconnect_socket(): + attempt += 1 + if attempt >= max_reconnect_attempts: + self.log.error( + f"Max reconnection attempts ({max_reconnect_attempts}) reached. " + "Giving up on WebSocket reconnection." + ) + break + continue + + # Reset attempt counter on successful reconnection + attempt = 0 + self.log.info("WebSocket reconnected successfully") + + # Normal message routing assert self.channel_socket is not None raw_message = self.channel_socket.recv() if not raw_message: - break + # Empty message, socket might be closing + self.log.debug("Received empty message from WebSocket") + self.channel_socket = None + continue + response_message = json_decode(utf8(raw_message)) channel = response_message["channel"] assert self._channel_queues is not None self._channel_queues[channel].put_nowait(response_message) - except websocket.WebSocketConnectionClosedException: - pass # websocket closure most likely due to shut down + except websocket.WebSocketConnectionClosedException: + self.log.warning("WebSocket connection closed unexpectedly") + self.channel_socket = None + # Loop will retry connection on next iteration if not stopped - except BaseException as be: - if not self._channels_stopped: - self.log.warning(f"Unexpected exception encountered ({be})") + except Exception as e: + if not self._channels_stopped: + self.log.exception(f"Error in response router: {e}") + self.channel_socket = None + # Brief pause before retry to avoid tight error loop + time.sleep(1) - # Notify channel queues that this thread had finished and no more messages are being received - assert self._channel_queues is not None - for channel_queue in self._channel_queues.values(): - channel_queue.response_router_finished = True - - self.log.debug("Response router thread exiting...") + # Final cleanup: mark all queues as finished + self._mark_queues_finished() KernelClientABC.register(GatewayKernelClient) diff --git a/jupyter_server/services/kernels/websocket.py b/jupyter_server/services/kernels/websocket.py index a24b0539f..94834604f 100644 --- a/jupyter_server/services/kernels/websocket.py +++ b/jupyter_server/services/kernels/websocket.py @@ -9,6 +9,7 @@ from jupyter_server.auth.decorator import ws_authenticated from jupyter_server.base.handlers import JupyterHandler from jupyter_server.base.websocket import WebSocketMixin +from jupyter_server.saturn.connections import SaturnGatewayWebSocketConnection AUTH_RESOURCE = "kernels"