1use crate::error::Exception;
4use crate::module::Module;
5use crate::module::Param;
6use crate::ops::indexing::IndexOp;
7use crate::quantization::Quantizable;
8use crate::Array;
9use mlx_macros::ModuleParameters;
10
11use super::QuantizedEmbedding;
12
13#[derive(Debug, Clone, ModuleParameters)]
17#[module(root = crate)]
18pub struct Embedding {
19 #[param]
21 pub weight: Param<Array>,
22}
23
24impl Embedding {
25 pub fn new(embedding_count: i32, dimensions: i32) -> Result<Self, Exception> {
32 let scale = f32::sqrt(1.0 / (dimensions as f32));
33 let weight =
34 crate::random::normal::<f32>(&[embedding_count, dimensions], None, None, None)? * scale;
35
36 Ok(Self {
37 weight: Param::new(weight),
38 })
39 }
40
41 pub fn as_linear(&self, x: &Array) -> Result<Array, Exception> {
46 crate::ops::matmul(x, self.weight.value.t())
47 }
48}
49
50impl Quantizable for Embedding {
51 type Quantized = QuantizedEmbedding;
52
53 type QuantizationError = Exception;
54
55 fn try_into_quantized(
56 self,
57 group_size: i32,
58 bits: i32,
59 ) -> Result<Self::Quantized, Self::QuantizationError> {
60 QuantizedEmbedding::try_from_embedding(self, group_size, bits)
61 }
62}
63
64impl Module<&Array> for Embedding {
65 type Error = Exception;
66 type Output = Array;
67
68 fn forward(&mut self, x: &Array) -> Result<Array, Self::Error> {
69 Ok(self.weight.index(x))
70 }
71
72 fn training_mode(&mut self, _mode: bool) {}
73}
74
75#[cfg(test)]
76mod tests {
77 use super::*;
78 use float_eq::float_eq;
79 use pretty_assertions::assert_eq;
80
81 #[test]
82 fn test_embedding() {
83 crate::random::seed(557).unwrap();
84 let a = crate::random::randint::<_, i32>(0, 10, &[2, 8, 8, 4], None).unwrap();
85 assert_eq!(a.shape(), &[2, 8, 8, 4]);
86 assert_eq!(a.dtype(), crate::Dtype::Int32);
87 float_eq!(
88 a.mean(None, None).unwrap().item::<f32>(),
89 4.605_468_8,
90 abs <= 0.092_109_375
91 );
92 float_eq!(
93 a.sum(None, None).unwrap().item::<f32>(),
94 2358.0,
95 abs <= 47.16
96 );
97
98 let result = Embedding::new(10, 8).unwrap().forward(&a).unwrap();
99 assert_eq!(result.shape(), &[2, 8, 8, 4, 8]);
100 assert_eq!(result.dtype(), crate::Dtype::Float32);
101 float_eq!(
102 result.mean(None, None).unwrap().item::<f32>(),
103 -0.001_197_346_3,
104 abs <= 2.394_692_5e-5
105 );
106 float_eq!(
107 result.sum(None, None).unwrap().item::<f32>(),
108 -4.904_330_3,
109 abs <= 0.098_086_6
110 );
111 }
112}