mlx.nn.Module.load_weights

mlx.nn.Module.load_weights#

Module.load_weights(file_or_weights: str | List[Tuple[str, array]], strict: bool = True) Module#

Update the model’s weights from a .npz, a .safetensors file, or a list.

Parameters:
  • file_or_weights (str or list(tuple(str, mx.array))) – The path to the weights .npz file (.npz or .safetensors) or a list of pairs of parameter names and arrays.

  • strict (bool, optional) – If True then checks that the provided weights exactly match the parameters of the model. Otherwise, only the weights actually contained in the model are loaded and shapes are not checked. Default: True.

Returns:

The module instance after updating the weights.

Example

import mlx.core as mx
import mlx.nn as nn
model = nn.Linear(10, 10)

# Load from file
model.load_weights("weights.npz")

# Load from .safetensors file
model.load_weights("weights.safetensors")

# Load from list
weights = [
    ("weight", mx.random.uniform(shape=(10, 10))),
    ("bias",  mx.zeros((10,))),
]
model.load_weights(weights)

# Missing weight
weights = [
    ("weight", mx.random.uniform(shape=(10, 10))),
]

# Raises a ValueError exception
model.load_weights(weights)

# Ok, only updates the weight but not the bias
model.load_weights(weights, strict=False)