1use std::{
2 collections::HashMap,
3 ops::{Deref, DerefMut},
4 rc::Rc,
5};
6
7use crate::{nested::NestedValue, Array};
8
9use super::ModuleParameters;
10
11pub trait Parameter {
13 fn count(&self) -> usize;
15
16 fn freeze(&mut self, recursive: bool);
18
19 fn unfreeze(&mut self, recursive: bool);
21
22 fn is_frozen(&self) -> Option<bool>;
25
26 fn as_nested_value(&self) -> NestedValue<Rc<str>, &Array>;
28
29 fn as_nested_value_mut(&mut self) -> NestedValue<Rc<str>, &mut Array>;
31
32 fn as_trainable_nested_value(&self) -> Option<NestedValue<Rc<str>, &Array>>;
34}
35
36#[derive(Debug, Clone)]
38pub struct Param<T> {
39 pub value: T,
41
42 is_frozen: bool,
46}
47
48impl<T> Param<T> {
49 pub fn new(value: T) -> Self {
51 Self {
52 value,
53 is_frozen: false,
54 }
55 }
56}
57
58impl<T> From<T> for Param<T> {
59 fn from(inner: T) -> Self {
60 Self::new(inner)
61 }
62}
63
64impl<T> Deref for Param<T> {
65 type Target = T;
66
67 fn deref(&self) -> &Self::Target {
68 &self.value
69 }
70}
71
72impl<T> DerefMut for Param<T> {
73 fn deref_mut(&mut self) -> &mut Self::Target {
74 &mut self.value
75 }
76}
77
78impl<T> AsRef<T> for Param<T> {
79 fn as_ref(&self) -> &T {
80 &self.value
81 }
82}
83
84impl<T> AsMut<T> for Param<T> {
85 fn as_mut(&mut self) -> &mut T {
86 &mut self.value
87 }
88}
89
90impl Parameter for Param<Array> {
91 fn count(&self) -> usize {
92 1
93 }
94
95 fn freeze(&mut self, _recursive: bool) {
96 self.is_frozen = true;
97 }
98
99 fn unfreeze(&mut self, _recursive: bool) {
100 self.is_frozen = false;
101 }
102
103 fn is_frozen(&self) -> Option<bool> {
104 Some(self.is_frozen)
105 }
106
107 fn as_nested_value<'a>(&self) -> NestedValue<Rc<str>, &Array> {
108 NestedValue::Value(&self.value)
109 }
110
111 fn as_nested_value_mut<'a>(&mut self) -> NestedValue<Rc<str>, &mut Array> {
112 NestedValue::Value(&mut self.value)
113 }
114
115 fn as_trainable_nested_value<'a>(&self) -> Option<NestedValue<Rc<str>, &Array>> {
116 match self.is_frozen {
117 true => None,
118 false => Some(NestedValue::Value(&self.value)),
119 }
120 }
121}
122
123impl Parameter for Param<Option<Array>> {
124 fn count(&self) -> usize {
125 self.value.as_ref().map_or(0, |_| 1)
126 }
127
128 fn freeze(&mut self, _recursive: bool) {
129 self.is_frozen = true;
130 }
131
132 fn unfreeze(&mut self, _recursive: bool) {
133 self.is_frozen = false;
134 }
135
136 fn is_frozen(&self) -> Option<bool> {
137 Some(self.is_frozen)
138 }
139
140 fn as_nested_value(&self) -> NestedValue<Rc<str>, &Array> {
141 match &self.value {
142 Some(array) => NestedValue::Value(array),
143 None => NestedValue::Map(HashMap::with_capacity(0)),
145 }
146 }
147
148 fn as_nested_value_mut(&mut self) -> NestedValue<Rc<str>, &mut Array> {
149 match &mut self.value {
150 Some(array) => NestedValue::Value(array),
151 None => NestedValue::Map(HashMap::with_capacity(0)),
153 }
154 }
155
156 fn as_trainable_nested_value(&self) -> Option<NestedValue<Rc<str>, &Array>> {
157 match self.is_frozen {
158 true => None,
159 false => self.value.as_ref().map(NestedValue::Value),
160 }
161 }
162}
163
164impl<M> Parameter for Option<M>
165where
166 M: ModuleParameters,
167{
168 fn count(&self) -> usize {
169 self.as_ref().map_or(0, |m| m.count())
170 }
171
172 fn freeze(&mut self, recursive: bool) {
173 if let Some(m) = self.as_mut() {
174 m.freeze(recursive);
175 }
176 }
177
178 fn unfreeze(&mut self, recursive: bool) {
179 if let Some(m) = self.as_mut() {
180 m.unfreeze(recursive);
181 }
182 }
183
184 fn is_frozen(&self) -> Option<bool> {
185 self.as_ref().and_then(|m| m.is_frozen())
186 }
187
188 fn as_nested_value(&self) -> NestedValue<Rc<str>, &Array> {
189 match self {
190 Some(m) => m.as_nested_value(),
191 None => NestedValue::Map(HashMap::with_capacity(0)),
192 }
193 }
194
195 fn as_nested_value_mut(&mut self) -> NestedValue<Rc<str>, &mut Array> {
196 match self {
197 Some(m) => m.as_nested_value_mut(),
198 None => NestedValue::Map(HashMap::with_capacity(0)),
199 }
200 }
201
202 fn as_trainable_nested_value(&self) -> Option<NestedValue<Rc<str>, &Array>> {
203 match self {
204 Some(m) => m.as_trainable_nested_value(),
205 None => None,
206 }
207 }
208}
209
210impl<T> Parameter for T
211where
212 T: ModuleParameters,
213{
214 fn count(&self) -> usize {
215 self.num_parameters()
216 }
217
218 fn freeze(&mut self, recursive: bool) {
219 self.freeze_parameters(recursive);
220 }
221
222 fn unfreeze(&mut self, recursive: bool) {
223 self.unfreeze_parameters(recursive);
224 }
225
226 fn is_frozen(&self) -> Option<bool> {
227 self.all_frozen()
228 }
229
230 fn as_nested_value(&self) -> NestedValue<Rc<str>, &Array> {
231 self.parameters().into()
232 }
233
234 fn as_nested_value_mut(&mut self) -> NestedValue<Rc<str>, &mut Array> {
235 self.parameters_mut().into()
236 }
237
238 fn as_trainable_nested_value(&self) -> Option<NestedValue<Rc<str>, &Array>> {
239 Some(self.trainable_parameters().into())
240 }
241}