mlx_rs/array/
operators.rs

1use crate::{utils::ScalarOrArray, Array, StreamOrDevice};
2use num_traits::Pow;
3use std::{
4    iter::Product,
5    ops::{
6        Add, AddAssign, Div, DivAssign, Mul, MulAssign, Neg, Not, Rem, RemAssign, Sub, SubAssign,
7    },
8};
9
10macro_rules! impl_binary_op {
11    ($trait:ident, $method:ident, $c_method:ident) => {
12        impl<'a, T> $trait<T> for Array
13        where
14            T: ScalarOrArray<'a>,
15        {
16            type Output = Array;
17
18            fn $method(self, rhs: T) -> Self::Output {
19                paste::paste! {
20                    self.[<$c_method _device>](rhs.into_owned_or_ref_array(), StreamOrDevice::default()).unwrap()
21                }
22            }
23        }
24
25        impl<'a, 't: 'a, T> $trait<T> for &'a Array
26        where
27            T: ScalarOrArray<'t>,
28        {
29            type Output = Array;
30
31            fn $method(self, rhs: T) -> Self::Output {
32                paste::paste! {
33                    self.[<$c_method _device>](rhs.into_owned_or_ref_array(), StreamOrDevice::default()).unwrap()
34                }
35            }
36        }
37    };
38}
39
40macro_rules! impl_binary_op_assign {
41    ($trait:ident, $method:ident, $c_method:ident) => {
42        impl<T: Into<Array>> $trait<T> for Array {
43            fn $method(&mut self, rhs: T) {
44                let new_array = paste::paste! {
45                    self.[<$c_method _device>](&rhs.into(), StreamOrDevice::default()).unwrap()
46                };
47                *self = new_array;
48            }
49        }
50
51        impl $trait<&Array> for Array {
52            fn $method(&mut self, rhs: &Self) {
53                let new_array = paste::paste! {
54                    self.[<$c_method _device>](rhs, StreamOrDevice::default()).unwrap()
55                };
56                *self = new_array;
57            }
58        }
59    };
60}
61
62impl_binary_op!(Add, add, add);
63impl_binary_op_assign!(AddAssign, add_assign, add);
64impl_binary_op!(Sub, sub, subtract);
65impl_binary_op_assign!(SubAssign, sub_assign, subtract);
66impl_binary_op!(Mul, mul, multiply);
67impl_binary_op_assign!(MulAssign, mul_assign, multiply);
68impl_binary_op!(Div, div, divide);
69impl_binary_op_assign!(DivAssign, div_assign, divide);
70impl_binary_op!(Rem, rem, remainder);
71impl_binary_op_assign!(RemAssign, rem_assign, remainder);
72impl_binary_op!(Pow, pow, power);
73
74impl Neg for &Array {
75    type Output = Array;
76    fn neg(self) -> Self::Output {
77        self.negative_device(StreamOrDevice::default()).unwrap()
78    }
79}
80impl Neg for Array {
81    type Output = Array;
82    fn neg(self) -> Self::Output {
83        self.negative_device(StreamOrDevice::default()).unwrap()
84    }
85}
86
87impl Not for &Array {
88    type Output = Array;
89    fn not(self) -> Self::Output {
90        self.logical_not_device(StreamOrDevice::default()).unwrap()
91    }
92}
93impl Not for Array {
94    type Output = Array;
95    fn not(self) -> Self::Output {
96        self.logical_not_device(StreamOrDevice::default()).unwrap()
97    }
98}
99
100impl Product<Array> for Array {
101    fn product<I: Iterator<Item = Self>>(iter: I) -> Self {
102        iter.fold(1.0.into(), |acc, x| acc * x)
103    }
104}
105
106impl<'a> Product<&'a Array> for Array {
107    fn product<I: Iterator<Item = &'a Self>>(iter: I) -> Self {
108        iter.fold(1.0.into(), |acc, x| acc * x)
109    }
110}
111
112#[cfg(test)]
113mod tests {
114    use super::*;
115    use pretty_assertions::assert_eq;
116
117    #[test]
118    fn test_add_assign() {
119        let mut a = Array::from_slice(&[1.0, 2.0, 3.0], &[3]);
120        let b = Array::from_slice(&[4.0, 5.0, 6.0], &[3]);
121        a += &b;
122
123        assert_eq!(a.as_slice::<f32>(), &[5.0, 7.0, 9.0]);
124    }
125
126    #[test]
127    fn test_sub_assign() {
128        let mut a = Array::from_slice(&[1.0, 2.0, 3.0], &[3]);
129        let b = Array::from_slice(&[4.0, 5.0, 6.0], &[3]);
130        a -= &b;
131
132        assert_eq!(a.as_slice::<f32>(), &[-3.0, -3.0, -3.0]);
133    }
134
135    #[test]
136    fn test_mul_assign() {
137        let mut a = Array::from_slice(&[1.0, 2.0, 3.0], &[3]);
138        let b = Array::from_slice(&[4.0, 5.0, 6.0], &[3]);
139        a *= &b;
140
141        assert_eq!(a.as_slice::<f32>(), &[4.0, 10.0, 18.0]);
142    }
143
144    #[test]
145    fn test_div_assign() {
146        let mut a = Array::from_slice(&[1.0, 2.0, 3.0], &[3]);
147        let b = Array::from_slice(&[4.0, 5.0, 6.0], &[3]);
148        a /= &b;
149
150        assert_eq!(a.as_slice::<f32>(), &[0.25, 0.4, 0.5]);
151    }
152}