1use mlx_internal_macros::{default_device, generate_macro};
2
3use crate::{
4 error::Result,
5 utils::{guard::Guarded, IntoOption},
6 Array, Stream, StreamOrDevice,
7};
8
9use super::{
10 as_complex64,
11 utils::{resolve_size_and_axis_unchecked, resolve_sizes_and_axes_unchecked},
12};
13
14#[generate_macro(customize(root = "$crate::fft"))]
26#[default_device]
27pub fn rfft_device(
28 a: impl AsRef<Array>,
29 #[optional] n: impl Into<Option<i32>>,
30 #[optional] axis: impl Into<Option<i32>>,
31 #[optional] stream: impl AsRef<Stream>,
32) -> Result<Array> {
33 let a = as_complex64(a.as_ref())?;
34 let (n, axis) = resolve_size_and_axis_unchecked(&a, n.into(), axis.into());
35 Array::try_from_op(|res| unsafe {
36 mlx_sys::mlx_fft_rfft(res, a.as_ptr(), n, axis, stream.as_ref().as_ptr())
37 })
38}
39
40#[generate_macro(customize(root = "$crate::fft"))]
53#[default_device]
54pub fn rfft2_device<'a>(
55 a: impl AsRef<Array>,
56 #[optional] s: impl IntoOption<&'a [i32]>,
57 #[optional] axes: impl IntoOption<&'a [i32]>,
58 #[optional] stream: impl AsRef<Stream>,
59) -> Result<Array> {
60 let a = as_complex64(a.as_ref())?;
61 let axes = axes.into_option().unwrap_or(&[-2, -1]);
62 let (s, axes) = resolve_sizes_and_axes_unchecked(&a, s.into_option(), Some(axes));
63
64 let num_s = s.len();
65 let num_axes = axes.len();
66
67 let s_ptr = s.as_ptr();
68 let axes_ptr = axes.as_ptr();
69
70 Array::try_from_op(|res| unsafe {
71 mlx_sys::mlx_fft_rfft2(
72 res,
73 a.as_ptr(),
74 s_ptr,
75 num_s,
76 axes_ptr,
77 num_axes,
78 stream.as_ref().as_ptr(),
79 )
80 })
81}
82
83#[generate_macro(customize(root = "$crate::fft"))]
97#[default_device]
98pub fn rfftn_device<'a>(
99 a: impl AsRef<Array>,
100 #[optional] s: impl IntoOption<&'a [i32]>,
101 #[optional] axes: impl IntoOption<&'a [i32]>,
102 #[optional] stream: impl AsRef<Stream>,
103) -> Result<Array> {
104 let a = as_complex64(a.as_ref())?;
105 let (s, axes) = resolve_sizes_and_axes_unchecked(&a, s.into_option(), axes.into_option());
106
107 let num_s = s.len();
108 let num_axes = axes.len();
109
110 let s_ptr = s.as_ptr();
111 let axes_ptr = axes.as_ptr();
112
113 Array::try_from_op(|res| unsafe {
114 mlx_sys::mlx_fft_rfftn(
115 res,
116 a.as_ptr(),
117 s_ptr,
118 num_s,
119 axes_ptr,
120 num_axes,
121 stream.as_ref().as_ptr(),
122 )
123 })
124}
125
126#[generate_macro(customize(root = "$crate::fft"))]
137#[default_device]
138pub fn irfft_device(
139 a: impl AsRef<Array>,
140 #[optional] n: impl Into<Option<i32>>,
141 #[optional] axis: impl Into<Option<i32>>,
142 #[optional] stream: impl AsRef<Stream>,
143) -> Result<Array> {
144 let a = as_complex64(a.as_ref())?;
145 let n = n.into();
146 let axis = axis.into();
147 let modify_n = n.is_none();
148 let (mut n, axis) = resolve_size_and_axis_unchecked(&a, n, axis);
149 if modify_n {
150 n = (n - 1) * 2;
151 }
152
153 Array::try_from_op(|res| unsafe {
154 mlx_sys::mlx_fft_irfft(res, a.as_ptr(), n, axis, stream.as_ref().as_ptr())
155 })
156}
157
158#[generate_macro(customize(root = "$crate::fft"))]
172#[default_device]
173pub fn irfft2_device<'a>(
174 a: impl AsRef<Array>,
175 #[optional] s: impl IntoOption<&'a [i32]>,
176 #[optional] axes: impl IntoOption<&'a [i32]>,
177 #[optional] stream: impl AsRef<Stream>,
178) -> Result<Array> {
179 let a = as_complex64(a.as_ref())?;
180 let s = s.into_option();
181 let axes = axes.into_option().unwrap_or(&[-2, -1]);
182 let modify_last_axis = s.is_none();
183
184 let (mut s, axes) = resolve_sizes_and_axes_unchecked(&a, s, Some(axes));
185 if modify_last_axis {
186 let end = s.len() - 1;
187 s[end] = (s[end] - 1) * 2;
188 }
189
190 let num_s = s.len();
191 let num_axes = axes.len();
192
193 let s_ptr = s.as_ptr();
194 let axes_ptr = axes.as_ptr();
195
196 Array::try_from_op(|res| unsafe {
197 mlx_sys::mlx_fft_irfft2(
198 res,
199 a.as_ptr(),
200 s_ptr,
201 num_s,
202 axes_ptr,
203 num_axes,
204 stream.as_ref().as_ptr(),
205 )
206 })
207}
208
209#[generate_macro(customize(root = "$crate::fft"))]
224#[default_device]
225pub fn irfftn_device<'a>(
226 a: impl AsRef<Array>,
227 #[optional] s: impl IntoOption<&'a [i32]>,
228 #[optional] axes: impl IntoOption<&'a [i32]>,
229 #[optional] stream: impl AsRef<Stream>,
230) -> Result<Array> {
231 let a = as_complex64(a.as_ref())?;
232 let s = s.into_option();
233 let axes = axes.into_option();
234 let modify_last_axis = s.is_none();
235
236 let (mut s, axes) = resolve_sizes_and_axes_unchecked(&a, s, axes);
237 if modify_last_axis {
238 let end = s.len() - 1;
239 s[end] = (s[end] - 1) * 2;
240 }
241
242 let num_s = s.len();
243 let num_axes = axes.len();
244
245 let s_ptr = s.as_ptr();
246 let axes_ptr = axes.as_ptr();
247
248 Array::try_from_op(|res| unsafe {
249 mlx_sys::mlx_fft_irfftn(
250 res,
251 a.as_ptr(),
252 s_ptr,
253 num_s,
254 axes_ptr,
255 num_axes,
256 stream.as_ref().as_ptr(),
257 )
258 })
259}
260
261#[cfg(test)]
262mod tests {
263 use crate::{complex64, Array, Dtype};
264
265 #[test]
266 fn test_rfft() {
267 const RFFT_DATA: &[f32] = &[1.0, 2.0, 3.0, 4.0];
268 const RFFT_N: i32 = 4;
269 const RFFT_SHAPE: &[i32] = &[RFFT_N];
270 const RFFT_AXIS: i32 = -1;
271 const RFFT_EXPECTED: &[complex64] = &[
272 complex64::new(10.0, 0.0),
273 complex64::new(-2.0, 2.0),
274 complex64::new(-2.0, 0.0),
275 ];
276
277 let a = Array::from_slice(RFFT_DATA, RFFT_SHAPE);
278 let rfft = super::rfft(&a, RFFT_N, RFFT_AXIS).unwrap();
279 assert_eq!(rfft.dtype(), Dtype::Complex64);
280 assert_eq!(rfft.as_slice::<complex64>(), RFFT_EXPECTED);
281
282 let irfft = super::irfft(&rfft, RFFT_N, RFFT_AXIS).unwrap();
283 assert_eq!(irfft.dtype(), Dtype::Float32);
284 assert_eq!(irfft.as_slice::<f32>(), RFFT_DATA);
285 }
286
287 #[test]
288 fn test_rfft_shape_with_default_params() {
289 const IN_N: i32 = 8;
290 const OUT_N: i32 = IN_N / 2 + 1;
291
292 let a = Array::ones::<f32>(&[IN_N]).unwrap();
293 let rfft = super::rfft(&a, None, None).unwrap();
294 assert_eq!(rfft.shape(), &[OUT_N]);
295 }
296
297 #[test]
298 fn test_irfft_shape_with_default_params() {
299 const IN_N: i32 = 8;
300 const OUT_N: i32 = (IN_N - 1) * 2;
301
302 let a = Array::ones::<f32>(&[IN_N]).unwrap();
303 let irfft = super::irfft(&a, None, None).unwrap();
304 assert_eq!(irfft.shape(), &[OUT_N]);
305 }
306
307 #[test]
308 fn test_rfft2() {
309 const RFFT2_DATA: &[f32] = &[1.0; 4];
310 const RFFT2_SHAPE: &[i32] = &[2, 2];
311 const RFFT2_EXPECTED: &[complex64] = &[
312 complex64::new(4.0, 0.0),
313 complex64::new(0.0, 0.0),
314 complex64::new(0.0, 0.0),
315 complex64::new(0.0, 0.0),
316 ];
317
318 let a = Array::from_slice(RFFT2_DATA, RFFT2_SHAPE);
319 let rfft2 = super::rfft2(&a, None, None).unwrap();
320 assert_eq!(rfft2.dtype(), Dtype::Complex64);
321 assert_eq!(rfft2.as_slice::<complex64>(), RFFT2_EXPECTED);
322
323 let irfft2 = super::irfft2(&rfft2, None, None).unwrap();
324 assert_eq!(irfft2.dtype(), Dtype::Float32);
325 assert_eq!(irfft2.as_slice::<f32>(), RFFT2_DATA);
326 }
327
328 #[test]
329 fn test_rfft2_shape_with_default_params() {
330 const IN_SHAPE: &[i32] = &[6, 6];
331 const OUT_SHAPE: &[i32] = &[6, 6 / 2 + 1];
332
333 let a = Array::ones::<f32>(IN_SHAPE).unwrap();
334 let rfft2 = super::rfft2(&a, None, None).unwrap();
335 assert_eq!(rfft2.shape(), OUT_SHAPE);
336 }
337
338 #[test]
339 fn test_irfft2_shape_with_default_params() {
340 const IN_SHAPE: &[i32] = &[6, 6];
341 const OUT_SHAPE: &[i32] = &[6, (6 - 1) * 2];
342
343 let a = Array::ones::<f32>(IN_SHAPE).unwrap();
344 let irfft2 = super::irfft2(&a, None, None).unwrap();
345 assert_eq!(irfft2.shape(), OUT_SHAPE);
346 }
347
348 #[test]
349 fn test_rfftn() {
350 const RFFTN_DATA: &[f32] = &[1.0; 8];
351 const RFFTN_SHAPE: &[i32] = &[2, 2, 2];
352 const RFFTN_EXPECTED: &[complex64] = &[
353 complex64::new(8.0, 0.0),
354 complex64::new(0.0, 0.0),
355 complex64::new(0.0, 0.0),
356 complex64::new(0.0, 0.0),
357 complex64::new(0.0, 0.0),
358 complex64::new(0.0, 0.0),
359 complex64::new(0.0, 0.0),
360 complex64::new(0.0, 0.0),
361 ];
362
363 let a = Array::from_slice(RFFTN_DATA, RFFTN_SHAPE);
364 let rfftn = super::rfftn(&a, None, None).unwrap();
365 assert_eq!(rfftn.dtype(), Dtype::Complex64);
366 assert_eq!(rfftn.as_slice::<complex64>(), RFFTN_EXPECTED);
367
368 let irfftn = super::irfftn(&rfftn, None, None).unwrap();
369 assert_eq!(irfftn.dtype(), Dtype::Float32);
370 assert_eq!(irfftn.as_slice::<f32>(), RFFTN_DATA);
371 }
372
373 #[test]
374 fn test_fftn_shape_with_default_params() {
375 const IN_SHAPE: &[i32] = &[6, 6, 6];
376 const OUT_SHAPE: &[i32] = &[6, 6, 6 / 2 + 1];
377
378 let a = Array::ones::<f32>(IN_SHAPE).unwrap();
379 let rfftn = super::rfftn(&a, None, None).unwrap();
380 assert_eq!(rfftn.shape(), OUT_SHAPE);
381 }
382
383 #[test]
384 fn test_irfftn_shape_with_default_params() {
385 const IN_SHAPE: &[i32] = &[6, 6, 6];
386 const OUT_SHAPE: &[i32] = &[6, 6, (6 - 1) * 2];
387
388 let a = Array::ones::<f32>(IN_SHAPE).unwrap();
389 let irfftn = super::irfftn(&a, None, None).unwrap();
390 assert_eq!(irfftn.shape(), OUT_SHAPE);
391 }
392}