1use mlx_internal_macros::default_device;
2
3use crate::{error::Result, utils::guard::Guarded, Array, ArrayElement, Dtype, Stream};
4
5impl Array {
6 #[default_device]
22 pub fn as_type_device<T: ArrayElement>(&self, stream: impl AsRef<Stream>) -> Result<Array> {
23 self.as_dtype_device(T::DTYPE, stream)
24 }
25
26 #[default_device]
28 pub fn as_dtype_device(&self, dtype: Dtype, stream: impl AsRef<Stream>) -> Result<Array> {
29 Array::try_from_op(|res| unsafe {
30 mlx_sys::mlx_astype(res, self.as_ptr(), dtype.into(), stream.as_ref().as_ptr())
31 })
32 }
33
34 #[default_device]
44 pub fn view_device<T: ArrayElement>(&self, stream: impl AsRef<Stream>) -> Result<Array> {
45 self.view_dtype_device(T::DTYPE, stream)
46 }
47
48 #[default_device]
50 pub fn view_dtype_device(&self, dtype: Dtype, stream: impl AsRef<Stream>) -> Result<Array> {
51 Array::try_from_op(|res| unsafe {
52 mlx_sys::mlx_view(res, self.as_ptr(), dtype.into(), stream.as_ref().as_ptr())
53 })
54 }
55}
56
57#[cfg(test)]
58mod tests {
59 use super::*;
60 use crate::complex64;
61 use half::{bf16, f16};
62 use pretty_assertions::assert_eq;
63
64 macro_rules! test_as_type {
65 ($src_type:ty, $src_val:expr, $dst_type:ty, $dst_val:expr, $len:expr) => {
66 paste::paste! {
67 #[test]
68 fn [<test_as_type_ $src_type _ $dst_type>]() {
69 let array = Array::from_slice(&[$src_val; $len], &[$len as i32]);
70 let new_array = array.as_type::<$dst_type>().unwrap();
71
72 assert_eq!(new_array.dtype(), $dst_type::DTYPE);
73 assert_eq!(new_array.shape(), &[3]);
74 assert_eq!(new_array.item_size(), std::mem::size_of::<$dst_type>());
75 assert_eq!(new_array.as_slice::<$dst_type>(), &[$dst_val; $len]);
76 }
77 }
78 };
79 }
80
81 test_as_type!(bool, true, i8, 1, 3);
82 test_as_type!(bool, true, i16, 1, 3);
83 test_as_type!(bool, true, i32, 1, 3);
84 test_as_type!(bool, true, i64, 1, 3);
85 test_as_type!(bool, true, u8, 1, 3);
86 test_as_type!(bool, true, u16, 1, 3);
87 test_as_type!(bool, true, u32, 1, 3);
88 test_as_type!(bool, true, u64, 1, 3);
89 test_as_type!(bool, true, f32, 1.0, 3);
90 test_as_type!(bool, true, f16, f16::from_f32(1.0), 3);
91 test_as_type!(bool, true, bf16, bf16::from_f32(1.0), 3);
92 test_as_type!(bool, true, complex64, complex64::new(1.0, 0.0), 3);
93
94 test_as_type!(i8, 1, bool, true, 3);
95 test_as_type!(i8, 1, i16, 1, 3);
96 test_as_type!(i8, 1, i32, 1, 3);
97 test_as_type!(i8, 1, i64, 1, 3);
98 test_as_type!(i8, 1, u8, 1, 3);
99 test_as_type!(i8, 1, u16, 1, 3);
100 test_as_type!(i8, 1, u32, 1, 3);
101 test_as_type!(i8, 1, u64, 1, 3);
102 test_as_type!(i8, 1, f32, 1.0, 3);
103 test_as_type!(i8, 1, f16, f16::from_f32(1.0), 3);
104 test_as_type!(i8, 1, bf16, bf16::from_f32(1.0), 3);
105 test_as_type!(i8, 1, complex64, complex64::new(1.0, 0.0), 3);
106
107 test_as_type!(i16, 1, bool, true, 3);
108 test_as_type!(i16, 1, i8, 1, 3);
109 test_as_type!(i16, 1, i32, 1, 3);
110 test_as_type!(i16, 1, i64, 1, 3);
111 test_as_type!(i16, 1, u8, 1, 3);
112 test_as_type!(i16, 1, u16, 1, 3);
113 test_as_type!(i16, 1, u32, 1, 3);
114 test_as_type!(i16, 1, u64, 1, 3);
115 test_as_type!(i16, 1, f32, 1.0, 3);
116 test_as_type!(i16, 1, f16, f16::from_f32(1.0), 3);
117 test_as_type!(i16, 1, bf16, bf16::from_f32(1.0), 3);
118 test_as_type!(i16, 1, complex64, complex64::new(1.0, 0.0), 3);
119
120 test_as_type!(i32, 1, bool, true, 3);
121 test_as_type!(i32, 1, i8, 1, 3);
122 test_as_type!(i32, 1, i16, 1, 3);
123 test_as_type!(i32, 1, i64, 1, 3);
124 test_as_type!(i32, 1, u8, 1, 3);
125 test_as_type!(i32, 1, u16, 1, 3);
126 test_as_type!(i32, 1, u32, 1, 3);
127 test_as_type!(i32, 1, u64, 1, 3);
128 test_as_type!(i32, 1, f32, 1.0, 3);
129 test_as_type!(i32, 1, f16, f16::from_f32(1.0), 3);
130 test_as_type!(i32, 1, bf16, bf16::from_f32(1.0), 3);
131 test_as_type!(i32, 1, complex64, complex64::new(1.0, 0.0), 3);
132
133 test_as_type!(i64, 1, bool, true, 3);
134 test_as_type!(i64, 1, i8, 1, 3);
135 test_as_type!(i64, 1, i16, 1, 3);
136 test_as_type!(i64, 1, i32, 1, 3);
137 test_as_type!(i64, 1, u8, 1, 3);
138 test_as_type!(i64, 1, u16, 1, 3);
139 test_as_type!(i64, 1, u32, 1, 3);
140 test_as_type!(i64, 1, u64, 1, 3);
141 test_as_type!(i64, 1, f32, 1.0, 3);
142 test_as_type!(i64, 1, f16, f16::from_f32(1.0), 3);
143 test_as_type!(i64, 1, bf16, bf16::from_f32(1.0), 3);
144 test_as_type!(i64, 1, complex64, complex64::new(1.0, 0.0), 3);
145
146 test_as_type!(u8, 1, bool, true, 3);
147 test_as_type!(u8, 1, i8, 1, 3);
148 test_as_type!(u8, 1, i16, 1, 3);
149 test_as_type!(u8, 1, i32, 1, 3);
150 test_as_type!(u8, 1, i64, 1, 3);
151 test_as_type!(u8, 1, u16, 1, 3);
152 test_as_type!(u8, 1, u32, 1, 3);
153 test_as_type!(u8, 1, u64, 1, 3);
154 test_as_type!(u8, 1, f32, 1.0, 3);
155 test_as_type!(u8, 1, f16, f16::from_f32(1.0), 3);
156 test_as_type!(u8, 1, bf16, bf16::from_f32(1.0), 3);
157 test_as_type!(u8, 1, complex64, complex64::new(1.0, 0.0), 3);
158
159 test_as_type!(u16, 1, bool, true, 3);
160 test_as_type!(u16, 1, i8, 1, 3);
161 test_as_type!(u16, 1, i16, 1, 3);
162 test_as_type!(u16, 1, i32, 1, 3);
163 test_as_type!(u16, 1, i64, 1, 3);
164 test_as_type!(u16, 1, u8, 1, 3);
165 test_as_type!(u16, 1, u32, 1, 3);
166 test_as_type!(u16, 1, u64, 1, 3);
167 test_as_type!(u16, 1, f32, 1.0, 3);
168 test_as_type!(u16, 1, f16, f16::from_f32(1.0), 3);
169 test_as_type!(u16, 1, bf16, bf16::from_f32(1.0), 3);
170 test_as_type!(u16, 1, complex64, complex64::new(1.0, 0.0), 3);
171
172 test_as_type!(u32, 1, bool, true, 3);
173 test_as_type!(u32, 1, i8, 1, 3);
174 test_as_type!(u32, 1, i16, 1, 3);
175 test_as_type!(u32, 1, i32, 1, 3);
176 test_as_type!(u32, 1, i64, 1, 3);
177 test_as_type!(u32, 1, u8, 1, 3);
178 test_as_type!(u32, 1, u16, 1, 3);
179 test_as_type!(u32, 1, u64, 1, 3);
180 test_as_type!(u32, 1, f32, 1.0, 3);
181 test_as_type!(u32, 1, f16, f16::from_f32(1.0), 3);
182 test_as_type!(u32, 1, bf16, bf16::from_f32(1.0), 3);
183 test_as_type!(u32, 1, complex64, complex64::new(1.0, 0.0), 3);
184
185 test_as_type!(u64, 1, bool, true, 3);
186 test_as_type!(u64, 1, i8, 1, 3);
187 test_as_type!(u64, 1, i16, 1, 3);
188 test_as_type!(u64, 1, i32, 1, 3);
189 test_as_type!(u64, 1, i64, 1, 3);
190 test_as_type!(u64, 1, u8, 1, 3);
191 test_as_type!(u64, 1, u16, 1, 3);
192 test_as_type!(u64, 1, u32, 1, 3);
193 test_as_type!(u64, 1, f32, 1.0, 3);
194 test_as_type!(u64, 1, f16, f16::from_f32(1.0), 3);
195 test_as_type!(u64, 1, bf16, bf16::from_f32(1.0), 3);
196 test_as_type!(u64, 1, complex64, complex64::new(1.0, 0.0), 3);
197
198 test_as_type!(f32, 1.0, bool, true, 3);
199 test_as_type!(f32, 1.0, i8, 1, 3);
200 test_as_type!(f32, 1.0, i16, 1, 3);
201 test_as_type!(f32, 1.0, i32, 1, 3);
202 test_as_type!(f32, 1.0, i64, 1, 3);
203 test_as_type!(f32, 1.0, u8, 1, 3);
204 test_as_type!(f32, 1.0, u16, 1, 3);
205 test_as_type!(f32, 1.0, u32, 1, 3);
206 test_as_type!(f32, 1.0, u64, 1, 3);
207 test_as_type!(f32, 1.0, f16, f16::from_f32(1.0), 3);
208 test_as_type!(f32, 1.0, bf16, bf16::from_f32(1.0), 3);
209 test_as_type!(f32, 1.0, complex64, complex64::new(1.0, 0.0), 3);
210
211 test_as_type!(f16, f16::from_f32(1.0), bool, true, 3);
212 test_as_type!(f16, f16::from_f32(1.0), i8, 1, 3);
213 test_as_type!(f16, f16::from_f32(1.0), i16, 1, 3);
214 test_as_type!(f16, f16::from_f32(1.0), i32, 1, 3);
215 test_as_type!(f16, f16::from_f32(1.0), i64, 1, 3);
216 test_as_type!(f16, f16::from_f32(1.0), u8, 1, 3);
217 test_as_type!(f16, f16::from_f32(1.0), u16, 1, 3);
218 test_as_type!(f16, f16::from_f32(1.0), u32, 1, 3);
219 test_as_type!(f16, f16::from_f32(1.0), u64, 1, 3);
220 test_as_type!(f16, f16::from_f32(1.0), f32, 1.0, 3);
221 test_as_type!(f16, f16::from_f32(1.0), bf16, bf16::from_f32(1.0), 3);
222 test_as_type!(
223 f16,
224 f16::from_f32(1.0),
225 complex64,
226 complex64::new(1.0, 0.0),
227 3
228 );
229
230 test_as_type!(bf16, bf16::from_f32(1.0), bool, true, 3);
231 test_as_type!(bf16, bf16::from_f32(1.0), i8, 1, 3);
232 test_as_type!(bf16, bf16::from_f32(1.0), i16, 1, 3);
233 test_as_type!(bf16, bf16::from_f32(1.0), i32, 1, 3);
234 test_as_type!(bf16, bf16::from_f32(1.0), i64, 1, 3);
235 test_as_type!(bf16, bf16::from_f32(1.0), u8, 1, 3);
236 test_as_type!(bf16, bf16::from_f32(1.0), u16, 1, 3);
237 test_as_type!(bf16, bf16::from_f32(1.0), u32, 1, 3);
238 test_as_type!(bf16, bf16::from_f32(1.0), u64, 1, 3);
239 test_as_type!(bf16, bf16::from_f32(1.0), f32, 1.0, 3);
240 test_as_type!(bf16, bf16::from_f32(1.0), f16, f16::from_f32(1.0), 3);
241
242 test_as_type!(complex64, complex64::new(1.0, 0.0), bool, true, 3);
243 test_as_type!(complex64, complex64::new(1.0, 0.0), i8, 1, 3);
244 test_as_type!(complex64, complex64::new(1.0, 0.0), i16, 1, 3);
245 test_as_type!(complex64, complex64::new(1.0, 0.0), i32, 1, 3);
246 test_as_type!(complex64, complex64::new(1.0, 0.0), i64, 1, 3);
247 test_as_type!(complex64, complex64::new(1.0, 0.0), u8, 1, 3);
248 test_as_type!(complex64, complex64::new(1.0, 0.0), u16, 1, 3);
249 test_as_type!(complex64, complex64::new(1.0, 0.0), u32, 1, 3);
250 test_as_type!(complex64, complex64::new(1.0, 0.0), u64, 1, 3);
251 test_as_type!(complex64, complex64::new(1.0, 0.0), f32, 1.0, 3);
252 test_as_type!(
253 complex64,
254 complex64::new(1.0, 0.0),
255 f16,
256 f16::from_f32(1.0),
257 3
258 );
259 test_as_type!(
260 complex64,
261 complex64::new(1.0, 0.0),
262 bf16,
263 bf16::from_f32(1.0),
264 3
265 );
266
267 #[test]
268 fn test_view() {
269 let array = Array::from_slice(&[1i16, 2, 3], &[3]);
270 let new_array = array.view::<i8>().unwrap();
271
272 assert_eq!(new_array.dtype(), Dtype::Int8);
273 assert_eq!(new_array.shape(), &[6]);
274 assert_eq!(new_array.item_size(), 1);
275 assert_eq!(new_array.as_slice::<i8>(), &[1, 0, 2, 0, 3, 0]);
276 }
277}