mlx.nn.init.normal

Contents

mlx.nn.init.normal#

normal(mean: float = 0.0, std: float = 1.0, dtype: Dtype = mlx.core.float32) Callable[[array], array]#

An initializer that returns samples from a normal distribution.

Parameters:
  • mean (float, optional) – Mean of the normal distribution. Default: 0.0.

  • std (float, optional) – Standard deviation of the normal distribution. Default: 1.0.

  • dtype (Dtype, optional) – The data type of the array. Default: float32.

Returns:

An initializer that returns an array with the same shape as the input, filled with samples from a normal distribution.

Return type:

Callable[[array], array]

Example

>>> init_fn = nn.init.normal()
>>> init_fn(mx.zeros((2, 2)))
array([[-0.982273, -0.534422],
       [0.380709, 0.0645099]], dtype=float32)