mlx_rs/
linalg.rs

1//! Linear algebra operations.
2
3use crate::error::{Exception, Result};
4use crate::utils::guard::Guarded;
5use crate::utils::{IntoOption, VectorArray};
6use crate::{Array, Stream};
7use mlx_internal_macros::{default_device, generate_macro};
8use smallvec::SmallVec;
9use std::f64;
10use std::ffi::CString;
11
12/// Order of the norm
13///
14/// See [`norm`] for more details.
15#[derive(Debug, Clone, Copy)]
16pub enum Ord<'a> {
17    /// String representation of the order
18    Str(&'a str),
19
20    /// Order of the norm
21    P(f64),
22}
23
24impl Default for Ord<'_> {
25    fn default() -> Self {
26        Ord::Str("fro")
27    }
28}
29
30impl std::fmt::Display for Ord<'_> {
31    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
32        match self {
33            Ord::Str(s) => write!(f, "{s}"),
34            Ord::P(p) => write!(f, "{p}"),
35        }
36    }
37}
38
39impl<'a> From<&'a str> for Ord<'a> {
40    fn from(value: &'a str) -> Self {
41        Ord::Str(value)
42    }
43}
44
45impl From<f64> for Ord<'_> {
46    fn from(value: f64) -> Self {
47        Ord::P(value)
48    }
49}
50
51impl<'a> IntoOption<Ord<'a>> for &'a str {
52    fn into_option(self) -> Option<Ord<'a>> {
53        Some(Ord::Str(self))
54    }
55}
56
57impl<'a> IntoOption<Ord<'a>> for f64 {
58    fn into_option(self) -> Option<Ord<'a>> {
59        Some(Ord::P(self))
60    }
61}
62
63/// Compute p-norm of an [`Array`]
64#[generate_macro(customize(root = "$crate::linalg"))]
65#[default_device]
66pub fn norm_device<'a>(
67    array: impl AsRef<Array>,
68    ord: f64,
69    #[optional] axes: impl IntoOption<&'a [i32]>,
70    #[optional] keep_dims: impl Into<Option<bool>>,
71    #[optional] stream: impl AsRef<Stream>,
72) -> Result<Array> {
73    let keep_dims = keep_dims.into().unwrap_or(false);
74
75    match axes.into_option() {
76        Some(axes) => Array::try_from_op(|res| unsafe {
77            mlx_sys::mlx_linalg_norm(
78                res,
79                array.as_ref().as_ptr(),
80                ord,
81                axes.as_ptr(),
82                axes.len(),
83                keep_dims,
84                stream.as_ref().as_ptr(),
85            )
86        }),
87        None => Array::try_from_op(|res| unsafe {
88            mlx_sys::mlx_linalg_norm(
89                res,
90                array.as_ref().as_ptr(),
91                ord,
92                std::ptr::null(),
93                0,
94                keep_dims,
95                stream.as_ref().as_ptr(),
96            )
97        }),
98    }
99}
100
101/// Matrix or vector norm.
102#[generate_macro(customize(root = "$crate::linalg"))]
103#[default_device]
104pub fn norm_matrix_device<'a>(
105    array: impl AsRef<Array>,
106    ord: &'a str,
107    #[optional] axes: impl IntoOption<&'a [i32]>,
108    #[optional] keep_dims: impl Into<Option<bool>>,
109    #[optional] stream: impl AsRef<Stream>,
110) -> Result<Array> {
111    let ord = CString::new(ord).map_err(|e| Exception::custom(format!("{e}")))?;
112    let keep_dims = keep_dims.into().unwrap_or(false);
113
114    match axes.into_option() {
115        Some(axes) => Array::try_from_op(|res| unsafe {
116            mlx_sys::mlx_linalg_norm_matrix(
117                res,
118                array.as_ref().as_ptr(),
119                ord.as_ptr(),
120                axes.as_ptr(),
121                axes.len(),
122                keep_dims,
123                stream.as_ref().as_ptr(),
124            )
125        }),
126        None => Array::try_from_op(|res| unsafe {
127            mlx_sys::mlx_linalg_norm_matrix(
128                res,
129                array.as_ref().as_ptr(),
130                ord.as_ptr(),
131                std::ptr::null(),
132                0,
133                keep_dims,
134                stream.as_ref().as_ptr(),
135            )
136        }),
137    }
138}
139
140/// Compute the L2 norm of an [`Array`]
141#[generate_macro(customize(root = "$crate::linalg"))]
142#[default_device]
143pub fn norm_l2_device<'a>(
144    array: impl AsRef<Array>,
145    #[optional] axes: impl IntoOption<&'a [i32]>,
146    #[optional] keep_dims: impl Into<Option<bool>>,
147    #[optional] stream: impl AsRef<Stream>,
148) -> Result<Array> {
149    let keep_dims = keep_dims.into().unwrap_or(false);
150
151    match axes.into_option() {
152        Some(axis) => Array::try_from_op(|res| unsafe {
153            mlx_sys::mlx_linalg_norm_l2(
154                res,
155                array.as_ref().as_ptr(),
156                axis.as_ptr(),
157                axis.len(),
158                keep_dims,
159                stream.as_ref().as_ptr(),
160            )
161        }),
162        None => Array::try_from_op(|res| unsafe {
163            mlx_sys::mlx_linalg_norm_l2(
164                res,
165                array.as_ref().as_ptr(),
166                std::ptr::null(),
167                0,
168                keep_dims,
169                stream.as_ref().as_ptr(),
170            )
171        }),
172    }
173}
174
175// TODO: Change the original `norm` function to use builder pattern
176// /// Matrix or vector norm.
177// ///
178// /// For values of `ord < 1`, the result is, strictly speaking, not a
179// /// mathematical norm, but it may still be useful for various numerical
180// /// purposes.
181// ///
182// /// The following norms can be calculated:
183// ///
184// /// ord   | norm for matrices            | norm for vectors
185// /// ----- | ---------------------------- | --------------------------
186// /// None  | Frobenius norm               | 2-norm
187// /// 'fro' | Frobenius norm               | --
188// /// inf   | max(sum(abs(x), axis-1))     | max(abs(x))
189// /// -inf  | min(sum(abs(x), axis-1))     | min(abs(x))
190// /// 0     | --                           | sum(x !- 0)
191// /// 1     | max(sum(abs(x), axis-0))     | as below
192// /// -1    | min(sum(abs(x), axis-0))     | as below
193// /// 2     | 2-norm (largest sing. value) | as below
194// /// -2    | smallest singular value      | as below
195// /// other | --                           | sum(abs(x)**ord)**(1./ord)
196// ///
197// /// > Nuclear norm and norms based on singular values are not yet implemented.
198// ///
199// /// The Frobenius norm is given by G. H. Golub and C. F. Van Loan, *Matrix Computations*,
200// ///        Baltimore, MD, Johns Hopkins University Press, 1985, pg. 15
201// ///
202// /// The nuclear norm is the sum of the singular values.
203// ///
204// /// Both the Frobenius and nuclear norm orders are only defined for
205// /// matrices and produce a fatal error when `array.ndim != 2`
206// ///
207// /// # Params
208// ///
209// /// - `array`: input array
210// /// - `ord`: order of the norm, see table
211// /// - `axes`: axes that hold 2d matrices
212// /// - `keep_dims`: if `true` the axes which are normed over are left in the result as dimensions
213// ///   with size one
214// #[generate_macro(customize(root = "$crate::linalg"))]
215// #[default_device]
216// pub fn norm_device<'a>(
217//     array: impl AsRef<Array>,
218//     #[optional] ord: impl IntoOption<Ord<'a>>,
219//     #[optional] axes: impl IntoOption<&'a [i32]>,
220//     #[optional] keep_dims: impl Into<Option<bool>>,
221//     #[optional] stream: impl AsRef<Stream>,
222// ) -> Result<Array> {
223//     let ord = ord.into_option();
224//     let axes = axes.into_option();
225//     let keep_dims = keep_dims.into().unwrap_or(false);
226
227//     match (ord, axes) {
228//         // If axis and ord are both unspecified, computes the 2-norm of flatten(x).
229//         (None, None) => {
230//             let axes_ptr = std::ptr::null(); // mlx-c already handles the case where axes is null
231//             Array::try_from_op(|res| unsafe {
232//                 mlx_sys::mlx_linalg_norm(
233//                     res,
234//                     array.as_ref().as_ptr(),
235//                     axes_ptr,
236//                     0,
237//                     keep_dims,
238//                     stream.as_ref().as_ptr(),
239//                 )
240//             })
241//         }
242//         // If axis is not provided but ord is, then x must be either 1D or 2D.
243//         //
244//         // Frobenius norm is only supported for matrices
245//         (Some(Ord::Str(ord)), None) => norm_ord_device(array, ord, axes, keep_dims, stream),
246//         (Some(Ord::P(p)), None) => norm_p_device(array, p, axes, keep_dims, stream),
247//         // If axis is provided, but ord is not, then the 2-norm (or Frobenius norm for matrices) is
248//         // computed along the given axes. At most 2 axes can be specified.
249//         (None, Some(axes)) => Array::try_from_op(|res| unsafe {
250//             mlx_sys::mlx_linalg_norm(
251//                 res,
252//                 array.as_ref().as_ptr(),
253//                 axes.as_ptr(),
254//                 axes.len(),
255//                 keep_dims,
256//                 stream.as_ref().as_ptr(),
257//             )
258//         }),
259//         // If both axis and ord are provided, then the corresponding matrix or vector
260//         // norm is computed. At most 2 axes can be specified.
261//         (Some(Ord::Str(ord)), Some(axes)) => norm_ord_device(array, ord, axes, keep_dims, stream),
262//         (Some(Ord::P(p)), Some(axes)) => norm_p_device(array, p, axes, keep_dims, stream),
263//     }
264// }
265
266/// The QR factorization of the input matrix. Returns an error if the input is not valid.
267///
268/// This function supports arrays with at least 2 dimensions. The matrices which are factorized are
269/// assumed to be in the last two dimensions of the input.
270///
271/// Evaluation on the GPU is not yet implemented.
272///
273/// # Params
274///
275/// - `array`: input array
276///
277/// # Example
278///
279/// ```rust
280/// use mlx_rs::{Array, StreamOrDevice, linalg::*};
281///
282/// let a = Array::from_slice(&[2.0f32, 3.0, 1.0, 2.0], &[2, 2]);
283///
284/// let (q, r) = qr_device(&a, StreamOrDevice::cpu()).unwrap();
285///
286/// let q_expected = Array::from_slice(&[-0.894427, -0.447214, -0.447214, 0.894427], &[2, 2]);
287/// let r_expected = Array::from_slice(&[-2.23607, -3.57771, 0.0, 0.447214], &[2, 2]);
288///
289/// assert!(q.all_close(&q_expected, None, None, None).unwrap().item::<bool>());
290/// assert!(r.all_close(&r_expected, None, None, None).unwrap().item::<bool>());
291/// ```
292#[generate_macro(customize(root = "$crate::linalg"))]
293#[default_device]
294pub fn qr_device(
295    a: impl AsRef<Array>,
296    #[optional] stream: impl AsRef<Stream>,
297) -> Result<(Array, Array)> {
298    <(Array, Array)>::try_from_op(|(res_0, res_1)| unsafe {
299        mlx_sys::mlx_linalg_qr(res_0, res_1, a.as_ref().as_ptr(), stream.as_ref().as_ptr())
300    })
301}
302
303/// The Singular Value Decomposition (SVD) of the input matrix. Returns an error if the input is not
304/// valid.
305///
306/// This function supports arrays with at least 2 dimensions. When the input has more than two
307/// dimensions, the function iterates over all indices of the first a.ndim - 2 dimensions and for
308/// each combination SVD is applied to the last two indices.
309///
310/// Evaluation on the GPU is not yet implemented.
311///
312/// # Params
313///
314/// - `array`: input array
315///
316/// # Example
317///
318/// ```rust
319/// use mlx_rs::{Array, StreamOrDevice, linalg::*};
320///
321/// let a = Array::from_slice(&[1.0f32, 2.0, 3.0, 4.0], &[2, 2]);
322/// let (u, s, vt) = svd_device(&a, StreamOrDevice::cpu()).unwrap();
323/// let u_expected = Array::from_slice(&[-0.404554, 0.914514, -0.914514, -0.404554], &[2, 2]);
324/// let s_expected = Array::from_slice(&[5.46499, 0.365966], &[2]);
325/// let vt_expected = Array::from_slice(&[-0.576048, -0.817416, -0.817415, 0.576048], &[2, 2]);
326/// assert!(u.all_close(&u_expected, None, None, None).unwrap().item::<bool>());
327/// assert!(s.all_close(&s_expected, None, None, None).unwrap().item::<bool>());
328/// assert!(vt.all_close(&vt_expected, None, None, None).unwrap().item::<bool>());
329/// ```
330#[generate_macro(customize(root = "$crate::linalg"))]
331#[default_device]
332pub fn svd_device(
333    array: impl AsRef<Array>,
334    #[optional] stream: impl AsRef<Stream>,
335) -> Result<(Array, Array, Array)> {
336    let v = VectorArray::try_from_op(|res| unsafe {
337        mlx_sys::mlx_linalg_svd(res, array.as_ref().as_ptr(), true, stream.as_ref().as_ptr())
338    })?;
339
340    let vals: SmallVec<[Array; 3]> = v.try_into_values()?;
341    let mut iter = vals.into_iter();
342    let u = iter.next().unwrap();
343    let s = iter.next().unwrap();
344    let vt = iter.next().unwrap();
345
346    Ok((u, s, vt))
347}
348
349/// Compute the inverse of a square matrix. Returns an error if the input is not valid.
350///
351/// This function supports arrays with at least 2 dimensions. When the input has more than two
352/// dimensions, the inverse is computed for each matrix in the last two dimensions of `a`.
353///
354/// Evaluation on the GPU is not yet implemented.
355///
356/// # Params
357///
358/// - `a`: input array
359///
360/// # Example
361///
362/// ```rust
363/// use mlx_rs::{Array, StreamOrDevice, linalg::*};
364///
365/// let a = Array::from_slice(&[1.0f32, 2.0, 3.0, 4.0], &[2, 2]);
366/// let a_inv = inv_device(&a, StreamOrDevice::cpu()).unwrap();
367/// let expected = Array::from_slice(&[-2.0, 1.0, 1.5, -0.5], &[2, 2]);
368/// assert!(a_inv.all_close(&expected, None, None, None).unwrap().item::<bool>());
369/// ```
370#[generate_macro(customize(root = "$crate::linalg"))]
371#[default_device]
372pub fn inv_device(a: impl AsRef<Array>, #[optional] stream: impl AsRef<Stream>) -> Result<Array> {
373    Array::try_from_op(|res| unsafe {
374        mlx_sys::mlx_linalg_inv(res, a.as_ref().as_ptr(), stream.as_ref().as_ptr())
375    })
376}
377
378/// Compute the Cholesky decomposition of a real symmetric positive semi-definite matrix.
379///
380/// This function supports arrays with at least 2 dimensions. When the input has more than two
381/// dimensions, the Cholesky decomposition is computed for each matrix in the last two dimensions of
382/// `a`.
383///
384/// If the input matrix is not symmetric positive semi-definite, behaviour is undefined.
385///
386/// # Params
387///
388/// - `a`: input array
389/// - `upper`: If `true`, return the upper triangular Cholesky factor. If `false`, return the lower
390///   triangular Cholesky factor. Default: `false`.
391#[generate_macro(customize(root = "$crate::linalg"))]
392#[default_device]
393pub fn cholesky_device(
394    a: impl AsRef<Array>,
395    #[optional] upper: Option<bool>,
396    #[optional] stream: impl AsRef<Stream>,
397) -> Result<Array> {
398    let upper = upper.unwrap_or(false);
399    Array::try_from_op(|res| unsafe {
400        mlx_sys::mlx_linalg_cholesky(res, a.as_ref().as_ptr(), upper, stream.as_ref().as_ptr())
401    })
402}
403
404/// Compute the inverse of a real symmetric positive semi-definite matrix using it’s Cholesky decomposition.
405///
406/// Please see the python documentation for more details.
407#[generate_macro(customize(root = "$crate::linalg"))]
408#[default_device]
409pub fn cholesky_inv_device(
410    a: impl AsRef<Array>,
411    #[optional] upper: Option<bool>,
412    #[optional] stream: impl AsRef<Stream>,
413) -> Result<Array> {
414    let upper = upper.unwrap_or(false);
415    Array::try_from_op(|res| unsafe {
416        mlx_sys::mlx_linalg_cholesky_inv(res, a.as_ref().as_ptr(), upper, stream.as_ref().as_ptr())
417    })
418}
419
420/// Compute the cross product of two arrays along a specified axis.
421///
422/// The cross product is defined for arrays with size 2 or 3 in the specified axis. If the size is 2
423/// then the third value is assumed to be zero.
424#[generate_macro(customize(root = "$crate::linalg"))]
425#[default_device]
426pub fn cross_device(
427    a: impl AsRef<Array>,
428    b: impl AsRef<Array>,
429    #[optional] axis: Option<i32>,
430    #[optional] stream: impl AsRef<Stream>,
431) -> Result<Array> {
432    let axis = axis.unwrap_or(-1);
433    Array::try_from_op(|res| unsafe {
434        mlx_sys::mlx_linalg_cross(
435            res,
436            a.as_ref().as_ptr(),
437            b.as_ref().as_ptr(),
438            axis,
439            stream.as_ref().as_ptr(),
440        )
441    })
442}
443
444/// Compute the eigenvalues and eigenvectors of a complex Hermitian or real symmetric matrix.
445///
446/// This function supports arrays with at least 2 dimensions. When the input has more than two
447/// dimensions, the eigenvalues and eigenvectors are computed for each matrix in the last two
448/// dimensions.
449#[generate_macro(customize(root = "$crate::linalg"))]
450#[default_device]
451pub fn eigh_device(
452    a: impl AsRef<Array>,
453    #[optional] uplo: Option<&str>,
454    #[optional] stream: impl AsRef<Stream>,
455) -> Result<(Array, Array)> {
456    let a = a.as_ref();
457    let uplo = CString::new(uplo.unwrap_or("L")).map_err(|e| Exception::custom(format!("{e}")))?;
458
459    <(Array, Array) as Guarded>::try_from_op(|(res_0, res_1)| unsafe {
460        mlx_sys::mlx_linalg_eigh(
461            res_0,
462            res_1,
463            a.as_ptr(),
464            uplo.as_ptr(),
465            stream.as_ref().as_ptr(),
466        )
467    })
468}
469
470/// Compute the eigenvalues of a complex Hermitian or real symmetric matrix.
471///
472/// This function supports arrays with at least 2 dimensions. When the input has more than two
473/// dimensions, the eigenvalues are computed for each matrix in the last two dimensions.
474#[generate_macro(customize(root = "$crate::linalg"))]
475#[default_device]
476pub fn eigvalsh_device(
477    a: impl AsRef<Array>,
478    #[optional] uplo: Option<&str>,
479    #[optional] stream: impl AsRef<Stream>,
480) -> Result<Array> {
481    let a = a.as_ref();
482    let uplo = CString::new(uplo.unwrap_or("L")).map_err(|e| Exception::custom(format!("{e}")))?;
483    Array::try_from_op(|res| unsafe {
484        mlx_sys::mlx_linalg_eigvalsh(res, a.as_ptr(), uplo.as_ptr(), stream.as_ref().as_ptr())
485    })
486}
487
488/// Compute the (Moore-Penrose) pseudo-inverse of a matrix.
489#[generate_macro(customize(root = "$crate::linalg"))]
490#[default_device]
491pub fn pinv_device(a: impl AsRef<Array>, #[optional] stream: impl AsRef<Stream>) -> Result<Array> {
492    Array::try_from_op(|res| unsafe {
493        mlx_sys::mlx_linalg_pinv(res, a.as_ref().as_ptr(), stream.as_ref().as_ptr())
494    })
495}
496
497/// Compute the inverse of a triangular square matrix.
498///
499/// This function supports arrays with at least 2 dimensions. When the input has more than two
500/// dimensions, the inverse is computed for each matrix in the last two dimensions of a.
501#[generate_macro(customize(root = "$crate::linalg"))]
502#[default_device]
503pub fn tri_inv_device(
504    a: impl AsRef<Array>,
505    #[optional] upper: Option<bool>,
506    #[optional] stream: impl AsRef<Stream>,
507) -> Result<Array> {
508    let upper = upper.unwrap_or(false);
509    Array::try_from_op(|res| unsafe {
510        mlx_sys::mlx_linalg_tri_inv(res, a.as_ref().as_ptr(), upper, stream.as_ref().as_ptr())
511    })
512}
513
514/// Compute the LU factorization of the given matrix A.
515///
516/// Note, unlike the default behavior of scipy.linalg.lu, the pivots are
517/// indices. To reconstruct the input use L[P, :] @ U for 2 dimensions or
518/// mx.take_along_axis(L, P[..., None], axis=-2) @ U for more than 2 dimensions.
519///
520/// To construct the full permuation matrix do:
521///
522/// ```rust,ignore
523/// // python
524/// // P = mx.put_along_axis(mx.zeros_like(L), p[..., None], mx.array(1.0), axis=-1)
525/// let p = mlx_rs::ops::put_along_axis(
526///     mlx_rs::ops::zeros_like(&l),
527///     p.index((Ellipsis, NewAxis)),
528///     array!(1.0),
529///     -1,
530/// ).unwrap();
531/// ```
532///
533/// # Params
534///
535/// - `a`: input array
536/// - `stream`: stream to execute the operation
537///
538/// # Returns
539///
540/// The `p`, `L`, and `U` arrays, such that `A = L[P, :] @ U`
541#[generate_macro(customize(root = "$crate::linalg"))]
542#[default_device]
543pub fn lu_device(
544    a: impl AsRef<Array>,
545    #[optional] stream: impl AsRef<Stream>,
546) -> Result<(Array, Array, Array)> {
547    let v = Vec::<Array>::try_from_op(|res| unsafe {
548        mlx_sys::mlx_linalg_lu(res, a.as_ref().as_ptr(), stream.as_ref().as_ptr())
549    })?;
550    let mut iter = v.into_iter();
551    let p = iter.next().ok_or_else(|| Exception::custom("missing P"))?;
552    let l = iter.next().ok_or_else(|| Exception::custom("missing L"))?;
553    let u = iter.next().ok_or_else(|| Exception::custom("missing U"))?;
554    Ok((p, l, u))
555}
556
557/// Computes a compact representation of the LU factorization.
558///
559/// # Params
560///
561/// - `a`: input array
562/// - `stream`: stream to execute the operation
563///
564/// # Returns
565///
566/// The `LU` matrix and `pivots` array.
567#[generate_macro(customize(root = "$crate::linalg"))]
568#[default_device]
569pub fn lu_factor_device(
570    a: impl AsRef<Array>,
571    #[optional] stream: impl AsRef<Stream>,
572) -> Result<(Array, Array)> {
573    <(Array, Array)>::try_from_op(|(res_0, res_1)| unsafe {
574        mlx_sys::mlx_linalg_lu_factor(res_0, res_1, a.as_ref().as_ptr(), stream.as_ref().as_ptr())
575    })
576}
577
578/// Compute the solution to a system of linear equations `AX = B`
579///
580/// # Params
581///
582/// - `a`: input array
583/// - `b`: input array
584/// - `stream`: stream to execute the operation
585///
586/// # Returns
587///
588/// The unique solution to the system `AX = B`
589#[generate_macro(customize(root = "$crate::linalg"))]
590#[default_device]
591pub fn solve_device(
592    a: impl AsRef<Array>,
593    b: impl AsRef<Array>,
594    #[optional] stream: impl AsRef<Stream>,
595) -> Result<Array> {
596    Array::try_from_op(|res| unsafe {
597        mlx_sys::mlx_linalg_solve(
598            res,
599            a.as_ref().as_ptr(),
600            b.as_ref().as_ptr(),
601            stream.as_ref().as_ptr(),
602        )
603    })
604}
605
606/// Computes the solution of a triangular system of linear equations `AX = B`
607///
608/// # Params
609///
610/// - `a`: input array
611/// - `b`: input array
612/// - `upper`: whether the matrix is upper triangular. Default: `false`
613/// - `stream`: stream to execute the operation
614///
615/// # Returns
616///
617/// The unique solution to the system `AX = B`
618#[generate_macro(customize(root = "$crate::linalg"))]
619#[default_device]
620pub fn solve_triangular_device(
621    a: impl AsRef<Array>,
622    b: impl AsRef<Array>,
623    #[optional] upper: impl Into<Option<bool>>,
624    #[optional] stream: impl AsRef<Stream>,
625) -> Result<Array> {
626    let upper = upper.into().unwrap_or(false);
627
628    Array::try_from_op(|res| unsafe {
629        mlx_sys::mlx_linalg_solve_triangular(
630            res,
631            a.as_ref().as_ptr(),
632            b.as_ref().as_ptr(),
633            upper,
634            stream.as_ref().as_ptr(),
635        )
636    })
637}
638
639#[cfg(test)]
640mod tests {
641    use float_eq::assert_float_eq;
642
643    use crate::{
644        array,
645        ops::{eye, indexing::IndexOp, tril, triu},
646        StreamOrDevice,
647    };
648
649    use super::*;
650
651    // The tests below are adapted from the swift bindings tests
652    // and they are not exhaustive. Additional tests should be added
653    // to cover the error cases
654
655    #[test]
656    fn test_norm_no_axes() {
657        let a = Array::from_iter(0..9, &[9]) - 4;
658        let b = a.reshape(&[3, 3]).unwrap();
659
660        assert_float_eq!(
661            norm_l2(&a, None, None).unwrap().item::<f32>(),
662            7.74597,
663            abs <= 0.001
664        );
665        assert_float_eq!(
666            norm_l2(&b, None, None).unwrap().item::<f32>(),
667            7.74597,
668            abs <= 0.001
669        );
670
671        assert_float_eq!(
672            norm_matrix(&b, "fro", None, None).unwrap().item::<f32>(),
673            7.74597,
674            abs <= 0.001
675        );
676
677        assert_float_eq!(
678            norm(&a, f64::INFINITY, None, None).unwrap().item::<f32>(),
679            4.0,
680            abs <= 0.001
681        );
682        assert_float_eq!(
683            norm(&b, f64::INFINITY, None, None).unwrap().item::<f32>(),
684            9.0,
685            abs <= 0.001
686        );
687
688        assert_float_eq!(
689            norm(&a, f64::NEG_INFINITY, None, None)
690                .unwrap()
691                .item::<f32>(),
692            0.0,
693            abs <= 0.001
694        );
695        assert_float_eq!(
696            norm(&b, f64::NEG_INFINITY, None, None)
697                .unwrap()
698                .item::<f32>(),
699            2.0,
700            abs <= 0.001
701        );
702
703        assert_float_eq!(
704            norm(&a, 1.0, None, None).unwrap().item::<f32>(),
705            20.0,
706            abs <= 0.001
707        );
708        assert_float_eq!(
709            norm(&b, 1.0, None, None).unwrap().item::<f32>(),
710            7.0,
711            abs <= 0.001
712        );
713
714        assert_float_eq!(
715            norm(&a, -1.0, None, None).unwrap().item::<f32>(),
716            0.0,
717            abs <= 0.001
718        );
719        assert_float_eq!(
720            norm(&b, -1.0, None, None).unwrap().item::<f32>(),
721            6.0,
722            abs <= 0.001
723        );
724    }
725
726    #[test]
727    fn test_norm_axis() {
728        let c = Array::from_slice(&[1, 2, 3, -1, 1, 4], &[2, 3]);
729
730        let result = norm_l2(&c, &[0], None).unwrap();
731        let expected = Array::from_slice(&[1.41421, 2.23607, 5.0], &[3]);
732        assert!(result
733            .all_close(&expected, None, None, None)
734            .unwrap()
735            .item::<bool>());
736    }
737
738    #[test]
739    fn test_norm_axes() {
740        let m = Array::from_iter(0..8, &[2, 2, 2]);
741
742        let result = norm_l2(&m, &[1, 2][..], None).unwrap();
743        let expected = Array::from_slice(&[3.74166, 11.225], &[2]);
744        assert!(result
745            .all_close(&expected, None, None, None)
746            .unwrap()
747            .item::<bool>());
748    }
749
750    #[test]
751    fn test_qr() {
752        let a = Array::from_slice(&[2.0f32, 3.0, 1.0, 2.0], &[2, 2]);
753
754        let (q, r) = qr_device(&a, StreamOrDevice::cpu()).unwrap();
755
756        let q_expected = Array::from_slice(&[-0.894427, -0.447214, -0.447214, 0.894427], &[2, 2]);
757        let r_expected = Array::from_slice(&[-2.23607, -3.57771, 0.0, 0.447214], &[2, 2]);
758
759        assert!(q
760            .all_close(&q_expected, None, None, None)
761            .unwrap()
762            .item::<bool>());
763        assert!(r
764            .all_close(&r_expected, None, None, None)
765            .unwrap()
766            .item::<bool>());
767    }
768
769    // The tests below are adapted from the c++ tests
770
771    #[test]
772    fn test_svd() {
773        // eval_gpu is not implemented yet.
774        let stream = StreamOrDevice::cpu();
775
776        // 0D and 1D returns error
777        let a = Array::from_f32(0.0);
778        assert!(svd_device(&a, &stream).is_err());
779
780        let a = Array::from_slice(&[0.0, 1.0], &[2]);
781        assert!(svd_device(&a, &stream).is_err());
782
783        // Unsupported types returns error
784        let a = Array::from_slice(&[0, 1], &[1, 2]);
785        assert!(svd_device(&a, &stream).is_err());
786
787        // TODO: wait for random
788    }
789
790    #[test]
791    fn test_inv() {
792        // eval_gpu is not implemented yet.
793        let stream = StreamOrDevice::cpu();
794
795        // 0D and 1D returns error
796        let a = Array::from_f32(0.0);
797        assert!(inv_device(&a, &stream).is_err());
798
799        let a = Array::from_slice(&[0.0, 1.0], &[2]);
800        assert!(inv_device(&a, &stream).is_err());
801
802        // Unsupported types returns error
803        let a = Array::from_slice(&[1, 2, 3, 4, 5, 6], &[2, 3]);
804        assert!(inv_device(&a, &stream).is_err());
805
806        // TODO: wait for random
807    }
808
809    #[test]
810    fn test_cholesky() {
811        // eval_gpu is not implemented yet.
812        let stream = StreamOrDevice::cpu();
813
814        // 0D and 1D returns error
815        let a = Array::from_f32(0.0);
816        assert!(cholesky_device(&a, None, &stream).is_err());
817
818        let a = Array::from_slice(&[0.0, 1.0], &[2]);
819        assert!(cholesky_device(&a, None, &stream).is_err());
820
821        // Unsupported types returns error
822        let a = Array::from_slice(&[0, 1, 1, 2], &[2, 2]);
823        assert!(cholesky_device(&a, None, &stream).is_err());
824
825        // Non-square returns error
826        let a = Array::from_slice(&[1, 2, 3, 4, 5, 6], &[2, 3]);
827        assert!(cholesky_device(&a, None, &stream).is_err());
828
829        // TODO: wait for random
830    }
831
832    // The unit test below is adapted from the python unit test `test_linalg.py/test_lu`
833    #[test]
834    fn test_lu() {
835        let scalar = array!(1.0);
836        let result = lu_device(&scalar, StreamOrDevice::cpu());
837        assert!(result.is_err());
838
839        // # Test 3x3 matrix
840        let a = array!([[3.0f32, 1.0, 2.0], [1.0, 8.0, 6.0], [9.0, 2.0, 5.0]]);
841        let (p, l, u) = lu_device(&a, StreamOrDevice::cpu()).unwrap();
842        let a_rec = l.index((p, ..)).matmul(u).unwrap();
843        assert_array_all_close!(a, a_rec);
844    }
845
846    // The unit test below is adapted from the python unit test `test_linalg.py/test_lu_factor`
847    #[test]
848    fn test_lu_factor() {
849        crate::random::seed(7).unwrap();
850
851        // Test 3x3 matrix
852        let a = crate::random::uniform::<_, f32>(0.0, 1.0, &[5, 5], None).unwrap();
853        let (lu, pivots) = lu_factor_device(&a, StreamOrDevice::cpu()).unwrap();
854        let shape = a.shape();
855        let n = shape[shape.len() - 1];
856
857        let pivots: Vec<u32> = pivots.as_slice().to_vec();
858        let mut perm: Vec<u32> = (0..n as u32).collect();
859        for (i, p) in pivots.iter().enumerate() {
860            perm.swap(i, *p as usize);
861        }
862
863        let l = tril(&lu, -1)
864            .and_then(|l| l.add(eye::<f32>(n, None, None)?))
865            .unwrap();
866        let u = triu(&lu, None).unwrap();
867
868        let lhs = l.matmul(&u).unwrap();
869        let perm = Array::from_slice(&perm, &[n]);
870        let rhs = a.index((perm, ..));
871        assert_array_all_close!(lhs, rhs);
872    }
873
874    // The unit test below is adapted from the python unit test `test_linalg.py/test_solve`
875    #[test]
876    fn test_solve() {
877        crate::random::seed(7).unwrap();
878
879        // Test 3x3 matrix with 1D rhs
880        let a = array!([[3.0f32, 1.0, 2.0], [1.0, 8.0, 6.0], [9.0, 2.0, 5.0]]);
881        let b = array!([11.0f32, 35.0, 28.0]);
882
883        let result = solve_device(&a, &b, StreamOrDevice::cpu()).unwrap();
884        let expected = array!([1.0f32, 2.0, 3.0]);
885        assert_array_all_close!(result, expected);
886    }
887
888    #[test]
889    fn test_solve_triangular() {
890        let a = array!([[4.0f32, 0.0, 0.0], [2.0, 3.0, 0.0], [1.0, -2.0, 5.0]]);
891        let b = array!([8.0f32, 14.0, 3.0]);
892
893        let result = solve_triangular_device(&a, &b, false, StreamOrDevice::cpu()).unwrap();
894        let expected = array!([2.0f32, 3.333_333_3, 1.533_333_3]);
895        assert_array_all_close!(result, expected);
896    }
897}