pub fn categorical_device<'a>(
logits: impl AsRef<Array>,
axis: impl Into<Option<i32>>,
shape_or_count: impl Into<Option<ShapeOrCount<'a>>>,
key: impl Into<Option<&'a Array>>,
stream: impl AsRef<Stream>,
) -> Result<Array>Expand description
Sample from a categorical distribution.
The values are sampled from the categorical distribution specified by
the unnormalized values in logits. If the shape is not specified
the result shape will be the same shape as logits with the axis
dimension removed.
/// # Params
§Params
logits: The unnormalized categorical distribution(s).axis(optional): The axis which specifies the distribution. Default is-1.shape_or_count(optional):-
Shape: The shape of the output. This must be broadcast compatible withlogits.shapewith theaxisdimension removed.
-
Count: The number of samples to draw from each of the categorical distributions inlogits. The output will have the number of samples in the last dimension.
key(optional): A PRNG key.
§Example
let key = mlx_rs::random::key(0).unwrap();
let logits = mlx_rs::Array::zeros::<u32>(&[5, 20]).unwrap();
// produces Array of u32 shape &[5]
let result = mlx_rs::random::categorical(&logits, None, None, &key);