1use crate::error::IoError;
2use crate::utils::guard::Guarded;
3use crate::utils::io::{FilePtr, SafeTensors};
4use crate::utils::SUCCESS;
5use crate::{Array, Stream, StreamOrDevice};
6use mlx_internal_macros::default_device;
7use std::collections::HashMap;
8use std::ffi::CString;
9use std::path::Path;
10
11fn check_file_extension(path: &Path, expected: &str) -> Result<(), IoError> {
12 match path.extension().and_then(|ext| ext.to_str()) {
13 Some(ext) if ext == expected => Ok(()),
14 _ => Err(IoError::UnsupportedFormat),
15 }
16}
17
18impl Array {
19 #[default_device]
26 pub fn load_numpy_device(
27 path: impl AsRef<Path>,
28 stream: impl AsRef<Stream>,
29 ) -> Result<Array, IoError> {
30 let path = path.as_ref();
31 if !path.is_file() {
32 return Err(IoError::NotFile);
33 }
34 let c_path = CString::new(path.to_str().ok_or(IoError::InvalidUtf8)?)?;
35 check_file_extension(path, "npy")?;
36
37 Array::try_from_op(|res| unsafe {
38 mlx_sys::mlx_load(res, c_path.as_ptr(), stream.as_ref().as_ptr())
39 })
40 .map_err(Into::into)
41 }
42
43 #[default_device]
51 pub fn load_safetensors_device(
52 path: impl AsRef<Path>,
53 stream: impl AsRef<Stream>,
54 ) -> Result<HashMap<String, Array>, IoError> {
55 let safetensors = SafeTensors::load_device(path.as_ref(), stream)?;
56 let data = safetensors.data()?;
57 Ok(data)
58 }
59
60 #[allow(clippy::type_complexity)]
67 #[default_device]
68 pub fn load_safetensors_with_metadata_device(
69 path: impl AsRef<Path>,
70 stream: impl AsRef<Stream>,
71 ) -> Result<(HashMap<String, Array>, HashMap<String, String>), IoError> {
72 let safetensors = SafeTensors::load_device(path.as_ref(), stream)?;
73 let data = safetensors.data()?;
74 let metadata = safetensors.metadata()?;
75
76 Ok((data, metadata))
77 }
78
79 pub fn save_numpy(&self, path: impl AsRef<Path>) -> Result<(), IoError> {
86 let path = path.as_ref();
87 check_file_extension(path, "npy")?;
88 let file_ptr = FilePtr::open(path, "w")?;
89
90 unsafe { mlx_sys::mlx_save_file(file_ptr.as_ptr(), self.as_ptr()) };
91
92 Ok(())
93 }
94
95 pub fn save_safetensors<'a, I, S, V>(
104 arrays: I,
105 metadata: impl Into<Option<&'a HashMap<String, String>>>,
106 path: impl AsRef<Path>,
107 ) -> Result<(), IoError>
108 where
109 I: IntoIterator<Item = (S, V)>,
110 S: AsRef<str>,
111 V: AsRef<Array>,
112 {
113 crate::error::INIT_ERR_HANDLER
114 .with(|init| init.call_once(crate::error::setup_mlx_error_handler));
115
116 let path = path.as_ref();
117
118 check_file_extension(path, "safetensors")?;
119
120 let arrays = unsafe {
121 let data = mlx_sys::mlx_map_string_to_array_new();
122 for (key, array) in arrays.into_iter() {
123 let key = CString::new(key.as_ref())?;
124
125 let status = mlx_sys::mlx_map_string_to_array_insert(
126 data,
127 key.as_ptr(),
128 array.as_ref().as_ptr(),
129 );
130
131 if status != SUCCESS {
132 mlx_sys::mlx_map_string_to_array_free(data);
133 return Err(crate::error::get_and_clear_last_mlx_error()
134 .expect("A non-success status was returned, but no error was set.")
135 .into());
136 }
137 }
138 data
139 };
140
141 let default_metadata = HashMap::new();
142 let metadata_ref = metadata.into().unwrap_or(&default_metadata);
143
144 let metadata = unsafe {
145 let data = mlx_sys::mlx_map_string_to_string_new();
146 for (key, value) in metadata_ref.iter() {
147 let key = CString::new(key.as_str())?;
148 let value = CString::new(value.as_str())?;
149
150 let status =
151 mlx_sys::mlx_map_string_to_string_insert(data, key.as_ptr(), value.as_ptr());
152
153 if status != SUCCESS {
154 mlx_sys::mlx_map_string_to_string_free(data);
155 return Err(crate::error::get_and_clear_last_mlx_error()
156 .expect("A non-success status was returned, but no error was set.")
157 .into());
158 }
159 }
160 data
161 };
162
163 let file_ptr = FilePtr::open(path, "w")?;
164
165 unsafe {
166 let status = mlx_sys::mlx_save_safetensors_file(file_ptr.as_ptr(), arrays, metadata);
167
168 let last_error = match status {
169 SUCCESS => None,
170 _ => Some(
171 crate::error::get_and_clear_last_mlx_error()
172 .expect("A non-success status was returned, but no error was set."),
173 ),
174 };
175
176 mlx_sys::mlx_map_string_to_array_free(arrays);
177 mlx_sys::mlx_map_string_to_string_free(metadata);
178
179 if let Some(error) = last_error {
180 return Err(error.into());
181 }
182 };
183
184 Ok(())
185 }
186}
187
188#[cfg(test)]
189mod tests {
190 use crate::Array;
191
192 #[test]
193 fn test_save_arrays() {
194 let tmp_dir = tempfile::tempdir().unwrap();
195 let path = tmp_dir.path().join("test.safetensors");
196
197 let mut arrays = std::collections::HashMap::new();
198 arrays.insert("foo".to_string(), Array::ones::<i32>(&[1, 2]).unwrap());
199 arrays.insert("bar".to_string(), Array::zeros::<i32>(&[2, 1]).unwrap());
200
201 Array::save_safetensors(&arrays, None, &path).unwrap();
202
203 let loaded_arrays = Array::load_safetensors(&path).unwrap();
204
205 let mut loaded_keys: Vec<_> = loaded_arrays.keys().cloned().collect();
207 let mut original_keys: Vec<_> = arrays.keys().cloned().collect();
208 loaded_keys.sort();
209 original_keys.sort();
210 assert_eq!(loaded_keys, original_keys);
211
212 for key in loaded_keys {
213 let loaded_array = loaded_arrays.get(&key).unwrap();
214 let original_array = arrays.get(&key).unwrap();
215 assert!(loaded_array
216 .all_close(original_array, None, None, None)
217 .unwrap()
218 .item::<bool>());
219 }
220 }
221
222 #[test]
223 fn test_save_array() {
224 let tmp_dir = tempfile::tempdir().unwrap();
225 let path = tmp_dir.path().join("test.npy");
226
227 let a = Array::ones::<i32>(&[2, 4]).unwrap();
228 a.save_numpy(&path).unwrap();
229
230 let b = Array::load_numpy(&path).unwrap();
231 assert!(a.all_close(&b, None, None, None).unwrap().item::<bool>());
232 }
233}