mlx.nn.QuantizedEmbedding

mlx.nn.QuantizedEmbedding#

class QuantizedEmbedding(num_embeddings: int, dims: int, group_size: int = 64, bits: int = 4)#

The same as Embedding but with a quantized weight matrix.

QuantizedEmbedding also provides a from_embedding() classmethod to convert embedding layers to QuantizedEmbedding layers.

Parameters:
  • num_embeddings (int) – How many possible discrete tokens can we embed. Usually called the vocabulary size.

  • dims (int) – The dimensionality of the embeddings.

  • group_size (int, optional) – The group size to use for the quantized weight. See quantize(). Default: 64.

  • bits (int, optional) – The bit width to use for the quantized weight. See quantize(). Default: 4.

Methods

as_linear(x)

Call the quantized embedding layer as a quantized linear layer.

from_embedding(embedding_layer[, ...])

Create a QuantizedEmbedding layer from an Embedding layer.