Adding a New Model for LoRA Fine-Tuning#

This guide explains how to add Low-Rank Adaptation (LoRA) support for a new model architecture in MaxText.

MaxText leverages Tunix and Qwix to support Parameter-Efficient Fine-Tuning (PEFT) on JAX/NNX model definitions. Since the architecture uses modular APIs, adding LoRA support for a new model is highly streamlined.


1. Step-by-Step Bring-up Guide for NNX LoRA#

To enable LoRA support for a new model, follow these two simple steps:

Step 1.1: Verify Base Model Support#

The target model architecture must already be implemented and supported as a base model in MaxText.

  • The JAX/NNX model definition should be located under src/maxtext/models/ (e.g., gemma3.py).

  • The model configurations must be registered and runnable for baseline pre-training or full fine-tuning.

Step 1.2: Define Trainable LoRA Target Modules#

Add a recommended target pattern for your model architecture prefix in src/maxtext/configs/post_train/lora_module_path.yml:

your_model_prefix: "decoder/layers/.*(self_attention/(query|key|value|out)|mlp/(wi_0|wi_1|wo))"

[!NOTE] MaxText’s _get_lora_module_path in lora_utils.py automatically handles both scanned (e.g., layers/0/self_attention/...) and unscanned (e.g., layers/self_attention/...) layer formats by injecting an optional layer index regex. You only need to define standard, unscanned paths.

If no prefix matches your model name, MaxText falls back to the default pattern:

default: "decoder/layers/.*(self_attention/(query|key|value|out)|mlp/(wi_0|wi_1|wo))"

2. Integrating Custom Weight Mappings (When is it needed?)#

Determining whether you need to implement custom weight mappings depends entirely on your downstream workflow:

Scenario A: SFT Training & Conversion to PEFT (No Mapping Needed)#

If you only need to run SFT fine-tuning with LoRA and then export the adapter back to Hugging Face format using to_huggingface.py, you do not need to write any custom weight mappings.

  • The conversion utility automatically maps, scales, and formats the LoRA adapter parameters back into standard Hugging Face PEFT format based on the base model’s existing weight mapping.

Scenario B: Decoding with the MaxText vLLM Adapter (Mapping is Required)#

If you want to perform decoding or run high-performance serving on your adapted model using the MaxText vLLM adapter (e.g., via vllm_decode), you must define and register a custom weight mapping. This allows the vLLM JAX wrapper to dynamically map and feed weights to the vLLM engine.

To add weight mapping for vLLM decode:

  1. Create a Weight Mapping Config: Create a new file in src/maxtext/integration/tunix/weight_mapping/ (e.g., your_model.py) defining a mapping dataclass. You can refer to llama3.py as a template.

    Your class should specify:

    • to_hf_mapping(): Maps MaxText base parameters to Hugging Face parameters and specifies their sharding axes.

    • to_hf_hook_fns(): Custom hook functions for complex weight transformations (e.g., RoPE reordering or query scaling).

    • lora_to_hf_mappings(): Custom mapping for LoRA weights if they require different handling.

  2. Register the Mapping: Register your new class in src/maxtext/integration/tunix/weight_mapping/init.py inside the StandaloneVllmWeightMapping class:

    # Inside StandaloneVllmWeightMapping
    if name.startswith("your_model_name"):
        return YOUR_MODEL_VLLM_MAPPING
    

3. Verifying Your Custom LoRA Targets#

If you are developing or bringing up your model architecture interactively (e.g., in a Python interpreter or Jupyter notebook), you can verify which layers are wrapped with LoRA adapters by inspecting the model’s module graph:

import re
from flax import nnx
from maxtext.utils import model_creation_utils, lora_utils

# 1. Create model, mesh, and load config
model, mesh = model_creation_utils.from_pretrained(mt_config)

# 2. Extract lora path regex and compile
compiled_module_path = re.compile(lora_utils._get_lora_module_path(mt_config))

# 3. Iterate over modules to see exactly which ones matched your pattern
for path, _ in nnx.iter_modules(model):
    module_path = "/".join(str(p) for p in path)
    if compiled_module_path.search(module_path):
        print(f"Matched and wrapped with LoRA: {module_path}")

This programmatic verification allows you to inspect and traverse parameters interactively during development.