maxtext.layers.pipeline_deprecated module#
Pipeline layer wrapping a decoder layer(s). Supports circular pipelining
- class maxtext.layers.pipeline_deprecated.Pipeline(config, layers, mesh, remat_policy=None, parent=<flax.linen.module._Sentinel object>, name=None)[source]#
Bases:
ModuleModule that implements pipelining across stages.
This module will loop over microbatches and execute the main body with a vmap for both the inputs and weights. This will produce a pipeline pattern if the stage dimension is sharded.
Supports circular pipelines, and multiple layers per stage are used when a module that executes multiple layers is passed as the layers input.
- Parameters:
config (Any)
layers (Module)
mesh (Mesh)
remat_policy (Any)
parent (Module | Scope | _Sentinel | None)
name (str | None)
- config: Any#
Importantly contains num_pipeline_microbatches, num_pipeline_repeats.
- layers: Module#
A module instance that each stage can execute. It can either be a single layer such as a LlamaDecoderLayer instance or scanned/looped set of decoder layers to execute multiple layers per stage. The name of this property (layers) is reflected in the state pytree and thus also checkpoints.
- mesh: Mesh#
The device mesh of the system.
- remat_policy: Any = None#
Remat policy to use for the loop iterations
- setup()[source]#
Initializes a Module lazily (similar to a lazy
__init__).setupis called once lazily on a module instance when a module is bound, immediately before any other methods like__call__are invoked, or before asetup-defined attribute onselfis accessed.This can happen in three cases:
Immediately when invoking
apply(),init()orinit_and_output().Once the module is given a name by being assigned to an attribute of another module inside the other module’s
setupmethod (see__setattr__()):>>> class MyModule(nn.Module): ... def setup(self): ... submodule = nn.Conv(...) ... # Accessing `submodule` attributes does not yet work here. ... # The following line invokes `self.__setattr__`, which gives ... # `submodule` the name "conv1". ... self.conv1 = submodule ... # Accessing `submodule` attributes or methods is now safe and ... # either causes setup() to be called once.
Once a module is constructed inside a method wrapped with
compact(), immediately before another method is called orsetupdefined attribute is accessed.
- init_states(inputs)[source]#
Initialize components of state: state_io, shift, circular_storage and circular_storage_mover Assumes input has already been reshaped into microbatches: [num_micro_batches, micro_batch_size, sequence, embed]
- Returns a dictionary with properties
shift: zeros shape [num_stages, micro_size, sequence, embed] prev_outputs: same shape as shift, only used when pipeline_delay_activation_forwarding is set to true, else None state_io: reshaped inputs [num_stages, microbatches/stages, micro_size, sequence, embed] circ_storage: zeros [num_stages, microbatches, micro_size, sequence, embed] when needed, else None circ_storage_mover: zeros[num_stages, micro_size, sequence, embed] when needed, else None loop_iteration: scalar set initially to 0.
- get_iteration_inputs(loop_iteration, state_io, circ_storage, shift)[source]#
Construct stages_in: the global array that is operated on for this iteration, shape same as shift=[stages, micro_size, sequence, embed] This is almost a rotated version of the last outputs, except for the first stage which must grab a new batch from state_io or an old one from circ_storage
- shard_dim_by_stages(x, dim, physical_partition_spec, is_stage_weight=False)[source]#
Shards x using the provided partition_spec, but adds the “stage” mesh axis to the existing sharding at the specified dimension.
- Parameters:
dim (int)
physical_partition_spec (PartitionSpec | None)
is_stage_weight (bool)
- get_microbatch_and_repeat_ids(loop_iteration)[source]#
Gets the microbatch_ids and repeat_ids for all stages on this loop_iteration. Works for both circular and non-circular
- vmap_parallel_gather(weights, physical_partition_spec, repeat_ids, repeat_dim_in_weights, stages_dim_in_weights)[source]#
Use vmap to implement a sharded parallel gather. Parallel gather means each stage has its own weights, and gets one slice from it. :param weights: Per-stage data to be gathered from. :param physical_partition_spec: Physical partition spec of the input weight. :param repeat_ids: Integer tensor of shape [num_stages], the repeats of the stages. :param repeat_dim_in_weights: The dimension in weights where repeat_ids are applied. The output will not
have this dimension.
- Parameters:
stages_dim_in_weights – The dimension in weights that represents parallel stages.
- Returns:
- The per-stage gathered values. The shape is weights.shape but with repeat_dim_in_weights
removed.
- vmap_gather(xs, ids, ids_dim)[source]#
Use vmap to implement a stage-wise sharded gather.
The stages share the same input, but they have different offsets.
- Parameters:
xs – Data shared by all stages, to be gathered from.
ids – Integer tensor of shape [num_stages], the offsets of the stages.
ids_dim – The dimension in xs where ids are applied. In the output, this dimension will be [num_stages], since each stage gets one slice.
- Returns:
- The per-stage gathered values. The shape is xs.shape but with ids_dim size
replaced with [num_stages].
- get_new_loop_state(output, loop_state)[source]#
Update the various buffers given the output of the most recent iteration * state_io: rotates left/up by 1 (the whole created in the last slot is filled with the most recent pipeline output)
Pushing inputs up from top of state_io into first stage of shift
Pulling outputs up from last stage of shift into bottom of state_io
- shift: rotate output (or prev_outputs if using delay) right/down by 1 - we imagine the pipeline moves to
right/down
circ_storage: pushes circ_storage_mover (the output of the previous iteration) into rotating index of circ_storage
circ_storage_mover: assigned to rotated output and pushed into circ_storage on the next iteration
prev_outputs: is set to the current output
- get_current_stage_weights(pipeline_weights, loop_iteration, physical_partition_spec=None)[source]#
Gets the current weights used for one iteration. Outputs a pytree whose arrays have leading dimension of stages, e.g. {‘mlp’: ‘wo’: [stages, mlp, embed]}. Stage 0 will use the 0th index of this pytree, Stage 1 the 1st index, etc. For non-circular pipelines, this simply returns all weights - every weight is used in every iteraiton. However for circular pipelines each stage grabs only the weights corresponding to the current repeat.
- get_current_repeat_from_stages(weights, loop_iteration, physical_partition_spec=None)[source]#
get current repeat from stages
- get_main_vmap_func_for_iterations()[source]#
Returns main stage function vmapped by number of stages. This becomes a vmap over a single layer instance if body_instance is a single layer, else a set of layers if body_instance is a set of layers.
- run_one_iteration(loop_state, pipeline_weights, positions, segment_ids, deterministic, model_mode, decoder_layer_instance, logical_partition_spec=None)[source]#
Run one loop iteration - gets weights and inputs for each stage, run the stages in parallel, and update the loop state.
- name: str | None = None#
- parent: Module | Scope | _Sentinel | None = None#
- scope: Scope | None = None#