mlx_rs/nn/
recurrent.rs

1use std::sync::Arc;
2
3use crate::{
4    array,
5    error::Exception,
6    module::{Module, Param},
7    ops::{
8        addmm,
9        indexing::{Ellipsis, IndexOp},
10        matmul, sigmoid, split, stack_axis, tanh, tanh_device,
11    },
12    random::uniform,
13    Array, Stream,
14};
15use mlx_internal_macros::{generate_builder, Buildable, Builder};
16use mlx_macros::ModuleParameters;
17
18/// Type alias for the non-linearity function.
19pub type NonLinearity = dyn Fn(&Array, &Stream) -> Result<Array, Exception>;
20
21/// An Elman recurrent layer.
22///
23/// The input is a sequence of shape `NLD` or `LD` where:
24///
25/// * `N` is the optional batch dimension
26/// * `L` is the sequence length
27/// * `D` is the input's feature dimension
28///
29/// The hidden state `h` has shape `NH` or `H`, depending on
30/// whether the input is batched or not. Returns the hidden state at each
31/// time step, of shape `NLH` or `LH`.
32#[derive(Clone, ModuleParameters, Buildable)]
33#[module(root = crate)]
34#[buildable(root = crate)]
35pub struct Rnn {
36    /// non-linearity function to use
37    pub non_linearity: Arc<NonLinearity>,
38
39    /// Wxh
40    #[param]
41    pub wxh: Param<Array>,
42
43    /// Whh
44    #[param]
45    pub whh: Param<Array>,
46
47    /// Bias. Enabled by default.
48    #[param]
49    pub bias: Param<Option<Array>>,
50}
51
52/// Builder for the [`Rnn`] module.
53#[derive(Clone, Builder)]
54#[builder(
55    root = crate,
56    build_with = build_rnn,
57    err = Exception,
58)]
59pub struct RnnBuilder {
60    /// Dimension of the input, `D`.
61    pub input_size: i32,
62
63    /// Dimension of the hidden state, `H`.
64    pub hidden_size: i32,
65
66    /// non-linearity function to use. Default to `tanh` if not set.
67    #[builder(optional, default = Rnn::DEFAULT_NONLINEARITY)]
68    pub non_linearity: Option<Arc<NonLinearity>>,
69
70    /// Bias. Default to [`Rnn::DEFAULT_BIAS`].
71    #[builder(optional, default = Rnn::DEFAULT_BIAS)]
72    pub bias: bool,
73}
74
75impl std::fmt::Debug for RnnBuilder {
76    fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
77        f.debug_struct("RnnBuilder")
78            .field("bias", &self.bias)
79            .finish()
80    }
81}
82
83/// Build the [`Rnn`] module.
84fn build_rnn(builder: RnnBuilder) -> Result<Rnn, Exception> {
85    let input_size = builder.input_size;
86    let hidden_size = builder.hidden_size;
87    let non_linearity = builder
88        .non_linearity
89        .unwrap_or_else(|| Arc::new(|x, d| tanh_device(x, d)));
90
91    let scale = 1.0 / (input_size as f32).sqrt();
92    let wxh = uniform::<_, f32>(-scale, scale, &[hidden_size, input_size], None)?;
93    let whh = uniform::<_, f32>(-scale, scale, &[hidden_size, hidden_size], None)?;
94    let bias = if builder.bias {
95        Some(uniform::<_, f32>(-scale, scale, &[hidden_size], None)?)
96    } else {
97        None
98    };
99
100    Ok(Rnn {
101        non_linearity,
102        wxh: Param::new(wxh),
103        whh: Param::new(whh),
104        bias: Param::new(bias),
105    })
106}
107
108impl std::fmt::Debug for Rnn {
109    fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
110        f.debug_struct("Rnn")
111            .field("wxh", &self.wxh)
112            .field("whh", &self.whh)
113            .field("bias", &self.bias)
114            .finish()
115    }
116}
117
118impl Rnn {
119    /// Default value for bias
120    pub const DEFAULT_BIAS: bool = true;
121
122    /// RnnBuilder::non_linearity is initialized with `None`, and the default non-linearity is `tanh` if not set.
123    pub const DEFAULT_NONLINEARITY: Option<Arc<NonLinearity>> = None;
124
125    /// Apply a single step of the RNN.
126    pub fn step(&mut self, x: &Array, hidden: Option<&Array>) -> Result<Array, Exception> {
127        let x = if let Some(bias) = &self.bias.value {
128            addmm(bias, x, self.wxh.t(), None, None)?
129        } else {
130            matmul(x, self.wxh.t())?
131        };
132
133        let mut all_hidden = Vec::new();
134        for index in 0..x.dim(-2) {
135            let hidden = match hidden {
136                Some(hidden_) => addmm(
137                    x.index((Ellipsis, index, 0..)),
138                    hidden_,
139                    self.whh.t(),
140                    None,
141                    None,
142                )?,
143                None => x.index((Ellipsis, index, 0..)),
144            };
145
146            let hidden = (self.non_linearity)(&hidden, &Stream::default())?;
147            all_hidden.push(hidden);
148        }
149
150        stack_axis(&all_hidden[..], -2)
151    }
152}
153
154generate_builder! {
155    /// Input for the RNN module.
156    #[derive(Debug, Clone, Buildable)]
157    #[buildable(root = crate)]
158    #[builder(root = crate)]
159    pub struct RnnInput<'a> {
160        /// Input tensor
161        pub x: &'a Array,
162
163        /// Hidden state
164        #[builder(optional, default = None)]
165        pub hidden: Option<&'a Array>,
166    }
167}
168
169impl<'a> From<&'a Array> for RnnInput<'a> {
170    fn from(x: &'a Array) -> Self {
171        RnnInput { x, hidden: None }
172    }
173}
174
175impl<'a> From<(&'a Array,)> for RnnInput<'a> {
176    fn from(input: (&'a Array,)) -> Self {
177        RnnInput {
178            x: input.0,
179            hidden: None,
180        }
181    }
182}
183
184impl<'a> From<(&'a Array, &'a Array)> for RnnInput<'a> {
185    fn from(input: (&'a Array, &'a Array)) -> Self {
186        RnnInput {
187            x: input.0,
188            hidden: Some(input.1),
189        }
190    }
191}
192
193impl<'a> From<(&'a Array, Option<&'a Array>)> for RnnInput<'a> {
194    fn from(input: (&'a Array, Option<&'a Array>)) -> Self {
195        RnnInput {
196            x: input.0,
197            hidden: input.1,
198        }
199    }
200}
201
202impl<'a, Input> Module<Input> for Rnn
203where
204    Input: Into<RnnInput<'a>>,
205{
206    type Error = Exception;
207    type Output = Array;
208
209    fn forward(&mut self, input: Input) -> Result<Array, Exception> {
210        let input = input.into();
211        self.step(input.x, input.hidden)
212    }
213
214    fn training_mode(&mut self, _mode: bool) {}
215}
216
217/// A gated recurrent unit (GRU) RNN layer.
218///
219/// The input has shape `NLD` or `LD` where:
220///
221/// * `N` is the optional batch dimension
222/// * `L` is the sequence length
223/// * `D` is the input's feature dimension
224///
225/// The hidden state `h` has shape `NH` or `H`, depending on
226/// whether the input is batched or not. Returns the hidden state at each
227/// time step, of shape `NLH` or `LH`.
228#[derive(Debug, Clone, ModuleParameters, Buildable)]
229#[module(root = crate)]
230#[buildable(root = crate)]
231pub struct Gru {
232    /// Dimension of the hidden state, `H`
233    pub hidden_size: i32,
234
235    /// Wx
236    #[param]
237    pub wx: Param<Array>,
238
239    /// Wh
240    #[param]
241    pub wh: Param<Array>,
242
243    /// Bias. Enabled by default.
244    #[param]
245    pub bias: Param<Option<Array>>,
246
247    /// bhn. Enabled by default.
248    #[param]
249    pub bhn: Param<Option<Array>>,
250}
251
252/// Builder for the [`Gru`] module.
253#[derive(Debug, Clone, Builder)]
254#[builder(
255    root = crate,
256    build_with = build_gru,
257    err = Exception,
258)]
259pub struct GruBuilder {
260    /// Dimension of the input, `D`.
261    pub input_size: i32,
262
263    /// Dimension of the hidden state, `H`.
264    pub hidden_size: i32,
265
266    /// Bias. Default to [`Gru::DEFAULT_BIAS`].
267    #[builder(optional, default = Gru::DEFAULT_BIAS)]
268    pub bias: bool,
269}
270
271fn build_gru(builder: GruBuilder) -> Result<Gru, Exception> {
272    let input_size = builder.input_size;
273    let hidden_size = builder.hidden_size;
274
275    let scale = 1.0 / f32::sqrt(hidden_size as f32);
276    let wx = uniform::<_, f32>(-scale, scale, &[3 * hidden_size, input_size], None)?;
277    let wh = uniform::<_, f32>(-scale, scale, &[3 * hidden_size, hidden_size], None)?;
278    let (bias, bhn) = if builder.bias {
279        let bias = uniform::<_, f32>(-scale, scale, &[3 * hidden_size], None)?;
280        let bhn = uniform::<_, f32>(-scale, scale, &[hidden_size], None)?;
281        (Some(bias), Some(bhn))
282    } else {
283        (None, None)
284    };
285
286    Ok(Gru {
287        hidden_size,
288        wx: Param::new(wx),
289        wh: Param::new(wh),
290        bias: Param::new(bias),
291        bhn: Param::new(bhn),
292    })
293}
294
295impl Gru {
296    /// Enable `bias` and `bhn` by default
297    pub const DEFAULT_BIAS: bool = true;
298
299    /// Apply a single step of the GRU.
300    pub fn step(&mut self, x: &Array, hidden: Option<&Array>) -> Result<Array, Exception> {
301        let x = if let Some(b) = &self.bias.value {
302            addmm(b, x, self.wx.t(), None, None)?
303        } else {
304            matmul(x, self.wx.t())?
305        };
306
307        let x_rz = x.index((Ellipsis, ..(-self.hidden_size)));
308        let x_n = x.index((Ellipsis, (-self.hidden_size)..));
309
310        let mut all_hidden = Vec::new();
311
312        for index in 0..x.dim(-2) {
313            let mut rz = x_rz.index((Ellipsis, index, ..));
314            let mut h_proj_n = None;
315            if let Some(hidden_) = hidden {
316                let h_proj = matmul(hidden_, self.wh.t())?;
317                let h_proj_rz = h_proj.index((Ellipsis, ..(-self.hidden_size)));
318                h_proj_n = Some(h_proj.index((Ellipsis, (-self.hidden_size)..)));
319
320                if let Some(bhn) = &self.bhn.value {
321                    h_proj_n = h_proj_n
322                        .map(|h_proj_n| h_proj_n.add(bhn))
323                        // This is not matrix transpose, but from `Option<Result<_>>` to `Result<Option<_>>`
324                        .transpose()?;
325                }
326
327                rz = rz.add(h_proj_rz)?;
328            }
329
330            rz = sigmoid(&rz)?;
331
332            let parts = split(&rz, 2, -1)?;
333            let r = &parts[0];
334            let z = &parts[1];
335
336            let mut n = x_n.index((Ellipsis, index, 0..));
337
338            if let Some(h_proj_n) = h_proj_n {
339                n = n.add(r.multiply(h_proj_n)?)?;
340            }
341            n = tanh(&n)?;
342
343            let hidden = match hidden {
344                Some(hidden) => array!(1.0)
345                    .subtract(z)?
346                    .multiply(&n)?
347                    .add(z.multiply(hidden)?)?,
348                None => array!(1.0).subtract(z)?.multiply(&n)?,
349            };
350
351            all_hidden.push(hidden);
352        }
353
354        stack_axis(&all_hidden[..], -2)
355    }
356}
357
358/// Type alias for the input of the GRU module.
359pub type GruInput<'a> = RnnInput<'a>;
360
361/// Type alias for the builder of the input of the GRU module.
362pub type GruInputBuilder<'a> = RnnInputBuilder<'a>;
363
364impl<'a, Input> Module<Input> for Gru
365where
366    Input: Into<GruInput<'a>>,
367{
368    type Error = Exception;
369    type Output = Array;
370
371    fn forward(&mut self, input: Input) -> Result<Array, Exception> {
372        let input = input.into();
373        self.step(input.x, input.hidden)
374    }
375
376    fn training_mode(&mut self, _mode: bool) {}
377}
378
379/// A long short-term memory (LSTM) RNN layer.
380#[derive(Debug, Clone, ModuleParameters, Buildable)]
381#[module(root = crate)]
382#[buildable(root = crate)]
383pub struct Lstm {
384    /// Wx
385    #[param]
386    pub wx: Param<Array>,
387
388    /// Wh
389    #[param]
390    pub wh: Param<Array>,
391
392    /// Bias. Enabled by default.
393    #[param]
394    pub bias: Param<Option<Array>>,
395}
396
397/// Builder for the [`Lstm`] module.
398#[derive(Debug, Clone, Builder)]
399#[builder(
400    root = crate,
401    build_with = build_lstm,
402    err = Exception,
403)]
404pub struct LstmBuilder {
405    /// Dimension of the input, `D`.
406    pub input_size: i32,
407
408    /// Dimension of the hidden state, `H`.
409    pub hidden_size: i32,
410
411    /// Bias. Default to [`Lstm::DEFAULT_BIAS`].
412    #[builder(optional, default = Lstm::DEFAULT_BIAS)]
413    pub bias: bool,
414}
415
416fn build_lstm(builder: LstmBuilder) -> Result<Lstm, Exception> {
417    let input_size = builder.input_size;
418    let hidden_size = builder.hidden_size;
419    let scale = 1.0 / f32::sqrt(hidden_size as f32);
420    let wx = uniform::<_, f32>(-scale, scale, &[4 * hidden_size, input_size], None)?;
421    let wh = uniform::<_, f32>(-scale, scale, &[4 * hidden_size, hidden_size], None)?;
422    let bias = if builder.bias {
423        Some(uniform::<_, f32>(-scale, scale, &[4 * hidden_size], None)?)
424    } else {
425        None
426    };
427
428    Ok(Lstm {
429        wx: Param::new(wx),
430        wh: Param::new(wh),
431        bias: Param::new(bias),
432    })
433}
434
435generate_builder! {
436    /// Input for the LSTM module.
437    #[derive(Debug, Clone, Buildable)]
438    #[buildable(root = crate)]
439    #[builder(root = crate)]
440    pub struct LstmInput<'a> {
441        /// Input tensor
442        pub x: &'a Array,
443
444        /// Hidden state
445        #[builder(optional, default = None)]
446        pub hidden: Option<&'a Array>,
447
448        /// Cell state
449        #[builder(optional, default = None)]
450        pub cell: Option<&'a Array>,
451    }
452}
453
454impl<'a> From<&'a Array> for LstmInput<'a> {
455    fn from(x: &'a Array) -> Self {
456        LstmInput {
457            x,
458            hidden: None,
459            cell: None,
460        }
461    }
462}
463
464impl<'a> From<(&'a Array,)> for LstmInput<'a> {
465    fn from(input: (&'a Array,)) -> Self {
466        LstmInput {
467            x: input.0,
468            hidden: None,
469            cell: None,
470        }
471    }
472}
473
474impl<'a> From<(&'a Array, &'a Array)> for LstmInput<'a> {
475    fn from(input: (&'a Array, &'a Array)) -> Self {
476        LstmInput {
477            x: input.0,
478            hidden: Some(input.1),
479            cell: None,
480        }
481    }
482}
483
484impl<'a> From<(&'a Array, &'a Array, &'a Array)> for LstmInput<'a> {
485    fn from(input: (&'a Array, &'a Array, &'a Array)) -> Self {
486        LstmInput {
487            x: input.0,
488            hidden: Some(input.1),
489            cell: Some(input.2),
490        }
491    }
492}
493
494impl<'a> From<(&'a Array, Option<&'a Array>)> for LstmInput<'a> {
495    fn from(input: (&'a Array, Option<&'a Array>)) -> Self {
496        LstmInput {
497            x: input.0,
498            hidden: input.1,
499            cell: None,
500        }
501    }
502}
503
504impl<'a> From<(&'a Array, Option<&'a Array>, Option<&'a Array>)> for LstmInput<'a> {
505    fn from(input: (&'a Array, Option<&'a Array>, Option<&'a Array>)) -> Self {
506        LstmInput {
507            x: input.0,
508            hidden: input.1,
509            cell: input.2,
510        }
511    }
512}
513
514impl Lstm {
515    /// Default value for `bias`
516    pub const DEFAULT_BIAS: bool = true;
517
518    /// Apply a single step of the LSTM.
519    pub fn step(
520        &mut self,
521        x: &Array,
522        hidden: Option<&Array>,
523        cell: Option<&Array>,
524    ) -> Result<(Array, Array), Exception> {
525        let x = if let Some(b) = &self.bias.value {
526            addmm(b, x, self.wx.t(), None, None)?
527        } else {
528            matmul(x, self.wx.t())?
529        };
530
531        let mut all_hidden = Vec::new();
532        let mut all_cell = Vec::new();
533
534        for index in 0..x.dim(-2) {
535            let mut ifgo = x.index((Ellipsis, index, 0..));
536            if let Some(hidden) = hidden {
537                ifgo = addmm(&ifgo, hidden, self.wh.t(), None, None)?;
538            }
539
540            let pieces = split(&ifgo, 4, -1)?;
541
542            let i = sigmoid(&pieces[0])?;
543            let f = sigmoid(&pieces[1])?;
544            let g = tanh(&pieces[2])?;
545            let o = sigmoid(&pieces[3])?;
546
547            let cell = match cell {
548                Some(cell) => f.multiply(cell)?.add(i.multiply(&g)?)?,
549                None => i.multiply(&g)?,
550            };
551
552            let hidden = o.multiply(tanh(&cell)?)?;
553
554            all_hidden.push(hidden);
555            all_cell.push(cell);
556        }
557
558        Ok((
559            stack_axis(&all_hidden[..], -2)?,
560            stack_axis(&all_cell[..], -2)?,
561        ))
562    }
563}
564
565impl<'a, Input> Module<Input> for Lstm
566where
567    Input: Into<LstmInput<'a>>,
568{
569    type Output = (Array, Array);
570    type Error = Exception;
571
572    fn forward(&mut self, input: Input) -> Result<(Array, Array), Exception> {
573        let input = input.into();
574        self.step(input.x, input.hidden, input.cell)
575    }
576
577    fn training_mode(&mut self, _mode: bool) {}
578}
579
580// The uint tests below are ported from the python codebase
581#[cfg(test)]
582mod tests {
583    use crate::{builder::Builder, ops::maximum_device, random::normal};
584
585    use super::*;
586
587    #[test]
588    fn test_rnn() {
589        let mut layer = Rnn::new(5, 12).unwrap();
590        let inp = normal::<f32>(&[2, 25, 5], None, None, None).unwrap();
591
592        let h_out = layer.forward(RnnInput::from(&inp)).unwrap();
593        assert_eq!(h_out.shape(), &[2, 25, 12]);
594
595        let nonlinearity = |x: &Array, d: &Stream| maximum_device(x, array!(0.0), d);
596        let mut layer = RnnBuilder::new(5, 12)
597            .bias(false)
598            .non_linearity(Arc::new(nonlinearity) as Arc<NonLinearity>)
599            .build()
600            .unwrap();
601
602        let h_out = layer.forward(RnnInput::from(&inp)).unwrap();
603        assert_eq!(h_out.shape(), &[2, 25, 12]);
604
605        let inp = normal::<f32>(&[44, 5], None, None, None).unwrap();
606        let h_out = layer.forward(RnnInput::from(&inp)).unwrap();
607        assert_eq!(h_out.shape(), &[44, 12]);
608
609        let hidden = h_out.index((-1, ..));
610        let h_out = layer.forward(RnnInput::from((&inp, &hidden))).unwrap();
611        assert_eq!(h_out.shape(), &[44, 12]);
612    }
613
614    #[test]
615    fn test_gru() {
616        let mut layer = Gru::new(5, 12).unwrap();
617        let inp = normal::<f32>(&[2, 25, 5], None, None, None).unwrap();
618
619        let h_out = layer.forward(GruInput::from(&inp)).unwrap();
620        assert_eq!(h_out.shape(), &[2, 25, 12]);
621
622        let hidden = h_out.index((.., -1, ..));
623        let h_out = layer.forward(GruInput::from((&inp, &hidden))).unwrap();
624        assert_eq!(h_out.shape(), &[2, 25, 12]);
625
626        let inp = normal::<f32>(&[44, 5], None, None, None).unwrap();
627        let h_out = layer.forward(GruInput::from(&inp)).unwrap();
628        assert_eq!(h_out.shape(), &[44, 12]);
629
630        let hidden = h_out.index((-1, ..));
631        let h_out = layer.forward(GruInput::from((&inp, &hidden))).unwrap();
632        assert_eq!(h_out.shape(), &[44, 12]);
633    }
634
635    #[test]
636    fn test_lstm() {
637        let mut layer = Lstm::new(5, 12).unwrap();
638        let inp = normal::<f32>(&[2, 25, 5], None, None, None).unwrap();
639
640        let (h_out, c_out) = layer.forward(LstmInput::from(&inp)).unwrap();
641        assert_eq!(h_out.shape(), &[2, 25, 12]);
642        assert_eq!(c_out.shape(), &[2, 25, 12]);
643
644        let (h_out, c_out) = layer
645            .step(
646                &inp,
647                Some(&h_out.index((.., -1, ..))),
648                Some(&c_out.index((.., -1, ..))),
649            )
650            .unwrap();
651        assert_eq!(h_out.shape(), &[2, 25, 12]);
652        assert_eq!(c_out.shape(), &[2, 25, 12]);
653
654        let inp = normal::<f32>(&[44, 5], None, None, None).unwrap();
655        let (h_out, c_out) = layer.forward(LstmInput::from(&inp)).unwrap();
656        assert_eq!(h_out.shape(), &[44, 12]);
657        assert_eq!(c_out.shape(), &[44, 12]);
658
659        let hidden = h_out.index((-1, ..));
660        let cell = c_out.index((-1, ..));
661        let (h_out, c_out) = layer
662            .forward(LstmInput::from((&inp, &hidden, &cell)))
663            .unwrap();
664        assert_eq!(h_out.shape(), &[44, 12]);
665        assert_eq!(c_out.shape(), &[44, 12]);
666    }
667}