Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
192 changes: 163 additions & 29 deletions jupyter_server/gateway/managers.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,12 @@
import datetime
import json
import os
import time
from queue import Empty, Queue
from threading import Thread
from time import monotonic
from typing import TYPE_CHECKING, Any, Optional, cast
from urllib.parse import parse_qs, urlencode, urlsplit, urlunsplit

import websocket
from jupyter_client.asynchronous.client import AsyncKernelClient
Expand All @@ -21,7 +23,7 @@
from jupyter_core.utils import ensure_async
from tornado import web
from tornado.escape import json_decode, json_encode, url_escape, utf8
from traitlets import DottedObjectName, Instance, Type, default
from traitlets import DottedObjectName, Instance, Type, Unicode, default

from .._tz import UTC, utcnow
from ..services.kernels.kernelmanager import (
Expand Down Expand Up @@ -622,9 +624,10 @@ async def _async_get(self, timeout=None):
except Empty:
if self.response_router_finished:
msg = "Response router had finished"
# TODO throw dedicated Exception for the caller to react on it.
raise RuntimeError(msg) from None
if monotonic() > end_time:
raise
raise TimeoutError(f"{self.channel_name} async_get timeout") from None
await asyncio.sleep(0)

async def get_msg(self, *args: Any, **kwargs: Any) -> dict[str, Any]:
Expand Down Expand Up @@ -720,33 +723,58 @@ class GatewayKernelClient(AsyncKernelClient):
_iopub_channel: Optional[ChannelQueue]
_shell_channel: Optional[ChannelQueue]

def __init__(self, kernel_id, **kwargs):
ws_url = Unicode(
default_value=None,
allow_none=True,
config=True,
help="""The websocket url of the Kernel or Kernel Gateway server. If not provided, this value
will correspond to the value of the Gateway url with 'ws' in place of 'http'. (JUPYTER_GATEWAY_WS_URL env var)
""",
)

@property
def session_id(self):
"""The session id."""
return self.session.session

def __init__(self, **kwargs):
"""Initialize a gateway kernel client."""
super().__init__(**kwargs)
self.kernel_id = kernel_id
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note: kernel_id is only used for constructing gateway websocket url.

self.channel_socket: Optional[websocket.WebSocket] = None
self.response_router: Optional[Thread] = None
self._channels_stopped = False
self._channel_queues = {}

def add_session_param(self, ws_url):
if not self.session_id:
return ws_url

scheme, netloc, path, query_string, fragment = urlsplit(ws_url)
# Parse the existing query string into a dictionary
query_params = parse_qs(query_string)
# Update the dictionary with new parameters
query_params.update({"session": self.session_id})
# Encode the updated dictionary back into a query string
new_query_string = urlencode(query_params, doseq=True)
# Reconstruct the URL
return urlunsplit((scheme, netloc, path, new_query_string, fragment))

# --------------------------------------------------------------------------
# Channel management methods
# --------------------------------------------------------------------------

async def start_channels(self, shell=True, iopub=True, stdin=True, hb=True, control=True):
"""Starts the channels for this kernel.

Please set property ws_url before calling this method.
For this class, we establish a websocket connection to the destination
and set up the channel-based queues on which applicable messages will
be posted.
"""
if self.ws_url is None:
msg = "ws_url is None. set it before call start_channels"
raise RuntimeError(msg)

ws_url = url_path_join(
GatewayClient.instance().ws_url or "",
GatewayClient.instance().kernels_endpoint,
url_escape(self.kernel_id),
"channels",
)
Copy link
Author

@vincentye38 vincentye38 Feb 19, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note: I proposed to move ws_url initialization to kernel manager, and make ws_url as a property of kernel client. Because kernel manager is responsible to start kernel process, instantiate a kernel client to connect o the process. So it knows information of how to connect to the kernel process.
Kernel manager will set ws_url of kernel client after it creates the kernel client.

# Gather cert info in case where ssl is desired...
ssl_options = {
"ca_certs": GatewayClient.instance().ca_certs,
Expand All @@ -755,7 +783,7 @@ async def start_channels(self, shell=True, iopub=True, stdin=True, hb=True, cont
}

self.channel_socket = websocket.create_connection(
ws_url,
self.add_session_param(self.ws_url),
timeout=GatewayClient.instance().KERNEL_LAUNCH_TIMEOUT,
enable_multithread=True,
sslopt=ssl_options,
Expand Down Expand Up @@ -844,38 +872,144 @@ def control_channel(self):
self._channel_queues["control"] = self._control_channel
return self._control_channel

def _should_reconnect(self) -> bool:
"""Determine if we should attempt to reconnect the WebSocket.

Returns True if the kernel manager has a provisioner with is_alive=True,
or if the kernel manager's is_alive check passes.
"""
# Don't reconnect if intentionally stopped
if self._channels_stopped:
return False

# Don't reconnect if kernel is not alive
return asyncio.run(self.is_alive())

def _reconnect_socket(self) -> bool:
"""Attempt to reconnect the WebSocket.

Returns True on successful reconnection, False otherwise.
"""
try:
# Close old socket if exists
if self.channel_socket:
try:
self.channel_socket.close()
except Exception:
pass

# Create new WebSocket connection
ssl_options = {
"ca_certs": GatewayClient.instance().ca_certs,
"certfile": GatewayClient.instance().client_cert,
"keyfile": GatewayClient.instance().client_key,
}

self.channel_socket = websocket.create_connection(
self.add_session_param(self.ws_url),
timeout=GatewayClient.instance().KERNEL_LAUNCH_TIMEOUT,
enable_multithread=True,
sslopt=ssl_options,
)

# Update channel queues with new socket
# Each ChannelQueue holds a reference to channel_socket for sending messages
if self._channel_queues:
for queue in self._channel_queues.values():
queue.channel_socket = self.channel_socket

return True

except Exception as e:
self.log.error(f"Failed to reconnect WebSocket: {e}")
return False

def _mark_queues_finished(self) -> None:
"""Mark all channel queues as finished.

This notifies consumers that no more messages will be received.
"""
if self._channel_queues:
for channel_queue in self._channel_queues.values():
channel_queue.response_router_finished = True
self.log.debug("Response router thread exiting...")

def _route_responses(self):
"""
Reads responses from the websocket and routes each to the appropriate channel queue based
on the message's channel. It does this for the duration of the class's lifetime until the
channels are stopped, at which time the socket is closed (unblocking the router) and
the thread terminates. If shutdown happens to occur while processing a response (unlikely),
termination takes place via the loop control boolean.
on the message's channel. Implements automatic reconnection with exponential backoff when
the WebSocket disconnects but the kernel is still alive.

The router continues running until channels are stopped, handling disconnections gracefully
and attempting to reconnect to maintain the message flow.
"""
try:
while not self._channels_stopped:
max_reconnect_attempts = 5
base_delay = 1.0 # seconds
max_delay = 30.0 # max delay between reconnection attempts
attempt = 0

while not self._channels_stopped:
try:
# Check if socket needs reconnection
if self.channel_socket is None:
if not self._should_reconnect():
self.log.info(
"WebSocket disconnected and Kernel is terminated. Exiting router."
)
break

# Calculate exponential backoff delay
delay = min(base_delay * (2**attempt), max_delay)
self.log.warning(
f"WebSocket disconnected. Attempting reconnection "
f"(attempt {attempt + 1}/{max_reconnect_attempts}) "
f"after {delay:.1f}s delay..."
)
time.sleep(delay)

# Attempt reconnection
if not self._reconnect_socket():
attempt += 1
if attempt >= max_reconnect_attempts:
self.log.error(
f"Max reconnection attempts ({max_reconnect_attempts}) reached. "
"Giving up on WebSocket reconnection."
)
break
continue

# Reset attempt counter on successful reconnection
attempt = 0
self.log.info("WebSocket reconnected successfully")

# Normal message routing
assert self.channel_socket is not None
raw_message = self.channel_socket.recv()
if not raw_message:
break
# Empty message, socket might be closing
self.log.debug("Received empty message from WebSocket")
self.channel_socket = None
continue

response_message = json_decode(utf8(raw_message))
channel = response_message["channel"]
assert self._channel_queues is not None
self._channel_queues[channel].put_nowait(response_message)

except websocket.WebSocketConnectionClosedException:
pass # websocket closure most likely due to shut down
except websocket.WebSocketConnectionClosedException:
self.log.warning("WebSocket connection closed unexpectedly")
self.channel_socket = None
# Loop will retry connection on next iteration if not stopped

except BaseException as be:
if not self._channels_stopped:
self.log.warning(f"Unexpected exception encountered ({be})")
except Exception as e:
if not self._channels_stopped:
self.log.exception(f"Error in response router: {e}")
self.channel_socket = None
# Brief pause before retry to avoid tight error loop
time.sleep(1)

# Notify channel queues that this thread had finished and no more messages are being received
assert self._channel_queues is not None
for channel_queue in self._channel_queues.values():
channel_queue.response_router_finished = True

self.log.debug("Response router thread exiting...")
# Final cleanup: mark all queues as finished
self._mark_queues_finished()


KernelClientABC.register(GatewayKernelClient)
1 change: 1 addition & 0 deletions jupyter_server/services/kernels/websocket.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from jupyter_server.auth.decorator import ws_authenticated
from jupyter_server.base.handlers import JupyterHandler
from jupyter_server.base.websocket import WebSocketMixin
from jupyter_server.saturn.connections import SaturnGatewayWebSocketConnection

AUTH_RESOURCE = "kernels"

Expand Down
Loading