mlx_rs/fft/
utils.rs

1use smallvec::SmallVec;
2
3use crate::{constants::DEFAULT_STACK_VEC_LEN, utils::resolve_index_unchecked, Array};
4
5#[inline]
6pub(super) fn resolve_size_and_axis_unchecked(
7    a: &Array,
8    n: Option<i32>,
9    axis: Option<i32>,
10) -> (i32, i32) {
11    let axis = axis.unwrap_or(-1);
12    let n = n.unwrap_or_else(|| {
13        let axis_index = resolve_index_unchecked(axis, a.ndim());
14        a.shape()[axis_index]
15    });
16    (n, axis)
17}
18
19// Use Cow or SmallVec?
20#[inline]
21pub(super) fn resolve_sizes_and_axes_unchecked<'a>(
22    a: &Array,
23    s: Option<&'a [i32]>,
24    axes: Option<&'a [i32]>,
25) -> (
26    SmallVec<[i32; DEFAULT_STACK_VEC_LEN]>,
27    SmallVec<[i32; DEFAULT_STACK_VEC_LEN]>,
28) {
29    match (s, axes) {
30        (Some(s), Some(axes)) => {
31            let valid_s = SmallVec::<[i32; DEFAULT_STACK_VEC_LEN]>::from_slice(s);
32            let valid_axes = SmallVec::<[i32; DEFAULT_STACK_VEC_LEN]>::from_slice(axes);
33            (valid_s, valid_axes)
34        }
35        (Some(s), None) => {
36            let valid_s = SmallVec::<[i32; DEFAULT_STACK_VEC_LEN]>::from_slice(s);
37            let valid_axes = (-(valid_s.len() as i32)..0).collect();
38            (valid_s, valid_axes)
39        }
40        (None, Some(axes)) => {
41            let valid_s = axes
42                .iter()
43                .map(|&axis| {
44                    let axis_index = resolve_index_unchecked(axis, a.ndim());
45                    a.shape()[axis_index]
46                })
47                .collect();
48            let valid_axes = SmallVec::<[i32; DEFAULT_STACK_VEC_LEN]>::from_slice(axes);
49            (valid_s, valid_axes)
50        }
51        (None, None) => {
52            let valid_s: SmallVec<[i32; DEFAULT_STACK_VEC_LEN]> =
53                (0..a.ndim()).map(|axis| a.shape()[axis]).collect();
54            let valid_axes = (-(valid_s.len() as i32)..0).collect();
55            (valid_s, valid_axes)
56        }
57    }
58}