maxtext.layers.pipeline_deprecated module#

Pipeline layer wrapping a decoder layer(s). Supports circular pipelining

class maxtext.layers.pipeline_deprecated.Pipeline(config, layers, mesh, remat_policy=None, parent=<flax.linen.module._Sentinel object>, name=None)[source]#

Bases: Module

Module that implements pipelining across stages.

This module will loop over microbatches and execute the main body with a vmap for both the inputs and weights. This will produce a pipeline pattern if the stage dimension is sharded.

Supports circular pipelines, and multiple layers per stage are used when a module that executes multiple layers is passed as the layers input.

Parameters:
  • config (Any)

  • layers (Module)

  • mesh (Mesh)

  • remat_policy (Any)

  • parent (Module | Scope | _Sentinel | None)

  • name (str | None)

config: Any#

Importantly contains num_pipeline_microbatches, num_pipeline_repeats.

layers: Module#

A module instance that each stage can execute. It can either be a single layer such as a LlamaDecoderLayer instance or scanned/looped set of decoder layers to execute multiple layers per stage. The name of this property (layers) is reflected in the state pytree and thus also checkpoints.

mesh: Mesh#

The device mesh of the system.

remat_policy: Any = None#

Remat policy to use for the loop iterations

setup()[source]#

Initializes a Module lazily (similar to a lazy __init__).

setup is called once lazily on a module instance when a module is bound, immediately before any other methods like __call__ are invoked, or before a setup-defined attribute on self is accessed.

This can happen in three cases:

  1. Immediately when invoking apply(), init() or init_and_output().

  2. Once the module is given a name by being assigned to an attribute of another module inside the other module’s setup method (see __setattr__()):

    >>> class MyModule(nn.Module):
    ...   def setup(self):
    ...     submodule = nn.Conv(...)
    
    ...     # Accessing `submodule` attributes does not yet work here.
    
    ...     # The following line invokes `self.__setattr__`, which gives
    ...     # `submodule` the name "conv1".
    ...     self.conv1 = submodule
    
    ...     # Accessing `submodule` attributes or methods is now safe and
    ...     # either causes setup() to be called once.
    
  3. Once a module is constructed inside a method wrapped with compact(), immediately before another method is called or setup defined attribute is accessed.

need_circ_storage()[source]#
iterations_to_complete_first_microbatch_one_repeat()[source]#
iterations_to_complete_first_microbatch()[source]#
init_states(inputs)[source]#

Initialize components of state: state_io, shift, circular_storage and circular_storage_mover Assumes input has already been reshaped into microbatches: [num_micro_batches, micro_batch_size, sequence, embed]

Returns a dictionary with properties

shift: zeros shape [num_stages, micro_size, sequence, embed] prev_outputs: same shape as shift, only used when pipeline_delay_activation_forwarding is set to true, else None state_io: reshaped inputs [num_stages, microbatches/stages, micro_size, sequence, embed] circ_storage: zeros [num_stages, microbatches, micro_size, sequence, embed] when needed, else None circ_storage_mover: zeros[num_stages, micro_size, sequence, embed] when needed, else None loop_iteration: scalar set initially to 0.

get_iteration_inputs(loop_iteration, state_io, circ_storage, shift)[source]#

Construct stages_in: the global array that is operated on for this iteration, shape same as shift=[stages, micro_size, sequence, embed] This is almost a rotated version of the last outputs, except for the first stage which must grab a new batch from state_io or an old one from circ_storage

shard_dim_by_stages(x, dim, physical_partition_spec, is_stage_weight=False)[source]#

Shards x using the provided partition_spec, but adds the “stage” mesh axis to the existing sharding at the specified dimension.

Parameters:
  • dim (int)

  • physical_partition_spec (PartitionSpec | None)

  • is_stage_weight (bool)

get_microbatch_and_repeat_ids(loop_iteration)[source]#

Gets the microbatch_ids and repeat_ids for all stages on this loop_iteration. Works for both circular and non-circular

vmap_parallel_gather(weights, physical_partition_spec, repeat_ids, repeat_dim_in_weights, stages_dim_in_weights)[source]#

Use vmap to implement a sharded parallel gather. Parallel gather means each stage has its own weights, and gets one slice from it. :param weights: Per-stage data to be gathered from. :param physical_partition_spec: Physical partition spec of the input weight. :param repeat_ids: Integer tensor of shape [num_stages], the repeats of the stages. :param repeat_dim_in_weights: The dimension in weights where repeat_ids are applied. The output will not

have this dimension.

Parameters:

stages_dim_in_weights – The dimension in weights that represents parallel stages.

Returns:

The per-stage gathered values. The shape is weights.shape but with repeat_dim_in_weights

removed.

vmap_gather(xs, ids, ids_dim)[source]#

Use vmap to implement a stage-wise sharded gather.

The stages share the same input, but they have different offsets.

Parameters:
  • xs – Data shared by all stages, to be gathered from.

  • ids – Integer tensor of shape [num_stages], the offsets of the stages.

  • ids_dim – The dimension in xs where ids are applied. In the output, this dimension will be [num_stages], since each stage gets one slice.

Returns:

The per-stage gathered values. The shape is xs.shape but with ids_dim size

replaced with [num_stages].

get_new_loop_state(output, loop_state)[source]#

Update the various buffers given the output of the most recent iteration * state_io: rotates left/up by 1 (the whole created in the last slot is filled with the most recent pipeline output)

  • Pushing inputs up from top of state_io into first stage of shift

  • Pulling outputs up from last stage of shift into bottom of state_io

  • shift: rotate output (or prev_outputs if using delay) right/down by 1 - we imagine the pipeline moves to

    right/down

  • circ_storage: pushes circ_storage_mover (the output of the previous iteration) into rotating index of circ_storage

  • circ_storage_mover: assigned to rotated output and pushed into circ_storage on the next iteration

  • prev_outputs: is set to the current output

permute_output_micro_per_stage_dim(output)[source]#
get_current_stage_weights(pipeline_weights, loop_iteration, physical_partition_spec=None)[source]#

Gets the current weights used for one iteration. Outputs a pytree whose arrays have leading dimension of stages, e.g. {‘mlp’: ‘wo’: [stages, mlp, embed]}. Stage 0 will use the 0th index of this pytree, Stage 1 the 1st index, etc. For non-circular pipelines, this simply returns all weights - every weight is used in every iteraiton. However for circular pipelines each stage grabs only the weights corresponding to the current repeat.

get_current_repeat_from_stages(weights, loop_iteration, physical_partition_spec=None)[source]#

get current repeat from stages

get_vmap_func_for_init()[source]#

This vmap func is used to initialize the weights only on init.

get_main_vmap_func_for_iterations()[source]#

Returns main stage function vmapped by number of stages. This becomes a vmap over a single layer instance if body_instance is a single layer, else a set of layers if body_instance is a set of layers.

run_one_iteration(loop_state, pipeline_weights, positions, segment_ids, deterministic, model_mode, decoder_layer_instance, logical_partition_spec=None)[source]#

Run one loop iteration - gets weights and inputs for each stage, run the stages in parallel, and update the loop state.

get_pipeline_remat_policy()[source]#

Returns the pipeline remat policy for this pipeline.

get_weight_sharding(*init_args)[source]#

get weight sharding function for this pipeline.

static get_logical_spec_repeats_removed(full_logical)[source]#
all_gather_over_fsdp(variables, logical_partition_spec)[source]#
name: str | None = None#
parent: Module | Scope | _Sentinel | None = None#
scope: Scope | None = None#