mlx.data.Buffer.batch#
- Buffer.batch(self: mlx.data._c.Buffer, batch_size: int | List[int], pad: Dict[str, float] = {}, dim: Dict[str, int] = {}) mlx.data._c.Buffer#
Creates batches from
batch_sizeconsecutive samples.When two samples have arrays that are not the same shape, the batch shape is the smallest shape that contains all samples in each dimension. The places that do not have values are filled with
padvalues.When a batch dimension is not provided, the arrays are stacked. If it is provided, the arrays are concatenated along that dimension.
The following example showcases the use the
dimargument.import mlx.data as dx import numpy as np dset = dx.buffer_from_vector([{"x": np.random.randn(10, i+1)} for i in range(10)]) print(dset.batch(4)[0]["x"].shape) # prints (4, 10, 4) print(dset.batch(4)[1]["x"].shape) # prints (4, 10, 8) print(dset.batch(4, dim=dict(x=0))[0]["x"].shape) # prints (40, 4) print(dset.batch(4, dim=dict(x=0))[1]["x"].shape) # prints (40, 8) print(dset.batch(4, dim=dict(x=1))[0]["x"].shape) # prints (10, 10) print(dset.batch(4, dim=dict(x=1))[1]["x"].shape) # prints (10, 26)