diff --git a/docs/source/api/jupyter_server.services.kernels.rst b/docs/source/api/jupyter_server.services.kernels.rst index 683e340a7..9fa4d1c4e 100644 --- a/docs/source/api/jupyter_server.services.kernels.rst +++ b/docs/source/api/jupyter_server.services.kernels.rst @@ -25,6 +25,12 @@ Submodules :undoc-members: +.. automodule:: jupyter_server.services.kernels.routing + :members: + :show-inheritance: + :undoc-members: + + .. automodule:: jupyter_server.services.kernels.websocket :members: :show-inheritance: diff --git a/jupyter_server/gateway/managers.py b/jupyter_server/gateway/managers.py index 270001f30..a5ab4012b 100644 --- a/jupyter_server/gateway/managers.py +++ b/jupyter_server/gateway/managers.py @@ -8,6 +8,7 @@ import datetime import json import os +import warnings from queue import Empty, Queue from threading import Thread from time import monotonic @@ -30,7 +31,9 @@ emit_kernel_action_event, ) from ..services.sessions.sessionmanager import SessionManager +from ..transutils import _i18n from ..utils import url_path_join +from .connections import GatewayWebSocketConnection from .gateway_client import GatewayClient, gateway_request if TYPE_CHECKING: @@ -211,6 +214,13 @@ async def cull_kernels(self): await self.list_kernels() await super().cull_kernels() + @property + def info(self): + return ( + _i18n("\nKernels will be managed by the Gateway server running at:\n%s") + % self.kernels_url + ) + class GatewayKernelSpecManager(KernelSpecManager): """A gateway kernel spec manager.""" @@ -359,22 +369,13 @@ class GatewaySessionManager(SessionManager): kernel_manager = Instance("jupyter_server.gateway.managers.GatewayMappingKernelManager") - async def kernel_culled(self, kernel_id: str) -> bool: # typing: ignore - """Checks if the kernel is still considered alive and returns true if it's not found.""" - km: Optional[GatewayKernelManager] = None - try: - # Since we keep the models up-to-date via client polling, use that state to determine - # if this kernel no longer exists on the gateway server rather than perform a redundant - # fetch operation - especially since this is called at approximately the same interval. - # This has the effect of reducing GET /api/kernels requests against the gateway server - # by 50%! - # Note that should the redundant polling be consolidated, or replaced with an event-based - # notification model, this will need to be revisited. - km = self.kernel_manager.get_kernel(kernel_id) - except Exception: - # Let exceptions here reflect culled kernel - pass - return km is None + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + warnings.warn( + "The GatewaySessionManager class is deprecated and will not be supported in Jupyter Server 3.0", + DeprecationWarning, + stacklevel=2, + ) class GatewayKernelManager(ServerKernelManager): @@ -406,6 +407,7 @@ def has_kernel(self): client_class = DottedObjectName("jupyter_server.gateway.managers.GatewayKernelClient") client_factory = Type(klass="jupyter_server.gateway.managers.GatewayKernelClient") + websocket_connection_class = GatewayWebSocketConnection # -------------------------------------------------------------------------- # create a Client connected to our Kernel diff --git a/jupyter_server/serverapp.py b/jupyter_server/serverapp.py index ed2bd1361..8d91e0c34 100644 --- a/jupyter_server/serverapp.py +++ b/jupyter_server/serverapp.py @@ -103,13 +103,7 @@ from jupyter_server.extension.config import ExtensionConfigManager from jupyter_server.extension.manager import ExtensionManager from jupyter_server.extension.serverextension import ServerExtensionApp -from jupyter_server.gateway.connections import GatewayWebSocketConnection from jupyter_server.gateway.gateway_client import GatewayClient -from jupyter_server.gateway.managers import ( - GatewayKernelSpecManager, - GatewayMappingKernelManager, - GatewaySessionManager, -) from jupyter_server.log import log_request from jupyter_server.prometheus.metrics import ( ACTIVE_DURATION, @@ -131,6 +125,10 @@ AsyncMappingKernelManager, MappingKernelManager, ) +from jupyter_server.services.kernels.routing import ( + RoutingKernelSpecManager, + RoutingMappingKernelManager, +) from jupyter_server.services.sessions.sessionmanager import SessionManager from jupyter_server.utils import ( JupyterServerAuthWarning, @@ -893,14 +891,12 @@ class ServerApp(JupyterApp): AsyncContentsManager, AsyncFileContentsManager, NotebookNotary, - GatewayMappingKernelManager, - GatewayKernelSpecManager, - GatewaySessionManager, - GatewayWebSocketConnection, GatewayClient, Authorizer, EventLogger, ZMQChannelsWebsocketConnection, + RoutingKernelSpecManager, + RoutingMappingKernelManager, ] subcommands: dict[str, t.Any] = { @@ -1621,9 +1617,7 @@ def template_file_path(self) -> list[str]: @default("kernel_manager_class") def _default_kernel_manager_class(self) -> t.Union[str, type[AsyncMappingKernelManager]]: - if self.gateway_config.gateway_enabled: - return "jupyter_server.gateway.managers.GatewayMappingKernelManager" - return AsyncMappingKernelManager + return RoutingMappingKernelManager session_manager_class = Type( config=True, @@ -1632,8 +1626,6 @@ def _default_kernel_manager_class(self) -> t.Union[str, type[AsyncMappingKernelM @default("session_manager_class") def _default_session_manager_class(self) -> t.Union[str, type[SessionManager]]: - if self.gateway_config.gateway_enabled: - return "jupyter_server.gateway.managers.GatewaySessionManager" return SessionManager kernel_websocket_connection_class = Type( @@ -1646,8 +1638,11 @@ def _default_session_manager_class(self) -> t.Union[str, type[SessionManager]]: def _default_kernel_websocket_connection_class( self, ) -> t.Union[str, type[ZMQChannelsWebsocketConnection]]: - if self.gateway_config.gateway_enabled: - return "jupyter_server.gateway.connections.GatewayWebSocketConnection" + if issubclass( + self.kernel_manager_class, + RoutingMappingKernelManager, + ): + return "jupyter_server.services.kernels.routing.RoutingKernelManagerWebsocketConnection" return ZMQChannelsWebsocketConnection websocket_ping_interval = Integer( @@ -1697,8 +1692,11 @@ def _default_kernel_websocket_connection_class( @default("kernel_spec_manager_class") def _default_kernel_spec_manager_class(self) -> t.Union[str, type[KernelSpecManager]]: - if self.gateway_config.gateway_enabled: - return "jupyter_server.gateway.managers.GatewayKernelSpecManager" + if issubclass( + self.kernel_manager_class, + RoutingMappingKernelManager, + ): + return RoutingKernelSpecManager return KernelSpecManager login_handler_class = Type( @@ -2877,11 +2875,8 @@ def running_server_info(self, kernel_count: bool = True) -> str: info += _i18n("Jupyter Server {version} is running at:\n{url}").format( version=ServerApp.version, url=self.display_url ) - if self.gateway_config.gateway_enabled: - info += ( - _i18n("\nKernels will be managed by the Gateway server running at:\n%s") - % self.gateway_config.url - ) + if hasattr(self.kernel_manager, "info"): + info += self.kernel_manager.info return info def server_info(self) -> dict[str, t.Any]: diff --git a/jupyter_server/services/kernels/kernelmanager.py b/jupyter_server/services/kernels/kernelmanager.py index 5a64917c1..ed1775a70 100644 --- a/jupyter_server/services/kernels/kernelmanager.py +++ b/jupyter_server/services/kernels/kernelmanager.py @@ -43,6 +43,7 @@ Integer, List, TraitError, + Type, Unicode, default, validate, @@ -51,6 +52,8 @@ from jupyter_server import DEFAULT_EVENTS_SCHEMA_PATH from jupyter_server._tz import isoformat, utcnow from jupyter_server.prometheus.metrics import KERNEL_CURRENTLY_RUNNING_TOTAL +from jupyter_server.services.kernels.connection.base import BaseKernelWebsocketConnection +from jupyter_server.services.kernels.connection.channels import ZMQChannelsWebsocketConnection from jupyter_server.utils import ApiPath, import_item, to_os_path @@ -899,6 +902,14 @@ def _default_event_logger(self): pass return logger + websocket_connection_class = Type( + default_value=ZMQChannelsWebsocketConnection, + klass=BaseKernelWebsocketConnection, + help=""" + The websocket connection class to use for this manager's kernels. + """, + ).tag(config=True) + def emit(self, schema_id, data): """Emit an event from the kernel manager.""" self.event_logger.emit(schema_id=schema_id, data=data) diff --git a/jupyter_server/services/kernels/routing.py b/jupyter_server/services/kernels/routing.py new file mode 100644 index 000000000..1ed229e8e --- /dev/null +++ b/jupyter_server/services/kernels/routing.py @@ -0,0 +1,378 @@ +import copy +import typing as t + +from jupyter_client.kernelspec import KernelSpecManager +from jupyter_client.manager import in_pending_state +from jupyter_client.managerabc import KernelManagerABC +from jupyter_core.utils import ensure_async, run_sync +from traitlets import ( + Dict, + Instance, + List, + Type, + Unicode, + default, + observe, +) +from traitlets.config import LoggingConfigurable + +from jupyter_server.gateway.gateway_client import GatewayClient +from jupyter_server.gateway.managers import GatewayKernelSpecManager, GatewayMappingKernelManager +from jupyter_server.services.kernels.connection.base import BaseKernelWebsocketConnection +from jupyter_server.services.kernels.connection.channels import ZMQChannelsWebsocketConnection +from jupyter_server.services.kernels.kernelmanager import ( + AsyncMappingKernelManager, + ServerKernelManager, +) +from jupyter_server.transutils import _i18n + + +class RoutingProvider(LoggingConfigurable): + connection_dir = Unicode("") + + primary_manager = Instance(AsyncMappingKernelManager) + + additional_managers = List(trait=Instance(AsyncMappingKernelManager)) + + @default("primary_manager") + def _default_primary_manager(self): + ksm = KernelSpecManager(parent=self.parent) + return AsyncMappingKernelManager( + parent=self.parent, + log=self.log, + connection_dir=self.connection_dir, + kernel_spec_manager=ksm, + ) + + info = Unicode("") + + @default("info") + def _default_info(self): + if hasattr(self.primary_manager, "info"): + return self.primary_manager.info + return "" + + +class RemoteOnlyRoutingProvider(RoutingProvider): + @default("primary_manager") + def _default_primary_manager(self): + ksm = GatewayKernelSpecManager(parent=self.parent) + return GatewayMappingKernelManager( + parent=self.parent, + log=self.log, + connection_dir=self.connection_dir, + kernel_spec_manager=ksm, + ) + + +class SideBySideRoutingProvider(RoutingProvider): + @default("additional_managers") + def _default_additional_managers(self): + ksm = GatewayKernelSpecManager(parent=self.parent) + return [ + GatewayMappingKernelManager( + parent=self.parent, + log=self.log, + connection_dir=self.connection_dir, + kernel_spec_manager=ksm, + ) + ] + + +class AsyncRoutingKernelSpecManager(KernelSpecManager): + """KernelSpecManager that routes to multiple nested kernel spec managers. + + This async version of the wrapper exists because the base KernelSpecManager + class only has synchronous methods, but some child classes (in particular, + GatewayKernelManager) change those methods to be async. + + In order to support both versions, we first implement the routing in this async + class, but then make it synchronous in the child, RoutingKernelSpecManager class. + """ + + @property + def primary_manager(self) -> AsyncMappingKernelManager: + # This kernelspec manager can only be used when the corresponding kernel + # manager can tell us how to route requests to the nested managers. + assert self.parent is not None + assert hasattr(self.parent.kernel_manager, "routing_provider") + assert isinstance(self.parent.kernel_manager.routing_provider, RoutingProvider) + + km = self.parent.kernel_manager.routing_provider.primary_manager + + # On the odd chance that an administrator explicitly configured a routing + # provider with a nested routing kernelspec manager, all attempts to list + # or get kernelspecs will result in an infinite loop. + # + # Accordingly, we use an assert to catch this early. + assert not isinstance(km.kernel_spec_manager, AsyncRoutingKernelSpecManager) + + return km + + @property + def additional_managers(self): + # This kernelspec manager can only be used when the corresponding kernel + # manager can tell us how to route requests to the nested managers. + assert self.parent is not None + assert hasattr(self.parent.kernel_manager, "routing_provider") + + kms = self.parent.kernel_manager.routing_provider.additional_managers + + # Similarly to the `primary_manager` property, we want to ensure that + # none of the nested kernelspec managers are instances of this same class, + # in order to prevent infinite loops and to catch the configuration + # issues that could cause such loops early. + for km in kms: + assert not isinstance(km.kernel_spec_manager, AsyncRoutingKernelSpecManager) + + return kms + + spec_to_manager_map = Dict(key_trait=Unicode(), value_trait=Instance(AsyncMappingKernelManager)) + + async def get_all_specs(self): + ksm = self.primary_manager.kernel_spec_manager + assert ksm is not None + ks = await ensure_async(ksm.get_all_specs()) + for spec_name, _spec in ks.items(): + self.spec_to_manager_map[spec_name] = self.primary_manager + for additional_manager in self.additional_managers: + additional_ks = await ensure_async( + additional_manager.kernel_spec_manager.get_all_specs() + ) + for spec_name, spec in additional_ks.items(): + if spec_name not in ks: + ks[spec_name] = spec + self.spec_to_manager_map[spec_name] = additional_manager + return ks + + def get_mapping_kernel_manager(self, kernel_name: str) -> AsyncMappingKernelManager: + km = self.spec_to_manager_map.get(kernel_name, None) or self.primary_manager + return km + + async def get_kernel_spec(self, kernel_name, **kwargs): + wrapped_manager = self.get_mapping_kernel_manager(kernel_name).kernel_spec_manager + assert wrapped_manager is not None + + return ensure_async(wrapped_manager.get_kernel_spec(kernel_name, **kwargs)) + + async def get_kernel_spec_resource(self, kernel_name, path): + wrapped_manager = self.get_mapping_kernel_manager(kernel_name).kernel_spec_manager + assert wrapped_manager is not None + + if hasattr(wrapped_manager, "get_kernel_spec_resource"): + return await ensure_async(wrapped_manager.get_kernel_spec_resource(kernel_name, path)) + return None + + def is_remote(self, kernel_name): + wrapped_manager = self.get_mapping_kernel_manager(kernel_name).kernel_spec_manager + assert wrapped_manager is not None + + return isinstance(wrapped_manager, GatewayKernelSpecManager) + + +class RoutingKernelSpecManager(AsyncRoutingKernelSpecManager): + """KernelSpecManager that routes to multiple nested kernel spec managers.""" + + def get_all_specs(self): + return run_sync(super().get_all_specs)() + + def get_kernel_spec(self, kernel_name, *args, **kwargs): + return run_sync(super().get_kernel_spec)(kernel_name, *args, **kwargs) + + +class RoutingKernelManagerWebsocketConnection(BaseKernelWebsocketConnection): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + km = self.kernel_manager.wrapped_kernel_manager + wrapped_class = ZMQChannelsWebsocketConnection + if hasattr(km, "websocket_connection_class"): + wrapped_class = km.websocket_connection_class + self.wrapped = wrapped_class( + parent=km, websocket_handler=self.websocket_handler, config=self.config + ) + + async def connect(self): + """Connect the kernel websocket to the kernel ZMQ connections""" + return await self.wrapped.connect() + + # N.B. The disconnect method in the BaseKernelWebsocketConnection is defined + # to be async, but in all of the implementing subclasses it is sync, and + # the Jupyter server does not await the value returned from this method. + def disconnect(self): + """Disconnect the kernel websocket from the kernel ZMQ connections""" + return self.wrapped.disconnect() + + def handle_incoming_message(self, incoming_msg: str) -> None: + """Broker the incoming websocket message to the appropriate ZMQ channel.""" + self.wrapped.handle_incoming_message(incoming_msg) + + def handle_outgoing_message(self, stream: str, outgoing_msg: list[t.Any]) -> None: + """Broker outgoing ZMQ messages to the kernel websocket.""" + self.wrapped.handle_outgoing_message(stream, outgoing_msg) + + async def prepare(self): + if hasattr(self.wrapped, "prepare"): + return await self.wrapped.prepare() + + +class RoutingKernelManager(ServerKernelManager): + kernel_id_map: dict[str, str] = {} + + @property + def is_remote(self): + if not self.kernel_name or not self.kernel_id: + return False + assert self.parent is not None + return self.parent.kernel_spec_manager.is_remote(self.kernel_name) + + @property + def wrapped_multi_kernel_manager(self): + assert self.parent is not None + return self.parent.kernel_spec_manager.get_mapping_kernel_manager(self.kernel_name) + + @property + def wrapped_kernel_manager(self): + if not self.kernel_id: + return None + wrapped_kernel_id = RoutingKernelManager.kernel_id_map.get(self.kernel_id, self.kernel_id) + return self.wrapped_multi_kernel_manager.get_kernel(wrapped_kernel_id) + + @default("websocket_connection_class") + def _default_websocket_connection_class(self): + return RoutingKernelManagerWebsocketConnection + + @property + def has_kernel(self): + if not self.kernel_id: + return False + return self.wrapped_kernel_manager.has_kernel + + async def is_alive(self): + if not self.has_kernel: + return False + return await self.wrapped_kernel_manager.is_alive() + + def client(self, *args, **kwargs): + if not self.kernel_id: + return None + return self.wrapped_kernel_manager.client(*args, **kwargs) + + @in_pending_state + async def start_kernel(self, *args, **kwargs): + kernel_id: t.Optional[str] = kwargs.pop("kernel_id", self.kernel_id) + if kernel_id: + self.kernel_id = kernel_id + + km = self.wrapped_multi_kernel_manager + wrapped_kernel_id: str = await ensure_async( + km.start_kernel(kernel_name=self.kernel_name, **kwargs) + ) + self.kernel_id = self.kernel_id or wrapped_kernel_id + RoutingKernelManager.kernel_id_map[self.kernel_id] = wrapped_kernel_id + self.log.debug( + f"Created kernel {self.kernel_id} corresponding to {wrapped_kernel_id} in {km}" + ) + self.log.debug(RoutingKernelManager.kernel_id_map) + + async def shutdown_kernel(self, now=False, restart=False): + wrapped_kernel_id = RoutingKernelManager.kernel_id_map.get(self.kernel_id, self.kernel_id) + km = self.wrapped_multi_kernel_manager + await ensure_async(km.shutdown_kernel(wrapped_kernel_id, now=now, restart=restart)) + RoutingKernelManager.kernel_id_map.pop(self.kernel_id, None) + + async def restart_kernel(self, now=False): + wrapped_kernel_id = RoutingKernelManager.kernel_id_map.get(self.kernel_id, self.kernel_id) + km = self.wrapped_multi_kernel_manager + return await ensure_async(km.restart_kernel(wrapped_kernel_id, now=now)) + + async def interrupt_kernel(self): + km = self.wrapped_kernel_manager + return await ensure_async(km.interrupt_kernel()) + + async def model(self): + wrapped_kernel_id = RoutingKernelManager.kernel_id_map.get(self.kernel_id, self.kernel_id) + wrapped_model = await ensure_async( + self.wrapped_multi_kernel_manager.kernel_model(wrapped_kernel_id) + ) + model = copy.deepcopy(wrapped_model) + model["id"] = self.kernel_id + return model + + +class RoutingMappingKernelManager(AsyncMappingKernelManager): + @default("kernel_manager_class") + def _default_kernel_manager_class(self): + return "jupyter_server.services.kernels.routing.RoutingKernelManager" + + kernel_spec_manager = Instance( + "jupyter_server.services.kernels.routing.RoutingKernelSpecManager" + ) + + _routing_provider = None + routing_provider_class = Type( + klass=RoutingProvider, + config=True, + help=_i18n( + "The class defining how kernelspec and kernel requests are routed " + + "to the various supported managers." + ), + ) + + @default("routing_provider_class") + def _default_routing_provider_class(self): + gateway_config = GatewayClient.instance(parent=self.parent) + if gateway_config.gateway_enabled: + return RemoteOnlyRoutingProvider + return RoutingProvider + + @property + def routing_provider(self): + if not self._routing_provider: + self._routing_provider = self.routing_provider_class( + parent=self.parent, log=self.log, connection_dir=self.connection_dir + ) + return self._routing_provider + + def has_remote_kernels(self): + for kid in self._kernels: + if self._kernels[kid].is_remote: + return True + return False + + async def list_kernels(self): + if self.has_remote_kernels(): + # We have remote kernels, so we must call `list_kernels` on the + # wrapped Gateway kernel managers to update our kernel models. + try: + await ensure_async(self.routing_provider.primary_manager.list_kernels()) + for wrapped in self.routing_provider.additional_managers: + await ensure_async(wrapped.list_kernels()) + except Exception as ex: + self.log.exception("Failure listing kernels: %s", ex) + # Ignore the exception listing remote kernels, so that local kernels are still usable. + return super().list_kernels() + + def kernel_model(self, kernel_id): + self._check_kernel_id(kernel_id) + kernel = self._kernels[kernel_id] + # Normally, calls to `run_sync` pose a danger of locking up Tornado's + # single-threaded event loop. + # + # However, the call below should be fine because it cannot block for an + # arbitrary amount of time. + # + # This call blocks on the `model` method defined below, which in turn + # blocks on the `GatewayMappingKernelManager`'s `kernel_model` method + # (https://github.com/jupyter-server/jupyter_server/blob/547f7a244d89f79dd09fa7d382322d1c40890a3f/jupyter_server/gateway/managers.py#L94). + # + # That will only take a small, deterministic amount of time to complete + # because that `kernel_model` only operates on existing, in-memory data + # and does not block on any outgoing network requests. + return run_sync(kernel.model)() + + @property + def info(self): + return self.routing_provider.info + + +KernelManagerABC.register(RoutingKernelManager) diff --git a/jupyter_server/services/kernels/websocket.py b/jupyter_server/services/kernels/websocket.py index a24b0539f..2d0ab4974 100644 --- a/jupyter_server/services/kernels/websocket.py +++ b/jupyter_server/services/kernels/websocket.py @@ -21,6 +21,10 @@ class KernelWebsocketHandler(WebSocketMixin, WebSocketHandler, JupyterHandler): @property def kernel_websocket_connection_class(self): """The kernel websocket connection class.""" + if self.kernel_manager and self.kernel_id: + kernel = self.kernel_manager.get_kernel(self.kernel_id) + if hasattr(kernel, "websocket_connection_class"): + return kernel.websocket_connection_class return self.settings.get("kernel_websocket_connection_class") def set_default_headers(self): diff --git a/jupyter_server/services/sessions/sessionmanager.py b/jupyter_server/services/sessions/sessionmanager.py index f02e04bc4..b13a05ce3 100644 --- a/jupyter_server/services/sessions/sessionmanager.py +++ b/jupyter_server/services/sessions/sessionmanager.py @@ -18,6 +18,7 @@ from dataclasses import dataclass, fields +from jupyter_client.manager import KernelManager from jupyter_core.utils import ensure_async from tornado import web from traitlets import Instance, TraitError, Unicode, validate @@ -468,7 +469,13 @@ async def update_session(self, session_id, **kwargs): async def kernel_culled(self, kernel_id: str) -> bool: """Checks if the kernel is still considered alive and returns true if its not found.""" - return kernel_id not in self.kernel_manager + km: Optional[KernelManager] = None + try: + km = self.kernel_manager.get_kernel(kernel_id) + except Exception: + # Let exceptions here reflect culled kernel + pass + return km is None async def row_to_model(self, row, tolerate_culled=False): """Takes sqlite database session row and turns it into a dictionary""" diff --git a/tests/base/test_call_context.py b/tests/base/test_call_context.py index f3e48522f..ed683f869 100644 --- a/tests/base/test_call_context.py +++ b/tests/base/test_call_context.py @@ -7,7 +7,7 @@ from jupyter_server.services.kernels.kernelmanager import AsyncMappingKernelManager -async def test_jupyter_handler_contextvar(jp_fetch, monkeypatch): +async def test_jupyter_handler_contextvar(jp_serverapp, jp_fetch, monkeypatch): # Create some mock kernel Ids kernel1 = "x-x-x-x-x" kernel2 = "y-y-y-y-y" @@ -45,7 +45,7 @@ async def kernel_model(self, kernel_id): context_tracker[kernel_id]["ended"] = current.current_user return {"id": kernel_id, "name": "blah"} - monkeypatch.setattr(AsyncMappingKernelManager, "kernel_model", kernel_model) + monkeypatch.setattr(jp_serverapp.kernel_manager.__class__, "kernel_model", kernel_model) # Make two requests in parallel. await asyncio.gather( diff --git a/tests/test_gateway.py b/tests/test_gateway.py index 6f4fcebc1..d50c2b23c 100644 --- a/tests/test_gateway.py +++ b/tests/test_gateway.py @@ -439,7 +439,10 @@ def test_gateway_request_with_expiring_cookies( async def test_gateway_class_mappings(init_gateway, jp_serverapp): # Ensure appropriate class mappings are in place. assert jp_serverapp.kernel_manager_class.__name__ == "GatewayMappingKernelManager" - assert jp_serverapp.session_manager_class.__name__ == "GatewaySessionManager" + assert ( + jp_serverapp.session_manager.kernel_manager.__class__.__name__ + == "GatewayMappingKernelManager" + ) assert jp_serverapp.kernel_spec_manager_class.__name__ == "GatewayKernelSpecManager"