mlx.core.take_along_axis#

take_along_axis(a: array, /, indices: array, axis: int | None = None, *, stream: None | Stream | Device = None) array#

Take values along an axis at the specified indices.

Parameters:
  • a (array) – Input array.

  • indices (array) – Indices array. These should be broadcastable with the input array excluding the axis dimension.

  • axis (int or None) – Axis in the input to take the values from. If axis == None the array is flattened to 1D prior to the indexing operation.

Returns:

The output array with the specified shape and values.

Return type:

array