mlx_rs/fft/
rfftn.rs

1use mlx_internal_macros::{default_device, generate_macro};
2
3use crate::{
4    error::Result,
5    utils::{guard::Guarded, IntoOption},
6    Array, Stream,
7};
8
9use super::utils::{resolve_size_and_axis_unchecked, resolve_sizes_and_axes_unchecked};
10
11/// One dimensional discrete Fourier Transform on a real input.
12///
13/// The output has the same shape as the input except along `axis` in which case it has size `n // 2
14/// + 1`.
15///
16/// # Params
17///
18/// - `a`: The input array. If the array is complex it will be silently cast to a real type.
19/// - `n`: Size of the transformed axis. The corresponding axis in the input is truncated or padded
20///  with zeros to match `n`. The default value is `a.shape[axis]` if not specified.
21/// - `axis`: Axis along which to perform the FFT. The default is `-1` if not specified.
22#[generate_macro(customize(root = "$crate::fft"))]
23#[default_device]
24pub fn rfft_device(
25    a: impl AsRef<Array>,
26    #[optional] n: impl Into<Option<i32>>,
27    #[optional] axis: impl Into<Option<i32>>,
28    #[optional] stream: impl AsRef<Stream>,
29) -> Result<Array> {
30    let a = a.as_ref();
31    let (n, axis) = resolve_size_and_axis_unchecked(a, n.into(), axis.into());
32    Array::try_from_op(|res| unsafe {
33        mlx_sys::mlx_fft_rfft(res, a.as_ptr(), n, axis, stream.as_ref().as_ptr())
34    })
35}
36
37/// Two-dimensional real discrete Fourier Transform.
38///
39/// The output has the same shape as the input except along the dimensions in `axes` in which case
40/// it has sizes from `s`. The last axis in `axes` is treated as the real axis and will have size
41/// `s[s.len()-1] // 2 + 1`.
42///
43/// # Params
44///
45/// - `a`: The input array. If the array is complex it will be silently cast to a real type.
46/// - `s`: Sizes of the transformed axes. The corresponding axes in the input are truncated or
47/// padded with zeros to match `s`. The default value is the sizes of `a` along `axes`.
48/// - `axes`: Axes along which to perform the FFT. The default is `[-2, -1]`.
49#[generate_macro(customize(root = "$crate::fft"))]
50#[default_device]
51pub fn rfft2_device<'a>(
52    a: impl AsRef<Array>,
53    #[optional] s: impl IntoOption<&'a [i32]>,
54    #[optional] axes: impl IntoOption<&'a [i32]>,
55    #[optional] stream: impl AsRef<Stream>,
56) -> Result<Array> {
57    let a = a.as_ref();
58    let axes = axes.into_option().unwrap_or(&[-2, -1]);
59    let (s, axes) = resolve_sizes_and_axes_unchecked(a, s.into_option(), Some(axes));
60
61    let num_s = s.len();
62    let num_axes = axes.len();
63
64    let s_ptr = s.as_ptr();
65    let axes_ptr = axes.as_ptr();
66
67    Array::try_from_op(|res| unsafe {
68        mlx_sys::mlx_fft_rfft2(
69            res,
70            a.as_ptr(),
71            s_ptr,
72            num_s,
73            axes_ptr,
74            num_axes,
75            stream.as_ref().as_ptr(),
76        )
77    })
78}
79
80/// n-dimensional real discrete Fourier Transform.
81///
82/// The output has the same shape as the input except along the dimensions in `axes` in which case
83/// it has sizes from `s`. The last axis in `axes` is treated as the real axis and will have size
84/// `s[s.len()-1] // 2 + 1`.
85///
86/// # Params
87///
88/// - `a`: The input array. If the array is complex it will be silently cast to a real type.
89/// - `s`: Sizes of the transformed axes. The corresponding axes in the input are truncated or
90/// padded with zeros to match `s`. The default value is the sizes of `a` along `axes`.
91/// - `axes`: Axes along which to perform the FFT. The default is `None` in which case the FFT is over
92///   the last `len(s)` axes or all axes if `s` is also `None`.
93#[generate_macro(customize(root = "$crate::fft"))]
94#[default_device]
95pub fn rfftn_device<'a>(
96    a: impl AsRef<Array>,
97    #[optional] s: impl IntoOption<&'a [i32]>,
98    #[optional] axes: impl IntoOption<&'a [i32]>,
99    #[optional] stream: impl AsRef<Stream>,
100) -> Result<Array> {
101    let a = a.as_ref();
102    let (s, axes) = resolve_sizes_and_axes_unchecked(a, s.into_option(), axes.into_option());
103
104    let num_s = s.len();
105    let num_axes = axes.len();
106
107    let s_ptr = s.as_ptr();
108    let axes_ptr = axes.as_ptr();
109
110    Array::try_from_op(|res| unsafe {
111        mlx_sys::mlx_fft_rfftn(
112            res,
113            a.as_ptr(),
114            s_ptr,
115            num_s,
116            axes_ptr,
117            num_axes,
118            stream.as_ref().as_ptr(),
119        )
120    })
121}
122
123/// The inverse of [`rfft()`].
124///
125/// The output has the same shape as the input except along axis in which case it has size n.
126///
127/// # Params
128///
129/// - `a`: The input array.
130/// - `n`: Size of the transformed axis. The corresponding axis in the input is truncated or padded
131///   with zeros to match `n // 2 + 1`. The default value is `a.shape[axis] // 2 + 1`.
132/// - `axis`: Axis along which to perform the FFT. The default is `-1`.
133#[generate_macro(customize(root = "$crate::fft"))]
134#[default_device]
135pub fn irfft_device(
136    a: impl AsRef<Array>,
137    #[optional] n: impl Into<Option<i32>>,
138    #[optional] axis: impl Into<Option<i32>>,
139    #[optional] stream: impl AsRef<Stream>,
140) -> Result<Array> {
141    let a = a.as_ref();
142    let n = n.into();
143    let axis = axis.into();
144    let modify_n = n.is_none();
145    let (mut n, axis) = resolve_size_and_axis_unchecked(a, n, axis);
146    if modify_n {
147        n = (n - 1) * 2;
148    }
149
150    Array::try_from_op(|res| unsafe {
151        mlx_sys::mlx_fft_irfft(res, a.as_ptr(), n, axis, stream.as_ref().as_ptr())
152    })
153}
154
155/// The inverse of [`rfft2()`].
156///
157/// Note the input is generally complex. The dimensions of the input specified in `axes` are padded
158/// or truncated to match the sizes from `s`. The last axis in `axes` is treated as the real axis
159/// and will have size `s[s.len()-1] // 2 + 1`.
160///
161/// # Params
162///
163/// - `a`: The input array.
164/// - `s`: Sizes of the transformed axes. The corresponding axes in the input are truncated or
165///   padded with zeros to match the sizes in `s` except for the last axis which has size
166///   `s[s.len()-1] // 2 + 1`. The default value is the sizes of `a` along `axes`.
167/// - `axes`: Axes along which to perform the FFT. The default is `[-2, -1]`.
168#[generate_macro(customize(root = "$crate::fft"))]
169#[default_device]
170pub fn irfft2_device<'a>(
171    a: impl AsRef<Array>,
172    #[optional] s: impl IntoOption<&'a [i32]>,
173    #[optional] axes: impl IntoOption<&'a [i32]>,
174    #[optional] stream: impl AsRef<Stream>,
175) -> Result<Array> {
176    let a = a.as_ref();
177    let s = s.into_option();
178    let axes = axes.into_option().unwrap_or(&[-2, -1]);
179    let modify_last_axis = s.is_none();
180
181    let (mut s, axes) = resolve_sizes_and_axes_unchecked(a, s, Some(axes));
182    if modify_last_axis {
183        let end = s.len() - 1;
184        s[end] = (s[end] - 1) * 2;
185    }
186
187    let num_s = s.len();
188    let num_axes = axes.len();
189
190    let s_ptr = s.as_ptr();
191    let axes_ptr = axes.as_ptr();
192
193    Array::try_from_op(|res| unsafe {
194        mlx_sys::mlx_fft_irfft2(
195            res,
196            a.as_ptr(),
197            s_ptr,
198            num_s,
199            axes_ptr,
200            num_axes,
201            stream.as_ref().as_ptr(),
202        )
203    })
204}
205
206/// The inverse of [`rfftn()`].
207///
208/// Note the input is generally complex. The dimensions of the input specified in `axes` are padded
209/// or truncated to match the sizes from `s`. The last axis in `axes` is treated as the real axis
210/// and will have size `s[s.len()-1] // 2 + 1`.
211///
212/// # Params
213///
214/// - `a`: The input array.
215/// - `s`: Sizes of the transformed axes. The corresponding axes in the input are truncated or
216///   padded with zeros to match the sizes in `s` except for the last axis which has size
217///   `s[s.len()-1] // 2 + 1`. The default value is the sizes of `a` along `axes`.
218/// - `axes`: Axes along which to perform the FFT. The default is `None` in which case the FFT is
219///  over the last `len(s)` axes or all axes if `s` is also `None`.
220#[generate_macro(customize(root = "$crate::fft"))]
221#[default_device]
222pub fn irfftn_device<'a>(
223    a: impl AsRef<Array>,
224    #[optional] s: impl IntoOption<&'a [i32]>,
225    #[optional] axes: impl IntoOption<&'a [i32]>,
226    #[optional] stream: impl AsRef<Stream>,
227) -> Result<Array> {
228    let a = a.as_ref();
229    let s = s.into_option();
230    let axes = axes.into_option();
231    let modify_last_axis = s.is_none();
232
233    let (mut s, axes) = resolve_sizes_and_axes_unchecked(a, s, axes);
234    if modify_last_axis {
235        let end = s.len() - 1;
236        s[end] = (s[end] - 1) * 2;
237    }
238
239    let num_s = s.len();
240    let num_axes = axes.len();
241
242    let s_ptr = s.as_ptr();
243    let axes_ptr = axes.as_ptr();
244
245    Array::try_from_op(|res| unsafe {
246        mlx_sys::mlx_fft_irfftn(
247            res,
248            a.as_ptr(),
249            s_ptr,
250            num_s,
251            axes_ptr,
252            num_axes,
253            stream.as_ref().as_ptr(),
254        )
255    })
256}
257
258#[cfg(test)]
259mod tests {
260    use crate::{complex64, Array, Dtype};
261
262    #[test]
263    fn test_rfft() {
264        const RFFT_DATA: &[f32] = &[1.0, 2.0, 3.0, 4.0];
265        const RFFT_N: i32 = 4;
266        const RFFT_SHAPE: &[i32] = &[RFFT_N];
267        const RFFT_AXIS: i32 = -1;
268        const RFFT_EXPECTED: &[complex64] = &[
269            complex64::new(10.0, 0.0),
270            complex64::new(-2.0, 2.0),
271            complex64::new(-2.0, 0.0),
272        ];
273
274        let a = Array::from_slice(RFFT_DATA, RFFT_SHAPE);
275        let rfft = super::rfft(&a, RFFT_N, RFFT_AXIS).unwrap();
276        assert_eq!(rfft.dtype(), Dtype::Complex64);
277        assert_eq!(rfft.as_slice::<complex64>(), RFFT_EXPECTED);
278
279        let irfft = super::irfft(&rfft, RFFT_N, RFFT_AXIS).unwrap();
280        assert_eq!(irfft.dtype(), Dtype::Float32);
281        assert_eq!(irfft.as_slice::<f32>(), RFFT_DATA);
282    }
283
284    #[test]
285    fn test_rfft_shape_with_default_params() {
286        const IN_N: i32 = 8;
287        const OUT_N: i32 = IN_N / 2 + 1;
288
289        let a = Array::ones::<f32>(&[IN_N]).unwrap();
290        let rfft = super::rfft(&a, None, None).unwrap();
291        assert_eq!(rfft.shape(), &[OUT_N]);
292    }
293
294    #[test]
295    fn test_irfft_shape_with_default_params() {
296        const IN_N: i32 = 8;
297        const OUT_N: i32 = (IN_N - 1) * 2;
298
299        let a = Array::ones::<f32>(&[IN_N]).unwrap();
300        let irfft = super::irfft(&a, None, None).unwrap();
301        assert_eq!(irfft.shape(), &[OUT_N]);
302    }
303
304    #[test]
305    fn test_rfft2() {
306        const RFFT2_DATA: &[f32] = &[1.0; 4];
307        const RFFT2_SHAPE: &[i32] = &[2, 2];
308        const RFFT2_EXPECTED: &[complex64] = &[
309            complex64::new(4.0, 0.0),
310            complex64::new(0.0, 0.0),
311            complex64::new(0.0, 0.0),
312            complex64::new(0.0, 0.0),
313        ];
314
315        let a = Array::from_slice(RFFT2_DATA, RFFT2_SHAPE);
316        let rfft2 = super::rfft2(&a, None, None).unwrap();
317        assert_eq!(rfft2.dtype(), Dtype::Complex64);
318        assert_eq!(rfft2.as_slice::<complex64>(), RFFT2_EXPECTED);
319
320        let irfft2 = super::irfft2(&rfft2, None, None).unwrap();
321        assert_eq!(irfft2.dtype(), Dtype::Float32);
322        assert_eq!(irfft2.as_slice::<f32>(), RFFT2_DATA);
323    }
324
325    #[test]
326    fn test_rfft2_shape_with_default_params() {
327        const IN_SHAPE: &[i32] = &[6, 6];
328        const OUT_SHAPE: &[i32] = &[6, 6 / 2 + 1];
329
330        let a = Array::ones::<f32>(IN_SHAPE).unwrap();
331        let rfft2 = super::rfft2(&a, None, None).unwrap();
332        assert_eq!(rfft2.shape(), OUT_SHAPE);
333    }
334
335    #[test]
336    fn test_irfft2_shape_with_default_params() {
337        const IN_SHAPE: &[i32] = &[6, 6];
338        const OUT_SHAPE: &[i32] = &[6, (6 - 1) * 2];
339
340        let a = Array::ones::<f32>(IN_SHAPE).unwrap();
341        let irfft2 = super::irfft2(&a, None, None).unwrap();
342        assert_eq!(irfft2.shape(), OUT_SHAPE);
343    }
344
345    #[test]
346    fn test_rfftn() {
347        const RFFTN_DATA: &[f32] = &[1.0; 8];
348        const RFFTN_SHAPE: &[i32] = &[2, 2, 2];
349        const RFFTN_EXPECTED: &[complex64] = &[
350            complex64::new(8.0, 0.0),
351            complex64::new(0.0, 0.0),
352            complex64::new(0.0, 0.0),
353            complex64::new(0.0, 0.0),
354            complex64::new(0.0, 0.0),
355            complex64::new(0.0, 0.0),
356            complex64::new(0.0, 0.0),
357            complex64::new(0.0, 0.0),
358        ];
359
360        let a = Array::from_slice(RFFTN_DATA, RFFTN_SHAPE);
361        let rfftn = super::rfftn(&a, None, None).unwrap();
362        assert_eq!(rfftn.dtype(), Dtype::Complex64);
363        assert_eq!(rfftn.as_slice::<complex64>(), RFFTN_EXPECTED);
364
365        let irfftn = super::irfftn(&rfftn, None, None).unwrap();
366        assert_eq!(irfftn.dtype(), Dtype::Float32);
367        assert_eq!(irfftn.as_slice::<f32>(), RFFTN_DATA);
368    }
369
370    #[test]
371    fn test_fftn_shape_with_default_params() {
372        const IN_SHAPE: &[i32] = &[6, 6, 6];
373        const OUT_SHAPE: &[i32] = &[6, 6, 6 / 2 + 1];
374
375        let a = Array::ones::<f32>(IN_SHAPE).unwrap();
376        let rfftn = super::rfftn(&a, None, None).unwrap();
377        assert_eq!(rfftn.shape(), OUT_SHAPE);
378    }
379
380    #[test]
381    fn test_irfftn_shape_with_default_params() {
382        const IN_SHAPE: &[i32] = &[6, 6, 6];
383        const OUT_SHAPE: &[i32] = &[6, 6, (6 - 1) * 2];
384
385        let a = Array::ones::<f32>(IN_SHAPE).unwrap();
386        let irfftn = super::irfftn(&a, None, None).unwrap();
387        assert_eq!(irfftn.shape(), OUT_SHAPE);
388    }
389}