1use mlx_internal_macros::generate_test_cases;
2use strum::EnumIter;
3
4generate_test_cases! {
5 #[derive(
7 Debug,
8 Clone,
9 Copy,
10 PartialEq,
11 Eq,
12 num_enum::IntoPrimitive,
13 num_enum::TryFromPrimitive,
14 EnumIter,
15 Hash,
16 )]
17 #[repr(u32)]
18 pub enum Dtype {
19 Bool = mlx_sys::mlx_dtype__MLX_BOOL,
21
22 Uint8 = mlx_sys::mlx_dtype__MLX_UINT8,
24
25 Uint16 = mlx_sys::mlx_dtype__MLX_UINT16,
27
28 Uint32 = mlx_sys::mlx_dtype__MLX_UINT32,
30
31 Uint64 = mlx_sys::mlx_dtype__MLX_UINT64,
33
34 Int8 = mlx_sys::mlx_dtype__MLX_INT8,
36
37 Int16 = mlx_sys::mlx_dtype__MLX_INT16,
39
40 Int32 = mlx_sys::mlx_dtype__MLX_INT32,
42
43 Int64 = mlx_sys::mlx_dtype__MLX_INT64,
45
46 Float16 = mlx_sys::mlx_dtype__MLX_FLOAT16,
48
49 Float32 = mlx_sys::mlx_dtype__MLX_FLOAT32,
51
52 Float64 = mlx_sys::mlx_dtype__MLX_FLOAT64,
54
55 Bfloat16 = mlx_sys::mlx_dtype__MLX_BFLOAT16,
57
58 Complex64 = mlx_sys::mlx_dtype__MLX_COMPLEX64,
60 }
61}
62
63impl Dtype {
64 pub fn is_complex(&self) -> bool {
66 matches!(self, Dtype::Complex64)
67 }
68
69 pub fn is_float(&self) -> bool {
71 matches!(self, Dtype::Float16 | Dtype::Float32 | Dtype::Bfloat16)
72 }
73
74 pub fn is_inexact(&self) -> bool {
76 matches!(
77 self,
78 Dtype::Float16 | Dtype::Float32 | Dtype::Complex64 | Dtype::Bfloat16
79 )
80 }
81
82 pub fn from_promoting_types(a: Dtype, b: Dtype) -> Self {
84 a.promote_with(b)
85 }
86}
87
88pub(crate) trait TypePromotion {
89 fn promote_with(self, other: Self) -> Self;
90}
91
92impl TypePromotion for Dtype {
93 fn promote_with(self, other: Self) -> Self {
94 use crate::dtype::Dtype::*;
95 match (self, other) {
96 (Bool, Bool) => Bool,
98 (Bool, _) | (_, Bool) => {
99 if self == Bool {
100 other
101 } else {
102 self
103 }
104 }
105
106 (Uint8, Uint8) => Uint8,
108 (Uint8, Uint16) | (Uint16, Uint8) => Uint16,
109 (Uint8, Uint32) | (Uint32, Uint8) => Uint32,
110 (Uint8, Uint64) | (Uint64, Uint8) => Uint64,
111 (Uint8, Int8) | (Int8, Uint8) => Int16,
112 (Uint8, Int16) | (Int16, Uint8) => Int16,
113 (Uint8, Int32) | (Int32, Uint8) => Int32,
114 (Uint8, Int64) | (Int64, Uint8) => Int64,
115
116 (Uint16, Uint16) => Uint16,
118 (Uint16, Uint32) | (Uint32, Uint16) => Uint32,
119 (Uint16, Uint64) | (Uint64, Uint16) => Uint64,
120 (Uint16, Int8) | (Int8, Uint16) => Int32,
121 (Uint16, Int16) | (Int16, Uint16) => Int32,
122 (Uint16, Int32) | (Int32, Uint16) => Int32,
123 (Uint16, Int64) | (Int64, Uint16) => Int64,
124
125 (Uint32, Uint32) => Uint32,
127 (Uint32, Uint64) | (Uint64, Uint32) => Uint64,
128 (Uint32, Int8) | (Int8, Uint32) => Int64,
129 (Uint32, Int16) | (Int16, Uint32) => Int64,
130 (Uint32, Int32) | (Int32, Uint32) => Int64,
131 (Uint32, Int64) | (Int64, Uint32) => Int64,
132
133 (Uint64, Uint64) => Uint64,
135 (Uint64, Int8) | (Int8, Uint64) => Float32,
136 (Uint64, Int16) | (Int16, Uint64) => Float32,
137 (Uint64, Int32) | (Int32, Uint64) => Float32,
138 (Uint64, Int64) | (Int64, Uint64) => Float32,
139
140 (Int8, Int8) => Int8,
142 (Int8, Int16) | (Int16, Int8) => Int16,
143 (Int8, Int32) | (Int32, Int8) => Int32,
144 (Int8, Int64) | (Int64, Int8) => Int64,
145
146 (Int16, Int16) => Int16,
148 (Int16, Int32) | (Int32, Int16) => Int32,
149 (Int16, Int64) | (Int64, Int16) => Int64,
150
151 (Int32, Int32) => Int32,
153 (Int32, Int64) | (Int64, Int32) => Int64,
154
155 (Int64, Int64) => Int64,
157
158 (Float16, Bfloat16) | (Bfloat16, Float16) => Float32,
160
161 (Complex64, _) | (_, Complex64) => Complex64,
163
164 (Float64, _) | (_, Float64) => Float64,
166
167 (Float32, _) | (_, Float32) => Float32,
169
170 (Float16, _) | (_, Float16) => Float16,
172
173 (Bfloat16, _) | (_, Bfloat16) => Bfloat16,
175 }
176 }
177}
178
179cfg_safetensors! {
180 impl TryFrom<safetensors::tensor::Dtype> for Dtype {
181 type Error = crate::error::ConversionError;
182
183 fn try_from(value: safetensors::tensor::Dtype) -> Result<Self, Self::Error> {
184 let out = match value {
185 safetensors::Dtype::BOOL => Dtype::Bool,
186 safetensors::Dtype::U8 => Dtype::Uint8,
187 safetensors::Dtype::I8 => Dtype::Int8,
188 safetensors::Dtype::F8_E5M2 => return Err(crate::error::ConversionError::SafeTensorDtype(value)),
189 safetensors::Dtype::F8_E4M3 => return Err(crate::error::ConversionError::SafeTensorDtype(value)),
190 safetensors::Dtype::I16 => Dtype::Int16,
191 safetensors::Dtype::U16 => Dtype::Uint16,
192 safetensors::Dtype::F16 => Dtype::Float16,
193 safetensors::Dtype::BF16 => Dtype::Bfloat16,
194 safetensors::Dtype::I32 => Dtype::Int32,
195 safetensors::Dtype::U32 => Dtype::Uint32,
196 safetensors::Dtype::F32 => Dtype::Float32,
197 safetensors::Dtype::F64 => Dtype::Float64,
198 safetensors::Dtype::I64 => Dtype::Int64,
199 safetensors::Dtype::U64 => Dtype::Uint64,
200 _ => return Err(crate::error::ConversionError::SafeTensorDtype(value)),
201 };
202 Ok(out)
203 }
204 }
205
206 impl TryFrom<Dtype> for safetensors::tensor::Dtype {
207 type Error = crate::error::ConversionError;
208
209 fn try_from(value: Dtype) -> Result<Self, Self::Error> {
210 let out = match value {
211 Dtype::Bool => safetensors::Dtype::BOOL,
212 Dtype::Uint8 => safetensors::Dtype::U8,
213 Dtype::Int8 => safetensors::Dtype::I8,
214 Dtype::Int16 => safetensors::Dtype::I16,
215 Dtype::Uint16 => safetensors::Dtype::U16,
216 Dtype::Float16 => safetensors::Dtype::F16,
217 Dtype::Bfloat16 => safetensors::Dtype::BF16,
218 Dtype::Int32 => safetensors::Dtype::I32,
219 Dtype::Uint32 => safetensors::Dtype::U32,
220 Dtype::Float32 => safetensors::Dtype::F32,
221 Dtype::Float64 => safetensors::Dtype::F64,
222 Dtype::Int64 => safetensors::Dtype::I64,
223 Dtype::Uint64 => safetensors::Dtype::U64,
224 Dtype::Complex64 => return Err(crate::error::ConversionError::MlxDtype(value)),
225 };
226 Ok(out)
227 }
228 }
229}