1use crate::{
2 array,
3 error::Exception,
4 macros::ModuleParameters,
5 module::Module,
6 ops::{
7 abs, broadcast_to, ceil, clip, expand_dims, floor,
8 indexing::{ArrayIndex, ArrayIndexOp, Ellipsis, IndexOp, NewAxis, TryIndexOp},
9 },
10 transforms::compile::compile,
11 Array,
12};
13
14use crate::utils::SingleOrVec;
15
16#[derive(Debug, Clone, Copy, PartialEq, Eq)]
18pub enum UpsampleMode {
19 Nearest,
21
22 Linear {
24 align_corners: bool,
27 },
28
29 Cubic {
31 align_corners: bool,
33 },
34}
35
36#[derive(Debug, Clone, ModuleParameters)]
38#[module(root = crate)]
39pub struct Upsample {
40 pub scale_factor: SingleOrVec<f32>,
46
47 pub mode: UpsampleMode,
49}
50
51impl Upsample {
52 pub fn new(scale_factor: impl Into<SingleOrVec<f32>>, mode: UpsampleMode) -> Self {
54 let scale_factor = scale_factor.into();
55 Upsample { scale_factor, mode }
56 }
57
58 fn forward_inner(&self, x: &Array, scale: &[f32]) -> Result<Array, Exception> {
59 match self.mode {
60 UpsampleMode::Nearest => upsample_nearest(x, scale),
61 UpsampleMode::Linear { align_corners } => {
62 interpolate(x, scale, linear_indices, align_corners)
63 }
64 UpsampleMode::Cubic { align_corners } => {
65 interpolate(x, scale, cubic_indices, align_corners)
66 }
67 }
68 }
69}
70
71impl Module<&Array> for Upsample {
72 type Error = Exception;
73 type Output = Array;
74
75 fn forward(&mut self, x: &Array) -> Result<Self::Output, Self::Error> {
76 let dimensions = x.ndim() - 2;
77
78 if dimensions == 0 {
79 return Err(Exception::custom(format!(
80 "[Upsample] The input should have at least
81 1 spatial dimension which means it should be at least
82 3D but {}D was provided",
83 x.ndim()
84 )));
85 }
86
87 match &self.scale_factor {
88 SingleOrVec::Single(scale) => {
89 let scale = vec![*scale; dimensions];
90 self.forward_inner(x, &scale[..])
91 }
92 SingleOrVec::Vec(scales) => self.forward_inner(x, &scales[..]),
93 }
94 }
95
96 fn training_mode(&mut self, _mode: bool) {}
97}
98
99#[allow(non_snake_case)]
100fn upsample_nearest(x: &Array, scale: &[f32]) -> Result<Array, Exception> {
101 let dimensions = x.ndim() - 2;
102 if dimensions != scale.len() {
103 return Err(Exception::custom(format!(
104 "The number of scale factors ({}) must match the number of spatial dimensions ({})",
105 scale.len(),
106 dimensions
107 )));
108 }
109
110 let int_scales = scale.iter().map(|&s| s as i32).collect::<Vec<_>>();
112 let int_float_scales = int_scales.iter().map(|&s| s as f32).collect::<Vec<_>>();
113
114 if int_float_scales == scale {
115 let mut shape = x.shape().to_vec();
117 (0..dimensions).for_each(|d| {
118 shape.insert(2 + 2 * d, 1);
119 });
120 let mut x = x.reshape(&shape)?;
121
122 (0..dimensions).for_each(|d| {
123 shape[2 + 2 * d] = int_scales[d];
124 });
125 x = broadcast_to(&x, &shape)?;
126
127 (0..dimensions).for_each(|d| {
128 shape[d + 1] *= shape[d + 2];
129 shape.remove(d + 2);
130 });
131 x = x.reshape(&shape)?;
132
133 Ok(x)
134 } else {
135 let shape_len = x.shape().len();
137 let N = &x.shape()[1..shape_len - 1];
138 let mut indices: Vec<ArrayIndexOp> = vec![(..).index_op()];
139
140 for (i, (n, s)) in N.iter().zip(scale.iter()).enumerate() {
141 indices.push(nearest_indices(*n, *s, i, dimensions)?.index_op());
142 }
143
144 x.try_index(&indices[..])
145 }
146}
147
148type IndexWeight = (Array, Array);
149
150type IndicesFn = fn(i32, f32, bool, usize, usize) -> Result<Vec<IndexWeight>, Exception>;
151
152#[allow(non_snake_case)]
153fn interpolate(
154 x: &Array,
155 scale: &[f32],
156 indices_fn: IndicesFn,
157 align_corners: bool,
158) -> Result<Array, Exception> {
159 let dimensions = x.ndim() - 2;
160 if dimensions != scale.len() {
161 return Err(Exception::custom(format!(
162 "The number of scale factors ({}) must match the number of spatial dimensions ({})",
163 scale.len(),
164 dimensions
165 )));
166 }
167
168 let N = &x.shape()[1..x.ndim() - 1];
169
170 let mut index_weights = Vec::with_capacity(N.len());
172 for (i, (n, s)) in N.iter().zip(scale.iter()).enumerate() {
173 index_weights.push(indices_fn(*n, *s, align_corners, i, dimensions)?);
174 }
175
176 let prod = product(&index_weights);
178 let mut samples = Vec::with_capacity(prod.len());
179 let mut weights = Vec::with_capacity(prod.len());
180 for index_weight in prod {
181 let (index, weight): (Vec<&Array>, Vec<&Array>) =
182 index_weight.iter().map(|(i, w)| (i, w)).unzip();
183 let mut index_ops = index.iter().map(|i| i.index_op()).collect::<Vec<_>>();
184
185 let mut sample_indices = vec![(..).index_op()];
186 sample_indices.append(&mut index_ops);
187 samples.push(x.index(&sample_indices[..]));
188
189 weights.push(weight.into_iter().product::<Array>());
190 }
191
192 let acc = &weights[0] * &samples[0];
194 weights[1..]
195 .iter()
196 .zip(samples[1..].iter())
197 .try_fold(acc, |acc, (w, s)| acc.add(w.multiply(s)?))
198}
199
200fn product<T>(values: &[Vec<T>]) -> Vec<Vec<&T>> {
201 if values.is_empty() {
202 return vec![];
203 }
204
205 let per_tuple = values[0].len();
208 let count = (0..values.len()).fold(1, |acc, _| acc * per_tuple);
209
210 let mut result = Vec::with_capacity(count);
211 for result_index in 0..count {
212 let mut items = vec![];
213
214 let mut index_generator = result_index;
216 for value in values {
217 let index = index_generator % per_tuple;
218 items.push(&value[index]);
219 index_generator /= per_tuple;
220 }
221
222 result.push(items);
223 }
224
225 result
226}
227
228fn nearest_indices(
229 dimension: i32,
230 scale: f32,
231 dim: usize,
232 ndim: usize,
233) -> Result<Array, Exception> {
234 scaled_indices(dimension, scale, true, dim, ndim).and_then(|i| i.as_type::<i32>())
235}
236
237fn linear_indices(
238 dimension: i32,
239 scale: f32,
240 align_corners: bool,
241 dim: usize,
242 ndim: usize,
243) -> Result<Vec<IndexWeight>, Exception> {
244 let mut indices = scaled_indices(dimension, scale, align_corners, dim, ndim)?;
245 indices = clip(&indices, (0, dimension - 1))?;
246 let indices_left = floor(&indices)?;
247 let indices_right = ceil(&indices)?;
248 let weight = expand_dims(&indices.subtract(&indices_left)?, &[-1])?;
249
250 let indices_left = indices_left.as_type::<i32>()?;
251 let indices_right = indices_right.as_type::<i32>()?;
252
253 Ok(vec![
254 (indices_left, array!(1.0) - &weight),
256 (indices_right, weight),
257 ])
258}
259
260fn cubic_indices(
261 dimension: i32,
262 scale: f32,
263 align_corners: bool,
264 dim: usize,
265 ndim: usize,
266) -> Result<Vec<IndexWeight>, Exception> {
267 let indices = scaled_indices(dimension, scale, align_corners, dim, ndim)?;
268
269 let mut indices_l1 = floor(&indices)?;
271 let mut indices_r1 = floor(&(&indices + 1))?;
272 let mut indices_l2 = (&indices_l1) - 1;
273 let mut indices_r2 = (&indices_r1) + 1;
274
275 let weight_l1 = compiled_get_weight1(&indices, &indices_l1)?.index((Ellipsis, NewAxis));
276 let weight_r1 = compiled_get_weight1(&indices, &indices_r1)?.index((Ellipsis, NewAxis));
277 let weight_l2 = compiled_get_weight2(&indices, &indices_l2)?.index((Ellipsis, NewAxis));
278 let weight_r2 = compiled_get_weight2(&indices, &indices_r2)?.index((Ellipsis, NewAxis));
279
280 indices_l1 = clip(&indices_l1, (0, dimension - 1))?.as_type::<i32>()?;
282 indices_r1 = clip(&indices_r1, (0, dimension - 1))?.as_type::<i32>()?;
283 indices_l2 = clip(&indices_l2, (0, dimension - 1))?.as_type::<i32>()?;
284 indices_r2 = clip(&indices_r2, (0, dimension - 1))?.as_type::<i32>()?;
285
286 Ok(vec![
287 (indices_l1, weight_l1),
288 (indices_r1, weight_r1),
289 (indices_l2, weight_l2),
290 (indices_r2, weight_r2),
291 ])
292}
293
294fn compiled_get_weight1(ind: &Array, grid: &Array) -> Result<Array, Exception> {
295 let get_weight1 = |(ind_, grid_): (&Array, &Array)| {
299 let a = -0.75;
300 let x = abs(ind_ - grid_)?;
301 Ok((array!(a + 2.0) * &x - array!(a + 3.0)) * &x * &x + 1.0)
302 };
303 let mut compiled = compile(get_weight1, true);
304 compiled((ind, grid))
305}
306
307fn compiled_get_weight2(ind: &Array, grid: &Array) -> Result<Array, Exception> {
308 let get_weight2 = |(ind_, grid_): (&Array, &Array)| {
309 let a = -0.75;
310 let x = abs(ind_ - grid_)?;
311 Ok((((&x - 5.0) * &x + 8.0) * &x - 4.0) * a)
312 };
313 let mut compiled = compile(get_weight2, true);
314 compiled((ind, grid))
315}
316
317#[allow(non_snake_case)]
318fn scaled_indices(
319 N: i32,
320 scale: f32,
321 align_corners: bool,
322 dim: usize,
323 ndim: usize,
324) -> Result<Array, Exception> {
325 let M = (scale * N as f32) as i32;
326
327 let indices = match align_corners {
328 true => {
329 Array::from_iter(0..M, &[M]).as_type::<f32>()? * ((N as f32 - 1.0) / (M as f32 - 1.0))
331 }
332 false => {
333 let step = 1.0 / scale;
334 let start = ((M as f32 - 1.0) * step - N as f32 + 1.0) / 2.0;
335 Array::from_iter(0..M, &[M]).as_type::<f32>()? * step - start
337 }
338 };
339
340 let mut shape = vec![1; ndim];
341 shape[dim] = -1;
342
343 indices.reshape(&shape)
344}
345
346#[cfg(test)]
347mod tests {
348 use crate::assert_array_eq;
349
350 use super::*;
351
352 #[test]
354 fn test_nearest() {
355 let input = array!([1, 2, 3, 4], shape = [1, 2, 2, 1]);
357
358 let mut up = Upsample::new(2.0, UpsampleMode::Nearest);
359 let result = up.forward(&input).and_then(|r| r.squeeze(None)).unwrap();
360
361 assert_eq!(result.shape(), &[4, 4]);
362
363 let expected = array!(
368 [1, 1, 2, 2, 1, 1, 2, 2, 3, 3, 4, 4, 3, 3, 4, 4],
369 shape = [4, 4]
370 )
371 .as_type::<i32>()
372 .unwrap();
373 assert_eq!(result, expected);
374 }
375
376 #[test]
378 fn test_linear() {
379 let input = array!([1, 2, 3, 4], shape = [1, 2, 2, 1]);
381
382 let mut up = Upsample::new(
383 2.0,
384 UpsampleMode::Linear {
385 align_corners: false,
386 },
387 );
388 let result = up.forward(&input).and_then(|r| r.squeeze(None)).unwrap();
389
390 assert_eq!(result.shape(), &[4, 4]);
391
392 let expected = array!(
397 [
398 1.0, 1.25, 1.75, 2.0, 1.5, 1.75, 2.25, 2.5, 2.5, 2.75, 3.25, 3.5, 3.0, 3.25, 3.75,
399 4.0
400 ],
401 shape = [4, 4]
402 )
403 .as_type::<f32>()
404 .unwrap();
405 assert_eq!(result, expected);
406 }
407
408 #[test]
410 fn test_cubic() {
411 let input = array!([1, 2, 3, 4], shape = [1, 2, 2, 1]);
413
414 let mut up = Upsample::new(
415 2.0,
416 UpsampleMode::Cubic {
417 align_corners: false,
418 },
419 );
420 let result = up.forward(&input).and_then(|r| r.squeeze(None)).unwrap();
421
422 assert_eq!(result.shape(), &[4, 4]);
423
424 let expected = array!(
430 [
431 0.683594, 1.01562, 1.5625, 1.89453, 1.34766, 1.67969, 2.22656, 2.55859, 2.44141,
432 2.77344, 3.32031, 3.65234, 3.10547, 3.4375, 3.98438, 4.31641
433 ],
434 shape = [4, 4]
435 )
436 .as_type::<f32>()
437 .unwrap();
438
439 assert_array_eq!(result, expected, 1e-5);
440 }
441}