maxtext.layers.learn_to_init_layer module#

nxx module overrides and utility methods for LTI distillation

class maxtext.layers.learn_to_init_layer.LearnToInitDecoderLayer(*args, **kwargs)[source]#

Bases: Module

A generic wrapper that initializes a base decoder layer and dynamically swaps its DenseGeneral modules for learn-to-init distillation.

This class instantiates a standard base decoder layer (e.g., LlamaDecoderLayer) and replaces specific attention projection sub-modules (“query”, “key”, “value”, “out”) with customized LearnToInitDense modules.

Parameters:
  • args (Any)

  • kwargs (Any)

Return type:

Any

learn_to_init_wrapper#

The instantiated base decoder layer containing the mutable NNX graph.

config#

The model configuration parameters.

rngs#

The random number generator state used for initialization.

self_attention_module_name#

The target name of the attention module to customize.

class maxtext.layers.learn_to_init_layer.LearnToInitDense(*args, **kwargs)[source]#

Bases: Module

A customized Dense layer used exclusively during the learn-to-init phase of distillation.

This module replaces standard DenseGeneral projections within the attention mechanism. Instead of a single standard kernel, it computes the effective projection weights dynamically during the forward pass by combining learnable student parameters (either A and B matrices, or a general linear map W) with frozen teacher weights (C).

The projection math adapts automatically based on whether the layer is used for Q/K/V projections or the final output projection.

Parameters:
  • args (Any)

  • kwargs (Any)

Return type:

Any

C#

The frozen, pre-trained teacher tensor.

A#

The first learnable projection matrix (used if use_general_linear_map is False).

B#

The second learnable projection matrix (used if use_general_linear_map is False).

W#

A single, general learnable linear map (used if use_general_linear_map is True).

bias#

An optional learnable bias parameter.

TENSOR_A = 'A'#
TENSOR_B = 'B'#
TENSOR_C = 'C'#
TENSOR_W = 'W'#
maxtext.layers.learn_to_init_layer.calculate_attn_weight(A, B, C, general_map=None, is_output_projection=False, matmul_precision='default')[source]#

Helper function to dynamically compute the effective attention weights using jnp.einsum.

Computes the kernel by contracting the frozen teacher tensor (C) with the learnable student representations. It handles both factorized maps (A and B) and general linear maps (general_map/W), adjusting the tensor contractions based on whether the module is an output projection or a Q/K/V projection.

Parameters:
  • A (Array | None) – The first learned factorized matrix.

  • B (Array | None) – The second learned factorized matrix.

  • C (Array) – The frozen teacher tensor.

  • general_map (Array | None) – An optional unified learnable projection tensor used instead of A and B.

  • is_output_projection (bool) – Boolean flag indicating if this computes the output projection weight.

  • matmul_precision (str) – The precision for the einsum matrix multiplications.

  • scan_dim – A string representing the scan dimension for einsum (e.g., “l” for scanned layers, or “”).

Returns:

The computed effective kernel tensor.

Return type:

Array

maxtext.layers.learn_to_init_layer.apply_lti_model_update(student_model, student_config)[source]#

Applies the finalized learn-to-init weights to the student model and cleans up the NNX graph.

This function iterates over the LearnToInitDense layers in the trained student model, calculates their final, static effective kernels using calculate_attn_weight, and replaces the dynamically-computed LTI modules with standard kernel representations. It effectively collapses the learn-to-init parameterization back into a standard decoder architecture, modifying the student_model in-place.

NOTE: works for ToNXX decoder model and layer-scan mode only

Parameters:
  • student_model – The trained student model to be updated in-place.

  • student_config – The configuration of the student model containing parameters like matmul_precision.