mlx.core.unflatten#
- unflatten(a: array, /, axis: int, shape: Sequence[int], *, stream: None | Stream | Device = None) array #
Unflatten an axis of an array to a shape.
- Parameters:
a (array) – Input array.
axis (int) – The axis to unflatten.
shape (tuple(int)) – The shape to unflatten to. At most one entry can be
-1
in which case the corresponding size will be inferred.stream (Stream, optional) – Stream or device. Defaults to
None
in which case the default stream of the default device is used.
- Returns:
The unflattened array.
- Return type:
Example
>>> a = mx.array([1, 2, 3, 4]) >>> mx.unflatten(a, 0, (2, -1)) array([[1, 2], [3, 4]], dtype=int32)