mlx_rs/fft/
fftn.rs

1use mlx_internal_macros::{default_device, generate_macro};
2
3use crate::{
4    array::Array,
5    error::Result,
6    stream::StreamOrDevice,
7    utils::{guard::Guarded, IntoOption},
8    Stream,
9};
10
11use super::{
12    as_complex64,
13    utils::{resolve_size_and_axis_unchecked, resolve_sizes_and_axes_unchecked},
14};
15
16/// One dimensional discrete Fourier Transform.
17///
18/// # Params
19///
20/// - `a`: The input array.
21/// - `n`: Size of the transformed axis. The corresponding axis in the input is truncated or padded
22///   with zeros to match `n`. The default value is `a.shape[axis]`.
23/// - `axis`: Axis along which to perform the FFT. The default is -1.
24#[generate_macro(customize(root = "$crate::fft"))]
25#[default_device]
26pub fn fft_device(
27    a: impl AsRef<Array>,
28    #[optional] n: impl Into<Option<i32>>,
29    #[optional] axis: impl Into<Option<i32>>,
30    #[optional] stream: impl AsRef<Stream>,
31) -> Result<Array> {
32    let a = as_complex64(a.as_ref())?;
33
34    let (n, axis) = resolve_size_and_axis_unchecked(&a, n.into(), axis.into());
35    Array::try_from_op(|res| unsafe {
36        mlx_sys::mlx_fft_fft(res, a.as_ptr(), n, axis, stream.as_ref().as_ptr())
37    })
38}
39
40/// Two dimensional discrete Fourier Transform.
41///
42/// # Params
43///
44/// - `a`: The input array.
45/// - `s`: Size of the transformed axes. The corresponding axes in the input are truncated or padded
46/// with zeros to match `s`. The default value is the sizes of `a` along `axes`.
47/// - `axes`: Axes along which to perform the FFT. The default is `[-2, -1]`.
48#[generate_macro(customize(root = "$crate::fft"))]
49#[default_device]
50pub fn fft2_device<'a>(
51    a: impl AsRef<Array>,
52    #[optional] s: impl IntoOption<&'a [i32]>,
53    #[optional] axes: impl IntoOption<&'a [i32]>,
54    #[optional] stream: impl AsRef<Stream>,
55) -> Result<Array> {
56    let a = as_complex64(a.as_ref())?;
57    let axes = axes.into_option().unwrap_or(&[-2, -1]);
58    let (s, axes) = resolve_sizes_and_axes_unchecked(&a, s.into_option(), Some(axes));
59
60    let num_s = s.len();
61    let num_axes = axes.len();
62
63    let s_ptr = s.as_ptr();
64    let axes_ptr = axes.as_ptr();
65
66    Array::try_from_op(|res| unsafe {
67        mlx_sys::mlx_fft_fft2(
68            res,
69            a.as_ptr(),
70            s_ptr,
71            num_s,
72            axes_ptr,
73            num_axes,
74            stream.as_ref().as_ptr(),
75        )
76    })
77}
78
79/// n-dimensional discrete Fourier Transform.
80///
81/// # Params
82///
83/// - `a`: The input array.
84/// - `s`: Sizes of the transformed axes. The corresponding axes in the input are truncated or
85/// padded with zeros to match the sizes in `s`. The default value is the sizes of `a` along `axes`
86/// if not specified.
87/// - `axes`: Axes along which to perform the FFT. The default is `None` in which case the FFT is
88/// over the last `len(s)` axes are or all axes if `s` is also `None`.
89#[generate_macro(customize(root = "$crate::fft"))]
90#[default_device]
91pub fn fftn_device<'a>(
92    a: impl AsRef<Array>,
93    #[optional] s: impl IntoOption<&'a [i32]>,
94    #[optional] axes: impl IntoOption<&'a [i32]>,
95    #[optional] stream: impl AsRef<Stream>,
96) -> Result<Array> {
97    let a = as_complex64(a.as_ref())?;
98    let (s, axes) = resolve_sizes_and_axes_unchecked(&a, s.into_option(), axes.into_option());
99    let num_s = s.len();
100    let num_axes = axes.len();
101
102    let s_ptr = s.as_ptr();
103    let axes_ptr = axes.as_ptr();
104
105    Array::try_from_op(|res| unsafe {
106        mlx_sys::mlx_fft_fftn(
107            res,
108            a.as_ptr(),
109            s_ptr,
110            num_s,
111            axes_ptr,
112            num_axes,
113            stream.as_ref().as_ptr(),
114        )
115    })
116}
117
118/// One dimensional inverse discrete Fourier Transform.
119///
120/// # Params
121///
122/// - `a`: Input array.
123/// - `n`: Size of the transformed axis. The corresponding axis in the input is truncated or padded
124///  with zeros to match `n`. The default value is `a.shape[axis]` if not specified.
125/// - `axis`: Axis along which to perform the FFT. The default is `-1` if not specified.
126#[generate_macro(customize(root = "$crate::fft"))]
127#[default_device]
128pub fn ifft_device(
129    a: impl AsRef<Array>,
130    #[optional] n: impl Into<Option<i32>>,
131    #[optional] axis: impl Into<Option<i32>>,
132    #[optional] stream: impl AsRef<Stream>,
133) -> Result<Array> {
134    let a = as_complex64(a.as_ref())?;
135    let (n, axis) = resolve_size_and_axis_unchecked(&a, n.into(), axis.into());
136
137    Array::try_from_op(|res| unsafe {
138        mlx_sys::mlx_fft_ifft(res, a.as_ptr(), n, axis, stream.as_ref().as_ptr())
139    })
140}
141
142/// Two dimensional inverse discrete Fourier Transform.
143///
144/// # Params
145///
146/// - `a`: The input array.
147/// - `s`: Size of the transformed axes. The corresponding axes in the input are truncated or padded
148/// with zeros to match `s`. The default value is the sizes of `a` along `axes`.
149/// - `axes`: Axes along which to perform the FFT. The default is `[-2, -1]`.
150#[generate_macro(customize(root = "$crate::fft"))]
151#[default_device]
152pub fn ifft2_device<'a>(
153    a: impl AsRef<Array>,
154    #[optional] s: impl IntoOption<&'a [i32]>,
155    #[optional] axes: impl IntoOption<&'a [i32]>,
156    #[optional] stream: impl AsRef<Stream>,
157) -> Result<Array> {
158    let a = as_complex64(a.as_ref())?;
159    let axes = axes.into_option().unwrap_or(&[-2, -1]);
160    let (s, axes) = resolve_sizes_and_axes_unchecked(&a, s.into_option(), Some(axes));
161
162    let num_s = s.len();
163    let num_axes = axes.len();
164
165    let s_ptr = s.as_ptr();
166    let axes_ptr = axes.as_ptr();
167
168    Array::try_from_op(|res| unsafe {
169        mlx_sys::mlx_fft_ifft2(
170            res,
171            a.as_ptr(),
172            s_ptr,
173            num_s,
174            axes_ptr,
175            num_axes,
176            stream.as_ref().as_ptr(),
177        )
178    })
179}
180
181/// n-dimensional inverse discrete Fourier Transform.
182///
183/// # Params
184///
185/// - `a`: The input array.
186/// - `s`: Sizes of the transformed axes. The corresponding axes in the input are truncated or
187/// padded with zeros to match the sizes in `s`. The default value is the sizes of `a` along `axes`
188/// if not specified.
189/// - `axes`: Axes along which to perform the FFT. The default is `None` in which case the FFT is
190/// over the last `len(s)` axes are or all axes if `s` is also `None`.
191#[generate_macro(customize(root = "$crate::fft"))]
192#[default_device]
193pub fn ifftn_device<'a>(
194    a: impl AsRef<Array>,
195    #[optional] s: impl IntoOption<&'a [i32]>,
196    #[optional] axes: impl IntoOption<&'a [i32]>,
197    #[optional] stream: impl AsRef<Stream>,
198) -> Result<Array> {
199    let a = as_complex64(a.as_ref())?;
200    let (s, axes) = resolve_sizes_and_axes_unchecked(&a, s.into_option(), axes.into_option());
201    let num_s = s.len();
202    let num_axes = axes.len();
203
204    let s_ptr = s.as_ptr();
205    let axes_ptr = axes.as_ptr();
206
207    Array::try_from_op(|res| unsafe {
208        mlx_sys::mlx_fft_ifftn(
209            res,
210            a.as_ptr(),
211            s_ptr,
212            num_s,
213            axes_ptr,
214            num_axes,
215            stream.as_ref().as_ptr(),
216        )
217    })
218}
219
220#[cfg(test)]
221mod tests {
222    use crate::{complex64, fft::*, Array, Dtype};
223
224    #[test]
225    fn test_fft() {
226        const FFT_DATA: &[f32] = &[1.0, 2.0, 3.0, 4.0];
227        const FFT_SHAPE: &[i32] = &[4];
228        const FFT_EXPECTED: &[complex64; 4] = &[
229            complex64::new(10.0, 0.0),
230            complex64::new(-2.0, 2.0),
231            complex64::new(-2.0, 0.0),
232            complex64::new(-2.0, -2.0),
233        ];
234
235        let array = Array::from_slice(FFT_DATA, FFT_SHAPE);
236        let fft = fft(&array, None, None).unwrap();
237
238        assert_eq!(fft.dtype(), Dtype::Complex64);
239        assert_eq!(fft.as_slice::<complex64>(), FFT_EXPECTED);
240
241        let ifft = ifft(&fft, None, None).unwrap();
242
243        assert_eq!(ifft.dtype(), Dtype::Complex64);
244        assert_eq!(
245            ifft.as_slice::<complex64>(),
246            FFT_DATA
247                .iter()
248                .map(|&x| complex64::new(x, 0.0))
249                .collect::<Vec<_>>()
250        );
251
252        // The original array is not modified and valid
253        let data: &[f32] = array.as_slice();
254        assert_eq!(data, FFT_DATA);
255    }
256
257    #[test]
258    fn test_fft2() {
259        const FFT2_DATA: &[f32] = &[1.0, 1.0, 1.0, 1.0];
260        const FFT2_SHAPE: &[i32] = &[2, 2];
261        const FFT2_EXPECTED: &[complex64; 4] = &[
262            complex64::new(4.0, 0.0),
263            complex64::new(0.0, 0.0),
264            complex64::new(0.0, 0.0),
265            complex64::new(0.0, 0.0),
266        ];
267
268        let array = Array::from_slice(FFT2_DATA, FFT2_SHAPE);
269        let fft2 = fft2(&array, None, None).unwrap();
270
271        assert_eq!(fft2.dtype(), Dtype::Complex64);
272        assert_eq!(fft2.as_slice::<complex64>(), FFT2_EXPECTED);
273
274        let ifft2 = ifft2(&fft2, None, None).unwrap();
275
276        assert_eq!(ifft2.dtype(), Dtype::Complex64);
277        assert_eq!(
278            ifft2.as_slice::<complex64>(),
279            FFT2_DATA
280                .iter()
281                .map(|&x| complex64::new(x, 0.0))
282                .collect::<Vec<_>>()
283        );
284
285        // test that previous array is not modified and valid
286        let data: &[f32] = array.as_slice();
287        assert_eq!(data, FFT2_DATA);
288    }
289
290    #[test]
291    fn test_fftn() {
292        const FFTN_DATA: &[f32] = &[1.0; 8];
293        const FFTN_SHAPE: &[i32] = &[2, 2, 2];
294        const FFTN_EXPECTED: &[complex64; 8] = &[
295            complex64::new(8.0, 0.0),
296            complex64::new(0.0, 0.0),
297            complex64::new(0.0, 0.0),
298            complex64::new(0.0, 0.0),
299            complex64::new(0.0, 0.0),
300            complex64::new(0.0, 0.0),
301            complex64::new(0.0, 0.0),
302            complex64::new(0.0, 0.0),
303        ];
304
305        let array = Array::from_slice(FFTN_DATA, FFTN_SHAPE);
306        let fftn = fftn(&array, None, None).unwrap();
307
308        assert_eq!(fftn.dtype(), Dtype::Complex64);
309        assert_eq!(fftn.as_slice::<complex64>(), FFTN_EXPECTED);
310
311        let ifftn = ifftn(&fftn, FFTN_SHAPE, &[0, 1, 2]).unwrap();
312
313        assert_eq!(ifftn.dtype(), Dtype::Complex64);
314        assert_eq!(
315            ifftn.as_slice::<complex64>(),
316            FFTN_DATA
317                .iter()
318                .map(|&x| complex64::new(x, 0.0))
319                .collect::<Vec<_>>()
320        );
321
322        // test that previous array is not modified and valid
323        let data: &[f32] = array.as_slice();
324        assert_eq!(data, FFTN_DATA);
325    }
326}