1use mlx_internal_macros::default_device;
2
3use crate::{
4 error::Result, utils::guard::Guarded, Array, ArrayElement, Dtype, Stream, StreamOrDevice,
5};
6
7impl Array {
8 #[default_device]
24 pub fn as_type_device<T: ArrayElement>(&self, stream: impl AsRef<Stream>) -> Result<Array> {
25 self.as_dtype_device(T::DTYPE, stream)
26 }
27
28 #[default_device]
30 pub fn as_dtype_device(&self, dtype: Dtype, stream: impl AsRef<Stream>) -> Result<Array> {
31 Array::try_from_op(|res| unsafe {
32 mlx_sys::mlx_astype(res, self.as_ptr(), dtype.into(), stream.as_ref().as_ptr())
33 })
34 }
35
36 #[default_device]
46 pub fn view_device<T: ArrayElement>(&self, stream: impl AsRef<Stream>) -> Result<Array> {
47 self.view_dtype_device(T::DTYPE, stream)
48 }
49
50 #[default_device]
52 pub fn view_dtype_device(&self, dtype: Dtype, stream: impl AsRef<Stream>) -> Result<Array> {
53 Array::try_from_op(|res| unsafe {
54 mlx_sys::mlx_view(res, self.as_ptr(), dtype.into(), stream.as_ref().as_ptr())
55 })
56 }
57}
58
59#[cfg(test)]
60mod tests {
61 use super::*;
62 use crate::complex64;
63 use half::{bf16, f16};
64 use pretty_assertions::assert_eq;
65
66 macro_rules! test_as_type {
67 ($src_type:ty, $src_val:expr, $dst_type:ty, $dst_val:expr, $len:expr) => {
68 paste::paste! {
69 #[test]
70 fn [<test_as_type_ $src_type _ $dst_type>]() {
71 let array = Array::from_slice(&[$src_val; $len], &[$len as i32]);
72 let new_array = array.as_type::<$dst_type>().unwrap();
73
74 assert_eq!(new_array.dtype(), $dst_type::DTYPE);
75 assert_eq!(new_array.shape(), &[3]);
76 assert_eq!(new_array.item_size(), std::mem::size_of::<$dst_type>());
77 assert_eq!(new_array.as_slice::<$dst_type>(), &[$dst_val; $len]);
78 }
79 }
80 };
81 }
82
83 test_as_type!(bool, true, i8, 1, 3);
84 test_as_type!(bool, true, i16, 1, 3);
85 test_as_type!(bool, true, i32, 1, 3);
86 test_as_type!(bool, true, i64, 1, 3);
87 test_as_type!(bool, true, u8, 1, 3);
88 test_as_type!(bool, true, u16, 1, 3);
89 test_as_type!(bool, true, u32, 1, 3);
90 test_as_type!(bool, true, u64, 1, 3);
91 test_as_type!(bool, true, f32, 1.0, 3);
92 test_as_type!(bool, true, f16, f16::from_f32(1.0), 3);
93 test_as_type!(bool, true, bf16, bf16::from_f32(1.0), 3);
94 test_as_type!(bool, true, complex64, complex64::new(1.0, 0.0), 3);
95
96 test_as_type!(i8, 1, bool, true, 3);
97 test_as_type!(i8, 1, i16, 1, 3);
98 test_as_type!(i8, 1, i32, 1, 3);
99 test_as_type!(i8, 1, i64, 1, 3);
100 test_as_type!(i8, 1, u8, 1, 3);
101 test_as_type!(i8, 1, u16, 1, 3);
102 test_as_type!(i8, 1, u32, 1, 3);
103 test_as_type!(i8, 1, u64, 1, 3);
104 test_as_type!(i8, 1, f32, 1.0, 3);
105 test_as_type!(i8, 1, f16, f16::from_f32(1.0), 3);
106 test_as_type!(i8, 1, bf16, bf16::from_f32(1.0), 3);
107 test_as_type!(i8, 1, complex64, complex64::new(1.0, 0.0), 3);
108
109 test_as_type!(i16, 1, bool, true, 3);
110 test_as_type!(i16, 1, i8, 1, 3);
111 test_as_type!(i16, 1, i32, 1, 3);
112 test_as_type!(i16, 1, i64, 1, 3);
113 test_as_type!(i16, 1, u8, 1, 3);
114 test_as_type!(i16, 1, u16, 1, 3);
115 test_as_type!(i16, 1, u32, 1, 3);
116 test_as_type!(i16, 1, u64, 1, 3);
117 test_as_type!(i16, 1, f32, 1.0, 3);
118 test_as_type!(i16, 1, f16, f16::from_f32(1.0), 3);
119 test_as_type!(i16, 1, bf16, bf16::from_f32(1.0), 3);
120 test_as_type!(i16, 1, complex64, complex64::new(1.0, 0.0), 3);
121
122 test_as_type!(i32, 1, bool, true, 3);
123 test_as_type!(i32, 1, i8, 1, 3);
124 test_as_type!(i32, 1, i16, 1, 3);
125 test_as_type!(i32, 1, i64, 1, 3);
126 test_as_type!(i32, 1, u8, 1, 3);
127 test_as_type!(i32, 1, u16, 1, 3);
128 test_as_type!(i32, 1, u32, 1, 3);
129 test_as_type!(i32, 1, u64, 1, 3);
130 test_as_type!(i32, 1, f32, 1.0, 3);
131 test_as_type!(i32, 1, f16, f16::from_f32(1.0), 3);
132 test_as_type!(i32, 1, bf16, bf16::from_f32(1.0), 3);
133 test_as_type!(i32, 1, complex64, complex64::new(1.0, 0.0), 3);
134
135 test_as_type!(i64, 1, bool, true, 3);
136 test_as_type!(i64, 1, i8, 1, 3);
137 test_as_type!(i64, 1, i16, 1, 3);
138 test_as_type!(i64, 1, i32, 1, 3);
139 test_as_type!(i64, 1, u8, 1, 3);
140 test_as_type!(i64, 1, u16, 1, 3);
141 test_as_type!(i64, 1, u32, 1, 3);
142 test_as_type!(i64, 1, u64, 1, 3);
143 test_as_type!(i64, 1, f32, 1.0, 3);
144 test_as_type!(i64, 1, f16, f16::from_f32(1.0), 3);
145 test_as_type!(i64, 1, bf16, bf16::from_f32(1.0), 3);
146 test_as_type!(i64, 1, complex64, complex64::new(1.0, 0.0), 3);
147
148 test_as_type!(u8, 1, bool, true, 3);
149 test_as_type!(u8, 1, i8, 1, 3);
150 test_as_type!(u8, 1, i16, 1, 3);
151 test_as_type!(u8, 1, i32, 1, 3);
152 test_as_type!(u8, 1, i64, 1, 3);
153 test_as_type!(u8, 1, u16, 1, 3);
154 test_as_type!(u8, 1, u32, 1, 3);
155 test_as_type!(u8, 1, u64, 1, 3);
156 test_as_type!(u8, 1, f32, 1.0, 3);
157 test_as_type!(u8, 1, f16, f16::from_f32(1.0), 3);
158 test_as_type!(u8, 1, bf16, bf16::from_f32(1.0), 3);
159 test_as_type!(u8, 1, complex64, complex64::new(1.0, 0.0), 3);
160
161 test_as_type!(u16, 1, bool, true, 3);
162 test_as_type!(u16, 1, i8, 1, 3);
163 test_as_type!(u16, 1, i16, 1, 3);
164 test_as_type!(u16, 1, i32, 1, 3);
165 test_as_type!(u16, 1, i64, 1, 3);
166 test_as_type!(u16, 1, u8, 1, 3);
167 test_as_type!(u16, 1, u32, 1, 3);
168 test_as_type!(u16, 1, u64, 1, 3);
169 test_as_type!(u16, 1, f32, 1.0, 3);
170 test_as_type!(u16, 1, f16, f16::from_f32(1.0), 3);
171 test_as_type!(u16, 1, bf16, bf16::from_f32(1.0), 3);
172 test_as_type!(u16, 1, complex64, complex64::new(1.0, 0.0), 3);
173
174 test_as_type!(u32, 1, bool, true, 3);
175 test_as_type!(u32, 1, i8, 1, 3);
176 test_as_type!(u32, 1, i16, 1, 3);
177 test_as_type!(u32, 1, i32, 1, 3);
178 test_as_type!(u32, 1, i64, 1, 3);
179 test_as_type!(u32, 1, u8, 1, 3);
180 test_as_type!(u32, 1, u16, 1, 3);
181 test_as_type!(u32, 1, u64, 1, 3);
182 test_as_type!(u32, 1, f32, 1.0, 3);
183 test_as_type!(u32, 1, f16, f16::from_f32(1.0), 3);
184 test_as_type!(u32, 1, bf16, bf16::from_f32(1.0), 3);
185 test_as_type!(u32, 1, complex64, complex64::new(1.0, 0.0), 3);
186
187 test_as_type!(u64, 1, bool, true, 3);
188 test_as_type!(u64, 1, i8, 1, 3);
189 test_as_type!(u64, 1, i16, 1, 3);
190 test_as_type!(u64, 1, i32, 1, 3);
191 test_as_type!(u64, 1, i64, 1, 3);
192 test_as_type!(u64, 1, u8, 1, 3);
193 test_as_type!(u64, 1, u16, 1, 3);
194 test_as_type!(u64, 1, u32, 1, 3);
195 test_as_type!(u64, 1, f32, 1.0, 3);
196 test_as_type!(u64, 1, f16, f16::from_f32(1.0), 3);
197 test_as_type!(u64, 1, bf16, bf16::from_f32(1.0), 3);
198 test_as_type!(u64, 1, complex64, complex64::new(1.0, 0.0), 3);
199
200 test_as_type!(f32, 1.0, bool, true, 3);
201 test_as_type!(f32, 1.0, i8, 1, 3);
202 test_as_type!(f32, 1.0, i16, 1, 3);
203 test_as_type!(f32, 1.0, i32, 1, 3);
204 test_as_type!(f32, 1.0, i64, 1, 3);
205 test_as_type!(f32, 1.0, u8, 1, 3);
206 test_as_type!(f32, 1.0, u16, 1, 3);
207 test_as_type!(f32, 1.0, u32, 1, 3);
208 test_as_type!(f32, 1.0, u64, 1, 3);
209 test_as_type!(f32, 1.0, f16, f16::from_f32(1.0), 3);
210 test_as_type!(f32, 1.0, bf16, bf16::from_f32(1.0), 3);
211 test_as_type!(f32, 1.0, complex64, complex64::new(1.0, 0.0), 3);
212
213 test_as_type!(f16, f16::from_f32(1.0), bool, true, 3);
214 test_as_type!(f16, f16::from_f32(1.0), i8, 1, 3);
215 test_as_type!(f16, f16::from_f32(1.0), i16, 1, 3);
216 test_as_type!(f16, f16::from_f32(1.0), i32, 1, 3);
217 test_as_type!(f16, f16::from_f32(1.0), i64, 1, 3);
218 test_as_type!(f16, f16::from_f32(1.0), u8, 1, 3);
219 test_as_type!(f16, f16::from_f32(1.0), u16, 1, 3);
220 test_as_type!(f16, f16::from_f32(1.0), u32, 1, 3);
221 test_as_type!(f16, f16::from_f32(1.0), u64, 1, 3);
222 test_as_type!(f16, f16::from_f32(1.0), f32, 1.0, 3);
223 test_as_type!(f16, f16::from_f32(1.0), bf16, bf16::from_f32(1.0), 3);
224 test_as_type!(
225 f16,
226 f16::from_f32(1.0),
227 complex64,
228 complex64::new(1.0, 0.0),
229 3
230 );
231
232 test_as_type!(bf16, bf16::from_f32(1.0), bool, true, 3);
233 test_as_type!(bf16, bf16::from_f32(1.0), i8, 1, 3);
234 test_as_type!(bf16, bf16::from_f32(1.0), i16, 1, 3);
235 test_as_type!(bf16, bf16::from_f32(1.0), i32, 1, 3);
236 test_as_type!(bf16, bf16::from_f32(1.0), i64, 1, 3);
237 test_as_type!(bf16, bf16::from_f32(1.0), u8, 1, 3);
238 test_as_type!(bf16, bf16::from_f32(1.0), u16, 1, 3);
239 test_as_type!(bf16, bf16::from_f32(1.0), u32, 1, 3);
240 test_as_type!(bf16, bf16::from_f32(1.0), u64, 1, 3);
241 test_as_type!(bf16, bf16::from_f32(1.0), f32, 1.0, 3);
242 test_as_type!(bf16, bf16::from_f32(1.0), f16, f16::from_f32(1.0), 3);
243
244 test_as_type!(complex64, complex64::new(1.0, 0.0), bool, true, 3);
245 test_as_type!(complex64, complex64::new(1.0, 0.0), i8, 1, 3);
246 test_as_type!(complex64, complex64::new(1.0, 0.0), i16, 1, 3);
247 test_as_type!(complex64, complex64::new(1.0, 0.0), i32, 1, 3);
248 test_as_type!(complex64, complex64::new(1.0, 0.0), i64, 1, 3);
249 test_as_type!(complex64, complex64::new(1.0, 0.0), u8, 1, 3);
250 test_as_type!(complex64, complex64::new(1.0, 0.0), u16, 1, 3);
251 test_as_type!(complex64, complex64::new(1.0, 0.0), u32, 1, 3);
252 test_as_type!(complex64, complex64::new(1.0, 0.0), u64, 1, 3);
253 test_as_type!(complex64, complex64::new(1.0, 0.0), f32, 1.0, 3);
254 test_as_type!(
255 complex64,
256 complex64::new(1.0, 0.0),
257 f16,
258 f16::from_f32(1.0),
259 3
260 );
261 test_as_type!(
262 complex64,
263 complex64::new(1.0, 0.0),
264 bf16,
265 bf16::from_f32(1.0),
266 3
267 );
268
269 #[test]
270 fn test_view() {
271 let array = Array::from_slice(&[1i16, 2, 3], &[3]);
272 let new_array = array.view::<i8>().unwrap();
273
274 assert_eq!(new_array.dtype(), Dtype::Int8);
275 assert_eq!(new_array.shape(), &[6]);
276 assert_eq!(new_array.item_size(), 1);
277 assert_eq!(new_array.as_slice::<i8>(), &[1, 0, 2, 0, 3, 0]);
278 }
279}