mlx.nn.RMSNorm

Contents

mlx.nn.RMSNorm#

class RMSNorm(dims: int, eps: float = 1e-05)#

Applies Root Mean Square normalization [1] to the inputs.

Computes

\[y = \frac{x}{\sqrt{E[x^2] + \epsilon}} \gamma\]

where \(\gamma\) is a learned per feature dimension parameter initialized at 1.

Note the accumulation for the mean is done in 32-bit precision.

[1]: https://arxiv.org/abs/1910.07467

Parameters:
  • dims (int) – The feature dimension of the input to normalize over

  • eps (float) – A small additive constant for numerical stability

Methods