mlx_internal_macros/lib.rs
1extern crate proc_macro;
2use darling::FromMeta;
3use proc_macro::TokenStream;
4use quote::{format_ident, quote};
5use syn::punctuated::Punctuated;
6use syn::{parse_macro_input, parse_quote, DeriveInput, FnArg, ItemEnum, ItemFn, Pat};
7
8mod derive_buildable;
9mod derive_builder;
10mod generate_builder;
11mod generate_macro;
12mod shared;
13
14#[derive(Debug, FromMeta)]
15enum DeviceType {
16 Cpu,
17 Gpu,
18}
19
20#[derive(Debug)]
21struct DefaultDeviceInput {
22 device: DeviceType,
23}
24
25impl FromMeta for DefaultDeviceInput {
26 fn from_meta(meta: &syn::Meta) -> darling::Result<Self> {
27 let syn::Meta::NameValue(meta_name_value) = meta else {
28 return Err(darling::Error::unsupported_format(
29 "expected a name-value attribute",
30 ));
31 };
32
33 let ident = meta_name_value.path.get_ident().unwrap();
34 assert_eq!(ident, "device", "expected `device`");
35
36 let device = DeviceType::from_expr(&meta_name_value.value)?;
37
38 Ok(DefaultDeviceInput { device })
39 }
40}
41
42#[doc(hidden)]
43#[proc_macro_attribute]
44pub fn default_device(attr: TokenStream, item: TokenStream) -> TokenStream {
45 let input = if !attr.is_empty() {
46 let meta = syn::parse_macro_input!(attr as syn::Meta);
47 Some(DefaultDeviceInput::from_meta(&meta).unwrap())
48 } else {
49 None
50 };
51
52 let mut input_fn = parse_macro_input!(item as ItemFn);
53 let original_fn = input_fn.clone();
54
55 // Ensure function name convention
56 if !input_fn.sig.ident.to_string().contains("_device") {
57 panic!("Function name must end with '_device'");
58 }
59 let new_fn_name = format_ident!("{}", &input_fn.sig.ident.to_string().replace("_device", ""));
60 input_fn.sig.ident = new_fn_name;
61
62 // Filter out the `stream` parameter and reconstruct the Punctuated collection
63 let filtered_inputs = input_fn
64 .sig
65 .inputs
66 .iter()
67 .filter(|arg| match arg {
68 FnArg::Typed(pat_typed) => {
69 if let Pat::Ident(pat_ident) = &*pat_typed.pat {
70 pat_ident.ident != "stream"
71 } else {
72 true
73 }
74 }
75 _ => true,
76 })
77 .cloned()
78 .collect::<Vec<_>>();
79
80 input_fn.sig.inputs = Punctuated::from_iter(filtered_inputs);
81
82 // Prepend default stream initialization
83 let default_stream_stmt = match input.map(|input| input.device) {
84 Some(DeviceType::Cpu) => parse_quote! {
85 let stream = StreamOrDevice::cpu();
86 },
87 Some(DeviceType::Gpu) => parse_quote! {
88 let stream = StreamOrDevice::gpu();
89 },
90 None => parse_quote! {
91 let stream = StreamOrDevice::default();
92 },
93 };
94 input_fn.block.stmts.insert(0, default_stream_stmt);
95
96 // Combine the original and modified functions into the output
97 let expanded = quote! {
98 #original_fn
99
100 #input_fn
101 };
102
103 TokenStream::from(expanded)
104}
105
106#[doc(hidden)]
107#[proc_macro]
108pub fn generate_test_cases(input: TokenStream) -> TokenStream {
109 let input = parse_macro_input!(input as ItemEnum);
110 let name = &input.ident;
111
112 let tests = quote! {
113 /// MLX's rules for promoting two dtypes.
114 #[rustfmt::skip]
115 const TYPE_RULES: [[Dtype; 14]; 14] = [
116 // bool uint8 uint16 uint32 uint64 int8 int16 int32 int64 float16 float32 float64, bfloat16 complex64
117 [Dtype::Bool, Dtype::Uint8, Dtype::Uint16, Dtype::Uint32, Dtype::Uint64, Dtype::Int8, Dtype::Int16, Dtype::Int32, Dtype::Int64, Dtype::Float16, Dtype::Float32, Dtype::Float64, Dtype::Bfloat16, Dtype::Complex64], // bool
118 [Dtype::Uint8, Dtype::Uint8, Dtype::Uint16, Dtype::Uint32, Dtype::Uint64, Dtype::Int16, Dtype::Int16, Dtype::Int32, Dtype::Int64, Dtype::Float16, Dtype::Float32, Dtype::Float64, Dtype::Bfloat16, Dtype::Complex64], // uint8
119 [Dtype::Uint16, Dtype::Uint16, Dtype::Uint16, Dtype::Uint32, Dtype::Uint64, Dtype::Int32, Dtype::Int32, Dtype::Int32, Dtype::Int64, Dtype::Float16, Dtype::Float32, Dtype::Float64, Dtype::Bfloat16, Dtype::Complex64], // uint16
120 [Dtype::Uint32, Dtype::Uint32, Dtype::Uint32, Dtype::Uint32, Dtype::Uint64, Dtype::Int64, Dtype::Int64, Dtype::Int64, Dtype::Int64, Dtype::Float16, Dtype::Float32, Dtype::Float64, Dtype::Bfloat16, Dtype::Complex64], // uint32
121 [Dtype::Uint64, Dtype::Uint64, Dtype::Uint64, Dtype::Uint64, Dtype::Uint64, Dtype::Float32, Dtype::Float32, Dtype::Float32, Dtype::Float32, Dtype::Float16, Dtype::Float32, Dtype::Float64, Dtype::Bfloat16, Dtype::Complex64], // uint64
122 [Dtype::Int8, Dtype::Int16, Dtype::Int32, Dtype::Int64, Dtype::Float32, Dtype::Int8, Dtype::Int16, Dtype::Int32, Dtype::Int64, Dtype::Float16, Dtype::Float32, Dtype::Float64, Dtype::Bfloat16, Dtype::Complex64], // int8
123 [Dtype::Int16, Dtype::Int16, Dtype::Int32, Dtype::Int64, Dtype::Float32, Dtype::Int16, Dtype::Int16, Dtype::Int32, Dtype::Int64, Dtype::Float16, Dtype::Float32, Dtype::Float64, Dtype::Bfloat16, Dtype::Complex64], // int16
124 [Dtype::Int32, Dtype::Int32, Dtype::Int32, Dtype::Int64, Dtype::Float32, Dtype::Int32, Dtype::Int32, Dtype::Int32, Dtype::Int64, Dtype::Float16, Dtype::Float32, Dtype::Float64, Dtype::Bfloat16, Dtype::Complex64], // int32
125 [Dtype::Int64, Dtype::Int64, Dtype::Int64, Dtype::Int64, Dtype::Float32, Dtype::Int64, Dtype::Int64, Dtype::Int64, Dtype::Int64, Dtype::Float16, Dtype::Float32, Dtype::Float64, Dtype::Bfloat16, Dtype::Complex64], // int64
126 [Dtype::Float16, Dtype::Float16, Dtype::Float16, Dtype::Float16, Dtype::Float16, Dtype::Float16, Dtype::Float16, Dtype::Float16, Dtype::Float16, Dtype::Float16, Dtype::Float32, Dtype::Float64, Dtype::Float32, Dtype::Complex64], // float16
127 [Dtype::Float32, Dtype::Float32, Dtype::Float32, Dtype::Float32, Dtype::Float32, Dtype::Float32, Dtype::Float32, Dtype::Float32, Dtype::Float32, Dtype::Float32, Dtype::Float32, Dtype::Float64, Dtype::Float32, Dtype::Complex64], // float32
128 [Dtype::Float64, Dtype::Float64, Dtype::Float64, Dtype::Float64, Dtype::Float64, Dtype::Float64, Dtype::Float64, Dtype::Float64, Dtype::Float64, Dtype::Float64, Dtype::Float64, Dtype::Float64, Dtype::Float64, Dtype::Complex64], // Dtype::Float64
129 [Dtype::Bfloat16, Dtype::Bfloat16, Dtype::Bfloat16, Dtype::Bfloat16, Dtype::Bfloat16, Dtype::Bfloat16, Dtype::Bfloat16, Dtype::Bfloat16, Dtype::Bfloat16, Dtype::Float32, Dtype::Float32, Dtype::Float64, Dtype::Bfloat16, Dtype::Complex64], // bfloat16
130 [Dtype::Complex64, Dtype::Complex64, Dtype::Complex64, Dtype::Complex64, Dtype::Complex64, Dtype::Complex64, Dtype::Complex64, Dtype::Complex64, Dtype::Complex64, Dtype::Complex64, Dtype::Complex64, Dtype::Complex64, Dtype::Complex64, Dtype::Complex64], // complex64
131 ];
132
133 #[cfg(test)]
134 mod generated_tests {
135 use super::*;
136 use strum::IntoEnumIterator;
137 use pretty_assertions::assert_eq;
138
139 #[test]
140 fn test_all_combinations() {
141 for a in #name::iter() {
142 for b in #name::iter() {
143 let result = a.promote_with(b);
144 let expected = TYPE_RULES[a as usize][b as usize];
145 assert_eq!(result, expected, "{}", format!("Failed promotion test for {:?} and {:?}", a, b));
146 }
147 }
148 }
149 }
150 };
151
152 TokenStream::from(quote! {
153 #input
154 #tests
155 })
156}
157
158/// Generates a builder struct for the given struct.
159///
160/// This macro should be used in conjunction with the `#[derive(Buildable)]` derive macro.
161/// See the [`Buildable`] macro for more information.
162#[doc(hidden)]
163#[proc_macro]
164pub fn generate_builder(input: TokenStream) -> TokenStream {
165 // let input = parse_macro_input!(input as ItemStruct);
166 let input = parse_macro_input!(input as DeriveInput);
167 let builder = generate_builder::expand_generate_builder(&input).unwrap();
168 quote::quote! {
169 #input
170 #builder
171 }
172 .into()
173}
174
175/// Derive `mlx_rs::builder::Buildable` for a struct. When used with the `generate_builder` macro,
176/// a builder struct `<Struct>Builder` will be generated.
177///
178/// # Attributes
179///
180/// ## `#[buildable]`
181///
182/// ### Arguments
183///
184/// - `builder`: Path to the builder struct. Default to `<Struct>Builder` if not provided.
185/// - `root`: Path to the root module. Default to `::mlx_rs` if not provided.
186///
187/// ## `#[builder]`
188///
189/// **Note**: This attribute has no effect if NOT used with the `generate_builder` macro.
190///
191/// ### Arguments when applied on struct
192///
193/// - `build_with`: Function ident to build the struct.
194/// - `root`: Path to the root module. Default to `::mlx_rs` if not provided.
195/// - `err`: Type of error to return when build fails. Default to `std::convert::Infallible`
196/// if not provided.
197/// - `default_infallible`: Whether the default error type is infallible. Default to `err.is_none()`
198/// if not provided. When `true`, the generated `<Struct>::new()` method will unwrap the build result
199/// and return `<Struct>`. When `false`, the generated `<Struct>::new()` method will return `Result<<Struct>, Err>`.
200///
201/// ### Arguments when applied on field
202///
203/// - `optional`: Whether the field is optional. Default to `false` if not provided.
204/// - `default`: Path to the default value for the field. This is required if the field is optional.
205/// - `rename`: Rename the field in the builder struct.
206/// - `ignore`: Ignore the field in the builder struct.
207/// - `ty_override`: Override the type of the field in the builder struct.
208/// - `skip_setter`: Skip the setter method for the field in the builder struct.
209///
210/// # Example
211///
212/// ```rust,ignore
213/// use mlx_internal_macros::*;
214/// use mlx_rs::builder::{Buildable, Builder};
215///
216/// generate_builder! {
217/// /// Test struct for the builder generation.
218/// #[derive(Debug, Buildable)]
219/// #[builder(build_with = build_test_struct)]
220/// struct TestStruct {
221/// #[builder(optional, default = TestStruct::DEFAULT_OPT_FIELD_1)]
222/// opt_field_1: i32,
223/// #[builder(optional, default = TestStruct::DEFAULT_OPT_FIELD_2)]
224/// opt_field_2: i32,
225/// mandatory_field_1: i32,
226///
227/// #[builder(ignore)]
228/// ignored_field: String,
229/// }
230/// }
231///
232/// fn build_test_struct(
233/// builder: TestStructBuilder,
234/// ) -> std::result::Result<TestStruct, std::convert::Infallible> {
235/// Ok(TestStruct {
236/// opt_field_1: builder.opt_field_1,
237/// opt_field_2: builder.opt_field_2,
238/// mandatory_field_1: builder.mandatory_field_1,
239/// ignored_field: String::from("ignored"),
240/// })
241/// }
242///
243/// impl TestStruct {
244/// pub const DEFAULT_OPT_FIELD_1: i32 = 1;
245/// pub const DEFAULT_OPT_FIELD_2: i32 = 2;
246/// }
247///
248/// #[test]
249/// fn test_generated_builder() {
250/// let test_struct = <TestStruct as Buildable>::Builder::new(4)
251/// .opt_field_1(2)
252/// .opt_field_2(3)
253/// .build()
254/// .unwrap();
255///
256/// assert_eq!(test_struct.opt_field_1, 2);
257/// assert_eq!(test_struct.opt_field_2, 3);
258/// assert_eq!(test_struct.mandatory_field_1, 4);
259/// assert_eq!(test_struct.ignored_field, String::from("ignored"));
260/// }
261/// ```
262#[doc(hidden)]
263#[proc_macro_derive(Buildable, attributes(buildable, builder))]
264pub fn derive_buildable(input: TokenStream) -> TokenStream {
265 let input = parse_macro_input!(input as DeriveInput);
266 let builder = derive_buildable::expand_derive_buildable(input).unwrap();
267 TokenStream::from(builder)
268}
269
270/// Derive `mlx_rs::builder::Builder` trait for a struct and generate the following methods:
271///
272/// - `<Struct>Builder::new(mandatory_fields)`: Create a new builder with the mandatory fields.
273/// - setter methods for each optinal field
274/// - `<Struct>::new(mandatory_fields)`: Create the struct from the builder with the mandatory fields.
275///
276/// # Attributes
277///
278/// ## `#[builder]`
279///
280/// ### Arguments when applied on struct
281///
282/// - `build_with`: Function ident to build the struct.
283/// - `root`: Path to the root module. Default to `::mlx_rs` if not provided.
284/// - `err`: Type of error to return when build fails. Default to `std::convert::Infallible`
285/// if not provided.
286/// - `default_infallible`: Whether the default error type is infallible. Default to `err.is_none()`
287/// if not provided. When `true`, the generated `<Struct>::new()` method will unwrap the build result
288/// and return `<Struct>`. When `false`, the generated `<Struct>::new()` method will return `Result<<Struct>, Err>`.
289///
290/// ### Arguments when applied on field
291///
292/// - `optional`: Whether the field is optional. Default to `false` if not provided.
293/// - `default`: Path to the default value for the field. This is required if the field is optional.
294/// - `rename`: Rename the field in the builder struct.
295/// - `ignore`: Ignore the field in the builder struct.
296/// - `ty_override`: Override the type of the field in the builder struct.
297/// - `skip_setter`: Skip the setter method for the field in the builder struct.
298#[doc(hidden)]
299#[proc_macro_derive(Builder, attributes(builder))]
300pub fn derive_builder(input: TokenStream) -> TokenStream {
301 let input = parse_macro_input!(input as DeriveInput);
302 let builder = derive_builder::expand_derive_builder(input).unwrap();
303 TokenStream::from(builder)
304}
305
306/// Generate a macro that expands to the given function for ergonomic purposes.
307///
308/// See `mlx-rs/mlx-tests/test_generate_macro.rs` for more usage examples.
309///
310/// ```rust,ignore
311/// #![allow(unused_variables)]
312///
313/// use mlx_internal_macros::{default_device, generate_macro};
314/// use mlx_rs::{Stream, StreamOrDevice};
315///
316/// /// Test macro generation.
317/// #[generate_macro(customize(root = "$crate"))] // Default is `$crate::ops`
318/// #[default_device]
319/// fn foo_device(
320/// a: i32, // Mandatory argument
321/// b: i32, // Mandatory argument
322/// #[optional] c: Option<i32>, // Optional argument
323/// #[optional] d: impl Into<Option<i32>>, // Optional argument but impl Trait
324/// #[optional] stream: impl AsRef<Stream>, // stream always optional and placed at the end
325/// ) -> i32 {
326/// a + b + c.unwrap_or(0) + d.into().unwrap_or(0)
327/// }
328///
329/// assert_eq!(foo!(1, 2), 3);
330/// assert_eq!(foo!(1, 2, c = Some(3)), 6);
331/// assert_eq!(foo!(1, 2, d = Some(4)), 7);
332/// assert_eq!(foo!(1, 2, c = Some(3), d = Some(4)), 10);
333///
334/// let stream = Stream::new();
335///
336/// assert_eq!(foo!(1, 2, stream = &stream), 3);
337/// assert_eq!(foo!(1, 2, c = Some(3), stream = &stream), 6);
338/// assert_eq!(foo!(1, 2, d = Some(4), stream = &stream), 7);
339/// assert_eq!(foo!(1, 2, c = Some(3), d = Some(4), stream = &stream), 10);
340/// ```
341#[doc(hidden)]
342#[proc_macro_attribute]
343pub fn generate_macro(attr: TokenStream, item: TokenStream) -> TokenStream {
344 let attr = if !attr.is_empty() {
345 let meta = syn::parse_macro_input!(attr as syn::Meta);
346 Some(meta)
347 } else {
348 None
349 };
350 let item = parse_macro_input!(item as ItemFn);
351 generate_macro::expand_generate_macro(attr, item).unwrap()
352}