mlx.optimizers.clip_grad_norm#
- clip_grad_norm(grads, max_norm)#
Clips the global norm of the gradients.
This function ensures that the global norm of the gradients does not exceed
max_norm
. It scales down the gradients proportionally if their norm is greater thanmax_norm
.Example
>>> grads = {"w1": mx.array([2, 3]), "w2": mx.array([1])} >>> clipped_grads, total_norm = clip_grad_norm(grads, max_norm=2.0) >>> print(clipped_grads) {"w1": mx.array([...]), "w2": mx.array([...])}