mlx_rs/
quantization.rs

1//! Traits for quantization
2
3use crate::module::{Module, ModuleParameters};
4
5/// Trait for quantization of modules.
6pub trait Quantizable {
7    /// The default group size for quantization.
8    const DEFAULT_GROUP_SIZE: i32 = 64;
9
10    /// The default number of bits for quantization.
11    const DEFAULT_BITS: i32 = 4;
12
13    /// The quantized type.
14    type Quantized;
15
16    /// The error type for quantization.
17    type QuantizationError;
18
19    /// Quantize the module with the specified group size and number of bits.
20    fn try_into_quantized(
21        self,
22        group_size: i32,
23        bits: i32,
24    ) -> Result<Self::Quantized, Self::QuantizationError>;
25}
26
27impl<M> Quantizable for Vec<M>
28where
29    M: Quantizable,
30{
31    type Quantized = Vec<M::Quantized>;
32
33    type QuantizationError = M::QuantizationError;
34
35    fn try_into_quantized(
36        self,
37        group_size: i32,
38        bits: i32,
39    ) -> Result<Self::Quantized, Self::QuantizationError> {
40        self.into_iter()
41            .map(|m| m.try_into_quantized(group_size, bits))
42            .collect()
43    }
44}
45
46impl<M> Quantizable for Box<M>
47where
48    M: Quantizable,
49{
50    type Quantized = Box<M::Quantized>;
51
52    type QuantizationError = M::QuantizationError;
53
54    fn try_into_quantized(
55        self,
56        group_size: i32,
57        bits: i32,
58    ) -> Result<Self::Quantized, Self::QuantizationError> {
59        (*self).try_into_quantized(group_size, bits).map(Box::new)
60    }
61}
62
63impl<M> Quantizable for Option<M>
64where
65    M: Quantizable,
66{
67    type Quantized = Option<M::Quantized>;
68
69    type QuantizationError = M::QuantizationError;
70
71    fn try_into_quantized(
72        self,
73        group_size: i32,
74        bits: i32,
75    ) -> Result<Self::Quantized, Self::QuantizationError> {
76        match self {
77            Some(m) => m.try_into_quantized(group_size, bits).map(Some),
78            None => Ok(None),
79        }
80    }
81}
82
83/// A wrapper for a quantizable module.
84#[derive(Debug, Clone)]
85pub enum MaybeQuantized<M>
86where
87    M: Quantizable,
88{
89    /// The original module.
90    Original(M),
91
92    /// The quantized version of the module.
93    Quantized(M::Quantized),
94}
95
96impl<M> Quantizable for MaybeQuantized<M>
97where
98    M: Quantizable,
99{
100    type Quantized = Self;
101    type QuantizationError = <M as Quantizable>::QuantizationError;
102
103    fn try_into_quantized(
104        self,
105        group_size: i32,
106        bits: i32,
107    ) -> Result<Self, Self::QuantizationError> {
108        match self {
109            MaybeQuantized::Original(m) => {
110                let quantized = m.try_into_quantized(group_size, bits)?;
111                Ok(MaybeQuantized::Quantized(quantized))
112            }
113            MaybeQuantized::Quantized(q) => Ok(MaybeQuantized::Quantized(q)),
114        }
115    }
116}
117
118impl<M> MaybeQuantized<M>
119where
120    M: Quantizable,
121{
122    /// Create a new [`MaybeQuantized`] from the original module.
123    pub fn new(module: M) -> Self {
124        MaybeQuantized::Original(module)
125    }
126
127    /// Quantize the module with a custom quantization function.
128    ///
129    /// This is useful if one would like to quantize with a custom group size or bit width.
130    pub fn quantize_with(
131        self,
132        op: impl FnOnce(M) -> Result<M::Quantized, M::QuantizationError>,
133    ) -> Result<Self, M::QuantizationError> {
134        match self {
135            MaybeQuantized::Original(m) => op(m).map(MaybeQuantized::Quantized),
136            MaybeQuantized::Quantized(q) => Ok(MaybeQuantized::Quantized(q)),
137        }
138    }
139
140    /// Check if the module is quantized.
141    pub fn is_quantized(&self) -> bool {
142        match self {
143            MaybeQuantized::Original(_) => false,
144            MaybeQuantized::Quantized(_) => true,
145        }
146    }
147}
148
149impl<M> ModuleParameters for MaybeQuantized<M>
150where
151    M: Quantizable + ModuleParameters,
152    M::Quantized: ModuleParameters,
153{
154    fn num_parameters(&self) -> usize {
155        match self {
156            MaybeQuantized::Original(m) => m.num_parameters(),
157            MaybeQuantized::Quantized(q) => q.num_parameters(),
158        }
159    }
160
161    fn parameters(&self) -> crate::module::ModuleParamRef<'_> {
162        match self {
163            MaybeQuantized::Original(m) => m.parameters(),
164            MaybeQuantized::Quantized(q) => q.parameters(),
165        }
166    }
167
168    fn parameters_mut(&mut self) -> crate::module::ModuleParamMut<'_> {
169        match self {
170            MaybeQuantized::Original(m) => m.parameters_mut(),
171            MaybeQuantized::Quantized(q) => q.parameters_mut(),
172        }
173    }
174
175    fn trainable_parameters(&self) -> crate::module::ModuleParamRef<'_> {
176        match self {
177            MaybeQuantized::Original(m) => m.trainable_parameters(),
178            MaybeQuantized::Quantized(q) => q.trainable_parameters(),
179        }
180    }
181
182    fn freeze_parameters(&mut self, recursive: bool) {
183        match self {
184            MaybeQuantized::Original(m) => m.freeze_parameters(recursive),
185            MaybeQuantized::Quantized(q) => q.freeze_parameters(recursive),
186        }
187    }
188
189    fn unfreeze_parameters(&mut self, recursive: bool) {
190        match self {
191            MaybeQuantized::Original(m) => m.unfreeze_parameters(recursive),
192            MaybeQuantized::Quantized(q) => q.unfreeze_parameters(recursive),
193        }
194    }
195
196    fn all_frozen(&self) -> Option<bool> {
197        match self {
198            MaybeQuantized::Original(m) => m.all_frozen(),
199            MaybeQuantized::Quantized(q) => q.all_frozen(),
200        }
201    }
202
203    fn any_frozen(&self) -> Option<bool> {
204        match self {
205            MaybeQuantized::Original(m) => m.any_frozen(),
206            MaybeQuantized::Quantized(q) => q.any_frozen(),
207        }
208    }
209}
210
211impl<M, Input> Module<Input> for MaybeQuantized<M>
212where
213    M: Quantizable + Module<Input>,
214    M::Quantized:
215        Module<Input, Output = <M as Module<Input>>::Output, Error = <M as Module<Input>>::Error>,
216{
217    type Output = <M as Module<Input>>::Output;
218
219    type Error = <M as Module<Input>>::Error;
220
221    fn forward(&mut self, x: Input) -> Result<Self::Output, Self::Error> {
222        match self {
223            MaybeQuantized::Original(m) => m.forward(x),
224            MaybeQuantized::Quantized(q) => q.forward(x),
225        }
226    }
227
228    fn training_mode(&mut self, mode: bool) {
229        match self {
230            MaybeQuantized::Original(m) => m.training_mode(mode),
231            MaybeQuantized::Quantized(q) => q.training_mode(mode),
232        }
233    }
234}
235
236#[cfg(test)]
237mod tests {
238    use crate::nn::{self, Embedding, Linear};
239
240    use super::*;
241
242    #[test]
243    fn test_quantizable_linear() {
244        let linear = Linear::new(64, 64).unwrap();
245        let mut qlinear = MaybeQuantized::new(linear);
246        assert!(!qlinear.is_quantized());
247
248        qlinear = nn::quantize(qlinear, None, None).unwrap();
249        assert!(qlinear.is_quantized());
250    }
251
252    #[test]
253    fn test_quantizable_embedding() {
254        let embedding = Embedding::new(64, 64).unwrap();
255        let mut qembedding = MaybeQuantized::new(embedding);
256        assert!(!qembedding.is_quantized());
257
258        qembedding = nn::quantize(qembedding, None, None).unwrap();
259        assert!(qembedding.is_quantized());
260    }
261}