mistral/
model.rs

1use mlx_rs::{
2    builder::Builder,
3    error::Exception,
4    fast::scaled_dot_product_attention,
5    macros::{ModuleParameters, Quantizable},
6    module::Module,
7    nn,
8    ops::concatenate,
9    quantization::MaybeQuantized,
10    Array,
11};
12use serde::Deserialize;
13
14#[derive(Debug, Clone, Deserialize)]
15pub struct ModelArgs {
16    pub dim: i32,
17    pub n_layers: i32,
18    pub head_dim: i32,
19    pub hidden_dim: i32,
20    pub n_heads: i32,
21    pub n_kv_heads: i32,
22    pub norm_eps: f32,
23    pub vocab_size: i32,
24    pub rope_theta: Option<f32>,
25}
26
27impl ModelArgs {
28    pub const DEFAULT_ROPE_THETA: f32 = 10000.0;
29}
30
31#[derive(Debug, Clone, ModuleParameters, Quantizable)]
32pub struct Attention {
33    n_heads: i32,
34    n_kv_heads: i32,
35    repeats: i32,
36    scale: f32,
37
38    #[quantizable]
39    #[param]
40    wq: MaybeQuantized<nn::Linear>,
41
42    #[quantizable]
43    #[param]
44    wk: MaybeQuantized<nn::Linear>,
45
46    #[quantizable]
47    #[param]
48    wv: MaybeQuantized<nn::Linear>,
49
50    #[quantizable]
51    #[param]
52    wo: MaybeQuantized<nn::Linear>,
53
54    #[param]
55    rope: nn::Rope,
56}
57
58impl Attention {
59    pub fn new(args: &ModelArgs) -> Result<Self, Exception> {
60        let n_heads = args.n_heads;
61        let n_kv_heads = args.n_kv_heads;
62        let repeats = n_heads / n_kv_heads;
63        let scale = (args.head_dim as f32).powf(-0.5);
64
65        let wq = nn::LinearBuilder::new(args.dim, n_heads * args.head_dim)
66            .bias(false)
67            .build()?;
68        let wk = nn::LinearBuilder::new(args.dim, n_kv_heads * args.head_dim)
69            .bias(false)
70            .build()?;
71        let wv = nn::LinearBuilder::new(args.dim, n_kv_heads * args.head_dim)
72            .bias(false)
73            .build()?;
74        let wo = nn::LinearBuilder::new(n_heads * args.head_dim, args.dim)
75            .bias(false)
76            .build()?;
77        let rope = nn::RopeBuilder::new(args.head_dim)
78            .traditional(true)
79            .base(args.rope_theta.unwrap_or(ModelArgs::DEFAULT_ROPE_THETA))
80            .build()?;
81
82        Ok(Self {
83            n_heads,
84            n_kv_heads,
85            repeats,
86            scale,
87            wq: MaybeQuantized::new(wq),
88            wk: MaybeQuantized::new(wk),
89            wv: MaybeQuantized::new(wv),
90            wo: MaybeQuantized::new(wo),
91            rope,
92        })
93    }
94}
95
96struct AttentionInput<'a> {
97    x: &'a Array,
98    mask: Option<&'a Array>,
99    cache: Option<(&'a Array, &'a Array)>,
100}
101
102struct AttentionOutput {
103    output: Array,
104    cache: (Array, Array),
105}
106
107impl Module<AttentionInput<'_>> for Attention {
108    type Output = AttentionOutput;
109
110    type Error = Exception;
111
112    #[allow(non_snake_case)]
113    fn forward(&mut self, input: AttentionInput<'_>) -> Result<Self::Output, Self::Error> {
114        let AttentionInput { x, mask, cache } = input;
115
116        // NOTE: this will panic if the input shape is not correct
117        let B = x.shape()[0];
118        let L = x.shape()[1];
119
120        let mut queries = self.wq.forward(x)?;
121        let mut keys = self.wk.forward(x)?;
122        let mut values = self.wv.forward(x)?;
123
124        // Prepare the queries, keys, and values for the attention computation
125        queries = queries
126            .reshape(&[B, L, self.n_heads, -1])?
127            .transpose(&[0, 2, 1, 3])?;
128        keys = keys
129            .reshape(&[B, L, self.n_kv_heads, -1])?
130            .transpose(&[0, 2, 1, 3])?;
131        values = values
132            .reshape(&[B, L, self.n_kv_heads, -1])?
133            .transpose(&[0, 2, 1, 3])?;
134
135        match cache {
136            Some((key_cache, value_cache)) => {
137                let offset = key_cache.shape()[2];
138                queries = self.rope.forward((&queries, offset))?;
139                keys = self.rope.forward((&keys, offset))?;
140                keys = concatenate(&[key_cache, &keys], 2)?;
141                values = concatenate(&[value_cache, &values], 2)?;
142            }
143            None => {
144                queries = self.rope.forward(&queries)?;
145                keys = self.rope.forward(&keys)?;
146            }
147        }
148
149        let output = scaled_dot_product_attention(queries, &keys, &values, self.scale, mask, None)?;
150        let output = output.transpose(&[0, 2, 1, 3])?.reshape(&[B, L, -1])?;
151        let output = self.wo.forward(&output)?;
152
153        Ok(AttentionOutput {
154            output,
155            cache: (keys, values),
156        })
157    }
158
159    fn training_mode(&mut self, mode: bool) {
160        self.wq.training_mode(mode);
161        self.wk.training_mode(mode);
162        self.wv.training_mode(mode);
163        self.wo.training_mode(mode);
164    }
165}
166
167#[derive(Debug, Clone, ModuleParameters, Quantizable)]
168struct FeedForward {
169    #[quantizable]
170    #[param]
171    w1: MaybeQuantized<nn::Linear>,
172
173    #[quantizable]
174    #[param]
175    w2: MaybeQuantized<nn::Linear>,
176
177    #[quantizable]
178    #[param]
179    w3: MaybeQuantized<nn::Linear>,
180}
181
182impl FeedForward {
183    pub fn new(args: &ModelArgs) -> Result<Self, Exception> {
184        let w1 = nn::LinearBuilder::new(args.dim, args.hidden_dim)
185            .bias(false)
186            .build()?;
187        let w2 = nn::LinearBuilder::new(args.hidden_dim, args.dim)
188            .bias(false)
189            .build()?;
190        let w3 = nn::LinearBuilder::new(args.dim, args.dim)
191            .bias(false)
192            .build()?;
193        Ok(Self {
194            w1: MaybeQuantized::new(w1),
195            w2: MaybeQuantized::new(w2),
196            w3: MaybeQuantized::new(w3),
197        })
198    }
199}
200
201impl Module<&Array> for FeedForward {
202    type Output = Array;
203
204    type Error = Exception;
205
206    fn forward(&mut self, x: &'_ Array) -> Result<Self::Output, Self::Error> {
207        let w2_input = nn::silu(self.w1.forward(x)?)?.multiply(self.w3.forward(x)?)?;
208        self.w2.forward(&w2_input)
209    }
210
211    fn training_mode(&mut self, mode: bool) {
212        self.w1.training_mode(mode);
213        self.w2.training_mode(mode);
214        self.w3.training_mode(mode);
215    }
216}
217
218#[derive(Debug, Clone, ModuleParameters, Quantizable)]
219struct TransformerBlock {
220    n_heads: i32,
221    dim: i32,
222
223    #[quantizable]
224    #[param]
225    attention: Attention,
226
227    #[quantizable]
228    #[param]
229    feed_forward: FeedForward,
230
231    #[param]
232    attention_norm: nn::RmsNorm,
233
234    #[param]
235    ffn_norm: nn::RmsNorm,
236}
237
238impl TransformerBlock {
239    pub fn new(args: &ModelArgs) -> Result<Self, Exception> {
240        let n_heads = args.n_heads;
241        let dim = args.dim;
242
243        let attention = Attention::new(args)?;
244        let feed_forward = FeedForward::new(args)?;
245        let attention_norm = nn::RmsNormBuilder::new(dim).eps(args.norm_eps).build()?;
246        let ffn_norm = nn::RmsNormBuilder::new(dim).eps(args.norm_eps).build()?;
247        Ok(Self {
248            n_heads,
249            dim,
250            attention,
251            feed_forward,
252            attention_norm,
253            ffn_norm,
254        })
255    }
256}
257
258impl Module<AttentionInput<'_>> for TransformerBlock {
259    type Output = AttentionOutput;
260
261    type Error = Exception;
262
263    fn forward(&mut self, input: AttentionInput<'_>) -> Result<Self::Output, Self::Error> {
264        let AttentionInput { x, mask, cache } = input;
265        let norm_x = self.attention_norm.forward(x)?;
266        let attention_input = AttentionInput {
267            x: &norm_x,
268            mask,
269            cache,
270        };
271        let attention_output = self.attention.forward(attention_input)?;
272
273        let r = attention_output.output;
274        let cache = attention_output.cache;
275
276        let h = x.add(r)?;
277        let r = self.feed_forward.forward(&self.ffn_norm.forward(&h)?)?;
278        let output = h.add(r)?;
279
280        Ok(AttentionOutput { output, cache })
281    }
282
283    fn training_mode(&mut self, mode: bool) {
284        self.attention.training_mode(mode);
285        self.feed_forward.training_mode(mode);
286        self.attention_norm.training_mode(mode);
287        self.ffn_norm.training_mode(mode);
288    }
289}
290
291#[derive(Debug, thiserror::Error)]
292pub enum MistralError {
293    #[error("Invalid vocab size: {0}")]
294    InvalidVocabSize(i32),
295
296    #[error(transparent)]
297    Exception(#[from] Exception),
298}
299
300#[derive(Debug, Clone, ModuleParameters, Quantizable)]
301pub struct Mistral {
302    vocab_size: i32,
303    n_layers: i32,
304
305    #[quantizable]
306    #[param]
307    tok_embeddings: MaybeQuantized<nn::Embedding>,
308
309    #[quantizable]
310    #[param]
311    layers: Vec<TransformerBlock>,
312
313    #[param]
314    norm: nn::RmsNorm,
315
316    #[quantizable]
317    #[param]
318    output: MaybeQuantized<nn::Linear>,
319}
320
321impl Mistral {
322    pub fn new(args: &ModelArgs) -> Result<Self, MistralError> {
323        let vocab_size = args.vocab_size;
324        if vocab_size <= 0 {
325            // We would still have to check for the zero case even if we switch to u32
326            return Err(MistralError::InvalidVocabSize(vocab_size));
327        }
328        let n_layers = args.n_layers;
329
330        let tok_embeddings = nn::Embedding::new(vocab_size, args.dim)?;
331        let layers = (0..n_layers)
332            .map(|_| TransformerBlock::new(args))
333            .collect::<Result<Vec<_>, _>>()?;
334        let norm = nn::RmsNormBuilder::new(args.dim)
335            .eps(args.norm_eps)
336            .build()?;
337        let output = nn::LinearBuilder::new(args.dim, vocab_size)
338            .bias(false)
339            .build()?;
340
341        Ok(Self {
342            vocab_size,
343            n_layers,
344            tok_embeddings: MaybeQuantized::new(tok_embeddings),
345            layers,
346            norm,
347            output: MaybeQuantized::new(output),
348        })
349    }
350}
351
352pub struct MistralInput<'a> {
353    pub inputs: &'a Array,
354    pub cache: &'a [Option<(Array, Array)>],
355}
356pub struct MistralOutput {
357    pub logits: Array,
358    pub cache: Vec<Option<(Array, Array)>>,
359}
360
361impl Module<MistralInput<'_>> for Mistral {
362    type Output = MistralOutput;
363
364    type Error = MistralError;
365
366    fn forward(&mut self, input: MistralInput<'_>) -> Result<Self::Output, Self::Error> {
367        let MistralInput { inputs, cache } = input;
368
369        let mut h = self.tok_embeddings.forward(inputs)?;
370
371        let mut mask = None;
372        if h.shape()[1] > 1 {
373            let mask_ = nn::MultiHeadAttention::create_additive_causal_mask::<f32>(h.shape()[1])?;
374            let mask_ = mask_.as_dtype(h.dtype())?;
375            mask = Some(mask_);
376        }
377
378        let mut out_cache = Vec::with_capacity(self.layers.len());
379        for (i, layer) in self.layers.iter_mut().enumerate() {
380            let cache_entry = cache.get(i).and_then(Option::as_ref).map(|(k, v)| (k, v));
381            let input = AttentionInput {
382                x: &h,
383                mask: mask.as_ref(),
384                cache: cache_entry,
385            };
386            let output = layer.forward(input)?;
387            h = output.output;
388            out_cache.push(Some(output.cache));
389        }
390
391        let output = self.output.forward(&self.norm.forward(&h)?)?;
392
393        Ok(MistralOutput {
394            logits: output,
395            cache: out_cache,
396        })
397    }
398
399    fn training_mode(&mut self, mode: bool) {
400        self.tok_embeddings.training_mode(mode);
401        self.layers
402            .iter_mut()
403            .for_each(|layer| layer.training_mode(mode));
404        self.norm.training_mode(mode);
405        self.output.training_mode(mode);
406    }
407}