1use mlx_internal_macros::{default_device, generate_macro};
2
3use crate::{
4 array::Array,
5 error::Result,
6 stream::StreamOrDevice,
7 utils::{guard::Guarded, IntoOption},
8 Stream,
9};
10
11use super::{
12 as_complex64,
13 utils::{resolve_size_and_axis_unchecked, resolve_sizes_and_axes_unchecked},
14};
15
16#[generate_macro(customize(root = "$crate::fft"))]
25#[default_device]
26pub fn fft_device(
27 a: impl AsRef<Array>,
28 #[optional] n: impl Into<Option<i32>>,
29 #[optional] axis: impl Into<Option<i32>>,
30 #[optional] stream: impl AsRef<Stream>,
31) -> Result<Array> {
32 let a = as_complex64(a.as_ref())?;
33
34 let (n, axis) = resolve_size_and_axis_unchecked(&a, n.into(), axis.into());
35 Array::try_from_op(|res| unsafe {
36 mlx_sys::mlx_fft_fft(res, a.as_ptr(), n, axis, stream.as_ref().as_ptr())
37 })
38}
39
40#[generate_macro(customize(root = "$crate::fft"))]
49#[default_device]
50pub fn fft2_device<'a>(
51 a: impl AsRef<Array>,
52 #[optional] s: impl IntoOption<&'a [i32]>,
53 #[optional] axes: impl IntoOption<&'a [i32]>,
54 #[optional] stream: impl AsRef<Stream>,
55) -> Result<Array> {
56 let a = as_complex64(a.as_ref())?;
57 let axes = axes.into_option().unwrap_or(&[-2, -1]);
58 let (s, axes) = resolve_sizes_and_axes_unchecked(&a, s.into_option(), Some(axes));
59
60 let num_s = s.len();
61 let num_axes = axes.len();
62
63 let s_ptr = s.as_ptr();
64 let axes_ptr = axes.as_ptr();
65
66 Array::try_from_op(|res| unsafe {
67 mlx_sys::mlx_fft_fft2(
68 res,
69 a.as_ptr(),
70 s_ptr,
71 num_s,
72 axes_ptr,
73 num_axes,
74 stream.as_ref().as_ptr(),
75 )
76 })
77}
78
79#[generate_macro(customize(root = "$crate::fft"))]
90#[default_device]
91pub fn fftn_device<'a>(
92 a: impl AsRef<Array>,
93 #[optional] s: impl IntoOption<&'a [i32]>,
94 #[optional] axes: impl IntoOption<&'a [i32]>,
95 #[optional] stream: impl AsRef<Stream>,
96) -> Result<Array> {
97 let a = as_complex64(a.as_ref())?;
98 let (s, axes) = resolve_sizes_and_axes_unchecked(&a, s.into_option(), axes.into_option());
99 let num_s = s.len();
100 let num_axes = axes.len();
101
102 let s_ptr = s.as_ptr();
103 let axes_ptr = axes.as_ptr();
104
105 Array::try_from_op(|res| unsafe {
106 mlx_sys::mlx_fft_fftn(
107 res,
108 a.as_ptr(),
109 s_ptr,
110 num_s,
111 axes_ptr,
112 num_axes,
113 stream.as_ref().as_ptr(),
114 )
115 })
116}
117
118#[generate_macro(customize(root = "$crate::fft"))]
127#[default_device]
128pub fn ifft_device(
129 a: impl AsRef<Array>,
130 #[optional] n: impl Into<Option<i32>>,
131 #[optional] axis: impl Into<Option<i32>>,
132 #[optional] stream: impl AsRef<Stream>,
133) -> Result<Array> {
134 let a = as_complex64(a.as_ref())?;
135 let (n, axis) = resolve_size_and_axis_unchecked(&a, n.into(), axis.into());
136
137 Array::try_from_op(|res| unsafe {
138 mlx_sys::mlx_fft_ifft(res, a.as_ptr(), n, axis, stream.as_ref().as_ptr())
139 })
140}
141
142#[generate_macro(customize(root = "$crate::fft"))]
151#[default_device]
152pub fn ifft2_device<'a>(
153 a: impl AsRef<Array>,
154 #[optional] s: impl IntoOption<&'a [i32]>,
155 #[optional] axes: impl IntoOption<&'a [i32]>,
156 #[optional] stream: impl AsRef<Stream>,
157) -> Result<Array> {
158 let a = as_complex64(a.as_ref())?;
159 let axes = axes.into_option().unwrap_or(&[-2, -1]);
160 let (s, axes) = resolve_sizes_and_axes_unchecked(&a, s.into_option(), Some(axes));
161
162 let num_s = s.len();
163 let num_axes = axes.len();
164
165 let s_ptr = s.as_ptr();
166 let axes_ptr = axes.as_ptr();
167
168 Array::try_from_op(|res| unsafe {
169 mlx_sys::mlx_fft_ifft2(
170 res,
171 a.as_ptr(),
172 s_ptr,
173 num_s,
174 axes_ptr,
175 num_axes,
176 stream.as_ref().as_ptr(),
177 )
178 })
179}
180
181#[generate_macro(customize(root = "$crate::fft"))]
192#[default_device]
193pub fn ifftn_device<'a>(
194 a: impl AsRef<Array>,
195 #[optional] s: impl IntoOption<&'a [i32]>,
196 #[optional] axes: impl IntoOption<&'a [i32]>,
197 #[optional] stream: impl AsRef<Stream>,
198) -> Result<Array> {
199 let a = as_complex64(a.as_ref())?;
200 let (s, axes) = resolve_sizes_and_axes_unchecked(&a, s.into_option(), axes.into_option());
201 let num_s = s.len();
202 let num_axes = axes.len();
203
204 let s_ptr = s.as_ptr();
205 let axes_ptr = axes.as_ptr();
206
207 Array::try_from_op(|res| unsafe {
208 mlx_sys::mlx_fft_ifftn(
209 res,
210 a.as_ptr(),
211 s_ptr,
212 num_s,
213 axes_ptr,
214 num_axes,
215 stream.as_ref().as_ptr(),
216 )
217 })
218}
219
220#[cfg(test)]
221mod tests {
222 use crate::{complex64, fft::*, Array, Dtype};
223
224 #[test]
225 fn test_fft() {
226 const FFT_DATA: &[f32] = &[1.0, 2.0, 3.0, 4.0];
227 const FFT_SHAPE: &[i32] = &[4];
228 const FFT_EXPECTED: &[complex64; 4] = &[
229 complex64::new(10.0, 0.0),
230 complex64::new(-2.0, 2.0),
231 complex64::new(-2.0, 0.0),
232 complex64::new(-2.0, -2.0),
233 ];
234
235 let array = Array::from_slice(FFT_DATA, FFT_SHAPE);
236 let fft = fft(&array, None, None).unwrap();
237
238 assert_eq!(fft.dtype(), Dtype::Complex64);
239 assert_eq!(fft.as_slice::<complex64>(), FFT_EXPECTED);
240
241 let ifft = ifft(&fft, None, None).unwrap();
242
243 assert_eq!(ifft.dtype(), Dtype::Complex64);
244 assert_eq!(
245 ifft.as_slice::<complex64>(),
246 FFT_DATA
247 .iter()
248 .map(|&x| complex64::new(x, 0.0))
249 .collect::<Vec<_>>()
250 );
251
252 let data: &[f32] = array.as_slice();
254 assert_eq!(data, FFT_DATA);
255 }
256
257 #[test]
258 fn test_fft2() {
259 const FFT2_DATA: &[f32] = &[1.0, 1.0, 1.0, 1.0];
260 const FFT2_SHAPE: &[i32] = &[2, 2];
261 const FFT2_EXPECTED: &[complex64; 4] = &[
262 complex64::new(4.0, 0.0),
263 complex64::new(0.0, 0.0),
264 complex64::new(0.0, 0.0),
265 complex64::new(0.0, 0.0),
266 ];
267
268 let array = Array::from_slice(FFT2_DATA, FFT2_SHAPE);
269 let fft2 = fft2(&array, None, None).unwrap();
270
271 assert_eq!(fft2.dtype(), Dtype::Complex64);
272 assert_eq!(fft2.as_slice::<complex64>(), FFT2_EXPECTED);
273
274 let ifft2 = ifft2(&fft2, None, None).unwrap();
275
276 assert_eq!(ifft2.dtype(), Dtype::Complex64);
277 assert_eq!(
278 ifft2.as_slice::<complex64>(),
279 FFT2_DATA
280 .iter()
281 .map(|&x| complex64::new(x, 0.0))
282 .collect::<Vec<_>>()
283 );
284
285 let data: &[f32] = array.as_slice();
287 assert_eq!(data, FFT2_DATA);
288 }
289
290 #[test]
291 fn test_fftn() {
292 const FFTN_DATA: &[f32] = &[1.0; 8];
293 const FFTN_SHAPE: &[i32] = &[2, 2, 2];
294 const FFTN_EXPECTED: &[complex64; 8] = &[
295 complex64::new(8.0, 0.0),
296 complex64::new(0.0, 0.0),
297 complex64::new(0.0, 0.0),
298 complex64::new(0.0, 0.0),
299 complex64::new(0.0, 0.0),
300 complex64::new(0.0, 0.0),
301 complex64::new(0.0, 0.0),
302 complex64::new(0.0, 0.0),
303 ];
304
305 let array = Array::from_slice(FFTN_DATA, FFTN_SHAPE);
306 let fftn = fftn(&array, None, None).unwrap();
307
308 assert_eq!(fftn.dtype(), Dtype::Complex64);
309 assert_eq!(fftn.as_slice::<complex64>(), FFTN_EXPECTED);
310
311 let ifftn = ifftn(&fftn, FFTN_SHAPE, &[0, 1, 2]).unwrap();
312
313 assert_eq!(ifftn.dtype(), Dtype::Complex64);
314 assert_eq!(
315 ifftn.as_slice::<complex64>(),
316 FFTN_DATA
317 .iter()
318 .map(|&x| complex64::new(x, 0.0))
319 .collect::<Vec<_>>()
320 );
321
322 let data: &[f32] = array.as_slice();
324 assert_eq!(data, FFTN_DATA);
325 }
326}