mlx_macros/
lib.rs

1extern crate proc_macro;
2use proc_macro::TokenStream;
3use syn::{parse_macro_input, DeriveInput};
4
5mod module_parameters;
6mod quantizable;
7mod util;
8
9/// Derive the `ModuleParameters` trait for a struct. Mark a field with
10/// `#[param]` attribute to include it in the parameters. The field type must
11/// implement the `mlx_rs::module::Parameter` trait.
12///
13/// # Example
14///
15/// ```rust, ignore
16/// use mlx_macros::ModuleParameters;
17/// use mlx_rs::module::{ModuleParameters, Param};
18///
19/// #[derive(ModuleParameters)]
20/// struct Example {
21///     #[param]
22///     regular: Param<Array>,
23///
24///     #[param]
25///     optional: Param<Option<Array>>,
26///
27///     #[param]
28///     nested: Inner,
29///
30///     #[param]
31///     vec_nested: Vec<Inner>,
32///
33///     #[param]
34///     trait_object: Box<dyn Module>,
35///
36///     #[param]
37///     trait_object_vec: Vec<Box<dyn Module>>,
38/// }
39///
40/// #[derive(ModuleParameters)]
41/// struct Inner {
42///     #[param]
43///     a: Param<Array>,
44/// }
45/// ```
46#[proc_macro_derive(ModuleParameters, attributes(module, param))]
47pub fn derive_module_parameters(input: TokenStream) -> TokenStream {
48    let input = parse_macro_input!(input as DeriveInput);
49    let module_param_impl = module_parameters::expand_module_parameters(&input).unwrap();
50    TokenStream::from(module_param_impl)
51}
52
53/// Derive the `Quantizable` trait for a struct. Mark a field with
54/// `#[quantizable]` attribute to include it in the quantization process.
55/// Only support types `M` that `M::Quantized = Self`
56///
57/// See `mlx-rs/mlx-tests/tests/test_quantizable.rs` for example usage.
58///
59/// # Panics
60///
61/// This macro will panic if the struct does not have any field marked with
62/// `#[quantizable]`.
63#[proc_macro_derive(Quantizable, attributes(quantizable))]
64pub fn derive_quantizable_module(input: TokenStream) -> TokenStream {
65    let input = parse_macro_input!(input as DeriveInput);
66    let quantizable_module_impl = quantizable::expand_quantizable(&input).unwrap();
67    TokenStream::from(quantizable_module_impl)
68}