mlx.utils.tree_map#
- tree_map(fn: Callable, tree: Any, *rest: Any, is_leaf: Callable | None = None) Any #
Applies
fn
to the leaves of the Python treetree
and returns a new collection with the results.If
rest
is provided, every item is assumed to be a superset oftree
and the corresponding leaves are provided as extra positional arguments tofn
. In that respect,tree_map()
is closer toitertools.starmap()
than tomap()
.The keyword argument
is_leaf
decides what constitutes a leaf fromtree
similar totree_flatten()
.import mlx.nn as nn from mlx.utils import tree_map model = nn.Linear(10, 10) print(model.parameters().keys()) # dict_keys(['weight', 'bias']) # square the parameters model.update(tree_map(lambda x: x*x, model.parameters()))
- Parameters:
fn (callable) – The function that processes the leaves of the tree.
tree (Any) – The main Python tree that will be iterated upon.
rest (tuple[Any]) – Extra trees to be iterated together with
tree
.is_leaf (callable, optional) – An optional callable that returns
True
if the passed object is considered a leaf orFalse
otherwise.
- Returns:
A Python tree with the new values returned by
fn
.