use crate::Dtype;
use libc::strdup;
use std::convert::Infallible;
use std::ffi::NulError;
use std::panic::Location;
use std::sync::Once;
use std::{cell::Cell, ffi::c_char};
use thiserror::Error;
pub type Result<T> = std::result::Result<T, Exception>;
#[derive(Error, PartialEq, Debug)]
pub enum IoError {
#[error("Path must point to a local file")]
NotFile,
#[error("Path contains invalid UTF-8")]
InvalidUtf8,
#[error("Path contains null bytes")]
NullBytes,
#[error("No file extension found")]
NoExtension,
#[error("Unsupported file format")]
UnsupportedFormat,
#[error("Unable to open file")]
UnableToOpenFile,
#[error("Unable to allocate memory")]
AllocationError,
#[error(transparent)]
NulError(#[from] NulError),
#[error(transparent)]
Unflatten(#[from] UnflattenError),
#[error(transparent)]
Exception(#[from] Exception),
}
impl From<Infallible> for IoError {
fn from(_: Infallible) -> Self {
unreachable!()
}
}
impl From<RawException> for IoError {
#[track_caller]
fn from(e: RawException) -> Self {
let exception = Exception {
what: e.what,
location: Location::caller(),
};
Self::Exception(exception)
}
}
#[derive(Debug, PartialEq, Error)]
pub enum AsSliceError {
#[error("The data pointer is null.")]
Null,
#[error("dtype mismatch: expected {expecting:?}, found {found:?}")]
DtypeMismatch {
expecting: Dtype,
found: Dtype,
},
#[error(transparent)]
Exception(#[from] Exception),
}
#[derive(Debug, PartialEq, Error)]
pub enum UnflattenError {
#[error("Expecting next (key, value) pair, found none")]
ExpectingNextPair,
#[error("Invalid key")]
InvalidKey,
}
#[derive(Debug, PartialEq, Error)]
pub enum OptimizerStateLoadError {
#[error(transparent)]
Io(#[from] IoError),
#[error(transparent)]
Unflatten(#[from] UnflattenError),
}
impl From<Infallible> for OptimizerStateLoadError {
fn from(_: Infallible) -> Self {
unreachable!()
}
}
cfg_safetensors! {
#[derive(Debug, Error)]
pub enum ConversionError {
#[error("The safetensors data type {0:?} is not supported.")]
SafeTensorDtype(safetensors::tensor::Dtype),
#[error("The mlx data type {0:?} is not supported.")]
MlxDtype(crate::Dtype),
#[error(transparent)]
PodCastError(#[from] bytemuck::PodCastError),
#[error(transparent)]
SafeTensorError(#[from] safetensors::tensor::SafeTensorError),
}
}
pub(crate) struct RawException {
pub(crate) what: String,
}
#[derive(Debug, PartialEq, Error)]
#[error("{what:?} at {location}")]
pub struct Exception {
pub(crate) what: String,
pub(crate) location: &'static Location<'static>,
}
impl Exception {
pub fn what(&self) -> &str {
&self.what
}
pub fn location(&self) -> &'static Location<'static> {
self.location
}
#[track_caller]
pub fn custom(what: impl Into<String>) -> Self {
Self {
what: what.into(),
location: Location::caller(),
}
}
}
impl From<RawException> for Exception {
#[track_caller]
fn from(e: RawException) -> Self {
Self {
what: e.what,
location: Location::caller(),
}
}
}
impl From<&str> for Exception {
#[track_caller]
fn from(what: &str) -> Self {
Self {
what: what.to_string(),
location: Location::caller(),
}
}
}
impl From<Infallible> for Exception {
fn from(_: Infallible) -> Self {
unreachable!()
}
}
impl From<Exception> for String {
fn from(e: Exception) -> Self {
e.what
}
}
thread_local! {
static CLOSURE_ERROR: Cell<Option<Exception>> = const { Cell::new(None) };
static LAST_MLX_ERROR: Cell<*const c_char> = const { Cell::new(std::ptr::null()) };
pub(crate) static INIT_ERR_HANDLER: Once = const { Once::new() };
}
#[no_mangle]
extern "C" fn default_mlx_error_handler(msg: *const c_char, _data: *mut std::ffi::c_void) {
unsafe {
LAST_MLX_ERROR.with(|last_error| {
last_error.set(strdup(msg));
});
}
}
#[no_mangle]
extern "C" fn noop_mlx_error_handler_data_deleter(_data: *mut std::ffi::c_void) {}
pub(crate) fn setup_mlx_error_handler() {
let handler = default_mlx_error_handler;
let data_ptr = LAST_MLX_ERROR.with(|last_error| last_error.as_ptr() as *mut std::ffi::c_void);
let dtor = noop_mlx_error_handler_data_deleter;
unsafe {
mlx_sys::mlx_set_error_handler(Some(handler), data_ptr, Some(dtor));
}
}
pub(crate) fn set_closure_error(err: Exception) {
CLOSURE_ERROR.with(|closure_error| closure_error.set(Some(err)));
}
pub(crate) fn get_and_clear_closure_error() -> Option<Exception> {
CLOSURE_ERROR.with(|closure_error| closure_error.replace(None))
}
#[track_caller]
pub(crate) fn get_and_clear_last_mlx_error() -> Option<RawException> {
LAST_MLX_ERROR.with(|last_error| {
let last_err_ptr = last_error.replace(std::ptr::null());
if last_err_ptr.is_null() {
return None;
}
let last_err = unsafe {
std::ffi::CStr::from_ptr(last_err_ptr)
.to_string_lossy()
.into_owned()
};
unsafe {
libc::free(last_err_ptr as *mut libc::c_void);
}
Some(RawException { what: last_err })
})
}
#[derive(Debug, Clone, PartialEq, Error)]
pub enum CrossEntropyBuildError {
#[error("Label smoothing factor must be in the range [0, 1)")]
InvalidLabelSmoothingFactor,
}
impl From<CrossEntropyBuildError> for Exception {
fn from(value: CrossEntropyBuildError) -> Self {
Exception::custom(format!("{}", value))
}
}
#[derive(Debug, Clone, PartialEq, Error)]
pub enum RmsPropBuildError {
#[error("alpha must be non-negative")]
NegativeAlpha,
#[error("epsilon must be non-negative")]
NegativeEpsilon,
}
#[derive(Debug, Clone, PartialEq, Error)]
pub enum AdaDeltaBuildError {
#[error("rho must be non-negative")]
NegativeRho,
#[error("epsilon must be non-negative")]
NegativeEps,
}
#[derive(Debug, Clone, PartialEq, Error)]
pub enum AdafactorBuildError {
#[error("Either learning rate is provided or relative step is set to true")]
LrIsNoneAndRelativeStepIsFalse,
}
#[derive(Debug, Clone, PartialEq, Error)]
pub enum DropoutBuildError {
#[error("Dropout probability must be in the range [0, 1)")]
InvalidProbability,
}
#[derive(Debug, PartialEq, Error)]
pub enum MultiHeadAttentionBuildError {
#[error("Invalid number of heads: {0}")]
InvalidNumHeads(i32),
#[error(transparent)]
Exception(#[from] Exception),
}
#[derive(Debug, PartialEq, Error)]
pub enum TransformerBulidError {
#[error("Dropout probability must be in the range [0, 1)")]
InvalidProbability,
#[error("Invalid number of heads: {0}")]
InvalidNumHeads(i32),
#[error(transparent)]
Exception(#[from] Exception),
}
impl From<DropoutBuildError> for TransformerBulidError {
fn from(e: DropoutBuildError) -> Self {
match e {
DropoutBuildError::InvalidProbability => Self::InvalidProbability,
}
}
}
impl From<MultiHeadAttentionBuildError> for TransformerBulidError {
fn from(e: MultiHeadAttentionBuildError) -> Self {
match e {
MultiHeadAttentionBuildError::InvalidNumHeads(n) => Self::InvalidNumHeads(n),
MultiHeadAttentionBuildError::Exception(e) => Self::Exception(e),
}
}
}
#[cfg(test)]
mod tests {
use crate::array;
#[test]
fn test_exception() {
let a = array!([1.0, 2.0, 3.0]);
let b = array!([4.0, 5.0]);
let result = a.add(&b);
let error = result.expect_err("Expected error");
assert!(error
.what()
.contains("Shapes (3) and (2) cannot be broadcast."))
}
}