mlx.core.custom_function

mlx.core.custom_function#

class custom_function#

Set up a function for custom gradient and vmap definitions.

This class is meant to be used as a function decorator. Instances are callables that behave identically to the wrapped function. However, when a function transformation is used (e.g. computing gradients using value_and_grad()) then the functions defined via custom_function.vjp(), custom_function.jvp() and custom_function.vmap() are used instead of the default transformation.

Note, all custom transformations are optional. Undefined transformations fall back to the default behaviour.

Example usage:

import mlx.core as mx

@mx.custom_function
def f(x, y):
    return mx.sin(x) * y

@f.vjp
def f_vjp(primals, cotangent, output):
    x, y = primals
    return cotan * mx.cos(x) * y, cotan * mx.sin(x)

@f.jvp
def f_jvp(primals, tangents):
  x, y = primals
  dx, dy = tangents
  return dx * mx.cos(x) * y + dy * mx.sin(x)

@f.vmap
def f_vmap(inputs, axes):
  x, y = inputs
  ax, ay = axes
  if ay != ax and ax is not None:
      y = y.swapaxes(ay, ax)
  return mx.sin(x) * y, (ax or ay)
__init__(self, f: Callable)#

Methods

__init__(self, f)

jvp(self, f)

Define a custom jvp for the wrapped function.

vjp(self, f)

Define a custom vjp for the wrapped function.

vmap(self, f)

Define a custom vectorization transformation for the wrapped function.