mlx.nn.BatchNorm#
- class BatchNorm(num_features: int, eps: float = 1e-05, momentum: float = 0.1, affine: bool = True, track_running_stats: bool = True)#
Applies Batch Normalization over a 2D or 3D input.
Computes
\[y = \frac{x - E[x]}{\sqrt{Var[x]} + \epsilon} \gamma + \beta,\]where \(\gamma\) and \(\beta\) are learned per feature dimension parameters initialized at 1 and 0 respectively.
The input shape is specified as
NC
orNLC
, whereN
is the batch,C
is the number of features or channels, andL
is the sequence length. The output has the same shape as the input. For four-dimensional arrays, the shape isNHWC
, whereH
andW
are the height and width respectively.For more information on Batch Normalization, see the original paper Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift.
- Parameters:
num_features (int) – The feature dimension to normalize over.
eps (float, optional) – A small additive constant for numerical stability. Default:
1e-5
.momentum (float, optional) – The momentum for updating the running mean and variance. Default:
0.1
.affine (bool, optional) – If
True
, apply a learned affine transformation after the normalization. Default:True
.track_running_stats (bool, optional) – If
True
, track the running mean and variance. Default:True
.
Examples
>>> import mlx.core as mx >>> import mlx.nn as nn >>> x = mx.random.normal((5, 4)) >>> bn = nn.BatchNorm(num_features=4, affine=True) >>> output = bn(x)
Methods
unfreeze
(*args, **kwargs)Wrap unfreeze to make sure that running_mean and var are always frozen parameters.