use std::iter::once;
use crate::{
array,
error::Exception,
module::{Module, ModuleParameters, Param},
ops::indexing::IndexOp,
ops::{self, dequantize, quantized_matmul, zeros},
quantization::Quantizable,
random::uniform,
Array,
};
use mlx_internal_macros::{Buildable, Builder};
use mlx_macros::ModuleParameters;
use crate::nn::{Embedding, Linear};
pub fn quantize<M>(
module: M,
group_size: impl Into<Option<i32>>,
bits: impl Into<Option<i32>>,
) -> Result<M::Quantized, M::QuantizationError>
where
M: Quantizable,
{
let group_size = group_size.into().unwrap_or(M::DEFAULT_GROUP_SIZE);
let bits = bits.into().unwrap_or(M::DEFAULT_BITS);
module.try_into_quantized(group_size, bits)
}
#[derive(Debug, Clone, Builder)]
#[builder(
root = crate,
build_with = build_quantized_embedding,
err = Exception,
)]
pub struct QuantizedEmbeddingBuilder {
pub embedding_count: i32,
pub dimensions: i32,
#[builder(optional, default = QuantizedEmbedding::DEFAULT_GROUP_SIZE)]
pub group_size: i32,
#[builder(optional, default = QuantizedEmbedding::DEFAULT_BITS)]
pub bits: i32,
}
#[derive(Debug, Clone, ModuleParameters, Buildable)]
#[module(root = crate)]
#[buildable(root = crate)]
pub struct QuantizedEmbedding {
pub group_size: i32,
pub bits: i32,
pub scales: Param<Array>,
pub biases: Param<Array>,
pub inner: Embedding,
}
impl QuantizedEmbeddingBuilder {
pub fn build_with_embedding(
self,
embedding: Embedding,
) -> Result<QuantizedEmbedding, Exception> {
let weight = embedding.weight.value;
self.build_with_weight(weight)
}
pub fn build_with_weight(self, weight: Array) -> Result<QuantizedEmbedding, Exception> {
let group_size = self.group_size;
let bits = self.bits;
build_quantized_embedding_inner(weight, group_size, bits)
}
}
fn build_quantized_embedding_inner(
weight: Array,
group_size: i32,
bits: i32,
) -> Result<QuantizedEmbedding, Exception> {
let (quantized_weight, scales, biases) = ops::quantize(&weight, group_size, bits)?;
let inner = Embedding {
weight: Param::new(quantized_weight),
};
let mut qe = QuantizedEmbedding {
group_size,
bits,
scales: Param::new(scales),
biases: Param::new(biases),
inner,
};
qe.freeze_parameters(true);
Ok(qe)
}
fn build_quantized_embedding(
builder: QuantizedEmbeddingBuilder,
) -> Result<QuantizedEmbedding, Exception> {
let embedding_count = builder.embedding_count;
let dims = builder.dimensions;
let scale = array!(f32::sqrt(1.0 / (dims as f32)));
let weight = crate::random::normal::<f32>(&[embedding_count, dims], None, None, None)? * &scale;
builder.build_with_weight(weight)
}
impl QuantizedEmbedding {
pub const DEFAULT_GROUP_SIZE: i32 = 64;
pub const DEFAULT_BITS: i32 = 4;
pub fn try_from_embedding(
embedding: Embedding,
group_size: impl Into<Option<i32>>,
bits: impl Into<Option<i32>>,
) -> Result<Self, Exception> {
let group_size = group_size.into().unwrap_or(Self::DEFAULT_GROUP_SIZE);
let bits = bits.into().unwrap_or(Self::DEFAULT_BITS);
build_quantized_embedding_inner(embedding.weight.value, group_size, bits)
}
pub fn as_linear(&self, x: impl AsRef<Array>) -> Result<Array, Exception> {
quantized_matmul(
x.as_ref(),
&self.inner.weight,
&self.scales,
&self.biases,
true,
self.group_size,
self.bits,
)
}
}
impl TryFrom<Embedding> for QuantizedEmbedding {
type Error = Exception;
fn try_from(embedding: Embedding) -> Result<Self, Self::Error> {
Self::try_from_embedding(embedding, None, None)
}
}
impl Module<&Array> for QuantizedEmbedding {
type Error = Exception;
type Output = Array;
fn forward(&mut self, x: &Array) -> Result<Array, Self::Error> {
let s = x.shape();
let x = x.flatten(None, None)?;
let w = self.inner.weight.index(&x);
let scales = self.scales.index(&x);
let biases = self.biases.index(&x);
let out = dequantize(&w, &scales, &biases, self.group_size, self.bits)?;
let ret_shape = s.iter().copied().chain(once(-1)).collect::<Vec<_>>();
out.reshape(&ret_shape)
}
fn training_mode(&mut self, mode: bool) {
self.inner.training_mode(mode);
}
}
#[derive(Debug, Clone, Builder)]
#[builder(
root = crate,
build_with = build_quantized_linear,
err = Exception,
)]
pub struct QuantizedLinearBuilder {
pub input_dims: i32,
pub output_dims: i32,
#[builder(optional, default = QuantizedLinear::DEFAULT_GROUP_SIZE)]
pub group_size: i32,
#[builder(optional, default = QuantizedLinear::DEFAULT_BITS)]
pub bits: i32,
#[builder(optional, default = Linear::DEFAULT_BIAS)]
pub bias: bool,
}
impl QuantizedLinearBuilder {
pub fn build_with_linear(self, other: Linear) -> Result<QuantizedLinear, Exception> {
self.build_with_weight_and_bias(other.weight.value, other.bias.value)
}
fn build_with_weight_and_bias(
self,
weight: Array,
bias: Option<Array>,
) -> Result<QuantizedLinear, Exception> {
build_quantized_linear_inner(weight, bias, self.group_size, self.bits)
}
}
fn build_quantized_linear_inner(
weight: Array,
bias: Option<Array>,
group_size: i32,
bits: i32,
) -> Result<QuantizedLinear, Exception> {
let (quantized_weight, scales, biases) = ops::quantize(&weight, group_size, bits)?;
let inner = Linear {
weight: Param::new(quantized_weight),
bias: Param::new(bias),
};
let mut ql = QuantizedLinear {
group_size,
bits,
scales: Param::new(scales),
biases: Param::new(biases),
inner,
};
ql.freeze_parameters(true);
Ok(ql)
}
pub fn build_quantized_linear(
builder: QuantizedLinearBuilder,
) -> Result<QuantizedLinear, Exception> {
let input_dims = builder.input_dims;
let output_dims = builder.output_dims;
let scale = f32::sqrt(1.0 / (input_dims as f32));
let weight = uniform::<_, f32>(-scale, scale, &[output_dims, input_dims], None)?;
let bias = if builder.bias {
Some(zeros::<f32>(&[output_dims])?)
} else {
None
};
builder.build_with_weight_and_bias(weight, bias)
}
#[derive(Debug, Clone, ModuleParameters, Buildable)]
#[module(root = crate)]
#[buildable(root = crate)]
pub struct QuantizedLinear {
pub group_size: i32,
pub bits: i32,
#[param]
pub scales: Param<Array>,
#[param]
pub biases: Param<Array>,
#[param]
pub inner: Linear,
}
impl QuantizedLinear {
pub const DEFAULT_GROUP_SIZE: i32 = 64;
pub const DEFAULT_BITS: i32 = 4;
pub fn try_from_linear(
linear: Linear,
group_size: impl Into<Option<i32>>,
bits: impl Into<Option<i32>>,
) -> Result<Self, Exception> {
let group_size = group_size.into().unwrap_or(Self::DEFAULT_GROUP_SIZE);
let bits = bits.into().unwrap_or(Self::DEFAULT_BITS);
build_quantized_linear_inner(linear.weight.value, linear.bias.value, group_size, bits)
}
}
impl TryFrom<Linear> for QuantizedLinear {
type Error = Exception;
fn try_from(linear: Linear) -> Result<Self, Self::Error> {
Self::try_from_linear(linear, None, None)
}
}
impl Module<&Array> for QuantizedLinear {
type Error = Exception;
type Output = Array;
fn forward(&mut self, x: &Array) -> Result<Array, Self::Error> {
let mut x = quantized_matmul(
x,
&self.inner.weight,
&self.scales,
&self.biases,
true,
self.group_size,
self.bits,
)?;
if let Some(bias) = &self.inner.bias.value {
x = x.add(bias)?;
}
Ok(x)
}
fn training_mode(&mut self, mode: bool) {
self.inner.training_mode(mode);
}
}