1use std::marker::PhantomData;
6
7use crate::{error::Exception, Array};
8
9use super::{type_id_to_usize, Closure, Compiled, CompiledState, Guarded, VectorArray};
10
11pub fn compile<F, A, O, E>(
17 f: F,
18 shapeless: impl Into<Option<bool>>,
19) -> impl for<'a> FnMut(F::Args<'a>) -> Result<O, Exception>
20where
21 F: Compile<A, O, E> + 'static + Copy,
22{
23 let shapeless = shapeless.into().unwrap_or(false);
24 move |args| {
25 let mut compiled = f.compile(shapeless);
28 compiled.call_mut(args)
29 }
30}
31
32pub trait Compile<A, O, E>: Sized {
40 type Args<'a>;
46
47 fn compile<'args>(self, shapeless: bool) -> impl CallMut<Self::Args<'args>, O, E>;
49}
50
51impl<F> Compile<&[Array], Vec<Array>, ()> for F
52where
53 F: FnMut(&[Array]) -> Vec<Array> + 'static,
54{
55 type Args<'a> = &'a [Array];
56
57 fn compile<'args>(self, shapeless: bool) -> impl CallMut<Self::Args<'args>, Vec<Array>, ()> {
58 let id = type_id_to_usize(&self);
59 let state = CompiledState {
60 f: self,
61
62 shapeless,
63 id,
64 };
65 Compiled {
66 f_marker: PhantomData::<F>,
67 state,
68 }
69 }
70}
71
72impl<F> Compile<&Array, Array, ()> for F
73where
74 F: FnMut(&Array) -> Array + 'static,
75{
76 type Args<'a> = &'a Array;
77
78 fn compile<'args>(mut self, shapeless: bool) -> impl CallMut<Self::Args<'args>, Array, ()> {
79 let id = type_id_to_usize(&self);
80 let f = move |args: &[Array]| -> Vec<Array> {
81 let result = (self)(&args[0]);
82 vec![result]
83 };
84 let state = CompiledState { f, shapeless, id };
85 Compiled {
86 f_marker: PhantomData::<F>,
87 state,
88 }
89 }
90}
91
92impl<F> Compile<(&Array, &Array), Array, ()> for F
93where
94 F: FnMut((&Array, &Array)) -> Array + 'static,
95{
96 type Args<'a> = (&'a Array, &'a Array);
97
98 fn compile<'args>(mut self, shapeless: bool) -> impl CallMut<Self::Args<'args>, Array, ()> {
99 let id = type_id_to_usize(&self);
100 let f = move |args: &[Array]| -> Vec<Array> {
101 let result = (self)((&args[0], &args[1]));
102 vec![result]
103 };
104 let state = CompiledState { f, shapeless, id };
105 Compiled {
106 f_marker: PhantomData::<F>,
107 state,
108 }
109 }
110}
111
112impl<F> Compile<(&Array, &Array, &Array), Array, ()> for F
113where
114 F: FnMut((&Array, &Array, &Array)) -> Array + 'static,
115{
116 type Args<'a> = (&'a Array, &'a Array, &'a Array);
117
118 fn compile<'args>(mut self, shapeless: bool) -> impl CallMut<Self::Args<'args>, Array, ()> {
119 let id = type_id_to_usize(&self);
120 let f = move |args: &[Array]| -> Vec<Array> {
121 let result = (self)((&args[0], &args[1], &args[2]));
122 vec![result]
123 };
124 let state = CompiledState { f, shapeless, id };
125 Compiled {
126 f_marker: PhantomData::<F>,
127 state,
128 }
129 }
130}
131
132impl<F> Compile<&[Array], Vec<Array>, Exception> for F
133where
134 F: FnMut(&[Array]) -> Result<Vec<Array>, Exception> + 'static,
135{
136 type Args<'a> = &'a [Array];
137
138 fn compile<'args>(
139 self,
140 shapeless: bool,
141 ) -> impl CallMut<Self::Args<'args>, Vec<Array>, Exception> {
142 let id = type_id_to_usize(&self);
143 let state = CompiledState {
144 f: self,
145 shapeless,
146 id,
147 };
148 Compiled {
149 f_marker: PhantomData::<F>,
150 state,
151 }
152 }
153}
154
155impl<F> Compile<&Array, Array, Exception> for F
156where
157 F: FnMut(&Array) -> Result<Array, Exception> + 'static,
158{
159 type Args<'a> = &'a Array;
160
161 fn compile<'args>(
162 mut self,
163 shapeless: bool,
164 ) -> impl CallMut<Self::Args<'args>, Array, Exception> {
165 let id = type_id_to_usize(&self);
166 let f = move |args: &[Array]| -> Result<Vec<Array>, Exception> {
167 let result = (self)(&args[0])?;
168 Ok(vec![result])
169 };
170 let state = CompiledState { f, shapeless, id };
171 Compiled {
172 f_marker: PhantomData::<F>,
173 state,
174 }
175 }
176}
177
178impl<F> Compile<(&Array, &Array), Array, Exception> for F
179where
180 F: FnMut((&Array, &Array)) -> Result<Array, Exception> + 'static,
181{
182 type Args<'a> = (&'a Array, &'a Array);
183
184 fn compile<'args>(
185 mut self,
186 shapeless: bool,
187 ) -> impl CallMut<Self::Args<'args>, Array, Exception> {
188 let id = type_id_to_usize(&self);
189 let f = move |args: &[Array]| -> Result<Vec<Array>, Exception> {
190 let result = (self)((&args[0], &args[1]))?;
191 Ok(vec![result])
192 };
193 let state = CompiledState { f, shapeless, id };
194 Compiled {
195 f_marker: PhantomData::<F>,
196 state,
197 }
198 }
199}
200
201impl<F> Compile<(&Array, &Array, &Array), Array, Exception> for F
202where
203 F: FnMut((&Array, &Array, &Array)) -> Result<Array, Exception> + 'static,
204{
205 type Args<'a> = (&'a Array, &'a Array, &'a Array);
206
207 fn compile<'args>(
208 mut self,
209 shapeless: bool,
210 ) -> impl CallMut<Self::Args<'args>, Array, Exception> {
211 let id = type_id_to_usize(&self);
212 let f = move |args: &[Array]| -> Result<Vec<Array>, Exception> {
213 let result = (self)((&args[0], &args[1], &args[2]))?;
214 Ok(vec![result])
215 };
216 let state = CompiledState { f, shapeless, id };
217 Compiled {
218 f_marker: PhantomData::<F>,
219 state,
220 }
221 }
222}
223
224pub trait CallMut<A, O, E> {
226 fn call_mut(&mut self, args: A) -> Result<O, Exception>;
228}
229
230impl<'a, F, G> CallMut<&'a [Array], Vec<Array>, ()> for Compiled<F, G>
231where
232 F: FnMut(&[Array]) -> Vec<Array> + 'a,
233 G: FnMut(&[Array]) -> Vec<Array> + 'a,
234{
235 fn call_mut(&mut self, args: &[Array]) -> Result<Vec<Array>, Exception> {
236 self.state.call_mut(args)
237 }
238}
239
240impl<'a, F, G> CallMut<&'a Array, Array, ()> for Compiled<F, G>
241where
242 F: FnMut(&Array) -> Array + 'a,
243 G: FnMut(&[Array]) -> Vec<Array> + 'a,
244{
245 fn call_mut(&mut self, args: &Array) -> Result<Array, Exception> {
246 let args = std::slice::from_ref(args);
247 let result = self.state.call_mut(args)?;
248 Ok(result.into_iter().next().unwrap())
249 }
250}
251
252impl<'a, F, G> CallMut<(&'a Array, &'a Array), Array, ()> for Compiled<F, G>
253where
254 F: FnMut((&Array, &Array)) -> Array + 'a,
255 G: FnMut(&[Array]) -> Vec<Array> + 'a,
256{
257 fn call_mut(&mut self, args: (&Array, &Array)) -> Result<Array, Exception> {
258 let args = &[args.0, args.1];
259 let result = self.state.call_mut(args)?;
260 Ok(result.into_iter().next().unwrap())
261 }
262}
263
264impl<'a, F, G> CallMut<(&'a Array, &'a Array, &'a Array), Array, ()> for Compiled<F, G>
265where
266 F: FnMut((&Array, &Array, &Array)) -> Array + 'a,
267 G: FnMut(&[Array]) -> Vec<Array> + 'a,
268{
269 fn call_mut(&mut self, args: (&Array, &Array, &Array)) -> Result<Array, Exception> {
270 let args = &[args.0, args.1, args.2];
272 let result = self.state.call_mut(args)?;
273 Ok(result.into_iter().next().unwrap())
274 }
275}
276
277impl<'a, F, G> CallMut<&'a [Array], Vec<Array>, Exception> for Compiled<F, G>
278where
279 F: FnMut(&[Array]) -> Result<Vec<Array>, Exception> + 'a,
280 G: FnMut(&[Array]) -> Result<Vec<Array>, Exception> + 'a,
281{
282 fn call_mut(&mut self, args: &[Array]) -> Result<Vec<Array>, Exception> {
283 self.state.fallible_call_mut(args)
284 }
285}
286
287impl<'a, F, G> CallMut<&'a Array, Array, Exception> for Compiled<F, G>
288where
289 F: FnMut(&Array) -> Result<Array, Exception> + 'a,
290 G: FnMut(&[Array]) -> Result<Vec<Array>, Exception> + 'a,
291{
292 fn call_mut(&mut self, args: &Array) -> Result<Array, Exception> {
293 let args = &[args];
294 let result = self.state.fallible_call_mut(args)?;
295 Ok(result.into_iter().next().unwrap())
296 }
297}
298
299impl<'a, F, G> CallMut<(&'a Array, &'a Array), Array, Exception> for Compiled<F, G>
300where
301 F: FnMut((&Array, &Array)) -> Result<Array, Exception> + 'a,
302 G: FnMut(&[Array]) -> Result<Vec<Array>, Exception> + 'a,
303{
304 fn call_mut(&mut self, args: (&Array, &Array)) -> Result<Array, Exception> {
305 let args = &[args.0, args.1];
306 let result = self.state.fallible_call_mut(args)?;
307 Ok(result.into_iter().next().unwrap())
308 }
309}
310
311impl<'a, F, G> CallMut<(&'a Array, &'a Array, &'a Array), Array, Exception> for Compiled<F, G>
312where
313 F: FnMut((&Array, &Array, &Array)) -> Result<Array, Exception> + 'a,
314 G: FnMut(&[Array]) -> Result<Vec<Array>, Exception> + 'a,
315{
316 fn call_mut(&mut self, args: (&Array, &Array, &Array)) -> Result<Array, Exception> {
317 let args = &[args.0, args.1, args.2];
318 let result = self.state.fallible_call_mut(args)?;
319 Ok(result.into_iter().next().unwrap())
320 }
321}
322
323#[inline]
324fn call_mut_inner(
325 inner_closure: Closure,
326 fun_id: usize,
327 shapeless: bool,
328 args: &[impl AsRef<Array>],
329) -> crate::error::Result<Vec<Array>> {
330 let compiled = Closure::try_from_op(|res| unsafe {
333 let constants = &[];
334 mlx_sys::mlx_detail_compile(
335 res,
336 inner_closure.as_ptr(),
337 fun_id,
338 shapeless,
339 constants.as_ptr(),
340 0,
341 )
342 })?;
343
344 let inner_inputs_vector = VectorArray::try_from_iter(args.iter())?;
345
346 let result_vector = VectorArray::try_from_op(|res| unsafe {
349 mlx_sys::mlx_closure_apply(res, compiled.as_ptr(), inner_inputs_vector.as_ptr())
350 })?;
351 let result_plus_state_output: Vec<Array> = result_vector.try_into_values()?;
352
353 let result_len = result_plus_state_output.len();
354 Ok(result_plus_state_output
355 .into_iter()
356 .take(result_len)
357 .collect())
358}
359
360impl<F> CompiledState<F> {
361 fn call_mut(&mut self, args: &[impl AsRef<Array>]) -> Result<Vec<Array>, Exception>
362 where
363 F: FnMut(&[Array]) -> Vec<Array>,
364 {
365 let inner_closure = Closure::new(&mut self.f);
366
367 call_mut_inner(inner_closure, self.id, self.shapeless, args)
368 }
369
370 fn fallible_call_mut(&mut self, args: &[impl AsRef<Array>]) -> Result<Vec<Array>, Exception>
371 where
372 F: FnMut(&[Array]) -> Result<Vec<Array>, Exception>,
373 {
374 let inner_closure = Closure::new_fallible(&mut self.f);
375
376 call_mut_inner(inner_closure, self.id, self.shapeless, args)
377 }
378}
379
380#[cfg(test)]
381mod tests {
382 use core::panic;
383
384 use crate::{
385 array,
386 error::Exception,
387 ops::{multiply, ones},
388 Array,
389 };
390
391 use super::compile;
392
393 fn example_fn_0(x: f32) -> f32 {
394 x + 1.0
395 }
396
397 fn example_fn_3(x: f32) -> f32 {
398 x + 1.0
399 }
400
401 #[test]
402 fn test_type_id_to_usize() {
403 let example_fn_1 = |x: f32| x + 1.0;
407 let example_fn_2 = |x: f32| x + 1.0;
408
409 let mut ids = Vec::new();
410
411 ids.push(super::type_id_to_usize(&example_fn_0));
412
413 let id1 = super::type_id_to_usize(&example_fn_1);
414 if ids.contains(&id1) {
415 panic!("id1 already exists");
416 }
417 ids.push(id1);
418
419 let id2 = super::type_id_to_usize(&example_fn_2);
420 if ids.contains(&id2) {
421 panic!("id2 already exists");
422 }
423 ids.push(id2);
424
425 let id3 = super::type_id_to_usize(&example_fn_3);
426 if ids.contains(&id3) {
427 panic!("id3 already exists");
428 }
429 ids.push(id3);
430
431 assert_eq!(ids.len(), 4);
432 }
433
434 #[test]
435 fn test_compile() {
436 let f = |inputs: &[Array]| -> Vec<Array> { vec![&inputs[0] * &inputs[1]] };
439 let mut compiled = compile(f, None);
440
441 let i1 = ones::<f32>(&[20, 20]).unwrap();
442 let i2 = ones::<f32>(&[20, 20]).unwrap();
443
444 let args = [i1, i2];
445
446 let r1 = f(&args).drain(0..1).next().unwrap();
448 let r2 = compiled(&args).unwrap().drain(0..1).next().unwrap();
450
451 assert_eq!(&r1, &r2);
452
453 let r3 = compiled(&args).unwrap().drain(0..1).next().unwrap();
454 assert_eq!(&r1, &r3);
455 }
456
457 #[test]
458 fn test_compile_with_error() {
459 let f = |inputs: &[Array]| -> Result<Vec<Array>, Exception> {
460 multiply(&inputs[0], &inputs[1]).map(|x| vec![x])
461 };
462
463 let i1 = ones::<f32>(&[20, 20]).unwrap();
465 let i2 = ones::<f32>(&[20, 20]).unwrap();
466 let args = [i1, i2];
467
468 let r1 = f(&args).unwrap().drain(0..1).next().unwrap();
470
471 let mut compiled = compile(f, None);
473 let r2 = compiled(&args).unwrap().drain(0..1).next().unwrap();
474
475 assert_eq!(&r1, &r2);
476
477 let r3 = compiled(&args).unwrap().drain(0..1).next().unwrap();
478 assert_eq!(&r1, &r3);
479
480 let a = array!([1.0, 2.0, 3.0]);
482 let b = array!([4.0, 5.0]);
483 let args = [a, b];
484
485 let c = array!([4.0, 5.0, 6.0]);
487 let d = array!([7.0, 8.0]);
488 let another_args = [c, d];
489
490 let result = f(&args);
492 assert!(result.is_err());
493
494 let mut compiled = compile(f, None);
496 let result = compiled(&args);
497 assert!(result.is_err());
498
499 let result = compiled(&args);
500 assert!(result.is_err());
501
502 let result = compiled(&another_args);
503 assert!(result.is_err());
504 }
505
506 #[test]
507 fn test_compile_with_one_arg() {
508 let f = |x: &Array| x * x;
509
510 let i = ones::<f32>(&[20, 20]).unwrap();
511
512 let r1 = f(&i);
514
515 let mut compiled = compile(f, None);
517 let r2 = compiled(&i).unwrap();
518
519 assert_eq!(&r1, &r2);
520
521 let r3 = compiled(&i).unwrap();
522 assert_eq!(&r1, &r3);
523 }
524
525 #[test]
526 fn test_compile_with_two_args() {
527 let f = |(x, y): (&Array, &Array)| x * y;
528
529 let i1 = ones::<f32>(&[20, 20]).unwrap();
530 let i2 = ones::<f32>(&[20, 20]).unwrap();
531
532 let r1 = f((&i1, &i2));
534
535 let mut compiled = compile(f, None);
537 let r2 = compiled((&i1, &i2)).unwrap();
538
539 assert_eq!(&r1, &r2);
540
541 let r3 = compiled((&i1, &i2)).unwrap();
542 assert_eq!(&r1, &r3);
543 }
544
545 #[test]
546 fn test_compile_with_three_args() {
547 let f = |(x, y, z): (&Array, &Array, &Array)| x * y * z;
548 let mut compiled = compile(f, None);
549
550 let i1 = ones::<f32>(&[20, 20]).unwrap();
551 let i2 = ones::<f32>(&[20, 20]).unwrap();
552 let i3 = ones::<f32>(&[20, 20]).unwrap();
553
554 let r1 = f((&i1, &i2, &i3));
556
557 let r2 = compiled((&i1, &i2, &i3)).unwrap();
559
560 assert_eq!(&r1, &r2);
561
562 let r3 = compiled((&i1, &i2, &i3)).unwrap();
563 assert_eq!(&r1, &r3);
564 }
565}