Neural Networks#
Writing arbitrarily complex neural networks in MLX can be done using only
mlx.core.array
and mlx.core.value_and_grad()
. However, this requires the
user to write again and again the same simple neural network operations as well
as handle all the parameter state and initialization manually and explicitly.
The module mlx.nn
solves this problem by providing an intuitive way of
composing neural network layers, initializing their parameters, freezing them
for finetuning and more.
Quick Start with Neural Networks#
import mlx.core as mx
import mlx.nn as nn
class MLP(nn.Module):
def __init__(self, in_dims: int, out_dims: int):
super().__init__()
self.layers = [
nn.Linear(in_dims, 128),
nn.Linear(128, 128),
nn.Linear(128, out_dims),
]
def __call__(self, x):
for i, l in enumerate(self.layers):
x = mx.maximum(x, 0) if i > 0 else x
x = l(x)
return x
# The model is created with all its parameters but nothing is initialized
# yet because MLX is lazily evaluated
mlp = MLP(2, 10)
# We can access its parameters by calling mlp.parameters()
params = mlp.parameters()
print(params["layers"][0]["weight"].shape)
# Printing a parameter will cause it to be evaluated and thus initialized
print(params["layers"][0])
# We can also force evaluate all parameters to initialize the model
mx.eval(mlp.parameters())
# A simple loss function.
# NOTE: It doesn't matter how it uses the mlp model. It currently captures
# it from the local scope. It could be a positional argument or a
# keyword argument.
def l2_loss(x, y):
y_hat = mlp(x)
return (y_hat - y).square().mean()
# Calling `nn.value_and_grad` instead of `mx.value_and_grad` returns the
# gradient with respect to `mlp.trainable_parameters()`
loss_and_grad = nn.value_and_grad(mlp, l2_loss)
The Module Class#
The workhorse of any neural network library is the Module
class. In
MLX the Module
class is a container of mlx.core.array
or
Module
instances. Its main function is to provide a way to
recursively access and update its parameters and those of its
submodules.
Parameters#
A parameter of a module is any public member of type mlx.core.array
(its
name should not start with _
). It can be arbitrarily nested in other
Module
instances or lists and dictionaries.
Module.parameters()
can be used to extract a nested dictionary with all
the parameters of a module and its submodules.
A Module
can also keep track of “frozen” parameters. See the
Module.freeze()
method for more details. mlx.nn.value_and_grad()
the gradients returned will be with respect to these trainable parameters.
Updating the Parameters#
MLX modules allow accessing and updating individual parameters. However, most
times we need to update large subsets of a module’s parameters. This action is
performed by Module.update()
.
Inspecting Modules#
The simplest way to see the model architecture is to print it. Following along with
the above example, you can print the MLP
with:
print(mlp)
This will display:
MLP(
(layers.0): Linear(input_dims=2, output_dims=128, bias=True)
(layers.1): Linear(input_dims=128, output_dims=128, bias=True)
(layers.2): Linear(input_dims=128, output_dims=10, bias=True)
)
To get more detailed information on the arrays in a Module
you can use
mlx.utils.tree_map()
on the parameters. For example, to see the shapes of
all the parameters in a Module
do:
from mlx.utils import tree_map
shapes = tree_map(lambda p: p.shape, mlp.parameters())
As another example, you can count the number of parameters in a Module
with:
from mlx.utils import tree_flatten
num_params = sum(v.size for _, v in tree_flatten(mlp.parameters()))
Value and Grad#
Using a Module
does not preclude using MLX’s high order function
transformations (mlx.core.value_and_grad()
, mlx.core.grad()
, etc.). However,
these function transformations assume pure functions, namely the parameters
should be passed as an argument to the function being transformed.
There is an easy pattern to achieve that with MLX modules
model = ...
def f(params, other_inputs):
model.update(params) # <---- Necessary to make the model use the passed parameters
return model(other_inputs)
f(model.trainable_parameters(), mx.zeros((10,)))
However, mlx.nn.value_and_grad()
provides precisely this pattern and only
computes the gradients with respect to the trainable parameters of the model.
In detail:
it wraps the passed function with a function that calls
Module.update()
to make sure the model is using the provided parameters.it calls
mlx.core.value_and_grad()
to transform the function into a function that also computes the gradients with respect to the passed parameters.it wraps the returned function with a function that passes the trainable parameters as the first argument to the function returned by
mlx.core.value_and_grad()
|
Transform the passed function |
|
Quantize the sub-modules of a module according to a predicate. |
- Module
Module
- mlx.nn.Module.training
- mlx.nn.Module.state
- mlx.nn.Module.apply
- mlx.nn.Module.apply_to_modules
- mlx.nn.Module.children
- mlx.nn.Module.eval
- mlx.nn.Module.filter_and_map
- mlx.nn.Module.freeze
- mlx.nn.Module.leaf_modules
- mlx.nn.Module.load_weights
- mlx.nn.Module.modules
- mlx.nn.Module.named_modules
- mlx.nn.Module.parameters
- mlx.nn.Module.save_weights
- mlx.nn.Module.set_dtype
- mlx.nn.Module.train
- mlx.nn.Module.trainable_parameters
- mlx.nn.Module.unfreeze
- mlx.nn.Module.update
- mlx.nn.Module.update_modules
- Layers
- mlx.nn.ALiBi
- mlx.nn.AvgPool1d
- mlx.nn.AvgPool2d
- mlx.nn.BatchNorm
- mlx.nn.CELU
- mlx.nn.Conv1d
- mlx.nn.Conv2d
- mlx.nn.Conv3d
- mlx.nn.ConvTranspose1d
- mlx.nn.ConvTranspose2d
- mlx.nn.ConvTranspose3d
- mlx.nn.Dropout
- mlx.nn.Dropout2d
- mlx.nn.Dropout3d
- mlx.nn.Embedding
- mlx.nn.ELU
- mlx.nn.GELU
- mlx.nn.GLU
- mlx.nn.GroupNorm
- mlx.nn.GRU
- mlx.nn.HardShrink
- mlx.nn.HardTanh
- mlx.nn.Hardswish
- mlx.nn.InstanceNorm
- mlx.nn.LayerNorm
- mlx.nn.LeakyReLU
- mlx.nn.Linear
- mlx.nn.LogSigmoid
- mlx.nn.LogSoftmax
- mlx.nn.LSTM
- mlx.nn.MaxPool1d
- mlx.nn.MaxPool2d
- mlx.nn.Mish
- mlx.nn.MultiHeadAttention
- mlx.nn.PReLU
- mlx.nn.QuantizedEmbedding
- mlx.nn.QuantizedLinear
- mlx.nn.RMSNorm
- mlx.nn.ReLU
- mlx.nn.ReLU6
- mlx.nn.RNN
- mlx.nn.RoPE
- mlx.nn.SELU
- mlx.nn.Sequential
- mlx.nn.Sigmoid
- mlx.nn.SiLU
- mlx.nn.SinusoidalPositionalEncoding
- mlx.nn.Softmin
- mlx.nn.Softshrink
- mlx.nn.Softsign
- mlx.nn.Softmax
- mlx.nn.Softplus
- mlx.nn.Step
- mlx.nn.Tanh
- mlx.nn.Transformer
- mlx.nn.Upsample
- Functions
- mlx.nn.elu
- mlx.nn.celu
- mlx.nn.gelu
- mlx.nn.gelu_approx
- mlx.nn.gelu_fast_approx
- mlx.nn.glu
- mlx.nn.hard_shrink
- mlx.nn.hard_tanh
- mlx.nn.hardswish
- mlx.nn.leaky_relu
- mlx.nn.log_sigmoid
- mlx.nn.log_softmax
- mlx.nn.mish
- mlx.nn.prelu
- mlx.nn.relu
- mlx.nn.relu6
- mlx.nn.selu
- mlx.nn.sigmoid
- mlx.nn.silu
- mlx.nn.softmax
- mlx.nn.softmin
- mlx.nn.softplus
- mlx.nn.softshrink
- mlx.nn.step
- mlx.nn.tanh
- Loss Functions
- mlx.nn.losses.binary_cross_entropy
- mlx.nn.losses.cosine_similarity_loss
- mlx.nn.losses.cross_entropy
- mlx.nn.losses.gaussian_nll_loss
- mlx.nn.losses.hinge_loss
- mlx.nn.losses.huber_loss
- mlx.nn.losses.kl_div_loss
- mlx.nn.losses.l1_loss
- mlx.nn.losses.log_cosh_loss
- mlx.nn.losses.margin_ranking_loss
- mlx.nn.losses.mse_loss
- mlx.nn.losses.nll_loss
- mlx.nn.losses.smooth_l1_loss
- mlx.nn.losses.triplet_loss
- Initializers