mlx_rs/ops/
reduction.rs

1use crate::array::Array;
2use crate::error::Result;
3use crate::utils::axes_or_default_to_all;
4use crate::utils::guard::Guarded;
5use crate::Stream;
6use mlx_internal_macros::{default_device, generate_macro};
7
8impl Array {
9    /// An `and` reduction over the given axes returning an error if the axes are invalid.
10    ///
11    /// # Params
12    ///
13    /// - axes: The axes to reduce over -- defaults to all axes if not provided
14    /// - keep_dims: Whether to keep the reduced dimensions -- defaults to false if not provided
15    ///
16    /// # Example
17    ///
18    /// ```rust
19    /// use mlx_rs::Array;
20    /// let a = Array::from_slice(&[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11], &[3, 4]);
21    /// let mut b = a.all_axes(&[0], None).unwrap();
22    ///
23    /// let results: &[bool] = b.as_slice();
24    /// // results == [false, true, true, true]
25    /// ```
26    #[default_device]
27    pub fn all_axes_device(
28        &self,
29        axes: &[i32],
30        keep_dims: impl Into<Option<bool>>,
31        stream: impl AsRef<Stream>,
32    ) -> Result<Array> {
33        Array::try_from_op(|res| unsafe {
34            mlx_sys::mlx_all_axes(
35                res,
36                self.as_ptr(),
37                axes.as_ptr(),
38                axes.len(),
39                keep_dims.into().unwrap_or(false),
40                stream.as_ref().as_ptr(),
41            )
42        })
43    }
44
45    /// Similar to [`Array::all_axes`] but only reduces over a single axis.
46    #[default_device]
47    pub fn all_axis_device(
48        &self,
49        axis: i32,
50        keep_dims: impl Into<Option<bool>>,
51        stream: impl AsRef<Stream>,
52    ) -> Result<Array> {
53        Array::try_from_op(|res| unsafe {
54            mlx_sys::mlx_all_axis(
55                res,
56                self.as_ptr(),
57                axis,
58                keep_dims.into().unwrap_or(false),
59                stream.as_ref().as_ptr(),
60            )
61        })
62    }
63
64    /// Similar to [`Array::all_axes`] but reduces over all axes.
65    #[default_device]
66    pub fn all_device(
67        &self,
68        keep_dims: impl Into<Option<bool>>,
69        stream: impl AsRef<Stream>,
70    ) -> Result<Array> {
71        Array::try_from_op(|res| unsafe {
72            mlx_sys::mlx_all(
73                res,
74                self.as_ptr(),
75                keep_dims.into().unwrap_or(false),
76                stream.as_ref().as_ptr(),
77            )
78        })
79    }
80
81    /// A `product` reduction over the given axes returning an error if the axes are invalid.
82    ///
83    /// # Params
84    ///
85    /// - axes: axes to reduce over
86    /// - keep_dims: Whether to keep the reduced dimensions -- defaults to false if not provided
87    ///
88    /// # Example
89    ///
90    /// ```rust
91    /// use mlx_rs::Array;
92    /// let array = Array::from_slice(&[5, 8, 4, 9], &[2, 2]);
93    ///
94    /// // result is [20, 72]
95    /// let result = array.prod_axes(&[0], None).unwrap();
96    /// ```
97    #[default_device]
98    pub fn prod_axes_device(
99        &self,
100        axes: &[i32],
101        keep_dims: impl Into<Option<bool>>,
102        stream: impl AsRef<Stream>,
103    ) -> Result<Array> {
104        Array::try_from_op(|res| unsafe {
105            mlx_sys::mlx_prod_axes(
106                res,
107                self.as_ptr(),
108                axes.as_ptr(),
109                axes.len(),
110                keep_dims.into().unwrap_or(false),
111                stream.as_ref().as_ptr(),
112            )
113        })
114    }
115
116    /// Similar to [`Array::prod_axes`] but only reduces over a single axis.
117    #[default_device]
118    pub fn prod_axis_device(
119        &self,
120        axis: i32,
121        keep_dims: impl Into<Option<bool>>,
122        stream: impl AsRef<Stream>,
123    ) -> Result<Array> {
124        Array::try_from_op(|res| unsafe {
125            mlx_sys::mlx_prod_axis(
126                res,
127                self.as_ptr(),
128                axis,
129                keep_dims.into().unwrap_or(false),
130                stream.as_ref().as_ptr(),
131            )
132        })
133    }
134
135    /// Similar to [`Array::prod_axes`] but reduces over all axes.
136    #[default_device]
137    pub fn prod_device(
138        &self,
139        keep_dims: impl Into<Option<bool>>,
140        stream: impl AsRef<Stream>,
141    ) -> Result<Array> {
142        Array::try_from_op(|res| unsafe {
143            mlx_sys::mlx_prod(
144                res,
145                self.as_ptr(),
146                keep_dims.into().unwrap_or(false),
147                stream.as_ref().as_ptr(),
148            )
149        })
150    }
151
152    /// A `max` reduction over the given axes returning an error if the axes are invalid.
153    ///
154    /// # Params
155    ///
156    /// - axes: axes to reduce over
157    /// - keep_dims: Whether to keep the reduced dimensions -- defaults to false if not provided
158    ///
159    /// # Example
160    ///
161    /// ```rust
162    /// use mlx_rs::Array;
163    /// let array = Array::from_slice(&[5, 8, 4, 9], &[2, 2]);
164    ///
165    /// // result is [5, 9]
166    /// let result = array.max_axes(&[0], None).unwrap();
167    /// ```
168    #[default_device]
169    pub fn max_axes_device(
170        &self,
171        axes: &[i32],
172        keep_dims: impl Into<Option<bool>>,
173        stream: impl AsRef<Stream>,
174    ) -> Result<Array> {
175        Array::try_from_op(|res| unsafe {
176            mlx_sys::mlx_max_axes(
177                res,
178                self.as_ptr(),
179                axes.as_ptr(),
180                axes.len(),
181                keep_dims.into().unwrap_or(false),
182                stream.as_ref().as_ptr(),
183            )
184        })
185    }
186
187    /// Similar to [`Array::max_axes`] but only reduces over a single axis.
188    #[default_device]
189    pub fn max_axis_device(
190        &self,
191        axis: i32,
192        keep_dims: impl Into<Option<bool>>,
193        stream: impl AsRef<Stream>,
194    ) -> Result<Array> {
195        Array::try_from_op(|res| unsafe {
196            mlx_sys::mlx_max_axis(
197                res,
198                self.as_ptr(),
199                axis,
200                keep_dims.into().unwrap_or(false),
201                stream.as_ref().as_ptr(),
202            )
203        })
204    }
205
206    /// Similar to [`Array::max_axes`] but reduces over all axes.
207    #[default_device]
208    pub fn max_device(
209        &self,
210        keep_dims: impl Into<Option<bool>>,
211        stream: impl AsRef<Stream>,
212    ) -> Result<Array> {
213        Array::try_from_op(|res| unsafe {
214            mlx_sys::mlx_max(
215                res,
216                self.as_ptr(),
217                keep_dims.into().unwrap_or(false),
218                stream.as_ref().as_ptr(),
219            )
220        })
221    }
222
223    /// Sum reduce the array over the given axes returning an error if the axes are invalid.
224    ///
225    /// # Params
226    ///
227    /// - axes: axes to reduce over
228    /// - keep_dims: if `true`, keep the reduces axes as singleton dimensions
229    ///
230    /// # Example
231    ///
232    /// ```rust
233    /// use mlx_rs::Array;
234    /// let array = Array::from_slice(&[5, 8, 4, 9], &[2, 2]);
235    ///
236    /// // result is [9, 17]
237    /// let result = array.sum_axes(&[0], None).unwrap();
238    /// ```
239    #[default_device]
240    pub fn sum_axes_device(
241        &self,
242        axes: &[i32],
243        keep_dims: impl Into<Option<bool>>,
244        stream: impl AsRef<Stream>,
245    ) -> Result<Array> {
246        Array::try_from_op(|res| unsafe {
247            mlx_sys::mlx_sum_axes(
248                res,
249                self.as_ptr(),
250                axes.as_ptr(),
251                axes.len(),
252                keep_dims.into().unwrap_or(false),
253                stream.as_ref().as_ptr(),
254            )
255        })
256    }
257
258    /// Similar to [`Array::sum_axes`] but only reduces over a single axis.
259    #[default_device]
260    pub fn sum_axis_device(
261        &self,
262        axis: i32,
263        keep_dims: impl Into<Option<bool>>,
264        stream: impl AsRef<Stream>,
265    ) -> Result<Array> {
266        Array::try_from_op(|res| unsafe {
267            mlx_sys::mlx_sum_axis(
268                res,
269                self.as_ptr(),
270                axis,
271                keep_dims.into().unwrap_or(false),
272                stream.as_ref().as_ptr(),
273            )
274        })
275    }
276
277    /// Similar to [`Array::sum_axes`] but reduces over all axes.
278    #[default_device]
279    pub fn sum_device(
280        &self,
281        keep_dims: impl Into<Option<bool>>,
282        stream: impl AsRef<Stream>,
283    ) -> Result<Array> {
284        Array::try_from_op(|res| unsafe {
285            mlx_sys::mlx_sum(
286                res,
287                self.as_ptr(),
288                keep_dims.into().unwrap_or(false),
289                stream.as_ref().as_ptr(),
290            )
291        })
292    }
293
294    /// A `mean` reduction over the given axes returning an error if the axes are invalid.
295    ///
296    /// # Params
297    ///
298    /// - axes: axes to reduce over
299    /// - keep_dims: Whether to keep the reduced dimensions -- defaults to false if not provided
300    ///
301    /// # Example
302    ///
303    /// ```rust
304    /// use mlx_rs::Array;
305    /// let array = Array::from_slice(&[5, 8, 4, 9], &[2, 2]);
306    ///
307    /// // result is [4.5, 8.5]
308    /// let result = array.mean_axes(&[0], None).unwrap();
309    /// ```
310    #[default_device]
311    pub fn mean_axes_device(
312        &self,
313        axes: &[i32],
314        keep_dims: impl Into<Option<bool>>,
315        stream: impl AsRef<Stream>,
316    ) -> Result<Array> {
317        let axes = axes_or_default_to_all(axes, self.ndim() as i32);
318        Array::try_from_op(|res| unsafe {
319            mlx_sys::mlx_mean_axes(
320                res,
321                self.as_ptr(),
322                axes.as_ptr(),
323                axes.len(),
324                keep_dims.into().unwrap_or(false),
325                stream.as_ref().as_ptr(),
326            )
327        })
328    }
329
330    /// Similar to [`Array::mean_axes`] but only reduces over a single axis.
331    #[default_device]
332    pub fn mean_axis_device(
333        &self,
334        axis: i32,
335        keep_dims: impl Into<Option<bool>>,
336        stream: impl AsRef<Stream>,
337    ) -> Result<Array> {
338        Array::try_from_op(|res| unsafe {
339            mlx_sys::mlx_mean_axis(
340                res,
341                self.as_ptr(),
342                axis,
343                keep_dims.into().unwrap_or(false),
344                stream.as_ref().as_ptr(),
345            )
346        })
347    }
348
349    /// Similar to [`Array::mean_axes`] but reduces over all axes.
350    #[default_device]
351    pub fn mean_device(
352        &self,
353        keep_dims: impl Into<Option<bool>>,
354        stream: impl AsRef<Stream>,
355    ) -> Result<Array> {
356        Array::try_from_op(|res| unsafe {
357            mlx_sys::mlx_mean(
358                res,
359                self.as_ptr(),
360                keep_dims.into().unwrap_or(false),
361                stream.as_ref().as_ptr(),
362            )
363        })
364    }
365
366    /// A `min` reduction over the given axes returning an error if the axes are invalid.
367    ///
368    /// # Params
369    ///
370    /// - axes: axes to reduce over
371    /// - keep_dims: Whether to keep the reduced dimensions -- defaults to false if not provided
372    ///
373    /// # Example
374    ///
375    /// ```rust
376    /// use mlx_rs::Array;
377    /// let array = Array::from_slice(&[5, 8, 4, 9], &[2, 2]);
378    ///
379    /// // result is [4, 8]
380    /// let result = array.min_axes(&[0], None).unwrap();
381    /// ```
382    #[default_device]
383    pub fn min_axes_device(
384        &self,
385        axes: &[i32],
386        keep_dims: impl Into<Option<bool>>,
387        stream: impl AsRef<Stream>,
388    ) -> Result<Array> {
389        Array::try_from_op(|res| unsafe {
390            mlx_sys::mlx_min_axes(
391                res,
392                self.as_ptr(),
393                axes.as_ptr(),
394                axes.len(),
395                keep_dims.into().unwrap_or(false),
396                stream.as_ref().as_ptr(),
397            )
398        })
399    }
400
401    /// Similar to [`Array::min_axes`] but only reduces over a single axis.
402    #[default_device]
403    pub fn min_axis_device(
404        &self,
405        axis: i32,
406        keep_dims: impl Into<Option<bool>>,
407        stream: impl AsRef<Stream>,
408    ) -> Result<Array> {
409        Array::try_from_op(|res| unsafe {
410            mlx_sys::mlx_min_axis(
411                res,
412                self.as_ptr(),
413                axis,
414                keep_dims.into().unwrap_or(false),
415                stream.as_ref().as_ptr(),
416            )
417        })
418    }
419
420    /// Similar to [`Array::min_axes`] but reduces over all axes.
421    #[default_device]
422    pub fn min_device(
423        &self,
424        keep_dims: impl Into<Option<bool>>,
425        stream: impl AsRef<Stream>,
426    ) -> Result<Array> {
427        Array::try_from_op(|res| unsafe {
428            mlx_sys::mlx_min(
429                res,
430                self.as_ptr(),
431                keep_dims.into().unwrap_or(false),
432                stream.as_ref().as_ptr(),
433            )
434        })
435    }
436
437    /// Compute the variance(s) over the given axes returning an error if the axes are invalid.
438    ///
439    /// # Params
440    ///
441    /// - axes: axes to reduce over
442    /// - keep_dims: if `true`, keep the reduces axes as singleton dimensions
443    /// - ddof: the divisor to compute the variance is `N - ddof`
444    #[default_device]
445    pub fn var_axes_device(
446        &self,
447        axes: &[i32],
448        keep_dims: impl Into<Option<bool>>,
449        ddof: impl Into<Option<i32>>,
450        stream: impl AsRef<Stream>,
451    ) -> Result<Array> {
452        Array::try_from_op(|res| unsafe {
453            mlx_sys::mlx_var_axes(
454                res,
455                self.as_ptr(),
456                axes.as_ptr(),
457                axes.len(),
458                keep_dims.into().unwrap_or(false),
459                ddof.into().unwrap_or(0),
460                stream.as_ref().as_ptr(),
461            )
462        })
463    }
464
465    /// Similar to [`Array::var_axes`] but only reduces over a single axis.
466    #[default_device]
467    pub fn var_axis_device(
468        &self,
469        axis: i32,
470        keep_dims: impl Into<Option<bool>>,
471        ddof: impl Into<Option<i32>>,
472        stream: impl AsRef<Stream>,
473    ) -> Result<Array> {
474        Array::try_from_op(|res| unsafe {
475            mlx_sys::mlx_var_axis(
476                res,
477                self.as_ptr(),
478                axis,
479                keep_dims.into().unwrap_or(false),
480                ddof.into().unwrap_or(0),
481                stream.as_ref().as_ptr(),
482            )
483        })
484    }
485
486    /// Similar to [`Array::var_axes`] but reduces over all axes.
487    #[default_device]
488    pub fn var_device(
489        &self,
490        keep_dims: impl Into<Option<bool>>,
491        ddof: impl Into<Option<i32>>,
492        stream: impl AsRef<Stream>,
493    ) -> Result<Array> {
494        Array::try_from_op(|res| unsafe {
495            mlx_sys::mlx_var(
496                res,
497                self.as_ptr(),
498                keep_dims.into().unwrap_or(false),
499                ddof.into().unwrap_or(0),
500                stream.as_ref().as_ptr(),
501            )
502        })
503    }
504
505    /// A `log-sum-exp` reduction over the given axes returning an error if the axes are invalid.
506    ///
507    /// The log-sum-exp reduction is a numerically stable version of using the individual operations.
508    ///
509    /// # Params
510    ///
511    /// - axes: axes to reduce over
512    /// - keep_dims: Whether to keep the reduced dimensions -- defaults to false if not provided
513    #[default_device]
514    pub fn logsumexp_axes_device(
515        &self,
516        axes: &[i32],
517        keep_dims: impl Into<Option<bool>>,
518        stream: impl AsRef<Stream>,
519    ) -> Result<Array> {
520        Array::try_from_op(|res| unsafe {
521            mlx_sys::mlx_logsumexp_axes(
522                res,
523                self.as_ptr(),
524                axes.as_ptr(),
525                axes.len(),
526                keep_dims.into().unwrap_or(false),
527                stream.as_ref().as_ptr(),
528            )
529        })
530    }
531
532    /// Similar to [`Array::logsumexp_axes`] but only reduces over a single axis.
533    #[default_device]
534    pub fn logsumexp_axis_device(
535        &self,
536        axis: i32,
537        keep_dims: impl Into<Option<bool>>,
538        stream: impl AsRef<Stream>,
539    ) -> Result<Array> {
540        Array::try_from_op(|res| unsafe {
541            mlx_sys::mlx_logsumexp_axis(
542                res,
543                self.as_ptr(),
544                axis,
545                keep_dims.into().unwrap_or(false),
546                stream.as_ref().as_ptr(),
547            )
548        })
549    }
550
551    /// Similar to [`Array::logsumexp_axes`] but reduces over all axes.
552    #[default_device]
553    pub fn logsumexp_device(
554        &self,
555        keep_dims: impl Into<Option<bool>>,
556        stream: impl AsRef<Stream>,
557    ) -> Result<Array> {
558        Array::try_from_op(|res| unsafe {
559            mlx_sys::mlx_logsumexp(
560                res,
561                self.as_ptr(),
562                keep_dims.into().unwrap_or(false),
563                stream.as_ref().as_ptr(),
564            )
565        })
566    }
567}
568
569/// See [`Array::all_axes`]
570#[generate_macro]
571#[default_device]
572pub fn all_axes_device(
573    array: impl AsRef<Array>,
574    axes: &[i32],
575    #[optional] keep_dims: impl Into<Option<bool>>,
576    #[optional] stream: impl AsRef<Stream>,
577) -> Result<Array> {
578    array.as_ref().all_axes_device(axes, keep_dims, stream)
579}
580
581/// See [`Array::all_axis`]
582#[generate_macro]
583#[default_device]
584pub fn all_axis_device(
585    array: impl AsRef<Array>,
586    axis: i32,
587    #[optional] keep_dims: impl Into<Option<bool>>,
588    #[optional] stream: impl AsRef<Stream>,
589) -> Result<Array> {
590    array.as_ref().all_axis_device(axis, keep_dims, stream)
591}
592
593/// See [`Array::all`]
594#[generate_macro]
595#[default_device]
596pub fn all_device(
597    array: impl AsRef<Array>,
598    #[optional] keep_dims: impl Into<Option<bool>>,
599    #[optional] stream: impl AsRef<Stream>,
600) -> Result<Array> {
601    array.as_ref().all_device(keep_dims, stream)
602}
603
604/// See [`Array::prod_axes`]
605#[generate_macro]
606#[default_device]
607pub fn prod_axes_device(
608    array: impl AsRef<Array>,
609    axes: &[i32],
610    #[optional] keep_dims: impl Into<Option<bool>>,
611    #[optional] stream: impl AsRef<Stream>,
612) -> Result<Array> {
613    array.as_ref().prod_axes_device(axes, keep_dims, stream)
614}
615
616/// See [`Array::prod_axis`]
617#[generate_macro]
618#[default_device]
619pub fn prod_axis_device(
620    array: impl AsRef<Array>,
621    axis: i32,
622    #[optional] keep_dims: impl Into<Option<bool>>,
623    #[optional] stream: impl AsRef<Stream>,
624) -> Result<Array> {
625    array.as_ref().prod_axis_device(axis, keep_dims, stream)
626}
627
628/// See [`Array::prod`]
629#[generate_macro]
630#[default_device]
631pub fn prod_device(
632    array: impl AsRef<Array>,
633    #[optional] keep_dims: impl Into<Option<bool>>,
634    #[optional] stream: impl AsRef<Stream>,
635) -> Result<Array> {
636    array.as_ref().prod_device(keep_dims, stream)
637}
638
639/// See [`Array::max_axes`]
640#[generate_macro]
641#[default_device]
642pub fn max_axes_device(
643    array: impl AsRef<Array>,
644    axes: &[i32],
645    #[optional] keep_dims: impl Into<Option<bool>>,
646    #[optional] stream: impl AsRef<Stream>,
647) -> Result<Array> {
648    array.as_ref().max_axes_device(axes, keep_dims, stream)
649}
650
651/// See [`Array::max_axis`]
652#[generate_macro]
653#[default_device]
654pub fn max_axis_device(
655    array: impl AsRef<Array>,
656    axis: i32,
657    #[optional] keep_dims: impl Into<Option<bool>>,
658    #[optional] stream: impl AsRef<Stream>,
659) -> Result<Array> {
660    array.as_ref().max_axis_device(axis, keep_dims, stream)
661}
662
663/// See [`Array::max`]
664#[generate_macro]
665#[default_device]
666pub fn max_device(
667    array: impl AsRef<Array>,
668    #[optional] keep_dims: impl Into<Option<bool>>,
669    #[optional] stream: impl AsRef<Stream>,
670) -> Result<Array> {
671    array.as_ref().max_device(keep_dims, stream)
672}
673
674/// Compute the standard deviation(s) over the given axes.
675///
676/// # Params
677///
678/// - `a`: Input array
679/// - `axes`: Optional axis or axes to reduce over. If unspecified this defaults to reducing over
680///   the entire array.
681/// - `keep_dims`: Keep reduced axes as singleton dimensions, defaults to False.
682/// - `ddof`: The divisor to compute the variance is `N - ddof`, defaults to `0`.
683#[generate_macro]
684#[default_device]
685pub fn std_axes_device(
686    a: impl AsRef<Array>,
687    axes: &[i32],
688    #[optional] keep_dims: impl Into<Option<bool>>,
689    #[optional] ddof: impl Into<Option<i32>>,
690    #[optional] stream: impl AsRef<Stream>,
691) -> Result<Array> {
692    let a = a.as_ref();
693    let keep_dims = keep_dims.into().unwrap_or(false);
694    let ddof = ddof.into().unwrap_or(0);
695    Array::try_from_op(|res| unsafe {
696        mlx_sys::mlx_std_axes(
697            res,
698            a.as_ptr(),
699            axes.as_ptr(),
700            axes.len(),
701            keep_dims,
702            ddof,
703            stream.as_ref().as_ptr(),
704        )
705    })
706}
707
708/// Similar to [`std_axes`] but only reduces over a single axis.
709#[generate_macro]
710#[default_device]
711pub fn std_axis_device(
712    a: impl AsRef<Array>,
713    axis: i32,
714    #[optional] keep_dims: impl Into<Option<bool>>,
715    #[optional] ddof: impl Into<Option<i32>>,
716    #[optional] stream: impl AsRef<Stream>,
717) -> Result<Array> {
718    let a = a.as_ref();
719    let keep_dims = keep_dims.into().unwrap_or(false);
720    let ddof = ddof.into().unwrap_or(0);
721    Array::try_from_op(|res| unsafe {
722        mlx_sys::mlx_std_axis(
723            res,
724            a.as_ptr(),
725            axis,
726            keep_dims,
727            ddof,
728            stream.as_ref().as_ptr(),
729        )
730    })
731}
732
733/// Similar to [`std_axes`] but reduces over all axes.
734#[generate_macro]
735#[default_device]
736pub fn std_device(
737    a: impl AsRef<Array>,
738    #[optional] keep_dims: impl Into<Option<bool>>,
739    #[optional] ddof: impl Into<Option<i32>>,
740    #[optional] stream: impl AsRef<Stream>,
741) -> Result<Array> {
742    let a = a.as_ref();
743    let keep_dims = keep_dims.into().unwrap_or(false);
744    let ddof = ddof.into().unwrap_or(0);
745    Array::try_from_op(|res| unsafe {
746        mlx_sys::mlx_std(res, a.as_ptr(), keep_dims, ddof, stream.as_ref().as_ptr())
747    })
748}
749
750/// See [`Array::sum_axes`]
751#[generate_macro]
752#[default_device]
753pub fn sum_axes_device(
754    array: impl AsRef<Array>,
755    axes: &[i32],
756    #[optional] keep_dims: impl Into<Option<bool>>,
757    #[optional] stream: impl AsRef<Stream>,
758) -> Result<Array> {
759    array.as_ref().sum_axes_device(axes, keep_dims, stream)
760}
761
762/// See [`Array::sum_axis`]
763#[generate_macro]
764#[default_device]
765pub fn sum_axis_device(
766    array: impl AsRef<Array>,
767    axis: i32,
768    #[optional] keep_dims: impl Into<Option<bool>>,
769    #[optional] stream: impl AsRef<Stream>,
770) -> Result<Array> {
771    array.as_ref().sum_axis_device(axis, keep_dims, stream)
772}
773
774/// See [`Array::sum`]
775#[generate_macro]
776#[default_device]
777pub fn sum_device(
778    array: impl AsRef<Array>,
779    #[optional] keep_dims: impl Into<Option<bool>>,
780    #[optional] stream: impl AsRef<Stream>,
781) -> Result<Array> {
782    array.as_ref().sum_device(keep_dims, stream)
783}
784
785/// See [`Array::mean_axes`]
786#[generate_macro]
787#[default_device]
788pub fn mean_axes_device(
789    array: impl AsRef<Array>,
790    axes: &[i32],
791    #[optional] keep_dims: impl Into<Option<bool>>,
792    #[optional] stream: impl AsRef<Stream>,
793) -> Result<Array> {
794    array.as_ref().mean_axes_device(axes, keep_dims, stream)
795}
796
797/// See [`Array::mean_axis`]
798#[generate_macro]
799#[default_device]
800pub fn mean_axis_device(
801    array: impl AsRef<Array>,
802    axis: i32,
803    #[optional] keep_dims: impl Into<Option<bool>>,
804    #[optional] stream: impl AsRef<Stream>,
805) -> Result<Array> {
806    array.as_ref().mean_axis_device(axis, keep_dims, stream)
807}
808
809/// See [`Array::mean`]
810#[generate_macro]
811#[default_device]
812pub fn mean_device(
813    array: impl AsRef<Array>,
814    #[optional] keep_dims: impl Into<Option<bool>>,
815    #[optional] stream: impl AsRef<Stream>,
816) -> Result<Array> {
817    array.as_ref().mean_device(keep_dims, stream)
818}
819
820/// See [`Array::min`]
821#[generate_macro]
822#[default_device]
823pub fn min_axes_device(
824    array: impl AsRef<Array>,
825    axes: &[i32],
826    #[optional] keep_dims: impl Into<Option<bool>>,
827    #[optional] stream: impl AsRef<Stream>,
828) -> Result<Array> {
829    array.as_ref().min_axes_device(axes, keep_dims, stream)
830}
831
832/// See [`Array::min_axis`]
833#[generate_macro]
834#[default_device]
835pub fn min_axis_device(
836    array: impl AsRef<Array>,
837    axis: i32,
838    #[optional] keep_dims: impl Into<Option<bool>>,
839    #[optional] stream: impl AsRef<Stream>,
840) -> Result<Array> {
841    array.as_ref().min_axis_device(axis, keep_dims, stream)
842}
843
844/// See [`Array::min`]
845#[generate_macro]
846#[default_device]
847pub fn min_device(
848    array: impl AsRef<Array>,
849    #[optional] keep_dims: impl Into<Option<bool>>,
850    #[optional] stream: impl AsRef<Stream>,
851) -> Result<Array> {
852    array.as_ref().min_device(keep_dims, stream)
853}
854
855/// See [`Array::var_axes`]
856#[generate_macro]
857#[default_device]
858pub fn var_axes_device(
859    array: impl AsRef<Array>,
860    axes: &[i32],
861    #[optional] keep_dims: impl Into<Option<bool>>,
862    #[optional] ddof: impl Into<Option<i32>>,
863    #[optional] stream: impl AsRef<Stream>,
864) -> Result<Array> {
865    array
866        .as_ref()
867        .var_axes_device(axes, keep_dims, ddof, stream)
868}
869
870/// See [`Array::var_axis`]
871#[generate_macro]
872#[default_device]
873pub fn var_axis_device(
874    array: impl AsRef<Array>,
875    axis: i32,
876    #[optional] keep_dims: impl Into<Option<bool>>,
877    #[optional] ddof: impl Into<Option<i32>>,
878    #[optional] stream: impl AsRef<Stream>,
879) -> Result<Array> {
880    array
881        .as_ref()
882        .var_axis_device(axis, keep_dims, ddof, stream)
883}
884
885/// See [`Array::var`]
886#[generate_macro]
887#[default_device]
888pub fn var_device(
889    array: impl AsRef<Array>,
890    #[optional] keep_dims: impl Into<Option<bool>>,
891    #[optional] ddof: impl Into<Option<i32>>,
892    #[optional] stream: impl AsRef<Stream>,
893) -> Result<Array> {
894    array.as_ref().var_device(keep_dims, ddof, stream)
895}
896
897/// See [`Array::logsumexp_axes`]
898#[generate_macro]
899#[default_device]
900pub fn logsumexp_axes_device(
901    array: impl AsRef<Array>,
902    axes: &[i32],
903    #[optional] keep_dims: impl Into<Option<bool>>,
904    #[optional] stream: impl AsRef<Stream>,
905) -> Result<Array> {
906    array
907        .as_ref()
908        .logsumexp_axes_device(axes, keep_dims, stream)
909}
910
911/// See [`Array::logsumexp_axis`]
912#[generate_macro]
913#[default_device]
914pub fn logsumexp_axis_device(
915    array: impl AsRef<Array>,
916    axis: i32,
917    #[optional] keep_dims: impl Into<Option<bool>>,
918    #[optional] stream: impl AsRef<Stream>,
919) -> Result<Array> {
920    array
921        .as_ref()
922        .logsumexp_axis_device(axis, keep_dims, stream)
923}
924
925/// See [`Array::logsumexp`]
926#[generate_macro]
927#[default_device]
928pub fn logsumexp_device(
929    array: impl AsRef<Array>,
930    #[optional] keep_dims: impl Into<Option<bool>>,
931    #[optional] stream: impl AsRef<Stream>,
932) -> Result<Array> {
933    array.as_ref().logsumexp_device(keep_dims, stream)
934}
935
936#[cfg(test)]
937mod tests {
938    use super::*;
939    use pretty_assertions::assert_eq;
940
941    #[test]
942    fn test_all() {
943        let array = Array::from_slice(&[true, false, true, false], &[2, 2]);
944
945        assert_eq!(array.all(None).unwrap().item::<bool>(), false);
946        assert_eq!(array.all(true).unwrap().shape(), &[1, 1]);
947        assert_eq!(array.all_axes(&[0, 1], None).unwrap().item::<bool>(), false);
948
949        let result = array.all_axis(0, None).unwrap();
950        assert_eq!(result.as_slice::<bool>(), &[true, false]);
951
952        let result = array.all_axis(1, None).unwrap();
953        assert_eq!(result.as_slice::<bool>(), &[false, false]);
954    }
955
956    #[test]
957    fn test_all_empty_axes() {
958        let array = Array::from_slice(&[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11], &[3, 4]);
959        let all = array.all_axes(&[], None).unwrap();
960
961        let results: &[bool] = all.as_slice();
962        assert_eq!(
963            results,
964            &[false, true, true, true, true, true, true, true, true, true, true, true]
965        );
966    }
967
968    #[test]
969    fn test_prod() {
970        let x = Array::from_slice(&[1, 2, 3, 3], &[2, 2]);
971        assert_eq!(x.prod(None).unwrap().item::<i32>(), 18);
972
973        let y = x.prod(true).unwrap();
974        assert_eq!(y.item::<i32>(), 18);
975        assert_eq!(y.shape(), &[1, 1]);
976
977        let result = x.prod_axis(0, None).unwrap();
978        assert_eq!(result.as_slice::<i32>(), &[3, 6]);
979
980        let result = x.prod_axis(1, None).unwrap();
981        assert_eq!(result.as_slice::<i32>(), &[2, 9])
982    }
983
984    #[test]
985    fn test_prod_empty_axes() {
986        let array = Array::from_slice(&[5, 8, 4, 9], &[2, 2]);
987        let result = array.prod_axes(&[], None).unwrap();
988
989        let results: &[i32] = result.as_slice();
990        assert_eq!(results, &[5, 8, 4, 9]);
991    }
992
993    #[test]
994    fn test_max() {
995        let x = Array::from_slice(&[1, 2, 3, 4], &[2, 2]);
996        assert_eq!(x.max(None).unwrap().item::<i32>(), 4);
997        let y = x.max(true).unwrap();
998        assert_eq!(y.item::<i32>(), 4);
999        assert_eq!(y.shape(), &[1, 1]);
1000
1001        let result = x.max_axis(0, None).unwrap();
1002        assert_eq!(result.as_slice::<i32>(), &[3, 4]);
1003
1004        let result = x.max_axis(1, None).unwrap();
1005        assert_eq!(result.as_slice::<i32>(), &[2, 4]);
1006    }
1007
1008    #[test]
1009    fn test_max_empty_axes() {
1010        let array = Array::from_slice(&[5, 8, 4, 9], &[2, 2]);
1011        let result = array.max_axes(&[], None).unwrap();
1012
1013        let results: &[i32] = result.as_slice();
1014        assert_eq!(results, &[5, 8, 4, 9]);
1015    }
1016
1017    #[test]
1018    fn test_sum() {
1019        let array = Array::from_slice(&[5, 8, 4, 9], &[2, 2]);
1020        let result = array.sum_axis(0, None).unwrap();
1021
1022        let results: &[i32] = result.as_slice();
1023        assert_eq!(results, &[9, 17]);
1024    }
1025
1026    #[test]
1027    fn test_sum_empty_axes() {
1028        let array = Array::from_slice(&[5, 8, 4, 9], &[2, 2]);
1029        let result = array.sum_axes(&[], None).unwrap();
1030
1031        let results: &[i32] = result.as_slice();
1032        assert_eq!(results, &[5, 8, 4, 9]);
1033    }
1034
1035    #[test]
1036    fn test_mean() {
1037        let x = Array::from_slice(&[1, 2, 3, 4], &[2, 2]);
1038        assert_eq!(x.mean(None).unwrap().item::<f32>(), 2.5);
1039        let y = x.mean(true).unwrap();
1040        assert_eq!(y.item::<f32>(), 2.5);
1041        assert_eq!(y.shape(), &[1, 1]);
1042
1043        let result = x.mean_axis(0, None).unwrap();
1044        assert_eq!(result.as_slice::<f32>(), &[2.0, 3.0]);
1045
1046        let result = x.mean_axis(1, None).unwrap();
1047        assert_eq!(result.as_slice::<f32>(), &[1.5, 3.5]);
1048    }
1049
1050    #[test]
1051    fn test_mean_empty_axes() {
1052        let array = Array::from_slice(&[5, 8, 4, 9], &[2, 2]);
1053        let result = array.mean_axes(&[], None).unwrap();
1054
1055        let results: &[f32] = result.as_slice();
1056        assert_eq!(results, &[5.0, 8.0, 4.0, 9.0]);
1057    }
1058
1059    #[test]
1060    fn test_mean_out_of_bounds() {
1061        let array = Array::from_slice(&[5, 8, 4, 9], &[2, 2]);
1062        let result = array.mean_axis(2, None);
1063        assert!(result.is_err());
1064    }
1065
1066    #[test]
1067    fn test_min() {
1068        let x = Array::from_slice(&[1, 2, 3, 4], &[2, 2]);
1069        assert_eq!(x.min(None).unwrap().item::<i32>(), 1);
1070        let y = x.min(true).unwrap();
1071        assert_eq!(y.item::<i32>(), 1);
1072        assert_eq!(y.shape(), &[1, 1]);
1073
1074        let result = x.min_axis(0, None).unwrap();
1075        assert_eq!(result.as_slice::<i32>(), &[1, 2]);
1076
1077        let result = x.min_axis(1, None).unwrap();
1078        assert_eq!(result.as_slice::<i32>(), &[1, 3]);
1079    }
1080
1081    #[test]
1082    fn test_min_empty_axes() {
1083        let array = Array::from_slice(&[5, 8, 4, 9], &[2, 2]);
1084        let result = array.min_axes(&[], None).unwrap();
1085
1086        let results: &[i32] = result.as_slice();
1087        assert_eq!(results, &[5, 8, 4, 9]);
1088    }
1089
1090    #[test]
1091    fn test_var() {
1092        let x = Array::from_slice(&[1, 2, 3, 4], &[2, 2]);
1093        assert_eq!(x.var(None, None).unwrap().item::<f32>(), 1.25);
1094        let y = x.var(true, None).unwrap();
1095        assert_eq!(y.item::<f32>(), 1.25);
1096        assert_eq!(y.shape(), &[1, 1]);
1097
1098        let result = x.var_axis(0, None, None).unwrap();
1099        assert_eq!(result.as_slice::<f32>(), &[1.0, 1.0]);
1100
1101        let result = x.var_axis(1, None, None).unwrap();
1102        assert_eq!(result.as_slice::<f32>(), &[0.25, 0.25]);
1103
1104        let x = Array::from_slice(&[1.0, 2.0], &[2]);
1105        let out = x.var(None, Some(3)).unwrap();
1106        assert_eq!(out.item::<f32>(), f32::INFINITY);
1107    }
1108
1109    #[test]
1110    fn test_var_empty_axes() {
1111        let array = Array::from_slice(&[5, 8, 4, 9], &[2, 2]);
1112        let result = array.var_axes(&[], None, 0).unwrap();
1113
1114        let results: &[f32] = result.as_slice();
1115        assert_eq!(results, &[0.0, 0.0, 0.0, 0.0]);
1116    }
1117
1118    #[test]
1119    fn test_log_sum_exp() {
1120        let array = Array::from_slice(&[5, 8, 4, 9], &[2, 2]);
1121        let result = array.logsumexp_axis(0, None).unwrap();
1122
1123        let results: &[f32] = result.as_slice();
1124        assert_eq!(results, &[5.3132615, 9.313262]);
1125    }
1126
1127    #[test]
1128    fn test_log_sum_exp_empty_axes() {
1129        let array = Array::from_slice(&[5, 8, 4, 9], &[2, 2]);
1130        let result = array.logsumexp_axes(&[], None).unwrap();
1131
1132        let results: &[f32] = result.as_slice();
1133        assert_eq!(results, &[5.0, 8.0, 4.0, 9.0]);
1134    }
1135}