1use darling::FromMeta;
2use itertools::Itertools;
3use proc_macro::TokenStream;
4use quote::quote;
5use syn::{FnArg, Ident, ItemFn, Meta};
6
7const CUSTOM_ATTRIBUTE_OPTIONAL: &str = "optional";
8const CUSTOM_ATTRIBUTE_NAMED: &str = "named";
9
10const CUSTOM_ATTRIBUTES: &[&str] = &[CUSTOM_ATTRIBUTE_OPTIONAL, CUSTOM_ATTRIBUTE_NAMED];
11
12#[derive(Default, Debug, FromMeta)]
13#[darling(default)]
14struct Customize {
15 root: Option<syn::LitStr>,
16 default_dtype: Option<syn::Path>,
17}
18
19fn arg_type(attrs: &[syn::Attribute]) -> ArgType {
20 for attr in attrs {
21 if attr.path().is_ident(CUSTOM_ATTRIBUTE_OPTIONAL) {
22 return ArgType::NamedOptional;
23 } else if attr.path().is_ident(CUSTOM_ATTRIBUTE_NAMED) {
24 return ArgType::Named;
25 }
26 }
27 ArgType::Positional
28}
29
30fn remove_attribute(attrs: &mut Vec<syn::Attribute>, targets: &[&str]) {
31 attrs.retain(|attr| !targets.iter().any(|target| !attr.path().is_ident(target)));
32}
33
34fn remove_prefix_from_str(s: &str) -> String {
36 s.trim_start_matches("$").to_string()
37}
38
39pub fn expand_generate_macro(
40 attr: Option<Meta>,
41 mut item: ItemFn, ) -> Result<TokenStream, syn::Error> {
43 let customize = match attr {
44 Some(attr) => Customize::from_meta(&attr).map_err(|e| syn::Error::new_spanned(attr, e))?,
45 None => Customize::default(),
46 };
47
48 let (fn_mod_path, doc_mod_path) = match customize.root {
50 Some(lit_str) => {
51 let tokens: proc_macro2::TokenStream = lit_str.parse()?;
52 let s = remove_prefix_from_str(&lit_str.value());
53 (quote! { #tokens }, s)
54 }
55 None => (quote! { $crate::ops }, "crate::ops".into()),
56 };
57
58 let (default_generics, dtype_generics) =
59 handle_generic_args(&item.sig.generics, &customize.default_dtype);
60
61 let args = item
62 .sig
63 .inputs
64 .iter_mut()
65 .map(|arg| match arg {
66 FnArg::Receiver(_) => Err(syn::Error::new_spanned(arg, "self is not allowed")),
67 FnArg::Typed(pat_type) => Ok(pat_type),
68 })
69 .collect::<Result<Vec<_>, _>>()?;
70
71 let mut parsed_args = parse_args(args);
72
73 if let Some(arg) = parsed_args.last() {
75 if arg.ident != "stream" {
76 return Err(syn::Error::new_spanned(
77 &item,
78 "the last optional argument must be `stream`",
79 ));
80 }
81 }
82 parsed_args.pop();
84
85 let fn_ident = &item.sig.ident;
87
88 let generated = generate_macro(
89 &fn_mod_path,
90 &doc_mod_path,
91 fn_ident,
92 &parsed_args,
93 &default_generics,
94 &dtype_generics,
95 )?;
96
97 let output = quote! {
98 #item
99 #generated
100 };
101
102 Ok(output.into())
103}
104
105fn handle_generic_args(
111 generic_args: &syn::Generics,
112 default_dtype: &Option<syn::Path>,
113) -> (proc_macro2::TokenStream, Option<proc_macro2::TokenStream>) {
114 let count = generic_args
116 .params
117 .iter()
118 .filter(|param| matches!(param, syn::GenericParam::Type(_)))
119 .count();
120
121 if count == 0 {
122 return (quote! {}, None);
123 }
124
125 let infer_tokens = vec![quote! { _ }; count - 1];
127
128 let default_generics = match default_dtype {
129 Some(path) => quote! { ::<#(#infer_tokens,)* #path> },
130 None => quote! { ::<#(#infer_tokens,)* f32> },
131 };
132 let dtype_generics = quote! { ::<#(#infer_tokens,)* $dtype> };
133
134 (default_generics, Some(dtype_generics))
135}
136
137#[derive(Debug, Clone, Copy)]
138enum ArgType {
139 Positional,
140 Named,
141 NamedOptional,
142}
143
144struct Arg {
145 ident: Ident,
146 arg_type: ArgType,
147}
148
149fn parse_args(args: Vec<&mut syn::PatType>) -> Vec<Arg> {
150 let mut is_prev_optional = false;
151 let mut parsed = Vec::new();
152 for arg in args {
153 match &*arg.pat {
154 syn::Pat::Ident(ident) => {
155 let arg_type = arg_type(&arg.attrs);
156
157 let is_positional = matches!(arg_type, ArgType::Positional);
158 if is_prev_optional && is_positional {
159 panic!("positional argument cannot follow an optional argument");
160 }
161 is_prev_optional = matches!(arg_type, ArgType::NamedOptional);
162
163 parsed.push(Arg {
164 ident: ident.ident.clone(),
165 arg_type,
166 });
167 }
168 _ => panic!("unsupported pattern"),
169 }
170
171 remove_attribute(&mut arg.attrs, CUSTOM_ATTRIBUTES);
172 }
173 parsed
174}
175
176fn generate_macro(
177 fn_mod_path: &proc_macro2::TokenStream,
178 doc_mod_path: &str,
179 fn_ident: &Ident,
180 args: &[Arg],
181 default_generics: &proc_macro2::TokenStream,
182 dtype_generics: &Option<proc_macro2::TokenStream>,
183) -> Result<proc_macro2::TokenStream, syn::Error> {
184 let mut trimmed_fn_ident_str = fn_ident.to_string();
185 if trimmed_fn_ident_str.ends_with("_device") {
186 trimmed_fn_ident_str = trimmed_fn_ident_str.trim_end_matches("_device").to_string();
187 }
188 let trimmed_fn_ident = Ident::new(&trimmed_fn_ident_str, fn_ident.span());
189
190 let mut macro_variants = Vec::new();
191
192 generate_macro_variants(
193 fn_mod_path,
194 fn_ident,
195 &trimmed_fn_ident,
196 args,
197 default_generics,
198 dtype_generics,
199 &mut macro_variants,
200 );
201
202 let macro_docs = format!(
203 "Macro generated for the function [`{}::{}`]. See the function documentation for more details.",
204 doc_mod_path, trimmed_fn_ident
205 );
206
207 let generated = quote! {
208 #[doc = #macro_docs]
209 #[macro_export]
210 macro_rules! #trimmed_fn_ident {
211 #(
212 #macro_variants
213 )*
214 }
215 };
216
217 Ok(generated)
218}
219
220fn generate_macro_variants(
221 fn_mod_path: &proc_macro2::TokenStream,
222 fn_ident: &Ident,
223 trimmed_fn_ident: &Ident,
224 args: &[Arg],
225 default_generics: &proc_macro2::TokenStream,
226 dtype_generics: &Option<proc_macro2::TokenStream>,
227 macro_variants: &mut Vec<proc_macro2::TokenStream>,
228) {
229 let args_ident = args.iter().map(|arg| &arg.ident).collect::<Vec<_>>();
230 let args_type = args.iter().map(|arg| arg.arg_type).collect::<Vec<_>>();
231 let mut optional_indices = Vec::new();
232 let mut selected = Vec::with_capacity(args.len());
233 for (idx, arg) in args.iter().enumerate() {
234 match arg.arg_type {
235 ArgType::Positional => {
236 selected.push(true);
237 }
238 ArgType::Named => {
239 selected.push(true);
240 }
241 ArgType::NamedOptional => {
242 selected.push(false);
243 optional_indices.push(idx);
244 }
245 }
246 }
247
248 for perms in 0..optional_indices.len() + 1 {
249 for selected_indice in optional_indices.iter().permutations(perms) {
251 selected_indice.iter().for_each(|&&i| selected[i] = true);
252
253 generate_macro_variants_for_selected_args(
254 fn_mod_path,
255 fn_ident,
256 trimmed_fn_ident,
257 &args_ident,
258 &args_type,
259 &selected,
260 default_generics,
261 dtype_generics,
262 macro_variants,
263 );
264
265 selected_indice.iter().for_each(|&&i| selected[i] = false);
267 }
268 }
269}
270
271#[allow(clippy::too_many_arguments)]
272fn generate_macro_variants_for_selected_args(
273 fn_mod_path: &proc_macro2::TokenStream,
274 fn_ident: &Ident,
275 trimmed_fn_ident: &Ident,
276 args_ident: &[&Ident],
277 args_type: &[ArgType],
278 selected: &[bool],
279 default_generics: &proc_macro2::TokenStream,
280 dtype_generics: &Option<proc_macro2::TokenStream>,
281 macro_variants: &mut Vec<proc_macro2::TokenStream>,
282) {
283 let macro_args: Vec<proc_macro2::TokenStream> = args_ident
284 .iter()
285 .zip(args_type.iter())
286 .zip(selected.iter())
287 .filter_map(|((ident, arg_type), &selected)| match selected {
288 true => {
289 let token = match arg_type {
290 ArgType::Positional => quote! { $#ident:expr },
291 ArgType::Named => quote! { #ident=$#ident:expr },
292 ArgType::NamedOptional => quote! { #ident=$#ident:expr },
293 };
294 Some(token)
295 }
296 false => None,
297 })
298 .collect();
299
300 let input: Vec<proc_macro2::TokenStream> = args_ident
301 .iter()
302 .zip(selected.iter())
303 .map(|(ident, &selected)| {
304 if selected {
305 quote! { $#ident }
306 } else {
307 quote! { None }
308 }
309 })
310 .collect();
311
312 let variant_body = quote! {
313 (
314 #(#macro_args),*
315 ) => {
316 #fn_mod_path::#trimmed_fn_ident #default_generics(#(#input,)*)
317 };
318 (
319 #(#macro_args,)*
320 stream=$stream:expr
321 ) => {
322 #fn_mod_path::#fn_ident #default_generics(#(#input,)* $stream)
323 };
324 };
325
326 macro_variants.push(variant_body);
327
328 if let Some(dtype_generics) = &dtype_generics {
329 let variant_body = quote! {
330 (
331 #(#macro_args,)*
332 dtype=$dtype:ty
333 ) => {
334 #fn_mod_path::#trimmed_fn_ident #dtype_generics(#(#input,)*)
335 };
336 (
337 #(#macro_args,)*
338 dtype=$dtype:ty,
339 stream=$stream:expr
340 ) => {
341 #fn_mod_path::#fn_ident #dtype_generics(#(#input,)* $stream)
342 };
343 };
344
345 macro_variants.push(variant_body);
346 }
347}