mlx_rs/nn/
embedding.rs

1//! Embedding layer.
2
3use crate::error::Exception;
4use crate::module::Module;
5use crate::module::Param;
6use crate::ops::indexing::IndexOp;
7use crate::quantization::Quantizable;
8use crate::Array;
9use mlx_macros::ModuleParameters;
10
11use super::QuantizedEmbedding;
12
13/// Implements a simple lookup table that maps each input integer to a high-dimensional vector.
14///
15/// Typically used to embed discrete tokens for processing by neural networks.
16#[derive(Debug, Clone, ModuleParameters)]
17#[module(root = crate)]
18pub struct Embedding {
19    /// The weight of the
20    #[param]
21    pub weight: Param<Array>,
22}
23
24impl Embedding {
25    /// Creates a new [`Embedding`] layer.
26    ///
27    /// # Params
28    ///
29    /// - `embedding_count`: How many possible discrete tokens can we embed.  Usually called the vocabulary size.
30    /// - `dimensions`: The dimensionality of the embeddings.
31    pub fn new(embedding_count: i32, dimensions: i32) -> Result<Self, Exception> {
32        let scale = f32::sqrt(1.0 / (dimensions as f32));
33        let weight =
34            crate::random::normal::<f32>(&[embedding_count, dimensions], None, None, None)? * scale;
35
36        Ok(Self {
37            weight: Param::new(weight),
38        })
39    }
40
41    /// Call the embedding layer as a linear layer.
42    ///
43    /// Use this for example when input embedding and output projection
44    /// weights are tied.
45    pub fn as_linear(&self, x: &Array) -> Result<Array, Exception> {
46        crate::ops::matmul(x, self.weight.value.t())
47    }
48}
49
50impl Quantizable for Embedding {
51    type Quantized = QuantizedEmbedding;
52
53    type QuantizationError = Exception;
54
55    fn try_into_quantized(
56        self,
57        group_size: i32,
58        bits: i32,
59    ) -> Result<Self::Quantized, Self::QuantizationError> {
60        QuantizedEmbedding::try_from_embedding(self, group_size, bits)
61    }
62}
63
64impl Module<&Array> for Embedding {
65    type Error = Exception;
66    type Output = Array;
67
68    fn forward(&mut self, x: &Array) -> Result<Array, Self::Error> {
69        Ok(self.weight.index(x))
70    }
71
72    fn training_mode(&mut self, _mode: bool) {}
73}
74
75#[cfg(test)]
76mod tests {
77    use super::*;
78    use float_eq::float_eq;
79    use pretty_assertions::assert_eq;
80
81    #[test]
82    fn test_embedding() {
83        crate::random::seed(557).unwrap();
84        let a = crate::random::randint::<_, i32>(0, 10, &[2, 8, 8, 4], None).unwrap();
85        assert_eq!(a.shape(), &[2, 8, 8, 4]);
86        assert_eq!(a.dtype(), crate::Dtype::Int32);
87        float_eq!(
88            a.mean(None, None).unwrap().item::<f32>(),
89            4.605_468_8,
90            abs <= 0.092_109_375
91        );
92        float_eq!(
93            a.sum(None, None).unwrap().item::<f32>(),
94            2358.0,
95            abs <= 47.16
96        );
97
98        let result = Embedding::new(10, 8).unwrap().forward(&a).unwrap();
99        assert_eq!(result.shape(), &[2, 8, 8, 4, 8]);
100        assert_eq!(result.dtype(), crate::Dtype::Float32);
101        float_eq!(
102            result.mean(None, None).unwrap().item::<f32>(),
103            -0.001_197_346_3,
104            abs <= 2.394_692_5e-5
105        );
106        float_eq!(
107            result.sum(None, None).unwrap().item::<f32>(),
108            -4.904_330_3,
109            abs <= 0.098_086_6
110        );
111    }
112}