1use mlx_sys::mlx_closure_value_and_grad;
49
50use crate::{
51 error::{get_and_clear_closure_error, Result},
52 module::ModuleParamRef,
53 utils::{guard::Guarded, Closure, VectorArray},
54 Array,
55};
56
57pub mod compile;
58mod grad;
59mod keyed_value_and_grad;
60mod value_and_grad;
61
62pub use grad::*;
63pub use keyed_value_and_grad::*;
64pub use value_and_grad::*;
65
66pub fn eval<'a>(outputs: impl IntoIterator<Item = &'a Array>) -> Result<()> {
68 let vec = VectorArray::try_from_iter(outputs.into_iter())?;
69 <() as Guarded>::try_from_op(|_| unsafe { mlx_sys::mlx_eval(vec.as_ptr()) })
70}
71
72pub fn eval_params(params: ModuleParamRef<'_>) -> Result<()> {
76 eval(params.flatten().values().copied())
77}
78
79pub fn async_eval<'a>(outputs: impl IntoIterator<Item = &'a Array>) -> Result<()> {
83 let vec = VectorArray::try_from_iter(outputs.into_iter())?;
84 <() as Guarded>::try_from_op(|_| unsafe { mlx_sys::mlx_async_eval(vec.as_ptr()) })
85}
86
87pub fn async_eval_params(params: ModuleParamRef<'_>) -> Result<()> {
91 async_eval(params.flatten().values().copied())
92}
93
94#[inline]
95fn jvp_inner(
96 closure: Closure<'_>,
97 primals: &[Array],
98 tangents: &[Array],
99) -> Result<(Vec<Array>, Vec<Array>)> {
100 let c_primals = VectorArray::try_from_iter(primals.iter())?;
101 let c_tangents = VectorArray::try_from_iter(tangents.iter())?;
102
103 <(Vec<Array>, Vec<Array>) as Guarded>::try_from_op(|(res_0, res_1)| unsafe {
104 mlx_sys::mlx_jvp(
105 res_0,
106 res_1,
107 closure.as_ptr(),
108 c_primals.as_ptr(),
109 c_tangents.as_ptr(),
110 )
111 })
112 .map_err(|e| match get_and_clear_closure_error() {
113 Some(err) => err,
114 None => e,
115 })
116}
117
118pub fn jvp<'a, F>(f: F, primals: &[Array], tangents: &[Array]) -> Result<(Vec<Array>, Vec<Array>)>
137where
138 F: FnMut(&[Array]) -> Vec<Array> + 'a,
139{
140 let closure = Closure::new(f);
141 jvp_inner(closure, primals, tangents)
142}
143
144pub fn fallible_jvp<'a, F>(
146 f: F,
147 primals: &[Array],
148 tangents: &[Array],
149) -> Result<(Vec<Array>, Vec<Array>)>
150where
151 F: FnMut(&[Array]) -> Result<Vec<Array>> + 'a,
152{
153 let closure = Closure::new_fallible(f);
154 jvp_inner(closure, primals, tangents)
155}
156
157#[inline]
158fn vjp_inner(
159 closure: Closure<'_>,
160 primals: &[Array],
161 cotangents: &[Array],
162) -> Result<(Vec<Array>, Vec<Array>)> {
163 let c_primals = VectorArray::try_from_iter(primals.iter())?;
164 let c_cotangents = VectorArray::try_from_iter(cotangents.iter())?;
165
166 <(Vec<Array>, Vec<Array>) as Guarded>::try_from_op(|(res_0, res_1)| unsafe {
167 mlx_sys::mlx_vjp(
168 res_0,
169 res_1,
170 closure.as_ptr(),
171 c_primals.as_ptr(),
172 c_cotangents.as_ptr(),
173 )
174 })
175 .map_err(|e| match get_and_clear_closure_error() {
176 Some(err) => err,
177 None => e,
178 })
179}
180
181pub fn vjp<'a, F>(f: F, primals: &[Array], cotangents: &[Array]) -> Result<(Vec<Array>, Vec<Array>)>
198where
199 F: FnMut(&[Array]) -> Vec<Array> + 'a,
200{
201 let closure = Closure::new(f);
202 vjp_inner(closure, primals, cotangents)
203}
204
205pub fn fallible_vjp<'a, F>(
207 f: F,
208 primals: &[Array],
209 cotangents: &[Array],
210) -> Result<(Vec<Array>, Vec<Array>)>
211where
212 F: FnMut(&[Array]) -> Result<Vec<Array>> + 'a,
213{
214 let closure = Closure::new_fallible(f);
215 vjp_inner(closure, primals, cotangents)
216}
217
218pub(crate) struct ClosureValueAndGrad {
219 pub(crate) c_closure_value_and_grad: mlx_closure_value_and_grad,
220}
221
222impl ClosureValueAndGrad {
223 pub fn as_ptr(&self) -> mlx_closure_value_and_grad {
224 self.c_closure_value_and_grad
225 }
226}
227
228fn value_and_gradient(
229 value_and_grad: mlx_closure_value_and_grad,
230 arrays: impl Iterator<Item = impl AsRef<Array>>,
231) -> Result<(Vec<Array>, Vec<Array>)> {
232 let input_vector = VectorArray::try_from_iter(arrays)?;
233
234 <(Vec<Array>, Vec<Array>) as Guarded>::try_from_op(|(res_0, res_1)| unsafe {
235 mlx_sys::mlx_closure_value_and_grad_apply(
236 res_0,
237 res_1,
238 value_and_grad,
239 input_vector.as_ptr(),
240 )
241 })
242 .map_err(|e| match get_and_clear_closure_error() {
243 Some(err) => err,
244 None => e,
245 })
246}
247
248#[cfg(test)]
249mod tests {
250
251 use crate::{
252 array,
253 transforms::{jvp, vjp},
254 Array,
255 };
256
257 use super::*;
258
259 #[test]
262 fn test_jvp() {
263 let f = |inputs: &[Array]| -> Vec<Array> { vec![&inputs[0] + &inputs[1]] };
264 let x = array!(1.0f32);
265 let y = array!(1.0f32);
266 let (out, dout) = jvp(f, &[x, y], &[array!(1.0f32), array!(3.0f32)]).unwrap();
267 assert_eq!(out[0].item::<f32>(), 2.0f32);
268 assert_eq!(dout[0].item::<f32>(), 4.0f32);
269 }
270
271 #[test]
272 fn test_jvp_with_error() {
273 let f = |inputs: &[Array]| -> Result<Vec<Array>> {
274 inputs[0].add(&inputs[1]).map(|res| vec![res])
275 };
276
277 let x = array!(1.0f32);
279 let y = array!(1.0f32);
280 let (out, dout) = fallible_jvp(f, &[x, y], &[array!(1.0f32), array!(3.0f32)]).unwrap();
281 assert_eq!(out[0].item::<f32>(), 2.0f32);
282 assert_eq!(dout[0].item::<f32>(), 4.0f32);
283
284 let a = array!([1.0, 2.0, 3.0]);
287 let b = array!([4.0, 5.0]);
288 let result = fallible_jvp(f, &[a, b], &[array!(1.0f32), array!(3.0f32)]);
289 assert!(result.is_err());
290
291 let err = result.unwrap_err();
293 assert!(!err.what().contains("non-zero value"))
294 }
295
296 #[test]
297 fn test_vjp() {
298 let f = |inputs: &[Array]| -> Vec<Array> { vec![&inputs[0] + &inputs[1]] };
299 let x = array!(1.0f32);
300 let y = array!(1.0f32);
301 let primals = vec![x, y];
302 let cotangents = vec![array!(1.0f32)];
303 let (out, dout) = vjp(f, &primals, &cotangents).unwrap();
304 assert_eq!(out[0].item::<f32>(), 2.0f32);
305 assert_eq!(dout[0].item::<f32>(), 1.0f32);
306 }
307
308 #[test]
309 fn test_vjp_with_error() {
310 let f = |inputs: &[Array]| -> Result<Vec<Array>> {
311 inputs[0].add(&inputs[1]).map(|res| vec![res])
312 };
313
314 let x = array!(1.0f32);
316 let y = array!(1.0f32);
317 let primals = vec![x, y];
318 let cotangents = vec![array!(1.0f32)];
319 let (out, dout) = fallible_vjp(f, &primals, &cotangents).unwrap();
320 assert_eq!(out[0].item::<f32>(), 2.0f32);
321 assert_eq!(dout[0].item::<f32>(), 1.0f32);
322
323 let a = array!([1.0, 2.0, 3.0]);
326 let b = array!([4.0, 5.0]);
327 let result = fallible_vjp(f, &[a, b], &[array!(1.0f32)]);
328 assert!(result.is_err());
329
330 let err = result.unwrap_err();
332 assert!(!err.what().contains("non-zero value"))
333 }
334}