1use std::ffi::CStr;
4
5use crate::error::Result;
6use crate::utils::guard::Guarded;
7use crate::utils::{IntoOption, VectorArray};
8use crate::{Array, Stream};
9use mlx_internal_macros::{default_device, generate_macro};
10
11#[allow(clippy::too_many_arguments)]
13#[generate_macro(customize(root = "$crate::fast"))]
14#[default_device]
15pub fn rope_device<'a>(
16 #[named] array: impl AsRef<Array>,
17 #[named] dimensions: i32,
18 #[named] traditional: bool,
19 #[optional] base: impl Into<Option<f32>>,
20 #[named] scale: f32,
21 #[named] offset: i32,
22 #[optional] freqs: impl Into<Option<&'a Array>>,
23 #[optional] stream: impl AsRef<Stream>,
24) -> Result<Array> {
25 let base = base.into();
26 let base = mlx_sys::mlx_optional_float {
27 value: base.unwrap_or(0.0),
28 has_value: base.is_some(),
29 };
30 let freqs = freqs.into();
31 Array::try_from_op(|res| unsafe {
32 mlx_sys::mlx_fast_rope(
33 res,
34 array.as_ref().as_ptr(),
35 dimensions,
36 traditional,
37 base,
38 scale,
39 offset,
40 freqs
41 .map(|a| a.as_ptr())
42 .unwrap_or(mlx_sys::mlx_array_new()),
43 stream.as_ref().as_ptr(),
44 )
45 })
46}
47
48const DEFAULT_MASK_MODE: &CStr = c"";
49const CAUSAL_MASK_MODE: &CStr = c"causal";
50
51#[derive(Debug)]
53pub enum ScaledDotProductAttentionMask<'a> {
54 Array(&'a Array),
56
57 Arrays(&'a [Array]),
59
60 Causal,
62}
63
64impl<'a> From<&'a Array> for ScaledDotProductAttentionMask<'a> {
65 fn from(mask: &'a Array) -> Self {
66 ScaledDotProductAttentionMask::Array(mask)
67 }
68}
69
70impl<'a> From<&'a [Array]> for ScaledDotProductAttentionMask<'a> {
71 fn from(masks: &'a [Array]) -> Self {
72 ScaledDotProductAttentionMask::Arrays(masks)
73 }
74}
75
76impl<'a> IntoOption<ScaledDotProductAttentionMask<'a>> for &'a Array {
77 fn into_option(self) -> Option<ScaledDotProductAttentionMask<'a>> {
78 Some(ScaledDotProductAttentionMask::Array(self))
79 }
80}
81
82impl<'a> IntoOption<ScaledDotProductAttentionMask<'a>> for &'a [Array] {
83 fn into_option(self) -> Option<ScaledDotProductAttentionMask<'a>> {
84 Some(ScaledDotProductAttentionMask::Arrays(self))
85 }
86}
87
88impl ScaledDotProductAttentionMask<'_> {
89 fn as_mode_and_masks(&self) -> (&'static CStr, VectorArray) {
90 match self {
91 ScaledDotProductAttentionMask::Array(mask) => (
92 DEFAULT_MASK_MODE,
93 VectorArray::try_from_iter([mask].iter()).unwrap(),
94 ),
95 ScaledDotProductAttentionMask::Arrays(masks) => (
96 DEFAULT_MASK_MODE,
97 VectorArray::try_from_iter(masks.iter()).unwrap(),
98 ),
99 ScaledDotProductAttentionMask::Causal => (CAUSAL_MASK_MODE, unsafe {
100 VectorArray::from_ptr(mlx_sys::mlx_vector_array_new())
101 }),
102 }
103 }
104}
105
106#[generate_macro(customize(root = "$crate::fast"))]
116#[default_device]
117pub fn scaled_dot_product_attention_device<'a>(
118 queries: impl AsRef<Array>,
119 keys: impl AsRef<Array>,
120 values: impl AsRef<Array>,
121 scale: f32,
122 #[optional] mask: impl IntoOption<ScaledDotProductAttentionMask<'a>>,
123 #[optional] stream: impl AsRef<Stream>,
124) -> Result<Array> {
125 let (mask_mode, masks) = mask.into_option().map_or_else(
126 || {
127 (DEFAULT_MASK_MODE, unsafe {
128 VectorArray::from_ptr(mlx_sys::mlx_vector_array_new())
129 })
130 },
131 |m| m.as_mode_and_masks(),
132 );
133
134 Array::try_from_op(|res| unsafe {
135 mlx_sys::mlx_fast_scaled_dot_product_attention(
136 res,
137 queries.as_ref().as_ptr(),
138 keys.as_ref().as_ptr(),
139 values.as_ref().as_ptr(),
140 scale,
141 mask_mode.as_ptr(),
142 masks.as_ptr(),
143 stream.as_ref().as_ptr(),
144 )
145 })
146}
147
148#[generate_macro(customize(root = "$crate::fast"))]
159#[default_device]
160pub fn rms_norm_device(
161 x: impl AsRef<Array>,
162 weight: impl AsRef<Array>,
163 eps: f32,
164 #[optional] stream: impl AsRef<Stream>,
165) -> Result<Array> {
166 Array::try_from_op(|res| unsafe {
167 mlx_sys::mlx_fast_rms_norm(
168 res,
169 x.as_ref().as_ptr(),
170 weight.as_ref().as_ptr(),
171 eps,
172 stream.as_ref().as_ptr(),
173 )
174 })
175}
176
177#[generate_macro(customize(root = "$crate::fast"))]
191#[default_device]
192pub fn layer_norm_device<'a>(
193 #[named] x: impl AsRef<Array>,
194 #[optional] weight: impl Into<Option<&'a Array>>,
195 #[optional] bias: impl Into<Option<&'a Array>>,
196 #[named] eps: f32,
197 #[optional] stream: impl AsRef<Stream>,
198) -> Result<Array> {
199 Array::try_from_op(|res| unsafe {
200 mlx_sys::mlx_fast_layer_norm(
201 res,
202 x.as_ref().as_ptr(),
203 weight
204 .into()
205 .map(|a| a.as_ptr())
206 .unwrap_or(mlx_sys::mlx_array_new()),
207 bias.into()
208 .map(|a| a.as_ptr())
209 .unwrap_or(mlx_sys::mlx_array_new()),
210 eps,
211 stream.as_ref().as_ptr(),
212 )
213 })
214}
215
216#[cfg(test)]
217mod tests {
218 use super::*;
219 use crate::{
220 ops::indexing::{ArrayIndexOp, IndexOp},
221 random::normal,
222 };
223 use float_eq::assert_float_eq;
224 use pretty_assertions::assert_eq;
225
226 #[test]
227 fn test_rope() {
228 crate::random::seed(71).unwrap();
229 let a = crate::random::uniform::<_, f32>(0.0, 1.0, &[2, 8, 16], None).unwrap();
230 assert_eq!(a.shape(), [2, 8, 16]);
231 assert_eq!(a.dtype(), crate::Dtype::Float32);
232
233 let result = rope(a, 8, false, 10000., 1.0, 0, None).unwrap();
234 assert_eq!(result.shape(), [2, 8, 16]);
235 assert_eq!(result.dtype(), crate::Dtype::Float32);
236 assert_float_eq!(
237 result.mean(None).unwrap().item::<f32>(),
238 0.456_253_77,
239 abs <= 0.009_125_075
240 );
241 assert_float_eq!(
242 result.sum(None).unwrap().item::<f32>(),
243 116.800_964,
244 abs <= 2.336_019_3
245 );
246 }
247
248 #[test]
249 fn test_rms_norm() {
250 crate::random::seed(103).unwrap();
251 let a = crate::random::uniform::<_, f32>(0.0, 1.0, &[2, 8, 16], None).unwrap();
252 assert_eq!(a.shape(), [2, 8, 16]);
253 assert_eq!(a.dtype(), crate::Dtype::Float32);
254
255 let weight = Array::ones::<f32>(&[16]).unwrap();
256 let result = rms_norm(a, weight, 1e-5).unwrap();
257 assert_eq!(result.shape(), [2, 8, 16]);
258 assert_eq!(result.dtype(), crate::Dtype::Float32);
259 assert_float_eq!(
260 result.mean(None).unwrap().item::<f32>(),
261 0.872_938_75,
262 abs <= 0.017_458_774
263 );
264 assert_float_eq!(
265 result.sum(None).unwrap().item::<f32>(),
266 223.472_32,
267 abs <= 4.469_446
268 );
269 }
270
271 #[test]
272 pub fn test_layer_norm_affine() {
273 crate::random::seed(635).unwrap();
274 let a = crate::random::uniform::<_, f32>(0.0, 1.0, &[2, 8, 16], None).unwrap();
275 assert_eq!(a.shape(), [2, 8, 16]);
276 assert_eq!(a.dtype(), crate::Dtype::Float32);
277
278 let weight = Array::ones::<f32>(&[16]).unwrap();
279 let bias = Array::zeros::<f32>(&[16]).unwrap();
280 let result = layer_norm(a, &weight, &bias, 1e-5).unwrap();
281 let result = result.index((ArrayIndexOp::Ellipsis, 0));
282 assert_eq!(result.shape(), [2, 8]);
283 assert_eq!(result.dtype(), crate::Dtype::Float32);
284 assert_float_eq!(
285 result.mean(None).unwrap().item::<f32>(),
286 0.290_990_38,
287 abs <= 0.005_819_807_8
288 );
289 assert_float_eq!(
290 result.sum(None).unwrap().item::<f32>(),
291 4.655_846,
292 abs <= 0.093_116_924
293 );
294 }
295
296 #[test]
297 #[allow(non_snake_case)]
298 fn test_fast_sdpa() {
299 let Dk = 64;
303 let scale = 1.0 / (Dk as f32).sqrt();
304 for seq_len in [63, 129, 400] {
305 for dtype in [crate::Dtype::Float32, crate::Dtype::Float16] {
306 let B = 2;
307 let H = 24;
308 let q = normal::<f32>(&[B, H, seq_len, Dk], None, None, None)
309 .unwrap()
310 .as_dtype(dtype)
311 .unwrap();
312 let k = normal::<f32>(&[B, H, seq_len, Dk], None, None, None)
313 .unwrap()
314 .as_dtype(dtype)
315 .unwrap();
316 let v = normal::<f32>(&[B, H, seq_len, Dk], None, None, None)
317 .unwrap()
318 .as_dtype(dtype)
319 .unwrap();
320
321 let result = scaled_dot_product_attention(q, k, v, scale, None).unwrap();
322 assert_eq!(result.shape(), [B, H, seq_len, Dk]);
323 assert_eq!(result.dtype(), dtype);
324 }
325 }
326 }
327}