Module#
- class Module#
- Base class for building neural networks with MLX. - All the layers provided in - mlx.nn.layerssubclass this class and your models should do the same.- A - Modulecan contain other- Moduleinstances or- mlx.core.arrayinstances in arbitrary nesting of python lists or dicts. The- Modulethen allows recursively extracting all the- mlx.core.arrayinstances using- mlx.nn.Module.parameters().- In addition, the - Modulehas the concept of trainable and non trainable parameters (called “frozen”). When using- mlx.nn.value_and_grad()the gradients are returned only with respect to the trainable parameters. All arrays in a module are trainable unless they are added in the “frozen” set by calling- freeze().- import mlx.core as mx import mlx.nn as nn class MyMLP(nn.Module): def __init__(self, in_dims: int, out_dims: int, hidden_dims: int = 16): super().__init__() self.in_proj = nn.Linear(in_dims, hidden_dims) self.out_proj = nn.Linear(hidden_dims, out_dims) def __call__(self, x): x = self.in_proj(x) x = mx.maximum(x, 0) return self.out_proj(x) model = MyMLP(2, 1) # All the model parameters are created but since MLX is lazy by # default, they are not evaluated yet. Calling `mx.eval` actually # allocates memory and initializes the parameters. mx.eval(model.parameters()) # Setting a parameter to a new value is as simply as accessing that # parameter and assigning a new array to it. model.in_proj.weight = model.in_proj.weight * 2 mx.eval(model.parameters()) - Attributes - Boolean indicating if the model is in training mode. - The module's state dictionary - Methods - Module.apply(map_fn[, filter_fn])- Map all the parameters using the provided - map_fnand immediately update the module with the mapped parameters.- Module.apply_to_modules(apply_fn)- Apply a function to all the modules in this instance (including this instance). - Return the direct descendants of this Module instance. - Set the model to evaluation mode. - Module.filter_and_map(filter_fn[, map_fn, ...])- Recursively filter the contents of the module using - filter_fn, namely only select keys and values where- filter_fnreturns true.- Module.freeze(*[, recurse, keys, strict])- Freeze the Module's parameters or some of them. - Return the submodules that do not contain other modules. - Module.load_weights(file_or_weights[, strict])- Update the model's weights from a - .npz, a- .safetensorsfile, or a list.- Return a list with all the modules in this instance. - Return a list with all the modules in this instance and their name with dot notation. - Recursively return all the - mlx.core.arraymembers of this Module as a dict of dicts and lists.- Module.save_weights(file)- Save the model's weights to a file. - Module.set_dtype(dtype[, predicate])- Set the dtype of the module's parameters. - Module.train([mode])- Set the model in or out of training mode. - Recursively return all the non frozen - mlx.core.arraymembers of this Module as a dict of dicts and lists.- Module.unfreeze(*[, recurse, keys, strict])- Unfreeze the Module's parameters or some of them. - Module.update(parameters[, strict])- Replace the parameters of this Module with the provided ones in the dict of dicts and lists. - Module.update_modules(modules[, strict])- Replace the child modules of this - Moduleinstance with the provided ones in the dict of dicts and lists.
