1use std::sync::Arc;
2
3use crate::{
4 array,
5 error::Exception,
6 module::{Module, Param},
7 ops::{
8 addmm,
9 indexing::{Ellipsis, IndexOp},
10 matmul, sigmoid, split, stack_axis, tanh, tanh_device,
11 },
12 random::uniform,
13 Array, Stream,
14};
15use mlx_internal_macros::{generate_builder, Buildable, Builder};
16use mlx_macros::ModuleParameters;
17
18pub type NonLinearity = dyn Fn(&Array, &Stream) -> Result<Array, Exception>;
20
21#[derive(Clone, ModuleParameters, Buildable)]
33#[module(root = crate)]
34#[buildable(root = crate)]
35pub struct Rnn {
36 pub non_linearity: Arc<NonLinearity>,
38
39 #[param]
41 pub wxh: Param<Array>,
42
43 #[param]
45 pub whh: Param<Array>,
46
47 #[param]
49 pub bias: Param<Option<Array>>,
50}
51
52#[derive(Clone, Builder)]
54#[builder(
55 root = crate,
56 build_with = build_rnn,
57 err = Exception,
58)]
59pub struct RnnBuilder {
60 pub input_size: i32,
62
63 pub hidden_size: i32,
65
66 #[builder(optional, default = Rnn::DEFAULT_NONLINEARITY)]
68 pub non_linearity: Option<Arc<NonLinearity>>,
69
70 #[builder(optional, default = Rnn::DEFAULT_BIAS)]
72 pub bias: bool,
73}
74
75impl std::fmt::Debug for RnnBuilder {
76 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
77 f.debug_struct("RnnBuilder")
78 .field("bias", &self.bias)
79 .finish()
80 }
81}
82
83fn build_rnn(builder: RnnBuilder) -> Result<Rnn, Exception> {
85 let input_size = builder.input_size;
86 let hidden_size = builder.hidden_size;
87 let non_linearity = builder
88 .non_linearity
89 .unwrap_or_else(|| Arc::new(|x, d| tanh_device(x, d)));
90
91 let scale = 1.0 / (input_size as f32).sqrt();
92 let wxh = uniform::<_, f32>(-scale, scale, &[hidden_size, input_size], None)?;
93 let whh = uniform::<_, f32>(-scale, scale, &[hidden_size, hidden_size], None)?;
94 let bias = if builder.bias {
95 Some(uniform::<_, f32>(-scale, scale, &[hidden_size], None)?)
96 } else {
97 None
98 };
99
100 Ok(Rnn {
101 non_linearity,
102 wxh: Param::new(wxh),
103 whh: Param::new(whh),
104 bias: Param::new(bias),
105 })
106}
107
108impl std::fmt::Debug for Rnn {
109 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
110 f.debug_struct("Rnn")
111 .field("wxh", &self.wxh)
112 .field("whh", &self.whh)
113 .field("bias", &self.bias)
114 .finish()
115 }
116}
117
118impl Rnn {
119 pub const DEFAULT_BIAS: bool = true;
121
122 pub const DEFAULT_NONLINEARITY: Option<Arc<NonLinearity>> = None;
124
125 pub fn step(&mut self, x: &Array, hidden: Option<&Array>) -> Result<Array, Exception> {
127 let x = if let Some(bias) = &self.bias.value {
128 addmm(bias, x, self.wxh.t(), None, None)?
129 } else {
130 matmul(x, self.wxh.t())?
131 };
132
133 let mut all_hidden = Vec::new();
134 for index in 0..x.dim(-2) {
135 let hidden = match hidden {
136 Some(hidden_) => addmm(
137 x.index((Ellipsis, index, 0..)),
138 hidden_,
139 self.whh.t(),
140 None,
141 None,
142 )?,
143 None => x.index((Ellipsis, index, 0..)),
144 };
145
146 let hidden = (self.non_linearity)(&hidden, &Stream::default())?;
147 all_hidden.push(hidden);
148 }
149
150 stack_axis(&all_hidden[..], -2)
151 }
152}
153
154generate_builder! {
155 #[derive(Debug, Clone, Buildable)]
157 #[buildable(root = crate)]
158 #[builder(root = crate)]
159 pub struct RnnInput<'a> {
160 pub x: &'a Array,
162
163 #[builder(optional, default = None)]
165 pub hidden: Option<&'a Array>,
166 }
167}
168
169impl<'a> From<&'a Array> for RnnInput<'a> {
170 fn from(x: &'a Array) -> Self {
171 RnnInput { x, hidden: None }
172 }
173}
174
175impl<'a> From<(&'a Array,)> for RnnInput<'a> {
176 fn from(input: (&'a Array,)) -> Self {
177 RnnInput {
178 x: input.0,
179 hidden: None,
180 }
181 }
182}
183
184impl<'a> From<(&'a Array, &'a Array)> for RnnInput<'a> {
185 fn from(input: (&'a Array, &'a Array)) -> Self {
186 RnnInput {
187 x: input.0,
188 hidden: Some(input.1),
189 }
190 }
191}
192
193impl<'a> From<(&'a Array, Option<&'a Array>)> for RnnInput<'a> {
194 fn from(input: (&'a Array, Option<&'a Array>)) -> Self {
195 RnnInput {
196 x: input.0,
197 hidden: input.1,
198 }
199 }
200}
201
202impl<'a, Input> Module<Input> for Rnn
203where
204 Input: Into<RnnInput<'a>>,
205{
206 type Error = Exception;
207 type Output = Array;
208
209 fn forward(&mut self, input: Input) -> Result<Array, Exception> {
210 let input = input.into();
211 self.step(input.x, input.hidden)
212 }
213
214 fn training_mode(&mut self, _mode: bool) {}
215}
216
217#[derive(Debug, Clone, ModuleParameters, Buildable)]
229#[module(root = crate)]
230#[buildable(root = crate)]
231pub struct Gru {
232 pub hidden_size: i32,
234
235 #[param]
237 pub wx: Param<Array>,
238
239 #[param]
241 pub wh: Param<Array>,
242
243 #[param]
245 pub bias: Param<Option<Array>>,
246
247 #[param]
249 pub bhn: Param<Option<Array>>,
250}
251
252#[derive(Debug, Clone, Builder)]
254#[builder(
255 root = crate,
256 build_with = build_gru,
257 err = Exception,
258)]
259pub struct GruBuilder {
260 pub input_size: i32,
262
263 pub hidden_size: i32,
265
266 #[builder(optional, default = Gru::DEFAULT_BIAS)]
268 pub bias: bool,
269}
270
271fn build_gru(builder: GruBuilder) -> Result<Gru, Exception> {
272 let input_size = builder.input_size;
273 let hidden_size = builder.hidden_size;
274
275 let scale = 1.0 / f32::sqrt(hidden_size as f32);
276 let wx = uniform::<_, f32>(-scale, scale, &[3 * hidden_size, input_size], None)?;
277 let wh = uniform::<_, f32>(-scale, scale, &[3 * hidden_size, hidden_size], None)?;
278 let (bias, bhn) = if builder.bias {
279 let bias = uniform::<_, f32>(-scale, scale, &[3 * hidden_size], None)?;
280 let bhn = uniform::<_, f32>(-scale, scale, &[hidden_size], None)?;
281 (Some(bias), Some(bhn))
282 } else {
283 (None, None)
284 };
285
286 Ok(Gru {
287 hidden_size,
288 wx: Param::new(wx),
289 wh: Param::new(wh),
290 bias: Param::new(bias),
291 bhn: Param::new(bhn),
292 })
293}
294
295impl Gru {
296 pub const DEFAULT_BIAS: bool = true;
298
299 pub fn step(&mut self, x: &Array, hidden: Option<&Array>) -> Result<Array, Exception> {
301 let x = if let Some(b) = &self.bias.value {
302 addmm(b, x, self.wx.t(), None, None)?
303 } else {
304 matmul(x, self.wx.t())?
305 };
306
307 let x_rz = x.index((Ellipsis, ..(-self.hidden_size)));
308 let x_n = x.index((Ellipsis, (-self.hidden_size)..));
309
310 let mut all_hidden = Vec::new();
311
312 for index in 0..x.dim(-2) {
313 let mut rz = x_rz.index((Ellipsis, index, ..));
314 let mut h_proj_n = None;
315 if let Some(hidden_) = hidden {
316 let h_proj = matmul(hidden_, self.wh.t())?;
317 let h_proj_rz = h_proj.index((Ellipsis, ..(-self.hidden_size)));
318 h_proj_n = Some(h_proj.index((Ellipsis, (-self.hidden_size)..)));
319
320 if let Some(bhn) = &self.bhn.value {
321 h_proj_n = h_proj_n
322 .map(|h_proj_n| h_proj_n.add(bhn))
323 .transpose()?;
325 }
326
327 rz = rz.add(h_proj_rz)?;
328 }
329
330 rz = sigmoid(&rz)?;
331
332 let parts = split(&rz, 2, -1)?;
333 let r = &parts[0];
334 let z = &parts[1];
335
336 let mut n = x_n.index((Ellipsis, index, 0..));
337
338 if let Some(h_proj_n) = h_proj_n {
339 n = n.add(r.multiply(h_proj_n)?)?;
340 }
341 n = tanh(&n)?;
342
343 let hidden = match hidden {
344 Some(hidden) => array!(1.0)
345 .subtract(z)?
346 .multiply(&n)?
347 .add(z.multiply(hidden)?)?,
348 None => array!(1.0).subtract(z)?.multiply(&n)?,
349 };
350
351 all_hidden.push(hidden);
352 }
353
354 stack_axis(&all_hidden[..], -2)
355 }
356}
357
358pub type GruInput<'a> = RnnInput<'a>;
360
361pub type GruInputBuilder<'a> = RnnInputBuilder<'a>;
363
364impl<'a, Input> Module<Input> for Gru
365where
366 Input: Into<GruInput<'a>>,
367{
368 type Error = Exception;
369 type Output = Array;
370
371 fn forward(&mut self, input: Input) -> Result<Array, Exception> {
372 let input = input.into();
373 self.step(input.x, input.hidden)
374 }
375
376 fn training_mode(&mut self, _mode: bool) {}
377}
378
379#[derive(Debug, Clone, ModuleParameters, Buildable)]
381#[module(root = crate)]
382#[buildable(root = crate)]
383pub struct Lstm {
384 #[param]
386 pub wx: Param<Array>,
387
388 #[param]
390 pub wh: Param<Array>,
391
392 #[param]
394 pub bias: Param<Option<Array>>,
395}
396
397#[derive(Debug, Clone, Builder)]
399#[builder(
400 root = crate,
401 build_with = build_lstm,
402 err = Exception,
403)]
404pub struct LstmBuilder {
405 pub input_size: i32,
407
408 pub hidden_size: i32,
410
411 #[builder(optional, default = Lstm::DEFAULT_BIAS)]
413 pub bias: bool,
414}
415
416fn build_lstm(builder: LstmBuilder) -> Result<Lstm, Exception> {
417 let input_size = builder.input_size;
418 let hidden_size = builder.hidden_size;
419 let scale = 1.0 / f32::sqrt(hidden_size as f32);
420 let wx = uniform::<_, f32>(-scale, scale, &[4 * hidden_size, input_size], None)?;
421 let wh = uniform::<_, f32>(-scale, scale, &[4 * hidden_size, hidden_size], None)?;
422 let bias = if builder.bias {
423 Some(uniform::<_, f32>(-scale, scale, &[4 * hidden_size], None)?)
424 } else {
425 None
426 };
427
428 Ok(Lstm {
429 wx: Param::new(wx),
430 wh: Param::new(wh),
431 bias: Param::new(bias),
432 })
433}
434
435generate_builder! {
436 #[derive(Debug, Clone, Buildable)]
438 #[buildable(root = crate)]
439 #[builder(root = crate)]
440 pub struct LstmInput<'a> {
441 pub x: &'a Array,
443
444 #[builder(optional, default = None)]
446 pub hidden: Option<&'a Array>,
447
448 #[builder(optional, default = None)]
450 pub cell: Option<&'a Array>,
451 }
452}
453
454impl<'a> From<&'a Array> for LstmInput<'a> {
455 fn from(x: &'a Array) -> Self {
456 LstmInput {
457 x,
458 hidden: None,
459 cell: None,
460 }
461 }
462}
463
464impl<'a> From<(&'a Array,)> for LstmInput<'a> {
465 fn from(input: (&'a Array,)) -> Self {
466 LstmInput {
467 x: input.0,
468 hidden: None,
469 cell: None,
470 }
471 }
472}
473
474impl<'a> From<(&'a Array, &'a Array)> for LstmInput<'a> {
475 fn from(input: (&'a Array, &'a Array)) -> Self {
476 LstmInput {
477 x: input.0,
478 hidden: Some(input.1),
479 cell: None,
480 }
481 }
482}
483
484impl<'a> From<(&'a Array, &'a Array, &'a Array)> for LstmInput<'a> {
485 fn from(input: (&'a Array, &'a Array, &'a Array)) -> Self {
486 LstmInput {
487 x: input.0,
488 hidden: Some(input.1),
489 cell: Some(input.2),
490 }
491 }
492}
493
494impl<'a> From<(&'a Array, Option<&'a Array>)> for LstmInput<'a> {
495 fn from(input: (&'a Array, Option<&'a Array>)) -> Self {
496 LstmInput {
497 x: input.0,
498 hidden: input.1,
499 cell: None,
500 }
501 }
502}
503
504impl<'a> From<(&'a Array, Option<&'a Array>, Option<&'a Array>)> for LstmInput<'a> {
505 fn from(input: (&'a Array, Option<&'a Array>, Option<&'a Array>)) -> Self {
506 LstmInput {
507 x: input.0,
508 hidden: input.1,
509 cell: input.2,
510 }
511 }
512}
513
514impl Lstm {
515 pub const DEFAULT_BIAS: bool = true;
517
518 pub fn step(
520 &mut self,
521 x: &Array,
522 hidden: Option<&Array>,
523 cell: Option<&Array>,
524 ) -> Result<(Array, Array), Exception> {
525 let x = if let Some(b) = &self.bias.value {
526 addmm(b, x, self.wx.t(), None, None)?
527 } else {
528 matmul(x, self.wx.t())?
529 };
530
531 let mut all_hidden = Vec::new();
532 let mut all_cell = Vec::new();
533
534 for index in 0..x.dim(-2) {
535 let mut ifgo = x.index((Ellipsis, index, 0..));
536 if let Some(hidden) = hidden {
537 ifgo = addmm(&ifgo, hidden, self.wh.t(), None, None)?;
538 }
539
540 let pieces = split(&ifgo, 4, -1)?;
541
542 let i = sigmoid(&pieces[0])?;
543 let f = sigmoid(&pieces[1])?;
544 let g = tanh(&pieces[2])?;
545 let o = sigmoid(&pieces[3])?;
546
547 let cell = match cell {
548 Some(cell) => f.multiply(cell)?.add(i.multiply(&g)?)?,
549 None => i.multiply(&g)?,
550 };
551
552 let hidden = o.multiply(tanh(&cell)?)?;
553
554 all_hidden.push(hidden);
555 all_cell.push(cell);
556 }
557
558 Ok((
559 stack_axis(&all_hidden[..], -2)?,
560 stack_axis(&all_cell[..], -2)?,
561 ))
562 }
563}
564
565impl<'a, Input> Module<Input> for Lstm
566where
567 Input: Into<LstmInput<'a>>,
568{
569 type Output = (Array, Array);
570 type Error = Exception;
571
572 fn forward(&mut self, input: Input) -> Result<(Array, Array), Exception> {
573 let input = input.into();
574 self.step(input.x, input.hidden, input.cell)
575 }
576
577 fn training_mode(&mut self, _mode: bool) {}
578}
579
580#[cfg(test)]
582mod tests {
583 use crate::{builder::Builder, ops::maximum_device, random::normal};
584
585 use super::*;
586
587 #[test]
588 fn test_rnn() {
589 let mut layer = Rnn::new(5, 12).unwrap();
590 let inp = normal::<f32>(&[2, 25, 5], None, None, None).unwrap();
591
592 let h_out = layer.forward(RnnInput::from(&inp)).unwrap();
593 assert_eq!(h_out.shape(), &[2, 25, 12]);
594
595 let nonlinearity = |x: &Array, d: &Stream| maximum_device(x, array!(0.0), d);
596 let mut layer = RnnBuilder::new(5, 12)
597 .bias(false)
598 .non_linearity(Arc::new(nonlinearity) as Arc<NonLinearity>)
599 .build()
600 .unwrap();
601
602 let h_out = layer.forward(RnnInput::from(&inp)).unwrap();
603 assert_eq!(h_out.shape(), &[2, 25, 12]);
604
605 let inp = normal::<f32>(&[44, 5], None, None, None).unwrap();
606 let h_out = layer.forward(RnnInput::from(&inp)).unwrap();
607 assert_eq!(h_out.shape(), &[44, 12]);
608
609 let hidden = h_out.index((-1, ..));
610 let h_out = layer.forward(RnnInput::from((&inp, &hidden))).unwrap();
611 assert_eq!(h_out.shape(), &[44, 12]);
612 }
613
614 #[test]
615 fn test_gru() {
616 let mut layer = Gru::new(5, 12).unwrap();
617 let inp = normal::<f32>(&[2, 25, 5], None, None, None).unwrap();
618
619 let h_out = layer.forward(GruInput::from(&inp)).unwrap();
620 assert_eq!(h_out.shape(), &[2, 25, 12]);
621
622 let hidden = h_out.index((.., -1, ..));
623 let h_out = layer.forward(GruInput::from((&inp, &hidden))).unwrap();
624 assert_eq!(h_out.shape(), &[2, 25, 12]);
625
626 let inp = normal::<f32>(&[44, 5], None, None, None).unwrap();
627 let h_out = layer.forward(GruInput::from(&inp)).unwrap();
628 assert_eq!(h_out.shape(), &[44, 12]);
629
630 let hidden = h_out.index((-1, ..));
631 let h_out = layer.forward(GruInput::from((&inp, &hidden))).unwrap();
632 assert_eq!(h_out.shape(), &[44, 12]);
633 }
634
635 #[test]
636 fn test_lstm() {
637 let mut layer = Lstm::new(5, 12).unwrap();
638 let inp = normal::<f32>(&[2, 25, 5], None, None, None).unwrap();
639
640 let (h_out, c_out) = layer.forward(LstmInput::from(&inp)).unwrap();
641 assert_eq!(h_out.shape(), &[2, 25, 12]);
642 assert_eq!(c_out.shape(), &[2, 25, 12]);
643
644 let (h_out, c_out) = layer
645 .step(
646 &inp,
647 Some(&h_out.index((.., -1, ..))),
648 Some(&c_out.index((.., -1, ..))),
649 )
650 .unwrap();
651 assert_eq!(h_out.shape(), &[2, 25, 12]);
652 assert_eq!(c_out.shape(), &[2, 25, 12]);
653
654 let inp = normal::<f32>(&[44, 5], None, None, None).unwrap();
655 let (h_out, c_out) = layer.forward(LstmInput::from(&inp)).unwrap();
656 assert_eq!(h_out.shape(), &[44, 12]);
657 assert_eq!(c_out.shape(), &[44, 12]);
658
659 let hidden = h_out.index((-1, ..));
660 let cell = c_out.index((-1, ..));
661 let (h_out, c_out) = layer
662 .forward(LstmInput::from((&inp, &hidden, &cell)))
663 .unwrap();
664 assert_eq!(h_out.shape(), &[44, 12]);
665 assert_eq!(c_out.shape(), &[44, 12]);
666 }
667}