1use mlx_internal_macros::{default_device, generate_macro};
2
3use crate::{
4 error::Result,
5 utils::{guard::Guarded, IntoOption},
6 Array, Stream,
7};
8
9use super::utils::{resolve_size_and_axis_unchecked, resolve_sizes_and_axes_unchecked};
10
11#[generate_macro(customize(root = "$crate::fft"))]
23#[default_device]
24pub fn rfft_device(
25 a: impl AsRef<Array>,
26 #[optional] n: impl Into<Option<i32>>,
27 #[optional] axis: impl Into<Option<i32>>,
28 #[optional] stream: impl AsRef<Stream>,
29) -> Result<Array> {
30 let a = a.as_ref();
31 let (n, axis) = resolve_size_and_axis_unchecked(a, n.into(), axis.into());
32 Array::try_from_op(|res| unsafe {
33 mlx_sys::mlx_fft_rfft(res, a.as_ptr(), n, axis, stream.as_ref().as_ptr())
34 })
35}
36
37#[generate_macro(customize(root = "$crate::fft"))]
50#[default_device]
51pub fn rfft2_device<'a>(
52 a: impl AsRef<Array>,
53 #[optional] s: impl IntoOption<&'a [i32]>,
54 #[optional] axes: impl IntoOption<&'a [i32]>,
55 #[optional] stream: impl AsRef<Stream>,
56) -> Result<Array> {
57 let a = a.as_ref();
58 let axes = axes.into_option().unwrap_or(&[-2, -1]);
59 let (s, axes) = resolve_sizes_and_axes_unchecked(a, s.into_option(), Some(axes));
60
61 let num_s = s.len();
62 let num_axes = axes.len();
63
64 let s_ptr = s.as_ptr();
65 let axes_ptr = axes.as_ptr();
66
67 Array::try_from_op(|res| unsafe {
68 mlx_sys::mlx_fft_rfft2(
69 res,
70 a.as_ptr(),
71 s_ptr,
72 num_s,
73 axes_ptr,
74 num_axes,
75 stream.as_ref().as_ptr(),
76 )
77 })
78}
79
80#[generate_macro(customize(root = "$crate::fft"))]
94#[default_device]
95pub fn rfftn_device<'a>(
96 a: impl AsRef<Array>,
97 #[optional] s: impl IntoOption<&'a [i32]>,
98 #[optional] axes: impl IntoOption<&'a [i32]>,
99 #[optional] stream: impl AsRef<Stream>,
100) -> Result<Array> {
101 let a = a.as_ref();
102 let (s, axes) = resolve_sizes_and_axes_unchecked(a, s.into_option(), axes.into_option());
103
104 let num_s = s.len();
105 let num_axes = axes.len();
106
107 let s_ptr = s.as_ptr();
108 let axes_ptr = axes.as_ptr();
109
110 Array::try_from_op(|res| unsafe {
111 mlx_sys::mlx_fft_rfftn(
112 res,
113 a.as_ptr(),
114 s_ptr,
115 num_s,
116 axes_ptr,
117 num_axes,
118 stream.as_ref().as_ptr(),
119 )
120 })
121}
122
123#[generate_macro(customize(root = "$crate::fft"))]
134#[default_device]
135pub fn irfft_device(
136 a: impl AsRef<Array>,
137 #[optional] n: impl Into<Option<i32>>,
138 #[optional] axis: impl Into<Option<i32>>,
139 #[optional] stream: impl AsRef<Stream>,
140) -> Result<Array> {
141 let a = a.as_ref();
142 let n = n.into();
143 let axis = axis.into();
144 let modify_n = n.is_none();
145 let (mut n, axis) = resolve_size_and_axis_unchecked(a, n, axis);
146 if modify_n {
147 n = (n - 1) * 2;
148 }
149
150 Array::try_from_op(|res| unsafe {
151 mlx_sys::mlx_fft_irfft(res, a.as_ptr(), n, axis, stream.as_ref().as_ptr())
152 })
153}
154
155#[generate_macro(customize(root = "$crate::fft"))]
169#[default_device]
170pub fn irfft2_device<'a>(
171 a: impl AsRef<Array>,
172 #[optional] s: impl IntoOption<&'a [i32]>,
173 #[optional] axes: impl IntoOption<&'a [i32]>,
174 #[optional] stream: impl AsRef<Stream>,
175) -> Result<Array> {
176 let a = a.as_ref();
177 let s = s.into_option();
178 let axes = axes.into_option().unwrap_or(&[-2, -1]);
179 let modify_last_axis = s.is_none();
180
181 let (mut s, axes) = resolve_sizes_and_axes_unchecked(a, s, Some(axes));
182 if modify_last_axis {
183 let end = s.len() - 1;
184 s[end] = (s[end] - 1) * 2;
185 }
186
187 let num_s = s.len();
188 let num_axes = axes.len();
189
190 let s_ptr = s.as_ptr();
191 let axes_ptr = axes.as_ptr();
192
193 Array::try_from_op(|res| unsafe {
194 mlx_sys::mlx_fft_irfft2(
195 res,
196 a.as_ptr(),
197 s_ptr,
198 num_s,
199 axes_ptr,
200 num_axes,
201 stream.as_ref().as_ptr(),
202 )
203 })
204}
205
206#[generate_macro(customize(root = "$crate::fft"))]
221#[default_device]
222pub fn irfftn_device<'a>(
223 a: impl AsRef<Array>,
224 #[optional] s: impl IntoOption<&'a [i32]>,
225 #[optional] axes: impl IntoOption<&'a [i32]>,
226 #[optional] stream: impl AsRef<Stream>,
227) -> Result<Array> {
228 let a = a.as_ref();
229 let s = s.into_option();
230 let axes = axes.into_option();
231 let modify_last_axis = s.is_none();
232
233 let (mut s, axes) = resolve_sizes_and_axes_unchecked(a, s, axes);
234 if modify_last_axis {
235 let end = s.len() - 1;
236 s[end] = (s[end] - 1) * 2;
237 }
238
239 let num_s = s.len();
240 let num_axes = axes.len();
241
242 let s_ptr = s.as_ptr();
243 let axes_ptr = axes.as_ptr();
244
245 Array::try_from_op(|res| unsafe {
246 mlx_sys::mlx_fft_irfftn(
247 res,
248 a.as_ptr(),
249 s_ptr,
250 num_s,
251 axes_ptr,
252 num_axes,
253 stream.as_ref().as_ptr(),
254 )
255 })
256}
257
258#[cfg(test)]
259mod tests {
260 use crate::{complex64, Array, Dtype};
261
262 #[test]
263 fn test_rfft() {
264 const RFFT_DATA: &[f32] = &[1.0, 2.0, 3.0, 4.0];
265 const RFFT_N: i32 = 4;
266 const RFFT_SHAPE: &[i32] = &[RFFT_N];
267 const RFFT_AXIS: i32 = -1;
268 const RFFT_EXPECTED: &[complex64] = &[
269 complex64::new(10.0, 0.0),
270 complex64::new(-2.0, 2.0),
271 complex64::new(-2.0, 0.0),
272 ];
273
274 let a = Array::from_slice(RFFT_DATA, RFFT_SHAPE);
275 let rfft = super::rfft(&a, RFFT_N, RFFT_AXIS).unwrap();
276 assert_eq!(rfft.dtype(), Dtype::Complex64);
277 assert_eq!(rfft.as_slice::<complex64>(), RFFT_EXPECTED);
278
279 let irfft = super::irfft(&rfft, RFFT_N, RFFT_AXIS).unwrap();
280 assert_eq!(irfft.dtype(), Dtype::Float32);
281 assert_eq!(irfft.as_slice::<f32>(), RFFT_DATA);
282 }
283
284 #[test]
285 fn test_rfft_shape_with_default_params() {
286 const IN_N: i32 = 8;
287 const OUT_N: i32 = IN_N / 2 + 1;
288
289 let a = Array::ones::<f32>(&[IN_N]).unwrap();
290 let rfft = super::rfft(&a, None, None).unwrap();
291 assert_eq!(rfft.shape(), &[OUT_N]);
292 }
293
294 #[test]
295 fn test_irfft_shape_with_default_params() {
296 const IN_N: i32 = 8;
297 const OUT_N: i32 = (IN_N - 1) * 2;
298
299 let a = Array::ones::<f32>(&[IN_N]).unwrap();
300 let irfft = super::irfft(&a, None, None).unwrap();
301 assert_eq!(irfft.shape(), &[OUT_N]);
302 }
303
304 #[test]
305 fn test_rfft2() {
306 const RFFT2_DATA: &[f32] = &[1.0; 4];
307 const RFFT2_SHAPE: &[i32] = &[2, 2];
308 const RFFT2_EXPECTED: &[complex64] = &[
309 complex64::new(4.0, 0.0),
310 complex64::new(0.0, 0.0),
311 complex64::new(0.0, 0.0),
312 complex64::new(0.0, 0.0),
313 ];
314
315 let a = Array::from_slice(RFFT2_DATA, RFFT2_SHAPE);
316 let rfft2 = super::rfft2(&a, None, None).unwrap();
317 assert_eq!(rfft2.dtype(), Dtype::Complex64);
318 assert_eq!(rfft2.as_slice::<complex64>(), RFFT2_EXPECTED);
319
320 let irfft2 = super::irfft2(&rfft2, None, None).unwrap();
321 assert_eq!(irfft2.dtype(), Dtype::Float32);
322 assert_eq!(irfft2.as_slice::<f32>(), RFFT2_DATA);
323 }
324
325 #[test]
326 fn test_rfft2_shape_with_default_params() {
327 const IN_SHAPE: &[i32] = &[6, 6];
328 const OUT_SHAPE: &[i32] = &[6, 6 / 2 + 1];
329
330 let a = Array::ones::<f32>(IN_SHAPE).unwrap();
331 let rfft2 = super::rfft2(&a, None, None).unwrap();
332 assert_eq!(rfft2.shape(), OUT_SHAPE);
333 }
334
335 #[test]
336 fn test_irfft2_shape_with_default_params() {
337 const IN_SHAPE: &[i32] = &[6, 6];
338 const OUT_SHAPE: &[i32] = &[6, (6 - 1) * 2];
339
340 let a = Array::ones::<f32>(IN_SHAPE).unwrap();
341 let irfft2 = super::irfft2(&a, None, None).unwrap();
342 assert_eq!(irfft2.shape(), OUT_SHAPE);
343 }
344
345 #[test]
346 fn test_rfftn() {
347 const RFFTN_DATA: &[f32] = &[1.0; 8];
348 const RFFTN_SHAPE: &[i32] = &[2, 2, 2];
349 const RFFTN_EXPECTED: &[complex64] = &[
350 complex64::new(8.0, 0.0),
351 complex64::new(0.0, 0.0),
352 complex64::new(0.0, 0.0),
353 complex64::new(0.0, 0.0),
354 complex64::new(0.0, 0.0),
355 complex64::new(0.0, 0.0),
356 complex64::new(0.0, 0.0),
357 complex64::new(0.0, 0.0),
358 ];
359
360 let a = Array::from_slice(RFFTN_DATA, RFFTN_SHAPE);
361 let rfftn = super::rfftn(&a, None, None).unwrap();
362 assert_eq!(rfftn.dtype(), Dtype::Complex64);
363 assert_eq!(rfftn.as_slice::<complex64>(), RFFTN_EXPECTED);
364
365 let irfftn = super::irfftn(&rfftn, None, None).unwrap();
366 assert_eq!(irfftn.dtype(), Dtype::Float32);
367 assert_eq!(irfftn.as_slice::<f32>(), RFFTN_DATA);
368 }
369
370 #[test]
371 fn test_fftn_shape_with_default_params() {
372 const IN_SHAPE: &[i32] = &[6, 6, 6];
373 const OUT_SHAPE: &[i32] = &[6, 6, 6 / 2 + 1];
374
375 let a = Array::ones::<f32>(IN_SHAPE).unwrap();
376 let rfftn = super::rfftn(&a, None, None).unwrap();
377 assert_eq!(rfftn.shape(), OUT_SHAPE);
378 }
379
380 #[test]
381 fn test_irfftn_shape_with_default_params() {
382 const IN_SHAPE: &[i32] = &[6, 6, 6];
383 const OUT_SHAPE: &[i32] = &[6, 6, (6 - 1) * 2];
384
385 let a = Array::ones::<f32>(IN_SHAPE).unwrap();
386 let irfftn = super::irfftn(&a, None, None).unwrap();
387 assert_eq!(irfftn.shape(), OUT_SHAPE);
388 }
389}