mlx_rs/fft/mod.rs
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180
//! Fast Fourier Transform (FFT) and its inverse (IFFT) for one, two, and `N` dimensions.
//!
//! Like all other functions in `mlx-rs`, three variants are provided for each FFT function, plus
//! each variant has a version that uses the default `StreamOrDevice` or takes a user-specified
//! `StreamOrDevice`.
//!
//! The difference are explained below using `fftn` as an example:
//!
//! 1. `fftn_unchecked`/`fftn_device_unchecked`: This function is simply a wrapper around the C API
//! and does not perform any checks on the input. It may panic or get an fatal error that cannot
//! be caught by the rust runtime if the input is invalid.
//! 2. `try_fftn`/`try_fftn_device`: This function performs checks on the input and returns a
//! `Result` instead of panicking.
//! 3. `fftn`/`fftn_device`: This function is a wrapper around `try_fftn` and unwraps the result. It
//! panics if the input is invalid.
//!
//! The functions that contains `device` in their name are meant to be used with a user-specified
//! `StreamOrDevice`. If you don't care about the stream, you can use the functions without `device`
//! in their names. Please note that GPU device support is not yet implemented.
//!
//! # Examples
//!
//! ## One dimension
//!
//! ```rust
//! use mlx_rs::{Dtype, Array, StreamOrDevice, complex64, fft::*};
//!
//! let src = [1.0f32, 2.0, 3.0, 4.0];
//! let mut array = Array::from_slice(&src[..], &[4]);
//!
//! let mut fft_result = fft(&array, 4, 0).unwrap();
//! assert_eq!(fft_result.dtype(), Dtype::Complex64);
//!
//! let expected = &[
//! complex64::new(10.0, 0.0),
//! complex64::new(-2.0, 2.0),
//! complex64::new(-2.0, 0.0),
//! complex64::new(-2.0, -2.0),
//! ];
//! assert_eq!(fft_result.as_slice::<complex64>(), &expected[..]);
//!
//! let mut ifft_result = ifft(&fft_result, 4, 0).unwrap();
//! assert_eq!(ifft_result.dtype(), Dtype::Complex64);
//!
//! let expected = &[
//! complex64::new(1.0, 0.0),
//! complex64::new(2.0, 0.0),
//! complex64::new(3.0, 0.0),
//! complex64::new(4.0, 0.0),
//! ];
//! assert_eq!(ifft_result.as_slice::<complex64>(), &expected[..]);
//!
//! let mut rfft_result = rfft(&array, 4, 0).unwrap();
//! assert_eq!(rfft_result.dtype(), Dtype::Complex64);
//!
//! let expected = &[
//! complex64::new(10.0, 0.0),
//! complex64::new(-2.0, 2.0),
//! complex64::new(-2.0, 0.0),
//! ];
//! assert_eq!(rfft_result.as_slice::<complex64>(), &expected[..]);
//!
//! let mut irfft_result = irfft(&rfft_result, 4, 0).unwrap();
//! assert_eq!(irfft_result.dtype(), Dtype::Float32);
//! assert_eq!(irfft_result.as_slice::<f32>(), &src[..]);
//!
//! // The original array is not modified
//! let data: &[f32] = array.as_slice();
//! assert_eq!(data, &src[..]);
//! ```
//!
//! ## Two dimensions
//!
//! ```rust
//! use mlx_rs::{Dtype, Array, StreamOrDevice, complex64, fft::*};
//!
//! let src = [1.0f32, 1.0, 1.0, 1.0];
//! let mut array = Array::from_slice(&src[..], &[2, 2]);
//!
//! let mut fft2_result = fft2(&array, None, None).unwrap();
//! assert_eq!(fft2_result.dtype(), Dtype::Complex64);
//! let expected = &[
//! complex64::new(4.0, 0.0),
//! complex64::new(0.0, 0.0),
//! complex64::new(0.0, 0.0),
//! complex64::new(0.0, 0.0),
//! ];
//! assert_eq!(fft2_result.as_slice::<complex64>(), &expected[..]);
//!
//! let mut ifft2_result = ifft2(&fft2_result, None, None).unwrap();
//! assert_eq!(ifft2_result.dtype(), Dtype::Complex64);
//!
//! let expected = &[
//! complex64::new(1.0, 0.0),
//! complex64::new(1.0, 0.0),
//! complex64::new(1.0, 0.0),
//! complex64::new(1.0, 0.0),
//! ];
//! assert_eq!(ifft2_result.as_slice::<complex64>(), &expected[..]);
//!
//! let mut rfft2_result = rfft2(&array, None, None).unwrap();
//! assert_eq!(rfft2_result.dtype(), Dtype::Complex64);
//!
//! let expected = &[
//! complex64::new(4.0, 0.0),
//! complex64::new(0.0, 0.0),
//! complex64::new(0.0, 0.0),
//! complex64::new(0.0, 0.0),
//! ];
//! assert_eq!(rfft2_result.as_slice::<complex64>(), &expected[..]);
//!
//! let mut irfft2_result = irfft2(&rfft2_result, None, None).unwrap();
//! assert_eq!(irfft2_result.dtype(), Dtype::Float32);
//! assert_eq!(irfft2_result.as_slice::<f32>(), &src[..]);
//!
//! // The original array is not modified
//! let data: &[f32] = array.as_slice();
//! assert_eq!(data, &[1.0, 1.0, 1.0, 1.0]);
//! ```
//!
//! ## `N` dimensions
//!
//! ```rust
//! use mlx_rs::{Dtype, Array, StreamOrDevice, complex64, fft::*};
//!
//! let mut array = Array::ones::<f32>(&[2, 2, 2]).unwrap();
//! let mut fftn_result = fftn(&array, None, None).unwrap();
//! assert_eq!(fftn_result.dtype(), Dtype::Complex64);
//!
//! let mut expected = [complex64::new(0.0, 0.0); 8];
//! expected[0] = complex64::new(8.0, 0.0);
//! assert_eq!(fftn_result.as_slice::<complex64>(), &expected[..]);
//!
//! let mut ifftn_result = ifftn(&fftn_result, None, None).unwrap();
//! assert_eq!(ifftn_result.dtype(), Dtype::Complex64);
//!
//! let expected = [complex64::new(1.0, 0.0); 8];
//! assert_eq!(ifftn_result.as_slice::<complex64>(), &expected[..]);
//!
//! let mut rfftn_result = rfftn(&array, None, None).unwrap();
//! assert_eq!(rfftn_result.dtype(), Dtype::Complex64);
//!
//! let mut expected = [complex64::new(0.0, 0.0); 8];
//! expected[0] = complex64::new(8.0, 0.0);
//! assert_eq!(rfftn_result.as_slice::<complex64>(), &expected[..]);
//!
//! let mut irfftn_result = irfftn(&rfftn_result, None, None).unwrap();
//! assert_eq!(irfftn_result.dtype(), Dtype::Float32);
//!
//! let expected = [1.0; 8];
//! assert_eq!(irfftn_result.as_slice::<f32>(), &expected[..]);
//!
//! // The original array is not modified
//! let data: &[f32] = array.as_slice();
//! assert_eq!(data, &[1.0; 8]);
//! ```
mod fftn;
mod rfftn;
mod utils;
pub use self::{fftn::*, rfftn::*};
/* -------------------------------------------------------------------------- */
/* Helper functions */
/* -------------------------------------------------------------------------- */
use crate::{complex64, error::Exception, Array, Dtype};
use std::borrow::Cow;
fn as_complex64(src: &Array) -> Result<Cow<'_, Array>, Exception> {
match src.dtype() {
Dtype::Complex64 => Ok(Cow::Borrowed(src)),
_ => {
let new_array = src.as_type::<complex64>()?;
new_array.eval()?;
Ok(Cow::Owned(new_array))
}
}
}