1use std::sync::Arc;
2
3use crate::{
4 array,
5 error::Exception,
6 module::{Module, Param},
7 ops::indexing::{Ellipsis, IndexOp},
8 ops::{addmm, matmul, sigmoid, split_equal, stack, tanh, tanh_device},
9 random::uniform,
10 Array, Stream,
11};
12use mlx_internal_macros::{generate_builder, Buildable, Builder};
13use mlx_macros::ModuleParameters;
14
15pub type NonLinearity = dyn Fn(&Array, &Stream) -> Result<Array, Exception>;
17
18#[derive(Clone, ModuleParameters, Buildable)]
30#[module(root = crate)]
31#[buildable(root = crate)]
32pub struct Rnn {
33 pub non_linearity: Arc<NonLinearity>,
35
36 #[param]
38 pub wxh: Param<Array>,
39
40 #[param]
42 pub whh: Param<Array>,
43
44 #[param]
46 pub bias: Param<Option<Array>>,
47}
48
49#[derive(Clone, Builder)]
51#[builder(
52 root = crate,
53 build_with = build_rnn,
54 err = Exception,
55)]
56pub struct RnnBuilder {
57 pub input_size: i32,
59
60 pub hidden_size: i32,
62
63 #[builder(optional, default = Rnn::DEFAULT_NONLINEARITY)]
65 pub non_linearity: Option<Arc<NonLinearity>>,
66
67 #[builder(optional, default = Rnn::DEFAULT_BIAS)]
69 pub bias: bool,
70}
71
72impl std::fmt::Debug for RnnBuilder {
73 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
74 f.debug_struct("RnnBuilder")
75 .field("bias", &self.bias)
76 .finish()
77 }
78}
79
80fn build_rnn(builder: RnnBuilder) -> Result<Rnn, Exception> {
82 let input_size = builder.input_size;
83 let hidden_size = builder.hidden_size;
84 let non_linearity = builder
85 .non_linearity
86 .unwrap_or_else(|| Arc::new(|x, d| tanh_device(x, d)));
87
88 let scale = 1.0 / (input_size as f32).sqrt();
89 let wxh = uniform::<_, f32>(-scale, scale, &[hidden_size, input_size], None)?;
90 let whh = uniform::<_, f32>(-scale, scale, &[hidden_size, hidden_size], None)?;
91 let bias = if builder.bias {
92 Some(uniform::<_, f32>(-scale, scale, &[hidden_size], None)?)
93 } else {
94 None
95 };
96
97 Ok(Rnn {
98 non_linearity,
99 wxh: Param::new(wxh),
100 whh: Param::new(whh),
101 bias: Param::new(bias),
102 })
103}
104
105impl std::fmt::Debug for Rnn {
106 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
107 f.debug_struct("Rnn")
108 .field("wxh", &self.wxh)
109 .field("whh", &self.whh)
110 .field("bias", &self.bias)
111 .finish()
112 }
113}
114
115impl Rnn {
116 pub const DEFAULT_BIAS: bool = true;
118
119 pub const DEFAULT_NONLINEARITY: Option<Arc<NonLinearity>> = None;
121
122 pub fn step(&mut self, x: &Array, hidden: Option<&Array>) -> Result<Array, Exception> {
124 let x = if let Some(bias) = &self.bias.value {
125 addmm(bias, x, self.wxh.t(), None, None)?
126 } else {
127 matmul(x, self.wxh.t())?
128 };
129
130 let mut all_hidden = Vec::new();
131 for index in 0..x.dim(-2) {
132 let hidden = match hidden {
133 Some(hidden_) => addmm(
134 x.index((Ellipsis, index, 0..)),
135 hidden_,
136 self.whh.t(),
137 None,
138 None,
139 )?,
140 None => x.index((Ellipsis, index, 0..)),
141 };
142
143 let hidden = (self.non_linearity)(&hidden, &Stream::default())?;
144 all_hidden.push(hidden);
145 }
146
147 stack(&all_hidden[..], -2)
148 }
149}
150
151generate_builder! {
152 #[derive(Debug, Clone, Buildable)]
154 #[buildable(root = crate)]
155 #[builder(root = crate)]
156 pub struct RnnInput<'a> {
157 pub x: &'a Array,
159
160 #[builder(optional, default = None)]
162 pub hidden: Option<&'a Array>,
163 }
164}
165
166impl<'a> From<&'a Array> for RnnInput<'a> {
167 fn from(x: &'a Array) -> Self {
168 RnnInput { x, hidden: None }
169 }
170}
171
172impl<'a> From<(&'a Array,)> for RnnInput<'a> {
173 fn from(input: (&'a Array,)) -> Self {
174 RnnInput {
175 x: input.0,
176 hidden: None,
177 }
178 }
179}
180
181impl<'a> From<(&'a Array, &'a Array)> for RnnInput<'a> {
182 fn from(input: (&'a Array, &'a Array)) -> Self {
183 RnnInput {
184 x: input.0,
185 hidden: Some(input.1),
186 }
187 }
188}
189
190impl<'a> From<(&'a Array, Option<&'a Array>)> for RnnInput<'a> {
191 fn from(input: (&'a Array, Option<&'a Array>)) -> Self {
192 RnnInput {
193 x: input.0,
194 hidden: input.1,
195 }
196 }
197}
198
199impl<'a, Input> Module<Input> for Rnn
200where
201 Input: Into<RnnInput<'a>>,
202{
203 type Error = Exception;
204 type Output = Array;
205
206 fn forward(&mut self, input: Input) -> Result<Array, Exception> {
207 let input = input.into();
208 self.step(input.x, input.hidden)
209 }
210
211 fn training_mode(&mut self, _mode: bool) {}
212}
213
214#[derive(Debug, Clone, ModuleParameters, Buildable)]
226#[module(root = crate)]
227#[buildable(root = crate)]
228pub struct Gru {
229 pub hidden_size: i32,
231
232 #[param]
234 pub wx: Param<Array>,
235
236 #[param]
238 pub wh: Param<Array>,
239
240 #[param]
242 pub bias: Param<Option<Array>>,
243
244 #[param]
246 pub bhn: Param<Option<Array>>,
247}
248
249#[derive(Debug, Clone, Builder)]
251#[builder(
252 root = crate,
253 build_with = build_gru,
254 err = Exception,
255)]
256pub struct GruBuilder {
257 pub input_size: i32,
259
260 pub hidden_size: i32,
262
263 #[builder(optional, default = Gru::DEFAULT_BIAS)]
265 pub bias: bool,
266}
267
268fn build_gru(builder: GruBuilder) -> Result<Gru, Exception> {
269 let input_size = builder.input_size;
270 let hidden_size = builder.hidden_size;
271
272 let scale = 1.0 / f32::sqrt(hidden_size as f32);
273 let wx = uniform::<_, f32>(-scale, scale, &[3 * hidden_size, input_size], None)?;
274 let wh = uniform::<_, f32>(-scale, scale, &[3 * hidden_size, hidden_size], None)?;
275 let (bias, bhn) = if builder.bias {
276 let bias = uniform::<_, f32>(-scale, scale, &[3 * hidden_size], None)?;
277 let bhn = uniform::<_, f32>(-scale, scale, &[hidden_size], None)?;
278 (Some(bias), Some(bhn))
279 } else {
280 (None, None)
281 };
282
283 Ok(Gru {
284 hidden_size,
285 wx: Param::new(wx),
286 wh: Param::new(wh),
287 bias: Param::new(bias),
288 bhn: Param::new(bhn),
289 })
290}
291
292impl Gru {
293 pub const DEFAULT_BIAS: bool = true;
295
296 pub fn step(&mut self, x: &Array, hidden: Option<&Array>) -> Result<Array, Exception> {
298 let x = if let Some(b) = &self.bias.value {
299 addmm(b, x, self.wx.t(), None, None)?
300 } else {
301 matmul(x, self.wx.t())?
302 };
303
304 let x_rz = x.index((Ellipsis, ..(-self.hidden_size)));
305 let x_n = x.index((Ellipsis, (-self.hidden_size)..));
306
307 let mut all_hidden = Vec::new();
308
309 for index in 0..x.dim(-2) {
310 let mut rz = x_rz.index((Ellipsis, index, ..));
311 let mut h_proj_n = None;
312 if let Some(hidden_) = hidden {
313 let h_proj = matmul(hidden_, self.wh.t())?;
314 let h_proj_rz = h_proj.index((Ellipsis, ..(-self.hidden_size)));
315 h_proj_n = Some(h_proj.index((Ellipsis, (-self.hidden_size)..)));
316
317 if let Some(bhn) = &self.bhn.value {
318 h_proj_n = h_proj_n
319 .map(|h_proj_n| h_proj_n.add(bhn))
320 .transpose()?;
322 }
323
324 rz = rz.add(h_proj_rz)?;
325 }
326
327 rz = sigmoid(&rz)?;
328
329 let parts = split_equal(&rz, 2, -1)?;
330 let r = &parts[0];
331 let z = &parts[1];
332
333 let mut n = x_n.index((Ellipsis, index, 0..));
334
335 if let Some(h_proj_n) = h_proj_n {
336 n = n.add(r.multiply(h_proj_n)?)?;
337 }
338 n = tanh(&n)?;
339
340 let hidden = match hidden {
341 Some(hidden) => array!(1.0)
342 .subtract(z)?
343 .multiply(&n)?
344 .add(z.multiply(hidden)?)?,
345 None => array!(1.0).subtract(z)?.multiply(&n)?,
346 };
347
348 all_hidden.push(hidden);
349 }
350
351 stack(&all_hidden[..], -2)
352 }
353}
354
355pub type GruInput<'a> = RnnInput<'a>;
357
358pub type GruInputBuilder<'a> = RnnInputBuilder<'a>;
360
361impl<'a, Input> Module<Input> for Gru
362where
363 Input: Into<GruInput<'a>>,
364{
365 type Error = Exception;
366 type Output = Array;
367
368 fn forward(&mut self, input: Input) -> Result<Array, Exception> {
369 let input = input.into();
370 self.step(input.x, input.hidden)
371 }
372
373 fn training_mode(&mut self, _mode: bool) {}
374}
375
376#[derive(Debug, Clone, ModuleParameters, Buildable)]
378#[module(root = crate)]
379#[buildable(root = crate)]
380pub struct Lstm {
381 #[param]
383 pub wx: Param<Array>,
384
385 #[param]
387 pub wh: Param<Array>,
388
389 #[param]
391 pub bias: Param<Option<Array>>,
392}
393
394#[derive(Debug, Clone, Builder)]
396#[builder(
397 root = crate,
398 build_with = build_lstm,
399 err = Exception,
400)]
401pub struct LstmBuilder {
402 pub input_size: i32,
404
405 pub hidden_size: i32,
407
408 #[builder(optional, default = Lstm::DEFAULT_BIAS)]
410 pub bias: bool,
411}
412
413fn build_lstm(builder: LstmBuilder) -> Result<Lstm, Exception> {
414 let input_size = builder.input_size;
415 let hidden_size = builder.hidden_size;
416 let scale = 1.0 / f32::sqrt(hidden_size as f32);
417 let wx = uniform::<_, f32>(-scale, scale, &[4 * hidden_size, input_size], None)?;
418 let wh = uniform::<_, f32>(-scale, scale, &[4 * hidden_size, hidden_size], None)?;
419 let bias = if builder.bias {
420 Some(uniform::<_, f32>(-scale, scale, &[4 * hidden_size], None)?)
421 } else {
422 None
423 };
424
425 Ok(Lstm {
426 wx: Param::new(wx),
427 wh: Param::new(wh),
428 bias: Param::new(bias),
429 })
430}
431
432generate_builder! {
433 #[derive(Debug, Clone, Buildable)]
435 #[buildable(root = crate)]
436 #[builder(root = crate)]
437 pub struct LstmInput<'a> {
438 pub x: &'a Array,
440
441 #[builder(optional, default = None)]
443 pub hidden: Option<&'a Array>,
444
445 #[builder(optional, default = None)]
447 pub cell: Option<&'a Array>,
448 }
449}
450
451impl<'a> From<&'a Array> for LstmInput<'a> {
452 fn from(x: &'a Array) -> Self {
453 LstmInput {
454 x,
455 hidden: None,
456 cell: None,
457 }
458 }
459}
460
461impl<'a> From<(&'a Array,)> for LstmInput<'a> {
462 fn from(input: (&'a Array,)) -> Self {
463 LstmInput {
464 x: input.0,
465 hidden: None,
466 cell: None,
467 }
468 }
469}
470
471impl<'a> From<(&'a Array, &'a Array)> for LstmInput<'a> {
472 fn from(input: (&'a Array, &'a Array)) -> Self {
473 LstmInput {
474 x: input.0,
475 hidden: Some(input.1),
476 cell: None,
477 }
478 }
479}
480
481impl<'a> From<(&'a Array, &'a Array, &'a Array)> for LstmInput<'a> {
482 fn from(input: (&'a Array, &'a Array, &'a Array)) -> Self {
483 LstmInput {
484 x: input.0,
485 hidden: Some(input.1),
486 cell: Some(input.2),
487 }
488 }
489}
490
491impl<'a> From<(&'a Array, Option<&'a Array>)> for LstmInput<'a> {
492 fn from(input: (&'a Array, Option<&'a Array>)) -> Self {
493 LstmInput {
494 x: input.0,
495 hidden: input.1,
496 cell: None,
497 }
498 }
499}
500
501impl<'a> From<(&'a Array, Option<&'a Array>, Option<&'a Array>)> for LstmInput<'a> {
502 fn from(input: (&'a Array, Option<&'a Array>, Option<&'a Array>)) -> Self {
503 LstmInput {
504 x: input.0,
505 hidden: input.1,
506 cell: input.2,
507 }
508 }
509}
510
511impl Lstm {
512 pub const DEFAULT_BIAS: bool = true;
514
515 pub fn step(
517 &mut self,
518 x: &Array,
519 hidden: Option<&Array>,
520 cell: Option<&Array>,
521 ) -> Result<(Array, Array), Exception> {
522 let x = if let Some(b) = &self.bias.value {
523 addmm(b, x, self.wx.t(), None, None)?
524 } else {
525 matmul(x, self.wx.t())?
526 };
527
528 let mut all_hidden = Vec::new();
529 let mut all_cell = Vec::new();
530
531 for index in 0..x.dim(-2) {
532 let mut ifgo = x.index((Ellipsis, index, 0..));
533 if let Some(hidden) = hidden {
534 ifgo = addmm(&ifgo, hidden, self.wh.t(), None, None)?;
535 }
536
537 let pieces = split_equal(&ifgo, 4, -1)?;
538
539 let i = sigmoid(&pieces[0])?;
540 let f = sigmoid(&pieces[1])?;
541 let g = tanh(&pieces[2])?;
542 let o = sigmoid(&pieces[3])?;
543
544 let cell = match cell {
545 Some(cell) => f.multiply(cell)?.add(i.multiply(&g)?)?,
546 None => i.multiply(&g)?,
547 };
548
549 let hidden = o.multiply(tanh(&cell)?)?;
550
551 all_hidden.push(hidden);
552 all_cell.push(cell);
553 }
554
555 Ok((stack(&all_hidden[..], -2)?, stack(&all_cell[..], -2)?))
556 }
557}
558
559impl<'a, Input> Module<Input> for Lstm
560where
561 Input: Into<LstmInput<'a>>,
562{
563 type Output = (Array, Array);
564 type Error = Exception;
565
566 fn forward(&mut self, input: Input) -> Result<(Array, Array), Exception> {
567 let input = input.into();
568 self.step(input.x, input.hidden, input.cell)
569 }
570
571 fn training_mode(&mut self, _mode: bool) {}
572}
573
574#[cfg(test)]
576mod tests {
577 use crate::{builder::Builder, ops::maximum_device, random::normal};
578
579 use super::*;
580
581 #[test]
582 fn test_rnn() {
583 let mut layer = Rnn::new(5, 12).unwrap();
584 let inp = normal::<f32>(&[2, 25, 5], None, None, None).unwrap();
585
586 let h_out = layer.forward(RnnInput::from(&inp)).unwrap();
587 assert_eq!(h_out.shape(), &[2, 25, 12]);
588
589 let nonlinearity = |x: &Array, d: &Stream| maximum_device(x, array!(0.0), d);
590 let mut layer = RnnBuilder::new(5, 12)
591 .bias(false)
592 .non_linearity(Arc::new(nonlinearity) as Arc<NonLinearity>)
593 .build()
594 .unwrap();
595
596 let h_out = layer.forward(RnnInput::from(&inp)).unwrap();
597 assert_eq!(h_out.shape(), &[2, 25, 12]);
598
599 let inp = normal::<f32>(&[44, 5], None, None, None).unwrap();
600 let h_out = layer.forward(RnnInput::from(&inp)).unwrap();
601 assert_eq!(h_out.shape(), &[44, 12]);
602
603 let hidden = h_out.index((-1, ..));
604 let h_out = layer.forward(RnnInput::from((&inp, &hidden))).unwrap();
605 assert_eq!(h_out.shape(), &[44, 12]);
606 }
607
608 #[test]
609 fn test_gru() {
610 let mut layer = Gru::new(5, 12).unwrap();
611 let inp = normal::<f32>(&[2, 25, 5], None, None, None).unwrap();
612
613 let h_out = layer.forward(GruInput::from(&inp)).unwrap();
614 assert_eq!(h_out.shape(), &[2, 25, 12]);
615
616 let hidden = h_out.index((.., -1, ..));
617 let h_out = layer.forward(GruInput::from((&inp, &hidden))).unwrap();
618 assert_eq!(h_out.shape(), &[2, 25, 12]);
619
620 let inp = normal::<f32>(&[44, 5], None, None, None).unwrap();
621 let h_out = layer.forward(GruInput::from(&inp)).unwrap();
622 assert_eq!(h_out.shape(), &[44, 12]);
623
624 let hidden = h_out.index((-1, ..));
625 let h_out = layer.forward(GruInput::from((&inp, &hidden))).unwrap();
626 assert_eq!(h_out.shape(), &[44, 12]);
627 }
628
629 #[test]
630 fn test_lstm() {
631 let mut layer = Lstm::new(5, 12).unwrap();
632 let inp = normal::<f32>(&[2, 25, 5], None, None, None).unwrap();
633
634 let (h_out, c_out) = layer.forward(LstmInput::from(&inp)).unwrap();
635 assert_eq!(h_out.shape(), &[2, 25, 12]);
636 assert_eq!(c_out.shape(), &[2, 25, 12]);
637
638 let (h_out, c_out) = layer
639 .step(
640 &inp,
641 Some(&h_out.index((.., -1, ..))),
642 Some(&c_out.index((.., -1, ..))),
643 )
644 .unwrap();
645 assert_eq!(h_out.shape(), &[2, 25, 12]);
646 assert_eq!(c_out.shape(), &[2, 25, 12]);
647
648 let inp = normal::<f32>(&[44, 5], None, None, None).unwrap();
649 let (h_out, c_out) = layer.forward(LstmInput::from(&inp)).unwrap();
650 assert_eq!(h_out.shape(), &[44, 12]);
651 assert_eq!(c_out.shape(), &[44, 12]);
652
653 let hidden = h_out.index((-1, ..));
654 let cell = c_out.index((-1, ..));
655 let (h_out, c_out) = layer
656 .forward(LstmInput::from((&inp, &hidden, &cell)))
657 .unwrap();
658 assert_eq!(h_out.shape(), &[44, 12]);
659 assert_eq!(c_out.shape(), &[44, 12]);
660 }
661}