mlx_rs/ops/indexing/mod.rs
1//! Indexing Arrays
2//!
3//! # Overview
4//!
5//! Due to limitations in the `std::ops::Index` and `std::ops::IndexMut` traits (only references can
6//! be returned), the indexing is achieved with the [`IndexOp`] and [`IndexMutOp`] traits where
7//! arrays can be indexed with [`IndexOp::index()`] and [`IndexMutOp::index_mut()`] respectively.
8//!
9//! The following types can be used as indices:
10//!
11//! | Type | Description |
12//! |------|-------------|
13//! | [`i32`] | An integer index |
14//! | [`Array`] | Use an array to index another array |
15//! | `&Array` | Use a reference to an array to index another array |
16//! | [`std::ops::Range<i32>`] | A range index |
17//! | [`std::ops::RangeFrom<i32>`] | A range index |
18//! | [`std::ops::RangeFull`] | A range index |
19//! | [`std::ops::RangeInclusive<i32>`] | A range index |
20//! | [`std::ops::RangeTo<i32>`] | A range index |
21//! | [`std::ops::RangeToInclusive<i32>`] | A range index |
22//! | [`StrideBy`] | A range index with stride |
23//! | [`NewAxis`] | Add a new axis |
24//! | [`Ellipsis`] | Consume all axes |
25//!
26//! # Single axis indexing
27//!
28//! | Indexing Operation | `mlx` (python) | `mlx-swift` | `mlx-rs` |
29//! |--------------------|--------|-------|------|
30//! | integer | `arr[1]` | `arr[1]` | `arr.index(1)` |
31//! | range expression | `arr[1:3]` | `arr[1..<3]` | `arr.index(1..3)` |
32//! | full range | `arr[:]` | `arr[0...]` | `arr.index(..)` |
33//! | range with stride | `arr[::2]` | `arr[.stride(by: 2)]` | `arr.index((..).stride_by(2))` |
34//! | ellipsis (consuming all axes) | `arr[...]` | `arr[.ellipsis]` | `arr.index(Ellipsis)` |
35//! | newaxis | `arr[None]` | `arr[.newAxis]` | `arr.index(NewAxis)` |
36//! | mlx array `i` | `arr[i]` | `arr[i]` | `arr.index(i)` |
37//!
38//! # Multi-axes indexing
39//!
40//! Multi-axes indexing with combinations of the above operations is also supported by combining the
41//! operations in a tuple with the restriction that `Ellipsis` can only be used once.
42//!
43//! ## Examples
44//!
45//! ```rust
46//! // See the multi-dimensional example code for mlx python https://ml-explore.github.io/mlx/build/html/usage/indexing.html
47//!
48//! use mlx_rs::{Array, ops::indexing::*};
49//!
50//! let a = Array::from_iter(0..8, &[2, 2, 2]);
51//!
52//! // a[:, :, 0]
53//! let mut s1 = a.index((.., .., 0));
54//!
55//! let expected = Array::from_slice(&[0, 2, 4, 6], &[2, 2]);
56//! assert_eq!(s1, expected);
57//!
58//! // a[..., 0]
59//! let mut s2 = a.index((Ellipsis, 0));
60//!
61//! let expected = Array::from_slice(&[0, 2, 4, 6], &[2, 2]);
62//! assert_eq!(s1, expected);
63//! ```
64//!
65//! # Set values with indexing
66//!
67//! The same indexing operations (single or multiple) can be used to set values in an array using
68//! the [`IndexMutOp`] trait.
69//!
70//! ## Example
71//!
72//! ```rust
73//! use mlx_rs::{Array, ops::indexing::*};
74//!
75//! let mut a = Array::from_slice(&[1, 2, 3], &[3]);
76//! a.index_mut(2, Array::from_int(0));
77//!
78//! let expected = Array::from_slice(&[1, 2, 0], &[3]);
79//! assert_eq!(a, expected);
80//! ```
81//!
82//! ```rust
83//! use mlx_rs::{Array, ops::indexing::*};
84//!
85//! let mut a = Array::from_iter(0i32..20, &[2, 2, 5]);
86//!
87//! // writing using slices -- this ends up covering two elements
88//! a.index_mut((0..1, 1..2, 2..4), Array::from_int(88));
89//!
90//! let expected = Array::from_slice(
91//! &[
92//! 0, 1, 2, 3, 4, 5, 6, 88, 88, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19,
93//! ],
94//! &[2, 2, 5],
95//! );
96//! assert_eq!(a, expected);
97//! ```
98
99use std::{borrow::Cow, ops::Bound, rc::Rc};
100
101use mlx_internal_macros::{default_device, generate_macro};
102
103use crate::{error::Result, utils::guard::Guarded, Array, Stream, StreamOrDevice};
104
105pub(crate) mod index_impl;
106pub(crate) mod indexmut_impl;
107
108/* -------------------------------------------------------------------------- */
109/* Custom types */
110/* -------------------------------------------------------------------------- */
111
112/// New axis indexing operation.
113///
114/// See the module level documentation for more information.
115#[derive(Debug, Clone, Copy)]
116pub struct NewAxis;
117
118/// Ellipsis indexing operation.
119///
120/// See the module level documentation for more information.
121#[derive(Debug, Clone, Copy)]
122pub struct Ellipsis;
123
124/// Stride indexing operation.
125///
126/// See the module level documentation for more information.
127#[derive(Debug, Clone, Copy)]
128pub struct StrideBy<I> {
129 /// The inner iterator
130 pub inner: I,
131
132 /// The stride
133 pub stride: i32,
134}
135
136/// Helper trait for creating a stride indexing operation.
137pub trait IntoStrideBy: Sized {
138 /// Create a stride indexing operation.
139 fn stride_by(self, stride: i32) -> StrideBy<Self>;
140}
141
142impl<T> IntoStrideBy for T {
143 fn stride_by(self, stride: i32) -> StrideBy<Self> {
144 StrideBy {
145 inner: self,
146 stride,
147 }
148 }
149}
150
151/// Range indexing operation.
152#[derive(Debug, Clone)]
153pub struct RangeIndex {
154 start: Bound<i32>,
155 stop: Bound<i32>,
156 stride: i32,
157}
158
159impl RangeIndex {
160 pub(crate) fn new(start: Bound<i32>, stop: Bound<i32>, stride: Option<i32>) -> Self {
161 let stride = stride.unwrap_or(1);
162 Self {
163 start,
164 stop,
165 stride,
166 }
167 }
168
169 pub(crate) fn is_full(&self) -> bool {
170 matches!(self.start, Bound::Unbounded)
171 && matches!(self.stop, Bound::Unbounded)
172 && self.stride == 1
173 }
174
175 pub(crate) fn stride(&self) -> i32 {
176 self.stride
177 }
178
179 pub(crate) fn start(&self, size: i32) -> i32 {
180 match self.start {
181 Bound::Included(start) => start,
182 Bound::Excluded(start) => start + 1,
183 Bound::Unbounded => {
184 // ref swift binding
185 // _start ?? (stride < 0 ? size - 1 : 0)
186
187 if self.stride.is_negative() {
188 size - 1
189 } else {
190 0
191 }
192 }
193 }
194 }
195
196 pub(crate) fn absolute_start(&self, size: i32) -> i32 {
197 // ref swift binding
198 // return start < 0 ? start + size : start
199
200 let start = self.start(size);
201 if start.is_negative() {
202 start + size
203 } else {
204 start
205 }
206 }
207
208 pub(crate) fn end(&self, size: i32) -> i32 {
209 match self.stop {
210 Bound::Included(stop) => stop + 1,
211 Bound::Excluded(stop) => stop,
212 Bound::Unbounded => {
213 // ref swift binding
214 // _end ?? (stride < 0 ? -size - 1 : size)
215
216 if self.stride.is_negative() {
217 -size - 1
218 } else {
219 size
220 }
221 }
222 }
223 }
224
225 pub(crate) fn absolute_end(&self, size: i32) -> i32 {
226 // ref swift binding
227 // return end < 0 ? end + size : end
228
229 let end = self.end(size);
230 if end.is_negative() {
231 end + size
232 } else {
233 end
234 }
235 }
236}
237
238/// Indexing operation for arrays.
239#[derive(Debug, Clone)]
240pub enum ArrayIndexOp<'a> {
241 /// An `Ellipsis` is used to consume all axes
242 ///
243 /// This is equivalent to `...` in python
244 Ellipsis,
245
246 /// A single index operation
247 ///
248 /// This is equivalent to `arr[1]` in python
249 TakeIndex {
250 /// The index to take
251 index: i32,
252 },
253
254 /// Indexing with an array
255 TakeArray {
256 /// The indices to take
257 indices: Rc<Array>, // TODO: remove `Rc` because `Array` is `Clone`
258 },
259
260 /// Indexing with an array reference
261 TakeArrayRef {
262 /// The indices to take
263 indices: &'a Array,
264 },
265
266 /// Indexing with a range
267 ///
268 /// This is equivalent to `arr[1:3]` in python
269 Slice(RangeIndex),
270
271 /// New axis operation
272 ///
273 /// This is equivalent to `arr[None]` in python
274 ExpandDims,
275}
276
277impl ArrayIndexOp<'_> {
278 fn is_array_or_index(&self) -> bool {
279 // Using the full match syntax to avoid forgetting to add new variants
280 match self {
281 ArrayIndexOp::TakeIndex { .. }
282 | ArrayIndexOp::TakeArrayRef { .. }
283 | ArrayIndexOp::TakeArray { .. } => true,
284 ArrayIndexOp::Ellipsis | ArrayIndexOp::Slice(_) | ArrayIndexOp::ExpandDims => false,
285 }
286 }
287
288 fn is_array(&self) -> bool {
289 // Using the full match syntax to avoid forgetting to add new variants
290 match self {
291 ArrayIndexOp::TakeArray { .. } | ArrayIndexOp::TakeArrayRef { .. } => true,
292 ArrayIndexOp::TakeIndex { .. }
293 | ArrayIndexOp::Ellipsis
294 | ArrayIndexOp::Slice(_)
295 | ArrayIndexOp::ExpandDims => false,
296 }
297 }
298}
299
300/* -------------------------------------------------------------------------- */
301/* Custom traits */
302/* -------------------------------------------------------------------------- */
303
304/// Trait for custom indexing operations.
305///
306/// Out of bounds indexing is allowed and wouldn't return an error.
307pub trait TryIndexOp<Idx> {
308 /// Try to index the array with the given index.
309 fn try_index_device(&self, i: Idx, stream: impl AsRef<Stream>) -> Result<Array>;
310
311 /// Try to index the array with the given index.
312 fn try_index(&self, i: Idx) -> Result<Array> {
313 self.try_index_device(i, StreamOrDevice::default())
314 }
315}
316
317/// Trait for custom indexing operations.
318///
319/// This is implemented for all types that implement `TryIndexOp`.
320pub trait IndexOp<Idx>: TryIndexOp<Idx> {
321 /// Index the array with the given index.
322 fn index_device(&self, i: Idx, stream: impl AsRef<Stream>) -> Array {
323 self.try_index_device(i, stream).unwrap()
324 }
325
326 /// Index the array with the given index.
327 fn index(&self, i: Idx) -> Array {
328 self.try_index(i).unwrap()
329 }
330}
331
332impl<T, Idx> IndexOp<Idx> for T where T: TryIndexOp<Idx> {}
333
334/// Trait for custom mutable indexing operations.
335pub trait TryIndexMutOp<Idx, Val> {
336 /// Try to index the array with the given index and set the value.
337 fn try_index_mut_device(&mut self, i: Idx, val: Val, stream: impl AsRef<Stream>) -> Result<()>;
338
339 /// Try to index the array with the given index and set the value.
340 fn try_index_mut(&mut self, i: Idx, val: Val) -> Result<()> {
341 self.try_index_mut_device(i, val, StreamOrDevice::default())
342 }
343}
344
345// TODO: should `Val` impl `AsRef<Array>` or `Into<Array>`?
346
347/// Trait for custom mutable indexing operations.
348pub trait IndexMutOp<Idx, Val>: TryIndexMutOp<Idx, Val> {
349 /// Index the array with the given index and set the value.
350 fn index_mut_device(&mut self, i: Idx, val: Val, stream: impl AsRef<Stream>) {
351 self.try_index_mut_device(i, val, stream).unwrap()
352 }
353
354 /// Index the array with the given index and set the value.
355 fn index_mut(&mut self, i: Idx, val: Val) {
356 self.try_index_mut(i, val).unwrap()
357 }
358}
359
360impl<T, Idx, Val> IndexMutOp<Idx, Val> for T where T: TryIndexMutOp<Idx, Val> {}
361
362/// Trait for custom indexing operations.
363pub trait ArrayIndex<'a> {
364 /// `mlx` allows out of bounds indexing.
365 fn index_op(self) -> ArrayIndexOp<'a>;
366}
367
368/* -------------------------------------------------------------------------- */
369/* Implementation */
370/* -------------------------------------------------------------------------- */
371
372// Implement public bindings
373impl Array {
374 /// Take elements along an axis.
375 ///
376 /// The elements are taken from `indices` along the specified axis. If the axis is not specified
377 /// the array is treated as a flattened 1-D array prior to performing the take.
378 ///
379 /// See [`Array::take_all`] for the flattened array.
380 ///
381 /// # Params
382 ///
383 /// - `indices`: The indices to take from the array.
384 /// - `axis`: The axis along which to take the elements.
385 #[default_device]
386 pub fn take_device(
387 &self,
388 indices: impl AsRef<Array>,
389 axis: i32,
390 stream: impl AsRef<Stream>,
391 ) -> Result<Array> {
392 Array::try_from_op(|res| unsafe {
393 mlx_sys::mlx_take(
394 res,
395 self.as_ptr(),
396 indices.as_ref().as_ptr(),
397 axis,
398 stream.as_ref().as_ptr(),
399 )
400 })
401 }
402
403 /// Take elements from flattened 1-D array.
404 ///
405 /// # Params
406 ///
407 /// - `indices`: The indices to take from the array.
408 #[default_device]
409 pub fn take_all_device(
410 &self,
411 indices: impl AsRef<Array>,
412 stream: impl AsRef<Stream>,
413 ) -> Result<Array> {
414 Array::try_from_op(|res| unsafe {
415 mlx_sys::mlx_take_all(
416 res,
417 self.as_ptr(),
418 indices.as_ref().as_ptr(),
419 stream.as_ref().as_ptr(),
420 )
421 })
422 }
423
424 /// Take values along an axis at the specified indices.
425 ///
426 /// If no axis is specified, the array is flattened to 1D prior to the indexing operation.
427 ///
428 /// # Params
429 ///
430 /// - `indices`: The indices to take from the array.
431 /// - `axis`: Axis in the input to take the values from.
432 #[default_device]
433 pub fn take_along_axis_device(
434 &self,
435 indices: impl AsRef<Array>,
436 axis: impl Into<Option<i32>>,
437 stream: impl AsRef<Stream>,
438 ) -> Result<Array> {
439 let (input, axis) = match axis.into() {
440 None => (Cow::Owned(self.reshape_device(&[-1], &stream)?), 0),
441 Some(ax) => (Cow::Borrowed(self), ax),
442 };
443
444 Array::try_from_op(|res| unsafe {
445 mlx_sys::mlx_take_along_axis(
446 res,
447 input.as_ptr(),
448 indices.as_ref().as_ptr(),
449 axis,
450 stream.as_ref().as_ptr(),
451 )
452 })
453 }
454
455 /// Put values along an axis at the specified indices.
456 ///
457 /// If no axis is specified, the array is flattened to 1D prior to the indexing operation.
458 ///
459 /// # Params
460 /// - indices: Indices array. These should be broadcastable with the input array excluding the `axis` dimension.
461 /// - values: Values array. These should be broadcastable with the indices.
462 /// - axis: Axis in the destination to put the values to.
463 /// - stream: stream or device to evaluate on.
464 #[default_device]
465 pub fn put_along_axis_device(
466 &self,
467 indices: impl AsRef<Array>,
468 values: impl AsRef<Array>,
469 axis: impl Into<Option<i32>>,
470 stream: impl AsRef<Stream>,
471 ) -> Result<Array> {
472 match axis.into() {
473 None => {
474 let input = self.reshape_device(&[-1], &stream)?;
475 let array = Array::try_from_op(|res| unsafe {
476 mlx_sys::mlx_put_along_axis(
477 res,
478 input.as_ptr(),
479 indices.as_ref().as_ptr(),
480 values.as_ref().as_ptr(),
481 0,
482 stream.as_ref().as_ptr(),
483 )
484 })?;
485 let array = array.reshape_device(self.shape(), &stream)?;
486 Ok(array)
487 }
488 Some(ax) => Array::try_from_op(|res| unsafe {
489 mlx_sys::mlx_put_along_axis(
490 res,
491 self.as_ptr(),
492 indices.as_ref().as_ptr(),
493 values.as_ref().as_ptr(),
494 ax,
495 stream.as_ref().as_ptr(),
496 )
497 }),
498 }
499 }
500}
501
502/// Indices of the maximum values along the axis.
503///
504/// See [`argmax_all`] for the flattened array.
505///
506/// # Params
507///
508/// - `a`: The input array.
509/// - `axis`: Axis to reduce over
510/// - `keep_dims`: Keep reduced axes as singleton dimensions, defaults to False.
511#[generate_macro(customize(root = "$crate::ops::indexing"))]
512#[default_device]
513pub fn argmax_device(
514 a: impl AsRef<Array>,
515 axis: i32,
516 #[optional] keep_dims: impl Into<Option<bool>>,
517 #[optional] stream: impl AsRef<Stream>,
518) -> Result<Array> {
519 let keep_dims = keep_dims.into().unwrap_or(false);
520
521 Array::try_from_op(|res| unsafe {
522 mlx_sys::mlx_argmax(
523 res,
524 a.as_ref().as_ptr(),
525 axis,
526 keep_dims,
527 stream.as_ref().as_ptr(),
528 )
529 })
530}
531
532/// Indices of the maximum value over the entire array.
533///
534/// # Params
535///
536/// - `a`: The input array.
537/// - `keep_dims`: Keep reduced axes as singleton dimensions, defaults to False.
538#[generate_macro(customize(root = "$crate::ops::indexing"))]
539#[default_device]
540pub fn argmax_all_device(
541 a: impl AsRef<Array>,
542 #[optional] keep_dims: impl Into<Option<bool>>,
543 #[optional] stream: impl AsRef<Stream>,
544) -> Result<Array> {
545 let keep_dims = keep_dims.into().unwrap_or(false);
546
547 Array::try_from_op(|res| unsafe {
548 mlx_sys::mlx_argmax_all(
549 res,
550 a.as_ref().as_ptr(),
551 keep_dims,
552 stream.as_ref().as_ptr(),
553 )
554 })
555}
556
557/// Indices of the minimum values along the axis.
558///
559/// See [`argmin_all`] for the flattened array.
560///
561/// # Params
562///
563/// - `a`: The input array.
564/// - `axis`: Axis to reduce over.
565/// - `keep_dims`: Keep reduced axes as singleton dimensions, defaults to False.
566#[generate_macro(customize(root = "$crate::ops::indexing"))]
567#[default_device]
568pub fn argmin_device(
569 a: impl AsRef<Array>,
570 axis: i32,
571 #[optional] keep_dims: impl Into<Option<bool>>,
572 #[optional] stream: impl AsRef<Stream>,
573) -> Result<Array> {
574 let keep_dims = keep_dims.into().unwrap_or(false);
575
576 Array::try_from_op(|res| unsafe {
577 mlx_sys::mlx_argmin(
578 res,
579 a.as_ref().as_ptr(),
580 axis,
581 keep_dims,
582 stream.as_ref().as_ptr(),
583 )
584 })
585}
586
587/// Indices of the minimum value over the entire array.
588///
589/// # Params
590///
591/// - `a`: The input array.
592/// - `keep_dims`: Keep reduced axes as singleton dimensions, defaults to False.
593#[generate_macro(customize(root = "$crate::ops::indexing"))]
594#[default_device]
595pub fn argmin_all_device(
596 a: impl AsRef<Array>,
597 #[optional] keep_dims: impl Into<Option<bool>>,
598 #[optional] stream: impl AsRef<Stream>,
599) -> Result<Array> {
600 let keep_dims = keep_dims.into().unwrap_or(false);
601
602 Array::try_from_op(|res| unsafe {
603 mlx_sys::mlx_argmin_all(
604 res,
605 a.as_ref().as_ptr(),
606 keep_dims,
607 stream.as_ref().as_ptr(),
608 )
609 })
610}
611
612/// See [`Array::take_along_axis`]
613#[generate_macro(customize(root = "$crate::ops::indexing"))]
614#[default_device]
615pub fn take_along_axis_device(
616 a: impl AsRef<Array>,
617 indices: impl AsRef<Array>,
618 #[optional] axis: impl Into<Option<i32>>,
619 #[optional] stream: impl AsRef<Stream>,
620) -> Result<Array> {
621 a.as_ref().take_along_axis_device(indices, axis, stream)
622}
623
624/// See [`Array::put_along_axis`]
625#[generate_macro(customize(root = "$crate::ops::indexing"))]
626#[default_device]
627pub fn put_along_axis_device(
628 a: impl AsRef<Array>,
629 indices: impl AsRef<Array>,
630 values: impl AsRef<Array>,
631 #[optional] axis: impl Into<Option<i32>>,
632 #[optional] stream: impl AsRef<Stream>,
633) -> Result<Array> {
634 a.as_ref()
635 .put_along_axis_device(indices, values, axis, stream)
636}
637
638/// See [`Array::take`]
639#[generate_macro(customize(root = "$crate::ops::indexing"))]
640#[default_device]
641pub fn take_device(
642 a: impl AsRef<Array>,
643 indices: impl AsRef<Array>,
644 axis: i32,
645 #[optional] stream: impl AsRef<Stream>,
646) -> Result<Array> {
647 a.as_ref().take_device(indices, axis, stream)
648}
649
650/// See [`Array::take_all`]
651#[generate_macro(customize(root = "$crate::ops::indexing"))]
652#[default_device]
653pub fn take_all_device(
654 a: impl AsRef<Array>,
655 indices: impl AsRef<Array>,
656 #[optional] stream: impl AsRef<Stream>,
657) -> Result<Array> {
658 a.as_ref().take_all_device(indices, stream)
659}
660
661/// Returns the `k` largest elements from the input along a given axis.
662///
663/// The elements will not necessarily be in sorted order.
664///
665/// See [`topk_all`] for the flattened array.
666///
667/// # Params
668///
669/// - `a`: The input array.
670/// - `k`: The number of elements to return.
671/// - `axis`: Axis to sort over. Default to `-1` if not specified.
672#[generate_macro(customize(root = "$crate::ops::indexing"))]
673#[default_device]
674pub fn topk_device(
675 a: impl AsRef<Array>,
676 k: i32,
677 #[optional] axis: impl Into<Option<i32>>,
678 #[optional] stream: impl AsRef<Stream>,
679) -> Result<Array> {
680 let axis = axis.into().unwrap_or(-1);
681
682 Array::try_from_op(|res| unsafe {
683 mlx_sys::mlx_topk(res, a.as_ref().as_ptr(), k, axis, stream.as_ref().as_ptr())
684 })
685}
686
687/// Returns the `k` largest elements from the flattened input array.
688#[generate_macro(customize(root = "$crate::ops::indexing"))]
689#[default_device]
690pub fn topk_all_device(
691 a: impl AsRef<Array>,
692 k: i32,
693 #[optional] stream: impl AsRef<Stream>,
694) -> Result<Array> {
695 Array::try_from_op(|res| unsafe {
696 mlx_sys::mlx_topk_all(res, a.as_ref().as_ptr(), k, stream.as_ref().as_ptr())
697 })
698}
699
700/* -------------------------------------------------------------------------- */
701/* Helper functions */
702/* -------------------------------------------------------------------------- */
703
704fn count_non_new_axis_operations(operations: &[ArrayIndexOp]) -> usize {
705 operations
706 .iter()
707 .filter(|op| !matches!(op, ArrayIndexOp::ExpandDims))
708 .count()
709}
710
711fn expand_ellipsis_operations<'a>(
712 ndim: usize,
713 operations: &'a [ArrayIndexOp<'a>],
714) -> Cow<'a, [ArrayIndexOp<'a>]> {
715 let ellipsis_count = operations
716 .iter()
717 .filter(|op| matches!(op, ArrayIndexOp::Ellipsis))
718 .count();
719 if ellipsis_count == 0 {
720 return Cow::Borrowed(operations);
721 }
722
723 if ellipsis_count > 1 {
724 panic!("Indexing with multiple ellipsis is not supported");
725 }
726
727 let ellipsis_pos = operations
728 .iter()
729 .position(|op| matches!(op, ArrayIndexOp::Ellipsis))
730 .unwrap();
731 let prefix = &operations[..ellipsis_pos];
732 let suffix = &operations[(ellipsis_pos + 1)..];
733 let expand_range =
734 count_non_new_axis_operations(prefix)..(ndim - count_non_new_axis_operations(suffix));
735 let expand = expand_range.map(|_| (..).index_op());
736
737 let mut expanded = Vec::with_capacity(ndim);
738 expanded.extend_from_slice(prefix);
739 expanded.extend(expand);
740 expanded.extend_from_slice(suffix);
741
742 Cow::Owned(expanded)
743}