1use std::{collections::HashMap, rc::Rc};
2
3use crate::{
4 error::{Exception, Result},
5 utils::{guard::Guarded, Closure},
6 Array,
7};
8
9use super::{value_and_gradient, ClosureValueAndGrad};
10
11pub type KeyedParameters<Arr> = HashMap<Rc<str>, Arr>;
13
14pub type KeyedGrad = KeyedParameters<Array>;
16
17macro_rules! keyed_value_and_grad {
18 ($inner_ret:ty, $cls_new:ident, $f:ident, $args_ty:ty) => {
19 move |parameters: KeyedParameters<Arr>,
20 arrays: $args_ty|
21 -> Result<(Vec<Array>, KeyedGrad)> {
22 let (flattened_keys, flattened_values): (Vec<_>, Vec<_>) =
23 parameters.into_iter().unzip();
24
25 let inner = |flattened_arrays: &[Array]| -> $inner_ret {
26 let parameters = flattened_keys
27 .iter()
28 .cloned()
29 .zip(flattened_arrays.iter().cloned())
30 .collect();
31 ($f)(parameters, arrays.clone())
32 };
33
34 let argument_numbers = (0..flattened_values.len() as i32).collect::<Vec<_>>();
35
36 let closure = Closure::$cls_new(inner);
37 let cvg = ClosureValueAndGrad::try_from_op(|res| unsafe {
38 mlx_sys::mlx_value_and_grad(
39 res,
40 closure.as_ptr(),
41 argument_numbers.as_ptr(),
42 argument_numbers.len(),
43 )
44 })?;
45
46 let (value, grads) = value_and_gradient(cvg.as_ptr(), flattened_values.into_iter())?;
47
48 let grads_map = flattened_keys.iter().cloned().zip(grads).collect();
49
50 Ok((value, grads_map))
51 }
52 };
53}
54
55pub trait IntoKeyedValueAndGrad<'a, Arr, Args, Err>
57where
58 Arr: AsRef<Array>,
59 Args: Clone,
60{
61 fn into_keyed_value_and_grad(
63 self,
64 ) -> impl FnMut(KeyedParameters<Arr>, Args) -> Result<(Vec<Array>, KeyedGrad)> + 'a;
65}
66
67impl<'a, F, Arr, Args> IntoKeyedValueAndGrad<'a, Arr, Args, ()> for F
68where
69 F: FnMut(HashMap<Rc<str>, Array>, Args) -> Vec<Array> + 'a,
70 Arr: AsRef<Array>,
71 Args: Clone,
72{
73 fn into_keyed_value_and_grad(
74 mut self,
75 ) -> impl FnMut(KeyedParameters<Arr>, Args) -> Result<(Vec<Array>, KeyedGrad)> + 'a {
76 keyed_value_and_grad!(Vec<Array>, new, self, Args)
77 }
78}
79
80impl<'a, F, Arr, Args> IntoKeyedValueAndGrad<'a, Arr, Args, Exception> for F
81where
82 F: FnMut(HashMap<Rc<str>, Array>, Args) -> Result<Vec<Array>> + 'a,
83 Arr: AsRef<Array>,
84 Args: Clone,
85{
86 fn into_keyed_value_and_grad(
87 mut self,
88 ) -> impl FnMut(KeyedParameters<Arr>, Args) -> Result<(Vec<Array>, KeyedGrad)> + 'a {
89 keyed_value_and_grad!(Result<Vec<Array>>, new_fallible, self, Args)
90 }
91}
92
93pub fn keyed_value_and_grad<'a, F, Arr, Args, Err>(
95 f: F,
96) -> impl FnMut(KeyedParameters<Arr>, Args) -> Result<(Vec<Array>, KeyedGrad)> + 'a
97where
98 F: IntoKeyedValueAndGrad<'a, Arr, Args, Err> + 'a,
99 Arr: AsRef<Array>,
100 Args: Clone,
101{
102 f.into_keyed_value_and_grad()
103}
104
105#[cfg(test)]
106mod tests {
107 use std::{collections::HashMap, rc::Rc};
108
109 use crate::{array, Array};
110
111 use super::*;
112
113 #[test]
114 fn test_keyed_value_and_grad() {
115 let f = |parameters: HashMap<Rc<str>, Array>, _: i32| -> Vec<Array> {
116 vec![¶meters["x"] * ¶meters["y"]]
117 };
118
119 let x = array!(1.5f32);
120 let y = array!(2.0f32);
121 let parameters = vec![("x", x), ("y", y)]
122 .into_iter()
123 .map(|(k, v)| (k.into(), v))
124 .collect();
125
126 let mut vg = keyed_value_and_grad(f);
127
128 let (value, grad) = vg(parameters, 0).unwrap();
129
130 assert_eq!(value[0].item::<f32>(), 1.5 * 2.0);
131 assert_eq!(grad["x"].item::<f32>(), 2.0);
132 assert_eq!(grad["y"].item::<f32>(), 1.5);
133 }
134}