use crate::error::Exception;
use crate::module::Module;
use crate::module::Param;
use crate::ops::indexing::IndexOp;
use crate::quantization::Quantizable;
use crate::Array;
use mlx_macros::ModuleParameters;
use super::QuantizedEmbedding;
#[derive(Debug, Clone, ModuleParameters)]
#[module(root = crate)]
pub struct Embedding {
#[param]
pub weight: Param<Array>,
}
impl Embedding {
pub fn new(embedding_count: i32, dimensions: i32) -> Result<Self, Exception> {
let scale = f32::sqrt(1.0 / (dimensions as f32));
let weight =
crate::random::normal::<f32>(&[embedding_count, dimensions], None, None, None)? * scale;
Ok(Self {
weight: Param::new(weight),
})
}
pub fn as_linear(&self, x: &Array) -> Result<Array, Exception> {
crate::ops::matmul(x, self.weight.value.t())
}
}
impl Quantizable for Embedding {
type Quantized = QuantizedEmbedding;
type QuantizationError = Exception;
fn try_into_quantized(
self,
group_size: i32,
bits: i32,
) -> Result<Self::Quantized, Self::QuantizationError> {
QuantizedEmbedding::try_from_embedding(self, group_size, bits)
}
}
impl Module<&Array> for Embedding {
type Error = Exception;
type Output = Array;
fn forward(&mut self, x: &Array) -> Result<Array, Self::Error> {
Ok(self.weight.index(x))
}
fn training_mode(&mut self, _mode: bool) {}
}
#[cfg(test)]
mod tests {
use super::*;
use float_eq::float_eq;
use pretty_assertions::assert_eq;
#[test]
fn test_embedding() {
crate::random::seed(557).unwrap();
let a = crate::random::randint::<_, i32>(0, 10, &[2, 8, 8, 4], None).unwrap();
assert_eq!(a.shape(), &[2, 8, 8, 4]);
assert_eq!(a.dtype(), crate::Dtype::Int32);
float_eq!(
a.mean(None, None).unwrap().item::<f32>(),
4.605_468_8,
abs <= 0.092_109_375
);
float_eq!(
a.sum(None, None).unwrap().item::<f32>(),
2358.0,
abs <= 47.16
);
let result = Embedding::new(10, 8).unwrap().forward(&a).unwrap();
assert_eq!(result.shape(), &[2, 8, 8, 4, 8]);
assert_eq!(result.dtype(), crate::Dtype::Float32);
float_eq!(
result.mean(None, None).unwrap().item::<f32>(),
-0.001_197_346_3,
abs <= 2.394_692_5e-5
);
float_eq!(
result.sum(None, None).unwrap().item::<f32>(),
-4.904_330_3,
abs <= 0.098_086_6
);
}
}