mlx.nn.average_gradients#
- average_gradients(gradients: Any, group: Group | None = None, all_reduce_size: int = 33554432, communication_type: Dtype | None = None)#
Average the gradients across the distributed processes in the passed group.
This helper enables concatenating several gradients of small arrays to one big all reduce call for better networking performance.
- Parameters:
gradients (Any) – The Python tree containing the gradients (it should have the same structure across processes)
group (Optional[Group]) – The group of processes to average the gradients. If set to
None
the global group is used. Default:None
.all_reduce_size (int) – Group arrays until their size in bytes exceeds this number. Perform one communication step per group of arrays. If less or equal to 0 array grouping is disabled. Default:
32MiB
.communication_type (Optional[Dtype]) – If provided cast to this type before performing the communication. Typically cast to a smaller float to reduce the communication size. Default:
None
.