1use std::borrow::Cow;
2
3use smallvec::{smallvec, SmallVec};
4
5use crate::{
6 constants::DEFAULT_STACK_VEC_LEN,
7 error::Result,
8 ops::{
9 broadcast_arrays_device, broadcast_to_device,
10 indexing::{count_non_new_axis_operations, expand_ellipsis_operations},
11 },
12 utils::{resolve_index_signed_unchecked, VectorArray},
13 Array, Stream,
14};
15
16use super::{ArrayIndex, ArrayIndexOp, Guarded, RangeIndex, TryIndexMutOp};
17
18impl Array {
19 pub(crate) fn slice_update_device(
20 &self,
21 update: &Array,
22 starts: &[i32],
23 ends: &[i32],
24 strides: &[i32],
25 stream: impl AsRef<Stream>,
26 ) -> Result<Array> {
27 Array::try_from_op(|res| unsafe {
28 mlx_sys::mlx_slice_update(
29 res,
30 self.as_ptr(),
31 update.as_ptr(),
32 starts.as_ptr(),
33 starts.len(),
34 ends.as_ptr(),
35 ends.len(),
36 strides.as_ptr(),
37 strides.len(),
38 stream.as_ref().as_ptr(),
39 )
40 })
41 }
42}
43
44fn update_slice(
46 src: &Array,
47 operations: &[ArrayIndexOp],
48 update: &Array,
49 stream: impl AsRef<Stream>,
50) -> Result<Option<Array>> {
51 let ndim = src.ndim();
52 if ndim == 0 || operations.is_empty() {
53 return Ok(None);
54 }
55
56 let mut update = remove_leading_singleton_dimensions(update, &stream)?;
58
59 let mut starts: SmallVec<[i32; DEFAULT_STACK_VEC_LEN]> = smallvec![0; ndim];
61 let mut ends: SmallVec<[i32; DEFAULT_STACK_VEC_LEN]> = SmallVec::from_slice(src.shape());
62 let mut strides: SmallVec<[i32; DEFAULT_STACK_VEC_LEN]> = smallvec![1; ndim];
63
64 if operations.len() == 1 {
66 if let ArrayIndexOp::Slice(range_index) = &operations[0] {
67 let size = src.dim(0);
68 starts[0] = range_index.start(size);
69 ends[0] = range_index.end(size);
70 strides[0] = range_index.stride();
71
72 return Ok(Some(src.slice_update_device(
73 &update, &starts, &ends, &strides, &stream,
74 )?));
75 }
76 }
77
78 if operations.iter().any(|op| op.is_array()) {
80 return Ok(None);
81 }
82
83 let operations = expand_ellipsis_operations(ndim, operations);
85
86 if count_non_new_axis_operations(&operations) == 0 {
88 return Ok(Some(broadcast_to_device(&update, src.shape(), &stream)?));
89 }
90
91 let mut update_expand_dims: SmallVec<[i32; DEFAULT_STACK_VEC_LEN]> = SmallVec::new();
93 let mut axis = 0i32;
94 for item in operations.iter() {
95 use ArrayIndexOp::*;
96
97 match item {
98 TakeIndex { index } => {
99 let size = src.dim(axis);
100 let index = if index.is_negative() {
101 size + index
102 } else {
103 *index
104 };
105 starts[axis as usize] = index;
107 ends[axis as usize] = index + 1;
108 if ndim - (axis as usize) < update.ndim() {
109 update_expand_dims.push(axis.saturating_sub_unsigned(ndim as u32));
110 }
111
112 axis = axis.saturating_add(1);
113 }
114 Slice(range_index) => {
115 let size = src.dim(axis);
116 starts[axis as usize] = range_index.start(size);
118 ends[axis as usize] = range_index.end(size);
119 strides[axis as usize] = range_index.stride();
120 axis = axis.saturating_add(1);
121 }
122 ExpandDims => {}
123 Ellipsis | TakeArray { indices: _ } | TakeArrayRef { indices: _ } => {
124 panic!("unexpected item in operations")
125 }
126 }
127 }
128
129 if !update_expand_dims.is_empty() {
130 update = Cow::Owned(update.expand_dims_device(&update_expand_dims, &stream)?);
131 }
132
133 Ok(Some(src.slice_update_device(
134 &update, &starts, &ends, &strides, &stream,
135 )?))
136}
137
138fn remove_leading_singleton_dimensions(
140 a: &Array,
141 stream: impl AsRef<Stream>,
142) -> Result<Cow<'_, Array>> {
143 let shape = a.shape();
144 let mut new_shape: Vec<_> = shape.iter().skip_while(|&&dim| dim == 1).cloned().collect();
145 if shape != new_shape {
146 if new_shape.is_empty() {
147 new_shape = vec![1];
148 }
149 Ok(Cow::Owned(a.reshape_device(&new_shape, stream)?))
150 } else {
151 Ok(Cow::Borrowed(a))
152 }
153}
154
155struct ScatterArgs<'a> {
156 indices: SmallVec<[Cow<'a, Array>; DEFAULT_STACK_VEC_LEN]>,
157 update: Array,
158 axes: SmallVec<[i32; DEFAULT_STACK_VEC_LEN]>,
159}
160
161fn scatter_args<'a>(
163 src: &'a Array,
164 operations: &'a [ArrayIndexOp],
165 update: &Array,
166 stream: impl AsRef<Stream>,
167) -> Result<ScatterArgs<'a>> {
168 use ArrayIndexOp::*;
169
170 if operations.len() == 1 {
171 return match &operations[0] {
172 TakeIndex { index } => scatter_args_index(src, *index, update, stream),
173 TakeArray { indices } => {
174 scatter_args_array(src, Cow::Borrowed(indices), update, stream)
175 }
176 TakeArrayRef { indices } => {
177 scatter_args_array(src, Cow::Borrowed(indices), update, stream)
178 }
179 Slice(range_index) => scatter_args_slice(src, range_index, update, stream),
180 ExpandDims => Ok(ScatterArgs {
181 indices: smallvec![],
182 update: broadcast_to_device(update, src.shape(), &stream)?,
183 axes: smallvec![],
184 }),
185 Ellipsis => panic!("Unable to update array with ellipsis argument"),
186 };
187 }
188
189 scatter_args_nd(src, operations, update, stream)
190}
191
192fn scatter_args_index<'a>(
193 src: &'a Array,
194 index: i32,
195 update: &Array,
196 stream: impl AsRef<Stream>,
197) -> Result<ScatterArgs<'a>> {
198 let update = remove_leading_singleton_dimensions(update, &stream)?;
203
204 let mut shape: SmallVec<[i32; DEFAULT_STACK_VEC_LEN]> = SmallVec::from_slice(src.shape());
205 shape[0] = 1;
206
207 Ok(ScatterArgs {
208 indices: smallvec![Cow::Owned(Array::from_int(resolve_index_signed_unchecked(
209 index,
210 src.dim(0)
211 )))],
212 update: broadcast_to_device(&update, &shape, &stream)?,
213 axes: smallvec![0],
214 })
215}
216
217fn scatter_args_array<'a>(
218 src: &'a Array,
219 a: Cow<'a, Array>,
220 update: &Array,
221 stream: impl AsRef<Stream>,
222) -> Result<ScatterArgs<'a>> {
223 let update = remove_leading_singleton_dimensions(update, &stream)?;
227
228 let mut update_shape: SmallVec<[i32; DEFAULT_STACK_VEC_LEN]> = a
230 .shape()
231 .iter()
232 .chain(src.shape().iter().skip(1))
233 .cloned()
234 .collect();
235 let update = broadcast_to_device(&update, &update_shape, &stream)?;
236
237 update_shape.insert(a.ndim(), 1);
238 let update = update.reshape_device(&update_shape, &stream)?;
239
240 Ok(ScatterArgs {
241 indices: smallvec![a],
242 update,
243 axes: smallvec![0],
244 })
245}
246
247fn scatter_args_slice<'a>(
248 src: &'a Array,
249 range_index: &'a RangeIndex,
250 update: &Array,
251 stream: impl AsRef<Stream>,
252) -> Result<ScatterArgs<'a>> {
253 if range_index.is_full() {
257 let update = remove_leading_singleton_dimensions(update, &stream)?;
258
259 return Ok(ScatterArgs {
260 indices: smallvec![],
261 update: broadcast_to_device(&update, src.shape(), &stream)?,
262 axes: smallvec![],
263 });
264 }
265
266 let size = src.dim(0);
267 let start = range_index.start(size);
268 let end = range_index.end(size);
269 let stride = range_index.stride();
270
271 if stride == 1 {
273 let update = remove_leading_singleton_dimensions(update, &stream)?;
274
275 let update_broadcast_shape: SmallVec<[i32; DEFAULT_STACK_VEC_LEN]> = (1..end - start)
277 .chain(src.shape().iter().skip(1).cloned())
278 .collect();
279 let update = broadcast_to_device(&update, &update_broadcast_shape, &stream)?;
280
281 let indices = Array::from_slice(&[start], &[1]);
282 Ok(ScatterArgs {
283 indices: smallvec![Cow::Owned(indices)],
284 update,
285 axes: smallvec![0],
286 })
287 } else {
288 let a_vals = strided_range_to_vec(start, end, stride);
290 let a = Array::from_slice(&a_vals, &[a_vals.len() as i32]);
291
292 scatter_args_array(src, Cow::Owned(a), update, stream)
293 }
294}
295
296fn scatter_args_nd<'a>(
297 src: &'a Array,
298 operations: &[ArrayIndexOp],
299 update: &Array,
300 stream: impl AsRef<Stream>,
301) -> Result<ScatterArgs<'a>> {
302 use ArrayIndexOp::*;
303
304 let shape = src.shape();
307
308 let operations = expand_ellipsis_operations(src.ndim(), operations);
309 let update = remove_leading_singleton_dimensions(update, &stream)?;
310
311 let non_new_axis_operation_count = count_non_new_axis_operations(&operations);
313 if non_new_axis_operation_count == 0 {
314 return Ok(ScatterArgs {
315 indices: smallvec![],
316 update: broadcast_to_device(&update, shape, &stream)?,
317 axes: smallvec![],
318 });
319 }
320
321 let mut max_dims = 0;
323 let mut arrays_first = false;
324 let mut count_new_axis: i32 = 0;
325 let mut count_slices: i32 = 0;
326 let mut count_arrays: i32 = 0;
327 let mut count_strided_slices: i32 = 0;
328 let mut count_simple_slices_post: i32 = 0;
329
330 let mut have_array = false;
331 let mut have_non_array = false;
332
333 macro_rules! analyze_indices_take_array {
334 ($indices:ident) => {
335 have_array = true;
336 if have_array && have_non_array {
337 arrays_first = true;
338 }
339 max_dims = $indices.ndim().max(max_dims);
340 count_arrays = count_arrays.saturating_add(1);
341 count_simple_slices_post = 0;
342 };
343 }
344
345 for item in operations.iter() {
346 match item {
347 TakeIndex { index: _ } => {
348 }
350 Slice(range_index) => {
351 have_non_array = have_array;
352 count_slices = count_slices.saturating_add(1);
353 if range_index.stride() != 1 {
354 count_strided_slices = count_strided_slices.saturating_add(1);
355 count_simple_slices_post = 0;
356 } else {
357 count_simple_slices_post = count_simple_slices_post.saturating_add(1);
358 }
359 }
360 TakeArray { indices } => {
361 analyze_indices_take_array!(indices);
362 }
363 TakeArrayRef { indices } => {
364 analyze_indices_take_array!(indices);
365 }
366 ExpandDims => {
367 have_non_array = true;
368 count_new_axis = count_new_axis.saturating_add(1);
369 }
370 Ellipsis => panic!("Unexpected item ellipsis in scatter_args_nd"),
371 }
372 }
373
374 let mut index_dims = (max_dims + count_new_axis as usize + count_slices as usize)
376 .saturating_sub(count_simple_slices_post as usize);
377
378 if index_dims == 0 {
380 index_dims = 1;
381 }
382
383 let mut array_indices: SmallVec<[Array; DEFAULT_STACK_VEC_LEN]> =
385 SmallVec::with_capacity(operations.len());
386 let mut slice_number: i32 = 0;
387 let mut array_number: i32 = 0;
388 let mut axis: i32 = 0;
389
390 let mut update_shape = vec![1; non_new_axis_operation_count];
392 let mut slice_shapes: SmallVec<[i32; DEFAULT_STACK_VEC_LEN]> = SmallVec::new();
393
394 macro_rules! update_shapes_take_array {
395 ($indices:ident) => {
396 let start = if arrays_first {
398 max_dims - $indices.ndim()
399 } else {
400 slice_number as usize + max_dims - $indices.ndim()
402 };
403 let mut new_shape = vec![1; index_dims];
404
405 for j in 0..$indices.ndim() {
406 new_shape[start + j] = $indices.dim(j as i32);
407 }
408
409 array_indices.push($indices.reshape_device(&new_shape, &stream)?);
410 array_number = array_number.saturating_add(1);
411
412 if !arrays_first && array_number == count_arrays {
413 slice_number = slice_number.saturating_add_unsigned(max_dims as u32);
414 }
415
416 update_shape[axis as usize] = 1;
418 axis = axis.saturating_add(1);
419 };
420 }
421
422 for item in operations.iter() {
423 match item {
424 TakeIndex { index } => {
425 let resolved_index = resolve_index_signed_unchecked(*index, src.dim(axis));
426 array_indices.push(Array::from_int(resolved_index));
427 update_shape[axis as usize] = 1;
429 axis = axis.saturating_add(1);
430 }
431 Slice(range_index) => {
432 let size = src.dim(axis);
433 let start = range_index.absolute_start(size);
434 let end = range_index.absolute_end(size);
435 let stride = range_index.stride();
436
437 let mut index_shape = vec![1; index_dims];
438
439 if array_number >= count_arrays && count_strided_slices <= 0 && stride == 1 {
441 let index = Array::from_int(start).reshape_device(&index_shape, &stream)?;
442 let slice_shape_entry = end - start;
443 slice_shapes.push(slice_shape_entry);
444 array_indices.push(index);
445
446 update_shape[axis as usize] = slice_shape_entry;
448 } else {
449 let index_vals = strided_range_to_vec(start, end, stride);
451 let index = Array::from_slice(&index_vals, &[index_vals.len() as i32]);
452 let location = if arrays_first {
453 slice_number.saturating_add(max_dims as i32)
454 } else {
455 slice_number
456 };
457 index_shape[location as usize] = index.size() as i32;
458 array_indices.push(index.reshape_device(&index_shape, &stream)?);
459
460 slice_number = slice_number.saturating_add(1);
461 count_strided_slices = count_strided_slices.saturating_sub(1);
462
463 update_shape[axis as usize] = 1;
465 }
466
467 axis = axis.saturating_add(1);
468 }
469 TakeArray { indices } => {
470 update_shapes_take_array!(indices);
471 }
472 TakeArrayRef { indices } => {
473 update_shapes_take_array!(indices);
474 }
475 ExpandDims => slice_number = slice_number.saturating_add(1),
476 Ellipsis => panic!("Unexpected item ellipsis in scatter_args_nd"),
477 }
478 }
479
480 let array_indices = broadcast_arrays_device(&array_indices, &stream)?;
482 let update_shape_broadcast: Vec<_> = array_indices[0]
483 .shape()
484 .iter()
485 .chain(slice_shapes.iter())
486 .chain(src.shape().iter().skip(non_new_axis_operation_count))
487 .cloned()
488 .collect();
489 let update = broadcast_to_device(&update, &update_shape_broadcast, &stream)?;
490
491 let update_reshape: Vec<_> = array_indices[0]
493 .shape()
494 .iter()
495 .chain(update_shape.iter())
496 .chain(src.shape().iter().skip(non_new_axis_operation_count))
497 .cloned()
498 .collect();
499
500 let update = update.reshape_device(&update_reshape, &stream)?;
501
502 let array_indices_len = array_indices.len();
503
504 let indices = array_indices.into_iter().map(Cow::Owned).collect();
505 Ok(ScatterArgs {
506 indices,
507 update,
508 axes: (0..array_indices_len as i32).collect(),
509 })
510}
511
512fn strided_range_to_vec(start: i32, exclusive_end: i32, stride: i32) -> Vec<i32> {
513 let estimated_capacity = (exclusive_end - start).abs() / stride.abs();
514 let mut vec = Vec::with_capacity(estimated_capacity as usize);
515 let mut current = start;
516
517 assert_ne!(stride, 0, "Stride cannot be zero");
518
519 if stride.is_negative() {
520 while current > exclusive_end {
521 vec.push(current);
522 current += stride;
523 }
524 } else {
525 while current < exclusive_end {
526 vec.push(current);
527 current += stride;
528 }
529 }
530
531 vec
532}
533
534unsafe fn scatter_device(
535 a: &Array,
536 indices: &[impl AsRef<Array>],
537 updates: &Array,
538 axes: &[i32],
539 stream: impl AsRef<Stream>,
540) -> Result<Array> {
541 let indices_vector = VectorArray::try_from_iter(indices.iter())?;
542
543 Array::try_from_op(|res| unsafe {
544 mlx_sys::mlx_scatter(
545 res,
546 a.as_ptr(),
547 indices_vector.as_ptr(),
548 updates.as_ptr(),
549 axes.as_ptr(),
550 axes.len(),
551 stream.as_ref().as_ptr(),
552 )
553 })
554}
555
556impl Array {
557 fn try_index_mut_device_inner(
558 &mut self,
559 operations: &[ArrayIndexOp],
560 update: &Array,
561 stream: impl AsRef<Stream>,
562 ) -> Result<()> {
563 if let Some(result) = update_slice(self, operations, update, &stream)? {
564 *self = result;
565 return Ok(());
566 }
567
568 let ScatterArgs {
569 indices,
570 update,
571 axes,
572 } = scatter_args(self, operations, update, &stream)?;
573 if !indices.is_empty() {
574 let result = unsafe { scatter_device(self, &indices, &update, &axes, stream)? };
575 drop(indices);
576 *self = result;
577 } else {
578 drop(indices);
579 *self = update;
580 }
581 Ok(())
582 }
583}
584
585impl<'a, Val> TryIndexMutOp<&'a [ArrayIndexOp<'a>], Val> for Array
586where
587 Val: AsRef<Array>,
588{
589 fn try_index_mut_device(
590 &mut self,
591 i: &'a [ArrayIndexOp<'a>],
592 val: Val,
593 stream: impl AsRef<Stream>,
594 ) -> Result<()> {
595 let update = val.as_ref();
596 self.try_index_mut_device_inner(i, update, stream)
597 }
598}
599
600impl<A, Val> TryIndexMutOp<A, Val> for Array
601where
602 for<'a> A: ArrayIndex<'a>,
603 Val: AsRef<Array>,
604{
605 fn try_index_mut_device(&mut self, i: A, val: Val, stream: impl AsRef<Stream>) -> Result<()> {
606 let operations = [i.index_op()];
607 let update = val.as_ref();
608 self.try_index_mut_device_inner(&operations, update, stream)
609 }
610}
611
612impl<'a, A, Val> TryIndexMutOp<(A,), Val> for Array
613where
614 A: ArrayIndex<'a>,
615 Val: AsRef<Array>,
616{
617 fn try_index_mut_device(
618 &mut self,
619 (i,): (A,),
620 val: Val,
621 stream: impl AsRef<Stream>,
622 ) -> Result<()> {
623 let operations = [i.index_op()];
624 let update = val.as_ref();
625 self.try_index_mut_device_inner(&operations, update, stream)
626 }
627}
628
629impl<'a, 'b, A, B, Val> TryIndexMutOp<(A, B), Val> for Array
630where
631 A: ArrayIndex<'a>,
632 B: ArrayIndex<'b>,
633 Val: AsRef<Array>,
634{
635 fn try_index_mut_device(
636 &mut self,
637 i: (A, B),
638 val: Val,
639 stream: impl AsRef<Stream>,
640 ) -> Result<()> {
641 let operations = [i.0.index_op(), i.1.index_op()];
642 let update = val.as_ref();
643 self.try_index_mut_device_inner(&operations, update, stream)
644 }
645}
646
647impl<'a, 'b, 'c, A, B, C, Val> TryIndexMutOp<(A, B, C), Val> for Array
648where
649 A: ArrayIndex<'a>,
650 B: ArrayIndex<'b>,
651 C: ArrayIndex<'c>,
652 Val: AsRef<Array>,
653{
654 fn try_index_mut_device(
655 &mut self,
656 i: (A, B, C),
657 val: Val,
658 stream: impl AsRef<Stream>,
659 ) -> Result<()> {
660 let operations = [i.0.index_op(), i.1.index_op(), i.2.index_op()];
661 let update = val.as_ref();
662 self.try_index_mut_device_inner(&operations, update, stream)
663 }
664}
665
666impl<'a, 'b, 'c, 'd, A, B, C, D, Val> TryIndexMutOp<(A, B, C, D), Val> for Array
667where
668 A: ArrayIndex<'a>,
669 B: ArrayIndex<'b>,
670 C: ArrayIndex<'c>,
671 D: ArrayIndex<'d>,
672 Val: AsRef<Array>,
673{
674 fn try_index_mut_device(
675 &mut self,
676 i: (A, B, C, D),
677 val: Val,
678 stream: impl AsRef<Stream>,
679 ) -> Result<()> {
680 let operations = [
681 i.0.index_op(),
682 i.1.index_op(),
683 i.2.index_op(),
684 i.3.index_op(),
685 ];
686 let update = val.as_ref();
687 self.try_index_mut_device_inner(&operations, update, stream)
688 }
689}
690
691impl<'a, 'b, 'c, 'd, 'e, A, B, C, D, E, Val> TryIndexMutOp<(A, B, C, D, E), Val> for Array
692where
693 A: ArrayIndex<'a>,
694 B: ArrayIndex<'b>,
695 C: ArrayIndex<'c>,
696 D: ArrayIndex<'d>,
697 E: ArrayIndex<'e>,
698 Val: AsRef<Array>,
699{
700 fn try_index_mut_device(
701 &mut self,
702 i: (A, B, C, D, E),
703 val: Val,
704 stream: impl AsRef<Stream>,
705 ) -> Result<()> {
706 let operations = [
707 i.0.index_op(),
708 i.1.index_op(),
709 i.2.index_op(),
710 i.3.index_op(),
711 i.4.index_op(),
712 ];
713 let update = val.as_ref();
714 self.try_index_mut_device_inner(&operations, update, stream)
715 }
716}
717
718impl<'a, 'b, 'c, 'd, 'e, 'f, A, B, C, D, E, F, Val> TryIndexMutOp<(A, B, C, D, E, F), Val> for Array
719where
720 A: ArrayIndex<'a>,
721 B: ArrayIndex<'b>,
722 C: ArrayIndex<'c>,
723 D: ArrayIndex<'d>,
724 E: ArrayIndex<'e>,
725 F: ArrayIndex<'f>,
726 Val: AsRef<Array>,
727{
728 fn try_index_mut_device(
729 &mut self,
730 i: (A, B, C, D, E, F),
731 val: Val,
732 stream: impl AsRef<Stream>,
733 ) -> Result<()> {
734 let operations = [
735 i.0.index_op(),
736 i.1.index_op(),
737 i.2.index_op(),
738 i.3.index_op(),
739 i.4.index_op(),
740 i.5.index_op(),
741 ];
742 let update = val.as_ref();
743 self.try_index_mut_device_inner(&operations, update, stream)
744 }
745}
746
747impl<'a, 'b, 'c, 'd, 'e, 'f, 'g, A, B, C, D, E, F, G, Val> TryIndexMutOp<(A, B, C, D, E, F, G), Val>
748 for Array
749where
750 A: ArrayIndex<'a>,
751 B: ArrayIndex<'b>,
752 C: ArrayIndex<'c>,
753 D: ArrayIndex<'d>,
754 E: ArrayIndex<'e>,
755 F: ArrayIndex<'f>,
756 G: ArrayIndex<'g>,
757 Val: AsRef<Array>,
758{
759 fn try_index_mut_device(
760 &mut self,
761 i: (A, B, C, D, E, F, G),
762 val: Val,
763 stream: impl AsRef<Stream>,
764 ) -> Result<()> {
765 let operations = [
766 i.0.index_op(),
767 i.1.index_op(),
768 i.2.index_op(),
769 i.3.index_op(),
770 i.4.index_op(),
771 i.5.index_op(),
772 i.6.index_op(),
773 ];
774 let update = val.as_ref();
775 self.try_index_mut_device_inner(&operations, update, stream)
776 }
777}
778
779impl<'a, 'b, 'c, 'd, 'e, 'f, 'g, 'h, A, B, C, D, E, F, G, H, Val>
780 TryIndexMutOp<(A, B, C, D, E, F, G, H), Val> for Array
781where
782 A: ArrayIndex<'a>,
783 B: ArrayIndex<'b>,
784 C: ArrayIndex<'c>,
785 D: ArrayIndex<'d>,
786 E: ArrayIndex<'e>,
787 F: ArrayIndex<'f>,
788 G: ArrayIndex<'g>,
789 H: ArrayIndex<'h>,
790 Val: AsRef<Array>,
791{
792 fn try_index_mut_device(
793 &mut self,
794 i: (A, B, C, D, E, F, G, H),
795 val: Val,
796 stream: impl AsRef<Stream>,
797 ) -> Result<()> {
798 let operations = [
799 i.0.index_op(),
800 i.1.index_op(),
801 i.2.index_op(),
802 i.3.index_op(),
803 i.4.index_op(),
804 i.5.index_op(),
805 i.6.index_op(),
806 i.7.index_op(),
807 ];
808 let update = val.as_ref();
809 self.try_index_mut_device_inner(&operations, update, stream)
810 }
811}
812
813impl<'a, 'b, 'c, 'd, 'e, 'f, 'g, 'h, 'i, A, B, C, D, E, F, G, H, I, Val>
814 TryIndexMutOp<(A, B, C, D, E, F, G, H, I), Val> for Array
815where
816 A: ArrayIndex<'a>,
817 B: ArrayIndex<'b>,
818 C: ArrayIndex<'c>,
819 D: ArrayIndex<'d>,
820 E: ArrayIndex<'e>,
821 F: ArrayIndex<'f>,
822 G: ArrayIndex<'g>,
823 H: ArrayIndex<'h>,
824 I: ArrayIndex<'i>,
825 Val: AsRef<Array>,
826{
827 fn try_index_mut_device(
828 &mut self,
829 i: (A, B, C, D, E, F, G, H, I),
830 val: Val,
831 stream: impl AsRef<Stream>,
832 ) -> Result<()> {
833 let operations = [
834 i.0.index_op(),
835 i.1.index_op(),
836 i.2.index_op(),
837 i.3.index_op(),
838 i.4.index_op(),
839 i.5.index_op(),
840 i.6.index_op(),
841 i.7.index_op(),
842 i.8.index_op(),
843 ];
844 let update = val.as_ref();
845 self.try_index_mut_device_inner(&operations, update, stream)
846 }
847}
848
849impl<'a, 'b, 'c, 'd, 'e, 'f, 'g, 'h, 'i, 'j, A, B, C, D, E, F, G, H, I, J, Val>
850 TryIndexMutOp<(A, B, C, D, E, F, G, H, I, J), Val> for Array
851where
852 A: ArrayIndex<'a>,
853 B: ArrayIndex<'b>,
854 C: ArrayIndex<'c>,
855 D: ArrayIndex<'d>,
856 E: ArrayIndex<'e>,
857 F: ArrayIndex<'f>,
858 G: ArrayIndex<'g>,
859 H: ArrayIndex<'h>,
860 I: ArrayIndex<'i>,
861 J: ArrayIndex<'j>,
862 Val: AsRef<Array>,
863{
864 fn try_index_mut_device(
865 &mut self,
866 i: (A, B, C, D, E, F, G, H, I, J),
867 val: Val,
868 stream: impl AsRef<Stream>,
869 ) -> Result<()> {
870 let operations = [
871 i.0.index_op(),
872 i.1.index_op(),
873 i.2.index_op(),
874 i.3.index_op(),
875 i.4.index_op(),
876 i.5.index_op(),
877 i.6.index_op(),
878 i.7.index_op(),
879 i.8.index_op(),
880 i.9.index_op(),
881 ];
882 let update = val.as_ref();
883 self.try_index_mut_device_inner(&operations, update, stream)
884 }
885}
886
887impl<'a, 'b, 'c, 'd, 'e, 'f, 'g, 'h, 'i, 'j, 'k, A, B, C, D, E, F, G, H, I, J, K, Val>
888 TryIndexMutOp<(A, B, C, D, E, F, G, H, I, J, K), Val> for Array
889where
890 A: ArrayIndex<'a>,
891 B: ArrayIndex<'b>,
892 C: ArrayIndex<'c>,
893 D: ArrayIndex<'d>,
894 E: ArrayIndex<'e>,
895 F: ArrayIndex<'f>,
896 G: ArrayIndex<'g>,
897 H: ArrayIndex<'h>,
898 I: ArrayIndex<'i>,
899 J: ArrayIndex<'j>,
900 K: ArrayIndex<'k>,
901 Val: AsRef<Array>,
902{
903 fn try_index_mut_device(
904 &mut self,
905 i: (A, B, C, D, E, F, G, H, I, J, K),
906 val: Val,
907 stream: impl AsRef<Stream>,
908 ) -> Result<()> {
909 let operations = [
910 i.0.index_op(),
911 i.1.index_op(),
912 i.2.index_op(),
913 i.3.index_op(),
914 i.4.index_op(),
915 i.5.index_op(),
916 i.6.index_op(),
917 i.7.index_op(),
918 i.8.index_op(),
919 i.9.index_op(),
920 i.10.index_op(),
921 ];
922 let update = val.as_ref();
923 self.try_index_mut_device_inner(&operations, update, stream)
924 }
925}
926
927impl<'a, 'b, 'c, 'd, 'e, 'f, 'g, 'h, 'i, 'j, 'k, 'l, A, B, C, D, E, F, G, H, I, J, K, L, Val>
928 TryIndexMutOp<(A, B, C, D, E, F, G, H, I, J, K, L), Val> for Array
929where
930 A: ArrayIndex<'a>,
931 B: ArrayIndex<'b>,
932 C: ArrayIndex<'c>,
933 D: ArrayIndex<'d>,
934 E: ArrayIndex<'e>,
935 F: ArrayIndex<'f>,
936 G: ArrayIndex<'g>,
937 H: ArrayIndex<'h>,
938 I: ArrayIndex<'i>,
939 J: ArrayIndex<'j>,
940 K: ArrayIndex<'k>,
941 L: ArrayIndex<'l>,
942 Val: AsRef<Array>,
943{
944 fn try_index_mut_device(
945 &mut self,
946 i: (A, B, C, D, E, F, G, H, I, J, K, L),
947 val: Val,
948 stream: impl AsRef<Stream>,
949 ) -> Result<()> {
950 let operations = [
951 i.0.index_op(),
952 i.1.index_op(),
953 i.2.index_op(),
954 i.3.index_op(),
955 i.4.index_op(),
956 i.5.index_op(),
957 i.6.index_op(),
958 i.7.index_op(),
959 i.8.index_op(),
960 i.9.index_op(),
961 i.10.index_op(),
962 i.11.index_op(),
963 ];
964 let update = val.as_ref();
965 self.try_index_mut_device_inner(&operations, update, stream)
966 }
967}
968
969impl<
970 'a,
971 'b,
972 'c,
973 'd,
974 'e,
975 'f,
976 'g,
977 'h,
978 'i,
979 'j,
980 'k,
981 'l,
982 'm,
983 A,
984 B,
985 C,
986 D,
987 E,
988 F,
989 G,
990 H,
991 I,
992 J,
993 K,
994 L,
995 M,
996 Val,
997 > TryIndexMutOp<(A, B, C, D, E, F, G, H, I, J, K, L, M), Val> for Array
998where
999 A: ArrayIndex<'a>,
1000 B: ArrayIndex<'b>,
1001 C: ArrayIndex<'c>,
1002 D: ArrayIndex<'d>,
1003 E: ArrayIndex<'e>,
1004 F: ArrayIndex<'f>,
1005 G: ArrayIndex<'g>,
1006 H: ArrayIndex<'h>,
1007 I: ArrayIndex<'i>,
1008 J: ArrayIndex<'j>,
1009 K: ArrayIndex<'k>,
1010 L: ArrayIndex<'l>,
1011 M: ArrayIndex<'m>,
1012 Val: AsRef<Array>,
1013{
1014 fn try_index_mut_device(
1015 &mut self,
1016 i: (A, B, C, D, E, F, G, H, I, J, K, L, M),
1017 val: Val,
1018 stream: impl AsRef<Stream>,
1019 ) -> Result<()> {
1020 let operations = [
1021 i.0.index_op(),
1022 i.1.index_op(),
1023 i.2.index_op(),
1024 i.3.index_op(),
1025 i.4.index_op(),
1026 i.5.index_op(),
1027 i.6.index_op(),
1028 i.7.index_op(),
1029 i.8.index_op(),
1030 i.9.index_op(),
1031 i.10.index_op(),
1032 i.11.index_op(),
1033 i.12.index_op(),
1034 ];
1035 let update = val.as_ref();
1036 self.try_index_mut_device_inner(&operations, update, stream)
1037 }
1038}
1039
1040impl<
1041 'a,
1042 'b,
1043 'c,
1044 'd,
1045 'e,
1046 'f,
1047 'g,
1048 'h,
1049 'i,
1050 'j,
1051 'k,
1052 'l,
1053 'm,
1054 'n,
1055 A,
1056 B,
1057 C,
1058 D,
1059 E,
1060 F,
1061 G,
1062 H,
1063 I,
1064 J,
1065 K,
1066 L,
1067 M,
1068 N,
1069 Val,
1070 > TryIndexMutOp<(A, B, C, D, E, F, G, H, I, J, K, L, M, N), Val> for Array
1071where
1072 A: ArrayIndex<'a>,
1073 B: ArrayIndex<'b>,
1074 C: ArrayIndex<'c>,
1075 D: ArrayIndex<'d>,
1076 E: ArrayIndex<'e>,
1077 F: ArrayIndex<'f>,
1078 G: ArrayIndex<'g>,
1079 H: ArrayIndex<'h>,
1080 I: ArrayIndex<'i>,
1081 J: ArrayIndex<'j>,
1082 K: ArrayIndex<'k>,
1083 L: ArrayIndex<'l>,
1084 M: ArrayIndex<'m>,
1085 N: ArrayIndex<'n>,
1086 Val: AsRef<Array>,
1087{
1088 fn try_index_mut_device(
1089 &mut self,
1090 i: (A, B, C, D, E, F, G, H, I, J, K, L, M, N),
1091 val: Val,
1092 stream: impl AsRef<Stream>,
1093 ) -> Result<()> {
1094 let operations = [
1095 i.0.index_op(),
1096 i.1.index_op(),
1097 i.2.index_op(),
1098 i.3.index_op(),
1099 i.4.index_op(),
1100 i.5.index_op(),
1101 i.6.index_op(),
1102 i.7.index_op(),
1103 i.8.index_op(),
1104 i.9.index_op(),
1105 i.10.index_op(),
1106 i.11.index_op(),
1107 i.12.index_op(),
1108 i.13.index_op(),
1109 ];
1110 let update = val.as_ref();
1111 self.try_index_mut_device_inner(&operations, update, stream)
1112 }
1113}
1114
1115impl<
1116 'a,
1117 'b,
1118 'c,
1119 'd,
1120 'e,
1121 'f,
1122 'g,
1123 'h,
1124 'i,
1125 'j,
1126 'k,
1127 'l,
1128 'm,
1129 'n,
1130 'o,
1131 A,
1132 B,
1133 C,
1134 D,
1135 E,
1136 F,
1137 G,
1138 H,
1139 I,
1140 J,
1141 K,
1142 L,
1143 M,
1144 N,
1145 O,
1146 Val,
1147 > TryIndexMutOp<(A, B, C, D, E, F, G, H, I, J, K, L, M, N, O), Val> for Array
1148where
1149 A: ArrayIndex<'a>,
1150 B: ArrayIndex<'b>,
1151 C: ArrayIndex<'c>,
1152 D: ArrayIndex<'d>,
1153 E: ArrayIndex<'e>,
1154 F: ArrayIndex<'f>,
1155 G: ArrayIndex<'g>,
1156 H: ArrayIndex<'h>,
1157 I: ArrayIndex<'i>,
1158 J: ArrayIndex<'j>,
1159 K: ArrayIndex<'k>,
1160 L: ArrayIndex<'l>,
1161 M: ArrayIndex<'m>,
1162 N: ArrayIndex<'n>,
1163 O: ArrayIndex<'o>,
1164 Val: AsRef<Array>,
1165{
1166 fn try_index_mut_device(
1167 &mut self,
1168 i: (A, B, C, D, E, F, G, H, I, J, K, L, M, N, O),
1169 val: Val,
1170 stream: impl AsRef<Stream>,
1171 ) -> Result<()> {
1172 let operations = [
1173 i.0.index_op(),
1174 i.1.index_op(),
1175 i.2.index_op(),
1176 i.3.index_op(),
1177 i.4.index_op(),
1178 i.5.index_op(),
1179 i.6.index_op(),
1180 i.7.index_op(),
1181 i.8.index_op(),
1182 i.9.index_op(),
1183 i.10.index_op(),
1184 i.11.index_op(),
1185 i.12.index_op(),
1186 i.13.index_op(),
1187 i.14.index_op(),
1188 ];
1189 let update = val.as_ref();
1190 self.try_index_mut_device_inner(&operations, update, stream)
1191 }
1192}
1193
1194impl<
1195 'a,
1196 'b,
1197 'c,
1198 'd,
1199 'e,
1200 'f,
1201 'g,
1202 'h,
1203 'i,
1204 'j,
1205 'k,
1206 'l,
1207 'm,
1208 'n,
1209 'o,
1210 'p,
1211 A,
1212 B,
1213 C,
1214 D,
1215 E,
1216 F,
1217 G,
1218 H,
1219 I,
1220 J,
1221 K,
1222 L,
1223 M,
1224 N,
1225 O,
1226 P,
1227 Val,
1228 > TryIndexMutOp<(A, B, C, D, E, F, G, H, I, J, K, L, M, N, O, P), Val> for Array
1229where
1230 A: ArrayIndex<'a>,
1231 B: ArrayIndex<'b>,
1232 C: ArrayIndex<'c>,
1233 D: ArrayIndex<'d>,
1234 E: ArrayIndex<'e>,
1235 F: ArrayIndex<'f>,
1236 G: ArrayIndex<'g>,
1237 H: ArrayIndex<'h>,
1238 I: ArrayIndex<'i>,
1239 J: ArrayIndex<'j>,
1240 K: ArrayIndex<'k>,
1241 L: ArrayIndex<'l>,
1242 M: ArrayIndex<'m>,
1243 N: ArrayIndex<'n>,
1244 O: ArrayIndex<'o>,
1245 P: ArrayIndex<'p>,
1246 Val: AsRef<Array>,
1247{
1248 fn try_index_mut_device(
1249 &mut self,
1250 i: (A, B, C, D, E, F, G, H, I, J, K, L, M, N, O, P),
1251 val: Val,
1252 stream: impl AsRef<Stream>,
1253 ) -> Result<()> {
1254 let operations = [
1255 i.0.index_op(),
1256 i.1.index_op(),
1257 i.2.index_op(),
1258 i.3.index_op(),
1259 i.4.index_op(),
1260 i.5.index_op(),
1261 i.6.index_op(),
1262 i.7.index_op(),
1263 i.8.index_op(),
1264 i.9.index_op(),
1265 i.10.index_op(),
1266 i.11.index_op(),
1267 i.12.index_op(),
1268 i.13.index_op(),
1269 i.14.index_op(),
1270 i.15.index_op(),
1271 ];
1272 let update = val.as_ref();
1273 self.try_index_mut_device_inner(&operations, update, stream)
1274 }
1275}
1276
1277#[cfg(test)]
1279mod tests {
1280 use crate::{ops::indexing::*, Array};
1281
1282 #[test]
1283 fn test_array_mutate_single_index() {
1284 let mut a = Array::from_iter(0i32..12, &[3, 4]);
1285 let new_value = Array::from_int(77);
1286 a.index_mut(1, new_value);
1287
1288 let expected = Array::from_slice(&[0, 1, 2, 3, 77, 77, 77, 77, 8, 9, 10, 11], &[3, 4]);
1289 assert_array_all_close!(a, expected);
1290 }
1291
1292 #[test]
1293 fn test_array_mutate_broadcast_multi_index() {
1294 let mut a = Array::from_iter(0i32..20, &[2, 2, 5]);
1295
1296 a.index_mut((1, 0), Array::from_int(77));
1298
1299 a.index_mut((0, 0), Array::from_slice(&[55i32, 66, 77, 88, 99], &[5]));
1301
1302 a.index_mut((0, 1, 3), Array::from_int(123));
1304
1305 let expected = Array::from_slice(
1306 &[
1307 55, 66, 77, 88, 99, 5, 6, 7, 123, 9, 77, 77, 77, 77, 77, 15, 16, 17, 18, 19,
1308 ],
1309 &[2, 2, 5],
1310 );
1311 assert_array_all_close!(a, expected);
1312 }
1313
1314 #[test]
1315 fn test_array_mutate_broadcast_slice() {
1316 let mut a = Array::from_iter(0i32..20, &[2, 2, 5]);
1317
1318 a.index_mut((0..1, 1..2, 2..4), Array::from_int(88));
1320
1321 let expected = Array::from_slice(
1322 &[
1323 0, 1, 2, 3, 4, 5, 6, 88, 88, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19,
1324 ],
1325 &[2, 2, 5],
1326 );
1327 assert_array_all_close!(a, expected);
1328 }
1329
1330 #[test]
1331 fn test_array_mutate_advanced() {
1332 let mut a = Array::from_iter(0i32..35, &[5, 7]);
1333
1334 let i1 = Array::from_slice(&[0, 2, 4], &[3]);
1335 let i2 = Array::from_slice(&[0, 1, 2], &[3]);
1336
1337 a.index_mut((i1, i2), Array::from_slice(&[100, 200, 300], &[3]));
1338
1339 assert_eq!(a.index((0, 0)).item::<i32>(), 100i32);
1340 assert_eq!(a.index((2, 1)).item::<i32>(), 200i32);
1341 assert_eq!(a.index((4, 2)).item::<i32>(), 300i32);
1342 }
1343
1344 #[test]
1345 fn test_full_index_write_single() {
1346 fn check<I>(index: I, expected_sum: i32)
1347 where
1348 for<'a> I: ArrayIndex<'a>,
1349 {
1350 let mut a = Array::from_iter(0..60, &[3, 4, 5]);
1351
1352 a.index_mut(index, Array::from_int(1));
1353 let sum = a.sum(None, None).unwrap().item::<i32>();
1354 assert_eq!(sum, expected_sum);
1355 }
1356
1357 check(NewAxis, 60);
1362
1363 check(0, 1600);
1365
1366 check(1..3, 230);
1368
1369 let i = Array::from_slice(&[2, 1], &[2]);
1371
1372 check(i, 230);
1374 }
1375
1376 #[test]
1377 fn test_full_index_write_no_array() {
1378 macro_rules! check {
1379 (($( $i:expr ),*), $sum:expr ) => {
1380 {
1381 let mut a = Array::from_iter(0..360, &[2, 3, 4, 5, 3]);
1382
1383 a.index_mut(($($i),*), Array::from_int(1));
1384 let sum = a.sum(None, None).unwrap().item::<i32>();
1385 assert_eq!(sum, $sum);
1386 }
1387 };
1388 }
1389
1390 check!((Ellipsis, 0), 43320);
1392
1393 check!((0, Ellipsis), 48690);
1395
1396 check!((0, Ellipsis, 0), 59370);
1398
1399 check!((Ellipsis, (..).stride_by(2), ..), 26064);
1401
1402 check!((Ellipsis, NewAxis, (..).stride_by(2), -1), 51696);
1404
1405 check!((.., 2.., 0), 58140);
1407
1408 check!(
1410 (
1411 (..).stride_by(-1),
1412 ..2,
1413 2..,
1414 Ellipsis,
1415 NewAxis,
1416 (..).stride_by(2)
1417 ),
1418 51540
1419 );
1420 }
1421
1422 #[test]
1423 fn test_full_index_write_array() {
1424 macro_rules! check {
1427 (($( $i:expr ),*), $sum:expr ) => {
1428 {
1429 let mut a = Array::from_iter(0..540, &[3, 3, 4, 5, 3]);
1430
1431 a.index_mut(($($i),*), Array::from_int(1));
1432 let sum = a.sum(None, None).unwrap().item::<i32>();
1433 assert_eq!(sum, $sum);
1434 }
1435 };
1436 }
1437
1438 let i = Array::from_slice(&[2, 1], &[2]);
1440
1441 check!((0, &i), 131310);
1443
1444 check!((Ellipsis, &i, 0), 126378);
1446
1447 check!((&i, 0, Ellipsis), 109710);
1449
1450 check!((&i, Ellipsis, &i), 102450);
1452
1453 check!((&i, Ellipsis, (..).stride_by(2), ..), 68094);
1455
1456 check!((Ellipsis, &i, NewAxis, (..).stride_by(2), -1), 130977);
1458
1459 check!((.., 2.., &i), 115965);
1461
1462 check!(
1464 (
1465 (..).stride_by(-1),
1466 ..2,
1467 i,
1468 2..,
1469 Ellipsis,
1470 NewAxis,
1471 (..).stride_by(2)
1472 ),
1473 128142
1474 );
1475 }
1476}