use std::{borrow::Borrow, collections::HashMap, hash::Hash, path::Path, rc::Rc};
use crate::{
error::{Exception, IoError},
nested::{NestedHashMap, NestedValue},
Array,
};
pub type ModuleParam = NestedHashMap<Rc<str>, Array>;
pub type ModuleParamRef<'a> = NestedHashMap<Rc<str>, &'a Array>;
pub type ModuleParamMut<'a> = NestedHashMap<Rc<str>, &'a mut Array>;
pub type FlattenedModuleParam = HashMap<Rc<str>, Array>;
pub type FlattenedModuleParamRef<'a> = HashMap<Rc<str>, &'a Array>;
pub type FlattenedModuleParamMut<'a> = HashMap<Rc<str>, &'a mut Array>;
pub trait Module<Input>: ModuleParameters + std::fmt::Debug {
type Output;
type Error: std::error::Error;
fn forward(&mut self, input: Input) -> Result<Self::Output, Self::Error>;
fn training_mode(&mut self, mode: bool);
}
pub trait UnaryModule: for<'a> Module<&'a Array, Output = Array> {}
impl<T> UnaryModule for T where T: for<'a> Module<&'a Array, Output = Array> {}
pub trait ModuleParameters {
fn parameters(&self) -> ModuleParamRef<'_>;
fn parameters_mut(&mut self) -> ModuleParamMut<'_>;
fn trainable_parameters(&self) -> ModuleParamRef<'_>;
fn update(&mut self, parameters: ModuleParam) {
let flattened_parameters = parameters.flatten();
update_parameters(self, flattened_parameters)
}
fn update_flattened(&mut self, flattened_parameters: FlattenedModuleParam) {
update_parameters(self, flattened_parameters)
}
fn freeze_parameters(&mut self, recursive: bool);
fn unfreeze_parameters(&mut self, recursive: bool);
fn all_frozen(&self) -> Option<bool>;
fn any_frozen(&self) -> Option<bool>;
}
pub fn update_parameters<M, I, Q>(module: &mut M, parameters: I)
where
M: ModuleParameters + ?Sized,
I: IntoIterator<Item = (Q, Array)>,
Q: Hash + Eq,
Rc<str>: Borrow<Q>,
{
let mut flattened_self_parameters = module.parameters_mut().flatten();
for (key, value) in parameters {
if let Some(self_value) = flattened_self_parameters.get_mut(&key) {
**self_value = value;
}
}
}
impl<T> ModuleParameters for &'_ mut T
where
T: ModuleParameters + ?Sized,
{
fn parameters(&self) -> ModuleParamRef<'_> {
(**self).parameters()
}
fn parameters_mut(&mut self) -> ModuleParamMut<'_> {
(**self).parameters_mut()
}
fn trainable_parameters(&self) -> ModuleParamRef<'_> {
(**self).trainable_parameters()
}
fn freeze_parameters(&mut self, recursive: bool) {
(**self).freeze_parameters(recursive);
}
fn unfreeze_parameters(&mut self, recursive: bool) {
(**self).unfreeze_parameters(recursive);
}
fn all_frozen(&self) -> Option<bool> {
(**self).all_frozen()
}
fn any_frozen(&self) -> Option<bool> {
(**self).any_frozen()
}
}
impl<T> ModuleParameters for Box<T>
where
T: ModuleParameters + ?Sized,
{
fn parameters(&self) -> ModuleParamRef<'_> {
self.as_ref().parameters()
}
fn parameters_mut(&mut self) -> ModuleParamMut<'_> {
self.as_mut().parameters_mut()
}
fn trainable_parameters(&self) -> ModuleParamRef<'_> {
self.as_ref().trainable_parameters()
}
fn freeze_parameters(&mut self, recursive: bool) {
self.as_mut().freeze_parameters(recursive);
}
fn unfreeze_parameters(&mut self, recursive: bool) {
self.as_mut().unfreeze_parameters(recursive);
}
fn all_frozen(&self) -> Option<bool> {
self.as_ref().all_frozen()
}
fn any_frozen(&self) -> Option<bool> {
self.as_ref().any_frozen()
}
}
impl<T> ModuleParameters for Vec<T>
where
T: ModuleParameters,
{
fn parameters(&self) -> ModuleParamRef<'_> {
let mut parameters = NestedHashMap::new();
self.iter().enumerate().for_each(|(i, module)| {
let value = module.parameters();
parameters.insert(Rc::from(i.to_string()), NestedValue::Map(value.entries));
});
parameters
}
fn parameters_mut(&mut self) -> ModuleParamMut<'_> {
let mut parameters = NestedHashMap::new();
self.iter_mut().enumerate().for_each(|(i, module)| {
let value = module.parameters_mut();
parameters.insert(Rc::from(i.to_string()), NestedValue::Map(value.entries));
});
parameters
}
fn trainable_parameters(&self) -> ModuleParamRef<'_> {
let mut parameters = NestedHashMap::new();
self.iter().enumerate().for_each(|(i, module)| {
let value = module.trainable_parameters();
parameters.insert(Rc::from(i.to_string()), NestedValue::Map(value.entries));
});
parameters
}
fn freeze_parameters(&mut self, recursive: bool) {
self.iter_mut().for_each(|module| {
module.freeze_parameters(recursive);
});
}
fn unfreeze_parameters(&mut self, recursive: bool) {
self.iter_mut().for_each(|module| {
module.unfreeze_parameters(recursive);
});
}
fn all_frozen(&self) -> Option<bool> {
let mut result = None;
for module in self.iter() {
match module.all_frozen() {
Some(true) => result = Some(true),
Some(false) => return Some(false),
None => {}
}
}
result
}
fn any_frozen(&self) -> Option<bool> {
let mut result = None;
for module in self.iter() {
match module.any_frozen() {
Some(true) => return Some(true),
Some(false) => result = Some(false),
None => {}
}
}
result
}
}
pub trait ModuleParametersExt: ModuleParameters {
fn eval(&self) -> Result<(), Exception> {
crate::transforms::eval_params(self.parameters())
}
fn load_safetensors(&mut self, path: impl AsRef<Path>) -> Result<(), IoError> {
let loaded = Array::load_safetensors(path)?;
let mut params = self.parameters_mut().flatten();
for (key, value) in loaded {
if let Some(param) = params.get_mut(&*key) {
**param = value;
}
}
self.eval()?;
Ok(())
}
fn save_safetensors(&self, path: impl AsRef<Path>) -> Result<(), IoError> {
let params = self.parameters().flatten();
Array::save_safetensors(params, None, path)?;
Ok(())
}
}
impl<T: ModuleParameters> ModuleParametersExt for T {}