mlx.core.grad

Contents

mlx.core.grad#

grad(fun: Callable, argnums: int | list[int] | None = None, argnames: str | list[str] = []) Callable#

Returns a function which computes the gradient of fun.

Parameters:
  • fun (Callable) – A function which takes a variable number of array or trees of array and returns a scalar output array.

  • argnums (int or list(int), optional) – Specify the index (or indices) of the positional arguments of fun to compute the gradient with respect to. If neither argnums nor argnames are provided argnums defaults to 0 indicating fun’s first argument.

  • argnames (str or list(str), optional) – Specify keyword arguments of fun to compute gradients with respect to. It defaults to [] so no gradients for keyword arguments by default.

Returns:

A function which has the same input arguments as fun and returns the gradient(s).

Return type:

Callable