mlx.core.random.categorical#
- categorical(logits: array, axis: int = -1, shape: Sequence[int] | None = None, num_samples: int | None = None, key: array | None = None, stream: None | Stream | Device = None) array#
Sample from a categorical distribution.
The values are sampled from the categorical distribution specified by the unnormalized values in
logits. Note, at most one ofshapeornum_samplescan be specified. If both areNone, the output has the same shape aslogitswith theaxisdimension removed.- Parameters:
logits (array) – The unnormalized categorical distribution(s).
axis (int, optional) – The axis which specifies the distribution. Default:
-1.shape (list(int), optional) – The shape of the output. This must be broadcast compatible with
logits.shapewith theaxisdimension removed. Default:Nonenum_samples (int, optional) – The number of samples to draw from each of the categorical distributions in
logits. The output will havenum_samplesin the last dimension. Default:None.key (array, optional) – A PRNG key. Default:
None.
- Returns:
The
shape-sized output array with typeuint32.- Return type: