mlx.nn.quantize

Contents

mlx.nn.quantize#

quantize(model: Module, group_size: int = 64, bits: int = 4, class_predicate: Callable[[str, Module], bool | dict] | None = None)#

Quantize the sub-modules of a module according to a predicate.

By default all layers that define a to_quantized(group_size, bits) method will be quantized. Both Linear and Embedding layers will be quantized. Note also, the module is updated in-place.

Parameters:
  • model (Module) – The model whose leaf modules may be quantized.

  • group_size (int) – The quantization group size (see mlx.core.quantize()). Default: 64.

  • bits (int) – The number of bits per parameter (see mlx.core.quantize()). Default: 4.

  • class_predicate (Optional[Callable]) – A callable which receives the Module path and Module itself and returns True or a dict of params for to_quantized if it should be quantized and False otherwise. If None, then all layers that define a to_quantized(group_size, bits) method are quantized. Default: None.