Source code for maxtext.models.llama4

# Copyright 2023–2026 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#    https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Llama4 decoder layer definition."""
# pylint: disable=arguments-differ, disable=no-name-in-module, missing-function-docstring

import math

from flax import linen as nn
from flax import nnx
import jax
from jax import lax
from jax.ad_checkpoint import checkpoint_name
import jax.numpy as jnp
from jax.sharding import Mesh
from maxtext.common.common_types import Array, AttentionType, Config, MODEL_MODE_TRAIN
from maxtext.common.common_types import MODEL_MODE_PREFILL
from maxtext.inference import page_manager
from maxtext.layers import initializers
from maxtext.layers import linears
from maxtext.layers import nnx_wrappers
from maxtext.layers import quantizations
from maxtext.layers.attentions import Attention
from maxtext.layers.linears import Dropout
from maxtext.layers.linears import MlpBlock
from maxtext.layers.moe import RoutedAndSharedMoE
from maxtext.layers.normalizations import RMSNorm
from maxtext.layers.quantizations import AqtQuantization as Quant
from maxtext.utils import max_utils

#### Multi modal model implementation


[docs] class Llama4UnfoldConvolution(nnx.Module): """implementation of Llama4UnfoldConvolution for Llama4 Multi modal model. This module extracts patches from input images and projects them to hidden dimension. Attributes: config: Config containing model parameters """ def __init__(self, config: Config, *, rngs: nnx.Rngs = None): self.config = config self.rngs = rngs self.vit_unfold_linear = linears.DenseGeneral( in_features_shape=( self.config.num_channels_for_vit * self.config.patch_size_for_vit * self.config.patch_size_for_vit ), out_features_shape=self.config.hidden_size_for_vit, dtype=self.config.dtype_mm, use_bias=False, matmul_precision=self.config.matmul_precision, rngs=rngs, ) def __call__(self, inputs: Array) -> Array: batch_size, num_channels, img, _ = inputs.shape num_patches = (img // self.config.patch_size_for_vit) ** 2 patches = lax.conv_general_dilated_patches( inputs, filter_shape=[self.config.patch_size_for_vit, self.config.patch_size_for_vit], window_strides=[self.config.patch_size_for_vit, self.config.patch_size_for_vit], padding="VALID", dimension_numbers=("NCHW", "HWIO", "NCHW"), precision=lax.Precision(self.config.matmul_precision), preferred_element_type=self.config.dtype_mm, ) patches = patches.reshape( batch_size, num_channels * self.config.patch_size_for_vit * self.config.patch_size_for_vit, num_patches ) patches = patches.transpose(0, 2, 1) hidden_states = self.vit_unfold_linear(patches) return hidden_states
[docs] def pixel_shuffle(input_tensor: Array, shuffle_ratio: float) -> Array: """Apply pixel shuffle operation to the input tensor.""" batch_size, num_patches, channels = input_tensor.shape patch_size = int(math.sqrt(num_patches)) # Reshape to [batch_size, patch_size, patch_size, channels] input_tensor = input_tensor.reshape(batch_size, patch_size, patch_size, -1) batch_size, height, width, channels = input_tensor.shape # Reshape to [batch_size, height, width * shuffle_ratio, channels / shuffle_ratio] reshaped_tensor = input_tensor.reshape(batch_size, height, int(width * shuffle_ratio), int(channels / shuffle_ratio)) # Transpose to [batch_size, width * shuffle_ratio, height, channels / shuffle_ratio] reshaped_tensor = reshaped_tensor.transpose(0, 2, 1, 3) # Reshape to [batch_size, height * shuffle_ratio, width * shuffle_ratio, channels / (shuffle_ratio^2)] reshaped_tensor = reshaped_tensor.reshape( batch_size, int(height * shuffle_ratio), int(width * shuffle_ratio), int(channels / (shuffle_ratio**2)) ) # Transpose to [batch_size, width * shuffle_ratio, height * shuffle_ratio, channels / (shuffle_ratio^2)] reshaped_tensor = reshaped_tensor.transpose(0, 2, 1, 3) # Reshape back to [batch_size, num_patches, channels] output_tensor = reshaped_tensor.reshape(batch_size, -1, reshaped_tensor.shape[-1]) return output_tensor
[docs] class Llama4VisionMLP(nnx.Module): """MLP block for Llama4EncoderLayer. Attributes: config: Config containing model parameters """ def __init__(self, config: Config, *, rngs: nnx.Rngs = None): self.config = config self.rngs = rngs self.vit_encoder_layer_mlp_fc1 = linears.DenseGeneral( in_features_shape=self.config.hidden_size_for_vit, out_features_shape=self.config.intermediate_size_for_vit, dtype=self.config.dtype_mm, use_bias=True, matmul_precision=self.config.matmul_precision, rngs=self.rngs, ) self.vit_encoder_layer_mlp_fc2 = linears.DenseGeneral( in_features_shape=self.config.intermediate_size_for_vit, out_features_shape=self.config.hidden_size_for_vit, dtype=self.config.dtype_mm, use_bias=True, matmul_precision=self.config.matmul_precision, rngs=self.rngs, ) def __call__(self, hidden_states: Array) -> Array: hidden_states = self.vit_encoder_layer_mlp_fc1(hidden_states) hidden_states = nnx.gelu(hidden_states, approximate=False) hidden_states = self.vit_encoder_layer_mlp_fc2(hidden_states) return hidden_states
[docs] class Llama4VisionMLP2(nnx.Module): """MLP block for Llama4VisionPixelShuffleMLP. Attributes: config: Config containing model parameters """ def __init__(self, config: Config, *, rngs: nnx.Rngs = None): self.config = config self.rngs = rngs self.vit_pixel_shuffle_mlp_fc1 = linears.DenseGeneral( in_features_shape=self.config.intermediate_size_for_vit, out_features_shape=self.config.projector_input_dim_for_vit, dtype=self.config.dtype_mm, use_bias=False, matmul_precision=self.config.matmul_precision, rngs=self.rngs, ) self.vit_pixel_shuffle_mlp_fc2 = linears.DenseGeneral( in_features_shape=self.config.projector_input_dim_for_vit, out_features_shape=self.config.projector_output_dim_for_vit, dtype=self.config.dtype_mm, use_bias=False, matmul_precision=self.config.matmul_precision, rngs=self.rngs, ) self.dropout = linears.Dropout(rate=self.config.projector_dropout_for_vit, rngs=self.rngs) def __call__(self, hidden_states: Array, deterministic: bool = False) -> Array: hidden_states = self.vit_pixel_shuffle_mlp_fc1(hidden_states) hidden_states = nnx.gelu(hidden_states, approximate=False) hidden_states = self.dropout(hidden_states, deterministic=deterministic) hidden_states = self.vit_pixel_shuffle_mlp_fc2(hidden_states) hidden_states = nnx.gelu(hidden_states, approximate=False) return hidden_states
[docs] class Llama4VisionPixelShuffleMLP(nnx.Module): """Implementation of Llama4VisionPixelShuffleMLP for Llama4 Multi modal model. This module applies pixel shuffle operation and MLP to encoded patches. Attributes: config: Config containing model parameters """ def __init__(self, config: Config, *, rngs: nnx.Rngs = None): self.config = config self.rngs = rngs self.pixel_shuffle_ratio = self.config.pixel_shuffle_ratio_for_vit self.pixel_shuffle_mlp = Llama4VisionMLP2(config=config, rngs=self.rngs) def __call__(self, encoded_patches: Array, deterministic: bool = False) -> Array: # Apply pixel shuffle operation encoded_patches = pixel_shuffle(encoded_patches, self.pixel_shuffle_ratio) # Apply MLP transformation result = self.pixel_shuffle_mlp(encoded_patches, deterministic=deterministic) return result
[docs] class Llama4MultiModalProjector(nnx.Module): """Implementation of Llama4MultiModalProjector for Llama4 Multi modal model. This module projects vision features to text hidden dimension. Attributes: config: Config containing model parameters """ def __init__(self, config: Config, mesh: Mesh, *, rngs: nnx.Rngs = None): self.config = config self.mesh = mesh self.rngs = rngs self.vit_multi_modal_projector = linears.DenseGeneral( in_features_shape=self.config.vision_output_dim_for_vit, out_features_shape=self.config.base_emb_dim, dtype=self.config.dtype_mm, use_bias=False, matmul_precision=self.config.matmul_precision, rngs=self.rngs, ) def __call__(self, image_features: Array) -> Array: """Project image features to text hidden dimension. Args: image_features: Input tensor of shape [batch_size, num_patches, (pixel_shuffle_ratio**2), vision_output_dim] Returns: Tensor of shape [batch_size, num_patches, (pixel_shuffle_ratio**2), vision_hidden_size] """ b, t, c, d = image_features.shape # Reshape image_features to [b * t, c, d] and project to text hidden dimension image_features = image_features.reshape(b * t, c, d) hidden_states = self.vit_multi_modal_projector(image_features) _, c, d = hidden_states.shape hidden_states = hidden_states.reshape(b, t, c, d) return hidden_states
[docs] def llama4multimodalprojector_as_linen(config: Config, mesh: Mesh): return nnx_wrappers.to_linen( Llama4MultiModalProjector, config=config, mesh=mesh, name="Llama4MultiModalProjector_0", abstract_init=False, metadata_fn=initializers.variable_to_logically_partitioned, )
[docs] def determine_is_nope_layer(layer_id: int, nope_layer_interval: int) -> bool: """ Determines whether the given layer at `layer_id` should use RoPE or not (NoPE). Args: layer_id: The index of the layer. nope_layer_interval: The interval at which layers should use NoPE. Returns: True if the layer should use NoPE, False otherwise. """ return nope_layer_interval is not None and nope_layer_interval > 0 and (layer_id + 1) % nope_layer_interval == 0
[docs] def determine_is_moe_layer(layer_id: int, interleave_moe_layer_step: int) -> bool: """ Determines whether the given layer at `layer_id` is MoE layer. This function implements a striding pattern. For example: - If moe_layer_stride is 1, all layers are MoE layers. - If moe_layer_stride is 2, layers with index 1, 3, 5, ... are MoE layers. Args: layer_id: The 0-based index of the layer being checked. interleave_moe_layer_step: The interval or stride for placing MoE layers. Returns: True if the layer is MoE layer, False otherwise. """ return ( interleave_moe_layer_step is not None and interleave_moe_layer_step > 0 and (layer_id + 1) % interleave_moe_layer_step == 0 )
# ----------------------------------------- # The Decoder Layer specific for LLama4 # -----------------------------------------
[docs] class Llama4DecoderLayer(nnx.Module): """Transformer decoder layer for Llama4.""" def __init__( self, config: Config, mesh: Mesh, model_mode: str, rngs: nnx.Rngs, quant: None | Quant = None, is_nope_layer: bool = False, is_moe_layer: bool = False, ): """Initializes the Llama4 decoder layer. Args: config: The main model configuration object. mesh: The device mesh used for sharding parameters and activations. model_mode: One of MODEL_MODE_TRAIN, MODEL_MODE_PREFILL, or MODEL_MODE_AUTOREGRESSIVE. rngs: An `nnx.Rngs` object to provide random numbers. quant: An optional configuration for quantization. Defaults to None. is_nope_layer: If True, this layer will be configured as No Position Embeddings layer. Defaults to False. is_moe_layer: If True, this layer will use a MoE block. Defaults to False as Dense. """ self.config = config self.mesh = mesh self.quant = quant self.rngs = rngs self.is_nope_layer = is_nope_layer self.is_moe_layer = is_moe_layer batch_size, seq_len = max_utils.get_batch_seq_len_for_mode(config, model_mode) dummy_inputs_shape = (batch_size, seq_len, config.emb_dim) self.pre_self_attention_layer_norm = RMSNorm( num_features=config.emb_dim, dtype=config.dtype, weight_dtype=config.weight_dtype, kernel_axes=("norm",), epsilon=config.normalization_layer_epsilon, rngs=rngs, ) # Instead of scaling the query values in the checkpoint conversion (`llama_or_mistral_ckpt`) # we'll do it dynamically in the forward pass of Attention query_pre_attn_scalar = config.head_dim**-0.5 self.self_attention = Attention( config=config, num_query_heads=config.num_query_heads, num_kv_heads=config.num_kv_heads, head_dim=config.head_dim, max_target_length=config.max_target_length, max_prefill_predict_length=config.max_prefill_predict_length, attention_kernel=config.attention, inputs_q_shape=dummy_inputs_shape, inputs_kv_shape=dummy_inputs_shape, mesh=mesh, dtype=config.dtype, weight_dtype=config.weight_dtype, dropout_rate=config.dropout_rate, float32_qk_product=config.float32_qk_product, float32_logits=config.float32_logits, quant=self.quant, kv_quant=quantizations.configure_kv_quant(config), prefill_cache_axis_order=tuple(map(int, config.prefill_cache_axis_order.split(","))), ar_cache_axis_order=tuple(map(int, config.ar_cache_axis_order.split(","))), compute_axis_order=tuple(map(int, config.compute_axis_order.split(","))), reshape_q=config.reshape_q, use_ragged_attention=config.use_ragged_attention, ragged_block_size=config.ragged_block_size, is_nope_layer=self.is_nope_layer, use_qk_norm=config.use_qk_norm, query_pre_attn_scalar=query_pre_attn_scalar, temperature_tuning=config.temperature_tuning, temperature_tuning_scale=0.1, temperature_tuning_floor_scale=8192.0, # note: chunk_attn_window_size is set in the config attention_type=AttentionType.GLOBAL if self.is_nope_layer else AttentionType.CHUNK, model_mode=model_mode, rngs=rngs, ) self.post_self_attention_layer_norm = RMSNorm( num_features=config.emb_dim, dtype=config.dtype, weight_dtype=config.weight_dtype, kernel_axes=("norm",), epsilon=config.normalization_layer_epsilon, rngs=self.rngs, ) if self.is_moe_layer: # NOTE: the name Llama4MoEBlock_0 is to ensure reverse compatibility with # existing checkpoints for MoE block. self.Llama4MoEBlock_0 = RoutedAndSharedMoE( config=config, mesh=self.mesh, kernel_init=initializers.nd_dense_init(config.dense_init_scale, "fan_in", "truncated_normal"), kernel_axes=("embed", None), dtype=config.dtype, weight_dtype=config.weight_dtype, quant=self.quant, rngs=self.rngs, ) else: self.mlp = MlpBlock( mesh=self.mesh, in_features=config.emb_dim, intermediate_dim=config.mlp_dim, activations=config.mlp_activations, intermediate_dropout_rate=config.dropout_rate, dtype=config.dtype, weight_dtype=config.weight_dtype, config=config, quant=self.quant, model_mode=model_mode, rngs=self.rngs, ) self.dropout = Dropout(rate=config.dropout_rate, broadcast_dims=(-2,), rngs=self.rngs) if model_mode == MODEL_MODE_PREFILL: self.activation_axis_names = ("activation_batch", "prefill_activation_norm_length", "activation_embed") else: self.activation_axis_names = ("activation_batch", "activation_norm_length", "activation_embed") @property def moe_block(self): return self.Llama4MoEBlock_0 def __call__( self, inputs, decoder_segment_ids, decoder_positions, deterministic, model_mode, previous_chunk=None, slot: None | int = None, page_state: None | page_manager.PageState = None, kv_cache=None, attention_metadata=None, ): cfg = self.config assert cfg.num_experts >= 1, "Expected the Llama4 config to have `num_experts > 1`." # Unpack inputs if it's a tuple (e.g. from a previous layer returning (hidden_states, kv_cache)) is_scan_carry = False if isinstance(inputs, tuple) and len(inputs) == 3: hidden_states, stacked_kv_cache, layer_idx = inputs kv_cache = stacked_kv_cache[layer_idx] inputs = hidden_states is_scan_carry = True elif isinstance(inputs, tuple): inputs = inputs[0] inputs = nn.with_logical_constraint(inputs, self.activation_axis_names) inputs = checkpoint_name(inputs, "decoder_layer_input") lnx = self.pre_self_attention_layer_norm(inputs) lnx = nn.with_logical_constraint(lnx, self.activation_axis_names) # Self-attention block attention_lnx, kv_cache = self.self_attention( lnx, lnx, decoder_positions, decoder_segment_ids=decoder_segment_ids, deterministic=deterministic, model_mode=model_mode, slot=slot, page_state=page_state, previous_chunk=previous_chunk, kv_cache=kv_cache, attention_metadata=attention_metadata, ) attention_lnx = nn.with_logical_constraint(attention_lnx, self.activation_axis_names) intermediate_inputs = inputs + attention_lnx # Fully Connected hidden_states = self.post_self_attention_layer_norm(intermediate_inputs) hidden_states = nn.with_logical_constraint(hidden_states, self.activation_axis_names) load_balance_loss = None if self.is_moe_layer: mlp_lnx, load_balance_loss, _ = self.moe_block(hidden_states) else: mlp_lnx = self.mlp(hidden_states, deterministic=deterministic) mlp_lnx = nn.with_logical_constraint(mlp_lnx, self.activation_axis_names) layer_output = mlp_lnx + intermediate_inputs layer_output = self.dropout(layer_output, deterministic=deterministic) layer_output = nn.with_logical_constraint(layer_output, self.activation_axis_names) if self.config.load_balance_loss_weight > 0.0 and load_balance_loss is not None: self.sow("intermediates", "moe_lb_loss", load_balance_loss) if cfg.record_internal_nn_metrics: self.sow("intermediates", "activation_mean", jnp.mean(layer_output)) self.sow("intermediates", "activation_stdev", jnp.std(layer_output)) self.sow( "intermediates", "activation_fraction_zero", jnp.sum(layer_output == 0) / jnp.size(layer_output), ) if is_scan_carry: def update_cache(cache, val): if jnp.size(val) > 0: return cache.at[layer_idx].set(val) return cache stacked_kv_cache = jax.tree_util.tree_map(update_cache, stacked_kv_cache, kv_cache) return (layer_output, stacked_kv_cache, layer_idx + 1), None elif cfg.scan_layers: return layer_output, None else: return layer_output, kv_cache
Llama4DecoderLayerToLinen = nnx_wrappers.to_linen_class( Llama4DecoderLayer, base_metadata_fn=initializers.variable_to_logically_partitioned, )
[docs] class Llama4ScannableBlock(nnx.Module): """A repeatable block given nope_layer_interval and interleave_moe_layer_step.""" def __init__( self, config: Config, mesh: Mesh, model_mode: str, rngs: nnx.Rngs, quant: None | Quant = None, nope_layer_interval: int = 1, interleave_moe_layer_step: int = 1, ): """Initializes the scannable block. Args: config: The main model configuration object. mesh: The device mesh used for sharding parameters and activations. model_mode: One of MODEL_MODE_TRAIN, MODEL_MODE_PREFILL, or MODEL_MODE_AUTOREGRESSIVE. rngs: An `nnx.Rngs` object to provide random numbers for initialization. quant: An optional configuration for quantization. Defaults to None. nope_layer_interval: Specifies the interval for inserting a NoPE layer. interleave_moe_layer_step: Specifies the interval for inserting a MoE layer. """ self.config = config self.mesh = mesh self.model_mode = model_mode self.quant = quant self.rngs = rngs self.nope_layer_interval = nope_layer_interval self.interleave_moe_layer_step = interleave_moe_layer_step for layer_id in range(self.config.inhomogeneous_layer_cycle_interval): nope_layer = determine_is_nope_layer(layer_id, self.nope_layer_interval) moe_layer = determine_is_moe_layer(layer_id, self.interleave_moe_layer_step) layer_name = f"layers_{layer_id}" layer = Llama4DecoderLayer( config=self.config, mesh=self.mesh, model_mode=self.model_mode, rngs=self.rngs, quant=self.quant, is_nope_layer=nope_layer, is_moe_layer=moe_layer, ) setattr(self, layer_name, layer) def __call__( self, inputs, decoder_segment_ids, decoder_positions, deterministic, model_mode, previous_chunk=None, slot: None | int = None, page_state: None | page_manager.PageState = None, kv_cache=None, attention_metadata=None, ): cfg = self.config inputs = nn.with_logical_constraint(inputs, ("activation_batch", "activation_norm_length", "activation_embed")) inputs = checkpoint_name(inputs, "decoder_layer_input") y = inputs for layer_id in range(cfg.inhomogeneous_layer_cycle_interval): y = getattr(self, f"layers_{layer_id}")( y, decoder_segment_ids, decoder_positions, deterministic, model_mode, previous_chunk=previous_chunk, page_state=page_state, slot=slot, kv_cache=kv_cache, attention_metadata=attention_metadata, ) if cfg.scan_layers: y = y[0] if cfg.scan_layers: return y, None else: return y
Llama4ScannableBlockToLinen = nnx_wrappers.to_linen_class( Llama4ScannableBlock, base_metadata_fn=initializers.variable_to_logically_partitioned, )
[docs] class Llama4VisionEncoderLayer(nnx.Module): """Transformer encoder layer for Llama4 vision model.""" def __init__(self, config: Config, mesh: Mesh, *, rngs: nnx.Rngs = None): self.config = config self.mesh = mesh self.rngs = rngs self.hidden_states_shape = ( self.config.per_device_batch_size, (self.config.image_size_for_vit // self.config.patch_size_for_vit) ** 2 + 1, self.config.hidden_size_for_vit, ) self.input_layer_norm = nnx.LayerNorm( num_features=self.config.hidden_size_for_vit, epsilon=self.config.normalization_layer_epsilon, rngs=self.rngs ) self.self_attention_vision = Attention( config=self.config, num_query_heads=self.config.num_attention_heads_for_vit, num_kv_heads=self.config.num_attention_heads_for_vit, head_dim=self.config.hidden_size_for_vit // self.config.num_attention_heads_for_vit, max_target_length=(self.config.image_size_for_vit // self.config.patch_size_for_vit) ** 2 + 1, attention_kernel="dot_product", inputs_q_shape=self.hidden_states_shape, inputs_kv_shape=self.hidden_states_shape, float32_qk_product=self.config.float32_qk_product, float32_logits=self.config.float32_logits, dtype=self.config.dtype_mm, weight_dtype=self.config.weight_dtype, mesh=self.mesh, dropout_rate=0, name="self_attention_vision", attention_type=AttentionType.FULL, is_nope_layer=False, use_bias_in_projections=True, is_vision=True, use_qk_norm=False, query_pre_attn_scalar=1 / math.sqrt(self.config.hidden_size_for_vit // self.config.num_attention_heads_for_vit), # The vision encoder processes an image in a single forward pass to produce # embeddings. It doesn't have the concept of "prefill" and "autoregressive" # steps that a text decoder has. Therefore, it doesn't need a KV cache for # its self-attention mechanism. model_mode=MODEL_MODE_TRAIN, rngs=self.rngs, ) self.post_attention_layer_norm = nnx.LayerNorm( num_features=self.config.hidden_size_for_vit, epsilon=self.config.normalization_layer_epsilon, rngs=self.rngs ) self.Llama4VisionMLP_0 = Llama4VisionMLP(config=self.config, rngs=self.rngs) def __call__( self, hidden_states: Array, deterministic: bool = False, ): residual = hidden_states hidden_states = self.input_layer_norm(hidden_states) hidden_states, _ = self.self_attention_vision( inputs_q=hidden_states, inputs_kv=hidden_states, deterministic=deterministic, ) hidden_states = residual + hidden_states residual = hidden_states hidden_states = self.post_attention_layer_norm(hidden_states) hidden_states = self.Llama4VisionMLP_0(hidden_states) hidden_states = residual + hidden_states return hidden_states
[docs] class Llama4VisionEncoder(nnx.Module): """Transformer encoder consisting of multiple Llama4VisionEncoderLayer layers. This encoder is based on the PyTorch reference implementation and uses multiple encoder layers to process vision input. Attributes: config: Config containing model parameters mesh: Mesh, JAX device mesh (used for sharding) """ def __init__(self, config: Config, mesh: Mesh, *, rngs: nnx.Rngs = None): self.config = config self.mesh = mesh self.rngs = rngs for lyr in range(self.config.num_hidden_layers_for_vit): layer_name = f"layers_{lyr}" layer = Llama4VisionEncoderLayer( config=self.config, mesh=self.mesh, rngs=self.rngs, ) setattr(self, layer_name, layer) def __call__(self, hidden_states: Array, deterministic: bool = False): for lyr in range(self.config.num_hidden_layers_for_vit): layer_name = f"layers_{lyr}" layer = getattr(self, layer_name) hidden_states = layer(hidden_states, deterministic=deterministic) return hidden_states
[docs] class Llama4VisionModel(nnx.Module): """Llama4 vision model for processing image inputs. This model extracts patches from input image tiles and processes them through Llama4VisionEncoder and other vision-specific layers. Attributes: config: Config containing model parameters mesh: Mesh, JAX device mesh (used for sharding) """ def __init__(self, config: Config, mesh: Mesh, *, rngs: nnx.Rngs = None): self.config = config self.mesh = mesh self.rngs = rngs self.scale = self.config.hidden_size_for_vit**-0.5 self.num_patches = (self.config.tile_size_for_vit // self.config.patch_size_for_vit) ** 2 + 1 self.initializer = nnx.initializers.normal(self.scale) self.class_embedding = nnx.Param( self.initializer(self.rngs.params(), (self.config.hidden_size_for_vit,), self.config.dtype_mm) ) self.positional_embedding_vlm = nnx.Param( self.initializer(self.rngs.params(), (self.num_patches, self.config.hidden_size_for_vit), self.config.dtype_mm) ) self.layernorm_pre = nnx.LayerNorm( num_features=self.config.hidden_size_for_vit, epsilon=self.config.normalization_layer_epsilon, dtype=self.config.dtype_mm, rngs=self.rngs, ) self.layernorm_post = nnx.LayerNorm( num_features=self.config.hidden_size_for_vit, epsilon=self.config.normalization_layer_epsilon, dtype=self.config.dtype_mm, rngs=self.rngs, ) self.Llama4UnfoldConvolution_0 = Llama4UnfoldConvolution(config=self.config, rngs=self.rngs) self.Llama4VisionEncoder_0 = Llama4VisionEncoder(config=self.config, mesh=self.mesh, rngs=self.rngs) self.Llama4VisionPixelShuffleMLP_0 = Llama4VisionPixelShuffleMLP(config=self.config, rngs=self.rngs) def __call__( self, pixel_values: Array, output_attentions: None | bool = None, output_hidden_states: None | bool = None, return_dict: None | bool = None, deterministic: None | bool = False, ) -> Array: """Forward pass of the Llama4 vision model. Args: inputs: Input tensor of shape: [batch_size * num_images, num_tiles, num_channels_for_vit, tile_size_for_vit, tile_size_for_vit] deterministic: Whether to use deterministic mode (disables dropout) Returns: Final hidden states from the vision encoder of shape: [batch_size * num_images, num_tiles, num_patches, vision_output_dim_for_vit] """ # Reshape pixel values to combine batch and num_tiles dimensions b, t, c, h, w = pixel_values.shape pixel_values = jnp.reshape(pixel_values, [b * t, c, h, w]) hidden_states = self.Llama4UnfoldConvolution_0(pixel_values) # Add class embedding to the beginning of the sequence class_embedding_expanded = jnp.expand_dims(jnp.expand_dims(self.class_embedding, axis=0), axis=0) class_embedding = jnp.broadcast_to( class_embedding_expanded, (hidden_states.shape[0], 1, self.config.hidden_size_for_vit) ) hidden_states = jnp.concatenate([hidden_states, class_embedding], axis=1) # Add positional embedding hidden_states += self.positional_embedding_vlm # Transformation layers hidden_states = self.layernorm_pre(hidden_states) hidden_states = self.Llama4VisionEncoder_0(hidden_states) hidden_states = self.layernorm_post(hidden_states) hidden_states = hidden_states[:, :-1, :] hidden_states = self.Llama4VisionPixelShuffleMLP_0(hidden_states) # Reshape hidden states _, patch_num, patch_dim = hidden_states.shape hidden_states = jnp.reshape(hidden_states, [b, t, patch_num, patch_dim]) return hidden_states
[docs] def llama4visionmodel_as_linen(config: Config, mesh: Mesh) -> nn.Module: return nnx_wrappers.to_linen( Llama4VisionModel, config=config, mesh=mesh, name="Llama4VisionModel_0", abstract_init=False, metadata_fn=initializers.variable_to_logically_partitioned, )