maxtext.checkpoint_conversion.utils.hf_utils module#
Utility functions to support the HF checkpoint conversion and verification process in test_hf.py.
- maxtext.checkpoint_conversion.utils.hf_utils.convert_jax_weight_to_torch(weight, dtype=None)[source]#
- Parameters:
weight (Array)
dtype (None | str)
- Return type:
torch.Tensor
- maxtext.checkpoint_conversion.utils.hf_utils.check_arrays_match(arrayA, arrayB, atol=0.01, rtol=1e-05)[source]#
Compare two sets of arrays for equality within the specified absolute and relative tolerances.
This function handles both PyTorch tensors and JAX arrays, automatically converting between the two if necessary. If the arrays don’t match within the specified tolerance, it prints detailed information about the mismatches.
- Parameters:
arrayA (torch.Tensor | jax.Array) – First set of arrays to compare
arrayB (torch.Tensor | jax.Array) – Second set of arrays to compare
atol (float, optional) – Absolute tolerance for comparison. Defaults to 0.01.
rtol (float, optional) – Relative tolerance for comparison. Defaults to 1e-5.
- Returns:
True if the arrays match within the specified tolerances, False otherwise.
- Return type:
bool
- maxtext.checkpoint_conversion.utils.hf_utils.check_predicted_tokens_match(logits_a, logits_b, tolerance=0.1)[source]#
Compares the top predicted tokens from each set of logits and ensures their disagreement rate doesn’t exceed the tolerance threshold. Raises an AssertionError if the disagreement is too high.
- Parameters:
logits_a (jax.Array | torch.Tensor | np.ndarray) – First set of model output logits
logits_b (jax.Array | torch.Tensor | np.ndarray) – Second set of model output logits to compare against logits_a
tolerance (float, optional) – Maximum allowed fraction of token prediction disagreements, must be between 0.0 and 1.0. Defaults to 0.05 (5%).
Examples
>>> logits1 = get_model_output(input1) >>> logits2 = get_model_output(input2) >>> check_predicted_tokens_match(logits1, logits2, tolerance=0.03) # Allows 3% disagreement
- maxtext.checkpoint_conversion.utils.hf_utils.get_logits_comparison_metrics(logitsA, logitsB)[source]#
Calculate various comparison metrics between two sets of logits.
This function computes several metrics to compare the similarity and differences between two sets of logits, including KL divergence, absolute differences, and agreement in top-k predictions.
- Parameters:
logitsA (jax.Array | torch.Tensor | np.ndarray) – First set of logits to compare
logitsB (jax.Array | torch.Tensor | np.ndarray) – Second set of logits to compare
- Returns:
- A dictionary containing the following metrics:
max_kl_div: Maximum KL divergence between probability distributions
abs_diff: Maximum absolute difference between probabilities
disagreement_top5: Proportion of positions where top-5 predictions differ
disagreement_top1: Proportion of positions where top-1 predictions differ
- Return type:
dict
Notes
The function also prints a formatted table of the metrics using tabulate.