mlx_rs::random

Function truncated_normal

Source
pub fn truncated_normal<'a, E: Into<Array>, T: ArrayElement>(
    lower: E,
    upper: E,
    shape: impl IntoOption<&'a [i32]>,
    key: impl Into<Option<&'a Array>>,
) -> Result<Array>
Expand description

Generate values from a truncated normal distribution between low and high.

The values are sampled from the truncated normal distribution on the domain (lower, upper). The bounds lower and upper can be scalars or arrays and must be broadcastable to shape.

use mlx_rs::{array, random};

let key = random::key(0).unwrap();

// generate an array of two Float values, one in the range 0 ..< 10
// and one in the range 10 ..< 100
let value = random::truncated_normal::<_, f32>(array!([0, 10]), array!([10, 100]), None, &key);