mlx.data.Buffer.shard

mlx.data.Buffer.shard#

Buffer.shard(self: mlx.data._c.Buffer, key: str, num_shards: int, output_key: str = '') mlx.data._c.Buffer#

Split the first dimension in num_shards.

This operation performs the following numpy style reshape:

def shard(x):
  shape = x.shape
  return x.reshape(num_shards, -1, *shape[1:])
Parameters:
  • key (str) – The sample key that contains the array we are operating on.

  • num_shards (int) – The size of the first dimension of the reshaped array.

  • output_key (str) – If it is not empty then write the result to this key instead of overwriting key. (default: ‘’)