Transforms#
-
void mlx_async_eval(const mlx_vector_array outputs)#
-
mlx_closure mlx_checkpoint(mlx_closure fun)#
-
void mlx_eval(const mlx_vector_array outputs)#
-
mlx_vector_vector_array mlx_jvp(mlx_closure fun, const mlx_vector_array primals, const mlx_vector_array tangents)#
-
mlx_closure_value_and_grad mlx_value_and_grad(mlx_closure fun, const int *argnums, size_t num_argnums)#
-
mlx_vector_vector_array mlx_vjp(mlx_closure fun, const mlx_vector_array primals, const mlx_vector_array cotangents)#