1use mlx_internal_macros::{default_device, generate_macro};
2
3use crate::{
4 array::Array,
5 error::Result,
6 utils::{guard::Guarded, IntoOption},
7 Stream,
8};
9
10use super::utils::{resolve_size_and_axis_unchecked, resolve_sizes_and_axes_unchecked};
11
12#[generate_macro(customize(root = "$crate::fft"))]
21#[default_device]
22pub fn fft_device(
23 a: impl AsRef<Array>,
24 #[optional] n: impl Into<Option<i32>>,
25 #[optional] axis: impl Into<Option<i32>>,
26 #[optional] stream: impl AsRef<Stream>,
27) -> Result<Array> {
28 let a = a.as_ref();
29 let (n, axis) = resolve_size_and_axis_unchecked(a, n.into(), axis.into());
30 Array::try_from_op(|res| unsafe {
31 mlx_sys::mlx_fft_fft(res, a.as_ptr(), n, axis, stream.as_ref().as_ptr())
32 })
33}
34
35#[generate_macro(customize(root = "$crate::fft"))]
44#[default_device]
45pub fn fft2_device<'a>(
46 a: impl AsRef<Array>,
47 #[optional] s: impl IntoOption<&'a [i32]>,
48 #[optional] axes: impl IntoOption<&'a [i32]>,
49 #[optional] stream: impl AsRef<Stream>,
50) -> Result<Array> {
51 let a = a.as_ref();
52 let axes = axes.into_option().unwrap_or(&[-2, -1]);
53 let (s, axes) = resolve_sizes_and_axes_unchecked(a, s.into_option(), Some(axes));
54
55 let num_s = s.len();
56 let num_axes = axes.len();
57
58 let s_ptr = s.as_ptr();
59 let axes_ptr = axes.as_ptr();
60
61 Array::try_from_op(|res| unsafe {
62 mlx_sys::mlx_fft_fft2(
63 res,
64 a.as_ptr(),
65 s_ptr,
66 num_s,
67 axes_ptr,
68 num_axes,
69 stream.as_ref().as_ptr(),
70 )
71 })
72}
73
74#[generate_macro(customize(root = "$crate::fft"))]
85#[default_device]
86pub fn fftn_device<'a>(
87 a: impl AsRef<Array>,
88 #[optional] s: impl IntoOption<&'a [i32]>,
89 #[optional] axes: impl IntoOption<&'a [i32]>,
90 #[optional] stream: impl AsRef<Stream>,
91) -> Result<Array> {
92 let a = a.as_ref();
93 let (s, axes) = resolve_sizes_and_axes_unchecked(a, s.into_option(), axes.into_option());
94 let num_s = s.len();
95 let num_axes = axes.len();
96
97 let s_ptr = s.as_ptr();
98 let axes_ptr = axes.as_ptr();
99
100 Array::try_from_op(|res| unsafe {
101 mlx_sys::mlx_fft_fftn(
102 res,
103 a.as_ptr(),
104 s_ptr,
105 num_s,
106 axes_ptr,
107 num_axes,
108 stream.as_ref().as_ptr(),
109 )
110 })
111}
112
113#[generate_macro(customize(root = "$crate::fft"))]
122#[default_device]
123pub fn ifft_device(
124 a: impl AsRef<Array>,
125 #[optional] n: impl Into<Option<i32>>,
126 #[optional] axis: impl Into<Option<i32>>,
127 #[optional] stream: impl AsRef<Stream>,
128) -> Result<Array> {
129 let a = a.as_ref();
130 let (n, axis) = resolve_size_and_axis_unchecked(a, n.into(), axis.into());
131
132 Array::try_from_op(|res| unsafe {
133 mlx_sys::mlx_fft_ifft(res, a.as_ptr(), n, axis, stream.as_ref().as_ptr())
134 })
135}
136
137#[generate_macro(customize(root = "$crate::fft"))]
146#[default_device]
147pub fn ifft2_device<'a>(
148 a: impl AsRef<Array>,
149 #[optional] s: impl IntoOption<&'a [i32]>,
150 #[optional] axes: impl IntoOption<&'a [i32]>,
151 #[optional] stream: impl AsRef<Stream>,
152) -> Result<Array> {
153 let a = a.as_ref();
154 let axes = axes.into_option().unwrap_or(&[-2, -1]);
155 let (s, axes) = resolve_sizes_and_axes_unchecked(a, s.into_option(), Some(axes));
156
157 let num_s = s.len();
158 let num_axes = axes.len();
159
160 let s_ptr = s.as_ptr();
161 let axes_ptr = axes.as_ptr();
162
163 Array::try_from_op(|res| unsafe {
164 mlx_sys::mlx_fft_ifft2(
165 res,
166 a.as_ptr(),
167 s_ptr,
168 num_s,
169 axes_ptr,
170 num_axes,
171 stream.as_ref().as_ptr(),
172 )
173 })
174}
175
176#[generate_macro(customize(root = "$crate::fft"))]
187#[default_device]
188pub fn ifftn_device<'a>(
189 a: impl AsRef<Array>,
190 #[optional] s: impl IntoOption<&'a [i32]>,
191 #[optional] axes: impl IntoOption<&'a [i32]>,
192 #[optional] stream: impl AsRef<Stream>,
193) -> Result<Array> {
194 let a = a.as_ref();
195 let (s, axes) = resolve_sizes_and_axes_unchecked(a, s.into_option(), axes.into_option());
196 let num_s = s.len();
197 let num_axes = axes.len();
198
199 let s_ptr = s.as_ptr();
200 let axes_ptr = axes.as_ptr();
201
202 Array::try_from_op(|res| unsafe {
203 mlx_sys::mlx_fft_ifftn(
204 res,
205 a.as_ptr(),
206 s_ptr,
207 num_s,
208 axes_ptr,
209 num_axes,
210 stream.as_ref().as_ptr(),
211 )
212 })
213}
214
215#[cfg(test)]
216mod tests {
217 use crate::{complex64, fft::*, Array, Dtype};
218
219 #[test]
220 fn test_fft() {
221 const FFT_DATA: &[f32] = &[1.0, 2.0, 3.0, 4.0];
222 const FFT_SHAPE: &[i32] = &[4];
223 const FFT_EXPECTED: &[complex64; 4] = &[
224 complex64::new(10.0, 0.0),
225 complex64::new(-2.0, 2.0),
226 complex64::new(-2.0, 0.0),
227 complex64::new(-2.0, -2.0),
228 ];
229
230 let array = Array::from_slice(FFT_DATA, FFT_SHAPE);
231 let fft = fft(&array, None, None).unwrap();
232
233 assert_eq!(fft.dtype(), Dtype::Complex64);
234 assert_eq!(fft.as_slice::<complex64>(), FFT_EXPECTED);
235
236 let ifft = ifft(&fft, None, None).unwrap();
237
238 assert_eq!(ifft.dtype(), Dtype::Complex64);
239 assert_eq!(
240 ifft.as_slice::<complex64>(),
241 FFT_DATA
242 .iter()
243 .map(|&x| complex64::new(x, 0.0))
244 .collect::<Vec<_>>()
245 );
246
247 let data: &[f32] = array.as_slice();
249 assert_eq!(data, FFT_DATA);
250 }
251
252 #[test]
253 fn test_fft2() {
254 const FFT2_DATA: &[f32] = &[1.0, 1.0, 1.0, 1.0];
255 const FFT2_SHAPE: &[i32] = &[2, 2];
256 const FFT2_EXPECTED: &[complex64; 4] = &[
257 complex64::new(4.0, 0.0),
258 complex64::new(0.0, 0.0),
259 complex64::new(0.0, 0.0),
260 complex64::new(0.0, 0.0),
261 ];
262
263 let array = Array::from_slice(FFT2_DATA, FFT2_SHAPE);
264 let fft2 = fft2(&array, None, None).unwrap();
265
266 assert_eq!(fft2.dtype(), Dtype::Complex64);
267 assert_eq!(fft2.as_slice::<complex64>(), FFT2_EXPECTED);
268
269 let ifft2 = ifft2(&fft2, None, None).unwrap();
270
271 assert_eq!(ifft2.dtype(), Dtype::Complex64);
272 assert_eq!(
273 ifft2.as_slice::<complex64>(),
274 FFT2_DATA
275 .iter()
276 .map(|&x| complex64::new(x, 0.0))
277 .collect::<Vec<_>>()
278 );
279
280 let data: &[f32] = array.as_slice();
282 assert_eq!(data, FFT2_DATA);
283 }
284
285 #[test]
286 fn test_fftn() {
287 const FFTN_DATA: &[f32] = &[1.0; 8];
288 const FFTN_SHAPE: &[i32] = &[2, 2, 2];
289 const FFTN_EXPECTED: &[complex64; 8] = &[
290 complex64::new(8.0, 0.0),
291 complex64::new(0.0, 0.0),
292 complex64::new(0.0, 0.0),
293 complex64::new(0.0, 0.0),
294 complex64::new(0.0, 0.0),
295 complex64::new(0.0, 0.0),
296 complex64::new(0.0, 0.0),
297 complex64::new(0.0, 0.0),
298 ];
299
300 let array = Array::from_slice(FFTN_DATA, FFTN_SHAPE);
301 let fftn = fftn(&array, None, None).unwrap();
302
303 assert_eq!(fftn.dtype(), Dtype::Complex64);
304 assert_eq!(fftn.as_slice::<complex64>(), FFTN_EXPECTED);
305
306 let ifftn = ifftn(&fftn, FFTN_SHAPE, &[0, 1, 2]).unwrap();
307
308 assert_eq!(ifftn.dtype(), Dtype::Complex64);
309 assert_eq!(
310 ifftn.as_slice::<complex64>(),
311 FFTN_DATA
312 .iter()
313 .map(|&x| complex64::new(x, 0.0))
314 .collect::<Vec<_>>()
315 );
316
317 let data: &[f32] = array.as_slice();
319 assert_eq!(data, FFTN_DATA);
320 }
321}