mlx_internal_macros/
generate_macro.rs

1use darling::FromMeta;
2use itertools::Itertools;
3use proc_macro::TokenStream;
4use quote::quote;
5use syn::{FnArg, Ident, ItemFn, Meta};
6
7const CUSTOM_ATTRIBUTE_OPTIONAL: &str = "optional";
8const CUSTOM_ATTRIBUTE_NAMED: &str = "named";
9
10const CUSTOM_ATTRIBUTES: &[&str] = &[CUSTOM_ATTRIBUTE_OPTIONAL, CUSTOM_ATTRIBUTE_NAMED];
11
12#[derive(Default, Debug, FromMeta)]
13#[darling(default)]
14struct Customize {
15    root: Option<syn::LitStr>,
16    default_dtype: Option<syn::Path>,
17}
18
19fn arg_type(attrs: &[syn::Attribute]) -> ArgType {
20    for attr in attrs {
21        if attr.path().is_ident(CUSTOM_ATTRIBUTE_OPTIONAL) {
22            return ArgType::NamedOptional;
23        } else if attr.path().is_ident(CUSTOM_ATTRIBUTE_NAMED) {
24            return ArgType::Named;
25        }
26    }
27    ArgType::Positional
28}
29
30fn remove_attribute(attrs: &mut Vec<syn::Attribute>, targets: &[&str]) {
31    attrs.retain(|attr| !targets.iter().any(|target| !attr.path().is_ident(target)));
32}
33
34/// Remove "$" prefix from the string
35fn remove_prefix_from_str(s: &str) -> String {
36    s.trim_start_matches("$").to_string()
37}
38
39pub fn expand_generate_macro(
40    attr: Option<Meta>,
41    mut item: ItemFn, // The original function should be kept as is
42) -> Result<TokenStream, syn::Error> {
43    let customize = match attr {
44        Some(attr) => Customize::from_meta(&attr).map_err(|e| syn::Error::new_spanned(attr, e))?,
45        None => Customize::default(),
46    };
47
48    // The mod path where the function can be accessed publicly
49    let (fn_mod_path, doc_mod_path) = match customize.root {
50        Some(lit_str) => {
51            let tokens: proc_macro2::TokenStream = lit_str.parse()?;
52            let s = remove_prefix_from_str(&lit_str.value());
53            (quote! { #tokens }, s)
54        }
55        None => (quote! { $crate::ops }, "crate::ops".into()),
56    };
57
58    let (default_generics, dtype_generics) =
59        handle_generic_args(&item.sig.generics, &customize.default_dtype);
60
61    let args = item
62        .sig
63        .inputs
64        .iter_mut()
65        .map(|arg| match arg {
66            FnArg::Receiver(_) => Err(syn::Error::new_spanned(arg, "self is not allowed")),
67            FnArg::Typed(pat_type) => Ok(pat_type),
68        })
69        .collect::<Result<Vec<_>, _>>()?;
70
71    let mut parsed_args = parse_args(args);
72
73    // Check if the last optional argument is `stream`
74    if let Some(arg) = parsed_args.last() {
75        if arg.ident != "stream" {
76            return Err(syn::Error::new_spanned(
77                &item,
78                "the last optional argument must be `stream`",
79            ));
80        }
81    }
82    // Remove the last optional argument `stream`
83    parsed_args.pop();
84
85    // Remove "_device" suffix from the macro name if it exists
86    let fn_ident = &item.sig.ident;
87
88    let generated = generate_macro(
89        &fn_mod_path,
90        &doc_mod_path,
91        fn_ident,
92        &parsed_args,
93        &default_generics,
94        &dtype_generics,
95    )?;
96
97    let output = quote! {
98        #item
99        #generated
100    };
101
102    Ok(output.into())
103}
104
105/// If there are generic arguments, the last argument is assumed to be `dtype`.
106///
107/// Returns two `syn::Generics`:
108/// 1. With the last argument set to `f32`
109/// 2. With the last argument set to `$dtype`
110fn handle_generic_args(
111    generic_args: &syn::Generics,
112    default_dtype: &Option<syn::Path>,
113) -> (proc_macro2::TokenStream, Option<proc_macro2::TokenStream>) {
114    // Count number of generic type arguments
115    let count = generic_args
116        .params
117        .iter()
118        .filter(|param| matches!(param, syn::GenericParam::Type(_)))
119        .count();
120
121    if count == 0 {
122        return (quote! {}, None);
123    }
124
125    // All generics arguments except for the last one will be inferred
126    let infer_tokens = vec![quote! { _ }; count - 1];
127
128    let default_generics = match default_dtype {
129        Some(path) => quote! { ::<#(#infer_tokens,)* #path> },
130        None => quote! { ::<#(#infer_tokens,)* f32> },
131    };
132    let dtype_generics = quote! { ::<#(#infer_tokens,)* $dtype> };
133
134    (default_generics, Some(dtype_generics))
135}
136
137#[derive(Debug, Clone, Copy)]
138enum ArgType {
139    Positional,
140    Named,
141    NamedOptional,
142}
143
144struct Arg {
145    ident: Ident,
146    arg_type: ArgType,
147}
148
149fn parse_args(args: Vec<&mut syn::PatType>) -> Vec<Arg> {
150    let mut is_prev_optional = false;
151    let mut parsed = Vec::new();
152    for arg in args {
153        match &*arg.pat {
154            syn::Pat::Ident(ident) => {
155                let arg_type = arg_type(&arg.attrs);
156
157                let is_positional = matches!(arg_type, ArgType::Positional);
158                if is_prev_optional && is_positional {
159                    panic!("positional argument cannot follow an optional argument");
160                }
161                is_prev_optional = matches!(arg_type, ArgType::NamedOptional);
162
163                parsed.push(Arg {
164                    ident: ident.ident.clone(),
165                    arg_type,
166                });
167            }
168            _ => panic!("unsupported pattern"),
169        }
170
171        remove_attribute(&mut arg.attrs, CUSTOM_ATTRIBUTES);
172    }
173    parsed
174}
175
176fn generate_macro(
177    fn_mod_path: &proc_macro2::TokenStream,
178    doc_mod_path: &str,
179    fn_ident: &Ident,
180    args: &[Arg],
181    default_generics: &proc_macro2::TokenStream,
182    dtype_generics: &Option<proc_macro2::TokenStream>,
183) -> Result<proc_macro2::TokenStream, syn::Error> {
184    let mut trimmed_fn_ident_str = fn_ident.to_string();
185    if trimmed_fn_ident_str.ends_with("_device") {
186        trimmed_fn_ident_str = trimmed_fn_ident_str.trim_end_matches("_device").to_string();
187    }
188    let trimmed_fn_ident = Ident::new(&trimmed_fn_ident_str, fn_ident.span());
189
190    let mut macro_variants = Vec::new();
191
192    generate_macro_variants(
193        fn_mod_path,
194        fn_ident,
195        &trimmed_fn_ident,
196        args,
197        default_generics,
198        dtype_generics,
199        &mut macro_variants,
200    );
201
202    let macro_docs = format!(
203        "Macro generated for the function [`{doc_mod_path}::{trimmed_fn_ident}`]. See the function documentation for more details."
204    );
205
206    let generated = quote! {
207        #[doc = #macro_docs]
208        #[macro_export]
209        macro_rules! #trimmed_fn_ident {
210            #(
211                #macro_variants
212            )*
213        }
214    };
215
216    Ok(generated)
217}
218
219fn generate_macro_variants(
220    fn_mod_path: &proc_macro2::TokenStream,
221    fn_ident: &Ident,
222    trimmed_fn_ident: &Ident,
223    args: &[Arg],
224    default_generics: &proc_macro2::TokenStream,
225    dtype_generics: &Option<proc_macro2::TokenStream>,
226    macro_variants: &mut Vec<proc_macro2::TokenStream>,
227) {
228    let args_ident = args.iter().map(|arg| &arg.ident).collect::<Vec<_>>();
229    let args_type = args.iter().map(|arg| arg.arg_type).collect::<Vec<_>>();
230    let mut optional_indices = Vec::new();
231    let mut selected = Vec::with_capacity(args.len());
232    for (idx, arg) in args.iter().enumerate() {
233        match arg.arg_type {
234            ArgType::Positional => {
235                selected.push(true);
236            }
237            ArgType::Named => {
238                selected.push(true);
239            }
240            ArgType::NamedOptional => {
241                selected.push(false);
242                optional_indices.push(idx);
243            }
244        }
245    }
246
247    for perms in 0..optional_indices.len() + 1 {
248        // Select `perms` number of optional arguments
249        for selected_indice in optional_indices.iter().permutations(perms) {
250            selected_indice.iter().for_each(|&&i| selected[i] = true);
251
252            generate_macro_variants_for_selected_args(
253                fn_mod_path,
254                fn_ident,
255                trimmed_fn_ident,
256                &args_ident,
257                &args_type,
258                &selected,
259                default_generics,
260                dtype_generics,
261                macro_variants,
262            );
263
264            // Clear the selected flag for the next iteration
265            selected_indice.iter().for_each(|&&i| selected[i] = false);
266        }
267    }
268}
269
270#[allow(clippy::too_many_arguments)]
271fn generate_macro_variants_for_selected_args(
272    fn_mod_path: &proc_macro2::TokenStream,
273    fn_ident: &Ident,
274    trimmed_fn_ident: &Ident,
275    args_ident: &[&Ident],
276    args_type: &[ArgType],
277    selected: &[bool],
278    default_generics: &proc_macro2::TokenStream,
279    dtype_generics: &Option<proc_macro2::TokenStream>,
280    macro_variants: &mut Vec<proc_macro2::TokenStream>,
281) {
282    let macro_args: Vec<proc_macro2::TokenStream> = args_ident
283        .iter()
284        .zip(args_type.iter())
285        .zip(selected.iter())
286        .filter_map(|((ident, arg_type), &selected)| match selected {
287            true => {
288                let token = match arg_type {
289                    ArgType::Positional => quote! { $#ident:expr },
290                    ArgType::Named => quote! { #ident=$#ident:expr },
291                    ArgType::NamedOptional => quote! { #ident=$#ident:expr },
292                };
293                Some(token)
294            }
295            false => None,
296        })
297        .collect();
298
299    let input: Vec<proc_macro2::TokenStream> = args_ident
300        .iter()
301        .zip(selected.iter())
302        .map(|(ident, &selected)| {
303            if selected {
304                quote! { $#ident }
305            } else {
306                quote! { None }
307            }
308        })
309        .collect();
310
311    let variant_body = quote! {
312        (
313            #(#macro_args),*
314        ) => {
315            #fn_mod_path::#trimmed_fn_ident #default_generics(#(#input,)*)
316        };
317        (
318            #(#macro_args,)*
319            stream=$stream:expr
320        ) => {
321            #fn_mod_path::#fn_ident #default_generics(#(#input,)* $stream)
322        };
323    };
324
325    macro_variants.push(variant_body);
326
327    if let Some(dtype_generics) = &dtype_generics {
328        let variant_body = quote! {
329            (
330                #(#macro_args,)*
331                dtype=$dtype:ty
332            ) => {
333                #fn_mod_path::#trimmed_fn_ident #dtype_generics(#(#input,)*)
334            };
335            (
336                #(#macro_args,)*
337                dtype=$dtype:ty,
338                stream=$stream:expr
339            ) => {
340                #fn_mod_path::#fn_ident #dtype_generics(#(#input,)* $stream)
341            };
342        };
343
344        macro_variants.push(variant_body);
345    }
346}