mlx_rs::linalg

Function qr

Source
pub fn qr(a: impl AsRef<Array>) -> Result<(Array, Array)>
Expand description

The QR factorization of the input matrix. Returns an error if the input is not valid.

This function supports arrays with at least 2 dimensions. The matrices which are factorized are assumed to be in the last two dimensions of the input.

Evaluation on the GPU is not yet implemented.

§Params

  • array: input array

§Example

use mlx_rs::{Array, StreamOrDevice, linalg::*};

let a = Array::from_slice(&[2.0f32, 3.0, 1.0, 2.0], &[2, 2]);

let (q, r) = qr_device(&a, StreamOrDevice::cpu()).unwrap();

let q_expected = Array::from_slice(&[-0.894427, -0.447214, -0.447214, 0.894427], &[2, 2]);
let r_expected = Array::from_slice(&[-2.23607, -3.57771, 0.0, 0.447214], &[2, 2]);

assert!(q.all_close(&q_expected, None, None, None).unwrap().item::<bool>());
assert!(r.all_close(&r_expected, None, None, None).unwrap().item::<bool>());