mlx_rs/macros/
array.rs

1//! Macros for creating arrays.
2
3/// A helper macro to create an array with up to 3 dimensions.
4///
5/// # Examples
6///
7/// ```rust
8/// use mlx_rs::array;
9///
10/// // Create an empty array
11/// // Note that an empty array defaults to f32 and one dimension
12/// let empty = array!();
13///
14/// // Create a scalar array
15/// let s = array!(1);
16/// // Scalar array has 0 dimension
17/// assert_eq!(s.ndim(), 0);
18///
19/// // Create a one-element array (singleton matrix)
20/// let s = array!([1]);
21/// // Singleton array has 1 dimension
22/// assert!(s.ndim() == 1);
23///
24/// // Create a 1D array
25/// let a1 = array!([1, 2, 3]);
26///
27/// // Create a 2D array
28/// let a2 = array!([
29///     [1, 2, 3],
30///     [4, 5, 6]
31/// ]);
32///
33/// // Create a 3D array
34/// let a3 = array!([
35///     [
36///         [1, 2, 3],
37///         [4, 5, 6]
38///     ],
39///     [
40///         [7, 8, 9],
41///         [10, 11, 12]
42///     ]
43/// ]);
44///
45/// // Create a 2x2 array by specifying the shape
46/// let a = array!([1, 2, 3, 4], shape=[2, 2]);
47/// ```
48#[macro_export]
49macro_rules! array {
50    ([$($x:expr),*], shape=[$($s:expr),*]) => {
51        {
52            let data = [$($x,)*];
53            let shape = [$($s,)*];
54            $crate::Array::from_slice(&data, &shape)
55        }
56    };
57    ([$([$([$($x:expr),*]),*]),*]) => {
58        {
59            let arr = [$([$([$($x,)*],)*],)*];
60            <$crate::Array as $crate::FromNested<_>>::from_nested(arr)
61        }
62    };
63    ([$([$($x:expr),*]),*]) => {
64        {
65            let arr = [$([$($x,)*],)*];
66            <$crate::Array as $crate::FromNested<_>>::from_nested(arr)
67        }
68    };
69    ([$($x:expr),*]) => {
70        {
71            let arr = [$($x,)*];
72            <$crate::Array as $crate::FromNested<_>>::from_nested(arr)
73        }
74    };
75    ($x:expr) => {
76        {
77            <$crate::Array as $crate::FromScalar<_>>::from_scalar($x)
78        }
79    };
80    // Empty array default to f32
81    () => {
82        $crate::Array::from_slice::<f32>(&[], &[0])
83    };
84}
85
86#[cfg(test)]
87mod tests {
88    use crate::ops::indexing::IndexOp;
89
90    #[test]
91    fn test_scalar_array() {
92        let arr = array!(1);
93
94        // Scalar array has 0 dimension
95        assert_eq!(arr.ndim(), 0);
96        // Scalar array has empty shape
97        assert!(arr.shape().is_empty());
98        assert_eq!(arr.item::<i32>(), 1);
99    }
100
101    #[test]
102    fn test_array_1d() {
103        let arr = array!([1, 2, 3]);
104
105        // One element array has 1 dimension
106        assert_eq!(arr.ndim(), 1);
107        assert_eq!(arr.shape(), &[3]);
108        assert_eq!(arr.index(0).item::<i32>(), 1);
109        assert_eq!(arr.index(1).item::<i32>(), 2);
110        assert_eq!(arr.index(2).item::<i32>(), 3);
111    }
112
113    #[test]
114    fn test_array_2d() {
115        let a = array!([[1, 2, 3], [4, 5, 6]]);
116
117        assert_eq!(a.ndim(), 2);
118        assert_eq!(a.shape(), &[2, 3]);
119        assert_eq!(a.index((0, 0)).item::<i32>(), 1);
120        assert_eq!(a.index((0, 1)).item::<i32>(), 2);
121        assert_eq!(a.index((0, 2)).item::<i32>(), 3);
122        assert_eq!(a.index((1, 0)).item::<i32>(), 4);
123        assert_eq!(a.index((1, 1)).item::<i32>(), 5);
124        assert_eq!(a.index((1, 2)).item::<i32>(), 6);
125    }
126
127    #[test]
128    fn test_array_3d() {
129        let a = array!([[[1, 2, 3], [4, 5, 6]], [[7, 8, 9], [10, 11, 12]]]);
130
131        assert!(a.ndim() == 3);
132        assert_eq!(a.shape(), &[2, 2, 3]);
133        assert_eq!(a.index((0, 0, 0)).item::<i32>(), 1);
134        assert_eq!(a.index((0, 0, 1)).item::<i32>(), 2);
135        assert_eq!(a.index((0, 0, 2)).item::<i32>(), 3);
136        assert_eq!(a.index((0, 1, 0)).item::<i32>(), 4);
137        assert_eq!(a.index((0, 1, 1)).item::<i32>(), 5);
138        assert_eq!(a.index((0, 1, 2)).item::<i32>(), 6);
139        assert_eq!(a.index((1, 0, 0)).item::<i32>(), 7);
140        assert_eq!(a.index((1, 0, 1)).item::<i32>(), 8);
141        assert_eq!(a.index((1, 0, 2)).item::<i32>(), 9);
142        assert_eq!(a.index((1, 1, 0)).item::<i32>(), 10);
143        assert_eq!(a.index((1, 1, 1)).item::<i32>(), 11);
144        assert_eq!(a.index((1, 1, 2)).item::<i32>(), 12);
145    }
146
147    #[test]
148    fn test_array_with_shape() {
149        let a = array!([1, 2, 3, 4], shape = [2, 2]);
150
151        assert_eq!(a.ndim(), 2);
152        assert_eq!(a.shape(), &[2, 2]);
153        assert_eq!(a.index((0, 0)).item::<i32>(), 1);
154        assert_eq!(a.index((0, 1)).item::<i32>(), 2);
155        assert_eq!(a.index((1, 0)).item::<i32>(), 3);
156        assert_eq!(a.index((1, 1)).item::<i32>(), 4);
157    }
158}