mlx.core.grad#

grad(fun: callable, argnums: int | List[int] | None = None, argnames: str | List[str] = []) callable#

Returns a function which computes the gradient of fun.

Parameters:
  • fun (callable) – A function which takes a variable number of array or trees of array and returns a scalar output array.

  • argnums (int or list(int), optional) – Specify the index (or indices) of the positional arguments of fun to compute the gradient with respect to. If neither argnums nor argnames are provided argnums defaults to 0 indicating fun’s first argument.

  • argnames (str or list(str), optional) – Specify keyword arguments of fun to compute gradients with respect to. It defaults to [] so no gradients for keyword arguments by default.

Returns:

A function which has the same input arguments as fun and returns the gradient(s).

Return type:

callable