1use crate::{utils::ScalarOrArray, Array, StreamOrDevice};
2use num_traits::Pow;
3use std::{
4 iter::Product,
5 ops::{
6 Add, AddAssign, Div, DivAssign, Mul, MulAssign, Neg, Not, Rem, RemAssign, Sub, SubAssign,
7 },
8};
9
10macro_rules! impl_binary_op {
11 ($trait:ident, $method:ident, $c_method:ident) => {
12 impl<'a, T> $trait<T> for Array
13 where
14 T: ScalarOrArray<'a>,
15 {
16 type Output = Array;
17
18 fn $method(self, rhs: T) -> Self::Output {
19 paste::paste! {
20 self.[<$c_method _device>](rhs.into_owned_or_ref_array(), StreamOrDevice::default()).unwrap()
21 }
22 }
23 }
24
25 impl<'a, 't: 'a, T> $trait<T> for &'a Array
26 where
27 T: ScalarOrArray<'t>,
28 {
29 type Output = Array;
30
31 fn $method(self, rhs: T) -> Self::Output {
32 paste::paste! {
33 self.[<$c_method _device>](rhs.into_owned_or_ref_array(), StreamOrDevice::default()).unwrap()
34 }
35 }
36 }
37 };
38}
39
40macro_rules! impl_binary_op_assign {
41 ($trait:ident, $method:ident, $c_method:ident) => {
42 impl<T: Into<Array>> $trait<T> for Array {
43 fn $method(&mut self, rhs: T) {
44 let new_array = paste::paste! {
45 self.[<$c_method _device>](&rhs.into(), StreamOrDevice::default()).unwrap()
46 };
47 *self = new_array;
48 }
49 }
50
51 impl $trait<&Array> for Array {
52 fn $method(&mut self, rhs: &Self) {
53 let new_array = paste::paste! {
54 self.[<$c_method _device>](rhs, StreamOrDevice::default()).unwrap()
55 };
56 *self = new_array;
57 }
58 }
59 };
60}
61
62impl_binary_op!(Add, add, add);
63impl_binary_op_assign!(AddAssign, add_assign, add);
64impl_binary_op!(Sub, sub, subtract);
65impl_binary_op_assign!(SubAssign, sub_assign, subtract);
66impl_binary_op!(Mul, mul, multiply);
67impl_binary_op_assign!(MulAssign, mul_assign, multiply);
68impl_binary_op!(Div, div, divide);
69impl_binary_op_assign!(DivAssign, div_assign, divide);
70impl_binary_op!(Rem, rem, remainder);
71impl_binary_op_assign!(RemAssign, rem_assign, remainder);
72impl_binary_op!(Pow, pow, power);
73
74impl Neg for &Array {
75 type Output = Array;
76 fn neg(self) -> Self::Output {
77 self.negative_device(StreamOrDevice::default()).unwrap()
78 }
79}
80impl Neg for Array {
81 type Output = Array;
82 fn neg(self) -> Self::Output {
83 self.negative_device(StreamOrDevice::default()).unwrap()
84 }
85}
86
87impl Not for &Array {
88 type Output = Array;
89 fn not(self) -> Self::Output {
90 self.logical_not_device(StreamOrDevice::default()).unwrap()
91 }
92}
93impl Not for Array {
94 type Output = Array;
95 fn not(self) -> Self::Output {
96 self.logical_not_device(StreamOrDevice::default()).unwrap()
97 }
98}
99
100impl Product<Array> for Array {
101 fn product<I: Iterator<Item = Self>>(iter: I) -> Self {
102 iter.fold(1.0.into(), |acc, x| acc * x)
103 }
104}
105
106impl<'a> Product<&'a Array> for Array {
107 fn product<I: Iterator<Item = &'a Self>>(iter: I) -> Self {
108 iter.fold(1.0.into(), |acc, x| acc * x)
109 }
110}
111
112#[cfg(test)]
113mod tests {
114 use super::*;
115 use pretty_assertions::assert_eq;
116
117 #[test]
118 fn test_add_assign() {
119 let mut a = Array::from_slice(&[1.0, 2.0, 3.0], &[3]);
120 let b = Array::from_slice(&[4.0, 5.0, 6.0], &[3]);
121 a += &b;
122
123 assert_eq!(a.as_slice::<f32>(), &[5.0, 7.0, 9.0]);
124 }
125
126 #[test]
127 fn test_sub_assign() {
128 let mut a = Array::from_slice(&[1.0, 2.0, 3.0], &[3]);
129 let b = Array::from_slice(&[4.0, 5.0, 6.0], &[3]);
130 a -= &b;
131
132 assert_eq!(a.as_slice::<f32>(), &[-3.0, -3.0, -3.0]);
133 }
134
135 #[test]
136 fn test_mul_assign() {
137 let mut a = Array::from_slice(&[1.0, 2.0, 3.0], &[3]);
138 let b = Array::from_slice(&[4.0, 5.0, 6.0], &[3]);
139 a *= &b;
140
141 assert_eq!(a.as_slice::<f32>(), &[4.0, 10.0, 18.0]);
142 }
143
144 #[test]
145 fn test_div_assign() {
146 let mut a = Array::from_slice(&[1.0, 2.0, 3.0], &[3]);
147 let b = Array::from_slice(&[4.0, 5.0, 6.0], &[3]);
148 a /= &b;
149
150 assert_eq!(a.as_slice::<f32>(), &[0.25, 0.4, 0.5]);
151 }
152}