maxtext.input_pipeline.input_pipeline_utils module#
Operations used by Grain
- maxtext.input_pipeline.input_pipeline_utils.get_tokenizer(tokenizer_path, tokenizer_type, add_bos, add_eos, hf_access_token=None)[source]#
- maxtext.input_pipeline.input_pipeline_utils.truncate_to_max_allowable_length(x, max_length)[source]#
- maxtext.input_pipeline.input_pipeline_utils.add_segmentation_and_position(x, data_columns, padding_token=0)[source]#
- maxtext.input_pipeline.input_pipeline_utils.TokenizeOp(tokenizer_model, features, data_keys=('inputs', 'targets'))[source]#
Op for tokenization
- Parameters:
features (dict[str, Any])
data_keys (Iterable[str])
- Return type:
dict[str, Any]
- maxtext.input_pipeline.input_pipeline_utils.reformat_prompt(example, column, image_placeholder, model_name)[source]#
reformat prompt for multimodal SFT
- maxtext.input_pipeline.input_pipeline_utils.reformat_response(example, column, model_name)[source]#
reformat response for multimodal SFT
- maxtext.input_pipeline.input_pipeline_utils.merge_image_columns(example, image_columns, max_num_images_per_example)[source]#
Merge multiple image columns into a single list of images.
- maxtext.input_pipeline.input_pipeline_utils.pre_process_image_sft(example, image_column, model_name)[source]#
pre-process image for multimodal SFT
- maxtext.input_pipeline.input_pipeline_utils.prepare_text_for_image_fusion(example, column_name, config)[source]#
prepare text for image fusion for multimodal SFT
- maxtext.input_pipeline.input_pipeline_utils.combine_columns(example, columns, data_column)[source]#
Combine columns such as ‘prompt’ and ‘completion’ for sft training
- maxtext.input_pipeline.input_pipeline_utils.is_conversational(features, data_columns)[source]#
Check if data is in a conversational format. Examples:
- features = {‘prompt’: [{‘content’: Value(dtype=’string’, id=None), ‘role’: Value(dtype=’string’, id=None)}],
‘completion’: [{‘content’: Value(dtype=’string’, id=None), ‘role’: Value(dtype=’string’, id=None)}]}
data_columns = [“prompt”, “completion”] is_conversational(features, data_columns) return True.
features = {‘prompt’: [Value(dtype=’string’, id=None)], ‘completion’: [Value(dtype=’string’, id=None)]} data_columns = [“prompt”, “completion”] is_conversational(features, data_columns) returns False.
- maxtext.input_pipeline.input_pipeline_utils.extract_token_ids(tokens)[source]#
Extracts token IDs from various tokenizer output formats.
This helper function standardizes the extraction of tokenized integer IDs from common return types of Hugging Face tokenizers, including BatchEncoding objects, dictionaries, or simple lists.
- Parameters:
tokens – The object containing token IDs. Supported types include: - A list of integers. - A dictionary containing the INPUT_TOKENS_KEY. - An object (e.g., BatchEncoding) with an attribute named INPUT_TOKENS_KEY.
- Returns:
A list of integer token IDs.
- Raises:
ValueError – If the input type is not supported or does not contain the expected key.
- maxtext.input_pipeline.input_pipeline_utils.verify_chat_template_generation_prompt_logic(tokenizer_model)[source]#
Verifies the tokenizer’s chat template for correct SFT loss masking.
This function ensures that the tokens added by add_generation_prompt=True are identical to the tokens that begin an assistant’s turn in a complete conversation, which is critical for masking prompt tokens during SFT loss calculation.
- Example of a mismatch:
A ValueError is raised if the generation prompt and the actual assistant prefix do not match. For example:
add_generation_prompt=True on a user message produces a prompt ending in: `…<|im_start|>generation
- `
A full turn with an assistant message starts the reply with: `…<|im_start|>assistant
…`
This function would fail because the tokens for “generation” do not match the tokens for “assistant”.
- Args:
tokenizer_model: The Hugging Face tokenizer instance to verify.
- Raises:
- ValueError: If the add_generation_prompt tokens do not exactly
match the beginning of an assistant message in the template.
- maxtext.input_pipeline.input_pipeline_utils.apply_chat_template(example, tokenizer_model, data_column_name)[source]#
Formats conversational data by applying the tokenizer’s chat template and identifying prompt/completion segments for SFT masking.
- Parameters:
example – A dictionary containing conversational data. It is expected to have a key specified by data_column_name that holds a list of messages.
tokenizer_model – The tokenizer instance associated with the language model, which contains the specific chat template.
data_column_name – The name of the column in the example dictionary that contains the list of messages.
- Returns:
- The modified example dictionary.
The data_column_name column will be updated to a list of messages, each formatted according to the tokenizer’s chat template.
A new column “is_prompt” is added, where True indicates the tokens contain the system message, user message, and generation prompt (if applicable). False indicates the expected LLM completion, excluding the assistant’s start tokens.
- maxtext.input_pipeline.input_pipeline_utils.tokenization(example, hf_tokenizer, truncation, max_length, column_names)[source]#
Tokenize a HuggingFace dataset
- class maxtext.input_pipeline.input_pipeline_utils.SFTPromptMasking(*args, **kwargs)[source]#
Bases:
MapTransformConstruct inputs and targets for SFT training. Concat prompt and completion to generate inputs. For targets, if train on completion only, the prompt will be masked by unk_id. Otherwise the same as inputs.
- map(element)[source]#
Maps a single dataset element to an SFT training instance. It concatenates the prompt and completion to form the inputs sequence. For the targets sequence: - If self.completion_only is True, the prompt portion of the
concatenated sequence is masked using self.unk_id.
If self.completion_only is False, the target sequence is identical to the input sequence.
- class maxtext.input_pipeline.input_pipeline_utils.SFTPromptMaskingVision(*args, **kwargs)[source]#
Bases:
MapTransformSFT prompt masking for multimodal
- class maxtext.input_pipeline.input_pipeline_utils.HFNormalizeFeatures(*args, **kwargs)[source]#
Bases:
MapTransformNormalize feature keys for HuggingFace input
- class maxtext.input_pipeline.input_pipeline_utils.HFDataSource(*args, **kwargs)[source]#
Bases:
RandomAccessDataSourceA class that makes HuggingFace IterableDataset a grain datasource without random access support
- Parameters:
dataset (datasets.IterableDataset)
dataloading_host_index (int)
dataloading_host_count (int)
num_threads (int)
max_target_length (int)
data_column_names (list[str])
- class maxtext.input_pipeline.input_pipeline_utils.GCSTFRecordIterDataset(*args, **kwargs)[source]#
Bases:
TFRecordIterDatasetExtends Grain’s TFRecordIterDataset to support GCS paths.
- Parameters:
args (Any)
kwargs (Any)
- Return type:
Any
- maxtext.input_pipeline.input_pipeline_utils.make_tfrecord_iter_dataset(path)[source]#
Returns the appropriate TFRecordIterDataset for local or GCS paths.
- Parameters:
path (str)
- class maxtext.input_pipeline.input_pipeline_utils.ParseFeatures(*args, **kwargs)[source]#
Bases:
MapTransformParse serialized example
- class maxtext.input_pipeline.input_pipeline_utils.NormalizeFeatures(*args, **kwargs)[source]#
Bases:
MapTransformNormalize text feature keys.
- class maxtext.input_pipeline.input_pipeline_utils.KeepFeatures(*args, **kwargs)[source]#
Bases:
MapTransformKeep only specified features in the dataset element.
This transform filters the input dictionary, retaining only the keys that are present in feature_names.
- Parameters:
feature_names (list[str])
- class maxtext.input_pipeline.input_pipeline_utils.Rekey(*args, **kwargs)[source]#
Bases:
MapTransformRename keys according to a mapping dict
- class maxtext.input_pipeline.input_pipeline_utils.ReformatPacking(*args, **kwargs)[source]#
Bases:
MapTransformReformat packing outputs.
- class maxtext.input_pipeline.input_pipeline_utils.PadOrTrimToMaxLength(*args, **kwargs)[source]#
Bases:
MapTransformPads or trims each input to the specified length. And optionally add true length for the input.
- Parameters:
max_length (int)
pad_id (int)
add_true_length (bool)
max_num_images_per_example (int)
- map(element)[source]#
map to each element
- Parameters:
element (dict[str, ndarray | PreprocessorOutput])
- Return type:
dict[str, ndarray | PreprocessorOutput]
- class maxtext.input_pipeline.input_pipeline_utils.ExtractImagesAndMasks(*args, **kwargs)[source]#
Bases:
MapTransformExtracts images and masks from a PreprocessorOutput object.
This transform is used in multi-modal data pipelines to extract the image tensors and their corresponding masks from a PreprocessorOutput object. The extracted images and masks are then added to the data element under the keys ‘images’ and ‘image_masks’, respectively.
If the ‘images’ key is not present in the input element, the transform returns the element unchanged.
- class maxtext.input_pipeline.input_pipeline_utils.FoldImagesIntoBatch(*args, **kwargs)[source]#
Bases:
MapTransformFolds the ‘image’ dimension into the batch dimension.
This transform is used in multi-modal data pipelines where each data example might have multiple associated images. For model processing, it’s often efficient to treat each image as a separate item in a larger batch.
This operation reshapes the ‘images’ tensor from a shape like (B, N, T, H, W, C) to (B * N, T, H, W, C), where B is the batch size, N is the number of images per example, and T is the number of image tiles.
The transformation is triggered only if the input ‘images’ tensor has more dimensions than the expected batched image tensor.
- Parameters:
model_name (str | None)
- model_name: str | None = None#
- maxtext.input_pipeline.input_pipeline_utils.shift_right(x, axis=1)[source]#
Shift the input to the right by padding and slicing on axis.
- maxtext.input_pipeline.input_pipeline_utils.shift_left(x, pad_id, axis=1)[source]#
Shift to the left and pad.
- maxtext.input_pipeline.input_pipeline_utils.shift_and_refine(x, ignored_ids, axis=1)[source]#
Shift inputs, set segmentation to 0 when target element is in ignored_ids if provided
- class maxtext.input_pipeline.input_pipeline_utils.ShiftData(*args, **kwargs)[source]#
Bases:
MapTransformShift inputs and refine annotations.
- class maxtext.input_pipeline.input_pipeline_utils.ComputeQwen3OmniPositions(*args, **kwargs)[source]#
Bases:
MapTransformComputes 3D position IDs for Qwen3-Omni multimodal sequences.
This transform replaces the standard 1D sequential positions with 3D positions (temporal, height, width) for multimodal models like Qwen3-Omni.
For text-only sequences, all 3 dimensions receive the same sequential values. For multimodal sequences with vision/audio, vision tokens get true 3D positions and text tokens continue sequentially from max(vision_pos) + 1.
The actual position computation is delegated to multimodal_utils.get_rope_index(), which can be tested and modified independently.
- Parameters:
data_column (str)
spatial_merge_size (int)
position_id_per_seconds (int)
use_audio_in_video (bool)
- map(element)[source]#
Compute 3D position IDs for the batch element.
- Parameters:
element (dict[str, ndarray]) – Dictionary containing: - {data_column}: Token IDs with shape (batch, seq_len) - {data_column}_segmentation: Attention mask (1=real, 0=padding) - image_grid_thw: Optional (num_images, 3) array - video_grid_thw: Optional (num_videos, 3) array - audio_lengths: Optional (num_audios,) array - second_per_grids: Optional (num_videos,) array
- Returns:
element with {data_column}_position updated to shape (3, batch, seq_len) for 3D positions (always 3D, even for text-only sequences).
- Return type:
dict[str, ndarray]