1use std::{
2 collections::HashMap,
3 ops::{Deref, DerefMut},
4 rc::Rc,
5};
6
7use crate::{nested::NestedValue, Array};
8
9use super::ModuleParameters;
10
11pub trait Parameter {
13 fn count(&self) -> usize;
15
16 fn freeze(&mut self, recursive: bool);
18
19 fn unfreeze(&mut self, recursive: bool);
21
22 fn is_frozen(&self) -> Option<bool>;
25
26 fn as_nested_value(&self) -> NestedValue<Rc<str>, &Array>;
28
29 fn as_nested_value_mut(&mut self) -> NestedValue<Rc<str>, &mut Array>;
31
32 fn as_trainable_nested_value(&self) -> Option<NestedValue<Rc<str>, &Array>>;
34}
35
36#[derive(Debug, Clone)]
38pub struct Param<T> {
39 pub value: T,
41
42 is_frozen: bool,
46}
47
48impl<T> Param<T> {
49 pub fn new(value: T) -> Self {
51 Self {
52 value,
53 is_frozen: false,
54 }
55 }
56}
57
58impl<T> From<T> for Param<T> {
59 fn from(inner: T) -> Self {
60 Self::new(inner)
61 }
62}
63
64impl<T> Deref for Param<T> {
65 type Target = T;
66
67 fn deref(&self) -> &Self::Target {
68 &self.value
69 }
70}
71
72impl<T> DerefMut for Param<T> {
73 fn deref_mut(&mut self) -> &mut Self::Target {
74 &mut self.value
75 }
76}
77
78impl<T> AsRef<T> for Param<T> {
79 fn as_ref(&self) -> &T {
80 &self.value
81 }
82}
83
84impl<T> AsMut<T> for Param<T> {
85 fn as_mut(&mut self) -> &mut T {
86 &mut self.value
87 }
88}
89
90impl Parameter for Param<Array> {
91 fn count(&self) -> usize {
92 1
93 }
94
95 fn freeze(&mut self, _recursive: bool) {
96 self.is_frozen = true;
97 }
98
99 fn unfreeze(&mut self, _recursive: bool) {
100 self.is_frozen = false;
101 }
102
103 fn is_frozen(&self) -> Option<bool> {
104 Some(self.is_frozen)
105 }
106
107 fn as_nested_value<'a>(&self) -> NestedValue<Rc<str>, &Array> {
108 NestedValue::Value(&self.value)
109 }
110
111 fn as_nested_value_mut<'a>(&mut self) -> NestedValue<Rc<str>, &mut Array> {
112 NestedValue::Value(&mut self.value)
113 }
114
115 fn as_trainable_nested_value<'a>(&self) -> Option<NestedValue<Rc<str>, &Array>> {
116 match self.is_frozen {
117 true => None,
118 false => Some(NestedValue::Value(&self.value)),
119 }
120 }
121}
122
123impl Parameter for Param<Option<Array>> {
124 fn count(&self) -> usize {
125 self.value.as_ref().map_or(0, |_| 1)
126 }
127
128 fn freeze(&mut self, _recursive: bool) {
129 self.is_frozen = true;
130 }
131
132 fn unfreeze(&mut self, _recursive: bool) {
133 self.is_frozen = false;
134 }
135
136 fn is_frozen(&self) -> Option<bool> {
137 Some(self.is_frozen)
138 }
139
140 fn as_nested_value(&self) -> NestedValue<Rc<str>, &Array> {
141 match &self.value {
142 Some(array) => NestedValue::Value(array),
143 None => NestedValue::Map(HashMap::with_capacity(0)),
145 }
146 }
147
148 fn as_nested_value_mut(&mut self) -> NestedValue<Rc<str>, &mut Array> {
149 match &mut self.value {
150 Some(array) => NestedValue::Value(array),
151 None => NestedValue::Map(HashMap::with_capacity(0)),
153 }
154 }
155
156 fn as_trainable_nested_value(&self) -> Option<NestedValue<Rc<str>, &Array>> {
157 match self.is_frozen {
158 true => None,
159 false => self.value.as_ref().map(NestedValue::Value),
160 }
161 }
162}
163
164impl<T> Parameter for T
165where
166 T: ModuleParameters,
167{
168 fn count(&self) -> usize {
169 self.num_parameters()
170 }
171
172 fn freeze(&mut self, recursive: bool) {
173 self.freeze_parameters(recursive);
174 }
175
176 fn unfreeze(&mut self, recursive: bool) {
177 self.unfreeze_parameters(recursive);
178 }
179
180 fn is_frozen(&self) -> Option<bool> {
181 self.all_frozen()
182 }
183
184 fn as_nested_value(&self) -> NestedValue<Rc<str>, &Array> {
185 self.parameters().into()
186 }
187
188 fn as_nested_value_mut(&mut self) -> NestedValue<Rc<str>, &mut Array> {
189 self.parameters_mut().into()
190 }
191
192 fn as_trainable_nested_value(&self) -> Option<NestedValue<Rc<str>, &Array>> {
193 Some(self.trainable_parameters().into())
194 }
195}