mlx_rs/
stream.rs

1use std::ffi::CStr;
2
3use crate::{
4    device::Device,
5    error::Result,
6    utils::{guard::Guarded, SUCCESS},
7};
8
9/// Parameter type for all MLX operations.
10///
11/// Use this to control where operations are evaluated:
12///
13/// If omitted it will use the [Default::default()], which will be [Device::gpu()] unless
14/// set otherwise.
15#[derive(PartialEq)]
16pub struct StreamOrDevice {
17    pub(crate) stream: Stream,
18}
19
20impl StreamOrDevice {
21    /// Create a new [`StreamOrDevice`] with a [`Stream`].
22    pub fn new(stream: Stream) -> StreamOrDevice {
23        StreamOrDevice { stream }
24    }
25
26    /// Create a new [`StreamOrDevice`] with a [`Device`].
27    pub fn new_with_device(device: &Device) -> StreamOrDevice {
28        StreamOrDevice {
29            stream: Stream::new_with_device(device),
30        }
31    }
32
33    /// Current default CPU stream.
34    pub fn cpu() -> StreamOrDevice {
35        StreamOrDevice {
36            stream: Stream::cpu(),
37        }
38    }
39
40    /// Current default GPU stream.
41    pub fn gpu() -> StreamOrDevice {
42        StreamOrDevice {
43            stream: Stream::gpu(),
44        }
45    }
46}
47
48impl Default for StreamOrDevice {
49    /// The default stream on the default device.
50    ///
51    /// This will be [Device::gpu()] unless [Device::set_default()]
52    /// sets it otherwise.
53    fn default() -> Self {
54        Self {
55            stream: Stream::new(),
56        }
57    }
58}
59
60impl AsRef<Stream> for StreamOrDevice {
61    fn as_ref(&self) -> &Stream {
62        &self.stream
63    }
64}
65
66impl std::fmt::Debug for StreamOrDevice {
67    fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
68        write!(f, "{}", self.stream)
69    }
70}
71
72impl std::fmt::Display for StreamOrDevice {
73    fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
74        write!(f, "{}", self.stream)
75    }
76}
77
78/// A stream of evaluation attached to a particular device.
79///
80/// Typically, this is used via the `stream:` parameter on a method with a [StreamOrDevice]:
81pub struct Stream {
82    pub(crate) c_stream: mlx_sys::mlx_stream,
83}
84
85impl AsRef<Stream> for Stream {
86    fn as_ref(&self) -> &Stream {
87        self
88    }
89}
90
91impl Stream {
92    /// Create a new stream on the default device. Panics if fails.
93    pub fn new() -> Stream {
94        unsafe {
95            let mut dev = mlx_sys::mlx_device_new();
96            // SAFETY: mlx_get_default_device internally never throws an error
97            mlx_sys::mlx_get_default_device(&mut dev as *mut _);
98
99            let mut c_stream = mlx_sys::mlx_stream_new();
100            // SAFETY: mlx_get_default_stream internally never throws if dev is valid
101            mlx_sys::mlx_get_default_stream(&mut c_stream as *mut _, dev);
102
103            mlx_sys::mlx_device_free(dev);
104            Stream { c_stream }
105        }
106    }
107
108    /// Try to get the default stream on the given device.
109    pub fn try_default_on_device(device: &Device) -> Result<Stream> {
110        Stream::try_from_op(|res| unsafe { mlx_sys::mlx_get_default_stream(res, device.c_device) })
111    }
112
113    /// Create a new stream on the given device
114    pub fn new_with_device(device: &Device) -> Stream {
115        unsafe {
116            let c_stream = mlx_sys::mlx_stream_new_device(device.c_device);
117            Stream { c_stream }
118        }
119    }
120
121    /// Get the underlying C pointer.
122    pub fn as_ptr(&self) -> mlx_sys::mlx_stream {
123        self.c_stream
124    }
125
126    /// Current default CPU stream.
127    pub fn cpu() -> Self {
128        unsafe {
129            let c_stream = mlx_sys::mlx_default_cpu_stream_new();
130            Stream { c_stream }
131        }
132    }
133
134    /// Current default GPU stream.
135    pub fn gpu() -> Self {
136        unsafe {
137            let c_stream = mlx_sys::mlx_default_gpu_stream_new();
138            Stream { c_stream }
139        }
140    }
141
142    /// Get the index of the stream.
143    pub fn get_index(&self) -> Result<i32> {
144        i32::try_from_op(|res| unsafe { mlx_sys::mlx_stream_get_index(res, self.c_stream) })
145    }
146
147    fn describe(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
148        unsafe {
149            let mut mlx_str = mlx_sys::mlx_string_new();
150            let result = match mlx_sys::mlx_stream_tostring(&mut mlx_str as *mut _, self.c_stream) {
151                SUCCESS => {
152                    let ptr = mlx_sys::mlx_string_data(mlx_str);
153                    let c_str = CStr::from_ptr(ptr);
154                    write!(f, "{}", c_str.to_string_lossy())
155                }
156                _ => Err(std::fmt::Error),
157            };
158            mlx_sys::mlx_string_free(mlx_str);
159            result
160        }
161    }
162}
163
164impl Drop for Stream {
165    fn drop(&mut self) {
166        unsafe { mlx_sys::mlx_stream_free(self.c_stream) };
167    }
168}
169
170impl Default for Stream {
171    fn default() -> Self {
172        Stream::new()
173    }
174}
175
176impl std::fmt::Debug for Stream {
177    fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
178        self.describe(f)
179    }
180}
181
182impl std::fmt::Display for Stream {
183    fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
184        self.describe(f)
185    }
186}
187
188impl PartialEq for Stream {
189    fn eq(&self, other: &Self) -> bool {
190        unsafe { mlx_sys::mlx_stream_equal(self.c_stream, other.c_stream) }
191    }
192}