1use crate::error::Result;
4use crate::utils::guard::Guarded;
5use crate::{Array, Stream, StreamOrDevice};
6use mlx_internal_macros::default_device;
7
8#[allow(clippy::too_many_arguments)]
10#[default_device]
11pub fn rope_device<'a>(
12 array: impl AsRef<Array>,
13 dimensions: i32,
14 traditional: bool,
15 base: impl Into<Option<f32>>,
16 scale: f32,
17 offset: i32,
18 freqs: impl Into<Option<&'a Array>>,
19 stream: impl AsRef<Stream>,
20) -> Result<Array> {
21 let base = base.into();
22 let base = mlx_sys::mlx_optional_float {
23 value: base.unwrap_or(0.0),
24 has_value: base.is_some(),
25 };
26 let freqs = freqs.into();
27 Array::try_from_op(|res| unsafe {
28 mlx_sys::mlx_fast_rope(
29 res,
30 array.as_ref().as_ptr(),
31 dimensions,
32 traditional,
33 base,
34 scale,
35 offset,
36 freqs
37 .map(|a| a.as_ptr())
38 .unwrap_or(mlx_sys::mlx_array_new()),
39 stream.as_ref().as_ptr(),
40 )
41 })
42}
43
44#[default_device]
54pub fn scaled_dot_product_attention_device<'a>(
55 queries: impl AsRef<Array>,
56 keys: impl AsRef<Array>,
57 values: impl AsRef<Array>,
58 scale: f32,
59 mask: impl Into<Option<&'a Array>>,
60 memory_efficient_threshold: impl Into<Option<i32>>,
61 stream: impl AsRef<Stream>,
62) -> Result<Array> {
63 let memory_efficient_threshold = memory_efficient_threshold.into();
64 let memory_efficient_threshold = mlx_sys::mlx_optional_int {
65 value: memory_efficient_threshold.unwrap_or(0),
66 has_value: memory_efficient_threshold.is_some(),
67 };
68
69 Array::try_from_op(|res| unsafe {
70 mlx_sys::mlx_fast_scaled_dot_product_attention(
71 res,
72 queries.as_ref().as_ptr(),
73 keys.as_ref().as_ptr(),
74 values.as_ref().as_ptr(),
75 scale,
76 mask.into()
77 .map(|a| a.as_ptr())
78 .unwrap_or(mlx_sys::mlx_array_new()),
79 memory_efficient_threshold,
80 stream.as_ref().as_ptr(),
81 )
82 })
83}
84
85#[default_device]
96pub fn rms_norm_device(
97 x: impl AsRef<Array>,
98 weight: impl AsRef<Array>,
99 eps: f32,
100 stream: impl AsRef<Stream>,
101) -> Result<Array> {
102 Array::try_from_op(|res| unsafe {
103 mlx_sys::mlx_fast_rms_norm(
104 res,
105 x.as_ref().as_ptr(),
106 weight.as_ref().as_ptr(),
107 eps,
108 stream.as_ref().as_ptr(),
109 )
110 })
111}
112
113#[default_device]
127pub fn layer_norm_device<'a>(
128 x: impl AsRef<Array>,
129 weight: impl Into<Option<&'a Array>>,
130 bias: impl Into<Option<&'a Array>>,
131 eps: f32,
132 stream: impl AsRef<Stream>,
133) -> Result<Array> {
134 Array::try_from_op(|res| unsafe {
135 mlx_sys::mlx_fast_layer_norm(
136 res,
137 x.as_ref().as_ptr(),
138 weight
139 .into()
140 .map(|a| a.as_ptr())
141 .unwrap_or(mlx_sys::mlx_array_new()),
142 bias.into()
143 .map(|a| a.as_ptr())
144 .unwrap_or(mlx_sys::mlx_array_new()),
145 eps,
146 stream.as_ref().as_ptr(),
147 )
148 })
149}
150
151#[cfg(test)]
152mod tests {
153 use super::*;
154 use crate::ops::indexing::{ArrayIndexOp, IndexOp};
155 use float_eq::assert_float_eq;
156 use pretty_assertions::assert_eq;
157
158 #[test]
159 fn test_rope() {
160 crate::random::seed(71).unwrap();
161 let a = crate::random::uniform::<_, f32>(0.0, 1.0, &[2, 8, 16], None).unwrap();
162 assert_eq!(a.shape(), [2, 8, 16]);
163 assert_eq!(a.dtype(), crate::Dtype::Float32);
164
165 let result = rope(a, 8, false, 10000., 1.0, 0, None).unwrap();
166 assert_eq!(result.shape(), [2, 8, 16]);
167 assert_eq!(result.dtype(), crate::Dtype::Float32);
168 assert_float_eq!(
169 result.mean(None, None).unwrap().item::<f32>(),
170 0.456_253_77,
171 abs <= 0.009_125_075
172 );
173 assert_float_eq!(
174 result.sum(None, None).unwrap().item::<f32>(),
175 116.800_964,
176 abs <= 2.336_019_3
177 );
178 }
179
180 #[test]
181 fn test_rms_norm() {
182 crate::random::seed(103).unwrap();
183 let a = crate::random::uniform::<_, f32>(0.0, 1.0, &[2, 8, 16], None).unwrap();
184 assert_eq!(a.shape(), [2, 8, 16]);
185 assert_eq!(a.dtype(), crate::Dtype::Float32);
186
187 let weight = Array::ones::<f32>(&[16]).unwrap();
188 let result = rms_norm(a, weight, 1e-5).unwrap();
189 assert_eq!(result.shape(), [2, 8, 16]);
190 assert_eq!(result.dtype(), crate::Dtype::Float32);
191 assert_float_eq!(
192 result.mean(None, None).unwrap().item::<f32>(),
193 0.872_938_75,
194 abs <= 0.017_458_774
195 );
196 assert_float_eq!(
197 result.sum(None, None).unwrap().item::<f32>(),
198 223.472_32,
199 abs <= 4.469_446
200 );
201 }
202
203 #[test]
204 pub fn test_layer_norm_affine() {
205 crate::random::seed(635).unwrap();
206 let a = crate::random::uniform::<_, f32>(0.0, 1.0, &[2, 8, 16], None).unwrap();
207 assert_eq!(a.shape(), [2, 8, 16]);
208 assert_eq!(a.dtype(), crate::Dtype::Float32);
209
210 let weight = Array::ones::<f32>(&[16]).unwrap();
211 let bias = Array::zeros::<f32>(&[16]).unwrap();
212 let result = layer_norm(a, &weight, &bias, 1e-5).unwrap();
213 let result = result.index((ArrayIndexOp::Ellipsis, 0));
214 assert_eq!(result.shape(), [2, 8]);
215 assert_eq!(result.dtype(), crate::Dtype::Float32);
216 assert_float_eq!(
217 result.mean(None, None).unwrap().item::<f32>(),
218 0.290_990_38,
219 abs <= 0.005_819_807_8
220 );
221 assert_float_eq!(
222 result.sum(None, None).unwrap().item::<f32>(),
223 4.655_846,
224 abs <= 0.093_116_924
225 );
226 }
227}