mlx_rs::random

Function bernoulli_device

Source
pub fn bernoulli_device<'a>(
    p: impl Into<Option<&'a Array>>,
    shape: impl IntoOption<&'a [i32]>,
    key: impl Into<Option<&'a Array>>,
    stream: impl AsRef<Stream>,
) -> Result<Array>
Expand description

Generate Bernoulli random values with a given p value.

The values are sampled from the bernoulli distribution with parameter p. The parameter p must have a floating point type and must be broadcastable to shape.

use mlx_rs::{array, Array, random};

let key = random::key(0).unwrap();

// generate a single random Bool with p = 0.8
let p: Array = 0.8.into();
let value = random::bernoulli(&p, None, &key);

// generate an array of shape [50, 2] of random Bool with p = 0.8
let array = random::bernoulli(&p, &[50, 2], &key);

// generate an array of [3] Bool with the given p values
let array = random::bernoulli(&array!([0.1, 0.5, 0.8]), None, &key);