mlx_rs/ops/indexing/
index_impl.rs

1use std::{
2    borrow::Cow,
3    ops::{Bound, Deref, RangeBounds},
4    rc::Rc,
5};
6
7use smallvec::{smallvec, SmallVec};
8
9use crate::{
10    array,
11    constants::DEFAULT_STACK_VEC_LEN,
12    error::Result,
13    ops::indexing::expand_ellipsis_operations,
14    utils::{resolve_index_unchecked, VectorArray},
15    Array, Stream,
16};
17
18use super::{
19    ArrayIndex, ArrayIndexOp, Ellipsis, Guarded, NewAxis, RangeIndex, StrideBy, TryIndexOp,
20};
21
22/* -------------------------------------------------------------------------- */
23/*                               Implementation                               */
24/* -------------------------------------------------------------------------- */
25
26impl<'a> ArrayIndex<'a> for i32 {
27    fn index_op(self) -> ArrayIndexOp<'a> {
28        ArrayIndexOp::TakeIndex { index: self }
29    }
30}
31
32impl<'a> ArrayIndex<'a> for NewAxis {
33    fn index_op(self) -> ArrayIndexOp<'a> {
34        ArrayIndexOp::ExpandDims
35    }
36}
37
38impl<'a> ArrayIndex<'a> for Ellipsis {
39    fn index_op(self) -> ArrayIndexOp<'a> {
40        ArrayIndexOp::Ellipsis
41    }
42}
43
44impl<'a> ArrayIndex<'a> for Array {
45    fn index_op(self) -> ArrayIndexOp<'a> {
46        ArrayIndexOp::TakeArray {
47            indices: Rc::new(self),
48        }
49    }
50}
51
52impl<'a> ArrayIndex<'a> for &'a Array {
53    fn index_op(self) -> ArrayIndexOp<'a> {
54        ArrayIndexOp::TakeArrayRef { indices: self }
55    }
56}
57
58impl<'a> ArrayIndex<'a> for ArrayIndexOp<'a> {
59    fn index_op(self) -> ArrayIndexOp<'a> {
60        self
61    }
62}
63
64macro_rules! impl_array_index_for_bounded_range {
65    ($t:ty) => {
66        impl<'a> ArrayIndex<'a> for $t {
67            fn index_op(self) -> ArrayIndexOp<'a> {
68                ArrayIndexOp::Slice(RangeIndex::new(
69                    self.start_bound().cloned(),
70                    self.end_bound().cloned(),
71                    Some(1),
72                ))
73            }
74        }
75    };
76}
77
78impl_array_index_for_bounded_range!(std::ops::Range<i32>);
79impl_array_index_for_bounded_range!(std::ops::RangeFrom<i32>);
80impl_array_index_for_bounded_range!(std::ops::RangeInclusive<i32>);
81impl_array_index_for_bounded_range!(std::ops::RangeTo<i32>);
82impl_array_index_for_bounded_range!(std::ops::RangeToInclusive<i32>);
83
84impl<'a> ArrayIndex<'a> for std::ops::RangeFull {
85    fn index_op(self) -> ArrayIndexOp<'a> {
86        ArrayIndexOp::Slice(RangeIndex::new(Bound::Unbounded, Bound::Unbounded, Some(1)))
87    }
88}
89
90macro_rules! impl_array_index_for_stride_by_bounded_range {
91    ($t:ty) => {
92        impl<'a> ArrayIndex<'a> for StrideBy<$t> {
93            fn index_op(self) -> ArrayIndexOp<'a> {
94                ArrayIndexOp::Slice(RangeIndex::new(
95                    self.inner.start_bound().cloned(),
96                    self.inner.end_bound().cloned(),
97                    Some(self.stride),
98                ))
99            }
100        }
101    };
102}
103
104impl_array_index_for_stride_by_bounded_range!(std::ops::Range<i32>);
105impl_array_index_for_stride_by_bounded_range!(std::ops::RangeFrom<i32>);
106impl_array_index_for_stride_by_bounded_range!(std::ops::RangeInclusive<i32>);
107impl_array_index_for_stride_by_bounded_range!(std::ops::RangeTo<i32>);
108impl_array_index_for_stride_by_bounded_range!(std::ops::RangeToInclusive<i32>);
109
110impl<'a> ArrayIndex<'a> for StrideBy<std::ops::RangeFull> {
111    fn index_op(self) -> ArrayIndexOp<'a> {
112        ArrayIndexOp::Slice(RangeIndex::new(
113            Bound::Unbounded,
114            Bound::Unbounded,
115            Some(self.stride),
116        ))
117    }
118}
119
120impl<'a, T> TryIndexOp<T> for Array
121where
122    T: ArrayIndex<'a>,
123{
124    fn try_index_device(&self, i: T, stream: impl AsRef<Stream>) -> Result<Array> {
125        get_item(self, i, stream)
126    }
127}
128
129impl<'a> TryIndexOp<&'a [ArrayIndexOp<'a>]> for Array {
130    fn try_index_device(
131        &self,
132        i: &'a [ArrayIndexOp<'a>],
133        stream: impl AsRef<Stream>,
134    ) -> Result<Array> {
135        get_item_nd(self, i, stream)
136    }
137}
138
139impl<'a, A> TryIndexOp<(A,)> for Array
140where
141    A: ArrayIndex<'a>,
142{
143    fn try_index_device(&self, i: (A,), stream: impl AsRef<Stream>) -> Result<Array> {
144        let i = [i.0.index_op()];
145        get_item_nd(self, &i, stream)
146    }
147}
148
149impl<'a, 'b, A, B> TryIndexOp<(A, B)> for Array
150where
151    A: ArrayIndex<'a>,
152    B: ArrayIndex<'b>,
153{
154    fn try_index_device(&self, i: (A, B), stream: impl AsRef<Stream>) -> Result<Array> {
155        let i = [i.0.index_op(), i.1.index_op()];
156        get_item_nd(self, &i, stream)
157    }
158}
159
160impl<'a, 'b, 'c, A, B, C> TryIndexOp<(A, B, C)> for Array
161where
162    A: ArrayIndex<'a>,
163    B: ArrayIndex<'b>,
164    C: ArrayIndex<'c>,
165{
166    fn try_index_device(&self, i: (A, B, C), stream: impl AsRef<Stream>) -> Result<Array> {
167        let i = [i.0.index_op(), i.1.index_op(), i.2.index_op()];
168        get_item_nd(self, &i, stream)
169    }
170}
171
172impl<'a, 'b, 'c, 'd, A, B, C, D> TryIndexOp<(A, B, C, D)> for Array
173where
174    A: ArrayIndex<'a>,
175    B: ArrayIndex<'b>,
176    C: ArrayIndex<'c>,
177    D: ArrayIndex<'d>,
178{
179    fn try_index_device(&self, i: (A, B, C, D), stream: impl AsRef<Stream>) -> Result<Array> {
180        let i = [
181            i.0.index_op(),
182            i.1.index_op(),
183            i.2.index_op(),
184            i.3.index_op(),
185        ];
186        get_item_nd(self, &i, stream)
187    }
188}
189
190impl<'a, 'b, 'c, 'd, 'e, A, B, C, D, E> TryIndexOp<(A, B, C, D, E)> for Array
191where
192    A: ArrayIndex<'a>,
193    B: ArrayIndex<'b>,
194    C: ArrayIndex<'c>,
195    D: ArrayIndex<'d>,
196    E: ArrayIndex<'e>,
197{
198    fn try_index_device(&self, i: (A, B, C, D, E), stream: impl AsRef<Stream>) -> Result<Array> {
199        let i = [
200            i.0.index_op(),
201            i.1.index_op(),
202            i.2.index_op(),
203            i.3.index_op(),
204            i.4.index_op(),
205        ];
206        get_item_nd(self, &i, stream)
207    }
208}
209
210impl<'a, 'b, 'c, 'd, 'e, 'f, A, B, C, D, E, F> TryIndexOp<(A, B, C, D, E, F)> for Array
211where
212    A: ArrayIndex<'a>,
213    B: ArrayIndex<'b>,
214    C: ArrayIndex<'c>,
215    D: ArrayIndex<'d>,
216    E: ArrayIndex<'e>,
217    F: ArrayIndex<'f>,
218{
219    fn try_index_device(&self, i: (A, B, C, D, E, F), stream: impl AsRef<Stream>) -> Result<Array> {
220        let i = [
221            i.0.index_op(),
222            i.1.index_op(),
223            i.2.index_op(),
224            i.3.index_op(),
225            i.4.index_op(),
226            i.5.index_op(),
227        ];
228        get_item_nd(self, &i, stream)
229    }
230}
231
232impl<'a, 'b, 'c, 'd, 'e, 'f, 'g, A, B, C, D, E, F, G> TryIndexOp<(A, B, C, D, E, F, G)> for Array
233where
234    A: ArrayIndex<'a>,
235    B: ArrayIndex<'b>,
236    C: ArrayIndex<'c>,
237    D: ArrayIndex<'d>,
238    E: ArrayIndex<'e>,
239    F: ArrayIndex<'f>,
240    G: ArrayIndex<'g>,
241{
242    fn try_index_device(
243        &self,
244        i: (A, B, C, D, E, F, G),
245        stream: impl AsRef<Stream>,
246    ) -> Result<Array> {
247        let i = [
248            i.0.index_op(),
249            i.1.index_op(),
250            i.2.index_op(),
251            i.3.index_op(),
252            i.4.index_op(),
253            i.5.index_op(),
254            i.6.index_op(),
255        ];
256        get_item_nd(self, &i, stream)
257    }
258}
259
260impl<'a, 'b, 'c, 'd, 'e, 'f, 'g, 'h, A, B, C, D, E, F, G, H> TryIndexOp<(A, B, C, D, E, F, G, H)>
261    for Array
262where
263    A: ArrayIndex<'a>,
264    B: ArrayIndex<'b>,
265    C: ArrayIndex<'c>,
266    D: ArrayIndex<'d>,
267    E: ArrayIndex<'e>,
268    F: ArrayIndex<'f>,
269    G: ArrayIndex<'g>,
270    H: ArrayIndex<'h>,
271{
272    fn try_index_device(
273        &self,
274        i: (A, B, C, D, E, F, G, H),
275        stream: impl AsRef<Stream>,
276    ) -> Result<Array> {
277        let i = [
278            i.0.index_op(),
279            i.1.index_op(),
280            i.2.index_op(),
281            i.3.index_op(),
282            i.4.index_op(),
283            i.5.index_op(),
284            i.6.index_op(),
285            i.7.index_op(),
286        ];
287        get_item_nd(self, &i, stream)
288    }
289}
290
291impl<'a, 'b, 'c, 'd, 'e, 'f, 'g, 'h, 'i, A, B, C, D, E, F, G, H, I>
292    TryIndexOp<(A, B, C, D, E, F, G, H, I)> for Array
293where
294    A: ArrayIndex<'a>,
295    B: ArrayIndex<'b>,
296    C: ArrayIndex<'c>,
297    D: ArrayIndex<'d>,
298    E: ArrayIndex<'e>,
299    F: ArrayIndex<'f>,
300    G: ArrayIndex<'g>,
301    H: ArrayIndex<'h>,
302    I: ArrayIndex<'i>,
303{
304    fn try_index_device(
305        &self,
306        i: (A, B, C, D, E, F, G, H, I),
307        stream: impl AsRef<Stream>,
308    ) -> Result<Array> {
309        let i = [
310            i.0.index_op(),
311            i.1.index_op(),
312            i.2.index_op(),
313            i.3.index_op(),
314            i.4.index_op(),
315            i.5.index_op(),
316            i.6.index_op(),
317            i.7.index_op(),
318            i.8.index_op(),
319        ];
320        get_item_nd(self, &i, stream)
321    }
322}
323
324impl<'a, 'b, 'c, 'd, 'e, 'f, 'g, 'h, 'i, 'j, A, B, C, D, E, F, G, H, I, J>
325    TryIndexOp<(A, B, C, D, E, F, G, H, I, J)> for Array
326where
327    A: ArrayIndex<'a>,
328    B: ArrayIndex<'b>,
329    C: ArrayIndex<'c>,
330    D: ArrayIndex<'d>,
331    E: ArrayIndex<'e>,
332    F: ArrayIndex<'f>,
333    G: ArrayIndex<'g>,
334    H: ArrayIndex<'h>,
335    I: ArrayIndex<'i>,
336    J: ArrayIndex<'j>,
337{
338    fn try_index_device(
339        &self,
340        i: (A, B, C, D, E, F, G, H, I, J),
341        stream: impl AsRef<Stream>,
342    ) -> Result<Array> {
343        let i = [
344            i.0.index_op(),
345            i.1.index_op(),
346            i.2.index_op(),
347            i.3.index_op(),
348            i.4.index_op(),
349            i.5.index_op(),
350            i.6.index_op(),
351            i.7.index_op(),
352            i.8.index_op(),
353            i.9.index_op(),
354        ];
355        get_item_nd(self, &i, stream)
356    }
357}
358
359impl<'a, 'b, 'c, 'd, 'e, 'f, 'g, 'h, 'i, 'j, 'k, A, B, C, D, E, F, G, H, I, J, K>
360    TryIndexOp<(A, B, C, D, E, F, G, H, I, J, K)> for Array
361where
362    A: ArrayIndex<'a>,
363    B: ArrayIndex<'b>,
364    C: ArrayIndex<'c>,
365    D: ArrayIndex<'d>,
366    E: ArrayIndex<'e>,
367    F: ArrayIndex<'f>,
368    G: ArrayIndex<'g>,
369    H: ArrayIndex<'h>,
370    I: ArrayIndex<'i>,
371    J: ArrayIndex<'j>,
372    K: ArrayIndex<'k>,
373{
374    fn try_index_device(
375        &self,
376        i: (A, B, C, D, E, F, G, H, I, J, K),
377        stream: impl AsRef<Stream>,
378    ) -> Result<Array> {
379        let i = [
380            i.0.index_op(),
381            i.1.index_op(),
382            i.2.index_op(),
383            i.3.index_op(),
384            i.4.index_op(),
385            i.5.index_op(),
386            i.6.index_op(),
387            i.7.index_op(),
388            i.8.index_op(),
389            i.9.index_op(),
390            i.10.index_op(),
391        ];
392        get_item_nd(self, &i, stream)
393    }
394}
395
396impl<'a, 'b, 'c, 'd, 'e, 'f, 'g, 'h, 'i, 'j, 'k, 'l, A, B, C, D, E, F, G, H, I, J, K, L>
397    TryIndexOp<(A, B, C, D, E, F, G, H, I, J, K, L)> for Array
398where
399    A: ArrayIndex<'a>,
400    B: ArrayIndex<'b>,
401    C: ArrayIndex<'c>,
402    D: ArrayIndex<'d>,
403    E: ArrayIndex<'e>,
404    F: ArrayIndex<'f>,
405    G: ArrayIndex<'g>,
406    H: ArrayIndex<'h>,
407    I: ArrayIndex<'i>,
408    J: ArrayIndex<'j>,
409    K: ArrayIndex<'k>,
410    L: ArrayIndex<'l>,
411{
412    fn try_index_device(
413        &self,
414        i: (A, B, C, D, E, F, G, H, I, J, K, L),
415        stream: impl AsRef<Stream>,
416    ) -> Result<Array> {
417        let i = [
418            i.0.index_op(),
419            i.1.index_op(),
420            i.2.index_op(),
421            i.3.index_op(),
422            i.4.index_op(),
423            i.5.index_op(),
424            i.6.index_op(),
425            i.7.index_op(),
426            i.8.index_op(),
427            i.9.index_op(),
428            i.10.index_op(),
429            i.11.index_op(),
430        ];
431        get_item_nd(self, &i, stream)
432    }
433}
434
435impl<'a, 'b, 'c, 'd, 'e, 'f, 'g, 'h, 'i, 'j, 'k, 'l, 'm, A, B, C, D, E, F, G, H, I, J, K, L, M>
436    TryIndexOp<(A, B, C, D, E, F, G, H, I, J, K, L, M)> for Array
437where
438    A: ArrayIndex<'a>,
439    B: ArrayIndex<'b>,
440    C: ArrayIndex<'c>,
441    D: ArrayIndex<'d>,
442    E: ArrayIndex<'e>,
443    F: ArrayIndex<'f>,
444    G: ArrayIndex<'g>,
445    H: ArrayIndex<'h>,
446    I: ArrayIndex<'i>,
447    J: ArrayIndex<'j>,
448    K: ArrayIndex<'k>,
449    L: ArrayIndex<'l>,
450    M: ArrayIndex<'m>,
451{
452    fn try_index_device(
453        &self,
454        i: (A, B, C, D, E, F, G, H, I, J, K, L, M),
455        stream: impl AsRef<Stream>,
456    ) -> Result<Array> {
457        let i = [
458            i.0.index_op(),
459            i.1.index_op(),
460            i.2.index_op(),
461            i.3.index_op(),
462            i.4.index_op(),
463            i.5.index_op(),
464            i.6.index_op(),
465            i.7.index_op(),
466            i.8.index_op(),
467            i.9.index_op(),
468            i.10.index_op(),
469            i.11.index_op(),
470            i.12.index_op(),
471        ];
472        get_item_nd(self, &i, stream)
473    }
474}
475
476impl<
477        'a,
478        'b,
479        'c,
480        'd,
481        'e,
482        'f,
483        'g,
484        'h,
485        'i,
486        'j,
487        'k,
488        'l,
489        'm,
490        'n,
491        A,
492        B,
493        C,
494        D,
495        E,
496        F,
497        G,
498        H,
499        I,
500        J,
501        K,
502        L,
503        M,
504        N,
505    > TryIndexOp<(A, B, C, D, E, F, G, H, I, J, K, L, M, N)> for Array
506where
507    A: ArrayIndex<'a>,
508    B: ArrayIndex<'b>,
509    C: ArrayIndex<'c>,
510    D: ArrayIndex<'d>,
511    E: ArrayIndex<'e>,
512    F: ArrayIndex<'f>,
513    G: ArrayIndex<'g>,
514    H: ArrayIndex<'h>,
515    I: ArrayIndex<'i>,
516    J: ArrayIndex<'j>,
517    K: ArrayIndex<'k>,
518    L: ArrayIndex<'l>,
519    M: ArrayIndex<'m>,
520    N: ArrayIndex<'n>,
521{
522    fn try_index_device(
523        &self,
524        i: (A, B, C, D, E, F, G, H, I, J, K, L, M, N),
525        stream: impl AsRef<Stream>,
526    ) -> Result<Array> {
527        let i = [
528            i.0.index_op(),
529            i.1.index_op(),
530            i.2.index_op(),
531            i.3.index_op(),
532            i.4.index_op(),
533            i.5.index_op(),
534            i.6.index_op(),
535            i.7.index_op(),
536            i.8.index_op(),
537            i.9.index_op(),
538            i.10.index_op(),
539            i.11.index_op(),
540            i.12.index_op(),
541            i.13.index_op(),
542        ];
543        get_item_nd(self, &i, stream)
544    }
545}
546
547impl<
548        'a,
549        'b,
550        'c,
551        'd,
552        'e,
553        'f,
554        'g,
555        'h,
556        'i,
557        'j,
558        'k,
559        'l,
560        'm,
561        'n,
562        'o,
563        A,
564        B,
565        C,
566        D,
567        E,
568        F,
569        G,
570        H,
571        I,
572        J,
573        K,
574        L,
575        M,
576        N,
577        O,
578    > TryIndexOp<(A, B, C, D, E, F, G, H, I, J, K, L, M, N, O)> for Array
579where
580    A: ArrayIndex<'a>,
581    B: ArrayIndex<'b>,
582    C: ArrayIndex<'c>,
583    D: ArrayIndex<'d>,
584    E: ArrayIndex<'e>,
585    F: ArrayIndex<'f>,
586    G: ArrayIndex<'g>,
587    H: ArrayIndex<'h>,
588    I: ArrayIndex<'i>,
589    J: ArrayIndex<'j>,
590    K: ArrayIndex<'k>,
591    L: ArrayIndex<'l>,
592    M: ArrayIndex<'m>,
593    N: ArrayIndex<'n>,
594    O: ArrayIndex<'o>,
595{
596    fn try_index_device(
597        &self,
598        i: (A, B, C, D, E, F, G, H, I, J, K, L, M, N, O),
599        stream: impl AsRef<Stream>,
600    ) -> Result<Array> {
601        let i = [
602            i.0.index_op(),
603            i.1.index_op(),
604            i.2.index_op(),
605            i.3.index_op(),
606            i.4.index_op(),
607            i.5.index_op(),
608            i.6.index_op(),
609            i.7.index_op(),
610            i.8.index_op(),
611            i.9.index_op(),
612            i.10.index_op(),
613            i.11.index_op(),
614            i.12.index_op(),
615            i.13.index_op(),
616            i.14.index_op(),
617        ];
618        get_item_nd(self, &i, stream)
619    }
620}
621
622impl<
623        'a,
624        'b,
625        'c,
626        'd,
627        'e,
628        'f,
629        'g,
630        'h,
631        'i,
632        'j,
633        'k,
634        'l,
635        'm,
636        'n,
637        'o,
638        'p,
639        A,
640        B,
641        C,
642        D,
643        E,
644        F,
645        G,
646        H,
647        I,
648        J,
649        K,
650        L,
651        M,
652        N,
653        O,
654        P,
655    > TryIndexOp<(A, B, C, D, E, F, G, H, I, J, K, L, M, N, O, P)> for Array
656where
657    A: ArrayIndex<'a>,
658    B: ArrayIndex<'b>,
659    C: ArrayIndex<'c>,
660    D: ArrayIndex<'d>,
661    E: ArrayIndex<'e>,
662    F: ArrayIndex<'f>,
663    G: ArrayIndex<'g>,
664    H: ArrayIndex<'h>,
665    I: ArrayIndex<'i>,
666    J: ArrayIndex<'j>,
667    K: ArrayIndex<'k>,
668    L: ArrayIndex<'l>,
669    M: ArrayIndex<'m>,
670    N: ArrayIndex<'n>,
671    O: ArrayIndex<'o>,
672    P: ArrayIndex<'p>,
673{
674    fn try_index_device(
675        &self,
676        i: (A, B, C, D, E, F, G, H, I, J, K, L, M, N, O, P),
677        stream: impl AsRef<Stream>,
678    ) -> Result<Array> {
679        let i = [
680            i.0.index_op(),
681            i.1.index_op(),
682            i.2.index_op(),
683            i.3.index_op(),
684            i.4.index_op(),
685            i.5.index_op(),
686            i.6.index_op(),
687            i.7.index_op(),
688            i.8.index_op(),
689            i.9.index_op(),
690            i.10.index_op(),
691            i.11.index_op(),
692            i.12.index_op(),
693            i.13.index_op(),
694            i.14.index_op(),
695            i.15.index_op(),
696        ];
697        get_item_nd(self, &i, stream)
698    }
699}
700
701// Implement private bindings
702impl Array {
703    // This is exposed in the c api but not found in the swift or python api
704    //
705    // Thie is not the same as rust slice. Slice in python is more like `StepBy` iterator in rust
706    pub(crate) fn slice_device(
707        &self,
708        start: &[i32],
709        stop: &[i32],
710        strides: &[i32],
711        stream: impl AsRef<Stream>,
712    ) -> Result<Array> {
713        Array::try_from_op(|res| unsafe {
714            mlx_sys::mlx_slice(
715                res,
716                self.as_ptr(),
717                start.as_ptr(),
718                start.len(),
719                stop.as_ptr(),
720                stop.len(),
721                strides.as_ptr(),
722                strides.len(),
723                stream.as_ref().as_ptr(),
724            )
725        })
726    }
727}
728
729/* -------------------------------------------------------------------------- */
730/*                              Helper functions                              */
731/* -------------------------------------------------------------------------- */
732
733fn absolute_indices(absolute_start: i32, absolute_end: i32, stride: i32) -> Vec<i32> {
734    let mut indices = Vec::new();
735    let mut i = absolute_start;
736    while (stride > 0 && i < absolute_end) || (stride < 0 && i > absolute_end) {
737        indices.push(i);
738        i += stride;
739    }
740    indices
741}
742
743enum GatherIndexItem<'a> {
744    Owned(Array),
745    Borrowed(&'a Array),
746    Rc(Rc<Array>),
747}
748
749impl AsRef<Array> for GatherIndexItem<'_> {
750    fn as_ref(&self) -> &Array {
751        match self {
752            GatherIndexItem::Owned(array) => array,
753            GatherIndexItem::Borrowed(array) => array,
754            GatherIndexItem::Rc(array) => array,
755        }
756    }
757}
758
759impl Deref for GatherIndexItem<'_> {
760    type Target = Array;
761
762    fn deref(&self) -> &Self::Target {
763        match self {
764            GatherIndexItem::Owned(array) => array,
765            GatherIndexItem::Borrowed(array) => array,
766            GatherIndexItem::Rc(array) => array,
767        }
768    }
769}
770
771// Implement additional public APIs
772//
773// TODO: rewrite this in a more rusty way
774#[inline]
775fn gather_nd<'a>(
776    src: &Array,
777    operations: impl Iterator<Item = &'a ArrayIndexOp<'a>>,
778    gather_first: bool,
779    last_array_or_index: usize,
780    stream: impl AsRef<Stream>,
781) -> Result<(usize, Array)> {
782    use ArrayIndexOp::*;
783
784    let mut max_dims = 0;
785    let mut slice_count = 0;
786    let mut is_slice: Vec<bool> = Vec::with_capacity(last_array_or_index);
787    let mut gather_indices: Vec<GatherIndexItem> = Vec::with_capacity(last_array_or_index);
788
789    let shape = src.shape();
790
791    // prepare the gather indices
792    let mut axes = Vec::with_capacity(last_array_or_index);
793    let mut operation_len: usize = 0;
794    let mut slice_sizes = shape.to_vec();
795    for (i, op) in operations.enumerate() {
796        axes.push(i as i32);
797        operation_len += 1;
798        slice_sizes[i] = 1;
799        match op {
800            TakeIndex { index } => {
801                let item = Array::from_int(resolve_index_unchecked(
802                    *index,
803                    src.dim(i as i32) as usize,
804                ) as i32);
805                gather_indices.push(GatherIndexItem::Owned(item));
806                is_slice.push(false);
807            }
808            Slice(range) => {
809                slice_count += 1;
810                is_slice.push(true);
811
812                let size = shape[i];
813                let absolute_start = range.absolute_start(size);
814                let absolute_end = range.absolute_end(size);
815                let indices = absolute_indices(absolute_start, absolute_end, range.stride());
816
817                let item = Array::from_slice(&indices, &[indices.len() as i32]);
818
819                gather_indices.push(GatherIndexItem::Owned(item));
820            }
821            TakeArray { indices } => {
822                is_slice.push(false);
823                max_dims = max_dims.max(indices.ndim());
824                // Cloning is just incrementing the reference count
825                gather_indices.push(GatherIndexItem::Rc(indices.clone()));
826            }
827            TakeArrayRef { indices } => {
828                is_slice.push(false);
829                max_dims = max_dims.max(indices.ndim());
830                // Cloning is just incrementing the reference count
831                gather_indices.push(GatherIndexItem::Borrowed(indices));
832            }
833            Ellipsis | ExpandDims => {
834                unreachable!("Unexpected operation in gather_nd")
835            }
836        }
837    }
838
839    // reshape them so that the int/array indices are first
840    if gather_first {
841        if slice_count > 0 {
842            let mut slice_index = 0;
843            for (i, item) in gather_indices.iter_mut().enumerate() {
844                if is_slice[i] {
845                    let mut new_shape = vec![1; max_dims + slice_count];
846                    new_shape[max_dims + slice_index] = item.dim(0);
847                    *item = GatherIndexItem::Owned(item.reshape(&new_shape)?);
848                    slice_index += 1;
849                } else {
850                    let mut new_shape = item.shape().to_vec();
851                    new_shape.extend((0..slice_count).map(|_| 1));
852                    *item = GatherIndexItem::Owned(item.reshape(&new_shape)?);
853                }
854            }
855        }
856    } else {
857        // reshape them so that the int/array indices are last
858        for (i, item) in gather_indices[..slice_count].iter_mut().enumerate() {
859            let mut new_shape = vec![1; max_dims + slice_count];
860            new_shape[i] = item.dim(0);
861            *item = GatherIndexItem::Owned(item.reshape(&new_shape)?);
862        }
863    }
864
865    // Do the gather
866    // let indices = new_mlx_vector_array(gather_indices);
867    // SAFETY: indices will be freed at the end of this function. The lifetime of the items in
868    // `gather_indices` is managed by the `gather_indices` vector.
869    let indices = VectorArray::try_from_iter(gather_indices.iter())?;
870
871    let gathered = Array::try_from_op(|res| unsafe {
872        mlx_sys::mlx_gather(
873            res,
874            src.as_ptr(),
875            indices.as_ptr(),
876            axes.as_ptr(),
877            axes.len(),
878            slice_sizes.as_ptr(),
879            slice_sizes.len(),
880            stream.as_ref().as_ptr(),
881        )
882    })?;
883    let gathered_shape = gathered.shape();
884
885    // Squeeze the dims
886    let output_shape: Vec<i32> = gathered_shape[0..(max_dims + slice_count)]
887        .iter()
888        .chain(gathered_shape[(max_dims + slice_count + operation_len)..].iter())
889        .copied()
890        .collect();
891    let result = gathered.reshape(&output_shape)?;
892
893    Ok((max_dims, result))
894}
895
896#[inline]
897fn get_item_index(src: &Array, index: i32, axis: i32, stream: impl AsRef<Stream>) -> Result<Array> {
898    let index = resolve_index_unchecked(index, src.dim(axis) as usize) as i32;
899    src.take_device(array!(index), axis, stream)
900}
901
902#[inline]
903fn get_item_array(
904    src: &Array,
905    indices: &Array,
906    axis: i32,
907    stream: impl AsRef<Stream>,
908) -> Result<Array> {
909    src.take_device(indices, axis, stream)
910}
911
912#[inline]
913fn get_item_slice(src: &Array, range: RangeIndex, stream: impl AsRef<Stream>) -> Result<Array> {
914    let ndim = src.ndim();
915    let mut starts: SmallVec<[i32; DEFAULT_STACK_VEC_LEN]> = smallvec![0; ndim];
916    let mut ends: SmallVec<[i32; DEFAULT_STACK_VEC_LEN]> = SmallVec::from_slice(src.shape());
917    let mut strides: SmallVec<[i32; DEFAULT_STACK_VEC_LEN]> = smallvec![1; ndim];
918
919    let size = ends[0];
920    starts[0] = range.start(size);
921    ends[0] = range.end(size);
922    strides[0] = range.stride();
923
924    src.slice_device(&starts, &ends, &strides, stream)
925}
926
927// See `mlx_get_item` in python/src/indexing.cpp and `getItem` in
928// mlx-swift/Sources/MLX/MLXArray+Indexing.swift
929fn get_item<'a>(
930    src: &Array,
931    index: impl ArrayIndex<'a>,
932    stream: impl AsRef<Stream>,
933) -> Result<Array> {
934    use ArrayIndexOp::*;
935
936    match index.index_op() {
937        Ellipsis => Ok(src.deep_clone()),
938        TakeIndex { index } => get_item_index(src, index, 0, stream),
939        TakeArray { indices } => get_item_array(src, &indices, 0, stream),
940        TakeArrayRef { indices } => get_item_array(src, indices, 0, stream),
941        Slice(range) => get_item_slice(src, range, stream),
942        ExpandDims => src.expand_dims_device(&[0], stream),
943    }
944}
945
946// See `mlx_get_item_nd` in python/src/indexing.cpp and `getItemNd` in
947// mlx-swift/Sources/MLX/MLXArray+Indexing.swift
948fn get_item_nd(
949    src: &Array,
950    operations: &[ArrayIndexOp],
951    stream: impl AsRef<Stream>,
952) -> Result<Array> {
953    use ArrayIndexOp::*;
954
955    let mut src = Cow::Borrowed(src);
956
957    // The plan is as follows:
958    // 1. Replace the ellipsis with a series of slice(None)
959    // 2. Loop over the indices and calculate the gather indices
960    // 3. Calculate the remaining slices and reshapes
961
962    let operations = expand_ellipsis_operations(src.ndim(), operations);
963
964    // Gather handling
965
966    // compute gatherFirst -- this will be true if there is:
967    // - a leading array or index operation followed by
968    // - a non index/array (e.g. a slice)
969    // - an int/array operation
970    //
971    // - and there is at least one array operation (hanled below with haveArray)
972    let mut gather_first = false;
973    let mut have_array_or_index = false; // This is `haveArrayOrIndex` in the swift binding
974    let mut have_non_array = false;
975    for item in operations.iter() {
976        if item.is_array_or_index() {
977            if have_array_or_index && have_non_array {
978                gather_first = true;
979                break;
980            }
981            have_array_or_index = true;
982        } else {
983            have_non_array = have_non_array || have_array_or_index;
984        }
985    }
986
987    let array_count = operations.iter().filter(|op| op.is_array()).count();
988    let have_array = array_count > 0;
989
990    let mut remaining_indices: Vec<ArrayIndexOp> = Vec::new();
991    if have_array {
992        // apply all the operations (except for .newAxis) up to and including the
993        // final .array operation (array operations are implemented via gather)
994        let last_array_or_index = operations
995            .iter()
996            .rposition(|op| op.is_array_or_index())
997            .unwrap(); // safe because we know there is at least one array operation
998        let gather_indices = operations[..=last_array_or_index]
999            .iter()
1000            .filter(|op| !matches!(op, Ellipsis | ExpandDims));
1001        let (max_dims, gathered) = gather_nd(
1002            &src,
1003            gather_indices,
1004            gather_first,
1005            last_array_or_index,
1006            &stream,
1007        )?;
1008
1009        src = Cow::Owned(gathered);
1010
1011        // Reassemble the indices for the slicing or reshaping if there are any
1012        if gather_first {
1013            remaining_indices.extend((0..max_dims).map(|_| (..).index_op()));
1014
1015            // copy any newAxis in the gatherIndices through.  any slices get
1016            // copied in as full range (already applied)
1017            for item in &operations[..=last_array_or_index] {
1018                // Using full match syntax to avoid forgetting to add new cases
1019                match item {
1020                    ExpandDims => remaining_indices.push(item.clone()),
1021                    Slice { .. } => remaining_indices.push((..).index_op()),
1022                    Ellipsis
1023                    | TakeIndex { index: _ }
1024                    | TakeArray { indices: _ }
1025                    | TakeArrayRef { indices: _ } => {}
1026                }
1027            }
1028
1029            // append the remaining operations
1030            remaining_indices.extend(operations[(last_array_or_index + 1)..].iter().cloned());
1031        } else {
1032            // !gather_first
1033            for item in operations.iter() {
1034                // Using full match syntax to avoid forgetting to add new cases
1035                match item {
1036                    TakeIndex { .. } | TakeArray { .. } | TakeArrayRef { .. } => break,
1037                    ExpandDims => remaining_indices.push(item.clone()),
1038                    Ellipsis | Slice(_) => remaining_indices.push((..).index_op()),
1039                }
1040            }
1041
1042            // handle the trailing gathers
1043            remaining_indices.extend((0..max_dims).map(|_| (..).index_op()));
1044
1045            // and the remaining operations
1046            remaining_indices.extend(operations[(last_array_or_index + 1)..].iter().cloned());
1047        }
1048    }
1049
1050    if have_array && remaining_indices.is_empty() {
1051        return Ok(src.into_owned());
1052    }
1053
1054    if remaining_indices.is_empty() {
1055        remaining_indices = operations.to_vec();
1056    }
1057
1058    // Slice handling
1059    let ndim = src.ndim();
1060    let mut starts: SmallVec<[i32; DEFAULT_STACK_VEC_LEN]> = smallvec![0; ndim];
1061    let mut ends: SmallVec<[i32; DEFAULT_STACK_VEC_LEN]> = SmallVec::from_slice(src.shape());
1062    let mut strides: SmallVec<[i32; DEFAULT_STACK_VEC_LEN]> = smallvec![1; ndim];
1063    let mut squeeze_needed = false;
1064    let mut axis = 0;
1065
1066    for item in remaining_indices.iter() {
1067        match item {
1068            ExpandDims => continue,
1069            TakeIndex { mut index } => {
1070                if !have_array {
1071                    index = resolve_index_unchecked(index, src.dim(axis as i32) as usize) as i32;
1072                    starts[axis] = index;
1073                    ends[axis] = index + 1;
1074                    squeeze_needed = true;
1075                }
1076            }
1077            Slice(range) => {
1078                let size = src.dim(axis as i32);
1079                starts[axis] = range.start(size);
1080                ends[axis] = range.end(size);
1081                strides[axis] = range.stride();
1082            }
1083            Ellipsis | TakeArray { .. } | TakeArrayRef { .. } => {
1084                unreachable!("Unexpected item in remaining_indices: {:?}", item)
1085            }
1086        }
1087        axis += 1;
1088    }
1089
1090    src = Cow::Owned(src.slice_device(&starts, &ends, &strides, stream)?);
1091
1092    // Unsqueeze handling
1093    if remaining_indices.len() > ndim || squeeze_needed {
1094        let mut new_shape = SmallVec::<[i32; DEFAULT_STACK_VEC_LEN]>::new();
1095        let mut axis_ = 0;
1096        for item in remaining_indices {
1097            // using full match syntax to avoid forgetting to add new cases
1098            match item {
1099                ExpandDims => new_shape.push(1),
1100                TakeIndex { .. } => {
1101                    if squeeze_needed {
1102                        axis_ += 1;
1103                    }
1104                }
1105                Ellipsis | TakeArray { .. } | TakeArrayRef { .. } | Slice(_) => {
1106                    new_shape.push(src.dim(axis_));
1107                    axis_ += 1;
1108                }
1109            }
1110        }
1111        new_shape.extend(src.shape()[(axis_ as usize)..].iter().cloned());
1112
1113        src = Cow::Owned(src.reshape(&new_shape)?);
1114    }
1115
1116    Ok(src.into_owned())
1117}
1118
1119/* -------------------------------------------------------------------------- */
1120/*                                 Unit tests                                 */
1121/* -------------------------------------------------------------------------- */
1122
1123#[cfg(test)]
1124mod tests {
1125    use crate::{
1126        assert_array_eq,
1127        ops::indexing::{Ellipsis, IndexOp, IntoStrideBy, NewAxis},
1128        Array,
1129    };
1130
1131    #[test]
1132    fn test_array_index_negative_int() {
1133        let a = Array::from_iter(0i32..8, &[8]);
1134
1135        let s = a.index(-1);
1136
1137        assert_eq!(s.ndim(), 0);
1138        assert_eq!(s.item::<i32>(), 7);
1139
1140        let s = a.index(-8);
1141
1142        assert_eq!(s.ndim(), 0);
1143        assert_eq!(s.item::<i32>(), 0);
1144    }
1145
1146    #[test]
1147    fn test_array_index_new_axis() {
1148        let a = Array::from_iter(0..60, &[3, 4, 5]);
1149        let s = a.index(NewAxis);
1150
1151        assert_eq!(s.ndim(), 4);
1152        assert_eq!(s.shape(), &[1, 3, 4, 5]);
1153
1154        let expected = Array::from_iter(0..60, &[1, 3, 4, 5]);
1155        assert_array_eq!(s, expected, 0.01);
1156    }
1157
1158    #[test]
1159    fn test_array_index_ellipsis() {
1160        let a = Array::from_iter(0i32..8, &[2, 2, 2]);
1161
1162        let s1 = a.index((.., .., 0));
1163        let expected = Array::from_slice(&[0, 2, 4, 6], &[2, 2]);
1164        assert_array_eq!(s1, expected, 0.01);
1165
1166        let s2 = a.index((Ellipsis, 0));
1167
1168        let expected = Array::from_slice(&[0, 2, 4, 6], &[2, 2]);
1169        assert_array_eq!(s2, expected, 0.01);
1170
1171        let s3 = a.index(Ellipsis);
1172
1173        let expected = Array::from_iter(0i32..8, &[2, 2, 2]);
1174        assert_array_eq!(s3, expected, 0.01);
1175    }
1176
1177    #[test]
1178    fn test_array_index_stride() {
1179        let arr = Array::from_iter(0..10, &[10]);
1180        let s = arr.index((2..8).stride_by(2));
1181
1182        let expected = Array::from_slice(&[2, 4, 6], &[3]);
1183        assert_array_eq!(s, expected, 0.01);
1184    }
1185
1186    // The unit tests below are ported from the swift binding.
1187    // See `mlx-swift/Tests/MLXTests/MLXArray+IndexingTests.swift`
1188
1189    #[test]
1190    fn test_array_subscript_int() {
1191        let a = Array::from_iter(0i32..512, &[8, 8, 8]);
1192
1193        let s = a.index(1);
1194
1195        assert_eq!(s.ndim(), 2);
1196        assert_eq!(s.shape(), &[8, 8]);
1197
1198        let expected = Array::from_iter(64..128, &[8, 8]);
1199        assert_array_eq!(s, expected, 0.01);
1200    }
1201
1202    #[test]
1203    fn test_array_subscript_int_array() {
1204        // squeeze output dimensions as needed
1205        let a = Array::from_iter(0i32..512, &[8, 8, 8]);
1206
1207        let s1 = a.index((1, 2));
1208
1209        assert_eq!(s1.ndim(), 1);
1210        assert_eq!(s1.shape(), &[8]);
1211
1212        let expected = Array::from_iter(80..88, &[8]);
1213        assert_array_eq!(s1, expected, 0.01);
1214
1215        let s2 = a.index((1, 2, 3));
1216
1217        assert_eq!(s2.ndim(), 0);
1218        assert!(s2.shape().is_empty());
1219        assert_eq!(s2.item::<i32>(), 64 + 2 * 8 + 3);
1220    }
1221
1222    #[test]
1223    fn test_array_subscript_int_array_2() {
1224        // last dimension should not be squeezed
1225        let a = Array::from_iter(0i32..512, &[8, 8, 8, 1]);
1226
1227        let s = a.index(1);
1228
1229        assert_eq!(s.ndim(), 3);
1230        assert_eq!(s.shape(), &[8, 8, 1]);
1231
1232        let s1 = a.index((1, 2));
1233
1234        assert_eq!(s1.ndim(), 2);
1235        assert_eq!(s1.shape(), &[8, 1]);
1236
1237        let s2 = a.index((1, 2, 3));
1238
1239        assert_eq!(s2.ndim(), 1);
1240        assert_eq!(s2.shape(), &[1]);
1241    }
1242
1243    #[test]
1244    fn test_array_subscript_from_end() {
1245        let a = Array::from_iter(0i32..12, &[3, 4]);
1246
1247        let s = a.index((-1, -2));
1248
1249        assert_eq!(s.ndim(), 0);
1250        assert_eq!(s.item::<i32>(), 10);
1251    }
1252
1253    #[test]
1254    fn test_array_subscript_range() {
1255        let a = Array::from_iter(0i32..512, &[8, 8, 8]);
1256
1257        let s1 = a.index(1..3);
1258
1259        assert_eq!(s1.ndim(), 3);
1260        assert_eq!(s1.shape(), &[2, 8, 8]);
1261        let expected = Array::from_iter(64..192, &[2, 8, 8]);
1262        assert_array_eq!(s1, expected, 0.01);
1263
1264        // even though the first dimension is 1 we do not squeeze it
1265        let s2 = a.index(1..=1);
1266
1267        assert_eq!(s2.ndim(), 3);
1268        assert_eq!(s2.shape(), &[1, 8, 8]);
1269        let expected = Array::from_iter(64..128, &[1, 8, 8]);
1270        assert_array_eq!(s2, expected, 0.01);
1271
1272        // multiple ranges, resolving RangeExpressions vs the dimensions
1273        let s3 = a.index((1..2, ..3, 3..));
1274
1275        assert_eq!(s3.ndim(), 3);
1276        assert_eq!(s3.shape(), &[1, 3, 5]);
1277        let expected = Array::from_slice(
1278            &[67, 68, 69, 70, 71, 75, 76, 77, 78, 79, 83, 84, 85, 86, 87],
1279            &[1, 3, 5],
1280        );
1281        assert_array_eq!(s3, expected, 0.01);
1282
1283        let s4 = a.index((-2..-1, ..-3, -3..));
1284
1285        assert_eq!(s4.ndim(), 3);
1286        assert_eq!(s4.shape(), &[1, 5, 3]);
1287        let expected = Array::from_slice(
1288            &[
1289                389, 390, 391, 397, 398, 399, 405, 406, 407, 413, 414, 415, 421, 422, 423,
1290            ],
1291            &[1, 5, 3],
1292        );
1293        assert_array_eq!(s4, expected, 0.01);
1294    }
1295
1296    #[test]
1297    fn test_array_subscript_advanced() {
1298        // advanced subscript examples taken from
1299        // https://numpy.org/doc/stable/user/basics.indexing.html#integer-array-indexing
1300
1301        let a = Array::from_iter(0..35, &[5, 7]).as_type::<i32>().unwrap();
1302
1303        let i1 = Array::from_slice(&[0, 2, 4], &[3]);
1304        let i2 = Array::from_slice(&[0, 1, 2], &[3]);
1305
1306        let s1 = a.index((i1, i2));
1307
1308        assert_eq!(s1.ndim(), 1);
1309        assert_eq!(s1.shape(), &[3]);
1310
1311        let expected = Array::from_slice(&[0i32, 15, 30], &[3]);
1312        assert_array_eq!(s1, expected, 0.01);
1313    }
1314
1315    #[test]
1316    fn test_array_subscript_advanced_with_ref() {
1317        let a = Array::from_iter(0..35, &[5, 7]).as_type::<i32>().unwrap();
1318
1319        let i1 = Array::from_slice(&[0, 2, 4], &[3]);
1320        let i2 = Array::from_slice(&[0, 1, 2], &[3]);
1321
1322        let s1 = a.index((i1, &i2));
1323
1324        assert_eq!(s1.ndim(), 1);
1325        assert_eq!(s1.shape(), &[3]);
1326
1327        let expected = Array::from_slice(&[0i32, 15, 30], &[3]);
1328        assert_array_eq!(s1, expected, 0.01);
1329    }
1330
1331    #[test]
1332    fn test_array_subscript_advanced_2() {
1333        let a = Array::from_iter(0..12, &[6, 2]).as_type::<i32>().unwrap();
1334
1335        let i1 = Array::from_slice(&[0, 2, 4], &[3]);
1336        let s2 = a.index(i1);
1337
1338        let expected = Array::from_slice(&[0i32, 1, 4, 5, 8, 9], &[3, 2]);
1339        assert_array_eq!(s2, expected, 0.01);
1340    }
1341
1342    #[test]
1343    fn test_collection() {
1344        let a = Array::from_iter(0i32..20, &[2, 2, 5]);
1345
1346        // enumerate "rows"
1347        for i in 0..2 {
1348            let row = a.index(i);
1349            let expected = Array::from_iter((i * 10)..(i * 10 + 10), &[2, 5]);
1350            assert_array_all_close!(row, expected);
1351        }
1352    }
1353
1354    #[test]
1355    fn test_array_subscript_advanced_2d() {
1356        let a = Array::from_iter(0..12, &[4, 3]).as_type::<i32>().unwrap();
1357
1358        let rows = Array::from_slice(&[0, 0, 3, 3], &[2, 2]);
1359        let cols = Array::from_slice(&[0, 2, 0, 2], &[2, 2]);
1360
1361        let s = a.index((rows, cols));
1362
1363        let expected = Array::from_slice(&[0, 2, 9, 11], &[2, 2]);
1364        assert_array_eq!(s, expected, 0.01);
1365    }
1366
1367    #[test]
1368    fn test_array_subscript_advanced_2d_2() {
1369        let a = Array::from_iter(0..12, &[4, 3]).as_type::<i32>().unwrap();
1370
1371        let rows = Array::from_slice(&[0, 3], &[2, 1]);
1372        let cols = Array::from_slice(&[0, 2], &[2]);
1373
1374        let s = a.index((rows, cols));
1375
1376        let expected = Array::from_slice(&[0, 2, 9, 11], &[2, 2]);
1377        assert_array_eq!(s, expected, 0.01);
1378    }
1379
1380    fn check(result: impl AsRef<Array>, shape: &[i32], expected_sum: i32) {
1381        let result = result.as_ref();
1382        assert_eq!(result.shape(), shape);
1383
1384        let sum = result.sum(None, None).unwrap();
1385
1386        assert_eq!(sum.item::<i32>(), expected_sum);
1387    }
1388
1389    #[test]
1390    fn test_full_index_read_single() {
1391        let a = Array::from_iter(0..60, &[3, 4, 5]);
1392
1393        // a[...]
1394        check(a.index(Ellipsis), &[3, 4, 5], 1770);
1395
1396        // a[None]
1397        check(a.index(NewAxis), &[1, 3, 4, 5], 1770);
1398
1399        // a[0]
1400        check(a.index(0), &[4, 5], 190);
1401
1402        // a[1:3]
1403        check(a.index(1..3), &[2, 4, 5], 1580);
1404
1405        // i = mx.array([2, 1])
1406        let i = Array::from_slice(&[2, 1], &[2]);
1407
1408        // a[i]
1409        check(a.index(i), &[2, 4, 5], 1580);
1410    }
1411
1412    #[test]
1413    fn test_full_index_read_no_array() {
1414        let a = Array::from_iter(0..360, &[2, 3, 4, 5, 3]);
1415
1416        // a[..., 0]
1417        check(a.index((Ellipsis, 0)), &[2, 3, 4, 5], 21420);
1418
1419        // a[0, ...]
1420        check(a.index((0, Ellipsis)), &[3, 4, 5, 3], 16110);
1421
1422        // a[0, ..., 0]
1423        check(a.index((0, Ellipsis, 0)), &[3, 4, 5], 5310);
1424
1425        // a[..., ::2, :]
1426        let result = a.index((Ellipsis, (..).stride_by(2), ..));
1427        check(result, &[2, 3, 4, 3, 3], 38772);
1428
1429        // a[..., None, ::2, -1]
1430        let result = a.index((Ellipsis, NewAxis, (..).stride_by(2), -1));
1431        check(result, &[2, 3, 4, 1, 3], 12996);
1432
1433        // a[:, 2:, 0]
1434        check(a.index((.., 2.., 0)), &[2, 1, 5, 3], 6510);
1435
1436        // a[::-1, :2, 2:, ..., None, ::2]
1437        let result = a.index((
1438            (..).stride_by(-1),
1439            ..2,
1440            2..,
1441            Ellipsis,
1442            NewAxis,
1443            (..).stride_by(2),
1444        ));
1445        check(result, &[2, 2, 2, 5, 1, 2], 13160);
1446    }
1447
1448    #[test]
1449    fn test_full_index_read_array() {
1450        // these have an `Array` as a source of indices and go through the gather path
1451
1452        // a = mx.arange(540).reshape(3, 3, 4, 5, 3)
1453        let a = Array::from_iter(0..540, &[3, 3, 4, 5, 3]);
1454
1455        // i = mx.array([2, 1])
1456        let i = Array::from_slice(&[2, 1], &[2]);
1457
1458        // a[0, i]
1459        check(a.index((0, &i)), &[2, 4, 5, 3], 14340);
1460
1461        // a[..., i, 0]
1462        check(a.index((Ellipsis, &i, 0)), &[3, 3, 4, 2], 19224);
1463
1464        // a[i, 0, ...]
1465        check(a.index((&i, 0, Ellipsis)), &[2, 4, 5, 3], 35940);
1466
1467        // gatherFirst path
1468        // a[i, ..., i]
1469        check(a.index((&i, Ellipsis, &i)), &[2, 3, 4, 5], 43200);
1470
1471        // a[i, ..., ::2, :]
1472        let result = a.index((&i, Ellipsis, (..).stride_by(2), ..));
1473        check(result, &[2, 3, 4, 3, 3], 77652);
1474
1475        // gatherFirst path
1476        // a[..., i, None, ::2, -1]
1477        let result = a.index((Ellipsis, &i, NewAxis, (..).stride_by(2), -1));
1478        check(result, &[2, 3, 3, 1, 3], 14607);
1479
1480        // a[:, 2:, i]
1481        check(a.index((.., 2.., &i)), &[3, 1, 2, 5, 3], 29655);
1482
1483        // a[::-1, :2, i, 2:, ..., None, ::2]
1484        let result = a.index((
1485            (..).stride_by(-1),
1486            ..2,
1487            i,
1488            2..,
1489            Ellipsis,
1490            NewAxis,
1491            (..).stride_by(2),
1492        ));
1493        check(result, &[3, 2, 2, 3, 1, 2], 17460);
1494    }
1495}