pub fn categorical_device<'a>(
logits: impl AsRef<Array>,
axis: impl Into<Option<i32>>,
shape_or_count: impl Into<Option<ShapeOrCount<'a>>>,
key: impl Into<Option<&'a Array>>,
stream: impl AsRef<Stream>,
) -> Result<Array>
Expand description
Sample from a categorical distribution.
The values are sampled from the categorical distribution specified by
the unnormalized values in logits
. If the shape
is not specified
the result shape will be the same shape as logits
with the axis
dimension removed.
/// # Params
§Params
logits
: The unnormalized categorical distribution(s).axis
(optional): The axis which specifies the distribution. Default is-1
.shape_or_count
(optional):-
Shape
: The shape of the output. This must be broadcast compatible withlogits.shape
with theaxis
dimension removed.
-
Count
: The number of samples to draw from each of the categorical distributions inlogits
. The output will have the number of samples in the last dimension.
key
(optional): A PRNG key.
§Example
let key = mlx_rs::random::key(0).unwrap();
let logits = mlx_rs::Array::zeros::<u32>(&[5, 20]).unwrap();
// produces Array of u32 shape &[5]
let result = mlx_rs::random::categorical(&logits, None, None, &key);