mlx_rs/nn/
positional_encoding.rs

1use std::{cell::RefCell, collections::HashMap};
2
3use crate::{
4    array,
5    error::Exception,
6    module::{Module, Param},
7    ops::indexing::NewAxis,
8    ops::{arange, concatenate, exp, indexing::TryIndexOp, log},
9    Array, Dtype,
10};
11use mlx_internal_macros::{generate_builder, Buildable, Builder};
12use mlx_macros::ModuleParameters;
13
14/// Type alias for [`RotaryPositionalEncoding`].
15pub type Rope = RotaryPositionalEncoding;
16
17/// Type alias for [`RotaryPositionalEncodingBuilder`].
18pub type RopeBuilder = RotaryPositionalEncodingBuilder;
19
20generate_builder! {
21    /// Implements the rotary positional encoding.
22    ///
23    /// The traditional implementation rotates consecutive pairs of elements in the
24    /// feature dimension while the default implementation rotates pairs with
25    /// stride half the feature dimensions for efficiency.
26    ///
27    /// For more details see _RoFormer: Enhanced Transformer with Rotary Position
28    /// Embedding_ ([https://arxiv.org/abs/2104.09864](https://arxiv.org/abs/2104.09864))
29    #[derive(Debug, Clone, ModuleParameters, Buildable)]
30    #[module(root = crate)]
31    #[buildable(root = crate)]
32    #[builder(root = crate)]
33    pub struct RotaryPositionalEncoding {
34        /// The feature dimensions to be rotated. If the input feature is larger
35        /// than dims then the rest is left unchanged
36        pub dimensions: i32,
37
38        /// If `true` choose the traditional implementation which is slightly
39        /// less efficient
40        #[builder(optional, default = RotaryPositionalEncoding::DEFAULT_TRADITIONAL)]
41        pub traditional: bool,
42
43        /// The base used to compute angular frequency for each dimension in the
44        /// positional encodings
45        #[builder(optional, default = RotaryPositionalEncoding::DEFAULT_BASE)]
46        pub base: f32,
47
48        /// scale used to scale the positions
49        #[builder(optional, default = RotaryPositionalEncoding::DEFAULT_SCALE)]
50        pub scale: f32,
51    }
52}
53
54impl RotaryPositionalEncoding {
55    /// Default value for `traditional` field.
56    pub const DEFAULT_TRADITIONAL: bool = false;
57
58    /// Default value for `base` field.
59    pub const DEFAULT_BASE: f32 = 10_000.0;
60
61    /// Default value for `scale` field.
62    pub const DEFAULT_SCALE: f32 = 1.0;
63}
64
65generate_builder! {
66    /// Input for the [`RotaryPositionalEncoding`] module.
67    #[derive(Debug, Buildable, Clone)]
68    #[buildable(root = crate)]
69    #[builder(root = crate)]
70    pub struct RopeInput<'a> {
71        /// The input tensor.
72        pub x: &'a Array,
73
74        /// Offset
75        #[builder(optional, default = RopeInput::DEFAULT_OFFSET)]
76        pub offset: i32,
77    }
78}
79
80impl RopeInput<'_> {
81    /// Default value for `offset` field.
82    pub const DEFAULT_OFFSET: i32 = 0;
83}
84
85impl<'a> From<&'a Array> for RopeInput<'a> {
86    fn from(x: &'a Array) -> Self {
87        RopeInput {
88            x,
89            offset: Self::DEFAULT_OFFSET,
90        }
91    }
92}
93
94impl<'a> From<(&'a Array,)> for RopeInput<'a> {
95    fn from((x,): (&'a Array,)) -> Self {
96        RopeInput {
97            x,
98            offset: Self::DEFAULT_OFFSET,
99        }
100    }
101}
102
103impl<'a> From<(&'a Array, i32)> for RopeInput<'a> {
104    fn from((x, offset): (&'a Array, i32)) -> Self {
105        RopeInput { x, offset }
106    }
107}
108
109impl<'a, Input> Module<Input> for RotaryPositionalEncoding
110where
111    Input: Into<RopeInput<'a>>,
112{
113    type Error = Exception;
114
115    type Output = Array;
116
117    fn forward(&mut self, input: Input) -> Result<Self::Output, Self::Error> {
118        let RopeInput { x, offset } = input.into();
119        let shape = x.shape();
120        let x = x.reshape(&[-1, x.dim(-2), x.dim(-1)])?;
121        let x = crate::fast::rope(
122            x,
123            self.dimensions,
124            self.traditional,
125            self.base,
126            self.scale,
127            offset,
128            None,
129        )?;
130        x.reshape(shape)
131    }
132
133    fn training_mode(&mut self, _mode: bool) {}
134}
135
136/// Type alias for [`SinusoidalPositionalEncoding`].
137pub type Sinpe = SinusoidalPositionalEncoding;
138
139/// Type alias for [`SinusoidalPositionalEncodingBuilder`].
140pub type SinpeBuilder = SinusoidalPositionalEncodingBuilder;
141
142/// Implements sinusoidal positional encoding.
143///
144/// For more details see the paper "Attention Is All You Need"
145/// <https://arxiv.org/abs/1706.03762>.
146#[derive(Debug, Clone, ModuleParameters, Buildable)]
147#[module(root = crate)]
148#[buildable(root = crate)]
149pub struct SinusoidalPositionalEncoding {
150    #[param]
151    sigmas: Param<Array>,
152
153    /// multiplicative scale for the embeddings.  Default is `sqrt(2/dimensions)`
154    pub scale: f32,
155
156    /// if `true` embed using `[cos(x), sin(x)]` instead of the reverse
157    pub cosine_first: bool,
158}
159
160impl Sinpe {
161    /// Default value for `cosine_first` field.
162    pub const DEFAULT_COSINE_FIRST: bool = false;
163
164    /// Default value for min frequency.
165    pub const DEFAULT_MIN_FREQUENCY: f32 = 0.0001;
166
167    /// Default value for max frequency.
168    pub const DEFAULT_MAX_FREQUENCY: f32 = 1.0;
169
170    /// Default value for full turns.
171    pub const DEFAULT_FULL_TURNS: bool = false;
172}
173
174/// Builder for [`SinusoidalPositionalEncoding`].
175#[derive(Debug, Clone, Builder)]
176#[builder(
177    root = crate,
178    build_with = build_sinpe,
179    err = Exception,
180)]
181pub struct SinusoidalPositionalEncodingBuilder {
182    dimensions: i32,
183
184    #[builder(optional, default = Sinpe::DEFAULT_MIN_FREQUENCY)]
185    min_frequency: f32,
186
187    #[builder(optional, default = Sinpe::DEFAULT_MAX_FREQUENCY)]
188    max_frequency: f32,
189
190    #[builder(optional, default = None)]
191    scale: Option<f32>,
192
193    #[builder(optional, default = Sinpe::DEFAULT_COSINE_FIRST)]
194    cosine_first: bool,
195
196    #[builder(optional, default = Sinpe::DEFAULT_FULL_TURNS)]
197    full_turns: bool,
198}
199
200fn build_sinpe(builder: SinpeBuilder) -> Result<SinusoidalPositionalEncoding, Exception> {
201    let SinpeBuilder {
202        dimensions,
203        min_frequency,
204        max_frequency,
205        scale,
206        cosine_first,
207        full_turns,
208    } = builder;
209
210    let half_dim = dimensions / 2;
211    let one_zero = array!(1.0)
212        .subtract(Array::from_iter(0..half_dim, &[half_dim]).divide(array!(half_dim - 1))?)?;
213    let min_frequency = log(array!(min_frequency))?;
214    let max_frequency = log(array!(max_frequency))?;
215
216    // SAFETY: max_frequency and min_frequency are scalars and operations with scalars won't throw
217    let mut sigmas = exp(&one_zero * (&max_frequency - &min_frequency) + &min_frequency)?;
218    if full_turns {
219        // SAFETY: scalar array operation won't throw
220        sigmas *= array!(2.0 * std::f32::consts::PI);
221    }
222
223    let scale = scale.unwrap_or_else(|| (2.0 / dimensions as f32).sqrt());
224
225    Ok(SinusoidalPositionalEncoding {
226        sigmas: Param::new(sigmas),
227        scale,
228        cosine_first,
229    })
230}
231
232impl Module<&Array> for Sinpe {
233    type Error = Exception;
234    type Output = Array;
235
236    fn forward(&mut self, x: &Array) -> Result<Self::Output, Self::Error> {
237        let mut y = x
238            .expand_dims(&[-1])
239            .and_then(|x| x.multiply(&self.sigmas))?;
240
241        let cosy = y.cos()?;
242        let siny = y.sin()?;
243
244        if self.cosine_first {
245            y = concatenate(&[cosy, siny], -1)?;
246        } else {
247            y = concatenate(&[siny, cosy], -1)?;
248        }
249
250        if self.scale != 1.0 {
251            // SAFETY: multiplication with scalar won't throw
252            y *= self.scale;
253        }
254
255        Ok(y)
256    }
257
258    fn training_mode(&mut self, _mode: bool) {}
259}
260
261#[derive(Debug, Clone, Hash, PartialEq, Eq)]
262struct AlibiKey {
263    q_seq_len: i32,
264    k_seq_len: i32,
265    num_heads: i32,
266    offset: i32,
267    dtype: Dtype,
268}
269
270thread_local! {
271    static ALIBI_CACHE: RefCell<HashMap<AlibiKey, Array>> = RefCell::new(HashMap::new());
272}
273
274/// Attention with Linear Biases
275#[derive(Debug, Clone, ModuleParameters)]
276#[module(root = crate)]
277pub struct Alibi;
278
279impl Alibi {
280    fn slope(num_heads: i32) -> Result<Array, Exception> {
281        let x = 2.0_f32.powi(8).powf(1.0 / num_heads as f32);
282        array!(x)
283            .power(&arange::<_, f32>(1, num_heads + 1, None)?)?
284            .expand_dims(&[-1, -2])
285    }
286
287    fn matrix(key: AlibiKey) -> Result<Array, Exception> {
288        if let Some(value) = ALIBI_CACHE.with(|cache| cache.borrow().get(&key).cloned()) {
289            return Ok(value);
290        }
291
292        let x1 = arange::<_, f32>(key.offset, key.q_seq_len, None)?;
293        let x2 = arange::<_, f32>(0, key.k_seq_len, None)?;
294        let distance_matrix = x1
295            .try_index((.., NewAxis))?
296            .subtract(x2.try_index((NewAxis, ..))?)?
297            .expand_dims(&[0, 1])?
298            .abs()?
299            .negative()?;
300
301        let slope = Self::slope(key.num_heads)?;
302        let mask = distance_matrix.multiply(&slope)?.as_dtype(key.dtype)?;
303
304        ALIBI_CACHE.with(|cache| {
305            cache.borrow_mut().insert(key, mask.clone());
306        });
307
308        Ok(mask)
309    }
310}
311
312generate_builder! {
313    /// Input for the [`Alibi`] module.
314    #[derive(Debug, Clone, Buildable)]
315    #[buildable(root = crate)]
316    #[builder(root = crate)]
317    pub struct AlibiInput<'a> {
318        /// The attention scores.
319        pub attention_scores: &'a Array,
320
321        /// Offset
322        #[builder(optional, default = AlibiInput::DEFAULT_OFFSET)]
323        pub offset: i32,
324
325        /// Mask
326        #[builder(optional, default = None)]
327        pub mask: Option<&'a Array>,
328    }
329}
330
331impl AlibiInput<'_> {
332    /// Default value for `offset` field.
333    pub const DEFAULT_OFFSET: i32 = 0;
334}
335
336impl<'a> From<&'a Array> for AlibiInput<'a> {
337    fn from(attention_scores: &'a Array) -> Self {
338        AlibiInput {
339            attention_scores,
340            offset: Self::DEFAULT_OFFSET,
341            mask: None,
342        }
343    }
344}
345
346impl<'a> From<(&'a Array,)> for AlibiInput<'a> {
347    fn from((attention_scores,): (&'a Array,)) -> Self {
348        AlibiInput {
349            attention_scores,
350            offset: Self::DEFAULT_OFFSET,
351            mask: None,
352        }
353    }
354}
355
356impl<'a> From<(&'a Array, i32)> for AlibiInput<'a> {
357    fn from((attention_scores, offset): (&'a Array, i32)) -> Self {
358        AlibiInput {
359            attention_scores,
360            offset,
361            mask: None,
362        }
363    }
364}
365
366impl<'a> From<(&'a Array, i32, &'a Array)> for AlibiInput<'a> {
367    fn from((attention_scores, offset, mask): (&'a Array, i32, &'a Array)) -> Self {
368        AlibiInput {
369            attention_scores,
370            offset,
371            mask: Some(mask),
372        }
373    }
374}
375
376impl<'a> From<(&'a Array, i32, Option<&'a Array>)> for AlibiInput<'a> {
377    fn from((attention_scores, offset, mask): (&'a Array, i32, Option<&'a Array>)) -> Self {
378        AlibiInput {
379            attention_scores,
380            offset,
381            mask,
382        }
383    }
384}
385
386impl<'a, Input> Module<Input> for Alibi
387where
388    Input: Into<AlibiInput<'a>>,
389{
390    type Output = Array;
391    type Error = Exception;
392
393    fn forward(&mut self, input: Input) -> Result<Self::Output, Self::Error> {
394        let AlibiInput {
395            attention_scores,
396            offset,
397            mask,
398        } = input.into();
399
400        let key = AlibiKey {
401            q_seq_len: attention_scores.dim(-2) + offset,
402            k_seq_len: attention_scores.dim(-1),
403            num_heads: attention_scores.dim(1),
404            offset,
405            dtype: attention_scores.dtype(),
406        };
407
408        let mut alibi_mask = Self::matrix(key)?;
409        if let Some(mask) = mask {
410            alibi_mask = alibi_mask.add(mask)?;
411        }
412
413        attention_scores.add(alibi_mask)
414    }
415
416    fn training_mode(&mut self, _mode: bool) {}
417}
418
419#[allow(clippy::excessive_precision)]
420#[cfg(test)]
421mod tests {
422    use crate::{module::Module, nn::AlibiInput, random::uniform, Dtype};
423    use float_eq::assert_float_eq;
424
425    use crate::nn::Rope;
426
427    // The unit test below is adapted from the swift binding at:
428    // mlx-swift/Tests/MLXTests/IntegrationTests.swift
429    #[test]
430    fn test_rope() {
431        crate::random::seed(71).unwrap();
432        let a = uniform::<_, f32>(0, 1, &[2, 8, 16], None).unwrap();
433        assert_eq!(a.shape(), &[2, 8, 16]);
434        assert_eq!(a.dtype(), Dtype::Float32);
435        assert_float_eq!(
436            a.mean(None, None).unwrap().item::<f32>(),
437            0.5082664489746094,
438            abs <= 0.010165328979492188
439        );
440        assert_float_eq!(
441            a.sum(None, None).unwrap().item::<f32>(),
442            130.1162109375,
443            abs <= 2.60232421875
444        );
445
446        let mut rope = Rope::new(8);
447        let result = rope.forward(&a).unwrap();
448        assert_eq!(result.shape(), &[2, 8, 16]);
449        assert_eq!(result.dtype(), Dtype::Float32);
450        assert_float_eq!(
451            result.mean(None, None).unwrap().item::<f32>(),
452            0.4562537670135498,
453            abs <= 0.009125075340270997
454        );
455        assert_float_eq!(
456            result.sum(None, None).unwrap().item::<f32>(),
457            116.80096435546875,
458            abs <= 2.3360192871093752
459        );
460    }
461
462    // The unit test below is adapted from the swift binding at:
463    // mlx-swift/Tests/MLXTests/IntegrationTests.swift
464    #[test]
465    fn test_sinpe() {
466        crate::random::seed(226).unwrap();
467        let a = uniform::<_, f32>(0, 1, &[2, 8, 16], None).unwrap();
468        assert_eq!(a.shape(), &[2, 8, 16]);
469        assert_eq!(a.dtype(), Dtype::Float32);
470        assert_float_eq!(
471            a.mean(None, None).unwrap().item::<f32>(),
472            0.5026599168777466,
473            abs <= 0.010053198337554931
474        );
475        assert_float_eq!(
476            a.sum(None, None).unwrap().item::<f32>(),
477            128.68093872070312,
478            abs <= 2.5736187744140624
479        );
480
481        let mut sinpe = crate::nn::Sinpe::new(8).unwrap();
482        let result = sinpe.forward(&a).unwrap();
483        assert_eq!(result.shape(), &[2, 8, 16, 8]);
484        assert_eq!(result.dtype(), Dtype::Float32);
485        assert_float_eq!(
486            result.mean(None, None).unwrap().item::<f32>(),
487            0.2705308198928833,
488            abs <= 0.005410616397857666
489        );
490        assert_float_eq!(
491            result.sum(None, None).unwrap().item::<f32>(),
492            554.047119140625,
493            abs <= 11.0809423828125
494        );
495    }
496
497    // The unit test below is adapted from the python binding at:
498    // mlx/python/tests/test_nn.py
499    #[test]
500    fn test_alibi() {
501        let mut alibi = crate::nn::Alibi;
502        let shape = [1, 8, 20, 20];
503        let x = uniform::<_, f32>(0, 1, &shape, None).unwrap();
504        let input = AlibiInput::from(&x);
505        let y = alibi.forward(input).unwrap();
506        assert_eq!(y.shape(), shape);
507        assert_eq!(y.dtype(), Dtype::Float32);
508
509        let x2 = x.as_dtype(Dtype::Float16).unwrap();
510        let input = AlibiInput::from(&x2);
511        let y = alibi.forward(input).unwrap();
512        assert_eq!(y.dtype(), Dtype::Float16);
513    }
514}