mlx_rs/ops/
arithmetic.rs

1use crate::array::Array;
2use crate::error::Result;
3use crate::sealed::Sealed;
4use crate::stream::StreamOrDevice;
5
6use crate::utils::guard::Guarded;
7use crate::utils::{IntoOption, ScalarOrArray, VectorArray};
8use crate::Stream;
9use mlx_internal_macros::{default_device, generate_macro};
10use smallvec::SmallVec;
11
12impl Array {
13    /// Element-wise absolute value.
14    ///
15    /// # Example
16    ///
17    /// ```rust
18    /// use mlx_rs::Array;
19    /// let array = Array::from_slice(&[1i32, 2, -3, -4, -5], &[5]);
20    /// let mut result = array.abs().unwrap();
21    ///
22    /// let data: &[i32] = result.as_slice();
23    /// // data == [1, 2, 3, 4, 5]
24    /// ```
25    #[default_device]
26    pub fn abs_device(&self, stream: impl AsRef<Stream>) -> Result<Array> {
27        Array::try_from_op(|res| unsafe {
28            mlx_sys::mlx_abs(res, self.as_ptr(), stream.as_ref().as_ptr())
29        })
30    }
31
32    /// Element-wise addition returning an error if arrays are not broadcastable.
33    ///
34    /// Add two arrays with [broadcasting](https://swiftpackageindex.com/ml-explore/mlx-swift/main/documentation/mlx/broadcasting).
35    ///
36    /// # Params
37    ///
38    /// - other: array to add
39    ///
40    /// # Example
41    ///
42    /// ```rust
43    /// use mlx_rs::Array;
44    /// let a = Array::from_slice(&[1.0, 2.0, 3.0], &[3]);
45    /// let b = Array::from_slice(&[4.0, 5.0, 6.0], &[3]);
46    /// let mut c = a.add(&b).unwrap();
47    ///
48    /// let c_data: &[f32] = c.as_slice();
49    /// // c_data == [5.0, 7.0, 9.0]
50    /// ```
51    #[default_device]
52    pub fn add_device(
53        &self,
54        other: impl AsRef<Array>,
55        stream: impl AsRef<Stream>,
56    ) -> Result<Array> {
57        Array::try_from_op(|res| unsafe {
58            mlx_sys::mlx_add(
59                res,
60                self.as_ptr(),
61                other.as_ref().as_ptr(),
62                stream.as_ref().as_ptr(),
63            )
64        })
65    }
66
67    /// Element-wise subtraction returning an error if arrays are not broadcastable.
68    ///
69    /// Subtract two arrays with [broadcasting](https://swiftpackageindex.com/ml-explore/mlx-swift/main/documentation/mlx/broadcasting).
70    ///
71    /// # Params
72    ///
73    /// - other: array to subtract
74    ///
75    /// # Example
76    ///
77    /// ```rust
78    /// use mlx_rs::Array;
79    /// let a = Array::from_slice(&[1.0, 2.0, 3.0], &[3]);
80    /// let b = Array::from_slice(&[4.0, 5.0, 6.0], &[3]);
81    /// let mut c = a.subtract(&b).unwrap();
82    ///
83    /// let c_data: &[f32] = c.as_slice();
84    /// // c_data == [-3.0, -3.0, -3.0]
85    /// ```
86    #[default_device]
87    pub fn subtract_device(
88        &self,
89        other: impl AsRef<Array>,
90        stream: impl AsRef<Stream>,
91    ) -> Result<Array> {
92        Array::try_from_op(|res| unsafe {
93            mlx_sys::mlx_subtract(
94                res,
95                self.as_ptr(),
96                other.as_ref().as_ptr(),
97                stream.as_ref().as_ptr(),
98            )
99        })
100    }
101
102    /// Unary element-wise negation. Returns an error if the array is of type bool.
103    ///
104    /// Negate the values in the array.
105    ///
106    /// # Example
107    ///
108    /// ```rust
109    /// use mlx_rs::Array;
110    /// let a = Array::from_slice(&[1.0, 2.0, 3.0], &[3]);
111    /// let mut b = a.negative().unwrap();
112    ///
113    /// let b_data: &[f32] = b.as_slice();
114    /// // b_data == [-1.0, -2.0, -3.0]
115    /// ```
116    #[default_device]
117    pub fn negative_device(&self, stream: impl AsRef<Stream>) -> Result<Array> {
118        Array::try_from_op(|res| unsafe {
119            mlx_sys::mlx_negative(res, self.as_ptr(), stream.as_ref().as_ptr())
120        })
121    }
122
123    /// Element-wise multiplication returning an error if arrays are not broadcastable.
124    ///
125    /// Multiply two arrays with [broadcasting](https://swiftpackageindex.com/ml-explore/mlx-swift/main/documentation/mlx/broadcasting).
126    ///
127    /// # Example
128    ///
129    /// ```rust
130    /// use mlx_rs::Array;
131    /// let a = Array::from_slice(&[1.0, 2.0, 3.0], &[3]);
132    /// let b = Array::from_slice(&[4.0, 5.0, 6.0], &[3]);
133    /// let mut c = a.multiply(&b).unwrap();
134    ///
135    /// let c_data: &[f32] = c.as_slice();
136    /// // c_data == [4.0, 10.0, 18.0]
137    /// ```
138    #[default_device]
139    pub fn multiply_device(
140        &self,
141        other: impl AsRef<Array>,
142        stream: impl AsRef<Stream>,
143    ) -> Result<Array> {
144        Array::try_from_op(|res| unsafe {
145            mlx_sys::mlx_multiply(
146                res,
147                self.as_ptr(),
148                other.as_ref().as_ptr(),
149                stream.as_ref().as_ptr(),
150            )
151        })
152    }
153
154    /// Replace NaN and Inf values with finite numbers.
155    ///
156    /// # Params
157    /// - nan: value to replace NaN with
158    /// - posInf: value to replace positive inifinites with.  If not specified will use
159    ///     the largest finite value for the given dtype.
160    /// - negInf: value to replace negative inifinites with.  If not specified will use
161    ///     the negative of the largest finite value for the given dtype.
162    /// - stream: stream or device to evaluate on
163    #[default_device]
164    pub fn nan_to_num_device(
165        &self,
166        nan: impl IntoOption<f32>,
167        pos_inf: impl IntoOption<f32>,
168        neg_inf: impl IntoOption<f32>,
169        stream: impl AsRef<Stream>,
170    ) -> Result<Array> {
171        let pos_inf = pos_inf.into_option();
172        let neg_inf = neg_inf.into_option();
173
174        let pos_inf = mlx_sys::mlx_optional_float {
175            value: pos_inf.unwrap_or(0.0),
176            has_value: pos_inf.is_some(),
177        };
178        let neg_inf = mlx_sys::mlx_optional_float {
179            value: neg_inf.unwrap_or(0.0),
180            has_value: neg_inf.is_some(),
181        };
182
183        Array::try_from_op(|res| unsafe {
184            mlx_sys::mlx_nan_to_num(
185                res,
186                self.as_ptr(),
187                nan.into_option().unwrap_or(0.),
188                pos_inf,
189                neg_inf,
190                stream.as_ref().as_ptr(),
191            )
192        })
193    }
194
195    /// Element-wise division returning an error if arrays are not broadcastable.
196    ///
197    /// Divide two arrays with [broadcasting](https://swiftpackageindex.com/ml-explore/mlx-swift/main/documentation/mlx/broadcasting).
198    ///
199    /// # Params
200    ///
201    /// - other: array to divide
202    ///
203    /// # Example
204    ///
205    /// ```rust
206    /// use mlx_rs::Array;
207    /// let a = Array::from_slice(&[1.0, 2.0, 3.0], &[3]);
208    /// let b = Array::from_slice(&[4.0, 5.0, 6.0], &[3]);
209    /// let mut c = a.divide(&b).unwrap();
210    ///
211    /// let c_data: &[f32] = c.as_slice();
212    /// // c_data == [0.25, 0.4, 0.5]
213    /// ```
214    #[default_device]
215    pub fn divide_device(
216        &self,
217        other: impl AsRef<Array>,
218        stream: impl AsRef<Stream>,
219    ) -> Result<Array> {
220        Array::try_from_op(|res| unsafe {
221            mlx_sys::mlx_divide(
222                res,
223                self.as_ptr(),
224                other.as_ref().as_ptr(),
225                stream.as_ref().as_ptr(),
226            )
227        })
228    }
229
230    /// Element-wise power operation returning an error if arrays are not broadcastable if they have different shapes.
231    ///
232    /// Raise the elements of the array to the power of the elements of another array.
233    ///
234    /// # Params
235    ///
236    /// - other: array to raise to the power of
237    ///
238    /// # Example
239    ///
240    /// ```rust
241    /// use mlx_rs::Array;
242    /// let a = Array::from_slice(&[1.0, 2.0, 3.0], &[3]);
243    /// let b = Array::from_slice(&[2.0, 3.0, 4.0], &[3]);
244    /// let mut c = a.power(&b).unwrap();
245    ///
246    /// let c_data: &[f32] = c.as_slice();
247    /// // c_data == [1.0, 8.0, 81.0]
248    /// ```
249    #[default_device]
250    pub fn power_device(
251        &self,
252        other: impl AsRef<Array>,
253        stream: impl AsRef<Stream>,
254    ) -> Result<Array> {
255        Array::try_from_op(|res| unsafe {
256            mlx_sys::mlx_power(
257                res,
258                self.as_ptr(),
259                other.as_ref().as_ptr(),
260                stream.as_ref().as_ptr(),
261            )
262        })
263    }
264
265    /// Element-wise remainder of division returning an error if arrays are not broadcastable.
266    ///
267    /// Computes the remainder of dividing `lhs` with `rhs` with [broadcasting](https://swiftpackageindex.com/ml-explore/mlx-swift/main/documentation/mlx/broadcasting).
268    ///
269    /// # Params
270    ///
271    /// - other: array to divide
272    ///
273    /// # Example
274    ///
275    /// ```rust
276    /// use mlx_rs::Array;
277    /// let a = Array::from_slice(&[10.0, 11.0, 12.0], &[3]);
278    /// let b = Array::from_slice(&[3.0, 4.0, 5.0], &[3]);
279    /// let mut c = a.remainder(&b).unwrap();
280    ///
281    /// let c_data: &[f32] = c.as_slice();
282    /// // c_data == [1.0, 3.0, 2.0]
283    /// ```
284    #[default_device]
285    pub fn remainder_device(
286        &self,
287        other: impl AsRef<Array>,
288        stream: impl AsRef<Stream>,
289    ) -> Result<Array> {
290        Array::try_from_op(|res| unsafe {
291            mlx_sys::mlx_remainder(
292                res,
293                self.as_ptr(),
294                other.as_ref().as_ptr(),
295                stream.as_ref().as_ptr(),
296            )
297        })
298    }
299
300    /// Element-wise square root
301    ///
302    /// # Example
303    ///
304    /// ```rust
305    /// use mlx_rs::Array;
306    /// let a = Array::from_slice(&[1.0, 4.0, 9.0], &[3]);
307    /// let mut b = a.sqrt().unwrap();
308    ///
309    /// let b_data: &[f32] = b.as_slice();
310    /// // b_data == [1.0, 2.0, 3.0]
311    /// ```
312    #[default_device]
313    pub fn sqrt_device(&self, stream: impl AsRef<Stream>) -> Result<Array> {
314        Array::try_from_op(|res| unsafe {
315            mlx_sys::mlx_sqrt(res, self.as_ptr(), stream.as_ref().as_ptr())
316        })
317    }
318
319    /// Element-wise cosine
320    ///
321    /// # Example
322    ///
323    /// ```rust
324    /// use mlx_rs::Array;
325    /// let a = Array::from_slice(&[0.0, 1.0, 2.0], &[3]);
326    /// let mut b = a.cos().unwrap();
327    ///
328    /// let b_data: &[f32] = b.as_slice();
329    /// // b_data == [1.0, 0.54030234, -0.41614687]
330    /// ```
331    #[default_device]
332    pub fn cos_device(&self, stream: impl AsRef<Stream>) -> Result<Array> {
333        Array::try_from_op(|res| unsafe {
334            mlx_sys::mlx_cos(res, self.as_ptr(), stream.as_ref().as_ptr())
335        })
336    }
337
338    /// Element-wise exponential.
339    ///
340    /// # Example
341    ///
342    /// ```rust
343    /// use mlx_rs::Array;
344    ///
345    /// let a = Array::from_slice(&[0.0, 1.0, 2.0], &[3]);
346    /// let a = Array::from_slice(&[0.0, 1.0, 2.0], &[3]);
347    /// let mut b = a.exp().unwrap();
348    ///
349    /// let b_data: &[f32] = b.as_slice();
350    /// // b_data == [1.0, 2.7182817, 7.389056]
351    /// ```
352    #[default_device]
353    pub fn exp_device(&self, stream: impl AsRef<Stream>) -> Result<Array> {
354        Array::try_from_op(|res| unsafe {
355            mlx_sys::mlx_exp(res, self.as_ptr(), stream.as_ref().as_ptr())
356        })
357    }
358
359    /// Element-wise floor returning an error if the array is of type complex64.
360    ///
361    /// # Example
362    ///
363    /// ```rust
364    /// use mlx_rs::Array;
365    /// let a = Array::from_slice(&[0.1, 1.9, 2.5], &[3]);
366    /// let mut b = a.floor().unwrap();
367    ///
368    /// let b_data: &[f32] = b.as_slice();
369    /// // b_data == [0.0, 1.0, 2.0]
370    /// ```
371    #[default_device]
372    pub fn floor_device(&self, stream: impl AsRef<Stream>) -> Result<Array> {
373        Array::try_from_op(|res| unsafe {
374            mlx_sys::mlx_floor(res, self.as_ptr(), stream.as_ref().as_ptr())
375        })
376    }
377
378    /// Element-wise integer division returning an error if arrays are not broadcastable.
379    ///
380    /// Divide two arrays with
381    /// [broadcasting](https://swiftpackageindex.com/ml-explore/mlx-swift/main/documentation/mlx/broadcasting).
382    ///
383    /// If either array is a floating point type then it is equivalent to calling [`Array::floor()`]
384    /// after `/`.
385    ///
386    /// # Params
387    ///
388    /// - other: array to divide
389    ///
390    /// # Example
391    ///
392    /// ```rust
393    /// use mlx_rs::Array;
394    /// let a = Array::from_slice(&[1.0, 2.0, 3.0], &[3]);
395    /// let b = Array::from_slice(&[4.0, 5.0, 6.0], &[3]);
396    /// let mut c = a.floor_divide(&b).unwrap();
397    ///
398    /// let c_data: &[f32] = c.as_slice();
399    /// // c_data == [0.25, 0.4, 0.5]
400    /// ```
401    #[default_device]
402    pub fn floor_divide_device(
403        &self,
404        other: impl AsRef<Array>,
405        stream: impl AsRef<Stream>,
406    ) -> Result<Array> {
407        Array::try_from_op(|res| unsafe {
408            mlx_sys::mlx_floor_divide(
409                res,
410                self.as_ptr(),
411                other.as_ref().as_ptr(),
412                stream.as_ref().as_ptr(),
413            )
414        })
415    }
416
417    /// Return a boolean array indicating which elements are NaN.
418    ///
419    /// # Params
420    /// - stream: stream or device to evaluate on
421    #[default_device]
422    pub fn is_nan_device(&self, stream: impl AsRef<Stream>) -> Result<Array> {
423        Array::try_from_op(|res| unsafe {
424            mlx_sys::mlx_isnan(res, self.as_ptr(), stream.as_ref().as_ptr())
425        })
426    }
427
428    /// Return a boolean array indicating which elements are infinity.
429    ///
430    /// # Params
431    /// - stream: stream or device to evaluate on
432    #[default_device]
433    pub fn is_inf_device(&self, stream: impl AsRef<Stream>) -> Result<Array> {
434        Array::try_from_op(|res| unsafe {
435            mlx_sys::mlx_isinf(res, self.as_ptr(), stream.as_ref().as_ptr())
436        })
437    }
438
439    /// Return a boolean array indicating which elements are finite.
440    ///
441    /// # Params
442    /// - stream: stream or device to evaluate on
443    #[default_device]
444    pub fn is_finite_device(&self, stream: impl AsRef<Stream>) -> Result<Array> {
445        Array::try_from_op(|res| unsafe {
446            mlx_sys::mlx_isfinite(res, self.as_ptr(), stream.as_ref().as_ptr())
447        })
448    }
449
450    /// Return a boolean array indicating which elements are negative infinity.
451    ///
452    /// # Params
453    /// - stream: stream or device to evaluate on
454    #[default_device]
455    pub fn is_neg_inf_device(&self, stream: impl AsRef<Stream>) -> Result<Array> {
456        Array::try_from_op(|res| unsafe {
457            mlx_sys::mlx_isneginf(res, self.as_ptr(), stream.as_ref().as_ptr())
458        })
459    }
460
461    /// Return a boolean array indicating which elements are positive infinity.
462    ///
463    /// # Params
464    /// - stream: stream or device to evaluate on
465    #[default_device]
466    pub fn is_pos_inf_device(&self, stream: impl AsRef<Stream>) -> Result<Array> {
467        Array::try_from_op(|res| unsafe {
468            mlx_sys::mlx_isposinf(res, self.as_ptr(), stream.as_ref().as_ptr())
469        })
470    }
471
472    /// Element-wise natural logarithm.
473    ///
474    /// # Example
475    ///
476    /// ```rust
477    /// use mlx_rs::Array;
478    /// let a = Array::from_slice(&[1.0, 2.0, 3.0], &[3]);
479    /// let mut b = a.log().unwrap();
480    ///
481    /// let b_data: &[f32] = b.as_slice();
482    /// // b_data == [0.0, 0.6931472, 1.0986123]
483    /// ```
484    #[default_device]
485    pub fn log_device(&self, stream: impl AsRef<Stream>) -> Result<Array> {
486        Array::try_from_op(|res| unsafe {
487            mlx_sys::mlx_log(res, self.as_ptr(), stream.as_ref().as_ptr())
488        })
489    }
490
491    /// Element-wise base-2 logarithm.
492    ///
493    /// # Example
494    ///
495    /// ```rust
496    /// use mlx_rs::Array;
497    /// let a = Array::from_slice(&[1.0, 2.0, 4.0, 8.0], &[4]);
498    /// let mut b = a.log2().unwrap();
499    ///
500    /// let b_data: &[f32] = b.as_slice();
501    /// // b_data == [0.0, 1.0, 2.0, 3.0]
502    /// ```
503    #[default_device]
504    pub fn log2_device(&self, stream: impl AsRef<Stream>) -> Result<Array> {
505        Array::try_from_op(|res| unsafe {
506            mlx_sys::mlx_log2(res, self.as_ptr(), stream.as_ref().as_ptr())
507        })
508    }
509
510    /// Element-wise base-10 logarithm.
511    ///
512    /// # Example
513    ///
514    /// ```rust
515    /// use mlx_rs::Array;
516    /// let a = Array::from_slice(&[1.0, 10.0, 100.0], &[3]);
517    /// let mut b = a.log10().unwrap();
518    ///
519    /// let b_data: &[f32] = b.as_slice();
520    /// // b_data == [0.0, 1.0, 2.0]
521    /// ```
522    #[default_device]
523    pub fn log10_device(&self, stream: impl AsRef<Stream>) -> Result<Array> {
524        Array::try_from_op(|res| unsafe {
525            mlx_sys::mlx_log10(res, self.as_ptr(), stream.as_ref().as_ptr())
526        })
527    }
528
529    /// Element-wise natural log of one plus the array.
530    ///
531    /// # Example
532    ///
533    /// ```rust
534    /// use mlx_rs::Array;
535    /// let a = Array::from_slice(&[1.0, 2.0, 3.0], &[3]);
536    /// let mut b = a.log1p().unwrap();
537    ///
538    /// let b_data: &[f32] = b.as_slice();
539    /// // b_data == [0.6931472, 1.0986123, 1.3862944]
540    /// ```
541    #[default_device]
542    pub fn log1p_device(&self, stream: impl AsRef<Stream>) -> Result<Array> {
543        Array::try_from_op(|res| unsafe {
544            mlx_sys::mlx_log1p(res, self.as_ptr(), stream.as_ref().as_ptr())
545        })
546    }
547
548    /// Matrix multiplication returning an error if inputs are not valid.
549    ///
550    /// Perform the (possibly batched) matrix multiplication of two arrays. This function supports
551    /// broadcasting for arrays with more than two dimensions.
552    ///
553    /// - If the first array is 1-D then a 1 is prepended to its shape to make it
554    ///   a matrix. Similarly, if the second array is 1-D then a 1 is appended to its
555    ///   shape to make it a matrix. In either case the singleton dimension is removed
556    ///   from the result.
557    /// - A batched matrix multiplication is performed if the arrays have more than
558    ///   2 dimensions.  The matrix dimensions for the matrix product are the last
559    ///   two dimensions of each input.
560    /// - All but the last two dimensions of each input are broadcast with one another using
561    ///   standard [broadcasting](https://swiftpackageindex.com/ml-explore/mlx-swift/main/documentation/mlx/broadcasting).
562    ///
563    /// # Params
564    ///
565    /// - other: array to multiply
566    ///
567    /// # Example
568    ///
569    /// ```rust
570    /// use mlx_rs::Array;
571    /// let a = Array::from_slice(&[1, 2, 3, 4], &[2, 2]);
572    /// let b = Array::from_slice(&[-5.0, 37.5, 4., 7., 1., 0.], &[2, 3]);
573    ///
574    /// // produces a [2, 3] result
575    /// let mut c = a.matmul(&b);
576    /// ```
577    #[default_device]
578    pub fn matmul_device(
579        &self,
580        other: impl AsRef<Array>,
581        stream: impl AsRef<Stream>,
582    ) -> Result<Array> {
583        Array::try_from_op(|res| unsafe {
584            mlx_sys::mlx_matmul(
585                res,
586                self.as_ptr(),
587                other.as_ref().as_ptr(),
588                stream.as_ref().as_ptr(),
589            )
590        })
591    }
592
593    /// Element-wise reciprocal.
594    ///
595    /// # Example
596    ///
597    /// ```rust
598    /// use mlx_rs::Array;
599    /// let a = Array::from_slice(&[1.0, 2.0, 4.0], &[3]);
600    /// let mut b = a.reciprocal().unwrap();
601    ///
602    /// let b_data: &[f32] = b.as_slice();
603    /// // b_data == [1.0, 0.5, 0.25]
604    /// ```
605    #[default_device]
606    pub fn reciprocal_device(&self, stream: impl AsRef<Stream>) -> Result<Array> {
607        Array::try_from_op(|res| unsafe {
608            mlx_sys::mlx_reciprocal(res, self.as_ptr(), stream.as_ref().as_ptr())
609        })
610    }
611
612    /// Round to the given number of decimals.
613    ///
614    /// # Params
615    ///
616    /// - decimals: number of decimals to round to - default is 0 if not provided
617    #[default_device]
618    pub fn round_device(
619        &self,
620        decimals: impl Into<Option<i32>>,
621        stream: impl AsRef<Stream>,
622    ) -> Result<Array> {
623        Array::try_from_op(|res| unsafe {
624            mlx_sys::mlx_round(
625                res,
626                self.as_ptr(),
627                decimals.into().unwrap_or(0),
628                stream.as_ref().as_ptr(),
629            )
630        })
631    }
632
633    /// Element-wise reciprocal and square root.
634    #[default_device]
635    pub fn rsqrt_device(&self, stream: impl AsRef<Stream>) -> Result<Array> {
636        Array::try_from_op(|res| unsafe {
637            mlx_sys::mlx_rsqrt(res, self.as_ptr(), stream.as_ref().as_ptr())
638        })
639    }
640
641    /// Element-wise sine.
642    #[default_device]
643    pub fn sin_device(&self, stream: impl AsRef<Stream>) -> Result<Array> {
644        Array::try_from_op(|res| unsafe {
645            mlx_sys::mlx_sin(res, self.as_ptr(), stream.as_ref().as_ptr())
646        })
647    }
648
649    /// Element-wise square.
650    #[default_device]
651    pub fn square_device(&self, stream: impl AsRef<Stream>) -> Result<Array> {
652        Array::try_from_op(|res| unsafe {
653            mlx_sys::mlx_square(res, self.as_ptr(), stream.as_ref().as_ptr())
654        })
655    }
656}
657
658/// Element-wise absolute value.
659///
660/// # Example
661///
662/// ```rust
663/// use mlx_rs::{Array, ops};
664///
665/// let array = Array::from_slice(&[1i32, 2, -3, -4, -5], &[5]);
666/// let result = ops::abs(&array).unwrap();
667/// ```
668#[generate_macro]
669#[default_device]
670pub fn abs_device(a: impl AsRef<Array>, #[optional] stream: impl AsRef<Stream>) -> Result<Array> {
671    a.as_ref().abs_device(stream)
672}
673
674/// Element-wise inverse cosine.
675#[generate_macro]
676#[default_device]
677pub fn acos_device(a: impl AsRef<Array>, #[optional] stream: impl AsRef<Stream>) -> Result<Array> {
678    Array::try_from_op(|res| unsafe {
679        mlx_sys::mlx_arccos(res, a.as_ref().as_ptr(), stream.as_ref().as_ptr())
680    })
681}
682
683/// Element-wise inverse hyperbolic cosine.
684#[generate_macro]
685#[default_device]
686pub fn acosh_device(a: impl AsRef<Array>, #[optional] stream: impl AsRef<Stream>) -> Result<Array> {
687    Array::try_from_op(|res| unsafe {
688        mlx_sys::mlx_arccosh(res, a.as_ref().as_ptr(), stream.as_ref().as_ptr())
689    })
690}
691
692/// See [`Array::add`].
693#[generate_macro]
694#[default_device]
695pub fn add_device(
696    lhs: impl AsRef<Array>,
697    rhs: impl AsRef<Array>,
698    #[optional] stream: impl AsRef<Stream>,
699) -> Result<Array> {
700    lhs.as_ref().add_device(rhs, stream)
701}
702
703/// Element-wise inverse sine.
704#[generate_macro]
705#[default_device]
706pub fn asin_device(a: impl AsRef<Array>, #[optional] stream: impl AsRef<Stream>) -> Result<Array> {
707    Array::try_from_op(|res| unsafe {
708        mlx_sys::mlx_arcsin(res, a.as_ref().as_ptr(), stream.as_ref().as_ptr())
709    })
710}
711
712/// Element-wise inverse hyperbolic sine.
713#[generate_macro]
714#[default_device]
715pub fn asinh_device(a: impl AsRef<Array>, #[optional] stream: impl AsRef<Stream>) -> Result<Array> {
716    Array::try_from_op(|res| unsafe {
717        mlx_sys::mlx_arcsinh(res, a.as_ref().as_ptr(), stream.as_ref().as_ptr())
718    })
719}
720
721/// Element-wise inverse tangent.
722#[generate_macro]
723#[default_device]
724pub fn atan_device(a: impl AsRef<Array>, #[optional] stream: impl AsRef<Stream>) -> Result<Array> {
725    Array::try_from_op(|res| unsafe {
726        mlx_sys::mlx_arctan(res, a.as_ref().as_ptr(), stream.as_ref().as_ptr())
727    })
728}
729
730/// Element-wise inverse hyperbolic tangent.
731#[generate_macro]
732#[default_device]
733pub fn atanh_device(a: impl AsRef<Array>, #[optional] stream: impl AsRef<Stream>) -> Result<Array> {
734    Array::try_from_op(|res| unsafe {
735        mlx_sys::mlx_arctanh(res, a.as_ref().as_ptr(), stream.as_ref().as_ptr())
736    })
737}
738
739/// Element-wise ceiling.
740#[generate_macro]
741#[default_device]
742pub fn ceil_device(a: impl AsRef<Array>, #[optional] stream: impl AsRef<Stream>) -> Result<Array> {
743    Array::try_from_op(|res| unsafe {
744        mlx_sys::mlx_ceil(res, a.as_ref().as_ptr(), stream.as_ref().as_ptr())
745    })
746}
747
748/// A custom trait for the bound of the clip operation.
749///
750/// This trait is only implemented for tuples of the form `(Min, Max)`, `(Min, ())`, and `((),
751/// Max)`. The `Min` and `Max` types must implement the `ScalarOrArray` trait.
752pub trait ClipBound<'min, 'max>: Sealed {
753    /// Convert the bound into a tuple of optional minimum and maximum values.
754    fn into_min_max(
755        self,
756    ) -> (
757        Option<impl ScalarOrArray<'min>>,
758        Option<impl ScalarOrArray<'max>>,
759    );
760}
761
762impl<'min, Min> ClipBound<'min, 'min> for (Min, ())
763where
764    Min: ScalarOrArray<'min> + Sealed,
765{
766    fn into_min_max(
767        self,
768    ) -> (
769        Option<impl ScalarOrArray<'min>>,
770        Option<impl ScalarOrArray<'min>>,
771    ) {
772        (Some(self.0), Option::<Min>::None)
773    }
774}
775
776impl<'max, Max> ClipBound<'max, 'max> for ((), Max)
777where
778    Max: ScalarOrArray<'max> + Sealed,
779{
780    fn into_min_max(
781        self,
782    ) -> (
783        Option<impl ScalarOrArray<'max>>,
784        Option<impl ScalarOrArray<'max>>,
785    ) {
786        (Option::<Max>::None, Some(self.1))
787    }
788}
789
790impl<'min, 'max, Min, Max> ClipBound<'min, 'max> for (Min, Max)
791where
792    Min: ScalarOrArray<'min> + Sealed,
793    Max: ScalarOrArray<'max> + Sealed,
794{
795    fn into_min_max(
796        self,
797    ) -> (
798        Option<impl ScalarOrArray<'min>>,
799        Option<impl ScalarOrArray<'max>>,
800    ) {
801        (Some(self.0), Some(self.1))
802    }
803}
804
805/// Clip the values of the array between the given minimum and maximum.
806///
807/// If either `a_min` or `a_max` are None, then corresponding edge is ignored. At least one of
808/// `a_min` and `a_max` cannot be `None`. The input `a` and the limits must broadcast with one
809/// another.
810///
811/// # Params
812///
813/// - `a`: Input array.
814/// - `bound`: minimum and/or maximum values to clip the array to.
815///
816/// # Example
817///
818/// ```rust
819/// use mlx_rs::{Array, ops::clip, array};
820///
821/// let a = array!([1.0, 4.0, 3.0, 8.0, 5.0]);
822/// let expected = array!([2.0, 4.0, 3.0, 6.0, 5.0]);
823/// let clipped = clip(&a, (2.0, 6.0)).unwrap();
824/// assert_eq!(clipped, expected);
825/// ```
826#[generate_macro]
827#[default_device]
828pub fn clip_device<'min, 'max>(
829    a: impl AsRef<Array>,
830    bound: impl ClipBound<'min, 'max>,
831    #[optional] stream: impl AsRef<Stream>,
832) -> Result<Array> {
833    let (a_min, a_max) = bound.into_min_max();
834
835    // This is needed to keep the lifetime of the min/max arrays in scope.
836    let a_min = a_min.map(|min| min.into_owned_or_ref_array());
837    let a_max = a_max.map(|max| max.into_owned_or_ref_array());
838
839    unsafe {
840        let min_ptr = match &a_min {
841            Some(a_min) => a_min.as_ref().as_ptr(),
842            None => mlx_sys::mlx_array_new(),
843        };
844        let max_ptr = match &a_max {
845            Some(a_max) => a_max.as_ref().as_ptr(),
846            None => mlx_sys::mlx_array_new(),
847        };
848
849        Array::try_from_op(|res| {
850            mlx_sys::mlx_clip(
851                res,
852                a.as_ref().as_ptr(),
853                min_ptr,
854                max_ptr,
855                stream.as_ref().as_ptr(),
856            )
857        })
858    }
859}
860
861/// Element-wise cosine.
862#[generate_macro]
863#[default_device]
864pub fn cos_device(a: impl AsRef<Array>, #[optional] stream: impl AsRef<Stream>) -> Result<Array> {
865    a.as_ref().cos_device(stream)
866}
867
868/// Element-wise hyperbolic cosine.
869#[generate_macro]
870#[default_device]
871pub fn cosh_device(a: impl AsRef<Array>, #[optional] stream: impl AsRef<Stream>) -> Result<Array> {
872    Array::try_from_op(|res| unsafe {
873        mlx_sys::mlx_cosh(res, a.as_ref().as_ptr(), stream.as_ref().as_ptr())
874    })
875}
876
877/// Convert angles from radians to degrees.
878#[generate_macro]
879#[default_device]
880pub fn degrees_device(
881    a: impl AsRef<Array>,
882    #[optional] stream: impl AsRef<Stream>,
883) -> Result<Array> {
884    Array::try_from_op(|res| unsafe {
885        mlx_sys::mlx_degrees(res, a.as_ref().as_ptr(), stream.as_ref().as_ptr())
886    })
887}
888
889/// See [`Array::divide`].
890#[generate_macro]
891#[default_device]
892pub fn divide_device(
893    a: impl AsRef<Array>,
894    b: impl AsRef<Array>,
895    #[optional] stream: impl AsRef<Stream>,
896) -> Result<Array> {
897    a.as_ref().divide_device(b, stream)
898}
899
900/// Element-wise quotient and remainder.
901///
902/// The fuction `divmod(a, b)` is equivalent to but faster than `(a // b, a % b)`. The function uses
903/// numpy-style broadcasting semantics. Either or both input arrays can also be scalars.
904///
905/// Returns Ok((quotient, remainder)) if the operation was successful.
906#[generate_macro]
907#[default_device]
908pub fn divmod_device(
909    a: impl AsRef<Array>,
910    b: impl AsRef<Array>,
911    #[optional] stream: impl AsRef<Stream>,
912) -> Result<(Array, Array)> {
913    let a_ptr = a.as_ref().as_ptr();
914    let b_ptr = b.as_ref().as_ptr();
915
916    let vec = VectorArray::try_from_op(|res| unsafe {
917        mlx_sys::mlx_divmod(res, a_ptr, b_ptr, stream.as_ref().as_ptr())
918    })?;
919
920    let vals: SmallVec<[_; 2]> = vec.try_into_values()?;
921    let mut iter = vals.into_iter();
922    let quotient = iter.next().unwrap();
923    let remainder = iter.next().unwrap();
924
925    Ok((quotient, remainder))
926}
927
928/// Element-wise error function.
929#[generate_macro]
930#[default_device]
931pub fn erf_device(a: impl AsRef<Array>, #[optional] stream: impl AsRef<Stream>) -> Result<Array> {
932    Array::try_from_op(|res| unsafe {
933        mlx_sys::mlx_erf(res, a.as_ref().as_ptr(), stream.as_ref().as_ptr())
934    })
935}
936
937/// Element-wise inverse error function.
938#[generate_macro]
939#[default_device]
940pub fn erfinv_device(
941    a: impl AsRef<Array>,
942    #[optional] stream: impl AsRef<Stream>,
943) -> Result<Array> {
944    Array::try_from_op(|res| unsafe {
945        mlx_sys::mlx_erfinv(res, a.as_ref().as_ptr(), stream.as_ref().as_ptr())
946    })
947}
948
949/// See [`Array::exp`].
950#[generate_macro]
951#[default_device]
952pub fn exp_device(a: impl AsRef<Array>, #[optional] stream: impl AsRef<Stream>) -> Result<Array> {
953    a.as_ref().exp_device(stream)
954}
955
956/// Element-wise exponential minus 1.
957#[generate_macro]
958#[default_device]
959pub fn expm1_device(a: impl AsRef<Array>, #[optional] stream: impl AsRef<Stream>) -> Result<Array> {
960    Array::try_from_op(|res| unsafe {
961        mlx_sys::mlx_expm1(res, a.as_ref().as_ptr(), stream.as_ref().as_ptr())
962    })
963}
964
965/// See [`Array::floor`].
966#[generate_macro]
967#[default_device]
968pub fn floor_device(a: impl AsRef<Array>, #[optional] stream: impl AsRef<Stream>) -> Result<Array> {
969    a.as_ref().floor_device(stream)
970}
971
972/// See [`Array::floor_divide`].
973#[generate_macro]
974#[default_device]
975pub fn floor_divide_device(
976    a: impl AsRef<Array>,
977    other: impl AsRef<Array>,
978    #[optional] stream: impl AsRef<Stream>,
979) -> Result<Array> {
980    a.as_ref().floor_divide_device(other, stream)
981}
982
983/// See [`Array::log`].
984#[generate_macro]
985#[default_device]
986pub fn log_device(a: impl AsRef<Array>, #[optional] stream: impl AsRef<Stream>) -> Result<Array> {
987    a.as_ref().log_device(stream)
988}
989
990/// See [`Array::log10`].
991#[generate_macro]
992#[default_device]
993pub fn log10_device(a: impl AsRef<Array>, #[optional] stream: impl AsRef<Stream>) -> Result<Array> {
994    a.as_ref().log10_device(stream)
995}
996
997/// See [`Array::log1p`].
998#[generate_macro]
999#[default_device]
1000pub fn log1p_device(a: impl AsRef<Array>, #[optional] stream: impl AsRef<Stream>) -> Result<Array> {
1001    a.as_ref().log1p_device(stream)
1002}
1003
1004/// See [`Array::log2`].
1005#[generate_macro]
1006#[default_device]
1007pub fn log2_device(a: impl AsRef<Array>, #[optional] stream: impl AsRef<Stream>) -> Result<Array> {
1008    a.as_ref().log2_device(stream)
1009}
1010
1011/// Element-wise log-add-exp.
1012///
1013/// This is a numerically stable log-add-exp of two arrays with numpy-style broadcasting semantics.
1014/// Either or both input arrays can also be scalars.
1015///
1016/// The computation is is a numerically stable version of `log(exp(a) + exp(b))`.
1017#[generate_macro]
1018#[default_device]
1019pub fn log_add_exp_device(
1020    a: impl AsRef<Array>,
1021    b: impl AsRef<Array>,
1022    #[optional] stream: impl AsRef<Stream>,
1023) -> Result<Array> {
1024    let a_ptr = a.as_ref().as_ptr();
1025    let b_ptr = b.as_ref().as_ptr();
1026
1027    Array::try_from_op(|res| unsafe {
1028        mlx_sys::mlx_logaddexp(res, a_ptr, b_ptr, stream.as_ref().as_ptr())
1029    })
1030}
1031
1032/// See [`Array::matmul`].
1033#[generate_macro]
1034#[default_device]
1035pub fn matmul_device(
1036    a: impl AsRef<Array>,
1037    b: impl AsRef<Array>,
1038    #[optional] stream: impl AsRef<Stream>,
1039) -> Result<Array> {
1040    a.as_ref().matmul_device(b, stream)
1041}
1042
1043/// Element-wise maximum.
1044///
1045/// Take the element-wise max of two arrays with numpy-style broadcasting semantics. Either or both
1046/// input arrays can also be scalars.
1047#[generate_macro]
1048#[default_device]
1049pub fn maximum_device(
1050    a: impl AsRef<Array>,
1051    b: impl AsRef<Array>,
1052    #[optional] stream: impl AsRef<Stream>,
1053) -> Result<Array> {
1054    let a_ptr = a.as_ref().as_ptr();
1055    let b_ptr = b.as_ref().as_ptr();
1056
1057    Array::try_from_op(|res| unsafe {
1058        mlx_sys::mlx_maximum(res, a_ptr, b_ptr, stream.as_ref().as_ptr())
1059    })
1060}
1061
1062/// Element-wise minimum.
1063///
1064/// Take the element-wise min of two arrays with numpy-style broadcasting semantics. Either or both
1065/// input arrays can also be scalars.
1066#[generate_macro]
1067#[default_device]
1068pub fn minimum_device(
1069    a: impl AsRef<Array>,
1070    b: impl AsRef<Array>,
1071    #[optional] stream: impl AsRef<Stream>,
1072) -> Result<Array> {
1073    let a_ptr = a.as_ref().as_ptr();
1074    let b_ptr = b.as_ref().as_ptr();
1075
1076    Array::try_from_op(|res| unsafe {
1077        mlx_sys::mlx_minimum(res, a_ptr, b_ptr, stream.as_ref().as_ptr())
1078    })
1079}
1080
1081/// See [`Array::multiply`].
1082#[generate_macro]
1083#[default_device]
1084pub fn multiply_device(
1085    a: impl AsRef<Array>,
1086    b: impl AsRef<Array>,
1087    #[optional] stream: impl AsRef<Stream>,
1088) -> Result<Array> {
1089    a.as_ref().multiply_device(b, stream)
1090}
1091
1092/// See [`Array::negative`].
1093#[generate_macro]
1094#[default_device]
1095pub fn negative_device(
1096    a: impl AsRef<Array>,
1097    #[optional] stream: impl AsRef<Stream>,
1098) -> Result<Array> {
1099    a.as_ref().negative_device(stream)
1100}
1101
1102/// See [`Array::power`].
1103#[generate_macro]
1104#[default_device]
1105pub fn power_device(
1106    a: impl AsRef<Array>,
1107    b: impl AsRef<Array>,
1108    #[optional] stream: impl AsRef<Stream>,
1109) -> Result<Array> {
1110    a.as_ref().power_device(b, stream)
1111}
1112
1113/// Convert angles from degrees to radians.
1114#[generate_macro]
1115#[default_device]
1116pub fn radians_device(
1117    a: impl AsRef<Array>,
1118    #[optional] stream: impl AsRef<Stream>,
1119) -> Result<Array> {
1120    Array::try_from_op(|res| unsafe {
1121        mlx_sys::mlx_radians(res, a.as_ref().as_ptr(), stream.as_ref().as_ptr())
1122    })
1123}
1124
1125/// See [`Array::reciprocal`].
1126#[generate_macro]
1127#[default_device]
1128pub fn reciprocal_device(
1129    a: impl AsRef<Array>,
1130    #[optional] stream: impl AsRef<Stream>,
1131) -> Result<Array> {
1132    a.as_ref().reciprocal_device(stream)
1133}
1134
1135/// See [`Array::remainder`].
1136#[generate_macro]
1137#[default_device]
1138pub fn remainder_device(
1139    a: impl AsRef<Array>,
1140    b: impl AsRef<Array>,
1141    #[optional] stream: impl AsRef<Stream>,
1142) -> Result<Array> {
1143    a.as_ref().remainder_device(b, stream)
1144}
1145
1146/// See [`Array::round`].
1147#[generate_macro]
1148#[default_device]
1149pub fn round_device(
1150    a: impl AsRef<Array>,
1151    decimals: impl Into<Option<i32>>,
1152    #[optional] stream: impl AsRef<Stream>,
1153) -> Result<Array> {
1154    a.as_ref().round_device(decimals, stream)
1155}
1156
1157/// See [`Array::rsqrt`].
1158#[generate_macro]
1159#[default_device]
1160pub fn rsqrt_device(a: impl AsRef<Array>, #[optional] stream: impl AsRef<Stream>) -> Result<Array> {
1161    a.as_ref().rsqrt_device(stream)
1162}
1163
1164/// Element-wise logistic sigmoid.
1165///
1166/// See the [python API
1167/// docs](https://ml-explore.github.io/mlx/build/html/python/_autosummary/mlx.core.sigmoid.html#mlx.core.sigmoid)
1168/// for more information
1169#[generate_macro]
1170#[default_device]
1171pub fn sigmoid_device(
1172    a: impl AsRef<Array>,
1173    #[optional] stream: impl AsRef<Stream>,
1174) -> Result<Array> {
1175    Array::try_from_op(|res| unsafe {
1176        mlx_sys::mlx_sigmoid(res, a.as_ref().as_ptr(), stream.as_ref().as_ptr())
1177    })
1178}
1179
1180/// Element-wise sign.
1181#[generate_macro]
1182#[default_device]
1183pub fn sign_device(a: impl AsRef<Array>, #[optional] stream: impl AsRef<Stream>) -> Result<Array> {
1184    Array::try_from_op(|res| unsafe {
1185        mlx_sys::mlx_sign(res, a.as_ref().as_ptr(), stream.as_ref().as_ptr())
1186    })
1187}
1188
1189/// See [`Array::sin`].
1190#[generate_macro]
1191#[default_device]
1192pub fn sin_device(a: impl AsRef<Array>, #[optional] stream: impl AsRef<Stream>) -> Result<Array> {
1193    a.as_ref().sin_device(stream)
1194}
1195
1196/// Element-wise hyperbolic sine.
1197#[generate_macro]
1198#[default_device]
1199pub fn sinh_device(a: impl AsRef<Array>, #[optional] stream: impl AsRef<Stream>) -> Result<Array> {
1200    Array::try_from_op(|res| unsafe {
1201        mlx_sys::mlx_sinh(res, a.as_ref().as_ptr(), stream.as_ref().as_ptr())
1202    })
1203}
1204
1205/// Perform the softmax along the given axis.
1206///
1207/// See the [python API
1208/// docs](https://ml-explore.github.io/mlx/build/html/python/_autosummary/mlx.core.softmax.html#mlx.core.softmax)
1209/// for more information.
1210#[generate_macro]
1211#[default_device]
1212pub fn softmax_device(
1213    a: impl AsRef<Array>,
1214    axes: &[i32],
1215    precise: impl Into<Option<bool>>,
1216    #[optional] stream: impl AsRef<Stream>,
1217) -> Result<Array> {
1218    let precise = precise.into().unwrap_or(false);
1219    let s = stream.as_ref().as_ptr();
1220
1221    Array::try_from_op(|res| unsafe {
1222        mlx_sys::mlx_softmax(
1223            res,
1224            a.as_ref().as_ptr(),
1225            axes.as_ptr(),
1226            axes.len(),
1227            precise,
1228            s,
1229        )
1230    })
1231}
1232
1233/// Perform the softmax along all axes.
1234#[generate_macro]
1235#[default_device]
1236pub fn softmax_all_device(
1237    a: impl AsRef<Array>,
1238    #[optional] precise: impl Into<Option<bool>>,
1239    #[optional] stream: impl AsRef<Stream>,
1240) -> Result<Array> {
1241    let precise = precise.into().unwrap_or(false);
1242    let s = stream.as_ref().as_ptr();
1243
1244    Array::try_from_op(|res| unsafe {
1245        mlx_sys::mlx_softmax_all(res, a.as_ref().as_ptr(), precise, s)
1246    })
1247}
1248
1249/// See [`Array::sqrt`].
1250#[generate_macro]
1251#[default_device]
1252pub fn sqrt_device(a: impl AsRef<Array>, #[optional] stream: impl AsRef<Stream>) -> Result<Array> {
1253    a.as_ref().sqrt_device(stream)
1254}
1255
1256/// See [`Array::square`].
1257#[generate_macro]
1258#[default_device]
1259pub fn square_device(
1260    a: impl AsRef<Array>,
1261    #[optional] stream: impl AsRef<Stream>,
1262) -> Result<Array> {
1263    a.as_ref().square_device(stream)
1264}
1265
1266/// See [`Array::subtract`].
1267#[generate_macro]
1268#[default_device]
1269pub fn subtract_device(
1270    a: impl AsRef<Array>,
1271    b: impl AsRef<Array>,
1272    #[optional] stream: impl AsRef<Stream>,
1273) -> Result<Array> {
1274    a.as_ref().subtract_device(b, stream)
1275}
1276
1277/// See [`Array::tan`].
1278#[generate_macro]
1279#[default_device]
1280pub fn tan_device(a: impl AsRef<Array>, #[optional] stream: impl AsRef<Stream>) -> Result<Array> {
1281    Array::try_from_op(|res| unsafe {
1282        mlx_sys::mlx_tan(res, a.as_ref().as_ptr(), stream.as_ref().as_ptr())
1283    })
1284}
1285
1286/// Element-wise hyperbolic tangent.
1287#[generate_macro]
1288#[default_device]
1289pub fn tanh_device(a: impl AsRef<Array>, #[optional] stream: impl AsRef<Stream>) -> Result<Array> {
1290    Array::try_from_op(|res| unsafe {
1291        mlx_sys::mlx_tanh(res, a.as_ref().as_ptr(), stream.as_ref().as_ptr())
1292    })
1293}
1294
1295/// Matrix multiplication with block masking.
1296///
1297/// See the [python API docs](
1298/// https://ml-explore.github.io/mlx/build/html/python/_autosummary/mlx.core.block_masked_mm.html#mlx.core.block_masked_mm
1299/// ) for more information.
1300#[generate_macro]
1301#[default_device]
1302pub fn block_masked_mm_device<'mo, 'lhs, 'rhs>(
1303    a: impl AsRef<Array>,
1304    b: impl AsRef<Array>,
1305    #[optional] block_size: impl Into<Option<i32>>,
1306    #[optional] mask_out: impl Into<Option<&'mo Array>>,
1307    #[optional] mask_lhs: impl Into<Option<&'lhs Array>>,
1308    #[optional] mask_rhs: impl Into<Option<&'rhs Array>>,
1309    #[optional] stream: impl AsRef<Stream>,
1310) -> Result<Array> {
1311    let a_ptr = a.as_ref().as_ptr();
1312    let b_ptr = b.as_ref().as_ptr();
1313    unsafe {
1314        let mask_out_ptr = mask_out
1315            .into()
1316            .map(|m| m.as_ptr())
1317            .unwrap_or(mlx_sys::mlx_array_new());
1318        let mask_lhs_ptr = mask_lhs
1319            .into()
1320            .map(|m| m.as_ptr())
1321            .unwrap_or(mlx_sys::mlx_array_new());
1322        let mask_rhs_ptr = mask_rhs
1323            .into()
1324            .map(|m| m.as_ptr())
1325            .unwrap_or(mlx_sys::mlx_array_new());
1326
1327        Array::try_from_op(|res| {
1328            mlx_sys::mlx_block_masked_mm(
1329                res,
1330                a_ptr,
1331                b_ptr,
1332                block_size.into().unwrap_or(32),
1333                mask_out_ptr,
1334                mask_lhs_ptr,
1335                mask_rhs_ptr,
1336                stream.as_ref().as_ptr(),
1337            )
1338        })
1339    }
1340}
1341
1342/// Matrix multiplication with addition and optional scaling.
1343///
1344/// Perform the (possibly batched) matrix multiplication of two arrays and add to the result with
1345/// optional scaling factors.
1346///
1347/// # Params
1348///
1349/// - `c`: input array,
1350/// - `a`: input array,
1351/// - `b`: input array,
1352/// - `alpha`: Scaling factor for the matrix product of `a` and `b` (default: `1`)
1353/// - `beta`: Scaling factor for `c` (default: `1`)
1354#[generate_macro]
1355#[default_device]
1356pub fn addmm_device(
1357    c: impl AsRef<Array>,
1358    a: impl AsRef<Array>,
1359    b: impl AsRef<Array>,
1360    #[optional] alpha: impl Into<Option<f32>>,
1361    #[optional] beta: impl Into<Option<f32>>,
1362    #[optional] stream: impl AsRef<Stream>,
1363) -> Result<Array> {
1364    let c_ptr = c.as_ref().as_ptr();
1365    let a_ptr = a.as_ref().as_ptr();
1366    let b_ptr = b.as_ref().as_ptr();
1367    let alpha = alpha.into().unwrap_or(1.0);
1368    let beta = beta.into().unwrap_or(1.0);
1369
1370    Array::try_from_op(|res| unsafe {
1371        mlx_sys::mlx_addmm(
1372            res,
1373            c_ptr,
1374            a_ptr,
1375            b_ptr,
1376            alpha,
1377            beta,
1378            stream.as_ref().as_ptr(),
1379        )
1380    })
1381}
1382
1383/// Ordinary inner product of vectors for 1-D arrays, in higher dimensions a sum product over the
1384/// last axes.
1385#[generate_macro]
1386#[default_device]
1387pub fn inner_device(
1388    a: impl AsRef<Array>,
1389    b: impl AsRef<Array>,
1390    #[optional] stream: impl AsRef<Stream>,
1391) -> Result<Array> {
1392    let a = a.as_ref();
1393    let b = b.as_ref();
1394    Array::try_from_op(|res| unsafe {
1395        mlx_sys::mlx_inner(res, a.as_ptr(), b.as_ptr(), stream.as_ref().as_ptr())
1396    })
1397}
1398
1399/// Compute the outer product of two 1-D arrays, if the array’s passed are not 1-D a flatten op will
1400/// be run beforehand.
1401#[generate_macro]
1402#[default_device]
1403pub fn outer_device(
1404    a: impl AsRef<Array>,
1405    b: impl AsRef<Array>,
1406    #[optional] stream: impl AsRef<Stream>,
1407) -> Result<Array> {
1408    let a = a.as_ref();
1409    let b = b.as_ref();
1410    Array::try_from_op(|res| unsafe {
1411        mlx_sys::mlx_outer(res, a.as_ptr(), b.as_ptr(), stream.as_ref().as_ptr())
1412    })
1413}
1414
1415/// The number of dimensions to sum over
1416#[derive(Debug)]
1417pub enum TensorDotDims<'a> {
1418    /// Sum over the last axes dimensions of a and the first axes dimensions of b.
1419    Int(i32),
1420
1421    /// Sum over the corresponding dimensions of a and b.
1422    List((&'a [i32], &'a [i32])),
1423}
1424
1425impl From<i32> for TensorDotDims<'_> {
1426    fn from(i: i32) -> Self {
1427        TensorDotDims::Int(i)
1428    }
1429}
1430
1431impl<'a> From<(&'a [i32], &'a [i32])> for TensorDotDims<'a> {
1432    fn from((lhs, rhs): (&'a [i32], &'a [i32])) -> Self {
1433        TensorDotDims::List((lhs, rhs))
1434    }
1435}
1436
1437impl<'a, const M: usize, const N: usize> From<(&'a [i32; M], &'a [i32; N])> for TensorDotDims<'a> {
1438    fn from((lhs, rhs): (&'a [i32; M], &'a [i32; N])) -> Self {
1439        TensorDotDims::List((lhs, rhs))
1440    }
1441}
1442
1443/// Compute the tensor dot product along the specified axes.
1444///
1445/// # Params
1446///
1447/// - `a`: input array,
1448/// - `b`: input array,
1449/// - `axes`: The number of dimensions to sum over. If an integer is provided, then sum over
1450///   the last axes dimensions of a and the first axes dimensions of b. If a tuple of lists is
1451///   provided, then sum over the corresponding dimensions of a and b. (default: 2)
1452#[generate_macro]
1453#[default_device]
1454pub fn tensordot_device<'a>(
1455    a: impl AsRef<Array>,
1456    b: impl AsRef<Array>,
1457    #[optional] axes: impl Into<TensorDotDims<'a>>,
1458    #[optional] stream: impl AsRef<Stream>,
1459) -> Result<Array> {
1460    let a = a.as_ref();
1461    let b = b.as_ref();
1462    match axes.into() {
1463        TensorDotDims::Int(dim) => Array::try_from_op(|res| unsafe {
1464            mlx_sys::mlx_tensordot_along_axis(
1465                res,
1466                a.as_ptr(),
1467                b.as_ptr(),
1468                dim,
1469                stream.as_ref().as_ptr(),
1470            )
1471        }),
1472        TensorDotDims::List((lhs, rhs)) => Array::try_from_op(|res| unsafe {
1473            mlx_sys::mlx_tensordot(
1474                res,
1475                a.as_ptr(),
1476                b.as_ptr(),
1477                lhs.as_ptr(),
1478                lhs.len(),
1479                rhs.as_ptr(),
1480                rhs.len(),
1481                stream.as_ref().as_ptr(),
1482            )
1483        }),
1484    }
1485}
1486
1487#[cfg(test)]
1488mod tests {
1489    use std::f32::consts::PI;
1490
1491    use super::*;
1492    use crate::{
1493        array, complex64,
1494        ops::{all_close, arange, broadcast_to, eye, full, linspace, ones, reshape, split_equal},
1495        transforms::eval,
1496        Dtype,
1497    };
1498    use float_eq::assert_float_eq;
1499    use pretty_assertions::assert_eq;
1500
1501    #[test]
1502    fn test_abs() {
1503        let data = [1i32, 2, -3, -4, -5];
1504        let array = Array::from_slice(&data, &[5]);
1505        let result = array.abs().unwrap();
1506
1507        let data: &[i32] = result.as_slice();
1508        assert_eq!(data, [1, 2, 3, 4, 5]);
1509
1510        // test that previous array is not modified and valid
1511        let data: &[i32] = array.as_slice();
1512        assert_eq!(data, [1, 2, -3, -4, -5]);
1513    }
1514
1515    #[test]
1516    fn test_add() {
1517        let a = Array::from_slice(&[1.0, 2.0, 3.0], &[3]);
1518        let b = Array::from_slice(&[4.0, 5.0, 6.0], &[3]);
1519
1520        let c = &a + &b;
1521
1522        let c_data: &[f32] = c.as_slice();
1523        assert_eq!(c_data, &[5.0, 7.0, 9.0]);
1524
1525        // check a and b are not modified
1526        let a_data: &[f32] = a.as_slice();
1527        assert_eq!(a_data, &[1.0, 2.0, 3.0]);
1528
1529        let b_data: &[f32] = b.as_slice();
1530        assert_eq!(b_data, &[4.0, 5.0, 6.0]);
1531    }
1532
1533    #[test]
1534    fn test_add_invalid_broadcast() {
1535        let a = Array::from_slice(&[1.0, 2.0, 3.0], &[3]);
1536        let b = Array::from_slice(&[4.0, 5.0], &[2]);
1537
1538        let c = a.add(&b);
1539        assert!(c.is_err());
1540    }
1541
1542    #[test]
1543    fn test_sub() {
1544        let a = Array::from_slice(&[1.0, 2.0, 3.0], &[3]);
1545        let b = Array::from_slice(&[4.0, 5.0, 6.0], &[3]);
1546
1547        let c = &a - &b;
1548
1549        let c_data: &[f32] = c.as_slice();
1550        assert_eq!(c_data, &[-3.0, -3.0, -3.0]);
1551
1552        // check a and b are not modified
1553        let a_data: &[f32] = a.as_slice();
1554        assert_eq!(a_data, &[1.0, 2.0, 3.0]);
1555
1556        let b_data: &[f32] = b.as_slice();
1557        assert_eq!(b_data, &[4.0, 5.0, 6.0]);
1558    }
1559
1560    #[test]
1561    fn test_sub_invalid_broadcast() {
1562        let a = Array::from_slice(&[1.0, 2.0, 3.0], &[3]);
1563        let b = Array::from_slice(&[4.0, 5.0], &[2]);
1564        let c = a.subtract(&b);
1565        assert!(c.is_err());
1566    }
1567
1568    #[test]
1569    fn test_neg() {
1570        let a = Array::from_slice::<f32>(&[1.0, 2.0, 3.0], &[3]);
1571        let b = a.negative().unwrap();
1572
1573        let b_data: &[f32] = b.as_slice();
1574        assert_eq!(b_data, &[-1.0, -2.0, -3.0]);
1575
1576        // check a is not modified
1577        let a_data: &[f32] = a.as_slice();
1578        assert_eq!(a_data, &[1.0, 2.0, 3.0]);
1579    }
1580
1581    #[test]
1582    fn test_neg_bool() {
1583        let a = Array::from_slice(&[true, false, true], &[3]);
1584        let b = a.negative();
1585        assert!(b.is_err());
1586    }
1587
1588    #[test]
1589    fn test_logical_not() {
1590        let a: Array = false.into();
1591        let b = a.logical_not().unwrap();
1592
1593        let b_data: &[bool] = b.as_slice();
1594        assert_eq!(b_data, [true]);
1595    }
1596
1597    #[test]
1598    fn test_mul() {
1599        let a = Array::from_slice(&[1.0, 2.0, 3.0], &[3]);
1600        let b = Array::from_slice(&[4.0, 5.0, 6.0], &[3]);
1601
1602        let c = &a * &b;
1603
1604        let c_data: &[f32] = c.as_slice();
1605        assert_eq!(c_data, &[4.0, 10.0, 18.0]);
1606
1607        // check a and b are not modified
1608        let a_data: &[f32] = a.as_slice();
1609        assert_eq!(a_data, &[1.0, 2.0, 3.0]);
1610
1611        let b_data: &[f32] = b.as_slice();
1612        assert_eq!(b_data, &[4.0, 5.0, 6.0]);
1613    }
1614
1615    #[test]
1616    fn test_mul_invalid_broadcast() {
1617        let a = Array::from_slice(&[1.0, 2.0, 3.0], &[3]);
1618        let b = Array::from_slice(&[4.0, 5.0], &[2]);
1619        let c = a.multiply(&b);
1620        assert!(c.is_err());
1621    }
1622
1623    #[test]
1624    fn test_nan_to_num() {
1625        let a = array!([1.0, 2.0, f32::NAN, 4.0, 5.0]);
1626        let b = a.nan_to_num(0.0, 1.0, 0.0).unwrap();
1627
1628        let b_data: &[f32] = b.as_slice();
1629        assert_eq!(b_data, &[1.0, 2.0, 0.0, 4.0, 5.0]);
1630    }
1631
1632    #[test]
1633    fn test_div() {
1634        let a = Array::from_slice(&[1.0, 2.0, 3.0], &[3]);
1635        let b = Array::from_slice(&[4.0, 5.0, 6.0], &[3]);
1636
1637        let c = &a / &b;
1638
1639        let c_data: &[f32] = c.as_slice();
1640        assert_eq!(c_data, &[0.25, 0.4, 0.5]);
1641
1642        // check a and b are not modified
1643        let a_data: &[f32] = a.as_slice();
1644        assert_eq!(a_data, &[1.0, 2.0, 3.0]);
1645
1646        let b_data: &[f32] = b.as_slice();
1647        assert_eq!(b_data, &[4.0, 5.0, 6.0]);
1648    }
1649
1650    #[test]
1651    fn test_div_invalid_broadcast() {
1652        let a = Array::from_slice(&[1.0, 2.0, 3.0], &[3]);
1653        let b = Array::from_slice(&[4.0, 5.0], &[2]);
1654        let c = a.divide(&b);
1655        assert!(c.is_err());
1656    }
1657
1658    #[test]
1659    fn test_pow() {
1660        let a = Array::from_slice(&[1.0, 2.0, 3.0], &[3]);
1661        let b = Array::from_slice(&[2.0, 3.0, 4.0], &[3]);
1662
1663        let c = a.power(&b).unwrap();
1664
1665        let c_data: &[f32] = c.as_slice();
1666        assert_eq!(c_data, &[1.0, 8.0, 81.0]);
1667
1668        // check a and b are not modified
1669        let a_data: &[f32] = a.as_slice();
1670        assert_eq!(a_data, &[1.0, 2.0, 3.0]);
1671
1672        let b_data: &[f32] = b.as_slice();
1673        assert_eq!(b_data, &[2.0, 3.0, 4.0]);
1674    }
1675
1676    #[test]
1677    fn test_pow_invalid_broadcast() {
1678        let a = Array::from_slice(&[1.0, 2.0, 3.0], &[3]);
1679        let b = Array::from_slice(&[2.0, 3.0], &[2]);
1680        let c = a.power(&b);
1681        assert!(c.is_err());
1682    }
1683
1684    #[test]
1685    fn test_rem() {
1686        let a = Array::from_slice(&[10.0, 11.0, 12.0], &[3]);
1687        let b = Array::from_slice(&[3.0, 4.0, 5.0], &[3]);
1688
1689        let c = &a % &b;
1690
1691        let c_data: &[f32] = c.as_slice();
1692        assert_eq!(c_data, &[1.0, 3.0, 2.0]);
1693
1694        // check a and b are not modified
1695        let a_data: &[f32] = a.as_slice();
1696        assert_eq!(a_data, &[10.0, 11.0, 12.0]);
1697
1698        let b_data: &[f32] = b.as_slice();
1699        assert_eq!(b_data, &[3.0, 4.0, 5.0]);
1700    }
1701
1702    #[test]
1703    fn test_rem_invalid_broadcast() {
1704        let a = Array::from_slice(&[10.0, 11.0, 12.0], &[3]);
1705        let b = Array::from_slice(&[3.0, 4.0], &[2]);
1706        let c = a.remainder(&b);
1707        assert!(c.is_err());
1708    }
1709
1710    #[test]
1711    fn test_sqrt() {
1712        let a = Array::from_slice(&[1.0, 4.0, 9.0], &[3]);
1713        let b = a.sqrt().unwrap();
1714
1715        let b_data: &[f32] = b.as_slice();
1716        assert_eq!(b_data, &[1.0, 2.0, 3.0]);
1717
1718        // check a is not modified
1719        let a_data: &[f32] = a.as_slice();
1720        assert_eq!(a_data, &[1.0, 4.0, 9.0]);
1721    }
1722
1723    #[test]
1724    fn test_cos() {
1725        let a = Array::from_slice(&[0.0, 1.0, 2.0], &[3]);
1726        let b = a.cos().unwrap();
1727
1728        let b_data: &[f32] = b.as_slice();
1729        assert_eq!(b_data, &[1.0, 0.54030234, -0.41614687]);
1730
1731        // check a is not modified
1732        let a_data: &[f32] = a.as_slice();
1733        assert_eq!(a_data, &[0.0, 1.0, 2.0]);
1734    }
1735
1736    #[test]
1737    fn test_exp() {
1738        let a = Array::from_slice(&[0.0, 1.0, 2.0], &[3]);
1739        let b = a.exp().unwrap();
1740
1741        let b_data: &[f32] = b.as_slice();
1742        assert_eq!(b_data, &[1.0, 2.7182817, 7.389056]);
1743
1744        // check a is not modified
1745        let a_data: &[f32] = a.as_slice();
1746        assert_eq!(a_data, &[0.0, 1.0, 2.0]);
1747    }
1748
1749    #[test]
1750    fn test_floor() {
1751        let a = Array::from_slice(&[0.1, 1.9, 2.5], &[3]);
1752        let b = a.floor().unwrap();
1753
1754        let b_data: &[f32] = b.as_slice();
1755        assert_eq!(b_data, &[0.0, 1.0, 2.0]);
1756
1757        // check a is not modified
1758        let a_data: &[f32] = a.as_slice();
1759        assert_eq!(a_data, &[0.1, 1.9, 2.5]);
1760    }
1761
1762    #[test]
1763    fn test_floor_complex64() {
1764        let val = complex64::new(1.0, 2.0);
1765        let a = Array::from_complex(val);
1766        let b = a.floor_device(StreamOrDevice::default());
1767        assert!(b.is_err());
1768    }
1769
1770    #[test]
1771    fn test_floor_divide() {
1772        let a = Array::from_slice(&[1.0, 2.0, 3.0], &[3]);
1773        let b = Array::from_slice(&[4.0, 5.0, 6.0], &[3]);
1774
1775        let c = a.floor_divide(&b).unwrap();
1776
1777        let c_data: &[f32] = c.as_slice();
1778        assert_eq!(c_data, &[0.0, 0.0, 0.0]);
1779
1780        // check a and b are not modified
1781        let a_data: &[f32] = a.as_slice();
1782        assert_eq!(a_data, &[1.0, 2.0, 3.0]);
1783
1784        let b_data: &[f32] = b.as_slice();
1785        assert_eq!(b_data, &[4.0, 5.0, 6.0]);
1786    }
1787
1788    #[test]
1789    fn test_floor_divide_complex64() {
1790        let val = complex64::new(1.0, 2.0);
1791        let a = Array::from_complex(val);
1792        let b = Array::from_slice(&[4.0, 5.0, 6.0], &[3]);
1793        let c = a.floor_divide_device(&b, StreamOrDevice::default());
1794        assert!(c.is_err());
1795    }
1796
1797    #[test]
1798    fn test_floor_divide_invalid_broadcast() {
1799        let a = Array::from_slice(&[1.0, 2.0, 3.0], &[3]);
1800        let b = Array::from_slice(&[4.0, 5.0], &[2]);
1801        let c = a.floor_divide_device(&b, StreamOrDevice::default());
1802        assert!(c.is_err());
1803    }
1804
1805    #[test]
1806    fn test_is_nan() {
1807        let a = Array::from_slice(&[1.0, f32::NAN, 3.0], &[3]);
1808        let b = a.is_nan().unwrap();
1809
1810        let b_data: &[bool] = b.as_slice();
1811        assert_eq!(b_data, &[false, true, false]);
1812    }
1813
1814    #[test]
1815    fn test_is_inf() {
1816        let a = Array::from_slice(&[1.0, f32::INFINITY, 3.0], &[3]);
1817        let b = a.is_inf().unwrap();
1818
1819        let b_data: &[bool] = b.as_slice();
1820        assert_eq!(b_data, &[false, true, false]);
1821    }
1822
1823    #[test]
1824    fn test_is_finite() {
1825        let a = Array::from_slice(&[1.0, f32::INFINITY, 3.0], &[3]);
1826        let b = a.is_finite().unwrap();
1827
1828        let b_data: &[bool] = b.as_slice();
1829        assert_eq!(b_data, &[true, false, true]);
1830    }
1831
1832    #[test]
1833    fn test_is_neg_inf() {
1834        let a = Array::from_slice(&[1.0, f32::NEG_INFINITY, 3.0], &[3]);
1835        let b = a.is_neg_inf().unwrap();
1836
1837        let b_data: &[bool] = b.as_slice();
1838        assert_eq!(b_data, &[false, true, false]);
1839    }
1840
1841    #[test]
1842    fn test_is_pos_inf() {
1843        let a = Array::from_slice(&[1.0, f32::INFINITY, 3.0], &[3]);
1844        let b = a.is_pos_inf().unwrap();
1845
1846        let b_data: &[bool] = b.as_slice();
1847        assert_eq!(b_data, &[false, true, false]);
1848    }
1849
1850    #[test]
1851    fn test_log() {
1852        let a = Array::from_slice(&[1.0, 2.0, 3.0], &[3]);
1853        let b = a.log().unwrap();
1854
1855        let b_data: &[f32] = b.as_slice();
1856        assert_eq!(b_data, &[0.0, 0.6931472, 1.0986123]);
1857
1858        // check a is not modified
1859        let a_data: &[f32] = a.as_slice();
1860        assert_eq!(a_data, &[1.0, 2.0, 3.0]);
1861    }
1862
1863    #[test]
1864    fn test_log2() {
1865        let a = Array::from_slice(&[1.0, 2.0, 4.0, 8.0], &[4]);
1866        let b = a.log2().unwrap();
1867
1868        let b_data: &[f32] = b.as_slice();
1869        assert_eq!(b_data, &[0.0, 1.0, 2.0, 3.0]);
1870
1871        // check a is not modified
1872        let a_data: &[f32] = a.as_slice();
1873        assert_eq!(a_data, &[1.0, 2.0, 4.0, 8.0]);
1874    }
1875
1876    #[test]
1877    fn test_log10() {
1878        let a = Array::from_slice(&[1.0, 10.0, 100.0], &[3]);
1879        let b = a.log10().unwrap();
1880
1881        let b_data: &[f32] = b.as_slice();
1882        assert_eq!(b_data, &[0.0, 1.0, 2.0]);
1883
1884        // check a is not modified
1885        let a_data: &[f32] = a.as_slice();
1886        assert_eq!(a_data, &[1.0, 10.0, 100.0]);
1887    }
1888
1889    #[test]
1890    fn test_log1p() {
1891        let a = Array::from_slice(&[1.0, 2.0, 3.0], &[3]);
1892        let b = a.log1p().unwrap();
1893
1894        let b_data: &[f32] = b.as_slice();
1895        assert_eq!(b_data, &[0.6931472, 1.0986123, 1.3862944]);
1896
1897        // check a is not modified
1898        let a_data: &[f32] = a.as_slice();
1899        assert_eq!(a_data, &[1.0, 2.0, 3.0]);
1900    }
1901
1902    #[test]
1903    fn test_matmul() {
1904        let a = Array::from_slice(&[1, 2, 3, 4], &[2, 2]);
1905        let b = Array::from_slice(&[-5.0, 37.5, 4., 7., 1., 0.], &[2, 3]);
1906
1907        let c = a.matmul(&b).unwrap();
1908
1909        assert_eq!(c.shape(), &[2, 3]);
1910        let c_data: &[f32] = c.as_slice();
1911        assert_eq!(c_data, &[9.0, 39.5, 4.0, 13.0, 116.5, 12.0]);
1912
1913        // check a and b are not modified
1914        let a_data: &[i32] = a.as_slice();
1915        assert_eq!(a_data, &[1, 2, 3, 4]);
1916
1917        let b_data: &[f32] = b.as_slice();
1918        assert_eq!(b_data, &[-5.0, 37.5, 4., 7., 1., 0.]);
1919    }
1920
1921    #[test]
1922    fn test_matmul_ndim_zero() {
1923        let a: Array = 1.0.into();
1924        let b = Array::from_slice::<i32>(&[1], &[1]);
1925        let c = a.matmul(&b);
1926        assert!(c.is_err());
1927    }
1928
1929    #[test]
1930    fn test_matmul_ndim_one() {
1931        let a = Array::from_slice(&[1.0, 2.0, 3.0, 4.0], &[4]);
1932        let b = Array::from_slice(&[1.0, 2.0, 3.0, 4.0], &[4]);
1933        let c = a.matmul(&b);
1934        assert!(c.is_ok());
1935    }
1936
1937    #[test]
1938    fn test_matmul_dim_mismatch() {
1939        let a = Array::from_slice(&[1, 2, 3, 4, 5, 6], &[2, 3]);
1940        let b = Array::from_slice(&[1, 2, 3, 4, 5, 6, 7, 8, 9, 10], &[2, 5]);
1941        let c = a.matmul(&b);
1942        assert!(c.is_err());
1943    }
1944
1945    #[test]
1946    fn test_matmul_non_float_output_type() {
1947        let a = Array::from_slice(&[1, 2, 3, 4], &[2, 2]);
1948        let b = Array::from_slice(&[5, 37, 4, 7, 1, 0], &[2, 3]);
1949
1950        let c = a.matmul(&b);
1951        assert!(c.is_err());
1952    }
1953
1954    #[test]
1955    fn test_reciprocal() {
1956        let a = Array::from_slice(&[1.0, 2.0, 4.0], &[3]);
1957        let b = a.reciprocal().unwrap();
1958
1959        let b_data: &[f32] = b.as_slice();
1960        assert_eq!(b_data, &[1.0, 0.5, 0.25]);
1961
1962        // check a is not modified
1963        let a_data: &[f32] = a.as_slice();
1964        assert_eq!(a_data, &[1.0, 2.0, 4.0]);
1965    }
1966
1967    #[test]
1968    fn test_round() {
1969        let a = Array::from_slice(&[1.1, 2.9, 3.5], &[3]);
1970        let b = a.round(None).unwrap();
1971
1972        let b_data: &[f32] = b.as_slice();
1973        assert_eq!(b_data, &[1.0, 3.0, 4.0]);
1974
1975        // check a is not modified
1976        let a_data: &[f32] = a.as_slice();
1977        assert_eq!(a_data, &[1.1, 2.9, 3.5]);
1978    }
1979
1980    #[test]
1981    fn test_rsqrt() {
1982        let a = Array::from_slice(&[1.0, 2.0, 4.0], &[3]);
1983        let b = a.rsqrt().unwrap();
1984
1985        let b_data: &[f32] = b.as_slice();
1986        assert_eq!(b_data, &[1.0, 0.70710677, 0.5]);
1987
1988        // check a is not modified
1989        let a_data: &[f32] = a.as_slice();
1990        assert_eq!(a_data, &[1.0, 2.0, 4.0]);
1991    }
1992
1993    #[test]
1994    fn test_sin() {
1995        let a = Array::from_slice(&[0.0, 1.0, 2.0], &[3]);
1996        let b = a.sin().unwrap();
1997
1998        let b_data: &[f32] = b.as_slice();
1999        assert_eq!(b_data, &[0.0, 0.841471, 0.9092974]);
2000
2001        // check a is not modified
2002        let a_data: &[f32] = a.as_slice();
2003        assert_eq!(a_data, &[0.0, 1.0, 2.0]);
2004    }
2005
2006    #[test]
2007    fn test_square() {
2008        let a = Array::from_slice(&[1.0, 2.0, 3.0], &[3]);
2009        let b = a.square().unwrap();
2010
2011        let b_data: &[f32] = b.as_slice();
2012        assert_eq!(b_data, &[1.0, 4.0, 9.0]);
2013
2014        // check a is not modified
2015        let a_data: &[f32] = a.as_slice();
2016        assert_eq!(a_data, &[1.0, 2.0, 3.0]);
2017    }
2018
2019    // The unit tests below are adapted from the original mlx c++ codebase.
2020
2021    #[test]
2022    fn test_unary_neg() {
2023        let x = array!(1.0);
2024        assert_eq!(negative(&x).unwrap().item::<f32>(), -1.0);
2025        assert_eq!((-x).item::<f32>(), -1.0);
2026
2027        // works on empty array
2028        assert_eq!(-array!(), array!());
2029
2030        // Throws on bool
2031        let x = array!(true);
2032        assert!(negative(&x).is_err());
2033    }
2034
2035    #[test]
2036    fn test_unary_abs() {
2037        let x = array!([-1.0, 0.0, 1.0]);
2038        assert_eq!(abs(&x).unwrap(), array!([1.0, 0.0, 1.0]));
2039
2040        // works on empty array
2041        assert_eq!(abs(array!()).unwrap(), array!());
2042
2043        // int32
2044        let x = array!([-1, 0, 1]);
2045        assert_eq!(abs(&x).unwrap(), array!([1, 0, 1]));
2046
2047        // uint32
2048        let x = array!([1u32, 0, 1]);
2049        assert_eq!(abs(&x).unwrap(), array!([1u32, 0, 1]));
2050
2051        // bool
2052        let x = array!([false, true]);
2053        assert_eq!(abs(&x).unwrap(), array!([false, true]));
2054    }
2055
2056    #[test]
2057    fn test_unary_sign() {
2058        let x = array!([-1.0, 0.0, 1.0]);
2059        assert_eq!(sign(&x).unwrap(), x);
2060
2061        // works on empty array
2062        assert_eq!(sign(array!()).unwrap(), array!());
2063
2064        // int32
2065        let x = array!([-1, 0, 1]);
2066        assert_eq!(sign(&x).unwrap(), x);
2067
2068        // uint32
2069        let x = array!([1u32, 0, 1]);
2070        assert_eq!(sign(&x).unwrap(), x);
2071
2072        // bool
2073        let x = array!([false, true]);
2074        assert_eq!(sign(&x).unwrap(), x);
2075    }
2076
2077    const NEG_INF: f32 = f32::NEG_INFINITY;
2078
2079    #[test]
2080    fn test_unary_floor_ceil() {
2081        let x = array![1.0];
2082        assert_eq!(floor(&x).unwrap().item::<f32>(), 1.0);
2083        assert_eq!(ceil(&x).unwrap().item::<f32>(), 1.0);
2084
2085        let x = array![1.5];
2086        assert_eq!(floor(&x).unwrap().item::<f32>(), 1.0);
2087        assert_eq!(ceil(&x).unwrap().item::<f32>(), 2.0);
2088
2089        let x = array![-1.5];
2090        assert_eq!(floor(&x).unwrap().item::<f32>(), -2.0);
2091        assert_eq!(ceil(&x).unwrap().item::<f32>(), -1.0);
2092
2093        let x = array![NEG_INF];
2094        assert_eq!(floor(&x).unwrap().item::<f32>(), NEG_INF);
2095        assert_eq!(ceil(&x).unwrap().item::<f32>(), NEG_INF);
2096
2097        let x = array!([1.0, 1.0]).as_type::<complex64>().unwrap();
2098        assert!(floor(&x).is_err());
2099        assert!(ceil(&x).is_err());
2100    }
2101
2102    #[test]
2103    fn test_unary_round() {
2104        let x = array!([0.5, -0.5, 1.5, -1.5, 2.3, 2.6]);
2105        assert_eq!(round(&x, None).unwrap(), array!([0, 0, 2, -2, 2, 3]));
2106
2107        let x = array!([11, 222, 32]);
2108        assert_eq!(round(&x, -1).unwrap(), array!([10, 220, 30]));
2109    }
2110
2111    #[test]
2112    fn test_unary_exp() {
2113        let x = array![0.0];
2114        assert_eq!(exp(&x).unwrap().item::<f32>(), 1.0);
2115
2116        let x = array![2.0];
2117        assert_float_eq! {
2118            exp(&x).unwrap().item::<f32>(),
2119            2.0f32.exp(),
2120            abs <= 1e-5
2121        };
2122
2123        assert_eq!(exp(array!()).unwrap(), array!());
2124
2125        let x = array![NEG_INF];
2126        assert_eq!(exp(&x).unwrap().item::<f32>(), 0.0);
2127
2128        // Integer input type
2129        let x = array![2];
2130        assert_eq!(x.dtype(), Dtype::Int32);
2131        assert_float_eq! {
2132            exp(&x).unwrap().item::<f32>(),
2133            2.0f32.exp(),
2134            abs <= 1e-5
2135        };
2136
2137        // Input is irregularly strided
2138        let x = broadcast_to(&array!(1.0), &[2, 2, 2]).unwrap();
2139        let res = exp(&x).unwrap();
2140        let expected = Array::full::<f32>(&[2, 2, 2], array!(1.0f32.exp())).unwrap();
2141        assert!(all_close(&res, &expected, None, None, None)
2142            .unwrap()
2143            .item::<bool>());
2144
2145        let data = Array::from_slice(&[0.0, 1.0, 2.0, 3.0], &[2, 2]);
2146        let x = split_equal(&data, 2, 1).unwrap();
2147        let expected = Array::from_slice(&[0.0f32.exp(), 2.0f32.exp()], &[2, 1]);
2148        assert!(all_close(exp(&x[0]).unwrap(), &expected, None, None, None)
2149            .unwrap()
2150            .item::<bool>());
2151    }
2152
2153    #[test]
2154    fn test_unary_expm1() {
2155        let x = array![-1.0];
2156        assert_float_eq! {
2157            expm1(&x).unwrap().item::<f32>(),
2158            (-1.0f32).exp_m1(),
2159            abs <= 1e-5
2160        };
2161
2162        let x = array![1.0];
2163        assert_float_eq! {
2164            expm1(&x).unwrap().item::<f32>(),
2165            1.0f32.exp_m1(),
2166            abs <= 1e-5
2167        };
2168
2169        // Integer input type
2170        let x = array![1];
2171        assert_eq!(expm1(&x).unwrap().dtype(), Dtype::Float32);
2172        assert_float_eq! {
2173            expm1(&x).unwrap().item::<f32>(),
2174            1.0f32.exp_m1(),
2175            abs <= 1e-5
2176        };
2177    }
2178
2179    #[test]
2180    fn test_unary_sin() {
2181        let x = array![0.0];
2182        assert_eq!(sin(&x).unwrap().item::<f32>(), 0.0);
2183
2184        let x = array![std::f32::consts::PI / 2.0];
2185        assert_float_eq! {
2186            sin(&x).unwrap().item::<f32>(),
2187            (std::f32::consts::PI / 2.0f32).sin(),
2188            abs <= 1e-5
2189        };
2190
2191        assert_eq!(sin(array!()).unwrap(), array!());
2192
2193        // Integer input type
2194        let x = array![0];
2195        assert_eq!(x.dtype(), Dtype::Int32);
2196        assert_float_eq! {
2197            sin(&x).unwrap().item::<f32>(),
2198            0.0f32.sin(),
2199            abs <= 1e-5
2200        };
2201
2202        // Input is irregularly strided
2203        let x = broadcast_to(&array!(1.0), &[2, 2, 2]).unwrap();
2204        let res = sin(&x).unwrap();
2205        let expected = Array::full::<f32>(&[2, 2, 2], array!(1.0f32.sin())).unwrap();
2206        assert!(all_close(&res, &expected, None, None, None)
2207            .unwrap()
2208            .item::<bool>());
2209
2210        let data = Array::from_slice(&[0.0, 1.0, 2.0, 3.0], &[2, 2]);
2211        let x = split_equal(&data, 2, 1).unwrap();
2212        let expected = Array::from_slice(&[0.0f32.sin(), 2.0f32.sin()], &[2, 1]);
2213        assert!(all_close(sin(&x[0]).unwrap(), &expected, None, None, None)
2214            .unwrap()
2215            .item::<bool>());
2216    }
2217
2218    #[test]
2219    fn test_unary_cos() {
2220        let x = array![0.0];
2221        assert_float_eq! {
2222            cos(&x).unwrap().item::<f32>(),
2223            0.0f32.cos(),
2224            abs <= 1e-5
2225        };
2226
2227        let x = array![std::f32::consts::PI / 2.0];
2228        assert_float_eq! {
2229            cos(&x).unwrap().item::<f32>(),
2230            (std::f32::consts::PI / 2.0f32).cos(),
2231            abs <= 1e-5
2232        };
2233
2234        assert_eq!(cos(array!()).unwrap(), array!());
2235
2236        // Integer input type
2237        let x = array![0];
2238        assert_eq!(x.dtype(), Dtype::Int32);
2239        assert_float_eq! {
2240            cos(&x).unwrap().item::<f32>(),
2241            0.0f32.cos(),
2242            abs <= 1e-5
2243        };
2244
2245        // Input is irregularly strided
2246        let x = broadcast_to(&array!(1.0), &[2, 2, 2]).unwrap();
2247        let res = cos(&x).unwrap();
2248        let expected = Array::full::<f32>(&[2, 2, 2], array!(1.0f32.cos())).unwrap();
2249        assert!(all_close(&res, &expected, None, None, None)
2250            .unwrap()
2251            .item::<bool>());
2252
2253        let data = Array::from_slice(&[0.0, 1.0, 2.0, 3.0], &[2, 2]);
2254        let x = split_equal(&data, 2, 1).unwrap();
2255        let expected = Array::from_slice(&[0.0f32.cos(), 2.0f32.cos()], &[2, 1]);
2256        assert!(all_close(cos(&x[0]).unwrap(), &expected, None, None, None)
2257            .unwrap()
2258            .item::<bool>());
2259    }
2260
2261    #[test]
2262    fn test_unary_degrees() {
2263        let x = array![0.0];
2264        assert_eq!(degrees(&x).unwrap().item::<f32>(), 0.0);
2265
2266        let x = array![std::f32::consts::PI / 2.0];
2267        assert_eq!(degrees(&x).unwrap().item::<f32>(), 90.0);
2268
2269        assert_eq!(degrees(array!()).unwrap(), array!());
2270
2271        // Integer input type
2272        let x = array![0];
2273        assert_eq!(x.dtype(), Dtype::Int32);
2274        assert_eq!(degrees(&x).unwrap().item::<f32>(), 0.0);
2275
2276        // Input is irregularly strided
2277        let x = broadcast_to(&array!(std::f32::consts::PI / 2.0), &[2, 2, 2]).unwrap();
2278        let res = degrees(&x).unwrap();
2279        let expected = Array::full::<f32>(&[2, 2, 2], array!(90.0)).unwrap();
2280        assert!(all_close(&res, &expected, None, None, None)
2281            .unwrap()
2282            .item::<bool>());
2283
2284        let angles = Array::from_slice(&[0.0, PI / 2.0, PI, 1.5 * PI], &[2, 2]);
2285        let x = split_equal(&angles, 2, 1).unwrap();
2286        let expected = Array::from_slice(&[0.0, 180.0], &[2, 1]);
2287        assert!(
2288            all_close(degrees(&x[0]).unwrap(), &expected, None, None, None)
2289                .unwrap()
2290                .item::<bool>()
2291        );
2292    }
2293
2294    #[test]
2295    fn test_unary_radians() {
2296        let x = array![0.0];
2297        assert_eq!(radians(&x).unwrap().item::<f32>(), 0.0);
2298
2299        let x = array![90.0];
2300        assert_eq!(
2301            radians(&x).unwrap().item::<f32>(),
2302            std::f32::consts::PI / 2.0
2303        );
2304
2305        assert_eq!(radians(array!()).unwrap(), array!());
2306
2307        // Integer input type
2308        let x = array![90];
2309        assert_eq!(x.dtype(), Dtype::Int32);
2310        assert_eq!(
2311            radians(&x).unwrap().item::<f32>(),
2312            std::f32::consts::PI / 2.0
2313        );
2314
2315        // Input is irregularly strided
2316        let x = broadcast_to(&array!(90.0), &[2, 2, 2]).unwrap();
2317        let res = radians(&x).unwrap();
2318        let expected = Array::full::<f32>(&[2, 2, 2], array!(std::f32::consts::PI / 2.0)).unwrap();
2319        assert!(all_close(&res, &expected, None, None, None)
2320            .unwrap()
2321            .item::<bool>());
2322
2323        let angles = Array::from_slice(&[0.0, 90.0, 180.0, 270.0], &[2, 2]);
2324        let x = split_equal(&angles, 2, 1).unwrap();
2325        let expected = Array::from_slice(&[0.0, PI], &[2, 1]);
2326        assert!(
2327            all_close(radians(&x[0]).unwrap(), &expected, None, None, None)
2328                .unwrap()
2329                .item::<bool>()
2330        );
2331    }
2332
2333    #[test]
2334    fn test_unary_log() {
2335        let x = array![0.0];
2336        assert_eq!(log(&x).unwrap().item::<f32>(), NEG_INF);
2337
2338        let x = array![1.0];
2339        assert_eq!(log(&x).unwrap().item::<f32>(), 0.0);
2340
2341        // Integer input type
2342        let x = array![1];
2343        assert_eq!(log(&x).unwrap().dtype(), Dtype::Float32);
2344        assert_eq!(log(&x).unwrap().item::<f32>(), 0.0);
2345
2346        // Input is irregularly strided
2347        let x = broadcast_to(&array!(1.0), &[2, 2, 2]).unwrap();
2348        let res = log(&x).unwrap();
2349        let expected = Array::full::<f32>(&[2, 2, 2], array!(0.0)).unwrap();
2350        assert!(all_close(&res, &expected, None, None, None)
2351            .unwrap()
2352            .item::<bool>());
2353
2354        let data = Array::from_slice(&[1.0, 2.0, 3.0, 4.0], &[2, 2]);
2355        let x = split_equal(&data, 2, 1).unwrap();
2356        let expected = Array::from_slice(&[1.0f32.ln(), 3.0f32.ln()], &[2, 1]);
2357        assert!(all_close(log(&x[0]).unwrap(), &expected, None, None, None)
2358            .unwrap()
2359            .item::<bool>());
2360    }
2361
2362    #[test]
2363    fn test_unary_log2() {
2364        let x = array![0.0];
2365        assert_eq!(log2(&x).unwrap().item::<f32>(), NEG_INF);
2366
2367        let x = array![1.0];
2368        assert_eq!(log2(&x).unwrap().item::<f32>(), 0.0);
2369
2370        let x = array![1024.0];
2371        assert_eq!(log2(&x).unwrap().item::<f32>(), 10.0);
2372    }
2373
2374    #[test]
2375    fn test_unary_log10() {
2376        let x = array![0.0];
2377        assert_eq!(log10(&x).unwrap().item::<f32>(), NEG_INF);
2378
2379        let x = array![1.0];
2380        assert_eq!(log10(&x).unwrap().item::<f32>(), 0.0);
2381
2382        let x = array![1000.0];
2383        assert_eq!(log10(&x).unwrap().item::<f32>(), 3.0);
2384    }
2385
2386    #[test]
2387    fn test_unary_log1p() {
2388        let x = array![-1.0];
2389        assert_float_eq! {
2390            log1p(&x).unwrap().item::<f32>(),
2391            (-1.0f32).ln_1p(),
2392            abs <= 1e-5
2393        };
2394
2395        let x = array![1.0];
2396        assert_float_eq! {
2397            log1p(&x).unwrap().item::<f32>(),
2398            1.0f32.ln_1p(),
2399            abs <= 1e-5
2400        };
2401
2402        // Integer input type
2403        let x = array![1];
2404        assert_eq!(log1p(&x).unwrap().dtype(), Dtype::Float32);
2405        assert_float_eq! {
2406            log1p(&x).unwrap().item::<f32>(),
2407            1.0f32.ln_1p(),
2408            abs <= 1e-5
2409        };
2410
2411        // Input is irregularly strided
2412        let x = broadcast_to(&array!(1.0), &[2, 2, 2]).unwrap();
2413        let res = log1p(&x).unwrap();
2414        let expected = Array::full::<f32>(&[2, 2, 2], array!(1.0f32.ln_1p())).unwrap();
2415        assert!(all_close(&res, &expected, None, None, None)
2416            .unwrap()
2417            .item::<bool>());
2418
2419        let data = Array::from_slice(&[1.0, 2.0, 3.0, 4.0], &[2, 2]);
2420        let x = split_equal(&data, 2, 1).unwrap();
2421        let expected = Array::from_slice(&[1.0f32.ln_1p(), 3.0f32.ln_1p()], &[2, 1]);
2422        assert!(
2423            all_close(log1p(&x[0]).unwrap(), &expected, None, None, None)
2424                .unwrap()
2425                .item::<bool>()
2426        );
2427    }
2428
2429    #[test]
2430    fn test_unary_sigmoid() {
2431        let x = array![0.0];
2432        assert_float_eq! {
2433            sigmoid(&x).unwrap().item::<f32>(),
2434            0.5,
2435            abs <= 1e-5
2436        };
2437
2438        // Integer input type
2439        let x = array![0];
2440        assert_eq!(sigmoid(&x).unwrap().dtype(), Dtype::Float32);
2441        assert_float_eq! {
2442            sigmoid(&x).unwrap().item::<f32>(),
2443            0.5,
2444            abs <= 1e-5
2445        };
2446
2447        let inf = f32::INFINITY;
2448        let x = array![inf];
2449        assert_eq!(sigmoid(&x).unwrap().item::<f32>(), 1.0);
2450
2451        let x = array![-inf];
2452        assert_eq!(sigmoid(&x).unwrap().item::<f32>(), 0.0);
2453    }
2454
2455    #[test]
2456    fn test_unary_square() {
2457        let x = array![3.0];
2458        assert_eq!(square(&x).unwrap().item::<f32>(), 9.0);
2459
2460        let x = array![2];
2461        assert_eq!(square(&x).unwrap().item::<i32>(), 4);
2462
2463        let x = Array::full::<f32>(&[3, 3], array!(2.0)).unwrap();
2464        assert!(all_close(
2465            square(&x).unwrap(),
2466            Array::full::<f32>(&[3, 3], array!(4.0)).unwrap(),
2467            None,
2468            None,
2469            None
2470        )
2471        .unwrap()
2472        .item::<bool>());
2473    }
2474
2475    #[test]
2476    fn test_unary_sqrt_rsqrt() {
2477        let x = array![4.0];
2478        assert_eq!(sqrt(&x).unwrap().item::<f32>(), 2.0);
2479        assert_eq!(rsqrt(&x).unwrap().item::<f32>(), 0.5);
2480
2481        let x = Array::full::<f32>(&[3, 3], array!(9.0)).unwrap();
2482        assert!(all_close(
2483            sqrt(&x).unwrap(),
2484            Array::full::<f32>(&[3, 3], array!(3.0)).unwrap(),
2485            None,
2486            None,
2487            None
2488        )
2489        .unwrap()
2490        .item::<bool>());
2491
2492        let x = array![4i32];
2493        assert_eq!(sqrt(&x).unwrap().item::<f32>(), 2.0);
2494        assert_eq!(rsqrt(&x).unwrap().item::<f32>(), 0.5);
2495    }
2496
2497    #[test]
2498    fn test_unary_reciprocal() {
2499        let x = array![8.0];
2500        assert_eq!(reciprocal(&x).unwrap().item::<f32>(), 0.125);
2501
2502        let x = array![2];
2503        let out = reciprocal(&x).unwrap();
2504        assert_eq!(out.dtype(), Dtype::Float32);
2505        assert_eq!(out.item::<f32>(), 0.5);
2506
2507        let x = Array::full::<f32>(&[3, 3], array!(2.0)).unwrap();
2508        assert!(all_close(
2509            reciprocal(&x).unwrap(),
2510            Array::full::<f32>(&[3, 3], array!(0.5)).unwrap(),
2511            None,
2512            None,
2513            None
2514        )
2515        .unwrap()
2516        .item::<bool>());
2517    }
2518
2519    #[test]
2520    fn test_binary_add() {
2521        let x = array![1.0];
2522        let y = array![1.0];
2523        let z = add(&x, &y).unwrap();
2524        assert_eq!(z.item::<f32>(), 2.0);
2525
2526        let z = &x + y;
2527        assert_eq!(z.item::<f32>(), 2.0);
2528
2529        let z = add(z, &x).unwrap();
2530        assert_eq!(z.item::<f32>(), 3.0);
2531
2532        // Chain a few adds:
2533        let mut out = x.deep_clone();
2534        for _ in 0..10 {
2535            out = add(&out, &x).unwrap();
2536        }
2537        assert_eq!(out.item::<f32>(), 11.0);
2538
2539        // Works for different shapes
2540        let x = array!([1.0, 2.0, 3.0]);
2541        let y = array!([1.0, 2.0, 3.0]);
2542        let z = add(&x, &y).unwrap();
2543        assert_eq!(z.shape(), &[3]);
2544        assert_eq!(z, array!([2.0, 4.0, 6.0]));
2545
2546        // Works with scalars
2547        let x = array!([1.0, 2.0, 3.0]);
2548        let y = &x + 2.0;
2549        assert_eq!(y.dtype(), Dtype::Float32);
2550        assert_eq!(y, array!([3.0, 4.0, 5.0]));
2551        let y = &x + 2.0;
2552        assert_eq!(y.dtype(), Dtype::Float32);
2553        assert_eq!(y, array!([3.0, 4.0, 5.0]));
2554
2555        // Check type promotion
2556        let y = x + 2;
2557        assert_eq!(y.dtype(), Dtype::Float32);
2558
2559        let y = array!([1, 2, 3]) + 2.0;
2560        assert_eq!(y.dtype(), Dtype::Float32);
2561        // assert!(array_equal(&y, &array![3.0, 4.0, 5.0]).item::<bool>());
2562        assert_eq!(y, array!([3.0, 4.0, 5.0]));
2563
2564        // Broadcasting works
2565        let x = broadcast_to(&array!(1.0), &[10]).unwrap();
2566        let y = broadcast_to(&array!(2.0), &[10]).unwrap();
2567        let z = add(&x, &y).unwrap();
2568        assert_eq!(z, full::<f32>(&[10], array!(3.0)).unwrap());
2569
2570        let x = Array::from_slice(&[1.0, 2.0], &[1, 2]);
2571        let y = Array::from_slice(&[1.0, 2.0], &[2, 1]);
2572        let z = add(&x, &y).unwrap();
2573        assert_eq!(z.shape(), &[2, 2]);
2574        assert_eq!(z, Array::from_slice(&[2.0, 3.0, 3.0, 4.0], &[2, 2]));
2575
2576        let x = ones::<f32>(&[3, 2, 1]).unwrap();
2577        let z = x + 2.0;
2578        assert_eq!(z.shape(), &[3, 2, 1]);
2579        let expected = Array::from_slice(&[3.0, 3.0, 3.0, 3.0, 3.0, 3.0], &[3, 2, 1]);
2580        assert_eq!(z, expected);
2581
2582        // Works for empty arrays
2583        let x = array!();
2584        let y = array!();
2585        let z = x + y;
2586        z.eval().unwrap();
2587        assert_eq!(z.size(), 0);
2588        assert_eq!(z.shape(), &[0]);
2589    }
2590
2591    #[test]
2592    fn test_binary_sub() {
2593        let x = array!([3.0, 2.0, 1.0]);
2594        let y = array!([1.0, 1.0, 1.0]);
2595        assert_eq!(x - y, array!([2.0, 1.0, 0.0]));
2596    }
2597
2598    #[test]
2599    fn test_binary_mul() {
2600        let x = array!([1.0, 2.0, 3.0]);
2601        let y = array!([2.0, 2.0, 2.0]);
2602        assert_eq!(x * y, array!([2.0, 4.0, 6.0]));
2603    }
2604
2605    #[test]
2606    fn test_binary_div() {
2607        let x = array![1.0];
2608        let y = array![1.0];
2609        assert_eq!(divide(&x, &y).unwrap().item::<f32>(), 1.0);
2610
2611        let x = array![1.0];
2612        let y = array![0.5];
2613        assert_eq!(divide(&x, &y).unwrap().item::<f32>(), 2.0);
2614
2615        let x = array![1.0];
2616        let y = array![4.0];
2617        assert_eq!(divide(&x, &y).unwrap().item::<f32>(), 0.25);
2618
2619        let x = array![true];
2620        let y = array![true];
2621        assert_eq!(divide(&x, &y).unwrap().item::<f32>(), 1.0);
2622
2623        let x = array![false];
2624        let y = array![true];
2625        assert_eq!(divide(&x, &y).unwrap().item::<f32>(), 0.0);
2626
2627        let x = array![true];
2628        let y = array![false];
2629        assert!(divide(&x, &y).unwrap().item::<f32>().is_infinite());
2630
2631        let x = array![false];
2632        let y = array![false];
2633        assert!(divide(&x, &y).unwrap().item::<f32>().is_nan());
2634    }
2635
2636    #[test]
2637    fn test_binary_maximum_minimum() {
2638        let x = array![1.0];
2639        let y = array![0.0];
2640        assert_eq!(maximum(&x, &y).unwrap().item::<f32>(), 1.0);
2641        assert_eq!(minimum(&x, &y).unwrap().item::<f32>(), 0.0);
2642
2643        let y = array![2.0];
2644        assert_eq!(maximum(&x, &y).unwrap().item::<f32>(), 2.0);
2645        assert_eq!(minimum(&x, &y).unwrap().item::<f32>(), 1.0);
2646    }
2647
2648    #[test]
2649    fn test_binary_logaddexp() {
2650        let x = array![0.0];
2651        let y = array![0.0];
2652        assert_float_eq! {
2653            log_add_exp(&x, &y).unwrap().item::<f32>(),
2654            2.0f32.ln(),
2655            abs <= 1e-5
2656        };
2657
2658        let x = array!([0u32]);
2659        let y = array!([10000u32]);
2660        assert_eq!(log_add_exp(&x, &y).unwrap().item::<f32>(), 10000.0);
2661
2662        let x = array![f32::INFINITY];
2663        let y = array![3.0];
2664        assert_eq!(log_add_exp(&x, &y).unwrap().item::<f32>(), f32::INFINITY);
2665
2666        let x = array![f32::NEG_INFINITY];
2667        let y = array![3.0];
2668        assert_eq!(log_add_exp(&x, &y).unwrap().item::<f32>(), 3.0);
2669
2670        let x = array![f32::NEG_INFINITY];
2671        let y = array![f32::NEG_INFINITY];
2672        assert_eq!(
2673            log_add_exp(&x, &y).unwrap().item::<f32>(),
2674            f32::NEG_INFINITY
2675        );
2676
2677        let x = array![f32::INFINITY];
2678        let y = array![f32::INFINITY];
2679        assert_eq!(log_add_exp(&x, &y).unwrap().item::<f32>(), f32::INFINITY);
2680
2681        let x = array![f32::NEG_INFINITY];
2682        let y = array![f32::INFINITY];
2683        assert_eq!(log_add_exp(&x, &y).unwrap().item::<f32>(), f32::INFINITY);
2684    }
2685
2686    #[test]
2687    fn test_basic_clip() {
2688        let a = array!([1.0, 4.0, 3.0, 8.0, 5.0]);
2689        let expected = array!([2.0, 4.0, 3.0, 6.0, 5.0]);
2690        let clipped = clip(&a, (array!(2.0), array!(6.0))).unwrap();
2691        assert_eq!(clipped, expected);
2692
2693        // Test with scalar
2694        let clipped = clip(&a, (2.0, 6.0)).unwrap();
2695        assert_eq!(clipped, expected);
2696    }
2697
2698    #[test]
2699    fn test_clip_with_only_min() {
2700        let a = array!([-1.0, 1.0, 0.0, 5.0]);
2701        let expected = array!([0.0, 1.0, 0.0, 5.0]);
2702        let clipped = clip(&a, (array!(0.0), ())).unwrap();
2703        assert_eq!(clipped, expected);
2704
2705        // Test with scalar
2706        let clipped = clip(&a, (0.0, ())).unwrap();
2707        assert_eq!(clipped, expected);
2708    }
2709
2710    #[test]
2711    fn test_clip_with_only_max() {
2712        let a = array!([2.0, 3.0, 4.0, 5.0]);
2713        let expected = array!([2.0, 3.0, 4.0, 4.0]);
2714        let clipped = clip(&a, ((), array!(4.0))).unwrap();
2715        assert_eq!(clipped, expected);
2716
2717        // Test with scalar
2718        let clipped = clip(&a, ((), 4.0)).unwrap();
2719        assert_eq!(clipped, expected);
2720    }
2721
2722    #[test]
2723    fn test_tensordot() {
2724        let x = reshape(arange::<_, f32>(None, 60.0, None).unwrap(), &[3, 4, 5]).unwrap();
2725        let y = reshape(arange::<_, f32>(None, 24.0, None).unwrap(), &[4, 3, 2]).unwrap();
2726        let z = tensordot(&x, &y, (&[1i32, 0], &[0i32, 1])).unwrap();
2727        let expected = Array::from_slice(
2728            &[4400, 4730, 4532, 4874, 4664, 5018, 4796, 5162, 4928, 5306],
2729            &[5, 2],
2730        );
2731        assert_eq!(z, expected);
2732
2733        let x = reshape(arange::<_, f32>(None, 360.0, None).unwrap(), &[3, 4, 5, 6]).unwrap();
2734        let y = reshape(arange::<_, f32>(None, 360.0, None).unwrap(), &[6, 4, 5, 3]).unwrap();
2735        assert!(tensordot(&x, &y, (&[2, 1, 3], &[1, 2, 0])).is_err());
2736
2737        let x = reshape(arange::<_, f32>(None, 60.0, None).unwrap(), &[3, 4, 5]).unwrap();
2738        let y = reshape(arange::<_, f32>(None, 120.0, None).unwrap(), &[4, 5, 6]).unwrap();
2739
2740        let z = tensordot(&x, &y, 2).unwrap();
2741        let expected = Array::from_slice(
2742            &[
2743                14820.0, 15010.0, 15200.0, 15390.0, 15580.0, 15770.0, 37620.0, 38210.0, 38800.0,
2744                39390.0, 39980.0, 40570.0, 60420.0, 61410.0, 62400.0, 63390.0, 64380.0, 65370.0,
2745            ],
2746            &[3, 6],
2747        );
2748        assert_eq!(z, expected);
2749    }
2750
2751    #[test]
2752    fn test_outer() {
2753        let x = arange::<_, f32>(1.0, 5.0, None).unwrap();
2754        let y = arange::<_, f32>(1.0, 4.0, None).unwrap();
2755        let z = outer(&x, &y).unwrap();
2756        let expected = Array::from_slice(
2757            &[1.0, 2.0, 3.0, 2.0, 4.0, 6.0, 3.0, 6.0, 9.0, 4.0, 8.0, 12.0],
2758            &[4, 3],
2759        );
2760        assert_eq!(z, expected);
2761
2762        let x = ones::<f32>(&[5]).unwrap();
2763        let y = linspace::<_, f32>(-2.0, 2.0, 5).unwrap();
2764        let z = outer(&x, &y).unwrap();
2765        let expected = Array::from_slice(
2766            &[
2767                -2.0, -1.0, 0.0, 1.0, 2.0, -2.0, -1.0, 0.0, 1.0, 2.0, -2.0, -1.0, 0.0, 1.0, 2.0,
2768                -2.0, -1.0, 0.0, 1.0, 2.0, -2.0, -1.0, 0.0, 1.0, 2.0,
2769            ],
2770            &[5, 5],
2771        );
2772        assert_eq!(z, expected);
2773    }
2774
2775    #[test]
2776    fn test_inner() {
2777        let x = reshape(arange::<_, f32>(None, 5.0, None).unwrap(), &[1, 5]).unwrap();
2778        let y = reshape(arange::<_, f32>(None, 6.0, None).unwrap(), &[2, 3]).unwrap();
2779        assert!(inner(&x, &y).is_err());
2780
2781        let x = array!([1.0, 2.0, 3.0]);
2782        let y = array!([0.0, 1.0, 0.0]);
2783        let z = inner(&x, &y).unwrap();
2784        assert_eq!(z.item::<f32>(), 2.0);
2785
2786        let x = reshape(arange::<_, f32>(None, 24.0, None).unwrap(), &[2, 3, 4]).unwrap();
2787        let y = arange::<_, f32>(None, 4.0, None).unwrap();
2788        let z = inner(&x, &y).unwrap();
2789        let expected = Array::from_slice(&[14.0, 38.0, 62.0, 86.0, 110.0, 134.0], &[2, 3]);
2790        assert_eq!(z, expected);
2791
2792        let x = reshape(arange::<_, f32>(None, 2.0, None).unwrap(), &[1, 1, 2]).unwrap();
2793        let y = reshape(arange::<_, f32>(None, 6.0, None).unwrap(), &[3, 2]).unwrap();
2794        let z = inner(&x, &y).unwrap();
2795        let expected = Array::from_slice(&[1.0, 3.0, 5.0], &[1, 1, 3]);
2796        assert_eq!(z, expected);
2797
2798        let x = eye::<f32>(2, None, None).unwrap();
2799        let y = Array::from_f32(7.0);
2800        let z = inner(&x, &y).unwrap();
2801        let expected = Array::from_slice(&[7.0, 0.0, 0.0, 7.0], &[2, 2]);
2802        assert_eq!(z, expected);
2803    }
2804
2805    #[test]
2806    fn test_divmod() {
2807        let x = array!([1.0, 2.0, 3.0]);
2808        let y = array!([1.0, 1.0, 1.0]);
2809        let out = divmod(&x, &y).unwrap();
2810        assert_eq!(out.0, array!([1.0, 2.0, 3.0]));
2811        assert_eq!(out.1, array!([0.0, 0.0, 0.0]));
2812
2813        let x = array!([5.0, 6.0, 7.0]);
2814        let y = array!([2.0, 2.0, 2.0]);
2815        let out = divmod(&x, &y).unwrap();
2816        assert_eq!(out.0, array!([2.0, 3.0, 3.0]));
2817        assert_eq!(out.1, array!([1.0, 0.0, 1.0]));
2818
2819        let x = array!([5.0, 6.0, 7.0]);
2820        let y = array!([2.0, 2.0, 2.0]);
2821        let out = divmod(&x, &y).unwrap();
2822        assert_eq!(out.0, array!([2.0, 3.0, 3.0]));
2823        assert_eq!(out.1, array!([1.0, 0.0, 1.0]));
2824
2825        let x = array![complex64::new(1.0, 0.0)];
2826        let y = array![complex64::new(2.0, 0.0)];
2827        assert!(divmod(&x, &y).is_err());
2828
2829        // Check that we can eval on both outputs
2830        let x = array![1.0];
2831        let y = array![2.0];
2832        let (quo, rem) = divmod(&x, &y).unwrap();
2833        eval([&quo, &rem]).unwrap();
2834        assert_eq!(quo.item::<f32>(), 0.0);
2835        assert_eq!(rem.item::<f32>(), 1.0);
2836
2837        // Check nested in the graph
2838        let x = array![1.0];
2839        let y = array![2.0];
2840        let (quo, rem) = divmod(&x, &y).unwrap();
2841        let z = quo + rem;
2842        assert_eq!(z.item::<f32>(), 1.0);
2843
2844        // Check that we can still eval when one output goes out of scope
2845        let mut out_holder = {
2846            let (quo, _) = divmod(&x, &y).unwrap();
2847            vec![quo]
2848        };
2849        eval(out_holder.iter()).unwrap();
2850        assert_eq!(out_holder[0].item::<f32>(), 0.0);
2851
2852        // Check that we can still eval when the other output goes out of scope
2853        out_holder.clear();
2854        let out_holder = {
2855            let (_, rem) = divmod(&x, &y).unwrap();
2856            vec![rem]
2857        };
2858        eval(out_holder.iter()).unwrap();
2859        assert_eq!(out_holder[0].item::<f32>(), 1.0);
2860    }
2861}