use std::{
collections::HashMap,
ops::{Deref, DerefMut},
rc::Rc,
};
use crate::{nested::NestedValue, Array};
use super::ModuleParameters;
pub trait Parameter {
fn freeze(&mut self, recursive: bool);
fn unfreeze(&mut self, recursive: bool);
fn is_frozen(&self) -> Option<bool>;
fn as_nested_value(&self) -> NestedValue<Rc<str>, &Array>;
fn as_nested_value_mut(&mut self) -> NestedValue<Rc<str>, &mut Array>;
fn as_trainable_nested_value(&self) -> Option<NestedValue<Rc<str>, &Array>>;
}
#[derive(Debug, Clone)]
pub struct Param<T> {
pub value: T,
is_frozen: bool,
}
impl<T> Param<T> {
pub fn new(value: T) -> Self {
Self {
value,
is_frozen: false,
}
}
}
impl<T> From<T> for Param<T> {
fn from(inner: T) -> Self {
Self::new(inner)
}
}
impl<T> Deref for Param<T> {
type Target = T;
fn deref(&self) -> &Self::Target {
&self.value
}
}
impl<T> DerefMut for Param<T> {
fn deref_mut(&mut self) -> &mut Self::Target {
&mut self.value
}
}
impl<T> AsRef<T> for Param<T> {
fn as_ref(&self) -> &T {
&self.value
}
}
impl<T> AsMut<T> for Param<T> {
fn as_mut(&mut self) -> &mut T {
&mut self.value
}
}
impl Parameter for Param<Array> {
fn freeze(&mut self, _recursive: bool) {
self.is_frozen = true;
}
fn unfreeze(&mut self, _recursive: bool) {
self.is_frozen = false;
}
fn is_frozen(&self) -> Option<bool> {
Some(self.is_frozen)
}
fn as_nested_value<'a>(&self) -> NestedValue<Rc<str>, &Array> {
NestedValue::Value(&self.value)
}
fn as_nested_value_mut<'a>(&mut self) -> NestedValue<Rc<str>, &mut Array> {
NestedValue::Value(&mut self.value)
}
fn as_trainable_nested_value<'a>(&self) -> Option<NestedValue<Rc<str>, &Array>> {
match self.is_frozen {
true => None,
false => Some(NestedValue::Value(&self.value)),
}
}
}
impl Parameter for Param<Option<Array>> {
fn freeze(&mut self, _recursive: bool) {
self.is_frozen = true;
}
fn unfreeze(&mut self, _recursive: bool) {
self.is_frozen = false;
}
fn is_frozen(&self) -> Option<bool> {
Some(self.is_frozen)
}
fn as_nested_value(&self) -> NestedValue<Rc<str>, &Array> {
match &self.value {
Some(array) => NestedValue::Value(array),
None => NestedValue::Map(HashMap::with_capacity(0)),
}
}
fn as_nested_value_mut(&mut self) -> NestedValue<Rc<str>, &mut Array> {
match &mut self.value {
Some(array) => NestedValue::Value(array),
None => NestedValue::Map(HashMap::with_capacity(0)),
}
}
fn as_trainable_nested_value(&self) -> Option<NestedValue<Rc<str>, &Array>> {
match self.is_frozen {
true => None,
false => self.value.as_ref().map(NestedValue::Value),
}
}
}
impl<T> Parameter for T
where
T: ModuleParameters,
{
fn freeze(&mut self, recursive: bool) {
self.freeze_parameters(recursive);
}
fn unfreeze(&mut self, recursive: bool) {
self.unfreeze_parameters(recursive);
}
fn is_frozen(&self) -> Option<bool> {
self.all_frozen()
}
fn as_nested_value(&self) -> NestedValue<Rc<str>, &Array> {
self.parameters().into()
}
fn as_nested_value_mut(&mut self) -> NestedValue<Rc<str>, &mut Array> {
self.parameters_mut().into()
}
fn as_trainable_nested_value(&self) -> Option<NestedValue<Rc<str>, &Array>> {
Some(self.trainable_parameters().into())
}
}