mlx_rs/
fast.rs

1//! Fast implementations of commonly used multi-op functions.
2
3use std::ffi::CString;
4
5use crate::error::{Exception, Result};
6use crate::utils::guard::Guarded;
7use crate::utils::VectorArray;
8use crate::{Array, Stream};
9use mlx_internal_macros::default_device;
10
11/// Optimized implementation of `NN.RoPE`.
12#[allow(clippy::too_many_arguments)]
13#[default_device]
14pub fn rope_device<'a>(
15    array: impl AsRef<Array>,
16    dimensions: i32,
17    traditional: bool,
18    base: impl Into<Option<f32>>,
19    scale: f32,
20    offset: i32,
21    freqs: impl Into<Option<&'a Array>>,
22    stream: impl AsRef<Stream>,
23) -> Result<Array> {
24    let base = base.into();
25    let base = mlx_sys::mlx_optional_float {
26        value: base.unwrap_or(0.0),
27        has_value: base.is_some(),
28    };
29    let freqs = freqs.into();
30    Array::try_from_op(|res| unsafe {
31        mlx_sys::mlx_fast_rope(
32            res,
33            array.as_ref().as_ptr(),
34            dimensions,
35            traditional,
36            base,
37            scale,
38            offset,
39            freqs
40                .map(|a| a.as_ptr())
41                .unwrap_or(mlx_sys::mlx_array_new()),
42            stream.as_ref().as_ptr(),
43        )
44    })
45}
46
47/// A fast implementation of multi-head attention: `O = softmax(Q @ K.T, dim=-1) @ V`
48///
49/// Supports [Multi-Head Attention](https://arxiv.org/abs/1706.03762), [Grouped Query Attention](https://arxiv.org/abs/2305.13245), and [Multi-Query Attention](https://arxiv.org/abs/1911.02150).
50///
51/// This function will dispatch to an optimized Metal kernel when the query sequence length is 1. It handles other cases with regular MLX operations.
52///
53/// > Note: The softmax operation is performed in float32 precision regardless of input precision (float16 or float32).
54///
55/// > Note: For Grouped Query Attention and Multi-Query Attention, the input arrays for `key` and `value` should not be pre-tiled to match the `query` array.
56#[default_device]
57pub fn scaled_dot_product_attention_device<'a>(
58    queries: impl AsRef<Array>,
59    keys: impl AsRef<Array>,
60    values: impl AsRef<Array>,
61    scale: f32,
62    mask: impl Into<Option<&'a Array>>,
63    stream: impl AsRef<Stream>,
64) -> Result<Array> {
65    let mask_mode = CString::new("").map_err(|e| Exception::custom(format!("{}", e)))?;
66    let masks = match mask.into() {
67        Some(m) => VectorArray::try_from_iter([m].iter())?,
68        None => unsafe { VectorArray::from_ptr(mlx_sys::mlx_vector_array_new()) },
69    };
70
71    Array::try_from_op(|res| unsafe {
72        mlx_sys::mlx_fast_scaled_dot_product_attention(
73            res,
74            queries.as_ref().as_ptr(),
75            keys.as_ref().as_ptr(),
76            values.as_ref().as_ptr(),
77            scale,
78            mask_mode.as_ptr(),
79            masks.as_ptr(),
80            stream.as_ref().as_ptr(),
81        )
82    })
83}
84
85/// Root Mean Square normalization (RMS norm).
86///
87/// The normalization is with respect to the last axis of the input `x`.
88///
89/// # Params
90///
91/// - x: input array
92/// - weight: A multiplicative weight to scale the result by. The `weight` should be one-dimensional with the same size as the last axis of `x`.
93/// - eps: A small additive constant for numerical stability
94/// - stream: stream or device to evaluate on
95#[default_device]
96pub fn rms_norm_device(
97    x: impl AsRef<Array>,
98    weight: impl AsRef<Array>,
99    eps: f32,
100    stream: impl AsRef<Stream>,
101) -> Result<Array> {
102    Array::try_from_op(|res| unsafe {
103        mlx_sys::mlx_fast_rms_norm(
104            res,
105            x.as_ref().as_ptr(),
106            weight.as_ref().as_ptr(),
107            eps,
108            stream.as_ref().as_ptr(),
109        )
110    })
111}
112
113/// Layer normalization.
114///
115/// The normalization is with respect to the last axis of the input `x`.
116///
117/// # Params
118///
119/// - x: input array
120/// - weight: A multiplicative weight to scale the result by. The `weight` should be one-dimensional
121///   with the same size as the last axis of `x`.  If not given no scaling will occur.
122/// - bias: An additive offset to be added to the result. The `bias` should be one-dimensional
123///   with the same size as the last axis of `x`.  It not given no offset will occur.
124/// - eps: A small additive constant for numerical stability
125/// - stream: stream or device to evaluate on
126#[default_device]
127pub fn layer_norm_device<'a>(
128    x: impl AsRef<Array>,
129    weight: impl Into<Option<&'a Array>>,
130    bias: impl Into<Option<&'a Array>>,
131    eps: f32,
132    stream: impl AsRef<Stream>,
133) -> Result<Array> {
134    Array::try_from_op(|res| unsafe {
135        mlx_sys::mlx_fast_layer_norm(
136            res,
137            x.as_ref().as_ptr(),
138            weight
139                .into()
140                .map(|a| a.as_ptr())
141                .unwrap_or(mlx_sys::mlx_array_new()),
142            bias.into()
143                .map(|a| a.as_ptr())
144                .unwrap_or(mlx_sys::mlx_array_new()),
145            eps,
146            stream.as_ref().as_ptr(),
147        )
148    })
149}
150
151#[cfg(test)]
152mod tests {
153    use super::*;
154    use crate::ops::indexing::{ArrayIndexOp, IndexOp};
155    use float_eq::assert_float_eq;
156    use pretty_assertions::assert_eq;
157
158    #[test]
159    fn test_rope() {
160        crate::random::seed(71).unwrap();
161        let a = crate::random::uniform::<_, f32>(0.0, 1.0, &[2, 8, 16], None).unwrap();
162        assert_eq!(a.shape(), [2, 8, 16]);
163        assert_eq!(a.dtype(), crate::Dtype::Float32);
164
165        let result = rope(a, 8, false, 10000., 1.0, 0, None).unwrap();
166        assert_eq!(result.shape(), [2, 8, 16]);
167        assert_eq!(result.dtype(), crate::Dtype::Float32);
168        assert_float_eq!(
169            result.mean(None).unwrap().item::<f32>(),
170            0.456_253_77,
171            abs <= 0.009_125_075
172        );
173        assert_float_eq!(
174            result.sum(None).unwrap().item::<f32>(),
175            116.800_964,
176            abs <= 2.336_019_3
177        );
178    }
179
180    #[test]
181    fn test_rms_norm() {
182        crate::random::seed(103).unwrap();
183        let a = crate::random::uniform::<_, f32>(0.0, 1.0, &[2, 8, 16], None).unwrap();
184        assert_eq!(a.shape(), [2, 8, 16]);
185        assert_eq!(a.dtype(), crate::Dtype::Float32);
186
187        let weight = Array::ones::<f32>(&[16]).unwrap();
188        let result = rms_norm(a, weight, 1e-5).unwrap();
189        assert_eq!(result.shape(), [2, 8, 16]);
190        assert_eq!(result.dtype(), crate::Dtype::Float32);
191        assert_float_eq!(
192            result.mean(None).unwrap().item::<f32>(),
193            0.872_938_75,
194            abs <= 0.017_458_774
195        );
196        assert_float_eq!(
197            result.sum(None).unwrap().item::<f32>(),
198            223.472_32,
199            abs <= 4.469_446
200        );
201    }
202
203    #[test]
204    pub fn test_layer_norm_affine() {
205        crate::random::seed(635).unwrap();
206        let a = crate::random::uniform::<_, f32>(0.0, 1.0, &[2, 8, 16], None).unwrap();
207        assert_eq!(a.shape(), [2, 8, 16]);
208        assert_eq!(a.dtype(), crate::Dtype::Float32);
209
210        let weight = Array::ones::<f32>(&[16]).unwrap();
211        let bias = Array::zeros::<f32>(&[16]).unwrap();
212        let result = layer_norm(a, &weight, &bias, 1e-5).unwrap();
213        let result = result.index((ArrayIndexOp::Ellipsis, 0));
214        assert_eq!(result.shape(), [2, 8]);
215        assert_eq!(result.dtype(), crate::Dtype::Float32);
216        assert_float_eq!(
217            result.mean(None).unwrap().item::<f32>(),
218            0.290_990_38,
219            abs <= 0.005_819_807_8
220        );
221        assert_float_eq!(
222            result.sum(None).unwrap().item::<f32>(),
223            4.655_846,
224            abs <= 0.093_116_924
225        );
226    }
227}