mlx_rs/ops/quantization.rs
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138
// quantized(_:groupSize:bits:stream:)
// quantizedMatmul(_:_:scales:biases:transpose:groupSize:bits:stream:)
// dequantized(_:scales:biases:groupSize:bits:stream:)
use mlx_internal_macros::default_device;
use crate::{error::Result, utils::guard::Guarded, Array, Stream, StreamOrDevice};
/// Quantize the matrix `w` using `bits` bits per element.
///
/// Note, every `group_size` elements in a row of `w` are quantized together. Hence, number of
/// columns of `w` should be divisible by `group_size`. In particular, the rows of `w` are divided
/// into groups of size `group_size` which are quantized together.
///
/// > `quantized` currently only supports 2D inputs with dimensions which are multiples of 32
///
/// For details, please see [this
/// documentation](https://ml-explore.github.io/mlx/build/html/python/_autosummary/mlx.core.quantize.html)
///
/// # Params
///
/// - `w`: The input matrix
/// - `group_size`: The size of the group in `w` that shares a scale and bias. (default: `64`)
/// - `bits`: The number of bits occupied by each element of w in the returned quantized matrix.
/// (default: 4)
#[default_device]
pub fn quantize_device(
w: impl AsRef<Array>,
group_size: impl Into<Option<i32>>,
bits: impl Into<Option<i32>>,
stream: impl AsRef<Stream>,
) -> Result<(Array, Array, Array)> {
let group_size = group_size.into().unwrap_or(64);
let bits = bits.into().unwrap_or(4);
<(Array, Array, Array) as Guarded>::try_from_op(|(res0, res1, res2)| unsafe {
mlx_sys::mlx_quantize(
res0,
res1,
res2,
w.as_ref().as_ptr(),
group_size,
bits,
stream.as_ref().as_ptr(),
)
})
}
/// Perform the matrix multiplication with the quantized matrix `w`. The quantization uses one
/// floating point scale and bias per `group_size` of elements. Each element in `w` takes `bits`
/// bits and is packed in an unsigned 32 bit integer.
#[allow(clippy::too_many_arguments)]
#[default_device]
pub fn quantized_matmul_device(
x: impl AsRef<Array>,
w: impl AsRef<Array>,
scales: impl AsRef<Array>,
biases: impl AsRef<Array>,
transpose: impl Into<Option<bool>>,
group_size: impl Into<Option<i32>>,
bits: impl Into<Option<i32>>,
stream: impl AsRef<Stream>,
) -> Result<Array> {
let transpose = transpose.into().unwrap_or(false);
let group_size = group_size.into().unwrap_or(64);
let bits = bits.into().unwrap_or(4);
<Array as Guarded>::try_from_op(|res| unsafe {
mlx_sys::mlx_quantized_matmul(
res,
x.as_ref().as_ptr(),
w.as_ref().as_ptr(),
scales.as_ref().as_ptr(),
biases.as_ref().as_ptr(),
transpose,
group_size,
bits,
stream.as_ref().as_ptr(),
)
})
}
/// Dequantize the matrix `w` using the provided `scales` and `biases` and the `group_size` and
/// `bits` configuration.
///
/// For details, please see [this
/// documentation](https://ml-explore.github.io/mlx/build/html/python/_autosummary/mlx.core.dequantize.html)
#[default_device]
pub fn dequantize_device(
w: impl AsRef<Array>,
scales: impl AsRef<Array>,
biases: impl AsRef<Array>,
group_size: impl Into<Option<i32>>,
bits: impl Into<Option<i32>>,
stream: impl AsRef<Stream>,
) -> Result<Array> {
let group_size = group_size.into().unwrap_or(64);
let bits = bits.into().unwrap_or(4);
<Array as Guarded>::try_from_op(|res| unsafe {
mlx_sys::mlx_dequantize(
res,
w.as_ref().as_ptr(),
scales.as_ref().as_ptr(),
biases.as_ref().as_ptr(),
group_size,
bits,
stream.as_ref().as_ptr(),
)
})
}
#[cfg(test)]
mod tests {
use crate::{
ops::{dequantize, expand_dims, quantize},
Array,
};
#[test]
fn test_quantize_dequantize() {
let x1 = Array::ones::<f32>(&[128, 1]).unwrap();
let x2 = expand_dims(Array::arange::<_, f32>(0, 512, None).unwrap(), &[0]).unwrap();
let x = x1 * x2;
for i in [2, 4, 8].iter() {
let el_per_int = 32 / i;
let (x_q, scales, biases) = quantize(&x, 128, *i).unwrap();
assert_eq!(x_q.shape(), [128, 512 / el_per_int]);
assert_eq!(scales.shape(), [128, 4]);
assert_eq!(biases.shape(), [128, 4]);
let x_hat = dequantize(&x_q, &scales, &biases, 128, *i).unwrap();
let max_diff = ((&x - &x_hat).abs().unwrap().max(None, None).unwrap()).item::<f32>();
assert!(max_diff <= 127.0 / (1 << i) as f32);
}
}
}