1use std::borrow::Cow;
2
3use crate::{
4 array,
5 builder::Builder,
6 error::Exception,
7 module::{Module, UnaryModule},
8 ops::{arange, expand_dims, matmul, softmax},
9 quantization::MaybeQuantized,
10 Array, ArrayElement, FromScalar,
11};
12use dyn_clone::DynClone;
13use mlx_internal_macros::{generate_builder, Buildable, Builder};
14use mlx_macros::{ModuleParameters, Quantizable};
15use num_traits::bounds::LowerBounded;
16
17use crate::{
18 error::{MultiHeadAttentionBuildError, TransformerBulidError},
19 nn::{Dropout, DropoutBuilder, LayerNorm, Linear, LinearBuilder, Relu},
20};
21
22pub trait Activation: UnaryModule<Error = Exception> + std::fmt::Debug + DynClone {}
24
25impl<M> Activation for M where M: UnaryModule<Error = Exception> + std::fmt::Debug + DynClone {}
26
27#[derive(Debug, Clone, Builder)]
29#[builder(
30 root = crate,
31 build_with = build_multi_head_attention,
32 err = MultiHeadAttentionBuildError,
33)]
34pub struct MultiHeadAttentionBuilder {
35 pub dims: i32,
37
38 pub num_heads: i32,
40
41 #[builder(optional, default = None)]
43 pub query_input_dims: Option<i32>,
44
45 #[builder(optional, default = None)]
47 pub key_input_dims: Option<i32>,
48
49 #[builder(optional, default = None)]
51 pub value_input_dims: Option<i32>,
52
53 #[builder(optional, default = None)]
55 pub value_dims: Option<i32>,
56
57 #[builder(optional, default = None)]
59 pub value_output_dims: Option<i32>,
60
61 #[builder(optional, default = MultiHeadAttention::DEFAULT_BIAS)]
63 pub bias: bool,
64}
65
66fn build_multi_head_attention(
67 builder: MultiHeadAttentionBuilder,
68) -> Result<MultiHeadAttention, MultiHeadAttentionBuildError> {
69 if builder.dims % builder.num_heads != 0 {
70 return Err(MultiHeadAttentionBuildError::InvalidNumHeads(
71 builder.num_heads,
72 ));
73 }
74
75 let dims = builder.dims;
76 let bias = builder.bias;
77 let query_input_dims = builder.query_input_dims.unwrap_or(builder.dims);
78 let key_input_dims = builder.key_input_dims.unwrap_or(builder.dims);
79 let value_input_dims = builder.value_input_dims.unwrap_or(builder.dims);
80 let value_dims = builder.value_dims.unwrap_or(builder.dims);
81 let value_output_dims = builder.value_output_dims.unwrap_or(builder.dims);
82
83 let num_heads = builder.num_heads;
84
85 let query_proj = LinearBuilder::new(query_input_dims, dims)
86 .bias(bias)
87 .build()?;
88 let key_proj = LinearBuilder::new(key_input_dims, dims)
89 .bias(bias)
90 .build()?;
91 let value_proj = LinearBuilder::new(value_input_dims, value_dims)
92 .bias(bias)
93 .build()?;
94 let output_proj = LinearBuilder::new(value_dims, value_output_dims)
95 .bias(bias)
96 .build()?;
97
98 Ok(MultiHeadAttention {
99 num_heads,
100 query_proj: MaybeQuantized::new(query_proj),
101 key_proj: MaybeQuantized::new(key_proj),
102 value_proj: MaybeQuantized::new(value_proj),
103 output_proj: MaybeQuantized::new(output_proj),
104 })
105}
106
107#[derive(Debug, Clone, ModuleParameters, Quantizable, Buildable)]
109#[module(root = crate)]
110#[quantizable(root = crate)]
111#[buildable(root = crate)]
112pub struct MultiHeadAttention {
113 pub num_heads: i32,
115
116 #[quantizable]
118 #[param]
119 pub query_proj: MaybeQuantized<Linear>,
120
121 #[quantizable]
123 #[param]
124 pub key_proj: MaybeQuantized<Linear>,
125
126 #[quantizable]
128 #[param]
129 pub value_proj: MaybeQuantized<Linear>,
130
131 #[quantizable]
133 #[param]
134 pub output_proj: MaybeQuantized<Linear>,
135}
136
137impl MultiHeadAttention {
138 pub const DEFAULT_BIAS: bool = false;
140
141 pub fn create_additive_causal_mask<T>(n: i32) -> Result<Array, Exception>
143 where
144 T: ArrayElement + LowerBounded,
145 Array: FromScalar<T>,
146 {
147 let indices = arange::<_, T>(0, n, 1)?;
148 let left = expand_dims(&indices, &[1])?;
149 let right = expand_dims(&indices, &[0])?;
150 let mask = left.lt(right)?;
151 let mask = mask.as_type::<T>()?.multiply(array!(T::min_value()))?; Ok(mask)
153 }
154}
155
156generate_builder! {
157 #[derive(Debug, Clone, Buildable)]
159 #[buildable(root = crate)]
160 #[builder(root = crate)]
161 pub struct MultiHeadAttentionInput<'a> {
162 pub queries: &'a Array,
164
165 pub keys: &'a Array,
167
168 pub values: &'a Array,
170
171 #[builder(optional, default = None)]
173 pub mask: Option<&'a Array>,
174 }
175}
176
177impl<'a> From<(&'a Array, &'a Array, &'a Array)> for MultiHeadAttentionInput<'a> {
178 fn from((queries, keys, values): (&'a Array, &'a Array, &'a Array)) -> Self {
179 MultiHeadAttentionInput {
180 queries,
181 keys,
182 values,
183 mask: None,
184 }
185 }
186}
187
188impl<'a> From<(&'a Array, &'a Array, &'a Array, &'a Array)> for MultiHeadAttentionInput<'a> {
189 fn from((queries, keys, values, mask): (&'a Array, &'a Array, &'a Array, &'a Array)) -> Self {
190 MultiHeadAttentionInput {
191 queries,
192 keys,
193 values,
194 mask: Some(mask),
195 }
196 }
197}
198
199impl<'a> From<(&'a Array, &'a Array, &'a Array, Option<&'a Array>)>
200 for MultiHeadAttentionInput<'a>
201{
202 fn from(
203 (queries, keys, values, mask): (&'a Array, &'a Array, &'a Array, Option<&'a Array>),
204 ) -> Self {
205 MultiHeadAttentionInput {
206 queries,
207 keys,
208 values,
209 mask,
210 }
211 }
212}
213
214impl<'a, Input> Module<Input> for MultiHeadAttention
215where
216 Input: Into<MultiHeadAttentionInput<'a>>,
217{
218 type Error = Exception;
219 type Output = Array;
220
221 #[allow(non_snake_case)]
222 fn forward(&mut self, input: Input) -> Result<Self::Output, Self::Error> {
223 let input = input.into();
224 let queries = self.query_proj.forward(input.queries)?;
225 let keys = self.key_proj.forward(input.keys)?;
226 let values = self.value_proj.forward(input.values)?;
227
228 let B = queries.dim(0);
229 let L = queries.dim(1);
230 let S = keys.dim(1);
231
232 let queries = queries
233 .reshape(&[B, L, self.num_heads, -1])?
234 .transpose(&[0, 2, 1, 3])?;
235 let keys = keys
236 .reshape(&[B, S, self.num_heads, -1])?
237 .transpose(&[0, 2, 3, 1])?;
238 let values = values
239 .reshape(&[B, S, self.num_heads, -1])?
240 .transpose(&[0, 2, 1, 3])?;
241
242 let scale = f32::sqrt(1.0 / queries.dim(-1) as f32);
244 let mut scores = (queries * scale).matmul(&keys)?;
245 if let Some(mask) = input.mask {
246 scores = scores.add(mask.as_dtype(scores.dtype())?)?;
247 }
248 scores = softmax(&scores, &[-1], None)?;
249 let value_hat = matmul(&scores, &values)?
250 .transpose(&[0, 2, 1, 3])?
251 .reshape(&[B, L, -1])?;
252
253 self.output_proj.forward(&value_hat)
254 }
255
256 fn training_mode(&mut self, mode: bool) {
257 self.query_proj.training_mode(mode);
258 self.key_proj.training_mode(mode);
259 self.value_proj.training_mode(mode);
260 self.output_proj.training_mode(mode);
261 }
262}
263
264#[derive(Debug, Builder)]
265#[builder(
266 root = crate,
267 build_with = build_transformer_encoder_layer,
268 err = TransformerBulidError,
269)]
270struct TransformerEncoderLayerBuilder {
271 pub dimensions: i32,
272 pub num_heads: i32,
273
274 #[builder(optional, default = None)]
275 pub mlp_dimensions: Option<i32>,
276
277 #[builder(optional, default = Self::DEFAULT_DROPOUT)]
278 pub dropout: f32,
279
280 #[builder(optional, default = None)]
281 pub activation: Option<Box<dyn Activation>>,
282
283 pub norm_first: bool,
284}
285
286impl Clone for TransformerEncoderLayerBuilder {
287 fn clone(&self) -> Self {
288 Self {
289 dimensions: self.dimensions,
290 num_heads: self.num_heads,
291 mlp_dimensions: self.mlp_dimensions,
292 dropout: self.dropout,
293 activation: self
294 .activation
295 .as_ref()
296 .map(|a| dyn_clone::clone_box(a.as_ref())),
297 norm_first: self.norm_first,
298 }
299 }
300}
301
302impl TransformerEncoderLayerBuilder {
304 const DEFAULT_DROPOUT: f32 = 0.0;
305}
306
307fn build_transformer_encoder_layer(
308 builder: TransformerEncoderLayerBuilder,
309) -> Result<TransformerEncoderLayer, TransformerBulidError> {
310 let dimensions = builder.dimensions;
311 let num_heads = builder.num_heads;
312 let mlp_dimensions = builder.mlp_dimensions.unwrap_or(4 * dimensions);
313 let dropout = builder.dropout;
314 let attention = MultiHeadAttention::new(dimensions, num_heads)?;
315 let ln1 = LayerNorm::new(dimensions)?;
316 let ln2 = LayerNorm::new(dimensions)?;
317 let linear1 = Linear::new(dimensions, mlp_dimensions)?;
318 let linear2 = Linear::new(mlp_dimensions, dimensions)?;
319 let dropout1 = DropoutBuilder::new().p(dropout).build()?;
320 let dropout2 = DropoutBuilder::new().p(dropout).build()?;
321 let activation = builder.activation.unwrap_or(Box::new(Relu));
322 let norm_first = builder.norm_first;
323
324 Ok(TransformerEncoderLayer {
325 attention,
326 ln1,
327 ln2,
328 linear1: MaybeQuantized::new(linear1),
329 linear2: MaybeQuantized::new(linear2),
330 dropout1,
331 dropout2,
332 activation,
333 norm_first,
334 })
335}
336
337#[derive(Debug, ModuleParameters, Quantizable, Buildable)]
339#[module(root = crate)]
340#[quantizable(root = crate)]
341#[buildable(root = crate)]
342struct TransformerEncoderLayer {
343 #[param]
345 pub attention: MultiHeadAttention,
346
347 #[param]
349 pub ln1: LayerNorm,
350
351 #[param]
353 pub ln2: LayerNorm,
354
355 #[quantizable]
357 #[param]
358 pub linear1: MaybeQuantized<Linear>,
359
360 #[quantizable]
362 #[param]
363 pub linear2: MaybeQuantized<Linear>,
364
365 #[param]
367 pub dropout1: Dropout,
368
369 #[param]
371 pub dropout2: Dropout,
372
373 #[param]
375 pub activation: Box<dyn Activation>,
376
377 pub norm_first: bool,
379}
380
381impl Clone for TransformerEncoderLayer {
382 fn clone(&self) -> Self {
383 Self {
384 attention: self.attention.clone(),
385 ln1: self.ln1.clone(),
386 ln2: self.ln2.clone(),
387 linear1: self.linear1.clone(),
388 linear2: self.linear2.clone(),
389 dropout1: self.dropout1.clone(),
390 dropout2: self.dropout2.clone(),
391 activation: dyn_clone::clone_box(&*self.activation),
392 norm_first: self.norm_first,
393 }
394 }
395}
396
397struct TransformerEncoderInput<'a> {
398 pub x: &'a Array,
399 pub mask: &'a Array,
400}
401
402impl<'a> From<(&'a Array, &'a Array)> for TransformerEncoderInput<'a> {
403 fn from((x, mask): (&'a Array, &'a Array)) -> Self {
404 TransformerEncoderInput { x, mask }
405 }
406}
407
408impl<'a, Input> Module<Input> for TransformerEncoderLayer
409where
410 Input: Into<TransformerEncoderInput<'a>>,
411{
412 type Error = Exception;
413 type Output = Array;
414
415 fn forward(&mut self, input: Input) -> Result<Self::Output, Self::Error> {
416 let input = input.into();
417 let x = input.x;
418 let mask = input.mask;
419
420 if self.norm_first {
421 let mut y = self.ln1.forward(x)?;
422 let attention_input = MultiHeadAttentionInput::from((&y, &y, &y, mask));
423 y = self.attention.forward(attention_input)?;
424 y = self.dropout1.forward(&y)?;
425 let x = x.add(&y)?;
426
427 y = self.ln2.forward(&x)?;
428 y = self.linear1.forward(&y)?;
429 y = self.activation.forward(&y)?;
430 y = self.dropout2.forward(&y)?;
431 y = self.linear2.forward(&y)?;
432 y = x.add(&y)?;
433
434 Ok(y)
435 } else {
436 let attention_input = MultiHeadAttentionInput::from((x, x, x, mask));
437 let mut y = self.attention.forward(attention_input)?;
438 y = self.dropout1.forward(&y)?;
439 let mut x = x.add(&y)?;
440 x = self.ln1.forward(&x)?;
441
442 y = self.linear1.forward(&x)?;
443 y = self.activation.forward(&y)?;
444 y = self.dropout2.forward(&y)?;
445 y = self.linear2.forward(&y)?;
446 y = x.add(&y)?;
447 y = self.ln2.forward(&y)?;
448
449 Ok(y)
450 }
451 }
452
453 fn training_mode(&mut self, mode: bool) {
454 <MultiHeadAttention as Module<MultiHeadAttentionInput>>::training_mode(
455 &mut self.attention,
456 mode,
457 );
458 self.ln1.training_mode(mode);
459 self.ln2.training_mode(mode);
460 self.linear1.training_mode(mode);
461 self.linear2.training_mode(mode);
462 self.dropout1.training_mode(mode);
463 self.dropout2.training_mode(mode);
464 self.activation.training_mode(mode);
465 }
466}
467
468#[derive(Debug, Builder)]
469#[builder(
470 root = crate,
471 build_with = build_transformer_encoder,
472 err = TransformerBulidError,
473)]
474struct TransformerEncoderBuilder {
475 pub layer_count: usize,
476 pub dimensions: i32,
477 pub num_heads: i32,
478
479 #[builder(optional, default = None)]
480 pub mlp_dimensions: Option<i32>,
481
482 #[builder(optional, default = Self::DEFAULT_DROPOUT)]
483 pub dropout: f32,
484
485 #[builder(optional, default = None)]
486 pub activation: Option<Box<dyn Activation>>,
487
488 pub norm_first: bool,
489}
490
491impl TransformerEncoderBuilder {
492 const DEFAULT_DROPOUT: f32 = 0.0;
493}
494
495impl Clone for TransformerEncoderBuilder {
496 fn clone(&self) -> Self {
497 Self {
498 layer_count: self.layer_count,
499 dimensions: self.dimensions,
500 num_heads: self.num_heads,
501 mlp_dimensions: self.mlp_dimensions,
502 dropout: self.dropout,
503 activation: self
504 .activation
505 .as_ref()
506 .map(|a| dyn_clone::clone_box(a.as_ref())),
507 norm_first: self.norm_first,
508 }
509 }
510}
511
512fn build_transformer_encoder(
513 builder: TransformerEncoderBuilder,
514) -> Result<TransformerEncoder, TransformerBulidError> {
515 let layer_count = builder.layer_count;
516 let dimensions = builder.dimensions;
517 let num_heads = builder.num_heads;
518 let norm_first = builder.norm_first;
519 let activation = builder.activation.unwrap_or(Box::new(Relu));
520
521 let layers = (0..layer_count)
522 .map(|_| {
523 TransformerEncoderLayerBuilder::new(dimensions, num_heads, norm_first)
524 .mlp_dimensions(builder.mlp_dimensions)
525 .dropout(builder.dropout)
526 .activation(dyn_clone::clone_box(&*activation))
527 .build()
528 })
529 .collect::<Result<Vec<_>, _>>()?;
530 let ln = LayerNorm::new(dimensions)?;
531
532 Ok(TransformerEncoder { layers, ln })
533}
534
535#[derive(Debug, Clone, ModuleParameters, Quantizable, Buildable)]
536#[module(root = crate)]
537#[quantizable(root = crate)]
538#[buildable(root = crate)]
539struct TransformerEncoder {
540 #[quantizable]
541 #[param]
542 pub layers: Vec<TransformerEncoderLayer>,
543
544 #[param]
545 pub ln: LayerNorm,
546}
547
548impl<'a, Input> Module<Input> for TransformerEncoder
549where
550 Input: Into<TransformerEncoderInput<'a>>,
551{
552 type Error = Exception;
553 type Output = Array;
554
555 fn forward(&mut self, input: Input) -> Result<Self::Output, Self::Error> {
556 let input = input.into();
557 let x = input.x;
558 let mask = input.mask;
559
560 let mut x = Cow::Borrowed(x);
561
562 for l in &mut self.layers {
563 let layer_input = TransformerEncoderInput::from((&*x, mask));
564 x = Cow::Owned(l.forward(layer_input)?);
565 }
566
567 self.ln.forward(&*x)
568 }
569
570 fn training_mode(&mut self, mode: bool) {
571 self.layers.iter_mut().for_each(|layer| {
572 <TransformerEncoderLayer as Module<TransformerEncoderInput>>::training_mode(
573 layer, mode,
574 );
575 });
576 self.ln.training_mode(mode);
577 }
578}
579
580#[derive(Debug, Builder)]
581#[builder(
582 root = crate,
583 build_with = build_transformer_decoder_layer,
584 err = TransformerBulidError,
585)]
586struct TransformerDecoderLayerBuilder {
587 pub dimensions: i32,
588 pub num_heads: i32,
589 #[builder(optional, default = None)]
590 pub ml_dimensions: Option<i32>,
591 #[builder(optional, default = Self::DEFAULT_DROPOUT)]
592 pub dropout: f32,
593 #[builder(optional, default = None)]
594 pub activation: Option<Box<dyn Activation>>,
595 pub norm_first: bool,
596}
597
598impl TransformerDecoderLayerBuilder {
599 const DEFAULT_DROPOUT: f32 = 0.0;
600}
601
602impl Clone for TransformerDecoderLayerBuilder {
603 fn clone(&self) -> Self {
604 Self {
605 dimensions: self.dimensions,
606 num_heads: self.num_heads,
607 ml_dimensions: self.ml_dimensions,
608 dropout: self.dropout,
609 activation: self
610 .activation
611 .as_ref()
612 .map(|a| dyn_clone::clone_box(a.as_ref())),
613 norm_first: self.norm_first,
614 }
615 }
616}
617
618fn build_transformer_decoder_layer(
619 builder: TransformerDecoderLayerBuilder,
620) -> Result<TransformerDecoderLayer, TransformerBulidError> {
621 let dimensions = builder.dimensions;
622 let num_heads = builder.num_heads;
623 let mlp_dimensions = builder.ml_dimensions.unwrap_or(4 * dimensions);
624 let dropout = builder.dropout;
625
626 let self_attention = MultiHeadAttention::new(dimensions, num_heads)?;
627 let cross_attention = MultiHeadAttention::new(dimensions, num_heads)?;
628 let ln1 = LayerNorm::new(dimensions)?;
629 let ln2 = LayerNorm::new(dimensions)?;
630 let ln3 = LayerNorm::new(dimensions)?;
631 let linear1 = Linear::new(dimensions, mlp_dimensions)?;
632 let linear2 = Linear::new(mlp_dimensions, dimensions)?;
633 let dropout1 = DropoutBuilder::new().p(dropout).build()?;
634 let dropout2 = DropoutBuilder::new().p(dropout).build()?;
635 let dropout3 = DropoutBuilder::new().p(dropout).build()?;
636 let activation = builder.activation.unwrap_or(Box::new(Relu));
637 let norm_first = builder.norm_first;
638
639 Ok(TransformerDecoderLayer {
640 self_attention,
641 cross_attention,
642 ln1,
643 ln2,
644 ln3,
645 linear1: MaybeQuantized::new(linear1),
646 linear2: MaybeQuantized::new(linear2),
647 dropout1,
648 dropout2,
649 dropout3,
650 activation,
651 norm_first,
652 })
653}
654
655#[derive(Debug, ModuleParameters, Quantizable, Buildable)]
656#[module(root = crate)]
657#[quantizable(root = crate)]
658#[buildable(root = crate)]
659struct TransformerDecoderLayer {
660 #[param]
661 pub self_attention: MultiHeadAttention,
662
663 #[param]
664 pub cross_attention: MultiHeadAttention,
665
666 #[param]
667 pub ln1: LayerNorm,
668
669 #[param]
670 pub ln2: LayerNorm,
671
672 #[param]
673 pub ln3: LayerNorm,
674
675 #[quantizable]
676 #[param]
677 pub linear1: MaybeQuantized<Linear>,
678
679 #[quantizable]
680 #[param]
681 pub linear2: MaybeQuantized<Linear>,
682
683 #[param]
684 pub dropout1: Dropout,
685
686 #[param]
687 pub dropout2: Dropout,
688
689 #[param]
690 pub dropout3: Dropout,
691
692 #[param]
693 pub activation: Box<dyn Activation>,
694
695 pub norm_first: bool,
696}
697
698impl Clone for TransformerDecoderLayer {
699 fn clone(&self) -> Self {
700 Self {
701 self_attention: self.self_attention.clone(),
702 cross_attention: self.cross_attention.clone(),
703 ln1: self.ln1.clone(),
704 ln2: self.ln2.clone(),
705 ln3: self.ln3.clone(),
706 linear1: self.linear1.clone(),
707 linear2: self.linear2.clone(),
708 dropout1: self.dropout1.clone(),
709 dropout2: self.dropout2.clone(),
710 dropout3: self.dropout3.clone(),
711 activation: dyn_clone::clone_box(&*self.activation),
712 norm_first: self.norm_first,
713 }
714 }
715}
716
717struct TransformerDecoderInput<'a> {
718 pub x: &'a Array,
719 pub memory: &'a Array,
720 pub x_mask: &'a Array,
721 pub memory_mask: &'a Array,
722}
723
724impl<'a> From<(&'a Array, &'a Array, &'a Array, &'a Array)> for TransformerDecoderInput<'a> {
725 fn from(
726 (x, memory, x_mask, memory_mask): (&'a Array, &'a Array, &'a Array, &'a Array),
727 ) -> Self {
728 TransformerDecoderInput {
729 x,
730 memory,
731 x_mask,
732 memory_mask,
733 }
734 }
735}
736
737impl<'a, Input> Module<Input> for TransformerDecoderLayer
738where
739 Input: Into<TransformerDecoderInput<'a>>,
740{
741 type Error = Exception;
742 type Output = Array;
743
744 fn forward(&mut self, input: Input) -> Result<Self::Output, Self::Error> {
745 let input = input.into();
746 let x = input.x;
747 let memory = input.memory;
748 let x_mask = input.x_mask;
749 let memory_mask = input.memory_mask;
750
751 if self.norm_first {
752 let mut y = self.ln1.forward(x)?;
753 y = self
754 .self_attention
755 .forward(MultiHeadAttentionInput::from((&y, &y, &y, x_mask)))?;
756 y = self.dropout1.forward(&y)?;
757 let x = x.add(&y)?;
758
759 y = self.ln2.forward(&x)?;
760 y = self
761 .cross_attention
762 .forward(MultiHeadAttentionInput::from((
763 &y,
764 memory,
765 memory,
766 memory_mask,
767 )))?;
768 y = self.dropout2.forward(&y)?;
769 let x = x.add(&y)?;
770
771 y = self.ln3.forward(&x)?;
772 y = self.linear1.forward(&y)?;
773 y = self.activation.forward(&y)?;
774 y = self.dropout3.forward(&y)?;
775 y = self.linear2.forward(&y)?;
776 x.add(&y)
777 } else {
778 let mut y = self
779 .self_attention
780 .forward(MultiHeadAttentionInput::from((x, x, x, x_mask)))?;
781 y = self.dropout1.forward(&y)?;
782 let mut x = x.add(&y)?;
783 x = self.ln1.forward(&x)?;
784
785 y = self
786 .cross_attention
787 .forward(MultiHeadAttentionInput::from((
788 &y,
789 memory,
790 memory,
791 memory_mask,
792 )))?;
793 y = self.dropout2.forward(&y)?;
794 x = x.add(&y)?;
795 x = self.ln2.forward(&x)?; y = self.linear1.forward(&x)?;
798 y = self.activation.forward(&y)?;
799 y = self.dropout3.forward(&y)?;
800 y = self.linear2.forward(&y)?;
801 y = x.add(&y)?;
802 self.ln3.forward(&y)
803 }
804 }
805
806 fn training_mode(&mut self, mode: bool) {
807 <MultiHeadAttention as Module<MultiHeadAttentionInput>>::training_mode(
808 &mut self.self_attention,
809 mode,
810 );
811 <MultiHeadAttention as Module<MultiHeadAttentionInput>>::training_mode(
812 &mut self.cross_attention,
813 mode,
814 );
815 self.ln1.training_mode(mode);
816 self.ln2.training_mode(mode);
817 self.ln3.training_mode(mode);
818 self.linear1.training_mode(mode);
819 self.linear2.training_mode(mode);
820 self.dropout1.training_mode(mode);
821 self.dropout2.training_mode(mode);
822 self.dropout3.training_mode(mode);
823 self.activation.training_mode(mode);
824 }
825}
826
827#[derive(Debug, Builder)]
828#[builder(
829 root = crate,
830 build_with = build_transformer_decoder,
831 err = TransformerBulidError,
832)]
833struct TransformerDecoderBuilder {
834 pub layer_count: usize,
835 pub dimensions: i32,
836 pub num_heads: i32,
837
838 #[builder(optional, default = None)]
839 pub mlp_dimensions: Option<i32>,
840
841 #[builder(optional, default = Self::DEFAULT_DROPOUT)]
842 pub dropout: f32,
843
844 #[builder(optional, default = None)]
845 pub activation: Option<Box<dyn Activation>>,
846
847 pub norm_first: bool,
848}
849
850impl TransformerDecoderBuilder {
851 const DEFAULT_DROPOUT: f32 = 0.0;
852}
853
854impl Clone for TransformerDecoderBuilder {
855 fn clone(&self) -> Self {
856 Self {
857 layer_count: self.layer_count,
858 dimensions: self.dimensions,
859 num_heads: self.num_heads,
860 mlp_dimensions: self.mlp_dimensions,
861 dropout: self.dropout,
862 activation: self
863 .activation
864 .as_ref()
865 .map(|a| dyn_clone::clone_box(a.as_ref())),
866 norm_first: self.norm_first,
867 }
868 }
869}
870
871fn build_transformer_decoder(
872 builder: TransformerDecoderBuilder,
873) -> Result<TransformerDecoder, TransformerBulidError> {
874 let layer_count = builder.layer_count;
875 let dimensions = builder.dimensions;
876 let num_heads = builder.num_heads;
877 let norm_first = builder.norm_first;
878
879 let activation = builder.activation.unwrap_or(Box::new(Relu));
880
881 let layers = (0..layer_count)
882 .map(|_| {
883 TransformerDecoderLayerBuilder::new(dimensions, num_heads, norm_first)
884 .ml_dimensions(builder.mlp_dimensions)
885 .dropout(builder.dropout)
886 .activation(dyn_clone::clone_box(&*activation))
887 .build()
888 })
889 .collect::<Result<Vec<_>, _>>()?;
890 let ln = LayerNorm::new(dimensions)?;
891
892 Ok(TransformerDecoder { layers, ln })
893}
894
895#[derive(Debug, Clone, ModuleParameters, Quantizable, Buildable)]
896#[module(root = crate)]
897#[quantizable(root = crate)]
898#[buildable(root = crate)]
899struct TransformerDecoder {
900 #[quantizable]
901 #[param]
902 pub layers: Vec<TransformerDecoderLayer>,
903
904 #[param]
905 pub ln: LayerNorm,
906}
907
908impl<'a, Input> Module<Input> for TransformerDecoder
909where
910 Input: Into<TransformerDecoderInput<'a>>,
911{
912 type Error = Exception;
913 type Output = Array;
914
915 fn forward(&mut self, input: Input) -> Result<Self::Output, Self::Error> {
916 let input = input.into();
917 let x = input.x;
918 let memory = input.memory;
919 let x_mask = input.x_mask;
920 let memory_mask = input.memory_mask;
921
922 let mut x = Cow::Borrowed(x);
923
924 for l in &mut self.layers {
925 let layer_input = TransformerDecoderInput::from((&*x, memory, x_mask, memory_mask));
926 x = Cow::Owned(l.forward(layer_input)?);
927 }
928
929 self.ln.forward(&*x)
930 }
931
932 fn training_mode(&mut self, mode: bool) {
933 self.layers.iter_mut().for_each(|layer| {
934 <TransformerDecoderLayer as Module<TransformerDecoderInput>>::training_mode(
935 layer, mode,
936 );
937 });
938 self.ln.training_mode(mode);
939 }
940}
941
942#[derive(Debug, Builder)]
944#[builder(
945 root = crate,
946 build_with = build_transformer,
947 err = TransformerBulidError,
948)]
949pub struct TransformerBuilder {
950 #[builder(optional, default = Transformer::DEFAULT_DIMENSIONS)]
952 pub dimensions: i32,
953
954 #[builder(optional, default = Transformer::DEFAULT_NUM_HEADS)]
956 pub num_heads: i32,
957
958 #[builder(optional, default = Transformer::DEFAULT_ENCODER_LAYERS_COUNT)]
960 pub encoder_layer_count: usize,
961
962 #[builder(optional, default = Transformer::DEFAULT_DECODER_LAYERS_COUNT)]
964 pub decoder_layer_count: usize,
965
966 #[builder(optional, default = None)]
969 pub mlp_dimensions: Option<i32>,
970
971 #[builder(optional, default = Transformer::DEFAULT_DROPOUT)]
974 pub dropout: f32,
975
976 #[builder(optional, default = None)]
978 pub activation: Option<Box<dyn Activation>>,
979
980 #[builder(optional, default = Transformer::DEFAULT_NORM_FIRST)]
983 pub norm_first: bool,
984}
985
986impl Clone for TransformerBuilder {
987 fn clone(&self) -> Self {
988 Self {
989 dimensions: self.dimensions,
990 num_heads: self.num_heads,
991 encoder_layer_count: self.encoder_layer_count,
992 decoder_layer_count: self.decoder_layer_count,
993 mlp_dimensions: self.mlp_dimensions,
994 dropout: self.dropout,
995 activation: self
996 .activation
997 .as_ref()
998 .map(|a| dyn_clone::clone_box(a.as_ref())),
999 norm_first: self.norm_first,
1000 }
1001 }
1002}
1003
1004fn build_transformer(builder: TransformerBuilder) -> Result<Transformer, TransformerBulidError> {
1005 let dimensions = builder.dimensions;
1006 let num_heads = builder.num_heads;
1007 let encoder_layer_count = builder.encoder_layer_count;
1008 let decoder_layer_count = builder.decoder_layer_count;
1009 let mlp_dimensions = builder.mlp_dimensions;
1010 let dropout = builder.dropout;
1011 let activation = builder.activation.unwrap_or(Box::new(Relu));
1012 let norm_first = builder.norm_first;
1013
1014 let encoder =
1015 TransformerEncoderBuilder::new(encoder_layer_count, dimensions, num_heads, norm_first)
1016 .mlp_dimensions(mlp_dimensions)
1017 .dropout(dropout)
1018 .activation(dyn_clone::clone_box(&*activation))
1019 .build()?;
1020 let decoder =
1021 TransformerDecoderBuilder::new(decoder_layer_count, dimensions, num_heads, norm_first)
1022 .mlp_dimensions(mlp_dimensions)
1023 .dropout(dropout)
1024 .activation(dyn_clone::clone_box(&*activation))
1025 .build()?;
1026
1027 Ok(Transformer { encoder, decoder })
1028}
1029
1030#[derive(Debug, Clone, ModuleParameters, Quantizable, Buildable)]
1040#[module(root = crate)]
1041#[quantizable(root = crate)]
1042#[buildable(root = crate)]
1043pub struct Transformer {
1044 #[quantizable]
1046 #[param]
1047 encoder: TransformerEncoder, #[quantizable]
1051 #[param]
1052 decoder: TransformerDecoder, }
1054
1055impl Transformer {
1056 pub const DEFAULT_DIMENSIONS: i32 = 512;
1058
1059 pub const DEFAULT_NUM_HEADS: i32 = 8;
1061
1062 pub const DEFAULT_ENCODER_LAYERS_COUNT: usize = 6;
1064
1065 pub const DEFAULT_DECODER_LAYERS_COUNT: usize = 6;
1067
1068 pub const DEFAULT_DROPOUT: f32 = 0.0;
1070
1071 pub const DEFAULT_NORM_FIRST: bool = false;
1073}
1074
1075#[derive(Debug, Clone)]
1077pub struct TransformerInput<'a> {
1078 pub source: &'a Array,
1080
1081 pub target: &'a Array,
1083
1084 pub source_mask: &'a Array,
1086
1087 pub target_mask: &'a Array,
1089
1090 pub memory_mask: &'a Array,
1092}
1093
1094impl<'a> From<(&'a Array, &'a Array, &'a Array, &'a Array, &'a Array)> for TransformerInput<'a> {
1095 fn from(
1096 (source, target, source_mask, target_mask, memory_mask): (
1097 &'a Array,
1098 &'a Array,
1099 &'a Array,
1100 &'a Array,
1101 &'a Array,
1102 ),
1103 ) -> Self {
1104 TransformerInput {
1105 source,
1106 target,
1107 source_mask,
1108 target_mask,
1109 memory_mask,
1110 }
1111 }
1112}
1113
1114impl<'a, Input> Module<Input> for Transformer
1115where
1116 Input: Into<TransformerInput<'a>>,
1117{
1118 type Error = Exception;
1119 type Output = Array;
1120
1121 fn forward(&mut self, input: Input) -> Result<Self::Output, Self::Error> {
1122 let input = input.into();
1123 let source = input.source;
1124 let target = input.target;
1125 let source_mask = input.source_mask;
1126 let target_mask = input.target_mask;
1127 let memory_mask = input.memory_mask;
1128
1129 let memory = self
1130 .encoder
1131 .forward(TransformerEncoderInput::from((source, source_mask)))?;
1132 self.decoder.forward(TransformerDecoderInput::from((
1133 target,
1134 &memory,
1135 target_mask,
1136 memory_mask,
1137 )))
1138 }
1139
1140 fn training_mode(&mut self, mode: bool) {
1141 <TransformerEncoder as Module<TransformerEncoderInput>>::training_mode(
1142 &mut self.encoder,
1143 mode,
1144 );
1145 <TransformerDecoder as Module<TransformerDecoderInput>>::training_mode(
1146 &mut self.decoder,
1147 mode,
1148 );
1149 }
1150}