use mlx_sys::mlx_closure_value_and_grad;
use crate::{
error::{get_and_clear_closure_error, Result},
module::ModuleParamRef,
utils::{guard::Guarded, Closure, VectorArray},
Array,
};
pub mod compile;
mod grad;
mod keyed_value_and_grad;
mod value_and_grad;
pub use grad::*;
pub use keyed_value_and_grad::*;
pub use value_and_grad::*;
pub fn eval<'a>(outputs: impl IntoIterator<Item = &'a Array>) -> Result<()> {
let vec = VectorArray::try_from_iter(outputs.into_iter())?;
<() as Guarded>::try_from_op(|_| unsafe { mlx_sys::mlx_eval(vec.as_ptr()) })
}
pub fn eval_params(params: ModuleParamRef<'_>) -> Result<()> {
eval(params.flatten().values().copied())
}
pub fn async_eval<'a>(outputs: impl IntoIterator<Item = &'a Array>) -> Result<()> {
let vec = VectorArray::try_from_iter(outputs.into_iter())?;
<() as Guarded>::try_from_op(|_| unsafe { mlx_sys::mlx_async_eval(vec.as_ptr()) })
}
pub fn async_eval_params(params: ModuleParamRef<'_>) -> Result<()> {
async_eval(params.flatten().values().copied())
}
#[inline]
fn jvp_inner(
closure: Closure<'_>,
primals: &[Array],
tangents: &[Array],
) -> Result<(Vec<Array>, Vec<Array>)> {
let c_primals = VectorArray::try_from_iter(primals.iter())?;
let c_tangents = VectorArray::try_from_iter(tangents.iter())?;
<(Vec<Array>, Vec<Array>) as Guarded>::try_from_op(|(res_0, res_1)| unsafe {
mlx_sys::mlx_jvp(
res_0,
res_1,
closure.as_ptr(),
c_primals.as_ptr(),
c_tangents.as_ptr(),
)
})
.map_err(|e| match get_and_clear_closure_error() {
Some(err) => err,
None => e,
})
}
pub fn jvp<'a, F>(f: F, primals: &[Array], tangents: &[Array]) -> Result<(Vec<Array>, Vec<Array>)>
where
F: FnMut(&[Array]) -> Vec<Array> + 'a,
{
let closure = Closure::new(f);
jvp_inner(closure, primals, tangents)
}
pub fn fallible_jvp<'a, F>(
f: F,
primals: &[Array],
tangents: &[Array],
) -> Result<(Vec<Array>, Vec<Array>)>
where
F: FnMut(&[Array]) -> Result<Vec<Array>> + 'a,
{
let closure = Closure::new_fallible(f);
jvp_inner(closure, primals, tangents)
}
#[inline]
fn vjp_inner(
closure: Closure<'_>,
primals: &[Array],
cotangents: &[Array],
) -> Result<(Vec<Array>, Vec<Array>)> {
let c_primals = VectorArray::try_from_iter(primals.iter())?;
let c_cotangents = VectorArray::try_from_iter(cotangents.iter())?;
<(Vec<Array>, Vec<Array>) as Guarded>::try_from_op(|(res_0, res_1)| unsafe {
mlx_sys::mlx_vjp(
res_0,
res_1,
closure.as_ptr(),
c_primals.as_ptr(),
c_cotangents.as_ptr(),
)
})
.map_err(|e| match get_and_clear_closure_error() {
Some(err) => err,
None => e,
})
}
pub fn vjp<'a, F>(f: F, primals: &[Array], cotangents: &[Array]) -> Result<(Vec<Array>, Vec<Array>)>
where
F: FnMut(&[Array]) -> Vec<Array> + 'a,
{
let closure = Closure::new(f);
vjp_inner(closure, primals, cotangents)
}
pub fn fallible_vjp<'a, F>(
f: F,
primals: &[Array],
cotangents: &[Array],
) -> Result<(Vec<Array>, Vec<Array>)>
where
F: FnMut(&[Array]) -> Result<Vec<Array>> + 'a,
{
let closure = Closure::new_fallible(f);
vjp_inner(closure, primals, cotangents)
}
pub(crate) struct ClosureValueAndGrad {
pub(crate) c_closure_value_and_grad: mlx_closure_value_and_grad,
}
impl ClosureValueAndGrad {
pub fn as_ptr(&self) -> mlx_closure_value_and_grad {
self.c_closure_value_and_grad
}
}
fn value_and_gradient(
value_and_grad: mlx_closure_value_and_grad,
arrays: impl Iterator<Item = impl AsRef<Array>>,
) -> Result<(Vec<Array>, Vec<Array>)> {
let input_vector = VectorArray::try_from_iter(arrays)?;
<(Vec<Array>, Vec<Array>) as Guarded>::try_from_op(|(res_0, res_1)| unsafe {
mlx_sys::mlx_closure_value_and_grad_apply(
res_0,
res_1,
value_and_grad,
input_vector.as_ptr(),
)
})
.map_err(|e| match get_and_clear_closure_error() {
Some(err) => err,
None => e,
})
}
#[cfg(test)]
mod tests {
use crate::{
array,
transforms::{jvp, vjp},
Array,
};
use super::*;
#[test]
fn test_jvp() {
let f = |inputs: &[Array]| -> Vec<Array> { vec![&inputs[0] + &inputs[1]] };
let x = array!(1.0f32);
let y = array!(1.0f32);
let (out, dout) = jvp(f, &[x, y], &[array!(1.0f32), array!(3.0f32)]).unwrap();
assert_eq!(out[0].item::<f32>(), 2.0f32);
assert_eq!(dout[0].item::<f32>(), 4.0f32);
}
#[test]
fn test_jvp_with_error() {
let f = |inputs: &[Array]| -> Result<Vec<Array>> {
inputs[0].add(&inputs[1]).map(|res| vec![res])
};
let x = array!(1.0f32);
let y = array!(1.0f32);
let (out, dout) = fallible_jvp(f, &[x, y], &[array!(1.0f32), array!(3.0f32)]).unwrap();
assert_eq!(out[0].item::<f32>(), 2.0f32);
assert_eq!(dout[0].item::<f32>(), 4.0f32);
let a = array!([1.0, 2.0, 3.0]);
let b = array!([4.0, 5.0]);
let result = fallible_jvp(f, &[a, b], &[array!(1.0f32), array!(3.0f32)]);
assert!(result.is_err());
let err = result.unwrap_err();
assert!(!err.what().contains("non-zero value"))
}
#[test]
fn test_vjp() {
let f = |inputs: &[Array]| -> Vec<Array> { vec![&inputs[0] + &inputs[1]] };
let x = array!(1.0f32);
let y = array!(1.0f32);
let primals = vec![x, y];
let cotangents = vec![array!(1.0f32)];
let (out, dout) = vjp(f, &primals, &cotangents).unwrap();
assert_eq!(out[0].item::<f32>(), 2.0f32);
assert_eq!(dout[0].item::<f32>(), 1.0f32);
}
#[test]
fn test_vjp_with_error() {
let f = |inputs: &[Array]| -> Result<Vec<Array>> {
inputs[0].add(&inputs[1]).map(|res| vec![res])
};
let x = array!(1.0f32);
let y = array!(1.0f32);
let primals = vec![x, y];
let cotangents = vec![array!(1.0f32)];
let (out, dout) = fallible_vjp(f, &primals, &cotangents).unwrap();
assert_eq!(out[0].item::<f32>(), 2.0f32);
assert_eq!(dout[0].item::<f32>(), 1.0f32);
let a = array!([1.0, 2.0, 3.0]);
let b = array!([4.0, 5.0]);
let result = fallible_vjp(f, &[a, b], &[array!(1.0f32)]);
assert!(result.is_err());
let err = result.unwrap_err();
assert!(!err.what().contains("non-zero value"))
}
}