mlx.core.custom_function#
- class custom_function#
Set up a function for custom gradient and vmap definitions.
This class is meant to be used as a function decorator. Instances are callables that behave identically to the wrapped function. However, when a function transformation is used (e.g. computing gradients using
value_and_grad()
) then the functions defined viacustom_function.vjp()
,custom_function.jvp()
andcustom_function.vmap()
are used instead of the default transformation.Note, all custom transformations are optional. Undefined transformations fall back to the default behaviour.
Example
import mlx.core as mx @mx.custom_function def f(x, y): return mx.sin(x) * y @f.vjp def f_vjp(primals, cotangent, output): x, y = primals return cotan * mx.cos(x) * y, cotan * mx.sin(x) @f.jvp def f_jvp(primals, tangents): x, y = primals dx, dy = tangents return dx * mx.cos(x) * y + dy * mx.sin(x) @f.vmap def f_vmap(inputs, axes): x, y = inputs ax, ay = axes if ay != ax and ax is not None: y = y.swapaxes(ay, ax) return mx.sin(x) * y, (ax or ay)
All
custom_function
instances behave as pure functions. Namely, any variables captured will be treated as constants and no gradients will be computed with respect to the captured arrays. For instance:import mlx.core as mx def g(x, y): @mx.custom_function def f(x): return x * y @f.vjp def f_vjp(x, dx, fx): # Note that we have only x, dx and fx and nothing with respect to y raise ValueError("Abort!") return f(x) x = mx.array(2.0) y = mx.array(3.0) print(g(x, y)) # prints 6.0 print(mx.grad(g)(x, y)) # Raises exception print(mx.grad(g, argnums=1)(x, y)) # prints 0.0
- __init__(self, f: Callable)#
Methods
__init__
(self, f)jvp
(self, f)Define a custom jvp for the wrapped function.
vjp
(self, f)Define a custom vjp for the wrapped function.
vmap
(self, f)Define a custom vectorization transformation for the wrapped function.