1use crate::error::{Exception, IoError};
2use crate::utils::SUCCESS;
3use crate::{Array, Stream};
4use mlx_sys::FILE;
5use std::collections::HashMap;
6use std::ffi::{CStr, CString};
7use std::path::Path;
8use std::ptr::{null_mut, NonNull};
9
10use super::Guarded;
11
12pub(crate) struct FilePtr(NonNull<FILE>);
13
14impl Drop for FilePtr {
15 fn drop(&mut self) {
16 unsafe {
17 mlx_sys::fclose(self.0.as_ptr());
18 }
19 }
20}
21
22impl FilePtr {
23 pub(crate) fn as_ptr(&self) -> *mut FILE {
24 self.0.as_ptr()
25 }
26
27 pub(crate) fn open(path: &Path, mode: &str) -> Result<Self, IoError> {
28 let path = CString::new(path.to_str().ok_or(IoError::InvalidUtf8)?)?;
29 let mode = CString::new(mode)?;
30
31 unsafe {
32 NonNull::new(mlx_sys::fopen(path.as_ptr(), mode.as_ptr()))
33 .map(Self)
34 .ok_or(IoError::UnableToOpenFile)
35 }
36 }
37}
38
39pub(crate) struct SafeTensors {
40 pub(crate) c_data: mlx_sys::mlx_map_string_to_array,
41 pub(crate) c_metadata: mlx_sys::mlx_map_string_to_string,
42}
43
44impl Drop for SafeTensors {
45 fn drop(&mut self) {
46 unsafe {
47 mlx_sys::mlx_map_string_to_string_free(self.c_metadata);
48 mlx_sys::mlx_map_string_to_array_free(self.c_data);
49 }
50 }
51}
52
53impl SafeTensors {
54 pub(crate) fn load_device(path: &Path, stream: impl AsRef<Stream>) -> Result<Self, IoError> {
55 if !path.is_file() {
56 return Err(IoError::NotFile);
57 }
58
59 let extension = path
60 .extension()
61 .and_then(|ext| ext.to_str())
62 .ok_or(IoError::UnsupportedFormat)?;
63
64 if extension != "safetensors" {
65 return Err(IoError::UnsupportedFormat);
66 }
67
68 let path_str = path.to_str().ok_or(IoError::InvalidUtf8)?;
69 let filepath = CString::new(path_str)?;
70
71 SafeTensors::try_from_op(|(res_0, res_1)| unsafe {
72 mlx_sys::mlx_load_safetensors(res_0, res_1, filepath.as_ptr(), stream.as_ref().as_ptr())
73 })
74 .map_err(Into::into)
75 }
76
77 pub(crate) fn data(&self) -> Result<HashMap<String, Array>, Exception> {
78 crate::error::INIT_ERR_HANDLER
79 .with(|init| init.call_once(crate::error::setup_mlx_error_handler));
80 let mut map = HashMap::new();
81 unsafe {
82 let iterator = mlx_sys::mlx_map_string_to_array_iterator_new(self.c_data);
83
84 loop {
85 let mut key_ptr: *const ::std::os::raw::c_char = null_mut();
86 let mut value = mlx_sys::mlx_array_new();
87 let status = mlx_sys::mlx_map_string_to_array_iterator_next(
88 &mut key_ptr as *mut *const _,
89 &mut value,
90 iterator,
91 );
92
93 match status {
94 SUCCESS => {
95 let key = CStr::from_ptr(key_ptr).to_string_lossy().into_owned();
96 let array = Array::from_ptr(value);
97 map.insert(key, array);
98 }
99 1 => {
100 mlx_sys::mlx_array_free(value);
101 return Err(crate::error::get_and_clear_last_mlx_error()
102 .expect("A non-success status was returned, but no error was set.")
103 .into());
104 }
105 2 => {
106 mlx_sys::mlx_array_free(value);
107 break;
108 }
109 _ => unreachable!(),
110 }
111 }
112
113 mlx_sys::mlx_map_string_to_array_iterator_free(iterator);
114 }
115
116 Ok(map)
117 }
118
119 pub(crate) fn metadata(&self) -> Result<HashMap<String, String>, Exception> {
120 crate::error::INIT_ERR_HANDLER
121 .with(|init| init.call_once(crate::error::setup_mlx_error_handler));
122
123 let mut map = HashMap::new();
124 unsafe {
125 let iterator = mlx_sys::mlx_map_string_to_string_iterator_new(self.c_metadata);
126
127 let mut key: *const ::std::os::raw::c_char = null_mut();
128 let mut value: *const ::std::os::raw::c_char = null_mut();
129 loop {
130 let status = mlx_sys::mlx_map_string_to_string_iterator_next(
131 &mut key as *mut *const _,
132 &mut value as *mut *const _,
133 iterator,
134 );
135
136 match status {
137 SUCCESS => {
138 let key = CStr::from_ptr(key).to_string_lossy().into_owned();
139 let value = CStr::from_ptr(value).to_string_lossy().into_owned();
140 map.insert(key, value);
141 }
142 1 => {
143 return Err(crate::error::get_and_clear_last_mlx_error()
144 .expect("A non-success status was returned, but no error was set.")
145 .into())
146 }
147 2 => break,
148 _ => unreachable!(),
149 }
150 }
151 }
152
153 Ok(map)
154 }
155}