use crate::module::{update_parameters, ModuleParameters};
use crate::transforms::keyed_value_and_grad;
use crate::{error::Exception, Array};
use crate::module::FlattenedModuleParam;
fn trainable_params(model: &impl ModuleParameters) -> FlattenedModuleParam {
model
.trainable_parameters()
.flatten()
.into_iter()
.map(|(k, v)| (k, v.clone()))
.collect()
}
pub trait IntoModuleValueAndGrad<'a, M, Args, Val, Err>
where
M: ModuleParameters + 'a,
Args: Clone,
{
fn into_module_value_and_grad(
self,
) -> impl FnMut(&mut M, Args) -> Result<(Val, FlattenedModuleParam), Exception> + 'a;
}
impl<'a, F, M, Args> IntoModuleValueAndGrad<'a, M, Args, Vec<Array>, ()> for F
where
M: ModuleParameters + 'a,
F: FnMut(&mut M, Args) -> Vec<Array> + 'a,
Args: Clone,
{
fn into_module_value_and_grad(
mut self,
) -> impl FnMut(&mut M, Args) -> Result<(Vec<Array>, FlattenedModuleParam), Exception> + 'a
{
move |model, arrays| {
let trainable_parameters = trainable_params(model);
let inner = |parameters: FlattenedModuleParam, arrays: Args| -> Vec<Array> {
let flattened_parameters = parameters.into_iter();
update_parameters(model, flattened_parameters);
self(model, arrays)
};
let mut vg = keyed_value_and_grad(inner);
let (v, g) = vg(trainable_parameters, arrays)?;
Ok((v, g))
}
}
}
impl<'a, F, M, Args> IntoModuleValueAndGrad<'a, M, Args, Vec<Array>, Exception> for F
where
M: ModuleParameters + 'a,
F: FnMut(&mut M, Args) -> Result<Vec<Array>, Exception> + 'a,
Args: Clone,
{
fn into_module_value_and_grad(
mut self,
) -> impl FnMut(&mut M, Args) -> Result<(Vec<Array>, FlattenedModuleParam), Exception> + 'a
{
move |model, arrays| {
let trainable_parameters = trainable_params(model);
let inner =
|parameters: FlattenedModuleParam, arrays: Args| -> Result<Vec<Array>, Exception> {
let flattened_parameters = parameters.into_iter().map(|(k, v)| (k, v.clone()));
update_parameters(model, flattened_parameters);
self(model, arrays)
};
let mut vg = keyed_value_and_grad(inner);
let (v, g) = vg(trainable_parameters, arrays)?;
Ok((v, g))
}
}
}
impl<'a, F, M, Args> IntoModuleValueAndGrad<'a, M, Args, Array, ()> for F
where
M: ModuleParameters + 'a,
F: FnMut(&mut M, Args) -> Array + 'a,
Args: Clone,
{
fn into_module_value_and_grad(
mut self,
) -> impl FnMut(&mut M, Args) -> Result<(Array, FlattenedModuleParam), Exception> + 'a {
move |model, arrays| {
let trainable_parameters = trainable_params(model);
let inner = |parameters: FlattenedModuleParam, arrays: Args| -> Vec<Array> {
let flattened_parameters = parameters.into_iter().map(|(k, v)| (k, v.clone()));
update_parameters(model, flattened_parameters);
vec![self(model, arrays)]
};
let mut vg = keyed_value_and_grad(inner);
let (v, g) = vg(trainable_parameters, arrays)?;
let v = v.into_iter().next().expect("Expected a single value");
Ok((v, g))
}
}
}
impl<'a, F, M, Args> IntoModuleValueAndGrad<'a, M, Args, Array, Exception> for F
where
M: ModuleParameters + 'a,
F: FnMut(&mut M, Args) -> Result<Array, Exception> + 'a,
Args: Clone,
{
fn into_module_value_and_grad(
mut self,
) -> impl FnMut(&mut M, Args) -> Result<(Array, FlattenedModuleParam), Exception> + 'a {
move |model, arrays| {
let trainable_parameters = trainable_params(model);
let inner =
|parameters: FlattenedModuleParam, arrays: Args| -> Result<Vec<Array>, Exception> {
let flattened_parameters = parameters.into_iter().map(|(k, v)| (k, v.clone()));
update_parameters(model, flattened_parameters);
self(model, arrays).map(|v| vec![v])
};
let mut vg = keyed_value_and_grad(inner);
let (v, g) = vg(trainable_parameters, arrays)?;
let v = v.into_iter().next().expect("Expected a single value");
Ok((v, g))
}
}
}
pub fn value_and_grad<'a, F, M, Args, Val, Err>(
f: F,
) -> impl FnMut(&mut M, Args) -> Result<(Val, FlattenedModuleParam), Exception> + 'a
where
M: ModuleParameters + 'a,
F: IntoModuleValueAndGrad<'a, M, Args, Val, Err>,
Args: Clone,
{
f.into_module_value_and_grad()
}
#[cfg(test)]
mod tests {
use crate::module::Module;
use crate::{array, error::Exception, Array};
use crate::nn::{self, Linear};
#[test]
fn test_value_and_grad() {
let mut model = Linear::new(2, 2).unwrap();
let x = crate::random::uniform::<_, f32>(1.0, 2.0, &[2, 2], None).unwrap();
let loss = |model: &mut Linear, x: &Array| -> Vec<Array> {
vec![model.forward(x).unwrap().sum(None, None).unwrap()]
};
let mut vg = nn::value_and_grad(loss);
let (v, g) = vg(&mut model, &x).unwrap();
assert_ne!(v[0].sum(None, None).unwrap(), array!(0.0));
assert_ne!(g["weight"].sum(None, None).unwrap(), array!(0.0));
assert_ne!(g["bias"].sum(None, None).unwrap(), array!(0.0));
}
#[test]
fn test_value_and_grad_with_unary_output() {
let mut model = Linear::new(2, 2).unwrap();
let x = crate::random::uniform::<_, f32>(1.0, 2.0, &[2, 2], None).unwrap();
let loss = |model: &mut Linear, x: &Array| -> Array {
model.forward(x).unwrap().sum(None, None).unwrap()
};
let mut vg = nn::value_and_grad(loss);
let (v, g) = vg(&mut model, &x).unwrap();
assert_ne!(v.sum(None, None).unwrap(), array!(0.0));
assert_ne!(g["weight"].sum(None, None).unwrap(), array!(0.0));
assert_ne!(g["bias"].sum(None, None).unwrap(), array!(0.0));
}
#[test]
fn test_fallible_module_value_and_grad() {
let mut model = Linear::new(2, 2).unwrap();
let x = crate::random::uniform::<_, f32>(1.0, 2.0, &[2, 2], None).unwrap();
let loss = |model: &mut Linear, x: &Array| -> Result<Vec<Array>, Exception> {
Ok(vec![model.forward(x)?.sum(None, None)?])
};
let mut vg = nn::value_and_grad(loss);
let (v, g) = vg(&mut model, &x).unwrap();
assert_ne!(v[0].sum(None, None).unwrap(), array!(0.0));
assert_ne!(g["weight"].sum(None, None).unwrap(), array!(0.0));
assert_ne!(g["bias"].sum(None, None).unwrap(), array!(0.0));
}
#[test]
fn test_value_and_grad_with_two_args() {
let mut model = Linear::new(2, 2).unwrap();
let x = crate::random::uniform::<_, f32>(1.0, 2.0, &[2, 2], None).unwrap();
let y = crate::ops::ones::<f32>(x.shape()).unwrap();
let loss =
|model: &mut Linear, (x, y): (&Array, &Array)| -> Result<Vec<Array>, Exception> {
model
.forward(x)?
.subtract(y)?
.square()?
.sum(None, None)
.map(|v| vec![v])
};
let mut vg = nn::value_and_grad(loss);
let (v, g) = vg(&mut model, (&x, &y)).unwrap();
assert_ne!(v[0].sum(None, None).unwrap(), array!(0.0));
assert_ne!(g["weight"].sum(None, None).unwrap(), array!(0.0));
assert_ne!(g["bias"].sum(None, None).unwrap(), array!(0.0));
}
#[test]
fn test_value_and_grad_with_error() {
let mut model = Linear::new(2, 2).unwrap();
let x = crate::random::uniform::<_, f32>(1.0, 2.0, &[3, 3], None).unwrap();
let loss = |model: &mut Linear, x: &Array| -> Result<Vec<Array>, Exception> {
Ok(vec![model.forward(x)?.sum(None, None)?])
};
let mut vg = nn::value_and_grad(loss);
let result = vg(&mut model, &x);
assert!(result.is_err());
let err = result.unwrap_err();
assert!(!err.what().contains("non-zero value"))
}
}