mlx.core.hadamard_transform

mlx.core.hadamard_transform#

hadamard_transform(a: array, scale: float | None = None, stream: None | Stream | Device = None) array#

Perform the Walsh-Hadamard transform along the final axis.

Equivalent to:

from scipy.linalg import hadamard

y = (hadamard(len(x)) @ x) * scale

Supports sizes n = m*2^k for m in (1, 12, 20, 28) and 2^k <= 8192 for float32 and 2^k <= 16384 for float16/bfloat16.

Parameters:
  • a (array) – Input array or scalar.

  • scale (float) – Scale the output by this factor. Defaults to 1/sqrt(a.shape[-1]) so that the Hadamard matrix is orthonormal.

Returns:

The transformed array.

Return type:

array