mlx_rs::random

Function categorical

Source
pub fn categorical<'a>(
    logits: impl AsRef<Array>,
    axis: impl Into<Option<i32>>,
    shape_or_count: impl Into<Option<ShapeOrCount<'a>>>,
    key: impl Into<Option<&'a Array>>,
) -> Result<Array>
Expand description

Sample from a categorical distribution.

The values are sampled from the categorical distribution specified by the unnormalized values in logits. If the shape is not specified the result shape will be the same shape as logits with the axis dimension removed.

/// # Params

§Params

  • logits: The unnormalized categorical distribution(s).
  • axis(optional): The axis which specifies the distribution. Default is -1.
  • shape_or_count(optional):
    • Shape: The shape of the output. This must be broadcast compatible with logits.shape with the axis dimension removed.
    • Count: The number of samples to draw from each of the categorical distributions in logits. The output will have the number of samples in the last dimension.
  • key (optional): A PRNG key.

§Example

let key = mlx_rs::random::key(0).unwrap();

let logits = mlx_rs::Array::zeros::<u32>(&[5, 20]).unwrap();

// produces Array of u32 shape &[5]
let result = mlx_rs::random::categorical(&logits, None, None, &key);