mlx.nn.quantize#
- quantize(model: Module, group_size: int = 64, bits: int = 4, class_predicate: Callable[[str, Module], bool | dict] | None = None)#
Quantize the sub-modules of a module according to a predicate.
By default all layers that define a
to_quantized(group_size, bits)
method will be quantized. BothLinear
andEmbedding
layers will be quantized. Note also, the module is updated in-place.- Parameters:
model (Module) – The model whose leaf modules may be quantized.
group_size (int) – The quantization group size (see
mlx.core.quantize()
). Default:64
.bits (int) – The number of bits per parameter (see
mlx.core.quantize()
). Default:4
.class_predicate (Optional[Callable]) – A callable which receives the
Module
path andModule
itself and returnsTrue
or a dict of params for to_quantized if it should be quantized andFalse
otherwise. IfNone
, then all layers that define ato_quantized(group_size, bits)
method are quantized. Default:None
.