mlx_rs/
device.rs

1use std::ffi::CStr;
2
3use crate::{
4    error::Result,
5    utils::{guard::Guarded, SUCCESS},
6};
7
8///Type of device.
9#[derive(num_enum::IntoPrimitive, Debug, Clone, Copy)]
10#[repr(u32)]
11pub enum DeviceType {
12    /// CPU device
13    Cpu = mlx_sys::mlx_device_type__MLX_CPU,
14
15    /// GPU device
16    Gpu = mlx_sys::mlx_device_type__MLX_GPU,
17}
18
19/// Representation of a Device in MLX.
20pub struct Device {
21    pub(crate) c_device: mlx_sys::mlx_device,
22}
23
24impl PartialEq for Device {
25    fn eq(&self, other: &Self) -> bool {
26        unsafe { mlx_sys::mlx_device_equal(self.c_device, other.c_device) }
27    }
28}
29
30impl Device {
31    /// Create a new [`Device`]
32    pub fn new(device_type: DeviceType, index: i32) -> Device {
33        let c_device = unsafe { mlx_sys::mlx_device_new_type(device_type.into(), index) };
34        Device { c_device }
35    }
36
37    /// Try to get the default device.
38    pub fn try_default() -> Result<Self> {
39        Device::try_from_op(|res| unsafe { mlx_sys::mlx_get_default_device(res) })
40    }
41
42    /// Create a default CPU device.
43    pub fn cpu() -> Device {
44        Device::new(DeviceType::Cpu, 0)
45    }
46
47    /// Create a default GPU device.
48    pub fn gpu() -> Device {
49        Device::new(DeviceType::Gpu, 0)
50    }
51
52    /// Get the device index
53    pub fn get_index(&self) -> Result<i32> {
54        i32::try_from_op(|res| unsafe { mlx_sys::mlx_device_get_index(res, self.c_device) })
55    }
56
57    /// Get the device type
58    pub fn get_type(&self) -> Result<DeviceType> {
59        DeviceType::try_from_op(|res| unsafe { mlx_sys::mlx_device_get_type(res, self.c_device) })
60    }
61
62    /// Set the default device.
63    ///
64    /// # Example:
65    ///
66    /// ```rust
67    /// use mlx_rs::{Device, DeviceType};
68    /// Device::set_default(&Device::new(DeviceType::Cpu, 1));
69    /// ```
70    ///
71    /// By default, this is `gpu()`.
72    pub fn set_default(device: &Device) {
73        unsafe { mlx_sys::mlx_set_default_device(device.c_device) };
74    }
75
76    fn describe(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
77        unsafe {
78            let mut mlx_str = mlx_sys::mlx_string_new();
79            let result = match mlx_sys::mlx_device_tostring(&mut mlx_str as *mut _, self.c_device) {
80                SUCCESS => {
81                    let ptr = mlx_sys::mlx_string_data(mlx_str);
82                    let c_str = CStr::from_ptr(ptr);
83                    write!(f, "{}", c_str.to_string_lossy())
84                }
85                _ => Err(std::fmt::Error),
86            };
87            mlx_sys::mlx_string_free(mlx_str);
88            result
89        }
90    }
91}
92
93impl Drop for Device {
94    fn drop(&mut self) {
95        let status = unsafe { mlx_sys::mlx_device_free(self.c_device) };
96        debug_assert_eq!(status, SUCCESS);
97    }
98}
99
100impl Default for Device {
101    fn default() -> Self {
102        Self::try_default().unwrap()
103    }
104}
105
106impl std::fmt::Debug for Device {
107    fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
108        self.describe(f)
109    }
110}
111
112impl std::fmt::Display for Device {
113    fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
114        self.describe(f)
115    }
116}
117
118#[cfg(test)]
119mod tests {
120    use super::*;
121
122    #[test]
123    fn test_fmt() {
124        let device = Device::default();
125        let description = format!("{}", device);
126        assert_eq!(description, "Device(gpu, 0)");
127    }
128}