maxtext.checkpoint_conversion.utils.hf_utils module#

Utility functions to support the HF checkpoint conversion and verification process in test_hf.py.

maxtext.checkpoint_conversion.utils.hf_utils.convert_jax_weight_to_torch(weight, dtype=None)[source]#
Parameters:
  • weight (Array)

  • dtype (None | str)

Return type:

torch.Tensor

maxtext.checkpoint_conversion.utils.hf_utils.check_arrays_match(arrayA, arrayB, atol=0.01, rtol=1e-05)[source]#

Compare two sets of arrays for equality within the specified absolute and relative tolerances.

This function handles both PyTorch tensors and JAX arrays, automatically converting between the two if necessary. If the arrays don’t match within the specified tolerance, it prints detailed information about the mismatches.

Parameters:
  • arrayA (torch.Tensor | jax.Array) – First set of arrays to compare

  • arrayB (torch.Tensor | jax.Array) – Second set of arrays to compare

  • atol (float, optional) – Absolute tolerance for comparison. Defaults to 0.01.

  • rtol (float, optional) – Relative tolerance for comparison. Defaults to 1e-5.

Returns:

True if the arrays match within the specified tolerances, False otherwise.

Return type:

bool

maxtext.checkpoint_conversion.utils.hf_utils.check_predicted_tokens_match(logits_a, logits_b, tolerance=0.1)[source]#

Compares the top predicted tokens from each set of logits and ensures their disagreement rate doesn’t exceed the tolerance threshold. Raises an AssertionError if the disagreement is too high.

Parameters:
  • logits_a (jax.Array | torch.Tensor | np.ndarray) – First set of model output logits

  • logits_b (jax.Array | torch.Tensor | np.ndarray) – Second set of model output logits to compare against logits_a

  • tolerance (float, optional) – Maximum allowed fraction of token prediction disagreements, must be between 0.0 and 1.0. Defaults to 0.05 (5%).

Examples

>>> logits1 = get_model_output(input1)
>>> logits2 = get_model_output(input2)
>>> check_predicted_tokens_match(logits1, logits2, tolerance=0.03)  # Allows 3% disagreement
maxtext.checkpoint_conversion.utils.hf_utils.get_logits_comparison_metrics(logitsA, logitsB)[source]#

Calculate various comparison metrics between two sets of logits.

This function computes several metrics to compare the similarity and differences between two sets of logits, including KL divergence, absolute differences, and agreement in top-k predictions.

Parameters:
  • logitsA (jax.Array | torch.Tensor | np.ndarray) – First set of logits to compare

  • logitsB (jax.Array | torch.Tensor | np.ndarray) – Second set of logits to compare

Returns:

A dictionary containing the following metrics:
  • max_kl_div: Maximum KL divergence between probability distributions

  • abs_diff: Maximum absolute difference between probabilities

  • disagreement_top5: Proportion of positions where top-5 predictions differ

  • disagreement_top1: Proportion of positions where top-1 predictions differ

Return type:

dict

Notes

The function also prints a formatted table of the metrics using tabulate.