mlx_rs::linalg

Function inv_device

Source
pub fn inv_device(
    a: impl AsRef<Array>,
    stream: impl AsRef<Stream>,
) -> Result<Array>
Expand description

Compute the inverse of a square matrix. Returns an error if the input is not valid.

This function supports arrays with at least 2 dimensions. When the input has more than two dimensions, the inverse is computed for each matrix in the last two dimensions of a.

Evaluation on the GPU is not yet implemented.

§Params

  • a: input array

§Example

use mlx_rs::{Array, StreamOrDevice, linalg::*};

let a = Array::from_slice(&[1.0f32, 2.0, 3.0, 4.0], &[2, 2]);
let a_inv = inv_device(&a, StreamOrDevice::cpu()).unwrap();
let expected = Array::from_slice(&[-2.0, 1.0, 1.5, -0.5], &[2, 2]);
assert!(a_inv.all_close(&expected, None, None, None).unwrap().item::<bool>());