mlx.core.random.truncated_normal#

truncated_normal(lower: Union[scalar, array], upper: Union[scalar, array], shape: Optional[Sequence[int]] = None, dtype: float32, key: Optional[array] = None, stream: Union[None, Stream, Device] = None) array#

Generate values from a truncated normal distribution.

The values are sampled from the truncated normal distribution on the domain (lower, upper). The bounds lower and upper can be scalars or arrays and must be broadcastable to shape.

Parameters:
  • lower (scalar or array) – Lower bound of the domain.

  • upper (scalar or array) – Upper bound of the domain.

  • shape (list(int), optional) – The shape of the output. Default is ().

  • dtype (Dtype, optional) – The data type of the output. Default is float32.

  • key (array, optional) – A PRNG key. Default: None.

Returns:

The output array of random values.

Return type:

array