# Copyright 2023–2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Input pipeline using Huggingface datasets."""
from typing import Optional
import ml_collections
import jax
import transformers
import grain.python as grain
import numpy as np
from maxtext.input_pipeline import data_processing_utils
from maxtext.input_pipeline import input_pipeline_utils
from maxtext.input_pipeline import instruction_data_processing
from maxtext.input_pipeline import multihost_dataloading
from maxtext.utils import elastic_utils
def _get_pad_id(tokenizer):
if tokenizer.pad_token_id is not None:
pad_id = tokenizer.pad_token_id
elif tokenizer.unk_token_id is not None:
pad_id = tokenizer.unk_token_id
else:
pad_id = -1
return pad_id
[docs]
def vision_sft_preprocessing_pipeline(
dataset,
config,
dataloading_host_index,
dataloading_host_count,
global_mesh,
text_columns,
image_column,
global_batch_size,
):
"""pipeline for multimodal SFT with HF dataset"""
assert len(text_columns) == 2, f"Need two text_columns for query and response, received {text_columns=}"
# Tunix GA requires per-micro-batch slicing at the data level,
# whereas Native GA processes the full batch and splits it internally.
if config.elastic_enabled:
local_batch_size = elastic_utils.get_local_batch_size(config)
else:
local_batch_size = global_batch_size // jax.process_count()
if config.use_tunix_gradient_accumulation:
batch_size = local_batch_size // config.gradient_accumulation_steps
else:
batch_size = local_batch_size
# for multi-epoch with shuffle, shuffle each epoch with different seeds then concat
import datasets # pylint: disable=import-outside-toplevel
if config.enable_data_shuffling and config.num_epoch > 1:
epoch_datasets = [dataset.shuffle(seed=config.data_shuffle_seed + i) for i in range(config.num_epoch)]
dataset = datasets.concatenate_datasets(epoch_datasets)
elif config.enable_data_shuffling:
dataset = dataset.shuffle(seed=config.data_shuffle_seed)
elif config.num_epoch > 1:
dataset = dataset.repeat(config.num_epoch)
# If multiple image columns are provided, merge them into a single 'images' column.
if isinstance(image_column, list):
dataset = dataset.map(
input_pipeline_utils.merge_image_columns,
fn_kwargs={
"image_columns": image_column,
"max_num_images_per_example": config.max_num_images_per_example,
},
remove_columns=image_column, # Drop the original image columns
)
image_column = "images"
dataset = dataset.select_columns(text_columns + [image_column])
if image_column != "images":
dataset = dataset.rename_column(image_column, "images")
dataset = dataset.map(
input_pipeline_utils.reformat_prompt,
fn_kwargs={
"column": text_columns[0],
"image_placeholder": config.image_placeholder,
"model_name": config.model_name,
},
)
dataset = dataset.map(
input_pipeline_utils.reformat_response,
fn_kwargs={"column": text_columns[1], "model_name": config.model_name},
)
dataset = dataset.map(
input_pipeline_utils.pre_process_image_sft,
fn_kwargs={"image_column": "images", "model_name": config.model_name},
)
tokenizer = transformers.AutoTokenizer.from_pretrained(
config.tokenizer_path,
add_bos_token=False,
add_eos_token=False,
legacy=False,
token=config.hf_access_token,
)
pad_id = _get_pad_id(tokenizer)
dataset = dataset.map(
input_pipeline_utils.tokenization,
batched=True,
batch_size=global_batch_size,
fn_kwargs={
"hf_tokenizer": tokenizer,
"truncation": False,
"max_length": config.max_target_length,
"column_names": text_columns,
},
)
dataset = dataset.map(
input_pipeline_utils.prepare_text_for_image_fusion,
fn_kwargs={"column_name": text_columns[0], "config": config},
)
dataset = input_pipeline_utils.HFDataSource(
dataset=dataset,
dataloading_host_index=dataloading_host_index,
dataloading_host_count=dataloading_host_count,
num_threads=1,
max_target_length=config.max_target_length,
data_column_names=text_columns,
)
operations = []
operations.append(
input_pipeline_utils.SFTPromptMaskingVision(
query_column=text_columns[0],
response_column=text_columns[1],
max_target_length=config.max_target_length,
pad_id=pad_id,
)
)
# TODO(aireenmei, hengtaoguo): support packing
operations.append(
input_pipeline_utils.PadOrTrimToMaxLength(
config.max_target_length,
pad_id,
config=config,
max_num_images_per_example=config.max_num_images_per_example,
)
)
operations.append(input_pipeline_utils.ExtractImagesAndMasks())
operations.append(grain.Batch(batch_size=batch_size, drop_remainder=True))
operations.append(input_pipeline_utils.FoldImagesIntoBatch(model_name=config.model_name))
operations.append(input_pipeline_utils.ShiftData(ignored_ids=[pad_id], axis=1))
dummy_index_sampler = grain.IndexSampler(
num_records=len(dataset),
num_epochs=1,
shard_options=grain.ShardOptions(
shard_index=dataloading_host_index, shard_count=dataloading_host_count, drop_remainder=False
),
shuffle=False,
seed=0,
)
dataloader = grain.DataLoader(
data_source=dataset,
operations=operations,
sampler=dummy_index_sampler,
worker_count=1, # only supports <=1 for now, more workers results in duplicated data
worker_buffer_size=1,
read_options=grain.ReadOptions(num_threads=1, prefetch_buffer_size=batch_size * 4),
)
multihost_gen = multihost_dataloading.MultiHostDataLoadIterator(dataloader, global_mesh)
# Return multi-host jax.Array prep iterator
return multihost_gen
[docs]
def preprocessing_pipeline(
dataloading_host_index,
dataloading_host_count,
global_mesh,
dataset,
config,
data_column_names,
tokenize,
tokenizer_path,
hf_access_token,
global_batch_size,
max_target_length,
shuffle,
data_shuffle_seed,
chat_template_path="",
add_bos=True,
add_eos=True,
packing=True,
shift=True,
num_threads=1,
drop_remainder=True,
generate_padding_batch=False,
use_dpo=None,
use_sft=None,
use_tunix_gradient_accumulation=False,
num_microbatches=1,
sft_train_on_completion_only=True,
grain_worker_count=1, # only support 0 or 1
max_segments_per_seq=None,
num_epoch=1,
chat_template: Optional[str] = None,
formatting_func_path: Optional[str] = None,
formatting_func_kwargs: Optional[dict] = None,
):
"""pipeline for preprocessing HF dataset"""
import datasets # pylint: disable=import-outside-toplevel
assert global_batch_size % global_mesh.size == 0, "Batch size should be divisible by number of global devices."
# Tunix GA requires per-micro-batch slicing at the data level,
# whereas Native GA processes the full batch and splits it internally.
if config.elastic_enabled:
local_batch_size = elastic_utils.get_local_batch_size(config)
else:
local_batch_size = global_batch_size // jax.process_count()
if use_tunix_gradient_accumulation:
batch_size = local_batch_size // num_microbatches
else:
batch_size = local_batch_size
# for multi-epoch with shuffle, shuffle each epoch with different seeds then concat
if shuffle and num_epoch > 1:
epoch_datasets = [dataset.shuffle(seed=data_shuffle_seed + i) for i in range(num_epoch)]
dataset = datasets.concatenate_datasets(epoch_datasets)
elif shuffle:
dataset = dataset.shuffle(seed=data_shuffle_seed)
elif num_epoch > 1:
dataset = dataset.repeat(num_epoch)
tokenizer = transformers.AutoTokenizer.from_pretrained(
tokenizer_path,
add_bos_token=add_bos if not use_sft else False,
add_eos_token=add_eos if not use_sft else False,
legacy=False,
token=hf_access_token,
)
dataset = dataset.select_columns(data_column_names)
if use_sft:
if not chat_template:
chat_template = instruction_data_processing.load_chat_template_from_file(chat_template_path)
data_processing_utils.validate_and_configure_sft_columns(data_column_names, tokenizer, chat_template)
# convert instruction dataset to conversational format
dataset, data_column_names = instruction_data_processing.convert_to_conversational_format(
dataset=dataset,
data_columns=data_column_names,
formatting_func_path=formatting_func_path,
formatting_func_kwargs=formatting_func_kwargs,
)
assert input_pipeline_utils.is_conversational(
dataset.features, data_column_names
), "Dataset is not in conversational format."
if len(data_column_names) > 1:
combined_column_name = "messages"
dataset_features = datasets.Features(
{combined_column_name: [{"content": datasets.Value(dtype="string"), "role": datasets.Value(dtype="string")}]}
)
dataset = dataset.map(
input_pipeline_utils.combine_columns,
fn_kwargs={"columns": data_column_names, "data_column": combined_column_name},
remove_columns=data_column_names,
features=dataset_features,
)
data_column_names = list(dataset.features.keys())
dataset = dataset.map(
input_pipeline_utils.apply_chat_template,
fn_kwargs={"tokenizer_model": tokenizer, "data_column_name": data_column_names[0]},
)
pad_id = _get_pad_id(tokenizer)
if tokenize:
dataset = dataset.map(
input_pipeline_utils.tokenization,
batched=True,
fn_kwargs={
"hf_tokenizer": tokenizer,
"truncation": not use_sft,
"max_length": max_target_length,
"column_names": data_column_names,
},
)
dataset = input_pipeline_utils.HFDataSource(
dataset,
dataloading_host_index,
dataloading_host_count,
num_threads,
max_target_length,
data_column_names,
)
operations = []
if use_sft:
input_pipeline_utils.verify_chat_template_generation_prompt_logic(tokenizer)
operations.append(
input_pipeline_utils.SFTPromptMasking(
text_column_name=data_column_names[0],
completion_only=sft_train_on_completion_only,
max_target_length=max_target_length,
unk_id=pad_id,
)
)
data_column_names = ("inputs", "targets")
elif use_dpo:
def lists2array(x):
"""Convert lists/tuples to array"""
return jax.tree.map(np.asarray, x, is_leaf=lambda y: isinstance(y, (list, tuple)))
operations.append(grain.MapOperation(lists2array))
else:
assert len(data_column_names) == 1
operations.append(input_pipeline_utils.HFNormalizeFeatures(data_column_names[0]))
data_column_names = ("inputs", "targets")
if packing and not use_dpo:
length_struct = {col: max_target_length for col in data_column_names}
max_segments = max_segments_per_seq
if max_segments is not None and max_segments <= 0:
max_segments = None
operations.append(
grain.experimental.PackAndBatchOperation(
batch_size=batch_size,
length_struct=length_struct,
max_sequences_per_bin=max_segments,
)
)
operations.append(input_pipeline_utils.ReformatPacking(data_column_names))
else:
operations.append(input_pipeline_utils.PadOrTrimToMaxLength(max_target_length, pad_id))
operations.append(grain.Batch(batch_size=batch_size, drop_remainder=drop_remainder))
if shift and not use_dpo:
operations.append(input_pipeline_utils.ShiftData(ignored_ids=[pad_id, tokenizer.bos_token_id], axis=1))
# Since HuggingFace IterableDataset does not support access through index
# Indexes generated by dummy_index_sampler is not used.
# dummy_index_sampler is used as an input place holder for grain.Dataloader
dummy_index_sampler = grain.IndexSampler(
num_records=len(dataset),
num_epochs=1,
shard_options=grain.ShardOptions(
shard_index=dataloading_host_index, shard_count=dataloading_host_count, drop_remainder=False
),
shuffle=False,
seed=0,
)
dataloader = grain.DataLoader(
data_source=dataset,
operations=operations,
sampler=dummy_index_sampler,
worker_count=grain_worker_count, # only supports <=1 for now, more workers results in duplicated data
worker_buffer_size=1,
read_options=grain.ReadOptions(num_threads=num_threads, prefetch_buffer_size=128),
)
multihost_gen = multihost_dataloading.MultiHostDataLoadIterator(dataloader, global_mesh, generate_padding_batch)
# Return multi-host jax.Array prep iterator
return multihost_gen
[docs]
def make_hf_train_iterator(
config: ml_collections.ConfigDict,
global_mesh,
process_indices_train,
):
"""Load, preprocess dataset and return iterators"""
import datasets # pylint: disable=import-outside-toplevel
train_ds = datasets.load_dataset(
config.hf_path,
name=config.hf_name,
data_dir=config.hf_data_dir,
data_files=config.hf_train_files,
split=config.train_split,
streaming=True,
token=config.hf_access_token,
)
if config.use_sft and config.use_multimodal:
train_iter = vision_sft_preprocessing_pipeline(
dataset=train_ds,
config=config,
dataloading_host_index=process_indices_train.index(jax.process_index()),
dataloading_host_count=len(process_indices_train),
global_mesh=global_mesh,
text_columns=config.train_data_columns,
image_column=config.train_image_column,
global_batch_size=config.global_batch_size_to_load,
)
else:
train_iter = preprocessing_pipeline(
dataloading_host_index=process_indices_train.index(jax.process_index()),
dataloading_host_count=len(process_indices_train),
global_mesh=global_mesh,
dataset=train_ds,
config=config,
data_column_names=config.train_data_columns,
tokenize=config.tokenize_train_data,
tokenizer_path=config.tokenizer_path,
hf_access_token=config.hf_access_token,
global_batch_size=config.global_batch_size_to_load,
max_target_length=config.max_target_length,
shuffle=config.enable_data_shuffling,
data_shuffle_seed=config.data_shuffle_seed,
add_bos=config.add_bos,
add_eos=config.add_eos,
packing=config.packing,
generate_padding_batch=config.generate_padding_batch_train,
use_dpo=config.use_dpo,
use_sft=config.use_sft,
use_tunix_gradient_accumulation=config.use_tunix_gradient_accumulation,
num_microbatches=config.gradient_accumulation_steps,
sft_train_on_completion_only=config.sft_train_on_completion_only,
chat_template_path=config.chat_template_path,
max_segments_per_seq=config.max_segments_per_seq,
num_epoch=config.num_epoch,
chat_template=config.chat_template,
formatting_func_path=config.formatting_func_path,
formatting_func_kwargs=config.formatting_func_kwargs,
)
return train_iter
[docs]
def make_hf_eval_iterator(
config: ml_collections.ConfigDict,
global_mesh,
process_indices_eval,
):
"""Make Hugging Face evaluation iterator. Load and preprocess eval dataset: and return iterator."""
import datasets # pylint: disable=import-outside-toplevel
eval_ds = datasets.load_dataset(
config.hf_path,
name=config.hf_name,
data_dir=config.hf_data_dir,
data_files=config.hf_eval_files,
split=config.hf_eval_split,
streaming=True,
token=config.hf_access_token,
)
if config.use_sft and config.use_multimodal:
eval_iter = vision_sft_preprocessing_pipeline(
dataset=eval_ds,
config=config,
dataloading_host_index=process_indices_eval.index(jax.process_index()),
dataloading_host_count=len(process_indices_eval),
global_mesh=global_mesh,
text_columns=config.eval_data_columns,
image_column=config.eval_image_column,
global_batch_size=config.global_batch_size_to_load_eval,
)
else:
eval_iter = preprocessing_pipeline(
dataloading_host_index=process_indices_eval.index(jax.process_index()),
dataloading_host_count=len(process_indices_eval),
global_mesh=global_mesh,
dataset=eval_ds,
config=config,
data_column_names=config.eval_data_columns,
tokenize=config.tokenize_eval_data,
tokenizer_path=config.tokenizer_path,
hf_access_token=config.hf_access_token,
global_batch_size=config.global_batch_size_to_load_eval,
max_target_length=config.max_target_length,
shuffle=False,
data_shuffle_seed=config.data_shuffle_seed,
add_bos=config.add_bos,
add_eos=config.add_eos,
packing=config.packing,
generate_padding_batch=config.generate_padding_batch_eval,
use_dpo=config.use_dpo,
use_sft=config.use_sft,
num_microbatches=config.gradient_accumulation_steps,
sft_train_on_completion_only=config.sft_train_on_completion_only,
chat_template_path=config.chat_template_path,
max_segments_per_seq=config.max_segments_per_seq,
chat_template=config.chat_template,
formatting_func_path=config.formatting_func_path,
formatting_func_kwargs=config.formatting_func_kwargs,
)
return eval_iter