1use std::{
2 collections::HashMap,
3 ops::{Deref, DerefMut},
4 rc::Rc,
5};
6
7use crate::{nested::NestedValue, Array};
8
9use super::ModuleParameters;
10
11pub trait Parameter {
13 fn freeze(&mut self, recursive: bool);
15
16 fn unfreeze(&mut self, recursive: bool);
18
19 fn is_frozen(&self) -> Option<bool>;
22
23 fn as_nested_value(&self) -> NestedValue<Rc<str>, &Array>;
25
26 fn as_nested_value_mut(&mut self) -> NestedValue<Rc<str>, &mut Array>;
28
29 fn as_trainable_nested_value(&self) -> Option<NestedValue<Rc<str>, &Array>>;
31}
32
33#[derive(Debug, Clone)]
35pub struct Param<T> {
36 pub value: T,
38
39 is_frozen: bool,
43}
44
45impl<T> Param<T> {
46 pub fn new(value: T) -> Self {
48 Self {
49 value,
50 is_frozen: false,
51 }
52 }
53}
54
55impl<T> From<T> for Param<T> {
56 fn from(inner: T) -> Self {
57 Self::new(inner)
58 }
59}
60
61impl<T> Deref for Param<T> {
62 type Target = T;
63
64 fn deref(&self) -> &Self::Target {
65 &self.value
66 }
67}
68
69impl<T> DerefMut for Param<T> {
70 fn deref_mut(&mut self) -> &mut Self::Target {
71 &mut self.value
72 }
73}
74
75impl<T> AsRef<T> for Param<T> {
76 fn as_ref(&self) -> &T {
77 &self.value
78 }
79}
80
81impl<T> AsMut<T> for Param<T> {
82 fn as_mut(&mut self) -> &mut T {
83 &mut self.value
84 }
85}
86
87impl Parameter for Param<Array> {
88 fn freeze(&mut self, _recursive: bool) {
89 self.is_frozen = true;
90 }
91
92 fn unfreeze(&mut self, _recursive: bool) {
93 self.is_frozen = false;
94 }
95
96 fn is_frozen(&self) -> Option<bool> {
97 Some(self.is_frozen)
98 }
99
100 fn as_nested_value<'a>(&self) -> NestedValue<Rc<str>, &Array> {
101 NestedValue::Value(&self.value)
102 }
103
104 fn as_nested_value_mut<'a>(&mut self) -> NestedValue<Rc<str>, &mut Array> {
105 NestedValue::Value(&mut self.value)
106 }
107
108 fn as_trainable_nested_value<'a>(&self) -> Option<NestedValue<Rc<str>, &Array>> {
109 match self.is_frozen {
110 true => None,
111 false => Some(NestedValue::Value(&self.value)),
112 }
113 }
114}
115
116impl Parameter for Param<Option<Array>> {
117 fn freeze(&mut self, _recursive: bool) {
118 self.is_frozen = true;
119 }
120
121 fn unfreeze(&mut self, _recursive: bool) {
122 self.is_frozen = false;
123 }
124
125 fn is_frozen(&self) -> Option<bool> {
126 Some(self.is_frozen)
127 }
128
129 fn as_nested_value(&self) -> NestedValue<Rc<str>, &Array> {
130 match &self.value {
131 Some(array) => NestedValue::Value(array),
132 None => NestedValue::Map(HashMap::with_capacity(0)),
134 }
135 }
136
137 fn as_nested_value_mut(&mut self) -> NestedValue<Rc<str>, &mut Array> {
138 match &mut self.value {
139 Some(array) => NestedValue::Value(array),
140 None => NestedValue::Map(HashMap::with_capacity(0)),
142 }
143 }
144
145 fn as_trainable_nested_value(&self) -> Option<NestedValue<Rc<str>, &Array>> {
146 match self.is_frozen {
147 true => None,
148 false => self.value.as_ref().map(NestedValue::Value),
149 }
150 }
151}
152
153impl<T> Parameter for T
154where
155 T: ModuleParameters,
156{
157 fn freeze(&mut self, recursive: bool) {
158 self.freeze_parameters(recursive);
159 }
160
161 fn unfreeze(&mut self, recursive: bool) {
162 self.unfreeze_parameters(recursive);
163 }
164
165 fn is_frozen(&self) -> Option<bool> {
166 self.all_frozen()
167 }
168
169 fn as_nested_value(&self) -> NestedValue<Rc<str>, &Array> {
170 self.parameters().into()
171 }
172
173 fn as_nested_value_mut(&mut self) -> NestedValue<Rc<str>, &mut Array> {
174 self.parameters_mut().into()
175 }
176
177 fn as_trainable_nested_value(&self) -> Option<NestedValue<Rc<str>, &Array>> {
178 Some(self.trainable_parameters().into())
179 }
180}