mlx_rs/nn/
value_and_grad.rs

1use crate::module::{update_parameters, ModuleParameters};
2use crate::transforms::keyed_value_and_grad;
3use crate::{error::Exception, Array};
4
5use crate::module::FlattenedModuleParam;
6
7fn trainable_params(model: &impl ModuleParameters) -> FlattenedModuleParam {
8    model
9        .trainable_parameters()
10        .flatten()
11        .into_iter()
12        .map(|(k, v)| (k, v.clone()))
13        .collect()
14}
15
16/// Helper trait for [`value_and_grad`]
17pub trait IntoModuleValueAndGrad<'a, M, Args, Val, Err>
18where
19    M: ModuleParameters + 'a,
20    Args: Clone,
21{
22    /// Computes the valud and gradient of the passed function `f(model, args)` with regard to the
23    /// model's trainable parameters.
24    fn into_module_value_and_grad(
25        self,
26    ) -> impl FnMut(&mut M, Args) -> Result<(Val, FlattenedModuleParam), Exception> + 'a;
27}
28
29impl<'a, F, M, Args> IntoModuleValueAndGrad<'a, M, Args, Vec<Array>, ()> for F
30where
31    M: ModuleParameters + 'a,
32    F: FnMut(&mut M, Args) -> Vec<Array> + 'a,
33    Args: Clone,
34{
35    fn into_module_value_and_grad(
36        mut self,
37    ) -> impl FnMut(&mut M, Args) -> Result<(Vec<Array>, FlattenedModuleParam), Exception> + 'a
38    {
39        move |model, arrays| {
40            let trainable_parameters = trainable_params(model);
41            let inner = |parameters: FlattenedModuleParam, arrays: Args| -> Vec<Array> {
42                let flattened_parameters = parameters.into_iter();
43                update_parameters(model, flattened_parameters);
44
45                self(model, arrays)
46            };
47            let mut vg = keyed_value_and_grad(inner);
48
49            let (v, g) = vg(trainable_parameters, arrays)?;
50            Ok((v, g))
51        }
52    }
53}
54
55impl<'a, F, M, Args> IntoModuleValueAndGrad<'a, M, Args, Vec<Array>, Exception> for F
56where
57    M: ModuleParameters + 'a,
58    F: FnMut(&mut M, Args) -> Result<Vec<Array>, Exception> + 'a,
59    Args: Clone,
60{
61    fn into_module_value_and_grad(
62        mut self,
63    ) -> impl FnMut(&mut M, Args) -> Result<(Vec<Array>, FlattenedModuleParam), Exception> + 'a
64    {
65        move |model, arrays| {
66            let trainable_parameters = trainable_params(model);
67            let inner =
68                |parameters: FlattenedModuleParam, arrays: Args| -> Result<Vec<Array>, Exception> {
69                    let flattened_parameters = parameters.into_iter().map(|(k, v)| (k, v.clone()));
70                    update_parameters(model, flattened_parameters);
71
72                    self(model, arrays)
73                };
74            let mut vg = keyed_value_and_grad(inner);
75
76            let (v, g) = vg(trainable_parameters, arrays)?;
77            Ok((v, g))
78        }
79    }
80}
81
82impl<'a, F, M, Args> IntoModuleValueAndGrad<'a, M, Args, Array, ()> for F
83where
84    M: ModuleParameters + 'a,
85    F: FnMut(&mut M, Args) -> Array + 'a,
86    Args: Clone,
87{
88    fn into_module_value_and_grad(
89        mut self,
90    ) -> impl FnMut(&mut M, Args) -> Result<(Array, FlattenedModuleParam), Exception> + 'a {
91        move |model, arrays| {
92            let trainable_parameters = trainable_params(model);
93            let inner = |parameters: FlattenedModuleParam, arrays: Args| -> Vec<Array> {
94                let flattened_parameters = parameters.into_iter().map(|(k, v)| (k, v.clone()));
95                update_parameters(model, flattened_parameters);
96
97                vec![self(model, arrays)]
98            };
99            let mut vg = keyed_value_and_grad(inner);
100
101            let (v, g) = vg(trainable_parameters, arrays)?;
102            let v = v.into_iter().next().expect("Expected a single value");
103            Ok((v, g))
104        }
105    }
106}
107
108impl<'a, F, M, Args> IntoModuleValueAndGrad<'a, M, Args, Array, Exception> for F
109where
110    M: ModuleParameters + 'a,
111    F: FnMut(&mut M, Args) -> Result<Array, Exception> + 'a,
112    Args: Clone,
113{
114    fn into_module_value_and_grad(
115        mut self,
116    ) -> impl FnMut(&mut M, Args) -> Result<(Array, FlattenedModuleParam), Exception> + 'a {
117        move |model, arrays| {
118            let trainable_parameters = trainable_params(model);
119            let inner =
120                |parameters: FlattenedModuleParam, arrays: Args| -> Result<Vec<Array>, Exception> {
121                    let flattened_parameters = parameters.into_iter().map(|(k, v)| (k, v.clone()));
122                    update_parameters(model, flattened_parameters);
123
124                    self(model, arrays).map(|v| vec![v])
125                };
126            let mut vg = keyed_value_and_grad(inner);
127
128            let (v, g) = vg(trainable_parameters, arrays)?;
129            let v = v.into_iter().next().expect("Expected a single value");
130            Ok((v, g))
131        }
132    }
133}
134
135/// Transform the passed function `f(model, args)` to a function that computes the gradients of `f`
136/// with regard to the model's trainable parameters and also its value.
137pub fn value_and_grad<'a, F, M, Args, Val, Err>(
138    f: F,
139) -> impl FnMut(&mut M, Args) -> Result<(Val, FlattenedModuleParam), Exception> + 'a
140where
141    M: ModuleParameters + 'a,
142    F: IntoModuleValueAndGrad<'a, M, Args, Val, Err>,
143    Args: Clone,
144{
145    f.into_module_value_and_grad()
146}
147
148#[cfg(test)]
149mod tests {
150    use crate::module::Module;
151    use crate::{array, error::Exception, Array};
152
153    use crate::nn::{self, Linear};
154
155    // The unit test below is adapted from `test_compiled_optimizer` in
156    // `mlx/python/tests/test_optimizers.py``
157    #[test]
158    fn test_value_and_grad() {
159        let mut model = Linear::new(2, 2).unwrap();
160        let x = crate::random::uniform::<_, f32>(1.0, 2.0, &[2, 2], None).unwrap();
161
162        let loss = |model: &mut Linear, x: &Array| -> Vec<Array> {
163            vec![model.forward(x).unwrap().sum(None, None).unwrap()]
164        };
165
166        let mut vg = nn::value_and_grad(loss);
167        let (v, g) = vg(&mut model, &x).unwrap();
168
169        assert_ne!(v[0].sum(None, None).unwrap(), array!(0.0));
170        assert_ne!(g["weight"].sum(None, None).unwrap(), array!(0.0));
171        assert_ne!(g["bias"].sum(None, None).unwrap(), array!(0.0));
172    }
173
174    #[test]
175    fn test_value_and_grad_with_unary_output() {
176        let mut model = Linear::new(2, 2).unwrap();
177        let x = crate::random::uniform::<_, f32>(1.0, 2.0, &[2, 2], None).unwrap();
178
179        let loss = |model: &mut Linear, x: &Array| -> Array {
180            model.forward(x).unwrap().sum(None, None).unwrap()
181        };
182
183        let mut vg = nn::value_and_grad(loss);
184        let (v, g) = vg(&mut model, &x).unwrap();
185
186        assert_ne!(v.sum(None, None).unwrap(), array!(0.0));
187        assert_ne!(g["weight"].sum(None, None).unwrap(), array!(0.0));
188        assert_ne!(g["bias"].sum(None, None).unwrap(), array!(0.0));
189    }
190
191    #[test]
192    fn test_fallible_module_value_and_grad() {
193        let mut model = Linear::new(2, 2).unwrap();
194        let x = crate::random::uniform::<_, f32>(1.0, 2.0, &[2, 2], None).unwrap();
195
196        let loss = |model: &mut Linear, x: &Array| -> Result<Vec<Array>, Exception> {
197            Ok(vec![model.forward(x)?.sum(None, None)?])
198        };
199
200        let mut vg = nn::value_and_grad(loss);
201        let (v, g) = vg(&mut model, &x).unwrap();
202
203        assert_ne!(v[0].sum(None, None).unwrap(), array!(0.0));
204        assert_ne!(g["weight"].sum(None, None).unwrap(), array!(0.0));
205        assert_ne!(g["bias"].sum(None, None).unwrap(), array!(0.0));
206    }
207
208    #[test]
209    fn test_value_and_grad_with_two_args() {
210        let mut model = Linear::new(2, 2).unwrap();
211        let x = crate::random::uniform::<_, f32>(1.0, 2.0, &[2, 2], None).unwrap();
212        let y = crate::ops::ones::<f32>(x.shape()).unwrap();
213
214        let loss =
215            |model: &mut Linear, (x, y): (&Array, &Array)| -> Result<Vec<Array>, Exception> {
216                model
217                    .forward(x)?
218                    .subtract(y)?
219                    .square()?
220                    .sum(None, None)
221                    .map(|v| vec![v])
222            };
223
224        let mut vg = nn::value_and_grad(loss);
225        let (v, g) = vg(&mut model, (&x, &y)).unwrap();
226
227        assert_ne!(v[0].sum(None, None).unwrap(), array!(0.0));
228        assert_ne!(g["weight"].sum(None, None).unwrap(), array!(0.0));
229        assert_ne!(g["bias"].sum(None, None).unwrap(), array!(0.0));
230    }
231
232    #[test]
233    fn test_value_and_grad_with_error() {
234        let mut model = Linear::new(2, 2).unwrap();
235        // Use a shape that is not compatible with the model
236        let x = crate::random::uniform::<_, f32>(1.0, 2.0, &[3, 3], None).unwrap();
237
238        let loss = |model: &mut Linear, x: &Array| -> Result<Vec<Array>, Exception> {
239            Ok(vec![model.forward(x)?.sum(None, None)?])
240        };
241
242        let mut vg = nn::value_and_grad(loss);
243        let result = vg(&mut model, &x);
244
245        assert!(result.is_err());
246
247        // Check that the error message is not just "mlx_closure returned a non-zero value"
248        let err = result.unwrap_err();
249        assert!(!err.what().contains("non-zero value"))
250    }
251}