maxtext.input_pipeline.olmo_grain_data_processing module

maxtext.input_pipeline.olmo_grain_data_processing module#

MaxText trainer adapter for the OLMo numpy fixed-seq-length pipeline.

The trainer expects dataset_type to map to two factory functions (make_<type>_train_iterator, make_<type>_eval_iterator) that take (config, mesh, process_indices) and return a MultiHostDataLoadIterator.

This module provides those for dataset_type=olmo_grain. The hard work lives in maxtext.input_pipeline.olmo_data_grain (data source + sampler + transforms); here we just wire it to MaxText’s config + the multihost dataloading wrapper.

Notes

  • Sequence length match: config.max_target_length must match the sequence_length recorded in the index JSON. Mismatches raise at load time.

  • Path remap: AI2’s index typically holds gs:// URIs. For training, we read via a GCSFUSE mount on each TPU host. The olmo_path_remap_from / olmo_path_remap_to config pair rewrites the prefix at runtime.

  • Sharding: each data-loading host is assigned a non-overlapping shard of the global instance space via OlmoIndexSampler. We use process_indices.index(jax.process_index()) as the local shard index (matches the pattern in grain_data_processing).

maxtext.input_pipeline.olmo_grain_data_processing.make_olmo_grain_train_iterator(config, global_mesh, process_indices)[source]#

Train iterator for dataset_type=olmo_grain.

maxtext.input_pipeline.olmo_grain_data_processing.make_olmo_grain_eval_iterator(config, global_mesh, process_indices)[source]#

Eval iterator for dataset_type=olmo_grain.

Currently reuses the train data with a different seed: the OLMo mix is a pretraining corpus with no canonical eval partition, so eval here means “deterministic held-out shuffle” rather than “held-out documents”. For a real eval split, point a future config.eval_olmo_index_path at a separate index built over different files; the rest of this function is unchanged.