maxtext.input_pipeline.olmo_grain_data_processing module#
MaxText trainer adapter for the OLMo numpy fixed-seq-length pipeline.
The trainer expects dataset_type to map to two factory functions
(make_<type>_train_iterator, make_<type>_eval_iterator) that take
(config, mesh, process_indices) and return a
MultiHostDataLoadIterator.
This module provides those for dataset_type=olmo_grain. The hard work
lives in maxtext.input_pipeline.olmo_data_grain (data source +
sampler + transforms); here we just wire it to MaxText’s config + the
multihost dataloading wrapper.
Notes
Sequence length match:
config.max_target_lengthmust match thesequence_lengthrecorded in the index JSON. Mismatches raise at load time.Path remap: AI2’s index typically holds
gs://URIs. For training, we read via a GCSFUSE mount on each TPU host. Theolmo_path_remap_from/olmo_path_remap_toconfig pair rewrites the prefix at runtime.Sharding: each data-loading host is assigned a non-overlapping shard of the global instance space via
OlmoIndexSampler. We useprocess_indices.index(jax.process_index())as the local shard index (matches the pattern ingrain_data_processing).
- maxtext.input_pipeline.olmo_grain_data_processing.make_olmo_grain_train_iterator(config, global_mesh, process_indices)[source]#
Train iterator for
dataset_type=olmo_grain.
- maxtext.input_pipeline.olmo_grain_data_processing.make_olmo_grain_eval_iterator(config, global_mesh, process_indices)[source]#
Eval iterator for
dataset_type=olmo_grain.Currently reuses the train data with a different seed: the OLMo mix is a pretraining corpus with no canonical eval partition, so eval here means “deterministic held-out shuffle” rather than “held-out documents”. For a real eval split, point a future
config.eval_olmo_index_pathat a separate index built over different files; the rest of this function is unchanged.