1use std::iter::once;
2
3use crate::{error::Exception, quantization::Quantizable, Array};
4use mlx_internal_macros::{Buildable, Builder};
5
6use crate::{
7 macros::ModuleParameters,
8 module::{Module, Param},
9};
10
11use super::QuantizedLinear;
12
13#[derive(Debug, Clone, Builder)]
15#[builder(
16 root = crate,
17 build_with = build_linear,
18 err = Exception,
19)]
20pub struct LinearBuilder {
21 pub input_dims: i32,
23
24 pub output_dims: i32,
26
27 #[builder(optional, default = Linear::DEFAULT_BIAS)]
29 pub bias: bool,
30}
31
32fn build_linear(builder: LinearBuilder) -> Result<Linear, Exception> {
34 let input_dims = builder.input_dims;
35 let output_dims = builder.output_dims;
36 let with_bias = builder.bias;
37
38 let scale = f32::sqrt(1.0 / (input_dims as f32));
39 let weight = crate::random::uniform::<_, f32>(-scale, scale, &[output_dims, input_dims], None)?;
40
41 let bias = if with_bias {
42 Some(crate::random::uniform::<_, f32>(
43 -scale,
44 scale,
45 &[output_dims],
46 None,
47 )?)
48 } else {
49 None
50 };
51
52 Ok(Linear {
53 weight: Param::new(weight),
54 bias: Param::new(bias),
55 })
56}
57
58#[derive(Debug, Clone, ModuleParameters, Buildable)]
60#[module(root = crate)]
61#[buildable(root = crate)]
62pub struct Linear {
63 #[param]
65 pub weight: Param<Array>,
66
67 #[param]
69 pub bias: Param<Option<Array>>,
70}
71
72impl Linear {
73 pub const DEFAULT_BIAS: bool = true;
75
76 pub fn shape(&self) -> (i32, i32) {
78 let weight_shape = self.weight.as_ref().shape();
79 (weight_shape[0], weight_shape[1])
80 }
81}
82
83impl Module<&Array> for Linear {
84 type Error = Exception;
85 type Output = Array;
86
87 fn forward(&mut self, x: &Array) -> Result<Array, Self::Error> {
88 match &self.bias.value {
89 Some(bias) => crate::ops::addmm(bias, x, self.weight.value.t(), None, None),
90 None => crate::ops::matmul(x, self.weight.value.t()),
91 }
92 }
93
94 fn training_mode(&mut self, _: bool) {}
95}
96
97impl Quantizable for Linear {
98 type Quantized = QuantizedLinear;
99 type QuantizationError = Exception;
100
101 fn try_into_quantized(
102 self,
103 group_size: i32,
104 bits: i32,
105 ) -> Result<Self::Quantized, Self::QuantizationError> {
106 QuantizedLinear::try_from_linear(self, group_size, bits)
107 }
108}
109
110#[derive(Debug, Clone, Builder)]
112#[builder(
113 root = crate,
114 build_with = build_bilinear,
115 err = Exception,
116)]
117pub struct BilinearBuilder {
118 pub input_dims_1: i32,
120
121 pub input_dims_2: i32,
123
124 pub output_dims: i32,
126
127 #[builder(optional, default = Bilinear::DEFAULT_BIAS)]
129 pub bias: bool,
130}
131
132fn build_bilinear(builder: BilinearBuilder) -> Result<Bilinear, Exception> {
133 let input_dims_1 = builder.input_dims_1;
134 let input_dims_2 = builder.input_dims_2;
135 let output_dims = builder.output_dims;
136 let with_bias = builder.bias;
137
138 let scale = f32::sqrt(1.0 / (input_dims_1 as f32));
139 let weights = crate::random::uniform::<_, f32>(
140 -scale,
141 scale,
142 &[output_dims, input_dims_2, input_dims_1],
143 None,
144 )?;
145
146 let bias = if with_bias {
147 Some(crate::random::uniform::<_, f32>(
148 -scale,
149 scale,
150 &[output_dims],
151 None,
152 )?)
153 } else {
154 None
155 };
156
157 Ok(Bilinear {
158 weights: Param::new(weights),
159 bias: Param::new(bias),
160 })
161}
162
163#[derive(Debug, Clone, ModuleParameters, Buildable)]
165#[module(root = crate)]
166#[buildable(root = crate)]
167pub struct Bilinear {
168 #[param]
170 pub weights: Param<Array>,
171
172 #[param]
174 pub bias: Param<Option<Array>>,
175}
176
177impl Bilinear {
178 pub const DEFAULT_BIAS: bool = true;
180}
181
182impl Module<&Array> for Bilinear {
183 type Error = Exception;
184 type Output = Array;
185
186 fn forward(&mut self, x: &Array) -> Result<Array, Self::Error> {
187 let shape = self.weights.shape();
188 let (out, in2, in1) = (shape[0], shape[1], shape[2]);
189 let x_shape = &x.shape()[..x.shape().len() - 1];
190 let x1 = x.reshape(&[-1, in1])?;
191 let x2 = x.reshape(&[-1, 1, in2])?;
192
193 let w = self.weights.reshape(&[out * in2, in1])?;
195 let mut y = crate::ops::matmul(&x1, w.t())?;
196 y = y.reshape(&[-1, out, in2])?.swap_axes(-2, -1)?;
197 y = crate::ops::matmul(&x2, &y)?;
198 y = y.squeeze(&[1])?;
199
200 let new_shape = x_shape.iter().cloned().chain(once(out)).collect::<Vec<_>>();
202 y = y.reshape(&new_shape)?;
203
204 if let Some(bias) = &self.bias.value {
205 y = crate::ops::add(&y, bias)?;
206 }
207
208 Ok(y)
209 }
210
211 fn training_mode(&mut self, _: bool) {}
212}
213
214#[cfg(test)]
217mod tests {
218 use crate::{random::uniform, Dtype};
219 use float_eq::assert_float_eq;
220
221 use super::*;
222
223 #[test]
224 fn test_linear() {
225 crate::random::seed(744).unwrap();
226 let a = uniform::<_, f32>(0.0, 1.0, &[2, 8, 16], None).unwrap();
227 assert_eq!(a.shape(), &[2, 8, 16]);
228 assert_eq!(a.dtype(), Dtype::Float32);
229 assert_float_eq!(
230 a.mean(None, None).unwrap().item::<f32>(),
231 0.508_688_57,
232 abs <= 0.010_173_771_5
233 );
234 assert_float_eq!(
235 a.sum(None, None).unwrap().item::<f32>(),
236 130.224_27,
237 abs <= 2.604_485_5
238 );
239 let result = Linear::new(16, 5).unwrap().forward(&a).unwrap();
240 assert_eq!(result.shape(), &[2, 8, 5]);
241 assert_eq!(result.dtype(), Dtype::Float32);
242 assert_float_eq!(
243 result.mean(None, None).unwrap().item::<f32>(),
244 0.104_193_09,
245 abs <= 0.002_083_861_7
246 );
247 assert_float_eq!(
248 result.sum(None, None).unwrap().item::<f32>(),
249 8.335_447,
250 abs <= 0.166_708_95
251 );
252 }
253}