1use crate::error::{Exception, Result};
4use crate::utils::guard::Guarded;
5use crate::utils::{IntoOption, VectorArray};
6use crate::{Array, Stream};
7use mlx_internal_macros::{default_device, generate_macro};
8use smallvec::SmallVec;
9use std::f64;
10use std::ffi::CString;
11
12#[derive(Debug, Clone, Copy)]
16pub enum Ord<'a> {
17 Str(&'a str),
19
20 P(f64),
22}
23
24impl Default for Ord<'_> {
25 fn default() -> Self {
26 Ord::Str("fro")
27 }
28}
29
30impl std::fmt::Display for Ord<'_> {
31 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
32 match self {
33 Ord::Str(s) => write!(f, "{s}"),
34 Ord::P(p) => write!(f, "{p}"),
35 }
36 }
37}
38
39impl<'a> From<&'a str> for Ord<'a> {
40 fn from(value: &'a str) -> Self {
41 Ord::Str(value)
42 }
43}
44
45impl From<f64> for Ord<'_> {
46 fn from(value: f64) -> Self {
47 Ord::P(value)
48 }
49}
50
51impl<'a> IntoOption<Ord<'a>> for &'a str {
52 fn into_option(self) -> Option<Ord<'a>> {
53 Some(Ord::Str(self))
54 }
55}
56
57impl<'a> IntoOption<Ord<'a>> for f64 {
58 fn into_option(self) -> Option<Ord<'a>> {
59 Some(Ord::P(self))
60 }
61}
62
63#[generate_macro(customize(root = "$crate::linalg"))]
65#[default_device]
66pub fn norm_device<'a>(
67 array: impl AsRef<Array>,
68 ord: f64,
69 #[optional] axes: impl IntoOption<&'a [i32]>,
70 #[optional] keep_dims: impl Into<Option<bool>>,
71 #[optional] stream: impl AsRef<Stream>,
72) -> Result<Array> {
73 let keep_dims = keep_dims.into().unwrap_or(false);
74
75 match axes.into_option() {
76 Some(axes) => Array::try_from_op(|res| unsafe {
77 mlx_sys::mlx_linalg_norm(
78 res,
79 array.as_ref().as_ptr(),
80 ord,
81 axes.as_ptr(),
82 axes.len(),
83 keep_dims,
84 stream.as_ref().as_ptr(),
85 )
86 }),
87 None => Array::try_from_op(|res| unsafe {
88 mlx_sys::mlx_linalg_norm(
89 res,
90 array.as_ref().as_ptr(),
91 ord,
92 std::ptr::null(),
93 0,
94 keep_dims,
95 stream.as_ref().as_ptr(),
96 )
97 }),
98 }
99}
100
101#[generate_macro(customize(root = "$crate::linalg"))]
103#[default_device]
104pub fn norm_matrix_device<'a>(
105 array: impl AsRef<Array>,
106 ord: &'a str,
107 #[optional] axes: impl IntoOption<&'a [i32]>,
108 #[optional] keep_dims: impl Into<Option<bool>>,
109 #[optional] stream: impl AsRef<Stream>,
110) -> Result<Array> {
111 let ord = CString::new(ord).map_err(|e| Exception::custom(format!("{e}")))?;
112 let keep_dims = keep_dims.into().unwrap_or(false);
113
114 match axes.into_option() {
115 Some(axes) => Array::try_from_op(|res| unsafe {
116 mlx_sys::mlx_linalg_norm_matrix(
117 res,
118 array.as_ref().as_ptr(),
119 ord.as_ptr(),
120 axes.as_ptr(),
121 axes.len(),
122 keep_dims,
123 stream.as_ref().as_ptr(),
124 )
125 }),
126 None => Array::try_from_op(|res| unsafe {
127 mlx_sys::mlx_linalg_norm_matrix(
128 res,
129 array.as_ref().as_ptr(),
130 ord.as_ptr(),
131 std::ptr::null(),
132 0,
133 keep_dims,
134 stream.as_ref().as_ptr(),
135 )
136 }),
137 }
138}
139
140#[generate_macro(customize(root = "$crate::linalg"))]
142#[default_device]
143pub fn norm_l2_device<'a>(
144 array: impl AsRef<Array>,
145 #[optional] axes: impl IntoOption<&'a [i32]>,
146 #[optional] keep_dims: impl Into<Option<bool>>,
147 #[optional] stream: impl AsRef<Stream>,
148) -> Result<Array> {
149 let keep_dims = keep_dims.into().unwrap_or(false);
150
151 match axes.into_option() {
152 Some(axis) => Array::try_from_op(|res| unsafe {
153 mlx_sys::mlx_linalg_norm_l2(
154 res,
155 array.as_ref().as_ptr(),
156 axis.as_ptr(),
157 axis.len(),
158 keep_dims,
159 stream.as_ref().as_ptr(),
160 )
161 }),
162 None => Array::try_from_op(|res| unsafe {
163 mlx_sys::mlx_linalg_norm_l2(
164 res,
165 array.as_ref().as_ptr(),
166 std::ptr::null(),
167 0,
168 keep_dims,
169 stream.as_ref().as_ptr(),
170 )
171 }),
172 }
173}
174
175#[generate_macro(customize(root = "$crate::linalg"))]
293#[default_device]
294pub fn qr_device(
295 a: impl AsRef<Array>,
296 #[optional] stream: impl AsRef<Stream>,
297) -> Result<(Array, Array)> {
298 <(Array, Array)>::try_from_op(|(res_0, res_1)| unsafe {
299 mlx_sys::mlx_linalg_qr(res_0, res_1, a.as_ref().as_ptr(), stream.as_ref().as_ptr())
300 })
301}
302
303#[generate_macro(customize(root = "$crate::linalg"))]
331#[default_device]
332pub fn svd_device(
333 array: impl AsRef<Array>,
334 #[optional] stream: impl AsRef<Stream>,
335) -> Result<(Array, Array, Array)> {
336 let v = VectorArray::try_from_op(|res| unsafe {
337 mlx_sys::mlx_linalg_svd(res, array.as_ref().as_ptr(), true, stream.as_ref().as_ptr())
338 })?;
339
340 let vals: SmallVec<[Array; 3]> = v.try_into_values()?;
341 let mut iter = vals.into_iter();
342 let u = iter.next().unwrap();
343 let s = iter.next().unwrap();
344 let vt = iter.next().unwrap();
345
346 Ok((u, s, vt))
347}
348
349#[generate_macro(customize(root = "$crate::linalg"))]
371#[default_device]
372pub fn inv_device(a: impl AsRef<Array>, #[optional] stream: impl AsRef<Stream>) -> Result<Array> {
373 Array::try_from_op(|res| unsafe {
374 mlx_sys::mlx_linalg_inv(res, a.as_ref().as_ptr(), stream.as_ref().as_ptr())
375 })
376}
377
378#[generate_macro(customize(root = "$crate::linalg"))]
392#[default_device]
393pub fn cholesky_device(
394 a: impl AsRef<Array>,
395 #[optional] upper: Option<bool>,
396 #[optional] stream: impl AsRef<Stream>,
397) -> Result<Array> {
398 let upper = upper.unwrap_or(false);
399 Array::try_from_op(|res| unsafe {
400 mlx_sys::mlx_linalg_cholesky(res, a.as_ref().as_ptr(), upper, stream.as_ref().as_ptr())
401 })
402}
403
404#[generate_macro(customize(root = "$crate::linalg"))]
408#[default_device]
409pub fn cholesky_inv_device(
410 a: impl AsRef<Array>,
411 #[optional] upper: Option<bool>,
412 #[optional] stream: impl AsRef<Stream>,
413) -> Result<Array> {
414 let upper = upper.unwrap_or(false);
415 Array::try_from_op(|res| unsafe {
416 mlx_sys::mlx_linalg_cholesky_inv(res, a.as_ref().as_ptr(), upper, stream.as_ref().as_ptr())
417 })
418}
419
420#[generate_macro(customize(root = "$crate::linalg"))]
425#[default_device]
426pub fn cross_device(
427 a: impl AsRef<Array>,
428 b: impl AsRef<Array>,
429 #[optional] axis: Option<i32>,
430 #[optional] stream: impl AsRef<Stream>,
431) -> Result<Array> {
432 let axis = axis.unwrap_or(-1);
433 Array::try_from_op(|res| unsafe {
434 mlx_sys::mlx_linalg_cross(
435 res,
436 a.as_ref().as_ptr(),
437 b.as_ref().as_ptr(),
438 axis,
439 stream.as_ref().as_ptr(),
440 )
441 })
442}
443
444#[generate_macro(customize(root = "$crate::linalg"))]
450#[default_device]
451pub fn eigh_device(
452 a: impl AsRef<Array>,
453 #[optional] uplo: Option<&str>,
454 #[optional] stream: impl AsRef<Stream>,
455) -> Result<(Array, Array)> {
456 let a = a.as_ref();
457 let uplo = CString::new(uplo.unwrap_or("L")).map_err(|e| Exception::custom(format!("{e}")))?;
458
459 <(Array, Array) as Guarded>::try_from_op(|(res_0, res_1)| unsafe {
460 mlx_sys::mlx_linalg_eigh(
461 res_0,
462 res_1,
463 a.as_ptr(),
464 uplo.as_ptr(),
465 stream.as_ref().as_ptr(),
466 )
467 })
468}
469
470#[generate_macro(customize(root = "$crate::linalg"))]
475#[default_device]
476pub fn eigvalsh_device(
477 a: impl AsRef<Array>,
478 #[optional] uplo: Option<&str>,
479 #[optional] stream: impl AsRef<Stream>,
480) -> Result<Array> {
481 let a = a.as_ref();
482 let uplo = CString::new(uplo.unwrap_or("L")).map_err(|e| Exception::custom(format!("{e}")))?;
483 Array::try_from_op(|res| unsafe {
484 mlx_sys::mlx_linalg_eigvalsh(res, a.as_ptr(), uplo.as_ptr(), stream.as_ref().as_ptr())
485 })
486}
487
488#[generate_macro(customize(root = "$crate::linalg"))]
490#[default_device]
491pub fn pinv_device(a: impl AsRef<Array>, #[optional] stream: impl AsRef<Stream>) -> Result<Array> {
492 Array::try_from_op(|res| unsafe {
493 mlx_sys::mlx_linalg_pinv(res, a.as_ref().as_ptr(), stream.as_ref().as_ptr())
494 })
495}
496
497#[generate_macro(customize(root = "$crate::linalg"))]
502#[default_device]
503pub fn tri_inv_device(
504 a: impl AsRef<Array>,
505 #[optional] upper: Option<bool>,
506 #[optional] stream: impl AsRef<Stream>,
507) -> Result<Array> {
508 let upper = upper.unwrap_or(false);
509 Array::try_from_op(|res| unsafe {
510 mlx_sys::mlx_linalg_tri_inv(res, a.as_ref().as_ptr(), upper, stream.as_ref().as_ptr())
511 })
512}
513
514#[generate_macro(customize(root = "$crate::linalg"))]
542#[default_device]
543pub fn lu_device(
544 a: impl AsRef<Array>,
545 #[optional] stream: impl AsRef<Stream>,
546) -> Result<(Array, Array, Array)> {
547 let v = Vec::<Array>::try_from_op(|res| unsafe {
548 mlx_sys::mlx_linalg_lu(res, a.as_ref().as_ptr(), stream.as_ref().as_ptr())
549 })?;
550 let mut iter = v.into_iter();
551 let p = iter.next().ok_or_else(|| Exception::custom("missing P"))?;
552 let l = iter.next().ok_or_else(|| Exception::custom("missing L"))?;
553 let u = iter.next().ok_or_else(|| Exception::custom("missing U"))?;
554 Ok((p, l, u))
555}
556
557#[generate_macro(customize(root = "$crate::linalg"))]
568#[default_device]
569pub fn lu_factor_device(
570 a: impl AsRef<Array>,
571 #[optional] stream: impl AsRef<Stream>,
572) -> Result<(Array, Array)> {
573 <(Array, Array)>::try_from_op(|(res_0, res_1)| unsafe {
574 mlx_sys::mlx_linalg_lu_factor(res_0, res_1, a.as_ref().as_ptr(), stream.as_ref().as_ptr())
575 })
576}
577
578#[generate_macro(customize(root = "$crate::linalg"))]
590#[default_device]
591pub fn solve_device(
592 a: impl AsRef<Array>,
593 b: impl AsRef<Array>,
594 #[optional] stream: impl AsRef<Stream>,
595) -> Result<Array> {
596 Array::try_from_op(|res| unsafe {
597 mlx_sys::mlx_linalg_solve(
598 res,
599 a.as_ref().as_ptr(),
600 b.as_ref().as_ptr(),
601 stream.as_ref().as_ptr(),
602 )
603 })
604}
605
606#[generate_macro(customize(root = "$crate::linalg"))]
619#[default_device]
620pub fn solve_triangular_device(
621 a: impl AsRef<Array>,
622 b: impl AsRef<Array>,
623 #[optional] upper: impl Into<Option<bool>>,
624 #[optional] stream: impl AsRef<Stream>,
625) -> Result<Array> {
626 let upper = upper.into().unwrap_or(false);
627
628 Array::try_from_op(|res| unsafe {
629 mlx_sys::mlx_linalg_solve_triangular(
630 res,
631 a.as_ref().as_ptr(),
632 b.as_ref().as_ptr(),
633 upper,
634 stream.as_ref().as_ptr(),
635 )
636 })
637}
638
639#[cfg(test)]
640mod tests {
641 use float_eq::assert_float_eq;
642
643 use crate::{
644 array,
645 ops::{eye, indexing::IndexOp, tril, triu},
646 StreamOrDevice,
647 };
648
649 use super::*;
650
651 #[test]
656 fn test_norm_no_axes() {
657 let a = Array::from_iter(0..9, &[9]) - 4;
658 let b = a.reshape(&[3, 3]).unwrap();
659
660 assert_float_eq!(
661 norm_l2(&a, None, None).unwrap().item::<f32>(),
662 7.74597,
663 abs <= 0.001
664 );
665 assert_float_eq!(
666 norm_l2(&b, None, None).unwrap().item::<f32>(),
667 7.74597,
668 abs <= 0.001
669 );
670
671 assert_float_eq!(
672 norm_matrix(&b, "fro", None, None).unwrap().item::<f32>(),
673 7.74597,
674 abs <= 0.001
675 );
676
677 assert_float_eq!(
678 norm(&a, f64::INFINITY, None, None).unwrap().item::<f32>(),
679 4.0,
680 abs <= 0.001
681 );
682 assert_float_eq!(
683 norm(&b, f64::INFINITY, None, None).unwrap().item::<f32>(),
684 9.0,
685 abs <= 0.001
686 );
687
688 assert_float_eq!(
689 norm(&a, f64::NEG_INFINITY, None, None)
690 .unwrap()
691 .item::<f32>(),
692 0.0,
693 abs <= 0.001
694 );
695 assert_float_eq!(
696 norm(&b, f64::NEG_INFINITY, None, None)
697 .unwrap()
698 .item::<f32>(),
699 2.0,
700 abs <= 0.001
701 );
702
703 assert_float_eq!(
704 norm(&a, 1.0, None, None).unwrap().item::<f32>(),
705 20.0,
706 abs <= 0.001
707 );
708 assert_float_eq!(
709 norm(&b, 1.0, None, None).unwrap().item::<f32>(),
710 7.0,
711 abs <= 0.001
712 );
713
714 assert_float_eq!(
715 norm(&a, -1.0, None, None).unwrap().item::<f32>(),
716 0.0,
717 abs <= 0.001
718 );
719 assert_float_eq!(
720 norm(&b, -1.0, None, None).unwrap().item::<f32>(),
721 6.0,
722 abs <= 0.001
723 );
724 }
725
726 #[test]
727 fn test_norm_axis() {
728 let c = Array::from_slice(&[1, 2, 3, -1, 1, 4], &[2, 3]);
729
730 let result = norm_l2(&c, &[0], None).unwrap();
731 let expected = Array::from_slice(&[1.41421, 2.23607, 5.0], &[3]);
732 assert!(result
733 .all_close(&expected, None, None, None)
734 .unwrap()
735 .item::<bool>());
736 }
737
738 #[test]
739 fn test_norm_axes() {
740 let m = Array::from_iter(0..8, &[2, 2, 2]);
741
742 let result = norm_l2(&m, &[1, 2][..], None).unwrap();
743 let expected = Array::from_slice(&[3.74166, 11.225], &[2]);
744 assert!(result
745 .all_close(&expected, None, None, None)
746 .unwrap()
747 .item::<bool>());
748 }
749
750 #[test]
751 fn test_qr() {
752 let a = Array::from_slice(&[2.0f32, 3.0, 1.0, 2.0], &[2, 2]);
753
754 let (q, r) = qr_device(&a, StreamOrDevice::cpu()).unwrap();
755
756 let q_expected = Array::from_slice(&[-0.894427, -0.447214, -0.447214, 0.894427], &[2, 2]);
757 let r_expected = Array::from_slice(&[-2.23607, -3.57771, 0.0, 0.447214], &[2, 2]);
758
759 assert!(q
760 .all_close(&q_expected, None, None, None)
761 .unwrap()
762 .item::<bool>());
763 assert!(r
764 .all_close(&r_expected, None, None, None)
765 .unwrap()
766 .item::<bool>());
767 }
768
769 #[test]
772 fn test_svd() {
773 let stream = StreamOrDevice::cpu();
775
776 let a = Array::from_f32(0.0);
778 assert!(svd_device(&a, &stream).is_err());
779
780 let a = Array::from_slice(&[0.0, 1.0], &[2]);
781 assert!(svd_device(&a, &stream).is_err());
782
783 let a = Array::from_slice(&[0, 1], &[1, 2]);
785 assert!(svd_device(&a, &stream).is_err());
786
787 }
789
790 #[test]
791 fn test_inv() {
792 let stream = StreamOrDevice::cpu();
794
795 let a = Array::from_f32(0.0);
797 assert!(inv_device(&a, &stream).is_err());
798
799 let a = Array::from_slice(&[0.0, 1.0], &[2]);
800 assert!(inv_device(&a, &stream).is_err());
801
802 let a = Array::from_slice(&[1, 2, 3, 4, 5, 6], &[2, 3]);
804 assert!(inv_device(&a, &stream).is_err());
805
806 }
808
809 #[test]
810 fn test_cholesky() {
811 let stream = StreamOrDevice::cpu();
813
814 let a = Array::from_f32(0.0);
816 assert!(cholesky_device(&a, None, &stream).is_err());
817
818 let a = Array::from_slice(&[0.0, 1.0], &[2]);
819 assert!(cholesky_device(&a, None, &stream).is_err());
820
821 let a = Array::from_slice(&[0, 1, 1, 2], &[2, 2]);
823 assert!(cholesky_device(&a, None, &stream).is_err());
824
825 let a = Array::from_slice(&[1, 2, 3, 4, 5, 6], &[2, 3]);
827 assert!(cholesky_device(&a, None, &stream).is_err());
828
829 }
831
832 #[test]
834 fn test_lu() {
835 let scalar = array!(1.0);
836 let result = lu_device(&scalar, StreamOrDevice::cpu());
837 assert!(result.is_err());
838
839 let a = array!([[3.0f32, 1.0, 2.0], [1.0, 8.0, 6.0], [9.0, 2.0, 5.0]]);
841 let (p, l, u) = lu_device(&a, StreamOrDevice::cpu()).unwrap();
842 let a_rec = l.index((p, ..)).matmul(u).unwrap();
843 assert_array_all_close!(a, a_rec);
844 }
845
846 #[test]
848 fn test_lu_factor() {
849 crate::random::seed(7).unwrap();
850
851 let a = crate::random::uniform::<_, f32>(0.0, 1.0, &[5, 5], None).unwrap();
853 let (lu, pivots) = lu_factor_device(&a, StreamOrDevice::cpu()).unwrap();
854 let shape = a.shape();
855 let n = shape[shape.len() - 1];
856
857 let pivots: Vec<u32> = pivots.as_slice().to_vec();
858 let mut perm: Vec<u32> = (0..n as u32).collect();
859 for (i, p) in pivots.iter().enumerate() {
860 perm.swap(i, *p as usize);
861 }
862
863 let l = tril(&lu, -1)
864 .and_then(|l| l.add(eye::<f32>(n, None, None)?))
865 .unwrap();
866 let u = triu(&lu, None).unwrap();
867
868 let lhs = l.matmul(&u).unwrap();
869 let perm = Array::from_slice(&perm, &[n]);
870 let rhs = a.index((perm, ..));
871 assert_array_all_close!(lhs, rhs);
872 }
873
874 #[test]
876 fn test_solve() {
877 crate::random::seed(7).unwrap();
878
879 let a = array!([[3.0f32, 1.0, 2.0], [1.0, 8.0, 6.0], [9.0, 2.0, 5.0]]);
881 let b = array!([11.0f32, 35.0, 28.0]);
882
883 let result = solve_device(&a, &b, StreamOrDevice::cpu()).unwrap();
884 let expected = array!([1.0f32, 2.0, 3.0]);
885 assert_array_all_close!(result, expected);
886 }
887
888 #[test]
889 fn test_solve_triangular() {
890 let a = array!([[4.0f32, 0.0, 0.0], [2.0, 3.0, 0.0], [1.0, -2.0, 5.0]]);
891 let b = array!([8.0f32, 14.0, 3.0]);
892
893 let result = solve_triangular_device(&a, &b, false, StreamOrDevice::cpu()).unwrap();
894 let expected = array!([2.0f32, 3.333_333_3, 1.533_333_3]);
895 assert_array_all_close!(result, expected);
896 }
897}