maxtext.checkpoint_conversion.utils.utils module#

Checkpoint conversion utility functions.

maxtext.checkpoint_conversion.utils.utils.validate_and_filter_param_map_keys(param_map_keys, maxtext_state_keys)[source]#

Validates param_mapping coverage and filters unused keys, for to_maxtext and to_huggingface.

Preprocess maxtext keys for transformation. - Ensures every MaxText checkpoint key (maxtext_state_keys) is covered by

the flattened param_mapping.

  • Keys in the param_mapping that are not present in the checkpoint (common for multi-variant maps like gemma3, qwen3, deepseek) are skipped.

Parameters:
  • param_map_keys – MaxText keys from the PARAM_MAPPING. These can be: - atomic_mt_key: A single string representing one MaxText parameter that map to HF parameter(s). - composite_mt_key: A tuple of strings representing multiple MaxText parameters that map to HF parameter(s).

  • maxtext_state_keys – Set of MaxText keys loaded from the Orbax checkpoint.

Returns:

A list of ‘filtered’ mapping keys (strings or tuples) that are fully present and valid based on maxtext_state_keys.

Raises:

ValueError – If maxtext_state_keys is NOT a subset of the flattened param_map_keys.

maxtext.checkpoint_conversion.utils.utils.apply_hook_fns(weight, target_shape, hook_fns)[source]#

Apply hook functions, essential for to_maxtext and to_huggingface

maxtext.checkpoint_conversion.utils.utils.convert_jax_weight_to_numpy(weight, dtype_str=None)[source]#

Converts a JAX array to a NumPy array with the specified dtype, used in to_huggingface.

Parameters:
  • weight (Array) – The input JAX array, potentially sharded across devices.

  • dtype_str (None | str) – The target NumPy dtype as a string (e.g., ‘float32’, ‘bfloat16’). If None, the dtype of the input JAX array is preserved. Defaults to None.

Returns:

A NumPy array containing the data from weight, cast to dtype_str if provided.

Return type:

ndarray

maxtext.checkpoint_conversion.utils.utils.process_maxtext_param(maxtext_param_key, maxtext_param_weight, param_map, hook_fn_map, hf_shape_map, maxtext_config)[source]#

Processes a single MaxText parameter (or a group of parameters) for conversion, used in to_huggingface.

This function is responsible for taking a MaxText parameter and transforming it into one or more Hugging Face compatible parameters. It handles various scenarios based on - the MaxText key form (atomic_mt_key or composite_mt_key) - and the Hugging Face value form (unscanned string, scanned list of strings,

unscanned with expert stacking, or scanned with expert stacking).

Note: We assume composite_mt_key can only occur for unscanned/scanned HF keys, but not those with expert stacking.

Parameters:
  • maxtext_param_key (str | tuple[str, ...]) – The key identifying the MaxText parameter(s). Can be an atomic_mt_key (str) or a composite_mt_key (tuple of str) mapping to HF parameter(s).

  • maxtext_param_weight (Array | list[Array]) – The actual weight(s) of the MaxText parameter(s). This can be a single jax.Array for an atomic_mt_key or a list of jax.Array for a composite_mt_key.

  • param_map (dict[str, Any]) – A dictionary mapping MaxText parameter keys to their corresponding Hugging Face target path(s).

  • hook_fn_map (dict[str, Any]) – A dictionary mapping MaxText parameter keys to transformation functions (hooks) that should be applied to the weights.

  • hf_shape_map (dict[str, Any]) – A dictionary mapping Hugging Face parameter paths to their expected shapes.

  • maxtext_config (Any) – The MaxText configuration object, used to determine details like param_scan_axis and base_num_decoder_layers.

Returns:

  • hf_path (str): The Hugging Face parameter path.

  • hf_weight (np.ndarray): The transformed Hugging Face compatible weight.

Return type:

A list of tuples, where each tuple contains

maxtext.checkpoint_conversion.utils.utils.create_huggingface_hub_repo_if_not_exist(repo_id, repo_type)[source]#
maxtext.checkpoint_conversion.utils.utils.save_config_file(config, local_path_to_save_to, output_dir_final, file_name, remove_local_copy_after_upload=False)[source]#

Saves the model configuration file(config.json).

Parameters:
  • local_path_to_save_to (str)

  • output_dir_final (str)

  • file_name (str)

  • remove_local_copy_after_upload (bool)

maxtext.checkpoint_conversion.utils.utils.shard_checkpoint(weights_dict, max_shard_size=3221225472, weights_name='model.safetensors')[source]#

Shards a model checkpoint into smaller pieces based on size constraints.

Parameters:
  • weights_dict (dict[str, Array]) – Model weights dictionary to shard

  • max_shard_size (int) – Maximum size in bytes for each shard

  • weights_name (str) – Base filename for the shards

Returns:

tuple of (sharded weights dict, optional index dict) Index contains metadata and weight mapping information

Return type:

tuple[dict[str, dict[str, Array]], None | dict]

maxtext.checkpoint_conversion.utils.utils.save_safetensor_file(state_dict, local_dir_to_save_to, output_dir_final, file_name)[source]#

Saves a single safetensor file, from memory to remote when uploading

Parameters:
  • local_dir_to_save_to (str)

  • output_dir_final (str)

  • file_name (str)

maxtext.checkpoint_conversion.utils.utils.save_index_file(index, local_dir_to_save_to, output_dir_final, file_name, remove_local_copy_after_upload=False)[source]#

Saves the model index json file (model.safetensors.index.json).

Parameters:
  • index (dict)

  • local_dir_to_save_to (str)

  • output_dir_final (str)

  • file_name (str)

  • remove_local_copy_after_upload (bool)

maxtext.checkpoint_conversion.utils.utils.save_weight_files(shards, index, local_dir_to_save_to, output_dir_final, parallel_threads=8, remove_local_copy_after_upload=False)[source]#

Saves weight files and index if needed.

Requires local system to have at least parallel_threads * DEFAULT_MAX_SHARD_SIZE free disk space, as each thread will maintain a local cache of its shard during processing.

Parameters:
  • local_dir_to_save_to (str)

  • output_dir_final (str)

  • remove_local_copy_after_upload (bool)

maxtext.checkpoint_conversion.utils.utils.get_local_save_path_manager(output_dir)[source]#

Context manager to provide a local path for saving files. If output_dir is remote (GCS/HF), a temporary local directory is created. If output_dir is local, it’s used directly. :Yields: tuple – (path_to_use_for_saving: str, is_temporary: bool)

Parameters:

output_dir (str)

maxtext.checkpoint_conversion.utils.utils.save_model_files(weight_arrays, config, tokenizer, processor, output_dir, parallel_threads=8)[source]#

Saves model files (config and weights) to the specified directory. When uploading to GCS/HF hub,

*.safetensors are uploaded from memory to remote, no local storage is used to save disk usage

Parameters:
  • weight_arrays (dict)

  • tokenizer (None | Any)

  • output_dir (str)

maxtext.checkpoint_conversion.utils.utils.upload_state_dict_to_gcs(state_dict, gs_bucket_path)[source]#

Uploads a state_dict from memory to Google Cloud Storage.

Parameters:
  • state_dict (dict) – A PyTorch model’s state_dict.

  • gs_bucket_path (str) – GCS destination (e.g., “gs://my-bucket/models/model.pt”).

maxtext.checkpoint_conversion.utils.utils.upload_file_to_gcs(local_file, gs_bucket_path, remove_local_file_after_upload=False)[source]#

Uploads a single file to Google Cloud Storage.

Parameters:
  • local_file (str) – Path to local file

  • gs_bucket_path (str) – GCS destination (e.g. “gs://my-bucket/path/file.txt” or “my-bucket/path/file.txt”)

maxtext.checkpoint_conversion.utils.utils.upload_folder_to_gcs(local_folder, gs_bucket_path, num_workers=4)[source]#

Uploads all files from a local folder to Google Cloud Storage.

Parameters:
  • local_folder (str) – Path to local folder (e.g. “data/images”)

  • gs_bucket_path (str) – GCS destination (e.g. “gs://my-bucket/images” or “my-bucket/images”)

  • num_workers (int) – Number of parallel upload workers

maxtext.checkpoint_conversion.utils.utils.print_ram_usage(stage='')[source]#
maxtext.checkpoint_conversion.utils.utils.print_peak_memory()[source]#
class maxtext.checkpoint_conversion.utils.utils.MemoryMonitorTqdm(*_, **__)[source]#

Bases: tqdm

Custom tqdm class that displays memory usage in the progress bar.

format_meter(n, total, elapsed, postfix=None, **extra_kwargs)[source]#

Override to add memory usage info to the postfix.

maxtext.checkpoint_conversion.utils.utils.load_orbax_checkpoint(config)[source]#

Loads a full Orbax checkpoint from disk with unsharded arrays.

Parameters:

config – MaxText config containing checkpoint storage settings

Returns:

Dictionary containing the full checkpoint structure

Return type:

dict

maxtext.checkpoint_conversion.utils.utils.extract_nnx_weights(weights_dict)[source]#

Extract weights from NNX checkpoint structure.

NNX checkpoints have structure: {‘decoder’: {‘decoder_norm’: {‘scale’: {‘value’: array}}}} This function flattens it to: {‘params-decoder-decoder_norm-scale’: array}

Parameters:

weights_dict (dict) – NNX checkpoint weights dictionary

Returns:

Dictionary mapping parameter names to weight arrays

Return type:

dict[str, ndarray]

maxtext.checkpoint_conversion.utils.utils.extract_linen_weights(weights_dict)[source]#

Extract weights from Linen checkpoint structure.

Linen checkpoints have structure: {‘params’: {‘decoder’: {‘decoder_norm’: {‘scale’: array}}}} This function flattens it to: {‘params-decoder-decoder_norm-scale’: array}

Parameters:

weights_dict (dict) – Linen checkpoint weights dictionary

Returns:

Dictionary mapping parameter names to weight arrays

Return type:

dict[str, ndarray]

maxtext.checkpoint_conversion.utils.utils.detect_and_extract_checkpoint(checkpoint_dict)[source]#

Detect checkpoint type (Linen vs NNX) and extract weights.

Handles multiple NNX checkpoint variants: - Linen: {‘params’: {‘params’: {‘decoder’: {…}, ‘token_embedder’: … {WEIGHT_ARRAY}}}} - NNX-SFT: {‘decoder’: {…}, ‘token_embedder’: … {‘value’: WEIGHT_ARRAY}} - NNX-RL: {‘base’: {‘decoder’: {…}, ‘token_embedder’: … {‘value’: WEIGHT_ARRAY}}}

Currently, we align all extracted weights to MaxText-Linen naming convention like “params-decoder-decoder_norm-scale”. This allows reusing the same param_mapping for both Linen and NNX checkpoints.

Parameters:

checkpoint_dict (dict) – Raw checkpoint dictionary from Orbax

Returns:

Dictionary mapping MaxText parameter names to weight arrays

Return type:

dict[str, ndarray]

maxtext.checkpoint_conversion.utils.utils.load_hf_dict_from_transformers(model_id, token, revision=None, dtype='auto')[source]#

Loads the HuggingFace model based on model_id (Eager mode only), used in to_maxtext

Parameters:
  • model_id (str)

  • token (str)

  • revision (str | None)

  • dtype (str)

maxtext.checkpoint_conversion.utils.utils.load_hf_dict_from_safetensors(model_id_or_path, token, revision, framework='pt')[source]#

If the safetensor contains more HF keys than MaxText model, these HF keys will be loaded but ignored during conversion. For example, if maxtext has deepseek3 with mtp=false, then safetensor weight with prefix model.layers.61 will be the extra keys.