pub fn inv_device(
a: impl AsRef<Array>,
stream: impl AsRef<Stream>,
) -> Result<Array>
Expand description
Compute the inverse of a square matrix. Returns an error if the input is not valid.
This function supports arrays with at least 2 dimensions. When the input has more than two
dimensions, the inverse is computed for each matrix in the last two dimensions of a
.
Evaluation on the GPU is not yet implemented.
§Params
a
: input array
§Example
use mlx_rs::{Array, StreamOrDevice, linalg::*};
let a = Array::from_slice(&[1.0f32, 2.0, 3.0, 4.0], &[2, 2]);
let a_inv = inv_device(&a, StreamOrDevice::cpu()).unwrap();
let expected = Array::from_slice(&[-2.0, 1.0, 1.5, -0.5], &[2, 2]);
assert!(a_inv.all_close(&expected, None, None, None).unwrap().item::<bool>());