maxtext.checkpoint_conversion.utils.utils module#
Checkpoint conversion utility functions.
- maxtext.checkpoint_conversion.utils.utils.validate_and_filter_param_map_keys(param_map_keys, maxtext_state_keys)[source]#
Validates param_mapping coverage and filters unused keys, for to_maxtext and to_huggingface.
Preprocess maxtext keys for transformation. - Ensures every MaxText checkpoint key (maxtext_state_keys) is covered by
the flattened param_mapping.
Keys in the param_mapping that are not present in the checkpoint (common for multi-variant maps like gemma3, qwen3, deepseek) are skipped.
- Parameters:
param_map_keys – MaxText keys from the PARAM_MAPPING. These can be: - atomic_mt_key: A single string representing one MaxText parameter that map to HF parameter(s). - composite_mt_key: A tuple of strings representing multiple MaxText parameters that map to HF parameter(s).
maxtext_state_keys – Set of MaxText keys loaded from the Orbax checkpoint.
- Returns:
A list of ‘filtered’ mapping keys (strings or tuples) that are fully present and valid based on maxtext_state_keys.
- Raises:
ValueError – If maxtext_state_keys is NOT a subset of the flattened param_map_keys.
- maxtext.checkpoint_conversion.utils.utils.apply_hook_fns(weight, target_shape, hook_fns)[source]#
Apply hook functions, essential for to_maxtext and to_huggingface
- maxtext.checkpoint_conversion.utils.utils.convert_jax_weight_to_numpy(weight, dtype_str=None)[source]#
Converts a JAX array to a NumPy array with the specified dtype, used in to_huggingface.
- Parameters:
weight (Array) – The input JAX array, potentially sharded across devices.
dtype_str (None | str) – The target NumPy dtype as a string (e.g., ‘float32’, ‘bfloat16’). If None, the dtype of the input JAX array is preserved. Defaults to None.
- Returns:
A NumPy array containing the data from weight, cast to dtype_str if provided.
- Return type:
ndarray
- maxtext.checkpoint_conversion.utils.utils.process_maxtext_param(maxtext_param_key, maxtext_param_weight, param_map, hook_fn_map, hf_shape_map, maxtext_config)[source]#
Processes a single MaxText parameter (or a group of parameters) for conversion, used in to_huggingface.
This function is responsible for taking a MaxText parameter and transforming it into one or more Hugging Face compatible parameters. It handles various scenarios based on - the MaxText key form (atomic_mt_key or composite_mt_key) - and the Hugging Face value form (unscanned string, scanned list of strings,
unscanned with expert stacking, or scanned with expert stacking).
Note: We assume composite_mt_key can only occur for unscanned/scanned HF keys, but not those with expert stacking.
- Parameters:
maxtext_param_key (str | tuple[str, ...]) – The key identifying the MaxText parameter(s). Can be an atomic_mt_key (str) or a composite_mt_key (tuple of str) mapping to HF parameter(s).
maxtext_param_weight (Array | list[Array]) – The actual weight(s) of the MaxText parameter(s). This can be a single jax.Array for an atomic_mt_key or a list of jax.Array for a composite_mt_key.
param_map (dict[str, Any]) – A dictionary mapping MaxText parameter keys to their corresponding Hugging Face target path(s).
hook_fn_map (dict[str, Any]) – A dictionary mapping MaxText parameter keys to transformation functions (hooks) that should be applied to the weights.
hf_shape_map (dict[str, Any]) – A dictionary mapping Hugging Face parameter paths to their expected shapes.
maxtext_config (Any) – The MaxText configuration object, used to determine details like param_scan_axis and base_num_decoder_layers.
- Returns:
hf_path (str): The Hugging Face parameter path.
hf_weight (np.ndarray): The transformed Hugging Face compatible weight.
- Return type:
A list of tuples, where each tuple contains
- maxtext.checkpoint_conversion.utils.utils.create_huggingface_hub_repo_if_not_exist(repo_id, repo_type)[source]#
- maxtext.checkpoint_conversion.utils.utils.save_config_file(config, local_path_to_save_to, output_dir_final, file_name, remove_local_copy_after_upload=False)[source]#
Saves the model configuration file(config.json).
- Parameters:
local_path_to_save_to (str)
output_dir_final (str)
file_name (str)
remove_local_copy_after_upload (bool)
- maxtext.checkpoint_conversion.utils.utils.shard_checkpoint(weights_dict, max_shard_size=3221225472, weights_name='model.safetensors')[source]#
Shards a model checkpoint into smaller pieces based on size constraints.
- Parameters:
weights_dict (dict[str, Array]) – Model weights dictionary to shard
max_shard_size (int) – Maximum size in bytes for each shard
weights_name (str) – Base filename for the shards
- Returns:
tuple of (sharded weights dict, optional index dict) Index contains metadata and weight mapping information
- Return type:
tuple[dict[str, dict[str, Array]], None | dict]
- maxtext.checkpoint_conversion.utils.utils.save_safetensor_file(state_dict, local_dir_to_save_to, output_dir_final, file_name)[source]#
Saves a single safetensor file, from memory to remote when uploading
- Parameters:
local_dir_to_save_to (str)
output_dir_final (str)
file_name (str)
- maxtext.checkpoint_conversion.utils.utils.save_index_file(index, local_dir_to_save_to, output_dir_final, file_name, remove_local_copy_after_upload=False)[source]#
Saves the model index json file (model.safetensors.index.json).
- Parameters:
index (dict)
local_dir_to_save_to (str)
output_dir_final (str)
file_name (str)
remove_local_copy_after_upload (bool)
- maxtext.checkpoint_conversion.utils.utils.save_weight_files(shards, index, local_dir_to_save_to, output_dir_final, parallel_threads=8, remove_local_copy_after_upload=False)[source]#
Saves weight files and index if needed.
Requires local system to have at least parallel_threads * DEFAULT_MAX_SHARD_SIZE free disk space, as each thread will maintain a local cache of its shard during processing.
- Parameters:
local_dir_to_save_to (str)
output_dir_final (str)
remove_local_copy_after_upload (bool)
- maxtext.checkpoint_conversion.utils.utils.get_local_save_path_manager(output_dir)[source]#
Context manager to provide a local path for saving files. If output_dir is remote (GCS/HF), a temporary local directory is created. If output_dir is local, it’s used directly. :Yields: tuple – (path_to_use_for_saving: str, is_temporary: bool)
- Parameters:
output_dir (str)
- maxtext.checkpoint_conversion.utils.utils.save_model_files(weight_arrays, config, tokenizer, processor, output_dir, parallel_threads=8)[source]#
Saves model files (config and weights) to the specified directory. When uploading to GCS/HF hub,
*.safetensors are uploaded from memory to remote, no local storage is used to save disk usage
- Parameters:
weight_arrays (dict)
tokenizer (None | Any)
output_dir (str)
- maxtext.checkpoint_conversion.utils.utils.upload_state_dict_to_gcs(state_dict, gs_bucket_path)[source]#
Uploads a state_dict from memory to Google Cloud Storage.
- Parameters:
state_dict (dict) – A PyTorch model’s state_dict.
gs_bucket_path (str) – GCS destination (e.g., “gs://my-bucket/models/model.pt”).
- maxtext.checkpoint_conversion.utils.utils.upload_file_to_gcs(local_file, gs_bucket_path, remove_local_file_after_upload=False)[source]#
Uploads a single file to Google Cloud Storage.
- Parameters:
local_file (str) – Path to local file
gs_bucket_path (str) – GCS destination (e.g. “gs://my-bucket/path/file.txt” or “my-bucket/path/file.txt”)
- maxtext.checkpoint_conversion.utils.utils.upload_folder_to_gcs(local_folder, gs_bucket_path, num_workers=4)[source]#
Uploads all files from a local folder to Google Cloud Storage.
- Parameters:
local_folder (str) – Path to local folder (e.g. “data/images”)
gs_bucket_path (str) – GCS destination (e.g. “gs://my-bucket/images” or “my-bucket/images”)
num_workers (int) – Number of parallel upload workers
- class maxtext.checkpoint_conversion.utils.utils.MemoryMonitorTqdm(*_, **__)[source]#
Bases:
tqdmCustom tqdm class that displays memory usage in the progress bar.
- maxtext.checkpoint_conversion.utils.utils.load_orbax_checkpoint(config)[source]#
Loads a full Orbax checkpoint from disk with unsharded arrays.
- Parameters:
config – MaxText config containing checkpoint storage settings
- Returns:
Dictionary containing the full checkpoint structure
- Return type:
dict
- maxtext.checkpoint_conversion.utils.utils.extract_nnx_weights(weights_dict)[source]#
Extract weights from NNX checkpoint structure.
NNX checkpoints have structure: {‘decoder’: {‘decoder_norm’: {‘scale’: {‘value’: array}}}} This function flattens it to: {‘params-decoder-decoder_norm-scale’: array}
- Parameters:
weights_dict (dict) – NNX checkpoint weights dictionary
- Returns:
Dictionary mapping parameter names to weight arrays
- Return type:
dict[str, ndarray]
- maxtext.checkpoint_conversion.utils.utils.extract_linen_weights(weights_dict)[source]#
Extract weights from Linen checkpoint structure.
Linen checkpoints have structure: {‘params’: {‘decoder’: {‘decoder_norm’: {‘scale’: array}}}} This function flattens it to: {‘params-decoder-decoder_norm-scale’: array}
- Parameters:
weights_dict (dict) – Linen checkpoint weights dictionary
- Returns:
Dictionary mapping parameter names to weight arrays
- Return type:
dict[str, ndarray]
- maxtext.checkpoint_conversion.utils.utils.detect_and_extract_checkpoint(checkpoint_dict)[source]#
Detect checkpoint type (Linen vs NNX) and extract weights.
Handles multiple NNX checkpoint variants: - Linen: {‘params’: {‘params’: {‘decoder’: {…}, ‘token_embedder’: … {WEIGHT_ARRAY}}}} - NNX-SFT: {‘decoder’: {…}, ‘token_embedder’: … {‘value’: WEIGHT_ARRAY}} - NNX-RL: {‘base’: {‘decoder’: {…}, ‘token_embedder’: … {‘value’: WEIGHT_ARRAY}}}
Currently, we align all extracted weights to MaxText-Linen naming convention like “params-decoder-decoder_norm-scale”. This allows reusing the same param_mapping for both Linen and NNX checkpoints.
- Parameters:
checkpoint_dict (dict) – Raw checkpoint dictionary from Orbax
- Returns:
Dictionary mapping MaxText parameter names to weight arrays
- Return type:
dict[str, ndarray]
- maxtext.checkpoint_conversion.utils.utils.load_hf_dict_from_transformers(model_id, token, revision=None, dtype='auto')[source]#
Loads the HuggingFace model based on model_id (Eager mode only), used in to_maxtext
- Parameters:
model_id (str)
token (str)
revision (str | None)
dtype (str)
- maxtext.checkpoint_conversion.utils.utils.load_hf_dict_from_safetensors(model_id_or_path, token, revision, framework='pt')[source]#
If the safetensor contains more HF keys than MaxText model, these HF keys will be loaded but ignored during conversion. For example, if maxtext has deepseek3 with mtp=false, then safetensor weight with prefix model.layers.61 will be the extra keys.