1use std::ffi::CString;
2
3use mlx_internal_macros::{default_device, generate_macro};
4
5use crate::utils::guard::Guarded;
6use crate::utils::VectorArray;
7use crate::{
8 error::{Exception, Result},
9 Array, Stream, StreamOrDevice,
10};
11
12impl Array {
13 #[default_device]
23 pub fn diag_device(
24 &self,
25 k: impl Into<Option<i32>>,
26 stream: impl AsRef<Stream>,
27 ) -> Result<Array> {
28 Array::try_from_op(|res| unsafe {
29 mlx_sys::mlx_diag(
30 res,
31 self.as_ptr(),
32 k.into().unwrap_or(0),
33 stream.as_ref().as_ptr(),
34 )
35 })
36 }
37
38 #[default_device]
53 pub fn diagonal_device(
54 &self,
55 offset: impl Into<Option<i32>>,
56 axis1: impl Into<Option<i32>>,
57 axis2: impl Into<Option<i32>>,
58 stream: impl AsRef<Stream>,
59 ) -> Result<Array> {
60 Array::try_from_op(|res| unsafe {
61 mlx_sys::mlx_diagonal(
62 res,
63 self.as_ptr(),
64 offset.into().unwrap_or(0),
65 axis1.into().unwrap_or(0),
66 axis2.into().unwrap_or(1),
67 stream.as_ref().as_ptr(),
68 )
69 })
70 }
71
72 #[default_device]
81 pub fn hadamard_transform_device(
82 &self,
83 scale: impl Into<Option<f32>>,
84 stream: impl AsRef<Stream>,
85 ) -> Result<Array> {
86 let scale = scale.into();
87 let scale = mlx_sys::mlx_optional_float {
88 value: scale.unwrap_or(0.0),
89 has_value: scale.is_some(),
90 };
91
92 Array::try_from_op(|res| unsafe {
93 mlx_sys::mlx_hadamard_transform(res, self.as_ptr(), scale, stream.as_ref().as_ptr())
94 })
95 }
96}
97
98#[generate_macro]
100#[default_device]
101pub fn diag_device(
102 a: impl AsRef<Array>,
103 #[optional] k: impl Into<Option<i32>>,
104 #[optional] stream: impl AsRef<Stream>,
105) -> Result<Array> {
106 a.as_ref().diag_device(k, stream)
107}
108
109#[generate_macro]
111#[default_device]
112pub fn diagonal_device(
113 a: impl AsRef<Array>,
114 #[optional] offset: impl Into<Option<i32>>,
115 #[optional] axis1: impl Into<Option<i32>>,
116 #[optional] axis2: impl Into<Option<i32>>,
117 #[optional] stream: impl AsRef<Stream>,
118) -> Result<Array> {
119 a.as_ref().diagonal_device(offset, axis1, axis2, stream)
120}
121
122#[generate_macro]
130#[default_device]
131pub fn einsum_device<'a>(
132 subscripts: &str,
133 operands: impl IntoIterator<Item = &'a Array>,
134 #[optional] stream: impl AsRef<Stream>,
135) -> Result<Array> {
136 let c_subscripts =
137 CString::new(subscripts).map_err(|_| Exception::from("Invalid subscripts"))?;
138 let c_operands = VectorArray::try_from_iter(operands.into_iter())?;
139
140 Array::try_from_op(|res| unsafe {
141 mlx_sys::mlx_einsum(
142 res,
143 c_subscripts.as_ptr(),
144 c_operands.as_ptr(),
145 stream.as_ref().as_ptr(),
146 )
147 })
148}
149
150#[generate_macro]
158#[default_device]
159pub fn kron_device(
160 a: impl AsRef<Array>,
161 b: impl AsRef<Array>,
162 #[optional] stream: impl AsRef<Stream>,
163) -> Result<Array> {
164 Array::try_from_op(|res| unsafe {
165 mlx_sys::mlx_kron(
166 res,
167 a.as_ref().as_ptr(),
168 b.as_ref().as_ptr(),
169 stream.as_ref().as_ptr(),
170 )
171 })
172}
173
174#[cfg(test)]
175mod tests {
176 use crate::{
177 array,
178 ops::{arange, diag, einsum, reshape},
179 Array,
180 };
181 use pretty_assertions::assert_eq;
182
183 use super::diagonal;
184
185 #[test]
186 fn test_diagonal() {
187 let x = Array::from_slice(&[0, 1, 2, 3, 4, 5, 6, 7], &[4, 2]);
188 let out = diagonal(&x, None, None, None).unwrap();
189 assert_eq!(out, array!([0, 3]));
190
191 assert!(diagonal(&x, 1, 6, 0).is_err());
192 assert!(diagonal(&x, 1, 0, -3).is_err());
193
194 let x = Array::from_slice(&[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11], &[3, 4]);
195 let out = diagonal(&x, 2, 1, 0).unwrap();
196 assert_eq!(out, array!([8]));
197
198 let out = diagonal(&x, -1, 0, 1).unwrap();
199 assert_eq!(out, array!([4, 9]));
200
201 let out = diagonal(&x, -5, 0, 1).unwrap();
202 out.eval().unwrap();
203 assert_eq!(out.shape(), &[0]);
204
205 let x = Array::from_slice(&[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11], &[3, 2, 2]);
206 let out = diagonal(&x, 1, 0, 1).unwrap();
207 assert_eq!(out, array!([[2], [3]]));
208
209 let out = diagonal(&x, 0, 2, 0).unwrap();
210 assert_eq!(out, array!([[0, 5], [2, 7]]));
211
212 let out = diagonal(&x, 1, -1, 0).unwrap();
213 assert_eq!(out, array!([[4, 9], [6, 11]]));
214
215 let x = reshape(arange::<_, f32>(None, 16, None).unwrap(), &[2, 2, 2, 2]).unwrap();
216 let out = diagonal(&x, 0, 0, 1).unwrap();
217 assert_eq!(
218 out,
219 Array::from_slice(&[0, 12, 1, 13, 2, 14, 3, 15], &[2, 2, 2])
220 );
221
222 assert!(diagonal(&x, 0, 1, 1).is_err());
223
224 let x = array!([0, 1]);
225 assert!(diagonal(&x, 0, 0, 1).is_err());
226 }
227
228 #[test]
229 fn test_diag() {
230 assert!(diag(Array::from_f32(0.0), None).is_err());
232 assert!(diag(Array::from_slice(&[0.0], &[1, 1, 1]), None).is_err());
233
234 let x = array!([0, 1, 2, 3]);
236 let out = diag(&x, 0).unwrap();
237 assert_eq!(
238 out,
239 array!([[0, 0, 0, 0], [0, 1, 0, 0], [0, 0, 2, 0], [0, 0, 0, 3]])
240 );
241
242 let out = diag(&x, 1).unwrap();
243 assert_eq!(
244 out,
245 array!([
246 [0, 0, 0, 0, 0],
247 [0, 0, 1, 0, 0],
248 [0, 0, 0, 2, 0],
249 [0, 0, 0, 0, 3],
250 [0, 0, 0, 0, 0]
251 ])
252 );
253
254 let out = diag(&x, -1).unwrap();
255 assert_eq!(
256 out,
257 array!([
258 [0, 0, 0, 0, 0],
259 [0, 0, 0, 0, 0],
260 [0, 1, 0, 0, 0],
261 [0, 0, 2, 0, 0],
262 [0, 0, 0, 3, 0]
263 ])
264 );
265
266 let x = Array::from_slice(&[0, 1, 2, 3, 4, 5, 6, 7, 8], &[3, 3]);
268 let out = diag(&x, 0).unwrap();
269 assert_eq!(out, array!([0, 4, 8]));
270
271 let out = diag(&x, 1).unwrap();
272 assert_eq!(out, array!([1, 5]));
273
274 let out = diag(&x, -1).unwrap();
275 assert_eq!(out, array!([3, 7]));
276 }
277
278 #[test]
279 fn test_einsum() {
280 let a = array!([0.0, 1.0, 2.0, 3.0]);
282 let b = array!([4.0, 5.0, 6.0, 7.0]);
283 let out = einsum("i,i->", &[a, b]).unwrap();
284 assert_eq!(out, array!(38.0));
285
286 let m = array!([[1, 2], [3, 4]]);
288 let out = einsum("ii->", &[m]).unwrap();
289 assert_eq!(out, array!(5.0));
290 }
291
292 #[test]
293 fn test_hadamard_transform() {
294 let input = Array::from_slice(&[1.0, -1.0, -1.0, 1.0], &[2, 2]);
295 let expected = Array::from_slice(
296 &[
297 0.0,
298 2.0_f32 / 2.0_f32.sqrt(),
299 0.0,
300 -2.0_f32 / 2.0_f32.sqrt(),
301 ],
302 &[2, 2],
303 );
304 let result = input.hadamard_transform(None).unwrap();
305
306 let c = result.all_close(&expected, 1e-5, 1e-5, None).unwrap();
307 let c_data: &[bool] = c.as_slice();
308 assert_eq!(c_data, [true]);
309 }
310
311 #[test]
313 fn test_kron() {
314 let x = array!([1, 2]);
316 let y = array!([3, 4]);
317 let z = super::kron(&x, &y).unwrap();
318 assert_eq!(z, array!([3, 4, 6, 8]));
319
320 let x = array!([[1, 2], [3, 4]]);
322 let y = array!([[0, 5], [6, 7]]);
323 let z = super::kron(&x, &y).unwrap();
324 assert_eq!(
325 z,
326 array!([
327 [0, 5, 0, 10],
328 [6, 7, 12, 14],
329 [0, 15, 0, 20],
330 [18, 21, 24, 28]
331 ])
332 );
333 }
334}