Tree Utils#

In MLX we consider a python tree to be an arbitrarily nested collection of dictionaries, lists and tuples without cycles. Functions in this module that return python trees will be using the default python dict, list and tuple but they can usually process objects that inherit from any of these.


Dictionaries should have keys that are valid python identifiers.

tree_flatten(tree[, prefix, is_leaf])

Flattens a python tree to a list of key, value tuples.


Recreate a python tree from its flat representation.

tree_map(fn, tree, *rest[, is_leaf])

Applies fn to the leaves of the python tree tree and returns a new collection with the results.