mlx_rs/transforms/
grad.rs

1use crate::{
2    error::{Exception, Result},
3    utils::{guard::Guarded, Closure, IntoOption},
4    Array,
5};
6
7use super::{value_and_gradient, ClosureValueAndGrad};
8
9#[inline]
10fn build_gradient_inner<'a>(
11    closure: Closure<'a>,
12    argnums: &'a [i32],
13) -> impl FnMut(&[Array]) -> Result<Vec<Array>> + 'a {
14    move |arrays: &[Array]| -> Result<Vec<Array>> {
15        let cvg = ClosureValueAndGrad::try_from_op(|res| unsafe {
16            mlx_sys::mlx_value_and_grad(res, closure.as_ptr(), argnums.as_ptr(), argnums.len())
17        })?;
18        let result = value_and_gradient(cvg.as_ptr(), arrays.iter())?;
19        Ok(result.1)
20    }
21}
22
23fn build_gradient<'a, F>(
24    f: F,
25    argnums: &'a [i32],
26) -> impl FnMut(&[Array]) -> Result<Vec<Array>> + 'a
27where
28    F: FnMut(&[Array]) -> Vec<Array> + 'a,
29{
30    let argnums = argnums.into_option().unwrap_or(&[0]);
31    let closure = Closure::new(f);
32    build_gradient_inner(closure, argnums)
33}
34
35fn build_fallible_gradient<'a, F>(
36    f: F,
37    argnums: &'a [i32],
38) -> impl FnMut(&[Array]) -> Result<Vec<Array>> + 'a
39where
40    F: FnMut(&[Array]) -> Result<Vec<Array>> + 'a,
41{
42    let closure = Closure::new_fallible(f);
43    build_gradient_inner(closure, argnums)
44}
45
46/// Trait for functions/closures that can be converted into a closure that computes the gradient.
47pub trait IntoGrad<'a, Args, Output, Err> {
48    /// Convert the function/closure into a closure that computes the gradient.
49    fn into_grad(
50        self,
51        argnums: impl IntoOption<&'a [i32]>,
52    ) -> impl FnMut(Args) -> Result<Output> + 'a;
53}
54
55impl<'a, F> IntoGrad<'a, &[Array], Vec<Array>, ()> for F
56where
57    F: FnMut(&[Array]) -> Vec<Array> + 'a,
58{
59    // refining_impl_trait is fine here because we have restricted the Args and Output types
60    // in the generics.
61    #[allow(refining_impl_trait)]
62    fn into_grad(
63        self,
64        argnums: impl IntoOption<&'a [i32]>,
65    ) -> impl FnMut(&[Array]) -> Result<Vec<Array>> + 'a {
66        let argnums = argnums.into_option().unwrap_or(&[0]);
67        build_gradient(self, argnums)
68    }
69}
70
71impl<'a, F> IntoGrad<'a, &[Array], Vec<Array>, Exception> for F
72where
73    F: FnMut(&[Array]) -> Result<Vec<Array>> + 'a,
74{
75    #[allow(refining_impl_trait)]
76    fn into_grad(
77        self,
78        argnums: impl IntoOption<&'a [i32]>,
79    ) -> impl FnMut(&[Array]) -> Result<Vec<Array>> + 'a {
80        let argnums = argnums.into_option().unwrap_or(&[0]);
81        build_fallible_gradient(self, argnums)
82    }
83}
84
85impl<'a, F> IntoGrad<'a, &Array, Array, ()> for F
86where
87    F: FnMut(&Array) -> Array + 'a,
88{
89    #[allow(refining_impl_trait)]
90    fn into_grad(
91        mut self,
92        argnums: impl IntoOption<&'a [i32]>,
93    ) -> impl FnMut(&Array) -> Result<Array> + 'a {
94        let f = move |args: &[Array]| -> Vec<Array> { vec![self(&args[0])] };
95        let argnums = argnums.into_option().unwrap_or(&[0]);
96        let mut g = build_gradient(f, argnums);
97        move |args: &Array| -> Result<Array> {
98            let args_clone = &[args.clone()];
99            let result = g(args_clone)?;
100            Ok(result.into_iter().next().unwrap())
101        }
102    }
103}
104
105impl<'a, F> IntoGrad<'a, &Array, Array, Exception> for F
106where
107    F: FnMut(&Array) -> Result<Array> + 'a,
108{
109    #[allow(refining_impl_trait)]
110    fn into_grad(
111        mut self,
112        argnums: impl IntoOption<&'a [i32]>,
113    ) -> impl FnMut(&Array) -> Result<Array> + 'a {
114        let f = move |args: &[Array]| -> Result<Vec<Array>> { self(&args[0]).map(|res| vec![res]) };
115        let argnums = argnums.into_option().unwrap_or(&[0]);
116        let mut g = build_fallible_gradient(f, argnums);
117        move |args: &Array| -> Result<Array> {
118            let args_clone = &[args.clone()];
119            let result = g(args_clone)?;
120            Ok(result.into_iter().next().unwrap())
121        }
122    }
123}
124
125impl<'a, F> IntoGrad<'a, &[Array], Array, ()> for F
126where
127    F: FnMut(&[Array]) -> Array + 'a,
128{
129    #[allow(refining_impl_trait)]
130    fn into_grad(
131        mut self,
132        argnums: impl IntoOption<&'a [i32]>,
133    ) -> impl FnMut(&[Array]) -> Result<Array> + 'a {
134        let f = move |args: &[Array]| -> Vec<Array> { vec![self(args)] };
135        let argnums = argnums.into_option().unwrap_or(&[0]);
136        let mut g = build_gradient(f, argnums);
137        move |args: &[Array]| -> Result<Array> {
138            let result = g(args)?;
139            Ok(result.into_iter().next().unwrap())
140        }
141    }
142}
143
144impl<'a, F> IntoGrad<'a, &[Array], Array, Exception> for F
145where
146    F: FnMut(&[Array]) -> Result<Array> + 'a,
147{
148    #[allow(refining_impl_trait)]
149    fn into_grad(
150        mut self,
151        argnums: impl IntoOption<&'a [i32]>,
152    ) -> impl FnMut(&[Array]) -> Result<Array> + 'a {
153        let f = move |args: &[Array]| -> Result<Vec<Array>> { self(args).map(|res| vec![res]) };
154        let argnums = argnums.into_option().unwrap_or(&[0]);
155        let mut g = build_fallible_gradient(f, argnums);
156        move |args: &[Array]| -> Result<Array> {
157            let result = g(args)?;
158            Ok(result.into_iter().next().unwrap())
159        }
160    }
161}
162
163impl<'a, F> IntoGrad<'a, &Array, Vec<Array>, ()> for F
164where
165    F: FnMut(&Array) -> Vec<Array> + 'a,
166{
167    #[allow(refining_impl_trait)]
168    fn into_grad(
169        mut self,
170        argnums: impl IntoOption<&'a [i32]>,
171    ) -> impl FnMut(&Array) -> Result<Vec<Array>> + 'a {
172        let f = move |args: &[Array]| -> Vec<Array> { self(&args[0]) };
173        let argnums = argnums.into_option().unwrap_or(&[0]);
174        let mut g = build_gradient(f, argnums);
175        move |args: &Array| -> Result<Vec<Array>> {
176            let args_clone = &[args.clone()];
177            let result = g(args_clone)?;
178            Ok(result)
179        }
180    }
181}
182
183impl<'a, F> IntoGrad<'a, &Array, Vec<Array>, Exception> for F
184where
185    F: FnMut(&Array) -> Result<Vec<Array>> + 'a,
186{
187    #[allow(refining_impl_trait)]
188    fn into_grad(
189        mut self,
190        argnums: impl IntoOption<&'a [i32]>,
191    ) -> impl FnMut(&Array) -> Result<Vec<Array>> + 'a {
192        let f = move |args: &[Array]| -> Result<Vec<Array>> { self(&args[0]) };
193        let argnums = argnums.into_option().unwrap_or(&[0]);
194        let mut g = build_fallible_gradient(f, argnums);
195        move |args: &Array| -> Result<Vec<Array>> {
196            let args_clone = &[args.clone()];
197            let result = g(args_clone)?;
198            Ok(result)
199        }
200    }
201}
202
203/// Returns a function which computes the gradient of `f` with the default
204/// argument numbers `&[0]`.
205///
206/// See also [`grad_with_arg_nums`] for a version that allows specifying the
207/// argument numbers
208pub fn grad<'a, F, Args, Output, Err>(f: F) -> impl FnMut(Args) -> Result<Output> + 'a
209where
210    F: IntoGrad<'a, Args, Output, Err>,
211{
212    f.into_grad(None)
213}
214
215/// Returns a function which computes the gradient of `f`.
216///
217/// See also [`grad`] for a version that uses the default argument numbers
218/// `&[0]`.
219pub fn grad_with_argnums<'a, F, Args, Output, Err>(
220    f: F,
221    argnums: impl IntoOption<&'a [i32]>,
222) -> impl FnMut(Args) -> Result<Output> + 'a
223where
224    F: IntoGrad<'a, Args, Output, Err>,
225{
226    f.into_grad(argnums)
227}
228
229#[cfg(test)]
230mod tests {
231
232    use crate::{
233        transforms::{grad, grad_with_argnums, value_and_grad, value_and_grad_with_argnums},
234        Array,
235    };
236
237    // The unit tests below are adapted from the mlx c++ codebase
238    #[test]
239    fn test_grad() {
240        let x = &[Array::from_f32(1.0)];
241        let fun = |argin: &[Array]| -> Vec<Array> { vec![&argin[0] + 1.0] };
242        let argnums = &[0];
243
244        // TODO: how to make this more "functional"?
245        let grad_fn =
246            move |args: &[Array]| -> Vec<Array> { grad_with_argnums(fun, argnums)(args).unwrap() };
247        let (z, d2fdx2) = value_and_grad_with_argnums(grad_fn, argnums)(x).unwrap();
248
249        assert_eq!(z[0].item::<f32>(), 1.0);
250        assert_eq!(d2fdx2[0].item::<f32>(), 0.0);
251
252        let grad_fn = move |args: &[Array]| -> Vec<Array> { grad(fun)(args).unwrap() };
253        let (z, d2fdx2) = value_and_grad(grad_fn)(x).unwrap();
254
255        assert_eq!(z[0].item::<f32>(), 1.0);
256        assert_eq!(d2fdx2[0].item::<f32>(), 0.0);
257    }
258}