Transforms#

int mlx_async_eval(const mlx_vector_array outputs)#
int mlx_checkpoint(mlx_closure *res, const mlx_closure fun)#
int mlx_custom_function(mlx_closure *res, const mlx_closure fun, const mlx_closure_custom fun_vjp, const mlx_closure_custom_jvp fun_jvp, const mlx_closure_custom_vmap fun_vmap)#
int mlx_custom_vjp(mlx_closure *res, const mlx_closure fun, const mlx_closure_custom fun_vjp)#
int mlx_eval(const mlx_vector_array outputs)#
int mlx_jvp(mlx_vector_array *res_0, mlx_vector_array *res_1, const mlx_closure fun, const mlx_vector_array primals, const mlx_vector_array tangents)#
int mlx_value_and_grad(mlx_closure_value_and_grad *res, const mlx_closure fun, const int *argnums, size_t argnums_num)#
int mlx_vjp(mlx_vector_array *res_0, mlx_vector_array *res_1, const mlx_closure fun, const mlx_vector_array primals, const mlx_vector_array cotangents)#