maxtext.layers.nnx_wrappers module#

NNX <> Linen interoperability.

maxtext.layers.nnx_wrappers.is_vanilla_variable(vs)[source]#

A variables state is vanilla if its metadata is essentially blank.

Returns False only if it has non-empty hooks or any non-built-in attribute.

Parameters:

vs (Variable)

Return type:

bool

maxtext.layers.nnx_wrappers.to_linen_var(vs)[source]#
Parameters:

vs (Variable)

Return type:

AxisMetadata

maxtext.layers.nnx_wrappers.get_col_name(keypath)[source]#

Given the keypath of a Flax variable type, return its Linen collection name.

Parameters:

keypath (Sequence[Any])

Return type:

str

maxtext.layers.nnx_wrappers.to_nnx_var(col, x)[source]#

Convert a Linen variable to an NNX variable.

Parameters:
  • col (str)

  • x (AxisMetadata | Any)

Return type:

Variable

maxtext.layers.nnx_wrappers.linen_vars_to_nnx_attrs(variables)[source]#

Convert a dict of Linen-style variables to NNX variables.

Parameters:

variables (Mapping[str, Any])

Return type:

dict[str, Any]

maxtext.layers.nnx_wrappers.nnx_attrs_to_linen_vars(nnx_attrs)[source]#

Convert a dict of NNX variables (or variable states) to Linen-style variables.

Parameters:

nnx_attrs (dict)

Return type:

dict

maxtext.layers.nnx_wrappers.lazy_init(fn, *args, **kwargs)[source]#

To run through an arbitrary nnx.Module method and initialize all its needed state.

Here used to trigger initialization of all LinenToNNX module variables.

Parameters:

fn (Module | Callable[[...], Any])

maxtext.layers.nnx_wrappers.current_linen_module()[source]#

Get the current Linen module from the Linen context.

Return type:

Module | None

class maxtext.layers.nnx_wrappers.ToNNX(*args, **kwargs)[source]#

Bases: Module

A wrapper to turn any Linen module into an NNX module.

The result NNX module can be used standalone with all NNX APIs, or as a submodule of another NNX module.

Since Linen module initialization requires a sample input, you need to call lazy_init with an argument to initialize the variables.

Example:

>>> from flax import linen as nn, nnx
>>> import jax
>>> linen_module = nn.Dense(features=64)
>>> x = jax.numpy.ones((1, 32))
>>> # Like Linen init(), initialize with a sample input
>>> model = nnx.bridge.ToNNX(linen_module, rngs=nnx.Rngs(0)).lazy_init(x)
>>> # Like Linen apply(), but using NNX's direct call method
>>> y = model(x)
>>> model.kernel.shape
(32, 64)
Parameters:
  • module – The Linen Module instance.

  • rngs – The nnx.Rngs instance being passed to any NNX module.

  • args (Any)

  • kwargs (Any)

Returns:

A stateful NNX module that behaves the same as the wrapped Linen module.

Return type:

Any

lazy_init(*args, **kwargs)[source]#

A shortcut of calling nnx.bridge.lazy_init() upon this module.

maxtext.layers.nnx_wrappers.linen_rngs_dict(linen_module, add_default=False)[source]#

Given a module, split out one of its every active RNG key collections.

Parameters:
  • linen_module (Module)

  • add_default (bool)

class maxtext.layers.nnx_wrappers.ToLinen(nnx_class, args=(), kwargs=FrozenDict({}), skip_rng=False, metadata_fn=<function to_linen_var>, parent=<flax.linen.module._Sentinel object>, name=None)[source]#

Bases: Module

A wrapper to turn any NNX module into a Linen module.

The result Linen module can be used standalone with all Linen APIs, or as a submodule of another Linen module.

Since NNX modules are stateful and owns the state, we only create it once during init time, and will track its state and static data as separate variables.

Example:

>>> from flax import linen as nn, nnx
>>> import jax
>>> model = nnx.bridge.ToLinen(nnx.Linear, args=(32, 64))
>>> x = jax.numpy.ones((1, 32))
>>> y, variables = model.init_with_output(jax.random.key(0), x)
>>> y.shape
(1, 64)
>>> variables['params']['kernel'].shape
(32, 64)
>>> # The static GraphDef of the underlying NNX module
>>> variables.keys()
dict_keys(['params'])
Parameters:
  • nnx_class (Callable[[...], Module]) – The NNX Module class (not instance!).

  • args (Sequence) – The arguments that normally would be passed in to create the NNX module.

  • kwargs (Mapping[str, Any]) – The keyword arguments that normally would be passed in to create the NNX module.

  • skip_rng (bool) – True if this NNX module doesn’t need rngs arg during initialization (not common).

  • metadata_fn (Callable[[Variable], Any] | None)

  • parent (Module | Scope | _Sentinel | None)

  • name (str | None)

Returns:

A stateful NNX module that behaves the same as the wrapped Linen module.

nnx_class: Callable[[...], Module]#
args: Sequence = ()#
kwargs: Mapping[str, Any] = FrozenDict({})#
skip_rng: bool = False#
metadata_fn()#
Parameters:

vs (Variable)

Return type:

AxisMetadata

name: str | None = None#
parent: Module | Scope | _Sentinel | None = None#
scope: Scope | None = None#
maxtext.layers.nnx_wrappers.to_linen(nnx_class, *args, metadata_fn=<function to_linen_var>, name=None, skip_rng=False, abstract_init=True, **kwargs)[source]#

Shortcut of nnx.bridge.ToLinen if user is not changing any of its default fields.

Parameters:
  • nnx_class (Callable[[...], Module])

  • metadata_fn (Callable[[Variable], Any] | None)

  • name (str | None)

  • skip_rng (bool)

  • abstract_init (bool)

maxtext.layers.nnx_wrappers.to_linen_class(base_nnx_class, base_metadata_fn=<function to_linen_var>, base_skip_rng=False, **partial_kwargs)[source]#

A dynamically created Linen Module that wraps a specific NNX Module.

This class is not meant to be used directly. Instead, it is created and returned by the to_linen_class function. It acts as a “partially applied” version of the ToLinen wrapper, where the NNX module to be wrapped and its default arguments are pre-configured.

When you instantiate this class, it behaves like a standard Linen module. The arguments you provide during instantiation can override the defaults that were set when this class was created by to_linen_class.

For example:
>>> from flax import linen as nn, nnx
>>> from maxtext.layers import linears
>>> # Create a specialized Linen wrapper for linears.DenseGeneral
>>> LinenDenseGeneral = to_linen_class(linears.DenseGeneral)
>>> # Now, LinenDenseGeneral can be used like a regular Linen module
>>> class MyModel(nn.Module):
...   def setup(self):
...     # Instantiate the wrapped linears.DenseGeneral with its arguments
...     self.dense = LinenDenseGeneral(
...         in_features_shape=10, out_features_shape=5
...     )
...   def __call__(self, x):
...     return self.dense(x)
Parameters:
  • base_nnx_class (type[M])

  • base_metadata_fn (Callable[[Variable], Any] | None)

  • base_skip_rng (bool)

  • partial_kwargs (Any)

Return type:

type[ToLinen]

(The attributes are dynamically set by the `ToLinen` parent class based

on the arguments provided during instantiation.)