# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Verify the converted safetensor checkpoint (GCS or local) matches the HuggingFace checkpoint reference.
Usage to compare converted safetensor with remote HF reference:
JAX_PLATFORMS=cpu python3 -m maxtext.checkpoint_conversion.compare_hf_ckpt src/maxtext/configs/base.yml \
model_name=<maxtext_model_name> \
hf_access_token=<your_hf_token> \
hardware=cpu \
--candidate_path=<gcs_bucket_path or local_path> \
--atol=1e-2 --rtol=1e-2 --max_workers=12
Usage to compare converted safetensor with GCS/Local HF reference:
JAX_PLATFORMS=cpu python3 -m maxtext.checkpoint_conversion.compare_hf_ckpt src/maxtext/configs/base.yml \
hardware=cpu \
--candidate_path=<gcs_bucket_path or local_path> \
--reference_path=<gcs_bucket_path or local_path> \
--atol=1e-2 --rtol=1e-2 --max_workers=12
"""
import argparse
import os
import sys
import numpy as np
import gcsfs
import glob
import jax
from concurrent.futures import ThreadPoolExecutor, as_completed
from typing import Sequence, Dict
from tqdm import tqdm
import time
from absl import logging
from safetensors.torch import load as load_safetensors
from safetensors import safe_open
from maxtext.configs import pyconfig
from maxtext.checkpoint_conversion.utils.utils import print_ram_usage, load_hf_dict_from_transformers
from maxtext.utils import max_logging
from maxtext.utils.globals import HF_IDS
jax.config.update("jax_platform_name", "cpu")
def _load_gcs_shard(gcs_path: str, fs: gcsfs.GCSFileSystem) -> Dict[str, np.ndarray]:
"""Worker function to read and process a single safetensors file from GCS."""
max_logging.log(f"Processing GCS shard: {gcs_path}")
# Read bytes
with fs.open(gcs_path, "rb") as f:
file_bytes = f.read()
# Parse Safetensors
loaded_tensors = load_safetensors(file_bytes)
# Convert to Numpy
shard_dict = {}
for key, tensor in loaded_tensors.items():
shard_dict[key] = tensor.numpy()
return shard_dict
def _load_local_shard(file_path: str) -> Dict[str, np.ndarray]:
"""Worker function to read and process a single safetensors file from local disk using safe_open."""
max_logging.log(f"Processing local shard: {file_path}")
shard_dict = {}
with safe_open(file_path, framework="pt", device="cpu") as f:
for key in f.keys():
loaded_tensors = f.get_tensor(key)
shard_dict[key] = loaded_tensors.numpy()
return shard_dict
[docs]
def load_safetensors_generic(path: str, max_workers: int) -> Dict[str, np.ndarray]:
"""Downloads and loads all .safetensors files from GCS or Local Path in parallel."""
final_tensor_dict = {}
futures_map = {}
is_gcs = path.startswith("gs://")
print_ram_usage(f"Start {'GCS' if is_gcs else 'Local'} Load")
with ThreadPoolExecutor(max_workers=max_workers) as executor:
if is_gcs:
fs = gcsfs.GCSFileSystem()
search_pattern = f"{path.rstrip('/')}/*.safetensors"
safetensor_files = fs.glob(search_pattern)
safetensor_files = [f"gs://{f}" for f in safetensor_files]
max_logging.log(f"Found {len(safetensor_files)} files in GCS. Loading...")
for f in safetensor_files:
futures_map[executor.submit(_load_gcs_shard, f, fs)] = f
else:
# Local filesystem
search_pattern = os.path.join(path, "*.safetensors")
safetensor_files = glob.glob(search_pattern)
max_logging.log(f"Found {len(safetensor_files)} files locally. Loading...")
for f in safetensor_files:
futures_map[executor.submit(_load_local_shard, f)] = f
# Process results
for future in as_completed(futures_map):
try:
shard_data = future.result()
final_tensor_dict.update(shard_data)
except Exception as e:
max_logging.log(f"ERROR: Exception while loading shard: {e}")
raise e
print_ram_usage("End Load")
return final_tensor_dict
[docs]
def get_hf_model_state_dict(model_id: str, token: str) -> Dict[str, np.ndarray]:
"""Loads the HuggingFace model state dict and converts to numpy."""
max_logging.log(f"Loading reference model from HuggingFace: {model_id}...")
state_dict = load_hf_dict_from_transformers(model_id, token)
numpy_state_dict = {k: v.numpy() for k, v in state_dict.items()}
return numpy_state_dict
[docs]
def verify_dictionaries(
ref_dict: Dict[str, np.ndarray], cand_dict: Dict[str, np.ndarray], rtol: float, atol: float
) -> bool:
"""Compares two dictionaries of numpy arrays."""
max_logging.log(f"Verifying with rtol={rtol}, atol={atol}")
# 1. Compare Keys
ref_keys = set(ref_dict.keys())
cand_keys = set(cand_dict.keys())
if ref_keys != cand_keys:
max_logging.log("❌ KEYS DO NOT MATCH")
max_logging.log(f"Missing in Candidate: {ref_keys - cand_keys}")
max_logging.log(f"Extra in Candidate: {cand_keys - ref_keys}")
return False
max_logging.log("✅ Keys match. Verifying tensor values...")
# 2. Compare Values (Early Return)
for key in tqdm(ref_keys, desc="Verifying tensors"):
arr_ref = ref_dict[key]
arr_cand = cand_dict[key]
# Check Shape
if arr_ref.shape != arr_cand.shape:
max_logging.log(f"❌ SHAPE MISMATCH found for '{key}'")
max_logging.log(f"Reference: {arr_ref.shape} vs Candidate: {arr_cand.shape}")
return False
max_logging.log(f"✅ Key: {key} shape match.")
# Check Values
if not np.allclose(arr_ref, arr_cand, rtol=rtol, atol=atol):
max_diff = np.max(np.abs(arr_ref - arr_cand))
max_logging.log(f"❌ VALUE MISMATCH found for '{key}'")
max_logging.log(f"Max difference: {max_diff} (exceeds rtol={rtol}, atol={atol})")
return False
max_logging.log(f"✅ Key: {key} value match.")
max_logging.log("✅ All values match!")
return True
[docs]
def main(args: Sequence[str], test_args: argparse.Namespace) -> None:
# 1. Load Reference (HuggingFace)
t0 = time.perf_counter()
if test_args.reference_path:
hf_state_dict = load_safetensors_generic(test_args.reference_path, test_args.max_workers)
else:
config = pyconfig.initialize(args)
model_name = config.model_name
if model_name not in HF_IDS:
raise ValueError(f"Unsupported model name: {model_name}. " f"Supported: {list(HF_IDS.keys())}")
model_id = HF_IDS[model_name]
hf_token = config.hf_access_token
hf_state_dict = get_hf_model_state_dict(model_id, hf_token)
t1 = time.perf_counter()
max_logging.log(f"⏱️ HuggingFace model loaded in {(t1 - t0) / 60:.2f} minutes")
# 2. Load Candidate (GCS or Local)
t0 = time.perf_counter()
cand_state_dict = load_safetensors_generic(test_args.candidate_path, test_args.max_workers)
t1 = time.perf_counter()
max_logging.log(f"⏱️ Safetensors checkpoint loaded in {(t1 - t0) / 60:.2f} minutes")
# 3. Compare
success = verify_dictionaries(hf_state_dict, cand_state_dict, rtol=test_args.rtol, atol=test_args.atol)
if not success:
sys.exit(1)
if __name__ == "__main__":
# Suppress TF logging
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "0"
# Parse script-specific arguments
parser = argparse.ArgumentParser()
parser.add_argument(
"--candidate_path",
type=str,
default="",
required=True,
help="The path to the converted safetensors checkpoint (e.g. gs://bucket/path or /local/path)",
)
parser.add_argument(
"--reference_path",
type=str,
default="",
required=False,
help="The path to the reference safetensors checkpoint (e.g. gs://bucket/path or /local/path)",
)
parser.add_argument(
"--max_workers",
type=int,
default=12,
required=False,
help="The max workers for loading safetensors",
)
parser.add_argument(
"--rtol",
type=float,
default=1e-2,
required=False,
help="Relative tolerance for numpy.allclose",
)
parser.add_argument(
"--atol",
type=float,
default=1e-2,
required=False,
help="Absolute tolerance for numpy.allclose",
)
logging.set_verbosity(logging.INFO)
local_args, remaining_args = parser.parse_known_args()
model_args = [sys.argv[0]] + remaining_args
main(model_args, local_args)