1use crate::module::{update_parameters, ModuleParameters};
2use crate::transforms::keyed_value_and_grad;
3use crate::{error::Exception, Array};
4
5use crate::module::FlattenedModuleParam;
6
7fn trainable_params(model: &impl ModuleParameters) -> FlattenedModuleParam {
8 model
9 .trainable_parameters()
10 .flatten()
11 .into_iter()
12 .map(|(k, v)| (k, v.clone()))
13 .collect()
14}
15
16pub trait IntoModuleValueAndGrad<'a, M, Args, Val, Err>
18where
19 M: ModuleParameters + 'a,
20 Args: Clone,
21{
22 fn into_module_value_and_grad(
25 self,
26 ) -> impl FnMut(&mut M, Args) -> Result<(Val, FlattenedModuleParam), Exception> + 'a;
27}
28
29impl<'a, F, M, Args> IntoModuleValueAndGrad<'a, M, Args, Vec<Array>, ()> for F
30where
31 M: ModuleParameters + 'a,
32 F: FnMut(&mut M, Args) -> Vec<Array> + 'a,
33 Args: Clone,
34{
35 fn into_module_value_and_grad(
36 mut self,
37 ) -> impl FnMut(&mut M, Args) -> Result<(Vec<Array>, FlattenedModuleParam), Exception> + 'a
38 {
39 move |model, arrays| {
40 let trainable_parameters = trainable_params(model);
41 let inner = |parameters: FlattenedModuleParam, arrays: Args| -> Vec<Array> {
42 let flattened_parameters = parameters.into_iter();
43 update_parameters(model, flattened_parameters);
44
45 self(model, arrays)
46 };
47 let mut vg = keyed_value_and_grad(inner);
48
49 let (v, g) = vg(trainable_parameters, arrays)?;
50 Ok((v, g))
51 }
52 }
53}
54
55impl<'a, F, M, Args> IntoModuleValueAndGrad<'a, M, Args, Vec<Array>, Exception> for F
56where
57 M: ModuleParameters + 'a,
58 F: FnMut(&mut M, Args) -> Result<Vec<Array>, Exception> + 'a,
59 Args: Clone,
60{
61 fn into_module_value_and_grad(
62 mut self,
63 ) -> impl FnMut(&mut M, Args) -> Result<(Vec<Array>, FlattenedModuleParam), Exception> + 'a
64 {
65 move |model, arrays| {
66 let trainable_parameters = trainable_params(model);
67 let inner =
68 |parameters: FlattenedModuleParam, arrays: Args| -> Result<Vec<Array>, Exception> {
69 let flattened_parameters = parameters.into_iter().map(|(k, v)| (k, v.clone()));
70 update_parameters(model, flattened_parameters);
71
72 self(model, arrays)
73 };
74 let mut vg = keyed_value_and_grad(inner);
75
76 let (v, g) = vg(trainable_parameters, arrays)?;
77 Ok((v, g))
78 }
79 }
80}
81
82impl<'a, F, M, Args> IntoModuleValueAndGrad<'a, M, Args, Array, ()> for F
83where
84 M: ModuleParameters + 'a,
85 F: FnMut(&mut M, Args) -> Array + 'a,
86 Args: Clone,
87{
88 fn into_module_value_and_grad(
89 mut self,
90 ) -> impl FnMut(&mut M, Args) -> Result<(Array, FlattenedModuleParam), Exception> + 'a {
91 move |model, arrays| {
92 let trainable_parameters = trainable_params(model);
93 let inner = |parameters: FlattenedModuleParam, arrays: Args| -> Vec<Array> {
94 let flattened_parameters = parameters.into_iter().map(|(k, v)| (k, v.clone()));
95 update_parameters(model, flattened_parameters);
96
97 vec![self(model, arrays)]
98 };
99 let mut vg = keyed_value_and_grad(inner);
100
101 let (v, g) = vg(trainable_parameters, arrays)?;
102 let v = v.into_iter().next().expect("Expected a single value");
103 Ok((v, g))
104 }
105 }
106}
107
108impl<'a, F, M, Args> IntoModuleValueAndGrad<'a, M, Args, Array, Exception> for F
109where
110 M: ModuleParameters + 'a,
111 F: FnMut(&mut M, Args) -> Result<Array, Exception> + 'a,
112 Args: Clone,
113{
114 fn into_module_value_and_grad(
115 mut self,
116 ) -> impl FnMut(&mut M, Args) -> Result<(Array, FlattenedModuleParam), Exception> + 'a {
117 move |model, arrays| {
118 let trainable_parameters = trainable_params(model);
119 let inner =
120 |parameters: FlattenedModuleParam, arrays: Args| -> Result<Vec<Array>, Exception> {
121 let flattened_parameters = parameters.into_iter().map(|(k, v)| (k, v.clone()));
122 update_parameters(model, flattened_parameters);
123
124 self(model, arrays).map(|v| vec![v])
125 };
126 let mut vg = keyed_value_and_grad(inner);
127
128 let (v, g) = vg(trainable_parameters, arrays)?;
129 let v = v.into_iter().next().expect("Expected a single value");
130 Ok((v, g))
131 }
132 }
133}
134
135pub fn value_and_grad<'a, F, M, Args, Val, Err>(
138 f: F,
139) -> impl FnMut(&mut M, Args) -> Result<(Val, FlattenedModuleParam), Exception> + 'a
140where
141 M: ModuleParameters + 'a,
142 F: IntoModuleValueAndGrad<'a, M, Args, Val, Err>,
143 Args: Clone,
144{
145 f.into_module_value_and_grad()
146}
147
148#[cfg(test)]
149mod tests {
150 use crate::module::Module;
151 use crate::{array, error::Exception, Array};
152
153 use crate::nn::{self, Linear};
154
155 #[test]
158 fn test_value_and_grad() {
159 let mut model = Linear::new(2, 2).unwrap();
160 let x = crate::random::uniform::<_, f32>(1.0, 2.0, &[2, 2], None).unwrap();
161
162 let loss = |model: &mut Linear, x: &Array| -> Vec<Array> {
163 vec![model.forward(x).unwrap().sum(None, None).unwrap()]
164 };
165
166 let mut vg = nn::value_and_grad(loss);
167 let (v, g) = vg(&mut model, &x).unwrap();
168
169 assert_ne!(v[0].sum(None, None).unwrap(), array!(0.0));
170 assert_ne!(g["weight"].sum(None, None).unwrap(), array!(0.0));
171 assert_ne!(g["bias"].sum(None, None).unwrap(), array!(0.0));
172 }
173
174 #[test]
175 fn test_value_and_grad_with_unary_output() {
176 let mut model = Linear::new(2, 2).unwrap();
177 let x = crate::random::uniform::<_, f32>(1.0, 2.0, &[2, 2], None).unwrap();
178
179 let loss = |model: &mut Linear, x: &Array| -> Array {
180 model.forward(x).unwrap().sum(None, None).unwrap()
181 };
182
183 let mut vg = nn::value_and_grad(loss);
184 let (v, g) = vg(&mut model, &x).unwrap();
185
186 assert_ne!(v.sum(None, None).unwrap(), array!(0.0));
187 assert_ne!(g["weight"].sum(None, None).unwrap(), array!(0.0));
188 assert_ne!(g["bias"].sum(None, None).unwrap(), array!(0.0));
189 }
190
191 #[test]
192 fn test_fallible_module_value_and_grad() {
193 let mut model = Linear::new(2, 2).unwrap();
194 let x = crate::random::uniform::<_, f32>(1.0, 2.0, &[2, 2], None).unwrap();
195
196 let loss = |model: &mut Linear, x: &Array| -> Result<Vec<Array>, Exception> {
197 Ok(vec![model.forward(x)?.sum(None, None)?])
198 };
199
200 let mut vg = nn::value_and_grad(loss);
201 let (v, g) = vg(&mut model, &x).unwrap();
202
203 assert_ne!(v[0].sum(None, None).unwrap(), array!(0.0));
204 assert_ne!(g["weight"].sum(None, None).unwrap(), array!(0.0));
205 assert_ne!(g["bias"].sum(None, None).unwrap(), array!(0.0));
206 }
207
208 #[test]
209 fn test_value_and_grad_with_two_args() {
210 let mut model = Linear::new(2, 2).unwrap();
211 let x = crate::random::uniform::<_, f32>(1.0, 2.0, &[2, 2], None).unwrap();
212 let y = crate::ops::ones::<f32>(x.shape()).unwrap();
213
214 let loss =
215 |model: &mut Linear, (x, y): (&Array, &Array)| -> Result<Vec<Array>, Exception> {
216 model
217 .forward(x)?
218 .subtract(y)?
219 .square()?
220 .sum(None, None)
221 .map(|v| vec![v])
222 };
223
224 let mut vg = nn::value_and_grad(loss);
225 let (v, g) = vg(&mut model, (&x, &y)).unwrap();
226
227 assert_ne!(v[0].sum(None, None).unwrap(), array!(0.0));
228 assert_ne!(g["weight"].sum(None, None).unwrap(), array!(0.0));
229 assert_ne!(g["bias"].sum(None, None).unwrap(), array!(0.0));
230 }
231
232 #[test]
233 fn test_value_and_grad_with_error() {
234 let mut model = Linear::new(2, 2).unwrap();
235 let x = crate::random::uniform::<_, f32>(1.0, 2.0, &[3, 3], None).unwrap();
237
238 let loss = |model: &mut Linear, x: &Array| -> Result<Vec<Array>, Exception> {
239 Ok(vec![model.forward(x)?.sum(None, None)?])
240 };
241
242 let mut vg = nn::value_and_grad(loss);
243 let result = vg(&mut model, &x);
244
245 assert!(result.is_err());
246
247 let err = result.unwrap_err();
249 assert!(!err.what().contains("non-zero value"))
250 }
251}