1use std::borrow::Cow;
2
3use smallvec::{smallvec, SmallVec};
4
5use crate::{
6 constants::DEFAULT_STACK_VEC_LEN,
7 error::Result,
8 ops::{
9 broadcast_arrays_device, broadcast_to_device,
10 indexing::{count_non_new_axis_operations, expand_ellipsis_operations},
11 reshape_device,
12 },
13 utils::{resolve_index_signed_unchecked, VectorArray},
14 Array, Stream,
15};
16
17use super::{ArrayIndex, ArrayIndexOp, Guarded, RangeIndex, TryIndexMutOp};
18
19impl Array {
20 pub(crate) fn slice_update_device(
21 &self,
22 update: &Array,
23 starts: &[i32],
24 ends: &[i32],
25 strides: &[i32],
26 stream: impl AsRef<Stream>,
27 ) -> Result<Array> {
28 Array::try_from_op(|res| unsafe {
29 mlx_sys::mlx_slice_update(
30 res,
31 self.as_ptr(),
32 update.as_ptr(),
33 starts.as_ptr(),
34 starts.len(),
35 ends.as_ptr(),
36 ends.len(),
37 strides.as_ptr(),
38 strides.len(),
39 stream.as_ref().as_ptr(),
40 )
41 })
42 }
43}
44
45fn update_slice(
47 src: &Array,
48 operations: &[ArrayIndexOp],
49 update: &Array,
50 stream: impl AsRef<Stream>,
51) -> Result<Option<Array>> {
52 let ndim = src.ndim();
53 if ndim == 0 || operations.is_empty() {
54 return Ok(None);
55 }
56
57 let mut update = remove_leading_singleton_dimensions(update, &stream)?;
59
60 let mut starts: SmallVec<[i32; DEFAULT_STACK_VEC_LEN]> = smallvec![0; ndim];
62 let mut ends: SmallVec<[i32; DEFAULT_STACK_VEC_LEN]> = SmallVec::from_slice(src.shape());
63 let mut strides: SmallVec<[i32; DEFAULT_STACK_VEC_LEN]> = smallvec![1; ndim];
64
65 if operations.len() == 1 {
67 if let ArrayIndexOp::Slice(range_index) = &operations[0] {
68 let size = src.dim(0);
69 starts[0] = range_index.start(size);
70 ends[0] = range_index.end(size);
71 strides[0] = range_index.stride();
72
73 return Ok(Some(src.slice_update_device(
74 &update, &starts, &ends, &strides, &stream,
75 )?));
76 }
77 }
78
79 if operations.iter().any(|op| op.is_array()) {
81 return Ok(None);
82 }
83
84 let operations = expand_ellipsis_operations(ndim, operations);
86
87 let non_new_axis_operation_count = count_non_new_axis_operations(&operations);
89 if non_new_axis_operation_count == 0 {
90 return Ok(Some(broadcast_to_device(&update, src.shape(), &stream)?));
91 }
92
93 let mut update_reshape: SmallVec<[i32; DEFAULT_STACK_VEC_LEN]> = smallvec![0; ndim];
96 let mut axis = src.ndim() - 1;
97 let mut update_axis = update.ndim() as i32 - 1;
98
99 while axis >= non_new_axis_operation_count {
100 if update_axis >= 0 {
101 update_reshape[axis] = update.dim(update_axis);
102 update_axis -= 1;
103 } else {
104 update_reshape[axis] = 1;
105 }
106 axis -= 1;
107 }
108
109 for item in operations.iter().rev() {
110 use ArrayIndexOp::*;
111
112 match item {
113 TakeIndex { index } => {
114 let size = src.dim(axis as i32);
115 let index = if index.is_negative() {
116 size + index
117 } else {
118 *index
119 };
120 starts[axis] = index;
122 ends[axis] = index.saturating_add(1);
123
124 update_reshape[axis] = 1;
125 axis = axis.saturating_sub(1);
126 }
127 Slice(slice) => {
128 let size = src.dim(axis as i32);
129 starts[axis] = slice.start(size);
131 ends[axis] = slice.end(size);
132 strides[axis] = slice.stride();
133
134 if update_axis >= 0 {
135 update_reshape[axis] = update.dim(update_axis);
136 update_axis = update_axis.saturating_sub(1);
137 } else {
138 update_reshape[axis] = 1;
139 }
140 axis = axis.saturating_sub(1);
141 }
142 ExpandDims => {}
143 Ellipsis | TakeArray { indices: _ } | TakeArrayRef { indices: _ } => {
144 panic!("unexpected item in operations")
145 }
146 }
147 }
148
149 if update.shape() != &update_reshape[..] {
150 update = Cow::Owned(reshape_device(update, &update_reshape, &stream)?);
151 }
152
153 Ok(Some(src.slice_update_device(
154 &update, &starts, &ends, &strides, &stream,
155 )?))
156}
157
158fn remove_leading_singleton_dimensions(
160 a: &Array,
161 stream: impl AsRef<Stream>,
162) -> Result<Cow<'_, Array>> {
163 let shape = a.shape();
164 let mut new_shape: Vec<_> = shape.iter().skip_while(|&&dim| dim == 1).cloned().collect();
165 if shape != new_shape {
166 if new_shape.is_empty() {
167 new_shape = vec![1];
168 }
169 Ok(Cow::Owned(a.reshape_device(&new_shape, stream)?))
170 } else {
171 Ok(Cow::Borrowed(a))
172 }
173}
174
175struct ScatterArgs<'a> {
176 indices: SmallVec<[Cow<'a, Array>; DEFAULT_STACK_VEC_LEN]>,
177 update: Array,
178 axes: SmallVec<[i32; DEFAULT_STACK_VEC_LEN]>,
179}
180
181fn scatter_args<'a>(
183 src: &'a Array,
184 operations: &'a [ArrayIndexOp],
185 update: &Array,
186 stream: impl AsRef<Stream>,
187) -> Result<ScatterArgs<'a>> {
188 use ArrayIndexOp::*;
189
190 if operations.len() == 1 {
191 return match &operations[0] {
192 TakeIndex { index } => scatter_args_index(src, *index, update, stream),
193 TakeArray { indices } => {
194 scatter_args_array(src, Cow::Borrowed(indices), update, stream)
195 }
196 TakeArrayRef { indices } => {
197 scatter_args_array(src, Cow::Borrowed(indices), update, stream)
198 }
199 Slice(range_index) => scatter_args_slice(src, range_index, update, stream),
200 ExpandDims => Ok(ScatterArgs {
201 indices: smallvec![],
202 update: broadcast_to_device(update, src.shape(), &stream)?,
203 axes: smallvec![],
204 }),
205 Ellipsis => panic!("Unable to update array with ellipsis argument"),
206 };
207 }
208
209 scatter_args_nd(src, operations, update, stream)
210}
211
212fn scatter_args_index<'a>(
213 src: &'a Array,
214 index: i32,
215 update: &Array,
216 stream: impl AsRef<Stream>,
217) -> Result<ScatterArgs<'a>> {
218 let update = remove_leading_singleton_dimensions(update, &stream)?;
223
224 let mut shape: SmallVec<[i32; DEFAULT_STACK_VEC_LEN]> = SmallVec::from_slice(src.shape());
225 shape[0] = 1;
226
227 Ok(ScatterArgs {
228 indices: smallvec![Cow::Owned(Array::from_int(resolve_index_signed_unchecked(
229 index,
230 src.dim(0)
231 )))],
232 update: broadcast_to_device(&update, &shape, &stream)?,
233 axes: smallvec![0],
234 })
235}
236
237fn scatter_args_array<'a>(
238 src: &'a Array,
239 a: Cow<'a, Array>,
240 update: &Array,
241 stream: impl AsRef<Stream>,
242) -> Result<ScatterArgs<'a>> {
243 let update = remove_leading_singleton_dimensions(update, &stream)?;
247
248 let mut update_shape: SmallVec<[i32; DEFAULT_STACK_VEC_LEN]> = a
250 .shape()
251 .iter()
252 .chain(src.shape().iter().skip(1))
253 .cloned()
254 .collect();
255 let update = broadcast_to_device(&update, &update_shape, &stream)?;
256
257 update_shape.insert(a.ndim(), 1);
258 let update = update.reshape_device(&update_shape, &stream)?;
259
260 Ok(ScatterArgs {
261 indices: smallvec![a],
262 update,
263 axes: smallvec![0],
264 })
265}
266
267fn scatter_args_slice<'a>(
268 src: &'a Array,
269 range_index: &'a RangeIndex,
270 update: &Array,
271 stream: impl AsRef<Stream>,
272) -> Result<ScatterArgs<'a>> {
273 if range_index.is_full() {
277 let update = remove_leading_singleton_dimensions(update, &stream)?;
278
279 return Ok(ScatterArgs {
280 indices: smallvec![],
281 update: broadcast_to_device(&update, src.shape(), &stream)?,
282 axes: smallvec![],
283 });
284 }
285
286 let size = src.dim(0);
287 let start = range_index.start(size);
288 let end = range_index.end(size);
289 let stride = range_index.stride();
290
291 if stride == 1 {
293 let update = remove_leading_singleton_dimensions(update, &stream)?;
294
295 let update_broadcast_shape: SmallVec<[i32; DEFAULT_STACK_VEC_LEN]> = (1..end - start)
297 .chain(src.shape().iter().skip(1).cloned())
298 .collect();
299 let update = broadcast_to_device(&update, &update_broadcast_shape, &stream)?;
300
301 let indices = Array::from_slice(&[start], &[1]);
302 Ok(ScatterArgs {
303 indices: smallvec![Cow::Owned(indices)],
304 update,
305 axes: smallvec![0],
306 })
307 } else {
308 let a_vals = strided_range_to_vec(start, end, stride);
310 let a = Array::from_slice(&a_vals, &[a_vals.len() as i32]);
311
312 scatter_args_array(src, Cow::Owned(a), update, stream)
313 }
314}
315
316fn scatter_args_nd<'a>(
317 src: &'a Array,
318 operations: &[ArrayIndexOp],
319 update: &Array,
320 stream: impl AsRef<Stream>,
321) -> Result<ScatterArgs<'a>> {
322 use ArrayIndexOp::*;
323
324 let shape = src.shape();
327
328 let operations = expand_ellipsis_operations(src.ndim(), operations);
329 let update = remove_leading_singleton_dimensions(update, &stream)?;
330
331 let non_new_axis_operation_count = count_non_new_axis_operations(&operations);
333 if non_new_axis_operation_count == 0 {
334 return Ok(ScatterArgs {
335 indices: smallvec![],
336 update: broadcast_to_device(&update, shape, &stream)?,
337 axes: smallvec![],
338 });
339 }
340
341 let mut max_dims = 0;
343 let mut arrays_first = false;
344 let mut count_new_axis: i32 = 0;
345 let mut count_slices: i32 = 0;
346 let mut count_arrays: i32 = 0;
347 let mut count_strided_slices: i32 = 0;
348 let mut count_simple_slices_post: i32 = 0;
349
350 let mut have_array = false;
351 let mut have_non_array = false;
352
353 macro_rules! analyze_indices_take_array {
354 ($indices:ident) => {
355 have_array = true;
356 if have_array && have_non_array {
357 arrays_first = true;
358 }
359 max_dims = $indices.ndim().max(max_dims);
360 count_arrays = count_arrays.saturating_add(1);
361 count_simple_slices_post = 0;
362 };
363 }
364
365 for item in operations.iter() {
366 match item {
367 TakeIndex { index: _ } => {
368 }
370 Slice(range_index) => {
371 have_non_array = have_array;
372 count_slices = count_slices.saturating_add(1);
373 if range_index.stride() != 1 {
374 count_strided_slices = count_strided_slices.saturating_add(1);
375 count_simple_slices_post = 0;
376 } else {
377 count_simple_slices_post = count_simple_slices_post.saturating_add(1);
378 }
379 }
380 TakeArray { indices } => {
381 analyze_indices_take_array!(indices);
382 }
383 TakeArrayRef { indices } => {
384 analyze_indices_take_array!(indices);
385 }
386 ExpandDims => {
387 have_non_array = true;
388 count_new_axis = count_new_axis.saturating_add(1);
389 }
390 Ellipsis => panic!("Unexpected item ellipsis in scatter_args_nd"),
391 }
392 }
393
394 let mut index_dims = (max_dims + count_new_axis as usize + count_slices as usize)
396 .saturating_sub(count_simple_slices_post as usize);
397
398 if index_dims == 0 {
400 index_dims = 1;
401 }
402
403 let mut array_indices: SmallVec<[Array; DEFAULT_STACK_VEC_LEN]> =
405 SmallVec::with_capacity(operations.len());
406 let mut slice_number: i32 = 0;
407 let mut array_number: i32 = 0;
408 let mut axis: i32 = 0;
409
410 let mut update_shape = vec![1; non_new_axis_operation_count];
412 let mut slice_shapes: SmallVec<[i32; DEFAULT_STACK_VEC_LEN]> = SmallVec::new();
413
414 macro_rules! update_shapes_take_array {
415 ($indices:ident) => {
416 let start = if arrays_first {
418 max_dims - $indices.ndim()
419 } else {
420 slice_number as usize + max_dims - $indices.ndim()
422 };
423 let mut new_shape = vec![1; index_dims];
424
425 for j in 0..$indices.ndim() {
426 new_shape[start + j] = $indices.dim(j as i32);
427 }
428
429 array_indices.push($indices.reshape_device(&new_shape, &stream)?);
430 array_number = array_number.saturating_add(1);
431
432 if !arrays_first && array_number == count_arrays {
433 slice_number = slice_number.saturating_add_unsigned(max_dims as u32);
434 }
435
436 update_shape[axis as usize] = 1;
438 axis = axis.saturating_add(1);
439 };
440 }
441
442 for item in operations.iter() {
443 match item {
444 TakeIndex { index } => {
445 let resolved_index = resolve_index_signed_unchecked(*index, src.dim(axis));
446 array_indices.push(Array::from_int(resolved_index));
447 update_shape[axis as usize] = 1;
449 axis = axis.saturating_add(1);
450 }
451 Slice(range_index) => {
452 let size = src.dim(axis);
453 let start = range_index.absolute_start(size);
454 let end = range_index.absolute_end(size);
455 let stride = range_index.stride();
456
457 let mut index_shape = vec![1; index_dims];
458
459 if array_number >= count_arrays && count_strided_slices <= 0 && stride == 1 {
461 let index = Array::from_int(start).reshape_device(&index_shape, &stream)?;
462 let slice_shape_entry = end - start;
463 slice_shapes.push(slice_shape_entry);
464 array_indices.push(index);
465
466 update_shape[axis as usize] = slice_shape_entry;
468 } else {
469 let index_vals = strided_range_to_vec(start, end, stride);
471 let index = Array::from_slice(&index_vals, &[index_vals.len() as i32]);
472 let location = if arrays_first {
473 slice_number.saturating_add(max_dims as i32)
474 } else {
475 slice_number
476 };
477 index_shape[location as usize] = index.size() as i32;
478 array_indices.push(index.reshape_device(&index_shape, &stream)?);
479
480 slice_number = slice_number.saturating_add(1);
481 count_strided_slices = count_strided_slices.saturating_sub(1);
482
483 update_shape[axis as usize] = 1;
485 }
486
487 axis = axis.saturating_add(1);
488 }
489 TakeArray { indices } => {
490 update_shapes_take_array!(indices);
491 }
492 TakeArrayRef { indices } => {
493 update_shapes_take_array!(indices);
494 }
495 ExpandDims => slice_number = slice_number.saturating_add(1),
496 Ellipsis => panic!("Unexpected item ellipsis in scatter_args_nd"),
497 }
498 }
499
500 let array_indices = broadcast_arrays_device(&array_indices, &stream)?;
502 let update_shape_broadcast: Vec<_> = array_indices[0]
503 .shape()
504 .iter()
505 .chain(slice_shapes.iter())
506 .chain(src.shape().iter().skip(non_new_axis_operation_count))
507 .cloned()
508 .collect();
509 let update = broadcast_to_device(&update, &update_shape_broadcast, &stream)?;
510
511 let update_reshape: Vec<_> = array_indices[0]
513 .shape()
514 .iter()
515 .chain(update_shape.iter())
516 .chain(src.shape().iter().skip(non_new_axis_operation_count))
517 .cloned()
518 .collect();
519
520 let update = update.reshape_device(&update_reshape, &stream)?;
521
522 let array_indices_len = array_indices.len();
523
524 let indices = array_indices.into_iter().map(Cow::Owned).collect();
525 Ok(ScatterArgs {
526 indices,
527 update,
528 axes: (0..array_indices_len as i32).collect(),
529 })
530}
531
532fn strided_range_to_vec(start: i32, exclusive_end: i32, stride: i32) -> Vec<i32> {
533 let estimated_capacity = (exclusive_end - start).abs() / stride.abs();
534 let mut vec = Vec::with_capacity(estimated_capacity as usize);
535 let mut current = start;
536
537 assert_ne!(stride, 0, "Stride cannot be zero");
538
539 if stride.is_negative() {
540 while current > exclusive_end {
541 vec.push(current);
542 current += stride;
543 }
544 } else {
545 while current < exclusive_end {
546 vec.push(current);
547 current += stride;
548 }
549 }
550
551 vec
552}
553
554unsafe fn scatter_device(
555 a: &Array,
556 indices: &[impl AsRef<Array>],
557 updates: &Array,
558 axes: &[i32],
559 stream: impl AsRef<Stream>,
560) -> Result<Array> {
561 let indices_vector = VectorArray::try_from_iter(indices.iter())?;
562
563 Array::try_from_op(|res| unsafe {
564 mlx_sys::mlx_scatter(
565 res,
566 a.as_ptr(),
567 indices_vector.as_ptr(),
568 updates.as_ptr(),
569 axes.as_ptr(),
570 axes.len(),
571 stream.as_ref().as_ptr(),
572 )
573 })
574}
575
576impl Array {
577 fn try_index_mut_device_inner(
578 &mut self,
579 operations: &[ArrayIndexOp],
580 update: &Array,
581 stream: impl AsRef<Stream>,
582 ) -> Result<()> {
583 if let Some(result) = update_slice(self, operations, update, &stream)? {
584 *self = result;
585 return Ok(());
586 }
587
588 let ScatterArgs {
589 indices,
590 update,
591 axes,
592 } = scatter_args(self, operations, update, &stream)?;
593 if !indices.is_empty() {
594 let result = unsafe { scatter_device(self, &indices, &update, &axes, stream)? };
595 drop(indices);
596 *self = result;
597 } else {
598 drop(indices);
599 *self = update;
600 }
601 Ok(())
602 }
603}
604
605impl<'a, Val> TryIndexMutOp<&'a [ArrayIndexOp<'a>], Val> for Array
606where
607 Val: AsRef<Array>,
608{
609 fn try_index_mut_device(
610 &mut self,
611 i: &'a [ArrayIndexOp<'a>],
612 val: Val,
613 stream: impl AsRef<Stream>,
614 ) -> Result<()> {
615 let update = val.as_ref();
616 self.try_index_mut_device_inner(i, update, stream)
617 }
618}
619
620impl<A, Val> TryIndexMutOp<A, Val> for Array
621where
622 for<'a> A: ArrayIndex<'a>,
623 Val: AsRef<Array>,
624{
625 fn try_index_mut_device(&mut self, i: A, val: Val, stream: impl AsRef<Stream>) -> Result<()> {
626 let operations = [i.index_op()];
627 let update = val.as_ref();
628 self.try_index_mut_device_inner(&operations, update, stream)
629 }
630}
631
632impl<'a, A, Val> TryIndexMutOp<(A,), Val> for Array
633where
634 A: ArrayIndex<'a>,
635 Val: AsRef<Array>,
636{
637 fn try_index_mut_device(
638 &mut self,
639 (i,): (A,),
640 val: Val,
641 stream: impl AsRef<Stream>,
642 ) -> Result<()> {
643 let operations = [i.index_op()];
644 let update = val.as_ref();
645 self.try_index_mut_device_inner(&operations, update, stream)
646 }
647}
648
649impl<'a, 'b, A, B, Val> TryIndexMutOp<(A, B), Val> for Array
650where
651 A: ArrayIndex<'a>,
652 B: ArrayIndex<'b>,
653 Val: AsRef<Array>,
654{
655 fn try_index_mut_device(
656 &mut self,
657 i: (A, B),
658 val: Val,
659 stream: impl AsRef<Stream>,
660 ) -> Result<()> {
661 let operations = [i.0.index_op(), i.1.index_op()];
662 let update = val.as_ref();
663 self.try_index_mut_device_inner(&operations, update, stream)
664 }
665}
666
667impl<'a, 'b, 'c, A, B, C, Val> TryIndexMutOp<(A, B, C), Val> for Array
668where
669 A: ArrayIndex<'a>,
670 B: ArrayIndex<'b>,
671 C: ArrayIndex<'c>,
672 Val: AsRef<Array>,
673{
674 fn try_index_mut_device(
675 &mut self,
676 i: (A, B, C),
677 val: Val,
678 stream: impl AsRef<Stream>,
679 ) -> Result<()> {
680 let operations = [i.0.index_op(), i.1.index_op(), i.2.index_op()];
681 let update = val.as_ref();
682 self.try_index_mut_device_inner(&operations, update, stream)
683 }
684}
685
686impl<'a, 'b, 'c, 'd, A, B, C, D, Val> TryIndexMutOp<(A, B, C, D), Val> for Array
687where
688 A: ArrayIndex<'a>,
689 B: ArrayIndex<'b>,
690 C: ArrayIndex<'c>,
691 D: ArrayIndex<'d>,
692 Val: AsRef<Array>,
693{
694 fn try_index_mut_device(
695 &mut self,
696 i: (A, B, C, D),
697 val: Val,
698 stream: impl AsRef<Stream>,
699 ) -> Result<()> {
700 let operations = [
701 i.0.index_op(),
702 i.1.index_op(),
703 i.2.index_op(),
704 i.3.index_op(),
705 ];
706 let update = val.as_ref();
707 self.try_index_mut_device_inner(&operations, update, stream)
708 }
709}
710
711impl<'a, 'b, 'c, 'd, 'e, A, B, C, D, E, Val> TryIndexMutOp<(A, B, C, D, E), Val> for Array
712where
713 A: ArrayIndex<'a>,
714 B: ArrayIndex<'b>,
715 C: ArrayIndex<'c>,
716 D: ArrayIndex<'d>,
717 E: ArrayIndex<'e>,
718 Val: AsRef<Array>,
719{
720 fn try_index_mut_device(
721 &mut self,
722 i: (A, B, C, D, E),
723 val: Val,
724 stream: impl AsRef<Stream>,
725 ) -> Result<()> {
726 let operations = [
727 i.0.index_op(),
728 i.1.index_op(),
729 i.2.index_op(),
730 i.3.index_op(),
731 i.4.index_op(),
732 ];
733 let update = val.as_ref();
734 self.try_index_mut_device_inner(&operations, update, stream)
735 }
736}
737
738impl<'a, 'b, 'c, 'd, 'e, 'f, A, B, C, D, E, F, Val> TryIndexMutOp<(A, B, C, D, E, F), Val> for Array
739where
740 A: ArrayIndex<'a>,
741 B: ArrayIndex<'b>,
742 C: ArrayIndex<'c>,
743 D: ArrayIndex<'d>,
744 E: ArrayIndex<'e>,
745 F: ArrayIndex<'f>,
746 Val: AsRef<Array>,
747{
748 fn try_index_mut_device(
749 &mut self,
750 i: (A, B, C, D, E, F),
751 val: Val,
752 stream: impl AsRef<Stream>,
753 ) -> Result<()> {
754 let operations = [
755 i.0.index_op(),
756 i.1.index_op(),
757 i.2.index_op(),
758 i.3.index_op(),
759 i.4.index_op(),
760 i.5.index_op(),
761 ];
762 let update = val.as_ref();
763 self.try_index_mut_device_inner(&operations, update, stream)
764 }
765}
766
767impl<'a, 'b, 'c, 'd, 'e, 'f, 'g, A, B, C, D, E, F, G, Val> TryIndexMutOp<(A, B, C, D, E, F, G), Val>
768 for Array
769where
770 A: ArrayIndex<'a>,
771 B: ArrayIndex<'b>,
772 C: ArrayIndex<'c>,
773 D: ArrayIndex<'d>,
774 E: ArrayIndex<'e>,
775 F: ArrayIndex<'f>,
776 G: ArrayIndex<'g>,
777 Val: AsRef<Array>,
778{
779 fn try_index_mut_device(
780 &mut self,
781 i: (A, B, C, D, E, F, G),
782 val: Val,
783 stream: impl AsRef<Stream>,
784 ) -> Result<()> {
785 let operations = [
786 i.0.index_op(),
787 i.1.index_op(),
788 i.2.index_op(),
789 i.3.index_op(),
790 i.4.index_op(),
791 i.5.index_op(),
792 i.6.index_op(),
793 ];
794 let update = val.as_ref();
795 self.try_index_mut_device_inner(&operations, update, stream)
796 }
797}
798
799impl<'a, 'b, 'c, 'd, 'e, 'f, 'g, 'h, A, B, C, D, E, F, G, H, Val>
800 TryIndexMutOp<(A, B, C, D, E, F, G, H), Val> for Array
801where
802 A: ArrayIndex<'a>,
803 B: ArrayIndex<'b>,
804 C: ArrayIndex<'c>,
805 D: ArrayIndex<'d>,
806 E: ArrayIndex<'e>,
807 F: ArrayIndex<'f>,
808 G: ArrayIndex<'g>,
809 H: ArrayIndex<'h>,
810 Val: AsRef<Array>,
811{
812 fn try_index_mut_device(
813 &mut self,
814 i: (A, B, C, D, E, F, G, H),
815 val: Val,
816 stream: impl AsRef<Stream>,
817 ) -> Result<()> {
818 let operations = [
819 i.0.index_op(),
820 i.1.index_op(),
821 i.2.index_op(),
822 i.3.index_op(),
823 i.4.index_op(),
824 i.5.index_op(),
825 i.6.index_op(),
826 i.7.index_op(),
827 ];
828 let update = val.as_ref();
829 self.try_index_mut_device_inner(&operations, update, stream)
830 }
831}
832
833impl<'a, 'b, 'c, 'd, 'e, 'f, 'g, 'h, 'i, A, B, C, D, E, F, G, H, I, Val>
834 TryIndexMutOp<(A, B, C, D, E, F, G, H, I), Val> for Array
835where
836 A: ArrayIndex<'a>,
837 B: ArrayIndex<'b>,
838 C: ArrayIndex<'c>,
839 D: ArrayIndex<'d>,
840 E: ArrayIndex<'e>,
841 F: ArrayIndex<'f>,
842 G: ArrayIndex<'g>,
843 H: ArrayIndex<'h>,
844 I: ArrayIndex<'i>,
845 Val: AsRef<Array>,
846{
847 fn try_index_mut_device(
848 &mut self,
849 i: (A, B, C, D, E, F, G, H, I),
850 val: Val,
851 stream: impl AsRef<Stream>,
852 ) -> Result<()> {
853 let operations = [
854 i.0.index_op(),
855 i.1.index_op(),
856 i.2.index_op(),
857 i.3.index_op(),
858 i.4.index_op(),
859 i.5.index_op(),
860 i.6.index_op(),
861 i.7.index_op(),
862 i.8.index_op(),
863 ];
864 let update = val.as_ref();
865 self.try_index_mut_device_inner(&operations, update, stream)
866 }
867}
868
869impl<'a, 'b, 'c, 'd, 'e, 'f, 'g, 'h, 'i, 'j, A, B, C, D, E, F, G, H, I, J, Val>
870 TryIndexMutOp<(A, B, C, D, E, F, G, H, I, J), Val> for Array
871where
872 A: ArrayIndex<'a>,
873 B: ArrayIndex<'b>,
874 C: ArrayIndex<'c>,
875 D: ArrayIndex<'d>,
876 E: ArrayIndex<'e>,
877 F: ArrayIndex<'f>,
878 G: ArrayIndex<'g>,
879 H: ArrayIndex<'h>,
880 I: ArrayIndex<'i>,
881 J: ArrayIndex<'j>,
882 Val: AsRef<Array>,
883{
884 fn try_index_mut_device(
885 &mut self,
886 i: (A, B, C, D, E, F, G, H, I, J),
887 val: Val,
888 stream: impl AsRef<Stream>,
889 ) -> Result<()> {
890 let operations = [
891 i.0.index_op(),
892 i.1.index_op(),
893 i.2.index_op(),
894 i.3.index_op(),
895 i.4.index_op(),
896 i.5.index_op(),
897 i.6.index_op(),
898 i.7.index_op(),
899 i.8.index_op(),
900 i.9.index_op(),
901 ];
902 let update = val.as_ref();
903 self.try_index_mut_device_inner(&operations, update, stream)
904 }
905}
906
907impl<'a, 'b, 'c, 'd, 'e, 'f, 'g, 'h, 'i, 'j, 'k, A, B, C, D, E, F, G, H, I, J, K, Val>
908 TryIndexMutOp<(A, B, C, D, E, F, G, H, I, J, K), Val> for Array
909where
910 A: ArrayIndex<'a>,
911 B: ArrayIndex<'b>,
912 C: ArrayIndex<'c>,
913 D: ArrayIndex<'d>,
914 E: ArrayIndex<'e>,
915 F: ArrayIndex<'f>,
916 G: ArrayIndex<'g>,
917 H: ArrayIndex<'h>,
918 I: ArrayIndex<'i>,
919 J: ArrayIndex<'j>,
920 K: ArrayIndex<'k>,
921 Val: AsRef<Array>,
922{
923 fn try_index_mut_device(
924 &mut self,
925 i: (A, B, C, D, E, F, G, H, I, J, K),
926 val: Val,
927 stream: impl AsRef<Stream>,
928 ) -> Result<()> {
929 let operations = [
930 i.0.index_op(),
931 i.1.index_op(),
932 i.2.index_op(),
933 i.3.index_op(),
934 i.4.index_op(),
935 i.5.index_op(),
936 i.6.index_op(),
937 i.7.index_op(),
938 i.8.index_op(),
939 i.9.index_op(),
940 i.10.index_op(),
941 ];
942 let update = val.as_ref();
943 self.try_index_mut_device_inner(&operations, update, stream)
944 }
945}
946
947impl<'a, 'b, 'c, 'd, 'e, 'f, 'g, 'h, 'i, 'j, 'k, 'l, A, B, C, D, E, F, G, H, I, J, K, L, Val>
948 TryIndexMutOp<(A, B, C, D, E, F, G, H, I, J, K, L), Val> for Array
949where
950 A: ArrayIndex<'a>,
951 B: ArrayIndex<'b>,
952 C: ArrayIndex<'c>,
953 D: ArrayIndex<'d>,
954 E: ArrayIndex<'e>,
955 F: ArrayIndex<'f>,
956 G: ArrayIndex<'g>,
957 H: ArrayIndex<'h>,
958 I: ArrayIndex<'i>,
959 J: ArrayIndex<'j>,
960 K: ArrayIndex<'k>,
961 L: ArrayIndex<'l>,
962 Val: AsRef<Array>,
963{
964 fn try_index_mut_device(
965 &mut self,
966 i: (A, B, C, D, E, F, G, H, I, J, K, L),
967 val: Val,
968 stream: impl AsRef<Stream>,
969 ) -> Result<()> {
970 let operations = [
971 i.0.index_op(),
972 i.1.index_op(),
973 i.2.index_op(),
974 i.3.index_op(),
975 i.4.index_op(),
976 i.5.index_op(),
977 i.6.index_op(),
978 i.7.index_op(),
979 i.8.index_op(),
980 i.9.index_op(),
981 i.10.index_op(),
982 i.11.index_op(),
983 ];
984 let update = val.as_ref();
985 self.try_index_mut_device_inner(&operations, update, stream)
986 }
987}
988
989impl<
990 'a,
991 'b,
992 'c,
993 'd,
994 'e,
995 'f,
996 'g,
997 'h,
998 'i,
999 'j,
1000 'k,
1001 'l,
1002 'm,
1003 A,
1004 B,
1005 C,
1006 D,
1007 E,
1008 F,
1009 G,
1010 H,
1011 I,
1012 J,
1013 K,
1014 L,
1015 M,
1016 Val,
1017 > TryIndexMutOp<(A, B, C, D, E, F, G, H, I, J, K, L, M), Val> for Array
1018where
1019 A: ArrayIndex<'a>,
1020 B: ArrayIndex<'b>,
1021 C: ArrayIndex<'c>,
1022 D: ArrayIndex<'d>,
1023 E: ArrayIndex<'e>,
1024 F: ArrayIndex<'f>,
1025 G: ArrayIndex<'g>,
1026 H: ArrayIndex<'h>,
1027 I: ArrayIndex<'i>,
1028 J: ArrayIndex<'j>,
1029 K: ArrayIndex<'k>,
1030 L: ArrayIndex<'l>,
1031 M: ArrayIndex<'m>,
1032 Val: AsRef<Array>,
1033{
1034 fn try_index_mut_device(
1035 &mut self,
1036 i: (A, B, C, D, E, F, G, H, I, J, K, L, M),
1037 val: Val,
1038 stream: impl AsRef<Stream>,
1039 ) -> Result<()> {
1040 let operations = [
1041 i.0.index_op(),
1042 i.1.index_op(),
1043 i.2.index_op(),
1044 i.3.index_op(),
1045 i.4.index_op(),
1046 i.5.index_op(),
1047 i.6.index_op(),
1048 i.7.index_op(),
1049 i.8.index_op(),
1050 i.9.index_op(),
1051 i.10.index_op(),
1052 i.11.index_op(),
1053 i.12.index_op(),
1054 ];
1055 let update = val.as_ref();
1056 self.try_index_mut_device_inner(&operations, update, stream)
1057 }
1058}
1059
1060impl<
1061 'a,
1062 'b,
1063 'c,
1064 'd,
1065 'e,
1066 'f,
1067 'g,
1068 'h,
1069 'i,
1070 'j,
1071 'k,
1072 'l,
1073 'm,
1074 'n,
1075 A,
1076 B,
1077 C,
1078 D,
1079 E,
1080 F,
1081 G,
1082 H,
1083 I,
1084 J,
1085 K,
1086 L,
1087 M,
1088 N,
1089 Val,
1090 > TryIndexMutOp<(A, B, C, D, E, F, G, H, I, J, K, L, M, N), Val> for Array
1091where
1092 A: ArrayIndex<'a>,
1093 B: ArrayIndex<'b>,
1094 C: ArrayIndex<'c>,
1095 D: ArrayIndex<'d>,
1096 E: ArrayIndex<'e>,
1097 F: ArrayIndex<'f>,
1098 G: ArrayIndex<'g>,
1099 H: ArrayIndex<'h>,
1100 I: ArrayIndex<'i>,
1101 J: ArrayIndex<'j>,
1102 K: ArrayIndex<'k>,
1103 L: ArrayIndex<'l>,
1104 M: ArrayIndex<'m>,
1105 N: ArrayIndex<'n>,
1106 Val: AsRef<Array>,
1107{
1108 fn try_index_mut_device(
1109 &mut self,
1110 i: (A, B, C, D, E, F, G, H, I, J, K, L, M, N),
1111 val: Val,
1112 stream: impl AsRef<Stream>,
1113 ) -> Result<()> {
1114 let operations = [
1115 i.0.index_op(),
1116 i.1.index_op(),
1117 i.2.index_op(),
1118 i.3.index_op(),
1119 i.4.index_op(),
1120 i.5.index_op(),
1121 i.6.index_op(),
1122 i.7.index_op(),
1123 i.8.index_op(),
1124 i.9.index_op(),
1125 i.10.index_op(),
1126 i.11.index_op(),
1127 i.12.index_op(),
1128 i.13.index_op(),
1129 ];
1130 let update = val.as_ref();
1131 self.try_index_mut_device_inner(&operations, update, stream)
1132 }
1133}
1134
1135impl<
1136 'a,
1137 'b,
1138 'c,
1139 'd,
1140 'e,
1141 'f,
1142 'g,
1143 'h,
1144 'i,
1145 'j,
1146 'k,
1147 'l,
1148 'm,
1149 'n,
1150 'o,
1151 A,
1152 B,
1153 C,
1154 D,
1155 E,
1156 F,
1157 G,
1158 H,
1159 I,
1160 J,
1161 K,
1162 L,
1163 M,
1164 N,
1165 O,
1166 Val,
1167 > TryIndexMutOp<(A, B, C, D, E, F, G, H, I, J, K, L, M, N, O), Val> for Array
1168where
1169 A: ArrayIndex<'a>,
1170 B: ArrayIndex<'b>,
1171 C: ArrayIndex<'c>,
1172 D: ArrayIndex<'d>,
1173 E: ArrayIndex<'e>,
1174 F: ArrayIndex<'f>,
1175 G: ArrayIndex<'g>,
1176 H: ArrayIndex<'h>,
1177 I: ArrayIndex<'i>,
1178 J: ArrayIndex<'j>,
1179 K: ArrayIndex<'k>,
1180 L: ArrayIndex<'l>,
1181 M: ArrayIndex<'m>,
1182 N: ArrayIndex<'n>,
1183 O: ArrayIndex<'o>,
1184 Val: AsRef<Array>,
1185{
1186 fn try_index_mut_device(
1187 &mut self,
1188 i: (A, B, C, D, E, F, G, H, I, J, K, L, M, N, O),
1189 val: Val,
1190 stream: impl AsRef<Stream>,
1191 ) -> Result<()> {
1192 let operations = [
1193 i.0.index_op(),
1194 i.1.index_op(),
1195 i.2.index_op(),
1196 i.3.index_op(),
1197 i.4.index_op(),
1198 i.5.index_op(),
1199 i.6.index_op(),
1200 i.7.index_op(),
1201 i.8.index_op(),
1202 i.9.index_op(),
1203 i.10.index_op(),
1204 i.11.index_op(),
1205 i.12.index_op(),
1206 i.13.index_op(),
1207 i.14.index_op(),
1208 ];
1209 let update = val.as_ref();
1210 self.try_index_mut_device_inner(&operations, update, stream)
1211 }
1212}
1213
1214impl<
1215 'a,
1216 'b,
1217 'c,
1218 'd,
1219 'e,
1220 'f,
1221 'g,
1222 'h,
1223 'i,
1224 'j,
1225 'k,
1226 'l,
1227 'm,
1228 'n,
1229 'o,
1230 'p,
1231 A,
1232 B,
1233 C,
1234 D,
1235 E,
1236 F,
1237 G,
1238 H,
1239 I,
1240 J,
1241 K,
1242 L,
1243 M,
1244 N,
1245 O,
1246 P,
1247 Val,
1248 > TryIndexMutOp<(A, B, C, D, E, F, G, H, I, J, K, L, M, N, O, P), Val> for Array
1249where
1250 A: ArrayIndex<'a>,
1251 B: ArrayIndex<'b>,
1252 C: ArrayIndex<'c>,
1253 D: ArrayIndex<'d>,
1254 E: ArrayIndex<'e>,
1255 F: ArrayIndex<'f>,
1256 G: ArrayIndex<'g>,
1257 H: ArrayIndex<'h>,
1258 I: ArrayIndex<'i>,
1259 J: ArrayIndex<'j>,
1260 K: ArrayIndex<'k>,
1261 L: ArrayIndex<'l>,
1262 M: ArrayIndex<'m>,
1263 N: ArrayIndex<'n>,
1264 O: ArrayIndex<'o>,
1265 P: ArrayIndex<'p>,
1266 Val: AsRef<Array>,
1267{
1268 fn try_index_mut_device(
1269 &mut self,
1270 i: (A, B, C, D, E, F, G, H, I, J, K, L, M, N, O, P),
1271 val: Val,
1272 stream: impl AsRef<Stream>,
1273 ) -> Result<()> {
1274 let operations = [
1275 i.0.index_op(),
1276 i.1.index_op(),
1277 i.2.index_op(),
1278 i.3.index_op(),
1279 i.4.index_op(),
1280 i.5.index_op(),
1281 i.6.index_op(),
1282 i.7.index_op(),
1283 i.8.index_op(),
1284 i.9.index_op(),
1285 i.10.index_op(),
1286 i.11.index_op(),
1287 i.12.index_op(),
1288 i.13.index_op(),
1289 i.14.index_op(),
1290 i.15.index_op(),
1291 ];
1292 let update = val.as_ref();
1293 self.try_index_mut_device_inner(&operations, update, stream)
1294 }
1295}
1296
1297#[cfg(test)]
1299mod tests {
1300 use crate::{
1301 ops::{indexing::*, ones, zeros},
1302 Array,
1303 };
1304
1305 #[test]
1306 fn test_array_mutate_single_index() {
1307 let mut a = Array::from_iter(0i32..12, &[3, 4]);
1308 let new_value = Array::from_int(77);
1309 a.index_mut(1, new_value);
1310
1311 let expected = Array::from_slice(&[0, 1, 2, 3, 77, 77, 77, 77, 8, 9, 10, 11], &[3, 4]);
1312 assert_array_all_close!(a, expected);
1313 }
1314
1315 #[test]
1316 fn test_array_mutate_broadcast_multi_index() {
1317 let mut a = Array::from_iter(0i32..20, &[2, 2, 5]);
1318
1319 a.index_mut((1, 0), Array::from_int(77));
1321
1322 a.index_mut((0, 0), Array::from_slice(&[55i32, 66, 77, 88, 99], &[5]));
1324
1325 a.index_mut((0, 1, 3), Array::from_int(123));
1327
1328 let expected = Array::from_slice(
1329 &[
1330 55, 66, 77, 88, 99, 5, 6, 7, 123, 9, 77, 77, 77, 77, 77, 15, 16, 17, 18, 19,
1331 ],
1332 &[2, 2, 5],
1333 );
1334 assert_array_all_close!(a, expected);
1335 }
1336
1337 #[test]
1338 fn test_array_mutate_broadcast_slice() {
1339 let mut a = Array::from_iter(0i32..20, &[2, 2, 5]);
1340
1341 a.index_mut((0..1, 1..2, 2..4), Array::from_int(88));
1343
1344 let expected = Array::from_slice(
1345 &[
1346 0, 1, 2, 3, 4, 5, 6, 88, 88, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19,
1347 ],
1348 &[2, 2, 5],
1349 );
1350 assert_array_all_close!(a, expected);
1351 }
1352
1353 #[test]
1354 fn test_array_mutate_advanced() {
1355 let mut a = Array::from_iter(0i32..35, &[5, 7]);
1356
1357 let i1 = Array::from_slice(&[0, 2, 4], &[3]);
1358 let i2 = Array::from_slice(&[0, 1, 2], &[3]);
1359
1360 a.index_mut((i1, i2), Array::from_slice(&[100, 200, 300], &[3]));
1361
1362 assert_eq!(a.index((0, 0)).item::<i32>(), 100i32);
1363 assert_eq!(a.index((2, 1)).item::<i32>(), 200i32);
1364 assert_eq!(a.index((4, 2)).item::<i32>(), 300i32);
1365 }
1366
1367 #[test]
1368 fn test_full_index_write_single() {
1369 fn check<I>(index: I, expected_sum: i32)
1370 where
1371 for<'a> I: ArrayIndex<'a>,
1372 {
1373 let mut a = Array::from_iter(0..60, &[3, 4, 5]);
1374
1375 a.index_mut(index, Array::from_int(1));
1376 let sum = a.sum(None).unwrap().item::<i32>();
1377 assert_eq!(sum, expected_sum);
1378 }
1379
1380 check(NewAxis, 60);
1385
1386 check(0, 1600);
1388
1389 check(1..3, 230);
1391
1392 let i = Array::from_slice(&[2, 1], &[2]);
1394
1395 check(i, 230);
1397 }
1398
1399 #[test]
1400 fn test_full_index_write_no_array() {
1401 macro_rules! check {
1402 (($( $i:expr ),*), $sum:expr ) => {
1403 {
1404 let mut a = Array::from_iter(0..360, &[2, 3, 4, 5, 3]);
1405
1406 a.index_mut(($($i),*), Array::from_int(1));
1407 let sum = a.sum(None).unwrap().item::<i32>();
1408 assert_eq!(sum, $sum);
1409 }
1410 };
1411 }
1412
1413 check!((Ellipsis, 0), 43320);
1415
1416 check!((0, Ellipsis), 48690);
1418
1419 check!((0, Ellipsis, 0), 59370);
1421
1422 check!((Ellipsis, (..).stride_by(2), ..), 26064);
1424
1425 check!((Ellipsis, NewAxis, (..).stride_by(2), -1), 51696);
1427
1428 check!((.., 2.., 0), 58140);
1430
1431 check!(
1433 (
1434 (..).stride_by(-1),
1435 ..2,
1436 2..,
1437 Ellipsis,
1438 NewAxis,
1439 (..).stride_by(2)
1440 ),
1441 51540
1442 );
1443 }
1444
1445 #[test]
1446 fn test_full_index_write_array() {
1447 macro_rules! check {
1450 (($( $i:expr ),*), $sum:expr ) => {
1451 {
1452 let mut a = Array::from_iter(0..540, &[3, 3, 4, 5, 3]);
1453
1454 a.index_mut(($($i),*), Array::from_int(1));
1455 let sum = a.sum(None).unwrap().item::<i32>();
1456 assert_eq!(sum, $sum);
1457 }
1458 };
1459 }
1460
1461 let i = Array::from_slice(&[2, 1], &[2]);
1463
1464 check!((0, &i), 131310);
1466
1467 check!((Ellipsis, &i, 0), 126378);
1469
1470 check!((&i, 0, Ellipsis), 109710);
1472
1473 check!((&i, Ellipsis, &i), 102450);
1475
1476 check!((&i, Ellipsis, (..).stride_by(2), ..), 68094);
1478
1479 check!((Ellipsis, &i, NewAxis, (..).stride_by(2), -1), 130977);
1481
1482 check!((.., 2.., &i), 115965);
1484
1485 check!(
1487 (
1488 (..).stride_by(-1),
1489 ..2,
1490 i,
1491 2..,
1492 Ellipsis,
1493 NewAxis,
1494 (..).stride_by(2)
1495 ),
1496 128142
1497 );
1498 }
1499
1500 #[test]
1501 fn test_slice_update_with_broadcast() {
1502 let mut xs = zeros::<f32>(&[4, 3, 2]).unwrap();
1503 let x = ones::<f32>(&[4, 2]).unwrap();
1504
1505 let result = xs.try_index_mut((.., 0, ..), x);
1506 assert!(
1507 result.is_ok(),
1508 "Failed to update slice with broadcast: {result:?}"
1509 );
1510 }
1511}