mlx_rs/ops/
reduction.rs

1use crate::array::Array;
2use crate::error::Result;
3use crate::stream::StreamOrDevice;
4use crate::utils::guard::Guarded;
5use crate::utils::{axes_or_default_to_all, IntoOption};
6use crate::Stream;
7use mlx_internal_macros::{default_device, generate_macro};
8
9impl Array {
10    /// An `and` reduction over the given axes returning an error if the axes are invalid.
11    ///
12    /// # Params
13    ///
14    /// - axes: The axes to reduce over -- defaults to all axes if not provided
15    /// - keep_dims: Whether to keep the reduced dimensions -- defaults to false if not provided
16    ///
17    /// # Example
18    ///
19    /// ```rust
20    /// use mlx_rs::Array;
21    /// let a = Array::from_slice(&[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11], &[3, 4]);
22    /// let mut b = a.all(&[0], None).unwrap();
23    ///
24    /// let results: &[bool] = b.as_slice();
25    /// // results == [false, true, true, true]
26    /// ```
27    #[default_device]
28    pub fn all_device<'a>(
29        &self,
30        axes: impl IntoOption<&'a [i32]>,
31        keep_dims: impl Into<Option<bool>>,
32        stream: impl AsRef<Stream>,
33    ) -> Result<Array> {
34        let axes = axes_or_default_to_all(axes, self.ndim() as i32);
35        Array::try_from_op(|res| unsafe {
36            mlx_sys::mlx_all_axes(
37                res,
38                self.as_ptr(),
39                axes.as_ptr(),
40                axes.len(),
41                keep_dims.into().unwrap_or(false),
42                stream.as_ref().as_ptr(),
43            )
44        })
45    }
46
47    /// A `product` reduction over the given axes returning an error if the axes are invalid.
48    ///
49    /// # Params
50    ///
51    /// - axes: axes to reduce over
52    /// - keep_dims: Whether to keep the reduced dimensions -- defaults to false if not provided
53    ///
54    /// # Example
55    ///
56    /// ```rust
57    /// use mlx_rs::Array;
58    /// let array = Array::from_slice(&[5, 8, 4, 9], &[2, 2]);
59    ///
60    /// // result is [20, 72]
61    /// let result = array.prod(&[0], None).unwrap();
62    /// ```
63    #[default_device]
64    pub fn prod_device<'a>(
65        &self,
66        axes: impl IntoOption<&'a [i32]>,
67        keep_dims: impl Into<Option<bool>>,
68        stream: impl AsRef<Stream>,
69    ) -> Result<Array> {
70        let axes = axes_or_default_to_all(axes, self.ndim() as i32);
71        Array::try_from_op(|res| unsafe {
72            mlx_sys::mlx_prod(
73                res,
74                self.as_ptr(),
75                axes.as_ptr(),
76                axes.len(),
77                keep_dims.into().unwrap_or(false),
78                stream.as_ref().as_ptr(),
79            )
80        })
81    }
82
83    /// A `max` reduction over the given axes returning an error if the axes are invalid.
84    ///
85    /// # Params
86    ///
87    /// - axes: axes to reduce over
88    /// - keep_dims: Whether to keep the reduced dimensions -- defaults to false if not provided
89    ///
90    /// # Example
91    ///
92    /// ```rust
93    /// use mlx_rs::Array;
94    /// let array = Array::from_slice(&[5, 8, 4, 9], &[2, 2]);
95    ///
96    /// // result is [5, 9]
97    /// let result = array.max(&[0], None).unwrap();
98    /// ```
99    #[default_device]
100    pub fn max_device<'a>(
101        &self,
102        axes: impl IntoOption<&'a [i32]>,
103        keep_dims: impl Into<Option<bool>>,
104        stream: impl AsRef<Stream>,
105    ) -> Result<Array> {
106        let axes = axes_or_default_to_all(axes, self.ndim() as i32);
107        Array::try_from_op(|res| unsafe {
108            mlx_sys::mlx_max(
109                res,
110                self.as_ptr(),
111                axes.as_ptr(),
112                axes.len(),
113                keep_dims.into().unwrap_or(false),
114                stream.as_ref().as_ptr(),
115            )
116        })
117    }
118
119    /// Sum reduce the array over the given axes returning an error if the axes are invalid.
120    ///
121    /// # Params
122    ///
123    /// - axes: axes to reduce over
124    /// - keep_dims: if `true`, keep the reduces axes as singleton dimensions
125    ///
126    /// # Example
127    ///
128    /// ```rust
129    /// use mlx_rs::Array;
130    /// let array = Array::from_slice(&[5, 8, 4, 9], &[2, 2]);
131    ///
132    /// // result is [9, 17]
133    /// let result = array.sum(&[0], None).unwrap();
134    /// ```
135    #[default_device]
136    pub fn sum_device<'a>(
137        &self,
138        axes: impl IntoOption<&'a [i32]>,
139        keep_dims: impl Into<Option<bool>>,
140        stream: impl AsRef<Stream>,
141    ) -> Result<Array> {
142        let axes = axes_or_default_to_all(axes, self.ndim() as i32);
143        Array::try_from_op(|res| unsafe {
144            mlx_sys::mlx_sum(
145                res,
146                self.as_ptr(),
147                axes.as_ptr(),
148                axes.len(),
149                keep_dims.into().unwrap_or(false),
150                stream.as_ref().as_ptr(),
151            )
152        })
153    }
154
155    /// A `mean` reduction over the given axes returning an error if the axes are invalid.
156    ///
157    /// # Params
158    ///
159    /// - axes: axes to reduce over
160    /// - keep_dims: Whether to keep the reduced dimensions -- defaults to false if not provided
161    ///
162    /// # Example
163    ///
164    /// ```rust
165    /// use mlx_rs::Array;
166    /// let array = Array::from_slice(&[5, 8, 4, 9], &[2, 2]);
167    ///
168    /// // result is [4.5, 8.5]
169    /// let result = array.mean(&[0], None).unwrap();
170    /// ```
171    #[default_device]
172    pub fn mean_device<'a>(
173        &self,
174        axes: impl IntoOption<&'a [i32]>,
175        keep_dims: impl Into<Option<bool>>,
176        stream: impl AsRef<Stream>,
177    ) -> Result<Array> {
178        let axes = axes_or_default_to_all(axes, self.ndim() as i32);
179        Array::try_from_op(|res| unsafe {
180            mlx_sys::mlx_mean(
181                res,
182                self.as_ptr(),
183                axes.as_ptr(),
184                axes.len(),
185                keep_dims.into().unwrap_or(false),
186                stream.as_ref().as_ptr(),
187            )
188        })
189    }
190
191    /// A `min` reduction over the given axes returning an error if the axes are invalid.
192    ///
193    /// # Params
194    ///
195    /// - axes: axes to reduce over
196    /// - keep_dims: Whether to keep the reduced dimensions -- defaults to false if not provided
197    ///
198    /// # Example
199    ///
200    /// ```rust
201    /// use mlx_rs::Array;
202    /// let array = Array::from_slice(&[5, 8, 4, 9], &[2, 2]);
203    ///
204    /// // result is [4, 8]
205    /// let result = array.min(&[0], None).unwrap();
206    /// ```
207    #[default_device]
208    pub fn min_device<'a>(
209        &self,
210        axes: impl IntoOption<&'a [i32]>,
211        keep_dims: impl Into<Option<bool>>,
212        stream: impl AsRef<Stream>,
213    ) -> Result<Array> {
214        let axes = axes_or_default_to_all(axes, self.ndim() as i32);
215        Array::try_from_op(|res| unsafe {
216            mlx_sys::mlx_min(
217                res,
218                self.as_ptr(),
219                axes.as_ptr(),
220                axes.len(),
221                keep_dims.into().unwrap_or(false),
222                stream.as_ref().as_ptr(),
223            )
224        })
225    }
226
227    /// Compute the variance(s) over the given axes returning an error if the axes are invalid.
228    ///
229    /// # Params
230    ///
231    /// - axes: axes to reduce over
232    /// - keep_dims: if `true`, keep the reduces axes as singleton dimensions
233    /// - ddof: the divisor to compute the variance is `N - ddof`
234    #[default_device]
235    pub fn variance_device<'a>(
236        &self,
237        axes: impl IntoOption<&'a [i32]>,
238        keep_dims: impl Into<Option<bool>>,
239        ddof: impl Into<Option<i32>>,
240        stream: impl AsRef<Stream>,
241    ) -> Result<Array> {
242        let axes = axes_or_default_to_all(axes, self.ndim() as i32);
243        Array::try_from_op(|res| unsafe {
244            mlx_sys::mlx_var(
245                res,
246                self.as_ptr(),
247                axes.as_ptr(),
248                axes.len(),
249                keep_dims.into().unwrap_or(false),
250                ddof.into().unwrap_or(0),
251                stream.as_ref().as_ptr(),
252            )
253        })
254    }
255
256    /// A `log-sum-exp` reduction over the given axes returning an error if the axes are invalid.
257    ///
258    /// The log-sum-exp reduction is a numerically stable version of using the individual operations.
259    ///
260    /// # Params
261    ///
262    /// - axes: axes to reduce over
263    /// - keep_dims: Whether to keep the reduced dimensions -- defaults to false if not provided
264    #[default_device]
265    pub fn log_sum_exp_device<'a>(
266        &self,
267        axes: impl IntoOption<&'a [i32]>,
268        keep_dims: impl Into<Option<bool>>,
269        stream: impl AsRef<Stream>,
270    ) -> Result<Array> {
271        let axes = axes_or_default_to_all(axes, self.ndim() as i32);
272        Array::try_from_op(|res| unsafe {
273            mlx_sys::mlx_logsumexp(
274                res,
275                self.as_ptr(),
276                axes.as_ptr(),
277                axes.len(),
278                keep_dims.into().unwrap_or(false),
279                stream.as_ref().as_ptr(),
280            )
281        })
282    }
283}
284
285/// See [`Array::all`]
286#[generate_macro]
287#[default_device]
288pub fn all_device<'a>(
289    array: impl AsRef<Array>,
290    #[optional] axes: impl IntoOption<&'a [i32]>,
291    #[optional] keep_dims: impl Into<Option<bool>>,
292    #[optional] stream: impl AsRef<Stream>,
293) -> Result<Array> {
294    array.as_ref().all_device(axes, keep_dims, stream)
295}
296
297/// See [`Array::prod`]
298#[generate_macro]
299#[default_device]
300pub fn prod_device<'a>(
301    array: impl AsRef<Array>,
302    #[optional] axes: impl IntoOption<&'a [i32]>,
303    #[optional] keep_dims: impl Into<Option<bool>>,
304    #[optional] stream: impl AsRef<Stream>,
305) -> Result<Array> {
306    array.as_ref().prod_device(axes, keep_dims, stream)
307}
308
309/// See [`Array::max`]
310#[generate_macro]
311#[default_device]
312pub fn max_device<'a>(
313    array: impl AsRef<Array>,
314    #[optional] axes: impl IntoOption<&'a [i32]>,
315    #[optional] keep_dims: impl Into<Option<bool>>,
316    #[optional] stream: impl AsRef<Stream>,
317) -> Result<Array> {
318    array.as_ref().max_device(axes, keep_dims, stream)
319}
320
321/// Compute the standard deviation(s) over the given axes.
322///
323/// # Params
324///
325/// - `a`: Input array
326/// - `axes`: Optional axis or axes to reduce over. If unspecified this defaults to reducing over
327///   the entire array.
328/// - `keep_dims`: Keep reduced axes as singleton dimensions, defaults to False.
329/// - `ddof`: The divisor to compute the variance is `N - ddof`, defaults to `0`.
330#[generate_macro]
331#[default_device]
332pub fn std_device<'a>(
333    a: impl AsRef<Array>,
334    #[optional] axes: impl IntoOption<&'a [i32]>,
335    #[optional] keep_dims: impl Into<Option<bool>>,
336    #[optional] ddof: impl Into<Option<i32>>,
337    #[optional] stream: impl AsRef<Stream>,
338) -> Result<Array> {
339    let a = a.as_ref();
340    let axes = axes_or_default_to_all(axes, a.ndim() as i32);
341    let keep_dims = keep_dims.into().unwrap_or(false);
342    let ddof = ddof.into().unwrap_or(0);
343    Array::try_from_op(|res| unsafe {
344        mlx_sys::mlx_std(
345            res,
346            a.as_ptr(),
347            axes.as_ptr(),
348            axes.len(),
349            keep_dims,
350            ddof,
351            stream.as_ref().as_ptr(),
352        )
353    })
354}
355
356/// See [`Array::sum`]
357#[generate_macro]
358#[default_device]
359pub fn sum_device<'a>(
360    array: impl AsRef<Array>,
361    #[optional] axes: impl IntoOption<&'a [i32]>,
362    #[optional] keep_dims: impl Into<Option<bool>>,
363    #[optional] stream: impl AsRef<Stream>,
364) -> Result<Array> {
365    array.as_ref().sum_device(axes, keep_dims, stream)
366}
367
368/// See [`Array::mean`]
369#[generate_macro]
370#[default_device]
371pub fn mean_device<'a>(
372    array: impl AsRef<Array>,
373    #[optional] axes: impl IntoOption<&'a [i32]>,
374    #[optional] keep_dims: impl Into<Option<bool>>,
375    #[optional] stream: impl AsRef<Stream>,
376) -> Result<Array> {
377    array.as_ref().mean_device(axes, keep_dims, stream)
378}
379
380/// See [`Array::min`]
381#[generate_macro]
382#[default_device]
383pub fn min_device<'a>(
384    array: impl AsRef<Array>,
385    #[optional] axes: impl IntoOption<&'a [i32]>,
386    #[optional] keep_dims: impl Into<Option<bool>>,
387    #[optional] stream: impl AsRef<Stream>,
388) -> Result<Array> {
389    array.as_ref().min_device(axes, keep_dims, stream)
390}
391
392/// See [`Array::variance`]
393#[generate_macro]
394#[default_device]
395pub fn variance_device<'a>(
396    array: impl AsRef<Array>,
397    #[optional] axes: impl IntoOption<&'a [i32]>,
398    #[optional] keep_dims: impl Into<Option<bool>>,
399    #[optional] ddof: impl Into<Option<i32>>,
400    #[optional] stream: impl AsRef<Stream>,
401) -> Result<Array> {
402    array
403        .as_ref()
404        .variance_device(axes, keep_dims, ddof, stream)
405}
406
407/// See [`Array::log_sum_exp`]
408#[generate_macro]
409#[default_device]
410pub fn log_sum_exp_device<'a>(
411    array: impl AsRef<Array>,
412    #[optional] axes: impl IntoOption<&'a [i32]>,
413    #[optional] keep_dims: impl Into<Option<bool>>,
414    #[optional] stream: impl AsRef<Stream>,
415) -> Result<Array> {
416    array.as_ref().log_sum_exp_device(axes, keep_dims, stream)
417}
418
419#[cfg(test)]
420mod tests {
421    use super::*;
422    use pretty_assertions::assert_eq;
423
424    #[test]
425    fn test_all() {
426        let array = Array::from_slice(&[true, false, true, false], &[2, 2]);
427
428        assert_eq!(array.all(None, None).unwrap().item::<bool>(), false);
429        assert_eq!(array.all(None, true).unwrap().shape(), &[1, 1]);
430        assert_eq!(array.all(&[0, 1][..], None).unwrap().item::<bool>(), false);
431
432        let result = array.all(&[0][..], None).unwrap();
433        assert_eq!(result.as_slice::<bool>(), &[true, false]);
434
435        let result = array.all(&[1][..], None).unwrap();
436        assert_eq!(result.as_slice::<bool>(), &[false, false]);
437    }
438
439    #[test]
440    fn test_all_empty_axes() {
441        let array = Array::from_slice(&[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11], &[3, 4]);
442        let all = array.all(&[][..], None).unwrap();
443
444        let results: &[bool] = all.as_slice();
445        assert_eq!(
446            results,
447            &[false, true, true, true, true, true, true, true, true, true, true, true]
448        );
449    }
450
451    #[test]
452    fn test_prod() {
453        let x = Array::from_slice(&[1, 2, 3, 3], &[2, 2]);
454        assert_eq!(x.prod(None, None).unwrap().item::<i32>(), 18);
455
456        let y = x.prod(None, true).unwrap();
457        assert_eq!(y.item::<i32>(), 18);
458        assert_eq!(y.shape(), &[1, 1]);
459
460        let result = x.prod(&[0][..], None).unwrap();
461        assert_eq!(result.as_slice::<i32>(), &[3, 6]);
462
463        let result = x.prod(&[1][..], None).unwrap();
464        assert_eq!(result.as_slice::<i32>(), &[2, 9])
465    }
466
467    #[test]
468    fn test_prod_empty_axes() {
469        let array = Array::from_slice(&[5, 8, 4, 9], &[2, 2]);
470        let result = array.prod(&[][..], None).unwrap();
471
472        let results: &[i32] = result.as_slice();
473        assert_eq!(results, &[5, 8, 4, 9]);
474    }
475
476    #[test]
477    fn test_max() {
478        let x = Array::from_slice(&[1, 2, 3, 4], &[2, 2]);
479        assert_eq!(x.max(None, None).unwrap().item::<i32>(), 4);
480        let y = x.max(None, true).unwrap();
481        assert_eq!(y.item::<i32>(), 4);
482        assert_eq!(y.shape(), &[1, 1]);
483
484        let result = x.max(&[0][..], None).unwrap();
485        assert_eq!(result.as_slice::<i32>(), &[3, 4]);
486
487        let result = x.max(&[1][..], None).unwrap();
488        assert_eq!(result.as_slice::<i32>(), &[2, 4]);
489    }
490
491    #[test]
492    fn test_max_empty_axes() {
493        let array = Array::from_slice(&[5, 8, 4, 9], &[2, 2]);
494        let result = array.max(&[][..], None).unwrap();
495
496        let results: &[i32] = result.as_slice();
497        assert_eq!(results, &[5, 8, 4, 9]);
498    }
499
500    #[test]
501    fn test_sum() {
502        let array = Array::from_slice(&[5, 8, 4, 9], &[2, 2]);
503        let result = array.sum(&[0][..], None).unwrap();
504
505        let results: &[i32] = result.as_slice();
506        assert_eq!(results, &[9, 17]);
507    }
508
509    #[test]
510    fn test_sum_empty_axes() {
511        let array = Array::from_slice(&[5, 8, 4, 9], &[2, 2]);
512        let result = array.sum(&[][..], None).unwrap();
513
514        let results: &[i32] = result.as_slice();
515        assert_eq!(results, &[5, 8, 4, 9]);
516    }
517
518    #[test]
519    fn test_mean() {
520        let x = Array::from_slice(&[1, 2, 3, 4], &[2, 2]);
521        assert_eq!(x.mean(None, None).unwrap().item::<f32>(), 2.5);
522        let y = x.mean(None, true).unwrap();
523        assert_eq!(y.item::<f32>(), 2.5);
524        assert_eq!(y.shape(), &[1, 1]);
525
526        let result = x.mean(&[0][..], None).unwrap();
527        assert_eq!(result.as_slice::<f32>(), &[2.0, 3.0]);
528
529        let result = x.mean(&[1][..], None).unwrap();
530        assert_eq!(result.as_slice::<f32>(), &[1.5, 3.5]);
531    }
532
533    #[test]
534    fn test_mean_empty_axes() {
535        let array = Array::from_slice(&[5, 8, 4, 9], &[2, 2]);
536        let result = array.mean(&[][..], None).unwrap();
537
538        let results: &[f32] = result.as_slice();
539        assert_eq!(results, &[5.0, 8.0, 4.0, 9.0]);
540    }
541
542    #[test]
543    fn test_mean_out_of_bounds() {
544        let array = Array::from_slice(&[5, 8, 4, 9], &[2, 2]);
545        let result = array.mean(&[2][..], None);
546        assert!(result.is_err());
547    }
548
549    #[test]
550    fn test_min() {
551        let x = Array::from_slice(&[1, 2, 3, 4], &[2, 2]);
552        assert_eq!(x.min(None, None).unwrap().item::<i32>(), 1);
553        let y = x.min(None, true).unwrap();
554        assert_eq!(y.item::<i32>(), 1);
555        assert_eq!(y.shape(), &[1, 1]);
556
557        let result = x.min(&[0][..], None).unwrap();
558        assert_eq!(result.as_slice::<i32>(), &[1, 2]);
559
560        let result = x.min(&[1][..], None).unwrap();
561        assert_eq!(result.as_slice::<i32>(), &[1, 3]);
562    }
563
564    #[test]
565    fn test_min_empty_axes() {
566        let array = Array::from_slice(&[5, 8, 4, 9], &[2, 2]);
567        let result = array.min(&[][..], None).unwrap();
568
569        let results: &[i32] = result.as_slice();
570        assert_eq!(results, &[5, 8, 4, 9]);
571    }
572
573    #[test]
574    fn test_var() {
575        let x = Array::from_slice(&[1, 2, 3, 4], &[2, 2]);
576        assert_eq!(x.variance(None, None, None).unwrap().item::<f32>(), 1.25);
577        let y = x.variance(None, true, None).unwrap();
578        assert_eq!(y.item::<f32>(), 1.25);
579        assert_eq!(y.shape(), &[1, 1]);
580
581        let result = x.variance(&[0][..], None, None).unwrap();
582        assert_eq!(result.as_slice::<f32>(), &[1.0, 1.0]);
583
584        let result = x.variance(&[1][..], None, None).unwrap();
585        assert_eq!(result.as_slice::<f32>(), &[0.25, 0.25]);
586
587        let x = Array::from_slice(&[1.0, 2.0], &[2]);
588        let out = x.variance(None, None, Some(3)).unwrap();
589        assert_eq!(out.item::<f32>(), f32::INFINITY);
590    }
591
592    #[test]
593    fn test_var_empty_axes() {
594        let array = Array::from_slice(&[5, 8, 4, 9], &[2, 2]);
595        let result = array.variance(&[][..], None, 0).unwrap();
596
597        let results: &[f32] = result.as_slice();
598        assert_eq!(results, &[0.0, 0.0, 0.0, 0.0]);
599    }
600
601    #[test]
602    fn test_log_sum_exp() {
603        let array = Array::from_slice(&[5, 8, 4, 9], &[2, 2]);
604        let result = array.log_sum_exp(&[0][..], None).unwrap();
605
606        let results: &[f32] = result.as_slice();
607        assert_eq!(results, &[5.3132615, 9.313262]);
608    }
609
610    #[test]
611    fn test_log_sum_exp_empty_axes() {
612        let array = Array::from_slice(&[5, 8, 4, 9], &[2, 2]);
613        let result = array.log_sum_exp(&[][..], None).unwrap();
614
615        let results: &[f32] = result.as_slice();
616        assert_eq!(results, &[5.0, 8.0, 4.0, 9.0]);
617    }
618}