mlx.core.split

Contents

mlx.core.split#

split(a: array, /, indices_or_sections: int | Sequence[int], axis: int = 0, *, stream: None | Stream | Device = None) array#

Split an array along a given axis.

Parameters:
  • a (array) – Input array.

  • indices_or_sections (int or list(int)) – If indices_or_sections is an integer the array is split into that many sections of equal size. An error is raised if this is not possible. If indices_or_sections is a list, then the indices are the split points, and the array is divided into len(indices_or_sections) + 1 sub-arrays.

  • axis (int, optional) – Axis to split along, defaults to 0.

Returns:

A list of split arrays.

Return type:

list(array)

Example

>>> a = mx.array([1, 2, 3, 4], dtype=mx.int32)
>>> mx.split(a, 2)
[array([1, 2], dtype=int32), array([3, 4], dtype=int32)]
>>> mx.split(a, [1, 3])
[array([1], dtype=int32), array([2, 3], dtype=int32), array([4], dtype=int32)]