mlx.utils.tree_map_with_path#

tree_map_with_path(fn, tree, *rest, is_leaf=None, path=None)#

Applies fn to the path and leaves of the Python tree tree and returns a new collection with the results.

This function is the same tree_map() but the fn takes the path as the first argument followed by the remaining tree nodes.

Parameters:
  • fn (callable) – The function that processes the leaves of the tree.

  • tree (Any) – The main Python tree that will be iterated upon.

  • rest (tuple[Any]) – Extra trees to be iterated together with tree.

  • is_leaf (callable, optional) – An optional callable that returns True if the passed object is considered a leaf or False otherwise.

Returns:

A Python tree with the new values returned by fn.

Example

>>> from mlx.utils import tree_map_with_path
>>> tree = {"model": [{"w": 0, "b": 1}, {"w": 0, "b": 1}]}
>>> new_tree = tree_map_with_path(lambda path, _: print(path), tree)
model.0.w
model.0.b
model.1.w
model.1.b