Skip to content

v0.11.0

Choose a tag to compare

@cgarciae cgarciae released this 29 Jul 21:04
· 442 commits to main since this release

v0.11.0 - Pytrees, MutableArrays, and more!

This version of Flax introduces some changes to improve interop with native JAX and adds support for the new jax.experimental.MutableArray. More on this soon! However, some breaking changes to align with the JAX way of doing things were necessary. Most code should remain intact, however, the following changes deviate from the current behavior:

  • Rngs in standard layers: all standard layers no longer hold a shared reference to the rngs object given in the constructor, instead they now keep a fork-ed copy of the Rngs or RngStream objects. This impacts Using Rngs in NNX Transforms and Loading Checkpoints with RNGs.
  • Optimizer Updates: the Optimizer abstraction no longer holds a reference to the model to avoid reference sharing, instead the model must be provided as the first argument to update.
  • Modules as Pytrees: Modules are now pytrees! This avoid unnecessary use of split and merge when interacting trivially with raw JAX transforms (state must still be manually propagated if not using MutableArrays, and referential transparency is still an issue). This affects when operating on Pytrees containing NNX Objects with jax.tree.* APIs.

Checkout the full NNX 0.10 to NNX 0.11 migration guide.

In the near future we'll share more information about new ways of using NNX with JAX transforms directly by leveraging the new Pytree and MutableArray support. Stay tuned!

What's Changed

New Contributors

Full Changelog: v0.10.7...v0.11.0