Indexing Arrays#

For the most part, indexing an MLX array works the same as indexing a NumPy numpy.ndarray. See the NumPy documentation for more details on how that works.

For example, you can use regular integers and slices (slice) to index arrays:

>>> arr = mx.arange(10)
>>> arr[3]
array(3, dtype=int32)
>>> arr[-2]  # negative indexing works
array(8, dtype=int32)
>>> arr[2:8:2] # start, stop, stride
array([2, 4, 6], dtype=int32)

For multi-dimensional arrays, the ... or Ellipsis syntax works as in NumPy:

>>> arr = mx.arange(8).reshape(2, 2, 2)
>>> arr[:, :, 0]
array(3, dtype=int32)
array([[0, 2],
       [4, 6]], dtype=int32
>>> arr[..., 0]
array([[0, 2],
       [4, 6]], dtype=int32

You can index with None to create a new axis:

>>> arr = mx.arange(8)
>>> arr.shape
[8]
>>> arr[None].shape
[1, 8]

You can also use an array to index another array:

>>> arr = mx.arange(10)
>>> idx = mx.array([5, 7])
>>> arr[idx]
array([5, 7], dtype=int32)

Mixing and matching integers, slice, ..., and array indices works just as in NumPy.

Other functions which may be useful for indexing arrays are take() and take_along_axis().

Differences from NumPy#

Note

MLX indexing is different from NumPy indexing in two important ways:

  • Indexing does not perform bounds checking. Indexing out of bounds is undefined behavior.

  • Boolean mask based indexing is not yet supported.

The reason for the lack of bounds checking is that exceptions cannot propagate from the GPU. Performing bounds checking for array indices before launching the kernel would be extremely inefficient.

Indexing with boolean masks is something that MLX may support in the future. In general, MLX has limited support for operations for which outputs shapes are dependent on input data. Other examples of these types of operations which MLX does not yet support include numpy.nonzero() and the single input version of numpy.where().

In Place Updates#

In place updates to indexed arrays are possible in MLX. For example:

>>> a = mx.array([1, 2, 3])
>>> a[2] = 0
>>> a
array([1, 2, 0], dtype=int32)

Just as in NumPy, in place updates will be reflected in all references to the same array:

>>> a = mx.array([1, 2, 3])
>>> b = a
>>> b[2] = 0
>>> b
array([1, 2, 0], dtype=int32)
>>> a
array([1, 2, 0], dtype=int32)

Transformations of functions which use in-place updates are allowed and work as expected. For example:

def fun(x, idx):
    x[idx] = 2.0
    return x.sum()

dfdx = mx.grad(fun)(mx.array([1.0, 2.0, 3.0]), mx.array([1]))
print(dfdx)  # Prints: array([1, 0, 1], dtype=float32)

In the above dfdx will have the correct gradient, namely zeros at idx and ones elsewhere.