1use std::{cell::RefCell, collections::HashMap};
2
3use crate::{
4 array,
5 error::Exception,
6 module::{Module, Param},
7 ops::indexing::NewAxis,
8 ops::{arange, concatenate, exp, indexing::TryIndexOp, log},
9 Array, Dtype,
10};
11use mlx_internal_macros::{generate_builder, Buildable, Builder};
12use mlx_macros::ModuleParameters;
13
14pub type Rope = RotaryPositionalEncoding;
16
17pub type RopeBuilder = RotaryPositionalEncodingBuilder;
19
20generate_builder! {
21 #[derive(Debug, Clone, ModuleParameters, Buildable)]
30 #[module(root = crate)]
31 #[buildable(root = crate)]
32 #[builder(root = crate)]
33 pub struct RotaryPositionalEncoding {
34 pub dimensions: i32,
37
38 #[builder(optional, default = RotaryPositionalEncoding::DEFAULT_TRADITIONAL)]
41 pub traditional: bool,
42
43 #[builder(optional, default = RotaryPositionalEncoding::DEFAULT_BASE)]
46 pub base: f32,
47
48 #[builder(optional, default = RotaryPositionalEncoding::DEFAULT_SCALE)]
50 pub scale: f32,
51 }
52}
53
54impl RotaryPositionalEncoding {
55 pub const DEFAULT_TRADITIONAL: bool = false;
57
58 pub const DEFAULT_BASE: f32 = 10_000.0;
60
61 pub const DEFAULT_SCALE: f32 = 1.0;
63}
64
65generate_builder! {
66 #[derive(Debug, Buildable, Clone)]
68 #[buildable(root = crate)]
69 #[builder(root = crate)]
70 pub struct RopeInput<'a> {
71 pub x: &'a Array,
73
74 #[builder(optional, default = RopeInput::DEFAULT_OFFSET)]
76 pub offset: i32,
77 }
78}
79
80impl RopeInput<'_> {
81 pub const DEFAULT_OFFSET: i32 = 0;
83}
84
85impl<'a> From<&'a Array> for RopeInput<'a> {
86 fn from(x: &'a Array) -> Self {
87 RopeInput {
88 x,
89 offset: Self::DEFAULT_OFFSET,
90 }
91 }
92}
93
94impl<'a> From<(&'a Array,)> for RopeInput<'a> {
95 fn from((x,): (&'a Array,)) -> Self {
96 RopeInput {
97 x,
98 offset: Self::DEFAULT_OFFSET,
99 }
100 }
101}
102
103impl<'a> From<(&'a Array, i32)> for RopeInput<'a> {
104 fn from((x, offset): (&'a Array, i32)) -> Self {
105 RopeInput { x, offset }
106 }
107}
108
109impl<'a, Input> Module<Input> for RotaryPositionalEncoding
110where
111 Input: Into<RopeInput<'a>>,
112{
113 type Error = Exception;
114
115 type Output = Array;
116
117 fn forward(&mut self, input: Input) -> Result<Self::Output, Self::Error> {
118 let RopeInput { x, offset } = input.into();
119 let shape = x.shape();
120 let x = x.reshape(&[-1, x.dim(-2), x.dim(-1)])?;
121 let x = crate::fast::rope(
122 x,
123 self.dimensions,
124 self.traditional,
125 self.base,
126 self.scale,
127 offset,
128 None,
129 )?;
130 x.reshape(shape)
131 }
132
133 fn training_mode(&mut self, _mode: bool) {}
134}
135
136pub type Sinpe = SinusoidalPositionalEncoding;
138
139pub type SinpeBuilder = SinusoidalPositionalEncodingBuilder;
141
142#[derive(Debug, Clone, ModuleParameters, Buildable)]
147#[module(root = crate)]
148#[buildable(root = crate)]
149pub struct SinusoidalPositionalEncoding {
150 #[param]
151 sigmas: Param<Array>,
152
153 pub scale: f32,
155
156 pub cosine_first: bool,
158}
159
160impl Sinpe {
161 pub const DEFAULT_COSINE_FIRST: bool = false;
163
164 pub const DEFAULT_MIN_FREQUENCY: f32 = 0.0001;
166
167 pub const DEFAULT_MAX_FREQUENCY: f32 = 1.0;
169
170 pub const DEFAULT_FULL_TURNS: bool = false;
172}
173
174#[derive(Debug, Clone, Builder)]
176#[builder(
177 root = crate,
178 build_with = build_sinpe,
179 err = Exception,
180)]
181pub struct SinusoidalPositionalEncodingBuilder {
182 dimensions: i32,
183
184 #[builder(optional, default = Sinpe::DEFAULT_MIN_FREQUENCY)]
185 min_frequency: f32,
186
187 #[builder(optional, default = Sinpe::DEFAULT_MAX_FREQUENCY)]
188 max_frequency: f32,
189
190 #[builder(optional, default = None)]
191 scale: Option<f32>,
192
193 #[builder(optional, default = Sinpe::DEFAULT_COSINE_FIRST)]
194 cosine_first: bool,
195
196 #[builder(optional, default = Sinpe::DEFAULT_FULL_TURNS)]
197 full_turns: bool,
198}
199
200fn build_sinpe(builder: SinpeBuilder) -> Result<SinusoidalPositionalEncoding, Exception> {
201 let SinpeBuilder {
202 dimensions,
203 min_frequency,
204 max_frequency,
205 scale,
206 cosine_first,
207 full_turns,
208 } = builder;
209
210 let half_dim = dimensions / 2;
211 let one_zero = array!(1.0)
212 .subtract(Array::from_iter(0..half_dim, &[half_dim]).divide(array!(half_dim - 1))?)?;
213 let min_frequency = log(array!(min_frequency))?;
214 let max_frequency = log(array!(max_frequency))?;
215
216 let mut sigmas = exp(&one_zero * (&max_frequency - &min_frequency) + &min_frequency)?;
218 if full_turns {
219 sigmas *= array!(2.0 * std::f32::consts::PI);
221 }
222
223 let scale = scale.unwrap_or_else(|| (2.0 / dimensions as f32).sqrt());
224
225 Ok(SinusoidalPositionalEncoding {
226 sigmas: Param::new(sigmas),
227 scale,
228 cosine_first,
229 })
230}
231
232impl Module<&Array> for Sinpe {
233 type Error = Exception;
234 type Output = Array;
235
236 fn forward(&mut self, x: &Array) -> Result<Self::Output, Self::Error> {
237 let mut y = x
238 .expand_dims(&[-1])
239 .and_then(|x| x.multiply(&self.sigmas))?;
240
241 let cosy = y.cos()?;
242 let siny = y.sin()?;
243
244 if self.cosine_first {
245 y = concatenate(&[cosy, siny], -1)?;
246 } else {
247 y = concatenate(&[siny, cosy], -1)?;
248 }
249
250 if self.scale != 1.0 {
251 y *= self.scale;
253 }
254
255 Ok(y)
256 }
257
258 fn training_mode(&mut self, _mode: bool) {}
259}
260
261#[derive(Debug, Clone, Hash, PartialEq, Eq)]
262struct AlibiKey {
263 q_seq_len: i32,
264 k_seq_len: i32,
265 num_heads: i32,
266 offset: i32,
267 dtype: Dtype,
268}
269
270thread_local! {
271 static ALIBI_CACHE: RefCell<HashMap<AlibiKey, Array>> = RefCell::new(HashMap::new());
272}
273
274#[derive(Debug, Clone, ModuleParameters)]
276#[module(root = crate)]
277pub struct Alibi;
278
279impl Alibi {
280 fn slope(num_heads: i32) -> Result<Array, Exception> {
281 let x = 2.0_f32.powi(8).powf(1.0 / num_heads as f32);
282 array!(x)
283 .power(&arange::<_, f32>(1, num_heads + 1, None)?)?
284 .expand_dims(&[-1, -2])
285 }
286
287 fn matrix(key: AlibiKey) -> Result<Array, Exception> {
288 if let Some(value) = ALIBI_CACHE.with(|cache| cache.borrow().get(&key).cloned()) {
289 return Ok(value);
290 }
291
292 let x1 = arange::<_, f32>(key.offset, key.q_seq_len, None)?;
293 let x2 = arange::<_, f32>(0, key.k_seq_len, None)?;
294 let distance_matrix = x1
295 .try_index((.., NewAxis))?
296 .subtract(x2.try_index((NewAxis, ..))?)?
297 .expand_dims(&[0, 1])?
298 .abs()?
299 .negative()?;
300
301 let slope = Self::slope(key.num_heads)?;
302 let mask = distance_matrix.multiply(&slope)?.as_dtype(key.dtype)?;
303
304 ALIBI_CACHE.with(|cache| {
305 cache.borrow_mut().insert(key, mask.clone());
306 });
307
308 Ok(mask)
309 }
310}
311
312generate_builder! {
313 #[derive(Debug, Clone, Buildable)]
315 #[buildable(root = crate)]
316 #[builder(root = crate)]
317 pub struct AlibiInput<'a> {
318 pub attention_scores: &'a Array,
320
321 #[builder(optional, default = AlibiInput::DEFAULT_OFFSET)]
323 pub offset: i32,
324
325 #[builder(optional, default = None)]
327 pub mask: Option<&'a Array>,
328 }
329}
330
331impl AlibiInput<'_> {
332 pub const DEFAULT_OFFSET: i32 = 0;
334}
335
336impl<'a> From<&'a Array> for AlibiInput<'a> {
337 fn from(attention_scores: &'a Array) -> Self {
338 AlibiInput {
339 attention_scores,
340 offset: Self::DEFAULT_OFFSET,
341 mask: None,
342 }
343 }
344}
345
346impl<'a> From<(&'a Array,)> for AlibiInput<'a> {
347 fn from((attention_scores,): (&'a Array,)) -> Self {
348 AlibiInput {
349 attention_scores,
350 offset: Self::DEFAULT_OFFSET,
351 mask: None,
352 }
353 }
354}
355
356impl<'a> From<(&'a Array, i32)> for AlibiInput<'a> {
357 fn from((attention_scores, offset): (&'a Array, i32)) -> Self {
358 AlibiInput {
359 attention_scores,
360 offset,
361 mask: None,
362 }
363 }
364}
365
366impl<'a> From<(&'a Array, i32, &'a Array)> for AlibiInput<'a> {
367 fn from((attention_scores, offset, mask): (&'a Array, i32, &'a Array)) -> Self {
368 AlibiInput {
369 attention_scores,
370 offset,
371 mask: Some(mask),
372 }
373 }
374}
375
376impl<'a> From<(&'a Array, i32, Option<&'a Array>)> for AlibiInput<'a> {
377 fn from((attention_scores, offset, mask): (&'a Array, i32, Option<&'a Array>)) -> Self {
378 AlibiInput {
379 attention_scores,
380 offset,
381 mask,
382 }
383 }
384}
385
386impl<'a, Input> Module<Input> for Alibi
387where
388 Input: Into<AlibiInput<'a>>,
389{
390 type Output = Array;
391 type Error = Exception;
392
393 fn forward(&mut self, input: Input) -> Result<Self::Output, Self::Error> {
394 let AlibiInput {
395 attention_scores,
396 offset,
397 mask,
398 } = input.into();
399
400 let key = AlibiKey {
401 q_seq_len: attention_scores.dim(-2) + offset,
402 k_seq_len: attention_scores.dim(-1),
403 num_heads: attention_scores.dim(1),
404 offset,
405 dtype: attention_scores.dtype(),
406 };
407
408 let mut alibi_mask = Self::matrix(key)?;
409 if let Some(mask) = mask {
410 alibi_mask = alibi_mask.add(mask)?;
411 }
412
413 attention_scores.add(alibi_mask)
414 }
415
416 fn training_mode(&mut self, _mode: bool) {}
417}
418
419#[allow(clippy::excessive_precision)]
420#[cfg(test)]
421mod tests {
422 use crate::{module::Module, nn::AlibiInput, random::uniform, Dtype};
423 use float_eq::assert_float_eq;
424
425 use crate::nn::Rope;
426
427 #[test]
430 fn test_rope() {
431 crate::random::seed(71).unwrap();
432 let a = uniform::<_, f32>(0, 1, &[2, 8, 16], None).unwrap();
433 assert_eq!(a.shape(), &[2, 8, 16]);
434 assert_eq!(a.dtype(), Dtype::Float32);
435 assert_float_eq!(
436 a.mean(None, None).unwrap().item::<f32>(),
437 0.5082664489746094,
438 abs <= 0.010165328979492188
439 );
440 assert_float_eq!(
441 a.sum(None, None).unwrap().item::<f32>(),
442 130.1162109375,
443 abs <= 2.60232421875
444 );
445
446 let mut rope = Rope::new(8);
447 let result = rope.forward(&a).unwrap();
448 assert_eq!(result.shape(), &[2, 8, 16]);
449 assert_eq!(result.dtype(), Dtype::Float32);
450 assert_float_eq!(
451 result.mean(None, None).unwrap().item::<f32>(),
452 0.4562537670135498,
453 abs <= 0.009125075340270997
454 );
455 assert_float_eq!(
456 result.sum(None, None).unwrap().item::<f32>(),
457 116.80096435546875,
458 abs <= 2.3360192871093752
459 );
460 }
461
462 #[test]
465 fn test_sinpe() {
466 crate::random::seed(226).unwrap();
467 let a = uniform::<_, f32>(0, 1, &[2, 8, 16], None).unwrap();
468 assert_eq!(a.shape(), &[2, 8, 16]);
469 assert_eq!(a.dtype(), Dtype::Float32);
470 assert_float_eq!(
471 a.mean(None, None).unwrap().item::<f32>(),
472 0.5026599168777466,
473 abs <= 0.010053198337554931
474 );
475 assert_float_eq!(
476 a.sum(None, None).unwrap().item::<f32>(),
477 128.68093872070312,
478 abs <= 2.5736187744140624
479 );
480
481 let mut sinpe = crate::nn::Sinpe::new(8).unwrap();
482 let result = sinpe.forward(&a).unwrap();
483 assert_eq!(result.shape(), &[2, 8, 16, 8]);
484 assert_eq!(result.dtype(), Dtype::Float32);
485 assert_float_eq!(
486 result.mean(None, None).unwrap().item::<f32>(),
487 0.2705308198928833,
488 abs <= 0.005410616397857666
489 );
490 assert_float_eq!(
491 result.sum(None, None).unwrap().item::<f32>(),
492 554.047119140625,
493 abs <= 11.0809423828125
494 );
495 }
496
497 #[test]
500 fn test_alibi() {
501 let mut alibi = crate::nn::Alibi;
502 let shape = [1, 8, 20, 20];
503 let x = uniform::<_, f32>(0, 1, &shape, None).unwrap();
504 let input = AlibiInput::from(&x);
505 let y = alibi.forward(input).unwrap();
506 assert_eq!(y.shape(), shape);
507 assert_eq!(y.dtype(), Dtype::Float32);
508
509 let x2 = x.as_dtype(Dtype::Float16).unwrap();
510 let input = AlibiInput::from(&x2);
511 let y = alibi.forward(input).unwrap();
512 assert_eq!(y.dtype(), Dtype::Float16);
513 }
514}