1use crate::ops::indexing::TryIndexOp;
4use crate::utils::guard::Guarded;
5use crate::utils::IntoOption;
6use crate::{error::Result, Array, ArrayElement, Stream, StreamOrDevice};
7use mach_sys::mach_time;
8use mlx_internal_macros::{default_device, generate_macro};
9use parking_lot::Mutex;
10use std::borrow::Cow;
11use std::sync::OnceLock;
12
13struct RandomState {
14 state: Array,
15}
16
17impl RandomState {
18 fn new() -> Result<Self> {
19 let now = unsafe { mach_time::mach_approximate_time() };
20 Ok(Self { state: key(now)? })
21 }
22
23 fn next(&mut self) -> Result<Array> {
24 let next = split(&self.state, 2)?;
25 self.state = next.0;
26 Ok(next.1)
27 }
28
29 fn seed(&mut self, seed: u64) -> Result<()> {
30 self.state = key(seed)?;
31 Ok(())
32 }
33}
34
35fn state() -> &'static Mutex<RandomState> {
36 static STATE: OnceLock<Mutex<RandomState>> = OnceLock::new();
37 STATE.get_or_init(|| Mutex::new(RandomState::new().unwrap()))
38}
39
40fn key_or_next<'a>(key: impl Into<Option<&'a Array>>) -> Result<Cow<'a, Array>> {
42 key.into().map_or_else(
43 || {
44 let mut state = state().lock();
45 state.next().map(Cow::Owned)
46 },
47 |k| Ok(Cow::Borrowed(k)),
48 )
49}
50
51pub fn seed(seed: u64) -> Result<()> {
53 let mut state = state().lock();
54 state.seed(seed)
55}
56
57pub fn key(seed: u64) -> Result<Array> {
63 Array::try_from_op(|res| unsafe { mlx_sys::mlx_random_key(res, seed) })
64}
65
66#[default_device]
68pub fn split_device(
69 key: impl AsRef<Array>,
70 num: i32,
71 stream: impl AsRef<Stream>,
72) -> Result<(Array, Array)> {
73 let keys = Array::try_from_op(|res| unsafe {
74 mlx_sys::mlx_random_split_num(res, key.as_ref().as_ptr(), num, stream.as_ref().as_ptr())
75 })?;
76
77 Ok((keys.try_index(0)?, keys.try_index(1)?))
78}
79
80#[generate_macro(customize(root = "$crate::random"))]
101#[default_device]
102pub fn uniform_device<'a, E: Into<Array>, T: ArrayElement>(
103 lower: E,
104 upper: E,
105 #[optional] shape: impl IntoOption<&'a [i32]>,
106 #[optional] key: impl Into<Option<&'a Array>>,
107 #[optional] stream: impl AsRef<Stream>,
108) -> Result<Array> {
109 let lb: Array = lower.into();
110 let ub: Array = upper.into();
111 let shape = shape.into_option().unwrap_or(&[]);
112 let key = key_or_next(key)?;
113
114 Array::try_from_op(|res| unsafe {
115 mlx_sys::mlx_random_uniform(
116 res,
117 lb.as_ptr(),
118 ub.as_ptr(),
119 shape.as_ptr(),
120 shape.len(),
121 T::DTYPE.into(),
122 key.as_ptr(),
123 stream.as_ref().as_ptr(),
124 )
125 })
126}
127
128#[generate_macro(customize(root = "$crate::random"))]
152#[default_device]
153pub fn normal_device<'a, T: ArrayElement>(
154 #[optional] shape: impl IntoOption<&'a [i32]>,
155 #[optional] loc: impl Into<Option<f32>>,
156 #[optional] scale: impl Into<Option<f32>>,
157 #[optional] key: impl Into<Option<&'a Array>>,
158 #[optional] stream: impl AsRef<Stream>,
159) -> Result<Array> {
160 let shape = shape.into_option().unwrap_or(&[]);
161 let key = key_or_next(key)?;
162
163 Array::try_from_op(|res| unsafe {
164 mlx_sys::mlx_random_normal(
165 res,
166 shape.as_ptr(),
167 shape.len(),
168 T::DTYPE.into(),
169 loc.into().unwrap_or(0.0),
170 scale.into().unwrap_or(1.0),
171 key.as_ptr(),
172 stream.as_ref().as_ptr(),
173 )
174 })
175}
176
177#[generate_macro(customize(root = "$crate::random"))]
188#[default_device(device = "cpu")] pub fn multivariate_normal_device<'a, T: ArrayElement>(
190 mean: impl AsRef<Array>,
191 covariance: impl AsRef<Array>,
192 #[optional] shape: impl IntoOption<&'a [i32]>,
193 #[optional] key: impl Into<Option<&'a Array>>,
194 #[optional] stream: impl AsRef<Stream>,
195) -> Result<Array> {
196 let shape = shape.into_option().unwrap_or(&[]);
197 let key = key_or_next(key)?;
198
199 Array::try_from_op(|res| unsafe {
200 mlx_sys::mlx_random_multivariate_normal(
201 res,
202 mean.as_ref().as_ptr(),
203 covariance.as_ref().as_ptr(),
204 shape.as_ptr(),
205 shape.len(),
206 T::DTYPE.into(),
207 key.as_ptr(),
208 stream.as_ref().as_ptr(),
209 )
210 })
211}
212
213#[generate_macro(customize(root = "$crate::random"))]
228#[default_device]
229pub fn randint_device<'a, E: Into<Array>, T: ArrayElement>(
230 lower: E,
231 upper: E,
232 #[optional] shape: impl IntoOption<&'a [i32]>,
233 #[optional] key: impl Into<Option<&'a Array>>,
234 #[optional] stream: impl AsRef<Stream>,
235) -> Result<Array> {
236 let lb: Array = lower.into();
237 let ub: Array = upper.into();
238 let shape = shape.into_option().unwrap_or(lb.shape());
239 let key = key_or_next(key)?;
240
241 Array::try_from_op(|res| unsafe {
242 mlx_sys::mlx_random_randint(
243 res,
244 lb.as_ptr(),
245 ub.as_ptr(),
246 shape.as_ptr(),
247 shape.len(),
248 T::DTYPE.into(),
249 key.as_ptr(),
250 stream.as_ref().as_ptr(),
251 )
252 })
253}
254
255#[generate_macro(customize(root = "$crate::random"))]
277#[default_device]
278pub fn bernoulli_device<'a>(
279 #[optional] p: impl Into<Option<&'a Array>>,
280 #[optional] shape: impl IntoOption<&'a [i32]>,
281 #[optional] key: impl Into<Option<&'a Array>>,
282 #[optional] stream: impl AsRef<Stream>,
283) -> Result<Array> {
284 let default_array = Array::from_f32(0.5);
285 let p = p.into().unwrap_or(&default_array);
286
287 let shape = shape.into_option().unwrap_or(p.shape());
288 let key = key_or_next(key)?;
289
290 Array::try_from_op(|res| unsafe {
291 mlx_sys::mlx_random_bernoulli(
292 res,
293 p.as_ptr(),
294 shape.as_ptr(),
295 shape.len(),
296 key.as_ptr(),
297 stream.as_ref().as_ptr(),
298 )
299 })
300}
301
302#[generate_macro(customize(root = "$crate::random"))]
318#[default_device]
319pub fn truncated_normal_device<'a, E: Into<Array>, T: ArrayElement>(
320 lower: E,
321 upper: E,
322 #[optional] shape: impl IntoOption<&'a [i32]>,
323 #[optional] key: impl Into<Option<&'a Array>>,
324 #[optional] stream: impl AsRef<Stream>,
325) -> Result<Array> {
326 let lb: Array = lower.into();
327 let ub: Array = upper.into();
328 let shape = shape.into_option().unwrap_or(lb.shape());
329 let key = key_or_next(key)?;
330
331 Array::try_from_op(|res| unsafe {
332 mlx_sys::mlx_random_truncated_normal(
333 res,
334 lb.as_ptr(),
335 ub.as_ptr(),
336 shape.as_ptr(),
337 shape.len(),
338 T::DTYPE.into(),
339 key.as_ptr(),
340 stream.as_ref().as_ptr(),
341 )
342 })
343}
344
345#[generate_macro(customize(root = "$crate::random"))]
360#[default_device]
361pub fn gumbel_device<'a, T: ArrayElement>(
362 #[optional] shape: impl IntoOption<&'a [i32]>,
363 #[optional] key: impl Into<Option<&'a Array>>,
364 #[optional] stream: impl AsRef<Stream>,
365) -> Result<Array> {
366 let shape = shape.into_option().unwrap_or(&[]);
367 let key = key_or_next(key)?;
368
369 Array::try_from_op(|res| unsafe {
370 mlx_sys::mlx_random_gumbel(
371 res,
372 shape.as_ptr(),
373 shape.len(),
374 T::DTYPE.into(),
375 key.as_ptr(),
376 stream.as_ref().as_ptr(),
377 )
378 })
379}
380
381#[derive(Debug, Clone, Copy)]
383pub enum ShapeOrCount<'a> {
384 Shape(&'a [i32]),
386
387 Count(i32),
389}
390
391#[generate_macro(customize(root = "$crate::random"))]
419#[default_device]
420pub fn categorical_device<'a>(
421 logits: impl AsRef<Array>,
422 #[optional] axis: impl Into<Option<i32>>,
423 #[optional] shape_or_count: impl Into<Option<ShapeOrCount<'a>>>,
424 #[optional] key: impl Into<Option<&'a Array>>,
425 #[optional] stream: impl AsRef<Stream>,
426) -> Result<Array> {
427 let axis = axis.into().unwrap_or(-1);
428 let key = key_or_next(key)?;
429
430 match shape_or_count.into() {
431 Some(ShapeOrCount::Shape(shape)) => Array::try_from_op(|res| unsafe {
432 mlx_sys::mlx_random_categorical_shape(
433 res,
434 logits.as_ref().as_ptr(),
435 axis,
436 shape.as_ptr(),
437 shape.len(),
438 key.as_ptr(),
439 stream.as_ref().as_ptr(),
440 )
441 }),
442 Some(ShapeOrCount::Count(num_samples)) => Array::try_from_op(|res| unsafe {
443 mlx_sys::mlx_random_categorical_num_samples(
444 res,
445 logits.as_ref().as_ptr(),
446 axis,
447 num_samples,
448 key.as_ptr(),
449 stream.as_ref().as_ptr(),
450 )
451 }),
452 None => Array::try_from_op(|res| unsafe {
453 mlx_sys::mlx_random_categorical(
454 res,
455 logits.as_ref().as_ptr(),
456 axis,
457 key.as_ptr(),
458 stream.as_ref().as_ptr(),
459 )
460 }),
461 }
462}
463
464#[cfg(test)]
465mod tests {
466 use super::*;
467 use crate::{array, assert_array_eq};
468 use float_eq::float_eq;
469
470 #[test]
471 fn test_global_rng() {
472 seed(3).unwrap();
473 let a = uniform::<_, f32>(0, 1, None, None).unwrap();
474 let b = uniform::<_, f32>(0, 1, None, None).unwrap();
475
476 seed(3).unwrap();
477 let x = uniform::<_, f32>(0, 1, None, None).unwrap();
478 let y = uniform::<_, f32>(0, 1, None, None).unwrap();
479
480 assert_array_eq!(a, x, 0.01);
481 assert_array_eq!(b, y, 0.01);
482 }
483
484 #[test]
485 fn test_key() {
486 let k1 = key(0).unwrap();
487 let k2 = key(0).unwrap();
488 assert!(k1 == k2);
489
490 let k2 = key(1).unwrap();
491 assert!(k1 != k2);
492 }
493
494 #[test]
495 fn test_split() {
496 let key = key(0).unwrap();
497
498 let (k1, k2) = split(&key, 2).unwrap();
499 assert!(k1 != k2);
500
501 let (r1, r2) = split(&key, 2).unwrap();
502 assert!(r1 == k1);
503 assert!(r2 == k2);
504 }
505
506 #[test]
507 fn test_uniform_no_seed() {
508 let value = uniform::<_, f32>(0, 10, &[3], None).unwrap();
509 assert_eq!(value.shape(), &[3]);
510 }
511
512 #[test]
513 fn test_uniform_single() {
514 let key = key(0).unwrap();
515 let value = uniform::<_, f32>(0, 10, None, Some(&key)).unwrap();
516 float_eq!(value.item::<f32>(), 4.18, abs <= 0.01);
517 }
518
519 #[test]
520 fn test_uniform_multiple() {
521 let key = key(0).unwrap();
522 let value = uniform::<_, f32>(0, 10, &[3], Some(&key)).unwrap();
523 let expected = Array::from_slice(&[9.65, 3.14, 6.33], &[3]);
524
525 assert_array_eq!(value, expected, 0.01);
526 }
527
528 #[test]
529 fn test_uniform_multiple_array() {
530 let key = key(0).unwrap();
531 let value = uniform::<_, f32>(&[0, 10], &[10, 100], &[2], Some(&key)).unwrap();
532 let expected = Array::from_slice(&[2.16, 82.37], &[2]);
533
534 assert_array_eq!(value, expected, 0.01);
535 }
536
537 #[test]
538 fn test_uniform_non_float() {
539 let key = key(0).unwrap();
540 let value = uniform::<_, i32>(&[0, 10], &[10, 100], &[2], Some(&key));
541 assert!(value.is_err());
542 }
543
544 #[test]
545 fn test_normal() {
546 let key = key(0).unwrap();
547 let value = normal::<f32>(None, None, None, &key).unwrap();
548 float_eq!(value.item::<f32>(), -0.20, abs <= 0.01);
549 }
550
551 #[test]
552 fn test_normal_non_float() {
553 let key = key(0).unwrap();
554 let value = normal::<i32>(None, None, None, &key);
555 assert!(value.is_err());
556 }
557
558 #[test]
559 fn test_multivariate_normal() {
560 let key = key(0).unwrap();
561 let mean = Array::from_slice(&[0.0, 0.0], &[2]);
562 let covariance = Array::from_slice(&[1.0, 0.0, 0.0, 1.0], &[2, 2]);
563
564 let a = multivariate_normal::<f32>(&mean, &covariance, &[3], &key).unwrap();
565 assert!(a.shape() == [3, 2]);
566 }
567
568 #[test]
569 fn test_randint_single() {
570 let key = key(0).unwrap();
571 let value = randint::<_, i32>(0, 100, None, Some(&key)).unwrap();
572 assert_eq!(value.item::<i32>(), 41);
573 }
574
575 #[test]
576 fn test_randint_multiple() {
577 let key = key(0).unwrap();
578 let value =
579 randint::<_, i32>(array!([0, 10]), array!([10, 100]), None, Some(&key)).unwrap();
580 let expected = Array::from_slice(&[2, 82], &[2]);
581
582 assert_array_eq!(value, expected, 0.01);
583 }
584
585 #[test]
586 fn test_randint_non_int() {
587 let key = key(0).unwrap();
588 let value = randint::<_, f32>(array!([0, 10]), array!([10, 100]), None, Some(&key));
589 assert!(value.is_err());
590 }
591
592 #[test]
593 fn test_bernoulli_single() {
594 let key = key(0).unwrap();
595 let value = bernoulli(None, None, &key).unwrap();
596 assert!(value.item::<bool>());
597 }
598
599 #[test]
600 fn test_bernoulli_multiple() {
601 let key = key(0).unwrap();
602 let value = bernoulli(None, &[4], &key).unwrap();
603 let expected = Array::from_slice(&[false, true, false, true], &[4]);
604
605 assert_array_eq!(value, expected, 0.01);
606 }
607
608 #[test]
609 fn test_bernoulli_p() {
610 let key = key(0).unwrap();
611 let p: Array = 0.8.into();
612 let value = bernoulli(&p, &[4], &key).unwrap();
613 let expected = Array::from_slice(&[false, true, true, true], &[4]);
614
615 assert_array_eq!(value, expected, 0.01);
616 }
617
618 #[test]
619 fn test_bernoulli_p_array() {
620 let key = key(0).unwrap();
621 let value = bernoulli(&array!([0.1, 0.5, 0.8]), None, &key).unwrap();
622 let expected = Array::from_slice(&[false, true, true], &[3]);
623
624 assert_array_eq!(value, expected, 0.01);
625 }
626
627 #[test]
628 fn test_truncated_normal_single() {
629 let key = key(0).unwrap();
630 let value = truncated_normal::<_, f32>(0, 10, None, &key).unwrap();
631 assert_array_eq!(value, Array::from_f32(0.55), 0.01);
632 }
633
634 #[test]
635 fn test_truncated_normal_multiple() {
636 let key = key(0).unwrap();
637 let value = truncated_normal::<_, f32>(0.0, 0.5, &[3], &key).unwrap();
638 let expected = Array::from_slice(&[0.48, 0.15, 0.30], &[3]);
639
640 assert_array_eq!(value, expected, 0.01);
641 }
642
643 #[test]
644 fn test_truncated_normal_multiple_array() {
645 let key = key(0).unwrap();
646 let value =
647 truncated_normal::<_, f32>(array!([0.0, 0.5]), array!([0.5, 1.0]), None, &key).unwrap();
648 let expected = Array::from_slice(&[0.10, 0.88], &[2]);
649
650 assert_array_eq!(value, expected, 0.01);
651 }
652
653 #[test]
654 fn test_gumbel() {
655 let key = key(0).unwrap();
656 let value = gumbel::<f32>(None, &key).unwrap();
657 assert_array_eq!(value, Array::from_f32(0.13), 0.01);
658 }
659
660 #[test]
661 fn test_logits() {
662 let key = key(0).unwrap();
663 let logits = Array::zeros::<u32>(&[5, 20]).unwrap();
664 let result = categorical(&logits, None, None, &key).unwrap();
665
666 assert_eq!(result.shape(), [5]);
667
668 let expected = Array::from_slice(&[1, 1, 17, 17, 17], &[5]);
669 assert_array_eq!(result, expected, 0.01);
670 }
671
672 #[test]
673 fn test_logits_count() {
674 let key = key(0).unwrap();
675 let logits = Array::zeros::<u32>(&[5, 20]).unwrap();
676 let result = categorical(&logits, None, ShapeOrCount::Count(2), &key).unwrap();
677
678 assert_eq!(result.shape(), [5, 2]);
679
680 let expected = Array::from_slice(&[16, 3, 14, 10, 17, 7, 6, 8, 12, 8], &[5, 2]);
681 assert_array_eq!(result, expected, 0.01);
682 }
683}