pub fn truncated_normal<'a, E: Into<Array>, T: ArrayElement>(
lower: E,
upper: E,
shape: impl IntoOption<&'a [i32]>,
key: impl Into<Option<&'a Array>>,
) -> Result<Array>Expand description
Generate values from a truncated normal distribution between low and high.
The values are sampled from the truncated normal distribution
on the domain (lower, upper). The bounds lower and upper
can be scalars or arrays and must be broadcastable to shape.
use mlx_rs::{array, random};
let key = random::key(0).unwrap();
// generate an array of two Float values, one in the range 0 ..< 10
// and one in the range 10 ..< 100
let value = random::truncated_normal::<_, f32>(array!([0, 10]), array!([10, 100]), None, &key);