mlx_rs/
dtype.rs

1use mlx_internal_macros::generate_test_cases;
2use strum::EnumIter;
3
4generate_test_cases! {
5    /// Array element type
6    #[derive(
7        Debug,
8        Clone,
9        Copy,
10        PartialEq,
11        Eq,
12        num_enum::IntoPrimitive,
13        num_enum::TryFromPrimitive,
14        EnumIter,
15        Hash,
16    )]
17    #[repr(u32)]
18    pub enum Dtype {
19        /// bool
20        Bool = mlx_sys::mlx_dtype__MLX_BOOL,
21
22        /// u8
23        Uint8 = mlx_sys::mlx_dtype__MLX_UINT8,
24
25        /// u16
26        Uint16 = mlx_sys::mlx_dtype__MLX_UINT16,
27
28        /// u32
29        Uint32 = mlx_sys::mlx_dtype__MLX_UINT32,
30
31        /// u64
32        Uint64 = mlx_sys::mlx_dtype__MLX_UINT64,
33
34        /// i8
35        Int8 = mlx_sys::mlx_dtype__MLX_INT8,
36
37        /// i16
38        Int16 = mlx_sys::mlx_dtype__MLX_INT16,
39
40        /// i32
41        Int32 = mlx_sys::mlx_dtype__MLX_INT32,
42
43        /// i64
44        Int64 = mlx_sys::mlx_dtype__MLX_INT64,
45
46        /// f16
47        Float16 = mlx_sys::mlx_dtype__MLX_FLOAT16,
48
49        /// f32
50        Float32 = mlx_sys::mlx_dtype__MLX_FLOAT32,
51
52        /// f64
53        Float64 = mlx_sys::mlx_dtype__MLX_FLOAT64,
54
55        /// bfloat16
56        Bfloat16 = mlx_sys::mlx_dtype__MLX_BFLOAT16,
57
58        /// complex64
59        Complex64 = mlx_sys::mlx_dtype__MLX_COMPLEX64,
60    }
61}
62
63impl Dtype {
64    /// Returns `true` if the data type is complex.
65    pub fn is_complex(&self) -> bool {
66        matches!(self, Dtype::Complex64)
67    }
68
69    /// Returns `true` if the data type is floating point.
70    pub fn is_float(&self) -> bool {
71        matches!(self, Dtype::Float16 | Dtype::Float32 | Dtype::Bfloat16)
72    }
73
74    /// Returns `true` if the data type is one of `f16`, `f32`, `bfloat16`, or `complex64`.
75    pub fn is_inexact(&self) -> bool {
76        matches!(
77            self,
78            Dtype::Float16 | Dtype::Float32 | Dtype::Complex64 | Dtype::Bfloat16
79        )
80    }
81
82    /// Returns the promotion type of two data types.
83    pub fn from_promoting_types(a: Dtype, b: Dtype) -> Self {
84        a.promote_with(b)
85    }
86}
87
88pub(crate) trait TypePromotion {
89    fn promote_with(self, other: Self) -> Self;
90}
91
92impl TypePromotion for Dtype {
93    fn promote_with(self, other: Self) -> Self {
94        use crate::dtype::Dtype::*;
95        match (self, other) {
96            // Boolean promotions
97            (Bool, Bool) => Bool,
98            (Bool, _) | (_, Bool) => {
99                if self == Bool {
100                    other
101                } else {
102                    self
103                }
104            }
105
106            // Uint8 promotions
107            (Uint8, Uint8) => Uint8,
108            (Uint8, Uint16) | (Uint16, Uint8) => Uint16,
109            (Uint8, Uint32) | (Uint32, Uint8) => Uint32,
110            (Uint8, Uint64) | (Uint64, Uint8) => Uint64,
111            (Uint8, Int8) | (Int8, Uint8) => Int16,
112            (Uint8, Int16) | (Int16, Uint8) => Int16,
113            (Uint8, Int32) | (Int32, Uint8) => Int32,
114            (Uint8, Int64) | (Int64, Uint8) => Int64,
115
116            // Uint16 promotions
117            (Uint16, Uint16) => Uint16,
118            (Uint16, Uint32) | (Uint32, Uint16) => Uint32,
119            (Uint16, Uint64) | (Uint64, Uint16) => Uint64,
120            (Uint16, Int8) | (Int8, Uint16) => Int32,
121            (Uint16, Int16) | (Int16, Uint16) => Int32,
122            (Uint16, Int32) | (Int32, Uint16) => Int32,
123            (Uint16, Int64) | (Int64, Uint16) => Int64,
124
125            // Uint32 promotions
126            (Uint32, Uint32) => Uint32,
127            (Uint32, Uint64) | (Uint64, Uint32) => Uint64,
128            (Uint32, Int8) | (Int8, Uint32) => Int64,
129            (Uint32, Int16) | (Int16, Uint32) => Int64,
130            (Uint32, Int32) | (Int32, Uint32) => Int64,
131            (Uint32, Int64) | (Int64, Uint32) => Int64,
132
133            // Uint64 promotions
134            (Uint64, Uint64) => Uint64,
135            (Uint64, Int8) | (Int8, Uint64) => Float32,
136            (Uint64, Int16) | (Int16, Uint64) => Float32,
137            (Uint64, Int32) | (Int32, Uint64) => Float32,
138            (Uint64, Int64) | (Int64, Uint64) => Float32,
139
140            // Int8 promotions
141            (Int8, Int8) => Int8,
142            (Int8, Int16) | (Int16, Int8) => Int16,
143            (Int8, Int32) | (Int32, Int8) => Int32,
144            (Int8, Int64) | (Int64, Int8) => Int64,
145
146            // Int16 promotions
147            (Int16, Int16) => Int16,
148            (Int16, Int32) | (Int32, Int16) => Int32,
149            (Int16, Int64) | (Int64, Int16) => Int64,
150
151            // Int32 promotions
152            (Int32, Int32) => Int32,
153            (Int32, Int64) | (Int64, Int32) => Int64,
154
155            // Int64 promotions
156            (Int64, Int64) => Int64,
157
158            // Float16 promotions
159            (Float16, Bfloat16) | (Bfloat16, Float16) => Float32,
160
161            // Complex type
162            (Complex64, _) | (_, Complex64) => Complex64,
163
164            // Float64 promotions
165            (Float64, _) | (_, Float64) => Float64,
166
167            // Float32 promotions
168            (Float32, _) | (_, Float32) => Float32,
169
170            // Float16 promotions
171            (Float16, _) | (_, Float16) => Float16,
172
173            // Bfloat16 promotions
174            (Bfloat16, _) | (_, Bfloat16) => Bfloat16,
175        }
176    }
177}
178
179cfg_safetensors! {
180    impl TryFrom<safetensors::tensor::Dtype> for Dtype {
181        type Error = crate::error::ConversionError;
182
183        fn try_from(value: safetensors::tensor::Dtype) -> Result<Self, Self::Error> {
184            let out = match value {
185                safetensors::Dtype::BOOL => Dtype::Bool,
186                safetensors::Dtype::U8 => Dtype::Uint8,
187                safetensors::Dtype::I8 => Dtype::Int8,
188                safetensors::Dtype::F8_E5M2 => return Err(crate::error::ConversionError::SafeTensorDtype(value)),
189                safetensors::Dtype::F8_E4M3 => return Err(crate::error::ConversionError::SafeTensorDtype(value)),
190                safetensors::Dtype::I16 => Dtype::Int16,
191                safetensors::Dtype::U16 => Dtype::Uint16,
192                safetensors::Dtype::F16 => Dtype::Float16,
193                safetensors::Dtype::BF16 => Dtype::Bfloat16,
194                safetensors::Dtype::I32 => Dtype::Int32,
195                safetensors::Dtype::U32 => Dtype::Uint32,
196                safetensors::Dtype::F32 => Dtype::Float32,
197                safetensors::Dtype::F64 => Dtype::Float64,
198                safetensors::Dtype::I64 => Dtype::Int64,
199                safetensors::Dtype::U64 => Dtype::Uint64,
200                _ => return Err(crate::error::ConversionError::SafeTensorDtype(value)),
201            };
202            Ok(out)
203        }
204    }
205
206    impl TryFrom<Dtype> for safetensors::tensor::Dtype {
207        type Error = crate::error::ConversionError;
208
209        fn try_from(value: Dtype) -> Result<Self, Self::Error> {
210            let out = match value {
211                Dtype::Bool => safetensors::Dtype::BOOL,
212                Dtype::Uint8 => safetensors::Dtype::U8,
213                Dtype::Int8 => safetensors::Dtype::I8,
214                Dtype::Int16 => safetensors::Dtype::I16,
215                Dtype::Uint16 => safetensors::Dtype::U16,
216                Dtype::Float16 => safetensors::Dtype::F16,
217                Dtype::Bfloat16 => safetensors::Dtype::BF16,
218                Dtype::Int32 => safetensors::Dtype::I32,
219                Dtype::Uint32 => safetensors::Dtype::U32,
220                Dtype::Float32 => safetensors::Dtype::F32,
221                Dtype::Float64 => safetensors::Dtype::F64,
222                Dtype::Int64 => safetensors::Dtype::I64,
223                Dtype::Uint64 => safetensors::Dtype::U64,
224                Dtype::Complex64 => return Err(crate::error::ConversionError::MlxDtype(value)),
225            };
226            Ok(out)
227        }
228    }
229}