Quick Start Guide

Quick Start Guide#

Basics#

Import mlx.core and make an array:

>> import mlx.core as mx
>> a = mx.array([1, 2, 3, 4])
>> a.shape
[4]
>> a.dtype
int32
>> b = mx.array([1.0, 2.0, 3.0, 4.0])
>> b.dtype
float32

Operations in MLX are lazy. The outputs of MLX operations are not computed until they are needed. To force an array to be evaluated use eval(). Arrays will automatically be evaluated in a few cases. For example, inspecting a scalar with array.item(), printing an array, or converting an array from array to numpy.ndarray all automatically evaluate the array.

>> c = a + b    # c not yet evaluated
>> mx.eval(c)  # evaluates c
>> c = a + b
>> print(c)     # Also evaluates c
array([2, 4, 6, 8], dtype=float32)
>> c = a + b
>> import numpy as np
>> np.array(c)   # Also evaluates c
array([2., 4., 6., 8.], dtype=float32)

See the page on Lazy Evaluation for more details.

Function and Graph Transformations#

MLX has standard function transformations like grad() and vmap(). Transformations can be composed arbitrarily. For example grad(vmap(grad(fn))) (or any other composition) is allowed.

>> x = mx.array(0.0)
>> mx.sin(x)
array(0, dtype=float32)
>> mx.grad(mx.sin)(x)
array(1, dtype=float32)
>> mx.grad(mx.grad(mx.sin))(x)
array(-0, dtype=float32)

Other gradient transformations include vjp() for vector-Jacobian products and jvp() for Jacobian-vector products.

Use value_and_grad() to efficiently compute both a function’s output and gradient with respect to the function’s input.