Distributed Communication#
MLX supports distributed communication operations that allow the computational cost of training or inference to be shared across many physical machines. At the moment we support two different communication backends:
MPI a full-featured and mature distributed communications library
A ring backend of our own that uses native TCP sockets and should be faster for thunderbolt connections.
The list of all currently supported operations and their documentation can be seen in the API docs.
Some operations may not be supported or not as fast as they should be. We are adding more and tuning the ones we have as we are figuring out the best way to do distributed computing on Macs using MLX.
Getting Started#
A distributed program in MLX is as simple as:
import mlx.core as mx
world = mx.distributed.init()
x = mx.distributed.all_sum(mx.ones(10))
print(world.rank(), x)
The program above sums the array mx.ones(10)
across all
distributed processes. However, when this script is run with python
one process is launched and no distributed communication takes place. Namely,
all operations in mx.distributed
are noops when the distributed group has a
size of one. This property allows us to avoid code that checks if we are in a
distributed setting similar to the one below:
import mlx.core as mx
x = ...
world = mx.distributed.init()
# No need for the check we can simply do x = mx.distributed.all_sum(x)
if world.size() > 1:
x = mx.distributed.all_sum(x)
Running Distributed Programs#
MLX provides mlx.launch
a helper script to launch distributed programs.
Continuing with our initial example we can run it on localhost with 4 processes using
$ mlx.launch -n 4 my_script.py
3 array([4, 4, 4, ..., 4, 4, 4], dtype=float32)
2 array([4, 4, 4, ..., 4, 4, 4], dtype=float32)
1 array([4, 4, 4, ..., 4, 4, 4], dtype=float32)
0 array([4, 4, 4, ..., 4, 4, 4], dtype=float32)
We can also run it on some remote hosts by providing their IPs (provided that the script exists on all hosts and they are reachable by ssh)
$ mlx.launch --hosts ip1,ip2,ip3,ip4 my_script.py
3 array([4, 4, 4, ..., 4, 4, 4], dtype=float32)
2 array([4, 4, 4, ..., 4, 4, 4], dtype=float32)
1 array([4, 4, 4, ..., 4, 4, 4], dtype=float32)
0 array([4, 4, 4, ..., 4, 4, 4], dtype=float32)
Consult the dedicated usage guide for more
information on using mlx.launch
Selecting Backend#
You can select the backend you want to use when calling init()
by passing
one of {'any', 'ring', 'mpi'}
. When passing any
, MLX will try to
initialize the ring
backend and if it fails the mpi
backend. If they
both fail then a singleton group is created.
After a distributed backend is successfully initialized init()
return the same backend if called without arguments or with backend set to
The following examples aim to clarify the backend initialization logic in MLX:
# Case 1: Initialize MPI regardless if it was possible to initialize the ring backend
world = mx.distributed.init(backend="mpi")
world2 = mx.distributed.init() # subsequent calls return the MPI backend!
# Case 2: Initialize any backend
world = mx.distributed.init(backend="any") # equivalent to no arguments
world2 = mx.distributed.init() # same as above
# Case 3: Initialize both backends at the same time
world_mpi = mx.distributed.init(backend="mpi")
world_ring = mx.distributed.init(backend="ring")
world_any = mx.distributed.init() # same as MPI because it was initialized first!
Training Example#
In this section we will adapt an MLX training loop to support data parallel distributed training. Namely, we will average the gradients across a set of hosts before applying them to the model.
Our training loop looks like the following code snippet if we omit the model, dataset and optimizer initialization.
model = ...
optimizer = ...
dataset = ...
def step(model, x, y):
loss, grads = loss_grad_fn(model, x, y)
optimizer.update(model, grads)
return loss
for x, y in dataset:
loss = step(model, x, y)
mx.eval(loss, model.parameters())
All we have to do to average the gradients across machines is perform an
and divide by the size of the Group
. Namely we
have to mlx.utils.tree_map()
the gradients with following function.
def all_avg(x):
return mx.distributed.all_sum(x) / mx.distributed.init().size()
Putting everything together our training loop step looks as follows with everything else remaining the same.
from mlx.utils import tree_map
def all_reduce_grads(grads):
N = mx.distributed.init().size()
if N == 1:
return grads
return tree_map(
lambda x: mx.distributed.all_sum(x) / N,
def step(model, x, y):
loss, grads = loss_grad_fn(model, x, y)
grads = all_reduce_grads(grads) # <--- This line was added
optimizer.update(model, grads)
return loss
Utilizing nn.average_gradients
Although the code example above works correctly; it performs one communication per gradient. It is significantly more efficient to aggregate several gradients together and perform fewer communication steps.
This is the purpose of mlx.nn.average_gradients()
. The final code looks
almost identical to the example above:
model = ...
optimizer = ...
dataset = ...
def step(model, x, y):
loss, grads = loss_grad_fn(model, x, y)
grads = mlx.nn.average_gradients(grads) # <---- This line was added
optimizer.update(model, grads)
return loss
for x, y in dataset:
loss = step(model, x, y)
mx.eval(loss, model.parameters())
Getting Started with MPI#
MLX already comes with the ability to “talk” to MPI if it is installed on the
machine. Launching distributed MLX programs that use MPI can be done with
as expected. However, in the following examples we will be using
mlx.launch --backend mpi
which takes care of some nuisances such as setting
absolute paths for the mpirun
executable and the libmpi.dyld
The simplest possible usage is the following which, assuming the minimal example in the beginning of this page, should result in:
$ mlx.launch --backend mpi -n 2 test.py
1 array([2, 2, 2, ..., 2, 2, 2], dtype=float32)
0 array([2, 2, 2, ..., 2, 2, 2], dtype=float32)
The above launches two processes on the same (local) machine and we can see
both standard output streams. The processes send the array of 1s to each other
and compute the sum which is printed. Launching with mlx.launch -n 4 ...
print 4 etc.
Installing MPI#
MPI can be installed with Homebrew, using the Anaconda package manager or
compiled from source. Most of our testing is done using openmpi
with the Anaconda package manager as follows:
$ conda install conda-forge::openmpi
Installing with Homebrew may require specifying the location of libmpi.dyld
so that MLX can find it and load it at runtime. This can simply be achieved by
environment variable to mpirun
and it is
done automatically by mlx.launch
$ mpirun -np 2 -x DYLD_LIBRARY_PATH=/opt/homebrew/lib/ python test.py
$ # or simply
$ mlx.launch -n 2 test.py
Setting up Remote Hosts#
MPI can automatically connect to remote hosts and set up the communication over the network if the remote hosts can be accessed via ssh. A good checklist to debug connectivity issues is the following:
ssh hostname
works from all machines to all machines without asking for password or host confirmationmpirun
is accessible on all machines.Ensure that the
used by MPI is the one that you have configured in the.ssh/config
files on all machines.
Tuning MPI All Reduce#
For faster all reduce consider using the ring backend either with Thunderbolt connections or over Ethernet.
Configure MPI to use N tcp connections between each host to improve bandwidth
by passing --mca btl_tcp_links N
Force MPI to use the most performant network interface by setting --mca
btl_tcp_if_include <iface>
where <iface>
should be the interface you want
to use.
Getting Started with Ring#
The ring backend does not depend on any third party library so it is always
available. It uses TCP sockets so the nodes need to be reachable via a network.
As the name suggests the nodes are connected in a ring which means that rank 1
can only communicate with rank 0 and rank 2, rank 2 only with rank 1 and rank 3
and so on and so forth. As a result send()
and recv()
arbitrary sender and receiver is not supported in the ring backend.
Defining a Ring#
The easiest way to define and use a ring is via a JSON hostfile and the
helper script. For each node one
defines a hostname to ssh into to run commands on this node and one or more IPs
that this node will listen to for connections.
For example the hostfile below defines a 4 node ring. hostname1
will be
rank 0, hostname2
rank 1 etc.
{"ssh": "hostname1", "ips": [""]},
{"ssh": "hostname2", "ips": [""]},
{"ssh": "hostname3", "ips": [""]},
{"ssh": "hostname4", "ips": [""]}
Running mlx.launch --hostfile ring-4.json my_script.py
will ssh into each
node, run the script which will listen for connections in each of the provided
IPs. Specifically, hostname1
will connect to
and accept a
connection from
and so on and so forth.
Thunderbolt Ring#
Although the ring backend can have benefits over MPI even for Ethernet, its
main purpose is to use Thunderbolt rings for higher bandwidth communication.
Setting up such thunderbolt rings can be done manually, but is a relatively
tedious process. To simplify this, we provide the utility mlx.distributed_config
To use mlx.distributed_config
your computers need to be accessible by ssh via
Ethernet or Wi-Fi. Subsequently, connect them via thunderbolt cables and then call the
utility as follows:
mlx.distributed_config --verbose --hosts host1,host2,host3,host4
By default the script will attempt to discover the thunderbolt ring and provide
you with the commands to configure each node as well as the hostfile.json
to use with mlx.launch
. If password-less sudo
is available on the nodes
then --auto-setup
can be used to configure them automatically.
To validate your connection without configuring anything
can also plot the ring using DOT format.
mlx.distributed_config --verbose --hosts host1,host2,host3,host4 --dot >ring.dot
dot -Tpng ring.dot >ring.png
open ring.png
If you want to go through the process manually, the steps are as follows:
Disable the thunderbolt bridge interface
For the cable connecting rank
to ranki + 1
find the interfaces corresponding to that cable in nodesi
andi + 1
.Set up a unique subnetwork connecting the two nodes for the corresponding interfaces. For instance if the cable corresponds to
on nodei
also on nodei + 1
then we may assign IPs192.168.0.1
respectively to the two nodes. For more details you can see the commands prepared by the utility script.