mlx.core.random.truncated_normal#
- truncated_normal(lower: Union[scalar, array], upper: Union[scalar, array], shape: Optional[Sequence[int]] = None, dtype: float32, key: Optional[array] = None, stream: Union[None, Stream, Device] = None) array #
Generate values from a truncated normal distribution.
The values are sampled from the truncated normal distribution on the domain
(lower, upper)
. The boundslower
andupper
can be scalars or arrays and must be broadcastable toshape
.- Parameters:
lower (scalar or array) – Lower bound of the domain.
upper (scalar or array) – Upper bound of the domain.
shape (list(int), optional) – The shape of the output. Default is
()
.dtype (Dtype, optional) – The data type of the output. Default is
float32
.key (array, optional) – A PRNG key. Default: None.
- Returns:
The output array of random values.
- Return type: