Lazy Evaluation#

Why Lazy Evaluation#

When you perform operations in MLX, no computation actually happens. Instead a compute graph is recorded. The actual computation only happens if an eval() is performed.

MLX uses lazy evaluation because it has some nice features, some of which we describe below.

Transforming Compute Graphs#

Lazy evaluation lets us record a compute graph without actually doing any computations. This is useful for function transformations like grad() and vmap() and graph optimizations.

Currently, MLX does not compile and rerun compute graphs. They are all generated dynamically. However, lazy evaluation makes it much easier to integrate compilation for future performance enhancements.

Only Compute What You Use#

In MLX you do not need to worry as much about computing outputs that are never used. For example:

def fun(x):
    a = fun1(x)
    b = expensive_fun(a)
    return a, b

y, _ = fun(x)

Here, we never actually compute the output of expensive_fun. Use this pattern with care though, as the graph of expensive_fun is still built, and that has some cost associated to it.

Similarly, lazy evaluation can be beneficial for saving memory while keeping code simple. Say you have a very large model Model derived from mlx.nn.Module. You can instantiate this model with model = Model(). Typically, this will initialize all of the weights as float32, but the initialization does not actually compute anything until you perform an eval(). If you update the model with float16 weights, your maximum consumed memory will be half that required if eager computation was used instead.

This pattern is simple to do in MLX thanks to lazy computation:

model = Model() # no memory used yet

When to Evaluate#

A common question is when to use eval(). The trade-off is between letting graphs get too large and not batching enough useful work.

For example:

for _ in range(100):
     a = a + b
     b = b * 2

This is a bad idea because there is some fixed overhead with each graph evaluation. On the other hand, there is some slight overhead which grows with the compute graph size, so extremely large graphs (while computationally correct) can be costly.

Luckily, a wide range of compute graph sizes work pretty well with MLX: anything from a few tens of operations to many thousands of operations per evaluation should be okay.

Most numerical computations have an iterative outer loop (e.g. the iteration in stochastic gradient descent). A natural and usually efficient place to use eval() is at each iteration of this outer loop.

Here is a concrete example:

for batch in dataset:

    # Nothing has been evaluated yet
    loss, grad = value_and_grad_fn(model, batch)

    # Still nothing has been evaluated
    optimizer.update(model, grad)

    # Evaluate the loss and the new parameters which will
    # run the full gradient computation and optimizer update
    mx.eval(loss, model.parameters())

An important behavior to be aware of is when the graph will be implicitly evaluated. Anytime you print an array, convert it to an numpy.ndarray, or otherwise access it’s memory via memoryview, the graph will be evaluated. Saving arrays via save() (or any other MLX saving functions) will also evaluate the array.

Calling array.item() on a scalar array will also evaluate it. In the example above, printing the loss (print(loss)) or adding the loss scalar to a list (losses.append(loss.item())) would cause a graph evaluation. If these lines are before mx.eval(loss, model.parameters()) then this will be a partial evaluation, computing only the forward pass.

Also, calling eval() on an array or set of arrays multiple times is perfectly fine. This is effectively a no-op.


Using scalar arrays for control-flow will cause an evaluation.

Here is an example:

def fun(x):
    h, y = first_layer(x)
    if y > 0:  # An evaluation is done here!
        z  = second_layer_a(h)
        z  = second_layer_b(h)
    return z

Using arrays for control flow should be done with care. The above example works and can even be used with gradient transformations. However, this can be very inefficient if evaluations are done too frequently.