mlx_rs/ops/
cumulative.rs

1use crate::error::Result;
2use crate::utils::guard::Guarded;
3use crate::{Array, Stream, StreamOrDevice};
4use mlx_internal_macros::{default_device, generate_macro};
5
6impl Array {
7    /// Return the cumulative maximum of the elements along the given axis returning an error if the inputs are invalid.
8    ///
9    /// # Params
10    ///
11    /// - axis: Optional axis to compute the cumulative maximum over. If unspecified the cumulative maximum of the flattened array is returned.
12    /// - reverse: If true, the cumulative maximum is computed in reverse - defaults to false if unspecified.
13    /// - inclusive: If true, the i-th element of the output includes the i-th element of the input - defaults to true if unspecified.
14    ///
15    /// # Example
16    ///
17    /// ```rust
18    /// use mlx_rs::Array;
19    /// let array = Array::from_slice(&[5, 8, 4, 9], &[2, 2]);
20    ///
21    /// // result is [[5, 8], [5, 9]] -- cumulative max along the columns
22    /// let result = array.cummax(0, None, None).unwrap();
23    /// ```
24    #[default_device]
25    pub fn cummax_device(
26        &self,
27        axis: impl Into<Option<i32>>,
28        reverse: impl Into<Option<bool>>,
29        inclusive: impl Into<Option<bool>>,
30        stream: impl AsRef<Stream>,
31    ) -> Result<Array> {
32        let stream = stream.as_ref();
33
34        match axis.into() {
35            Some(axis) => Array::try_from_op(|res| unsafe {
36                mlx_sys::mlx_cummax(
37                    res,
38                    self.as_ptr(),
39                    axis,
40                    reverse.into().unwrap_or(false),
41                    inclusive.into().unwrap_or(true),
42                    stream.as_ptr(),
43                )
44            }),
45            None => {
46                let shape = &[-1];
47                let flat = self.reshape_device(shape, stream)?;
48                Array::try_from_op(|res| unsafe {
49                    mlx_sys::mlx_cummax(
50                        res,
51                        flat.as_ptr(),
52                        0,
53                        reverse.into().unwrap_or(false),
54                        inclusive.into().unwrap_or(true),
55                        stream.as_ptr(),
56                    )
57                })
58            }
59        }
60    }
61
62    /// Return the cumulative minimum of the elements along the given axis returning an error if the inputs are invalid.
63    ///
64    /// # Params
65    ///
66    /// - axis: Optional axis to compute the cumulative minimum over. If unspecified the cumulative maximum of the flattened array is returned.
67    /// - reverse: If true, the cumulative minimum is computed in reverse - defaults to false if unspecified.
68    /// - inclusive: If true, the i-th element of the output includes the i-th element of the input - defaults to true if unspecified.
69    ///
70    /// # Example
71    ///
72    /// ```rust
73    /// use mlx_rs::Array;
74    /// let array = Array::from_slice(&[5, 8, 4, 9], &[2, 2]);
75    ///
76    /// // result is [[5, 8], [4, 8]] -- cumulative min along the columns
77    /// let result = array.cummin(0, None, None).unwrap();
78    /// ```
79    #[default_device]
80    pub fn cummin_device(
81        &self,
82        axis: impl Into<Option<i32>>,
83        reverse: impl Into<Option<bool>>,
84        inclusive: impl Into<Option<bool>>,
85        stream: impl AsRef<Stream>,
86    ) -> Result<Array> {
87        let stream = stream.as_ref();
88
89        match axis.into() {
90            Some(axis) => Array::try_from_op(|res| unsafe {
91                mlx_sys::mlx_cummin(
92                    res,
93                    self.as_ptr(),
94                    axis,
95                    reverse.into().unwrap_or(false),
96                    inclusive.into().unwrap_or(true),
97                    stream.as_ptr(),
98                )
99            }),
100            None => {
101                let shape = &[-1];
102                let flat = self.reshape_device(shape, stream)?;
103                Array::try_from_op(|res| unsafe {
104                    mlx_sys::mlx_cummin(
105                        res,
106                        flat.as_ptr(),
107                        0,
108                        reverse.into().unwrap_or(false),
109                        inclusive.into().unwrap_or(true),
110                        stream.as_ptr(),
111                    )
112                })
113            }
114        }
115    }
116
117    /// Return the cumulative product of the elements along the given axis returning an error if the inputs are invalid.
118    ///
119    /// # Params
120    ///
121    /// - axis: Optional axis to compute the cumulative product over. If unspecified the cumulative maximum of the flattened array is returned.
122    /// - reverse: If true, the cumulative product is computed in reverse - defaults to false if unspecified.
123    /// - inclusive: If true, the i-th element of the output includes the i-th element of the input - defaults to true if unspecified.
124    ///
125    /// # Example
126    ///
127    /// ```rust
128    /// use mlx_rs::Array;
129    /// let array = Array::from_slice(&[5, 8, 4, 9], &[2, 2]);
130    ///
131    /// // result is [[5, 8], [20, 72]] -- cumulative min along the columns
132    /// let result = array.cumprod(0, None, None).unwrap();
133    /// ```
134    #[default_device]
135    pub fn cumprod_device(
136        &self,
137        axis: impl Into<Option<i32>>,
138        reverse: impl Into<Option<bool>>,
139        inclusive: impl Into<Option<bool>>,
140        stream: impl AsRef<Stream>,
141    ) -> Result<Array> {
142        let stream = stream.as_ref();
143
144        match axis.into() {
145            Some(axis) => Array::try_from_op(|res| unsafe {
146                mlx_sys::mlx_cumprod(
147                    res,
148                    self.as_ptr(),
149                    axis,
150                    reverse.into().unwrap_or(false),
151                    inclusive.into().unwrap_or(true),
152                    stream.as_ptr(),
153                )
154            }),
155            None => {
156                let shape = &[-1];
157                let flat = self.reshape_device(shape, stream)?;
158                Array::try_from_op(|res| unsafe {
159                    mlx_sys::mlx_cumprod(
160                        res,
161                        flat.as_ptr(),
162                        0,
163                        reverse.into().unwrap_or(false),
164                        inclusive.into().unwrap_or(true),
165                        stream.as_ptr(),
166                    )
167                })
168            }
169        }
170    }
171
172    /// Return the cumulative sum of the elements along the given axis returning an error if the inputs are invalid.
173    ///
174    /// # Params
175    ///
176    /// - axis: Optional axis to compute the cumulative sum over. If unspecified the cumulative maximum of the flattened array is returned.
177    /// - reverse: If true, the cumulative sum is computed in reverse - defaults to false if unspecified.
178    /// - inclusive: If true, the i-th element of the output includes the i-th element of the input - defaults to true if unspecified.
179    ///
180    /// # Example
181    ///
182    /// ```rust
183    /// use mlx_rs::Array;
184    /// let array = Array::from_slice(&[5, 8, 4, 9], &[2, 2]);
185    ///
186    /// // result is [[5, 8], [9, 17]] -- cumulative min along the columns
187    /// let result = array.cumsum(0, None, None).unwrap();
188    /// ```
189    #[default_device]
190    pub fn cumsum_device(
191        &self,
192        axis: impl Into<Option<i32>>,
193        reverse: impl Into<Option<bool>>,
194        inclusive: impl Into<Option<bool>>,
195        stream: impl AsRef<Stream>,
196    ) -> Result<Array> {
197        let stream = stream.as_ref();
198
199        match axis.into() {
200            Some(axis) => Array::try_from_op(|res| unsafe {
201                mlx_sys::mlx_cumsum(
202                    res,
203                    self.as_ptr(),
204                    axis,
205                    reverse.into().unwrap_or(false),
206                    inclusive.into().unwrap_or(true),
207                    stream.as_ptr(),
208                )
209            }),
210            None => {
211                let shape = &[-1];
212                let flat = self.reshape_device(shape, stream)?;
213                Array::try_from_op(|res| unsafe {
214                    mlx_sys::mlx_cumsum(
215                        res,
216                        flat.as_ptr(),
217                        0,
218                        reverse.into().unwrap_or(false),
219                        inclusive.into().unwrap_or(true),
220                        stream.as_ptr(),
221                    )
222                })
223            }
224        }
225    }
226}
227
228/// See [`Array::cummax`]
229#[generate_macro]
230#[default_device]
231pub fn cummax_device(
232    a: impl AsRef<Array>,
233    #[optional] axis: impl Into<Option<i32>>,
234    #[optional] reverse: impl Into<Option<bool>>,
235    #[optional] inclusive: impl Into<Option<bool>>,
236    #[optional] stream: impl AsRef<Stream>,
237) -> Result<Array> {
238    a.as_ref().cummax_device(axis, reverse, inclusive, stream)
239}
240
241/// See [`Array::cummin`]
242#[generate_macro]
243#[default_device]
244pub fn cummin_device(
245    a: impl AsRef<Array>,
246    #[optional] axis: impl Into<Option<i32>>,
247    #[optional] reverse: impl Into<Option<bool>>,
248    #[optional] inclusive: impl Into<Option<bool>>,
249    #[optional] stream: impl AsRef<Stream>,
250) -> Result<Array> {
251    a.as_ref().cummin_device(axis, reverse, inclusive, stream)
252}
253
254/// See [`Array::cumprod`]
255#[generate_macro]
256#[default_device]
257pub fn cumprod_device(
258    a: impl AsRef<Array>,
259    #[optional] axis: impl Into<Option<i32>>,
260    #[optional] reverse: impl Into<Option<bool>>,
261    #[optional] inclusive: impl Into<Option<bool>>,
262    #[optional] stream: impl AsRef<Stream>,
263) -> Result<Array> {
264    a.as_ref().cumprod_device(axis, reverse, inclusive, stream)
265}
266
267/// See [`Array::cumsum`]
268#[generate_macro]
269#[default_device]
270pub fn cumsum_device(
271    a: impl AsRef<Array>,
272    #[optional] axis: impl Into<Option<i32>>,
273    #[optional] reverse: impl Into<Option<bool>>,
274    #[optional] inclusive: impl Into<Option<bool>>,
275    #[optional] stream: impl AsRef<Stream>,
276) -> Result<Array> {
277    a.as_ref().cumsum_device(axis, reverse, inclusive, stream)
278}
279
280#[cfg(test)]
281mod tests {
282    use super::*;
283    use pretty_assertions::assert_eq;
284
285    #[test]
286    fn test_cummax() {
287        let array = Array::from_slice(&[5, 8, 4, 9], &[2, 2]);
288
289        let result = array.cummax(0, None, None).unwrap();
290        assert_eq!(result.shape(), &[2, 2]);
291        assert_eq!(result.as_slice::<i32>(), &[5, 8, 5, 9]);
292
293        let result = array.cummax(1, None, None).unwrap();
294        assert_eq!(result.shape(), &[2, 2]);
295        assert_eq!(result.as_slice::<i32>(), &[5, 8, 4, 9]);
296
297        let result = array.cummax(None, None, None).unwrap();
298        assert_eq!(result.shape(), &[4]);
299        assert_eq!(result.as_slice::<i32>(), &[5, 8, 8, 9]);
300
301        let result = array.cummax(0, Some(true), None).unwrap();
302        assert_eq!(result.shape(), &[2, 2]);
303        assert_eq!(result.as_slice::<i32>(), &[5, 9, 4, 9]);
304
305        let result = array.cummax(0, None, Some(true)).unwrap();
306        assert_eq!(result.shape(), &[2, 2]);
307        assert_eq!(result.as_slice::<i32>(), &[5, 8, 5, 9]);
308    }
309
310    #[test]
311    fn test_cummax_out_of_bounds() {
312        let array = Array::from_slice(&[5, 8, 4, 9], &[2, 2]);
313        let result = array.cummax(2, None, None);
314        assert!(result.is_err());
315    }
316
317    #[test]
318    fn test_cummin() {
319        let array = Array::from_slice(&[5, 8, 4, 9], &[2, 2]);
320
321        let result = array.cummin(0, None, None).unwrap();
322        assert_eq!(result.shape(), &[2, 2]);
323        assert_eq!(result.as_slice::<i32>(), &[5, 8, 4, 8]);
324
325        let result = array.cummin(1, None, None).unwrap();
326        assert_eq!(result.shape(), &[2, 2]);
327        assert_eq!(result.as_slice::<i32>(), &[5, 5, 4, 4]);
328
329        let result = array.cummin(None, None, None).unwrap();
330        assert_eq!(result.shape(), &[4]);
331        assert_eq!(result.as_slice::<i32>(), &[5, 5, 4, 4]);
332
333        let result = array.cummin(0, Some(true), None).unwrap();
334        assert_eq!(result.shape(), &[2, 2]);
335        assert_eq!(result.as_slice::<i32>(), &[4, 8, 4, 9]);
336
337        let result = array.cummin(0, None, Some(true)).unwrap();
338        assert_eq!(result.shape(), &[2, 2]);
339        assert_eq!(result.as_slice::<i32>(), &[5, 8, 4, 8]);
340    }
341
342    #[test]
343    fn test_cummin_out_of_bounds() {
344        let array = Array::from_slice(&[5, 8, 4, 9], &[2, 2]);
345        let result = array.cummin(2, None, None);
346        assert!(result.is_err());
347    }
348
349    #[test]
350    fn test_cumprod() {
351        let array = Array::from_slice(&[5, 8, 4, 9], &[2, 2]);
352
353        let result = array.cumprod(0, None, None).unwrap();
354        assert_eq!(result.shape(), &[2, 2]);
355        assert_eq!(result.as_slice::<i32>(), &[5, 8, 20, 72]);
356
357        let result = array.cumprod(1, None, None).unwrap();
358        assert_eq!(result.shape(), &[2, 2]);
359        assert_eq!(result.as_slice::<i32>(), &[5, 40, 4, 36]);
360
361        let result = array.cumprod(None, None, None).unwrap();
362        assert_eq!(result.shape(), &[4]);
363        assert_eq!(result.as_slice::<i32>(), &[5, 40, 160, 1440]);
364
365        let result = array.cumprod(0, Some(true), None).unwrap();
366        assert_eq!(result.shape(), &[2, 2]);
367        assert_eq!(result.as_slice::<i32>(), &[20, 72, 4, 9]);
368
369        let result = array.cumprod(0, None, Some(true)).unwrap();
370        assert_eq!(result.shape(), &[2, 2]);
371        assert_eq!(result.as_slice::<i32>(), &[5, 8, 20, 72]);
372    }
373
374    #[test]
375    fn test_cumprod_out_of_bounds() {
376        let array = Array::from_slice(&[5, 8, 4, 9], &[2, 2]);
377        let result = array.cumprod(2, None, None);
378        assert!(result.is_err());
379    }
380
381    #[test]
382    fn test_cumsum() {
383        let array = Array::from_slice(&[5, 8, 4, 9], &[2, 2]);
384
385        let result = array.cumsum(0, None, None).unwrap();
386        assert_eq!(result.shape(), &[2, 2]);
387        assert_eq!(result.as_slice::<i32>(), &[5, 8, 9, 17]);
388
389        let result = array.cumsum(1, None, None).unwrap();
390        assert_eq!(result.shape(), &[2, 2]);
391        assert_eq!(result.as_slice::<i32>(), &[5, 13, 4, 13]);
392
393        let result = array.cumsum(None, None, None).unwrap();
394        assert_eq!(result.shape(), &[4]);
395        assert_eq!(result.as_slice::<i32>(), &[5, 13, 17, 26]);
396
397        let result = array.cumsum(0, Some(true), None).unwrap();
398        assert_eq!(result.shape(), &[2, 2]);
399        assert_eq!(result.as_slice::<i32>(), &[9, 17, 4, 9]);
400
401        let result = array.cumsum(0, None, Some(true)).unwrap();
402        assert_eq!(result.shape(), &[2, 2]);
403        assert_eq!(result.as_slice::<i32>(), &[5, 8, 9, 17]);
404    }
405
406    #[test]
407    fn test_cumsum_out_of_bounds() {
408        let array = Array::from_slice(&[5, 8, 4, 9], &[2, 2]);
409        let result = array.cumsum(2, None, None);
410        assert!(result.is_err());
411    }
412}