mlx_rs/nn/
linear.rs

1use std::iter::once;
2
3use crate::{error::Exception, quantization::Quantizable, Array};
4use mlx_internal_macros::{Buildable, Builder};
5
6use crate::{
7    macros::ModuleParameters,
8    module::{Module, Param},
9};
10
11use super::QuantizedLinear;
12
13/// Builder for [`Linear`] module
14#[derive(Debug, Clone, Builder)]
15#[builder(
16    root = crate,
17    build_with = build_linear,
18    err = Exception,
19)]
20pub struct LinearBuilder {
21    /// The number of input dimensions.
22    pub input_dims: i32,
23
24    /// The number of output dimensions.
25    pub output_dims: i32,
26
27    /// Whether to include bias in the linear layer. Default to [`Linear::DEFAULT_BIAS`].
28    #[builder(optional, default = Linear::DEFAULT_BIAS)]
29    pub bias: bool,
30}
31
32/// Builds a new [`Linear`] layer.
33fn build_linear(builder: LinearBuilder) -> Result<Linear, Exception> {
34    let input_dims = builder.input_dims;
35    let output_dims = builder.output_dims;
36    let with_bias = builder.bias;
37
38    let scale = f32::sqrt(1.0 / (input_dims as f32));
39    let weight = crate::random::uniform::<_, f32>(-scale, scale, &[output_dims, input_dims], None)?;
40
41    let bias = if with_bias {
42        Some(crate::random::uniform::<_, f32>(
43            -scale,
44            scale,
45            &[output_dims],
46            None,
47        )?)
48    } else {
49        None
50    };
51
52    Ok(Linear {
53        weight: Param::new(weight),
54        bias: Param::new(bias),
55    })
56}
57
58/// Applies an affine transformation to the input.
59#[derive(Debug, Clone, ModuleParameters, Buildable)]
60#[module(root = crate)]
61#[buildable(root = crate)]
62pub struct Linear {
63    /// The weight of the linear layer.
64    #[param]
65    pub weight: Param<Array>,
66
67    /// The bias of the linear layer.
68    #[param]
69    pub bias: Param<Option<Array>>,
70}
71
72impl Linear {
73    /// Default value for `with_bias`
74    pub const DEFAULT_BIAS: bool = true;
75
76    /// Returns the shape of the linear layer.
77    pub fn shape(&self) -> (i32, i32) {
78        let weight_shape = self.weight.as_ref().shape();
79        (weight_shape[0], weight_shape[1])
80    }
81}
82
83impl Module<&Array> for Linear {
84    type Error = Exception;
85    type Output = Array;
86
87    fn forward(&mut self, x: &Array) -> Result<Array, Self::Error> {
88        match &self.bias.value {
89            Some(bias) => crate::ops::addmm(bias, x, self.weight.value.t(), None, None),
90            None => crate::ops::matmul(x, self.weight.value.t()),
91        }
92    }
93
94    fn training_mode(&mut self, _: bool) {}
95}
96
97impl Quantizable for Linear {
98    type Quantized = QuantizedLinear;
99    type QuantizationError = Exception;
100
101    fn try_into_quantized(
102        self,
103        group_size: i32,
104        bits: i32,
105    ) -> Result<Self::Quantized, Self::QuantizationError> {
106        QuantizedLinear::try_from_linear(self, group_size, bits)
107    }
108}
109
110/// Builder for [`Bilinear`] module
111#[derive(Debug, Clone, Builder)]
112#[builder(
113    root = crate,
114    build_with = build_bilinear,
115    err = Exception,
116)]
117pub struct BilinearBuilder {
118    /// The number of input dimensions for the first input.
119    pub input_dims_1: i32,
120
121    /// The number of input dimensions for the second input.
122    pub input_dims_2: i32,
123
124    /// The number of output dimensions.
125    pub output_dims: i32,
126
127    /// Whether to include bias in the bilinear layer. Default to [Bilinear::DEFAULT_BIAS].
128    #[builder(optional, default = Bilinear::DEFAULT_BIAS)]
129    pub bias: bool,
130}
131
132fn build_bilinear(builder: BilinearBuilder) -> Result<Bilinear, Exception> {
133    let input_dims_1 = builder.input_dims_1;
134    let input_dims_2 = builder.input_dims_2;
135    let output_dims = builder.output_dims;
136    let with_bias = builder.bias;
137
138    let scale = f32::sqrt(1.0 / (input_dims_1 as f32));
139    let weights = crate::random::uniform::<_, f32>(
140        -scale,
141        scale,
142        &[output_dims, input_dims_2, input_dims_1],
143        None,
144    )?;
145
146    let bias = if with_bias {
147        Some(crate::random::uniform::<_, f32>(
148            -scale,
149            scale,
150            &[output_dims],
151            None,
152        )?)
153    } else {
154        None
155    };
156
157    Ok(Bilinear {
158        weights: Param::new(weights),
159        bias: Param::new(bias),
160    })
161}
162
163/// Applies a bilinear transformation to the inputs.
164#[derive(Debug, Clone, ModuleParameters, Buildable)]
165#[module(root = crate)]
166#[buildable(root = crate)]
167pub struct Bilinear {
168    /// The weight of the bilinear layer.
169    #[param]
170    pub weights: Param<Array>,
171
172    /// The bias of the bilinear layer.
173    #[param]
174    pub bias: Param<Option<Array>>,
175}
176
177impl Bilinear {
178    /// Default value for `with_bias`
179    pub const DEFAULT_BIAS: bool = true;
180}
181
182impl Module<&Array> for Bilinear {
183    type Error = Exception;
184    type Output = Array;
185
186    fn forward(&mut self, x: &Array) -> Result<Array, Self::Error> {
187        let shape = self.weights.shape();
188        let (out, in2, in1) = (shape[0], shape[1], shape[2]);
189        let x_shape = &x.shape()[..x.shape().len() - 1];
190        let x1 = x.reshape(&[-1, in1])?;
191        let x2 = x.reshape(&[-1, 1, in2])?;
192
193        // perform the bilinear transform
194        let w = self.weights.reshape(&[out * in2, in1])?;
195        let mut y = crate::ops::matmul(&x1, w.t())?;
196        y = y.reshape(&[-1, out, in2])?.swap_axes(-2, -1)?;
197        y = crate::ops::matmul(&x2, &y)?;
198        y = y.squeeze(&[1])?;
199
200        // reset the shape
201        let new_shape = x_shape.iter().cloned().chain(once(out)).collect::<Vec<_>>();
202        y = y.reshape(&new_shape)?;
203
204        if let Some(bias) = &self.bias.value {
205            y = crate::ops::add(&y, bias)?;
206        }
207
208        Ok(y)
209    }
210
211    fn training_mode(&mut self, _: bool) {}
212}
213
214// The following tests are ported from the swift binding:
215// mlx-swift/Tests/MLXTests/IntegrationTests.swift
216#[cfg(test)]
217mod tests {
218    use crate::{random::uniform, Dtype};
219    use float_eq::assert_float_eq;
220
221    use super::*;
222
223    #[test]
224    fn test_linear() {
225        crate::random::seed(744).unwrap();
226        let a = uniform::<_, f32>(0.0, 1.0, &[2, 8, 16], None).unwrap();
227        assert_eq!(a.shape(), &[2, 8, 16]);
228        assert_eq!(a.dtype(), Dtype::Float32);
229        assert_float_eq!(
230            a.mean(None, None).unwrap().item::<f32>(),
231            0.508_688_57,
232            abs <= 0.010_173_771_5
233        );
234        assert_float_eq!(
235            a.sum(None, None).unwrap().item::<f32>(),
236            130.224_27,
237            abs <= 2.604_485_5
238        );
239        let result = Linear::new(16, 5).unwrap().forward(&a).unwrap();
240        assert_eq!(result.shape(), &[2, 8, 5]);
241        assert_eq!(result.dtype(), Dtype::Float32);
242        assert_float_eq!(
243            result.mean(None, None).unwrap().item::<f32>(),
244            0.104_193_09,
245            abs <= 0.002_083_861_7
246        );
247        assert_float_eq!(
248            result.sum(None, None).unwrap().item::<f32>(),
249            8.335_447,
250            abs <= 0.166_708_95
251        );
252    }
253}