mlx_rs/utils/
io.rs

1use crate::error::{Exception, IoError};
2use crate::utils::SUCCESS;
3use crate::{Array, Stream};
4use mlx_sys::FILE;
5use std::collections::HashMap;
6use std::ffi::{CStr, CString};
7use std::path::Path;
8use std::ptr::{null_mut, NonNull};
9
10use super::Guarded;
11
12pub(crate) struct FilePtr(NonNull<FILE>);
13
14impl Drop for FilePtr {
15    fn drop(&mut self) {
16        unsafe {
17            mlx_sys::fclose(self.0.as_ptr());
18        }
19    }
20}
21
22impl FilePtr {
23    pub(crate) fn as_ptr(&self) -> *mut FILE {
24        self.0.as_ptr()
25    }
26
27    pub(crate) fn open(path: &Path, mode: &str) -> Result<Self, IoError> {
28        let path = CString::new(path.to_str().ok_or(IoError::InvalidUtf8)?)?;
29        let mode = CString::new(mode)?;
30
31        unsafe {
32            NonNull::new(mlx_sys::fopen(path.as_ptr(), mode.as_ptr()))
33                .map(Self)
34                .ok_or(IoError::UnableToOpenFile)
35        }
36    }
37}
38
39pub(crate) struct SafeTensors {
40    pub(crate) c_data: mlx_sys::mlx_map_string_to_array,
41    pub(crate) c_metadata: mlx_sys::mlx_map_string_to_string,
42}
43
44impl Drop for SafeTensors {
45    fn drop(&mut self) {
46        unsafe {
47            mlx_sys::mlx_map_string_to_string_free(self.c_metadata);
48            mlx_sys::mlx_map_string_to_array_free(self.c_data);
49        }
50    }
51}
52
53impl SafeTensors {
54    pub(crate) fn load_device(path: &Path, stream: impl AsRef<Stream>) -> Result<Self, IoError> {
55        if !path.is_file() {
56            return Err(IoError::NotFile);
57        }
58
59        let extension = path
60            .extension()
61            .and_then(|ext| ext.to_str())
62            .ok_or(IoError::UnsupportedFormat)?;
63
64        if extension != "safetensors" {
65            return Err(IoError::UnsupportedFormat);
66        }
67
68        let path_str = path.to_str().ok_or(IoError::InvalidUtf8)?;
69        let filepath = CString::new(path_str)?;
70
71        SafeTensors::try_from_op(|(res_0, res_1)| unsafe {
72            mlx_sys::mlx_load_safetensors(res_0, res_1, filepath.as_ptr(), stream.as_ref().as_ptr())
73        })
74        .map_err(Into::into)
75    }
76
77    pub(crate) fn data(&self) -> Result<HashMap<String, Array>, Exception> {
78        crate::error::INIT_ERR_HANDLER
79            .with(|init| init.call_once(crate::error::setup_mlx_error_handler));
80        let mut map = HashMap::new();
81        unsafe {
82            let iterator = mlx_sys::mlx_map_string_to_array_iterator_new(self.c_data);
83
84            loop {
85                let mut key_ptr: *const ::std::os::raw::c_char = null_mut();
86                let mut value = mlx_sys::mlx_array_new();
87                let status = mlx_sys::mlx_map_string_to_array_iterator_next(
88                    &mut key_ptr as *mut *const _,
89                    &mut value,
90                    iterator,
91                );
92
93                match status {
94                    SUCCESS => {
95                        let key = CStr::from_ptr(key_ptr).to_string_lossy().into_owned();
96                        let array = Array::from_ptr(value);
97                        map.insert(key, array);
98                    }
99                    1 => {
100                        mlx_sys::mlx_array_free(value);
101                        return Err(crate::error::get_and_clear_last_mlx_error()
102                            .expect("A non-success status was returned, but no error was set.")
103                            .into());
104                    }
105                    2 => {
106                        mlx_sys::mlx_array_free(value);
107                        break;
108                    }
109                    _ => unreachable!(),
110                }
111            }
112
113            mlx_sys::mlx_map_string_to_array_iterator_free(iterator);
114        }
115
116        Ok(map)
117    }
118
119    pub(crate) fn metadata(&self) -> Result<HashMap<String, String>, Exception> {
120        crate::error::INIT_ERR_HANDLER
121            .with(|init| init.call_once(crate::error::setup_mlx_error_handler));
122
123        let mut map = HashMap::new();
124        unsafe {
125            let iterator = mlx_sys::mlx_map_string_to_string_iterator_new(self.c_metadata);
126
127            let mut key: *const ::std::os::raw::c_char = null_mut();
128            let mut value: *const ::std::os::raw::c_char = null_mut();
129            loop {
130                let status = mlx_sys::mlx_map_string_to_string_iterator_next(
131                    &mut key as *mut *const _,
132                    &mut value as *mut *const _,
133                    iterator,
134                );
135
136                match status {
137                    SUCCESS => {
138                        let key = CStr::from_ptr(key).to_string_lossy().into_owned();
139                        let value = CStr::from_ptr(value).to_string_lossy().into_owned();
140                        map.insert(key, value);
141                    }
142                    1 => {
143                        return Err(crate::error::get_and_clear_last_mlx_error()
144                            .expect("A non-success status was returned, but no error was set.")
145                            .into())
146                    }
147                    2 => break,
148                    _ => unreachable!(),
149                }
150            }
151        }
152
153        Ok(map)
154    }
155}