Source code for maxtext.input_pipeline.input_pipeline_interface

# Copyright 2023–2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#    https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Input pipeline"""
import functools

import jax
from jax.sharding import PartitionSpec as P

from maxtext.configs import pyconfig
from maxtext.input_pipeline.grain_data_processing import make_grain_train_iterator
from maxtext.input_pipeline.grain_data_processing import make_grain_eval_iterator
from maxtext.input_pipeline.hf_data_processing import make_hf_train_iterator
from maxtext.input_pipeline.hf_data_processing import make_hf_eval_iterator
from maxtext.input_pipeline.olmo_grain_data_processing import make_olmo_grain_train_iterator
from maxtext.input_pipeline.olmo_grain_data_processing import make_olmo_grain_eval_iterator
from maxtext.input_pipeline.tfds_data_processing import make_tfds_train_iterator
from maxtext.input_pipeline.tfds_data_processing import make_tfds_eval_iterator
from maxtext.input_pipeline.tfds_data_processing_c4_mlperf import make_c4_mlperf_train_iterator
from maxtext.input_pipeline.tfds_data_processing_c4_mlperf import make_c4_mlperf_eval_iterator
from maxtext.input_pipeline.synthetic_data_processing import SyntheticDataIterator
from maxtext.input_pipeline.synthetic_data_processing import PlaceHolderDataIterator
from maxtext.utils import max_logging
from maxtext.utils.sharding import remove_size_one_mesh_axis


[docs] def get_process_loading_real_data( data_sharding, global_batch_size_to_load, global_batch_size_to_train_on, max_target_length, mesh ): """Get list of processes loading data from GCS when expansion_factor_real_data != -1""" data_sharding_pspec = remove_size_one_mesh_axis(P(*data_sharding), mesh) sharding = jax.sharding.NamedSharding(mesh, data_sharding_pspec) devices_indices_map = sharding.devices_indices_map((global_batch_size_to_load, max_target_length)) batch_cutoff = global_batch_size_to_train_on process_loading_real_data = set() for p, indices in devices_indices_map.items(): if not indices[0].stop or indices[0].stop <= batch_cutoff: process_loading_real_data.add(p.process_index) return list(process_loading_real_data)
[docs] def create_process_specific_iterator(config: pyconfig.HyperParameters, mesh, process_indices, input_iterator): """ If the current process's index is among the `process_indices`, a real data iterator is created. Otherwise, a placeholder iterator is returned. """ if jax.process_index() in process_indices: iterator_fn = functools.partial(input_iterator, config, mesh, process_indices) output_iterator = iterator_fn() else: output_iterator = PlaceHolderDataIterator(config, mesh) return output_iterator
[docs] def create_data_iterator(config: pyconfig.HyperParameters, mesh): """Create train and eval data iterators given configs and mesh.""" # Return synthetic dataset if selected if config.dataset_type == "synthetic": eval_iterator = SyntheticDataIterator(config, mesh) if config.eval_interval > 0 else None return SyntheticDataIterator(config, mesh), eval_iterator dataset_type_to_train_eval_iterator = { "tfds": (make_tfds_train_iterator, make_tfds_eval_iterator), "grain": (make_grain_train_iterator, make_grain_eval_iterator), "hf": (make_hf_train_iterator, make_hf_eval_iterator), "c4_mlperf": (make_c4_mlperf_train_iterator, make_c4_mlperf_eval_iterator), "olmo_grain": (make_olmo_grain_train_iterator, make_olmo_grain_eval_iterator), } # Collect train and eval iterators if config.dataset_type in ["tfds", "grain", "hf", "c4_mlperf", "olmo_grain"]: if config.dataset_type == "c4_mlperf": assert config.packing, "c4_mlperf dataloader only works with packing. For padded version, use tfds dataloader" train_iterator, eval_iterator = dataset_type_to_train_eval_iterator[config.dataset_type] else: max_logging.log( f"WARNING: '{config.dataset_type}' is not a supported dataset type." "Using synthetic data. Please choose from 'tfds', 'grain', 'hf', or 'c4_mlperf'." ) output_train_iterator, output_eval_iterator = SyntheticDataIterator(config, mesh), None return output_train_iterator, output_eval_iterator # Generate output train iterator process_indices_train = get_process_loading_real_data( config.data_sharding, config.global_batch_size_to_load, config.global_batch_size_to_train_on, config.max_target_length, mesh, ) output_train_iterator = create_process_specific_iterator(config, mesh, process_indices_train, train_iterator) if config.expansion_factor_real_data > 1: # assert number of hosts loading real data assert len(process_indices_train) == jax.process_count() // config.expansion_factor_real_data # Generate output eval iterator output_eval_iterator = None if config.eval_interval > 0: process_indices_eval = get_process_loading_real_data( config.data_sharding, config.global_batch_size_to_load_eval, config.global_batch_size_to_eval_on, config.max_target_length, mesh, ) if config.expansion_factor_real_data > 1: assert len(process_indices_eval) == jax.process_count() // config.expansion_factor_real_data output_eval_iterator = create_process_specific_iterator(config, mesh, process_indices_eval, eval_iterator) return output_train_iterator, output_eval_iterator