maxtext.input_pipeline.input_pipeline_utils module#

Operations used by Grain

maxtext.input_pipeline.input_pipeline_utils.normalize_features(x, column_name)[source]#
maxtext.input_pipeline.input_pipeline_utils.get_tokenizer(tokenizer_path, tokenizer_type, add_bos, add_eos, hf_access_token=None)[source]#
maxtext.input_pipeline.input_pipeline_utils.truncate_to_max_allowable_length(x, max_length)[source]#
maxtext.input_pipeline.input_pipeline_utils.shift_data_by_truncation(x)[source]#
maxtext.input_pipeline.input_pipeline_utils.add_segmentation_and_position(x, data_columns, padding_token=0)[source]#
maxtext.input_pipeline.input_pipeline_utils.TokenizeOp(tokenizer_model, features, data_keys=('inputs', 'targets'))[source]#

Op for tokenization

Parameters:
  • features (dict[str, Any])

  • data_keys (Iterable[str])

Return type:

dict[str, Any]

maxtext.input_pipeline.input_pipeline_utils.reformat_prompt(example, column, image_placeholder, model_name)[source]#

reformat prompt for multimodal SFT

maxtext.input_pipeline.input_pipeline_utils.reformat_response(example, column, model_name)[source]#

reformat response for multimodal SFT

maxtext.input_pipeline.input_pipeline_utils.merge_image_columns(example, image_columns, max_num_images_per_example)[source]#

Merge multiple image columns into a single list of images.

maxtext.input_pipeline.input_pipeline_utils.pre_process_image_sft(example, image_column, model_name)[source]#

pre-process image for multimodal SFT

maxtext.input_pipeline.input_pipeline_utils.prepare_text_for_image_fusion(example, column_name, config)[source]#

prepare text for image fusion for multimodal SFT

maxtext.input_pipeline.input_pipeline_utils.combine_columns(example, columns, data_column)[source]#

Combine columns such as ‘prompt’ and ‘completion’ for sft training

maxtext.input_pipeline.input_pipeline_utils.is_conversational(features, data_columns)[source]#

Check if data is in a conversational format. Examples:

features = {‘prompt’: [{‘content’: Value(dtype=’string’, id=None), ‘role’: Value(dtype=’string’, id=None)}],

‘completion’: [{‘content’: Value(dtype=’string’, id=None), ‘role’: Value(dtype=’string’, id=None)}]}

data_columns = [“prompt”, “completion”] is_conversational(features, data_columns) return True.

features = {‘prompt’: [Value(dtype=’string’, id=None)], ‘completion’: [Value(dtype=’string’, id=None)]} data_columns = [“prompt”, “completion”] is_conversational(features, data_columns) returns False.

maxtext.input_pipeline.input_pipeline_utils.extract_token_ids(tokens)[source]#

Extracts token IDs from various tokenizer output formats.

This helper function standardizes the extraction of tokenized integer IDs from common return types of Hugging Face tokenizers, including BatchEncoding objects, dictionaries, or simple lists.

Parameters:

tokens – The object containing token IDs. Supported types include: - A list of integers. - A dictionary containing the INPUT_TOKENS_KEY. - An object (e.g., BatchEncoding) with an attribute named INPUT_TOKENS_KEY.

Returns:

A list of integer token IDs.

Raises:

ValueError – If the input type is not supported or does not contain the expected key.

maxtext.input_pipeline.input_pipeline_utils.verify_chat_template_generation_prompt_logic(tokenizer_model)[source]#

Verifies the tokenizer’s chat template for correct SFT loss masking.

This function ensures that the tokens added by add_generation_prompt=True are identical to the tokens that begin an assistant’s turn in a complete conversation, which is critical for masking prompt tokens during SFT loss calculation.

Example of a mismatch:

A ValueError is raised if the generation prompt and the actual assistant prefix do not match. For example:

  • add_generation_prompt=True on a user message produces a prompt ending in: `…<|im_start|>generation

`
  • A full turn with an assistant message starts the reply with: `…<|im_start|>assistant

…`

This function would fail because the tokens for “generation” do not match the tokens for “assistant”.

Args:

tokenizer_model: The Hugging Face tokenizer instance to verify.

Raises:
ValueError: If the add_generation_prompt tokens do not exactly

match the beginning of an assistant message in the template.

maxtext.input_pipeline.input_pipeline_utils.apply_chat_template(example, tokenizer_model, data_column_name)[source]#

Formats conversational data by applying the tokenizer’s chat template and identifying prompt/completion segments for SFT masking.

Parameters:
  • example – A dictionary containing conversational data. It is expected to have a key specified by data_column_name that holds a list of messages.

  • tokenizer_model – The tokenizer instance associated with the language model, which contains the specific chat template.

  • data_column_name – The name of the column in the example dictionary that contains the list of messages.

Returns:

The modified example dictionary.
  • The data_column_name column will be updated to a list of messages, each formatted according to the tokenizer’s chat template.

  • A new column “is_prompt” is added, where True indicates the tokens contain the system message, user message, and generation prompt (if applicable). False indicates the expected LLM completion, excluding the assistant’s start tokens.

maxtext.input_pipeline.input_pipeline_utils.tokenization(example, hf_tokenizer, truncation, max_length, column_names)[source]#

Tokenize a HuggingFace dataset

class maxtext.input_pipeline.input_pipeline_utils.SFTPromptMasking(*args, **kwargs)[source]#

Bases: MapTransform

Construct inputs and targets for SFT training. Concat prompt and completion to generate inputs. For targets, if train on completion only, the prompt will be masked by unk_id. Otherwise the same as inputs.

map(element)[source]#

Maps a single dataset element to an SFT training instance. It concatenates the prompt and completion to form the inputs sequence. For the targets sequence: - If self.completion_only is True, the prompt portion of the

concatenated sequence is masked using self.unk_id.

  • If self.completion_only is False, the target sequence is identical to the input sequence.

class maxtext.input_pipeline.input_pipeline_utils.SFTPromptMaskingVision(*args, **kwargs)[source]#

Bases: MapTransform

SFT prompt masking for multimodal

map(element)[source]#
class maxtext.input_pipeline.input_pipeline_utils.HFNormalizeFeatures(*args, **kwargs)[source]#

Bases: MapTransform

Normalize feature keys for HuggingFace input

map(element)[source]#
class maxtext.input_pipeline.input_pipeline_utils.HFDataSource(*args, **kwargs)[source]#

Bases: RandomAccessDataSource

A class that makes HuggingFace IterableDataset a grain datasource without random access support

Parameters:
  • dataset (datasets.IterableDataset)

  • dataloading_host_index (int)

  • dataloading_host_count (int)

  • num_threads (int)

  • max_target_length (int)

  • data_column_names (list[str])

class maxtext.input_pipeline.input_pipeline_utils.GCSTFRecordIterDataset(*args, **kwargs)[source]#

Bases: TFRecordIterDataset

Extends Grain’s TFRecordIterDataset to support GCS paths.

Parameters:
  • args (Any)

  • kwargs (Any)

Return type:

Any

maxtext.input_pipeline.input_pipeline_utils.make_tfrecord_iter_dataset(path)[source]#

Returns the appropriate TFRecordIterDataset for local or GCS paths.

Parameters:

path (str)

class maxtext.input_pipeline.input_pipeline_utils.ParseFeatures(*args, **kwargs)[source]#

Bases: MapTransform

Parse serialized example

map(element)[source]#

Parse a serialized tf.train.Example proto and extract features.

class maxtext.input_pipeline.input_pipeline_utils.NormalizeFeatures(*args, **kwargs)[source]#

Bases: MapTransform

Normalize text feature keys.

map(element)[source]#
class maxtext.input_pipeline.input_pipeline_utils.KeepFeatures(*args, **kwargs)[source]#

Bases: MapTransform

Keep only specified features in the dataset element.

This transform filters the input dictionary, retaining only the keys that are present in feature_names.

Parameters:

feature_names (list[str])

map(element)[source]#

Applies the feature filtering to the input element.

Parameters:

element (dict[str, Any])

Return type:

dict[str, Any]

class maxtext.input_pipeline.input_pipeline_utils.Rekey(*args, **kwargs)[source]#

Bases: MapTransform

Rename keys according to a mapping dict

map(element)[source]#
class maxtext.input_pipeline.input_pipeline_utils.ReformatPacking(*args, **kwargs)[source]#

Bases: MapTransform

Reformat packing outputs.

map(element)[source]#
class maxtext.input_pipeline.input_pipeline_utils.PadOrTrimToMaxLength(*args, **kwargs)[source]#

Bases: MapTransform

Pads or trims each input to the specified length. And optionally add true length for the input.

Parameters:
  • max_length (int)

  • pad_id (int)

  • add_true_length (bool)

  • max_num_images_per_example (int)

map(element)[source]#

map to each element

Parameters:

element (dict[str, ndarray | PreprocessorOutput])

Return type:

dict[str, ndarray | PreprocessorOutput]

class maxtext.input_pipeline.input_pipeline_utils.ExtractImagesAndMasks(*args, **kwargs)[source]#

Bases: MapTransform

Extracts images and masks from a PreprocessorOutput object.

This transform is used in multi-modal data pipelines to extract the image tensors and their corresponding masks from a PreprocessorOutput object. The extracted images and masks are then added to the data element under the keys ‘images’ and ‘image_masks’, respectively.

If the ‘images’ key is not present in the input element, the transform returns the element unchanged.

map(element)[source]#

Applies the extraction transformation to the ‘images’ field if present.

Parameters:

element (dict[str, ndarray])

Return type:

dict[str, ndarray]

class maxtext.input_pipeline.input_pipeline_utils.FoldImagesIntoBatch(*args, **kwargs)[source]#

Bases: MapTransform

Folds the ‘image’ dimension into the batch dimension.

This transform is used in multi-modal data pipelines where each data example might have multiple associated images. For model processing, it’s often efficient to treat each image as a separate item in a larger batch.

This operation reshapes the ‘images’ tensor from a shape like (B, N, T, H, W, C) to (B * N, T, H, W, C), where B is the batch size, N is the number of images per example, and T is the number of image tiles.

The transformation is triggered only if the input ‘images’ tensor has more dimensions than the expected batched image tensor.

Parameters:

model_name (str | None)

model_name: str | None = None#
map(element)[source]#

Applies the folding transformation to the ‘images’ field if present.

Parameters:

element (dict[str, ndarray])

Return type:

dict[str, ndarray]

maxtext.input_pipeline.input_pipeline_utils.shift_right(x, axis=1)[source]#

Shift the input to the right by padding and slicing on axis.

maxtext.input_pipeline.input_pipeline_utils.shift_left(x, pad_id, axis=1)[source]#

Shift to the left and pad.

maxtext.input_pipeline.input_pipeline_utils.shift_and_refine(x, ignored_ids, axis=1)[source]#

Shift inputs, set segmentation to 0 when target element is in ignored_ids if provided

class maxtext.input_pipeline.input_pipeline_utils.ShiftData(*args, **kwargs)[source]#

Bases: MapTransform

Shift inputs and refine annotations.

map(element)[source]#
class maxtext.input_pipeline.input_pipeline_utils.ComputeQwen3OmniPositions(*args, **kwargs)[source]#

Bases: MapTransform

Computes 3D position IDs for Qwen3-Omni multimodal sequences.

This transform replaces the standard 1D sequential positions with 3D positions (temporal, height, width) for multimodal models like Qwen3-Omni.

For text-only sequences, all 3 dimensions receive the same sequential values. For multimodal sequences with vision/audio, vision tokens get true 3D positions and text tokens continue sequentially from max(vision_pos) + 1.

The actual position computation is delegated to multimodal_utils.get_rope_index(), which can be tested and modified independently.

Parameters:
  • data_column (str)

  • spatial_merge_size (int)

  • position_id_per_seconds (int)

  • use_audio_in_video (bool)

map(element)[source]#

Compute 3D position IDs for the batch element.

Parameters:

element (dict[str, ndarray]) – Dictionary containing: - {data_column}: Token IDs with shape (batch, seq_len) - {data_column}_segmentation: Attention mask (1=real, 0=padding) - image_grid_thw: Optional (num_images, 3) array - video_grid_thw: Optional (num_videos, 3) array - audio_lengths: Optional (num_audios,) array - second_per_grids: Optional (num_videos,) array

Returns:

element with {data_column}_position updated to shape (3, batch, seq_len) for 3D positions (always 3D, even for text-only sequences).

Return type:

dict[str, ndarray]