1use std::ffi::CStr;
2
3use crate::{
4 device::Device,
5 error::Result,
6 utils::{guard::Guarded, SUCCESS},
7};
8
9#[derive(PartialEq)]
16pub struct StreamOrDevice {
17 pub(crate) stream: Stream,
18}
19
20impl StreamOrDevice {
21 pub fn new(stream: Stream) -> StreamOrDevice {
23 StreamOrDevice { stream }
24 }
25
26 pub fn new_with_device(device: &Device) -> StreamOrDevice {
28 StreamOrDevice {
29 stream: Stream::new_with_device(device),
30 }
31 }
32
33 pub fn cpu() -> StreamOrDevice {
35 StreamOrDevice {
36 stream: Stream::cpu(),
37 }
38 }
39
40 pub fn gpu() -> StreamOrDevice {
42 StreamOrDevice {
43 stream: Stream::gpu(),
44 }
45 }
46}
47
48impl Default for StreamOrDevice {
49 fn default() -> Self {
54 Self {
55 stream: Stream::new(),
56 }
57 }
58}
59
60impl AsRef<Stream> for StreamOrDevice {
61 fn as_ref(&self) -> &Stream {
62 &self.stream
63 }
64}
65
66impl std::fmt::Debug for StreamOrDevice {
67 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
68 write!(f, "{}", self.stream)
69 }
70}
71
72impl std::fmt::Display for StreamOrDevice {
73 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
74 write!(f, "{}", self.stream)
75 }
76}
77
78pub struct Stream {
82 pub(crate) c_stream: mlx_sys::mlx_stream,
83}
84
85impl AsRef<Stream> for Stream {
86 fn as_ref(&self) -> &Stream {
87 self
88 }
89}
90
91impl Stream {
92 pub fn new() -> Stream {
94 unsafe {
95 let mut dev = mlx_sys::mlx_device_new();
96 mlx_sys::mlx_get_default_device(&mut dev as *mut _);
98
99 let mut c_stream = mlx_sys::mlx_stream_new();
100 mlx_sys::mlx_get_default_stream(&mut c_stream as *mut _, dev);
102
103 mlx_sys::mlx_device_free(dev);
104 Stream { c_stream }
105 }
106 }
107
108 pub fn try_default_on_device(device: &Device) -> Result<Stream> {
110 Stream::try_from_op(|res| unsafe { mlx_sys::mlx_get_default_stream(res, device.c_device) })
111 }
112
113 pub fn new_with_device(device: &Device) -> Stream {
115 unsafe {
116 let c_stream = mlx_sys::mlx_stream_new_device(device.c_device);
117 Stream { c_stream }
118 }
119 }
120
121 pub fn as_ptr(&self) -> mlx_sys::mlx_stream {
123 self.c_stream
124 }
125
126 pub fn cpu() -> Self {
128 unsafe {
129 let c_stream = mlx_sys::mlx_default_cpu_stream_new();
130 Stream { c_stream }
131 }
132 }
133
134 pub fn gpu() -> Self {
136 unsafe {
137 let c_stream = mlx_sys::mlx_default_gpu_stream_new();
138 Stream { c_stream }
139 }
140 }
141
142 pub fn get_index(&self) -> Result<i32> {
144 i32::try_from_op(|res| unsafe { mlx_sys::mlx_stream_get_index(res, self.c_stream) })
145 }
146
147 fn describe(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
148 unsafe {
149 let mut mlx_str = mlx_sys::mlx_string_new();
150 let result = match mlx_sys::mlx_stream_tostring(&mut mlx_str as *mut _, self.c_stream) {
151 SUCCESS => {
152 let ptr = mlx_sys::mlx_string_data(mlx_str);
153 let c_str = CStr::from_ptr(ptr);
154 write!(f, "{}", c_str.to_string_lossy())
155 }
156 _ => Err(std::fmt::Error),
157 };
158 mlx_sys::mlx_string_free(mlx_str);
159 result
160 }
161 }
162}
163
164impl Drop for Stream {
165 fn drop(&mut self) {
166 unsafe { mlx_sys::mlx_stream_free(self.c_stream) };
167 }
168}
169
170impl Default for Stream {
171 fn default() -> Self {
172 Stream::new()
173 }
174}
175
176impl std::fmt::Debug for Stream {
177 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
178 self.describe(f)
179 }
180}
181
182impl std::fmt::Display for Stream {
183 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
184 self.describe(f)
185 }
186}
187
188impl PartialEq for Stream {
189 fn eq(&self, other: &Self) -> bool {
190 unsafe { mlx_sys::mlx_stream_equal(self.c_stream, other.c_stream) }
191 }
192}