mnist/
mlp.rs

1use mlx_rs::{
2    error::Exception,
3    macros::ModuleParameters,
4    module::Module,
5    nn::{Linear, Relu, Sequential},
6    Array,
7};
8
9#[derive(Debug, ModuleParameters)]
10pub struct Mlp {
11    #[param]
12    pub layers: Sequential,
13}
14
15impl Module<&Array> for Mlp {
16    type Error = Exception;
17    type Output = Array;
18
19    fn forward(&mut self, x: &Array) -> Result<Array, Self::Error> {
20        self.layers.forward(x)
21    }
22
23    fn training_mode(&mut self, mode: bool) {
24        self.layers.training_mode(mode);
25    }
26}
27
28impl Mlp {
29    pub fn new(
30        num_layers: usize,
31        input_dim: i32,
32        hidden_dim: i32,
33        output_dim: i32,
34    ) -> Result<Self, Exception> {
35        let mut layers = Sequential::new();
36
37        // Add the input layer
38        layers = layers
39            .append(Linear::new(input_dim, hidden_dim)?)
40            .append(Relu);
41
42        // Add the hidden layers
43        for _ in 1..num_layers {
44            layers = layers
45                .append(Linear::new(hidden_dim, hidden_dim)?)
46                .append(Relu);
47        }
48
49        // Add the output layer
50        layers = layers.append(Linear::new(hidden_dim, output_dim)?);
51
52        Ok(Self { layers })
53    }
54}