Random#
Random sampling functions in MLX use an implicit global PRNG state by default.
However, all function take an optional key
keyword argument for when more
fine-grained control or explicit state management is needed.
For example, you can generate random numbers with:
for _ in range(3):
print(mx.random.uniform())
which will print a sequence of unique pseudo random numbers. Alternatively you can explicitly set the key:
key = mx.random.key(0)
for _ in range(3):
print(mx.random.uniform(key=key))
which will yield the same pseudo random number at each iteration.
Following JAX’s PRNG design we use a splittable version of Threefry, which is a counter-based PRNG.
|
Generate Bernoulli random values. |
|
Sample from a categorical distribution. |
|
Sample from the standard Gumbel distribution. |
|
Get a PRNG key from a seed. |
|
Generate normally distributed random numbers. |
|
Generate jointly-normal random samples given a mean and covariance. |
|
Generate random integers from the given interval. |
|
Seed the global PRNG. |
|
Split a PRNG key into sub keys. |
|
Generate values from a truncated normal distribution. |
|
Generate uniformly distributed random numbers. |
|
Sample numbers from a Laplace distribution. |
|
Generate a random permutation or permute the entries of an array. |