Source code for maxtext.experimental.agent.ckpt_conversion_agent.ground_truth.llama3

# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Lllama3 ckpt conversation agent ground truth hook functions."""

import jax

import numpy as np


[docs] def LLAMA31_MAXTEXT_TO_HF_PARAM_HOOK_FN(config, scan_layers=False, saving_to_hf=False): """Creates parameter transformation functions for converting between MaxText and HuggingFace formats. This function generates a mapping of transformation functions that handle the necessary conversions between MaxText and HuggingFace parameter formats, including operations like reshaping. """ nlayers = config["num_hidden_layers"] def scale_query_layer(input_tensor, target_shape): def to_hf(): depth_scale = np.dtype("float32").type(np.sqrt(config["head_dim"])) original_dtype = input_tensor.dtype output_tensor = input_tensor.astype(np.float32) * depth_scale return output_tensor.astype(original_dtype) def from_hf(): depth_scale = np.dtype("float32").type(1 / np.sqrt(config["head_dim"])) original_dtype = input_tensor.dtype output_tensor = input_tensor.astype(np.float32) * depth_scale return output_tensor.astype(original_dtype) if saving_to_hf: return to_hf() else: return from_hf() def adjust_rope(input_tensor, target_shape): def from_hf(arr): """Convert from HF's concatenated layout to MaxText's interleaved layout""" half_dim = arr.shape[-1] // 2 first_half = arr[..., :half_dim] second_half = arr[..., half_dim:] return jax.numpy.stack([first_half, second_half], axis=-1).reshape(arr.shape) def to_hf(arr): """Convert from MaxText's interleaved layout to HF's concatenated layout""" evens = arr[..., ::2] odds = arr[..., 1::2] return jax.numpy.concatenate((evens, odds), axis=arr.ndim - 1) if saving_to_hf: return to_hf(input_tensor) else: return from_hf(input_tensor) def reshape_kernel(input_tensor, target_shape): def to_hf(): flipped_target_shape = np.flip(np.array(target_shape)) return input_tensor.reshape(flipped_target_shape).transpose() def from_hf(): return input_tensor.transpose().reshape(target_shape) if saving_to_hf: return to_hf() else: return from_hf() query_hooks = [reshape_kernel, adjust_rope, scale_query_layer] key_hooks = [reshape_kernel, adjust_rope] if not saving_to_hf: query_hooks.reverse() key_hooks.reverse() hook_fns = {} hook_fns["params-decoder-logits_dense-kernel"] = reshape_kernel if scan_layers: hook_fns = { **hook_fns, "params-decoder-layers-self_attention-query-kernel": query_hooks, "params-decoder-layers-self_attention-key-kernel": key_hooks, "params-decoder-layers-self_attention-value-kernel": reshape_kernel, "params-decoder-layers-self_attention-out-kernel": reshape_kernel, "params-decoder-layers-mlp-wi_0-kernel": reshape_kernel, "params-decoder-layers-mlp-wi_1-kernel": reshape_kernel, "params-decoder-layers-mlp-wo-kernel": reshape_kernel, } else: for layer_idx in range(nlayers): hook_fns[f"params-decoder-layers_{layer_idx}-self_attention-query-kernel"] = query_hooks hook_fns[f"params-decoder-layers_{layer_idx}-self_attention-key-kernel"] = key_hooks hook_fns[f"params-decoder-layers_{layer_idx}-self_attention-value-kernel"] = reshape_kernel hook_fns[f"params-decoder-layers_{layer_idx}-self_attention-out-kernel"] = reshape_kernel hook_fns[f"params-decoder-layers_{layer_idx}-mlp-wi_0-kernel"] = reshape_kernel hook_fns[f"params-decoder-layers_{layer_idx}-mlp-wi_1-kernel"] = reshape_kernel hook_fns[f"params-decoder-layers_{layer_idx}-mlp-wo-kernel"] = reshape_kernel return hook_fns