tree_map(fn, tree, *rest, is_leaf=None)#

Applies fn to the leaves of the python tree tree and returns a new collection with the results.

If rest is provided, every item is assumed to be a superset of tree and the corresponding leaves are provided as extra positional arguments to fn. In that respect, tree_map() is closer to itertools.starmap() than to map().

The keyword argument is_leaf decides what constitutes a leaf from tree similar to tree_flatten().

import mlx.nn as nn
from mlx.utils import tree_map

model = nn.Linear(10, 10)
# dict_keys(['weight', 'bias'])

# square the parameters
model.update(tree_map(lambda x: x*x, model.parameters()))
  • fn (Callable) – The function that processes the leaves of the tree

  • tree (Any) – The main python tree that will be iterated upon

  • rest (Tuple[Any]) – Extra trees to be iterated together with tree

  • is_leaf (Optional[Callable]) – An optional callable that returns True if the passed object is considered a leaf or False otherwise.


A python tree with the new values returned by fn.