1use half::{bf16, f16};
2use mlx_internal_macros::generate_test_cases;
3use strum::EnumIter;
4
5use crate::error::InexactDtypeError;
6
7generate_test_cases! {
8 #[derive(
10 Debug,
11 Clone,
12 Copy,
13 PartialEq,
14 Eq,
15 num_enum::IntoPrimitive,
16 num_enum::TryFromPrimitive,
17 EnumIter,
18 Hash,
19 )]
20 #[repr(u32)]
21 pub enum Dtype {
22 Bool = mlx_sys::mlx_dtype__MLX_BOOL,
24
25 Uint8 = mlx_sys::mlx_dtype__MLX_UINT8,
27
28 Uint16 = mlx_sys::mlx_dtype__MLX_UINT16,
30
31 Uint32 = mlx_sys::mlx_dtype__MLX_UINT32,
33
34 Uint64 = mlx_sys::mlx_dtype__MLX_UINT64,
36
37 Int8 = mlx_sys::mlx_dtype__MLX_INT8,
39
40 Int16 = mlx_sys::mlx_dtype__MLX_INT16,
42
43 Int32 = mlx_sys::mlx_dtype__MLX_INT32,
45
46 Int64 = mlx_sys::mlx_dtype__MLX_INT64,
48
49 Float16 = mlx_sys::mlx_dtype__MLX_FLOAT16,
51
52 Float32 = mlx_sys::mlx_dtype__MLX_FLOAT32,
54
55 Float64 = mlx_sys::mlx_dtype__MLX_FLOAT64,
57
58 Bfloat16 = mlx_sys::mlx_dtype__MLX_BFLOAT16,
60
61 Complex64 = mlx_sys::mlx_dtype__MLX_COMPLEX64,
63 }
64}
65
66impl Dtype {
67 pub fn is_complex(&self) -> bool {
69 matches!(self, Dtype::Complex64)
70 }
71
72 pub fn is_float(&self) -> bool {
74 matches!(self, Dtype::Float16 | Dtype::Float32 | Dtype::Bfloat16)
75 }
76
77 pub fn is_inexact(&self) -> bool {
79 matches!(
80 self,
81 Dtype::Float16 | Dtype::Float32 | Dtype::Complex64 | Dtype::Bfloat16
82 )
83 }
84
85 pub fn from_promoting_types(a: Dtype, b: Dtype) -> Self {
87 a.promote_with(b)
88 }
89
90 pub fn finfo_min(&self) -> Result<f64, InexactDtypeError> {
93 match self {
94 Dtype::Float16 => Ok(f16::MIN.to_f64_const()),
95 Dtype::Float32 => Ok(f32::MIN as f64),
96 Dtype::Complex64 => Ok(f32::MIN as f64),
97 Dtype::Bfloat16 => Ok(bf16::MIN.to_f64_const()),
98 _ => Err(InexactDtypeError(*self)),
99 }
100 }
101
102 pub fn finfo_max(&self) -> Result<f64, InexactDtypeError> {
105 match self {
106 Dtype::Float16 => Ok(f16::MAX.to_f64_const()),
107 Dtype::Float32 => Ok(f32::MAX as f64),
108 Dtype::Complex64 => Ok(f32::MAX as f64),
109 Dtype::Bfloat16 => Ok(bf16::MAX.to_f64_const()),
110 _ => Err(InexactDtypeError(*self)),
111 }
112 }
113}
114
115pub(crate) trait TypePromotion {
116 fn promote_with(self, other: Self) -> Self;
117}
118
119impl TypePromotion for Dtype {
120 fn promote_with(self, other: Self) -> Self {
121 use crate::dtype::Dtype::*;
122 match (self, other) {
123 (Bool, Bool) => Bool,
125 (Bool, _) | (_, Bool) => {
126 if self == Bool {
127 other
128 } else {
129 self
130 }
131 }
132
133 (Uint8, Uint8) => Uint8,
135 (Uint8, Uint16) | (Uint16, Uint8) => Uint16,
136 (Uint8, Uint32) | (Uint32, Uint8) => Uint32,
137 (Uint8, Uint64) | (Uint64, Uint8) => Uint64,
138 (Uint8, Int8) | (Int8, Uint8) => Int16,
139 (Uint8, Int16) | (Int16, Uint8) => Int16,
140 (Uint8, Int32) | (Int32, Uint8) => Int32,
141 (Uint8, Int64) | (Int64, Uint8) => Int64,
142
143 (Uint16, Uint16) => Uint16,
145 (Uint16, Uint32) | (Uint32, Uint16) => Uint32,
146 (Uint16, Uint64) | (Uint64, Uint16) => Uint64,
147 (Uint16, Int8) | (Int8, Uint16) => Int32,
148 (Uint16, Int16) | (Int16, Uint16) => Int32,
149 (Uint16, Int32) | (Int32, Uint16) => Int32,
150 (Uint16, Int64) | (Int64, Uint16) => Int64,
151
152 (Uint32, Uint32) => Uint32,
154 (Uint32, Uint64) | (Uint64, Uint32) => Uint64,
155 (Uint32, Int8) | (Int8, Uint32) => Int64,
156 (Uint32, Int16) | (Int16, Uint32) => Int64,
157 (Uint32, Int32) | (Int32, Uint32) => Int64,
158 (Uint32, Int64) | (Int64, Uint32) => Int64,
159
160 (Uint64, Uint64) => Uint64,
162 (Uint64, Int8) | (Int8, Uint64) => Float32,
163 (Uint64, Int16) | (Int16, Uint64) => Float32,
164 (Uint64, Int32) | (Int32, Uint64) => Float32,
165 (Uint64, Int64) | (Int64, Uint64) => Float32,
166
167 (Int8, Int8) => Int8,
169 (Int8, Int16) | (Int16, Int8) => Int16,
170 (Int8, Int32) | (Int32, Int8) => Int32,
171 (Int8, Int64) | (Int64, Int8) => Int64,
172
173 (Int16, Int16) => Int16,
175 (Int16, Int32) | (Int32, Int16) => Int32,
176 (Int16, Int64) | (Int64, Int16) => Int64,
177
178 (Int32, Int32) => Int32,
180 (Int32, Int64) | (Int64, Int32) => Int64,
181
182 (Int64, Int64) => Int64,
184
185 (Float16, Bfloat16) | (Bfloat16, Float16) => Float32,
187
188 (Complex64, _) | (_, Complex64) => Complex64,
190
191 (Float64, _) | (_, Float64) => Float64,
193
194 (Float32, _) | (_, Float32) => Float32,
196
197 (Float16, _) | (_, Float16) => Float16,
199
200 (Bfloat16, _) | (_, Bfloat16) => Bfloat16,
202 }
203 }
204}
205
206cfg_safetensors! {
207 impl TryFrom<safetensors::tensor::Dtype> for Dtype {
208 type Error = crate::error::ConversionError;
209
210 fn try_from(value: safetensors::tensor::Dtype) -> Result<Self, Self::Error> {
211 let out = match value {
212 safetensors::Dtype::BOOL => Dtype::Bool,
213 safetensors::Dtype::U8 => Dtype::Uint8,
214 safetensors::Dtype::I8 => Dtype::Int8,
215 safetensors::Dtype::F8_E5M2 => return Err(crate::error::ConversionError::SafeTensorDtype(value)),
216 safetensors::Dtype::F8_E4M3 => return Err(crate::error::ConversionError::SafeTensorDtype(value)),
217 safetensors::Dtype::I16 => Dtype::Int16,
218 safetensors::Dtype::U16 => Dtype::Uint16,
219 safetensors::Dtype::F16 => Dtype::Float16,
220 safetensors::Dtype::BF16 => Dtype::Bfloat16,
221 safetensors::Dtype::I32 => Dtype::Int32,
222 safetensors::Dtype::U32 => Dtype::Uint32,
223 safetensors::Dtype::F32 => Dtype::Float32,
224 safetensors::Dtype::F64 => Dtype::Float64,
225 safetensors::Dtype::I64 => Dtype::Int64,
226 safetensors::Dtype::U64 => Dtype::Uint64,
227 _ => return Err(crate::error::ConversionError::SafeTensorDtype(value)),
228 };
229 Ok(out)
230 }
231 }
232
233 impl TryFrom<Dtype> for safetensors::tensor::Dtype {
234 type Error = crate::error::ConversionError;
235
236 fn try_from(value: Dtype) -> Result<Self, Self::Error> {
237 let out = match value {
238 Dtype::Bool => safetensors::Dtype::BOOL,
239 Dtype::Uint8 => safetensors::Dtype::U8,
240 Dtype::Int8 => safetensors::Dtype::I8,
241 Dtype::Int16 => safetensors::Dtype::I16,
242 Dtype::Uint16 => safetensors::Dtype::U16,
243 Dtype::Float16 => safetensors::Dtype::F16,
244 Dtype::Bfloat16 => safetensors::Dtype::BF16,
245 Dtype::Int32 => safetensors::Dtype::I32,
246 Dtype::Uint32 => safetensors::Dtype::U32,
247 Dtype::Float32 => safetensors::Dtype::F32,
248 Dtype::Float64 => safetensors::Dtype::F64,
249 Dtype::Int64 => safetensors::Dtype::I64,
250 Dtype::Uint64 => safetensors::Dtype::U64,
251 Dtype::Complex64 => return Err(crate::error::ConversionError::MlxDtype(value)),
252 };
253 Ok(out)
254 }
255 }
256}