1use crate::array::Array;
2use crate::error::Result;
3use crate::stream::StreamOrDevice;
4use crate::utils::guard::Guarded;
5use crate::utils::{axes_or_default_to_all, IntoOption};
6use crate::Stream;
7use mlx_internal_macros::{default_device, generate_macro};
8
9impl Array {
10 #[default_device]
28 pub fn all_device<'a>(
29 &self,
30 axes: impl IntoOption<&'a [i32]>,
31 keep_dims: impl Into<Option<bool>>,
32 stream: impl AsRef<Stream>,
33 ) -> Result<Array> {
34 let axes = axes_or_default_to_all(axes, self.ndim() as i32);
35 Array::try_from_op(|res| unsafe {
36 mlx_sys::mlx_all_axes(
37 res,
38 self.as_ptr(),
39 axes.as_ptr(),
40 axes.len(),
41 keep_dims.into().unwrap_or(false),
42 stream.as_ref().as_ptr(),
43 )
44 })
45 }
46
47 #[default_device]
64 pub fn prod_device<'a>(
65 &self,
66 axes: impl IntoOption<&'a [i32]>,
67 keep_dims: impl Into<Option<bool>>,
68 stream: impl AsRef<Stream>,
69 ) -> Result<Array> {
70 let axes = axes_or_default_to_all(axes, self.ndim() as i32);
71 Array::try_from_op(|res| unsafe {
72 mlx_sys::mlx_prod(
73 res,
74 self.as_ptr(),
75 axes.as_ptr(),
76 axes.len(),
77 keep_dims.into().unwrap_or(false),
78 stream.as_ref().as_ptr(),
79 )
80 })
81 }
82
83 #[default_device]
100 pub fn max_device<'a>(
101 &self,
102 axes: impl IntoOption<&'a [i32]>,
103 keep_dims: impl Into<Option<bool>>,
104 stream: impl AsRef<Stream>,
105 ) -> Result<Array> {
106 let axes = axes_or_default_to_all(axes, self.ndim() as i32);
107 Array::try_from_op(|res| unsafe {
108 mlx_sys::mlx_max(
109 res,
110 self.as_ptr(),
111 axes.as_ptr(),
112 axes.len(),
113 keep_dims.into().unwrap_or(false),
114 stream.as_ref().as_ptr(),
115 )
116 })
117 }
118
119 #[default_device]
136 pub fn sum_device<'a>(
137 &self,
138 axes: impl IntoOption<&'a [i32]>,
139 keep_dims: impl Into<Option<bool>>,
140 stream: impl AsRef<Stream>,
141 ) -> Result<Array> {
142 let axes = axes_or_default_to_all(axes, self.ndim() as i32);
143 Array::try_from_op(|res| unsafe {
144 mlx_sys::mlx_sum(
145 res,
146 self.as_ptr(),
147 axes.as_ptr(),
148 axes.len(),
149 keep_dims.into().unwrap_or(false),
150 stream.as_ref().as_ptr(),
151 )
152 })
153 }
154
155 #[default_device]
172 pub fn mean_device<'a>(
173 &self,
174 axes: impl IntoOption<&'a [i32]>,
175 keep_dims: impl Into<Option<bool>>,
176 stream: impl AsRef<Stream>,
177 ) -> Result<Array> {
178 let axes = axes_or_default_to_all(axes, self.ndim() as i32);
179 Array::try_from_op(|res| unsafe {
180 mlx_sys::mlx_mean(
181 res,
182 self.as_ptr(),
183 axes.as_ptr(),
184 axes.len(),
185 keep_dims.into().unwrap_or(false),
186 stream.as_ref().as_ptr(),
187 )
188 })
189 }
190
191 #[default_device]
208 pub fn min_device<'a>(
209 &self,
210 axes: impl IntoOption<&'a [i32]>,
211 keep_dims: impl Into<Option<bool>>,
212 stream: impl AsRef<Stream>,
213 ) -> Result<Array> {
214 let axes = axes_or_default_to_all(axes, self.ndim() as i32);
215 Array::try_from_op(|res| unsafe {
216 mlx_sys::mlx_min(
217 res,
218 self.as_ptr(),
219 axes.as_ptr(),
220 axes.len(),
221 keep_dims.into().unwrap_or(false),
222 stream.as_ref().as_ptr(),
223 )
224 })
225 }
226
227 #[default_device]
235 pub fn variance_device<'a>(
236 &self,
237 axes: impl IntoOption<&'a [i32]>,
238 keep_dims: impl Into<Option<bool>>,
239 ddof: impl Into<Option<i32>>,
240 stream: impl AsRef<Stream>,
241 ) -> Result<Array> {
242 let axes = axes_or_default_to_all(axes, self.ndim() as i32);
243 Array::try_from_op(|res| unsafe {
244 mlx_sys::mlx_var(
245 res,
246 self.as_ptr(),
247 axes.as_ptr(),
248 axes.len(),
249 keep_dims.into().unwrap_or(false),
250 ddof.into().unwrap_or(0),
251 stream.as_ref().as_ptr(),
252 )
253 })
254 }
255
256 #[default_device]
265 pub fn log_sum_exp_device<'a>(
266 &self,
267 axes: impl IntoOption<&'a [i32]>,
268 keep_dims: impl Into<Option<bool>>,
269 stream: impl AsRef<Stream>,
270 ) -> Result<Array> {
271 let axes = axes_or_default_to_all(axes, self.ndim() as i32);
272 Array::try_from_op(|res| unsafe {
273 mlx_sys::mlx_logsumexp(
274 res,
275 self.as_ptr(),
276 axes.as_ptr(),
277 axes.len(),
278 keep_dims.into().unwrap_or(false),
279 stream.as_ref().as_ptr(),
280 )
281 })
282 }
283}
284
285#[generate_macro]
287#[default_device]
288pub fn all_device<'a>(
289 array: impl AsRef<Array>,
290 #[optional] axes: impl IntoOption<&'a [i32]>,
291 #[optional] keep_dims: impl Into<Option<bool>>,
292 #[optional] stream: impl AsRef<Stream>,
293) -> Result<Array> {
294 array.as_ref().all_device(axes, keep_dims, stream)
295}
296
297#[generate_macro]
299#[default_device]
300pub fn prod_device<'a>(
301 array: impl AsRef<Array>,
302 #[optional] axes: impl IntoOption<&'a [i32]>,
303 #[optional] keep_dims: impl Into<Option<bool>>,
304 #[optional] stream: impl AsRef<Stream>,
305) -> Result<Array> {
306 array.as_ref().prod_device(axes, keep_dims, stream)
307}
308
309#[generate_macro]
311#[default_device]
312pub fn max_device<'a>(
313 array: impl AsRef<Array>,
314 #[optional] axes: impl IntoOption<&'a [i32]>,
315 #[optional] keep_dims: impl Into<Option<bool>>,
316 #[optional] stream: impl AsRef<Stream>,
317) -> Result<Array> {
318 array.as_ref().max_device(axes, keep_dims, stream)
319}
320
321#[generate_macro]
331#[default_device]
332pub fn std_device<'a>(
333 a: impl AsRef<Array>,
334 #[optional] axes: impl IntoOption<&'a [i32]>,
335 #[optional] keep_dims: impl Into<Option<bool>>,
336 #[optional] ddof: impl Into<Option<i32>>,
337 #[optional] stream: impl AsRef<Stream>,
338) -> Result<Array> {
339 let a = a.as_ref();
340 let axes = axes_or_default_to_all(axes, a.ndim() as i32);
341 let keep_dims = keep_dims.into().unwrap_or(false);
342 let ddof = ddof.into().unwrap_or(0);
343 Array::try_from_op(|res| unsafe {
344 mlx_sys::mlx_std(
345 res,
346 a.as_ptr(),
347 axes.as_ptr(),
348 axes.len(),
349 keep_dims,
350 ddof,
351 stream.as_ref().as_ptr(),
352 )
353 })
354}
355
356#[generate_macro]
358#[default_device]
359pub fn sum_device<'a>(
360 array: impl AsRef<Array>,
361 #[optional] axes: impl IntoOption<&'a [i32]>,
362 #[optional] keep_dims: impl Into<Option<bool>>,
363 #[optional] stream: impl AsRef<Stream>,
364) -> Result<Array> {
365 array.as_ref().sum_device(axes, keep_dims, stream)
366}
367
368#[generate_macro]
370#[default_device]
371pub fn mean_device<'a>(
372 array: impl AsRef<Array>,
373 #[optional] axes: impl IntoOption<&'a [i32]>,
374 #[optional] keep_dims: impl Into<Option<bool>>,
375 #[optional] stream: impl AsRef<Stream>,
376) -> Result<Array> {
377 array.as_ref().mean_device(axes, keep_dims, stream)
378}
379
380#[generate_macro]
382#[default_device]
383pub fn min_device<'a>(
384 array: impl AsRef<Array>,
385 #[optional] axes: impl IntoOption<&'a [i32]>,
386 #[optional] keep_dims: impl Into<Option<bool>>,
387 #[optional] stream: impl AsRef<Stream>,
388) -> Result<Array> {
389 array.as_ref().min_device(axes, keep_dims, stream)
390}
391
392#[generate_macro]
394#[default_device]
395pub fn variance_device<'a>(
396 array: impl AsRef<Array>,
397 #[optional] axes: impl IntoOption<&'a [i32]>,
398 #[optional] keep_dims: impl Into<Option<bool>>,
399 #[optional] ddof: impl Into<Option<i32>>,
400 #[optional] stream: impl AsRef<Stream>,
401) -> Result<Array> {
402 array
403 .as_ref()
404 .variance_device(axes, keep_dims, ddof, stream)
405}
406
407#[generate_macro]
409#[default_device]
410pub fn log_sum_exp_device<'a>(
411 array: impl AsRef<Array>,
412 #[optional] axes: impl IntoOption<&'a [i32]>,
413 #[optional] keep_dims: impl Into<Option<bool>>,
414 #[optional] stream: impl AsRef<Stream>,
415) -> Result<Array> {
416 array.as_ref().log_sum_exp_device(axes, keep_dims, stream)
417}
418
419#[cfg(test)]
420mod tests {
421 use super::*;
422 use pretty_assertions::assert_eq;
423
424 #[test]
425 fn test_all() {
426 let array = Array::from_slice(&[true, false, true, false], &[2, 2]);
427
428 assert_eq!(array.all(None, None).unwrap().item::<bool>(), false);
429 assert_eq!(array.all(None, true).unwrap().shape(), &[1, 1]);
430 assert_eq!(array.all(&[0, 1][..], None).unwrap().item::<bool>(), false);
431
432 let result = array.all(&[0][..], None).unwrap();
433 assert_eq!(result.as_slice::<bool>(), &[true, false]);
434
435 let result = array.all(&[1][..], None).unwrap();
436 assert_eq!(result.as_slice::<bool>(), &[false, false]);
437 }
438
439 #[test]
440 fn test_all_empty_axes() {
441 let array = Array::from_slice(&[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11], &[3, 4]);
442 let all = array.all(&[][..], None).unwrap();
443
444 let results: &[bool] = all.as_slice();
445 assert_eq!(
446 results,
447 &[false, true, true, true, true, true, true, true, true, true, true, true]
448 );
449 }
450
451 #[test]
452 fn test_prod() {
453 let x = Array::from_slice(&[1, 2, 3, 3], &[2, 2]);
454 assert_eq!(x.prod(None, None).unwrap().item::<i32>(), 18);
455
456 let y = x.prod(None, true).unwrap();
457 assert_eq!(y.item::<i32>(), 18);
458 assert_eq!(y.shape(), &[1, 1]);
459
460 let result = x.prod(&[0][..], None).unwrap();
461 assert_eq!(result.as_slice::<i32>(), &[3, 6]);
462
463 let result = x.prod(&[1][..], None).unwrap();
464 assert_eq!(result.as_slice::<i32>(), &[2, 9])
465 }
466
467 #[test]
468 fn test_prod_empty_axes() {
469 let array = Array::from_slice(&[5, 8, 4, 9], &[2, 2]);
470 let result = array.prod(&[][..], None).unwrap();
471
472 let results: &[i32] = result.as_slice();
473 assert_eq!(results, &[5, 8, 4, 9]);
474 }
475
476 #[test]
477 fn test_max() {
478 let x = Array::from_slice(&[1, 2, 3, 4], &[2, 2]);
479 assert_eq!(x.max(None, None).unwrap().item::<i32>(), 4);
480 let y = x.max(None, true).unwrap();
481 assert_eq!(y.item::<i32>(), 4);
482 assert_eq!(y.shape(), &[1, 1]);
483
484 let result = x.max(&[0][..], None).unwrap();
485 assert_eq!(result.as_slice::<i32>(), &[3, 4]);
486
487 let result = x.max(&[1][..], None).unwrap();
488 assert_eq!(result.as_slice::<i32>(), &[2, 4]);
489 }
490
491 #[test]
492 fn test_max_empty_axes() {
493 let array = Array::from_slice(&[5, 8, 4, 9], &[2, 2]);
494 let result = array.max(&[][..], None).unwrap();
495
496 let results: &[i32] = result.as_slice();
497 assert_eq!(results, &[5, 8, 4, 9]);
498 }
499
500 #[test]
501 fn test_sum() {
502 let array = Array::from_slice(&[5, 8, 4, 9], &[2, 2]);
503 let result = array.sum(&[0][..], None).unwrap();
504
505 let results: &[i32] = result.as_slice();
506 assert_eq!(results, &[9, 17]);
507 }
508
509 #[test]
510 fn test_sum_empty_axes() {
511 let array = Array::from_slice(&[5, 8, 4, 9], &[2, 2]);
512 let result = array.sum(&[][..], None).unwrap();
513
514 let results: &[i32] = result.as_slice();
515 assert_eq!(results, &[5, 8, 4, 9]);
516 }
517
518 #[test]
519 fn test_mean() {
520 let x = Array::from_slice(&[1, 2, 3, 4], &[2, 2]);
521 assert_eq!(x.mean(None, None).unwrap().item::<f32>(), 2.5);
522 let y = x.mean(None, true).unwrap();
523 assert_eq!(y.item::<f32>(), 2.5);
524 assert_eq!(y.shape(), &[1, 1]);
525
526 let result = x.mean(&[0][..], None).unwrap();
527 assert_eq!(result.as_slice::<f32>(), &[2.0, 3.0]);
528
529 let result = x.mean(&[1][..], None).unwrap();
530 assert_eq!(result.as_slice::<f32>(), &[1.5, 3.5]);
531 }
532
533 #[test]
534 fn test_mean_empty_axes() {
535 let array = Array::from_slice(&[5, 8, 4, 9], &[2, 2]);
536 let result = array.mean(&[][..], None).unwrap();
537
538 let results: &[f32] = result.as_slice();
539 assert_eq!(results, &[5.0, 8.0, 4.0, 9.0]);
540 }
541
542 #[test]
543 fn test_mean_out_of_bounds() {
544 let array = Array::from_slice(&[5, 8, 4, 9], &[2, 2]);
545 let result = array.mean(&[2][..], None);
546 assert!(result.is_err());
547 }
548
549 #[test]
550 fn test_min() {
551 let x = Array::from_slice(&[1, 2, 3, 4], &[2, 2]);
552 assert_eq!(x.min(None, None).unwrap().item::<i32>(), 1);
553 let y = x.min(None, true).unwrap();
554 assert_eq!(y.item::<i32>(), 1);
555 assert_eq!(y.shape(), &[1, 1]);
556
557 let result = x.min(&[0][..], None).unwrap();
558 assert_eq!(result.as_slice::<i32>(), &[1, 2]);
559
560 let result = x.min(&[1][..], None).unwrap();
561 assert_eq!(result.as_slice::<i32>(), &[1, 3]);
562 }
563
564 #[test]
565 fn test_min_empty_axes() {
566 let array = Array::from_slice(&[5, 8, 4, 9], &[2, 2]);
567 let result = array.min(&[][..], None).unwrap();
568
569 let results: &[i32] = result.as_slice();
570 assert_eq!(results, &[5, 8, 4, 9]);
571 }
572
573 #[test]
574 fn test_var() {
575 let x = Array::from_slice(&[1, 2, 3, 4], &[2, 2]);
576 assert_eq!(x.variance(None, None, None).unwrap().item::<f32>(), 1.25);
577 let y = x.variance(None, true, None).unwrap();
578 assert_eq!(y.item::<f32>(), 1.25);
579 assert_eq!(y.shape(), &[1, 1]);
580
581 let result = x.variance(&[0][..], None, None).unwrap();
582 assert_eq!(result.as_slice::<f32>(), &[1.0, 1.0]);
583
584 let result = x.variance(&[1][..], None, None).unwrap();
585 assert_eq!(result.as_slice::<f32>(), &[0.25, 0.25]);
586
587 let x = Array::from_slice(&[1.0, 2.0], &[2]);
588 let out = x.variance(None, None, Some(3)).unwrap();
589 assert_eq!(out.item::<f32>(), f32::INFINITY);
590 }
591
592 #[test]
593 fn test_var_empty_axes() {
594 let array = Array::from_slice(&[5, 8, 4, 9], &[2, 2]);
595 let result = array.variance(&[][..], None, 0).unwrap();
596
597 let results: &[f32] = result.as_slice();
598 assert_eq!(results, &[0.0, 0.0, 0.0, 0.0]);
599 }
600
601 #[test]
602 fn test_log_sum_exp() {
603 let array = Array::from_slice(&[5, 8, 4, 9], &[2, 2]);
604 let result = array.log_sum_exp(&[0][..], None).unwrap();
605
606 let results: &[f32] = result.as_slice();
607 assert_eq!(results, &[5.3132615, 9.313262]);
608 }
609
610 #[test]
611 fn test_log_sum_exp_empty_axes() {
612 let array = Array::from_slice(&[5, 8, 4, 9], &[2, 2]);
613 let result = array.log_sum_exp(&[][..], None).unwrap();
614
615 let results: &[f32] = result.as_slice();
616 assert_eq!(results, &[5.0, 8.0, 4.0, 9.0]);
617 }
618}