mlx.nn.quantize#
- quantize(model: Module, group_size: int = None, bits: int = None, *, mode: str = 'affine', quantize_input: bool = False, class_predicate: Callable[[str, Module], bool | dict] | None = None)#
Quantize the sub-modules of a module according to a predicate.
By default all layers that define a
to_quantized()method will be quantized. BothLinearandEmbeddinglayers will be quantized. The module is updated in-place.Note
quantize_input=Trueis only supported for"nvfp4"and"mxfp8"modes andLinearlayers.- Parameters:
model (Module) – The model whose leaf modules may be quantized.
group_size (Optional[int]) – The quantization group size (see
mlx.core.quantize()). Default:None.bits (Optional[int]) – The number of bits per parameter (see
mlx.core.quantize()). Default:None.mode (str) – The quantization method to use (see
mlx.core.quantize()). Default:"affine".quantize_input (bool) – Whether to quantize activations. Default:
False.class_predicate (Optional[Callable]) – A callable which receives the
Modulepath andModuleitself and returnsTrueor a dict of params forto_quantizedif it should be quantized andFalseotherwise. IfNone, then all layers that define ato_quantized()method are quantized. Default:None.
Example
Weight only quantization for all layers that define a
to_quantized()method:>>> import mlx.nn as nn >>> nn.quantize(model, group_size=64, bits=4, mode="affine")
Weight and input quantization for all linear layers:
>>> predicate = lambda p, m: isinstance(m, nn.Linear) >>> nn.quantize(model, mode="nvfp4", quantize_input=True, class_predicate=predicate)