1use std::iter::{once, zip};
2
3use crate::{error::Exception, module::Module, ops::as_strided, Array};
4use dyn_clone::DynClone;
5use mlx_macros::ModuleParameters;
6
7use crate::utils::SingleOrPair;
8
9pub trait Pooling
11where
12 Self: Fn(&Array, &[i32]) -> Result<Array, Exception> + DynClone,
13{
14}
15
16impl<T> Pooling for T where T: Fn(&Array, &[i32]) -> Result<Array, Exception> + DynClone {}
17
18#[derive(ModuleParameters)]
27#[module(root = crate)]
28pub struct Pool {
29 kernel_size: Vec<i32>,
31
32 stride: Vec<i64>,
34
35 axes: Vec<i32>,
37
38 pooling_op: Box<dyn Pooling>,
42}
43
44impl Clone for Pool {
45 fn clone(&self) -> Self {
46 Self {
47 kernel_size: self.kernel_size.clone(),
48 stride: self.stride.clone(),
49 axes: self.axes.clone(),
50 pooling_op: dyn_clone::clone_box(&*self.pooling_op),
51 }
52 }
53}
54
55impl std::fmt::Debug for Pool {
56 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
57 f.debug_struct("Pool")
58 .field("kernel_size", &self.kernel_size)
59 .field("stride", &self.stride)
60 .field("axes", &self.axes)
61 .finish()
62 }
63}
64
65impl Pool {
66 pub fn new(kernel_size: Vec<i32>, stride: Vec<i64>, op: impl Pooling + 'static) -> Self {
68 let start = -(kernel_size.len() as i32) - 1;
69 let axes: Vec<_> = (start..-1).collect();
70 Self {
71 kernel_size,
72 stride,
73 axes,
74 pooling_op: Box::new(op),
75 }
76 }
77}
78
79impl Module<&Array> for Pool {
80 type Error = Exception;
81 type Output = Array;
82
83 fn forward(&mut self, x: &Array) -> Result<Array, Self::Error> {
84 let shape = x.shape();
85 let rest = &shape[1..shape.len() - 1];
86
87 let iter = zip(zip(rest, &self.kernel_size), &self.stride)
88 .map(|((size, window), stride)| (size - window) / *stride as i32 + 1);
89
90 let final_shape = once(shape[0])
91 .chain(iter)
92 .chain(self.kernel_size.iter().copied())
93 .chain(once(shape[shape.len() - 1]))
94 .collect::<Vec<_>>();
95
96 let strides = shape
97 .iter()
98 .map(|s| *s as i64)
99 .chain(once(1))
100 .rev()
101 .fold(vec![], |mut acc, a| {
102 match acc.last() {
103 Some(&element) => acc.push(a * element),
104 None => acc.push(a),
105 }
106 acc
107 })
108 .into_iter()
109 .rev()
110 .skip(1)
111 .collect::<Vec<_>>();
112 let middle_strides = &strides[1..strides.len() - 1];
113
114 let final_strides = once(strides[0])
115 .chain(zip(middle_strides, &self.stride).map(|(ms, s)| ms * s))
116 .chain(middle_strides.iter().copied())
117 .chain(once(1))
118 .collect::<Vec<_>>();
119
120 let strided = as_strided(x, &final_shape, &final_strides, None)?;
122 (self.pooling_op)(&strided, &self.axes)
123 }
124
125 fn training_mode(&mut self, _mode: bool) {}
126}
127
128macro_rules! impl_module {
129 ($name:ident) => {
130 impl Module<&Array> for $name {
131 type Output = Array;
132 type Error = Exception;
133
134 fn forward(&mut self, x: &Array) -> Result<Array, Self::Error> {
135 self.inner.forward(x)
136 }
137
138 fn training_mode(&mut self, mode: bool) {
139 self.inner.training_mode(mode);
140 }
141 }
142 };
143}
144
145#[derive(Debug, Clone, ModuleParameters)]
154#[module(root = crate)]
155pub struct MaxPool1d {
156 #[param]
157 inner: Pool,
158}
159
160impl MaxPool1d {
161 pub fn new(kernel_size: i32, stride: i64) -> Self {
168 let op = |x: &Array, axes: &[i32]| x.max(axes, None);
169 let inner = Pool::new(vec![kernel_size], vec![stride], op);
170 Self { inner }
171 }
172}
173
174impl_module!(MaxPool1d);
175
176#[derive(Debug, Clone, ModuleParameters)]
185#[module(root = crate)]
186pub struct MaxPool2d {
187 #[param]
188 inner: Pool,
189}
190
191impl MaxPool2d {
192 pub fn new(
199 kernel_size: impl Into<SingleOrPair<i32>>,
200 stride: impl Into<SingleOrPair<i64>>,
201 ) -> Self {
202 let kernel_size = kernel_size.into();
203 let kernel_size = vec![kernel_size.first(), kernel_size.second()];
204 let stride = stride.into();
205 let stride = vec![stride.first(), stride.second()];
206
207 let op = |x: &Array, axes: &[i32]| x.max(axes, None);
208 let inner = Pool::new(kernel_size, stride, op);
209 Self { inner }
210 }
211}
212
213impl_module!(MaxPool2d);
214
215#[derive(Debug, Clone, ModuleParameters)]
224#[module(root = crate)]
225pub struct AvgPool1d {
226 #[param]
227 inner: Pool,
228}
229
230impl AvgPool1d {
231 pub fn new(kernel_size: i32, stride: i64) -> Self {
238 let op = |x: &Array, axes: &[i32]| x.mean(axes, None);
239 let inner = Pool::new(vec![kernel_size], vec![stride], op);
240 Self { inner }
241 }
242}
243
244impl_module!(AvgPool1d);
245
246#[derive(Debug, Clone, ModuleParameters)]
255#[module(root = crate)]
256pub struct AvgPool2d {
257 #[param]
258 inner: Pool,
259}
260
261impl AvgPool2d {
262 pub fn new(
269 kernel_size: impl Into<SingleOrPair<i32>>,
270 stride: impl Into<SingleOrPair<i64>>,
271 ) -> Self {
272 let kernel_size = kernel_size.into();
273 let kernel_size = vec![kernel_size.first(), kernel_size.second()];
274 let stride = stride.into();
275 let stride = vec![stride.first(), stride.second()];
276
277 let op = |x: &Array, axes: &[i32]| x.mean(axes, None);
278 let inner = Pool::new(kernel_size, stride, op);
279 Self { inner }
280 }
281}
282
283impl_module!(AvgPool2d);
284
285#[cfg(test)]
286mod tests {
287 use crate::{array, assert_array_eq, module::ModuleParameters};
288
289 use super::*;
290
291 #[test]
292 fn test_pool_has_no_learnable_params() {
293 let pool = MaxPool1d::new(2, 1);
294 let params = pool.parameters().flatten();
295 assert_eq!(params.len(), 0);
296 }
297
298 #[test]
299 fn test_max_pooling_1d_stride_1() {
300 let input = Array::from_iter(0..4, &[1, 4, 1]);
301 let mut pool = MaxPool1d::new(2, 1);
302 let output = pool.forward(&input).unwrap();
303 assert_array_eq!(output, array!([1, 2, 3], shape = [1, 3, 1]));
304 }
305
306 #[test]
307 fn test_max_pooling_1d_stride_2() {
308 let input = Array::from_iter(0..8, &[2, 4, 1]);
309 let mut pool = MaxPool1d::new(2, 2);
310 let output = pool.forward(&input).unwrap();
311 assert_array_eq!(output, array!([1, 3, 5, 7], shape = [2, 2, 1]));
312 }
313
314 #[test]
315 fn test_max_pooling_2d_stride_1() {
316 let input = Array::from_iter(0..16, &[1, 4, 4, 1]);
317 let mut pool = MaxPool2d::new(2, 1);
318 let output = pool.forward(&input).unwrap();
319 assert_array_eq!(
320 output,
321 array!([5, 6, 7, 9, 10, 11, 13, 14, 15], shape = [1, 3, 3, 1])
322 );
323 }
324
325 #[test]
326 fn test_max_pooling_2d_stride_2() {
327 let input = Array::from_iter(0..32, &[2, 4, 4, 1]);
328 let mut pool = MaxPool2d::new(2, 2);
329 let output = pool.forward(&input).unwrap();
330 assert_array_eq!(
331 output,
332 array!([5, 7, 13, 15, 21, 23, 29, 31], shape = [2, 2, 2, 1])
333 );
334 }
335
336 #[test]
337 fn test_avg_pooling_1d_stride_1() {
338 let input = Array::from_iter(0..4, &[1, 4, 1]);
339 let mut pool = AvgPool1d::new(2, 1);
340 let output = pool.forward(&input).unwrap();
341 assert_array_eq!(output, array!([0.5, 1.5, 2.5], shape = [1, 3, 1]));
342 }
343
344 #[test]
345 fn test_avg_pooling_1d_stride_2() {
346 let input = Array::from_iter(0..8, &[2, 4, 1]);
347 let mut pool = AvgPool1d::new(2, 2);
348 let output = pool.forward(&input).unwrap();
349 assert_array_eq!(output, array!([0.5, 2.5, 4.5, 6.5], shape = [2, 2, 1]));
350 }
351
352 #[test]
353 fn test_avg_pooling_2d_stride_1() {
354 let input = Array::from_iter(0..16, &[1, 4, 4, 1]);
355 let mut pool = AvgPool2d::new(2, 1);
356 let output = pool.forward(&input).unwrap();
357 assert_array_eq!(
358 output,
359 array!(
360 [2.5, 3.5, 4.5, 6.5, 7.5, 8.5, 10.5, 11.5, 12.5],
361 shape = [1, 3, 3, 1]
362 )
363 );
364 }
365
366 #[test]
367 fn test_avg_pooling_2d_stride_2() {
368 let input = Array::from_iter(0..16, &[1, 4, 4, 1]);
369 let mut pool = AvgPool2d::new(2, 2);
370 let output = pool.forward(&input).unwrap();
371 assert_array_eq!(output, array!([2.5, 4.5, 10.5, 12.5], shape = [1, 2, 2, 1]));
372 }
373}