maxtext.input_pipeline.multihost_dataloading module#

SPMD Multihost Dataloading Utilities.

Adapted from Sholto’s: sholtodouglas/multihost_dataloading

class maxtext.input_pipeline.multihost_dataloading.MultiHostDataLoadIterator(dataloader, global_mesh, generate_padding_batch=False, expansion_loading_factor_for_grain=-1)[source]#

Bases: object

fold get_next_batch_sharded into a iterator class. expansion_factor_for_grain is only used for grain pipeline when having a subset of hosts loading real data.

Parameters:
  • dataloader (DatasetV2 | Iterable)

  • global_mesh (Mesh)

  • generate_padding_batch (bool)

  • expansion_loading_factor_for_grain (int)

reset()[source]#
class maxtext.input_pipeline.multihost_dataloading.RemoteIteratorWrapper(get_ds_fn, preprocessing_fn, global_mesh, global_shape, checkpoint_path='', elastic=False)[source]#

Bases: object

Wrapper for RemoteIterator that handles device placement.

reset()[source]#
save_state(step)[source]#
restore_state(step)[source]#