1use std::iter::once;
2
3use crate::{
4 array,
5 error::Exception,
6 module::{Module, ModuleParameters, Param},
7 ops::indexing::IndexOp,
8 ops::{self, dequantize, quantized_matmul, zeros},
9 quantization::Quantizable,
10 random::uniform,
11 Array,
12};
13use mlx_internal_macros::{Buildable, Builder};
14use mlx_macros::ModuleParameters;
15
16use crate::nn::{Embedding, Linear};
17
18pub fn quantize<M>(
26 module: M,
27 group_size: impl Into<Option<i32>>,
28 bits: impl Into<Option<i32>>,
29) -> Result<M::Quantized, M::QuantizationError>
30where
31 M: Quantizable,
32{
33 let group_size = group_size.into().unwrap_or(M::DEFAULT_GROUP_SIZE);
34 let bits = bits.into().unwrap_or(M::DEFAULT_BITS);
35 module.try_into_quantized(group_size, bits)
36}
37
38#[derive(Debug, Clone, Builder)]
40#[builder(
41 root = crate,
42 build_with = build_quantized_embedding,
43 err = Exception,
44)]
45pub struct QuantizedEmbeddingBuilder {
46 pub embedding_count: i32,
48
49 pub dimensions: i32,
51
52 #[builder(optional, default = QuantizedEmbedding::DEFAULT_GROUP_SIZE)]
54 pub group_size: i32,
55
56 #[builder(optional, default = QuantizedEmbedding::DEFAULT_BITS)]
58 pub bits: i32,
59}
60
61#[derive(Debug, Clone, ModuleParameters, Buildable)]
63#[module(root = crate)]
64#[buildable(root = crate)]
65pub struct QuantizedEmbedding {
66 pub group_size: i32,
68
69 pub bits: i32,
71
72 pub scales: Param<Array>,
74
75 pub biases: Param<Array>,
77
78 pub inner: Embedding,
80}
81
82impl QuantizedEmbeddingBuilder {
83 pub fn build_with_embedding(
85 self,
86 embedding: Embedding,
87 ) -> Result<QuantizedEmbedding, Exception> {
88 let weight = embedding.weight.value;
89 self.build_with_weight(weight)
90 }
91
92 pub fn build_with_weight(self, weight: Array) -> Result<QuantizedEmbedding, Exception> {
94 let group_size = self.group_size;
95 let bits = self.bits;
96 build_quantized_embedding_inner(weight, group_size, bits)
97 }
98}
99
100fn build_quantized_embedding_inner(
101 weight: Array,
102 group_size: i32,
103 bits: i32,
104) -> Result<QuantizedEmbedding, Exception> {
105 let (quantized_weight, scales, biases) = ops::quantize(&weight, group_size, bits)?;
106
107 let inner = Embedding {
108 weight: Param::new(quantized_weight),
109 };
110
111 let mut qe = QuantizedEmbedding {
112 group_size,
113 bits,
114 scales: Param::new(scales),
115 biases: Param::new(biases),
116 inner,
117 };
118
119 qe.freeze_parameters(true);
121
122 Ok(qe)
123}
124
125fn build_quantized_embedding(
126 builder: QuantizedEmbeddingBuilder,
127) -> Result<QuantizedEmbedding, Exception> {
128 let embedding_count = builder.embedding_count;
129 let dims = builder.dimensions;
130
131 let scale = array!(f32::sqrt(1.0 / (dims as f32)));
132 let weight = crate::random::normal::<f32>(&[embedding_count, dims], None, None, None)? * &scale;
134
135 builder.build_with_weight(weight)
136}
137
138impl QuantizedEmbedding {
139 pub const DEFAULT_GROUP_SIZE: i32 = 64;
141
142 pub const DEFAULT_BITS: i32 = 4;
144
145 pub fn try_from_embedding(
153 embedding: Embedding,
154 group_size: impl Into<Option<i32>>,
155 bits: impl Into<Option<i32>>,
156 ) -> Result<Self, Exception> {
157 let group_size = group_size.into().unwrap_or(Self::DEFAULT_GROUP_SIZE);
158 let bits = bits.into().unwrap_or(Self::DEFAULT_BITS);
159 build_quantized_embedding_inner(embedding.weight.value, group_size, bits)
160 }
161
162 pub fn as_linear(&self, x: impl AsRef<Array>) -> Result<Array, Exception> {
167 quantized_matmul(
168 x.as_ref(),
169 &self.inner.weight,
170 &self.scales,
171 &self.biases,
172 true,
173 self.group_size,
174 self.bits,
175 )
176 }
177}
178
179impl TryFrom<Embedding> for QuantizedEmbedding {
180 type Error = Exception;
181
182 fn try_from(embedding: Embedding) -> Result<Self, Self::Error> {
183 Self::try_from_embedding(embedding, None, None)
184 }
185}
186
187impl Module<&Array> for QuantizedEmbedding {
188 type Error = Exception;
189 type Output = Array;
190
191 fn forward(&mut self, x: &Array) -> Result<Array, Self::Error> {
192 let s = x.shape();
193 let x = x.flatten(None, None)?;
194 let w = self.inner.weight.index(&x);
195 let scales = self.scales.index(&x);
196 let biases = self.biases.index(&x);
197
198 let out = dequantize(&w, &scales, &biases, self.group_size, self.bits)?;
199
200 let ret_shape = s.iter().copied().chain(once(-1)).collect::<Vec<_>>();
201 out.reshape(&ret_shape)
202 }
203
204 fn training_mode(&mut self, mode: bool) {
205 self.inner.training_mode(mode);
206 }
207}
208
209#[derive(Debug, Clone, Builder)]
211#[builder(
212 root = crate,
213 build_with = build_quantized_linear,
214 err = Exception,
215)]
216pub struct QuantizedLinearBuilder {
217 pub input_dims: i32,
219
220 pub output_dims: i32,
222
223 #[builder(optional, default = QuantizedLinear::DEFAULT_GROUP_SIZE)]
225 pub group_size: i32,
226
227 #[builder(optional, default = QuantizedLinear::DEFAULT_BITS)]
229 pub bits: i32,
230
231 #[builder(optional, default = Linear::DEFAULT_BIAS)]
233 pub bias: bool,
234}
235
236impl QuantizedLinearBuilder {
237 pub fn build_with_linear(self, other: Linear) -> Result<QuantizedLinear, Exception> {
239 self.build_with_weight_and_bias(other.weight.value, other.bias.value)
240 }
241
242 fn build_with_weight_and_bias(
243 self,
244 weight: Array,
245 bias: Option<Array>,
246 ) -> Result<QuantizedLinear, Exception> {
247 build_quantized_linear_inner(weight, bias, self.group_size, self.bits)
248 }
249}
250
251fn build_quantized_linear_inner(
252 weight: Array,
253 bias: Option<Array>,
254 group_size: i32,
255 bits: i32,
256) -> Result<QuantizedLinear, Exception> {
257 let (quantized_weight, scales, biases) = ops::quantize(&weight, group_size, bits)?;
258
259 let inner = Linear {
260 weight: Param::new(quantized_weight),
261 bias: Param::new(bias),
262 };
263
264 let mut ql = QuantizedLinear {
265 group_size,
266 bits,
267 scales: Param::new(scales),
268 biases: Param::new(biases),
269 inner,
270 };
271
272 ql.freeze_parameters(true);
274
275 Ok(ql)
276}
277
278pub fn build_quantized_linear(
280 builder: QuantizedLinearBuilder,
281) -> Result<QuantizedLinear, Exception> {
282 let input_dims = builder.input_dims;
283 let output_dims = builder.output_dims;
284 let scale = f32::sqrt(1.0 / (input_dims as f32));
285 let weight = uniform::<_, f32>(-scale, scale, &[output_dims, input_dims], None)?;
286
287 let bias = if builder.bias {
288 Some(zeros::<f32>(&[output_dims])?)
289 } else {
290 None
291 };
292
293 builder.build_with_weight_and_bias(weight, bias)
294}
295
296#[derive(Debug, Clone, ModuleParameters, Buildable)]
305#[module(root = crate)]
306#[buildable(root = crate)]
307pub struct QuantizedLinear {
308 pub group_size: i32,
310
311 pub bits: i32,
313
314 #[param]
316 pub scales: Param<Array>,
317
318 #[param]
320 pub biases: Param<Array>,
321
322 #[param]
324 pub inner: Linear,
325}
326
327impl QuantizedLinear {
328 pub const DEFAULT_GROUP_SIZE: i32 = 64;
330
331 pub const DEFAULT_BITS: i32 = 4;
333
334 pub fn try_from_linear(
342 linear: Linear,
343 group_size: impl Into<Option<i32>>,
344 bits: impl Into<Option<i32>>,
345 ) -> Result<Self, Exception> {
346 let group_size = group_size.into().unwrap_or(Self::DEFAULT_GROUP_SIZE);
347 let bits = bits.into().unwrap_or(Self::DEFAULT_BITS);
348 build_quantized_linear_inner(linear.weight.value, linear.bias.value, group_size, bits)
349 }
350}
351
352impl TryFrom<Linear> for QuantizedLinear {
353 type Error = Exception;
354
355 fn try_from(linear: Linear) -> Result<Self, Self::Error> {
356 Self::try_from_linear(linear, None, None)
357 }
358}
359
360impl Module<&Array> for QuantizedLinear {
361 type Error = Exception;
362 type Output = Array;
363
364 fn forward(&mut self, x: &Array) -> Result<Array, Self::Error> {
365 let mut x = quantized_matmul(
366 x,
367 &self.inner.weight,
368 &self.scales,
369 &self.biases,
370 true,
371 self.group_size,
372 self.bits,
373 )?;
374 if let Some(bias) = &self.inner.bias.value {
375 x = x.add(bias)?;
376 }
377 Ok(x)
378 }
379
380 fn training_mode(&mut self, mode: bool) {
381 self.inner.training_mode(mode);
382 }
383}