1use std::iter::{once, zip};
2
3use crate::{error::Exception, module::Module, ops::as_strided, Array};
4use dyn_clone::DynClone;
5use mlx_macros::ModuleParameters;
6
7use crate::utils::SingleOrPair;
8
9pub trait Pooling
11where
12    Self: Fn(&Array, &[i32]) -> Result<Array, Exception> + DynClone,
13{
14}
15
16impl<T> Pooling for T where T: Fn(&Array, &[i32]) -> Result<Array, Exception> + DynClone {}
17
18#[derive(ModuleParameters)]
27#[module(root = crate)]
28pub struct Pool {
29    kernel_size: Vec<i32>,
31
32    stride: Vec<i64>,
34
35    axes: Vec<i32>,
37
38    pooling_op: Box<dyn Pooling>,
42}
43
44impl Clone for Pool {
45    fn clone(&self) -> Self {
46        Self {
47            kernel_size: self.kernel_size.clone(),
48            stride: self.stride.clone(),
49            axes: self.axes.clone(),
50            pooling_op: dyn_clone::clone_box(&*self.pooling_op),
51        }
52    }
53}
54
55impl std::fmt::Debug for Pool {
56    fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
57        f.debug_struct("Pool")
58            .field("kernel_size", &self.kernel_size)
59            .field("stride", &self.stride)
60            .field("axes", &self.axes)
61            .finish()
62    }
63}
64
65impl Pool {
66    pub fn new(kernel_size: Vec<i32>, stride: Vec<i64>, op: impl Pooling + 'static) -> Self {
68        let start = -(kernel_size.len() as i32) - 1;
69        let axes: Vec<_> = (start..-1).collect();
70        Self {
71            kernel_size,
72            stride,
73            axes,
74            pooling_op: Box::new(op),
75        }
76    }
77}
78
79impl Module<&Array> for Pool {
80    type Error = Exception;
81    type Output = Array;
82
83    fn forward(&mut self, x: &Array) -> Result<Array, Self::Error> {
84        let shape = x.shape();
85        let rest = &shape[1..shape.len() - 1];
86
87        let iter = zip(zip(rest, &self.kernel_size), &self.stride)
88            .map(|((size, window), stride)| (size - window) / *stride as i32 + 1);
89
90        let final_shape = once(shape[0])
91            .chain(iter)
92            .chain(self.kernel_size.iter().copied())
93            .chain(once(shape[shape.len() - 1]))
94            .collect::<Vec<_>>();
95
96        let strides = shape
97            .iter()
98            .map(|s| *s as i64)
99            .chain(once(1))
100            .rev()
101            .fold(vec![], |mut acc, a| {
102                match acc.last() {
103                    Some(&element) => acc.push(a * element),
104                    None => acc.push(a),
105                }
106                acc
107            })
108            .into_iter()
109            .rev()
110            .skip(1)
111            .collect::<Vec<_>>();
112        let middle_strides = &strides[1..strides.len() - 1];
113
114        let final_strides = once(strides[0])
115            .chain(zip(middle_strides, &self.stride).map(|(ms, s)| ms * s))
116            .chain(middle_strides.iter().copied())
117            .chain(once(1))
118            .collect::<Vec<_>>();
119
120        let strided = as_strided(x, &final_shape, &final_strides, None)?;
122        (self.pooling_op)(&strided, &self.axes)
123    }
124
125    fn training_mode(&mut self, _mode: bool) {}
126}
127
128macro_rules! impl_module {
129    ($name:ident) => {
130        impl Module<&Array> for $name {
131            type Output = Array;
132            type Error = Exception;
133
134            fn forward(&mut self, x: &Array) -> Result<Array, Self::Error> {
135                self.inner.forward(x)
136            }
137
138            fn training_mode(&mut self, mode: bool) {
139                self.inner.training_mode(mode);
140            }
141        }
142    };
143}
144
145#[derive(Debug, Clone, ModuleParameters)]
154#[module(root = crate)]
155pub struct MaxPool1d {
156    #[param]
157    inner: Pool,
158}
159
160impl MaxPool1d {
161    pub fn new(kernel_size: i32, stride: i64) -> Self {
168        let op = |x: &Array, axes: &[i32]| x.max_axes(axes, None);
169        let inner = Pool::new(vec![kernel_size], vec![stride], op);
170        Self { inner }
171    }
172}
173
174impl_module!(MaxPool1d);
175
176#[derive(Debug, Clone, ModuleParameters)]
185#[module(root = crate)]
186pub struct MaxPool2d {
187    #[param]
188    inner: Pool,
189}
190
191impl MaxPool2d {
192    pub fn new(
199        kernel_size: impl Into<SingleOrPair<i32>>,
200        stride: impl Into<SingleOrPair<i64>>,
201    ) -> Self {
202        let kernel_size = kernel_size.into();
203        let kernel_size = vec![kernel_size.first(), kernel_size.second()];
204        let stride = stride.into();
205        let stride = vec![stride.first(), stride.second()];
206
207        let op = |x: &Array, axes: &[i32]| x.max_axes(axes, None);
208        let inner = Pool::new(kernel_size, stride, op);
209        Self { inner }
210    }
211}
212
213impl_module!(MaxPool2d);
214
215#[derive(Debug, Clone, ModuleParameters)]
224#[module(root = crate)]
225pub struct AvgPool1d {
226    #[param]
227    inner: Pool,
228}
229
230impl AvgPool1d {
231    pub fn new(kernel_size: i32, stride: i64) -> Self {
238        let op = |x: &Array, axes: &[i32]| x.mean_axes(axes, None);
239        let inner = Pool::new(vec![kernel_size], vec![stride], op);
240        Self { inner }
241    }
242}
243
244impl_module!(AvgPool1d);
245
246#[derive(Debug, Clone, ModuleParameters)]
255#[module(root = crate)]
256pub struct AvgPool2d {
257    #[param]
258    inner: Pool,
259}
260
261impl AvgPool2d {
262    pub fn new(
269        kernel_size: impl Into<SingleOrPair<i32>>,
270        stride: impl Into<SingleOrPair<i64>>,
271    ) -> Self {
272        let kernel_size = kernel_size.into();
273        let kernel_size = vec![kernel_size.first(), kernel_size.second()];
274        let stride = stride.into();
275        let stride = vec![stride.first(), stride.second()];
276
277        let op = |x: &Array, axes: &[i32]| x.mean_axes(axes, None);
278        let inner = Pool::new(kernel_size, stride, op);
279        Self { inner }
280    }
281}
282
283impl_module!(AvgPool2d);
284
285#[cfg(test)]
286mod tests {
287    use crate::{array, assert_array_eq, module::ModuleParameters};
288
289    use super::*;
290
291    #[test]
292    fn test_pool_has_no_learnable_params() {
293        let pool = MaxPool1d::new(2, 1);
294        let params = pool.parameters().flatten();
295        assert_eq!(params.len(), 0);
296    }
297
298    #[test]
299    fn test_max_pooling_1d_stride_1() {
300        let input = Array::from_iter(0..4, &[1, 4, 1]);
301        let mut pool = MaxPool1d::new(2, 1);
302        let output = pool.forward(&input).unwrap();
303        assert_array_eq!(output, array!([1, 2, 3], shape = [1, 3, 1]));
304    }
305
306    #[test]
307    fn test_max_pooling_1d_stride_2() {
308        let input = Array::from_iter(0..8, &[2, 4, 1]);
309        let mut pool = MaxPool1d::new(2, 2);
310        let output = pool.forward(&input).unwrap();
311        assert_array_eq!(output, array!([1, 3, 5, 7], shape = [2, 2, 1]));
312    }
313
314    #[test]
315    fn test_max_pooling_2d_stride_1() {
316        let input = Array::from_iter(0..16, &[1, 4, 4, 1]);
317        let mut pool = MaxPool2d::new(2, 1);
318        let output = pool.forward(&input).unwrap();
319        assert_array_eq!(
320            output,
321            array!([5, 6, 7, 9, 10, 11, 13, 14, 15], shape = [1, 3, 3, 1])
322        );
323    }
324
325    #[test]
326    fn test_max_pooling_2d_stride_2() {
327        let input = Array::from_iter(0..32, &[2, 4, 4, 1]);
328        let mut pool = MaxPool2d::new(2, 2);
329        let output = pool.forward(&input).unwrap();
330        assert_array_eq!(
331            output,
332            array!([5, 7, 13, 15, 21, 23, 29, 31], shape = [2, 2, 2, 1])
333        );
334    }
335
336    #[test]
337    fn test_avg_pooling_1d_stride_1() {
338        let input = Array::from_iter(0..4, &[1, 4, 1]);
339        let mut pool = AvgPool1d::new(2, 1);
340        let output = pool.forward(&input).unwrap();
341        assert_array_eq!(output, array!([0.5, 1.5, 2.5], shape = [1, 3, 1]));
342    }
343
344    #[test]
345    fn test_avg_pooling_1d_stride_2() {
346        let input = Array::from_iter(0..8, &[2, 4, 1]);
347        let mut pool = AvgPool1d::new(2, 2);
348        let output = pool.forward(&input).unwrap();
349        assert_array_eq!(output, array!([0.5, 2.5, 4.5, 6.5], shape = [2, 2, 1]));
350    }
351
352    #[test]
353    fn test_avg_pooling_2d_stride_1() {
354        let input = Array::from_iter(0..16, &[1, 4, 4, 1]);
355        let mut pool = AvgPool2d::new(2, 1);
356        let output = pool.forward(&input).unwrap();
357        assert_array_eq!(
358            output,
359            array!(
360                [2.5, 3.5, 4.5, 6.5, 7.5, 8.5, 10.5, 11.5, 12.5],
361                shape = [1, 3, 3, 1]
362            )
363        );
364    }
365
366    #[test]
367    fn test_avg_pooling_2d_stride_2() {
368        let input = Array::from_iter(0..16, &[1, 4, 4, 1]);
369        let mut pool = AvgPool2d::new(2, 2);
370        let output = pool.forward(&input).unwrap();
371        assert_array_eq!(output, array!([2.5, 4.5, 10.5, 12.5], shape = [1, 2, 2, 1]));
372    }
373}