mlx.nn.fsdp_apply_gradients#
- fsdp_apply_gradients(gradients, parameters, optimizer, fsdp_group=None, dp_group=None, communication_size=33554432, communication_stream=None, max_norm=None)#
Perform a distributed optimizer step by sharding gradients and optimizer states across ranks.
This helper function performs the following steps: 1. Reduce-scatter the gradients across ranks so each rank gets a shard of the averaged gradients. 2. Optionally clip the sharded gradients by global norm. 3. Apply the optimizer update on the local parameter slice using the sharded gradients. 4. All-gather the updated parameter slices from all ranks to reconstruct the full parameters tree.
This is similar to PyTorch’s FSDP with reshard_after_forward=False.
- Parameters:
gradients (Any) – The Python tree containing the full gradients (it should have the same structure as
parameters). Each gradient’s first dimension must be divisible byfsdp_group.size().parameters (Any) – The Python tree containing the full parameters (it should have the same structure across processes). Each parameter’s first dimension must be divisible by
fsdp_group.size().optimizer – Optimizer with an
apply_gradientsmethod.fsdp_group (Optional[Group]) – The group of processes for FSDP sharding. If
None, the global group is used.dp_group (Optional[Group]) – The group of processes for data-parallel gradient averaging. Required when
fsdp_groupis smaller than the world (e.g. FSDP intra-node, DDP inter-node). Default:None.communication_size (int) – Group arrays until their size in bytes exceeds this number. Perform one communication step per group of arrays. If less or equal to 0 array grouping is disabled. Default:
32MiB.communication_stream (Optional[Stream]) – The stream to use for the communication. If unspecified the default communication stream is used which can vary by back-end. Default:
None.max_norm (Optional[float]) – If provided, clip gradients to this maximum global norm before applying the optimizer update. Default:
None.
- Returns:
If
max_normisNone, returns the updated full-parameter tree. Otherwise returns(parameters, grad_norm), wheregrad_normis the global gradient norm before clipping.
Example
>>> optimizer = optim.SGD(learning_rate=0.01) >>> # Without gradient clipping >>> updated_params = fsdp_apply_gradients(grads, params, optimizer) >>> model.update(updated_params) >>> >>> # With gradient clipping >>> updated_params, grad_norm = fsdp_apply_gradients( ... grads, params, optimizer, max_norm=1.0 ... ) >>> model.update(updated_params)