mlx_rs/nn/
positional_encoding.rs

1use std::{cell::RefCell, collections::HashMap};
2
3use crate::{
4    array,
5    error::Exception,
6    module::{Module, Param},
7    ops::{
8        arange, concatenate_axis, exp,
9        indexing::{NewAxis, TryIndexOp},
10        log,
11    },
12    Array, Dtype,
13};
14use mlx_internal_macros::{generate_builder, Buildable, Builder};
15use mlx_macros::ModuleParameters;
16
17/// Type alias for [`RotaryPositionalEncoding`].
18pub type Rope = RotaryPositionalEncoding;
19
20/// Type alias for [`RotaryPositionalEncodingBuilder`].
21pub type RopeBuilder = RotaryPositionalEncodingBuilder;
22
23generate_builder! {
24    /// Implements the rotary positional encoding.
25    ///
26    /// The traditional implementation rotates consecutive pairs of elements in the
27    /// feature dimension while the default implementation rotates pairs with
28    /// stride half the feature dimensions for efficiency.
29    ///
30    /// For more details see _RoFormer: Enhanced Transformer with Rotary Position
31    /// Embedding_ ([https://arxiv.org/abs/2104.09864](https://arxiv.org/abs/2104.09864))
32    #[derive(Debug, Clone, ModuleParameters, Buildable)]
33    #[module(root = crate)]
34    #[buildable(root = crate)]
35    #[builder(root = crate)]
36    pub struct RotaryPositionalEncoding {
37        /// The feature dimensions to be rotated. If the input feature is larger
38        /// than dims then the rest is left unchanged
39        pub dimensions: i32,
40
41        /// If `true` choose the traditional implementation which is slightly
42        /// less efficient
43        #[builder(optional, default = RotaryPositionalEncoding::DEFAULT_TRADITIONAL)]
44        pub traditional: bool,
45
46        /// The base used to compute angular frequency for each dimension in the
47        /// positional encodings
48        #[builder(optional, default = RotaryPositionalEncoding::DEFAULT_BASE)]
49        pub base: f32,
50
51        /// scale used to scale the positions
52        #[builder(optional, default = RotaryPositionalEncoding::DEFAULT_SCALE)]
53        pub scale: f32,
54    }
55}
56
57impl RotaryPositionalEncoding {
58    /// Default value for `traditional` field.
59    pub const DEFAULT_TRADITIONAL: bool = false;
60
61    /// Default value for `base` field.
62    pub const DEFAULT_BASE: f32 = 10_000.0;
63
64    /// Default value for `scale` field.
65    pub const DEFAULT_SCALE: f32 = 1.0;
66}
67
68generate_builder! {
69    /// Input for the [`RotaryPositionalEncoding`] module.
70    #[derive(Debug, Buildable, Clone)]
71    #[buildable(root = crate)]
72    #[builder(root = crate)]
73    pub struct RopeInput<'a> {
74        /// The input tensor.
75        pub x: &'a Array,
76
77        /// Offset
78        #[builder(optional, default = RopeInput::DEFAULT_OFFSET)]
79        pub offset: i32,
80    }
81}
82
83impl RopeInput<'_> {
84    /// Default value for `offset` field.
85    pub const DEFAULT_OFFSET: i32 = 0;
86}
87
88impl<'a> From<&'a Array> for RopeInput<'a> {
89    fn from(x: &'a Array) -> Self {
90        RopeInput {
91            x,
92            offset: Self::DEFAULT_OFFSET,
93        }
94    }
95}
96
97impl<'a> From<(&'a Array,)> for RopeInput<'a> {
98    fn from((x,): (&'a Array,)) -> Self {
99        RopeInput {
100            x,
101            offset: Self::DEFAULT_OFFSET,
102        }
103    }
104}
105
106impl<'a> From<(&'a Array, i32)> for RopeInput<'a> {
107    fn from((x, offset): (&'a Array, i32)) -> Self {
108        RopeInput { x, offset }
109    }
110}
111
112impl<'a, Input> Module<Input> for RotaryPositionalEncoding
113where
114    Input: Into<RopeInput<'a>>,
115{
116    type Error = Exception;
117
118    type Output = Array;
119
120    fn forward(&mut self, input: Input) -> Result<Self::Output, Self::Error> {
121        let RopeInput { x, offset } = input.into();
122        let shape = x.shape();
123        let x = x.reshape(&[-1, x.dim(-2), x.dim(-1)])?;
124        let x = crate::fast::rope(
125            x,
126            self.dimensions,
127            self.traditional,
128            self.base,
129            self.scale,
130            offset,
131            None,
132        )?;
133        x.reshape(shape)
134    }
135
136    fn training_mode(&mut self, _mode: bool) {}
137}
138
139/// Type alias for [`SinusoidalPositionalEncoding`].
140pub type Sinpe = SinusoidalPositionalEncoding;
141
142/// Type alias for [`SinusoidalPositionalEncodingBuilder`].
143pub type SinpeBuilder = SinusoidalPositionalEncodingBuilder;
144
145/// Implements sinusoidal positional encoding.
146///
147/// For more details see the paper "Attention Is All You Need"
148/// <https://arxiv.org/abs/1706.03762>.
149#[derive(Debug, Clone, ModuleParameters, Buildable)]
150#[module(root = crate)]
151#[buildable(root = crate)]
152pub struct SinusoidalPositionalEncoding {
153    #[param]
154    sigmas: Param<Array>,
155
156    /// multiplicative scale for the embeddings.  Default is `sqrt(2/dimensions)`
157    pub scale: f32,
158
159    /// if `true` embed using `[cos(x), sin(x)]` instead of the reverse
160    pub cosine_first: bool,
161}
162
163impl Sinpe {
164    /// Default value for `cosine_first` field.
165    pub const DEFAULT_COSINE_FIRST: bool = false;
166
167    /// Default value for min frequency.
168    pub const DEFAULT_MIN_FREQUENCY: f32 = 0.0001;
169
170    /// Default value for max frequency.
171    pub const DEFAULT_MAX_FREQUENCY: f32 = 1.0;
172
173    /// Default value for full turns.
174    pub const DEFAULT_FULL_TURNS: bool = false;
175}
176
177/// Builder for [`SinusoidalPositionalEncoding`].
178#[derive(Debug, Clone, Builder)]
179#[builder(
180    root = crate,
181    build_with = build_sinpe,
182    err = Exception,
183)]
184pub struct SinusoidalPositionalEncodingBuilder {
185    dimensions: i32,
186
187    #[builder(optional, default = Sinpe::DEFAULT_MIN_FREQUENCY)]
188    min_frequency: f32,
189
190    #[builder(optional, default = Sinpe::DEFAULT_MAX_FREQUENCY)]
191    max_frequency: f32,
192
193    #[builder(optional, default = None)]
194    scale: Option<f32>,
195
196    #[builder(optional, default = Sinpe::DEFAULT_COSINE_FIRST)]
197    cosine_first: bool,
198
199    #[builder(optional, default = Sinpe::DEFAULT_FULL_TURNS)]
200    full_turns: bool,
201}
202
203fn build_sinpe(builder: SinpeBuilder) -> Result<SinusoidalPositionalEncoding, Exception> {
204    let SinpeBuilder {
205        dimensions,
206        min_frequency,
207        max_frequency,
208        scale,
209        cosine_first,
210        full_turns,
211    } = builder;
212
213    let half_dim = dimensions / 2;
214    let one_zero = array!(1.0)
215        .subtract(Array::from_iter(0..half_dim, &[half_dim]).divide(array!(half_dim - 1))?)?;
216    let min_frequency = log(array!(min_frequency))?;
217    let max_frequency = log(array!(max_frequency))?;
218
219    // SAFETY: max_frequency and min_frequency are scalars and operations with scalars won't throw
220    let mut sigmas = exp(&one_zero * (&max_frequency - &min_frequency) + &min_frequency)?;
221    if full_turns {
222        // SAFETY: scalar array operation won't throw
223        sigmas *= array!(2.0 * std::f32::consts::PI);
224    }
225
226    let scale = scale.unwrap_or_else(|| (2.0 / dimensions as f32).sqrt());
227
228    Ok(SinusoidalPositionalEncoding {
229        sigmas: Param::new(sigmas),
230        scale,
231        cosine_first,
232    })
233}
234
235impl Module<&Array> for Sinpe {
236    type Error = Exception;
237    type Output = Array;
238
239    fn forward(&mut self, x: &Array) -> Result<Self::Output, Self::Error> {
240        let mut y = x
241            .expand_dims_axes(&[-1])
242            .and_then(|x| x.multiply(&self.sigmas))?;
243
244        let cosy = y.cos()?;
245        let siny = y.sin()?;
246
247        if self.cosine_first {
248            y = concatenate_axis(&[cosy, siny], -1)?;
249        } else {
250            y = concatenate_axis(&[siny, cosy], -1)?;
251        }
252
253        if self.scale != 1.0 {
254            // SAFETY: multiplication with scalar won't throw
255            y *= self.scale;
256        }
257
258        Ok(y)
259    }
260
261    fn training_mode(&mut self, _mode: bool) {}
262}
263
264#[derive(Debug, Clone, Hash, PartialEq, Eq)]
265struct AlibiKey {
266    q_seq_len: i32,
267    k_seq_len: i32,
268    num_heads: i32,
269    offset: i32,
270    dtype: Dtype,
271}
272
273thread_local! {
274    static ALIBI_CACHE: RefCell<HashMap<AlibiKey, Array>> = RefCell::new(HashMap::new());
275}
276
277/// Attention with Linear Biases
278#[derive(Debug, Clone, ModuleParameters)]
279#[module(root = crate)]
280pub struct Alibi;
281
282impl Alibi {
283    fn slope(num_heads: i32) -> Result<Array, Exception> {
284        let x = 2.0_f32.powi(8).powf(1.0 / num_heads as f32);
285        array!(x)
286            .power(&arange::<_, f32>(1, num_heads + 1, None)?)?
287            .expand_dims_axes(&[-1, -2])
288    }
289
290    fn matrix(key: AlibiKey) -> Result<Array, Exception> {
291        if let Some(value) = ALIBI_CACHE.with(|cache| cache.borrow().get(&key).cloned()) {
292            return Ok(value);
293        }
294
295        let x1 = arange::<_, f32>(key.offset, key.q_seq_len, None)?;
296        let x2 = arange::<_, f32>(0, key.k_seq_len, None)?;
297        let distance_matrix = x1
298            .try_index((.., NewAxis))?
299            .subtract(x2.try_index((NewAxis, ..))?)?
300            .expand_dims_axes(&[0, 1])?
301            .abs()?
302            .negative()?;
303
304        let slope = Self::slope(key.num_heads)?;
305        let mask = distance_matrix.multiply(&slope)?.as_dtype(key.dtype)?;
306
307        ALIBI_CACHE.with(|cache| {
308            cache.borrow_mut().insert(key, mask.clone());
309        });
310
311        Ok(mask)
312    }
313}
314
315generate_builder! {
316    /// Input for the [`Alibi`] module.
317    #[derive(Debug, Clone, Buildable)]
318    #[buildable(root = crate)]
319    #[builder(root = crate)]
320    pub struct AlibiInput<'a> {
321        /// The attention scores.
322        pub attention_scores: &'a Array,
323
324        /// Offset
325        #[builder(optional, default = AlibiInput::DEFAULT_OFFSET)]
326        pub offset: i32,
327
328        /// Mask
329        #[builder(optional, default = None)]
330        pub mask: Option<&'a Array>,
331    }
332}
333
334impl AlibiInput<'_> {
335    /// Default value for `offset` field.
336    pub const DEFAULT_OFFSET: i32 = 0;
337}
338
339impl<'a> From<&'a Array> for AlibiInput<'a> {
340    fn from(attention_scores: &'a Array) -> Self {
341        AlibiInput {
342            attention_scores,
343            offset: Self::DEFAULT_OFFSET,
344            mask: None,
345        }
346    }
347}
348
349impl<'a> From<(&'a Array,)> for AlibiInput<'a> {
350    fn from((attention_scores,): (&'a Array,)) -> Self {
351        AlibiInput {
352            attention_scores,
353            offset: Self::DEFAULT_OFFSET,
354            mask: None,
355        }
356    }
357}
358
359impl<'a> From<(&'a Array, i32)> for AlibiInput<'a> {
360    fn from((attention_scores, offset): (&'a Array, i32)) -> Self {
361        AlibiInput {
362            attention_scores,
363            offset,
364            mask: None,
365        }
366    }
367}
368
369impl<'a> From<(&'a Array, i32, &'a Array)> for AlibiInput<'a> {
370    fn from((attention_scores, offset, mask): (&'a Array, i32, &'a Array)) -> Self {
371        AlibiInput {
372            attention_scores,
373            offset,
374            mask: Some(mask),
375        }
376    }
377}
378
379impl<'a> From<(&'a Array, i32, Option<&'a Array>)> for AlibiInput<'a> {
380    fn from((attention_scores, offset, mask): (&'a Array, i32, Option<&'a Array>)) -> Self {
381        AlibiInput {
382            attention_scores,
383            offset,
384            mask,
385        }
386    }
387}
388
389impl<'a, Input> Module<Input> for Alibi
390where
391    Input: Into<AlibiInput<'a>>,
392{
393    type Output = Array;
394    type Error = Exception;
395
396    fn forward(&mut self, input: Input) -> Result<Self::Output, Self::Error> {
397        let AlibiInput {
398            attention_scores,
399            offset,
400            mask,
401        } = input.into();
402
403        let key = AlibiKey {
404            q_seq_len: attention_scores.dim(-2) + offset,
405            k_seq_len: attention_scores.dim(-1),
406            num_heads: attention_scores.dim(1),
407            offset,
408            dtype: attention_scores.dtype(),
409        };
410
411        let mut alibi_mask = Self::matrix(key)?;
412        if let Some(mask) = mask {
413            alibi_mask = alibi_mask.add(mask)?;
414        }
415
416        attention_scores.add(alibi_mask)
417    }
418
419    fn training_mode(&mut self, _mode: bool) {}
420}
421
422#[allow(clippy::excessive_precision)]
423#[cfg(test)]
424mod tests {
425    use crate::{module::Module, nn::AlibiInput, random::uniform, Dtype};
426    use float_eq::assert_float_eq;
427
428    use crate::nn::Rope;
429
430    // The unit test below is adapted from the swift binding at:
431    // mlx-swift/Tests/MLXTests/IntegrationTests.swift
432    #[test]
433    fn test_rope() {
434        crate::random::seed(71).unwrap();
435        let a = uniform::<_, f32>(0, 1, &[2, 8, 16], None).unwrap();
436        assert_eq!(a.shape(), &[2, 8, 16]);
437        assert_eq!(a.dtype(), Dtype::Float32);
438        assert_float_eq!(
439            a.mean(None).unwrap().item::<f32>(),
440            0.5082664489746094,
441            abs <= 0.010165328979492188
442        );
443        assert_float_eq!(
444            a.sum(None).unwrap().item::<f32>(),
445            130.1162109375,
446            abs <= 2.60232421875
447        );
448
449        let mut rope = Rope::new(8);
450        let result = rope.forward(&a).unwrap();
451        assert_eq!(result.shape(), &[2, 8, 16]);
452        assert_eq!(result.dtype(), Dtype::Float32);
453        assert_float_eq!(
454            result.mean(None).unwrap().item::<f32>(),
455            0.4562537670135498,
456            abs <= 0.009125075340270997
457        );
458        assert_float_eq!(
459            result.sum(None).unwrap().item::<f32>(),
460            116.80096435546875,
461            abs <= 2.3360192871093752
462        );
463    }
464
465    // The unit test below is adapted from the swift binding at:
466    // mlx-swift/Tests/MLXTests/IntegrationTests.swift
467    #[test]
468    fn test_sinpe() {
469        crate::random::seed(226).unwrap();
470        let a = uniform::<_, f32>(0, 1, &[2, 8, 16], None).unwrap();
471        assert_eq!(a.shape(), &[2, 8, 16]);
472        assert_eq!(a.dtype(), Dtype::Float32);
473        assert_float_eq!(
474            a.mean(None).unwrap().item::<f32>(),
475            0.5026599168777466,
476            abs <= 0.010053198337554931
477        );
478        assert_float_eq!(
479            a.sum(None).unwrap().item::<f32>(),
480            128.68093872070312,
481            abs <= 2.5736187744140624
482        );
483
484        let mut sinpe = crate::nn::Sinpe::new(8).unwrap();
485        let result = sinpe.forward(&a).unwrap();
486        assert_eq!(result.shape(), &[2, 8, 16, 8]);
487        assert_eq!(result.dtype(), Dtype::Float32);
488        assert_float_eq!(
489            result.mean(None).unwrap().item::<f32>(),
490            0.2705308198928833,
491            abs <= 0.005410616397857666
492        );
493        assert_float_eq!(
494            result.sum(None).unwrap().item::<f32>(),
495            554.047119140625,
496            abs <= 11.0809423828125
497        );
498    }
499
500    // The unit test below is adapted from the python binding at:
501    // mlx/python/tests/test_nn.py
502    #[test]
503    fn test_alibi() {
504        let mut alibi = crate::nn::Alibi;
505        let shape = [1, 8, 20, 20];
506        let x = uniform::<_, f32>(0, 1, &shape, None).unwrap();
507        let input = AlibiInput::from(&x);
508        let y = alibi.forward(input).unwrap();
509        assert_eq!(y.shape(), shape);
510        assert_eq!(y.dtype(), Dtype::Float32);
511
512        let x2 = x.as_dtype(Dtype::Float16).unwrap();
513        let input = AlibiInput::from(&x2);
514        let y = alibi.forward(input).unwrap();
515        assert_eq!(y.dtype(), Dtype::Float16);
516    }
517}