mlx.core.gather_mm#
- gather_mm(a: array, b: array, /, lhs_indices: array, rhs_indices: array, *, stream: None | Stream | Device = None) array #
Matrix multiplication with matrix-level gather.
Performs a gather of the operands with the given indices followed by a (possibly batched) matrix multiplication of two arrays. This operation is more efficient than explicitly applying a
take()
followed by amatmul()
.The indices
lhs_indices
andrhs_indices
contain flat indices along the batch dimensions (i.e. all but the last two dimensions) ofa
andb
respectively.For
a
with shape(A1, A2, ..., AS, M, K)
,lhs_indices
contains indices from the range[0, A1 * A2 * ... * AS)
For
b
with shape(B1, B2, ..., BS, M, K)
,rhs_indices
contains indices from the range[0, B1 * B2 * ... * BS)