mlx_rs/ops/
other.rs

1use std::ffi::CString;
2
3use mlx_internal_macros::{default_device, generate_macro};
4
5use crate::utils::guard::Guarded;
6use crate::utils::VectorArray;
7use crate::{
8    error::{Exception, Result},
9    Array, Stream, StreamOrDevice,
10};
11
12impl Array {
13    /// Extract a diagonal or construct a diagonal matrix.
14    ///
15    /// If self is 1-D then a diagonal matrix is constructed with self on the `k`-th diagonal. If
16    /// self is 2-D then the `k`-th diagonal is returned.
17    ///
18    /// # Params:
19    ///
20    /// - `k`: the diagonal to extract or construct
21    /// - `stream`: stream or device to evaluate on
22    #[default_device]
23    pub fn diag_device(
24        &self,
25        k: impl Into<Option<i32>>,
26        stream: impl AsRef<Stream>,
27    ) -> Result<Array> {
28        Array::try_from_op(|res| unsafe {
29            mlx_sys::mlx_diag(
30                res,
31                self.as_ptr(),
32                k.into().unwrap_or(0),
33                stream.as_ref().as_ptr(),
34            )
35        })
36    }
37
38    /// Return specified diagonals.
39    ///
40    /// If self is 2-D, then a 1-D array containing the diagonal at the given `offset` is returned.
41    ///
42    /// If self has more than two dimensions, then `axis1` and `axis2` determine the 2D subarrays
43    /// from which diagonals are extracted. The new shape is the original shape with `axis1` and
44    /// `axis2` removed and a new dimension inserted at the end corresponding to the diagonal.
45    ///
46    /// # Params:
47    ///
48    /// - `offset`: offset of the diagonal.  Can be positive or negative
49    /// - `axis1`: first axis of the 2-D sub-array from which the diagonals should be taken
50    /// - `axis2`: second axis of the 2-D sub-array from which the diagonals should be taken
51    /// - `stream`: stream or device to evaluate on
52    #[default_device]
53    pub fn diagonal_device(
54        &self,
55        offset: impl Into<Option<i32>>,
56        axis1: impl Into<Option<i32>>,
57        axis2: impl Into<Option<i32>>,
58        stream: impl AsRef<Stream>,
59    ) -> Result<Array> {
60        Array::try_from_op(|res| unsafe {
61            mlx_sys::mlx_diagonal(
62                res,
63                self.as_ptr(),
64                offset.into().unwrap_or(0),
65                axis1.into().unwrap_or(0),
66                axis2.into().unwrap_or(1),
67                stream.as_ref().as_ptr(),
68            )
69        })
70    }
71
72    /// Perform the Walsh-Hadamard transform along the final axis.
73    ///
74    /// Supports sizes `n = m*2^k` for `m` in `(1, 12, 20, 28)` and `2^k <= 8192`
75    /// for ``DType/float32`` and `2^k <= 16384` for ``DType/float16`` and ``DType/bfloat16``.
76    ///
77    /// # Params
78    /// - scale: scale the output by this factor -- default is `1.0/sqrt(array.dim(-1))`
79    /// - stream: stream to evaluate on.
80    #[default_device]
81    pub fn hadamard_transform_device(
82        &self,
83        scale: impl Into<Option<f32>>,
84        stream: impl AsRef<Stream>,
85    ) -> Result<Array> {
86        let scale = scale.into();
87        let scale = mlx_sys::mlx_optional_float {
88            value: scale.unwrap_or(0.0),
89            has_value: scale.is_some(),
90        };
91
92        Array::try_from_op(|res| unsafe {
93            mlx_sys::mlx_hadamard_transform(res, self.as_ptr(), scale, stream.as_ref().as_ptr())
94        })
95    }
96}
97
98/// See [`Array::diag`]
99#[generate_macro]
100#[default_device]
101pub fn diag_device(
102    a: impl AsRef<Array>,
103    #[optional] k: impl Into<Option<i32>>,
104    #[optional] stream: impl AsRef<Stream>,
105) -> Result<Array> {
106    a.as_ref().diag_device(k, stream)
107}
108
109/// See [`Array::diagonal`]
110#[generate_macro]
111#[default_device]
112pub fn diagonal_device(
113    a: impl AsRef<Array>,
114    #[optional] offset: impl Into<Option<i32>>,
115    #[optional] axis1: impl Into<Option<i32>>,
116    #[optional] axis2: impl Into<Option<i32>>,
117    #[optional] stream: impl AsRef<Stream>,
118) -> Result<Array> {
119    a.as_ref().diagonal_device(offset, axis1, axis2, stream)
120}
121
122/// Perform the Einstein summation convention on the operands.
123///
124/// # Params
125///
126/// - subscripts: Einstein summation convention equation
127/// - operands: input arrays
128/// - stream: stream or device to evaluate on
129#[generate_macro]
130#[default_device]
131pub fn einsum_device<'a>(
132    subscripts: &str,
133    operands: impl IntoIterator<Item = &'a Array>,
134    #[optional] stream: impl AsRef<Stream>,
135) -> Result<Array> {
136    let c_subscripts =
137        CString::new(subscripts).map_err(|_| Exception::from("Invalid subscripts"))?;
138    let c_operands = VectorArray::try_from_iter(operands.into_iter())?;
139
140    Array::try_from_op(|res| unsafe {
141        mlx_sys::mlx_einsum(
142            res,
143            c_subscripts.as_ptr(),
144            c_operands.as_ptr(),
145            stream.as_ref().as_ptr(),
146        )
147    })
148}
149
150/// Perform the Kronecker product of two arrays.
151///
152/// # Params
153///
154/// - `a`: first array
155/// - `b`: second array
156/// - `stream`: stream or device to evaluate on
157#[generate_macro]
158#[default_device]
159pub fn kron_device(
160    a: impl AsRef<Array>,
161    b: impl AsRef<Array>,
162    #[optional] stream: impl AsRef<Stream>,
163) -> Result<Array> {
164    Array::try_from_op(|res| unsafe {
165        mlx_sys::mlx_kron(
166            res,
167            a.as_ref().as_ptr(),
168            b.as_ref().as_ptr(),
169            stream.as_ref().as_ptr(),
170        )
171    })
172}
173
174#[cfg(test)]
175mod tests {
176    use crate::{
177        array,
178        ops::{arange, diag, einsum, reshape},
179        Array,
180    };
181    use pretty_assertions::assert_eq;
182
183    use super::diagonal;
184
185    #[test]
186    fn test_diagonal() {
187        let x = Array::from_slice(&[0, 1, 2, 3, 4, 5, 6, 7], &[4, 2]);
188        let out = diagonal(&x, None, None, None).unwrap();
189        assert_eq!(out, array!([0, 3]));
190
191        assert!(diagonal(&x, 1, 6, 0).is_err());
192        assert!(diagonal(&x, 1, 0, -3).is_err());
193
194        let x = Array::from_slice(&[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11], &[3, 4]);
195        let out = diagonal(&x, 2, 1, 0).unwrap();
196        assert_eq!(out, array!([8]));
197
198        let out = diagonal(&x, -1, 0, 1).unwrap();
199        assert_eq!(out, array!([4, 9]));
200
201        let out = diagonal(&x, -5, 0, 1).unwrap();
202        out.eval().unwrap();
203        assert_eq!(out.shape(), &[0]);
204
205        let x = Array::from_slice(&[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11], &[3, 2, 2]);
206        let out = diagonal(&x, 1, 0, 1).unwrap();
207        assert_eq!(out, array!([[2], [3]]));
208
209        let out = diagonal(&x, 0, 2, 0).unwrap();
210        assert_eq!(out, array!([[0, 5], [2, 7]]));
211
212        let out = diagonal(&x, 1, -1, 0).unwrap();
213        assert_eq!(out, array!([[4, 9], [6, 11]]));
214
215        let x = reshape(arange::<_, f32>(None, 16, None).unwrap(), &[2, 2, 2, 2]).unwrap();
216        let out = diagonal(&x, 0, 0, 1).unwrap();
217        assert_eq!(
218            out,
219            Array::from_slice(&[0, 12, 1, 13, 2, 14, 3, 15], &[2, 2, 2])
220        );
221
222        assert!(diagonal(&x, 0, 1, 1).is_err());
223
224        let x = array!([0, 1]);
225        assert!(diagonal(&x, 0, 0, 1).is_err());
226    }
227
228    #[test]
229    fn test_diag() {
230        // Too few or too many dimensions
231        assert!(diag(Array::from_f32(0.0), None).is_err());
232        assert!(diag(Array::from_slice(&[0.0], &[1, 1, 1]), None).is_err());
233
234        // Test with 1D array
235        let x = array!([0, 1, 2, 3]);
236        let out = diag(&x, 0).unwrap();
237        assert_eq!(
238            out,
239            array!([[0, 0, 0, 0], [0, 1, 0, 0], [0, 0, 2, 0], [0, 0, 0, 3]])
240        );
241
242        let out = diag(&x, 1).unwrap();
243        assert_eq!(
244            out,
245            array!([
246                [0, 0, 0, 0, 0],
247                [0, 0, 1, 0, 0],
248                [0, 0, 0, 2, 0],
249                [0, 0, 0, 0, 3],
250                [0, 0, 0, 0, 0]
251            ])
252        );
253
254        let out = diag(&x, -1).unwrap();
255        assert_eq!(
256            out,
257            array!([
258                [0, 0, 0, 0, 0],
259                [0, 0, 0, 0, 0],
260                [0, 1, 0, 0, 0],
261                [0, 0, 2, 0, 0],
262                [0, 0, 0, 3, 0]
263            ])
264        );
265
266        // Test with 2D array
267        let x = Array::from_slice(&[0, 1, 2, 3, 4, 5, 6, 7, 8], &[3, 3]);
268        let out = diag(&x, 0).unwrap();
269        assert_eq!(out, array!([0, 4, 8]));
270
271        let out = diag(&x, 1).unwrap();
272        assert_eq!(out, array!([1, 5]));
273
274        let out = diag(&x, -1).unwrap();
275        assert_eq!(out, array!([3, 7]));
276    }
277
278    #[test]
279    fn test_einsum() {
280        // Test dot product (vector-vector)
281        let a = array!([0.0, 1.0, 2.0, 3.0]);
282        let b = array!([4.0, 5.0, 6.0, 7.0]);
283        let out = einsum("i,i->", &[a, b]).unwrap();
284        assert_eq!(out, array!(38.0));
285
286        // Test trace (diagonal sum)
287        let m = array!([[1, 2], [3, 4]]);
288        let out = einsum("ii->", &[m]).unwrap();
289        assert_eq!(out, array!(5.0));
290    }
291
292    #[test]
293    fn test_hadamard_transform() {
294        let input = Array::from_slice(&[1.0, -1.0, -1.0, 1.0], &[2, 2]);
295        let expected = Array::from_slice(
296            &[
297                0.0,
298                2.0_f32 / 2.0_f32.sqrt(),
299                0.0,
300                -2.0_f32 / 2.0_f32.sqrt(),
301            ],
302            &[2, 2],
303        );
304        let result = input.hadamard_transform(None).unwrap();
305
306        let c = result.all_close(&expected, 1e-5, 1e-5, None).unwrap();
307        let c_data: &[bool] = c.as_slice();
308        assert_eq!(c_data, [true]);
309    }
310
311    // This test is adapted from the python unit test `mlx/test/test_ops.py` `test_kron`
312    #[test]
313    fn test_kron() {
314        // Basic vector test
315        let x = array!([1, 2]);
316        let y = array!([3, 4]);
317        let z = super::kron(&x, &y).unwrap();
318        assert_eq!(z, array!([3, 4, 6, 8]));
319
320        // Basic matrix test
321        let x = array!([[1, 2], [3, 4]]);
322        let y = array!([[0, 5], [6, 7]]);
323        let z = super::kron(&x, &y).unwrap();
324        assert_eq!(
325            z,
326            array!([
327                [0, 5, 0, 10],
328                [6, 7, 12, 14],
329                [0, 15, 0, 20],
330                [18, 21, 24, 28]
331            ])
332        );
333    }
334}