maxtext.input_pipeline.olmo_data_grain module#
OLMo numpy fixed-seq-length dataset on top of Grain.
A Grain RandomAccessDataSource over the AI2 OLMo virtual token stream
plus a deterministic global-shuffle sampler. See
docs/guides/data_input_pipeline/olmo_grain.md for an overview.
- class maxtext.input_pipeline.olmo_data_grain.OlmoNpyDataSource(*args, **kwargs)[source]#
Bases:
RandomAccessDataSourceRandom-access view of an OLMo numpy mix as a stream of token windows.
Files are opened lazily and cached as
np.memmapper worker. Open mmaps are reference-counted by_MmapCacheso we don’t blow pastulimit -nwhen iterating over the full 950-file mix.The data source is process-safe: every Grain worker subprocess builds its own
_MmapCacheafter the fork. No shared mutable state.- Parameters:
index (OlmoNpyIndex) – The
OlmoNpyIndexdescribing the mix. Path strings must be reachable from the data-loading host (typically a GCSFUSE mount path like/mnt/<your-mount>/...).path_remap (Optional[Dict[str, str]]) – Optional dict to rewrite
index.files[i].path. Useful when the index was built withgs://paths and you want to read from a gcsfuse mount, or vice versa. A path is rewritten if it starts with any key in this dict.max_open_files (int) – Soft cap on the number of mmaps held open in the per-worker cache. The cache is LRU.
- class maxtext.input_pipeline.olmo_data_grain.OlmoIndexSampler(*, total_instances, seed, shard_index=0, shard_count=1, shuffle=True, initial_step=0)[source]#
Bases:
objectGlobal-shuffle sampler over an OLMo numpy mix.
Mirrors OLMo-core’s
NumpyDataLoaderBaseshuffle math: a single Fisher-Yates over[0, total_instances)keyed byhash(seed, epoch), then partitioned acrossshard_counthosts.Implements Grain’s
Samplerprotocol — i.e.__getitem__returninggrain.python.RecordMetadata. Grain callssampler[index]for each global step; the sampler is responsible for mapping that to the actual record_key fed todata_source[record_key].Indexing semantics:
indexhere is a per-host (per-data-loader) global step counter starting at 0 and advancing without bound (we support infinite epochs).epoch = index // num_local_instances_per_epochselects which permutation to use;in_epoch = index % num_local_instances_per_epochselects the position within this host’s shard of that permutation.
Checkpointing is trivial: the only mutable state is “which epoch’s permutation is currently cached” (a perf optimization). The user-visible position is just the index passed to
__getitem__.- Parameters:
total_instances (int) –
index.total_instancesfrom the OLMo index.seed (int) – Base seed for the shuffle.
shard_index (int) – Zero-based index of this data-loading host. Typically
jax.process_index().shard_count (int) – Number of data-loading hosts. Typically
jax.process_count().shuffle (bool) – If
False, instances are emitted in linear order — useful for debugging.initial_step (int) – Per-host batch step at which the training run should resume.
__getitem__(local_idx)returns the record at absolute positionlocal_idx + initial_step. Use this to resume a run from a saved trainer step without saving Grain’s iterator state — our sampler is a pure function of its inputs, so the (seed, shard, absolute step) tuple fully determines the next record.
- property num_instances: int#
- property num_local_instances_per_epoch: int#
Instances assigned to this host per epoch (drops trailing remainder).
- shuffled_global_indices(*, seed, epoch)[source]#
Build the full shuffled list for
(seed, epoch).For the production 724 M-instance mix this allocates ~5.8 GB at uint64 (numpy’s default for
permutation). For production we should swap to an on-disk memmap scheme like olmo-core’sbuild_and_save_global_indices. Sized for unit tests + the initial smoke training run for now.- Parameters:
seed (int)
epoch (int)
- Return type:
ndarray
- class maxtext.input_pipeline.olmo_data_grain.NgramFilterTransform(*args, **kwargs)[source]#
Bases:
MapAdd an
instance_maskfield per OLMo-core’s repetition filter.instance_mask = Trueif the instance is “clean” (kept fully in the loss);Falseif it has too-repetitive periodic spans (zero-out at loss time). We don’t drop the instance — that would mess with sharding — matching OLMo-core’s behavior.- Parameters:
max_period (int)
min_period (int)
max_count (int)
mask_value (int)
- class maxtext.input_pipeline.olmo_data_grain.ShiftToInputsTargets(*args, **kwargs)[source]#
Bases:
MapConvert a
tokensarray into the keys MaxText’s pretrain trainer expects.Produces, for a single instance of length
L = sequence_length:inputs:tokens.astype(int32), shape(L,)targets:tokensshifted left by one, padded with 0 at positionL-1, shape(L,)inputs_position:[0, 1, ..., L-1]int32inputs_segmentation:int32ones, shape(L,)— single segmenttargets_segmentation:int32ones, shape(L,)with the last position zeroed (loss masked at the padded position); the entire row is zero ifinstance_maskis False (n-gram filter flagged the instance).
Outputs are the full
Ltokens (notL-1) because the TPU splash attention kernel requiresq_seq_lendivisible by 512; producing lengthL-1would break that invariant for typical OLMoL=8192.The OLMo dataset has no document boundaries inside an instance — sequences span doc boundaries with no special masking — so
segmentationandpositionare trivially uniform within an instance.- Parameters:
args (Any)
kwargs (Any)
- Return type:
Any
- maxtext.input_pipeline.olmo_data_grain.make_olmo_grain_data_loader(index, *, seed, batch_size, shard_index, shard_count, apply_ngram_filter=True, shift_to_inputs_targets=True, path_remap=None, grain_worker_count=0, grain_worker_buffer_size=1, initial_step=0)[source]#
Build a Grain
DataLoaderfor OLMo-style fixed-seq-length training.- Parameters:
index (OlmoNpyIndex) – Loaded
OlmoNpyIndex.seed (int) – Shuffle seed (paired with the implicit per-step
epoch = step // n_per_hostto drive the per-epoch permutation).batch_size (int) – Per-host batch size (i.e. global_batch / shard_count).
shard_index (int) – This host’s data-loading rank.
shard_count (int) – Total data-loading hosts.
apply_ngram_filter (bool) – Add
NgramFilterTransform(recommended).shift_to_inputs_targets (bool) – Add
ShiftToInputsTargetsso the loader yields theinputs/targetsshape MaxText’s trainer expects.path_remap (Dict[str, str] | None) – Pass-through to
OlmoNpyDataSource.grain_worker_count (int) –
0runs in-process; otherwise Grain forks workers.grain_worker_buffer_size (int) – Per-worker batch prefetch.
initial_step (int) – Start the underlying sampler at this absolute step. The Grain DataLoader still iterates from its own 0, but every record lookup is shifted by
initial_step. Set this totrain_step * batch_sizeon resume to pick up the data stream where it left off without needing Grain’s iterator-state checkpointing.
- Returns:
A
grain.DataLoader.