mlx.optimizers.clip_grad_norm

Contents

mlx.optimizers.clip_grad_norm#

clip_grad_norm(grads, max_norm)#

Clips the global norm of the gradients.

This function ensures that the global norm of the gradients does not exceed max_norm. It scales down the gradients proportionally if their norm is greater than max_norm.

Example

>>> grads = {"w1": mx.array([2, 3]), "w2": mx.array([1])}
>>> clipped_grads, total_norm = clip_grad_norm(grads, max_norm=2.0)
>>> print(clipped_grads)
{"w1": mx.array([...]), "w2": mx.array([...])}
Parameters:
  • grads (dict) – A dictionary containing the gradient arrays.

  • max_norm (float) – The maximum allowed global norm of the gradients.

Returns:

The possibly rescaled gradients and the original gradient norm.

Return type:

(dict, float)