mlx.core.gather_qmm

Contents

mlx.core.gather_qmm#

gather_qmm(x: array, w: array, /, scales: array, biases: array, lhs_indices: array | None = None, rhs_indices: array | None = None, transpose: bool = True, group_size: int = 64, bits: int = 4, *, stream: None | Stream | Device = None) array#

Perform quantized matrix multiplication with matrix-level gather.

This operation is the quantized equivalent to gather_mm(). Similar to gather_mm(), the indices lhs_indices and rhs_indices contain flat indices along the batch dimensions (i.e. all but the last two dimensions) of x and w respectively.

Note that scales and biases must have the same batch dimensions as w since they represent the same quantized matrix.

Parameters:
  • x (array) – Input array

  • w (array) – Quantized matrix packed in unsigned integers

  • scales (array) – The scales to use per group_size elements of w

  • biases (array) – The biases to use per group_size elements of w

  • lhs_indices (array, optional) – Integer indices for x. Default: None.

  • rhs_indices (array, optional) – Integer indices for w. Default: None.

  • transpose (bool, optional) – Defines whether to multiply with the transposed w or not, namely whether we are performing x @ w.T or x @ w. Default: True.

  • group_size (int, optional) – The size of the group in w that shares a scale and bias. Default: 64.

  • bits (int, optional) – The number of bits occupied by each element in w. Default: 4.

Returns:

The result of the multiplication of x with w

after gathering using lhs_indices and rhs_indices.

Return type:

array