mlx.core.grad#
- grad(fun: Callable, argnums: int | list[int] | None = None, argnames: str | list[str] = []) Callable #
Returns a function which computes the gradient of
fun
.- Parameters:
fun (Callable) – A function which takes a variable number of
array
or trees ofarray
and returns a scalar outputarray
.argnums (int or list(int), optional) – Specify the index (or indices) of the positional arguments of
fun
to compute the gradient with respect to. If neitherargnums
norargnames
are providedargnums
defaults to0
indicatingfun
’s first argument.argnames (str or list(str), optional) – Specify keyword arguments of
fun
to compute gradients with respect to. It defaults to [] so no gradients for keyword arguments by default.
- Returns:
A function which has the same input arguments as
fun
and returns the gradient(s).- Return type:
Callable