Function lu

Source
pub fn lu(a: impl AsRef<Array>) -> Result<(Array, Array, Array)>
Expand description

Compute the LU factorization of the given matrix A.

Note, unlike the default behavior of scipy.linalg.lu, the pivots are indices. To reconstruct the input use L[P, :] @ U for 2 dimensions or mx.take_along_axis(L, P[…, None], axis=-2) @ U for more than 2 dimensions.

To construct the full permuation matrix do:

// python
// P = mx.put_along_axis(mx.zeros_like(L), p[..., None], mx.array(1.0), axis=-1)
let p = mlx_rs::ops::put_along_axis(
    mlx_rs::ops::zeros_like(&l),
    p.index((Ellipsis, NewAxis)),
    array!(1.0),
    -1,
).unwrap();

§Params

  • a: input array
  • stream: stream to execute the operation

§Returns

The p, L, and U arrays, such that A = L[P, :] @ U