mlx.core.quantize#
- quantize(w: array, /, group_size: int | None = None, bits: int | None = None, mode: str = 'affine', *, stream: None | Stream | Device = None) tuple[array, array, array]#
Quantize the array
w.Note, every
group_sizeelements in a row ofware quantized together. Hence, the last dimension ofwshould be divisible bygroup_size.Warning
quantizeonly supports inputs with two or more dimensions with the last dimension divisible bygroup_sizeThe supported quantization modes are
"affine","mxfp4","mxfp8", and"nvfp4". They are described in more detail below.- Parameters:
w (array) – Array to be quantized
group_size (int, optional) – The size of the group in
wthat shares a scale and bias. See supported values and defaults in the table of quantization modes. Default:None.bits (int, optional) – The number of bits occupied by each element of
win the quantized array. See supported values and defaults in the table of quantization modes. Default:None.mode (str, optional) – The quantization mode. Default:
"affine".
- Returns:
A tuple with either two or three elements containing:
w_q (array): The quantized version of
wscales (array): The quantization scales
biases (array): The quantization biases (returned for
mode=="affine").
- Return type:
Notes
Quantization modes# mode
group size
bits
scale type
bias
affine
32, 64*, 128
2, 3, 4*, 5, 6, 8
same as input
yes
mxfp4
32*
4*
e8m0
no
mxfp8
32*
4*
e8m0
no
nvfp4
16*
4*
e4m3
no
* indicates the default value when unspecified.
The
"affine"mode quantizes groups of \(g\) consecutive elements in a row ofw. For each group the quantized representation of each element \(\hat{w_i}\) is computed as follows:\[\begin{split}\begin{aligned} \alpha &= \max_i w_i \\ \beta &= \min_i w_i \\ s &= \frac{\alpha - \beta}{2^b - 1} \\ \hat{w_i} &= \textrm{round}\left( \frac{w_i - \beta}{s}\right). \end{aligned}\end{split}\]After the above computation, \(\hat{w_i}\) fits in \(b\) bits and is packed in an unsigned 32-bit integer from the lower to upper bits. For instance, for 4-bit quantization we fit 8 elements in an unsigned 32 bit integer where the 1st element occupies the 4 least significant bits, the 2nd bits 4-7 etc.
To dequantize the elements of
w, we also save \(s\) and \(\beta\) which are the returnedscalesandbiasesrespectively.The
"mxfp4","mxfp8", and"nvfp4"modes similarly quantize groups of \(g\) elements ofw. For the"mx"modes, the group size must be32. For"nvfp4"the group size must be 16. The elements are quantized to 4-bit or 8-bit precision floating-point values: E2M1 for"fp4"and E4M3 for"fp8". There is a shared 8-bit scale per group. The"mx"modes us an E8M0 scale and the"nv"mode uses an E4M3 scale. Unlikeaffinequantization, these modes does not have a bias value.More details on the
"mx"formats can be found in the specification.