1use crate::error::Result;
2use crate::utils::guard::Guarded;
3use crate::{Array, Stream, StreamOrDevice};
4use mlx_internal_macros::{default_device, generate_macro};
5
6impl Array {
7 #[default_device]
25 pub fn cummax_device(
26 &self,
27 axis: impl Into<Option<i32>>,
28 reverse: impl Into<Option<bool>>,
29 inclusive: impl Into<Option<bool>>,
30 stream: impl AsRef<Stream>,
31 ) -> Result<Array> {
32 let stream = stream.as_ref();
33
34 match axis.into() {
35 Some(axis) => Array::try_from_op(|res| unsafe {
36 mlx_sys::mlx_cummax(
37 res,
38 self.as_ptr(),
39 axis,
40 reverse.into().unwrap_or(false),
41 inclusive.into().unwrap_or(true),
42 stream.as_ptr(),
43 )
44 }),
45 None => {
46 let shape = &[-1];
47 let flat = self.reshape_device(shape, stream)?;
48 Array::try_from_op(|res| unsafe {
49 mlx_sys::mlx_cummax(
50 res,
51 flat.as_ptr(),
52 0,
53 reverse.into().unwrap_or(false),
54 inclusive.into().unwrap_or(true),
55 stream.as_ptr(),
56 )
57 })
58 }
59 }
60 }
61
62 #[default_device]
80 pub fn cummin_device(
81 &self,
82 axis: impl Into<Option<i32>>,
83 reverse: impl Into<Option<bool>>,
84 inclusive: impl Into<Option<bool>>,
85 stream: impl AsRef<Stream>,
86 ) -> Result<Array> {
87 let stream = stream.as_ref();
88
89 match axis.into() {
90 Some(axis) => Array::try_from_op(|res| unsafe {
91 mlx_sys::mlx_cummin(
92 res,
93 self.as_ptr(),
94 axis,
95 reverse.into().unwrap_or(false),
96 inclusive.into().unwrap_or(true),
97 stream.as_ptr(),
98 )
99 }),
100 None => {
101 let shape = &[-1];
102 let flat = self.reshape_device(shape, stream)?;
103 Array::try_from_op(|res| unsafe {
104 mlx_sys::mlx_cummin(
105 res,
106 flat.as_ptr(),
107 0,
108 reverse.into().unwrap_or(false),
109 inclusive.into().unwrap_or(true),
110 stream.as_ptr(),
111 )
112 })
113 }
114 }
115 }
116
117 #[default_device]
135 pub fn cumprod_device(
136 &self,
137 axis: impl Into<Option<i32>>,
138 reverse: impl Into<Option<bool>>,
139 inclusive: impl Into<Option<bool>>,
140 stream: impl AsRef<Stream>,
141 ) -> Result<Array> {
142 let stream = stream.as_ref();
143
144 match axis.into() {
145 Some(axis) => Array::try_from_op(|res| unsafe {
146 mlx_sys::mlx_cumprod(
147 res,
148 self.as_ptr(),
149 axis,
150 reverse.into().unwrap_or(false),
151 inclusive.into().unwrap_or(true),
152 stream.as_ptr(),
153 )
154 }),
155 None => {
156 let shape = &[-1];
157 let flat = self.reshape_device(shape, stream)?;
158 Array::try_from_op(|res| unsafe {
159 mlx_sys::mlx_cumprod(
160 res,
161 flat.as_ptr(),
162 0,
163 reverse.into().unwrap_or(false),
164 inclusive.into().unwrap_or(true),
165 stream.as_ptr(),
166 )
167 })
168 }
169 }
170 }
171
172 #[default_device]
190 pub fn cumsum_device(
191 &self,
192 axis: impl Into<Option<i32>>,
193 reverse: impl Into<Option<bool>>,
194 inclusive: impl Into<Option<bool>>,
195 stream: impl AsRef<Stream>,
196 ) -> Result<Array> {
197 let stream = stream.as_ref();
198
199 match axis.into() {
200 Some(axis) => Array::try_from_op(|res| unsafe {
201 mlx_sys::mlx_cumsum(
202 res,
203 self.as_ptr(),
204 axis,
205 reverse.into().unwrap_or(false),
206 inclusive.into().unwrap_or(true),
207 stream.as_ptr(),
208 )
209 }),
210 None => {
211 let shape = &[-1];
212 let flat = self.reshape_device(shape, stream)?;
213 Array::try_from_op(|res| unsafe {
214 mlx_sys::mlx_cumsum(
215 res,
216 flat.as_ptr(),
217 0,
218 reverse.into().unwrap_or(false),
219 inclusive.into().unwrap_or(true),
220 stream.as_ptr(),
221 )
222 })
223 }
224 }
225 }
226}
227
228#[generate_macro]
230#[default_device]
231pub fn cummax_device(
232 a: impl AsRef<Array>,
233 #[optional] axis: impl Into<Option<i32>>,
234 #[optional] reverse: impl Into<Option<bool>>,
235 #[optional] inclusive: impl Into<Option<bool>>,
236 #[optional] stream: impl AsRef<Stream>,
237) -> Result<Array> {
238 a.as_ref().cummax_device(axis, reverse, inclusive, stream)
239}
240
241#[generate_macro]
243#[default_device]
244pub fn cummin_device(
245 a: impl AsRef<Array>,
246 #[optional] axis: impl Into<Option<i32>>,
247 #[optional] reverse: impl Into<Option<bool>>,
248 #[optional] inclusive: impl Into<Option<bool>>,
249 #[optional] stream: impl AsRef<Stream>,
250) -> Result<Array> {
251 a.as_ref().cummin_device(axis, reverse, inclusive, stream)
252}
253
254#[generate_macro]
256#[default_device]
257pub fn cumprod_device(
258 a: impl AsRef<Array>,
259 #[optional] axis: impl Into<Option<i32>>,
260 #[optional] reverse: impl Into<Option<bool>>,
261 #[optional] inclusive: impl Into<Option<bool>>,
262 #[optional] stream: impl AsRef<Stream>,
263) -> Result<Array> {
264 a.as_ref().cumprod_device(axis, reverse, inclusive, stream)
265}
266
267#[generate_macro]
269#[default_device]
270pub fn cumsum_device(
271 a: impl AsRef<Array>,
272 #[optional] axis: impl Into<Option<i32>>,
273 #[optional] reverse: impl Into<Option<bool>>,
274 #[optional] inclusive: impl Into<Option<bool>>,
275 #[optional] stream: impl AsRef<Stream>,
276) -> Result<Array> {
277 a.as_ref().cumsum_device(axis, reverse, inclusive, stream)
278}
279
280#[cfg(test)]
281mod tests {
282 use super::*;
283 use pretty_assertions::assert_eq;
284
285 #[test]
286 fn test_cummax() {
287 let array = Array::from_slice(&[5, 8, 4, 9], &[2, 2]);
288
289 let result = array.cummax(0, None, None).unwrap();
290 assert_eq!(result.shape(), &[2, 2]);
291 assert_eq!(result.as_slice::<i32>(), &[5, 8, 5, 9]);
292
293 let result = array.cummax(1, None, None).unwrap();
294 assert_eq!(result.shape(), &[2, 2]);
295 assert_eq!(result.as_slice::<i32>(), &[5, 8, 4, 9]);
296
297 let result = array.cummax(None, None, None).unwrap();
298 assert_eq!(result.shape(), &[4]);
299 assert_eq!(result.as_slice::<i32>(), &[5, 8, 8, 9]);
300
301 let result = array.cummax(0, Some(true), None).unwrap();
302 assert_eq!(result.shape(), &[2, 2]);
303 assert_eq!(result.as_slice::<i32>(), &[5, 9, 4, 9]);
304
305 let result = array.cummax(0, None, Some(true)).unwrap();
306 assert_eq!(result.shape(), &[2, 2]);
307 assert_eq!(result.as_slice::<i32>(), &[5, 8, 5, 9]);
308 }
309
310 #[test]
311 fn test_cummax_out_of_bounds() {
312 let array = Array::from_slice(&[5, 8, 4, 9], &[2, 2]);
313 let result = array.cummax(2, None, None);
314 assert!(result.is_err());
315 }
316
317 #[test]
318 fn test_cummin() {
319 let array = Array::from_slice(&[5, 8, 4, 9], &[2, 2]);
320
321 let result = array.cummin(0, None, None).unwrap();
322 assert_eq!(result.shape(), &[2, 2]);
323 assert_eq!(result.as_slice::<i32>(), &[5, 8, 4, 8]);
324
325 let result = array.cummin(1, None, None).unwrap();
326 assert_eq!(result.shape(), &[2, 2]);
327 assert_eq!(result.as_slice::<i32>(), &[5, 5, 4, 4]);
328
329 let result = array.cummin(None, None, None).unwrap();
330 assert_eq!(result.shape(), &[4]);
331 assert_eq!(result.as_slice::<i32>(), &[5, 5, 4, 4]);
332
333 let result = array.cummin(0, Some(true), None).unwrap();
334 assert_eq!(result.shape(), &[2, 2]);
335 assert_eq!(result.as_slice::<i32>(), &[4, 8, 4, 9]);
336
337 let result = array.cummin(0, None, Some(true)).unwrap();
338 assert_eq!(result.shape(), &[2, 2]);
339 assert_eq!(result.as_slice::<i32>(), &[5, 8, 4, 8]);
340 }
341
342 #[test]
343 fn test_cummin_out_of_bounds() {
344 let array = Array::from_slice(&[5, 8, 4, 9], &[2, 2]);
345 let result = array.cummin(2, None, None);
346 assert!(result.is_err());
347 }
348
349 #[test]
350 fn test_cumprod() {
351 let array = Array::from_slice(&[5, 8, 4, 9], &[2, 2]);
352
353 let result = array.cumprod(0, None, None).unwrap();
354 assert_eq!(result.shape(), &[2, 2]);
355 assert_eq!(result.as_slice::<i32>(), &[5, 8, 20, 72]);
356
357 let result = array.cumprod(1, None, None).unwrap();
358 assert_eq!(result.shape(), &[2, 2]);
359 assert_eq!(result.as_slice::<i32>(), &[5, 40, 4, 36]);
360
361 let result = array.cumprod(None, None, None).unwrap();
362 assert_eq!(result.shape(), &[4]);
363 assert_eq!(result.as_slice::<i32>(), &[5, 40, 160, 1440]);
364
365 let result = array.cumprod(0, Some(true), None).unwrap();
366 assert_eq!(result.shape(), &[2, 2]);
367 assert_eq!(result.as_slice::<i32>(), &[20, 72, 4, 9]);
368
369 let result = array.cumprod(0, None, Some(true)).unwrap();
370 assert_eq!(result.shape(), &[2, 2]);
371 assert_eq!(result.as_slice::<i32>(), &[5, 8, 20, 72]);
372 }
373
374 #[test]
375 fn test_cumprod_out_of_bounds() {
376 let array = Array::from_slice(&[5, 8, 4, 9], &[2, 2]);
377 let result = array.cumprod(2, None, None);
378 assert!(result.is_err());
379 }
380
381 #[test]
382 fn test_cumsum() {
383 let array = Array::from_slice(&[5, 8, 4, 9], &[2, 2]);
384
385 let result = array.cumsum(0, None, None).unwrap();
386 assert_eq!(result.shape(), &[2, 2]);
387 assert_eq!(result.as_slice::<i32>(), &[5, 8, 9, 17]);
388
389 let result = array.cumsum(1, None, None).unwrap();
390 assert_eq!(result.shape(), &[2, 2]);
391 assert_eq!(result.as_slice::<i32>(), &[5, 13, 4, 13]);
392
393 let result = array.cumsum(None, None, None).unwrap();
394 assert_eq!(result.shape(), &[4]);
395 assert_eq!(result.as_slice::<i32>(), &[5, 13, 17, 26]);
396
397 let result = array.cumsum(0, Some(true), None).unwrap();
398 assert_eq!(result.shape(), &[2, 2]);
399 assert_eq!(result.as_slice::<i32>(), &[9, 17, 4, 9]);
400
401 let result = array.cumsum(0, None, Some(true)).unwrap();
402 assert_eq!(result.shape(), &[2, 2]);
403 assert_eq!(result.as_slice::<i32>(), &[5, 8, 9, 17]);
404 }
405
406 #[test]
407 fn test_cumsum_out_of_bounds() {
408 let array = Array::from_slice(&[5, 8, 4, 9], &[2, 2]);
409 let result = array.cumsum(2, None, None);
410 assert!(result.is_err());
411 }
412}