1use std::ffi::CString;
4
5use crate::error::{Exception, Result};
6use crate::utils::guard::Guarded;
7use crate::utils::VectorArray;
8use crate::{Array, Stream};
9use mlx_internal_macros::default_device;
10
11#[allow(clippy::too_many_arguments)]
13#[default_device]
14pub fn rope_device<'a>(
15 array: impl AsRef<Array>,
16 dimensions: i32,
17 traditional: bool,
18 base: impl Into<Option<f32>>,
19 scale: f32,
20 offset: i32,
21 freqs: impl Into<Option<&'a Array>>,
22 stream: impl AsRef<Stream>,
23) -> Result<Array> {
24 let base = base.into();
25 let base = mlx_sys::mlx_optional_float {
26 value: base.unwrap_or(0.0),
27 has_value: base.is_some(),
28 };
29 let freqs = freqs.into();
30 Array::try_from_op(|res| unsafe {
31 mlx_sys::mlx_fast_rope(
32 res,
33 array.as_ref().as_ptr(),
34 dimensions,
35 traditional,
36 base,
37 scale,
38 offset,
39 freqs
40 .map(|a| a.as_ptr())
41 .unwrap_or(mlx_sys::mlx_array_new()),
42 stream.as_ref().as_ptr(),
43 )
44 })
45}
46
47#[default_device]
57pub fn scaled_dot_product_attention_device<'a>(
58 queries: impl AsRef<Array>,
59 keys: impl AsRef<Array>,
60 values: impl AsRef<Array>,
61 scale: f32,
62 mask: impl Into<Option<&'a Array>>,
63 stream: impl AsRef<Stream>,
64) -> Result<Array> {
65 let mask_mode = CString::new("").map_err(|e| Exception::custom(format!("{}", e)))?;
66 let masks = match mask.into() {
67 Some(m) => VectorArray::try_from_iter([m].iter())?,
68 None => unsafe { VectorArray::from_ptr(mlx_sys::mlx_vector_array_new()) },
69 };
70
71 Array::try_from_op(|res| unsafe {
72 mlx_sys::mlx_fast_scaled_dot_product_attention(
73 res,
74 queries.as_ref().as_ptr(),
75 keys.as_ref().as_ptr(),
76 values.as_ref().as_ptr(),
77 scale,
78 mask_mode.as_ptr(),
79 masks.as_ptr(),
80 stream.as_ref().as_ptr(),
81 )
82 })
83}
84
85#[default_device]
96pub fn rms_norm_device(
97 x: impl AsRef<Array>,
98 weight: impl AsRef<Array>,
99 eps: f32,
100 stream: impl AsRef<Stream>,
101) -> Result<Array> {
102 Array::try_from_op(|res| unsafe {
103 mlx_sys::mlx_fast_rms_norm(
104 res,
105 x.as_ref().as_ptr(),
106 weight.as_ref().as_ptr(),
107 eps,
108 stream.as_ref().as_ptr(),
109 )
110 })
111}
112
113#[default_device]
127pub fn layer_norm_device<'a>(
128 x: impl AsRef<Array>,
129 weight: impl Into<Option<&'a Array>>,
130 bias: impl Into<Option<&'a Array>>,
131 eps: f32,
132 stream: impl AsRef<Stream>,
133) -> Result<Array> {
134 Array::try_from_op(|res| unsafe {
135 mlx_sys::mlx_fast_layer_norm(
136 res,
137 x.as_ref().as_ptr(),
138 weight
139 .into()
140 .map(|a| a.as_ptr())
141 .unwrap_or(mlx_sys::mlx_array_new()),
142 bias.into()
143 .map(|a| a.as_ptr())
144 .unwrap_or(mlx_sys::mlx_array_new()),
145 eps,
146 stream.as_ref().as_ptr(),
147 )
148 })
149}
150
151#[cfg(test)]
152mod tests {
153 use super::*;
154 use crate::ops::indexing::{ArrayIndexOp, IndexOp};
155 use float_eq::assert_float_eq;
156 use pretty_assertions::assert_eq;
157
158 #[test]
159 fn test_rope() {
160 crate::random::seed(71).unwrap();
161 let a = crate::random::uniform::<_, f32>(0.0, 1.0, &[2, 8, 16], None).unwrap();
162 assert_eq!(a.shape(), [2, 8, 16]);
163 assert_eq!(a.dtype(), crate::Dtype::Float32);
164
165 let result = rope(a, 8, false, 10000., 1.0, 0, None).unwrap();
166 assert_eq!(result.shape(), [2, 8, 16]);
167 assert_eq!(result.dtype(), crate::Dtype::Float32);
168 assert_float_eq!(
169 result.mean(None).unwrap().item::<f32>(),
170 0.456_253_77,
171 abs <= 0.009_125_075
172 );
173 assert_float_eq!(
174 result.sum(None).unwrap().item::<f32>(),
175 116.800_964,
176 abs <= 2.336_019_3
177 );
178 }
179
180 #[test]
181 fn test_rms_norm() {
182 crate::random::seed(103).unwrap();
183 let a = crate::random::uniform::<_, f32>(0.0, 1.0, &[2, 8, 16], None).unwrap();
184 assert_eq!(a.shape(), [2, 8, 16]);
185 assert_eq!(a.dtype(), crate::Dtype::Float32);
186
187 let weight = Array::ones::<f32>(&[16]).unwrap();
188 let result = rms_norm(a, weight, 1e-5).unwrap();
189 assert_eq!(result.shape(), [2, 8, 16]);
190 assert_eq!(result.dtype(), crate::Dtype::Float32);
191 assert_float_eq!(
192 result.mean(None).unwrap().item::<f32>(),
193 0.872_938_75,
194 abs <= 0.017_458_774
195 );
196 assert_float_eq!(
197 result.sum(None).unwrap().item::<f32>(),
198 223.472_32,
199 abs <= 4.469_446
200 );
201 }
202
203 #[test]
204 pub fn test_layer_norm_affine() {
205 crate::random::seed(635).unwrap();
206 let a = crate::random::uniform::<_, f32>(0.0, 1.0, &[2, 8, 16], None).unwrap();
207 assert_eq!(a.shape(), [2, 8, 16]);
208 assert_eq!(a.dtype(), crate::Dtype::Float32);
209
210 let weight = Array::ones::<f32>(&[16]).unwrap();
211 let bias = Array::zeros::<f32>(&[16]).unwrap();
212 let result = layer_norm(a, &weight, &bias, 1e-5).unwrap();
213 let result = result.index((ArrayIndexOp::Ellipsis, 0));
214 assert_eq!(result.shape(), [2, 8]);
215 assert_eq!(result.dtype(), crate::Dtype::Float32);
216 assert_float_eq!(
217 result.mean(None).unwrap().item::<f32>(),
218 0.290_990_38,
219 abs <= 0.005_819_807_8
220 );
221 assert_float_eq!(
222 result.sum(None).unwrap().item::<f32>(),
223 4.655_846,
224 abs <= 0.093_116_924
225 );
226 }
227}