mlx_rs/linalg.rs
1//! Linear algebra operations.
2
3use crate::error::{Exception, Result};
4use crate::utils::guard::Guarded;
5use crate::utils::{IntoOption, VectorArray};
6use crate::{Array, Stream};
7use mlx_internal_macros::{default_device, generate_macro};
8use smallvec::SmallVec;
9use std::f64;
10use std::ffi::CString;
11
12/// Order of the norm
13///
14/// See [`norm`] for more details.
15#[derive(Debug, Clone, Copy)]
16pub enum Ord<'a> {
17 /// String representation of the order
18 Str(&'a str),
19
20 /// Order of the norm
21 P(f64),
22}
23
24impl Default for Ord<'_> {
25 fn default() -> Self {
26 Ord::Str("fro")
27 }
28}
29
30impl std::fmt::Display for Ord<'_> {
31 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
32 match self {
33 Ord::Str(s) => write!(f, "{s}"),
34 Ord::P(p) => write!(f, "{p}"),
35 }
36 }
37}
38
39impl<'a> From<&'a str> for Ord<'a> {
40 fn from(value: &'a str) -> Self {
41 Ord::Str(value)
42 }
43}
44
45impl From<f64> for Ord<'_> {
46 fn from(value: f64) -> Self {
47 Ord::P(value)
48 }
49}
50
51impl<'a> IntoOption<Ord<'a>> for &'a str {
52 fn into_option(self) -> Option<Ord<'a>> {
53 Some(Ord::Str(self))
54 }
55}
56
57impl<'a> IntoOption<Ord<'a>> for f64 {
58 fn into_option(self) -> Option<Ord<'a>> {
59 Some(Ord::P(self))
60 }
61}
62
63/// Compute p-norm of an [`Array`]
64#[generate_macro(customize(root = "$crate::linalg"))]
65#[default_device]
66pub fn norm_device<'a>(
67 array: impl AsRef<Array>,
68 ord: f64,
69 #[optional] axes: impl IntoOption<&'a [i32]>,
70 #[optional] keep_dims: impl Into<Option<bool>>,
71 #[optional] stream: impl AsRef<Stream>,
72) -> Result<Array> {
73 let keep_dims = keep_dims.into().unwrap_or(false);
74
75 match axes.into_option() {
76 Some(axes) => Array::try_from_op(|res| unsafe {
77 mlx_sys::mlx_linalg_norm(
78 res,
79 array.as_ref().as_ptr(),
80 ord,
81 axes.as_ptr(),
82 axes.len(),
83 keep_dims,
84 stream.as_ref().as_ptr(),
85 )
86 }),
87 None => Array::try_from_op(|res| unsafe {
88 mlx_sys::mlx_linalg_norm(
89 res,
90 array.as_ref().as_ptr(),
91 ord,
92 std::ptr::null(),
93 0,
94 keep_dims,
95 stream.as_ref().as_ptr(),
96 )
97 }),
98 }
99}
100
101/// Matrix or vector norm.
102#[generate_macro(customize(root = "$crate::linalg"))]
103#[default_device]
104pub fn norm_matrix_device<'a>(
105 array: impl AsRef<Array>,
106 ord: &'a str,
107 #[optional] axes: impl IntoOption<&'a [i32]>,
108 #[optional] keep_dims: impl Into<Option<bool>>,
109 #[optional] stream: impl AsRef<Stream>,
110) -> Result<Array> {
111 let ord = CString::new(ord).map_err(|e| Exception::custom(format!("{e}")))?;
112 let keep_dims = keep_dims.into().unwrap_or(false);
113
114 match axes.into_option() {
115 Some(axes) => Array::try_from_op(|res| unsafe {
116 mlx_sys::mlx_linalg_norm_matrix(
117 res,
118 array.as_ref().as_ptr(),
119 ord.as_ptr(),
120 axes.as_ptr(),
121 axes.len(),
122 keep_dims,
123 stream.as_ref().as_ptr(),
124 )
125 }),
126 None => Array::try_from_op(|res| unsafe {
127 mlx_sys::mlx_linalg_norm_matrix(
128 res,
129 array.as_ref().as_ptr(),
130 ord.as_ptr(),
131 std::ptr::null(),
132 0,
133 keep_dims,
134 stream.as_ref().as_ptr(),
135 )
136 }),
137 }
138}
139
140/// Compute the L2 norm of an [`Array`]
141#[generate_macro(customize(root = "$crate::linalg"))]
142#[default_device]
143pub fn norm_l2_device<'a>(
144 array: impl AsRef<Array>,
145 #[optional] axes: impl IntoOption<&'a [i32]>,
146 #[optional] keep_dims: impl Into<Option<bool>>,
147 #[optional] stream: impl AsRef<Stream>,
148) -> Result<Array> {
149 let keep_dims = keep_dims.into().unwrap_or(false);
150
151 match axes.into_option() {
152 Some(axis) => Array::try_from_op(|res| unsafe {
153 mlx_sys::mlx_linalg_norm_l2(
154 res,
155 array.as_ref().as_ptr(),
156 axis.as_ptr(),
157 axis.len(),
158 keep_dims,
159 stream.as_ref().as_ptr(),
160 )
161 }),
162 None => Array::try_from_op(|res| unsafe {
163 mlx_sys::mlx_linalg_norm_l2(
164 res,
165 array.as_ref().as_ptr(),
166 std::ptr::null(),
167 0,
168 keep_dims,
169 stream.as_ref().as_ptr(),
170 )
171 }),
172 }
173}
174
175// TODO: Change the original `norm` function to use builder pattern
176// /// Matrix or vector norm.
177// ///
178// /// For values of `ord < 1`, the result is, strictly speaking, not a
179// /// mathematical norm, but it may still be useful for various numerical
180// /// purposes.
181// ///
182// /// The following norms can be calculated:
183// ///
184// /// ord | norm for matrices | norm for vectors
185// /// ----- | ---------------------------- | --------------------------
186// /// None | Frobenius norm | 2-norm
187// /// 'fro' | Frobenius norm | --
188// /// inf | max(sum(abs(x), axis-1)) | max(abs(x))
189// /// -inf | min(sum(abs(x), axis-1)) | min(abs(x))
190// /// 0 | -- | sum(x !- 0)
191// /// 1 | max(sum(abs(x), axis-0)) | as below
192// /// -1 | min(sum(abs(x), axis-0)) | as below
193// /// 2 | 2-norm (largest sing. value) | as below
194// /// -2 | smallest singular value | as below
195// /// other | -- | sum(abs(x)**ord)**(1./ord)
196// ///
197// /// > Nuclear norm and norms based on singular values are not yet implemented.
198// ///
199// /// The Frobenius norm is given by G. H. Golub and C. F. Van Loan, *Matrix Computations*,
200// /// Baltimore, MD, Johns Hopkins University Press, 1985, pg. 15
201// ///
202// /// The nuclear norm is the sum of the singular values.
203// ///
204// /// Both the Frobenius and nuclear norm orders are only defined for
205// /// matrices and produce a fatal error when `array.ndim != 2`
206// ///
207// /// # Params
208// ///
209// /// - `array`: input array
210// /// - `ord`: order of the norm, see table
211// /// - `axes`: axes that hold 2d matrices
212// /// - `keep_dims`: if `true` the axes which are normed over are left in the result as dimensions
213// /// with size one
214// #[generate_macro(customize(root = "$crate::linalg"))]
215// #[default_device]
216// pub fn norm_device<'a>(
217// array: impl AsRef<Array>,
218// #[optional] ord: impl IntoOption<Ord<'a>>,
219// #[optional] axes: impl IntoOption<&'a [i32]>,
220// #[optional] keep_dims: impl Into<Option<bool>>,
221// #[optional] stream: impl AsRef<Stream>,
222// ) -> Result<Array> {
223// let ord = ord.into_option();
224// let axes = axes.into_option();
225// let keep_dims = keep_dims.into().unwrap_or(false);
226
227// match (ord, axes) {
228// // If axis and ord are both unspecified, computes the 2-norm of flatten(x).
229// (None, None) => {
230// let axes_ptr = std::ptr::null(); // mlx-c already handles the case where axes is null
231// Array::try_from_op(|res| unsafe {
232// mlx_sys::mlx_linalg_norm(
233// res,
234// array.as_ref().as_ptr(),
235// axes_ptr,
236// 0,
237// keep_dims,
238// stream.as_ref().as_ptr(),
239// )
240// })
241// }
242// // If axis is not provided but ord is, then x must be either 1D or 2D.
243// //
244// // Frobenius norm is only supported for matrices
245// (Some(Ord::Str(ord)), None) => norm_ord_device(array, ord, axes, keep_dims, stream),
246// (Some(Ord::P(p)), None) => norm_p_device(array, p, axes, keep_dims, stream),
247// // If axis is provided, but ord is not, then the 2-norm (or Frobenius norm for matrices) is
248// // computed along the given axes. At most 2 axes can be specified.
249// (None, Some(axes)) => Array::try_from_op(|res| unsafe {
250// mlx_sys::mlx_linalg_norm(
251// res,
252// array.as_ref().as_ptr(),
253// axes.as_ptr(),
254// axes.len(),
255// keep_dims,
256// stream.as_ref().as_ptr(),
257// )
258// }),
259// // If both axis and ord are provided, then the corresponding matrix or vector
260// // norm is computed. At most 2 axes can be specified.
261// (Some(Ord::Str(ord)), Some(axes)) => norm_ord_device(array, ord, axes, keep_dims, stream),
262// (Some(Ord::P(p)), Some(axes)) => norm_p_device(array, p, axes, keep_dims, stream),
263// }
264// }
265
266/// The QR factorization of the input matrix. Returns an error if the input is not valid.
267///
268/// This function supports arrays with at least 2 dimensions. The matrices which are factorized are
269/// assumed to be in the last two dimensions of the input.
270///
271/// Evaluation on the GPU is not yet implemented.
272///
273/// # Params
274///
275/// - `array`: input array
276///
277/// # Example
278///
279/// ```rust
280/// use mlx_rs::{Array, StreamOrDevice, linalg::*};
281///
282/// let a = Array::from_slice(&[2.0f32, 3.0, 1.0, 2.0], &[2, 2]);
283///
284/// let (q, r) = qr_device(&a, StreamOrDevice::cpu()).unwrap();
285///
286/// let q_expected = Array::from_slice(&[-0.894427, -0.447214, -0.447214, 0.894427], &[2, 2]);
287/// let r_expected = Array::from_slice(&[-2.23607, -3.57771, 0.0, 0.447214], &[2, 2]);
288///
289/// assert!(q.all_close(&q_expected, None, None, None).unwrap().item::<bool>());
290/// assert!(r.all_close(&r_expected, None, None, None).unwrap().item::<bool>());
291/// ```
292#[generate_macro(customize(root = "$crate::linalg"))]
293#[default_device]
294pub fn qr_device(
295 a: impl AsRef<Array>,
296 #[optional] stream: impl AsRef<Stream>,
297) -> Result<(Array, Array)> {
298 <(Array, Array)>::try_from_op(|(res_0, res_1)| unsafe {
299 mlx_sys::mlx_linalg_qr(res_0, res_1, a.as_ref().as_ptr(), stream.as_ref().as_ptr())
300 })
301}
302
303/// The Singular Value Decomposition (SVD) of the input matrix. Returns an error if the input is not
304/// valid.
305///
306/// This function supports arrays with at least 2 dimensions. When the input has more than two
307/// dimensions, the function iterates over all indices of the first a.ndim - 2 dimensions and for
308/// each combination SVD is applied to the last two indices.
309///
310/// Evaluation on the GPU is not yet implemented.
311///
312/// # Params
313///
314/// - `array`: input array
315///
316/// # Example
317///
318/// ```rust
319/// use mlx_rs::{Array, StreamOrDevice, linalg::*};
320///
321/// let a = Array::from_slice(&[1.0f32, 2.0, 3.0, 4.0], &[2, 2]);
322/// let (u, s, vt) = svd_device(&a, StreamOrDevice::cpu()).unwrap();
323/// let u_expected = Array::from_slice(&[-0.404554, 0.914514, -0.914514, -0.404554], &[2, 2]);
324/// let s_expected = Array::from_slice(&[5.46499, 0.365966], &[2]);
325/// let vt_expected = Array::from_slice(&[-0.576048, -0.817416, -0.817415, 0.576048], &[2, 2]);
326/// assert!(u.all_close(&u_expected, None, None, None).unwrap().item::<bool>());
327/// assert!(s.all_close(&s_expected, None, None, None).unwrap().item::<bool>());
328/// assert!(vt.all_close(&vt_expected, None, None, None).unwrap().item::<bool>());
329/// ```
330#[generate_macro(customize(root = "$crate::linalg"))]
331#[default_device]
332pub fn svd_device(
333 array: impl AsRef<Array>,
334 #[optional] stream: impl AsRef<Stream>,
335) -> Result<(Array, Array, Array)> {
336 let v = VectorArray::try_from_op(|res| unsafe {
337 mlx_sys::mlx_linalg_svd(res, array.as_ref().as_ptr(), true, stream.as_ref().as_ptr())
338 })?;
339
340 let vals: SmallVec<[Array; 3]> = v.try_into_values()?;
341 let mut iter = vals.into_iter();
342 let u = iter.next().unwrap();
343 let s = iter.next().unwrap();
344 let vt = iter.next().unwrap();
345
346 Ok((u, s, vt))
347}
348
349/// Compute the inverse of a square matrix. Returns an error if the input is not valid.
350///
351/// This function supports arrays with at least 2 dimensions. When the input has more than two
352/// dimensions, the inverse is computed for each matrix in the last two dimensions of `a`.
353///
354/// Evaluation on the GPU is not yet implemented.
355///
356/// # Params
357///
358/// - `a`: input array
359///
360/// # Example
361///
362/// ```rust
363/// use mlx_rs::{Array, StreamOrDevice, linalg::*};
364///
365/// let a = Array::from_slice(&[1.0f32, 2.0, 3.0, 4.0], &[2, 2]);
366/// let a_inv = inv_device(&a, StreamOrDevice::cpu()).unwrap();
367/// let expected = Array::from_slice(&[-2.0, 1.0, 1.5, -0.5], &[2, 2]);
368/// assert!(a_inv.all_close(&expected, None, None, None).unwrap().item::<bool>());
369/// ```
370#[generate_macro(customize(root = "$crate::linalg"))]
371#[default_device]
372pub fn inv_device(a: impl AsRef<Array>, #[optional] stream: impl AsRef<Stream>) -> Result<Array> {
373 Array::try_from_op(|res| unsafe {
374 mlx_sys::mlx_linalg_inv(res, a.as_ref().as_ptr(), stream.as_ref().as_ptr())
375 })
376}
377
378/// Compute the Cholesky decomposition of a real symmetric positive semi-definite matrix.
379///
380/// This function supports arrays with at least 2 dimensions. When the input has more than two
381/// dimensions, the Cholesky decomposition is computed for each matrix in the last two dimensions of
382/// `a`.
383///
384/// If the input matrix is not symmetric positive semi-definite, behaviour is undefined.
385///
386/// # Params
387///
388/// - `a`: input array
389/// - `upper`: If `true`, return the upper triangular Cholesky factor. If `false`, return the lower
390/// triangular Cholesky factor. Default: `false`.
391#[generate_macro(customize(root = "$crate::linalg"))]
392#[default_device]
393pub fn cholesky_device(
394 a: impl AsRef<Array>,
395 #[optional] upper: Option<bool>,
396 #[optional] stream: impl AsRef<Stream>,
397) -> Result<Array> {
398 let upper = upper.unwrap_or(false);
399 Array::try_from_op(|res| unsafe {
400 mlx_sys::mlx_linalg_cholesky(res, a.as_ref().as_ptr(), upper, stream.as_ref().as_ptr())
401 })
402}
403
404/// Compute the inverse of a real symmetric positive semi-definite matrix using it’s Cholesky decomposition.
405///
406/// Please see the python documentation for more details.
407#[generate_macro(customize(root = "$crate::linalg"))]
408#[default_device]
409pub fn cholesky_inv_device(
410 a: impl AsRef<Array>,
411 #[optional] upper: Option<bool>,
412 #[optional] stream: impl AsRef<Stream>,
413) -> Result<Array> {
414 let upper = upper.unwrap_or(false);
415 Array::try_from_op(|res| unsafe {
416 mlx_sys::mlx_linalg_cholesky_inv(res, a.as_ref().as_ptr(), upper, stream.as_ref().as_ptr())
417 })
418}
419
420/// Compute the cross product of two arrays along a specified axis.
421///
422/// The cross product is defined for arrays with size 2 or 3 in the specified axis. If the size is 2
423/// then the third value is assumed to be zero.
424#[generate_macro(customize(root = "$crate::linalg"))]
425#[default_device]
426pub fn cross_device(
427 a: impl AsRef<Array>,
428 b: impl AsRef<Array>,
429 #[optional] axis: Option<i32>,
430 #[optional] stream: impl AsRef<Stream>,
431) -> Result<Array> {
432 let axis = axis.unwrap_or(-1);
433 Array::try_from_op(|res| unsafe {
434 mlx_sys::mlx_linalg_cross(
435 res,
436 a.as_ref().as_ptr(),
437 b.as_ref().as_ptr(),
438 axis,
439 stream.as_ref().as_ptr(),
440 )
441 })
442}
443
444/// Compute the eigenvalues and eigenvectors of a complex Hermitian or real symmetric matrix.
445///
446/// This function supports arrays with at least 2 dimensions. When the input has more than two
447/// dimensions, the eigenvalues and eigenvectors are computed for each matrix in the last two
448/// dimensions.
449#[generate_macro(customize(root = "$crate::linalg"))]
450#[default_device]
451pub fn eigh_device(
452 a: impl AsRef<Array>,
453 #[optional] uplo: Option<&str>,
454 #[optional] stream: impl AsRef<Stream>,
455) -> Result<(Array, Array)> {
456 let a = a.as_ref();
457 let uplo = CString::new(uplo.unwrap_or("L")).map_err(|e| Exception::custom(format!("{e}")))?;
458
459 <(Array, Array) as Guarded>::try_from_op(|(res_0, res_1)| unsafe {
460 mlx_sys::mlx_linalg_eigh(
461 res_0,
462 res_1,
463 a.as_ptr(),
464 uplo.as_ptr(),
465 stream.as_ref().as_ptr(),
466 )
467 })
468}
469
470/// Compute the eigenvalues of a complex Hermitian or real symmetric matrix.
471///
472/// This function supports arrays with at least 2 dimensions. When the input has more than two
473/// dimensions, the eigenvalues are computed for each matrix in the last two dimensions.
474#[generate_macro(customize(root = "$crate::linalg"))]
475#[default_device]
476pub fn eigvalsh_device(
477 a: impl AsRef<Array>,
478 #[optional] uplo: Option<&str>,
479 #[optional] stream: impl AsRef<Stream>,
480) -> Result<Array> {
481 let a = a.as_ref();
482 let uplo = CString::new(uplo.unwrap_or("L")).map_err(|e| Exception::custom(format!("{e}")))?;
483 Array::try_from_op(|res| unsafe {
484 mlx_sys::mlx_linalg_eigvalsh(res, a.as_ptr(), uplo.as_ptr(), stream.as_ref().as_ptr())
485 })
486}
487
488/// Compute the eigenvalues and eigenvectors of a square matrix.
489///
490/// This function supports arrays with at least 2 dimensions. When the input has more than two
491/// dimensions, the eigenvalues and eigenvectors are computed for each matrix in the last two
492/// dimensions.
493///
494/// Unlike [`eigh`], this function computes eigenvalues for general (not necessarily symmetric
495/// or Hermitian) matrices. The eigenvalues and eigenvectors may be complex.
496///
497/// # Params
498///
499/// - `a`: Input array. Must be a square matrix.
500///
501/// # Returns
502///
503/// A tuple `(eigenvalues, eigenvectors)` where eigenvalues has shape `(..., N)` and
504/// eigenvectors has shape `(..., N, N)`. The eigenvectors are stored as columns.
505///
506/// # Example
507///
508/// ```rust
509/// use mlx_rs::{Array, linalg::*, StreamOrDevice};
510///
511/// let a = Array::from_slice(&[1.0f32, 1.0, 3.0, 4.0], &[2, 2]);
512/// let (eigenvalues, eigenvectors) = eig_device(&a, StreamOrDevice::cpu()).unwrap();
513/// // eigenvalues and eigenvectors are complex even for real input
514/// ```
515#[generate_macro(customize(root = "$crate::linalg"))]
516#[default_device]
517pub fn eig_device(
518 a: impl AsRef<Array>,
519 #[optional] stream: impl AsRef<Stream>,
520) -> Result<(Array, Array)> {
521 <(Array, Array) as Guarded>::try_from_op(|(res_0, res_1)| unsafe {
522 mlx_sys::mlx_linalg_eig(res_0, res_1, a.as_ref().as_ptr(), stream.as_ref().as_ptr())
523 })
524}
525
526/// Compute the eigenvalues of a square matrix.
527///
528/// This function supports arrays with at least 2 dimensions. When the input has more than two
529/// dimensions, the eigenvalues are computed for each matrix in the last two dimensions.
530///
531/// Unlike [`eigvalsh`], this function computes eigenvalues for general (not necessarily symmetric
532/// or Hermitian) matrices. The eigenvalues may be complex.
533///
534/// # Params
535///
536/// - `a`: Input array. Must be a square matrix.
537///
538/// # Returns
539///
540/// An array of eigenvalues with shape `(..., N)`.
541///
542/// # Example
543///
544/// ```rust
545/// use mlx_rs::{Array, linalg::*, StreamOrDevice};
546///
547/// let a = Array::from_slice(&[1.0f32, 1.0, 3.0, 4.0], &[2, 2]);
548/// let eigenvalues = eigvals_device(&a, StreamOrDevice::cpu()).unwrap();
549/// ```
550#[generate_macro(customize(root = "$crate::linalg"))]
551#[default_device]
552pub fn eigvals_device(
553 a: impl AsRef<Array>,
554 #[optional] stream: impl AsRef<Stream>,
555) -> Result<Array> {
556 Array::try_from_op(|res| unsafe {
557 mlx_sys::mlx_linalg_eigvals(res, a.as_ref().as_ptr(), stream.as_ref().as_ptr())
558 })
559}
560
561/// Compute the (Moore-Penrose) pseudo-inverse of a matrix.
562#[generate_macro(customize(root = "$crate::linalg"))]
563#[default_device]
564pub fn pinv_device(a: impl AsRef<Array>, #[optional] stream: impl AsRef<Stream>) -> Result<Array> {
565 Array::try_from_op(|res| unsafe {
566 mlx_sys::mlx_linalg_pinv(res, a.as_ref().as_ptr(), stream.as_ref().as_ptr())
567 })
568}
569
570/// Compute the inverse of a triangular square matrix.
571///
572/// This function supports arrays with at least 2 dimensions. When the input has more than two
573/// dimensions, the inverse is computed for each matrix in the last two dimensions of a.
574#[generate_macro(customize(root = "$crate::linalg"))]
575#[default_device]
576pub fn tri_inv_device(
577 a: impl AsRef<Array>,
578 #[optional] upper: Option<bool>,
579 #[optional] stream: impl AsRef<Stream>,
580) -> Result<Array> {
581 let upper = upper.unwrap_or(false);
582 Array::try_from_op(|res| unsafe {
583 mlx_sys::mlx_linalg_tri_inv(res, a.as_ref().as_ptr(), upper, stream.as_ref().as_ptr())
584 })
585}
586
587/// Compute the LU factorization of the given matrix A.
588///
589/// Note, unlike the default behavior of scipy.linalg.lu, the pivots are
590/// indices. To reconstruct the input use L[P, :] @ U for 2 dimensions or
591/// mx.take_along_axis(L, P[..., None], axis=-2) @ U for more than 2 dimensions.
592///
593/// To construct the full permuation matrix do:
594///
595/// ```rust,ignore
596/// // python
597/// // P = mx.put_along_axis(mx.zeros_like(L), p[..., None], mx.array(1.0), axis=-1)
598/// let p = mlx_rs::ops::put_along_axis(
599/// mlx_rs::ops::zeros_like(&l),
600/// p.index((Ellipsis, NewAxis)),
601/// array!(1.0),
602/// -1,
603/// ).unwrap();
604/// ```
605///
606/// # Params
607///
608/// - `a`: input array
609/// - `stream`: stream to execute the operation
610///
611/// # Returns
612///
613/// The `p`, `L`, and `U` arrays, such that `A = L[P, :] @ U`
614#[generate_macro(customize(root = "$crate::linalg"))]
615#[default_device]
616pub fn lu_device(
617 a: impl AsRef<Array>,
618 #[optional] stream: impl AsRef<Stream>,
619) -> Result<(Array, Array, Array)> {
620 let v = Vec::<Array>::try_from_op(|res| unsafe {
621 mlx_sys::mlx_linalg_lu(res, a.as_ref().as_ptr(), stream.as_ref().as_ptr())
622 })?;
623 let mut iter = v.into_iter();
624 let p = iter.next().ok_or_else(|| Exception::custom("missing P"))?;
625 let l = iter.next().ok_or_else(|| Exception::custom("missing L"))?;
626 let u = iter.next().ok_or_else(|| Exception::custom("missing U"))?;
627 Ok((p, l, u))
628}
629
630/// Computes a compact representation of the LU factorization.
631///
632/// # Params
633///
634/// - `a`: input array
635/// - `stream`: stream to execute the operation
636///
637/// # Returns
638///
639/// The `LU` matrix and `pivots` array.
640#[generate_macro(customize(root = "$crate::linalg"))]
641#[default_device]
642pub fn lu_factor_device(
643 a: impl AsRef<Array>,
644 #[optional] stream: impl AsRef<Stream>,
645) -> Result<(Array, Array)> {
646 <(Array, Array)>::try_from_op(|(res_0, res_1)| unsafe {
647 mlx_sys::mlx_linalg_lu_factor(res_0, res_1, a.as_ref().as_ptr(), stream.as_ref().as_ptr())
648 })
649}
650
651/// Compute the solution to a system of linear equations `AX = B`
652///
653/// # Params
654///
655/// - `a`: input array
656/// - `b`: input array
657/// - `stream`: stream to execute the operation
658///
659/// # Returns
660///
661/// The unique solution to the system `AX = B`
662#[generate_macro(customize(root = "$crate::linalg"))]
663#[default_device]
664pub fn solve_device(
665 a: impl AsRef<Array>,
666 b: impl AsRef<Array>,
667 #[optional] stream: impl AsRef<Stream>,
668) -> Result<Array> {
669 Array::try_from_op(|res| unsafe {
670 mlx_sys::mlx_linalg_solve(
671 res,
672 a.as_ref().as_ptr(),
673 b.as_ref().as_ptr(),
674 stream.as_ref().as_ptr(),
675 )
676 })
677}
678
679/// Computes the solution of a triangular system of linear equations `AX = B`
680///
681/// # Params
682///
683/// - `a`: input array
684/// - `b`: input array
685/// - `upper`: whether the matrix is upper triangular. Default: `false`
686/// - `stream`: stream to execute the operation
687///
688/// # Returns
689///
690/// The unique solution to the system `AX = B`
691#[generate_macro(customize(root = "$crate::linalg"))]
692#[default_device]
693pub fn solve_triangular_device(
694 a: impl AsRef<Array>,
695 b: impl AsRef<Array>,
696 #[optional] upper: impl Into<Option<bool>>,
697 #[optional] stream: impl AsRef<Stream>,
698) -> Result<Array> {
699 let upper = upper.into().unwrap_or(false);
700
701 Array::try_from_op(|res| unsafe {
702 mlx_sys::mlx_linalg_solve_triangular(
703 res,
704 a.as_ref().as_ptr(),
705 b.as_ref().as_ptr(),
706 upper,
707 stream.as_ref().as_ptr(),
708 )
709 })
710}
711
712#[cfg(test)]
713mod tests {
714 use float_eq::assert_float_eq;
715
716 use crate::{
717 array,
718 ops::{eye, indexing::IndexOp, tril, triu},
719 StreamOrDevice,
720 };
721
722 use super::*;
723
724 // The tests below are adapted from the swift bindings tests
725 // and they are not exhaustive. Additional tests should be added
726 // to cover the error cases
727
728 #[test]
729 fn test_norm_no_axes() {
730 let a = Array::from_iter(0..9, &[9]) - 4;
731 let b = a.reshape(&[3, 3]).unwrap();
732
733 assert_float_eq!(
734 norm_l2(&a, None, None).unwrap().item::<f32>(),
735 7.74597,
736 abs <= 0.001
737 );
738 assert_float_eq!(
739 norm_l2(&b, None, None).unwrap().item::<f32>(),
740 7.74597,
741 abs <= 0.001
742 );
743
744 assert_float_eq!(
745 norm_matrix(&b, "fro", None, None).unwrap().item::<f32>(),
746 7.74597,
747 abs <= 0.001
748 );
749
750 assert_float_eq!(
751 norm(&a, f64::INFINITY, None, None).unwrap().item::<f32>(),
752 4.0,
753 abs <= 0.001
754 );
755 assert_float_eq!(
756 norm(&b, f64::INFINITY, None, None).unwrap().item::<f32>(),
757 9.0,
758 abs <= 0.001
759 );
760
761 assert_float_eq!(
762 norm(&a, f64::NEG_INFINITY, None, None)
763 .unwrap()
764 .item::<f32>(),
765 0.0,
766 abs <= 0.001
767 );
768 assert_float_eq!(
769 norm(&b, f64::NEG_INFINITY, None, None)
770 .unwrap()
771 .item::<f32>(),
772 2.0,
773 abs <= 0.001
774 );
775
776 assert_float_eq!(
777 norm(&a, 1.0, None, None).unwrap().item::<f32>(),
778 20.0,
779 abs <= 0.001
780 );
781 assert_float_eq!(
782 norm(&b, 1.0, None, None).unwrap().item::<f32>(),
783 7.0,
784 abs <= 0.001
785 );
786
787 assert_float_eq!(
788 norm(&a, -1.0, None, None).unwrap().item::<f32>(),
789 0.0,
790 abs <= 0.001
791 );
792 assert_float_eq!(
793 norm(&b, -1.0, None, None).unwrap().item::<f32>(),
794 6.0,
795 abs <= 0.001
796 );
797 }
798
799 #[test]
800 fn test_norm_axis() {
801 let c = Array::from_slice(&[1, 2, 3, -1, 1, 4], &[2, 3]);
802
803 let result = norm_l2(&c, &[0], None).unwrap();
804 let expected = Array::from_slice(&[1.41421, 2.23607, 5.0], &[3]);
805 assert!(result
806 .all_close(&expected, None, None, None)
807 .unwrap()
808 .item::<bool>());
809 }
810
811 #[test]
812 fn test_norm_axes() {
813 let m = Array::from_iter(0..8, &[2, 2, 2]);
814
815 let result = norm_l2(&m, &[1, 2][..], None).unwrap();
816 let expected = Array::from_slice(&[3.74166, 11.225], &[2]);
817 assert!(result
818 .all_close(&expected, None, None, None)
819 .unwrap()
820 .item::<bool>());
821 }
822
823 #[test]
824 fn test_qr() {
825 let a = Array::from_slice(&[2.0f32, 3.0, 1.0, 2.0], &[2, 2]);
826
827 let (q, r) = qr_device(&a, StreamOrDevice::cpu()).unwrap();
828
829 let q_expected = Array::from_slice(&[-0.894427, -0.447214, -0.447214, 0.894427], &[2, 2]);
830 let r_expected = Array::from_slice(&[-2.23607, -3.57771, 0.0, 0.447214], &[2, 2]);
831
832 assert!(q
833 .all_close(&q_expected, None, None, None)
834 .unwrap()
835 .item::<bool>());
836 assert!(r
837 .all_close(&r_expected, None, None, None)
838 .unwrap()
839 .item::<bool>());
840 }
841
842 // The tests below are adapted from the c++ tests
843
844 #[test]
845 fn test_svd() {
846 // eval_gpu is not implemented yet.
847 let stream = StreamOrDevice::cpu();
848
849 // 0D and 1D returns error
850 let a = Array::from_f32(0.0);
851 assert!(svd_device(&a, &stream).is_err());
852
853 let a = Array::from_slice(&[0.0, 1.0], &[2]);
854 assert!(svd_device(&a, &stream).is_err());
855
856 // Unsupported types returns error
857 let a = Array::from_slice(&[0, 1], &[1, 2]);
858 assert!(svd_device(&a, &stream).is_err());
859
860 // TODO: wait for random
861 }
862
863 #[test]
864 fn test_inv() {
865 // eval_gpu is not implemented yet.
866 let stream = StreamOrDevice::cpu();
867
868 // 0D and 1D returns error
869 let a = Array::from_f32(0.0);
870 assert!(inv_device(&a, &stream).is_err());
871
872 let a = Array::from_slice(&[0.0, 1.0], &[2]);
873 assert!(inv_device(&a, &stream).is_err());
874
875 // Unsupported types returns error
876 let a = Array::from_slice(&[1, 2, 3, 4, 5, 6], &[2, 3]);
877 assert!(inv_device(&a, &stream).is_err());
878
879 // TODO: wait for random
880 }
881
882 #[test]
883 fn test_cholesky() {
884 // eval_gpu is not implemented yet.
885 let stream = StreamOrDevice::cpu();
886
887 // 0D and 1D returns error
888 let a = Array::from_f32(0.0);
889 assert!(cholesky_device(&a, None, &stream).is_err());
890
891 let a = Array::from_slice(&[0.0, 1.0], &[2]);
892 assert!(cholesky_device(&a, None, &stream).is_err());
893
894 // Unsupported types returns error
895 let a = Array::from_slice(&[0, 1, 1, 2], &[2, 2]);
896 assert!(cholesky_device(&a, None, &stream).is_err());
897
898 // Non-square returns error
899 let a = Array::from_slice(&[1, 2, 3, 4, 5, 6], &[2, 3]);
900 assert!(cholesky_device(&a, None, &stream).is_err());
901
902 // TODO: wait for random
903 }
904
905 // The unit test below is adapted from the python unit test `test_linalg.py/test_lu`
906 #[test]
907 fn test_lu() {
908 let scalar = array!(1.0);
909 let result = lu_device(&scalar, StreamOrDevice::cpu());
910 assert!(result.is_err());
911
912 // # Test 3x3 matrix
913 let a = array!([[3.0f32, 1.0, 2.0], [1.0, 8.0, 6.0], [9.0, 2.0, 5.0]]);
914 let (p, l, u) = lu_device(&a, StreamOrDevice::cpu()).unwrap();
915 let a_rec = l.index((p, ..)).matmul(u).unwrap();
916 assert_array_all_close!(a, a_rec);
917 }
918
919 // The unit test below is adapted from the python unit test `test_linalg.py/test_lu_factor`
920 #[test]
921 fn test_lu_factor() {
922 crate::random::seed(7).unwrap();
923
924 // Test 3x3 matrix
925 let a = crate::random::uniform::<_, f32>(0.0, 1.0, &[5, 5], None).unwrap();
926 let (lu, pivots) = lu_factor_device(&a, StreamOrDevice::cpu()).unwrap();
927 let shape = a.shape();
928 let n = shape[shape.len() - 1];
929
930 let pivots: Vec<u32> = pivots.as_slice().to_vec();
931 let mut perm: Vec<u32> = (0..n as u32).collect();
932 for (i, p) in pivots.iter().enumerate() {
933 perm.swap(i, *p as usize);
934 }
935
936 let l = tril(&lu, -1)
937 .and_then(|l| l.add(eye::<f32>(n, None, None)?))
938 .unwrap();
939 let u = triu(&lu, None).unwrap();
940
941 let lhs = l.matmul(&u).unwrap();
942 let perm = Array::from_slice(&perm, &[n]);
943 let rhs = a.index((perm, ..));
944 assert_array_all_close!(lhs, rhs);
945 }
946
947 // The unit test below is adapted from the python unit test `test_linalg.py/test_solve`
948 #[test]
949 fn test_solve() {
950 crate::random::seed(7).unwrap();
951
952 // Test 3x3 matrix with 1D rhs
953 let a = array!([[3.0f32, 1.0, 2.0], [1.0, 8.0, 6.0], [9.0, 2.0, 5.0]]);
954 let b = array!([11.0f32, 35.0, 28.0]);
955
956 let result = solve_device(&a, &b, StreamOrDevice::cpu()).unwrap();
957 let expected = array!([1.0f32, 2.0, 3.0]);
958 assert_array_all_close!(result, expected);
959 }
960
961 #[test]
962 fn test_solve_triangular() {
963 let a = array!([[4.0f32, 0.0, 0.0], [2.0, 3.0, 0.0], [1.0, -2.0, 5.0]]);
964 let b = array!([8.0f32, 14.0, 3.0]);
965
966 let result = solve_triangular_device(&a, &b, false, StreamOrDevice::cpu()).unwrap();
967 let expected = array!([2.0f32, 3.333_333_3, 1.533_333_3]);
968 assert_array_all_close!(result, expected);
969 }
970
971 // The tests below are adapted from the python unit test `test_linalg.py/test_eig`
972 #[test]
973 fn test_eig() {
974 use crate::ops::expand_dims;
975
976 // Helper to check eigenvalues and eigenvectors
977 fn check_eigs_and_vecs(a: &Array) {
978 let (eig_vals, eig_vecs) = eig_device(a, StreamOrDevice::cpu()).unwrap();
979
980 // Check A @ eig_vecs == eig_vals * eig_vecs
981 let lhs = a.matmul(&eig_vecs).unwrap();
982 // eig_vals[..., None, :] * eig_vecs - broadcast eigenvalues
983 // For a 1D eigenvalues array (n,), we need shape (1, n) to broadcast with eigenvectors (n, n)
984 // For batched eigenvalues (..., n), we need shape (..., 1, n)
985 let eig_vals_broadcast = expand_dims(&eig_vals, -2).unwrap();
986 let rhs = eig_vals_broadcast.multiply(&eig_vecs).unwrap();
987 assert!(
988 lhs.all_close(&rhs, 1e-4, 1e-4, None)
989 .unwrap()
990 .item::<bool>(),
991 "A @ eig_vecs should equal eig_vals * eig_vecs"
992 );
993
994 // Check eigvals returns same values
995 let eig_vals_only = eigvals_device(a, StreamOrDevice::cpu()).unwrap();
996 assert!(
997 eig_vals
998 .all_close(&eig_vals_only, 1e-4, 1e-4, None)
999 .unwrap()
1000 .item::<bool>(),
1001 "eigvals should return same eigenvalues as eig"
1002 );
1003 }
1004
1005 // Test a simple 2x2 matrix
1006 let a = array!([[1.0f32, 1.0], [3.0, 4.0]]);
1007 check_eigs_and_vecs(&a);
1008
1009 // Test complex eigenvalues (rotation-like matrix)
1010 let a = array!([[1.0f32, -1.0], [1.0, 1.0]]);
1011 check_eigs_and_vecs(&a);
1012
1013 // Test a larger random matrix
1014 crate::random::seed(1).unwrap();
1015 let a = crate::random::normal::<f32>(&[5, 5], None, None, None).unwrap();
1016 check_eigs_and_vecs(&a);
1017
1018 // Test with batched input
1019 let a = crate::random::normal::<f32>(&[3, 5, 5], None, None, None).unwrap();
1020 check_eigs_and_vecs(&a);
1021 }
1022
1023 #[test]
1024 fn test_eig_errors() {
1025 // 1D array should fail
1026 let a = array!([1.0f32, 2.0]);
1027 assert!(eig_device(&a, StreamOrDevice::cpu()).is_err());
1028 assert!(eigvals_device(&a, StreamOrDevice::cpu()).is_err());
1029
1030 // Non-square matrix should fail
1031 let a = array!([[1.0f32, 2.0], [3.0, 4.0], [5.0, 6.0]]);
1032 assert!(eig_device(&a, StreamOrDevice::cpu()).is_err());
1033 assert!(eigvals_device(&a, StreamOrDevice::cpu()).is_err());
1034 }
1035}