mlx.core.random.categorical#
- categorical(logits: array, axis: int = -1, shape: Sequence[int] | None = None, num_samples: int | None = None, key: array | None = None, stream: None | Stream | Device = None) array #
Sample from a categorical distribution.
The values are sampled from the categorical distribution specified by the unnormalized values in
logits
. Note, at most one ofshape
ornum_samples
can be specified. If both areNone
, the output has the same shape aslogits
with theaxis
dimension removed.- Parameters:
logits (array) – The unnormalized categorical distribution(s).
axis (int, optional) – The axis which specifies the distribution. Default:
-1
.shape (list(int), optional) – The shape of the output. This must be broadcast compatable with
logits.shape
with theaxis
dimension removed. Default:None
num_samples (int, optional) – The number of samples to draw from each of the categorical distributions in
logits
. The output will havenum_samples
in the last dimension. Default:None
.key (array, optional) – A PRNG key. Default:
None
.
- Returns:
The
shape
-sized output array with typeuint32
.- Return type: