mlx_rs/fft/
fftn.rs

1use mlx_internal_macros::{default_device, generate_macro};
2
3use crate::{
4    array::Array,
5    error::Result,
6    utils::{guard::Guarded, IntoOption},
7    Stream,
8};
9
10use super::utils::{resolve_size_and_axis_unchecked, resolve_sizes_and_axes_unchecked};
11
12/// One dimensional discrete Fourier Transform.
13///
14/// # Params
15///
16/// - `a`: The input array.
17/// - `n`: Size of the transformed axis. The corresponding axis in the input is truncated or padded
18///   with zeros to match `n`. The default value is `a.shape[axis]`.
19/// - `axis`: Axis along which to perform the FFT. The default is -1.
20#[generate_macro(customize(root = "$crate::fft"))]
21#[default_device]
22pub fn fft_device(
23    a: impl AsRef<Array>,
24    #[optional] n: impl Into<Option<i32>>,
25    #[optional] axis: impl Into<Option<i32>>,
26    #[optional] stream: impl AsRef<Stream>,
27) -> Result<Array> {
28    let a = a.as_ref();
29    let (n, axis) = resolve_size_and_axis_unchecked(a, n.into(), axis.into());
30    Array::try_from_op(|res| unsafe {
31        mlx_sys::mlx_fft_fft(res, a.as_ptr(), n, axis, stream.as_ref().as_ptr())
32    })
33}
34
35/// Two dimensional discrete Fourier Transform.
36///
37/// # Params
38///
39/// - `a`: The input array.
40/// - `s`: Size of the transformed axes. The corresponding axes in the input are truncated or padded
41/// with zeros to match `s`. The default value is the sizes of `a` along `axes`.
42/// - `axes`: Axes along which to perform the FFT. The default is `[-2, -1]`.
43#[generate_macro(customize(root = "$crate::fft"))]
44#[default_device]
45pub fn fft2_device<'a>(
46    a: impl AsRef<Array>,
47    #[optional] s: impl IntoOption<&'a [i32]>,
48    #[optional] axes: impl IntoOption<&'a [i32]>,
49    #[optional] stream: impl AsRef<Stream>,
50) -> Result<Array> {
51    let a = a.as_ref();
52    let axes = axes.into_option().unwrap_or(&[-2, -1]);
53    let (s, axes) = resolve_sizes_and_axes_unchecked(a, s.into_option(), Some(axes));
54
55    let num_s = s.len();
56    let num_axes = axes.len();
57
58    let s_ptr = s.as_ptr();
59    let axes_ptr = axes.as_ptr();
60
61    Array::try_from_op(|res| unsafe {
62        mlx_sys::mlx_fft_fft2(
63            res,
64            a.as_ptr(),
65            s_ptr,
66            num_s,
67            axes_ptr,
68            num_axes,
69            stream.as_ref().as_ptr(),
70        )
71    })
72}
73
74/// n-dimensional discrete Fourier Transform.
75///
76/// # Params
77///
78/// - `a`: The input array.
79/// - `s`: Sizes of the transformed axes. The corresponding axes in the input are truncated or
80/// padded with zeros to match the sizes in `s`. The default value is the sizes of `a` along `axes`
81/// if not specified.
82/// - `axes`: Axes along which to perform the FFT. The default is `None` in which case the FFT is
83/// over the last `len(s)` axes are or all axes if `s` is also `None`.
84#[generate_macro(customize(root = "$crate::fft"))]
85#[default_device]
86pub fn fftn_device<'a>(
87    a: impl AsRef<Array>,
88    #[optional] s: impl IntoOption<&'a [i32]>,
89    #[optional] axes: impl IntoOption<&'a [i32]>,
90    #[optional] stream: impl AsRef<Stream>,
91) -> Result<Array> {
92    let a = a.as_ref();
93    let (s, axes) = resolve_sizes_and_axes_unchecked(a, s.into_option(), axes.into_option());
94    let num_s = s.len();
95    let num_axes = axes.len();
96
97    let s_ptr = s.as_ptr();
98    let axes_ptr = axes.as_ptr();
99
100    Array::try_from_op(|res| unsafe {
101        mlx_sys::mlx_fft_fftn(
102            res,
103            a.as_ptr(),
104            s_ptr,
105            num_s,
106            axes_ptr,
107            num_axes,
108            stream.as_ref().as_ptr(),
109        )
110    })
111}
112
113/// One dimensional inverse discrete Fourier Transform.
114///
115/// # Params
116///
117/// - `a`: Input array.
118/// - `n`: Size of the transformed axis. The corresponding axis in the input is truncated or padded
119///  with zeros to match `n`. The default value is `a.shape[axis]` if not specified.
120/// - `axis`: Axis along which to perform the FFT. The default is `-1` if not specified.
121#[generate_macro(customize(root = "$crate::fft"))]
122#[default_device]
123pub fn ifft_device(
124    a: impl AsRef<Array>,
125    #[optional] n: impl Into<Option<i32>>,
126    #[optional] axis: impl Into<Option<i32>>,
127    #[optional] stream: impl AsRef<Stream>,
128) -> Result<Array> {
129    let a = a.as_ref();
130    let (n, axis) = resolve_size_and_axis_unchecked(a, n.into(), axis.into());
131
132    Array::try_from_op(|res| unsafe {
133        mlx_sys::mlx_fft_ifft(res, a.as_ptr(), n, axis, stream.as_ref().as_ptr())
134    })
135}
136
137/// Two dimensional inverse discrete Fourier Transform.
138///
139/// # Params
140///
141/// - `a`: The input array.
142/// - `s`: Size of the transformed axes. The corresponding axes in the input are truncated or padded
143/// with zeros to match `s`. The default value is the sizes of `a` along `axes`.
144/// - `axes`: Axes along which to perform the FFT. The default is `[-2, -1]`.
145#[generate_macro(customize(root = "$crate::fft"))]
146#[default_device]
147pub fn ifft2_device<'a>(
148    a: impl AsRef<Array>,
149    #[optional] s: impl IntoOption<&'a [i32]>,
150    #[optional] axes: impl IntoOption<&'a [i32]>,
151    #[optional] stream: impl AsRef<Stream>,
152) -> Result<Array> {
153    let a = a.as_ref();
154    let axes = axes.into_option().unwrap_or(&[-2, -1]);
155    let (s, axes) = resolve_sizes_and_axes_unchecked(a, s.into_option(), Some(axes));
156
157    let num_s = s.len();
158    let num_axes = axes.len();
159
160    let s_ptr = s.as_ptr();
161    let axes_ptr = axes.as_ptr();
162
163    Array::try_from_op(|res| unsafe {
164        mlx_sys::mlx_fft_ifft2(
165            res,
166            a.as_ptr(),
167            s_ptr,
168            num_s,
169            axes_ptr,
170            num_axes,
171            stream.as_ref().as_ptr(),
172        )
173    })
174}
175
176/// n-dimensional inverse discrete Fourier Transform.
177///
178/// # Params
179///
180/// - `a`: The input array.
181/// - `s`: Sizes of the transformed axes. The corresponding axes in the input are truncated or
182/// padded with zeros to match the sizes in `s`. The default value is the sizes of `a` along `axes`
183/// if not specified.
184/// - `axes`: Axes along which to perform the FFT. The default is `None` in which case the FFT is
185/// over the last `len(s)` axes are or all axes if `s` is also `None`.
186#[generate_macro(customize(root = "$crate::fft"))]
187#[default_device]
188pub fn ifftn_device<'a>(
189    a: impl AsRef<Array>,
190    #[optional] s: impl IntoOption<&'a [i32]>,
191    #[optional] axes: impl IntoOption<&'a [i32]>,
192    #[optional] stream: impl AsRef<Stream>,
193) -> Result<Array> {
194    let a = a.as_ref();
195    let (s, axes) = resolve_sizes_and_axes_unchecked(a, s.into_option(), axes.into_option());
196    let num_s = s.len();
197    let num_axes = axes.len();
198
199    let s_ptr = s.as_ptr();
200    let axes_ptr = axes.as_ptr();
201
202    Array::try_from_op(|res| unsafe {
203        mlx_sys::mlx_fft_ifftn(
204            res,
205            a.as_ptr(),
206            s_ptr,
207            num_s,
208            axes_ptr,
209            num_axes,
210            stream.as_ref().as_ptr(),
211        )
212    })
213}
214
215#[cfg(test)]
216mod tests {
217    use crate::{complex64, fft::*, Array, Dtype};
218
219    #[test]
220    fn test_fft() {
221        const FFT_DATA: &[f32] = &[1.0, 2.0, 3.0, 4.0];
222        const FFT_SHAPE: &[i32] = &[4];
223        const FFT_EXPECTED: &[complex64; 4] = &[
224            complex64::new(10.0, 0.0),
225            complex64::new(-2.0, 2.0),
226            complex64::new(-2.0, 0.0),
227            complex64::new(-2.0, -2.0),
228        ];
229
230        let array = Array::from_slice(FFT_DATA, FFT_SHAPE);
231        let fft = fft(&array, None, None).unwrap();
232
233        assert_eq!(fft.dtype(), Dtype::Complex64);
234        assert_eq!(fft.as_slice::<complex64>(), FFT_EXPECTED);
235
236        let ifft = ifft(&fft, None, None).unwrap();
237
238        assert_eq!(ifft.dtype(), Dtype::Complex64);
239        assert_eq!(
240            ifft.as_slice::<complex64>(),
241            FFT_DATA
242                .iter()
243                .map(|&x| complex64::new(x, 0.0))
244                .collect::<Vec<_>>()
245        );
246
247        // The original array is not modified and valid
248        let data: &[f32] = array.as_slice();
249        assert_eq!(data, FFT_DATA);
250    }
251
252    #[test]
253    fn test_fft2() {
254        const FFT2_DATA: &[f32] = &[1.0, 1.0, 1.0, 1.0];
255        const FFT2_SHAPE: &[i32] = &[2, 2];
256        const FFT2_EXPECTED: &[complex64; 4] = &[
257            complex64::new(4.0, 0.0),
258            complex64::new(0.0, 0.0),
259            complex64::new(0.0, 0.0),
260            complex64::new(0.0, 0.0),
261        ];
262
263        let array = Array::from_slice(FFT2_DATA, FFT2_SHAPE);
264        let fft2 = fft2(&array, None, None).unwrap();
265
266        assert_eq!(fft2.dtype(), Dtype::Complex64);
267        assert_eq!(fft2.as_slice::<complex64>(), FFT2_EXPECTED);
268
269        let ifft2 = ifft2(&fft2, None, None).unwrap();
270
271        assert_eq!(ifft2.dtype(), Dtype::Complex64);
272        assert_eq!(
273            ifft2.as_slice::<complex64>(),
274            FFT2_DATA
275                .iter()
276                .map(|&x| complex64::new(x, 0.0))
277                .collect::<Vec<_>>()
278        );
279
280        // test that previous array is not modified and valid
281        let data: &[f32] = array.as_slice();
282        assert_eq!(data, FFT2_DATA);
283    }
284
285    #[test]
286    fn test_fftn() {
287        const FFTN_DATA: &[f32] = &[1.0; 8];
288        const FFTN_SHAPE: &[i32] = &[2, 2, 2];
289        const FFTN_EXPECTED: &[complex64; 8] = &[
290            complex64::new(8.0, 0.0),
291            complex64::new(0.0, 0.0),
292            complex64::new(0.0, 0.0),
293            complex64::new(0.0, 0.0),
294            complex64::new(0.0, 0.0),
295            complex64::new(0.0, 0.0),
296            complex64::new(0.0, 0.0),
297            complex64::new(0.0, 0.0),
298        ];
299
300        let array = Array::from_slice(FFTN_DATA, FFTN_SHAPE);
301        let fftn = fftn(&array, None, None).unwrap();
302
303        assert_eq!(fftn.dtype(), Dtype::Complex64);
304        assert_eq!(fftn.as_slice::<complex64>(), FFTN_EXPECTED);
305
306        let ifftn = ifftn(&fftn, FFTN_SHAPE, &[0, 1, 2]).unwrap();
307
308        assert_eq!(ifftn.dtype(), Dtype::Complex64);
309        assert_eq!(
310            ifftn.as_slice::<complex64>(),
311            FFTN_DATA
312                .iter()
313                .map(|&x| complex64::new(x, 0.0))
314                .collect::<Vec<_>>()
315        );
316
317        // test that previous array is not modified and valid
318        let data: &[f32] = array.as_slice();
319        assert_eq!(data, FFTN_DATA);
320    }
321}