Fast Custom Ops
-
typedef struct mlx_fast_cuda_kernel_config_ mlx_fast_cuda_kernel_config
-
typedef struct mlx_fast_cuda_kernel_ mlx_fast_cuda_kernel
-
typedef struct mlx_fast_metal_kernel_config_ mlx_fast_metal_kernel_config
-
typedef struct mlx_fast_metal_kernel_ mlx_fast_metal_kernel
-
mlx_fast_cuda_kernel_config mlx_fast_cuda_kernel_config_new(void)
-
void mlx_fast_cuda_kernel_config_free(mlx_fast_cuda_kernel_config cls)
-
int mlx_fast_cuda_kernel_config_add_output_arg(mlx_fast_cuda_kernel_config cls, const int *shape, size_t size, mlx_dtype dtype)
-
int mlx_fast_cuda_kernel_config_set_grid(mlx_fast_cuda_kernel_config cls, int grid1, int grid2, int grid3)
-
int mlx_fast_cuda_kernel_config_set_thread_group(mlx_fast_cuda_kernel_config cls, int thread1, int thread2, int thread3)
-
int mlx_fast_cuda_kernel_config_set_init_value(mlx_fast_cuda_kernel_config cls, float value)
-
int mlx_fast_cuda_kernel_config_set_verbose(mlx_fast_cuda_kernel_config cls, bool verbose)
-
int mlx_fast_cuda_kernel_config_add_template_arg_dtype(mlx_fast_cuda_kernel_config cls, const char *name, mlx_dtype dtype)
-
int mlx_fast_cuda_kernel_config_add_template_arg_int(mlx_fast_cuda_kernel_config cls, const char *name, int value)
-
int mlx_fast_cuda_kernel_config_add_template_arg_bool(mlx_fast_cuda_kernel_config cls, const char *name, bool value)
-
mlx_fast_cuda_kernel mlx_fast_cuda_kernel_new(const char *name, const mlx_vector_string input_names, const mlx_vector_string output_names, const char *source, const char *header, bool ensure_row_contiguous, int shared_memory)
-
void mlx_fast_cuda_kernel_free(mlx_fast_cuda_kernel cls)
-
int mlx_fast_cuda_kernel_apply(mlx_vector_array *outputs, mlx_fast_cuda_kernel cls, const mlx_vector_array inputs, const mlx_fast_cuda_kernel_config config, const mlx_stream stream)
-
int mlx_fast_layer_norm(mlx_array *res, const mlx_array x, const mlx_array weight, const mlx_array bias, float eps, const mlx_stream s)
-
mlx_fast_metal_kernel_config mlx_fast_metal_kernel_config_new(void)
-
void mlx_fast_metal_kernel_config_free(mlx_fast_metal_kernel_config cls)
-
int mlx_fast_metal_kernel_config_add_output_arg(mlx_fast_metal_kernel_config cls, const int *shape, size_t size, mlx_dtype dtype)
-
int mlx_fast_metal_kernel_config_set_grid(mlx_fast_metal_kernel_config cls, int grid1, int grid2, int grid3)
-
int mlx_fast_metal_kernel_config_set_thread_group(mlx_fast_metal_kernel_config cls, int thread1, int thread2, int thread3)
-
int mlx_fast_metal_kernel_config_set_init_value(mlx_fast_metal_kernel_config cls, float value)
-
int mlx_fast_metal_kernel_config_set_verbose(mlx_fast_metal_kernel_config cls, bool verbose)
-
int mlx_fast_metal_kernel_config_add_template_arg_dtype(mlx_fast_metal_kernel_config cls, const char *name, mlx_dtype dtype)
-
int mlx_fast_metal_kernel_config_add_template_arg_int(mlx_fast_metal_kernel_config cls, const char *name, int value)
-
int mlx_fast_metal_kernel_config_add_template_arg_bool(mlx_fast_metal_kernel_config cls, const char *name, bool value)
-
mlx_fast_metal_kernel mlx_fast_metal_kernel_new(const char *name, const mlx_vector_string input_names, const mlx_vector_string output_names, const char *source, const char *header, bool ensure_row_contiguous, bool atomic_outputs)
-
void mlx_fast_metal_kernel_free(mlx_fast_metal_kernel cls)
-
int mlx_fast_metal_kernel_apply(mlx_vector_array *outputs, mlx_fast_metal_kernel cls, const mlx_vector_array inputs, const mlx_fast_metal_kernel_config config, const mlx_stream stream)
-
int mlx_fast_rms_norm(mlx_array *res, const mlx_array x, const mlx_array weight, float eps, const mlx_stream s)
-
int mlx_fast_rope(mlx_array *res, const mlx_array x, int dims, bool traditional, mlx_optional_float base, float scale, int offset, const mlx_array freqs, const mlx_stream s)
-
int mlx_fast_rope_dynamic(mlx_array *res, const mlx_array x, int dims, bool traditional, mlx_optional_float base, float scale, const mlx_array offset, const mlx_array freqs, const mlx_stream s)
-
int mlx_fast_scaled_dot_product_attention(mlx_array *res, const mlx_array queries, const mlx_array keys, const mlx_array values, float scale, const char *mask_mode, const mlx_array mask_arr, const mlx_array sinks, const mlx_stream s)
-
struct mlx_fast_cuda_kernel_config_
#include <fast.h>
-
struct mlx_fast_cuda_kernel_
#include <fast.h>
-
struct mlx_fast_metal_kernel_config_
#include <fast.h>
-
struct mlx_fast_metal_kernel_
#include <fast.h>