mlx.core.value_and_grad#
- value_and_grad(fun: Callable, argnums: int | list[int] | None = None, argnames: str | list[str] = []) Callable #
Returns a function which computes the value and gradient of
fun
.The function passed to
value_and_grad()
should return either a scalar loss or a tuple in which the first element is a scalar loss and the remaining elements can be anything.import mlx.core as mx def mse(params, inputs, targets): outputs = forward(params, inputs) lvalue = (outputs - targets).square().mean() return lvalue # Returns lvalue, dlvalue/dparams lvalue, grads = mx.value_and_grad(mse)(params, inputs, targets) def lasso(params, inputs, targets, a=1.0, b=1.0): outputs = forward(params, inputs) mse = (outputs - targets).square().mean() l1 = mx.abs(outputs - targets).mean() loss = a*mse + b*l1 return loss, mse, l1 (loss, mse, l1), grads = mx.value_and_grad(lasso)(params, inputs, targets)
- Parameters:
fun (Callable) – A function which takes a variable number of
array
or trees ofarray
and returns a scalar outputarray
or a tuple the first element of which should be a scalararray
.argnums (int or list(int), optional) – Specify the index (or indices) of the positional arguments of
fun
to compute the gradient with respect to. If neitherargnums
norargnames
are providedargnums
defaults to0
indicatingfun
’s first argument.argnames (str or list(str), optional) – Specify keyword arguments of
fun
to compute gradients with respect to. It defaults to [] so no gradients for keyword arguments by default.
- Returns:
A function which returns a tuple where the first element is the output of fun and the second element is the gradients w.r.t. the loss.
- Return type:
Callable