mlx.core.custom_function#
- class custom_function#
Set up a function for custom gradient and vmap definitions.
This class is meant to be used as a function decorator. Instances are callables that behave identically to the wrapped function. However, when a function transformation is used (e.g. computing gradients using
value_and_grad()
) then the functions defined viacustom_function.vjp()
,custom_function.jvp()
andcustom_function.vmap()
are used instead of the default transformation.Note, all custom transformations are optional. Undefined transformations fall back to the default behaviour.
Example usage:
import mlx.core as mx @mx.custom_function def f(x, y): return mx.sin(x) * y @f.vjp def f_vjp(primals, cotangent, output): x, y = primals return cotan * mx.cos(x) * y, cotan * mx.sin(x) @f.jvp def f_jvp(primals, tangents): x, y = primals dx, dy = tangents return dx * mx.cos(x) * y + dy * mx.sin(x) @f.vmap def f_vmap(inputs, axes): x, y = inputs ax, ay = axes if ay != ax and ax is not None: y = y.swapaxes(ay, ax) return mx.sin(x) * y, (ax or ay)
- __init__(self, f: Callable)#
Methods
__init__
(self, f)jvp
(self, f)Define a custom jvp for the wrapped function.
vjp
(self, f)Define a custom vjp for the wrapped function.
vmap
(self, f)Define a custom vectorization transformation for the wrapped function.