mlx_rs/ops/
logical.rs

1use crate::array::Array;
2use crate::error::Result;
3use crate::stream::StreamOrDevice;
4use crate::utils::guard::Guarded;
5use crate::utils::{axes_or_default_to_all, IntoOption};
6use crate::Stream;
7use mlx_internal_macros::{default_device, generate_macro};
8
9impl Array {
10    /// Element-wise equality returning an error if the arrays are not broadcastable.
11    ///
12    /// Equality comparison on two arrays with
13    /// [broadcasting](https://swiftpackageindex.com/ml-explore/mlx-swift/main/documentation/mlx/broadcasting).
14    ///
15    /// # Params
16    ///
17    /// - other: array to compare
18    ///
19    /// # Example
20    ///
21    /// ```rust
22    /// use mlx_rs::Array;
23    /// let a = Array::from_slice(&[1, 2, 3], &[3]);
24    /// let b = Array::from_slice(&[1, 2, 3], &[3]);
25    /// let mut c = a.eq(&b).unwrap();
26    ///
27    /// let c_data: &[bool] = c.as_slice();
28    /// // c_data == [true, true, true]
29    /// ```
30    #[default_device]
31    pub fn eq_device(&self, other: impl AsRef<Array>, stream: impl AsRef<Stream>) -> Result<Array> {
32        Array::try_from_op(|res| unsafe {
33            mlx_sys::mlx_equal(
34                res,
35                self.as_ptr(),
36                other.as_ref().as_ptr(),
37                stream.as_ref().as_ptr(),
38            )
39        })
40    }
41
42    /// Element-wise less than or equal returning an error if the arrays are not broadcastable.
43    ///
44    /// Less than or equal on two arrays with
45    /// [broadcasting](https://swiftpackageindex.com/ml-explore/mlx-swift/main/documentation/mlx/broadcasting).
46    ///
47    /// # Params
48    ///
49    /// - other: array to compare
50    ///
51    /// # Example
52    ///
53    /// ```rust
54    /// use mlx_rs::Array;
55    /// let a = Array::from_slice(&[1, 2, 3], &[3]);
56    /// let b = Array::from_slice(&[1, 2, 3], &[3]);
57    /// let mut c = a.le(&b).unwrap();
58    ///
59    /// let c_data: &[bool] = c.as_slice();
60    /// // c_data == [true, true, true]
61    /// ```
62    #[default_device]
63    pub fn le_device(&self, other: impl AsRef<Array>, stream: impl AsRef<Stream>) -> Result<Array> {
64        Array::try_from_op(|res| unsafe {
65            mlx_sys::mlx_less_equal(
66                res,
67                self.as_ptr(),
68                other.as_ref().as_ptr(),
69                stream.as_ref().as_ptr(),
70            )
71        })
72    }
73
74    /// Element-wise greater than or equal returning an error if the arrays are not broadcastable.
75    ///
76    /// Greater than or equal on two arrays with
77    /// [broadcasting](https://swiftpackageindex.com/ml-explore/mlx-swift/main/documentation/mlx/broadcasting).
78    ///
79    /// # Params
80    ///
81    /// - other: array to compare
82    ///
83    /// # Example
84    ///
85    /// ```rust
86    /// use mlx_rs::Array;
87    /// let a = Array::from_slice(&[1, 2, 3], &[3]);
88    /// let b = Array::from_slice(&[1, 2, 3], &[3]);
89    /// let mut c = a.ge(&b).unwrap();
90    ///
91    /// let c_data: &[bool] = c.as_slice();
92    /// // c_data == [true, true, true]
93    /// ```
94    #[default_device]
95    pub fn ge_device(&self, other: impl AsRef<Array>, stream: impl AsRef<Stream>) -> Result<Array> {
96        Array::try_from_op(|res| unsafe {
97            mlx_sys::mlx_greater_equal(
98                res,
99                self.as_ptr(),
100                other.as_ref().as_ptr(),
101                stream.as_ref().as_ptr(),
102            )
103        })
104    }
105
106    /// Element-wise not equal returning an error if the arrays are not broadcastable.
107    ///
108    /// Not equal on two arrays with
109    /// [broadcasting](https://swiftpackageindex.com/ml-explore/mlx-swift/main/documentation/mlx/broadcasting).
110    ///
111    /// # Params
112    ///
113    /// - other: array to compare
114    ///
115    /// # Example
116    ///
117    /// ```rust
118    /// use mlx_rs::Array;
119    /// let a = Array::from_slice(&[1, 2, 3], &[3]);
120    /// let b = Array::from_slice(&[1, 2, 3], &[3]);
121    /// let mut c = a.ne(&b).unwrap();
122    ///
123    /// let c_data: &[bool] = c.as_slice();
124    /// // c_data == [false, false, false]
125    /// ```
126    #[default_device]
127    pub fn ne_device(&self, other: impl AsRef<Array>, stream: impl AsRef<Stream>) -> Result<Array> {
128        Array::try_from_op(|res| unsafe {
129            mlx_sys::mlx_not_equal(
130                res,
131                self.as_ptr(),
132                other.as_ref().as_ptr(),
133                stream.as_ref().as_ptr(),
134            )
135        })
136    }
137
138    /// Element-wise less than returning an error if the arrays are not broadcastable.
139    ///
140    /// Less than on two arrays with [broadcasting](https://swiftpackageindex.com/ml-explore/mlx-swift/main/documentation/mlx/broadcasting).
141    ///
142    /// # Params
143    ///
144    /// - other: array to compare
145    ///
146    /// # Example
147    ///
148    /// ```rust
149    /// use mlx_rs::Array;
150    /// let a = Array::from_slice(&[1, 2, 3], &[3]);
151    /// let b = Array::from_slice(&[1, 2, 3], &[3]);
152    /// let mut c = a.lt(&b).unwrap();
153    ///
154    /// let c_data: &[bool] = c.as_slice();
155    /// // c_data == [false, false, false]
156    /// ```
157    #[default_device]
158    pub fn lt_device(&self, other: impl AsRef<Array>, stream: impl AsRef<Stream>) -> Result<Array> {
159        Array::try_from_op(|res| unsafe {
160            mlx_sys::mlx_less(
161                res,
162                self.as_ptr(),
163                other.as_ref().as_ptr(),
164                stream.as_ref().as_ptr(),
165            )
166        })
167    }
168
169    /// Element-wise greater than returning an error if the arrays are not broadcastable.
170    ///
171    /// Greater than on two arrays with [broadcasting](https://swiftpackageindex.com/ml-explore/mlx-swift/main/documentation/mlx/broadcasting).
172    ///
173    /// # Params
174    ///
175    /// - other: array to compare
176    ///
177    /// # Example
178    ///
179    /// ```rust
180    /// use mlx_rs::Array;
181    /// let a = Array::from_slice(&[1, 2, 3], &[3]);
182    /// let b = Array::from_slice(&[1, 2, 3], &[3]);
183    /// let mut c = a.gt(&b).unwrap();
184    ///
185    /// let c_data: &[bool] = c.as_slice();
186    /// // c_data == [false, false, false]
187    /// ```
188    #[default_device]
189    pub fn gt_device(&self, other: impl AsRef<Array>, stream: impl AsRef<Stream>) -> Result<Array> {
190        Array::try_from_op(|res| unsafe {
191            mlx_sys::mlx_greater(
192                res,
193                self.as_ptr(),
194                other.as_ref().as_ptr(),
195                stream.as_ref().as_ptr(),
196            )
197        })
198    }
199
200    /// Element-wise logical and returning an error if the arrays are not broadcastable.
201    ///
202    /// Logical and on two arrays with [broadcasting](https://swiftpackageindex.com/ml-explore/mlx-swift/main/documentation/mlx/broadcasting).
203    ///
204    /// # Params
205    ///
206    /// - other: array to compare
207    ///
208    /// # Example
209    ///
210    /// ```rust
211    /// use mlx_rs::Array;
212    /// let a = Array::from_slice(&[true, false, true], &[3]);
213    /// let b = Array::from_slice(&[true, true, false], &[3]);
214    /// let mut c = a.logical_and(&b).unwrap();
215    ///
216    /// let c_data: &[bool] = c.as_slice();
217    /// // c_data == [true, false, false]
218    /// ```
219    #[default_device]
220    pub fn logical_and_device(
221        &self,
222        other: impl AsRef<Array>,
223        stream: impl AsRef<Stream>,
224    ) -> Result<Array> {
225        Array::try_from_op(|res| unsafe {
226            mlx_sys::mlx_logical_and(
227                res,
228                self.as_ptr(),
229                other.as_ref().as_ptr(),
230                stream.as_ref().as_ptr(),
231            )
232        })
233    }
234
235    /// Element-wise logical or returning an error if the arrays are not broadcastable.
236    ///
237    /// Logical or on two arrays with [broadcasting](https://swiftpackageindex.com/ml-explore/mlx-swift/main/documentation/mlx/broadcasting).
238    ///
239    /// # Params
240    ///
241    /// - other: array to compare
242    ///
243    /// # Example
244    ///
245    /// ```rust
246    /// use mlx_rs::Array;
247    /// let a = Array::from_slice(&[true, false, true], &[3]);
248    /// let b = Array::from_slice(&[true, true, false], &[3]);
249    /// let mut c = a.logical_or(&b).unwrap();
250    ///
251    /// let c_data: &[bool] = c.as_slice();
252    /// // c_data == [true, true, true]
253    /// ```
254    #[default_device]
255    pub fn logical_or_device(
256        &self,
257        other: impl AsRef<Array>,
258        stream: impl AsRef<Stream>,
259    ) -> Result<Array> {
260        Array::try_from_op(|res| unsafe {
261            mlx_sys::mlx_logical_or(
262                res,
263                self.as_ptr(),
264                other.as_ref().as_ptr(),
265                stream.as_ref().as_ptr(),
266            )
267        })
268    }
269
270    /// Unary element-wise logical not.
271    ///
272    /// # Example
273    ///
274    /// ```rust
275    /// use mlx_rs::{Array, StreamOrDevice};
276    /// let a: Array = false.into();
277    /// let mut b = a.logical_not_device(StreamOrDevice::default()).unwrap();
278    ///
279    /// let b_data: &[bool] = b.as_slice();
280    /// // b_data == [true]
281    /// ```
282    #[default_device]
283    pub fn logical_not_device(&self, stream: impl AsRef<Stream>) -> Result<Array> {
284        Array::try_from_op(|res| unsafe {
285            mlx_sys::mlx_logical_not(res, self.as_ptr(), stream.as_ref().as_ptr())
286        })
287    }
288
289    /// Approximate comparison of two arrays returning an error if the inputs aren't valid.
290    ///
291    /// The arrays are considered equal if:
292    ///
293    /// ```text
294    /// all(abs(a - b) <= (atol + rtol * abs(b)))
295    /// ```
296    ///
297    /// # Params
298    ///
299    /// - other: array to compare
300    /// - rtol: relative tolerance = defaults to 1e-5 when None
301    /// - atol: absolute tolerance - defaults to 1e-8 when None
302    /// - equal_nan: whether to consider NaNs equal -- default is false when None
303    ///
304    /// # Example
305    ///
306    /// ```rust
307    /// use num_traits::Pow;
308    /// use mlx_rs::array;
309    /// let a = array!([0., 1., 2., 3.]).sqrt().unwrap();
310    /// let b = array!([0., 1., 2., 3.]).power(array!(0.5)).unwrap();
311    /// let mut c = a.all_close(&b, None, None, None).unwrap();
312    ///
313    /// let c_data: &[bool] = c.as_slice();
314    /// // c_data == [true]
315    /// ```
316    #[default_device]
317    pub fn all_close_device(
318        &self,
319        other: impl AsRef<Array>,
320        rtol: impl Into<Option<f64>>,
321        atol: impl Into<Option<f64>>,
322        equal_nan: impl Into<Option<bool>>,
323        stream: impl AsRef<Stream>,
324    ) -> Result<Array> {
325        Array::try_from_op(|res| unsafe {
326            mlx_sys::mlx_allclose(
327                res,
328                self.as_ptr(),
329                other.as_ref().as_ptr(),
330                rtol.into().unwrap_or(1e-5),
331                atol.into().unwrap_or(1e-8),
332                equal_nan.into().unwrap_or(false),
333                stream.as_ref().as_ptr(),
334            )
335        })
336    }
337
338    /// Returns a boolean array where two arrays are element-wise equal within a tolerance returning an error if the arrays are not broadcastable.
339    ///
340    /// Infinite values are considered equal if they have the same sign, NaN values are not equal unless
341    /// `equalNAN` is `true`.
342    ///
343    /// Two values are considered close if:
344    ///
345    /// ```text
346    /// abs(a - b) <= (atol + rtol * abs(b))
347    /// ```
348    ///
349    /// Unlike [self.array_eq] this function supports [broadcasting](https://swiftpackageindex.com/ml-explore/mlx-swift/main/documentation/mlx/broadcasting).
350    #[default_device]
351    pub fn is_close_device(
352        &self,
353        other: impl AsRef<Array>,
354        rtol: impl Into<Option<f64>>,
355        atol: impl Into<Option<f64>>,
356        equal_nan: impl Into<Option<bool>>,
357        stream: impl AsRef<Stream>,
358    ) -> Result<Array> {
359        Array::try_from_op(|res| unsafe {
360            mlx_sys::mlx_isclose(
361                res,
362                self.as_ptr(),
363                other.as_ref().as_ptr(),
364                rtol.into().unwrap_or(1e-5),
365                atol.into().unwrap_or(1e-8),
366                equal_nan.into().unwrap_or(false),
367                stream.as_ref().as_ptr(),
368            )
369        })
370    }
371
372    /// Array equality check.
373    ///
374    /// Compare two arrays for equality. Returns `true` iff the arrays have
375    /// the same shape and their values are equal. The arrays need not have
376    /// the same type to be considered equal.
377    ///
378    /// # Params
379    ///
380    /// - other: array to compare
381    /// - equal_nan: whether to consider NaNs equal -- default is false when None
382    ///
383    /// # Example
384    ///
385    /// ```rust
386    /// use mlx_rs::Array;
387    /// let a = Array::from_slice(&[0, 1, 2, 3], &[4]);
388    /// let b = Array::from_slice(&[0., 1., 2., 3.], &[4]);
389    ///
390    /// let c = a.array_eq(&b, None);
391    /// // c == [true]
392    /// ```
393    #[default_device]
394    pub fn array_eq_device(
395        &self,
396        other: impl AsRef<Array>,
397        equal_nan: impl Into<Option<bool>>,
398        stream: impl AsRef<Stream>,
399    ) -> Result<Array> {
400        Array::try_from_op(|res| unsafe {
401            mlx_sys::mlx_array_equal(
402                res,
403                self.as_ptr(),
404                other.as_ref().as_ptr(),
405                equal_nan.into().unwrap_or(false),
406                stream.as_ref().as_ptr(),
407            )
408        })
409    }
410
411    /// An `or` reduction over the given axes returning an error if the axes are invalid.
412    ///
413    /// # Params
414    ///
415    /// - axes: axes to reduce over -- defaults to all axes if not provided
416    /// - keep_dims: if `true` keep reduced axis as singleton dimension -- defaults to false if not provided
417    ///
418    ///  # Example
419    ///
420    /// ```rust
421    /// use mlx_rs::Array;
422    ///
423    /// let array = Array::from_slice(&[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11], &[3, 4]);
424    ///
425    /// // will produce a scalar Array with true -- some of the values are non-zero
426    /// let all = array.any(None, None).unwrap();
427    ///
428    /// // produces an Array([true, true, true, true]) -- all rows have non-zeros
429    /// let all_rows = array.any(&[0], None).unwrap();
430    /// ```
431    #[default_device]
432    pub fn any_device<'a>(
433        &self,
434        axes: impl IntoOption<&'a [i32]>,
435        keep_dims: impl Into<Option<bool>>,
436        stream: impl AsRef<Stream>,
437    ) -> Result<Array> {
438        let axes = axes_or_default_to_all(axes, self.ndim() as i32);
439
440        Array::try_from_op(|res| unsafe {
441            mlx_sys::mlx_any(
442                res,
443                self.as_ptr(),
444                axes.as_ptr(),
445                axes.len(),
446                keep_dims.into().unwrap_or(false),
447                stream.as_ref().as_ptr(),
448            )
449        })
450    }
451}
452
453/// See [`Array::any`]
454#[generate_macro]
455#[default_device]
456pub fn any_device<'a>(
457    array: impl AsRef<Array>,
458    #[optional] axes: impl IntoOption<&'a [i32]>,
459    #[optional] keep_dims: impl Into<Option<bool>>,
460    #[optional] stream: impl AsRef<Stream>,
461) -> Result<Array> {
462    array.as_ref().any_device(axes, keep_dims, stream)
463}
464
465/// See [`Array::logical_and`]
466#[generate_macro]
467#[default_device]
468pub fn logical_and_device(
469    a: impl AsRef<Array>,
470    b: impl AsRef<Array>,
471    #[optional] stream: impl AsRef<Stream>,
472) -> Result<Array> {
473    a.as_ref().logical_and_device(b, stream)
474}
475
476/// See [`Array::logical_or`]
477#[generate_macro]
478#[default_device]
479pub fn logical_or_device(
480    a: impl AsRef<Array>,
481    b: impl AsRef<Array>,
482    #[optional] stream: impl AsRef<Stream>,
483) -> Result<Array> {
484    a.as_ref().logical_or_device(b, stream)
485}
486
487/// See [`Array::logical_not`]
488#[generate_macro]
489#[default_device]
490pub fn logical_not_device(
491    a: impl AsRef<Array>,
492    #[optional] stream: impl AsRef<Stream>,
493) -> Result<Array> {
494    a.as_ref().logical_not_device(stream)
495}
496
497/// See [`Array::all_close`]
498#[generate_macro]
499#[default_device]
500pub fn all_close_device(
501    a: impl AsRef<Array>,
502    b: impl AsRef<Array>,
503    #[optional] rtol: impl Into<Option<f64>>,
504    #[optional] atol: impl Into<Option<f64>>,
505    #[optional] equal_nan: impl Into<Option<bool>>,
506    #[optional] stream: impl AsRef<Stream>,
507) -> Result<Array> {
508    a.as_ref()
509        .all_close_device(b, rtol, atol, equal_nan, stream)
510}
511
512/// See [`Array::is_close`]
513#[generate_macro]
514#[default_device]
515pub fn is_close_device(
516    a: impl AsRef<Array>,
517    b: impl AsRef<Array>,
518    #[optional] rtol: impl Into<Option<f64>>,
519    #[optional] atol: impl Into<Option<f64>>,
520    #[optional] equal_nan: impl Into<Option<bool>>,
521    #[optional] stream: impl AsRef<Stream>,
522) -> Result<Array> {
523    a.as_ref().is_close_device(b, rtol, atol, equal_nan, stream)
524}
525
526/// See [`Array::array_eq`]
527#[generate_macro]
528#[default_device]
529pub fn array_eq_device(
530    a: impl AsRef<Array>,
531    b: impl AsRef<Array>,
532    #[optional] equal_nan: impl Into<Option<bool>>,
533    #[optional] stream: impl AsRef<Stream>,
534) -> Result<Array> {
535    a.as_ref().array_eq_device(b, equal_nan, stream)
536}
537
538/// See [`Array::eq`]
539#[generate_macro]
540#[default_device]
541pub fn eq_device(
542    a: impl AsRef<Array>,
543    b: impl AsRef<Array>,
544    #[optional] stream: impl AsRef<Stream>,
545) -> Result<Array> {
546    a.as_ref().eq_device(b, stream)
547}
548
549/// See [`Array::le`]
550#[generate_macro]
551#[default_device]
552pub fn le_device(
553    a: impl AsRef<Array>,
554    b: impl AsRef<Array>,
555    #[optional] stream: impl AsRef<Stream>,
556) -> Result<Array> {
557    a.as_ref().le_device(b, stream)
558}
559
560/// See [`Array::ge`]
561#[generate_macro]
562#[default_device]
563pub fn ge_device(
564    a: impl AsRef<Array>,
565    b: impl AsRef<Array>,
566    #[optional] stream: impl AsRef<Stream>,
567) -> Result<Array> {
568    a.as_ref().ge_device(b, stream)
569}
570
571/// See [`Array::ne`]
572#[generate_macro]
573#[default_device]
574pub fn ne_device(
575    a: impl AsRef<Array>,
576    b: impl AsRef<Array>,
577    #[optional] stream: impl AsRef<Stream>,
578) -> Result<Array> {
579    a.as_ref().ne_device(b, stream)
580}
581
582/// See [`Array::lt`]
583#[generate_macro]
584#[default_device]
585pub fn lt_device(
586    a: impl AsRef<Array>,
587    b: impl AsRef<Array>,
588    #[optional] stream: impl AsRef<Stream>,
589) -> Result<Array> {
590    a.as_ref().lt_device(b, stream)
591}
592
593/// See [`Array::gt`]
594#[generate_macro]
595#[default_device]
596pub fn gt_device(
597    a: impl AsRef<Array>,
598    b: impl AsRef<Array>,
599    #[optional] stream: impl AsRef<Stream>,
600) -> Result<Array> {
601    a.as_ref().gt_device(b, stream)
602}
603
604// TODO: check if the functions below could throw an exception.
605
606/// Return a boolean array indicating which elements are NaN.
607#[generate_macro]
608#[default_device]
609pub fn is_nan_device(
610    array: impl AsRef<Array>,
611    #[optional] stream: impl AsRef<Stream>,
612) -> Result<Array> {
613    Array::try_from_op(|res| unsafe {
614        mlx_sys::mlx_isnan(res, array.as_ref().as_ptr(), stream.as_ref().as_ptr())
615    })
616}
617
618/// Return a boolean array indicating which elements are +/- inifnity.
619#[generate_macro]
620#[default_device]
621pub fn is_inf_device(
622    array: impl AsRef<Array>,
623    #[optional] stream: impl AsRef<Stream>,
624) -> Result<Array> {
625    Array::try_from_op(|res| unsafe {
626        mlx_sys::mlx_isinf(res, array.as_ref().as_ptr(), stream.as_ref().as_ptr())
627    })
628}
629
630/// Return a boolean array indicating which elements are positive infinity.
631#[generate_macro]
632#[default_device]
633pub fn is_pos_inf_device(
634    array: impl AsRef<Array>,
635    #[optional] stream: impl AsRef<Stream>,
636) -> Result<Array> {
637    Array::try_from_op(|res| unsafe {
638        mlx_sys::mlx_isposinf(res, array.as_ref().as_ptr(), stream.as_ref().as_ptr())
639    })
640}
641
642/// Return a boolean array indicating which elements are negative infinity.
643#[generate_macro]
644#[default_device]
645pub fn is_neg_inf_device(
646    array: impl AsRef<Array>,
647    #[optional] stream: impl AsRef<Stream>,
648) -> Result<Array> {
649    Array::try_from_op(|res| unsafe {
650        mlx_sys::mlx_isneginf(res, array.as_ref().as_ptr(), stream.as_ref().as_ptr())
651    })
652}
653
654/// Select from `a` or `b` according to `condition` returning an error if the arrays are not
655/// broadcastable.
656///
657/// The condition and input arrays must be the same shape or
658/// [broadcasting](https://swiftpackageindex.com/ml-explore/mlx-swift/main/documentation/mlx/broadcasting)
659/// with each another.
660///
661/// # Params
662///
663/// - condition: condition array
664/// - a: input selected from where condition is non-zero or `true`
665/// - b: input selected from where condition is zero or `false`
666#[default_device]
667pub fn r#where_device(
668    condition: impl AsRef<Array>,
669    a: impl AsRef<Array>,
670    b: impl AsRef<Array>,
671    stream: impl AsRef<Stream>,
672) -> Result<Array> {
673    Array::try_from_op(|res| unsafe {
674        mlx_sys::mlx_where(
675            res,
676            condition.as_ref().as_ptr(),
677            a.as_ref().as_ptr(),
678            b.as_ref().as_ptr(),
679            stream.as_ref().as_ptr(),
680        )
681    })
682}
683
684/// Alias for [`r#where`]
685#[generate_macro]
686#[default_device]
687pub fn which_device(
688    condition: impl AsRef<Array>,
689    a: impl AsRef<Array>,
690    b: impl AsRef<Array>,
691    #[optional] stream: impl AsRef<Stream>,
692) -> Result<Array> {
693    r#where_device(condition, a, b, stream)
694}
695
696#[cfg(test)]
697mod tests {
698    use crate::{array, Dtype};
699
700    use super::*;
701
702    #[test]
703    fn test_eq() {
704        let a = Array::from_slice(&[1, 2, 3], &[3]);
705        let b = Array::from_slice(&[1, 2, 3], &[3]);
706        let c = a.eq(&b).unwrap();
707
708        let c_data: &[bool] = c.as_slice();
709        assert_eq!(c_data, [true, true, true]);
710
711        // check a and b are not modified
712        let a_data: &[i32] = a.as_slice();
713        assert_eq!(a_data, [1, 2, 3]);
714
715        let b_data: &[i32] = b.as_slice();
716        assert_eq!(b_data, [1, 2, 3]);
717    }
718
719    #[test]
720    fn test_eq_invalid_broadcast() {
721        let a = Array::from_slice(&[1, 2, 3], &[3]);
722        let b = Array::from_slice(&[1, 2, 3, 4], &[4]);
723        let c = a.eq(&b);
724        assert!(c.is_err());
725    }
726
727    #[test]
728    fn test_le() {
729        let a = Array::from_slice(&[1, 2, 3], &[3]);
730        let b = Array::from_slice(&[1, 2, 3], &[3]);
731        let c = a.le(&b).unwrap();
732
733        let c_data: &[bool] = c.as_slice();
734        assert_eq!(c_data, [true, true, true]);
735
736        // check a and b are not modified
737        let a_data: &[i32] = a.as_slice();
738        assert_eq!(a_data, [1, 2, 3]);
739
740        let b_data: &[i32] = b.as_slice();
741        assert_eq!(b_data, [1, 2, 3]);
742    }
743
744    #[test]
745    fn test_le_invalid_broadcast() {
746        let a = Array::from_slice(&[1, 2, 3], &[3]);
747        let b = Array::from_slice(&[1, 2, 3, 4], &[4]);
748        let c = a.le(&b);
749        assert!(c.is_err());
750    }
751
752    #[test]
753    fn test_ge() {
754        let a = Array::from_slice(&[1, 2, 3], &[3]);
755        let b = Array::from_slice(&[1, 2, 3], &[3]);
756        let c = a.ge(&b).unwrap();
757
758        let c_data: &[bool] = c.as_slice();
759        assert_eq!(c_data, [true, true, true]);
760
761        // check a and b are not modified
762        let a_data: &[i32] = a.as_slice();
763        assert_eq!(a_data, [1, 2, 3]);
764
765        let b_data: &[i32] = b.as_slice();
766        assert_eq!(b_data, [1, 2, 3]);
767    }
768
769    #[test]
770    fn test_ge_invalid_broadcast() {
771        let a = Array::from_slice(&[1, 2, 3], &[3]);
772        let b = Array::from_slice(&[1, 2, 3, 4], &[4]);
773        let c = a.ge(&b);
774        assert!(c.is_err());
775    }
776
777    #[test]
778    fn test_ne() {
779        let a = Array::from_slice(&[1, 2, 3], &[3]);
780        let b = Array::from_slice(&[1, 2, 3], &[3]);
781        let c = a.ne(&b).unwrap();
782
783        let c_data: &[bool] = c.as_slice();
784        assert_eq!(c_data, [false, false, false]);
785
786        // check a and b are not modified
787        let a_data: &[i32] = a.as_slice();
788        assert_eq!(a_data, [1, 2, 3]);
789
790        let b_data: &[i32] = b.as_slice();
791        assert_eq!(b_data, [1, 2, 3]);
792    }
793
794    #[test]
795    fn test_ne_invalid_broadcast() {
796        let a = Array::from_slice(&[1, 2, 3], &[3]);
797        let b = Array::from_slice(&[1, 2, 3, 4], &[4]);
798        let c = a.ne(&b);
799        assert!(c.is_err());
800    }
801
802    #[test]
803    fn test_lt() {
804        let a = Array::from_slice(&[1, 0, 3], &[3]);
805        let b = Array::from_slice(&[1, 2, 3], &[3]);
806        let c = a.lt(&b).unwrap();
807
808        let c_data: &[bool] = c.as_slice();
809        assert_eq!(c_data, [false, true, false]);
810
811        // check a and b are not modified
812        let a_data: &[i32] = a.as_slice();
813        assert_eq!(a_data, [1, 0, 3]);
814
815        let b_data: &[i32] = b.as_slice();
816        assert_eq!(b_data, [1, 2, 3]);
817    }
818
819    #[test]
820    fn test_lt_invalid_broadcast() {
821        let a = Array::from_slice(&[1, 2, 3], &[3]);
822        let b = Array::from_slice(&[1, 2, 3, 4], &[4]);
823        let c = a.lt(&b);
824        assert!(c.is_err());
825    }
826
827    #[test]
828    fn test_gt() {
829        let a = Array::from_slice(&[1, 4, 3], &[3]);
830        let b = Array::from_slice(&[1, 2, 3], &[3]);
831        let c = a.gt(&b).unwrap();
832
833        let c_data: &[bool] = c.as_slice();
834        assert_eq!(c_data, [false, true, false]);
835
836        // check a and b are not modified
837        let a_data: &[i32] = a.as_slice();
838        assert_eq!(a_data, [1, 4, 3]);
839
840        let b_data: &[i32] = b.as_slice();
841        assert_eq!(b_data, [1, 2, 3]);
842    }
843
844    #[test]
845    fn test_gt_invalid_broadcast() {
846        let a = Array::from_slice(&[1, 2, 3], &[3]);
847        let b = Array::from_slice(&[1, 2, 3, 4], &[4]);
848        let c = a.gt(&b);
849        assert!(c.is_err());
850    }
851
852    #[test]
853    fn test_logical_and() {
854        let a = Array::from_slice(&[true, false, true], &[3]);
855        let b = Array::from_slice(&[true, true, false], &[3]);
856        let c = a.logical_and(&b).unwrap();
857
858        let c_data: &[bool] = c.as_slice();
859        assert_eq!(c_data, [true, false, false]);
860
861        // check a and b are not modified
862        let a_data: &[bool] = a.as_slice();
863        assert_eq!(a_data, [true, false, true]);
864
865        let b_data: &[bool] = b.as_slice();
866        assert_eq!(b_data, [true, true, false]);
867    }
868
869    #[test]
870    fn test_logical_and_invalid_broadcast() {
871        let a = Array::from_slice(&[true, false, true], &[3]);
872        let b = Array::from_slice(&[true, true, false, true], &[4]);
873        let c = a.logical_and(&b);
874        assert!(c.is_err());
875    }
876
877    #[test]
878    fn test_logical_or() {
879        let a = Array::from_slice(&[true, false, true], &[3]);
880        let b = Array::from_slice(&[true, true, false], &[3]);
881        let c = a.logical_or(&b).unwrap();
882
883        let c_data: &[bool] = c.as_slice();
884        assert_eq!(c_data, [true, true, true]);
885
886        // check a and b are not modified
887        let a_data: &[bool] = a.as_slice();
888        assert_eq!(a_data, [true, false, true]);
889
890        let b_data: &[bool] = b.as_slice();
891        assert_eq!(b_data, [true, true, false]);
892    }
893
894    #[test]
895    fn test_logical_or_invalid_broadcast() {
896        let a = Array::from_slice(&[true, false, true], &[3]);
897        let b = Array::from_slice(&[true, true, false, true], &[4]);
898        let c = a.logical_or(&b);
899        assert!(c.is_err());
900    }
901
902    #[test]
903    fn test_all_close() {
904        let a = Array::from_slice(&[0., 1., 2., 3.], &[4]).sqrt().unwrap();
905        let b = Array::from_slice(&[0., 1., 2., 3.], &[4])
906            .power(array!(0.5))
907            .unwrap();
908        let c = a.all_close(&b, 1e-5, None, None).unwrap();
909
910        let c_data: &[bool] = c.as_slice();
911        assert_eq!(c_data, [true]);
912    }
913
914    #[test]
915    fn test_all_close_invalid_broadcast() {
916        let a = Array::from_slice(&[0., 1., 2., 3.], &[4]);
917        let b = Array::from_slice(&[0., 1., 2., 3., 4.], &[5]);
918        let c = a.all_close(&b, 1e-5, None, None);
919        assert!(c.is_err());
920    }
921
922    #[test]
923    fn test_is_close_false() {
924        let a = Array::from_slice(&[1., 2., 3.], &[3]);
925        let b = Array::from_slice(&[1.1, 2.2, 3.3], &[3]);
926        let c = a.is_close(&b, None, None, false).unwrap();
927
928        let c_data: &[bool] = c.as_slice();
929        assert_eq!(c_data, [false, false, false]);
930    }
931
932    #[test]
933    fn test_is_close_true() {
934        let a = Array::from_slice(&[1., 2., 3.], &[3]);
935        let b = Array::from_slice(&[1.1, 2.2, 3.3], &[3]);
936        let c = a.is_close(&b, 0.1, 0.2, true).unwrap();
937
938        let c_data: &[bool] = c.as_slice();
939        assert_eq!(c_data, [true, true, true]);
940    }
941
942    #[test]
943    fn test_is_close_invalid_broadcast() {
944        let a = Array::from_slice(&[1., 2., 3.], &[3]);
945        let b = Array::from_slice(&[1.1, 2.2, 3.3, 4.4], &[4]);
946        let c = a.is_close(&b, None, None, false);
947        assert!(c.is_err());
948    }
949
950    #[test]
951    fn test_array_eq() {
952        let a = Array::from_slice(&[0, 1, 2, 3], &[4]);
953        let b = Array::from_slice(&[0., 1., 2., 3.], &[4]);
954        let c = a.array_eq(&b, None).unwrap();
955
956        let c_data: &[bool] = c.as_slice();
957        assert_eq!(c_data, [true]);
958    }
959
960    #[test]
961    fn test_any() {
962        let array = Array::from_slice(&[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11], &[3, 4]);
963        let all = array.any(&[0][..], None).unwrap();
964
965        let results: &[bool] = all.as_slice();
966        assert_eq!(results, &[true, true, true, true]);
967    }
968
969    #[test]
970    fn test_any_empty_axes() {
971        let array = Array::from_slice(&[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11], &[3, 4]);
972        let all = array.any(&[][..], None).unwrap();
973
974        let results: &[bool] = all.as_slice();
975        assert_eq!(
976            results,
977            &[false, true, true, true, true, true, true, true, true, true, true, true]
978        );
979    }
980
981    #[test]
982    fn test_any_out_of_bounds() {
983        let array = Array::from_slice(&[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11], &[12]);
984        let result = array.any(&[1][..], None);
985        assert!(result.is_err());
986    }
987
988    #[test]
989    fn test_any_duplicate_axes() {
990        let array = Array::from_slice(&[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11], &[3, 4]);
991        let result = array.any(&[0, 0][..], None);
992        assert!(result.is_err());
993    }
994
995    #[test]
996    fn test_which() {
997        let condition = Array::from_slice(&[true, false, true], &[3]);
998        let a = Array::from_slice(&[1, 2, 3], &[3]);
999        let b = Array::from_slice(&[4, 5, 6], &[3]);
1000        let c = which(&condition, &a, &b).unwrap();
1001
1002        let c_data: &[i32] = c.as_slice();
1003        assert_eq!(c_data, [1, 5, 3]);
1004    }
1005
1006    #[test]
1007    fn test_which_invalid_broadcast() {
1008        let condition = Array::from_slice(&[true, false, true], &[3]);
1009        let a = Array::from_slice(&[1, 2, 3], &[3]);
1010        let b = Array::from_slice(&[4, 5, 6, 7], &[4]);
1011        let c = which(&condition, &a, &b);
1012        assert!(c.is_err());
1013    }
1014
1015    // The unit tests below are adapted from the mlx c++ codebase
1016
1017    #[test]
1018    fn test_unary_logical_not() {
1019        let x = array!(false);
1020        assert!(logical_not(&x).unwrap().item::<bool>());
1021
1022        let x = array!(1.0);
1023        let y = logical_not(&x).unwrap();
1024        assert_eq!(y.dtype(), Dtype::Bool);
1025        assert!(!y.item::<bool>());
1026
1027        let x = array!(0);
1028        let y = logical_not(&x).unwrap();
1029        assert_eq!(y.dtype(), Dtype::Bool);
1030        assert!(y.item::<bool>());
1031    }
1032
1033    #[test]
1034    fn test_unary_logical_and() {
1035        let x = array!(true);
1036        let y = array!(true);
1037        assert!(logical_and(&x, &y).unwrap().item::<bool>());
1038
1039        let x = array!(1.0);
1040        let y = array!(1.0);
1041        let z = logical_and(&x, &y).unwrap();
1042        assert_eq!(z.dtype(), Dtype::Bool);
1043        assert!(z.item::<bool>());
1044
1045        let x = array!(0);
1046        let y = array!(1.0);
1047        let z = logical_and(&x, &y).unwrap();
1048        assert_eq!(z.dtype(), Dtype::Bool);
1049        assert!(!z.item::<bool>());
1050    }
1051
1052    #[test]
1053    fn test_unary_logical_or() {
1054        let a = array!(false);
1055        let b = array!(false);
1056        assert!(!logical_or(&a, &b).unwrap().item::<bool>());
1057
1058        let a = array!(1.0);
1059        let b = array!(1.0);
1060        let c = logical_or(&a, &b).unwrap();
1061        assert_eq!(c.dtype(), Dtype::Bool);
1062        assert!(c.item::<bool>());
1063
1064        let a = array!(0);
1065        let b = array!(1.0);
1066        let c = logical_or(&a, &b).unwrap();
1067        assert_eq!(c.dtype(), Dtype::Bool);
1068        assert!(c.item::<bool>());
1069    }
1070}