mlx_macros/lib.rs
1extern crate proc_macro;
2use proc_macro::TokenStream;
3use syn::{parse_macro_input, DeriveInput};
4
5mod module_parameters;
6mod quantizable;
7mod util;
8
9/// Derive the `ModuleParameters` trait for a struct. Mark a field with
10/// `#[param]` attribute to include it in the parameters. The field type must
11/// implement the `mlx_rs::module::Parameter` trait.
12///
13/// # Example
14///
15/// ```rust, ignore
16/// use mlx_macros::ModuleParameters;
17/// use mlx_rs::module::{ModuleParameters, Param};
18///
19/// #[derive(ModuleParameters)]
20/// struct Example {
21/// #[param]
22/// regular: Param<Array>,
23///
24/// #[param]
25/// optional: Param<Option<Array>>,
26///
27/// #[param]
28/// nested: Inner,
29///
30/// #[param]
31/// vec_nested: Vec<Inner>,
32///
33/// #[param]
34/// trait_object: Box<dyn Module>,
35///
36/// #[param]
37/// trait_object_vec: Vec<Box<dyn Module>>,
38/// }
39///
40/// #[derive(ModuleParameters)]
41/// struct Inner {
42/// #[param]
43/// a: Param<Array>,
44/// }
45/// ```
46#[proc_macro_derive(ModuleParameters, attributes(module, param))]
47pub fn derive_module_parameters(input: TokenStream) -> TokenStream {
48 let input = parse_macro_input!(input as DeriveInput);
49 let module_param_impl = module_parameters::expand_module_parameters(&input).unwrap();
50 TokenStream::from(module_param_impl)
51}
52
53/// Derive the `Quantizable` trait for a struct. Mark a field with
54/// `#[quantizable]` attribute to include it in the quantization process.
55/// Only support types `M` that `M::Quantized = Self`
56///
57/// See `mlx-rs/mlx-tests/tests/test_quantizable.rs` for example usage.
58///
59/// # Panics
60///
61/// This macro will panic if the struct does not have any field marked with
62/// `#[quantizable]`.
63#[proc_macro_derive(Quantizable, attributes(quantizable))]
64pub fn derive_quantizable_module(input: TokenStream) -> TokenStream {
65 let input = parse_macro_input!(input as DeriveInput);
66 let quantizable_module_impl = quantizable::expand_quantizable(&input).unwrap();
67 TokenStream::from(quantizable_module_impl)
68}