maxtext.layers.pipeline module

Contents

maxtext.layers.pipeline module#

Pipeline layer wrapping a decoder layer(s). Supports circular pipelining.

class maxtext.layers.pipeline.PipelineBase(config, layers, mesh, remat_policy=None, parent=<flax.linen.module._Sentinel object>, name=None)[source]#

Bases: Module

Base module that implements shared pipelining logic across stages.

Parameters:
  • config (Any)

  • layers (Module)

  • mesh (Mesh)

  • remat_policy (Any)

  • parent (Module | Scope | _Sentinel | None)

  • name (str | None)

config: Any#
layers: Module#
mesh: Mesh#
remat_policy: Any = None#
setup()[source]#

Initializes the configuration, calculating num_stages, delay, axes, and partition specs.

need_circ_storage()[source]#
iterations_to_complete_first_microbatch_one_repeat()[source]#
iterations_to_complete_first_microbatch()[source]#
get_iteration_inputs(loop_iteration, state_io, circ_storage, shift)[source]#

Construct stages_in: the global array that is operated on for this iteration, shape same as shift=[stages, micro_size, sequence, embed] This is almost a rotated version of the last outputs, except for the first stage which must grab a new batch from state_io or an old one from circ_storage

get_microbatch_and_repeat_ids(loop_iteration)[source]#

Gets the microbatch_ids and repeat_ids for all stages on this loop_iteration. Works for both circular and non-circular

get_pipeline_remat_policy()[source]#

Returns the pipeline remat policy for this pipeline.

get_weight_sharding(*init_args)[source]#

get weight sharding function for this pipeline.

get_vmap_func_for_init()[source]#

This vmap func is used to initialize the weights only on init.

get_main_vmap_func_for_iterations()[source]#

Returns main stage function vmapped by number of stages. This becomes a vmap over a single layer instance if body_instance is a single layer, else a set of layers if body_instance is a set of layers.

name: str | None = None#
parent: Module | Scope | _Sentinel | None = None#
scope: Scope | None = None#
class maxtext.layers.pipeline.Pipeline(config, layers, mesh, remat_policy=None, parent=<flax.linen.module._Sentinel object>, name=None)[source]#

Bases: PipelineBase

Original Pipeline implementation.

Parameters:
  • config (Any)

  • layers (Module)

  • mesh (Mesh)

  • remat_policy (Any)

  • parent (Module | Scope | _Sentinel | None)

  • name (str | None)

init_states(inputs)[source]#

Initialize components of state: state_io, shift, circular_storage and circular_storage_mover Assumes input has already been reshaped into microbatches: [num_micro_batches, micro_batch_size, sequence, embed]

Returns a dictionary with properties

shift: zeros shape [num_stages, micro_size, sequence, embed] prev_outputs: same shape as shift, only used when pipeline_delay_activation_forwarding is set to true, else None state_io: reshaped inputs [num_stages, microbatches/stages, micro_size, sequence, embed] circ_storage: zeros [num_stages, microbatches, micro_size, sequence, embed] when needed, else None circ_storage_mover: zeros[num_stages, micro_size, sequence, embed] when needed, else None loop_iteration: scalar set initially to 0 bsw: pytree of identical structure as weights with leaf arrays leading dimension of num_repeats replaced by 2, e.g.

a leaf of shape [num_repeats, stages, mlp, embed] is mapped to [2, num_stages, mlp, embed].

shard_dim_by_stages(x, dim, physical_partition_spec, is_stage_weight=False)[source]#

Shards x using the provided partition_spec, but adds the “stage” mesh axis to the existing sharding at the specified dimension.

Parameters:
  • dim (int)

  • physical_partition_spec (PartitionSpec | None)

  • is_stage_weight (bool)

vmap_parallel_gather(weights, physical_partition_spec, repeat_ids, repeat_dim_in_weights, stages_dim_in_weights)[source]#

Use vmap to implement a sharded parallel gather. Parallel gather means each stage has its own weights, and gets one slice from it. :param weights: Per-stage data to be gathered from. :param repeat_ids: Integer tensor of shape [num_stages], the repeats of the stages. :param repeat_dim_in_weights: The dimension in weights where repeat_ids are applied. The output will not

have this dimension.

Parameters:

stages_dim_in_weights – The dimension in weights that represents parallel stages.

Returns:

The per-stage gathered values. The shape is weights.shape but with repeat_dim_in_weights

removed.

vmap_gather(xs, ids, ids_dim)[source]#

Use vmap to implement a stage-wise sharded gather.

The stages share the same input, but they have different offsets.

Parameters:
  • xs – Data shared by all stages, to be gathered from.

  • ids – Integer tensor of shape [num_stages], the offsets of the stages.

  • ids_dim – The dimension in xs where ids are applied. In the output, this dimension will be [num_stages], since each stage gets one slice.

Returns:

The per-stage gathered values. The shape is xs.shape but with ids_dim size

replaced with [num_stages].

get_new_loop_state(output, loop_state)[source]#

Update the various buffers given the output of the most recent iteration * state_io: rotates left/up by 1 (the whole created in the last slot is filled with the most recent pipeline output)

  • Pushing inputs up from top of state_io into first stage of shift

  • Pulling outputs up from last stage of shift into bottom of state_io

  • shift: rotate output (or prev_outputs if using delay) right/down by 1 - we imagine the pipeline moves to

    right/down

  • circ_storage: pushes circ_storage_mover (the output of the previous iteration) into rotating index of circ_storage

  • circ_storage_mover: assigned to rotated output and pushed into circ_storage on the next iteration

  • prev_outputs: is set to the current output

permute_output_micro_per_stage_dim(output)[source]#

Permutes the output microbatches to match the input order.

The pipeline execution introduces a delay (bubble) for each stage. Consequently, the first microbatch (index 0) finishes after a certain number of iterations and lands at a shifted position in the output buffer (state_io). This function calculates the offset (microbatch_0_idx) and permutes the output along the microbatch dimension so that microbatch 0 is at index 0, microbatch 1 at index 1, etc.

get_current_stage_weights(pipeline_weights, loop_iteration, physical_partition_spec=None)[source]#

Gets the current weights used for one iteration. Outputs a pytree whose arrays have leading dimension of stages, e.g. {‘mlp’: ‘wo’: [stages, mlp, embed]}. Stage 0 will use the 0th index of this pytree, Stage 1 the 1st index, etc. For non-circular pipelines, this simply returns all weights - every weight is used in every iteraiton. However for circular pipelines each stage grabs only the weights corresponding to the current repeat.

get_current_repeat_from_stages(weights, loop_iteration, physical_partition_spec=None)[source]#

Fetches the weights for the current repeat from the stages.

run_one_iteration(loop_state, pipeline_weights, positions, segment_ids, deterministic, model_mode, decoder_layer_instance, logical_partition_spec=None)[source]#

Run one loop iteration - gets weights and inputs for each stage, run the stages in parallel, and update the loop state.

Parameters:
  • loop_state – Dictionary containing the current state of the pipeline (state_io, shift, etc.)

  • positions – Positional encodings.

  • segment_ids – Segment IDs for packed sequences.

  • deterministic – Boolean indicating if execution should be deterministic (e.g. for dropout).

  • model_mode – Current model mode (train/predict).

  • logical_partition_spec – Logical partition specification for weights.

static get_logical_spec_repeats_removed(full_logical)[source]#

Returns a new logical spec with ‘circular_repeats’ removed.

all_gather_over_fsdp(variables, logical_partition_spec)[source]#

Gathers FSDP partitioned variables to reconstruct them fully.

name: str | None = None#
parent: Module | Scope | _Sentinel | None = None#
scope: Scope | None = None#
class maxtext.layers.pipeline.CircularPipeline(config, layers, mesh, remat_policy=None, parent=<flax.linen.module._Sentinel object>, name=None)[source]#

Bases: PipelineBase

Implements an circular pipeline schedule with asynchronous weight prefetching.

Circular pipelining reduces the pipeline “bubble” by interleaving multiple pipeline stages on the same physical devices. To hide the communication overhead of Fully Sharded Data Parallelism (FSDP), this module utilizes a Buffer Sliding Window (BSW) to prefetch and all-gather the weights for the next pipeline repeat while the current repeat is executing.

Parameters:
  • config (Any)

  • layers (Module)

  • mesh (Mesh)

  • remat_policy (Any)

  • parent (Module | Scope | _Sentinel | None)

  • name (str | None)

init_states(inputs)[source]#

Initializes the pipeline execution state and communication buffers.

This sets up the memory needed to pass activations between pipeline stages (state_io and shift) and allocates the empty Buffer Sliding Window (BSW) that will hold the gathered FSDP weights.

gather_weights_across_stages_vmap(weights, repeat_ids, repeat_dim_in_weights, stages_dim_in_weights)[source]#

Uses jax.vmap to dynamically slice and gather weights for specific pipeline repeats.

gather_microbatch_inputs_vmap(xs, ids, ids_dim)[source]#

Slices out the specific sequence inputs (e.g., positions, segments) for the current microbatch.

advance_circular_buffers(output, loop_state)[source]#

Rotates pipeline activations to the next physical device stage.

Uses jax.lax.ppermute to perform cross-device ring communication, shifting the forward activations (state_io and shift) from stage $i$ to stage $i+1$.

realign_output_microbatches(output)[source]#

Reorders the output tensor to reverse the circular shifts applied during execution.

Because the pipeline operates circularly, the output microbatches are shifted out of order by the time the final stage is completed. This rolls them back into their original sequential layout.

fetch_active_stage_weights(bsw, loop_iteration, physical_partition_spec=None, is_initializing=None)[source]#

The module fetches the actively prefetched weights from the Buffer Sliding Window to avoid mid-iteration FSDP all-gathers.

get_current_weights_from_bsw(bsw, loop_iteration, physical_partition_spec, is_initializing=None)[source]#

Pulls the fully gathered parameters for the current repeat from the BSW dual-buffer.

from_all_variables_to_repeat_weights(weights, loop_iteration)[source]#

Gathers weights corresponding to the repeat IDs for current iteration.

from_repeat_weights_to_bsw(repeat_weights, physical_partition_spec, axes_to_gather=('fsdp', 'fsdp_transpose', 'context', 'expert'), use_shardmap=False)[source]#

Executes the FSDP-like all-gathers to fully materialize a block of weights for the BSW.

weight_prefetching(weights, physical_partition_spec, loop_iteration)[source]#

Triggers asynchronous FSDP-like all-gathers for the next pipeline steps.

By gathering weights for loop_iteration + 1 right now, the network communication can overlap with the compute happening in loop_iteration.

run_one_iteration(loop_state, bsw, positions, segment_ids, deterministic, model_mode, logical_partition_spec)[source]#

Executes the forward/backward logic for a single microbatch inside the pipeline.

This acts as the core step function that our jax.lax.scan wrappers call. It routes the active BSW weights, sequences, and position IDs into the layer blocks, and then advances the pipeline communication buffers via advance_circular_buffers.

name: str | None = None#
parent: Module | Scope | _Sentinel | None = None#
scope: Scope | None = None#
maxtext.layers.pipeline.create_pipeline(config, layers, mesh, remat_policy=None)[source]#

Factory function to instantiate the correct Pipeline module based on config.

Parameters:
  • config (Any)

  • layers (Module)

  • mesh (Mesh)

  • remat_policy (Any)

Return type:

PipelineBase