maxtext.checkpoint_conversion.compare_hf_ckpt module

maxtext.checkpoint_conversion.compare_hf_ckpt module#

Verify the converted safetensor checkpoint (GCS or local) matches the HuggingFace checkpoint reference.

Usage to compare converted safetensor with remote HF reference: JAX_PLATFORMS=cpu python3 -m maxtext.checkpoint_conversion.compare_hf_ckpt src/maxtext/configs/base.yml model_name=<maxtext_model_name> hf_access_token=<your_hf_token> hardware=cpu –candidate_path=<gcs_bucket_path or local_path> –atol=1e-2 –rtol=1e-2 –max_workers=12

Usage to compare converted safetensor with GCS/Local HF reference: JAX_PLATFORMS=cpu python3 -m maxtext.checkpoint_conversion.compare_hf_ckpt src/maxtext/configs/base.yml hardware=cpu –candidate_path=<gcs_bucket_path or local_path> –reference_path=<gcs_bucket_path or local_path> –atol=1e-2 –rtol=1e-2 –max_workers=12

maxtext.checkpoint_conversion.compare_hf_ckpt.load_safetensors_generic(path, max_workers)[source]#

Downloads and loads all .safetensors files from GCS or Local Path in parallel.

Parameters:
  • path (str)

  • max_workers (int)

Return type:

Dict[str, ndarray]

maxtext.checkpoint_conversion.compare_hf_ckpt.get_hf_model_state_dict(model_id, token)[source]#

Loads the HuggingFace model state dict and converts to numpy.

Parameters:
  • model_id (str)

  • token (str)

Return type:

Dict[str, ndarray]

maxtext.checkpoint_conversion.compare_hf_ckpt.verify_dictionaries(ref_dict, cand_dict, rtol, atol)[source]#

Compares two dictionaries of numpy arrays.

Parameters:
  • ref_dict (Dict[str, ndarray])

  • cand_dict (Dict[str, ndarray])

  • rtol (float)

  • atol (float)

Return type:

bool

maxtext.checkpoint_conversion.compare_hf_ckpt.main(args, test_args)[source]#
Parameters:
  • args (Sequence[str])

  • test_args (Namespace)

Return type:

None