mlx_rs::ops

Function flatten_device

Source
pub fn flatten_device(
    a: impl AsRef<Array>,
    start_axis: impl Into<Option<i32>>,
    end_axis: impl Into<Option<i32>>,
    stream: impl AsRef<Stream>,
) -> Result<Array>
Expand description

Flatten an array. Returns an error if the axes are invalid.

The axes flattened will be between start_axis and end_axis, inclusive. Negative axes are supported. After converting negative axis to positive, axes outside the valid range will be clamped to a valid value, start_axis to 0 and end_axis to ndim - 1.

§Params

  • a: The input array.
  • start_axis: The first axis to flatten. Default is 0 if not provided.
  • end_axis: The last axis to flatten. Default is -1 if not provided.

§Example

use mlx_rs::{Array, ops::*};

let x = Array::zeros::<i32>(&[2, 2, 2]).unwrap();
let y = flatten(&x, None, None);