1use crate::Dtype;
4use libc::strdup;
5use std::convert::Infallible;
6use std::ffi::NulError;
7use std::panic::Location;
8use std::sync::Once;
9use std::{cell::Cell, ffi::c_char};
10use thiserror::Error;
11
12pub type Result<T> = std::result::Result<T, Exception>;
14
15#[derive(Error, PartialEq, Debug)]
17pub enum IoError {
18 #[error("Path must point to a local file")]
20 NotFile,
21
22 #[error("Path contains invalid UTF-8")]
24 InvalidUtf8,
25
26 #[error("Path contains null bytes")]
28 NullBytes,
29
30 #[error("No file extension found")]
32 NoExtension,
33
34 #[error("Unsupported file format")]
36 UnsupportedFormat,
37
38 #[error("Unable to open file")]
40 UnableToOpenFile,
41
42 #[error("Unable to allocate memory")]
44 AllocationError,
45
46 #[error(transparent)]
48 NulError(#[from] NulError),
49
50 #[error(transparent)]
52 Unflatten(#[from] UnflattenError),
53
54 #[error(transparent)]
56 Exception(#[from] Exception),
57}
58
59impl From<Infallible> for IoError {
60 fn from(_: Infallible) -> Self {
61 unreachable!()
62 }
63}
64
65impl From<RawException> for IoError {
66 #[track_caller]
67 fn from(e: RawException) -> Self {
68 let exception = Exception {
69 what: e.what,
70 location: Location::caller(),
71 };
72 Self::Exception(exception)
73 }
74}
75
76#[derive(Debug, PartialEq, Error)]
78pub enum AsSliceError {
79 #[error("The data pointer is null.")]
83 Null,
84
85 #[error("dtype mismatch: expected {expecting:?}, found {found:?}")]
87 DtypeMismatch {
88 expecting: Dtype,
90
91 found: Dtype,
93 },
94
95 #[error(transparent)]
97 Exception(#[from] Exception),
98}
99
100#[derive(Debug, PartialEq, Error)]
102pub enum UnflattenError {
103 #[error("Expecting next (key, value) pair, found none")]
105 ExpectingNextPair,
106
107 #[error("Invalid key")]
109 InvalidKey,
110}
111
112#[derive(Debug, PartialEq, Error)]
114pub enum OptimizerStateLoadError {
115 #[error(transparent)]
117 Io(#[from] IoError),
118
119 #[error(transparent)]
121 Unflatten(#[from] UnflattenError),
122}
123
124impl From<Infallible> for OptimizerStateLoadError {
125 fn from(_: Infallible) -> Self {
126 unreachable!()
127 }
128}
129
130cfg_safetensors! {
131 #[derive(Debug, Error)]
134 pub enum ConversionError {
135 #[error("The safetensors data type {0:?} is not supported.")]
139 SafeTensorDtype(safetensors::tensor::Dtype),
140
141 #[error("The mlx data type {0:?} is not supported.")]
145 MlxDtype(crate::Dtype),
146
147 #[error(transparent)]
149 PodCastError(#[from] bytemuck::PodCastError),
150
151 #[error(transparent)]
153 SafeTensorError(#[from] safetensors::tensor::SafeTensorError),
154 }
155}
156
157pub(crate) struct RawException {
158 pub(crate) what: String,
159}
160
161#[derive(Debug, PartialEq, Error)]
163#[error("{what:?} at {location}")]
164pub struct Exception {
165 pub(crate) what: String,
166 pub(crate) location: &'static Location<'static>,
167}
168
169impl Exception {
170 pub fn what(&self) -> &str {
172 &self.what
173 }
174
175 pub fn location(&self) -> &'static Location<'static> {
181 self.location
182 }
183
184 #[track_caller]
186 pub fn custom(what: impl Into<String>) -> Self {
187 Self {
188 what: what.into(),
189 location: Location::caller(),
190 }
191 }
192}
193
194impl From<RawException> for Exception {
195 #[track_caller]
196 fn from(e: RawException) -> Self {
197 Self {
198 what: e.what,
199 location: Location::caller(),
200 }
201 }
202}
203
204impl From<&str> for Exception {
205 #[track_caller]
206 fn from(what: &str) -> Self {
207 Self {
208 what: what.to_string(),
209 location: Location::caller(),
210 }
211 }
212}
213
214impl From<Infallible> for Exception {
215 fn from(_: Infallible) -> Self {
216 unreachable!()
217 }
218}
219
220impl From<Exception> for String {
221 fn from(e: Exception) -> Self {
222 e.what
223 }
224}
225
226thread_local! {
227 static CLOSURE_ERROR: Cell<Option<Exception>> = const { Cell::new(None) };
228 static LAST_MLX_ERROR: Cell<*const c_char> = const { Cell::new(std::ptr::null()) };
229 pub(crate) static INIT_ERR_HANDLER: Once = const { Once::new() };
230}
231
232#[no_mangle]
233extern "C" fn default_mlx_error_handler(msg: *const c_char, _data: *mut std::ffi::c_void) {
234 unsafe {
235 LAST_MLX_ERROR.with(|last_error| {
236 last_error.set(strdup(msg));
237 });
238 }
239}
240
241#[no_mangle]
242extern "C" fn noop_mlx_error_handler_data_deleter(_data: *mut std::ffi::c_void) {}
243
244pub(crate) fn setup_mlx_error_handler() {
245 let handler = default_mlx_error_handler;
246 let data_ptr = LAST_MLX_ERROR.with(|last_error| last_error.as_ptr() as *mut std::ffi::c_void);
247 let dtor = noop_mlx_error_handler_data_deleter;
248 unsafe {
249 mlx_sys::mlx_set_error_handler(Some(handler), data_ptr, Some(dtor));
250 }
251}
252
253pub(crate) fn set_closure_error(err: Exception) {
254 CLOSURE_ERROR.with(|closure_error| closure_error.set(Some(err)));
255}
256
257pub(crate) fn get_and_clear_closure_error() -> Option<Exception> {
258 CLOSURE_ERROR.with(|closure_error| closure_error.replace(None))
259}
260
261#[track_caller]
262pub(crate) fn get_and_clear_last_mlx_error() -> Option<RawException> {
263 LAST_MLX_ERROR.with(|last_error| {
264 let last_err_ptr = last_error.replace(std::ptr::null());
265 if last_err_ptr.is_null() {
266 return None;
267 }
268
269 let last_err = unsafe {
270 std::ffi::CStr::from_ptr(last_err_ptr)
271 .to_string_lossy()
272 .into_owned()
273 };
274 unsafe {
275 libc::free(last_err_ptr as *mut libc::c_void);
276 }
277
278 Some(RawException { what: last_err })
279 })
280}
281
282#[derive(Debug, Clone, PartialEq, Error)]
284pub enum CrossEntropyBuildError {
285 #[error("Label smoothing factor must be in the range [0, 1)")]
287 InvalidLabelSmoothingFactor,
288}
289
290impl From<CrossEntropyBuildError> for Exception {
291 fn from(value: CrossEntropyBuildError) -> Self {
292 Exception::custom(format!("{}", value))
293 }
294}
295
296#[derive(Debug, Clone, PartialEq, Error)]
298pub enum RmsPropBuildError {
299 #[error("alpha must be non-negative")]
301 NegativeAlpha,
302
303 #[error("epsilon must be non-negative")]
305 NegativeEpsilon,
306}
307
308#[derive(Debug, Clone, PartialEq, Error)]
310pub enum AdaDeltaBuildError {
311 #[error("rho must be non-negative")]
313 NegativeRho,
314
315 #[error("epsilon must be non-negative")]
317 NegativeEps,
318}
319
320#[derive(Debug, Clone, PartialEq, Error)]
322pub enum AdafactorBuildError {
323 #[error("Either learning rate is provided or relative step is set to true")]
325 LrIsNoneAndRelativeStepIsFalse,
326}
327
328#[derive(Debug, Clone, PartialEq, Error)]
330pub enum DropoutBuildError {
331 #[error("Dropout probability must be in the range [0, 1)")]
333 InvalidProbability,
334}
335
336#[derive(Debug, PartialEq, Error)]
338pub enum MultiHeadAttentionBuildError {
339 #[error("Invalid number of heads: {0}")]
341 InvalidNumHeads(i32),
342
343 #[error(transparent)]
345 Exception(#[from] Exception),
346}
347
348#[derive(Debug, PartialEq, Error)]
350pub enum TransformerBulidError {
351 #[error("Dropout probability must be in the range [0, 1)")]
353 InvalidProbability,
354
355 #[error("Invalid number of heads: {0}")]
357 InvalidNumHeads(i32),
358
359 #[error(transparent)]
361 Exception(#[from] Exception),
362}
363
364impl From<DropoutBuildError> for TransformerBulidError {
365 fn from(e: DropoutBuildError) -> Self {
366 match e {
367 DropoutBuildError::InvalidProbability => Self::InvalidProbability,
368 }
369 }
370}
371
372impl From<MultiHeadAttentionBuildError> for TransformerBulidError {
373 fn from(e: MultiHeadAttentionBuildError) -> Self {
374 match e {
375 MultiHeadAttentionBuildError::InvalidNumHeads(n) => Self::InvalidNumHeads(n),
376 MultiHeadAttentionBuildError::Exception(e) => Self::Exception(e),
377 }
378 }
379}
380
381#[cfg(test)]
382mod tests {
383 use crate::array;
384
385 #[test]
386 fn test_exception() {
387 let a = array!([1.0, 2.0, 3.0]);
388 let b = array!([4.0, 5.0]);
389
390 let result = a.add(&b);
391 let error = result.expect_err("Expected error");
392
393 assert!(error
396 .what()
397 .contains("Shapes (3) and (2) cannot be broadcast."))
398 }
399}