value_and_grad(fun: callable, argnums: int | List[int] | None = None, argnames: str | List[str] = []) callable#

Returns a function which computes the value and gradient of fun.

The function passed to value_and_grad() should return either a scalar loss or a tuple in which the first element is a scalar loss and the remaining elements can be anything.

import mlx.core as mx

def mse(params, inputs, targets):
    outputs = forward(params, inputs)
    lvalue = (outputs - targets).square().mean()
    return lvalue

# Returns lvalue, dlvalue/dparams
lvalue, grads = mx.value_and_grad(mse)(params, inputs, targets)

def lasso(params, inputs, targets, a=1.0, b=1.0):
    outputs = forward(params, inputs)
    mse = (outputs - targets).square().mean()
    l1 = mx.abs(outputs - targets).mean()

    loss = a*mse + b*l1

    return loss, mse, l1

(loss, mse, l1), grads = mx.value_and_grad(lasso)(params, inputs, targets)
  • fun (callable) – A function which takes a variable number of array or trees of array and returns a scalar output array or a tuple the first element of which should be a scalar array.

  • argnums (int or list(int), optional) – Specify the index (or indices) of the positional arguments of fun to compute the gradient with respect to. If neither argnums nor argnames are provided argnums defaults to 0 indicating fun’s first argument.

  • argnames (str or list(str), optional) – Specify keyword arguments of fun to compute gradients with respect to. It defaults to [] so no gradients for keyword arguments by default.


A function which returns a tuple where the first element is the output of fun and the second element is the gradients w.r.t. the loss.

Return type: