mlx_rs/
linalg.rs

1//! Linear algebra operations.
2
3use crate::error::{Exception, Result};
4use crate::utils::guard::Guarded;
5use crate::utils::{IntoOption, VectorArray};
6use crate::{Array, Stream};
7use mlx_internal_macros::{default_device, generate_macro};
8use smallvec::SmallVec;
9use std::f64;
10use std::ffi::CString;
11
12/// Order of the norm
13///
14/// See [`norm`] for more details.
15#[derive(Debug, Clone, Copy)]
16pub enum Ord<'a> {
17    /// String representation of the order
18    Str(&'a str),
19
20    /// Order of the norm
21    P(f64),
22}
23
24impl Default for Ord<'_> {
25    fn default() -> Self {
26        Ord::Str("fro")
27    }
28}
29
30impl std::fmt::Display for Ord<'_> {
31    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
32        match self {
33            Ord::Str(s) => write!(f, "{s}"),
34            Ord::P(p) => write!(f, "{p}"),
35        }
36    }
37}
38
39impl<'a> From<&'a str> for Ord<'a> {
40    fn from(value: &'a str) -> Self {
41        Ord::Str(value)
42    }
43}
44
45impl From<f64> for Ord<'_> {
46    fn from(value: f64) -> Self {
47        Ord::P(value)
48    }
49}
50
51impl<'a> IntoOption<Ord<'a>> for &'a str {
52    fn into_option(self) -> Option<Ord<'a>> {
53        Some(Ord::Str(self))
54    }
55}
56
57impl<'a> IntoOption<Ord<'a>> for f64 {
58    fn into_option(self) -> Option<Ord<'a>> {
59        Some(Ord::P(self))
60    }
61}
62
63/// Compute p-norm of an [`Array`]
64#[generate_macro(customize(root = "$crate::linalg"))]
65#[default_device]
66pub fn norm_device<'a>(
67    array: impl AsRef<Array>,
68    ord: f64,
69    #[optional] axes: impl IntoOption<&'a [i32]>,
70    #[optional] keep_dims: impl Into<Option<bool>>,
71    #[optional] stream: impl AsRef<Stream>,
72) -> Result<Array> {
73    let keep_dims = keep_dims.into().unwrap_or(false);
74
75    match axes.into_option() {
76        Some(axes) => Array::try_from_op(|res| unsafe {
77            mlx_sys::mlx_linalg_norm(
78                res,
79                array.as_ref().as_ptr(),
80                ord,
81                axes.as_ptr(),
82                axes.len(),
83                keep_dims,
84                stream.as_ref().as_ptr(),
85            )
86        }),
87        None => Array::try_from_op(|res| unsafe {
88            mlx_sys::mlx_linalg_norm(
89                res,
90                array.as_ref().as_ptr(),
91                ord,
92                std::ptr::null(),
93                0,
94                keep_dims,
95                stream.as_ref().as_ptr(),
96            )
97        }),
98    }
99}
100
101/// Matrix or vector norm.
102#[generate_macro(customize(root = "$crate::linalg"))]
103#[default_device]
104pub fn norm_matrix_device<'a>(
105    array: impl AsRef<Array>,
106    ord: &'a str,
107    #[optional] axes: impl IntoOption<&'a [i32]>,
108    #[optional] keep_dims: impl Into<Option<bool>>,
109    #[optional] stream: impl AsRef<Stream>,
110) -> Result<Array> {
111    let ord = CString::new(ord).map_err(|e| Exception::custom(format!("{e}")))?;
112    let keep_dims = keep_dims.into().unwrap_or(false);
113
114    match axes.into_option() {
115        Some(axes) => Array::try_from_op(|res| unsafe {
116            mlx_sys::mlx_linalg_norm_matrix(
117                res,
118                array.as_ref().as_ptr(),
119                ord.as_ptr(),
120                axes.as_ptr(),
121                axes.len(),
122                keep_dims,
123                stream.as_ref().as_ptr(),
124            )
125        }),
126        None => Array::try_from_op(|res| unsafe {
127            mlx_sys::mlx_linalg_norm_matrix(
128                res,
129                array.as_ref().as_ptr(),
130                ord.as_ptr(),
131                std::ptr::null(),
132                0,
133                keep_dims,
134                stream.as_ref().as_ptr(),
135            )
136        }),
137    }
138}
139
140/// Compute the L2 norm of an [`Array`]
141#[generate_macro(customize(root = "$crate::linalg"))]
142#[default_device]
143pub fn norm_l2_device<'a>(
144    array: impl AsRef<Array>,
145    #[optional] axes: impl IntoOption<&'a [i32]>,
146    #[optional] keep_dims: impl Into<Option<bool>>,
147    #[optional] stream: impl AsRef<Stream>,
148) -> Result<Array> {
149    let keep_dims = keep_dims.into().unwrap_or(false);
150
151    match axes.into_option() {
152        Some(axis) => Array::try_from_op(|res| unsafe {
153            mlx_sys::mlx_linalg_norm_l2(
154                res,
155                array.as_ref().as_ptr(),
156                axis.as_ptr(),
157                axis.len(),
158                keep_dims,
159                stream.as_ref().as_ptr(),
160            )
161        }),
162        None => Array::try_from_op(|res| unsafe {
163            mlx_sys::mlx_linalg_norm_l2(
164                res,
165                array.as_ref().as_ptr(),
166                std::ptr::null(),
167                0,
168                keep_dims,
169                stream.as_ref().as_ptr(),
170            )
171        }),
172    }
173}
174
175// TODO: Change the original `norm` function to use builder pattern
176// /// Matrix or vector norm.
177// ///
178// /// For values of `ord < 1`, the result is, strictly speaking, not a
179// /// mathematical norm, but it may still be useful for various numerical
180// /// purposes.
181// ///
182// /// The following norms can be calculated:
183// ///
184// /// ord   | norm for matrices            | norm for vectors
185// /// ----- | ---------------------------- | --------------------------
186// /// None  | Frobenius norm               | 2-norm
187// /// 'fro' | Frobenius norm               | --
188// /// inf   | max(sum(abs(x), axis-1))     | max(abs(x))
189// /// -inf  | min(sum(abs(x), axis-1))     | min(abs(x))
190// /// 0     | --                           | sum(x !- 0)
191// /// 1     | max(sum(abs(x), axis-0))     | as below
192// /// -1    | min(sum(abs(x), axis-0))     | as below
193// /// 2     | 2-norm (largest sing. value) | as below
194// /// -2    | smallest singular value      | as below
195// /// other | --                           | sum(abs(x)**ord)**(1./ord)
196// ///
197// /// > Nuclear norm and norms based on singular values are not yet implemented.
198// ///
199// /// The Frobenius norm is given by G. H. Golub and C. F. Van Loan, *Matrix Computations*,
200// ///        Baltimore, MD, Johns Hopkins University Press, 1985, pg. 15
201// ///
202// /// The nuclear norm is the sum of the singular values.
203// ///
204// /// Both the Frobenius and nuclear norm orders are only defined for
205// /// matrices and produce a fatal error when `array.ndim != 2`
206// ///
207// /// # Params
208// ///
209// /// - `array`: input array
210// /// - `ord`: order of the norm, see table
211// /// - `axes`: axes that hold 2d matrices
212// /// - `keep_dims`: if `true` the axes which are normed over are left in the result as dimensions
213// ///   with size one
214// #[generate_macro(customize(root = "$crate::linalg"))]
215// #[default_device]
216// pub fn norm_device<'a>(
217//     array: impl AsRef<Array>,
218//     #[optional] ord: impl IntoOption<Ord<'a>>,
219//     #[optional] axes: impl IntoOption<&'a [i32]>,
220//     #[optional] keep_dims: impl Into<Option<bool>>,
221//     #[optional] stream: impl AsRef<Stream>,
222// ) -> Result<Array> {
223//     let ord = ord.into_option();
224//     let axes = axes.into_option();
225//     let keep_dims = keep_dims.into().unwrap_or(false);
226
227//     match (ord, axes) {
228//         // If axis and ord are both unspecified, computes the 2-norm of flatten(x).
229//         (None, None) => {
230//             let axes_ptr = std::ptr::null(); // mlx-c already handles the case where axes is null
231//             Array::try_from_op(|res| unsafe {
232//                 mlx_sys::mlx_linalg_norm(
233//                     res,
234//                     array.as_ref().as_ptr(),
235//                     axes_ptr,
236//                     0,
237//                     keep_dims,
238//                     stream.as_ref().as_ptr(),
239//                 )
240//             })
241//         }
242//         // If axis is not provided but ord is, then x must be either 1D or 2D.
243//         //
244//         // Frobenius norm is only supported for matrices
245//         (Some(Ord::Str(ord)), None) => norm_ord_device(array, ord, axes, keep_dims, stream),
246//         (Some(Ord::P(p)), None) => norm_p_device(array, p, axes, keep_dims, stream),
247//         // If axis is provided, but ord is not, then the 2-norm (or Frobenius norm for matrices) is
248//         // computed along the given axes. At most 2 axes can be specified.
249//         (None, Some(axes)) => Array::try_from_op(|res| unsafe {
250//             mlx_sys::mlx_linalg_norm(
251//                 res,
252//                 array.as_ref().as_ptr(),
253//                 axes.as_ptr(),
254//                 axes.len(),
255//                 keep_dims,
256//                 stream.as_ref().as_ptr(),
257//             )
258//         }),
259//         // If both axis and ord are provided, then the corresponding matrix or vector
260//         // norm is computed. At most 2 axes can be specified.
261//         (Some(Ord::Str(ord)), Some(axes)) => norm_ord_device(array, ord, axes, keep_dims, stream),
262//         (Some(Ord::P(p)), Some(axes)) => norm_p_device(array, p, axes, keep_dims, stream),
263//     }
264// }
265
266/// The QR factorization of the input matrix. Returns an error if the input is not valid.
267///
268/// This function supports arrays with at least 2 dimensions. The matrices which are factorized are
269/// assumed to be in the last two dimensions of the input.
270///
271/// Evaluation on the GPU is not yet implemented.
272///
273/// # Params
274///
275/// - `array`: input array
276///
277/// # Example
278///
279/// ```rust
280/// use mlx_rs::{Array, StreamOrDevice, linalg::*};
281///
282/// let a = Array::from_slice(&[2.0f32, 3.0, 1.0, 2.0], &[2, 2]);
283///
284/// let (q, r) = qr_device(&a, StreamOrDevice::cpu()).unwrap();
285///
286/// let q_expected = Array::from_slice(&[-0.894427, -0.447214, -0.447214, 0.894427], &[2, 2]);
287/// let r_expected = Array::from_slice(&[-2.23607, -3.57771, 0.0, 0.447214], &[2, 2]);
288///
289/// assert!(q.all_close(&q_expected, None, None, None).unwrap().item::<bool>());
290/// assert!(r.all_close(&r_expected, None, None, None).unwrap().item::<bool>());
291/// ```
292#[generate_macro(customize(root = "$crate::linalg"))]
293#[default_device]
294pub fn qr_device(
295    a: impl AsRef<Array>,
296    #[optional] stream: impl AsRef<Stream>,
297) -> Result<(Array, Array)> {
298    <(Array, Array)>::try_from_op(|(res_0, res_1)| unsafe {
299        mlx_sys::mlx_linalg_qr(res_0, res_1, a.as_ref().as_ptr(), stream.as_ref().as_ptr())
300    })
301}
302
303/// The Singular Value Decomposition (SVD) of the input matrix. Returns an error if the input is not
304/// valid.
305///
306/// This function supports arrays with at least 2 dimensions. When the input has more than two
307/// dimensions, the function iterates over all indices of the first a.ndim - 2 dimensions and for
308/// each combination SVD is applied to the last two indices.
309///
310/// Evaluation on the GPU is not yet implemented.
311///
312/// # Params
313///
314/// - `array`: input array
315///
316/// # Example
317///
318/// ```rust
319/// use mlx_rs::{Array, StreamOrDevice, linalg::*};
320///
321/// let a = Array::from_slice(&[1.0f32, 2.0, 3.0, 4.0], &[2, 2]);
322/// let (u, s, vt) = svd_device(&a, StreamOrDevice::cpu()).unwrap();
323/// let u_expected = Array::from_slice(&[-0.404554, 0.914514, -0.914514, -0.404554], &[2, 2]);
324/// let s_expected = Array::from_slice(&[5.46499, 0.365966], &[2]);
325/// let vt_expected = Array::from_slice(&[-0.576048, -0.817416, -0.817415, 0.576048], &[2, 2]);
326/// assert!(u.all_close(&u_expected, None, None, None).unwrap().item::<bool>());
327/// assert!(s.all_close(&s_expected, None, None, None).unwrap().item::<bool>());
328/// assert!(vt.all_close(&vt_expected, None, None, None).unwrap().item::<bool>());
329/// ```
330#[generate_macro(customize(root = "$crate::linalg"))]
331#[default_device]
332pub fn svd_device(
333    array: impl AsRef<Array>,
334    #[optional] stream: impl AsRef<Stream>,
335) -> Result<(Array, Array, Array)> {
336    let v = VectorArray::try_from_op(|res| unsafe {
337        mlx_sys::mlx_linalg_svd(res, array.as_ref().as_ptr(), true, stream.as_ref().as_ptr())
338    })?;
339
340    let vals: SmallVec<[Array; 3]> = v.try_into_values()?;
341    let mut iter = vals.into_iter();
342    let u = iter.next().unwrap();
343    let s = iter.next().unwrap();
344    let vt = iter.next().unwrap();
345
346    Ok((u, s, vt))
347}
348
349/// Compute the inverse of a square matrix. Returns an error if the input is not valid.
350///
351/// This function supports arrays with at least 2 dimensions. When the input has more than two
352/// dimensions, the inverse is computed for each matrix in the last two dimensions of `a`.
353///
354/// Evaluation on the GPU is not yet implemented.
355///
356/// # Params
357///
358/// - `a`: input array
359///
360/// # Example
361///
362/// ```rust
363/// use mlx_rs::{Array, StreamOrDevice, linalg::*};
364///
365/// let a = Array::from_slice(&[1.0f32, 2.0, 3.0, 4.0], &[2, 2]);
366/// let a_inv = inv_device(&a, StreamOrDevice::cpu()).unwrap();
367/// let expected = Array::from_slice(&[-2.0, 1.0, 1.5, -0.5], &[2, 2]);
368/// assert!(a_inv.all_close(&expected, None, None, None).unwrap().item::<bool>());
369/// ```
370#[generate_macro(customize(root = "$crate::linalg"))]
371#[default_device]
372pub fn inv_device(a: impl AsRef<Array>, #[optional] stream: impl AsRef<Stream>) -> Result<Array> {
373    Array::try_from_op(|res| unsafe {
374        mlx_sys::mlx_linalg_inv(res, a.as_ref().as_ptr(), stream.as_ref().as_ptr())
375    })
376}
377
378/// Compute the Cholesky decomposition of a real symmetric positive semi-definite matrix.
379///
380/// This function supports arrays with at least 2 dimensions. When the input has more than two
381/// dimensions, the Cholesky decomposition is computed for each matrix in the last two dimensions of
382/// `a`.
383///
384/// If the input matrix is not symmetric positive semi-definite, behaviour is undefined.
385///
386/// # Params
387///
388/// - `a`: input array
389/// - `upper`: If `true`, return the upper triangular Cholesky factor. If `false`, return the lower
390///   triangular Cholesky factor. Default: `false`.
391#[generate_macro(customize(root = "$crate::linalg"))]
392#[default_device]
393pub fn cholesky_device(
394    a: impl AsRef<Array>,
395    #[optional] upper: Option<bool>,
396    #[optional] stream: impl AsRef<Stream>,
397) -> Result<Array> {
398    let upper = upper.unwrap_or(false);
399    Array::try_from_op(|res| unsafe {
400        mlx_sys::mlx_linalg_cholesky(res, a.as_ref().as_ptr(), upper, stream.as_ref().as_ptr())
401    })
402}
403
404/// Compute the inverse of a real symmetric positive semi-definite matrix using it’s Cholesky decomposition.
405///
406/// Please see the python documentation for more details.
407#[generate_macro(customize(root = "$crate::linalg"))]
408#[default_device]
409pub fn cholesky_inv_device(
410    a: impl AsRef<Array>,
411    #[optional] upper: Option<bool>,
412    #[optional] stream: impl AsRef<Stream>,
413) -> Result<Array> {
414    let upper = upper.unwrap_or(false);
415    Array::try_from_op(|res| unsafe {
416        mlx_sys::mlx_linalg_cholesky_inv(res, a.as_ref().as_ptr(), upper, stream.as_ref().as_ptr())
417    })
418}
419
420/// Compute the cross product of two arrays along a specified axis.
421///
422/// The cross product is defined for arrays with size 2 or 3 in the specified axis. If the size is 2
423/// then the third value is assumed to be zero.
424#[generate_macro(customize(root = "$crate::linalg"))]
425#[default_device]
426pub fn cross_device(
427    a: impl AsRef<Array>,
428    b: impl AsRef<Array>,
429    #[optional] axis: Option<i32>,
430    #[optional] stream: impl AsRef<Stream>,
431) -> Result<Array> {
432    let axis = axis.unwrap_or(-1);
433    Array::try_from_op(|res| unsafe {
434        mlx_sys::mlx_linalg_cross(
435            res,
436            a.as_ref().as_ptr(),
437            b.as_ref().as_ptr(),
438            axis,
439            stream.as_ref().as_ptr(),
440        )
441    })
442}
443
444/// Compute the eigenvalues and eigenvectors of a complex Hermitian or real symmetric matrix.
445///
446/// This function supports arrays with at least 2 dimensions. When the input has more than two
447/// dimensions, the eigenvalues and eigenvectors are computed for each matrix in the last two
448/// dimensions.
449#[generate_macro(customize(root = "$crate::linalg"))]
450#[default_device]
451pub fn eigh_device(
452    a: impl AsRef<Array>,
453    #[optional] uplo: Option<&str>,
454    #[optional] stream: impl AsRef<Stream>,
455) -> Result<(Array, Array)> {
456    let a = a.as_ref();
457    let uplo = CString::new(uplo.unwrap_or("L")).map_err(|e| Exception::custom(format!("{e}")))?;
458
459    <(Array, Array) as Guarded>::try_from_op(|(res_0, res_1)| unsafe {
460        mlx_sys::mlx_linalg_eigh(
461            res_0,
462            res_1,
463            a.as_ptr(),
464            uplo.as_ptr(),
465            stream.as_ref().as_ptr(),
466        )
467    })
468}
469
470/// Compute the eigenvalues of a complex Hermitian or real symmetric matrix.
471///
472/// This function supports arrays with at least 2 dimensions. When the input has more than two
473/// dimensions, the eigenvalues are computed for each matrix in the last two dimensions.
474#[generate_macro(customize(root = "$crate::linalg"))]
475#[default_device]
476pub fn eigvalsh_device(
477    a: impl AsRef<Array>,
478    #[optional] uplo: Option<&str>,
479    #[optional] stream: impl AsRef<Stream>,
480) -> Result<Array> {
481    let a = a.as_ref();
482    let uplo = CString::new(uplo.unwrap_or("L")).map_err(|e| Exception::custom(format!("{e}")))?;
483    Array::try_from_op(|res| unsafe {
484        mlx_sys::mlx_linalg_eigvalsh(res, a.as_ptr(), uplo.as_ptr(), stream.as_ref().as_ptr())
485    })
486}
487
488/// Compute the eigenvalues and eigenvectors of a square matrix.
489///
490/// This function supports arrays with at least 2 dimensions. When the input has more than two
491/// dimensions, the eigenvalues and eigenvectors are computed for each matrix in the last two
492/// dimensions.
493///
494/// Unlike [`eigh`], this function computes eigenvalues for general (not necessarily symmetric
495/// or Hermitian) matrices. The eigenvalues and eigenvectors may be complex.
496///
497/// # Params
498///
499/// - `a`: Input array. Must be a square matrix.
500///
501/// # Returns
502///
503/// A tuple `(eigenvalues, eigenvectors)` where eigenvalues has shape `(..., N)` and
504/// eigenvectors has shape `(..., N, N)`. The eigenvectors are stored as columns.
505///
506/// # Example
507///
508/// ```rust
509/// use mlx_rs::{Array, linalg::*, StreamOrDevice};
510///
511/// let a = Array::from_slice(&[1.0f32, 1.0, 3.0, 4.0], &[2, 2]);
512/// let (eigenvalues, eigenvectors) = eig_device(&a, StreamOrDevice::cpu()).unwrap();
513/// // eigenvalues and eigenvectors are complex even for real input
514/// ```
515#[generate_macro(customize(root = "$crate::linalg"))]
516#[default_device]
517pub fn eig_device(
518    a: impl AsRef<Array>,
519    #[optional] stream: impl AsRef<Stream>,
520) -> Result<(Array, Array)> {
521    <(Array, Array) as Guarded>::try_from_op(|(res_0, res_1)| unsafe {
522        mlx_sys::mlx_linalg_eig(res_0, res_1, a.as_ref().as_ptr(), stream.as_ref().as_ptr())
523    })
524}
525
526/// Compute the eigenvalues of a square matrix.
527///
528/// This function supports arrays with at least 2 dimensions. When the input has more than two
529/// dimensions, the eigenvalues are computed for each matrix in the last two dimensions.
530///
531/// Unlike [`eigvalsh`], this function computes eigenvalues for general (not necessarily symmetric
532/// or Hermitian) matrices. The eigenvalues may be complex.
533///
534/// # Params
535///
536/// - `a`: Input array. Must be a square matrix.
537///
538/// # Returns
539///
540/// An array of eigenvalues with shape `(..., N)`.
541///
542/// # Example
543///
544/// ```rust
545/// use mlx_rs::{Array, linalg::*, StreamOrDevice};
546///
547/// let a = Array::from_slice(&[1.0f32, 1.0, 3.0, 4.0], &[2, 2]);
548/// let eigenvalues = eigvals_device(&a, StreamOrDevice::cpu()).unwrap();
549/// ```
550#[generate_macro(customize(root = "$crate::linalg"))]
551#[default_device]
552pub fn eigvals_device(
553    a: impl AsRef<Array>,
554    #[optional] stream: impl AsRef<Stream>,
555) -> Result<Array> {
556    Array::try_from_op(|res| unsafe {
557        mlx_sys::mlx_linalg_eigvals(res, a.as_ref().as_ptr(), stream.as_ref().as_ptr())
558    })
559}
560
561/// Compute the (Moore-Penrose) pseudo-inverse of a matrix.
562#[generate_macro(customize(root = "$crate::linalg"))]
563#[default_device]
564pub fn pinv_device(a: impl AsRef<Array>, #[optional] stream: impl AsRef<Stream>) -> Result<Array> {
565    Array::try_from_op(|res| unsafe {
566        mlx_sys::mlx_linalg_pinv(res, a.as_ref().as_ptr(), stream.as_ref().as_ptr())
567    })
568}
569
570/// Compute the inverse of a triangular square matrix.
571///
572/// This function supports arrays with at least 2 dimensions. When the input has more than two
573/// dimensions, the inverse is computed for each matrix in the last two dimensions of a.
574#[generate_macro(customize(root = "$crate::linalg"))]
575#[default_device]
576pub fn tri_inv_device(
577    a: impl AsRef<Array>,
578    #[optional] upper: Option<bool>,
579    #[optional] stream: impl AsRef<Stream>,
580) -> Result<Array> {
581    let upper = upper.unwrap_or(false);
582    Array::try_from_op(|res| unsafe {
583        mlx_sys::mlx_linalg_tri_inv(res, a.as_ref().as_ptr(), upper, stream.as_ref().as_ptr())
584    })
585}
586
587/// Compute the LU factorization of the given matrix A.
588///
589/// Note, unlike the default behavior of scipy.linalg.lu, the pivots are
590/// indices. To reconstruct the input use L[P, :] @ U for 2 dimensions or
591/// mx.take_along_axis(L, P[..., None], axis=-2) @ U for more than 2 dimensions.
592///
593/// To construct the full permuation matrix do:
594///
595/// ```rust,ignore
596/// // python
597/// // P = mx.put_along_axis(mx.zeros_like(L), p[..., None], mx.array(1.0), axis=-1)
598/// let p = mlx_rs::ops::put_along_axis(
599///     mlx_rs::ops::zeros_like(&l),
600///     p.index((Ellipsis, NewAxis)),
601///     array!(1.0),
602///     -1,
603/// ).unwrap();
604/// ```
605///
606/// # Params
607///
608/// - `a`: input array
609/// - `stream`: stream to execute the operation
610///
611/// # Returns
612///
613/// The `p`, `L`, and `U` arrays, such that `A = L[P, :] @ U`
614#[generate_macro(customize(root = "$crate::linalg"))]
615#[default_device]
616pub fn lu_device(
617    a: impl AsRef<Array>,
618    #[optional] stream: impl AsRef<Stream>,
619) -> Result<(Array, Array, Array)> {
620    let v = Vec::<Array>::try_from_op(|res| unsafe {
621        mlx_sys::mlx_linalg_lu(res, a.as_ref().as_ptr(), stream.as_ref().as_ptr())
622    })?;
623    let mut iter = v.into_iter();
624    let p = iter.next().ok_or_else(|| Exception::custom("missing P"))?;
625    let l = iter.next().ok_or_else(|| Exception::custom("missing L"))?;
626    let u = iter.next().ok_or_else(|| Exception::custom("missing U"))?;
627    Ok((p, l, u))
628}
629
630/// Computes a compact representation of the LU factorization.
631///
632/// # Params
633///
634/// - `a`: input array
635/// - `stream`: stream to execute the operation
636///
637/// # Returns
638///
639/// The `LU` matrix and `pivots` array.
640#[generate_macro(customize(root = "$crate::linalg"))]
641#[default_device]
642pub fn lu_factor_device(
643    a: impl AsRef<Array>,
644    #[optional] stream: impl AsRef<Stream>,
645) -> Result<(Array, Array)> {
646    <(Array, Array)>::try_from_op(|(res_0, res_1)| unsafe {
647        mlx_sys::mlx_linalg_lu_factor(res_0, res_1, a.as_ref().as_ptr(), stream.as_ref().as_ptr())
648    })
649}
650
651/// Compute the solution to a system of linear equations `AX = B`
652///
653/// # Params
654///
655/// - `a`: input array
656/// - `b`: input array
657/// - `stream`: stream to execute the operation
658///
659/// # Returns
660///
661/// The unique solution to the system `AX = B`
662#[generate_macro(customize(root = "$crate::linalg"))]
663#[default_device]
664pub fn solve_device(
665    a: impl AsRef<Array>,
666    b: impl AsRef<Array>,
667    #[optional] stream: impl AsRef<Stream>,
668) -> Result<Array> {
669    Array::try_from_op(|res| unsafe {
670        mlx_sys::mlx_linalg_solve(
671            res,
672            a.as_ref().as_ptr(),
673            b.as_ref().as_ptr(),
674            stream.as_ref().as_ptr(),
675        )
676    })
677}
678
679/// Computes the solution of a triangular system of linear equations `AX = B`
680///
681/// # Params
682///
683/// - `a`: input array
684/// - `b`: input array
685/// - `upper`: whether the matrix is upper triangular. Default: `false`
686/// - `stream`: stream to execute the operation
687///
688/// # Returns
689///
690/// The unique solution to the system `AX = B`
691#[generate_macro(customize(root = "$crate::linalg"))]
692#[default_device]
693pub fn solve_triangular_device(
694    a: impl AsRef<Array>,
695    b: impl AsRef<Array>,
696    #[optional] upper: impl Into<Option<bool>>,
697    #[optional] stream: impl AsRef<Stream>,
698) -> Result<Array> {
699    let upper = upper.into().unwrap_or(false);
700
701    Array::try_from_op(|res| unsafe {
702        mlx_sys::mlx_linalg_solve_triangular(
703            res,
704            a.as_ref().as_ptr(),
705            b.as_ref().as_ptr(),
706            upper,
707            stream.as_ref().as_ptr(),
708        )
709    })
710}
711
712#[cfg(test)]
713mod tests {
714    use float_eq::assert_float_eq;
715
716    use crate::{
717        array,
718        ops::{eye, indexing::IndexOp, tril, triu},
719        StreamOrDevice,
720    };
721
722    use super::*;
723
724    // The tests below are adapted from the swift bindings tests
725    // and they are not exhaustive. Additional tests should be added
726    // to cover the error cases
727
728    #[test]
729    fn test_norm_no_axes() {
730        let a = Array::from_iter(0..9, &[9]) - 4;
731        let b = a.reshape(&[3, 3]).unwrap();
732
733        assert_float_eq!(
734            norm_l2(&a, None, None).unwrap().item::<f32>(),
735            7.74597,
736            abs <= 0.001
737        );
738        assert_float_eq!(
739            norm_l2(&b, None, None).unwrap().item::<f32>(),
740            7.74597,
741            abs <= 0.001
742        );
743
744        assert_float_eq!(
745            norm_matrix(&b, "fro", None, None).unwrap().item::<f32>(),
746            7.74597,
747            abs <= 0.001
748        );
749
750        assert_float_eq!(
751            norm(&a, f64::INFINITY, None, None).unwrap().item::<f32>(),
752            4.0,
753            abs <= 0.001
754        );
755        assert_float_eq!(
756            norm(&b, f64::INFINITY, None, None).unwrap().item::<f32>(),
757            9.0,
758            abs <= 0.001
759        );
760
761        assert_float_eq!(
762            norm(&a, f64::NEG_INFINITY, None, None)
763                .unwrap()
764                .item::<f32>(),
765            0.0,
766            abs <= 0.001
767        );
768        assert_float_eq!(
769            norm(&b, f64::NEG_INFINITY, None, None)
770                .unwrap()
771                .item::<f32>(),
772            2.0,
773            abs <= 0.001
774        );
775
776        assert_float_eq!(
777            norm(&a, 1.0, None, None).unwrap().item::<f32>(),
778            20.0,
779            abs <= 0.001
780        );
781        assert_float_eq!(
782            norm(&b, 1.0, None, None).unwrap().item::<f32>(),
783            7.0,
784            abs <= 0.001
785        );
786
787        assert_float_eq!(
788            norm(&a, -1.0, None, None).unwrap().item::<f32>(),
789            0.0,
790            abs <= 0.001
791        );
792        assert_float_eq!(
793            norm(&b, -1.0, None, None).unwrap().item::<f32>(),
794            6.0,
795            abs <= 0.001
796        );
797    }
798
799    #[test]
800    fn test_norm_axis() {
801        let c = Array::from_slice(&[1, 2, 3, -1, 1, 4], &[2, 3]);
802
803        let result = norm_l2(&c, &[0], None).unwrap();
804        let expected = Array::from_slice(&[1.41421, 2.23607, 5.0], &[3]);
805        assert!(result
806            .all_close(&expected, None, None, None)
807            .unwrap()
808            .item::<bool>());
809    }
810
811    #[test]
812    fn test_norm_axes() {
813        let m = Array::from_iter(0..8, &[2, 2, 2]);
814
815        let result = norm_l2(&m, &[1, 2][..], None).unwrap();
816        let expected = Array::from_slice(&[3.74166, 11.225], &[2]);
817        assert!(result
818            .all_close(&expected, None, None, None)
819            .unwrap()
820            .item::<bool>());
821    }
822
823    #[test]
824    fn test_qr() {
825        let a = Array::from_slice(&[2.0f32, 3.0, 1.0, 2.0], &[2, 2]);
826
827        let (q, r) = qr_device(&a, StreamOrDevice::cpu()).unwrap();
828
829        let q_expected = Array::from_slice(&[-0.894427, -0.447214, -0.447214, 0.894427], &[2, 2]);
830        let r_expected = Array::from_slice(&[-2.23607, -3.57771, 0.0, 0.447214], &[2, 2]);
831
832        assert!(q
833            .all_close(&q_expected, None, None, None)
834            .unwrap()
835            .item::<bool>());
836        assert!(r
837            .all_close(&r_expected, None, None, None)
838            .unwrap()
839            .item::<bool>());
840    }
841
842    // The tests below are adapted from the c++ tests
843
844    #[test]
845    fn test_svd() {
846        // eval_gpu is not implemented yet.
847        let stream = StreamOrDevice::cpu();
848
849        // 0D and 1D returns error
850        let a = Array::from_f32(0.0);
851        assert!(svd_device(&a, &stream).is_err());
852
853        let a = Array::from_slice(&[0.0, 1.0], &[2]);
854        assert!(svd_device(&a, &stream).is_err());
855
856        // Unsupported types returns error
857        let a = Array::from_slice(&[0, 1], &[1, 2]);
858        assert!(svd_device(&a, &stream).is_err());
859
860        // TODO: wait for random
861    }
862
863    #[test]
864    fn test_inv() {
865        // eval_gpu is not implemented yet.
866        let stream = StreamOrDevice::cpu();
867
868        // 0D and 1D returns error
869        let a = Array::from_f32(0.0);
870        assert!(inv_device(&a, &stream).is_err());
871
872        let a = Array::from_slice(&[0.0, 1.0], &[2]);
873        assert!(inv_device(&a, &stream).is_err());
874
875        // Unsupported types returns error
876        let a = Array::from_slice(&[1, 2, 3, 4, 5, 6], &[2, 3]);
877        assert!(inv_device(&a, &stream).is_err());
878
879        // TODO: wait for random
880    }
881
882    #[test]
883    fn test_cholesky() {
884        // eval_gpu is not implemented yet.
885        let stream = StreamOrDevice::cpu();
886
887        // 0D and 1D returns error
888        let a = Array::from_f32(0.0);
889        assert!(cholesky_device(&a, None, &stream).is_err());
890
891        let a = Array::from_slice(&[0.0, 1.0], &[2]);
892        assert!(cholesky_device(&a, None, &stream).is_err());
893
894        // Unsupported types returns error
895        let a = Array::from_slice(&[0, 1, 1, 2], &[2, 2]);
896        assert!(cholesky_device(&a, None, &stream).is_err());
897
898        // Non-square returns error
899        let a = Array::from_slice(&[1, 2, 3, 4, 5, 6], &[2, 3]);
900        assert!(cholesky_device(&a, None, &stream).is_err());
901
902        // TODO: wait for random
903    }
904
905    // The unit test below is adapted from the python unit test `test_linalg.py/test_lu`
906    #[test]
907    fn test_lu() {
908        let scalar = array!(1.0);
909        let result = lu_device(&scalar, StreamOrDevice::cpu());
910        assert!(result.is_err());
911
912        // # Test 3x3 matrix
913        let a = array!([[3.0f32, 1.0, 2.0], [1.0, 8.0, 6.0], [9.0, 2.0, 5.0]]);
914        let (p, l, u) = lu_device(&a, StreamOrDevice::cpu()).unwrap();
915        let a_rec = l.index((p, ..)).matmul(u).unwrap();
916        assert_array_all_close!(a, a_rec);
917    }
918
919    // The unit test below is adapted from the python unit test `test_linalg.py/test_lu_factor`
920    #[test]
921    fn test_lu_factor() {
922        crate::random::seed(7).unwrap();
923
924        // Test 3x3 matrix
925        let a = crate::random::uniform::<_, f32>(0.0, 1.0, &[5, 5], None).unwrap();
926        let (lu, pivots) = lu_factor_device(&a, StreamOrDevice::cpu()).unwrap();
927        let shape = a.shape();
928        let n = shape[shape.len() - 1];
929
930        let pivots: Vec<u32> = pivots.as_slice().to_vec();
931        let mut perm: Vec<u32> = (0..n as u32).collect();
932        for (i, p) in pivots.iter().enumerate() {
933            perm.swap(i, *p as usize);
934        }
935
936        let l = tril(&lu, -1)
937            .and_then(|l| l.add(eye::<f32>(n, None, None)?))
938            .unwrap();
939        let u = triu(&lu, None).unwrap();
940
941        let lhs = l.matmul(&u).unwrap();
942        let perm = Array::from_slice(&perm, &[n]);
943        let rhs = a.index((perm, ..));
944        assert_array_all_close!(lhs, rhs);
945    }
946
947    // The unit test below is adapted from the python unit test `test_linalg.py/test_solve`
948    #[test]
949    fn test_solve() {
950        crate::random::seed(7).unwrap();
951
952        // Test 3x3 matrix with 1D rhs
953        let a = array!([[3.0f32, 1.0, 2.0], [1.0, 8.0, 6.0], [9.0, 2.0, 5.0]]);
954        let b = array!([11.0f32, 35.0, 28.0]);
955
956        let result = solve_device(&a, &b, StreamOrDevice::cpu()).unwrap();
957        let expected = array!([1.0f32, 2.0, 3.0]);
958        assert_array_all_close!(result, expected);
959    }
960
961    #[test]
962    fn test_solve_triangular() {
963        let a = array!([[4.0f32, 0.0, 0.0], [2.0, 3.0, 0.0], [1.0, -2.0, 5.0]]);
964        let b = array!([8.0f32, 14.0, 3.0]);
965
966        let result = solve_triangular_device(&a, &b, false, StreamOrDevice::cpu()).unwrap();
967        let expected = array!([2.0f32, 3.333_333_3, 1.533_333_3]);
968        assert_array_all_close!(result, expected);
969    }
970
971    // The tests below are adapted from the python unit test `test_linalg.py/test_eig`
972    #[test]
973    fn test_eig() {
974        use crate::ops::expand_dims;
975
976        // Helper to check eigenvalues and eigenvectors
977        fn check_eigs_and_vecs(a: &Array) {
978            let (eig_vals, eig_vecs) = eig_device(a, StreamOrDevice::cpu()).unwrap();
979
980            // Check A @ eig_vecs == eig_vals * eig_vecs
981            let lhs = a.matmul(&eig_vecs).unwrap();
982            // eig_vals[..., None, :] * eig_vecs - broadcast eigenvalues
983            // For a 1D eigenvalues array (n,), we need shape (1, n) to broadcast with eigenvectors (n, n)
984            // For batched eigenvalues (..., n), we need shape (..., 1, n)
985            let eig_vals_broadcast = expand_dims(&eig_vals, -2).unwrap();
986            let rhs = eig_vals_broadcast.multiply(&eig_vecs).unwrap();
987            assert!(
988                lhs.all_close(&rhs, 1e-4, 1e-4, None)
989                    .unwrap()
990                    .item::<bool>(),
991                "A @ eig_vecs should equal eig_vals * eig_vecs"
992            );
993
994            // Check eigvals returns same values
995            let eig_vals_only = eigvals_device(a, StreamOrDevice::cpu()).unwrap();
996            assert!(
997                eig_vals
998                    .all_close(&eig_vals_only, 1e-4, 1e-4, None)
999                    .unwrap()
1000                    .item::<bool>(),
1001                "eigvals should return same eigenvalues as eig"
1002            );
1003        }
1004
1005        // Test a simple 2x2 matrix
1006        let a = array!([[1.0f32, 1.0], [3.0, 4.0]]);
1007        check_eigs_and_vecs(&a);
1008
1009        // Test complex eigenvalues (rotation-like matrix)
1010        let a = array!([[1.0f32, -1.0], [1.0, 1.0]]);
1011        check_eigs_and_vecs(&a);
1012
1013        // Test a larger random matrix
1014        crate::random::seed(1).unwrap();
1015        let a = crate::random::normal::<f32>(&[5, 5], None, None, None).unwrap();
1016        check_eigs_and_vecs(&a);
1017
1018        // Test with batched input
1019        let a = crate::random::normal::<f32>(&[3, 5, 5], None, None, None).unwrap();
1020        check_eigs_and_vecs(&a);
1021    }
1022
1023    #[test]
1024    fn test_eig_errors() {
1025        // 1D array should fail
1026        let a = array!([1.0f32, 2.0]);
1027        assert!(eig_device(&a, StreamOrDevice::cpu()).is_err());
1028        assert!(eigvals_device(&a, StreamOrDevice::cpu()).is_err());
1029
1030        // Non-square matrix should fail
1031        let a = array!([[1.0f32, 2.0], [3.0, 4.0], [5.0, 6.0]]);
1032        assert!(eig_device(&a, StreamOrDevice::cpu()).is_err());
1033        assert!(eigvals_device(&a, StreamOrDevice::cpu()).is_err());
1034    }
1035}