mlx.nn.Transformer

Contents

mlx.nn.Transformer#

class Transformer(dims: int = 512, num_heads: int = 8, num_encoder_layers: int = 6, num_decoder_layers: int = 6, mlp_dims: int | None = None, dropout: float = 0.0, activation: ~typing.Callable[[~typing.Any], ~typing.Any] = <nanobind.nb_func object>, custom_encoder: ~typing.Any | None = None, custom_decoder: ~typing.Any | None = None, norm_first: bool = True, checkpoint: bool = False)#

Implements a standard Transformer model.

The implementation is based on Attention Is All You Need.

The Transformer model contains an encoder and a decoder. The encoder processes the input sequence and the decoder generates the output sequence. The interaction between encoder and decoder happens through the attention mechanism.

Parameters:
  • dims (int, optional) – The number of expected features in the encoder/decoder inputs. Default: 512.

  • num_heads (int, optional) – The number of attention heads. Default: 8.

  • num_encoder_layers (int, optional) – The number of encoder layers in the Transformer encoder. Default: 6.

  • num_decoder_layers (int, optional) – The number of decoder layers in the Transformer decoder. Default: 6.

  • mlp_dims (int, optional) – The hidden dimension of the MLP block in each Transformer layer. Defaults to 4*dims if not provided. Default: None.

  • dropout (float, optional) – The dropout value for the Transformer encoder and decoder. Dropout is used after each attention layer and the activation in the MLP layer. Default: 0.0.

  • activation (function, optional) – the activation function for the MLP hidden layer. Default: mlx.nn.relu().

  • custom_encoder (Module, optional) – A custom encoder to replace the standard Transformer encoder. Default: None.

  • custom_decoder (Module, optional) – A custom decoder to replace the standard Transformer decoder. Default: None.

  • norm_first (bool, optional) – if True, encoder and decoder layers will perform layer normalization before attention and MLP operations, otherwise after. Default: True.

  • checkpoint (bool, optional) – if True perform gradient checkpointing to reduce the memory usage at the expense of more computation. Default: False.

Methods