mlx_rs/
fast.rs

1//! Fast implementations of commonly used multi-op functions.
2
3use crate::error::Result;
4use crate::utils::guard::Guarded;
5use crate::{Array, Stream, StreamOrDevice};
6use mlx_internal_macros::default_device;
7
8/// Optimized implementation of `NN.RoPE`.
9#[allow(clippy::too_many_arguments)]
10#[default_device]
11pub fn rope_device<'a>(
12    array: impl AsRef<Array>,
13    dimensions: i32,
14    traditional: bool,
15    base: impl Into<Option<f32>>,
16    scale: f32,
17    offset: i32,
18    freqs: impl Into<Option<&'a Array>>,
19    stream: impl AsRef<Stream>,
20) -> Result<Array> {
21    let base = base.into();
22    let base = mlx_sys::mlx_optional_float {
23        value: base.unwrap_or(0.0),
24        has_value: base.is_some(),
25    };
26    let freqs = freqs.into();
27    Array::try_from_op(|res| unsafe {
28        mlx_sys::mlx_fast_rope(
29            res,
30            array.as_ref().as_ptr(),
31            dimensions,
32            traditional,
33            base,
34            scale,
35            offset,
36            freqs
37                .map(|a| a.as_ptr())
38                .unwrap_or(mlx_sys::mlx_array_new()),
39            stream.as_ref().as_ptr(),
40        )
41    })
42}
43
44/// A fast implementation of multi-head attention: `O = softmax(Q @ K.T, dim=-1) @ V`
45///
46/// Supports [Multi-Head Attention](https://arxiv.org/abs/1706.03762), [Grouped Query Attention](https://arxiv.org/abs/2305.13245), and [Multi-Query Attention](https://arxiv.org/abs/1911.02150).
47///
48/// This function will dispatch to an optimized Metal kernel when the query sequence length is 1. It handles other cases with regular MLX operations.
49///
50/// > Note: The softmax operation is performed in float32 precision regardless of input precision (float16 or float32).
51///
52/// > Note: For Grouped Query Attention and Multi-Query Attention, the input arrays for `key` and `value` should not be pre-tiled to match the `query` array.
53#[default_device]
54pub fn scaled_dot_product_attention_device<'a>(
55    queries: impl AsRef<Array>,
56    keys: impl AsRef<Array>,
57    values: impl AsRef<Array>,
58    scale: f32,
59    mask: impl Into<Option<&'a Array>>,
60    memory_efficient_threshold: impl Into<Option<i32>>,
61    stream: impl AsRef<Stream>,
62) -> Result<Array> {
63    let memory_efficient_threshold = memory_efficient_threshold.into();
64    let memory_efficient_threshold = mlx_sys::mlx_optional_int {
65        value: memory_efficient_threshold.unwrap_or(0),
66        has_value: memory_efficient_threshold.is_some(),
67    };
68
69    Array::try_from_op(|res| unsafe {
70        mlx_sys::mlx_fast_scaled_dot_product_attention(
71            res,
72            queries.as_ref().as_ptr(),
73            keys.as_ref().as_ptr(),
74            values.as_ref().as_ptr(),
75            scale,
76            mask.into()
77                .map(|a| a.as_ptr())
78                .unwrap_or(mlx_sys::mlx_array_new()),
79            memory_efficient_threshold,
80            stream.as_ref().as_ptr(),
81        )
82    })
83}
84
85/// Root Mean Square normalization (RMS norm).
86///
87/// The normalization is with respect to the last axis of the input `x`.
88///
89/// # Params
90///
91/// - x: input array
92/// - weight: A multiplicative weight to scale the result by. The `weight` should be one-dimensional with the same size as the last axis of `x`.
93/// - eps: A small additive constant for numerical stability
94/// - stream: stream or device to evaluate on
95#[default_device]
96pub fn rms_norm_device(
97    x: impl AsRef<Array>,
98    weight: impl AsRef<Array>,
99    eps: f32,
100    stream: impl AsRef<Stream>,
101) -> Result<Array> {
102    Array::try_from_op(|res| unsafe {
103        mlx_sys::mlx_fast_rms_norm(
104            res,
105            x.as_ref().as_ptr(),
106            weight.as_ref().as_ptr(),
107            eps,
108            stream.as_ref().as_ptr(),
109        )
110    })
111}
112
113/// Layer normalization.
114///
115/// The normalization is with respect to the last axis of the input `x`.
116///
117/// # Params
118///
119/// - x: input array
120/// - weight: A multiplicative weight to scale the result by. The `weight` should be one-dimensional
121///   with the same size as the last axis of `x`.  If not given no scaling will occur.
122/// - bias: An additive offset to be added to the result. The `bias` should be one-dimensional
123///   with the same size as the last axis of `x`.  It not given no offset will occur.
124/// - eps: A small additive constant for numerical stability
125/// - stream: stream or device to evaluate on
126#[default_device]
127pub fn layer_norm_device<'a>(
128    x: impl AsRef<Array>,
129    weight: impl Into<Option<&'a Array>>,
130    bias: impl Into<Option<&'a Array>>,
131    eps: f32,
132    stream: impl AsRef<Stream>,
133) -> Result<Array> {
134    Array::try_from_op(|res| unsafe {
135        mlx_sys::mlx_fast_layer_norm(
136            res,
137            x.as_ref().as_ptr(),
138            weight
139                .into()
140                .map(|a| a.as_ptr())
141                .unwrap_or(mlx_sys::mlx_array_new()),
142            bias.into()
143                .map(|a| a.as_ptr())
144                .unwrap_or(mlx_sys::mlx_array_new()),
145            eps,
146            stream.as_ref().as_ptr(),
147        )
148    })
149}
150
151#[cfg(test)]
152mod tests {
153    use super::*;
154    use crate::ops::indexing::{ArrayIndexOp, IndexOp};
155    use float_eq::assert_float_eq;
156    use pretty_assertions::assert_eq;
157
158    #[test]
159    fn test_rope() {
160        crate::random::seed(71).unwrap();
161        let a = crate::random::uniform::<_, f32>(0.0, 1.0, &[2, 8, 16], None).unwrap();
162        assert_eq!(a.shape(), [2, 8, 16]);
163        assert_eq!(a.dtype(), crate::Dtype::Float32);
164
165        let result = rope(a, 8, false, 10000., 1.0, 0, None).unwrap();
166        assert_eq!(result.shape(), [2, 8, 16]);
167        assert_eq!(result.dtype(), crate::Dtype::Float32);
168        assert_float_eq!(
169            result.mean(None, None).unwrap().item::<f32>(),
170            0.456_253_77,
171            abs <= 0.009_125_075
172        );
173        assert_float_eq!(
174            result.sum(None, None).unwrap().item::<f32>(),
175            116.800_964,
176            abs <= 2.336_019_3
177        );
178    }
179
180    #[test]
181    fn test_rms_norm() {
182        crate::random::seed(103).unwrap();
183        let a = crate::random::uniform::<_, f32>(0.0, 1.0, &[2, 8, 16], None).unwrap();
184        assert_eq!(a.shape(), [2, 8, 16]);
185        assert_eq!(a.dtype(), crate::Dtype::Float32);
186
187        let weight = Array::ones::<f32>(&[16]).unwrap();
188        let result = rms_norm(a, weight, 1e-5).unwrap();
189        assert_eq!(result.shape(), [2, 8, 16]);
190        assert_eq!(result.dtype(), crate::Dtype::Float32);
191        assert_float_eq!(
192            result.mean(None, None).unwrap().item::<f32>(),
193            0.872_938_75,
194            abs <= 0.017_458_774
195        );
196        assert_float_eq!(
197            result.sum(None, None).unwrap().item::<f32>(),
198            223.472_32,
199            abs <= 4.469_446
200        );
201    }
202
203    #[test]
204    pub fn test_layer_norm_affine() {
205        crate::random::seed(635).unwrap();
206        let a = crate::random::uniform::<_, f32>(0.0, 1.0, &[2, 8, 16], None).unwrap();
207        assert_eq!(a.shape(), [2, 8, 16]);
208        assert_eq!(a.dtype(), crate::Dtype::Float32);
209
210        let weight = Array::ones::<f32>(&[16]).unwrap();
211        let bias = Array::zeros::<f32>(&[16]).unwrap();
212        let result = layer_norm(a, &weight, &bias, 1e-5).unwrap();
213        let result = result.index((ArrayIndexOp::Ellipsis, 0));
214        assert_eq!(result.shape(), [2, 8]);
215        assert_eq!(result.dtype(), crate::Dtype::Float32);
216        assert_float_eq!(
217            result.mean(None, None).unwrap().item::<f32>(),
218            0.290_990_38,
219            abs <= 0.005_819_807_8
220        );
221        assert_float_eq!(
222            result.sum(None, None).unwrap().item::<f32>(),
223            4.655_846,
224            abs <= 0.093_116_924
225        );
226    }
227}