mlx_rs/nn/
container.rs
1use std::borrow::Cow;
2
3use crate::module::{Module, UnaryModule};
4use crate::{error::Exception, Array};
5use mlx_macros::ModuleParameters;
6
7pub trait SequentialModuleItem: UnaryModule + std::fmt::Debug {}
11
12impl<T> SequentialModuleItem for T where T: UnaryModule + std::fmt::Debug {}
13
14#[derive(Debug, ModuleParameters)]
18#[module(root = crate)]
19pub struct Sequential<Err = Exception> {
20 #[param]
22 pub layers: Vec<Box<dyn SequentialModuleItem<Error = Err>>>,
23}
24
25impl Module<&Array> for Sequential {
26 type Error = Exception;
27 type Output = Array;
28
29 fn forward(&mut self, x: &Array) -> Result<Array, Self::Error> {
30 let mut x = Cow::Borrowed(x);
31
32 for layer in &mut self.layers {
33 x = Cow::Owned(layer.forward(x.as_ref())?);
34 }
35
36 match x {
37 Cow::Owned(array) => Ok(array),
38 Cow::Borrowed(array) => Ok(array.clone()),
39 }
40 }
41
42 fn training_mode(&mut self, mode: bool) {
43 self.layers
44 .iter_mut()
45 .for_each(|layer| layer.training_mode(mode));
46 }
47}
48
49impl<Err> Default for Sequential<Err> {
50 fn default() -> Self {
51 Self::new()
52 }
53}
54
55impl<Err> Sequential<Err> {
56 pub fn new() -> Self {
58 Self { layers: Vec::new() }
59 }
60
61 pub fn append<M>(mut self, layer: M) -> Self
63 where
64 M: UnaryModule<Error = Err> + std::fmt::Debug + 'static,
65 {
66 self.layers.push(Box::new(layer));
67 self
68 }
69}
70
71#[cfg(test)]
72mod tests {
73 use crate::{
74 array,
75 builder::Builder,
76 module::ModuleParameters,
77 nn::{self, Linear},
78 ops::zeros,
79 optimizers::{Optimizer, Sgd},
80 random::uniform,
81 transforms::{eval, eval_params},
82 };
83
84 use crate::losses::{LossReduction, MseLossBuilder};
85
86 use super::*;
87
88 #[test]
89 fn test_sequential_linear_param_len() {
90 let model = Sequential::new()
91 .append(Linear::new(2, 3).unwrap())
92 .append(Linear::new(3, 1).unwrap());
93
94 let params = model.parameters().flatten();
95 assert_eq!(params.len(), 4);
96 }
97
98 #[test]
99 fn test_sequential_linear_param_update() {
100 let mut model = Sequential::new()
101 .append(Linear::new(2, 3).unwrap())
102 .append(Linear::new(3, 1).unwrap());
103
104 model
105 .trainable_parameters()
106 .flatten()
107 .iter()
108 .for_each(|(key, value)| {
109 println!("{}: {:?}", key, value);
110 });
111
112 let mut params = model.parameters_mut().flatten();
113
114 assert!(
116 params["layers.0.weight"]
117 .abs()
118 .unwrap()
119 .sum(None, None)
120 .unwrap()
121 .item::<f32>()
122 - 0.0
123 > 1e-6
124 );
125
126 let shape = params["layers.0.weight"].shape();
128 let zeros = zeros::<f32>(shape).unwrap();
129 let value_mut = params.get_mut("layers.0.weight").unwrap();
130 **value_mut = zeros;
131
132 let first_layer = &model.layers[0];
134 let linear_params = first_layer.parameters().flatten();
135 let weight = linear_params["weight"];
136 assert!(weight.abs().unwrap().sum(None, None).unwrap().item::<f32>() - 0.0 < 1e-6);
137 }
138
139 #[test]
140 fn test_sgd_update_sequential_linear_params() {
141 let lr = 1e-2;
142 let input_dim = 2;
143 let hidden_dim = 3;
144 let output_dim = 2;
145
146 let m = array!(0.25);
148 let b = array!(0.75);
149
150 let mut model = Sequential::new()
151 .append(Linear::new(input_dim, hidden_dim).unwrap())
152 .append(Linear::new(hidden_dim, output_dim).unwrap());
153
154 let loss = MseLossBuilder::new()
155 .reduction(LossReduction::Mean)
156 .build()
157 .unwrap();
158 let loss_fn = |model: &mut Sequential, (x, y): (&Array, &Array)| {
159 let y_pred = model.forward(x)?;
160 loss.apply(&y_pred, y)
161 };
162
163 let mut lg = nn::value_and_grad(loss_fn);
164
165 let mut optimizer = Sgd::new(lr);
166
167 let mut losses = vec![];
168 for _ in 0..100 {
169 let x = uniform::<_, f32>(-5.0, 5.0, &[input_dim], None).unwrap();
171 let y = &m * &x + &b;
172
173 eval([&x, &y]).unwrap();
174
175 let (loss, grads) = lg(&mut model, (&x, &y)).unwrap();
177 optimizer.update(&mut model, grads).unwrap();
178
179 eval_params(model.parameters()).unwrap();
180
181 losses.push(loss.item::<f32>());
182 }
183
184 assert!(
186 losses[0] > losses[losses.len() - 1],
187 "Not converging loss: {:?}",
188 losses
189 );
190 }
191}