mlx.core.fast.rms_norm

Contents

mlx.core.fast.rms_norm#

rms_norm(x: array, weight: array, eps: float, *, stream: None | Stream | Device = None) array#

Root Mean Square normalization (RMS norm).

The normalization is with respect to the last axis of the input x.

Parameters:
  • x (array) – Input array.

  • weight (array) – A multiplicative weight to scale the result by. The weight should be one-dimensional with the same size as the last axis of x.

  • eps (float) – A small additive constant for numerical stability.

Returns:

The output array.

Return type:

array