maxtext.experimental.rl.grpo_input_pipeline module#
Input pipeline for GRPO training using Hugging Face datasets.
This module provides functions to create a data loading and preprocessing pipeline for Group Relative Policy Optimization (GRPO). It leverages the datasets library to stream data from Hugging Face and grain for efficient processing and batching. The pipeline tokenizes, pads/trims, and batches text data to be used as prompts for the GRPO generation and training loop.
- class maxtext.experimental.rl.grpo_input_pipeline.SingleHostDataLoader(dataloader, global_mesh)[source]#
Bases:
objectA data loader for a single host that wraps a grain.DataLoader.
This class provides a standard Python iterator interface over a grain.DataLoader. It is designed to be used on a single host and ensures that the iterator can be reset.
- Parameters:
dataloader (grain.python.DataLoader)
global_mesh (Mesh)
- global_mesh#
The JAX device mesh.
- dataloader#
The underlying grain.DataLoader instance.
- local_iterator#
The Python iterator created from the dataloader.
- maxtext.experimental.rl.grpo_input_pipeline.preprocessing_pipeline(dataloading_host_index, dataloading_host_count, global_mesh, dataset, data_column_names, tokenize, tokenizer_path, hf_access_token, global_batch_size, max_target_length, shuffle=False, data_shuffle_seed=0, add_bos=True, add_eos=True, num_threads=1, drop_remainder=False)[source]#
Creates a preprocessing pipeline for a Hugging Face dataset.
This function sets up a series of operations to tokenize, pad, and batch data from a streaming Hugging Face dataset. It is designed to be used within a multi-host data loading setup.
- Parameters:
dataloading_host_index – The index of the current host in the data loading process.
dataloading_host_count – The total number of hosts involved in data loading.
global_mesh – The JAX device mesh.
dataset – The Hugging Face IterableDataset to preprocess.
data_column_names – The list of column names in the dataset to use.
tokenize – A boolean indicating whether to tokenize the data.
tokenizer_path – The path to the tokenizer model.
hf_access_token – The Hugging Face access token.
global_batch_size – The total batch size across all devices.
max_target_length – The maximum sequence length for padding or trimming.
shuffle (bool) – Whether to shuffle the dataset.
data_shuffle_seed – The seed for shuffling the data.
add_bos – Whether to add a beginning-of-sequence token.
add_eos – Whether to add an end-of-sequence token.
num_threads – The number of threads to use for data processing.
drop_remainder – Whether to drop the last batch if it’s smaller than the batch size.
- Returns:
An iterator that yields preprocessed data batches for the local host.
- maxtext.experimental.rl.grpo_input_pipeline.make_hf_train_iterator(config, global_mesh, process_indices_train)[source]#
Loads a Hugging Face dataset and creates a local preprocessed iterator.
This function loads a streaming dataset from the Hugging Face Hub, then applies the preprocessing_pipeline to create an iterator that yields batches of data suitable for training on the current host.
- Parameters:
config – The configuration object with dataset and model parameters.
global_mesh – The JAX device mesh.
process_indices_train – A list of process indices that are loading data.
- Returns:
A local data iterator for the training set.
- maxtext.experimental.rl.grpo_input_pipeline.create_data_iterator(config, mesh)[source]#
Creates a data iterator for GRPO training.
This function determines which processes should load data and then creates a process-specific data iterator for the training prompts. It currently does not support evaluation data.
- Parameters:
config – The configuration object containing data loading settings.
mesh – The JAX device mesh.
- Returns:
A data iterator that yields batches of training prompts.
- Raises:
ValueError – If evaluation is configured (eval_interval > 0), as the GRPO input pipeline does not support it.