1use std::{cell::RefCell, ffi::CStr};
2
3use crate::{
4 device::Device,
5 error::Result,
6 utils::{guard::Guarded, SUCCESS},
7};
8
9thread_local! {
10 static TASK_LOCAL_DEFAULT_STREAM: RefCell<Option<Stream>> = const { RefCell::new(None) };
11}
12
13pub fn task_local_default_stream() -> Option<Stream> {
18 TASK_LOCAL_DEFAULT_STREAM.with_borrow(|s| s.clone())
19}
20
21pub fn with_new_default_stream<F, T>(default_stream: Stream, f: F) -> T
23where
24 F: FnOnce() -> T,
25{
26 let prev_stream = TASK_LOCAL_DEFAULT_STREAM.with_borrow_mut(|s| s.replace(default_stream));
27
28 let result = f();
29
30 TASK_LOCAL_DEFAULT_STREAM.with_borrow_mut(|s| {
31 *s = prev_stream;
32 });
33
34 result
35}
36
37#[derive(PartialEq)]
44pub struct StreamOrDevice {
45 pub(crate) stream: Stream,
46}
47
48impl StreamOrDevice {
49 pub fn new(stream: Stream) -> StreamOrDevice {
51 StreamOrDevice { stream }
52 }
53
54 pub fn new_with_device(device: &Device) -> StreamOrDevice {
56 StreamOrDevice {
57 stream: Stream::new_with_device(device),
58 }
59 }
60
61 pub fn cpu() -> StreamOrDevice {
63 StreamOrDevice {
64 stream: Stream::cpu(),
65 }
66 }
67
68 pub fn gpu() -> StreamOrDevice {
70 StreamOrDevice {
71 stream: Stream::gpu(),
72 }
73 }
74}
75
76impl Default for StreamOrDevice {
77 fn default() -> Self {
82 Self {
83 stream: Stream::new(),
84 }
85 }
86}
87
88impl AsRef<Stream> for StreamOrDevice {
89 fn as_ref(&self) -> &Stream {
90 &self.stream
91 }
92}
93
94impl std::fmt::Debug for StreamOrDevice {
95 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
96 write!(f, "{}", self.stream)
97 }
98}
99
100impl std::fmt::Display for StreamOrDevice {
101 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
102 write!(f, "{}", self.stream)
103 }
104}
105
106pub struct Stream {
110 pub(crate) c_stream: mlx_sys::mlx_stream,
111}
112
113impl AsRef<Stream> for Stream {
114 fn as_ref(&self) -> &Stream {
115 self
116 }
117}
118
119impl Clone for Stream {
120 fn clone(&self) -> Self {
121 Stream::try_from_op(|res| unsafe { mlx_sys::mlx_stream_set(res, self.c_stream) })
122 .expect("Failed to clone stream")
123 }
124}
125
126impl Stream {
127 pub fn task_local_or_default() -> Self {
130 task_local_default_stream().unwrap_or_default()
131 }
132
133 pub fn task_local_or_cpu() -> Self {
136 task_local_default_stream().unwrap_or_else(Stream::cpu)
137 }
138
139 pub fn task_local_or_gpu() -> Self {
142 task_local_default_stream().unwrap_or_else(Stream::gpu)
143 }
144
145 pub fn new() -> Stream {
147 unsafe {
148 let mut dev = mlx_sys::mlx_device_new();
149 mlx_sys::mlx_get_default_device(&mut dev as *mut _);
151
152 let mut c_stream = mlx_sys::mlx_stream_new();
153 mlx_sys::mlx_get_default_stream(&mut c_stream as *mut _, dev);
155
156 mlx_sys::mlx_device_free(dev);
157 Stream { c_stream }
158 }
159 }
160
161 pub fn try_default_on_device(device: &Device) -> Result<Stream> {
163 Stream::try_from_op(|res| unsafe { mlx_sys::mlx_get_default_stream(res, device.c_device) })
164 }
165
166 pub fn new_with_device(device: &Device) -> Stream {
168 unsafe {
169 let c_stream = mlx_sys::mlx_stream_new_device(device.c_device);
170 Stream { c_stream }
171 }
172 }
173
174 pub fn as_ptr(&self) -> mlx_sys::mlx_stream {
176 self.c_stream
177 }
178
179 pub fn cpu() -> Self {
181 unsafe {
182 let c_stream = mlx_sys::mlx_default_cpu_stream_new();
183 Stream { c_stream }
184 }
185 }
186
187 pub fn gpu() -> Self {
189 unsafe {
190 let c_stream = mlx_sys::mlx_default_gpu_stream_new();
191 Stream { c_stream }
192 }
193 }
194
195 pub fn get_index(&self) -> Result<i32> {
197 i32::try_from_op(|res| unsafe { mlx_sys::mlx_stream_get_index(res, self.c_stream) })
198 }
199
200 fn describe(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
201 unsafe {
202 let mut mlx_str = mlx_sys::mlx_string_new();
203 let result = match mlx_sys::mlx_stream_tostring(&mut mlx_str as *mut _, self.c_stream) {
204 SUCCESS => {
205 let ptr = mlx_sys::mlx_string_data(mlx_str);
206 let c_str = CStr::from_ptr(ptr);
207 write!(f, "{}", c_str.to_string_lossy())
208 }
209 _ => Err(std::fmt::Error),
210 };
211 mlx_sys::mlx_string_free(mlx_str);
212 result
213 }
214 }
215}
216
217impl Drop for Stream {
218 fn drop(&mut self) {
219 unsafe { mlx_sys::mlx_stream_free(self.c_stream) };
220 }
221}
222
223impl Default for Stream {
224 fn default() -> Self {
225 Stream::new()
226 }
227}
228
229impl std::fmt::Debug for Stream {
230 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
231 self.describe(f)
232 }
233}
234
235impl std::fmt::Display for Stream {
236 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
237 self.describe(f)
238 }
239}
240
241impl PartialEq for Stream {
242 fn eq(&self, other: &Self) -> bool {
243 unsafe { mlx_sys::mlx_stream_equal(self.c_stream, other.c_stream) }
244 }
245}
246
247#[cfg(test)]
248mod tests {
249 use super::*;
250
251 #[test]
252 fn test_scoped_default_stream() {
253 let cpu_device = Device::cpu();
255 Device::set_default(&cpu_device);
256 let cpu_stream = Stream::default();
257
258 let task_default_stream = Stream::gpu();
259 with_new_default_stream(task_default_stream, || {
260 let task_local_stream_0 = Stream::task_local_or_default();
261 let task_local_stream_1 = Stream::task_local_or_default();
262 assert_eq!(task_local_stream_0, task_local_stream_1);
263 assert_ne!(task_local_stream_0, cpu_stream);
264 });
265 }
266
267 #[test]
268 fn test_stream_clone() {
269 let stream = Stream::new();
270 let cloned_stream = stream.clone();
271 assert_eq!(stream, cloned_stream);
272 }
273
274 #[test]
275 fn test_cpu_gpu_stream_not_equal() {
276 let cpu_device = Device::cpu();
277 let gpu_device = Device::gpu();
278
279 Device::set_default(&cpu_device);
281 let cpu_stream = Stream::default();
282
283 Device::set_default(&gpu_device);
285 let gpu_stream = Stream::default();
286
287 assert_ne!(cpu_stream, gpu_stream);
289 }
290}