# Copyright 2023–2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Input pipeline for a LM1B dataset."""
import warnings
import functools
import ml_collections
import tensorflow as tf
import tensorflow_datasets as tfds
import jax
from maxtext.input_pipeline import multihost_dataloading
from maxtext.input_pipeline.packing import sequence_packing
from maxtext.input_pipeline import input_pipeline_utils
AUTOTUNE = tf.data.experimental.AUTOTUNE
# reserve GPU memory for JAX only if tensorflow is built with GPU support
try:
tf.config.experimental.set_visible_devices([], "GPU")
except tf.errors.NotFoundError:
pass
[docs]
def get_datasets(
dataset_name,
data_split,
shuffle_files,
shuffle_seed,
dataloading_host_index,
dataloading_host_count,
dataset_path=None,
):
"""Load a TFDS dataset."""
ds_builder = tfds.builder(dataset_name, data_dir=dataset_path)
if shuffle_files:
read_config = tfds.ReadConfig(shuffle_seed=shuffle_seed)
else:
read_config = tfds.ReadConfig()
if ds_builder.info.splits[data_split].num_shards >= dataloading_host_count:
read_config.input_context = tf.distribute.InputContext(
input_pipeline_id=dataloading_host_index,
num_input_pipelines=dataloading_host_count,
)
ds = ds_builder.as_dataset(split=data_split, read_config=read_config, shuffle_files=shuffle_files)
else:
warnings.warn(
f"WARNING: Inefficient dataloading. Your {dataset_name} contains {ds_builder.info.splits[data_split].num_shards}"
f"shards, smaller than {dataloading_host_count=}. This is known to lead to inefficient dataloading."
"see https://github.com/google/maxtext/blob/main/getting_started/Data_Input_Pipeline.md"
"#multihost-dataloading-best-practice"
)
ds = ds_builder.as_dataset(split=data_split, read_config=read_config, shuffle_files=shuffle_files)
ds = ds.shard(num_shards=dataloading_host_count, index=dataloading_host_index)
return ds
[docs]
def preprocessing_pipeline(
dataset,
tokenizer_path,
tokenizer_type: str,
global_batch_size: int,
max_target_length: int,
data_column_names,
shuffle: bool = False,
data_shuffle_seed=0,
tokenize: bool = True,
add_bos: bool = True,
add_eos: bool = True,
num_epochs: None | int = 1,
pack_examples: bool = True,
shuffle_buffer_size: int = 1024,
shift: bool = True,
drop_remainder: bool = True,
prefetch_size=tf.data.experimental.AUTOTUNE,
use_dpo: bool = False,
hf_access_token: str = "",
):
"""pipeline for preprocessing TFDS dataset."""
missing = [c for c in data_column_names if c not in dataset.element_spec]
if missing:
raise ValueError(
f"Column {missing} not found in dataset. Available columns: {sorted(dataset.element_spec.keys())}. "
"Please set train_data_columns or eval_data_columns accordingly."
)
if not use_dpo:
assert len(data_column_names) == 1
dataset = dataset.map(
lambda x: input_pipeline_utils.normalize_features(x, data_column_names[0]), num_parallel_calls=AUTOTUNE
)
else:
dataset = dataset.map(lambda x: {col: x[col] for col in data_column_names}, num_parallel_calls=AUTOTUNE)
data_column_names = data_column_names if use_dpo else ("inputs", "targets")
tokenizer_model = input_pipeline_utils.get_tokenizer(tokenizer_path, tokenizer_type, add_bos, add_eos, hf_access_token)
if tokenizer_model.pad_id is not None:
pad_id = tokenizer_model.pad_id
elif tokenizer_model.unk_id is not None:
pad_id = tokenizer_model.unk_id
else:
pad_id = -1
if tokenize:
dataset = dataset.map(
lambda x: input_pipeline_utils.TokenizeOp(
tokenizer_model=tokenizer_model, features=x, data_keys=data_column_names
),
num_parallel_calls=AUTOTUNE,
)
if max_target_length > 0:
# in pre-training we can take upto max_length+1 because there would be truncation by
# 1 token for both inputs and targets
extra_tokens = 1 if not use_dpo else 0
dataset = dataset.map(
lambda x: input_pipeline_utils.truncate_to_max_allowable_length(x, max_target_length + extra_tokens),
num_parallel_calls=AUTOTUNE,
)
# Shuffle and repeat.
if shuffle:
dataset = dataset.shuffle(shuffle_buffer_size, seed=data_shuffle_seed)
dataset = dataset.repeat(num_epochs)
# Shift inputs for teacher-forced training
if shift and not use_dpo:
dataset = dataset.map(
input_pipeline_utils.shift_data_by_truncation, num_parallel_calls=tf.data.AUTOTUNE, deterministic=True
)
# Perform greedy sequence packing and batching
if pack_examples and not use_dpo:
dataset = sequence_packing.pack_dataset(dataset, max_target_length, pad_id)
dataset = dataset.batch(global_batch_size // jax.process_count(), drop_remainder=drop_remainder)
else:
# simple (static-shape) padded batching
dataset = dataset.padded_batch(
global_batch_size // jax.process_count(),
padded_shapes={k: max_target_length for k in data_column_names},
padding_values={k: pad_id for k in data_column_names},
drop_remainder=drop_remainder,
)
dataset = dataset.map(
lambda x: input_pipeline_utils.add_segmentation_and_position(x, data_column_names, padding_token=pad_id),
num_parallel_calls=tf.data.AUTOTUNE,
deterministic=True,
)
if prefetch_size:
dataset = dataset.prefetch(prefetch_size)
return dataset
[docs]
def make_tfds_train_iterator(
config: ml_collections.ConfigDict,
global_mesh,
process_indices_train,
):
"""load dataset, preprocess and return iterators"""
assert (
config.global_batch_size_to_load % global_mesh.size == 0
), "Batch size should be divisible by number of global devices."
get_datasets_kwargs = {
"dataset_name": config.dataset_name,
"dataset_path": config.dataset_path,
"data_split": config.train_split,
"shuffle_files": config.enable_data_shuffling,
"shuffle_seed": config.data_shuffle_seed,
}
if not config.colocated_python_data_input:
train_ds = get_datasets(
dataloading_host_index=process_indices_train.index(jax.process_index()),
dataloading_host_count=len(process_indices_train),
**get_datasets_kwargs,
)
train_dataloader = preprocessing_pipeline(
dataset=train_ds,
tokenizer_path=config.tokenizer_path,
tokenizer_type=config.tokenizer_type,
global_batch_size=config.global_batch_size_to_load,
max_target_length=config.max_target_length,
data_column_names=config.train_data_columns,
shuffle=config.enable_data_shuffling,
data_shuffle_seed=config.data_shuffle_seed,
tokenize=config.tokenize_train_data,
add_bos=config.add_bos,
add_eos=config.add_eos,
num_epochs=config.num_epoch,
pack_examples=config.packing,
use_dpo=config.use_dpo,
hf_access_token=config.hf_access_token,
)
return multihost_dataloading.MultiHostDataLoadIterator(
train_dataloader, global_mesh, config.generate_padding_batch_train
)
else:
get_ds_fn = functools.partial(
get_datasets,
**get_datasets_kwargs,
)
preprocessing_fn = functools.partial(
preprocessing_pipeline,
tokenizer_path=config.tokenizer_path,
tokenizer_type=config.tokenizer_type,
global_batch_size=config.global_batch_size_to_load,
max_target_length=config.max_target_length,
data_column_names=config.train_data_columns,
shuffle=config.enable_data_shuffling,
data_shuffle_seed=config.data_shuffle_seed,
tokenize=config.tokenize_train_data,
add_bos=config.add_bos,
add_eos=config.add_eos,
num_epochs=config.num_epoch,
pack_examples=config.packing,
use_dpo=config.use_dpo,
hf_access_token=config.hf_access_token,
)
global_shape = (config.global_batch_size_to_load, config.max_target_length)
return multihost_dataloading.RemoteIteratorWrapper(
get_ds_fn, preprocessing_fn, global_mesh, global_shape, checkpoint_path=config.checkpoint_dir
)
[docs]
def make_tfds_eval_iterator(
config: ml_collections.ConfigDict,
global_mesh,
process_indices_eval,
):
"""load eval dataset, preprocess and return iterators"""
assert (
config.global_batch_size_to_load_eval % global_mesh.size == 0
), "Batch size should be divisible by number of global devices."
if not config.colocated_python_data_input:
eval_ds = get_datasets(
dataset_name=config.eval_dataset_name,
data_split=config.eval_split,
shuffle_files=False,
shuffle_seed=config.data_shuffle_seed,
dataloading_host_index=process_indices_eval.index(jax.process_index()),
dataloading_host_count=len(process_indices_eval),
)
eval_dataloader = preprocessing_pipeline(
dataset=eval_ds,
tokenizer_path=config.tokenizer_path,
tokenizer_type=config.tokenizer_type,
global_batch_size=config.global_batch_size_to_load_eval,
max_target_length=config.max_target_length,
data_column_names=config.eval_data_columns,
shuffle=False,
data_shuffle_seed=config.data_shuffle_seed,
tokenize=config.tokenize_eval_data,
add_bos=config.add_bos,
add_eos=config.add_eos,
pack_examples=config.packing,
use_dpo=config.use_dpo,
hf_access_token=config.hf_access_token,
)
return multihost_dataloading.MultiHostDataLoadIterator(
eval_dataloader, global_mesh, config.generate_padding_batch_eval
)
else:
get_ds_fn = functools.partial(
get_datasets,
dataset_name=config.eval_dataset_name,
data_split=config.eval_split,
shuffle_files=False,
shuffle_seed=config.data_shuffle_seed,
)
preprocessing_fn = functools.partial(
preprocessing_pipeline,
tokenizer_path=config.tokenizer_path,
tokenizer_type=config.tokenizer_type,
global_batch_size=config.global_batch_size_to_load_eval,
max_target_length=config.max_target_length,
data_column_names=config.eval_data_columns,
shuffle=False,
data_shuffle_seed=config.data_shuffle_seed,
tokenize=config.tokenize_eval_data,
add_bos=config.add_bos,
add_eos=config.add_eos,
pack_examples=config.packing,
use_dpo=config.use_dpo,
hf_access_token=config.hf_access_token,
)
return multihost_dataloading.RemoteIteratorWrapper(
get_ds_fn, preprocessing_fn, config, global_mesh, checkpoint_path=config.checkpoint_dir
)