pub fn truncated_normal<'a, E: Into<Array>, T: ArrayElement>(
lower: E,
upper: E,
shape: impl IntoOption<&'a [i32]>,
key: impl Into<Option<&'a Array>>,
) -> Result<Array>
Expand description
Generate values from a truncated normal distribution between low
and high
.
The values are sampled from the truncated normal distribution
on the domain (lower, upper)
. The bounds lower
and upper
can be scalars or arrays and must be broadcastable to shape
.
use mlx_rs::{array, random};
let key = random::key(0).unwrap();
// generate an array of two Float values, one in the range 0 ..< 10
// and one in the range 10 ..< 100
let value = random::truncated_normal::<_, f32>(array!([0, 10]), array!([10, 100]), None, &key);