mlx.nn.RMSNorm#
- class RMSNorm(dims: int, eps: float = 1e-05)#
Applies Root Mean Square normalization [1] to the inputs.
Computes
\[y = \frac{x}{\sqrt{E[x^2] + \epsilon}} \gamma\]where \(\gamma\) is a learned per feature dimension parameter initialized at 1.
Note the accumulation for the mean is done in 32-bit precision.
[1]: https://arxiv.org/abs/1910.07467
- Parameters:
Methods