1use crate::error::Result;
2use crate::sealed::Sealed;
3use crate::{complex64, Array, Dtype};
4use half::{bf16, f16};
5
6pub trait ArrayElement: Sized + Sealed {
8 const DTYPE: Dtype;
10
11 fn array_item(array: &Array) -> Result<Self>;
13
14 fn array_data(array: &Array) -> *const Self;
16}
17
18pub trait FromSliceElement: ArrayElement {}
28
29macro_rules! impl_array_element {
30 ($type:ty, $dtype:expr, $ctype:ident) => {
31 paste::paste! {
32 impl Sealed for $type {}
33 impl ArrayElement for $type {
34 const DTYPE: Dtype = $dtype;
35
36 fn array_item(array: &Array) -> Result<Self> {
37 use crate::utils::guard::*;
38
39 <$type as Guarded>::try_from_op(|ptr| unsafe {
40 mlx_sys::[<mlx_array_item_ $ctype >](ptr, array.as_ptr())
41 })
42 }
43
44 fn array_data(array: &Array) -> *const Self {
45 unsafe { mlx_sys::[<mlx_array_data_ $ctype >](array.as_ptr()) as *const Self }
46 }
47
48 }
49 }
50 };
51}
52
53impl_array_element!(bool, Dtype::Bool, bool);
54impl_array_element!(u8, Dtype::Uint8, uint8);
55impl_array_element!(u16, Dtype::Uint16, uint16);
56impl_array_element!(u32, Dtype::Uint32, uint32);
57impl_array_element!(u64, Dtype::Uint64, uint64);
58impl_array_element!(i8, Dtype::Int8, int8);
59impl_array_element!(i16, Dtype::Int16, int16);
60impl_array_element!(i32, Dtype::Int32, int32);
61impl_array_element!(i64, Dtype::Int64, int64);
62impl_array_element!(f64, Dtype::Float64, float64);
63impl_array_element!(f32, Dtype::Float32, float32);
64impl_array_element!(f16, Dtype::Float16, float16);
65impl_array_element!(bf16, Dtype::Bfloat16, bfloat16);
66impl_array_element!(complex64, Dtype::Complex64, complex64);
67
68macro_rules! impl_from_slice_element {
69 ($type:ty) => {
70 impl FromSliceElement for $type {}
71 };
72}
73
74impl_from_slice_element!(bool);
75impl_from_slice_element!(u8);
76impl_from_slice_element!(u16);
77impl_from_slice_element!(u32);
78impl_from_slice_element!(u64);
79impl_from_slice_element!(i8);
80impl_from_slice_element!(i16);
81impl_from_slice_element!(i32);
82impl_from_slice_element!(i64);
83impl_from_slice_element!(f32);
84impl_from_slice_element!(f16);
85impl_from_slice_element!(bf16);
86impl_from_slice_element!(complex64);