maxtext.layers.train_state_nnx module

maxtext.layers.train_state_nnx module#

The NNX Unified TrainState.

class maxtext.layers.train_state_nnx.TrainStateNNX(*args, **kwargs)[source]#

Bases: Module

A unified container for NNX models and optimizers. This replaces Linen’s TrainState for checkpointing.

Linen TrainState pytree:

{“params”: {…}, “opt_state”: {}…}

TrainStateNNX state pytree:

{“model”: {…}, “optimizer”: {“opt_state”: {…}}

Parameters:
  • args (Any)

  • kwargs (Any)

Return type:

Any

apply_gradients(grads)[source]#

Mimics the Linen apply_gradients function. Updates the optimizer state, applies updates to parameters, and increments the step counter.

Parameters:

grads (Any)