mlx_macros/
quantizable.rs

1use darling::FromDeriveInput;
2use syn::{DeriveInput, Generics, Ident};
3
4use crate::util::{filter_fields_with_attr, FilteredFields};
5
6#[derive(Debug, Clone, FromDeriveInput)]
7#[darling(attributes(quantizable))]
8struct StructProperties {
9    root: Option<syn::Path>,
10}
11
12pub(crate) fn expand_quantizable(
13    input: &DeriveInput,
14) -> Result<proc_macro2::TokenStream, syn::Error> {
15    let prop = StructProperties::from_derive_input(input)?;
16    let struct_ident = &input.ident;
17    let generics = &input.generics;
18
19    match &input.data {
20        syn::Data::Struct(data) => {
21            expand_quantizable_module_for_struct(struct_ident, generics, data, prop.root)
22        }
23        _ => Err(syn::Error::new_spanned(
24            input,
25            "Quantizable can only be derived for structs",
26        )),
27    }
28}
29
30fn expand_quantizable_module_for_struct(
31    ident: &syn::Ident,
32    generics: &syn::Generics,
33    data: &syn::DataStruct,
34    root: Option<syn::Path>,
35) -> Result<proc_macro2::TokenStream, syn::Error> {
36    // Filter fields with #[quantizable]
37    let fields = filter_fields_with_attr(&data.fields, "quantizable")?;
38
39    impl_quantizable_module_for_struct(ident, generics, fields, root)
40}
41
42fn impl_quantizable_module_for_struct(
43    ident: &Ident,
44    generics: &Generics,
45    fields: FilteredFields,
46    root: Option<syn::Path>,
47) -> Result<proc_macro2::TokenStream, syn::Error> {
48    let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
49    // let field_names: Vec<_> = fields.iter().map(|field| &field.ident).collect();
50
51    let filtered_field_names = fields.filtered.iter().map(|field| &field.ident);
52    let other_field_names = fields.other_fields.iter().map(|field| &field.ident);
53
54    if fields.filtered.is_empty() {
55        return Err(syn::Error::new_spanned(
56            ident,
57            "At least one field must be quantizable",
58        ));
59    }
60
61    let (extern_import, root) = match root {
62        Some(root) => (quote::quote! {}, quote::quote! { #root }),
63        None => (
64            quote::quote! { extern crate mlx_rs as _mlx_rs; },
65            quote::quote! { _mlx_rs },
66        ),
67    };
68
69    let token = quote::quote! {
70        const _: () = {
71            #extern_import
72            impl #impl_generics #root::quantization::Quantizable for #ident #ty_generics #where_clause {
73                type Quantized = Self; // Generating new struct is not supported yet
74
75                type QuantizationError = #root::error::Exception;
76
77                fn try_into_quantized(
78                    self,
79                    group_size: i32,
80                    bits: i32,
81                ) -> Result<Self::Quantized, Self::QuantizationError> {
82                    Ok(Self {
83                        #(
84                            #filtered_field_names: #root::quantization::Quantizable
85                                ::try_into_quantized(self.#filtered_field_names, group_size, bits)?,
86                        )*
87                        #(
88                            #other_field_names: self.#other_field_names,
89                        )*
90                    })
91                }
92            }
93        };
94    };
95    Ok(token)
96}