mlx_rs/ops/
arithmetic.rs

1use crate::array::Array;
2use crate::error::Result;
3use crate::sealed::Sealed;
4
5use crate::utils::guard::Guarded;
6use crate::utils::{IntoOption, ScalarOrArray, VectorArray};
7use crate::Stream;
8use mlx_internal_macros::{default_device, generate_macro};
9use smallvec::SmallVec;
10
11impl Array {
12    /// Element-wise absolute value.
13    ///
14    /// # Example
15    ///
16    /// ```rust
17    /// use mlx_rs::Array;
18    /// let array = Array::from_slice(&[1i32, 2, -3, -4, -5], &[5]);
19    /// let mut result = array.abs().unwrap();
20    ///
21    /// let data: &[i32] = result.as_slice();
22    /// // data == [1, 2, 3, 4, 5]
23    /// ```
24    #[default_device]
25    pub fn abs_device(&self, stream: impl AsRef<Stream>) -> Result<Array> {
26        Array::try_from_op(|res| unsafe {
27            mlx_sys::mlx_abs(res, self.as_ptr(), stream.as_ref().as_ptr())
28        })
29    }
30
31    /// Element-wise addition returning an error if arrays are not broadcastable.
32    ///
33    /// Add two arrays with [broadcasting](https://swiftpackageindex.com/ml-explore/mlx-swift/main/documentation/mlx/broadcasting).
34    ///
35    /// # Params
36    ///
37    /// - other: array to add
38    ///
39    /// # Example
40    ///
41    /// ```rust
42    /// use mlx_rs::Array;
43    /// let a = Array::from_slice(&[1.0, 2.0, 3.0], &[3]);
44    /// let b = Array::from_slice(&[4.0, 5.0, 6.0], &[3]);
45    /// let mut c = a.add(&b).unwrap();
46    ///
47    /// let c_data: &[f32] = c.as_slice();
48    /// // c_data == [5.0, 7.0, 9.0]
49    /// ```
50    #[default_device]
51    pub fn add_device(
52        &self,
53        other: impl AsRef<Array>,
54        stream: impl AsRef<Stream>,
55    ) -> Result<Array> {
56        Array::try_from_op(|res| unsafe {
57            mlx_sys::mlx_add(
58                res,
59                self.as_ptr(),
60                other.as_ref().as_ptr(),
61                stream.as_ref().as_ptr(),
62            )
63        })
64    }
65
66    /// Element-wise subtraction returning an error if arrays are not broadcastable.
67    ///
68    /// Subtract two arrays with [broadcasting](https://swiftpackageindex.com/ml-explore/mlx-swift/main/documentation/mlx/broadcasting).
69    ///
70    /// # Params
71    ///
72    /// - other: array to subtract
73    ///
74    /// # Example
75    ///
76    /// ```rust
77    /// use mlx_rs::Array;
78    /// let a = Array::from_slice(&[1.0, 2.0, 3.0], &[3]);
79    /// let b = Array::from_slice(&[4.0, 5.0, 6.0], &[3]);
80    /// let mut c = a.subtract(&b).unwrap();
81    ///
82    /// let c_data: &[f32] = c.as_slice();
83    /// // c_data == [-3.0, -3.0, -3.0]
84    /// ```
85    #[default_device]
86    pub fn subtract_device(
87        &self,
88        other: impl AsRef<Array>,
89        stream: impl AsRef<Stream>,
90    ) -> Result<Array> {
91        Array::try_from_op(|res| unsafe {
92            mlx_sys::mlx_subtract(
93                res,
94                self.as_ptr(),
95                other.as_ref().as_ptr(),
96                stream.as_ref().as_ptr(),
97            )
98        })
99    }
100
101    /// Unary element-wise negation. Returns an error if the array is of type bool.
102    ///
103    /// Negate the values in the array.
104    ///
105    /// # Example
106    ///
107    /// ```rust
108    /// use mlx_rs::Array;
109    /// let a = Array::from_slice(&[1.0, 2.0, 3.0], &[3]);
110    /// let mut b = a.negative().unwrap();
111    ///
112    /// let b_data: &[f32] = b.as_slice();
113    /// // b_data == [-1.0, -2.0, -3.0]
114    /// ```
115    #[default_device]
116    pub fn negative_device(&self, stream: impl AsRef<Stream>) -> Result<Array> {
117        Array::try_from_op(|res| unsafe {
118            mlx_sys::mlx_negative(res, self.as_ptr(), stream.as_ref().as_ptr())
119        })
120    }
121
122    /// Element-wise multiplication returning an error if arrays are not broadcastable.
123    ///
124    /// Multiply two arrays with [broadcasting](https://swiftpackageindex.com/ml-explore/mlx-swift/main/documentation/mlx/broadcasting).
125    ///
126    /// # Example
127    ///
128    /// ```rust
129    /// use mlx_rs::Array;
130    /// let a = Array::from_slice(&[1.0, 2.0, 3.0], &[3]);
131    /// let b = Array::from_slice(&[4.0, 5.0, 6.0], &[3]);
132    /// let mut c = a.multiply(&b).unwrap();
133    ///
134    /// let c_data: &[f32] = c.as_slice();
135    /// // c_data == [4.0, 10.0, 18.0]
136    /// ```
137    #[default_device]
138    pub fn multiply_device(
139        &self,
140        other: impl AsRef<Array>,
141        stream: impl AsRef<Stream>,
142    ) -> Result<Array> {
143        Array::try_from_op(|res| unsafe {
144            mlx_sys::mlx_multiply(
145                res,
146                self.as_ptr(),
147                other.as_ref().as_ptr(),
148                stream.as_ref().as_ptr(),
149            )
150        })
151    }
152
153    /// Replace NaN and Inf values with finite numbers.
154    ///
155    /// # Params
156    /// - nan: value to replace NaN with
157    /// - posInf: value to replace positive inifinites with.  If not specified will use
158    ///     the largest finite value for the given dtype.
159    /// - negInf: value to replace negative inifinites with.  If not specified will use
160    ///     the negative of the largest finite value for the given dtype.
161    /// - stream: stream or device to evaluate on
162    #[default_device]
163    pub fn nan_to_num_device(
164        &self,
165        nan: impl IntoOption<f32>,
166        pos_inf: impl IntoOption<f32>,
167        neg_inf: impl IntoOption<f32>,
168        stream: impl AsRef<Stream>,
169    ) -> Result<Array> {
170        let pos_inf = pos_inf.into_option();
171        let neg_inf = neg_inf.into_option();
172
173        let pos_inf = mlx_sys::mlx_optional_float {
174            value: pos_inf.unwrap_or(0.0),
175            has_value: pos_inf.is_some(),
176        };
177        let neg_inf = mlx_sys::mlx_optional_float {
178            value: neg_inf.unwrap_or(0.0),
179            has_value: neg_inf.is_some(),
180        };
181
182        Array::try_from_op(|res| unsafe {
183            mlx_sys::mlx_nan_to_num(
184                res,
185                self.as_ptr(),
186                nan.into_option().unwrap_or(0.),
187                pos_inf,
188                neg_inf,
189                stream.as_ref().as_ptr(),
190            )
191        })
192    }
193
194    /// Element-wise division returning an error if arrays are not broadcastable.
195    ///
196    /// Divide two arrays with [broadcasting](https://swiftpackageindex.com/ml-explore/mlx-swift/main/documentation/mlx/broadcasting).
197    ///
198    /// # Params
199    ///
200    /// - other: array to divide
201    ///
202    /// # Example
203    ///
204    /// ```rust
205    /// use mlx_rs::Array;
206    /// let a = Array::from_slice(&[1.0, 2.0, 3.0], &[3]);
207    /// let b = Array::from_slice(&[4.0, 5.0, 6.0], &[3]);
208    /// let mut c = a.divide(&b).unwrap();
209    ///
210    /// let c_data: &[f32] = c.as_slice();
211    /// // c_data == [0.25, 0.4, 0.5]
212    /// ```
213    #[default_device]
214    pub fn divide_device(
215        &self,
216        other: impl AsRef<Array>,
217        stream: impl AsRef<Stream>,
218    ) -> Result<Array> {
219        Array::try_from_op(|res| unsafe {
220            mlx_sys::mlx_divide(
221                res,
222                self.as_ptr(),
223                other.as_ref().as_ptr(),
224                stream.as_ref().as_ptr(),
225            )
226        })
227    }
228
229    /// Element-wise power operation returning an error if arrays are not broadcastable if they have different shapes.
230    ///
231    /// Raise the elements of the array to the power of the elements of another array.
232    ///
233    /// # Params
234    ///
235    /// - other: array to raise to the power of
236    ///
237    /// # Example
238    ///
239    /// ```rust
240    /// use mlx_rs::Array;
241    /// let a = Array::from_slice(&[1.0, 2.0, 3.0], &[3]);
242    /// let b = Array::from_slice(&[2.0, 3.0, 4.0], &[3]);
243    /// let mut c = a.power(&b).unwrap();
244    ///
245    /// let c_data: &[f32] = c.as_slice();
246    /// // c_data == [1.0, 8.0, 81.0]
247    /// ```
248    #[default_device]
249    pub fn power_device(
250        &self,
251        other: impl AsRef<Array>,
252        stream: impl AsRef<Stream>,
253    ) -> Result<Array> {
254        Array::try_from_op(|res| unsafe {
255            mlx_sys::mlx_power(
256                res,
257                self.as_ptr(),
258                other.as_ref().as_ptr(),
259                stream.as_ref().as_ptr(),
260            )
261        })
262    }
263
264    /// Element-wise remainder of division returning an error if arrays are not broadcastable.
265    ///
266    /// Computes the remainder of dividing `lhs` with `rhs` with [broadcasting](https://swiftpackageindex.com/ml-explore/mlx-swift/main/documentation/mlx/broadcasting).
267    ///
268    /// # Params
269    ///
270    /// - other: array to divide
271    ///
272    /// # Example
273    ///
274    /// ```rust
275    /// use mlx_rs::Array;
276    /// let a = Array::from_slice(&[10.0, 11.0, 12.0], &[3]);
277    /// let b = Array::from_slice(&[3.0, 4.0, 5.0], &[3]);
278    /// let mut c = a.remainder(&b).unwrap();
279    ///
280    /// let c_data: &[f32] = c.as_slice();
281    /// // c_data == [1.0, 3.0, 2.0]
282    /// ```
283    #[default_device]
284    pub fn remainder_device(
285        &self,
286        other: impl AsRef<Array>,
287        stream: impl AsRef<Stream>,
288    ) -> Result<Array> {
289        Array::try_from_op(|res| unsafe {
290            mlx_sys::mlx_remainder(
291                res,
292                self.as_ptr(),
293                other.as_ref().as_ptr(),
294                stream.as_ref().as_ptr(),
295            )
296        })
297    }
298
299    /// Element-wise square root
300    ///
301    /// # Example
302    ///
303    /// ```rust
304    /// use mlx_rs::Array;
305    /// let a = Array::from_slice(&[1.0, 4.0, 9.0], &[3]);
306    /// let mut b = a.sqrt().unwrap();
307    ///
308    /// let b_data: &[f32] = b.as_slice();
309    /// // b_data == [1.0, 2.0, 3.0]
310    /// ```
311    #[default_device]
312    pub fn sqrt_device(&self, stream: impl AsRef<Stream>) -> Result<Array> {
313        Array::try_from_op(|res| unsafe {
314            mlx_sys::mlx_sqrt(res, self.as_ptr(), stream.as_ref().as_ptr())
315        })
316    }
317
318    /// Element-wise cosine
319    ///
320    /// # Example
321    ///
322    /// ```rust
323    /// use mlx_rs::Array;
324    /// let a = Array::from_slice(&[0.0, 1.0, 2.0], &[3]);
325    /// let mut b = a.cos().unwrap();
326    ///
327    /// let b_data: &[f32] = b.as_slice();
328    /// // b_data == [1.0, 0.54030234, -0.41614687]
329    /// ```
330    #[default_device]
331    pub fn cos_device(&self, stream: impl AsRef<Stream>) -> Result<Array> {
332        Array::try_from_op(|res| unsafe {
333            mlx_sys::mlx_cos(res, self.as_ptr(), stream.as_ref().as_ptr())
334        })
335    }
336
337    /// Element-wise exponential.
338    ///
339    /// # Example
340    ///
341    /// ```rust
342    /// use mlx_rs::Array;
343    ///
344    /// let a = Array::from_slice(&[0.0, 1.0, 2.0], &[3]);
345    /// let a = Array::from_slice(&[0.0, 1.0, 2.0], &[3]);
346    /// let mut b = a.exp().unwrap();
347    ///
348    /// let b_data: &[f32] = b.as_slice();
349    /// // b_data == [1.0, 2.7182817, 7.389056]
350    /// ```
351    #[default_device]
352    pub fn exp_device(&self, stream: impl AsRef<Stream>) -> Result<Array> {
353        Array::try_from_op(|res| unsafe {
354            mlx_sys::mlx_exp(res, self.as_ptr(), stream.as_ref().as_ptr())
355        })
356    }
357
358    /// Element-wise floor returning an error if the array is of type complex64.
359    ///
360    /// # Example
361    ///
362    /// ```rust
363    /// use mlx_rs::Array;
364    /// let a = Array::from_slice(&[0.1, 1.9, 2.5], &[3]);
365    /// let mut b = a.floor().unwrap();
366    ///
367    /// let b_data: &[f32] = b.as_slice();
368    /// // b_data == [0.0, 1.0, 2.0]
369    /// ```
370    #[default_device]
371    pub fn floor_device(&self, stream: impl AsRef<Stream>) -> Result<Array> {
372        Array::try_from_op(|res| unsafe {
373            mlx_sys::mlx_floor(res, self.as_ptr(), stream.as_ref().as_ptr())
374        })
375    }
376
377    /// Element-wise integer division returning an error if arrays are not broadcastable.
378    ///
379    /// Divide two arrays with
380    /// [broadcasting](https://swiftpackageindex.com/ml-explore/mlx-swift/main/documentation/mlx/broadcasting).
381    ///
382    /// If either array is a floating point type then it is equivalent to calling [`Array::floor()`]
383    /// after `/`.
384    ///
385    /// # Params
386    ///
387    /// - other: array to divide
388    ///
389    /// # Example
390    ///
391    /// ```rust
392    /// use mlx_rs::Array;
393    /// let a = Array::from_slice(&[1.0, 2.0, 3.0], &[3]);
394    /// let b = Array::from_slice(&[4.0, 5.0, 6.0], &[3]);
395    /// let mut c = a.floor_divide(&b).unwrap();
396    ///
397    /// let c_data: &[f32] = c.as_slice();
398    /// // c_data == [0.25, 0.4, 0.5]
399    /// ```
400    #[default_device]
401    pub fn floor_divide_device(
402        &self,
403        other: impl AsRef<Array>,
404        stream: impl AsRef<Stream>,
405    ) -> Result<Array> {
406        Array::try_from_op(|res| unsafe {
407            mlx_sys::mlx_floor_divide(
408                res,
409                self.as_ptr(),
410                other.as_ref().as_ptr(),
411                stream.as_ref().as_ptr(),
412            )
413        })
414    }
415
416    /// Return a boolean array indicating which elements are NaN.
417    ///
418    /// # Params
419    /// - stream: stream or device to evaluate on
420    #[default_device]
421    pub fn is_nan_device(&self, stream: impl AsRef<Stream>) -> Result<Array> {
422        Array::try_from_op(|res| unsafe {
423            mlx_sys::mlx_isnan(res, self.as_ptr(), stream.as_ref().as_ptr())
424        })
425    }
426
427    /// Return a boolean array indicating which elements are infinity.
428    ///
429    /// # Params
430    /// - stream: stream or device to evaluate on
431    #[default_device]
432    pub fn is_inf_device(&self, stream: impl AsRef<Stream>) -> Result<Array> {
433        Array::try_from_op(|res| unsafe {
434            mlx_sys::mlx_isinf(res, self.as_ptr(), stream.as_ref().as_ptr())
435        })
436    }
437
438    /// Return a boolean array indicating which elements are finite.
439    ///
440    /// # Params
441    /// - stream: stream or device to evaluate on
442    #[default_device]
443    pub fn is_finite_device(&self, stream: impl AsRef<Stream>) -> Result<Array> {
444        Array::try_from_op(|res| unsafe {
445            mlx_sys::mlx_isfinite(res, self.as_ptr(), stream.as_ref().as_ptr())
446        })
447    }
448
449    /// Return a boolean array indicating which elements are negative infinity.
450    ///
451    /// # Params
452    /// - stream: stream or device to evaluate on
453    #[default_device]
454    pub fn is_neg_inf_device(&self, stream: impl AsRef<Stream>) -> Result<Array> {
455        Array::try_from_op(|res| unsafe {
456            mlx_sys::mlx_isneginf(res, self.as_ptr(), stream.as_ref().as_ptr())
457        })
458    }
459
460    /// Return a boolean array indicating which elements are positive infinity.
461    ///
462    /// # Params
463    /// - stream: stream or device to evaluate on
464    #[default_device]
465    pub fn is_pos_inf_device(&self, stream: impl AsRef<Stream>) -> Result<Array> {
466        Array::try_from_op(|res| unsafe {
467            mlx_sys::mlx_isposinf(res, self.as_ptr(), stream.as_ref().as_ptr())
468        })
469    }
470
471    /// Element-wise natural logarithm.
472    ///
473    /// # Example
474    ///
475    /// ```rust
476    /// use mlx_rs::Array;
477    /// let a = Array::from_slice(&[1.0, 2.0, 3.0], &[3]);
478    /// let mut b = a.log().unwrap();
479    ///
480    /// let b_data: &[f32] = b.as_slice();
481    /// // b_data == [0.0, 0.6931472, 1.0986123]
482    /// ```
483    #[default_device]
484    pub fn log_device(&self, stream: impl AsRef<Stream>) -> Result<Array> {
485        Array::try_from_op(|res| unsafe {
486            mlx_sys::mlx_log(res, self.as_ptr(), stream.as_ref().as_ptr())
487        })
488    }
489
490    /// Element-wise base-2 logarithm.
491    ///
492    /// # Example
493    ///
494    /// ```rust
495    /// use mlx_rs::Array;
496    /// let a = Array::from_slice(&[1.0, 2.0, 4.0, 8.0], &[4]);
497    /// let mut b = a.log2().unwrap();
498    ///
499    /// let b_data: &[f32] = b.as_slice();
500    /// // b_data == [0.0, 1.0, 2.0, 3.0]
501    /// ```
502    #[default_device]
503    pub fn log2_device(&self, stream: impl AsRef<Stream>) -> Result<Array> {
504        Array::try_from_op(|res| unsafe {
505            mlx_sys::mlx_log2(res, self.as_ptr(), stream.as_ref().as_ptr())
506        })
507    }
508
509    /// Element-wise base-10 logarithm.
510    ///
511    /// # Example
512    ///
513    /// ```rust
514    /// use mlx_rs::Array;
515    /// let a = Array::from_slice(&[1.0, 10.0, 100.0], &[3]);
516    /// let mut b = a.log10().unwrap();
517    ///
518    /// let b_data: &[f32] = b.as_slice();
519    /// // b_data == [0.0, 1.0, 2.0]
520    /// ```
521    #[default_device]
522    pub fn log10_device(&self, stream: impl AsRef<Stream>) -> Result<Array> {
523        Array::try_from_op(|res| unsafe {
524            mlx_sys::mlx_log10(res, self.as_ptr(), stream.as_ref().as_ptr())
525        })
526    }
527
528    /// Element-wise natural log of one plus the array.
529    ///
530    /// # Example
531    ///
532    /// ```rust
533    /// use mlx_rs::Array;
534    /// let a = Array::from_slice(&[1.0, 2.0, 3.0], &[3]);
535    /// let mut b = a.log1p().unwrap();
536    ///
537    /// let b_data: &[f32] = b.as_slice();
538    /// // b_data == [0.6931472, 1.0986123, 1.3862944]
539    /// ```
540    #[default_device]
541    pub fn log1p_device(&self, stream: impl AsRef<Stream>) -> Result<Array> {
542        Array::try_from_op(|res| unsafe {
543            mlx_sys::mlx_log1p(res, self.as_ptr(), stream.as_ref().as_ptr())
544        })
545    }
546
547    /// Matrix multiplication returning an error if inputs are not valid.
548    ///
549    /// Perform the (possibly batched) matrix multiplication of two arrays. This function supports
550    /// broadcasting for arrays with more than two dimensions.
551    ///
552    /// - If the first array is 1-D then a 1 is prepended to its shape to make it
553    ///   a matrix. Similarly, if the second array is 1-D then a 1 is appended to its
554    ///   shape to make it a matrix. In either case the singleton dimension is removed
555    ///   from the result.
556    /// - A batched matrix multiplication is performed if the arrays have more than
557    ///   2 dimensions.  The matrix dimensions for the matrix product are the last
558    ///   two dimensions of each input.
559    /// - All but the last two dimensions of each input are broadcast with one another using
560    ///   standard [broadcasting](https://swiftpackageindex.com/ml-explore/mlx-swift/main/documentation/mlx/broadcasting).
561    ///
562    /// # Params
563    ///
564    /// - other: array to multiply
565    ///
566    /// # Example
567    ///
568    /// ```rust
569    /// use mlx_rs::Array;
570    /// let a = Array::from_slice(&[1, 2, 3, 4], &[2, 2]);
571    /// let b = Array::from_slice(&[-5.0, 37.5, 4., 7., 1., 0.], &[2, 3]);
572    ///
573    /// // produces a [2, 3] result
574    /// let mut c = a.matmul(&b);
575    /// ```
576    #[default_device]
577    pub fn matmul_device(
578        &self,
579        other: impl AsRef<Array>,
580        stream: impl AsRef<Stream>,
581    ) -> Result<Array> {
582        Array::try_from_op(|res| unsafe {
583            mlx_sys::mlx_matmul(
584                res,
585                self.as_ptr(),
586                other.as_ref().as_ptr(),
587                stream.as_ref().as_ptr(),
588            )
589        })
590    }
591
592    /// Element-wise reciprocal.
593    ///
594    /// # Example
595    ///
596    /// ```rust
597    /// use mlx_rs::Array;
598    /// let a = Array::from_slice(&[1.0, 2.0, 4.0], &[3]);
599    /// let mut b = a.reciprocal().unwrap();
600    ///
601    /// let b_data: &[f32] = b.as_slice();
602    /// // b_data == [1.0, 0.5, 0.25]
603    /// ```
604    #[default_device]
605    pub fn reciprocal_device(&self, stream: impl AsRef<Stream>) -> Result<Array> {
606        Array::try_from_op(|res| unsafe {
607            mlx_sys::mlx_reciprocal(res, self.as_ptr(), stream.as_ref().as_ptr())
608        })
609    }
610
611    /// Round to the given number of decimals.
612    ///
613    /// # Params
614    ///
615    /// - decimals: number of decimals to round to - default is 0 if not provided
616    #[default_device]
617    pub fn round_device(
618        &self,
619        decimals: impl Into<Option<i32>>,
620        stream: impl AsRef<Stream>,
621    ) -> Result<Array> {
622        Array::try_from_op(|res| unsafe {
623            mlx_sys::mlx_round(
624                res,
625                self.as_ptr(),
626                decimals.into().unwrap_or(0),
627                stream.as_ref().as_ptr(),
628            )
629        })
630    }
631
632    /// Element-wise reciprocal and square root.
633    #[default_device]
634    pub fn rsqrt_device(&self, stream: impl AsRef<Stream>) -> Result<Array> {
635        Array::try_from_op(|res| unsafe {
636            mlx_sys::mlx_rsqrt(res, self.as_ptr(), stream.as_ref().as_ptr())
637        })
638    }
639
640    /// Element-wise sine.
641    #[default_device]
642    pub fn sin_device(&self, stream: impl AsRef<Stream>) -> Result<Array> {
643        Array::try_from_op(|res| unsafe {
644            mlx_sys::mlx_sin(res, self.as_ptr(), stream.as_ref().as_ptr())
645        })
646    }
647
648    /// Element-wise square.
649    #[default_device]
650    pub fn square_device(&self, stream: impl AsRef<Stream>) -> Result<Array> {
651        Array::try_from_op(|res| unsafe {
652            mlx_sys::mlx_square(res, self.as_ptr(), stream.as_ref().as_ptr())
653        })
654    }
655
656    /// Element-wise real part from a complex array.
657    #[default_device]
658    pub fn real_device(&self, stream: impl AsRef<Stream>) -> Result<Array> {
659        Array::try_from_op(|res| unsafe {
660            mlx_sys::mlx_real(res, self.as_ptr(), stream.as_ref().as_ptr())
661        })
662    }
663
664    /// Element-wise imag part from a complex array.
665    #[default_device]
666    pub fn imag_device(&self, stream: impl AsRef<Stream>) -> Result<Array> {
667        Array::try_from_op(|res| unsafe {
668            mlx_sys::mlx_imag(res, self.as_ptr(), stream.as_ref().as_ptr())
669        })
670    }
671}
672
673/// Element-wise absolute value.
674///
675/// # Example
676///
677/// ```rust
678/// use mlx_rs::{Array, ops};
679///
680/// let array = Array::from_slice(&[1i32, 2, -3, -4, -5], &[5]);
681/// let result = ops::abs(&array).unwrap();
682/// ```
683#[generate_macro]
684#[default_device]
685pub fn abs_device(a: impl AsRef<Array>, #[optional] stream: impl AsRef<Stream>) -> Result<Array> {
686    a.as_ref().abs_device(stream)
687}
688
689/// Element-wise inverse cosine.
690#[generate_macro]
691#[default_device]
692pub fn acos_device(a: impl AsRef<Array>, #[optional] stream: impl AsRef<Stream>) -> Result<Array> {
693    Array::try_from_op(|res| unsafe {
694        mlx_sys::mlx_arccos(res, a.as_ref().as_ptr(), stream.as_ref().as_ptr())
695    })
696}
697
698/// Element-wise inverse hyperbolic cosine.
699#[generate_macro]
700#[default_device]
701pub fn acosh_device(a: impl AsRef<Array>, #[optional] stream: impl AsRef<Stream>) -> Result<Array> {
702    Array::try_from_op(|res| unsafe {
703        mlx_sys::mlx_arccosh(res, a.as_ref().as_ptr(), stream.as_ref().as_ptr())
704    })
705}
706
707/// See [`Array::add`].
708#[generate_macro]
709#[default_device]
710pub fn add_device(
711    lhs: impl AsRef<Array>,
712    rhs: impl AsRef<Array>,
713    #[optional] stream: impl AsRef<Stream>,
714) -> Result<Array> {
715    lhs.as_ref().add_device(rhs, stream)
716}
717
718/// Element-wise inverse sine.
719#[generate_macro]
720#[default_device]
721pub fn asin_device(a: impl AsRef<Array>, #[optional] stream: impl AsRef<Stream>) -> Result<Array> {
722    Array::try_from_op(|res| unsafe {
723        mlx_sys::mlx_arcsin(res, a.as_ref().as_ptr(), stream.as_ref().as_ptr())
724    })
725}
726
727/// Element-wise inverse hyperbolic sine.
728#[generate_macro]
729#[default_device]
730pub fn asinh_device(a: impl AsRef<Array>, #[optional] stream: impl AsRef<Stream>) -> Result<Array> {
731    Array::try_from_op(|res| unsafe {
732        mlx_sys::mlx_arcsinh(res, a.as_ref().as_ptr(), stream.as_ref().as_ptr())
733    })
734}
735
736/// Element-wise inverse tangent.
737#[generate_macro]
738#[default_device]
739pub fn atan_device(a: impl AsRef<Array>, #[optional] stream: impl AsRef<Stream>) -> Result<Array> {
740    Array::try_from_op(|res| unsafe {
741        mlx_sys::mlx_arctan(res, a.as_ref().as_ptr(), stream.as_ref().as_ptr())
742    })
743}
744
745/// Element-wise inverse tangent of b/a choosing the quadrant correctly.
746#[generate_macro]
747#[default_device]
748pub fn atan2_device(
749    a: impl AsRef<Array>,
750    b: impl AsRef<Array>,
751    #[optional] stream: impl AsRef<Stream>,
752) -> Result<Array> {
753    let a = a.as_ref();
754    let b = b.as_ref();
755
756    Array::try_from_op(|res| unsafe {
757        mlx_sys::mlx_arctan2(res, a.as_ptr(), b.as_ptr(), stream.as_ref().as_ptr())
758    })
759}
760
761/// Element-wise inverse hyperbolic tangent.
762#[generate_macro]
763#[default_device]
764pub fn atanh_device(a: impl AsRef<Array>, #[optional] stream: impl AsRef<Stream>) -> Result<Array> {
765    Array::try_from_op(|res| unsafe {
766        mlx_sys::mlx_arctanh(res, a.as_ref().as_ptr(), stream.as_ref().as_ptr())
767    })
768}
769
770/// Element-wise ceiling.
771#[generate_macro]
772#[default_device]
773pub fn ceil_device(a: impl AsRef<Array>, #[optional] stream: impl AsRef<Stream>) -> Result<Array> {
774    Array::try_from_op(|res| unsafe {
775        mlx_sys::mlx_ceil(res, a.as_ref().as_ptr(), stream.as_ref().as_ptr())
776    })
777}
778
779/// A custom trait for the bound of the clip operation.
780///
781/// This trait is only implemented for tuples of the form `(Min, Max)`, `(Min, ())`, and `((),
782/// Max)`. The `Min` and `Max` types must implement the `ScalarOrArray` trait.
783pub trait ClipBound<'min, 'max>: Sealed {
784    /// Convert the bound into a tuple of optional minimum and maximum values.
785    fn into_min_max(
786        self,
787    ) -> (
788        Option<impl ScalarOrArray<'min>>,
789        Option<impl ScalarOrArray<'max>>,
790    );
791}
792
793impl<'min, Min> ClipBound<'min, 'min> for (Min, ())
794where
795    Min: ScalarOrArray<'min> + Sealed,
796{
797    fn into_min_max(
798        self,
799    ) -> (
800        Option<impl ScalarOrArray<'min>>,
801        Option<impl ScalarOrArray<'min>>,
802    ) {
803        (Some(self.0), Option::<Min>::None)
804    }
805}
806
807impl<'max, Max> ClipBound<'max, 'max> for ((), Max)
808where
809    Max: ScalarOrArray<'max> + Sealed,
810{
811    fn into_min_max(
812        self,
813    ) -> (
814        Option<impl ScalarOrArray<'max>>,
815        Option<impl ScalarOrArray<'max>>,
816    ) {
817        (Option::<Max>::None, Some(self.1))
818    }
819}
820
821impl<'min, 'max, Min, Max> ClipBound<'min, 'max> for (Min, Max)
822where
823    Min: ScalarOrArray<'min> + Sealed,
824    Max: ScalarOrArray<'max> + Sealed,
825{
826    fn into_min_max(
827        self,
828    ) -> (
829        Option<impl ScalarOrArray<'min>>,
830        Option<impl ScalarOrArray<'max>>,
831    ) {
832        (Some(self.0), Some(self.1))
833    }
834}
835
836/// Clip the values of the array between the given minimum and maximum.
837///
838/// If either `a_min` or `a_max` are None, then corresponding edge is ignored. At least one of
839/// `a_min` and `a_max` cannot be `None`. The input `a` and the limits must broadcast with one
840/// another.
841///
842/// # Params
843///
844/// - `a`: Input array.
845/// - `bound`: minimum and/or maximum values to clip the array to.
846///
847/// # Example
848///
849/// ```rust
850/// use mlx_rs::{Array, ops::clip, array};
851///
852/// let a = array!([1.0, 4.0, 3.0, 8.0, 5.0]);
853/// let expected = array!([2.0, 4.0, 3.0, 6.0, 5.0]);
854/// let clipped = clip(&a, (2.0, 6.0)).unwrap();
855/// assert_eq!(clipped, expected);
856/// ```
857#[generate_macro]
858#[default_device]
859pub fn clip_device<'min, 'max>(
860    a: impl AsRef<Array>,
861    bound: impl ClipBound<'min, 'max>,
862    #[optional] stream: impl AsRef<Stream>,
863) -> Result<Array> {
864    let (a_min, a_max) = bound.into_min_max();
865
866    // This is needed to keep the lifetime of the min/max arrays in scope.
867    let a_min = a_min.map(|min| min.into_owned_or_ref_array());
868    let a_max = a_max.map(|max| max.into_owned_or_ref_array());
869
870    unsafe {
871        let min_ptr = match &a_min {
872            Some(a_min) => a_min.as_ref().as_ptr(),
873            None => mlx_sys::mlx_array_new(),
874        };
875        let max_ptr = match &a_max {
876            Some(a_max) => a_max.as_ref().as_ptr(),
877            None => mlx_sys::mlx_array_new(),
878        };
879
880        Array::try_from_op(|res| {
881            mlx_sys::mlx_clip(
882                res,
883                a.as_ref().as_ptr(),
884                min_ptr,
885                max_ptr,
886                stream.as_ref().as_ptr(),
887            )
888        })
889    }
890}
891
892/// Element-wise cosine.
893#[generate_macro]
894#[default_device]
895pub fn cos_device(a: impl AsRef<Array>, #[optional] stream: impl AsRef<Stream>) -> Result<Array> {
896    a.as_ref().cos_device(stream)
897}
898
899/// Element-wise hyperbolic cosine.
900#[generate_macro]
901#[default_device]
902pub fn cosh_device(a: impl AsRef<Array>, #[optional] stream: impl AsRef<Stream>) -> Result<Array> {
903    Array::try_from_op(|res| unsafe {
904        mlx_sys::mlx_cosh(res, a.as_ref().as_ptr(), stream.as_ref().as_ptr())
905    })
906}
907
908/// Convert angles from radians to degrees.
909#[generate_macro]
910#[default_device]
911pub fn degrees_device(
912    a: impl AsRef<Array>,
913    #[optional] stream: impl AsRef<Stream>,
914) -> Result<Array> {
915    Array::try_from_op(|res| unsafe {
916        mlx_sys::mlx_degrees(res, a.as_ref().as_ptr(), stream.as_ref().as_ptr())
917    })
918}
919
920/// See [`Array::divide`].
921#[generate_macro]
922#[default_device]
923pub fn divide_device(
924    a: impl AsRef<Array>,
925    b: impl AsRef<Array>,
926    #[optional] stream: impl AsRef<Stream>,
927) -> Result<Array> {
928    a.as_ref().divide_device(b, stream)
929}
930
931/// Element-wise quotient and remainder.
932///
933/// The fuction `divmod(a, b)` is equivalent to but faster than `(a // b, a % b)`. The function uses
934/// numpy-style broadcasting semantics. Either or both input arrays can also be scalars.
935///
936/// Returns Ok((quotient, remainder)) if the operation was successful.
937#[generate_macro]
938#[default_device]
939pub fn divmod_device(
940    a: impl AsRef<Array>,
941    b: impl AsRef<Array>,
942    #[optional] stream: impl AsRef<Stream>,
943) -> Result<(Array, Array)> {
944    let a_ptr = a.as_ref().as_ptr();
945    let b_ptr = b.as_ref().as_ptr();
946
947    let vec = VectorArray::try_from_op(|res| unsafe {
948        mlx_sys::mlx_divmod(res, a_ptr, b_ptr, stream.as_ref().as_ptr())
949    })?;
950
951    let vals: SmallVec<[_; 2]> = vec.try_into_values()?;
952    let mut iter = vals.into_iter();
953    let quotient = iter.next().unwrap();
954    let remainder = iter.next().unwrap();
955
956    Ok((quotient, remainder))
957}
958
959/// Element-wise error function.
960#[generate_macro]
961#[default_device]
962pub fn erf_device(a: impl AsRef<Array>, #[optional] stream: impl AsRef<Stream>) -> Result<Array> {
963    Array::try_from_op(|res| unsafe {
964        mlx_sys::mlx_erf(res, a.as_ref().as_ptr(), stream.as_ref().as_ptr())
965    })
966}
967
968/// Element-wise inverse error function.
969#[generate_macro]
970#[default_device]
971pub fn erfinv_device(
972    a: impl AsRef<Array>,
973    #[optional] stream: impl AsRef<Stream>,
974) -> Result<Array> {
975    Array::try_from_op(|res| unsafe {
976        mlx_sys::mlx_erfinv(res, a.as_ref().as_ptr(), stream.as_ref().as_ptr())
977    })
978}
979
980/// See [`Array::exp`].
981#[generate_macro]
982#[default_device]
983pub fn exp_device(a: impl AsRef<Array>, #[optional] stream: impl AsRef<Stream>) -> Result<Array> {
984    a.as_ref().exp_device(stream)
985}
986
987/// Element-wise exponential minus 1.
988#[generate_macro]
989#[default_device]
990pub fn expm1_device(a: impl AsRef<Array>, #[optional] stream: impl AsRef<Stream>) -> Result<Array> {
991    Array::try_from_op(|res| unsafe {
992        mlx_sys::mlx_expm1(res, a.as_ref().as_ptr(), stream.as_ref().as_ptr())
993    })
994}
995
996/// See [`Array::floor`].
997#[generate_macro]
998#[default_device]
999pub fn floor_device(a: impl AsRef<Array>, #[optional] stream: impl AsRef<Stream>) -> Result<Array> {
1000    a.as_ref().floor_device(stream)
1001}
1002
1003/// See [`Array::floor_divide`].
1004#[generate_macro]
1005#[default_device]
1006pub fn floor_divide_device(
1007    a: impl AsRef<Array>,
1008    other: impl AsRef<Array>,
1009    #[optional] stream: impl AsRef<Stream>,
1010) -> Result<Array> {
1011    a.as_ref().floor_divide_device(other, stream)
1012}
1013
1014/// See [`Array::log`].
1015#[generate_macro]
1016#[default_device]
1017pub fn log_device(a: impl AsRef<Array>, #[optional] stream: impl AsRef<Stream>) -> Result<Array> {
1018    a.as_ref().log_device(stream)
1019}
1020
1021/// See [`Array::log10`].
1022#[generate_macro]
1023#[default_device]
1024pub fn log10_device(a: impl AsRef<Array>, #[optional] stream: impl AsRef<Stream>) -> Result<Array> {
1025    a.as_ref().log10_device(stream)
1026}
1027
1028/// See [`Array::log1p`].
1029#[generate_macro]
1030#[default_device]
1031pub fn log1p_device(a: impl AsRef<Array>, #[optional] stream: impl AsRef<Stream>) -> Result<Array> {
1032    a.as_ref().log1p_device(stream)
1033}
1034
1035/// See [`Array::log2`].
1036#[generate_macro]
1037#[default_device]
1038pub fn log2_device(a: impl AsRef<Array>, #[optional] stream: impl AsRef<Stream>) -> Result<Array> {
1039    a.as_ref().log2_device(stream)
1040}
1041
1042/// Element-wise log-add-exp.
1043///
1044/// This is a numerically stable log-add-exp of two arrays with numpy-style broadcasting semantics.
1045/// Either or both input arrays can also be scalars.
1046///
1047/// The computation is is a numerically stable version of `log(exp(a) + exp(b))`.
1048#[generate_macro]
1049#[default_device]
1050pub fn logaddexp_device(
1051    a: impl AsRef<Array>,
1052    b: impl AsRef<Array>,
1053    #[optional] stream: impl AsRef<Stream>,
1054) -> Result<Array> {
1055    let a_ptr = a.as_ref().as_ptr();
1056    let b_ptr = b.as_ref().as_ptr();
1057
1058    Array::try_from_op(|res| unsafe {
1059        mlx_sys::mlx_logaddexp(res, a_ptr, b_ptr, stream.as_ref().as_ptr())
1060    })
1061}
1062
1063/// See [`Array::matmul`].
1064#[generate_macro]
1065#[default_device]
1066pub fn matmul_device(
1067    a: impl AsRef<Array>,
1068    b: impl AsRef<Array>,
1069    #[optional] stream: impl AsRef<Stream>,
1070) -> Result<Array> {
1071    a.as_ref().matmul_device(b, stream)
1072}
1073
1074/// Element-wise maximum.
1075///
1076/// Take the element-wise max of two arrays with numpy-style broadcasting semantics. Either or both
1077/// input arrays can also be scalars.
1078#[generate_macro]
1079#[default_device]
1080pub fn maximum_device(
1081    a: impl AsRef<Array>,
1082    b: impl AsRef<Array>,
1083    #[optional] stream: impl AsRef<Stream>,
1084) -> Result<Array> {
1085    let a_ptr = a.as_ref().as_ptr();
1086    let b_ptr = b.as_ref().as_ptr();
1087
1088    Array::try_from_op(|res| unsafe {
1089        mlx_sys::mlx_maximum(res, a_ptr, b_ptr, stream.as_ref().as_ptr())
1090    })
1091}
1092
1093/// Element-wise minimum.
1094///
1095/// Take the element-wise min of two arrays with numpy-style broadcasting semantics. Either or both
1096/// input arrays can also be scalars.
1097#[generate_macro]
1098#[default_device]
1099pub fn minimum_device(
1100    a: impl AsRef<Array>,
1101    b: impl AsRef<Array>,
1102    #[optional] stream: impl AsRef<Stream>,
1103) -> Result<Array> {
1104    let a_ptr = a.as_ref().as_ptr();
1105    let b_ptr = b.as_ref().as_ptr();
1106
1107    Array::try_from_op(|res| unsafe {
1108        mlx_sys::mlx_minimum(res, a_ptr, b_ptr, stream.as_ref().as_ptr())
1109    })
1110}
1111
1112/// See [`Array::multiply`].
1113#[generate_macro]
1114#[default_device]
1115pub fn multiply_device(
1116    a: impl AsRef<Array>,
1117    b: impl AsRef<Array>,
1118    #[optional] stream: impl AsRef<Stream>,
1119) -> Result<Array> {
1120    a.as_ref().multiply_device(b, stream)
1121}
1122
1123/// See [`Array::negative`].
1124#[generate_macro]
1125#[default_device]
1126pub fn negative_device(
1127    a: impl AsRef<Array>,
1128    #[optional] stream: impl AsRef<Stream>,
1129) -> Result<Array> {
1130    a.as_ref().negative_device(stream)
1131}
1132
1133/// See [`Array::power`].
1134#[generate_macro]
1135#[default_device]
1136pub fn power_device(
1137    a: impl AsRef<Array>,
1138    b: impl AsRef<Array>,
1139    #[optional] stream: impl AsRef<Stream>,
1140) -> Result<Array> {
1141    a.as_ref().power_device(b, stream)
1142}
1143
1144/// Convert angles from degrees to radians.
1145#[generate_macro]
1146#[default_device]
1147pub fn radians_device(
1148    a: impl AsRef<Array>,
1149    #[optional] stream: impl AsRef<Stream>,
1150) -> Result<Array> {
1151    Array::try_from_op(|res| unsafe {
1152        mlx_sys::mlx_radians(res, a.as_ref().as_ptr(), stream.as_ref().as_ptr())
1153    })
1154}
1155
1156/// See [`Array::reciprocal`].
1157#[generate_macro]
1158#[default_device]
1159pub fn reciprocal_device(
1160    a: impl AsRef<Array>,
1161    #[optional] stream: impl AsRef<Stream>,
1162) -> Result<Array> {
1163    a.as_ref().reciprocal_device(stream)
1164}
1165
1166/// See [`Array::remainder`].
1167#[generate_macro]
1168#[default_device]
1169pub fn remainder_device(
1170    a: impl AsRef<Array>,
1171    b: impl AsRef<Array>,
1172    #[optional] stream: impl AsRef<Stream>,
1173) -> Result<Array> {
1174    a.as_ref().remainder_device(b, stream)
1175}
1176
1177/// See [`Array::round`].
1178#[generate_macro]
1179#[default_device]
1180pub fn round_device(
1181    a: impl AsRef<Array>,
1182    decimals: impl Into<Option<i32>>,
1183    #[optional] stream: impl AsRef<Stream>,
1184) -> Result<Array> {
1185    a.as_ref().round_device(decimals, stream)
1186}
1187
1188/// See [`Array::rsqrt`].
1189#[generate_macro]
1190#[default_device]
1191pub fn rsqrt_device(a: impl AsRef<Array>, #[optional] stream: impl AsRef<Stream>) -> Result<Array> {
1192    a.as_ref().rsqrt_device(stream)
1193}
1194
1195/// Element-wise logistic sigmoid.
1196///
1197/// See the [python API
1198/// docs](https://ml-explore.github.io/mlx/build/html/python/_autosummary/mlx.core.sigmoid.html#mlx.core.sigmoid)
1199/// for more information
1200#[generate_macro]
1201#[default_device]
1202pub fn sigmoid_device(
1203    a: impl AsRef<Array>,
1204    #[optional] stream: impl AsRef<Stream>,
1205) -> Result<Array> {
1206    Array::try_from_op(|res| unsafe {
1207        mlx_sys::mlx_sigmoid(res, a.as_ref().as_ptr(), stream.as_ref().as_ptr())
1208    })
1209}
1210
1211/// Element-wise sign.
1212#[generate_macro]
1213#[default_device]
1214pub fn sign_device(a: impl AsRef<Array>, #[optional] stream: impl AsRef<Stream>) -> Result<Array> {
1215    Array::try_from_op(|res| unsafe {
1216        mlx_sys::mlx_sign(res, a.as_ref().as_ptr(), stream.as_ref().as_ptr())
1217    })
1218}
1219
1220/// See [`Array::sin`].
1221#[generate_macro]
1222#[default_device]
1223pub fn sin_device(a: impl AsRef<Array>, #[optional] stream: impl AsRef<Stream>) -> Result<Array> {
1224    a.as_ref().sin_device(stream)
1225}
1226
1227/// Element-wise hyperbolic sine.
1228#[generate_macro]
1229#[default_device]
1230pub fn sinh_device(a: impl AsRef<Array>, #[optional] stream: impl AsRef<Stream>) -> Result<Array> {
1231    Array::try_from_op(|res| unsafe {
1232        mlx_sys::mlx_sinh(res, a.as_ref().as_ptr(), stream.as_ref().as_ptr())
1233    })
1234}
1235
1236/// Perform the softmax along the given axis.
1237///
1238/// See the [python API
1239/// docs](https://ml-explore.github.io/mlx/build/html/python/_autosummary/mlx.core.softmax.html#mlx.core.softmax)
1240/// for more information.
1241#[generate_macro]
1242#[default_device]
1243pub fn softmax_axes_device(
1244    a: impl AsRef<Array>,
1245    axes: &[i32],
1246    precise: impl Into<Option<bool>>,
1247    #[optional] stream: impl AsRef<Stream>,
1248) -> Result<Array> {
1249    let precise = precise.into().unwrap_or(false);
1250    let s = stream.as_ref().as_ptr();
1251
1252    Array::try_from_op(|res| unsafe {
1253        mlx_sys::mlx_softmax_axes(
1254            res,
1255            a.as_ref().as_ptr(),
1256            axes.as_ptr(),
1257            axes.len(),
1258            precise,
1259            s,
1260        )
1261    })
1262}
1263
1264/// Similar to [`softmax_axes`] but with a single axis.
1265#[generate_macro]
1266#[default_device]
1267pub fn softmax_axis_device(
1268    a: impl AsRef<Array>,
1269    axis: i32,
1270    precise: impl Into<Option<bool>>,
1271    #[optional] stream: impl AsRef<Stream>,
1272) -> Result<Array> {
1273    let precise = precise.into().unwrap_or(false);
1274    let s = stream.as_ref().as_ptr();
1275
1276    Array::try_from_op(|res| unsafe {
1277        mlx_sys::mlx_softmax_axis(res, a.as_ref().as_ptr(), axis, precise, s)
1278    })
1279}
1280
1281/// Similar to [`softmax_axes`] but with no axis specified.
1282#[generate_macro]
1283#[default_device]
1284pub fn softmax_device(
1285    a: impl AsRef<Array>,
1286    precise: impl Into<Option<bool>>,
1287    #[optional] stream: impl AsRef<Stream>,
1288) -> Result<Array> {
1289    let precise = precise.into().unwrap_or(false);
1290    let s = stream.as_ref().as_ptr();
1291
1292    Array::try_from_op(|res| unsafe { mlx_sys::mlx_softmax(res, a.as_ref().as_ptr(), precise, s) })
1293}
1294
1295/// See [`Array::sqrt`].
1296#[generate_macro]
1297#[default_device]
1298pub fn sqrt_device(a: impl AsRef<Array>, #[optional] stream: impl AsRef<Stream>) -> Result<Array> {
1299    a.as_ref().sqrt_device(stream)
1300}
1301
1302/// See [`Array::square`].
1303#[generate_macro]
1304#[default_device]
1305pub fn square_device(
1306    a: impl AsRef<Array>,
1307    #[optional] stream: impl AsRef<Stream>,
1308) -> Result<Array> {
1309    a.as_ref().square_device(stream)
1310}
1311
1312/// See [`Array::subtract`].
1313#[generate_macro]
1314#[default_device]
1315pub fn subtract_device(
1316    a: impl AsRef<Array>,
1317    b: impl AsRef<Array>,
1318    #[optional] stream: impl AsRef<Stream>,
1319) -> Result<Array> {
1320    a.as_ref().subtract_device(b, stream)
1321}
1322
1323/// See [`Array::tan`].
1324#[generate_macro]
1325#[default_device]
1326pub fn tan_device(a: impl AsRef<Array>, #[optional] stream: impl AsRef<Stream>) -> Result<Array> {
1327    Array::try_from_op(|res| unsafe {
1328        mlx_sys::mlx_tan(res, a.as_ref().as_ptr(), stream.as_ref().as_ptr())
1329    })
1330}
1331
1332/// Element-wise hyperbolic tangent.
1333#[generate_macro]
1334#[default_device]
1335pub fn tanh_device(a: impl AsRef<Array>, #[optional] stream: impl AsRef<Stream>) -> Result<Array> {
1336    Array::try_from_op(|res| unsafe {
1337        mlx_sys::mlx_tanh(res, a.as_ref().as_ptr(), stream.as_ref().as_ptr())
1338    })
1339}
1340
1341/// Element-wise real part from a complex array.
1342#[generate_macro]
1343#[default_device]
1344pub fn real_device(a: impl AsRef<Array>, #[optional] stream: impl AsRef<Stream>) -> Result<Array> {
1345    Array::try_from_op(|res| unsafe {
1346        mlx_sys::mlx_real(res, a.as_ref().as_ptr(), stream.as_ref().as_ptr())
1347    })
1348}
1349
1350/// Element-wise imaginary part from a complex array.
1351#[generate_macro]
1352#[default_device]
1353pub fn imag_device(a: impl AsRef<Array>, #[optional] stream: impl AsRef<Stream>) -> Result<Array> {
1354    Array::try_from_op(|res| unsafe {
1355        mlx_sys::mlx_imag(res, a.as_ref().as_ptr(), stream.as_ref().as_ptr())
1356    })
1357}
1358
1359/// Matrix multiplication with block masking.
1360///
1361/// See the [python API docs](
1362/// https://ml-explore.github.io/mlx/build/html/python/_autosummary/mlx.core.block_masked_mm.html#mlx.core.block_masked_mm
1363/// ) for more information.
1364#[generate_macro]
1365#[default_device]
1366pub fn block_masked_mm_device<'mo, 'lhs, 'rhs>(
1367    a: impl AsRef<Array>,
1368    b: impl AsRef<Array>,
1369    #[optional] block_size: impl Into<Option<i32>>,
1370    #[optional] mask_out: impl Into<Option<&'mo Array>>,
1371    #[optional] mask_lhs: impl Into<Option<&'lhs Array>>,
1372    #[optional] mask_rhs: impl Into<Option<&'rhs Array>>,
1373    #[optional] stream: impl AsRef<Stream>,
1374) -> Result<Array> {
1375    let a_ptr = a.as_ref().as_ptr();
1376    let b_ptr = b.as_ref().as_ptr();
1377    unsafe {
1378        let mask_out_ptr = mask_out
1379            .into()
1380            .map(|m| m.as_ptr())
1381            .unwrap_or(mlx_sys::mlx_array_new());
1382        let mask_lhs_ptr = mask_lhs
1383            .into()
1384            .map(|m| m.as_ptr())
1385            .unwrap_or(mlx_sys::mlx_array_new());
1386        let mask_rhs_ptr = mask_rhs
1387            .into()
1388            .map(|m| m.as_ptr())
1389            .unwrap_or(mlx_sys::mlx_array_new());
1390
1391        Array::try_from_op(|res| {
1392            mlx_sys::mlx_block_masked_mm(
1393                res,
1394                a_ptr,
1395                b_ptr,
1396                block_size.into().unwrap_or(32),
1397                mask_out_ptr,
1398                mask_lhs_ptr,
1399                mask_rhs_ptr,
1400                stream.as_ref().as_ptr(),
1401            )
1402        })
1403    }
1404}
1405
1406/// Matrix multiplication with addition and optional scaling.
1407///
1408/// Perform the (possibly batched) matrix multiplication of two arrays and add to the result with
1409/// optional scaling factors.
1410///
1411/// # Params
1412///
1413/// - `c`: input array,
1414/// - `a`: input array,
1415/// - `b`: input array,
1416/// - `alpha`: Scaling factor for the matrix product of `a` and `b` (default: `1`)
1417/// - `beta`: Scaling factor for `c` (default: `1`)
1418#[generate_macro]
1419#[default_device]
1420pub fn addmm_device(
1421    c: impl AsRef<Array>,
1422    a: impl AsRef<Array>,
1423    b: impl AsRef<Array>,
1424    #[optional] alpha: impl Into<Option<f32>>,
1425    #[optional] beta: impl Into<Option<f32>>,
1426    #[optional] stream: impl AsRef<Stream>,
1427) -> Result<Array> {
1428    let c_ptr = c.as_ref().as_ptr();
1429    let a_ptr = a.as_ref().as_ptr();
1430    let b_ptr = b.as_ref().as_ptr();
1431    let alpha = alpha.into().unwrap_or(1.0);
1432    let beta = beta.into().unwrap_or(1.0);
1433
1434    Array::try_from_op(|res| unsafe {
1435        mlx_sys::mlx_addmm(
1436            res,
1437            c_ptr,
1438            a_ptr,
1439            b_ptr,
1440            alpha,
1441            beta,
1442            stream.as_ref().as_ptr(),
1443        )
1444    })
1445}
1446
1447/// Ordinary inner product of vectors for 1-D arrays, in higher dimensions a sum product over the
1448/// last axes.
1449#[generate_macro]
1450#[default_device]
1451pub fn inner_device(
1452    a: impl AsRef<Array>,
1453    b: impl AsRef<Array>,
1454    #[optional] stream: impl AsRef<Stream>,
1455) -> Result<Array> {
1456    let a = a.as_ref();
1457    let b = b.as_ref();
1458    Array::try_from_op(|res| unsafe {
1459        mlx_sys::mlx_inner(res, a.as_ptr(), b.as_ptr(), stream.as_ref().as_ptr())
1460    })
1461}
1462
1463/// Compute the outer product of two 1-D arrays, if the array’s passed are not 1-D a flatten op will
1464/// be run beforehand.
1465#[generate_macro]
1466#[default_device]
1467pub fn outer_device(
1468    a: impl AsRef<Array>,
1469    b: impl AsRef<Array>,
1470    #[optional] stream: impl AsRef<Stream>,
1471) -> Result<Array> {
1472    let a = a.as_ref();
1473    let b = b.as_ref();
1474    Array::try_from_op(|res| unsafe {
1475        mlx_sys::mlx_outer(res, a.as_ptr(), b.as_ptr(), stream.as_ref().as_ptr())
1476    })
1477}
1478
1479/// Compute the tensor dot product along the specified axes.
1480#[generate_macro]
1481#[default_device]
1482pub fn tensordot_axes_device(
1483    a: impl AsRef<Array>,
1484    b: impl AsRef<Array>,
1485    axes_a: &[i32],
1486    axes_b: &[i32],
1487    #[optional] stream: impl AsRef<Stream>,
1488) -> Result<Array> {
1489    let a = a.as_ref();
1490    let b = b.as_ref();
1491    Array::try_from_op(|res| unsafe {
1492        mlx_sys::mlx_tensordot(
1493            res,
1494            a.as_ptr(),
1495            b.as_ptr(),
1496            axes_a.as_ptr(),
1497            axes_a.len(),
1498            axes_b.as_ptr(),
1499            axes_b.len(),
1500            stream.as_ref().as_ptr(),
1501        )
1502    })
1503}
1504
1505/// Similar to [`tensordot_axes`] but with a single axis.
1506#[generate_macro]
1507#[default_device]
1508pub fn tensordot_axis_device(
1509    a: impl AsRef<Array>,
1510    b: impl AsRef<Array>,
1511    axis: i32,
1512    #[optional] stream: impl AsRef<Stream>,
1513) -> Result<Array> {
1514    let a = a.as_ref();
1515    let b = b.as_ref();
1516    Array::try_from_op(|res| unsafe {
1517        mlx_sys::mlx_tensordot_axis(res, a.as_ptr(), b.as_ptr(), axis, stream.as_ref().as_ptr())
1518    })
1519}
1520
1521#[cfg(test)]
1522mod tests {
1523    use std::f32::consts::PI;
1524
1525    use super::*;
1526    use crate::{
1527        array, complex64,
1528        ops::{all_close, arange, broadcast_to, eye, full, linspace, ones, reshape, split},
1529        transforms::eval,
1530        Dtype, StreamOrDevice,
1531    };
1532    use float_eq::assert_float_eq;
1533    use pretty_assertions::assert_eq;
1534
1535    #[test]
1536    fn test_abs() {
1537        let data = [1i32, 2, -3, -4, -5];
1538        let array = Array::from_slice(&data, &[5]);
1539        let result = array.abs().unwrap();
1540
1541        let data: &[i32] = result.as_slice();
1542        assert_eq!(data, [1, 2, 3, 4, 5]);
1543
1544        // test that previous array is not modified and valid
1545        let data: &[i32] = array.as_slice();
1546        assert_eq!(data, [1, 2, -3, -4, -5]);
1547    }
1548
1549    #[test]
1550    fn test_add() {
1551        let a = Array::from_slice(&[1.0, 2.0, 3.0], &[3]);
1552        let b = Array::from_slice(&[4.0, 5.0, 6.0], &[3]);
1553
1554        let c = &a + &b;
1555
1556        let c_data: &[f32] = c.as_slice();
1557        assert_eq!(c_data, &[5.0, 7.0, 9.0]);
1558
1559        // check a and b are not modified
1560        let a_data: &[f32] = a.as_slice();
1561        assert_eq!(a_data, &[1.0, 2.0, 3.0]);
1562
1563        let b_data: &[f32] = b.as_slice();
1564        assert_eq!(b_data, &[4.0, 5.0, 6.0]);
1565    }
1566
1567    #[test]
1568    fn test_add_invalid_broadcast() {
1569        let a = Array::from_slice(&[1.0, 2.0, 3.0], &[3]);
1570        let b = Array::from_slice(&[4.0, 5.0], &[2]);
1571
1572        let c = a.add(&b);
1573        assert!(c.is_err());
1574    }
1575
1576    #[test]
1577    fn test_sub() {
1578        let a = Array::from_slice(&[1.0, 2.0, 3.0], &[3]);
1579        let b = Array::from_slice(&[4.0, 5.0, 6.0], &[3]);
1580
1581        let c = &a - &b;
1582
1583        let c_data: &[f32] = c.as_slice();
1584        assert_eq!(c_data, &[-3.0, -3.0, -3.0]);
1585
1586        // check a and b are not modified
1587        let a_data: &[f32] = a.as_slice();
1588        assert_eq!(a_data, &[1.0, 2.0, 3.0]);
1589
1590        let b_data: &[f32] = b.as_slice();
1591        assert_eq!(b_data, &[4.0, 5.0, 6.0]);
1592    }
1593
1594    #[test]
1595    fn test_sub_invalid_broadcast() {
1596        let a = Array::from_slice(&[1.0, 2.0, 3.0], &[3]);
1597        let b = Array::from_slice(&[4.0, 5.0], &[2]);
1598        let c = a.subtract(&b);
1599        assert!(c.is_err());
1600    }
1601
1602    #[test]
1603    fn test_neg() {
1604        let a = Array::from_slice::<f32>(&[1.0, 2.0, 3.0], &[3]);
1605        let b = a.negative().unwrap();
1606
1607        let b_data: &[f32] = b.as_slice();
1608        assert_eq!(b_data, &[-1.0, -2.0, -3.0]);
1609
1610        // check a is not modified
1611        let a_data: &[f32] = a.as_slice();
1612        assert_eq!(a_data, &[1.0, 2.0, 3.0]);
1613    }
1614
1615    #[test]
1616    fn test_neg_bool() {
1617        let a = Array::from_slice(&[true, false, true], &[3]);
1618        let b = a.negative();
1619        assert!(b.is_err());
1620    }
1621
1622    #[test]
1623    fn test_logical_not() {
1624        let a: Array = false.into();
1625        let b = a.logical_not().unwrap();
1626
1627        let b_data: &[bool] = b.as_slice();
1628        assert_eq!(b_data, [true]);
1629    }
1630
1631    #[test]
1632    fn test_mul() {
1633        let a = Array::from_slice(&[1.0, 2.0, 3.0], &[3]);
1634        let b = Array::from_slice(&[4.0, 5.0, 6.0], &[3]);
1635
1636        let c = &a * &b;
1637
1638        let c_data: &[f32] = c.as_slice();
1639        assert_eq!(c_data, &[4.0, 10.0, 18.0]);
1640
1641        // check a and b are not modified
1642        let a_data: &[f32] = a.as_slice();
1643        assert_eq!(a_data, &[1.0, 2.0, 3.0]);
1644
1645        let b_data: &[f32] = b.as_slice();
1646        assert_eq!(b_data, &[4.0, 5.0, 6.0]);
1647    }
1648
1649    #[test]
1650    fn test_mul_invalid_broadcast() {
1651        let a = Array::from_slice(&[1.0, 2.0, 3.0], &[3]);
1652        let b = Array::from_slice(&[4.0, 5.0], &[2]);
1653        let c = a.multiply(&b);
1654        assert!(c.is_err());
1655    }
1656
1657    #[test]
1658    fn test_nan_to_num() {
1659        let a = array!([1.0, 2.0, f32::NAN, 4.0, 5.0]);
1660        let b = a.nan_to_num(0.0, 1.0, 0.0).unwrap();
1661
1662        let b_data: &[f32] = b.as_slice();
1663        assert_eq!(b_data, &[1.0, 2.0, 0.0, 4.0, 5.0]);
1664    }
1665
1666    #[test]
1667    fn test_div() {
1668        let a = Array::from_slice(&[1.0, 2.0, 3.0], &[3]);
1669        let b = Array::from_slice(&[4.0, 5.0, 6.0], &[3]);
1670
1671        let c = &a / &b;
1672
1673        let c_data: &[f32] = c.as_slice();
1674        assert_eq!(c_data, &[0.25, 0.4, 0.5]);
1675
1676        // check a and b are not modified
1677        let a_data: &[f32] = a.as_slice();
1678        assert_eq!(a_data, &[1.0, 2.0, 3.0]);
1679
1680        let b_data: &[f32] = b.as_slice();
1681        assert_eq!(b_data, &[4.0, 5.0, 6.0]);
1682    }
1683
1684    #[test]
1685    fn test_div_invalid_broadcast() {
1686        let a = Array::from_slice(&[1.0, 2.0, 3.0], &[3]);
1687        let b = Array::from_slice(&[4.0, 5.0], &[2]);
1688        let c = a.divide(&b);
1689        assert!(c.is_err());
1690    }
1691
1692    #[test]
1693    fn test_pow() {
1694        let a = Array::from_slice(&[1.0, 2.0, 3.0], &[3]);
1695        let b = Array::from_slice(&[2.0, 3.0, 4.0], &[3]);
1696
1697        let c = a.power(&b).unwrap();
1698
1699        let c_data: &[f32] = c.as_slice();
1700        assert_eq!(c_data, &[1.0, 8.0, 81.0]);
1701
1702        // check a and b are not modified
1703        let a_data: &[f32] = a.as_slice();
1704        assert_eq!(a_data, &[1.0, 2.0, 3.0]);
1705
1706        let b_data: &[f32] = b.as_slice();
1707        assert_eq!(b_data, &[2.0, 3.0, 4.0]);
1708    }
1709
1710    #[test]
1711    fn test_pow_invalid_broadcast() {
1712        let a = Array::from_slice(&[1.0, 2.0, 3.0], &[3]);
1713        let b = Array::from_slice(&[2.0, 3.0], &[2]);
1714        let c = a.power(&b);
1715        assert!(c.is_err());
1716    }
1717
1718    #[test]
1719    fn test_rem() {
1720        let a = Array::from_slice(&[10.0, 11.0, 12.0], &[3]);
1721        let b = Array::from_slice(&[3.0, 4.0, 5.0], &[3]);
1722
1723        let c = &a % &b;
1724
1725        let c_data: &[f32] = c.as_slice();
1726        assert_eq!(c_data, &[1.0, 3.0, 2.0]);
1727
1728        // check a and b are not modified
1729        let a_data: &[f32] = a.as_slice();
1730        assert_eq!(a_data, &[10.0, 11.0, 12.0]);
1731
1732        let b_data: &[f32] = b.as_slice();
1733        assert_eq!(b_data, &[3.0, 4.0, 5.0]);
1734    }
1735
1736    #[test]
1737    fn test_rem_invalid_broadcast() {
1738        let a = Array::from_slice(&[10.0, 11.0, 12.0], &[3]);
1739        let b = Array::from_slice(&[3.0, 4.0], &[2]);
1740        let c = a.remainder(&b);
1741        assert!(c.is_err());
1742    }
1743
1744    #[test]
1745    fn test_sqrt() {
1746        let a = Array::from_slice(&[1.0, 4.0, 9.0], &[3]);
1747        let b = a.sqrt().unwrap();
1748
1749        let b_data: &[f32] = b.as_slice();
1750        assert_eq!(b_data, &[1.0, 2.0, 3.0]);
1751
1752        // check a is not modified
1753        let a_data: &[f32] = a.as_slice();
1754        assert_eq!(a_data, &[1.0, 4.0, 9.0]);
1755    }
1756
1757    #[test]
1758    fn test_cos() {
1759        let a = Array::from_slice(&[0.0, 1.0, 2.0], &[3]);
1760        let b = a.cos().unwrap();
1761
1762        let b_expected = array!([1.0, 0.54030234, -0.41614687]);
1763        assert_array_all_close!(b, b_expected);
1764
1765        // check a is not modified
1766        let a_expected = array!([0.0, 1.0, 2.0]);
1767        assert_array_all_close!(a, a_expected);
1768    }
1769
1770    #[test]
1771    fn test_exp() {
1772        let a = Array::from_slice(&[0.0, 1.0, 2.0], &[3]);
1773        let b = a.exp().unwrap();
1774
1775        let b_expected = array!([1.0, 2.7182817, 7.389056]);
1776        assert_array_all_close!(b, b_expected);
1777
1778        // check a is not modified
1779        let a_expected = array!([0.0, 1.0, 2.0]);
1780        assert_array_all_close!(a, a_expected);
1781    }
1782
1783    #[test]
1784    fn test_floor() {
1785        let a = Array::from_slice(&[0.1, 1.9, 2.5], &[3]);
1786        let b = a.floor().unwrap();
1787
1788        let b_data: &[f32] = b.as_slice();
1789        assert_eq!(b_data, &[0.0, 1.0, 2.0]);
1790
1791        // check a is not modified
1792        let a_data: &[f32] = a.as_slice();
1793        assert_eq!(a_data, &[0.1, 1.9, 2.5]);
1794    }
1795
1796    #[test]
1797    fn test_floor_complex64() {
1798        let val = complex64::new(1.0, 2.0);
1799        let a = Array::from_complex(val);
1800        let b = a.floor_device(StreamOrDevice::default());
1801        assert!(b.is_err());
1802    }
1803
1804    #[test]
1805    fn test_floor_divide() {
1806        let a = Array::from_slice(&[1.0, 2.0, 3.0], &[3]);
1807        let b = Array::from_slice(&[4.0, 5.0, 6.0], &[3]);
1808
1809        let c = a.floor_divide(&b).unwrap();
1810
1811        let c_data: &[f32] = c.as_slice();
1812        assert_eq!(c_data, &[0.0, 0.0, 0.0]);
1813
1814        // check a and b are not modified
1815        let a_data: &[f32] = a.as_slice();
1816        assert_eq!(a_data, &[1.0, 2.0, 3.0]);
1817
1818        let b_data: &[f32] = b.as_slice();
1819        assert_eq!(b_data, &[4.0, 5.0, 6.0]);
1820    }
1821
1822    #[test]
1823    fn test_floor_divide_complex64() {
1824        let val = complex64::new(1.0, 2.0);
1825        let a = Array::from_complex(val);
1826        let b = Array::from_slice(&[4.0, 5.0, 6.0], &[3]);
1827        let c = a.floor_divide_device(&b, StreamOrDevice::default());
1828        assert!(c.is_err());
1829    }
1830
1831    #[test]
1832    fn test_floor_divide_invalid_broadcast() {
1833        let a = Array::from_slice(&[1.0, 2.0, 3.0], &[3]);
1834        let b = Array::from_slice(&[4.0, 5.0], &[2]);
1835        let c = a.floor_divide_device(&b, StreamOrDevice::default());
1836        assert!(c.is_err());
1837    }
1838
1839    #[test]
1840    fn test_is_nan() {
1841        let a = Array::from_slice(&[1.0, f32::NAN, 3.0], &[3]);
1842        let b = a.is_nan().unwrap();
1843
1844        let b_data: &[bool] = b.as_slice();
1845        assert_eq!(b_data, &[false, true, false]);
1846    }
1847
1848    #[test]
1849    fn test_is_inf() {
1850        let a = Array::from_slice(&[1.0, f32::INFINITY, 3.0], &[3]);
1851        let b = a.is_inf().unwrap();
1852
1853        let b_data: &[bool] = b.as_slice();
1854        assert_eq!(b_data, &[false, true, false]);
1855    }
1856
1857    #[test]
1858    fn test_is_finite() {
1859        let a = Array::from_slice(&[1.0, f32::INFINITY, 3.0], &[3]);
1860        let b = a.is_finite().unwrap();
1861
1862        let b_data: &[bool] = b.as_slice();
1863        assert_eq!(b_data, &[true, false, true]);
1864    }
1865
1866    #[test]
1867    fn test_is_neg_inf() {
1868        let a = Array::from_slice(&[1.0, f32::NEG_INFINITY, 3.0], &[3]);
1869        let b = a.is_neg_inf().unwrap();
1870
1871        let b_data: &[bool] = b.as_slice();
1872        assert_eq!(b_data, &[false, true, false]);
1873    }
1874
1875    #[test]
1876    fn test_is_pos_inf() {
1877        let a = Array::from_slice(&[1.0, f32::INFINITY, 3.0], &[3]);
1878        let b = a.is_pos_inf().unwrap();
1879
1880        let b_data: &[bool] = b.as_slice();
1881        assert_eq!(b_data, &[false, true, false]);
1882    }
1883
1884    #[test]
1885    fn test_log() {
1886        let a = Array::from_slice(&[1.0, 2.0, 3.0], &[3]);
1887        let b = a.log().unwrap();
1888
1889        let b_data: &[f32] = b.as_slice();
1890        assert_eq!(b_data, &[0.0, 0.6931472, 1.0986123]);
1891
1892        // check a is not modified
1893        let a_data: &[f32] = a.as_slice();
1894        assert_eq!(a_data, &[1.0, 2.0, 3.0]);
1895    }
1896
1897    #[test]
1898    fn test_log2() {
1899        let a = Array::from_slice(&[1.0, 2.0, 4.0, 8.0], &[4]);
1900        let b = a.log2().unwrap();
1901
1902        let b_data: &[f32] = b.as_slice();
1903        assert_eq!(b_data, &[0.0, 1.0, 2.0, 3.0]);
1904
1905        // check a is not modified
1906        let a_data: &[f32] = a.as_slice();
1907        assert_eq!(a_data, &[1.0, 2.0, 4.0, 8.0]);
1908    }
1909
1910    #[test]
1911    fn test_log10() {
1912        let a = Array::from_slice(&[1.0, 10.0, 100.0], &[3]);
1913        let b = a.log10().unwrap();
1914
1915        let b_data: &[f32] = b.as_slice();
1916        assert_eq!(b_data, &[0.0, 1.0, 2.0]);
1917
1918        // check a is not modified
1919        let a_data: &[f32] = a.as_slice();
1920        assert_eq!(a_data, &[1.0, 10.0, 100.0]);
1921    }
1922
1923    #[test]
1924    fn test_log1p() {
1925        let a = Array::from_slice(&[1.0, 2.0, 3.0], &[3]);
1926        let b = a.log1p().unwrap();
1927
1928        let b_data: &[f32] = b.as_slice();
1929        assert_eq!(b_data, &[0.6931472, 1.0986123, 1.3862944]);
1930
1931        // check a is not modified
1932        let a_data: &[f32] = a.as_slice();
1933        assert_eq!(a_data, &[1.0, 2.0, 3.0]);
1934    }
1935
1936    #[test]
1937    fn test_matmul() {
1938        let a = Array::from_slice(&[1, 2, 3, 4], &[2, 2]);
1939        let b = Array::from_slice(&[-5.0, 37.5, 4., 7., 1., 0.], &[2, 3]);
1940
1941        let c = a.matmul(&b).unwrap();
1942
1943        assert_eq!(c.shape(), &[2, 3]);
1944        let c_data: &[f32] = c.as_slice();
1945        assert_eq!(c_data, &[9.0, 39.5, 4.0, 13.0, 116.5, 12.0]);
1946
1947        // check a and b are not modified
1948        let a_data: &[i32] = a.as_slice();
1949        assert_eq!(a_data, &[1, 2, 3, 4]);
1950
1951        let b_data: &[f32] = b.as_slice();
1952        assert_eq!(b_data, &[-5.0, 37.5, 4., 7., 1., 0.]);
1953    }
1954
1955    #[test]
1956    fn test_matmul_ndim_zero() {
1957        let a: Array = 1.0.into();
1958        let b = Array::from_slice::<i32>(&[1], &[1]);
1959        let c = a.matmul(&b);
1960        assert!(c.is_err());
1961    }
1962
1963    #[test]
1964    fn test_matmul_ndim_one() {
1965        let a = Array::from_slice(&[1.0, 2.0, 3.0, 4.0], &[4]);
1966        let b = Array::from_slice(&[1.0, 2.0, 3.0, 4.0], &[4]);
1967        let c = a.matmul(&b);
1968        assert!(c.is_ok());
1969    }
1970
1971    #[test]
1972    fn test_matmul_dim_mismatch() {
1973        let a = Array::from_slice(&[1, 2, 3, 4, 5, 6], &[2, 3]);
1974        let b = Array::from_slice(&[1, 2, 3, 4, 5, 6, 7, 8, 9, 10], &[2, 5]);
1975        let c = a.matmul(&b);
1976        assert!(c.is_err());
1977    }
1978
1979    #[test]
1980    fn test_matmul_non_float_output_type() {
1981        let a = Array::from_slice(&[1, 2, 3, 4], &[2, 2]);
1982        let b = Array::from_slice(&[5, 37, 4, 7, 1, 0], &[2, 3]);
1983
1984        let c = a.matmul(&b);
1985        assert!(c.is_err());
1986    }
1987
1988    #[test]
1989    fn test_reciprocal() {
1990        let a = Array::from_slice(&[1.0, 2.0, 4.0], &[3]);
1991        let b = a.reciprocal().unwrap();
1992
1993        let b_data: &[f32] = b.as_slice();
1994        assert_eq!(b_data, &[1.0, 0.5, 0.25]);
1995
1996        // check a is not modified
1997        let a_data: &[f32] = a.as_slice();
1998        assert_eq!(a_data, &[1.0, 2.0, 4.0]);
1999    }
2000
2001    #[test]
2002    fn test_round() {
2003        let a = Array::from_slice(&[1.1, 2.9, 3.5], &[3]);
2004        let b = a.round(None).unwrap();
2005
2006        let b_data: &[f32] = b.as_slice();
2007        assert_eq!(b_data, &[1.0, 3.0, 4.0]);
2008
2009        // check a is not modified
2010        let a_data: &[f32] = a.as_slice();
2011        assert_eq!(a_data, &[1.1, 2.9, 3.5]);
2012    }
2013
2014    #[test]
2015    fn test_rsqrt() {
2016        let a = Array::from_slice(&[1.0, 2.0, 4.0], &[3]);
2017        let b = a.rsqrt().unwrap();
2018
2019        let b_data: &[f32] = b.as_slice();
2020        assert_eq!(b_data, &[1.0, 0.70710677, 0.5]);
2021
2022        // check a is not modified
2023        let a_data: &[f32] = a.as_slice();
2024        assert_eq!(a_data, &[1.0, 2.0, 4.0]);
2025    }
2026
2027    #[test]
2028    fn test_sin() {
2029        let a = Array::from_slice(&[0.0, 1.0, 2.0], &[3]);
2030        let b = a.sin().unwrap();
2031
2032        let b_data: &[f32] = b.as_slice();
2033        assert_eq!(b_data, &[0.0, 0.841471, 0.9092974]);
2034
2035        // check a is not modified
2036        let a_data: &[f32] = a.as_slice();
2037        assert_eq!(a_data, &[0.0, 1.0, 2.0]);
2038    }
2039
2040    #[test]
2041    fn test_square() {
2042        let a = Array::from_slice(&[1.0, 2.0, 3.0], &[3]);
2043        let b = a.square().unwrap();
2044
2045        let b_data: &[f32] = b.as_slice();
2046        assert_eq!(b_data, &[1.0, 4.0, 9.0]);
2047
2048        // check a is not modified
2049        let a_data: &[f32] = a.as_slice();
2050        assert_eq!(a_data, &[1.0, 2.0, 3.0]);
2051    }
2052
2053    // The unit tests below are adapted from the original mlx c++ codebase.
2054
2055    #[test]
2056    fn test_unary_neg() {
2057        let x = array!(1.0);
2058        assert_eq!(negative(&x).unwrap().item::<f32>(), -1.0);
2059        assert_eq!((-x).item::<f32>(), -1.0);
2060
2061        // works on empty array
2062        assert_eq!(-array!(), array!());
2063
2064        // Throws on bool
2065        let x = array!(true);
2066        assert!(negative(&x).is_err());
2067    }
2068
2069    #[test]
2070    fn test_unary_abs() {
2071        let x = array!([-1.0, 0.0, 1.0]);
2072        assert_eq!(abs(&x).unwrap(), array!([1.0, 0.0, 1.0]));
2073
2074        // works on empty array
2075        assert_eq!(abs(array!()).unwrap(), array!());
2076
2077        // int32
2078        let x = array!([-1, 0, 1]);
2079        assert_eq!(abs(&x).unwrap(), array!([1, 0, 1]));
2080
2081        // uint32
2082        let x = array!([1u32, 0, 1]);
2083        assert_eq!(abs(&x).unwrap(), array!([1u32, 0, 1]));
2084
2085        // bool
2086        let x = array!([false, true]);
2087        assert_eq!(abs(&x).unwrap(), array!([false, true]));
2088    }
2089
2090    #[test]
2091    fn test_unary_sign() {
2092        let x = array!([-1.0, 0.0, 1.0]);
2093        assert_eq!(sign(&x).unwrap(), x);
2094
2095        // works on empty array
2096        assert_eq!(sign(array!()).unwrap(), array!());
2097
2098        // int32
2099        let x = array!([-1, 0, 1]);
2100        assert_eq!(sign(&x).unwrap(), x);
2101
2102        // uint32
2103        let x = array!([1u32, 0, 1]);
2104        assert_eq!(sign(&x).unwrap(), x);
2105
2106        // bool
2107        let x = array!([false, true]);
2108        assert_eq!(sign(&x).unwrap(), x);
2109    }
2110
2111    const NEG_INF: f32 = f32::NEG_INFINITY;
2112
2113    #[test]
2114    fn test_unary_floor_ceil() {
2115        let x = array![1.0];
2116        assert_eq!(floor(&x).unwrap().item::<f32>(), 1.0);
2117        assert_eq!(ceil(&x).unwrap().item::<f32>(), 1.0);
2118
2119        let x = array![1.5];
2120        assert_eq!(floor(&x).unwrap().item::<f32>(), 1.0);
2121        assert_eq!(ceil(&x).unwrap().item::<f32>(), 2.0);
2122
2123        let x = array![-1.5];
2124        assert_eq!(floor(&x).unwrap().item::<f32>(), -2.0);
2125        assert_eq!(ceil(&x).unwrap().item::<f32>(), -1.0);
2126
2127        let x = array![NEG_INF];
2128        assert_eq!(floor(&x).unwrap().item::<f32>(), NEG_INF);
2129        assert_eq!(ceil(&x).unwrap().item::<f32>(), NEG_INF);
2130
2131        let x = array!([1.0, 1.0]).as_type::<complex64>().unwrap();
2132        assert!(floor(&x).is_err());
2133        assert!(ceil(&x).is_err());
2134    }
2135
2136    #[test]
2137    fn test_unary_round() {
2138        let x = array!([0.5, -0.5, 1.5, -1.5, 2.3, 2.6]);
2139        assert_eq!(round(&x, None).unwrap(), array!([0, 0, 2, -2, 2, 3]));
2140
2141        let x = array!([11, 222, 32]);
2142        assert_eq!(round(&x, -1).unwrap(), array!([10, 220, 30]));
2143    }
2144
2145    #[test]
2146    fn test_unary_exp() {
2147        let x = array![0.0];
2148        assert_eq!(exp(&x).unwrap().item::<f32>(), 1.0);
2149
2150        let x = array![2.0];
2151        assert_float_eq! {
2152            exp(&x).unwrap().item::<f32>(),
2153            2.0f32.exp(),
2154            abs <= 1e-5
2155        };
2156
2157        assert_eq!(exp(array!()).unwrap(), array!());
2158
2159        let x = array![NEG_INF];
2160        assert_eq!(exp(&x).unwrap().item::<f32>(), 0.0);
2161
2162        // Integer input type
2163        let x = array![2];
2164        assert_eq!(x.dtype(), Dtype::Int32);
2165        assert_float_eq! {
2166            exp(&x).unwrap().item::<f32>(),
2167            2.0f32.exp(),
2168            abs <= 1e-5
2169        };
2170
2171        // Input is irregularly strided
2172        let x = broadcast_to(&array!(1.0), &[2, 2, 2]).unwrap();
2173        let res = exp(&x).unwrap();
2174        let expected = Array::full::<f32>(&[2, 2, 2], array!(1.0f32.exp())).unwrap();
2175        assert!(all_close(&res, &expected, None, None, None)
2176            .unwrap()
2177            .item::<bool>());
2178
2179        let data = Array::from_slice(&[0.0, 1.0, 2.0, 3.0], &[2, 2]);
2180        let x = split(&data, 2, 1).unwrap();
2181        let expected = Array::from_slice(&[0.0f32.exp(), 2.0f32.exp()], &[2, 1]);
2182        assert!(all_close(exp(&x[0]).unwrap(), &expected, None, None, None)
2183            .unwrap()
2184            .item::<bool>());
2185    }
2186
2187    #[test]
2188    fn test_unary_expm1() {
2189        let x = array![-1.0];
2190        assert_float_eq! {
2191            expm1(&x).unwrap().item::<f32>(),
2192            (-1.0f32).exp_m1(),
2193            abs <= 1e-5
2194        };
2195
2196        let x = array![1.0];
2197        assert_float_eq! {
2198            expm1(&x).unwrap().item::<f32>(),
2199            1.0f32.exp_m1(),
2200            abs <= 1e-5
2201        };
2202
2203        // Integer input type
2204        let x = array![1];
2205        assert_eq!(expm1(&x).unwrap().dtype(), Dtype::Float32);
2206        assert_float_eq! {
2207            expm1(&x).unwrap().item::<f32>(),
2208            1.0f32.exp_m1(),
2209            abs <= 1e-5
2210        };
2211    }
2212
2213    #[test]
2214    fn test_unary_sin() {
2215        let x = array![0.0];
2216        assert_eq!(sin(&x).unwrap().item::<f32>(), 0.0);
2217
2218        let x = array![std::f32::consts::PI / 2.0];
2219        assert_float_eq! {
2220            sin(&x).unwrap().item::<f32>(),
2221            (std::f32::consts::PI / 2.0f32).sin(),
2222            abs <= 1e-5
2223        };
2224
2225        assert_eq!(sin(array!()).unwrap(), array!());
2226
2227        // Integer input type
2228        let x = array![0];
2229        assert_eq!(x.dtype(), Dtype::Int32);
2230        assert_float_eq! {
2231            sin(&x).unwrap().item::<f32>(),
2232            0.0f32.sin(),
2233            abs <= 1e-5
2234        };
2235
2236        // Input is irregularly strided
2237        let x = broadcast_to(&array!(1.0), &[2, 2, 2]).unwrap();
2238        let res = sin(&x).unwrap();
2239        let expected = Array::full::<f32>(&[2, 2, 2], array!(1.0f32.sin())).unwrap();
2240        assert!(all_close(&res, &expected, None, None, None)
2241            .unwrap()
2242            .item::<bool>());
2243
2244        let data = Array::from_slice(&[0.0, 1.0, 2.0, 3.0], &[2, 2]);
2245        let x = split(&data, 2, 1).unwrap();
2246        let expected = Array::from_slice(&[0.0f32.sin(), 2.0f32.sin()], &[2, 1]);
2247        assert!(all_close(sin(&x[0]).unwrap(), &expected, None, None, None)
2248            .unwrap()
2249            .item::<bool>());
2250    }
2251
2252    #[test]
2253    fn test_unary_cos() {
2254        let x = array![0.0];
2255        assert_float_eq! {
2256            cos(&x).unwrap().item::<f32>(),
2257            0.0f32.cos(),
2258            abs <= 1e-5
2259        };
2260
2261        let x = array![std::f32::consts::PI / 2.0];
2262        assert_float_eq! {
2263            cos(&x).unwrap().item::<f32>(),
2264            (std::f32::consts::PI / 2.0f32).cos(),
2265            abs <= 1e-5
2266        };
2267
2268        assert_eq!(cos(array!()).unwrap(), array!());
2269
2270        // Integer input type
2271        let x = array![0];
2272        assert_eq!(x.dtype(), Dtype::Int32);
2273        assert_float_eq! {
2274            cos(&x).unwrap().item::<f32>(),
2275            0.0f32.cos(),
2276            abs <= 1e-5
2277        };
2278
2279        // Input is irregularly strided
2280        let x = broadcast_to(&array!(1.0), &[2, 2, 2]).unwrap();
2281        let res = cos(&x).unwrap();
2282        let expected = Array::full::<f32>(&[2, 2, 2], array!(1.0f32.cos())).unwrap();
2283        assert!(all_close(&res, &expected, None, None, None)
2284            .unwrap()
2285            .item::<bool>());
2286
2287        let data = Array::from_slice(&[0.0, 1.0, 2.0, 3.0], &[2, 2]);
2288        let x = split(&data, 2, 1).unwrap();
2289        let expected = Array::from_slice(&[0.0f32.cos(), 2.0f32.cos()], &[2, 1]);
2290        assert!(all_close(cos(&x[0]).unwrap(), &expected, None, None, None)
2291            .unwrap()
2292            .item::<bool>());
2293    }
2294
2295    #[test]
2296    fn test_unary_degrees() {
2297        let x = array![0.0];
2298        assert_eq!(degrees(&x).unwrap().item::<f32>(), 0.0);
2299
2300        let x = array![std::f32::consts::PI / 2.0];
2301        assert_eq!(degrees(&x).unwrap().item::<f32>(), 90.0);
2302
2303        assert_eq!(degrees(array!()).unwrap(), array!());
2304
2305        // Integer input type
2306        let x = array![0];
2307        assert_eq!(x.dtype(), Dtype::Int32);
2308        assert_eq!(degrees(&x).unwrap().item::<f32>(), 0.0);
2309
2310        // Input is irregularly strided
2311        let x = broadcast_to(&array!(std::f32::consts::PI / 2.0), &[2, 2, 2]).unwrap();
2312        let res = degrees(&x).unwrap();
2313        let expected = Array::full::<f32>(&[2, 2, 2], array!(90.0)).unwrap();
2314        assert!(all_close(&res, &expected, None, None, None)
2315            .unwrap()
2316            .item::<bool>());
2317
2318        let angles = Array::from_slice(&[0.0, PI / 2.0, PI, 1.5 * PI], &[2, 2]);
2319        let x = split(&angles, 2, 1).unwrap();
2320        let expected = Array::from_slice(&[0.0, 180.0], &[2, 1]);
2321        assert!(
2322            all_close(degrees(&x[0]).unwrap(), &expected, None, None, None)
2323                .unwrap()
2324                .item::<bool>()
2325        );
2326    }
2327
2328    #[test]
2329    fn test_unary_radians() {
2330        let x = array![0.0];
2331        assert_eq!(radians(&x).unwrap().item::<f32>(), 0.0);
2332
2333        let x = array![90.0];
2334        assert_eq!(
2335            radians(&x).unwrap().item::<f32>(),
2336            std::f32::consts::PI / 2.0
2337        );
2338
2339        assert_eq!(radians(array!()).unwrap(), array!());
2340
2341        // Integer input type
2342        let x = array![90];
2343        assert_eq!(x.dtype(), Dtype::Int32);
2344        assert_eq!(
2345            radians(&x).unwrap().item::<f32>(),
2346            std::f32::consts::PI / 2.0
2347        );
2348
2349        // Input is irregularly strided
2350        let x = broadcast_to(&array!(90.0), &[2, 2, 2]).unwrap();
2351        let res = radians(&x).unwrap();
2352        let expected = Array::full::<f32>(&[2, 2, 2], array!(std::f32::consts::PI / 2.0)).unwrap();
2353        assert!(all_close(&res, &expected, None, None, None)
2354            .unwrap()
2355            .item::<bool>());
2356
2357        let angles = Array::from_slice(&[0.0, 90.0, 180.0, 270.0], &[2, 2]);
2358        let x = split(&angles, 2, 1).unwrap();
2359        let expected = Array::from_slice(&[0.0, PI], &[2, 1]);
2360        assert!(
2361            all_close(radians(&x[0]).unwrap(), &expected, None, None, None)
2362                .unwrap()
2363                .item::<bool>()
2364        );
2365    }
2366
2367    #[test]
2368    fn test_unary_log() {
2369        let x = array![0.0];
2370        assert_eq!(log(&x).unwrap().item::<f32>(), NEG_INF);
2371
2372        let x = array![1.0];
2373        assert_eq!(log(&x).unwrap().item::<f32>(), 0.0);
2374
2375        // Integer input type
2376        let x = array![1];
2377        assert_eq!(log(&x).unwrap().dtype(), Dtype::Float32);
2378        assert_eq!(log(&x).unwrap().item::<f32>(), 0.0);
2379
2380        // Input is irregularly strided
2381        let x = broadcast_to(&array!(1.0), &[2, 2, 2]).unwrap();
2382        let res = log(&x).unwrap();
2383        let expected = Array::full::<f32>(&[2, 2, 2], array!(0.0)).unwrap();
2384        assert!(all_close(&res, &expected, None, None, None)
2385            .unwrap()
2386            .item::<bool>());
2387
2388        let data = Array::from_slice(&[1.0, 2.0, 3.0, 4.0], &[2, 2]);
2389        let x = split(&data, 2, 1).unwrap();
2390        let expected = Array::from_slice(&[1.0f32.ln(), 3.0f32.ln()], &[2, 1]);
2391        assert!(all_close(log(&x[0]).unwrap(), &expected, None, None, None)
2392            .unwrap()
2393            .item::<bool>());
2394    }
2395
2396    #[test]
2397    fn test_unary_log2() {
2398        let x = array![0.0];
2399        assert_eq!(log2(&x).unwrap().item::<f32>(), NEG_INF);
2400
2401        let x = array![1.0];
2402        assert_eq!(log2(&x).unwrap().item::<f32>(), 0.0);
2403
2404        let x = array![1024.0];
2405        assert_eq!(log2(&x).unwrap().item::<f32>(), 10.0);
2406    }
2407
2408    #[test]
2409    fn test_unary_log10() {
2410        let x = array![0.0];
2411        assert_eq!(log10(&x).unwrap().item::<f32>(), NEG_INF);
2412
2413        let x = array![1.0];
2414        assert_eq!(log10(&x).unwrap().item::<f32>(), 0.0);
2415
2416        let x = array![1000.0];
2417        assert_eq!(log10(&x).unwrap().item::<f32>(), 3.0);
2418    }
2419
2420    #[test]
2421    fn test_unary_log1p() {
2422        let x = array![-1.0];
2423        assert_float_eq! {
2424            log1p(&x).unwrap().item::<f32>(),
2425            (-1.0f32).ln_1p(),
2426            abs <= 1e-5
2427        };
2428
2429        let x = array![1.0];
2430        assert_float_eq! {
2431            log1p(&x).unwrap().item::<f32>(),
2432            1.0f32.ln_1p(),
2433            abs <= 1e-5
2434        };
2435
2436        // Integer input type
2437        let x = array![1];
2438        assert_eq!(log1p(&x).unwrap().dtype(), Dtype::Float32);
2439        assert_float_eq! {
2440            log1p(&x).unwrap().item::<f32>(),
2441            1.0f32.ln_1p(),
2442            abs <= 1e-5
2443        };
2444
2445        // Input is irregularly strided
2446        let x = broadcast_to(&array!(1.0), &[2, 2, 2]).unwrap();
2447        let res = log1p(&x).unwrap();
2448        let expected = Array::full::<f32>(&[2, 2, 2], array!(1.0f32.ln_1p())).unwrap();
2449        assert!(all_close(&res, &expected, None, None, None)
2450            .unwrap()
2451            .item::<bool>());
2452
2453        let data = Array::from_slice(&[1.0, 2.0, 3.0, 4.0], &[2, 2]);
2454        let x = split(&data, 2, 1).unwrap();
2455        let expected = Array::from_slice(&[1.0f32.ln_1p(), 3.0f32.ln_1p()], &[2, 1]);
2456        assert!(
2457            all_close(log1p(&x[0]).unwrap(), &expected, None, None, None)
2458                .unwrap()
2459                .item::<bool>()
2460        );
2461    }
2462
2463    #[test]
2464    fn test_unary_sigmoid() {
2465        let x = array![0.0];
2466        assert_float_eq! {
2467            sigmoid(&x).unwrap().item::<f32>(),
2468            0.5,
2469            abs <= 1e-5
2470        };
2471
2472        // Integer input type
2473        let x = array![0];
2474        assert_eq!(sigmoid(&x).unwrap().dtype(), Dtype::Float32);
2475        assert_float_eq! {
2476            sigmoid(&x).unwrap().item::<f32>(),
2477            0.5,
2478            abs <= 1e-5
2479        };
2480
2481        let inf = f32::INFINITY;
2482        let x = array![inf];
2483        assert_eq!(sigmoid(&x).unwrap().item::<f32>(), 1.0);
2484
2485        let x = array![-inf];
2486        assert_eq!(sigmoid(&x).unwrap().item::<f32>(), 0.0);
2487    }
2488
2489    #[test]
2490    fn test_unary_square() {
2491        let x = array![3.0];
2492        assert_eq!(square(&x).unwrap().item::<f32>(), 9.0);
2493
2494        let x = array![2];
2495        assert_eq!(square(&x).unwrap().item::<i32>(), 4);
2496
2497        let x = Array::full::<f32>(&[3, 3], array!(2.0)).unwrap();
2498        assert!(all_close(
2499            square(&x).unwrap(),
2500            Array::full::<f32>(&[3, 3], array!(4.0)).unwrap(),
2501            None,
2502            None,
2503            None
2504        )
2505        .unwrap()
2506        .item::<bool>());
2507    }
2508
2509    #[test]
2510    fn test_unary_sqrt_rsqrt() {
2511        let x = array![4.0];
2512        assert_eq!(sqrt(&x).unwrap().item::<f32>(), 2.0);
2513        assert_eq!(rsqrt(&x).unwrap().item::<f32>(), 0.5);
2514
2515        let x = Array::full::<f32>(&[3, 3], array!(9.0)).unwrap();
2516        assert!(all_close(
2517            sqrt(&x).unwrap(),
2518            Array::full::<f32>(&[3, 3], array!(3.0)).unwrap(),
2519            None,
2520            None,
2521            None
2522        )
2523        .unwrap()
2524        .item::<bool>());
2525
2526        let x = array![4i32];
2527        assert_eq!(sqrt(&x).unwrap().item::<f32>(), 2.0);
2528        assert_eq!(rsqrt(&x).unwrap().item::<f32>(), 0.5);
2529    }
2530
2531    #[test]
2532    fn test_unary_reciprocal() {
2533        let x = array![8.0];
2534        assert_eq!(reciprocal(&x).unwrap().item::<f32>(), 0.125);
2535
2536        let x = array![2];
2537        let out = reciprocal(&x).unwrap();
2538        assert_eq!(out.dtype(), Dtype::Float32);
2539        assert_eq!(out.item::<f32>(), 0.5);
2540
2541        let x = Array::full::<f32>(&[3, 3], array!(2.0)).unwrap();
2542        assert!(all_close(
2543            reciprocal(&x).unwrap(),
2544            Array::full::<f32>(&[3, 3], array!(0.5)).unwrap(),
2545            None,
2546            None,
2547            None
2548        )
2549        .unwrap()
2550        .item::<bool>());
2551    }
2552
2553    #[test]
2554    fn test_unary_real_imag() {
2555        let x = Array::from_complex(complex64::new(0.0, 1.0));
2556        assert_eq!(real(&x).unwrap(), Array::from_f32(0.0));
2557        assert_eq!(imag(&x).unwrap(), Array::from_f32(1.0));
2558    }
2559
2560    #[test]
2561    fn test_binary_add() {
2562        let x = array![1.0];
2563        let y = array![1.0];
2564        let z = add(&x, &y).unwrap();
2565        assert_eq!(z.item::<f32>(), 2.0);
2566
2567        let z = &x + y;
2568        assert_eq!(z.item::<f32>(), 2.0);
2569
2570        let z = add(z, &x).unwrap();
2571        assert_eq!(z.item::<f32>(), 3.0);
2572
2573        // Chain a few adds:
2574        let mut out = x.deep_clone();
2575        for _ in 0..10 {
2576            out = add(&out, &x).unwrap();
2577        }
2578        assert_eq!(out.item::<f32>(), 11.0);
2579
2580        // Works for different shapes
2581        let x = array!([1.0, 2.0, 3.0]);
2582        let y = array!([1.0, 2.0, 3.0]);
2583        let z = add(&x, &y).unwrap();
2584        assert_eq!(z.shape(), &[3]);
2585        assert_eq!(z, array!([2.0, 4.0, 6.0]));
2586
2587        // Works with scalars
2588        let x = array!([1.0, 2.0, 3.0]);
2589        let y = &x + 2.0;
2590        assert_eq!(y.dtype(), Dtype::Float32);
2591        assert_eq!(y, array!([3.0, 4.0, 5.0]));
2592        let y = &x + 2.0;
2593        assert_eq!(y.dtype(), Dtype::Float32);
2594        assert_eq!(y, array!([3.0, 4.0, 5.0]));
2595
2596        // Check type promotion
2597        let y = x + 2;
2598        assert_eq!(y.dtype(), Dtype::Float32);
2599
2600        let y = array!([1, 2, 3]) + 2.0;
2601        assert_eq!(y.dtype(), Dtype::Float32);
2602        // assert!(array_equal(&y, &array![3.0, 4.0, 5.0]).item::<bool>());
2603        assert_eq!(y, array!([3.0, 4.0, 5.0]));
2604
2605        // Broadcasting works
2606        let x = broadcast_to(&array!(1.0), &[10]).unwrap();
2607        let y = broadcast_to(&array!(2.0), &[10]).unwrap();
2608        let z = add(&x, &y).unwrap();
2609        assert_eq!(z, full::<f32>(&[10], array!(3.0)).unwrap());
2610
2611        let x = Array::from_slice(&[1.0, 2.0], &[1, 2]);
2612        let y = Array::from_slice(&[1.0, 2.0], &[2, 1]);
2613        let z = add(&x, &y).unwrap();
2614        assert_eq!(z.shape(), &[2, 2]);
2615        assert_eq!(z, Array::from_slice(&[2.0, 3.0, 3.0, 4.0], &[2, 2]));
2616
2617        let x = ones::<f32>(&[3, 2, 1]).unwrap();
2618        let z = x + 2.0;
2619        assert_eq!(z.shape(), &[3, 2, 1]);
2620        let expected = Array::from_slice(&[3.0, 3.0, 3.0, 3.0, 3.0, 3.0], &[3, 2, 1]);
2621        assert_eq!(z, expected);
2622
2623        // Works for empty arrays
2624        let x = array!();
2625        let y = array!();
2626        let z = x + y;
2627        z.eval().unwrap();
2628        assert_eq!(z.size(), 0);
2629        assert_eq!(z.shape(), &[0]);
2630    }
2631
2632    #[test]
2633    fn test_binary_sub() {
2634        let x = array!([3.0, 2.0, 1.0]);
2635        let y = array!([1.0, 1.0, 1.0]);
2636        assert_eq!(x - y, array!([2.0, 1.0, 0.0]));
2637    }
2638
2639    #[test]
2640    fn test_binary_mul() {
2641        let x = array!([1.0, 2.0, 3.0]);
2642        let y = array!([2.0, 2.0, 2.0]);
2643        assert_eq!(x * y, array!([2.0, 4.0, 6.0]));
2644    }
2645
2646    #[test]
2647    fn test_binary_div() {
2648        let x = array![1.0];
2649        let y = array![1.0];
2650        assert_eq!(divide(&x, &y).unwrap().item::<f32>(), 1.0);
2651
2652        let x = array![1.0];
2653        let y = array![0.5];
2654        assert_eq!(divide(&x, &y).unwrap().item::<f32>(), 2.0);
2655
2656        let x = array![1.0];
2657        let y = array![4.0];
2658        assert_eq!(divide(&x, &y).unwrap().item::<f32>(), 0.25);
2659
2660        let x = array![true];
2661        let y = array![true];
2662        assert_eq!(divide(&x, &y).unwrap().item::<f32>(), 1.0);
2663
2664        let x = array![false];
2665        let y = array![true];
2666        assert_eq!(divide(&x, &y).unwrap().item::<f32>(), 0.0);
2667
2668        let x = array![true];
2669        let y = array![false];
2670        assert!(divide(&x, &y).unwrap().item::<f32>().is_infinite());
2671
2672        let x = array![false];
2673        let y = array![false];
2674        assert!(divide(&x, &y).unwrap().item::<f32>().is_nan());
2675    }
2676
2677    #[test]
2678    fn test_binary_maximum_minimum() {
2679        let x = array![1.0];
2680        let y = array![0.0];
2681        assert_eq!(maximum(&x, &y).unwrap().item::<f32>(), 1.0);
2682        assert_eq!(minimum(&x, &y).unwrap().item::<f32>(), 0.0);
2683
2684        let y = array![2.0];
2685        assert_eq!(maximum(&x, &y).unwrap().item::<f32>(), 2.0);
2686        assert_eq!(minimum(&x, &y).unwrap().item::<f32>(), 1.0);
2687    }
2688
2689    #[test]
2690    fn test_binary_logaddexp() {
2691        let x = array![0.0];
2692        let y = array![0.0];
2693        assert_float_eq! {
2694            logaddexp(&x, &y).unwrap().item::<f32>(),
2695            2.0f32.ln(),
2696            abs <= 1e-5
2697        };
2698
2699        let x = array!([0u32]);
2700        let y = array!([10000u32]);
2701        assert_eq!(logaddexp(&x, &y).unwrap().item::<f32>(), 10000.0);
2702
2703        let x = array![f32::INFINITY];
2704        let y = array![3.0];
2705        assert_eq!(logaddexp(&x, &y).unwrap().item::<f32>(), f32::INFINITY);
2706
2707        let x = array![f32::NEG_INFINITY];
2708        let y = array![3.0];
2709        assert_eq!(logaddexp(&x, &y).unwrap().item::<f32>(), 3.0);
2710
2711        let x = array![f32::NEG_INFINITY];
2712        let y = array![f32::NEG_INFINITY];
2713        assert_eq!(logaddexp(&x, &y).unwrap().item::<f32>(), f32::NEG_INFINITY);
2714
2715        let x = array![f32::INFINITY];
2716        let y = array![f32::INFINITY];
2717        assert_eq!(logaddexp(&x, &y).unwrap().item::<f32>(), f32::INFINITY);
2718
2719        let x = array![f32::NEG_INFINITY];
2720        let y = array![f32::INFINITY];
2721        assert_eq!(logaddexp(&x, &y).unwrap().item::<f32>(), f32::INFINITY);
2722    }
2723
2724    #[test]
2725    fn test_basic_clip() {
2726        let a = array!([1.0, 4.0, 3.0, 8.0, 5.0]);
2727        let expected = array!([2.0, 4.0, 3.0, 6.0, 5.0]);
2728        let clipped = clip(&a, (array!(2.0), array!(6.0))).unwrap();
2729        assert_eq!(clipped, expected);
2730
2731        // Test with scalar
2732        let clipped = clip(&a, (2.0, 6.0)).unwrap();
2733        assert_eq!(clipped, expected);
2734    }
2735
2736    #[test]
2737    fn test_clip_with_only_min() {
2738        let a = array!([-1.0, 1.0, 0.0, 5.0]);
2739        let expected = array!([0.0, 1.0, 0.0, 5.0]);
2740        let clipped = clip(&a, (array!(0.0), ())).unwrap();
2741        assert_eq!(clipped, expected);
2742
2743        // Test with scalar
2744        let clipped = clip(&a, (0.0, ())).unwrap();
2745        assert_eq!(clipped, expected);
2746    }
2747
2748    #[test]
2749    fn test_clip_with_only_max() {
2750        let a = array!([2.0, 3.0, 4.0, 5.0]);
2751        let expected = array!([2.0, 3.0, 4.0, 4.0]);
2752        let clipped = clip(&a, ((), array!(4.0))).unwrap();
2753        assert_eq!(clipped, expected);
2754
2755        // Test with scalar
2756        let clipped = clip(&a, ((), 4.0)).unwrap();
2757        assert_eq!(clipped, expected);
2758    }
2759
2760    #[test]
2761    fn test_tensordot() {
2762        let x = reshape(arange::<_, f32>(None, 60.0, None).unwrap(), &[3, 4, 5]).unwrap();
2763        let y = reshape(arange::<_, f32>(None, 24.0, None).unwrap(), &[4, 3, 2]).unwrap();
2764        let z = tensordot_axes(&x, &y, &[1i32, 0], &[0i32, 1]).unwrap();
2765        let expected = Array::from_slice(
2766            &[4400, 4730, 4532, 4874, 4664, 5018, 4796, 5162, 4928, 5306],
2767            &[5, 2],
2768        );
2769        assert_eq!(z, expected);
2770
2771        let x = reshape(arange::<_, f32>(None, 360.0, None).unwrap(), &[3, 4, 5, 6]).unwrap();
2772        let y = reshape(arange::<_, f32>(None, 360.0, None).unwrap(), &[6, 4, 5, 3]).unwrap();
2773        assert!(tensordot_axes(&x, &y, &[2, 1, 3], &[1, 2, 0]).is_err());
2774
2775        let x = reshape(arange::<_, f32>(None, 60.0, None).unwrap(), &[3, 4, 5]).unwrap();
2776        let y = reshape(arange::<_, f32>(None, 120.0, None).unwrap(), &[4, 5, 6]).unwrap();
2777
2778        let z = tensordot_axis(&x, &y, 2).unwrap();
2779        let expected = Array::from_slice(
2780            &[
2781                14820.0, 15010.0, 15200.0, 15390.0, 15580.0, 15770.0, 37620.0, 38210.0, 38800.0,
2782                39390.0, 39980.0, 40570.0, 60420.0, 61410.0, 62400.0, 63390.0, 64380.0, 65370.0,
2783            ],
2784            &[3, 6],
2785        );
2786        assert_eq!(z, expected);
2787    }
2788
2789    #[test]
2790    fn test_outer() {
2791        let x = arange::<_, f32>(1.0, 5.0, None).unwrap();
2792        let y = arange::<_, f32>(1.0, 4.0, None).unwrap();
2793        let z = outer(&x, &y).unwrap();
2794        let expected = Array::from_slice(
2795            &[1.0, 2.0, 3.0, 2.0, 4.0, 6.0, 3.0, 6.0, 9.0, 4.0, 8.0, 12.0],
2796            &[4, 3],
2797        );
2798        assert_eq!(z, expected);
2799
2800        let x = ones::<f32>(&[5]).unwrap();
2801        let y = linspace::<_, f32>(-2.0, 2.0, 5).unwrap();
2802        let z = outer(&x, &y).unwrap();
2803        let expected = Array::from_slice(
2804            &[
2805                -2.0, -1.0, 0.0, 1.0, 2.0, -2.0, -1.0, 0.0, 1.0, 2.0, -2.0, -1.0, 0.0, 1.0, 2.0,
2806                -2.0, -1.0, 0.0, 1.0, 2.0, -2.0, -1.0, 0.0, 1.0, 2.0,
2807            ],
2808            &[5, 5],
2809        );
2810        assert_eq!(z, expected);
2811    }
2812
2813    #[test]
2814    fn test_inner() {
2815        let x = reshape(arange::<_, f32>(None, 5.0, None).unwrap(), &[1, 5]).unwrap();
2816        let y = reshape(arange::<_, f32>(None, 6.0, None).unwrap(), &[2, 3]).unwrap();
2817        assert!(inner(&x, &y).is_err());
2818
2819        let x = array!([1.0, 2.0, 3.0]);
2820        let y = array!([0.0, 1.0, 0.0]);
2821        let z = inner(&x, &y).unwrap();
2822        assert_eq!(z.item::<f32>(), 2.0);
2823
2824        let x = reshape(arange::<_, f32>(None, 24.0, None).unwrap(), &[2, 3, 4]).unwrap();
2825        let y = arange::<_, f32>(None, 4.0, None).unwrap();
2826        let z = inner(&x, &y).unwrap();
2827        let expected = Array::from_slice(&[14.0, 38.0, 62.0, 86.0, 110.0, 134.0], &[2, 3]);
2828        assert_eq!(z, expected);
2829
2830        let x = reshape(arange::<_, f32>(None, 2.0, None).unwrap(), &[1, 1, 2]).unwrap();
2831        let y = reshape(arange::<_, f32>(None, 6.0, None).unwrap(), &[3, 2]).unwrap();
2832        let z = inner(&x, &y).unwrap();
2833        let expected = Array::from_slice(&[1.0, 3.0, 5.0], &[1, 1, 3]);
2834        assert_eq!(z, expected);
2835
2836        let x = eye::<f32>(2, None, None).unwrap();
2837        let y = Array::from_f32(7.0);
2838        let z = inner(&x, &y).unwrap();
2839        let expected = Array::from_slice(&[7.0, 0.0, 0.0, 7.0], &[2, 2]);
2840        assert_eq!(z, expected);
2841    }
2842
2843    #[test]
2844    fn test_divmod() {
2845        let x = array!([1.0, 2.0, 3.0]);
2846        let y = array!([1.0, 1.0, 1.0]);
2847        let out = divmod(&x, &y).unwrap();
2848        assert_eq!(out.0, array!([1.0, 2.0, 3.0]));
2849        assert_eq!(out.1, array!([0.0, 0.0, 0.0]));
2850
2851        let x = array!([5.0, 6.0, 7.0]);
2852        let y = array!([2.0, 2.0, 2.0]);
2853        let out = divmod(&x, &y).unwrap();
2854        assert_eq!(out.0, array!([2.0, 3.0, 3.0]));
2855        assert_eq!(out.1, array!([1.0, 0.0, 1.0]));
2856
2857        let x = array!([5.0, 6.0, 7.0]);
2858        let y = array!([2.0, 2.0, 2.0]);
2859        let out = divmod(&x, &y).unwrap();
2860        assert_eq!(out.0, array!([2.0, 3.0, 3.0]));
2861        assert_eq!(out.1, array!([1.0, 0.0, 1.0]));
2862
2863        let x = array![complex64::new(1.0, 0.0)];
2864        let y = array![complex64::new(2.0, 0.0)];
2865        assert!(divmod(&x, &y).is_err());
2866
2867        // Check that we can eval on both outputs
2868        let x = array![1.0];
2869        let y = array![2.0];
2870        let (quo, rem) = divmod(&x, &y).unwrap();
2871        eval([&quo, &rem]).unwrap();
2872        assert_eq!(quo.item::<f32>(), 0.0);
2873        assert_eq!(rem.item::<f32>(), 1.0);
2874
2875        // Check nested in the graph
2876        let x = array![1.0];
2877        let y = array![2.0];
2878        let (quo, rem) = divmod(&x, &y).unwrap();
2879        let z = quo + rem;
2880        assert_eq!(z.item::<f32>(), 1.0);
2881
2882        // Check that we can still eval when one output goes out of scope
2883        let mut out_holder = {
2884            let (quo, _) = divmod(&x, &y).unwrap();
2885            vec![quo]
2886        };
2887        eval(out_holder.iter()).unwrap();
2888        assert_eq!(out_holder[0].item::<f32>(), 0.0);
2889
2890        // Check that we can still eval when the other output goes out of scope
2891        out_holder.clear();
2892        let out_holder = {
2893            let (_, rem) = divmod(&x, &y).unwrap();
2894            vec![rem]
2895        };
2896        eval(out_holder.iter()).unwrap();
2897        assert_eq!(out_holder[0].item::<f32>(), 1.0);
2898    }
2899}