mlx_rs/ops/
quantization.rs

1use mlx_internal_macros::{default_device, generate_macro};
2
3use crate::{error::Result, utils::guard::Guarded, Array, Stream, StreamOrDevice};
4
5/// Quantize the matrix `w` using `bits` bits per element.
6///
7/// Note, every `group_size` elements in a row of `w` are quantized together. Hence, number of
8/// columns of `w` should be divisible by `group_size`. In particular, the rows of `w` are divided
9/// into groups of size `group_size` which are quantized together.
10///
11/// > `quantized` currently only supports 2D inputs with dimensions which are multiples of 32
12///
13/// For details, please see [this
14/// documentation](https://ml-explore.github.io/mlx/build/html/python/_autosummary/mlx.core.quantize.html)
15///
16/// # Params
17///
18/// - `w`: The input matrix
19/// - `group_size`: The size of the group in `w` that shares a scale and bias. (default: `64`)
20/// - `bits`: The number of bits occupied by each element of w in the returned quantized matrix.
21///   (default: 4)
22#[generate_macro]
23#[default_device]
24pub fn quantize_device(
25    w: impl AsRef<Array>,
26    #[optional] group_size: impl Into<Option<i32>>,
27    #[optional] bits: impl Into<Option<i32>>,
28    #[optional] stream: impl AsRef<Stream>,
29) -> Result<(Array, Array, Array)> {
30    let group_size = group_size.into().unwrap_or(64);
31    let bits = bits.into().unwrap_or(4);
32
33    <(Array, Array, Array) as Guarded>::try_from_op(|(res0, res1, res2)| unsafe {
34        mlx_sys::mlx_quantize(
35            res0,
36            res1,
37            res2,
38            w.as_ref().as_ptr(),
39            group_size,
40            bits,
41            stream.as_ref().as_ptr(),
42        )
43    })
44}
45
46/// Perform the matrix multiplication with the quantized matrix `w`. The quantization uses one
47/// floating point scale and bias per `group_size` of elements. Each element in `w` takes `bits`
48/// bits and is packed in an unsigned 32 bit integer.
49#[allow(clippy::too_many_arguments)]
50#[generate_macro]
51#[default_device]
52pub fn quantized_matmul_device(
53    x: impl AsRef<Array>,
54    w: impl AsRef<Array>,
55    scales: impl AsRef<Array>,
56    biases: impl AsRef<Array>,
57    #[optional] transpose: impl Into<Option<bool>>,
58    #[optional] group_size: impl Into<Option<i32>>,
59    #[optional] bits: impl Into<Option<i32>>,
60    #[optional] stream: impl AsRef<Stream>,
61) -> Result<Array> {
62    let transpose = transpose.into().unwrap_or(false);
63    let group_size = group_size.into().unwrap_or(64);
64    let bits = bits.into().unwrap_or(4);
65
66    <Array as Guarded>::try_from_op(|res| unsafe {
67        mlx_sys::mlx_quantized_matmul(
68            res,
69            x.as_ref().as_ptr(),
70            w.as_ref().as_ptr(),
71            scales.as_ref().as_ptr(),
72            biases.as_ref().as_ptr(),
73            transpose,
74            group_size,
75            bits,
76            stream.as_ref().as_ptr(),
77        )
78    })
79}
80
81/// Dequantize the matrix `w` using the provided `scales` and `biases` and the `group_size` and
82/// `bits` configuration.
83///
84/// For details, please see [this
85/// documentation](https://ml-explore.github.io/mlx/build/html/python/_autosummary/mlx.core.dequantize.html)
86#[generate_macro]
87#[default_device]
88pub fn dequantize_device(
89    w: impl AsRef<Array>,
90    scales: impl AsRef<Array>,
91    biases: impl AsRef<Array>,
92    #[optional] group_size: impl Into<Option<i32>>,
93    #[optional] bits: impl Into<Option<i32>>,
94    #[optional] stream: impl AsRef<Stream>,
95) -> Result<Array> {
96    let group_size = group_size.into().unwrap_or(64);
97    let bits = bits.into().unwrap_or(4);
98
99    <Array as Guarded>::try_from_op(|res| unsafe {
100        mlx_sys::mlx_dequantize(
101            res,
102            w.as_ref().as_ptr(),
103            scales.as_ref().as_ptr(),
104            biases.as_ref().as_ptr(),
105            group_size,
106            bits,
107            stream.as_ref().as_ptr(),
108        )
109    })
110}
111
112#[cfg(test)]
113mod tests {
114    use crate::{
115        ops::{dequantize, expand_dims, quantize},
116        Array,
117    };
118
119    #[test]
120    fn test_quantize_dequantize() {
121        let x1 = Array::ones::<f32>(&[128, 1]).unwrap();
122        let x2 = expand_dims(Array::arange::<_, f32>(0, 512, None).unwrap(), &[0]).unwrap();
123        let x = x1 * x2;
124
125        for i in [2, 4, 8].iter() {
126            let el_per_int = 32 / i;
127            let (x_q, scales, biases) = quantize(&x, 128, *i).unwrap();
128            assert_eq!(x_q.shape(), [128, 512 / el_per_int]);
129            assert_eq!(scales.shape(), [128, 4]);
130            assert_eq!(biases.shape(), [128, 4]);
131
132            let x_hat = dequantize(&x_q, &scales, &biases, 128, *i).unwrap();
133            let max_diff = ((&x - &x_hat).abs().unwrap().max(None, None).unwrap()).item::<f32>();
134            assert!(max_diff <= 127.0 / (1 << i) as f32);
135        }
136    }
137}