mlx_rs/ops/indexing/
mod.rs

1//! Indexing Arrays
2//!
3//! # Overview
4//!
5//! Due to limitations in the `std::ops::Index` and `std::ops::IndexMut` traits (only references can
6//! be returned), the indexing is achieved with the [`IndexOp`] and [`IndexMutOp`] traits where
7//! arrays can be indexed with [`IndexOp::index()`] and [`IndexMutOp::index_mut()`] respectively.
8//!
9//! The following types can be used as indices:
10//!
11//! | Type | Description |
12//! |------|-------------|
13//! | [`i32`] | An integer index |
14//! | [`Array`] | Use an array to index another array |
15//! | `&Array` | Use a reference to an array to index another array |
16//! | [`std::ops::Range<i32>`] | A range index |
17//! | [`std::ops::RangeFrom<i32>`] | A range index |
18//! | [`std::ops::RangeFull`] | A range index |
19//! | [`std::ops::RangeInclusive<i32>`] | A range index |
20//! | [`std::ops::RangeTo<i32>`] | A range index |
21//! | [`std::ops::RangeToInclusive<i32>`] | A range index |
22//! | [`StrideBy`] | A range index with stride |
23//! | [`NewAxis`] | Add a new axis |
24//! | [`Ellipsis`] | Consume all axes |
25//!
26//! # Single axis indexing
27//!
28//! | Indexing Operation | `mlx` (python) | `mlx-swift` | `mlx-rs` |
29//! |--------------------|--------|-------|------|
30//! | integer | `arr[1]` | `arr[1]` | `arr.index(1)` |
31//! | range expression | `arr[1:3]` | `arr[1..<3]` | `arr.index(1..3)` |
32//! | full range | `arr[:]` | `arr[0...]` | `arr.index(..)` |
33//! | range with stride | `arr[::2]` | `arr[.stride(by: 2)]` | `arr.index((..).stride_by(2))` |
34//! | ellipsis (consuming all axes) | `arr[...]` | `arr[.ellipsis]` | `arr.index(Ellipsis)` |
35//! | newaxis | `arr[None]` | `arr[.newAxis]` | `arr.index(NewAxis)` |
36//! | mlx array `i` | `arr[i]` | `arr[i]` | `arr.index(i)` |
37//!
38//! # Multi-axes indexing
39//!
40//! Multi-axes indexing with combinations of the above operations is also supported by combining the
41//! operations in a tuple with the restriction that `Ellipsis` can only be used once.
42//!
43//! ## Examples
44//!
45//! ```rust
46//! // See the multi-dimensional example code for mlx python https://ml-explore.github.io/mlx/build/html/usage/indexing.html
47//!
48//! use mlx_rs::{Array, ops::indexing::*};
49//!
50//! let a = Array::from_iter(0..8, &[2, 2, 2]);
51//!
52//! // a[:, :, 0]
53//! let mut s1 = a.index((.., .., 0));
54//!
55//! let expected = Array::from_slice(&[0, 2, 4, 6], &[2, 2]);
56//! assert_eq!(s1, expected);
57//!
58//! // a[..., 0]
59//! let mut s2 = a.index((Ellipsis, 0));
60//!
61//! let expected = Array::from_slice(&[0, 2, 4, 6], &[2, 2]);
62//! assert_eq!(s1, expected);
63//! ```
64//!
65//! # Set values with indexing
66//!
67//! The same indexing operations (single or multiple) can be used to set values in an array using
68//! the [`IndexMutOp`] trait.
69//!
70//! ## Example
71//!
72//! ```rust
73//! use mlx_rs::{Array, ops::indexing::*};
74//!
75//! let mut a = Array::from_slice(&[1, 2, 3], &[3]);
76//! a.index_mut(2, Array::from_int(0));
77//!
78//! let expected = Array::from_slice(&[1, 2, 0], &[3]);
79//! assert_eq!(a, expected);
80//! ```
81//!
82//! ```rust
83//! use mlx_rs::{Array, ops::indexing::*};
84//!
85//! let mut a = Array::from_iter(0i32..20, &[2, 2, 5]);
86//!
87//! // writing using slices -- this ends up covering two elements
88//! a.index_mut((0..1, 1..2, 2..4), Array::from_int(88));
89//!
90//! let expected = Array::from_slice(
91//!     &[
92//!         0, 1, 2, 3, 4, 5, 6, 88, 88, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19,
93//!     ],
94//!     &[2, 2, 5],
95//! );
96//! assert_eq!(a, expected);
97//! ```
98
99use std::{borrow::Cow, ops::Bound, rc::Rc};
100
101use mlx_internal_macros::{default_device, generate_macro};
102
103use crate::{error::Result, utils::guard::Guarded, Array, Stream, StreamOrDevice};
104
105pub(crate) mod index_impl;
106pub(crate) mod indexmut_impl;
107
108/* -------------------------------------------------------------------------- */
109/*                                Custom types                                */
110/* -------------------------------------------------------------------------- */
111
112/// New axis indexing operation.
113///
114/// See the module level documentation for more information.
115#[derive(Debug, Clone, Copy)]
116pub struct NewAxis;
117
118/// Ellipsis indexing operation.
119///
120/// See the module level documentation for more information.
121#[derive(Debug, Clone, Copy)]
122pub struct Ellipsis;
123
124/// Stride indexing operation.
125///
126/// See the module level documentation for more information.
127#[derive(Debug, Clone, Copy)]
128pub struct StrideBy<I> {
129    /// The inner iterator
130    pub inner: I,
131
132    /// The stride
133    pub stride: i32,
134}
135
136/// Helper trait for creating a stride indexing operation.
137pub trait IntoStrideBy: Sized {
138    /// Create a stride indexing operation.
139    fn stride_by(self, stride: i32) -> StrideBy<Self>;
140}
141
142impl<T> IntoStrideBy for T {
143    fn stride_by(self, stride: i32) -> StrideBy<Self> {
144        StrideBy {
145            inner: self,
146            stride,
147        }
148    }
149}
150
151/// Range indexing operation.
152#[derive(Debug, Clone)]
153pub struct RangeIndex {
154    start: Bound<i32>,
155    stop: Bound<i32>,
156    stride: i32,
157}
158
159impl RangeIndex {
160    pub(crate) fn new(start: Bound<i32>, stop: Bound<i32>, stride: Option<i32>) -> Self {
161        let stride = stride.unwrap_or(1);
162        Self {
163            start,
164            stop,
165            stride,
166        }
167    }
168
169    pub(crate) fn is_full(&self) -> bool {
170        matches!(self.start, Bound::Unbounded)
171            && matches!(self.stop, Bound::Unbounded)
172            && self.stride == 1
173    }
174
175    pub(crate) fn stride(&self) -> i32 {
176        self.stride
177    }
178
179    pub(crate) fn start(&self, size: i32) -> i32 {
180        match self.start {
181            Bound::Included(start) => start,
182            Bound::Excluded(start) => start + 1,
183            Bound::Unbounded => {
184                // ref swift binding
185                // _start ?? (stride < 0 ? size - 1 : 0)
186
187                if self.stride.is_negative() {
188                    size - 1
189                } else {
190                    0
191                }
192            }
193        }
194    }
195
196    pub(crate) fn absolute_start(&self, size: i32) -> i32 {
197        // ref swift binding
198        // return start < 0 ? start + size : start
199
200        let start = self.start(size);
201        if start.is_negative() {
202            start + size
203        } else {
204            start
205        }
206    }
207
208    pub(crate) fn end(&self, size: i32) -> i32 {
209        match self.stop {
210            Bound::Included(stop) => stop + 1,
211            Bound::Excluded(stop) => stop,
212            Bound::Unbounded => {
213                // ref swift binding
214                // _end ?? (stride < 0 ? -size - 1 : size)
215
216                if self.stride.is_negative() {
217                    -size - 1
218                } else {
219                    size
220                }
221            }
222        }
223    }
224
225    pub(crate) fn absolute_end(&self, size: i32) -> i32 {
226        // ref swift binding
227        // return end < 0 ? end + size : end
228
229        let end = self.end(size);
230        if end.is_negative() {
231            end + size
232        } else {
233            end
234        }
235    }
236}
237
238/// Indexing operation for arrays.
239#[derive(Debug, Clone)]
240pub enum ArrayIndexOp<'a> {
241    /// An `Ellipsis` is used to consume all axes
242    ///
243    /// This is equivalent to `...` in python
244    Ellipsis,
245
246    /// A single index operation
247    ///
248    /// This is equivalent to `arr[1]` in python
249    TakeIndex {
250        /// The index to take
251        index: i32,
252    },
253
254    /// Indexing with an array
255    TakeArray {
256        /// The indices to take
257        indices: Rc<Array>, // TODO: remove `Rc` because `Array` is `Clone`
258    },
259
260    /// Indexing with an array reference
261    TakeArrayRef {
262        /// The indices to take
263        indices: &'a Array,
264    },
265
266    /// Indexing with a range
267    ///
268    /// This is equivalent to `arr[1:3]` in python
269    Slice(RangeIndex),
270
271    /// New axis operation
272    ///
273    /// This is equivalent to `arr[None]` in python
274    ExpandDims,
275}
276
277impl ArrayIndexOp<'_> {
278    fn is_array_or_index(&self) -> bool {
279        // Using the full match syntax to avoid forgetting to add new variants
280        match self {
281            ArrayIndexOp::TakeIndex { .. }
282            | ArrayIndexOp::TakeArrayRef { .. }
283            | ArrayIndexOp::TakeArray { .. } => true,
284            ArrayIndexOp::Ellipsis | ArrayIndexOp::Slice(_) | ArrayIndexOp::ExpandDims => false,
285        }
286    }
287
288    fn is_array(&self) -> bool {
289        // Using the full match syntax to avoid forgetting to add new variants
290        match self {
291            ArrayIndexOp::TakeArray { .. } | ArrayIndexOp::TakeArrayRef { .. } => true,
292            ArrayIndexOp::TakeIndex { .. }
293            | ArrayIndexOp::Ellipsis
294            | ArrayIndexOp::Slice(_)
295            | ArrayIndexOp::ExpandDims => false,
296        }
297    }
298}
299
300/* -------------------------------------------------------------------------- */
301/*                                Custom traits                               */
302/* -------------------------------------------------------------------------- */
303
304/// Trait for custom indexing operations.
305///
306/// Out of bounds indexing is allowed and wouldn't return an error.
307pub trait TryIndexOp<Idx> {
308    /// Try to index the array with the given index.
309    fn try_index_device(&self, i: Idx, stream: impl AsRef<Stream>) -> Result<Array>;
310
311    /// Try to index the array with the given index.
312    fn try_index(&self, i: Idx) -> Result<Array> {
313        self.try_index_device(i, StreamOrDevice::default())
314    }
315}
316
317/// Trait for custom indexing operations.
318///
319/// This is implemented for all types that implement `TryIndexOp`.
320pub trait IndexOp<Idx>: TryIndexOp<Idx> {
321    /// Index the array with the given index.
322    fn index_device(&self, i: Idx, stream: impl AsRef<Stream>) -> Array {
323        self.try_index_device(i, stream).unwrap()
324    }
325
326    /// Index the array with the given index.
327    fn index(&self, i: Idx) -> Array {
328        self.try_index(i).unwrap()
329    }
330}
331
332impl<T, Idx> IndexOp<Idx> for T where T: TryIndexOp<Idx> {}
333
334/// Trait for custom mutable indexing operations.
335pub trait TryIndexMutOp<Idx, Val> {
336    /// Try to index the array with the given index and set the value.
337    fn try_index_mut_device(&mut self, i: Idx, val: Val, stream: impl AsRef<Stream>) -> Result<()>;
338
339    /// Try to index the array with the given index and set the value.
340    fn try_index_mut(&mut self, i: Idx, val: Val) -> Result<()> {
341        self.try_index_mut_device(i, val, StreamOrDevice::default())
342    }
343}
344
345// TODO: should `Val` impl `AsRef<Array>` or `Into<Array>`?
346
347/// Trait for custom mutable indexing operations.
348pub trait IndexMutOp<Idx, Val>: TryIndexMutOp<Idx, Val> {
349    /// Index the array with the given index and set the value.
350    fn index_mut_device(&mut self, i: Idx, val: Val, stream: impl AsRef<Stream>) {
351        self.try_index_mut_device(i, val, stream).unwrap()
352    }
353
354    /// Index the array with the given index and set the value.
355    fn index_mut(&mut self, i: Idx, val: Val) {
356        self.try_index_mut(i, val).unwrap()
357    }
358}
359
360impl<T, Idx, Val> IndexMutOp<Idx, Val> for T where T: TryIndexMutOp<Idx, Val> {}
361
362/// Trait for custom indexing operations.
363pub trait ArrayIndex<'a> {
364    /// `mlx` allows out of bounds indexing.
365    fn index_op(self) -> ArrayIndexOp<'a>;
366}
367
368/* -------------------------------------------------------------------------- */
369/*                               Implementation                               */
370/* -------------------------------------------------------------------------- */
371
372// Implement public bindings
373impl Array {
374    /// Take elements along an axis.
375    ///
376    /// The elements are taken from `indices` along the specified axis. If the axis is not specified
377    /// the array is treated as a flattened 1-D array prior to performing the take.
378    ///
379    /// See [`Array::take_all`] for the flattened array.
380    ///
381    /// # Params
382    ///
383    /// - `indices`: The indices to take from the array.
384    /// - `axis`: The axis along which to take the elements.
385    #[default_device]
386    pub fn take_device(
387        &self,
388        indices: impl AsRef<Array>,
389        axis: i32,
390        stream: impl AsRef<Stream>,
391    ) -> Result<Array> {
392        Array::try_from_op(|res| unsafe {
393            mlx_sys::mlx_take(
394                res,
395                self.as_ptr(),
396                indices.as_ref().as_ptr(),
397                axis,
398                stream.as_ref().as_ptr(),
399            )
400        })
401    }
402
403    /// Take elements from flattened 1-D array.
404    ///
405    /// # Params
406    ///
407    /// - `indices`: The indices to take from the array.
408    #[default_device]
409    pub fn take_all_device(
410        &self,
411        indices: impl AsRef<Array>,
412        stream: impl AsRef<Stream>,
413    ) -> Result<Array> {
414        Array::try_from_op(|res| unsafe {
415            mlx_sys::mlx_take_all(
416                res,
417                self.as_ptr(),
418                indices.as_ref().as_ptr(),
419                stream.as_ref().as_ptr(),
420            )
421        })
422    }
423
424    /// Take values along an axis at the specified indices.
425    ///
426    /// If no axis is specified, the array is flattened to 1D prior to the indexing operation.
427    ///
428    /// # Params
429    ///
430    /// - `indices`: The indices to take from the array.
431    /// - `axis`: Axis in the input to take the values from.
432    #[default_device]
433    pub fn take_along_axis_device(
434        &self,
435        indices: impl AsRef<Array>,
436        axis: impl Into<Option<i32>>,
437        stream: impl AsRef<Stream>,
438    ) -> Result<Array> {
439        let (input, axis) = match axis.into() {
440            None => (Cow::Owned(self.reshape_device(&[-1], &stream)?), 0),
441            Some(ax) => (Cow::Borrowed(self), ax),
442        };
443
444        Array::try_from_op(|res| unsafe {
445            mlx_sys::mlx_take_along_axis(
446                res,
447                input.as_ptr(),
448                indices.as_ref().as_ptr(),
449                axis,
450                stream.as_ref().as_ptr(),
451            )
452        })
453    }
454
455    /// Put values along an axis at the specified indices.
456    ///
457    /// If no axis is specified, the array is flattened to 1D prior to the indexing operation.
458    ///
459    /// # Params
460    /// - indices: Indices array. These should be broadcastable with the input array excluding the `axis` dimension.
461    /// - values: Values array. These should be broadcastable with the indices.
462    /// - axis: Axis in the destination to put the values to.
463    /// - stream: stream or device to evaluate on.
464    #[default_device]
465    pub fn put_along_axis_device(
466        &self,
467        indices: impl AsRef<Array>,
468        values: impl AsRef<Array>,
469        axis: impl Into<Option<i32>>,
470        stream: impl AsRef<Stream>,
471    ) -> Result<Array> {
472        match axis.into() {
473            None => {
474                let input = self.reshape_device(&[-1], &stream)?;
475                let array = Array::try_from_op(|res| unsafe {
476                    mlx_sys::mlx_put_along_axis(
477                        res,
478                        input.as_ptr(),
479                        indices.as_ref().as_ptr(),
480                        values.as_ref().as_ptr(),
481                        0,
482                        stream.as_ref().as_ptr(),
483                    )
484                })?;
485                let array = array.reshape_device(self.shape(), &stream)?;
486                Ok(array)
487            }
488            Some(ax) => Array::try_from_op(|res| unsafe {
489                mlx_sys::mlx_put_along_axis(
490                    res,
491                    self.as_ptr(),
492                    indices.as_ref().as_ptr(),
493                    values.as_ref().as_ptr(),
494                    ax,
495                    stream.as_ref().as_ptr(),
496                )
497            }),
498        }
499    }
500}
501
502/// Indices of the maximum values along the axis.
503///
504/// See [`argmax_all`] for the flattened array.
505///
506/// # Params
507///
508/// - `a`: The input array.
509/// - `axis`: Axis to reduce over
510/// - `keep_dims`: Keep reduced axes as singleton dimensions, defaults to False.
511#[generate_macro(customize(root = "$crate::ops::indexing"))]
512#[default_device]
513pub fn argmax_device(
514    a: impl AsRef<Array>,
515    axis: i32,
516    #[optional] keep_dims: impl Into<Option<bool>>,
517    #[optional] stream: impl AsRef<Stream>,
518) -> Result<Array> {
519    let keep_dims = keep_dims.into().unwrap_or(false);
520
521    Array::try_from_op(|res| unsafe {
522        mlx_sys::mlx_argmax(
523            res,
524            a.as_ref().as_ptr(),
525            axis,
526            keep_dims,
527            stream.as_ref().as_ptr(),
528        )
529    })
530}
531
532/// Indices of the maximum value over the entire array.
533///
534/// # Params
535///
536/// - `a`: The input array.
537/// - `keep_dims`: Keep reduced axes as singleton dimensions, defaults to False.
538#[generate_macro(customize(root = "$crate::ops::indexing"))]
539#[default_device]
540pub fn argmax_all_device(
541    a: impl AsRef<Array>,
542    #[optional] keep_dims: impl Into<Option<bool>>,
543    #[optional] stream: impl AsRef<Stream>,
544) -> Result<Array> {
545    let keep_dims = keep_dims.into().unwrap_or(false);
546
547    Array::try_from_op(|res| unsafe {
548        mlx_sys::mlx_argmax_all(
549            res,
550            a.as_ref().as_ptr(),
551            keep_dims,
552            stream.as_ref().as_ptr(),
553        )
554    })
555}
556
557/// Indices of the minimum values along the axis.
558///
559/// See [`argmin_all`] for the flattened array.
560///
561/// # Params
562///
563/// - `a`: The input array.
564/// - `axis`: Axis to reduce over.
565/// - `keep_dims`: Keep reduced axes as singleton dimensions, defaults to False.
566#[generate_macro(customize(root = "$crate::ops::indexing"))]
567#[default_device]
568pub fn argmin_device(
569    a: impl AsRef<Array>,
570    axis: i32,
571    #[optional] keep_dims: impl Into<Option<bool>>,
572    #[optional] stream: impl AsRef<Stream>,
573) -> Result<Array> {
574    let keep_dims = keep_dims.into().unwrap_or(false);
575
576    Array::try_from_op(|res| unsafe {
577        mlx_sys::mlx_argmin(
578            res,
579            a.as_ref().as_ptr(),
580            axis,
581            keep_dims,
582            stream.as_ref().as_ptr(),
583        )
584    })
585}
586
587/// Indices of the minimum value over the entire array.
588///
589/// # Params
590///
591/// - `a`: The input array.
592/// - `keep_dims`: Keep reduced axes as singleton dimensions, defaults to False.
593#[generate_macro(customize(root = "$crate::ops::indexing"))]
594#[default_device]
595pub fn argmin_all_device(
596    a: impl AsRef<Array>,
597    #[optional] keep_dims: impl Into<Option<bool>>,
598    #[optional] stream: impl AsRef<Stream>,
599) -> Result<Array> {
600    let keep_dims = keep_dims.into().unwrap_or(false);
601
602    Array::try_from_op(|res| unsafe {
603        mlx_sys::mlx_argmin_all(
604            res,
605            a.as_ref().as_ptr(),
606            keep_dims,
607            stream.as_ref().as_ptr(),
608        )
609    })
610}
611
612/// See [`Array::take_along_axis`]
613#[generate_macro(customize(root = "$crate::ops::indexing"))]
614#[default_device]
615pub fn take_along_axis_device(
616    a: impl AsRef<Array>,
617    indices: impl AsRef<Array>,
618    #[optional] axis: impl Into<Option<i32>>,
619    #[optional] stream: impl AsRef<Stream>,
620) -> Result<Array> {
621    a.as_ref().take_along_axis_device(indices, axis, stream)
622}
623
624/// See [`Array::put_along_axis`]
625#[generate_macro(customize(root = "$crate::ops::indexing"))]
626#[default_device]
627pub fn put_along_axis_device(
628    a: impl AsRef<Array>,
629    indices: impl AsRef<Array>,
630    values: impl AsRef<Array>,
631    #[optional] axis: impl Into<Option<i32>>,
632    #[optional] stream: impl AsRef<Stream>,
633) -> Result<Array> {
634    a.as_ref()
635        .put_along_axis_device(indices, values, axis, stream)
636}
637
638/// See [`Array::take`]
639#[generate_macro(customize(root = "$crate::ops::indexing"))]
640#[default_device]
641pub fn take_device(
642    a: impl AsRef<Array>,
643    indices: impl AsRef<Array>,
644    axis: i32,
645    #[optional] stream: impl AsRef<Stream>,
646) -> Result<Array> {
647    a.as_ref().take_device(indices, axis, stream)
648}
649
650/// See [`Array::take_all`]
651#[generate_macro(customize(root = "$crate::ops::indexing"))]
652#[default_device]
653pub fn take_all_device(
654    a: impl AsRef<Array>,
655    indices: impl AsRef<Array>,
656    #[optional] stream: impl AsRef<Stream>,
657) -> Result<Array> {
658    a.as_ref().take_all_device(indices, stream)
659}
660
661/// Returns the `k` largest elements from the input along a given axis.
662///
663/// The elements will not necessarily be in sorted order.
664///
665/// See [`topk_all`] for the flattened array.
666///
667/// # Params
668///
669/// - `a`: The input array.
670/// - `k`: The number of elements to return.
671/// - `axis`: Axis to sort over. Default to `-1` if not specified.
672#[generate_macro(customize(root = "$crate::ops::indexing"))]
673#[default_device]
674pub fn topk_device(
675    a: impl AsRef<Array>,
676    k: i32,
677    #[optional] axis: impl Into<Option<i32>>,
678    #[optional] stream: impl AsRef<Stream>,
679) -> Result<Array> {
680    let axis = axis.into().unwrap_or(-1);
681
682    Array::try_from_op(|res| unsafe {
683        mlx_sys::mlx_topk(res, a.as_ref().as_ptr(), k, axis, stream.as_ref().as_ptr())
684    })
685}
686
687/// Returns the `k` largest elements from the flattened input array.
688#[generate_macro(customize(root = "$crate::ops::indexing"))]
689#[default_device]
690pub fn topk_all_device(
691    a: impl AsRef<Array>,
692    k: i32,
693    #[optional] stream: impl AsRef<Stream>,
694) -> Result<Array> {
695    Array::try_from_op(|res| unsafe {
696        mlx_sys::mlx_topk_all(res, a.as_ref().as_ptr(), k, stream.as_ref().as_ptr())
697    })
698}
699
700/* -------------------------------------------------------------------------- */
701/*                              Helper functions                              */
702/* -------------------------------------------------------------------------- */
703
704fn count_non_new_axis_operations(operations: &[ArrayIndexOp]) -> usize {
705    operations
706        .iter()
707        .filter(|op| !matches!(op, ArrayIndexOp::ExpandDims))
708        .count()
709}
710
711fn expand_ellipsis_operations<'a>(
712    ndim: usize,
713    operations: &'a [ArrayIndexOp<'a>],
714) -> Cow<'a, [ArrayIndexOp<'a>]> {
715    let ellipsis_count = operations
716        .iter()
717        .filter(|op| matches!(op, ArrayIndexOp::Ellipsis))
718        .count();
719    if ellipsis_count == 0 {
720        return Cow::Borrowed(operations);
721    }
722
723    if ellipsis_count > 1 {
724        panic!("Indexing with multiple ellipsis is not supported");
725    }
726
727    let ellipsis_pos = operations
728        .iter()
729        .position(|op| matches!(op, ArrayIndexOp::Ellipsis))
730        .unwrap();
731    let prefix = &operations[..ellipsis_pos];
732    let suffix = &operations[(ellipsis_pos + 1)..];
733    let expand_range =
734        count_non_new_axis_operations(prefix)..(ndim - count_non_new_axis_operations(suffix));
735    let expand = expand_range.map(|_| (..).index_op());
736
737    let mut expanded = Vec::with_capacity(ndim);
738    expanded.extend_from_slice(prefix);
739    expanded.extend(expand);
740    expanded.extend_from_slice(suffix);
741
742    Cow::Owned(expanded)
743}