mlx_rs/
quantization.rs

1//! Traits for quantization
2
3use crate::module::{Module, ModuleParameters};
4
5/// Trait for quantization of modules.
6pub trait Quantizable {
7    /// The default group size for quantization.
8    const DEFAULT_GROUP_SIZE: i32 = 64;
9
10    /// The default number of bits for quantization.
11    const DEFAULT_BITS: i32 = 4;
12
13    /// The quantized type.
14    type Quantized;
15
16    /// The error type for quantization.
17    type QuantizationError;
18
19    /// Quantize the module with the specified group size and number of bits.
20    fn try_into_quantized(
21        self,
22        group_size: i32,
23        bits: i32,
24    ) -> Result<Self::Quantized, Self::QuantizationError>;
25}
26
27impl<M> Quantizable for Vec<M>
28where
29    M: Quantizable,
30{
31    type Quantized = Vec<M::Quantized>;
32
33    type QuantizationError = M::QuantizationError;
34
35    fn try_into_quantized(
36        self,
37        group_size: i32,
38        bits: i32,
39    ) -> Result<Self::Quantized, Self::QuantizationError> {
40        self.into_iter()
41            .map(|m| m.try_into_quantized(group_size, bits))
42            .collect()
43    }
44}
45
46impl<M> Quantizable for Box<M>
47where
48    M: Quantizable,
49{
50    type Quantized = Box<M::Quantized>;
51
52    type QuantizationError = M::QuantizationError;
53
54    fn try_into_quantized(
55        self,
56        group_size: i32,
57        bits: i32,
58    ) -> Result<Self::Quantized, Self::QuantizationError> {
59        (*self).try_into_quantized(group_size, bits).map(Box::new)
60    }
61}
62
63/// A wrapper for a quantizable module.
64#[derive(Debug, Clone)]
65pub enum MaybeQuantized<M>
66where
67    M: Quantizable,
68{
69    /// The original module.
70    Original(M),
71
72    /// The quantized version of the module.
73    Quantized(M::Quantized),
74}
75
76impl<M> Quantizable for MaybeQuantized<M>
77where
78    M: Quantizable,
79{
80    type Quantized = Self;
81    type QuantizationError = <M as Quantizable>::QuantizationError;
82
83    fn try_into_quantized(
84        self,
85        group_size: i32,
86        bits: i32,
87    ) -> Result<Self, Self::QuantizationError> {
88        match self {
89            MaybeQuantized::Original(m) => {
90                let quantized = m.try_into_quantized(group_size, bits)?;
91                Ok(MaybeQuantized::Quantized(quantized))
92            }
93            MaybeQuantized::Quantized(q) => Ok(MaybeQuantized::Quantized(q)),
94        }
95    }
96}
97
98impl<M> MaybeQuantized<M>
99where
100    M: Quantizable,
101{
102    /// Create a new [`MaybeQuantized`] from the original module.
103    pub fn new(module: M) -> Self {
104        MaybeQuantized::Original(module)
105    }
106
107    /// Quantize the module with a custom quantization function.
108    ///
109    /// This is useful if one would like to quantize with a custom group size or bit width.
110    pub fn quantize_with(
111        self,
112        op: impl FnOnce(M) -> Result<M::Quantized, M::QuantizationError>,
113    ) -> Result<Self, M::QuantizationError> {
114        match self {
115            MaybeQuantized::Original(m) => op(m).map(MaybeQuantized::Quantized),
116            MaybeQuantized::Quantized(q) => Ok(MaybeQuantized::Quantized(q)),
117        }
118    }
119
120    /// Check if the module is quantized.
121    pub fn is_quantized(&self) -> bool {
122        match self {
123            MaybeQuantized::Original(_) => false,
124            MaybeQuantized::Quantized(_) => true,
125        }
126    }
127}
128
129impl<M> ModuleParameters for MaybeQuantized<M>
130where
131    M: Quantizable + ModuleParameters,
132    M::Quantized: ModuleParameters,
133{
134    fn parameters(&self) -> crate::module::ModuleParamRef<'_> {
135        match self {
136            MaybeQuantized::Original(m) => m.parameters(),
137            MaybeQuantized::Quantized(q) => q.parameters(),
138        }
139    }
140
141    fn parameters_mut(&mut self) -> crate::module::ModuleParamMut<'_> {
142        match self {
143            MaybeQuantized::Original(m) => m.parameters_mut(),
144            MaybeQuantized::Quantized(q) => q.parameters_mut(),
145        }
146    }
147
148    fn trainable_parameters(&self) -> crate::module::ModuleParamRef<'_> {
149        match self {
150            MaybeQuantized::Original(m) => m.trainable_parameters(),
151            MaybeQuantized::Quantized(q) => q.trainable_parameters(),
152        }
153    }
154
155    fn freeze_parameters(&mut self, recursive: bool) {
156        match self {
157            MaybeQuantized::Original(m) => m.freeze_parameters(recursive),
158            MaybeQuantized::Quantized(q) => q.freeze_parameters(recursive),
159        }
160    }
161
162    fn unfreeze_parameters(&mut self, recursive: bool) {
163        match self {
164            MaybeQuantized::Original(m) => m.unfreeze_parameters(recursive),
165            MaybeQuantized::Quantized(q) => q.unfreeze_parameters(recursive),
166        }
167    }
168
169    fn all_frozen(&self) -> Option<bool> {
170        match self {
171            MaybeQuantized::Original(m) => m.all_frozen(),
172            MaybeQuantized::Quantized(q) => q.all_frozen(),
173        }
174    }
175
176    fn any_frozen(&self) -> Option<bool> {
177        match self {
178            MaybeQuantized::Original(m) => m.any_frozen(),
179            MaybeQuantized::Quantized(q) => q.any_frozen(),
180        }
181    }
182}
183
184impl<M, Input> Module<Input> for MaybeQuantized<M>
185where
186    M: Quantizable + Module<Input>,
187    M::Quantized:
188        Module<Input, Output = <M as Module<Input>>::Output, Error = <M as Module<Input>>::Error>,
189{
190    type Output = <M as Module<Input>>::Output;
191
192    type Error = <M as Module<Input>>::Error;
193
194    fn forward(&mut self, x: Input) -> Result<Self::Output, Self::Error> {
195        match self {
196            MaybeQuantized::Original(m) => m.forward(x),
197            MaybeQuantized::Quantized(q) => q.forward(x),
198        }
199    }
200
201    fn training_mode(&mut self, mode: bool) {
202        match self {
203            MaybeQuantized::Original(m) => m.training_mode(mode),
204            MaybeQuantized::Quantized(q) => q.training_mode(mode),
205        }
206    }
207}
208
209#[cfg(test)]
210mod tests {
211    use crate::nn::{self, Embedding, Linear};
212
213    use super::*;
214
215    #[test]
216    fn test_quantizable_linear() {
217        let linear = Linear::new(64, 64).unwrap();
218        let mut qlinear = MaybeQuantized::new(linear);
219        assert!(!qlinear.is_quantized());
220
221        qlinear = nn::quantize(qlinear, None, None).unwrap();
222        assert!(qlinear.is_quantized());
223    }
224
225    #[test]
226    fn test_quantizable_embedding() {
227        let embedding = Embedding::new(64, 64).unwrap();
228        let mut qembedding = MaybeQuantized::new(embedding);
229        assert!(!qembedding.is_quantized());
230
231        qembedding = nn::quantize(qembedding, None, None).unwrap();
232        assert!(qembedding.is_quantized());
233    }
234}