use std::{cell::RefCell, marker::PhantomData, rc::Rc};
use crate::{
error::Exception,
transforms::compile::{type_id_to_usize, CompiledState},
utils::Updatable,
Array,
};
use super::{update_by_replace_with_ref_to_new_array, Closure, Compiled, Guarded, VectorArray};
pub fn compile_with_state<F, U, A, O, E>(
f: F,
shapeless: impl Into<Option<bool>>,
) -> impl for<'a> FnMut(&mut U, F::Args<'a>) -> Result<O, Exception>
where
F: CompileWithState<U, A, O, E> + Copy + 'static,
U: Updatable,
{
let shapeless = shapeless.into().unwrap_or(false);
move |state, args| {
let mut compiled = f.compile(shapeless);
compiled.call_mut(state, args)
}
}
pub trait CompileWithState<U, A, O, E> {
type Args<'a>;
fn compile<'args>(self, shapeless: bool) -> impl CallMutWithState<U, Self::Args<'args>, O, E>;
}
impl<F, U> CompileWithState<U, &[Array], Vec<Array>, ()> for F
where
F: FnMut(&mut U, &[Array]) -> Vec<Array> + 'static,
U: Updatable,
{
type Args<'a> = &'a [Array];
fn compile<'args>(
self,
shapeless: bool,
) -> impl CallMutWithState<U, Self::Args<'args>, Vec<Array>, ()> {
let id = type_id_to_usize(&self);
let state = CompiledState {
f: self,
shapeless,
id,
};
Compiled {
f_marker: PhantomData::<F>,
state,
}
}
}
impl<F, U> CompileWithState<U, &Array, Array, ()> for F
where
F: FnMut(&mut U, &Array) -> Array + 'static,
U: Updatable,
{
type Args<'a> = &'a Array;
fn compile<'args>(
mut self,
shapeless: bool,
) -> impl CallMutWithState<U, Self::Args<'args>, Array, ()> {
let id = type_id_to_usize(&self);
let f = move |state: &mut U, args: &[Array]| -> Vec<Array> {
let result = (self)(state, &args[0]);
vec![result]
};
let state = CompiledState { f, shapeless, id };
Compiled {
f_marker: PhantomData::<F>,
state,
}
}
}
impl<F, U> CompileWithState<U, (&Array, &Array), Array, ()> for F
where
F: FnMut(&mut U, (&Array, &Array)) -> Array + 'static,
U: Updatable,
{
type Args<'a> = (&'a Array, &'a Array);
fn compile<'args>(
mut self,
shapeless: bool,
) -> impl CallMutWithState<U, Self::Args<'args>, Array, ()> {
let id = type_id_to_usize(&self);
let f = move |state: &mut U, args: &[Array]| -> Vec<Array> {
let result = (self)(state, (&args[0], &args[1]));
vec![result]
};
let state = CompiledState { f, shapeless, id };
Compiled {
f_marker: PhantomData::<F>,
state,
}
}
}
impl<F, U> CompileWithState<U, (&Array, &Array, &Array), Array, ()> for F
where
F: FnMut(&mut U, (&Array, &Array, &Array)) -> Array + 'static,
U: Updatable,
{
type Args<'a> = (&'a Array, &'a Array, &'a Array);
fn compile<'args>(
mut self,
shapeless: bool,
) -> impl CallMutWithState<U, Self::Args<'args>, Array, ()> {
let id = type_id_to_usize(&self);
let f = move |state: &mut U, args: &[Array]| -> Vec<Array> {
let result = (self)(state, (&args[0], &args[1], &args[2]));
vec![result]
};
let state = CompiledState { f, shapeless, id };
Compiled {
f_marker: PhantomData::<F>,
state,
}
}
}
impl<F, U> CompileWithState<U, &[Array], Vec<Array>, Exception> for F
where
F: FnMut(&mut U, &[Array]) -> Result<Vec<Array>, Exception> + 'static,
U: Updatable,
{
type Args<'a> = &'a [Array];
fn compile<'args>(
self,
shapeless: bool,
) -> impl CallMutWithState<U, Self::Args<'args>, Vec<Array>, Exception> {
let id = type_id_to_usize(&self);
let state = CompiledState {
f: self,
shapeless,
id,
};
Compiled {
f_marker: PhantomData::<F>,
state,
}
}
}
impl<F, U> CompileWithState<U, &Array, Array, Exception> for F
where
F: FnMut(&mut U, &Array) -> Result<Array, Exception> + 'static,
U: Updatable,
{
type Args<'a> = &'a Array;
fn compile<'args>(
mut self,
shapeless: bool,
) -> impl CallMutWithState<U, Self::Args<'args>, Array, Exception> {
let id = type_id_to_usize(&self);
let f = move |state: &mut U, args: &[Array]| -> Result<Vec<Array>, Exception> {
let result = (self)(state, &args[0])?;
Ok(vec![result])
};
let state = CompiledState { f, shapeless, id };
Compiled {
f_marker: PhantomData::<F>,
state,
}
}
}
impl<F, U> CompileWithState<U, (&Array, &Array), Array, Exception> for F
where
F: FnMut(&mut U, (&Array, &Array)) -> Result<Array, Exception> + 'static,
U: Updatable,
{
type Args<'a> = (&'a Array, &'a Array);
fn compile<'args>(
mut self,
shapeless: bool,
) -> impl CallMutWithState<U, Self::Args<'args>, Array, Exception> {
let id = type_id_to_usize(&self);
let f = move |state: &mut U, args: &[Array]| -> Result<Vec<Array>, Exception> {
let result = (self)(state, (&args[0], &args[1]))?;
Ok(vec![result])
};
let state = CompiledState { f, shapeless, id };
Compiled {
f_marker: PhantomData::<F>,
state,
}
}
}
impl<F, U> CompileWithState<U, (&Array, &Array, &Array), Array, Exception> for F
where
F: FnMut(&mut U, (&Array, &Array, &Array)) -> Result<Array, Exception> + 'static,
U: Updatable,
{
type Args<'a> = (&'a Array, &'a Array, &'a Array);
fn compile<'args>(
mut self,
shapeless: bool,
) -> impl CallMutWithState<U, Self::Args<'args>, Array, Exception> {
let id = type_id_to_usize(&self);
let f = move |state: &mut U, args: &[Array]| -> Result<Vec<Array>, Exception> {
let result = (self)(state, (&args[0], &args[1], &args[2]))?;
Ok(vec![result])
};
let state = CompiledState { f, shapeless, id };
Compiled {
f_marker: PhantomData::<F>,
state,
}
}
}
pub trait CallMutWithState<U, A, O, E> {
fn call_mut(&mut self, state: &mut U, args: A) -> Result<O, Exception>;
}
impl<U, F, G> CallMutWithState<U, &[Array], Vec<Array>, ()> for Compiled<F, G>
where
F: FnMut(&mut U, &[Array]) -> Vec<Array>,
G: FnMut(&mut U, &[Array]) -> Vec<Array>,
U: Updatable,
{
fn call_mut(&mut self, state: &mut U, args: &[Array]) -> Result<Vec<Array>, Exception> {
self.state.retry_call_mut_with_state(state, args)
}
}
impl<U, F, G> CallMutWithState<U, &Array, Array, ()> for Compiled<F, G>
where
F: FnMut(&mut U, &Array) -> Array,
G: FnMut(&mut U, &[Array]) -> Vec<Array>,
U: Updatable,
{
fn call_mut(&mut self, state: &mut U, args: &Array) -> Result<Array, Exception> {
let args = std::slice::from_ref(args);
let result = self.state.retry_call_mut_with_state(state, args)?;
Ok(result.into_iter().next().unwrap())
}
}
impl<U, F, G> CallMutWithState<U, (&Array, &Array), Array, ()> for Compiled<F, G>
where
F: FnMut(&mut U, (&Array, &Array)) -> Array,
G: FnMut(&mut U, &[Array]) -> Vec<Array>,
U: Updatable,
{
fn call_mut(&mut self, state: &mut U, args: (&Array, &Array)) -> Result<Array, Exception> {
let args = &[args.0, args.1];
let result = self.state.retry_call_mut_with_state(state, args)?;
Ok(result.into_iter().next().unwrap())
}
}
impl<U, F, G> CallMutWithState<U, (&Array, &Array, &Array), Array, ()> for Compiled<F, G>
where
F: FnMut(&mut U, (&Array, &Array, &Array)) -> Array,
G: FnMut(&mut U, &[Array]) -> Vec<Array>,
U: Updatable,
{
fn call_mut(
&mut self,
state: &mut U,
args: (&Array, &Array, &Array),
) -> Result<Array, Exception> {
let args = &[args.0, args.1, args.2];
let result = self.state.retry_call_mut_with_state(state, args)?;
Ok(result.into_iter().next().unwrap())
}
}
impl<U, F, G> CallMutWithState<U, &[Array], Vec<Array>, Exception> for Compiled<F, G>
where
F: FnMut(&mut U, &[Array]) -> Result<Vec<Array>, Exception>,
G: FnMut(&mut U, &[Array]) -> Result<Vec<Array>, Exception>,
U: Updatable,
{
fn call_mut(&mut self, state: &mut U, args: &[Array]) -> Result<Vec<Array>, Exception> {
self.state.retry_fallible_call_mut_with_state(state, args)
}
}
impl<U, F, G> CallMutWithState<U, &Array, Array, Exception> for Compiled<F, G>
where
F: FnMut(&mut U, &Array) -> Result<Array, Exception>,
G: FnMut(&mut U, &[Array]) -> Result<Vec<Array>, Exception>,
U: Updatable,
{
fn call_mut(&mut self, state: &mut U, args: &Array) -> Result<Array, Exception> {
let args = std::slice::from_ref(args);
let result = self.state.retry_fallible_call_mut_with_state(state, args)?;
Ok(result.into_iter().next().unwrap())
}
}
impl<U, F, G> CallMutWithState<U, (&Array, &Array), Array, Exception> for Compiled<F, G>
where
F: FnMut(&mut U, (&Array, &Array)) -> Result<Array, Exception>,
G: FnMut(&mut U, &[Array]) -> Result<Vec<Array>, Exception>,
U: Updatable,
{
fn call_mut(&mut self, state: &mut U, args: (&Array, &Array)) -> Result<Array, Exception> {
let args = &[args.0, args.1];
let result = self.state.retry_fallible_call_mut_with_state(state, args)?;
Ok(result.into_iter().next().unwrap())
}
}
impl<U, F, G> CallMutWithState<U, (&Array, &Array, &Array), Array, Exception> for Compiled<F, G>
where
F: FnMut(&mut U, (&Array, &Array, &Array)) -> Result<Array, Exception>,
G: FnMut(&mut U, &[Array]) -> Result<Vec<Array>, Exception>,
U: Updatable,
{
fn call_mut(
&mut self,
state: &mut U,
args: (&Array, &Array, &Array),
) -> Result<Array, Exception> {
let args = &[args.0, args.1, args.2];
let result = self.state.retry_fallible_call_mut_with_state(state, args)?;
Ok(result.into_iter().next().unwrap())
}
}
#[inline]
fn call_mut_with_state_inner<U>(
inner_closure: Closure,
fun_id: usize,
shapeless: bool,
state: Rc<RefCell<&mut U>>,
args: &[impl AsRef<Array>],
) -> crate::error::Result<Vec<Array>>
where
U: Updatable,
{
let compiled = Closure::try_from_op(|res| unsafe {
let constants = &[];
mlx_sys::mlx_detail_compile(
res,
inner_closure.as_ptr(),
fun_id,
shapeless,
constants.as_ptr(),
0,
)
})?;
let (state_params_len, inner_inputs_vector) = {
let borrow = state.borrow();
let state_params: Vec<_> = borrow.updatable_states().into_iter().collect();
let state_params_len = state_params.len();
let inner_inputs_vector = VectorArray::try_from_iter(
args.iter()
.map(AsRef::as_ref)
.chain(state_params.into_iter()),
)?;
(state_params_len, inner_inputs_vector)
};
let result_vector = VectorArray::try_from_op(|res| unsafe {
mlx_sys::mlx_closure_apply(res, compiled.as_ptr(), inner_inputs_vector.as_ptr())
})?;
let result_plus_state_output: Vec<Array> = result_vector.try_into_values()?;
let result_plus_state_output_len = result_plus_state_output.len();
let suffix_len = result_plus_state_output_len - state_params_len;
for (s, new_values) in state
.borrow_mut()
.updatable_states_mut()
.into_iter()
.zip(result_plus_state_output[suffix_len..].iter())
{
update_by_replace_with_ref_to_new_array(s, new_values);
}
let result_len = result_plus_state_output.len() - state_params_len;
Ok(result_plus_state_output
.into_iter()
.take(result_len)
.collect())
}
impl<F> CompiledState<F> {
fn retry_call_mut_with_state<U>(
&mut self,
state: &mut U,
args: &[impl AsRef<Array>],
) -> Result<Vec<Array>, Exception>
where
F: FnMut(&mut U, &[Array]) -> Vec<Array>,
U: Updatable,
{
self.call_mut_with_state(state, args).or_else(|_e| {
self.call_mut_with_state(state, args)
})
}
fn retry_fallible_call_mut_with_state<U>(
&mut self,
state: &mut U,
args: &[impl AsRef<Array>],
) -> Result<Vec<Array>, Exception>
where
F: FnMut(&mut U, &[Array]) -> Result<Vec<Array>, Exception>,
U: Updatable,
{
self.fallible_call_mut_with_state(state, args)
.or_else(|_e| {
self.fallible_call_mut_with_state(state, args)
})
}
fn call_mut_with_state<U>(
&mut self,
state: &mut U,
args: &[impl AsRef<Array>],
) -> Result<Vec<Array>, Exception>
where
F: FnMut(&mut U, &[Array]) -> Vec<Array>,
U: Updatable,
{
let args_len = args.len();
let state = Rc::new(RefCell::new(state));
let f = &mut self.f;
let state_clone = Rc::clone(&state);
let inner = move |tracers: &[Array]| -> Vec<Array> {
let tracer_args = &tracers[..args_len];
let saved_state_inputs = state_clone
.borrow()
.updatable_states()
.into_iter()
.map(|array| (*array).clone())
.collect::<Vec<Array>>();
for (s, tracer) in state_clone
.borrow_mut()
.updatable_states_mut()
.into_iter()
.zip(tracers.iter().skip(args_len))
{
update_by_replace_with_ref_to_new_array(s, tracer);
}
let mut result = (f)(*state_clone.borrow_mut(), tracer_args);
let mut state_output_tracers = state_clone
.borrow()
.updatable_states()
.into_iter()
.map(|array| (*array).clone())
.collect::<Vec<Array>>();
for (s, saved) in state_clone
.borrow_mut()
.updatable_states_mut()
.into_iter()
.zip(saved_state_inputs)
{
update_by_replace_with_ref_to_new_array(s, &saved);
}
result.append(&mut state_output_tracers);
result
};
let inner_closure = Closure::new(inner);
call_mut_with_state_inner(inner_closure, self.id, self.shapeless, state, args)
}
fn fallible_call_mut_with_state<U>(
&mut self,
state: &mut U,
args: &[impl AsRef<Array>],
) -> Result<Vec<Array>, Exception>
where
F: FnMut(&mut U, &[Array]) -> Result<Vec<Array>, Exception>,
U: Updatable,
{
let args_len = args.len();
let state = Rc::new(RefCell::new(state));
let f = &mut self.f;
let state_clone = Rc::clone(&state);
let inner = move |tracers: &[Array]| -> Result<Vec<Array>, Exception> {
let tracer_args = &tracers[..args_len];
let saved_state_inputs = state_clone
.borrow()
.updatable_states()
.into_iter()
.map(|array| (*array).clone())
.collect::<Vec<Array>>();
for (s, tracer) in state_clone
.borrow_mut()
.updatable_states_mut()
.into_iter()
.zip(tracers.iter().skip(args_len))
{
update_by_replace_with_ref_to_new_array(s, tracer);
}
let mut result = (f)(*state_clone.borrow_mut(), tracer_args)?;
let mut state_output_tracers = state_clone
.borrow()
.updatable_states()
.into_iter()
.map(|array| (*array).clone())
.collect::<Vec<Array>>();
for (s, saved) in state_clone
.borrow_mut()
.updatable_states_mut()
.into_iter()
.zip(saved_state_inputs)
{
update_by_replace_with_ref_to_new_array(s, &saved);
}
result.append(&mut state_output_tracers);
Ok(result)
};
let inner_closure = Closure::new_fallible(inner);
call_mut_with_state_inner(inner_closure, self.id, self.shapeless, state, args)
}
}