1use mlx_rs::{
2 error::Exception,
3 macros::ModuleParameters,
4 module::Module,
5 nn::{Linear, Relu, Sequential},
6 Array,
7};
8
9#[derive(Debug, ModuleParameters)]
10pub struct Mlp {
11 #[param]
12 pub layers: Sequential,
13}
14
15impl Module<&Array> for Mlp {
16 type Error = Exception;
17 type Output = Array;
18
19 fn forward(&mut self, x: &Array) -> Result<Array, Self::Error> {
20 self.layers.forward(x)
21 }
22
23 fn training_mode(&mut self, mode: bool) {
24 self.layers.training_mode(mode);
25 }
26}
27
28impl Mlp {
29 pub fn new(
30 num_layers: usize,
31 input_dim: i32,
32 hidden_dim: i32,
33 output_dim: i32,
34 ) -> Result<Self, Exception> {
35 let mut layers = Sequential::new();
36
37 layers = layers
39 .append(Linear::new(input_dim, hidden_dim)?)
40 .append(Relu);
41
42 for _ in 1..num_layers {
44 layers = layers
45 .append(Linear::new(hidden_dim, hidden_dim)?)
46 .append(Relu);
47 }
48
49 layers = layers.append(Linear::new(hidden_dim, output_dim)?);
51
52 Ok(Self { layers })
53 }
54}