1use crate::module::{Module, ModuleParameters};
4
5pub trait Quantizable {
7 const DEFAULT_GROUP_SIZE: i32 = 64;
9
10 const DEFAULT_BITS: i32 = 4;
12
13 type Quantized;
15
16 type QuantizationError;
18
19 fn try_into_quantized(
21 self,
22 group_size: i32,
23 bits: i32,
24 ) -> Result<Self::Quantized, Self::QuantizationError>;
25}
26
27impl<M> Quantizable for Vec<M>
28where
29 M: Quantizable,
30{
31 type Quantized = Vec<M::Quantized>;
32
33 type QuantizationError = M::QuantizationError;
34
35 fn try_into_quantized(
36 self,
37 group_size: i32,
38 bits: i32,
39 ) -> Result<Self::Quantized, Self::QuantizationError> {
40 self.into_iter()
41 .map(|m| m.try_into_quantized(group_size, bits))
42 .collect()
43 }
44}
45
46impl<M> Quantizable for Box<M>
47where
48 M: Quantizable,
49{
50 type Quantized = Box<M::Quantized>;
51
52 type QuantizationError = M::QuantizationError;
53
54 fn try_into_quantized(
55 self,
56 group_size: i32,
57 bits: i32,
58 ) -> Result<Self::Quantized, Self::QuantizationError> {
59 (*self).try_into_quantized(group_size, bits).map(Box::new)
60 }
61}
62
63impl<M> Quantizable for Option<M>
64where
65 M: Quantizable,
66{
67 type Quantized = Option<M::Quantized>;
68
69 type QuantizationError = M::QuantizationError;
70
71 fn try_into_quantized(
72 self,
73 group_size: i32,
74 bits: i32,
75 ) -> Result<Self::Quantized, Self::QuantizationError> {
76 match self {
77 Some(m) => m.try_into_quantized(group_size, bits).map(Some),
78 None => Ok(None),
79 }
80 }
81}
82
83#[derive(Debug, Clone)]
85pub enum MaybeQuantized<M>
86where
87 M: Quantizable,
88{
89 Original(M),
91
92 Quantized(M::Quantized),
94}
95
96impl<M> Quantizable for MaybeQuantized<M>
97where
98 M: Quantizable,
99{
100 type Quantized = Self;
101 type QuantizationError = <M as Quantizable>::QuantizationError;
102
103 fn try_into_quantized(
104 self,
105 group_size: i32,
106 bits: i32,
107 ) -> Result<Self, Self::QuantizationError> {
108 match self {
109 MaybeQuantized::Original(m) => {
110 let quantized = m.try_into_quantized(group_size, bits)?;
111 Ok(MaybeQuantized::Quantized(quantized))
112 }
113 MaybeQuantized::Quantized(q) => Ok(MaybeQuantized::Quantized(q)),
114 }
115 }
116}
117
118impl<M> MaybeQuantized<M>
119where
120 M: Quantizable,
121{
122 pub fn new(module: M) -> Self {
124 MaybeQuantized::Original(module)
125 }
126
127 pub fn quantize_with(
131 self,
132 op: impl FnOnce(M) -> Result<M::Quantized, M::QuantizationError>,
133 ) -> Result<Self, M::QuantizationError> {
134 match self {
135 MaybeQuantized::Original(m) => op(m).map(MaybeQuantized::Quantized),
136 MaybeQuantized::Quantized(q) => Ok(MaybeQuantized::Quantized(q)),
137 }
138 }
139
140 pub fn is_quantized(&self) -> bool {
142 match self {
143 MaybeQuantized::Original(_) => false,
144 MaybeQuantized::Quantized(_) => true,
145 }
146 }
147}
148
149impl<M> ModuleParameters for MaybeQuantized<M>
150where
151 M: Quantizable + ModuleParameters,
152 M::Quantized: ModuleParameters,
153{
154 fn num_parameters(&self) -> usize {
155 match self {
156 MaybeQuantized::Original(m) => m.num_parameters(),
157 MaybeQuantized::Quantized(q) => q.num_parameters(),
158 }
159 }
160
161 fn parameters(&self) -> crate::module::ModuleParamRef<'_> {
162 match self {
163 MaybeQuantized::Original(m) => m.parameters(),
164 MaybeQuantized::Quantized(q) => q.parameters(),
165 }
166 }
167
168 fn parameters_mut(&mut self) -> crate::module::ModuleParamMut<'_> {
169 match self {
170 MaybeQuantized::Original(m) => m.parameters_mut(),
171 MaybeQuantized::Quantized(q) => q.parameters_mut(),
172 }
173 }
174
175 fn trainable_parameters(&self) -> crate::module::ModuleParamRef<'_> {
176 match self {
177 MaybeQuantized::Original(m) => m.trainable_parameters(),
178 MaybeQuantized::Quantized(q) => q.trainable_parameters(),
179 }
180 }
181
182 fn freeze_parameters(&mut self, recursive: bool) {
183 match self {
184 MaybeQuantized::Original(m) => m.freeze_parameters(recursive),
185 MaybeQuantized::Quantized(q) => q.freeze_parameters(recursive),
186 }
187 }
188
189 fn unfreeze_parameters(&mut self, recursive: bool) {
190 match self {
191 MaybeQuantized::Original(m) => m.unfreeze_parameters(recursive),
192 MaybeQuantized::Quantized(q) => q.unfreeze_parameters(recursive),
193 }
194 }
195
196 fn all_frozen(&self) -> Option<bool> {
197 match self {
198 MaybeQuantized::Original(m) => m.all_frozen(),
199 MaybeQuantized::Quantized(q) => q.all_frozen(),
200 }
201 }
202
203 fn any_frozen(&self) -> Option<bool> {
204 match self {
205 MaybeQuantized::Original(m) => m.any_frozen(),
206 MaybeQuantized::Quantized(q) => q.any_frozen(),
207 }
208 }
209}
210
211impl<M, Input> Module<Input> for MaybeQuantized<M>
212where
213 M: Quantizable + Module<Input>,
214 M::Quantized:
215 Module<Input, Output = <M as Module<Input>>::Output, Error = <M as Module<Input>>::Error>,
216{
217 type Output = <M as Module<Input>>::Output;
218
219 type Error = <M as Module<Input>>::Error;
220
221 fn forward(&mut self, x: Input) -> Result<Self::Output, Self::Error> {
222 match self {
223 MaybeQuantized::Original(m) => m.forward(x),
224 MaybeQuantized::Quantized(q) => q.forward(x),
225 }
226 }
227
228 fn training_mode(&mut self, mode: bool) {
229 match self {
230 MaybeQuantized::Original(m) => m.training_mode(mode),
231 MaybeQuantized::Quantized(q) => q.training_mode(mode),
232 }
233 }
234}
235
236#[cfg(test)]
237mod tests {
238 use crate::nn::{self, Embedding, Linear};
239
240 use super::*;
241
242 #[test]
243 fn test_quantizable_linear() {
244 let linear = Linear::new(64, 64).unwrap();
245 let mut qlinear = MaybeQuantized::new(linear);
246 assert!(!qlinear.is_quantized());
247
248 qlinear = nn::quantize(qlinear, None, None).unwrap();
249 assert!(qlinear.is_quantized());
250 }
251
252 #[test]
253 fn test_quantizable_embedding() {
254 let embedding = Embedding::new(64, 64).unwrap();
255 let mut qembedding = MaybeQuantized::new(embedding);
256 assert!(!qembedding.is_quantized());
257
258 qembedding = nn::quantize(qembedding, None, None).unwrap();
259 assert!(qembedding.is_quantized());
260 }
261}