mlx_macros/lib.rs
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68
extern crate proc_macro;
use proc_macro::TokenStream;
use syn::{parse_macro_input, DeriveInput};
mod module_parameters;
mod quantizable;
mod util;
/// Derive the `ModuleParameters` trait for a struct. Mark a field with
/// `#[param]` attribute to include it in the parameters. The field type must
/// implement the `mlx_rs::module::Parameter` trait.
///
/// # Example
///
/// ```rust, ignore
/// use mlx_macros::ModuleParameters;
/// use mlx_rs::module::{ModuleParameters, Param};
///
/// #[derive(ModuleParameters)]
/// struct Example {
/// #[param]
/// regular: Param<Array>,
///
/// #[param]
/// optional: Param<Option<Array>>,
///
/// #[param]
/// nested: Inner,
///
/// #[param]
/// vec_nested: Vec<Inner>,
///
/// #[param]
/// trait_object: Box<dyn Module>,
///
/// #[param]
/// trait_object_vec: Vec<Box<dyn Module>>,
/// }
///
/// #[derive(ModuleParameters)]
/// struct Inner {
/// #[param]
/// a: Param<Array>,
/// }
/// ```
#[proc_macro_derive(ModuleParameters, attributes(module, param))]
pub fn derive_module_parameters(input: TokenStream) -> TokenStream {
let input = parse_macro_input!(input as DeriveInput);
let module_param_impl = module_parameters::expand_module_parameters(&input).unwrap();
TokenStream::from(module_param_impl)
}
/// Derive the `Quantizable` trait for a struct. Mark a field with
/// `#[quantizable]` attribute to include it in the quantization process.
/// Only support types `M` that `M::Quantized = Self`
///
/// See `mlx-rs/mlx-tests/tests/test_quantizable.rs` for example usage.
///
/// # Panics
///
/// This macro will panic if the struct does not have any field marked with
/// `#[quantizable]`.
#[proc_macro_derive(Quantizable, attributes(quantizable))]
pub fn derive_quantizable_module(input: TokenStream) -> TokenStream {
let input = parse_macro_input!(input as DeriveInput);
let quantizable_module_impl = quantizable::expand_quantizable(&input).unwrap();
TokenStream::from(quantizable_module_impl)
}