mlx_rs/
dtype.rs

1use half::{bf16, f16};
2use mlx_internal_macros::generate_test_cases;
3use strum::EnumIter;
4
5use crate::error::InexactDtypeError;
6
7generate_test_cases! {
8    /// Array element type
9    #[derive(
10        Debug,
11        Clone,
12        Copy,
13        PartialEq,
14        Eq,
15        num_enum::IntoPrimitive,
16        num_enum::TryFromPrimitive,
17        EnumIter,
18        Hash,
19    )]
20    #[repr(u32)]
21    pub enum Dtype {
22        /// bool
23        Bool = mlx_sys::mlx_dtype__MLX_BOOL,
24
25        /// u8
26        Uint8 = mlx_sys::mlx_dtype__MLX_UINT8,
27
28        /// u16
29        Uint16 = mlx_sys::mlx_dtype__MLX_UINT16,
30
31        /// u32
32        Uint32 = mlx_sys::mlx_dtype__MLX_UINT32,
33
34        /// u64
35        Uint64 = mlx_sys::mlx_dtype__MLX_UINT64,
36
37        /// i8
38        Int8 = mlx_sys::mlx_dtype__MLX_INT8,
39
40        /// i16
41        Int16 = mlx_sys::mlx_dtype__MLX_INT16,
42
43        /// i32
44        Int32 = mlx_sys::mlx_dtype__MLX_INT32,
45
46        /// i64
47        Int64 = mlx_sys::mlx_dtype__MLX_INT64,
48
49        /// f16
50        Float16 = mlx_sys::mlx_dtype__MLX_FLOAT16,
51
52        /// f32
53        Float32 = mlx_sys::mlx_dtype__MLX_FLOAT32,
54
55        /// f64
56        Float64 = mlx_sys::mlx_dtype__MLX_FLOAT64,
57
58        /// bfloat16
59        Bfloat16 = mlx_sys::mlx_dtype__MLX_BFLOAT16,
60
61        /// complex64
62        Complex64 = mlx_sys::mlx_dtype__MLX_COMPLEX64,
63    }
64}
65
66impl Dtype {
67    /// Returns `true` if the data type is complex.
68    pub fn is_complex(&self) -> bool {
69        matches!(self, Dtype::Complex64)
70    }
71
72    /// Returns `true` if the data type is floating point.
73    pub fn is_float(&self) -> bool {
74        matches!(self, Dtype::Float16 | Dtype::Float32 | Dtype::Bfloat16)
75    }
76
77    /// Returns `true` if the data type is one of `f16`, `f32`, `bfloat16`, or `complex64`.
78    pub fn is_inexact(&self) -> bool {
79        matches!(
80            self,
81            Dtype::Float16 | Dtype::Float32 | Dtype::Complex64 | Dtype::Bfloat16
82        )
83    }
84
85    /// Returns the promotion type of two data types.
86    pub fn from_promoting_types(a: Dtype, b: Dtype) -> Self {
87        a.promote_with(b)
88    }
89
90    /// Minimum value of the float point types. Returns `Err(_)` if the type is not
91    /// float point
92    pub fn finfo_min(&self) -> Result<f64, InexactDtypeError> {
93        match self {
94            Dtype::Float16 => Ok(f16::MIN.to_f64_const()),
95            Dtype::Float32 => Ok(f32::MIN as f64),
96            Dtype::Complex64 => Ok(f32::MIN as f64),
97            Dtype::Bfloat16 => Ok(bf16::MIN.to_f64_const()),
98            _ => Err(InexactDtypeError(*self)),
99        }
100    }
101
102    /// Maximum value of the float point types. Returns `Err(_)` if the type is not
103    /// float point
104    pub fn finfo_max(&self) -> Result<f64, InexactDtypeError> {
105        match self {
106            Dtype::Float16 => Ok(f16::MAX.to_f64_const()),
107            Dtype::Float32 => Ok(f32::MAX as f64),
108            Dtype::Complex64 => Ok(f32::MAX as f64),
109            Dtype::Bfloat16 => Ok(bf16::MAX.to_f64_const()),
110            _ => Err(InexactDtypeError(*self)),
111        }
112    }
113}
114
115pub(crate) trait TypePromotion {
116    fn promote_with(self, other: Self) -> Self;
117}
118
119impl TypePromotion for Dtype {
120    fn promote_with(self, other: Self) -> Self {
121        use crate::dtype::Dtype::*;
122        match (self, other) {
123            // Boolean promotions
124            (Bool, Bool) => Bool,
125            (Bool, _) | (_, Bool) => {
126                if self == Bool {
127                    other
128                } else {
129                    self
130                }
131            }
132
133            // Uint8 promotions
134            (Uint8, Uint8) => Uint8,
135            (Uint8, Uint16) | (Uint16, Uint8) => Uint16,
136            (Uint8, Uint32) | (Uint32, Uint8) => Uint32,
137            (Uint8, Uint64) | (Uint64, Uint8) => Uint64,
138            (Uint8, Int8) | (Int8, Uint8) => Int16,
139            (Uint8, Int16) | (Int16, Uint8) => Int16,
140            (Uint8, Int32) | (Int32, Uint8) => Int32,
141            (Uint8, Int64) | (Int64, Uint8) => Int64,
142
143            // Uint16 promotions
144            (Uint16, Uint16) => Uint16,
145            (Uint16, Uint32) | (Uint32, Uint16) => Uint32,
146            (Uint16, Uint64) | (Uint64, Uint16) => Uint64,
147            (Uint16, Int8) | (Int8, Uint16) => Int32,
148            (Uint16, Int16) | (Int16, Uint16) => Int32,
149            (Uint16, Int32) | (Int32, Uint16) => Int32,
150            (Uint16, Int64) | (Int64, Uint16) => Int64,
151
152            // Uint32 promotions
153            (Uint32, Uint32) => Uint32,
154            (Uint32, Uint64) | (Uint64, Uint32) => Uint64,
155            (Uint32, Int8) | (Int8, Uint32) => Int64,
156            (Uint32, Int16) | (Int16, Uint32) => Int64,
157            (Uint32, Int32) | (Int32, Uint32) => Int64,
158            (Uint32, Int64) | (Int64, Uint32) => Int64,
159
160            // Uint64 promotions
161            (Uint64, Uint64) => Uint64,
162            (Uint64, Int8) | (Int8, Uint64) => Float32,
163            (Uint64, Int16) | (Int16, Uint64) => Float32,
164            (Uint64, Int32) | (Int32, Uint64) => Float32,
165            (Uint64, Int64) | (Int64, Uint64) => Float32,
166
167            // Int8 promotions
168            (Int8, Int8) => Int8,
169            (Int8, Int16) | (Int16, Int8) => Int16,
170            (Int8, Int32) | (Int32, Int8) => Int32,
171            (Int8, Int64) | (Int64, Int8) => Int64,
172
173            // Int16 promotions
174            (Int16, Int16) => Int16,
175            (Int16, Int32) | (Int32, Int16) => Int32,
176            (Int16, Int64) | (Int64, Int16) => Int64,
177
178            // Int32 promotions
179            (Int32, Int32) => Int32,
180            (Int32, Int64) | (Int64, Int32) => Int64,
181
182            // Int64 promotions
183            (Int64, Int64) => Int64,
184
185            // Float16 promotions
186            (Float16, Bfloat16) | (Bfloat16, Float16) => Float32,
187
188            // Complex type
189            (Complex64, _) | (_, Complex64) => Complex64,
190
191            // Float64 promotions
192            (Float64, _) | (_, Float64) => Float64,
193
194            // Float32 promotions
195            (Float32, _) | (_, Float32) => Float32,
196
197            // Float16 promotions
198            (Float16, _) | (_, Float16) => Float16,
199
200            // Bfloat16 promotions
201            (Bfloat16, _) | (_, Bfloat16) => Bfloat16,
202        }
203    }
204}
205
206cfg_safetensors! {
207    impl TryFrom<safetensors::tensor::Dtype> for Dtype {
208        type Error = crate::error::ConversionError;
209
210        fn try_from(value: safetensors::tensor::Dtype) -> Result<Self, Self::Error> {
211            let out = match value {
212                safetensors::Dtype::BOOL => Dtype::Bool,
213                safetensors::Dtype::U8 => Dtype::Uint8,
214                safetensors::Dtype::I8 => Dtype::Int8,
215                safetensors::Dtype::F8_E5M2 => return Err(crate::error::ConversionError::SafeTensorDtype(value)),
216                safetensors::Dtype::F8_E4M3 => return Err(crate::error::ConversionError::SafeTensorDtype(value)),
217                safetensors::Dtype::I16 => Dtype::Int16,
218                safetensors::Dtype::U16 => Dtype::Uint16,
219                safetensors::Dtype::F16 => Dtype::Float16,
220                safetensors::Dtype::BF16 => Dtype::Bfloat16,
221                safetensors::Dtype::I32 => Dtype::Int32,
222                safetensors::Dtype::U32 => Dtype::Uint32,
223                safetensors::Dtype::F32 => Dtype::Float32,
224                safetensors::Dtype::F64 => Dtype::Float64,
225                safetensors::Dtype::I64 => Dtype::Int64,
226                safetensors::Dtype::U64 => Dtype::Uint64,
227                _ => return Err(crate::error::ConversionError::SafeTensorDtype(value)),
228            };
229            Ok(out)
230        }
231    }
232
233    impl TryFrom<Dtype> for safetensors::tensor::Dtype {
234        type Error = crate::error::ConversionError;
235
236        fn try_from(value: Dtype) -> Result<Self, Self::Error> {
237            let out = match value {
238                Dtype::Bool => safetensors::Dtype::BOOL,
239                Dtype::Uint8 => safetensors::Dtype::U8,
240                Dtype::Int8 => safetensors::Dtype::I8,
241                Dtype::Int16 => safetensors::Dtype::I16,
242                Dtype::Uint16 => safetensors::Dtype::U16,
243                Dtype::Float16 => safetensors::Dtype::F16,
244                Dtype::Bfloat16 => safetensors::Dtype::BF16,
245                Dtype::Int32 => safetensors::Dtype::I32,
246                Dtype::Uint32 => safetensors::Dtype::U32,
247                Dtype::Float32 => safetensors::Dtype::F32,
248                Dtype::Float64 => safetensors::Dtype::F64,
249                Dtype::Int64 => safetensors::Dtype::I64,
250                Dtype::Uint64 => safetensors::Dtype::U64,
251                Dtype::Complex64 => return Err(crate::error::ConversionError::MlxDtype(value)),
252            };
253            Ok(out)
254        }
255    }
256}