mlx_rs/
linalg.rs

1//! Linear algebra operations.
2
3use crate::error::{Exception, Result};
4use crate::utils::guard::Guarded;
5use crate::utils::{IntoOption, VectorArray};
6use crate::{Array, Stream, StreamOrDevice};
7use mlx_internal_macros::{default_device, generate_macro};
8use smallvec::SmallVec;
9use std::f64;
10use std::ffi::CString;
11
12/// Order of the norm
13///
14/// See [`norm`] for more details.
15#[derive(Debug, Clone, Copy)]
16pub enum Ord<'a> {
17    /// String representation of the order
18    Str(&'a str),
19
20    /// Order of the norm
21    P(f64),
22}
23
24impl Default for Ord<'_> {
25    fn default() -> Self {
26        Ord::Str("fro")
27    }
28}
29
30impl std::fmt::Display for Ord<'_> {
31    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
32        match self {
33            Ord::Str(s) => write!(f, "{}", s),
34            Ord::P(p) => write!(f, "{}", p),
35        }
36    }
37}
38
39impl<'a> From<&'a str> for Ord<'a> {
40    fn from(value: &'a str) -> Self {
41        Ord::Str(value)
42    }
43}
44
45impl From<f64> for Ord<'_> {
46    fn from(value: f64) -> Self {
47        Ord::P(value)
48    }
49}
50
51impl<'a> IntoOption<Ord<'a>> for &'a str {
52    fn into_option(self) -> Option<Ord<'a>> {
53        Some(Ord::Str(self))
54    }
55}
56
57impl<'a> IntoOption<Ord<'a>> for f64 {
58    fn into_option(self) -> Option<Ord<'a>> {
59        Some(Ord::P(self))
60    }
61}
62
63/// Compute p-norm of an [`Array`]
64#[generate_macro(customize(root = "$crate::linalg"))]
65#[default_device]
66pub fn norm_p_device<'a>(
67    array: impl AsRef<Array>,
68    ord: f64,
69    #[optional] axes: impl IntoOption<&'a [i32]>,
70    #[optional] keep_dims: impl Into<Option<bool>>,
71    #[optional] stream: impl AsRef<Stream>,
72) -> Result<Array> {
73    let keep_dims = keep_dims.into().unwrap_or(false);
74
75    match axes.into_option() {
76        Some(axes) => Array::try_from_op(|res| unsafe {
77            mlx_sys::mlx_linalg_norm_p(
78                res,
79                array.as_ref().as_ptr(),
80                ord,
81                axes.as_ptr(),
82                axes.len(),
83                keep_dims,
84                stream.as_ref().as_ptr(),
85            )
86        }),
87        None => Array::try_from_op(|res| unsafe {
88            mlx_sys::mlx_linalg_norm_p(
89                res,
90                array.as_ref().as_ptr(),
91                ord,
92                std::ptr::null(),
93                0,
94                keep_dims,
95                stream.as_ref().as_ptr(),
96            )
97        }),
98    }
99}
100
101/// Matrix or vector norm.
102#[generate_macro(customize(root = "$crate::linalg"))]
103#[default_device]
104pub fn norm_ord_device<'a>(
105    array: impl AsRef<Array>,
106    ord: &'a str,
107    #[optional] axes: impl IntoOption<&'a [i32]>,
108    #[optional] keep_dims: impl Into<Option<bool>>,
109    #[optional] stream: impl AsRef<Stream>,
110) -> Result<Array> {
111    let ord = CString::new(ord).map_err(|e| Exception::custom(format!("{}", e)))?;
112    let keep_dims = keep_dims.into().unwrap_or(false);
113
114    match axes.into_option() {
115        Some(axes) => Array::try_from_op(|res| unsafe {
116            mlx_sys::mlx_linalg_norm_ord(
117                res,
118                array.as_ref().as_ptr(),
119                ord.as_ptr(),
120                axes.as_ptr(),
121                axes.len(),
122                keep_dims,
123                stream.as_ref().as_ptr(),
124            )
125        }),
126        None => Array::try_from_op(|res| unsafe {
127            mlx_sys::mlx_linalg_norm_ord(
128                res,
129                array.as_ref().as_ptr(),
130                ord.as_ptr(),
131                std::ptr::null(),
132                0,
133                keep_dims,
134                stream.as_ref().as_ptr(),
135            )
136        }),
137    }
138}
139
140/// Matrix or vector norm.
141///
142/// For values of `ord < 1`, the result is, strictly speaking, not a
143/// mathematical norm, but it may still be useful for various numerical
144/// purposes.
145///
146/// The following norms can be calculated:
147///
148/// ord   | norm for matrices            | norm for vectors
149/// ----- | ---------------------------- | --------------------------
150/// None  | Frobenius norm               | 2-norm
151/// 'fro' | Frobenius norm               | --
152/// inf   | max(sum(abs(x), axis-1))     | max(abs(x))
153/// -inf  | min(sum(abs(x), axis-1))     | min(abs(x))
154/// 0     | --                           | sum(x !- 0)
155/// 1     | max(sum(abs(x), axis-0))     | as below
156/// -1    | min(sum(abs(x), axis-0))     | as below
157/// 2     | 2-norm (largest sing. value) | as below
158/// -2    | smallest singular value      | as below
159/// other | --                           | sum(abs(x)**ord)**(1./ord)
160///
161/// > Nuclear norm and norms based on singular values are not yet implemented.
162///
163/// The Frobenius norm is given by G. H. Golub and C. F. Van Loan, *Matrix Computations*,
164///        Baltimore, MD, Johns Hopkins University Press, 1985, pg. 15
165///
166/// The nuclear norm is the sum of the singular values.
167///
168/// Both the Frobenius and nuclear norm orders are only defined for
169/// matrices and produce a fatal error when `array.ndim != 2`
170///
171/// # Params
172///
173/// - `array`: input array
174/// - `ord`: order of the norm, see table
175/// - `axes`: axes that hold 2d matrices
176/// - `keep_dims`: if `true` the axes which are normed over are left in the result as dimensions
177///   with size one
178#[generate_macro(customize(root = "$crate::linalg"))]
179#[default_device]
180pub fn norm_device<'a>(
181    array: impl AsRef<Array>,
182    #[optional] ord: impl IntoOption<Ord<'a>>,
183    #[optional] axes: impl IntoOption<&'a [i32]>,
184    #[optional] keep_dims: impl Into<Option<bool>>,
185    #[optional] stream: impl AsRef<Stream>,
186) -> Result<Array> {
187    let ord = ord.into_option();
188    let axes = axes.into_option();
189    let keep_dims = keep_dims.into().unwrap_or(false);
190
191    match (ord, axes) {
192        // If axis and ord are both unspecified, computes the 2-norm of flatten(x).
193        (None, None) => {
194            let axes_ptr = std::ptr::null(); // mlx-c already handles the case where axes is null
195            Array::try_from_op(|res| unsafe {
196                mlx_sys::mlx_linalg_norm(
197                    res,
198                    array.as_ref().as_ptr(),
199                    axes_ptr,
200                    0,
201                    keep_dims,
202                    stream.as_ref().as_ptr(),
203                )
204            })
205        }
206        // If axis is not provided but ord is, then x must be either 1D or 2D.
207        //
208        // Frobenius norm is only supported for matrices
209        (Some(Ord::Str(ord)), None) => norm_ord_device(array, ord, axes, keep_dims, stream),
210        (Some(Ord::P(p)), None) => norm_p_device(array, p, axes, keep_dims, stream),
211        // If axis is provided, but ord is not, then the 2-norm (or Frobenius norm for matrices) is
212        // computed along the given axes. At most 2 axes can be specified.
213        (None, Some(axes)) => Array::try_from_op(|res| unsafe {
214            mlx_sys::mlx_linalg_norm(
215                res,
216                array.as_ref().as_ptr(),
217                axes.as_ptr(),
218                axes.len(),
219                keep_dims,
220                stream.as_ref().as_ptr(),
221            )
222        }),
223        // If both axis and ord are provided, then the corresponding matrix or vector
224        // norm is computed. At most 2 axes can be specified.
225        (Some(Ord::Str(ord)), Some(axes)) => norm_ord_device(array, ord, axes, keep_dims, stream),
226        (Some(Ord::P(p)), Some(axes)) => norm_p_device(array, p, axes, keep_dims, stream),
227    }
228}
229
230/// The QR factorization of the input matrix. Returns an error if the input is not valid.
231///
232/// This function supports arrays with at least 2 dimensions. The matrices which are factorized are
233/// assumed to be in the last two dimensions of the input.
234///
235/// Evaluation on the GPU is not yet implemented.
236///
237/// # Params
238///
239/// - `array`: input array
240///
241/// # Example
242///
243/// ```rust
244/// use mlx_rs::{Array, StreamOrDevice, linalg::*};
245///
246/// let a = Array::from_slice(&[2.0f32, 3.0, 1.0, 2.0], &[2, 2]);
247///
248/// let (q, r) = qr_device(&a, StreamOrDevice::cpu()).unwrap();
249///
250/// let q_expected = Array::from_slice(&[-0.894427, -0.447214, -0.447214, 0.894427], &[2, 2]);
251/// let r_expected = Array::from_slice(&[-2.23607, -3.57771, 0.0, 0.447214], &[2, 2]);
252///
253/// assert!(q.all_close(&q_expected, None, None, None).unwrap().item::<bool>());
254/// assert!(r.all_close(&r_expected, None, None, None).unwrap().item::<bool>());
255/// ```
256#[generate_macro(customize(root = "$crate::linalg"))]
257#[default_device]
258pub fn qr_device(
259    a: impl AsRef<Array>,
260    #[optional] stream: impl AsRef<Stream>,
261) -> Result<(Array, Array)> {
262    <(Array, Array)>::try_from_op(|(res_0, res_1)| unsafe {
263        mlx_sys::mlx_linalg_qr(res_0, res_1, a.as_ref().as_ptr(), stream.as_ref().as_ptr())
264    })
265}
266
267/// The Singular Value Decomposition (SVD) of the input matrix. Returns an error if the input is not
268/// valid.
269///
270/// This function supports arrays with at least 2 dimensions. When the input has more than two
271/// dimensions, the function iterates over all indices of the first a.ndim - 2 dimensions and for
272/// each combination SVD is applied to the last two indices.
273///
274/// Evaluation on the GPU is not yet implemented.
275///
276/// # Params
277///
278/// - `array`: input array
279///
280/// # Example
281///
282/// ```rust
283/// use mlx_rs::{Array, StreamOrDevice, linalg::*};
284///
285/// let a = Array::from_slice(&[1.0f32, 2.0, 3.0, 4.0], &[2, 2]);
286/// let (u, s, vt) = svd_device(&a, StreamOrDevice::cpu()).unwrap();
287/// let u_expected = Array::from_slice(&[-0.404554, 0.914514, -0.914514, -0.404554], &[2, 2]);
288/// let s_expected = Array::from_slice(&[5.46499, 0.365966], &[2]);
289/// let vt_expected = Array::from_slice(&[-0.576048, -0.817416, -0.817415, 0.576048], &[2, 2]);
290/// assert!(u.all_close(&u_expected, None, None, None).unwrap().item::<bool>());
291/// assert!(s.all_close(&s_expected, None, None, None).unwrap().item::<bool>());
292/// assert!(vt.all_close(&vt_expected, None, None, None).unwrap().item::<bool>());
293/// ```
294#[generate_macro(customize(root = "$crate::linalg"))]
295#[default_device]
296pub fn svd_device(
297    array: impl AsRef<Array>,
298    #[optional] stream: impl AsRef<Stream>,
299) -> Result<(Array, Array, Array)> {
300    let v = VectorArray::try_from_op(|res| unsafe {
301        mlx_sys::mlx_linalg_svd(res, array.as_ref().as_ptr(), stream.as_ref().as_ptr())
302    })?;
303
304    let vals: SmallVec<[Array; 3]> = v.try_into_values()?;
305    let mut iter = vals.into_iter();
306    let u = iter.next().unwrap();
307    let s = iter.next().unwrap();
308    let vt = iter.next().unwrap();
309
310    Ok((u, s, vt))
311}
312
313/// Compute the inverse of a square matrix. Returns an error if the input is not valid.
314///
315/// This function supports arrays with at least 2 dimensions. When the input has more than two
316/// dimensions, the inverse is computed for each matrix in the last two dimensions of `a`.
317///
318/// Evaluation on the GPU is not yet implemented.
319///
320/// # Params
321///
322/// - `a`: input array
323///
324/// # Example
325///
326/// ```rust
327/// use mlx_rs::{Array, StreamOrDevice, linalg::*};
328///
329/// let a = Array::from_slice(&[1.0f32, 2.0, 3.0, 4.0], &[2, 2]);
330/// let a_inv = inv_device(&a, StreamOrDevice::cpu()).unwrap();
331/// let expected = Array::from_slice(&[-2.0, 1.0, 1.5, -0.5], &[2, 2]);
332/// assert!(a_inv.all_close(&expected, None, None, None).unwrap().item::<bool>());
333/// ```
334#[generate_macro(customize(root = "$crate::linalg"))]
335#[default_device]
336pub fn inv_device(a: impl AsRef<Array>, #[optional] stream: impl AsRef<Stream>) -> Result<Array> {
337    Array::try_from_op(|res| unsafe {
338        mlx_sys::mlx_linalg_inv(res, a.as_ref().as_ptr(), stream.as_ref().as_ptr())
339    })
340}
341
342/// Compute the Cholesky decomposition of a real symmetric positive semi-definite matrix.
343///
344/// This function supports arrays with at least 2 dimensions. When the input has more than two
345/// dimensions, the Cholesky decomposition is computed for each matrix in the last two dimensions of
346/// `a`.
347///
348/// If the input matrix is not symmetric positive semi-definite, behaviour is undefined.
349///
350/// # Params
351///
352/// - `a`: input array
353/// - `upper`: If `true`, return the upper triangular Cholesky factor. If `false`, return the lower
354///   triangular Cholesky factor. Default: `false`.
355#[generate_macro(customize(root = "$crate::linalg"))]
356#[default_device]
357pub fn cholesky_device(
358    a: impl AsRef<Array>,
359    #[optional] upper: Option<bool>,
360    #[optional] stream: impl AsRef<Stream>,
361) -> Result<Array> {
362    let upper = upper.unwrap_or(false);
363    Array::try_from_op(|res| unsafe {
364        mlx_sys::mlx_linalg_cholesky(res, a.as_ref().as_ptr(), upper, stream.as_ref().as_ptr())
365    })
366}
367
368/// Compute the inverse of a real symmetric positive semi-definite matrix using it’s Cholesky decomposition.
369///
370/// Please see the python documentation for more details.
371#[generate_macro(customize(root = "$crate::linalg"))]
372#[default_device]
373pub fn cholesky_inv_device(
374    a: impl AsRef<Array>,
375    #[optional] upper: Option<bool>,
376    #[optional] stream: impl AsRef<Stream>,
377) -> Result<Array> {
378    let upper = upper.unwrap_or(false);
379    Array::try_from_op(|res| unsafe {
380        mlx_sys::mlx_linalg_cholesky_inv(res, a.as_ref().as_ptr(), upper, stream.as_ref().as_ptr())
381    })
382}
383
384/// Compute the cross product of two arrays along a specified axis.
385///
386/// The cross product is defined for arrays with size 2 or 3 in the specified axis. If the size is 2
387/// then the third value is assumed to be zero.
388#[generate_macro(customize(root = "$crate::linalg"))]
389#[default_device]
390pub fn cross_device(
391    a: impl AsRef<Array>,
392    b: impl AsRef<Array>,
393    #[optional] axis: Option<i32>,
394    #[optional] stream: impl AsRef<Stream>,
395) -> Result<Array> {
396    let axis = axis.unwrap_or(-1);
397    Array::try_from_op(|res| unsafe {
398        mlx_sys::mlx_linalg_cross(
399            res,
400            a.as_ref().as_ptr(),
401            b.as_ref().as_ptr(),
402            axis,
403            stream.as_ref().as_ptr(),
404        )
405    })
406}
407
408/// Compute the eigenvalues and eigenvectors of a complex Hermitian or real symmetric matrix.
409///
410/// This function supports arrays with at least 2 dimensions. When the input has more than two
411/// dimensions, the eigenvalues and eigenvectors are computed for each matrix in the last two
412/// dimensions.
413#[generate_macro(customize(root = "$crate::linalg"))]
414#[default_device]
415pub fn eigh_device(
416    a: impl AsRef<Array>,
417    #[optional] uplo: Option<&str>,
418    #[optional] stream: impl AsRef<Stream>,
419) -> Result<(Array, Array)> {
420    let a = a.as_ref();
421    let uplo =
422        CString::new(uplo.unwrap_or("L")).map_err(|e| Exception::custom(format!("{}", e)))?;
423
424    <(Array, Array) as Guarded>::try_from_op(|(res_0, res_1)| unsafe {
425        mlx_sys::mlx_linalg_eigh(
426            res_0,
427            res_1,
428            a.as_ptr(),
429            uplo.as_ptr(),
430            stream.as_ref().as_ptr(),
431        )
432    })
433}
434
435/// Compute the eigenvalues of a complex Hermitian or real symmetric matrix.
436///
437/// This function supports arrays with at least 2 dimensions. When the input has more than two
438/// dimensions, the eigenvalues are computed for each matrix in the last two dimensions.
439#[generate_macro(customize(root = "$crate::linalg"))]
440#[default_device]
441pub fn eigvalsh_device(
442    a: impl AsRef<Array>,
443    #[optional] uplo: Option<&str>,
444    #[optional] stream: impl AsRef<Stream>,
445) -> Result<Array> {
446    let a = a.as_ref();
447    let uplo =
448        CString::new(uplo.unwrap_or("L")).map_err(|e| Exception::custom(format!("{}", e)))?;
449    Array::try_from_op(|res| unsafe {
450        mlx_sys::mlx_linalg_eigvalsh(res, a.as_ptr(), uplo.as_ptr(), stream.as_ref().as_ptr())
451    })
452}
453
454/// Compute the (Moore-Penrose) pseudo-inverse of a matrix.
455#[generate_macro(customize(root = "$crate::linalg"))]
456#[default_device]
457pub fn pinv_device(a: impl AsRef<Array>, #[optional] stream: impl AsRef<Stream>) -> Result<Array> {
458    Array::try_from_op(|res| unsafe {
459        mlx_sys::mlx_linalg_pinv(res, a.as_ref().as_ptr(), stream.as_ref().as_ptr())
460    })
461}
462
463/// Compute the inverse of a triangular square matrix.
464///
465/// This function supports arrays with at least 2 dimensions. When the input has more than two
466/// dimensions, the inverse is computed for each matrix in the last two dimensions of a.
467#[generate_macro(customize(root = "$crate::linalg"))]
468#[default_device]
469pub fn tri_inv_device(
470    a: impl AsRef<Array>,
471    #[optional] upper: Option<bool>,
472    #[optional] stream: impl AsRef<Stream>,
473) -> Result<Array> {
474    let upper = upper.unwrap_or(false);
475    Array::try_from_op(|res| unsafe {
476        mlx_sys::mlx_linalg_tri_inv(res, a.as_ref().as_ptr(), upper, stream.as_ref().as_ptr())
477    })
478}
479
480/// Compute the LU factorization of the given matrix A.
481///
482/// Note, unlike the default behavior of scipy.linalg.lu, the pivots are
483/// indices. To reconstruct the input use L[P, :] @ U for 2 dimensions or
484/// mx.take_along_axis(L, P[..., None], axis=-2) @ U for more than 2 dimensions.
485///
486/// To construct the full permuation matrix do:
487///
488/// ```rust,ignore
489/// // python
490/// // P = mx.put_along_axis(mx.zeros_like(L), p[..., None], mx.array(1.0), axis=-1)
491/// let p = mlx_rs::ops::put_along_axis(
492///     mlx_rs::ops::zeros_like(&l),
493///     p.index((Ellipsis, NewAxis)),
494///     array!(1.0),
495///     -1,
496/// ).unwrap();
497/// ```
498///
499/// # Params
500///
501/// - `a`: input array
502/// - `stream`: stream to execute the operation
503///
504/// # Returns
505///
506/// The `p`, `L`, and `U` arrays, such that `A = L[P, :] @ U`
507#[generate_macro(customize(root = "$crate::linalg"))]
508#[default_device]
509pub fn lu_device(
510    a: impl AsRef<Array>,
511    #[optional] stream: impl AsRef<Stream>,
512) -> Result<(Array, Array, Array)> {
513    let v = Vec::<Array>::try_from_op(|res| unsafe {
514        mlx_sys::mlx_linalg_lu(res, a.as_ref().as_ptr(), stream.as_ref().as_ptr())
515    })?;
516    let mut iter = v.into_iter();
517    let p = iter.next().ok_or_else(|| Exception::custom("missing P"))?;
518    let l = iter.next().ok_or_else(|| Exception::custom("missing L"))?;
519    let u = iter.next().ok_or_else(|| Exception::custom("missing U"))?;
520    Ok((p, l, u))
521}
522
523/// Computes a compact representation of the LU factorization.
524///
525/// # Params
526///
527/// - `a`: input array
528/// - `stream`: stream to execute the operation
529///
530/// # Returns
531///
532/// The `LU` matrix and `pivots` array.
533#[generate_macro(customize(root = "$crate::linalg"))]
534#[default_device]
535pub fn lu_factor_device(
536    a: impl AsRef<Array>,
537    #[optional] stream: impl AsRef<Stream>,
538) -> Result<(Array, Array)> {
539    <(Array, Array)>::try_from_op(|(res_0, res_1)| unsafe {
540        mlx_sys::mlx_linalg_lu_factor(res_0, res_1, a.as_ref().as_ptr(), stream.as_ref().as_ptr())
541    })
542}
543
544/// Compute the solution to a system of linear equations `AX = B`
545///
546/// # Params
547///
548/// - `a`: input array
549/// - `b`: input array
550/// - `stream`: stream to execute the operation
551///
552/// # Returns
553///
554/// The unique solution to the system `AX = B`
555#[generate_macro(customize(root = "$crate::linalg"))]
556#[default_device]
557pub fn solve_device(
558    a: impl AsRef<Array>,
559    b: impl AsRef<Array>,
560    #[optional] stream: impl AsRef<Stream>,
561) -> Result<Array> {
562    Array::try_from_op(|res| unsafe {
563        mlx_sys::mlx_linalg_solve(
564            res,
565            a.as_ref().as_ptr(),
566            b.as_ref().as_ptr(),
567            stream.as_ref().as_ptr(),
568        )
569    })
570}
571
572/// Computes the solution of a triangular system of linear equations `AX = B`
573///
574/// # Params
575///
576/// - `a`: input array
577/// - `b`: input array
578/// - `upper`: whether the matrix is upper triangular. Default: `false`
579/// - `stream`: stream to execute the operation
580///
581/// # Returns
582///
583/// The unique solution to the system `AX = B`
584#[generate_macro(customize(root = "$crate::linalg"))]
585#[default_device]
586pub fn solve_triangular_device(
587    a: impl AsRef<Array>,
588    b: impl AsRef<Array>,
589    #[optional] upper: impl Into<Option<bool>>,
590    #[optional] stream: impl AsRef<Stream>,
591) -> Result<Array> {
592    let upper = upper.into().unwrap_or(false);
593
594    Array::try_from_op(|res| unsafe {
595        mlx_sys::mlx_linalg_solve_triangular(
596            res,
597            a.as_ref().as_ptr(),
598            b.as_ref().as_ptr(),
599            upper,
600            stream.as_ref().as_ptr(),
601        )
602    })
603}
604
605#[cfg(test)]
606mod tests {
607    use float_eq::assert_float_eq;
608
609    use crate::{
610        array,
611        ops::{eye, indexing::IndexOp, tril, triu},
612    };
613
614    use super::*;
615
616    // The tests below are adapted from the swift bindings tests
617    // and they are not exhaustive. Additional tests should be added
618    // to cover the error cases
619
620    #[test]
621    fn test_norm_no_axes() {
622        let a = Array::from_iter(0..9, &[9]) - 4;
623        let b = a.reshape(&[3, 3]).unwrap();
624
625        assert_float_eq!(
626            norm(&a, None, None, None).unwrap().item::<f32>(),
627            7.74597,
628            abs <= 0.001
629        );
630        assert_float_eq!(
631            norm(&b, None, None, None).unwrap().item::<f32>(),
632            7.74597,
633            abs <= 0.001
634        );
635
636        assert_float_eq!(
637            norm(&b, "fro", None, None).unwrap().item::<f32>(),
638            7.74597,
639            abs <= 0.001
640        );
641
642        assert_float_eq!(
643            norm(&a, f64::INFINITY, None, None).unwrap().item::<f32>(),
644            4.0,
645            abs <= 0.001
646        );
647        assert_float_eq!(
648            norm(&b, f64::INFINITY, None, None).unwrap().item::<f32>(),
649            9.0,
650            abs <= 0.001
651        );
652
653        assert_float_eq!(
654            norm(&a, f64::NEG_INFINITY, None, None)
655                .unwrap()
656                .item::<f32>(),
657            0.0,
658            abs <= 0.001
659        );
660        assert_float_eq!(
661            norm(&b, f64::NEG_INFINITY, None, None)
662                .unwrap()
663                .item::<f32>(),
664            2.0,
665            abs <= 0.001
666        );
667
668        assert_float_eq!(
669            norm(&a, 1.0, None, None).unwrap().item::<f32>(),
670            20.0,
671            abs <= 0.001
672        );
673        assert_float_eq!(
674            norm(&b, 1.0, None, None).unwrap().item::<f32>(),
675            7.0,
676            abs <= 0.001
677        );
678
679        assert_float_eq!(
680            norm(&a, -1.0, None, None).unwrap().item::<f32>(),
681            0.0,
682            abs <= 0.001
683        );
684        assert_float_eq!(
685            norm(&b, -1.0, None, None).unwrap().item::<f32>(),
686            6.0,
687            abs <= 0.001
688        );
689    }
690
691    #[test]
692    fn test_norm_axis() {
693        let c = Array::from_slice(&[1, 2, 3, -1, 1, 4], &[2, 3]);
694
695        let result = norm(&c, None, &[0][..], None).unwrap();
696        let expected = Array::from_slice(&[1.41421, 2.23607, 5.0], &[3]);
697        assert!(result
698            .all_close(&expected, None, None, None)
699            .unwrap()
700            .item::<bool>());
701    }
702
703    #[test]
704    fn test_norm_axes() {
705        let m = Array::from_iter(0..8, &[2, 2, 2]);
706
707        let result = norm(&m, None, &[1, 2][..], None).unwrap();
708        let expected = Array::from_slice(&[3.74166, 11.225], &[2]);
709        assert!(result
710            .all_close(&expected, None, None, None)
711            .unwrap()
712            .item::<bool>());
713    }
714
715    #[test]
716    fn test_qr() {
717        let a = Array::from_slice(&[2.0f32, 3.0, 1.0, 2.0], &[2, 2]);
718
719        let (q, r) = qr_device(&a, StreamOrDevice::cpu()).unwrap();
720
721        let q_expected = Array::from_slice(&[-0.894427, -0.447214, -0.447214, 0.894427], &[2, 2]);
722        let r_expected = Array::from_slice(&[-2.23607, -3.57771, 0.0, 0.447214], &[2, 2]);
723
724        assert!(q
725            .all_close(&q_expected, None, None, None)
726            .unwrap()
727            .item::<bool>());
728        assert!(r
729            .all_close(&r_expected, None, None, None)
730            .unwrap()
731            .item::<bool>());
732    }
733
734    // The tests below are adapted from the c++ tests
735
736    #[test]
737    fn test_svd() {
738        // eval_gpu is not implemented yet.
739        let stream = StreamOrDevice::cpu();
740
741        // 0D and 1D returns error
742        let a = Array::from_f32(0.0);
743        assert!(svd_device(&a, &stream).is_err());
744
745        let a = Array::from_slice(&[0.0, 1.0], &[2]);
746        assert!(svd_device(&a, &stream).is_err());
747
748        // Unsupported types returns error
749        let a = Array::from_slice(&[0, 1], &[1, 2]);
750        assert!(svd_device(&a, &stream).is_err());
751
752        // TODO: wait for random
753    }
754
755    #[test]
756    fn test_inv() {
757        // eval_gpu is not implemented yet.
758        let stream = StreamOrDevice::cpu();
759
760        // 0D and 1D returns error
761        let a = Array::from_f32(0.0);
762        assert!(inv_device(&a, &stream).is_err());
763
764        let a = Array::from_slice(&[0.0, 1.0], &[2]);
765        assert!(inv_device(&a, &stream).is_err());
766
767        // Unsupported types returns error
768        let a = Array::from_slice(&[1, 2, 3, 4, 5, 6], &[2, 3]);
769        assert!(inv_device(&a, &stream).is_err());
770
771        // TODO: wait for random
772    }
773
774    #[test]
775    fn test_cholesky() {
776        // eval_gpu is not implemented yet.
777        let stream = StreamOrDevice::cpu();
778
779        // 0D and 1D returns error
780        let a = Array::from_f32(0.0);
781        assert!(cholesky_device(&a, None, &stream).is_err());
782
783        let a = Array::from_slice(&[0.0, 1.0], &[2]);
784        assert!(cholesky_device(&a, None, &stream).is_err());
785
786        // Unsupported types returns error
787        let a = Array::from_slice(&[0, 1, 1, 2], &[2, 2]);
788        assert!(cholesky_device(&a, None, &stream).is_err());
789
790        // Non-square returns error
791        let a = Array::from_slice(&[1, 2, 3, 4, 5, 6], &[2, 3]);
792        assert!(cholesky_device(&a, None, &stream).is_err());
793
794        // TODO: wait for random
795    }
796
797    // The unit test below is adapted from the python unit test `test_linalg.py/test_lu`
798    #[test]
799    fn test_lu() {
800        let scalar = array!(1.0);
801        let result = lu_device(&scalar, StreamOrDevice::cpu());
802        assert!(result.is_err());
803
804        // # Test 3x3 matrix
805        let a = array!([[3.0f32, 1.0, 2.0], [1.0, 8.0, 6.0], [9.0, 2.0, 5.0]]);
806        let (p, l, u) = lu_device(&a, StreamOrDevice::cpu()).unwrap();
807        let a_rec = l.index((p, ..)).matmul(u).unwrap();
808        assert_array_all_close!(a, a_rec);
809    }
810
811    // The unit test below is adapted from the python unit test `test_linalg.py/test_lu_factor`
812    #[test]
813    fn test_lu_factor() {
814        crate::random::seed(7).unwrap();
815
816        // Test 3x3 matrix
817        let a = crate::random::uniform::<_, f32>(0.0, 1.0, &[5, 5], None).unwrap();
818        let (lu, pivots) = lu_factor_device(&a, StreamOrDevice::cpu()).unwrap();
819        let shape = a.shape();
820        let n = shape[shape.len() - 1];
821
822        let pivots: Vec<u32> = pivots.as_slice().to_vec();
823        let mut perm: Vec<u32> = (0..n as u32).collect();
824        for (i, p) in pivots.iter().enumerate() {
825            perm.swap(i, *p as usize);
826        }
827
828        let l = tril(&lu, -1)
829            .and_then(|l| l.add(eye::<f32>(n, None, None)?))
830            .unwrap();
831        let u = triu(&lu, None).unwrap();
832
833        let lhs = l.matmul(&u).unwrap();
834        let perm = Array::from_slice(&perm, &[n]);
835        let rhs = a.index((perm, ..));
836        assert_array_all_close!(lhs, rhs);
837    }
838
839    // The unit test below is adapted from the python unit test `test_linalg.py/test_solve`
840    #[test]
841    fn test_solve() {
842        crate::random::seed(7).unwrap();
843
844        // Test 3x3 matrix with 1D rhs
845        let a = array!([[3.0f32, 1.0, 2.0], [1.0, 8.0, 6.0], [9.0, 2.0, 5.0]]);
846        let b = array!([11.0f32, 35.0, 28.0]);
847
848        let result = solve_device(&a, &b, StreamOrDevice::cpu()).unwrap();
849        let expected = array!([1.0f32, 2.0, 3.0]);
850        assert_array_all_close!(result, expected);
851    }
852
853    #[test]
854    fn test_solve_triangular() {
855        let a = array!([[4.0f32, 0.0, 0.0], [2.0, 3.0, 0.0], [1.0, -2.0, 5.0]]);
856        let b = array!([8.0f32, 14.0, 3.0]);
857
858        let result = solve_triangular_device(&a, &b, false, StreamOrDevice::cpu()).unwrap();
859        let expected = array!([2.0f32, 3.333_333_3, 1.533_333_3]);
860        assert_array_all_close!(result, expected);
861    }
862}