mlx_internal_macros/
generate_macro.rs

1use darling::FromMeta;
2use itertools::Itertools;
3use proc_macro::TokenStream;
4use quote::quote;
5use syn::{FnArg, Ident, ItemFn, Meta};
6
7const CUSTOM_ATTRIBUTE_OPTIONAL: &str = "optional";
8const CUSTOM_ATTRIBUTE_NAMED: &str = "named";
9
10const CUSTOM_ATTRIBUTES: &[&str] = &[CUSTOM_ATTRIBUTE_OPTIONAL, CUSTOM_ATTRIBUTE_NAMED];
11
12#[derive(Default, Debug, FromMeta)]
13#[darling(default)]
14struct Customize {
15    root: Option<syn::LitStr>,
16    default_dtype: Option<syn::Path>,
17}
18
19fn arg_type(attrs: &[syn::Attribute]) -> ArgType {
20    for attr in attrs {
21        if attr.path().is_ident(CUSTOM_ATTRIBUTE_OPTIONAL) {
22            return ArgType::NamedOptional;
23        } else if attr.path().is_ident(CUSTOM_ATTRIBUTE_NAMED) {
24            return ArgType::Named;
25        }
26    }
27    ArgType::Positional
28}
29
30fn remove_attribute(attrs: &mut Vec<syn::Attribute>, targets: &[&str]) {
31    attrs.retain(|attr| !targets.iter().any(|target| !attr.path().is_ident(target)));
32}
33
34/// Remove "$" prefix from the string
35fn remove_prefix_from_str(s: &str) -> String {
36    s.trim_start_matches("$").to_string()
37}
38
39pub fn expand_generate_macro(
40    attr: Option<Meta>,
41    mut item: ItemFn, // The original function should be kept as is
42) -> Result<TokenStream, syn::Error> {
43    let customize = match attr {
44        Some(attr) => Customize::from_meta(&attr).map_err(|e| syn::Error::new_spanned(attr, e))?,
45        None => Customize::default(),
46    };
47
48    // The mod path where the function can be accessed publicly
49    let (fn_mod_path, doc_mod_path) = match customize.root {
50        Some(lit_str) => {
51            let tokens: proc_macro2::TokenStream = lit_str.parse()?;
52            let s = remove_prefix_from_str(&lit_str.value());
53            (quote! { #tokens }, s)
54        }
55        None => (quote! { $crate::ops }, "crate::ops".into()),
56    };
57
58    let (default_generics, dtype_generics) =
59        handle_generic_args(&item.sig.generics, &customize.default_dtype);
60
61    let args = item
62        .sig
63        .inputs
64        .iter_mut()
65        .map(|arg| match arg {
66            FnArg::Receiver(_) => Err(syn::Error::new_spanned(arg, "self is not allowed")),
67            FnArg::Typed(pat_type) => Ok(pat_type),
68        })
69        .collect::<Result<Vec<_>, _>>()?;
70
71    let mut parsed_args = parse_args(args);
72
73    // Check if the last optional argument is `stream`
74    if let Some(arg) = parsed_args.last() {
75        if arg.ident != "stream" {
76            return Err(syn::Error::new_spanned(
77                &item,
78                "the last optional argument must be `stream`",
79            ));
80        }
81    }
82    // Remove the last optional argument `stream`
83    parsed_args.pop();
84
85    // Remove "_device" suffix from the macro name if it exists
86    let fn_ident = &item.sig.ident;
87
88    let generated = generate_macro(
89        &fn_mod_path,
90        &doc_mod_path,
91        fn_ident,
92        &parsed_args,
93        &default_generics,
94        &dtype_generics,
95    )?;
96
97    let output = quote! {
98        #item
99        #generated
100    };
101
102    Ok(output.into())
103}
104
105/// If there are generic arguments, the last argument is assumed to be `dtype`.
106///
107/// Returns two `syn::Generics`:
108/// 1. With the last argument set to `f32`
109/// 2. With the last argument set to `$dtype`
110fn handle_generic_args(
111    generic_args: &syn::Generics,
112    default_dtype: &Option<syn::Path>,
113) -> (proc_macro2::TokenStream, Option<proc_macro2::TokenStream>) {
114    // Count number of generic type arguments
115    let count = generic_args
116        .params
117        .iter()
118        .filter(|param| matches!(param, syn::GenericParam::Type(_)))
119        .count();
120
121    if count == 0 {
122        return (quote! {}, None);
123    }
124
125    // All generics arguments except for the last one will be inferred
126    let infer_tokens = vec![quote! { _ }; count - 1];
127
128    let default_generics = match default_dtype {
129        Some(path) => quote! { ::<#(#infer_tokens,)* #path> },
130        None => quote! { ::<#(#infer_tokens,)* f32> },
131    };
132    let dtype_generics = quote! { ::<#(#infer_tokens,)* $dtype> };
133
134    (default_generics, Some(dtype_generics))
135}
136
137#[derive(Debug, Clone, Copy)]
138enum ArgType {
139    Positional,
140    Named,
141    NamedOptional,
142}
143
144struct Arg {
145    ident: Ident,
146    arg_type: ArgType,
147}
148
149fn parse_args(args: Vec<&mut syn::PatType>) -> Vec<Arg> {
150    let mut is_prev_optional = false;
151    let mut parsed = Vec::new();
152    for arg in args {
153        match &*arg.pat {
154            syn::Pat::Ident(ident) => {
155                let arg_type = arg_type(&arg.attrs);
156
157                let is_positional = matches!(arg_type, ArgType::Positional);
158                if is_prev_optional && is_positional {
159                    panic!("positional argument cannot follow an optional argument");
160                }
161                is_prev_optional = matches!(arg_type, ArgType::NamedOptional);
162
163                parsed.push(Arg {
164                    ident: ident.ident.clone(),
165                    arg_type,
166                });
167            }
168            _ => panic!("unsupported pattern"),
169        }
170
171        remove_attribute(&mut arg.attrs, CUSTOM_ATTRIBUTES);
172    }
173    parsed
174}
175
176fn generate_macro(
177    fn_mod_path: &proc_macro2::TokenStream,
178    doc_mod_path: &str,
179    fn_ident: &Ident,
180    args: &[Arg],
181    default_generics: &proc_macro2::TokenStream,
182    dtype_generics: &Option<proc_macro2::TokenStream>,
183) -> Result<proc_macro2::TokenStream, syn::Error> {
184    let mut trimmed_fn_ident_str = fn_ident.to_string();
185    if trimmed_fn_ident_str.ends_with("_device") {
186        trimmed_fn_ident_str = trimmed_fn_ident_str.trim_end_matches("_device").to_string();
187    }
188    let trimmed_fn_ident = Ident::new(&trimmed_fn_ident_str, fn_ident.span());
189
190    let mut macro_variants = Vec::new();
191
192    generate_macro_variants(
193        fn_mod_path,
194        fn_ident,
195        &trimmed_fn_ident,
196        args,
197        default_generics,
198        dtype_generics,
199        &mut macro_variants,
200    );
201
202    let macro_docs = format!(
203        "Macro generated for the function [`{}::{}`]. See the function documentation for more details.",
204        doc_mod_path, trimmed_fn_ident
205    );
206
207    let generated = quote! {
208        #[doc = #macro_docs]
209        #[macro_export]
210        macro_rules! #trimmed_fn_ident {
211            #(
212                #macro_variants
213            )*
214        }
215    };
216
217    Ok(generated)
218}
219
220fn generate_macro_variants(
221    fn_mod_path: &proc_macro2::TokenStream,
222    fn_ident: &Ident,
223    trimmed_fn_ident: &Ident,
224    args: &[Arg],
225    default_generics: &proc_macro2::TokenStream,
226    dtype_generics: &Option<proc_macro2::TokenStream>,
227    macro_variants: &mut Vec<proc_macro2::TokenStream>,
228) {
229    let args_ident = args.iter().map(|arg| &arg.ident).collect::<Vec<_>>();
230    let args_type = args.iter().map(|arg| arg.arg_type).collect::<Vec<_>>();
231    let mut optional_indices = Vec::new();
232    let mut selected = Vec::with_capacity(args.len());
233    for (idx, arg) in args.iter().enumerate() {
234        match arg.arg_type {
235            ArgType::Positional => {
236                selected.push(true);
237            }
238            ArgType::Named => {
239                selected.push(true);
240            }
241            ArgType::NamedOptional => {
242                selected.push(false);
243                optional_indices.push(idx);
244            }
245        }
246    }
247
248    for perms in 0..optional_indices.len() + 1 {
249        // Select `perms` number of optional arguments
250        for selected_indice in optional_indices.iter().permutations(perms) {
251            selected_indice.iter().for_each(|&&i| selected[i] = true);
252
253            generate_macro_variants_for_selected_args(
254                fn_mod_path,
255                fn_ident,
256                trimmed_fn_ident,
257                &args_ident,
258                &args_type,
259                &selected,
260                default_generics,
261                dtype_generics,
262                macro_variants,
263            );
264
265            // Clear the selected flag for the next iteration
266            selected_indice.iter().for_each(|&&i| selected[i] = false);
267        }
268    }
269}
270
271#[allow(clippy::too_many_arguments)]
272fn generate_macro_variants_for_selected_args(
273    fn_mod_path: &proc_macro2::TokenStream,
274    fn_ident: &Ident,
275    trimmed_fn_ident: &Ident,
276    args_ident: &[&Ident],
277    args_type: &[ArgType],
278    selected: &[bool],
279    default_generics: &proc_macro2::TokenStream,
280    dtype_generics: &Option<proc_macro2::TokenStream>,
281    macro_variants: &mut Vec<proc_macro2::TokenStream>,
282) {
283    let macro_args: Vec<proc_macro2::TokenStream> = args_ident
284        .iter()
285        .zip(args_type.iter())
286        .zip(selected.iter())
287        .filter_map(|((ident, arg_type), &selected)| match selected {
288            true => {
289                let token = match arg_type {
290                    ArgType::Positional => quote! { $#ident:expr },
291                    ArgType::Named => quote! { #ident=$#ident:expr },
292                    ArgType::NamedOptional => quote! { #ident=$#ident:expr },
293                };
294                Some(token)
295            }
296            false => None,
297        })
298        .collect();
299
300    let input: Vec<proc_macro2::TokenStream> = args_ident
301        .iter()
302        .zip(selected.iter())
303        .map(|(ident, &selected)| {
304            if selected {
305                quote! { $#ident }
306            } else {
307                quote! { None }
308            }
309        })
310        .collect();
311
312    let variant_body = quote! {
313        (
314            #(#macro_args),*
315        ) => {
316            #fn_mod_path::#trimmed_fn_ident #default_generics(#(#input,)*)
317        };
318        (
319            #(#macro_args,)*
320            stream=$stream:expr
321        ) => {
322            #fn_mod_path::#fn_ident #default_generics(#(#input,)* $stream)
323        };
324    };
325
326    macro_variants.push(variant_body);
327
328    if let Some(dtype_generics) = &dtype_generics {
329        let variant_body = quote! {
330            (
331                #(#macro_args,)*
332                dtype=$dtype:ty
333            ) => {
334                #fn_mod_path::#trimmed_fn_ident #dtype_generics(#(#input,)*)
335            };
336            (
337                #(#macro_args,)*
338                dtype=$dtype:ty,
339                stream=$stream:expr
340            ) => {
341                #fn_mod_path::#fn_ident #dtype_generics(#(#input,)* $stream)
342            };
343        };
344
345        macro_variants.push(variant_body);
346    }
347}