maxtext.models.simple_layer module

maxtext.models.simple_layer module#

Simple decoder layers for testing and debugging purposes.

class maxtext.models.simple_layer.SimpleDecoderLayer(*args, **kwargs)[source]#

Bases: Module

Decoder layer consisting of a single [embed, embed] weight matrix.

Parameters:
  • args (Any)

  • kwargs (Any)

Return type:

Any

class maxtext.models.simple_layer.SimpleMlpDecoderLayer(*args, **kwargs)[source]#

Bases: Module

Decoder layer consisting of [embed,mlp] followed by an [mlp,embed] matmul.

Parameters:
  • args (Any)

  • kwargs (Any)

Return type:

Any