mlx_rs/ops/
conversion.rs

1use mlx_internal_macros::default_device;
2
3use crate::{
4    error::Result, utils::guard::Guarded, Array, ArrayElement, Dtype, Stream, StreamOrDevice,
5};
6
7impl Array {
8    /// Create a new array with the contents converted to the given [ArrayElement] type.
9    ///
10    /// # Example
11    ///
12    /// ```rust
13    /// use mlx_rs::{Array, Dtype};
14    ///
15    /// let array = Array::from_slice(&[1i16,2,3], &[3]);
16    /// let mut new_array = array.as_type::<f32>().unwrap();
17    ///
18    /// assert_eq!(new_array.dtype(), Dtype::Float32);
19    /// assert_eq!(new_array.shape(), &[3]);
20    /// assert_eq!(new_array.item_size(), 4);
21    /// assert_eq!(new_array.as_slice::<f32>(), &[1.0,2.0,3.0]);
22    /// ```
23    #[default_device]
24    pub fn as_type_device<T: ArrayElement>(&self, stream: impl AsRef<Stream>) -> Result<Array> {
25        self.as_dtype_device(T::DTYPE, stream)
26    }
27
28    /// Same as `as_type` but with a [`Dtype`] argument.
29    #[default_device]
30    pub fn as_dtype_device(&self, dtype: Dtype, stream: impl AsRef<Stream>) -> Result<Array> {
31        Array::try_from_op(|res| unsafe {
32            mlx_sys::mlx_astype(res, self.as_ptr(), dtype.into(), stream.as_ref().as_ptr())
33        })
34    }
35
36    /// View the array as a different type.
37    ///
38    /// The output array will change along the last axis if the input array's
39    /// type and the output array's type do not have the same size.
40    ///
41    /// _Note: the view op does not imply that the input and output arrays share
42    /// their underlying data. The view only guarantees that the binary
43    /// representation of each element (or group of elements) is the same._
44    ///
45    #[default_device]
46    pub fn view_device<T: ArrayElement>(&self, stream: impl AsRef<Stream>) -> Result<Array> {
47        self.view_dtype_device(T::DTYPE, stream)
48    }
49
50    /// Same as `view` but with a [`Dtype`] argument.
51    #[default_device]
52    pub fn view_dtype_device(&self, dtype: Dtype, stream: impl AsRef<Stream>) -> Result<Array> {
53        Array::try_from_op(|res| unsafe {
54            mlx_sys::mlx_view(res, self.as_ptr(), dtype.into(), stream.as_ref().as_ptr())
55        })
56    }
57}
58
59#[cfg(test)]
60mod tests {
61    use super::*;
62    use crate::complex64;
63    use half::{bf16, f16};
64    use pretty_assertions::assert_eq;
65
66    macro_rules! test_as_type {
67        ($src_type:ty, $src_val:expr, $dst_type:ty, $dst_val:expr, $len:expr) => {
68            paste::paste! {
69                #[test]
70                fn [<test_as_type_ $src_type _ $dst_type>]() {
71                    let array = Array::from_slice(&[$src_val; $len], &[$len as i32]);
72                    let new_array = array.as_type::<$dst_type>().unwrap();
73
74                    assert_eq!(new_array.dtype(), $dst_type::DTYPE);
75                    assert_eq!(new_array.shape(), &[3]);
76                    assert_eq!(new_array.item_size(), std::mem::size_of::<$dst_type>());
77                    assert_eq!(new_array.as_slice::<$dst_type>(), &[$dst_val; $len]);
78                }
79            }
80        };
81    }
82
83    test_as_type!(bool, true, i8, 1, 3);
84    test_as_type!(bool, true, i16, 1, 3);
85    test_as_type!(bool, true, i32, 1, 3);
86    test_as_type!(bool, true, i64, 1, 3);
87    test_as_type!(bool, true, u8, 1, 3);
88    test_as_type!(bool, true, u16, 1, 3);
89    test_as_type!(bool, true, u32, 1, 3);
90    test_as_type!(bool, true, u64, 1, 3);
91    test_as_type!(bool, true, f32, 1.0, 3);
92    test_as_type!(bool, true, f16, f16::from_f32(1.0), 3);
93    test_as_type!(bool, true, bf16, bf16::from_f32(1.0), 3);
94    test_as_type!(bool, true, complex64, complex64::new(1.0, 0.0), 3);
95
96    test_as_type!(i8, 1, bool, true, 3);
97    test_as_type!(i8, 1, i16, 1, 3);
98    test_as_type!(i8, 1, i32, 1, 3);
99    test_as_type!(i8, 1, i64, 1, 3);
100    test_as_type!(i8, 1, u8, 1, 3);
101    test_as_type!(i8, 1, u16, 1, 3);
102    test_as_type!(i8, 1, u32, 1, 3);
103    test_as_type!(i8, 1, u64, 1, 3);
104    test_as_type!(i8, 1, f32, 1.0, 3);
105    test_as_type!(i8, 1, f16, f16::from_f32(1.0), 3);
106    test_as_type!(i8, 1, bf16, bf16::from_f32(1.0), 3);
107    test_as_type!(i8, 1, complex64, complex64::new(1.0, 0.0), 3);
108
109    test_as_type!(i16, 1, bool, true, 3);
110    test_as_type!(i16, 1, i8, 1, 3);
111    test_as_type!(i16, 1, i32, 1, 3);
112    test_as_type!(i16, 1, i64, 1, 3);
113    test_as_type!(i16, 1, u8, 1, 3);
114    test_as_type!(i16, 1, u16, 1, 3);
115    test_as_type!(i16, 1, u32, 1, 3);
116    test_as_type!(i16, 1, u64, 1, 3);
117    test_as_type!(i16, 1, f32, 1.0, 3);
118    test_as_type!(i16, 1, f16, f16::from_f32(1.0), 3);
119    test_as_type!(i16, 1, bf16, bf16::from_f32(1.0), 3);
120    test_as_type!(i16, 1, complex64, complex64::new(1.0, 0.0), 3);
121
122    test_as_type!(i32, 1, bool, true, 3);
123    test_as_type!(i32, 1, i8, 1, 3);
124    test_as_type!(i32, 1, i16, 1, 3);
125    test_as_type!(i32, 1, i64, 1, 3);
126    test_as_type!(i32, 1, u8, 1, 3);
127    test_as_type!(i32, 1, u16, 1, 3);
128    test_as_type!(i32, 1, u32, 1, 3);
129    test_as_type!(i32, 1, u64, 1, 3);
130    test_as_type!(i32, 1, f32, 1.0, 3);
131    test_as_type!(i32, 1, f16, f16::from_f32(1.0), 3);
132    test_as_type!(i32, 1, bf16, bf16::from_f32(1.0), 3);
133    test_as_type!(i32, 1, complex64, complex64::new(1.0, 0.0), 3);
134
135    test_as_type!(i64, 1, bool, true, 3);
136    test_as_type!(i64, 1, i8, 1, 3);
137    test_as_type!(i64, 1, i16, 1, 3);
138    test_as_type!(i64, 1, i32, 1, 3);
139    test_as_type!(i64, 1, u8, 1, 3);
140    test_as_type!(i64, 1, u16, 1, 3);
141    test_as_type!(i64, 1, u32, 1, 3);
142    test_as_type!(i64, 1, u64, 1, 3);
143    test_as_type!(i64, 1, f32, 1.0, 3);
144    test_as_type!(i64, 1, f16, f16::from_f32(1.0), 3);
145    test_as_type!(i64, 1, bf16, bf16::from_f32(1.0), 3);
146    test_as_type!(i64, 1, complex64, complex64::new(1.0, 0.0), 3);
147
148    test_as_type!(u8, 1, bool, true, 3);
149    test_as_type!(u8, 1, i8, 1, 3);
150    test_as_type!(u8, 1, i16, 1, 3);
151    test_as_type!(u8, 1, i32, 1, 3);
152    test_as_type!(u8, 1, i64, 1, 3);
153    test_as_type!(u8, 1, u16, 1, 3);
154    test_as_type!(u8, 1, u32, 1, 3);
155    test_as_type!(u8, 1, u64, 1, 3);
156    test_as_type!(u8, 1, f32, 1.0, 3);
157    test_as_type!(u8, 1, f16, f16::from_f32(1.0), 3);
158    test_as_type!(u8, 1, bf16, bf16::from_f32(1.0), 3);
159    test_as_type!(u8, 1, complex64, complex64::new(1.0, 0.0), 3);
160
161    test_as_type!(u16, 1, bool, true, 3);
162    test_as_type!(u16, 1, i8, 1, 3);
163    test_as_type!(u16, 1, i16, 1, 3);
164    test_as_type!(u16, 1, i32, 1, 3);
165    test_as_type!(u16, 1, i64, 1, 3);
166    test_as_type!(u16, 1, u8, 1, 3);
167    test_as_type!(u16, 1, u32, 1, 3);
168    test_as_type!(u16, 1, u64, 1, 3);
169    test_as_type!(u16, 1, f32, 1.0, 3);
170    test_as_type!(u16, 1, f16, f16::from_f32(1.0), 3);
171    test_as_type!(u16, 1, bf16, bf16::from_f32(1.0), 3);
172    test_as_type!(u16, 1, complex64, complex64::new(1.0, 0.0), 3);
173
174    test_as_type!(u32, 1, bool, true, 3);
175    test_as_type!(u32, 1, i8, 1, 3);
176    test_as_type!(u32, 1, i16, 1, 3);
177    test_as_type!(u32, 1, i32, 1, 3);
178    test_as_type!(u32, 1, i64, 1, 3);
179    test_as_type!(u32, 1, u8, 1, 3);
180    test_as_type!(u32, 1, u16, 1, 3);
181    test_as_type!(u32, 1, u64, 1, 3);
182    test_as_type!(u32, 1, f32, 1.0, 3);
183    test_as_type!(u32, 1, f16, f16::from_f32(1.0), 3);
184    test_as_type!(u32, 1, bf16, bf16::from_f32(1.0), 3);
185    test_as_type!(u32, 1, complex64, complex64::new(1.0, 0.0), 3);
186
187    test_as_type!(u64, 1, bool, true, 3);
188    test_as_type!(u64, 1, i8, 1, 3);
189    test_as_type!(u64, 1, i16, 1, 3);
190    test_as_type!(u64, 1, i32, 1, 3);
191    test_as_type!(u64, 1, i64, 1, 3);
192    test_as_type!(u64, 1, u8, 1, 3);
193    test_as_type!(u64, 1, u16, 1, 3);
194    test_as_type!(u64, 1, u32, 1, 3);
195    test_as_type!(u64, 1, f32, 1.0, 3);
196    test_as_type!(u64, 1, f16, f16::from_f32(1.0), 3);
197    test_as_type!(u64, 1, bf16, bf16::from_f32(1.0), 3);
198    test_as_type!(u64, 1, complex64, complex64::new(1.0, 0.0), 3);
199
200    test_as_type!(f32, 1.0, bool, true, 3);
201    test_as_type!(f32, 1.0, i8, 1, 3);
202    test_as_type!(f32, 1.0, i16, 1, 3);
203    test_as_type!(f32, 1.0, i32, 1, 3);
204    test_as_type!(f32, 1.0, i64, 1, 3);
205    test_as_type!(f32, 1.0, u8, 1, 3);
206    test_as_type!(f32, 1.0, u16, 1, 3);
207    test_as_type!(f32, 1.0, u32, 1, 3);
208    test_as_type!(f32, 1.0, u64, 1, 3);
209    test_as_type!(f32, 1.0, f16, f16::from_f32(1.0), 3);
210    test_as_type!(f32, 1.0, bf16, bf16::from_f32(1.0), 3);
211    test_as_type!(f32, 1.0, complex64, complex64::new(1.0, 0.0), 3);
212
213    test_as_type!(f16, f16::from_f32(1.0), bool, true, 3);
214    test_as_type!(f16, f16::from_f32(1.0), i8, 1, 3);
215    test_as_type!(f16, f16::from_f32(1.0), i16, 1, 3);
216    test_as_type!(f16, f16::from_f32(1.0), i32, 1, 3);
217    test_as_type!(f16, f16::from_f32(1.0), i64, 1, 3);
218    test_as_type!(f16, f16::from_f32(1.0), u8, 1, 3);
219    test_as_type!(f16, f16::from_f32(1.0), u16, 1, 3);
220    test_as_type!(f16, f16::from_f32(1.0), u32, 1, 3);
221    test_as_type!(f16, f16::from_f32(1.0), u64, 1, 3);
222    test_as_type!(f16, f16::from_f32(1.0), f32, 1.0, 3);
223    test_as_type!(f16, f16::from_f32(1.0), bf16, bf16::from_f32(1.0), 3);
224    test_as_type!(
225        f16,
226        f16::from_f32(1.0),
227        complex64,
228        complex64::new(1.0, 0.0),
229        3
230    );
231
232    test_as_type!(bf16, bf16::from_f32(1.0), bool, true, 3);
233    test_as_type!(bf16, bf16::from_f32(1.0), i8, 1, 3);
234    test_as_type!(bf16, bf16::from_f32(1.0), i16, 1, 3);
235    test_as_type!(bf16, bf16::from_f32(1.0), i32, 1, 3);
236    test_as_type!(bf16, bf16::from_f32(1.0), i64, 1, 3);
237    test_as_type!(bf16, bf16::from_f32(1.0), u8, 1, 3);
238    test_as_type!(bf16, bf16::from_f32(1.0), u16, 1, 3);
239    test_as_type!(bf16, bf16::from_f32(1.0), u32, 1, 3);
240    test_as_type!(bf16, bf16::from_f32(1.0), u64, 1, 3);
241    test_as_type!(bf16, bf16::from_f32(1.0), f32, 1.0, 3);
242    test_as_type!(bf16, bf16::from_f32(1.0), f16, f16::from_f32(1.0), 3);
243
244    test_as_type!(complex64, complex64::new(1.0, 0.0), bool, true, 3);
245    test_as_type!(complex64, complex64::new(1.0, 0.0), i8, 1, 3);
246    test_as_type!(complex64, complex64::new(1.0, 0.0), i16, 1, 3);
247    test_as_type!(complex64, complex64::new(1.0, 0.0), i32, 1, 3);
248    test_as_type!(complex64, complex64::new(1.0, 0.0), i64, 1, 3);
249    test_as_type!(complex64, complex64::new(1.0, 0.0), u8, 1, 3);
250    test_as_type!(complex64, complex64::new(1.0, 0.0), u16, 1, 3);
251    test_as_type!(complex64, complex64::new(1.0, 0.0), u32, 1, 3);
252    test_as_type!(complex64, complex64::new(1.0, 0.0), u64, 1, 3);
253    test_as_type!(complex64, complex64::new(1.0, 0.0), f32, 1.0, 3);
254    test_as_type!(
255        complex64,
256        complex64::new(1.0, 0.0),
257        f16,
258        f16::from_f32(1.0),
259        3
260    );
261    test_as_type!(
262        complex64,
263        complex64::new(1.0, 0.0),
264        bf16,
265        bf16::from_f32(1.0),
266        3
267    );
268
269    #[test]
270    fn test_view() {
271        let array = Array::from_slice(&[1i16, 2, 3], &[3]);
272        let new_array = array.view::<i8>().unwrap();
273
274        assert_eq!(new_array.dtype(), Dtype::Int8);
275        assert_eq!(new_array.shape(), &[6]);
276        assert_eq!(new_array.item_size(), 1);
277        assert_eq!(new_array.as_slice::<i8>(), &[1, 0, 2, 0, 3, 0]);
278    }
279}