mlx.nn.GroupNorm#
- class GroupNorm(num_groups: int, dims: int, eps: float = 1e-05, affine: bool = True, pytorch_compatible: bool = False)#
Applies Group Normalization [1] to the inputs.
Computes the same normalization as layer norm, namely
\[y = \frac{x - E[x]}{\sqrt{Var[x]} + \epsilon} \gamma + \beta,\]where \(\gamma\) and \(\beta\) are learned per feature dimension parameters initialized at 1 and 0 respectively. However, the mean and variance are computed over the spatial dimensions and each group of features. In particular, the input is split into num_groups across the feature dimension.
The feature dimension is assumed to be the last dimension and the dimensions that precede it (except the first) are considered the spatial dimensions.
[1]: https://arxiv.org/abs/1803.08494
- Parameters:
num_groups (int) – Number of groups to separate the features into
dims (int) – The feature dimensions of the input to normalize over
eps (float) – A small additive constant for numerical stability
affine (bool) – If True learn an affine transform to apply after the normalization.
pytorch_compatible (bool) – If True perform the group normalization in the same order/grouping as PyTorch.
Methods