use std::borrow::Cow;
use smallvec::{smallvec, SmallVec};
use crate::{
constants::DEFAULT_STACK_VEC_LEN,
error::Result,
ops::{
broadcast_arrays_device, broadcast_to_device,
indexing::{count_non_new_axis_operations, expand_ellipsis_operations},
},
utils::{resolve_index_signed_unchecked, VectorArray},
Array, Stream,
};
use super::{ArrayIndex, ArrayIndexOp, Guarded, RangeIndex, TryIndexMutOp};
impl Array {
pub(crate) fn slice_update_device(
&self,
update: &Array,
starts: &[i32],
ends: &[i32],
strides: &[i32],
stream: impl AsRef<Stream>,
) -> Result<Array> {
Array::try_from_op(|res| unsafe {
mlx_sys::mlx_slice_update(
res,
self.as_ptr(),
update.as_ptr(),
starts.as_ptr(),
starts.len(),
ends.as_ptr(),
ends.len(),
strides.as_ptr(),
strides.len(),
stream.as_ref().as_ptr(),
)
})
}
}
fn update_slice(
src: &Array,
operations: &[ArrayIndexOp],
update: &Array,
stream: impl AsRef<Stream>,
) -> Result<Option<Array>> {
let ndim = src.ndim();
if ndim == 0 || operations.is_empty() {
return Ok(None);
}
let mut update = remove_leading_singleton_dimensions(update, &stream)?;
let mut starts: SmallVec<[i32; DEFAULT_STACK_VEC_LEN]> = smallvec![0; ndim];
let mut ends: SmallVec<[i32; DEFAULT_STACK_VEC_LEN]> = SmallVec::from_slice(src.shape());
let mut strides: SmallVec<[i32; DEFAULT_STACK_VEC_LEN]> = smallvec![1; ndim];
if operations.len() == 1 {
if let ArrayIndexOp::Slice(range_index) = &operations[0] {
let size = src.dim(0);
starts[0] = range_index.start(size);
ends[0] = range_index.end(size);
strides[0] = range_index.stride();
return Ok(Some(src.slice_update_device(
&update, &starts, &ends, &strides, &stream,
)?));
}
}
if operations.iter().any(|op| op.is_array()) {
return Ok(None);
}
let operations = expand_ellipsis_operations(ndim, operations);
if count_non_new_axis_operations(&operations) == 0 {
return Ok(Some(broadcast_to_device(&update, src.shape(), &stream)?));
}
let mut update_expand_dims: SmallVec<[i32; DEFAULT_STACK_VEC_LEN]> = SmallVec::new();
let mut axis = 0i32;
for item in operations.iter() {
use ArrayIndexOp::*;
match item {
TakeIndex { index } => {
let size = src.dim(axis);
let index = if index.is_negative() {
size + index
} else {
*index
};
starts[axis as usize] = index;
ends[axis as usize] = index + 1;
if ndim - (axis as usize) < update.ndim() {
update_expand_dims.push(axis.saturating_sub_unsigned(ndim as u32));
}
axis = axis.saturating_add(1);
}
Slice(range_index) => {
let size = src.dim(axis);
starts[axis as usize] = range_index.start(size);
ends[axis as usize] = range_index.end(size);
strides[axis as usize] = range_index.stride();
axis = axis.saturating_add(1);
}
ExpandDims => {}
Ellipsis | TakeArray { indices: _ } | TakeArrayRef { indices: _ } => {
panic!("unexpected item in operations")
}
}
}
if !update_expand_dims.is_empty() {
update = Cow::Owned(update.expand_dims_device(&update_expand_dims, &stream)?);
}
Ok(Some(src.slice_update_device(
&update, &starts, &ends, &strides, &stream,
)?))
}
fn remove_leading_singleton_dimensions(
a: &Array,
stream: impl AsRef<Stream>,
) -> Result<Cow<'_, Array>> {
let shape = a.shape();
let mut new_shape: Vec<_> = shape.iter().skip_while(|&&dim| dim == 1).cloned().collect();
if shape != new_shape {
if new_shape.is_empty() {
new_shape = vec![1];
}
Ok(Cow::Owned(a.reshape_device(&new_shape, stream)?))
} else {
Ok(Cow::Borrowed(a))
}
}
struct ScatterArgs<'a> {
indices: SmallVec<[Cow<'a, Array>; DEFAULT_STACK_VEC_LEN]>,
update: Array,
axes: SmallVec<[i32; DEFAULT_STACK_VEC_LEN]>,
}
fn scatter_args<'a>(
src: &'a Array,
operations: &'a [ArrayIndexOp],
update: &Array,
stream: impl AsRef<Stream>,
) -> Result<ScatterArgs<'a>> {
use ArrayIndexOp::*;
if operations.len() == 1 {
return match &operations[0] {
TakeIndex { index } => scatter_args_index(src, *index, update, stream),
TakeArray { indices } => {
scatter_args_array(src, Cow::Borrowed(indices), update, stream)
}
TakeArrayRef { indices } => {
scatter_args_array(src, Cow::Borrowed(indices), update, stream)
}
Slice(range_index) => scatter_args_slice(src, range_index, update, stream),
ExpandDims => Ok(ScatterArgs {
indices: smallvec![],
update: broadcast_to_device(update, src.shape(), &stream)?,
axes: smallvec![],
}),
Ellipsis => panic!("Unable to update array with ellipsis argument"),
};
}
scatter_args_nd(src, operations, update, stream)
}
fn scatter_args_index<'a>(
src: &'a Array,
index: i32,
update: &Array,
stream: impl AsRef<Stream>,
) -> Result<ScatterArgs<'a>> {
let update = remove_leading_singleton_dimensions(update, &stream)?;
let mut shape: SmallVec<[i32; DEFAULT_STACK_VEC_LEN]> = SmallVec::from_slice(src.shape());
shape[0] = 1;
Ok(ScatterArgs {
indices: smallvec![Cow::Owned(Array::from_int(resolve_index_signed_unchecked(
index,
src.dim(0)
)))],
update: broadcast_to_device(&update, &shape, &stream)?,
axes: smallvec![0],
})
}
fn scatter_args_array<'a>(
src: &'a Array,
a: Cow<'a, Array>,
update: &Array,
stream: impl AsRef<Stream>,
) -> Result<ScatterArgs<'a>> {
let update = remove_leading_singleton_dimensions(update, &stream)?;
let mut update_shape: SmallVec<[i32; DEFAULT_STACK_VEC_LEN]> = a
.shape()
.iter()
.chain(src.shape().iter().skip(1))
.cloned()
.collect();
let update = broadcast_to_device(&update, &update_shape, &stream)?;
update_shape.insert(a.ndim(), 1);
let update = update.reshape_device(&update_shape, &stream)?;
Ok(ScatterArgs {
indices: smallvec![a],
update,
axes: smallvec![0],
})
}
fn scatter_args_slice<'a>(
src: &'a Array,
range_index: &'a RangeIndex,
update: &Array,
stream: impl AsRef<Stream>,
) -> Result<ScatterArgs<'a>> {
if range_index.is_full() {
let update = remove_leading_singleton_dimensions(update, &stream)?;
return Ok(ScatterArgs {
indices: smallvec![],
update: broadcast_to_device(&update, src.shape(), &stream)?,
axes: smallvec![],
});
}
let size = src.dim(0);
let start = range_index.start(size);
let end = range_index.end(size);
let stride = range_index.stride();
if stride == 1 {
let update = remove_leading_singleton_dimensions(update, &stream)?;
let update_broadcast_shape: SmallVec<[i32; DEFAULT_STACK_VEC_LEN]> = (1..end - start)
.chain(src.shape().iter().skip(1).cloned())
.collect();
let update = broadcast_to_device(&update, &update_broadcast_shape, &stream)?;
let indices = Array::from_slice(&[start], &[1]);
Ok(ScatterArgs {
indices: smallvec![Cow::Owned(indices)],
update,
axes: smallvec![0],
})
} else {
let a_vals = strided_range_to_vec(start, end, stride);
let a = Array::from_slice(&a_vals, &[a_vals.len() as i32]);
scatter_args_array(src, Cow::Owned(a), update, stream)
}
}
fn scatter_args_nd<'a>(
src: &'a Array,
operations: &[ArrayIndexOp],
update: &Array,
stream: impl AsRef<Stream>,
) -> Result<ScatterArgs<'a>> {
use ArrayIndexOp::*;
let shape = src.shape();
let operations = expand_ellipsis_operations(src.ndim(), operations);
let update = remove_leading_singleton_dimensions(update, &stream)?;
let non_new_axis_operation_count = count_non_new_axis_operations(&operations);
if non_new_axis_operation_count == 0 {
return Ok(ScatterArgs {
indices: smallvec![],
update: broadcast_to_device(&update, shape, &stream)?,
axes: smallvec![],
});
}
let mut max_dims = 0;
let mut arrays_first = false;
let mut count_new_axis: i32 = 0;
let mut count_slices: i32 = 0;
let mut count_arrays: i32 = 0;
let mut count_strided_slices: i32 = 0;
let mut count_simple_slices_post: i32 = 0;
let mut have_array = false;
let mut have_non_array = false;
macro_rules! analyze_indices_take_array {
($indices:ident) => {
have_array = true;
if have_array && have_non_array {
arrays_first = true;
}
max_dims = $indices.ndim().max(max_dims);
count_arrays = count_arrays.saturating_add(1);
count_simple_slices_post = 0;
};
}
for item in operations.iter() {
match item {
TakeIndex { index: _ } => {
}
Slice(range_index) => {
have_non_array = have_array;
count_slices = count_slices.saturating_add(1);
if range_index.stride() != 1 {
count_strided_slices = count_strided_slices.saturating_add(1);
count_simple_slices_post = 0;
} else {
count_simple_slices_post = count_simple_slices_post.saturating_add(1);
}
}
TakeArray { indices } => {
analyze_indices_take_array!(indices);
}
TakeArrayRef { indices } => {
analyze_indices_take_array!(indices);
}
ExpandDims => {
have_non_array = true;
count_new_axis = count_new_axis.saturating_add(1);
}
Ellipsis => panic!("Unexpected item ellipsis in scatter_args_nd"),
}
}
let mut index_dims = (max_dims + count_new_axis as usize + count_slices as usize)
.saturating_sub(count_simple_slices_post as usize);
if index_dims == 0 {
index_dims = 1;
}
let mut array_indices: SmallVec<[Array; DEFAULT_STACK_VEC_LEN]> =
SmallVec::with_capacity(operations.len());
let mut slice_number: i32 = 0;
let mut array_number: i32 = 0;
let mut axis: i32 = 0;
let mut update_shape = vec![1; non_new_axis_operation_count];
let mut slice_shapes: SmallVec<[i32; DEFAULT_STACK_VEC_LEN]> = SmallVec::new();
macro_rules! update_shapes_take_array {
($indices:ident) => {
let start = if arrays_first {
max_dims - $indices.ndim()
} else {
slice_number as usize + max_dims - $indices.ndim()
};
let mut new_shape = vec![1; index_dims];
for j in 0..$indices.ndim() {
new_shape[start + j] = $indices.dim(j as i32);
}
array_indices.push($indices.reshape_device(&new_shape, &stream)?);
array_number = array_number.saturating_add(1);
if !arrays_first && array_number == count_arrays {
slice_number = slice_number.saturating_add_unsigned(max_dims as u32);
}
update_shape[axis as usize] = 1;
axis = axis.saturating_add(1);
};
}
for item in operations.iter() {
match item {
TakeIndex { index } => {
let resolved_index = resolve_index_signed_unchecked(*index, src.dim(axis));
array_indices.push(Array::from_int(resolved_index));
update_shape[axis as usize] = 1;
axis = axis.saturating_add(1);
}
Slice(range_index) => {
let size = src.dim(axis);
let start = range_index.absolute_start(size);
let end = range_index.absolute_end(size);
let stride = range_index.stride();
let mut index_shape = vec![1; index_dims];
if array_number >= count_arrays && count_strided_slices <= 0 && stride == 1 {
let index = Array::from_int(start).reshape_device(&index_shape, &stream)?;
let slice_shape_entry = end - start;
slice_shapes.push(slice_shape_entry);
array_indices.push(index);
update_shape[axis as usize] = slice_shape_entry;
} else {
let index_vals = strided_range_to_vec(start, end, stride);
let index = Array::from_slice(&index_vals, &[index_vals.len() as i32]);
let location = if arrays_first {
slice_number.saturating_add(max_dims as i32)
} else {
slice_number
};
index_shape[location as usize] = index.size() as i32;
array_indices.push(index.reshape_device(&index_shape, &stream)?);
slice_number = slice_number.saturating_add(1);
count_strided_slices = count_strided_slices.saturating_sub(1);
update_shape[axis as usize] = 1;
}
axis = axis.saturating_add(1);
}
TakeArray { indices } => {
update_shapes_take_array!(indices);
}
TakeArrayRef { indices } => {
update_shapes_take_array!(indices);
}
ExpandDims => slice_number = slice_number.saturating_add(1),
Ellipsis => panic!("Unexpected item ellipsis in scatter_args_nd"),
}
}
let array_indices = broadcast_arrays_device(&array_indices, &stream)?;
let update_shape_broadcast: Vec<_> = array_indices[0]
.shape()
.iter()
.chain(slice_shapes.iter())
.chain(src.shape().iter().skip(non_new_axis_operation_count))
.cloned()
.collect();
let update = broadcast_to_device(&update, &update_shape_broadcast, &stream)?;
let update_reshape: Vec<_> = array_indices[0]
.shape()
.iter()
.chain(update_shape.iter())
.chain(src.shape().iter().skip(non_new_axis_operation_count))
.cloned()
.collect();
let update = update.reshape_device(&update_reshape, &stream)?;
let array_indices_len = array_indices.len();
let indices = array_indices.into_iter().map(Cow::Owned).collect();
Ok(ScatterArgs {
indices,
update,
axes: (0..array_indices_len as i32).collect(),
})
}
fn strided_range_to_vec(start: i32, exclusive_end: i32, stride: i32) -> Vec<i32> {
let estimated_capacity = (exclusive_end - start).abs() / stride.abs();
let mut vec = Vec::with_capacity(estimated_capacity as usize);
let mut current = start;
assert_ne!(stride, 0, "Stride cannot be zero");
if stride.is_negative() {
while current > exclusive_end {
vec.push(current);
current += stride;
}
} else {
while current < exclusive_end {
vec.push(current);
current += stride;
}
}
vec
}
unsafe fn scatter_device(
a: &Array,
indices: &[impl AsRef<Array>],
updates: &Array,
axes: &[i32],
stream: impl AsRef<Stream>,
) -> Result<Array> {
let indices_vector = VectorArray::try_from_iter(indices.iter())?;
Array::try_from_op(|res| unsafe {
mlx_sys::mlx_scatter(
res,
a.as_ptr(),
indices_vector.as_ptr(),
updates.as_ptr(),
axes.as_ptr(),
axes.len(),
stream.as_ref().as_ptr(),
)
})
}
impl Array {
fn try_index_mut_device_inner(
&mut self,
operations: &[ArrayIndexOp],
update: &Array,
stream: impl AsRef<Stream>,
) -> Result<()> {
if let Some(result) = update_slice(self, operations, update, &stream)? {
*self = result;
return Ok(());
}
let ScatterArgs {
indices,
update,
axes,
} = scatter_args(self, operations, update, &stream)?;
if !indices.is_empty() {
let result = unsafe { scatter_device(self, &indices, &update, &axes, stream)? };
drop(indices);
*self = result;
} else {
drop(indices);
*self = update;
}
Ok(())
}
}
impl<'a, Val> TryIndexMutOp<&'a [ArrayIndexOp<'a>], Val> for Array
where
Val: AsRef<Array>,
{
fn try_index_mut_device(
&mut self,
i: &'a [ArrayIndexOp<'a>],
val: Val,
stream: impl AsRef<Stream>,
) -> Result<()> {
let update = val.as_ref();
self.try_index_mut_device_inner(i, update, stream)
}
}
impl<A, Val> TryIndexMutOp<A, Val> for Array
where
for<'a> A: ArrayIndex<'a>,
Val: AsRef<Array>,
{
fn try_index_mut_device(&mut self, i: A, val: Val, stream: impl AsRef<Stream>) -> Result<()> {
let operations = [i.index_op()];
let update = val.as_ref();
self.try_index_mut_device_inner(&operations, update, stream)
}
}
impl<'a, A, Val> TryIndexMutOp<(A,), Val> for Array
where
A: ArrayIndex<'a>,
Val: AsRef<Array>,
{
fn try_index_mut_device(
&mut self,
(i,): (A,),
val: Val,
stream: impl AsRef<Stream>,
) -> Result<()> {
let operations = [i.index_op()];
let update = val.as_ref();
self.try_index_mut_device_inner(&operations, update, stream)
}
}
impl<'a, 'b, A, B, Val> TryIndexMutOp<(A, B), Val> for Array
where
A: ArrayIndex<'a>,
B: ArrayIndex<'b>,
Val: AsRef<Array>,
{
fn try_index_mut_device(
&mut self,
i: (A, B),
val: Val,
stream: impl AsRef<Stream>,
) -> Result<()> {
let operations = [i.0.index_op(), i.1.index_op()];
let update = val.as_ref();
self.try_index_mut_device_inner(&operations, update, stream)
}
}
impl<'a, 'b, 'c, A, B, C, Val> TryIndexMutOp<(A, B, C), Val> for Array
where
A: ArrayIndex<'a>,
B: ArrayIndex<'b>,
C: ArrayIndex<'c>,
Val: AsRef<Array>,
{
fn try_index_mut_device(
&mut self,
i: (A, B, C),
val: Val,
stream: impl AsRef<Stream>,
) -> Result<()> {
let operations = [i.0.index_op(), i.1.index_op(), i.2.index_op()];
let update = val.as_ref();
self.try_index_mut_device_inner(&operations, update, stream)
}
}
impl<'a, 'b, 'c, 'd, A, B, C, D, Val> TryIndexMutOp<(A, B, C, D), Val> for Array
where
A: ArrayIndex<'a>,
B: ArrayIndex<'b>,
C: ArrayIndex<'c>,
D: ArrayIndex<'d>,
Val: AsRef<Array>,
{
fn try_index_mut_device(
&mut self,
i: (A, B, C, D),
val: Val,
stream: impl AsRef<Stream>,
) -> Result<()> {
let operations = [
i.0.index_op(),
i.1.index_op(),
i.2.index_op(),
i.3.index_op(),
];
let update = val.as_ref();
self.try_index_mut_device_inner(&operations, update, stream)
}
}
impl<'a, 'b, 'c, 'd, 'e, A, B, C, D, E, Val> TryIndexMutOp<(A, B, C, D, E), Val> for Array
where
A: ArrayIndex<'a>,
B: ArrayIndex<'b>,
C: ArrayIndex<'c>,
D: ArrayIndex<'d>,
E: ArrayIndex<'e>,
Val: AsRef<Array>,
{
fn try_index_mut_device(
&mut self,
i: (A, B, C, D, E),
val: Val,
stream: impl AsRef<Stream>,
) -> Result<()> {
let operations = [
i.0.index_op(),
i.1.index_op(),
i.2.index_op(),
i.3.index_op(),
i.4.index_op(),
];
let update = val.as_ref();
self.try_index_mut_device_inner(&operations, update, stream)
}
}
impl<'a, 'b, 'c, 'd, 'e, 'f, A, B, C, D, E, F, Val> TryIndexMutOp<(A, B, C, D, E, F), Val> for Array
where
A: ArrayIndex<'a>,
B: ArrayIndex<'b>,
C: ArrayIndex<'c>,
D: ArrayIndex<'d>,
E: ArrayIndex<'e>,
F: ArrayIndex<'f>,
Val: AsRef<Array>,
{
fn try_index_mut_device(
&mut self,
i: (A, B, C, D, E, F),
val: Val,
stream: impl AsRef<Stream>,
) -> Result<()> {
let operations = [
i.0.index_op(),
i.1.index_op(),
i.2.index_op(),
i.3.index_op(),
i.4.index_op(),
i.5.index_op(),
];
let update = val.as_ref();
self.try_index_mut_device_inner(&operations, update, stream)
}
}
impl<'a, 'b, 'c, 'd, 'e, 'f, 'g, A, B, C, D, E, F, G, Val> TryIndexMutOp<(A, B, C, D, E, F, G), Val>
for Array
where
A: ArrayIndex<'a>,
B: ArrayIndex<'b>,
C: ArrayIndex<'c>,
D: ArrayIndex<'d>,
E: ArrayIndex<'e>,
F: ArrayIndex<'f>,
G: ArrayIndex<'g>,
Val: AsRef<Array>,
{
fn try_index_mut_device(
&mut self,
i: (A, B, C, D, E, F, G),
val: Val,
stream: impl AsRef<Stream>,
) -> Result<()> {
let operations = [
i.0.index_op(),
i.1.index_op(),
i.2.index_op(),
i.3.index_op(),
i.4.index_op(),
i.5.index_op(),
i.6.index_op(),
];
let update = val.as_ref();
self.try_index_mut_device_inner(&operations, update, stream)
}
}
impl<'a, 'b, 'c, 'd, 'e, 'f, 'g, 'h, A, B, C, D, E, F, G, H, Val>
TryIndexMutOp<(A, B, C, D, E, F, G, H), Val> for Array
where
A: ArrayIndex<'a>,
B: ArrayIndex<'b>,
C: ArrayIndex<'c>,
D: ArrayIndex<'d>,
E: ArrayIndex<'e>,
F: ArrayIndex<'f>,
G: ArrayIndex<'g>,
H: ArrayIndex<'h>,
Val: AsRef<Array>,
{
fn try_index_mut_device(
&mut self,
i: (A, B, C, D, E, F, G, H),
val: Val,
stream: impl AsRef<Stream>,
) -> Result<()> {
let operations = [
i.0.index_op(),
i.1.index_op(),
i.2.index_op(),
i.3.index_op(),
i.4.index_op(),
i.5.index_op(),
i.6.index_op(),
i.7.index_op(),
];
let update = val.as_ref();
self.try_index_mut_device_inner(&operations, update, stream)
}
}
impl<'a, 'b, 'c, 'd, 'e, 'f, 'g, 'h, 'i, A, B, C, D, E, F, G, H, I, Val>
TryIndexMutOp<(A, B, C, D, E, F, G, H, I), Val> for Array
where
A: ArrayIndex<'a>,
B: ArrayIndex<'b>,
C: ArrayIndex<'c>,
D: ArrayIndex<'d>,
E: ArrayIndex<'e>,
F: ArrayIndex<'f>,
G: ArrayIndex<'g>,
H: ArrayIndex<'h>,
I: ArrayIndex<'i>,
Val: AsRef<Array>,
{
fn try_index_mut_device(
&mut self,
i: (A, B, C, D, E, F, G, H, I),
val: Val,
stream: impl AsRef<Stream>,
) -> Result<()> {
let operations = [
i.0.index_op(),
i.1.index_op(),
i.2.index_op(),
i.3.index_op(),
i.4.index_op(),
i.5.index_op(),
i.6.index_op(),
i.7.index_op(),
i.8.index_op(),
];
let update = val.as_ref();
self.try_index_mut_device_inner(&operations, update, stream)
}
}
impl<'a, 'b, 'c, 'd, 'e, 'f, 'g, 'h, 'i, 'j, A, B, C, D, E, F, G, H, I, J, Val>
TryIndexMutOp<(A, B, C, D, E, F, G, H, I, J), Val> for Array
where
A: ArrayIndex<'a>,
B: ArrayIndex<'b>,
C: ArrayIndex<'c>,
D: ArrayIndex<'d>,
E: ArrayIndex<'e>,
F: ArrayIndex<'f>,
G: ArrayIndex<'g>,
H: ArrayIndex<'h>,
I: ArrayIndex<'i>,
J: ArrayIndex<'j>,
Val: AsRef<Array>,
{
fn try_index_mut_device(
&mut self,
i: (A, B, C, D, E, F, G, H, I, J),
val: Val,
stream: impl AsRef<Stream>,
) -> Result<()> {
let operations = [
i.0.index_op(),
i.1.index_op(),
i.2.index_op(),
i.3.index_op(),
i.4.index_op(),
i.5.index_op(),
i.6.index_op(),
i.7.index_op(),
i.8.index_op(),
i.9.index_op(),
];
let update = val.as_ref();
self.try_index_mut_device_inner(&operations, update, stream)
}
}
impl<'a, 'b, 'c, 'd, 'e, 'f, 'g, 'h, 'i, 'j, 'k, A, B, C, D, E, F, G, H, I, J, K, Val>
TryIndexMutOp<(A, B, C, D, E, F, G, H, I, J, K), Val> for Array
where
A: ArrayIndex<'a>,
B: ArrayIndex<'b>,
C: ArrayIndex<'c>,
D: ArrayIndex<'d>,
E: ArrayIndex<'e>,
F: ArrayIndex<'f>,
G: ArrayIndex<'g>,
H: ArrayIndex<'h>,
I: ArrayIndex<'i>,
J: ArrayIndex<'j>,
K: ArrayIndex<'k>,
Val: AsRef<Array>,
{
fn try_index_mut_device(
&mut self,
i: (A, B, C, D, E, F, G, H, I, J, K),
val: Val,
stream: impl AsRef<Stream>,
) -> Result<()> {
let operations = [
i.0.index_op(),
i.1.index_op(),
i.2.index_op(),
i.3.index_op(),
i.4.index_op(),
i.5.index_op(),
i.6.index_op(),
i.7.index_op(),
i.8.index_op(),
i.9.index_op(),
i.10.index_op(),
];
let update = val.as_ref();
self.try_index_mut_device_inner(&operations, update, stream)
}
}
impl<'a, 'b, 'c, 'd, 'e, 'f, 'g, 'h, 'i, 'j, 'k, 'l, A, B, C, D, E, F, G, H, I, J, K, L, Val>
TryIndexMutOp<(A, B, C, D, E, F, G, H, I, J, K, L), Val> for Array
where
A: ArrayIndex<'a>,
B: ArrayIndex<'b>,
C: ArrayIndex<'c>,
D: ArrayIndex<'d>,
E: ArrayIndex<'e>,
F: ArrayIndex<'f>,
G: ArrayIndex<'g>,
H: ArrayIndex<'h>,
I: ArrayIndex<'i>,
J: ArrayIndex<'j>,
K: ArrayIndex<'k>,
L: ArrayIndex<'l>,
Val: AsRef<Array>,
{
fn try_index_mut_device(
&mut self,
i: (A, B, C, D, E, F, G, H, I, J, K, L),
val: Val,
stream: impl AsRef<Stream>,
) -> Result<()> {
let operations = [
i.0.index_op(),
i.1.index_op(),
i.2.index_op(),
i.3.index_op(),
i.4.index_op(),
i.5.index_op(),
i.6.index_op(),
i.7.index_op(),
i.8.index_op(),
i.9.index_op(),
i.10.index_op(),
i.11.index_op(),
];
let update = val.as_ref();
self.try_index_mut_device_inner(&operations, update, stream)
}
}
impl<
'a,
'b,
'c,
'd,
'e,
'f,
'g,
'h,
'i,
'j,
'k,
'l,
'm,
A,
B,
C,
D,
E,
F,
G,
H,
I,
J,
K,
L,
M,
Val,
> TryIndexMutOp<(A, B, C, D, E, F, G, H, I, J, K, L, M), Val> for Array
where
A: ArrayIndex<'a>,
B: ArrayIndex<'b>,
C: ArrayIndex<'c>,
D: ArrayIndex<'d>,
E: ArrayIndex<'e>,
F: ArrayIndex<'f>,
G: ArrayIndex<'g>,
H: ArrayIndex<'h>,
I: ArrayIndex<'i>,
J: ArrayIndex<'j>,
K: ArrayIndex<'k>,
L: ArrayIndex<'l>,
M: ArrayIndex<'m>,
Val: AsRef<Array>,
{
fn try_index_mut_device(
&mut self,
i: (A, B, C, D, E, F, G, H, I, J, K, L, M),
val: Val,
stream: impl AsRef<Stream>,
) -> Result<()> {
let operations = [
i.0.index_op(),
i.1.index_op(),
i.2.index_op(),
i.3.index_op(),
i.4.index_op(),
i.5.index_op(),
i.6.index_op(),
i.7.index_op(),
i.8.index_op(),
i.9.index_op(),
i.10.index_op(),
i.11.index_op(),
i.12.index_op(),
];
let update = val.as_ref();
self.try_index_mut_device_inner(&operations, update, stream)
}
}
impl<
'a,
'b,
'c,
'd,
'e,
'f,
'g,
'h,
'i,
'j,
'k,
'l,
'm,
'n,
A,
B,
C,
D,
E,
F,
G,
H,
I,
J,
K,
L,
M,
N,
Val,
> TryIndexMutOp<(A, B, C, D, E, F, G, H, I, J, K, L, M, N), Val> for Array
where
A: ArrayIndex<'a>,
B: ArrayIndex<'b>,
C: ArrayIndex<'c>,
D: ArrayIndex<'d>,
E: ArrayIndex<'e>,
F: ArrayIndex<'f>,
G: ArrayIndex<'g>,
H: ArrayIndex<'h>,
I: ArrayIndex<'i>,
J: ArrayIndex<'j>,
K: ArrayIndex<'k>,
L: ArrayIndex<'l>,
M: ArrayIndex<'m>,
N: ArrayIndex<'n>,
Val: AsRef<Array>,
{
fn try_index_mut_device(
&mut self,
i: (A, B, C, D, E, F, G, H, I, J, K, L, M, N),
val: Val,
stream: impl AsRef<Stream>,
) -> Result<()> {
let operations = [
i.0.index_op(),
i.1.index_op(),
i.2.index_op(),
i.3.index_op(),
i.4.index_op(),
i.5.index_op(),
i.6.index_op(),
i.7.index_op(),
i.8.index_op(),
i.9.index_op(),
i.10.index_op(),
i.11.index_op(),
i.12.index_op(),
i.13.index_op(),
];
let update = val.as_ref();
self.try_index_mut_device_inner(&operations, update, stream)
}
}
impl<
'a,
'b,
'c,
'd,
'e,
'f,
'g,
'h,
'i,
'j,
'k,
'l,
'm,
'n,
'o,
A,
B,
C,
D,
E,
F,
G,
H,
I,
J,
K,
L,
M,
N,
O,
Val,
> TryIndexMutOp<(A, B, C, D, E, F, G, H, I, J, K, L, M, N, O), Val> for Array
where
A: ArrayIndex<'a>,
B: ArrayIndex<'b>,
C: ArrayIndex<'c>,
D: ArrayIndex<'d>,
E: ArrayIndex<'e>,
F: ArrayIndex<'f>,
G: ArrayIndex<'g>,
H: ArrayIndex<'h>,
I: ArrayIndex<'i>,
J: ArrayIndex<'j>,
K: ArrayIndex<'k>,
L: ArrayIndex<'l>,
M: ArrayIndex<'m>,
N: ArrayIndex<'n>,
O: ArrayIndex<'o>,
Val: AsRef<Array>,
{
fn try_index_mut_device(
&mut self,
i: (A, B, C, D, E, F, G, H, I, J, K, L, M, N, O),
val: Val,
stream: impl AsRef<Stream>,
) -> Result<()> {
let operations = [
i.0.index_op(),
i.1.index_op(),
i.2.index_op(),
i.3.index_op(),
i.4.index_op(),
i.5.index_op(),
i.6.index_op(),
i.7.index_op(),
i.8.index_op(),
i.9.index_op(),
i.10.index_op(),
i.11.index_op(),
i.12.index_op(),
i.13.index_op(),
i.14.index_op(),
];
let update = val.as_ref();
self.try_index_mut_device_inner(&operations, update, stream)
}
}
impl<
'a,
'b,
'c,
'd,
'e,
'f,
'g,
'h,
'i,
'j,
'k,
'l,
'm,
'n,
'o,
'p,
A,
B,
C,
D,
E,
F,
G,
H,
I,
J,
K,
L,
M,
N,
O,
P,
Val,
> TryIndexMutOp<(A, B, C, D, E, F, G, H, I, J, K, L, M, N, O, P), Val> for Array
where
A: ArrayIndex<'a>,
B: ArrayIndex<'b>,
C: ArrayIndex<'c>,
D: ArrayIndex<'d>,
E: ArrayIndex<'e>,
F: ArrayIndex<'f>,
G: ArrayIndex<'g>,
H: ArrayIndex<'h>,
I: ArrayIndex<'i>,
J: ArrayIndex<'j>,
K: ArrayIndex<'k>,
L: ArrayIndex<'l>,
M: ArrayIndex<'m>,
N: ArrayIndex<'n>,
O: ArrayIndex<'o>,
P: ArrayIndex<'p>,
Val: AsRef<Array>,
{
fn try_index_mut_device(
&mut self,
i: (A, B, C, D, E, F, G, H, I, J, K, L, M, N, O, P),
val: Val,
stream: impl AsRef<Stream>,
) -> Result<()> {
let operations = [
i.0.index_op(),
i.1.index_op(),
i.2.index_op(),
i.3.index_op(),
i.4.index_op(),
i.5.index_op(),
i.6.index_op(),
i.7.index_op(),
i.8.index_op(),
i.9.index_op(),
i.10.index_op(),
i.11.index_op(),
i.12.index_op(),
i.13.index_op(),
i.14.index_op(),
i.15.index_op(),
];
let update = val.as_ref();
self.try_index_mut_device_inner(&operations, update, stream)
}
}
#[cfg(test)]
mod tests {
use crate::{ops::indexing::*, Array};
#[test]
fn test_array_mutate_single_index() {
let mut a = Array::from_iter(0i32..12, &[3, 4]);
let new_value = Array::from_int(77);
a.index_mut(1, new_value);
let expected = Array::from_slice(&[0, 1, 2, 3, 77, 77, 77, 77, 8, 9, 10, 11], &[3, 4]);
assert_array_all_close!(a, expected);
}
#[test]
fn test_array_mutate_broadcast_multi_index() {
let mut a = Array::from_iter(0i32..20, &[2, 2, 5]);
a.index_mut((1, 0), Array::from_int(77));
a.index_mut((0, 0), Array::from_slice(&[55i32, 66, 77, 88, 99], &[5]));
a.index_mut((0, 1, 3), Array::from_int(123));
let expected = Array::from_slice(
&[
55, 66, 77, 88, 99, 5, 6, 7, 123, 9, 77, 77, 77, 77, 77, 15, 16, 17, 18, 19,
],
&[2, 2, 5],
);
assert_array_all_close!(a, expected);
}
#[test]
fn test_array_mutate_broadcast_slice() {
let mut a = Array::from_iter(0i32..20, &[2, 2, 5]);
a.index_mut((0..1, 1..2, 2..4), Array::from_int(88));
let expected = Array::from_slice(
&[
0, 1, 2, 3, 4, 5, 6, 88, 88, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19,
],
&[2, 2, 5],
);
assert_array_all_close!(a, expected);
}
#[test]
fn test_array_mutate_advanced() {
let mut a = Array::from_iter(0i32..35, &[5, 7]);
let i1 = Array::from_slice(&[0, 2, 4], &[3]);
let i2 = Array::from_slice(&[0, 1, 2], &[3]);
a.index_mut((i1, i2), Array::from_slice(&[100, 200, 300], &[3]));
assert_eq!(a.index((0, 0)).item::<i32>(), 100i32);
assert_eq!(a.index((2, 1)).item::<i32>(), 200i32);
assert_eq!(a.index((4, 2)).item::<i32>(), 300i32);
}
#[test]
fn test_full_index_write_single() {
fn check<I>(index: I, expected_sum: i32)
where
for<'a> I: ArrayIndex<'a>,
{
let mut a = Array::from_iter(0..60, &[3, 4, 5]);
a.index_mut(index, Array::from_int(1));
let sum = a.sum(None, None).unwrap().item::<i32>();
assert_eq!(sum, expected_sum);
}
check(NewAxis, 60);
check(0, 1600);
check(1..3, 230);
let i = Array::from_slice(&[2, 1], &[2]);
check(i, 230);
}
#[test]
fn test_full_index_write_no_array() {
macro_rules! check {
(($( $i:expr ),*), $sum:expr ) => {
{
let mut a = Array::from_iter(0..360, &[2, 3, 4, 5, 3]);
a.index_mut(($($i),*), Array::from_int(1));
let sum = a.sum(None, None).unwrap().item::<i32>();
assert_eq!(sum, $sum);
}
};
}
check!((Ellipsis, 0), 43320);
check!((0, Ellipsis), 48690);
check!((0, Ellipsis, 0), 59370);
check!((Ellipsis, (..).stride_by(2), ..), 26064);
check!((Ellipsis, NewAxis, (..).stride_by(2), -1), 51696);
check!((.., 2.., 0), 58140);
check!(
(
(..).stride_by(-1),
..2,
2..,
Ellipsis,
NewAxis,
(..).stride_by(2)
),
51540
);
}
#[test]
fn test_full_index_write_array() {
macro_rules! check {
(($( $i:expr ),*), $sum:expr ) => {
{
let mut a = Array::from_iter(0..540, &[3, 3, 4, 5, 3]);
a.index_mut(($($i),*), Array::from_int(1));
let sum = a.sum(None, None).unwrap().item::<i32>();
assert_eq!(sum, $sum);
}
};
}
let i = Array::from_slice(&[2, 1], &[2]);
check!((0, &i), 131310);
check!((Ellipsis, &i, 0), 126378);
check!((&i, 0, Ellipsis), 109710);
check!((&i, Ellipsis, &i), 102450);
check!((&i, Ellipsis, (..).stride_by(2), ..), 68094);
check!((Ellipsis, &i, NewAxis, (..).stride_by(2), -1), 130977);
check!((.., 2.., &i), 115965);
check!(
(
(..).stride_by(-1),
..2,
i,
2..,
Ellipsis,
NewAxis,
(..).stride_by(2)
),
128142
);
}
}