mlx_rs/nn/
recurrent.rs

1use std::sync::Arc;
2
3use crate::{
4    array,
5    error::Exception,
6    module::{Module, Param},
7    ops::indexing::{Ellipsis, IndexOp},
8    ops::{addmm, matmul, sigmoid, split_equal, stack, tanh, tanh_device},
9    random::uniform,
10    Array, Stream,
11};
12use mlx_internal_macros::{generate_builder, Buildable, Builder};
13use mlx_macros::ModuleParameters;
14
15/// Type alias for the non-linearity function.
16pub type NonLinearity = dyn Fn(&Array, &Stream) -> Result<Array, Exception>;
17
18/// An Elman recurrent layer.
19///
20/// The input is a sequence of shape `NLD` or `LD` where:
21///
22/// * `N` is the optional batch dimension
23/// * `L` is the sequence length
24/// * `D` is the input's feature dimension
25///
26/// The hidden state `h` has shape `NH` or `H`, depending on
27/// whether the input is batched or not. Returns the hidden state at each
28/// time step, of shape `NLH` or `LH`.
29#[derive(Clone, ModuleParameters, Buildable)]
30#[module(root = crate)]
31#[buildable(root = crate)]
32pub struct Rnn {
33    /// non-linearity function to use
34    pub non_linearity: Arc<NonLinearity>,
35
36    /// Wxh
37    #[param]
38    pub wxh: Param<Array>,
39
40    /// Whh
41    #[param]
42    pub whh: Param<Array>,
43
44    /// Bias. Enabled by default.
45    #[param]
46    pub bias: Param<Option<Array>>,
47}
48
49/// Builder for the [`Rnn`] module.
50#[derive(Clone, Builder)]
51#[builder(
52    root = crate,
53    build_with = build_rnn,
54    err = Exception,
55)]
56pub struct RnnBuilder {
57    /// Dimension of the input, `D`.
58    pub input_size: i32,
59
60    /// Dimension of the hidden state, `H`.
61    pub hidden_size: i32,
62
63    /// non-linearity function to use. Default to `tanh` if not set.
64    #[builder(optional, default = Rnn::DEFAULT_NONLINEARITY)]
65    pub non_linearity: Option<Arc<NonLinearity>>,
66
67    /// Bias. Default to [`Rnn::DEFAULT_BIAS`].
68    #[builder(optional, default = Rnn::DEFAULT_BIAS)]
69    pub bias: bool,
70}
71
72impl std::fmt::Debug for RnnBuilder {
73    fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
74        f.debug_struct("RnnBuilder")
75            .field("bias", &self.bias)
76            .finish()
77    }
78}
79
80/// Build the [`Rnn`] module.
81fn build_rnn(builder: RnnBuilder) -> Result<Rnn, Exception> {
82    let input_size = builder.input_size;
83    let hidden_size = builder.hidden_size;
84    let non_linearity = builder
85        .non_linearity
86        .unwrap_or_else(|| Arc::new(|x, d| tanh_device(x, d)));
87
88    let scale = 1.0 / (input_size as f32).sqrt();
89    let wxh = uniform::<_, f32>(-scale, scale, &[hidden_size, input_size], None)?;
90    let whh = uniform::<_, f32>(-scale, scale, &[hidden_size, hidden_size], None)?;
91    let bias = if builder.bias {
92        Some(uniform::<_, f32>(-scale, scale, &[hidden_size], None)?)
93    } else {
94        None
95    };
96
97    Ok(Rnn {
98        non_linearity,
99        wxh: Param::new(wxh),
100        whh: Param::new(whh),
101        bias: Param::new(bias),
102    })
103}
104
105impl std::fmt::Debug for Rnn {
106    fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
107        f.debug_struct("Rnn")
108            .field("wxh", &self.wxh)
109            .field("whh", &self.whh)
110            .field("bias", &self.bias)
111            .finish()
112    }
113}
114
115impl Rnn {
116    /// Default value for bias
117    pub const DEFAULT_BIAS: bool = true;
118
119    /// RnnBuilder::non_linearity is initialized with `None`, and the default non-linearity is `tanh` if not set.
120    pub const DEFAULT_NONLINEARITY: Option<Arc<NonLinearity>> = None;
121
122    /// Apply a single step of the RNN.
123    pub fn step(&mut self, x: &Array, hidden: Option<&Array>) -> Result<Array, Exception> {
124        let x = if let Some(bias) = &self.bias.value {
125            addmm(bias, x, self.wxh.t(), None, None)?
126        } else {
127            matmul(x, self.wxh.t())?
128        };
129
130        let mut all_hidden = Vec::new();
131        for index in 0..x.dim(-2) {
132            let hidden = match hidden {
133                Some(hidden_) => addmm(
134                    x.index((Ellipsis, index, 0..)),
135                    hidden_,
136                    self.whh.t(),
137                    None,
138                    None,
139                )?,
140                None => x.index((Ellipsis, index, 0..)),
141            };
142
143            let hidden = (self.non_linearity)(&hidden, &Stream::default())?;
144            all_hidden.push(hidden);
145        }
146
147        stack(&all_hidden[..], -2)
148    }
149}
150
151generate_builder! {
152    /// Input for the RNN module.
153    #[derive(Debug, Clone, Buildable)]
154    #[buildable(root = crate)]
155    #[builder(root = crate)]
156    pub struct RnnInput<'a> {
157        /// Input tensor
158        pub x: &'a Array,
159
160        /// Hidden state
161        #[builder(optional, default = None)]
162        pub hidden: Option<&'a Array>,
163    }
164}
165
166impl<'a> From<&'a Array> for RnnInput<'a> {
167    fn from(x: &'a Array) -> Self {
168        RnnInput { x, hidden: None }
169    }
170}
171
172impl<'a> From<(&'a Array,)> for RnnInput<'a> {
173    fn from(input: (&'a Array,)) -> Self {
174        RnnInput {
175            x: input.0,
176            hidden: None,
177        }
178    }
179}
180
181impl<'a> From<(&'a Array, &'a Array)> for RnnInput<'a> {
182    fn from(input: (&'a Array, &'a Array)) -> Self {
183        RnnInput {
184            x: input.0,
185            hidden: Some(input.1),
186        }
187    }
188}
189
190impl<'a> From<(&'a Array, Option<&'a Array>)> for RnnInput<'a> {
191    fn from(input: (&'a Array, Option<&'a Array>)) -> Self {
192        RnnInput {
193            x: input.0,
194            hidden: input.1,
195        }
196    }
197}
198
199impl<'a, Input> Module<Input> for Rnn
200where
201    Input: Into<RnnInput<'a>>,
202{
203    type Error = Exception;
204    type Output = Array;
205
206    fn forward(&mut self, input: Input) -> Result<Array, Exception> {
207        let input = input.into();
208        self.step(input.x, input.hidden)
209    }
210
211    fn training_mode(&mut self, _mode: bool) {}
212}
213
214/// A gated recurrent unit (GRU) RNN layer.
215///
216/// The input has shape `NLD` or `LD` where:
217///
218/// * `N` is the optional batch dimension
219/// * `L` is the sequence length
220/// * `D` is the input's feature dimension
221///
222/// The hidden state `h` has shape `NH` or `H`, depending on
223/// whether the input is batched or not. Returns the hidden state at each
224/// time step, of shape `NLH` or `LH`.
225#[derive(Debug, Clone, ModuleParameters, Buildable)]
226#[module(root = crate)]
227#[buildable(root = crate)]
228pub struct Gru {
229    /// Dimension of the hidden state, `H`
230    pub hidden_size: i32,
231
232    /// Wx
233    #[param]
234    pub wx: Param<Array>,
235
236    /// Wh
237    #[param]
238    pub wh: Param<Array>,
239
240    /// Bias. Enabled by default.
241    #[param]
242    pub bias: Param<Option<Array>>,
243
244    /// bhn. Enabled by default.
245    #[param]
246    pub bhn: Param<Option<Array>>,
247}
248
249/// Builder for the [`Gru`] module.
250#[derive(Debug, Clone, Builder)]
251#[builder(
252    root = crate,
253    build_with = build_gru,
254    err = Exception,
255)]
256pub struct GruBuilder {
257    /// Dimension of the input, `D`.
258    pub input_size: i32,
259
260    /// Dimension of the hidden state, `H`.
261    pub hidden_size: i32,
262
263    /// Bias. Default to [`Gru::DEFAULT_BIAS`].
264    #[builder(optional, default = Gru::DEFAULT_BIAS)]
265    pub bias: bool,
266}
267
268fn build_gru(builder: GruBuilder) -> Result<Gru, Exception> {
269    let input_size = builder.input_size;
270    let hidden_size = builder.hidden_size;
271
272    let scale = 1.0 / f32::sqrt(hidden_size as f32);
273    let wx = uniform::<_, f32>(-scale, scale, &[3 * hidden_size, input_size], None)?;
274    let wh = uniform::<_, f32>(-scale, scale, &[3 * hidden_size, hidden_size], None)?;
275    let (bias, bhn) = if builder.bias {
276        let bias = uniform::<_, f32>(-scale, scale, &[3 * hidden_size], None)?;
277        let bhn = uniform::<_, f32>(-scale, scale, &[hidden_size], None)?;
278        (Some(bias), Some(bhn))
279    } else {
280        (None, None)
281    };
282
283    Ok(Gru {
284        hidden_size,
285        wx: Param::new(wx),
286        wh: Param::new(wh),
287        bias: Param::new(bias),
288        bhn: Param::new(bhn),
289    })
290}
291
292impl Gru {
293    /// Enable `bias` and `bhn` by default
294    pub const DEFAULT_BIAS: bool = true;
295
296    /// Apply a single step of the GRU.
297    pub fn step(&mut self, x: &Array, hidden: Option<&Array>) -> Result<Array, Exception> {
298        let x = if let Some(b) = &self.bias.value {
299            addmm(b, x, self.wx.t(), None, None)?
300        } else {
301            matmul(x, self.wx.t())?
302        };
303
304        let x_rz = x.index((Ellipsis, ..(-self.hidden_size)));
305        let x_n = x.index((Ellipsis, (-self.hidden_size)..));
306
307        let mut all_hidden = Vec::new();
308
309        for index in 0..x.dim(-2) {
310            let mut rz = x_rz.index((Ellipsis, index, ..));
311            let mut h_proj_n = None;
312            if let Some(hidden_) = hidden {
313                let h_proj = matmul(hidden_, self.wh.t())?;
314                let h_proj_rz = h_proj.index((Ellipsis, ..(-self.hidden_size)));
315                h_proj_n = Some(h_proj.index((Ellipsis, (-self.hidden_size)..)));
316
317                if let Some(bhn) = &self.bhn.value {
318                    h_proj_n = h_proj_n
319                        .map(|h_proj_n| h_proj_n.add(bhn))
320                        // This is not matrix transpose, but from `Option<Result<_>>` to `Result<Option<_>>`
321                        .transpose()?;
322                }
323
324                rz = rz.add(h_proj_rz)?;
325            }
326
327            rz = sigmoid(&rz)?;
328
329            let parts = split_equal(&rz, 2, -1)?;
330            let r = &parts[0];
331            let z = &parts[1];
332
333            let mut n = x_n.index((Ellipsis, index, 0..));
334
335            if let Some(h_proj_n) = h_proj_n {
336                n = n.add(r.multiply(h_proj_n)?)?;
337            }
338            n = tanh(&n)?;
339
340            let hidden = match hidden {
341                Some(hidden) => array!(1.0)
342                    .subtract(z)?
343                    .multiply(&n)?
344                    .add(z.multiply(hidden)?)?,
345                None => array!(1.0).subtract(z)?.multiply(&n)?,
346            };
347
348            all_hidden.push(hidden);
349        }
350
351        stack(&all_hidden[..], -2)
352    }
353}
354
355/// Type alias for the input of the GRU module.
356pub type GruInput<'a> = RnnInput<'a>;
357
358/// Type alias for the builder of the input of the GRU module.
359pub type GruInputBuilder<'a> = RnnInputBuilder<'a>;
360
361impl<'a, Input> Module<Input> for Gru
362where
363    Input: Into<GruInput<'a>>,
364{
365    type Error = Exception;
366    type Output = Array;
367
368    fn forward(&mut self, input: Input) -> Result<Array, Exception> {
369        let input = input.into();
370        self.step(input.x, input.hidden)
371    }
372
373    fn training_mode(&mut self, _mode: bool) {}
374}
375
376/// A long short-term memory (LSTM) RNN layer.
377#[derive(Debug, Clone, ModuleParameters, Buildable)]
378#[module(root = crate)]
379#[buildable(root = crate)]
380pub struct Lstm {
381    /// Wx
382    #[param]
383    pub wx: Param<Array>,
384
385    /// Wh
386    #[param]
387    pub wh: Param<Array>,
388
389    /// Bias. Enabled by default.
390    #[param]
391    pub bias: Param<Option<Array>>,
392}
393
394/// Builder for the [`Lstm`] module.
395#[derive(Debug, Clone, Builder)]
396#[builder(
397    root = crate,
398    build_with = build_lstm,
399    err = Exception,
400)]
401pub struct LstmBuilder {
402    /// Dimension of the input, `D`.
403    pub input_size: i32,
404
405    /// Dimension of the hidden state, `H`.
406    pub hidden_size: i32,
407
408    /// Bias. Default to [`Lstm::DEFAULT_BIAS`].
409    #[builder(optional, default = Lstm::DEFAULT_BIAS)]
410    pub bias: bool,
411}
412
413fn build_lstm(builder: LstmBuilder) -> Result<Lstm, Exception> {
414    let input_size = builder.input_size;
415    let hidden_size = builder.hidden_size;
416    let scale = 1.0 / f32::sqrt(hidden_size as f32);
417    let wx = uniform::<_, f32>(-scale, scale, &[4 * hidden_size, input_size], None)?;
418    let wh = uniform::<_, f32>(-scale, scale, &[4 * hidden_size, hidden_size], None)?;
419    let bias = if builder.bias {
420        Some(uniform::<_, f32>(-scale, scale, &[4 * hidden_size], None)?)
421    } else {
422        None
423    };
424
425    Ok(Lstm {
426        wx: Param::new(wx),
427        wh: Param::new(wh),
428        bias: Param::new(bias),
429    })
430}
431
432generate_builder! {
433    /// Input for the LSTM module.
434    #[derive(Debug, Clone, Buildable)]
435    #[buildable(root = crate)]
436    #[builder(root = crate)]
437    pub struct LstmInput<'a> {
438        /// Input tensor
439        pub x: &'a Array,
440
441        /// Hidden state
442        #[builder(optional, default = None)]
443        pub hidden: Option<&'a Array>,
444
445        /// Cell state
446        #[builder(optional, default = None)]
447        pub cell: Option<&'a Array>,
448    }
449}
450
451impl<'a> From<&'a Array> for LstmInput<'a> {
452    fn from(x: &'a Array) -> Self {
453        LstmInput {
454            x,
455            hidden: None,
456            cell: None,
457        }
458    }
459}
460
461impl<'a> From<(&'a Array,)> for LstmInput<'a> {
462    fn from(input: (&'a Array,)) -> Self {
463        LstmInput {
464            x: input.0,
465            hidden: None,
466            cell: None,
467        }
468    }
469}
470
471impl<'a> From<(&'a Array, &'a Array)> for LstmInput<'a> {
472    fn from(input: (&'a Array, &'a Array)) -> Self {
473        LstmInput {
474            x: input.0,
475            hidden: Some(input.1),
476            cell: None,
477        }
478    }
479}
480
481impl<'a> From<(&'a Array, &'a Array, &'a Array)> for LstmInput<'a> {
482    fn from(input: (&'a Array, &'a Array, &'a Array)) -> Self {
483        LstmInput {
484            x: input.0,
485            hidden: Some(input.1),
486            cell: Some(input.2),
487        }
488    }
489}
490
491impl<'a> From<(&'a Array, Option<&'a Array>)> for LstmInput<'a> {
492    fn from(input: (&'a Array, Option<&'a Array>)) -> Self {
493        LstmInput {
494            x: input.0,
495            hidden: input.1,
496            cell: None,
497        }
498    }
499}
500
501impl<'a> From<(&'a Array, Option<&'a Array>, Option<&'a Array>)> for LstmInput<'a> {
502    fn from(input: (&'a Array, Option<&'a Array>, Option<&'a Array>)) -> Self {
503        LstmInput {
504            x: input.0,
505            hidden: input.1,
506            cell: input.2,
507        }
508    }
509}
510
511impl Lstm {
512    /// Default value for `bias`
513    pub const DEFAULT_BIAS: bool = true;
514
515    /// Apply a single step of the LSTM.
516    pub fn step(
517        &mut self,
518        x: &Array,
519        hidden: Option<&Array>,
520        cell: Option<&Array>,
521    ) -> Result<(Array, Array), Exception> {
522        let x = if let Some(b) = &self.bias.value {
523            addmm(b, x, self.wx.t(), None, None)?
524        } else {
525            matmul(x, self.wx.t())?
526        };
527
528        let mut all_hidden = Vec::new();
529        let mut all_cell = Vec::new();
530
531        for index in 0..x.dim(-2) {
532            let mut ifgo = x.index((Ellipsis, index, 0..));
533            if let Some(hidden) = hidden {
534                ifgo = addmm(&ifgo, hidden, self.wh.t(), None, None)?;
535            }
536
537            let pieces = split_equal(&ifgo, 4, -1)?;
538
539            let i = sigmoid(&pieces[0])?;
540            let f = sigmoid(&pieces[1])?;
541            let g = tanh(&pieces[2])?;
542            let o = sigmoid(&pieces[3])?;
543
544            let cell = match cell {
545                Some(cell) => f.multiply(cell)?.add(i.multiply(&g)?)?,
546                None => i.multiply(&g)?,
547            };
548
549            let hidden = o.multiply(tanh(&cell)?)?;
550
551            all_hidden.push(hidden);
552            all_cell.push(cell);
553        }
554
555        Ok((stack(&all_hidden[..], -2)?, stack(&all_cell[..], -2)?))
556    }
557}
558
559impl<'a, Input> Module<Input> for Lstm
560where
561    Input: Into<LstmInput<'a>>,
562{
563    type Output = (Array, Array);
564    type Error = Exception;
565
566    fn forward(&mut self, input: Input) -> Result<(Array, Array), Exception> {
567        let input = input.into();
568        self.step(input.x, input.hidden, input.cell)
569    }
570
571    fn training_mode(&mut self, _mode: bool) {}
572}
573
574// The uint tests below are ported from the python codebase
575#[cfg(test)]
576mod tests {
577    use crate::{builder::Builder, ops::maximum_device, random::normal};
578
579    use super::*;
580
581    #[test]
582    fn test_rnn() {
583        let mut layer = Rnn::new(5, 12).unwrap();
584        let inp = normal::<f32>(&[2, 25, 5], None, None, None).unwrap();
585
586        let h_out = layer.forward(RnnInput::from(&inp)).unwrap();
587        assert_eq!(h_out.shape(), &[2, 25, 12]);
588
589        let nonlinearity = |x: &Array, d: &Stream| maximum_device(x, array!(0.0), d);
590        let mut layer = RnnBuilder::new(5, 12)
591            .bias(false)
592            .non_linearity(Arc::new(nonlinearity) as Arc<NonLinearity>)
593            .build()
594            .unwrap();
595
596        let h_out = layer.forward(RnnInput::from(&inp)).unwrap();
597        assert_eq!(h_out.shape(), &[2, 25, 12]);
598
599        let inp = normal::<f32>(&[44, 5], None, None, None).unwrap();
600        let h_out = layer.forward(RnnInput::from(&inp)).unwrap();
601        assert_eq!(h_out.shape(), &[44, 12]);
602
603        let hidden = h_out.index((-1, ..));
604        let h_out = layer.forward(RnnInput::from((&inp, &hidden))).unwrap();
605        assert_eq!(h_out.shape(), &[44, 12]);
606    }
607
608    #[test]
609    fn test_gru() {
610        let mut layer = Gru::new(5, 12).unwrap();
611        let inp = normal::<f32>(&[2, 25, 5], None, None, None).unwrap();
612
613        let h_out = layer.forward(GruInput::from(&inp)).unwrap();
614        assert_eq!(h_out.shape(), &[2, 25, 12]);
615
616        let hidden = h_out.index((.., -1, ..));
617        let h_out = layer.forward(GruInput::from((&inp, &hidden))).unwrap();
618        assert_eq!(h_out.shape(), &[2, 25, 12]);
619
620        let inp = normal::<f32>(&[44, 5], None, None, None).unwrap();
621        let h_out = layer.forward(GruInput::from(&inp)).unwrap();
622        assert_eq!(h_out.shape(), &[44, 12]);
623
624        let hidden = h_out.index((-1, ..));
625        let h_out = layer.forward(GruInput::from((&inp, &hidden))).unwrap();
626        assert_eq!(h_out.shape(), &[44, 12]);
627    }
628
629    #[test]
630    fn test_lstm() {
631        let mut layer = Lstm::new(5, 12).unwrap();
632        let inp = normal::<f32>(&[2, 25, 5], None, None, None).unwrap();
633
634        let (h_out, c_out) = layer.forward(LstmInput::from(&inp)).unwrap();
635        assert_eq!(h_out.shape(), &[2, 25, 12]);
636        assert_eq!(c_out.shape(), &[2, 25, 12]);
637
638        let (h_out, c_out) = layer
639            .step(
640                &inp,
641                Some(&h_out.index((.., -1, ..))),
642                Some(&c_out.index((.., -1, ..))),
643            )
644            .unwrap();
645        assert_eq!(h_out.shape(), &[2, 25, 12]);
646        assert_eq!(c_out.shape(), &[2, 25, 12]);
647
648        let inp = normal::<f32>(&[44, 5], None, None, None).unwrap();
649        let (h_out, c_out) = layer.forward(LstmInput::from(&inp)).unwrap();
650        assert_eq!(h_out.shape(), &[44, 12]);
651        assert_eq!(c_out.shape(), &[44, 12]);
652
653        let hidden = h_out.index((-1, ..));
654        let cell = c_out.index((-1, ..));
655        let (h_out, c_out) = layer
656            .forward(LstmInput::from((&inp, &hidden, &cell)))
657            .unwrap();
658        assert_eq!(h_out.shape(), &[44, 12]);
659        assert_eq!(c_out.shape(), &[44, 12]);
660    }
661}