1use std::{cell::RefCell, collections::HashMap};
2
3use crate::{
4 array,
5 error::Exception,
6 module::{Module, Param},
7 ops::{
8 arange, concatenate_axis, exp,
9 indexing::{NewAxis, TryIndexOp},
10 log,
11 },
12 Array, Dtype,
13};
14use mlx_internal_macros::{generate_builder, Buildable, Builder};
15use mlx_macros::ModuleParameters;
16
17pub type Rope = RotaryPositionalEncoding;
19
20pub type RopeBuilder = RotaryPositionalEncodingBuilder;
22
23generate_builder! {
24 #[derive(Debug, Clone, ModuleParameters, Buildable)]
33 #[module(root = crate)]
34 #[buildable(root = crate)]
35 #[builder(root = crate)]
36 pub struct RotaryPositionalEncoding {
37 pub dimensions: i32,
40
41 #[builder(optional, default = RotaryPositionalEncoding::DEFAULT_TRADITIONAL)]
44 pub traditional: bool,
45
46 #[builder(optional, default = RotaryPositionalEncoding::DEFAULT_BASE)]
49 pub base: f32,
50
51 #[builder(optional, default = RotaryPositionalEncoding::DEFAULT_SCALE)]
53 pub scale: f32,
54 }
55}
56
57impl RotaryPositionalEncoding {
58 pub const DEFAULT_TRADITIONAL: bool = false;
60
61 pub const DEFAULT_BASE: f32 = 10_000.0;
63
64 pub const DEFAULT_SCALE: f32 = 1.0;
66}
67
68generate_builder! {
69 #[derive(Debug, Buildable, Clone)]
71 #[buildable(root = crate)]
72 #[builder(root = crate)]
73 pub struct RopeInput<'a> {
74 pub x: &'a Array,
76
77 #[builder(optional, default = RopeInput::DEFAULT_OFFSET)]
79 pub offset: i32,
80 }
81}
82
83impl RopeInput<'_> {
84 pub const DEFAULT_OFFSET: i32 = 0;
86}
87
88impl<'a> From<&'a Array> for RopeInput<'a> {
89 fn from(x: &'a Array) -> Self {
90 RopeInput {
91 x,
92 offset: Self::DEFAULT_OFFSET,
93 }
94 }
95}
96
97impl<'a> From<(&'a Array,)> for RopeInput<'a> {
98 fn from((x,): (&'a Array,)) -> Self {
99 RopeInput {
100 x,
101 offset: Self::DEFAULT_OFFSET,
102 }
103 }
104}
105
106impl<'a> From<(&'a Array, i32)> for RopeInput<'a> {
107 fn from((x, offset): (&'a Array, i32)) -> Self {
108 RopeInput { x, offset }
109 }
110}
111
112impl<'a, Input> Module<Input> for RotaryPositionalEncoding
113where
114 Input: Into<RopeInput<'a>>,
115{
116 type Error = Exception;
117
118 type Output = Array;
119
120 fn forward(&mut self, input: Input) -> Result<Self::Output, Self::Error> {
121 let RopeInput { x, offset } = input.into();
122 let shape = x.shape();
123 let x = x.reshape(&[-1, x.dim(-2), x.dim(-1)])?;
124 let x = crate::fast::rope(
125 x,
126 self.dimensions,
127 self.traditional,
128 self.base,
129 self.scale,
130 offset,
131 None,
132 )?;
133 x.reshape(shape)
134 }
135
136 fn training_mode(&mut self, _mode: bool) {}
137}
138
139pub type Sinpe = SinusoidalPositionalEncoding;
141
142pub type SinpeBuilder = SinusoidalPositionalEncodingBuilder;
144
145#[derive(Debug, Clone, ModuleParameters, Buildable)]
150#[module(root = crate)]
151#[buildable(root = crate)]
152pub struct SinusoidalPositionalEncoding {
153 #[param]
154 sigmas: Param<Array>,
155
156 pub scale: f32,
158
159 pub cosine_first: bool,
161}
162
163impl Sinpe {
164 pub const DEFAULT_COSINE_FIRST: bool = false;
166
167 pub const DEFAULT_MIN_FREQUENCY: f32 = 0.0001;
169
170 pub const DEFAULT_MAX_FREQUENCY: f32 = 1.0;
172
173 pub const DEFAULT_FULL_TURNS: bool = false;
175}
176
177#[derive(Debug, Clone, Builder)]
179#[builder(
180 root = crate,
181 build_with = build_sinpe,
182 err = Exception,
183)]
184pub struct SinusoidalPositionalEncodingBuilder {
185 dimensions: i32,
186
187 #[builder(optional, default = Sinpe::DEFAULT_MIN_FREQUENCY)]
188 min_frequency: f32,
189
190 #[builder(optional, default = Sinpe::DEFAULT_MAX_FREQUENCY)]
191 max_frequency: f32,
192
193 #[builder(optional, default = None)]
194 scale: Option<f32>,
195
196 #[builder(optional, default = Sinpe::DEFAULT_COSINE_FIRST)]
197 cosine_first: bool,
198
199 #[builder(optional, default = Sinpe::DEFAULT_FULL_TURNS)]
200 full_turns: bool,
201}
202
203fn build_sinpe(builder: SinpeBuilder) -> Result<SinusoidalPositionalEncoding, Exception> {
204 let SinpeBuilder {
205 dimensions,
206 min_frequency,
207 max_frequency,
208 scale,
209 cosine_first,
210 full_turns,
211 } = builder;
212
213 let half_dim = dimensions / 2;
214 let one_zero = array!(1.0)
215 .subtract(Array::from_iter(0..half_dim, &[half_dim]).divide(array!(half_dim - 1))?)?;
216 let min_frequency = log(array!(min_frequency))?;
217 let max_frequency = log(array!(max_frequency))?;
218
219 let mut sigmas = exp(&one_zero * (&max_frequency - &min_frequency) + &min_frequency)?;
221 if full_turns {
222 sigmas *= array!(2.0 * std::f32::consts::PI);
224 }
225
226 let scale = scale.unwrap_or_else(|| (2.0 / dimensions as f32).sqrt());
227
228 Ok(SinusoidalPositionalEncoding {
229 sigmas: Param::new(sigmas),
230 scale,
231 cosine_first,
232 })
233}
234
235impl Module<&Array> for Sinpe {
236 type Error = Exception;
237 type Output = Array;
238
239 fn forward(&mut self, x: &Array) -> Result<Self::Output, Self::Error> {
240 let mut y = x
241 .expand_dims_axes(&[-1])
242 .and_then(|x| x.multiply(&self.sigmas))?;
243
244 let cosy = y.cos()?;
245 let siny = y.sin()?;
246
247 if self.cosine_first {
248 y = concatenate_axis(&[cosy, siny], -1)?;
249 } else {
250 y = concatenate_axis(&[siny, cosy], -1)?;
251 }
252
253 if self.scale != 1.0 {
254 y *= self.scale;
256 }
257
258 Ok(y)
259 }
260
261 fn training_mode(&mut self, _mode: bool) {}
262}
263
264#[derive(Debug, Clone, Hash, PartialEq, Eq)]
265struct AlibiKey {
266 q_seq_len: i32,
267 k_seq_len: i32,
268 num_heads: i32,
269 offset: i32,
270 dtype: Dtype,
271}
272
273thread_local! {
274 static ALIBI_CACHE: RefCell<HashMap<AlibiKey, Array>> = RefCell::new(HashMap::new());
275}
276
277#[derive(Debug, Clone, ModuleParameters)]
279#[module(root = crate)]
280pub struct Alibi;
281
282impl Alibi {
283 fn slope(num_heads: i32) -> Result<Array, Exception> {
284 let x = 2.0_f32.powi(8).powf(1.0 / num_heads as f32);
285 array!(x)
286 .power(&arange::<_, f32>(1, num_heads + 1, None)?)?
287 .expand_dims_axes(&[-1, -2])
288 }
289
290 fn matrix(key: AlibiKey) -> Result<Array, Exception> {
291 if let Some(value) = ALIBI_CACHE.with(|cache| cache.borrow().get(&key).cloned()) {
292 return Ok(value);
293 }
294
295 let x1 = arange::<_, f32>(key.offset, key.q_seq_len, None)?;
296 let x2 = arange::<_, f32>(0, key.k_seq_len, None)?;
297 let distance_matrix = x1
298 .try_index((.., NewAxis))?
299 .subtract(x2.try_index((NewAxis, ..))?)?
300 .expand_dims_axes(&[0, 1])?
301 .abs()?
302 .negative()?;
303
304 let slope = Self::slope(key.num_heads)?;
305 let mask = distance_matrix.multiply(&slope)?.as_dtype(key.dtype)?;
306
307 ALIBI_CACHE.with(|cache| {
308 cache.borrow_mut().insert(key, mask.clone());
309 });
310
311 Ok(mask)
312 }
313}
314
315generate_builder! {
316 #[derive(Debug, Clone, Buildable)]
318 #[buildable(root = crate)]
319 #[builder(root = crate)]
320 pub struct AlibiInput<'a> {
321 pub attention_scores: &'a Array,
323
324 #[builder(optional, default = AlibiInput::DEFAULT_OFFSET)]
326 pub offset: i32,
327
328 #[builder(optional, default = None)]
330 pub mask: Option<&'a Array>,
331 }
332}
333
334impl AlibiInput<'_> {
335 pub const DEFAULT_OFFSET: i32 = 0;
337}
338
339impl<'a> From<&'a Array> for AlibiInput<'a> {
340 fn from(attention_scores: &'a Array) -> Self {
341 AlibiInput {
342 attention_scores,
343 offset: Self::DEFAULT_OFFSET,
344 mask: None,
345 }
346 }
347}
348
349impl<'a> From<(&'a Array,)> for AlibiInput<'a> {
350 fn from((attention_scores,): (&'a Array,)) -> Self {
351 AlibiInput {
352 attention_scores,
353 offset: Self::DEFAULT_OFFSET,
354 mask: None,
355 }
356 }
357}
358
359impl<'a> From<(&'a Array, i32)> for AlibiInput<'a> {
360 fn from((attention_scores, offset): (&'a Array, i32)) -> Self {
361 AlibiInput {
362 attention_scores,
363 offset,
364 mask: None,
365 }
366 }
367}
368
369impl<'a> From<(&'a Array, i32, &'a Array)> for AlibiInput<'a> {
370 fn from((attention_scores, offset, mask): (&'a Array, i32, &'a Array)) -> Self {
371 AlibiInput {
372 attention_scores,
373 offset,
374 mask: Some(mask),
375 }
376 }
377}
378
379impl<'a> From<(&'a Array, i32, Option<&'a Array>)> for AlibiInput<'a> {
380 fn from((attention_scores, offset, mask): (&'a Array, i32, Option<&'a Array>)) -> Self {
381 AlibiInput {
382 attention_scores,
383 offset,
384 mask,
385 }
386 }
387}
388
389impl<'a, Input> Module<Input> for Alibi
390where
391 Input: Into<AlibiInput<'a>>,
392{
393 type Output = Array;
394 type Error = Exception;
395
396 fn forward(&mut self, input: Input) -> Result<Self::Output, Self::Error> {
397 let AlibiInput {
398 attention_scores,
399 offset,
400 mask,
401 } = input.into();
402
403 let key = AlibiKey {
404 q_seq_len: attention_scores.dim(-2) + offset,
405 k_seq_len: attention_scores.dim(-1),
406 num_heads: attention_scores.dim(1),
407 offset,
408 dtype: attention_scores.dtype(),
409 };
410
411 let mut alibi_mask = Self::matrix(key)?;
412 if let Some(mask) = mask {
413 alibi_mask = alibi_mask.add(mask)?;
414 }
415
416 attention_scores.add(alibi_mask)
417 }
418
419 fn training_mode(&mut self, _mode: bool) {}
420}
421
422#[allow(clippy::excessive_precision)]
423#[cfg(test)]
424mod tests {
425 use crate::{module::Module, nn::AlibiInput, random::uniform, Dtype};
426 use float_eq::assert_float_eq;
427
428 use crate::nn::Rope;
429
430 #[test]
433 fn test_rope() {
434 crate::random::seed(71).unwrap();
435 let a = uniform::<_, f32>(0, 1, &[2, 8, 16], None).unwrap();
436 assert_eq!(a.shape(), &[2, 8, 16]);
437 assert_eq!(a.dtype(), Dtype::Float32);
438 assert_float_eq!(
439 a.mean(None).unwrap().item::<f32>(),
440 0.5082664489746094,
441 abs <= 0.010165328979492188
442 );
443 assert_float_eq!(
444 a.sum(None).unwrap().item::<f32>(),
445 130.1162109375,
446 abs <= 2.60232421875
447 );
448
449 let mut rope = Rope::new(8);
450 let result = rope.forward(&a).unwrap();
451 assert_eq!(result.shape(), &[2, 8, 16]);
452 assert_eq!(result.dtype(), Dtype::Float32);
453 assert_float_eq!(
454 result.mean(None).unwrap().item::<f32>(),
455 0.4562537670135498,
456 abs <= 0.009125075340270997
457 );
458 assert_float_eq!(
459 result.sum(None).unwrap().item::<f32>(),
460 116.80096435546875,
461 abs <= 2.3360192871093752
462 );
463 }
464
465 #[test]
468 fn test_sinpe() {
469 crate::random::seed(226).unwrap();
470 let a = uniform::<_, f32>(0, 1, &[2, 8, 16], None).unwrap();
471 assert_eq!(a.shape(), &[2, 8, 16]);
472 assert_eq!(a.dtype(), Dtype::Float32);
473 assert_float_eq!(
474 a.mean(None).unwrap().item::<f32>(),
475 0.5026599168777466,
476 abs <= 0.010053198337554931
477 );
478 assert_float_eq!(
479 a.sum(None).unwrap().item::<f32>(),
480 128.68093872070312,
481 abs <= 2.5736187744140624
482 );
483
484 let mut sinpe = crate::nn::Sinpe::new(8).unwrap();
485 let result = sinpe.forward(&a).unwrap();
486 assert_eq!(result.shape(), &[2, 8, 16, 8]);
487 assert_eq!(result.dtype(), Dtype::Float32);
488 assert_float_eq!(
489 result.mean(None).unwrap().item::<f32>(),
490 0.2705308198928833,
491 abs <= 0.005410616397857666
492 );
493 assert_float_eq!(
494 result.sum(None).unwrap().item::<f32>(),
495 554.047119140625,
496 abs <= 11.0809423828125
497 );
498 }
499
500 #[test]
503 fn test_alibi() {
504 let mut alibi = crate::nn::Alibi;
505 let shape = [1, 8, 20, 20];
506 let x = uniform::<_, f32>(0, 1, &shape, None).unwrap();
507 let input = AlibiInput::from(&x);
508 let y = alibi.forward(input).unwrap();
509 assert_eq!(y.shape(), shape);
510 assert_eq!(y.dtype(), Dtype::Float32);
511
512 let x2 = x.as_dtype(Dtype::Float16).unwrap();
513 let input = AlibiInput::from(&x2);
514 let y = alibi.forward(input).unwrap();
515 assert_eq!(y.dtype(), Dtype::Float16);
516 }
517}