1use std::{
2 borrow::Cow,
3 ops::{Bound, Deref, RangeBounds},
4 rc::Rc,
5};
6
7use smallvec::{smallvec, SmallVec};
8
9use crate::{
10 array,
11 constants::DEFAULT_STACK_VEC_LEN,
12 error::Result,
13 ops::indexing::expand_ellipsis_operations,
14 utils::{resolve_index_unchecked, VectorArray},
15 Array, Stream,
16};
17
18use super::{
19 ArrayIndex, ArrayIndexOp, Ellipsis, Guarded, NewAxis, RangeIndex, StrideBy, TryIndexOp,
20};
21
22impl<'a> ArrayIndex<'a> for i32 {
27 fn index_op(self) -> ArrayIndexOp<'a> {
28 ArrayIndexOp::TakeIndex { index: self }
29 }
30}
31
32impl<'a> ArrayIndex<'a> for NewAxis {
33 fn index_op(self) -> ArrayIndexOp<'a> {
34 ArrayIndexOp::ExpandDims
35 }
36}
37
38impl<'a> ArrayIndex<'a> for Ellipsis {
39 fn index_op(self) -> ArrayIndexOp<'a> {
40 ArrayIndexOp::Ellipsis
41 }
42}
43
44impl<'a> ArrayIndex<'a> for Array {
45 fn index_op(self) -> ArrayIndexOp<'a> {
46 ArrayIndexOp::TakeArray {
47 indices: Rc::new(self),
48 }
49 }
50}
51
52impl<'a> ArrayIndex<'a> for &'a Array {
53 fn index_op(self) -> ArrayIndexOp<'a> {
54 ArrayIndexOp::TakeArrayRef { indices: self }
55 }
56}
57
58impl<'a> ArrayIndex<'a> for ArrayIndexOp<'a> {
59 fn index_op(self) -> ArrayIndexOp<'a> {
60 self
61 }
62}
63
64macro_rules! impl_array_index_for_bounded_range {
65 ($t:ty) => {
66 impl<'a> ArrayIndex<'a> for $t {
67 fn index_op(self) -> ArrayIndexOp<'a> {
68 ArrayIndexOp::Slice(RangeIndex::new(
69 self.start_bound().cloned(),
70 self.end_bound().cloned(),
71 Some(1),
72 ))
73 }
74 }
75 };
76}
77
78impl_array_index_for_bounded_range!(std::ops::Range<i32>);
79impl_array_index_for_bounded_range!(std::ops::RangeFrom<i32>);
80impl_array_index_for_bounded_range!(std::ops::RangeInclusive<i32>);
81impl_array_index_for_bounded_range!(std::ops::RangeTo<i32>);
82impl_array_index_for_bounded_range!(std::ops::RangeToInclusive<i32>);
83
84impl<'a> ArrayIndex<'a> for std::ops::RangeFull {
85 fn index_op(self) -> ArrayIndexOp<'a> {
86 ArrayIndexOp::Slice(RangeIndex::new(Bound::Unbounded, Bound::Unbounded, Some(1)))
87 }
88}
89
90macro_rules! impl_array_index_for_stride_by_bounded_range {
91 ($t:ty) => {
92 impl<'a> ArrayIndex<'a> for StrideBy<$t> {
93 fn index_op(self) -> ArrayIndexOp<'a> {
94 ArrayIndexOp::Slice(RangeIndex::new(
95 self.inner.start_bound().cloned(),
96 self.inner.end_bound().cloned(),
97 Some(self.stride),
98 ))
99 }
100 }
101 };
102}
103
104impl_array_index_for_stride_by_bounded_range!(std::ops::Range<i32>);
105impl_array_index_for_stride_by_bounded_range!(std::ops::RangeFrom<i32>);
106impl_array_index_for_stride_by_bounded_range!(std::ops::RangeInclusive<i32>);
107impl_array_index_for_stride_by_bounded_range!(std::ops::RangeTo<i32>);
108impl_array_index_for_stride_by_bounded_range!(std::ops::RangeToInclusive<i32>);
109
110impl<'a> ArrayIndex<'a> for StrideBy<std::ops::RangeFull> {
111 fn index_op(self) -> ArrayIndexOp<'a> {
112 ArrayIndexOp::Slice(RangeIndex::new(
113 Bound::Unbounded,
114 Bound::Unbounded,
115 Some(self.stride),
116 ))
117 }
118}
119
120impl<'a, T> TryIndexOp<T> for Array
121where
122 T: ArrayIndex<'a>,
123{
124 fn try_index_device(&self, i: T, stream: impl AsRef<Stream>) -> Result<Array> {
125 get_item(self, i, stream)
126 }
127}
128
129impl<'a> TryIndexOp<&'a [ArrayIndexOp<'a>]> for Array {
130 fn try_index_device(
131 &self,
132 i: &'a [ArrayIndexOp<'a>],
133 stream: impl AsRef<Stream>,
134 ) -> Result<Array> {
135 get_item_nd(self, i, stream)
136 }
137}
138
139impl<'a, A> TryIndexOp<(A,)> for Array
140where
141 A: ArrayIndex<'a>,
142{
143 fn try_index_device(&self, i: (A,), stream: impl AsRef<Stream>) -> Result<Array> {
144 let i = [i.0.index_op()];
145 get_item_nd(self, &i, stream)
146 }
147}
148
149impl<'a, 'b, A, B> TryIndexOp<(A, B)> for Array
150where
151 A: ArrayIndex<'a>,
152 B: ArrayIndex<'b>,
153{
154 fn try_index_device(&self, i: (A, B), stream: impl AsRef<Stream>) -> Result<Array> {
155 let i = [i.0.index_op(), i.1.index_op()];
156 get_item_nd(self, &i, stream)
157 }
158}
159
160impl<'a, 'b, 'c, A, B, C> TryIndexOp<(A, B, C)> for Array
161where
162 A: ArrayIndex<'a>,
163 B: ArrayIndex<'b>,
164 C: ArrayIndex<'c>,
165{
166 fn try_index_device(&self, i: (A, B, C), stream: impl AsRef<Stream>) -> Result<Array> {
167 let i = [i.0.index_op(), i.1.index_op(), i.2.index_op()];
168 get_item_nd(self, &i, stream)
169 }
170}
171
172impl<'a, 'b, 'c, 'd, A, B, C, D> TryIndexOp<(A, B, C, D)> for Array
173where
174 A: ArrayIndex<'a>,
175 B: ArrayIndex<'b>,
176 C: ArrayIndex<'c>,
177 D: ArrayIndex<'d>,
178{
179 fn try_index_device(&self, i: (A, B, C, D), stream: impl AsRef<Stream>) -> Result<Array> {
180 let i = [
181 i.0.index_op(),
182 i.1.index_op(),
183 i.2.index_op(),
184 i.3.index_op(),
185 ];
186 get_item_nd(self, &i, stream)
187 }
188}
189
190impl<'a, 'b, 'c, 'd, 'e, A, B, C, D, E> TryIndexOp<(A, B, C, D, E)> for Array
191where
192 A: ArrayIndex<'a>,
193 B: ArrayIndex<'b>,
194 C: ArrayIndex<'c>,
195 D: ArrayIndex<'d>,
196 E: ArrayIndex<'e>,
197{
198 fn try_index_device(&self, i: (A, B, C, D, E), stream: impl AsRef<Stream>) -> Result<Array> {
199 let i = [
200 i.0.index_op(),
201 i.1.index_op(),
202 i.2.index_op(),
203 i.3.index_op(),
204 i.4.index_op(),
205 ];
206 get_item_nd(self, &i, stream)
207 }
208}
209
210impl<'a, 'b, 'c, 'd, 'e, 'f, A, B, C, D, E, F> TryIndexOp<(A, B, C, D, E, F)> for Array
211where
212 A: ArrayIndex<'a>,
213 B: ArrayIndex<'b>,
214 C: ArrayIndex<'c>,
215 D: ArrayIndex<'d>,
216 E: ArrayIndex<'e>,
217 F: ArrayIndex<'f>,
218{
219 fn try_index_device(&self, i: (A, B, C, D, E, F), stream: impl AsRef<Stream>) -> Result<Array> {
220 let i = [
221 i.0.index_op(),
222 i.1.index_op(),
223 i.2.index_op(),
224 i.3.index_op(),
225 i.4.index_op(),
226 i.5.index_op(),
227 ];
228 get_item_nd(self, &i, stream)
229 }
230}
231
232impl<'a, 'b, 'c, 'd, 'e, 'f, 'g, A, B, C, D, E, F, G> TryIndexOp<(A, B, C, D, E, F, G)> for Array
233where
234 A: ArrayIndex<'a>,
235 B: ArrayIndex<'b>,
236 C: ArrayIndex<'c>,
237 D: ArrayIndex<'d>,
238 E: ArrayIndex<'e>,
239 F: ArrayIndex<'f>,
240 G: ArrayIndex<'g>,
241{
242 fn try_index_device(
243 &self,
244 i: (A, B, C, D, E, F, G),
245 stream: impl AsRef<Stream>,
246 ) -> Result<Array> {
247 let i = [
248 i.0.index_op(),
249 i.1.index_op(),
250 i.2.index_op(),
251 i.3.index_op(),
252 i.4.index_op(),
253 i.5.index_op(),
254 i.6.index_op(),
255 ];
256 get_item_nd(self, &i, stream)
257 }
258}
259
260impl<'a, 'b, 'c, 'd, 'e, 'f, 'g, 'h, A, B, C, D, E, F, G, H> TryIndexOp<(A, B, C, D, E, F, G, H)>
261 for Array
262where
263 A: ArrayIndex<'a>,
264 B: ArrayIndex<'b>,
265 C: ArrayIndex<'c>,
266 D: ArrayIndex<'d>,
267 E: ArrayIndex<'e>,
268 F: ArrayIndex<'f>,
269 G: ArrayIndex<'g>,
270 H: ArrayIndex<'h>,
271{
272 fn try_index_device(
273 &self,
274 i: (A, B, C, D, E, F, G, H),
275 stream: impl AsRef<Stream>,
276 ) -> Result<Array> {
277 let i = [
278 i.0.index_op(),
279 i.1.index_op(),
280 i.2.index_op(),
281 i.3.index_op(),
282 i.4.index_op(),
283 i.5.index_op(),
284 i.6.index_op(),
285 i.7.index_op(),
286 ];
287 get_item_nd(self, &i, stream)
288 }
289}
290
291impl<'a, 'b, 'c, 'd, 'e, 'f, 'g, 'h, 'i, A, B, C, D, E, F, G, H, I>
292 TryIndexOp<(A, B, C, D, E, F, G, H, I)> for Array
293where
294 A: ArrayIndex<'a>,
295 B: ArrayIndex<'b>,
296 C: ArrayIndex<'c>,
297 D: ArrayIndex<'d>,
298 E: ArrayIndex<'e>,
299 F: ArrayIndex<'f>,
300 G: ArrayIndex<'g>,
301 H: ArrayIndex<'h>,
302 I: ArrayIndex<'i>,
303{
304 fn try_index_device(
305 &self,
306 i: (A, B, C, D, E, F, G, H, I),
307 stream: impl AsRef<Stream>,
308 ) -> Result<Array> {
309 let i = [
310 i.0.index_op(),
311 i.1.index_op(),
312 i.2.index_op(),
313 i.3.index_op(),
314 i.4.index_op(),
315 i.5.index_op(),
316 i.6.index_op(),
317 i.7.index_op(),
318 i.8.index_op(),
319 ];
320 get_item_nd(self, &i, stream)
321 }
322}
323
324impl<'a, 'b, 'c, 'd, 'e, 'f, 'g, 'h, 'i, 'j, A, B, C, D, E, F, G, H, I, J>
325 TryIndexOp<(A, B, C, D, E, F, G, H, I, J)> for Array
326where
327 A: ArrayIndex<'a>,
328 B: ArrayIndex<'b>,
329 C: ArrayIndex<'c>,
330 D: ArrayIndex<'d>,
331 E: ArrayIndex<'e>,
332 F: ArrayIndex<'f>,
333 G: ArrayIndex<'g>,
334 H: ArrayIndex<'h>,
335 I: ArrayIndex<'i>,
336 J: ArrayIndex<'j>,
337{
338 fn try_index_device(
339 &self,
340 i: (A, B, C, D, E, F, G, H, I, J),
341 stream: impl AsRef<Stream>,
342 ) -> Result<Array> {
343 let i = [
344 i.0.index_op(),
345 i.1.index_op(),
346 i.2.index_op(),
347 i.3.index_op(),
348 i.4.index_op(),
349 i.5.index_op(),
350 i.6.index_op(),
351 i.7.index_op(),
352 i.8.index_op(),
353 i.9.index_op(),
354 ];
355 get_item_nd(self, &i, stream)
356 }
357}
358
359impl<'a, 'b, 'c, 'd, 'e, 'f, 'g, 'h, 'i, 'j, 'k, A, B, C, D, E, F, G, H, I, J, K>
360 TryIndexOp<(A, B, C, D, E, F, G, H, I, J, K)> for Array
361where
362 A: ArrayIndex<'a>,
363 B: ArrayIndex<'b>,
364 C: ArrayIndex<'c>,
365 D: ArrayIndex<'d>,
366 E: ArrayIndex<'e>,
367 F: ArrayIndex<'f>,
368 G: ArrayIndex<'g>,
369 H: ArrayIndex<'h>,
370 I: ArrayIndex<'i>,
371 J: ArrayIndex<'j>,
372 K: ArrayIndex<'k>,
373{
374 fn try_index_device(
375 &self,
376 i: (A, B, C, D, E, F, G, H, I, J, K),
377 stream: impl AsRef<Stream>,
378 ) -> Result<Array> {
379 let i = [
380 i.0.index_op(),
381 i.1.index_op(),
382 i.2.index_op(),
383 i.3.index_op(),
384 i.4.index_op(),
385 i.5.index_op(),
386 i.6.index_op(),
387 i.7.index_op(),
388 i.8.index_op(),
389 i.9.index_op(),
390 i.10.index_op(),
391 ];
392 get_item_nd(self, &i, stream)
393 }
394}
395
396impl<'a, 'b, 'c, 'd, 'e, 'f, 'g, 'h, 'i, 'j, 'k, 'l, A, B, C, D, E, F, G, H, I, J, K, L>
397 TryIndexOp<(A, B, C, D, E, F, G, H, I, J, K, L)> for Array
398where
399 A: ArrayIndex<'a>,
400 B: ArrayIndex<'b>,
401 C: ArrayIndex<'c>,
402 D: ArrayIndex<'d>,
403 E: ArrayIndex<'e>,
404 F: ArrayIndex<'f>,
405 G: ArrayIndex<'g>,
406 H: ArrayIndex<'h>,
407 I: ArrayIndex<'i>,
408 J: ArrayIndex<'j>,
409 K: ArrayIndex<'k>,
410 L: ArrayIndex<'l>,
411{
412 fn try_index_device(
413 &self,
414 i: (A, B, C, D, E, F, G, H, I, J, K, L),
415 stream: impl AsRef<Stream>,
416 ) -> Result<Array> {
417 let i = [
418 i.0.index_op(),
419 i.1.index_op(),
420 i.2.index_op(),
421 i.3.index_op(),
422 i.4.index_op(),
423 i.5.index_op(),
424 i.6.index_op(),
425 i.7.index_op(),
426 i.8.index_op(),
427 i.9.index_op(),
428 i.10.index_op(),
429 i.11.index_op(),
430 ];
431 get_item_nd(self, &i, stream)
432 }
433}
434
435impl<'a, 'b, 'c, 'd, 'e, 'f, 'g, 'h, 'i, 'j, 'k, 'l, 'm, A, B, C, D, E, F, G, H, I, J, K, L, M>
436 TryIndexOp<(A, B, C, D, E, F, G, H, I, J, K, L, M)> for Array
437where
438 A: ArrayIndex<'a>,
439 B: ArrayIndex<'b>,
440 C: ArrayIndex<'c>,
441 D: ArrayIndex<'d>,
442 E: ArrayIndex<'e>,
443 F: ArrayIndex<'f>,
444 G: ArrayIndex<'g>,
445 H: ArrayIndex<'h>,
446 I: ArrayIndex<'i>,
447 J: ArrayIndex<'j>,
448 K: ArrayIndex<'k>,
449 L: ArrayIndex<'l>,
450 M: ArrayIndex<'m>,
451{
452 fn try_index_device(
453 &self,
454 i: (A, B, C, D, E, F, G, H, I, J, K, L, M),
455 stream: impl AsRef<Stream>,
456 ) -> Result<Array> {
457 let i = [
458 i.0.index_op(),
459 i.1.index_op(),
460 i.2.index_op(),
461 i.3.index_op(),
462 i.4.index_op(),
463 i.5.index_op(),
464 i.6.index_op(),
465 i.7.index_op(),
466 i.8.index_op(),
467 i.9.index_op(),
468 i.10.index_op(),
469 i.11.index_op(),
470 i.12.index_op(),
471 ];
472 get_item_nd(self, &i, stream)
473 }
474}
475
476impl<
477 'a,
478 'b,
479 'c,
480 'd,
481 'e,
482 'f,
483 'g,
484 'h,
485 'i,
486 'j,
487 'k,
488 'l,
489 'm,
490 'n,
491 A,
492 B,
493 C,
494 D,
495 E,
496 F,
497 G,
498 H,
499 I,
500 J,
501 K,
502 L,
503 M,
504 N,
505 > TryIndexOp<(A, B, C, D, E, F, G, H, I, J, K, L, M, N)> for Array
506where
507 A: ArrayIndex<'a>,
508 B: ArrayIndex<'b>,
509 C: ArrayIndex<'c>,
510 D: ArrayIndex<'d>,
511 E: ArrayIndex<'e>,
512 F: ArrayIndex<'f>,
513 G: ArrayIndex<'g>,
514 H: ArrayIndex<'h>,
515 I: ArrayIndex<'i>,
516 J: ArrayIndex<'j>,
517 K: ArrayIndex<'k>,
518 L: ArrayIndex<'l>,
519 M: ArrayIndex<'m>,
520 N: ArrayIndex<'n>,
521{
522 fn try_index_device(
523 &self,
524 i: (A, B, C, D, E, F, G, H, I, J, K, L, M, N),
525 stream: impl AsRef<Stream>,
526 ) -> Result<Array> {
527 let i = [
528 i.0.index_op(),
529 i.1.index_op(),
530 i.2.index_op(),
531 i.3.index_op(),
532 i.4.index_op(),
533 i.5.index_op(),
534 i.6.index_op(),
535 i.7.index_op(),
536 i.8.index_op(),
537 i.9.index_op(),
538 i.10.index_op(),
539 i.11.index_op(),
540 i.12.index_op(),
541 i.13.index_op(),
542 ];
543 get_item_nd(self, &i, stream)
544 }
545}
546
547impl<
548 'a,
549 'b,
550 'c,
551 'd,
552 'e,
553 'f,
554 'g,
555 'h,
556 'i,
557 'j,
558 'k,
559 'l,
560 'm,
561 'n,
562 'o,
563 A,
564 B,
565 C,
566 D,
567 E,
568 F,
569 G,
570 H,
571 I,
572 J,
573 K,
574 L,
575 M,
576 N,
577 O,
578 > TryIndexOp<(A, B, C, D, E, F, G, H, I, J, K, L, M, N, O)> for Array
579where
580 A: ArrayIndex<'a>,
581 B: ArrayIndex<'b>,
582 C: ArrayIndex<'c>,
583 D: ArrayIndex<'d>,
584 E: ArrayIndex<'e>,
585 F: ArrayIndex<'f>,
586 G: ArrayIndex<'g>,
587 H: ArrayIndex<'h>,
588 I: ArrayIndex<'i>,
589 J: ArrayIndex<'j>,
590 K: ArrayIndex<'k>,
591 L: ArrayIndex<'l>,
592 M: ArrayIndex<'m>,
593 N: ArrayIndex<'n>,
594 O: ArrayIndex<'o>,
595{
596 fn try_index_device(
597 &self,
598 i: (A, B, C, D, E, F, G, H, I, J, K, L, M, N, O),
599 stream: impl AsRef<Stream>,
600 ) -> Result<Array> {
601 let i = [
602 i.0.index_op(),
603 i.1.index_op(),
604 i.2.index_op(),
605 i.3.index_op(),
606 i.4.index_op(),
607 i.5.index_op(),
608 i.6.index_op(),
609 i.7.index_op(),
610 i.8.index_op(),
611 i.9.index_op(),
612 i.10.index_op(),
613 i.11.index_op(),
614 i.12.index_op(),
615 i.13.index_op(),
616 i.14.index_op(),
617 ];
618 get_item_nd(self, &i, stream)
619 }
620}
621
622impl<
623 'a,
624 'b,
625 'c,
626 'd,
627 'e,
628 'f,
629 'g,
630 'h,
631 'i,
632 'j,
633 'k,
634 'l,
635 'm,
636 'n,
637 'o,
638 'p,
639 A,
640 B,
641 C,
642 D,
643 E,
644 F,
645 G,
646 H,
647 I,
648 J,
649 K,
650 L,
651 M,
652 N,
653 O,
654 P,
655 > TryIndexOp<(A, B, C, D, E, F, G, H, I, J, K, L, M, N, O, P)> for Array
656where
657 A: ArrayIndex<'a>,
658 B: ArrayIndex<'b>,
659 C: ArrayIndex<'c>,
660 D: ArrayIndex<'d>,
661 E: ArrayIndex<'e>,
662 F: ArrayIndex<'f>,
663 G: ArrayIndex<'g>,
664 H: ArrayIndex<'h>,
665 I: ArrayIndex<'i>,
666 J: ArrayIndex<'j>,
667 K: ArrayIndex<'k>,
668 L: ArrayIndex<'l>,
669 M: ArrayIndex<'m>,
670 N: ArrayIndex<'n>,
671 O: ArrayIndex<'o>,
672 P: ArrayIndex<'p>,
673{
674 fn try_index_device(
675 &self,
676 i: (A, B, C, D, E, F, G, H, I, J, K, L, M, N, O, P),
677 stream: impl AsRef<Stream>,
678 ) -> Result<Array> {
679 let i = [
680 i.0.index_op(),
681 i.1.index_op(),
682 i.2.index_op(),
683 i.3.index_op(),
684 i.4.index_op(),
685 i.5.index_op(),
686 i.6.index_op(),
687 i.7.index_op(),
688 i.8.index_op(),
689 i.9.index_op(),
690 i.10.index_op(),
691 i.11.index_op(),
692 i.12.index_op(),
693 i.13.index_op(),
694 i.14.index_op(),
695 i.15.index_op(),
696 ];
697 get_item_nd(self, &i, stream)
698 }
699}
700
701impl Array {
703 pub(crate) fn slice_device(
707 &self,
708 start: &[i32],
709 stop: &[i32],
710 strides: &[i32],
711 stream: impl AsRef<Stream>,
712 ) -> Result<Array> {
713 Array::try_from_op(|res| unsafe {
714 mlx_sys::mlx_slice(
715 res,
716 self.as_ptr(),
717 start.as_ptr(),
718 start.len(),
719 stop.as_ptr(),
720 stop.len(),
721 strides.as_ptr(),
722 strides.len(),
723 stream.as_ref().as_ptr(),
724 )
725 })
726 }
727}
728
729fn absolute_indices(absolute_start: i32, absolute_end: i32, stride: i32) -> Vec<i32> {
734 let mut indices = Vec::new();
735 let mut i = absolute_start;
736 while (stride > 0 && i < absolute_end) || (stride < 0 && i > absolute_end) {
737 indices.push(i);
738 i += stride;
739 }
740 indices
741}
742
743enum GatherIndexItem<'a> {
744 Owned(Array),
745 Borrowed(&'a Array),
746 Rc(Rc<Array>),
747}
748
749impl AsRef<Array> for GatherIndexItem<'_> {
750 fn as_ref(&self) -> &Array {
751 match self {
752 GatherIndexItem::Owned(array) => array,
753 GatherIndexItem::Borrowed(array) => array,
754 GatherIndexItem::Rc(array) => array,
755 }
756 }
757}
758
759impl Deref for GatherIndexItem<'_> {
760 type Target = Array;
761
762 fn deref(&self) -> &Self::Target {
763 match self {
764 GatherIndexItem::Owned(array) => array,
765 GatherIndexItem::Borrowed(array) => array,
766 GatherIndexItem::Rc(array) => array,
767 }
768 }
769}
770
771#[inline]
775fn gather_nd<'a>(
776 src: &Array,
777 operations: impl Iterator<Item = &'a ArrayIndexOp<'a>>,
778 gather_first: bool,
779 last_array_or_index: usize,
780 stream: impl AsRef<Stream>,
781) -> Result<(usize, Array)> {
782 use ArrayIndexOp::*;
783
784 let mut max_dims = 0;
785 let mut slice_count = 0;
786 let mut is_slice: Vec<bool> = Vec::with_capacity(last_array_or_index);
787 let mut gather_indices: Vec<GatherIndexItem> = Vec::with_capacity(last_array_or_index);
788
789 let shape = src.shape();
790
791 let mut axes = Vec::with_capacity(last_array_or_index);
793 let mut operation_len: usize = 0;
794 let mut slice_sizes = shape.to_vec();
795 for (i, op) in operations.enumerate() {
796 axes.push(i as i32);
797 operation_len += 1;
798 slice_sizes[i] = 1;
799 match op {
800 TakeIndex { index } => {
801 let item = Array::from_int(resolve_index_unchecked(
802 *index,
803 src.dim(i as i32) as usize,
804 ) as i32);
805 gather_indices.push(GatherIndexItem::Owned(item));
806 is_slice.push(false);
807 }
808 Slice(range) => {
809 slice_count += 1;
810 is_slice.push(true);
811
812 let size = shape[i];
813 let absolute_start = range.absolute_start(size);
814 let absolute_end = range.absolute_end(size);
815 let indices = absolute_indices(absolute_start, absolute_end, range.stride());
816
817 let item = Array::from_slice(&indices, &[indices.len() as i32]);
818
819 gather_indices.push(GatherIndexItem::Owned(item));
820 }
821 TakeArray { indices } => {
822 is_slice.push(false);
823 max_dims = max_dims.max(indices.ndim());
824 gather_indices.push(GatherIndexItem::Rc(indices.clone()));
826 }
827 TakeArrayRef { indices } => {
828 is_slice.push(false);
829 max_dims = max_dims.max(indices.ndim());
830 gather_indices.push(GatherIndexItem::Borrowed(indices));
832 }
833 Ellipsis | ExpandDims => {
834 unreachable!("Unexpected operation in gather_nd")
835 }
836 }
837 }
838
839 if gather_first {
841 if slice_count > 0 {
842 let mut slice_index = 0;
843 for (i, item) in gather_indices.iter_mut().enumerate() {
844 if is_slice[i] {
845 let mut new_shape = vec![1; max_dims + slice_count];
846 new_shape[max_dims + slice_index] = item.dim(0);
847 *item = GatherIndexItem::Owned(item.reshape(&new_shape)?);
848 slice_index += 1;
849 } else {
850 let mut new_shape = item.shape().to_vec();
851 new_shape.extend((0..slice_count).map(|_| 1));
852 *item = GatherIndexItem::Owned(item.reshape(&new_shape)?);
853 }
854 }
855 }
856 } else {
857 for (i, item) in gather_indices[..slice_count].iter_mut().enumerate() {
859 let mut new_shape = vec![1; max_dims + slice_count];
860 new_shape[i] = item.dim(0);
861 *item = GatherIndexItem::Owned(item.reshape(&new_shape)?);
862 }
863 }
864
865 let indices = VectorArray::try_from_iter(gather_indices.iter())?;
870
871 let gathered = Array::try_from_op(|res| unsafe {
872 mlx_sys::mlx_gather(
873 res,
874 src.as_ptr(),
875 indices.as_ptr(),
876 axes.as_ptr(),
877 axes.len(),
878 slice_sizes.as_ptr(),
879 slice_sizes.len(),
880 stream.as_ref().as_ptr(),
881 )
882 })?;
883 let gathered_shape = gathered.shape();
884
885 let output_shape: Vec<i32> = gathered_shape[0..(max_dims + slice_count)]
887 .iter()
888 .chain(gathered_shape[(max_dims + slice_count + operation_len)..].iter())
889 .copied()
890 .collect();
891 let result = gathered.reshape(&output_shape)?;
892
893 Ok((max_dims, result))
894}
895
896#[inline]
897fn get_item_index(src: &Array, index: i32, axis: i32, stream: impl AsRef<Stream>) -> Result<Array> {
898 let index = resolve_index_unchecked(index, src.dim(axis) as usize) as i32;
899 src.take_device(array!(index), axis, stream)
900}
901
902#[inline]
903fn get_item_array(
904 src: &Array,
905 indices: &Array,
906 axis: i32,
907 stream: impl AsRef<Stream>,
908) -> Result<Array> {
909 src.take_device(indices, axis, stream)
910}
911
912#[inline]
913fn get_item_slice(src: &Array, range: RangeIndex, stream: impl AsRef<Stream>) -> Result<Array> {
914 let ndim = src.ndim();
915 let mut starts: SmallVec<[i32; DEFAULT_STACK_VEC_LEN]> = smallvec![0; ndim];
916 let mut ends: SmallVec<[i32; DEFAULT_STACK_VEC_LEN]> = SmallVec::from_slice(src.shape());
917 let mut strides: SmallVec<[i32; DEFAULT_STACK_VEC_LEN]> = smallvec![1; ndim];
918
919 let size = ends[0];
920 starts[0] = range.start(size);
921 ends[0] = range.end(size);
922 strides[0] = range.stride();
923
924 src.slice_device(&starts, &ends, &strides, stream)
925}
926
927fn get_item<'a>(
930 src: &Array,
931 index: impl ArrayIndex<'a>,
932 stream: impl AsRef<Stream>,
933) -> Result<Array> {
934 use ArrayIndexOp::*;
935
936 match index.index_op() {
937 Ellipsis => Ok(src.deep_clone()),
938 TakeIndex { index } => get_item_index(src, index, 0, stream),
939 TakeArray { indices } => get_item_array(src, &indices, 0, stream),
940 TakeArrayRef { indices } => get_item_array(src, indices, 0, stream),
941 Slice(range) => get_item_slice(src, range, stream),
942 ExpandDims => src.expand_dims_device(&[0], stream),
943 }
944}
945
946fn get_item_nd(
949 src: &Array,
950 operations: &[ArrayIndexOp],
951 stream: impl AsRef<Stream>,
952) -> Result<Array> {
953 use ArrayIndexOp::*;
954
955 let mut src = Cow::Borrowed(src);
956
957 let operations = expand_ellipsis_operations(src.ndim(), operations);
963
964 let mut gather_first = false;
973 let mut have_array_or_index = false; let mut have_non_array = false;
975 for item in operations.iter() {
976 if item.is_array_or_index() {
977 if have_array_or_index && have_non_array {
978 gather_first = true;
979 break;
980 }
981 have_array_or_index = true;
982 } else {
983 have_non_array = have_non_array || have_array_or_index;
984 }
985 }
986
987 let array_count = operations.iter().filter(|op| op.is_array()).count();
988 let have_array = array_count > 0;
989
990 let mut remaining_indices: Vec<ArrayIndexOp> = Vec::new();
991 if have_array {
992 let last_array_or_index = operations
995 .iter()
996 .rposition(|op| op.is_array_or_index())
997 .unwrap(); let gather_indices = operations[..=last_array_or_index]
999 .iter()
1000 .filter(|op| !matches!(op, Ellipsis | ExpandDims));
1001 let (max_dims, gathered) = gather_nd(
1002 &src,
1003 gather_indices,
1004 gather_first,
1005 last_array_or_index,
1006 &stream,
1007 )?;
1008
1009 src = Cow::Owned(gathered);
1010
1011 if gather_first {
1013 remaining_indices.extend((0..max_dims).map(|_| (..).index_op()));
1014
1015 for item in &operations[..=last_array_or_index] {
1018 match item {
1020 ExpandDims => remaining_indices.push(item.clone()),
1021 Slice { .. } => remaining_indices.push((..).index_op()),
1022 Ellipsis
1023 | TakeIndex { index: _ }
1024 | TakeArray { indices: _ }
1025 | TakeArrayRef { indices: _ } => {}
1026 }
1027 }
1028
1029 remaining_indices.extend(operations[(last_array_or_index + 1)..].iter().cloned());
1031 } else {
1032 for item in operations.iter() {
1034 match item {
1036 TakeIndex { .. } | TakeArray { .. } | TakeArrayRef { .. } => break,
1037 ExpandDims => remaining_indices.push(item.clone()),
1038 Ellipsis | Slice(_) => remaining_indices.push((..).index_op()),
1039 }
1040 }
1041
1042 remaining_indices.extend((0..max_dims).map(|_| (..).index_op()));
1044
1045 remaining_indices.extend(operations[(last_array_or_index + 1)..].iter().cloned());
1047 }
1048 }
1049
1050 if have_array && remaining_indices.is_empty() {
1051 return Ok(src.into_owned());
1052 }
1053
1054 if remaining_indices.is_empty() {
1055 remaining_indices = operations.to_vec();
1056 }
1057
1058 let ndim = src.ndim();
1060 let mut starts: SmallVec<[i32; DEFAULT_STACK_VEC_LEN]> = smallvec![0; ndim];
1061 let mut ends: SmallVec<[i32; DEFAULT_STACK_VEC_LEN]> = SmallVec::from_slice(src.shape());
1062 let mut strides: SmallVec<[i32; DEFAULT_STACK_VEC_LEN]> = smallvec![1; ndim];
1063 let mut squeeze_needed = false;
1064 let mut axis = 0;
1065
1066 for item in remaining_indices.iter() {
1067 match item {
1068 ExpandDims => continue,
1069 TakeIndex { mut index } => {
1070 if !have_array {
1071 index = resolve_index_unchecked(index, src.dim(axis as i32) as usize) as i32;
1072 starts[axis] = index;
1073 ends[axis] = index + 1;
1074 squeeze_needed = true;
1075 }
1076 }
1077 Slice(range) => {
1078 let size = src.dim(axis as i32);
1079 starts[axis] = range.start(size);
1080 ends[axis] = range.end(size);
1081 strides[axis] = range.stride();
1082 }
1083 Ellipsis | TakeArray { .. } | TakeArrayRef { .. } => {
1084 unreachable!("Unexpected item in remaining_indices: {:?}", item)
1085 }
1086 }
1087 axis += 1;
1088 }
1089
1090 src = Cow::Owned(src.slice_device(&starts, &ends, &strides, stream)?);
1091
1092 if remaining_indices.len() > ndim || squeeze_needed {
1094 let mut new_shape = SmallVec::<[i32; DEFAULT_STACK_VEC_LEN]>::new();
1095 let mut axis_ = 0;
1096 for item in remaining_indices {
1097 match item {
1099 ExpandDims => new_shape.push(1),
1100 TakeIndex { .. } => {
1101 if squeeze_needed {
1102 axis_ += 1;
1103 }
1104 }
1105 Ellipsis | TakeArray { .. } | TakeArrayRef { .. } | Slice(_) => {
1106 new_shape.push(src.dim(axis_));
1107 axis_ += 1;
1108 }
1109 }
1110 }
1111 new_shape.extend(src.shape()[(axis_ as usize)..].iter().cloned());
1112
1113 src = Cow::Owned(src.reshape(&new_shape)?);
1114 }
1115
1116 Ok(src.into_owned())
1117}
1118
1119#[cfg(test)]
1124mod tests {
1125 use crate::{
1126 assert_array_eq,
1127 ops::indexing::{Ellipsis, IndexOp, IntoStrideBy, NewAxis},
1128 Array,
1129 };
1130
1131 #[test]
1132 fn test_array_index_negative_int() {
1133 let a = Array::from_iter(0i32..8, &[8]);
1134
1135 let s = a.index(-1);
1136
1137 assert_eq!(s.ndim(), 0);
1138 assert_eq!(s.item::<i32>(), 7);
1139
1140 let s = a.index(-8);
1141
1142 assert_eq!(s.ndim(), 0);
1143 assert_eq!(s.item::<i32>(), 0);
1144 }
1145
1146 #[test]
1147 fn test_array_index_new_axis() {
1148 let a = Array::from_iter(0..60, &[3, 4, 5]);
1149 let s = a.index(NewAxis);
1150
1151 assert_eq!(s.ndim(), 4);
1152 assert_eq!(s.shape(), &[1, 3, 4, 5]);
1153
1154 let expected = Array::from_iter(0..60, &[1, 3, 4, 5]);
1155 assert_array_eq!(s, expected, 0.01);
1156 }
1157
1158 #[test]
1159 fn test_array_index_ellipsis() {
1160 let a = Array::from_iter(0i32..8, &[2, 2, 2]);
1161
1162 let s1 = a.index((.., .., 0));
1163 let expected = Array::from_slice(&[0, 2, 4, 6], &[2, 2]);
1164 assert_array_eq!(s1, expected, 0.01);
1165
1166 let s2 = a.index((Ellipsis, 0));
1167
1168 let expected = Array::from_slice(&[0, 2, 4, 6], &[2, 2]);
1169 assert_array_eq!(s2, expected, 0.01);
1170
1171 let s3 = a.index(Ellipsis);
1172
1173 let expected = Array::from_iter(0i32..8, &[2, 2, 2]);
1174 assert_array_eq!(s3, expected, 0.01);
1175 }
1176
1177 #[test]
1178 fn test_array_index_stride() {
1179 let arr = Array::from_iter(0..10, &[10]);
1180 let s = arr.index((2..8).stride_by(2));
1181
1182 let expected = Array::from_slice(&[2, 4, 6], &[3]);
1183 assert_array_eq!(s, expected, 0.01);
1184 }
1185
1186 #[test]
1190 fn test_array_subscript_int() {
1191 let a = Array::from_iter(0i32..512, &[8, 8, 8]);
1192
1193 let s = a.index(1);
1194
1195 assert_eq!(s.ndim(), 2);
1196 assert_eq!(s.shape(), &[8, 8]);
1197
1198 let expected = Array::from_iter(64..128, &[8, 8]);
1199 assert_array_eq!(s, expected, 0.01);
1200 }
1201
1202 #[test]
1203 fn test_array_subscript_int_array() {
1204 let a = Array::from_iter(0i32..512, &[8, 8, 8]);
1206
1207 let s1 = a.index((1, 2));
1208
1209 assert_eq!(s1.ndim(), 1);
1210 assert_eq!(s1.shape(), &[8]);
1211
1212 let expected = Array::from_iter(80..88, &[8]);
1213 assert_array_eq!(s1, expected, 0.01);
1214
1215 let s2 = a.index((1, 2, 3));
1216
1217 assert_eq!(s2.ndim(), 0);
1218 assert!(s2.shape().is_empty());
1219 assert_eq!(s2.item::<i32>(), 64 + 2 * 8 + 3);
1220 }
1221
1222 #[test]
1223 fn test_array_subscript_int_array_2() {
1224 let a = Array::from_iter(0i32..512, &[8, 8, 8, 1]);
1226
1227 let s = a.index(1);
1228
1229 assert_eq!(s.ndim(), 3);
1230 assert_eq!(s.shape(), &[8, 8, 1]);
1231
1232 let s1 = a.index((1, 2));
1233
1234 assert_eq!(s1.ndim(), 2);
1235 assert_eq!(s1.shape(), &[8, 1]);
1236
1237 let s2 = a.index((1, 2, 3));
1238
1239 assert_eq!(s2.ndim(), 1);
1240 assert_eq!(s2.shape(), &[1]);
1241 }
1242
1243 #[test]
1244 fn test_array_subscript_from_end() {
1245 let a = Array::from_iter(0i32..12, &[3, 4]);
1246
1247 let s = a.index((-1, -2));
1248
1249 assert_eq!(s.ndim(), 0);
1250 assert_eq!(s.item::<i32>(), 10);
1251 }
1252
1253 #[test]
1254 fn test_array_subscript_range() {
1255 let a = Array::from_iter(0i32..512, &[8, 8, 8]);
1256
1257 let s1 = a.index(1..3);
1258
1259 assert_eq!(s1.ndim(), 3);
1260 assert_eq!(s1.shape(), &[2, 8, 8]);
1261 let expected = Array::from_iter(64..192, &[2, 8, 8]);
1262 assert_array_eq!(s1, expected, 0.01);
1263
1264 let s2 = a.index(1..=1);
1266
1267 assert_eq!(s2.ndim(), 3);
1268 assert_eq!(s2.shape(), &[1, 8, 8]);
1269 let expected = Array::from_iter(64..128, &[1, 8, 8]);
1270 assert_array_eq!(s2, expected, 0.01);
1271
1272 let s3 = a.index((1..2, ..3, 3..));
1274
1275 assert_eq!(s3.ndim(), 3);
1276 assert_eq!(s3.shape(), &[1, 3, 5]);
1277 let expected = Array::from_slice(
1278 &[67, 68, 69, 70, 71, 75, 76, 77, 78, 79, 83, 84, 85, 86, 87],
1279 &[1, 3, 5],
1280 );
1281 assert_array_eq!(s3, expected, 0.01);
1282
1283 let s4 = a.index((-2..-1, ..-3, -3..));
1284
1285 assert_eq!(s4.ndim(), 3);
1286 assert_eq!(s4.shape(), &[1, 5, 3]);
1287 let expected = Array::from_slice(
1288 &[
1289 389, 390, 391, 397, 398, 399, 405, 406, 407, 413, 414, 415, 421, 422, 423,
1290 ],
1291 &[1, 5, 3],
1292 );
1293 assert_array_eq!(s4, expected, 0.01);
1294 }
1295
1296 #[test]
1297 fn test_array_subscript_advanced() {
1298 let a = Array::from_iter(0..35, &[5, 7]).as_type::<i32>().unwrap();
1302
1303 let i1 = Array::from_slice(&[0, 2, 4], &[3]);
1304 let i2 = Array::from_slice(&[0, 1, 2], &[3]);
1305
1306 let s1 = a.index((i1, i2));
1307
1308 assert_eq!(s1.ndim(), 1);
1309 assert_eq!(s1.shape(), &[3]);
1310
1311 let expected = Array::from_slice(&[0i32, 15, 30], &[3]);
1312 assert_array_eq!(s1, expected, 0.01);
1313 }
1314
1315 #[test]
1316 fn test_array_subscript_advanced_with_ref() {
1317 let a = Array::from_iter(0..35, &[5, 7]).as_type::<i32>().unwrap();
1318
1319 let i1 = Array::from_slice(&[0, 2, 4], &[3]);
1320 let i2 = Array::from_slice(&[0, 1, 2], &[3]);
1321
1322 let s1 = a.index((i1, &i2));
1323
1324 assert_eq!(s1.ndim(), 1);
1325 assert_eq!(s1.shape(), &[3]);
1326
1327 let expected = Array::from_slice(&[0i32, 15, 30], &[3]);
1328 assert_array_eq!(s1, expected, 0.01);
1329 }
1330
1331 #[test]
1332 fn test_array_subscript_advanced_2() {
1333 let a = Array::from_iter(0..12, &[6, 2]).as_type::<i32>().unwrap();
1334
1335 let i1 = Array::from_slice(&[0, 2, 4], &[3]);
1336 let s2 = a.index(i1);
1337
1338 let expected = Array::from_slice(&[0i32, 1, 4, 5, 8, 9], &[3, 2]);
1339 assert_array_eq!(s2, expected, 0.01);
1340 }
1341
1342 #[test]
1343 fn test_collection() {
1344 let a = Array::from_iter(0i32..20, &[2, 2, 5]);
1345
1346 for i in 0..2 {
1348 let row = a.index(i);
1349 let expected = Array::from_iter((i * 10)..(i * 10 + 10), &[2, 5]);
1350 assert_array_all_close!(row, expected);
1351 }
1352 }
1353
1354 #[test]
1355 fn test_array_subscript_advanced_2d() {
1356 let a = Array::from_iter(0..12, &[4, 3]).as_type::<i32>().unwrap();
1357
1358 let rows = Array::from_slice(&[0, 0, 3, 3], &[2, 2]);
1359 let cols = Array::from_slice(&[0, 2, 0, 2], &[2, 2]);
1360
1361 let s = a.index((rows, cols));
1362
1363 let expected = Array::from_slice(&[0, 2, 9, 11], &[2, 2]);
1364 assert_array_eq!(s, expected, 0.01);
1365 }
1366
1367 #[test]
1368 fn test_array_subscript_advanced_2d_2() {
1369 let a = Array::from_iter(0..12, &[4, 3]).as_type::<i32>().unwrap();
1370
1371 let rows = Array::from_slice(&[0, 3], &[2, 1]);
1372 let cols = Array::from_slice(&[0, 2], &[2]);
1373
1374 let s = a.index((rows, cols));
1375
1376 let expected = Array::from_slice(&[0, 2, 9, 11], &[2, 2]);
1377 assert_array_eq!(s, expected, 0.01);
1378 }
1379
1380 fn check(result: impl AsRef<Array>, shape: &[i32], expected_sum: i32) {
1381 let result = result.as_ref();
1382 assert_eq!(result.shape(), shape);
1383
1384 let sum = result.sum(None, None).unwrap();
1385
1386 assert_eq!(sum.item::<i32>(), expected_sum);
1387 }
1388
1389 #[test]
1390 fn test_full_index_read_single() {
1391 let a = Array::from_iter(0..60, &[3, 4, 5]);
1392
1393 check(a.index(Ellipsis), &[3, 4, 5], 1770);
1395
1396 check(a.index(NewAxis), &[1, 3, 4, 5], 1770);
1398
1399 check(a.index(0), &[4, 5], 190);
1401
1402 check(a.index(1..3), &[2, 4, 5], 1580);
1404
1405 let i = Array::from_slice(&[2, 1], &[2]);
1407
1408 check(a.index(i), &[2, 4, 5], 1580);
1410 }
1411
1412 #[test]
1413 fn test_full_index_read_no_array() {
1414 let a = Array::from_iter(0..360, &[2, 3, 4, 5, 3]);
1415
1416 check(a.index((Ellipsis, 0)), &[2, 3, 4, 5], 21420);
1418
1419 check(a.index((0, Ellipsis)), &[3, 4, 5, 3], 16110);
1421
1422 check(a.index((0, Ellipsis, 0)), &[3, 4, 5], 5310);
1424
1425 let result = a.index((Ellipsis, (..).stride_by(2), ..));
1427 check(result, &[2, 3, 4, 3, 3], 38772);
1428
1429 let result = a.index((Ellipsis, NewAxis, (..).stride_by(2), -1));
1431 check(result, &[2, 3, 4, 1, 3], 12996);
1432
1433 check(a.index((.., 2.., 0)), &[2, 1, 5, 3], 6510);
1435
1436 let result = a.index((
1438 (..).stride_by(-1),
1439 ..2,
1440 2..,
1441 Ellipsis,
1442 NewAxis,
1443 (..).stride_by(2),
1444 ));
1445 check(result, &[2, 2, 2, 5, 1, 2], 13160);
1446 }
1447
1448 #[test]
1449 fn test_full_index_read_array() {
1450 let a = Array::from_iter(0..540, &[3, 3, 4, 5, 3]);
1454
1455 let i = Array::from_slice(&[2, 1], &[2]);
1457
1458 check(a.index((0, &i)), &[2, 4, 5, 3], 14340);
1460
1461 check(a.index((Ellipsis, &i, 0)), &[3, 3, 4, 2], 19224);
1463
1464 check(a.index((&i, 0, Ellipsis)), &[2, 4, 5, 3], 35940);
1466
1467 check(a.index((&i, Ellipsis, &i)), &[2, 3, 4, 5], 43200);
1470
1471 let result = a.index((&i, Ellipsis, (..).stride_by(2), ..));
1473 check(result, &[2, 3, 4, 3, 3], 77652);
1474
1475 let result = a.index((Ellipsis, &i, NewAxis, (..).stride_by(2), -1));
1478 check(result, &[2, 3, 3, 1, 3], 14607);
1479
1480 check(a.index((.., 2.., &i)), &[3, 1, 2, 5, 3], 29655);
1482
1483 let result = a.index((
1485 (..).stride_by(-1),
1486 ..2,
1487 i,
1488 2..,
1489 Ellipsis,
1490 NewAxis,
1491 (..).stride_by(2),
1492 ));
1493 check(result, &[3, 2, 2, 3, 1, 2], 17460);
1494 }
1495}