mlx.core.distributed.all_gather

Contents

mlx.core.distributed.all_gather#

all_gather(x: array, *, group: Group | None = None, stream: None | Stream | Device = None) array#

Gather arrays from all processes.

Gather the x arrays from all processes in the group and concatenate them along the first axis. The arrays should all have the same shape.

Parameters:
  • x (array) – Input array.

  • group (Group) – The group of processes that will participate in the gather. If set to None the global group is used. Default: None.

  • stream (Stream, optional) – Stream or device. Defaults to None in which case the default stream of the default device is used.

Returns:

The concatenation of all x arrays.

Return type:

array