mlx.utils.tree_flatten

Contents

mlx.utils.tree_flatten#

tree_flatten(tree: Any, prefix: str = '', is_leaf: Callable | None = None) Any#

Flattens a Python tree to a list of key, value tuples.

The keys are using the dot notation to define trees of arbitrary depth and complexity.

from mlx.utils import tree_flatten

print(tree_flatten([[[0]]]))
# [("0.0.0", 0)]

print(tree_flatten([[[0]]], ".hello"))
# [("hello.0.0.0", 0)]

Note

Dictionaries should have keys that are valid Python identifiers.

Parameters:
  • tree (Any) – The Python tree to be flattened.

  • prefix (str) – A prefix to use for the keys. The first character is always discarded.

  • is_leaf (callable) – An optional callable that returns True if the passed object is considered a leaf or False otherwise.

Returns:

The flat representation of the Python tree.

Return type:

List[Tuple[str, Any]]