1use crate::ops::indexing::TryIndexOp;
4use crate::utils::guard::Guarded;
5use crate::utils::IntoOption;
6use crate::{error::Result, Array, ArrayElement, Stream};
7use mach_sys::mach_time;
8use mlx_internal_macros::{default_device, generate_macro};
9use parking_lot::Mutex;
10use std::borrow::Cow;
11use std::cell::RefCell;
12use std::sync::OnceLock;
13
14static GLOBAL_STATE: OnceLock<Mutex<RandomState>> = OnceLock::new();
15
16thread_local! {
17 static TASK_LOCAL_STATE: RefCell<Option<RandomState>> = const { RefCell::new(None) };
18}
19
20#[derive(Debug, Clone)]
22pub struct RandomState {
23 state: Array,
24}
25
26impl RandomState {
27 fn new() -> Result<Self> {
28 let now = unsafe { mach_time::mach_approximate_time() };
29 Ok(Self { state: key(now)? })
30 }
31
32 fn next(&mut self) -> Result<Array> {
33 let next = split(&self.state, 2)?;
34 self.state = next.0;
35 Ok(next.1)
36 }
37
38 fn seed(&mut self, seed: u64) -> Result<()> {
39 self.state = key(seed)?;
40 Ok(())
41 }
42}
43
44fn global_state() -> &'static Mutex<RandomState> {
45 GLOBAL_STATE.get_or_init(|| Mutex::new(RandomState::new().unwrap()))
46}
47
48fn resolve_task_local_key() -> Option<Result<Array>> {
51 TASK_LOCAL_STATE.with_borrow_mut(|state| state.as_mut().map(|s| s.next()))
52}
53
54fn resolve_global_key() -> Result<Array> {
55 let mut state = global_state().lock();
56 state.next()
57}
58
59fn resolve<'a>(key: impl Into<Option<&'a Array>>) -> Result<Cow<'a, Array>> {
61 key.into().map_or_else(
62 || {
63 resolve_task_local_key()
64 .unwrap_or_else(resolve_global_key)
65 .map(Cow::Owned)
66 },
67 |k| Ok(Cow::Borrowed(k)),
68 )
69}
70
71pub fn with_random_state<F, T>(state: RandomState, f: F) -> T
73where
74 F: FnOnce() -> T,
75{
76 let prev_state = TASK_LOCAL_STATE.with_borrow_mut(|s| s.replace(state));
77
78 let result = f();
79
80 TASK_LOCAL_STATE.with_borrow_mut(|s| {
81 *s = prev_state;
82 });
83
84 result
85}
86
87pub fn seed(seed: u64) -> Result<()> {
89 let mut state = global_state().lock();
90 state.seed(seed)
91}
92
93pub fn key(seed: u64) -> Result<Array> {
99 Array::try_from_op(|res| unsafe { mlx_sys::mlx_random_key(res, seed) })
100}
101
102#[default_device]
104pub fn split_device(
105 key: impl AsRef<Array>,
106 num: i32,
107 stream: impl AsRef<Stream>,
108) -> Result<(Array, Array)> {
109 let keys = Array::try_from_op(|res| unsafe {
110 mlx_sys::mlx_random_split_num(res, key.as_ref().as_ptr(), num, stream.as_ref().as_ptr())
111 })?;
112
113 Ok((keys.try_index(0)?, keys.try_index(1)?))
114}
115
116#[generate_macro(customize(root = "$crate::random"))]
137#[default_device]
138pub fn uniform_device<'a, E: Into<Array>, T: ArrayElement>(
139 lower: E,
140 upper: E,
141 #[optional] shape: impl IntoOption<&'a [i32]>,
142 #[optional] key: impl Into<Option<&'a Array>>,
143 #[optional] stream: impl AsRef<Stream>,
144) -> Result<Array> {
145 let lb: Array = lower.into();
146 let ub: Array = upper.into();
147 let shape = shape.into_option().unwrap_or(&[]);
148 let key = resolve(key)?;
149
150 Array::try_from_op(|res| unsafe {
151 mlx_sys::mlx_random_uniform(
152 res,
153 lb.as_ptr(),
154 ub.as_ptr(),
155 shape.as_ptr(),
156 shape.len(),
157 T::DTYPE.into(),
158 key.as_ptr(),
159 stream.as_ref().as_ptr(),
160 )
161 })
162}
163
164#[generate_macro(customize(root = "$crate::random"))]
188#[default_device]
189pub fn normal_device<'a, T: ArrayElement>(
190 #[optional] shape: impl IntoOption<&'a [i32]>,
191 #[optional] loc: impl Into<Option<f32>>,
192 #[optional] scale: impl Into<Option<f32>>,
193 #[optional] key: impl Into<Option<&'a Array>>,
194 #[optional] stream: impl AsRef<Stream>,
195) -> Result<Array> {
196 let shape = shape.into_option().unwrap_or(&[]);
197 let key = resolve(key)?;
198
199 Array::try_from_op(|res| unsafe {
200 mlx_sys::mlx_random_normal(
201 res,
202 shape.as_ptr(),
203 shape.len(),
204 T::DTYPE.into(),
205 loc.into().unwrap_or(0.0),
206 scale.into().unwrap_or(1.0),
207 key.as_ptr(),
208 stream.as_ref().as_ptr(),
209 )
210 })
211}
212
213#[generate_macro(customize(root = "$crate::random"))]
224#[default_device(device = "cpu")] pub fn multivariate_normal_device<'a, T: ArrayElement>(
226 mean: impl AsRef<Array>,
227 covariance: impl AsRef<Array>,
228 #[optional] shape: impl IntoOption<&'a [i32]>,
229 #[optional] key: impl Into<Option<&'a Array>>,
230 #[optional] stream: impl AsRef<Stream>,
231) -> Result<Array> {
232 let shape = shape.into_option().unwrap_or(&[]);
233 let key = resolve(key)?;
234
235 Array::try_from_op(|res| unsafe {
236 mlx_sys::mlx_random_multivariate_normal(
237 res,
238 mean.as_ref().as_ptr(),
239 covariance.as_ref().as_ptr(),
240 shape.as_ptr(),
241 shape.len(),
242 T::DTYPE.into(),
243 key.as_ptr(),
244 stream.as_ref().as_ptr(),
245 )
246 })
247}
248
249#[generate_macro(customize(root = "$crate::random"))]
264#[default_device]
265pub fn randint_device<'a, E: Into<Array>, T: ArrayElement>(
266 lower: E,
267 upper: E,
268 #[optional] shape: impl IntoOption<&'a [i32]>,
269 #[optional] key: impl Into<Option<&'a Array>>,
270 #[optional] stream: impl AsRef<Stream>,
271) -> Result<Array> {
272 let lb: Array = lower.into();
273 let ub: Array = upper.into();
274 let shape = shape.into_option().unwrap_or(lb.shape());
275 let key = resolve(key)?;
276
277 Array::try_from_op(|res| unsafe {
278 mlx_sys::mlx_random_randint(
279 res,
280 lb.as_ptr(),
281 ub.as_ptr(),
282 shape.as_ptr(),
283 shape.len(),
284 T::DTYPE.into(),
285 key.as_ptr(),
286 stream.as_ref().as_ptr(),
287 )
288 })
289}
290
291#[generate_macro(customize(root = "$crate::random"))]
313#[default_device]
314pub fn bernoulli_device<'a>(
315 #[optional] p: impl Into<Option<&'a Array>>,
316 #[optional] shape: impl IntoOption<&'a [i32]>,
317 #[optional] key: impl Into<Option<&'a Array>>,
318 #[optional] stream: impl AsRef<Stream>,
319) -> Result<Array> {
320 let default_array = Array::from_f32(0.5);
321 let p = p.into().unwrap_or(&default_array);
322
323 let shape = shape.into_option().unwrap_or(p.shape());
324 let key = resolve(key)?;
325
326 Array::try_from_op(|res| unsafe {
327 mlx_sys::mlx_random_bernoulli(
328 res,
329 p.as_ptr(),
330 shape.as_ptr(),
331 shape.len(),
332 key.as_ptr(),
333 stream.as_ref().as_ptr(),
334 )
335 })
336}
337
338#[generate_macro(customize(root = "$crate::random"))]
354#[default_device]
355pub fn truncated_normal_device<'a, E: Into<Array>, T: ArrayElement>(
356 lower: E,
357 upper: E,
358 #[optional] shape: impl IntoOption<&'a [i32]>,
359 #[optional] key: impl Into<Option<&'a Array>>,
360 #[optional] stream: impl AsRef<Stream>,
361) -> Result<Array> {
362 let lb: Array = lower.into();
363 let ub: Array = upper.into();
364 let shape = shape.into_option().unwrap_or(lb.shape());
365 let key = resolve(key)?;
366
367 Array::try_from_op(|res| unsafe {
368 mlx_sys::mlx_random_truncated_normal(
369 res,
370 lb.as_ptr(),
371 ub.as_ptr(),
372 shape.as_ptr(),
373 shape.len(),
374 T::DTYPE.into(),
375 key.as_ptr(),
376 stream.as_ref().as_ptr(),
377 )
378 })
379}
380
381#[generate_macro(customize(root = "$crate::random"))]
396#[default_device]
397pub fn gumbel_device<'a, T: ArrayElement>(
398 #[optional] shape: impl IntoOption<&'a [i32]>,
399 #[optional] key: impl Into<Option<&'a Array>>,
400 #[optional] stream: impl AsRef<Stream>,
401) -> Result<Array> {
402 let shape = shape.into_option().unwrap_or(&[]);
403 let key = resolve(key)?;
404
405 Array::try_from_op(|res| unsafe {
406 mlx_sys::mlx_random_gumbel(
407 res,
408 shape.as_ptr(),
409 shape.len(),
410 T::DTYPE.into(),
411 key.as_ptr(),
412 stream.as_ref().as_ptr(),
413 )
414 })
415}
416
417#[derive(Debug, Clone, Copy)]
419pub enum ShapeOrCount<'a> {
420 Shape(&'a [i32]),
422
423 Count(i32),
425}
426
427#[generate_macro(customize(root = "$crate::random"))]
455#[default_device]
456pub fn categorical_device<'a>(
457 logits: impl AsRef<Array>,
458 #[optional] axis: impl Into<Option<i32>>,
459 #[optional] shape_or_count: impl Into<Option<ShapeOrCount<'a>>>,
460 #[optional] key: impl Into<Option<&'a Array>>,
461 #[optional] stream: impl AsRef<Stream>,
462) -> Result<Array> {
463 let axis = axis.into().unwrap_or(-1);
464 let key = resolve(key)?;
465
466 match shape_or_count.into() {
467 Some(ShapeOrCount::Shape(shape)) => Array::try_from_op(|res| unsafe {
468 mlx_sys::mlx_random_categorical_shape(
469 res,
470 logits.as_ref().as_ptr(),
471 axis,
472 shape.as_ptr(),
473 shape.len(),
474 key.as_ptr(),
475 stream.as_ref().as_ptr(),
476 )
477 }),
478 Some(ShapeOrCount::Count(num_samples)) => Array::try_from_op(|res| unsafe {
479 mlx_sys::mlx_random_categorical_num_samples(
480 res,
481 logits.as_ref().as_ptr(),
482 axis,
483 num_samples,
484 key.as_ptr(),
485 stream.as_ref().as_ptr(),
486 )
487 }),
488 None => Array::try_from_op(|res| unsafe {
489 mlx_sys::mlx_random_categorical(
490 res,
491 logits.as_ref().as_ptr(),
492 axis,
493 key.as_ptr(),
494 stream.as_ref().as_ptr(),
495 )
496 }),
497 }
498}
499
500#[cfg(test)]
501mod tests {
502 use super::*;
503 use crate::{array, assert_array_eq};
504 use float_eq::{assert_float_eq, float_eq};
505
506 #[test]
507 fn test_global_rng() {
508 seed(3).unwrap();
509 let a = uniform::<_, f32>(0, 1, None, None).unwrap();
510 let b = uniform::<_, f32>(0, 1, None, None).unwrap();
511
512 seed(3).unwrap();
513 let x = uniform::<_, f32>(0, 1, None, None).unwrap();
514 let y = uniform::<_, f32>(0, 1, None, None).unwrap();
515
516 assert_array_eq!(a, x, 0.01);
517 assert_array_eq!(b, y, 0.01);
518 }
519
520 #[test]
521 fn test_key() {
522 let k1 = key(0).unwrap();
523 let k2 = key(0).unwrap();
524 assert!(k1 == k2);
525
526 let k2 = key(1).unwrap();
527 assert!(k1 != k2);
528 }
529
530 #[test]
531 fn test_split() {
532 let key = key(0).unwrap();
533
534 let (k1, k2) = split(&key, 2).unwrap();
535 assert!(k1 != k2);
536
537 let (r1, r2) = split(&key, 2).unwrap();
538 assert!(r1 == k1);
539 assert!(r2 == k2);
540 }
541
542 #[test]
543 fn test_uniform_no_seed() {
544 let value = uniform::<_, f32>(0, 10, &[3], None).unwrap();
545 assert_eq!(value.shape(), &[3]);
546 }
547
548 #[test]
549 fn test_uniform_single() {
550 let key = key(0).unwrap();
551 let value = uniform::<_, f32>(0, 10, None, Some(&key)).unwrap();
552 float_eq!(value.item::<f32>(), 4.18, abs <= 0.01);
553 }
554
555 #[test]
556 fn test_uniform_multiple() {
557 let key = key(0).unwrap();
558 let value = uniform::<_, f32>(0, 10, &[3], Some(&key)).unwrap();
559 let expected = Array::from_slice(&[9.65, 3.14, 6.33], &[3]);
560
561 assert_array_eq!(value, expected, 0.01);
562 }
563
564 #[test]
565 fn test_uniform_multiple_array() {
566 let key = key(0).unwrap();
567 let value = uniform::<_, f32>(&[0, 10], &[10, 100], &[2], Some(&key)).unwrap();
568 let expected = Array::from_slice(&[2.16, 82.37], &[2]);
569
570 assert_array_eq!(value, expected, 0.01);
571 }
572
573 #[test]
574 fn test_uniform_non_float() {
575 let key = key(0).unwrap();
576 let value = uniform::<_, i32>(&[0, 10], &[10, 100], &[2], Some(&key));
577 assert!(value.is_err());
578 }
579
580 #[test]
581 fn test_normal() {
582 let key = key(0).unwrap();
583 let value = normal::<f32>(None, None, None, &key).unwrap();
584 float_eq!(value.item::<f32>(), -0.20, abs <= 0.01);
585 }
586
587 #[test]
588 fn test_normal_non_float() {
589 let key = key(0).unwrap();
590 let value = normal::<i32>(None, None, None, &key);
591 assert!(value.is_err());
592 }
593
594 #[test]
595 fn test_multivariate_normal() {
596 let key = key(0).unwrap();
597 let mean = Array::from_slice(&[0.0, 0.0], &[2]);
598 let covariance = Array::from_slice(&[1.0, 0.0, 0.0, 1.0], &[2, 2]);
599
600 let a = multivariate_normal::<f32>(&mean, &covariance, &[3], &key).unwrap();
601 assert!(a.shape() == [3, 2]);
602 }
603
604 #[test]
605 fn test_randint_single() {
606 let key = key(0).unwrap();
607 let value = randint::<_, i32>(0, 100, None, Some(&key)).unwrap();
608 assert_eq!(value.item::<i32>(), 41);
609 }
610
611 #[test]
612 fn test_randint_multiple() {
613 let key = key(0).unwrap();
614 let value =
615 randint::<_, i32>(array!([0, 10]), array!([10, 100]), None, Some(&key)).unwrap();
616 let expected = Array::from_slice(&[2, 82], &[2]);
617
618 assert_array_eq!(value, expected, 0.01);
619 }
620
621 #[test]
622 fn test_randint_non_int() {
623 let key = key(0).unwrap();
624 let value = randint::<_, f32>(array!([0, 10]), array!([10, 100]), None, Some(&key));
625 assert!(value.is_err());
626 }
627
628 #[test]
629 fn test_bernoulli_single() {
630 let key = key(0).unwrap();
631 let value = bernoulli(None, None, &key).unwrap();
632 assert!(value.item::<bool>());
633 }
634
635 #[test]
636 fn test_bernoulli_multiple() {
637 let key = key(0).unwrap();
638 let value = bernoulli(None, &[4], &key).unwrap();
639 let expected = Array::from_slice(&[false, true, false, true], &[4]);
640
641 assert_array_eq!(value, expected, 0.01);
642 }
643
644 #[test]
645 fn test_bernoulli_p() {
646 let key = key(0).unwrap();
647 let p: Array = 0.8.into();
648 let value = bernoulli(&p, &[4], &key).unwrap();
649 let expected = Array::from_slice(&[false, true, true, true], &[4]);
650
651 assert_array_eq!(value, expected, 0.01);
652 }
653
654 #[test]
655 fn test_bernoulli_p_array() {
656 let key = key(0).unwrap();
657 let value = bernoulli(&array!([0.1, 0.5, 0.8]), None, &key).unwrap();
658 let expected = Array::from_slice(&[false, true, true], &[3]);
659
660 assert_array_eq!(value, expected, 0.01);
661 }
662
663 #[test]
664 fn test_truncated_normal_single() {
665 let key = key(0).unwrap();
666 let value = truncated_normal::<_, f32>(0, 10, None, &key).unwrap();
667 assert_array_eq!(value, Array::from_f32(0.55), 0.01);
668 }
669
670 #[test]
671 fn test_truncated_normal_multiple() {
672 let key = key(0).unwrap();
673 let value = truncated_normal::<_, f32>(0.0, 0.5, &[3], &key).unwrap();
674 let expected = Array::from_slice(&[0.48, 0.15, 0.30], &[3]);
675
676 assert_array_eq!(value, expected, 0.01);
677 }
678
679 #[test]
680 fn test_truncated_normal_multiple_array() {
681 let key = key(0).unwrap();
682 let value =
683 truncated_normal::<_, f32>(array!([0.0, 0.5]), array!([0.5, 1.0]), None, &key).unwrap();
684 let expected = Array::from_slice(&[0.10, 0.88], &[2]);
685
686 assert_array_eq!(value, expected, 0.01);
687 }
688
689 #[test]
690 fn test_gumbel() {
691 let key = key(0).unwrap();
692 let value = gumbel::<f32>(None, &key).unwrap();
693 assert_array_eq!(value, Array::from_f32(0.13), 0.01);
694 }
695
696 #[test]
697 fn test_logits() {
698 let key = key(0).unwrap();
699 let logits = Array::zeros::<u32>(&[5, 20]).unwrap();
700 let result = categorical(&logits, None, None, &key).unwrap();
701
702 assert_eq!(result.shape(), [5]);
703
704 let expected = Array::from_slice(&[1, 1, 17, 17, 17], &[5]);
705 assert_array_eq!(result, expected, 0.01);
706 }
707
708 #[test]
709 fn test_logits_count() {
710 let key = key(0).unwrap();
711 let logits = Array::zeros::<u32>(&[5, 20]).unwrap();
712 let result = categorical(&logits, None, ShapeOrCount::Count(2), &key).unwrap();
713
714 assert_eq!(result.shape(), [5, 2]);
715
716 let expected = Array::from_slice(&[16, 3, 14, 10, 17, 7, 6, 8, 12, 8], &[5, 2]);
717 assert_array_eq!(result, expected, 0.01);
718 }
719
720 #[test]
721 fn test_random_seed_same() {
722 let seed = 23;
724 let mut results = Vec::new();
725 let f = || {
726 uniform::<_, f32>(0.0, 1.0, &[10, 10], None)?
727 .sum(None)?
728 .try_item::<f32>()
729 };
730 for _ in 0..10 {
731 let mut state = RandomState::new().unwrap();
732 state.seed(seed).unwrap();
733 let result = with_random_state(state, f).unwrap();
734 results.push(result);
735 }
736
737 let first = results[0];
739 for result in &results[1..] {
740 assert_float_eq!(
741 first,
742 *result,
743 abs <= 0.01,
744 "Results should be equal for the same seed"
745 );
746 }
747 }
748}