pub fn lu_device(
a: impl AsRef<Array>,
stream: impl AsRef<Stream>,
) -> Result<(Array, Array, Array)>
Expand description
Compute the LU factorization of the given matrix A.
Note, unlike the default behavior of scipy.linalg.lu, the pivots are indices. To reconstruct the input use L[P, :] @ U for 2 dimensions or mx.take_along_axis(L, P[…, None], axis=-2) @ U for more than 2 dimensions.
To construct the full permuation matrix do:
ⓘ
// python
// P = mx.put_along_axis(mx.zeros_like(L), p[..., None], mx.array(1.0), axis=-1)
let p = mlx_rs::ops::put_along_axis(
mlx_rs::ops::zeros_like(&l),
p.index((Ellipsis, NewAxis)),
array!(1.0),
-1,
).unwrap();
§Params
a
: input arraystream
: stream to execute the operation
§Returns
The p
, L
, and U
arrays, such that A = L[P, :] @ U