mlx.nn.Module.set_dtype#

Module.set_dtype(dtype: ~mlx.core.Dtype, predicate: ~typing.Callable[[~mlx.core.Dtype], bool] | None = <function Module.<lambda>>)#

Set the dtype of the module’s parameters.

Parameters:
  • dtype (Dtype) – The new dtype.

  • predicate (Callable, optional) – A predicate to select parameters to cast. By default, only parameters of type floating will be updated to avoid casting integer parameters to the new dtype.