mlx.nn.Module.filter_and_map#

Module.filter_and_map(filter_fn: Callable[[Module, str, Any], bool], map_fn: Callable | None = None, is_leaf_fn: Callable[[Module, str, Any], bool] | None = None)#

Recursively filter the contents of the module using filter_fn, namely only select keys and values where filter_fn returns true.

This is used to implement parameters() and trainable_parameters() but it can also be used to extract any subset of the module’s parameters.

Parameters:
  • filter_fn (Callable) – Given a value, the key in which it is found and the containing module, decide whether to keep the value or drop it.

  • map_fn (Callable, optional) – Optionally transform the value before returning it.

  • is_leaf_fn (Callable, optional) – Given a value, the key in which it is found and the containing module decide if it is a leaf.

Returns:

A dictionary containing the contents of the module recursively filtered