mlx_rs/ops/indexing/
indexmut_impl.rs

1use std::borrow::Cow;
2
3use smallvec::{smallvec, SmallVec};
4
5use crate::{
6    constants::DEFAULT_STACK_VEC_LEN,
7    error::Result,
8    ops::{
9        broadcast_arrays_device, broadcast_to_device,
10        indexing::{count_non_new_axis_operations, expand_ellipsis_operations},
11    },
12    utils::{resolve_index_signed_unchecked, VectorArray},
13    Array, Stream,
14};
15
16use super::{ArrayIndex, ArrayIndexOp, Guarded, RangeIndex, TryIndexMutOp};
17
18impl Array {
19    pub(crate) fn slice_update_device(
20        &self,
21        update: &Array,
22        starts: &[i32],
23        ends: &[i32],
24        strides: &[i32],
25        stream: impl AsRef<Stream>,
26    ) -> Result<Array> {
27        Array::try_from_op(|res| unsafe {
28            mlx_sys::mlx_slice_update(
29                res,
30                self.as_ptr(),
31                update.as_ptr(),
32                starts.as_ptr(),
33                starts.len(),
34                ends.as_ptr(),
35                ends.len(),
36                strides.as_ptr(),
37                strides.len(),
38                stream.as_ref().as_ptr(),
39            )
40        })
41    }
42}
43
44// See `updateSlice` in the swift binding or `mlx_slice_update` in the python binding
45fn update_slice(
46    src: &Array,
47    operations: &[ArrayIndexOp],
48    update: &Array,
49    stream: impl AsRef<Stream>,
50) -> Result<Option<Array>> {
51    let ndim = src.ndim();
52    if ndim == 0 || operations.is_empty() {
53        return Ok(None);
54    }
55
56    // Remove leading singletons dimensions from the update
57    let mut update = remove_leading_singleton_dimensions(update, &stream)?;
58
59    // Build slice update params
60    let mut starts: SmallVec<[i32; DEFAULT_STACK_VEC_LEN]> = smallvec![0; ndim];
61    let mut ends: SmallVec<[i32; DEFAULT_STACK_VEC_LEN]> = SmallVec::from_slice(src.shape());
62    let mut strides: SmallVec<[i32; DEFAULT_STACK_VEC_LEN]> = smallvec![1; ndim];
63
64    // If it's just a simple slice, just do a slice update and return
65    if operations.len() == 1 {
66        if let ArrayIndexOp::Slice(range_index) = &operations[0] {
67            let size = src.dim(0);
68            starts[0] = range_index.start(size);
69            ends[0] = range_index.end(size);
70            strides[0] = range_index.stride();
71
72            return Ok(Some(src.slice_update_device(
73                &update, &starts, &ends, &strides, &stream,
74            )?));
75        }
76    }
77
78    // Can't route to slice update if any arrays are present
79    if operations.iter().any(|op| op.is_array()) {
80        return Ok(None);
81    }
82
83    // Expand ellipses into a series of ':' (range full) slices
84    let operations = expand_ellipsis_operations(ndim, operations);
85
86    // If no non-None indices return the broadcasted update
87    if count_non_new_axis_operations(&operations) == 0 {
88        return Ok(Some(broadcast_to_device(&update, src.shape(), &stream)?));
89    }
90
91    // Process entries
92    let mut update_expand_dims: SmallVec<[i32; DEFAULT_STACK_VEC_LEN]> = SmallVec::new();
93    let mut axis = 0i32;
94    for item in operations.iter() {
95        use ArrayIndexOp::*;
96
97        match item {
98            TakeIndex { index } => {
99                let size = src.dim(axis);
100                let index = if index.is_negative() {
101                    size + index
102                } else {
103                    *index
104                };
105                // SAFETY: axis is always non-negative
106                starts[axis as usize] = index;
107                ends[axis as usize] = index + 1;
108                if ndim - (axis as usize) < update.ndim() {
109                    update_expand_dims.push(axis.saturating_sub_unsigned(ndim as u32));
110                }
111
112                axis = axis.saturating_add(1);
113            }
114            Slice(range_index) => {
115                let size = src.dim(axis);
116                // SAFETY: axis is always non-negative
117                starts[axis as usize] = range_index.start(size);
118                ends[axis as usize] = range_index.end(size);
119                strides[axis as usize] = range_index.stride();
120                axis = axis.saturating_add(1);
121            }
122            ExpandDims => {}
123            Ellipsis | TakeArray { indices: _ } | TakeArrayRef { indices: _ } => {
124                panic!("unexpected item in operations")
125            }
126        }
127    }
128
129    if !update_expand_dims.is_empty() {
130        update = Cow::Owned(update.expand_dims_device(&update_expand_dims, &stream)?);
131    }
132
133    Ok(Some(src.slice_update_device(
134        &update, &starts, &ends, &strides, &stream,
135    )?))
136}
137
138// See `leadingSingletonDimensionsRemoved` in the swift binding
139fn remove_leading_singleton_dimensions(
140    a: &Array,
141    stream: impl AsRef<Stream>,
142) -> Result<Cow<'_, Array>> {
143    let shape = a.shape();
144    let mut new_shape: Vec<_> = shape.iter().skip_while(|&&dim| dim == 1).cloned().collect();
145    if shape != new_shape {
146        if new_shape.is_empty() {
147            new_shape = vec![1];
148        }
149        Ok(Cow::Owned(a.reshape_device(&new_shape, stream)?))
150    } else {
151        Ok(Cow::Borrowed(a))
152    }
153}
154
155struct ScatterArgs<'a> {
156    indices: SmallVec<[Cow<'a, Array>; DEFAULT_STACK_VEC_LEN]>,
157    update: Array,
158    axes: SmallVec<[i32; DEFAULT_STACK_VEC_LEN]>,
159}
160
161/// See `scatterArguments` in the swift binding
162fn scatter_args<'a>(
163    src: &'a Array,
164    operations: &'a [ArrayIndexOp],
165    update: &Array,
166    stream: impl AsRef<Stream>,
167) -> Result<ScatterArgs<'a>> {
168    use ArrayIndexOp::*;
169
170    if operations.len() == 1 {
171        return match &operations[0] {
172            TakeIndex { index } => scatter_args_index(src, *index, update, stream),
173            TakeArray { indices } => {
174                scatter_args_array(src, Cow::Borrowed(indices), update, stream)
175            }
176            TakeArrayRef { indices } => {
177                scatter_args_array(src, Cow::Borrowed(indices), update, stream)
178            }
179            Slice(range_index) => scatter_args_slice(src, range_index, update, stream),
180            ExpandDims => Ok(ScatterArgs {
181                indices: smallvec![],
182                update: broadcast_to_device(update, src.shape(), &stream)?,
183                axes: smallvec![],
184            }),
185            Ellipsis => panic!("Unable to update array with ellipsis argument"),
186        };
187    }
188
189    scatter_args_nd(src, operations, update, stream)
190}
191
192fn scatter_args_index<'a>(
193    src: &'a Array,
194    index: i32,
195    update: &Array,
196    stream: impl AsRef<Stream>,
197) -> Result<ScatterArgs<'a>> {
198    // mlx_scatter_args_index
199
200    // Remove any leading singleton dimensions from the update
201    // and then broadcast update to shape of src[0, ...]
202    let update = remove_leading_singleton_dimensions(update, &stream)?;
203
204    let mut shape: SmallVec<[i32; DEFAULT_STACK_VEC_LEN]> = SmallVec::from_slice(src.shape());
205    shape[0] = 1;
206
207    Ok(ScatterArgs {
208        indices: smallvec![Cow::Owned(Array::from_int(resolve_index_signed_unchecked(
209            index,
210            src.dim(0)
211        )))],
212        update: broadcast_to_device(&update, &shape, &stream)?,
213        axes: smallvec![0],
214    })
215}
216
217fn scatter_args_array<'a>(
218    src: &'a Array,
219    a: Cow<'a, Array>,
220    update: &Array,
221    stream: impl AsRef<Stream>,
222) -> Result<ScatterArgs<'a>> {
223    // mlx_scatter_args_array
224
225    // trim leading singleton dimensions
226    let update = remove_leading_singleton_dimensions(update, &stream)?;
227
228    // The update shape must broadcast with indices.shape + [1] + src.shape[1:]
229    let mut update_shape: SmallVec<[i32; DEFAULT_STACK_VEC_LEN]> = a
230        .shape()
231        .iter()
232        .chain(src.shape().iter().skip(1))
233        .cloned()
234        .collect();
235    let update = broadcast_to_device(&update, &update_shape, &stream)?;
236
237    update_shape.insert(a.ndim(), 1);
238    let update = update.reshape_device(&update_shape, &stream)?;
239
240    Ok(ScatterArgs {
241        indices: smallvec![a],
242        update,
243        axes: smallvec![0],
244    })
245}
246
247fn scatter_args_slice<'a>(
248    src: &'a Array,
249    range_index: &'a RangeIndex,
250    update: &Array,
251    stream: impl AsRef<Stream>,
252) -> Result<ScatterArgs<'a>> {
253    // mlx_scatter_args_slice
254
255    // if none slice is requested braodcast the update to the src size and return it
256    if range_index.is_full() {
257        let update = remove_leading_singleton_dimensions(update, &stream)?;
258
259        return Ok(ScatterArgs {
260            indices: smallvec![],
261            update: broadcast_to_device(&update, src.shape(), &stream)?,
262            axes: smallvec![],
263        });
264    }
265
266    let size = src.dim(0);
267    let start = range_index.start(size);
268    let end = range_index.end(size);
269    let stride = range_index.stride();
270
271    // If simple stride
272    if stride == 1 {
273        let update = remove_leading_singleton_dimensions(update, &stream)?;
274
275        // Broadcast update to slice size
276        let update_broadcast_shape: SmallVec<[i32; DEFAULT_STACK_VEC_LEN]> = (1..end - start)
277            .chain(src.shape().iter().skip(1).cloned())
278            .collect();
279        let update = broadcast_to_device(&update, &update_broadcast_shape, &stream)?;
280
281        let indices = Array::from_slice(&[start], &[1]);
282        Ok(ScatterArgs {
283            indices: smallvec![Cow::Owned(indices)],
284            update,
285            axes: smallvec![0],
286        })
287    } else {
288        // stride != 1, convert the slice to an array
289        let a_vals = strided_range_to_vec(start, end, stride);
290        let a = Array::from_slice(&a_vals, &[a_vals.len() as i32]);
291
292        scatter_args_array(src, Cow::Owned(a), update, stream)
293    }
294}
295
296fn scatter_args_nd<'a>(
297    src: &'a Array,
298    operations: &[ArrayIndexOp],
299    update: &Array,
300    stream: impl AsRef<Stream>,
301) -> Result<ScatterArgs<'a>> {
302    use ArrayIndexOp::*;
303
304    // mlx_scatter_args_nd
305
306    let shape = src.shape();
307
308    let operations = expand_ellipsis_operations(src.ndim(), operations);
309    let update = remove_leading_singleton_dimensions(update, &stream)?;
310
311    // If no non-newAxis indices return the broadcasted update
312    let non_new_axis_operation_count = count_non_new_axis_operations(&operations);
313    if non_new_axis_operation_count == 0 {
314        return Ok(ScatterArgs {
315            indices: smallvec![],
316            update: broadcast_to_device(&update, shape, &stream)?,
317            axes: smallvec![],
318        });
319    }
320
321    // Analyse the types of the indices
322    let mut max_dims = 0;
323    let mut arrays_first = false;
324    let mut count_new_axis: i32 = 0;
325    let mut count_slices: i32 = 0;
326    let mut count_arrays: i32 = 0;
327    let mut count_strided_slices: i32 = 0;
328    let mut count_simple_slices_post: i32 = 0;
329
330    let mut have_array = false;
331    let mut have_non_array = false;
332
333    macro_rules! analyze_indices_take_array {
334        ($indices:ident) => {
335            have_array = true;
336            if have_array && have_non_array {
337                arrays_first = true;
338            }
339            max_dims = $indices.ndim().max(max_dims);
340            count_arrays = count_arrays.saturating_add(1);
341            count_simple_slices_post = 0;
342        };
343    }
344
345    for item in operations.iter() {
346        match item {
347            TakeIndex { index: _ } => {
348                // ignore
349            }
350            Slice(range_index) => {
351                have_non_array = have_array;
352                count_slices = count_slices.saturating_add(1);
353                if range_index.stride() != 1 {
354                    count_strided_slices = count_strided_slices.saturating_add(1);
355                    count_simple_slices_post = 0;
356                } else {
357                    count_simple_slices_post = count_simple_slices_post.saturating_add(1);
358                }
359            }
360            TakeArray { indices } => {
361                analyze_indices_take_array!(indices);
362            }
363            TakeArrayRef { indices } => {
364                analyze_indices_take_array!(indices);
365            }
366            ExpandDims => {
367                have_non_array = true;
368                count_new_axis = count_new_axis.saturating_add(1);
369            }
370            Ellipsis => panic!("Unexpected item ellipsis in scatter_args_nd"),
371        }
372    }
373
374    // We have index dims for the arrays, strided slices (implemented as arrays), none
375    let mut index_dims = (max_dims + count_new_axis as usize + count_slices as usize)
376        .saturating_sub(count_simple_slices_post as usize);
377
378    // If we have simple non-strided slices, we also attach an index for that
379    if index_dims == 0 {
380        index_dims = 1;
381    }
382
383    // Go over each index type and translate to the needed scatter args
384    let mut array_indices: SmallVec<[Array; DEFAULT_STACK_VEC_LEN]> =
385        SmallVec::with_capacity(operations.len());
386    let mut slice_number: i32 = 0;
387    let mut array_number: i32 = 0;
388    let mut axis: i32 = 0;
389
390    // We collect the shapes of the slices and updates during this process
391    let mut update_shape = vec![1; non_new_axis_operation_count];
392    let mut slice_shapes: SmallVec<[i32; DEFAULT_STACK_VEC_LEN]> = SmallVec::new();
393
394    macro_rules! update_shapes_take_array {
395        ($indices:ident) => {
396            // Place the arrays in the correct dimension
397            let start = if arrays_first {
398                max_dims - $indices.ndim()
399            } else {
400                // SAFETY: slice_number is never decremented and should be non-negative
401                slice_number as usize + max_dims - $indices.ndim()
402            };
403            let mut new_shape = vec![1; index_dims];
404
405            for j in 0..$indices.ndim() {
406                new_shape[start + j] = $indices.dim(j as i32);
407            }
408
409            array_indices.push($indices.reshape_device(&new_shape, &stream)?);
410            array_number = array_number.saturating_add(1);
411
412            if !arrays_first && array_number == count_arrays {
413                slice_number = slice_number.saturating_add_unsigned(max_dims as u32);
414            }
415
416            // Add the shape to the update
417            update_shape[axis as usize] = 1;
418            axis = axis.saturating_add(1);
419        };
420    }
421
422    for item in operations.iter() {
423        match item {
424            TakeIndex { index } => {
425                let resolved_index = resolve_index_signed_unchecked(*index, src.dim(axis));
426                array_indices.push(Array::from_int(resolved_index));
427                // SAFETY: axis is always non-negative
428                update_shape[axis as usize] = 1;
429                axis = axis.saturating_add(1);
430            }
431            Slice(range_index) => {
432                let size = src.dim(axis);
433                let start = range_index.absolute_start(size);
434                let end = range_index.absolute_end(size);
435                let stride = range_index.stride();
436
437                let mut index_shape = vec![1; index_dims];
438
439                // If it's a simple slice, we only need to add the start index
440                if array_number >= count_arrays && count_strided_slices <= 0 && stride == 1 {
441                    let index = Array::from_int(start).reshape_device(&index_shape, &stream)?;
442                    let slice_shape_entry = end - start;
443                    slice_shapes.push(slice_shape_entry);
444                    array_indices.push(index);
445
446                    // Add the shape to the update
447                    update_shape[axis as usize] = slice_shape_entry;
448                } else {
449                    // Otherwise we expand the slice into indices using arange
450                    let index_vals = strided_range_to_vec(start, end, stride);
451                    let index = Array::from_slice(&index_vals, &[index_vals.len() as i32]);
452                    let location = if arrays_first {
453                        slice_number.saturating_add(max_dims as i32)
454                    } else {
455                        slice_number
456                    };
457                    index_shape[location as usize] = index.size() as i32;
458                    array_indices.push(index.reshape_device(&index_shape, &stream)?);
459
460                    slice_number = slice_number.saturating_add(1);
461                    count_strided_slices = count_strided_slices.saturating_sub(1);
462
463                    // Add the shape to the update
464                    update_shape[axis as usize] = 1;
465                }
466
467                axis = axis.saturating_add(1);
468            }
469            TakeArray { indices } => {
470                update_shapes_take_array!(indices);
471            }
472            TakeArrayRef { indices } => {
473                update_shapes_take_array!(indices);
474            }
475            ExpandDims => slice_number = slice_number.saturating_add(1),
476            Ellipsis => panic!("Unexpected item ellipsis in scatter_args_nd"),
477        }
478    }
479
480    // Broadcast the update to the indices and slices
481    let array_indices = broadcast_arrays_device(&array_indices, &stream)?;
482    let update_shape_broadcast: Vec<_> = array_indices[0]
483        .shape()
484        .iter()
485        .chain(slice_shapes.iter())
486        .chain(src.shape().iter().skip(non_new_axis_operation_count))
487        .cloned()
488        .collect();
489    let update = broadcast_to_device(&update, &update_shape_broadcast, &stream)?;
490
491    // Reshape the update with the size-1 dims for the int and array indices
492    let update_reshape: Vec<_> = array_indices[0]
493        .shape()
494        .iter()
495        .chain(update_shape.iter())
496        .chain(src.shape().iter().skip(non_new_axis_operation_count))
497        .cloned()
498        .collect();
499
500    let update = update.reshape_device(&update_reshape, &stream)?;
501
502    let array_indices_len = array_indices.len();
503
504    let indices = array_indices.into_iter().map(Cow::Owned).collect();
505    Ok(ScatterArgs {
506        indices,
507        update,
508        axes: (0..array_indices_len as i32).collect(),
509    })
510}
511
512fn strided_range_to_vec(start: i32, exclusive_end: i32, stride: i32) -> Vec<i32> {
513    let estimated_capacity = (exclusive_end - start).abs() / stride.abs();
514    let mut vec = Vec::with_capacity(estimated_capacity as usize);
515    let mut current = start;
516
517    assert_ne!(stride, 0, "Stride cannot be zero");
518
519    if stride.is_negative() {
520        while current > exclusive_end {
521            vec.push(current);
522            current += stride;
523        }
524    } else {
525        while current < exclusive_end {
526            vec.push(current);
527            current += stride;
528        }
529    }
530
531    vec
532}
533
534unsafe fn scatter_device(
535    a: &Array,
536    indices: &[impl AsRef<Array>],
537    updates: &Array,
538    axes: &[i32],
539    stream: impl AsRef<Stream>,
540) -> Result<Array> {
541    let indices_vector = VectorArray::try_from_iter(indices.iter())?;
542
543    Array::try_from_op(|res| unsafe {
544        mlx_sys::mlx_scatter(
545            res,
546            a.as_ptr(),
547            indices_vector.as_ptr(),
548            updates.as_ptr(),
549            axes.as_ptr(),
550            axes.len(),
551            stream.as_ref().as_ptr(),
552        )
553    })
554}
555
556impl Array {
557    fn try_index_mut_device_inner(
558        &mut self,
559        operations: &[ArrayIndexOp],
560        update: &Array,
561        stream: impl AsRef<Stream>,
562    ) -> Result<()> {
563        if let Some(result) = update_slice(self, operations, update, &stream)? {
564            *self = result;
565            return Ok(());
566        }
567
568        let ScatterArgs {
569            indices,
570            update,
571            axes,
572        } = scatter_args(self, operations, update, &stream)?;
573        if !indices.is_empty() {
574            let result = unsafe { scatter_device(self, &indices, &update, &axes, stream)? };
575            drop(indices);
576            *self = result;
577        } else {
578            drop(indices);
579            *self = update;
580        }
581        Ok(())
582    }
583}
584
585impl<'a, Val> TryIndexMutOp<&'a [ArrayIndexOp<'a>], Val> for Array
586where
587    Val: AsRef<Array>,
588{
589    fn try_index_mut_device(
590        &mut self,
591        i: &'a [ArrayIndexOp<'a>],
592        val: Val,
593        stream: impl AsRef<Stream>,
594    ) -> Result<()> {
595        let update = val.as_ref();
596        self.try_index_mut_device_inner(i, update, stream)
597    }
598}
599
600impl<A, Val> TryIndexMutOp<A, Val> for Array
601where
602    for<'a> A: ArrayIndex<'a>,
603    Val: AsRef<Array>,
604{
605    fn try_index_mut_device(&mut self, i: A, val: Val, stream: impl AsRef<Stream>) -> Result<()> {
606        let operations = [i.index_op()];
607        let update = val.as_ref();
608        self.try_index_mut_device_inner(&operations, update, stream)
609    }
610}
611
612impl<'a, A, Val> TryIndexMutOp<(A,), Val> for Array
613where
614    A: ArrayIndex<'a>,
615    Val: AsRef<Array>,
616{
617    fn try_index_mut_device(
618        &mut self,
619        (i,): (A,),
620        val: Val,
621        stream: impl AsRef<Stream>,
622    ) -> Result<()> {
623        let operations = [i.index_op()];
624        let update = val.as_ref();
625        self.try_index_mut_device_inner(&operations, update, stream)
626    }
627}
628
629impl<'a, 'b, A, B, Val> TryIndexMutOp<(A, B), Val> for Array
630where
631    A: ArrayIndex<'a>,
632    B: ArrayIndex<'b>,
633    Val: AsRef<Array>,
634{
635    fn try_index_mut_device(
636        &mut self,
637        i: (A, B),
638        val: Val,
639        stream: impl AsRef<Stream>,
640    ) -> Result<()> {
641        let operations = [i.0.index_op(), i.1.index_op()];
642        let update = val.as_ref();
643        self.try_index_mut_device_inner(&operations, update, stream)
644    }
645}
646
647impl<'a, 'b, 'c, A, B, C, Val> TryIndexMutOp<(A, B, C), Val> for Array
648where
649    A: ArrayIndex<'a>,
650    B: ArrayIndex<'b>,
651    C: ArrayIndex<'c>,
652    Val: AsRef<Array>,
653{
654    fn try_index_mut_device(
655        &mut self,
656        i: (A, B, C),
657        val: Val,
658        stream: impl AsRef<Stream>,
659    ) -> Result<()> {
660        let operations = [i.0.index_op(), i.1.index_op(), i.2.index_op()];
661        let update = val.as_ref();
662        self.try_index_mut_device_inner(&operations, update, stream)
663    }
664}
665
666impl<'a, 'b, 'c, 'd, A, B, C, D, Val> TryIndexMutOp<(A, B, C, D), Val> for Array
667where
668    A: ArrayIndex<'a>,
669    B: ArrayIndex<'b>,
670    C: ArrayIndex<'c>,
671    D: ArrayIndex<'d>,
672    Val: AsRef<Array>,
673{
674    fn try_index_mut_device(
675        &mut self,
676        i: (A, B, C, D),
677        val: Val,
678        stream: impl AsRef<Stream>,
679    ) -> Result<()> {
680        let operations = [
681            i.0.index_op(),
682            i.1.index_op(),
683            i.2.index_op(),
684            i.3.index_op(),
685        ];
686        let update = val.as_ref();
687        self.try_index_mut_device_inner(&operations, update, stream)
688    }
689}
690
691impl<'a, 'b, 'c, 'd, 'e, A, B, C, D, E, Val> TryIndexMutOp<(A, B, C, D, E), Val> for Array
692where
693    A: ArrayIndex<'a>,
694    B: ArrayIndex<'b>,
695    C: ArrayIndex<'c>,
696    D: ArrayIndex<'d>,
697    E: ArrayIndex<'e>,
698    Val: AsRef<Array>,
699{
700    fn try_index_mut_device(
701        &mut self,
702        i: (A, B, C, D, E),
703        val: Val,
704        stream: impl AsRef<Stream>,
705    ) -> Result<()> {
706        let operations = [
707            i.0.index_op(),
708            i.1.index_op(),
709            i.2.index_op(),
710            i.3.index_op(),
711            i.4.index_op(),
712        ];
713        let update = val.as_ref();
714        self.try_index_mut_device_inner(&operations, update, stream)
715    }
716}
717
718impl<'a, 'b, 'c, 'd, 'e, 'f, A, B, C, D, E, F, Val> TryIndexMutOp<(A, B, C, D, E, F), Val> for Array
719where
720    A: ArrayIndex<'a>,
721    B: ArrayIndex<'b>,
722    C: ArrayIndex<'c>,
723    D: ArrayIndex<'d>,
724    E: ArrayIndex<'e>,
725    F: ArrayIndex<'f>,
726    Val: AsRef<Array>,
727{
728    fn try_index_mut_device(
729        &mut self,
730        i: (A, B, C, D, E, F),
731        val: Val,
732        stream: impl AsRef<Stream>,
733    ) -> Result<()> {
734        let operations = [
735            i.0.index_op(),
736            i.1.index_op(),
737            i.2.index_op(),
738            i.3.index_op(),
739            i.4.index_op(),
740            i.5.index_op(),
741        ];
742        let update = val.as_ref();
743        self.try_index_mut_device_inner(&operations, update, stream)
744    }
745}
746
747impl<'a, 'b, 'c, 'd, 'e, 'f, 'g, A, B, C, D, E, F, G, Val> TryIndexMutOp<(A, B, C, D, E, F, G), Val>
748    for Array
749where
750    A: ArrayIndex<'a>,
751    B: ArrayIndex<'b>,
752    C: ArrayIndex<'c>,
753    D: ArrayIndex<'d>,
754    E: ArrayIndex<'e>,
755    F: ArrayIndex<'f>,
756    G: ArrayIndex<'g>,
757    Val: AsRef<Array>,
758{
759    fn try_index_mut_device(
760        &mut self,
761        i: (A, B, C, D, E, F, G),
762        val: Val,
763        stream: impl AsRef<Stream>,
764    ) -> Result<()> {
765        let operations = [
766            i.0.index_op(),
767            i.1.index_op(),
768            i.2.index_op(),
769            i.3.index_op(),
770            i.4.index_op(),
771            i.5.index_op(),
772            i.6.index_op(),
773        ];
774        let update = val.as_ref();
775        self.try_index_mut_device_inner(&operations, update, stream)
776    }
777}
778
779impl<'a, 'b, 'c, 'd, 'e, 'f, 'g, 'h, A, B, C, D, E, F, G, H, Val>
780    TryIndexMutOp<(A, B, C, D, E, F, G, H), Val> for Array
781where
782    A: ArrayIndex<'a>,
783    B: ArrayIndex<'b>,
784    C: ArrayIndex<'c>,
785    D: ArrayIndex<'d>,
786    E: ArrayIndex<'e>,
787    F: ArrayIndex<'f>,
788    G: ArrayIndex<'g>,
789    H: ArrayIndex<'h>,
790    Val: AsRef<Array>,
791{
792    fn try_index_mut_device(
793        &mut self,
794        i: (A, B, C, D, E, F, G, H),
795        val: Val,
796        stream: impl AsRef<Stream>,
797    ) -> Result<()> {
798        let operations = [
799            i.0.index_op(),
800            i.1.index_op(),
801            i.2.index_op(),
802            i.3.index_op(),
803            i.4.index_op(),
804            i.5.index_op(),
805            i.6.index_op(),
806            i.7.index_op(),
807        ];
808        let update = val.as_ref();
809        self.try_index_mut_device_inner(&operations, update, stream)
810    }
811}
812
813impl<'a, 'b, 'c, 'd, 'e, 'f, 'g, 'h, 'i, A, B, C, D, E, F, G, H, I, Val>
814    TryIndexMutOp<(A, B, C, D, E, F, G, H, I), Val> for Array
815where
816    A: ArrayIndex<'a>,
817    B: ArrayIndex<'b>,
818    C: ArrayIndex<'c>,
819    D: ArrayIndex<'d>,
820    E: ArrayIndex<'e>,
821    F: ArrayIndex<'f>,
822    G: ArrayIndex<'g>,
823    H: ArrayIndex<'h>,
824    I: ArrayIndex<'i>,
825    Val: AsRef<Array>,
826{
827    fn try_index_mut_device(
828        &mut self,
829        i: (A, B, C, D, E, F, G, H, I),
830        val: Val,
831        stream: impl AsRef<Stream>,
832    ) -> Result<()> {
833        let operations = [
834            i.0.index_op(),
835            i.1.index_op(),
836            i.2.index_op(),
837            i.3.index_op(),
838            i.4.index_op(),
839            i.5.index_op(),
840            i.6.index_op(),
841            i.7.index_op(),
842            i.8.index_op(),
843        ];
844        let update = val.as_ref();
845        self.try_index_mut_device_inner(&operations, update, stream)
846    }
847}
848
849impl<'a, 'b, 'c, 'd, 'e, 'f, 'g, 'h, 'i, 'j, A, B, C, D, E, F, G, H, I, J, Val>
850    TryIndexMutOp<(A, B, C, D, E, F, G, H, I, J), Val> for Array
851where
852    A: ArrayIndex<'a>,
853    B: ArrayIndex<'b>,
854    C: ArrayIndex<'c>,
855    D: ArrayIndex<'d>,
856    E: ArrayIndex<'e>,
857    F: ArrayIndex<'f>,
858    G: ArrayIndex<'g>,
859    H: ArrayIndex<'h>,
860    I: ArrayIndex<'i>,
861    J: ArrayIndex<'j>,
862    Val: AsRef<Array>,
863{
864    fn try_index_mut_device(
865        &mut self,
866        i: (A, B, C, D, E, F, G, H, I, J),
867        val: Val,
868        stream: impl AsRef<Stream>,
869    ) -> Result<()> {
870        let operations = [
871            i.0.index_op(),
872            i.1.index_op(),
873            i.2.index_op(),
874            i.3.index_op(),
875            i.4.index_op(),
876            i.5.index_op(),
877            i.6.index_op(),
878            i.7.index_op(),
879            i.8.index_op(),
880            i.9.index_op(),
881        ];
882        let update = val.as_ref();
883        self.try_index_mut_device_inner(&operations, update, stream)
884    }
885}
886
887impl<'a, 'b, 'c, 'd, 'e, 'f, 'g, 'h, 'i, 'j, 'k, A, B, C, D, E, F, G, H, I, J, K, Val>
888    TryIndexMutOp<(A, B, C, D, E, F, G, H, I, J, K), Val> for Array
889where
890    A: ArrayIndex<'a>,
891    B: ArrayIndex<'b>,
892    C: ArrayIndex<'c>,
893    D: ArrayIndex<'d>,
894    E: ArrayIndex<'e>,
895    F: ArrayIndex<'f>,
896    G: ArrayIndex<'g>,
897    H: ArrayIndex<'h>,
898    I: ArrayIndex<'i>,
899    J: ArrayIndex<'j>,
900    K: ArrayIndex<'k>,
901    Val: AsRef<Array>,
902{
903    fn try_index_mut_device(
904        &mut self,
905        i: (A, B, C, D, E, F, G, H, I, J, K),
906        val: Val,
907        stream: impl AsRef<Stream>,
908    ) -> Result<()> {
909        let operations = [
910            i.0.index_op(),
911            i.1.index_op(),
912            i.2.index_op(),
913            i.3.index_op(),
914            i.4.index_op(),
915            i.5.index_op(),
916            i.6.index_op(),
917            i.7.index_op(),
918            i.8.index_op(),
919            i.9.index_op(),
920            i.10.index_op(),
921        ];
922        let update = val.as_ref();
923        self.try_index_mut_device_inner(&operations, update, stream)
924    }
925}
926
927impl<'a, 'b, 'c, 'd, 'e, 'f, 'g, 'h, 'i, 'j, 'k, 'l, A, B, C, D, E, F, G, H, I, J, K, L, Val>
928    TryIndexMutOp<(A, B, C, D, E, F, G, H, I, J, K, L), Val> for Array
929where
930    A: ArrayIndex<'a>,
931    B: ArrayIndex<'b>,
932    C: ArrayIndex<'c>,
933    D: ArrayIndex<'d>,
934    E: ArrayIndex<'e>,
935    F: ArrayIndex<'f>,
936    G: ArrayIndex<'g>,
937    H: ArrayIndex<'h>,
938    I: ArrayIndex<'i>,
939    J: ArrayIndex<'j>,
940    K: ArrayIndex<'k>,
941    L: ArrayIndex<'l>,
942    Val: AsRef<Array>,
943{
944    fn try_index_mut_device(
945        &mut self,
946        i: (A, B, C, D, E, F, G, H, I, J, K, L),
947        val: Val,
948        stream: impl AsRef<Stream>,
949    ) -> Result<()> {
950        let operations = [
951            i.0.index_op(),
952            i.1.index_op(),
953            i.2.index_op(),
954            i.3.index_op(),
955            i.4.index_op(),
956            i.5.index_op(),
957            i.6.index_op(),
958            i.7.index_op(),
959            i.8.index_op(),
960            i.9.index_op(),
961            i.10.index_op(),
962            i.11.index_op(),
963        ];
964        let update = val.as_ref();
965        self.try_index_mut_device_inner(&operations, update, stream)
966    }
967}
968
969impl<
970        'a,
971        'b,
972        'c,
973        'd,
974        'e,
975        'f,
976        'g,
977        'h,
978        'i,
979        'j,
980        'k,
981        'l,
982        'm,
983        A,
984        B,
985        C,
986        D,
987        E,
988        F,
989        G,
990        H,
991        I,
992        J,
993        K,
994        L,
995        M,
996        Val,
997    > TryIndexMutOp<(A, B, C, D, E, F, G, H, I, J, K, L, M), Val> for Array
998where
999    A: ArrayIndex<'a>,
1000    B: ArrayIndex<'b>,
1001    C: ArrayIndex<'c>,
1002    D: ArrayIndex<'d>,
1003    E: ArrayIndex<'e>,
1004    F: ArrayIndex<'f>,
1005    G: ArrayIndex<'g>,
1006    H: ArrayIndex<'h>,
1007    I: ArrayIndex<'i>,
1008    J: ArrayIndex<'j>,
1009    K: ArrayIndex<'k>,
1010    L: ArrayIndex<'l>,
1011    M: ArrayIndex<'m>,
1012    Val: AsRef<Array>,
1013{
1014    fn try_index_mut_device(
1015        &mut self,
1016        i: (A, B, C, D, E, F, G, H, I, J, K, L, M),
1017        val: Val,
1018        stream: impl AsRef<Stream>,
1019    ) -> Result<()> {
1020        let operations = [
1021            i.0.index_op(),
1022            i.1.index_op(),
1023            i.2.index_op(),
1024            i.3.index_op(),
1025            i.4.index_op(),
1026            i.5.index_op(),
1027            i.6.index_op(),
1028            i.7.index_op(),
1029            i.8.index_op(),
1030            i.9.index_op(),
1031            i.10.index_op(),
1032            i.11.index_op(),
1033            i.12.index_op(),
1034        ];
1035        let update = val.as_ref();
1036        self.try_index_mut_device_inner(&operations, update, stream)
1037    }
1038}
1039
1040impl<
1041        'a,
1042        'b,
1043        'c,
1044        'd,
1045        'e,
1046        'f,
1047        'g,
1048        'h,
1049        'i,
1050        'j,
1051        'k,
1052        'l,
1053        'm,
1054        'n,
1055        A,
1056        B,
1057        C,
1058        D,
1059        E,
1060        F,
1061        G,
1062        H,
1063        I,
1064        J,
1065        K,
1066        L,
1067        M,
1068        N,
1069        Val,
1070    > TryIndexMutOp<(A, B, C, D, E, F, G, H, I, J, K, L, M, N), Val> for Array
1071where
1072    A: ArrayIndex<'a>,
1073    B: ArrayIndex<'b>,
1074    C: ArrayIndex<'c>,
1075    D: ArrayIndex<'d>,
1076    E: ArrayIndex<'e>,
1077    F: ArrayIndex<'f>,
1078    G: ArrayIndex<'g>,
1079    H: ArrayIndex<'h>,
1080    I: ArrayIndex<'i>,
1081    J: ArrayIndex<'j>,
1082    K: ArrayIndex<'k>,
1083    L: ArrayIndex<'l>,
1084    M: ArrayIndex<'m>,
1085    N: ArrayIndex<'n>,
1086    Val: AsRef<Array>,
1087{
1088    fn try_index_mut_device(
1089        &mut self,
1090        i: (A, B, C, D, E, F, G, H, I, J, K, L, M, N),
1091        val: Val,
1092        stream: impl AsRef<Stream>,
1093    ) -> Result<()> {
1094        let operations = [
1095            i.0.index_op(),
1096            i.1.index_op(),
1097            i.2.index_op(),
1098            i.3.index_op(),
1099            i.4.index_op(),
1100            i.5.index_op(),
1101            i.6.index_op(),
1102            i.7.index_op(),
1103            i.8.index_op(),
1104            i.9.index_op(),
1105            i.10.index_op(),
1106            i.11.index_op(),
1107            i.12.index_op(),
1108            i.13.index_op(),
1109        ];
1110        let update = val.as_ref();
1111        self.try_index_mut_device_inner(&operations, update, stream)
1112    }
1113}
1114
1115impl<
1116        'a,
1117        'b,
1118        'c,
1119        'd,
1120        'e,
1121        'f,
1122        'g,
1123        'h,
1124        'i,
1125        'j,
1126        'k,
1127        'l,
1128        'm,
1129        'n,
1130        'o,
1131        A,
1132        B,
1133        C,
1134        D,
1135        E,
1136        F,
1137        G,
1138        H,
1139        I,
1140        J,
1141        K,
1142        L,
1143        M,
1144        N,
1145        O,
1146        Val,
1147    > TryIndexMutOp<(A, B, C, D, E, F, G, H, I, J, K, L, M, N, O), Val> for Array
1148where
1149    A: ArrayIndex<'a>,
1150    B: ArrayIndex<'b>,
1151    C: ArrayIndex<'c>,
1152    D: ArrayIndex<'d>,
1153    E: ArrayIndex<'e>,
1154    F: ArrayIndex<'f>,
1155    G: ArrayIndex<'g>,
1156    H: ArrayIndex<'h>,
1157    I: ArrayIndex<'i>,
1158    J: ArrayIndex<'j>,
1159    K: ArrayIndex<'k>,
1160    L: ArrayIndex<'l>,
1161    M: ArrayIndex<'m>,
1162    N: ArrayIndex<'n>,
1163    O: ArrayIndex<'o>,
1164    Val: AsRef<Array>,
1165{
1166    fn try_index_mut_device(
1167        &mut self,
1168        i: (A, B, C, D, E, F, G, H, I, J, K, L, M, N, O),
1169        val: Val,
1170        stream: impl AsRef<Stream>,
1171    ) -> Result<()> {
1172        let operations = [
1173            i.0.index_op(),
1174            i.1.index_op(),
1175            i.2.index_op(),
1176            i.3.index_op(),
1177            i.4.index_op(),
1178            i.5.index_op(),
1179            i.6.index_op(),
1180            i.7.index_op(),
1181            i.8.index_op(),
1182            i.9.index_op(),
1183            i.10.index_op(),
1184            i.11.index_op(),
1185            i.12.index_op(),
1186            i.13.index_op(),
1187            i.14.index_op(),
1188        ];
1189        let update = val.as_ref();
1190        self.try_index_mut_device_inner(&operations, update, stream)
1191    }
1192}
1193
1194impl<
1195        'a,
1196        'b,
1197        'c,
1198        'd,
1199        'e,
1200        'f,
1201        'g,
1202        'h,
1203        'i,
1204        'j,
1205        'k,
1206        'l,
1207        'm,
1208        'n,
1209        'o,
1210        'p,
1211        A,
1212        B,
1213        C,
1214        D,
1215        E,
1216        F,
1217        G,
1218        H,
1219        I,
1220        J,
1221        K,
1222        L,
1223        M,
1224        N,
1225        O,
1226        P,
1227        Val,
1228    > TryIndexMutOp<(A, B, C, D, E, F, G, H, I, J, K, L, M, N, O, P), Val> for Array
1229where
1230    A: ArrayIndex<'a>,
1231    B: ArrayIndex<'b>,
1232    C: ArrayIndex<'c>,
1233    D: ArrayIndex<'d>,
1234    E: ArrayIndex<'e>,
1235    F: ArrayIndex<'f>,
1236    G: ArrayIndex<'g>,
1237    H: ArrayIndex<'h>,
1238    I: ArrayIndex<'i>,
1239    J: ArrayIndex<'j>,
1240    K: ArrayIndex<'k>,
1241    L: ArrayIndex<'l>,
1242    M: ArrayIndex<'m>,
1243    N: ArrayIndex<'n>,
1244    O: ArrayIndex<'o>,
1245    P: ArrayIndex<'p>,
1246    Val: AsRef<Array>,
1247{
1248    fn try_index_mut_device(
1249        &mut self,
1250        i: (A, B, C, D, E, F, G, H, I, J, K, L, M, N, O, P),
1251        val: Val,
1252        stream: impl AsRef<Stream>,
1253    ) -> Result<()> {
1254        let operations = [
1255            i.0.index_op(),
1256            i.1.index_op(),
1257            i.2.index_op(),
1258            i.3.index_op(),
1259            i.4.index_op(),
1260            i.5.index_op(),
1261            i.6.index_op(),
1262            i.7.index_op(),
1263            i.8.index_op(),
1264            i.9.index_op(),
1265            i.10.index_op(),
1266            i.11.index_op(),
1267            i.12.index_op(),
1268            i.13.index_op(),
1269            i.14.index_op(),
1270            i.15.index_op(),
1271        ];
1272        let update = val.as_ref();
1273        self.try_index_mut_device_inner(&operations, update, stream)
1274    }
1275}
1276
1277/// The unit tests below are adapted from the Swift binding tests
1278#[cfg(test)]
1279mod tests {
1280    use crate::{ops::indexing::*, Array};
1281
1282    #[test]
1283    fn test_array_mutate_single_index() {
1284        let mut a = Array::from_iter(0i32..12, &[3, 4]);
1285        let new_value = Array::from_int(77);
1286        a.index_mut(1, new_value);
1287
1288        let expected = Array::from_slice(&[0, 1, 2, 3, 77, 77, 77, 77, 8, 9, 10, 11], &[3, 4]);
1289        assert_array_all_close!(a, expected);
1290    }
1291
1292    #[test]
1293    fn test_array_mutate_broadcast_multi_index() {
1294        let mut a = Array::from_iter(0i32..20, &[2, 2, 5]);
1295
1296        // broadcast to a row
1297        a.index_mut((1, 0), Array::from_int(77));
1298
1299        // assign to a row
1300        a.index_mut((0, 0), Array::from_slice(&[55i32, 66, 77, 88, 99], &[5]));
1301
1302        // single element
1303        a.index_mut((0, 1, 3), Array::from_int(123));
1304
1305        let expected = Array::from_slice(
1306            &[
1307                55, 66, 77, 88, 99, 5, 6, 7, 123, 9, 77, 77, 77, 77, 77, 15, 16, 17, 18, 19,
1308            ],
1309            &[2, 2, 5],
1310        );
1311        assert_array_all_close!(a, expected);
1312    }
1313
1314    #[test]
1315    fn test_array_mutate_broadcast_slice() {
1316        let mut a = Array::from_iter(0i32..20, &[2, 2, 5]);
1317
1318        // writing using slices -- this ends up covering two elements
1319        a.index_mut((0..1, 1..2, 2..4), Array::from_int(88));
1320
1321        let expected = Array::from_slice(
1322            &[
1323                0, 1, 2, 3, 4, 5, 6, 88, 88, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19,
1324            ],
1325            &[2, 2, 5],
1326        );
1327        assert_array_all_close!(a, expected);
1328    }
1329
1330    #[test]
1331    fn test_array_mutate_advanced() {
1332        let mut a = Array::from_iter(0i32..35, &[5, 7]);
1333
1334        let i1 = Array::from_slice(&[0, 2, 4], &[3]);
1335        let i2 = Array::from_slice(&[0, 1, 2], &[3]);
1336
1337        a.index_mut((i1, i2), Array::from_slice(&[100, 200, 300], &[3]));
1338
1339        assert_eq!(a.index((0, 0)).item::<i32>(), 100i32);
1340        assert_eq!(a.index((2, 1)).item::<i32>(), 200i32);
1341        assert_eq!(a.index((4, 2)).item::<i32>(), 300i32);
1342    }
1343
1344    #[test]
1345    fn test_full_index_write_single() {
1346        fn check<I>(index: I, expected_sum: i32)
1347        where
1348            for<'a> I: ArrayIndex<'a>,
1349        {
1350            let mut a = Array::from_iter(0..60, &[3, 4, 5]);
1351
1352            a.index_mut(index, Array::from_int(1));
1353            let sum = a.sum(None, None).unwrap().item::<i32>();
1354            assert_eq!(sum, expected_sum);
1355        }
1356
1357        // a[...]
1358        // not valid
1359
1360        // a[None]
1361        check(NewAxis, 60);
1362
1363        // a[0]
1364        check(0, 1600);
1365
1366        // a[1:3]
1367        check(1..3, 230);
1368
1369        // i = mx.array([2, 1])
1370        let i = Array::from_slice(&[2, 1], &[2]);
1371
1372        // a[i]
1373        check(i, 230);
1374    }
1375
1376    #[test]
1377    fn test_full_index_write_no_array() {
1378        macro_rules! check {
1379            (($( $i:expr ),*), $sum:expr ) => {
1380                {
1381                    let mut a = Array::from_iter(0..360, &[2, 3, 4, 5, 3]);
1382
1383                    a.index_mut(($($i),*), Array::from_int(1));
1384                    let sum = a.sum(None, None).unwrap().item::<i32>();
1385                    assert_eq!(sum, $sum);
1386                }
1387            };
1388        }
1389
1390        // a[..., 0] = 1
1391        check!((Ellipsis, 0), 43320);
1392
1393        // a[0, ...] = 1
1394        check!((0, Ellipsis), 48690);
1395
1396        // a[0, ..., 0] = 1
1397        check!((0, Ellipsis, 0), 59370);
1398
1399        // a[..., ::2, :] = 1
1400        check!((Ellipsis, (..).stride_by(2), ..), 26064);
1401
1402        // a[..., None, ::2, -1]
1403        check!((Ellipsis, NewAxis, (..).stride_by(2), -1), 51696);
1404
1405        // a[:, 2:, 0] = 1
1406        check!((.., 2.., 0), 58140);
1407
1408        // a[::-1, :2, 2:, ..., None, ::2] = 1
1409        check!(
1410            (
1411                (..).stride_by(-1),
1412                ..2,
1413                2..,
1414                Ellipsis,
1415                NewAxis,
1416                (..).stride_by(2)
1417            ),
1418            51540
1419        );
1420    }
1421
1422    #[test]
1423    fn test_full_index_write_array() {
1424        // these have an Array as a source of indices and go through the gather path
1425
1426        macro_rules! check {
1427            (($( $i:expr ),*), $sum:expr ) => {
1428                {
1429                    let mut a = Array::from_iter(0..540, &[3, 3, 4, 5, 3]);
1430
1431                    a.index_mut(($($i),*), Array::from_int(1));
1432                    let sum = a.sum(None, None).unwrap().item::<i32>();
1433                    assert_eq!(sum, $sum);
1434                }
1435            };
1436        }
1437
1438        // i = mx.array([2, 1])
1439        let i = Array::from_slice(&[2, 1], &[2]);
1440
1441        // a[0, i] = 1
1442        check!((0, &i), 131310);
1443
1444        // a[..., i, 0] = 1
1445        check!((Ellipsis, &i, 0), 126378);
1446
1447        // a[i, 0, ...] = 1
1448        check!((&i, 0, Ellipsis), 109710);
1449
1450        // a[i, ..., i] = 1
1451        check!((&i, Ellipsis, &i), 102450);
1452
1453        // a[i, ..., ::2, :] = 1
1454        check!((&i, Ellipsis, (..).stride_by(2), ..), 68094);
1455
1456        // a[..., i, None, ::2, -1] = 1
1457        check!((Ellipsis, &i, NewAxis, (..).stride_by(2), -1), 130977);
1458
1459        // a[:, 2:, i] = 1
1460        check!((.., 2.., &i), 115965);
1461
1462        // a[::-1, :2, i, 2:, ..., None, ::2] = 1
1463        check!(
1464            (
1465                (..).stride_by(-1),
1466                ..2,
1467                i,
1468                2..,
1469                Ellipsis,
1470                NewAxis,
1471                (..).stride_by(2)
1472            ),
1473            128142
1474        );
1475    }
1476}