1use crate::error::{Exception, Result};
4use crate::utils::guard::Guarded;
5use crate::utils::{IntoOption, VectorArray};
6use crate::{Array, Stream};
7use mlx_internal_macros::{default_device, generate_macro};
8use smallvec::SmallVec;
9use std::f64;
10use std::ffi::CString;
11
12#[derive(Debug, Clone, Copy)]
16pub enum Ord<'a> {
17 Str(&'a str),
19
20 P(f64),
22}
23
24impl Default for Ord<'_> {
25 fn default() -> Self {
26 Ord::Str("fro")
27 }
28}
29
30impl std::fmt::Display for Ord<'_> {
31 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
32 match self {
33 Ord::Str(s) => write!(f, "{}", s),
34 Ord::P(p) => write!(f, "{}", p),
35 }
36 }
37}
38
39impl<'a> From<&'a str> for Ord<'a> {
40 fn from(value: &'a str) -> Self {
41 Ord::Str(value)
42 }
43}
44
45impl From<f64> for Ord<'_> {
46 fn from(value: f64) -> Self {
47 Ord::P(value)
48 }
49}
50
51impl<'a> IntoOption<Ord<'a>> for &'a str {
52 fn into_option(self) -> Option<Ord<'a>> {
53 Some(Ord::Str(self))
54 }
55}
56
57impl<'a> IntoOption<Ord<'a>> for f64 {
58 fn into_option(self) -> Option<Ord<'a>> {
59 Some(Ord::P(self))
60 }
61}
62
63#[generate_macro(customize(root = "$crate::linalg"))]
65#[default_device]
66pub fn norm_device<'a>(
67 array: impl AsRef<Array>,
68 ord: f64,
69 #[optional] axes: impl IntoOption<&'a [i32]>,
70 #[optional] keep_dims: impl Into<Option<bool>>,
71 #[optional] stream: impl AsRef<Stream>,
72) -> Result<Array> {
73 let keep_dims = keep_dims.into().unwrap_or(false);
74
75 match axes.into_option() {
76 Some(axes) => Array::try_from_op(|res| unsafe {
77 mlx_sys::mlx_linalg_norm(
78 res,
79 array.as_ref().as_ptr(),
80 ord,
81 axes.as_ptr(),
82 axes.len(),
83 keep_dims,
84 stream.as_ref().as_ptr(),
85 )
86 }),
87 None => Array::try_from_op(|res| unsafe {
88 mlx_sys::mlx_linalg_norm(
89 res,
90 array.as_ref().as_ptr(),
91 ord,
92 std::ptr::null(),
93 0,
94 keep_dims,
95 stream.as_ref().as_ptr(),
96 )
97 }),
98 }
99}
100
101#[generate_macro(customize(root = "$crate::linalg"))]
103#[default_device]
104pub fn norm_matrix_device<'a>(
105 array: impl AsRef<Array>,
106 ord: &'a str,
107 #[optional] axes: impl IntoOption<&'a [i32]>,
108 #[optional] keep_dims: impl Into<Option<bool>>,
109 #[optional] stream: impl AsRef<Stream>,
110) -> Result<Array> {
111 let ord = CString::new(ord).map_err(|e| Exception::custom(format!("{}", e)))?;
112 let keep_dims = keep_dims.into().unwrap_or(false);
113
114 match axes.into_option() {
115 Some(axes) => Array::try_from_op(|res| unsafe {
116 mlx_sys::mlx_linalg_norm_matrix(
117 res,
118 array.as_ref().as_ptr(),
119 ord.as_ptr(),
120 axes.as_ptr(),
121 axes.len(),
122 keep_dims,
123 stream.as_ref().as_ptr(),
124 )
125 }),
126 None => Array::try_from_op(|res| unsafe {
127 mlx_sys::mlx_linalg_norm_matrix(
128 res,
129 array.as_ref().as_ptr(),
130 ord.as_ptr(),
131 std::ptr::null(),
132 0,
133 keep_dims,
134 stream.as_ref().as_ptr(),
135 )
136 }),
137 }
138}
139
140#[generate_macro(customize(root = "$crate::linalg"))]
142#[default_device]
143pub fn norm_l2_device<'a>(
144 array: impl AsRef<Array>,
145 #[optional] axes: impl IntoOption<&'a [i32]>,
146 #[optional] keep_dims: impl Into<Option<bool>>,
147 #[optional] stream: impl AsRef<Stream>,
148) -> Result<Array> {
149 let keep_dims = keep_dims.into().unwrap_or(false);
150
151 match axes.into_option() {
152 Some(axis) => Array::try_from_op(|res| unsafe {
153 mlx_sys::mlx_linalg_norm_l2(
154 res,
155 array.as_ref().as_ptr(),
156 axis.as_ptr(),
157 axis.len(),
158 keep_dims,
159 stream.as_ref().as_ptr(),
160 )
161 }),
162 None => Array::try_from_op(|res| unsafe {
163 mlx_sys::mlx_linalg_norm_l2(
164 res,
165 array.as_ref().as_ptr(),
166 std::ptr::null(),
167 0,
168 keep_dims,
169 stream.as_ref().as_ptr(),
170 )
171 }),
172 }
173}
174
175#[generate_macro(customize(root = "$crate::linalg"))]
293#[default_device]
294pub fn qr_device(
295 a: impl AsRef<Array>,
296 #[optional] stream: impl AsRef<Stream>,
297) -> Result<(Array, Array)> {
298 <(Array, Array)>::try_from_op(|(res_0, res_1)| unsafe {
299 mlx_sys::mlx_linalg_qr(res_0, res_1, a.as_ref().as_ptr(), stream.as_ref().as_ptr())
300 })
301}
302
303#[generate_macro(customize(root = "$crate::linalg"))]
331#[default_device]
332pub fn svd_device(
333 array: impl AsRef<Array>,
334 #[optional] stream: impl AsRef<Stream>,
335) -> Result<(Array, Array, Array)> {
336 let v = VectorArray::try_from_op(|res| unsafe {
337 mlx_sys::mlx_linalg_svd(res, array.as_ref().as_ptr(), true, stream.as_ref().as_ptr())
338 })?;
339
340 let vals: SmallVec<[Array; 3]> = v.try_into_values()?;
341 let mut iter = vals.into_iter();
342 let u = iter.next().unwrap();
343 let s = iter.next().unwrap();
344 let vt = iter.next().unwrap();
345
346 Ok((u, s, vt))
347}
348
349#[generate_macro(customize(root = "$crate::linalg"))]
371#[default_device]
372pub fn inv_device(a: impl AsRef<Array>, #[optional] stream: impl AsRef<Stream>) -> Result<Array> {
373 Array::try_from_op(|res| unsafe {
374 mlx_sys::mlx_linalg_inv(res, a.as_ref().as_ptr(), stream.as_ref().as_ptr())
375 })
376}
377
378#[generate_macro(customize(root = "$crate::linalg"))]
392#[default_device]
393pub fn cholesky_device(
394 a: impl AsRef<Array>,
395 #[optional] upper: Option<bool>,
396 #[optional] stream: impl AsRef<Stream>,
397) -> Result<Array> {
398 let upper = upper.unwrap_or(false);
399 Array::try_from_op(|res| unsafe {
400 mlx_sys::mlx_linalg_cholesky(res, a.as_ref().as_ptr(), upper, stream.as_ref().as_ptr())
401 })
402}
403
404#[generate_macro(customize(root = "$crate::linalg"))]
408#[default_device]
409pub fn cholesky_inv_device(
410 a: impl AsRef<Array>,
411 #[optional] upper: Option<bool>,
412 #[optional] stream: impl AsRef<Stream>,
413) -> Result<Array> {
414 let upper = upper.unwrap_or(false);
415 Array::try_from_op(|res| unsafe {
416 mlx_sys::mlx_linalg_cholesky_inv(res, a.as_ref().as_ptr(), upper, stream.as_ref().as_ptr())
417 })
418}
419
420#[generate_macro(customize(root = "$crate::linalg"))]
425#[default_device]
426pub fn cross_device(
427 a: impl AsRef<Array>,
428 b: impl AsRef<Array>,
429 #[optional] axis: Option<i32>,
430 #[optional] stream: impl AsRef<Stream>,
431) -> Result<Array> {
432 let axis = axis.unwrap_or(-1);
433 Array::try_from_op(|res| unsafe {
434 mlx_sys::mlx_linalg_cross(
435 res,
436 a.as_ref().as_ptr(),
437 b.as_ref().as_ptr(),
438 axis,
439 stream.as_ref().as_ptr(),
440 )
441 })
442}
443
444#[generate_macro(customize(root = "$crate::linalg"))]
450#[default_device]
451pub fn eigh_device(
452 a: impl AsRef<Array>,
453 #[optional] uplo: Option<&str>,
454 #[optional] stream: impl AsRef<Stream>,
455) -> Result<(Array, Array)> {
456 let a = a.as_ref();
457 let uplo =
458 CString::new(uplo.unwrap_or("L")).map_err(|e| Exception::custom(format!("{}", e)))?;
459
460 <(Array, Array) as Guarded>::try_from_op(|(res_0, res_1)| unsafe {
461 mlx_sys::mlx_linalg_eigh(
462 res_0,
463 res_1,
464 a.as_ptr(),
465 uplo.as_ptr(),
466 stream.as_ref().as_ptr(),
467 )
468 })
469}
470
471#[generate_macro(customize(root = "$crate::linalg"))]
476#[default_device]
477pub fn eigvalsh_device(
478 a: impl AsRef<Array>,
479 #[optional] uplo: Option<&str>,
480 #[optional] stream: impl AsRef<Stream>,
481) -> Result<Array> {
482 let a = a.as_ref();
483 let uplo =
484 CString::new(uplo.unwrap_or("L")).map_err(|e| Exception::custom(format!("{}", e)))?;
485 Array::try_from_op(|res| unsafe {
486 mlx_sys::mlx_linalg_eigvalsh(res, a.as_ptr(), uplo.as_ptr(), stream.as_ref().as_ptr())
487 })
488}
489
490#[generate_macro(customize(root = "$crate::linalg"))]
492#[default_device]
493pub fn pinv_device(a: impl AsRef<Array>, #[optional] stream: impl AsRef<Stream>) -> Result<Array> {
494 Array::try_from_op(|res| unsafe {
495 mlx_sys::mlx_linalg_pinv(res, a.as_ref().as_ptr(), stream.as_ref().as_ptr())
496 })
497}
498
499#[generate_macro(customize(root = "$crate::linalg"))]
504#[default_device]
505pub fn tri_inv_device(
506 a: impl AsRef<Array>,
507 #[optional] upper: Option<bool>,
508 #[optional] stream: impl AsRef<Stream>,
509) -> Result<Array> {
510 let upper = upper.unwrap_or(false);
511 Array::try_from_op(|res| unsafe {
512 mlx_sys::mlx_linalg_tri_inv(res, a.as_ref().as_ptr(), upper, stream.as_ref().as_ptr())
513 })
514}
515
516#[generate_macro(customize(root = "$crate::linalg"))]
544#[default_device]
545pub fn lu_device(
546 a: impl AsRef<Array>,
547 #[optional] stream: impl AsRef<Stream>,
548) -> Result<(Array, Array, Array)> {
549 let v = Vec::<Array>::try_from_op(|res| unsafe {
550 mlx_sys::mlx_linalg_lu(res, a.as_ref().as_ptr(), stream.as_ref().as_ptr())
551 })?;
552 let mut iter = v.into_iter();
553 let p = iter.next().ok_or_else(|| Exception::custom("missing P"))?;
554 let l = iter.next().ok_or_else(|| Exception::custom("missing L"))?;
555 let u = iter.next().ok_or_else(|| Exception::custom("missing U"))?;
556 Ok((p, l, u))
557}
558
559#[generate_macro(customize(root = "$crate::linalg"))]
570#[default_device]
571pub fn lu_factor_device(
572 a: impl AsRef<Array>,
573 #[optional] stream: impl AsRef<Stream>,
574) -> Result<(Array, Array)> {
575 <(Array, Array)>::try_from_op(|(res_0, res_1)| unsafe {
576 mlx_sys::mlx_linalg_lu_factor(res_0, res_1, a.as_ref().as_ptr(), stream.as_ref().as_ptr())
577 })
578}
579
580#[generate_macro(customize(root = "$crate::linalg"))]
592#[default_device]
593pub fn solve_device(
594 a: impl AsRef<Array>,
595 b: impl AsRef<Array>,
596 #[optional] stream: impl AsRef<Stream>,
597) -> Result<Array> {
598 Array::try_from_op(|res| unsafe {
599 mlx_sys::mlx_linalg_solve(
600 res,
601 a.as_ref().as_ptr(),
602 b.as_ref().as_ptr(),
603 stream.as_ref().as_ptr(),
604 )
605 })
606}
607
608#[generate_macro(customize(root = "$crate::linalg"))]
621#[default_device]
622pub fn solve_triangular_device(
623 a: impl AsRef<Array>,
624 b: impl AsRef<Array>,
625 #[optional] upper: impl Into<Option<bool>>,
626 #[optional] stream: impl AsRef<Stream>,
627) -> Result<Array> {
628 let upper = upper.into().unwrap_or(false);
629
630 Array::try_from_op(|res| unsafe {
631 mlx_sys::mlx_linalg_solve_triangular(
632 res,
633 a.as_ref().as_ptr(),
634 b.as_ref().as_ptr(),
635 upper,
636 stream.as_ref().as_ptr(),
637 )
638 })
639}
640
641#[cfg(test)]
642mod tests {
643 use float_eq::assert_float_eq;
644
645 use crate::{
646 array,
647 ops::{eye, indexing::IndexOp, tril, triu},
648 StreamOrDevice,
649 };
650
651 use super::*;
652
653 #[test]
658 fn test_norm_no_axes() {
659 let a = Array::from_iter(0..9, &[9]) - 4;
660 let b = a.reshape(&[3, 3]).unwrap();
661
662 assert_float_eq!(
663 norm_l2(&a, None, None).unwrap().item::<f32>(),
664 7.74597,
665 abs <= 0.001
666 );
667 assert_float_eq!(
668 norm_l2(&b, None, None).unwrap().item::<f32>(),
669 7.74597,
670 abs <= 0.001
671 );
672
673 assert_float_eq!(
674 norm_matrix(&b, "fro", None, None).unwrap().item::<f32>(),
675 7.74597,
676 abs <= 0.001
677 );
678
679 assert_float_eq!(
680 norm(&a, f64::INFINITY, None, None).unwrap().item::<f32>(),
681 4.0,
682 abs <= 0.001
683 );
684 assert_float_eq!(
685 norm(&b, f64::INFINITY, None, None).unwrap().item::<f32>(),
686 9.0,
687 abs <= 0.001
688 );
689
690 assert_float_eq!(
691 norm(&a, f64::NEG_INFINITY, None, None)
692 .unwrap()
693 .item::<f32>(),
694 0.0,
695 abs <= 0.001
696 );
697 assert_float_eq!(
698 norm(&b, f64::NEG_INFINITY, None, None)
699 .unwrap()
700 .item::<f32>(),
701 2.0,
702 abs <= 0.001
703 );
704
705 assert_float_eq!(
706 norm(&a, 1.0, None, None).unwrap().item::<f32>(),
707 20.0,
708 abs <= 0.001
709 );
710 assert_float_eq!(
711 norm(&b, 1.0, None, None).unwrap().item::<f32>(),
712 7.0,
713 abs <= 0.001
714 );
715
716 assert_float_eq!(
717 norm(&a, -1.0, None, None).unwrap().item::<f32>(),
718 0.0,
719 abs <= 0.001
720 );
721 assert_float_eq!(
722 norm(&b, -1.0, None, None).unwrap().item::<f32>(),
723 6.0,
724 abs <= 0.001
725 );
726 }
727
728 #[test]
729 fn test_norm_axis() {
730 let c = Array::from_slice(&[1, 2, 3, -1, 1, 4], &[2, 3]);
731
732 let result = norm_l2(&c, &[0], None).unwrap();
733 let expected = Array::from_slice(&[1.41421, 2.23607, 5.0], &[3]);
734 assert!(result
735 .all_close(&expected, None, None, None)
736 .unwrap()
737 .item::<bool>());
738 }
739
740 #[test]
741 fn test_norm_axes() {
742 let m = Array::from_iter(0..8, &[2, 2, 2]);
743
744 let result = norm_l2(&m, &[1, 2][..], None).unwrap();
745 let expected = Array::from_slice(&[3.74166, 11.225], &[2]);
746 assert!(result
747 .all_close(&expected, None, None, None)
748 .unwrap()
749 .item::<bool>());
750 }
751
752 #[test]
753 fn test_qr() {
754 let a = Array::from_slice(&[2.0f32, 3.0, 1.0, 2.0], &[2, 2]);
755
756 let (q, r) = qr_device(&a, StreamOrDevice::cpu()).unwrap();
757
758 let q_expected = Array::from_slice(&[-0.894427, -0.447214, -0.447214, 0.894427], &[2, 2]);
759 let r_expected = Array::from_slice(&[-2.23607, -3.57771, 0.0, 0.447214], &[2, 2]);
760
761 assert!(q
762 .all_close(&q_expected, None, None, None)
763 .unwrap()
764 .item::<bool>());
765 assert!(r
766 .all_close(&r_expected, None, None, None)
767 .unwrap()
768 .item::<bool>());
769 }
770
771 #[test]
774 fn test_svd() {
775 let stream = StreamOrDevice::cpu();
777
778 let a = Array::from_f32(0.0);
780 assert!(svd_device(&a, &stream).is_err());
781
782 let a = Array::from_slice(&[0.0, 1.0], &[2]);
783 assert!(svd_device(&a, &stream).is_err());
784
785 let a = Array::from_slice(&[0, 1], &[1, 2]);
787 assert!(svd_device(&a, &stream).is_err());
788
789 }
791
792 #[test]
793 fn test_inv() {
794 let stream = StreamOrDevice::cpu();
796
797 let a = Array::from_f32(0.0);
799 assert!(inv_device(&a, &stream).is_err());
800
801 let a = Array::from_slice(&[0.0, 1.0], &[2]);
802 assert!(inv_device(&a, &stream).is_err());
803
804 let a = Array::from_slice(&[1, 2, 3, 4, 5, 6], &[2, 3]);
806 assert!(inv_device(&a, &stream).is_err());
807
808 }
810
811 #[test]
812 fn test_cholesky() {
813 let stream = StreamOrDevice::cpu();
815
816 let a = Array::from_f32(0.0);
818 assert!(cholesky_device(&a, None, &stream).is_err());
819
820 let a = Array::from_slice(&[0.0, 1.0], &[2]);
821 assert!(cholesky_device(&a, None, &stream).is_err());
822
823 let a = Array::from_slice(&[0, 1, 1, 2], &[2, 2]);
825 assert!(cholesky_device(&a, None, &stream).is_err());
826
827 let a = Array::from_slice(&[1, 2, 3, 4, 5, 6], &[2, 3]);
829 assert!(cholesky_device(&a, None, &stream).is_err());
830
831 }
833
834 #[test]
836 fn test_lu() {
837 let scalar = array!(1.0);
838 let result = lu_device(&scalar, StreamOrDevice::cpu());
839 assert!(result.is_err());
840
841 let a = array!([[3.0f32, 1.0, 2.0], [1.0, 8.0, 6.0], [9.0, 2.0, 5.0]]);
843 let (p, l, u) = lu_device(&a, StreamOrDevice::cpu()).unwrap();
844 let a_rec = l.index((p, ..)).matmul(u).unwrap();
845 assert_array_all_close!(a, a_rec);
846 }
847
848 #[test]
850 fn test_lu_factor() {
851 crate::random::seed(7).unwrap();
852
853 let a = crate::random::uniform::<_, f32>(0.0, 1.0, &[5, 5], None).unwrap();
855 let (lu, pivots) = lu_factor_device(&a, StreamOrDevice::cpu()).unwrap();
856 let shape = a.shape();
857 let n = shape[shape.len() - 1];
858
859 let pivots: Vec<u32> = pivots.as_slice().to_vec();
860 let mut perm: Vec<u32> = (0..n as u32).collect();
861 for (i, p) in pivots.iter().enumerate() {
862 perm.swap(i, *p as usize);
863 }
864
865 let l = tril(&lu, -1)
866 .and_then(|l| l.add(eye::<f32>(n, None, None)?))
867 .unwrap();
868 let u = triu(&lu, None).unwrap();
869
870 let lhs = l.matmul(&u).unwrap();
871 let perm = Array::from_slice(&perm, &[n]);
872 let rhs = a.index((perm, ..));
873 assert_array_all_close!(lhs, rhs);
874 }
875
876 #[test]
878 fn test_solve() {
879 crate::random::seed(7).unwrap();
880
881 let a = array!([[3.0f32, 1.0, 2.0], [1.0, 8.0, 6.0], [9.0, 2.0, 5.0]]);
883 let b = array!([11.0f32, 35.0, 28.0]);
884
885 let result = solve_device(&a, &b, StreamOrDevice::cpu()).unwrap();
886 let expected = array!([1.0f32, 2.0, 3.0]);
887 assert_array_all_close!(result, expected);
888 }
889
890 #[test]
891 fn test_solve_triangular() {
892 let a = array!([[4.0f32, 0.0, 0.0], [2.0, 3.0, 0.0], [1.0, -2.0, 5.0]]);
893 let b = array!([8.0f32, 14.0, 3.0]);
894
895 let result = solve_triangular_device(&a, &b, false, StreamOrDevice::cpu()).unwrap();
896 let expected = array!([2.0f32, 3.333_333_3, 1.533_333_3]);
897 assert_array_all_close!(result, expected);
898 }
899}