mlx_rs/transforms/
value_and_grad.rs

1use crate::{
2    error::{Exception, Result},
3    utils::{guard::Guarded, Closure, IntoOption},
4    Array,
5};
6
7use super::{value_and_gradient, ClosureValueAndGrad};
8
9fn build_value_and_gradient_inner<'a>(
10    closure: Closure<'a>,
11    argnums: &'a [i32],
12) -> impl FnMut(&[Array]) -> Result<(Vec<Array>, Vec<Array>)> + 'a {
13    move |arrays: &[Array]| unsafe {
14        let cvg = ClosureValueAndGrad::try_from_op(|res| {
15            mlx_sys::mlx_value_and_grad(res, closure.as_ptr(), argnums.as_ptr(), argnums.len())
16        })?;
17        value_and_gradient(cvg.as_ptr(), arrays.iter())
18    }
19}
20
21fn build_value_and_gradient<'a, F>(
22    f: F,
23    argnums: &'a [i32],
24) -> impl FnMut(&[Array]) -> Result<(Vec<Array>, Vec<Array>)> + 'a
25where
26    F: FnMut(&[Array]) -> Vec<Array> + 'a,
27{
28    let closure = Closure::new(f);
29    build_value_and_gradient_inner(closure, argnums)
30}
31
32fn build_fallible_value_and_gradient<'a, F>(
33    f: F,
34    argnums: &'a [i32],
35) -> impl FnMut(&[Array]) -> Result<(Vec<Array>, Vec<Array>)> + 'a
36where
37    F: FnMut(&[Array]) -> Result<Vec<Array>> + 'a,
38{
39    let closure = Closure::new_fallible(f);
40    build_value_and_gradient_inner(closure, argnums)
41}
42
43/// Trait for functions/closures that can be converted into a closure that computes the value and
44/// gradient.
45pub trait IntoValueAndGrad<'a, Err> {
46    /// Convert the function/closure into a closure that computes the value and gradient.
47    fn into_value_and_grad(
48        self,
49        argnums: impl IntoOption<&'a [i32]>,
50    ) -> impl FnMut(&[Array]) -> Result<(Vec<Array>, Vec<Array>)> + 'a;
51}
52
53impl<'a, F> IntoValueAndGrad<'a, ()> for F
54where
55    F: FnMut(&[Array]) -> Vec<Array> + 'a,
56{
57    // refining_impl_trait is fine here because we have restricted the Args and Output types
58    // in the generics.
59    #[allow(refining_impl_trait)]
60    fn into_value_and_grad(
61        self,
62        argnums: impl IntoOption<&'a [i32]>,
63    ) -> impl FnMut(&[Array]) -> Result<(Vec<Array>, Vec<Array>)> + 'a {
64        let argnums = argnums.into_option().unwrap_or(&[0]);
65        build_value_and_gradient(self, argnums)
66    }
67}
68
69impl<'a, F> IntoValueAndGrad<'a, Exception> for F
70where
71    F: FnMut(&[Array]) -> Result<Vec<Array>> + 'a,
72{
73    #[allow(refining_impl_trait)]
74    fn into_value_and_grad(
75        self,
76        argnums: impl IntoOption<&'a [i32]>,
77    ) -> impl FnMut(&[Array]) -> Result<(Vec<Array>, Vec<Array>)> + 'a {
78        let argnums = argnums.into_option().unwrap_or(&[0]);
79        build_fallible_value_and_gradient(self, argnums)
80    }
81}
82
83/// Returns a function which computes the value and gradient of `f` with a
84/// default argument number `&[0]`.
85///
86/// See also [`value_and_grad_with_arg_nums`] for a version that allows
87/// specifying the argument numbers
88pub fn value_and_grad<'a, F, Err>(
89    f: F,
90) -> impl FnMut(&[Array]) -> Result<(Vec<Array>, Vec<Array>)> + 'a
91where
92    F: IntoValueAndGrad<'a, Err> + 'a,
93{
94    f.into_value_and_grad(None)
95}
96
97/// Returns a function which computes the value and gradient of `f`.
98///
99/// See also [`value_and_grad`] for a version that uses the default argument
100/// numbers `&[0]`.
101pub fn value_and_grad_with_argnums<'a, F, Err>(
102    f: F,
103    argnums: impl IntoOption<&'a [i32]>,
104) -> impl FnMut(&[Array]) -> Result<(Vec<Array>, Vec<Array>)> + 'a
105where
106    F: IntoValueAndGrad<'a, Err> + 'a,
107{
108    f.into_value_and_grad(argnums)
109}
110
111#[cfg(test)]
112mod tests {
113
114    use crate::{array, transforms::value_and_grad, Array};
115
116    use super::*;
117
118    // The unit tests below are adapted from the mlx c++ codebase
119    #[test]
120    fn test_value_and_grad() {
121        let x = &[Array::from_f32(1.0)];
122        let fun = |argin: &[Array]| -> Vec<Array> { vec![&argin[0] + 1.0] };
123        let argnums = &[0];
124        let (y, dfdx) = value_and_grad_with_argnums(fun, argnums)(x).unwrap();
125        assert_eq!(y[0].item::<f32>(), 2.0);
126        assert_eq!(dfdx[0].item::<f32>(), 1.0);
127
128        let (y, dfdx) = value_and_grad(fun)(x).unwrap();
129        assert_eq!(y[0].item::<f32>(), 2.0);
130        assert_eq!(dfdx[0].item::<f32>(), 1.0);
131    }
132
133    #[test]
134    fn test_value_and_grad_with_error() {
135        let fun = |argin: &[Array]| -> Result<Vec<Array>> {
136            argin[0].add(array!(1.0)).map(|res| vec![res])
137        };
138
139        // Success case
140        let argnums = &[0];
141        let x = array!(1.0f32);
142        let y = array!(1.0f32);
143        let args = &[x, y];
144        let result = value_and_grad_with_argnums(fun, argnums)(args);
145        assert!(result.is_ok());
146        let result = value_and_grad(fun)(args);
147        assert!(result.is_ok());
148
149        // Error case
150        // Use non-broadcastable shapes
151        let a = array!([1.0, 2.0, 3.0]);
152        let b = array!([4.0, 5.0]);
153        let args = &[a, b];
154        let result = value_and_grad_with_argnums(fun, argnums)(args);
155        assert!(result.is_err());
156        let result = value_and_grad(fun)(args);
157        assert!(result.is_err());
158
159        // Check that the error is not just "mlx_closure returned a non-zero value"
160        let err = result.unwrap_err();
161        assert!(!err.what().contains("non-zero value"))
162    }
163}