# Copyright 2023–2026 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""OLMo numpy fixed-seq-length dataset on top of Grain.
A Grain ``RandomAccessDataSource`` over the AI2 OLMo virtual token stream
plus a deterministic global-shuffle sampler. See
``docs/guides/data_input_pipeline/olmo_grain.md`` for an overview.
"""
from __future__ import annotations
import hashlib
import threading
from typing import Any, Dict, List, Optional
import numpy as np
import grain
from maxtext.input_pipeline.olmo_data import (
OlmoNpyIndex,
global_to_local,
is_clean_instance,
)
[docs]
class OlmoNpyDataSource(grain.sources.RandomAccessDataSource):
"""Random-access view of an OLMo numpy mix as a stream of token windows.
Files are opened lazily and cached as ``np.memmap`` per worker. Open mmaps
are reference-counted by :class:`_MmapCache` so we don't blow past
``ulimit -n`` when iterating over the full 950-file mix.
The data source is **process-safe**: every Grain worker subprocess builds
its own ``_MmapCache`` after the fork. No shared mutable state.
Args:
index: The :class:`OlmoNpyIndex` describing the mix. Path strings must be
reachable from the data-loading host (typically a GCSFUSE mount path
like ``/mnt/<your-mount>/...``).
path_remap: Optional dict to rewrite ``index.files[i].path``. Useful when
the index was built with ``gs://`` paths and you want to read from a
gcsfuse mount, or vice versa. A path is rewritten if it starts with
any key in this dict.
max_open_files: Soft cap on the number of mmaps held open in the
per-worker cache. The cache is LRU.
"""
def __init__(
self,
index: OlmoNpyIndex,
*,
path_remap: Optional[Dict[str, str]] = None,
max_open_files: int = 64,
):
self._index = index
self._dtype = np.dtype(index.dtype)
self._sequence_length = index.sequence_length
self._path_remap = dict(path_remap or {})
self._mmaps = _MmapCache(max_open_files=max_open_files)
# ---- Grain's RandomAccessDataSource interface --------------------------- #
def __len__(self) -> int:
return self._index.total_instances
def __getitem__(self, instance_id: int) -> Dict[str, Any]:
file_idx, token_offset = global_to_local(self._index, instance_id)
file_entry = self._index.files[file_idx]
arr = self._mmaps.get(self._resolve_path(file_entry.path), self._dtype)
# Always copy: the memmap is opened read-only (mode="r"), and we need to
# hand a writable, picklable array back to Grain so transforms can
# mutate it freely and worker processes can serialize it without
# dragging the memmap object along.
tokens = np.array(arr[token_offset : token_offset + self._sequence_length], copy=True)
return {
"tokens": tokens,
"instance_id": int(instance_id),
"file_id": int(file_idx),
}
# ---- Helpers ------------------------------------------------------------ #
def _resolve_path(self, path: str) -> str:
for prefix, replacement in self._path_remap.items():
if path.startswith(prefix):
return replacement + path[len(prefix) :]
return path
def __getstate__(self):
# Mmap caches don't survive pickling; rebuild after unpickle.
state = self.__dict__.copy()
state["_mmaps"] = None
return state
def __setstate__(self, state):
self.__dict__.update(state)
self._mmaps = _MmapCache(max_open_files=64)
def __repr__(self) -> str:
"""Stable repr — Grain compares ``repr(data_source)`` between the
checkpoint and the live source on resume. The default repr embeds the
object id, which breaks resume across separate construction. We hash
the index fingerprint + seq + path-remap instead so equivalent sources
compare equal as strings."""
return (
f"OlmoNpyDataSource(fingerprint={self._index.fingerprint!r}, "
f"seq={self._sequence_length}, dtype={self._dtype.str!r}, "
f"remap={sorted(self._path_remap.items())!r})"
)
class _MmapCache:
"""Tiny LRU-ish cache of open ``np.memmap`` handles."""
def __init__(self, max_open_files: int = 64):
self._max = max_open_files
self._mmaps: Dict[str, np.memmap] = {}
self._lock = threading.Lock()
def get(self, path: str, dtype: np.dtype) -> np.memmap:
"""Return a cached ``np.memmap`` for ``path``, opening it lazily."""
with self._lock:
arr = self._mmaps.get(path)
if arr is not None:
# touch: re-insert at end for LRU ordering.
del self._mmaps[path]
self._mmaps[path] = arr
return arr
# Open the file as 1-D memmap; for raw .npy (no header) we need to know
# length, but np.memmap can derive it from file size when shape=(-1,).
arr = np.memmap(path, dtype=dtype, mode="r")
self._mmaps[path] = arr
while len(self._mmaps) > self._max:
# Evict oldest.
oldest = next(iter(self._mmaps))
del self._mmaps[oldest]
return arr
[docs]
class OlmoIndexSampler:
"""Global-shuffle sampler over an OLMo numpy mix.
Mirrors OLMo-core's :class:`NumpyDataLoaderBase` shuffle math: a single
Fisher-Yates over ``[0, total_instances)`` keyed by ``hash(seed, epoch)``,
then partitioned across ``shard_count`` hosts.
Implements Grain's ``Sampler`` protocol — i.e. ``__getitem__`` returning
:class:`grain.python.RecordMetadata`. Grain calls
``sampler[index]`` for each global step; the sampler is responsible for
mapping that to the actual record_key fed to ``data_source[record_key]``.
Indexing semantics:
* ``index`` here is a *per-host* (per-data-loader) global step counter
starting at 0 and advancing without bound (we support infinite epochs).
* ``epoch = index // num_local_instances_per_epoch`` selects which
permutation to use; ``in_epoch = index % num_local_instances_per_epoch``
selects the position within this host's shard of that permutation.
Checkpointing is trivial: the only mutable state is "which epoch's
permutation is currently cached" (a perf optimization). The user-visible
position is just the index passed to ``__getitem__``.
Args:
total_instances: ``index.total_instances`` from the OLMo index.
seed: Base seed for the shuffle.
shard_index: Zero-based index of this data-loading host. Typically
``jax.process_index()``.
shard_count: Number of data-loading hosts. Typically
``jax.process_count()``.
shuffle: If ``False``, instances are emitted in linear order — useful
for debugging.
initial_step: Per-host batch step at which the *training run* should
resume. ``__getitem__(local_idx)`` returns the record at absolute
position ``local_idx + initial_step``. Use this to resume a run from
a saved trainer step without saving Grain's iterator state — our
sampler is a pure function of its inputs, so the (seed, shard,
absolute step) tuple fully determines the next record.
"""
def __init__(
self,
*,
total_instances: int,
seed: int,
shard_index: int = 0,
shard_count: int = 1,
shuffle: bool = True,
initial_step: int = 0,
):
if shard_count <= 0 or shard_index < 0 or shard_index >= shard_count:
raise ValueError(f"Invalid shard config: shard_index={shard_index} of {shard_count}")
if total_instances <= 0:
raise ValueError(f"total_instances must be positive, got {total_instances}")
if initial_step < 0:
raise ValueError(f"initial_step must be non-negative, got {initial_step}")
self._total = int(total_instances)
self._seed = int(seed)
self._shard_index = int(shard_index)
self._shard_count = int(shard_count)
self._shuffle = bool(shuffle)
self._initial_step = int(initial_step)
# Cache the shuffled-and-sharded indices for the most-recently-touched
# epoch. Cheap to recompute on epoch boundaries; expensive to keep many
# epochs resident at once for the full 724 M-instance mix.
self._cached_epoch: Optional[int] = None
self._cached_shard_indices: Optional[np.ndarray] = None
self._cache_lock = threading.Lock()
# ---- Public API --------------------------------------------------------- #
@property
def num_instances(self) -> int:
return self._total
@property
def num_local_instances_per_epoch(self) -> int:
"""Instances assigned to *this* host per epoch (drops trailing remainder)."""
return self._total // self._shard_count
[docs]
def shuffled_global_indices(self, *, seed: int, epoch: int) -> np.ndarray:
"""Build the full shuffled list for ``(seed, epoch)``.
For the production 724 M-instance mix this allocates ~5.8 GB at uint64
(numpy's default for ``permutation``). For production we should swap to
an on-disk memmap scheme like olmo-core's
``build_and_save_global_indices``. Sized for unit tests + the initial
smoke training run for now.
"""
if not self._shuffle:
return np.arange(self._total, dtype=np.uint64)
rng = np.random.default_rng(_combine_seed_epoch(seed, epoch))
order = rng.permutation(self._total)
return order.astype(np.uint64, copy=False)
[docs]
def shard_indices(self, *, seed: int, epoch: int) -> np.ndarray:
"""Slice the global shuffled order down to this host's share."""
full = self.shuffled_global_indices(seed=seed, epoch=epoch)
n_per = self.num_local_instances_per_epoch
start = self._shard_index * n_per
end = start + n_per
return full[start:end]
def _shard_indices_for_epoch(self, epoch: int) -> np.ndarray:
with self._cache_lock:
if self._cached_epoch == epoch and self._cached_shard_indices is not None:
return self._cached_shard_indices
shard = self.shard_indices(seed=self._seed, epoch=epoch)
self._cached_epoch = epoch
self._cached_shard_indices = shard
return shard
def __getstate__(self):
# threading.Lock can't be pickled, and the per-epoch cache is a pure perf
# optimization — drop both before serialization to forked Grain workers.
state = self.__dict__.copy()
state["_cache_lock"] = None
state["_cached_epoch"] = None
state["_cached_shard_indices"] = None
return state
def __setstate__(self, state):
self.__dict__.update(state)
self._cache_lock = threading.Lock()
# ---- Sampler protocol --------------------------------------------------- #
def __getitem__(self, index: int) -> grain.RecordMetadata:
"""Map a per-host global step ``index`` to the next record to fetch.
The lookup applies ``initial_step`` as a transparent offset: the caller
sees a fresh stream starting at index 0, but the underlying record
pointer is at absolute position ``index + initial_step``. That's the
mechanism that lets resume work without persisting any iterator state.
"""
if index < 0:
raise IndexError(f"sampler index must be non-negative, got {index}")
n_per = self.num_local_instances_per_epoch
if n_per == 0:
raise IndexError(
f"No instances assigned to shard {self._shard_index}/{self._shard_count} " f"(total_instances={self._total})"
)
absolute = index + self._initial_step
epoch = absolute // n_per
in_epoch = absolute % n_per
shard = self._shard_indices_for_epoch(epoch)
record_key = int(shard[in_epoch])
return grain.RecordMetadata(index=index, record_key=record_key)
# Grain >=0.2.16 expects either a finite ``__len__`` or that the sampler
# raises ``IndexError`` on out-of-bounds. We support infinite training and
# never raise IndexError for non-negative indices, so we omit ``__len__``.
def __repr__(self) -> str:
"""Stable repr — Grain compares ``repr(sampler)`` between the checkpoint
and the live sampler to validate the sampler is unchanged on resume.
We deliberately **exclude** ``initial_step`` from the repr: a sampler
rebuilt with a different ``initial_step`` produces a different absolute
position via offset arithmetic, but it's still the *same logical sampler*
over the same data. Including the step here would break interop with
Grain's iterator-state checkpointing path (different reprs reject each
other). The repr captures only the immutable config that defines the
sample space; the offset is just a starting cursor.
"""
return (
f"OlmoIndexSampler(total_instances={self._total}, seed={self._seed}, "
f"shard_index={self._shard_index}, shard_count={self._shard_count}, "
f"shuffle={self._shuffle})"
)
def _combine_seed_epoch(seed: int, epoch: int) -> int:
"""Stable 64-bit mix of (seed, epoch) for the per-epoch shuffle RNG.
Uses SHA-256 truncated to 64 bits — no fixed points (unlike a raw multiply
by a constant when seed=epoch=0), and avoids the numpy uint64 multiplication
overflow warnings that dog SplitMix-style mixers in pure numpy.
"""
digest = hashlib.sha256(f"olmo-shuffle:{int(seed)}:{int(epoch)}".encode("utf-8")).digest()
return int.from_bytes(digest[:8], "little")
[docs]
def make_olmo_grain_data_loader(
index: OlmoNpyIndex,
*,
seed: int,
batch_size: int,
shard_index: int,
shard_count: int,
apply_ngram_filter: bool = True,
shift_to_inputs_targets: bool = True,
path_remap: Optional[Dict[str, str]] = None,
grain_worker_count: int = 0,
grain_worker_buffer_size: int = 1,
initial_step: int = 0,
):
"""Build a Grain ``DataLoader`` for OLMo-style fixed-seq-length training.
Args:
index: Loaded :class:`OlmoNpyIndex`.
seed: Shuffle seed (paired with the implicit per-step ``epoch =
step // n_per_host`` to drive the per-epoch permutation).
batch_size: Per-host batch size (i.e. global_batch / shard_count).
shard_index: This host's data-loading rank.
shard_count: Total data-loading hosts.
apply_ngram_filter: Add :class:`NgramFilterTransform` (recommended).
shift_to_inputs_targets: Add :class:`ShiftToInputsTargets` so the loader
yields the ``inputs``/``targets`` shape MaxText's trainer expects.
path_remap: Pass-through to :class:`OlmoNpyDataSource`.
grain_worker_count: ``0`` runs in-process; otherwise Grain forks workers.
grain_worker_buffer_size: Per-worker batch prefetch.
initial_step: Start the *underlying sampler* at this absolute step.
The Grain DataLoader still iterates from its own 0, but every record
lookup is shifted by ``initial_step``. Set this to ``train_step *
batch_size`` on resume to pick up the data stream where it left off
*without* needing Grain's iterator-state checkpointing.
Returns:
A ``grain.DataLoader``.
"""
source = OlmoNpyDataSource(index, path_remap=path_remap)
sampler = OlmoIndexSampler(
total_instances=index.total_instances,
seed=seed,
shard_index=shard_index,
shard_count=shard_count,
initial_step=initial_step,
)
ops: List[Any] = []
if apply_ngram_filter:
ops.append(NgramFilterTransform())
if shift_to_inputs_targets:
ops.append(ShiftToInputsTargets())
ops.append(grain.transforms.Batch(batch_size=batch_size, drop_remainder=True))
# Grain expects ``shard_options`` on the DataLoader (sharding used to live
# on the Sampler). Our sampler already does the shard-by-rank slicing, but
# Grain still requires this object to validate checkpoint compatibility.
shard_options = grain.sharding.ShardOptions(shard_index=shard_index, shard_count=shard_count, drop_remainder=True)
return grain.DataLoader(
data_source=source,
sampler=sampler,
operations=ops,
shard_options=shard_options,
worker_count=grain_worker_count,
worker_buffer_size=grain_worker_buffer_size,
)