maxtext.models.gpt3 module#
Transformer model definition.
- class maxtext.models.gpt3.Gpt3LayerNorm(*args, **kwargs)[source]#
Bases:
ModuleGPT3 Layer normalization operating on the last axis of the input data.
- Parameters:
args (Any)
kwargs (Any)
- Return type:
Any
- maxtext.models.gpt3.gpt3_layer_norm(*, num_features, epsilon=1e-06, dtype=<class 'jax.numpy.float32'>, weight_dtype=<class 'jax.numpy.float32'>, kernel_axes=(), scale_init=<function zeros>, use_bias=True, reductions_in_fp32=False, parameter_memory_host_offload=False, name=None)[source]#
Initializes the gpt3_layer_norm module.
- Parameters:
num_features (int) – the number of features.
epsilon (float) – the epsilon for the layer norm.
dtype (Any) – the dtype of the computation (default: float32).
weight_dtype (Any) – the dtype of the weights (default: float32).
kernel_axes (tuple[None | str, ...]) – logical axes for partitioning the kernel.
scale_init (Callable[[Array, Sequence[int], dtype], Array]) – initializer for the scale.
use_bias (bool) – whether to add bias in linear transformation.
reductions_in_fp32 (bool) – whether to do reductions in fp32.
parameter_memory_host_offload (bool) – Determines whether to offload params to host
name (None | str) – name passed to the ToLinen Module
- class maxtext.models.gpt3.Gpt3MultiHeadAttention(*args, **kwargs)[source]#
Bases:
ModuleMulti-head attention in gpt3.
- Parameters:
args (Any)
kwargs (Any)
- Return type:
Any
- num_heads#
number of attention heads. Features (i.e. inputs_q.shape[-1]) should be divisible by the number of heads.
- head_dim#
dimension of each head.
- max_target_length#
maximum length of output
- max_prefill_predict_length#
size of the maximum prefill
- mesh#
device mesh
- dtype#
the dtype of the computation.
- dropout_rate#
dropout rate
- kernel_init#
initializer for the kernel of the Dense layers.
- float32_qk_product#
bool, if True then compute logits via float32 qk_product to avoid numerical issues with bfloat16.
- float32_logits#
bool, if True then cast logits to float32 before softmax to avoid numerical issues with bfloat16.
- fused_qkv#
whether to fuse query, key and value into one projection.
- quant#
Quant, stores quantization config, defaults to None implying no quantization.
- use_bias#
whether to add bias in linear transformation.
- create_projection_layer(input_shape, output_shape, kernel_axes, axis=-1)[source]#
Create projection layer for Key, Value, Query and Output
- Parameters:
input_shape (tuple[int, ...])
output_shape (tuple[int, ...] | int)
kernel_axes (tuple[str, ...])
axis (int | tuple[int, ...])