he_normal(dtype: Dtype = mlx.core.float32) Callable[[array, str, float], array]#

Build a He normal initializer.

This initializer samples from a normal distribution with a standard deviation computed from the number of input (fan_in) or output (fan_out) units according to:

\[\sigma = \gamma \frac{1}{\sqrt{\text{fan}}}\]

where \(\text{fan}\) is either the number of input units when the mode is "fan_in" or output units when the mode is "fan_out".

For more details see the original reference: Delving Deep into Rectifiers: Surpassing Human-Level Performance on ImageNet Classification


dtype (Dtype, optional) – The data type of the array. Defaults to mx.float32.


An initializer that returns an array with the same shape as the input, filled with samples from the He normal distribution.

Return type:

Callable[[array, str, float], array]


>>> init_fn = nn.init.he_normal()
>>> init_fn(mx.zeros((2, 2)))  # uses fan_in
array([[-1.25211, 0.458835],
       [-0.177208, -0.0137595]], dtype=float32)
>>> init_fn(mx.zeros((2, 2)), mode="fan_out", gain=5)
array([[5.6967, 4.02765],
       [-4.15268, -2.75787]], dtype=float32)