mlx_rs/nn/
quantized.rs

1use std::iter::once;
2
3use crate::{
4    array,
5    error::Exception,
6    module::{Module, ModuleParameters, Param},
7    ops::indexing::IndexOp,
8    ops::{self, dequantize, quantized_matmul, zeros},
9    quantization::Quantizable,
10    random::uniform,
11    Array,
12};
13use mlx_internal_macros::{Buildable, Builder};
14use mlx_macros::ModuleParameters;
15
16use crate::nn::{Embedding, Linear};
17
18/// Quantize a module.
19///
20/// # Params
21///
22/// - `module`: The module to quantize.
23/// - `group_size`: The group size to use for the quantized weight. Default to [`Quantizable::DEFAULT_GROUP_SIZE`]
24/// - `bits`: The bit width to use for the quantized weight. Default to [`Quantizable::DEFAULT_BITS`]
25pub fn quantize<M>(
26    module: M,
27    group_size: impl Into<Option<i32>>,
28    bits: impl Into<Option<i32>>,
29) -> Result<M::Quantized, M::QuantizationError>
30where
31    M: Quantizable,
32{
33    let group_size = group_size.into().unwrap_or(M::DEFAULT_GROUP_SIZE);
34    let bits = bits.into().unwrap_or(M::DEFAULT_BITS);
35    module.try_into_quantized(group_size, bits)
36}
37
38/// Builder for [`QuantizedEmbedding`]
39#[derive(Debug, Clone, Builder)]
40#[builder(
41    root = crate,
42    build_with = build_quantized_embedding,
43    err = Exception,
44)]
45pub struct QuantizedEmbeddingBuilder {
46    /// How many possible discrete tokens can we embed. Usually called the vocabulary size.
47    pub embedding_count: i32,
48
49    /// The dimensionality of the embeddings.
50    pub dimensions: i32,
51
52    /// Quantization group size. Default to [`QuantizedEmbedding::DEFAULT_GROUP_SIZE`]
53    #[builder(optional, default = QuantizedEmbedding::DEFAULT_GROUP_SIZE)]
54    pub group_size: i32,
55
56    /// Bits per parameter. Default to [`QuantizedEmbedding::DEFAULT_BITS`]
57    #[builder(optional, default = QuantizedEmbedding::DEFAULT_BITS)]
58    pub bits: i32,
59}
60
61/// The same as ``Embedding`` but with a quantized weight matrix.
62#[derive(Debug, Clone, ModuleParameters, Buildable)]
63#[module(root = crate)]
64#[buildable(root = crate)]
65pub struct QuantizedEmbedding {
66    /// Quantization group size. Default to [`QuantizedEmbedding::DEFAULT_GROUP_SIZE`]
67    pub group_size: i32,
68
69    /// Bits per parameter. Default to [`QuantizedEmbedding::DEFAULT_BITS`]
70    pub bits: i32,
71
72    /// Scales
73    #[param]
74    pub scales: Param<Array>,
75
76    /// Biases
77    #[param]
78    pub biases: Param<Array>,
79
80    /// Inner embedding
81    #[param]
82    pub inner: Embedding,
83}
84
85impl QuantizedEmbeddingBuilder {
86    /// Convenience method to build a new [`QuantizedEmbedding`] with an existing [`Embedding`]
87    pub fn build_with_embedding(
88        self,
89        embedding: Embedding,
90    ) -> Result<QuantizedEmbedding, Exception> {
91        let weight = embedding.weight.value;
92        self.build_with_weight(weight)
93    }
94
95    /// Convenience method to build a new [`QuantizedEmbedding`] with an existing weight matrix
96    pub fn build_with_weight(self, weight: Array) -> Result<QuantizedEmbedding, Exception> {
97        let group_size = self.group_size;
98        let bits = self.bits;
99        build_quantized_embedding_inner(weight, group_size, bits)
100    }
101}
102
103fn build_quantized_embedding_inner(
104    weight: Array,
105    group_size: i32,
106    bits: i32,
107) -> Result<QuantizedEmbedding, Exception> {
108    let (quantized_weight, scales, biases) = ops::quantize(&weight, group_size, bits)?;
109
110    let inner = Embedding {
111        weight: Param::new(quantized_weight),
112    };
113
114    let mut qe = QuantizedEmbedding {
115        group_size,
116        bits,
117        scales: Param::new(scales),
118        biases: Param::new(biases),
119        inner,
120    };
121
122    // Freeze all parameters
123    qe.freeze_parameters(true);
124
125    Ok(qe)
126}
127
128fn build_quantized_embedding(
129    builder: QuantizedEmbeddingBuilder,
130) -> Result<QuantizedEmbedding, Exception> {
131    let embedding_count = builder.embedding_count;
132    let dims = builder.dimensions;
133
134    let scale = array!(f32::sqrt(1.0 / (dims as f32)));
135    // SAFETY: This is safe because the array scale is a single element array
136    let weight = crate::random::normal::<f32>(&[embedding_count, dims], None, None, None)? * &scale;
137
138    builder.build_with_weight(weight)
139}
140
141impl QuantizedEmbedding {
142    /// Default group size
143    pub const DEFAULT_GROUP_SIZE: i32 = 64;
144
145    /// Default bits
146    pub const DEFAULT_BITS: i32 = 4;
147
148    /// Convert an embedding layer to a quantized embedding layer.
149    ///
150    /// # Params
151    ///
152    /// - `embedding`: The embedding layer to convert.
153    /// - `group_size`: The group size to use for the quantized weight. Default to [`QuantizedEmbedding::DEFAULT_GROUP_SIZE`]
154    /// - `bits`: The bit width to use for the quantized weight. Default to [`QuantizedEmbedding::DEFAULT_BITS`]
155    pub fn try_from_embedding(
156        embedding: Embedding,
157        group_size: impl Into<Option<i32>>,
158        bits: impl Into<Option<i32>>,
159    ) -> Result<Self, Exception> {
160        let group_size = group_size.into().unwrap_or(Self::DEFAULT_GROUP_SIZE);
161        let bits = bits.into().unwrap_or(Self::DEFAULT_BITS);
162        build_quantized_embedding_inner(embedding.weight.value, group_size, bits)
163    }
164
165    /// Call the embedding layer as a linear layer.
166    ///
167    /// Use this for example when input embedding and output projection
168    /// weights are tied.
169    pub fn as_linear(&self, x: impl AsRef<Array>) -> Result<Array, Exception> {
170        quantized_matmul(
171            x.as_ref(),
172            &self.inner.weight,
173            &self.scales,
174            self.biases.as_ref(),
175            true,
176            self.group_size,
177            self.bits,
178        )
179    }
180}
181
182impl TryFrom<Embedding> for QuantizedEmbedding {
183    type Error = Exception;
184
185    fn try_from(embedding: Embedding) -> Result<Self, Self::Error> {
186        Self::try_from_embedding(embedding, None, None)
187    }
188}
189
190impl Module<&Array> for QuantizedEmbedding {
191    type Error = Exception;
192    type Output = Array;
193
194    fn forward(&mut self, x: &Array) -> Result<Array, Self::Error> {
195        let s = x.shape();
196        let x = x.flatten(None, None)?;
197        let w = self.inner.weight.index(&x);
198        let scales = self.scales.index(&x);
199        let biases = self.biases.index(&x);
200
201        let out = dequantize(&w, &scales, &biases, self.group_size, self.bits)?;
202
203        let ret_shape = s.iter().copied().chain(once(-1)).collect::<Vec<_>>();
204        out.reshape(&ret_shape)
205    }
206
207    fn training_mode(&mut self, mode: bool) {
208        self.inner.training_mode(mode);
209    }
210}
211
212/// Builder for [`QuantizedLinear`]
213#[derive(Debug, Clone, Builder)]
214#[builder(
215    root = crate,
216    build_with = build_quantized_linear,
217    err = Exception,
218)]
219pub struct QuantizedLinearBuilder {
220    /// The dimensionality of the input features.
221    pub input_dims: i32,
222
223    /// The dimensionality of the output features.
224    pub output_dims: i32,
225
226    /// Quantization group size. Default to [`QuantizedLinear::DEFAULT_GROUP_SIZE`]
227    #[builder(optional, default = QuantizedLinear::DEFAULT_GROUP_SIZE)]
228    pub group_size: i32,
229
230    /// Bits per parameter. Default to [`QuantizedLinear::DEFAULT_BITS`]
231    #[builder(optional, default = QuantizedLinear::DEFAULT_BITS)]
232    pub bits: i32,
233
234    /// Whether the linear layer has a bias. Default to [`Linear::DEFAULT_BIAS`]
235    #[builder(optional, default = Linear::DEFAULT_BIAS)]
236    pub bias: bool,
237}
238
239impl QuantizedLinearBuilder {
240    /// Convenience method to build a new [`QuantizedLinear`] with an existing [`Linear`]
241    pub fn build_with_linear(self, other: Linear) -> Result<QuantizedLinear, Exception> {
242        self.build_with_weight_and_bias(other.weight.value, other.bias.value)
243    }
244
245    fn build_with_weight_and_bias(
246        self,
247        weight: Array,
248        bias: Option<Array>,
249    ) -> Result<QuantizedLinear, Exception> {
250        build_quantized_linear_inner(weight, bias, self.group_size, self.bits)
251    }
252}
253
254fn build_quantized_linear_inner(
255    weight: Array,
256    bias: Option<Array>,
257    group_size: i32,
258    bits: i32,
259) -> Result<QuantizedLinear, Exception> {
260    let (quantized_weight, scales, biases) = ops::quantize(&weight, group_size, bits)?;
261
262    let inner = Linear {
263        weight: Param::new(quantized_weight),
264        bias: Param::new(bias),
265    };
266
267    let mut ql = QuantizedLinear {
268        group_size,
269        bits,
270        scales: Param::new(scales),
271        biases: Param::new(biases),
272        inner,
273    };
274
275    // Freeze all parameters
276    ql.freeze_parameters(true);
277
278    Ok(ql)
279}
280
281/// Builds a new [`QuantizedLinear`]
282pub fn build_quantized_linear(
283    builder: QuantizedLinearBuilder,
284) -> Result<QuantizedLinear, Exception> {
285    let input_dims = builder.input_dims;
286    let output_dims = builder.output_dims;
287    let scale = f32::sqrt(1.0 / (input_dims as f32));
288    let weight = uniform::<_, f32>(-scale, scale, &[output_dims, input_dims], None)?;
289
290    let bias = if builder.bias {
291        Some(zeros::<f32>(&[output_dims])?)
292    } else {
293        None
294    };
295
296    builder.build_with_weight_and_bias(weight, bias)
297}
298
299/// Applies an affine transformation to the input using a quantized weight matrix.
300///
301/// It is the quantized equivalent of [`Linear`].  For now its
302/// parameters are frozen and will not be included in any gradient computation
303/// but this will probably change in the future.
304///
305/// QuantizedLinear also provides several useful static to convert linear
306/// layers to QuantizedLinear layers.
307#[derive(Debug, Clone, ModuleParameters, Buildable)]
308#[module(root = crate)]
309#[buildable(root = crate)]
310pub struct QuantizedLinear {
311    /// Quantization group size. Default to [`QuantizedLinear::DEFAULT_GROUP_SIZE`]
312    pub group_size: i32,
313
314    /// Bits per parameter. Default to [`QuantizedLinear::DEFAULT_BITS`]
315    pub bits: i32,
316
317    /// Scales
318    #[param]
319    pub scales: Param<Array>,
320
321    /// Biases
322    #[param]
323    pub biases: Param<Array>,
324
325    /// Inner linear layer
326    #[param]
327    pub inner: Linear,
328}
329
330impl QuantizedLinear {
331    /// Default group size
332    pub const DEFAULT_GROUP_SIZE: i32 = 64;
333
334    /// Default bits
335    pub const DEFAULT_BITS: i32 = 4;
336
337    /// Convert a linear layer to a quantized linear layer.
338    ///
339    /// # Params
340    ///
341    /// - `linear`: The linear layer to convert.
342    /// - `group_size`: The group size to use for the quantized weight. Default to [`QuantizedLinear::DEFAULT_GROUP_SIZE`]
343    /// - `bits`: The bit width to use for the quantized weight. Default to [`QuantizedLinear::DEFAULT_BITS`]
344    pub fn try_from_linear(
345        linear: Linear,
346        group_size: impl Into<Option<i32>>,
347        bits: impl Into<Option<i32>>,
348    ) -> Result<Self, Exception> {
349        let group_size = group_size.into().unwrap_or(Self::DEFAULT_GROUP_SIZE);
350        let bits = bits.into().unwrap_or(Self::DEFAULT_BITS);
351        build_quantized_linear_inner(linear.weight.value, linear.bias.value, group_size, bits)
352    }
353}
354
355impl TryFrom<Linear> for QuantizedLinear {
356    type Error = Exception;
357
358    fn try_from(linear: Linear) -> Result<Self, Self::Error> {
359        Self::try_from_linear(linear, None, None)
360    }
361}
362
363impl Module<&Array> for QuantizedLinear {
364    type Error = Exception;
365    type Output = Array;
366
367    fn forward(&mut self, x: &Array) -> Result<Array, Self::Error> {
368        let mut x = quantized_matmul(
369            x,
370            &self.inner.weight,
371            &self.scales,
372            self.biases.as_ref(),
373            true,
374            self.group_size,
375            self.bits,
376        )?;
377        if let Some(bias) = &self.inner.bias.value {
378            x = x.add(bias)?;
379        }
380        Ok(x)
381    }
382
383    fn training_mode(&mut self, mode: bool) {
384        self.inner.training_mode(mode);
385    }
386}