mlx.core.random.truncated_normal#

truncated_normal(lower: scalar | array, upper: scalar | array, shape: Sequence[int] | None = None, dtype: Dtype | None = float32, key: array | None = None, stream: None | Stream | Device = None) array#

Generate values from a truncated normal distribution.

The values are sampled from the truncated normal distribution on the domain (lower, upper). The bounds lower and upper can be scalars or arrays and must be broadcastable to shape.

Parameters:
  • lower (scalar or array) – Lower bound of the domain.

  • upper (scalar or array) – Upper bound of the domain.

  • shape (list(int), optional) – The shape of the output. Default:().

  • dtype (Dtype, optional) – The data type of the output. Default: float32.

  • key (array, optional) – A PRNG key. Default: None.

Returns:

The output array of random values.

Return type:

array