mlx_macros/
quantizable.rs1use darling::FromDeriveInput;
2use syn::{DeriveInput, Generics, Ident};
3
4use crate::util::{filter_fields_with_attr, FilteredFields};
5
6#[derive(Debug, Clone, FromDeriveInput)]
7#[darling(attributes(quantizable))]
8struct StructProperties {
9 root: Option<syn::Path>,
10}
11
12pub(crate) fn expand_quantizable(
13 input: &DeriveInput,
14) -> Result<proc_macro2::TokenStream, syn::Error> {
15 let prop = StructProperties::from_derive_input(input)?;
16 let struct_ident = &input.ident;
17 let generics = &input.generics;
18
19 match &input.data {
20 syn::Data::Struct(data) => {
21 expand_quantizable_module_for_struct(struct_ident, generics, data, prop.root)
22 }
23 _ => Err(syn::Error::new_spanned(
24 input,
25 "Quantizable can only be derived for structs",
26 )),
27 }
28}
29
30fn expand_quantizable_module_for_struct(
31 ident: &syn::Ident,
32 generics: &syn::Generics,
33 data: &syn::DataStruct,
34 root: Option<syn::Path>,
35) -> Result<proc_macro2::TokenStream, syn::Error> {
36 let fields = filter_fields_with_attr(&data.fields, "quantizable")?;
38
39 impl_quantizable_module_for_struct(ident, generics, fields, root)
40}
41
42fn impl_quantizable_module_for_struct(
43 ident: &Ident,
44 generics: &Generics,
45 fields: FilteredFields,
46 root: Option<syn::Path>,
47) -> Result<proc_macro2::TokenStream, syn::Error> {
48 let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
49 let filtered_field_names = fields.filtered.iter().map(|field| &field.ident);
52 let other_field_names = fields.other_fields.iter().map(|field| &field.ident);
53
54 if fields.filtered.is_empty() {
55 return Err(syn::Error::new_spanned(
56 ident,
57 "At least one field must be quantizable",
58 ));
59 }
60
61 let (extern_import, root) = match root {
62 Some(root) => (quote::quote! {}, quote::quote! { #root }),
63 None => (
64 quote::quote! { extern crate mlx_rs as _mlx_rs; },
65 quote::quote! { _mlx_rs },
66 ),
67 };
68
69 let token = quote::quote! {
70 const _: () = {
71 #extern_import
72 impl #impl_generics #root::quantization::Quantizable for #ident #ty_generics #where_clause {
73 type Quantized = Self; type QuantizationError = #root::error::Exception;
76
77 fn try_into_quantized(
78 self,
79 group_size: i32,
80 bits: i32,
81 ) -> Result<Self::Quantized, Self::QuantizationError> {
82 Ok(Self {
83 #(
84 #filtered_field_names: #root::quantization::Quantizable
85 ::try_into_quantized(self.#filtered_field_names, group_size, bits)?,
86 )*
87 #(
88 #other_field_names: self.#other_field_names,
89 )*
90 })
91 }
92 }
93 };
94 };
95 Ok(token)
96}