mlx.nn.InstanceNorm#
- class InstanceNorm(dims: int, eps: float = 1e-05, affine: bool = False)#
Applies instance normalization [1] on the inputs.
Computes
\[y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta,\]where \(\gamma\) and \(\beta\) are learned per feature dimension parameters initialized at 1 and 0 respectively. Both are of size
dims
, ifaffine
isTrue
.- Parameters:
- Shape:
Input: \((..., C)\) where \(C\) is equal to
dims
.Output: Same shape as the input.
Examples
>>> import mlx.core as mx >>> import mlx.nn as nn >>> x = mx.random.normal((8, 4, 4, 16)) >>> inorm = nn.InstanceNorm(dims=16) >>> output = inorm(x)
References
[1]: https://arxiv.org/abs/1607.08022
Methods