mlx.core.vmap#
- vmap(fun: Callable, in_axes: object = 0, out_axes: object = 0) Callable#
Returns a vectorized version of
fun.- Parameters:
fun (Callable) – A function which takes a variable number of
arrayor a tree ofarrayand returns a variable number ofarrayor a tree ofarray.in_axes (int, optional) – An integer or a valid prefix tree of the inputs to
funwhere each node specifies the vmapped axis. If the value isNonethen the corresponding input(s) are not vmapped. Defaults to0.out_axes (int, optional) – An integer or a valid prefix tree of the outputs of
funwhere each node specifies the vmapped axis. If the value isNonethen the corresponding outputs(s) are not vmapped. Defaults to0.
- Returns:
The vectorized function.
- Return type:
Callable