Source code for maxtext.models.simple_layer

# Copyright 2023–2026 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#    https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Simple decoder layers for testing and debugging purposes."""

from jax import numpy as jnp
from jax.sharding import Mesh

from flax import nnx
from maxtext.common.common_types import Config, ShardMode
from maxtext.layers import quantizations, nnx_wrappers
from maxtext.layers.initializers import variable_to_logically_partitioned
from maxtext.utils.sharding import create_sharding


from typing import Optional
# pytype: disable=attribute-error


[docs] class SimpleDecoderLayer(nnx.Module): """Decoder layer consisting of a single [embed, embed] weight matrix.""" def __init__( self, config: Config, mesh: Mesh, model_mode: str, rngs: nnx.Rngs, quant: Optional[quantizations.AqtQuantization] = None, ) -> None: self.config = config self.mesh = mesh self.model_mode = model_mode self.rngs = rngs self.quant = quant init_fn = nnx.with_partitioning(nnx.initializers.lecun_normal(), sharding=("embed", "mlp"), mesh=self.mesh) self.weights = nnx.Param( init_fn(self.rngs.params(), (self.config.emb_dim, self.config.emb_dim)), ) activation_axis_names = ("activation_batch", "activation_norm_length", "activation_embed") self.out_sharding = ( create_sharding(self.mesh, activation_axis_names) if config.shard_mode == ShardMode.EXPLICIT else None ) def __call__( self, inputs: jnp.ndarray, positions, segmentation, deterministic, model_mode, previous_chunk=None, page_state=None, slot=None, kv_cache=None, attention_metadata=None, ): # Unpack inputs if it's a tuple (e.g. from a previous layer returning (hidden_states, kv_cache)) if isinstance(inputs, tuple): inputs = inputs[0] if self.config.scan_layers: return jnp.dot(inputs, self.weights.astype(inputs.dtype), out_sharding=self.out_sharding), None return jnp.dot(inputs, self.weights.astype(inputs.dtype), out_sharding=self.out_sharding)
SimpleDecoderLayerToLinen = nnx_wrappers.to_linen_class( SimpleDecoderLayer, base_metadata_fn=variable_to_logically_partitioned, )
[docs] class SimpleMlpDecoderLayer(nnx.Module): """Decoder layer consisting of [embed,mlp] followed by an [mlp,embed] matmul.""" def __init__( self, config: Config, mesh: Mesh, model_mode: str, rngs: nnx.Rngs, quant: Optional[quantizations.AqtQuantization] = None, ) -> None: self.config = config self.mesh = mesh self.model_mode = model_mode self.rngs = rngs self.quant = quant init_ff1_fn = nnx.with_partitioning(nnx.initializers.lecun_normal(), sharding=("embed", "mlp"), mesh=self.mesh) self.ff_1 = nnx.Param( init_ff1_fn(self.rngs.params(), (self.config.emb_dim, self.config.mlp_dim)), ) init_ff2_fn = nnx.with_partitioning(nnx.initializers.lecun_normal(), sharding=("mlp", "embed"), mesh=self.mesh) self.ff_2 = nnx.Param( init_ff2_fn(self.rngs.params(), (self.config.mlp_dim, self.config.emb_dim)), ) activation_axes_names = ("activation_batch", "activation_norm_length", "activation_embed") self.activation_sharding = ( create_sharding(mesh, activation_axes_names) if config.shard_mode == ShardMode.EXPLICIT else None ) mlp_axes_names = ("activation_batch", "activation_norm_length", "activation_mlp") self.mlp_sharding = create_sharding(mesh, mlp_axes_names) if config.shard_mode == ShardMode.EXPLICIT else None def __call__( self, inputs: jnp.ndarray, positions, segmentation, deterministic, model_mode, previous_chunk=None, page_state=None, slot=0, kv_cache=None, attention_metadata=None, ): # Unpack inputs if it's a tuple (e.g. from a previous layer returning (hidden_states, kv_cache)) if isinstance(inputs, tuple): inputs = inputs[0] intermediate = jnp.dot(inputs, self.ff_1.astype(inputs.dtype), out_sharding=self.mlp_sharding) output = jnp.dot(intermediate, self.ff_2.astype(inputs.dtype), out_sharding=self.activation_sharding) if self.config.scan_layers: return output, None return output
SimpleMlpDecoderLayerToLinen = nnx_wrappers.to_linen_class( SimpleMlpDecoderLayer, base_metadata_fn=variable_to_logically_partitioned, )