pub fn truncated_normal_device<'a, E: Into<Array>, T: ArrayElement>(
lower: E,
upper: E,
shape: impl IntoOption<&'a [i32]>,
key: impl Into<Option<&'a Array>>,
stream: impl AsRef<Stream>,
) -> Result<Array>Expand description
Generate values from a truncated normal distribution between low and high.
The values are sampled from the truncated normal distribution
on the domain (lower, upper). The bounds lower and upper
can be scalars or arrays and must be broadcastable to shape.
use mlx_rs::{array, random};
let key = random::key(0).unwrap();
// generate an array of two Float values, one in the range 0 ..< 10
// and one in the range 10 ..< 100
let value = random::truncated_normal::<_, f32>(array!([0, 10]), array!([10, 100]), None, &key);