mlx.nn.average_gradients

mlx.nn.average_gradients#

average_gradients(gradients: Any, group: Group | None = None, all_reduce_size: int = 33554432, communication_type: Dtype | None = None)#

Average the gradients across the distributed processes in the passed group.

This helper enables concatenating several gradients of small arrays to one big all reduce call for better networking performance.

Parameters:
  • gradients (Any) – The Python tree containing the gradients (it should have the same structure across processes)

  • group (Optional[Group]) – The group of processes to average the gradients. If set to None the global group is used. Default: None.

  • all_reduce_size (int) – Group arrays until their size in bytes exceeds this number. Perform one communication step per group of arrays. If less or equal to 0 array grouping is disabled. Default: 32MiB.

  • communication_type (Optional[Dtype]) – If provided cast to this type before performing the communication. Typically cast to a smaller float to reduce the communication size. Default: None.