Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions jaxtyping/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
)
from ._import_hook import install_import_hook as install_import_hook
from ._ipython_extension import load_ipython_extension as load_ipython_extension
from ._ipython_extension import unload_ipython_extension as unload_ipython_extension
from ._storage import print_bindings as print_bindings


Expand Down
43 changes: 43 additions & 0 deletions jaxtyping/_ipython_extension.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
# IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
# CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.

from ._config import config
from ._import_hook import JaxtypingTransformer, Typechecker


Expand Down Expand Up @@ -54,3 +55,45 @@ def load_ipython_extension(ipython):
raise RuntimeError("Failed to define jaxtyping.typechecker magic") from e

ipython.register_magics(ChooseTypecheckerMagics)


def unload_ipython_extension(ipython):
"""
Support `%unload_ext jaxtyping` to remove the jaxtyping AST transformer
and unregister the `%jaxtyping.typechecker` magic.
"""
if ipython is None:
return

# Disable runtime typechecking globally (covers already-decorated functions).
try:
config.jaxtyping_disable = True
except Exception:
pass

# 1) Remove any JaxtypingTransformer from the AST transformers.
try:
ipython.ast_transformers = [
t for t in getattr(ipython, "ast_transformers", [])
if not isinstance(t, JaxtypingTransformer)
]
except Exception:
# Be permissive: if IPython internals change, don't hard-fail.
pass

# 2) Unregister the `%jaxtyping.typechecker` magic.
try:
mm = getattr(ipython, "magics_manager", None)
if mm is not None:
for kind in ("line", "cell", "line_cell"):
d = mm.magics.get(kind, {})
# Names registered via @line_magic use the explicit string we provided.
for name in ("jaxtyping.typechecker",):
if name in d:
try:
del d[name]
except Exception:
pass
except Exception:
# Also permissive here.
pass
17 changes: 17 additions & 0 deletions test/test_ipython_extension.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,23 @@ def g(x: Float[Array, "1"]):
ip.run_cell(raw_cell='g("string")').raise_error()


def test_unload_extension_disables_typechecking(ip):
ip.run_cell(
raw_cell="""
from jaxtyping import Float, Array
import jax

def g(x: Float[Array, "1"]):
return x + 1

int_arr = jax.numpy.array([1])
"""
).raise_error()

ip.run_cell(raw_cell="%unload_ext jaxtyping").raise_error()
ip.run_cell(raw_cell="g(int_arr)").raise_error()


def test_function_jaxtyped_and_jitted(ip):
ip.run_cell(
raw_cell="""
Expand Down