Operations#
-
array arange(double start, double stop, double step, Dtype dtype, StreamOrDevice s = {})#
A 1D array of numbers starting at
start
(optional), stopping at stop, stepping bystep
(optional).
-
array arange(double start, double stop, double step, StreamOrDevice s = {})#
-
array arange(double start, double stop, Dtype dtype, StreamOrDevice s = {})#
-
array arange(double start, double stop, StreamOrDevice s = {})#
-
array arange(double stop, Dtype dtype, StreamOrDevice s = {})#
-
array arange(double stop, StreamOrDevice s = {})#
-
array arange(int start, int stop, int step, StreamOrDevice s = {})#
-
array arange(int start, int stop, StreamOrDevice s = {})#
-
array arange(int stop, StreamOrDevice s = {})#
-
array linspace(double start, double stop, int num = 50, Dtype dtype = float32, StreamOrDevice s = {})#
A 1D array of
num
evenly spaced numbers in the range[start, stop]
-
array astype(array a, Dtype dtype, StreamOrDevice s = {})#
Convert an array to the given data type.
-
array as_strided(array a, std::vector<int> shape, std::vector<size_t> strides, size_t offset, StreamOrDevice s = {})#
Create a view of an array with the given shape and strides.
-
array copy(array a, StreamOrDevice s = {})#
Copy another array.
-
array full(std::vector<int> shape, array vals, Dtype dtype, StreamOrDevice s = {})#
Fill an array of the given shape with the given value(s).
-
array zeros(const std::vector<int> &shape, Dtype dtype, StreamOrDevice s = {})#
Fill an array of the given shape with zeros.
-
array zeros_like(const array &a, StreamOrDevice s = {})#
-
array ones(const std::vector<int> &shape, Dtype dtype, StreamOrDevice s = {})#
Fill an array of the given shape with ones.
-
array ones_like(const array &a, StreamOrDevice s = {})#
-
array eye(int n, int m, int k, Dtype dtype, StreamOrDevice s = {})#
Fill an array of the given shape (n,m) with ones in the specified diagonal k, and zeros everywhere else.
-
inline array eye(int n, Dtype dtype, StreamOrDevice s = {})#
-
inline array eye(int n, int m, StreamOrDevice s = {})#
-
inline array eye(int n, int m, int k, StreamOrDevice s = {})#
-
inline array eye(int n, StreamOrDevice s = {})#
-
array identity(int n, Dtype dtype, StreamOrDevice s = {})#
Create a square matrix of shape (n,n) of zeros, and ones in the major diagonal.
-
inline array identity(int n, StreamOrDevice s = {})#
-
array tri(int n, int m, int k, Dtype type, StreamOrDevice s = {})#
-
inline array tri(int n, Dtype type, StreamOrDevice s = {})#
-
array tril(array x, int k = 0, StreamOrDevice s = {})#
-
array triu(array x, int k = 0, StreamOrDevice s = {})#
-
array reshape(const array &a, std::vector<int> shape, StreamOrDevice s = {})#
Reshape an array to the given shape.
-
array flatten(const array &a, int start_axis, int end_axis = -1, StreamOrDevice s = {})#
Flatten the dimensions in the range
[start_axis, end_axis]
.
-
array flatten(const array &a, StreamOrDevice s = {})#
Flatten the array to 1D.
-
array squeeze(const array &a, const std::vector<int> &axes, StreamOrDevice s = {})#
Remove singleton dimensions at the given axes.
-
inline array squeeze(const array &a, int axis, StreamOrDevice s = {})#
Remove singleton dimensions at the given axis.
-
array squeeze(const array &a, StreamOrDevice s = {})#
Remove all singleton dimensions.
-
array expand_dims(const array &a, const std::vector<int> &axes, StreamOrDevice s = {})#
Add a singleton dimension at the given axes.
-
array expand_dims(const array &a, int axis, StreamOrDevice s = {})#
Add a singleton dimension at the given axis.
-
array slice(const array &a, std::vector<int> start, std::vector<int> stop, std::vector<int> strides, StreamOrDevice s = {})#
Slice an array.
-
array slice(const array &a, const std::vector<int> &start, const std::vector<int> &stop, StreamOrDevice s = {})#
Slice an array with a stride of 1 in each dimension.
-
array slice_update(const array &src, const array &update, std::vector<int> start, std::vector<int> stop, std::vector<int> strides, StreamOrDevice s = {})#
Update a slice from the source array.
-
array slice_update(const array &src, const array &update, std::vector<int> start, std::vector<int> stop, StreamOrDevice s = {})#
Update a slice from the source array with stride 1 in each dimension.
-
std::vector<array> split(const array &a, int num_splits, int axis, StreamOrDevice s = {})#
Split an array into sub-arrays along a given axis.
-
std::vector<array> split(const array &a, const std::vector<int> &indices, int axis, StreamOrDevice s = {})#
-
std::vector<array> meshgrid(const std::vector<array> &arrays, bool sparse = false, std::string indexing = "xy", StreamOrDevice s = {})#
A vector of coordinate arrays from coordinate vectors.
-
array clip(const array &a, const std::optional<array> &a_min = std::nullopt, const std::optional<array> &a_max = std::nullopt, StreamOrDevice s = {})#
Clip (limit) the values in an array.
-
array concatenate(const std::vector<array> &arrays, int axis, StreamOrDevice s = {})#
Concatenate arrays along a given axis.
-
array stack(const std::vector<array> &arrays, int axis, StreamOrDevice s = {})#
Stack arrays along a new axis.
-
array repeat(const array &arr, int repeats, int axis, StreamOrDevice s = {})#
Repeat an array along an axis.
-
array repeat(const array &arr, int repeats, StreamOrDevice s = {})#
-
array transpose(const array &a, std::vector<int> axes, StreamOrDevice s = {})#
Permutes the dimensions according to the given axes.
-
array swapaxes(const array &a, int axis1, int axis2, StreamOrDevice s = {})#
Swap two axes of an array.
-
array moveaxis(const array &a, int source, int destination, StreamOrDevice s = {})#
Move an axis of an array.
-
array pad(const array &a, const std::vector<int> &axes, const std::vector<int> &low_pad_size, const std::vector<int> &high_pad_size, const array &pad_value = array(0), StreamOrDevice s = {})#
Pad an array with a constant value.
-
array pad(const array &a, const std::vector<std::pair<int, int>> &pad_width, const array &pad_value = array(0), StreamOrDevice s = {})#
Pad an array with a constant value along all axes.
-
array pad(const array &a, const std::pair<int, int> &pad_width, const array &pad_value = array(0), StreamOrDevice s = {})#
-
array pad(const array &a, int pad_width, const array &pad_value = array(0), StreamOrDevice s = {})#
-
array transpose(const array &a, StreamOrDevice s = {})#
Permutes the dimensions in reverse order.
-
array broadcast_to(const array &a, const std::vector<int> &shape, StreamOrDevice s = {})#
Broadcast an array to a given shape.
-
std::vector<array> broadcast_arrays(const std::vector<array> &inputs, StreamOrDevice s = {})#
Broadcast a vector of arrays against one another.
-
array equal(const array &a, const array &b, StreamOrDevice s = {})#
Returns the bool array with (a == b) element-wise.
-
inline array operator==(const array &a, const array &b)#
-
array not_equal(const array &a, const array &b, StreamOrDevice s = {})#
Returns the bool array with (a != b) element-wise.
-
inline array operator!=(const array &a, const array &b)#
-
array greater(const array &a, const array &b, StreamOrDevice s = {})#
Returns bool array with (a > b) element-wise.
-
inline array operator>(const array &a, const array &b)#
-
array greater_equal(const array &a, const array &b, StreamOrDevice s = {})#
Returns bool array with (a >= b) element-wise.
-
inline array operator>=(const array &a, const array &b)#
-
array less(const array &a, const array &b, StreamOrDevice s = {})#
Returns bool array with (a < b) element-wise.
-
inline array operator<(const array &a, const array &b)#
-
array less_equal(const array &a, const array &b, StreamOrDevice s = {})#
Returns bool array with (a <= b) element-wise.
-
inline array operator<=(const array &a, const array &b)#
-
array array_equal(const array &a, const array &b, bool equal_nan, StreamOrDevice s = {})#
True if two arrays have the same shape and elements.
-
inline array array_equal(const array &a, const array &b, StreamOrDevice s = {})#
-
array isnan(const array &a, StreamOrDevice s = {})#
-
array isinf(const array &a, StreamOrDevice s = {})#
-
array isposinf(const array &a, StreamOrDevice s = {})#
-
array isneginf(const array &a, StreamOrDevice s = {})#
-
array where(const array &condition, const array &x, const array &y, StreamOrDevice s = {})#
Select from x or y depending on condition.
-
array all(const array &a, bool keepdims, StreamOrDevice s = {})#
True if all elements in the array are true (or non-zero).
-
inline array all(const array &a, StreamOrDevice s = {})#
-
array allclose(const array &a, const array &b, double rtol = 1e-5, double atol = 1e-8, bool equal_nan = false, StreamOrDevice s = {})#
True if the two arrays are equal within the specified tolerance.
-
array isclose(const array &a, const array &b, double rtol = 1e-5, double atol = 1e-8, bool equal_nan = false, StreamOrDevice s = {})#
Returns a boolean array where two arrays are element-wise equal within the specified tolerance.
-
array all(const array &a, const std::vector<int> &axes, bool keepdims = false, StreamOrDevice s = {})#
Reduces the input along the given axes.
An output value is true if all the corresponding inputs are true.
-
array all(const array &a, int axis, bool keepdims = false, StreamOrDevice s = {})#
Reduces the input along the given axis.
An output value is true if all the corresponding inputs are true.
-
array any(const array &a, bool keepdims, StreamOrDevice s = {})#
True if any elements in the array are true (or non-zero).
-
inline array any(const array &a, StreamOrDevice s = {})#
-
array any(const array &a, const std::vector<int> &axes, bool keepdims = false, StreamOrDevice s = {})#
Reduces the input along the given axes.
An output value is true if any of the corresponding inputs are true.
-
array any(const array &a, int axis, bool keepdims = false, StreamOrDevice s = {})#
Reduces the input along the given axis.
An output value is true if any of the corresponding inputs are true.
-
array sum(const array &a, bool keepdims, StreamOrDevice s = {})#
Sums the elements of an array.
-
inline array sum(const array &a, StreamOrDevice s = {})#
-
array sum(const array &a, const std::vector<int> &axes, bool keepdims = false, StreamOrDevice s = {})#
Sums the elements of an array along the given axes.
-
array sum(const array &a, int axis, bool keepdims = false, StreamOrDevice s = {})#
Sums the elements of an array along the given axis.
-
array mean(const array &a, bool keepdims, StreamOrDevice s = {})#
Computes the mean of the elements of an array.
-
inline array mean(const array &a, StreamOrDevice s = {})#
-
array mean(const array &a, const std::vector<int> &axes, bool keepdims = false, StreamOrDevice s = {})#
Computes the mean of the elements of an array along the given axes.
-
array mean(const array &a, int axis, bool keepdims = false, StreamOrDevice s = {})#
Computes the mean of the elements of an array along the given axis.
-
array var(const array &a, bool keepdims, int ddof = 0, StreamOrDevice s = {})#
Computes the variance of the elements of an array.
-
inline array var(const array &a, StreamOrDevice s = {})#
-
array var(const array &a, const std::vector<int> &axes, bool keepdims = false, int ddof = 0, StreamOrDevice s = {})#
Computes the variance of the elements of an array along the given axes.
-
array var(const array &a, int axis, bool keepdims = false, int ddof = 0, StreamOrDevice s = {})#
Computes the variance of the elements of an array along the given axis.
-
array std(const array &a, bool keepdims, int ddof = 0, StreamOrDevice s = {})#
Computes the standard deviation of the elements of an array.
-
inline array std(const array &a, StreamOrDevice s = {})#
-
array std(const array &a, const std::vector<int> &axes, bool keepdims = false, int ddof = 0, StreamOrDevice s = {})#
Computes the standard deviatoin of the elements of an array along the given axes.
-
array std(const array &a, int axis, bool keepdims = false, int ddof = 0, StreamOrDevice s = {})#
Computes the standard deviation of the elements of an array along the given axis.
-
array prod(const array &a, bool keepdims, StreamOrDevice s = {})#
The product of all elements of the array.
-
inline array prod(const array &a, StreamOrDevice s = {})#
-
array prod(const array &a, const std::vector<int> &axes, bool keepdims = false, StreamOrDevice s = {})#
The product of the elements of an array along the given axes.
-
array prod(const array &a, int axis, bool keepdims = false, StreamOrDevice s = {})#
The product of the elements of an array along the given axis.
-
array max(const array &a, bool keepdims, StreamOrDevice s = {})#
The maximum of all elements of the array.
-
inline array max(const array &a, StreamOrDevice s = {})#
-
array max(const array &a, const std::vector<int> &axes, bool keepdims = false, StreamOrDevice s = {})#
The maximum of the elements of an array along the given axes.
-
array max(const array &a, int axis, bool keepdims = false, StreamOrDevice s = {})#
The maximum of the elements of an array along the given axis.
-
array min(const array &a, bool keepdims, StreamOrDevice s = {})#
The minimum of all elements of the array.
-
inline array min(const array &a, StreamOrDevice s = {})#
-
array min(const array &a, const std::vector<int> &axes, bool keepdims = false, StreamOrDevice s = {})#
The minimum of the elements of an array along the given axes.
-
array min(const array &a, int axis, bool keepdims = false, StreamOrDevice s = {})#
The minimum of the elements of an array along the given axis.
-
array argmin(const array &a, bool keepdims, StreamOrDevice s = {})#
Returns the index of the minimum value in the array.
-
inline array argmin(const array &a, StreamOrDevice s = {})#
-
array argmin(const array &a, int axis, bool keepdims = false, StreamOrDevice s = {})#
Returns the indices of the minimum values along a given axis.
-
array argmax(const array &a, bool keepdims, StreamOrDevice s = {})#
Returns the index of the maximum value in the array.
-
inline array argmax(const array &a, StreamOrDevice s = {})#
-
array argmax(const array &a, int axis, bool keepdims = false, StreamOrDevice s = {})#
Returns the indices of the maximum values along a given axis.
-
array sort(const array &a, StreamOrDevice s = {})#
Returns a sorted copy of the flattened array.
-
array sort(const array &a, int axis, StreamOrDevice s = {})#
Returns a sorted copy of the array along a given axis.
-
array argsort(const array &a, StreamOrDevice s = {})#
Returns indices that sort the flattened array.
-
array argsort(const array &a, int axis, StreamOrDevice s = {})#
Returns indices that sort the array along a given axis.
-
array partition(const array &a, int kth, StreamOrDevice s = {})#
Returns a partitioned copy of the flattened array such that the smaller kth elements are first.
-
array partition(const array &a, int kth, int axis, StreamOrDevice s = {})#
Returns a partitioned copy of the array along a given axis such that the smaller kth elements are first.
-
array argpartition(const array &a, int kth, StreamOrDevice s = {})#
Returns indices that partition the flattened array such that the smaller kth elements are first.
-
array argpartition(const array &a, int kth, int axis, StreamOrDevice s = {})#
Returns indices that partition the array along a given axis such that the smaller kth elements are first.
-
array topk(const array &a, int k, StreamOrDevice s = {})#
Returns topk elements of the flattened array.
-
array topk(const array &a, int k, int axis, StreamOrDevice s = {})#
Returns topk elements of the array along a given axis.
-
array logsumexp(const array &a, bool keepdims, StreamOrDevice s = {})#
The logsumexp of all elements of the array.
-
inline array logsumexp(const array &a, StreamOrDevice s = {})#
-
array logsumexp(const array &a, const std::vector<int> &axes, bool keepdims = false, StreamOrDevice s = {})#
The logsumexp of the elements of an array along the given axes.
-
array logsumexp(const array &a, int axis, bool keepdims = false, StreamOrDevice s = {})#
The logsumexp of the elements of an array along the given axis.
-
array abs(const array &a, StreamOrDevice s = {})#
Absolute value of elements in an array.
-
array negative(const array &a, StreamOrDevice s = {})#
Negate an array.
-
array operator-(const array &a)#
-
array sign(const array &a, StreamOrDevice s = {})#
The sign of the elements in an array.
-
array logical_not(const array &a, StreamOrDevice s = {})#
Logical not of an array.
-
array logical_and(const array &a, const array &b, StreamOrDevice s = {})#
Logical and of two arrays.
-
array operator&&(const array &a, const array &b)#
-
array logical_or(const array &a, const array &b, StreamOrDevice s = {})#
Logical or of two arrays.
-
array operator||(const array &a, const array &b)#
-
array reciprocal(const array &a, StreamOrDevice s = {})#
The reciprocal (1/x) of the elements in an array.
-
array add(const array &a, const array &b, StreamOrDevice s = {})#
Add two arrays.
-
array operator+(const array &a, const array &b)#
-
array subtract(const array &a, const array &b, StreamOrDevice s = {})#
Subtract two arrays.
-
array operator-(const array &a, const array &b)#
-
array multiply(const array &a, const array &b, StreamOrDevice s = {})#
Multiply two arrays.
-
array operator*(const array &a, const array &b)#
-
array divide(const array &a, const array &b, StreamOrDevice s = {})#
Divide two arrays.
-
array operator/(const array &a, const array &b)#
-
array operator/(double a, const array &b)#
-
array operator/(const array &a, double b)#
-
std::vector<array> divmod(const array &a, const array &b, StreamOrDevice s = {})#
Compute the element-wise quotient and remainder.
-
array floor_divide(const array &a, const array &b, StreamOrDevice s = {})#
Compute integer division.
Equivalent to doing floor(a / x).
-
array remainder(const array &a, const array &b, StreamOrDevice s = {})#
Compute the element-wise remainder of division.
-
array operator%(const array &a, const array &b)#
-
array maximum(const array &a, const array &b, StreamOrDevice s = {})#
Element-wise maximum between two arrays.
-
array minimum(const array &a, const array &b, StreamOrDevice s = {})#
Element-wise minimum between two arrays.
-
array floor(const array &a, StreamOrDevice s = {})#
Floor the element of an array.
-
array ceil(const array &a, StreamOrDevice s = {})#
Ceil the element of an array.
-
array square(const array &a, StreamOrDevice s = {})#
Square the elements of an array.
-
array exp(const array &a, StreamOrDevice s = {})#
Exponential of the elements of an array.
-
array sin(const array &a, StreamOrDevice s = {})#
Sine of the elements of an array.
-
array cos(const array &a, StreamOrDevice s = {})#
Cosine of the elements of an array.
-
array tan(const array &a, StreamOrDevice s = {})#
Tangent of the elements of an array.
-
array arcsin(const array &a, StreamOrDevice s = {})#
Arc Sine of the elements of an array.
-
array arccos(const array &a, StreamOrDevice s = {})#
Arc Cosine of the elements of an array.
-
array arctan(const array &a, StreamOrDevice s = {})#
Arc Tangent of the elements of an array.
-
array arctan2(const array &a, const array &b, StreamOrDevice s = {})#
Inverse tangent of the ratio of two arrays.
-
array sinh(const array &a, StreamOrDevice s = {})#
Hyperbolic Sine of the elements of an array.
-
array cosh(const array &a, StreamOrDevice s = {})#
Hyperbolic Cosine of the elements of an array.
-
array tanh(const array &a, StreamOrDevice s = {})#
Hyperbolic Tangent of the elements of an array.
-
array arcsinh(const array &a, StreamOrDevice s = {})#
Inverse Hyperbolic Sine of the elements of an array.
-
array arccosh(const array &a, StreamOrDevice s = {})#
Inverse Hyperbolic Cosine of the elements of an array.
-
array arctanh(const array &a, StreamOrDevice s = {})#
Inverse Hyperbolic Tangent of the elements of an array.
-
array degrees(const array &a, StreamOrDevice s = {})#
Convert the elements of an array from Radians to Degrees.
-
array radians(const array &a, StreamOrDevice s = {})#
Convert the elements of an array from Degrees to Radians.
-
array log(const array &a, StreamOrDevice s = {})#
Natural logarithm of the elements of an array.
-
array log2(const array &a, StreamOrDevice s = {})#
Log base 2 of the elements of an array.
-
array log10(const array &a, StreamOrDevice s = {})#
Log base 10 of the elements of an array.
-
array log1p(const array &a, StreamOrDevice s = {})#
Natural logarithm of one plus elements in the array:
log(1 + a)
.
-
array logaddexp(const array &a, const array &b, StreamOrDevice s = {})#
Log-add-exp of one elements in the array:
log(exp(a) + exp(b))
.
-
array sigmoid(const array &a, StreamOrDevice s = {})#
Element-wise logistic sigmoid of the array:
1 / (1 + exp(-x)
.
-
array erf(const array &a, StreamOrDevice s = {})#
Computes the error function of the elements of an array.
-
array erfinv(const array &a, StreamOrDevice s = {})#
Computes the inverse error function of the elements of an array.
-
array expm1(const array &a, StreamOrDevice s = {})#
Computes the expm1 function of the elements of an array.
-
array stop_gradient(const array &a, StreamOrDevice s = {})#
Stop the flow of gradients.
-
array round(const array &a, int decimals, StreamOrDevice s = {})#
Round a floating point number.
-
inline array round(const array &a, StreamOrDevice s = {})#
-
array matmul(const array &a, const array &b, StreamOrDevice s = {})#
Matrix-matrix multiplication.
-
array gather(const array &a, const std::vector<array> &indices, const std::vector<int> &axes, const std::vector<int> &slice_sizes, StreamOrDevice s = {})#
Gather array entries given indices and slices.
-
inline array gather(const array &a, const array &indices, int axis, const std::vector<int> &slice_sizes, StreamOrDevice s = {})#
-
array take(const array &a, const array &indices, int axis, StreamOrDevice s = {})#
Take array slices at the given indices of the specified axis.
-
array take(const array &a, const array &indices, StreamOrDevice s = {})#
Take array entries at the given indices treating the array as flattened.
-
array take_along_axis(const array &a, const array &indices, int axis, StreamOrDevice s = {})#
Take array entries given indices along the axis.
-
array scatter(const array &a, const std::vector<array> &indices, const array &updates, const std::vector<int> &axes, StreamOrDevice s = {})#
Scatter updates to given linear indices.
-
inline array scatter(const array &a, const array &indices, const array &updates, int axis, StreamOrDevice s = {})#
-
array scatter_add(const array &a, const std::vector<array> &indices, const array &updates, const std::vector<int> &axes, StreamOrDevice s = {})#
Scatter and add updates to given indices.
-
inline array scatter_add(const array &a, const array &indices, const array &updates, int axis, StreamOrDevice s = {})#
-
array scatter_prod(const array &a, const std::vector<array> &indices, const array &updates, const std::vector<int> &axes, StreamOrDevice s = {})#
Scatter and prod updates to given indices.
-
inline array scatter_prod(const array &a, const array &indices, const array &updates, int axis, StreamOrDevice s = {})#
-
array scatter_max(const array &a, const std::vector<array> &indices, const array &updates, const std::vector<int> &axes, StreamOrDevice s = {})#
Scatter and max updates to given linear indices.
-
inline array scatter_max(const array &a, const array &indices, const array &updates, int axis, StreamOrDevice s = {})#
-
array scatter_min(const array &a, const std::vector<array> &indices, const array &updates, const std::vector<int> &axes, StreamOrDevice s = {})#
Scatter and min updates to given linear indices.
-
inline array scatter_min(const array &a, const array &indices, const array &updates, int axis, StreamOrDevice s = {})#
-
array sqrt(const array &a, StreamOrDevice s = {})#
Square root the elements of an array.
-
array rsqrt(const array &a, StreamOrDevice s = {})#
Square root and reciprocal the elements of an array.
-
array softmax(const array &a, const std::vector<int> &axes, bool precise = false, StreamOrDevice s = {})#
Softmax of an array.
-
array softmax(const array &a, bool precise = false, StreamOrDevice s = {})#
Softmax of an array.
-
inline array softmax(const array &a, int axis, bool precise = false, StreamOrDevice s = {})#
Softmax of an array.
-
array power(const array &a, const array &b, StreamOrDevice s = {})#
Raise elements of a to the power of b element-wise.
-
array cumsum(const array &a, int axis, bool reverse = false, bool inclusive = true, StreamOrDevice s = {})#
Cumulative sum of an array.
-
array cumprod(const array &a, int axis, bool reverse = false, bool inclusive = true, StreamOrDevice s = {})#
Cumulative product of an array.
-
array cummax(const array &a, int axis, bool reverse = false, bool inclusive = true, StreamOrDevice s = {})#
Cumulative max of an array.
-
array cummin(const array &a, int axis, bool reverse = false, bool inclusive = true, StreamOrDevice s = {})#
Cumulative min of an array.
-
array conv_general(array input, array weight, std::vector<int> stride = {}, std::vector<int> padding_lo = {}, std::vector<int> padding_hi = {}, std::vector<int> kernel_dilation = {}, std::vector<int> input_dilation = {}, int groups = 1, bool flip = false, StreamOrDevice s = {})#
General convolution with a filter.
-
inline array conv_general(const array &input, const array &weight, std::vector<int> stride = {}, std::vector<int> padding = {}, std::vector<int> kernel_dilation = {}, std::vector<int> input_dilation = {}, int groups = 1, bool flip = false, StreamOrDevice s = {})#
General convolution with a filter.
-
array conv1d(const array &input, const array &weight, int stride = 1, int padding = 0, int dilation = 1, int groups = 1, StreamOrDevice s = {})#
1D convolution with a filter
-
array conv2d(const array &input, const array &weight, const std::pair<int, int> &stride = {1, 1}, const std::pair<int, int> &padding = {0, 0}, const std::pair<int, int> &dilation = {1, 1}, int groups = 1, StreamOrDevice s = {})#
2D convolution with a filter
-
array quantized_matmul(const array &x, const array &w, const array &scales, const array &biases, bool transpose = true, int group_size = 64, int bits = 4, StreamOrDevice s = {})#
Quantized matmul multiplies x with a quantized matrix w.
-
std::tuple<array, array, array> quantize(const array &w, int group_size = 64, int bits = 4, StreamOrDevice s = {})#
Quantize a matrix along its last axis.
-
array dequantize(const array &w, const array &scales, const array &biases, int group_size = 64, int bits = 4, StreamOrDevice s = {})#
Dequantize a matrix produced by quantize()
-
array tensordot(const array &a, const array &b, const int axis = 2, StreamOrDevice s = {})#
Returns a contraction of a and b over multiple dimensions.
-
array tensordot(const array &a, const array &b, const std::vector<int> &axes_a, const std::vector<int> &axes_b, StreamOrDevice s = {})#
-
array outer(const array &a, const array &b, StreamOrDevice s = {})#
Compute the outer product of two vectors.
-
array inner(const array &a, const array &b, StreamOrDevice s = {})#
Compute the inner product of two vectors.
-
array addmm(array c, array a, array b, const float &alpha = 1.f, const float &beta = 1.f, StreamOrDevice s = {})#
Compute D = beta * C + alpha * (A @ B)
-
array block_masked_mm(array a, array b, int block_size, std::optional<array> mask_out = std::nullopt, std::optional<array> mask_lhs = std::nullopt, std::optional<array> mask_rhs = std::nullopt, StreamOrDevice s = {})#
Compute matrix product with block masking.
-
array block_sparse_mm(array a, array b, std::optional<array> lhs_indices = std::nullopt, std::optional<array> rhs_indices = std::nullopt, StreamOrDevice s = {})#
Compute matrix product with matrix-level gather.
-
array diagonal(const array &a, int offset = 0, int axis1 = 0, int axis2 = 1, StreamOrDevice s = {})#
Extract a diagonal or construct a diagonal array.
-
array diag(const array &a, int k = 0, StreamOrDevice s = {})#
Extract diagonal from a 2d array or create a diagonal matrix.
-
std::vector<array> depends(const std::vector<array> &inputs, const std::vector<array> &dependencies)#
Implements the identity function but allows injecting dependencies to other arrays.
This ensures that these other arrays will have been computed when the outputs of this function are computed.
-
array atleast_1d(const array &a, StreamOrDevice s = {})#
convert an array to an atleast ndim array
-
array atleast_2d(const array &a, StreamOrDevice s = {})#
-
array atleast_3d(const array &a, StreamOrDevice s = {})#
-
array number_of_elements(const array &a, std::vector<int> axes, bool inverted, Dtype dtype = int32, StreamOrDevice s = {})#
Extract the number of elements along some axes as a scalar array.
Used to allow shape dependent shapeless compilation (pun intended).
-
array conjugate(const array &a, StreamOrDevice s = {})#
-
array bitwise_and(const array &a, const array &b, StreamOrDevice s = {})#
Bitwise and.
-
array operator&(const array &a, const array &b)#
-
array bitwise_or(const array &a, const array &b, StreamOrDevice s = {})#
Bitwise inclusive or.
-
array operator|(const array &a, const array &b)#
-
array bitwise_xor(const array &a, const array &b, StreamOrDevice s = {})#
Bitwise exclusive or.
-
array operator^(const array &a, const array &b)#
-
array left_shift(const array &a, const array &b, StreamOrDevice s = {})#
Shift bits to the left.
-
array operator<<(const array &a, const array &b)#
-
array right_shift(const array &a, const array &b, StreamOrDevice s = {})#
Shift bits to the right.
-
array operator>>(const array &a, const array &b)#