1use std::ffi::CStr;
2
3use crate::{
4 error::Result,
5 utils::{guard::Guarded, SUCCESS},
6};
7
8#[derive(num_enum::IntoPrimitive, Debug, Clone, Copy)]
10#[repr(u32)]
11pub enum DeviceType {
12 Cpu = mlx_sys::mlx_device_type__MLX_CPU,
14
15 Gpu = mlx_sys::mlx_device_type__MLX_GPU,
17}
18
19pub struct Device {
21 pub(crate) c_device: mlx_sys::mlx_device,
22}
23
24impl PartialEq for Device {
25 fn eq(&self, other: &Self) -> bool {
26 unsafe { mlx_sys::mlx_device_equal(self.c_device, other.c_device) }
27 }
28}
29
30impl Device {
31 pub fn new(device_type: DeviceType, index: i32) -> Device {
33 let c_device = unsafe { mlx_sys::mlx_device_new_type(device_type.into(), index) };
34 Device { c_device }
35 }
36
37 pub fn try_default() -> Result<Self> {
39 Device::try_from_op(|res| unsafe { mlx_sys::mlx_get_default_device(res) })
40 }
41
42 pub fn cpu() -> Device {
44 Device::new(DeviceType::Cpu, 0)
45 }
46
47 pub fn gpu() -> Device {
49 Device::new(DeviceType::Gpu, 0)
50 }
51
52 pub fn get_index(&self) -> Result<i32> {
54 i32::try_from_op(|res| unsafe { mlx_sys::mlx_device_get_index(res, self.c_device) })
55 }
56
57 pub fn get_type(&self) -> Result<DeviceType> {
59 DeviceType::try_from_op(|res| unsafe { mlx_sys::mlx_device_get_type(res, self.c_device) })
60 }
61
62 pub fn set_default(device: &Device) {
73 unsafe { mlx_sys::mlx_set_default_device(device.c_device) };
74 }
75
76 fn describe(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
77 unsafe {
78 let mut mlx_str = mlx_sys::mlx_string_new();
79 let result = match mlx_sys::mlx_device_tostring(&mut mlx_str as *mut _, self.c_device) {
80 SUCCESS => {
81 let ptr = mlx_sys::mlx_string_data(mlx_str);
82 let c_str = CStr::from_ptr(ptr);
83 write!(f, "{}", c_str.to_string_lossy())
84 }
85 _ => Err(std::fmt::Error),
86 };
87 mlx_sys::mlx_string_free(mlx_str);
88 result
89 }
90 }
91}
92
93impl Drop for Device {
94 fn drop(&mut self) {
95 let status = unsafe { mlx_sys::mlx_device_free(self.c_device) };
96 debug_assert_eq!(status, SUCCESS);
97 }
98}
99
100impl Default for Device {
101 fn default() -> Self {
102 Self::try_default().unwrap()
103 }
104}
105
106impl std::fmt::Debug for Device {
107 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
108 self.describe(f)
109 }
110}
111
112impl std::fmt::Display for Device {
113 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
114 self.describe(f)
115 }
116}
117
118#[cfg(test)]
119mod tests {
120 use super::*;
121
122 #[test]
123 fn test_fmt() {
124 let device = Device::default();
125 let description = format!("{}", device);
126 assert_eq!(description, "Device(gpu, 0)");
127 }
128}