mlx_rs/transforms/
value_and_grad.rs1use crate::{
2 error::{Exception, Result},
3 utils::{guard::Guarded, Closure, IntoOption},
4 Array,
5};
6
7use super::{value_and_gradient, ClosureValueAndGrad};
8
9fn build_value_and_gradient_inner<'a>(
10 closure: Closure<'a>,
11 argnums: &'a [i32],
12) -> impl FnMut(&[Array]) -> Result<(Vec<Array>, Vec<Array>)> + 'a {
13 move |arrays: &[Array]| unsafe {
14 let cvg = ClosureValueAndGrad::try_from_op(|res| {
15 mlx_sys::mlx_value_and_grad(res, closure.as_ptr(), argnums.as_ptr(), argnums.len())
16 })?;
17 value_and_gradient(cvg.as_ptr(), arrays.iter())
18 }
19}
20
21fn build_value_and_gradient<'a, F>(
22 f: F,
23 argnums: &'a [i32],
24) -> impl FnMut(&[Array]) -> Result<(Vec<Array>, Vec<Array>)> + 'a
25where
26 F: FnMut(&[Array]) -> Vec<Array> + 'a,
27{
28 let closure = Closure::new(f);
29 build_value_and_gradient_inner(closure, argnums)
30}
31
32fn build_fallible_value_and_gradient<'a, F>(
33 f: F,
34 argnums: &'a [i32],
35) -> impl FnMut(&[Array]) -> Result<(Vec<Array>, Vec<Array>)> + 'a
36where
37 F: FnMut(&[Array]) -> Result<Vec<Array>> + 'a,
38{
39 let closure = Closure::new_fallible(f);
40 build_value_and_gradient_inner(closure, argnums)
41}
42
43pub trait IntoValueAndGrad<'a, Err> {
46 fn into_value_and_grad(
48 self,
49 argnums: impl IntoOption<&'a [i32]>,
50 ) -> impl FnMut(&[Array]) -> Result<(Vec<Array>, Vec<Array>)> + 'a;
51}
52
53impl<'a, F> IntoValueAndGrad<'a, ()> for F
54where
55 F: FnMut(&[Array]) -> Vec<Array> + 'a,
56{
57 #[allow(refining_impl_trait)]
60 fn into_value_and_grad(
61 self,
62 argnums: impl IntoOption<&'a [i32]>,
63 ) -> impl FnMut(&[Array]) -> Result<(Vec<Array>, Vec<Array>)> + 'a {
64 let argnums = argnums.into_option().unwrap_or(&[0]);
65 build_value_and_gradient(self, argnums)
66 }
67}
68
69impl<'a, F> IntoValueAndGrad<'a, Exception> for F
70where
71 F: FnMut(&[Array]) -> Result<Vec<Array>> + 'a,
72{
73 #[allow(refining_impl_trait)]
74 fn into_value_and_grad(
75 self,
76 argnums: impl IntoOption<&'a [i32]>,
77 ) -> impl FnMut(&[Array]) -> Result<(Vec<Array>, Vec<Array>)> + 'a {
78 let argnums = argnums.into_option().unwrap_or(&[0]);
79 build_fallible_value_and_gradient(self, argnums)
80 }
81}
82
83pub fn value_and_grad<'a, F, Err>(
89 f: F,
90) -> impl FnMut(&[Array]) -> Result<(Vec<Array>, Vec<Array>)> + 'a
91where
92 F: IntoValueAndGrad<'a, Err> + 'a,
93{
94 f.into_value_and_grad(None)
95}
96
97pub fn value_and_grad_with_argnums<'a, F, Err>(
102 f: F,
103 argnums: impl IntoOption<&'a [i32]>,
104) -> impl FnMut(&[Array]) -> Result<(Vec<Array>, Vec<Array>)> + 'a
105where
106 F: IntoValueAndGrad<'a, Err> + 'a,
107{
108 f.into_value_and_grad(argnums)
109}
110
111#[cfg(test)]
112mod tests {
113
114 use crate::{array, transforms::value_and_grad, Array};
115
116 use super::*;
117
118 #[test]
120 fn test_value_and_grad() {
121 let x = &[Array::from_f32(1.0)];
122 let fun = |argin: &[Array]| -> Vec<Array> { vec![&argin[0] + 1.0] };
123 let argnums = &[0];
124 let (y, dfdx) = value_and_grad_with_argnums(fun, argnums)(x).unwrap();
125 assert_eq!(y[0].item::<f32>(), 2.0);
126 assert_eq!(dfdx[0].item::<f32>(), 1.0);
127
128 let (y, dfdx) = value_and_grad(fun)(x).unwrap();
129 assert_eq!(y[0].item::<f32>(), 2.0);
130 assert_eq!(dfdx[0].item::<f32>(), 1.0);
131 }
132
133 #[test]
134 fn test_value_and_grad_with_error() {
135 let fun = |argin: &[Array]| -> Result<Vec<Array>> {
136 argin[0].add(array!(1.0)).map(|res| vec![res])
137 };
138
139 let argnums = &[0];
141 let x = array!(1.0f32);
142 let y = array!(1.0f32);
143 let args = &[x, y];
144 let result = value_and_grad_with_argnums(fun, argnums)(args);
145 assert!(result.is_ok());
146 let result = value_and_grad(fun)(args);
147 assert!(result.is_ok());
148
149 let a = array!([1.0, 2.0, 3.0]);
152 let b = array!([4.0, 5.0]);
153 let args = &[a, b];
154 let result = value_and_grad_with_argnums(fun, argnums)(args);
155 assert!(result.is_err());
156 let result = value_and_grad(fun)(args);
157 assert!(result.is_err());
158
159 let err = result.unwrap_err();
161 assert!(!err.what().contains("non-zero value"))
162 }
163}