mlx_rs/array/
element.rs

1use crate::error::Result;
2use crate::sealed::Sealed;
3use crate::{complex64, Array, Dtype};
4use half::{bf16, f16};
5
6/// A marker trait for array elements.
7pub trait ArrayElement: Sized + Sealed {
8    /// The data type of the element.
9    const DTYPE: Dtype;
10
11    /// Access the value of a scalar array. Returns `Err` if the array is not scalar.
12    fn array_item(array: &Array) -> Result<Self>;
13
14    /// Access the raw data of an array.
15    fn array_data(array: &Array) -> *const Self;
16}
17
18/// A marker trait for array element types that can be constructed via the
19/// [`Array::from_slice`] method. This trait is implemented for all array
20/// element types except for [`f64`].
21///
22/// [`f64`] is treated specially because it is not supported on GPU devices, but
23/// rust defaults floating point literals to `f64`. With this trait, we can
24/// limit the default floating point literals to `f32` for constructors
25/// functions like [`Array::from_slice`] and [`Array::from_iter`], and macro
26/// [`crate::array!`].
27pub trait FromSliceElement: ArrayElement {}
28
29macro_rules! impl_array_element {
30    ($type:ty, $dtype:expr, $ctype:ident) => {
31        paste::paste! {
32            impl Sealed for $type {}
33            impl ArrayElement for $type {
34                const DTYPE: Dtype = $dtype;
35
36                fn array_item(array: &Array) -> Result<Self> {
37                    use crate::utils::guard::*;
38
39                    <$type as Guarded>::try_from_op(|ptr| unsafe {
40                        mlx_sys::[<mlx_array_item_ $ctype >](ptr, array.as_ptr())
41                    })
42                }
43
44                fn array_data(array: &Array) -> *const Self {
45                    unsafe { mlx_sys::[<mlx_array_data_ $ctype >](array.as_ptr()) as *const Self }
46                }
47
48            }
49        }
50    };
51}
52
53impl_array_element!(bool, Dtype::Bool, bool);
54impl_array_element!(u8, Dtype::Uint8, uint8);
55impl_array_element!(u16, Dtype::Uint16, uint16);
56impl_array_element!(u32, Dtype::Uint32, uint32);
57impl_array_element!(u64, Dtype::Uint64, uint64);
58impl_array_element!(i8, Dtype::Int8, int8);
59impl_array_element!(i16, Dtype::Int16, int16);
60impl_array_element!(i32, Dtype::Int32, int32);
61impl_array_element!(i64, Dtype::Int64, int64);
62impl_array_element!(f64, Dtype::Float64, float64);
63impl_array_element!(f32, Dtype::Float32, float32);
64impl_array_element!(f16, Dtype::Float16, float16);
65impl_array_element!(bf16, Dtype::Bfloat16, bfloat16);
66impl_array_element!(complex64, Dtype::Complex64, complex64);
67
68macro_rules! impl_from_slice_element {
69    ($type:ty) => {
70        impl FromSliceElement for $type {}
71    };
72}
73
74impl_from_slice_element!(bool);
75impl_from_slice_element!(u8);
76impl_from_slice_element!(u16);
77impl_from_slice_element!(u32);
78impl_from_slice_element!(u64);
79impl_from_slice_element!(i8);
80impl_from_slice_element!(i16);
81impl_from_slice_element!(i32);
82impl_from_slice_element!(i64);
83impl_from_slice_element!(f32);
84impl_from_slice_element!(f16);
85impl_from_slice_element!(bf16);
86impl_from_slice_element!(complex64);