use mlx_internal_macros::default_device;
use crate::{
error::Result,
utils::{guard::Guarded, IntoOption},
Array, Stream, StreamOrDevice,
};
use super::{
as_complex64,
utils::{resolve_size_and_axis_unchecked, resolve_sizes_and_axes_unchecked},
};
#[default_device]
pub fn rfft_device(
a: impl AsRef<Array>,
n: impl Into<Option<i32>>,
axis: impl Into<Option<i32>>,
stream: impl AsRef<Stream>,
) -> Result<Array> {
let a = as_complex64(a.as_ref())?;
let (n, axis) = resolve_size_and_axis_unchecked(&a, n.into(), axis.into());
Array::try_from_op(|res| unsafe {
mlx_sys::mlx_fft_rfft(res, a.as_ptr(), n, axis, stream.as_ref().as_ptr())
})
}
#[default_device]
pub fn rfft2_device<'a>(
a: impl AsRef<Array>,
s: impl IntoOption<&'a [i32]>,
axes: impl IntoOption<&'a [i32]>,
stream: impl AsRef<Stream>,
) -> Result<Array> {
let a = as_complex64(a.as_ref())?;
let axes = axes.into_option().unwrap_or(&[-2, -1]);
let (s, axes) = resolve_sizes_and_axes_unchecked(&a, s.into_option(), Some(axes));
let num_s = s.len();
let num_axes = axes.len();
let s_ptr = s.as_ptr();
let axes_ptr = axes.as_ptr();
Array::try_from_op(|res| unsafe {
mlx_sys::mlx_fft_rfft2(
res,
a.as_ptr(),
s_ptr,
num_s,
axes_ptr,
num_axes,
stream.as_ref().as_ptr(),
)
})
}
#[default_device]
pub fn rfftn_device<'a>(
a: impl AsRef<Array>,
s: impl IntoOption<&'a [i32]>,
axes: impl IntoOption<&'a [i32]>,
stream: impl AsRef<Stream>,
) -> Result<Array> {
let a = as_complex64(a.as_ref())?;
let (s, axes) = resolve_sizes_and_axes_unchecked(&a, s.into_option(), axes.into_option());
let num_s = s.len();
let num_axes = axes.len();
let s_ptr = s.as_ptr();
let axes_ptr = axes.as_ptr();
Array::try_from_op(|res| unsafe {
mlx_sys::mlx_fft_rfftn(
res,
a.as_ptr(),
s_ptr,
num_s,
axes_ptr,
num_axes,
stream.as_ref().as_ptr(),
)
})
}
#[default_device]
pub fn irfft_device(
a: impl AsRef<Array>,
n: impl Into<Option<i32>>,
axis: impl Into<Option<i32>>,
stream: impl AsRef<Stream>,
) -> Result<Array> {
let a = as_complex64(a.as_ref())?;
let n = n.into();
let axis = axis.into();
let modify_n = n.is_none();
let (mut n, axis) = resolve_size_and_axis_unchecked(&a, n, axis);
if modify_n {
n = (n - 1) * 2;
}
Array::try_from_op(|res| unsafe {
mlx_sys::mlx_fft_irfft(res, a.as_ptr(), n, axis, stream.as_ref().as_ptr())
})
}
#[default_device]
pub fn irfft2_device<'a>(
a: impl AsRef<Array>,
s: impl IntoOption<&'a [i32]>,
axes: impl IntoOption<&'a [i32]>,
stream: impl AsRef<Stream>,
) -> Result<Array> {
let a = as_complex64(a.as_ref())?;
let s = s.into_option();
let axes = axes.into_option().unwrap_or(&[-2, -1]);
let modify_last_axis = s.is_none();
let (mut s, axes) = resolve_sizes_and_axes_unchecked(&a, s, Some(axes));
if modify_last_axis {
let end = s.len() - 1;
s[end] = (s[end] - 1) * 2;
}
let num_s = s.len();
let num_axes = axes.len();
let s_ptr = s.as_ptr();
let axes_ptr = axes.as_ptr();
Array::try_from_op(|res| unsafe {
mlx_sys::mlx_fft_irfft2(
res,
a.as_ptr(),
s_ptr,
num_s,
axes_ptr,
num_axes,
stream.as_ref().as_ptr(),
)
})
}
#[default_device]
pub fn irfftn_device<'a>(
a: impl AsRef<Array>,
s: impl IntoOption<&'a [i32]>,
axes: impl IntoOption<&'a [i32]>,
stream: impl AsRef<Stream>,
) -> Result<Array> {
let a = as_complex64(a.as_ref())?;
let s = s.into_option();
let axes = axes.into_option();
let modify_last_axis = s.is_none();
let (mut s, axes) = resolve_sizes_and_axes_unchecked(&a, s, axes);
if modify_last_axis {
let end = s.len() - 1;
s[end] = (s[end] - 1) * 2;
}
let num_s = s.len();
let num_axes = axes.len();
let s_ptr = s.as_ptr();
let axes_ptr = axes.as_ptr();
Array::try_from_op(|res| unsafe {
mlx_sys::mlx_fft_irfftn(
res,
a.as_ptr(),
s_ptr,
num_s,
axes_ptr,
num_axes,
stream.as_ref().as_ptr(),
)
})
}
#[cfg(test)]
mod tests {
use crate::{complex64, Array, Dtype};
#[test]
fn test_rfft() {
const RFFT_DATA: &[f32] = &[1.0, 2.0, 3.0, 4.0];
const RFFT_N: i32 = 4;
const RFFT_SHAPE: &[i32] = &[RFFT_N];
const RFFT_AXIS: i32 = -1;
const RFFT_EXPECTED: &[complex64] = &[
complex64::new(10.0, 0.0),
complex64::new(-2.0, 2.0),
complex64::new(-2.0, 0.0),
];
let a = Array::from_slice(RFFT_DATA, RFFT_SHAPE);
let rfft = super::rfft(&a, RFFT_N, RFFT_AXIS).unwrap();
assert_eq!(rfft.dtype(), Dtype::Complex64);
assert_eq!(rfft.as_slice::<complex64>(), RFFT_EXPECTED);
let irfft = super::irfft(&rfft, RFFT_N, RFFT_AXIS).unwrap();
assert_eq!(irfft.dtype(), Dtype::Float32);
assert_eq!(irfft.as_slice::<f32>(), RFFT_DATA);
}
#[test]
fn test_rfft_shape_with_default_params() {
const IN_N: i32 = 8;
const OUT_N: i32 = IN_N / 2 + 1;
let a = Array::ones::<f32>(&[IN_N]).unwrap();
let rfft = super::rfft(&a, None, None).unwrap();
assert_eq!(rfft.shape(), &[OUT_N]);
}
#[test]
fn test_irfft_shape_with_default_params() {
const IN_N: i32 = 8;
const OUT_N: i32 = (IN_N - 1) * 2;
let a = Array::ones::<f32>(&[IN_N]).unwrap();
let irfft = super::irfft(&a, None, None).unwrap();
assert_eq!(irfft.shape(), &[OUT_N]);
}
#[test]
fn test_rfft2() {
const RFFT2_DATA: &[f32] = &[1.0; 4];
const RFFT2_SHAPE: &[i32] = &[2, 2];
const RFFT2_EXPECTED: &[complex64] = &[
complex64::new(4.0, 0.0),
complex64::new(0.0, 0.0),
complex64::new(0.0, 0.0),
complex64::new(0.0, 0.0),
];
let a = Array::from_slice(RFFT2_DATA, RFFT2_SHAPE);
let rfft2 = super::rfft2(&a, None, None).unwrap();
assert_eq!(rfft2.dtype(), Dtype::Complex64);
assert_eq!(rfft2.as_slice::<complex64>(), RFFT2_EXPECTED);
let irfft2 = super::irfft2(&rfft2, None, None).unwrap();
assert_eq!(irfft2.dtype(), Dtype::Float32);
assert_eq!(irfft2.as_slice::<f32>(), RFFT2_DATA);
}
#[test]
fn test_rfft2_shape_with_default_params() {
const IN_SHAPE: &[i32] = &[6, 6];
const OUT_SHAPE: &[i32] = &[6, 6 / 2 + 1];
let a = Array::ones::<f32>(IN_SHAPE).unwrap();
let rfft2 = super::rfft2(&a, None, None).unwrap();
assert_eq!(rfft2.shape(), OUT_SHAPE);
}
#[test]
fn test_irfft2_shape_with_default_params() {
const IN_SHAPE: &[i32] = &[6, 6];
const OUT_SHAPE: &[i32] = &[6, (6 - 1) * 2];
let a = Array::ones::<f32>(IN_SHAPE).unwrap();
let irfft2 = super::irfft2(&a, None, None).unwrap();
assert_eq!(irfft2.shape(), OUT_SHAPE);
}
#[test]
fn test_rfftn() {
const RFFTN_DATA: &[f32] = &[1.0; 8];
const RFFTN_SHAPE: &[i32] = &[2, 2, 2];
const RFFTN_EXPECTED: &[complex64] = &[
complex64::new(8.0, 0.0),
complex64::new(0.0, 0.0),
complex64::new(0.0, 0.0),
complex64::new(0.0, 0.0),
complex64::new(0.0, 0.0),
complex64::new(0.0, 0.0),
complex64::new(0.0, 0.0),
complex64::new(0.0, 0.0),
];
let a = Array::from_slice(RFFTN_DATA, RFFTN_SHAPE);
let rfftn = super::rfftn(&a, None, None).unwrap();
assert_eq!(rfftn.dtype(), Dtype::Complex64);
assert_eq!(rfftn.as_slice::<complex64>(), RFFTN_EXPECTED);
let irfftn = super::irfftn(&rfftn, None, None).unwrap();
assert_eq!(irfftn.dtype(), Dtype::Float32);
assert_eq!(irfftn.as_slice::<f32>(), RFFTN_DATA);
}
#[test]
fn test_fftn_shape_with_default_params() {
const IN_SHAPE: &[i32] = &[6, 6, 6];
const OUT_SHAPE: &[i32] = &[6, 6, 6 / 2 + 1];
let a = Array::ones::<f32>(IN_SHAPE).unwrap();
let rfftn = super::rfftn(&a, None, None).unwrap();
assert_eq!(rfftn.shape(), OUT_SHAPE);
}
#[test]
fn test_irfftn_shape_with_default_params() {
const IN_SHAPE: &[i32] = &[6, 6, 6];
const OUT_SHAPE: &[i32] = &[6, 6, (6 - 1) * 2];
let a = Array::ones::<f32>(IN_SHAPE).unwrap();
let irfftn = super::irfftn(&a, None, None).unwrap();
assert_eq!(irfftn.shape(), OUT_SHAPE);
}
}