Operations

Contents

Operations#

int mlx_abs(mlx_array *res, const mlx_array a, const mlx_stream s)#
int mlx_add(mlx_array *res, const mlx_array a, const mlx_array b, const mlx_stream s)#
int mlx_addmm(mlx_array *res, const mlx_array c, const mlx_array a, const mlx_array b, float alpha, float beta, const mlx_stream s)#
int mlx_all_axes(mlx_array *res, const mlx_array a, const int *axes, size_t axes_num, bool keepdims, const mlx_stream s)#
int mlx_all_axis(mlx_array *res, const mlx_array a, int axis, bool keepdims, const mlx_stream s)#
int mlx_all_all(mlx_array *res, const mlx_array a, bool keepdims, const mlx_stream s)#
int mlx_allclose(mlx_array *res, const mlx_array a, const mlx_array b, double rtol, double atol, bool equal_nan, const mlx_stream s)#
int mlx_any(mlx_array *res, const mlx_array a, const int *axes, size_t axes_num, bool keepdims, const mlx_stream s)#
int mlx_any_all(mlx_array *res, const mlx_array a, bool keepdims, const mlx_stream s)#
int mlx_arange(mlx_array *res, double start, double stop, double step, mlx_dtype dtype, const mlx_stream s)#
int mlx_arccos(mlx_array *res, const mlx_array a, const mlx_stream s)#
int mlx_arccosh(mlx_array *res, const mlx_array a, const mlx_stream s)#
int mlx_arcsin(mlx_array *res, const mlx_array a, const mlx_stream s)#
int mlx_arcsinh(mlx_array *res, const mlx_array a, const mlx_stream s)#
int mlx_arctan(mlx_array *res, const mlx_array a, const mlx_stream s)#
int mlx_arctan2(mlx_array *res, const mlx_array a, const mlx_array b, const mlx_stream s)#
int mlx_arctanh(mlx_array *res, const mlx_array a, const mlx_stream s)#
int mlx_argmax(mlx_array *res, const mlx_array a, int axis, bool keepdims, const mlx_stream s)#
int mlx_argmax_all(mlx_array *res, const mlx_array a, bool keepdims, const mlx_stream s)#
int mlx_argmin(mlx_array *res, const mlx_array a, int axis, bool keepdims, const mlx_stream s)#
int mlx_argmin_all(mlx_array *res, const mlx_array a, bool keepdims, const mlx_stream s)#
int mlx_argpartition(mlx_array *res, const mlx_array a, int kth, int axis, const mlx_stream s)#
int mlx_argpartition_all(mlx_array *res, const mlx_array a, int kth, const mlx_stream s)#
int mlx_argsort(mlx_array *res, const mlx_array a, int axis, const mlx_stream s)#
int mlx_argsort_all(mlx_array *res, const mlx_array a, const mlx_stream s)#
int mlx_array_equal(mlx_array *res, const mlx_array a, const mlx_array b, bool equal_nan, const mlx_stream s)#
int mlx_as_strided(mlx_array *res, const mlx_array a, const int *shape, size_t shape_num, const int64_t *strides, size_t strides_num, size_t offset, const mlx_stream s)#
int mlx_astype(mlx_array *res, const mlx_array a, mlx_dtype dtype, const mlx_stream s)#
int mlx_atleast_1d(mlx_array *res, const mlx_array a, const mlx_stream s)#
int mlx_atleast_2d(mlx_array *res, const mlx_array a, const mlx_stream s)#
int mlx_atleast_3d(mlx_array *res, const mlx_array a, const mlx_stream s)#
int mlx_bitwise_and(mlx_array *res, const mlx_array a, const mlx_array b, const mlx_stream s)#
int mlx_bitwise_invert(mlx_array *res, const mlx_array a, const mlx_stream s)#
int mlx_bitwise_or(mlx_array *res, const mlx_array a, const mlx_array b, const mlx_stream s)#
int mlx_bitwise_xor(mlx_array *res, const mlx_array a, const mlx_array b, const mlx_stream s)#
int mlx_block_masked_mm(mlx_array *res, const mlx_array a, const mlx_array b, int block_size, const mlx_array mask_out, const mlx_array mask_lhs, const mlx_array mask_rhs, const mlx_stream s)#
int mlx_broadcast_arrays(mlx_vector_array *res, const mlx_vector_array inputs, const mlx_stream s)#
int mlx_broadcast_to(mlx_array *res, const mlx_array a, const int *shape, size_t shape_num, const mlx_stream s)#
int mlx_ceil(mlx_array *res, const mlx_array a, const mlx_stream s)#
int mlx_clip(mlx_array *res, const mlx_array a, const mlx_array a_min, const mlx_array a_max, const mlx_stream s)#
int mlx_concatenate(mlx_array *res, const mlx_vector_array arrays, int axis, const mlx_stream s)#
int mlx_concatenate_all(mlx_array *res, const mlx_vector_array arrays, const mlx_stream s)#
int mlx_conjugate(mlx_array *res, const mlx_array a, const mlx_stream s)#
int mlx_contiguous(mlx_array *res, const mlx_array a, bool allow_col_major, const mlx_stream s)#
int mlx_conv1d(mlx_array *res, const mlx_array input, const mlx_array weight, int stride, int padding, int dilation, int groups, const mlx_stream s)#
int mlx_conv2d(mlx_array *res, const mlx_array input, const mlx_array weight, int stride_0, int stride_1, int padding_0, int padding_1, int dilation_0, int dilation_1, int groups, const mlx_stream s)#
int mlx_conv3d(mlx_array *res, const mlx_array input, const mlx_array weight, int stride_0, int stride_1, int stride_2, int padding_0, int padding_1, int padding_2, int dilation_0, int dilation_1, int dilation_2, int groups, const mlx_stream s)#
int mlx_conv_general(mlx_array *res, const mlx_array input, const mlx_array weight, const int *stride, size_t stride_num, const int *padding_lo, size_t padding_lo_num, const int *padding_hi, size_t padding_hi_num, const int *kernel_dilation, size_t kernel_dilation_num, const int *input_dilation, size_t input_dilation_num, int groups, bool flip, const mlx_stream s)#
int mlx_conv_transpose1d(mlx_array *res, const mlx_array input, const mlx_array weight, int stride, int padding, int dilation, int groups, const mlx_stream s)#
int mlx_conv_transpose2d(mlx_array *res, const mlx_array input, const mlx_array weight, int stride_0, int stride_1, int padding_0, int padding_1, int dilation_0, int dilation_1, int groups, const mlx_stream s)#
int mlx_conv_transpose3d(mlx_array *res, const mlx_array input, const mlx_array weight, int stride_0, int stride_1, int stride_2, int padding_0, int padding_1, int padding_2, int dilation_0, int dilation_1, int dilation_2, int groups, const mlx_stream s)#
int mlx_copy(mlx_array *res, const mlx_array a, const mlx_stream s)#
int mlx_cos(mlx_array *res, const mlx_array a, const mlx_stream s)#
int mlx_cosh(mlx_array *res, const mlx_array a, const mlx_stream s)#
int mlx_cummax(mlx_array *res, const mlx_array a, int axis, bool reverse, bool inclusive, const mlx_stream s)#
int mlx_cummin(mlx_array *res, const mlx_array a, int axis, bool reverse, bool inclusive, const mlx_stream s)#
int mlx_cumprod(mlx_array *res, const mlx_array a, int axis, bool reverse, bool inclusive, const mlx_stream s)#
int mlx_cumsum(mlx_array *res, const mlx_array a, int axis, bool reverse, bool inclusive, const mlx_stream s)#
int mlx_degrees(mlx_array *res, const mlx_array a, const mlx_stream s)#
int mlx_depends(mlx_vector_array *res, const mlx_vector_array inputs, const mlx_vector_array dependencies)#
int mlx_dequantize(mlx_array *res, const mlx_array w, const mlx_array scales, const mlx_array biases, int group_size, int bits, const mlx_stream s)#
int mlx_diag(mlx_array *res, const mlx_array a, int k, const mlx_stream s)#
int mlx_diagonal(mlx_array *res, const mlx_array a, int offset, int axis1, int axis2, const mlx_stream s)#
int mlx_divide(mlx_array *res, const mlx_array a, const mlx_array b, const mlx_stream s)#
int mlx_divmod(mlx_vector_array *res, const mlx_array a, const mlx_array b, const mlx_stream s)#
int mlx_einsum(mlx_array *res, const char *subscripts, const mlx_vector_array operands, const mlx_stream s)#
int mlx_equal(mlx_array *res, const mlx_array a, const mlx_array b, const mlx_stream s)#
int mlx_erf(mlx_array *res, const mlx_array a, const mlx_stream s)#
int mlx_erfinv(mlx_array *res, const mlx_array a, const mlx_stream s)#
int mlx_exp(mlx_array *res, const mlx_array a, const mlx_stream s)#
int mlx_expand_dims(mlx_array *res, const mlx_array a, const int *axes, size_t axes_num, const mlx_stream s)#
int mlx_expm1(mlx_array *res, const mlx_array a, const mlx_stream s)#
int mlx_eye(mlx_array *res, int n, int m, int k, mlx_dtype dtype, const mlx_stream s)#
int mlx_flatten(mlx_array *res, const mlx_array a, int start_axis, int end_axis, const mlx_stream s)#
int mlx_floor(mlx_array *res, const mlx_array a, const mlx_stream s)#
int mlx_floor_divide(mlx_array *res, const mlx_array a, const mlx_array b, const mlx_stream s)#
int mlx_full(mlx_array *res, const int *shape, size_t shape_num, const mlx_array vals, mlx_dtype dtype, const mlx_stream s)#
int mlx_gather(mlx_array *res, const mlx_array a, const mlx_vector_array indices, const int *axes, size_t axes_num, const int *slice_sizes, size_t slice_sizes_num, const mlx_stream s)#
int mlx_gather_mm(mlx_array *res, const mlx_array a, const mlx_array b, const mlx_array lhs_indices, const mlx_array rhs_indices, const mlx_stream s)#
int mlx_gather_qmm(mlx_array *res, const mlx_array x, const mlx_array w, const mlx_array scales, const mlx_array biases, const mlx_array lhs_indices, const mlx_array rhs_indices, bool transpose, int group_size, int bits, const mlx_stream s)#
int mlx_greater(mlx_array *res, const mlx_array a, const mlx_array b, const mlx_stream s)#
int mlx_greater_equal(mlx_array *res, const mlx_array a, const mlx_array b, const mlx_stream s)#
int mlx_hadamard_transform(mlx_array *res, const mlx_array a, mlx_optional_float scale, const mlx_stream s)#
int mlx_identity(mlx_array *res, int n, mlx_dtype dtype, const mlx_stream s)#
int mlx_imag(mlx_array *res, const mlx_array a, const mlx_stream s)#
int mlx_inner(mlx_array *res, const mlx_array a, const mlx_array b, const mlx_stream s)#
int mlx_isclose(mlx_array *res, const mlx_array a, const mlx_array b, double rtol, double atol, bool equal_nan, const mlx_stream s)#
int mlx_isfinite(mlx_array *res, const mlx_array a, const mlx_stream s)#
int mlx_isinf(mlx_array *res, const mlx_array a, const mlx_stream s)#
int mlx_isnan(mlx_array *res, const mlx_array a, const mlx_stream s)#
int mlx_isneginf(mlx_array *res, const mlx_array a, const mlx_stream s)#
int mlx_isposinf(mlx_array *res, const mlx_array a, const mlx_stream s)#
int mlx_kron(mlx_array *res, const mlx_array a, const mlx_array b, const mlx_stream s)#
int mlx_left_shift(mlx_array *res, const mlx_array a, const mlx_array b, const mlx_stream s)#
int mlx_less(mlx_array *res, const mlx_array a, const mlx_array b, const mlx_stream s)#
int mlx_less_equal(mlx_array *res, const mlx_array a, const mlx_array b, const mlx_stream s)#
int mlx_linspace(mlx_array *res, double start, double stop, int num, mlx_dtype dtype, const mlx_stream s)#
int mlx_log(mlx_array *res, const mlx_array a, const mlx_stream s)#
int mlx_log10(mlx_array *res, const mlx_array a, const mlx_stream s)#
int mlx_log1p(mlx_array *res, const mlx_array a, const mlx_stream s)#
int mlx_log2(mlx_array *res, const mlx_array a, const mlx_stream s)#
int mlx_logaddexp(mlx_array *res, const mlx_array a, const mlx_array b, const mlx_stream s)#
int mlx_logical_and(mlx_array *res, const mlx_array a, const mlx_array b, const mlx_stream s)#
int mlx_logical_not(mlx_array *res, const mlx_array a, const mlx_stream s)#
int mlx_logical_or(mlx_array *res, const mlx_array a, const mlx_array b, const mlx_stream s)#
int mlx_logsumexp(mlx_array *res, const mlx_array a, const int *axes, size_t axes_num, bool keepdims, const mlx_stream s)#
int mlx_logsumexp_all(mlx_array *res, const mlx_array a, bool keepdims, const mlx_stream s)#
int mlx_matmul(mlx_array *res, const mlx_array a, const mlx_array b, const mlx_stream s)#
int mlx_max(mlx_array *res, const mlx_array a, const int *axes, size_t axes_num, bool keepdims, const mlx_stream s)#
int mlx_max_all(mlx_array *res, const mlx_array a, bool keepdims, const mlx_stream s)#
int mlx_maximum(mlx_array *res, const mlx_array a, const mlx_array b, const mlx_stream s)#
int mlx_mean(mlx_array *res, const mlx_array a, const int *axes, size_t axes_num, bool keepdims, const mlx_stream s)#
int mlx_mean_all(mlx_array *res, const mlx_array a, bool keepdims, const mlx_stream s)#
int mlx_meshgrid(mlx_vector_array *res, const mlx_vector_array arrays, bool sparse, const char *indexing, const mlx_stream s)#
int mlx_min(mlx_array *res, const mlx_array a, const int *axes, size_t axes_num, bool keepdims, const mlx_stream s)#
int mlx_min_all(mlx_array *res, const mlx_array a, bool keepdims, const mlx_stream s)#
int mlx_minimum(mlx_array *res, const mlx_array a, const mlx_array b, const mlx_stream s)#
int mlx_moveaxis(mlx_array *res, const mlx_array a, int source, int destination, const mlx_stream s)#
int mlx_multiply(mlx_array *res, const mlx_array a, const mlx_array b, const mlx_stream s)#
int mlx_nan_to_num(mlx_array *res, const mlx_array a, float nan, mlx_optional_float posinf, mlx_optional_float neginf, const mlx_stream s)#
int mlx_negative(mlx_array *res, const mlx_array a, const mlx_stream s)#
int mlx_not_equal(mlx_array *res, const mlx_array a, const mlx_array b, const mlx_stream s)#
int mlx_number_of_elements(mlx_array *res, const mlx_array a, const int *axes, size_t axes_num, bool inverted, mlx_dtype dtype, const mlx_stream s)#
int mlx_ones(mlx_array *res, const int *shape, size_t shape_num, mlx_dtype dtype, const mlx_stream s)#
int mlx_ones_like(mlx_array *res, const mlx_array a, const mlx_stream s)#
int mlx_outer(mlx_array *res, const mlx_array a, const mlx_array b, const mlx_stream s)#
int mlx_pad(mlx_array *res, const mlx_array a, const int *axes, size_t axes_num, const int *low_pad_size, size_t low_pad_size_num, const int *high_pad_size, size_t high_pad_size_num, const mlx_array pad_value, const char *mode, const mlx_stream s)#
int mlx_partition(mlx_array *res, const mlx_array a, int kth, int axis, const mlx_stream s)#
int mlx_partition_all(mlx_array *res, const mlx_array a, int kth, const mlx_stream s)#
int mlx_power(mlx_array *res, const mlx_array a, const mlx_array b, const mlx_stream s)#
int mlx_prod(mlx_array *res, const mlx_array a, const int *axes, size_t axes_num, bool keepdims, const mlx_stream s)#
int mlx_prod_all(mlx_array *res, const mlx_array a, bool keepdims, const mlx_stream s)#
int mlx_put_along_axis(mlx_array *res, const mlx_array a, const mlx_array indices, const mlx_array values, int axis, const mlx_stream s)#
int mlx_quantize(mlx_array *res_0, mlx_array *res_1, mlx_array *res_2, const mlx_array w, int group_size, int bits, const mlx_stream s)#
int mlx_quantized_matmul(mlx_array *res, const mlx_array x, const mlx_array w, const mlx_array scales, const mlx_array biases, bool transpose, int group_size, int bits, const mlx_stream s)#
int mlx_radians(mlx_array *res, const mlx_array a, const mlx_stream s)#
int mlx_real(mlx_array *res, const mlx_array a, const mlx_stream s)#
int mlx_reciprocal(mlx_array *res, const mlx_array a, const mlx_stream s)#
int mlx_remainder(mlx_array *res, const mlx_array a, const mlx_array b, const mlx_stream s)#
int mlx_repeat(mlx_array *res, const mlx_array arr, int repeats, int axis, const mlx_stream s)#
int mlx_repeat_all(mlx_array *res, const mlx_array arr, int repeats, const mlx_stream s)#
int mlx_reshape(mlx_array *res, const mlx_array a, const int *shape, size_t shape_num, const mlx_stream s)#
int mlx_right_shift(mlx_array *res, const mlx_array a, const mlx_array b, const mlx_stream s)#
int mlx_roll(mlx_array *res, const mlx_array a, int shift, const int *axes, size_t axes_num, const mlx_stream s)#
int mlx_roll_all(mlx_array *res, const mlx_array a, int shift, const mlx_stream s)#
int mlx_round(mlx_array *res, const mlx_array a, int decimals, const mlx_stream s)#
int mlx_rsqrt(mlx_array *res, const mlx_array a, const mlx_stream s)#
int mlx_scatter(mlx_array *res, const mlx_array a, const mlx_vector_array indices, const mlx_array updates, const int *axes, size_t axes_num, const mlx_stream s)#
int mlx_scatter_add(mlx_array *res, const mlx_array a, const mlx_vector_array indices, const mlx_array updates, const int *axes, size_t axes_num, const mlx_stream s)#
int mlx_scatter_add_axis(mlx_array *res, const mlx_array a, const mlx_array indices, const mlx_array values, int axis, const mlx_stream s)#
int mlx_scatter_max(mlx_array *res, const mlx_array a, const mlx_vector_array indices, const mlx_array updates, const int *axes, size_t axes_num, const mlx_stream s)#
int mlx_scatter_min(mlx_array *res, const mlx_array a, const mlx_vector_array indices, const mlx_array updates, const int *axes, size_t axes_num, const mlx_stream s)#
int mlx_scatter_prod(mlx_array *res, const mlx_array a, const mlx_vector_array indices, const mlx_array updates, const int *axes, size_t axes_num, const mlx_stream s)#
int mlx_sigmoid(mlx_array *res, const mlx_array a, const mlx_stream s)#
int mlx_sign(mlx_array *res, const mlx_array a, const mlx_stream s)#
int mlx_sin(mlx_array *res, const mlx_array a, const mlx_stream s)#
int mlx_sinh(mlx_array *res, const mlx_array a, const mlx_stream s)#
int mlx_slice(mlx_array *res, const mlx_array a, const int *start, size_t start_num, const int *stop, size_t stop_num, const int *strides, size_t strides_num, const mlx_stream s)#
int mlx_slice_update(mlx_array *res, const mlx_array src, const mlx_array update, const int *start, size_t start_num, const int *stop, size_t stop_num, const int *strides, size_t strides_num, const mlx_stream s)#
int mlx_softmax(mlx_array *res, const mlx_array a, const int *axes, size_t axes_num, bool precise, const mlx_stream s)#
int mlx_softmax_all(mlx_array *res, const mlx_array a, bool precise, const mlx_stream s)#
int mlx_sort(mlx_array *res, const mlx_array a, int axis, const mlx_stream s)#
int mlx_sort_all(mlx_array *res, const mlx_array a, const mlx_stream s)#
int mlx_split_equal_parts(mlx_vector_array *res, const mlx_array a, int num_splits, int axis, const mlx_stream s)#
int mlx_split(mlx_vector_array *res, const mlx_array a, const int *indices, size_t indices_num, int axis, const mlx_stream s)#
int mlx_sqrt(mlx_array *res, const mlx_array a, const mlx_stream s)#
int mlx_square(mlx_array *res, const mlx_array a, const mlx_stream s)#
int mlx_squeeze(mlx_array *res, const mlx_array a, const int *axes, size_t axes_num, const mlx_stream s)#
int mlx_squeeze_all(mlx_array *res, const mlx_array a, const mlx_stream s)#
int mlx_stack(mlx_array *res, const mlx_vector_array arrays, int axis, const mlx_stream s)#
int mlx_stack_all(mlx_array *res, const mlx_vector_array arrays, const mlx_stream s)#
int mlx_std(mlx_array *res, const mlx_array a, const int *axes, size_t axes_num, bool keepdims, int ddof, const mlx_stream s)#
int mlx_std_all(mlx_array *res, const mlx_array a, bool keepdims, int ddof, const mlx_stream s)#
int mlx_stop_gradient(mlx_array *res, const mlx_array a, const mlx_stream s)#
int mlx_subtract(mlx_array *res, const mlx_array a, const mlx_array b, const mlx_stream s)#
int mlx_sum(mlx_array *res, const mlx_array a, const int *axes, size_t axes_num, bool keepdims, const mlx_stream s)#
int mlx_sum_all(mlx_array *res, const mlx_array a, bool keepdims, const mlx_stream s)#
int mlx_swapaxes(mlx_array *res, const mlx_array a, int axis1, int axis2, const mlx_stream s)#
int mlx_take(mlx_array *res, const mlx_array a, const mlx_array indices, int axis, const mlx_stream s)#
int mlx_take_all(mlx_array *res, const mlx_array a, const mlx_array indices, const mlx_stream s)#
int mlx_take_along_axis(mlx_array *res, const mlx_array a, const mlx_array indices, int axis, const mlx_stream s)#
int mlx_tan(mlx_array *res, const mlx_array a, const mlx_stream s)#
int mlx_tanh(mlx_array *res, const mlx_array a, const mlx_stream s)#
int mlx_tensordot(mlx_array *res, const mlx_array a, const mlx_array b, const int *axes_a, size_t axes_a_num, const int *axes_b, size_t axes_b_num, const mlx_stream s)#
int mlx_tensordot_along_axis(mlx_array *res, const mlx_array a, const mlx_array b, int axis, const mlx_stream s)#
int mlx_tile(mlx_array *res, const mlx_array arr, const int *reps, size_t reps_num, const mlx_stream s)#
int mlx_topk(mlx_array *res, const mlx_array a, int k, int axis, const mlx_stream s)#
int mlx_topk_all(mlx_array *res, const mlx_array a, int k, const mlx_stream s)#
int mlx_trace(mlx_array *res, const mlx_array a, int offset, int axis1, int axis2, mlx_dtype dtype, const mlx_stream s)#
int mlx_transpose(mlx_array *res, const mlx_array a, const int *axes, size_t axes_num, const mlx_stream s)#
int mlx_transpose_all(mlx_array *res, const mlx_array a, const mlx_stream s)#
int mlx_tri(mlx_array *res, int n, int m, int k, mlx_dtype type, const mlx_stream s)#
int mlx_tril(mlx_array *res, const mlx_array x, int k, const mlx_stream s)#
int mlx_triu(mlx_array *res, const mlx_array x, int k, const mlx_stream s)#
int mlx_unflatten(mlx_array *res, const mlx_array a, int axis, const int *shape, size_t shape_num, const mlx_stream s)#
int mlx_var(mlx_array *res, const mlx_array a, const int *axes, size_t axes_num, bool keepdims, int ddof, const mlx_stream s)#
int mlx_var_all(mlx_array *res, const mlx_array a, bool keepdims, int ddof, const mlx_stream s)#
int mlx_view(mlx_array *res, const mlx_array a, mlx_dtype dtype, const mlx_stream s)#
int mlx_where(mlx_array *res, const mlx_array condition, const mlx_array x, const mlx_array y, const mlx_stream s)#
int mlx_zeros(mlx_array *res, const int *shape, size_t shape_num, mlx_dtype dtype, const mlx_stream s)#
int mlx_zeros_like(mlx_array *res, const mlx_array a, const mlx_stream s)#