use crate::error::Result;
use crate::utils::guard::Guarded;
use crate::{Array, Stream, StreamOrDevice};
use mlx_internal_macros::default_device;
#[allow(clippy::too_many_arguments)]
#[default_device]
pub fn rope_device<'a>(
array: impl AsRef<Array>,
dimensions: i32,
traditional: bool,
base: impl Into<Option<f32>>,
scale: f32,
offset: i32,
freqs: impl Into<Option<&'a Array>>,
stream: impl AsRef<Stream>,
) -> Result<Array> {
let base = base.into();
let base = mlx_sys::mlx_optional_float {
value: base.unwrap_or(0.0),
has_value: base.is_some(),
};
let freqs = freqs.into();
Array::try_from_op(|res| unsafe {
mlx_sys::mlx_fast_rope(
res,
array.as_ref().as_ptr(),
dimensions,
traditional,
base,
scale,
offset,
freqs
.map(|a| a.as_ptr())
.unwrap_or(mlx_sys::mlx_array_new()),
stream.as_ref().as_ptr(),
)
})
}
#[default_device]
pub fn scaled_dot_product_attention_device<'a>(
queries: impl AsRef<Array>,
keys: impl AsRef<Array>,
values: impl AsRef<Array>,
scale: f32,
mask: impl Into<Option<&'a Array>>,
memory_efficient_threshold: impl Into<Option<i32>>,
stream: impl AsRef<Stream>,
) -> Result<Array> {
let memory_efficient_threshold = memory_efficient_threshold.into();
let memory_efficient_threshold = mlx_sys::mlx_optional_int {
value: memory_efficient_threshold.unwrap_or(0),
has_value: memory_efficient_threshold.is_some(),
};
Array::try_from_op(|res| unsafe {
mlx_sys::mlx_fast_scaled_dot_product_attention(
res,
queries.as_ref().as_ptr(),
keys.as_ref().as_ptr(),
values.as_ref().as_ptr(),
scale,
mask.into()
.map(|a| a.as_ptr())
.unwrap_or(mlx_sys::mlx_array_new()),
memory_efficient_threshold,
stream.as_ref().as_ptr(),
)
})
}
#[default_device]
pub fn rms_norm_device(
x: impl AsRef<Array>,
weight: impl AsRef<Array>,
eps: f32,
stream: impl AsRef<Stream>,
) -> Result<Array> {
Array::try_from_op(|res| unsafe {
mlx_sys::mlx_fast_rms_norm(
res,
x.as_ref().as_ptr(),
weight.as_ref().as_ptr(),
eps,
stream.as_ref().as_ptr(),
)
})
}
#[default_device]
pub fn layer_norm_device<'a>(
x: impl AsRef<Array>,
weight: impl Into<Option<&'a Array>>,
bias: impl Into<Option<&'a Array>>,
eps: f32,
stream: impl AsRef<Stream>,
) -> Result<Array> {
Array::try_from_op(|res| unsafe {
mlx_sys::mlx_fast_layer_norm(
res,
x.as_ref().as_ptr(),
weight
.into()
.map(|a| a.as_ptr())
.unwrap_or(mlx_sys::mlx_array_new()),
bias.into()
.map(|a| a.as_ptr())
.unwrap_or(mlx_sys::mlx_array_new()),
eps,
stream.as_ref().as_ptr(),
)
})
}
#[cfg(test)]
mod tests {
use super::*;
use crate::ops::indexing::{ArrayIndexOp, IndexOp};
use float_eq::assert_float_eq;
use pretty_assertions::assert_eq;
#[test]
fn test_rope() {
crate::random::seed(71).unwrap();
let a = crate::random::uniform::<_, f32>(0.0, 1.0, &[2, 8, 16], None).unwrap();
assert_eq!(a.shape(), [2, 8, 16]);
assert_eq!(a.dtype(), crate::Dtype::Float32);
let result = rope(a, 8, false, 10000., 1.0, 0, None).unwrap();
assert_eq!(result.shape(), [2, 8, 16]);
assert_eq!(result.dtype(), crate::Dtype::Float32);
assert_float_eq!(
result.mean(None, None).unwrap().item::<f32>(),
0.456_253_77,
abs <= 0.009_125_075
);
assert_float_eq!(
result.sum(None, None).unwrap().item::<f32>(),
116.800_964,
abs <= 2.336_019_3
);
}
#[test]
fn test_rms_norm() {
crate::random::seed(103).unwrap();
let a = crate::random::uniform::<_, f32>(0.0, 1.0, &[2, 8, 16], None).unwrap();
assert_eq!(a.shape(), [2, 8, 16]);
assert_eq!(a.dtype(), crate::Dtype::Float32);
let weight = Array::ones::<f32>(&[16]).unwrap();
let result = rms_norm(a, weight, 1e-5).unwrap();
assert_eq!(result.shape(), [2, 8, 16]);
assert_eq!(result.dtype(), crate::Dtype::Float32);
assert_float_eq!(
result.mean(None, None).unwrap().item::<f32>(),
0.872_938_75,
abs <= 0.017_458_774
);
assert_float_eq!(
result.sum(None, None).unwrap().item::<f32>(),
223.472_32,
abs <= 4.469_446
);
}
#[test]
pub fn test_layer_norm_affine() {
crate::random::seed(635).unwrap();
let a = crate::random::uniform::<_, f32>(0.0, 1.0, &[2, 8, 16], None).unwrap();
assert_eq!(a.shape(), [2, 8, 16]);
assert_eq!(a.dtype(), crate::Dtype::Float32);
let weight = Array::ones::<f32>(&[16]).unwrap();
let bias = Array::zeros::<f32>(&[16]).unwrap();
let result = layer_norm(a, &weight, &bias, 1e-5).unwrap();
let result = result.index((ArrayIndexOp::Ellipsis, 0));
assert_eq!(result.shape(), [2, 8]);
assert_eq!(result.dtype(), crate::Dtype::Float32);
assert_float_eq!(
result.mean(None, None).unwrap().item::<f32>(),
0.290_990_38,
abs <= 0.005_819_807_8
);
assert_float_eq!(
result.sum(None, None).unwrap().item::<f32>(),
4.655_846,
abs <= 0.093_116_924
);
}
}