mlx.core.value_and_grad#
- value_and_grad(fun: Callable, argnums: int | Sequence[int] | None = None, argnames: str | Sequence[str] = []) Callable#
Returns a function which computes the value and gradient of
fun.The function passed to
value_and_grad()should return either a scalar loss or a tuple in which the first element is a scalar loss and the remaining elements can be anything.import mlx.core as mx def mse(params, inputs, targets): outputs = forward(params, inputs) lvalue = (outputs - targets).square().mean() return lvalue # Returns lvalue, dlvalue/dparams lvalue, grads = mx.value_and_grad(mse)(params, inputs, targets) def lasso(params, inputs, targets, a=1.0, b=1.0): outputs = forward(params, inputs) mse = (outputs - targets).square().mean() l1 = mx.abs(outputs - targets).mean() loss = a*mse + b*l1 return loss, mse, l1 (loss, mse, l1), grads = mx.value_and_grad(lasso)(params, inputs, targets)
- Parameters:
fun (Callable) – A function which takes a variable number of
arrayor trees ofarrayand returns a scalar outputarrayor a tuple the first element of which should be a scalararray.argnums (int or list(int), optional) – Specify the index (or indices) of the positional arguments of
funto compute the gradient with respect to. If neitherargnumsnorargnamesare providedargnumsdefaults to0indicatingfun’s first argument.argnames (str or list(str), optional) – Specify keyword arguments of
funto compute gradients with respect to. It defaults to [] so no gradients for keyword arguments by default.
- Returns:
A function which returns a tuple where the first element is the output of fun and the second element is the gradients w.r.t. the loss.
- Return type:
Callable