mlx_rs/transforms/
keyed_value_and_grad.rs

1use std::{collections::HashMap, rc::Rc};
2
3use crate::{
4    error::{Exception, Result},
5    utils::{guard::Guarded, Closure},
6    Array,
7};
8
9use super::{value_and_gradient, ClosureValueAndGrad};
10
11/// Type alias for a hashmap of parameters.
12pub type KeyedParameters<Arr> = HashMap<Rc<str>, Arr>;
13
14/// Type alias for a hashmap of gradients.
15pub type KeyedGrad = KeyedParameters<Array>;
16
17macro_rules! keyed_value_and_grad {
18    ($inner_ret:ty, $cls_new:ident, $f:ident, $args_ty:ty) => {
19        move |parameters: KeyedParameters<Arr>,
20              arrays: $args_ty|
21              -> Result<(Vec<Array>, KeyedGrad)> {
22            let (flattened_keys, flattened_values): (Vec<_>, Vec<_>) =
23                parameters.into_iter().unzip();
24
25            let inner = |flattened_arrays: &[Array]| -> $inner_ret {
26                let parameters = flattened_keys
27                    .iter()
28                    .cloned()
29                    .zip(flattened_arrays.iter().cloned())
30                    .collect();
31                ($f)(parameters, arrays.clone())
32            };
33
34            let argument_numbers = (0..flattened_values.len() as i32).collect::<Vec<_>>();
35
36            let closure = Closure::$cls_new(inner);
37            let cvg = ClosureValueAndGrad::try_from_op(|res| unsafe {
38                mlx_sys::mlx_value_and_grad(
39                    res,
40                    closure.as_ptr(),
41                    argument_numbers.as_ptr(),
42                    argument_numbers.len(),
43                )
44            })?;
45
46            let (value, grads) = value_and_gradient(cvg.as_ptr(), flattened_values.into_iter())?;
47
48            let grads_map = flattened_keys.iter().cloned().zip(grads).collect();
49
50            Ok((value, grads_map))
51        }
52    };
53}
54
55/// Similar to [`IntoValueAndGrad`] but for functions that take a hashmap of parameters.
56pub trait IntoKeyedValueAndGrad<'a, Arr, Args, Err>
57where
58    Arr: AsRef<Array>,
59    Args: Clone,
60{
61    /// Convert the function/closure into a closure that computes the value and gradient.
62    fn into_keyed_value_and_grad(
63        self,
64    ) -> impl FnMut(KeyedParameters<Arr>, Args) -> Result<(Vec<Array>, KeyedGrad)> + 'a;
65}
66
67impl<'a, F, Arr, Args> IntoKeyedValueAndGrad<'a, Arr, Args, ()> for F
68where
69    F: FnMut(HashMap<Rc<str>, Array>, Args) -> Vec<Array> + 'a,
70    Arr: AsRef<Array>,
71    Args: Clone,
72{
73    fn into_keyed_value_and_grad(
74        mut self,
75    ) -> impl FnMut(KeyedParameters<Arr>, Args) -> Result<(Vec<Array>, KeyedGrad)> + 'a {
76        keyed_value_and_grad!(Vec<Array>, new, self, Args)
77    }
78}
79
80impl<'a, F, Arr, Args> IntoKeyedValueAndGrad<'a, Arr, Args, Exception> for F
81where
82    F: FnMut(HashMap<Rc<str>, Array>, Args) -> Result<Vec<Array>> + 'a,
83    Arr: AsRef<Array>,
84    Args: Clone,
85{
86    fn into_keyed_value_and_grad(
87        mut self,
88    ) -> impl FnMut(KeyedParameters<Arr>, Args) -> Result<(Vec<Array>, KeyedGrad)> + 'a {
89        keyed_value_and_grad!(Result<Vec<Array>>, new_fallible, self, Args)
90    }
91}
92
93/// Returns a function which computes the value and gradient of `f` with keyed parameters.
94pub fn keyed_value_and_grad<'a, F, Arr, Args, Err>(
95    f: F,
96) -> impl FnMut(KeyedParameters<Arr>, Args) -> Result<(Vec<Array>, KeyedGrad)> + 'a
97where
98    F: IntoKeyedValueAndGrad<'a, Arr, Args, Err> + 'a,
99    Arr: AsRef<Array>,
100    Args: Clone,
101{
102    f.into_keyed_value_and_grad()
103}
104
105#[cfg(test)]
106mod tests {
107    use std::{collections::HashMap, rc::Rc};
108
109    use crate::{array, Array};
110
111    use super::*;
112
113    #[test]
114    fn test_keyed_value_and_grad() {
115        let f = |parameters: HashMap<Rc<str>, Array>, _: i32| -> Vec<Array> {
116            vec![&parameters["x"] * &parameters["y"]]
117        };
118
119        let x = array!(1.5f32);
120        let y = array!(2.0f32);
121        let parameters = vec![("x", x), ("y", y)]
122            .into_iter()
123            .map(|(k, v)| (k.into(), v))
124            .collect();
125
126        let mut vg = keyed_value_and_grad(f);
127
128        let (value, grad) = vg(parameters, 0).unwrap();
129
130        assert_eq!(value[0].item::<f32>(), 1.5 * 2.0);
131        assert_eq!(grad["x"].item::<f32>(), 2.0);
132        assert_eq!(grad["y"].item::<f32>(), 1.5);
133    }
134}