mlx.core.vjp#

vjp(fun: callable, primals: List[array], cotangents: List[array]) Tuple[List[array], List[array]]#

Compute the vector-Jacobian product.

Computes the product of the cotangents with the Jacobian of a function fun evaluated at primals.

Parameters:
  • fun (callable) – A function which takes a variable number of array and returns a single array or list of array.

  • primals (list(array)) – A list of array at which to evaluate the Jacobian.

  • cotangents (list(array)) – A list of array which are the “vector” in the vector-Jacobian product. The cotangents should be the same in number, shape, and type as the outputs of fun.

Returns:

A list of the vector-Jacobian products which is the same in number, shape, and type of the outputs of fun.

Return type:

list(array)