Source code for maxtext.input_pipeline.tfds_data_processing

# Copyright 2023–2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#    https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Input pipeline for a LM1B dataset."""

import warnings
import functools

import ml_collections

import tensorflow as tf
import tensorflow_datasets as tfds

import jax

from maxtext.input_pipeline import multihost_dataloading
from maxtext.input_pipeline.packing import sequence_packing
from maxtext.input_pipeline import input_pipeline_utils

AUTOTUNE = tf.data.experimental.AUTOTUNE

# reserve GPU memory for JAX only if tensorflow is built with GPU support
try:
  tf.config.experimental.set_visible_devices([], "GPU")
except tf.errors.NotFoundError:
  pass


[docs] def get_datasets( dataset_name, data_split, shuffle_files, shuffle_seed, dataloading_host_index, dataloading_host_count, dataset_path=None, ): """Load a TFDS dataset.""" ds_builder = tfds.builder(dataset_name, data_dir=dataset_path) if shuffle_files: read_config = tfds.ReadConfig(shuffle_seed=shuffle_seed) else: read_config = tfds.ReadConfig() if ds_builder.info.splits[data_split].num_shards >= dataloading_host_count: read_config.input_context = tf.distribute.InputContext( input_pipeline_id=dataloading_host_index, num_input_pipelines=dataloading_host_count, ) ds = ds_builder.as_dataset(split=data_split, read_config=read_config, shuffle_files=shuffle_files) else: warnings.warn( f"WARNING: Inefficient dataloading. Your {dataset_name} contains {ds_builder.info.splits[data_split].num_shards}" f"shards, smaller than {dataloading_host_count=}. This is known to lead to inefficient dataloading." "see https://github.com/google/maxtext/blob/main/getting_started/Data_Input_Pipeline.md" "#multihost-dataloading-best-practice" ) ds = ds_builder.as_dataset(split=data_split, read_config=read_config, shuffle_files=shuffle_files) ds = ds.shard(num_shards=dataloading_host_count, index=dataloading_host_index) return ds
[docs] def preprocessing_pipeline( dataset, tokenizer_path, tokenizer_type: str, global_batch_size: int, max_target_length: int, data_column_names, shuffle: bool = False, data_shuffle_seed=0, tokenize: bool = True, add_bos: bool = True, add_eos: bool = True, num_epochs: None | int = 1, pack_examples: bool = True, shuffle_buffer_size: int = 1024, shift: bool = True, drop_remainder: bool = True, prefetch_size=tf.data.experimental.AUTOTUNE, use_dpo: bool = False, hf_access_token: str = "", ): """pipeline for preprocessing TFDS dataset.""" missing = [c for c in data_column_names if c not in dataset.element_spec] if missing: raise ValueError( f"Column {missing} not found in dataset. Available columns: {sorted(dataset.element_spec.keys())}. " "Please set train_data_columns or eval_data_columns accordingly." ) if not use_dpo: assert len(data_column_names) == 1 dataset = dataset.map( lambda x: input_pipeline_utils.normalize_features(x, data_column_names[0]), num_parallel_calls=AUTOTUNE ) else: dataset = dataset.map(lambda x: {col: x[col] for col in data_column_names}, num_parallel_calls=AUTOTUNE) data_column_names = data_column_names if use_dpo else ("inputs", "targets") tokenizer_model = input_pipeline_utils.get_tokenizer(tokenizer_path, tokenizer_type, add_bos, add_eos, hf_access_token) if tokenizer_model.pad_id is not None: pad_id = tokenizer_model.pad_id elif tokenizer_model.unk_id is not None: pad_id = tokenizer_model.unk_id else: pad_id = -1 if tokenize: dataset = dataset.map( lambda x: input_pipeline_utils.TokenizeOp( tokenizer_model=tokenizer_model, features=x, data_keys=data_column_names ), num_parallel_calls=AUTOTUNE, ) if max_target_length > 0: # in pre-training we can take upto max_length+1 because there would be truncation by # 1 token for both inputs and targets extra_tokens = 1 if not use_dpo else 0 dataset = dataset.map( lambda x: input_pipeline_utils.truncate_to_max_allowable_length(x, max_target_length + extra_tokens), num_parallel_calls=AUTOTUNE, ) # Shuffle and repeat. if shuffle: dataset = dataset.shuffle(shuffle_buffer_size, seed=data_shuffle_seed) dataset = dataset.repeat(num_epochs) # Shift inputs for teacher-forced training if shift and not use_dpo: dataset = dataset.map( input_pipeline_utils.shift_data_by_truncation, num_parallel_calls=tf.data.AUTOTUNE, deterministic=True ) # Perform greedy sequence packing and batching if pack_examples and not use_dpo: dataset = sequence_packing.pack_dataset(dataset, max_target_length, pad_id) dataset = dataset.batch(global_batch_size // jax.process_count(), drop_remainder=drop_remainder) else: # simple (static-shape) padded batching dataset = dataset.padded_batch( global_batch_size // jax.process_count(), padded_shapes={k: max_target_length for k in data_column_names}, padding_values={k: pad_id for k in data_column_names}, drop_remainder=drop_remainder, ) dataset = dataset.map( lambda x: input_pipeline_utils.add_segmentation_and_position(x, data_column_names, padding_token=pad_id), num_parallel_calls=tf.data.AUTOTUNE, deterministic=True, ) if prefetch_size: dataset = dataset.prefetch(prefetch_size) return dataset
[docs] def make_tfds_train_iterator( config: ml_collections.ConfigDict, global_mesh, process_indices_train, ): """load dataset, preprocess and return iterators""" assert ( config.global_batch_size_to_load % global_mesh.size == 0 ), "Batch size should be divisible by number of global devices." get_datasets_kwargs = { "dataset_name": config.dataset_name, "dataset_path": config.dataset_path, "data_split": config.train_split, "shuffle_files": config.enable_data_shuffling, "shuffle_seed": config.data_shuffle_seed, } if not config.colocated_python_data_input: train_ds = get_datasets( dataloading_host_index=process_indices_train.index(jax.process_index()), dataloading_host_count=len(process_indices_train), **get_datasets_kwargs, ) train_dataloader = preprocessing_pipeline( dataset=train_ds, tokenizer_path=config.tokenizer_path, tokenizer_type=config.tokenizer_type, global_batch_size=config.global_batch_size_to_load, max_target_length=config.max_target_length, data_column_names=config.train_data_columns, shuffle=config.enable_data_shuffling, data_shuffle_seed=config.data_shuffle_seed, tokenize=config.tokenize_train_data, add_bos=config.add_bos, add_eos=config.add_eos, num_epochs=config.num_epoch, pack_examples=config.packing, use_dpo=config.use_dpo, hf_access_token=config.hf_access_token, ) return multihost_dataloading.MultiHostDataLoadIterator( train_dataloader, global_mesh, config.generate_padding_batch_train ) else: get_ds_fn = functools.partial( get_datasets, **get_datasets_kwargs, ) preprocessing_fn = functools.partial( preprocessing_pipeline, tokenizer_path=config.tokenizer_path, tokenizer_type=config.tokenizer_type, global_batch_size=config.global_batch_size_to_load, max_target_length=config.max_target_length, data_column_names=config.train_data_columns, shuffle=config.enable_data_shuffling, data_shuffle_seed=config.data_shuffle_seed, tokenize=config.tokenize_train_data, add_bos=config.add_bos, add_eos=config.add_eos, num_epochs=config.num_epoch, pack_examples=config.packing, use_dpo=config.use_dpo, hf_access_token=config.hf_access_token, ) global_shape = (config.global_batch_size_to_load, config.max_target_length) return multihost_dataloading.RemoteIteratorWrapper( get_ds_fn, preprocessing_fn, global_mesh, global_shape, checkpoint_path=config.checkpoint_dir )
[docs] def make_tfds_eval_iterator( config: ml_collections.ConfigDict, global_mesh, process_indices_eval, ): """load eval dataset, preprocess and return iterators""" assert ( config.global_batch_size_to_load_eval % global_mesh.size == 0 ), "Batch size should be divisible by number of global devices." if not config.colocated_python_data_input: eval_ds = get_datasets( dataset_name=config.eval_dataset_name, data_split=config.eval_split, shuffle_files=False, shuffle_seed=config.data_shuffle_seed, dataloading_host_index=process_indices_eval.index(jax.process_index()), dataloading_host_count=len(process_indices_eval), ) eval_dataloader = preprocessing_pipeline( dataset=eval_ds, tokenizer_path=config.tokenizer_path, tokenizer_type=config.tokenizer_type, global_batch_size=config.global_batch_size_to_load_eval, max_target_length=config.max_target_length, data_column_names=config.eval_data_columns, shuffle=False, data_shuffle_seed=config.data_shuffle_seed, tokenize=config.tokenize_eval_data, add_bos=config.add_bos, add_eos=config.add_eos, pack_examples=config.packing, use_dpo=config.use_dpo, hf_access_token=config.hf_access_token, ) return multihost_dataloading.MultiHostDataLoadIterator( eval_dataloader, global_mesh, config.generate_padding_batch_eval ) else: get_ds_fn = functools.partial( get_datasets, dataset_name=config.eval_dataset_name, data_split=config.eval_split, shuffle_files=False, shuffle_seed=config.data_shuffle_seed, ) preprocessing_fn = functools.partial( preprocessing_pipeline, tokenizer_path=config.tokenizer_path, tokenizer_type=config.tokenizer_type, global_batch_size=config.global_batch_size_to_load_eval, max_target_length=config.max_target_length, data_column_names=config.eval_data_columns, shuffle=False, data_shuffle_seed=config.data_shuffle_seed, tokenize=config.tokenize_eval_data, add_bos=config.add_bos, add_eos=config.add_eos, pack_examples=config.packing, use_dpo=config.use_dpo, hf_access_token=config.hf_access_token, ) return multihost_dataloading.RemoteIteratorWrapper( get_ds_fn, preprocessing_fn, config, global_mesh, checkpoint_path=config.checkpoint_dir )