maxtext.input_pipeline.olmo_data_grain module#

OLMo numpy fixed-seq-length dataset on top of Grain.

A Grain RandomAccessDataSource over the AI2 OLMo virtual token stream plus a deterministic global-shuffle sampler. See docs/guides/data_input_pipeline/olmo_grain.md for an overview.

class maxtext.input_pipeline.olmo_data_grain.OlmoNpyDataSource(*args, **kwargs)[source]#

Bases: RandomAccessDataSource

Random-access view of an OLMo numpy mix as a stream of token windows.

Files are opened lazily and cached as np.memmap per worker. Open mmaps are reference-counted by _MmapCache so we don’t blow past ulimit -n when iterating over the full 950-file mix.

The data source is process-safe: every Grain worker subprocess builds its own _MmapCache after the fork. No shared mutable state.

Parameters:
  • index (OlmoNpyIndex) – The OlmoNpyIndex describing the mix. Path strings must be reachable from the data-loading host (typically a GCSFUSE mount path like /mnt/<your-mount>/...).

  • path_remap (Optional[Dict[str, str]]) – Optional dict to rewrite index.files[i].path. Useful when the index was built with gs:// paths and you want to read from a gcsfuse mount, or vice versa. A path is rewritten if it starts with any key in this dict.

  • max_open_files (int) – Soft cap on the number of mmaps held open in the per-worker cache. The cache is LRU.

class maxtext.input_pipeline.olmo_data_grain.OlmoIndexSampler(*, total_instances, seed, shard_index=0, shard_count=1, shuffle=True, initial_step=0)[source]#

Bases: object

Global-shuffle sampler over an OLMo numpy mix.

Mirrors OLMo-core’s NumpyDataLoaderBase shuffle math: a single Fisher-Yates over [0, total_instances) keyed by hash(seed, epoch), then partitioned across shard_count hosts.

Implements Grain’s Sampler protocol — i.e. __getitem__ returning grain.python.RecordMetadata. Grain calls sampler[index] for each global step; the sampler is responsible for mapping that to the actual record_key fed to data_source[record_key].

Indexing semantics:

  • index here is a per-host (per-data-loader) global step counter starting at 0 and advancing without bound (we support infinite epochs).

  • epoch = index // num_local_instances_per_epoch selects which permutation to use; in_epoch = index % num_local_instances_per_epoch selects the position within this host’s shard of that permutation.

Checkpointing is trivial: the only mutable state is “which epoch’s permutation is currently cached” (a perf optimization). The user-visible position is just the index passed to __getitem__.

Parameters:
  • total_instances (int) – index.total_instances from the OLMo index.

  • seed (int) – Base seed for the shuffle.

  • shard_index (int) – Zero-based index of this data-loading host. Typically jax.process_index().

  • shard_count (int) – Number of data-loading hosts. Typically jax.process_count().

  • shuffle (bool) – If False, instances are emitted in linear order — useful for debugging.

  • initial_step (int) – Per-host batch step at which the training run should resume. __getitem__(local_idx) returns the record at absolute position local_idx + initial_step. Use this to resume a run from a saved trainer step without saving Grain’s iterator state — our sampler is a pure function of its inputs, so the (seed, shard, absolute step) tuple fully determines the next record.

property num_instances: int#
property num_local_instances_per_epoch: int#

Instances assigned to this host per epoch (drops trailing remainder).

shuffled_global_indices(*, seed, epoch)[source]#

Build the full shuffled list for (seed, epoch).

For the production 724 M-instance mix this allocates ~5.8 GB at uint64 (numpy’s default for permutation). For production we should swap to an on-disk memmap scheme like olmo-core’s build_and_save_global_indices. Sized for unit tests + the initial smoke training run for now.

Parameters:
  • seed (int)

  • epoch (int)

Return type:

ndarray

shard_indices(*, seed, epoch)[source]#

Slice the global shuffled order down to this host’s share.

Parameters:
  • seed (int)

  • epoch (int)

Return type:

ndarray

class maxtext.input_pipeline.olmo_data_grain.NgramFilterTransform(*args, **kwargs)[source]#

Bases: Map

Add an instance_mask field per OLMo-core’s repetition filter.

instance_mask = True if the instance is “clean” (kept fully in the loss); False if it has too-repetitive periodic spans (zero-out at loss time). We don’t drop the instance — that would mess with sharding — matching OLMo-core’s behavior.

Parameters:
  • max_period (int)

  • min_period (int)

  • max_count (int)

  • mask_value (int)

map(element)[source]#

Add instance_mask to element based on the n-gram filter.

Parameters:

element (Dict[str, Any])

Return type:

Dict[str, Any]

class maxtext.input_pipeline.olmo_data_grain.ShiftToInputsTargets(*args, **kwargs)[source]#

Bases: Map

Convert a tokens array into the keys MaxText’s pretrain trainer expects.

Produces, for a single instance of length L = sequence_length:

  • inputs: tokens.astype(int32), shape (L,)

  • targets: tokens shifted left by one, padded with 0 at position L-1, shape (L,)

  • inputs_position: [0, 1, ..., L-1] int32

  • inputs_segmentation: int32 ones, shape (L,) — single segment

  • targets_segmentation: int32 ones, shape (L,) with the last position zeroed (loss masked at the padded position); the entire row is zero if instance_mask is False (n-gram filter flagged the instance).

Outputs are the full L tokens (not L-1) because the TPU splash attention kernel requires q_seq_len divisible by 512; producing length L-1 would break that invariant for typical OLMo L=8192.

The OLMo dataset has no document boundaries inside an instance — sequences span doc boundaries with no special masking — so segmentation and position are trivially uniform within an instance.

Parameters:
  • args (Any)

  • kwargs (Any)

Return type:

Any

map(element)[source]#

Convert tokens into inputs / targets / segmentation tensors.

Parameters:

element (Dict[str, Any])

Return type:

Dict[str, Any]

maxtext.input_pipeline.olmo_data_grain.make_olmo_grain_data_loader(index, *, seed, batch_size, shard_index, shard_count, apply_ngram_filter=True, shift_to_inputs_targets=True, path_remap=None, grain_worker_count=0, grain_worker_buffer_size=1, initial_step=0)[source]#

Build a Grain DataLoader for OLMo-style fixed-seq-length training.

Parameters:
  • index (OlmoNpyIndex) – Loaded OlmoNpyIndex.

  • seed (int) – Shuffle seed (paired with the implicit per-step epoch = step // n_per_host to drive the per-epoch permutation).

  • batch_size (int) – Per-host batch size (i.e. global_batch / shard_count).

  • shard_index (int) – This host’s data-loading rank.

  • shard_count (int) – Total data-loading hosts.

  • apply_ngram_filter (bool) – Add NgramFilterTransform (recommended).

  • shift_to_inputs_targets (bool) – Add ShiftToInputsTargets so the loader yields the inputs/targets shape MaxText’s trainer expects.

  • path_remap (Dict[str, str] | None) – Pass-through to OlmoNpyDataSource.

  • grain_worker_count (int) – 0 runs in-process; otherwise Grain forks workers.

  • grain_worker_buffer_size (int) – Per-worker batch prefetch.

  • initial_step (int) – Start the underlying sampler at this absolute step. The Grain DataLoader still iterates from its own 0, but every record lookup is shifted by initial_step. Set this to train_step * batch_size on resume to pick up the data stream where it left off without needing Grain’s iterator-state checkpointing.

Returns:

A grain.DataLoader.