mlx_rs::linalg

Function svd_device

Source
pub fn svd_device(
    array: impl AsRef<Array>,
    stream: impl AsRef<Stream>,
) -> Result<(Array, Array, Array)>
Expand description

The Singular Value Decomposition (SVD) of the input matrix. Returns an error if the input is not valid.

This function supports arrays with at least 2 dimensions. When the input has more than two dimensions, the function iterates over all indices of the first a.ndim - 2 dimensions and for each combination SVD is applied to the last two indices.

Evaluation on the GPU is not yet implemented.

§Params

  • array: input array

§Example

use mlx_rs::{Array, StreamOrDevice, linalg::*};

let a = Array::from_slice(&[1.0f32, 2.0, 3.0, 4.0], &[2, 2]);
let (u, s, vt) = svd_device(&a, StreamOrDevice::cpu()).unwrap();
let u_expected = Array::from_slice(&[-0.404554, 0.914514, -0.914514, -0.404554], &[2, 2]);
let s_expected = Array::from_slice(&[5.46499, 0.365966], &[2]);
let vt_expected = Array::from_slice(&[-0.576048, -0.817416, -0.817415, 0.576048], &[2, 2]);
assert!(u.all_close(&u_expected, None, None, None).unwrap().item::<bool>());
assert!(s.all_close(&s_expected, None, None, None).unwrap().item::<bool>());
assert!(vt.all_close(&vt_expected, None, None, None).unwrap().item::<bool>());