mlx_rs/ops/
conversion.rs

1use mlx_internal_macros::default_device;
2
3use crate::{error::Result, utils::guard::Guarded, Array, ArrayElement, Dtype, Stream};
4
5impl Array {
6    /// Create a new array with the contents converted to the given [ArrayElement] type.
7    ///
8    /// # Example
9    ///
10    /// ```rust
11    /// use mlx_rs::{Array, Dtype};
12    ///
13    /// let array = Array::from_slice(&[1i16,2,3], &[3]);
14    /// let mut new_array = array.as_type::<f32>().unwrap();
15    ///
16    /// assert_eq!(new_array.dtype(), Dtype::Float32);
17    /// assert_eq!(new_array.shape(), &[3]);
18    /// assert_eq!(new_array.item_size(), 4);
19    /// assert_eq!(new_array.as_slice::<f32>(), &[1.0,2.0,3.0]);
20    /// ```
21    #[default_device]
22    pub fn as_type_device<T: ArrayElement>(&self, stream: impl AsRef<Stream>) -> Result<Array> {
23        self.as_dtype_device(T::DTYPE, stream)
24    }
25
26    /// Same as `as_type` but with a [`Dtype`] argument.
27    #[default_device]
28    pub fn as_dtype_device(&self, dtype: Dtype, stream: impl AsRef<Stream>) -> Result<Array> {
29        Array::try_from_op(|res| unsafe {
30            mlx_sys::mlx_astype(res, self.as_ptr(), dtype.into(), stream.as_ref().as_ptr())
31        })
32    }
33
34    /// View the array as a different type.
35    ///
36    /// The output array will change along the last axis if the input array's
37    /// type and the output array's type do not have the same size.
38    ///
39    /// _Note: the view op does not imply that the input and output arrays share
40    /// their underlying data. The view only guarantees that the binary
41    /// representation of each element (or group of elements) is the same._
42    ///
43    #[default_device]
44    pub fn view_device<T: ArrayElement>(&self, stream: impl AsRef<Stream>) -> Result<Array> {
45        self.view_dtype_device(T::DTYPE, stream)
46    }
47
48    /// Same as `view` but with a [`Dtype`] argument.
49    #[default_device]
50    pub fn view_dtype_device(&self, dtype: Dtype, stream: impl AsRef<Stream>) -> Result<Array> {
51        Array::try_from_op(|res| unsafe {
52            mlx_sys::mlx_view(res, self.as_ptr(), dtype.into(), stream.as_ref().as_ptr())
53        })
54    }
55}
56
57#[cfg(test)]
58mod tests {
59    use super::*;
60    use crate::complex64;
61    use half::{bf16, f16};
62    use pretty_assertions::assert_eq;
63
64    macro_rules! test_as_type {
65        ($src_type:ty, $src_val:expr, $dst_type:ty, $dst_val:expr, $len:expr) => {
66            paste::paste! {
67                #[test]
68                fn [<test_as_type_ $src_type _ $dst_type>]() {
69                    let array = Array::from_slice(&[$src_val; $len], &[$len as i32]);
70                    let new_array = array.as_type::<$dst_type>().unwrap();
71
72                    assert_eq!(new_array.dtype(), $dst_type::DTYPE);
73                    assert_eq!(new_array.shape(), &[3]);
74                    assert_eq!(new_array.item_size(), std::mem::size_of::<$dst_type>());
75                    assert_eq!(new_array.as_slice::<$dst_type>(), &[$dst_val; $len]);
76                }
77            }
78        };
79    }
80
81    test_as_type!(bool, true, i8, 1, 3);
82    test_as_type!(bool, true, i16, 1, 3);
83    test_as_type!(bool, true, i32, 1, 3);
84    test_as_type!(bool, true, i64, 1, 3);
85    test_as_type!(bool, true, u8, 1, 3);
86    test_as_type!(bool, true, u16, 1, 3);
87    test_as_type!(bool, true, u32, 1, 3);
88    test_as_type!(bool, true, u64, 1, 3);
89    test_as_type!(bool, true, f32, 1.0, 3);
90    test_as_type!(bool, true, f16, f16::from_f32(1.0), 3);
91    test_as_type!(bool, true, bf16, bf16::from_f32(1.0), 3);
92    test_as_type!(bool, true, complex64, complex64::new(1.0, 0.0), 3);
93
94    test_as_type!(i8, 1, bool, true, 3);
95    test_as_type!(i8, 1, i16, 1, 3);
96    test_as_type!(i8, 1, i32, 1, 3);
97    test_as_type!(i8, 1, i64, 1, 3);
98    test_as_type!(i8, 1, u8, 1, 3);
99    test_as_type!(i8, 1, u16, 1, 3);
100    test_as_type!(i8, 1, u32, 1, 3);
101    test_as_type!(i8, 1, u64, 1, 3);
102    test_as_type!(i8, 1, f32, 1.0, 3);
103    test_as_type!(i8, 1, f16, f16::from_f32(1.0), 3);
104    test_as_type!(i8, 1, bf16, bf16::from_f32(1.0), 3);
105    test_as_type!(i8, 1, complex64, complex64::new(1.0, 0.0), 3);
106
107    test_as_type!(i16, 1, bool, true, 3);
108    test_as_type!(i16, 1, i8, 1, 3);
109    test_as_type!(i16, 1, i32, 1, 3);
110    test_as_type!(i16, 1, i64, 1, 3);
111    test_as_type!(i16, 1, u8, 1, 3);
112    test_as_type!(i16, 1, u16, 1, 3);
113    test_as_type!(i16, 1, u32, 1, 3);
114    test_as_type!(i16, 1, u64, 1, 3);
115    test_as_type!(i16, 1, f32, 1.0, 3);
116    test_as_type!(i16, 1, f16, f16::from_f32(1.0), 3);
117    test_as_type!(i16, 1, bf16, bf16::from_f32(1.0), 3);
118    test_as_type!(i16, 1, complex64, complex64::new(1.0, 0.0), 3);
119
120    test_as_type!(i32, 1, bool, true, 3);
121    test_as_type!(i32, 1, i8, 1, 3);
122    test_as_type!(i32, 1, i16, 1, 3);
123    test_as_type!(i32, 1, i64, 1, 3);
124    test_as_type!(i32, 1, u8, 1, 3);
125    test_as_type!(i32, 1, u16, 1, 3);
126    test_as_type!(i32, 1, u32, 1, 3);
127    test_as_type!(i32, 1, u64, 1, 3);
128    test_as_type!(i32, 1, f32, 1.0, 3);
129    test_as_type!(i32, 1, f16, f16::from_f32(1.0), 3);
130    test_as_type!(i32, 1, bf16, bf16::from_f32(1.0), 3);
131    test_as_type!(i32, 1, complex64, complex64::new(1.0, 0.0), 3);
132
133    test_as_type!(i64, 1, bool, true, 3);
134    test_as_type!(i64, 1, i8, 1, 3);
135    test_as_type!(i64, 1, i16, 1, 3);
136    test_as_type!(i64, 1, i32, 1, 3);
137    test_as_type!(i64, 1, u8, 1, 3);
138    test_as_type!(i64, 1, u16, 1, 3);
139    test_as_type!(i64, 1, u32, 1, 3);
140    test_as_type!(i64, 1, u64, 1, 3);
141    test_as_type!(i64, 1, f32, 1.0, 3);
142    test_as_type!(i64, 1, f16, f16::from_f32(1.0), 3);
143    test_as_type!(i64, 1, bf16, bf16::from_f32(1.0), 3);
144    test_as_type!(i64, 1, complex64, complex64::new(1.0, 0.0), 3);
145
146    test_as_type!(u8, 1, bool, true, 3);
147    test_as_type!(u8, 1, i8, 1, 3);
148    test_as_type!(u8, 1, i16, 1, 3);
149    test_as_type!(u8, 1, i32, 1, 3);
150    test_as_type!(u8, 1, i64, 1, 3);
151    test_as_type!(u8, 1, u16, 1, 3);
152    test_as_type!(u8, 1, u32, 1, 3);
153    test_as_type!(u8, 1, u64, 1, 3);
154    test_as_type!(u8, 1, f32, 1.0, 3);
155    test_as_type!(u8, 1, f16, f16::from_f32(1.0), 3);
156    test_as_type!(u8, 1, bf16, bf16::from_f32(1.0), 3);
157    test_as_type!(u8, 1, complex64, complex64::new(1.0, 0.0), 3);
158
159    test_as_type!(u16, 1, bool, true, 3);
160    test_as_type!(u16, 1, i8, 1, 3);
161    test_as_type!(u16, 1, i16, 1, 3);
162    test_as_type!(u16, 1, i32, 1, 3);
163    test_as_type!(u16, 1, i64, 1, 3);
164    test_as_type!(u16, 1, u8, 1, 3);
165    test_as_type!(u16, 1, u32, 1, 3);
166    test_as_type!(u16, 1, u64, 1, 3);
167    test_as_type!(u16, 1, f32, 1.0, 3);
168    test_as_type!(u16, 1, f16, f16::from_f32(1.0), 3);
169    test_as_type!(u16, 1, bf16, bf16::from_f32(1.0), 3);
170    test_as_type!(u16, 1, complex64, complex64::new(1.0, 0.0), 3);
171
172    test_as_type!(u32, 1, bool, true, 3);
173    test_as_type!(u32, 1, i8, 1, 3);
174    test_as_type!(u32, 1, i16, 1, 3);
175    test_as_type!(u32, 1, i32, 1, 3);
176    test_as_type!(u32, 1, i64, 1, 3);
177    test_as_type!(u32, 1, u8, 1, 3);
178    test_as_type!(u32, 1, u16, 1, 3);
179    test_as_type!(u32, 1, u64, 1, 3);
180    test_as_type!(u32, 1, f32, 1.0, 3);
181    test_as_type!(u32, 1, f16, f16::from_f32(1.0), 3);
182    test_as_type!(u32, 1, bf16, bf16::from_f32(1.0), 3);
183    test_as_type!(u32, 1, complex64, complex64::new(1.0, 0.0), 3);
184
185    test_as_type!(u64, 1, bool, true, 3);
186    test_as_type!(u64, 1, i8, 1, 3);
187    test_as_type!(u64, 1, i16, 1, 3);
188    test_as_type!(u64, 1, i32, 1, 3);
189    test_as_type!(u64, 1, i64, 1, 3);
190    test_as_type!(u64, 1, u8, 1, 3);
191    test_as_type!(u64, 1, u16, 1, 3);
192    test_as_type!(u64, 1, u32, 1, 3);
193    test_as_type!(u64, 1, f32, 1.0, 3);
194    test_as_type!(u64, 1, f16, f16::from_f32(1.0), 3);
195    test_as_type!(u64, 1, bf16, bf16::from_f32(1.0), 3);
196    test_as_type!(u64, 1, complex64, complex64::new(1.0, 0.0), 3);
197
198    test_as_type!(f32, 1.0, bool, true, 3);
199    test_as_type!(f32, 1.0, i8, 1, 3);
200    test_as_type!(f32, 1.0, i16, 1, 3);
201    test_as_type!(f32, 1.0, i32, 1, 3);
202    test_as_type!(f32, 1.0, i64, 1, 3);
203    test_as_type!(f32, 1.0, u8, 1, 3);
204    test_as_type!(f32, 1.0, u16, 1, 3);
205    test_as_type!(f32, 1.0, u32, 1, 3);
206    test_as_type!(f32, 1.0, u64, 1, 3);
207    test_as_type!(f32, 1.0, f16, f16::from_f32(1.0), 3);
208    test_as_type!(f32, 1.0, bf16, bf16::from_f32(1.0), 3);
209    test_as_type!(f32, 1.0, complex64, complex64::new(1.0, 0.0), 3);
210
211    test_as_type!(f16, f16::from_f32(1.0), bool, true, 3);
212    test_as_type!(f16, f16::from_f32(1.0), i8, 1, 3);
213    test_as_type!(f16, f16::from_f32(1.0), i16, 1, 3);
214    test_as_type!(f16, f16::from_f32(1.0), i32, 1, 3);
215    test_as_type!(f16, f16::from_f32(1.0), i64, 1, 3);
216    test_as_type!(f16, f16::from_f32(1.0), u8, 1, 3);
217    test_as_type!(f16, f16::from_f32(1.0), u16, 1, 3);
218    test_as_type!(f16, f16::from_f32(1.0), u32, 1, 3);
219    test_as_type!(f16, f16::from_f32(1.0), u64, 1, 3);
220    test_as_type!(f16, f16::from_f32(1.0), f32, 1.0, 3);
221    test_as_type!(f16, f16::from_f32(1.0), bf16, bf16::from_f32(1.0), 3);
222    test_as_type!(
223        f16,
224        f16::from_f32(1.0),
225        complex64,
226        complex64::new(1.0, 0.0),
227        3
228    );
229
230    test_as_type!(bf16, bf16::from_f32(1.0), bool, true, 3);
231    test_as_type!(bf16, bf16::from_f32(1.0), i8, 1, 3);
232    test_as_type!(bf16, bf16::from_f32(1.0), i16, 1, 3);
233    test_as_type!(bf16, bf16::from_f32(1.0), i32, 1, 3);
234    test_as_type!(bf16, bf16::from_f32(1.0), i64, 1, 3);
235    test_as_type!(bf16, bf16::from_f32(1.0), u8, 1, 3);
236    test_as_type!(bf16, bf16::from_f32(1.0), u16, 1, 3);
237    test_as_type!(bf16, bf16::from_f32(1.0), u32, 1, 3);
238    test_as_type!(bf16, bf16::from_f32(1.0), u64, 1, 3);
239    test_as_type!(bf16, bf16::from_f32(1.0), f32, 1.0, 3);
240    test_as_type!(bf16, bf16::from_f32(1.0), f16, f16::from_f32(1.0), 3);
241
242    test_as_type!(complex64, complex64::new(1.0, 0.0), bool, true, 3);
243    test_as_type!(complex64, complex64::new(1.0, 0.0), i8, 1, 3);
244    test_as_type!(complex64, complex64::new(1.0, 0.0), i16, 1, 3);
245    test_as_type!(complex64, complex64::new(1.0, 0.0), i32, 1, 3);
246    test_as_type!(complex64, complex64::new(1.0, 0.0), i64, 1, 3);
247    test_as_type!(complex64, complex64::new(1.0, 0.0), u8, 1, 3);
248    test_as_type!(complex64, complex64::new(1.0, 0.0), u16, 1, 3);
249    test_as_type!(complex64, complex64::new(1.0, 0.0), u32, 1, 3);
250    test_as_type!(complex64, complex64::new(1.0, 0.0), u64, 1, 3);
251    test_as_type!(complex64, complex64::new(1.0, 0.0), f32, 1.0, 3);
252    test_as_type!(
253        complex64,
254        complex64::new(1.0, 0.0),
255        f16,
256        f16::from_f32(1.0),
257        3
258    );
259    test_as_type!(
260        complex64,
261        complex64::new(1.0, 0.0),
262        bf16,
263        bf16::from_f32(1.0),
264        3
265    );
266
267    #[test]
268    fn test_view() {
269        let array = Array::from_slice(&[1i16, 2, 3], &[3]);
270        let new_array = array.view::<i8>().unwrap();
271
272        assert_eq!(new_array.dtype(), Dtype::Int8);
273        assert_eq!(new_array.shape(), &[6]);
274        assert_eq!(new_array.item_size(), 1);
275        assert_eq!(new_array.as_slice::<i8>(), &[1, 0, 2, 0, 3, 0]);
276    }
277}