mlx_rs/ops/
arithmetic.rs

1use crate::array::Array;
2use crate::error::Result;
3use crate::sealed::Sealed;
4
5use crate::utils::guard::Guarded;
6use crate::utils::{IntoOption, ScalarOrArray, VectorArray};
7use crate::Stream;
8use mlx_internal_macros::{default_device, generate_macro};
9use smallvec::SmallVec;
10
11impl Array {
12    /// Element-wise absolute value.
13    ///
14    /// # Example
15    ///
16    /// ```rust
17    /// use mlx_rs::Array;
18    /// let array = Array::from_slice(&[1i32, 2, -3, -4, -5], &[5]);
19    /// let mut result = array.abs().unwrap();
20    ///
21    /// let data: &[i32] = result.as_slice();
22    /// // data == [1, 2, 3, 4, 5]
23    /// ```
24    #[default_device]
25    pub fn abs_device(&self, stream: impl AsRef<Stream>) -> Result<Array> {
26        Array::try_from_op(|res| unsafe {
27            mlx_sys::mlx_abs(res, self.as_ptr(), stream.as_ref().as_ptr())
28        })
29    }
30
31    /// Element-wise addition returning an error if arrays are not broadcastable.
32    ///
33    /// Add two arrays with [broadcasting](https://swiftpackageindex.com/ml-explore/mlx-swift/main/documentation/mlx/broadcasting).
34    ///
35    /// # Params
36    ///
37    /// - other: array to add
38    ///
39    /// # Example
40    ///
41    /// ```rust
42    /// use mlx_rs::Array;
43    /// let a = Array::from_slice(&[1.0, 2.0, 3.0], &[3]);
44    /// let b = Array::from_slice(&[4.0, 5.0, 6.0], &[3]);
45    /// let mut c = a.add(&b).unwrap();
46    ///
47    /// let c_data: &[f32] = c.as_slice();
48    /// // c_data == [5.0, 7.0, 9.0]
49    /// ```
50    #[default_device]
51    pub fn add_device(
52        &self,
53        other: impl AsRef<Array>,
54        stream: impl AsRef<Stream>,
55    ) -> Result<Array> {
56        Array::try_from_op(|res| unsafe {
57            mlx_sys::mlx_add(
58                res,
59                self.as_ptr(),
60                other.as_ref().as_ptr(),
61                stream.as_ref().as_ptr(),
62            )
63        })
64    }
65
66    /// Element-wise subtraction returning an error if arrays are not broadcastable.
67    ///
68    /// Subtract two arrays with [broadcasting](https://swiftpackageindex.com/ml-explore/mlx-swift/main/documentation/mlx/broadcasting).
69    ///
70    /// # Params
71    ///
72    /// - other: array to subtract
73    ///
74    /// # Example
75    ///
76    /// ```rust
77    /// use mlx_rs::Array;
78    /// let a = Array::from_slice(&[1.0, 2.0, 3.0], &[3]);
79    /// let b = Array::from_slice(&[4.0, 5.0, 6.0], &[3]);
80    /// let mut c = a.subtract(&b).unwrap();
81    ///
82    /// let c_data: &[f32] = c.as_slice();
83    /// // c_data == [-3.0, -3.0, -3.0]
84    /// ```
85    #[default_device]
86    pub fn subtract_device(
87        &self,
88        other: impl AsRef<Array>,
89        stream: impl AsRef<Stream>,
90    ) -> Result<Array> {
91        Array::try_from_op(|res| unsafe {
92            mlx_sys::mlx_subtract(
93                res,
94                self.as_ptr(),
95                other.as_ref().as_ptr(),
96                stream.as_ref().as_ptr(),
97            )
98        })
99    }
100
101    /// Unary element-wise negation. Returns an error if the array is of type bool.
102    ///
103    /// Negate the values in the array.
104    ///
105    /// # Example
106    ///
107    /// ```rust
108    /// use mlx_rs::Array;
109    /// let a = Array::from_slice(&[1.0, 2.0, 3.0], &[3]);
110    /// let mut b = a.negative().unwrap();
111    ///
112    /// let b_data: &[f32] = b.as_slice();
113    /// // b_data == [-1.0, -2.0, -3.0]
114    /// ```
115    #[default_device]
116    pub fn negative_device(&self, stream: impl AsRef<Stream>) -> Result<Array> {
117        Array::try_from_op(|res| unsafe {
118            mlx_sys::mlx_negative(res, self.as_ptr(), stream.as_ref().as_ptr())
119        })
120    }
121
122    /// Element-wise multiplication returning an error if arrays are not broadcastable.
123    ///
124    /// Multiply two arrays with [broadcasting](https://swiftpackageindex.com/ml-explore/mlx-swift/main/documentation/mlx/broadcasting).
125    ///
126    /// # Example
127    ///
128    /// ```rust
129    /// use mlx_rs::Array;
130    /// let a = Array::from_slice(&[1.0, 2.0, 3.0], &[3]);
131    /// let b = Array::from_slice(&[4.0, 5.0, 6.0], &[3]);
132    /// let mut c = a.multiply(&b).unwrap();
133    ///
134    /// let c_data: &[f32] = c.as_slice();
135    /// // c_data == [4.0, 10.0, 18.0]
136    /// ```
137    #[default_device]
138    pub fn multiply_device(
139        &self,
140        other: impl AsRef<Array>,
141        stream: impl AsRef<Stream>,
142    ) -> Result<Array> {
143        Array::try_from_op(|res| unsafe {
144            mlx_sys::mlx_multiply(
145                res,
146                self.as_ptr(),
147                other.as_ref().as_ptr(),
148                stream.as_ref().as_ptr(),
149            )
150        })
151    }
152
153    /// Replace NaN and Inf values with finite numbers.
154    ///
155    /// # Params
156    /// - nan: value to replace NaN with
157    /// - posInf: value to replace positive inifinites with.  If not specified will use
158    ///     the largest finite value for the given dtype.
159    /// - negInf: value to replace negative inifinites with.  If not specified will use
160    ///     the negative of the largest finite value for the given dtype.
161    /// - stream: stream or device to evaluate on
162    #[default_device]
163    pub fn nan_to_num_device(
164        &self,
165        nan: impl IntoOption<f32>,
166        pos_inf: impl IntoOption<f32>,
167        neg_inf: impl IntoOption<f32>,
168        stream: impl AsRef<Stream>,
169    ) -> Result<Array> {
170        let pos_inf = pos_inf.into_option();
171        let neg_inf = neg_inf.into_option();
172
173        let pos_inf = mlx_sys::mlx_optional_float {
174            value: pos_inf.unwrap_or(0.0),
175            has_value: pos_inf.is_some(),
176        };
177        let neg_inf = mlx_sys::mlx_optional_float {
178            value: neg_inf.unwrap_or(0.0),
179            has_value: neg_inf.is_some(),
180        };
181
182        Array::try_from_op(|res| unsafe {
183            mlx_sys::mlx_nan_to_num(
184                res,
185                self.as_ptr(),
186                nan.into_option().unwrap_or(0.),
187                pos_inf,
188                neg_inf,
189                stream.as_ref().as_ptr(),
190            )
191        })
192    }
193
194    /// Element-wise division returning an error if arrays are not broadcastable.
195    ///
196    /// Divide two arrays with [broadcasting](https://swiftpackageindex.com/ml-explore/mlx-swift/main/documentation/mlx/broadcasting).
197    ///
198    /// # Params
199    ///
200    /// - other: array to divide
201    ///
202    /// # Example
203    ///
204    /// ```rust
205    /// use mlx_rs::Array;
206    /// let a = Array::from_slice(&[1.0, 2.0, 3.0], &[3]);
207    /// let b = Array::from_slice(&[4.0, 5.0, 6.0], &[3]);
208    /// let mut c = a.divide(&b).unwrap();
209    ///
210    /// let c_data: &[f32] = c.as_slice();
211    /// // c_data == [0.25, 0.4, 0.5]
212    /// ```
213    #[default_device]
214    pub fn divide_device(
215        &self,
216        other: impl AsRef<Array>,
217        stream: impl AsRef<Stream>,
218    ) -> Result<Array> {
219        Array::try_from_op(|res| unsafe {
220            mlx_sys::mlx_divide(
221                res,
222                self.as_ptr(),
223                other.as_ref().as_ptr(),
224                stream.as_ref().as_ptr(),
225            )
226        })
227    }
228
229    /// Element-wise power operation returning an error if arrays are not broadcastable if they have different shapes.
230    ///
231    /// Raise the elements of the array to the power of the elements of another array.
232    ///
233    /// # Params
234    ///
235    /// - other: array to raise to the power of
236    ///
237    /// # Example
238    ///
239    /// ```rust
240    /// use mlx_rs::Array;
241    /// let a = Array::from_slice(&[1.0, 2.0, 3.0], &[3]);
242    /// let b = Array::from_slice(&[2.0, 3.0, 4.0], &[3]);
243    /// let mut c = a.power(&b).unwrap();
244    ///
245    /// let c_data: &[f32] = c.as_slice();
246    /// // c_data == [1.0, 8.0, 81.0]
247    /// ```
248    #[default_device]
249    pub fn power_device(
250        &self,
251        other: impl AsRef<Array>,
252        stream: impl AsRef<Stream>,
253    ) -> Result<Array> {
254        Array::try_from_op(|res| unsafe {
255            mlx_sys::mlx_power(
256                res,
257                self.as_ptr(),
258                other.as_ref().as_ptr(),
259                stream.as_ref().as_ptr(),
260            )
261        })
262    }
263
264    /// Element-wise remainder of division returning an error if arrays are not broadcastable.
265    ///
266    /// Computes the remainder of dividing `lhs` with `rhs` with [broadcasting](https://swiftpackageindex.com/ml-explore/mlx-swift/main/documentation/mlx/broadcasting).
267    ///
268    /// # Params
269    ///
270    /// - other: array to divide
271    ///
272    /// # Example
273    ///
274    /// ```rust
275    /// use mlx_rs::Array;
276    /// let a = Array::from_slice(&[10.0, 11.0, 12.0], &[3]);
277    /// let b = Array::from_slice(&[3.0, 4.0, 5.0], &[3]);
278    /// let mut c = a.remainder(&b).unwrap();
279    ///
280    /// let c_data: &[f32] = c.as_slice();
281    /// // c_data == [1.0, 3.0, 2.0]
282    /// ```
283    #[default_device]
284    pub fn remainder_device(
285        &self,
286        other: impl AsRef<Array>,
287        stream: impl AsRef<Stream>,
288    ) -> Result<Array> {
289        Array::try_from_op(|res| unsafe {
290            mlx_sys::mlx_remainder(
291                res,
292                self.as_ptr(),
293                other.as_ref().as_ptr(),
294                stream.as_ref().as_ptr(),
295            )
296        })
297    }
298
299    /// Element-wise square root
300    ///
301    /// # Example
302    ///
303    /// ```rust
304    /// use mlx_rs::Array;
305    /// let a = Array::from_slice(&[1.0, 4.0, 9.0], &[3]);
306    /// let mut b = a.sqrt().unwrap();
307    ///
308    /// let b_data: &[f32] = b.as_slice();
309    /// // b_data == [1.0, 2.0, 3.0]
310    /// ```
311    #[default_device]
312    pub fn sqrt_device(&self, stream: impl AsRef<Stream>) -> Result<Array> {
313        Array::try_from_op(|res| unsafe {
314            mlx_sys::mlx_sqrt(res, self.as_ptr(), stream.as_ref().as_ptr())
315        })
316    }
317
318    /// Element-wise cosine
319    ///
320    /// # Example
321    ///
322    /// ```rust
323    /// use mlx_rs::Array;
324    /// let a = Array::from_slice(&[0.0, 1.0, 2.0], &[3]);
325    /// let mut b = a.cos().unwrap();
326    ///
327    /// let b_data: &[f32] = b.as_slice();
328    /// // b_data == [1.0, 0.54030234, -0.41614687]
329    /// ```
330    #[default_device]
331    pub fn cos_device(&self, stream: impl AsRef<Stream>) -> Result<Array> {
332        Array::try_from_op(|res| unsafe {
333            mlx_sys::mlx_cos(res, self.as_ptr(), stream.as_ref().as_ptr())
334        })
335    }
336
337    /// Element-wise exponential.
338    ///
339    /// # Example
340    ///
341    /// ```rust
342    /// use mlx_rs::Array;
343    ///
344    /// let a = Array::from_slice(&[0.0, 1.0, 2.0], &[3]);
345    /// let a = Array::from_slice(&[0.0, 1.0, 2.0], &[3]);
346    /// let mut b = a.exp().unwrap();
347    ///
348    /// let b_data: &[f32] = b.as_slice();
349    /// // b_data == [1.0, 2.7182817, 7.389056]
350    /// ```
351    #[default_device]
352    pub fn exp_device(&self, stream: impl AsRef<Stream>) -> Result<Array> {
353        Array::try_from_op(|res| unsafe {
354            mlx_sys::mlx_exp(res, self.as_ptr(), stream.as_ref().as_ptr())
355        })
356    }
357
358    /// Element-wise floor returning an error if the array is of type complex64.
359    ///
360    /// # Example
361    ///
362    /// ```rust
363    /// use mlx_rs::Array;
364    /// let a = Array::from_slice(&[0.1, 1.9, 2.5], &[3]);
365    /// let mut b = a.floor().unwrap();
366    ///
367    /// let b_data: &[f32] = b.as_slice();
368    /// // b_data == [0.0, 1.0, 2.0]
369    /// ```
370    #[default_device]
371    pub fn floor_device(&self, stream: impl AsRef<Stream>) -> Result<Array> {
372        Array::try_from_op(|res| unsafe {
373            mlx_sys::mlx_floor(res, self.as_ptr(), stream.as_ref().as_ptr())
374        })
375    }
376
377    /// Element-wise integer division returning an error if arrays are not broadcastable.
378    ///
379    /// Divide two arrays with
380    /// [broadcasting](https://swiftpackageindex.com/ml-explore/mlx-swift/main/documentation/mlx/broadcasting).
381    ///
382    /// If either array is a floating point type then it is equivalent to calling [`Array::floor()`]
383    /// after `/`.
384    ///
385    /// # Params
386    ///
387    /// - other: array to divide
388    ///
389    /// # Example
390    ///
391    /// ```rust
392    /// use mlx_rs::Array;
393    /// let a = Array::from_slice(&[1.0, 2.0, 3.0], &[3]);
394    /// let b = Array::from_slice(&[4.0, 5.0, 6.0], &[3]);
395    /// let mut c = a.floor_divide(&b).unwrap();
396    ///
397    /// let c_data: &[f32] = c.as_slice();
398    /// // c_data == [0.25, 0.4, 0.5]
399    /// ```
400    #[default_device]
401    pub fn floor_divide_device(
402        &self,
403        other: impl AsRef<Array>,
404        stream: impl AsRef<Stream>,
405    ) -> Result<Array> {
406        Array::try_from_op(|res| unsafe {
407            mlx_sys::mlx_floor_divide(
408                res,
409                self.as_ptr(),
410                other.as_ref().as_ptr(),
411                stream.as_ref().as_ptr(),
412            )
413        })
414    }
415
416    /// Return a boolean array indicating which elements are NaN.
417    ///
418    /// # Params
419    /// - stream: stream or device to evaluate on
420    #[default_device]
421    pub fn is_nan_device(&self, stream: impl AsRef<Stream>) -> Result<Array> {
422        Array::try_from_op(|res| unsafe {
423            mlx_sys::mlx_isnan(res, self.as_ptr(), stream.as_ref().as_ptr())
424        })
425    }
426
427    /// Return a boolean array indicating which elements are infinity.
428    ///
429    /// # Params
430    /// - stream: stream or device to evaluate on
431    #[default_device]
432    pub fn is_inf_device(&self, stream: impl AsRef<Stream>) -> Result<Array> {
433        Array::try_from_op(|res| unsafe {
434            mlx_sys::mlx_isinf(res, self.as_ptr(), stream.as_ref().as_ptr())
435        })
436    }
437
438    /// Return a boolean array indicating which elements are finite.
439    ///
440    /// # Params
441    /// - stream: stream or device to evaluate on
442    #[default_device]
443    pub fn is_finite_device(&self, stream: impl AsRef<Stream>) -> Result<Array> {
444        Array::try_from_op(|res| unsafe {
445            mlx_sys::mlx_isfinite(res, self.as_ptr(), stream.as_ref().as_ptr())
446        })
447    }
448
449    /// Return a boolean array indicating which elements are negative infinity.
450    ///
451    /// # Params
452    /// - stream: stream or device to evaluate on
453    #[default_device]
454    pub fn is_neg_inf_device(&self, stream: impl AsRef<Stream>) -> Result<Array> {
455        Array::try_from_op(|res| unsafe {
456            mlx_sys::mlx_isneginf(res, self.as_ptr(), stream.as_ref().as_ptr())
457        })
458    }
459
460    /// Return a boolean array indicating which elements are positive infinity.
461    ///
462    /// # Params
463    /// - stream: stream or device to evaluate on
464    #[default_device]
465    pub fn is_pos_inf_device(&self, stream: impl AsRef<Stream>) -> Result<Array> {
466        Array::try_from_op(|res| unsafe {
467            mlx_sys::mlx_isposinf(res, self.as_ptr(), stream.as_ref().as_ptr())
468        })
469    }
470
471    /// Element-wise natural logarithm.
472    ///
473    /// # Example
474    ///
475    /// ```rust
476    /// use mlx_rs::Array;
477    /// let a = Array::from_slice(&[1.0, 2.0, 3.0], &[3]);
478    /// let mut b = a.log().unwrap();
479    ///
480    /// let b_data: &[f32] = b.as_slice();
481    /// // b_data == [0.0, 0.6931472, 1.0986123]
482    /// ```
483    #[default_device]
484    pub fn log_device(&self, stream: impl AsRef<Stream>) -> Result<Array> {
485        Array::try_from_op(|res| unsafe {
486            mlx_sys::mlx_log(res, self.as_ptr(), stream.as_ref().as_ptr())
487        })
488    }
489
490    /// Element-wise base-2 logarithm.
491    ///
492    /// # Example
493    ///
494    /// ```rust
495    /// use mlx_rs::Array;
496    /// let a = Array::from_slice(&[1.0, 2.0, 4.0, 8.0], &[4]);
497    /// let mut b = a.log2().unwrap();
498    ///
499    /// let b_data: &[f32] = b.as_slice();
500    /// // b_data == [0.0, 1.0, 2.0, 3.0]
501    /// ```
502    #[default_device]
503    pub fn log2_device(&self, stream: impl AsRef<Stream>) -> Result<Array> {
504        Array::try_from_op(|res| unsafe {
505            mlx_sys::mlx_log2(res, self.as_ptr(), stream.as_ref().as_ptr())
506        })
507    }
508
509    /// Element-wise base-10 logarithm.
510    ///
511    /// # Example
512    ///
513    /// ```rust
514    /// use mlx_rs::Array;
515    /// let a = Array::from_slice(&[1.0, 10.0, 100.0], &[3]);
516    /// let mut b = a.log10().unwrap();
517    ///
518    /// let b_data: &[f32] = b.as_slice();
519    /// // b_data == [0.0, 1.0, 2.0]
520    /// ```
521    #[default_device]
522    pub fn log10_device(&self, stream: impl AsRef<Stream>) -> Result<Array> {
523        Array::try_from_op(|res| unsafe {
524            mlx_sys::mlx_log10(res, self.as_ptr(), stream.as_ref().as_ptr())
525        })
526    }
527
528    /// Element-wise natural log of one plus the array.
529    ///
530    /// # Example
531    ///
532    /// ```rust
533    /// use mlx_rs::Array;
534    /// let a = Array::from_slice(&[1.0, 2.0, 3.0], &[3]);
535    /// let mut b = a.log1p().unwrap();
536    ///
537    /// let b_data: &[f32] = b.as_slice();
538    /// // b_data == [0.6931472, 1.0986123, 1.3862944]
539    /// ```
540    #[default_device]
541    pub fn log1p_device(&self, stream: impl AsRef<Stream>) -> Result<Array> {
542        Array::try_from_op(|res| unsafe {
543            mlx_sys::mlx_log1p(res, self.as_ptr(), stream.as_ref().as_ptr())
544        })
545    }
546
547    /// Matrix multiplication returning an error if inputs are not valid.
548    ///
549    /// Perform the (possibly batched) matrix multiplication of two arrays. This function supports
550    /// broadcasting for arrays with more than two dimensions.
551    ///
552    /// - If the first array is 1-D then a 1 is prepended to its shape to make it
553    ///   a matrix. Similarly, if the second array is 1-D then a 1 is appended to its
554    ///   shape to make it a matrix. In either case the singleton dimension is removed
555    ///   from the result.
556    /// - A batched matrix multiplication is performed if the arrays have more than
557    ///   2 dimensions.  The matrix dimensions for the matrix product are the last
558    ///   two dimensions of each input.
559    /// - All but the last two dimensions of each input are broadcast with one another using
560    ///   standard [broadcasting](https://swiftpackageindex.com/ml-explore/mlx-swift/main/documentation/mlx/broadcasting).
561    ///
562    /// # Params
563    ///
564    /// - other: array to multiply
565    ///
566    /// # Example
567    ///
568    /// ```rust
569    /// use mlx_rs::Array;
570    /// let a = Array::from_slice(&[1, 2, 3, 4], &[2, 2]);
571    /// let b = Array::from_slice(&[-5.0, 37.5, 4., 7., 1., 0.], &[2, 3]);
572    ///
573    /// // produces a [2, 3] result
574    /// let mut c = a.matmul(&b);
575    /// ```
576    #[default_device]
577    pub fn matmul_device(
578        &self,
579        other: impl AsRef<Array>,
580        stream: impl AsRef<Stream>,
581    ) -> Result<Array> {
582        Array::try_from_op(|res| unsafe {
583            mlx_sys::mlx_matmul(
584                res,
585                self.as_ptr(),
586                other.as_ref().as_ptr(),
587                stream.as_ref().as_ptr(),
588            )
589        })
590    }
591
592    /// Element-wise reciprocal.
593    ///
594    /// # Example
595    ///
596    /// ```rust
597    /// use mlx_rs::Array;
598    /// let a = Array::from_slice(&[1.0, 2.0, 4.0], &[3]);
599    /// let mut b = a.reciprocal().unwrap();
600    ///
601    /// let b_data: &[f32] = b.as_slice();
602    /// // b_data == [1.0, 0.5, 0.25]
603    /// ```
604    #[default_device]
605    pub fn reciprocal_device(&self, stream: impl AsRef<Stream>) -> Result<Array> {
606        Array::try_from_op(|res| unsafe {
607            mlx_sys::mlx_reciprocal(res, self.as_ptr(), stream.as_ref().as_ptr())
608        })
609    }
610
611    /// Round to the given number of decimals.
612    ///
613    /// # Params
614    ///
615    /// - decimals: number of decimals to round to - default is 0 if not provided
616    #[default_device]
617    pub fn round_device(
618        &self,
619        decimals: impl Into<Option<i32>>,
620        stream: impl AsRef<Stream>,
621    ) -> Result<Array> {
622        Array::try_from_op(|res| unsafe {
623            mlx_sys::mlx_round(
624                res,
625                self.as_ptr(),
626                decimals.into().unwrap_or(0),
627                stream.as_ref().as_ptr(),
628            )
629        })
630    }
631
632    /// Element-wise reciprocal and square root.
633    #[default_device]
634    pub fn rsqrt_device(&self, stream: impl AsRef<Stream>) -> Result<Array> {
635        Array::try_from_op(|res| unsafe {
636            mlx_sys::mlx_rsqrt(res, self.as_ptr(), stream.as_ref().as_ptr())
637        })
638    }
639
640    /// Element-wise sine.
641    #[default_device]
642    pub fn sin_device(&self, stream: impl AsRef<Stream>) -> Result<Array> {
643        Array::try_from_op(|res| unsafe {
644            mlx_sys::mlx_sin(res, self.as_ptr(), stream.as_ref().as_ptr())
645        })
646    }
647
648    /// Element-wise square.
649    #[default_device]
650    pub fn square_device(&self, stream: impl AsRef<Stream>) -> Result<Array> {
651        Array::try_from_op(|res| unsafe {
652            mlx_sys::mlx_square(res, self.as_ptr(), stream.as_ref().as_ptr())
653        })
654    }
655}
656
657/// Element-wise absolute value.
658///
659/// # Example
660///
661/// ```rust
662/// use mlx_rs::{Array, ops};
663///
664/// let array = Array::from_slice(&[1i32, 2, -3, -4, -5], &[5]);
665/// let result = ops::abs(&array).unwrap();
666/// ```
667#[generate_macro]
668#[default_device]
669pub fn abs_device(a: impl AsRef<Array>, #[optional] stream: impl AsRef<Stream>) -> Result<Array> {
670    a.as_ref().abs_device(stream)
671}
672
673/// Element-wise inverse cosine.
674#[generate_macro]
675#[default_device]
676pub fn acos_device(a: impl AsRef<Array>, #[optional] stream: impl AsRef<Stream>) -> Result<Array> {
677    Array::try_from_op(|res| unsafe {
678        mlx_sys::mlx_arccos(res, a.as_ref().as_ptr(), stream.as_ref().as_ptr())
679    })
680}
681
682/// Element-wise inverse hyperbolic cosine.
683#[generate_macro]
684#[default_device]
685pub fn acosh_device(a: impl AsRef<Array>, #[optional] stream: impl AsRef<Stream>) -> Result<Array> {
686    Array::try_from_op(|res| unsafe {
687        mlx_sys::mlx_arccosh(res, a.as_ref().as_ptr(), stream.as_ref().as_ptr())
688    })
689}
690
691/// See [`Array::add`].
692#[generate_macro]
693#[default_device]
694pub fn add_device(
695    lhs: impl AsRef<Array>,
696    rhs: impl AsRef<Array>,
697    #[optional] stream: impl AsRef<Stream>,
698) -> Result<Array> {
699    lhs.as_ref().add_device(rhs, stream)
700}
701
702/// Element-wise inverse sine.
703#[generate_macro]
704#[default_device]
705pub fn asin_device(a: impl AsRef<Array>, #[optional] stream: impl AsRef<Stream>) -> Result<Array> {
706    Array::try_from_op(|res| unsafe {
707        mlx_sys::mlx_arcsin(res, a.as_ref().as_ptr(), stream.as_ref().as_ptr())
708    })
709}
710
711/// Element-wise inverse hyperbolic sine.
712#[generate_macro]
713#[default_device]
714pub fn asinh_device(a: impl AsRef<Array>, #[optional] stream: impl AsRef<Stream>) -> Result<Array> {
715    Array::try_from_op(|res| unsafe {
716        mlx_sys::mlx_arcsinh(res, a.as_ref().as_ptr(), stream.as_ref().as_ptr())
717    })
718}
719
720/// Element-wise inverse tangent.
721#[generate_macro]
722#[default_device]
723pub fn atan_device(a: impl AsRef<Array>, #[optional] stream: impl AsRef<Stream>) -> Result<Array> {
724    Array::try_from_op(|res| unsafe {
725        mlx_sys::mlx_arctan(res, a.as_ref().as_ptr(), stream.as_ref().as_ptr())
726    })
727}
728
729/// Element-wise inverse hyperbolic tangent.
730#[generate_macro]
731#[default_device]
732pub fn atanh_device(a: impl AsRef<Array>, #[optional] stream: impl AsRef<Stream>) -> Result<Array> {
733    Array::try_from_op(|res| unsafe {
734        mlx_sys::mlx_arctanh(res, a.as_ref().as_ptr(), stream.as_ref().as_ptr())
735    })
736}
737
738/// Element-wise ceiling.
739#[generate_macro]
740#[default_device]
741pub fn ceil_device(a: impl AsRef<Array>, #[optional] stream: impl AsRef<Stream>) -> Result<Array> {
742    Array::try_from_op(|res| unsafe {
743        mlx_sys::mlx_ceil(res, a.as_ref().as_ptr(), stream.as_ref().as_ptr())
744    })
745}
746
747/// A custom trait for the bound of the clip operation.
748///
749/// This trait is only implemented for tuples of the form `(Min, Max)`, `(Min, ())`, and `((),
750/// Max)`. The `Min` and `Max` types must implement the `ScalarOrArray` trait.
751pub trait ClipBound<'min, 'max>: Sealed {
752    /// Convert the bound into a tuple of optional minimum and maximum values.
753    fn into_min_max(
754        self,
755    ) -> (
756        Option<impl ScalarOrArray<'min>>,
757        Option<impl ScalarOrArray<'max>>,
758    );
759}
760
761impl<'min, Min> ClipBound<'min, 'min> for (Min, ())
762where
763    Min: ScalarOrArray<'min> + Sealed,
764{
765    fn into_min_max(
766        self,
767    ) -> (
768        Option<impl ScalarOrArray<'min>>,
769        Option<impl ScalarOrArray<'min>>,
770    ) {
771        (Some(self.0), Option::<Min>::None)
772    }
773}
774
775impl<'max, Max> ClipBound<'max, 'max> for ((), Max)
776where
777    Max: ScalarOrArray<'max> + Sealed,
778{
779    fn into_min_max(
780        self,
781    ) -> (
782        Option<impl ScalarOrArray<'max>>,
783        Option<impl ScalarOrArray<'max>>,
784    ) {
785        (Option::<Max>::None, Some(self.1))
786    }
787}
788
789impl<'min, 'max, Min, Max> ClipBound<'min, 'max> for (Min, Max)
790where
791    Min: ScalarOrArray<'min> + Sealed,
792    Max: ScalarOrArray<'max> + Sealed,
793{
794    fn into_min_max(
795        self,
796    ) -> (
797        Option<impl ScalarOrArray<'min>>,
798        Option<impl ScalarOrArray<'max>>,
799    ) {
800        (Some(self.0), Some(self.1))
801    }
802}
803
804/// Clip the values of the array between the given minimum and maximum.
805///
806/// If either `a_min` or `a_max` are None, then corresponding edge is ignored. At least one of
807/// `a_min` and `a_max` cannot be `None`. The input `a` and the limits must broadcast with one
808/// another.
809///
810/// # Params
811///
812/// - `a`: Input array.
813/// - `bound`: minimum and/or maximum values to clip the array to.
814///
815/// # Example
816///
817/// ```rust
818/// use mlx_rs::{Array, ops::clip, array};
819///
820/// let a = array!([1.0, 4.0, 3.0, 8.0, 5.0]);
821/// let expected = array!([2.0, 4.0, 3.0, 6.0, 5.0]);
822/// let clipped = clip(&a, (2.0, 6.0)).unwrap();
823/// assert_eq!(clipped, expected);
824/// ```
825#[generate_macro]
826#[default_device]
827pub fn clip_device<'min, 'max>(
828    a: impl AsRef<Array>,
829    bound: impl ClipBound<'min, 'max>,
830    #[optional] stream: impl AsRef<Stream>,
831) -> Result<Array> {
832    let (a_min, a_max) = bound.into_min_max();
833
834    // This is needed to keep the lifetime of the min/max arrays in scope.
835    let a_min = a_min.map(|min| min.into_owned_or_ref_array());
836    let a_max = a_max.map(|max| max.into_owned_or_ref_array());
837
838    unsafe {
839        let min_ptr = match &a_min {
840            Some(a_min) => a_min.as_ref().as_ptr(),
841            None => mlx_sys::mlx_array_new(),
842        };
843        let max_ptr = match &a_max {
844            Some(a_max) => a_max.as_ref().as_ptr(),
845            None => mlx_sys::mlx_array_new(),
846        };
847
848        Array::try_from_op(|res| {
849            mlx_sys::mlx_clip(
850                res,
851                a.as_ref().as_ptr(),
852                min_ptr,
853                max_ptr,
854                stream.as_ref().as_ptr(),
855            )
856        })
857    }
858}
859
860/// Element-wise cosine.
861#[generate_macro]
862#[default_device]
863pub fn cos_device(a: impl AsRef<Array>, #[optional] stream: impl AsRef<Stream>) -> Result<Array> {
864    a.as_ref().cos_device(stream)
865}
866
867/// Element-wise hyperbolic cosine.
868#[generate_macro]
869#[default_device]
870pub fn cosh_device(a: impl AsRef<Array>, #[optional] stream: impl AsRef<Stream>) -> Result<Array> {
871    Array::try_from_op(|res| unsafe {
872        mlx_sys::mlx_cosh(res, a.as_ref().as_ptr(), stream.as_ref().as_ptr())
873    })
874}
875
876/// Convert angles from radians to degrees.
877#[generate_macro]
878#[default_device]
879pub fn degrees_device(
880    a: impl AsRef<Array>,
881    #[optional] stream: impl AsRef<Stream>,
882) -> Result<Array> {
883    Array::try_from_op(|res| unsafe {
884        mlx_sys::mlx_degrees(res, a.as_ref().as_ptr(), stream.as_ref().as_ptr())
885    })
886}
887
888/// See [`Array::divide`].
889#[generate_macro]
890#[default_device]
891pub fn divide_device(
892    a: impl AsRef<Array>,
893    b: impl AsRef<Array>,
894    #[optional] stream: impl AsRef<Stream>,
895) -> Result<Array> {
896    a.as_ref().divide_device(b, stream)
897}
898
899/// Element-wise quotient and remainder.
900///
901/// The fuction `divmod(a, b)` is equivalent to but faster than `(a // b, a % b)`. The function uses
902/// numpy-style broadcasting semantics. Either or both input arrays can also be scalars.
903///
904/// Returns Ok((quotient, remainder)) if the operation was successful.
905#[generate_macro]
906#[default_device]
907pub fn divmod_device(
908    a: impl AsRef<Array>,
909    b: impl AsRef<Array>,
910    #[optional] stream: impl AsRef<Stream>,
911) -> Result<(Array, Array)> {
912    let a_ptr = a.as_ref().as_ptr();
913    let b_ptr = b.as_ref().as_ptr();
914
915    let vec = VectorArray::try_from_op(|res| unsafe {
916        mlx_sys::mlx_divmod(res, a_ptr, b_ptr, stream.as_ref().as_ptr())
917    })?;
918
919    let vals: SmallVec<[_; 2]> = vec.try_into_values()?;
920    let mut iter = vals.into_iter();
921    let quotient = iter.next().unwrap();
922    let remainder = iter.next().unwrap();
923
924    Ok((quotient, remainder))
925}
926
927/// Element-wise error function.
928#[generate_macro]
929#[default_device]
930pub fn erf_device(a: impl AsRef<Array>, #[optional] stream: impl AsRef<Stream>) -> Result<Array> {
931    Array::try_from_op(|res| unsafe {
932        mlx_sys::mlx_erf(res, a.as_ref().as_ptr(), stream.as_ref().as_ptr())
933    })
934}
935
936/// Element-wise inverse error function.
937#[generate_macro]
938#[default_device]
939pub fn erfinv_device(
940    a: impl AsRef<Array>,
941    #[optional] stream: impl AsRef<Stream>,
942) -> Result<Array> {
943    Array::try_from_op(|res| unsafe {
944        mlx_sys::mlx_erfinv(res, a.as_ref().as_ptr(), stream.as_ref().as_ptr())
945    })
946}
947
948/// See [`Array::exp`].
949#[generate_macro]
950#[default_device]
951pub fn exp_device(a: impl AsRef<Array>, #[optional] stream: impl AsRef<Stream>) -> Result<Array> {
952    a.as_ref().exp_device(stream)
953}
954
955/// Element-wise exponential minus 1.
956#[generate_macro]
957#[default_device]
958pub fn expm1_device(a: impl AsRef<Array>, #[optional] stream: impl AsRef<Stream>) -> Result<Array> {
959    Array::try_from_op(|res| unsafe {
960        mlx_sys::mlx_expm1(res, a.as_ref().as_ptr(), stream.as_ref().as_ptr())
961    })
962}
963
964/// See [`Array::floor`].
965#[generate_macro]
966#[default_device]
967pub fn floor_device(a: impl AsRef<Array>, #[optional] stream: impl AsRef<Stream>) -> Result<Array> {
968    a.as_ref().floor_device(stream)
969}
970
971/// See [`Array::floor_divide`].
972#[generate_macro]
973#[default_device]
974pub fn floor_divide_device(
975    a: impl AsRef<Array>,
976    other: impl AsRef<Array>,
977    #[optional] stream: impl AsRef<Stream>,
978) -> Result<Array> {
979    a.as_ref().floor_divide_device(other, stream)
980}
981
982/// See [`Array::log`].
983#[generate_macro]
984#[default_device]
985pub fn log_device(a: impl AsRef<Array>, #[optional] stream: impl AsRef<Stream>) -> Result<Array> {
986    a.as_ref().log_device(stream)
987}
988
989/// See [`Array::log10`].
990#[generate_macro]
991#[default_device]
992pub fn log10_device(a: impl AsRef<Array>, #[optional] stream: impl AsRef<Stream>) -> Result<Array> {
993    a.as_ref().log10_device(stream)
994}
995
996/// See [`Array::log1p`].
997#[generate_macro]
998#[default_device]
999pub fn log1p_device(a: impl AsRef<Array>, #[optional] stream: impl AsRef<Stream>) -> Result<Array> {
1000    a.as_ref().log1p_device(stream)
1001}
1002
1003/// See [`Array::log2`].
1004#[generate_macro]
1005#[default_device]
1006pub fn log2_device(a: impl AsRef<Array>, #[optional] stream: impl AsRef<Stream>) -> Result<Array> {
1007    a.as_ref().log2_device(stream)
1008}
1009
1010/// Element-wise log-add-exp.
1011///
1012/// This is a numerically stable log-add-exp of two arrays with numpy-style broadcasting semantics.
1013/// Either or both input arrays can also be scalars.
1014///
1015/// The computation is is a numerically stable version of `log(exp(a) + exp(b))`.
1016#[generate_macro]
1017#[default_device]
1018pub fn logaddexp_device(
1019    a: impl AsRef<Array>,
1020    b: impl AsRef<Array>,
1021    #[optional] stream: impl AsRef<Stream>,
1022) -> Result<Array> {
1023    let a_ptr = a.as_ref().as_ptr();
1024    let b_ptr = b.as_ref().as_ptr();
1025
1026    Array::try_from_op(|res| unsafe {
1027        mlx_sys::mlx_logaddexp(res, a_ptr, b_ptr, stream.as_ref().as_ptr())
1028    })
1029}
1030
1031/// See [`Array::matmul`].
1032#[generate_macro]
1033#[default_device]
1034pub fn matmul_device(
1035    a: impl AsRef<Array>,
1036    b: impl AsRef<Array>,
1037    #[optional] stream: impl AsRef<Stream>,
1038) -> Result<Array> {
1039    a.as_ref().matmul_device(b, stream)
1040}
1041
1042/// Element-wise maximum.
1043///
1044/// Take the element-wise max of two arrays with numpy-style broadcasting semantics. Either or both
1045/// input arrays can also be scalars.
1046#[generate_macro]
1047#[default_device]
1048pub fn maximum_device(
1049    a: impl AsRef<Array>,
1050    b: impl AsRef<Array>,
1051    #[optional] stream: impl AsRef<Stream>,
1052) -> Result<Array> {
1053    let a_ptr = a.as_ref().as_ptr();
1054    let b_ptr = b.as_ref().as_ptr();
1055
1056    Array::try_from_op(|res| unsafe {
1057        mlx_sys::mlx_maximum(res, a_ptr, b_ptr, stream.as_ref().as_ptr())
1058    })
1059}
1060
1061/// Element-wise minimum.
1062///
1063/// Take the element-wise min of two arrays with numpy-style broadcasting semantics. Either or both
1064/// input arrays can also be scalars.
1065#[generate_macro]
1066#[default_device]
1067pub fn minimum_device(
1068    a: impl AsRef<Array>,
1069    b: impl AsRef<Array>,
1070    #[optional] stream: impl AsRef<Stream>,
1071) -> Result<Array> {
1072    let a_ptr = a.as_ref().as_ptr();
1073    let b_ptr = b.as_ref().as_ptr();
1074
1075    Array::try_from_op(|res| unsafe {
1076        mlx_sys::mlx_minimum(res, a_ptr, b_ptr, stream.as_ref().as_ptr())
1077    })
1078}
1079
1080/// See [`Array::multiply`].
1081#[generate_macro]
1082#[default_device]
1083pub fn multiply_device(
1084    a: impl AsRef<Array>,
1085    b: impl AsRef<Array>,
1086    #[optional] stream: impl AsRef<Stream>,
1087) -> Result<Array> {
1088    a.as_ref().multiply_device(b, stream)
1089}
1090
1091/// See [`Array::negative`].
1092#[generate_macro]
1093#[default_device]
1094pub fn negative_device(
1095    a: impl AsRef<Array>,
1096    #[optional] stream: impl AsRef<Stream>,
1097) -> Result<Array> {
1098    a.as_ref().negative_device(stream)
1099}
1100
1101/// See [`Array::power`].
1102#[generate_macro]
1103#[default_device]
1104pub fn power_device(
1105    a: impl AsRef<Array>,
1106    b: impl AsRef<Array>,
1107    #[optional] stream: impl AsRef<Stream>,
1108) -> Result<Array> {
1109    a.as_ref().power_device(b, stream)
1110}
1111
1112/// Convert angles from degrees to radians.
1113#[generate_macro]
1114#[default_device]
1115pub fn radians_device(
1116    a: impl AsRef<Array>,
1117    #[optional] stream: impl AsRef<Stream>,
1118) -> Result<Array> {
1119    Array::try_from_op(|res| unsafe {
1120        mlx_sys::mlx_radians(res, a.as_ref().as_ptr(), stream.as_ref().as_ptr())
1121    })
1122}
1123
1124/// See [`Array::reciprocal`].
1125#[generate_macro]
1126#[default_device]
1127pub fn reciprocal_device(
1128    a: impl AsRef<Array>,
1129    #[optional] stream: impl AsRef<Stream>,
1130) -> Result<Array> {
1131    a.as_ref().reciprocal_device(stream)
1132}
1133
1134/// See [`Array::remainder`].
1135#[generate_macro]
1136#[default_device]
1137pub fn remainder_device(
1138    a: impl AsRef<Array>,
1139    b: impl AsRef<Array>,
1140    #[optional] stream: impl AsRef<Stream>,
1141) -> Result<Array> {
1142    a.as_ref().remainder_device(b, stream)
1143}
1144
1145/// See [`Array::round`].
1146#[generate_macro]
1147#[default_device]
1148pub fn round_device(
1149    a: impl AsRef<Array>,
1150    decimals: impl Into<Option<i32>>,
1151    #[optional] stream: impl AsRef<Stream>,
1152) -> Result<Array> {
1153    a.as_ref().round_device(decimals, stream)
1154}
1155
1156/// See [`Array::rsqrt`].
1157#[generate_macro]
1158#[default_device]
1159pub fn rsqrt_device(a: impl AsRef<Array>, #[optional] stream: impl AsRef<Stream>) -> Result<Array> {
1160    a.as_ref().rsqrt_device(stream)
1161}
1162
1163/// Element-wise logistic sigmoid.
1164///
1165/// See the [python API
1166/// docs](https://ml-explore.github.io/mlx/build/html/python/_autosummary/mlx.core.sigmoid.html#mlx.core.sigmoid)
1167/// for more information
1168#[generate_macro]
1169#[default_device]
1170pub fn sigmoid_device(
1171    a: impl AsRef<Array>,
1172    #[optional] stream: impl AsRef<Stream>,
1173) -> Result<Array> {
1174    Array::try_from_op(|res| unsafe {
1175        mlx_sys::mlx_sigmoid(res, a.as_ref().as_ptr(), stream.as_ref().as_ptr())
1176    })
1177}
1178
1179/// Element-wise sign.
1180#[generate_macro]
1181#[default_device]
1182pub fn sign_device(a: impl AsRef<Array>, #[optional] stream: impl AsRef<Stream>) -> Result<Array> {
1183    Array::try_from_op(|res| unsafe {
1184        mlx_sys::mlx_sign(res, a.as_ref().as_ptr(), stream.as_ref().as_ptr())
1185    })
1186}
1187
1188/// See [`Array::sin`].
1189#[generate_macro]
1190#[default_device]
1191pub fn sin_device(a: impl AsRef<Array>, #[optional] stream: impl AsRef<Stream>) -> Result<Array> {
1192    a.as_ref().sin_device(stream)
1193}
1194
1195/// Element-wise hyperbolic sine.
1196#[generate_macro]
1197#[default_device]
1198pub fn sinh_device(a: impl AsRef<Array>, #[optional] stream: impl AsRef<Stream>) -> Result<Array> {
1199    Array::try_from_op(|res| unsafe {
1200        mlx_sys::mlx_sinh(res, a.as_ref().as_ptr(), stream.as_ref().as_ptr())
1201    })
1202}
1203
1204/// Perform the softmax along the given axis.
1205///
1206/// See the [python API
1207/// docs](https://ml-explore.github.io/mlx/build/html/python/_autosummary/mlx.core.softmax.html#mlx.core.softmax)
1208/// for more information.
1209#[generate_macro]
1210#[default_device]
1211pub fn softmax_axes_device(
1212    a: impl AsRef<Array>,
1213    axes: &[i32],
1214    precise: impl Into<Option<bool>>,
1215    #[optional] stream: impl AsRef<Stream>,
1216) -> Result<Array> {
1217    let precise = precise.into().unwrap_or(false);
1218    let s = stream.as_ref().as_ptr();
1219
1220    Array::try_from_op(|res| unsafe {
1221        mlx_sys::mlx_softmax_axes(
1222            res,
1223            a.as_ref().as_ptr(),
1224            axes.as_ptr(),
1225            axes.len(),
1226            precise,
1227            s,
1228        )
1229    })
1230}
1231
1232/// Similar to [`softmax_axes`] but with a single axis.
1233#[generate_macro]
1234#[default_device]
1235pub fn softmax_axis_device(
1236    a: impl AsRef<Array>,
1237    axis: i32,
1238    precise: impl Into<Option<bool>>,
1239    #[optional] stream: impl AsRef<Stream>,
1240) -> Result<Array> {
1241    let precise = precise.into().unwrap_or(false);
1242    let s = stream.as_ref().as_ptr();
1243
1244    Array::try_from_op(|res| unsafe {
1245        mlx_sys::mlx_softmax_axis(res, a.as_ref().as_ptr(), axis, precise, s)
1246    })
1247}
1248
1249/// Similar to [`softmax_axes`] but with no axis specified.
1250#[generate_macro]
1251#[default_device]
1252pub fn softmax_device(
1253    a: impl AsRef<Array>,
1254    precise: impl Into<Option<bool>>,
1255    #[optional] stream: impl AsRef<Stream>,
1256) -> Result<Array> {
1257    let precise = precise.into().unwrap_or(false);
1258    let s = stream.as_ref().as_ptr();
1259
1260    Array::try_from_op(|res| unsafe { mlx_sys::mlx_softmax(res, a.as_ref().as_ptr(), precise, s) })
1261}
1262
1263/// See [`Array::sqrt`].
1264#[generate_macro]
1265#[default_device]
1266pub fn sqrt_device(a: impl AsRef<Array>, #[optional] stream: impl AsRef<Stream>) -> Result<Array> {
1267    a.as_ref().sqrt_device(stream)
1268}
1269
1270/// See [`Array::square`].
1271#[generate_macro]
1272#[default_device]
1273pub fn square_device(
1274    a: impl AsRef<Array>,
1275    #[optional] stream: impl AsRef<Stream>,
1276) -> Result<Array> {
1277    a.as_ref().square_device(stream)
1278}
1279
1280/// See [`Array::subtract`].
1281#[generate_macro]
1282#[default_device]
1283pub fn subtract_device(
1284    a: impl AsRef<Array>,
1285    b: impl AsRef<Array>,
1286    #[optional] stream: impl AsRef<Stream>,
1287) -> Result<Array> {
1288    a.as_ref().subtract_device(b, stream)
1289}
1290
1291/// See [`Array::tan`].
1292#[generate_macro]
1293#[default_device]
1294pub fn tan_device(a: impl AsRef<Array>, #[optional] stream: impl AsRef<Stream>) -> Result<Array> {
1295    Array::try_from_op(|res| unsafe {
1296        mlx_sys::mlx_tan(res, a.as_ref().as_ptr(), stream.as_ref().as_ptr())
1297    })
1298}
1299
1300/// Element-wise hyperbolic tangent.
1301#[generate_macro]
1302#[default_device]
1303pub fn tanh_device(a: impl AsRef<Array>, #[optional] stream: impl AsRef<Stream>) -> Result<Array> {
1304    Array::try_from_op(|res| unsafe {
1305        mlx_sys::mlx_tanh(res, a.as_ref().as_ptr(), stream.as_ref().as_ptr())
1306    })
1307}
1308
1309/// Matrix multiplication with block masking.
1310///
1311/// See the [python API docs](
1312/// https://ml-explore.github.io/mlx/build/html/python/_autosummary/mlx.core.block_masked_mm.html#mlx.core.block_masked_mm
1313/// ) for more information.
1314#[generate_macro]
1315#[default_device]
1316pub fn block_masked_mm_device<'mo, 'lhs, 'rhs>(
1317    a: impl AsRef<Array>,
1318    b: impl AsRef<Array>,
1319    #[optional] block_size: impl Into<Option<i32>>,
1320    #[optional] mask_out: impl Into<Option<&'mo Array>>,
1321    #[optional] mask_lhs: impl Into<Option<&'lhs Array>>,
1322    #[optional] mask_rhs: impl Into<Option<&'rhs Array>>,
1323    #[optional] stream: impl AsRef<Stream>,
1324) -> Result<Array> {
1325    let a_ptr = a.as_ref().as_ptr();
1326    let b_ptr = b.as_ref().as_ptr();
1327    unsafe {
1328        let mask_out_ptr = mask_out
1329            .into()
1330            .map(|m| m.as_ptr())
1331            .unwrap_or(mlx_sys::mlx_array_new());
1332        let mask_lhs_ptr = mask_lhs
1333            .into()
1334            .map(|m| m.as_ptr())
1335            .unwrap_or(mlx_sys::mlx_array_new());
1336        let mask_rhs_ptr = mask_rhs
1337            .into()
1338            .map(|m| m.as_ptr())
1339            .unwrap_or(mlx_sys::mlx_array_new());
1340
1341        Array::try_from_op(|res| {
1342            mlx_sys::mlx_block_masked_mm(
1343                res,
1344                a_ptr,
1345                b_ptr,
1346                block_size.into().unwrap_or(32),
1347                mask_out_ptr,
1348                mask_lhs_ptr,
1349                mask_rhs_ptr,
1350                stream.as_ref().as_ptr(),
1351            )
1352        })
1353    }
1354}
1355
1356/// Matrix multiplication with addition and optional scaling.
1357///
1358/// Perform the (possibly batched) matrix multiplication of two arrays and add to the result with
1359/// optional scaling factors.
1360///
1361/// # Params
1362///
1363/// - `c`: input array,
1364/// - `a`: input array,
1365/// - `b`: input array,
1366/// - `alpha`: Scaling factor for the matrix product of `a` and `b` (default: `1`)
1367/// - `beta`: Scaling factor for `c` (default: `1`)
1368#[generate_macro]
1369#[default_device]
1370pub fn addmm_device(
1371    c: impl AsRef<Array>,
1372    a: impl AsRef<Array>,
1373    b: impl AsRef<Array>,
1374    #[optional] alpha: impl Into<Option<f32>>,
1375    #[optional] beta: impl Into<Option<f32>>,
1376    #[optional] stream: impl AsRef<Stream>,
1377) -> Result<Array> {
1378    let c_ptr = c.as_ref().as_ptr();
1379    let a_ptr = a.as_ref().as_ptr();
1380    let b_ptr = b.as_ref().as_ptr();
1381    let alpha = alpha.into().unwrap_or(1.0);
1382    let beta = beta.into().unwrap_or(1.0);
1383
1384    Array::try_from_op(|res| unsafe {
1385        mlx_sys::mlx_addmm(
1386            res,
1387            c_ptr,
1388            a_ptr,
1389            b_ptr,
1390            alpha,
1391            beta,
1392            stream.as_ref().as_ptr(),
1393        )
1394    })
1395}
1396
1397/// Ordinary inner product of vectors for 1-D arrays, in higher dimensions a sum product over the
1398/// last axes.
1399#[generate_macro]
1400#[default_device]
1401pub fn inner_device(
1402    a: impl AsRef<Array>,
1403    b: impl AsRef<Array>,
1404    #[optional] stream: impl AsRef<Stream>,
1405) -> Result<Array> {
1406    let a = a.as_ref();
1407    let b = b.as_ref();
1408    Array::try_from_op(|res| unsafe {
1409        mlx_sys::mlx_inner(res, a.as_ptr(), b.as_ptr(), stream.as_ref().as_ptr())
1410    })
1411}
1412
1413/// Compute the outer product of two 1-D arrays, if the array’s passed are not 1-D a flatten op will
1414/// be run beforehand.
1415#[generate_macro]
1416#[default_device]
1417pub fn outer_device(
1418    a: impl AsRef<Array>,
1419    b: impl AsRef<Array>,
1420    #[optional] stream: impl AsRef<Stream>,
1421) -> Result<Array> {
1422    let a = a.as_ref();
1423    let b = b.as_ref();
1424    Array::try_from_op(|res| unsafe {
1425        mlx_sys::mlx_outer(res, a.as_ptr(), b.as_ptr(), stream.as_ref().as_ptr())
1426    })
1427}
1428
1429/// Compute the tensor dot product along the specified axes.
1430#[generate_macro]
1431#[default_device]
1432pub fn tensordot_axes_device(
1433    a: impl AsRef<Array>,
1434    b: impl AsRef<Array>,
1435    axes_a: &[i32],
1436    axes_b: &[i32],
1437    #[optional] stream: impl AsRef<Stream>,
1438) -> Result<Array> {
1439    let a = a.as_ref();
1440    let b = b.as_ref();
1441    Array::try_from_op(|res| unsafe {
1442        mlx_sys::mlx_tensordot(
1443            res,
1444            a.as_ptr(),
1445            b.as_ptr(),
1446            axes_a.as_ptr(),
1447            axes_a.len(),
1448            axes_b.as_ptr(),
1449            axes_b.len(),
1450            stream.as_ref().as_ptr(),
1451        )
1452    })
1453}
1454
1455/// Similar to [`tensordot_axes`] but with a single axis.
1456#[generate_macro]
1457#[default_device]
1458pub fn tensordot_axis_device(
1459    a: impl AsRef<Array>,
1460    b: impl AsRef<Array>,
1461    axis: i32,
1462    #[optional] stream: impl AsRef<Stream>,
1463) -> Result<Array> {
1464    let a = a.as_ref();
1465    let b = b.as_ref();
1466    Array::try_from_op(|res| unsafe {
1467        mlx_sys::mlx_tensordot_axis(res, a.as_ptr(), b.as_ptr(), axis, stream.as_ref().as_ptr())
1468    })
1469}
1470
1471#[cfg(test)]
1472mod tests {
1473    use std::f32::consts::PI;
1474
1475    use super::*;
1476    use crate::{
1477        array, complex64,
1478        ops::{all_close, arange, broadcast_to, eye, full, linspace, ones, reshape, split},
1479        transforms::eval,
1480        Dtype, StreamOrDevice,
1481    };
1482    use float_eq::assert_float_eq;
1483    use pretty_assertions::assert_eq;
1484
1485    #[test]
1486    fn test_abs() {
1487        let data = [1i32, 2, -3, -4, -5];
1488        let array = Array::from_slice(&data, &[5]);
1489        let result = array.abs().unwrap();
1490
1491        let data: &[i32] = result.as_slice();
1492        assert_eq!(data, [1, 2, 3, 4, 5]);
1493
1494        // test that previous array is not modified and valid
1495        let data: &[i32] = array.as_slice();
1496        assert_eq!(data, [1, 2, -3, -4, -5]);
1497    }
1498
1499    #[test]
1500    fn test_add() {
1501        let a = Array::from_slice(&[1.0, 2.0, 3.0], &[3]);
1502        let b = Array::from_slice(&[4.0, 5.0, 6.0], &[3]);
1503
1504        let c = &a + &b;
1505
1506        let c_data: &[f32] = c.as_slice();
1507        assert_eq!(c_data, &[5.0, 7.0, 9.0]);
1508
1509        // check a and b are not modified
1510        let a_data: &[f32] = a.as_slice();
1511        assert_eq!(a_data, &[1.0, 2.0, 3.0]);
1512
1513        let b_data: &[f32] = b.as_slice();
1514        assert_eq!(b_data, &[4.0, 5.0, 6.0]);
1515    }
1516
1517    #[test]
1518    fn test_add_invalid_broadcast() {
1519        let a = Array::from_slice(&[1.0, 2.0, 3.0], &[3]);
1520        let b = Array::from_slice(&[4.0, 5.0], &[2]);
1521
1522        let c = a.add(&b);
1523        assert!(c.is_err());
1524    }
1525
1526    #[test]
1527    fn test_sub() {
1528        let a = Array::from_slice(&[1.0, 2.0, 3.0], &[3]);
1529        let b = Array::from_slice(&[4.0, 5.0, 6.0], &[3]);
1530
1531        let c = &a - &b;
1532
1533        let c_data: &[f32] = c.as_slice();
1534        assert_eq!(c_data, &[-3.0, -3.0, -3.0]);
1535
1536        // check a and b are not modified
1537        let a_data: &[f32] = a.as_slice();
1538        assert_eq!(a_data, &[1.0, 2.0, 3.0]);
1539
1540        let b_data: &[f32] = b.as_slice();
1541        assert_eq!(b_data, &[4.0, 5.0, 6.0]);
1542    }
1543
1544    #[test]
1545    fn test_sub_invalid_broadcast() {
1546        let a = Array::from_slice(&[1.0, 2.0, 3.0], &[3]);
1547        let b = Array::from_slice(&[4.0, 5.0], &[2]);
1548        let c = a.subtract(&b);
1549        assert!(c.is_err());
1550    }
1551
1552    #[test]
1553    fn test_neg() {
1554        let a = Array::from_slice::<f32>(&[1.0, 2.0, 3.0], &[3]);
1555        let b = a.negative().unwrap();
1556
1557        let b_data: &[f32] = b.as_slice();
1558        assert_eq!(b_data, &[-1.0, -2.0, -3.0]);
1559
1560        // check a is not modified
1561        let a_data: &[f32] = a.as_slice();
1562        assert_eq!(a_data, &[1.0, 2.0, 3.0]);
1563    }
1564
1565    #[test]
1566    fn test_neg_bool() {
1567        let a = Array::from_slice(&[true, false, true], &[3]);
1568        let b = a.negative();
1569        assert!(b.is_err());
1570    }
1571
1572    #[test]
1573    fn test_logical_not() {
1574        let a: Array = false.into();
1575        let b = a.logical_not().unwrap();
1576
1577        let b_data: &[bool] = b.as_slice();
1578        assert_eq!(b_data, [true]);
1579    }
1580
1581    #[test]
1582    fn test_mul() {
1583        let a = Array::from_slice(&[1.0, 2.0, 3.0], &[3]);
1584        let b = Array::from_slice(&[4.0, 5.0, 6.0], &[3]);
1585
1586        let c = &a * &b;
1587
1588        let c_data: &[f32] = c.as_slice();
1589        assert_eq!(c_data, &[4.0, 10.0, 18.0]);
1590
1591        // check a and b are not modified
1592        let a_data: &[f32] = a.as_slice();
1593        assert_eq!(a_data, &[1.0, 2.0, 3.0]);
1594
1595        let b_data: &[f32] = b.as_slice();
1596        assert_eq!(b_data, &[4.0, 5.0, 6.0]);
1597    }
1598
1599    #[test]
1600    fn test_mul_invalid_broadcast() {
1601        let a = Array::from_slice(&[1.0, 2.0, 3.0], &[3]);
1602        let b = Array::from_slice(&[4.0, 5.0], &[2]);
1603        let c = a.multiply(&b);
1604        assert!(c.is_err());
1605    }
1606
1607    #[test]
1608    fn test_nan_to_num() {
1609        let a = array!([1.0, 2.0, f32::NAN, 4.0, 5.0]);
1610        let b = a.nan_to_num(0.0, 1.0, 0.0).unwrap();
1611
1612        let b_data: &[f32] = b.as_slice();
1613        assert_eq!(b_data, &[1.0, 2.0, 0.0, 4.0, 5.0]);
1614    }
1615
1616    #[test]
1617    fn test_div() {
1618        let a = Array::from_slice(&[1.0, 2.0, 3.0], &[3]);
1619        let b = Array::from_slice(&[4.0, 5.0, 6.0], &[3]);
1620
1621        let c = &a / &b;
1622
1623        let c_data: &[f32] = c.as_slice();
1624        assert_eq!(c_data, &[0.25, 0.4, 0.5]);
1625
1626        // check a and b are not modified
1627        let a_data: &[f32] = a.as_slice();
1628        assert_eq!(a_data, &[1.0, 2.0, 3.0]);
1629
1630        let b_data: &[f32] = b.as_slice();
1631        assert_eq!(b_data, &[4.0, 5.0, 6.0]);
1632    }
1633
1634    #[test]
1635    fn test_div_invalid_broadcast() {
1636        let a = Array::from_slice(&[1.0, 2.0, 3.0], &[3]);
1637        let b = Array::from_slice(&[4.0, 5.0], &[2]);
1638        let c = a.divide(&b);
1639        assert!(c.is_err());
1640    }
1641
1642    #[test]
1643    fn test_pow() {
1644        let a = Array::from_slice(&[1.0, 2.0, 3.0], &[3]);
1645        let b = Array::from_slice(&[2.0, 3.0, 4.0], &[3]);
1646
1647        let c = a.power(&b).unwrap();
1648
1649        let c_data: &[f32] = c.as_slice();
1650        assert_eq!(c_data, &[1.0, 8.0, 81.0]);
1651
1652        // check a and b are not modified
1653        let a_data: &[f32] = a.as_slice();
1654        assert_eq!(a_data, &[1.0, 2.0, 3.0]);
1655
1656        let b_data: &[f32] = b.as_slice();
1657        assert_eq!(b_data, &[2.0, 3.0, 4.0]);
1658    }
1659
1660    #[test]
1661    fn test_pow_invalid_broadcast() {
1662        let a = Array::from_slice(&[1.0, 2.0, 3.0], &[3]);
1663        let b = Array::from_slice(&[2.0, 3.0], &[2]);
1664        let c = a.power(&b);
1665        assert!(c.is_err());
1666    }
1667
1668    #[test]
1669    fn test_rem() {
1670        let a = Array::from_slice(&[10.0, 11.0, 12.0], &[3]);
1671        let b = Array::from_slice(&[3.0, 4.0, 5.0], &[3]);
1672
1673        let c = &a % &b;
1674
1675        let c_data: &[f32] = c.as_slice();
1676        assert_eq!(c_data, &[1.0, 3.0, 2.0]);
1677
1678        // check a and b are not modified
1679        let a_data: &[f32] = a.as_slice();
1680        assert_eq!(a_data, &[10.0, 11.0, 12.0]);
1681
1682        let b_data: &[f32] = b.as_slice();
1683        assert_eq!(b_data, &[3.0, 4.0, 5.0]);
1684    }
1685
1686    #[test]
1687    fn test_rem_invalid_broadcast() {
1688        let a = Array::from_slice(&[10.0, 11.0, 12.0], &[3]);
1689        let b = Array::from_slice(&[3.0, 4.0], &[2]);
1690        let c = a.remainder(&b);
1691        assert!(c.is_err());
1692    }
1693
1694    #[test]
1695    fn test_sqrt() {
1696        let a = Array::from_slice(&[1.0, 4.0, 9.0], &[3]);
1697        let b = a.sqrt().unwrap();
1698
1699        let b_data: &[f32] = b.as_slice();
1700        assert_eq!(b_data, &[1.0, 2.0, 3.0]);
1701
1702        // check a is not modified
1703        let a_data: &[f32] = a.as_slice();
1704        assert_eq!(a_data, &[1.0, 4.0, 9.0]);
1705    }
1706
1707    #[test]
1708    fn test_cos() {
1709        let a = Array::from_slice(&[0.0, 1.0, 2.0], &[3]);
1710        let b = a.cos().unwrap();
1711
1712        let b_expected = array!([1.0, 0.54030234, -0.41614687]);
1713        assert_array_all_close!(b, b_expected);
1714
1715        // check a is not modified
1716        let a_expected = array!([0.0, 1.0, 2.0]);
1717        assert_array_all_close!(a, a_expected);
1718    }
1719
1720    #[test]
1721    fn test_exp() {
1722        let a = Array::from_slice(&[0.0, 1.0, 2.0], &[3]);
1723        let b = a.exp().unwrap();
1724
1725        let b_expected = array!([1.0, 2.7182817, 7.389056]);
1726        assert_array_all_close!(b, b_expected);
1727
1728        // check a is not modified
1729        let a_expected = array!([0.0, 1.0, 2.0]);
1730        assert_array_all_close!(a, a_expected);
1731    }
1732
1733    #[test]
1734    fn test_floor() {
1735        let a = Array::from_slice(&[0.1, 1.9, 2.5], &[3]);
1736        let b = a.floor().unwrap();
1737
1738        let b_data: &[f32] = b.as_slice();
1739        assert_eq!(b_data, &[0.0, 1.0, 2.0]);
1740
1741        // check a is not modified
1742        let a_data: &[f32] = a.as_slice();
1743        assert_eq!(a_data, &[0.1, 1.9, 2.5]);
1744    }
1745
1746    #[test]
1747    fn test_floor_complex64() {
1748        let val = complex64::new(1.0, 2.0);
1749        let a = Array::from_complex(val);
1750        let b = a.floor_device(StreamOrDevice::default());
1751        assert!(b.is_err());
1752    }
1753
1754    #[test]
1755    fn test_floor_divide() {
1756        let a = Array::from_slice(&[1.0, 2.0, 3.0], &[3]);
1757        let b = Array::from_slice(&[4.0, 5.0, 6.0], &[3]);
1758
1759        let c = a.floor_divide(&b).unwrap();
1760
1761        let c_data: &[f32] = c.as_slice();
1762        assert_eq!(c_data, &[0.0, 0.0, 0.0]);
1763
1764        // check a and b are not modified
1765        let a_data: &[f32] = a.as_slice();
1766        assert_eq!(a_data, &[1.0, 2.0, 3.0]);
1767
1768        let b_data: &[f32] = b.as_slice();
1769        assert_eq!(b_data, &[4.0, 5.0, 6.0]);
1770    }
1771
1772    #[test]
1773    fn test_floor_divide_complex64() {
1774        let val = complex64::new(1.0, 2.0);
1775        let a = Array::from_complex(val);
1776        let b = Array::from_slice(&[4.0, 5.0, 6.0], &[3]);
1777        let c = a.floor_divide_device(&b, StreamOrDevice::default());
1778        assert!(c.is_err());
1779    }
1780
1781    #[test]
1782    fn test_floor_divide_invalid_broadcast() {
1783        let a = Array::from_slice(&[1.0, 2.0, 3.0], &[3]);
1784        let b = Array::from_slice(&[4.0, 5.0], &[2]);
1785        let c = a.floor_divide_device(&b, StreamOrDevice::default());
1786        assert!(c.is_err());
1787    }
1788
1789    #[test]
1790    fn test_is_nan() {
1791        let a = Array::from_slice(&[1.0, f32::NAN, 3.0], &[3]);
1792        let b = a.is_nan().unwrap();
1793
1794        let b_data: &[bool] = b.as_slice();
1795        assert_eq!(b_data, &[false, true, false]);
1796    }
1797
1798    #[test]
1799    fn test_is_inf() {
1800        let a = Array::from_slice(&[1.0, f32::INFINITY, 3.0], &[3]);
1801        let b = a.is_inf().unwrap();
1802
1803        let b_data: &[bool] = b.as_slice();
1804        assert_eq!(b_data, &[false, true, false]);
1805    }
1806
1807    #[test]
1808    fn test_is_finite() {
1809        let a = Array::from_slice(&[1.0, f32::INFINITY, 3.0], &[3]);
1810        let b = a.is_finite().unwrap();
1811
1812        let b_data: &[bool] = b.as_slice();
1813        assert_eq!(b_data, &[true, false, true]);
1814    }
1815
1816    #[test]
1817    fn test_is_neg_inf() {
1818        let a = Array::from_slice(&[1.0, f32::NEG_INFINITY, 3.0], &[3]);
1819        let b = a.is_neg_inf().unwrap();
1820
1821        let b_data: &[bool] = b.as_slice();
1822        assert_eq!(b_data, &[false, true, false]);
1823    }
1824
1825    #[test]
1826    fn test_is_pos_inf() {
1827        let a = Array::from_slice(&[1.0, f32::INFINITY, 3.0], &[3]);
1828        let b = a.is_pos_inf().unwrap();
1829
1830        let b_data: &[bool] = b.as_slice();
1831        assert_eq!(b_data, &[false, true, false]);
1832    }
1833
1834    #[test]
1835    fn test_log() {
1836        let a = Array::from_slice(&[1.0, 2.0, 3.0], &[3]);
1837        let b = a.log().unwrap();
1838
1839        let b_data: &[f32] = b.as_slice();
1840        assert_eq!(b_data, &[0.0, 0.6931472, 1.0986123]);
1841
1842        // check a is not modified
1843        let a_data: &[f32] = a.as_slice();
1844        assert_eq!(a_data, &[1.0, 2.0, 3.0]);
1845    }
1846
1847    #[test]
1848    fn test_log2() {
1849        let a = Array::from_slice(&[1.0, 2.0, 4.0, 8.0], &[4]);
1850        let b = a.log2().unwrap();
1851
1852        let b_data: &[f32] = b.as_slice();
1853        assert_eq!(b_data, &[0.0, 1.0, 2.0, 3.0]);
1854
1855        // check a is not modified
1856        let a_data: &[f32] = a.as_slice();
1857        assert_eq!(a_data, &[1.0, 2.0, 4.0, 8.0]);
1858    }
1859
1860    #[test]
1861    fn test_log10() {
1862        let a = Array::from_slice(&[1.0, 10.0, 100.0], &[3]);
1863        let b = a.log10().unwrap();
1864
1865        let b_data: &[f32] = b.as_slice();
1866        assert_eq!(b_data, &[0.0, 1.0, 2.0]);
1867
1868        // check a is not modified
1869        let a_data: &[f32] = a.as_slice();
1870        assert_eq!(a_data, &[1.0, 10.0, 100.0]);
1871    }
1872
1873    #[test]
1874    fn test_log1p() {
1875        let a = Array::from_slice(&[1.0, 2.0, 3.0], &[3]);
1876        let b = a.log1p().unwrap();
1877
1878        let b_data: &[f32] = b.as_slice();
1879        assert_eq!(b_data, &[0.6931472, 1.0986123, 1.3862944]);
1880
1881        // check a is not modified
1882        let a_data: &[f32] = a.as_slice();
1883        assert_eq!(a_data, &[1.0, 2.0, 3.0]);
1884    }
1885
1886    #[test]
1887    fn test_matmul() {
1888        let a = Array::from_slice(&[1, 2, 3, 4], &[2, 2]);
1889        let b = Array::from_slice(&[-5.0, 37.5, 4., 7., 1., 0.], &[2, 3]);
1890
1891        let c = a.matmul(&b).unwrap();
1892
1893        assert_eq!(c.shape(), &[2, 3]);
1894        let c_data: &[f32] = c.as_slice();
1895        assert_eq!(c_data, &[9.0, 39.5, 4.0, 13.0, 116.5, 12.0]);
1896
1897        // check a and b are not modified
1898        let a_data: &[i32] = a.as_slice();
1899        assert_eq!(a_data, &[1, 2, 3, 4]);
1900
1901        let b_data: &[f32] = b.as_slice();
1902        assert_eq!(b_data, &[-5.0, 37.5, 4., 7., 1., 0.]);
1903    }
1904
1905    #[test]
1906    fn test_matmul_ndim_zero() {
1907        let a: Array = 1.0.into();
1908        let b = Array::from_slice::<i32>(&[1], &[1]);
1909        let c = a.matmul(&b);
1910        assert!(c.is_err());
1911    }
1912
1913    #[test]
1914    fn test_matmul_ndim_one() {
1915        let a = Array::from_slice(&[1.0, 2.0, 3.0, 4.0], &[4]);
1916        let b = Array::from_slice(&[1.0, 2.0, 3.0, 4.0], &[4]);
1917        let c = a.matmul(&b);
1918        assert!(c.is_ok());
1919    }
1920
1921    #[test]
1922    fn test_matmul_dim_mismatch() {
1923        let a = Array::from_slice(&[1, 2, 3, 4, 5, 6], &[2, 3]);
1924        let b = Array::from_slice(&[1, 2, 3, 4, 5, 6, 7, 8, 9, 10], &[2, 5]);
1925        let c = a.matmul(&b);
1926        assert!(c.is_err());
1927    }
1928
1929    #[test]
1930    fn test_matmul_non_float_output_type() {
1931        let a = Array::from_slice(&[1, 2, 3, 4], &[2, 2]);
1932        let b = Array::from_slice(&[5, 37, 4, 7, 1, 0], &[2, 3]);
1933
1934        let c = a.matmul(&b);
1935        assert!(c.is_err());
1936    }
1937
1938    #[test]
1939    fn test_reciprocal() {
1940        let a = Array::from_slice(&[1.0, 2.0, 4.0], &[3]);
1941        let b = a.reciprocal().unwrap();
1942
1943        let b_data: &[f32] = b.as_slice();
1944        assert_eq!(b_data, &[1.0, 0.5, 0.25]);
1945
1946        // check a is not modified
1947        let a_data: &[f32] = a.as_slice();
1948        assert_eq!(a_data, &[1.0, 2.0, 4.0]);
1949    }
1950
1951    #[test]
1952    fn test_round() {
1953        let a = Array::from_slice(&[1.1, 2.9, 3.5], &[3]);
1954        let b = a.round(None).unwrap();
1955
1956        let b_data: &[f32] = b.as_slice();
1957        assert_eq!(b_data, &[1.0, 3.0, 4.0]);
1958
1959        // check a is not modified
1960        let a_data: &[f32] = a.as_slice();
1961        assert_eq!(a_data, &[1.1, 2.9, 3.5]);
1962    }
1963
1964    #[test]
1965    fn test_rsqrt() {
1966        let a = Array::from_slice(&[1.0, 2.0, 4.0], &[3]);
1967        let b = a.rsqrt().unwrap();
1968
1969        let b_data: &[f32] = b.as_slice();
1970        assert_eq!(b_data, &[1.0, 0.70710677, 0.5]);
1971
1972        // check a is not modified
1973        let a_data: &[f32] = a.as_slice();
1974        assert_eq!(a_data, &[1.0, 2.0, 4.0]);
1975    }
1976
1977    #[test]
1978    fn test_sin() {
1979        let a = Array::from_slice(&[0.0, 1.0, 2.0], &[3]);
1980        let b = a.sin().unwrap();
1981
1982        let b_data: &[f32] = b.as_slice();
1983        assert_eq!(b_data, &[0.0, 0.841471, 0.9092974]);
1984
1985        // check a is not modified
1986        let a_data: &[f32] = a.as_slice();
1987        assert_eq!(a_data, &[0.0, 1.0, 2.0]);
1988    }
1989
1990    #[test]
1991    fn test_square() {
1992        let a = Array::from_slice(&[1.0, 2.0, 3.0], &[3]);
1993        let b = a.square().unwrap();
1994
1995        let b_data: &[f32] = b.as_slice();
1996        assert_eq!(b_data, &[1.0, 4.0, 9.0]);
1997
1998        // check a is not modified
1999        let a_data: &[f32] = a.as_slice();
2000        assert_eq!(a_data, &[1.0, 2.0, 3.0]);
2001    }
2002
2003    // The unit tests below are adapted from the original mlx c++ codebase.
2004
2005    #[test]
2006    fn test_unary_neg() {
2007        let x = array!(1.0);
2008        assert_eq!(negative(&x).unwrap().item::<f32>(), -1.0);
2009        assert_eq!((-x).item::<f32>(), -1.0);
2010
2011        // works on empty array
2012        assert_eq!(-array!(), array!());
2013
2014        // Throws on bool
2015        let x = array!(true);
2016        assert!(negative(&x).is_err());
2017    }
2018
2019    #[test]
2020    fn test_unary_abs() {
2021        let x = array!([-1.0, 0.0, 1.0]);
2022        assert_eq!(abs(&x).unwrap(), array!([1.0, 0.0, 1.0]));
2023
2024        // works on empty array
2025        assert_eq!(abs(array!()).unwrap(), array!());
2026
2027        // int32
2028        let x = array!([-1, 0, 1]);
2029        assert_eq!(abs(&x).unwrap(), array!([1, 0, 1]));
2030
2031        // uint32
2032        let x = array!([1u32, 0, 1]);
2033        assert_eq!(abs(&x).unwrap(), array!([1u32, 0, 1]));
2034
2035        // bool
2036        let x = array!([false, true]);
2037        assert_eq!(abs(&x).unwrap(), array!([false, true]));
2038    }
2039
2040    #[test]
2041    fn test_unary_sign() {
2042        let x = array!([-1.0, 0.0, 1.0]);
2043        assert_eq!(sign(&x).unwrap(), x);
2044
2045        // works on empty array
2046        assert_eq!(sign(array!()).unwrap(), array!());
2047
2048        // int32
2049        let x = array!([-1, 0, 1]);
2050        assert_eq!(sign(&x).unwrap(), x);
2051
2052        // uint32
2053        let x = array!([1u32, 0, 1]);
2054        assert_eq!(sign(&x).unwrap(), x);
2055
2056        // bool
2057        let x = array!([false, true]);
2058        assert_eq!(sign(&x).unwrap(), x);
2059    }
2060
2061    const NEG_INF: f32 = f32::NEG_INFINITY;
2062
2063    #[test]
2064    fn test_unary_floor_ceil() {
2065        let x = array![1.0];
2066        assert_eq!(floor(&x).unwrap().item::<f32>(), 1.0);
2067        assert_eq!(ceil(&x).unwrap().item::<f32>(), 1.0);
2068
2069        let x = array![1.5];
2070        assert_eq!(floor(&x).unwrap().item::<f32>(), 1.0);
2071        assert_eq!(ceil(&x).unwrap().item::<f32>(), 2.0);
2072
2073        let x = array![-1.5];
2074        assert_eq!(floor(&x).unwrap().item::<f32>(), -2.0);
2075        assert_eq!(ceil(&x).unwrap().item::<f32>(), -1.0);
2076
2077        let x = array![NEG_INF];
2078        assert_eq!(floor(&x).unwrap().item::<f32>(), NEG_INF);
2079        assert_eq!(ceil(&x).unwrap().item::<f32>(), NEG_INF);
2080
2081        let x = array!([1.0, 1.0]).as_type::<complex64>().unwrap();
2082        assert!(floor(&x).is_err());
2083        assert!(ceil(&x).is_err());
2084    }
2085
2086    #[test]
2087    fn test_unary_round() {
2088        let x = array!([0.5, -0.5, 1.5, -1.5, 2.3, 2.6]);
2089        assert_eq!(round(&x, None).unwrap(), array!([0, 0, 2, -2, 2, 3]));
2090
2091        let x = array!([11, 222, 32]);
2092        assert_eq!(round(&x, -1).unwrap(), array!([10, 220, 30]));
2093    }
2094
2095    #[test]
2096    fn test_unary_exp() {
2097        let x = array![0.0];
2098        assert_eq!(exp(&x).unwrap().item::<f32>(), 1.0);
2099
2100        let x = array![2.0];
2101        assert_float_eq! {
2102            exp(&x).unwrap().item::<f32>(),
2103            2.0f32.exp(),
2104            abs <= 1e-5
2105        };
2106
2107        assert_eq!(exp(array!()).unwrap(), array!());
2108
2109        let x = array![NEG_INF];
2110        assert_eq!(exp(&x).unwrap().item::<f32>(), 0.0);
2111
2112        // Integer input type
2113        let x = array![2];
2114        assert_eq!(x.dtype(), Dtype::Int32);
2115        assert_float_eq! {
2116            exp(&x).unwrap().item::<f32>(),
2117            2.0f32.exp(),
2118            abs <= 1e-5
2119        };
2120
2121        // Input is irregularly strided
2122        let x = broadcast_to(&array!(1.0), &[2, 2, 2]).unwrap();
2123        let res = exp(&x).unwrap();
2124        let expected = Array::full::<f32>(&[2, 2, 2], array!(1.0f32.exp())).unwrap();
2125        assert!(all_close(&res, &expected, None, None, None)
2126            .unwrap()
2127            .item::<bool>());
2128
2129        let data = Array::from_slice(&[0.0, 1.0, 2.0, 3.0], &[2, 2]);
2130        let x = split(&data, 2, 1).unwrap();
2131        let expected = Array::from_slice(&[0.0f32.exp(), 2.0f32.exp()], &[2, 1]);
2132        assert!(all_close(exp(&x[0]).unwrap(), &expected, None, None, None)
2133            .unwrap()
2134            .item::<bool>());
2135    }
2136
2137    #[test]
2138    fn test_unary_expm1() {
2139        let x = array![-1.0];
2140        assert_float_eq! {
2141            expm1(&x).unwrap().item::<f32>(),
2142            (-1.0f32).exp_m1(),
2143            abs <= 1e-5
2144        };
2145
2146        let x = array![1.0];
2147        assert_float_eq! {
2148            expm1(&x).unwrap().item::<f32>(),
2149            1.0f32.exp_m1(),
2150            abs <= 1e-5
2151        };
2152
2153        // Integer input type
2154        let x = array![1];
2155        assert_eq!(expm1(&x).unwrap().dtype(), Dtype::Float32);
2156        assert_float_eq! {
2157            expm1(&x).unwrap().item::<f32>(),
2158            1.0f32.exp_m1(),
2159            abs <= 1e-5
2160        };
2161    }
2162
2163    #[test]
2164    fn test_unary_sin() {
2165        let x = array![0.0];
2166        assert_eq!(sin(&x).unwrap().item::<f32>(), 0.0);
2167
2168        let x = array![std::f32::consts::PI / 2.0];
2169        assert_float_eq! {
2170            sin(&x).unwrap().item::<f32>(),
2171            (std::f32::consts::PI / 2.0f32).sin(),
2172            abs <= 1e-5
2173        };
2174
2175        assert_eq!(sin(array!()).unwrap(), array!());
2176
2177        // Integer input type
2178        let x = array![0];
2179        assert_eq!(x.dtype(), Dtype::Int32);
2180        assert_float_eq! {
2181            sin(&x).unwrap().item::<f32>(),
2182            0.0f32.sin(),
2183            abs <= 1e-5
2184        };
2185
2186        // Input is irregularly strided
2187        let x = broadcast_to(&array!(1.0), &[2, 2, 2]).unwrap();
2188        let res = sin(&x).unwrap();
2189        let expected = Array::full::<f32>(&[2, 2, 2], array!(1.0f32.sin())).unwrap();
2190        assert!(all_close(&res, &expected, None, None, None)
2191            .unwrap()
2192            .item::<bool>());
2193
2194        let data = Array::from_slice(&[0.0, 1.0, 2.0, 3.0], &[2, 2]);
2195        let x = split(&data, 2, 1).unwrap();
2196        let expected = Array::from_slice(&[0.0f32.sin(), 2.0f32.sin()], &[2, 1]);
2197        assert!(all_close(sin(&x[0]).unwrap(), &expected, None, None, None)
2198            .unwrap()
2199            .item::<bool>());
2200    }
2201
2202    #[test]
2203    fn test_unary_cos() {
2204        let x = array![0.0];
2205        assert_float_eq! {
2206            cos(&x).unwrap().item::<f32>(),
2207            0.0f32.cos(),
2208            abs <= 1e-5
2209        };
2210
2211        let x = array![std::f32::consts::PI / 2.0];
2212        assert_float_eq! {
2213            cos(&x).unwrap().item::<f32>(),
2214            (std::f32::consts::PI / 2.0f32).cos(),
2215            abs <= 1e-5
2216        };
2217
2218        assert_eq!(cos(array!()).unwrap(), array!());
2219
2220        // Integer input type
2221        let x = array![0];
2222        assert_eq!(x.dtype(), Dtype::Int32);
2223        assert_float_eq! {
2224            cos(&x).unwrap().item::<f32>(),
2225            0.0f32.cos(),
2226            abs <= 1e-5
2227        };
2228
2229        // Input is irregularly strided
2230        let x = broadcast_to(&array!(1.0), &[2, 2, 2]).unwrap();
2231        let res = cos(&x).unwrap();
2232        let expected = Array::full::<f32>(&[2, 2, 2], array!(1.0f32.cos())).unwrap();
2233        assert!(all_close(&res, &expected, None, None, None)
2234            .unwrap()
2235            .item::<bool>());
2236
2237        let data = Array::from_slice(&[0.0, 1.0, 2.0, 3.0], &[2, 2]);
2238        let x = split(&data, 2, 1).unwrap();
2239        let expected = Array::from_slice(&[0.0f32.cos(), 2.0f32.cos()], &[2, 1]);
2240        assert!(all_close(cos(&x[0]).unwrap(), &expected, None, None, None)
2241            .unwrap()
2242            .item::<bool>());
2243    }
2244
2245    #[test]
2246    fn test_unary_degrees() {
2247        let x = array![0.0];
2248        assert_eq!(degrees(&x).unwrap().item::<f32>(), 0.0);
2249
2250        let x = array![std::f32::consts::PI / 2.0];
2251        assert_eq!(degrees(&x).unwrap().item::<f32>(), 90.0);
2252
2253        assert_eq!(degrees(array!()).unwrap(), array!());
2254
2255        // Integer input type
2256        let x = array![0];
2257        assert_eq!(x.dtype(), Dtype::Int32);
2258        assert_eq!(degrees(&x).unwrap().item::<f32>(), 0.0);
2259
2260        // Input is irregularly strided
2261        let x = broadcast_to(&array!(std::f32::consts::PI / 2.0), &[2, 2, 2]).unwrap();
2262        let res = degrees(&x).unwrap();
2263        let expected = Array::full::<f32>(&[2, 2, 2], array!(90.0)).unwrap();
2264        assert!(all_close(&res, &expected, None, None, None)
2265            .unwrap()
2266            .item::<bool>());
2267
2268        let angles = Array::from_slice(&[0.0, PI / 2.0, PI, 1.5 * PI], &[2, 2]);
2269        let x = split(&angles, 2, 1).unwrap();
2270        let expected = Array::from_slice(&[0.0, 180.0], &[2, 1]);
2271        assert!(
2272            all_close(degrees(&x[0]).unwrap(), &expected, None, None, None)
2273                .unwrap()
2274                .item::<bool>()
2275        );
2276    }
2277
2278    #[test]
2279    fn test_unary_radians() {
2280        let x = array![0.0];
2281        assert_eq!(radians(&x).unwrap().item::<f32>(), 0.0);
2282
2283        let x = array![90.0];
2284        assert_eq!(
2285            radians(&x).unwrap().item::<f32>(),
2286            std::f32::consts::PI / 2.0
2287        );
2288
2289        assert_eq!(radians(array!()).unwrap(), array!());
2290
2291        // Integer input type
2292        let x = array![90];
2293        assert_eq!(x.dtype(), Dtype::Int32);
2294        assert_eq!(
2295            radians(&x).unwrap().item::<f32>(),
2296            std::f32::consts::PI / 2.0
2297        );
2298
2299        // Input is irregularly strided
2300        let x = broadcast_to(&array!(90.0), &[2, 2, 2]).unwrap();
2301        let res = radians(&x).unwrap();
2302        let expected = Array::full::<f32>(&[2, 2, 2], array!(std::f32::consts::PI / 2.0)).unwrap();
2303        assert!(all_close(&res, &expected, None, None, None)
2304            .unwrap()
2305            .item::<bool>());
2306
2307        let angles = Array::from_slice(&[0.0, 90.0, 180.0, 270.0], &[2, 2]);
2308        let x = split(&angles, 2, 1).unwrap();
2309        let expected = Array::from_slice(&[0.0, PI], &[2, 1]);
2310        assert!(
2311            all_close(radians(&x[0]).unwrap(), &expected, None, None, None)
2312                .unwrap()
2313                .item::<bool>()
2314        );
2315    }
2316
2317    #[test]
2318    fn test_unary_log() {
2319        let x = array![0.0];
2320        assert_eq!(log(&x).unwrap().item::<f32>(), NEG_INF);
2321
2322        let x = array![1.0];
2323        assert_eq!(log(&x).unwrap().item::<f32>(), 0.0);
2324
2325        // Integer input type
2326        let x = array![1];
2327        assert_eq!(log(&x).unwrap().dtype(), Dtype::Float32);
2328        assert_eq!(log(&x).unwrap().item::<f32>(), 0.0);
2329
2330        // Input is irregularly strided
2331        let x = broadcast_to(&array!(1.0), &[2, 2, 2]).unwrap();
2332        let res = log(&x).unwrap();
2333        let expected = Array::full::<f32>(&[2, 2, 2], array!(0.0)).unwrap();
2334        assert!(all_close(&res, &expected, None, None, None)
2335            .unwrap()
2336            .item::<bool>());
2337
2338        let data = Array::from_slice(&[1.0, 2.0, 3.0, 4.0], &[2, 2]);
2339        let x = split(&data, 2, 1).unwrap();
2340        let expected = Array::from_slice(&[1.0f32.ln(), 3.0f32.ln()], &[2, 1]);
2341        assert!(all_close(log(&x[0]).unwrap(), &expected, None, None, None)
2342            .unwrap()
2343            .item::<bool>());
2344    }
2345
2346    #[test]
2347    fn test_unary_log2() {
2348        let x = array![0.0];
2349        assert_eq!(log2(&x).unwrap().item::<f32>(), NEG_INF);
2350
2351        let x = array![1.0];
2352        assert_eq!(log2(&x).unwrap().item::<f32>(), 0.0);
2353
2354        let x = array![1024.0];
2355        assert_eq!(log2(&x).unwrap().item::<f32>(), 10.0);
2356    }
2357
2358    #[test]
2359    fn test_unary_log10() {
2360        let x = array![0.0];
2361        assert_eq!(log10(&x).unwrap().item::<f32>(), NEG_INF);
2362
2363        let x = array![1.0];
2364        assert_eq!(log10(&x).unwrap().item::<f32>(), 0.0);
2365
2366        let x = array![1000.0];
2367        assert_eq!(log10(&x).unwrap().item::<f32>(), 3.0);
2368    }
2369
2370    #[test]
2371    fn test_unary_log1p() {
2372        let x = array![-1.0];
2373        assert_float_eq! {
2374            log1p(&x).unwrap().item::<f32>(),
2375            (-1.0f32).ln_1p(),
2376            abs <= 1e-5
2377        };
2378
2379        let x = array![1.0];
2380        assert_float_eq! {
2381            log1p(&x).unwrap().item::<f32>(),
2382            1.0f32.ln_1p(),
2383            abs <= 1e-5
2384        };
2385
2386        // Integer input type
2387        let x = array![1];
2388        assert_eq!(log1p(&x).unwrap().dtype(), Dtype::Float32);
2389        assert_float_eq! {
2390            log1p(&x).unwrap().item::<f32>(),
2391            1.0f32.ln_1p(),
2392            abs <= 1e-5
2393        };
2394
2395        // Input is irregularly strided
2396        let x = broadcast_to(&array!(1.0), &[2, 2, 2]).unwrap();
2397        let res = log1p(&x).unwrap();
2398        let expected = Array::full::<f32>(&[2, 2, 2], array!(1.0f32.ln_1p())).unwrap();
2399        assert!(all_close(&res, &expected, None, None, None)
2400            .unwrap()
2401            .item::<bool>());
2402
2403        let data = Array::from_slice(&[1.0, 2.0, 3.0, 4.0], &[2, 2]);
2404        let x = split(&data, 2, 1).unwrap();
2405        let expected = Array::from_slice(&[1.0f32.ln_1p(), 3.0f32.ln_1p()], &[2, 1]);
2406        assert!(
2407            all_close(log1p(&x[0]).unwrap(), &expected, None, None, None)
2408                .unwrap()
2409                .item::<bool>()
2410        );
2411    }
2412
2413    #[test]
2414    fn test_unary_sigmoid() {
2415        let x = array![0.0];
2416        assert_float_eq! {
2417            sigmoid(&x).unwrap().item::<f32>(),
2418            0.5,
2419            abs <= 1e-5
2420        };
2421
2422        // Integer input type
2423        let x = array![0];
2424        assert_eq!(sigmoid(&x).unwrap().dtype(), Dtype::Float32);
2425        assert_float_eq! {
2426            sigmoid(&x).unwrap().item::<f32>(),
2427            0.5,
2428            abs <= 1e-5
2429        };
2430
2431        let inf = f32::INFINITY;
2432        let x = array![inf];
2433        assert_eq!(sigmoid(&x).unwrap().item::<f32>(), 1.0);
2434
2435        let x = array![-inf];
2436        assert_eq!(sigmoid(&x).unwrap().item::<f32>(), 0.0);
2437    }
2438
2439    #[test]
2440    fn test_unary_square() {
2441        let x = array![3.0];
2442        assert_eq!(square(&x).unwrap().item::<f32>(), 9.0);
2443
2444        let x = array![2];
2445        assert_eq!(square(&x).unwrap().item::<i32>(), 4);
2446
2447        let x = Array::full::<f32>(&[3, 3], array!(2.0)).unwrap();
2448        assert!(all_close(
2449            square(&x).unwrap(),
2450            Array::full::<f32>(&[3, 3], array!(4.0)).unwrap(),
2451            None,
2452            None,
2453            None
2454        )
2455        .unwrap()
2456        .item::<bool>());
2457    }
2458
2459    #[test]
2460    fn test_unary_sqrt_rsqrt() {
2461        let x = array![4.0];
2462        assert_eq!(sqrt(&x).unwrap().item::<f32>(), 2.0);
2463        assert_eq!(rsqrt(&x).unwrap().item::<f32>(), 0.5);
2464
2465        let x = Array::full::<f32>(&[3, 3], array!(9.0)).unwrap();
2466        assert!(all_close(
2467            sqrt(&x).unwrap(),
2468            Array::full::<f32>(&[3, 3], array!(3.0)).unwrap(),
2469            None,
2470            None,
2471            None
2472        )
2473        .unwrap()
2474        .item::<bool>());
2475
2476        let x = array![4i32];
2477        assert_eq!(sqrt(&x).unwrap().item::<f32>(), 2.0);
2478        assert_eq!(rsqrt(&x).unwrap().item::<f32>(), 0.5);
2479    }
2480
2481    #[test]
2482    fn test_unary_reciprocal() {
2483        let x = array![8.0];
2484        assert_eq!(reciprocal(&x).unwrap().item::<f32>(), 0.125);
2485
2486        let x = array![2];
2487        let out = reciprocal(&x).unwrap();
2488        assert_eq!(out.dtype(), Dtype::Float32);
2489        assert_eq!(out.item::<f32>(), 0.5);
2490
2491        let x = Array::full::<f32>(&[3, 3], array!(2.0)).unwrap();
2492        assert!(all_close(
2493            reciprocal(&x).unwrap(),
2494            Array::full::<f32>(&[3, 3], array!(0.5)).unwrap(),
2495            None,
2496            None,
2497            None
2498        )
2499        .unwrap()
2500        .item::<bool>());
2501    }
2502
2503    #[test]
2504    fn test_binary_add() {
2505        let x = array![1.0];
2506        let y = array![1.0];
2507        let z = add(&x, &y).unwrap();
2508        assert_eq!(z.item::<f32>(), 2.0);
2509
2510        let z = &x + y;
2511        assert_eq!(z.item::<f32>(), 2.0);
2512
2513        let z = add(z, &x).unwrap();
2514        assert_eq!(z.item::<f32>(), 3.0);
2515
2516        // Chain a few adds:
2517        let mut out = x.deep_clone();
2518        for _ in 0..10 {
2519            out = add(&out, &x).unwrap();
2520        }
2521        assert_eq!(out.item::<f32>(), 11.0);
2522
2523        // Works for different shapes
2524        let x = array!([1.0, 2.0, 3.0]);
2525        let y = array!([1.0, 2.0, 3.0]);
2526        let z = add(&x, &y).unwrap();
2527        assert_eq!(z.shape(), &[3]);
2528        assert_eq!(z, array!([2.0, 4.0, 6.0]));
2529
2530        // Works with scalars
2531        let x = array!([1.0, 2.0, 3.0]);
2532        let y = &x + 2.0;
2533        assert_eq!(y.dtype(), Dtype::Float32);
2534        assert_eq!(y, array!([3.0, 4.0, 5.0]));
2535        let y = &x + 2.0;
2536        assert_eq!(y.dtype(), Dtype::Float32);
2537        assert_eq!(y, array!([3.0, 4.0, 5.0]));
2538
2539        // Check type promotion
2540        let y = x + 2;
2541        assert_eq!(y.dtype(), Dtype::Float32);
2542
2543        let y = array!([1, 2, 3]) + 2.0;
2544        assert_eq!(y.dtype(), Dtype::Float32);
2545        // assert!(array_equal(&y, &array![3.0, 4.0, 5.0]).item::<bool>());
2546        assert_eq!(y, array!([3.0, 4.0, 5.0]));
2547
2548        // Broadcasting works
2549        let x = broadcast_to(&array!(1.0), &[10]).unwrap();
2550        let y = broadcast_to(&array!(2.0), &[10]).unwrap();
2551        let z = add(&x, &y).unwrap();
2552        assert_eq!(z, full::<f32>(&[10], array!(3.0)).unwrap());
2553
2554        let x = Array::from_slice(&[1.0, 2.0], &[1, 2]);
2555        let y = Array::from_slice(&[1.0, 2.0], &[2, 1]);
2556        let z = add(&x, &y).unwrap();
2557        assert_eq!(z.shape(), &[2, 2]);
2558        assert_eq!(z, Array::from_slice(&[2.0, 3.0, 3.0, 4.0], &[2, 2]));
2559
2560        let x = ones::<f32>(&[3, 2, 1]).unwrap();
2561        let z = x + 2.0;
2562        assert_eq!(z.shape(), &[3, 2, 1]);
2563        let expected = Array::from_slice(&[3.0, 3.0, 3.0, 3.0, 3.0, 3.0], &[3, 2, 1]);
2564        assert_eq!(z, expected);
2565
2566        // Works for empty arrays
2567        let x = array!();
2568        let y = array!();
2569        let z = x + y;
2570        z.eval().unwrap();
2571        assert_eq!(z.size(), 0);
2572        assert_eq!(z.shape(), &[0]);
2573    }
2574
2575    #[test]
2576    fn test_binary_sub() {
2577        let x = array!([3.0, 2.0, 1.0]);
2578        let y = array!([1.0, 1.0, 1.0]);
2579        assert_eq!(x - y, array!([2.0, 1.0, 0.0]));
2580    }
2581
2582    #[test]
2583    fn test_binary_mul() {
2584        let x = array!([1.0, 2.0, 3.0]);
2585        let y = array!([2.0, 2.0, 2.0]);
2586        assert_eq!(x * y, array!([2.0, 4.0, 6.0]));
2587    }
2588
2589    #[test]
2590    fn test_binary_div() {
2591        let x = array![1.0];
2592        let y = array![1.0];
2593        assert_eq!(divide(&x, &y).unwrap().item::<f32>(), 1.0);
2594
2595        let x = array![1.0];
2596        let y = array![0.5];
2597        assert_eq!(divide(&x, &y).unwrap().item::<f32>(), 2.0);
2598
2599        let x = array![1.0];
2600        let y = array![4.0];
2601        assert_eq!(divide(&x, &y).unwrap().item::<f32>(), 0.25);
2602
2603        let x = array![true];
2604        let y = array![true];
2605        assert_eq!(divide(&x, &y).unwrap().item::<f32>(), 1.0);
2606
2607        let x = array![false];
2608        let y = array![true];
2609        assert_eq!(divide(&x, &y).unwrap().item::<f32>(), 0.0);
2610
2611        let x = array![true];
2612        let y = array![false];
2613        assert!(divide(&x, &y).unwrap().item::<f32>().is_infinite());
2614
2615        let x = array![false];
2616        let y = array![false];
2617        assert!(divide(&x, &y).unwrap().item::<f32>().is_nan());
2618    }
2619
2620    #[test]
2621    fn test_binary_maximum_minimum() {
2622        let x = array![1.0];
2623        let y = array![0.0];
2624        assert_eq!(maximum(&x, &y).unwrap().item::<f32>(), 1.0);
2625        assert_eq!(minimum(&x, &y).unwrap().item::<f32>(), 0.0);
2626
2627        let y = array![2.0];
2628        assert_eq!(maximum(&x, &y).unwrap().item::<f32>(), 2.0);
2629        assert_eq!(minimum(&x, &y).unwrap().item::<f32>(), 1.0);
2630    }
2631
2632    #[test]
2633    fn test_binary_logaddexp() {
2634        let x = array![0.0];
2635        let y = array![0.0];
2636        assert_float_eq! {
2637            logaddexp(&x, &y).unwrap().item::<f32>(),
2638            2.0f32.ln(),
2639            abs <= 1e-5
2640        };
2641
2642        let x = array!([0u32]);
2643        let y = array!([10000u32]);
2644        assert_eq!(logaddexp(&x, &y).unwrap().item::<f32>(), 10000.0);
2645
2646        let x = array![f32::INFINITY];
2647        let y = array![3.0];
2648        assert_eq!(logaddexp(&x, &y).unwrap().item::<f32>(), f32::INFINITY);
2649
2650        let x = array![f32::NEG_INFINITY];
2651        let y = array![3.0];
2652        assert_eq!(logaddexp(&x, &y).unwrap().item::<f32>(), 3.0);
2653
2654        let x = array![f32::NEG_INFINITY];
2655        let y = array![f32::NEG_INFINITY];
2656        assert_eq!(logaddexp(&x, &y).unwrap().item::<f32>(), f32::NEG_INFINITY);
2657
2658        let x = array![f32::INFINITY];
2659        let y = array![f32::INFINITY];
2660        assert_eq!(logaddexp(&x, &y).unwrap().item::<f32>(), f32::INFINITY);
2661
2662        let x = array![f32::NEG_INFINITY];
2663        let y = array![f32::INFINITY];
2664        assert_eq!(logaddexp(&x, &y).unwrap().item::<f32>(), f32::INFINITY);
2665    }
2666
2667    #[test]
2668    fn test_basic_clip() {
2669        let a = array!([1.0, 4.0, 3.0, 8.0, 5.0]);
2670        let expected = array!([2.0, 4.0, 3.0, 6.0, 5.0]);
2671        let clipped = clip(&a, (array!(2.0), array!(6.0))).unwrap();
2672        assert_eq!(clipped, expected);
2673
2674        // Test with scalar
2675        let clipped = clip(&a, (2.0, 6.0)).unwrap();
2676        assert_eq!(clipped, expected);
2677    }
2678
2679    #[test]
2680    fn test_clip_with_only_min() {
2681        let a = array!([-1.0, 1.0, 0.0, 5.0]);
2682        let expected = array!([0.0, 1.0, 0.0, 5.0]);
2683        let clipped = clip(&a, (array!(0.0), ())).unwrap();
2684        assert_eq!(clipped, expected);
2685
2686        // Test with scalar
2687        let clipped = clip(&a, (0.0, ())).unwrap();
2688        assert_eq!(clipped, expected);
2689    }
2690
2691    #[test]
2692    fn test_clip_with_only_max() {
2693        let a = array!([2.0, 3.0, 4.0, 5.0]);
2694        let expected = array!([2.0, 3.0, 4.0, 4.0]);
2695        let clipped = clip(&a, ((), array!(4.0))).unwrap();
2696        assert_eq!(clipped, expected);
2697
2698        // Test with scalar
2699        let clipped = clip(&a, ((), 4.0)).unwrap();
2700        assert_eq!(clipped, expected);
2701    }
2702
2703    #[test]
2704    fn test_tensordot() {
2705        let x = reshape(arange::<_, f32>(None, 60.0, None).unwrap(), &[3, 4, 5]).unwrap();
2706        let y = reshape(arange::<_, f32>(None, 24.0, None).unwrap(), &[4, 3, 2]).unwrap();
2707        let z = tensordot_axes(&x, &y, &[1i32, 0], &[0i32, 1]).unwrap();
2708        let expected = Array::from_slice(
2709            &[4400, 4730, 4532, 4874, 4664, 5018, 4796, 5162, 4928, 5306],
2710            &[5, 2],
2711        );
2712        assert_eq!(z, expected);
2713
2714        let x = reshape(arange::<_, f32>(None, 360.0, None).unwrap(), &[3, 4, 5, 6]).unwrap();
2715        let y = reshape(arange::<_, f32>(None, 360.0, None).unwrap(), &[6, 4, 5, 3]).unwrap();
2716        assert!(tensordot_axes(&x, &y, &[2, 1, 3], &[1, 2, 0]).is_err());
2717
2718        let x = reshape(arange::<_, f32>(None, 60.0, None).unwrap(), &[3, 4, 5]).unwrap();
2719        let y = reshape(arange::<_, f32>(None, 120.0, None).unwrap(), &[4, 5, 6]).unwrap();
2720
2721        let z = tensordot_axis(&x, &y, 2).unwrap();
2722        let expected = Array::from_slice(
2723            &[
2724                14820.0, 15010.0, 15200.0, 15390.0, 15580.0, 15770.0, 37620.0, 38210.0, 38800.0,
2725                39390.0, 39980.0, 40570.0, 60420.0, 61410.0, 62400.0, 63390.0, 64380.0, 65370.0,
2726            ],
2727            &[3, 6],
2728        );
2729        assert_eq!(z, expected);
2730    }
2731
2732    #[test]
2733    fn test_outer() {
2734        let x = arange::<_, f32>(1.0, 5.0, None).unwrap();
2735        let y = arange::<_, f32>(1.0, 4.0, None).unwrap();
2736        let z = outer(&x, &y).unwrap();
2737        let expected = Array::from_slice(
2738            &[1.0, 2.0, 3.0, 2.0, 4.0, 6.0, 3.0, 6.0, 9.0, 4.0, 8.0, 12.0],
2739            &[4, 3],
2740        );
2741        assert_eq!(z, expected);
2742
2743        let x = ones::<f32>(&[5]).unwrap();
2744        let y = linspace::<_, f32>(-2.0, 2.0, 5).unwrap();
2745        let z = outer(&x, &y).unwrap();
2746        let expected = Array::from_slice(
2747            &[
2748                -2.0, -1.0, 0.0, 1.0, 2.0, -2.0, -1.0, 0.0, 1.0, 2.0, -2.0, -1.0, 0.0, 1.0, 2.0,
2749                -2.0, -1.0, 0.0, 1.0, 2.0, -2.0, -1.0, 0.0, 1.0, 2.0,
2750            ],
2751            &[5, 5],
2752        );
2753        assert_eq!(z, expected);
2754    }
2755
2756    #[test]
2757    fn test_inner() {
2758        let x = reshape(arange::<_, f32>(None, 5.0, None).unwrap(), &[1, 5]).unwrap();
2759        let y = reshape(arange::<_, f32>(None, 6.0, None).unwrap(), &[2, 3]).unwrap();
2760        assert!(inner(&x, &y).is_err());
2761
2762        let x = array!([1.0, 2.0, 3.0]);
2763        let y = array!([0.0, 1.0, 0.0]);
2764        let z = inner(&x, &y).unwrap();
2765        assert_eq!(z.item::<f32>(), 2.0);
2766
2767        let x = reshape(arange::<_, f32>(None, 24.0, None).unwrap(), &[2, 3, 4]).unwrap();
2768        let y = arange::<_, f32>(None, 4.0, None).unwrap();
2769        let z = inner(&x, &y).unwrap();
2770        let expected = Array::from_slice(&[14.0, 38.0, 62.0, 86.0, 110.0, 134.0], &[2, 3]);
2771        assert_eq!(z, expected);
2772
2773        let x = reshape(arange::<_, f32>(None, 2.0, None).unwrap(), &[1, 1, 2]).unwrap();
2774        let y = reshape(arange::<_, f32>(None, 6.0, None).unwrap(), &[3, 2]).unwrap();
2775        let z = inner(&x, &y).unwrap();
2776        let expected = Array::from_slice(&[1.0, 3.0, 5.0], &[1, 1, 3]);
2777        assert_eq!(z, expected);
2778
2779        let x = eye::<f32>(2, None, None).unwrap();
2780        let y = Array::from_f32(7.0);
2781        let z = inner(&x, &y).unwrap();
2782        let expected = Array::from_slice(&[7.0, 0.0, 0.0, 7.0], &[2, 2]);
2783        assert_eq!(z, expected);
2784    }
2785
2786    #[test]
2787    fn test_divmod() {
2788        let x = array!([1.0, 2.0, 3.0]);
2789        let y = array!([1.0, 1.0, 1.0]);
2790        let out = divmod(&x, &y).unwrap();
2791        assert_eq!(out.0, array!([1.0, 2.0, 3.0]));
2792        assert_eq!(out.1, array!([0.0, 0.0, 0.0]));
2793
2794        let x = array!([5.0, 6.0, 7.0]);
2795        let y = array!([2.0, 2.0, 2.0]);
2796        let out = divmod(&x, &y).unwrap();
2797        assert_eq!(out.0, array!([2.0, 3.0, 3.0]));
2798        assert_eq!(out.1, array!([1.0, 0.0, 1.0]));
2799
2800        let x = array!([5.0, 6.0, 7.0]);
2801        let y = array!([2.0, 2.0, 2.0]);
2802        let out = divmod(&x, &y).unwrap();
2803        assert_eq!(out.0, array!([2.0, 3.0, 3.0]));
2804        assert_eq!(out.1, array!([1.0, 0.0, 1.0]));
2805
2806        let x = array![complex64::new(1.0, 0.0)];
2807        let y = array![complex64::new(2.0, 0.0)];
2808        assert!(divmod(&x, &y).is_err());
2809
2810        // Check that we can eval on both outputs
2811        let x = array![1.0];
2812        let y = array![2.0];
2813        let (quo, rem) = divmod(&x, &y).unwrap();
2814        eval([&quo, &rem]).unwrap();
2815        assert_eq!(quo.item::<f32>(), 0.0);
2816        assert_eq!(rem.item::<f32>(), 1.0);
2817
2818        // Check nested in the graph
2819        let x = array![1.0];
2820        let y = array![2.0];
2821        let (quo, rem) = divmod(&x, &y).unwrap();
2822        let z = quo + rem;
2823        assert_eq!(z.item::<f32>(), 1.0);
2824
2825        // Check that we can still eval when one output goes out of scope
2826        let mut out_holder = {
2827            let (quo, _) = divmod(&x, &y).unwrap();
2828            vec![quo]
2829        };
2830        eval(out_holder.iter()).unwrap();
2831        assert_eq!(out_holder[0].item::<f32>(), 0.0);
2832
2833        // Check that we can still eval when the other output goes out of scope
2834        out_holder.clear();
2835        let out_holder = {
2836            let (_, rem) = divmod(&x, &y).unwrap();
2837            vec![rem]
2838        };
2839        eval(out_holder.iter()).unwrap();
2840        assert_eq!(out_holder[0].item::<f32>(), 1.0);
2841    }
2842}