mlx.core.vmap#

vmap(fun: callable, in_axes: object = 0, out_axes: object = 0) callable#

Returns a vectorized version of fun.

Parameters:
  • fun (callable) – A function which takes a variable number of array or a tree of array and returns a variable number of array or a tree of array.

  • in_axes (int, optional) – An integer or a valid prefix tree of the inputs to fun where each node specifies the vmapped axis. If the value is None then the corresponding input(s) are not vmapped. Defaults to 0.

  • out_axes (int, optional) – An integer or a valid prefix tree of the outputs of fun where each node specifies the vmapped axis. If the value is None then the corresponding outputs(s) are not vmapped. Defaults to 0.

Returns:

The vectorized function.

Return type:

callable