Module indexing

Source
Expand description

Indexing Arrays

§Overview

Due to limitations in the std::ops::Index and std::ops::IndexMut traits (only references can be returned), the indexing is achieved with the IndexOp and IndexMutOp traits where arrays can be indexed with IndexOp::index() and IndexMutOp::index_mut() respectively.

The following types can be used as indices:

TypeDescription
i32An integer index
ArrayUse an array to index another array
&ArrayUse a reference to an array to index another array
std::ops::Range<i32>A range index
std::ops::RangeFrom<i32>A range index
std::ops::RangeFullA range index
std::ops::RangeInclusive<i32>A range index
std::ops::RangeTo<i32>A range index
std::ops::RangeToInclusive<i32>A range index
StrideByA range index with stride
NewAxisAdd a new axis
EllipsisConsume all axes

§Single axis indexing

Indexing Operationmlx (python)mlx-swiftmlx-rs
integerarr[1]arr[1]arr.index(1)
range expressionarr[1:3]arr[1..<3]arr.index(1..3)
full rangearr[:]arr[0...]arr.index(..)
range with stridearr[::2]arr[.stride(by: 2)]arr.index((..).stride_by(2))
ellipsis (consuming all axes)arr[...]arr[.ellipsis]arr.index(Ellipsis)
newaxisarr[None]arr[.newAxis]arr.index(NewAxis)
mlx array iarr[i]arr[i]arr.index(i)

§Multi-axes indexing

Multi-axes indexing with combinations of the above operations is also supported by combining the operations in a tuple with the restriction that Ellipsis can only be used once.

§Examples

// See the multi-dimensional example code for mlx python https://ml-explore.github.io/mlx/build/html/usage/indexing.html

use mlx_rs::{Array, ops::indexing::*};

let a = Array::from_iter(0..8, &[2, 2, 2]);

// a[:, :, 0]
let mut s1 = a.index((.., .., 0));

let expected = Array::from_slice(&[0, 2, 4, 6], &[2, 2]);
assert_eq!(s1, expected);

// a[..., 0]
let mut s2 = a.index((Ellipsis, 0));

let expected = Array::from_slice(&[0, 2, 4, 6], &[2, 2]);
assert_eq!(s1, expected);

§Set values with indexing

The same indexing operations (single or multiple) can be used to set values in an array using the IndexMutOp trait.

§Example

use mlx_rs::{Array, ops::indexing::*};

let mut a = Array::from_slice(&[1, 2, 3], &[3]);
a.index_mut(2, Array::from_int(0));

let expected = Array::from_slice(&[1, 2, 0], &[3]);
assert_eq!(a, expected);
use mlx_rs::{Array, ops::indexing::*};

let mut a = Array::from_iter(0i32..20, &[2, 2, 5]);

// writing using slices -- this ends up covering two elements
a.index_mut((0..1, 1..2, 2..4), Array::from_int(88));

let expected = Array::from_slice(
    &[
        0, 1, 2, 3, 4, 5, 6, 88, 88, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19,
    ],
    &[2, 2, 5],
);
assert_eq!(a, expected);

Structs§

Ellipsis
Ellipsis indexing operation.
NewAxis
New axis indexing operation.
RangeIndex
Range indexing operation.
StrideBy
Stride indexing operation.

Enums§

ArrayIndexOp
Indexing operation for arrays.

Traits§

ArrayIndex
Trait for custom indexing operations.
IndexMutOp
Trait for custom mutable indexing operations.
IndexOp
Trait for custom indexing operations.
IntoStrideBy
Helper trait for creating a stride indexing operation.
TryIndexMutOp
Trait for custom mutable indexing operations.
TryIndexOp
Trait for custom indexing operations.

Functions§

argmax
Indices of the maximum values along the axis.
argmax_all
Indices of the maximum value over the entire array.
argmax_all_device
Indices of the maximum value over the entire array.
argmax_device
Indices of the maximum values along the axis.
argmin
Indices of the minimum values along the axis.
argmin_all
Indices of the minimum value over the entire array.
argmin_all_device
Indices of the minimum value over the entire array.
argmin_device
Indices of the minimum values along the axis.
put_along_axis
See Array::put_along_axis
put_along_axis_device
See Array::put_along_axis
take
See Array::take
take_all
See Array::take_all
take_all_device
See Array::take_all
take_along_axis
See Array::take_along_axis
take_along_axis_device
See Array::take_along_axis
take_device
See Array::take
topk
Returns the k largest elements from the input along a given axis.
topk_all
Returns the k largest elements from the flattened input array.
topk_all_device
Returns the k largest elements from the flattened input array.
topk_device
Returns the k largest elements from the input along a given axis.