maxtext.models.gpt3 module#

Transformer model definition.

class maxtext.models.gpt3.Gpt3LayerNorm(*args, **kwargs)[source]#

Bases: Module

GPT3 Layer normalization operating on the last axis of the input data.

Parameters:
  • args (Any)

  • kwargs (Any)

Return type:

Any

maxtext.models.gpt3.gpt3_layer_norm(*, num_features, epsilon=1e-06, dtype=<class 'jax.numpy.float32'>, weight_dtype=<class 'jax.numpy.float32'>, kernel_axes=(), scale_init=<function zeros>, use_bias=True, reductions_in_fp32=False, parameter_memory_host_offload=False, name=None)[source]#

Initializes the gpt3_layer_norm module.

Parameters:
  • num_features (int) – the number of features.

  • epsilon (float) – the epsilon for the layer norm.

  • dtype (Any) – the dtype of the computation (default: float32).

  • weight_dtype (Any) – the dtype of the weights (default: float32).

  • kernel_axes (tuple[None | str, ...]) – logical axes for partitioning the kernel.

  • scale_init (Callable[[Array, Sequence[int], dtype], Array]) – initializer for the scale.

  • use_bias (bool) – whether to add bias in linear transformation.

  • reductions_in_fp32 (bool) – whether to do reductions in fp32.

  • parameter_memory_host_offload (bool) – Determines whether to offload params to host

  • name (None | str) – name passed to the ToLinen Module

class maxtext.models.gpt3.Gpt3MultiHeadAttention(*args, **kwargs)[source]#

Bases: Module

Multi-head attention in gpt3.

Parameters:
  • args (Any)

  • kwargs (Any)

Return type:

Any

num_heads#

number of attention heads. Features (i.e. inputs_q.shape[-1]) should be divisible by the number of heads.

head_dim#

dimension of each head.

max_target_length#

maximum length of output

max_prefill_predict_length#

size of the maximum prefill

mesh#

device mesh

dtype#

the dtype of the computation.

dropout_rate#

dropout rate

kernel_init#

initializer for the kernel of the Dense layers.

float32_qk_product#

bool, if True then compute logits via float32 qk_product to avoid numerical issues with bfloat16.

float32_logits#

bool, if True then cast logits to float32 before softmax to avoid numerical issues with bfloat16.

fused_qkv#

whether to fuse query, key and value into one projection.

quant#

Quant, stores quantization config, defaults to None implying no quantization.

use_bias#

whether to add bias in linear transformation.

create_projection_layer(input_shape, output_shape, kernel_axes, axis=-1)[source]#

Create projection layer for Key, Value, Query and Output

Parameters:
  • input_shape (tuple[int, ...])

  • output_shape (tuple[int, ...] | int)

  • kernel_axes (tuple[str, ...])

  • axis (int | tuple[int, ...])

qkv_projection(projection_layer, inputs)[source]#

Fused QKV projection

Parameters:
  • projection_layer (Any)

  • inputs (Array)

projection(projection_layer, inputs)[source]#

individual projection for one of q, k and v.

Parameters:
  • projection_layer (Any)

  • inputs (Array)

Return type:

Array

class maxtext.models.gpt3.Gpt3DecoderLayer(*args, **kwargs)[source]#

Bases: Module

Transformer decoder layer that attends to the encoder.

Parameters:
  • args (Any)

  • kwargs (Any)

Return type:

Any