mlx.nn.value_and_grad#

value_and_grad(model: Module, fn: Callable)#

Transform the passed function fn to a function that computes the gradients of fn wrt the model’s trainable parameters and also its value.

Parameters:
  • model (Module) – The model whose trainable parameters to compute gradients for

  • fn (Callable) – The scalar function to compute gradients for

Returns:

A callable that returns the value of fn and the gradients wrt the trainable parameters of model