mlx_rs/fft/mod.rs
1//! Fast Fourier Transform (FFT) and its inverse (IFFT) for one, two, and `N` dimensions.
2//!
3//! Like all other functions in `mlx-rs`, three variants are provided for each FFT function, plus
4//! each variant has a version that uses the default `StreamOrDevice` or takes a user-specified
5//! `StreamOrDevice`.
6//!
7//! The difference are explained below using `fftn` as an example:
8//!
9//! 1. `fftn_unchecked`/`fftn_device_unchecked`: This function is simply a wrapper around the C API
10//! and does not perform any checks on the input. It may panic or get an fatal error that cannot
11//! be caught by the rust runtime if the input is invalid.
12//! 2. `try_fftn`/`try_fftn_device`: This function performs checks on the input and returns a
13//! `Result` instead of panicking.
14//! 3. `fftn`/`fftn_device`: This function is a wrapper around `try_fftn` and unwraps the result. It
15//! panics if the input is invalid.
16//!
17//! The functions that contains `device` in their name are meant to be used with a user-specified
18//! `StreamOrDevice`. If you don't care about the stream, you can use the functions without `device`
19//! in their names. Please note that GPU device support is not yet implemented.
20//!
21//! # Examples
22//!
23//! ## One dimension
24//!
25//! ```rust
26//! use mlx_rs::{Dtype, Array, StreamOrDevice, complex64, fft::*};
27//!
28//! let src = [1.0f32, 2.0, 3.0, 4.0];
29//! let mut array = Array::from_slice(&src[..], &[4]);
30//!
31//! let mut fft_result = fft(&array, 4, 0).unwrap();
32//! assert_eq!(fft_result.dtype(), Dtype::Complex64);
33//!
34//! let expected = &[
35//! complex64::new(10.0, 0.0),
36//! complex64::new(-2.0, 2.0),
37//! complex64::new(-2.0, 0.0),
38//! complex64::new(-2.0, -2.0),
39//! ];
40//! assert_eq!(fft_result.as_slice::<complex64>(), &expected[..]);
41//!
42//! let mut ifft_result = ifft(&fft_result, 4, 0).unwrap();
43//! assert_eq!(ifft_result.dtype(), Dtype::Complex64);
44//!
45//! let expected = &[
46//! complex64::new(1.0, 0.0),
47//! complex64::new(2.0, 0.0),
48//! complex64::new(3.0, 0.0),
49//! complex64::new(4.0, 0.0),
50//! ];
51//! assert_eq!(ifft_result.as_slice::<complex64>(), &expected[..]);
52//!
53//! let mut rfft_result = rfft(&array, 4, 0).unwrap();
54//! assert_eq!(rfft_result.dtype(), Dtype::Complex64);
55//!
56//! let expected = &[
57//! complex64::new(10.0, 0.0),
58//! complex64::new(-2.0, 2.0),
59//! complex64::new(-2.0, 0.0),
60//! ];
61//! assert_eq!(rfft_result.as_slice::<complex64>(), &expected[..]);
62//!
63//! let mut irfft_result = irfft(&rfft_result, 4, 0).unwrap();
64//! assert_eq!(irfft_result.dtype(), Dtype::Float32);
65//! assert_eq!(irfft_result.as_slice::<f32>(), &src[..]);
66//!
67//! // The original array is not modified
68//! let data: &[f32] = array.as_slice();
69//! assert_eq!(data, &src[..]);
70//! ```
71//!
72//! ## Two dimensions
73//!
74//! ```rust
75//! use mlx_rs::{Dtype, Array, StreamOrDevice, complex64, fft::*};
76//!
77//! let src = [1.0f32, 1.0, 1.0, 1.0];
78//! let mut array = Array::from_slice(&src[..], &[2, 2]);
79//!
80//! let mut fft2_result = fft2(&array, None, None).unwrap();
81//! assert_eq!(fft2_result.dtype(), Dtype::Complex64);
82//! let expected = &[
83//! complex64::new(4.0, 0.0),
84//! complex64::new(0.0, 0.0),
85//! complex64::new(0.0, 0.0),
86//! complex64::new(0.0, 0.0),
87//! ];
88//! assert_eq!(fft2_result.as_slice::<complex64>(), &expected[..]);
89//!
90//! let mut ifft2_result = ifft2(&fft2_result, None, None).unwrap();
91//! assert_eq!(ifft2_result.dtype(), Dtype::Complex64);
92//!
93//! let expected = &[
94//! complex64::new(1.0, 0.0),
95//! complex64::new(1.0, 0.0),
96//! complex64::new(1.0, 0.0),
97//! complex64::new(1.0, 0.0),
98//! ];
99//! assert_eq!(ifft2_result.as_slice::<complex64>(), &expected[..]);
100//!
101//! let mut rfft2_result = rfft2(&array, None, None).unwrap();
102//! assert_eq!(rfft2_result.dtype(), Dtype::Complex64);
103//!
104//! let expected = &[
105//! complex64::new(4.0, 0.0),
106//! complex64::new(0.0, 0.0),
107//! complex64::new(0.0, 0.0),
108//! complex64::new(0.0, 0.0),
109//! ];
110//! assert_eq!(rfft2_result.as_slice::<complex64>(), &expected[..]);
111//!
112//! let mut irfft2_result = irfft2(&rfft2_result, None, None).unwrap();
113//! assert_eq!(irfft2_result.dtype(), Dtype::Float32);
114//! assert_eq!(irfft2_result.as_slice::<f32>(), &src[..]);
115//!
116//! // The original array is not modified
117//! let data: &[f32] = array.as_slice();
118//! assert_eq!(data, &[1.0, 1.0, 1.0, 1.0]);
119//! ```
120//!
121//! ## `N` dimensions
122//!
123//! ```rust
124//! use mlx_rs::{Dtype, Array, StreamOrDevice, complex64, fft::*};
125//!
126//! let mut array = Array::ones::<f32>(&[2, 2, 2]).unwrap();
127//! let mut fftn_result = fftn(&array, None, None).unwrap();
128//! assert_eq!(fftn_result.dtype(), Dtype::Complex64);
129//!
130//! let mut expected = [complex64::new(0.0, 0.0); 8];
131//! expected[0] = complex64::new(8.0, 0.0);
132//! assert_eq!(fftn_result.as_slice::<complex64>(), &expected[..]);
133//!
134//! let mut ifftn_result = ifftn(&fftn_result, None, None).unwrap();
135//! assert_eq!(ifftn_result.dtype(), Dtype::Complex64);
136//!
137//! let expected = [complex64::new(1.0, 0.0); 8];
138//! assert_eq!(ifftn_result.as_slice::<complex64>(), &expected[..]);
139//!
140//! let mut rfftn_result = rfftn(&array, None, None).unwrap();
141//! assert_eq!(rfftn_result.dtype(), Dtype::Complex64);
142//!
143//! let mut expected = [complex64::new(0.0, 0.0); 8];
144//! expected[0] = complex64::new(8.0, 0.0);
145//! assert_eq!(rfftn_result.as_slice::<complex64>(), &expected[..]);
146//!
147//! let mut irfftn_result = irfftn(&rfftn_result, None, None).unwrap();
148//! assert_eq!(irfftn_result.dtype(), Dtype::Float32);
149//!
150//! let expected = [1.0; 8];
151//! assert_eq!(irfftn_result.as_slice::<f32>(), &expected[..]);
152//!
153//! // The original array is not modified
154//! let data: &[f32] = array.as_slice();
155//! assert_eq!(data, &[1.0; 8]);
156//! ```
157
158mod fftn;
159mod rfftn;
160mod utils;
161
162pub use self::{fftn::*, rfftn::*};
163
164/* -------------------------------------------------------------------------- */
165/* Helper functions */
166/* -------------------------------------------------------------------------- */
167
168use crate::{complex64, error::Exception, Array, Dtype};
169use std::borrow::Cow;
170
171fn as_complex64(src: &Array) -> Result<Cow<'_, Array>, Exception> {
172 match src.dtype() {
173 Dtype::Complex64 => Ok(Cow::Borrowed(src)),
174 _ => {
175 let new_array = src.as_type::<complex64>()?;
176 new_array.eval()?;
177 Ok(Cow::Owned(new_array))
178 }
179 }
180}