mlx_rs/
stream.rs

1use std::{cell::RefCell, ffi::CStr};
2
3use crate::{
4    device::Device,
5    error::Result,
6    utils::{guard::Guarded, SUCCESS},
7};
8
9thread_local! {
10    static TASK_LOCAL_DEFAULT_STREAM: RefCell<Option<Stream>> = const { RefCell::new(None) };
11}
12
13/// Gets the task local default stream.
14///
15/// This is NOT intended to be used directly in most cases. Instead, use the
16/// `with_default_stream` function to temporarily set a default stream for a closure.
17pub fn task_local_default_stream() -> Option<Stream> {
18    TASK_LOCAL_DEFAULT_STREAM.with_borrow(|s| s.clone())
19}
20
21/// Use a given default stream for the duration of the closure `f`.
22pub fn with_new_default_stream<F, T>(default_stream: Stream, f: F) -> T
23where
24    F: FnOnce() -> T,
25{
26    let prev_stream = TASK_LOCAL_DEFAULT_STREAM.with_borrow_mut(|s| s.replace(default_stream));
27
28    let result = f();
29
30    TASK_LOCAL_DEFAULT_STREAM.with_borrow_mut(|s| {
31        *s = prev_stream;
32    });
33
34    result
35}
36
37/// Parameter type for all MLX operations.
38///
39/// Use this to control where operations are evaluated:
40///
41/// If omitted it will use the [Default::default()], which will be [Device::gpu()] unless
42/// set otherwise.
43#[derive(PartialEq)]
44pub struct StreamOrDevice {
45    pub(crate) stream: Stream,
46}
47
48impl StreamOrDevice {
49    /// Create a new [`StreamOrDevice`] with a [`Stream`].
50    pub fn new(stream: Stream) -> StreamOrDevice {
51        StreamOrDevice { stream }
52    }
53
54    /// Create a new [`StreamOrDevice`] with a [`Device`].
55    pub fn new_with_device(device: &Device) -> StreamOrDevice {
56        StreamOrDevice {
57            stream: Stream::new_with_device(device),
58        }
59    }
60
61    /// Current default CPU stream.
62    pub fn cpu() -> StreamOrDevice {
63        StreamOrDevice {
64            stream: Stream::cpu(),
65        }
66    }
67
68    /// Current default GPU stream.
69    pub fn gpu() -> StreamOrDevice {
70        StreamOrDevice {
71            stream: Stream::gpu(),
72        }
73    }
74}
75
76impl Default for StreamOrDevice {
77    /// The default stream on the default device.
78    ///
79    /// This will be [Device::gpu()] unless [Device::set_default()]
80    /// sets it otherwise.
81    fn default() -> Self {
82        Self {
83            stream: Stream::new(),
84        }
85    }
86}
87
88impl AsRef<Stream> for StreamOrDevice {
89    fn as_ref(&self) -> &Stream {
90        &self.stream
91    }
92}
93
94impl std::fmt::Debug for StreamOrDevice {
95    fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
96        write!(f, "{}", self.stream)
97    }
98}
99
100impl std::fmt::Display for StreamOrDevice {
101    fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
102        write!(f, "{}", self.stream)
103    }
104}
105
106/// A stream of evaluation attached to a particular device.
107///
108/// Typically, this is used via the `stream:` parameter on a method with a [StreamOrDevice]:
109pub struct Stream {
110    pub(crate) c_stream: mlx_sys::mlx_stream,
111}
112
113impl AsRef<Stream> for Stream {
114    fn as_ref(&self) -> &Stream {
115        self
116    }
117}
118
119impl Clone for Stream {
120    fn clone(&self) -> Self {
121        Stream::try_from_op(|res| unsafe { mlx_sys::mlx_stream_set(res, self.c_stream) })
122            .expect("Failed to clone stream")
123    }
124}
125
126impl Stream {
127    /// Create a new stream on the default device, or return the task local
128    /// default stream if present.
129    pub fn task_local_or_default() -> Self {
130        task_local_default_stream().unwrap_or_default()
131    }
132
133    /// Create a new stream on the default cpu device, or return the task local
134    /// default stream if present.
135    pub fn task_local_or_cpu() -> Self {
136        task_local_default_stream().unwrap_or_else(Stream::cpu)
137    }
138
139    /// Create a new stream on the default gpu device, or return the task local
140    /// default stream if present.
141    pub fn task_local_or_gpu() -> Self {
142        task_local_default_stream().unwrap_or_else(Stream::gpu)
143    }
144
145    /// Create a new stream on the default device. Panics if fails.
146    pub fn new() -> Stream {
147        unsafe {
148            let mut dev = mlx_sys::mlx_device_new();
149            // SAFETY: mlx_get_default_device internally never throws an error
150            mlx_sys::mlx_get_default_device(&mut dev as *mut _);
151
152            let mut c_stream = mlx_sys::mlx_stream_new();
153            // SAFETY: mlx_get_default_stream internally never throws if dev is valid
154            mlx_sys::mlx_get_default_stream(&mut c_stream as *mut _, dev);
155
156            mlx_sys::mlx_device_free(dev);
157            Stream { c_stream }
158        }
159    }
160
161    /// Try to get the default stream on the given device.
162    pub fn try_default_on_device(device: &Device) -> Result<Stream> {
163        Stream::try_from_op(|res| unsafe { mlx_sys::mlx_get_default_stream(res, device.c_device) })
164    }
165
166    /// Create a new stream on the given device
167    pub fn new_with_device(device: &Device) -> Stream {
168        unsafe {
169            let c_stream = mlx_sys::mlx_stream_new_device(device.c_device);
170            Stream { c_stream }
171        }
172    }
173
174    /// Get the underlying C pointer.
175    pub fn as_ptr(&self) -> mlx_sys::mlx_stream {
176        self.c_stream
177    }
178
179    /// Current default CPU stream.
180    pub fn cpu() -> Self {
181        unsafe {
182            let c_stream = mlx_sys::mlx_default_cpu_stream_new();
183            Stream { c_stream }
184        }
185    }
186
187    /// Current default GPU stream.
188    pub fn gpu() -> Self {
189        unsafe {
190            let c_stream = mlx_sys::mlx_default_gpu_stream_new();
191            Stream { c_stream }
192        }
193    }
194
195    /// Get the index of the stream.
196    pub fn get_index(&self) -> Result<i32> {
197        i32::try_from_op(|res| unsafe { mlx_sys::mlx_stream_get_index(res, self.c_stream) })
198    }
199
200    fn describe(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
201        unsafe {
202            let mut mlx_str = mlx_sys::mlx_string_new();
203            let result = match mlx_sys::mlx_stream_tostring(&mut mlx_str as *mut _, self.c_stream) {
204                SUCCESS => {
205                    let ptr = mlx_sys::mlx_string_data(mlx_str);
206                    let c_str = CStr::from_ptr(ptr);
207                    write!(f, "{}", c_str.to_string_lossy())
208                }
209                _ => Err(std::fmt::Error),
210            };
211            mlx_sys::mlx_string_free(mlx_str);
212            result
213        }
214    }
215}
216
217impl Drop for Stream {
218    fn drop(&mut self) {
219        unsafe { mlx_sys::mlx_stream_free(self.c_stream) };
220    }
221}
222
223impl Default for Stream {
224    fn default() -> Self {
225        Stream::new()
226    }
227}
228
229impl std::fmt::Debug for Stream {
230    fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
231        self.describe(f)
232    }
233}
234
235impl std::fmt::Display for Stream {
236    fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
237        self.describe(f)
238    }
239}
240
241impl PartialEq for Stream {
242    fn eq(&self, other: &Self) -> bool {
243        unsafe { mlx_sys::mlx_stream_equal(self.c_stream, other.c_stream) }
244    }
245}
246
247#[cfg(test)]
248mod tests {
249    use super::*;
250
251    #[test]
252    fn test_scoped_default_stream() {
253        // First set default stream to CPU
254        let cpu_device = Device::cpu();
255        Device::set_default(&cpu_device);
256        let cpu_stream = Stream::default();
257
258        let task_default_stream = Stream::gpu();
259        with_new_default_stream(task_default_stream, || {
260            let task_local_stream_0 = Stream::task_local_or_default();
261            let task_local_stream_1 = Stream::task_local_or_default();
262            assert_eq!(task_local_stream_0, task_local_stream_1);
263            assert_ne!(task_local_stream_0, cpu_stream);
264        });
265    }
266
267    #[test]
268    fn test_stream_clone() {
269        let stream = Stream::new();
270        let cloned_stream = stream.clone();
271        assert_eq!(stream, cloned_stream);
272    }
273
274    #[test]
275    fn test_cpu_gpu_stream_not_equal() {
276        let cpu_device = Device::cpu();
277        let gpu_device = Device::gpu();
278
279        // First set default stream to CPU
280        Device::set_default(&cpu_device);
281        let cpu_stream = Stream::default();
282
283        // Then set default stream to GPU
284        Device::set_default(&gpu_device);
285        let gpu_stream = Stream::default();
286
287        // Assert that CPU and GPU streams are not equal
288        assert_ne!(cpu_stream, gpu_stream);
289    }
290}