1use mlx_rs::{
2 builder::Builder,
3 error::Exception,
4 fast::scaled_dot_product_attention,
5 macros::{ModuleParameters, Quantizable},
6 module::Module,
7 nn,
8 ops::concatenate,
9 quantization::MaybeQuantized,
10 Array,
11};
12use serde::Deserialize;
13
14#[derive(Debug, Clone, Deserialize)]
15pub struct ModelArgs {
16 pub dim: i32,
17 pub n_layers: i32,
18 pub head_dim: i32,
19 pub hidden_dim: i32,
20 pub n_heads: i32,
21 pub n_kv_heads: i32,
22 pub norm_eps: f32,
23 pub vocab_size: i32,
24 pub rope_theta: Option<f32>,
25}
26
27impl ModelArgs {
28 pub const DEFAULT_ROPE_THETA: f32 = 10000.0;
29}
30
31#[derive(Debug, Clone, ModuleParameters, Quantizable)]
32pub struct Attention {
33 n_heads: i32,
34 n_kv_heads: i32,
35 repeats: i32,
36 scale: f32,
37
38 #[quantizable]
39 #[param]
40 wq: MaybeQuantized<nn::Linear>,
41
42 #[quantizable]
43 #[param]
44 wk: MaybeQuantized<nn::Linear>,
45
46 #[quantizable]
47 #[param]
48 wv: MaybeQuantized<nn::Linear>,
49
50 #[quantizable]
51 #[param]
52 wo: MaybeQuantized<nn::Linear>,
53
54 #[param]
55 rope: nn::Rope,
56}
57
58impl Attention {
59 pub fn new(args: &ModelArgs) -> Result<Self, Exception> {
60 let n_heads = args.n_heads;
61 let n_kv_heads = args.n_kv_heads;
62 let repeats = n_heads / n_kv_heads;
63 let scale = (args.head_dim as f32).powf(-0.5);
64
65 let wq = nn::LinearBuilder::new(args.dim, n_heads * args.head_dim)
66 .bias(false)
67 .build()?;
68 let wk = nn::LinearBuilder::new(args.dim, n_kv_heads * args.head_dim)
69 .bias(false)
70 .build()?;
71 let wv = nn::LinearBuilder::new(args.dim, n_kv_heads * args.head_dim)
72 .bias(false)
73 .build()?;
74 let wo = nn::LinearBuilder::new(n_heads * args.head_dim, args.dim)
75 .bias(false)
76 .build()?;
77 let rope = nn::RopeBuilder::new(args.head_dim)
78 .traditional(true)
79 .base(args.rope_theta.unwrap_or(ModelArgs::DEFAULT_ROPE_THETA))
80 .build()?;
81
82 Ok(Self {
83 n_heads,
84 n_kv_heads,
85 repeats,
86 scale,
87 wq: MaybeQuantized::new(wq),
88 wk: MaybeQuantized::new(wk),
89 wv: MaybeQuantized::new(wv),
90 wo: MaybeQuantized::new(wo),
91 rope,
92 })
93 }
94}
95
96struct AttentionInput<'a> {
97 x: &'a Array,
98 mask: Option<&'a Array>,
99 cache: Option<(&'a Array, &'a Array)>,
100}
101
102struct AttentionOutput {
103 output: Array,
104 cache: (Array, Array),
105}
106
107impl Module<AttentionInput<'_>> for Attention {
108 type Output = AttentionOutput;
109
110 type Error = Exception;
111
112 #[allow(non_snake_case)]
113 fn forward(&mut self, input: AttentionInput<'_>) -> Result<Self::Output, Self::Error> {
114 let AttentionInput { x, mask, cache } = input;
115
116 let B = x.shape()[0];
118 let L = x.shape()[1];
119
120 let mut queries = self.wq.forward(x)?;
121 let mut keys = self.wk.forward(x)?;
122 let mut values = self.wv.forward(x)?;
123
124 queries = queries
126 .reshape(&[B, L, self.n_heads, -1])?
127 .transpose(&[0, 2, 1, 3])?;
128 keys = keys
129 .reshape(&[B, L, self.n_kv_heads, -1])?
130 .transpose(&[0, 2, 1, 3])?;
131 values = values
132 .reshape(&[B, L, self.n_kv_heads, -1])?
133 .transpose(&[0, 2, 1, 3])?;
134
135 match cache {
136 Some((key_cache, value_cache)) => {
137 let offset = key_cache.shape()[2];
138 queries = self.rope.forward((&queries, offset))?;
139 keys = self.rope.forward((&keys, offset))?;
140 keys = concatenate(&[key_cache, &keys], 2)?;
141 values = concatenate(&[value_cache, &values], 2)?;
142 }
143 None => {
144 queries = self.rope.forward(&queries)?;
145 keys = self.rope.forward(&keys)?;
146 }
147 }
148
149 let output = scaled_dot_product_attention(queries, &keys, &values, self.scale, mask, None)?;
150 let output = output.transpose(&[0, 2, 1, 3])?.reshape(&[B, L, -1])?;
151 let output = self.wo.forward(&output)?;
152
153 Ok(AttentionOutput {
154 output,
155 cache: (keys, values),
156 })
157 }
158
159 fn training_mode(&mut self, mode: bool) {
160 self.wq.training_mode(mode);
161 self.wk.training_mode(mode);
162 self.wv.training_mode(mode);
163 self.wo.training_mode(mode);
164 }
165}
166
167#[derive(Debug, Clone, ModuleParameters, Quantizable)]
168struct FeedForward {
169 #[quantizable]
170 #[param]
171 w1: MaybeQuantized<nn::Linear>,
172
173 #[quantizable]
174 #[param]
175 w2: MaybeQuantized<nn::Linear>,
176
177 #[quantizable]
178 #[param]
179 w3: MaybeQuantized<nn::Linear>,
180}
181
182impl FeedForward {
183 pub fn new(args: &ModelArgs) -> Result<Self, Exception> {
184 let w1 = nn::LinearBuilder::new(args.dim, args.hidden_dim)
185 .bias(false)
186 .build()?;
187 let w2 = nn::LinearBuilder::new(args.hidden_dim, args.dim)
188 .bias(false)
189 .build()?;
190 let w3 = nn::LinearBuilder::new(args.dim, args.dim)
191 .bias(false)
192 .build()?;
193 Ok(Self {
194 w1: MaybeQuantized::new(w1),
195 w2: MaybeQuantized::new(w2),
196 w3: MaybeQuantized::new(w3),
197 })
198 }
199}
200
201impl Module<&Array> for FeedForward {
202 type Output = Array;
203
204 type Error = Exception;
205
206 fn forward(&mut self, x: &'_ Array) -> Result<Self::Output, Self::Error> {
207 let w2_input = nn::silu(self.w1.forward(x)?)?.multiply(self.w3.forward(x)?)?;
208 self.w2.forward(&w2_input)
209 }
210
211 fn training_mode(&mut self, mode: bool) {
212 self.w1.training_mode(mode);
213 self.w2.training_mode(mode);
214 self.w3.training_mode(mode);
215 }
216}
217
218#[derive(Debug, Clone, ModuleParameters, Quantizable)]
219struct TransformerBlock {
220 n_heads: i32,
221 dim: i32,
222
223 #[quantizable]
224 #[param]
225 attention: Attention,
226
227 #[quantizable]
228 #[param]
229 feed_forward: FeedForward,
230
231 #[param]
232 attention_norm: nn::RmsNorm,
233
234 #[param]
235 ffn_norm: nn::RmsNorm,
236}
237
238impl TransformerBlock {
239 pub fn new(args: &ModelArgs) -> Result<Self, Exception> {
240 let n_heads = args.n_heads;
241 let dim = args.dim;
242
243 let attention = Attention::new(args)?;
244 let feed_forward = FeedForward::new(args)?;
245 let attention_norm = nn::RmsNormBuilder::new(dim).eps(args.norm_eps).build()?;
246 let ffn_norm = nn::RmsNormBuilder::new(dim).eps(args.norm_eps).build()?;
247 Ok(Self {
248 n_heads,
249 dim,
250 attention,
251 feed_forward,
252 attention_norm,
253 ffn_norm,
254 })
255 }
256}
257
258impl Module<AttentionInput<'_>> for TransformerBlock {
259 type Output = AttentionOutput;
260
261 type Error = Exception;
262
263 fn forward(&mut self, input: AttentionInput<'_>) -> Result<Self::Output, Self::Error> {
264 let AttentionInput { x, mask, cache } = input;
265 let norm_x = self.attention_norm.forward(x)?;
266 let attention_input = AttentionInput {
267 x: &norm_x,
268 mask,
269 cache,
270 };
271 let attention_output = self.attention.forward(attention_input)?;
272
273 let r = attention_output.output;
274 let cache = attention_output.cache;
275
276 let h = x.add(r)?;
277 let r = self.feed_forward.forward(&self.ffn_norm.forward(&h)?)?;
278 let output = h.add(r)?;
279
280 Ok(AttentionOutput { output, cache })
281 }
282
283 fn training_mode(&mut self, mode: bool) {
284 self.attention.training_mode(mode);
285 self.feed_forward.training_mode(mode);
286 self.attention_norm.training_mode(mode);
287 self.ffn_norm.training_mode(mode);
288 }
289}
290
291#[derive(Debug, thiserror::Error)]
292pub enum MistralError {
293 #[error("Invalid vocab size: {0}")]
294 InvalidVocabSize(i32),
295
296 #[error(transparent)]
297 Exception(#[from] Exception),
298}
299
300#[derive(Debug, Clone, ModuleParameters, Quantizable)]
301pub struct Mistral {
302 vocab_size: i32,
303 n_layers: i32,
304
305 #[quantizable]
306 #[param]
307 tok_embeddings: MaybeQuantized<nn::Embedding>,
308
309 #[quantizable]
310 #[param]
311 layers: Vec<TransformerBlock>,
312
313 #[param]
314 norm: nn::RmsNorm,
315
316 #[quantizable]
317 #[param]
318 output: MaybeQuantized<nn::Linear>,
319}
320
321impl Mistral {
322 pub fn new(args: &ModelArgs) -> Result<Self, MistralError> {
323 let vocab_size = args.vocab_size;
324 if vocab_size <= 0 {
325 return Err(MistralError::InvalidVocabSize(vocab_size));
327 }
328 let n_layers = args.n_layers;
329
330 let tok_embeddings = nn::Embedding::new(vocab_size, args.dim)?;
331 let layers = (0..n_layers)
332 .map(|_| TransformerBlock::new(args))
333 .collect::<Result<Vec<_>, _>>()?;
334 let norm = nn::RmsNormBuilder::new(args.dim)
335 .eps(args.norm_eps)
336 .build()?;
337 let output = nn::LinearBuilder::new(args.dim, vocab_size)
338 .bias(false)
339 .build()?;
340
341 Ok(Self {
342 vocab_size,
343 n_layers,
344 tok_embeddings: MaybeQuantized::new(tok_embeddings),
345 layers,
346 norm,
347 output: MaybeQuantized::new(output),
348 })
349 }
350}
351
352pub struct MistralInput<'a> {
353 pub inputs: &'a Array,
354 pub cache: &'a [Option<(Array, Array)>],
355}
356pub struct MistralOutput {
357 pub logits: Array,
358 pub cache: Vec<Option<(Array, Array)>>,
359}
360
361impl Module<MistralInput<'_>> for Mistral {
362 type Output = MistralOutput;
363
364 type Error = MistralError;
365
366 fn forward(&mut self, input: MistralInput<'_>) -> Result<Self::Output, Self::Error> {
367 let MistralInput { inputs, cache } = input;
368
369 let mut h = self.tok_embeddings.forward(inputs)?;
370
371 let mut mask = None;
372 if h.shape()[1] > 1 {
373 let mask_ = nn::MultiHeadAttention::create_additive_causal_mask::<f32>(h.shape()[1])?;
374 let mask_ = mask_.as_dtype(h.dtype())?;
375 mask = Some(mask_);
376 }
377
378 let mut out_cache = Vec::with_capacity(self.layers.len());
379 for (i, layer) in self.layers.iter_mut().enumerate() {
380 let cache_entry = cache.get(i).and_then(Option::as_ref).map(|(k, v)| (k, v));
381 let input = AttentionInput {
382 x: &h,
383 mask: mask.as_ref(),
384 cache: cache_entry,
385 };
386 let output = layer.forward(input)?;
387 h = output.output;
388 out_cache.push(Some(output.cache));
389 }
390
391 let output = self.output.forward(&self.norm.forward(&h)?)?;
392
393 Ok(MistralOutput {
394 logits: output,
395 cache: out_cache,
396 })
397 }
398
399 fn training_mode(&mut self, mode: bool) {
400 self.tok_embeddings.training_mode(mode);
401 self.layers
402 .iter_mut()
403 .for_each(|layer| layer.training_mode(mode));
404 self.norm.training_mode(mode);
405 self.output.training_mode(mode);
406 }
407}