mlx_rs/
linalg.rs

1//! Linear algebra operations.
2
3use crate::error::{Exception, Result};
4use crate::utils::guard::Guarded;
5use crate::utils::{IntoOption, VectorArray};
6use crate::{Array, Stream};
7use mlx_internal_macros::{default_device, generate_macro};
8use smallvec::SmallVec;
9use std::f64;
10use std::ffi::CString;
11
12/// Order of the norm
13///
14/// See [`norm`] for more details.
15#[derive(Debug, Clone, Copy)]
16pub enum Ord<'a> {
17    /// String representation of the order
18    Str(&'a str),
19
20    /// Order of the norm
21    P(f64),
22}
23
24impl Default for Ord<'_> {
25    fn default() -> Self {
26        Ord::Str("fro")
27    }
28}
29
30impl std::fmt::Display for Ord<'_> {
31    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
32        match self {
33            Ord::Str(s) => write!(f, "{}", s),
34            Ord::P(p) => write!(f, "{}", p),
35        }
36    }
37}
38
39impl<'a> From<&'a str> for Ord<'a> {
40    fn from(value: &'a str) -> Self {
41        Ord::Str(value)
42    }
43}
44
45impl From<f64> for Ord<'_> {
46    fn from(value: f64) -> Self {
47        Ord::P(value)
48    }
49}
50
51impl<'a> IntoOption<Ord<'a>> for &'a str {
52    fn into_option(self) -> Option<Ord<'a>> {
53        Some(Ord::Str(self))
54    }
55}
56
57impl<'a> IntoOption<Ord<'a>> for f64 {
58    fn into_option(self) -> Option<Ord<'a>> {
59        Some(Ord::P(self))
60    }
61}
62
63/// Compute p-norm of an [`Array`]
64#[generate_macro(customize(root = "$crate::linalg"))]
65#[default_device]
66pub fn norm_device<'a>(
67    array: impl AsRef<Array>,
68    ord: f64,
69    #[optional] axes: impl IntoOption<&'a [i32]>,
70    #[optional] keep_dims: impl Into<Option<bool>>,
71    #[optional] stream: impl AsRef<Stream>,
72) -> Result<Array> {
73    let keep_dims = keep_dims.into().unwrap_or(false);
74
75    match axes.into_option() {
76        Some(axes) => Array::try_from_op(|res| unsafe {
77            mlx_sys::mlx_linalg_norm(
78                res,
79                array.as_ref().as_ptr(),
80                ord,
81                axes.as_ptr(),
82                axes.len(),
83                keep_dims,
84                stream.as_ref().as_ptr(),
85            )
86        }),
87        None => Array::try_from_op(|res| unsafe {
88            mlx_sys::mlx_linalg_norm(
89                res,
90                array.as_ref().as_ptr(),
91                ord,
92                std::ptr::null(),
93                0,
94                keep_dims,
95                stream.as_ref().as_ptr(),
96            )
97        }),
98    }
99}
100
101/// Matrix or vector norm.
102#[generate_macro(customize(root = "$crate::linalg"))]
103#[default_device]
104pub fn norm_matrix_device<'a>(
105    array: impl AsRef<Array>,
106    ord: &'a str,
107    #[optional] axes: impl IntoOption<&'a [i32]>,
108    #[optional] keep_dims: impl Into<Option<bool>>,
109    #[optional] stream: impl AsRef<Stream>,
110) -> Result<Array> {
111    let ord = CString::new(ord).map_err(|e| Exception::custom(format!("{}", e)))?;
112    let keep_dims = keep_dims.into().unwrap_or(false);
113
114    match axes.into_option() {
115        Some(axes) => Array::try_from_op(|res| unsafe {
116            mlx_sys::mlx_linalg_norm_matrix(
117                res,
118                array.as_ref().as_ptr(),
119                ord.as_ptr(),
120                axes.as_ptr(),
121                axes.len(),
122                keep_dims,
123                stream.as_ref().as_ptr(),
124            )
125        }),
126        None => Array::try_from_op(|res| unsafe {
127            mlx_sys::mlx_linalg_norm_matrix(
128                res,
129                array.as_ref().as_ptr(),
130                ord.as_ptr(),
131                std::ptr::null(),
132                0,
133                keep_dims,
134                stream.as_ref().as_ptr(),
135            )
136        }),
137    }
138}
139
140/// Compute the L2 norm of an [`Array`]
141#[generate_macro(customize(root = "$crate::linalg"))]
142#[default_device]
143pub fn norm_l2_device<'a>(
144    array: impl AsRef<Array>,
145    #[optional] axes: impl IntoOption<&'a [i32]>,
146    #[optional] keep_dims: impl Into<Option<bool>>,
147    #[optional] stream: impl AsRef<Stream>,
148) -> Result<Array> {
149    let keep_dims = keep_dims.into().unwrap_or(false);
150
151    match axes.into_option() {
152        Some(axis) => Array::try_from_op(|res| unsafe {
153            mlx_sys::mlx_linalg_norm_l2(
154                res,
155                array.as_ref().as_ptr(),
156                axis.as_ptr(),
157                axis.len(),
158                keep_dims,
159                stream.as_ref().as_ptr(),
160            )
161        }),
162        None => Array::try_from_op(|res| unsafe {
163            mlx_sys::mlx_linalg_norm_l2(
164                res,
165                array.as_ref().as_ptr(),
166                std::ptr::null(),
167                0,
168                keep_dims,
169                stream.as_ref().as_ptr(),
170            )
171        }),
172    }
173}
174
175// TODO: Change the original `norm` function to use builder pattern
176// /// Matrix or vector norm.
177// ///
178// /// For values of `ord < 1`, the result is, strictly speaking, not a
179// /// mathematical norm, but it may still be useful for various numerical
180// /// purposes.
181// ///
182// /// The following norms can be calculated:
183// ///
184// /// ord   | norm for matrices            | norm for vectors
185// /// ----- | ---------------------------- | --------------------------
186// /// None  | Frobenius norm               | 2-norm
187// /// 'fro' | Frobenius norm               | --
188// /// inf   | max(sum(abs(x), axis-1))     | max(abs(x))
189// /// -inf  | min(sum(abs(x), axis-1))     | min(abs(x))
190// /// 0     | --                           | sum(x !- 0)
191// /// 1     | max(sum(abs(x), axis-0))     | as below
192// /// -1    | min(sum(abs(x), axis-0))     | as below
193// /// 2     | 2-norm (largest sing. value) | as below
194// /// -2    | smallest singular value      | as below
195// /// other | --                           | sum(abs(x)**ord)**(1./ord)
196// ///
197// /// > Nuclear norm and norms based on singular values are not yet implemented.
198// ///
199// /// The Frobenius norm is given by G. H. Golub and C. F. Van Loan, *Matrix Computations*,
200// ///        Baltimore, MD, Johns Hopkins University Press, 1985, pg. 15
201// ///
202// /// The nuclear norm is the sum of the singular values.
203// ///
204// /// Both the Frobenius and nuclear norm orders are only defined for
205// /// matrices and produce a fatal error when `array.ndim != 2`
206// ///
207// /// # Params
208// ///
209// /// - `array`: input array
210// /// - `ord`: order of the norm, see table
211// /// - `axes`: axes that hold 2d matrices
212// /// - `keep_dims`: if `true` the axes which are normed over are left in the result as dimensions
213// ///   with size one
214// #[generate_macro(customize(root = "$crate::linalg"))]
215// #[default_device]
216// pub fn norm_device<'a>(
217//     array: impl AsRef<Array>,
218//     #[optional] ord: impl IntoOption<Ord<'a>>,
219//     #[optional] axes: impl IntoOption<&'a [i32]>,
220//     #[optional] keep_dims: impl Into<Option<bool>>,
221//     #[optional] stream: impl AsRef<Stream>,
222// ) -> Result<Array> {
223//     let ord = ord.into_option();
224//     let axes = axes.into_option();
225//     let keep_dims = keep_dims.into().unwrap_or(false);
226
227//     match (ord, axes) {
228//         // If axis and ord are both unspecified, computes the 2-norm of flatten(x).
229//         (None, None) => {
230//             let axes_ptr = std::ptr::null(); // mlx-c already handles the case where axes is null
231//             Array::try_from_op(|res| unsafe {
232//                 mlx_sys::mlx_linalg_norm(
233//                     res,
234//                     array.as_ref().as_ptr(),
235//                     axes_ptr,
236//                     0,
237//                     keep_dims,
238//                     stream.as_ref().as_ptr(),
239//                 )
240//             })
241//         }
242//         // If axis is not provided but ord is, then x must be either 1D or 2D.
243//         //
244//         // Frobenius norm is only supported for matrices
245//         (Some(Ord::Str(ord)), None) => norm_ord_device(array, ord, axes, keep_dims, stream),
246//         (Some(Ord::P(p)), None) => norm_p_device(array, p, axes, keep_dims, stream),
247//         // If axis is provided, but ord is not, then the 2-norm (or Frobenius norm for matrices) is
248//         // computed along the given axes. At most 2 axes can be specified.
249//         (None, Some(axes)) => Array::try_from_op(|res| unsafe {
250//             mlx_sys::mlx_linalg_norm(
251//                 res,
252//                 array.as_ref().as_ptr(),
253//                 axes.as_ptr(),
254//                 axes.len(),
255//                 keep_dims,
256//                 stream.as_ref().as_ptr(),
257//             )
258//         }),
259//         // If both axis and ord are provided, then the corresponding matrix or vector
260//         // norm is computed. At most 2 axes can be specified.
261//         (Some(Ord::Str(ord)), Some(axes)) => norm_ord_device(array, ord, axes, keep_dims, stream),
262//         (Some(Ord::P(p)), Some(axes)) => norm_p_device(array, p, axes, keep_dims, stream),
263//     }
264// }
265
266/// The QR factorization of the input matrix. Returns an error if the input is not valid.
267///
268/// This function supports arrays with at least 2 dimensions. The matrices which are factorized are
269/// assumed to be in the last two dimensions of the input.
270///
271/// Evaluation on the GPU is not yet implemented.
272///
273/// # Params
274///
275/// - `array`: input array
276///
277/// # Example
278///
279/// ```rust
280/// use mlx_rs::{Array, StreamOrDevice, linalg::*};
281///
282/// let a = Array::from_slice(&[2.0f32, 3.0, 1.0, 2.0], &[2, 2]);
283///
284/// let (q, r) = qr_device(&a, StreamOrDevice::cpu()).unwrap();
285///
286/// let q_expected = Array::from_slice(&[-0.894427, -0.447214, -0.447214, 0.894427], &[2, 2]);
287/// let r_expected = Array::from_slice(&[-2.23607, -3.57771, 0.0, 0.447214], &[2, 2]);
288///
289/// assert!(q.all_close(&q_expected, None, None, None).unwrap().item::<bool>());
290/// assert!(r.all_close(&r_expected, None, None, None).unwrap().item::<bool>());
291/// ```
292#[generate_macro(customize(root = "$crate::linalg"))]
293#[default_device]
294pub fn qr_device(
295    a: impl AsRef<Array>,
296    #[optional] stream: impl AsRef<Stream>,
297) -> Result<(Array, Array)> {
298    <(Array, Array)>::try_from_op(|(res_0, res_1)| unsafe {
299        mlx_sys::mlx_linalg_qr(res_0, res_1, a.as_ref().as_ptr(), stream.as_ref().as_ptr())
300    })
301}
302
303/// The Singular Value Decomposition (SVD) of the input matrix. Returns an error if the input is not
304/// valid.
305///
306/// This function supports arrays with at least 2 dimensions. When the input has more than two
307/// dimensions, the function iterates over all indices of the first a.ndim - 2 dimensions and for
308/// each combination SVD is applied to the last two indices.
309///
310/// Evaluation on the GPU is not yet implemented.
311///
312/// # Params
313///
314/// - `array`: input array
315///
316/// # Example
317///
318/// ```rust
319/// use mlx_rs::{Array, StreamOrDevice, linalg::*};
320///
321/// let a = Array::from_slice(&[1.0f32, 2.0, 3.0, 4.0], &[2, 2]);
322/// let (u, s, vt) = svd_device(&a, StreamOrDevice::cpu()).unwrap();
323/// let u_expected = Array::from_slice(&[-0.404554, 0.914514, -0.914514, -0.404554], &[2, 2]);
324/// let s_expected = Array::from_slice(&[5.46499, 0.365966], &[2]);
325/// let vt_expected = Array::from_slice(&[-0.576048, -0.817416, -0.817415, 0.576048], &[2, 2]);
326/// assert!(u.all_close(&u_expected, None, None, None).unwrap().item::<bool>());
327/// assert!(s.all_close(&s_expected, None, None, None).unwrap().item::<bool>());
328/// assert!(vt.all_close(&vt_expected, None, None, None).unwrap().item::<bool>());
329/// ```
330#[generate_macro(customize(root = "$crate::linalg"))]
331#[default_device]
332pub fn svd_device(
333    array: impl AsRef<Array>,
334    #[optional] stream: impl AsRef<Stream>,
335) -> Result<(Array, Array, Array)> {
336    let v = VectorArray::try_from_op(|res| unsafe {
337        mlx_sys::mlx_linalg_svd(res, array.as_ref().as_ptr(), true, stream.as_ref().as_ptr())
338    })?;
339
340    let vals: SmallVec<[Array; 3]> = v.try_into_values()?;
341    let mut iter = vals.into_iter();
342    let u = iter.next().unwrap();
343    let s = iter.next().unwrap();
344    let vt = iter.next().unwrap();
345
346    Ok((u, s, vt))
347}
348
349/// Compute the inverse of a square matrix. Returns an error if the input is not valid.
350///
351/// This function supports arrays with at least 2 dimensions. When the input has more than two
352/// dimensions, the inverse is computed for each matrix in the last two dimensions of `a`.
353///
354/// Evaluation on the GPU is not yet implemented.
355///
356/// # Params
357///
358/// - `a`: input array
359///
360/// # Example
361///
362/// ```rust
363/// use mlx_rs::{Array, StreamOrDevice, linalg::*};
364///
365/// let a = Array::from_slice(&[1.0f32, 2.0, 3.0, 4.0], &[2, 2]);
366/// let a_inv = inv_device(&a, StreamOrDevice::cpu()).unwrap();
367/// let expected = Array::from_slice(&[-2.0, 1.0, 1.5, -0.5], &[2, 2]);
368/// assert!(a_inv.all_close(&expected, None, None, None).unwrap().item::<bool>());
369/// ```
370#[generate_macro(customize(root = "$crate::linalg"))]
371#[default_device]
372pub fn inv_device(a: impl AsRef<Array>, #[optional] stream: impl AsRef<Stream>) -> Result<Array> {
373    Array::try_from_op(|res| unsafe {
374        mlx_sys::mlx_linalg_inv(res, a.as_ref().as_ptr(), stream.as_ref().as_ptr())
375    })
376}
377
378/// Compute the Cholesky decomposition of a real symmetric positive semi-definite matrix.
379///
380/// This function supports arrays with at least 2 dimensions. When the input has more than two
381/// dimensions, the Cholesky decomposition is computed for each matrix in the last two dimensions of
382/// `a`.
383///
384/// If the input matrix is not symmetric positive semi-definite, behaviour is undefined.
385///
386/// # Params
387///
388/// - `a`: input array
389/// - `upper`: If `true`, return the upper triangular Cholesky factor. If `false`, return the lower
390///   triangular Cholesky factor. Default: `false`.
391#[generate_macro(customize(root = "$crate::linalg"))]
392#[default_device]
393pub fn cholesky_device(
394    a: impl AsRef<Array>,
395    #[optional] upper: Option<bool>,
396    #[optional] stream: impl AsRef<Stream>,
397) -> Result<Array> {
398    let upper = upper.unwrap_or(false);
399    Array::try_from_op(|res| unsafe {
400        mlx_sys::mlx_linalg_cholesky(res, a.as_ref().as_ptr(), upper, stream.as_ref().as_ptr())
401    })
402}
403
404/// Compute the inverse of a real symmetric positive semi-definite matrix using it’s Cholesky decomposition.
405///
406/// Please see the python documentation for more details.
407#[generate_macro(customize(root = "$crate::linalg"))]
408#[default_device]
409pub fn cholesky_inv_device(
410    a: impl AsRef<Array>,
411    #[optional] upper: Option<bool>,
412    #[optional] stream: impl AsRef<Stream>,
413) -> Result<Array> {
414    let upper = upper.unwrap_or(false);
415    Array::try_from_op(|res| unsafe {
416        mlx_sys::mlx_linalg_cholesky_inv(res, a.as_ref().as_ptr(), upper, stream.as_ref().as_ptr())
417    })
418}
419
420/// Compute the cross product of two arrays along a specified axis.
421///
422/// The cross product is defined for arrays with size 2 or 3 in the specified axis. If the size is 2
423/// then the third value is assumed to be zero.
424#[generate_macro(customize(root = "$crate::linalg"))]
425#[default_device]
426pub fn cross_device(
427    a: impl AsRef<Array>,
428    b: impl AsRef<Array>,
429    #[optional] axis: Option<i32>,
430    #[optional] stream: impl AsRef<Stream>,
431) -> Result<Array> {
432    let axis = axis.unwrap_or(-1);
433    Array::try_from_op(|res| unsafe {
434        mlx_sys::mlx_linalg_cross(
435            res,
436            a.as_ref().as_ptr(),
437            b.as_ref().as_ptr(),
438            axis,
439            stream.as_ref().as_ptr(),
440        )
441    })
442}
443
444/// Compute the eigenvalues and eigenvectors of a complex Hermitian or real symmetric matrix.
445///
446/// This function supports arrays with at least 2 dimensions. When the input has more than two
447/// dimensions, the eigenvalues and eigenvectors are computed for each matrix in the last two
448/// dimensions.
449#[generate_macro(customize(root = "$crate::linalg"))]
450#[default_device]
451pub fn eigh_device(
452    a: impl AsRef<Array>,
453    #[optional] uplo: Option<&str>,
454    #[optional] stream: impl AsRef<Stream>,
455) -> Result<(Array, Array)> {
456    let a = a.as_ref();
457    let uplo =
458        CString::new(uplo.unwrap_or("L")).map_err(|e| Exception::custom(format!("{}", e)))?;
459
460    <(Array, Array) as Guarded>::try_from_op(|(res_0, res_1)| unsafe {
461        mlx_sys::mlx_linalg_eigh(
462            res_0,
463            res_1,
464            a.as_ptr(),
465            uplo.as_ptr(),
466            stream.as_ref().as_ptr(),
467        )
468    })
469}
470
471/// Compute the eigenvalues of a complex Hermitian or real symmetric matrix.
472///
473/// This function supports arrays with at least 2 dimensions. When the input has more than two
474/// dimensions, the eigenvalues are computed for each matrix in the last two dimensions.
475#[generate_macro(customize(root = "$crate::linalg"))]
476#[default_device]
477pub fn eigvalsh_device(
478    a: impl AsRef<Array>,
479    #[optional] uplo: Option<&str>,
480    #[optional] stream: impl AsRef<Stream>,
481) -> Result<Array> {
482    let a = a.as_ref();
483    let uplo =
484        CString::new(uplo.unwrap_or("L")).map_err(|e| Exception::custom(format!("{}", e)))?;
485    Array::try_from_op(|res| unsafe {
486        mlx_sys::mlx_linalg_eigvalsh(res, a.as_ptr(), uplo.as_ptr(), stream.as_ref().as_ptr())
487    })
488}
489
490/// Compute the (Moore-Penrose) pseudo-inverse of a matrix.
491#[generate_macro(customize(root = "$crate::linalg"))]
492#[default_device]
493pub fn pinv_device(a: impl AsRef<Array>, #[optional] stream: impl AsRef<Stream>) -> Result<Array> {
494    Array::try_from_op(|res| unsafe {
495        mlx_sys::mlx_linalg_pinv(res, a.as_ref().as_ptr(), stream.as_ref().as_ptr())
496    })
497}
498
499/// Compute the inverse of a triangular square matrix.
500///
501/// This function supports arrays with at least 2 dimensions. When the input has more than two
502/// dimensions, the inverse is computed for each matrix in the last two dimensions of a.
503#[generate_macro(customize(root = "$crate::linalg"))]
504#[default_device]
505pub fn tri_inv_device(
506    a: impl AsRef<Array>,
507    #[optional] upper: Option<bool>,
508    #[optional] stream: impl AsRef<Stream>,
509) -> Result<Array> {
510    let upper = upper.unwrap_or(false);
511    Array::try_from_op(|res| unsafe {
512        mlx_sys::mlx_linalg_tri_inv(res, a.as_ref().as_ptr(), upper, stream.as_ref().as_ptr())
513    })
514}
515
516/// Compute the LU factorization of the given matrix A.
517///
518/// Note, unlike the default behavior of scipy.linalg.lu, the pivots are
519/// indices. To reconstruct the input use L[P, :] @ U for 2 dimensions or
520/// mx.take_along_axis(L, P[..., None], axis=-2) @ U for more than 2 dimensions.
521///
522/// To construct the full permuation matrix do:
523///
524/// ```rust,ignore
525/// // python
526/// // P = mx.put_along_axis(mx.zeros_like(L), p[..., None], mx.array(1.0), axis=-1)
527/// let p = mlx_rs::ops::put_along_axis(
528///     mlx_rs::ops::zeros_like(&l),
529///     p.index((Ellipsis, NewAxis)),
530///     array!(1.0),
531///     -1,
532/// ).unwrap();
533/// ```
534///
535/// # Params
536///
537/// - `a`: input array
538/// - `stream`: stream to execute the operation
539///
540/// # Returns
541///
542/// The `p`, `L`, and `U` arrays, such that `A = L[P, :] @ U`
543#[generate_macro(customize(root = "$crate::linalg"))]
544#[default_device]
545pub fn lu_device(
546    a: impl AsRef<Array>,
547    #[optional] stream: impl AsRef<Stream>,
548) -> Result<(Array, Array, Array)> {
549    let v = Vec::<Array>::try_from_op(|res| unsafe {
550        mlx_sys::mlx_linalg_lu(res, a.as_ref().as_ptr(), stream.as_ref().as_ptr())
551    })?;
552    let mut iter = v.into_iter();
553    let p = iter.next().ok_or_else(|| Exception::custom("missing P"))?;
554    let l = iter.next().ok_or_else(|| Exception::custom("missing L"))?;
555    let u = iter.next().ok_or_else(|| Exception::custom("missing U"))?;
556    Ok((p, l, u))
557}
558
559/// Computes a compact representation of the LU factorization.
560///
561/// # Params
562///
563/// - `a`: input array
564/// - `stream`: stream to execute the operation
565///
566/// # Returns
567///
568/// The `LU` matrix and `pivots` array.
569#[generate_macro(customize(root = "$crate::linalg"))]
570#[default_device]
571pub fn lu_factor_device(
572    a: impl AsRef<Array>,
573    #[optional] stream: impl AsRef<Stream>,
574) -> Result<(Array, Array)> {
575    <(Array, Array)>::try_from_op(|(res_0, res_1)| unsafe {
576        mlx_sys::mlx_linalg_lu_factor(res_0, res_1, a.as_ref().as_ptr(), stream.as_ref().as_ptr())
577    })
578}
579
580/// Compute the solution to a system of linear equations `AX = B`
581///
582/// # Params
583///
584/// - `a`: input array
585/// - `b`: input array
586/// - `stream`: stream to execute the operation
587///
588/// # Returns
589///
590/// The unique solution to the system `AX = B`
591#[generate_macro(customize(root = "$crate::linalg"))]
592#[default_device]
593pub fn solve_device(
594    a: impl AsRef<Array>,
595    b: impl AsRef<Array>,
596    #[optional] stream: impl AsRef<Stream>,
597) -> Result<Array> {
598    Array::try_from_op(|res| unsafe {
599        mlx_sys::mlx_linalg_solve(
600            res,
601            a.as_ref().as_ptr(),
602            b.as_ref().as_ptr(),
603            stream.as_ref().as_ptr(),
604        )
605    })
606}
607
608/// Computes the solution of a triangular system of linear equations `AX = B`
609///
610/// # Params
611///
612/// - `a`: input array
613/// - `b`: input array
614/// - `upper`: whether the matrix is upper triangular. Default: `false`
615/// - `stream`: stream to execute the operation
616///
617/// # Returns
618///
619/// The unique solution to the system `AX = B`
620#[generate_macro(customize(root = "$crate::linalg"))]
621#[default_device]
622pub fn solve_triangular_device(
623    a: impl AsRef<Array>,
624    b: impl AsRef<Array>,
625    #[optional] upper: impl Into<Option<bool>>,
626    #[optional] stream: impl AsRef<Stream>,
627) -> Result<Array> {
628    let upper = upper.into().unwrap_or(false);
629
630    Array::try_from_op(|res| unsafe {
631        mlx_sys::mlx_linalg_solve_triangular(
632            res,
633            a.as_ref().as_ptr(),
634            b.as_ref().as_ptr(),
635            upper,
636            stream.as_ref().as_ptr(),
637        )
638    })
639}
640
641#[cfg(test)]
642mod tests {
643    use float_eq::assert_float_eq;
644
645    use crate::{
646        array,
647        ops::{eye, indexing::IndexOp, tril, triu},
648        StreamOrDevice,
649    };
650
651    use super::*;
652
653    // The tests below are adapted from the swift bindings tests
654    // and they are not exhaustive. Additional tests should be added
655    // to cover the error cases
656
657    #[test]
658    fn test_norm_no_axes() {
659        let a = Array::from_iter(0..9, &[9]) - 4;
660        let b = a.reshape(&[3, 3]).unwrap();
661
662        assert_float_eq!(
663            norm_l2(&a, None, None).unwrap().item::<f32>(),
664            7.74597,
665            abs <= 0.001
666        );
667        assert_float_eq!(
668            norm_l2(&b, None, None).unwrap().item::<f32>(),
669            7.74597,
670            abs <= 0.001
671        );
672
673        assert_float_eq!(
674            norm_matrix(&b, "fro", None, None).unwrap().item::<f32>(),
675            7.74597,
676            abs <= 0.001
677        );
678
679        assert_float_eq!(
680            norm(&a, f64::INFINITY, None, None).unwrap().item::<f32>(),
681            4.0,
682            abs <= 0.001
683        );
684        assert_float_eq!(
685            norm(&b, f64::INFINITY, None, None).unwrap().item::<f32>(),
686            9.0,
687            abs <= 0.001
688        );
689
690        assert_float_eq!(
691            norm(&a, f64::NEG_INFINITY, None, None)
692                .unwrap()
693                .item::<f32>(),
694            0.0,
695            abs <= 0.001
696        );
697        assert_float_eq!(
698            norm(&b, f64::NEG_INFINITY, None, None)
699                .unwrap()
700                .item::<f32>(),
701            2.0,
702            abs <= 0.001
703        );
704
705        assert_float_eq!(
706            norm(&a, 1.0, None, None).unwrap().item::<f32>(),
707            20.0,
708            abs <= 0.001
709        );
710        assert_float_eq!(
711            norm(&b, 1.0, None, None).unwrap().item::<f32>(),
712            7.0,
713            abs <= 0.001
714        );
715
716        assert_float_eq!(
717            norm(&a, -1.0, None, None).unwrap().item::<f32>(),
718            0.0,
719            abs <= 0.001
720        );
721        assert_float_eq!(
722            norm(&b, -1.0, None, None).unwrap().item::<f32>(),
723            6.0,
724            abs <= 0.001
725        );
726    }
727
728    #[test]
729    fn test_norm_axis() {
730        let c = Array::from_slice(&[1, 2, 3, -1, 1, 4], &[2, 3]);
731
732        let result = norm_l2(&c, &[0], None).unwrap();
733        let expected = Array::from_slice(&[1.41421, 2.23607, 5.0], &[3]);
734        assert!(result
735            .all_close(&expected, None, None, None)
736            .unwrap()
737            .item::<bool>());
738    }
739
740    #[test]
741    fn test_norm_axes() {
742        let m = Array::from_iter(0..8, &[2, 2, 2]);
743
744        let result = norm_l2(&m, &[1, 2][..], None).unwrap();
745        let expected = Array::from_slice(&[3.74166, 11.225], &[2]);
746        assert!(result
747            .all_close(&expected, None, None, None)
748            .unwrap()
749            .item::<bool>());
750    }
751
752    #[test]
753    fn test_qr() {
754        let a = Array::from_slice(&[2.0f32, 3.0, 1.0, 2.0], &[2, 2]);
755
756        let (q, r) = qr_device(&a, StreamOrDevice::cpu()).unwrap();
757
758        let q_expected = Array::from_slice(&[-0.894427, -0.447214, -0.447214, 0.894427], &[2, 2]);
759        let r_expected = Array::from_slice(&[-2.23607, -3.57771, 0.0, 0.447214], &[2, 2]);
760
761        assert!(q
762            .all_close(&q_expected, None, None, None)
763            .unwrap()
764            .item::<bool>());
765        assert!(r
766            .all_close(&r_expected, None, None, None)
767            .unwrap()
768            .item::<bool>());
769    }
770
771    // The tests below are adapted from the c++ tests
772
773    #[test]
774    fn test_svd() {
775        // eval_gpu is not implemented yet.
776        let stream = StreamOrDevice::cpu();
777
778        // 0D and 1D returns error
779        let a = Array::from_f32(0.0);
780        assert!(svd_device(&a, &stream).is_err());
781
782        let a = Array::from_slice(&[0.0, 1.0], &[2]);
783        assert!(svd_device(&a, &stream).is_err());
784
785        // Unsupported types returns error
786        let a = Array::from_slice(&[0, 1], &[1, 2]);
787        assert!(svd_device(&a, &stream).is_err());
788
789        // TODO: wait for random
790    }
791
792    #[test]
793    fn test_inv() {
794        // eval_gpu is not implemented yet.
795        let stream = StreamOrDevice::cpu();
796
797        // 0D and 1D returns error
798        let a = Array::from_f32(0.0);
799        assert!(inv_device(&a, &stream).is_err());
800
801        let a = Array::from_slice(&[0.0, 1.0], &[2]);
802        assert!(inv_device(&a, &stream).is_err());
803
804        // Unsupported types returns error
805        let a = Array::from_slice(&[1, 2, 3, 4, 5, 6], &[2, 3]);
806        assert!(inv_device(&a, &stream).is_err());
807
808        // TODO: wait for random
809    }
810
811    #[test]
812    fn test_cholesky() {
813        // eval_gpu is not implemented yet.
814        let stream = StreamOrDevice::cpu();
815
816        // 0D and 1D returns error
817        let a = Array::from_f32(0.0);
818        assert!(cholesky_device(&a, None, &stream).is_err());
819
820        let a = Array::from_slice(&[0.0, 1.0], &[2]);
821        assert!(cholesky_device(&a, None, &stream).is_err());
822
823        // Unsupported types returns error
824        let a = Array::from_slice(&[0, 1, 1, 2], &[2, 2]);
825        assert!(cholesky_device(&a, None, &stream).is_err());
826
827        // Non-square returns error
828        let a = Array::from_slice(&[1, 2, 3, 4, 5, 6], &[2, 3]);
829        assert!(cholesky_device(&a, None, &stream).is_err());
830
831        // TODO: wait for random
832    }
833
834    // The unit test below is adapted from the python unit test `test_linalg.py/test_lu`
835    #[test]
836    fn test_lu() {
837        let scalar = array!(1.0);
838        let result = lu_device(&scalar, StreamOrDevice::cpu());
839        assert!(result.is_err());
840
841        // # Test 3x3 matrix
842        let a = array!([[3.0f32, 1.0, 2.0], [1.0, 8.0, 6.0], [9.0, 2.0, 5.0]]);
843        let (p, l, u) = lu_device(&a, StreamOrDevice::cpu()).unwrap();
844        let a_rec = l.index((p, ..)).matmul(u).unwrap();
845        assert_array_all_close!(a, a_rec);
846    }
847
848    // The unit test below is adapted from the python unit test `test_linalg.py/test_lu_factor`
849    #[test]
850    fn test_lu_factor() {
851        crate::random::seed(7).unwrap();
852
853        // Test 3x3 matrix
854        let a = crate::random::uniform::<_, f32>(0.0, 1.0, &[5, 5], None).unwrap();
855        let (lu, pivots) = lu_factor_device(&a, StreamOrDevice::cpu()).unwrap();
856        let shape = a.shape();
857        let n = shape[shape.len() - 1];
858
859        let pivots: Vec<u32> = pivots.as_slice().to_vec();
860        let mut perm: Vec<u32> = (0..n as u32).collect();
861        for (i, p) in pivots.iter().enumerate() {
862            perm.swap(i, *p as usize);
863        }
864
865        let l = tril(&lu, -1)
866            .and_then(|l| l.add(eye::<f32>(n, None, None)?))
867            .unwrap();
868        let u = triu(&lu, None).unwrap();
869
870        let lhs = l.matmul(&u).unwrap();
871        let perm = Array::from_slice(&perm, &[n]);
872        let rhs = a.index((perm, ..));
873        assert_array_all_close!(lhs, rhs);
874    }
875
876    // The unit test below is adapted from the python unit test `test_linalg.py/test_solve`
877    #[test]
878    fn test_solve() {
879        crate::random::seed(7).unwrap();
880
881        // Test 3x3 matrix with 1D rhs
882        let a = array!([[3.0f32, 1.0, 2.0], [1.0, 8.0, 6.0], [9.0, 2.0, 5.0]]);
883        let b = array!([11.0f32, 35.0, 28.0]);
884
885        let result = solve_device(&a, &b, StreamOrDevice::cpu()).unwrap();
886        let expected = array!([1.0f32, 2.0, 3.0]);
887        assert_array_all_close!(result, expected);
888    }
889
890    #[test]
891    fn test_solve_triangular() {
892        let a = array!([[4.0f32, 0.0, 0.0], [2.0, 3.0, 0.0], [1.0, -2.0, 5.0]]);
893        let b = array!([8.0f32, 14.0, 3.0]);
894
895        let result = solve_triangular_device(&a, &b, false, StreamOrDevice::cpu()).unwrap();
896        let expected = array!([2.0f32, 3.333_333_3, 1.533_333_3]);
897        assert_array_all_close!(result, expected);
898    }
899}