mlx.nn.quantize

Contents

mlx.nn.quantize#

quantize(model: Module, group_size: int = None, bits: int = None, *, mode: str = 'affine', quantize_input: bool = False, class_predicate: Callable[[str, Module], bool | dict] | None = None)#

Quantize the sub-modules of a module according to a predicate.

By default all layers that define a to_quantized() method will be quantized. Both Linear and Embedding layers will be quantized. The module is updated in-place.

Note

quantize_input=True is only supported for "nvfp4" and "mxfp8" modes and Linear layers.

Parameters:
  • model (Module) – The model whose leaf modules may be quantized.

  • group_size (Optional[int]) – The quantization group size (see mlx.core.quantize()). Default: None.

  • bits (Optional[int]) – The number of bits per parameter (see mlx.core.quantize()). Default: None.

  • mode (str) – The quantization method to use (see mlx.core.quantize()). Default: "affine".

  • quantize_input (bool) – Whether to quantize activations. Default: False.

  • class_predicate (Optional[Callable]) – A callable which receives the Module path and Module itself and returns True or a dict of params for to_quantized if it should be quantized and False otherwise. If None, then all layers that define a to_quantized() method are quantized. Default: None.

Example

Weight only quantization for all layers that define a to_quantized() method:

>>> import mlx.nn as nn
>>> nn.quantize(model, group_size=64, bits=4, mode="affine")

Weight and input quantization for all linear layers:

>>> predicate = lambda p, m: isinstance(m, nn.Linear)
>>> nn.quantize(model, mode="nvfp4", quantize_input=True, class_predicate=predicate)