maxtext.checkpoint_conversion.to_maxtext module#
This script converts a HuggingFace model checkpoint to a MaxText-compatible Orbax checkpoint.
- Key Parameters (to be set in the config file or as command-line overrides):
- model_name: (Required) The name of the model to convert (e.g., “gemma2-2b”).
Must be a key in maxtext.utils.globals.HF_IDS.
- base_output_directory: (Optional) The directory where the converted HuggingFace
checkpoint will be saved. Can be a local path, a GCS path (gs://…), or a HuggingFace Hub repo ID (hf://…). Defaults to “./mt_output/”.
- scan_layers: (bool) Whether the MaxText model was trained with scanned layers.
This must match the training configuration of the checkpoint.
- –lazy_load_tensors: (bool) If True, uses an on-demand loading strategy to minimize RAM
usage during conversion. Recommended if, 2 * model_size (GB) >= system RAM Defaults to False.
- –hf_model_path: (Optional) Specifies a local or remote directory containing the model weights.
If unspecified, we use the default Hugging Face repository ID (e.g., openai/gpt-oss-20b; see HF_IDS[model_name] in maxtext.utils.globals). This is necessary for locally dequantized models like GPT-OSS or DeepSeek.
- –save_dtype: (Optional) Specifies the data type of saved model weights.
Default to bfloat16 to save memory.
- Environment Variables:
- HF_AUTH_TOKEN: (Required) HuggingFace authentication token, needed to
download models from HuggingFace Hub.
- Example Usage:
To convert a gemma2-2b model and save it to a specific directory:
python -m maxtext.checkpoint_conversion.to_maxtext maxtext/configs/base.yml model_name=”gemma2-2b” base_output_directory=”/path/to/your/output/directory” hf_access_token=${HF_TOKEN?} hardware=cpu skip_jax_distributed_system=True scan_layers=False
For models with scanned layers (e.g., some custom architectures), you might need to set scan_layers=True and param_scan_axis accordingly.
To convert a 70B model with minimal RAM usage:
python -m maxtext.checkpoint_conversion.to_maxtext maxtext/configs/base.yml model_name=”llama3.1-70b” base_output_directory=”gs://my-bucket/maxtext-checkpoints” hf_access_token=${HF_TOKEN?} hardware=cpu skip_jax_distributed_system=True –lazy_load_tensors=True
- class maxtext.checkpoint_conversion.to_maxtext.LazyHFLoader(model_id, token, revision=None)[source]#
Bases:
objectLoads Hugging Face weights on-demand to minimize RAM usage.
This class is the core of the “lazy loading” feature. Instead of loading the entire model into memory at once, it reads the model’s index file (e.g., model.safetensors.index.json) to understand the mapping between tensor names and the shard files they belong to.
When a specific tensor is requested via get_tensor, this class: 1. Identifies the correct shard file. 2. Downloads the shard file if not already cached by huggingface_hub. 3. Opens the shard and extracts only the requested tensor into memory.
This approach is highly memory-efficient, especially for safetensors, as it avoids loading entire multi-gigabyte shard files when only a small piece is needed. A threading lock (_ram_lock) is used to ensure that memory-intensive file-opening operations are serialized to prevent RAM spikes, while downloads can still occur in parallel.
- get_tensor(key)[source]#
Retrieves a specific tensor by name, lazily loading its shard if necessary.
This is the main entry point for accessing model weights. It determines which shard file contains the tensor, ensures it’s downloaded, and then reads the tensor data.
For safetensors, this is extremely efficient as it memory-maps the file and reads only the required tensor’s data from disk.
- Parameters:
key (str)
- Return type:
ndarray
- class maxtext.checkpoint_conversion.to_maxtext.LazyTensor(load_fn, shape, dtype, name='unknown')[source]#
Bases:
objectA proxy object that looks like a NumPy array but delays actual loading and transformation until __array__ is called (e.g., by Orbax during save).
- Parameters:
load_fn (Callable[[], ndarray])
shape (tuple)
name (str)
- property size#
Total number of elements in the tensor.
- property nbytes#
Return estimated nbytes so Orbax doesn’t need to load the real array to find out.
- property itemsize#
- class maxtext.checkpoint_conversion.to_maxtext.LazyTensorHandler(metadata_key=None, ocdbt_process_id=None)[source]#
Bases:
NumpyHandlerCustom Orbax handler for LazyTensor.
It masquerades as a standard NumpyHandler so that the resulting checkpoint has the standard ‘array_metadatas’ structure and can be loaded by standard MaxText instances.
- Parameters:
metadata_key (Optional[str])
ocdbt_process_id (str | None)
- maxtext.checkpoint_conversion.to_maxtext.get_maxtext_model_info(config)[source]#
Initializes the abstract MaxText model and returns parameter mapping information.
- Parameters:
config – The MaxText configuration object.
- Returns:
- A dictionary mapping MaxText parameter keys to a tuple
(index, target_shape), where ‘index’ is the position of the parameter in the flattened parameter list.
abstract_params_treedef: The tree structure definition of the abstract model parameters.
- Return type:
maxtext_abstract_dict
- maxtext.checkpoint_conversion.to_maxtext.main(args, lazy_load_tensors=False, eager_load_method='transformers', hf_model_path=None, revision=None, save_dtype='bfloat16', simulated_cpu_devices_count=16)[source]#
- Parameters:
args (Sequence[str])
lazy_load_tensors (bool)
eager_load_method (str)
hf_model_path (str | None)
revision (str | None)
save_dtype (str)
simulated_cpu_devices_count (int)
- Return type:
None