mlx_rs/nn/
upsample.rs

1use crate::{
2    array,
3    error::Exception,
4    macros::ModuleParameters,
5    module::Module,
6    ops::{
7        abs, broadcast_to, ceil, clip, expand_dims, floor,
8        indexing::{ArrayIndex, ArrayIndexOp, Ellipsis, IndexOp, NewAxis, TryIndexOp},
9    },
10    transforms::compile::compile,
11    Array,
12};
13
14use crate::utils::SingleOrVec;
15
16/// Upsample mode
17#[derive(Debug, Clone, Copy, PartialEq, Eq)]
18pub enum UpsampleMode {
19    /// Nearest neighbor upsampling
20    Nearest,
21
22    /// Linear interpolation upsampling.
23    Linear {
24        /// If `true`, the top and left edge of the input and output
25        /// will match as will the bottom right edge
26        align_corners: bool,
27    },
28
29    /// Cubic interpolation upsampling.
30    Cubic {
31        /// If `true`, the top and left edge of the input and output
32        align_corners: bool,
33    },
34}
35
36/// Upsample the input signal spatially
37#[derive(Debug, Clone, ModuleParameters)]
38#[module(root = crate)]
39pub struct Upsample {
40    /// The multiplier for the spatial size.
41    ///
42    /// If a single `float` is provided, it is the multiplier for all spatial dimensions.
43    /// Otherwise, the number of scale factors provided must match the
44    /// number of spatial dimensions.
45    pub scale_factor: SingleOrVec<f32>,
46
47    /// The upsampling algorithm
48    pub mode: UpsampleMode,
49}
50
51impl Upsample {
52    /// Create a new `Upsample` module
53    pub fn new(scale_factor: impl Into<SingleOrVec<f32>>, mode: UpsampleMode) -> Self {
54        let scale_factor = scale_factor.into();
55        Upsample { scale_factor, mode }
56    }
57
58    fn forward_inner(&self, x: &Array, scale: &[f32]) -> Result<Array, Exception> {
59        match self.mode {
60            UpsampleMode::Nearest => upsample_nearest(x, scale),
61            UpsampleMode::Linear { align_corners } => {
62                interpolate(x, scale, linear_indices, align_corners)
63            }
64            UpsampleMode::Cubic { align_corners } => {
65                interpolate(x, scale, cubic_indices, align_corners)
66            }
67        }
68    }
69}
70
71impl Module<&Array> for Upsample {
72    type Error = Exception;
73    type Output = Array;
74
75    fn forward(&mut self, x: &Array) -> Result<Self::Output, Self::Error> {
76        let dimensions = x.ndim() - 2;
77
78        if dimensions == 0 {
79            return Err(Exception::custom(format!(
80                "[Upsample] The input should have at least 
81                1 spatial dimension which means it should be at least 
82                3D but {}D was provided",
83                x.ndim()
84            )));
85        }
86
87        match &self.scale_factor {
88            SingleOrVec::Single(scale) => {
89                let scale = vec![*scale; dimensions];
90                self.forward_inner(x, &scale[..])
91            }
92            SingleOrVec::Vec(scales) => self.forward_inner(x, &scales[..]),
93        }
94    }
95
96    fn training_mode(&mut self, _mode: bool) {}
97}
98
99#[allow(non_snake_case)]
100fn upsample_nearest(x: &Array, scale: &[f32]) -> Result<Array, Exception> {
101    let dimensions = x.ndim() - 2;
102    if dimensions != scale.len() {
103        return Err(Exception::custom(format!(
104            "The number of scale factors ({}) must match the number of spatial dimensions ({})",
105            scale.len(),
106            dimensions
107        )));
108    }
109
110    // Get a truncated version of the scales
111    let int_scales = scale.iter().map(|&s| s as i32).collect::<Vec<_>>();
112    let int_float_scales = int_scales.iter().map(|&s| s as f32).collect::<Vec<_>>();
113
114    if int_float_scales == scale {
115        // Int scale means we can simply expand-broadcast and reshape
116        let mut shape = x.shape().to_vec();
117        (0..dimensions).for_each(|d| {
118            shape.insert(2 + 2 * d, 1);
119        });
120        let mut x = x.reshape(&shape)?;
121
122        (0..dimensions).for_each(|d| {
123            shape[2 + 2 * d] = int_scales[d];
124        });
125        x = broadcast_to(&x, &shape)?;
126
127        (0..dimensions).for_each(|d| {
128            shape[d + 1] *= shape[d + 2];
129            shape.remove(d + 2);
130        });
131        x = x.reshape(&shape)?;
132
133        Ok(x)
134    } else {
135        // Float scales
136        let shape_len = x.shape().len();
137        let N = &x.shape()[1..shape_len - 1];
138        let mut indices: Vec<ArrayIndexOp> = vec![(..).index_op()];
139
140        for (i, (n, s)) in N.iter().zip(scale.iter()).enumerate() {
141            indices.push(nearest_indices(*n, *s, i, dimensions)?.index_op());
142        }
143
144        x.try_index(&indices[..])
145    }
146}
147
148type IndexWeight = (Array, Array);
149
150type IndicesFn = fn(i32, f32, bool, usize, usize) -> Result<Vec<IndexWeight>, Exception>;
151
152#[allow(non_snake_case)]
153fn interpolate(
154    x: &Array,
155    scale: &[f32],
156    indices_fn: IndicesFn,
157    align_corners: bool,
158) -> Result<Array, Exception> {
159    let dimensions = x.ndim() - 2;
160    if dimensions != scale.len() {
161        return Err(Exception::custom(format!(
162            "The number of scale factors ({}) must match the number of spatial dimensions ({})",
163            scale.len(),
164            dimensions
165        )));
166    }
167
168    let N = &x.shape()[1..x.ndim() - 1];
169
170    // compute the sampling grid
171    let mut index_weights = Vec::with_capacity(N.len());
172    for (i, (n, s)) in N.iter().zip(scale.iter()).enumerate() {
173        index_weights.push(indices_fn(*n, *s, align_corners, i, dimensions)?);
174    }
175
176    // sample and compute the weights
177    let prod = product(&index_weights);
178    let mut samples = Vec::with_capacity(prod.len());
179    let mut weights = Vec::with_capacity(prod.len());
180    for index_weight in prod {
181        let (index, weight): (Vec<&Array>, Vec<&Array>) =
182            index_weight.iter().map(|(i, w)| (i, w)).unzip();
183        let mut index_ops = index.iter().map(|i| i.index_op()).collect::<Vec<_>>();
184
185        let mut sample_indices = vec![(..).index_op()];
186        sample_indices.append(&mut index_ops);
187        samples.push(x.index(&sample_indices[..]));
188
189        weights.push(weight.into_iter().product::<Array>());
190    }
191
192    // interpolate
193    let acc = &weights[0] * &samples[0];
194    weights[1..]
195        .iter()
196        .zip(samples[1..].iter())
197        .try_fold(acc, |acc, (w, s)| acc.add(w.multiply(s)?))
198}
199
200fn product<T>(values: &[Vec<T>]) -> Vec<Vec<&T>> {
201    if values.is_empty() {
202        return vec![];
203    }
204
205    // if there are N items in values and M values per tuple there
206    // will be M^N values in the result
207    let per_tuple = values[0].len();
208    let count = (0..values.len()).fold(1, |acc, _| acc * per_tuple);
209
210    let mut result = Vec::with_capacity(count);
211    for result_index in 0..count {
212        let mut items = vec![];
213
214        // use % and / to compute which item will be used from each value[i]
215        let mut index_generator = result_index;
216        for value in values {
217            let index = index_generator % per_tuple;
218            items.push(&value[index]);
219            index_generator /= per_tuple;
220        }
221
222        result.push(items);
223    }
224
225    result
226}
227
228fn nearest_indices(
229    dimension: i32,
230    scale: f32,
231    dim: usize,
232    ndim: usize,
233) -> Result<Array, Exception> {
234    scaled_indices(dimension, scale, true, dim, ndim).and_then(|i| i.as_type::<i32>())
235}
236
237fn linear_indices(
238    dimension: i32,
239    scale: f32,
240    align_corners: bool,
241    dim: usize,
242    ndim: usize,
243) -> Result<Vec<IndexWeight>, Exception> {
244    let mut indices = scaled_indices(dimension, scale, align_corners, dim, ndim)?;
245    indices = clip(&indices, (0, dimension - 1))?;
246    let indices_left = floor(&indices)?;
247    let indices_right = ceil(&indices)?;
248    let weight = expand_dims(&indices.subtract(&indices_left)?, &[-1])?;
249
250    let indices_left = indices_left.as_type::<i32>()?;
251    let indices_right = indices_right.as_type::<i32>()?;
252
253    Ok(vec![
254        // SAFETY: arith ops with scalars won't panic
255        (indices_left, array!(1.0) - &weight),
256        (indices_right, weight),
257    ])
258}
259
260fn cubic_indices(
261    dimension: i32,
262    scale: f32,
263    align_corners: bool,
264    dim: usize,
265    ndim: usize,
266) -> Result<Vec<IndexWeight>, Exception> {
267    let indices = scaled_indices(dimension, scale, align_corners, dim, ndim)?;
268
269    // SAFETY: arith ops with scalars won't panic
270    let mut indices_l1 = floor(&indices)?;
271    let mut indices_r1 = floor(&(&indices + 1))?;
272    let mut indices_l2 = (&indices_l1) - 1;
273    let mut indices_r2 = (&indices_r1) + 1;
274
275    let weight_l1 = compiled_get_weight1(&indices, &indices_l1)?.index((Ellipsis, NewAxis));
276    let weight_r1 = compiled_get_weight1(&indices, &indices_r1)?.index((Ellipsis, NewAxis));
277    let weight_l2 = compiled_get_weight2(&indices, &indices_l2)?.index((Ellipsis, NewAxis));
278    let weight_r2 = compiled_get_weight2(&indices, &indices_r2)?.index((Ellipsis, NewAxis));
279
280    // Padding with border value
281    indices_l1 = clip(&indices_l1, (0, dimension - 1))?.as_type::<i32>()?;
282    indices_r1 = clip(&indices_r1, (0, dimension - 1))?.as_type::<i32>()?;
283    indices_l2 = clip(&indices_l2, (0, dimension - 1))?.as_type::<i32>()?;
284    indices_r2 = clip(&indices_r2, (0, dimension - 1))?.as_type::<i32>()?;
285
286    Ok(vec![
287        (indices_l1, weight_l1),
288        (indices_r1, weight_r1),
289        (indices_l2, weight_l2),
290        (indices_r2, weight_r2),
291    ])
292}
293
294fn compiled_get_weight1(ind: &Array, grid: &Array) -> Result<Array, Exception> {
295    // PyTorch uses -0.5 for antialiasing=true (compatibility with PIL)
296    // and uses -0.75 for antialiasing=false (compatibility with OpenCV)
297
298    let get_weight1 = |(ind_, grid_): (&Array, &Array)| {
299        let a = -0.75;
300        let x = abs(ind_ - grid_)?;
301        Ok((array!(a + 2.0) * &x - array!(a + 3.0)) * &x * &x + 1.0)
302    };
303    let mut compiled = compile(get_weight1, true);
304    compiled((ind, grid))
305}
306
307fn compiled_get_weight2(ind: &Array, grid: &Array) -> Result<Array, Exception> {
308    let get_weight2 = |(ind_, grid_): (&Array, &Array)| {
309        let a = -0.75;
310        let x = abs(ind_ - grid_)?;
311        Ok((((&x - 5.0) * &x + 8.0) * &x - 4.0) * a)
312    };
313    let mut compiled = compile(get_weight2, true);
314    compiled((ind, grid))
315}
316
317#[allow(non_snake_case)]
318fn scaled_indices(
319    N: i32,
320    scale: f32,
321    align_corners: bool,
322    dim: usize,
323    ndim: usize,
324) -> Result<Array, Exception> {
325    let M = (scale * N as f32) as i32;
326
327    let indices = match align_corners {
328        true => {
329            // SAFETY: arith ops on with scalars won't panic
330            Array::from_iter(0..M, &[M]).as_type::<f32>()? * ((N as f32 - 1.0) / (M as f32 - 1.0))
331        }
332        false => {
333            let step = 1.0 / scale;
334            let start = ((M as f32 - 1.0) * step - N as f32 + 1.0) / 2.0;
335            // SAFETY: arith ops with scalars won't panic
336            Array::from_iter(0..M, &[M]).as_type::<f32>()? * step - start
337        }
338    };
339
340    let mut shape = vec![1; ndim];
341    shape[dim] = -1;
342
343    indices.reshape(&shape)
344}
345
346#[cfg(test)]
347mod tests {
348    use crate::assert_array_eq;
349
350    use super::*;
351
352    // The unit test below is adapted from the swift binding.
353    #[test]
354    fn test_nearest() {
355        // BHWC
356        let input = array!([1, 2, 3, 4], shape = [1, 2, 2, 1]);
357
358        let mut up = Upsample::new(2.0, UpsampleMode::Nearest);
359        let result = up.forward(&input).and_then(|r| r.squeeze(None)).unwrap();
360
361        assert_eq!(result.shape(), &[4, 4]);
362
363        // array([[1, 1, 2, 2],
364        //        [1, 1, 2, 2],
365        //        [3, 3, 4, 4],
366        //        [3, 3, 4, 4]], dtype=int32)
367        let expected = array!(
368            [1, 1, 2, 2, 1, 1, 2, 2, 3, 3, 4, 4, 3, 3, 4, 4],
369            shape = [4, 4]
370        )
371        .as_type::<i32>()
372        .unwrap();
373        assert_eq!(result, expected);
374    }
375
376    // The unit test below is adapted from the swift binding.
377    #[test]
378    fn test_linear() {
379        // BHWC
380        let input = array!([1, 2, 3, 4], shape = [1, 2, 2, 1]);
381
382        let mut up = Upsample::new(
383            2.0,
384            UpsampleMode::Linear {
385                align_corners: false,
386            },
387        );
388        let result = up.forward(&input).and_then(|r| r.squeeze(None)).unwrap();
389
390        assert_eq!(result.shape(), &[4, 4]);
391
392        // array([[1, 1.25, 1.75, 2],
393        //        [1.5, 1.75, 2.25, 2.5],
394        //        [2.5, 2.75, 3.25, 3.5],
395        //        [3, 3.25, 3.75, 4]], dtype=float32)
396        let expected = array!(
397            [
398                1.0, 1.25, 1.75, 2.0, 1.5, 1.75, 2.25, 2.5, 2.5, 2.75, 3.25, 3.5, 3.0, 3.25, 3.75,
399                4.0
400            ],
401            shape = [4, 4]
402        )
403        .as_type::<f32>()
404        .unwrap();
405        assert_eq!(result, expected);
406    }
407
408    // The expected output for the test case below is obtained from the python binding.
409    #[test]
410    fn test_cubic() {
411        // BHWC
412        let input = array!([1, 2, 3, 4], shape = [1, 2, 2, 1]);
413
414        let mut up = Upsample::new(
415            2.0,
416            UpsampleMode::Cubic {
417                align_corners: false,
418            },
419        );
420        let result = up.forward(&input).and_then(|r| r.squeeze(None)).unwrap();
421
422        assert_eq!(result.shape(), &[4, 4]);
423
424        // Expected output from the python binding version 0.17.2
425        // array([[0.683594, 1.01562, 1.5625, 1.89453],
426        //     [1.34766, 1.67969, 2.22656, 2.55859],
427        //     [2.44141, 2.77344, 3.32031, 3.65234],
428        //     [3.10547, 3.4375, 3.98438, 4.31641]], dtype=float32)
429        let expected = array!(
430            [
431                0.683594, 1.01562, 1.5625, 1.89453, 1.34766, 1.67969, 2.22656, 2.55859, 2.44141,
432                2.77344, 3.32031, 3.65234, 3.10547, 3.4375, 3.98438, 4.31641
433            ],
434            shape = [4, 4]
435        )
436        .as_type::<f32>()
437        .unwrap();
438
439        assert_array_eq!(result, expected, 1e-5);
440    }
441}