mlx.core.distributed.all_sum

Contents

mlx.core.distributed.all_sum#

all_sum(x: array, *, group: Group | None = None, stream: None | Stream | Device = None) array#

All reduce sum.

Sum the x arrays from all processes in the group.

Parameters:
  • x (array) – Input array.

  • group (Group) – The group of processes that will participate in the reduction. If set to None the global group is used. Default: None.

  • stream (Stream, optional) – Stream or device. Defaults to None in which case the default stream of the default device is used.

Returns:

The sum of all x arrays.

Return type:

array