Source code for maxtext.layers.train_state_nnx

# Copyright 2023–2026 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#    https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

""" The NNX Unified TrainState. """

from typing import Any

from flax import nnx


[docs] class TrainStateNNX(nnx.Module): """ A unified container for NNX models and optimizers. This replaces Linen's TrainState for checkpointing. Linen TrainState pytree: {“params”: {...}, “opt_state”: {}...} TrainStateNNX state pytree: {“model”: {...}, “optimizer”: {“opt_state”: {...}} """ def __init__(self, model: nnx.Module, optimizer: nnx.Optimizer | None): self.model = model self.optimizer = optimizer
[docs] def apply_gradients(self, grads: Any): """ Mimics the Linen apply_gradients function. Updates the optimizer state, applies updates to parameters, and increments the step counter. """ if self.optimizer is None: raise RuntimeError( "Cannot call apply_gradients on a TrainStateNNX initialized without an optimizer. " "This usually happens when the state was created for inference only." ) self.optimizer.update(self.model, grads)