mlx_internal_macros/
lib.rs

1extern crate proc_macro;
2use darling::FromMeta;
3use proc_macro::TokenStream;
4use quote::{format_ident, quote};
5use syn::punctuated::Punctuated;
6use syn::{parse_macro_input, parse_quote, DeriveInput, FnArg, ItemEnum, ItemFn, Pat};
7
8mod derive_buildable;
9mod derive_builder;
10mod generate_builder;
11mod generate_macro;
12mod shared;
13
14#[derive(Debug, FromMeta)]
15enum DeviceType {
16    Cpu,
17    Gpu,
18}
19
20#[derive(Debug)]
21struct DefaultDeviceInput {
22    device: DeviceType,
23}
24
25impl FromMeta for DefaultDeviceInput {
26    fn from_meta(meta: &syn::Meta) -> darling::Result<Self> {
27        let syn::Meta::NameValue(meta_name_value) = meta else {
28            return Err(darling::Error::unsupported_format(
29                "expected a name-value attribute",
30            ));
31        };
32
33        let ident = meta_name_value.path.get_ident().unwrap();
34        assert_eq!(ident, "device", "expected `device`");
35
36        let device = DeviceType::from_expr(&meta_name_value.value)?;
37
38        Ok(DefaultDeviceInput { device })
39    }
40}
41
42#[doc(hidden)]
43#[proc_macro_attribute]
44pub fn default_device(attr: TokenStream, item: TokenStream) -> TokenStream {
45    let input = if !attr.is_empty() {
46        let meta = syn::parse_macro_input!(attr as syn::Meta);
47        Some(DefaultDeviceInput::from_meta(&meta).unwrap())
48    } else {
49        None
50    };
51
52    let mut input_fn = parse_macro_input!(item as ItemFn);
53    let original_fn = input_fn.clone();
54
55    // Ensure function name convention
56    if !input_fn.sig.ident.to_string().contains("_device") {
57        panic!("Function name must end with '_device'");
58    }
59    let new_fn_name = format_ident!("{}", &input_fn.sig.ident.to_string().replace("_device", ""));
60    input_fn.sig.ident = new_fn_name;
61
62    // Filter out the `stream` parameter and reconstruct the Punctuated collection
63    let filtered_inputs = input_fn
64        .sig
65        .inputs
66        .iter()
67        .filter(|arg| match arg {
68            FnArg::Typed(pat_typed) => {
69                if let Pat::Ident(pat_ident) = &*pat_typed.pat {
70                    pat_ident.ident != "stream"
71                } else {
72                    true
73                }
74            }
75            _ => true,
76        })
77        .cloned()
78        .collect::<Vec<_>>();
79
80    input_fn.sig.inputs = Punctuated::from_iter(filtered_inputs);
81
82    // Prepend default stream initialization
83    let default_stream_stmt = match input.map(|input| input.device) {
84        Some(DeviceType::Cpu) => parse_quote! {
85            let stream = StreamOrDevice::cpu();
86        },
87        Some(DeviceType::Gpu) => parse_quote! {
88            let stream = StreamOrDevice::gpu();
89        },
90        None => parse_quote! {
91            let stream = StreamOrDevice::default();
92        },
93    };
94    input_fn.block.stmts.insert(0, default_stream_stmt);
95
96    // Combine the original and modified functions into the output
97    let expanded = quote! {
98        #original_fn
99
100        #input_fn
101    };
102
103    TokenStream::from(expanded)
104}
105
106#[doc(hidden)]
107#[proc_macro]
108pub fn generate_test_cases(input: TokenStream) -> TokenStream {
109    let input = parse_macro_input!(input as ItemEnum);
110    let name = &input.ident;
111
112    let tests = quote! {
113        /// MLX's rules for promoting two dtypes.
114        #[rustfmt::skip]
115        const TYPE_RULES: [[Dtype; 14]; 14] = [
116            // bool             uint8               uint16              uint32              uint64              int8                int16               int32               int64               float16             float32             float64,           bfloat16            complex64
117            [Dtype::Bool,       Dtype::Uint8,       Dtype::Uint16,      Dtype::Uint32,      Dtype::Uint64,      Dtype::Int8,        Dtype::Int16,       Dtype::Int32,       Dtype::Int64,       Dtype::Float16,     Dtype::Float32,     Dtype::Float64,    Dtype::Bfloat16,    Dtype::Complex64], // bool
118            [Dtype::Uint8,      Dtype::Uint8,       Dtype::Uint16,      Dtype::Uint32,      Dtype::Uint64,      Dtype::Int16,       Dtype::Int16,       Dtype::Int32,       Dtype::Int64,       Dtype::Float16,     Dtype::Float32,     Dtype::Float64,    Dtype::Bfloat16,    Dtype::Complex64], // uint8
119            [Dtype::Uint16,     Dtype::Uint16,      Dtype::Uint16,      Dtype::Uint32,      Dtype::Uint64,      Dtype::Int32,       Dtype::Int32,       Dtype::Int32,       Dtype::Int64,       Dtype::Float16,     Dtype::Float32,     Dtype::Float64,    Dtype::Bfloat16,    Dtype::Complex64], // uint16
120            [Dtype::Uint32,     Dtype::Uint32,      Dtype::Uint32,      Dtype::Uint32,      Dtype::Uint64,      Dtype::Int64,       Dtype::Int64,       Dtype::Int64,       Dtype::Int64,       Dtype::Float16,     Dtype::Float32,     Dtype::Float64,    Dtype::Bfloat16,    Dtype::Complex64], // uint32
121            [Dtype::Uint64,     Dtype::Uint64,      Dtype::Uint64,      Dtype::Uint64,      Dtype::Uint64,      Dtype::Float32,     Dtype::Float32,     Dtype::Float32,     Dtype::Float32,     Dtype::Float16,     Dtype::Float32,     Dtype::Float64,    Dtype::Bfloat16,    Dtype::Complex64], // uint64
122            [Dtype::Int8,       Dtype::Int16,       Dtype::Int32,       Dtype::Int64,       Dtype::Float32,     Dtype::Int8,        Dtype::Int16,       Dtype::Int32,       Dtype::Int64,       Dtype::Float16,     Dtype::Float32,     Dtype::Float64,    Dtype::Bfloat16,    Dtype::Complex64], // int8
123            [Dtype::Int16,      Dtype::Int16,       Dtype::Int32,       Dtype::Int64,       Dtype::Float32,     Dtype::Int16,       Dtype::Int16,       Dtype::Int32,       Dtype::Int64,       Dtype::Float16,     Dtype::Float32,     Dtype::Float64,    Dtype::Bfloat16,    Dtype::Complex64], // int16
124            [Dtype::Int32,      Dtype::Int32,       Dtype::Int32,       Dtype::Int64,       Dtype::Float32,     Dtype::Int32,       Dtype::Int32,       Dtype::Int32,       Dtype::Int64,       Dtype::Float16,     Dtype::Float32,     Dtype::Float64,    Dtype::Bfloat16,    Dtype::Complex64], // int32
125            [Dtype::Int64,      Dtype::Int64,       Dtype::Int64,       Dtype::Int64,       Dtype::Float32,     Dtype::Int64,       Dtype::Int64,       Dtype::Int64,       Dtype::Int64,       Dtype::Float16,     Dtype::Float32,     Dtype::Float64,    Dtype::Bfloat16,    Dtype::Complex64], // int64
126            [Dtype::Float16,    Dtype::Float16,     Dtype::Float16,     Dtype::Float16,     Dtype::Float16,     Dtype::Float16,     Dtype::Float16,     Dtype::Float16,     Dtype::Float16,     Dtype::Float16,     Dtype::Float32,     Dtype::Float64,    Dtype::Float32,     Dtype::Complex64], // float16
127            [Dtype::Float32,    Dtype::Float32,     Dtype::Float32,     Dtype::Float32,     Dtype::Float32,     Dtype::Float32,     Dtype::Float32,     Dtype::Float32,     Dtype::Float32,     Dtype::Float32,     Dtype::Float32,     Dtype::Float64,    Dtype::Float32,     Dtype::Complex64], // float32
128            [Dtype::Float64,    Dtype::Float64,     Dtype::Float64,     Dtype::Float64,     Dtype::Float64,     Dtype::Float64,     Dtype::Float64,     Dtype::Float64,     Dtype::Float64,     Dtype::Float64,     Dtype::Float64,     Dtype::Float64,    Dtype::Float64,     Dtype::Complex64], // Dtype::Float64
129            [Dtype::Bfloat16,   Dtype::Bfloat16,    Dtype::Bfloat16,    Dtype::Bfloat16,    Dtype::Bfloat16,    Dtype::Bfloat16,    Dtype::Bfloat16,    Dtype::Bfloat16,    Dtype::Bfloat16,    Dtype::Float32,     Dtype::Float32,     Dtype::Float64,    Dtype::Bfloat16,    Dtype::Complex64], // bfloat16
130            [Dtype::Complex64,  Dtype::Complex64,   Dtype::Complex64,   Dtype::Complex64,   Dtype::Complex64,   Dtype::Complex64,   Dtype::Complex64,   Dtype::Complex64,   Dtype::Complex64,   Dtype::Complex64,   Dtype::Complex64,   Dtype::Complex64,  Dtype::Complex64,   Dtype::Complex64], // complex64
131        ];
132
133        #[cfg(test)]
134        mod generated_tests {
135            use super::*;
136            use strum::IntoEnumIterator;
137            use pretty_assertions::assert_eq;
138
139            #[test]
140            fn test_all_combinations() {
141                for a in #name::iter() {
142                    for b in #name::iter() {
143                        let result = a.promote_with(b);
144                        let expected = TYPE_RULES[a as usize][b as usize];
145                        assert_eq!(result, expected, "{}", format!("Failed promotion test for {:?} and {:?}", a, b));
146                    }
147                }
148            }
149        }
150    };
151
152    TokenStream::from(quote! {
153        #input
154        #tests
155    })
156}
157
158/// Generates a builder struct for the given struct.
159///
160/// This macro should be used in conjunction with the `#[derive(Buildable)]` derive macro.
161/// See the [`Buildable`] macro for more information.
162#[doc(hidden)]
163#[proc_macro]
164pub fn generate_builder(input: TokenStream) -> TokenStream {
165    // let input = parse_macro_input!(input as ItemStruct);
166    let input = parse_macro_input!(input as DeriveInput);
167    let builder = generate_builder::expand_generate_builder(&input).unwrap();
168    quote::quote! {
169        #input
170        #builder
171    }
172    .into()
173}
174
175/// Derive `mlx_rs::builder::Buildable` for a struct. When used with the `generate_builder` macro,
176/// a builder struct `<Struct>Builder` will be generated.
177///
178/// # Attributes
179///
180/// ## `#[buildable]`
181///
182/// ### Arguments
183///
184/// - `builder`: Path to the builder struct. Default to `<Struct>Builder` if not provided.
185/// - `root`: Path to the root module. Default to `::mlx_rs` if not provided.
186///
187/// ## `#[builder]`
188///
189/// **Note**: This attribute has no effect if NOT used with the `generate_builder` macro.
190///
191/// ### Arguments when applied on struct
192///
193/// - `build_with`: Function ident to build the struct.
194/// - `root`: Path to the root module. Default to `::mlx_rs` if not provided.
195/// - `err`: Type of error to return when build fails. Default to `std::convert::Infallible`
196///   if not provided.
197/// - `default_infallible`: Whether the default error type is infallible. Default to `err.is_none()`
198///   if not provided. When `true`, the generated `<Struct>::new()` method will unwrap the build result
199///   and return `<Struct>`. When `false`, the generated `<Struct>::new()` method will return `Result<<Struct>, Err>`.
200///
201/// ### Arguments when applied on field
202///
203/// - `optional`: Whether the field is optional. Default to `false` if not provided.
204/// - `default`: Path to the default value for the field. This is required if the field is optional.
205/// - `rename`: Rename the field in the builder struct.
206/// - `ignore`: Ignore the field in the builder struct.
207/// - `ty_override`: Override the type of the field in the builder struct.
208/// - `skip_setter`: Skip the setter method for the field in the builder struct.
209///
210/// # Example
211///
212/// ```rust,ignore
213/// use mlx_internal_macros::*;
214/// use mlx_rs::builder::{Buildable, Builder};
215///
216/// generate_builder! {
217///     /// Test struct for the builder generation.
218///     #[derive(Debug, Buildable)]
219///     #[builder(build_with = build_test_struct)]
220///     struct TestStruct {
221///         #[builder(optional, default = TestStruct::DEFAULT_OPT_FIELD_1)]
222///         opt_field_1: i32,
223///         #[builder(optional, default = TestStruct::DEFAULT_OPT_FIELD_2)]
224///         opt_field_2: i32,
225///         mandatory_field_1: i32,
226///
227///         #[builder(ignore)]
228///         ignored_field: String,
229///     }
230/// }
231///
232/// fn build_test_struct(
233///     builder: TestStructBuilder,
234/// ) -> std::result::Result<TestStruct, std::convert::Infallible> {
235///     Ok(TestStruct {
236///         opt_field_1: builder.opt_field_1,
237///         opt_field_2: builder.opt_field_2,
238///         mandatory_field_1: builder.mandatory_field_1,
239///         ignored_field: String::from("ignored"),
240///     })
241/// }
242///
243/// impl TestStruct {
244///     pub const DEFAULT_OPT_FIELD_1: i32 = 1;
245///     pub const DEFAULT_OPT_FIELD_2: i32 = 2;
246/// }
247///
248/// #[test]
249/// fn test_generated_builder() {
250///     let test_struct = <TestStruct as Buildable>::Builder::new(4)
251///         .opt_field_1(2)
252///         .opt_field_2(3)
253///         .build()
254///         .unwrap();
255///
256///     assert_eq!(test_struct.opt_field_1, 2);
257///     assert_eq!(test_struct.opt_field_2, 3);
258///     assert_eq!(test_struct.mandatory_field_1, 4);
259///     assert_eq!(test_struct.ignored_field, String::from("ignored"));
260/// }
261/// ```
262#[doc(hidden)]
263#[proc_macro_derive(Buildable, attributes(buildable, builder))]
264pub fn derive_buildable(input: TokenStream) -> TokenStream {
265    let input = parse_macro_input!(input as DeriveInput);
266    let builder = derive_buildable::expand_derive_buildable(input).unwrap();
267    TokenStream::from(builder)
268}
269
270/// Derive `mlx_rs::builder::Builder` trait for a struct and generate the following methods:
271///
272/// - `<Struct>Builder::new(mandatory_fields)`: Create a new builder with the mandatory fields.
273/// - setter methods for each optinal field
274/// - `<Struct>::new(mandatory_fields)`: Create the struct from the builder with the mandatory fields.
275///
276/// # Attributes
277///
278/// ## `#[builder]`
279///
280/// ### Arguments when applied on struct
281///
282/// - `build_with`: Function ident to build the struct.
283/// - `root`: Path to the root module. Default to `::mlx_rs` if not provided.
284/// - `err`: Type of error to return when build fails. Default to `std::convert::Infallible`
285///   if not provided.
286/// - `default_infallible`: Whether the default error type is infallible. Default to `err.is_none()`
287///   if not provided. When `true`, the generated `<Struct>::new()` method will unwrap the build result
288///   and return `<Struct>`. When `false`, the generated `<Struct>::new()` method will return `Result<<Struct>, Err>`.
289///
290/// ### Arguments when applied on field
291///
292/// - `optional`: Whether the field is optional. Default to `false` if not provided.
293/// - `default`: Path to the default value for the field. This is required if the field is optional.
294/// - `rename`: Rename the field in the builder struct.
295/// - `ignore`: Ignore the field in the builder struct.
296/// - `ty_override`: Override the type of the field in the builder struct.
297/// - `skip_setter`: Skip the setter method for the field in the builder struct.
298#[doc(hidden)]
299#[proc_macro_derive(Builder, attributes(builder))]
300pub fn derive_builder(input: TokenStream) -> TokenStream {
301    let input = parse_macro_input!(input as DeriveInput);
302    let builder = derive_builder::expand_derive_builder(input).unwrap();
303    TokenStream::from(builder)
304}
305
306/// Generate a macro that expands to the given function for ergonomic purposes.
307///
308/// See `mlx-rs/mlx-tests/test_generate_macro.rs` for more usage examples.
309///
310/// ```rust,ignore
311/// #![allow(unused_variables)]
312///
313/// use mlx_internal_macros::{default_device, generate_macro};
314/// use mlx_rs::{Stream, StreamOrDevice};
315///
316/// /// Test macro generation.
317/// #[generate_macro(customize(root = "$crate"))] // Default is `$crate::ops`
318/// #[default_device]
319/// fn foo_device(
320///     a: i32, // Mandatory argument
321///     b: i32, // Mandatory argument
322///     #[optional] c: Option<i32>, // Optional argument
323///     #[optional] d: impl Into<Option<i32>>, // Optional argument but impl Trait
324///     #[optional] stream: impl AsRef<Stream>, // stream always optional and placed at the end
325/// ) -> i32 {
326///     a + b + c.unwrap_or(0) + d.into().unwrap_or(0)
327/// }
328///
329/// assert_eq!(foo!(1, 2), 3);
330/// assert_eq!(foo!(1, 2, c = Some(3)), 6);
331/// assert_eq!(foo!(1, 2, d = Some(4)), 7);
332/// assert_eq!(foo!(1, 2, c = Some(3), d = Some(4)), 10);
333///
334/// let stream = Stream::new();
335///
336/// assert_eq!(foo!(1, 2, stream = &stream), 3);
337/// assert_eq!(foo!(1, 2, c = Some(3), stream = &stream), 6);
338/// assert_eq!(foo!(1, 2, d = Some(4), stream = &stream), 7);
339/// assert_eq!(foo!(1, 2, c = Some(3), d = Some(4), stream = &stream), 10);
340/// ```
341#[doc(hidden)]
342#[proc_macro_attribute]
343pub fn generate_macro(attr: TokenStream, item: TokenStream) -> TokenStream {
344    let attr = if !attr.is_empty() {
345        let meta = syn::parse_macro_input!(attr as syn::Meta);
346        Some(meta)
347    } else {
348        None
349    };
350    let item = parse_macro_input!(item as ItemFn);
351    generate_macro::expand_generate_macro(attr, item).unwrap()
352}