mlx_rs/fft/
rfftn.rs

1use mlx_internal_macros::{default_device, generate_macro};
2
3use crate::{
4    error::Result,
5    utils::{guard::Guarded, IntoOption},
6    Array, Stream, StreamOrDevice,
7};
8
9use super::{
10    as_complex64,
11    utils::{resolve_size_and_axis_unchecked, resolve_sizes_and_axes_unchecked},
12};
13
14/// One dimensional discrete Fourier Transform on a real input.
15///
16/// The output has the same shape as the input except along `axis` in which case it has size `n // 2
17/// + 1`.
18///
19/// # Params
20///
21/// - `a`: The input array. If the array is complex it will be silently cast to a real type.
22/// - `n`: Size of the transformed axis. The corresponding axis in the input is truncated or padded
23///  with zeros to match `n`. The default value is `a.shape[axis]` if not specified.
24/// - `axis`: Axis along which to perform the FFT. The default is `-1` if not specified.
25#[generate_macro(customize(root = "$crate::fft"))]
26#[default_device]
27pub fn rfft_device(
28    a: impl AsRef<Array>,
29    #[optional] n: impl Into<Option<i32>>,
30    #[optional] axis: impl Into<Option<i32>>,
31    #[optional] stream: impl AsRef<Stream>,
32) -> Result<Array> {
33    let a = as_complex64(a.as_ref())?;
34    let (n, axis) = resolve_size_and_axis_unchecked(&a, n.into(), axis.into());
35    Array::try_from_op(|res| unsafe {
36        mlx_sys::mlx_fft_rfft(res, a.as_ptr(), n, axis, stream.as_ref().as_ptr())
37    })
38}
39
40/// Two-dimensional real discrete Fourier Transform.
41///
42/// The output has the same shape as the input except along the dimensions in `axes` in which case
43/// it has sizes from `s`. The last axis in `axes` is treated as the real axis and will have size
44/// `s[s.len()-1] // 2 + 1`.
45///
46/// # Params
47///
48/// - `a`: The input array. If the array is complex it will be silently cast to a real type.
49/// - `s`: Sizes of the transformed axes. The corresponding axes in the input are truncated or
50/// padded with zeros to match `s`. The default value is the sizes of `a` along `axes`.
51/// - `axes`: Axes along which to perform the FFT. The default is `[-2, -1]`.
52#[generate_macro(customize(root = "$crate::fft"))]
53#[default_device]
54pub fn rfft2_device<'a>(
55    a: impl AsRef<Array>,
56    #[optional] s: impl IntoOption<&'a [i32]>,
57    #[optional] axes: impl IntoOption<&'a [i32]>,
58    #[optional] stream: impl AsRef<Stream>,
59) -> Result<Array> {
60    let a = as_complex64(a.as_ref())?;
61    let axes = axes.into_option().unwrap_or(&[-2, -1]);
62    let (s, axes) = resolve_sizes_and_axes_unchecked(&a, s.into_option(), Some(axes));
63
64    let num_s = s.len();
65    let num_axes = axes.len();
66
67    let s_ptr = s.as_ptr();
68    let axes_ptr = axes.as_ptr();
69
70    Array::try_from_op(|res| unsafe {
71        mlx_sys::mlx_fft_rfft2(
72            res,
73            a.as_ptr(),
74            s_ptr,
75            num_s,
76            axes_ptr,
77            num_axes,
78            stream.as_ref().as_ptr(),
79        )
80    })
81}
82
83/// n-dimensional real discrete Fourier Transform.
84///
85/// The output has the same shape as the input except along the dimensions in `axes` in which case
86/// it has sizes from `s`. The last axis in `axes` is treated as the real axis and will have size
87/// `s[s.len()-1] // 2 + 1`.
88///
89/// # Params
90///
91/// - `a`: The input array. If the array is complex it will be silently cast to a real type.
92/// - `s`: Sizes of the transformed axes. The corresponding axes in the input are truncated or
93/// padded with zeros to match `s`. The default value is the sizes of `a` along `axes`.
94/// - `axes`: Axes along which to perform the FFT. The default is `None` in which case the FFT is over
95///   the last `len(s)` axes or all axes if `s` is also `None`.
96#[generate_macro(customize(root = "$crate::fft"))]
97#[default_device]
98pub fn rfftn_device<'a>(
99    a: impl AsRef<Array>,
100    #[optional] s: impl IntoOption<&'a [i32]>,
101    #[optional] axes: impl IntoOption<&'a [i32]>,
102    #[optional] stream: impl AsRef<Stream>,
103) -> Result<Array> {
104    let a = as_complex64(a.as_ref())?;
105    let (s, axes) = resolve_sizes_and_axes_unchecked(&a, s.into_option(), axes.into_option());
106
107    let num_s = s.len();
108    let num_axes = axes.len();
109
110    let s_ptr = s.as_ptr();
111    let axes_ptr = axes.as_ptr();
112
113    Array::try_from_op(|res| unsafe {
114        mlx_sys::mlx_fft_rfftn(
115            res,
116            a.as_ptr(),
117            s_ptr,
118            num_s,
119            axes_ptr,
120            num_axes,
121            stream.as_ref().as_ptr(),
122        )
123    })
124}
125
126/// The inverse of [`rfft()`].
127///
128/// The output has the same shape as the input except along axis in which case it has size n.
129///
130/// # Params
131///
132/// - `a`: The input array.
133/// - `n`: Size of the transformed axis. The corresponding axis in the input is truncated or padded
134///   with zeros to match `n // 2 + 1`. The default value is `a.shape[axis] // 2 + 1`.
135/// - `axis`: Axis along which to perform the FFT. The default is `-1`.
136#[generate_macro(customize(root = "$crate::fft"))]
137#[default_device]
138pub fn irfft_device(
139    a: impl AsRef<Array>,
140    #[optional] n: impl Into<Option<i32>>,
141    #[optional] axis: impl Into<Option<i32>>,
142    #[optional] stream: impl AsRef<Stream>,
143) -> Result<Array> {
144    let a = as_complex64(a.as_ref())?;
145    let n = n.into();
146    let axis = axis.into();
147    let modify_n = n.is_none();
148    let (mut n, axis) = resolve_size_and_axis_unchecked(&a, n, axis);
149    if modify_n {
150        n = (n - 1) * 2;
151    }
152
153    Array::try_from_op(|res| unsafe {
154        mlx_sys::mlx_fft_irfft(res, a.as_ptr(), n, axis, stream.as_ref().as_ptr())
155    })
156}
157
158/// The inverse of [`rfft2()`].
159///
160/// Note the input is generally complex. The dimensions of the input specified in `axes` are padded
161/// or truncated to match the sizes from `s`. The last axis in `axes` is treated as the real axis
162/// and will have size `s[s.len()-1] // 2 + 1`.
163///
164/// # Params
165///
166/// - `a`: The input array.
167/// - `s`: Sizes of the transformed axes. The corresponding axes in the input are truncated or
168///   padded with zeros to match the sizes in `s` except for the last axis which has size
169///   `s[s.len()-1] // 2 + 1`. The default value is the sizes of `a` along `axes`.
170/// - `axes`: Axes along which to perform the FFT. The default is `[-2, -1]`.
171#[generate_macro(customize(root = "$crate::fft"))]
172#[default_device]
173pub fn irfft2_device<'a>(
174    a: impl AsRef<Array>,
175    #[optional] s: impl IntoOption<&'a [i32]>,
176    #[optional] axes: impl IntoOption<&'a [i32]>,
177    #[optional] stream: impl AsRef<Stream>,
178) -> Result<Array> {
179    let a = as_complex64(a.as_ref())?;
180    let s = s.into_option();
181    let axes = axes.into_option().unwrap_or(&[-2, -1]);
182    let modify_last_axis = s.is_none();
183
184    let (mut s, axes) = resolve_sizes_and_axes_unchecked(&a, s, Some(axes));
185    if modify_last_axis {
186        let end = s.len() - 1;
187        s[end] = (s[end] - 1) * 2;
188    }
189
190    let num_s = s.len();
191    let num_axes = axes.len();
192
193    let s_ptr = s.as_ptr();
194    let axes_ptr = axes.as_ptr();
195
196    Array::try_from_op(|res| unsafe {
197        mlx_sys::mlx_fft_irfft2(
198            res,
199            a.as_ptr(),
200            s_ptr,
201            num_s,
202            axes_ptr,
203            num_axes,
204            stream.as_ref().as_ptr(),
205        )
206    })
207}
208
209/// The inverse of [`rfftn()`].
210///
211/// Note the input is generally complex. The dimensions of the input specified in `axes` are padded
212/// or truncated to match the sizes from `s`. The last axis in `axes` is treated as the real axis
213/// and will have size `s[s.len()-1] // 2 + 1`.
214///
215/// # Params
216///
217/// - `a`: The input array.
218/// - `s`: Sizes of the transformed axes. The corresponding axes in the input are truncated or
219///   padded with zeros to match the sizes in `s` except for the last axis which has size
220///   `s[s.len()-1] // 2 + 1`. The default value is the sizes of `a` along `axes`.
221/// - `axes`: Axes along which to perform the FFT. The default is `None` in which case the FFT is
222///  over the last `len(s)` axes or all axes if `s` is also `None`.
223#[generate_macro(customize(root = "$crate::fft"))]
224#[default_device]
225pub fn irfftn_device<'a>(
226    a: impl AsRef<Array>,
227    #[optional] s: impl IntoOption<&'a [i32]>,
228    #[optional] axes: impl IntoOption<&'a [i32]>,
229    #[optional] stream: impl AsRef<Stream>,
230) -> Result<Array> {
231    let a = as_complex64(a.as_ref())?;
232    let s = s.into_option();
233    let axes = axes.into_option();
234    let modify_last_axis = s.is_none();
235
236    let (mut s, axes) = resolve_sizes_and_axes_unchecked(&a, s, axes);
237    if modify_last_axis {
238        let end = s.len() - 1;
239        s[end] = (s[end] - 1) * 2;
240    }
241
242    let num_s = s.len();
243    let num_axes = axes.len();
244
245    let s_ptr = s.as_ptr();
246    let axes_ptr = axes.as_ptr();
247
248    Array::try_from_op(|res| unsafe {
249        mlx_sys::mlx_fft_irfftn(
250            res,
251            a.as_ptr(),
252            s_ptr,
253            num_s,
254            axes_ptr,
255            num_axes,
256            stream.as_ref().as_ptr(),
257        )
258    })
259}
260
261#[cfg(test)]
262mod tests {
263    use crate::{complex64, Array, Dtype};
264
265    #[test]
266    fn test_rfft() {
267        const RFFT_DATA: &[f32] = &[1.0, 2.0, 3.0, 4.0];
268        const RFFT_N: i32 = 4;
269        const RFFT_SHAPE: &[i32] = &[RFFT_N];
270        const RFFT_AXIS: i32 = -1;
271        const RFFT_EXPECTED: &[complex64] = &[
272            complex64::new(10.0, 0.0),
273            complex64::new(-2.0, 2.0),
274            complex64::new(-2.0, 0.0),
275        ];
276
277        let a = Array::from_slice(RFFT_DATA, RFFT_SHAPE);
278        let rfft = super::rfft(&a, RFFT_N, RFFT_AXIS).unwrap();
279        assert_eq!(rfft.dtype(), Dtype::Complex64);
280        assert_eq!(rfft.as_slice::<complex64>(), RFFT_EXPECTED);
281
282        let irfft = super::irfft(&rfft, RFFT_N, RFFT_AXIS).unwrap();
283        assert_eq!(irfft.dtype(), Dtype::Float32);
284        assert_eq!(irfft.as_slice::<f32>(), RFFT_DATA);
285    }
286
287    #[test]
288    fn test_rfft_shape_with_default_params() {
289        const IN_N: i32 = 8;
290        const OUT_N: i32 = IN_N / 2 + 1;
291
292        let a = Array::ones::<f32>(&[IN_N]).unwrap();
293        let rfft = super::rfft(&a, None, None).unwrap();
294        assert_eq!(rfft.shape(), &[OUT_N]);
295    }
296
297    #[test]
298    fn test_irfft_shape_with_default_params() {
299        const IN_N: i32 = 8;
300        const OUT_N: i32 = (IN_N - 1) * 2;
301
302        let a = Array::ones::<f32>(&[IN_N]).unwrap();
303        let irfft = super::irfft(&a, None, None).unwrap();
304        assert_eq!(irfft.shape(), &[OUT_N]);
305    }
306
307    #[test]
308    fn test_rfft2() {
309        const RFFT2_DATA: &[f32] = &[1.0; 4];
310        const RFFT2_SHAPE: &[i32] = &[2, 2];
311        const RFFT2_EXPECTED: &[complex64] = &[
312            complex64::new(4.0, 0.0),
313            complex64::new(0.0, 0.0),
314            complex64::new(0.0, 0.0),
315            complex64::new(0.0, 0.0),
316        ];
317
318        let a = Array::from_slice(RFFT2_DATA, RFFT2_SHAPE);
319        let rfft2 = super::rfft2(&a, None, None).unwrap();
320        assert_eq!(rfft2.dtype(), Dtype::Complex64);
321        assert_eq!(rfft2.as_slice::<complex64>(), RFFT2_EXPECTED);
322
323        let irfft2 = super::irfft2(&rfft2, None, None).unwrap();
324        assert_eq!(irfft2.dtype(), Dtype::Float32);
325        assert_eq!(irfft2.as_slice::<f32>(), RFFT2_DATA);
326    }
327
328    #[test]
329    fn test_rfft2_shape_with_default_params() {
330        const IN_SHAPE: &[i32] = &[6, 6];
331        const OUT_SHAPE: &[i32] = &[6, 6 / 2 + 1];
332
333        let a = Array::ones::<f32>(IN_SHAPE).unwrap();
334        let rfft2 = super::rfft2(&a, None, None).unwrap();
335        assert_eq!(rfft2.shape(), OUT_SHAPE);
336    }
337
338    #[test]
339    fn test_irfft2_shape_with_default_params() {
340        const IN_SHAPE: &[i32] = &[6, 6];
341        const OUT_SHAPE: &[i32] = &[6, (6 - 1) * 2];
342
343        let a = Array::ones::<f32>(IN_SHAPE).unwrap();
344        let irfft2 = super::irfft2(&a, None, None).unwrap();
345        assert_eq!(irfft2.shape(), OUT_SHAPE);
346    }
347
348    #[test]
349    fn test_rfftn() {
350        const RFFTN_DATA: &[f32] = &[1.0; 8];
351        const RFFTN_SHAPE: &[i32] = &[2, 2, 2];
352        const RFFTN_EXPECTED: &[complex64] = &[
353            complex64::new(8.0, 0.0),
354            complex64::new(0.0, 0.0),
355            complex64::new(0.0, 0.0),
356            complex64::new(0.0, 0.0),
357            complex64::new(0.0, 0.0),
358            complex64::new(0.0, 0.0),
359            complex64::new(0.0, 0.0),
360            complex64::new(0.0, 0.0),
361        ];
362
363        let a = Array::from_slice(RFFTN_DATA, RFFTN_SHAPE);
364        let rfftn = super::rfftn(&a, None, None).unwrap();
365        assert_eq!(rfftn.dtype(), Dtype::Complex64);
366        assert_eq!(rfftn.as_slice::<complex64>(), RFFTN_EXPECTED);
367
368        let irfftn = super::irfftn(&rfftn, None, None).unwrap();
369        assert_eq!(irfftn.dtype(), Dtype::Float32);
370        assert_eq!(irfftn.as_slice::<f32>(), RFFTN_DATA);
371    }
372
373    #[test]
374    fn test_fftn_shape_with_default_params() {
375        const IN_SHAPE: &[i32] = &[6, 6, 6];
376        const OUT_SHAPE: &[i32] = &[6, 6, 6 / 2 + 1];
377
378        let a = Array::ones::<f32>(IN_SHAPE).unwrap();
379        let rfftn = super::rfftn(&a, None, None).unwrap();
380        assert_eq!(rfftn.shape(), OUT_SHAPE);
381    }
382
383    #[test]
384    fn test_irfftn_shape_with_default_params() {
385        const IN_SHAPE: &[i32] = &[6, 6, 6];
386        const OUT_SHAPE: &[i32] = &[6, 6, (6 - 1) * 2];
387
388        let a = Array::ones::<f32>(IN_SHAPE).unwrap();
389        let irfftn = super::irfftn(&a, None, None).unwrap();
390        assert_eq!(irfftn.shape(), OUT_SHAPE);
391    }
392}