Source code for maxtext.input_pipeline.olmo_grain_data_processing

# Copyright 2023–2026 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#    https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""MaxText trainer adapter for the OLMo numpy fixed-seq-length pipeline.

The trainer expects ``dataset_type`` to map to two factory functions
``(make_<type>_train_iterator, make_<type>_eval_iterator)`` that take
``(config, mesh, process_indices)`` and return a
:class:`MultiHostDataLoadIterator`.

This module provides those for ``dataset_type=olmo_grain``. The hard work
lives in :mod:`maxtext.input_pipeline.olmo_data_grain` (data source +
sampler + transforms); here we just wire it to MaxText's config + the
multihost dataloading wrapper.

Notes
-----

* **Sequence length match**: ``config.max_target_length`` must match the
  ``sequence_length`` recorded in the index JSON. Mismatches raise at load
  time.
* **Path remap**: AI2's index typically holds ``gs://`` URIs. For training,
  we read via a GCSFUSE mount on each TPU host. The
  ``olmo_path_remap_from`` / ``olmo_path_remap_to`` config pair rewrites
  the prefix at runtime.
* **Sharding**: each data-loading host is assigned a non-overlapping shard
  of the global instance space via ``OlmoIndexSampler``. We use
  ``process_indices.index(jax.process_index())`` as the local shard index
  (matches the pattern in :mod:`grain_data_processing`).
"""

from __future__ import annotations

from typing import List

import jax
from etils import epath

from maxtext.input_pipeline import multihost_dataloading
from maxtext.input_pipeline.olmo_data import load_index
from maxtext.input_pipeline.olmo_data_grain import make_olmo_grain_data_loader
from maxtext.utils import max_logging


def _build_path_remap(config) -> dict:
  src = getattr(config, "olmo_path_remap_from", "") or ""
  dst = getattr(config, "olmo_path_remap_to", "") or ""
  if src and dst:
    return {src: dst}
  if src or dst:
    raise ValueError("olmo_path_remap_from and olmo_path_remap_to must both be set or both empty.")
  return {}


def _detect_resumed_step(config) -> int:
  """Return the step number of the latest checkpoint, or 0 for a fresh run.

  Used so the Grain DataLoader can resume reading at the same offset where
  the model checkpoint was saved (``initial_step = step * batch_size``).
  Uses :class:`etils.epath.Path` so the lookup works against both local
  paths and GCS (``gs://...``) — checkpoints commonly land straight in
  GCS, where ``os.path.isdir`` would silently return False.
  """
  if not getattr(config, "enable_checkpointing", False):
    return 0
  ckpt_dir = getattr(config, "checkpoint_dir", "") or ""
  if not ckpt_dir:
    return 0
  path = epath.Path(ckpt_dir)
  if not path.exists() or not path.is_dir():
    return 0
  steps = [int(p.name) for p in path.iterdir() if p.name.isdigit()]
  return max(steps) if steps else 0


def _make_loader_for_host(
    config,
    *,
    process_indices: List[int],
    seed: int,
):
  """Construct an OLMo grain DataLoader for the current data-loading host."""
  index = load_index(config.olmo_index_path)
  if index.sequence_length != config.max_target_length:
    raise ValueError(
        f"OLMo index sequence_length={index.sequence_length} but "
        f"config.max_target_length={config.max_target_length}. Either rebuild "
        f"the index with the matching seq length or update the config."
    )

  this_proc = jax.process_index()
  shard_index = process_indices.index(this_proc)
  shard_count = len(process_indices)

  per_host_batch = config.global_batch_size_to_load // shard_count
  if per_host_batch * shard_count != config.global_batch_size_to_load:
    raise ValueError(
        f"global_batch_size_to_load={config.global_batch_size_to_load} is not " f"divisible by shard_count={shard_count}"
    )

  # Resume = step counter from the latest checkpoint (if any) × per-host
  # batch. Our sampler is stateless, so this single integer is enough to
  # rejoin the stream — no Grain iterator-state serialization needed.
  resumed_step = _detect_resumed_step(config)
  initial_step = resumed_step * per_host_batch

  max_logging.log(
      f"OLMo grain loader: index={config.olmo_index_path} "
      f"total_instances={index.total_instances:,} "
      f"shard={shard_index}/{shard_count} per_host_batch={per_host_batch} "
      f"seq={index.sequence_length} resumed_step={resumed_step} "
      f"initial_step={initial_step}"
  )

  # Worker count and per-worker buffer reuse the standard grain flags. The
  # ``-1`` value of ``grain_worker_count`` is the auto-tuning sentinel for
  # the standard pipeline; we don't auto-tune yet, so treat it as 0
  # (in-process) for safety.
  worker_count = max(int(getattr(config, "grain_worker_count", 0) or 0), 0)
  worker_buffer = int(getattr(config, "grain_per_worker_buffer_size", 1) or 1)

  return make_olmo_grain_data_loader(
      index,
      seed=seed,
      batch_size=per_host_batch,
      shard_index=shard_index,
      shard_count=shard_count,
      apply_ngram_filter=getattr(config, "olmo_apply_ngram_filter", True),
      shift_to_inputs_targets=True,
      path_remap=_build_path_remap(config),
      grain_worker_count=worker_count,
      grain_worker_buffer_size=worker_buffer,
      initial_step=initial_step,
  )


[docs] def make_olmo_grain_train_iterator(config, global_mesh, process_indices): """Train iterator for ``dataset_type=olmo_grain``.""" if not getattr(config, "olmo_index_path", ""): raise ValueError( "When dataset_type=olmo_grain, please set config.olmo_index_path to " "the JSON produced by tools/data_generation/build_olmo_npy_index.py." ) loader = _make_loader_for_host( config, process_indices=process_indices, seed=int(getattr(config, "data_shuffle_seed", 0)), ) return multihost_dataloading.MultiHostDataLoadIterator( loader, global_mesh, config.generate_padding_batch_train, expansion_loading_factor_for_grain=config.expansion_factor_real_data, )
[docs] def make_olmo_grain_eval_iterator(config, global_mesh, process_indices): """Eval iterator for ``dataset_type=olmo_grain``. Currently reuses the train data with a different seed: the OLMo mix is a pretraining corpus with no canonical eval partition, so eval here means "deterministic held-out shuffle" rather than "held-out documents". For a real eval split, point a future ``config.eval_olmo_index_path`` at a separate index built over different files; the rest of this function is unchanged. """ if not getattr(config, "olmo_index_path", ""): raise ValueError("When dataset_type=olmo_grain, please set config.olmo_index_path.") loader = _make_loader_for_host( config, process_indices=process_indices, # Distinct seed so eval doesn't overlap train batch order. seed=int(getattr(config, "data_shuffle_seed", 0)) ^ 0x1F1F1F1F, ) return multihost_dataloading.MultiHostDataLoadIterator( loader, global_mesh, config.generate_padding_batch_eval, expansion_loading_factor_for_grain=config.expansion_factor_real_data, )