mlx_rs/ops/
quantization.rs1use mlx_internal_macros::{default_device, generate_macro};
2
3use crate::{error::Result, utils::guard::Guarded, Array, Stream, StreamOrDevice};
4
5#[generate_macro]
23#[default_device]
24pub fn quantize_device(
25 w: impl AsRef<Array>,
26 #[optional] group_size: impl Into<Option<i32>>,
27 #[optional] bits: impl Into<Option<i32>>,
28 #[optional] stream: impl AsRef<Stream>,
29) -> Result<(Array, Array, Array)> {
30 let group_size = group_size.into().unwrap_or(64);
31 let bits = bits.into().unwrap_or(4);
32
33 <(Array, Array, Array) as Guarded>::try_from_op(|(res0, res1, res2)| unsafe {
34 mlx_sys::mlx_quantize(
35 res0,
36 res1,
37 res2,
38 w.as_ref().as_ptr(),
39 group_size,
40 bits,
41 stream.as_ref().as_ptr(),
42 )
43 })
44}
45
46#[allow(clippy::too_many_arguments)]
50#[generate_macro]
51#[default_device]
52pub fn quantized_matmul_device(
53 x: impl AsRef<Array>,
54 w: impl AsRef<Array>,
55 scales: impl AsRef<Array>,
56 biases: impl AsRef<Array>,
57 #[optional] transpose: impl Into<Option<bool>>,
58 #[optional] group_size: impl Into<Option<i32>>,
59 #[optional] bits: impl Into<Option<i32>>,
60 #[optional] stream: impl AsRef<Stream>,
61) -> Result<Array> {
62 let transpose = transpose.into().unwrap_or(false);
63 let group_size = group_size.into().unwrap_or(64);
64 let bits = bits.into().unwrap_or(4);
65
66 <Array as Guarded>::try_from_op(|res| unsafe {
67 mlx_sys::mlx_quantized_matmul(
68 res,
69 x.as_ref().as_ptr(),
70 w.as_ref().as_ptr(),
71 scales.as_ref().as_ptr(),
72 biases.as_ref().as_ptr(),
73 transpose,
74 group_size,
75 bits,
76 stream.as_ref().as_ptr(),
77 )
78 })
79}
80
81#[generate_macro]
87#[default_device]
88pub fn dequantize_device(
89 w: impl AsRef<Array>,
90 scales: impl AsRef<Array>,
91 biases: impl AsRef<Array>,
92 #[optional] group_size: impl Into<Option<i32>>,
93 #[optional] bits: impl Into<Option<i32>>,
94 #[optional] stream: impl AsRef<Stream>,
95) -> Result<Array> {
96 let group_size = group_size.into().unwrap_or(64);
97 let bits = bits.into().unwrap_or(4);
98
99 <Array as Guarded>::try_from_op(|res| unsafe {
100 mlx_sys::mlx_dequantize(
101 res,
102 w.as_ref().as_ptr(),
103 scales.as_ref().as_ptr(),
104 biases.as_ref().as_ptr(),
105 group_size,
106 bits,
107 stream.as_ref().as_ptr(),
108 )
109 })
110}
111
112#[cfg(test)]
113mod tests {
114 use crate::{
115 ops::{dequantize, expand_dims, quantize},
116 Array,
117 };
118
119 #[test]
120 fn test_quantize_dequantize() {
121 let x1 = Array::ones::<f32>(&[128, 1]).unwrap();
122 let x2 = expand_dims(Array::arange::<_, f32>(0, 512, None).unwrap(), &[0]).unwrap();
123 let x = x1 * x2;
124
125 for i in [2, 4, 8].iter() {
126 let el_per_int = 32 / i;
127 let (x_q, scales, biases) = quantize(&x, 128, *i).unwrap();
128 assert_eq!(x_q.shape(), [128, 512 / el_per_int]);
129 assert_eq!(scales.shape(), [128, 4]);
130 assert_eq!(biases.shape(), [128, 4]);
131
132 let x_hat = dequantize(&x_q, &scales, &biases, 128, *i).unwrap();
133 let max_diff = ((&x - &x_hat).abs().unwrap().max(None, None).unwrap()).item::<f32>();
134 assert!(max_diff <= 127.0 / (1 << i) as f32);
135 }
136 }
137}