mlx.core.vjp#
- vjp(fun: Callable, primals: list[array], cotangents: list[array]) tuple[list[array], list[array]] #
Compute the vector-Jacobian product.
Computes the product of the
cotangents
with the Jacobian of a functionfun
evaluated atprimals
.- Parameters:
fun (Callable) – A function which takes a variable number of
array
and returns a singlearray
or list ofarray
.primals (list(array)) – A list of
array
at which to evaluate the Jacobian.cotangents (list(array)) – A list of
array
which are the “vector” in the vector-Jacobian product. Thecotangents
should be the same in number, shape, and type as the outputs offun
.
- Returns:
A list of the vector-Jacobian products which is the same in number, shape, and type of the outputs of
fun
.- Return type: