mlx.core.random.categorical

Contents

mlx.core.random.categorical#

categorical(logits: array, axis: int = -1, shape: Sequence[int] | None = None, num_samples: int | None = None, key: array | None = None, stream: None | Stream | Device = None) array#

Sample from a categorical distribution.

The values are sampled from the categorical distribution specified by the unnormalized values in logits. Note, at most one of shape or num_samples can be specified. If both are None, the output has the same shape as logits with the axis dimension removed.

Parameters:
  • logits (array) – The unnormalized categorical distribution(s).

  • axis (int, optional) – The axis which specifies the distribution. Default: -1.

  • shape (list(int), optional) – The shape of the output. This must be broadcast compatable with logits.shape with the axis dimension removed. Default: None

  • num_samples (int, optional) – The number of samples to draw from each of the categorical distributions in logits. The output will have num_samples in the last dimension. Default: None.

  • key (array, optional) – A PRNG key. Default: None.

Returns:

The shape-sized output array with type uint32.

Return type:

array