mlx_rs/nn/
container.rs

1use std::borrow::Cow;
2
3use crate::module::{Module, UnaryModule};
4use crate::{error::Exception, Array};
5use mlx_macros::ModuleParameters;
6
7/// Marker trait for items that can be used in a `Sequential` module.
8///
9/// It is implemented for all types that implement [`Module`] and [`std::fmt::Debug`].
10pub trait SequentialModuleItem: UnaryModule + std::fmt::Debug {}
11
12impl<T> SequentialModuleItem for T where T: UnaryModule + std::fmt::Debug {}
13
14/// A sequential layer.
15///
16/// It calls each layer in sequence.
17#[derive(Debug, ModuleParameters)]
18#[module(root = crate)]
19pub struct Sequential<Err = Exception> {
20    /// The layers to be called in sequence.
21    #[param]
22    pub layers: Vec<Box<dyn SequentialModuleItem<Error = Err>>>,
23}
24
25impl Module<&Array> for Sequential {
26    type Error = Exception;
27    type Output = Array;
28
29    fn forward(&mut self, x: &Array) -> Result<Array, Self::Error> {
30        let mut x = Cow::Borrowed(x);
31
32        for layer in &mut self.layers {
33            x = Cow::Owned(layer.forward(x.as_ref())?);
34        }
35
36        match x {
37            Cow::Owned(array) => Ok(array),
38            Cow::Borrowed(array) => Ok(array.clone()),
39        }
40    }
41
42    fn training_mode(&mut self, mode: bool) {
43        self.layers
44            .iter_mut()
45            .for_each(|layer| layer.training_mode(mode));
46    }
47}
48
49impl<Err> Default for Sequential<Err> {
50    fn default() -> Self {
51        Self::new()
52    }
53}
54
55impl<Err> Sequential<Err> {
56    /// Creates a new [`Sequential`] module.
57    pub fn new() -> Self {
58        Self { layers: Vec::new() }
59    }
60
61    /// Appends a layer to the sequential module.
62    pub fn append<M>(mut self, layer: M) -> Self
63    where
64        M: UnaryModule<Error = Err> + std::fmt::Debug + 'static,
65    {
66        self.layers.push(Box::new(layer));
67        self
68    }
69}
70
71#[cfg(test)]
72mod tests {
73    use crate::{
74        array,
75        builder::Builder,
76        module::ModuleParameters,
77        nn::{self, Linear},
78        ops::zeros,
79        optimizers::{Optimizer, Sgd},
80        random::uniform,
81        transforms::{eval, eval_params},
82    };
83
84    use crate::losses::{LossReduction, MseLossBuilder};
85
86    use super::*;
87
88    #[test]
89    fn test_sequential_linear_param_len() {
90        let model = Sequential::new()
91            .append(Linear::new(2, 3).unwrap())
92            .append(Linear::new(3, 1).unwrap());
93
94        let params = model.parameters().flatten();
95        assert_eq!(params.len(), 4);
96    }
97
98    #[test]
99    fn test_sequential_linear_param_update() {
100        let mut model = Sequential::new()
101            .append(Linear::new(2, 3).unwrap())
102            .append(Linear::new(3, 1).unwrap());
103
104        model
105            .trainable_parameters()
106            .flatten()
107            .iter()
108            .for_each(|(key, value)| {
109                println!("{}: {:?}", key, value);
110            });
111
112        let mut params = model.parameters_mut().flatten();
113
114        // Check that the initial weights are not all zeros
115        assert!(
116            params["layers.0.weight"]
117                .abs()
118                .unwrap()
119                .sum(None, None)
120                .unwrap()
121                .item::<f32>()
122                - 0.0
123                > 1e-6
124        );
125
126        // Update the weight with zeros
127        let shape = params["layers.0.weight"].shape();
128        let zeros = zeros::<f32>(shape).unwrap();
129        let value_mut = params.get_mut("layers.0.weight").unwrap();
130        **value_mut = zeros;
131
132        // Check that the weight is now all zeros
133        let first_layer = &model.layers[0];
134        let linear_params = first_layer.parameters().flatten();
135        let weight = linear_params["weight"];
136        assert!(weight.abs().unwrap().sum(None, None).unwrap().item::<f32>() - 0.0 < 1e-6);
137    }
138
139    #[test]
140    fn test_sgd_update_sequential_linear_params() {
141        let lr = 1e-2;
142        let input_dim = 2;
143        let hidden_dim = 3;
144        let output_dim = 2;
145
146        // Test using a simple linear equation
147        let m = array!(0.25);
148        let b = array!(0.75);
149
150        let mut model = Sequential::new()
151            .append(Linear::new(input_dim, hidden_dim).unwrap())
152            .append(Linear::new(hidden_dim, output_dim).unwrap());
153
154        let loss = MseLossBuilder::new()
155            .reduction(LossReduction::Mean)
156            .build()
157            .unwrap();
158        let loss_fn = |model: &mut Sequential, (x, y): (&Array, &Array)| {
159            let y_pred = model.forward(x)?;
160            loss.apply(&y_pred, y)
161        };
162
163        let mut lg = nn::value_and_grad(loss_fn);
164
165        let mut optimizer = Sgd::new(lr);
166
167        let mut losses = vec![];
168        for _ in 0..100 {
169            // Generate random data
170            let x = uniform::<_, f32>(-5.0, 5.0, &[input_dim], None).unwrap();
171            let y = &m * &x + &b;
172
173            eval([&x, &y]).unwrap();
174
175            // Compute the loss and gradients and update the model
176            let (loss, grads) = lg(&mut model, (&x, &y)).unwrap();
177            optimizer.update(&mut model, grads).unwrap();
178
179            eval_params(model.parameters()).unwrap();
180
181            losses.push(loss.item::<f32>());
182        }
183
184        // Check that it converges
185        assert!(
186            losses[0] > losses[losses.len() - 1],
187            "Not converging loss: {:?}",
188            losses
189        );
190    }
191}