mlx.nn.InstanceNorm

Contents

mlx.nn.InstanceNorm#

class InstanceNorm(dims: int, eps: float = 1e-05, affine: bool = False)#

Applies instance normalization [1] on the inputs.

Computes

\[y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta,\]

where \(\gamma\) and \(\beta\) are learned per feature dimension parameters initialized at 1 and 0 respectively. Both are of size dims, if affine is True.

Parameters:
  • dims (int) – The number of features of the input.

  • eps (float) – A value added to the denominator for numerical stability. Default: 1e-5.

  • affine (bool) – Default: False.

Shape:
  • Input: \((..., C)\) where \(C\) is equal to dims.

  • Output: Same shape as the input.

Examples

>>> import mlx.core as mx
>>> import mlx.nn as nn
>>> x = mx.random.normal((8, 4, 4, 16))
>>> inorm = nn.InstanceNorm(dims=16)
>>> output = inorm(x)

References

[1]: https://arxiv.org/abs/1607.08022

Methods