Transforms

Transforms#

eval(*args)

Evaluate an array or tree of array.

compile(fun[, inputs, outputs, shapeless])

Returns a compiled function which produces the same output as fun.

custom_function

Set up a function for custom gradient and vmap definitions.

disable_compile()

Globally disable compilation.

enable_compile()

Globally enable compilation.

grad(fun[, argnums, argnames])

Returns a function which computes the gradient of fun.

value_and_grad(fun[, argnums, argnames])

Returns a function which computes the value and gradient of fun.

jvp(fun, primals, tangents)

Compute the Jacobian-vector product.

vjp(fun, primals, cotangents)

Compute the vector-Jacobian product.

vmap(fun[, in_axes, out_axes])

Returns a vectorized version of fun.