mlx_rs/nn/
transformer.rs

1use std::borrow::Cow;
2
3use crate::{
4    array,
5    builder::Builder,
6    error::Exception,
7    module::{Module, UnaryModule},
8    ops::{arange, expand_dims, matmul, softmax},
9    quantization::MaybeQuantized,
10    Array, ArrayElement, FromScalar,
11};
12use dyn_clone::DynClone;
13use mlx_internal_macros::{generate_builder, Buildable, Builder};
14use mlx_macros::{ModuleParameters, Quantizable};
15use num_traits::bounds::LowerBounded;
16
17use crate::{
18    error::{MultiHeadAttentionBuildError, TransformerBulidError},
19    nn::{Dropout, DropoutBuilder, LayerNorm, Linear, LinearBuilder, Relu},
20};
21
22/// A marker trait for activation functions used in transformers.
23pub trait Activation: UnaryModule<Error = Exception> + std::fmt::Debug + DynClone {}
24
25impl<M> Activation for M where M: UnaryModule<Error = Exception> + std::fmt::Debug + DynClone {}
26
27/// Builder for the [`MultiHeadAttention`] module
28#[derive(Debug, Clone, Builder)]
29#[builder(
30    root = crate,
31    build_with = build_multi_head_attention,
32    err = MultiHeadAttentionBuildError,
33)]
34pub struct MultiHeadAttentionBuilder {
35    /// Model dimensions and default for the other dimensions if they are not supplied
36    pub dims: i32,
37
38    /// Number of attention heads
39    pub num_heads: i32,
40
41    /// Input dimensions of queries
42    #[builder(optional, default = None)]
43    pub query_input_dims: Option<i32>,
44
45    /// Input dimensions of keys
46    #[builder(optional, default = None)]
47    pub key_input_dims: Option<i32>,
48
49    /// Input dimensions of values
50    #[builder(optional, default = None)]
51    pub value_input_dims: Option<i32>,
52
53    /// Dimensions of values after the projection
54    #[builder(optional, default = None)]
55    pub value_dims: Option<i32>,
56
57    /// Dimensions new values will be projected to
58    #[builder(optional, default = None)]
59    pub value_output_dims: Option<i32>,
60
61    /// If `true`, use a bias in the [`Linear`] layers
62    #[builder(optional, default = MultiHeadAttention::DEFAULT_BIAS)]
63    pub bias: bool,
64}
65
66fn build_multi_head_attention(
67    builder: MultiHeadAttentionBuilder,
68) -> Result<MultiHeadAttention, MultiHeadAttentionBuildError> {
69    if builder.dims % builder.num_heads != 0 {
70        return Err(MultiHeadAttentionBuildError::InvalidNumHeads(
71            builder.num_heads,
72        ));
73    }
74
75    let dims = builder.dims;
76    let bias = builder.bias;
77    let query_input_dims = builder.query_input_dims.unwrap_or(builder.dims);
78    let key_input_dims = builder.key_input_dims.unwrap_or(builder.dims);
79    let value_input_dims = builder.value_input_dims.unwrap_or(builder.dims);
80    let value_dims = builder.value_dims.unwrap_or(builder.dims);
81    let value_output_dims = builder.value_output_dims.unwrap_or(builder.dims);
82
83    let num_heads = builder.num_heads;
84
85    let query_proj = LinearBuilder::new(query_input_dims, dims)
86        .bias(bias)
87        .build()?;
88    let key_proj = LinearBuilder::new(key_input_dims, dims)
89        .bias(bias)
90        .build()?;
91    let value_proj = LinearBuilder::new(value_input_dims, value_dims)
92        .bias(bias)
93        .build()?;
94    let output_proj = LinearBuilder::new(value_dims, value_output_dims)
95        .bias(bias)
96        .build()?;
97
98    Ok(MultiHeadAttention {
99        num_heads,
100        query_proj: MaybeQuantized::new(query_proj),
101        key_proj: MaybeQuantized::new(key_proj),
102        value_proj: MaybeQuantized::new(value_proj),
103        output_proj: MaybeQuantized::new(output_proj),
104    })
105}
106
107/// Implements the scaled dot product attention with multiple heads.
108#[derive(Debug, Clone, ModuleParameters, Quantizable, Buildable)]
109#[module(root = crate)]
110#[quantizable(root = crate)]
111#[buildable(root = crate)]
112pub struct MultiHeadAttention {
113    /// Number of attention heads
114    pub num_heads: i32,
115
116    /// Query projection layer
117    #[quantizable]
118    #[param]
119    pub query_proj: MaybeQuantized<Linear>,
120
121    /// Key projection layer
122    #[quantizable]
123    #[param]
124    pub key_proj: MaybeQuantized<Linear>,
125
126    /// Value projection layer
127    #[quantizable]
128    #[param]
129    pub value_proj: MaybeQuantized<Linear>,
130
131    /// Output projection layer
132    #[quantizable]
133    #[param]
134    pub output_proj: MaybeQuantized<Linear>,
135}
136
137impl MultiHeadAttention {
138    /// Default value for the `bias` field
139    pub const DEFAULT_BIAS: bool = false;
140
141    /// Creates an attention mask for use with [`MultiHeadAttention`].
142    pub fn create_additive_causal_mask<T>(n: i32) -> Result<Array, Exception>
143    where
144        T: ArrayElement + LowerBounded,
145        Array: FromScalar<T>,
146    {
147        let indices = arange::<_, T>(0, n, 1)?;
148        let left = expand_dims(&indices, &[1])?;
149        let right = expand_dims(&indices, &[0])?;
150        let mask = left.lt(right)?;
151        let mask = mask.as_type::<T>()?.multiply(array!(T::min_value()))?; // TODO: replace with f32::MIN?
152        Ok(mask)
153    }
154}
155
156generate_builder! {
157    /// Input to the [`MultiHeadAttention`] module
158    #[derive(Debug, Clone, Buildable)]
159    #[buildable(root = crate)]
160    #[builder(root = crate)]
161    pub struct MultiHeadAttentionInput<'a> {
162        /// Queries
163        pub queries: &'a Array,
164
165        /// Keys
166        pub keys: &'a Array,
167
168        /// Values
169        pub values: &'a Array,
170
171        /// Mask
172        #[builder(optional, default = None)]
173        pub mask: Option<&'a Array>,
174    }
175}
176
177impl<'a> From<(&'a Array, &'a Array, &'a Array)> for MultiHeadAttentionInput<'a> {
178    fn from((queries, keys, values): (&'a Array, &'a Array, &'a Array)) -> Self {
179        MultiHeadAttentionInput {
180            queries,
181            keys,
182            values,
183            mask: None,
184        }
185    }
186}
187
188impl<'a> From<(&'a Array, &'a Array, &'a Array, &'a Array)> for MultiHeadAttentionInput<'a> {
189    fn from((queries, keys, values, mask): (&'a Array, &'a Array, &'a Array, &'a Array)) -> Self {
190        MultiHeadAttentionInput {
191            queries,
192            keys,
193            values,
194            mask: Some(mask),
195        }
196    }
197}
198
199impl<'a> From<(&'a Array, &'a Array, &'a Array, Option<&'a Array>)>
200    for MultiHeadAttentionInput<'a>
201{
202    fn from(
203        (queries, keys, values, mask): (&'a Array, &'a Array, &'a Array, Option<&'a Array>),
204    ) -> Self {
205        MultiHeadAttentionInput {
206            queries,
207            keys,
208            values,
209            mask,
210        }
211    }
212}
213
214impl<'a, Input> Module<Input> for MultiHeadAttention
215where
216    Input: Into<MultiHeadAttentionInput<'a>>,
217{
218    type Error = Exception;
219    type Output = Array;
220
221    #[allow(non_snake_case)]
222    fn forward(&mut self, input: Input) -> Result<Self::Output, Self::Error> {
223        let input = input.into();
224        let queries = self.query_proj.forward(input.queries)?;
225        let keys = self.key_proj.forward(input.keys)?;
226        let values = self.value_proj.forward(input.values)?;
227
228        let B = queries.dim(0);
229        let L = queries.dim(1);
230        let S = keys.dim(1);
231
232        let queries = queries
233            .reshape(&[B, L, self.num_heads, -1])?
234            .transpose(&[0, 2, 1, 3])?;
235        let keys = keys
236            .reshape(&[B, S, self.num_heads, -1])?
237            .transpose(&[0, 2, 3, 1])?;
238        let values = values
239            .reshape(&[B, S, self.num_heads, -1])?
240            .transpose(&[0, 2, 1, 3])?;
241
242        // Dimensions are [batch x num_heads x sequence x hidden_dim]
243        let scale = f32::sqrt(1.0 / queries.dim(-1) as f32);
244        let mut scores = (queries * scale).matmul(&keys)?;
245        if let Some(mask) = input.mask {
246            scores = scores.add(mask.as_dtype(scores.dtype())?)?;
247        }
248        scores = softmax(&scores, &[-1], None)?;
249        let value_hat = matmul(&scores, &values)?
250            .transpose(&[0, 2, 1, 3])?
251            .reshape(&[B, L, -1])?;
252
253        self.output_proj.forward(&value_hat)
254    }
255
256    fn training_mode(&mut self, mode: bool) {
257        self.query_proj.training_mode(mode);
258        self.key_proj.training_mode(mode);
259        self.value_proj.training_mode(mode);
260        self.output_proj.training_mode(mode);
261    }
262}
263
264#[derive(Debug, Builder)]
265#[builder(
266    root = crate,
267    build_with = build_transformer_encoder_layer,
268    err = TransformerBulidError,
269)]
270struct TransformerEncoderLayerBuilder {
271    pub dimensions: i32,
272    pub num_heads: i32,
273
274    #[builder(optional, default = None)]
275    pub mlp_dimensions: Option<i32>,
276
277    #[builder(optional, default = Self::DEFAULT_DROPOUT)]
278    pub dropout: f32,
279
280    #[builder(optional, default = None)]
281    pub activation: Option<Box<dyn Activation>>,
282
283    pub norm_first: bool,
284}
285
286impl Clone for TransformerEncoderLayerBuilder {
287    fn clone(&self) -> Self {
288        Self {
289            dimensions: self.dimensions,
290            num_heads: self.num_heads,
291            mlp_dimensions: self.mlp_dimensions,
292            dropout: self.dropout,
293            activation: self
294                .activation
295                .as_ref()
296                .map(|a| dyn_clone::clone_box(a.as_ref())),
297            norm_first: self.norm_first,
298        }
299    }
300}
301
302// The const are placed in the builder because the encoder layer is not public anyway
303impl TransformerEncoderLayerBuilder {
304    const DEFAULT_DROPOUT: f32 = 0.0;
305}
306
307fn build_transformer_encoder_layer(
308    builder: TransformerEncoderLayerBuilder,
309) -> Result<TransformerEncoderLayer, TransformerBulidError> {
310    let dimensions = builder.dimensions;
311    let num_heads = builder.num_heads;
312    let mlp_dimensions = builder.mlp_dimensions.unwrap_or(4 * dimensions);
313    let dropout = builder.dropout;
314    let attention = MultiHeadAttention::new(dimensions, num_heads)?;
315    let ln1 = LayerNorm::new(dimensions)?;
316    let ln2 = LayerNorm::new(dimensions)?;
317    let linear1 = Linear::new(dimensions, mlp_dimensions)?;
318    let linear2 = Linear::new(mlp_dimensions, dimensions)?;
319    let dropout1 = DropoutBuilder::new().p(dropout).build()?;
320    let dropout2 = DropoutBuilder::new().p(dropout).build()?;
321    let activation = builder.activation.unwrap_or(Box::new(Relu));
322    let norm_first = builder.norm_first;
323
324    Ok(TransformerEncoderLayer {
325        attention,
326        ln1,
327        ln2,
328        linear1: MaybeQuantized::new(linear1),
329        linear2: MaybeQuantized::new(linear2),
330        dropout1,
331        dropout2,
332        activation,
333        norm_first,
334    })
335}
336
337/// Transformer encoder layer.
338#[derive(Debug, ModuleParameters, Quantizable, Buildable)]
339#[module(root = crate)]
340#[quantizable(root = crate)]
341#[buildable(root = crate)]
342struct TransformerEncoderLayer {
343    /// Multi-head attention module
344    #[param]
345    pub attention: MultiHeadAttention,
346
347    /// First layer norm module
348    #[param]
349    pub ln1: LayerNorm,
350
351    /// Second layer norm module
352    #[param]
353    pub ln2: LayerNorm,
354
355    /// First linear module
356    #[quantizable]
357    #[param]
358    pub linear1: MaybeQuantized<Linear>,
359
360    /// Second linear module
361    #[quantizable]
362    #[param]
363    pub linear2: MaybeQuantized<Linear>,
364
365    /// Dropout module for the first layer
366    #[param]
367    pub dropout1: Dropout,
368
369    /// Dropout module for the second layer
370    #[param]
371    pub dropout2: Dropout,
372
373    /// Activation function
374    #[param]
375    pub activation: Box<dyn Activation>,
376
377    /// If `true`, apply the layer norm before the first linear layer
378    pub norm_first: bool,
379}
380
381impl Clone for TransformerEncoderLayer {
382    fn clone(&self) -> Self {
383        Self {
384            attention: self.attention.clone(),
385            ln1: self.ln1.clone(),
386            ln2: self.ln2.clone(),
387            linear1: self.linear1.clone(),
388            linear2: self.linear2.clone(),
389            dropout1: self.dropout1.clone(),
390            dropout2: self.dropout2.clone(),
391            activation: dyn_clone::clone_box(&*self.activation),
392            norm_first: self.norm_first,
393        }
394    }
395}
396
397struct TransformerEncoderInput<'a> {
398    pub x: &'a Array,
399    pub mask: &'a Array,
400}
401
402impl<'a> From<(&'a Array, &'a Array)> for TransformerEncoderInput<'a> {
403    fn from((x, mask): (&'a Array, &'a Array)) -> Self {
404        TransformerEncoderInput { x, mask }
405    }
406}
407
408impl<'a, Input> Module<Input> for TransformerEncoderLayer
409where
410    Input: Into<TransformerEncoderInput<'a>>,
411{
412    type Error = Exception;
413    type Output = Array;
414
415    fn forward(&mut self, input: Input) -> Result<Self::Output, Self::Error> {
416        let input = input.into();
417        let x = input.x;
418        let mask = input.mask;
419
420        if self.norm_first {
421            let mut y = self.ln1.forward(x)?;
422            let attention_input = MultiHeadAttentionInput::from((&y, &y, &y, mask));
423            y = self.attention.forward(attention_input)?;
424            y = self.dropout1.forward(&y)?;
425            let x = x.add(&y)?;
426
427            y = self.ln2.forward(&x)?;
428            y = self.linear1.forward(&y)?;
429            y = self.activation.forward(&y)?;
430            y = self.dropout2.forward(&y)?;
431            y = self.linear2.forward(&y)?;
432            y = x.add(&y)?;
433
434            Ok(y)
435        } else {
436            let attention_input = MultiHeadAttentionInput::from((x, x, x, mask));
437            let mut y = self.attention.forward(attention_input)?;
438            y = self.dropout1.forward(&y)?;
439            let mut x = x.add(&y)?;
440            x = self.ln1.forward(&x)?;
441
442            y = self.linear1.forward(&x)?;
443            y = self.activation.forward(&y)?;
444            y = self.dropout2.forward(&y)?;
445            y = self.linear2.forward(&y)?;
446            y = x.add(&y)?;
447            y = self.ln2.forward(&y)?;
448
449            Ok(y)
450        }
451    }
452
453    fn training_mode(&mut self, mode: bool) {
454        <MultiHeadAttention as Module<MultiHeadAttentionInput>>::training_mode(
455            &mut self.attention,
456            mode,
457        );
458        self.ln1.training_mode(mode);
459        self.ln2.training_mode(mode);
460        self.linear1.training_mode(mode);
461        self.linear2.training_mode(mode);
462        self.dropout1.training_mode(mode);
463        self.dropout2.training_mode(mode);
464        self.activation.training_mode(mode);
465    }
466}
467
468#[derive(Debug, Builder)]
469#[builder(
470    root = crate,
471    build_with = build_transformer_encoder,
472    err = TransformerBulidError,
473)]
474struct TransformerEncoderBuilder {
475    pub layer_count: usize,
476    pub dimensions: i32,
477    pub num_heads: i32,
478
479    #[builder(optional, default = None)]
480    pub mlp_dimensions: Option<i32>,
481
482    #[builder(optional, default = Self::DEFAULT_DROPOUT)]
483    pub dropout: f32,
484
485    #[builder(optional, default = None)]
486    pub activation: Option<Box<dyn Activation>>,
487
488    pub norm_first: bool,
489}
490
491impl TransformerEncoderBuilder {
492    const DEFAULT_DROPOUT: f32 = 0.0;
493}
494
495impl Clone for TransformerEncoderBuilder {
496    fn clone(&self) -> Self {
497        Self {
498            layer_count: self.layer_count,
499            dimensions: self.dimensions,
500            num_heads: self.num_heads,
501            mlp_dimensions: self.mlp_dimensions,
502            dropout: self.dropout,
503            activation: self
504                .activation
505                .as_ref()
506                .map(|a| dyn_clone::clone_box(a.as_ref())),
507            norm_first: self.norm_first,
508        }
509    }
510}
511
512fn build_transformer_encoder(
513    builder: TransformerEncoderBuilder,
514) -> Result<TransformerEncoder, TransformerBulidError> {
515    let layer_count = builder.layer_count;
516    let dimensions = builder.dimensions;
517    let num_heads = builder.num_heads;
518    let norm_first = builder.norm_first;
519    let activation = builder.activation.unwrap_or(Box::new(Relu));
520
521    let layers = (0..layer_count)
522        .map(|_| {
523            TransformerEncoderLayerBuilder::new(dimensions, num_heads, norm_first)
524                .mlp_dimensions(builder.mlp_dimensions)
525                .dropout(builder.dropout)
526                .activation(dyn_clone::clone_box(&*activation))
527                .build()
528        })
529        .collect::<Result<Vec<_>, _>>()?;
530    let ln = LayerNorm::new(dimensions)?;
531
532    Ok(TransformerEncoder { layers, ln })
533}
534
535#[derive(Debug, Clone, ModuleParameters, Quantizable, Buildable)]
536#[module(root = crate)]
537#[quantizable(root = crate)]
538#[buildable(root = crate)]
539struct TransformerEncoder {
540    #[quantizable]
541    #[param]
542    pub layers: Vec<TransformerEncoderLayer>,
543
544    #[param]
545    pub ln: LayerNorm,
546}
547
548impl<'a, Input> Module<Input> for TransformerEncoder
549where
550    Input: Into<TransformerEncoderInput<'a>>,
551{
552    type Error = Exception;
553    type Output = Array;
554
555    fn forward(&mut self, input: Input) -> Result<Self::Output, Self::Error> {
556        let input = input.into();
557        let x = input.x;
558        let mask = input.mask;
559
560        let mut x = Cow::Borrowed(x);
561
562        for l in &mut self.layers {
563            let layer_input = TransformerEncoderInput::from((&*x, mask));
564            x = Cow::Owned(l.forward(layer_input)?);
565        }
566
567        self.ln.forward(&*x)
568    }
569
570    fn training_mode(&mut self, mode: bool) {
571        self.layers.iter_mut().for_each(|layer| {
572            <TransformerEncoderLayer as Module<TransformerEncoderInput>>::training_mode(
573                layer, mode,
574            );
575        });
576        self.ln.training_mode(mode);
577    }
578}
579
580#[derive(Debug, Builder)]
581#[builder(
582    root = crate,
583    build_with = build_transformer_decoder_layer,
584    err = TransformerBulidError,
585)]
586struct TransformerDecoderLayerBuilder {
587    pub dimensions: i32,
588    pub num_heads: i32,
589    #[builder(optional, default = None)]
590    pub ml_dimensions: Option<i32>,
591    #[builder(optional, default = Self::DEFAULT_DROPOUT)]
592    pub dropout: f32,
593    #[builder(optional, default = None)]
594    pub activation: Option<Box<dyn Activation>>,
595    pub norm_first: bool,
596}
597
598impl TransformerDecoderLayerBuilder {
599    const DEFAULT_DROPOUT: f32 = 0.0;
600}
601
602impl Clone for TransformerDecoderLayerBuilder {
603    fn clone(&self) -> Self {
604        Self {
605            dimensions: self.dimensions,
606            num_heads: self.num_heads,
607            ml_dimensions: self.ml_dimensions,
608            dropout: self.dropout,
609            activation: self
610                .activation
611                .as_ref()
612                .map(|a| dyn_clone::clone_box(a.as_ref())),
613            norm_first: self.norm_first,
614        }
615    }
616}
617
618fn build_transformer_decoder_layer(
619    builder: TransformerDecoderLayerBuilder,
620) -> Result<TransformerDecoderLayer, TransformerBulidError> {
621    let dimensions = builder.dimensions;
622    let num_heads = builder.num_heads;
623    let mlp_dimensions = builder.ml_dimensions.unwrap_or(4 * dimensions);
624    let dropout = builder.dropout;
625
626    let self_attention = MultiHeadAttention::new(dimensions, num_heads)?;
627    let cross_attention = MultiHeadAttention::new(dimensions, num_heads)?;
628    let ln1 = LayerNorm::new(dimensions)?;
629    let ln2 = LayerNorm::new(dimensions)?;
630    let ln3 = LayerNorm::new(dimensions)?;
631    let linear1 = Linear::new(dimensions, mlp_dimensions)?;
632    let linear2 = Linear::new(mlp_dimensions, dimensions)?;
633    let dropout1 = DropoutBuilder::new().p(dropout).build()?;
634    let dropout2 = DropoutBuilder::new().p(dropout).build()?;
635    let dropout3 = DropoutBuilder::new().p(dropout).build()?;
636    let activation = builder.activation.unwrap_or(Box::new(Relu));
637    let norm_first = builder.norm_first;
638
639    Ok(TransformerDecoderLayer {
640        self_attention,
641        cross_attention,
642        ln1,
643        ln2,
644        ln3,
645        linear1: MaybeQuantized::new(linear1),
646        linear2: MaybeQuantized::new(linear2),
647        dropout1,
648        dropout2,
649        dropout3,
650        activation,
651        norm_first,
652    })
653}
654
655#[derive(Debug, ModuleParameters, Quantizable, Buildable)]
656#[module(root = crate)]
657#[quantizable(root = crate)]
658#[buildable(root = crate)]
659struct TransformerDecoderLayer {
660    #[param]
661    pub self_attention: MultiHeadAttention,
662
663    #[param]
664    pub cross_attention: MultiHeadAttention,
665
666    #[param]
667    pub ln1: LayerNorm,
668
669    #[param]
670    pub ln2: LayerNorm,
671
672    #[param]
673    pub ln3: LayerNorm,
674
675    #[quantizable]
676    #[param]
677    pub linear1: MaybeQuantized<Linear>,
678
679    #[quantizable]
680    #[param]
681    pub linear2: MaybeQuantized<Linear>,
682
683    #[param]
684    pub dropout1: Dropout,
685
686    #[param]
687    pub dropout2: Dropout,
688
689    #[param]
690    pub dropout3: Dropout,
691
692    #[param]
693    pub activation: Box<dyn Activation>,
694
695    pub norm_first: bool,
696}
697
698impl Clone for TransformerDecoderLayer {
699    fn clone(&self) -> Self {
700        Self {
701            self_attention: self.self_attention.clone(),
702            cross_attention: self.cross_attention.clone(),
703            ln1: self.ln1.clone(),
704            ln2: self.ln2.clone(),
705            ln3: self.ln3.clone(),
706            linear1: self.linear1.clone(),
707            linear2: self.linear2.clone(),
708            dropout1: self.dropout1.clone(),
709            dropout2: self.dropout2.clone(),
710            dropout3: self.dropout3.clone(),
711            activation: dyn_clone::clone_box(&*self.activation),
712            norm_first: self.norm_first,
713        }
714    }
715}
716
717struct TransformerDecoderInput<'a> {
718    pub x: &'a Array,
719    pub memory: &'a Array,
720    pub x_mask: &'a Array,
721    pub memory_mask: &'a Array,
722}
723
724impl<'a> From<(&'a Array, &'a Array, &'a Array, &'a Array)> for TransformerDecoderInput<'a> {
725    fn from(
726        (x, memory, x_mask, memory_mask): (&'a Array, &'a Array, &'a Array, &'a Array),
727    ) -> Self {
728        TransformerDecoderInput {
729            x,
730            memory,
731            x_mask,
732            memory_mask,
733        }
734    }
735}
736
737impl<'a, Input> Module<Input> for TransformerDecoderLayer
738where
739    Input: Into<TransformerDecoderInput<'a>>,
740{
741    type Error = Exception;
742    type Output = Array;
743
744    fn forward(&mut self, input: Input) -> Result<Self::Output, Self::Error> {
745        let input = input.into();
746        let x = input.x;
747        let memory = input.memory;
748        let x_mask = input.x_mask;
749        let memory_mask = input.memory_mask;
750
751        if self.norm_first {
752            let mut y = self.ln1.forward(x)?;
753            y = self
754                .self_attention
755                .forward(MultiHeadAttentionInput::from((&y, &y, &y, x_mask)))?;
756            y = self.dropout1.forward(&y)?;
757            let x = x.add(&y)?;
758
759            y = self.ln2.forward(&x)?;
760            y = self
761                .cross_attention
762                .forward(MultiHeadAttentionInput::from((
763                    &y,
764                    memory,
765                    memory,
766                    memory_mask,
767                )))?;
768            y = self.dropout2.forward(&y)?;
769            let x = x.add(&y)?;
770
771            y = self.ln3.forward(&x)?;
772            y = self.linear1.forward(&y)?;
773            y = self.activation.forward(&y)?;
774            y = self.dropout3.forward(&y)?;
775            y = self.linear2.forward(&y)?;
776            x.add(&y)
777        } else {
778            let mut y = self
779                .self_attention
780                .forward(MultiHeadAttentionInput::from((x, x, x, x_mask)))?;
781            y = self.dropout1.forward(&y)?;
782            let mut x = x.add(&y)?;
783            x = self.ln1.forward(&x)?;
784
785            y = self
786                .cross_attention
787                .forward(MultiHeadAttentionInput::from((
788                    &y,
789                    memory,
790                    memory,
791                    memory_mask,
792                )))?;
793            y = self.dropout2.forward(&y)?;
794            x = x.add(&y)?;
795            x = self.ln2.forward(&x)?; // TODO: https://github.com/ml-explore/mlx/issues/1636
796
797            y = self.linear1.forward(&x)?;
798            y = self.activation.forward(&y)?;
799            y = self.dropout3.forward(&y)?;
800            y = self.linear2.forward(&y)?;
801            y = x.add(&y)?;
802            self.ln3.forward(&y)
803        }
804    }
805
806    fn training_mode(&mut self, mode: bool) {
807        <MultiHeadAttention as Module<MultiHeadAttentionInput>>::training_mode(
808            &mut self.self_attention,
809            mode,
810        );
811        <MultiHeadAttention as Module<MultiHeadAttentionInput>>::training_mode(
812            &mut self.cross_attention,
813            mode,
814        );
815        self.ln1.training_mode(mode);
816        self.ln2.training_mode(mode);
817        self.ln3.training_mode(mode);
818        self.linear1.training_mode(mode);
819        self.linear2.training_mode(mode);
820        self.dropout1.training_mode(mode);
821        self.dropout2.training_mode(mode);
822        self.dropout3.training_mode(mode);
823        self.activation.training_mode(mode);
824    }
825}
826
827#[derive(Debug, Builder)]
828#[builder(
829    root = crate,
830    build_with = build_transformer_decoder,
831    err = TransformerBulidError,
832)]
833struct TransformerDecoderBuilder {
834    pub layer_count: usize,
835    pub dimensions: i32,
836    pub num_heads: i32,
837
838    #[builder(optional, default = None)]
839    pub mlp_dimensions: Option<i32>,
840
841    #[builder(optional, default = Self::DEFAULT_DROPOUT)]
842    pub dropout: f32,
843
844    #[builder(optional, default = None)]
845    pub activation: Option<Box<dyn Activation>>,
846
847    pub norm_first: bool,
848}
849
850impl TransformerDecoderBuilder {
851    const DEFAULT_DROPOUT: f32 = 0.0;
852}
853
854impl Clone for TransformerDecoderBuilder {
855    fn clone(&self) -> Self {
856        Self {
857            layer_count: self.layer_count,
858            dimensions: self.dimensions,
859            num_heads: self.num_heads,
860            mlp_dimensions: self.mlp_dimensions,
861            dropout: self.dropout,
862            activation: self
863                .activation
864                .as_ref()
865                .map(|a| dyn_clone::clone_box(a.as_ref())),
866            norm_first: self.norm_first,
867        }
868    }
869}
870
871fn build_transformer_decoder(
872    builder: TransformerDecoderBuilder,
873) -> Result<TransformerDecoder, TransformerBulidError> {
874    let layer_count = builder.layer_count;
875    let dimensions = builder.dimensions;
876    let num_heads = builder.num_heads;
877    let norm_first = builder.norm_first;
878
879    let activation = builder.activation.unwrap_or(Box::new(Relu));
880
881    let layers = (0..layer_count)
882        .map(|_| {
883            TransformerDecoderLayerBuilder::new(dimensions, num_heads, norm_first)
884                .ml_dimensions(builder.mlp_dimensions)
885                .dropout(builder.dropout)
886                .activation(dyn_clone::clone_box(&*activation))
887                .build()
888        })
889        .collect::<Result<Vec<_>, _>>()?;
890    let ln = LayerNorm::new(dimensions)?;
891
892    Ok(TransformerDecoder { layers, ln })
893}
894
895#[derive(Debug, Clone, ModuleParameters, Quantizable, Buildable)]
896#[module(root = crate)]
897#[quantizable(root = crate)]
898#[buildable(root = crate)]
899struct TransformerDecoder {
900    #[quantizable]
901    #[param]
902    pub layers: Vec<TransformerDecoderLayer>,
903
904    #[param]
905    pub ln: LayerNorm,
906}
907
908impl<'a, Input> Module<Input> for TransformerDecoder
909where
910    Input: Into<TransformerDecoderInput<'a>>,
911{
912    type Error = Exception;
913    type Output = Array;
914
915    fn forward(&mut self, input: Input) -> Result<Self::Output, Self::Error> {
916        let input = input.into();
917        let x = input.x;
918        let memory = input.memory;
919        let x_mask = input.x_mask;
920        let memory_mask = input.memory_mask;
921
922        let mut x = Cow::Borrowed(x);
923
924        for l in &mut self.layers {
925            let layer_input = TransformerDecoderInput::from((&*x, memory, x_mask, memory_mask));
926            x = Cow::Owned(l.forward(layer_input)?);
927        }
928
929        self.ln.forward(&*x)
930    }
931
932    fn training_mode(&mut self, mode: bool) {
933        self.layers.iter_mut().for_each(|layer| {
934            <TransformerDecoderLayer as Module<TransformerDecoderInput>>::training_mode(
935                layer, mode,
936            );
937        });
938        self.ln.training_mode(mode);
939    }
940}
941
942/// Builder for the [`Transformer`] module
943#[derive(Debug, Builder)]
944#[builder(
945    root = crate,
946    build_with = build_transformer,
947    err = TransformerBulidError,
948)]
949pub struct TransformerBuilder {
950    /// number of expected features in the encoder/decoder
951    #[builder(optional, default = Transformer::DEFAULT_DIMENSIONS)]
952    pub dimensions: i32,
953
954    /// number of attention heads
955    #[builder(optional, default = Transformer::DEFAULT_NUM_HEADS)]
956    pub num_heads: i32,
957
958    /// number of layers in the encoder
959    #[builder(optional, default = Transformer::DEFAULT_ENCODER_LAYERS_COUNT)]
960    pub encoder_layer_count: usize,
961
962    /// number of layers in the decoder
963    #[builder(optional, default = Transformer::DEFAULT_DECODER_LAYERS_COUNT)]
964    pub decoder_layer_count: usize,
965
966    /// hidden dimensions of the MLP block in each layer. Defaults to `4 * dimensions`
967    /// if not specified
968    #[builder(optional, default = None)]
969    pub mlp_dimensions: Option<i32>,
970
971    /// dropout value for the encode and decoder. Dropout is used after each attention layer
972    /// and the activation in the MLP layer
973    #[builder(optional, default = Transformer::DEFAULT_DROPOUT)]
974    pub dropout: f32,
975
976    /// the activation layer for the MLP hidden layer
977    #[builder(optional, default = None)]
978    pub activation: Option<Box<dyn Activation>>,
979
980    /// if `true` encode and decoder layers will perform layer normalization before
981    /// attention and MLP operations, otherwise after
982    #[builder(optional, default = Transformer::DEFAULT_NORM_FIRST)]
983    pub norm_first: bool,
984}
985
986impl Clone for TransformerBuilder {
987    fn clone(&self) -> Self {
988        Self {
989            dimensions: self.dimensions,
990            num_heads: self.num_heads,
991            encoder_layer_count: self.encoder_layer_count,
992            decoder_layer_count: self.decoder_layer_count,
993            mlp_dimensions: self.mlp_dimensions,
994            dropout: self.dropout,
995            activation: self
996                .activation
997                .as_ref()
998                .map(|a| dyn_clone::clone_box(a.as_ref())),
999            norm_first: self.norm_first,
1000        }
1001    }
1002}
1003
1004fn build_transformer(builder: TransformerBuilder) -> Result<Transformer, TransformerBulidError> {
1005    let dimensions = builder.dimensions;
1006    let num_heads = builder.num_heads;
1007    let encoder_layer_count = builder.encoder_layer_count;
1008    let decoder_layer_count = builder.decoder_layer_count;
1009    let mlp_dimensions = builder.mlp_dimensions;
1010    let dropout = builder.dropout;
1011    let activation = builder.activation.unwrap_or(Box::new(Relu));
1012    let norm_first = builder.norm_first;
1013
1014    let encoder =
1015        TransformerEncoderBuilder::new(encoder_layer_count, dimensions, num_heads, norm_first)
1016            .mlp_dimensions(mlp_dimensions)
1017            .dropout(dropout)
1018            .activation(dyn_clone::clone_box(&*activation))
1019            .build()?;
1020    let decoder =
1021        TransformerDecoderBuilder::new(decoder_layer_count, dimensions, num_heads, norm_first)
1022            .mlp_dimensions(mlp_dimensions)
1023            .dropout(dropout)
1024            .activation(dyn_clone::clone_box(&*activation))
1025            .build()?;
1026
1027    Ok(Transformer { encoder, decoder })
1028}
1029
1030/// Implements a standard Transformer model.
1031///
1032/// The implementation is based on "Attention Is All You Need"
1033/// <https://arxiv.org/abs/1706.03762>.
1034///
1035/// The Transformer model contains an encoder and a decoder. The encoder
1036/// processes the input sequence and the decoder generates the output sequence.
1037/// The interaction between encoder and decoder happens through the attention
1038/// mechanism.
1039#[derive(Debug, Clone, ModuleParameters, Quantizable, Buildable)]
1040#[module(root = crate)]
1041#[quantizable(root = crate)]
1042#[buildable(root = crate)]
1043pub struct Transformer {
1044    /// Encoder module
1045    #[quantizable]
1046    #[param]
1047    encoder: TransformerEncoder, // TODO: visibility?
1048
1049    /// Decoder module
1050    #[quantizable]
1051    #[param]
1052    decoder: TransformerDecoder, // TODO: visibility?
1053}
1054
1055impl Transformer {
1056    /// Default value for `dimensions`
1057    pub const DEFAULT_DIMENSIONS: i32 = 512;
1058
1059    /// Default value for `num_heads`
1060    pub const DEFAULT_NUM_HEADS: i32 = 8;
1061
1062    /// Default number of encoder layers
1063    pub const DEFAULT_ENCODER_LAYERS_COUNT: usize = 6;
1064
1065    /// Default number of decoder layers
1066    pub const DEFAULT_DECODER_LAYERS_COUNT: usize = 6;
1067
1068    /// Default value for dropout
1069    pub const DEFAULT_DROPOUT: f32 = 0.0;
1070
1071    /// Default value for `activation`
1072    pub const DEFAULT_NORM_FIRST: bool = false;
1073}
1074
1075/// Input to the [`Transformer`] module
1076#[derive(Debug, Clone)]
1077pub struct TransformerInput<'a> {
1078    /// Source
1079    pub source: &'a Array,
1080
1081    /// Target
1082    pub target: &'a Array,
1083
1084    /// Source mask
1085    pub source_mask: &'a Array,
1086
1087    /// Target mask
1088    pub target_mask: &'a Array,
1089
1090    /// Memory mask
1091    pub memory_mask: &'a Array,
1092}
1093
1094impl<'a> From<(&'a Array, &'a Array, &'a Array, &'a Array, &'a Array)> for TransformerInput<'a> {
1095    fn from(
1096        (source, target, source_mask, target_mask, memory_mask): (
1097            &'a Array,
1098            &'a Array,
1099            &'a Array,
1100            &'a Array,
1101            &'a Array,
1102        ),
1103    ) -> Self {
1104        TransformerInput {
1105            source,
1106            target,
1107            source_mask,
1108            target_mask,
1109            memory_mask,
1110        }
1111    }
1112}
1113
1114impl<'a, Input> Module<Input> for Transformer
1115where
1116    Input: Into<TransformerInput<'a>>,
1117{
1118    type Error = Exception;
1119    type Output = Array;
1120
1121    fn forward(&mut self, input: Input) -> Result<Self::Output, Self::Error> {
1122        let input = input.into();
1123        let source = input.source;
1124        let target = input.target;
1125        let source_mask = input.source_mask;
1126        let target_mask = input.target_mask;
1127        let memory_mask = input.memory_mask;
1128
1129        let memory = self
1130            .encoder
1131            .forward(TransformerEncoderInput::from((source, source_mask)))?;
1132        self.decoder.forward(TransformerDecoderInput::from((
1133            target,
1134            &memory,
1135            target_mask,
1136            memory_mask,
1137        )))
1138    }
1139
1140    fn training_mode(&mut self, mode: bool) {
1141        <TransformerEncoder as Module<TransformerEncoderInput>>::training_mode(
1142            &mut self.encoder,
1143            mode,
1144        );
1145        <TransformerDecoder as Module<TransformerDecoderInput>>::training_mode(
1146            &mut self.decoder,
1147            mode,
1148        );
1149    }
1150}