class Module#

Base class for building neural networks with MLX.

All the layers provided in mlx.nn.layers subclass this class and your models should do the same.

A Module can contain other Module instances or mlx.core.array instances in arbitrary nesting of python lists or dicts. The Module then allows recursively extracting all the mlx.core.array instances using mlx.nn.Module.parameters().

In addition, the Module has the concept of trainable and non trainable parameters (called “frozen”). When using mlx.nn.value_and_grad() the gradients are returned only with respect to the trainable parameters. All arrays in a module are trainable unless they are added in the “frozen” set by calling freeze().

import mlx.core as mx
import mlx.nn as nn

class MyMLP(nn.Module):
    def __init__(self, in_dims: int, out_dims: int, hidden_dims: int = 16):

        self.in_proj = nn.Linear(in_dims, hidden_dims)
        self.out_proj = nn.Linear(hidden_dims, out_dims)

    def __call__(self, x):
        x = self.in_proj(x)
        x = mx.maximum(x, 0)
        return self.out_proj(x)

model = MyMLP(2, 1)

# All the model parameters are created but since MLX is lazy by
# default, they are not evaluated yet. Calling `mx.eval` actually
# allocates memory and initializes the parameters.

# Setting a parameter to a new value is as simply as accessing that
# parameter and assigning a new array to it.
model.in_proj.weight = model.in_proj.weight * 2


Boolean indicating if the model is in training mode.


The module's state dictionary


Module.apply(map_fn[, filter_fn])

Map all the parameters using the provided map_fn and immediately update the module with the mapped parameters.


Apply a function to all the modules in this instance (including this instance).


Return the direct descendants of this Module instance.


Set the model to evaluation mode.

Module.filter_and_map(filter_fn[, map_fn, ...])

Recursively filter the contents of the module using filter_fn, namely only select keys and values where filter_fn returns true.

Module.freeze(*[, recurse, keys, strict])

Freeze the Module's parameters or some of them.


Return the submodules that do not contain other modules.

Module.load_weights(file_or_weights[, strict])

Update the model's weights from a .npz, a .safetensors file, or a list.


Return a list with all the modules in this instance.


Return a list with all the modules in this instance and their name with dot notation.


Recursively return all the mlx.core.array members of this Module as a dict of dicts and lists.


Save the model's weights to a file.

Module.set_dtype(dtype[, predicate])

Set the dtype of the module's parameters.


Set the model in or out of training mode.


Recursively return all the non frozen mlx.core.array members of this Module as a dict of dicts and lists.

Module.unfreeze(*[, recurse, keys, strict])

Unfreeze the Module's parameters or some of them.


Replace the parameters of this Module with the provided ones in the dict of dicts and lists.


Replace the child modules of this Module instance with the provided ones in the dict of dicts and lists.