maxtext.input_pipeline.packing.sequence_packing module

Contents

maxtext.input_pipeline.packing.sequence_packing module#

Packed Sequence Op.

maxtext.input_pipeline.packing.sequence_packing.pack_dataset(dataset, key2length, pad_id, keys=None)[source]#

Creates a ‘packed’ version of a dataset on-the-fly. Adapted from the mesh-tf implementation. This is meant to replace the irritation of having to create a separate “packed” version of a dataset to train efficiently on TPU. Each example in the output dataset represents several examples in the input dataset. For each key in the input dataset, two additional keys are created: <key>_segmentation: an int32 tensor identifying the parts

representing the original example.

<key>_position: an int32 tensor identifying the position within the original

example.

Example: Two input examples get combined to form an output example. The input examples are: {“inputs”: [8, 7, 1, 0], “targets”:[4, 1, 0]} {“inputs”: [2, 3, 4, 1], “targets”:[5, 6, 1]} The output example is: {

“inputs”: [8, 7, 1, 2, 3, 4, 1, 0, 0, 0]

“inputs_segmentation”: [1, 1, 1, 2, 2, 2, 2, 0, 0, 0]
“inputs_position”: [0, 1, 2, 0, 1, 2, 3, 0, 0, 0]

“targets”: [4, 1, 5, 6, 1, 0, 0, 0, 0, 0]

“targets_segmentation”: [1, 1, 2, 2, 2, 0, 0, 0, 0, 0]

“targets_position”: [0, 1, 0, 1, 2, 0, 0, 0, 0, 0]

} 0 represents padding in both the inputs and the outputs. Sequences in the incoming examples are truncated to length “length”, and the sequences in the output examples all have fixed (padded) length “length”. :param dataset: a tf.data.Dataset :param key2length: an integer, or a dict from feature-key to integer :param keys: a list of strings (e.g. [“inputs”, “targets”])

Returns:

a tf.data.Dataset

Parameters:
  • dataset (DatasetV2)

  • key2length (int | dict[str, int])

  • pad_id (int)

  • keys (None | list[str])

Return type:

DatasetV2