mlx.core.conv_general#

conv_general(input: array, weight: array, /, stride: int | Sequence[int] = 1, padding: int | Sequence[int] | Tuple[Sequence[int], Sequence[int]] = 0, kernel_dilation: int | Sequence[int] = 1, input_dilation: int | Sequence[int] = 1, groups: int = 1, flip: bool = false, *, stream: None | Stream | Device = None) array#

General convolution over an input with several channels

Note

  • Only 1d and 2d convolutions are supported at the moment

  • the default groups=1 is currently supported.

Parameters:
  • input (array) – Input array of shape (N, ..., C_in)

  • weight (array) – Weight array of shape (C_out, ..., C_in)

  • stride (int or list(int), optional) – list with kernel strides. All spatial dimensions get the same stride if only one number is specified. Default: 1.

  • padding (int, list(int), or tuple(list(int), list(int)), optional) – list with input padding. All spatial dimensions get the same padding if only one number is specified. Default: 0.

  • kernel_dilation (int or list(int), optional) – list with kernel dilation. All spatial dimensions get the same dilation if only one number is specified. Default: 1

  • input_dilation (int or list(int), optional) – list with input dilation. All spatial dimensions get the same dilation if only one number is specified. Default: 1

  • groups (int, optional) – Input feature groups. Default: 1.

  • flip (bool, optional) – Flip the order in which the spatial dimensions of the weights are processed. Performs the cross-correlation operator when flip is False and the convolution operator otherwise. Default: False.

Returns:

The convolved array.

Return type:

array