1use std::iter::once;
2
3use crate::{
4 array,
5 error::Exception,
6 module::{Module, ModuleParameters, Param},
7 ops::indexing::IndexOp,
8 ops::{self, dequantize, quantized_matmul, zeros},
9 quantization::Quantizable,
10 random::uniform,
11 Array,
12};
13use mlx_internal_macros::{Buildable, Builder};
14use mlx_macros::ModuleParameters;
15
16use crate::nn::{Embedding, Linear};
17
18pub fn quantize<M>(
26 module: M,
27 group_size: impl Into<Option<i32>>,
28 bits: impl Into<Option<i32>>,
29) -> Result<M::Quantized, M::QuantizationError>
30where
31 M: Quantizable,
32{
33 let group_size = group_size.into().unwrap_or(M::DEFAULT_GROUP_SIZE);
34 let bits = bits.into().unwrap_or(M::DEFAULT_BITS);
35 module.try_into_quantized(group_size, bits)
36}
37
38#[derive(Debug, Clone, Builder)]
40#[builder(
41 root = crate,
42 build_with = build_quantized_embedding,
43 err = Exception,
44)]
45pub struct QuantizedEmbeddingBuilder {
46 pub embedding_count: i32,
48
49 pub dimensions: i32,
51
52 #[builder(optional, default = QuantizedEmbedding::DEFAULT_GROUP_SIZE)]
54 pub group_size: i32,
55
56 #[builder(optional, default = QuantizedEmbedding::DEFAULT_BITS)]
58 pub bits: i32,
59}
60
61#[derive(Debug, Clone, ModuleParameters, Buildable)]
63#[module(root = crate)]
64#[buildable(root = crate)]
65pub struct QuantizedEmbedding {
66 pub group_size: i32,
68
69 pub bits: i32,
71
72 #[param]
74 pub scales: Param<Array>,
75
76 #[param]
78 pub biases: Param<Array>,
79
80 #[param]
82 pub inner: Embedding,
83}
84
85impl QuantizedEmbeddingBuilder {
86 pub fn build_with_embedding(
88 self,
89 embedding: Embedding,
90 ) -> Result<QuantizedEmbedding, Exception> {
91 let weight = embedding.weight.value;
92 self.build_with_weight(weight)
93 }
94
95 pub fn build_with_weight(self, weight: Array) -> Result<QuantizedEmbedding, Exception> {
97 let group_size = self.group_size;
98 let bits = self.bits;
99 build_quantized_embedding_inner(weight, group_size, bits)
100 }
101}
102
103fn build_quantized_embedding_inner(
104 weight: Array,
105 group_size: i32,
106 bits: i32,
107) -> Result<QuantizedEmbedding, Exception> {
108 let (quantized_weight, scales, biases) = ops::quantize(&weight, group_size, bits)?;
109
110 let inner = Embedding {
111 weight: Param::new(quantized_weight),
112 };
113
114 let mut qe = QuantizedEmbedding {
115 group_size,
116 bits,
117 scales: Param::new(scales),
118 biases: Param::new(biases),
119 inner,
120 };
121
122 qe.freeze_parameters(true);
124
125 Ok(qe)
126}
127
128fn build_quantized_embedding(
129 builder: QuantizedEmbeddingBuilder,
130) -> Result<QuantizedEmbedding, Exception> {
131 let embedding_count = builder.embedding_count;
132 let dims = builder.dimensions;
133
134 let scale = array!(f32::sqrt(1.0 / (dims as f32)));
135 let weight = crate::random::normal::<f32>(&[embedding_count, dims], None, None, None)? * &scale;
137
138 builder.build_with_weight(weight)
139}
140
141impl QuantizedEmbedding {
142 pub const DEFAULT_GROUP_SIZE: i32 = 64;
144
145 pub const DEFAULT_BITS: i32 = 4;
147
148 pub fn try_from_embedding(
156 embedding: Embedding,
157 group_size: impl Into<Option<i32>>,
158 bits: impl Into<Option<i32>>,
159 ) -> Result<Self, Exception> {
160 let group_size = group_size.into().unwrap_or(Self::DEFAULT_GROUP_SIZE);
161 let bits = bits.into().unwrap_or(Self::DEFAULT_BITS);
162 build_quantized_embedding_inner(embedding.weight.value, group_size, bits)
163 }
164
165 pub fn as_linear(&self, x: impl AsRef<Array>) -> Result<Array, Exception> {
170 quantized_matmul(
171 x.as_ref(),
172 &self.inner.weight,
173 &self.scales,
174 self.biases.as_ref(),
175 true,
176 self.group_size,
177 self.bits,
178 )
179 }
180}
181
182impl TryFrom<Embedding> for QuantizedEmbedding {
183 type Error = Exception;
184
185 fn try_from(embedding: Embedding) -> Result<Self, Self::Error> {
186 Self::try_from_embedding(embedding, None, None)
187 }
188}
189
190impl Module<&Array> for QuantizedEmbedding {
191 type Error = Exception;
192 type Output = Array;
193
194 fn forward(&mut self, x: &Array) -> Result<Array, Self::Error> {
195 let s = x.shape();
196 let x = x.flatten(None, None)?;
197 let w = self.inner.weight.index(&x);
198 let scales = self.scales.index(&x);
199 let biases = self.biases.index(&x);
200
201 let out = dequantize(&w, &scales, &biases, self.group_size, self.bits)?;
202
203 let ret_shape = s.iter().copied().chain(once(-1)).collect::<Vec<_>>();
204 out.reshape(&ret_shape)
205 }
206
207 fn training_mode(&mut self, mode: bool) {
208 self.inner.training_mode(mode);
209 }
210}
211
212#[derive(Debug, Clone, Builder)]
214#[builder(
215 root = crate,
216 build_with = build_quantized_linear,
217 err = Exception,
218)]
219pub struct QuantizedLinearBuilder {
220 pub input_dims: i32,
222
223 pub output_dims: i32,
225
226 #[builder(optional, default = QuantizedLinear::DEFAULT_GROUP_SIZE)]
228 pub group_size: i32,
229
230 #[builder(optional, default = QuantizedLinear::DEFAULT_BITS)]
232 pub bits: i32,
233
234 #[builder(optional, default = Linear::DEFAULT_BIAS)]
236 pub bias: bool,
237}
238
239impl QuantizedLinearBuilder {
240 pub fn build_with_linear(self, other: Linear) -> Result<QuantizedLinear, Exception> {
242 self.build_with_weight_and_bias(other.weight.value, other.bias.value)
243 }
244
245 fn build_with_weight_and_bias(
246 self,
247 weight: Array,
248 bias: Option<Array>,
249 ) -> Result<QuantizedLinear, Exception> {
250 build_quantized_linear_inner(weight, bias, self.group_size, self.bits)
251 }
252}
253
254fn build_quantized_linear_inner(
255 weight: Array,
256 bias: Option<Array>,
257 group_size: i32,
258 bits: i32,
259) -> Result<QuantizedLinear, Exception> {
260 let (quantized_weight, scales, biases) = ops::quantize(&weight, group_size, bits)?;
261
262 let inner = Linear {
263 weight: Param::new(quantized_weight),
264 bias: Param::new(bias),
265 };
266
267 let mut ql = QuantizedLinear {
268 group_size,
269 bits,
270 scales: Param::new(scales),
271 biases: Param::new(biases),
272 inner,
273 };
274
275 ql.freeze_parameters(true);
277
278 Ok(ql)
279}
280
281pub fn build_quantized_linear(
283 builder: QuantizedLinearBuilder,
284) -> Result<QuantizedLinear, Exception> {
285 let input_dims = builder.input_dims;
286 let output_dims = builder.output_dims;
287 let scale = f32::sqrt(1.0 / (input_dims as f32));
288 let weight = uniform::<_, f32>(-scale, scale, &[output_dims, input_dims], None)?;
289
290 let bias = if builder.bias {
291 Some(zeros::<f32>(&[output_dims])?)
292 } else {
293 None
294 };
295
296 builder.build_with_weight_and_bias(weight, bias)
297}
298
299#[derive(Debug, Clone, ModuleParameters, Buildable)]
308#[module(root = crate)]
309#[buildable(root = crate)]
310pub struct QuantizedLinear {
311 pub group_size: i32,
313
314 pub bits: i32,
316
317 #[param]
319 pub scales: Param<Array>,
320
321 #[param]
323 pub biases: Param<Array>,
324
325 #[param]
327 pub inner: Linear,
328}
329
330impl QuantizedLinear {
331 pub const DEFAULT_GROUP_SIZE: i32 = 64;
333
334 pub const DEFAULT_BITS: i32 = 4;
336
337 pub fn try_from_linear(
345 linear: Linear,
346 group_size: impl Into<Option<i32>>,
347 bits: impl Into<Option<i32>>,
348 ) -> Result<Self, Exception> {
349 let group_size = group_size.into().unwrap_or(Self::DEFAULT_GROUP_SIZE);
350 let bits = bits.into().unwrap_or(Self::DEFAULT_BITS);
351 build_quantized_linear_inner(linear.weight.value, linear.bias.value, group_size, bits)
352 }
353}
354
355impl TryFrom<Linear> for QuantizedLinear {
356 type Error = Exception;
357
358 fn try_from(linear: Linear) -> Result<Self, Self::Error> {
359 Self::try_from_linear(linear, None, None)
360 }
361}
362
363impl Module<&Array> for QuantizedLinear {
364 type Error = Exception;
365 type Output = Array;
366
367 fn forward(&mut self, x: &Array) -> Result<Array, Self::Error> {
368 let mut x = quantized_matmul(
369 x,
370 &self.inner.weight,
371 &self.scales,
372 self.biases.as_ref(),
373 true,
374 self.group_size,
375 self.bits,
376 )?;
377 if let Some(bias) = &self.inner.bias.value {
378 x = x.add(bias)?;
379 }
380 Ok(x)
381 }
382
383 fn training_mode(&mut self, mode: bool) {
384 self.inner.training_mode(mode);
385 }
386}