maxtext.input_pipeline.data_processing_utils module#

Utility functions for data processing pipelines.

maxtext.input_pipeline.data_processing_utils.parse_and_keep_features(dataset, config, data_columns, tokenize)[source]#

Parse arrayrecord features or keep specified columns for other formats.

maxtext.input_pipeline.data_processing_utils.get_tokenizer_and_pad_id(config)[source]#

Builds tokenizer and extracts pad_id safely.

maxtext.input_pipeline.data_processing_utils.validate_and_configure_sft_columns(data_columns, tokenizer_model, chat_template=None)[source]#

Validates SFT data columns and configures the tokenizer chat template.

maxtext.input_pipeline.data_processing_utils.get_local_batch_size(config)[source]#

Computes local batch size based on process count and expansion factor.

maxtext.input_pipeline.data_processing_utils.format_and_batch(dataset, config, batch_size, pad_id, data_columns, tokenizer_model, shift=True)[source]#

Packs or pads the dataset, batches it, and optionally shifts tokens for next-token prediction.

When config.grain_use_elastic_iterator is True, batching is skipped (ElasticIterator performs it internally) and, if shift=True, the shift is applied pre-batch on axis 0, which is equivalent to a post-batch axis=1 shift.

shift should be False for pipelines that don’t do next-token prediction (e.g. DPO, which scores full sequences).

maxtext.input_pipeline.data_processing_utils.apply_multiprocessing_and_prefetch(dataset, config, grain_worker_count, grain_per_worker_buffer_size)[source]#

Applies multiprocessing and prefetching configurations to the dataset.