# Copyright 2023–2026 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""nxx module overrides and utility methods for LTI distillation"""
import jax
from flax import nnx
from maxtext.layers import linears, initializers
from maxtext.common.common_types import Config
from jax.sharding import Mesh, NamedSharding
import jax.numpy as jnp
from typing import Iterable, Optional
from maxtext.common.common_types import DType, ShardMode, Array
from maxtext.layers.nnx_wrappers import ToNNX
from maxtext.layers.quantizations import AqtQuantization as Quant
from maxtext.layers.initializers import NdInitializer, nd_dense_init
from maxtext.utils import max_logging, max_utils
[docs]
class LearnToInitDecoderLayer(nnx.Module):
"""
A generic wrapper that initializes a base decoder layer and dynamically swaps
its DenseGeneral modules for learn-to-init distillation.
This class instantiates a standard base decoder layer (e.g., LlamaDecoderLayer)
and replaces specific attention projection sub-modules ("query", "key", "value",
"out") with customized `LearnToInitDense` modules.
Attributes:
learn_to_init_wrapper: The instantiated base decoder layer containing the mutable NNX graph.
config: The model configuration parameters.
rngs: The random number generator state used for initialization.
self_attention_module_name: The target name of the attention module to customize.
"""
def __init__(
self,
base_layer_cls,
config: Config,
model_mode: str,
mesh: Mesh,
rngs: nnx.Rngs,
quant=None,
**kwargs,
):
# Instantiate the original layer (e.g., LlamaDecoderLayer)
self.learn_to_init_wrapper = base_layer_cls(
config=config, model_mode=model_mode, mesh=mesh, rngs=rngs, quant=quant, **kwargs
)
self.config = config
self.rngs = rngs
self.self_attention_module_name = "self_attention"
# replace relevant nnx modules with customized LearnToInit modules
self._customize_attention_modules(self.learn_to_init_wrapper)
def _customize_attention_modules(self, module: nnx.Module):
"""Replaces specific DenseGeneral modules (q, k, v projections) in the attention module."""
attention_module = getattr(module, self.self_attention_module_name, None)
if attention_module is None:
return
# Target Q, K, V projections sub module names
target_names = ["query", "key", "value", "out"]
use_general_linear_map = self.config.lti_use_general_linear_map
teacher_config = self.config.teacher_config
for name in target_names:
child = getattr(attention_module, name, None)
if isinstance(child, linears.DenseGeneral):
orig_proj_shape = child.kernel.shape
assert len(orig_proj_shape) == 3
if name in ("query", "key", "value"):
teacher_heads_num = teacher_config.base_num_query_heads if name == "query" else teacher_config.base_num_kv_heads
teacher_shape = (orig_proj_shape[0], teacher_heads_num, teacher_config.head_dim)
elif name == "out":
teacher_shape = (teacher_config.base_num_query_heads, teacher_config.head_dim, orig_proj_shape[2])
else:
max_logging.warning(f"Non handled LTI projection type {name}")
continue
new_module = LearnToInitDense(
in_features_shape=child.in_features_shape,
out_features_shape=child.out_features_shape,
C=jnp.empty(teacher_shape),
axis=child.axis,
weight_dtype=child.weight_dtype,
dtype=child.dtype,
kernel_init=child.kernel_init,
kernel_axes=child.kernel_axes,
quant=child.quant,
use_bias=child.use_bias,
shard_mode=child.shard_mode,
matmul_precision=child.matmul_precision,
is_output_projection=(name == "out"),
use_general_linear_map=use_general_linear_map,
rngs=self.rngs, # Reuse the layer's RNG stream
)
# Swap the module in the mutable NNX graph
setattr(attention_module, name, new_module)
def __call__(self, *args, **kwargs):
# Just forward the forward pass arguments to the base layer
return self.learn_to_init_wrapper(*args, **kwargs)
[docs]
class LearnToInitDense(nnx.Module):
"""
A customized Dense layer used exclusively during the learn-to-init phase of distillation.
This module replaces standard `DenseGeneral` projections within the attention mechanism.
Instead of a single standard kernel, it computes the effective projection weights
dynamically during the forward pass by combining learnable student parameters
(either A and B matrices, or a general linear map W) with frozen teacher weights (C).
The projection math adapts automatically based on whether the layer is used for
Q/K/V projections or the final output projection.
Attributes:
C: The frozen, pre-trained teacher tensor.
A: The first learnable projection matrix (used if use_general_linear_map is False).
B: The second learnable projection matrix (used if use_general_linear_map is False).
W: A single, general learnable linear map (used if use_general_linear_map is True).
bias: An optional learnable bias parameter.
"""
TENSOR_A = "A"
TENSOR_B = "B"
TENSOR_C = "C"
TENSOR_W = "W"
def __init__(
self,
in_features_shape: Iterable[int] | int,
out_features_shape: Iterable[int] | int,
C: Optional[jax.Array] = None, # C is assumed to be the teacher tensor
axis: Iterable[int] | int = -1,
weight_dtype: DType = jnp.float32,
is_output_projection: bool = False,
use_general_linear_map: bool = False,
dtype: DType = jnp.float32,
kernel_init: NdInitializer = nd_dense_init(1.0, "fan_in", "truncated_normal"),
kernel_axes: tuple[None | str, ...] = (),
quant: None | Quant = None,
use_bias: bool = False,
shard_mode: ShardMode = ShardMode.AUTO,
matmul_precision: str = "default",
parameter_memory_host_offload: bool = False,
*, # Following arguments are keyword-only
rngs: nnx.Rngs = None,
):
self.in_features_shape = linears.canonicalize_tuple(in_features_shape)
self.out_features_shape = linears.canonicalize_tuple(out_features_shape)
self.axis = linears.canonicalize_tuple(axis)
self.weight_dtype = weight_dtype
self.is_output_projection = is_output_projection
self.dtype = dtype
self.kernel_init = kernel_init
self.kernel_axes = kernel_axes
self.quant = quant
self.use_bias = use_bias
self.shard_mode = shard_mode
self.matmul_precision = matmul_precision
self.parameter_memory_host_offload = parameter_memory_host_offload
self.use_general_linear_map = use_general_linear_map
self.C = nnx.Param(C, sharding=self.kernel_axes)
kernel_shape = self.in_features_shape + self.out_features_shape
assert len(kernel_shape) == 3, "LearnToInitDense currently only supports 3D kernels for attention."
assert len(self.C.value.shape) == 3, "The teacher tensor C must be 3D."
if self.is_output_projection:
# For output projection: student(u,v,b_s), teacher(x,y,b_t)
u, v, b_s = kernel_shape
x, y, b_t = self.C.value.shape
assert b_s == b_t, f"Embedding dimension mismatch for output projection: {b_s} != {b_t}"
if self.use_general_linear_map:
self.W = nnx.Param(
nnx.initializers.lecun_normal()(rngs.params(), (x, y, u, v), self.weight_dtype),
sharding=(None, None, None, None),
)
else:
self.A = nnx.Param(
nnx.initializers.lecun_normal()(rngs.params(), (x, u), self.weight_dtype),
sharding=(None, None),
)
self.B = nnx.Param(
nnx.initializers.lecun_normal()(rngs.params(), (v, y), self.weight_dtype),
sharding=(None, None),
)
else:
# For Q,K,V projections: student(b_s,u,v), teacher(b_t,x,y)
b_s, u, v = kernel_shape
b_t, x, y = self.C.value.shape
assert b_s == b_t, f"Dimension mismatch for QKV projection: {b_s} != {b_t}"
if self.use_general_linear_map:
self.W = nnx.Param(
nnx.initializers.lecun_normal()(rngs.params(), (x, y, u, v), self.weight_dtype),
sharding=(None, None, None, None),
)
else:
self.A = nnx.Param(
nnx.initializers.lecun_normal()(rngs.params(), (x, u), self.weight_dtype),
sharding=(None, None),
)
self.B = nnx.Param(
nnx.initializers.lecun_normal()(rngs.params(), (y, v), self.weight_dtype),
sharding=(None, None),
)
if self.use_bias:
bias_axes = self.kernel_axes[-len(self.out_features_shape) :]
bias_shape = self.out_features_shape
self.bias = nnx.Param(
initializers.default_bias_init(rngs.params(), bias_shape, self.weight_dtype),
sharding=bias_axes,
)
else:
self.bias = None
def __call__(self, inputs: Array, _initializing: bool = False, out_sharding: NamedSharding | None = None) -> Array:
inputs = jnp.asarray(inputs, self.dtype)
norm_axis = linears.normalize_axes(self.axis, inputs.ndim)
for i, ax in enumerate(norm_axis):
if inputs.shape[ax] != self.in_features_shape[i]:
raise ValueError(
f"Input dimension {inputs.shape[ax]} at axis {ax} "
f"does not match expected input feature size {self.in_features_shape[i]}"
)
if self.C.value.shape[0] == 0:
raise ValueError(
"The 'C' tensor in LearnToInitDense has not been initialized. "
"Please inject the teacher weights before training."
)
if self.use_general_linear_map:
kernel = _calc_attn_weight(
None,
None,
self.C,
general_map=self.W,
is_output_projection=self.is_output_projection,
matmul_precision=self.matmul_precision,
)
else:
kernel = _calc_attn_weight(
self.A, self.B, self.C, is_output_projection=self.is_output_projection, matmul_precision=self.matmul_precision
)
if self.parameter_memory_host_offload:
max_logging.log("linear.py: Moving parameter logits_dense kernel to device")
kernel = jax.device_put(kernel, max_utils.device_space())
kernel = jnp.asarray(kernel, self.dtype)
# out_sharding should be None for auto mesh axis
if self.shard_mode != ShardMode.EXPLICIT:
out_sharding = None
contract_ind = tuple(range(0, len(self.axis)))
output = linears._compute_dot_general_nnx(
inputs,
kernel,
norm_axis,
contract_ind,
self.matmul_precision,
None,
_initializing,
out_sharding,
)
if self.bias is not None:
bias = jnp.asarray(self.bias[...], self.dtype)
output += bias
return output
def _calc_attn_weight(
A: jax.Array | nnx.Param | None,
B: jax.Array | nnx.Param | None,
C: jax.Array | nnx.Param | None,
general_map: Optional[jax.Array | nnx.Param] = None,
is_output_projection: bool = False,
matmul_precision: str = "default",
scan_dim: str = "",
):
"""Computes the effective attention weights from teacher weight and learnable projection(s).
See the description of calculate_attn_weight() below for details.
"""
if general_map is not None:
if is_output_projection:
kernel = jnp.einsum(f"x{scan_dim}yb,x{scan_dim}yuv->u{scan_dim}vb", C, general_map, precision=matmul_precision)
else:
kernel = jnp.einsum(f"b{scan_dim}xy,x{scan_dim}yuv->b{scan_dim}uv", C, general_map, precision=matmul_precision)
return kernel
if is_output_projection:
intermediate = jnp.einsum(f"x{scan_dim}yb,x{scan_dim}u->y{scan_dim}ub", C, A, precision=matmul_precision)
kernel = jnp.einsum(f"y{scan_dim}ub,v{scan_dim}y->u{scan_dim}vb", intermediate, B, precision=matmul_precision)
else:
intermediate = jnp.einsum(f"b{scan_dim}xy,x{scan_dim}u->b{scan_dim}uy", C, A, precision=matmul_precision)
kernel = jnp.einsum(f"b{scan_dim}uy,y{scan_dim}v->b{scan_dim}uv", intermediate, B, precision=matmul_precision)
return kernel
[docs]
def calculate_attn_weight(
A: jax.Array | None,
B: jax.Array | None,
C: jax.Array,
general_map: Optional[jax.Array] = None,
is_output_projection: bool = False,
matmul_precision: str = "default",
) -> jax.Array:
"""
Helper function to dynamically compute the effective attention weights using `jnp.einsum`.
Computes the kernel by contracting the frozen teacher tensor (C) with the learnable
student representations. It handles both factorized maps (A and B) and general linear
maps (general_map/W), adjusting the tensor contractions based on whether the module
is an output projection or a Q/K/V projection.
Args:
A: The first learned factorized matrix.
B: The second learned factorized matrix.
C: The frozen teacher tensor.
general_map: An optional unified learnable projection tensor used instead of A and B.
is_output_projection: Boolean flag indicating if this computes the output projection weight.
matmul_precision: The precision for the einsum matrix multiplications.
scan_dim: A string representing the scan dimension for einsum (e.g., "l" for scanned layers, or "").
Returns:
The computed effective kernel tensor.
"""
# In scan mode, tensors have an extra 2-nd dimension for the layer.
# We add 'l' to the einsum string to handle this batch dimension.
scan_dim = "l" if C.ndim == 4 else ""
return _calc_attn_weight(
A,
B,
C,
general_map=general_map,
is_output_projection=is_output_projection,
matmul_precision=matmul_precision,
scan_dim=scan_dim,
)
[docs]
def apply_lti_model_update(student_model, student_config):
"""
Applies the finalized learn-to-init weights to the student model and cleans up the NNX graph.
This function iterates over the `LearnToInitDense` layers in the trained student model,
calculates their final, static effective kernels using `calculate_attn_weight`, and
replaces the dynamically-computed LTI modules with standard kernel representations.
It effectively collapses the learn-to-init parameterization back into a standard
decoder architecture, modifying the `student_model` in-place.
NOTE: works for ToNXX decoder model and layer-scan mode only
Args:
student_model: The trained student model to be updated in-place.
student_config: The configuration of the student model containing parameters like `matmul_precision`.
"""
# Access the nested ToNNX dictionary directly
assert isinstance(student_model.decoder, ToNNX), "LTI now only supports ToNNX as the student_model's decoder type"
lti_wrapped_node = student_model.decoder.layers["learn_to_init_wrapper"]
attn_state_dict = lti_wrapped_node["self_attention"]
# Iterate through known projections and compute final weights
for proj_name in ["query", "key", "value", "out"]:
if proj_name not in attn_state_dict:
raise ValueError("Unsupported structure of LTI-augmented Attention module.")
proj_params = attn_state_dict[proj_name]
is_output_proj = proj_name == "out"
C_param = proj_params.get(LearnToInitDense.TENSOR_C)
if C_param is None:
raise ValueError("Attention LTI-augmented module has no C parameter.")
if LearnToInitDense.TENSOR_W in proj_params:
max_logging.log(f"Computing final learn-to-init weight (general map) for: {proj_name}")
W_param = proj_params[LearnToInitDense.TENSOR_W]
final_kernel = calculate_attn_weight(
A=None,
B=None,
C=C_param,
general_map=W_param,
is_output_projection=is_output_proj,
matmul_precision=student_config.matmul_precision,
)
elif LearnToInitDense.TENSOR_A in proj_params and LearnToInitDense.TENSOR_B in proj_params:
max_logging.log(f"Computing final learn-to-init weight for: {proj_name}")
A_param = proj_params[LearnToInitDense.TENSOR_A]
B_param = proj_params[LearnToInitDense.TENSOR_B]
final_kernel = calculate_attn_weight(
A=A_param,
B=B_param,
C=C_param,
is_output_projection=is_output_proj,
matmul_precision=student_config.matmul_precision,
)
else:
continue
# 3. Overwrite C with the final computed kernel
C_param.set_value(final_kernel)
# 4. Standardize the structure by placing it under the 'kernel' key
proj_params["kernel"] = C_param
# 5. Clean up the LTI-specific parameters using .pop()
# Using pop(key, None) avoids KeyErrors if a tensor was omitted or already shared/deleted
proj_params.pop(LearnToInitDense.TENSOR_W, None)
proj_params.pop(LearnToInitDense.TENSOR_A, None)
proj_params.pop(LearnToInitDense.TENSOR_B, None)
proj_params.pop(LearnToInitDense.TENSOR_C, None)
# unpack the learn_to_init_wrapper to match the standard model structure
del student_model.decoder.layers["learn_to_init_wrapper"]
student_model.decoder.layers.update(lti_wrapped_node)