mlx_rs/
error.rs

1//! Custom error types and handler for the c ffi
2
3use crate::Dtype;
4use libc::strdup;
5use std::convert::Infallible;
6use std::ffi::NulError;
7use std::panic::Location;
8use std::sync::Once;
9use std::{cell::Cell, ffi::c_char};
10use thiserror::Error;
11
12/// Type alias for a `Result` with an `Exception` error type.
13pub type Result<T> = std::result::Result<T, Exception>;
14
15/// Error with io operations
16#[derive(Error, PartialEq, Debug)]
17pub enum IoError {
18    /// Path must point to a local file
19    #[error("Path must point to a local file")]
20    NotFile,
21
22    /// Path contains invalid UTF-8
23    #[error("Path contains invalid UTF-8")]
24    InvalidUtf8,
25
26    /// Path contains null bytes
27    #[error("Path contains null bytes")]
28    NullBytes,
29
30    /// No file extension found
31    #[error("No file extension found")]
32    NoExtension,
33
34    /// Unsupported file format
35    #[error("Unsupported file format")]
36    UnsupportedFormat,
37
38    /// Unable to open file
39    #[error("Unable to open file")]
40    UnableToOpenFile,
41
42    /// Unable to allocate memory
43    #[error("Unable to allocate memory")]
44    AllocationError,
45
46    /// Null error
47    #[error(transparent)]
48    NulError(#[from] NulError),
49
50    /// Error with unfalttening the loaded optimizer state
51    #[error(transparent)]
52    Unflatten(#[from] UnflattenError),
53
54    /// Exception
55    #[error(transparent)]
56    Exception(#[from] Exception),
57}
58
59impl From<Infallible> for IoError {
60    fn from(_: Infallible) -> Self {
61        unreachable!()
62    }
63}
64
65impl From<RawException> for IoError {
66    #[track_caller]
67    fn from(e: RawException) -> Self {
68        let exception = Exception {
69            what: e.what,
70            location: Location::caller(),
71        };
72        Self::Exception(exception)
73    }
74}
75
76/// Error associated with `Array::try_as_slice()`
77#[derive(Debug, PartialEq, Error)]
78pub enum AsSliceError {
79    /// The underlying data pointer is null.
80    ///
81    /// This is likely because the array has not been evaluated yet.
82    #[error("The data pointer is null.")]
83    Null,
84
85    /// The output dtype does not match the data type of the array.
86    #[error("dtype mismatch: expected {expecting:?}, found {found:?}")]
87    DtypeMismatch {
88        /// The expected data type.
89        expecting: Dtype,
90
91        /// The actual data type
92        found: Dtype,
93    },
94
95    /// Exception
96    #[error(transparent)]
97    Exception(#[from] Exception),
98}
99
100/// Error with unflattening a loaded optimizer state
101#[derive(Debug, PartialEq, Error)]
102pub enum UnflattenError {
103    /// Expecting next (key, value) pair, found none
104    #[error("Expecting next (key, value) pair, found none")]
105    ExpectingNextPair,
106
107    /// The key is not in a valid format
108    #[error("Invalid key")]
109    InvalidKey,
110}
111
112/// Error with loading an optimizer state
113#[derive(Debug, PartialEq, Error)]
114pub enum OptimizerStateLoadError {
115    /// Error with io operations
116    #[error(transparent)]
117    Io(#[from] IoError),
118
119    /// Error with unflattening the optimizer state
120    #[error(transparent)]
121    Unflatten(#[from] UnflattenError),
122}
123
124impl From<Infallible> for OptimizerStateLoadError {
125    fn from(_: Infallible) -> Self {
126        unreachable!()
127    }
128}
129
130cfg_safetensors! {
131    /// Error associated with conversion between `safetensors::tensor::TensorView` and `Array`
132    /// when the data type is not supported.
133    #[derive(Debug, Error)]
134    pub enum ConversionError {
135        /// The safetensors data type that is not supported.
136        ///
137        /// This is the error type for conversions from `safetensors::tensor::TensorView` to `Array`.
138        #[error("The safetensors data type {0:?} is not supported.")]
139        SafeTensorDtype(safetensors::tensor::Dtype),
140
141        /// The mlx data type that is not supported.
142        ///
143        /// This is the error type for conversions from `Array` to `safetensors::tensor::TensorView`.
144        #[error("The mlx data type {0:?} is not supported.")]
145        MlxDtype(crate::Dtype),
146
147        /// Error casting the data buffer to `&[u8]`.
148        #[error(transparent)]
149        PodCastError(#[from] bytemuck::PodCastError),
150
151        /// Error with creating a `safetensors::tensor::TensorView`.
152        #[error(transparent)]
153        SafeTensorError(#[from] safetensors::tensor::SafeTensorError),
154    }
155}
156
157pub(crate) struct RawException {
158    pub(crate) what: String,
159}
160
161/// Exception. Most will come from the C API.
162#[derive(Debug, PartialEq, Error)]
163#[error("{what:?} at {location}")]
164pub struct Exception {
165    pub(crate) what: String,
166    pub(crate) location: &'static Location<'static>,
167}
168
169impl Exception {
170    /// The error message.
171    pub fn what(&self) -> &str {
172        &self.what
173    }
174
175    /// The location of the error.
176    ///
177    /// The location is obtained from `std::panic::Location::caller()` and points
178    /// to the location in the code where the error was created and not where it was
179    /// propagated.
180    pub fn location(&self) -> &'static Location<'static> {
181        self.location
182    }
183
184    /// Creates a new exception with the given message.
185    #[track_caller]
186    pub fn custom(what: impl Into<String>) -> Self {
187        Self {
188            what: what.into(),
189            location: Location::caller(),
190        }
191    }
192}
193
194impl From<RawException> for Exception {
195    #[track_caller]
196    fn from(e: RawException) -> Self {
197        Self {
198            what: e.what,
199            location: Location::caller(),
200        }
201    }
202}
203
204impl From<&str> for Exception {
205    #[track_caller]
206    fn from(what: &str) -> Self {
207        Self {
208            what: what.to_string(),
209            location: Location::caller(),
210        }
211    }
212}
213
214impl From<Infallible> for Exception {
215    fn from(_: Infallible) -> Self {
216        unreachable!()
217    }
218}
219
220impl From<Exception> for String {
221    fn from(e: Exception) -> Self {
222        e.what
223    }
224}
225
226thread_local! {
227    static CLOSURE_ERROR: Cell<Option<Exception>> = const { Cell::new(None) };
228    static LAST_MLX_ERROR: Cell<*const c_char> = const { Cell::new(std::ptr::null()) };
229    pub(crate) static INIT_ERR_HANDLER: Once = const { Once::new() };
230}
231
232#[no_mangle]
233extern "C" fn default_mlx_error_handler(msg: *const c_char, _data: *mut std::ffi::c_void) {
234    unsafe {
235        LAST_MLX_ERROR.with(|last_error| {
236            last_error.set(strdup(msg));
237        });
238    }
239}
240
241#[no_mangle]
242extern "C" fn noop_mlx_error_handler_data_deleter(_data: *mut std::ffi::c_void) {}
243
244pub(crate) fn setup_mlx_error_handler() {
245    let handler = default_mlx_error_handler;
246    let data_ptr = LAST_MLX_ERROR.with(|last_error| last_error.as_ptr() as *mut std::ffi::c_void);
247    let dtor = noop_mlx_error_handler_data_deleter;
248    unsafe {
249        mlx_sys::mlx_set_error_handler(Some(handler), data_ptr, Some(dtor));
250    }
251}
252
253pub(crate) fn set_closure_error(err: Exception) {
254    CLOSURE_ERROR.with(|closure_error| closure_error.set(Some(err)));
255}
256
257pub(crate) fn get_and_clear_closure_error() -> Option<Exception> {
258    CLOSURE_ERROR.with(|closure_error| closure_error.replace(None))
259}
260
261#[track_caller]
262pub(crate) fn get_and_clear_last_mlx_error() -> Option<RawException> {
263    LAST_MLX_ERROR.with(|last_error| {
264        let last_err_ptr = last_error.replace(std::ptr::null());
265        if last_err_ptr.is_null() {
266            return None;
267        }
268
269        let last_err = unsafe {
270            std::ffi::CStr::from_ptr(last_err_ptr)
271                .to_string_lossy()
272                .into_owned()
273        };
274        unsafe {
275            libc::free(last_err_ptr as *mut libc::c_void);
276        }
277
278        Some(RawException { what: last_err })
279    })
280}
281
282/// Error with building a cross-entropy loss function
283#[derive(Debug, Clone, PartialEq, Error)]
284pub enum CrossEntropyBuildError {
285    /// Label smoothing factor must be in the range [0, 1)
286    #[error("Label smoothing factor must be in the range [0, 1)")]
287    InvalidLabelSmoothingFactor,
288}
289
290impl From<CrossEntropyBuildError> for Exception {
291    fn from(value: CrossEntropyBuildError) -> Self {
292        Exception::custom(format!("{}", value))
293    }
294}
295
296/// Error with building a RmsProp optimizer
297#[derive(Debug, Clone, PartialEq, Error)]
298pub enum RmsPropBuildError {
299    /// Alpha must be non-negative
300    #[error("alpha must be non-negative")]
301    NegativeAlpha,
302
303    /// Epsilon must be non-negative
304    #[error("epsilon must be non-negative")]
305    NegativeEpsilon,
306}
307
308/// Error with building an AdaDelta optimizer
309#[derive(Debug, Clone, PartialEq, Error)]
310pub enum AdaDeltaBuildError {
311    /// Rho must be non-negative
312    #[error("rho must be non-negative")]
313    NegativeRho,
314
315    /// Epsilon must be non-negative
316    #[error("epsilon must be non-negative")]
317    NegativeEps,
318}
319
320/// Error with building an Adafactor optimizer.
321#[derive(Debug, Clone, PartialEq, Error)]
322pub enum AdafactorBuildError {
323    /// Either learning rate is provided or relative step is set to true.
324    #[error("Either learning rate is provided or relative step is set to true")]
325    LrIsNoneAndRelativeStepIsFalse,
326}
327
328/// Error with building a dropout layer
329#[derive(Debug, Clone, PartialEq, Error)]
330pub enum DropoutBuildError {
331    /// Dropout probability must be in the range [0, 1)
332    #[error("Dropout probability must be in the range [0, 1)")]
333    InvalidProbability,
334}
335
336/// Error with building a MultiHeadAttention module
337#[derive(Debug, PartialEq, Error)]
338pub enum MultiHeadAttentionBuildError {
339    /// Invalid number of heads
340    #[error("Invalid number of heads: {0}")]
341    InvalidNumHeads(i32),
342
343    /// Exceptions
344    #[error(transparent)]
345    Exception(#[from] Exception),
346}
347
348/// Error with building a transformer
349#[derive(Debug, PartialEq, Error)]
350pub enum TransformerBulidError {
351    /// Dropout probability must be in the range [0, 1)
352    #[error("Dropout probability must be in the range [0, 1)")]
353    InvalidProbability,
354
355    /// Invalid number of heads
356    #[error("Invalid number of heads: {0}")]
357    InvalidNumHeads(i32),
358
359    /// Exceptions
360    #[error(transparent)]
361    Exception(#[from] Exception),
362}
363
364impl From<DropoutBuildError> for TransformerBulidError {
365    fn from(e: DropoutBuildError) -> Self {
366        match e {
367            DropoutBuildError::InvalidProbability => Self::InvalidProbability,
368        }
369    }
370}
371
372impl From<MultiHeadAttentionBuildError> for TransformerBulidError {
373    fn from(e: MultiHeadAttentionBuildError) -> Self {
374        match e {
375            MultiHeadAttentionBuildError::InvalidNumHeads(n) => Self::InvalidNumHeads(n),
376            MultiHeadAttentionBuildError::Exception(e) => Self::Exception(e),
377        }
378    }
379}
380
381#[cfg(test)]
382mod tests {
383    use crate::array;
384
385    #[test]
386    fn test_exception() {
387        let a = array!([1.0, 2.0, 3.0]);
388        let b = array!([4.0, 5.0]);
389
390        let result = a.add(&b);
391        let error = result.expect_err("Expected error");
392
393        // The full error message would also contain the full path to the original c++ file,
394        // so we just check for a substring
395        assert!(error
396            .what()
397            .contains("Shapes (3) and (2) cannot be broadcast."))
398    }
399}