mlx.nn.quantize#

quantize(model: Module, group_size: int = 64, bits: int = 4, class_predicate: callable | None = None)#

Quantize the sub-modules of a module according to a predicate.

By default all Linear and Embedding layers will be quantized. Note also, the module is updated in-place.

Parameters:
  • model (Module) – The model whose leaf modules may be quantized.

  • group_size (int) – The quantization group size (see mlx.core.quantize()). Default: 64.

  • bits (int) – The number of bits per parameter (see mlx.core.quantize()). Default: 4.

  • class_predicate (Optional[Callable]) – A callable which receives the Module path and Module itself and returns True if it should be quantized and False otherwise. If None, then all linear and embedding layers are quantized. Default: None.