mlx_rs/nn/
pooling.rs

1use std::iter::{once, zip};
2
3use crate::{error::Exception, module::Module, ops::as_strided, Array};
4use dyn_clone::DynClone;
5use mlx_macros::ModuleParameters;
6
7use crate::utils::SingleOrPair;
8
9/// Marker trait for pooling operations.
10pub trait Pooling
11where
12    Self: Fn(&Array, &[i32]) -> Result<Array, Exception> + DynClone,
13{
14}
15
16impl<T> Pooling for T where T: Fn(&Array, &[i32]) -> Result<Array, Exception> + DynClone {}
17
18/// Abstract pooling layer.
19///
20/// See also:
21///
22/// - [`MaxPool1d`]
23/// - [`MaxPool2d`]
24/// - [`AvgPool1d`]
25/// - [`AvgPool2d`]
26#[derive(ModuleParameters)]
27#[module(root = crate)]
28pub struct Pool {
29    /// Size of the pooling window
30    kernel_size: Vec<i32>,
31
32    /// Stride of the pooling window
33    stride: Vec<i64>,
34
35    /// Axes to pool over
36    axes: Vec<i32>,
37
38    /// Pooling operation
39    ///
40    /// TODO: We have Arc here just to make it `Clone` and `Send`. Is this necessary?
41    pooling_op: Box<dyn Pooling>,
42}
43
44impl Clone for Pool {
45    fn clone(&self) -> Self {
46        Self {
47            kernel_size: self.kernel_size.clone(),
48            stride: self.stride.clone(),
49            axes: self.axes.clone(),
50            pooling_op: dyn_clone::clone_box(&*self.pooling_op),
51        }
52    }
53}
54
55impl std::fmt::Debug for Pool {
56    fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
57        f.debug_struct("Pool")
58            .field("kernel_size", &self.kernel_size)
59            .field("stride", &self.stride)
60            .field("axes", &self.axes)
61            .finish()
62    }
63}
64
65impl Pool {
66    /// Create a new abstract pooling layer.
67    pub fn new(kernel_size: Vec<i32>, stride: Vec<i64>, op: impl Pooling + 'static) -> Self {
68        let start = -(kernel_size.len() as i32) - 1;
69        let axes: Vec<_> = (start..-1).collect();
70        Self {
71            kernel_size,
72            stride,
73            axes,
74            pooling_op: Box::new(op),
75        }
76    }
77}
78
79impl Module<&Array> for Pool {
80    type Error = Exception;
81    type Output = Array;
82
83    fn forward(&mut self, x: &Array) -> Result<Array, Self::Error> {
84        let shape = x.shape();
85        let rest = &shape[1..shape.len() - 1];
86
87        let iter = zip(zip(rest, &self.kernel_size), &self.stride)
88            .map(|((size, window), stride)| (size - window) / *stride as i32 + 1);
89
90        let final_shape = once(shape[0])
91            .chain(iter)
92            .chain(self.kernel_size.iter().copied())
93            .chain(once(shape[shape.len() - 1]))
94            .collect::<Vec<_>>();
95
96        let strides = shape
97            .iter()
98            .map(|s| *s as i64)
99            .chain(once(1))
100            .rev()
101            .fold(vec![], |mut acc, a| {
102                match acc.last() {
103                    Some(&element) => acc.push(a * element),
104                    None => acc.push(a),
105                }
106                acc
107            })
108            .into_iter()
109            .rev()
110            .skip(1)
111            .collect::<Vec<_>>();
112        let middle_strides = &strides[1..strides.len() - 1];
113
114        let final_strides = once(strides[0])
115            .chain(zip(middle_strides, &self.stride).map(|(ms, s)| ms * s))
116            .chain(middle_strides.iter().copied())
117            .chain(once(1))
118            .collect::<Vec<_>>();
119
120        // TODO: double check if as_strided would ever panic
121        let strided = as_strided(x, &final_shape, &final_strides, None)?;
122        (self.pooling_op)(&strided, &self.axes)
123    }
124
125    fn training_mode(&mut self, _mode: bool) {}
126}
127
128macro_rules! impl_module {
129    ($name:ident) => {
130        impl Module<&Array> for $name {
131            type Output = Array;
132            type Error = Exception;
133
134            fn forward(&mut self, x: &Array) -> Result<Array, Self::Error> {
135                self.inner.forward(x)
136            }
137
138            fn training_mode(&mut self, mode: bool) {
139                self.inner.training_mode(mode);
140            }
141        }
142    };
143}
144
145/// Applies 1-dimensional max pooling.
146///
147/// The input is expected to be `NLC`. The output will have the same N/C dimensions with the new `L
148/// = floor((L - kernel)/stride) + 1`
149///
150/// See [MaxPool1d python
151/// docs](https://ml-explore.github.io/mlx/build/html/python/nn/_autosummary/mlx.nn.MaxPool1d.html)
152/// for more information.
153#[derive(Debug, Clone, ModuleParameters)]
154#[module(root = crate)]
155pub struct MaxPool1d {
156    #[param]
157    inner: Pool,
158}
159
160impl MaxPool1d {
161    /// Create a new 1-dimensional max pooling layer.
162    ///
163    /// # Params
164    ///
165    /// - `kernel_size`: The size of the pooling window.
166    /// - `stride`: The stride of the pooling window.
167    pub fn new(kernel_size: i32, stride: i64) -> Self {
168        let op = |x: &Array, axes: &[i32]| x.max(axes, None);
169        let inner = Pool::new(vec![kernel_size], vec![stride], op);
170        Self { inner }
171    }
172}
173
174impl_module!(MaxPool1d);
175
176/// Applies 2-dimensional max pooling.
177///
178/// The input is expected to be `NHWC`. The output will have the same N/C dimensions with the new
179/// `H/W = floor((H/W - kernel)/stride) + 1`
180///
181/// See [MaxPool2d python
182/// docs](https://ml-explore.github.io/mlx/build/html/python/nn/_autosummary/mlx.nn.MaxPool2d.html)
183/// for more information.
184#[derive(Debug, Clone, ModuleParameters)]
185#[module(root = crate)]
186pub struct MaxPool2d {
187    #[param]
188    inner: Pool,
189}
190
191impl MaxPool2d {
192    /// Create a new 2-dimensional max pooling layer.
193    ///
194    /// # Params
195    ///
196    /// - `kernel_size`: The size of the pooling window.
197    /// - `stride`: The stride of the pooling window.
198    pub fn new(
199        kernel_size: impl Into<SingleOrPair<i32>>,
200        stride: impl Into<SingleOrPair<i64>>,
201    ) -> Self {
202        let kernel_size = kernel_size.into();
203        let kernel_size = vec![kernel_size.first(), kernel_size.second()];
204        let stride = stride.into();
205        let stride = vec![stride.first(), stride.second()];
206
207        let op = |x: &Array, axes: &[i32]| x.max(axes, None);
208        let inner = Pool::new(kernel_size, stride, op);
209        Self { inner }
210    }
211}
212
213impl_module!(MaxPool2d);
214
215/// Applies 1-dimensional average pooling.
216///
217/// The input is expected to be `NLC`. The output will have the same N/C dimensions with the new `L =
218/// floor((L - kernel)/stride) + 1`
219///
220/// See [AvgPool2d python
221/// docs](https://ml-explore.github.io/mlx/build/html/python/nn/_autosummary/mlx.nn.AvgPool2d.html)
222/// for more information.
223#[derive(Debug, Clone, ModuleParameters)]
224#[module(root = crate)]
225pub struct AvgPool1d {
226    #[param]
227    inner: Pool,
228}
229
230impl AvgPool1d {
231    /// Create a new 1-dimensional average pooling layer.
232    ///
233    /// # Params
234    ///
235    /// - `kernel_size`: The size of the pooling window.
236    /// - `stride`: The stride of the pooling window.
237    pub fn new(kernel_size: i32, stride: i64) -> Self {
238        let op = |x: &Array, axes: &[i32]| x.mean(axes, None);
239        let inner = Pool::new(vec![kernel_size], vec![stride], op);
240        Self { inner }
241    }
242}
243
244impl_module!(AvgPool1d);
245
246/// Applies 2-dimensional average pooling.
247///
248/// The input is expected to be `NHWC`. The output will have the same N/C dimensions with the new
249/// `H/W = floor((H/W - kernel)/stride) + 1`
250///
251/// See [AvgPool2d python
252/// docs](https://ml-explore.github.io/mlx/build/html/python/nn/_autosummary/mlx.nn.AvgPool2d.html)
253/// for more information.
254#[derive(Debug, Clone, ModuleParameters)]
255#[module(root = crate)]
256pub struct AvgPool2d {
257    #[param]
258    inner: Pool,
259}
260
261impl AvgPool2d {
262    /// Create a new 2-dimensional average pooling layer.
263    ///
264    /// # Params
265    ///
266    /// - `kernel_size`: The size of the pooling window.
267    /// - `stride`: The stride of the pooling window.
268    pub fn new(
269        kernel_size: impl Into<SingleOrPair<i32>>,
270        stride: impl Into<SingleOrPair<i64>>,
271    ) -> Self {
272        let kernel_size = kernel_size.into();
273        let kernel_size = vec![kernel_size.first(), kernel_size.second()];
274        let stride = stride.into();
275        let stride = vec![stride.first(), stride.second()];
276
277        let op = |x: &Array, axes: &[i32]| x.mean(axes, None);
278        let inner = Pool::new(kernel_size, stride, op);
279        Self { inner }
280    }
281}
282
283impl_module!(AvgPool2d);
284
285#[cfg(test)]
286mod tests {
287    use crate::{array, assert_array_eq, module::ModuleParameters};
288
289    use super::*;
290
291    #[test]
292    fn test_pool_has_no_learnable_params() {
293        let pool = MaxPool1d::new(2, 1);
294        let params = pool.parameters().flatten();
295        assert_eq!(params.len(), 0);
296    }
297
298    #[test]
299    fn test_max_pooling_1d_stride_1() {
300        let input = Array::from_iter(0..4, &[1, 4, 1]);
301        let mut pool = MaxPool1d::new(2, 1);
302        let output = pool.forward(&input).unwrap();
303        assert_array_eq!(output, array!([1, 2, 3], shape = [1, 3, 1]));
304    }
305
306    #[test]
307    fn test_max_pooling_1d_stride_2() {
308        let input = Array::from_iter(0..8, &[2, 4, 1]);
309        let mut pool = MaxPool1d::new(2, 2);
310        let output = pool.forward(&input).unwrap();
311        assert_array_eq!(output, array!([1, 3, 5, 7], shape = [2, 2, 1]));
312    }
313
314    #[test]
315    fn test_max_pooling_2d_stride_1() {
316        let input = Array::from_iter(0..16, &[1, 4, 4, 1]);
317        let mut pool = MaxPool2d::new(2, 1);
318        let output = pool.forward(&input).unwrap();
319        assert_array_eq!(
320            output,
321            array!([5, 6, 7, 9, 10, 11, 13, 14, 15], shape = [1, 3, 3, 1])
322        );
323    }
324
325    #[test]
326    fn test_max_pooling_2d_stride_2() {
327        let input = Array::from_iter(0..32, &[2, 4, 4, 1]);
328        let mut pool = MaxPool2d::new(2, 2);
329        let output = pool.forward(&input).unwrap();
330        assert_array_eq!(
331            output,
332            array!([5, 7, 13, 15, 21, 23, 29, 31], shape = [2, 2, 2, 1])
333        );
334    }
335
336    #[test]
337    fn test_avg_pooling_1d_stride_1() {
338        let input = Array::from_iter(0..4, &[1, 4, 1]);
339        let mut pool = AvgPool1d::new(2, 1);
340        let output = pool.forward(&input).unwrap();
341        assert_array_eq!(output, array!([0.5, 1.5, 2.5], shape = [1, 3, 1]));
342    }
343
344    #[test]
345    fn test_avg_pooling_1d_stride_2() {
346        let input = Array::from_iter(0..8, &[2, 4, 1]);
347        let mut pool = AvgPool1d::new(2, 2);
348        let output = pool.forward(&input).unwrap();
349        assert_array_eq!(output, array!([0.5, 2.5, 4.5, 6.5], shape = [2, 2, 1]));
350    }
351
352    #[test]
353    fn test_avg_pooling_2d_stride_1() {
354        let input = Array::from_iter(0..16, &[1, 4, 4, 1]);
355        let mut pool = AvgPool2d::new(2, 1);
356        let output = pool.forward(&input).unwrap();
357        assert_array_eq!(
358            output,
359            array!(
360                [2.5, 3.5, 4.5, 6.5, 7.5, 8.5, 10.5, 11.5, 12.5],
361                shape = [1, 3, 3, 1]
362            )
363        );
364    }
365
366    #[test]
367    fn test_avg_pooling_2d_stride_2() {
368        let input = Array::from_iter(0..16, &[1, 4, 4, 1]);
369        let mut pool = AvgPool2d::new(2, 2);
370        let output = pool.forward(&input).unwrap();
371        assert_array_eq!(output, array!([2.5, 4.5, 10.5, 12.5], shape = [1, 2, 2, 1]));
372    }
373}