mlx_rs::random

Function normal_device

Source
pub fn normal_device<'a, T: ArrayElement>(
    shape: impl IntoOption<&'a [i32]>,
    loc: impl Into<Option<f32>>,
    scale: impl Into<Option<f32>>,
    key: impl Into<Option<&'a Array>>,
    stream: impl AsRef<Stream>,
) -> Result<Array>
Expand description

Generate normally distributed random numbers.

Generate an array of random numbers using the optional shape. The result will be of the given T. T must be a floating point type.

§Params

  • shape: shape of the output, if None a single value is returned
  • loc: mean of the distribution, default is 0.0
  • scale: standard deviation of the distribution, default is 1.0
  • key: PRNG key

§Example

let key = mlx_rs::random::key(0).unwrap();

// generate a single f32 with normal distribution
let value = mlx_rs::random::normal::<f32>(None, None, None, &key).unwrap().item::<f32>();

// generate an array of f32 with normal distribution in shape [10, 5]
let array = mlx_rs::random::normal::<f32>(&[10, 5], None, None, &key);