# Copyright 2023-2026 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Pipeline layer wrapping a decoder layer(s). Supports circular pipelining."""
import functools
from typing import Any
import numpy as np
from jax import numpy as jnp
from jax.sharding import Mesh, NamedSharding, PartitionSpec as P
import jax
import jax.ad_checkpoint
from flax.core import meta
from flax import linen as nn
from flax.linen.spmd import LogicallyPartitioned
from maxtext.common.common_types import Config, MODEL_MODE_TRAIN, ShardMode
from maxtext.utils.sharding import (
maybe_shard_with_logical,
maybe_shard_with_name,
create_sharding,
logical_to_mesh_axes,
logical_to_mesh,
)
from maxtext.utils import pipeline_utils
[docs]
class PipelineBase(nn.Module):
"""Base module that implements shared pipelining logic across stages."""
config: Config
layers: nn.Module
mesh: Mesh
remat_policy: Any = None
[docs]
def setup(self):
"""Initializes the configuration, calculating num_stages, delay, axes, and partition specs."""
self.num_stages = self.config.ici_pipeline_parallelism * self.config.dcn_pipeline_parallelism
self.forwarding_delay = 2 if self.config.pipeline_delay_activation_forwarding else 1
self.pipeline_microbatch_size = self.config.micro_batch_size_to_train_on // self.config.num_pipeline_microbatches
microbatches_per_stage = self.config.num_pipeline_microbatches // self.num_stages
self.microbatches_per_stage = microbatches_per_stage
self.use_circ_storage = self.need_circ_storage()
self.batch_axis_name = "activation_batch"
self.seq_len_axis_name = "activation_length"
self.spmd_axis_name = "stage" if self.config.shard_mode == ShardMode.AUTO else None
self.stages_in_logical = ("activation_stage", self.batch_axis_name, self.seq_len_axis_name, "activation_embed")
self.stages_in_spec = logical_to_mesh_axes(self.stages_in_logical, self.mesh, rules=self.config.logical_axis_rules)
self.stages_in_sharding = (
NamedSharding(self.mesh, self.stages_in_spec) if self.config.shard_mode == ShardMode.EXPLICIT else None
)
self.state_io_logical = ("activation_stage", None, self.batch_axis_name, self.seq_len_axis_name, "activation_embed")
self.state_io_spec = logical_to_mesh_axes(self.state_io_logical, self.mesh, rules=self.config.logical_axis_rules)
self.state_io_sharding = (
NamedSharding(self.mesh, self.state_io_spec) if self.config.shard_mode == ShardMode.EXPLICIT else None
)
self.input_sharding = (
create_sharding(
self.mesh,
(None, self.batch_axis_name, self.seq_len_axis_name, "activation_embed"),
rules=self.config.logical_axis_rules,
)
if self.config.shard_mode == ShardMode.EXPLICIT
else None
)
self.output_sharding = (
create_sharding(
self.mesh,
(self.batch_axis_name, self.seq_len_axis_name, "activation_embed"),
rules=self.config.logical_axis_rules,
)
if self.config.shard_mode == ShardMode.EXPLICIT
else None
)
[docs]
def need_circ_storage(self):
return (
self.config.num_pipeline_repeats > 1
and self.config.num_pipeline_microbatches > self.num_stages * self.forwarding_delay
)
[docs]
def iterations_to_complete_first_microbatch_one_repeat(self):
# Return the number of iterations it takes for microbatch 0 to finish a repeat
return self.forwarding_delay * (self.num_stages - 1)
[docs]
def iterations_to_complete_first_microbatch(self):
# Return the number of iterations it takes for microbatch 0 to finish the last stage of the last repeat
return (
self.config.num_pipeline_microbatches * (self.config.num_pipeline_repeats - 1)
+ self.iterations_to_complete_first_microbatch_one_repeat()
)
def _maybe_shard_with_logical(self, inputs, logical_axes):
"""Wrapper of maybe_shard_with_logical"""
return maybe_shard_with_logical(
inputs,
logical_axes,
shard_mode=self.config.shard_mode,
mesh=self.mesh,
rules=self.config.logical_axis_rules,
debug_sharding=self.config.debug_sharding,
extra_stack_level=1,
)
def _maybe_shard_with_name(self, inputs, sharding_name):
"""Wrapper of maybe_shard_with_name"""
return maybe_shard_with_name(
inputs,
sharding_name,
shard_mode=self.config.shard_mode,
debug_sharding=self.config.debug_sharding,
)
[docs]
def get_microbatch_and_repeat_ids(self, loop_iteration):
"""Gets the microbatch_ids and repeat_ids for all stages on this loop_iteration. Works for both circular and
non-circular"""
# Stage 0 has processed one microbatch every loop_iter, but Stage 1 is 1 behind due to bubble, etc for other stages
microbatches_processed = jnp.maximum(loop_iteration - self.forwarding_delay * jnp.arange(self.num_stages), 0)
microbatches_processed = self._maybe_shard_with_name(microbatches_processed, NamedSharding(self.mesh, P("stage")))
microbatch_ids = microbatches_processed % self.config.num_pipeline_microbatches
repeat_ids = microbatches_processed // self.config.num_pipeline_microbatches
return microbatch_ids, repeat_ids
[docs]
def get_pipeline_remat_policy(self):
"""Returns the pipeline remat policy for this pipeline."""
if self.config.remat_policy == "custom":
return self.remat_policy
save_input_policy = jax.checkpoint_policies.save_only_these_names("iteration_input", "decoder_layer_input")
if self.remat_policy is not None:
remat_policy = jax.checkpoint_policies.save_from_both_policies(self.remat_policy, save_input_policy)
else:
remat_policy = save_input_policy
return remat_policy
[docs]
def get_weight_sharding(self, *init_args):
"""get weight sharding function for this pipeline."""
key = jax.random.PRNGKey(0)
keys = {"params": key, "dropout": key, "aqt": key}
weights = self.init(keys, *init_args)
def get_partition_spec(pytree):
def _is_leaf(x):
return isinstance(x, nn.spmd.LogicallyPartitioned)
def get_partition_spec_leaf(leaf):
return leaf.get_partition_spec()
return jax.tree.map(get_partition_spec_leaf, pytree, is_leaf=_is_leaf)
partition_spec_with_extra_layer = get_partition_spec(weights)
logical_partition_spec = {"params": partition_spec_with_extra_layer["params"]["layers"]}
return logical_partition_spec
[docs]
def get_vmap_func_for_init(self):
"""This vmap func is used to initialize the weights only on init."""
def func_to_vmap(body_instance, stages_inputs, stages_segment_ids, stages_positions, deterministic, model_mode):
return body_instance(stages_inputs, stages_segment_ids, stages_positions, deterministic, model_mode)
vmap_func = nn.vmap(
func_to_vmap,
in_axes=(0, 0, 0, None, None),
spmd_axis_name=self.spmd_axis_name,
variable_axes={"params": 0, "_overwrite_with_gradient": 0},
split_rngs={"params": self.is_initializing(), "dropout": self.config.enable_dropout},
metadata_params={
nn.PARTITION_NAME: "layers",
"sub_weight_split_dims_mapping": (None),
"is_initializing": self.is_initializing(),
"x_times": self.num_stages,
},
)
return vmap_func
[docs]
def get_main_vmap_func_for_iterations(self):
"""
Returns main stage function vmapped by number of stages.
This becomes a vmap over a single layer instance if body_instance is a single layer,
else a set of layers if body_instance is a set of layers.
"""
def func_to_vmap(
body_instance, weights, stages_inputs, stages_segment_ids, stages_positions, deterministic, model_mode
):
weights = meta.remove_axis(
weights,
0,
{
nn.PARTITION_NAME: "layers",
"sub_weight_split_dims_mapping": (None,),
"is_initializing": self.is_initializing(),
"x_times": self.num_stages,
},
)
return body_instance.apply(weights, stages_inputs, stages_segment_ids, stages_positions, deterministic, model_mode)
vmap_func = nn.vmap(
func_to_vmap,
in_axes=(0, 0, 0, 0, None, None),
spmd_axis_name=self.spmd_axis_name,
variable_axes={"params": 0},
split_rngs={"params": self.is_initializing(), "dropout": self.config.enable_dropout},
metadata_params={
nn.PARTITION_NAME: "layers",
"sub_weight_split_dims_mapping": (None),
"is_initializing": self.is_initializing(),
"x_times": self.num_stages,
},
)
return vmap_func
def _run_weight_initialization(
self, example_inputs, example_segmentation, example_position, segment_idx, position_idx, deterministic, model_mode
):
"""Runs the initialization sequence mapping layers appropriately based on pipeline settings."""
vmap_func = self.get_vmap_func_for_init()
if self.config.num_pipeline_repeats > 1:
vmap_func = nn.vmap(
vmap_func,
in_axes=(0, segment_idx, position_idx, None, None),
variable_axes={"params": 0, "_overwrite_with_gradient": 0, "non_trainable": 0, "hyper_params": 0},
split_rngs={"params": True, "dropout": self.config.enable_dropout},
metadata_params={
nn.PARTITION_NAME: "circular_repeats",
"sub_weight_split_dims_mapping": (None,),
"is_initializing": True,
"x_times": self.config.num_pipeline_repeats,
"optimizer_dims_mapping": None,
},
)
example_inputs = jax.lax.broadcast(example_inputs, [self.config.num_pipeline_repeats])
example_segmentation = (
jax.lax.broadcast(example_segmentation, [self.config.num_pipeline_repeats])
if example_segmentation is not None
else None
)
example_position = (
jax.lax.broadcast(example_position, [self.config.num_pipeline_repeats])
if example_position is not None
else None
)
example_inputs = self._maybe_shard_with_logical(example_inputs, (None, None, None, None))
stage_outputs = vmap_func(
self.layers, example_inputs, example_segmentation, example_position, deterministic, model_mode
)
if self.config.scan_layers:
stage_outputs = stage_outputs[0]
if self.config.num_pipeline_repeats > 1:
stage_outputs = stage_outputs[0]
broadcasted_stage_outpus = jax.lax.broadcast(
stage_outputs[0], [self.config.micro_batch_size_to_train_on // self.pipeline_microbatch_size]
)
return jnp.reshape(
broadcasted_stage_outpus,
[self.config.micro_batch_size_to_train_on, self.config.max_target_length, self.config.emb_dim],
out_sharding=self.output_sharding,
)
@staticmethod
def _remove_fsdp_from_physical_partition_spec(pps):
"""Removes 'fsdp' and 'fsdp_transpose' from physical partition spec."""
if isinstance(pps, P):
new_spec = []
# Iterate through each axis in the original PartitionSpec.
for axis in pps:
if axis is None:
new_spec.append(None)
elif isinstance(axis, str):
# If the axis is 'fsdp', replace it with None to signify replication.
if axis not in ("fsdp", "fsdp_transpose"):
new_spec.append(axis)
else:
new_spec.append(None)
elif isinstance(axis, (list, tuple)):
# If the axis is a collection, filter out 'fsdp'.
new_axis = [a for a in axis if a not in ("fsdp", "fsdp_transpose")]
new_spec.append(tuple(new_axis))
else:
raise ValueError(f"Unsupported_axis_type: {type(axis)}")
# Return a new sharding object with the modified spec.
return P(*new_spec)
return pps
[docs]
class Pipeline(PipelineBase):
"""Original Pipeline implementation."""
[docs]
def init_states(self, inputs):
"""Initialize components of state: state_io, shift, circular_storage and circular_storage_mover
Assumes input has already been reshaped into microbatches: [num_micro_batches, micro_batch_size, sequence, embed]
Returns a dictionary with properties
shift: zeros shape [num_stages, micro_size, sequence, embed]
prev_outputs: same shape as shift, only used when pipeline_delay_activation_forwarding is set to true, else None
state_io: reshaped inputs [num_stages, microbatches/stages, micro_size, sequence, embed]
circ_storage: zeros [num_stages, microbatches, micro_size, sequence, embed] when needed, else None
circ_storage_mover: zeros[num_stages, micro_size, sequence, embed] when needed, else None
loop_iteration: scalar set initially to 0
bsw: pytree of identical structure as weights with leaf arrays leading dimension of num_repeats replaced by 2, e.g.
a leaf of shape [num_repeats, stages, mlp, embed] is mapped to [2, num_stages, mlp, embed].
"""
# Shift is used to rotate the output of each pipeline into the input of the next
# shift has shape [num_stages, micro_size, sequence, embed]
shift = jnp.zeros((self.num_stages,) + inputs.shape[1:], dtype=inputs.dtype)
shift = self._maybe_shard_with_logical(shift, self.stages_in_logical)
# Prev outputs has the same shape of the output (and shift)
if self.config.pipeline_delay_activation_forwarding:
prev_outputs = jnp.zeros((self.num_stages,) + inputs.shape[1:], dtype=inputs.dtype)
prev_outputs = self._maybe_shard_with_logical(prev_outputs, self.stages_in_logical)
else:
prev_outputs = None
# state_io (state input output) at first holds all of the input batches, but also will hold the outputs
# as the pipeline runs/finishes
# state_io has shape [num_stages, microbatches/stages, micro_size, sequence, embed]
state_io = jnp.reshape(
inputs, (self.num_stages, self.microbatches_per_stage) + inputs.shape[1:], out_sharding=self.state_io_sharding
)
# We shard the pipeline_microbatch_size axis by data/fsdp, not num_microbatches since those are looped over.
state_io = self._maybe_shard_with_logical(state_io, self.state_io_logical)
# circ_storage is used to hold the final pipeline stage outputs before it is used for the next repeat. It is only
# needed when num_microbatches > num_stages, else instead the final stage will immediately pass to the first without
# additional storage.
# circ_storage has shape [num_stages, microbatches, micro_size, sequence, embed].
# Note that this shape is a factor of num_stages larger than necessary - each stage holds the global batch, but only
# stage 0 holds the real activations (since it will use them), the rest hold dummy ones. This amount of storage
# [global_batch, sequence, embed] is fine as long as there is some amount of additional sharding axes, e.g. FSDP,
# TP, DP (e.g. there are many devices that shard stage 0)
# We may look into alternatives using less storage if this becomes an issue (ideas in b/347603101).
# circ_storage_mover is used to push the microbatches from the pipeline into circ_storage with one buffer iteration
# of delay circ_storage_mover shape is same as shift: [num_stages, micro_size, sequence, embed]
if self.use_circ_storage:
circ_storage = jnp.zeros((self.num_stages,) + inputs.shape, dtype=inputs.dtype, out_sharding=self.state_io_sharding)
circ_storage_mover = shift
else:
circ_storage = None
circ_storage_mover = None
init_loop_state = {
"state_io": state_io,
"shift": shift,
"circ_storage": circ_storage,
"circ_storage_mover": circ_storage_mover,
"loop_iteration": 0,
"prev_outputs": prev_outputs,
}
return init_loop_state
[docs]
def shard_dim_by_stages(self, x, dim: int, physical_partition_spec: P | None, is_stage_weight: bool = False):
"""Shards x using the provided partition_spec, but adds the "stage" mesh axis to the existing sharding at
the specified dimension."""
placeholder = None if self.config.shard_mode == ShardMode.EXPLICIT else P.UNCONSTRAINED
if physical_partition_spec is None:
dims_mapping = [placeholder] * x.ndim
else:
physical_partition_spec = self._remove_fsdp_from_physical_partition_spec(physical_partition_spec)
dims_mapping = list(physical_partition_spec)
# If not a stage weight, we handle the repeat dimension offset
if not is_stage_weight:
dims_mapping = [placeholder] * (dim + 1) + dims_mapping[dim:] # inflat one dimension for num_repeats
dims_mapping[dim] = "stage"
dims_mapping = tuple(dims_mapping)
# We add reduced rule only when pspec is given for a stage weight
if physical_partition_spec and is_stage_weight and self.config.shard_mode == ShardMode.EXPLICIT:
batch_mesh_axis = ["data", "fsdp"]
reduced_mark = [mesh_axis for mesh_axis in batch_mesh_axis if self.mesh.shape[mesh_axis] > 1]
pspec = P(*dims_mapping, reduced=set(reduced_mark))
else:
pspec = P(*dims_mapping)
sharding = jax.sharding.NamedSharding(self.mesh, pspec)
return self._maybe_shard_with_name(x, sharding)
[docs]
def vmap_parallel_gather(
self, weights, physical_partition_spec, repeat_ids, repeat_dim_in_weights, stages_dim_in_weights
):
"""Use vmap to implement a sharded parallel gather.
Parallel gather means each stage has its own weights, and gets one slice from it.
Args:
weights: Per-stage data to be gathered from.
repeat_ids: Integer tensor of shape [num_stages], the repeats of the stages.
repeat_dim_in_weights: The dimension in weights where repeat_ids are applied. The output will not
have this dimension.
stages_dim_in_weights: The dimension in weights that represents parallel stages.
Returns:
The per-stage gathered values. The shape is weights.shape but with repeat_dim_in_weights
removed.
"""
def _gather_one(x, repeat_id):
return jnp.squeeze(jax.lax.dynamic_slice_in_dim(x, repeat_id, 1, repeat_dim_in_weights), repeat_dim_in_weights)
gathered_weights_stage_dim = 0
repeat_ids = self.shard_dim_by_stages(repeat_ids, 0, physical_partition_spec=None)
weights = self.shard_dim_by_stages(
weights, stages_dim_in_weights, physical_partition_spec=physical_partition_spec, is_stage_weight=False
)
stage_weights = jax.vmap(_gather_one, in_axes=(stages_dim_in_weights, 0), out_axes=gathered_weights_stage_dim)(
weights, repeat_ids
)
stage_weights = self.shard_dim_by_stages(
stage_weights, gathered_weights_stage_dim, physical_partition_spec=physical_partition_spec, is_stage_weight=True
)
return stage_weights
[docs]
def vmap_gather(self, xs, ids, ids_dim):
"""Use vmap to implement a stage-wise sharded gather.
The stages share the same input, but they have different offsets.
Args:
xs: Data shared by all stages, to be gathered from.
ids: Integer tensor of shape [num_stages], the offsets of the stages.
ids_dim: The dimension in xs where ids are applied. In the output, this
dimension will be [num_stages], since each stage gets one slice.
Returns:
The per-stage gathered values. The shape is xs.shape but with ids_dim size
replaced with [num_stages].
"""
def _gather_one(x, i):
idx = tuple(i if d == ids_dim else slice(None) for d in range(x.ndim))
replicated_sharding = NamedSharding(self.mesh, P())
return x.at[idx].get(out_sharding=replicated_sharding)
ids = self.shard_dim_by_stages(ids, 0, physical_partition_spec=None)
outs = jax.vmap(_gather_one, in_axes=(None, 0), out_axes=ids_dim)(xs, ids)
return self.shard_dim_by_stages(outs, 0, physical_partition_spec=None)
[docs]
def get_new_loop_state(self, output, loop_state):
"""
Update the various buffers given the output of the most recent iteration
* state_io: rotates left/up by 1 (the whole created in the last slot is filled with the most recent pipeline output)
* Pushing inputs up from top of state_io into first stage of shift
* Pulling outputs up from last stage of shift into bottom of state_io
* shift: rotate output (or prev_outputs if using delay) right/down by 1 - we imagine the pipeline moves to
right/down
* circ_storage: pushes circ_storage_mover (the output of the previous iteration) into rotating index of circ_storage
* circ_storage_mover: assigned to rotated output and pushed into circ_storage on the next iteration
* prev_outputs: is set to the current output
"""
old_state_io = loop_state["state_io"]
old_circ_storage = loop_state["circ_storage"]
old_circ_storage_mover = loop_state["circ_storage_mover"]
loop_iteration = loop_state["loop_iteration"]
old_prev_outputs = loop_state["prev_outputs"]
@jax.shard_map(mesh=self.mesh, in_specs=self.stages_in_spec, out_specs=self.stages_in_spec, check_vma=True)
def _rotate_right(arr):
# we use +1 for right shifting
stage_size = jax.lax.axis_size("stage")
perm = [(i, (i + 1) % stage_size) for i in range(stage_size)]
arr = jax.lax.ppermute(arr, axis_name="stage", perm=perm)
return arr
@jax.shard_map(mesh=self.mesh, in_specs=self.stages_in_spec, out_specs=self.stages_in_spec, check_vma=True)
def _shift_right(arr):
stage_idx = jax.lax.axis_index("stage")
stage_size = jax.lax.axis_size("stage")
perm = [(i, (i + 1) % stage_size) for i in range(stage_size)]
arr = jax.lax.ppermute(arr, axis_name="stage", perm=perm)
return jnp.where(stage_idx == 0, jnp.zeros_like(arr), arr)
# Shift either rotates or shifts depending on if the last stage immediately must send to first or not
# For non-circular pipelines, the last stage does not need to send to first
# For circular pipelines with #micro = #stages, last stage immediately sends to first
# For circular pipelines with #micro > stages (circ_storage), last stage sends to circ storage
def _update_shift(output_in):
if self.config.num_pipeline_repeats == 1 or self.use_circ_storage:
return _shift_right(output_in) # last stage does not have to send to first immediately
else:
return _rotate_right(output_in) # last stage must immediately send to first
if self.config.pipeline_delay_activation_forwarding:
new_shift = _update_shift(old_prev_outputs)
new_prev_outputs = output
else:
new_shift = _update_shift(output)
new_prev_outputs = None
if self.use_circ_storage:
# Insert the circ_storage_mover into new_circ_storage at a microbatch-rotating index.
# circ_storage_mover still points to the output of PREVIOUS iteration, which should aid in allowing overlapped
# compute/async transfers
def _rotate_right_and_update(circ_storage_mover_in, circ_storage_in):
rotated = _rotate_right(circ_storage_mover_in)
rotated = jnp.expand_dims(rotated, 1)
# We rotate the pushing index into circ storage, and ensure that microbatch 0 lands in index 0
offset = (
loop_iteration - self.iterations_to_complete_first_microbatch_one_repeat() - 1
) % self.config.num_pipeline_microbatches
# previous output - using circ_storage_mover before it is updated
return jax.lax.dynamic_update_slice_in_dim(circ_storage_in, rotated, offset, axis=1)
new_circ_storage = _rotate_right_and_update(old_circ_storage_mover, old_circ_storage)
new_circ_storage_mover = output
else:
new_circ_storage = None
new_circ_storage_mover = None
# Rotate stream_io left/up by 1 on rotating micro/stage index (stream_buf_idx), replacing the last/bottom with the
# last stage output
stream_buf_idx = loop_iteration % self.microbatches_per_stage
stream_slice = old_state_io[:, stream_buf_idx]
def _rotate_left(arr, stage_size):
# we use -1 for left shifting
perm = [(i, (i - 1) % stage_size) for i in range(stage_size)]
return jax.lax.ppermute(arr, axis_name="stage", perm=perm)
def _shift_left(arr, stage_size, output):
stage_idx = jax.lax.axis_index("stage")
arr = _rotate_left(arr, stage_size)
return jnp.where(stage_idx == stage_size - 1, output, arr)
@jax.shard_map(
mesh=self.mesh,
in_specs=(self.state_io_spec, self.stages_in_spec, self.stages_in_spec, P()),
out_specs=self.state_io_spec,
)
def _update_state_io(state_in, stream_slice, output, stream_buf_idx):
# Shift the current slice to the left, then fill the last stage with the final output.
stage_size = jax.lax.axis_size("stage")
stream_slice = _shift_left(stream_slice, stage_size, output)
stream_slice = jnp.expand_dims(stream_slice, 1)
return jax.lax.dynamic_update_slice_in_dim(state_in, stream_slice, stream_buf_idx, axis=1)
new_state = _update_state_io(old_state_io, stream_slice, output, stream_buf_idx)
new_loop_state = {
"state_io": new_state,
"shift": new_shift,
"circ_storage": new_circ_storage,
"circ_storage_mover": new_circ_storage_mover,
"loop_iteration": loop_iteration + 1,
"prev_outputs": new_prev_outputs,
}
return new_loop_state
[docs]
def permute_output_micro_per_stage_dim(self, output):
"""
Permutes the output microbatches to match the input order.
The pipeline execution introduces a delay (bubble) for each stage.
Consequently, the first microbatch (index 0) finishes after a certain number of iterations
and lands at a shifted position in the output buffer (`state_io`).
This function calculates the offset (`microbatch_0_idx`) and permutes the output
along the microbatch dimension so that microbatch 0 is at index 0, microbatch 1 at index 1, etc.
"""
# The first real output (microbatch 0) takes a certain amount of loop iterations to finish and be pushed to
# state_io - it will land on a different index of state_io depending on the number of iterations.
microbatch_0_idx = self.iterations_to_complete_first_microbatch() % self.microbatches_per_stage
permutation = (np.arange(self.microbatches_per_stage) + microbatch_0_idx) % self.microbatches_per_stage
output = output[:, permutation]
return output
[docs]
def get_current_stage_weights(self, pipeline_weights, loop_iteration, physical_partition_spec=None):
"""
Gets the current weights used for one iteration. Outputs a pytree whose arrays have leading dimension of stages, e.g.
{'mlp': 'wo': [stages, mlp, embed]}. Stage 0 will use the 0th index of this pytree, Stage 1 the 1st index, etc.
For non-circular pipelines, this simply returns all weights - every weight is used in every iteraiton. However
for circular pipelines each stage grabs only the weights corresponding to the current repeat.
"""
if self.config.num_pipeline_repeats > 1:
return self.get_current_repeat_from_stages(
pipeline_weights, loop_iteration, physical_partition_spec=physical_partition_spec
)
else:
return pipeline_weights
[docs]
def get_current_repeat_from_stages(self, weights, loop_iteration, physical_partition_spec=None):
"""Fetches the weights for the current repeat from the stages."""
_, repeat_ids = self.get_microbatch_and_repeat_ids(loop_iteration)
circular_metadata_params = {
nn.PARTITION_NAME: "circular_repeats",
"sub_weight_split_dims_mapping": (None,),
"is_initializing": self.is_initializing(),
"x_times": self.config.num_pipeline_repeats,
"optimizer_dims_mapping": None,
}
# Remove the circular metadata axis, this axis will be removed when passed to the main vmap,
# only one circular entry per stage.
weights = meta.remove_axis(weights, 0, circular_metadata_params)
weights = self._remove_logically_partition(weights)
def gather_weights_for_stages_in(w, spec=None):
return self.vmap_parallel_gather(
w, repeat_ids=repeat_ids, repeat_dim_in_weights=0, stages_dim_in_weights=1, physical_partition_spec=spec
)
if physical_partition_spec is None:
weights = jax.tree.map(gather_weights_for_stages_in, weights)
else:
weights = jax.tree.map(gather_weights_for_stages_in, weights, physical_partition_spec)
return weights
[docs]
def run_one_iteration(
self,
loop_state,
pipeline_weights,
positions,
segment_ids,
deterministic,
model_mode,
decoder_layer_instance,
logical_partition_spec=None,
):
"""Run one loop iteration - gets weights and inputs for each stage, run the stages in parallel,
and update the loop state.
Args:
loop_state: Dictionary containing the current state of the pipeline (state_io, shift, etc.)
positions: Positional encodings.
segment_ids: Segment IDs for packed sequences.
deterministic: Boolean indicating if execution should be deterministic (e.g. for dropout).
model_mode: Current model mode (train/predict).
logical_partition_spec: Logical partition specification for weights.
"""
state_io = loop_state["state_io"]
shift = loop_state["shift"]
circ_storage = loop_state["circ_storage"]
loop_iteration = loop_state["loop_iteration"]
microbatch_ids, _ = self.get_microbatch_and_repeat_ids(loop_iteration)
physical_partition_spec = logical_to_mesh(logical_partition_spec, self.mesh, rules=self.config.logical_axis_rules)
stages_inputs = self.get_iteration_inputs(loop_iteration, state_io, circ_storage, shift)
stages_inputs = jax.ad_checkpoint.checkpoint_name(stages_inputs, "iteration_input")
stages_positions = self.vmap_gather(positions, microbatch_ids, 0) if positions is not None else None
stages_segment_ids = self.vmap_gather(segment_ids, microbatch_ids, 0) if segment_ids is not None else None
vmap_func = self.get_main_vmap_func_for_iterations()
if self.config.num_pipeline_repeats > 1:
_, repeat_ids = self.get_microbatch_and_repeat_ids(loop_iteration)
def prepare_vars_for_main_vmap(weights, physical_partition_spec=None):
circular_metadata_params = {
nn.PARTITION_NAME: "circular_repeats",
"sub_weight_split_dims_mapping": (None,),
"is_initializing": self.is_initializing(),
"x_times": self.config.num_pipeline_repeats,
"optimizer_dims_mapping": None,
}
weights = meta.remove_axis(weights, 0, circular_metadata_params)
weights = self._remove_logically_partition(weights)
def gather_weights_for_stages_in(w, spec=None):
return self.vmap_parallel_gather(
w, repeat_ids=repeat_ids, repeat_dim_in_weights=0, stages_dim_in_weights=1, physical_partition_spec=spec
)
if physical_partition_spec is None:
weights = jax.tree.map(gather_weights_for_stages_in, weights)
else:
weights = jax.tree.map(gather_weights_for_stages_in, weights, physical_partition_spec)
return weights
prepare_vars_for_main_vmap_partial = functools.partial(
prepare_vars_for_main_vmap, physical_partition_spec=physical_partition_spec
)
vmap_func = nn.map_variables(
vmap_func,
mapped_collections=["params", "_overwrite_with_gradient", "non_trainable", "summaries", "intermediates"],
mutable=True,
trans_in_fn=prepare_vars_for_main_vmap_partial,
)
stage_weights = self.get_current_stage_weights(
pipeline_weights, loop_iteration, physical_partition_spec=physical_partition_spec
)
stages_output = vmap_func(
decoder_layer_instance,
stage_weights,
stages_inputs,
stages_segment_ids,
stages_positions,
deterministic,
model_mode,
)
if self.config.scan_layers:
stages_output = stages_output[0]
new_state = self.get_new_loop_state(stages_output, loop_state)
return new_state
[docs]
@staticmethod
def get_logical_spec_repeats_removed(full_logical):
"""Returns a new logical spec with 'circular_repeats' removed."""
if full_logical is None:
return None
def _remove_from_spec(spec):
return jax.sharding.PartitionSpec(*[dim for dim in spec if dim != "circular_repeats"])
return jax.tree.map(_remove_from_spec, full_logical)
@staticmethod
def _remove_logically_partition(weights):
"""Removes LogicallyPartitioned wrappers from the variables."""
def _remove_logically_partition_leaf(v):
return getattr(v, "value") if isinstance(v, LogicallyPartitioned) else v
return jax.tree.map(_remove_logically_partition_leaf, weights, is_leaf=lambda v: isinstance(v, LogicallyPartitioned))
[docs]
def all_gather_over_fsdp(self, variables, logical_partition_spec):
"""Gathers FSDP partitioned variables to reconstruct them fully."""
physical_partition_spec = logical_to_mesh(
logical_partition_spec, mesh=self.mesh, rules=self.config.logical_axis_rules
)
physical_partition_spec_no_fsdp = jax.tree.map(
self._remove_fsdp_from_physical_partition_spec, physical_partition_spec
)
return jax.tree.map(
lambda w, p: self._maybe_shard_with_name(w, NamedSharding(self.mesh, p)),
variables,
physical_partition_spec_no_fsdp,
)
@nn.compact
def __call__(
self,
inputs: jnp.ndarray,
segment_ids: jnp.ndarray,
positions: jnp.ndarray,
deterministic: bool,
model_mode=MODEL_MODE_TRAIN,
logical_partition_spec=None, # Pytree of sharding specifications of the weights (aka self.layers.variables)
) -> jnp.ndarray:
"""The main method that maps the series of decoder layer inputs to final layer outputs.
Has the same signature of a single decoder layer, and expects the same shapes, e.g. the inputs should have shape
[global_batch], and internally this will be reshapped into microbatches.
"""
inputs = inputs.reshape(
(
self.config.num_pipeline_microbatches,
self.pipeline_microbatch_size,
self.config.max_target_length,
self.config.emb_dim,
),
out_sharding=self.input_sharding,
)
example_inputs = jax.lax.broadcast(inputs[0], [self.num_stages])
ag_sharding = jax.sharding.NamedSharding(self.mesh, jax.sharding.PartitionSpec(None, None))
if positions is not None:
positions = self._maybe_shard_with_name(positions, ag_sharding)
positions = positions.reshape(
(self.config.num_pipeline_microbatches, self.pipeline_microbatch_size, self.config.max_target_length)
)
example_position = jax.lax.broadcast(positions[0], [self.num_stages])
position_idx = 0
else:
example_position = None
position_idx = None
if segment_ids is not None:
segment_ids = self._maybe_shard_with_name(segment_ids, ag_sharding)
segment_ids = segment_ids.reshape(
(self.config.num_pipeline_microbatches, self.pipeline_microbatch_size, self.config.max_target_length)
)
example_segmentation = jax.lax.broadcast(segment_ids[0], [self.num_stages])
segment_idx = 0
else:
example_segmentation = None
segment_idx = None
loop_state = self.init_states(inputs)
# Each microbatch should go through each stage (with repeats) - so there is num_micro * (num_stages * repeats)
# compute to perform
# Each iteration is vmapped by num_stages, so the number of iterations should be
# num_micro * num_stages * repeats / num_stages = num_micro * repeats
# However due to the pipeline bubble some iterations process less than num_stages microbatches. It takes
# num_micro * repeat iterations for the last microbatch to start the final repeat, then an additional
# num_stages - 1 to finish the final repeat.
# Thus the total iterations is num_micro * repeat + num_stages - 1, & we may consider the num_stages - 1 as bubble.
# The bubble doubles when we use forwarding delay.
bubble_iterations = self.forwarding_delay * (self.num_stages - 1)
real_iterations = self.config.num_pipeline_microbatches * self.config.num_pipeline_repeats
total_iterations = real_iterations + bubble_iterations
if self.is_initializing():
return self._run_weight_initialization(
example_inputs, example_segmentation, example_position, segment_idx, position_idx, deterministic, model_mode
)
if self.config.pipeline_fsdp_ag_once:
variables = self._remove_logically_partition(self.layers.variables)
all_pipeline_weights = self.all_gather_over_fsdp(variables, logical_partition_spec)
else:
all_pipeline_weights = self.layers.variables
logical_partition_spec = self.get_logical_spec_repeats_removed(logical_partition_spec)
def run_iteration_scannable(model, loop_state, xs):
# flax transforms like nn.scan and nn.remat can only be applied to nn.module classes or nn.module instances, so we
# explicitly wrap the run_one_iteration in this method - the 1st argument model (`self`) is a nn.module instance.
return (
model.run_one_iteration(
loop_state,
all_pipeline_weights,
positions,
segment_ids,
deterministic,
model_mode,
model.layers,
logical_partition_spec=logical_partition_spec,
),
None,
)
if self.config.set_remat_policy_on_pipeline_iterations:
run_iteration_scannable = nn.remat(
run_iteration_scannable,
prevent_cse=not self.config.scan_pipeline_iterations, # prevent_cse not used with scan
policy=self.get_pipeline_remat_policy(),
)
if self.config.scan_pipeline_iterations:
variable_carry = []
variable_broadcast = [
"params",
"_overwrite_with_gradient",
] # All loop iterations need the weights for the full pipeline.
if self.is_mutable_collection("non_trainable"):
variable_carry.append("non_trainable")
else:
variable_broadcast.append("non_trainable")
run_all_iterations_scanned = nn.scan(
run_iteration_scannable,
variable_axes={"summaries": 0, "aux_loss": 0, "intermediates": 0, "hyper_params": 0},
variable_broadcast=variable_broadcast,
variable_carry=variable_carry,
# Dropout/aqt keys will be split for each iteration.
split_rngs={"random": True},
length=total_iterations,
)
loop_state, _ = run_all_iterations_scanned(self, loop_state, None)
else:
for _ in range(total_iterations):
loop_state, _ = run_iteration_scannable(self, loop_state, None)
# The final output is located in the input/output array, however the output microbatches may be permuted relative to
# the input
final_output = self.permute_output_micro_per_stage_dim(loop_state["state_io"])
# reshape outputs to match input shape of total batch instead of microbatches [batch, sequence, embed]
final_output = jnp.reshape(
final_output,
(self.config.micro_batch_size_to_train_on, self.config.max_target_length, self.config.emb_dim),
out_sharding=self.output_sharding,
)
return final_output
[docs]
class CircularPipeline(PipelineBase):
"""Implements an circular pipeline schedule with asynchronous weight prefetching.
Circular pipelining reduces the pipeline "bubble" by interleaving multiple pipeline
stages on the same physical devices. To hide the communication overhead of Fully
Sharded Data Parallelism (FSDP), this module utilizes a Buffer Sliding Window (BSW)
to prefetch and all-gather the weights for the *next* pipeline repeat while the
*current* repeat is executing.
"""
[docs]
def init_states(self, inputs):
"""Initializes the pipeline execution state and communication buffers.
This sets up the memory needed to pass activations between pipeline stages
(`state_io` and `shift`) and allocates the empty Buffer Sliding Window (BSW)
that will hold the gathered FSDP weights.
"""
shift = jnp.zeros((self.num_stages,) + inputs.shape[1:], dtype=inputs.dtype)
shift = self._maybe_shard_with_logical(shift, self.stages_in_logical)
if self.config.pipeline_delay_activation_forwarding:
prev_outputs = jnp.zeros((self.num_stages,) + inputs.shape[1:], dtype=inputs.dtype)
prev_outputs = self._maybe_shard_with_logical(prev_outputs, self.stages_in_logical)
else:
prev_outputs = None
state_io = jnp.reshape(
inputs, (self.num_stages, self.microbatches_per_stage) + inputs.shape[1:], out_sharding=self.state_io_sharding
)
state_io = self._maybe_shard_with_logical(state_io, self.state_io_logical)
if self.use_circ_storage:
circ_storage = jnp.zeros((self.num_stages,) + inputs.shape, dtype=inputs.dtype, out_sharding=self.state_io_sharding)
circ_storage_mover = shift
else:
circ_storage = None
circ_storage_mover = None
def _init_empty_bsw_buffers(variables):
# BSW requires two buffers (current and next) for the sliding window
return (
jax.tree.map(lambda x: jnp.zeros_like(x[0]), variables),
jax.tree.map(lambda x: jnp.zeros_like(x[0]), variables),
)
if self.is_initializing():
bsw = None
else:
variables = pipeline_utils.remove_logically_partition(self.layers.variables)
bsw = _init_empty_bsw_buffers(variables)
init_loop_state = {
"state_io": state_io,
"shift": shift,
"circ_storage": circ_storage,
"circ_storage_mover": circ_storage_mover,
"loop_iteration": 0,
"prev_outputs": prev_outputs,
}
return init_loop_state, bsw
[docs]
def gather_weights_across_stages_vmap(self, weights, repeat_ids, repeat_dim_in_weights, stages_dim_in_weights):
"""Uses jax.vmap to dynamically slice and gather weights for specific pipeline repeats."""
def _gather_single_repeat(x, repeat_id):
return jnp.squeeze(jax.lax.dynamic_slice_in_dim(x, repeat_id, 1, repeat_dim_in_weights), repeat_dim_in_weights)
gathered_weights_stage_dim = 0
stage_weights = jax.vmap(
_gather_single_repeat, in_axes=(stages_dim_in_weights, 0), out_axes=gathered_weights_stage_dim
)(weights, repeat_ids)
return stage_weights
[docs]
def advance_circular_buffers(self, output, loop_state):
"""Rotates pipeline activations to the next physical device stage.
Uses `jax.lax.ppermute` to perform cross-device ring communication, shifting
the forward activations (`state_io` and `shift`) from stage $i$ to stage $i+1$.
"""
old_state_io = loop_state["state_io"]
old_circ_storage = loop_state["circ_storage"]
old_circ_storage_mover = loop_state["circ_storage_mover"]
loop_iteration = loop_state["loop_iteration"]
@jax.shard_map(mesh=self.mesh, in_specs=self.stages_in_spec, out_specs=self.stages_in_spec, check_vma=True)
def _rotate_right(arr):
stage_size = jax.lax.axis_size("stage")
perm = [(i, (i + 1) % stage_size) for i in range(stage_size)]
return jax.lax.ppermute(arr, axis_name="stage", perm=perm)
@jax.shard_map(mesh=self.mesh, in_specs=self.stages_in_spec, out_specs=self.stages_in_spec, check_vma=True)
def _shift_right(arr):
stage_idx = jax.lax.axis_index("stage")
stage_size = jax.lax.axis_size("stage")
perm = [(i, (i + 1) % stage_size) for i in range(stage_size)]
arr = jax.lax.ppermute(arr, axis_name="stage", perm=perm)
return jnp.where(stage_idx == 0, jnp.zeros_like(arr), arr)
def _update_shift(output_in):
if self.config.num_pipeline_repeats == 1 or self.use_circ_storage:
return _shift_right(output_in)
else:
return _rotate_right(output_in)
new_shift = _update_shift(output)
new_prev_outputs = None
if self.use_circ_storage:
def _rotate_right_and_update(circ_storage_mover_in, circ_storage_in):
rotated = _rotate_right(circ_storage_mover_in)
rotated = jnp.expand_dims(rotated, 1)
offset = (
loop_iteration - self.iterations_to_complete_first_microbatch_one_repeat() - 1
) % self.config.num_pipeline_microbatches
return jax.lax.dynamic_update_slice_in_dim(circ_storage_in, rotated, offset, axis=1)
new_circ_storage = _rotate_right_and_update(old_circ_storage_mover, old_circ_storage)
new_circ_storage_mover = output
else:
new_circ_storage = None
new_circ_storage_mover = None
stream_buf_idx = loop_iteration % self.microbatches_per_stage
stream_slice = old_state_io[:, stream_buf_idx]
def _rotate_left(arr, stage_size):
perm = [(i, (i - 1) % stage_size) for i in range(stage_size)]
return jax.lax.ppermute(arr, axis_name="stage", perm=perm)
def _shift_left(arr, stage_size, output):
stage_idx = jax.lax.axis_index("stage")
arr = _rotate_left(arr, stage_size)
return jnp.where(stage_idx == stage_size - 1, output, arr)
@jax.shard_map(
mesh=self.mesh,
in_specs=(self.state_io_spec, self.stages_in_spec, self.stages_in_spec, P()),
out_specs=self.state_io_spec,
check_vma=True,
)
def _update_state_io(state_in, stream_slice, output, stream_buf_idx):
stage_size = jax.lax.axis_size("stage")
stream_slice = _shift_left(stream_slice, stage_size, output)
stream_slice = jnp.expand_dims(stream_slice, 1)
return jax.lax.dynamic_update_slice_in_dim(state_in, stream_slice, stream_buf_idx, axis=1)
new_state = _update_state_io(old_state_io, stream_slice, output, stream_buf_idx)
new_loop_state = {
"state_io": new_state,
"shift": new_shift,
"circ_storage": new_circ_storage,
"circ_storage_mover": new_circ_storage_mover,
"loop_iteration": loop_iteration + 1,
"prev_outputs": new_prev_outputs,
}
return new_loop_state
[docs]
def realign_output_microbatches(self, output):
"""Reorders the output tensor to reverse the circular shifts applied during execution.
Because the pipeline operates circularly, the output microbatches are shifted
out of order by the time the final stage is completed. This rolls them back
into their original sequential layout.
"""
microbatch_0_idx = self.iterations_to_complete_first_microbatch() % self.microbatches_per_stage
output = jnp.roll(output, shift=-microbatch_0_idx, axis=1)
output = self._maybe_shard_with_logical(output, self.state_io_logical)
return output
[docs]
def fetch_active_stage_weights(self, bsw, loop_iteration, physical_partition_spec=None, is_initializing=None):
"""The module fetches the actively prefetched weights
from the Buffer Sliding Window to avoid mid-iteration FSDP all-gathers.
"""
pipeline_weights = self.get_current_weights_from_bsw(
bsw, loop_iteration, physical_partition_spec=physical_partition_spec, is_initializing=is_initializing
)
return pipeline_weights
[docs]
def get_current_weights_from_bsw(self, bsw, loop_iteration, physical_partition_spec, is_initializing=None):
"""Pulls the fully gathered parameters for the current repeat from the BSW dual-buffer."""
bsw_pps = jax.tree.map(self._remove_fsdp_from_physical_partition_spec, physical_partition_spec)
_, repeat_ids = self.get_microbatch_and_repeat_ids(loop_iteration)
stage0_repeat_id = jnp.maximum(loop_iteration, 0) // self.config.num_pipeline_microbatches
@jax.shard_map(mesh=self.mesh, in_specs=((bsw_pps, bsw_pps), P("stage")), out_specs=bsw_pps, check_vma=True)
def select_weights_from_bsw(bsw, repeat_id):
# Different stage uses different components in BSW. Stage 0 must use the new weight.
return jax.tree.map(lambda x, y: jax.lax.select(repeat_id[0] == stage0_repeat_id, y, x), bsw[0], bsw[1])
weights = select_weights_from_bsw(bsw, repeat_ids)
if is_initializing is None:
is_initializing = self.is_initializing()
circular_metadata_params = {
nn.PARTITION_NAME: "circular_repeats",
"sub_weight_split_dims_mapping": (None,),
"is_initializing": is_initializing,
"x_times": self.config.num_pipeline_repeats,
"optimizer_dims_mapping": None,
}
weights = meta.remove_axis(weights, 0, circular_metadata_params)
return weights
[docs]
def from_all_variables_to_repeat_weights(self, weights, loop_iteration):
"""Gathers weights corresponding to the repeat IDs for current iteration."""
_, repeat_ids = self.get_microbatch_and_repeat_ids(loop_iteration)
def gather_weights_for_stages_in(w):
return self.gather_weights_across_stages_vmap(
w, repeat_ids=repeat_ids, repeat_dim_in_weights=0, stages_dim_in_weights=1
)
weights = pipeline_utils.remove_logically_partition(weights)
weights = jax.tree.map(gather_weights_for_stages_in, weights)
circular_metadata_params = {
nn.PARTITION_NAME: "circular_repeats",
"sub_weight_split_dims_mapping": (None,),
"is_initializing": self.is_initializing(),
"x_times": self.config.num_pipeline_repeats,
"optimizer_dims_mapping": None,
}
repeat_weights = meta.remove_axis(weights, 0, circular_metadata_params)
return repeat_weights
[docs]
def from_repeat_weights_to_bsw(
self,
repeat_weights,
physical_partition_spec,
axes_to_gather=("fsdp", "fsdp_transpose", "context", "expert"),
# TODO (chengnuojin) set use_shardmap=true after JAX >= 10.0.0 and use all_gather(..., to='invarying')
use_shardmap=False, # using shardmap produces additional reduce-scatter in backward pass
):
"""Executes the FSDP-like all-gathers to fully materialize a block of weights for the BSW."""
axes_to_remove = ["fsdp", "fsdp_transpose", "context"]
bsw_pps = pipeline_utils.derive_stage_weight_partition_specs(physical_partition_spec, axes_to_remove)
def _from_repeat_weights_to_bsw_shardmap(
repeat_weights,
physical_partition_spec,
axes_to_gather,
):
repeat_weights_pps = jax.tree.map(lambda p: P(*p[1:]), physical_partition_spec)
# Dynamically gather the index pytrees for all specified axes
axis_indices_dict = {
axis: pipeline_utils.get_mesh_axis_dim_indices(physical_partition_spec, axis) for axis in axes_to_gather
}
axis_names = list(axis_indices_dict.keys())
axis_pytrees = list(axis_indices_dict.values())
def should_skip_gather(axis_name, path_keys):
"""Defines specific rule-based exceptions for gathering certain axes."""
if axis_name == "expert" and "MoeBlock_0" in path_keys:
return True
# Add more exclusion rules for other axes here if needed in the future
return False
# Renamed to be more descriptive of its action
@jax.shard_map(
mesh=self.mesh,
in_specs=(repeat_weights_pps, None), # 'None' covers the entire axis_pytrees list
out_specs=bsw_pps,
check_vma=False,
)
def _shard_map_gather_weights(sharded_weights, indices_pytrees_list):
# Renamed to clarify we are gathering a single tensor iteratively along requested axes
def _gather_tensor_along_axes(path, x, *indices):
path_keys = [getattr(p, "key", str(p)) for p in path]
# Iterate through the provided axes and their corresponding indices
for axis_name, axis_idx in zip(axis_names, indices):
if axis_idx >= 0 and not should_skip_gather(axis_name, path_keys):
x = jax.lax.all_gather(x, axis_name=axis_name, axis=axis_idx - 1, tiled=True)
return x
return jax.tree_util.tree_map_with_path(_gather_tensor_along_axes, sharded_weights, *indices_pytrees_list)
return _shard_map_gather_weights(repeat_weights, axis_pytrees)
def _from_repeat_weights_to_bsw_hint(
repeat_weights,
):
def _apply_sharding_hint(weight, pspec):
sharding_name = NamedSharding(self.mesh, pspec)
return maybe_shard_with_name(
weight,
sharding_name,
shard_mode=self.config.shard_mode,
debug_sharding=self.config.debug_sharding,
extra_stack_level=0,
)
return jax.tree.map(_apply_sharding_hint, repeat_weights, bsw_pps)
if use_shardmap:
return _from_repeat_weights_to_bsw_shardmap(repeat_weights, physical_partition_spec, axes_to_gather=axes_to_gather)
return _from_repeat_weights_to_bsw_hint(repeat_weights)
[docs]
def weight_prefetching(self, weights, physical_partition_spec, loop_iteration):
"""Triggers asynchronous FSDP-like all-gathers for the next pipeline steps.
By gathering weights for `loop_iteration + 1` right now, the network communication
can overlap with the compute happening in `loop_iteration`.
"""
repeat_weights = self.from_all_variables_to_repeat_weights(weights, loop_iteration + 1)
return self.from_repeat_weights_to_bsw(repeat_weights, physical_partition_spec)
[docs]
def run_one_iteration(self, loop_state, bsw, positions, segment_ids, deterministic, model_mode, logical_partition_spec):
"""Executes the forward/backward logic for a single microbatch inside the pipeline.
This acts as the core step function that our `jax.lax.scan` wrappers call. It routes
the active BSW weights, sequences, and position IDs into the layer blocks, and then
advances the pipeline communication buffers via `advance_circular_buffers`.
"""
state_io = loop_state["state_io"]
shift = loop_state["shift"]
circ_storage = loop_state["circ_storage"]
loop_iteration = loop_state["loop_iteration"]
microbatch_ids, _ = self.get_microbatch_and_repeat_ids(loop_iteration)
physical_partition_spec = logical_to_mesh(logical_partition_spec, self.mesh, rules=self.config.logical_axis_rules)
stages_inputs = self.get_iteration_inputs(loop_iteration, state_io, circ_storage, shift)
stages_inputs = jax.ad_checkpoint.checkpoint_name(stages_inputs, "iteration_input")
stages_positions = self.gather_microbatch_inputs_vmap(positions, microbatch_ids, 0) if positions is not None else None
stages_segment_ids = (
self.gather_microbatch_inputs_vmap(segment_ids, microbatch_ids, 0) if segment_ids is not None else None
)
vmap_func = self.get_main_vmap_func_for_iterations()
stage_weights = self.fetch_active_stage_weights(
bsw,
loop_iteration,
physical_partition_spec=physical_partition_spec,
is_initializing=self.is_initializing(),
)
stages_output = vmap_func(
self.layers, stage_weights, stages_inputs, stages_segment_ids, stages_positions, deterministic, model_mode
)
if self.config.scan_layers:
stages_output = stages_output[0]
new_state = self.advance_circular_buffers(stages_output, loop_state)
return new_state
@nn.compact
def __call__(
self,
inputs: jnp.ndarray,
segment_ids: jnp.ndarray,
positions: jnp.ndarray,
deterministic: bool,
model_mode=MODEL_MODE_TRAIN,
logical_partition_spec=None,
) -> jnp.ndarray:
"""Entry point for the Pipeline Module. Sets up microbatch schedules and executes scans."""
inputs = inputs.reshape(
(
self.config.num_pipeline_microbatches,
self.pipeline_microbatch_size,
self.config.max_target_length,
self.config.emb_dim,
),
out_sharding=self.input_sharding,
)
example_inputs = jax.lax.broadcast(inputs[0], [self.num_stages])
ag_sharding = jax.sharding.NamedSharding(self.mesh, jax.sharding.PartitionSpec(None, None))
if positions is not None:
positions = self._maybe_shard_with_name(positions, ag_sharding)
positions = positions.reshape(
(self.config.num_pipeline_microbatches, self.pipeline_microbatch_size, self.config.max_target_length)
)
example_position = jax.lax.broadcast(positions[0], [self.num_stages])
position_idx = 0
else:
example_position = None
position_idx = None
if segment_ids is not None:
segment_ids = self._maybe_shard_with_name(segment_ids, ag_sharding)
segment_ids = segment_ids.reshape(
(self.config.num_pipeline_microbatches, self.pipeline_microbatch_size, self.config.max_target_length)
)
example_segmentation = jax.lax.broadcast(segment_ids[0], [self.num_stages])
segment_idx = 0
else:
example_segmentation = None
segment_idx = None
loop_state, bsw = self.init_states(inputs)
physical_partition_spec = logical_to_mesh(
logical_partition_spec, mesh=self.mesh, rules=self.config.logical_axis_rules
)
bubble_iterations = self.forwarding_delay * (self.num_stages - 1)
if self.is_initializing():
return self._run_weight_initialization(
example_inputs, example_segmentation, example_position, segment_idx, position_idx, deterministic, model_mode
)
logical_partition_spec = pipeline_utils.strip_pipeline_repeat_logical_axis(logical_partition_spec)
def run_iteration_scannable(model, loop_state, bsw):
return (
model.run_one_iteration(
loop_state,
bsw,
positions,
segment_ids,
deterministic,
model_mode,
logical_partition_spec=logical_partition_spec,
),
None,
)
if self.config.set_remat_policy_on_pipeline_iterations:
run_iteration_scannable = nn.remat(
run_iteration_scannable,
prevent_cse=not self.config.scan_pipeline_iterations,
policy=self.get_pipeline_remat_policy(),
)
# base scannable function used twice for real and bubble runs
base_scannable = functools.partial(
pipeline_utils.create_pipeline_stage,
deterministic=deterministic,
model_mode=model_mode,
logical_partition_spec=logical_partition_spec,
physical_partition_spec=physical_partition_spec,
positions=positions,
segment_ids=segment_ids,
)
run_one_repeat_scannable = base_scannable(length=self.config.num_pipeline_microbatches)
run_bubbles_scannable = base_scannable(length=bubble_iterations)
run_repeats_scanned = pipeline_utils.create_flax_pipeline_scan(
pipeline_stage_fn=run_one_repeat_scannable,
length=self.config.num_pipeline_repeats,
remat_policy=self.get_pipeline_remat_policy(),
use_scan=self.config.scan_pipeline_repeats,
)
run_bubbles_scanned = pipeline_utils.create_flax_pipeline_scan(
pipeline_stage_fn=run_bubbles_scannable,
length=1,
remat_policy=self.get_pipeline_remat_policy(),
use_scan=self.config.scan_pipeline_repeats,
)
initial_carry_repeats = (loop_state, bsw[0], self.layers.variables)
(loop_state, w_curr, pipeline_weights), _ = run_repeats_scanned(self, initial_carry_repeats)
initial_carry_bubbles = (loop_state, w_curr, pipeline_weights)
(loop_state, _, pipeline_weights), _ = run_bubbles_scanned(self, initial_carry_bubbles)
final_output = self.realign_output_microbatches(loop_state["state_io"])
final_output = jnp.reshape(
final_output,
(self.config.micro_batch_size_to_train_on, self.config.max_target_length, self.config.emb_dim),
out_sharding=self.output_sharding,
)
return final_output
[docs]
def create_pipeline(config: Config, layers: nn.Module, mesh: Mesh, remat_policy: Any = None) -> PipelineBase:
"""Factory function to instantiate the correct Pipeline module based on config."""
if config.pipeline_fsdp_ag_per_repeat:
return CircularPipeline(config=config, layers=layers, mesh=mesh, remat_policy=remat_policy)
return Pipeline(config=config, layers=layers, mesh=mesh, remat_policy=remat_policy)