mlx.nn.value_and_grad#
- value_and_grad(model: Module, fn: Callable)#
Transform the passed function
fn
to a function that computes the gradients offn
wrt the model’s trainable parameters and also its value.- Parameters:
model (Module) – The model whose trainable parameters to compute gradients for
fn (Callable) – The scalar function to compute gradients for
- Returns:
A callable that returns the value of
fn
and the gradients wrt the trainable parameters ofmodel