mlx.nn.Embedding

Contents

mlx.nn.Embedding#

class Embedding(num_embeddings: int, dims: int)#

Implements a simple lookup table that maps each input integer to a high-dimensional vector.

Typically used to embed discrete tokens for processing by neural networks.

Parameters:
  • num_embeddings (int) – How many possible discrete tokens can we embed. Usually called the vocabulary size.

  • dims (int) – The dimensionality of the embeddings.

Methods

as_linear(x)

Call the embedding layer as a linear layer.

to_quantized([group_size, bits])

Return a QuantizedEmbedding layer that approximates this embedding layer.