# MaxText and the JAX ecosystem

MaxText is built on a curated stack of JAX libraries, each chosen for a specific purpose. This document provides an opinionated view on *why* MaxText uses the following key components of the JAX ecosystem:

- **Flax (NNX)**: For ergonomic and functional model definition.
- **Optax**: For composable optimization.
- **Orbax**: For robust checkpointing.
- **Grain**: For deterministic, multi-host data loading.
- **Qwix**: For native JAX quantization.
- **Tunix**: For modular fine-tuning.

This stack isn't just a random collection of tools; it represents a design philosophy centered around **explicitness, composability, and performance at scale**.

This document provides an opinionated view on *why* MaxText uses these specific libraries, explaining the design decisions that make them ideal for building and training large-scale models.

## Flax: For functional model definition

**What is it?** Flax is a high-performance neural network library for JAX that is designed to be flexible, explicit, and easy to use.

With its latest generation API, NNX, Flax provides a modern, object-oriented (OOP) approach that makes defining and managing models more intuitive and Pythonic.

1. **Explicit State Management**: Unlike stateful frameworks where parameters are hidden attributes of an object, Flax treats model parameters (`params`, `batch_stats`, etc.) as explicit arguments to its functions. This transparency is crucial for debugging and managing distributed state.
2. **Deep JAX Integration**: Flax's NNX is designed from the ground up to work seamlessly with JAX's powerful transformations like `jax.jit` and `jax.grad`. This enables high performance and scalability without sacrificing ease of use.
3. **Flexibility through PyTrees**: All model state is stored in standard JAX PyTrees (nested dictionaries), making it trivial to inspect and manipulate any part of the model.

For more information on using Flax, please refer to https://github.com/google/flax

## Optax: For composable optimization

**What is it?** Optax is a gradient processing and optimization library for JAX. It reimagines the optimizer as a series of composable functional transformations.

1. **Decoupling Optimization from Parameters**: Optax completely separates the optimizer's state from the model's parameters, treating the update step as a pure function.
2. **The Power of `optax.chain`**: The core design pattern in Optax is chaining gradient transformations. This makes it easy to build custom optimizers by combining building blocks like gradient clipping, weight decay, and a learning rate schedule.
3. **Rich Library of Pre-Built Optimizers**: While Optax is ideal for building custom optimizers, it also comes with a wide range of popular optimizers like `optax.adamw`, `optax.adam`, and `optax.sgd` ready to use out-of-the-box. This provides the flexibility to start with a standard optimizer and only customize when needed.

For more information on using Optax, please refer to https://github.com/google-deepmind/optax

## Orbax: For robust checkpointing

**What is it?** Orbax is a library for checkpointing JAX programs, designed for large-scale, potentially unreliable environments.

**Why does MaxText use it?**

For massive models, saving and loading state is a critical part of the training infrastructure.

1. **Asynchronous Checkpointing**: It writes large checkpoints to storage in the background without stalling the expensive TPU/GPU accelerators, maximizing hardware utilization.
2. **Checkpoint Management**: Orbax provides a `CheckpointManager` that handles the entire lifecycle of checkpoints, including versioning, keeping the N most recent saves, and ensuring atomic writes to prevent corruption.
3. **Handling Scanned and Unscanned Formats**: For performance, MaxText uses `jax.lax.scan` over its transformer layers. This results in an efficient "scanned" checkpoint format where layer parameters are stacked along a single array axis. For interoperability with other frameworks and inference, a more standard "unscanned" (layer-by-layer) format is often required. Orbax is used to reliably save, load, and convert between both formats, enabling both efficient training, inference, and easy model sharing.
4. **Facilitating Checkpoint Conversion**: When importing models from other ecosystems like Hugging Face, Orbax provides the final, critical step. Conversion scripts first load external weights (e.g., from `.safetensors` files) and map them to a JAX PyTree. Orbax is then used to save this PyTree as a native, MaxText-compatible checkpoint, providing a robust and standardized endpoint for the conversion pipeline.

For more information on using Orbax, please refer to https://github.com/google/orbax

## Grain: For deterministic, multi-host data loading

**What is it?** Grain is a high-performance data loading library designed for deterministic, global shuffle and multi-host data loading.

1. **Deterministic by Design**: Grain allows storing data loader states, provides strong guarantees about data ordering and sharding even with preemptions, which is critical for reproducibility.
2. **Global Shuffle**: Prevents local overfitting.
3. **Built for Multi-Host Training**: The using random access file format streamlines [data loading in the multi-host environments](https://maxtext.readthedocs.io/en/maxtext-v0.2.1/guides/data_input_pipeline.html#multihost-dataloading-best-practice).

Its APIs are explicitly designed for the multi-host paradigm, simplifying the process of ensuring that each host loads a unique shard of the global batch.

For more information on using Grain, please refer to https://github.com/google/grain and the grain guide in maxtext located at https://maxtext.readthedocs.io/en/latest/guides/data_input_pipeline/data_input_grain.html

## Qwix: For native JAX quantization

**What is it?** Qwix is a Jax quantization library supporting Quantization-Aware Training (QAT) and Post-Training Quantization (PTQ)

1. **Enables State-of-the-Art Techniques**: It provides core quantization formats (e.g., int8 & fp8) and functions to `quantize` and `dequantize` tensors, which are essential for modern efficient training methods.
2. **JAX-Native Integration**: Its operations and data types are designed to work seamlessly with JAX's transformations (`jit`, `pmap`) and PyTree data structures.

We chose Qwix because it provides the necessary primitives **natively within the JAX ecosystem**. Using a non-native library would require inefficient "boundary crossing" to move data in and out of JAX's control. Qwix's functions are just another JAX operation, allowing them to be composed and JIT-compiled along with the rest of the model.

For more information on how to quantize your model using Qwix, please refer to https://github.com/google/qwix

## Tunix: For comprehensive post-training

**What is it?** Tunix is a JAX-based library designed for a wide range of post-training tasks, including Supervised Fine-Tuning (SFT), Reinforcement Learning (RL), and Parameter-Efficient Fine-Tuning (PEFT) methods like LoRA.

**Why does MaxText use it?**

MaxText leverages Tunix as its core library for post-training, offering a unified and high-performance platform for adapting base models.

1. **Unified Post-Training Framework**: Tunix provides a consistent API and infrastructure for various post-training techniques, reducing the need for separate implementations for SFT, RL, and PEFT.
2. **State-of-the-Art RL Integration**: Tunix integrates with vLLM for efficient RL sampling, enabling advanced algorithms like Group Relative Policy Optimization (GRPO). This allows for fine-tuning models based on complex reward signals.
3. **NNX Compatibility**: Tunix is designed to work with NNX, the latest generation of Flax, allowing it to leverage the newest JAX features and a more modern API.
4. **Modularity for PEFT**: While offering full fine-tuning, Tunix also maintains strong support for PEFT methods. Techniques like LoRA are implemented as composable Flax/NNX modules, allowing easy application to existing models without altering their core structure.

We chose Tunix because it provides a **comprehensive, performant, and JAX-native solution for the entire post-training lifecycle**. Its integration with libraries like vLLM and its alignment with the NNX ecosystem make it a powerful tool for both full model adaptation and parameter-efficient tuning.

For more information on using Tunix, please refer to https://github.com/google/tunix
