mlx_rs/ops/
logical.rs

1use crate::array::Array;
2use crate::error::Result;
3use crate::utils::guard::Guarded;
4use crate::Stream;
5use mlx_internal_macros::{default_device, generate_macro};
6
7impl Array {
8    /// Element-wise equality returning an error if the arrays are not broadcastable.
9    ///
10    /// Equality comparison on two arrays with
11    /// [broadcasting](https://swiftpackageindex.com/ml-explore/mlx-swift/main/documentation/mlx/broadcasting).
12    ///
13    /// # Params
14    ///
15    /// - other: array to compare
16    ///
17    /// # Example
18    ///
19    /// ```rust
20    /// use mlx_rs::Array;
21    /// let a = Array::from_slice(&[1, 2, 3], &[3]);
22    /// let b = Array::from_slice(&[1, 2, 3], &[3]);
23    /// let mut c = a.eq(&b).unwrap();
24    ///
25    /// let c_data: &[bool] = c.as_slice();
26    /// // c_data == [true, true, true]
27    /// ```
28    #[default_device]
29    pub fn eq_device(&self, other: impl AsRef<Array>, stream: impl AsRef<Stream>) -> Result<Array> {
30        Array::try_from_op(|res| unsafe {
31            mlx_sys::mlx_equal(
32                res,
33                self.as_ptr(),
34                other.as_ref().as_ptr(),
35                stream.as_ref().as_ptr(),
36            )
37        })
38    }
39
40    /// Element-wise less than or equal returning an error if the arrays are not broadcastable.
41    ///
42    /// Less than or equal on two arrays with
43    /// [broadcasting](https://swiftpackageindex.com/ml-explore/mlx-swift/main/documentation/mlx/broadcasting).
44    ///
45    /// # Params
46    ///
47    /// - other: array to compare
48    ///
49    /// # Example
50    ///
51    /// ```rust
52    /// use mlx_rs::Array;
53    /// let a = Array::from_slice(&[1, 2, 3], &[3]);
54    /// let b = Array::from_slice(&[1, 2, 3], &[3]);
55    /// let mut c = a.le(&b).unwrap();
56    ///
57    /// let c_data: &[bool] = c.as_slice();
58    /// // c_data == [true, true, true]
59    /// ```
60    #[default_device]
61    pub fn le_device(&self, other: impl AsRef<Array>, stream: impl AsRef<Stream>) -> Result<Array> {
62        Array::try_from_op(|res| unsafe {
63            mlx_sys::mlx_less_equal(
64                res,
65                self.as_ptr(),
66                other.as_ref().as_ptr(),
67                stream.as_ref().as_ptr(),
68            )
69        })
70    }
71
72    /// Element-wise greater than or equal returning an error if the arrays are not broadcastable.
73    ///
74    /// Greater than or equal on two arrays with
75    /// [broadcasting](https://swiftpackageindex.com/ml-explore/mlx-swift/main/documentation/mlx/broadcasting).
76    ///
77    /// # Params
78    ///
79    /// - other: array to compare
80    ///
81    /// # Example
82    ///
83    /// ```rust
84    /// use mlx_rs::Array;
85    /// let a = Array::from_slice(&[1, 2, 3], &[3]);
86    /// let b = Array::from_slice(&[1, 2, 3], &[3]);
87    /// let mut c = a.ge(&b).unwrap();
88    ///
89    /// let c_data: &[bool] = c.as_slice();
90    /// // c_data == [true, true, true]
91    /// ```
92    #[default_device]
93    pub fn ge_device(&self, other: impl AsRef<Array>, stream: impl AsRef<Stream>) -> Result<Array> {
94        Array::try_from_op(|res| unsafe {
95            mlx_sys::mlx_greater_equal(
96                res,
97                self.as_ptr(),
98                other.as_ref().as_ptr(),
99                stream.as_ref().as_ptr(),
100            )
101        })
102    }
103
104    /// Element-wise not equal returning an error if the arrays are not broadcastable.
105    ///
106    /// Not equal on two arrays with
107    /// [broadcasting](https://swiftpackageindex.com/ml-explore/mlx-swift/main/documentation/mlx/broadcasting).
108    ///
109    /// # Params
110    ///
111    /// - other: array to compare
112    ///
113    /// # Example
114    ///
115    /// ```rust
116    /// use mlx_rs::Array;
117    /// let a = Array::from_slice(&[1, 2, 3], &[3]);
118    /// let b = Array::from_slice(&[1, 2, 3], &[3]);
119    /// let mut c = a.ne(&b).unwrap();
120    ///
121    /// let c_data: &[bool] = c.as_slice();
122    /// // c_data == [false, false, false]
123    /// ```
124    #[default_device]
125    pub fn ne_device(&self, other: impl AsRef<Array>, stream: impl AsRef<Stream>) -> Result<Array> {
126        Array::try_from_op(|res| unsafe {
127            mlx_sys::mlx_not_equal(
128                res,
129                self.as_ptr(),
130                other.as_ref().as_ptr(),
131                stream.as_ref().as_ptr(),
132            )
133        })
134    }
135
136    /// Element-wise less than returning an error if the arrays are not broadcastable.
137    ///
138    /// Less than on two arrays with [broadcasting](https://swiftpackageindex.com/ml-explore/mlx-swift/main/documentation/mlx/broadcasting).
139    ///
140    /// # Params
141    ///
142    /// - other: array to compare
143    ///
144    /// # Example
145    ///
146    /// ```rust
147    /// use mlx_rs::Array;
148    /// let a = Array::from_slice(&[1, 2, 3], &[3]);
149    /// let b = Array::from_slice(&[1, 2, 3], &[3]);
150    /// let mut c = a.lt(&b).unwrap();
151    ///
152    /// let c_data: &[bool] = c.as_slice();
153    /// // c_data == [false, false, false]
154    /// ```
155    #[default_device]
156    pub fn lt_device(&self, other: impl AsRef<Array>, stream: impl AsRef<Stream>) -> Result<Array> {
157        Array::try_from_op(|res| unsafe {
158            mlx_sys::mlx_less(
159                res,
160                self.as_ptr(),
161                other.as_ref().as_ptr(),
162                stream.as_ref().as_ptr(),
163            )
164        })
165    }
166
167    /// Element-wise greater than returning an error if the arrays are not broadcastable.
168    ///
169    /// Greater than on two arrays with [broadcasting](https://swiftpackageindex.com/ml-explore/mlx-swift/main/documentation/mlx/broadcasting).
170    ///
171    /// # Params
172    ///
173    /// - other: array to compare
174    ///
175    /// # Example
176    ///
177    /// ```rust
178    /// use mlx_rs::Array;
179    /// let a = Array::from_slice(&[1, 2, 3], &[3]);
180    /// let b = Array::from_slice(&[1, 2, 3], &[3]);
181    /// let mut c = a.gt(&b).unwrap();
182    ///
183    /// let c_data: &[bool] = c.as_slice();
184    /// // c_data == [false, false, false]
185    /// ```
186    #[default_device]
187    pub fn gt_device(&self, other: impl AsRef<Array>, stream: impl AsRef<Stream>) -> Result<Array> {
188        Array::try_from_op(|res| unsafe {
189            mlx_sys::mlx_greater(
190                res,
191                self.as_ptr(),
192                other.as_ref().as_ptr(),
193                stream.as_ref().as_ptr(),
194            )
195        })
196    }
197
198    /// Element-wise logical and returning an error if the arrays are not broadcastable.
199    ///
200    /// Logical and on two arrays with [broadcasting](https://swiftpackageindex.com/ml-explore/mlx-swift/main/documentation/mlx/broadcasting).
201    ///
202    /// # Params
203    ///
204    /// - other: array to compare
205    ///
206    /// # Example
207    ///
208    /// ```rust
209    /// use mlx_rs::Array;
210    /// let a = Array::from_slice(&[true, false, true], &[3]);
211    /// let b = Array::from_slice(&[true, true, false], &[3]);
212    /// let mut c = a.logical_and(&b).unwrap();
213    ///
214    /// let c_data: &[bool] = c.as_slice();
215    /// // c_data == [true, false, false]
216    /// ```
217    #[default_device]
218    pub fn logical_and_device(
219        &self,
220        other: impl AsRef<Array>,
221        stream: impl AsRef<Stream>,
222    ) -> Result<Array> {
223        Array::try_from_op(|res| unsafe {
224            mlx_sys::mlx_logical_and(
225                res,
226                self.as_ptr(),
227                other.as_ref().as_ptr(),
228                stream.as_ref().as_ptr(),
229            )
230        })
231    }
232
233    /// Element-wise logical or returning an error if the arrays are not broadcastable.
234    ///
235    /// Logical or on two arrays with [broadcasting](https://swiftpackageindex.com/ml-explore/mlx-swift/main/documentation/mlx/broadcasting).
236    ///
237    /// # Params
238    ///
239    /// - other: array to compare
240    ///
241    /// # Example
242    ///
243    /// ```rust
244    /// use mlx_rs::Array;
245    /// let a = Array::from_slice(&[true, false, true], &[3]);
246    /// let b = Array::from_slice(&[true, true, false], &[3]);
247    /// let mut c = a.logical_or(&b).unwrap();
248    ///
249    /// let c_data: &[bool] = c.as_slice();
250    /// // c_data == [true, true, true]
251    /// ```
252    #[default_device]
253    pub fn logical_or_device(
254        &self,
255        other: impl AsRef<Array>,
256        stream: impl AsRef<Stream>,
257    ) -> Result<Array> {
258        Array::try_from_op(|res| unsafe {
259            mlx_sys::mlx_logical_or(
260                res,
261                self.as_ptr(),
262                other.as_ref().as_ptr(),
263                stream.as_ref().as_ptr(),
264            )
265        })
266    }
267
268    /// Unary element-wise logical not.
269    ///
270    /// # Example
271    ///
272    /// ```rust
273    /// use mlx_rs::{Array, StreamOrDevice};
274    /// let a: Array = false.into();
275    /// let mut b = a.logical_not_device(StreamOrDevice::default()).unwrap();
276    ///
277    /// let b_data: &[bool] = b.as_slice();
278    /// // b_data == [true]
279    /// ```
280    #[default_device]
281    pub fn logical_not_device(&self, stream: impl AsRef<Stream>) -> Result<Array> {
282        Array::try_from_op(|res| unsafe {
283            mlx_sys::mlx_logical_not(res, self.as_ptr(), stream.as_ref().as_ptr())
284        })
285    }
286
287    /// Approximate comparison of two arrays returning an error if the inputs aren't valid.
288    ///
289    /// The arrays are considered equal if:
290    ///
291    /// ```text
292    /// all(abs(a - b) <= (atol + rtol * abs(b)))
293    /// ```
294    ///
295    /// # Params
296    ///
297    /// - other: array to compare
298    /// - rtol: relative tolerance = defaults to 1e-5 when None
299    /// - atol: absolute tolerance - defaults to 1e-8 when None
300    /// - equal_nan: whether to consider NaNs equal -- default is false when None
301    ///
302    /// # Example
303    ///
304    /// ```rust
305    /// use num_traits::Pow;
306    /// use mlx_rs::array;
307    /// let a = array!([0., 1., 2., 3.]).sqrt().unwrap();
308    /// let b = array!([0., 1., 2., 3.]).power(array!(0.5)).unwrap();
309    /// let mut c = a.all_close(&b, None, None, None).unwrap();
310    ///
311    /// let c_data: &[bool] = c.as_slice();
312    /// // c_data == [true]
313    /// ```
314    #[default_device]
315    pub fn all_close_device(
316        &self,
317        other: impl AsRef<Array>,
318        rtol: impl Into<Option<f64>>,
319        atol: impl Into<Option<f64>>,
320        equal_nan: impl Into<Option<bool>>,
321        stream: impl AsRef<Stream>,
322    ) -> Result<Array> {
323        Array::try_from_op(|res| unsafe {
324            mlx_sys::mlx_allclose(
325                res,
326                self.as_ptr(),
327                other.as_ref().as_ptr(),
328                rtol.into().unwrap_or(1e-5),
329                atol.into().unwrap_or(1e-8),
330                equal_nan.into().unwrap_or(false),
331                stream.as_ref().as_ptr(),
332            )
333        })
334    }
335
336    /// Returns a boolean array where two arrays are element-wise equal within a tolerance returning an error if the arrays are not broadcastable.
337    ///
338    /// Infinite values are considered equal if they have the same sign, NaN values are not equal unless
339    /// `equalNAN` is `true`.
340    ///
341    /// Two values are considered close if:
342    ///
343    /// ```text
344    /// abs(a - b) <= (atol + rtol * abs(b))
345    /// ```
346    ///
347    /// Unlike [self.array_eq] this function supports [broadcasting](https://swiftpackageindex.com/ml-explore/mlx-swift/main/documentation/mlx/broadcasting).
348    #[default_device]
349    pub fn is_close_device(
350        &self,
351        other: impl AsRef<Array>,
352        rtol: impl Into<Option<f64>>,
353        atol: impl Into<Option<f64>>,
354        equal_nan: impl Into<Option<bool>>,
355        stream: impl AsRef<Stream>,
356    ) -> Result<Array> {
357        Array::try_from_op(|res| unsafe {
358            mlx_sys::mlx_isclose(
359                res,
360                self.as_ptr(),
361                other.as_ref().as_ptr(),
362                rtol.into().unwrap_or(1e-5),
363                atol.into().unwrap_or(1e-8),
364                equal_nan.into().unwrap_or(false),
365                stream.as_ref().as_ptr(),
366            )
367        })
368    }
369
370    /// Array equality check.
371    ///
372    /// Compare two arrays for equality. Returns `true` iff the arrays have
373    /// the same shape and their values are equal. The arrays need not have
374    /// the same type to be considered equal.
375    ///
376    /// # Params
377    ///
378    /// - other: array to compare
379    /// - equal_nan: whether to consider NaNs equal -- default is false when None
380    ///
381    /// # Example
382    ///
383    /// ```rust
384    /// use mlx_rs::Array;
385    /// let a = Array::from_slice(&[0, 1, 2, 3], &[4]);
386    /// let b = Array::from_slice(&[0., 1., 2., 3.], &[4]);
387    ///
388    /// let c = a.array_eq(&b, None);
389    /// // c == [true]
390    /// ```
391    #[default_device]
392    pub fn array_eq_device(
393        &self,
394        other: impl AsRef<Array>,
395        equal_nan: impl Into<Option<bool>>,
396        stream: impl AsRef<Stream>,
397    ) -> Result<Array> {
398        Array::try_from_op(|res| unsafe {
399            mlx_sys::mlx_array_equal(
400                res,
401                self.as_ptr(),
402                other.as_ref().as_ptr(),
403                equal_nan.into().unwrap_or(false),
404                stream.as_ref().as_ptr(),
405            )
406        })
407    }
408
409    /// An `or` reduction over the given axes returning an error if the axes are invalid.
410    ///
411    /// # Params
412    ///
413    /// - axes: axes to reduce over -- defaults to all axes if not provided
414    /// - keep_dims: if `true` keep reduced axis as singleton dimension -- defaults to false if not provided
415    ///
416    ///  # Example
417    ///
418    /// ```rust
419    /// use mlx_rs::Array;
420    ///
421    /// let array = Array::from_slice(&[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11], &[3, 4]);
422    ///
423    /// // will produce a scalar Array with true -- some of the values are non-zero
424    /// let all = array.any(None).unwrap();
425    ///
426    /// // produces an Array([true, true, true, true]) -- all rows have non-zeros
427    /// let all_rows = array.any_axes(&[0], None).unwrap();
428    /// ```
429    #[default_device]
430    pub fn any_axes_device(
431        &self,
432        axes: &[i32],
433        keep_dims: impl Into<Option<bool>>,
434        stream: impl AsRef<Stream>,
435    ) -> Result<Array> {
436        Array::try_from_op(|res| unsafe {
437            mlx_sys::mlx_any_axes(
438                res,
439                self.as_ptr(),
440                axes.as_ptr(),
441                axes.len(),
442                keep_dims.into().unwrap_or(false),
443                stream.as_ref().as_ptr(),
444            )
445        })
446    }
447
448    /// Similar to [`any_axes`] but defaults to all axes.
449    #[default_device]
450    pub fn any_axis_device(
451        &self,
452        axis: i32,
453        keep_dims: impl Into<Option<bool>>,
454        stream: impl AsRef<Stream>,
455    ) -> Result<Array> {
456        Array::try_from_op(|res| unsafe {
457            mlx_sys::mlx_any_axis(
458                res,
459                self.as_ptr(),
460                axis,
461                keep_dims.into().unwrap_or(false),
462                stream.as_ref().as_ptr(),
463            )
464        })
465    }
466
467    /// Similar to [`any_axes`] but defaults to all axes.
468    #[default_device]
469    pub fn any_device(
470        &self,
471        keep_dims: impl Into<Option<bool>>,
472        stream: impl AsRef<Stream>,
473    ) -> Result<Array> {
474        Array::try_from_op(|res| unsafe {
475            mlx_sys::mlx_any(
476                res,
477                self.as_ptr(),
478                keep_dims.into().unwrap_or(false),
479                stream.as_ref().as_ptr(),
480            )
481        })
482    }
483}
484
485/// See [`Array::any`]
486#[generate_macro]
487#[default_device]
488pub fn any_axes_device(
489    array: impl AsRef<Array>,
490    axes: &[i32],
491    #[optional] keep_dims: impl Into<Option<bool>>,
492    #[optional] stream: impl AsRef<Stream>,
493) -> Result<Array> {
494    array.as_ref().any_axes_device(axes, keep_dims, stream)
495}
496
497/// See [`Array::any`]
498#[generate_macro]
499#[default_device]
500pub fn any_axis_device(
501    array: impl AsRef<Array>,
502    axis: i32,
503    #[optional] keep_dims: impl Into<Option<bool>>,
504    #[optional] stream: impl AsRef<Stream>,
505) -> Result<Array> {
506    array.as_ref().any_axis_device(axis, keep_dims, stream)
507}
508
509/// See [`Array::any`]
510#[generate_macro]
511#[default_device]
512pub fn any_device(
513    array: impl AsRef<Array>,
514    #[optional] keep_dims: impl Into<Option<bool>>,
515    #[optional] stream: impl AsRef<Stream>,
516) -> Result<Array> {
517    array.as_ref().any_device(keep_dims, stream)
518}
519
520/// See [`Array::logical_and`]
521#[generate_macro]
522#[default_device]
523pub fn logical_and_device(
524    a: impl AsRef<Array>,
525    b: impl AsRef<Array>,
526    #[optional] stream: impl AsRef<Stream>,
527) -> Result<Array> {
528    a.as_ref().logical_and_device(b, stream)
529}
530
531/// See [`Array::logical_or`]
532#[generate_macro]
533#[default_device]
534pub fn logical_or_device(
535    a: impl AsRef<Array>,
536    b: impl AsRef<Array>,
537    #[optional] stream: impl AsRef<Stream>,
538) -> Result<Array> {
539    a.as_ref().logical_or_device(b, stream)
540}
541
542/// See [`Array::logical_not`]
543#[generate_macro]
544#[default_device]
545pub fn logical_not_device(
546    a: impl AsRef<Array>,
547    #[optional] stream: impl AsRef<Stream>,
548) -> Result<Array> {
549    a.as_ref().logical_not_device(stream)
550}
551
552/// See [`Array::all_close`]
553#[generate_macro]
554#[default_device]
555pub fn all_close_device(
556    a: impl AsRef<Array>,
557    b: impl AsRef<Array>,
558    #[optional] rtol: impl Into<Option<f64>>,
559    #[optional] atol: impl Into<Option<f64>>,
560    #[optional] equal_nan: impl Into<Option<bool>>,
561    #[optional] stream: impl AsRef<Stream>,
562) -> Result<Array> {
563    a.as_ref()
564        .all_close_device(b, rtol, atol, equal_nan, stream)
565}
566
567/// See [`Array::is_close`]
568#[generate_macro]
569#[default_device]
570pub fn is_close_device(
571    a: impl AsRef<Array>,
572    b: impl AsRef<Array>,
573    #[optional] rtol: impl Into<Option<f64>>,
574    #[optional] atol: impl Into<Option<f64>>,
575    #[optional] equal_nan: impl Into<Option<bool>>,
576    #[optional] stream: impl AsRef<Stream>,
577) -> Result<Array> {
578    a.as_ref().is_close_device(b, rtol, atol, equal_nan, stream)
579}
580
581/// See [`Array::array_eq`]
582#[generate_macro]
583#[default_device]
584pub fn array_eq_device(
585    a: impl AsRef<Array>,
586    b: impl AsRef<Array>,
587    #[optional] equal_nan: impl Into<Option<bool>>,
588    #[optional] stream: impl AsRef<Stream>,
589) -> Result<Array> {
590    a.as_ref().array_eq_device(b, equal_nan, stream)
591}
592
593/// See [`Array::eq`]
594#[generate_macro]
595#[default_device]
596pub fn eq_device(
597    a: impl AsRef<Array>,
598    b: impl AsRef<Array>,
599    #[optional] stream: impl AsRef<Stream>,
600) -> Result<Array> {
601    a.as_ref().eq_device(b, stream)
602}
603
604/// See [`Array::le`]
605#[generate_macro]
606#[default_device]
607pub fn le_device(
608    a: impl AsRef<Array>,
609    b: impl AsRef<Array>,
610    #[optional] stream: impl AsRef<Stream>,
611) -> Result<Array> {
612    a.as_ref().le_device(b, stream)
613}
614
615/// See [`Array::ge`]
616#[generate_macro]
617#[default_device]
618pub fn ge_device(
619    a: impl AsRef<Array>,
620    b: impl AsRef<Array>,
621    #[optional] stream: impl AsRef<Stream>,
622) -> Result<Array> {
623    a.as_ref().ge_device(b, stream)
624}
625
626/// See [`Array::ne`]
627#[generate_macro]
628#[default_device]
629pub fn ne_device(
630    a: impl AsRef<Array>,
631    b: impl AsRef<Array>,
632    #[optional] stream: impl AsRef<Stream>,
633) -> Result<Array> {
634    a.as_ref().ne_device(b, stream)
635}
636
637/// See [`Array::lt`]
638#[generate_macro]
639#[default_device]
640pub fn lt_device(
641    a: impl AsRef<Array>,
642    b: impl AsRef<Array>,
643    #[optional] stream: impl AsRef<Stream>,
644) -> Result<Array> {
645    a.as_ref().lt_device(b, stream)
646}
647
648/// See [`Array::gt`]
649#[generate_macro]
650#[default_device]
651pub fn gt_device(
652    a: impl AsRef<Array>,
653    b: impl AsRef<Array>,
654    #[optional] stream: impl AsRef<Stream>,
655) -> Result<Array> {
656    a.as_ref().gt_device(b, stream)
657}
658
659// TODO: check if the functions below could throw an exception.
660
661/// Return a boolean array indicating which elements are NaN.
662#[generate_macro]
663#[default_device]
664pub fn is_nan_device(
665    array: impl AsRef<Array>,
666    #[optional] stream: impl AsRef<Stream>,
667) -> Result<Array> {
668    Array::try_from_op(|res| unsafe {
669        mlx_sys::mlx_isnan(res, array.as_ref().as_ptr(), stream.as_ref().as_ptr())
670    })
671}
672
673/// Return a boolean array indicating which elements are +/- inifnity.
674#[generate_macro]
675#[default_device]
676pub fn is_inf_device(
677    array: impl AsRef<Array>,
678    #[optional] stream: impl AsRef<Stream>,
679) -> Result<Array> {
680    Array::try_from_op(|res| unsafe {
681        mlx_sys::mlx_isinf(res, array.as_ref().as_ptr(), stream.as_ref().as_ptr())
682    })
683}
684
685/// Return a boolean array indicating which elements are positive infinity.
686#[generate_macro]
687#[default_device]
688pub fn is_pos_inf_device(
689    array: impl AsRef<Array>,
690    #[optional] stream: impl AsRef<Stream>,
691) -> Result<Array> {
692    Array::try_from_op(|res| unsafe {
693        mlx_sys::mlx_isposinf(res, array.as_ref().as_ptr(), stream.as_ref().as_ptr())
694    })
695}
696
697/// Return a boolean array indicating which elements are negative infinity.
698#[generate_macro]
699#[default_device]
700pub fn is_neg_inf_device(
701    array: impl AsRef<Array>,
702    #[optional] stream: impl AsRef<Stream>,
703) -> Result<Array> {
704    Array::try_from_op(|res| unsafe {
705        mlx_sys::mlx_isneginf(res, array.as_ref().as_ptr(), stream.as_ref().as_ptr())
706    })
707}
708
709/// Select from `a` or `b` according to `condition` returning an error if the arrays are not
710/// broadcastable.
711///
712/// The condition and input arrays must be the same shape or
713/// [broadcasting](https://swiftpackageindex.com/ml-explore/mlx-swift/main/documentation/mlx/broadcasting)
714/// with each another.
715///
716/// # Params
717///
718/// - condition: condition array
719/// - a: input selected from where condition is non-zero or `true`
720/// - b: input selected from where condition is zero or `false`
721#[default_device]
722pub fn r#where_device(
723    condition: impl AsRef<Array>,
724    a: impl AsRef<Array>,
725    b: impl AsRef<Array>,
726    stream: impl AsRef<Stream>,
727) -> Result<Array> {
728    Array::try_from_op(|res| unsafe {
729        mlx_sys::mlx_where(
730            res,
731            condition.as_ref().as_ptr(),
732            a.as_ref().as_ptr(),
733            b.as_ref().as_ptr(),
734            stream.as_ref().as_ptr(),
735        )
736    })
737}
738
739/// Alias for [`r#where`]
740#[generate_macro]
741#[default_device]
742pub fn which_device(
743    condition: impl AsRef<Array>,
744    a: impl AsRef<Array>,
745    b: impl AsRef<Array>,
746    #[optional] stream: impl AsRef<Stream>,
747) -> Result<Array> {
748    r#where_device(condition, a, b, stream)
749}
750
751#[cfg(test)]
752mod tests {
753    use crate::{array, Dtype};
754
755    use super::*;
756
757    #[test]
758    fn test_eq() {
759        let a = Array::from_slice(&[1, 2, 3], &[3]);
760        let b = Array::from_slice(&[1, 2, 3], &[3]);
761        let c = a.eq(&b).unwrap();
762
763        let c_data: &[bool] = c.as_slice();
764        assert_eq!(c_data, [true, true, true]);
765
766        // check a and b are not modified
767        let a_data: &[i32] = a.as_slice();
768        assert_eq!(a_data, [1, 2, 3]);
769
770        let b_data: &[i32] = b.as_slice();
771        assert_eq!(b_data, [1, 2, 3]);
772    }
773
774    #[test]
775    fn test_eq_invalid_broadcast() {
776        let a = Array::from_slice(&[1, 2, 3], &[3]);
777        let b = Array::from_slice(&[1, 2, 3, 4], &[4]);
778        let c = a.eq(&b);
779        assert!(c.is_err());
780    }
781
782    #[test]
783    fn test_le() {
784        let a = Array::from_slice(&[1, 2, 3], &[3]);
785        let b = Array::from_slice(&[1, 2, 3], &[3]);
786        let c = a.le(&b).unwrap();
787
788        let c_data: &[bool] = c.as_slice();
789        assert_eq!(c_data, [true, true, true]);
790
791        // check a and b are not modified
792        let a_data: &[i32] = a.as_slice();
793        assert_eq!(a_data, [1, 2, 3]);
794
795        let b_data: &[i32] = b.as_slice();
796        assert_eq!(b_data, [1, 2, 3]);
797    }
798
799    #[test]
800    fn test_le_invalid_broadcast() {
801        let a = Array::from_slice(&[1, 2, 3], &[3]);
802        let b = Array::from_slice(&[1, 2, 3, 4], &[4]);
803        let c = a.le(&b);
804        assert!(c.is_err());
805    }
806
807    #[test]
808    fn test_ge() {
809        let a = Array::from_slice(&[1, 2, 3], &[3]);
810        let b = Array::from_slice(&[1, 2, 3], &[3]);
811        let c = a.ge(&b).unwrap();
812
813        let c_data: &[bool] = c.as_slice();
814        assert_eq!(c_data, [true, true, true]);
815
816        // check a and b are not modified
817        let a_data: &[i32] = a.as_slice();
818        assert_eq!(a_data, [1, 2, 3]);
819
820        let b_data: &[i32] = b.as_slice();
821        assert_eq!(b_data, [1, 2, 3]);
822    }
823
824    #[test]
825    fn test_ge_invalid_broadcast() {
826        let a = Array::from_slice(&[1, 2, 3], &[3]);
827        let b = Array::from_slice(&[1, 2, 3, 4], &[4]);
828        let c = a.ge(&b);
829        assert!(c.is_err());
830    }
831
832    #[test]
833    fn test_ne() {
834        let a = Array::from_slice(&[1, 2, 3], &[3]);
835        let b = Array::from_slice(&[1, 2, 3], &[3]);
836        let c = a.ne(&b).unwrap();
837
838        let c_data: &[bool] = c.as_slice();
839        assert_eq!(c_data, [false, false, false]);
840
841        // check a and b are not modified
842        let a_data: &[i32] = a.as_slice();
843        assert_eq!(a_data, [1, 2, 3]);
844
845        let b_data: &[i32] = b.as_slice();
846        assert_eq!(b_data, [1, 2, 3]);
847    }
848
849    #[test]
850    fn test_ne_invalid_broadcast() {
851        let a = Array::from_slice(&[1, 2, 3], &[3]);
852        let b = Array::from_slice(&[1, 2, 3, 4], &[4]);
853        let c = a.ne(&b);
854        assert!(c.is_err());
855    }
856
857    #[test]
858    fn test_lt() {
859        let a = Array::from_slice(&[1, 0, 3], &[3]);
860        let b = Array::from_slice(&[1, 2, 3], &[3]);
861        let c = a.lt(&b).unwrap();
862
863        let c_data: &[bool] = c.as_slice();
864        assert_eq!(c_data, [false, true, false]);
865
866        // check a and b are not modified
867        let a_data: &[i32] = a.as_slice();
868        assert_eq!(a_data, [1, 0, 3]);
869
870        let b_data: &[i32] = b.as_slice();
871        assert_eq!(b_data, [1, 2, 3]);
872    }
873
874    #[test]
875    fn test_lt_invalid_broadcast() {
876        let a = Array::from_slice(&[1, 2, 3], &[3]);
877        let b = Array::from_slice(&[1, 2, 3, 4], &[4]);
878        let c = a.lt(&b);
879        assert!(c.is_err());
880    }
881
882    #[test]
883    fn test_gt() {
884        let a = Array::from_slice(&[1, 4, 3], &[3]);
885        let b = Array::from_slice(&[1, 2, 3], &[3]);
886        let c = a.gt(&b).unwrap();
887
888        let c_data: &[bool] = c.as_slice();
889        assert_eq!(c_data, [false, true, false]);
890
891        // check a and b are not modified
892        let a_data: &[i32] = a.as_slice();
893        assert_eq!(a_data, [1, 4, 3]);
894
895        let b_data: &[i32] = b.as_slice();
896        assert_eq!(b_data, [1, 2, 3]);
897    }
898
899    #[test]
900    fn test_gt_invalid_broadcast() {
901        let a = Array::from_slice(&[1, 2, 3], &[3]);
902        let b = Array::from_slice(&[1, 2, 3, 4], &[4]);
903        let c = a.gt(&b);
904        assert!(c.is_err());
905    }
906
907    #[test]
908    fn test_logical_and() {
909        let a = Array::from_slice(&[true, false, true], &[3]);
910        let b = Array::from_slice(&[true, true, false], &[3]);
911        let c = a.logical_and(&b).unwrap();
912
913        let c_data: &[bool] = c.as_slice();
914        assert_eq!(c_data, [true, false, false]);
915
916        // check a and b are not modified
917        let a_data: &[bool] = a.as_slice();
918        assert_eq!(a_data, [true, false, true]);
919
920        let b_data: &[bool] = b.as_slice();
921        assert_eq!(b_data, [true, true, false]);
922    }
923
924    #[test]
925    fn test_logical_and_invalid_broadcast() {
926        let a = Array::from_slice(&[true, false, true], &[3]);
927        let b = Array::from_slice(&[true, true, false, true], &[4]);
928        let c = a.logical_and(&b);
929        assert!(c.is_err());
930    }
931
932    #[test]
933    fn test_logical_or() {
934        let a = Array::from_slice(&[true, false, true], &[3]);
935        let b = Array::from_slice(&[true, true, false], &[3]);
936        let c = a.logical_or(&b).unwrap();
937
938        let c_data: &[bool] = c.as_slice();
939        assert_eq!(c_data, [true, true, true]);
940
941        // check a and b are not modified
942        let a_data: &[bool] = a.as_slice();
943        assert_eq!(a_data, [true, false, true]);
944
945        let b_data: &[bool] = b.as_slice();
946        assert_eq!(b_data, [true, true, false]);
947    }
948
949    #[test]
950    fn test_logical_or_invalid_broadcast() {
951        let a = Array::from_slice(&[true, false, true], &[3]);
952        let b = Array::from_slice(&[true, true, false, true], &[4]);
953        let c = a.logical_or(&b);
954        assert!(c.is_err());
955    }
956
957    #[test]
958    fn test_all_close() {
959        let a = Array::from_slice(&[0., 1., 2., 3.], &[4]).sqrt().unwrap();
960        let b = Array::from_slice(&[0., 1., 2., 3.], &[4])
961            .power(array!(0.5))
962            .unwrap();
963        let c = a.all_close(&b, 1e-5, None, None).unwrap();
964
965        let c_data: &[bool] = c.as_slice();
966        assert_eq!(c_data, [true]);
967    }
968
969    #[test]
970    fn test_all_close_invalid_broadcast() {
971        let a = Array::from_slice(&[0., 1., 2., 3.], &[4]);
972        let b = Array::from_slice(&[0., 1., 2., 3., 4.], &[5]);
973        let c = a.all_close(&b, 1e-5, None, None);
974        assert!(c.is_err());
975    }
976
977    #[test]
978    fn test_is_close_false() {
979        let a = Array::from_slice(&[1., 2., 3.], &[3]);
980        let b = Array::from_slice(&[1.1, 2.2, 3.3], &[3]);
981        let c = a.is_close(&b, None, None, false).unwrap();
982
983        let c_data: &[bool] = c.as_slice();
984        assert_eq!(c_data, [false, false, false]);
985    }
986
987    #[test]
988    fn test_is_close_true() {
989        let a = Array::from_slice(&[1., 2., 3.], &[3]);
990        let b = Array::from_slice(&[1.1, 2.2, 3.3], &[3]);
991        let c = a.is_close(&b, 0.1, 0.2, true).unwrap();
992
993        let c_data: &[bool] = c.as_slice();
994        assert_eq!(c_data, [true, true, true]);
995    }
996
997    #[test]
998    fn test_is_close_invalid_broadcast() {
999        let a = Array::from_slice(&[1., 2., 3.], &[3]);
1000        let b = Array::from_slice(&[1.1, 2.2, 3.3, 4.4], &[4]);
1001        let c = a.is_close(&b, None, None, false);
1002        assert!(c.is_err());
1003    }
1004
1005    #[test]
1006    fn test_array_eq() {
1007        let a = Array::from_slice(&[0, 1, 2, 3], &[4]);
1008        let b = Array::from_slice(&[0., 1., 2., 3.], &[4]);
1009        let c = a.array_eq(&b, None).unwrap();
1010
1011        let c_data: &[bool] = c.as_slice();
1012        assert_eq!(c_data, [true]);
1013    }
1014
1015    #[test]
1016    fn test_any() {
1017        let array = Array::from_slice(&[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11], &[3, 4]);
1018        let all = array.any_axes(&[0][..], None).unwrap();
1019
1020        let results: &[bool] = all.as_slice();
1021        assert_eq!(results, &[true, true, true, true]);
1022    }
1023
1024    #[test]
1025    fn test_any_empty_axes() {
1026        let array = Array::from_slice(&[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11], &[3, 4]);
1027        let all = array.any_axes(&[][..], None).unwrap();
1028
1029        let results: &[bool] = all.as_slice();
1030        assert_eq!(
1031            results,
1032            &[false, true, true, true, true, true, true, true, true, true, true, true]
1033        );
1034    }
1035
1036    #[test]
1037    fn test_any_out_of_bounds() {
1038        let array = Array::from_slice(&[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11], &[12]);
1039        let result = array.any_axes(&[1][..], None);
1040        assert!(result.is_err());
1041    }
1042
1043    #[test]
1044    fn test_any_duplicate_axes() {
1045        let array = Array::from_slice(&[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11], &[3, 4]);
1046        let result = array.any_axes(&[0, 0][..], None);
1047        assert!(result.is_err());
1048    }
1049
1050    #[test]
1051    fn test_which() {
1052        let condition = Array::from_slice(&[true, false, true], &[3]);
1053        let a = Array::from_slice(&[1, 2, 3], &[3]);
1054        let b = Array::from_slice(&[4, 5, 6], &[3]);
1055        let c = which(&condition, &a, &b).unwrap();
1056
1057        let c_data: &[i32] = c.as_slice();
1058        assert_eq!(c_data, [1, 5, 3]);
1059    }
1060
1061    #[test]
1062    fn test_which_invalid_broadcast() {
1063        let condition = Array::from_slice(&[true, false, true], &[3]);
1064        let a = Array::from_slice(&[1, 2, 3], &[3]);
1065        let b = Array::from_slice(&[4, 5, 6, 7], &[4]);
1066        let c = which(&condition, &a, &b);
1067        assert!(c.is_err());
1068    }
1069
1070    // The unit tests below are adapted from the mlx c++ codebase
1071
1072    #[test]
1073    fn test_unary_logical_not() {
1074        let x = array!(false);
1075        assert!(logical_not(&x).unwrap().item::<bool>());
1076
1077        let x = array!(1.0);
1078        let y = logical_not(&x).unwrap();
1079        assert_eq!(y.dtype(), Dtype::Bool);
1080        assert!(!y.item::<bool>());
1081
1082        let x = array!(0);
1083        let y = logical_not(&x).unwrap();
1084        assert_eq!(y.dtype(), Dtype::Bool);
1085        assert!(y.item::<bool>());
1086    }
1087
1088    #[test]
1089    fn test_unary_logical_and() {
1090        let x = array!(true);
1091        let y = array!(true);
1092        assert!(logical_and(&x, &y).unwrap().item::<bool>());
1093
1094        let x = array!(1.0);
1095        let y = array!(1.0);
1096        let z = logical_and(&x, &y).unwrap();
1097        assert_eq!(z.dtype(), Dtype::Bool);
1098        assert!(z.item::<bool>());
1099
1100        let x = array!(0);
1101        let y = array!(1.0);
1102        let z = logical_and(&x, &y).unwrap();
1103        assert_eq!(z.dtype(), Dtype::Bool);
1104        assert!(!z.item::<bool>());
1105    }
1106
1107    #[test]
1108    fn test_unary_logical_or() {
1109        let a = array!(false);
1110        let b = array!(false);
1111        assert!(!logical_or(&a, &b).unwrap().item::<bool>());
1112
1113        let a = array!(1.0);
1114        let b = array!(1.0);
1115        let c = logical_or(&a, &b).unwrap();
1116        assert_eq!(c.dtype(), Dtype::Bool);
1117        assert!(c.item::<bool>());
1118
1119        let a = array!(0);
1120        let b = array!(1.0);
1121        let c = logical_or(&a, &b).unwrap();
1122        assert_eq!(c.dtype(), Dtype::Bool);
1123        assert!(c.item::<bool>());
1124    }
1125}