mlx_rs/ops/
arithmetic.rs

1use crate::array::Array;
2use crate::error::Result;
3use crate::sealed::Sealed;
4
5use crate::utils::guard::Guarded;
6use crate::utils::{IntoOption, ScalarOrArray, VectorArray};
7use crate::Stream;
8use mlx_internal_macros::{default_device, generate_macro};
9use smallvec::SmallVec;
10
11impl Array {
12    /// Element-wise absolute value.
13    ///
14    /// # Example
15    ///
16    /// ```rust
17    /// use mlx_rs::Array;
18    /// let array = Array::from_slice(&[1i32, 2, -3, -4, -5], &[5]);
19    /// let mut result = array.abs().unwrap();
20    ///
21    /// let data: &[i32] = result.as_slice();
22    /// // data == [1, 2, 3, 4, 5]
23    /// ```
24    #[default_device]
25    pub fn abs_device(&self, stream: impl AsRef<Stream>) -> Result<Array> {
26        Array::try_from_op(|res| unsafe {
27            mlx_sys::mlx_abs(res, self.as_ptr(), stream.as_ref().as_ptr())
28        })
29    }
30
31    /// Element-wise addition returning an error if arrays are not broadcastable.
32    ///
33    /// Add two arrays with [broadcasting](https://swiftpackageindex.com/ml-explore/mlx-swift/main/documentation/mlx/broadcasting).
34    ///
35    /// # Params
36    ///
37    /// - other: array to add
38    ///
39    /// # Example
40    ///
41    /// ```rust
42    /// use mlx_rs::Array;
43    /// let a = Array::from_slice(&[1.0, 2.0, 3.0], &[3]);
44    /// let b = Array::from_slice(&[4.0, 5.0, 6.0], &[3]);
45    /// let mut c = a.add(&b).unwrap();
46    ///
47    /// let c_data: &[f32] = c.as_slice();
48    /// // c_data == [5.0, 7.0, 9.0]
49    /// ```
50    #[default_device]
51    pub fn add_device(
52        &self,
53        other: impl AsRef<Array>,
54        stream: impl AsRef<Stream>,
55    ) -> Result<Array> {
56        Array::try_from_op(|res| unsafe {
57            mlx_sys::mlx_add(
58                res,
59                self.as_ptr(),
60                other.as_ref().as_ptr(),
61                stream.as_ref().as_ptr(),
62            )
63        })
64    }
65
66    /// Element-wise subtraction returning an error if arrays are not broadcastable.
67    ///
68    /// Subtract two arrays with [broadcasting](https://swiftpackageindex.com/ml-explore/mlx-swift/main/documentation/mlx/broadcasting).
69    ///
70    /// # Params
71    ///
72    /// - other: array to subtract
73    ///
74    /// # Example
75    ///
76    /// ```rust
77    /// use mlx_rs::Array;
78    /// let a = Array::from_slice(&[1.0, 2.0, 3.0], &[3]);
79    /// let b = Array::from_slice(&[4.0, 5.0, 6.0], &[3]);
80    /// let mut c = a.subtract(&b).unwrap();
81    ///
82    /// let c_data: &[f32] = c.as_slice();
83    /// // c_data == [-3.0, -3.0, -3.0]
84    /// ```
85    #[default_device]
86    pub fn subtract_device(
87        &self,
88        other: impl AsRef<Array>,
89        stream: impl AsRef<Stream>,
90    ) -> Result<Array> {
91        Array::try_from_op(|res| unsafe {
92            mlx_sys::mlx_subtract(
93                res,
94                self.as_ptr(),
95                other.as_ref().as_ptr(),
96                stream.as_ref().as_ptr(),
97            )
98        })
99    }
100
101    /// Unary element-wise negation. Returns an error if the array is of type bool.
102    ///
103    /// Negate the values in the array.
104    ///
105    /// # Example
106    ///
107    /// ```rust
108    /// use mlx_rs::Array;
109    /// let a = Array::from_slice(&[1.0, 2.0, 3.0], &[3]);
110    /// let mut b = a.negative().unwrap();
111    ///
112    /// let b_data: &[f32] = b.as_slice();
113    /// // b_data == [-1.0, -2.0, -3.0]
114    /// ```
115    #[default_device]
116    pub fn negative_device(&self, stream: impl AsRef<Stream>) -> Result<Array> {
117        Array::try_from_op(|res| unsafe {
118            mlx_sys::mlx_negative(res, self.as_ptr(), stream.as_ref().as_ptr())
119        })
120    }
121
122    /// Element-wise multiplication returning an error if arrays are not broadcastable.
123    ///
124    /// Multiply two arrays with [broadcasting](https://swiftpackageindex.com/ml-explore/mlx-swift/main/documentation/mlx/broadcasting).
125    ///
126    /// # Example
127    ///
128    /// ```rust
129    /// use mlx_rs::Array;
130    /// let a = Array::from_slice(&[1.0, 2.0, 3.0], &[3]);
131    /// let b = Array::from_slice(&[4.0, 5.0, 6.0], &[3]);
132    /// let mut c = a.multiply(&b).unwrap();
133    ///
134    /// let c_data: &[f32] = c.as_slice();
135    /// // c_data == [4.0, 10.0, 18.0]
136    /// ```
137    #[default_device]
138    pub fn multiply_device(
139        &self,
140        other: impl AsRef<Array>,
141        stream: impl AsRef<Stream>,
142    ) -> Result<Array> {
143        Array::try_from_op(|res| unsafe {
144            mlx_sys::mlx_multiply(
145                res,
146                self.as_ptr(),
147                other.as_ref().as_ptr(),
148                stream.as_ref().as_ptr(),
149            )
150        })
151    }
152
153    /// Replace NaN and Inf values with finite numbers.
154    ///
155    /// # Params
156    /// - nan: value to replace NaN with
157    /// - posInf: value to replace positive inifinites with.  If not specified will use
158    ///     the largest finite value for the given dtype.
159    /// - negInf: value to replace negative inifinites with.  If not specified will use
160    ///     the negative of the largest finite value for the given dtype.
161    /// - stream: stream or device to evaluate on
162    #[default_device]
163    pub fn nan_to_num_device(
164        &self,
165        nan: impl IntoOption<f32>,
166        pos_inf: impl IntoOption<f32>,
167        neg_inf: impl IntoOption<f32>,
168        stream: impl AsRef<Stream>,
169    ) -> Result<Array> {
170        let pos_inf = pos_inf.into_option();
171        let neg_inf = neg_inf.into_option();
172
173        let pos_inf = mlx_sys::mlx_optional_float {
174            value: pos_inf.unwrap_or(0.0),
175            has_value: pos_inf.is_some(),
176        };
177        let neg_inf = mlx_sys::mlx_optional_float {
178            value: neg_inf.unwrap_or(0.0),
179            has_value: neg_inf.is_some(),
180        };
181
182        Array::try_from_op(|res| unsafe {
183            mlx_sys::mlx_nan_to_num(
184                res,
185                self.as_ptr(),
186                nan.into_option().unwrap_or(0.),
187                pos_inf,
188                neg_inf,
189                stream.as_ref().as_ptr(),
190            )
191        })
192    }
193
194    /// Element-wise division returning an error if arrays are not broadcastable.
195    ///
196    /// Divide two arrays with [broadcasting](https://swiftpackageindex.com/ml-explore/mlx-swift/main/documentation/mlx/broadcasting).
197    ///
198    /// # Params
199    ///
200    /// - other: array to divide
201    ///
202    /// # Example
203    ///
204    /// ```rust
205    /// use mlx_rs::Array;
206    /// let a = Array::from_slice(&[1.0, 2.0, 3.0], &[3]);
207    /// let b = Array::from_slice(&[4.0, 5.0, 6.0], &[3]);
208    /// let mut c = a.divide(&b).unwrap();
209    ///
210    /// let c_data: &[f32] = c.as_slice();
211    /// // c_data == [0.25, 0.4, 0.5]
212    /// ```
213    #[default_device]
214    pub fn divide_device(
215        &self,
216        other: impl AsRef<Array>,
217        stream: impl AsRef<Stream>,
218    ) -> Result<Array> {
219        Array::try_from_op(|res| unsafe {
220            mlx_sys::mlx_divide(
221                res,
222                self.as_ptr(),
223                other.as_ref().as_ptr(),
224                stream.as_ref().as_ptr(),
225            )
226        })
227    }
228
229    /// Element-wise power operation returning an error if arrays are not broadcastable if they have different shapes.
230    ///
231    /// Raise the elements of the array to the power of the elements of another array.
232    ///
233    /// # Params
234    ///
235    /// - other: array to raise to the power of
236    ///
237    /// # Example
238    ///
239    /// ```rust
240    /// use mlx_rs::Array;
241    /// let a = Array::from_slice(&[1.0, 2.0, 3.0], &[3]);
242    /// let b = Array::from_slice(&[2.0, 3.0, 4.0], &[3]);
243    /// let mut c = a.power(&b).unwrap();
244    ///
245    /// let c_data: &[f32] = c.as_slice();
246    /// // c_data == [1.0, 8.0, 81.0]
247    /// ```
248    #[default_device]
249    pub fn power_device(
250        &self,
251        other: impl AsRef<Array>,
252        stream: impl AsRef<Stream>,
253    ) -> Result<Array> {
254        Array::try_from_op(|res| unsafe {
255            mlx_sys::mlx_power(
256                res,
257                self.as_ptr(),
258                other.as_ref().as_ptr(),
259                stream.as_ref().as_ptr(),
260            )
261        })
262    }
263
264    /// Element-wise remainder of division returning an error if arrays are not broadcastable.
265    ///
266    /// Computes the remainder of dividing `lhs` with `rhs` with [broadcasting](https://swiftpackageindex.com/ml-explore/mlx-swift/main/documentation/mlx/broadcasting).
267    ///
268    /// # Params
269    ///
270    /// - other: array to divide
271    ///
272    /// # Example
273    ///
274    /// ```rust
275    /// use mlx_rs::Array;
276    /// let a = Array::from_slice(&[10.0, 11.0, 12.0], &[3]);
277    /// let b = Array::from_slice(&[3.0, 4.0, 5.0], &[3]);
278    /// let mut c = a.remainder(&b).unwrap();
279    ///
280    /// let c_data: &[f32] = c.as_slice();
281    /// // c_data == [1.0, 3.0, 2.0]
282    /// ```
283    #[default_device]
284    pub fn remainder_device(
285        &self,
286        other: impl AsRef<Array>,
287        stream: impl AsRef<Stream>,
288    ) -> Result<Array> {
289        Array::try_from_op(|res| unsafe {
290            mlx_sys::mlx_remainder(
291                res,
292                self.as_ptr(),
293                other.as_ref().as_ptr(),
294                stream.as_ref().as_ptr(),
295            )
296        })
297    }
298
299    /// Element-wise square root
300    ///
301    /// # Example
302    ///
303    /// ```rust
304    /// use mlx_rs::Array;
305    /// let a = Array::from_slice(&[1.0, 4.0, 9.0], &[3]);
306    /// let mut b = a.sqrt().unwrap();
307    ///
308    /// let b_data: &[f32] = b.as_slice();
309    /// // b_data == [1.0, 2.0, 3.0]
310    /// ```
311    #[default_device]
312    pub fn sqrt_device(&self, stream: impl AsRef<Stream>) -> Result<Array> {
313        Array::try_from_op(|res| unsafe {
314            mlx_sys::mlx_sqrt(res, self.as_ptr(), stream.as_ref().as_ptr())
315        })
316    }
317
318    /// Element-wise cosine
319    ///
320    /// # Example
321    ///
322    /// ```rust
323    /// use mlx_rs::Array;
324    /// let a = Array::from_slice(&[0.0, 1.0, 2.0], &[3]);
325    /// let mut b = a.cos().unwrap();
326    ///
327    /// let b_data: &[f32] = b.as_slice();
328    /// // b_data == [1.0, 0.54030234, -0.41614687]
329    /// ```
330    #[default_device]
331    pub fn cos_device(&self, stream: impl AsRef<Stream>) -> Result<Array> {
332        Array::try_from_op(|res| unsafe {
333            mlx_sys::mlx_cos(res, self.as_ptr(), stream.as_ref().as_ptr())
334        })
335    }
336
337    /// Element-wise exponential.
338    ///
339    /// # Example
340    ///
341    /// ```rust
342    /// use mlx_rs::Array;
343    ///
344    /// let a = Array::from_slice(&[0.0, 1.0, 2.0], &[3]);
345    /// let a = Array::from_slice(&[0.0, 1.0, 2.0], &[3]);
346    /// let mut b = a.exp().unwrap();
347    ///
348    /// let b_data: &[f32] = b.as_slice();
349    /// // b_data == [1.0, 2.7182817, 7.389056]
350    /// ```
351    #[default_device]
352    pub fn exp_device(&self, stream: impl AsRef<Stream>) -> Result<Array> {
353        Array::try_from_op(|res| unsafe {
354            mlx_sys::mlx_exp(res, self.as_ptr(), stream.as_ref().as_ptr())
355        })
356    }
357
358    /// Element-wise floor returning an error if the array is of type complex64.
359    ///
360    /// # Example
361    ///
362    /// ```rust
363    /// use mlx_rs::Array;
364    /// let a = Array::from_slice(&[0.1, 1.9, 2.5], &[3]);
365    /// let mut b = a.floor().unwrap();
366    ///
367    /// let b_data: &[f32] = b.as_slice();
368    /// // b_data == [0.0, 1.0, 2.0]
369    /// ```
370    #[default_device]
371    pub fn floor_device(&self, stream: impl AsRef<Stream>) -> Result<Array> {
372        Array::try_from_op(|res| unsafe {
373            mlx_sys::mlx_floor(res, self.as_ptr(), stream.as_ref().as_ptr())
374        })
375    }
376
377    /// Element-wise integer division returning an error if arrays are not broadcastable.
378    ///
379    /// Divide two arrays with
380    /// [broadcasting](https://swiftpackageindex.com/ml-explore/mlx-swift/main/documentation/mlx/broadcasting).
381    ///
382    /// If either array is a floating point type then it is equivalent to calling [`Array::floor()`]
383    /// after `/`.
384    ///
385    /// # Params
386    ///
387    /// - other: array to divide
388    ///
389    /// # Example
390    ///
391    /// ```rust
392    /// use mlx_rs::Array;
393    /// let a = Array::from_slice(&[1.0, 2.0, 3.0], &[3]);
394    /// let b = Array::from_slice(&[4.0, 5.0, 6.0], &[3]);
395    /// let mut c = a.floor_divide(&b).unwrap();
396    ///
397    /// let c_data: &[f32] = c.as_slice();
398    /// // c_data == [0.25, 0.4, 0.5]
399    /// ```
400    #[default_device]
401    pub fn floor_divide_device(
402        &self,
403        other: impl AsRef<Array>,
404        stream: impl AsRef<Stream>,
405    ) -> Result<Array> {
406        Array::try_from_op(|res| unsafe {
407            mlx_sys::mlx_floor_divide(
408                res,
409                self.as_ptr(),
410                other.as_ref().as_ptr(),
411                stream.as_ref().as_ptr(),
412            )
413        })
414    }
415
416    /// Return a boolean array indicating which elements are NaN.
417    ///
418    /// # Params
419    /// - stream: stream or device to evaluate on
420    #[default_device]
421    pub fn is_nan_device(&self, stream: impl AsRef<Stream>) -> Result<Array> {
422        Array::try_from_op(|res| unsafe {
423            mlx_sys::mlx_isnan(res, self.as_ptr(), stream.as_ref().as_ptr())
424        })
425    }
426
427    /// Return a boolean array indicating which elements are infinity.
428    ///
429    /// # Params
430    /// - stream: stream or device to evaluate on
431    #[default_device]
432    pub fn is_inf_device(&self, stream: impl AsRef<Stream>) -> Result<Array> {
433        Array::try_from_op(|res| unsafe {
434            mlx_sys::mlx_isinf(res, self.as_ptr(), stream.as_ref().as_ptr())
435        })
436    }
437
438    /// Return a boolean array indicating which elements are finite.
439    ///
440    /// # Params
441    /// - stream: stream or device to evaluate on
442    #[default_device]
443    pub fn is_finite_device(&self, stream: impl AsRef<Stream>) -> Result<Array> {
444        Array::try_from_op(|res| unsafe {
445            mlx_sys::mlx_isfinite(res, self.as_ptr(), stream.as_ref().as_ptr())
446        })
447    }
448
449    /// Return a boolean array indicating which elements are negative infinity.
450    ///
451    /// # Params
452    /// - stream: stream or device to evaluate on
453    #[default_device]
454    pub fn is_neg_inf_device(&self, stream: impl AsRef<Stream>) -> Result<Array> {
455        Array::try_from_op(|res| unsafe {
456            mlx_sys::mlx_isneginf(res, self.as_ptr(), stream.as_ref().as_ptr())
457        })
458    }
459
460    /// Return a boolean array indicating which elements are positive infinity.
461    ///
462    /// # Params
463    /// - stream: stream or device to evaluate on
464    #[default_device]
465    pub fn is_pos_inf_device(&self, stream: impl AsRef<Stream>) -> Result<Array> {
466        Array::try_from_op(|res| unsafe {
467            mlx_sys::mlx_isposinf(res, self.as_ptr(), stream.as_ref().as_ptr())
468        })
469    }
470
471    /// Element-wise natural logarithm.
472    ///
473    /// # Example
474    ///
475    /// ```rust
476    /// use mlx_rs::Array;
477    /// let a = Array::from_slice(&[1.0, 2.0, 3.0], &[3]);
478    /// let mut b = a.log().unwrap();
479    ///
480    /// let b_data: &[f32] = b.as_slice();
481    /// // b_data == [0.0, 0.6931472, 1.0986123]
482    /// ```
483    #[default_device]
484    pub fn log_device(&self, stream: impl AsRef<Stream>) -> Result<Array> {
485        Array::try_from_op(|res| unsafe {
486            mlx_sys::mlx_log(res, self.as_ptr(), stream.as_ref().as_ptr())
487        })
488    }
489
490    /// Element-wise base-2 logarithm.
491    ///
492    /// # Example
493    ///
494    /// ```rust
495    /// use mlx_rs::Array;
496    /// let a = Array::from_slice(&[1.0, 2.0, 4.0, 8.0], &[4]);
497    /// let mut b = a.log2().unwrap();
498    ///
499    /// let b_data: &[f32] = b.as_slice();
500    /// // b_data == [0.0, 1.0, 2.0, 3.0]
501    /// ```
502    #[default_device]
503    pub fn log2_device(&self, stream: impl AsRef<Stream>) -> Result<Array> {
504        Array::try_from_op(|res| unsafe {
505            mlx_sys::mlx_log2(res, self.as_ptr(), stream.as_ref().as_ptr())
506        })
507    }
508
509    /// Element-wise base-10 logarithm.
510    ///
511    /// # Example
512    ///
513    /// ```rust
514    /// use mlx_rs::Array;
515    /// let a = Array::from_slice(&[1.0, 10.0, 100.0], &[3]);
516    /// let mut b = a.log10().unwrap();
517    ///
518    /// let b_data: &[f32] = b.as_slice();
519    /// // b_data == [0.0, 1.0, 2.0]
520    /// ```
521    #[default_device]
522    pub fn log10_device(&self, stream: impl AsRef<Stream>) -> Result<Array> {
523        Array::try_from_op(|res| unsafe {
524            mlx_sys::mlx_log10(res, self.as_ptr(), stream.as_ref().as_ptr())
525        })
526    }
527
528    /// Element-wise natural log of one plus the array.
529    ///
530    /// # Example
531    ///
532    /// ```rust
533    /// use mlx_rs::Array;
534    /// let a = Array::from_slice(&[1.0, 2.0, 3.0], &[3]);
535    /// let mut b = a.log1p().unwrap();
536    ///
537    /// let b_data: &[f32] = b.as_slice();
538    /// // b_data == [0.6931472, 1.0986123, 1.3862944]
539    /// ```
540    #[default_device]
541    pub fn log1p_device(&self, stream: impl AsRef<Stream>) -> Result<Array> {
542        Array::try_from_op(|res| unsafe {
543            mlx_sys::mlx_log1p(res, self.as_ptr(), stream.as_ref().as_ptr())
544        })
545    }
546
547    /// Matrix multiplication returning an error if inputs are not valid.
548    ///
549    /// Perform the (possibly batched) matrix multiplication of two arrays. This function supports
550    /// broadcasting for arrays with more than two dimensions.
551    ///
552    /// - If the first array is 1-D then a 1 is prepended to its shape to make it
553    ///   a matrix. Similarly, if the second array is 1-D then a 1 is appended to its
554    ///   shape to make it a matrix. In either case the singleton dimension is removed
555    ///   from the result.
556    /// - A batched matrix multiplication is performed if the arrays have more than
557    ///   2 dimensions.  The matrix dimensions for the matrix product are the last
558    ///   two dimensions of each input.
559    /// - All but the last two dimensions of each input are broadcast with one another using
560    ///   standard [broadcasting](https://swiftpackageindex.com/ml-explore/mlx-swift/main/documentation/mlx/broadcasting).
561    ///
562    /// # Params
563    ///
564    /// - other: array to multiply
565    ///
566    /// # Example
567    ///
568    /// ```rust
569    /// use mlx_rs::Array;
570    /// let a = Array::from_slice(&[1, 2, 3, 4], &[2, 2]);
571    /// let b = Array::from_slice(&[-5.0, 37.5, 4., 7., 1., 0.], &[2, 3]);
572    ///
573    /// // produces a [2, 3] result
574    /// let mut c = a.matmul(&b);
575    /// ```
576    #[default_device]
577    pub fn matmul_device(
578        &self,
579        other: impl AsRef<Array>,
580        stream: impl AsRef<Stream>,
581    ) -> Result<Array> {
582        Array::try_from_op(|res| unsafe {
583            mlx_sys::mlx_matmul(
584                res,
585                self.as_ptr(),
586                other.as_ref().as_ptr(),
587                stream.as_ref().as_ptr(),
588            )
589        })
590    }
591
592    /// Element-wise reciprocal.
593    ///
594    /// # Example
595    ///
596    /// ```rust
597    /// use mlx_rs::Array;
598    /// let a = Array::from_slice(&[1.0, 2.0, 4.0], &[3]);
599    /// let mut b = a.reciprocal().unwrap();
600    ///
601    /// let b_data: &[f32] = b.as_slice();
602    /// // b_data == [1.0, 0.5, 0.25]
603    /// ```
604    #[default_device]
605    pub fn reciprocal_device(&self, stream: impl AsRef<Stream>) -> Result<Array> {
606        Array::try_from_op(|res| unsafe {
607            mlx_sys::mlx_reciprocal(res, self.as_ptr(), stream.as_ref().as_ptr())
608        })
609    }
610
611    /// Round to the given number of decimals.
612    ///
613    /// # Params
614    ///
615    /// - decimals: number of decimals to round to - default is 0 if not provided
616    #[default_device]
617    pub fn round_device(
618        &self,
619        decimals: impl Into<Option<i32>>,
620        stream: impl AsRef<Stream>,
621    ) -> Result<Array> {
622        Array::try_from_op(|res| unsafe {
623            mlx_sys::mlx_round(
624                res,
625                self.as_ptr(),
626                decimals.into().unwrap_or(0),
627                stream.as_ref().as_ptr(),
628            )
629        })
630    }
631
632    /// Element-wise reciprocal and square root.
633    #[default_device]
634    pub fn rsqrt_device(&self, stream: impl AsRef<Stream>) -> Result<Array> {
635        Array::try_from_op(|res| unsafe {
636            mlx_sys::mlx_rsqrt(res, self.as_ptr(), stream.as_ref().as_ptr())
637        })
638    }
639
640    /// Element-wise sine.
641    #[default_device]
642    pub fn sin_device(&self, stream: impl AsRef<Stream>) -> Result<Array> {
643        Array::try_from_op(|res| unsafe {
644            mlx_sys::mlx_sin(res, self.as_ptr(), stream.as_ref().as_ptr())
645        })
646    }
647
648    /// Element-wise square.
649    #[default_device]
650    pub fn square_device(&self, stream: impl AsRef<Stream>) -> Result<Array> {
651        Array::try_from_op(|res| unsafe {
652            mlx_sys::mlx_square(res, self.as_ptr(), stream.as_ref().as_ptr())
653        })
654    }
655
656    /// Element-wise real part from a complex array.
657    #[default_device]
658    pub fn real_device(&self, stream: impl AsRef<Stream>) -> Result<Array> {
659        Array::try_from_op(|res| unsafe {
660            mlx_sys::mlx_real(res, self.as_ptr(), stream.as_ref().as_ptr())
661        })
662    }
663
664    /// Element-wise imag part from a complex array.
665    #[default_device]
666    pub fn imag_device(&self, stream: impl AsRef<Stream>) -> Result<Array> {
667        Array::try_from_op(|res| unsafe {
668            mlx_sys::mlx_imag(res, self.as_ptr(), stream.as_ref().as_ptr())
669        })
670    }
671}
672
673/// Element-wise absolute value.
674///
675/// # Example
676///
677/// ```rust
678/// use mlx_rs::{Array, ops};
679///
680/// let array = Array::from_slice(&[1i32, 2, -3, -4, -5], &[5]);
681/// let result = ops::abs(&array).unwrap();
682/// ```
683#[generate_macro]
684#[default_device]
685pub fn abs_device(a: impl AsRef<Array>, #[optional] stream: impl AsRef<Stream>) -> Result<Array> {
686    a.as_ref().abs_device(stream)
687}
688
689/// Element-wise inverse cosine.
690#[generate_macro]
691#[default_device]
692pub fn acos_device(a: impl AsRef<Array>, #[optional] stream: impl AsRef<Stream>) -> Result<Array> {
693    Array::try_from_op(|res| unsafe {
694        mlx_sys::mlx_arccos(res, a.as_ref().as_ptr(), stream.as_ref().as_ptr())
695    })
696}
697
698/// Element-wise inverse hyperbolic cosine.
699#[generate_macro]
700#[default_device]
701pub fn acosh_device(a: impl AsRef<Array>, #[optional] stream: impl AsRef<Stream>) -> Result<Array> {
702    Array::try_from_op(|res| unsafe {
703        mlx_sys::mlx_arccosh(res, a.as_ref().as_ptr(), stream.as_ref().as_ptr())
704    })
705}
706
707/// See [`Array::add`].
708#[generate_macro]
709#[default_device]
710pub fn add_device(
711    lhs: impl AsRef<Array>,
712    rhs: impl AsRef<Array>,
713    #[optional] stream: impl AsRef<Stream>,
714) -> Result<Array> {
715    lhs.as_ref().add_device(rhs, stream)
716}
717
718/// Element-wise inverse sine.
719#[generate_macro]
720#[default_device]
721pub fn asin_device(a: impl AsRef<Array>, #[optional] stream: impl AsRef<Stream>) -> Result<Array> {
722    Array::try_from_op(|res| unsafe {
723        mlx_sys::mlx_arcsin(res, a.as_ref().as_ptr(), stream.as_ref().as_ptr())
724    })
725}
726
727/// Element-wise inverse hyperbolic sine.
728#[generate_macro]
729#[default_device]
730pub fn asinh_device(a: impl AsRef<Array>, #[optional] stream: impl AsRef<Stream>) -> Result<Array> {
731    Array::try_from_op(|res| unsafe {
732        mlx_sys::mlx_arcsinh(res, a.as_ref().as_ptr(), stream.as_ref().as_ptr())
733    })
734}
735
736/// Element-wise inverse tangent.
737#[generate_macro]
738#[default_device]
739pub fn atan_device(a: impl AsRef<Array>, #[optional] stream: impl AsRef<Stream>) -> Result<Array> {
740    Array::try_from_op(|res| unsafe {
741        mlx_sys::mlx_arctan(res, a.as_ref().as_ptr(), stream.as_ref().as_ptr())
742    })
743}
744
745/// Element-wise inverse hyperbolic tangent.
746#[generate_macro]
747#[default_device]
748pub fn atanh_device(a: impl AsRef<Array>, #[optional] stream: impl AsRef<Stream>) -> Result<Array> {
749    Array::try_from_op(|res| unsafe {
750        mlx_sys::mlx_arctanh(res, a.as_ref().as_ptr(), stream.as_ref().as_ptr())
751    })
752}
753
754/// Element-wise ceiling.
755#[generate_macro]
756#[default_device]
757pub fn ceil_device(a: impl AsRef<Array>, #[optional] stream: impl AsRef<Stream>) -> Result<Array> {
758    Array::try_from_op(|res| unsafe {
759        mlx_sys::mlx_ceil(res, a.as_ref().as_ptr(), stream.as_ref().as_ptr())
760    })
761}
762
763/// A custom trait for the bound of the clip operation.
764///
765/// This trait is only implemented for tuples of the form `(Min, Max)`, `(Min, ())`, and `((),
766/// Max)`. The `Min` and `Max` types must implement the `ScalarOrArray` trait.
767pub trait ClipBound<'min, 'max>: Sealed {
768    /// Convert the bound into a tuple of optional minimum and maximum values.
769    fn into_min_max(
770        self,
771    ) -> (
772        Option<impl ScalarOrArray<'min>>,
773        Option<impl ScalarOrArray<'max>>,
774    );
775}
776
777impl<'min, Min> ClipBound<'min, 'min> for (Min, ())
778where
779    Min: ScalarOrArray<'min> + Sealed,
780{
781    fn into_min_max(
782        self,
783    ) -> (
784        Option<impl ScalarOrArray<'min>>,
785        Option<impl ScalarOrArray<'min>>,
786    ) {
787        (Some(self.0), Option::<Min>::None)
788    }
789}
790
791impl<'max, Max> ClipBound<'max, 'max> for ((), Max)
792where
793    Max: ScalarOrArray<'max> + Sealed,
794{
795    fn into_min_max(
796        self,
797    ) -> (
798        Option<impl ScalarOrArray<'max>>,
799        Option<impl ScalarOrArray<'max>>,
800    ) {
801        (Option::<Max>::None, Some(self.1))
802    }
803}
804
805impl<'min, 'max, Min, Max> ClipBound<'min, 'max> for (Min, Max)
806where
807    Min: ScalarOrArray<'min> + Sealed,
808    Max: ScalarOrArray<'max> + Sealed,
809{
810    fn into_min_max(
811        self,
812    ) -> (
813        Option<impl ScalarOrArray<'min>>,
814        Option<impl ScalarOrArray<'max>>,
815    ) {
816        (Some(self.0), Some(self.1))
817    }
818}
819
820/// Clip the values of the array between the given minimum and maximum.
821///
822/// If either `a_min` or `a_max` are None, then corresponding edge is ignored. At least one of
823/// `a_min` and `a_max` cannot be `None`. The input `a` and the limits must broadcast with one
824/// another.
825///
826/// # Params
827///
828/// - `a`: Input array.
829/// - `bound`: minimum and/or maximum values to clip the array to.
830///
831/// # Example
832///
833/// ```rust
834/// use mlx_rs::{Array, ops::clip, array};
835///
836/// let a = array!([1.0, 4.0, 3.0, 8.0, 5.0]);
837/// let expected = array!([2.0, 4.0, 3.0, 6.0, 5.0]);
838/// let clipped = clip(&a, (2.0, 6.0)).unwrap();
839/// assert_eq!(clipped, expected);
840/// ```
841#[generate_macro]
842#[default_device]
843pub fn clip_device<'min, 'max>(
844    a: impl AsRef<Array>,
845    bound: impl ClipBound<'min, 'max>,
846    #[optional] stream: impl AsRef<Stream>,
847) -> Result<Array> {
848    let (a_min, a_max) = bound.into_min_max();
849
850    // This is needed to keep the lifetime of the min/max arrays in scope.
851    let a_min = a_min.map(|min| min.into_owned_or_ref_array());
852    let a_max = a_max.map(|max| max.into_owned_or_ref_array());
853
854    unsafe {
855        let min_ptr = match &a_min {
856            Some(a_min) => a_min.as_ref().as_ptr(),
857            None => mlx_sys::mlx_array_new(),
858        };
859        let max_ptr = match &a_max {
860            Some(a_max) => a_max.as_ref().as_ptr(),
861            None => mlx_sys::mlx_array_new(),
862        };
863
864        Array::try_from_op(|res| {
865            mlx_sys::mlx_clip(
866                res,
867                a.as_ref().as_ptr(),
868                min_ptr,
869                max_ptr,
870                stream.as_ref().as_ptr(),
871            )
872        })
873    }
874}
875
876/// Element-wise cosine.
877#[generate_macro]
878#[default_device]
879pub fn cos_device(a: impl AsRef<Array>, #[optional] stream: impl AsRef<Stream>) -> Result<Array> {
880    a.as_ref().cos_device(stream)
881}
882
883/// Element-wise hyperbolic cosine.
884#[generate_macro]
885#[default_device]
886pub fn cosh_device(a: impl AsRef<Array>, #[optional] stream: impl AsRef<Stream>) -> Result<Array> {
887    Array::try_from_op(|res| unsafe {
888        mlx_sys::mlx_cosh(res, a.as_ref().as_ptr(), stream.as_ref().as_ptr())
889    })
890}
891
892/// Convert angles from radians to degrees.
893#[generate_macro]
894#[default_device]
895pub fn degrees_device(
896    a: impl AsRef<Array>,
897    #[optional] stream: impl AsRef<Stream>,
898) -> Result<Array> {
899    Array::try_from_op(|res| unsafe {
900        mlx_sys::mlx_degrees(res, a.as_ref().as_ptr(), stream.as_ref().as_ptr())
901    })
902}
903
904/// See [`Array::divide`].
905#[generate_macro]
906#[default_device]
907pub fn divide_device(
908    a: impl AsRef<Array>,
909    b: impl AsRef<Array>,
910    #[optional] stream: impl AsRef<Stream>,
911) -> Result<Array> {
912    a.as_ref().divide_device(b, stream)
913}
914
915/// Element-wise quotient and remainder.
916///
917/// The fuction `divmod(a, b)` is equivalent to but faster than `(a // b, a % b)`. The function uses
918/// numpy-style broadcasting semantics. Either or both input arrays can also be scalars.
919///
920/// Returns Ok((quotient, remainder)) if the operation was successful.
921#[generate_macro]
922#[default_device]
923pub fn divmod_device(
924    a: impl AsRef<Array>,
925    b: impl AsRef<Array>,
926    #[optional] stream: impl AsRef<Stream>,
927) -> Result<(Array, Array)> {
928    let a_ptr = a.as_ref().as_ptr();
929    let b_ptr = b.as_ref().as_ptr();
930
931    let vec = VectorArray::try_from_op(|res| unsafe {
932        mlx_sys::mlx_divmod(res, a_ptr, b_ptr, stream.as_ref().as_ptr())
933    })?;
934
935    let vals: SmallVec<[_; 2]> = vec.try_into_values()?;
936    let mut iter = vals.into_iter();
937    let quotient = iter.next().unwrap();
938    let remainder = iter.next().unwrap();
939
940    Ok((quotient, remainder))
941}
942
943/// Element-wise error function.
944#[generate_macro]
945#[default_device]
946pub fn erf_device(a: impl AsRef<Array>, #[optional] stream: impl AsRef<Stream>) -> Result<Array> {
947    Array::try_from_op(|res| unsafe {
948        mlx_sys::mlx_erf(res, a.as_ref().as_ptr(), stream.as_ref().as_ptr())
949    })
950}
951
952/// Element-wise inverse error function.
953#[generate_macro]
954#[default_device]
955pub fn erfinv_device(
956    a: impl AsRef<Array>,
957    #[optional] stream: impl AsRef<Stream>,
958) -> Result<Array> {
959    Array::try_from_op(|res| unsafe {
960        mlx_sys::mlx_erfinv(res, a.as_ref().as_ptr(), stream.as_ref().as_ptr())
961    })
962}
963
964/// See [`Array::exp`].
965#[generate_macro]
966#[default_device]
967pub fn exp_device(a: impl AsRef<Array>, #[optional] stream: impl AsRef<Stream>) -> Result<Array> {
968    a.as_ref().exp_device(stream)
969}
970
971/// Element-wise exponential minus 1.
972#[generate_macro]
973#[default_device]
974pub fn expm1_device(a: impl AsRef<Array>, #[optional] stream: impl AsRef<Stream>) -> Result<Array> {
975    Array::try_from_op(|res| unsafe {
976        mlx_sys::mlx_expm1(res, a.as_ref().as_ptr(), stream.as_ref().as_ptr())
977    })
978}
979
980/// See [`Array::floor`].
981#[generate_macro]
982#[default_device]
983pub fn floor_device(a: impl AsRef<Array>, #[optional] stream: impl AsRef<Stream>) -> Result<Array> {
984    a.as_ref().floor_device(stream)
985}
986
987/// See [`Array::floor_divide`].
988#[generate_macro]
989#[default_device]
990pub fn floor_divide_device(
991    a: impl AsRef<Array>,
992    other: impl AsRef<Array>,
993    #[optional] stream: impl AsRef<Stream>,
994) -> Result<Array> {
995    a.as_ref().floor_divide_device(other, stream)
996}
997
998/// See [`Array::log`].
999#[generate_macro]
1000#[default_device]
1001pub fn log_device(a: impl AsRef<Array>, #[optional] stream: impl AsRef<Stream>) -> Result<Array> {
1002    a.as_ref().log_device(stream)
1003}
1004
1005/// See [`Array::log10`].
1006#[generate_macro]
1007#[default_device]
1008pub fn log10_device(a: impl AsRef<Array>, #[optional] stream: impl AsRef<Stream>) -> Result<Array> {
1009    a.as_ref().log10_device(stream)
1010}
1011
1012/// See [`Array::log1p`].
1013#[generate_macro]
1014#[default_device]
1015pub fn log1p_device(a: impl AsRef<Array>, #[optional] stream: impl AsRef<Stream>) -> Result<Array> {
1016    a.as_ref().log1p_device(stream)
1017}
1018
1019/// See [`Array::log2`].
1020#[generate_macro]
1021#[default_device]
1022pub fn log2_device(a: impl AsRef<Array>, #[optional] stream: impl AsRef<Stream>) -> Result<Array> {
1023    a.as_ref().log2_device(stream)
1024}
1025
1026/// Element-wise log-add-exp.
1027///
1028/// This is a numerically stable log-add-exp of two arrays with numpy-style broadcasting semantics.
1029/// Either or both input arrays can also be scalars.
1030///
1031/// The computation is is a numerically stable version of `log(exp(a) + exp(b))`.
1032#[generate_macro]
1033#[default_device]
1034pub fn logaddexp_device(
1035    a: impl AsRef<Array>,
1036    b: impl AsRef<Array>,
1037    #[optional] stream: impl AsRef<Stream>,
1038) -> Result<Array> {
1039    let a_ptr = a.as_ref().as_ptr();
1040    let b_ptr = b.as_ref().as_ptr();
1041
1042    Array::try_from_op(|res| unsafe {
1043        mlx_sys::mlx_logaddexp(res, a_ptr, b_ptr, stream.as_ref().as_ptr())
1044    })
1045}
1046
1047/// See [`Array::matmul`].
1048#[generate_macro]
1049#[default_device]
1050pub fn matmul_device(
1051    a: impl AsRef<Array>,
1052    b: impl AsRef<Array>,
1053    #[optional] stream: impl AsRef<Stream>,
1054) -> Result<Array> {
1055    a.as_ref().matmul_device(b, stream)
1056}
1057
1058/// Element-wise maximum.
1059///
1060/// Take the element-wise max of two arrays with numpy-style broadcasting semantics. Either or both
1061/// input arrays can also be scalars.
1062#[generate_macro]
1063#[default_device]
1064pub fn maximum_device(
1065    a: impl AsRef<Array>,
1066    b: impl AsRef<Array>,
1067    #[optional] stream: impl AsRef<Stream>,
1068) -> Result<Array> {
1069    let a_ptr = a.as_ref().as_ptr();
1070    let b_ptr = b.as_ref().as_ptr();
1071
1072    Array::try_from_op(|res| unsafe {
1073        mlx_sys::mlx_maximum(res, a_ptr, b_ptr, stream.as_ref().as_ptr())
1074    })
1075}
1076
1077/// Element-wise minimum.
1078///
1079/// Take the element-wise min of two arrays with numpy-style broadcasting semantics. Either or both
1080/// input arrays can also be scalars.
1081#[generate_macro]
1082#[default_device]
1083pub fn minimum_device(
1084    a: impl AsRef<Array>,
1085    b: impl AsRef<Array>,
1086    #[optional] stream: impl AsRef<Stream>,
1087) -> Result<Array> {
1088    let a_ptr = a.as_ref().as_ptr();
1089    let b_ptr = b.as_ref().as_ptr();
1090
1091    Array::try_from_op(|res| unsafe {
1092        mlx_sys::mlx_minimum(res, a_ptr, b_ptr, stream.as_ref().as_ptr())
1093    })
1094}
1095
1096/// See [`Array::multiply`].
1097#[generate_macro]
1098#[default_device]
1099pub fn multiply_device(
1100    a: impl AsRef<Array>,
1101    b: impl AsRef<Array>,
1102    #[optional] stream: impl AsRef<Stream>,
1103) -> Result<Array> {
1104    a.as_ref().multiply_device(b, stream)
1105}
1106
1107/// See [`Array::negative`].
1108#[generate_macro]
1109#[default_device]
1110pub fn negative_device(
1111    a: impl AsRef<Array>,
1112    #[optional] stream: impl AsRef<Stream>,
1113) -> Result<Array> {
1114    a.as_ref().negative_device(stream)
1115}
1116
1117/// See [`Array::power`].
1118#[generate_macro]
1119#[default_device]
1120pub fn power_device(
1121    a: impl AsRef<Array>,
1122    b: impl AsRef<Array>,
1123    #[optional] stream: impl AsRef<Stream>,
1124) -> Result<Array> {
1125    a.as_ref().power_device(b, stream)
1126}
1127
1128/// Convert angles from degrees to radians.
1129#[generate_macro]
1130#[default_device]
1131pub fn radians_device(
1132    a: impl AsRef<Array>,
1133    #[optional] stream: impl AsRef<Stream>,
1134) -> Result<Array> {
1135    Array::try_from_op(|res| unsafe {
1136        mlx_sys::mlx_radians(res, a.as_ref().as_ptr(), stream.as_ref().as_ptr())
1137    })
1138}
1139
1140/// See [`Array::reciprocal`].
1141#[generate_macro]
1142#[default_device]
1143pub fn reciprocal_device(
1144    a: impl AsRef<Array>,
1145    #[optional] stream: impl AsRef<Stream>,
1146) -> Result<Array> {
1147    a.as_ref().reciprocal_device(stream)
1148}
1149
1150/// See [`Array::remainder`].
1151#[generate_macro]
1152#[default_device]
1153pub fn remainder_device(
1154    a: impl AsRef<Array>,
1155    b: impl AsRef<Array>,
1156    #[optional] stream: impl AsRef<Stream>,
1157) -> Result<Array> {
1158    a.as_ref().remainder_device(b, stream)
1159}
1160
1161/// See [`Array::round`].
1162#[generate_macro]
1163#[default_device]
1164pub fn round_device(
1165    a: impl AsRef<Array>,
1166    decimals: impl Into<Option<i32>>,
1167    #[optional] stream: impl AsRef<Stream>,
1168) -> Result<Array> {
1169    a.as_ref().round_device(decimals, stream)
1170}
1171
1172/// See [`Array::rsqrt`].
1173#[generate_macro]
1174#[default_device]
1175pub fn rsqrt_device(a: impl AsRef<Array>, #[optional] stream: impl AsRef<Stream>) -> Result<Array> {
1176    a.as_ref().rsqrt_device(stream)
1177}
1178
1179/// Element-wise logistic sigmoid.
1180///
1181/// See the [python API
1182/// docs](https://ml-explore.github.io/mlx/build/html/python/_autosummary/mlx.core.sigmoid.html#mlx.core.sigmoid)
1183/// for more information
1184#[generate_macro]
1185#[default_device]
1186pub fn sigmoid_device(
1187    a: impl AsRef<Array>,
1188    #[optional] stream: impl AsRef<Stream>,
1189) -> Result<Array> {
1190    Array::try_from_op(|res| unsafe {
1191        mlx_sys::mlx_sigmoid(res, a.as_ref().as_ptr(), stream.as_ref().as_ptr())
1192    })
1193}
1194
1195/// Element-wise sign.
1196#[generate_macro]
1197#[default_device]
1198pub fn sign_device(a: impl AsRef<Array>, #[optional] stream: impl AsRef<Stream>) -> Result<Array> {
1199    Array::try_from_op(|res| unsafe {
1200        mlx_sys::mlx_sign(res, a.as_ref().as_ptr(), stream.as_ref().as_ptr())
1201    })
1202}
1203
1204/// See [`Array::sin`].
1205#[generate_macro]
1206#[default_device]
1207pub fn sin_device(a: impl AsRef<Array>, #[optional] stream: impl AsRef<Stream>) -> Result<Array> {
1208    a.as_ref().sin_device(stream)
1209}
1210
1211/// Element-wise hyperbolic sine.
1212#[generate_macro]
1213#[default_device]
1214pub fn sinh_device(a: impl AsRef<Array>, #[optional] stream: impl AsRef<Stream>) -> Result<Array> {
1215    Array::try_from_op(|res| unsafe {
1216        mlx_sys::mlx_sinh(res, a.as_ref().as_ptr(), stream.as_ref().as_ptr())
1217    })
1218}
1219
1220/// Perform the softmax along the given axis.
1221///
1222/// See the [python API
1223/// docs](https://ml-explore.github.io/mlx/build/html/python/_autosummary/mlx.core.softmax.html#mlx.core.softmax)
1224/// for more information.
1225#[generate_macro]
1226#[default_device]
1227pub fn softmax_axes_device(
1228    a: impl AsRef<Array>,
1229    axes: &[i32],
1230    precise: impl Into<Option<bool>>,
1231    #[optional] stream: impl AsRef<Stream>,
1232) -> Result<Array> {
1233    let precise = precise.into().unwrap_or(false);
1234    let s = stream.as_ref().as_ptr();
1235
1236    Array::try_from_op(|res| unsafe {
1237        mlx_sys::mlx_softmax_axes(
1238            res,
1239            a.as_ref().as_ptr(),
1240            axes.as_ptr(),
1241            axes.len(),
1242            precise,
1243            s,
1244        )
1245    })
1246}
1247
1248/// Similar to [`softmax_axes`] but with a single axis.
1249#[generate_macro]
1250#[default_device]
1251pub fn softmax_axis_device(
1252    a: impl AsRef<Array>,
1253    axis: i32,
1254    precise: impl Into<Option<bool>>,
1255    #[optional] stream: impl AsRef<Stream>,
1256) -> Result<Array> {
1257    let precise = precise.into().unwrap_or(false);
1258    let s = stream.as_ref().as_ptr();
1259
1260    Array::try_from_op(|res| unsafe {
1261        mlx_sys::mlx_softmax_axis(res, a.as_ref().as_ptr(), axis, precise, s)
1262    })
1263}
1264
1265/// Similar to [`softmax_axes`] but with no axis specified.
1266#[generate_macro]
1267#[default_device]
1268pub fn softmax_device(
1269    a: impl AsRef<Array>,
1270    precise: impl Into<Option<bool>>,
1271    #[optional] stream: impl AsRef<Stream>,
1272) -> Result<Array> {
1273    let precise = precise.into().unwrap_or(false);
1274    let s = stream.as_ref().as_ptr();
1275
1276    Array::try_from_op(|res| unsafe { mlx_sys::mlx_softmax(res, a.as_ref().as_ptr(), precise, s) })
1277}
1278
1279/// See [`Array::sqrt`].
1280#[generate_macro]
1281#[default_device]
1282pub fn sqrt_device(a: impl AsRef<Array>, #[optional] stream: impl AsRef<Stream>) -> Result<Array> {
1283    a.as_ref().sqrt_device(stream)
1284}
1285
1286/// See [`Array::square`].
1287#[generate_macro]
1288#[default_device]
1289pub fn square_device(
1290    a: impl AsRef<Array>,
1291    #[optional] stream: impl AsRef<Stream>,
1292) -> Result<Array> {
1293    a.as_ref().square_device(stream)
1294}
1295
1296/// See [`Array::subtract`].
1297#[generate_macro]
1298#[default_device]
1299pub fn subtract_device(
1300    a: impl AsRef<Array>,
1301    b: impl AsRef<Array>,
1302    #[optional] stream: impl AsRef<Stream>,
1303) -> Result<Array> {
1304    a.as_ref().subtract_device(b, stream)
1305}
1306
1307/// See [`Array::tan`].
1308#[generate_macro]
1309#[default_device]
1310pub fn tan_device(a: impl AsRef<Array>, #[optional] stream: impl AsRef<Stream>) -> Result<Array> {
1311    Array::try_from_op(|res| unsafe {
1312        mlx_sys::mlx_tan(res, a.as_ref().as_ptr(), stream.as_ref().as_ptr())
1313    })
1314}
1315
1316/// Element-wise hyperbolic tangent.
1317#[generate_macro]
1318#[default_device]
1319pub fn tanh_device(a: impl AsRef<Array>, #[optional] stream: impl AsRef<Stream>) -> Result<Array> {
1320    Array::try_from_op(|res| unsafe {
1321        mlx_sys::mlx_tanh(res, a.as_ref().as_ptr(), stream.as_ref().as_ptr())
1322    })
1323}
1324
1325/// Element-wise real part from a complex array.
1326#[generate_macro]
1327#[default_device]
1328pub fn real_device(a: impl AsRef<Array>, #[optional] stream: impl AsRef<Stream>) -> Result<Array> {
1329    Array::try_from_op(|res| unsafe {
1330        mlx_sys::mlx_real(res, a.as_ref().as_ptr(), stream.as_ref().as_ptr())
1331    })
1332}
1333
1334/// Element-wise imaginary part from a complex array.
1335#[generate_macro]
1336#[default_device]
1337pub fn imag_device(a: impl AsRef<Array>, #[optional] stream: impl AsRef<Stream>) -> Result<Array> {
1338    Array::try_from_op(|res| unsafe {
1339        mlx_sys::mlx_imag(res, a.as_ref().as_ptr(), stream.as_ref().as_ptr())
1340    })
1341}
1342
1343/// Matrix multiplication with block masking.
1344///
1345/// See the [python API docs](
1346/// https://ml-explore.github.io/mlx/build/html/python/_autosummary/mlx.core.block_masked_mm.html#mlx.core.block_masked_mm
1347/// ) for more information.
1348#[generate_macro]
1349#[default_device]
1350pub fn block_masked_mm_device<'mo, 'lhs, 'rhs>(
1351    a: impl AsRef<Array>,
1352    b: impl AsRef<Array>,
1353    #[optional] block_size: impl Into<Option<i32>>,
1354    #[optional] mask_out: impl Into<Option<&'mo Array>>,
1355    #[optional] mask_lhs: impl Into<Option<&'lhs Array>>,
1356    #[optional] mask_rhs: impl Into<Option<&'rhs Array>>,
1357    #[optional] stream: impl AsRef<Stream>,
1358) -> Result<Array> {
1359    let a_ptr = a.as_ref().as_ptr();
1360    let b_ptr = b.as_ref().as_ptr();
1361    unsafe {
1362        let mask_out_ptr = mask_out
1363            .into()
1364            .map(|m| m.as_ptr())
1365            .unwrap_or(mlx_sys::mlx_array_new());
1366        let mask_lhs_ptr = mask_lhs
1367            .into()
1368            .map(|m| m.as_ptr())
1369            .unwrap_or(mlx_sys::mlx_array_new());
1370        let mask_rhs_ptr = mask_rhs
1371            .into()
1372            .map(|m| m.as_ptr())
1373            .unwrap_or(mlx_sys::mlx_array_new());
1374
1375        Array::try_from_op(|res| {
1376            mlx_sys::mlx_block_masked_mm(
1377                res,
1378                a_ptr,
1379                b_ptr,
1380                block_size.into().unwrap_or(32),
1381                mask_out_ptr,
1382                mask_lhs_ptr,
1383                mask_rhs_ptr,
1384                stream.as_ref().as_ptr(),
1385            )
1386        })
1387    }
1388}
1389
1390/// Matrix multiplication with addition and optional scaling.
1391///
1392/// Perform the (possibly batched) matrix multiplication of two arrays and add to the result with
1393/// optional scaling factors.
1394///
1395/// # Params
1396///
1397/// - `c`: input array,
1398/// - `a`: input array,
1399/// - `b`: input array,
1400/// - `alpha`: Scaling factor for the matrix product of `a` and `b` (default: `1`)
1401/// - `beta`: Scaling factor for `c` (default: `1`)
1402#[generate_macro]
1403#[default_device]
1404pub fn addmm_device(
1405    c: impl AsRef<Array>,
1406    a: impl AsRef<Array>,
1407    b: impl AsRef<Array>,
1408    #[optional] alpha: impl Into<Option<f32>>,
1409    #[optional] beta: impl Into<Option<f32>>,
1410    #[optional] stream: impl AsRef<Stream>,
1411) -> Result<Array> {
1412    let c_ptr = c.as_ref().as_ptr();
1413    let a_ptr = a.as_ref().as_ptr();
1414    let b_ptr = b.as_ref().as_ptr();
1415    let alpha = alpha.into().unwrap_or(1.0);
1416    let beta = beta.into().unwrap_or(1.0);
1417
1418    Array::try_from_op(|res| unsafe {
1419        mlx_sys::mlx_addmm(
1420            res,
1421            c_ptr,
1422            a_ptr,
1423            b_ptr,
1424            alpha,
1425            beta,
1426            stream.as_ref().as_ptr(),
1427        )
1428    })
1429}
1430
1431/// Ordinary inner product of vectors for 1-D arrays, in higher dimensions a sum product over the
1432/// last axes.
1433#[generate_macro]
1434#[default_device]
1435pub fn inner_device(
1436    a: impl AsRef<Array>,
1437    b: impl AsRef<Array>,
1438    #[optional] stream: impl AsRef<Stream>,
1439) -> Result<Array> {
1440    let a = a.as_ref();
1441    let b = b.as_ref();
1442    Array::try_from_op(|res| unsafe {
1443        mlx_sys::mlx_inner(res, a.as_ptr(), b.as_ptr(), stream.as_ref().as_ptr())
1444    })
1445}
1446
1447/// Compute the outer product of two 1-D arrays, if the array’s passed are not 1-D a flatten op will
1448/// be run beforehand.
1449#[generate_macro]
1450#[default_device]
1451pub fn outer_device(
1452    a: impl AsRef<Array>,
1453    b: impl AsRef<Array>,
1454    #[optional] stream: impl AsRef<Stream>,
1455) -> Result<Array> {
1456    let a = a.as_ref();
1457    let b = b.as_ref();
1458    Array::try_from_op(|res| unsafe {
1459        mlx_sys::mlx_outer(res, a.as_ptr(), b.as_ptr(), stream.as_ref().as_ptr())
1460    })
1461}
1462
1463/// Compute the tensor dot product along the specified axes.
1464#[generate_macro]
1465#[default_device]
1466pub fn tensordot_axes_device(
1467    a: impl AsRef<Array>,
1468    b: impl AsRef<Array>,
1469    axes_a: &[i32],
1470    axes_b: &[i32],
1471    #[optional] stream: impl AsRef<Stream>,
1472) -> Result<Array> {
1473    let a = a.as_ref();
1474    let b = b.as_ref();
1475    Array::try_from_op(|res| unsafe {
1476        mlx_sys::mlx_tensordot(
1477            res,
1478            a.as_ptr(),
1479            b.as_ptr(),
1480            axes_a.as_ptr(),
1481            axes_a.len(),
1482            axes_b.as_ptr(),
1483            axes_b.len(),
1484            stream.as_ref().as_ptr(),
1485        )
1486    })
1487}
1488
1489/// Similar to [`tensordot_axes`] but with a single axis.
1490#[generate_macro]
1491#[default_device]
1492pub fn tensordot_axis_device(
1493    a: impl AsRef<Array>,
1494    b: impl AsRef<Array>,
1495    axis: i32,
1496    #[optional] stream: impl AsRef<Stream>,
1497) -> Result<Array> {
1498    let a = a.as_ref();
1499    let b = b.as_ref();
1500    Array::try_from_op(|res| unsafe {
1501        mlx_sys::mlx_tensordot_axis(res, a.as_ptr(), b.as_ptr(), axis, stream.as_ref().as_ptr())
1502    })
1503}
1504
1505#[cfg(test)]
1506mod tests {
1507    use std::f32::consts::PI;
1508
1509    use super::*;
1510    use crate::{
1511        array, complex64,
1512        ops::{all_close, arange, broadcast_to, eye, full, linspace, ones, reshape, split},
1513        transforms::eval,
1514        Dtype, StreamOrDevice,
1515    };
1516    use float_eq::assert_float_eq;
1517    use pretty_assertions::assert_eq;
1518
1519    #[test]
1520    fn test_abs() {
1521        let data = [1i32, 2, -3, -4, -5];
1522        let array = Array::from_slice(&data, &[5]);
1523        let result = array.abs().unwrap();
1524
1525        let data: &[i32] = result.as_slice();
1526        assert_eq!(data, [1, 2, 3, 4, 5]);
1527
1528        // test that previous array is not modified and valid
1529        let data: &[i32] = array.as_slice();
1530        assert_eq!(data, [1, 2, -3, -4, -5]);
1531    }
1532
1533    #[test]
1534    fn test_add() {
1535        let a = Array::from_slice(&[1.0, 2.0, 3.0], &[3]);
1536        let b = Array::from_slice(&[4.0, 5.0, 6.0], &[3]);
1537
1538        let c = &a + &b;
1539
1540        let c_data: &[f32] = c.as_slice();
1541        assert_eq!(c_data, &[5.0, 7.0, 9.0]);
1542
1543        // check a and b are not modified
1544        let a_data: &[f32] = a.as_slice();
1545        assert_eq!(a_data, &[1.0, 2.0, 3.0]);
1546
1547        let b_data: &[f32] = b.as_slice();
1548        assert_eq!(b_data, &[4.0, 5.0, 6.0]);
1549    }
1550
1551    #[test]
1552    fn test_add_invalid_broadcast() {
1553        let a = Array::from_slice(&[1.0, 2.0, 3.0], &[3]);
1554        let b = Array::from_slice(&[4.0, 5.0], &[2]);
1555
1556        let c = a.add(&b);
1557        assert!(c.is_err());
1558    }
1559
1560    #[test]
1561    fn test_sub() {
1562        let a = Array::from_slice(&[1.0, 2.0, 3.0], &[3]);
1563        let b = Array::from_slice(&[4.0, 5.0, 6.0], &[3]);
1564
1565        let c = &a - &b;
1566
1567        let c_data: &[f32] = c.as_slice();
1568        assert_eq!(c_data, &[-3.0, -3.0, -3.0]);
1569
1570        // check a and b are not modified
1571        let a_data: &[f32] = a.as_slice();
1572        assert_eq!(a_data, &[1.0, 2.0, 3.0]);
1573
1574        let b_data: &[f32] = b.as_slice();
1575        assert_eq!(b_data, &[4.0, 5.0, 6.0]);
1576    }
1577
1578    #[test]
1579    fn test_sub_invalid_broadcast() {
1580        let a = Array::from_slice(&[1.0, 2.0, 3.0], &[3]);
1581        let b = Array::from_slice(&[4.0, 5.0], &[2]);
1582        let c = a.subtract(&b);
1583        assert!(c.is_err());
1584    }
1585
1586    #[test]
1587    fn test_neg() {
1588        let a = Array::from_slice::<f32>(&[1.0, 2.0, 3.0], &[3]);
1589        let b = a.negative().unwrap();
1590
1591        let b_data: &[f32] = b.as_slice();
1592        assert_eq!(b_data, &[-1.0, -2.0, -3.0]);
1593
1594        // check a is not modified
1595        let a_data: &[f32] = a.as_slice();
1596        assert_eq!(a_data, &[1.0, 2.0, 3.0]);
1597    }
1598
1599    #[test]
1600    fn test_neg_bool() {
1601        let a = Array::from_slice(&[true, false, true], &[3]);
1602        let b = a.negative();
1603        assert!(b.is_err());
1604    }
1605
1606    #[test]
1607    fn test_logical_not() {
1608        let a: Array = false.into();
1609        let b = a.logical_not().unwrap();
1610
1611        let b_data: &[bool] = b.as_slice();
1612        assert_eq!(b_data, [true]);
1613    }
1614
1615    #[test]
1616    fn test_mul() {
1617        let a = Array::from_slice(&[1.0, 2.0, 3.0], &[3]);
1618        let b = Array::from_slice(&[4.0, 5.0, 6.0], &[3]);
1619
1620        let c = &a * &b;
1621
1622        let c_data: &[f32] = c.as_slice();
1623        assert_eq!(c_data, &[4.0, 10.0, 18.0]);
1624
1625        // check a and b are not modified
1626        let a_data: &[f32] = a.as_slice();
1627        assert_eq!(a_data, &[1.0, 2.0, 3.0]);
1628
1629        let b_data: &[f32] = b.as_slice();
1630        assert_eq!(b_data, &[4.0, 5.0, 6.0]);
1631    }
1632
1633    #[test]
1634    fn test_mul_invalid_broadcast() {
1635        let a = Array::from_slice(&[1.0, 2.0, 3.0], &[3]);
1636        let b = Array::from_slice(&[4.0, 5.0], &[2]);
1637        let c = a.multiply(&b);
1638        assert!(c.is_err());
1639    }
1640
1641    #[test]
1642    fn test_nan_to_num() {
1643        let a = array!([1.0, 2.0, f32::NAN, 4.0, 5.0]);
1644        let b = a.nan_to_num(0.0, 1.0, 0.0).unwrap();
1645
1646        let b_data: &[f32] = b.as_slice();
1647        assert_eq!(b_data, &[1.0, 2.0, 0.0, 4.0, 5.0]);
1648    }
1649
1650    #[test]
1651    fn test_div() {
1652        let a = Array::from_slice(&[1.0, 2.0, 3.0], &[3]);
1653        let b = Array::from_slice(&[4.0, 5.0, 6.0], &[3]);
1654
1655        let c = &a / &b;
1656
1657        let c_data: &[f32] = c.as_slice();
1658        assert_eq!(c_data, &[0.25, 0.4, 0.5]);
1659
1660        // check a and b are not modified
1661        let a_data: &[f32] = a.as_slice();
1662        assert_eq!(a_data, &[1.0, 2.0, 3.0]);
1663
1664        let b_data: &[f32] = b.as_slice();
1665        assert_eq!(b_data, &[4.0, 5.0, 6.0]);
1666    }
1667
1668    #[test]
1669    fn test_div_invalid_broadcast() {
1670        let a = Array::from_slice(&[1.0, 2.0, 3.0], &[3]);
1671        let b = Array::from_slice(&[4.0, 5.0], &[2]);
1672        let c = a.divide(&b);
1673        assert!(c.is_err());
1674    }
1675
1676    #[test]
1677    fn test_pow() {
1678        let a = Array::from_slice(&[1.0, 2.0, 3.0], &[3]);
1679        let b = Array::from_slice(&[2.0, 3.0, 4.0], &[3]);
1680
1681        let c = a.power(&b).unwrap();
1682
1683        let c_data: &[f32] = c.as_slice();
1684        assert_eq!(c_data, &[1.0, 8.0, 81.0]);
1685
1686        // check a and b are not modified
1687        let a_data: &[f32] = a.as_slice();
1688        assert_eq!(a_data, &[1.0, 2.0, 3.0]);
1689
1690        let b_data: &[f32] = b.as_slice();
1691        assert_eq!(b_data, &[2.0, 3.0, 4.0]);
1692    }
1693
1694    #[test]
1695    fn test_pow_invalid_broadcast() {
1696        let a = Array::from_slice(&[1.0, 2.0, 3.0], &[3]);
1697        let b = Array::from_slice(&[2.0, 3.0], &[2]);
1698        let c = a.power(&b);
1699        assert!(c.is_err());
1700    }
1701
1702    #[test]
1703    fn test_rem() {
1704        let a = Array::from_slice(&[10.0, 11.0, 12.0], &[3]);
1705        let b = Array::from_slice(&[3.0, 4.0, 5.0], &[3]);
1706
1707        let c = &a % &b;
1708
1709        let c_data: &[f32] = c.as_slice();
1710        assert_eq!(c_data, &[1.0, 3.0, 2.0]);
1711
1712        // check a and b are not modified
1713        let a_data: &[f32] = a.as_slice();
1714        assert_eq!(a_data, &[10.0, 11.0, 12.0]);
1715
1716        let b_data: &[f32] = b.as_slice();
1717        assert_eq!(b_data, &[3.0, 4.0, 5.0]);
1718    }
1719
1720    #[test]
1721    fn test_rem_invalid_broadcast() {
1722        let a = Array::from_slice(&[10.0, 11.0, 12.0], &[3]);
1723        let b = Array::from_slice(&[3.0, 4.0], &[2]);
1724        let c = a.remainder(&b);
1725        assert!(c.is_err());
1726    }
1727
1728    #[test]
1729    fn test_sqrt() {
1730        let a = Array::from_slice(&[1.0, 4.0, 9.0], &[3]);
1731        let b = a.sqrt().unwrap();
1732
1733        let b_data: &[f32] = b.as_slice();
1734        assert_eq!(b_data, &[1.0, 2.0, 3.0]);
1735
1736        // check a is not modified
1737        let a_data: &[f32] = a.as_slice();
1738        assert_eq!(a_data, &[1.0, 4.0, 9.0]);
1739    }
1740
1741    #[test]
1742    fn test_cos() {
1743        let a = Array::from_slice(&[0.0, 1.0, 2.0], &[3]);
1744        let b = a.cos().unwrap();
1745
1746        let b_expected = array!([1.0, 0.54030234, -0.41614687]);
1747        assert_array_all_close!(b, b_expected);
1748
1749        // check a is not modified
1750        let a_expected = array!([0.0, 1.0, 2.0]);
1751        assert_array_all_close!(a, a_expected);
1752    }
1753
1754    #[test]
1755    fn test_exp() {
1756        let a = Array::from_slice(&[0.0, 1.0, 2.0], &[3]);
1757        let b = a.exp().unwrap();
1758
1759        let b_expected = array!([1.0, 2.7182817, 7.389056]);
1760        assert_array_all_close!(b, b_expected);
1761
1762        // check a is not modified
1763        let a_expected = array!([0.0, 1.0, 2.0]);
1764        assert_array_all_close!(a, a_expected);
1765    }
1766
1767    #[test]
1768    fn test_floor() {
1769        let a = Array::from_slice(&[0.1, 1.9, 2.5], &[3]);
1770        let b = a.floor().unwrap();
1771
1772        let b_data: &[f32] = b.as_slice();
1773        assert_eq!(b_data, &[0.0, 1.0, 2.0]);
1774
1775        // check a is not modified
1776        let a_data: &[f32] = a.as_slice();
1777        assert_eq!(a_data, &[0.1, 1.9, 2.5]);
1778    }
1779
1780    #[test]
1781    fn test_floor_complex64() {
1782        let val = complex64::new(1.0, 2.0);
1783        let a = Array::from_complex(val);
1784        let b = a.floor_device(StreamOrDevice::default());
1785        assert!(b.is_err());
1786    }
1787
1788    #[test]
1789    fn test_floor_divide() {
1790        let a = Array::from_slice(&[1.0, 2.0, 3.0], &[3]);
1791        let b = Array::from_slice(&[4.0, 5.0, 6.0], &[3]);
1792
1793        let c = a.floor_divide(&b).unwrap();
1794
1795        let c_data: &[f32] = c.as_slice();
1796        assert_eq!(c_data, &[0.0, 0.0, 0.0]);
1797
1798        // check a and b are not modified
1799        let a_data: &[f32] = a.as_slice();
1800        assert_eq!(a_data, &[1.0, 2.0, 3.0]);
1801
1802        let b_data: &[f32] = b.as_slice();
1803        assert_eq!(b_data, &[4.0, 5.0, 6.0]);
1804    }
1805
1806    #[test]
1807    fn test_floor_divide_complex64() {
1808        let val = complex64::new(1.0, 2.0);
1809        let a = Array::from_complex(val);
1810        let b = Array::from_slice(&[4.0, 5.0, 6.0], &[3]);
1811        let c = a.floor_divide_device(&b, StreamOrDevice::default());
1812        assert!(c.is_err());
1813    }
1814
1815    #[test]
1816    fn test_floor_divide_invalid_broadcast() {
1817        let a = Array::from_slice(&[1.0, 2.0, 3.0], &[3]);
1818        let b = Array::from_slice(&[4.0, 5.0], &[2]);
1819        let c = a.floor_divide_device(&b, StreamOrDevice::default());
1820        assert!(c.is_err());
1821    }
1822
1823    #[test]
1824    fn test_is_nan() {
1825        let a = Array::from_slice(&[1.0, f32::NAN, 3.0], &[3]);
1826        let b = a.is_nan().unwrap();
1827
1828        let b_data: &[bool] = b.as_slice();
1829        assert_eq!(b_data, &[false, true, false]);
1830    }
1831
1832    #[test]
1833    fn test_is_inf() {
1834        let a = Array::from_slice(&[1.0, f32::INFINITY, 3.0], &[3]);
1835        let b = a.is_inf().unwrap();
1836
1837        let b_data: &[bool] = b.as_slice();
1838        assert_eq!(b_data, &[false, true, false]);
1839    }
1840
1841    #[test]
1842    fn test_is_finite() {
1843        let a = Array::from_slice(&[1.0, f32::INFINITY, 3.0], &[3]);
1844        let b = a.is_finite().unwrap();
1845
1846        let b_data: &[bool] = b.as_slice();
1847        assert_eq!(b_data, &[true, false, true]);
1848    }
1849
1850    #[test]
1851    fn test_is_neg_inf() {
1852        let a = Array::from_slice(&[1.0, f32::NEG_INFINITY, 3.0], &[3]);
1853        let b = a.is_neg_inf().unwrap();
1854
1855        let b_data: &[bool] = b.as_slice();
1856        assert_eq!(b_data, &[false, true, false]);
1857    }
1858
1859    #[test]
1860    fn test_is_pos_inf() {
1861        let a = Array::from_slice(&[1.0, f32::INFINITY, 3.0], &[3]);
1862        let b = a.is_pos_inf().unwrap();
1863
1864        let b_data: &[bool] = b.as_slice();
1865        assert_eq!(b_data, &[false, true, false]);
1866    }
1867
1868    #[test]
1869    fn test_log() {
1870        let a = Array::from_slice(&[1.0, 2.0, 3.0], &[3]);
1871        let b = a.log().unwrap();
1872
1873        let b_data: &[f32] = b.as_slice();
1874        assert_eq!(b_data, &[0.0, 0.6931472, 1.0986123]);
1875
1876        // check a is not modified
1877        let a_data: &[f32] = a.as_slice();
1878        assert_eq!(a_data, &[1.0, 2.0, 3.0]);
1879    }
1880
1881    #[test]
1882    fn test_log2() {
1883        let a = Array::from_slice(&[1.0, 2.0, 4.0, 8.0], &[4]);
1884        let b = a.log2().unwrap();
1885
1886        let b_data: &[f32] = b.as_slice();
1887        assert_eq!(b_data, &[0.0, 1.0, 2.0, 3.0]);
1888
1889        // check a is not modified
1890        let a_data: &[f32] = a.as_slice();
1891        assert_eq!(a_data, &[1.0, 2.0, 4.0, 8.0]);
1892    }
1893
1894    #[test]
1895    fn test_log10() {
1896        let a = Array::from_slice(&[1.0, 10.0, 100.0], &[3]);
1897        let b = a.log10().unwrap();
1898
1899        let b_data: &[f32] = b.as_slice();
1900        assert_eq!(b_data, &[0.0, 1.0, 2.0]);
1901
1902        // check a is not modified
1903        let a_data: &[f32] = a.as_slice();
1904        assert_eq!(a_data, &[1.0, 10.0, 100.0]);
1905    }
1906
1907    #[test]
1908    fn test_log1p() {
1909        let a = Array::from_slice(&[1.0, 2.0, 3.0], &[3]);
1910        let b = a.log1p().unwrap();
1911
1912        let b_data: &[f32] = b.as_slice();
1913        assert_eq!(b_data, &[0.6931472, 1.0986123, 1.3862944]);
1914
1915        // check a is not modified
1916        let a_data: &[f32] = a.as_slice();
1917        assert_eq!(a_data, &[1.0, 2.0, 3.0]);
1918    }
1919
1920    #[test]
1921    fn test_matmul() {
1922        let a = Array::from_slice(&[1, 2, 3, 4], &[2, 2]);
1923        let b = Array::from_slice(&[-5.0, 37.5, 4., 7., 1., 0.], &[2, 3]);
1924
1925        let c = a.matmul(&b).unwrap();
1926
1927        assert_eq!(c.shape(), &[2, 3]);
1928        let c_data: &[f32] = c.as_slice();
1929        assert_eq!(c_data, &[9.0, 39.5, 4.0, 13.0, 116.5, 12.0]);
1930
1931        // check a and b are not modified
1932        let a_data: &[i32] = a.as_slice();
1933        assert_eq!(a_data, &[1, 2, 3, 4]);
1934
1935        let b_data: &[f32] = b.as_slice();
1936        assert_eq!(b_data, &[-5.0, 37.5, 4., 7., 1., 0.]);
1937    }
1938
1939    #[test]
1940    fn test_matmul_ndim_zero() {
1941        let a: Array = 1.0.into();
1942        let b = Array::from_slice::<i32>(&[1], &[1]);
1943        let c = a.matmul(&b);
1944        assert!(c.is_err());
1945    }
1946
1947    #[test]
1948    fn test_matmul_ndim_one() {
1949        let a = Array::from_slice(&[1.0, 2.0, 3.0, 4.0], &[4]);
1950        let b = Array::from_slice(&[1.0, 2.0, 3.0, 4.0], &[4]);
1951        let c = a.matmul(&b);
1952        assert!(c.is_ok());
1953    }
1954
1955    #[test]
1956    fn test_matmul_dim_mismatch() {
1957        let a = Array::from_slice(&[1, 2, 3, 4, 5, 6], &[2, 3]);
1958        let b = Array::from_slice(&[1, 2, 3, 4, 5, 6, 7, 8, 9, 10], &[2, 5]);
1959        let c = a.matmul(&b);
1960        assert!(c.is_err());
1961    }
1962
1963    #[test]
1964    fn test_matmul_non_float_output_type() {
1965        let a = Array::from_slice(&[1, 2, 3, 4], &[2, 2]);
1966        let b = Array::from_slice(&[5, 37, 4, 7, 1, 0], &[2, 3]);
1967
1968        let c = a.matmul(&b);
1969        assert!(c.is_err());
1970    }
1971
1972    #[test]
1973    fn test_reciprocal() {
1974        let a = Array::from_slice(&[1.0, 2.0, 4.0], &[3]);
1975        let b = a.reciprocal().unwrap();
1976
1977        let b_data: &[f32] = b.as_slice();
1978        assert_eq!(b_data, &[1.0, 0.5, 0.25]);
1979
1980        // check a is not modified
1981        let a_data: &[f32] = a.as_slice();
1982        assert_eq!(a_data, &[1.0, 2.0, 4.0]);
1983    }
1984
1985    #[test]
1986    fn test_round() {
1987        let a = Array::from_slice(&[1.1, 2.9, 3.5], &[3]);
1988        let b = a.round(None).unwrap();
1989
1990        let b_data: &[f32] = b.as_slice();
1991        assert_eq!(b_data, &[1.0, 3.0, 4.0]);
1992
1993        // check a is not modified
1994        let a_data: &[f32] = a.as_slice();
1995        assert_eq!(a_data, &[1.1, 2.9, 3.5]);
1996    }
1997
1998    #[test]
1999    fn test_rsqrt() {
2000        let a = Array::from_slice(&[1.0, 2.0, 4.0], &[3]);
2001        let b = a.rsqrt().unwrap();
2002
2003        let b_data: &[f32] = b.as_slice();
2004        assert_eq!(b_data, &[1.0, 0.70710677, 0.5]);
2005
2006        // check a is not modified
2007        let a_data: &[f32] = a.as_slice();
2008        assert_eq!(a_data, &[1.0, 2.0, 4.0]);
2009    }
2010
2011    #[test]
2012    fn test_sin() {
2013        let a = Array::from_slice(&[0.0, 1.0, 2.0], &[3]);
2014        let b = a.sin().unwrap();
2015
2016        let b_data: &[f32] = b.as_slice();
2017        assert_eq!(b_data, &[0.0, 0.841471, 0.9092974]);
2018
2019        // check a is not modified
2020        let a_data: &[f32] = a.as_slice();
2021        assert_eq!(a_data, &[0.0, 1.0, 2.0]);
2022    }
2023
2024    #[test]
2025    fn test_square() {
2026        let a = Array::from_slice(&[1.0, 2.0, 3.0], &[3]);
2027        let b = a.square().unwrap();
2028
2029        let b_data: &[f32] = b.as_slice();
2030        assert_eq!(b_data, &[1.0, 4.0, 9.0]);
2031
2032        // check a is not modified
2033        let a_data: &[f32] = a.as_slice();
2034        assert_eq!(a_data, &[1.0, 2.0, 3.0]);
2035    }
2036
2037    // The unit tests below are adapted from the original mlx c++ codebase.
2038
2039    #[test]
2040    fn test_unary_neg() {
2041        let x = array!(1.0);
2042        assert_eq!(negative(&x).unwrap().item::<f32>(), -1.0);
2043        assert_eq!((-x).item::<f32>(), -1.0);
2044
2045        // works on empty array
2046        assert_eq!(-array!(), array!());
2047
2048        // Throws on bool
2049        let x = array!(true);
2050        assert!(negative(&x).is_err());
2051    }
2052
2053    #[test]
2054    fn test_unary_abs() {
2055        let x = array!([-1.0, 0.0, 1.0]);
2056        assert_eq!(abs(&x).unwrap(), array!([1.0, 0.0, 1.0]));
2057
2058        // works on empty array
2059        assert_eq!(abs(array!()).unwrap(), array!());
2060
2061        // int32
2062        let x = array!([-1, 0, 1]);
2063        assert_eq!(abs(&x).unwrap(), array!([1, 0, 1]));
2064
2065        // uint32
2066        let x = array!([1u32, 0, 1]);
2067        assert_eq!(abs(&x).unwrap(), array!([1u32, 0, 1]));
2068
2069        // bool
2070        let x = array!([false, true]);
2071        assert_eq!(abs(&x).unwrap(), array!([false, true]));
2072    }
2073
2074    #[test]
2075    fn test_unary_sign() {
2076        let x = array!([-1.0, 0.0, 1.0]);
2077        assert_eq!(sign(&x).unwrap(), x);
2078
2079        // works on empty array
2080        assert_eq!(sign(array!()).unwrap(), array!());
2081
2082        // int32
2083        let x = array!([-1, 0, 1]);
2084        assert_eq!(sign(&x).unwrap(), x);
2085
2086        // uint32
2087        let x = array!([1u32, 0, 1]);
2088        assert_eq!(sign(&x).unwrap(), x);
2089
2090        // bool
2091        let x = array!([false, true]);
2092        assert_eq!(sign(&x).unwrap(), x);
2093    }
2094
2095    const NEG_INF: f32 = f32::NEG_INFINITY;
2096
2097    #[test]
2098    fn test_unary_floor_ceil() {
2099        let x = array![1.0];
2100        assert_eq!(floor(&x).unwrap().item::<f32>(), 1.0);
2101        assert_eq!(ceil(&x).unwrap().item::<f32>(), 1.0);
2102
2103        let x = array![1.5];
2104        assert_eq!(floor(&x).unwrap().item::<f32>(), 1.0);
2105        assert_eq!(ceil(&x).unwrap().item::<f32>(), 2.0);
2106
2107        let x = array![-1.5];
2108        assert_eq!(floor(&x).unwrap().item::<f32>(), -2.0);
2109        assert_eq!(ceil(&x).unwrap().item::<f32>(), -1.0);
2110
2111        let x = array![NEG_INF];
2112        assert_eq!(floor(&x).unwrap().item::<f32>(), NEG_INF);
2113        assert_eq!(ceil(&x).unwrap().item::<f32>(), NEG_INF);
2114
2115        let x = array!([1.0, 1.0]).as_type::<complex64>().unwrap();
2116        assert!(floor(&x).is_err());
2117        assert!(ceil(&x).is_err());
2118    }
2119
2120    #[test]
2121    fn test_unary_round() {
2122        let x = array!([0.5, -0.5, 1.5, -1.5, 2.3, 2.6]);
2123        assert_eq!(round(&x, None).unwrap(), array!([0, 0, 2, -2, 2, 3]));
2124
2125        let x = array!([11, 222, 32]);
2126        assert_eq!(round(&x, -1).unwrap(), array!([10, 220, 30]));
2127    }
2128
2129    #[test]
2130    fn test_unary_exp() {
2131        let x = array![0.0];
2132        assert_eq!(exp(&x).unwrap().item::<f32>(), 1.0);
2133
2134        let x = array![2.0];
2135        assert_float_eq! {
2136            exp(&x).unwrap().item::<f32>(),
2137            2.0f32.exp(),
2138            abs <= 1e-5
2139        };
2140
2141        assert_eq!(exp(array!()).unwrap(), array!());
2142
2143        let x = array![NEG_INF];
2144        assert_eq!(exp(&x).unwrap().item::<f32>(), 0.0);
2145
2146        // Integer input type
2147        let x = array![2];
2148        assert_eq!(x.dtype(), Dtype::Int32);
2149        assert_float_eq! {
2150            exp(&x).unwrap().item::<f32>(),
2151            2.0f32.exp(),
2152            abs <= 1e-5
2153        };
2154
2155        // Input is irregularly strided
2156        let x = broadcast_to(&array!(1.0), &[2, 2, 2]).unwrap();
2157        let res = exp(&x).unwrap();
2158        let expected = Array::full::<f32>(&[2, 2, 2], array!(1.0f32.exp())).unwrap();
2159        assert!(all_close(&res, &expected, None, None, None)
2160            .unwrap()
2161            .item::<bool>());
2162
2163        let data = Array::from_slice(&[0.0, 1.0, 2.0, 3.0], &[2, 2]);
2164        let x = split(&data, 2, 1).unwrap();
2165        let expected = Array::from_slice(&[0.0f32.exp(), 2.0f32.exp()], &[2, 1]);
2166        assert!(all_close(exp(&x[0]).unwrap(), &expected, None, None, None)
2167            .unwrap()
2168            .item::<bool>());
2169    }
2170
2171    #[test]
2172    fn test_unary_expm1() {
2173        let x = array![-1.0];
2174        assert_float_eq! {
2175            expm1(&x).unwrap().item::<f32>(),
2176            (-1.0f32).exp_m1(),
2177            abs <= 1e-5
2178        };
2179
2180        let x = array![1.0];
2181        assert_float_eq! {
2182            expm1(&x).unwrap().item::<f32>(),
2183            1.0f32.exp_m1(),
2184            abs <= 1e-5
2185        };
2186
2187        // Integer input type
2188        let x = array![1];
2189        assert_eq!(expm1(&x).unwrap().dtype(), Dtype::Float32);
2190        assert_float_eq! {
2191            expm1(&x).unwrap().item::<f32>(),
2192            1.0f32.exp_m1(),
2193            abs <= 1e-5
2194        };
2195    }
2196
2197    #[test]
2198    fn test_unary_sin() {
2199        let x = array![0.0];
2200        assert_eq!(sin(&x).unwrap().item::<f32>(), 0.0);
2201
2202        let x = array![std::f32::consts::PI / 2.0];
2203        assert_float_eq! {
2204            sin(&x).unwrap().item::<f32>(),
2205            (std::f32::consts::PI / 2.0f32).sin(),
2206            abs <= 1e-5
2207        };
2208
2209        assert_eq!(sin(array!()).unwrap(), array!());
2210
2211        // Integer input type
2212        let x = array![0];
2213        assert_eq!(x.dtype(), Dtype::Int32);
2214        assert_float_eq! {
2215            sin(&x).unwrap().item::<f32>(),
2216            0.0f32.sin(),
2217            abs <= 1e-5
2218        };
2219
2220        // Input is irregularly strided
2221        let x = broadcast_to(&array!(1.0), &[2, 2, 2]).unwrap();
2222        let res = sin(&x).unwrap();
2223        let expected = Array::full::<f32>(&[2, 2, 2], array!(1.0f32.sin())).unwrap();
2224        assert!(all_close(&res, &expected, None, None, None)
2225            .unwrap()
2226            .item::<bool>());
2227
2228        let data = Array::from_slice(&[0.0, 1.0, 2.0, 3.0], &[2, 2]);
2229        let x = split(&data, 2, 1).unwrap();
2230        let expected = Array::from_slice(&[0.0f32.sin(), 2.0f32.sin()], &[2, 1]);
2231        assert!(all_close(sin(&x[0]).unwrap(), &expected, None, None, None)
2232            .unwrap()
2233            .item::<bool>());
2234    }
2235
2236    #[test]
2237    fn test_unary_cos() {
2238        let x = array![0.0];
2239        assert_float_eq! {
2240            cos(&x).unwrap().item::<f32>(),
2241            0.0f32.cos(),
2242            abs <= 1e-5
2243        };
2244
2245        let x = array![std::f32::consts::PI / 2.0];
2246        assert_float_eq! {
2247            cos(&x).unwrap().item::<f32>(),
2248            (std::f32::consts::PI / 2.0f32).cos(),
2249            abs <= 1e-5
2250        };
2251
2252        assert_eq!(cos(array!()).unwrap(), array!());
2253
2254        // Integer input type
2255        let x = array![0];
2256        assert_eq!(x.dtype(), Dtype::Int32);
2257        assert_float_eq! {
2258            cos(&x).unwrap().item::<f32>(),
2259            0.0f32.cos(),
2260            abs <= 1e-5
2261        };
2262
2263        // Input is irregularly strided
2264        let x = broadcast_to(&array!(1.0), &[2, 2, 2]).unwrap();
2265        let res = cos(&x).unwrap();
2266        let expected = Array::full::<f32>(&[2, 2, 2], array!(1.0f32.cos())).unwrap();
2267        assert!(all_close(&res, &expected, None, None, None)
2268            .unwrap()
2269            .item::<bool>());
2270
2271        let data = Array::from_slice(&[0.0, 1.0, 2.0, 3.0], &[2, 2]);
2272        let x = split(&data, 2, 1).unwrap();
2273        let expected = Array::from_slice(&[0.0f32.cos(), 2.0f32.cos()], &[2, 1]);
2274        assert!(all_close(cos(&x[0]).unwrap(), &expected, None, None, None)
2275            .unwrap()
2276            .item::<bool>());
2277    }
2278
2279    #[test]
2280    fn test_unary_degrees() {
2281        let x = array![0.0];
2282        assert_eq!(degrees(&x).unwrap().item::<f32>(), 0.0);
2283
2284        let x = array![std::f32::consts::PI / 2.0];
2285        assert_eq!(degrees(&x).unwrap().item::<f32>(), 90.0);
2286
2287        assert_eq!(degrees(array!()).unwrap(), array!());
2288
2289        // Integer input type
2290        let x = array![0];
2291        assert_eq!(x.dtype(), Dtype::Int32);
2292        assert_eq!(degrees(&x).unwrap().item::<f32>(), 0.0);
2293
2294        // Input is irregularly strided
2295        let x = broadcast_to(&array!(std::f32::consts::PI / 2.0), &[2, 2, 2]).unwrap();
2296        let res = degrees(&x).unwrap();
2297        let expected = Array::full::<f32>(&[2, 2, 2], array!(90.0)).unwrap();
2298        assert!(all_close(&res, &expected, None, None, None)
2299            .unwrap()
2300            .item::<bool>());
2301
2302        let angles = Array::from_slice(&[0.0, PI / 2.0, PI, 1.5 * PI], &[2, 2]);
2303        let x = split(&angles, 2, 1).unwrap();
2304        let expected = Array::from_slice(&[0.0, 180.0], &[2, 1]);
2305        assert!(
2306            all_close(degrees(&x[0]).unwrap(), &expected, None, None, None)
2307                .unwrap()
2308                .item::<bool>()
2309        );
2310    }
2311
2312    #[test]
2313    fn test_unary_radians() {
2314        let x = array![0.0];
2315        assert_eq!(radians(&x).unwrap().item::<f32>(), 0.0);
2316
2317        let x = array![90.0];
2318        assert_eq!(
2319            radians(&x).unwrap().item::<f32>(),
2320            std::f32::consts::PI / 2.0
2321        );
2322
2323        assert_eq!(radians(array!()).unwrap(), array!());
2324
2325        // Integer input type
2326        let x = array![90];
2327        assert_eq!(x.dtype(), Dtype::Int32);
2328        assert_eq!(
2329            radians(&x).unwrap().item::<f32>(),
2330            std::f32::consts::PI / 2.0
2331        );
2332
2333        // Input is irregularly strided
2334        let x = broadcast_to(&array!(90.0), &[2, 2, 2]).unwrap();
2335        let res = radians(&x).unwrap();
2336        let expected = Array::full::<f32>(&[2, 2, 2], array!(std::f32::consts::PI / 2.0)).unwrap();
2337        assert!(all_close(&res, &expected, None, None, None)
2338            .unwrap()
2339            .item::<bool>());
2340
2341        let angles = Array::from_slice(&[0.0, 90.0, 180.0, 270.0], &[2, 2]);
2342        let x = split(&angles, 2, 1).unwrap();
2343        let expected = Array::from_slice(&[0.0, PI], &[2, 1]);
2344        assert!(
2345            all_close(radians(&x[0]).unwrap(), &expected, None, None, None)
2346                .unwrap()
2347                .item::<bool>()
2348        );
2349    }
2350
2351    #[test]
2352    fn test_unary_log() {
2353        let x = array![0.0];
2354        assert_eq!(log(&x).unwrap().item::<f32>(), NEG_INF);
2355
2356        let x = array![1.0];
2357        assert_eq!(log(&x).unwrap().item::<f32>(), 0.0);
2358
2359        // Integer input type
2360        let x = array![1];
2361        assert_eq!(log(&x).unwrap().dtype(), Dtype::Float32);
2362        assert_eq!(log(&x).unwrap().item::<f32>(), 0.0);
2363
2364        // Input is irregularly strided
2365        let x = broadcast_to(&array!(1.0), &[2, 2, 2]).unwrap();
2366        let res = log(&x).unwrap();
2367        let expected = Array::full::<f32>(&[2, 2, 2], array!(0.0)).unwrap();
2368        assert!(all_close(&res, &expected, None, None, None)
2369            .unwrap()
2370            .item::<bool>());
2371
2372        let data = Array::from_slice(&[1.0, 2.0, 3.0, 4.0], &[2, 2]);
2373        let x = split(&data, 2, 1).unwrap();
2374        let expected = Array::from_slice(&[1.0f32.ln(), 3.0f32.ln()], &[2, 1]);
2375        assert!(all_close(log(&x[0]).unwrap(), &expected, None, None, None)
2376            .unwrap()
2377            .item::<bool>());
2378    }
2379
2380    #[test]
2381    fn test_unary_log2() {
2382        let x = array![0.0];
2383        assert_eq!(log2(&x).unwrap().item::<f32>(), NEG_INF);
2384
2385        let x = array![1.0];
2386        assert_eq!(log2(&x).unwrap().item::<f32>(), 0.0);
2387
2388        let x = array![1024.0];
2389        assert_eq!(log2(&x).unwrap().item::<f32>(), 10.0);
2390    }
2391
2392    #[test]
2393    fn test_unary_log10() {
2394        let x = array![0.0];
2395        assert_eq!(log10(&x).unwrap().item::<f32>(), NEG_INF);
2396
2397        let x = array![1.0];
2398        assert_eq!(log10(&x).unwrap().item::<f32>(), 0.0);
2399
2400        let x = array![1000.0];
2401        assert_eq!(log10(&x).unwrap().item::<f32>(), 3.0);
2402    }
2403
2404    #[test]
2405    fn test_unary_log1p() {
2406        let x = array![-1.0];
2407        assert_float_eq! {
2408            log1p(&x).unwrap().item::<f32>(),
2409            (-1.0f32).ln_1p(),
2410            abs <= 1e-5
2411        };
2412
2413        let x = array![1.0];
2414        assert_float_eq! {
2415            log1p(&x).unwrap().item::<f32>(),
2416            1.0f32.ln_1p(),
2417            abs <= 1e-5
2418        };
2419
2420        // Integer input type
2421        let x = array![1];
2422        assert_eq!(log1p(&x).unwrap().dtype(), Dtype::Float32);
2423        assert_float_eq! {
2424            log1p(&x).unwrap().item::<f32>(),
2425            1.0f32.ln_1p(),
2426            abs <= 1e-5
2427        };
2428
2429        // Input is irregularly strided
2430        let x = broadcast_to(&array!(1.0), &[2, 2, 2]).unwrap();
2431        let res = log1p(&x).unwrap();
2432        let expected = Array::full::<f32>(&[2, 2, 2], array!(1.0f32.ln_1p())).unwrap();
2433        assert!(all_close(&res, &expected, None, None, None)
2434            .unwrap()
2435            .item::<bool>());
2436
2437        let data = Array::from_slice(&[1.0, 2.0, 3.0, 4.0], &[2, 2]);
2438        let x = split(&data, 2, 1).unwrap();
2439        let expected = Array::from_slice(&[1.0f32.ln_1p(), 3.0f32.ln_1p()], &[2, 1]);
2440        assert!(
2441            all_close(log1p(&x[0]).unwrap(), &expected, None, None, None)
2442                .unwrap()
2443                .item::<bool>()
2444        );
2445    }
2446
2447    #[test]
2448    fn test_unary_sigmoid() {
2449        let x = array![0.0];
2450        assert_float_eq! {
2451            sigmoid(&x).unwrap().item::<f32>(),
2452            0.5,
2453            abs <= 1e-5
2454        };
2455
2456        // Integer input type
2457        let x = array![0];
2458        assert_eq!(sigmoid(&x).unwrap().dtype(), Dtype::Float32);
2459        assert_float_eq! {
2460            sigmoid(&x).unwrap().item::<f32>(),
2461            0.5,
2462            abs <= 1e-5
2463        };
2464
2465        let inf = f32::INFINITY;
2466        let x = array![inf];
2467        assert_eq!(sigmoid(&x).unwrap().item::<f32>(), 1.0);
2468
2469        let x = array![-inf];
2470        assert_eq!(sigmoid(&x).unwrap().item::<f32>(), 0.0);
2471    }
2472
2473    #[test]
2474    fn test_unary_square() {
2475        let x = array![3.0];
2476        assert_eq!(square(&x).unwrap().item::<f32>(), 9.0);
2477
2478        let x = array![2];
2479        assert_eq!(square(&x).unwrap().item::<i32>(), 4);
2480
2481        let x = Array::full::<f32>(&[3, 3], array!(2.0)).unwrap();
2482        assert!(all_close(
2483            square(&x).unwrap(),
2484            Array::full::<f32>(&[3, 3], array!(4.0)).unwrap(),
2485            None,
2486            None,
2487            None
2488        )
2489        .unwrap()
2490        .item::<bool>());
2491    }
2492
2493    #[test]
2494    fn test_unary_sqrt_rsqrt() {
2495        let x = array![4.0];
2496        assert_eq!(sqrt(&x).unwrap().item::<f32>(), 2.0);
2497        assert_eq!(rsqrt(&x).unwrap().item::<f32>(), 0.5);
2498
2499        let x = Array::full::<f32>(&[3, 3], array!(9.0)).unwrap();
2500        assert!(all_close(
2501            sqrt(&x).unwrap(),
2502            Array::full::<f32>(&[3, 3], array!(3.0)).unwrap(),
2503            None,
2504            None,
2505            None
2506        )
2507        .unwrap()
2508        .item::<bool>());
2509
2510        let x = array![4i32];
2511        assert_eq!(sqrt(&x).unwrap().item::<f32>(), 2.0);
2512        assert_eq!(rsqrt(&x).unwrap().item::<f32>(), 0.5);
2513    }
2514
2515    #[test]
2516    fn test_unary_reciprocal() {
2517        let x = array![8.0];
2518        assert_eq!(reciprocal(&x).unwrap().item::<f32>(), 0.125);
2519
2520        let x = array![2];
2521        let out = reciprocal(&x).unwrap();
2522        assert_eq!(out.dtype(), Dtype::Float32);
2523        assert_eq!(out.item::<f32>(), 0.5);
2524
2525        let x = Array::full::<f32>(&[3, 3], array!(2.0)).unwrap();
2526        assert!(all_close(
2527            reciprocal(&x).unwrap(),
2528            Array::full::<f32>(&[3, 3], array!(0.5)).unwrap(),
2529            None,
2530            None,
2531            None
2532        )
2533        .unwrap()
2534        .item::<bool>());
2535    }
2536
2537    #[test]
2538    fn test_unary_real_imag() {
2539        let x = Array::from_complex(complex64::new(0.0, 1.0));
2540        assert_eq!(real(&x).unwrap(), Array::from_f32(0.0));
2541        assert_eq!(imag(&x).unwrap(), Array::from_f32(1.0));
2542    }
2543
2544    #[test]
2545    fn test_binary_add() {
2546        let x = array![1.0];
2547        let y = array![1.0];
2548        let z = add(&x, &y).unwrap();
2549        assert_eq!(z.item::<f32>(), 2.0);
2550
2551        let z = &x + y;
2552        assert_eq!(z.item::<f32>(), 2.0);
2553
2554        let z = add(z, &x).unwrap();
2555        assert_eq!(z.item::<f32>(), 3.0);
2556
2557        // Chain a few adds:
2558        let mut out = x.deep_clone();
2559        for _ in 0..10 {
2560            out = add(&out, &x).unwrap();
2561        }
2562        assert_eq!(out.item::<f32>(), 11.0);
2563
2564        // Works for different shapes
2565        let x = array!([1.0, 2.0, 3.0]);
2566        let y = array!([1.0, 2.0, 3.0]);
2567        let z = add(&x, &y).unwrap();
2568        assert_eq!(z.shape(), &[3]);
2569        assert_eq!(z, array!([2.0, 4.0, 6.0]));
2570
2571        // Works with scalars
2572        let x = array!([1.0, 2.0, 3.0]);
2573        let y = &x + 2.0;
2574        assert_eq!(y.dtype(), Dtype::Float32);
2575        assert_eq!(y, array!([3.0, 4.0, 5.0]));
2576        let y = &x + 2.0;
2577        assert_eq!(y.dtype(), Dtype::Float32);
2578        assert_eq!(y, array!([3.0, 4.0, 5.0]));
2579
2580        // Check type promotion
2581        let y = x + 2;
2582        assert_eq!(y.dtype(), Dtype::Float32);
2583
2584        let y = array!([1, 2, 3]) + 2.0;
2585        assert_eq!(y.dtype(), Dtype::Float32);
2586        // assert!(array_equal(&y, &array![3.0, 4.0, 5.0]).item::<bool>());
2587        assert_eq!(y, array!([3.0, 4.0, 5.0]));
2588
2589        // Broadcasting works
2590        let x = broadcast_to(&array!(1.0), &[10]).unwrap();
2591        let y = broadcast_to(&array!(2.0), &[10]).unwrap();
2592        let z = add(&x, &y).unwrap();
2593        assert_eq!(z, full::<f32>(&[10], array!(3.0)).unwrap());
2594
2595        let x = Array::from_slice(&[1.0, 2.0], &[1, 2]);
2596        let y = Array::from_slice(&[1.0, 2.0], &[2, 1]);
2597        let z = add(&x, &y).unwrap();
2598        assert_eq!(z.shape(), &[2, 2]);
2599        assert_eq!(z, Array::from_slice(&[2.0, 3.0, 3.0, 4.0], &[2, 2]));
2600
2601        let x = ones::<f32>(&[3, 2, 1]).unwrap();
2602        let z = x + 2.0;
2603        assert_eq!(z.shape(), &[3, 2, 1]);
2604        let expected = Array::from_slice(&[3.0, 3.0, 3.0, 3.0, 3.0, 3.0], &[3, 2, 1]);
2605        assert_eq!(z, expected);
2606
2607        // Works for empty arrays
2608        let x = array!();
2609        let y = array!();
2610        let z = x + y;
2611        z.eval().unwrap();
2612        assert_eq!(z.size(), 0);
2613        assert_eq!(z.shape(), &[0]);
2614    }
2615
2616    #[test]
2617    fn test_binary_sub() {
2618        let x = array!([3.0, 2.0, 1.0]);
2619        let y = array!([1.0, 1.0, 1.0]);
2620        assert_eq!(x - y, array!([2.0, 1.0, 0.0]));
2621    }
2622
2623    #[test]
2624    fn test_binary_mul() {
2625        let x = array!([1.0, 2.0, 3.0]);
2626        let y = array!([2.0, 2.0, 2.0]);
2627        assert_eq!(x * y, array!([2.0, 4.0, 6.0]));
2628    }
2629
2630    #[test]
2631    fn test_binary_div() {
2632        let x = array![1.0];
2633        let y = array![1.0];
2634        assert_eq!(divide(&x, &y).unwrap().item::<f32>(), 1.0);
2635
2636        let x = array![1.0];
2637        let y = array![0.5];
2638        assert_eq!(divide(&x, &y).unwrap().item::<f32>(), 2.0);
2639
2640        let x = array![1.0];
2641        let y = array![4.0];
2642        assert_eq!(divide(&x, &y).unwrap().item::<f32>(), 0.25);
2643
2644        let x = array![true];
2645        let y = array![true];
2646        assert_eq!(divide(&x, &y).unwrap().item::<f32>(), 1.0);
2647
2648        let x = array![false];
2649        let y = array![true];
2650        assert_eq!(divide(&x, &y).unwrap().item::<f32>(), 0.0);
2651
2652        let x = array![true];
2653        let y = array![false];
2654        assert!(divide(&x, &y).unwrap().item::<f32>().is_infinite());
2655
2656        let x = array![false];
2657        let y = array![false];
2658        assert!(divide(&x, &y).unwrap().item::<f32>().is_nan());
2659    }
2660
2661    #[test]
2662    fn test_binary_maximum_minimum() {
2663        let x = array![1.0];
2664        let y = array![0.0];
2665        assert_eq!(maximum(&x, &y).unwrap().item::<f32>(), 1.0);
2666        assert_eq!(minimum(&x, &y).unwrap().item::<f32>(), 0.0);
2667
2668        let y = array![2.0];
2669        assert_eq!(maximum(&x, &y).unwrap().item::<f32>(), 2.0);
2670        assert_eq!(minimum(&x, &y).unwrap().item::<f32>(), 1.0);
2671    }
2672
2673    #[test]
2674    fn test_binary_logaddexp() {
2675        let x = array![0.0];
2676        let y = array![0.0];
2677        assert_float_eq! {
2678            logaddexp(&x, &y).unwrap().item::<f32>(),
2679            2.0f32.ln(),
2680            abs <= 1e-5
2681        };
2682
2683        let x = array!([0u32]);
2684        let y = array!([10000u32]);
2685        assert_eq!(logaddexp(&x, &y).unwrap().item::<f32>(), 10000.0);
2686
2687        let x = array![f32::INFINITY];
2688        let y = array![3.0];
2689        assert_eq!(logaddexp(&x, &y).unwrap().item::<f32>(), f32::INFINITY);
2690
2691        let x = array![f32::NEG_INFINITY];
2692        let y = array![3.0];
2693        assert_eq!(logaddexp(&x, &y).unwrap().item::<f32>(), 3.0);
2694
2695        let x = array![f32::NEG_INFINITY];
2696        let y = array![f32::NEG_INFINITY];
2697        assert_eq!(logaddexp(&x, &y).unwrap().item::<f32>(), f32::NEG_INFINITY);
2698
2699        let x = array![f32::INFINITY];
2700        let y = array![f32::INFINITY];
2701        assert_eq!(logaddexp(&x, &y).unwrap().item::<f32>(), f32::INFINITY);
2702
2703        let x = array![f32::NEG_INFINITY];
2704        let y = array![f32::INFINITY];
2705        assert_eq!(logaddexp(&x, &y).unwrap().item::<f32>(), f32::INFINITY);
2706    }
2707
2708    #[test]
2709    fn test_basic_clip() {
2710        let a = array!([1.0, 4.0, 3.0, 8.0, 5.0]);
2711        let expected = array!([2.0, 4.0, 3.0, 6.0, 5.0]);
2712        let clipped = clip(&a, (array!(2.0), array!(6.0))).unwrap();
2713        assert_eq!(clipped, expected);
2714
2715        // Test with scalar
2716        let clipped = clip(&a, (2.0, 6.0)).unwrap();
2717        assert_eq!(clipped, expected);
2718    }
2719
2720    #[test]
2721    fn test_clip_with_only_min() {
2722        let a = array!([-1.0, 1.0, 0.0, 5.0]);
2723        let expected = array!([0.0, 1.0, 0.0, 5.0]);
2724        let clipped = clip(&a, (array!(0.0), ())).unwrap();
2725        assert_eq!(clipped, expected);
2726
2727        // Test with scalar
2728        let clipped = clip(&a, (0.0, ())).unwrap();
2729        assert_eq!(clipped, expected);
2730    }
2731
2732    #[test]
2733    fn test_clip_with_only_max() {
2734        let a = array!([2.0, 3.0, 4.0, 5.0]);
2735        let expected = array!([2.0, 3.0, 4.0, 4.0]);
2736        let clipped = clip(&a, ((), array!(4.0))).unwrap();
2737        assert_eq!(clipped, expected);
2738
2739        // Test with scalar
2740        let clipped = clip(&a, ((), 4.0)).unwrap();
2741        assert_eq!(clipped, expected);
2742    }
2743
2744    #[test]
2745    fn test_tensordot() {
2746        let x = reshape(arange::<_, f32>(None, 60.0, None).unwrap(), &[3, 4, 5]).unwrap();
2747        let y = reshape(arange::<_, f32>(None, 24.0, None).unwrap(), &[4, 3, 2]).unwrap();
2748        let z = tensordot_axes(&x, &y, &[1i32, 0], &[0i32, 1]).unwrap();
2749        let expected = Array::from_slice(
2750            &[4400, 4730, 4532, 4874, 4664, 5018, 4796, 5162, 4928, 5306],
2751            &[5, 2],
2752        );
2753        assert_eq!(z, expected);
2754
2755        let x = reshape(arange::<_, f32>(None, 360.0, None).unwrap(), &[3, 4, 5, 6]).unwrap();
2756        let y = reshape(arange::<_, f32>(None, 360.0, None).unwrap(), &[6, 4, 5, 3]).unwrap();
2757        assert!(tensordot_axes(&x, &y, &[2, 1, 3], &[1, 2, 0]).is_err());
2758
2759        let x = reshape(arange::<_, f32>(None, 60.0, None).unwrap(), &[3, 4, 5]).unwrap();
2760        let y = reshape(arange::<_, f32>(None, 120.0, None).unwrap(), &[4, 5, 6]).unwrap();
2761
2762        let z = tensordot_axis(&x, &y, 2).unwrap();
2763        let expected = Array::from_slice(
2764            &[
2765                14820.0, 15010.0, 15200.0, 15390.0, 15580.0, 15770.0, 37620.0, 38210.0, 38800.0,
2766                39390.0, 39980.0, 40570.0, 60420.0, 61410.0, 62400.0, 63390.0, 64380.0, 65370.0,
2767            ],
2768            &[3, 6],
2769        );
2770        assert_eq!(z, expected);
2771    }
2772
2773    #[test]
2774    fn test_outer() {
2775        let x = arange::<_, f32>(1.0, 5.0, None).unwrap();
2776        let y = arange::<_, f32>(1.0, 4.0, None).unwrap();
2777        let z = outer(&x, &y).unwrap();
2778        let expected = Array::from_slice(
2779            &[1.0, 2.0, 3.0, 2.0, 4.0, 6.0, 3.0, 6.0, 9.0, 4.0, 8.0, 12.0],
2780            &[4, 3],
2781        );
2782        assert_eq!(z, expected);
2783
2784        let x = ones::<f32>(&[5]).unwrap();
2785        let y = linspace::<_, f32>(-2.0, 2.0, 5).unwrap();
2786        let z = outer(&x, &y).unwrap();
2787        let expected = Array::from_slice(
2788            &[
2789                -2.0, -1.0, 0.0, 1.0, 2.0, -2.0, -1.0, 0.0, 1.0, 2.0, -2.0, -1.0, 0.0, 1.0, 2.0,
2790                -2.0, -1.0, 0.0, 1.0, 2.0, -2.0, -1.0, 0.0, 1.0, 2.0,
2791            ],
2792            &[5, 5],
2793        );
2794        assert_eq!(z, expected);
2795    }
2796
2797    #[test]
2798    fn test_inner() {
2799        let x = reshape(arange::<_, f32>(None, 5.0, None).unwrap(), &[1, 5]).unwrap();
2800        let y = reshape(arange::<_, f32>(None, 6.0, None).unwrap(), &[2, 3]).unwrap();
2801        assert!(inner(&x, &y).is_err());
2802
2803        let x = array!([1.0, 2.0, 3.0]);
2804        let y = array!([0.0, 1.0, 0.0]);
2805        let z = inner(&x, &y).unwrap();
2806        assert_eq!(z.item::<f32>(), 2.0);
2807
2808        let x = reshape(arange::<_, f32>(None, 24.0, None).unwrap(), &[2, 3, 4]).unwrap();
2809        let y = arange::<_, f32>(None, 4.0, None).unwrap();
2810        let z = inner(&x, &y).unwrap();
2811        let expected = Array::from_slice(&[14.0, 38.0, 62.0, 86.0, 110.0, 134.0], &[2, 3]);
2812        assert_eq!(z, expected);
2813
2814        let x = reshape(arange::<_, f32>(None, 2.0, None).unwrap(), &[1, 1, 2]).unwrap();
2815        let y = reshape(arange::<_, f32>(None, 6.0, None).unwrap(), &[3, 2]).unwrap();
2816        let z = inner(&x, &y).unwrap();
2817        let expected = Array::from_slice(&[1.0, 3.0, 5.0], &[1, 1, 3]);
2818        assert_eq!(z, expected);
2819
2820        let x = eye::<f32>(2, None, None).unwrap();
2821        let y = Array::from_f32(7.0);
2822        let z = inner(&x, &y).unwrap();
2823        let expected = Array::from_slice(&[7.0, 0.0, 0.0, 7.0], &[2, 2]);
2824        assert_eq!(z, expected);
2825    }
2826
2827    #[test]
2828    fn test_divmod() {
2829        let x = array!([1.0, 2.0, 3.0]);
2830        let y = array!([1.0, 1.0, 1.0]);
2831        let out = divmod(&x, &y).unwrap();
2832        assert_eq!(out.0, array!([1.0, 2.0, 3.0]));
2833        assert_eq!(out.1, array!([0.0, 0.0, 0.0]));
2834
2835        let x = array!([5.0, 6.0, 7.0]);
2836        let y = array!([2.0, 2.0, 2.0]);
2837        let out = divmod(&x, &y).unwrap();
2838        assert_eq!(out.0, array!([2.0, 3.0, 3.0]));
2839        assert_eq!(out.1, array!([1.0, 0.0, 1.0]));
2840
2841        let x = array!([5.0, 6.0, 7.0]);
2842        let y = array!([2.0, 2.0, 2.0]);
2843        let out = divmod(&x, &y).unwrap();
2844        assert_eq!(out.0, array!([2.0, 3.0, 3.0]));
2845        assert_eq!(out.1, array!([1.0, 0.0, 1.0]));
2846
2847        let x = array![complex64::new(1.0, 0.0)];
2848        let y = array![complex64::new(2.0, 0.0)];
2849        assert!(divmod(&x, &y).is_err());
2850
2851        // Check that we can eval on both outputs
2852        let x = array![1.0];
2853        let y = array![2.0];
2854        let (quo, rem) = divmod(&x, &y).unwrap();
2855        eval([&quo, &rem]).unwrap();
2856        assert_eq!(quo.item::<f32>(), 0.0);
2857        assert_eq!(rem.item::<f32>(), 1.0);
2858
2859        // Check nested in the graph
2860        let x = array![1.0];
2861        let y = array![2.0];
2862        let (quo, rem) = divmod(&x, &y).unwrap();
2863        let z = quo + rem;
2864        assert_eq!(z.item::<f32>(), 1.0);
2865
2866        // Check that we can still eval when one output goes out of scope
2867        let mut out_holder = {
2868            let (quo, _) = divmod(&x, &y).unwrap();
2869            vec![quo]
2870        };
2871        eval(out_holder.iter()).unwrap();
2872        assert_eq!(out_holder[0].item::<f32>(), 0.0);
2873
2874        // Check that we can still eval when the other output goes out of scope
2875        out_holder.clear();
2876        let out_holder = {
2877            let (_, rem) = divmod(&x, &y).unwrap();
2878            vec![rem]
2879        };
2880        eval(out_holder.iter()).unwrap();
2881        assert_eq!(out_holder[0].item::<f32>(), 1.0);
2882    }
2883}