mlx.core.gather_mm#
- gather_mm(a: array, b: array, /, lhs_indices: array, rhs_indices: array, *, sorted_indices: bool = False, stream: None | Stream | Device = None) array #
Matrix multiplication with matrix-level gather.
Performs a gather of the operands with the given indices followed by a (possibly batched) matrix multiplication of two arrays. This operation is more efficient than explicitly applying a
take()
followed by amatmul()
.The indices
lhs_indices
andrhs_indices
contain flat indices along the batch dimensions (i.e. all but the last two dimensions) ofa
andb
respectively.For
a
with shape(A1, A2, ..., AS, M, K)
,lhs_indices
contains indices from the range[0, A1 * A2 * ... * AS)
For
b
with shape(B1, B2, ..., BS, M, K)
,rhs_indices
contains indices from the range[0, B1 * B2 * ... * BS)
If only one index is passed and it is sorted, the
sorted_indices
flag can be passed for a possible faster implementation.- Parameters:
a (array) – Input array.
b (array) – Input array.
lhs_indices (array, optional) – Integer indices for
a
. Default:None
rhs_indices (array, optional) – Integer indices for
b
. Default:None
sorted_indices (bool, optional) – May allow a faster implementation if the passed indices are sorted. Default:
False
.
- Returns:
The output array.
- Return type: