mlx.core.grad#
- grad(fun: Callable, argnums: int | Sequence[int] | None = None, argnames: str | Sequence[str] = []) Callable#
Returns a function which computes the gradient of
fun.- Parameters:
fun (Callable) – A function which takes a variable number of
arrayor trees ofarrayand returns a scalar outputarray.argnums (int or list(int), optional) – Specify the index (or indices) of the positional arguments of
funto compute the gradient with respect to. If neitherargnumsnorargnamesare providedargnumsdefaults to0indicatingfun’s first argument.argnames (str or list(str), optional) – Specify keyword arguments of
funto compute gradients with respect to. It defaults to [] so no gradients for keyword arguments by default.
- Returns:
A function which has the same input arguments as
funand returns the gradient(s).- Return type:
Callable