mlx.utils.tree_map_with_path#
- tree_map_with_path(fn: Callable, tree: Any, *rest: Any, is_leaf: Callable | None = None, path: Any | None = None) Any #
Applies
fn
to the path and leaves of the Python treetree
and returns a new collection with the results.This function is the same
tree_map()
but thefn
takes the path as the first argument followed by the remaining tree nodes.- Parameters:
fn (callable) – The function that processes the leaves of the tree.
tree (Any) – The main Python tree that will be iterated upon.
rest (tuple[Any]) – Extra trees to be iterated together with
tree
.is_leaf (callable, optional) – An optional callable that returns
True
if the passed object is considered a leaf orFalse
otherwise.
- Returns:
A Python tree with the new values returned by
fn
.
Example
>>> from mlx.utils import tree_map_with_path >>> tree = {"model": [{"w": 0, "b": 1}, {"w": 0, "b": 1}]} >>> new_tree = tree_map_with_path(lambda path, _: print(path), tree) model.0.w model.0.b model.1.w model.1.b