mlx_rs/ops/indexing/
indexmut_impl.rs

1use std::borrow::Cow;
2
3use smallvec::{smallvec, SmallVec};
4
5use crate::{
6    constants::DEFAULT_STACK_VEC_LEN,
7    error::Result,
8    ops::{
9        broadcast_arrays_device, broadcast_to_device,
10        indexing::{count_non_new_axis_operations, expand_ellipsis_operations},
11        reshape_device,
12    },
13    utils::{resolve_index_signed_unchecked, VectorArray},
14    Array, Stream,
15};
16
17use super::{ArrayIndex, ArrayIndexOp, Guarded, RangeIndex, TryIndexMutOp};
18
19impl Array {
20    pub(crate) fn slice_update_device(
21        &self,
22        update: &Array,
23        starts: &[i32],
24        ends: &[i32],
25        strides: &[i32],
26        stream: impl AsRef<Stream>,
27    ) -> Result<Array> {
28        Array::try_from_op(|res| unsafe {
29            mlx_sys::mlx_slice_update(
30                res,
31                self.as_ptr(),
32                update.as_ptr(),
33                starts.as_ptr(),
34                starts.len(),
35                ends.as_ptr(),
36                ends.len(),
37                strides.as_ptr(),
38                strides.len(),
39                stream.as_ref().as_ptr(),
40            )
41        })
42    }
43}
44
45// See `updateSlice` in the swift binding or `mlx_slice_update` in the python binding
46fn update_slice(
47    src: &Array,
48    operations: &[ArrayIndexOp],
49    update: &Array,
50    stream: impl AsRef<Stream>,
51) -> Result<Option<Array>> {
52    let ndim = src.ndim();
53    if ndim == 0 || operations.is_empty() {
54        return Ok(None);
55    }
56
57    // Remove leading singletons dimensions from the update
58    let mut update = remove_leading_singleton_dimensions(update, &stream)?;
59
60    // Build slice update params
61    let mut starts: SmallVec<[i32; DEFAULT_STACK_VEC_LEN]> = smallvec![0; ndim];
62    let mut ends: SmallVec<[i32; DEFAULT_STACK_VEC_LEN]> = SmallVec::from_slice(src.shape());
63    let mut strides: SmallVec<[i32; DEFAULT_STACK_VEC_LEN]> = smallvec![1; ndim];
64
65    // If it's just a simple slice, just do a slice update and return
66    if operations.len() == 1 {
67        if let ArrayIndexOp::Slice(range_index) = &operations[0] {
68            let size = src.dim(0);
69            starts[0] = range_index.start(size);
70            ends[0] = range_index.end(size);
71            strides[0] = range_index.stride();
72
73            return Ok(Some(src.slice_update_device(
74                &update, &starts, &ends, &strides, &stream,
75            )?));
76        }
77    }
78
79    // Can't route to slice update if any arrays are present
80    if operations.iter().any(|op| op.is_array()) {
81        return Ok(None);
82    }
83
84    // Expand ellipses into a series of ':' (range full) slices
85    let operations = expand_ellipsis_operations(ndim, operations);
86
87    // If no non-None indices return the broadcasted update
88    let non_new_axis_operation_count = count_non_new_axis_operations(&operations);
89    if non_new_axis_operation_count == 0 {
90        return Ok(Some(broadcast_to_device(&update, src.shape(), &stream)?));
91    }
92
93    // Process entries
94    // let mut update_expand_dims: SmallVec<[i32; DEFAULT_STACK_VEC_LEN]> = SmallVec::new();
95    let mut update_reshape: SmallVec<[i32; DEFAULT_STACK_VEC_LEN]> = smallvec![0; ndim];
96    let mut axis = src.ndim() - 1;
97    let mut update_axis = update.ndim() as i32 - 1;
98
99    while axis >= non_new_axis_operation_count {
100        if update_axis >= 0 {
101            update_reshape[axis] = update.dim(update_axis);
102            update_axis -= 1;
103        } else {
104            update_reshape[axis] = 1;
105        }
106        axis -= 1;
107    }
108
109    for item in operations.iter().rev() {
110        use ArrayIndexOp::*;
111
112        match item {
113            TakeIndex { index } => {
114                let size = src.dim(axis as i32);
115                let index = if index.is_negative() {
116                    size + index
117                } else {
118                    *index
119                };
120                // SAFETY: axis is always non-negative
121                starts[axis] = index;
122                ends[axis] = index.saturating_add(1);
123
124                update_reshape[axis] = 1;
125                axis = axis.saturating_sub(1);
126            }
127            Slice(slice) => {
128                let size = src.dim(axis as i32);
129                // SAFETY: axis is always non-negative
130                starts[axis] = slice.start(size);
131                ends[axis] = slice.end(size);
132                strides[axis] = slice.stride();
133
134                if update_axis >= 0 {
135                    update_reshape[axis] = update.dim(update_axis);
136                    update_axis = update_axis.saturating_sub(1);
137                } else {
138                    update_reshape[axis] = 1;
139                }
140                axis = axis.saturating_sub(1);
141            }
142            ExpandDims => {}
143            Ellipsis | TakeArray { indices: _ } | TakeArrayRef { indices: _ } => {
144                panic!("unexpected item in operations")
145            }
146        }
147    }
148
149    if update.shape() != &update_reshape[..] {
150        update = Cow::Owned(reshape_device(update, &update_reshape, &stream)?);
151    }
152
153    Ok(Some(src.slice_update_device(
154        &update, &starts, &ends, &strides, &stream,
155    )?))
156}
157
158// See `leadingSingletonDimensionsRemoved` in the swift binding
159fn remove_leading_singleton_dimensions(
160    a: &Array,
161    stream: impl AsRef<Stream>,
162) -> Result<Cow<'_, Array>> {
163    let shape = a.shape();
164    let mut new_shape: Vec<_> = shape.iter().skip_while(|&&dim| dim == 1).cloned().collect();
165    if shape != new_shape {
166        if new_shape.is_empty() {
167            new_shape = vec![1];
168        }
169        Ok(Cow::Owned(a.reshape_device(&new_shape, stream)?))
170    } else {
171        Ok(Cow::Borrowed(a))
172    }
173}
174
175struct ScatterArgs<'a> {
176    indices: SmallVec<[Cow<'a, Array>; DEFAULT_STACK_VEC_LEN]>,
177    update: Array,
178    axes: SmallVec<[i32; DEFAULT_STACK_VEC_LEN]>,
179}
180
181/// See `scatterArguments` in the swift binding
182fn scatter_args<'a>(
183    src: &'a Array,
184    operations: &'a [ArrayIndexOp],
185    update: &Array,
186    stream: impl AsRef<Stream>,
187) -> Result<ScatterArgs<'a>> {
188    use ArrayIndexOp::*;
189
190    if operations.len() == 1 {
191        return match &operations[0] {
192            TakeIndex { index } => scatter_args_index(src, *index, update, stream),
193            TakeArray { indices } => {
194                scatter_args_array(src, Cow::Borrowed(indices), update, stream)
195            }
196            TakeArrayRef { indices } => {
197                scatter_args_array(src, Cow::Borrowed(indices), update, stream)
198            }
199            Slice(range_index) => scatter_args_slice(src, range_index, update, stream),
200            ExpandDims => Ok(ScatterArgs {
201                indices: smallvec![],
202                update: broadcast_to_device(update, src.shape(), &stream)?,
203                axes: smallvec![],
204            }),
205            Ellipsis => panic!("Unable to update array with ellipsis argument"),
206        };
207    }
208
209    scatter_args_nd(src, operations, update, stream)
210}
211
212fn scatter_args_index<'a>(
213    src: &'a Array,
214    index: i32,
215    update: &Array,
216    stream: impl AsRef<Stream>,
217) -> Result<ScatterArgs<'a>> {
218    // mlx_scatter_args_index
219
220    // Remove any leading singleton dimensions from the update
221    // and then broadcast update to shape of src[0, ...]
222    let update = remove_leading_singleton_dimensions(update, &stream)?;
223
224    let mut shape: SmallVec<[i32; DEFAULT_STACK_VEC_LEN]> = SmallVec::from_slice(src.shape());
225    shape[0] = 1;
226
227    Ok(ScatterArgs {
228        indices: smallvec![Cow::Owned(Array::from_int(resolve_index_signed_unchecked(
229            index,
230            src.dim(0)
231        )))],
232        update: broadcast_to_device(&update, &shape, &stream)?,
233        axes: smallvec![0],
234    })
235}
236
237fn scatter_args_array<'a>(
238    src: &'a Array,
239    a: Cow<'a, Array>,
240    update: &Array,
241    stream: impl AsRef<Stream>,
242) -> Result<ScatterArgs<'a>> {
243    // mlx_scatter_args_array
244
245    // trim leading singleton dimensions
246    let update = remove_leading_singleton_dimensions(update, &stream)?;
247
248    // The update shape must broadcast with indices.shape + [1] + src.shape[1:]
249    let mut update_shape: SmallVec<[i32; DEFAULT_STACK_VEC_LEN]> = a
250        .shape()
251        .iter()
252        .chain(src.shape().iter().skip(1))
253        .cloned()
254        .collect();
255    let update = broadcast_to_device(&update, &update_shape, &stream)?;
256
257    update_shape.insert(a.ndim(), 1);
258    let update = update.reshape_device(&update_shape, &stream)?;
259
260    Ok(ScatterArgs {
261        indices: smallvec![a],
262        update,
263        axes: smallvec![0],
264    })
265}
266
267fn scatter_args_slice<'a>(
268    src: &'a Array,
269    range_index: &'a RangeIndex,
270    update: &Array,
271    stream: impl AsRef<Stream>,
272) -> Result<ScatterArgs<'a>> {
273    // mlx_scatter_args_slice
274
275    // if none slice is requested braodcast the update to the src size and return it
276    if range_index.is_full() {
277        let update = remove_leading_singleton_dimensions(update, &stream)?;
278
279        return Ok(ScatterArgs {
280            indices: smallvec![],
281            update: broadcast_to_device(&update, src.shape(), &stream)?,
282            axes: smallvec![],
283        });
284    }
285
286    let size = src.dim(0);
287    let start = range_index.start(size);
288    let end = range_index.end(size);
289    let stride = range_index.stride();
290
291    // If simple stride
292    if stride == 1 {
293        let update = remove_leading_singleton_dimensions(update, &stream)?;
294
295        // Broadcast update to slice size
296        let update_broadcast_shape: SmallVec<[i32; DEFAULT_STACK_VEC_LEN]> = (1..end - start)
297            .chain(src.shape().iter().skip(1).cloned())
298            .collect();
299        let update = broadcast_to_device(&update, &update_broadcast_shape, &stream)?;
300
301        let indices = Array::from_slice(&[start], &[1]);
302        Ok(ScatterArgs {
303            indices: smallvec![Cow::Owned(indices)],
304            update,
305            axes: smallvec![0],
306        })
307    } else {
308        // stride != 1, convert the slice to an array
309        let a_vals = strided_range_to_vec(start, end, stride);
310        let a = Array::from_slice(&a_vals, &[a_vals.len() as i32]);
311
312        scatter_args_array(src, Cow::Owned(a), update, stream)
313    }
314}
315
316fn scatter_args_nd<'a>(
317    src: &'a Array,
318    operations: &[ArrayIndexOp],
319    update: &Array,
320    stream: impl AsRef<Stream>,
321) -> Result<ScatterArgs<'a>> {
322    use ArrayIndexOp::*;
323
324    // mlx_scatter_args_nd
325
326    let shape = src.shape();
327
328    let operations = expand_ellipsis_operations(src.ndim(), operations);
329    let update = remove_leading_singleton_dimensions(update, &stream)?;
330
331    // If no non-newAxis indices return the broadcasted update
332    let non_new_axis_operation_count = count_non_new_axis_operations(&operations);
333    if non_new_axis_operation_count == 0 {
334        return Ok(ScatterArgs {
335            indices: smallvec![],
336            update: broadcast_to_device(&update, shape, &stream)?,
337            axes: smallvec![],
338        });
339    }
340
341    // Analyse the types of the indices
342    let mut max_dims = 0;
343    let mut arrays_first = false;
344    let mut count_new_axis: i32 = 0;
345    let mut count_slices: i32 = 0;
346    let mut count_arrays: i32 = 0;
347    let mut count_strided_slices: i32 = 0;
348    let mut count_simple_slices_post: i32 = 0;
349
350    let mut have_array = false;
351    let mut have_non_array = false;
352
353    macro_rules! analyze_indices_take_array {
354        ($indices:ident) => {
355            have_array = true;
356            if have_array && have_non_array {
357                arrays_first = true;
358            }
359            max_dims = $indices.ndim().max(max_dims);
360            count_arrays = count_arrays.saturating_add(1);
361            count_simple_slices_post = 0;
362        };
363    }
364
365    for item in operations.iter() {
366        match item {
367            TakeIndex { index: _ } => {
368                // ignore
369            }
370            Slice(range_index) => {
371                have_non_array = have_array;
372                count_slices = count_slices.saturating_add(1);
373                if range_index.stride() != 1 {
374                    count_strided_slices = count_strided_slices.saturating_add(1);
375                    count_simple_slices_post = 0;
376                } else {
377                    count_simple_slices_post = count_simple_slices_post.saturating_add(1);
378                }
379            }
380            TakeArray { indices } => {
381                analyze_indices_take_array!(indices);
382            }
383            TakeArrayRef { indices } => {
384                analyze_indices_take_array!(indices);
385            }
386            ExpandDims => {
387                have_non_array = true;
388                count_new_axis = count_new_axis.saturating_add(1);
389            }
390            Ellipsis => panic!("Unexpected item ellipsis in scatter_args_nd"),
391        }
392    }
393
394    // We have index dims for the arrays, strided slices (implemented as arrays), none
395    let mut index_dims = (max_dims + count_new_axis as usize + count_slices as usize)
396        .saturating_sub(count_simple_slices_post as usize);
397
398    // If we have simple non-strided slices, we also attach an index for that
399    if index_dims == 0 {
400        index_dims = 1;
401    }
402
403    // Go over each index type and translate to the needed scatter args
404    let mut array_indices: SmallVec<[Array; DEFAULT_STACK_VEC_LEN]> =
405        SmallVec::with_capacity(operations.len());
406    let mut slice_number: i32 = 0;
407    let mut array_number: i32 = 0;
408    let mut axis: i32 = 0;
409
410    // We collect the shapes of the slices and updates during this process
411    let mut update_shape = vec![1; non_new_axis_operation_count];
412    let mut slice_shapes: SmallVec<[i32; DEFAULT_STACK_VEC_LEN]> = SmallVec::new();
413
414    macro_rules! update_shapes_take_array {
415        ($indices:ident) => {
416            // Place the arrays in the correct dimension
417            let start = if arrays_first {
418                max_dims - $indices.ndim()
419            } else {
420                // SAFETY: slice_number is never decremented and should be non-negative
421                slice_number as usize + max_dims - $indices.ndim()
422            };
423            let mut new_shape = vec![1; index_dims];
424
425            for j in 0..$indices.ndim() {
426                new_shape[start + j] = $indices.dim(j as i32);
427            }
428
429            array_indices.push($indices.reshape_device(&new_shape, &stream)?);
430            array_number = array_number.saturating_add(1);
431
432            if !arrays_first && array_number == count_arrays {
433                slice_number = slice_number.saturating_add_unsigned(max_dims as u32);
434            }
435
436            // Add the shape to the update
437            update_shape[axis as usize] = 1;
438            axis = axis.saturating_add(1);
439        };
440    }
441
442    for item in operations.iter() {
443        match item {
444            TakeIndex { index } => {
445                let resolved_index = resolve_index_signed_unchecked(*index, src.dim(axis));
446                array_indices.push(Array::from_int(resolved_index));
447                // SAFETY: axis is always non-negative
448                update_shape[axis as usize] = 1;
449                axis = axis.saturating_add(1);
450            }
451            Slice(range_index) => {
452                let size = src.dim(axis);
453                let start = range_index.absolute_start(size);
454                let end = range_index.absolute_end(size);
455                let stride = range_index.stride();
456
457                let mut index_shape = vec![1; index_dims];
458
459                // If it's a simple slice, we only need to add the start index
460                if array_number >= count_arrays && count_strided_slices <= 0 && stride == 1 {
461                    let index = Array::from_int(start).reshape_device(&index_shape, &stream)?;
462                    let slice_shape_entry = end - start;
463                    slice_shapes.push(slice_shape_entry);
464                    array_indices.push(index);
465
466                    // Add the shape to the update
467                    update_shape[axis as usize] = slice_shape_entry;
468                } else {
469                    // Otherwise we expand the slice into indices using arange
470                    let index_vals = strided_range_to_vec(start, end, stride);
471                    let index = Array::from_slice(&index_vals, &[index_vals.len() as i32]);
472                    let location = if arrays_first {
473                        slice_number.saturating_add(max_dims as i32)
474                    } else {
475                        slice_number
476                    };
477                    index_shape[location as usize] = index.size() as i32;
478                    array_indices.push(index.reshape_device(&index_shape, &stream)?);
479
480                    slice_number = slice_number.saturating_add(1);
481                    count_strided_slices = count_strided_slices.saturating_sub(1);
482
483                    // Add the shape to the update
484                    update_shape[axis as usize] = 1;
485                }
486
487                axis = axis.saturating_add(1);
488            }
489            TakeArray { indices } => {
490                update_shapes_take_array!(indices);
491            }
492            TakeArrayRef { indices } => {
493                update_shapes_take_array!(indices);
494            }
495            ExpandDims => slice_number = slice_number.saturating_add(1),
496            Ellipsis => panic!("Unexpected item ellipsis in scatter_args_nd"),
497        }
498    }
499
500    // Broadcast the update to the indices and slices
501    let array_indices = broadcast_arrays_device(&array_indices, &stream)?;
502    let update_shape_broadcast: Vec<_> = array_indices[0]
503        .shape()
504        .iter()
505        .chain(slice_shapes.iter())
506        .chain(src.shape().iter().skip(non_new_axis_operation_count))
507        .cloned()
508        .collect();
509    let update = broadcast_to_device(&update, &update_shape_broadcast, &stream)?;
510
511    // Reshape the update with the size-1 dims for the int and array indices
512    let update_reshape: Vec<_> = array_indices[0]
513        .shape()
514        .iter()
515        .chain(update_shape.iter())
516        .chain(src.shape().iter().skip(non_new_axis_operation_count))
517        .cloned()
518        .collect();
519
520    let update = update.reshape_device(&update_reshape, &stream)?;
521
522    let array_indices_len = array_indices.len();
523
524    let indices = array_indices.into_iter().map(Cow::Owned).collect();
525    Ok(ScatterArgs {
526        indices,
527        update,
528        axes: (0..array_indices_len as i32).collect(),
529    })
530}
531
532fn strided_range_to_vec(start: i32, exclusive_end: i32, stride: i32) -> Vec<i32> {
533    let estimated_capacity = (exclusive_end - start).abs() / stride.abs();
534    let mut vec = Vec::with_capacity(estimated_capacity as usize);
535    let mut current = start;
536
537    assert_ne!(stride, 0, "Stride cannot be zero");
538
539    if stride.is_negative() {
540        while current > exclusive_end {
541            vec.push(current);
542            current += stride;
543        }
544    } else {
545        while current < exclusive_end {
546            vec.push(current);
547            current += stride;
548        }
549    }
550
551    vec
552}
553
554unsafe fn scatter_device(
555    a: &Array,
556    indices: &[impl AsRef<Array>],
557    updates: &Array,
558    axes: &[i32],
559    stream: impl AsRef<Stream>,
560) -> Result<Array> {
561    let indices_vector = VectorArray::try_from_iter(indices.iter())?;
562
563    Array::try_from_op(|res| unsafe {
564        mlx_sys::mlx_scatter(
565            res,
566            a.as_ptr(),
567            indices_vector.as_ptr(),
568            updates.as_ptr(),
569            axes.as_ptr(),
570            axes.len(),
571            stream.as_ref().as_ptr(),
572        )
573    })
574}
575
576impl Array {
577    fn try_index_mut_device_inner(
578        &mut self,
579        operations: &[ArrayIndexOp],
580        update: &Array,
581        stream: impl AsRef<Stream>,
582    ) -> Result<()> {
583        if let Some(result) = update_slice(self, operations, update, &stream)? {
584            *self = result;
585            return Ok(());
586        }
587
588        let ScatterArgs {
589            indices,
590            update,
591            axes,
592        } = scatter_args(self, operations, update, &stream)?;
593        if !indices.is_empty() {
594            let result = unsafe { scatter_device(self, &indices, &update, &axes, stream)? };
595            drop(indices);
596            *self = result;
597        } else {
598            drop(indices);
599            *self = update;
600        }
601        Ok(())
602    }
603}
604
605impl<'a, Val> TryIndexMutOp<&'a [ArrayIndexOp<'a>], Val> for Array
606where
607    Val: AsRef<Array>,
608{
609    fn try_index_mut_device(
610        &mut self,
611        i: &'a [ArrayIndexOp<'a>],
612        val: Val,
613        stream: impl AsRef<Stream>,
614    ) -> Result<()> {
615        let update = val.as_ref();
616        self.try_index_mut_device_inner(i, update, stream)
617    }
618}
619
620impl<A, Val> TryIndexMutOp<A, Val> for Array
621where
622    for<'a> A: ArrayIndex<'a>,
623    Val: AsRef<Array>,
624{
625    fn try_index_mut_device(&mut self, i: A, val: Val, stream: impl AsRef<Stream>) -> Result<()> {
626        let operations = [i.index_op()];
627        let update = val.as_ref();
628        self.try_index_mut_device_inner(&operations, update, stream)
629    }
630}
631
632impl<'a, A, Val> TryIndexMutOp<(A,), Val> for Array
633where
634    A: ArrayIndex<'a>,
635    Val: AsRef<Array>,
636{
637    fn try_index_mut_device(
638        &mut self,
639        (i,): (A,),
640        val: Val,
641        stream: impl AsRef<Stream>,
642    ) -> Result<()> {
643        let operations = [i.index_op()];
644        let update = val.as_ref();
645        self.try_index_mut_device_inner(&operations, update, stream)
646    }
647}
648
649impl<'a, 'b, A, B, Val> TryIndexMutOp<(A, B), Val> for Array
650where
651    A: ArrayIndex<'a>,
652    B: ArrayIndex<'b>,
653    Val: AsRef<Array>,
654{
655    fn try_index_mut_device(
656        &mut self,
657        i: (A, B),
658        val: Val,
659        stream: impl AsRef<Stream>,
660    ) -> Result<()> {
661        let operations = [i.0.index_op(), i.1.index_op()];
662        let update = val.as_ref();
663        self.try_index_mut_device_inner(&operations, update, stream)
664    }
665}
666
667impl<'a, 'b, 'c, A, B, C, Val> TryIndexMutOp<(A, B, C), Val> for Array
668where
669    A: ArrayIndex<'a>,
670    B: ArrayIndex<'b>,
671    C: ArrayIndex<'c>,
672    Val: AsRef<Array>,
673{
674    fn try_index_mut_device(
675        &mut self,
676        i: (A, B, C),
677        val: Val,
678        stream: impl AsRef<Stream>,
679    ) -> Result<()> {
680        let operations = [i.0.index_op(), i.1.index_op(), i.2.index_op()];
681        let update = val.as_ref();
682        self.try_index_mut_device_inner(&operations, update, stream)
683    }
684}
685
686impl<'a, 'b, 'c, 'd, A, B, C, D, Val> TryIndexMutOp<(A, B, C, D), Val> for Array
687where
688    A: ArrayIndex<'a>,
689    B: ArrayIndex<'b>,
690    C: ArrayIndex<'c>,
691    D: ArrayIndex<'d>,
692    Val: AsRef<Array>,
693{
694    fn try_index_mut_device(
695        &mut self,
696        i: (A, B, C, D),
697        val: Val,
698        stream: impl AsRef<Stream>,
699    ) -> Result<()> {
700        let operations = [
701            i.0.index_op(),
702            i.1.index_op(),
703            i.2.index_op(),
704            i.3.index_op(),
705        ];
706        let update = val.as_ref();
707        self.try_index_mut_device_inner(&operations, update, stream)
708    }
709}
710
711impl<'a, 'b, 'c, 'd, 'e, A, B, C, D, E, Val> TryIndexMutOp<(A, B, C, D, E), Val> for Array
712where
713    A: ArrayIndex<'a>,
714    B: ArrayIndex<'b>,
715    C: ArrayIndex<'c>,
716    D: ArrayIndex<'d>,
717    E: ArrayIndex<'e>,
718    Val: AsRef<Array>,
719{
720    fn try_index_mut_device(
721        &mut self,
722        i: (A, B, C, D, E),
723        val: Val,
724        stream: impl AsRef<Stream>,
725    ) -> Result<()> {
726        let operations = [
727            i.0.index_op(),
728            i.1.index_op(),
729            i.2.index_op(),
730            i.3.index_op(),
731            i.4.index_op(),
732        ];
733        let update = val.as_ref();
734        self.try_index_mut_device_inner(&operations, update, stream)
735    }
736}
737
738impl<'a, 'b, 'c, 'd, 'e, 'f, A, B, C, D, E, F, Val> TryIndexMutOp<(A, B, C, D, E, F), Val> for Array
739where
740    A: ArrayIndex<'a>,
741    B: ArrayIndex<'b>,
742    C: ArrayIndex<'c>,
743    D: ArrayIndex<'d>,
744    E: ArrayIndex<'e>,
745    F: ArrayIndex<'f>,
746    Val: AsRef<Array>,
747{
748    fn try_index_mut_device(
749        &mut self,
750        i: (A, B, C, D, E, F),
751        val: Val,
752        stream: impl AsRef<Stream>,
753    ) -> Result<()> {
754        let operations = [
755            i.0.index_op(),
756            i.1.index_op(),
757            i.2.index_op(),
758            i.3.index_op(),
759            i.4.index_op(),
760            i.5.index_op(),
761        ];
762        let update = val.as_ref();
763        self.try_index_mut_device_inner(&operations, update, stream)
764    }
765}
766
767impl<'a, 'b, 'c, 'd, 'e, 'f, 'g, A, B, C, D, E, F, G, Val> TryIndexMutOp<(A, B, C, D, E, F, G), Val>
768    for Array
769where
770    A: ArrayIndex<'a>,
771    B: ArrayIndex<'b>,
772    C: ArrayIndex<'c>,
773    D: ArrayIndex<'d>,
774    E: ArrayIndex<'e>,
775    F: ArrayIndex<'f>,
776    G: ArrayIndex<'g>,
777    Val: AsRef<Array>,
778{
779    fn try_index_mut_device(
780        &mut self,
781        i: (A, B, C, D, E, F, G),
782        val: Val,
783        stream: impl AsRef<Stream>,
784    ) -> Result<()> {
785        let operations = [
786            i.0.index_op(),
787            i.1.index_op(),
788            i.2.index_op(),
789            i.3.index_op(),
790            i.4.index_op(),
791            i.5.index_op(),
792            i.6.index_op(),
793        ];
794        let update = val.as_ref();
795        self.try_index_mut_device_inner(&operations, update, stream)
796    }
797}
798
799impl<'a, 'b, 'c, 'd, 'e, 'f, 'g, 'h, A, B, C, D, E, F, G, H, Val>
800    TryIndexMutOp<(A, B, C, D, E, F, G, H), Val> for Array
801where
802    A: ArrayIndex<'a>,
803    B: ArrayIndex<'b>,
804    C: ArrayIndex<'c>,
805    D: ArrayIndex<'d>,
806    E: ArrayIndex<'e>,
807    F: ArrayIndex<'f>,
808    G: ArrayIndex<'g>,
809    H: ArrayIndex<'h>,
810    Val: AsRef<Array>,
811{
812    fn try_index_mut_device(
813        &mut self,
814        i: (A, B, C, D, E, F, G, H),
815        val: Val,
816        stream: impl AsRef<Stream>,
817    ) -> Result<()> {
818        let operations = [
819            i.0.index_op(),
820            i.1.index_op(),
821            i.2.index_op(),
822            i.3.index_op(),
823            i.4.index_op(),
824            i.5.index_op(),
825            i.6.index_op(),
826            i.7.index_op(),
827        ];
828        let update = val.as_ref();
829        self.try_index_mut_device_inner(&operations, update, stream)
830    }
831}
832
833impl<'a, 'b, 'c, 'd, 'e, 'f, 'g, 'h, 'i, A, B, C, D, E, F, G, H, I, Val>
834    TryIndexMutOp<(A, B, C, D, E, F, G, H, I), Val> for Array
835where
836    A: ArrayIndex<'a>,
837    B: ArrayIndex<'b>,
838    C: ArrayIndex<'c>,
839    D: ArrayIndex<'d>,
840    E: ArrayIndex<'e>,
841    F: ArrayIndex<'f>,
842    G: ArrayIndex<'g>,
843    H: ArrayIndex<'h>,
844    I: ArrayIndex<'i>,
845    Val: AsRef<Array>,
846{
847    fn try_index_mut_device(
848        &mut self,
849        i: (A, B, C, D, E, F, G, H, I),
850        val: Val,
851        stream: impl AsRef<Stream>,
852    ) -> Result<()> {
853        let operations = [
854            i.0.index_op(),
855            i.1.index_op(),
856            i.2.index_op(),
857            i.3.index_op(),
858            i.4.index_op(),
859            i.5.index_op(),
860            i.6.index_op(),
861            i.7.index_op(),
862            i.8.index_op(),
863        ];
864        let update = val.as_ref();
865        self.try_index_mut_device_inner(&operations, update, stream)
866    }
867}
868
869impl<'a, 'b, 'c, 'd, 'e, 'f, 'g, 'h, 'i, 'j, A, B, C, D, E, F, G, H, I, J, Val>
870    TryIndexMutOp<(A, B, C, D, E, F, G, H, I, J), Val> for Array
871where
872    A: ArrayIndex<'a>,
873    B: ArrayIndex<'b>,
874    C: ArrayIndex<'c>,
875    D: ArrayIndex<'d>,
876    E: ArrayIndex<'e>,
877    F: ArrayIndex<'f>,
878    G: ArrayIndex<'g>,
879    H: ArrayIndex<'h>,
880    I: ArrayIndex<'i>,
881    J: ArrayIndex<'j>,
882    Val: AsRef<Array>,
883{
884    fn try_index_mut_device(
885        &mut self,
886        i: (A, B, C, D, E, F, G, H, I, J),
887        val: Val,
888        stream: impl AsRef<Stream>,
889    ) -> Result<()> {
890        let operations = [
891            i.0.index_op(),
892            i.1.index_op(),
893            i.2.index_op(),
894            i.3.index_op(),
895            i.4.index_op(),
896            i.5.index_op(),
897            i.6.index_op(),
898            i.7.index_op(),
899            i.8.index_op(),
900            i.9.index_op(),
901        ];
902        let update = val.as_ref();
903        self.try_index_mut_device_inner(&operations, update, stream)
904    }
905}
906
907impl<'a, 'b, 'c, 'd, 'e, 'f, 'g, 'h, 'i, 'j, 'k, A, B, C, D, E, F, G, H, I, J, K, Val>
908    TryIndexMutOp<(A, B, C, D, E, F, G, H, I, J, K), Val> for Array
909where
910    A: ArrayIndex<'a>,
911    B: ArrayIndex<'b>,
912    C: ArrayIndex<'c>,
913    D: ArrayIndex<'d>,
914    E: ArrayIndex<'e>,
915    F: ArrayIndex<'f>,
916    G: ArrayIndex<'g>,
917    H: ArrayIndex<'h>,
918    I: ArrayIndex<'i>,
919    J: ArrayIndex<'j>,
920    K: ArrayIndex<'k>,
921    Val: AsRef<Array>,
922{
923    fn try_index_mut_device(
924        &mut self,
925        i: (A, B, C, D, E, F, G, H, I, J, K),
926        val: Val,
927        stream: impl AsRef<Stream>,
928    ) -> Result<()> {
929        let operations = [
930            i.0.index_op(),
931            i.1.index_op(),
932            i.2.index_op(),
933            i.3.index_op(),
934            i.4.index_op(),
935            i.5.index_op(),
936            i.6.index_op(),
937            i.7.index_op(),
938            i.8.index_op(),
939            i.9.index_op(),
940            i.10.index_op(),
941        ];
942        let update = val.as_ref();
943        self.try_index_mut_device_inner(&operations, update, stream)
944    }
945}
946
947impl<'a, 'b, 'c, 'd, 'e, 'f, 'g, 'h, 'i, 'j, 'k, 'l, A, B, C, D, E, F, G, H, I, J, K, L, Val>
948    TryIndexMutOp<(A, B, C, D, E, F, G, H, I, J, K, L), Val> for Array
949where
950    A: ArrayIndex<'a>,
951    B: ArrayIndex<'b>,
952    C: ArrayIndex<'c>,
953    D: ArrayIndex<'d>,
954    E: ArrayIndex<'e>,
955    F: ArrayIndex<'f>,
956    G: ArrayIndex<'g>,
957    H: ArrayIndex<'h>,
958    I: ArrayIndex<'i>,
959    J: ArrayIndex<'j>,
960    K: ArrayIndex<'k>,
961    L: ArrayIndex<'l>,
962    Val: AsRef<Array>,
963{
964    fn try_index_mut_device(
965        &mut self,
966        i: (A, B, C, D, E, F, G, H, I, J, K, L),
967        val: Val,
968        stream: impl AsRef<Stream>,
969    ) -> Result<()> {
970        let operations = [
971            i.0.index_op(),
972            i.1.index_op(),
973            i.2.index_op(),
974            i.3.index_op(),
975            i.4.index_op(),
976            i.5.index_op(),
977            i.6.index_op(),
978            i.7.index_op(),
979            i.8.index_op(),
980            i.9.index_op(),
981            i.10.index_op(),
982            i.11.index_op(),
983        ];
984        let update = val.as_ref();
985        self.try_index_mut_device_inner(&operations, update, stream)
986    }
987}
988
989impl<
990        'a,
991        'b,
992        'c,
993        'd,
994        'e,
995        'f,
996        'g,
997        'h,
998        'i,
999        'j,
1000        'k,
1001        'l,
1002        'm,
1003        A,
1004        B,
1005        C,
1006        D,
1007        E,
1008        F,
1009        G,
1010        H,
1011        I,
1012        J,
1013        K,
1014        L,
1015        M,
1016        Val,
1017    > TryIndexMutOp<(A, B, C, D, E, F, G, H, I, J, K, L, M), Val> for Array
1018where
1019    A: ArrayIndex<'a>,
1020    B: ArrayIndex<'b>,
1021    C: ArrayIndex<'c>,
1022    D: ArrayIndex<'d>,
1023    E: ArrayIndex<'e>,
1024    F: ArrayIndex<'f>,
1025    G: ArrayIndex<'g>,
1026    H: ArrayIndex<'h>,
1027    I: ArrayIndex<'i>,
1028    J: ArrayIndex<'j>,
1029    K: ArrayIndex<'k>,
1030    L: ArrayIndex<'l>,
1031    M: ArrayIndex<'m>,
1032    Val: AsRef<Array>,
1033{
1034    fn try_index_mut_device(
1035        &mut self,
1036        i: (A, B, C, D, E, F, G, H, I, J, K, L, M),
1037        val: Val,
1038        stream: impl AsRef<Stream>,
1039    ) -> Result<()> {
1040        let operations = [
1041            i.0.index_op(),
1042            i.1.index_op(),
1043            i.2.index_op(),
1044            i.3.index_op(),
1045            i.4.index_op(),
1046            i.5.index_op(),
1047            i.6.index_op(),
1048            i.7.index_op(),
1049            i.8.index_op(),
1050            i.9.index_op(),
1051            i.10.index_op(),
1052            i.11.index_op(),
1053            i.12.index_op(),
1054        ];
1055        let update = val.as_ref();
1056        self.try_index_mut_device_inner(&operations, update, stream)
1057    }
1058}
1059
1060impl<
1061        'a,
1062        'b,
1063        'c,
1064        'd,
1065        'e,
1066        'f,
1067        'g,
1068        'h,
1069        'i,
1070        'j,
1071        'k,
1072        'l,
1073        'm,
1074        'n,
1075        A,
1076        B,
1077        C,
1078        D,
1079        E,
1080        F,
1081        G,
1082        H,
1083        I,
1084        J,
1085        K,
1086        L,
1087        M,
1088        N,
1089        Val,
1090    > TryIndexMutOp<(A, B, C, D, E, F, G, H, I, J, K, L, M, N), Val> for Array
1091where
1092    A: ArrayIndex<'a>,
1093    B: ArrayIndex<'b>,
1094    C: ArrayIndex<'c>,
1095    D: ArrayIndex<'d>,
1096    E: ArrayIndex<'e>,
1097    F: ArrayIndex<'f>,
1098    G: ArrayIndex<'g>,
1099    H: ArrayIndex<'h>,
1100    I: ArrayIndex<'i>,
1101    J: ArrayIndex<'j>,
1102    K: ArrayIndex<'k>,
1103    L: ArrayIndex<'l>,
1104    M: ArrayIndex<'m>,
1105    N: ArrayIndex<'n>,
1106    Val: AsRef<Array>,
1107{
1108    fn try_index_mut_device(
1109        &mut self,
1110        i: (A, B, C, D, E, F, G, H, I, J, K, L, M, N),
1111        val: Val,
1112        stream: impl AsRef<Stream>,
1113    ) -> Result<()> {
1114        let operations = [
1115            i.0.index_op(),
1116            i.1.index_op(),
1117            i.2.index_op(),
1118            i.3.index_op(),
1119            i.4.index_op(),
1120            i.5.index_op(),
1121            i.6.index_op(),
1122            i.7.index_op(),
1123            i.8.index_op(),
1124            i.9.index_op(),
1125            i.10.index_op(),
1126            i.11.index_op(),
1127            i.12.index_op(),
1128            i.13.index_op(),
1129        ];
1130        let update = val.as_ref();
1131        self.try_index_mut_device_inner(&operations, update, stream)
1132    }
1133}
1134
1135impl<
1136        'a,
1137        'b,
1138        'c,
1139        'd,
1140        'e,
1141        'f,
1142        'g,
1143        'h,
1144        'i,
1145        'j,
1146        'k,
1147        'l,
1148        'm,
1149        'n,
1150        'o,
1151        A,
1152        B,
1153        C,
1154        D,
1155        E,
1156        F,
1157        G,
1158        H,
1159        I,
1160        J,
1161        K,
1162        L,
1163        M,
1164        N,
1165        O,
1166        Val,
1167    > TryIndexMutOp<(A, B, C, D, E, F, G, H, I, J, K, L, M, N, O), Val> for Array
1168where
1169    A: ArrayIndex<'a>,
1170    B: ArrayIndex<'b>,
1171    C: ArrayIndex<'c>,
1172    D: ArrayIndex<'d>,
1173    E: ArrayIndex<'e>,
1174    F: ArrayIndex<'f>,
1175    G: ArrayIndex<'g>,
1176    H: ArrayIndex<'h>,
1177    I: ArrayIndex<'i>,
1178    J: ArrayIndex<'j>,
1179    K: ArrayIndex<'k>,
1180    L: ArrayIndex<'l>,
1181    M: ArrayIndex<'m>,
1182    N: ArrayIndex<'n>,
1183    O: ArrayIndex<'o>,
1184    Val: AsRef<Array>,
1185{
1186    fn try_index_mut_device(
1187        &mut self,
1188        i: (A, B, C, D, E, F, G, H, I, J, K, L, M, N, O),
1189        val: Val,
1190        stream: impl AsRef<Stream>,
1191    ) -> Result<()> {
1192        let operations = [
1193            i.0.index_op(),
1194            i.1.index_op(),
1195            i.2.index_op(),
1196            i.3.index_op(),
1197            i.4.index_op(),
1198            i.5.index_op(),
1199            i.6.index_op(),
1200            i.7.index_op(),
1201            i.8.index_op(),
1202            i.9.index_op(),
1203            i.10.index_op(),
1204            i.11.index_op(),
1205            i.12.index_op(),
1206            i.13.index_op(),
1207            i.14.index_op(),
1208        ];
1209        let update = val.as_ref();
1210        self.try_index_mut_device_inner(&operations, update, stream)
1211    }
1212}
1213
1214impl<
1215        'a,
1216        'b,
1217        'c,
1218        'd,
1219        'e,
1220        'f,
1221        'g,
1222        'h,
1223        'i,
1224        'j,
1225        'k,
1226        'l,
1227        'm,
1228        'n,
1229        'o,
1230        'p,
1231        A,
1232        B,
1233        C,
1234        D,
1235        E,
1236        F,
1237        G,
1238        H,
1239        I,
1240        J,
1241        K,
1242        L,
1243        M,
1244        N,
1245        O,
1246        P,
1247        Val,
1248    > TryIndexMutOp<(A, B, C, D, E, F, G, H, I, J, K, L, M, N, O, P), Val> for Array
1249where
1250    A: ArrayIndex<'a>,
1251    B: ArrayIndex<'b>,
1252    C: ArrayIndex<'c>,
1253    D: ArrayIndex<'d>,
1254    E: ArrayIndex<'e>,
1255    F: ArrayIndex<'f>,
1256    G: ArrayIndex<'g>,
1257    H: ArrayIndex<'h>,
1258    I: ArrayIndex<'i>,
1259    J: ArrayIndex<'j>,
1260    K: ArrayIndex<'k>,
1261    L: ArrayIndex<'l>,
1262    M: ArrayIndex<'m>,
1263    N: ArrayIndex<'n>,
1264    O: ArrayIndex<'o>,
1265    P: ArrayIndex<'p>,
1266    Val: AsRef<Array>,
1267{
1268    fn try_index_mut_device(
1269        &mut self,
1270        i: (A, B, C, D, E, F, G, H, I, J, K, L, M, N, O, P),
1271        val: Val,
1272        stream: impl AsRef<Stream>,
1273    ) -> Result<()> {
1274        let operations = [
1275            i.0.index_op(),
1276            i.1.index_op(),
1277            i.2.index_op(),
1278            i.3.index_op(),
1279            i.4.index_op(),
1280            i.5.index_op(),
1281            i.6.index_op(),
1282            i.7.index_op(),
1283            i.8.index_op(),
1284            i.9.index_op(),
1285            i.10.index_op(),
1286            i.11.index_op(),
1287            i.12.index_op(),
1288            i.13.index_op(),
1289            i.14.index_op(),
1290            i.15.index_op(),
1291        ];
1292        let update = val.as_ref();
1293        self.try_index_mut_device_inner(&operations, update, stream)
1294    }
1295}
1296
1297/// The unit tests below are adapted from the Swift binding tests
1298#[cfg(test)]
1299mod tests {
1300    use crate::{
1301        ops::{indexing::*, ones, zeros},
1302        Array,
1303    };
1304
1305    #[test]
1306    fn test_array_mutate_single_index() {
1307        let mut a = Array::from_iter(0i32..12, &[3, 4]);
1308        let new_value = Array::from_int(77);
1309        a.index_mut(1, new_value);
1310
1311        let expected = Array::from_slice(&[0, 1, 2, 3, 77, 77, 77, 77, 8, 9, 10, 11], &[3, 4]);
1312        assert_array_all_close!(a, expected);
1313    }
1314
1315    #[test]
1316    fn test_array_mutate_broadcast_multi_index() {
1317        let mut a = Array::from_iter(0i32..20, &[2, 2, 5]);
1318
1319        // broadcast to a row
1320        a.index_mut((1, 0), Array::from_int(77));
1321
1322        // assign to a row
1323        a.index_mut((0, 0), Array::from_slice(&[55i32, 66, 77, 88, 99], &[5]));
1324
1325        // single element
1326        a.index_mut((0, 1, 3), Array::from_int(123));
1327
1328        let expected = Array::from_slice(
1329            &[
1330                55, 66, 77, 88, 99, 5, 6, 7, 123, 9, 77, 77, 77, 77, 77, 15, 16, 17, 18, 19,
1331            ],
1332            &[2, 2, 5],
1333        );
1334        assert_array_all_close!(a, expected);
1335    }
1336
1337    #[test]
1338    fn test_array_mutate_broadcast_slice() {
1339        let mut a = Array::from_iter(0i32..20, &[2, 2, 5]);
1340
1341        // writing using slices -- this ends up covering two elements
1342        a.index_mut((0..1, 1..2, 2..4), Array::from_int(88));
1343
1344        let expected = Array::from_slice(
1345            &[
1346                0, 1, 2, 3, 4, 5, 6, 88, 88, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19,
1347            ],
1348            &[2, 2, 5],
1349        );
1350        assert_array_all_close!(a, expected);
1351    }
1352
1353    #[test]
1354    fn test_array_mutate_advanced() {
1355        let mut a = Array::from_iter(0i32..35, &[5, 7]);
1356
1357        let i1 = Array::from_slice(&[0, 2, 4], &[3]);
1358        let i2 = Array::from_slice(&[0, 1, 2], &[3]);
1359
1360        a.index_mut((i1, i2), Array::from_slice(&[100, 200, 300], &[3]));
1361
1362        assert_eq!(a.index((0, 0)).item::<i32>(), 100i32);
1363        assert_eq!(a.index((2, 1)).item::<i32>(), 200i32);
1364        assert_eq!(a.index((4, 2)).item::<i32>(), 300i32);
1365    }
1366
1367    #[test]
1368    fn test_full_index_write_single() {
1369        fn check<I>(index: I, expected_sum: i32)
1370        where
1371            for<'a> I: ArrayIndex<'a>,
1372        {
1373            let mut a = Array::from_iter(0..60, &[3, 4, 5]);
1374
1375            a.index_mut(index, Array::from_int(1));
1376            let sum = a.sum(None).unwrap().item::<i32>();
1377            assert_eq!(sum, expected_sum);
1378        }
1379
1380        // a[...]
1381        // not valid
1382
1383        // a[None]
1384        check(NewAxis, 60);
1385
1386        // a[0]
1387        check(0, 1600);
1388
1389        // a[1:3]
1390        check(1..3, 230);
1391
1392        // i = mx.array([2, 1])
1393        let i = Array::from_slice(&[2, 1], &[2]);
1394
1395        // a[i]
1396        check(i, 230);
1397    }
1398
1399    #[test]
1400    fn test_full_index_write_no_array() {
1401        macro_rules! check {
1402            (($( $i:expr ),*), $sum:expr ) => {
1403                {
1404                    let mut a = Array::from_iter(0..360, &[2, 3, 4, 5, 3]);
1405
1406                    a.index_mut(($($i),*), Array::from_int(1));
1407                    let sum = a.sum(None).unwrap().item::<i32>();
1408                    assert_eq!(sum, $sum);
1409                }
1410            };
1411        }
1412
1413        // a[..., 0] = 1
1414        check!((Ellipsis, 0), 43320);
1415
1416        // a[0, ...] = 1
1417        check!((0, Ellipsis), 48690);
1418
1419        // a[0, ..., 0] = 1
1420        check!((0, Ellipsis, 0), 59370);
1421
1422        // a[..., ::2, :] = 1
1423        check!((Ellipsis, (..).stride_by(2), ..), 26064);
1424
1425        // a[..., None, ::2, -1]
1426        check!((Ellipsis, NewAxis, (..).stride_by(2), -1), 51696);
1427
1428        // a[:, 2:, 0] = 1
1429        check!((.., 2.., 0), 58140);
1430
1431        // a[::-1, :2, 2:, ..., None, ::2] = 1
1432        check!(
1433            (
1434                (..).stride_by(-1),
1435                ..2,
1436                2..,
1437                Ellipsis,
1438                NewAxis,
1439                (..).stride_by(2)
1440            ),
1441            51540
1442        );
1443    }
1444
1445    #[test]
1446    fn test_full_index_write_array() {
1447        // these have an Array as a source of indices and go through the gather path
1448
1449        macro_rules! check {
1450            (($( $i:expr ),*), $sum:expr ) => {
1451                {
1452                    let mut a = Array::from_iter(0..540, &[3, 3, 4, 5, 3]);
1453
1454                    a.index_mut(($($i),*), Array::from_int(1));
1455                    let sum = a.sum(None).unwrap().item::<i32>();
1456                    assert_eq!(sum, $sum);
1457                }
1458            };
1459        }
1460
1461        // i = mx.array([2, 1])
1462        let i = Array::from_slice(&[2, 1], &[2]);
1463
1464        // a[0, i] = 1
1465        check!((0, &i), 131310);
1466
1467        // a[..., i, 0] = 1
1468        check!((Ellipsis, &i, 0), 126378);
1469
1470        // a[i, 0, ...] = 1
1471        check!((&i, 0, Ellipsis), 109710);
1472
1473        // a[i, ..., i] = 1
1474        check!((&i, Ellipsis, &i), 102450);
1475
1476        // a[i, ..., ::2, :] = 1
1477        check!((&i, Ellipsis, (..).stride_by(2), ..), 68094);
1478
1479        // a[..., i, None, ::2, -1] = 1
1480        check!((Ellipsis, &i, NewAxis, (..).stride_by(2), -1), 130977);
1481
1482        // a[:, 2:, i] = 1
1483        check!((.., 2.., &i), 115965);
1484
1485        // a[::-1, :2, i, 2:, ..., None, ::2] = 1
1486        check!(
1487            (
1488                (..).stride_by(-1),
1489                ..2,
1490                i,
1491                2..,
1492                Ellipsis,
1493                NewAxis,
1494                (..).stride_by(2)
1495            ),
1496            128142
1497        );
1498    }
1499
1500    #[test]
1501    fn test_slice_update_with_broadcast() {
1502        let mut xs = zeros::<f32>(&[4, 3, 2]).unwrap();
1503        let x = ones::<f32>(&[4, 2]).unwrap();
1504
1505        let result = xs.try_index_mut((.., 0, ..), x);
1506        assert!(
1507            result.is_ok(),
1508            "Failed to update slice with broadcast: {result:?}"
1509        );
1510    }
1511}