1use crate::module::{Module, ModuleParameters};
4
5pub trait Quantizable {
7 const DEFAULT_GROUP_SIZE: i32 = 64;
9
10 const DEFAULT_BITS: i32 = 4;
12
13 type Quantized;
15
16 type QuantizationError;
18
19 fn try_into_quantized(
21 self,
22 group_size: i32,
23 bits: i32,
24 ) -> Result<Self::Quantized, Self::QuantizationError>;
25}
26
27impl<M> Quantizable for Vec<M>
28where
29 M: Quantizable,
30{
31 type Quantized = Vec<M::Quantized>;
32
33 type QuantizationError = M::QuantizationError;
34
35 fn try_into_quantized(
36 self,
37 group_size: i32,
38 bits: i32,
39 ) -> Result<Self::Quantized, Self::QuantizationError> {
40 self.into_iter()
41 .map(|m| m.try_into_quantized(group_size, bits))
42 .collect()
43 }
44}
45
46impl<M> Quantizable for Box<M>
47where
48 M: Quantizable,
49{
50 type Quantized = Box<M::Quantized>;
51
52 type QuantizationError = M::QuantizationError;
53
54 fn try_into_quantized(
55 self,
56 group_size: i32,
57 bits: i32,
58 ) -> Result<Self::Quantized, Self::QuantizationError> {
59 (*self).try_into_quantized(group_size, bits).map(Box::new)
60 }
61}
62
63#[derive(Debug, Clone)]
65pub enum MaybeQuantized<M>
66where
67 M: Quantizable,
68{
69 Original(M),
71
72 Quantized(M::Quantized),
74}
75
76impl<M> Quantizable for MaybeQuantized<M>
77where
78 M: Quantizable,
79{
80 type Quantized = Self;
81 type QuantizationError = <M as Quantizable>::QuantizationError;
82
83 fn try_into_quantized(
84 self,
85 group_size: i32,
86 bits: i32,
87 ) -> Result<Self, Self::QuantizationError> {
88 match self {
89 MaybeQuantized::Original(m) => {
90 let quantized = m.try_into_quantized(group_size, bits)?;
91 Ok(MaybeQuantized::Quantized(quantized))
92 }
93 MaybeQuantized::Quantized(q) => Ok(MaybeQuantized::Quantized(q)),
94 }
95 }
96}
97
98impl<M> MaybeQuantized<M>
99where
100 M: Quantizable,
101{
102 pub fn new(module: M) -> Self {
104 MaybeQuantized::Original(module)
105 }
106
107 pub fn quantize_with(
111 self,
112 op: impl FnOnce(M) -> Result<M::Quantized, M::QuantizationError>,
113 ) -> Result<Self, M::QuantizationError> {
114 match self {
115 MaybeQuantized::Original(m) => op(m).map(MaybeQuantized::Quantized),
116 MaybeQuantized::Quantized(q) => Ok(MaybeQuantized::Quantized(q)),
117 }
118 }
119
120 pub fn is_quantized(&self) -> bool {
122 match self {
123 MaybeQuantized::Original(_) => false,
124 MaybeQuantized::Quantized(_) => true,
125 }
126 }
127}
128
129impl<M> ModuleParameters for MaybeQuantized<M>
130where
131 M: Quantizable + ModuleParameters,
132 M::Quantized: ModuleParameters,
133{
134 fn parameters(&self) -> crate::module::ModuleParamRef<'_> {
135 match self {
136 MaybeQuantized::Original(m) => m.parameters(),
137 MaybeQuantized::Quantized(q) => q.parameters(),
138 }
139 }
140
141 fn parameters_mut(&mut self) -> crate::module::ModuleParamMut<'_> {
142 match self {
143 MaybeQuantized::Original(m) => m.parameters_mut(),
144 MaybeQuantized::Quantized(q) => q.parameters_mut(),
145 }
146 }
147
148 fn trainable_parameters(&self) -> crate::module::ModuleParamRef<'_> {
149 match self {
150 MaybeQuantized::Original(m) => m.trainable_parameters(),
151 MaybeQuantized::Quantized(q) => q.trainable_parameters(),
152 }
153 }
154
155 fn freeze_parameters(&mut self, recursive: bool) {
156 match self {
157 MaybeQuantized::Original(m) => m.freeze_parameters(recursive),
158 MaybeQuantized::Quantized(q) => q.freeze_parameters(recursive),
159 }
160 }
161
162 fn unfreeze_parameters(&mut self, recursive: bool) {
163 match self {
164 MaybeQuantized::Original(m) => m.unfreeze_parameters(recursive),
165 MaybeQuantized::Quantized(q) => q.unfreeze_parameters(recursive),
166 }
167 }
168
169 fn all_frozen(&self) -> Option<bool> {
170 match self {
171 MaybeQuantized::Original(m) => m.all_frozen(),
172 MaybeQuantized::Quantized(q) => q.all_frozen(),
173 }
174 }
175
176 fn any_frozen(&self) -> Option<bool> {
177 match self {
178 MaybeQuantized::Original(m) => m.any_frozen(),
179 MaybeQuantized::Quantized(q) => q.any_frozen(),
180 }
181 }
182}
183
184impl<M, Input> Module<Input> for MaybeQuantized<M>
185where
186 M: Quantizable + Module<Input>,
187 M::Quantized:
188 Module<Input, Output = <M as Module<Input>>::Output, Error = <M as Module<Input>>::Error>,
189{
190 type Output = <M as Module<Input>>::Output;
191
192 type Error = <M as Module<Input>>::Error;
193
194 fn forward(&mut self, x: Input) -> Result<Self::Output, Self::Error> {
195 match self {
196 MaybeQuantized::Original(m) => m.forward(x),
197 MaybeQuantized::Quantized(q) => q.forward(x),
198 }
199 }
200
201 fn training_mode(&mut self, mode: bool) {
202 match self {
203 MaybeQuantized::Original(m) => m.training_mode(mode),
204 MaybeQuantized::Quantized(q) => q.training_mode(mode),
205 }
206 }
207}
208
209#[cfg(test)]
210mod tests {
211 use crate::nn::{self, Embedding, Linear};
212
213 use super::*;
214
215 #[test]
216 fn test_quantizable_linear() {
217 let linear = Linear::new(64, 64).unwrap();
218 let mut qlinear = MaybeQuantized::new(linear);
219 assert!(!qlinear.is_quantized());
220
221 qlinear = nn::quantize(qlinear, None, None).unwrap();
222 assert!(qlinear.is_quantized());
223 }
224
225 #[test]
226 fn test_quantizable_embedding() {
227 let embedding = Embedding::new(64, 64).unwrap();
228 let mut qembedding = MaybeQuantized::new(embedding);
229 assert!(!qembedding.is_quantized());
230
231 qembedding = nn::quantize(qembedding, None, None).unwrap();
232 assert!(qembedding.is_quantized());
233 }
234}