mlx.core.linalg.lu#
- lu(a: array, *, stream: None | Stream | Device = None) Tuple[array, array, array] #
Compute the LU factorization of the given matrix
A
.Note, unlike the default behavior of
scipy.linalg.lu
, the pivots are indices. To reconstruct the input useL[P, :] @ U
for 2 dimensions ormx.take_along_axis(L, P[..., None], axis=-2) @ U
for more than 2 dimensions.To construct the full permuation matrix do:
P = mx.put_along_axis(mx.zeros_like(L), p[..., None], mx.array(1.0), axis=-1)