1use crate::{
2 error::{Exception, Result},
3 utils::{guard::Guarded, Closure, IntoOption},
4 Array,
5};
6
7use super::{value_and_gradient, ClosureValueAndGrad};
8
9#[inline]
10fn build_gradient_inner<'a>(
11 closure: Closure<'a>,
12 argnums: &'a [i32],
13) -> impl FnMut(&[Array]) -> Result<Vec<Array>> + 'a {
14 move |arrays: &[Array]| -> Result<Vec<Array>> {
15 let cvg = ClosureValueAndGrad::try_from_op(|res| unsafe {
16 mlx_sys::mlx_value_and_grad(res, closure.as_ptr(), argnums.as_ptr(), argnums.len())
17 })?;
18 let result = value_and_gradient(cvg.as_ptr(), arrays.iter())?;
19 Ok(result.1)
20 }
21}
22
23fn build_gradient<'a, F>(
24 f: F,
25 argnums: &'a [i32],
26) -> impl FnMut(&[Array]) -> Result<Vec<Array>> + 'a
27where
28 F: FnMut(&[Array]) -> Vec<Array> + 'a,
29{
30 let argnums = argnums.into_option().unwrap_or(&[0]);
31 let closure = Closure::new(f);
32 build_gradient_inner(closure, argnums)
33}
34
35fn build_fallible_gradient<'a, F>(
36 f: F,
37 argnums: &'a [i32],
38) -> impl FnMut(&[Array]) -> Result<Vec<Array>> + 'a
39where
40 F: FnMut(&[Array]) -> Result<Vec<Array>> + 'a,
41{
42 let closure = Closure::new_fallible(f);
43 build_gradient_inner(closure, argnums)
44}
45
46pub trait IntoGrad<'a, Args, Output, Err> {
48 fn into_grad(
50 self,
51 argnums: impl IntoOption<&'a [i32]>,
52 ) -> impl FnMut(Args) -> Result<Output> + 'a;
53}
54
55impl<'a, F> IntoGrad<'a, &[Array], Vec<Array>, ()> for F
56where
57 F: FnMut(&[Array]) -> Vec<Array> + 'a,
58{
59 #[allow(refining_impl_trait)]
62 fn into_grad(
63 self,
64 argnums: impl IntoOption<&'a [i32]>,
65 ) -> impl FnMut(&[Array]) -> Result<Vec<Array>> + 'a {
66 let argnums = argnums.into_option().unwrap_or(&[0]);
67 build_gradient(self, argnums)
68 }
69}
70
71impl<'a, F> IntoGrad<'a, &[Array], Vec<Array>, Exception> for F
72where
73 F: FnMut(&[Array]) -> Result<Vec<Array>> + 'a,
74{
75 #[allow(refining_impl_trait)]
76 fn into_grad(
77 self,
78 argnums: impl IntoOption<&'a [i32]>,
79 ) -> impl FnMut(&[Array]) -> Result<Vec<Array>> + 'a {
80 let argnums = argnums.into_option().unwrap_or(&[0]);
81 build_fallible_gradient(self, argnums)
82 }
83}
84
85impl<'a, F> IntoGrad<'a, &Array, Array, ()> for F
86where
87 F: FnMut(&Array) -> Array + 'a,
88{
89 #[allow(refining_impl_trait)]
90 fn into_grad(
91 mut self,
92 argnums: impl IntoOption<&'a [i32]>,
93 ) -> impl FnMut(&Array) -> Result<Array> + 'a {
94 let f = move |args: &[Array]| -> Vec<Array> { vec![self(&args[0])] };
95 let argnums = argnums.into_option().unwrap_or(&[0]);
96 let mut g = build_gradient(f, argnums);
97 move |args: &Array| -> Result<Array> {
98 let args_clone = &[args.clone()];
99 let result = g(args_clone)?;
100 Ok(result.into_iter().next().unwrap())
101 }
102 }
103}
104
105impl<'a, F> IntoGrad<'a, &Array, Array, Exception> for F
106where
107 F: FnMut(&Array) -> Result<Array> + 'a,
108{
109 #[allow(refining_impl_trait)]
110 fn into_grad(
111 mut self,
112 argnums: impl IntoOption<&'a [i32]>,
113 ) -> impl FnMut(&Array) -> Result<Array> + 'a {
114 let f = move |args: &[Array]| -> Result<Vec<Array>> { self(&args[0]).map(|res| vec![res]) };
115 let argnums = argnums.into_option().unwrap_or(&[0]);
116 let mut g = build_fallible_gradient(f, argnums);
117 move |args: &Array| -> Result<Array> {
118 let args_clone = &[args.clone()];
119 let result = g(args_clone)?;
120 Ok(result.into_iter().next().unwrap())
121 }
122 }
123}
124
125impl<'a, F> IntoGrad<'a, &[Array], Array, ()> for F
126where
127 F: FnMut(&[Array]) -> Array + 'a,
128{
129 #[allow(refining_impl_trait)]
130 fn into_grad(
131 mut self,
132 argnums: impl IntoOption<&'a [i32]>,
133 ) -> impl FnMut(&[Array]) -> Result<Array> + 'a {
134 let f = move |args: &[Array]| -> Vec<Array> { vec![self(args)] };
135 let argnums = argnums.into_option().unwrap_or(&[0]);
136 let mut g = build_gradient(f, argnums);
137 move |args: &[Array]| -> Result<Array> {
138 let result = g(args)?;
139 Ok(result.into_iter().next().unwrap())
140 }
141 }
142}
143
144impl<'a, F> IntoGrad<'a, &[Array], Array, Exception> for F
145where
146 F: FnMut(&[Array]) -> Result<Array> + 'a,
147{
148 #[allow(refining_impl_trait)]
149 fn into_grad(
150 mut self,
151 argnums: impl IntoOption<&'a [i32]>,
152 ) -> impl FnMut(&[Array]) -> Result<Array> + 'a {
153 let f = move |args: &[Array]| -> Result<Vec<Array>> { self(args).map(|res| vec![res]) };
154 let argnums = argnums.into_option().unwrap_or(&[0]);
155 let mut g = build_fallible_gradient(f, argnums);
156 move |args: &[Array]| -> Result<Array> {
157 let result = g(args)?;
158 Ok(result.into_iter().next().unwrap())
159 }
160 }
161}
162
163impl<'a, F> IntoGrad<'a, &Array, Vec<Array>, ()> for F
164where
165 F: FnMut(&Array) -> Vec<Array> + 'a,
166{
167 #[allow(refining_impl_trait)]
168 fn into_grad(
169 mut self,
170 argnums: impl IntoOption<&'a [i32]>,
171 ) -> impl FnMut(&Array) -> Result<Vec<Array>> + 'a {
172 let f = move |args: &[Array]| -> Vec<Array> { self(&args[0]) };
173 let argnums = argnums.into_option().unwrap_or(&[0]);
174 let mut g = build_gradient(f, argnums);
175 move |args: &Array| -> Result<Vec<Array>> {
176 let args_clone = &[args.clone()];
177 let result = g(args_clone)?;
178 Ok(result)
179 }
180 }
181}
182
183impl<'a, F> IntoGrad<'a, &Array, Vec<Array>, Exception> for F
184where
185 F: FnMut(&Array) -> Result<Vec<Array>> + 'a,
186{
187 #[allow(refining_impl_trait)]
188 fn into_grad(
189 mut self,
190 argnums: impl IntoOption<&'a [i32]>,
191 ) -> impl FnMut(&Array) -> Result<Vec<Array>> + 'a {
192 let f = move |args: &[Array]| -> Result<Vec<Array>> { self(&args[0]) };
193 let argnums = argnums.into_option().unwrap_or(&[0]);
194 let mut g = build_fallible_gradient(f, argnums);
195 move |args: &Array| -> Result<Vec<Array>> {
196 let args_clone = &[args.clone()];
197 let result = g(args_clone)?;
198 Ok(result)
199 }
200 }
201}
202
203pub fn grad<'a, F, Args, Output, Err>(f: F) -> impl FnMut(Args) -> Result<Output> + 'a
209where
210 F: IntoGrad<'a, Args, Output, Err>,
211{
212 f.into_grad(None)
213}
214
215pub fn grad_with_argnums<'a, F, Args, Output, Err>(
220 f: F,
221 argnums: impl IntoOption<&'a [i32]>,
222) -> impl FnMut(Args) -> Result<Output> + 'a
223where
224 F: IntoGrad<'a, Args, Output, Err>,
225{
226 f.into_grad(argnums)
227}
228
229#[cfg(test)]
230mod tests {
231
232 use crate::{
233 transforms::{grad, grad_with_argnums, value_and_grad, value_and_grad_with_argnums},
234 Array,
235 };
236
237 #[test]
239 fn test_grad() {
240 let x = &[Array::from_f32(1.0)];
241 let fun = |argin: &[Array]| -> Vec<Array> { vec![&argin[0] + 1.0] };
242 let argnums = &[0];
243
244 let grad_fn =
246 move |args: &[Array]| -> Vec<Array> { grad_with_argnums(fun, argnums)(args).unwrap() };
247 let (z, d2fdx2) = value_and_grad_with_argnums(grad_fn, argnums)(x).unwrap();
248
249 assert_eq!(z[0].item::<f32>(), 1.0);
250 assert_eq!(d2fdx2[0].item::<f32>(), 0.0);
251
252 let grad_fn = move |args: &[Array]| -> Vec<Array> { grad(fun)(args).unwrap() };
253 let (z, d2fdx2) = value_and_grad(grad_fn)(x).unwrap();
254
255 assert_eq!(z[0].item::<f32>(), 1.0);
256 assert_eq!(d2fdx2[0].item::<f32>(), 0.0);
257 }
258}