-
Notifications
You must be signed in to change notification settings - Fork 388
Enhance GatewayClient #1603
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Enhance GatewayClient #1603
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -8,10 +8,12 @@ | |
| import datetime | ||
| import json | ||
| import os | ||
| import time | ||
| from queue import Empty, Queue | ||
| from threading import Thread | ||
| from time import monotonic | ||
| from typing import TYPE_CHECKING, Any, Optional, cast | ||
| from urllib.parse import parse_qs, urlencode, urlsplit, urlunsplit | ||
|
|
||
| import websocket | ||
| from jupyter_client.asynchronous.client import AsyncKernelClient | ||
|
|
@@ -21,7 +23,7 @@ | |
| from jupyter_core.utils import ensure_async | ||
| from tornado import web | ||
| from tornado.escape import json_decode, json_encode, url_escape, utf8 | ||
| from traitlets import DottedObjectName, Instance, Type, default | ||
| from traitlets import DottedObjectName, Instance, Type, Unicode, default | ||
|
|
||
| from .._tz import UTC, utcnow | ||
| from ..services.kernels.kernelmanager import ( | ||
|
|
@@ -622,9 +624,10 @@ async def _async_get(self, timeout=None): | |
| except Empty: | ||
| if self.response_router_finished: | ||
| msg = "Response router had finished" | ||
| # TODO throw dedicated Exception for the caller to react on it. | ||
| raise RuntimeError(msg) from None | ||
| if monotonic() > end_time: | ||
| raise | ||
| raise TimeoutError(f"{self.channel_name} async_get timeout") from None | ||
| await asyncio.sleep(0) | ||
|
|
||
| async def get_msg(self, *args: Any, **kwargs: Any) -> dict[str, Any]: | ||
|
|
@@ -720,33 +723,58 @@ class GatewayKernelClient(AsyncKernelClient): | |
| _iopub_channel: Optional[ChannelQueue] | ||
| _shell_channel: Optional[ChannelQueue] | ||
|
|
||
| def __init__(self, kernel_id, **kwargs): | ||
| ws_url = Unicode( | ||
| default_value=None, | ||
| allow_none=True, | ||
| config=True, | ||
| help="""The websocket url of the Kernel or Kernel Gateway server. If not provided, this value | ||
| will correspond to the value of the Gateway url with 'ws' in place of 'http'. (JUPYTER_GATEWAY_WS_URL env var) | ||
| """, | ||
| ) | ||
|
|
||
| @property | ||
| def session_id(self): | ||
| """The session id.""" | ||
| return self.session.session | ||
|
|
||
| def __init__(self, **kwargs): | ||
| """Initialize a gateway kernel client.""" | ||
| super().__init__(**kwargs) | ||
| self.kernel_id = kernel_id | ||
| self.channel_socket: Optional[websocket.WebSocket] = None | ||
| self.response_router: Optional[Thread] = None | ||
| self._channels_stopped = False | ||
| self._channel_queues = {} | ||
|
|
||
| def add_session_param(self, ws_url): | ||
| if not self.session_id: | ||
| return ws_url | ||
|
|
||
| scheme, netloc, path, query_string, fragment = urlsplit(ws_url) | ||
| # Parse the existing query string into a dictionary | ||
| query_params = parse_qs(query_string) | ||
| # Update the dictionary with new parameters | ||
| query_params.update({"session": self.session_id}) | ||
| # Encode the updated dictionary back into a query string | ||
| new_query_string = urlencode(query_params, doseq=True) | ||
| # Reconstruct the URL | ||
| return urlunsplit((scheme, netloc, path, new_query_string, fragment)) | ||
|
|
||
| # -------------------------------------------------------------------------- | ||
| # Channel management methods | ||
| # -------------------------------------------------------------------------- | ||
|
|
||
| async def start_channels(self, shell=True, iopub=True, stdin=True, hb=True, control=True): | ||
| """Starts the channels for this kernel. | ||
|
|
||
| Please set property ws_url before calling this method. | ||
| For this class, we establish a websocket connection to the destination | ||
| and set up the channel-based queues on which applicable messages will | ||
| be posted. | ||
| """ | ||
| if self.ws_url is None: | ||
| msg = "ws_url is None. set it before call start_channels" | ||
| raise RuntimeError(msg) | ||
|
|
||
| ws_url = url_path_join( | ||
| GatewayClient.instance().ws_url or "", | ||
| GatewayClient.instance().kernels_endpoint, | ||
| url_escape(self.kernel_id), | ||
| "channels", | ||
| ) | ||
|
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Note: I proposed to move ws_url initialization to kernel manager, and make ws_url as a property of kernel client. Because kernel manager is responsible to start kernel process, instantiate a kernel client to connect o the process. So it knows information of how to connect to the kernel process. |
||
| # Gather cert info in case where ssl is desired... | ||
| ssl_options = { | ||
| "ca_certs": GatewayClient.instance().ca_certs, | ||
|
|
@@ -755,7 +783,7 @@ async def start_channels(self, shell=True, iopub=True, stdin=True, hb=True, cont | |
| } | ||
|
|
||
| self.channel_socket = websocket.create_connection( | ||
| ws_url, | ||
| self.add_session_param(self.ws_url), | ||
| timeout=GatewayClient.instance().KERNEL_LAUNCH_TIMEOUT, | ||
| enable_multithread=True, | ||
| sslopt=ssl_options, | ||
|
|
@@ -844,38 +872,144 @@ def control_channel(self): | |
| self._channel_queues["control"] = self._control_channel | ||
| return self._control_channel | ||
|
|
||
| def _should_reconnect(self) -> bool: | ||
| """Determine if we should attempt to reconnect the WebSocket. | ||
|
|
||
| Returns True if the kernel manager has a provisioner with is_alive=True, | ||
| or if the kernel manager's is_alive check passes. | ||
| """ | ||
| # Don't reconnect if intentionally stopped | ||
| if self._channels_stopped: | ||
| return False | ||
|
|
||
| # Don't reconnect if kernel is not alive | ||
| return asyncio.run(self.is_alive()) | ||
|
|
||
| def _reconnect_socket(self) -> bool: | ||
| """Attempt to reconnect the WebSocket. | ||
|
|
||
| Returns True on successful reconnection, False otherwise. | ||
| """ | ||
| try: | ||
| # Close old socket if exists | ||
| if self.channel_socket: | ||
| try: | ||
| self.channel_socket.close() | ||
| except Exception: | ||
| pass | ||
|
|
||
| # Create new WebSocket connection | ||
| ssl_options = { | ||
| "ca_certs": GatewayClient.instance().ca_certs, | ||
| "certfile": GatewayClient.instance().client_cert, | ||
| "keyfile": GatewayClient.instance().client_key, | ||
| } | ||
|
|
||
| self.channel_socket = websocket.create_connection( | ||
| self.add_session_param(self.ws_url), | ||
| timeout=GatewayClient.instance().KERNEL_LAUNCH_TIMEOUT, | ||
| enable_multithread=True, | ||
| sslopt=ssl_options, | ||
| ) | ||
|
|
||
| # Update channel queues with new socket | ||
| # Each ChannelQueue holds a reference to channel_socket for sending messages | ||
| if self._channel_queues: | ||
| for queue in self._channel_queues.values(): | ||
| queue.channel_socket = self.channel_socket | ||
|
|
||
| return True | ||
|
|
||
| except Exception as e: | ||
| self.log.error(f"Failed to reconnect WebSocket: {e}") | ||
| return False | ||
|
|
||
| def _mark_queues_finished(self) -> None: | ||
| """Mark all channel queues as finished. | ||
|
|
||
| This notifies consumers that no more messages will be received. | ||
| """ | ||
| if self._channel_queues: | ||
| for channel_queue in self._channel_queues.values(): | ||
| channel_queue.response_router_finished = True | ||
| self.log.debug("Response router thread exiting...") | ||
|
|
||
| def _route_responses(self): | ||
| """ | ||
| Reads responses from the websocket and routes each to the appropriate channel queue based | ||
| on the message's channel. It does this for the duration of the class's lifetime until the | ||
| channels are stopped, at which time the socket is closed (unblocking the router) and | ||
| the thread terminates. If shutdown happens to occur while processing a response (unlikely), | ||
| termination takes place via the loop control boolean. | ||
| on the message's channel. Implements automatic reconnection with exponential backoff when | ||
| the WebSocket disconnects but the kernel is still alive. | ||
|
|
||
| The router continues running until channels are stopped, handling disconnections gracefully | ||
| and attempting to reconnect to maintain the message flow. | ||
| """ | ||
| try: | ||
| while not self._channels_stopped: | ||
| max_reconnect_attempts = 5 | ||
| base_delay = 1.0 # seconds | ||
| max_delay = 30.0 # max delay between reconnection attempts | ||
| attempt = 0 | ||
|
|
||
| while not self._channels_stopped: | ||
| try: | ||
| # Check if socket needs reconnection | ||
| if self.channel_socket is None: | ||
| if not self._should_reconnect(): | ||
| self.log.info( | ||
| "WebSocket disconnected and Kernel is terminated. Exiting router." | ||
| ) | ||
| break | ||
|
|
||
| # Calculate exponential backoff delay | ||
| delay = min(base_delay * (2**attempt), max_delay) | ||
| self.log.warning( | ||
| f"WebSocket disconnected. Attempting reconnection " | ||
| f"(attempt {attempt + 1}/{max_reconnect_attempts}) " | ||
| f"after {delay:.1f}s delay..." | ||
| ) | ||
| time.sleep(delay) | ||
|
|
||
| # Attempt reconnection | ||
| if not self._reconnect_socket(): | ||
| attempt += 1 | ||
| if attempt >= max_reconnect_attempts: | ||
| self.log.error( | ||
| f"Max reconnection attempts ({max_reconnect_attempts}) reached. " | ||
| "Giving up on WebSocket reconnection." | ||
| ) | ||
| break | ||
| continue | ||
|
|
||
| # Reset attempt counter on successful reconnection | ||
| attempt = 0 | ||
| self.log.info("WebSocket reconnected successfully") | ||
|
|
||
| # Normal message routing | ||
| assert self.channel_socket is not None | ||
| raw_message = self.channel_socket.recv() | ||
| if not raw_message: | ||
| break | ||
| # Empty message, socket might be closing | ||
| self.log.debug("Received empty message from WebSocket") | ||
| self.channel_socket = None | ||
| continue | ||
|
|
||
| response_message = json_decode(utf8(raw_message)) | ||
| channel = response_message["channel"] | ||
| assert self._channel_queues is not None | ||
| self._channel_queues[channel].put_nowait(response_message) | ||
|
|
||
| except websocket.WebSocketConnectionClosedException: | ||
| pass # websocket closure most likely due to shut down | ||
| except websocket.WebSocketConnectionClosedException: | ||
| self.log.warning("WebSocket connection closed unexpectedly") | ||
| self.channel_socket = None | ||
| # Loop will retry connection on next iteration if not stopped | ||
|
|
||
| except BaseException as be: | ||
| if not self._channels_stopped: | ||
| self.log.warning(f"Unexpected exception encountered ({be})") | ||
| except Exception as e: | ||
| if not self._channels_stopped: | ||
| self.log.exception(f"Error in response router: {e}") | ||
| self.channel_socket = None | ||
| # Brief pause before retry to avoid tight error loop | ||
| time.sleep(1) | ||
|
|
||
| # Notify channel queues that this thread had finished and no more messages are being received | ||
| assert self._channel_queues is not None | ||
| for channel_queue in self._channel_queues.values(): | ||
| channel_queue.response_router_finished = True | ||
|
|
||
| self.log.debug("Response router thread exiting...") | ||
| # Final cleanup: mark all queues as finished | ||
| self._mark_queues_finished() | ||
|
|
||
|
|
||
| KernelClientABC.register(GatewayKernelClient) | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Note: kernel_id is only used for constructing gateway websocket url.