mlx_rs/
random.rs

1//! Collection of functions related to random number generation
2
3use crate::ops::indexing::TryIndexOp;
4use crate::utils::guard::Guarded;
5use crate::utils::IntoOption;
6use crate::{error::Result, Array, ArrayElement, Stream};
7use mach_sys::mach_time;
8use mlx_internal_macros::{default_device, generate_macro};
9use parking_lot::Mutex;
10use std::borrow::Cow;
11use std::cell::RefCell;
12use std::sync::OnceLock;
13
14static GLOBAL_STATE: OnceLock<Mutex<RandomState>> = OnceLock::new();
15
16thread_local! {
17    static TASK_LOCAL_STATE: RefCell<Option<RandomState>> = const { RefCell::new(None) };
18}
19
20/// Random state factory
21#[derive(Debug, Clone)]
22pub struct RandomState {
23    state: Array,
24}
25
26impl RandomState {
27    fn new() -> Result<Self> {
28        let now = unsafe { mach_time::mach_approximate_time() };
29        Ok(Self { state: key(now)? })
30    }
31
32    fn next(&mut self) -> Result<Array> {
33        let next = split(&self.state, 2)?;
34        self.state = next.0;
35        Ok(next.1)
36    }
37
38    fn seed(&mut self, seed: u64) -> Result<()> {
39        self.state = key(seed)?;
40        Ok(())
41    }
42}
43
44fn global_state() -> &'static Mutex<RandomState> {
45    GLOBAL_STATE.get_or_init(|| Mutex::new(RandomState::new().unwrap()))
46}
47
48/// Returns a key from the task-local state if it exists, otherwise
49/// returns `None`
50fn resolve_task_local_key() -> Option<Result<Array>> {
51    TASK_LOCAL_STATE.with_borrow_mut(|state| state.as_mut().map(|s| s.next()))
52}
53
54fn resolve_global_key() -> Result<Array> {
55    let mut state = global_state().lock();
56    state.next()
57}
58
59/// Use given key or generate a new one if `None`.
60fn resolve<'a>(key: impl Into<Option<&'a Array>>) -> Result<Cow<'a, Array>> {
61    key.into().map_or_else(
62        || {
63            resolve_task_local_key()
64                .unwrap_or_else(resolve_global_key)
65                .map(Cow::Owned)
66        },
67        |k| Ok(Cow::Borrowed(k)),
68    )
69}
70
71/// Use the given random state for the scope of `f`
72pub fn with_random_state<F, T>(state: RandomState, f: F) -> T
73where
74    F: FnOnce() -> T,
75{
76    let prev_state = TASK_LOCAL_STATE.with_borrow_mut(|s| s.replace(state));
77
78    let result = f();
79
80    TASK_LOCAL_STATE.with_borrow_mut(|s| {
81        *s = prev_state;
82    });
83
84    result
85}
86
87/// Seed the random number generator.
88pub fn seed(seed: u64) -> Result<()> {
89    let mut state = global_state().lock();
90    state.seed(seed)
91}
92
93/// Get a PRNG key from a seed.
94///
95/// Return a value that can be used as a PRNG key.  All ``random::*``
96/// functions take an optional key -- this will let you control the
97/// random number generation.
98pub fn key(seed: u64) -> Result<Array> {
99    Array::try_from_op(|res| unsafe { mlx_sys::mlx_random_key(res, seed) })
100}
101
102/// Split a PRNG key into two keys and return a tuple.
103#[default_device]
104pub fn split_device(
105    key: impl AsRef<Array>,
106    num: i32,
107    stream: impl AsRef<Stream>,
108) -> Result<(Array, Array)> {
109    let keys = Array::try_from_op(|res| unsafe {
110        mlx_sys::mlx_random_split_num(res, key.as_ref().as_ptr(), num, stream.as_ref().as_ptr())
111    })?;
112
113    Ok((keys.try_index(0)?, keys.try_index(1)?))
114}
115
116/// Generate uniformly distributed random numbers.
117/// The values are sampled uniformly in the half-open interval `[lower, upper)`.
118/// The lower and upper bound can be scalars or arrays and must be broadcastable to `shape`.
119///
120/// # Params
121///
122/// - `lower`: Lower bound of the distribution.
123/// - `upper`: Upper bound of the distribution.
124/// - `shape` (optional): Shape of the output. Default is `&[]`.
125/// - `key` (optional): A PRNG key.
126///
127/// ```rust
128/// let key = mlx_rs::random::key(0).unwrap();
129///
130/// // create an array of shape `[50]` type f32 values in the range [0, 10)
131/// let array = mlx_rs::random::uniform::<_, f32>(0, 10, &[50], &key);
132///
133/// // same, but in range [0.5, 1)
134/// let array = mlx_rs::random::uniform::<_, f32>(0.5f32, 1f32, &[50], &key);
135/// ```
136#[generate_macro(customize(root = "$crate::random"))]
137#[default_device]
138pub fn uniform_device<'a, E: Into<Array>, T: ArrayElement>(
139    lower: E,
140    upper: E,
141    #[optional] shape: impl IntoOption<&'a [i32]>,
142    #[optional] key: impl Into<Option<&'a Array>>,
143    #[optional] stream: impl AsRef<Stream>,
144) -> Result<Array> {
145    let lb: Array = lower.into();
146    let ub: Array = upper.into();
147    let shape = shape.into_option().unwrap_or(&[]);
148    let key = resolve(key)?;
149
150    Array::try_from_op(|res| unsafe {
151        mlx_sys::mlx_random_uniform(
152            res,
153            lb.as_ptr(),
154            ub.as_ptr(),
155            shape.as_ptr(),
156            shape.len(),
157            T::DTYPE.into(),
158            key.as_ptr(),
159            stream.as_ref().as_ptr(),
160        )
161    })
162}
163
164/// Generate normally distributed random numbers.
165///
166/// Generate an array of random numbers using the optional shape. The result
167/// will be of the given `T`. `T` must be a floating point type.
168///
169/// # Params
170///
171///  - shape: shape of the output, if `None` a single value is returned
172///  - loc: mean of the distribution, default is `0.0`
173///  - scale: standard deviation of the distribution, default is `1.0`
174///  - key: PRNG key
175///
176/// # Example
177///
178/// ```rust
179/// let key = mlx_rs::random::key(0).unwrap();
180///
181/// // generate a single f32 with normal distribution
182/// let value = mlx_rs::random::normal::<f32>(None, None, None, &key).unwrap().item::<f32>();
183///
184/// // generate an array of f32 with normal distribution in shape [10, 5]
185/// let array = mlx_rs::random::normal::<f32>(&[10, 5], None, None, &key);
186/// ```
187#[generate_macro(customize(root = "$crate::random"))]
188#[default_device]
189pub fn normal_device<'a, T: ArrayElement>(
190    #[optional] shape: impl IntoOption<&'a [i32]>,
191    #[optional] loc: impl Into<Option<f32>>,
192    #[optional] scale: impl Into<Option<f32>>,
193    #[optional] key: impl Into<Option<&'a Array>>,
194    #[optional] stream: impl AsRef<Stream>,
195) -> Result<Array> {
196    let shape = shape.into_option().unwrap_or(&[]);
197    let key = resolve(key)?;
198
199    Array::try_from_op(|res| unsafe {
200        mlx_sys::mlx_random_normal(
201            res,
202            shape.as_ptr(),
203            shape.len(),
204            T::DTYPE.into(),
205            loc.into().unwrap_or(0.0),
206            scale.into().unwrap_or(1.0),
207            key.as_ptr(),
208            stream.as_ref().as_ptr(),
209        )
210    })
211}
212
213/// Generate jointly-normal random samples given a mean and covariance.
214///
215/// The matrix `covariance` must be positive semi-definite. The behavior is
216/// undefined if it is not.  The only supported output type is f32.
217///
218/// # Params
219/// - `mean`: array of shape `[..., n]`, the mean of the distribution.
220/// - `covariance`: array  of shape `[..., n, n]`, the covariance matrix of the distribution. The batch shape `...` must be broadcast-compatible with that of `mean`.
221/// - `shape`: The output shape must be broadcast-compatible with `&mean.shape[..mean.shape.len()-1]` and `&covariance.shape[..covariance.shape.len()-2]`. If empty, the result shape is determined by broadcasting the batch shapes of `mean` and `covariance`.
222/// - `key`: PRNG key.
223#[generate_macro(customize(root = "$crate::random"))]
224#[default_device(device = "cpu")] // TODO: not supported on GPU yet
225pub fn multivariate_normal_device<'a, T: ArrayElement>(
226    mean: impl AsRef<Array>,
227    covariance: impl AsRef<Array>,
228    #[optional] shape: impl IntoOption<&'a [i32]>,
229    #[optional] key: impl Into<Option<&'a Array>>,
230    #[optional] stream: impl AsRef<Stream>,
231) -> Result<Array> {
232    let shape = shape.into_option().unwrap_or(&[]);
233    let key = resolve(key)?;
234
235    Array::try_from_op(|res| unsafe {
236        mlx_sys::mlx_random_multivariate_normal(
237            res,
238            mean.as_ref().as_ptr(),
239            covariance.as_ref().as_ptr(),
240            shape.as_ptr(),
241            shape.len(),
242            T::DTYPE.into(),
243            key.as_ptr(),
244            stream.as_ref().as_ptr(),
245        )
246    })
247}
248
249/// Generate random integers from the given interval (`lower:` and `upper:`).
250///
251/// The values are sampled with equal probability from the integers in
252/// half-open interval `[lb, ub)`. The lower and upper bound can be
253/// scalars or arrays and must be roadcastable to `shape`.
254///
255/// ```rust
256/// use mlx_rs::{array, random};
257///
258/// let key = random::key(0).unwrap();
259///
260/// // generate an array of Int values, one in the range [0, 20) and one in the range [10, 100)
261/// let array = random::randint::<_, i32>(array!([0, 20]), array!([10, 100]), None, &key);
262/// ```
263#[generate_macro(customize(root = "$crate::random"))]
264#[default_device]
265pub fn randint_device<'a, E: Into<Array>, T: ArrayElement>(
266    lower: E,
267    upper: E,
268    #[optional] shape: impl IntoOption<&'a [i32]>,
269    #[optional] key: impl Into<Option<&'a Array>>,
270    #[optional] stream: impl AsRef<Stream>,
271) -> Result<Array> {
272    let lb: Array = lower.into();
273    let ub: Array = upper.into();
274    let shape = shape.into_option().unwrap_or(lb.shape());
275    let key = resolve(key)?;
276
277    Array::try_from_op(|res| unsafe {
278        mlx_sys::mlx_random_randint(
279            res,
280            lb.as_ptr(),
281            ub.as_ptr(),
282            shape.as_ptr(),
283            shape.len(),
284            T::DTYPE.into(),
285            key.as_ptr(),
286            stream.as_ref().as_ptr(),
287        )
288    })
289}
290
291/// Generate Bernoulli random values with a given `p` value.
292///
293/// The values are sampled from the bernoulli distribution with parameter
294/// `p`. The parameter `p` must have a floating point type and
295/// must be broadcastable to `shape`.
296///
297/// ```rust
298/// use mlx_rs::{array, Array, random};
299///
300/// let key = random::key(0).unwrap();
301///
302/// // generate a single random Bool with p = 0.8
303/// let p: Array = 0.8.into();
304/// let value = random::bernoulli(&p, None, &key);
305///
306/// // generate an array of shape [50, 2] of random Bool with p = 0.8
307/// let array = random::bernoulli(&p, &[50, 2], &key);
308///
309/// // generate an array of [3] Bool with the given p values
310/// let array = random::bernoulli(&array!([0.1, 0.5, 0.8]), None, &key);
311/// ```
312#[generate_macro(customize(root = "$crate::random"))]
313#[default_device]
314pub fn bernoulli_device<'a>(
315    #[optional] p: impl Into<Option<&'a Array>>,
316    #[optional] shape: impl IntoOption<&'a [i32]>,
317    #[optional] key: impl Into<Option<&'a Array>>,
318    #[optional] stream: impl AsRef<Stream>,
319) -> Result<Array> {
320    let default_array = Array::from_f32(0.5);
321    let p = p.into().unwrap_or(&default_array);
322
323    let shape = shape.into_option().unwrap_or(p.shape());
324    let key = resolve(key)?;
325
326    Array::try_from_op(|res| unsafe {
327        mlx_sys::mlx_random_bernoulli(
328            res,
329            p.as_ptr(),
330            shape.as_ptr(),
331            shape.len(),
332            key.as_ptr(),
333            stream.as_ref().as_ptr(),
334        )
335    })
336}
337
338/// Generate values from a truncated normal distribution between `low` and `high`.
339///
340/// The values are sampled from the truncated normal distribution
341/// on the domain `(lower, upper)`. The bounds `lower` and `upper`
342/// can be scalars or arrays and must be broadcastable to `shape`.
343///
344/// ```rust
345/// use mlx_rs::{array, random};
346///
347/// let key = random::key(0).unwrap();
348///
349/// // generate an array of two Float values, one in the range 0 ..< 10
350/// // and one in the range 10 ..< 100
351/// let value = random::truncated_normal::<_, f32>(array!([0, 10]), array!([10, 100]), None, &key);
352/// ```
353#[generate_macro(customize(root = "$crate::random"))]
354#[default_device]
355pub fn truncated_normal_device<'a, E: Into<Array>, T: ArrayElement>(
356    lower: E,
357    upper: E,
358    #[optional] shape: impl IntoOption<&'a [i32]>,
359    #[optional] key: impl Into<Option<&'a Array>>,
360    #[optional] stream: impl AsRef<Stream>,
361) -> Result<Array> {
362    let lb: Array = lower.into();
363    let ub: Array = upper.into();
364    let shape = shape.into_option().unwrap_or(lb.shape());
365    let key = resolve(key)?;
366
367    Array::try_from_op(|res| unsafe {
368        mlx_sys::mlx_random_truncated_normal(
369            res,
370            lb.as_ptr(),
371            ub.as_ptr(),
372            shape.as_ptr(),
373            shape.len(),
374            T::DTYPE.into(),
375            key.as_ptr(),
376            stream.as_ref().as_ptr(),
377        )
378    })
379}
380
381/// Sample from the standard Gumbel distribution.
382///
383/// The values are sampled from a standard Gumbel distribution
384/// which CDF `exp(-exp(-x))`.
385///
386/// ```rust
387/// let key = mlx_rs::random::key(0).unwrap();
388///
389/// // generate a single Float with Gumbel distribution
390/// let value = mlx_rs::random::gumbel::<f32>(None, &key).unwrap().item::<f32>();
391///
392/// // generate an array of Float with Gumbel distribution in shape [10, 5]
393/// let array = mlx_rs::random::gumbel::<f32>(&[10, 5], &key);
394/// ```
395#[generate_macro(customize(root = "$crate::random"))]
396#[default_device]
397pub fn gumbel_device<'a, T: ArrayElement>(
398    #[optional] shape: impl IntoOption<&'a [i32]>,
399    #[optional] key: impl Into<Option<&'a Array>>,
400    #[optional] stream: impl AsRef<Stream>,
401) -> Result<Array> {
402    let shape = shape.into_option().unwrap_or(&[]);
403    let key = resolve(key)?;
404
405    Array::try_from_op(|res| unsafe {
406        mlx_sys::mlx_random_gumbel(
407            res,
408            shape.as_ptr(),
409            shape.len(),
410            T::DTYPE.into(),
411            key.as_ptr(),
412            stream.as_ref().as_ptr(),
413        )
414    })
415}
416
417/// Shape or count for the categorical distribution.
418#[derive(Debug, Clone, Copy)]
419pub enum ShapeOrCount<'a> {
420    /// Shape
421    Shape(&'a [i32]),
422
423    /// Count
424    Count(i32),
425}
426
427/// Sample from a categorical distribution.
428///
429/// The values are sampled from the categorical distribution specified by
430/// the unnormalized values in `logits`.   If the `shape` is not specified
431/// the result shape will be the same shape as `logits` with the `axis`
432/// dimension removed.
433///
434/// /// # Params
435/// # Params
436///
437/// - `logits`: The *unnormalized* categorical distribution(s).
438/// - `axis`(optional): The axis which specifies the distribution. Default is `-1`.
439/// - `shape_or_count`(optional):
440/// - - `Shape`: The shape of the output. This must be broadcast compatible with `logits.shape` with the `axis` dimension removed.
441/// - - `Count`: The number of samples to draw from each of the categorical distributions in `logits`. The output will have the number of samples in the last dimension.
442/// - `key` (optional): A PRNG key.
443///
444/// # Example
445///
446/// ```rust
447/// let key = mlx_rs::random::key(0).unwrap();
448///
449/// let logits = mlx_rs::Array::zeros::<u32>(&[5, 20]).unwrap();
450///
451/// // produces Array of u32 shape &[5]
452/// let result = mlx_rs::random::categorical(&logits, None, None, &key);
453/// ```
454#[generate_macro(customize(root = "$crate::random"))]
455#[default_device]
456pub fn categorical_device<'a>(
457    logits: impl AsRef<Array>,
458    #[optional] axis: impl Into<Option<i32>>,
459    #[optional] shape_or_count: impl Into<Option<ShapeOrCount<'a>>>,
460    #[optional] key: impl Into<Option<&'a Array>>,
461    #[optional] stream: impl AsRef<Stream>,
462) -> Result<Array> {
463    let axis = axis.into().unwrap_or(-1);
464    let key = resolve(key)?;
465
466    match shape_or_count.into() {
467        Some(ShapeOrCount::Shape(shape)) => Array::try_from_op(|res| unsafe {
468            mlx_sys::mlx_random_categorical_shape(
469                res,
470                logits.as_ref().as_ptr(),
471                axis,
472                shape.as_ptr(),
473                shape.len(),
474                key.as_ptr(),
475                stream.as_ref().as_ptr(),
476            )
477        }),
478        Some(ShapeOrCount::Count(num_samples)) => Array::try_from_op(|res| unsafe {
479            mlx_sys::mlx_random_categorical_num_samples(
480                res,
481                logits.as_ref().as_ptr(),
482                axis,
483                num_samples,
484                key.as_ptr(),
485                stream.as_ref().as_ptr(),
486            )
487        }),
488        None => Array::try_from_op(|res| unsafe {
489            mlx_sys::mlx_random_categorical(
490                res,
491                logits.as_ref().as_ptr(),
492                axis,
493                key.as_ptr(),
494                stream.as_ref().as_ptr(),
495            )
496        }),
497    }
498}
499
500#[cfg(test)]
501mod tests {
502    use super::*;
503    use crate::{array, assert_array_eq};
504    use float_eq::{assert_float_eq, float_eq};
505
506    #[test]
507    fn test_global_rng() {
508        seed(3).unwrap();
509        let a = uniform::<_, f32>(0, 1, None, None).unwrap();
510        let b = uniform::<_, f32>(0, 1, None, None).unwrap();
511
512        seed(3).unwrap();
513        let x = uniform::<_, f32>(0, 1, None, None).unwrap();
514        let y = uniform::<_, f32>(0, 1, None, None).unwrap();
515
516        assert_array_eq!(a, x, 0.01);
517        assert_array_eq!(b, y, 0.01);
518    }
519
520    #[test]
521    fn test_key() {
522        let k1 = key(0).unwrap();
523        let k2 = key(0).unwrap();
524        assert!(k1 == k2);
525
526        let k2 = key(1).unwrap();
527        assert!(k1 != k2);
528    }
529
530    #[test]
531    fn test_split() {
532        let key = key(0).unwrap();
533
534        let (k1, k2) = split(&key, 2).unwrap();
535        assert!(k1 != k2);
536
537        let (r1, r2) = split(&key, 2).unwrap();
538        assert!(r1 == k1);
539        assert!(r2 == k2);
540    }
541
542    #[test]
543    fn test_uniform_no_seed() {
544        let value = uniform::<_, f32>(0, 10, &[3], None).unwrap();
545        assert_eq!(value.shape(), &[3]);
546    }
547
548    #[test]
549    fn test_uniform_single() {
550        let key = key(0).unwrap();
551        let value = uniform::<_, f32>(0, 10, None, Some(&key)).unwrap();
552        float_eq!(value.item::<f32>(), 4.18, abs <= 0.01);
553    }
554
555    #[test]
556    fn test_uniform_multiple() {
557        let key = key(0).unwrap();
558        let value = uniform::<_, f32>(0, 10, &[3], Some(&key)).unwrap();
559        let expected = Array::from_slice(&[9.65, 3.14, 6.33], &[3]);
560
561        assert_array_eq!(value, expected, 0.01);
562    }
563
564    #[test]
565    fn test_uniform_multiple_array() {
566        let key = key(0).unwrap();
567        let value = uniform::<_, f32>(&[0, 10], &[10, 100], &[2], Some(&key)).unwrap();
568        let expected = Array::from_slice(&[2.16, 82.37], &[2]);
569
570        assert_array_eq!(value, expected, 0.01);
571    }
572
573    #[test]
574    fn test_uniform_non_float() {
575        let key = key(0).unwrap();
576        let value = uniform::<_, i32>(&[0, 10], &[10, 100], &[2], Some(&key));
577        assert!(value.is_err());
578    }
579
580    #[test]
581    fn test_normal() {
582        let key = key(0).unwrap();
583        let value = normal::<f32>(None, None, None, &key).unwrap();
584        float_eq!(value.item::<f32>(), -0.20, abs <= 0.01);
585    }
586
587    #[test]
588    fn test_normal_non_float() {
589        let key = key(0).unwrap();
590        let value = normal::<i32>(None, None, None, &key);
591        assert!(value.is_err());
592    }
593
594    #[test]
595    fn test_multivariate_normal() {
596        let key = key(0).unwrap();
597        let mean = Array::from_slice(&[0.0, 0.0], &[2]);
598        let covariance = Array::from_slice(&[1.0, 0.0, 0.0, 1.0], &[2, 2]);
599
600        let a = multivariate_normal::<f32>(&mean, &covariance, &[3], &key).unwrap();
601        assert!(a.shape() == [3, 2]);
602    }
603
604    #[test]
605    fn test_randint_single() {
606        let key = key(0).unwrap();
607        let value = randint::<_, i32>(0, 100, None, Some(&key)).unwrap();
608        assert_eq!(value.item::<i32>(), 41);
609    }
610
611    #[test]
612    fn test_randint_multiple() {
613        let key = key(0).unwrap();
614        let value =
615            randint::<_, i32>(array!([0, 10]), array!([10, 100]), None, Some(&key)).unwrap();
616        let expected = Array::from_slice(&[2, 82], &[2]);
617
618        assert_array_eq!(value, expected, 0.01);
619    }
620
621    #[test]
622    fn test_randint_non_int() {
623        let key = key(0).unwrap();
624        let value = randint::<_, f32>(array!([0, 10]), array!([10, 100]), None, Some(&key));
625        assert!(value.is_err());
626    }
627
628    #[test]
629    fn test_bernoulli_single() {
630        let key = key(0).unwrap();
631        let value = bernoulli(None, None, &key).unwrap();
632        assert!(value.item::<bool>());
633    }
634
635    #[test]
636    fn test_bernoulli_multiple() {
637        let key = key(0).unwrap();
638        let value = bernoulli(None, &[4], &key).unwrap();
639        let expected = Array::from_slice(&[false, true, false, true], &[4]);
640
641        assert_array_eq!(value, expected, 0.01);
642    }
643
644    #[test]
645    fn test_bernoulli_p() {
646        let key = key(0).unwrap();
647        let p: Array = 0.8.into();
648        let value = bernoulli(&p, &[4], &key).unwrap();
649        let expected = Array::from_slice(&[false, true, true, true], &[4]);
650
651        assert_array_eq!(value, expected, 0.01);
652    }
653
654    #[test]
655    fn test_bernoulli_p_array() {
656        let key = key(0).unwrap();
657        let value = bernoulli(&array!([0.1, 0.5, 0.8]), None, &key).unwrap();
658        let expected = Array::from_slice(&[false, true, true], &[3]);
659
660        assert_array_eq!(value, expected, 0.01);
661    }
662
663    #[test]
664    fn test_truncated_normal_single() {
665        let key = key(0).unwrap();
666        let value = truncated_normal::<_, f32>(0, 10, None, &key).unwrap();
667        assert_array_eq!(value, Array::from_f32(0.55), 0.01);
668    }
669
670    #[test]
671    fn test_truncated_normal_multiple() {
672        let key = key(0).unwrap();
673        let value = truncated_normal::<_, f32>(0.0, 0.5, &[3], &key).unwrap();
674        let expected = Array::from_slice(&[0.48, 0.15, 0.30], &[3]);
675
676        assert_array_eq!(value, expected, 0.01);
677    }
678
679    #[test]
680    fn test_truncated_normal_multiple_array() {
681        let key = key(0).unwrap();
682        let value =
683            truncated_normal::<_, f32>(array!([0.0, 0.5]), array!([0.5, 1.0]), None, &key).unwrap();
684        let expected = Array::from_slice(&[0.10, 0.88], &[2]);
685
686        assert_array_eq!(value, expected, 0.01);
687    }
688
689    #[test]
690    fn test_gumbel() {
691        let key = key(0).unwrap();
692        let value = gumbel::<f32>(None, &key).unwrap();
693        assert_array_eq!(value, Array::from_f32(0.13), 0.01);
694    }
695
696    #[test]
697    fn test_logits() {
698        let key = key(0).unwrap();
699        let logits = Array::zeros::<u32>(&[5, 20]).unwrap();
700        let result = categorical(&logits, None, None, &key).unwrap();
701
702        assert_eq!(result.shape(), [5]);
703
704        let expected = Array::from_slice(&[1, 1, 17, 17, 17], &[5]);
705        assert_array_eq!(result, expected, 0.01);
706    }
707
708    #[test]
709    fn test_logits_count() {
710        let key = key(0).unwrap();
711        let logits = Array::zeros::<u32>(&[5, 20]).unwrap();
712        let result = categorical(&logits, None, ShapeOrCount::Count(2), &key).unwrap();
713
714        assert_eq!(result.shape(), [5, 2]);
715
716        let expected = Array::from_slice(&[16, 3, 14, 10, 17, 7, 6, 8, 12, 8], &[5, 2]);
717        assert_array_eq!(result, expected, 0.01);
718    }
719
720    #[test]
721    fn test_random_seed_same() {
722        // Same random seed should produce the same results
723        let seed = 23;
724        let mut results = Vec::new();
725        let f = || {
726            uniform::<_, f32>(0.0, 1.0, &[10, 10], None)?
727                .sum(None)?
728                .try_item::<f32>()
729        };
730        for _ in 0..10 {
731            let mut state = RandomState::new().unwrap();
732            state.seed(seed).unwrap();
733            let result = with_random_state(state, f).unwrap();
734            results.push(result);
735        }
736
737        // Check that all results are the same within a small tolerance
738        let first = results[0];
739        for result in &results[1..] {
740            assert_float_eq!(
741                first,
742                *result,
743                abs <= 0.01,
744                "Results should be equal for the same seed"
745            );
746        }
747    }
748}