mlx.core.gather_mm#
- gather_mm(a: array, b: array, /, lhs_indices: array, rhs_indices: array, *, sorted_indices: bool = False, stream: None | Stream | Device = None) array#
Matrix multiplication with matrix-level gather.
Performs a gather of the operands with the given indices followed by a (possibly batched) matrix multiplication of two arrays. This operation is more efficient than explicitly applying a
take()followed by amatmul().The indices
lhs_indicesandrhs_indicescontain flat indices along the batch dimensions (i.e. all but the last two dimensions) ofaandbrespectively.For
awith shape(A1, A2, ..., AS, M, K),lhs_indicescontains indices from the range[0, A1 * A2 * ... * AS)For
bwith shape(B1, B2, ..., BS, M, K),rhs_indicescontains indices from the range[0, B1 * B2 * ... * BS)If only one index is passed and it is sorted, the
sorted_indicesflag can be passed for a possible faster implementation.- Parameters:
a (array) – Input array.
b (array) – Input array.
lhs_indices (array, optional) – Integer indices for
a. Default:Nonerhs_indices (array, optional) – Integer indices for
b. Default:Nonesorted_indices (bool, optional) – May allow a faster implementation if the passed indices are sorted. Default:
False.
- Returns:
The output array.
- Return type: