# Copyright 2023–2026 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" Pipeline layer wrapping a decoder layer(s). Supports circular pipelining """
import functools
from typing import Any
import numpy as np
from jax import numpy as jnp
from jax.sharding import Mesh, NamedSharding, PartitionSpec as P
import jax
import jax.ad_checkpoint
from flax.core import meta
from flax import linen as nn
from flax.linen.spmd import LogicallyPartitioned
from maxtext.common.common_types import Config, MODEL_MODE_TRAIN, ShardMode
from maxtext.utils.sharding import (
maybe_shard_with_logical,
maybe_shard_with_name,
create_sharding,
logical_to_mesh_axes,
logical_to_mesh,
)
[docs]
class Pipeline(nn.Module):
"""Module that implements pipelining across stages.
This module will loop over microbatches and execute the main body with a vmap for both the inputs and weights.
This will produce a pipeline pattern if the stage dimension is sharded.
Supports circular pipelines, and multiple layers per stage are used when a module that executes multiple layers
is passed as the layers input.
"""
#: Importantly contains num_pipeline_microbatches, num_pipeline_repeats.
config: Config
#: A module instance that each stage can execute. It can either be a single layer such as a LlamaDecoderLayer instance or
#: scanned/looped set of decoder layers to execute multiple layers per stage. The name of this property (layers) is
#: reflected in the state pytree and thus also checkpoints.
layers: nn.Module
#: The device mesh of the system.
mesh: Mesh
#: Remat policy to use for the loop iterations
remat_policy: Any = None
[docs]
def setup(self): # pylint: disable=missing-function-docstring
self.num_stages = self.config.ici_pipeline_parallelism * self.config.dcn_pipeline_parallelism
self.forwarding_delay = 2 if self.config.pipeline_delay_activation_forwarding else 1
self.pipeline_microbatch_size = self.config.micro_batch_size_to_train_on // self.config.num_pipeline_microbatches
microbatches_per_stage = self.config.num_pipeline_microbatches // self.num_stages
self.microbatches_per_stage = microbatches_per_stage
self.use_circ_storage = self.need_circ_storage()
self.batch_axis_name = "activation_batch"
self.seq_len_axis_name = "activation_length"
# TODO(b/470167805): replace self.spmd_axis_name with "stage" when JAX >= 0.8.2.
self.spmd_axis_name = "stage" if self.config.shard_mode == ShardMode.AUTO else None
self.stages_in_logical = ("activation_stage", self.batch_axis_name, self.seq_len_axis_name, "activation_embed")
self.stages_in_spec = logical_to_mesh_axes(self.stages_in_logical, self.mesh, rules=self.config.logical_axis_rules)
self.stages_in_sharding = (
NamedSharding(self.mesh, self.stages_in_spec) if self.config.shard_mode == ShardMode.EXPLICIT else None
)
self.state_io_logical = ("activation_stage", None, self.batch_axis_name, self.seq_len_axis_name, "activation_embed")
self.state_io_spec = logical_to_mesh_axes(self.state_io_logical, self.mesh, rules=self.config.logical_axis_rules)
self.state_io_sharding = (
NamedSharding(self.mesh, self.state_io_spec) if self.config.shard_mode == ShardMode.EXPLICIT else None
)
self.input_sharding = (
create_sharding(
self.mesh,
(None, self.batch_axis_name, self.seq_len_axis_name, "activation_embed"),
rules=self.config.logical_axis_rules,
)
if self.config.shard_mode == ShardMode.EXPLICIT
else None
)
self.output_sharding = (
create_sharding(
self.mesh,
(self.batch_axis_name, self.seq_len_axis_name, "activation_embed"),
rules=self.config.logical_axis_rules,
)
if self.config.shard_mode == ShardMode.EXPLICIT
else None
)
[docs]
def need_circ_storage(self):
return (
self.config.num_pipeline_repeats > 1
and self.config.num_pipeline_microbatches > self.num_stages * self.forwarding_delay
)
[docs]
def iterations_to_complete_first_microbatch_one_repeat(self):
# Return the number of iterations it takes for microbatch 0 to finish a repeat
return self.forwarding_delay * (self.num_stages - 1)
[docs]
def iterations_to_complete_first_microbatch(self):
# Return the number of iterations it takes for microbatch 0 to finish the last stage of the last repeat
return (
self.config.num_pipeline_microbatches * (self.config.num_pipeline_repeats - 1)
+ self.iterations_to_complete_first_microbatch_one_repeat()
)
def _maybe_shard_with_logical(self, inputs, logical_axes):
"""Wrapper of maybe_shard_with_logical"""
return maybe_shard_with_logical(
inputs,
logical_axes,
shard_mode=self.config.shard_mode,
mesh=self.mesh,
rules=self.config.logical_axis_rules,
debug_sharding=self.config.debug_sharding,
extra_stack_level=1,
)
def _maybe_shard_with_name(self, inputs, sharding_name):
"""Wrapper of maybe_shard_with_name"""
return maybe_shard_with_name(
inputs,
sharding_name,
shard_mode=self.config.shard_mode,
debug_sharding=self.config.debug_sharding,
)
[docs]
def init_states(self, inputs):
"""Initialize components of state: state_io, shift, circular_storage and circular_storage_mover
Assumes input has already been reshaped into microbatches: [num_micro_batches, micro_batch_size, sequence, embed]
Returns a dictionary with properties
shift: zeros shape [num_stages, micro_size, sequence, embed]
prev_outputs: same shape as shift, only used when pipeline_delay_activation_forwarding is set to true, else None
state_io: reshaped inputs [num_stages, microbatches/stages, micro_size, sequence, embed]
circ_storage: zeros [num_stages, microbatches, micro_size, sequence, embed] when needed, else None
circ_storage_mover: zeros[num_stages, micro_size, sequence, embed] when needed, else None
loop_iteration: scalar set initially to 0.
"""
# Shift is used to rotate the output of each pipeline into the input of the next
# shift has shape [num_stages, micro_size, sequence, embed]
shift = jnp.zeros((self.num_stages,) + inputs.shape[1:], dtype=inputs.dtype)
shift = self._maybe_shard_with_logical(shift, self.stages_in_logical)
# Prev outputs has the same shape of the output (and shift)
if self.config.pipeline_delay_activation_forwarding:
prev_outputs = jnp.zeros((self.num_stages,) + inputs.shape[1:], dtype=inputs.dtype)
prev_outputs = self._maybe_shard_with_logical(prev_outputs, self.stages_in_logical)
else:
prev_outputs = None
# state_io (state input output) at first holds all of the input batches, but also will hold the outputs
# as the pipeline runs/finishes
# state_io has shape [num_stages, microbatches/stages, micro_size, sequence, embed]
state_io = jnp.reshape(
inputs, (self.num_stages, self.microbatches_per_stage) + inputs.shape[1:], out_sharding=self.state_io_sharding
)
# We shard the pipeline_microbatch_size axis by data/fsdp, not num_microbatches since those are looped over.
state_io = self._maybe_shard_with_logical(state_io, self.state_io_logical)
# circ_storage is used to hold the final pipeline stage outputs before it is used for the next repeat. It is only
# needed when num_microbatches > num_stages, else instead the final stage will immediately pass to the first without
# additional storage.
# circ_storage has shape [num_stages, microbatches, micro_size, sequence, embed].
# Note that this shape is a factor of num_stages larger than necessary - each stage holds the global batch, but only
# stage 0 holds the real activations (since it will use them), the rest hold dummy ones. This amount of storage
# [global_batch, sequence, embed] is fine as long as there is some amount of additional sharding axes, e.g. FSDP,
# TP, DP (e.g. there are many devices that shard stage 0)
# We may look into alternatives using less storage if this becomes an issue (ideas in b/347603101).
if self.use_circ_storage:
circ_storage = jnp.zeros((self.num_stages,) + inputs.shape, dtype=inputs.dtype, out_sharding=self.state_io_sharding)
else:
circ_storage = None
# circ_storage_mover is used to push the microbatches from the pipeline into circ_storage with one buffer iteration
# of delay circ_storage_mover shape is same as shift: [num_stages, micro_size, sequence, embed]
if self.use_circ_storage:
circ_storage_mover = shift
else:
circ_storage_mover = None
init_loop_state = {
"state_io": state_io,
"shift": shift,
"circ_storage": circ_storage,
"circ_storage_mover": circ_storage_mover,
"loop_iteration": 0,
"prev_outputs": prev_outputs,
}
return init_loop_state
[docs]
def shard_dim_by_stages(self, x, dim: int, physical_partition_spec: P | None, is_stage_weight: bool = False):
"""Shards x using the provided partition_spec, but adds the "stage" mesh axis to the existing sharding at
the specified dimension."""
placeholder = None if self.config.shard_mode == ShardMode.EXPLICIT else P.UNCONSTRAINED
if physical_partition_spec is None:
dims_mapping = [placeholder] * x.ndim
else:
physical_partition_spec = self._remove_fsdp_from_physical_partition_spec(physical_partition_spec)
dims_mapping = list(physical_partition_spec)
# If not a stage weight, we handle the repeat dimension offset
if not is_stage_weight:
dims_mapping = [placeholder] * (dim + 1) + dims_mapping[dim:] # inflat one dimension for num_repeats
dims_mapping[dim] = "stage"
dims_mapping = tuple(dims_mapping)
# We add reduced rule only when pspec is given for a stage weight
if physical_partition_spec and is_stage_weight and self.config.shard_mode == ShardMode.EXPLICIT:
batch_mesh_axis = ["data", "fsdp"]
reduced_mark = [mesh_axis for mesh_axis in batch_mesh_axis if self.mesh.shape[mesh_axis] > 1]
pspec = P(*dims_mapping, reduced=set(reduced_mark))
else:
pspec = P(*dims_mapping)
sharding = jax.sharding.NamedSharding(self.mesh, pspec)
return self._maybe_shard_with_name(x, sharding)
[docs]
def get_microbatch_and_repeat_ids(self, loop_iteration):
"""Gets the microbatch_ids and repeat_ids for all stages on this loop_iteration. Works for both circular and
non-circular"""
# Stage 0 has processed one microbatch every loop_iter, but Stage 1 is 1 behind due to bubble, etc for other stages
microbatches_processed = jnp.maximum(loop_iteration - self.forwarding_delay * jnp.arange(self.num_stages), 0)
microbatch_ids = microbatches_processed % self.config.num_pipeline_microbatches
repeat_ids = microbatches_processed // self.config.num_pipeline_microbatches
return microbatch_ids, repeat_ids
[docs]
def vmap_parallel_gather(
self, weights, physical_partition_spec, repeat_ids, repeat_dim_in_weights, stages_dim_in_weights
):
"""Use vmap to implement a sharded parallel gather.
Parallel gather means each stage has its own weights, and gets one slice from it.
Args:
weights: Per-stage data to be gathered from.
physical_partition_spec: Physical partition spec of the input weight.
repeat_ids: Integer tensor of shape [num_stages], the repeats of the stages.
repeat_dim_in_weights: The dimension in weights where repeat_ids are applied. The output will not
have this dimension.
stages_dim_in_weights: The dimension in weights that represents parallel stages.
Returns:
The per-stage gathered values. The shape is weights.shape but with repeat_dim_in_weights
removed.
"""
def _gather_one(x, repeat_id):
return jnp.squeeze(jax.lax.dynamic_slice_in_dim(x, repeat_id, 1, repeat_dim_in_weights), repeat_dim_in_weights)
gathered_weights_stage_dim = 0
repeat_ids = self.shard_dim_by_stages(repeat_ids, 0, physical_partition_spec=None)
# num_repeats x num_stages x *param_dim
weights = self.shard_dim_by_stages(
weights, stages_dim_in_weights, physical_partition_spec=physical_partition_spec, is_stage_weight=False
)
stage_weights = jax.vmap(_gather_one, in_axes=(stages_dim_in_weights, 0), out_axes=gathered_weights_stage_dim)(
weights, repeat_ids
)
# num_stages x *param_dim
stage_weights = self.shard_dim_by_stages(
stage_weights, gathered_weights_stage_dim, physical_partition_spec=physical_partition_spec, is_stage_weight=True
)
return stage_weights
[docs]
def vmap_gather(self, xs, ids, ids_dim):
"""Use vmap to implement a stage-wise sharded gather.
The stages share the same input, but they have different offsets.
Args:
xs: Data shared by all stages, to be gathered from.
ids: Integer tensor of shape [num_stages], the offsets of the stages.
ids_dim: The dimension in xs where ids are applied. In the output, this
dimension will be [num_stages], since each stage gets one slice.
Returns:
The per-stage gathered values. The shape is xs.shape but with ids_dim size
replaced with [num_stages].
"""
def _gather_one(x, i):
idx = tuple(i if d == ids_dim else slice(None) for d in range(x.ndim))
replicated_sharding = NamedSharding(self.mesh, P())
return x.at[idx].get(out_sharding=replicated_sharding)
ids = self.shard_dim_by_stages(ids, 0, physical_partition_spec=None)
outs = jax.vmap(_gather_one, in_axes=(None, 0), out_axes=ids_dim)(xs, ids)
return self.shard_dim_by_stages(outs, 0, physical_partition_spec=None)
[docs]
def get_new_loop_state(self, output, loop_state):
"""
Update the various buffers given the output of the most recent iteration
* state_io: rotates left/up by 1 (the whole created in the last slot is filled with the most recent pipeline output)
* Pushing inputs up from top of state_io into first stage of shift
* Pulling outputs up from last stage of shift into bottom of state_io
* shift: rotate output (or prev_outputs if using delay) right/down by 1 - we imagine the pipeline moves to
right/down
* circ_storage: pushes circ_storage_mover (the output of the previous iteration) into rotating index of circ_storage
* circ_storage_mover: assigned to rotated output and pushed into circ_storage on the next iteration
* prev_outputs: is set to the current output
"""
old_state_io = loop_state["state_io"]
old_circ_storage = loop_state["circ_storage"]
old_circ_storage_mover = loop_state["circ_storage_mover"]
loop_iteration = loop_state["loop_iteration"]
old_prev_outputs = loop_state["prev_outputs"]
@jax.shard_map(
mesh=self.mesh,
in_specs=self.stages_in_spec,
out_specs=self.stages_in_spec,
check_vma=True,
)
def _rotate_right(arr):
# we use +1 for right shifting
stage_size = jax.lax.axis_size("stage")
perm = [(i, (i + 1) % stage_size) for i in range(stage_size)]
arr = jax.lax.ppermute(arr, axis_name="stage", perm=perm)
return arr
@jax.shard_map(
mesh=self.mesh,
in_specs=self.stages_in_spec,
out_specs=self.stages_in_spec,
check_vma=True,
)
def _shift_right(arr):
stage_idx = jax.lax.axis_index("stage")
stage_size = jax.lax.axis_size("stage")
perm = [(i, (i + 1) % stage_size) for i in range(stage_size)]
arr = jax.lax.ppermute(arr, axis_name="stage", perm=perm)
return jnp.where(stage_idx == 0, jnp.zeros_like(arr), arr)
# Shift either rotates or shifts depending on if the last stage immediately must send to first or not
# For non-circular pipelines, the last stage does not need to send to first
# For circular pipelines with #micro = #stages, last stage immediately sends to first
# For circular pipelines with #micro > stages (circ_storage), last stage sends to circ storage
def _update_shift(output_in):
if self.config.num_pipeline_repeats == 1 or self.use_circ_storage:
return _shift_right(output_in) # last stage does not have to send to first immediately
else:
return _rotate_right(output_in) # last stage must immediately send to first
if self.config.pipeline_delay_activation_forwarding:
new_shift = _update_shift(old_prev_outputs)
new_prev_outputs = output
else:
new_shift = _update_shift(output)
new_prev_outputs = None
if self.use_circ_storage:
# Insert the circ_storage_mover into new_circ_storage at a microbatch-rotating index.
# circ_storage_mover still points to the output of PREVIOUS iteration, which should aid in allowing overlapped
# compute/async transfers
def _rotate_right_and_update(circ_storage_mover_in, circ_storage_in):
rotated = _rotate_right(circ_storage_mover_in)
rotated = jnp.expand_dims(rotated, 1)
# We rotate the pushing index into circ storage, and ensure that microbatch 0 lands in index 0
offset = (
loop_iteration - self.iterations_to_complete_first_microbatch_one_repeat() - 1
) % self.config.num_pipeline_microbatches # Note extra -1 b/c grabbing from the
# previous output - using circ_storage_mover before it is updated
return jax.lax.dynamic_update_slice_in_dim(circ_storage_in, rotated, offset, axis=1)
new_circ_storage = _rotate_right_and_update(old_circ_storage_mover, old_circ_storage)
new_circ_storage_mover = output
else:
new_circ_storage = None
new_circ_storage_mover = None
# Rotate stream_io left/up by 1 on rotating micro/stage index (stream_buf_idx), replacing the last/bottom with the
# last stage output
stream_buf_idx = loop_iteration % self.microbatches_per_stage
stream_slice = old_state_io[:, stream_buf_idx]
def _rotate_left(arr, stage_size):
# we use -1 for left shifting
perm = [(i, (i - 1) % stage_size) for i in range(stage_size)]
arr = jax.lax.ppermute(arr, axis_name="stage", perm=perm)
return arr
def _shift_left(arr, stage_size, output):
stage_idx = jax.lax.axis_index("stage")
arr = _rotate_left(arr, stage_size)
return jnp.where(stage_idx == stage_size - 1, output, arr)
@jax.shard_map(
mesh=self.mesh,
in_specs=(self.state_io_spec, self.stages_in_spec, self.stages_in_spec, P()),
out_specs=self.state_io_spec,
)
def _update_state_io(state_in, stream_slice, output, stream_buf_idx):
# Shift the current slice to the left, then fill the last stage with the final output.
stage_size = jax.lax.axis_size("stage")
stream_slice = _shift_left(stream_slice, stage_size, output)
stream_slice = jnp.expand_dims(stream_slice, 1)
return jax.lax.dynamic_update_slice_in_dim(state_in, stream_slice, stream_buf_idx, axis=1)
new_state = _update_state_io(old_state_io, stream_slice, output, stream_buf_idx)
new_loop_state = {
"state_io": new_state,
"shift": new_shift,
"circ_storage": new_circ_storage,
"circ_storage_mover": new_circ_storage_mover,
"loop_iteration": loop_iteration + 1,
"prev_outputs": new_prev_outputs,
}
return new_loop_state
[docs]
def permute_output_micro_per_stage_dim(self, output):
# The first real output (microbatch 0) takes a certain amount of loop iterations to finish and be pushed to
# state_io - it will land on a different index of state_io depending on the number of iterations.
microbatch_0_idx = self.iterations_to_complete_first_microbatch() % self.microbatches_per_stage
permutation = (
np.arange(self.microbatches_per_stage) + microbatch_0_idx
) % self.microbatches_per_stage # permute so the value in land_idx is moved into idx 0, and (land_idx + 1) appear
# in idx 1, etc
output = output[:, permutation]
return output
[docs]
def get_current_stage_weights(self, pipeline_weights, loop_iteration, physical_partition_spec=None):
"""
Gets the current weights used for one iteration. Outputs a pytree whose arrays have leading dimension of stages, e.g.
{'mlp': 'wo': [stages, mlp, embed]}. Stage 0 will use the 0th index of this pytree, Stage 1 the 1st index, etc.
For non-circular pipelines, this simply returns all weights - every weight is used in every iteraiton. However
for circular pipelines each stage grabs only the weights corresponding to the current repeat.
"""
if self.config.num_pipeline_repeats > 1:
return self.get_current_repeat_from_stages(
pipeline_weights, loop_iteration, physical_partition_spec=physical_partition_spec
)
else:
return pipeline_weights
[docs]
def get_current_repeat_from_stages(self, weights, loop_iteration, physical_partition_spec=None):
"""get current repeat from stages"""
_, repeat_ids = self.get_microbatch_and_repeat_ids(loop_iteration)
circular_metadata_params = {
nn.PARTITION_NAME: "circular_repeats",
"sub_weight_split_dims_mapping": (None,),
"is_initializing": self.is_initializing(),
"x_times": self.config.num_pipeline_repeats,
"optimizer_dims_mapping": None,
}
weights = meta.remove_axis(
weights, 0, circular_metadata_params
) # Remove the circular metadata axis, this axis will be removed when passed to the main vmap, only one circular
# entry per stage.
weights = self._remove_logically_partition(weights)
def gather_weights_for_stages_in(w, spec=None):
return self.vmap_parallel_gather(
w,
repeat_ids=repeat_ids,
repeat_dim_in_weights=0,
stages_dim_in_weights=1,
physical_partition_spec=spec,
)
if physical_partition_spec is None:
weights = jax.tree.map(gather_weights_for_stages_in, weights)
else:
weights = jax.tree.map(gather_weights_for_stages_in, weights, physical_partition_spec)
return weights
[docs]
def get_vmap_func_for_init(self):
"""This vmap func is used to initialize the weights only on init."""
def func_to_vmap(body_instance, stages_inputs, stages_segment_ids, stages_positions, deterministic, model_mode):
"""nn.vmap requires either a nn.module class or a function whose first argument is a nn.module instance."""
return body_instance(stages_inputs, stages_segment_ids, stages_positions, deterministic, model_mode)
vmap_func = nn.vmap(
func_to_vmap,
in_axes=(0, 0, 0, None, None),
spmd_axis_name=self.spmd_axis_name, # TODO(b/470167805): replace self.spmd_axis_name with "stage" when JAX >= 0.8.2.
variable_axes={"params": 0, "_overwrite_with_gradient": 0},
split_rngs={"params": self.is_initializing(), "dropout": self.config.enable_dropout},
metadata_params={
nn.PARTITION_NAME: "layers",
"sub_weight_split_dims_mapping": (None),
"is_initializing": self.is_initializing(),
"x_times": self.num_stages,
},
)
return vmap_func
[docs]
def get_main_vmap_func_for_iterations(self):
"""
Returns main stage function vmapped by number of stages.
This becomes a vmap over a single layer instance if body_instance is a single layer,
else a set of layers if body_instance is a set of layers.
"""
def func_to_vmap(
body_instance, weights, stages_inputs, stages_segment_ids, stages_positions, deterministic, model_mode
):
"""nn.vmap requires either a nn.module class or a function whose first argument is a nn.module instance."""
weights = meta.remove_axis(
weights,
0,
{
nn.PARTITION_NAME: "layers",
"sub_weight_split_dims_mapping": (None,),
"is_initializing": self.is_initializing(),
"x_times": self.num_stages,
},
)
return body_instance.apply(weights, stages_inputs, stages_segment_ids, stages_positions, deterministic, model_mode)
vmap_func = nn.vmap(
func_to_vmap,
in_axes=(0, 0, 0, 0, None, None),
spmd_axis_name=self.spmd_axis_name, # TODO(b/470167805): replace self.spmd_axis_name with "stage" when JAX >= 0.8.2.
variable_axes={"params": 0},
split_rngs={"params": self.is_initializing(), "dropout": self.config.enable_dropout},
metadata_params={
nn.PARTITION_NAME: "layers",
"sub_weight_split_dims_mapping": (None),
"is_initializing": self.is_initializing(),
"x_times": self.num_stages,
},
)
return vmap_func
[docs]
def run_one_iteration(
self,
loop_state,
pipeline_weights,
positions,
segment_ids,
deterministic,
model_mode,
decoder_layer_instance,
logical_partition_spec=None,
):
"""Run one loop iteration - gets weights and inputs for each stage, run the stages in parallel,
and update the loop state."""
state_io = loop_state["state_io"]
shift = loop_state["shift"]
circ_storage = loop_state["circ_storage"]
loop_iteration = loop_state["loop_iteration"]
microbatch_ids, _ = self.get_microbatch_and_repeat_ids(loop_iteration)
# Convert logical spec to physical spec
physical_partition_spec = logical_to_mesh(logical_partition_spec, self.mesh, rules=self.config.logical_axis_rules)
stages_inputs = self.get_iteration_inputs(loop_iteration, state_io, circ_storage, shift)
# We checkpoint stages_inputs since we are grabbing only one slice of the state_io, don't need to save the entire
# buffer.
stages_inputs = jax.ad_checkpoint.checkpoint_name(stages_inputs, "iteration_input")
stages_positions = self.vmap_gather(positions, microbatch_ids, 0) if positions is not None else None
stages_segment_ids = self.vmap_gather(segment_ids, microbatch_ids, 0) if segment_ids is not None else None
vmap_func = self.get_main_vmap_func_for_iterations()
if self.config.num_pipeline_repeats > 1:
_, repeat_ids = self.get_microbatch_and_repeat_ids(loop_iteration)
def prepare_vars_for_main_vmap(weights, physical_partition_spec=None):
circular_metadata_params = {
nn.PARTITION_NAME: "circular_repeats",
"sub_weight_split_dims_mapping": (None,),
"is_initializing": self.is_initializing(),
"x_times": self.config.num_pipeline_repeats,
"optimizer_dims_mapping": None,
}
weights = meta.remove_axis(
weights, 0, circular_metadata_params
) # Remove the circular metadata axis, this axis will be removed when passed to the main vmap, only one
# circular entry per stage.
weights = self._remove_logically_partition(weights)
def gather_weights_for_stages_in(w, spec=None):
return self.vmap_parallel_gather(
w, repeat_ids=repeat_ids, repeat_dim_in_weights=0, stages_dim_in_weights=1, physical_partition_spec=spec
)
if physical_partition_spec is None:
weights = jax.tree.map(gather_weights_for_stages_in, weights)
else:
weights = jax.tree.map(gather_weights_for_stages_in, weights, physical_partition_spec)
return weights
prepare_vars_for_main_vmap_partial = functools.partial(
prepare_vars_for_main_vmap, physical_partition_spec=physical_partition_spec
)
vmap_func = nn.map_variables(
vmap_func,
mapped_collections=["params", "_overwrite_with_gradient", "non_trainable", "summaries", "intermediates"],
mutable=True,
trans_in_fn=prepare_vars_for_main_vmap_partial,
)
stage_weights = self.get_current_stage_weights(
pipeline_weights, loop_iteration, physical_partition_spec=physical_partition_spec
)
stages_output = vmap_func(
decoder_layer_instance,
stage_weights,
stages_inputs,
stages_segment_ids,
stages_positions,
deterministic,
model_mode,
)
if self.config.scan_layers:
stages_output = stages_output[0]
new_state = self.get_new_loop_state(stages_output, loop_state)
return new_state
[docs]
def get_pipeline_remat_policy(self):
"""Returns the pipeline remat policy for this pipeline."""
# We ensure that the decoder layer inputs are saved, although we leave it to a custom
# policy if they should be saved to device or offloaded.
if self.config.remat_policy == "custom":
return self.remat_policy
save_input_policy = jax.checkpoint_policies.save_only_these_names("iteration_input", "decoder_layer_input")
if self.remat_policy is not None:
remat_policy = jax.checkpoint_policies.save_from_both_policies(self.remat_policy, save_input_policy)
else:
remat_policy = save_input_policy
return remat_policy
[docs]
def get_weight_sharding(self, *init_args):
"""get weight sharding function for this pipeline."""
# Returns a partition spec of all weights. Requires passing in arguments to init.
key = jax.random.PRNGKey(0)
keys = {"params": key, "dropout": key, "aqt": key}
weights = self.init(keys, *init_args)
def get_partition_spec(pytree):
def _is_leaf(x):
return isinstance(x, nn.spmd.LogicallyPartitioned)
def get_partition_spec_leaf(leaf):
return leaf.get_partition_spec()
partition_spec_tree = jax.tree.map(get_partition_spec_leaf, pytree, is_leaf=_is_leaf)
return partition_spec_tree
partition_spec_with_extra_layer = get_partition_spec(weights)
logical_partition_spec = {"params": partition_spec_with_extra_layer["params"]["layers"]}
return logical_partition_spec
[docs]
@staticmethod
def get_logical_spec_repeats_removed(full_logical):
if full_logical is None:
return None
def _remove_from_spec(spec):
return jax.sharding.PartitionSpec(*[dim for dim in spec if dim != "circular_repeats"])
return jax.tree.map(_remove_from_spec, full_logical)
# TODO(chengnuojin) Remove this function and its usage after pipeline nnx migration
@staticmethod
def _remove_logically_partition(weights):
def _remove_logically_partition_leaf(v):
return getattr(v, "value") if isinstance(v, LogicallyPartitioned) else v
return jax.tree.map(
_remove_logically_partition_leaf,
weights,
is_leaf=lambda v: isinstance(v, LogicallyPartitioned),
)
@staticmethod
def _remove_fsdp_from_physical_partition_spec(pps):
"""Removes 'fsdp' and 'fsdp_transpose' from a physical PartitionSpec."""
if isinstance(pps, P):
new_spec = []
# Iterate through each axis in the original PartitionSpec.
for axis in pps:
if axis is None:
new_spec.append(None)
elif isinstance(axis, str):
# If the axis is 'fsdp', replace it with None to signify replication.
if axis not in ("fsdp", "fsdp_transpose"):
new_spec.append(axis)
else:
new_spec.append(None)
elif isinstance(axis, (list, tuple)):
# If the axis is a collection, filter out 'fsdp'.
new_axis = [a for a in axis if a not in ("fsdp", "fsdp_transpose")]
new_spec.append(tuple(new_axis))
else:
raise ValueError(f"Unsupported_axis_type: {type(axis)}")
# Return a new sharding object with the modified spec.
return P(*new_spec)
return pps
[docs]
def all_gather_over_fsdp(self, variables, logical_partition_spec):
physical_partition_spec = logical_to_mesh(
logical_partition_spec, mesh=self.mesh, rules=self.config.logical_axis_rules
)
physical_partition_spec_no_fsdp = jax.tree.map(
self._remove_fsdp_from_physical_partition_spec, physical_partition_spec
)
return jax.tree.map(
lambda w, p: self._maybe_shard_with_name(w, NamedSharding(self.mesh, p)),
variables,
physical_partition_spec_no_fsdp,
)
@nn.compact
def __call__(
self,
inputs: jnp.ndarray,
segment_ids: jnp.ndarray,
positions: jnp.ndarray,
deterministic: bool,
model_mode=MODEL_MODE_TRAIN,
logical_partition_spec=None, # Pytree of sharding specifications of the weights (aka self.layers.variables)
) -> jnp.ndarray:
"""The main method that maps the series of decoder layer inputs to final layer outputs.
Has the same signature of a single decoder layer, and expects the same shapes, e.g. the inputs should have shape
[global_batch], and internally this will be reshapped into microbatches.
"""
# Reshape inputs of [global_batch, ...] to [microbatches, pipeline_microbatch_sizes, ...]
inputs = inputs.reshape(
(
self.config.num_pipeline_microbatches,
self.pipeline_microbatch_size,
self.config.max_target_length,
self.config.emb_dim,
),
out_sharding=self.input_sharding,
)
example_inputs = jax.lax.broadcast(inputs[0], [self.num_stages]) # dummy inputs fed to initialize the module
ag_sharding = jax.sharding.NamedSharding(self.mesh, jax.sharding.PartitionSpec(None, None))
if positions is not None:
# AG positions
positions = self._maybe_shard_with_name(positions, ag_sharding)
positions = positions.reshape(
(self.config.num_pipeline_microbatches, self.pipeline_microbatch_size, self.config.max_target_length)
)
example_position = jax.lax.broadcast(positions[0], [self.num_stages])
position_idx = 0
else:
example_position = None
position_idx = None
if segment_ids is not None:
segment_ids = self._maybe_shard_with_name(segment_ids, ag_sharding)
segment_ids = segment_ids.reshape(
(self.config.num_pipeline_microbatches, self.pipeline_microbatch_size, self.config.max_target_length)
)
example_segmentation = jax.lax.broadcast(segment_ids[0], [self.num_stages])
segment_idx = 0
else:
example_segmentation = None
segment_idx = None
loop_state = self.init_states(inputs)
# Each microbatch should go through each stage (with repeats) - so there is num_micro * (num_stages * repeats)
# compute to perform
# Each iteration is vmapped by num_stages, so the number of iterations should be
# num_micro * num_stages * repeats / num_stages = num_micro * repeats
# However due to the pipeline bubble some iterations process less than num_stages microbatches. It takes
# num_micro * repeat iterations for the last microbatch to start the final repeat, then an additional
# num_stages - 1 to finish the final repeat.
# Thus the total iterations is num_micro * repeat + num_stages - 1, & we may consider the num_stages - 1 as bubble.
# The bubble doubles when we use forwarding delay.
bubble_iterations = self.forwarding_delay * (self.num_stages - 1)
real_iterations = self.config.num_pipeline_microbatches * self.config.num_pipeline_repeats
total_iterations = real_iterations + bubble_iterations
if self.is_initializing():
vmap_func = self.get_vmap_func_for_init()
if self.config.num_pipeline_repeats > 1:
# To shard the weights on initialization for the circular pipeline we create weights of
# shape [num_repeat, num_stages, ...] (e.g. [num_repeat, num_stages, embed, mlp]) and shard the num_stages axis.
# We wrap the main stage vmap with a num_repeat vmap to generate this axis only for parameter initialization.
vmap_func = nn.vmap(
vmap_func,
in_axes=(0, segment_idx, position_idx, None, None),
variable_axes={
"params": 0,
"_overwrite_with_gradient": 0,
"non_trainable": 0,
"hyper_params": 0,
},
split_rngs={"params": True, "dropout": self.config.enable_dropout},
metadata_params={
nn.PARTITION_NAME: "circular_repeats",
"sub_weight_split_dims_mapping": (None,),
"is_initializing": True,
"x_times": self.config.num_pipeline_repeats,
"optimizer_dims_mapping": None,
},
)
example_inputs = jax.lax.broadcast(example_inputs, [self.config.num_pipeline_repeats])
example_segmentation = (
jax.lax.broadcast(example_segmentation, [self.config.num_pipeline_repeats])
if example_segmentation is not None
else None
)
example_position = (
jax.lax.broadcast(example_position, [self.config.num_pipeline_repeats])
if example_position is not None
else None
)
# We only need to run one set of stages to initialize the variables, instead of looping over all microbatches for
# the full total_iterations.
example_inputs = self._maybe_shard_with_logical(example_inputs, (None, None, None, None))
stage_outputs = vmap_func(
self.layers, example_inputs, example_segmentation, example_position, deterministic, model_mode
)
if self.config.scan_layers:
stage_outputs = stage_outputs[0]
# We return something of the correct shape (global_batch, sequence, embed) by reshaping a single stages output
# which has shape [pipeline_microbatch_size, sequence, embed]
if self.config.num_pipeline_repeats > 1:
stage_outputs = stage_outputs[0] # Remove extra dimension created for the circular vmap
broadcasted_stage_outpus = jax.lax.broadcast(
stage_outputs[0], [self.config.micro_batch_size_to_train_on // self.pipeline_microbatch_size]
)
return jnp.reshape(
broadcasted_stage_outpus,
[self.config.micro_batch_size_to_train_on, self.config.max_target_length, self.config.emb_dim],
out_sharding=self.output_sharding,
)
if self.config.pipeline_fsdp_ag_once:
variables = self._remove_logically_partition(self.layers.variables)
all_pipeline_weights = self.all_gather_over_fsdp(variables, logical_partition_spec)
else:
all_pipeline_weights = self.layers.variables
logical_partition_spec = self.get_logical_spec_repeats_removed(logical_partition_spec)
def run_iteration_scannable(model, loop_state, xs):
# flax transforms like nn.scan and nn.remat can only be applied to nn.module classes or nn.module instances, so we
# explicitly wrap the run_one_iteration in this method - the 1st argument model (`self`) is a nn.module instance.
return (
model.run_one_iteration(
loop_state,
all_pipeline_weights,
positions,
segment_ids,
deterministic,
model_mode,
model.layers,
logical_partition_spec=logical_partition_spec,
),
None,
)
if self.config.set_remat_policy_on_pipeline_iterations:
run_iteration_scannable = nn.remat(
run_iteration_scannable,
prevent_cse=not self.config.scan_pipeline_iterations, # prevent_cse not used with scan
policy=self.get_pipeline_remat_policy(),
)
# The scan cannot be used on init since it broadcasts the weights, which aren't yet initialized.
if self.config.scan_pipeline_iterations:
variable_carry = []
variable_broadcast = [
"params",
"_overwrite_with_gradient",
] # All loop iterations need the weights for the full pipeline.
if self.is_mutable_collection("non_trainable"):
variable_carry.append("non_trainable")
else:
variable_broadcast.append("non_trainable")
run_all_iterations_scanned = nn.scan(
run_iteration_scannable,
variable_axes={
"summaries": 0,
"aux_loss": 0,
"intermediates": 0,
"hyper_params": 0,
},
variable_broadcast=variable_broadcast,
variable_carry=variable_carry,
# Dropout/aqt keys will be split for each iteration.
split_rngs={"random": True},
length=total_iterations,
)
loop_state, _ = run_all_iterations_scanned(self, loop_state, None)
else:
for _ in range(total_iterations):
loop_state, _ = run_iteration_scannable(self, loop_state, None)
# The final output is located in the input/output array, however the output microbatches may be permuted relative to
# the input
final_output = self.permute_output_micro_per_stage_dim(loop_state["state_io"])
# reshape outputs to match input shape of total batch instead of microbatches [batch, sequence, embed]
final_output = jnp.reshape(
final_output,
(self.config.micro_batch_size_to_train_on, self.config.max_target_length, self.config.emb_dim),
out_sharding=self.output_sharding,
)
return final_output