1use crate::error::{Exception, IoError};
2use crate::utils::SUCCESS;
3use crate::{Array, Stream};
4use std::collections::HashMap;
5use std::ffi::{CStr, CString};
6use std::path::Path;
7use std::ptr::null_mut;
8
9use super::Guarded;
10
11pub(crate) struct SafeTensors {
12 pub(crate) c_data: mlx_sys::mlx_map_string_to_array,
13 pub(crate) c_metadata: mlx_sys::mlx_map_string_to_string,
14}
15
16impl Drop for SafeTensors {
17 fn drop(&mut self) {
18 unsafe {
19 mlx_sys::mlx_map_string_to_string_free(self.c_metadata);
20 mlx_sys::mlx_map_string_to_array_free(self.c_data);
21 }
22 }
23}
24
25impl SafeTensors {
26 pub(crate) fn load_device(path: &Path, stream: impl AsRef<Stream>) -> Result<Self, IoError> {
27 if !path.is_file() {
28 return Err(IoError::NotFile);
29 }
30
31 let extension = path
32 .extension()
33 .and_then(|ext| ext.to_str())
34 .ok_or(IoError::UnsupportedFormat)?;
35
36 if extension != "safetensors" {
37 return Err(IoError::UnsupportedFormat);
38 }
39
40 let path_str = path.to_str().ok_or(IoError::InvalidUtf8)?;
41 let filepath = CString::new(path_str)?;
42
43 SafeTensors::try_from_op(|(res_0, res_1)| unsafe {
44 mlx_sys::mlx_load_safetensors(res_0, res_1, filepath.as_ptr(), stream.as_ref().as_ptr())
45 })
46 .map_err(Into::into)
47 }
48
49 pub(crate) fn data(&self) -> Result<HashMap<String, Array>, Exception> {
50 crate::error::INIT_ERR_HANDLER
51 .with(|init| init.call_once(crate::error::setup_mlx_error_handler));
52 let mut map = HashMap::new();
53 unsafe {
54 let iterator = mlx_sys::mlx_map_string_to_array_iterator_new(self.c_data);
55
56 loop {
57 let mut key_ptr: *const ::std::os::raw::c_char = null_mut();
58 let mut value = mlx_sys::mlx_array_new();
59 let status = mlx_sys::mlx_map_string_to_array_iterator_next(
60 &mut key_ptr as *mut *const _,
61 &mut value,
62 iterator,
63 );
64
65 match status {
66 SUCCESS => {
67 let key = CStr::from_ptr(key_ptr).to_string_lossy().into_owned();
68 let array = Array::from_ptr(value);
69 map.insert(key, array);
70 }
71 1 => {
72 mlx_sys::mlx_array_free(value);
73 return Err(crate::error::get_and_clear_last_mlx_error()
74 .expect("A non-success status was returned, but no error was set.")
75 .into());
76 }
77 2 => {
78 mlx_sys::mlx_array_free(value);
79 break;
80 }
81 _ => unreachable!(),
82 }
83 }
84
85 mlx_sys::mlx_map_string_to_array_iterator_free(iterator);
86 }
87
88 Ok(map)
89 }
90
91 pub(crate) fn metadata(&self) -> Result<HashMap<String, String>, Exception> {
92 crate::error::INIT_ERR_HANDLER
93 .with(|init| init.call_once(crate::error::setup_mlx_error_handler));
94
95 let mut map = HashMap::new();
96 unsafe {
97 let iterator = mlx_sys::mlx_map_string_to_string_iterator_new(self.c_metadata);
98
99 let mut key: *const ::std::os::raw::c_char = null_mut();
100 let mut value: *const ::std::os::raw::c_char = null_mut();
101 loop {
102 let status = mlx_sys::mlx_map_string_to_string_iterator_next(
103 &mut key as *mut *const _,
104 &mut value as *mut *const _,
105 iterator,
106 );
107
108 match status {
109 SUCCESS => {
110 let key = CStr::from_ptr(key).to_string_lossy().into_owned();
111 let value = CStr::from_ptr(value).to_string_lossy().into_owned();
112 map.insert(key, value);
113 }
114 1 => {
115 return Err(crate::error::get_and_clear_last_mlx_error()
116 .expect("A non-success status was returned, but no error was set.")
117 .into())
118 }
119 2 => break,
120 _ => unreachable!(),
121 }
122 }
123 }
124
125 Ok(map)
126 }
127}