maxtext.checkpoint_conversion.compare_hf_ckpt module#
Verify the converted safetensor checkpoint (GCS or local) matches the HuggingFace checkpoint reference.
Usage to compare converted safetensor with remote HF reference: JAX_PLATFORMS=cpu python3 -m maxtext.checkpoint_conversion.compare_hf_ckpt src/maxtext/configs/base.yml model_name=<maxtext_model_name> hf_access_token=<your_hf_token> hardware=cpu –candidate_path=<gcs_bucket_path or local_path> –atol=1e-2 –rtol=1e-2 –max_workers=12
Usage to compare converted safetensor with GCS/Local HF reference: JAX_PLATFORMS=cpu python3 -m maxtext.checkpoint_conversion.compare_hf_ckpt src/maxtext/configs/base.yml hardware=cpu –candidate_path=<gcs_bucket_path or local_path> –reference_path=<gcs_bucket_path or local_path> –atol=1e-2 –rtol=1e-2 –max_workers=12
- maxtext.checkpoint_conversion.compare_hf_ckpt.load_safetensors_generic(path, max_workers)[source]#
Downloads and loads all .safetensors files from GCS or Local Path in parallel.
- Parameters:
path (str)
max_workers (int)
- Return type:
Dict[str, ndarray]
- maxtext.checkpoint_conversion.compare_hf_ckpt.get_hf_model_state_dict(model_id, token)[source]#
Loads the HuggingFace model state dict and converts to numpy.
- Parameters:
model_id (str)
token (str)
- Return type:
Dict[str, ndarray]