mlx_rs/ops/
io.rs

1use crate::error::IoError;
2use crate::utils::guard::Guarded;
3use crate::utils::io::{FilePtr, SafeTensors};
4use crate::utils::SUCCESS;
5use crate::{Array, Stream, StreamOrDevice};
6use mlx_internal_macros::default_device;
7use std::collections::HashMap;
8use std::ffi::CString;
9use std::path::Path;
10
11fn check_file_extension(path: &Path, expected: &str) -> Result<(), IoError> {
12    match path.extension().and_then(|ext| ext.to_str()) {
13        Some(ext) if ext == expected => Ok(()),
14        _ => Err(IoError::UnsupportedFormat),
15    }
16}
17
18impl Array {
19    /// Load array from a binary file in `.npy` format.
20    ///
21    /// # Params
22    ///
23    /// - path: path of file to load
24    /// - stream: stream or device to evaluate on
25    #[default_device]
26    pub fn load_numpy_device(
27        path: impl AsRef<Path>,
28        stream: impl AsRef<Stream>,
29    ) -> Result<Array, IoError> {
30        let path = path.as_ref();
31        if !path.is_file() {
32            return Err(IoError::NotFile);
33        }
34        let c_path = CString::new(path.to_str().ok_or(IoError::InvalidUtf8)?)?;
35        check_file_extension(path, "npy")?;
36
37        Array::try_from_op(|res| unsafe {
38            mlx_sys::mlx_load(res, c_path.as_ptr(), stream.as_ref().as_ptr())
39        })
40        .map_err(Into::into)
41    }
42
43    /// Load dictionary of ``MLXArray`` from a `safetensors` file.
44    ///
45    /// # Params
46    ///
47    /// - path: path of file to load
48    /// - stream: stream or device to evaluate on
49    ///
50    #[default_device]
51    pub fn load_safetensors_device(
52        path: impl AsRef<Path>,
53        stream: impl AsRef<Stream>,
54    ) -> Result<HashMap<String, Array>, IoError> {
55        let safetensors = SafeTensors::load_device(path.as_ref(), stream)?;
56        let data = safetensors.data()?;
57        Ok(data)
58    }
59
60    /// Load dictionary of ``MLXArray`` and metadata `[String:String]` from a `safetensors` file.
61    ///
62    /// # Params
63    ///
64    /// - path: path of file to load
65    /// - stream: stream or device to evaluate on
66    #[allow(clippy::type_complexity)]
67    #[default_device]
68    pub fn load_safetensors_with_metadata_device(
69        path: impl AsRef<Path>,
70        stream: impl AsRef<Stream>,
71    ) -> Result<(HashMap<String, Array>, HashMap<String, String>), IoError> {
72        let safetensors = SafeTensors::load_device(path.as_ref(), stream)?;
73        let data = safetensors.data()?;
74        let metadata = safetensors.metadata()?;
75
76        Ok((data, metadata))
77    }
78
79    /// Save array to a binary file in `.npy`format.
80    ///
81    /// # Params
82    ///
83    /// - array: array to save
84    /// - url: URL of file to load
85    pub fn save_numpy(&self, path: impl AsRef<Path>) -> Result<(), IoError> {
86        let path = path.as_ref();
87        check_file_extension(path, "npy")?;
88        let file_ptr = FilePtr::open(path, "w")?;
89
90        unsafe { mlx_sys::mlx_save_file(file_ptr.as_ptr(), self.as_ptr()) };
91
92        Ok(())
93    }
94
95    /// Save dictionary of arrays in `safetensors` format.
96    ///
97    /// # Params
98    ///
99    /// - arrays: arrays to save
100    /// - metadata: metadata to save
101    /// - path: path of file to save
102    /// - stream: stream or device to evaluate on
103    pub fn save_safetensors<'a, I, S, V>(
104        arrays: I,
105        metadata: impl Into<Option<&'a HashMap<String, String>>>,
106        path: impl AsRef<Path>,
107    ) -> Result<(), IoError>
108    where
109        I: IntoIterator<Item = (S, V)>,
110        S: AsRef<str>,
111        V: AsRef<Array>,
112    {
113        crate::error::INIT_ERR_HANDLER
114            .with(|init| init.call_once(crate::error::setup_mlx_error_handler));
115
116        let path = path.as_ref();
117
118        check_file_extension(path, "safetensors")?;
119
120        let arrays = unsafe {
121            let data = mlx_sys::mlx_map_string_to_array_new();
122            for (key, array) in arrays.into_iter() {
123                let key = CString::new(key.as_ref())?;
124
125                let status = mlx_sys::mlx_map_string_to_array_insert(
126                    data,
127                    key.as_ptr(),
128                    array.as_ref().as_ptr(),
129                );
130
131                if status != SUCCESS {
132                    mlx_sys::mlx_map_string_to_array_free(data);
133                    return Err(crate::error::get_and_clear_last_mlx_error()
134                        .expect("A non-success status was returned, but no error was set.")
135                        .into());
136                }
137            }
138            data
139        };
140
141        let default_metadata = HashMap::new();
142        let metadata_ref = metadata.into().unwrap_or(&default_metadata);
143
144        let metadata = unsafe {
145            let data = mlx_sys::mlx_map_string_to_string_new();
146            for (key, value) in metadata_ref.iter() {
147                let key = CString::new(key.as_str())?;
148                let value = CString::new(value.as_str())?;
149
150                let status =
151                    mlx_sys::mlx_map_string_to_string_insert(data, key.as_ptr(), value.as_ptr());
152
153                if status != SUCCESS {
154                    mlx_sys::mlx_map_string_to_string_free(data);
155                    return Err(crate::error::get_and_clear_last_mlx_error()
156                        .expect("A non-success status was returned, but no error was set.")
157                        .into());
158                }
159            }
160            data
161        };
162
163        let file_ptr = FilePtr::open(path, "w")?;
164
165        unsafe {
166            let status = mlx_sys::mlx_save_safetensors_file(file_ptr.as_ptr(), arrays, metadata);
167
168            let last_error = match status {
169                SUCCESS => None,
170                _ => Some(
171                    crate::error::get_and_clear_last_mlx_error()
172                        .expect("A non-success status was returned, but no error was set."),
173                ),
174            };
175
176            mlx_sys::mlx_map_string_to_array_free(arrays);
177            mlx_sys::mlx_map_string_to_string_free(metadata);
178
179            if let Some(error) = last_error {
180                return Err(error.into());
181            }
182        };
183
184        Ok(())
185    }
186}
187
188#[cfg(test)]
189mod tests {
190    use crate::Array;
191
192    #[test]
193    fn test_save_arrays() {
194        let tmp_dir = tempfile::tempdir().unwrap();
195        let path = tmp_dir.path().join("test.safetensors");
196
197        let mut arrays = std::collections::HashMap::new();
198        arrays.insert("foo".to_string(), Array::ones::<i32>(&[1, 2]).unwrap());
199        arrays.insert("bar".to_string(), Array::zeros::<i32>(&[2, 1]).unwrap());
200
201        Array::save_safetensors(&arrays, None, &path).unwrap();
202
203        let loaded_arrays = Array::load_safetensors(&path).unwrap();
204
205        // compare values
206        let mut loaded_keys: Vec<_> = loaded_arrays.keys().cloned().collect();
207        let mut original_keys: Vec<_> = arrays.keys().cloned().collect();
208        loaded_keys.sort();
209        original_keys.sort();
210        assert_eq!(loaded_keys, original_keys);
211
212        for key in loaded_keys {
213            let loaded_array = loaded_arrays.get(&key).unwrap();
214            let original_array = arrays.get(&key).unwrap();
215            assert!(loaded_array
216                .all_close(original_array, None, None, None)
217                .unwrap()
218                .item::<bool>());
219        }
220    }
221
222    #[test]
223    fn test_save_array() {
224        let tmp_dir = tempfile::tempdir().unwrap();
225        let path = tmp_dir.path().join("test.npy");
226
227        let a = Array::ones::<i32>(&[2, 4]).unwrap();
228        a.save_numpy(&path).unwrap();
229
230        let b = Array::load_numpy(&path).unwrap();
231        assert!(a.all_close(&b, None, None, None).unwrap().item::<bool>());
232    }
233}