mlx.core.gather_mm

Contents

mlx.core.gather_mm#

gather_mm(a: array, b: array, /, lhs_indices: array, rhs_indices: array, *, stream: None | Stream | Device = None) array#

Matrix multiplication with matrix-level gather.

Performs a gather of the operands with the given indices followed by a (possibly batched) matrix multiplication of two arrays. This operation is more efficient than explicitly applying a take() followed by a matmul().

The indices lhs_indices and rhs_indices contain flat indices along the batch dimensions (i.e. all but the last two dimensions) of a and b respectively.

For a with shape (A1, A2, ..., AS, M, K), lhs_indices contains indices from the range [0, A1 * A2 * ... * AS)

For b with shape (B1, B2, ..., BS, M, K), rhs_indices contains indices from the range [0, B1 * B2 * ... * BS)

Parameters:
  • a (array) – Input array.

  • b (array) – Input array.

  • lhs_indices (array, optional) – Integer indices for a. Default: None

  • rhs_indices (array, optional) – Integer indices for b. Default: None

Returns:

The output array.

Return type:

array