mlx_rs/
random.rs

1//! Collection of functions related to random number generation
2
3use crate::ops::indexing::TryIndexOp;
4use crate::utils::guard::Guarded;
5use crate::utils::IntoOption;
6use crate::{error::Result, Array, ArrayElement, Stream, StreamOrDevice};
7use mach_sys::mach_time;
8use mlx_internal_macros::{default_device, generate_macro};
9use parking_lot::Mutex;
10use std::borrow::Cow;
11use std::sync::OnceLock;
12
13struct RandomState {
14    state: Array,
15}
16
17impl RandomState {
18    fn new() -> Result<Self> {
19        let now = unsafe { mach_time::mach_approximate_time() };
20        Ok(Self { state: key(now)? })
21    }
22
23    fn next(&mut self) -> Result<Array> {
24        let next = split(&self.state, 2)?;
25        self.state = next.0;
26        Ok(next.1)
27    }
28
29    fn seed(&mut self, seed: u64) -> Result<()> {
30        self.state = key(seed)?;
31        Ok(())
32    }
33}
34
35fn state() -> &'static Mutex<RandomState> {
36    static STATE: OnceLock<Mutex<RandomState>> = OnceLock::new();
37    STATE.get_or_init(|| Mutex::new(RandomState::new().unwrap()))
38}
39
40/// Use given key or generate a new one if `None`.
41fn key_or_next<'a>(key: impl Into<Option<&'a Array>>) -> Result<Cow<'a, Array>> {
42    key.into().map_or_else(
43        || {
44            let mut state = state().lock();
45            state.next().map(Cow::Owned)
46        },
47        |k| Ok(Cow::Borrowed(k)),
48    )
49}
50
51/// Seed the random number generator.
52pub fn seed(seed: u64) -> Result<()> {
53    let mut state = state().lock();
54    state.seed(seed)
55}
56
57/// Get a PRNG key from a seed.
58///
59/// Return a value that can be used as a PRNG key.  All ``random::*``
60/// functions take an optional key -- this will let you control the
61/// random number generation.
62pub fn key(seed: u64) -> Result<Array> {
63    Array::try_from_op(|res| unsafe { mlx_sys::mlx_random_key(res, seed) })
64}
65
66/// Split a PRNG key into two keys and return a tuple.
67#[default_device]
68pub fn split_device(
69    key: impl AsRef<Array>,
70    num: i32,
71    stream: impl AsRef<Stream>,
72) -> Result<(Array, Array)> {
73    let keys = Array::try_from_op(|res| unsafe {
74        mlx_sys::mlx_random_split_num(res, key.as_ref().as_ptr(), num, stream.as_ref().as_ptr())
75    })?;
76
77    Ok((keys.try_index(0)?, keys.try_index(1)?))
78}
79
80/// Generate uniformly distributed random numbers.
81/// The values are sampled uniformly in the half-open interval `[lower, upper)`.
82/// The lower and upper bound can be scalars or arrays and must be broadcastable to `shape`.
83///
84/// # Params
85///
86/// - `lower`: Lower bound of the distribution.
87/// - `upper`: Upper bound of the distribution.
88/// - `shape` (optional): Shape of the output. Default is `&[]`.
89/// - `key` (optional): A PRNG key.
90///
91/// ```rust
92/// let key = mlx_rs::random::key(0).unwrap();
93///
94/// // create an array of shape `[50]` type f32 values in the range [0, 10)
95/// let array = mlx_rs::random::uniform::<_, f32>(0, 10, &[50], &key);
96///
97/// // same, but in range [0.5, 1)
98/// let array = mlx_rs::random::uniform::<_, f32>(0.5f32, 1f32, &[50], &key);
99/// ```
100#[generate_macro(customize(root = "$crate::random"))]
101#[default_device]
102pub fn uniform_device<'a, E: Into<Array>, T: ArrayElement>(
103    lower: E,
104    upper: E,
105    #[optional] shape: impl IntoOption<&'a [i32]>,
106    #[optional] key: impl Into<Option<&'a Array>>,
107    #[optional] stream: impl AsRef<Stream>,
108) -> Result<Array> {
109    let lb: Array = lower.into();
110    let ub: Array = upper.into();
111    let shape = shape.into_option().unwrap_or(&[]);
112    let key = key_or_next(key)?;
113
114    Array::try_from_op(|res| unsafe {
115        mlx_sys::mlx_random_uniform(
116            res,
117            lb.as_ptr(),
118            ub.as_ptr(),
119            shape.as_ptr(),
120            shape.len(),
121            T::DTYPE.into(),
122            key.as_ptr(),
123            stream.as_ref().as_ptr(),
124        )
125    })
126}
127
128/// Generate normally distributed random numbers.
129///
130/// Generate an array of random numbers using the optional shape. The result
131/// will be of the given `T`. `T` must be a floating point type.
132///
133/// # Params
134///
135///  - shape: shape of the output, if `None` a single value is returned
136///  - loc: mean of the distribution, default is `0.0`
137///  - scale: standard deviation of the distribution, default is `1.0`
138///  - key: PRNG key
139///
140/// # Example
141///
142/// ```rust
143/// let key = mlx_rs::random::key(0).unwrap();
144///
145/// // generate a single f32 with normal distribution
146/// let value = mlx_rs::random::normal::<f32>(None, None, None, &key).unwrap().item::<f32>();
147///
148/// // generate an array of f32 with normal distribution in shape [10, 5]
149/// let array = mlx_rs::random::normal::<f32>(&[10, 5], None, None, &key);
150/// ```
151#[generate_macro(customize(root = "$crate::random"))]
152#[default_device]
153pub fn normal_device<'a, T: ArrayElement>(
154    #[optional] shape: impl IntoOption<&'a [i32]>,
155    #[optional] loc: impl Into<Option<f32>>,
156    #[optional] scale: impl Into<Option<f32>>,
157    #[optional] key: impl Into<Option<&'a Array>>,
158    #[optional] stream: impl AsRef<Stream>,
159) -> Result<Array> {
160    let shape = shape.into_option().unwrap_or(&[]);
161    let key = key_or_next(key)?;
162
163    Array::try_from_op(|res| unsafe {
164        mlx_sys::mlx_random_normal(
165            res,
166            shape.as_ptr(),
167            shape.len(),
168            T::DTYPE.into(),
169            loc.into().unwrap_or(0.0),
170            scale.into().unwrap_or(1.0),
171            key.as_ptr(),
172            stream.as_ref().as_ptr(),
173        )
174    })
175}
176
177/// Generate jointly-normal random samples given a mean and covariance.
178///
179/// The matrix `covariance` must be positive semi-definite. The behavior is
180/// undefined if it is not.  The only supported output type is f32.
181///
182/// # Params
183/// - `mean`: array of shape `[..., n]`, the mean of the distribution.
184/// - `covariance`: array  of shape `[..., n, n]`, the covariance matrix of the distribution. The batch shape `...` must be broadcast-compatible with that of `mean`.
185/// - `shape`: The output shape must be broadcast-compatible with `&mean.shape[..mean.shape.len()-1]` and `&covariance.shape[..covariance.shape.len()-2]`. If empty, the result shape is determined by broadcasting the batch shapes of `mean` and `covariance`.
186/// - `key`: PRNG key.
187#[generate_macro(customize(root = "$crate::random"))]
188#[default_device(device = "cpu")] // TODO: not supported on GPU yet
189pub fn multivariate_normal_device<'a, T: ArrayElement>(
190    mean: impl AsRef<Array>,
191    covariance: impl AsRef<Array>,
192    #[optional] shape: impl IntoOption<&'a [i32]>,
193    #[optional] key: impl Into<Option<&'a Array>>,
194    #[optional] stream: impl AsRef<Stream>,
195) -> Result<Array> {
196    let shape = shape.into_option().unwrap_or(&[]);
197    let key = key_or_next(key)?;
198
199    Array::try_from_op(|res| unsafe {
200        mlx_sys::mlx_random_multivariate_normal(
201            res,
202            mean.as_ref().as_ptr(),
203            covariance.as_ref().as_ptr(),
204            shape.as_ptr(),
205            shape.len(),
206            T::DTYPE.into(),
207            key.as_ptr(),
208            stream.as_ref().as_ptr(),
209        )
210    })
211}
212
213/// Generate random integers from the given interval (`lower:` and `upper:`).
214///
215/// The values are sampled with equal probability from the integers in
216/// half-open interval `[lb, ub)`. The lower and upper bound can be
217/// scalars or arrays and must be roadcastable to `shape`.
218///
219/// ```rust
220/// use mlx_rs::{array, random};
221///
222/// let key = random::key(0).unwrap();
223///
224/// // generate an array of Int values, one in the range [0, 20) and one in the range [10, 100)
225/// let array = random::randint::<_, i32>(array!([0, 20]), array!([10, 100]), None, &key);
226/// ```
227#[generate_macro(customize(root = "$crate::random"))]
228#[default_device]
229pub fn randint_device<'a, E: Into<Array>, T: ArrayElement>(
230    lower: E,
231    upper: E,
232    #[optional] shape: impl IntoOption<&'a [i32]>,
233    #[optional] key: impl Into<Option<&'a Array>>,
234    #[optional] stream: impl AsRef<Stream>,
235) -> Result<Array> {
236    let lb: Array = lower.into();
237    let ub: Array = upper.into();
238    let shape = shape.into_option().unwrap_or(lb.shape());
239    let key = key_or_next(key)?;
240
241    Array::try_from_op(|res| unsafe {
242        mlx_sys::mlx_random_randint(
243            res,
244            lb.as_ptr(),
245            ub.as_ptr(),
246            shape.as_ptr(),
247            shape.len(),
248            T::DTYPE.into(),
249            key.as_ptr(),
250            stream.as_ref().as_ptr(),
251        )
252    })
253}
254
255/// Generate Bernoulli random values with a given `p` value.
256///
257/// The values are sampled from the bernoulli distribution with parameter
258/// `p`. The parameter `p` must have a floating point type and
259/// must be broadcastable to `shape`.
260///
261/// ```rust
262/// use mlx_rs::{array, Array, random};
263///
264/// let key = random::key(0).unwrap();
265///
266/// // generate a single random Bool with p = 0.8
267/// let p: Array = 0.8.into();
268/// let value = random::bernoulli(&p, None, &key);
269///
270/// // generate an array of shape [50, 2] of random Bool with p = 0.8
271/// let array = random::bernoulli(&p, &[50, 2], &key);
272///
273/// // generate an array of [3] Bool with the given p values
274/// let array = random::bernoulli(&array!([0.1, 0.5, 0.8]), None, &key);
275/// ```
276#[generate_macro(customize(root = "$crate::random"))]
277#[default_device]
278pub fn bernoulli_device<'a>(
279    #[optional] p: impl Into<Option<&'a Array>>,
280    #[optional] shape: impl IntoOption<&'a [i32]>,
281    #[optional] key: impl Into<Option<&'a Array>>,
282    #[optional] stream: impl AsRef<Stream>,
283) -> Result<Array> {
284    let default_array = Array::from_f32(0.5);
285    let p = p.into().unwrap_or(&default_array);
286
287    let shape = shape.into_option().unwrap_or(p.shape());
288    let key = key_or_next(key)?;
289
290    Array::try_from_op(|res| unsafe {
291        mlx_sys::mlx_random_bernoulli(
292            res,
293            p.as_ptr(),
294            shape.as_ptr(),
295            shape.len(),
296            key.as_ptr(),
297            stream.as_ref().as_ptr(),
298        )
299    })
300}
301
302/// Generate values from a truncated normal distribution between `low` and `high`.
303///
304/// The values are sampled from the truncated normal distribution
305/// on the domain `(lower, upper)`. The bounds `lower` and `upper`
306/// can be scalars or arrays and must be broadcastable to `shape`.
307///
308/// ```rust
309/// use mlx_rs::{array, random};
310///
311/// let key = random::key(0).unwrap();
312///
313/// // generate an array of two Float values, one in the range 0 ..< 10
314/// // and one in the range 10 ..< 100
315/// let value = random::truncated_normal::<_, f32>(array!([0, 10]), array!([10, 100]), None, &key);
316/// ```
317#[generate_macro(customize(root = "$crate::random"))]
318#[default_device]
319pub fn truncated_normal_device<'a, E: Into<Array>, T: ArrayElement>(
320    lower: E,
321    upper: E,
322    #[optional] shape: impl IntoOption<&'a [i32]>,
323    #[optional] key: impl Into<Option<&'a Array>>,
324    #[optional] stream: impl AsRef<Stream>,
325) -> Result<Array> {
326    let lb: Array = lower.into();
327    let ub: Array = upper.into();
328    let shape = shape.into_option().unwrap_or(lb.shape());
329    let key = key_or_next(key)?;
330
331    Array::try_from_op(|res| unsafe {
332        mlx_sys::mlx_random_truncated_normal(
333            res,
334            lb.as_ptr(),
335            ub.as_ptr(),
336            shape.as_ptr(),
337            shape.len(),
338            T::DTYPE.into(),
339            key.as_ptr(),
340            stream.as_ref().as_ptr(),
341        )
342    })
343}
344
345/// Sample from the standard Gumbel distribution.
346///
347/// The values are sampled from a standard Gumbel distribution
348/// which CDF `exp(-exp(-x))`.
349///
350/// ```rust
351/// let key = mlx_rs::random::key(0).unwrap();
352///
353/// // generate a single Float with Gumbel distribution
354/// let value = mlx_rs::random::gumbel::<f32>(None, &key).unwrap().item::<f32>();
355///
356/// // generate an array of Float with Gumbel distribution in shape [10, 5]
357/// let array = mlx_rs::random::gumbel::<f32>(&[10, 5], &key);
358/// ```
359#[generate_macro(customize(root = "$crate::random"))]
360#[default_device]
361pub fn gumbel_device<'a, T: ArrayElement>(
362    #[optional] shape: impl IntoOption<&'a [i32]>,
363    #[optional] key: impl Into<Option<&'a Array>>,
364    #[optional] stream: impl AsRef<Stream>,
365) -> Result<Array> {
366    let shape = shape.into_option().unwrap_or(&[]);
367    let key = key_or_next(key)?;
368
369    Array::try_from_op(|res| unsafe {
370        mlx_sys::mlx_random_gumbel(
371            res,
372            shape.as_ptr(),
373            shape.len(),
374            T::DTYPE.into(),
375            key.as_ptr(),
376            stream.as_ref().as_ptr(),
377        )
378    })
379}
380
381/// Shape or count for the categorical distribution.
382#[derive(Debug, Clone, Copy)]
383pub enum ShapeOrCount<'a> {
384    /// Shape
385    Shape(&'a [i32]),
386
387    /// Count
388    Count(i32),
389}
390
391/// Sample from a categorical distribution.
392///
393/// The values are sampled from the categorical distribution specified by
394/// the unnormalized values in `logits`.   If the `shape` is not specified
395/// the result shape will be the same shape as `logits` with the `axis`
396/// dimension removed.
397///
398/// /// # Params
399/// # Params
400///
401/// - `logits`: The *unnormalized* categorical distribution(s).
402/// - `axis`(optional): The axis which specifies the distribution. Default is `-1`.
403/// - `shape_or_count`(optional):
404/// - - `Shape`: The shape of the output. This must be broadcast compatible with `logits.shape` with the `axis` dimension removed.
405/// - - `Count`: The number of samples to draw from each of the categorical distributions in `logits`. The output will have the number of samples in the last dimension.
406/// - `key` (optional): A PRNG key.
407///
408/// # Example
409///
410/// ```rust
411/// let key = mlx_rs::random::key(0).unwrap();
412///
413/// let logits = mlx_rs::Array::zeros::<u32>(&[5, 20]).unwrap();
414///
415/// // produces Array of u32 shape &[5]
416/// let result = mlx_rs::random::categorical(&logits, None, None, &key);
417/// ```
418#[generate_macro(customize(root = "$crate::random"))]
419#[default_device]
420pub fn categorical_device<'a>(
421    logits: impl AsRef<Array>,
422    #[optional] axis: impl Into<Option<i32>>,
423    #[optional] shape_or_count: impl Into<Option<ShapeOrCount<'a>>>,
424    #[optional] key: impl Into<Option<&'a Array>>,
425    #[optional] stream: impl AsRef<Stream>,
426) -> Result<Array> {
427    let axis = axis.into().unwrap_or(-1);
428    let key = key_or_next(key)?;
429
430    match shape_or_count.into() {
431        Some(ShapeOrCount::Shape(shape)) => Array::try_from_op(|res| unsafe {
432            mlx_sys::mlx_random_categorical_shape(
433                res,
434                logits.as_ref().as_ptr(),
435                axis,
436                shape.as_ptr(),
437                shape.len(),
438                key.as_ptr(),
439                stream.as_ref().as_ptr(),
440            )
441        }),
442        Some(ShapeOrCount::Count(num_samples)) => Array::try_from_op(|res| unsafe {
443            mlx_sys::mlx_random_categorical_num_samples(
444                res,
445                logits.as_ref().as_ptr(),
446                axis,
447                num_samples,
448                key.as_ptr(),
449                stream.as_ref().as_ptr(),
450            )
451        }),
452        None => Array::try_from_op(|res| unsafe {
453            mlx_sys::mlx_random_categorical(
454                res,
455                logits.as_ref().as_ptr(),
456                axis,
457                key.as_ptr(),
458                stream.as_ref().as_ptr(),
459            )
460        }),
461    }
462}
463
464#[cfg(test)]
465mod tests {
466    use super::*;
467    use crate::{array, assert_array_eq};
468    use float_eq::float_eq;
469
470    #[test]
471    fn test_global_rng() {
472        seed(3).unwrap();
473        let a = uniform::<_, f32>(0, 1, None, None).unwrap();
474        let b = uniform::<_, f32>(0, 1, None, None).unwrap();
475
476        seed(3).unwrap();
477        let x = uniform::<_, f32>(0, 1, None, None).unwrap();
478        let y = uniform::<_, f32>(0, 1, None, None).unwrap();
479
480        assert_array_eq!(a, x, 0.01);
481        assert_array_eq!(b, y, 0.01);
482    }
483
484    #[test]
485    fn test_key() {
486        let k1 = key(0).unwrap();
487        let k2 = key(0).unwrap();
488        assert!(k1 == k2);
489
490        let k2 = key(1).unwrap();
491        assert!(k1 != k2);
492    }
493
494    #[test]
495    fn test_split() {
496        let key = key(0).unwrap();
497
498        let (k1, k2) = split(&key, 2).unwrap();
499        assert!(k1 != k2);
500
501        let (r1, r2) = split(&key, 2).unwrap();
502        assert!(r1 == k1);
503        assert!(r2 == k2);
504    }
505
506    #[test]
507    fn test_uniform_no_seed() {
508        let value = uniform::<_, f32>(0, 10, &[3], None).unwrap();
509        assert_eq!(value.shape(), &[3]);
510    }
511
512    #[test]
513    fn test_uniform_single() {
514        let key = key(0).unwrap();
515        let value = uniform::<_, f32>(0, 10, None, Some(&key)).unwrap();
516        float_eq!(value.item::<f32>(), 4.18, abs <= 0.01);
517    }
518
519    #[test]
520    fn test_uniform_multiple() {
521        let key = key(0).unwrap();
522        let value = uniform::<_, f32>(0, 10, &[3], Some(&key)).unwrap();
523        let expected = Array::from_slice(&[9.65, 3.14, 6.33], &[3]);
524
525        assert_array_eq!(value, expected, 0.01);
526    }
527
528    #[test]
529    fn test_uniform_multiple_array() {
530        let key = key(0).unwrap();
531        let value = uniform::<_, f32>(&[0, 10], &[10, 100], &[2], Some(&key)).unwrap();
532        let expected = Array::from_slice(&[2.16, 82.37], &[2]);
533
534        assert_array_eq!(value, expected, 0.01);
535    }
536
537    #[test]
538    fn test_uniform_non_float() {
539        let key = key(0).unwrap();
540        let value = uniform::<_, i32>(&[0, 10], &[10, 100], &[2], Some(&key));
541        assert!(value.is_err());
542    }
543
544    #[test]
545    fn test_normal() {
546        let key = key(0).unwrap();
547        let value = normal::<f32>(None, None, None, &key).unwrap();
548        float_eq!(value.item::<f32>(), -0.20, abs <= 0.01);
549    }
550
551    #[test]
552    fn test_normal_non_float() {
553        let key = key(0).unwrap();
554        let value = normal::<i32>(None, None, None, &key);
555        assert!(value.is_err());
556    }
557
558    #[test]
559    fn test_multivariate_normal() {
560        let key = key(0).unwrap();
561        let mean = Array::from_slice(&[0.0, 0.0], &[2]);
562        let covariance = Array::from_slice(&[1.0, 0.0, 0.0, 1.0], &[2, 2]);
563
564        let a = multivariate_normal::<f32>(&mean, &covariance, &[3], &key).unwrap();
565        assert!(a.shape() == [3, 2]);
566    }
567
568    #[test]
569    fn test_randint_single() {
570        let key = key(0).unwrap();
571        let value = randint::<_, i32>(0, 100, None, Some(&key)).unwrap();
572        assert_eq!(value.item::<i32>(), 41);
573    }
574
575    #[test]
576    fn test_randint_multiple() {
577        let key = key(0).unwrap();
578        let value =
579            randint::<_, i32>(array!([0, 10]), array!([10, 100]), None, Some(&key)).unwrap();
580        let expected = Array::from_slice(&[2, 82], &[2]);
581
582        assert_array_eq!(value, expected, 0.01);
583    }
584
585    #[test]
586    fn test_randint_non_int() {
587        let key = key(0).unwrap();
588        let value = randint::<_, f32>(array!([0, 10]), array!([10, 100]), None, Some(&key));
589        assert!(value.is_err());
590    }
591
592    #[test]
593    fn test_bernoulli_single() {
594        let key = key(0).unwrap();
595        let value = bernoulli(None, None, &key).unwrap();
596        assert!(value.item::<bool>());
597    }
598
599    #[test]
600    fn test_bernoulli_multiple() {
601        let key = key(0).unwrap();
602        let value = bernoulli(None, &[4], &key).unwrap();
603        let expected = Array::from_slice(&[false, true, false, true], &[4]);
604
605        assert_array_eq!(value, expected, 0.01);
606    }
607
608    #[test]
609    fn test_bernoulli_p() {
610        let key = key(0).unwrap();
611        let p: Array = 0.8.into();
612        let value = bernoulli(&p, &[4], &key).unwrap();
613        let expected = Array::from_slice(&[false, true, true, true], &[4]);
614
615        assert_array_eq!(value, expected, 0.01);
616    }
617
618    #[test]
619    fn test_bernoulli_p_array() {
620        let key = key(0).unwrap();
621        let value = bernoulli(&array!([0.1, 0.5, 0.8]), None, &key).unwrap();
622        let expected = Array::from_slice(&[false, true, true], &[3]);
623
624        assert_array_eq!(value, expected, 0.01);
625    }
626
627    #[test]
628    fn test_truncated_normal_single() {
629        let key = key(0).unwrap();
630        let value = truncated_normal::<_, f32>(0, 10, None, &key).unwrap();
631        assert_array_eq!(value, Array::from_f32(0.55), 0.01);
632    }
633
634    #[test]
635    fn test_truncated_normal_multiple() {
636        let key = key(0).unwrap();
637        let value = truncated_normal::<_, f32>(0.0, 0.5, &[3], &key).unwrap();
638        let expected = Array::from_slice(&[0.48, 0.15, 0.30], &[3]);
639
640        assert_array_eq!(value, expected, 0.01);
641    }
642
643    #[test]
644    fn test_truncated_normal_multiple_array() {
645        let key = key(0).unwrap();
646        let value =
647            truncated_normal::<_, f32>(array!([0.0, 0.5]), array!([0.5, 1.0]), None, &key).unwrap();
648        let expected = Array::from_slice(&[0.10, 0.88], &[2]);
649
650        assert_array_eq!(value, expected, 0.01);
651    }
652
653    #[test]
654    fn test_gumbel() {
655        let key = key(0).unwrap();
656        let value = gumbel::<f32>(None, &key).unwrap();
657        assert_array_eq!(value, Array::from_f32(0.13), 0.01);
658    }
659
660    #[test]
661    fn test_logits() {
662        let key = key(0).unwrap();
663        let logits = Array::zeros::<u32>(&[5, 20]).unwrap();
664        let result = categorical(&logits, None, None, &key).unwrap();
665
666        assert_eq!(result.shape(), [5]);
667
668        let expected = Array::from_slice(&[1, 1, 17, 17, 17], &[5]);
669        assert_array_eq!(result, expected, 0.01);
670    }
671
672    #[test]
673    fn test_logits_count() {
674        let key = key(0).unwrap();
675        let logits = Array::zeros::<u32>(&[5, 20]).unwrap();
676        let result = categorical(&logits, None, ShapeOrCount::Count(2), &key).unwrap();
677
678        assert_eq!(result.shape(), [5, 2]);
679
680        let expected = Array::from_slice(&[16, 3, 14, 10, 17, 7, 6, 8, 12, 8], &[5, 2]);
681        assert_array_eq!(result, expected, 0.01);
682    }
683}