mlx_rs/
fast.rs

1//! Fast implementations of commonly used multi-op functions.
2
3use std::ffi::CStr;
4
5use crate::error::Result;
6use crate::utils::guard::Guarded;
7use crate::utils::{IntoOption, VectorArray};
8use crate::{Array, Stream};
9use mlx_internal_macros::{default_device, generate_macro};
10
11/// Optimized implementation of `NN.RoPE`.
12#[allow(clippy::too_many_arguments)]
13#[generate_macro(customize(root = "$crate::fast"))]
14#[default_device]
15pub fn rope_device<'a>(
16    #[named] array: impl AsRef<Array>,
17    #[named] dimensions: i32,
18    #[named] traditional: bool,
19    #[optional] base: impl Into<Option<f32>>,
20    #[named] scale: f32,
21    #[named] offset: i32,
22    #[optional] freqs: impl Into<Option<&'a Array>>,
23    #[optional] stream: impl AsRef<Stream>,
24) -> Result<Array> {
25    let base = base.into();
26    let base = mlx_sys::mlx_optional_float {
27        value: base.unwrap_or(0.0),
28        has_value: base.is_some(),
29    };
30    let freqs = freqs.into();
31    Array::try_from_op(|res| unsafe {
32        mlx_sys::mlx_fast_rope(
33            res,
34            array.as_ref().as_ptr(),
35            dimensions,
36            traditional,
37            base,
38            scale,
39            offset,
40            freqs
41                .map(|a| a.as_ptr())
42                .unwrap_or(mlx_sys::mlx_array_new()),
43            stream.as_ref().as_ptr(),
44        )
45    })
46}
47
48const DEFAULT_MASK_MODE: &CStr = c"";
49const CAUSAL_MASK_MODE: &CStr = c"causal";
50
51/// Mask modes for scaled dot product attention.
52#[derive(Debug)]
53pub enum ScaledDotProductAttentionMask<'a> {
54    /// Array
55    Array(&'a Array),
56
57    /// Arrays
58    Arrays(&'a [Array]),
59
60    /// Causal
61    Causal,
62}
63
64impl<'a> From<&'a Array> for ScaledDotProductAttentionMask<'a> {
65    fn from(mask: &'a Array) -> Self {
66        ScaledDotProductAttentionMask::Array(mask)
67    }
68}
69
70impl<'a> From<&'a [Array]> for ScaledDotProductAttentionMask<'a> {
71    fn from(masks: &'a [Array]) -> Self {
72        ScaledDotProductAttentionMask::Arrays(masks)
73    }
74}
75
76impl<'a> IntoOption<ScaledDotProductAttentionMask<'a>> for &'a Array {
77    fn into_option(self) -> Option<ScaledDotProductAttentionMask<'a>> {
78        Some(ScaledDotProductAttentionMask::Array(self))
79    }
80}
81
82impl<'a> IntoOption<ScaledDotProductAttentionMask<'a>> for &'a [Array] {
83    fn into_option(self) -> Option<ScaledDotProductAttentionMask<'a>> {
84        Some(ScaledDotProductAttentionMask::Arrays(self))
85    }
86}
87
88impl ScaledDotProductAttentionMask<'_> {
89    fn as_mode_and_masks(&self) -> (&'static CStr, VectorArray) {
90        match self {
91            ScaledDotProductAttentionMask::Array(mask) => (
92                DEFAULT_MASK_MODE,
93                VectorArray::try_from_iter([mask].iter()).unwrap(),
94            ),
95            ScaledDotProductAttentionMask::Arrays(masks) => (
96                DEFAULT_MASK_MODE,
97                VectorArray::try_from_iter(masks.iter()).unwrap(),
98            ),
99            ScaledDotProductAttentionMask::Causal => (CAUSAL_MASK_MODE, unsafe {
100                VectorArray::from_ptr(mlx_sys::mlx_vector_array_new())
101            }),
102        }
103    }
104}
105
106/// A fast implementation of multi-head attention: `O = softmax(Q @ K.T, dim=-1) @ V`
107///
108/// Supports [Multi-Head Attention](https://arxiv.org/abs/1706.03762), [Grouped Query Attention](https://arxiv.org/abs/2305.13245), and [Multi-Query Attention](https://arxiv.org/abs/1911.02150).
109///
110/// This function will dispatch to an optimized Metal kernel when the query sequence length is 1. It handles other cases with regular MLX operations.
111///
112/// > Note: The softmax operation is performed in float32 precision regardless of input precision (float16 or float32).
113///
114/// > Note: For Grouped Query Attention and Multi-Query Attention, the input arrays for `key` and `value` should not be pre-tiled to match the `query` array.
115#[generate_macro(customize(root = "$crate::fast"))]
116#[default_device]
117pub fn scaled_dot_product_attention_device<'a>(
118    queries: impl AsRef<Array>,
119    keys: impl AsRef<Array>,
120    values: impl AsRef<Array>,
121    scale: f32,
122    #[optional] mask: impl IntoOption<ScaledDotProductAttentionMask<'a>>,
123    #[optional] stream: impl AsRef<Stream>,
124) -> Result<Array> {
125    let (mask_mode, masks) = mask.into_option().map_or_else(
126        || {
127            (DEFAULT_MASK_MODE, unsafe {
128                VectorArray::from_ptr(mlx_sys::mlx_vector_array_new())
129            })
130        },
131        |m| m.as_mode_and_masks(),
132    );
133
134    Array::try_from_op(|res| unsafe {
135        mlx_sys::mlx_fast_scaled_dot_product_attention(
136            res,
137            queries.as_ref().as_ptr(),
138            keys.as_ref().as_ptr(),
139            values.as_ref().as_ptr(),
140            scale,
141            mask_mode.as_ptr(),
142            masks.as_ptr(),
143            stream.as_ref().as_ptr(),
144        )
145    })
146}
147
148/// Root Mean Square normalization (RMS norm).
149///
150/// The normalization is with respect to the last axis of the input `x`.
151///
152/// # Params
153///
154/// - x: input array
155/// - weight: A multiplicative weight to scale the result by. The `weight` should be one-dimensional with the same size as the last axis of `x`.
156/// - eps: A small additive constant for numerical stability
157/// - stream: stream or device to evaluate on
158#[generate_macro(customize(root = "$crate::fast"))]
159#[default_device]
160pub fn rms_norm_device(
161    x: impl AsRef<Array>,
162    weight: impl AsRef<Array>,
163    eps: f32,
164    #[optional] stream: impl AsRef<Stream>,
165) -> Result<Array> {
166    Array::try_from_op(|res| unsafe {
167        mlx_sys::mlx_fast_rms_norm(
168            res,
169            x.as_ref().as_ptr(),
170            weight.as_ref().as_ptr(),
171            eps,
172            stream.as_ref().as_ptr(),
173        )
174    })
175}
176
177/// Layer normalization.
178///
179/// The normalization is with respect to the last axis of the input `x`.
180///
181/// # Params
182///
183/// - x: input array
184/// - weight: A multiplicative weight to scale the result by. The `weight` should be one-dimensional
185///   with the same size as the last axis of `x`.  If not given no scaling will occur.
186/// - bias: An additive offset to be added to the result. The `bias` should be one-dimensional
187///   with the same size as the last axis of `x`.  It not given no offset will occur.
188/// - eps: A small additive constant for numerical stability
189/// - stream: stream or device to evaluate on
190#[generate_macro(customize(root = "$crate::fast"))]
191#[default_device]
192pub fn layer_norm_device<'a>(
193    #[named] x: impl AsRef<Array>,
194    #[optional] weight: impl Into<Option<&'a Array>>,
195    #[optional] bias: impl Into<Option<&'a Array>>,
196    #[named] eps: f32,
197    #[optional] stream: impl AsRef<Stream>,
198) -> Result<Array> {
199    Array::try_from_op(|res| unsafe {
200        mlx_sys::mlx_fast_layer_norm(
201            res,
202            x.as_ref().as_ptr(),
203            weight
204                .into()
205                .map(|a| a.as_ptr())
206                .unwrap_or(mlx_sys::mlx_array_new()),
207            bias.into()
208                .map(|a| a.as_ptr())
209                .unwrap_or(mlx_sys::mlx_array_new()),
210            eps,
211            stream.as_ref().as_ptr(),
212        )
213    })
214}
215
216#[cfg(test)]
217mod tests {
218    use super::*;
219    use crate::{
220        ops::indexing::{ArrayIndexOp, IndexOp},
221        random::normal,
222    };
223    use float_eq::assert_float_eq;
224    use pretty_assertions::assert_eq;
225
226    #[test]
227    fn test_rope() {
228        crate::random::seed(71).unwrap();
229        let a = crate::random::uniform::<_, f32>(0.0, 1.0, &[2, 8, 16], None).unwrap();
230        assert_eq!(a.shape(), [2, 8, 16]);
231        assert_eq!(a.dtype(), crate::Dtype::Float32);
232
233        let result = rope(a, 8, false, 10000., 1.0, 0, None).unwrap();
234        assert_eq!(result.shape(), [2, 8, 16]);
235        assert_eq!(result.dtype(), crate::Dtype::Float32);
236        assert_float_eq!(
237            result.mean(None).unwrap().item::<f32>(),
238            0.456_253_77,
239            abs <= 0.009_125_075
240        );
241        assert_float_eq!(
242            result.sum(None).unwrap().item::<f32>(),
243            116.800_964,
244            abs <= 2.336_019_3
245        );
246    }
247
248    #[test]
249    fn test_rms_norm() {
250        crate::random::seed(103).unwrap();
251        let a = crate::random::uniform::<_, f32>(0.0, 1.0, &[2, 8, 16], None).unwrap();
252        assert_eq!(a.shape(), [2, 8, 16]);
253        assert_eq!(a.dtype(), crate::Dtype::Float32);
254
255        let weight = Array::ones::<f32>(&[16]).unwrap();
256        let result = rms_norm(a, weight, 1e-5).unwrap();
257        assert_eq!(result.shape(), [2, 8, 16]);
258        assert_eq!(result.dtype(), crate::Dtype::Float32);
259        assert_float_eq!(
260            result.mean(None).unwrap().item::<f32>(),
261            0.872_938_75,
262            abs <= 0.017_458_774
263        );
264        assert_float_eq!(
265            result.sum(None).unwrap().item::<f32>(),
266            223.472_32,
267            abs <= 4.469_446
268        );
269    }
270
271    #[test]
272    pub fn test_layer_norm_affine() {
273        crate::random::seed(635).unwrap();
274        let a = crate::random::uniform::<_, f32>(0.0, 1.0, &[2, 8, 16], None).unwrap();
275        assert_eq!(a.shape(), [2, 8, 16]);
276        assert_eq!(a.dtype(), crate::Dtype::Float32);
277
278        let weight = Array::ones::<f32>(&[16]).unwrap();
279        let bias = Array::zeros::<f32>(&[16]).unwrap();
280        let result = layer_norm(a, &weight, &bias, 1e-5).unwrap();
281        let result = result.index((ArrayIndexOp::Ellipsis, 0));
282        assert_eq!(result.shape(), [2, 8]);
283        assert_eq!(result.dtype(), crate::Dtype::Float32);
284        assert_float_eq!(
285            result.mean(None).unwrap().item::<f32>(),
286            0.290_990_38,
287            abs <= 0.005_819_807_8
288        );
289        assert_float_eq!(
290            result.sum(None).unwrap().item::<f32>(),
291            4.655_846,
292            abs <= 0.093_116_924
293        );
294    }
295
296    #[test]
297    #[allow(non_snake_case)]
298    fn test_fast_sdpa() {
299        // This test just makes sure that `scaled_dot_product_attention` is callable
300        // in the various cases, based on the Python test `test_fast_sdpa`.
301
302        let Dk = 64;
303        let scale = 1.0 / (Dk as f32).sqrt();
304        for seq_len in [63, 129, 400] {
305            for dtype in [crate::Dtype::Float32, crate::Dtype::Float16] {
306                let B = 2;
307                let H = 24;
308                let q = normal::<f32>(&[B, H, seq_len, Dk], None, None, None)
309                    .unwrap()
310                    .as_dtype(dtype)
311                    .unwrap();
312                let k = normal::<f32>(&[B, H, seq_len, Dk], None, None, None)
313                    .unwrap()
314                    .as_dtype(dtype)
315                    .unwrap();
316                let v = normal::<f32>(&[B, H, seq_len, Dk], None, None, None)
317                    .unwrap()
318                    .as_dtype(dtype)
319                    .unwrap();
320
321                let result = scaled_dot_product_attention(q, k, v, scale, None).unwrap();
322                assert_eq!(result.shape(), [B, H, seq_len, Dk]);
323                assert_eq!(result.dtype(), dtype);
324            }
325        }
326    }
327}