mlx.core.unflatten

Contents

mlx.core.unflatten#

unflatten(a: array, /, axis: int, shape: Sequence[int], *, stream: None | Stream | Device = None) array#

Unflatten an axis of an array to a shape.

Parameters:
  • a (array) – Input array.

  • axis (int) – The axis to unflatten.

  • shape (tuple(int)) – The shape to unflatten to. At most one entry can be -1 in which case the corresponding size will be inferred.

  • stream (Stream, optional) – Stream or device. Defaults to None in which case the default stream of the default device is used.

Returns:

The unflattened array.

Return type:

array

Example

>>> a = mx.array([1, 2, 3, 4])
>>> mx.unflatten(a, 0, (2, -1))
array([[1, 2], [3, 4]], dtype=int32)