1use crate::error::{Exception, Result};
4use crate::utils::guard::Guarded;
5use crate::utils::{IntoOption, VectorArray};
6use crate::{Array, Stream, StreamOrDevice};
7use mlx_internal_macros::{default_device, generate_macro};
8use smallvec::SmallVec;
9use std::f64;
10use std::ffi::CString;
11
12#[derive(Debug, Clone, Copy)]
16pub enum Ord<'a> {
17 Str(&'a str),
19
20 P(f64),
22}
23
24impl Default for Ord<'_> {
25 fn default() -> Self {
26 Ord::Str("fro")
27 }
28}
29
30impl std::fmt::Display for Ord<'_> {
31 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
32 match self {
33 Ord::Str(s) => write!(f, "{}", s),
34 Ord::P(p) => write!(f, "{}", p),
35 }
36 }
37}
38
39impl<'a> From<&'a str> for Ord<'a> {
40 fn from(value: &'a str) -> Self {
41 Ord::Str(value)
42 }
43}
44
45impl From<f64> for Ord<'_> {
46 fn from(value: f64) -> Self {
47 Ord::P(value)
48 }
49}
50
51impl<'a> IntoOption<Ord<'a>> for &'a str {
52 fn into_option(self) -> Option<Ord<'a>> {
53 Some(Ord::Str(self))
54 }
55}
56
57impl<'a> IntoOption<Ord<'a>> for f64 {
58 fn into_option(self) -> Option<Ord<'a>> {
59 Some(Ord::P(self))
60 }
61}
62
63#[generate_macro(customize(root = "$crate::linalg"))]
65#[default_device]
66pub fn norm_p_device<'a>(
67 array: impl AsRef<Array>,
68 ord: f64,
69 #[optional] axes: impl IntoOption<&'a [i32]>,
70 #[optional] keep_dims: impl Into<Option<bool>>,
71 #[optional] stream: impl AsRef<Stream>,
72) -> Result<Array> {
73 let keep_dims = keep_dims.into().unwrap_or(false);
74
75 match axes.into_option() {
76 Some(axes) => Array::try_from_op(|res| unsafe {
77 mlx_sys::mlx_linalg_norm_p(
78 res,
79 array.as_ref().as_ptr(),
80 ord,
81 axes.as_ptr(),
82 axes.len(),
83 keep_dims,
84 stream.as_ref().as_ptr(),
85 )
86 }),
87 None => Array::try_from_op(|res| unsafe {
88 mlx_sys::mlx_linalg_norm_p(
89 res,
90 array.as_ref().as_ptr(),
91 ord,
92 std::ptr::null(),
93 0,
94 keep_dims,
95 stream.as_ref().as_ptr(),
96 )
97 }),
98 }
99}
100
101#[generate_macro(customize(root = "$crate::linalg"))]
103#[default_device]
104pub fn norm_ord_device<'a>(
105 array: impl AsRef<Array>,
106 ord: &'a str,
107 #[optional] axes: impl IntoOption<&'a [i32]>,
108 #[optional] keep_dims: impl Into<Option<bool>>,
109 #[optional] stream: impl AsRef<Stream>,
110) -> Result<Array> {
111 let ord = CString::new(ord).map_err(|e| Exception::custom(format!("{}", e)))?;
112 let keep_dims = keep_dims.into().unwrap_or(false);
113
114 match axes.into_option() {
115 Some(axes) => Array::try_from_op(|res| unsafe {
116 mlx_sys::mlx_linalg_norm_ord(
117 res,
118 array.as_ref().as_ptr(),
119 ord.as_ptr(),
120 axes.as_ptr(),
121 axes.len(),
122 keep_dims,
123 stream.as_ref().as_ptr(),
124 )
125 }),
126 None => Array::try_from_op(|res| unsafe {
127 mlx_sys::mlx_linalg_norm_ord(
128 res,
129 array.as_ref().as_ptr(),
130 ord.as_ptr(),
131 std::ptr::null(),
132 0,
133 keep_dims,
134 stream.as_ref().as_ptr(),
135 )
136 }),
137 }
138}
139
140#[generate_macro(customize(root = "$crate::linalg"))]
179#[default_device]
180pub fn norm_device<'a>(
181 array: impl AsRef<Array>,
182 #[optional] ord: impl IntoOption<Ord<'a>>,
183 #[optional] axes: impl IntoOption<&'a [i32]>,
184 #[optional] keep_dims: impl Into<Option<bool>>,
185 #[optional] stream: impl AsRef<Stream>,
186) -> Result<Array> {
187 let ord = ord.into_option();
188 let axes = axes.into_option();
189 let keep_dims = keep_dims.into().unwrap_or(false);
190
191 match (ord, axes) {
192 (None, None) => {
194 let axes_ptr = std::ptr::null(); Array::try_from_op(|res| unsafe {
196 mlx_sys::mlx_linalg_norm(
197 res,
198 array.as_ref().as_ptr(),
199 axes_ptr,
200 0,
201 keep_dims,
202 stream.as_ref().as_ptr(),
203 )
204 })
205 }
206 (Some(Ord::Str(ord)), None) => norm_ord_device(array, ord, axes, keep_dims, stream),
210 (Some(Ord::P(p)), None) => norm_p_device(array, p, axes, keep_dims, stream),
211 (None, Some(axes)) => Array::try_from_op(|res| unsafe {
214 mlx_sys::mlx_linalg_norm(
215 res,
216 array.as_ref().as_ptr(),
217 axes.as_ptr(),
218 axes.len(),
219 keep_dims,
220 stream.as_ref().as_ptr(),
221 )
222 }),
223 (Some(Ord::Str(ord)), Some(axes)) => norm_ord_device(array, ord, axes, keep_dims, stream),
226 (Some(Ord::P(p)), Some(axes)) => norm_p_device(array, p, axes, keep_dims, stream),
227 }
228}
229
230#[generate_macro(customize(root = "$crate::linalg"))]
257#[default_device]
258pub fn qr_device(
259 a: impl AsRef<Array>,
260 #[optional] stream: impl AsRef<Stream>,
261) -> Result<(Array, Array)> {
262 <(Array, Array)>::try_from_op(|(res_0, res_1)| unsafe {
263 mlx_sys::mlx_linalg_qr(res_0, res_1, a.as_ref().as_ptr(), stream.as_ref().as_ptr())
264 })
265}
266
267#[generate_macro(customize(root = "$crate::linalg"))]
295#[default_device]
296pub fn svd_device(
297 array: impl AsRef<Array>,
298 #[optional] stream: impl AsRef<Stream>,
299) -> Result<(Array, Array, Array)> {
300 let v = VectorArray::try_from_op(|res| unsafe {
301 mlx_sys::mlx_linalg_svd(res, array.as_ref().as_ptr(), stream.as_ref().as_ptr())
302 })?;
303
304 let vals: SmallVec<[Array; 3]> = v.try_into_values()?;
305 let mut iter = vals.into_iter();
306 let u = iter.next().unwrap();
307 let s = iter.next().unwrap();
308 let vt = iter.next().unwrap();
309
310 Ok((u, s, vt))
311}
312
313#[generate_macro(customize(root = "$crate::linalg"))]
335#[default_device]
336pub fn inv_device(a: impl AsRef<Array>, #[optional] stream: impl AsRef<Stream>) -> Result<Array> {
337 Array::try_from_op(|res| unsafe {
338 mlx_sys::mlx_linalg_inv(res, a.as_ref().as_ptr(), stream.as_ref().as_ptr())
339 })
340}
341
342#[generate_macro(customize(root = "$crate::linalg"))]
356#[default_device]
357pub fn cholesky_device(
358 a: impl AsRef<Array>,
359 #[optional] upper: Option<bool>,
360 #[optional] stream: impl AsRef<Stream>,
361) -> Result<Array> {
362 let upper = upper.unwrap_or(false);
363 Array::try_from_op(|res| unsafe {
364 mlx_sys::mlx_linalg_cholesky(res, a.as_ref().as_ptr(), upper, stream.as_ref().as_ptr())
365 })
366}
367
368#[generate_macro(customize(root = "$crate::linalg"))]
372#[default_device]
373pub fn cholesky_inv_device(
374 a: impl AsRef<Array>,
375 #[optional] upper: Option<bool>,
376 #[optional] stream: impl AsRef<Stream>,
377) -> Result<Array> {
378 let upper = upper.unwrap_or(false);
379 Array::try_from_op(|res| unsafe {
380 mlx_sys::mlx_linalg_cholesky_inv(res, a.as_ref().as_ptr(), upper, stream.as_ref().as_ptr())
381 })
382}
383
384#[generate_macro(customize(root = "$crate::linalg"))]
389#[default_device]
390pub fn cross_device(
391 a: impl AsRef<Array>,
392 b: impl AsRef<Array>,
393 #[optional] axis: Option<i32>,
394 #[optional] stream: impl AsRef<Stream>,
395) -> Result<Array> {
396 let axis = axis.unwrap_or(-1);
397 Array::try_from_op(|res| unsafe {
398 mlx_sys::mlx_linalg_cross(
399 res,
400 a.as_ref().as_ptr(),
401 b.as_ref().as_ptr(),
402 axis,
403 stream.as_ref().as_ptr(),
404 )
405 })
406}
407
408#[generate_macro(customize(root = "$crate::linalg"))]
414#[default_device]
415pub fn eigh_device(
416 a: impl AsRef<Array>,
417 #[optional] uplo: Option<&str>,
418 #[optional] stream: impl AsRef<Stream>,
419) -> Result<(Array, Array)> {
420 let a = a.as_ref();
421 let uplo =
422 CString::new(uplo.unwrap_or("L")).map_err(|e| Exception::custom(format!("{}", e)))?;
423
424 <(Array, Array) as Guarded>::try_from_op(|(res_0, res_1)| unsafe {
425 mlx_sys::mlx_linalg_eigh(
426 res_0,
427 res_1,
428 a.as_ptr(),
429 uplo.as_ptr(),
430 stream.as_ref().as_ptr(),
431 )
432 })
433}
434
435#[generate_macro(customize(root = "$crate::linalg"))]
440#[default_device]
441pub fn eigvalsh_device(
442 a: impl AsRef<Array>,
443 #[optional] uplo: Option<&str>,
444 #[optional] stream: impl AsRef<Stream>,
445) -> Result<Array> {
446 let a = a.as_ref();
447 let uplo =
448 CString::new(uplo.unwrap_or("L")).map_err(|e| Exception::custom(format!("{}", e)))?;
449 Array::try_from_op(|res| unsafe {
450 mlx_sys::mlx_linalg_eigvalsh(res, a.as_ptr(), uplo.as_ptr(), stream.as_ref().as_ptr())
451 })
452}
453
454#[generate_macro(customize(root = "$crate::linalg"))]
456#[default_device]
457pub fn pinv_device(a: impl AsRef<Array>, #[optional] stream: impl AsRef<Stream>) -> Result<Array> {
458 Array::try_from_op(|res| unsafe {
459 mlx_sys::mlx_linalg_pinv(res, a.as_ref().as_ptr(), stream.as_ref().as_ptr())
460 })
461}
462
463#[generate_macro(customize(root = "$crate::linalg"))]
468#[default_device]
469pub fn tri_inv_device(
470 a: impl AsRef<Array>,
471 #[optional] upper: Option<bool>,
472 #[optional] stream: impl AsRef<Stream>,
473) -> Result<Array> {
474 let upper = upper.unwrap_or(false);
475 Array::try_from_op(|res| unsafe {
476 mlx_sys::mlx_linalg_tri_inv(res, a.as_ref().as_ptr(), upper, stream.as_ref().as_ptr())
477 })
478}
479
480#[generate_macro(customize(root = "$crate::linalg"))]
508#[default_device]
509pub fn lu_device(
510 a: impl AsRef<Array>,
511 #[optional] stream: impl AsRef<Stream>,
512) -> Result<(Array, Array, Array)> {
513 let v = Vec::<Array>::try_from_op(|res| unsafe {
514 mlx_sys::mlx_linalg_lu(res, a.as_ref().as_ptr(), stream.as_ref().as_ptr())
515 })?;
516 let mut iter = v.into_iter();
517 let p = iter.next().ok_or_else(|| Exception::custom("missing P"))?;
518 let l = iter.next().ok_or_else(|| Exception::custom("missing L"))?;
519 let u = iter.next().ok_or_else(|| Exception::custom("missing U"))?;
520 Ok((p, l, u))
521}
522
523#[generate_macro(customize(root = "$crate::linalg"))]
534#[default_device]
535pub fn lu_factor_device(
536 a: impl AsRef<Array>,
537 #[optional] stream: impl AsRef<Stream>,
538) -> Result<(Array, Array)> {
539 <(Array, Array)>::try_from_op(|(res_0, res_1)| unsafe {
540 mlx_sys::mlx_linalg_lu_factor(res_0, res_1, a.as_ref().as_ptr(), stream.as_ref().as_ptr())
541 })
542}
543
544#[generate_macro(customize(root = "$crate::linalg"))]
556#[default_device]
557pub fn solve_device(
558 a: impl AsRef<Array>,
559 b: impl AsRef<Array>,
560 #[optional] stream: impl AsRef<Stream>,
561) -> Result<Array> {
562 Array::try_from_op(|res| unsafe {
563 mlx_sys::mlx_linalg_solve(
564 res,
565 a.as_ref().as_ptr(),
566 b.as_ref().as_ptr(),
567 stream.as_ref().as_ptr(),
568 )
569 })
570}
571
572#[generate_macro(customize(root = "$crate::linalg"))]
585#[default_device]
586pub fn solve_triangular_device(
587 a: impl AsRef<Array>,
588 b: impl AsRef<Array>,
589 #[optional] upper: impl Into<Option<bool>>,
590 #[optional] stream: impl AsRef<Stream>,
591) -> Result<Array> {
592 let upper = upper.into().unwrap_or(false);
593
594 Array::try_from_op(|res| unsafe {
595 mlx_sys::mlx_linalg_solve_triangular(
596 res,
597 a.as_ref().as_ptr(),
598 b.as_ref().as_ptr(),
599 upper,
600 stream.as_ref().as_ptr(),
601 )
602 })
603}
604
605#[cfg(test)]
606mod tests {
607 use float_eq::assert_float_eq;
608
609 use crate::{
610 array,
611 ops::{eye, indexing::IndexOp, tril, triu},
612 };
613
614 use super::*;
615
616 #[test]
621 fn test_norm_no_axes() {
622 let a = Array::from_iter(0..9, &[9]) - 4;
623 let b = a.reshape(&[3, 3]).unwrap();
624
625 assert_float_eq!(
626 norm(&a, None, None, None).unwrap().item::<f32>(),
627 7.74597,
628 abs <= 0.001
629 );
630 assert_float_eq!(
631 norm(&b, None, None, None).unwrap().item::<f32>(),
632 7.74597,
633 abs <= 0.001
634 );
635
636 assert_float_eq!(
637 norm(&b, "fro", None, None).unwrap().item::<f32>(),
638 7.74597,
639 abs <= 0.001
640 );
641
642 assert_float_eq!(
643 norm(&a, f64::INFINITY, None, None).unwrap().item::<f32>(),
644 4.0,
645 abs <= 0.001
646 );
647 assert_float_eq!(
648 norm(&b, f64::INFINITY, None, None).unwrap().item::<f32>(),
649 9.0,
650 abs <= 0.001
651 );
652
653 assert_float_eq!(
654 norm(&a, f64::NEG_INFINITY, None, None)
655 .unwrap()
656 .item::<f32>(),
657 0.0,
658 abs <= 0.001
659 );
660 assert_float_eq!(
661 norm(&b, f64::NEG_INFINITY, None, None)
662 .unwrap()
663 .item::<f32>(),
664 2.0,
665 abs <= 0.001
666 );
667
668 assert_float_eq!(
669 norm(&a, 1.0, None, None).unwrap().item::<f32>(),
670 20.0,
671 abs <= 0.001
672 );
673 assert_float_eq!(
674 norm(&b, 1.0, None, None).unwrap().item::<f32>(),
675 7.0,
676 abs <= 0.001
677 );
678
679 assert_float_eq!(
680 norm(&a, -1.0, None, None).unwrap().item::<f32>(),
681 0.0,
682 abs <= 0.001
683 );
684 assert_float_eq!(
685 norm(&b, -1.0, None, None).unwrap().item::<f32>(),
686 6.0,
687 abs <= 0.001
688 );
689 }
690
691 #[test]
692 fn test_norm_axis() {
693 let c = Array::from_slice(&[1, 2, 3, -1, 1, 4], &[2, 3]);
694
695 let result = norm(&c, None, &[0][..], None).unwrap();
696 let expected = Array::from_slice(&[1.41421, 2.23607, 5.0], &[3]);
697 assert!(result
698 .all_close(&expected, None, None, None)
699 .unwrap()
700 .item::<bool>());
701 }
702
703 #[test]
704 fn test_norm_axes() {
705 let m = Array::from_iter(0..8, &[2, 2, 2]);
706
707 let result = norm(&m, None, &[1, 2][..], None).unwrap();
708 let expected = Array::from_slice(&[3.74166, 11.225], &[2]);
709 assert!(result
710 .all_close(&expected, None, None, None)
711 .unwrap()
712 .item::<bool>());
713 }
714
715 #[test]
716 fn test_qr() {
717 let a = Array::from_slice(&[2.0f32, 3.0, 1.0, 2.0], &[2, 2]);
718
719 let (q, r) = qr_device(&a, StreamOrDevice::cpu()).unwrap();
720
721 let q_expected = Array::from_slice(&[-0.894427, -0.447214, -0.447214, 0.894427], &[2, 2]);
722 let r_expected = Array::from_slice(&[-2.23607, -3.57771, 0.0, 0.447214], &[2, 2]);
723
724 assert!(q
725 .all_close(&q_expected, None, None, None)
726 .unwrap()
727 .item::<bool>());
728 assert!(r
729 .all_close(&r_expected, None, None, None)
730 .unwrap()
731 .item::<bool>());
732 }
733
734 #[test]
737 fn test_svd() {
738 let stream = StreamOrDevice::cpu();
740
741 let a = Array::from_f32(0.0);
743 assert!(svd_device(&a, &stream).is_err());
744
745 let a = Array::from_slice(&[0.0, 1.0], &[2]);
746 assert!(svd_device(&a, &stream).is_err());
747
748 let a = Array::from_slice(&[0, 1], &[1, 2]);
750 assert!(svd_device(&a, &stream).is_err());
751
752 }
754
755 #[test]
756 fn test_inv() {
757 let stream = StreamOrDevice::cpu();
759
760 let a = Array::from_f32(0.0);
762 assert!(inv_device(&a, &stream).is_err());
763
764 let a = Array::from_slice(&[0.0, 1.0], &[2]);
765 assert!(inv_device(&a, &stream).is_err());
766
767 let a = Array::from_slice(&[1, 2, 3, 4, 5, 6], &[2, 3]);
769 assert!(inv_device(&a, &stream).is_err());
770
771 }
773
774 #[test]
775 fn test_cholesky() {
776 let stream = StreamOrDevice::cpu();
778
779 let a = Array::from_f32(0.0);
781 assert!(cholesky_device(&a, None, &stream).is_err());
782
783 let a = Array::from_slice(&[0.0, 1.0], &[2]);
784 assert!(cholesky_device(&a, None, &stream).is_err());
785
786 let a = Array::from_slice(&[0, 1, 1, 2], &[2, 2]);
788 assert!(cholesky_device(&a, None, &stream).is_err());
789
790 let a = Array::from_slice(&[1, 2, 3, 4, 5, 6], &[2, 3]);
792 assert!(cholesky_device(&a, None, &stream).is_err());
793
794 }
796
797 #[test]
799 fn test_lu() {
800 let scalar = array!(1.0);
801 let result = lu_device(&scalar, StreamOrDevice::cpu());
802 assert!(result.is_err());
803
804 let a = array!([[3.0f32, 1.0, 2.0], [1.0, 8.0, 6.0], [9.0, 2.0, 5.0]]);
806 let (p, l, u) = lu_device(&a, StreamOrDevice::cpu()).unwrap();
807 let a_rec = l.index((p, ..)).matmul(u).unwrap();
808 assert_array_all_close!(a, a_rec);
809 }
810
811 #[test]
813 fn test_lu_factor() {
814 crate::random::seed(7).unwrap();
815
816 let a = crate::random::uniform::<_, f32>(0.0, 1.0, &[5, 5], None).unwrap();
818 let (lu, pivots) = lu_factor_device(&a, StreamOrDevice::cpu()).unwrap();
819 let shape = a.shape();
820 let n = shape[shape.len() - 1];
821
822 let pivots: Vec<u32> = pivots.as_slice().to_vec();
823 let mut perm: Vec<u32> = (0..n as u32).collect();
824 for (i, p) in pivots.iter().enumerate() {
825 perm.swap(i, *p as usize);
826 }
827
828 let l = tril(&lu, -1)
829 .and_then(|l| l.add(eye::<f32>(n, None, None)?))
830 .unwrap();
831 let u = triu(&lu, None).unwrap();
832
833 let lhs = l.matmul(&u).unwrap();
834 let perm = Array::from_slice(&perm, &[n]);
835 let rhs = a.index((perm, ..));
836 assert_array_all_close!(lhs, rhs);
837 }
838
839 #[test]
841 fn test_solve() {
842 crate::random::seed(7).unwrap();
843
844 let a = array!([[3.0f32, 1.0, 2.0], [1.0, 8.0, 6.0], [9.0, 2.0, 5.0]]);
846 let b = array!([11.0f32, 35.0, 28.0]);
847
848 let result = solve_device(&a, &b, StreamOrDevice::cpu()).unwrap();
849 let expected = array!([1.0f32, 2.0, 3.0]);
850 assert_array_all_close!(result, expected);
851 }
852
853 #[test]
854 fn test_solve_triangular() {
855 let a = array!([[4.0f32, 0.0, 0.0], [2.0, 3.0, 0.0], [1.0, -2.0, 5.0]]);
856 let b = array!([8.0f32, 14.0, 3.0]);
857
858 let result = solve_triangular_device(&a, &b, false, StreamOrDevice::cpu()).unwrap();
859 let expected = array!([2.0f32, 3.333_333_3, 1.533_333_3]);
860 assert_array_all_close!(result, expected);
861 }
862}