mlx_rs/nn/
quantized.rs

1use std::iter::once;
2
3use crate::{
4    array,
5    error::Exception,
6    module::{Module, ModuleParameters, Param},
7    ops::indexing::IndexOp,
8    ops::{self, dequantize, quantized_matmul, zeros},
9    quantization::Quantizable,
10    random::uniform,
11    Array,
12};
13use mlx_internal_macros::{Buildable, Builder};
14use mlx_macros::ModuleParameters;
15
16use crate::nn::{Embedding, Linear};
17
18/// Quantize a module.
19///
20/// # Params
21///
22/// - `module`: The module to quantize.
23/// - `group_size`: The group size to use for the quantized weight. Default to [`Quantizable::DEFAULT_GROUP_SIZE`]
24/// - `bits`: The bit width to use for the quantized weight. Default to [`Quantizable::DEFAULT_BITS`]
25pub fn quantize<M>(
26    module: M,
27    group_size: impl Into<Option<i32>>,
28    bits: impl Into<Option<i32>>,
29) -> Result<M::Quantized, M::QuantizationError>
30where
31    M: Quantizable,
32{
33    let group_size = group_size.into().unwrap_or(M::DEFAULT_GROUP_SIZE);
34    let bits = bits.into().unwrap_or(M::DEFAULT_BITS);
35    module.try_into_quantized(group_size, bits)
36}
37
38/// Builder for [`QuantizedEmbedding`]
39#[derive(Debug, Clone, Builder)]
40#[builder(
41    root = crate,
42    build_with = build_quantized_embedding,
43    err = Exception,
44)]
45pub struct QuantizedEmbeddingBuilder {
46    /// How many possible discrete tokens can we embed. Usually called the vocabulary size.
47    pub embedding_count: i32,
48
49    /// The dimensionality of the embeddings.
50    pub dimensions: i32,
51
52    /// Quantization group size. Default to [`QuantizedEmbedding::DEFAULT_GROUP_SIZE`]
53    #[builder(optional, default = QuantizedEmbedding::DEFAULT_GROUP_SIZE)]
54    pub group_size: i32,
55
56    /// Bits per parameter. Default to [`QuantizedEmbedding::DEFAULT_BITS`]
57    #[builder(optional, default = QuantizedEmbedding::DEFAULT_BITS)]
58    pub bits: i32,
59}
60
61/// The same as ``Embedding`` but with a quantized weight matrix.
62#[derive(Debug, Clone, ModuleParameters, Buildable)]
63#[module(root = crate)]
64#[buildable(root = crate)]
65pub struct QuantizedEmbedding {
66    /// Quantization group size. Default to [`QuantizedEmbedding::DEFAULT_GROUP_SIZE`]
67    pub group_size: i32,
68
69    /// Bits per parameter. Default to [`QuantizedEmbedding::DEFAULT_BITS`]
70    pub bits: i32,
71
72    /// Scales
73    pub scales: Param<Array>,
74
75    /// Biases
76    pub biases: Param<Array>,
77
78    /// Inner embedding
79    pub inner: Embedding,
80}
81
82impl QuantizedEmbeddingBuilder {
83    /// Convenience method to build a new [`QuantizedEmbedding`] with an existing [`Embedding`]
84    pub fn build_with_embedding(
85        self,
86        embedding: Embedding,
87    ) -> Result<QuantizedEmbedding, Exception> {
88        let weight = embedding.weight.value;
89        self.build_with_weight(weight)
90    }
91
92    /// Convenience method to build a new [`QuantizedEmbedding`] with an existing weight matrix
93    pub fn build_with_weight(self, weight: Array) -> Result<QuantizedEmbedding, Exception> {
94        let group_size = self.group_size;
95        let bits = self.bits;
96        build_quantized_embedding_inner(weight, group_size, bits)
97    }
98}
99
100fn build_quantized_embedding_inner(
101    weight: Array,
102    group_size: i32,
103    bits: i32,
104) -> Result<QuantizedEmbedding, Exception> {
105    let (quantized_weight, scales, biases) = ops::quantize(&weight, group_size, bits)?;
106
107    let inner = Embedding {
108        weight: Param::new(quantized_weight),
109    };
110
111    let mut qe = QuantizedEmbedding {
112        group_size,
113        bits,
114        scales: Param::new(scales),
115        biases: Param::new(biases),
116        inner,
117    };
118
119    // Freeze all parameters
120    qe.freeze_parameters(true);
121
122    Ok(qe)
123}
124
125fn build_quantized_embedding(
126    builder: QuantizedEmbeddingBuilder,
127) -> Result<QuantizedEmbedding, Exception> {
128    let embedding_count = builder.embedding_count;
129    let dims = builder.dimensions;
130
131    let scale = array!(f32::sqrt(1.0 / (dims as f32)));
132    // SAFETY: This is safe because the array scale is a single element array
133    let weight = crate::random::normal::<f32>(&[embedding_count, dims], None, None, None)? * &scale;
134
135    builder.build_with_weight(weight)
136}
137
138impl QuantizedEmbedding {
139    /// Default group size
140    pub const DEFAULT_GROUP_SIZE: i32 = 64;
141
142    /// Default bits
143    pub const DEFAULT_BITS: i32 = 4;
144
145    /// Convert an embedding layer to a quantized embedding layer.
146    ///
147    /// # Params
148    ///
149    /// - `embedding`: The embedding layer to convert.
150    /// - `group_size`: The group size to use for the quantized weight. Default to [`QuantizedEmbedding::DEFAULT_GROUP_SIZE`]
151    /// - `bits`: The bit width to use for the quantized weight. Default to [`QuantizedEmbedding::DEFAULT_BITS`]
152    pub fn try_from_embedding(
153        embedding: Embedding,
154        group_size: impl Into<Option<i32>>,
155        bits: impl Into<Option<i32>>,
156    ) -> Result<Self, Exception> {
157        let group_size = group_size.into().unwrap_or(Self::DEFAULT_GROUP_SIZE);
158        let bits = bits.into().unwrap_or(Self::DEFAULT_BITS);
159        build_quantized_embedding_inner(embedding.weight.value, group_size, bits)
160    }
161
162    /// Call the embedding layer as a linear layer.
163    ///
164    /// Use this for example when input embedding and output projection
165    /// weights are tied.
166    pub fn as_linear(&self, x: impl AsRef<Array>) -> Result<Array, Exception> {
167        quantized_matmul(
168            x.as_ref(),
169            &self.inner.weight,
170            &self.scales,
171            &self.biases,
172            true,
173            self.group_size,
174            self.bits,
175        )
176    }
177}
178
179impl TryFrom<Embedding> for QuantizedEmbedding {
180    type Error = Exception;
181
182    fn try_from(embedding: Embedding) -> Result<Self, Self::Error> {
183        Self::try_from_embedding(embedding, None, None)
184    }
185}
186
187impl Module<&Array> for QuantizedEmbedding {
188    type Error = Exception;
189    type Output = Array;
190
191    fn forward(&mut self, x: &Array) -> Result<Array, Self::Error> {
192        let s = x.shape();
193        let x = x.flatten(None, None)?;
194        let w = self.inner.weight.index(&x);
195        let scales = self.scales.index(&x);
196        let biases = self.biases.index(&x);
197
198        let out = dequantize(&w, &scales, &biases, self.group_size, self.bits)?;
199
200        let ret_shape = s.iter().copied().chain(once(-1)).collect::<Vec<_>>();
201        out.reshape(&ret_shape)
202    }
203
204    fn training_mode(&mut self, mode: bool) {
205        self.inner.training_mode(mode);
206    }
207}
208
209/// Builder for [`QuantizedLinear`]
210#[derive(Debug, Clone, Builder)]
211#[builder(
212    root = crate,
213    build_with = build_quantized_linear,
214    err = Exception,
215)]
216pub struct QuantizedLinearBuilder {
217    /// The dimensionality of the input features.
218    pub input_dims: i32,
219
220    /// The dimensionality of the output features.
221    pub output_dims: i32,
222
223    /// Quantization group size. Default to [`QuantizedLinear::DEFAULT_GROUP_SIZE`]
224    #[builder(optional, default = QuantizedLinear::DEFAULT_GROUP_SIZE)]
225    pub group_size: i32,
226
227    /// Bits per parameter. Default to [`QuantizedLinear::DEFAULT_BITS`]
228    #[builder(optional, default = QuantizedLinear::DEFAULT_BITS)]
229    pub bits: i32,
230
231    /// Whether the linear layer has a bias. Default to [`Linear::DEFAULT_BIAS`]
232    #[builder(optional, default = Linear::DEFAULT_BIAS)]
233    pub bias: bool,
234}
235
236impl QuantizedLinearBuilder {
237    /// Convenience method to build a new [`QuantizedLinear`] with an existing [`Linear`]
238    pub fn build_with_linear(self, other: Linear) -> Result<QuantizedLinear, Exception> {
239        self.build_with_weight_and_bias(other.weight.value, other.bias.value)
240    }
241
242    fn build_with_weight_and_bias(
243        self,
244        weight: Array,
245        bias: Option<Array>,
246    ) -> Result<QuantizedLinear, Exception> {
247        build_quantized_linear_inner(weight, bias, self.group_size, self.bits)
248    }
249}
250
251fn build_quantized_linear_inner(
252    weight: Array,
253    bias: Option<Array>,
254    group_size: i32,
255    bits: i32,
256) -> Result<QuantizedLinear, Exception> {
257    let (quantized_weight, scales, biases) = ops::quantize(&weight, group_size, bits)?;
258
259    let inner = Linear {
260        weight: Param::new(quantized_weight),
261        bias: Param::new(bias),
262    };
263
264    let mut ql = QuantizedLinear {
265        group_size,
266        bits,
267        scales: Param::new(scales),
268        biases: Param::new(biases),
269        inner,
270    };
271
272    // Freeze all parameters
273    ql.freeze_parameters(true);
274
275    Ok(ql)
276}
277
278/// Builds a new [`QuantizedLinear`]
279pub fn build_quantized_linear(
280    builder: QuantizedLinearBuilder,
281) -> Result<QuantizedLinear, Exception> {
282    let input_dims = builder.input_dims;
283    let output_dims = builder.output_dims;
284    let scale = f32::sqrt(1.0 / (input_dims as f32));
285    let weight = uniform::<_, f32>(-scale, scale, &[output_dims, input_dims], None)?;
286
287    let bias = if builder.bias {
288        Some(zeros::<f32>(&[output_dims])?)
289    } else {
290        None
291    };
292
293    builder.build_with_weight_and_bias(weight, bias)
294}
295
296/// Applies an affine transformation to the input using a quantized weight matrix.
297///
298/// It is the quantized equivalent of [`Linear`].  For now its
299/// parameters are frozen and will not be included in any gradient computation
300/// but this will probably change in the future.
301///
302/// QuantizedLinear also provides several useful static to convert linear
303/// layers to QuantizedLinear layers.
304#[derive(Debug, Clone, ModuleParameters, Buildable)]
305#[module(root = crate)]
306#[buildable(root = crate)]
307pub struct QuantizedLinear {
308    /// Quantization group size. Default to [`QuantizedLinear::DEFAULT_GROUP_SIZE`]
309    pub group_size: i32,
310
311    /// Bits per parameter. Default to [`QuantizedLinear::DEFAULT_BITS`]
312    pub bits: i32,
313
314    /// Scales
315    #[param]
316    pub scales: Param<Array>,
317
318    /// Biases
319    #[param]
320    pub biases: Param<Array>,
321
322    /// Inner linear layer
323    #[param]
324    pub inner: Linear,
325}
326
327impl QuantizedLinear {
328    /// Default group size
329    pub const DEFAULT_GROUP_SIZE: i32 = 64;
330
331    /// Default bits
332    pub const DEFAULT_BITS: i32 = 4;
333
334    /// Convert a linear layer to a quantized linear layer.
335    ///
336    /// # Params
337    ///
338    /// - `linear`: The linear layer to convert.
339    /// - `group_size`: The group size to use for the quantized weight. Default to [`QuantizedLinear::DEFAULT_GROUP_SIZE`]
340    /// - `bits`: The bit width to use for the quantized weight. Default to [`QuantizedLinear::DEFAULT_BITS`]
341    pub fn try_from_linear(
342        linear: Linear,
343        group_size: impl Into<Option<i32>>,
344        bits: impl Into<Option<i32>>,
345    ) -> Result<Self, Exception> {
346        let group_size = group_size.into().unwrap_or(Self::DEFAULT_GROUP_SIZE);
347        let bits = bits.into().unwrap_or(Self::DEFAULT_BITS);
348        build_quantized_linear_inner(linear.weight.value, linear.bias.value, group_size, bits)
349    }
350}
351
352impl TryFrom<Linear> for QuantizedLinear {
353    type Error = Exception;
354
355    fn try_from(linear: Linear) -> Result<Self, Self::Error> {
356        Self::try_from_linear(linear, None, None)
357    }
358}
359
360impl Module<&Array> for QuantizedLinear {
361    type Error = Exception;
362    type Output = Array;
363
364    fn forward(&mut self, x: &Array) -> Result<Array, Self::Error> {
365        let mut x = quantized_matmul(
366            x,
367            &self.inner.weight,
368            &self.scales,
369            &self.biases,
370            true,
371            self.group_size,
372            self.bits,
373        )?;
374        if let Some(bias) = &self.inner.bias.value {
375            x = x.add(bias)?;
376        }
377        Ok(x)
378    }
379
380    fn training_mode(&mut self, mode: bool) {
381        self.inner.training_mode(mode);
382    }
383}