Module#
- class Module#
Base class for building neural networks with MLX.
All the layers provided in
mlx.nn.layers
subclass this class and your models should do the same.A
Module
can contain otherModule
instances ormlx.core.array
instances in arbitrary nesting of python lists or dicts. TheModule
then allows recursively extracting all themlx.core.array
instances usingmlx.nn.Module.parameters()
.In addition, the
Module
has the concept of trainable and non trainable parameters (called “frozen”). When usingmlx.nn.value_and_grad()
the gradients are returned only with respect to the trainable parameters. All arrays in a module are trainable unless they are added in the “frozen” set by callingfreeze()
.import mlx.core as mx import mlx.nn as nn class MyMLP(nn.Module): def __init__(self, in_dims: int, out_dims: int, hidden_dims: int = 16): super().__init__() self.in_proj = nn.Linear(in_dims, hidden_dims) self.out_proj = nn.Linear(hidden_dims, out_dims) def __call__(self, x): x = self.in_proj(x) x = mx.maximum(x, 0) return self.out_proj(x) model = MyMLP(2, 1) # All the model parameters are created but since MLX is lazy by # default, they are not evaluated yet. Calling `mx.eval` actually # allocates memory and initializes the parameters. mx.eval(model.parameters()) # Setting a parameter to a new value is as simply as accessing that # parameter and assigning a new array to it. model.in_proj.weight = model.in_proj.weight * 2 mx.eval(model.parameters())
Attributes
Boolean indicating if the model is in training mode.
The module's state dictionary
Methods
Module.apply
(map_fn[, filter_fn])Map all the parameters using the provided
map_fn
and immediately update the module with the mapped parameters.Module.apply_to_modules
(apply_fn)Apply a function to all the modules in this instance (including this instance).
Return the direct descendants of this Module instance.
Set the model to evaluation mode.
Module.filter_and_map
(filter_fn[, map_fn, ...])Recursively filter the contents of the module using
filter_fn
, namely only select keys and values wherefilter_fn
returns true.Module.freeze
(*[, recurse, keys, strict])Freeze the Module's parameters or some of them.
Return the submodules that do not contain other modules.
Module.load_weights
(file_or_weights[, strict])Update the model's weights from a
.npz
, a.safetensors
file, or a list.Return a list with all the modules in this instance.
Return a list with all the modules in this instance and their name with dot notation.
Recursively return all the
mlx.core.array
members of this Module as a dict of dicts and lists.Module.save_weights
(file)Save the model's weights to a file.
Module.set_dtype
(dtype[, predicate])Set the dtype of the module's parameters.
Module.train
([mode])Set the model in or out of training mode.
Recursively return all the non frozen
mlx.core.array
members of this Module as a dict of dicts and lists.Module.unfreeze
(*[, recurse, keys, strict])Unfreeze the Module's parameters or some of them.
Module.update
(parameters)Replace the parameters of this Module with the provided ones in the dict of dicts and lists.
Module.update_modules
(modules)Replace the child modules of this
Module
instance with the provided ones in the dict of dicts and lists.