maxtext.layers.train_state_nnx module#
The NNX Unified TrainState.
- class maxtext.layers.train_state_nnx.TrainStateNNX(*args, **kwargs)[source]#
Bases:
ModuleA unified container for NNX models and optimizers. This replaces Linen’s TrainState for checkpointing.
- Linen TrainState pytree:
{“params”: {…}, “opt_state”: {}…}
- TrainStateNNX state pytree:
{“model”: {…}, “optimizer”: {“opt_state”: {…}}
- Parameters:
args (Any)
kwargs (Any)
- Return type:
Any