Source code for maxtext.layers.engram

# Copyright 2023–2026 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#    https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""
DeepSeek-AI, `Conditional Memory via Scalable Lookup: A New Axis of Sparsity for Large Language Models
  <https://arxiv.org/pdf/2601.07372>`_, 2026

Reference implementation: https://github.com/deepseek-ai/Engram/blob/main/engram_demo_v1.py
"""

from typing import List, Optional
from flax import nnx
import jax
import jax.numpy as jnp
from jax.sharding import Mesh
from maxtext.common.common_types import Array, Config, MODEL_MODE_TRAIN
from maxtext.input_pipeline.tokenizer import HFTokenizer
from maxtext.layers.embeddings import Embed
from maxtext.layers.initializers import NdInitializer, nd_dense_init
from maxtext.layers.linears import DenseGeneral
from maxtext.layers.normalizations import RMSNorm
from maxtext.layers.quantizations import AqtQuantization as Quant
import numpy as np
import sympy
import tokenizers
from tokenizers import normalizers


[docs] class CompressedTokenizer: """ A canonicalizing wrapper that reduces vocabulary sparsity for n-gram lookup. This class maps semantically equivalent tokens (e.g., "Apple", " apple", "APPLE") to a single unified ID. This many-to-one mapping significantly reduces the combinatorial size of the n-gram space. Attributes: lookup_table: Array mapping `original_id` -> `compressed_id`. num_new_token: Size of the compressed vocabulary. """ def __init__(self, tokenizer: HFTokenizer): normalizer = self._build_normalizer() self.lookup_table_np, self.num_new_token = self._build_lookup_table(tokenizer, normalizer) self.lookup_table = jnp.array(self.lookup_table_np, dtype=jnp.int32) def __len__(self) -> int: return self.num_new_token def _build_normalizer(self) -> normalizers.Sequence: """ Builds the normalization pipeline for text processing. """ # Private use Unicode character to protect single spaces during stripping SENTINEL = "\uE000" # Normalization pipeline: ensures variations like "Café" and "cafe" map to the same ID normalizer = normalizers.Sequence( [ # Compatibility decomposition (e.g., ½ -> 1/2) normalizers.NFKC(), # Canonical decomposition (e.g., é -> e + ´) normalizers.NFD(), # Strip diacritics (e.g., e + ´ -> e) normalizers.StripAccents(), # Lowercase conversion ("The" -> "the") normalizers.Lowercase(), # Collapse all whitespace variations to a single space normalizers.Replace(tokenizers.Regex(r"[ \t\r\n]+"), " "), # Protect standalone spaces from subsequent stripping normalizers.Replace(tokenizers.Regex(r"^ $"), SENTINEL), # Remove leading/trailing whitespace normalizers.Strip(), # Restore protected spaces normalizers.Replace(SENTINEL, " "), ] ) return normalizer def _build_lookup_table(self, tokenizer: HFTokenizer, normalizer: normalizers.Sequence) -> tuple[np.ndarray, int]: """ Builds the mapping from the original vocabulary to the compressed vocabulary. """ vocab_size = len(tokenizer) # Mapping: original_tid -> compressed_nid (Many-to-One) old2new = np.empty(vocab_size, dtype=np.int64) # Mapping: normalized_string -> compressed_nid (One-to-One) key2new = {} # Batch decode token to raw text texts = tokenizer.batch_decode([[tid] for tid in range(vocab_size)], skip_special_tokens=False) for tid, text in zip(range(vocab_size), texts): if "\ufffd" in text: # Handle invalid UTF-8 (replacement char �). Use raw token instead. key = tokenizer.convert_ids_to_tokens(tid) else: # Normalize text (e.g., " APPLE" -> "apple") normalized_text = normalizer.normalize_str(text) # Fallback to raw text if normalization creates an empty string key = normalized_text if normalized_text else text # Assign compressed ID nid = key2new.get(key) if nid is None: nid = len(key2new) key2new[key] = nid old2new[tid] = nid return old2new, len(key2new) def __call__(self, input_ids) -> Array: """ Maps original token IDs to compressed IDs. """ input_ids = jnp.asarray(input_ids, dtype=jnp.int32) # Map negative IDs to 0 for lookup, then mask output back. safe_ids = jnp.where(input_ids < 0, 0, input_ids) mapped_ids = self.lookup_table[safe_ids] # Restore negative IDs (padding) output_ids = jnp.where(input_ids < 0, input_ids, mapped_ids) return output_ids
[docs] class NgramHashMapping: """ Deterministically maps token indices to n-gram hash indices for embedding lookups. This class implements Multi-Head Hashing to bypass the combinatorial memory requirements of explicit n-gram vocabularies. Specifically, it applies multiplicative-XOR hashing to each n-gram window. Key Mechanisms for Collision Mitigation: - Multi-Head Factorization: Uses K distinct hash heads per n-gram order to increase effective capacity within fixed memory constraints. - Unique Prime Moduli: Assigns a unique prime vocabulary size to each head to minimize simultaneous collisions. """ def __init__( self, engram_vocab_bases: List[int], max_ngram_size: int, engram_num_heads: int, layer_ids: List[int], tokenizer: HFTokenizer, pad_id: int, seed: int, ): """ Args: engram_vocab_bases: List of minimum head vocab sizes for each n-gram order. max_ngram_size: Max n-gram size to track (e.g., 3 tracks 2-grams and 3-grams). engram_num_heads: Number of parallel heads per n-gram order. layer_ids: List of layer indices using Engram. tokenizer: Base Hugging Face tokenizer. pad_id: Padding token ID. seed: Random seed for hash multiplier generation. """ self.min_head_vocab_size_per_ngram = engram_vocab_bases self.max_ngram_size = max_ngram_size self.n_head_per_ngram = engram_num_heads self.layer_ids = layer_ids # Initialize compressed tokenizer self.compressed_tokenizer = CompressedTokenizer(tokenizer) self.tokenizer_vocab_size = len(self.compressed_tokenizer) if pad_id is None: raise ValueError("The `pad_id` must be provided and cannot be None.") # Pre-calculate pad_id on CPU using numpy array to avoid ConcretizationTypeError self.pad_id = int(self.compressed_tokenizer.lookup_table_np[pad_id]) # Pre-calculate odd multipliers for hashing: {layer_id: multipliers} # Store as JAX arrays self.layer_multipliers = { k: jnp.array(v, dtype=jnp.int32) for k, v in self._calculate_multipliers_across_layers(seed).items() } # Pre-calculate unique prime vocab sizes for every head # Structure: {layer_id: [[2gram_head1, ..., 2gram_headH], ..., [Ngram_head1, ..., Ngram_headH]]} self.vocab_size_across_layers = self._calculate_vocab_size_across_layers() def _calculate_multipliers_across_layers(self, seed: int) -> dict[int, np.ndarray]: """ Pre-calculates random odd multipliers for each layer and n-gram position. Returns: A dictionary mapping layer_id to a list of `max_ngram_size` multipliers. """ # Pre-calculate bounds for random generation using int32 to avoid overflow max_int = np.iinfo(np.int32).max m_max = int(max_int // self.tokenizer_vocab_size) half_bound = max(1, m_max // 2) # Hard-code prime number to align with reference LAYER_PRIME_OFFSET = 10007 layer_multipliers = {} for layer_id in self.layer_ids: # Offset seed to decorrelate layers layer_seed = int(seed + LAYER_PRIME_OFFSET * int(layer_id)) np_rng = np.random.default_rng(layer_seed) # Generate random odd integers random_value = np_rng.integers(low=0, high=half_bound, size=(self.max_ngram_size,), dtype=np.int32) multipliers = random_value * 2 + 1 layer_multipliers[layer_id] = multipliers return layer_multipliers def _calculate_vocab_size_across_layers(self) -> dict[int, List[List[int]]]: """ Calculates unique prime vocabulary sizes for every head in every layer. Using unique primes minimizes the probability of simultaneous collisions across heads. """ def find_next_unseen_prime(start: int, seen_primes: set) -> int: candidate = start + 1 while candidate in seen_primes or not sympy.isprime(candidate): candidate += 1 return candidate seen_primes = set() vocab_size_across_layers = {} for layer_id in self.layer_ids: all_ngram_vocab_sizes = [] for n in range(2, self.max_ngram_size + 1): current_ngram_heads_sizes = [] # Start search from the configured minimum size vocab_size = self.min_head_vocab_size_per_ngram[n - 2] current_prime_search_start = vocab_size - 1 # Find unique primes for each head num_heads = self.n_head_per_ngram for _ in range(num_heads): found_prime = find_next_unseen_prime(current_prime_search_start, seen_primes) seen_primes.add(found_prime) current_ngram_heads_sizes.append(found_prime) current_prime_search_start = found_prime all_ngram_vocab_sizes.append(current_ngram_heads_sizes) vocab_size_across_layers[layer_id] = all_ngram_vocab_sizes return vocab_size_across_layers
[docs] def get_vocab_sizes(self, layer_id: int) -> List[int]: """ Returns a flattened list of prime vocabulary sizes for a specific layer. """ return [head_size for ngram_size in self.vocab_size_across_layers[layer_id] for head_size in ngram_size]
def _get_ngram_hashes(self, compressed_ids: Array, layer_id: int) -> Array: """ Computes hash indices for all n-grams in the input batch. Args: compressed_ids: [B, S] input token IDs. layer_id: engram layer id. Returns: hash_ids: [B, S, H_total] where H_total = H * num_ngram_orders """ x = jnp.asarray(compressed_ids, dtype=jnp.int32) B, _ = x.shape # 1. Create Sliding Windows via Shifting shifted_inputs = [] for k in range(self.max_ngram_size): if k == 0: shifted_inputs.append(x) else: # Pre-allocate full array with PAD_ID padding = jnp.full((B, k), self.pad_id, dtype=jnp.int32) # Fast memory copy, slicing and assignment # e.g., k=1, [PAD, The, cat] # k=2, [PAD, PAD, The] shifted_x = jnp.concatenate([padding, x[:, :-k]], axis=1) shifted_inputs.append(shifted_x) # 2. Retrieve layer-specific hash multipliers multipliers = self.layer_multipliers[layer_id] # 3. Compute Hashes: multiplicative bitwise XOR # Implements hash: H_n = (shift_0 * m_0) ^ ... ^ (shift_k * m_k) # e.g., (The * m_0) ^ (PAD * m_1) ^ (PAD * m_2) # (cat * m_0) ^ (The * m_1) ^ (PAD * m_2) # (sat * m_0) ^ (cat * m_1) ^ (The * m_2) all_hashes = [] # Initialize with unigrams, shape: [B, S] ngram_hash = shifted_inputs[0] * multipliers[0] # Pre-fetch vocab sizes for modulo vocab_sizes = self.vocab_size_across_layers[layer_id] for n in range(2, self.max_ngram_size + 1): # Update hash with next history token ngram_hash = jnp.bitwise_xor(ngram_hash, shifted_inputs[n - 1] * multipliers[n - 1]) # Retrieve prime vocab sizes for all heads of this n-gram order vocab_sizes_for_this_gram = vocab_sizes[n - 2] mods = jnp.array(vocab_sizes_for_this_gram, dtype=jnp.int32) # Broadcast Modulo: Map hash to valid table indices # [B, S, 1] % [H] -> [B, S, H] head_hashes = ngram_hash[..., None] % mods all_hashes.append(head_hashes) # Concatenate all heads: [B, S, H_total] where H_total = H * num_ngram_orders return jnp.concatenate(all_hashes, axis=2) def __call__(self, input_ids) -> dict[int, Array]: # input_ids from standard tokenizer compressed_ids = self.compressed_tokenizer(input_ids) hash_ids_for_all_layers = {} for layer_id in self.layer_ids: hash_ids = self._get_ngram_hashes(compressed_ids, layer_id=layer_id) hash_ids_for_all_layers[layer_id] = hash_ids return hash_ids_for_all_layers
[docs] class StaticWrapper: """Wrapper to prevent nnx from treating the value as a variable.""" def __init__(self, val): self.val = val
[docs] class MultiHeadEmbedding(nnx.Module): """ A flattened table representation for multi-head embedding spaces across n-gram orders. """ def __init__( self, config: Config, mesh: Mesh, vocab_sizes: List[int], head_dim: int, rngs: nnx.Rngs = None, ): """ Args: config: The model configuration. mesh: Device mesh for partitioning. vocab_sizes: Flattened list of prime vocabulary sizes for all heads across all n-gram orders. Example: [2gram_Head1, 2gram_Head2, 3gram_Head1, ...] head_dim: Embedding dimension for a single head. rngs: Random number generators for initialization. """ self.num_heads = len(vocab_sizes) # Compute starting index for each head's segment in the flattened table. # Offsets serve as the "base address" for each head. offsets = np.cumsum([0] + vocab_sizes[:-1]) # prefix sum self.offsets = StaticWrapper(np.array(offsets, dtype=np.int64)) # The total embedding size is the sum of all individual head vocabularies. self.embedding = Embed(num_embeddings=sum(vocab_sizes), num_features=head_dim, config=config, mesh=mesh, rngs=rngs) def __call__(self, input_ids: Array, model_mode: str = MODEL_MODE_TRAIN) -> Array: """ Retrieves embeddings for multi-head indices. Args: input_ids: Hashed indices. Shape [B, S, H_total], where H_total is the total number of heads. model_mode: The model's operational mode (e.g., 'train', 'prefill'). Returns: embeddings: Shape [B, S, H_total, D_head]. """ # Broadcasting Add: [B, S, H] + [H] -> [B, S, H] # Shifts local indices (0..Prime-1) to global table positions. shifted_ids = input_ids + self.offsets.val # Embedding lookup: [B, S, H_total] -> [B, S, H_total, D_head] return self.embedding(shifted_ids, model_mode=model_mode)
[docs] class ShortConv(nnx.Module): """ Depthwise causal 1D convolution, with multi-branch integration. Applies local temporal smoothing - Independent RMSNorms to each branch - Convolution to mix time steps [t-k, t] """ def __init__( self, config: Config, hidden_size: int, kernel_size: int, dilation: int, mhc_expansion_rate: int, rngs: nnx.Rngs = None, ): """ Args: config: The model configuration. hidden_size (D): Dimension of a single branch. kernel_size: Temporal window size. dilation: Dilation rate for the convolution. mhc_expansion_rate (G): Number of branches. rngs: RNG state for initialization. """ self.mhc_expansion_rate = mhc_expansion_rate # Norms # Vectorized Init: Independent weights per branch # rngs: [G, 2] split RNGs, vectorize over G, `in_axes=0` # Stack weights at axis 0 to get [G, D], `out_axes=0` @nnx.split_rngs(splits=mhc_expansion_rate) @nnx.vmap(in_axes=0, out_axes=0) def create_norms(rngs): return RMSNorm( num_features=hidden_size, dtype=config.dtype, weight_dtype=config.weight_dtype, epsilon=config.normalization_layer_epsilon, kernel_axes=("norm",), rngs=rngs, ) self.norm = create_norms(rngs) # Convolution (Batch over branch) # Depthwise: feature_group_count == in_features ensures no mixing across channels and branches # Causal: Ensures output at t only depends on inputs <= t. # Weights: {"kernel": shape [kernel_size, in_features//feature_group_count, total_channels]} total_channels = mhc_expansion_rate * hidden_size # G * D self.conv = nnx.Conv( in_features=total_channels, out_features=total_channels, kernel_size=(kernel_size,), feature_group_count=total_channels, kernel_dilation=(dilation,), padding="CAUSAL", use_bias=False, # convolution parameters are initialized to zero # to strictly preserve the identity mapping at the start of training kernel_init=nnx.initializers.zeros, dtype=config.dtype, param_dtype=config.weight_dtype, precision=config.matmul_precision, rngs=rngs, ) def __call__(self, x: Array) -> Array: """ Compute y^i = SiLU(Conv1D(RMSNorm^i(x^i))) for each branch i. Args: x: Input tensor of shape [B, S, G, D] Returns: Output tensor of shape [B, S, G, D] Shape annotation: B: Batch size S: Sequence length (temporal dimension) G: Number of branches (mhc_expansion_rate) D: Hidden size (emb_dim) """ B, S, G, D = x.shape # Vectorized Apply # norms: [G, D], vectorize over G, `in_axes=0` # x: [B, S, G, D], vectorize over G, `in_axes=2` # Stack results at axis 2 to get [B, S, G, D], `out_axes=2` @nnx.vmap(in_axes=(0, 2), out_axes=2) def apply_norms(norms, x): return norms(x) # [B, S, G, D] shape stays x = apply_norms(self.norm, x) # Flatten branches into channel: [B, S, G, D] -> [B, S, G * D] x_flat = x.reshape(B, S, G * D) # Depthwise Convolution to mix temporal dimension S only. [B, S, G * D] shape stays y = self.conv(x_flat) y = jax.nn.silu(y) # Restore branch: [B, S, G * D] -> [B, S, G, D] return y.reshape(B, S, G, D)
[docs] class Engram(nnx.Module): """ Engram Memory Layer with n-gram embedding, with multi-branch integration. Main components: - Context-independent Retrieval: Fetch static n-gram embeddings via Multi-Head Hashing. - Context-aware Gating: Compute similarity between memory (Key) and context (Query) to determine relevance. - Mix: Apply local temporal smoothing via convolution. """ def __init__( self, config: Config, mesh: Mesh, vocab_sizes: List[int], engram_num_heads: int, engram_head_dim: int, engram_max_ngram_size: int, engram_kernel_size: int, mhc_expansion_rate: int, kernel_init: NdInitializer = nd_dense_init(1.0, "fan_in", "normal"), quant: Optional[Quant] = None, rngs: nnx.Rngs = None, ): """ Args: config: The model configuration. mesh: Partitioning mesh. vocab_sizes: Flattened list of prime vocabulary sizes for all heads across all n-gram orders. Example: [2gram_Head1, 2gram_Head2, 3gram_Head1, ...] engram_num_heads (H): Heads per n-gram order. engram_head_dim (D_head): Dimension per head. engram_max_ngram_size: Max n-gram order (e.g., 3 covers 2-grams and 3-grams). engram_kernel_size: convolution kernel size. mhc_expansion_rate (G): Number of branches. kernel_init: Weight initializer. quant: Quantization config. rngs: RNG state for initialization """ self.config = config self.mesh = mesh self.dtype = self.config.dtype self.weight_dtype = self.config.dtype self.kernel_init = kernel_init self.quant = quant self.rngs = rngs self.mhc_expansion_rate = mhc_expansion_rate # Hierarchy: Engram -> n-gram Order -> h-th Head self.max_ngram_size = engram_max_ngram_size self.conv_kernel_size = engram_kernel_size num_ngram_orders = self.max_ngram_size - 1 # D_en: Final concatenated size of the retrieved memory self.engram_dim = engram_head_dim * engram_num_heads * num_ngram_orders # Embedding: one flattened table to store all n-gram heads across orders self.multi_head_embedding = MultiHeadEmbedding( config=config, mesh=mesh, vocab_sizes=vocab_sizes, head_dim=engram_head_dim, rngs=rngs ) # Key Projection (Batch over branch) # retrieved n-gram memory -> Key, from D_en to [G, D] self.key_proj = DenseGeneral( in_features_shape=self.engram_dim, out_features_shape=(mhc_expansion_rate, config.emb_dim), axis=-1, kernel_init=self.kernel_init, kernel_axes=("engram_dim", "mhc", "embed"), dtype=self.dtype, weight_dtype=self.weight_dtype, quant=self.quant, matmul_precision=self.config.matmul_precision, shard_mode=self.config.shard_mode, use_bias=True, rngs=rngs, ) # Norms # Vectorized Init: Independent weights per branch # rngs: [G, 2] split RNGs, vectorize over G, `in_axes=0` # Stack weights at axis 0 to get [G, D], `out_axes=0` @nnx.split_rngs(splits=mhc_expansion_rate) @nnx.vmap(in_axes=0, out_axes=0) def create_norms(rngs): return RMSNorm( num_features=config.emb_dim, dtype=config.dtype, weight_dtype=config.weight_dtype, epsilon=config.normalization_layer_epsilon, kernel_axes=("norm",), rngs=rngs, ) # Key Normalization self.k_norm = create_norms(rngs) # Query Normalization self.q_norm = create_norms(rngs) # Value Projection (Shared): Retrieved memory -> Value self.value_proj = DenseGeneral( in_features_shape=self.engram_dim, out_features_shape=config.emb_dim, axis=-1, kernel_init=self.kernel_init, kernel_axes=("engram_dim", "embed"), dtype=self.dtype, weight_dtype=self.weight_dtype, quant=self.quant, matmul_precision=self.config.matmul_precision, shard_mode=self.config.shard_mode, use_bias=True, rngs=self.rngs, ) # Short Convolution (Vectorized Internally) # Applies depthwise causal convolution to smooth the retrieved memory over time. self.short_conv = ShortConv( config=config, hidden_size=config.emb_dim, kernel_size=self.conv_kernel_size, dilation=self.max_ngram_size, mhc_expansion_rate=mhc_expansion_rate, rngs=rngs, ) def __call__(self, hidden_states: Array, hash_input_ids: Array) -> Array: """ Computes the Engram output by retrieving, gating, and smoothing n-gram memory. Args: hidden_states: current transformer state. Shape: [B, S, G, D]. hash_input_ids: Hashed token IDs. Shape: [B, S, H_total]. Produced by `hash_mapping.hash(input_ids)[layer_id]`. Returns: Shape: [B, S, G, D] Shape annotation: B: Batch Size S: Sequence Length G: mhc_expansion_rate, Number of Branches H_total: Total number of heads across n-grams. num_head * num_ngrams D: emb_dim D_head: Dimension of a single head embedding D_en: Dimension of flattened embedding across heads and n-grams """ B, S, _, D = hidden_states.shape # 1. Retrieve Memory from Embedding # [B, S, H_total] -> [B, S, H_total, D_head] embeddings = self.multi_head_embedding(hash_input_ids) # [B, S, H_total, D_head] -> [B, S, D_en] embeddings = embeddings.reshape(B, S, -1) # 2. Static Memory as Key # [B, S, D_en] -> [B, S, G, D] key = self.key_proj(embeddings) # 3. Compute Norms # Vectorized Apply # norms: [G, D], vectorize over G, `in_axes=0` # x: [B, S, G, D], vectorize over G, `in_axes=2` # Stack results at axis 2 to get [B, S, G, D], `out_axes=2` @nnx.vmap(in_axes=(0, 2), out_axes=2) def apply_norms(norms, x): return norms(x) # [B, S, G, D] shape stays key = apply_norms(self.k_norm, key) # 4. Dynamic Context as Query # [B, S, G, D] shape stays query = apply_norms(self.q_norm, hidden_states) # 5. QK product as Gates # Compute similarity of memory (Key) and current state (Query) qk_product = jnp.einsum("bsgd,bsgd->bsg", query, key, precision=self.config.matmul_precision) gate = qk_product / jnp.sqrt(D) # Range Compression: Apply signed square-root to prevent sigmoid saturation gate = jnp.sqrt(jnp.maximum(jnp.abs(gate), 1e-6)) * jnp.sign(gate) # Sigmoid activation to get gating probability [0, 1] gate = jax.nn.sigmoid(gate) # [B, S, G] # 6. Static Memory as Value # [B, S, D_en] -> [B, S, D] value = self.value_proj(embeddings) # 7. Apply Gates to Value # [B, S, G, 1] * [B, S, 1, D] -> [B, S, G, D] gated_value = gate[:, :, :, None] * value[:, :, None, :] # 8. ShortConv as Temporal Smoothing # [B, S, G, D] shape stays # Apply depthwise conv to mix S conv_output = self.short_conv(gated_value) # residual connection for conv component output = gated_value + conv_output # Note: residual connection for hidden_states will be added by the caller return output