mlx_rs::random

Function truncated_normal_device

Source
pub fn truncated_normal_device<'a, E: Into<Array>, T: ArrayElement>(
    lower: E,
    upper: E,
    shape: impl IntoOption<&'a [i32]>,
    key: impl Into<Option<&'a Array>>,
    stream: impl AsRef<Stream>,
) -> Result<Array>
Expand description

Generate values from a truncated normal distribution between low and high.

The values are sampled from the truncated normal distribution on the domain (lower, upper). The bounds lower and upper can be scalars or arrays and must be broadcastable to shape.

use mlx_rs::{array, random};

let key = random::key(0).unwrap();

// generate an array of two Float values, one in the range 0 ..< 10
// and one in the range 10 ..< 100
let value = random::truncated_normal::<_, f32>(array!([0, 10]), array!([10, 100]), None, &key);