mlx_rs::random

Function multivariate_normal

Source
pub fn multivariate_normal<'a, T: ArrayElement>(
    mean: impl AsRef<Array>,
    covariance: impl AsRef<Array>,
    shape: impl IntoOption<&'a [i32]>,
    key: impl Into<Option<&'a Array>>,
) -> Result<Array>
Expand description

Generate jointly-normal random samples given a mean and covariance.

The matrix covariance must be positive semi-definite. The behavior is undefined if it is not. The only supported output type is f32.

ยงParams

  • mean: array of shape [..., n], the mean of the distribution.
  • covariance: array of shape [..., n, n], the covariance matrix of the distribution. The batch shape ... must be broadcast-compatible with that of mean.
  • shape: The output shape must be broadcast-compatible with &mean.shape[..mean.shape.len()-1] and &covariance.shape[..covariance.shape.len()-2]. If empty, the result shape is determined by broadcasting the batch shapes of mean and covariance.
  • key: PRNG key.