1use std::{cell::RefCell, marker::PhantomData, rc::Rc};
10
11use crate::{
12 error::Exception,
13 transforms::compile::{type_id_to_usize, CompiledState},
14 utils::Updatable,
15 Array,
16};
17
18use super::{update_by_replace_with_ref_to_new_array, Closure, Compiled, Guarded, VectorArray};
19
20pub fn compile_with_state<F, U, A, O, E>(
23 f: F,
24 shapeless: impl Into<Option<bool>>,
25) -> impl for<'a> FnMut(&mut U, F::Args<'a>) -> Result<O, Exception>
26where
27 F: CompileWithState<U, A, O, E> + Copy + 'static,
28 U: Updatable,
29{
30 let shapeless = shapeless.into().unwrap_or(false);
31 move |state, args| {
32 let mut compiled = f.compile(shapeless);
33 compiled.call_mut(state, args)
34 }
35}
36
37pub trait CompileWithState<U, A, O, E> {
49 type Args<'a>;
55
56 fn compile<'args>(self, shapeless: bool) -> impl CallMutWithState<U, Self::Args<'args>, O, E>;
58}
59
60impl<F, U> CompileWithState<U, &[Array], Vec<Array>, ()> for F
61where
62 F: FnMut(&mut U, &[Array]) -> Vec<Array> + 'static,
63 U: Updatable,
64{
65 type Args<'a> = &'a [Array];
66
67 fn compile<'args>(
68 self,
69 shapeless: bool,
70 ) -> impl CallMutWithState<U, Self::Args<'args>, Vec<Array>, ()> {
71 let id = type_id_to_usize(&self);
72 let state = CompiledState {
73 f: self,
74 shapeless,
75 id,
76 };
77 Compiled {
78 f_marker: PhantomData::<F>,
79 state,
80 }
81 }
82}
83
84impl<F, U> CompileWithState<U, &Array, Array, ()> for F
85where
86 F: FnMut(&mut U, &Array) -> Array + 'static,
87 U: Updatable,
88{
89 type Args<'a> = &'a Array;
90
91 fn compile<'args>(
92 mut self,
93 shapeless: bool,
94 ) -> impl CallMutWithState<U, Self::Args<'args>, Array, ()> {
95 let id = type_id_to_usize(&self);
96 let f = move |state: &mut U, args: &[Array]| -> Vec<Array> {
97 let result = (self)(state, &args[0]);
98 vec![result]
99 };
100 let state = CompiledState { f, shapeless, id };
101 Compiled {
102 f_marker: PhantomData::<F>,
103 state,
104 }
105 }
106}
107
108impl<F, U> CompileWithState<U, (&Array, &Array), Array, ()> for F
109where
110 F: FnMut(&mut U, (&Array, &Array)) -> Array + 'static,
111 U: Updatable,
112{
113 type Args<'a> = (&'a Array, &'a Array);
114
115 fn compile<'args>(
116 mut self,
117 shapeless: bool,
118 ) -> impl CallMutWithState<U, Self::Args<'args>, Array, ()> {
119 let id = type_id_to_usize(&self);
120 let f = move |state: &mut U, args: &[Array]| -> Vec<Array> {
121 let result = (self)(state, (&args[0], &args[1]));
122 vec![result]
123 };
124 let state = CompiledState { f, shapeless, id };
125 Compiled {
126 f_marker: PhantomData::<F>,
127 state,
128 }
129 }
130}
131
132impl<F, U> CompileWithState<U, (&Array, &Array, &Array), Array, ()> for F
133where
134 F: FnMut(&mut U, (&Array, &Array, &Array)) -> Array + 'static,
135 U: Updatable,
136{
137 type Args<'a> = (&'a Array, &'a Array, &'a Array);
138
139 fn compile<'args>(
140 mut self,
141 shapeless: bool,
142 ) -> impl CallMutWithState<U, Self::Args<'args>, Array, ()> {
143 let id = type_id_to_usize(&self);
144 let f = move |state: &mut U, args: &[Array]| -> Vec<Array> {
145 let result = (self)(state, (&args[0], &args[1], &args[2]));
146 vec![result]
147 };
148 let state = CompiledState { f, shapeless, id };
149 Compiled {
150 f_marker: PhantomData::<F>,
151 state,
152 }
153 }
154}
155
156impl<F, U> CompileWithState<U, &[Array], Vec<Array>, Exception> for F
157where
158 F: FnMut(&mut U, &[Array]) -> Result<Vec<Array>, Exception> + 'static,
159 U: Updatable,
160{
161 type Args<'a> = &'a [Array];
162
163 fn compile<'args>(
164 self,
165 shapeless: bool,
166 ) -> impl CallMutWithState<U, Self::Args<'args>, Vec<Array>, Exception> {
167 let id = type_id_to_usize(&self);
168 let state = CompiledState {
169 f: self,
170 shapeless,
171 id,
172 };
173 Compiled {
174 f_marker: PhantomData::<F>,
175 state,
176 }
177 }
178}
179
180impl<F, U> CompileWithState<U, &Array, Array, Exception> for F
181where
182 F: FnMut(&mut U, &Array) -> Result<Array, Exception> + 'static,
183 U: Updatable,
184{
185 type Args<'a> = &'a Array;
186
187 fn compile<'args>(
188 mut self,
189 shapeless: bool,
190 ) -> impl CallMutWithState<U, Self::Args<'args>, Array, Exception> {
191 let id = type_id_to_usize(&self);
192 let f = move |state: &mut U, args: &[Array]| -> Result<Vec<Array>, Exception> {
193 let result = (self)(state, &args[0])?;
194 Ok(vec![result])
195 };
196 let state = CompiledState { f, shapeless, id };
197 Compiled {
198 f_marker: PhantomData::<F>,
199 state,
200 }
201 }
202}
203
204impl<F, U> CompileWithState<U, (&Array, &Array), Array, Exception> for F
205where
206 F: FnMut(&mut U, (&Array, &Array)) -> Result<Array, Exception> + 'static,
207 U: Updatable,
208{
209 type Args<'a> = (&'a Array, &'a Array);
210
211 fn compile<'args>(
212 mut self,
213 shapeless: bool,
214 ) -> impl CallMutWithState<U, Self::Args<'args>, Array, Exception> {
215 let id = type_id_to_usize(&self);
216 let f = move |state: &mut U, args: &[Array]| -> Result<Vec<Array>, Exception> {
217 let result = (self)(state, (&args[0], &args[1]))?;
218 Ok(vec![result])
219 };
220 let state = CompiledState { f, shapeless, id };
221 Compiled {
222 f_marker: PhantomData::<F>,
223 state,
224 }
225 }
226}
227
228impl<F, U> CompileWithState<U, (&Array, &Array, &Array), Array, Exception> for F
229where
230 F: FnMut(&mut U, (&Array, &Array, &Array)) -> Result<Array, Exception> + 'static,
231 U: Updatable,
232{
233 type Args<'a> = (&'a Array, &'a Array, &'a Array);
234
235 fn compile<'args>(
236 mut self,
237 shapeless: bool,
238 ) -> impl CallMutWithState<U, Self::Args<'args>, Array, Exception> {
239 let id = type_id_to_usize(&self);
240 let f = move |state: &mut U, args: &[Array]| -> Result<Vec<Array>, Exception> {
241 let result = (self)(state, (&args[0], &args[1], &args[2]))?;
242 Ok(vec![result])
243 };
244 let state = CompiledState { f, shapeless, id };
245 Compiled {
246 f_marker: PhantomData::<F>,
247 state,
248 }
249 }
250}
251
252pub trait CallMutWithState<U, A, O, E> {
254 fn call_mut(&mut self, state: &mut U, args: A) -> Result<O, Exception>;
256}
257
258impl<U, F, G> CallMutWithState<U, &[Array], Vec<Array>, ()> for Compiled<F, G>
259where
260 F: FnMut(&mut U, &[Array]) -> Vec<Array>,
261 G: FnMut(&mut U, &[Array]) -> Vec<Array>,
262 U: Updatable,
263{
264 fn call_mut(&mut self, state: &mut U, args: &[Array]) -> Result<Vec<Array>, Exception> {
265 self.state.retry_call_mut_with_state(state, args)
266 }
267}
268
269impl<U, F, G> CallMutWithState<U, &Array, Array, ()> for Compiled<F, G>
270where
271 F: FnMut(&mut U, &Array) -> Array,
272 G: FnMut(&mut U, &[Array]) -> Vec<Array>,
273 U: Updatable,
274{
275 fn call_mut(&mut self, state: &mut U, args: &Array) -> Result<Array, Exception> {
276 let args = std::slice::from_ref(args);
277 let result = self.state.retry_call_mut_with_state(state, args)?;
278 Ok(result.into_iter().next().unwrap())
279 }
280}
281
282impl<U, F, G> CallMutWithState<U, (&Array, &Array), Array, ()> for Compiled<F, G>
283where
284 F: FnMut(&mut U, (&Array, &Array)) -> Array,
285 G: FnMut(&mut U, &[Array]) -> Vec<Array>,
286 U: Updatable,
287{
288 fn call_mut(&mut self, state: &mut U, args: (&Array, &Array)) -> Result<Array, Exception> {
289 let args = &[args.0, args.1];
290 let result = self.state.retry_call_mut_with_state(state, args)?;
291 Ok(result.into_iter().next().unwrap())
292 }
293}
294
295impl<U, F, G> CallMutWithState<U, (&Array, &Array, &Array), Array, ()> for Compiled<F, G>
296where
297 F: FnMut(&mut U, (&Array, &Array, &Array)) -> Array,
298 G: FnMut(&mut U, &[Array]) -> Vec<Array>,
299 U: Updatable,
300{
301 fn call_mut(
302 &mut self,
303 state: &mut U,
304 args: (&Array, &Array, &Array),
305 ) -> Result<Array, Exception> {
306 let args = &[args.0, args.1, args.2];
307 let result = self.state.retry_call_mut_with_state(state, args)?;
308 Ok(result.into_iter().next().unwrap())
309 }
310}
311
312impl<U, F, G> CallMutWithState<U, &[Array], Vec<Array>, Exception> for Compiled<F, G>
313where
314 F: FnMut(&mut U, &[Array]) -> Result<Vec<Array>, Exception>,
315 G: FnMut(&mut U, &[Array]) -> Result<Vec<Array>, Exception>,
316 U: Updatable,
317{
318 fn call_mut(&mut self, state: &mut U, args: &[Array]) -> Result<Vec<Array>, Exception> {
319 self.state.retry_fallible_call_mut_with_state(state, args)
320 }
321}
322
323impl<U, F, G> CallMutWithState<U, &Array, Array, Exception> for Compiled<F, G>
324where
325 F: FnMut(&mut U, &Array) -> Result<Array, Exception>,
326 G: FnMut(&mut U, &[Array]) -> Result<Vec<Array>, Exception>,
327 U: Updatable,
328{
329 fn call_mut(&mut self, state: &mut U, args: &Array) -> Result<Array, Exception> {
330 let args = std::slice::from_ref(args);
331 let result = self.state.retry_fallible_call_mut_with_state(state, args)?;
332 Ok(result.into_iter().next().unwrap())
333 }
334}
335
336impl<U, F, G> CallMutWithState<U, (&Array, &Array), Array, Exception> for Compiled<F, G>
337where
338 F: FnMut(&mut U, (&Array, &Array)) -> Result<Array, Exception>,
339 G: FnMut(&mut U, &[Array]) -> Result<Vec<Array>, Exception>,
340 U: Updatable,
341{
342 fn call_mut(&mut self, state: &mut U, args: (&Array, &Array)) -> Result<Array, Exception> {
343 let args = &[args.0, args.1];
344 let result = self.state.retry_fallible_call_mut_with_state(state, args)?;
345 Ok(result.into_iter().next().unwrap())
346 }
347}
348
349impl<U, F, G> CallMutWithState<U, (&Array, &Array, &Array), Array, Exception> for Compiled<F, G>
350where
351 F: FnMut(&mut U, (&Array, &Array, &Array)) -> Result<Array, Exception>,
352 G: FnMut(&mut U, &[Array]) -> Result<Vec<Array>, Exception>,
353 U: Updatable,
354{
355 fn call_mut(
356 &mut self,
357 state: &mut U,
358 args: (&Array, &Array, &Array),
359 ) -> Result<Array, Exception> {
360 let args = &[args.0, args.1, args.2];
361 let result = self.state.retry_fallible_call_mut_with_state(state, args)?;
362 Ok(result.into_iter().next().unwrap())
363 }
364}
365
366#[inline]
367fn call_mut_with_state_inner<U>(
368 inner_closure: Closure,
369 fun_id: usize,
370 shapeless: bool,
371 state: Rc<RefCell<&mut U>>,
372 args: &[impl AsRef<Array>],
373) -> crate::error::Result<Vec<Array>>
374where
375 U: Updatable,
376{
377 let compiled = Closure::try_from_op(|res| unsafe {
380 let constants = &[];
381 mlx_sys::mlx_detail_compile(
382 res,
383 inner_closure.as_ptr(),
384 fun_id,
385 shapeless,
386 constants.as_ptr(),
387 0,
388 )
389 })?;
390
391 let (state_params_len, inner_inputs_vector) = {
392 let borrow = state.borrow();
393 let state_params: Vec<_> = borrow.updatable_states().into_iter().collect();
394 let state_params_len = state_params.len();
395 let inner_inputs_vector = VectorArray::try_from_iter(
396 args.iter()
397 .map(AsRef::as_ref)
398 .chain(state_params.into_iter()),
399 )?;
400 (state_params_len, inner_inputs_vector)
401 };
402
403 let result_vector = VectorArray::try_from_op(|res| unsafe {
406 mlx_sys::mlx_closure_apply(res, compiled.as_ptr(), inner_inputs_vector.as_ptr())
407 })?;
408 let result_plus_state_output: Vec<Array> = result_vector.try_into_values()?;
409
410 let result_plus_state_output_len = result_plus_state_output.len();
412 let suffix_len = result_plus_state_output_len - state_params_len;
413 for (s, new_values) in state
414 .borrow_mut()
415 .updatable_states_mut()
416 .into_iter()
417 .zip(result_plus_state_output[suffix_len..].iter())
418 {
419 update_by_replace_with_ref_to_new_array(s, new_values);
420 }
421
422 let result_len = result_plus_state_output.len() - state_params_len;
423 Ok(result_plus_state_output
424 .into_iter()
425 .take(result_len)
426 .collect())
427}
428
429impl<F> CompiledState<F> {
430 fn retry_call_mut_with_state<U>(
431 &mut self,
432 state: &mut U,
433 args: &[impl AsRef<Array>],
434 ) -> Result<Vec<Array>, Exception>
435 where
436 F: FnMut(&mut U, &[Array]) -> Vec<Array>,
437 U: Updatable,
438 {
439 self.call_mut_with_state(state, args).or_else(|_e| {
440 self.call_mut_with_state(state, args)
447 })
448 }
449
450 fn retry_fallible_call_mut_with_state<U>(
451 &mut self,
452 state: &mut U,
453 args: &[impl AsRef<Array>],
454 ) -> Result<Vec<Array>, Exception>
455 where
456 F: FnMut(&mut U, &[Array]) -> Result<Vec<Array>, Exception>,
457 U: Updatable,
458 {
459 self.fallible_call_mut_with_state(state, args)
460 .or_else(|_e| {
461 self.fallible_call_mut_with_state(state, args)
468 })
469 }
470
471 fn call_mut_with_state<U>(
472 &mut self,
473 state: &mut U,
474 args: &[impl AsRef<Array>],
475 ) -> Result<Vec<Array>, Exception>
476 where
477 F: FnMut(&mut U, &[Array]) -> Vec<Array>,
478 U: Updatable,
479 {
480 let args_len = args.len();
481 let state = Rc::new(RefCell::new(state));
482 let f = &mut self.f;
483
484 let state_clone = Rc::clone(&state);
485 let inner = move |tracers: &[Array]| -> Vec<Array> {
486 let tracer_args = &tracers[..args_len];
491
492 let saved_state_inputs = state_clone
494 .borrow()
495 .updatable_states()
496 .into_iter()
497 .map(|array| (*array).clone())
498 .collect::<Vec<Array>>();
499
500 for (s, tracer) in state_clone
502 .borrow_mut()
503 .updatable_states_mut()
504 .into_iter()
505 .zip(tracers.iter().skip(args_len))
506 {
507 update_by_replace_with_ref_to_new_array(s, tracer);
508 }
509
510 let mut result = (f)(*state_clone.borrow_mut(), tracer_args);
512
513 let mut state_output_tracers = state_clone
515 .borrow()
516 .updatable_states()
517 .into_iter()
518 .map(|array| (*array).clone())
519 .collect::<Vec<Array>>();
520
521 for (s, saved) in state_clone
523 .borrow_mut()
524 .updatable_states_mut()
525 .into_iter()
526 .zip(saved_state_inputs)
527 {
528 update_by_replace_with_ref_to_new_array(s, &saved);
529 }
530
531 result.append(&mut state_output_tracers);
533
534 result
535 };
536
537 let inner_closure = Closure::new(inner);
538 call_mut_with_state_inner(inner_closure, self.id, self.shapeless, state, args)
539 }
540
541 fn fallible_call_mut_with_state<U>(
542 &mut self,
543 state: &mut U,
544 args: &[impl AsRef<Array>],
545 ) -> Result<Vec<Array>, Exception>
546 where
547 F: FnMut(&mut U, &[Array]) -> Result<Vec<Array>, Exception>,
548 U: Updatable,
549 {
550 let args_len = args.len();
551 let state = Rc::new(RefCell::new(state));
552 let f = &mut self.f;
553
554 let state_clone = Rc::clone(&state);
555 let inner = move |tracers: &[Array]| -> Result<Vec<Array>, Exception> {
556 let tracer_args = &tracers[..args_len];
561
562 let saved_state_inputs = state_clone
564 .borrow()
565 .updatable_states()
566 .into_iter()
567 .map(|array| (*array).clone())
568 .collect::<Vec<Array>>();
569
570 for (s, tracer) in state_clone
572 .borrow_mut()
573 .updatable_states_mut()
574 .into_iter()
575 .zip(tracers.iter().skip(args_len))
576 {
577 update_by_replace_with_ref_to_new_array(s, tracer);
578 }
579
580 let mut result = (f)(*state_clone.borrow_mut(), tracer_args)?;
582
583 let mut state_output_tracers = state_clone
585 .borrow()
586 .updatable_states()
587 .into_iter()
588 .map(|array| (*array).clone())
589 .collect::<Vec<Array>>();
590
591 for (s, saved) in state_clone
593 .borrow_mut()
594 .updatable_states_mut()
595 .into_iter()
596 .zip(saved_state_inputs)
597 {
598 update_by_replace_with_ref_to_new_array(s, &saved);
599 }
600
601 result.append(&mut state_output_tracers);
603
604 Ok(result)
605 };
606
607 let inner_closure = Closure::new_fallible(inner);
608 call_mut_with_state_inner(inner_closure, self.id, self.shapeless, state, args)
609 }
610}