mlx_rs/utils/
io.rs

1use crate::error::{Exception, IoError};
2use crate::utils::SUCCESS;
3use crate::{Array, Stream};
4use std::collections::HashMap;
5use std::ffi::{CStr, CString};
6use std::path::Path;
7use std::ptr::null_mut;
8
9use super::Guarded;
10
11pub(crate) struct SafeTensors {
12    pub(crate) c_data: mlx_sys::mlx_map_string_to_array,
13    pub(crate) c_metadata: mlx_sys::mlx_map_string_to_string,
14}
15
16impl Drop for SafeTensors {
17    fn drop(&mut self) {
18        unsafe {
19            mlx_sys::mlx_map_string_to_string_free(self.c_metadata);
20            mlx_sys::mlx_map_string_to_array_free(self.c_data);
21        }
22    }
23}
24
25impl SafeTensors {
26    pub(crate) fn load_device(path: &Path, stream: impl AsRef<Stream>) -> Result<Self, IoError> {
27        if !path.is_file() {
28            return Err(IoError::NotFile);
29        }
30
31        let extension = path
32            .extension()
33            .and_then(|ext| ext.to_str())
34            .ok_or(IoError::UnsupportedFormat)?;
35
36        if extension != "safetensors" {
37            return Err(IoError::UnsupportedFormat);
38        }
39
40        let path_str = path.to_str().ok_or(IoError::InvalidUtf8)?;
41        let filepath = CString::new(path_str)?;
42
43        SafeTensors::try_from_op(|(res_0, res_1)| unsafe {
44            mlx_sys::mlx_load_safetensors(res_0, res_1, filepath.as_ptr(), stream.as_ref().as_ptr())
45        })
46        .map_err(Into::into)
47    }
48
49    pub(crate) fn data(&self) -> Result<HashMap<String, Array>, Exception> {
50        crate::error::INIT_ERR_HANDLER
51            .with(|init| init.call_once(crate::error::setup_mlx_error_handler));
52        let mut map = HashMap::new();
53        unsafe {
54            let iterator = mlx_sys::mlx_map_string_to_array_iterator_new(self.c_data);
55
56            loop {
57                let mut key_ptr: *const ::std::os::raw::c_char = null_mut();
58                let mut value = mlx_sys::mlx_array_new();
59                let status = mlx_sys::mlx_map_string_to_array_iterator_next(
60                    &mut key_ptr as *mut *const _,
61                    &mut value,
62                    iterator,
63                );
64
65                match status {
66                    SUCCESS => {
67                        let key = CStr::from_ptr(key_ptr).to_string_lossy().into_owned();
68                        let array = Array::from_ptr(value);
69                        map.insert(key, array);
70                    }
71                    1 => {
72                        mlx_sys::mlx_array_free(value);
73                        return Err(crate::error::get_and_clear_last_mlx_error()
74                            .expect("A non-success status was returned, but no error was set.")
75                            .into());
76                    }
77                    2 => {
78                        mlx_sys::mlx_array_free(value);
79                        break;
80                    }
81                    _ => unreachable!(),
82                }
83            }
84
85            mlx_sys::mlx_map_string_to_array_iterator_free(iterator);
86        }
87
88        Ok(map)
89    }
90
91    pub(crate) fn metadata(&self) -> Result<HashMap<String, String>, Exception> {
92        crate::error::INIT_ERR_HANDLER
93            .with(|init| init.call_once(crate::error::setup_mlx_error_handler));
94
95        let mut map = HashMap::new();
96        unsafe {
97            let iterator = mlx_sys::mlx_map_string_to_string_iterator_new(self.c_metadata);
98
99            let mut key: *const ::std::os::raw::c_char = null_mut();
100            let mut value: *const ::std::os::raw::c_char = null_mut();
101            loop {
102                let status = mlx_sys::mlx_map_string_to_string_iterator_next(
103                    &mut key as *mut *const _,
104                    &mut value as *mut *const _,
105                    iterator,
106                );
107
108                match status {
109                    SUCCESS => {
110                        let key = CStr::from_ptr(key).to_string_lossy().into_owned();
111                        let value = CStr::from_ptr(value).to_string_lossy().into_owned();
112                        map.insert(key, value);
113                    }
114                    1 => {
115                        return Err(crate::error::get_and_clear_last_mlx_error()
116                            .expect("A non-success status was returned, but no error was set.")
117                            .into())
118                    }
119                    2 => break,
120                    _ => unreachable!(),
121                }
122            }
123        }
124
125        Ok(map)
126    }
127}